From d3322bc64dbcfb9dc03ab28aa34260912f5b7140 Mon Sep 17 00:00:00 2001 From: lxfight <1686540385@qq.com> Date: Thu, 4 Jun 2026 15:21:44 +0800 Subject: [PATCH 01/48] fix(kb): scope chunk_count to current KB in update_kb_stats The previous implementation called vec_db.count_documents() without any metadata filter, which counted chunks across ALL knowledge bases instead of just the target KB. This caused each KB's chunk_count statistic to reflect the total across all KBs, producing garbled stats. Fix: pass metadata_filter={'kb_id': kb_id} so only chunks belonging to the requested knowledge base are counted. --- astrbot/core/knowledge_base/kb_db_sqlite.py | 2 +- tests/test_kb_stats.py | 88 +++++++++++++++++++++ 2 files changed, 89 insertions(+), 1 deletion(-) create mode 100644 tests/test_kb_stats.py diff --git a/astrbot/core/knowledge_base/kb_db_sqlite.py b/astrbot/core/knowledge_base/kb_db_sqlite.py index 6a2cb5e0a8..1f634eec3e 100644 --- a/astrbot/core/knowledge_base/kb_db_sqlite.py +++ b/astrbot/core/knowledge_base/kb_db_sqlite.py @@ -329,7 +329,7 @@ async def get_media_by_id(self, media_id: str) -> KBMedia | None: async def update_kb_stats(self, kb_id: str, vec_db: "FaissVecDB") -> None: """更新知识库统计信息""" - chunk_cnt = await vec_db.count_documents() + chunk_cnt = await vec_db.count_documents(metadata_filter={"kb_id": kb_id}) async with self.get_db() as session, session.begin(): update_stmt = ( diff --git a/tests/test_kb_stats.py b/tests/test_kb_stats.py new file mode 100644 index 0000000000..48b72009f3 --- /dev/null +++ b/tests/test_kb_stats.py @@ -0,0 +1,88 @@ +"""Tests for knowledge base statistics accuracy.""" + +from unittest.mock import AsyncMock, MagicMock + +import pytest + +from astrbot.core.knowledge_base.kb_db_sqlite import KBSQLiteDatabase + + +class TestUpdateKbStatsChunkCountScope: + """Verify update_kb_stats scopes chunk counts to the correct KB.""" + + @staticmethod + def _patch_get_db(db: KBSQLiteDatabase) -> None: + """Replace get_db with a mock that simulates the real async-CM flow. + + In production:: + + async with self.get_db() as session, session.begin(): + ... + + Broken down: + 1. ``self.get_db()`` → async CM + 2. ``__aenter__()`` → await → session (bound via ``as``) + 3. ``session.begin()`` → second async CM + 4. ``__aenter__()`` → await → enters the transaction + + We must ensure the ``session`` yielded by step 2 has a ``begin`` that + returns a valid async CM so the second ``async with`` succeeds. + """ + session = AsyncMock() + # Step 2: __aenter__ must yield *this* session (with .begin overridden) + session.__aenter__.return_value = session + # Step 3-4: session.begin() returns an async CM → we return session itself + session.begin = MagicMock(return_value=session) + + db.get_db = MagicMock(return_value=session) + + @pytest.mark.asyncio + async def test_update_kb_stats_filters_chunk_count_by_kb_id(self): + """chunk_cnt should only count documents belonging to the target KB.""" + db = KBSQLiteDatabase.__new__(KBSQLiteDatabase) + self._patch_get_db(db) + + vec_db = AsyncMock() + vec_db.count_documents = AsyncMock(return_value=42) + + await db.update_kb_stats(kb_id="kb-abc", vec_db=vec_db) + + vec_db.count_documents.assert_awaited_once_with( + metadata_filter={"kb_id": "kb-abc"}, + ) + + @pytest.mark.asyncio + async def test_update_kb_stats_passes_different_kb_ids(self): + """Each KB update should filter chunks by its own kb_id.""" + db = KBSQLiteDatabase.__new__(KBSQLiteDatabase) + self._patch_get_db(db) + + vec_db_a = AsyncMock() + vec_db_a.count_documents = AsyncMock(return_value=10) + vec_db_b = AsyncMock() + vec_db_b.count_documents = AsyncMock(return_value=20) + + await db.update_kb_stats(kb_id="kb-alpha", vec_db=vec_db_a) + await db.update_kb_stats(kb_id="kb-beta", vec_db=vec_db_b) + + vec_db_a.count_documents.assert_awaited_once_with( + metadata_filter={"kb_id": "kb-alpha"}, + ) + vec_db_b.count_documents.assert_awaited_once_with( + metadata_filter={"kb_id": "kb-beta"}, + ) + + @pytest.mark.asyncio + async def test_update_kb_stats_zero_chunks(self): + """When a KB has no chunks, chunk_count should be set to 0.""" + db = KBSQLiteDatabase.__new__(KBSQLiteDatabase) + self._patch_get_db(db) + + vec_db = AsyncMock() + vec_db.count_documents = AsyncMock(return_value=0) + + await db.update_kb_stats(kb_id="kb-empty", vec_db=vec_db) + + vec_db.count_documents.assert_awaited_once_with( + metadata_filter={"kb_id": "kb-empty"}, + ) From 7bb50a9f3910aed969d4237e810088fcf733f0d2 Mon Sep 17 00:00:00 2001 From: lxfight <1686540385@qq.com> Date: Thu, 4 Jun 2026 16:48:55 +0800 Subject: [PATCH 02/48] fix(kb): unify sparse retrieval score direction to lower-is-better MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit FTS5 bm25() uses lower-is-better (0=perfect) while BM25Okapi fallback uses higher-is-better. When merged with reverse sort, all BM25 results systematically outranked FTS5 results regardless of relevance. RRF only uses rank positions, not score magnitudes — so this fix simply ensures both paths share the same sort direction. Changes: - FTS5: clamp negative bm25() values to 0, sort ascending - BM25 fallback: negate scores, sort ascending - Add debug logging for sparse top-5 and RRF top-5 scores --- .../knowledge_base/retrieval/rank_fusion.py | 10 + .../retrieval/sparse_retriever.py | 64 +++--- tests/test_kb_sparse_retrieval.py | 194 ++++++++++++++++++ 3 files changed, 236 insertions(+), 32 deletions(-) create mode 100644 tests/test_kb_sparse_retrieval.py diff --git a/astrbot/core/knowledge_base/retrieval/rank_fusion.py b/astrbot/core/knowledge_base/retrieval/rank_fusion.py index 40afd97484..056c59493f 100644 --- a/astrbot/core/knowledge_base/retrieval/rank_fusion.py +++ b/astrbot/core/knowledge_base/retrieval/rank_fusion.py @@ -6,6 +6,7 @@ import json from dataclasses import dataclass +from astrbot.core import logger from astrbot.core.db.vec_db.base import Result from astrbot.core.knowledge_base.kb_db_sqlite import KBSQLiteDatabase from astrbot.core.knowledge_base.retrieval.sparse_retriever import SparseResult @@ -108,6 +109,15 @@ async def fuse( reverse=True, )[:top_k] + if logger.isEnabledFor(10): # DEBUG + details = [] + for cid in sorted_ids[:5]: + d_rank = dense_ranks.get(cid, "-") + s_rank = sparse_ranks.get(cid, "-") + rrf = rrf_scores[cid] + details.append(f"{cid[:8]}(d={d_rank},s={s_rank},rrf={rrf:.4f})") + logger.debug(f"RRF top-5: {' | '.join(details)}") + # 5. 构建融合结果 fused_results = [] for identifier in sorted_ids: diff --git a/astrbot/core/knowledge_base/retrieval/sparse_retriever.py b/astrbot/core/knowledge_base/retrieval/sparse_retriever.py index f06eb50909..844b712d32 100644 --- a/astrbot/core/knowledge_base/retrieval/sparse_retriever.py +++ b/astrbot/core/knowledge_base/retrieval/sparse_retriever.py @@ -10,6 +10,7 @@ from rank_bm25 import BM25Okapi +from astrbot.core import logger from astrbot.core.knowledge_base.kb_db_sqlite import KBSQLiteDatabase from astrbot.core.knowledge_base.retrieval.tokenizer import ( load_stopwords, @@ -22,7 +23,10 @@ @dataclass class SparseResult: - """稀疏检索结果""" + """稀疏检索结果 + + score 语义: 越低越相关 (0 = 最佳匹配), 统一按升序排列后送入 RRF 融合。 + """ chunk_index: int chunk_id: str @@ -33,22 +37,11 @@ class SparseResult: class SparseRetriever: - """BM25 稀疏检索器 - - 职责: - - 基于关键词的文档检索 - - 使用 BM25 算法计算相关度 - """ + """BM25 稀疏检索器""" def __init__(self, kb_db: KBSQLiteDatabase) -> None: - """初始化稀疏检索器 - - Args: - kb_db: 知识库数据库实例 - - """ self.kb_db = kb_db - self._index_cache = {} # 缓存 BM25 索引 + self._index_cache = {} self.hit_stopwords = load_stopwords( os.path.join(os.path.dirname(__file__), "hit_stopwords.txt"), @@ -62,18 +55,13 @@ async def retrieve( ) -> list[SparseResult]: """执行稀疏检索 - Args: - query: 查询文本 - kb_ids: 知识库 ID 列表 - kb_options: 每个知识库的检索选项 - - Returns: - List[SparseResult]: 检索结果列表 - + 优先使用 FTS5 全文索引; 不可用时回退到内存 BM25。 + 结果按 score 升序排列 (lower-is-better), 直接喂给 RRF。 """ fts_results = [] fallback_kb_ids = [] query_tokens = tokenize_text(query, self.hit_stopwords) + for kb_id in kb_ids: vec_db: FaissVecDB | None = kb_options.get(kb_id, {}).get("vec_db") if not vec_db: @@ -89,6 +77,7 @@ async def retrieve( for doc in result: chunk_md = json.loads(doc["metadata"]) + # FTS5 bm25(): 0=最佳, 极短文档可能为负值 → clamp 到 0 fts_results.append( SparseResult( chunk_id=doc["doc_id"], @@ -96,7 +85,7 @@ async def retrieve( doc_id=chunk_md["kb_doc_id"], kb_id=kb_id, content=doc["text"], - score=-float(doc["score"]), + score=max(0.0, float(doc["score"])), ), ) @@ -107,8 +96,20 @@ async def retrieve( kb_ids=fallback_kb_ids, kb_options=kb_options, ) + results = fts_results + fallback_results - results.sort(key=lambda x: x.score, reverse=True) + results.sort(key=lambda x: x.score) + + if logger.isEnabledFor(10): # DEBUG + fts_top = [f"{r.chunk_id[:8]}={r.score:.4f}" for r in fts_results[:5]] + bm_top = [f"{r.chunk_id[:8]}={r.score:.4f}" for r in fallback_results[:5]] + merged_top = [f"{r.chunk_id[:8]}={r.score:.4f}" for r in results[:5]] + logger.debug( + f"Sparse top-5 | FTS5({len(fts_results)}): [{', '.join(fts_top)}] | " + f"BM25({len(fallback_results)}): [{', '.join(bm_top)}] | " + f"Merged({len(results)}): [{', '.join(merged_top)}]", + ) + return results async def _retrieve_with_bm25( @@ -117,8 +118,13 @@ async def _retrieve_with_bm25( kb_ids: list[str], kb_options: dict, ) -> list[SparseResult]: + """FTS5 不可用时的 BM25Okapi 回退路径。 + + BM25Okapi 原始分值 higher-is-better → 取反统一为 lower-is-better。 + """ top_k_sparse = 0 chunks = [] + for kb_id in kb_ids: vec_db: FaissVecDB | None = kb_options.get(kb_id, {}).get("vec_db") if not vec_db: @@ -145,18 +151,13 @@ async def _retrieve_with_bm25( if not chunks: return [] - # 2. 准备文档和索引 corpus = [chunk["text"] for chunk in chunks] tokenized_corpus = [tokenize_text(doc, self.hit_stopwords) for doc in corpus] - - # 3. 构建 BM25 索引 bm25 = BM25Okapi(tokenized_corpus) - # 4. 执行检索 tokenized_query = tokenize_text(query, self.hit_stopwords) scores = bm25.get_scores(tokenized_query) - # 5. 排序并返回 Top-K results = [] for idx, score in enumerate(scores): chunk = chunks[idx] @@ -167,10 +168,9 @@ async def _retrieve_with_bm25( doc_id=chunk["doc_id"], kb_id=chunk["kb_id"], content=chunk["text"], - score=float(score), + score=-float(score), ), ) - results.sort(key=lambda x: x.score, reverse=True) - # return results[: len(results) // len(kb_ids)] + results.sort(key=lambda x: x.score) return results[:top_k_sparse] diff --git a/tests/test_kb_sparse_retrieval.py b/tests/test_kb_sparse_retrieval.py new file mode 100644 index 0000000000..aa4fe48c0f --- /dev/null +++ b/tests/test_kb_sparse_retrieval.py @@ -0,0 +1,194 @@ +"""Tests for sparse retrieval score consistency between FTS5 and BM25 paths. + +RRF only uses rank positions, not score magnitudes. The sparse retrieval stage +just needs consistent sort direction: lower-is-better, ascending order. +""" + +import json +from unittest.mock import AsyncMock + +import pytest + +from astrbot.core.knowledge_base.retrieval.sparse_retriever import ( + SparseResult, + SparseRetriever, +) + + +def _make_fake_doc(doc_id: str, text: str, metadata: dict) -> dict: + return { + "id": hash(doc_id) % 10000, + "doc_id": doc_id, + "text": text, + "metadata": json.dumps(metadata), + "created_at": "2025-01-01T00:00:00", + "updated_at": "2025-01-01T00:00:00", + } + + +class TestSparseRetrieverScoreDirection: + """Verify FTS5 and BM25 both use lower-is-better, ascending sort.""" + + @pytest.mark.asyncio + async def test_fts5_best_match_has_lowest_score(self): + """FTS5: raw bm25=0 (perfect) → score=0, sorts first (ascending).""" + sr = SparseRetriever(kb_db=AsyncMock()) + sr._index_cache = {} + + vec_db = AsyncMock() + vec_db.document_storage.search_sparse = AsyncMock( + return_value=[ + { + "id": 1, "doc_id": "best", "text": "exact match", + "metadata": json.dumps({"chunk_index": 0, "kb_doc_id": "d1", "kb_id": "kb-a"}), + "score": 0.0, # perfect + "created_at": "", "updated_at": "", + }, + { + "id": 2, "doc_id": "worst", "text": "poor match", + "metadata": json.dumps({"chunk_index": 1, "kb_doc_id": "d1", "kb_id": "kb-a"}), + "score": 50.0, # terrible + "created_at": "", "updated_at": "", + }, + ], + ) + + kb_options = {"kb-a": {"vec_db": vec_db, "top_k_sparse": 10}} + results = await sr.retrieve(query="test", kb_ids=["kb-a"], kb_options=kb_options) + + assert len(results) == 2 + assert results[0].chunk_id == "best", f"Best should be first, got {results[0].chunk_id}" + assert results[0].score == 0.0 # lower-is-better + assert results[0].score < results[1].score # ascending + + @pytest.mark.asyncio + async def test_fts5_negative_bm25_clamped_to_zero(self): + """FTS5 bm25() negative values → clamped to 0 (same as perfect match).""" + sr = SparseRetriever(kb_db=AsyncMock()) + sr._index_cache = {} + + vec_db = AsyncMock() + vec_db.document_storage.search_sparse = AsyncMock( + return_value=[ + { + "id": 1, "doc_id": "short-doc", "text": "short", + "metadata": json.dumps({"chunk_index": 0, "kb_doc_id": "d1", "kb_id": "kb-a"}), + "score": -8.56, # FTS5 can be negative for short docs + "created_at": "", "updated_at": "", + }, + ], + ) + + kb_options = {"kb-a": {"vec_db": vec_db, "top_k_sparse": 10}} + results = await sr.retrieve(query="test", kb_ids=["kb-a"], kb_options=kb_options) + + assert len(results) == 1 + assert results[0].score == 0.0, ( + f"Negative raw bm25 should be clamped to 0, got {results[0].score}" + ) + + @pytest.mark.asyncio + async def test_bm25_fallback_negates_scores(self): + """BM25Okapi higher=better → negated to lower=better, ascending sort.""" + sr = SparseRetriever(kb_db=AsyncMock()) + sr._index_cache = {} + + vec_db = AsyncMock() + vec_db.document_storage.get_documents = AsyncMock( + return_value=[ + _make_fake_doc("chunk-best", "exact match hello world", + {"chunk_index": 0, "kb_doc_id": "d1", "kb_id": "kb-a"}), + _make_fake_doc("chunk-worst", "unrelated content here", + {"chunk_index": 0, "kb_doc_id": "d2", "kb_id": "kb-a"}), + ], + ) + + kb_options = {"kb-a": {"vec_db": vec_db, "top_k_sparse": 50}} + results = await sr._retrieve_with_bm25(query="hello", kb_ids=["kb-a"], kb_options=kb_options) + + assert len(results) == 2 + # Best match should be most negative (negated highest BM25Okapi) + assert results[0].score <= results[1].score, ( + f"Expected ascending sort (lower=better), got {[r.score for r in results]}" + ) + # Best score should be <= 0 (negation of non-negative BM25Okapi) + assert results[0].score <= 0, ( + f"BM25 fallback best match should be negative after negation, got {results[0].score}" + ) + + @pytest.mark.asyncio + async def test_merged_fts5_and_bm25_sort_correctly(self): + """Merge: FTS5 (0=best) + BM25 (neg=best) → ascending sort, both can be top.""" + fts = [ + SparseResult(chunk_id="fts-best", chunk_index=0, doc_id="d1", + kb_id="kb-a", content="a", score=0.0), + SparseResult(chunk_id="fts-mid", chunk_index=1, doc_id="d1", + kb_id="kb-a", content="b", score=3.0), + SparseResult(chunk_id="fts-worst", chunk_index=2, doc_id="d2", + kb_id="kb-a", content="c", score=12.5), + ] + bm25 = [ + SparseResult(chunk_id="bm25-good", chunk_index=0, doc_id="d3", + kb_id="kb-b", content="d", score=-15.0), # negated best + SparseResult(chunk_id="bm25-ok", chunk_index=1, doc_id="d3", + kb_id="kb-b", content="e", score=-5.0), + SparseResult(chunk_id="bm25-poor", chunk_index=2, doc_id="d4", + kb_id="kb-b", content="f", score=0.0), # negated worst + ] + + merged = fts + bm25 + merged.sort(key=lambda x: x.score) # ascending, lower=better + + # Expected: bm25-good(-15) < fts-best(0) < fts-mid(3) < bm25-ok(-5) < bm25-poor(0) < fts-worst(12.5) + # Wait: -15 < -5 < 0 < 0 < 3 < 12.5 + assert merged[0].chunk_id == "bm25-good" + assert merged[1].chunk_id == "bm25-ok" + # fts-best(0) and bm25-poor(0) tie — stable sort preserves order + assert merged[4].chunk_id == "fts-mid" + assert merged[5].chunk_id == "fts-worst" + + @pytest.mark.asyncio + async def test_fts5_and_bm25_both_contribute_to_sort(self): + """Integration: both paths produce consistent lower-is-better scores.""" + sr = SparseRetriever(kb_db=AsyncMock()) + + # KB "a" uses FTS5 + fts_vec_db = AsyncMock() + fts_vec_db.document_storage.search_sparse = AsyncMock( + return_value=[ + { + "id": 1, "doc_id": "fts-hit", "text": "test query match", + "metadata": json.dumps({"chunk_index": 0, "kb_doc_id": "d1", "kb_id": "kb-a"}), + "score": 0.0, + }, + ], + ) + + # KB "b" falls back to BM25 + bm25_vec_db = AsyncMock() + bm25_vec_db.document_storage.search_sparse = AsyncMock(return_value=None) + bm25_vec_db.document_storage.get_documents = AsyncMock( + return_value=[ + _make_fake_doc("bm25-hit", "test query result", + {"chunk_index": 0, "kb_doc_id": "d2", "kb_id": "kb-b"}), + _make_fake_doc("bm25-miss", "unrelated", + {"chunk_index": 0, "kb_doc_id": "d3", "kb_id": "kb-b"}), + ], + ) + + kb_options = { + "kb-a": {"vec_db": fts_vec_db, "top_k_sparse": 10}, + "kb-b": {"vec_db": bm25_vec_db, "top_k_sparse": 10}, + } + + results = await sr.retrieve(query="test", kb_ids=["kb-a", "kb-b"], kb_options=kb_options) + + assert len(results) >= 2 + # Ascending order + for i in range(len(results) - 1): + assert results[i].score <= results[i + 1].score, ( + f"Not sorted ascending at index {i}: {results[i].score} > {results[i+1].score}" + ) + # No out-of-range scores + for r in results: + assert r.score >= -1000.0, f"Unexpectedly low score: {r.score}" From 70e6e130b51200ab689585b48a8f9a2ba1d8809e Mon Sep 17 00:00:00 2001 From: lxfight <1686540385@qq.com> Date: Thu, 4 Jun 2026 17:54:57 +0800 Subject: [PATCH 03/48] fix(kb): rollback orphan vectors when metadata save fails after upload MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit When vec_db.insert_batch() succeeds but the subsequent metadata commit fails, vectors are already persisted in FAISS and DocumentStorage but have no corresponding KBDocument record, creating orphan data. Fix: track vectors_stored flag, and in the exception handler call vec_db.delete_documents(metadata_filters={'kb_doc_id': doc_id}) to clean up both FAISS embeddings and document storage rows. The sys.modules workaround in tests breaks the circular import chain (kb_helper → provider.manager → kb_mgr → provider.manager). --- astrbot/core/knowledge_base/kb_helper.py | 16 +++ tests/test_kb_upload_rollback.py | 163 +++++++++++++++++++++++ 2 files changed, 179 insertions(+) create mode 100644 tests/test_kb_upload_rollback.py diff --git a/astrbot/core/knowledge_base/kb_helper.py b/astrbot/core/knowledge_base/kb_helper.py index c29e45876d..7893a489bc 100644 --- a/astrbot/core/knowledge_base/kb_helper.py +++ b/astrbot/core/knowledge_base/kb_helper.py @@ -243,6 +243,7 @@ async def upload_document( doc_id = str(uuid.uuid4()) media_paths: list[Path] = [] file_size = 0 + vectors_stored = False # 标记向量是否已写入, 用于失败回滚 # file_path = self.kb_files_dir / f"{doc_id}.{file_type}" # async with aiofiles.open(file_path, "wb") as f: @@ -391,6 +392,7 @@ async def embedding_progress_callback(current, total) -> None: max_retries=max_retries, progress_callback=embedding_progress_callback, ) + vectors_stored = True except KnowledgeBaseUploadError: raise except Exception as exc: @@ -453,6 +455,20 @@ async def embedding_progress_callback(current, total) -> None: logger.warning(f"上传文档失败: {e}", extra={"details": e.details}) else: logger.error(f"上传文档失败: {e}", exc_info=True) + + # 回滚已写入的向量, 防止孤数据 + if vectors_stored: + try: + vec_db: FaissVecDB = self.vec_db # type: ignore + await vec_db.delete_documents( + metadata_filters={"kb_doc_id": doc_id}, + ) + logger.info(f"已清理文档 {doc_id} 的孤数据向量") + except Exception as cleanup_err: + logger.error( + f"清理文档 {doc_id} 向量回滚失败: {cleanup_err}", + ) + # if file_path.exists(): # file_path.unlink() diff --git a/tests/test_kb_upload_rollback.py b/tests/test_kb_upload_rollback.py new file mode 100644 index 0000000000..1d9b2ecdc6 --- /dev/null +++ b/tests/test_kb_upload_rollback.py @@ -0,0 +1,163 @@ +"""Tests for vector rollback on upload failure. + +The knowledge_base package has a circular import chain: + kb_helper → provider.manager → persona_mgr → ... → kb_mgr → provider.manager + +We break the chain by stubbing provider.manager in sys.modules before any +import, then patch what we need at the instance level. +""" + +import sys +from unittest.mock import AsyncMock, MagicMock, PropertyMock, patch + +import pytest + +# ── Break circular import BEFORE any knowledge_base module is touched ── +_mock_pm = MagicMock() +_mock_pm.ProviderManager = MagicMock() +sys.modules["astrbot.core.provider.manager"] = _mock_pm + + +def _build_helper(): + from astrbot.core.knowledge_base.kb_helper import KBHelper + from astrbot.core.knowledge_base.models import KnowledgeBase + + kb = KnowledgeBase( + kb_name="test-kb", kb_id="kb-test-1", + embedding_provider_id="emb-1", + chunk_size=512, chunk_overlap=50, + ) + helper = KBHelper.__new__(KBHelper) + helper.kb = kb + helper.kb_db = AsyncMock() + helper.kb_dir = MagicMock() + helper.kb_medias_dir = MagicMock() + helper.kb_files_dir = MagicMock() + helper.prov_mgr = MagicMock() + helper.chunker = AsyncMock() + helper.vec_db = AsyncMock() + helper._ensure_vec_db = AsyncMock() + helper.init_error = None + return helper + + +def _mock_parser(mock_select): + parser = AsyncMock() + result = MagicMock() + type(result).text = PropertyMock(return_value="hello world test content") + type(result).media = PropertyMock(return_value=[]) + parser.parse = AsyncMock(return_value=result) + mock_select.return_value = parser + + +class TestUploadDocumentRollback: + """Verify vectors are cleaned up when metadata save fails after insert.""" + + @pytest.mark.asyncio + async def test_rollback_when_metadata_save_fails(self): + from astrbot.core.exceptions import KnowledgeBaseUploadError + + with patch( + "astrbot.core.knowledge_base.kb_helper.select_parser", + new_callable=AsyncMock, + ) as mock_select, patch( + "astrbot.core.knowledge_base.kb_helper._compact_chunks", + return_value=["chunk 1", "chunk 2", "chunk 3"], + ): + _mock_parser(mock_select) + helper = _build_helper() + helper.vec_db.insert_batch = AsyncMock(return_value=[1, 2, 3]) + helper.vec_db.delete_documents = AsyncMock() + helper.kb_db.get_db.side_effect = RuntimeError("DB connection lost") + + with pytest.raises(KnowledgeBaseUploadError) as exc_info: + await helper.upload_document( + file_name="test.txt", + file_content=b"hello world", + file_type="txt", + ) + + assert exc_info.value.stage == "metadata" + helper.vec_db.delete_documents.assert_awaited_once() + + @pytest.mark.asyncio + async def test_no_rollback_when_insert_fails(self): + from astrbot.core.exceptions import KnowledgeBaseUploadError + + with patch( + "astrbot.core.knowledge_base.kb_helper.select_parser", + new_callable=AsyncMock, + ) as mock_select, patch( + "astrbot.core.knowledge_base.kb_helper._compact_chunks", + return_value=["chunk 1"], + ): + _mock_parser(mock_select) + helper = _build_helper() + helper.vec_db.insert_batch.side_effect = KnowledgeBaseUploadError( + stage="embedding", user_message="模拟失败", details={}, + ) + helper.vec_db.delete_documents = AsyncMock() + + with pytest.raises(KnowledgeBaseUploadError) as exc_info: + await helper.upload_document( + file_name="test.txt", file_content=b"hello", file_type="txt", + ) + + assert exc_info.value.stage == "embedding" + helper.vec_db.delete_documents.assert_not_awaited() + + @pytest.mark.asyncio + async def test_cleanup_failure_does_not_suppress_original_error(self): + from astrbot.core.exceptions import KnowledgeBaseUploadError + + with patch( + "astrbot.core.knowledge_base.kb_helper.select_parser", + new_callable=AsyncMock, + ) as mock_select, patch( + "astrbot.core.knowledge_base.kb_helper._compact_chunks", + return_value=["chunk 1"], + ): + _mock_parser(mock_select) + helper = _build_helper() + helper.vec_db.insert_batch = AsyncMock(return_value=[1]) + helper.vec_db.delete_documents.side_effect = RuntimeError("cleanup fail") + helper.kb_db.get_db.side_effect = RuntimeError("DB lost") + + with pytest.raises(KnowledgeBaseUploadError) as exc_info: + await helper.upload_document( + file_name="test.txt", file_content=b"hello", file_type="txt", + ) + + assert exc_info.value.stage == "metadata" + helper.vec_db.delete_documents.assert_awaited_once() + + @pytest.mark.asyncio + async def test_no_rollback_on_success(self): + with patch( + "astrbot.core.knowledge_base.kb_helper.select_parser", + new_callable=AsyncMock, + ) as mock_select, patch( + "astrbot.core.knowledge_base.kb_helper._compact_chunks", + return_value=["chunk 1", "chunk 2"], + ): + _mock_parser(mock_select) + helper = _build_helper() + + # session mock that survives `async with ... as session, session.begin():` + session = AsyncMock() + session.__aenter__.return_value = session # as clause gets this session + session.begin = MagicMock(return_value=session) # second async with + helper.kb_db.get_db = MagicMock(return_value=session) + helper.kb_db.update_kb_stats = AsyncMock() + helper.vec_db.insert_batch = AsyncMock(return_value=[1, 2]) + helper.vec_db.delete_documents = AsyncMock() + helper.vec_db.count_documents = AsyncMock(return_value=2) + helper.refresh_kb = AsyncMock() + helper.refresh_document = AsyncMock() + + doc = await helper.upload_document( + file_name="test.txt", file_content=b"hello world", file_type="txt", + ) + + assert doc is not None + helper.vec_db.delete_documents.assert_not_awaited() From 2b9e7d2ffed40dafc610b4a31300ecd21c3300a6 Mon Sep 17 00:00:00 2001 From: lxfight <1686540385@qq.com> Date: Thu, 4 Jun 2026 19:09:06 +0800 Subject: [PATCH 04/48] fix(sidebar): prevent memory leak in upload_tasks/upload_progress dicts MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit KnowledgeBaseRoute is a process-lifetime singleton; its upload_tasks and upload_progress dicts grew without bound because the cleanup code in get_upload_progress was commented out. Two-tier cleanup: 1. Immediate: remove task on poll when completed/failed 2. Safety net: background tasks schedule delayed cleanup (5 min) via _schedule_delayed_cleanup, guarding against clients that never poll get_upload_progress Also adds _cleanup_task() helper — idempotent, uses pop(key, None). --- astrbot/dashboard/routes/knowledge_base.py | 30 +- tests/test_kb_upload_memory_leak.py | 323 +++++++++++++++++++++ 2 files changed, 349 insertions(+), 4 deletions(-) create mode 100644 tests/test_kb_upload_memory_leak.py diff --git a/astrbot/dashboard/routes/knowledge_base.py b/astrbot/dashboard/routes/knowledge_base.py index 1b6f7a435d..5dcd2be2f9 100644 --- a/astrbot/dashboard/routes/knowledge_base.py +++ b/astrbot/dashboard/routes/knowledge_base.py @@ -87,6 +87,21 @@ def _set_task_result( if task_id in self.upload_progress: self.upload_progress[task_id]["status"] = status + def _cleanup_task(self, task_id: str) -> None: + """清理已完成/失败的任务,释放内存。幂等操作。""" + self.upload_tasks.pop(task_id, None) + self.upload_progress.pop(task_id, None) + + async def _schedule_delayed_cleanup( + self, task_id: str, delay_seconds: int = 300 + ) -> None: + """延迟清理任务,作为客户端不轮询时的兜底机制。""" + try: + await asyncio.sleep(delay_seconds) + except asyncio.CancelledError: + return + self._cleanup_task(task_id) + def _update_progress( self, task_id: str, @@ -220,6 +235,9 @@ async def _background_upload_task( logger.error(f"后台上传任务 {task_id} 失败: {e}") logger.error(traceback.format_exc()) self._set_task_result(task_id, "failed", error=str(e)) + finally: + # 兜底清理:防止客户端不轮询 get_upload_progress 导致内存泄漏 + asyncio.create_task(self._schedule_delayed_cleanup(task_id)) async def _background_import_task( self, @@ -310,6 +328,8 @@ async def _background_import_task( logger.error(f"后台导入任务 {task_id} 失败: {e}") logger.error(traceback.format_exc()) self._set_task_result(task_id, "failed", error=str(e)) + finally: + asyncio.create_task(self._schedule_delayed_cleanup(task_id)) async def list_kbs(self): """获取知识库列表 @@ -920,15 +940,15 @@ async def get_upload_progress(self): # 如果任务完成,返回结果 if status == "completed": response_data["result"] = task_info["result"] - # 清理已完成的任务 - # del self.upload_tasks[task_id] - # if task_id in self.upload_progress: - # del self.upload_progress[task_id] # 如果任务失败,返回错误信息 if status == "failed": response_data["error"] = task_info["error"] + # 清理已完成/失败的任务,释放内存 + if status in ("completed", "failed"): + self._cleanup_task(task_id) + return Response().ok(response_data).__dict__ except Exception as e: @@ -1286,3 +1306,5 @@ async def _background_upload_from_url_task( logger.error(f"后台上传URL任务 {task_id} 失败: {e}") logger.error(traceback.format_exc()) self._set_task_result(task_id, "failed", error=str(e)) + finally: + asyncio.create_task(self._schedule_delayed_cleanup(task_id)) diff --git a/tests/test_kb_upload_memory_leak.py b/tests/test_kb_upload_memory_leak.py new file mode 100644 index 0000000000..a596350b68 --- /dev/null +++ b/tests/test_kb_upload_memory_leak.py @@ -0,0 +1,323 @@ +"""Tests for #1: Memory leak fix in upload_tasks / upload_progress. + +Verifies: +- Completed/failed tasks are cleaned up on poll (get_upload_progress) +- Processing/pending tasks are NOT cleaned up +- Delayed cleanup is scheduled by background tasks (finally block) +- Delayed cleanup actually removes after sleep +- Cleanup is idempotent +- CancelledError is handled gracefully +""" + +import asyncio +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + + +class TestUploadTaskCleanup: + """Verify task cleanup in get_upload_progress.""" + + @pytest.mark.asyncio + async def test_cleanup_on_completed_poll(self): + """Completed task cleaned up when client polls for result.""" + from astrbot.dashboard.routes.knowledge_base import KnowledgeBaseRoute + + route = KnowledgeBaseRoute.__new__(KnowledgeBaseRoute) + route.upload_tasks = { + "task-1": { + "status": "completed", + "result": {"uploaded": []}, + "error": None, + }, + } + route.upload_progress = { + "task-1": {"status": "completed", "file_index": 0, "file_total": 1}, + } + + route._cleanup_task("task-1") + + assert "task-1" not in route.upload_tasks + assert "task-1" not in route.upload_progress + + @pytest.mark.asyncio + async def test_cleanup_on_failed_poll(self): + """Failed task cleaned up when client polls for result.""" + from astrbot.dashboard.routes.knowledge_base import KnowledgeBaseRoute + + route = KnowledgeBaseRoute.__new__(KnowledgeBaseRoute) + route.upload_tasks = { + "task-1": { + "status": "failed", + "result": None, + "error": "upload failed", + }, + } + route.upload_progress = { + "task-1": {"status": "failed", "file_index": 0, "file_total": 1}, + } + + route._cleanup_task("task-1") + + assert "task-1" not in route.upload_tasks + assert "task-1" not in route.upload_progress + + def test_no_cleanup_for_processing(self): + """_cleanup_task only removes what it's told — caller decides status filter.""" + from astrbot.dashboard.routes.knowledge_base import KnowledgeBaseRoute + + route = KnowledgeBaseRoute.__new__(KnowledgeBaseRoute) + route.upload_tasks = { + "task-1": {"status": "processing", "result": None, "error": None}, + } + route.upload_progress = { + "task-1": {"status": "processing", "file_index": 1, "file_total": 5}, + } + + # _cleanup_task is status-agnostic; the caller (get_upload_progress) + # only calls it for completed/failed. This test verifies that + # processing entries CAN be cleaned up by the method, not that + # get_upload_progress cleans them up. + route._cleanup_task("task-1") + + assert "task-1" not in route.upload_tasks + assert "task-1" not in route.upload_progress + + def test_cleanup_task_idempotent(self): + """Calling _cleanup_task twice is safe (idempotent).""" + from astrbot.dashboard.routes.knowledge_base import KnowledgeBaseRoute + + route = KnowledgeBaseRoute.__new__(KnowledgeBaseRoute) + route.upload_tasks = {"task-1": {}} + route.upload_progress = {"task-1": {}} + + route._cleanup_task("task-1") + route._cleanup_task("task-1") # second call should not raise + route._cleanup_task("never-existed") # non-existent should not raise + + assert "task-1" not in route.upload_tasks + assert "task-1" not in route.upload_progress + + @pytest.mark.asyncio + async def test_delayed_cleanup_removes_after_sleep(self): + """_schedule_delayed_cleanup removes task after delay.""" + from astrbot.dashboard.routes.knowledge_base import KnowledgeBaseRoute + + route = KnowledgeBaseRoute.__new__(KnowledgeBaseRoute) + route.upload_tasks = {"task-1": {"status": "completed"}} + route.upload_progress = {"task-1": {"status": "completed"}} + + # Use a very short delay for test + await route._schedule_delayed_cleanup("task-1", delay_seconds=0.01) + + assert "task-1" not in route.upload_tasks + assert "task-1" not in route.upload_progress + + @pytest.mark.asyncio + async def test_delayed_cleanup_idempotent(self): + """Delayed cleanup is safe even if task already removed by poll.""" + from astrbot.dashboard.routes.knowledge_base import KnowledgeBaseRoute + + route = KnowledgeBaseRoute.__new__(KnowledgeBaseRoute) + route.upload_tasks = {} + route.upload_progress = {} + + # Should not raise even though task doesn't exist + await route._schedule_delayed_cleanup("task-1", delay_seconds=0.01) + + @pytest.mark.asyncio + async def test_delayed_cleanup_cancelled_error_graceful(self): + """CancelledError inside _schedule_delayed_cleanup is caught, task not cleaned.""" + from astrbot.dashboard.routes.knowledge_base import KnowledgeBaseRoute + + route = KnowledgeBaseRoute.__new__(KnowledgeBaseRoute) + route.upload_tasks = {"task-1": {"status": "completed"}} + route.upload_progress = {"task-1": {"status": "completed"}} + + # Create the cleanup task + cleanup_task = asyncio.create_task( + route._schedule_delayed_cleanup("task-1", delay_seconds=10) + ) + await asyncio.sleep(0.02) # let it start sleeping + cleanup_task.cancel() + + # The outer task will get CancelledError, but the inner method catches it + try: + await cleanup_task + except asyncio.CancelledError: + pass # the asyncio.create_task wrapper gets cancelled + + # Since CancelledError was caught internally and returned early, + # the task data should still be there + assert "task-1" in route.upload_tasks + assert "task-1" in route.upload_progress + + # ── Background task finally-block tests ── + + @pytest.mark.asyncio + async def test_background_upload_schedules_cleanup_on_success(self): + """_background_upload_task schedules delayed cleanup in finally block.""" + from astrbot.dashboard.routes.knowledge_base import KnowledgeBaseRoute + + route = KnowledgeBaseRoute.__new__(KnowledgeBaseRoute) + route.upload_tasks = {} + route.upload_progress = {} + route._init_task = MagicMock() + route._set_task_result = MagicMock() + route._update_progress = MagicMock() + route._make_progress_callback = MagicMock(return_value=AsyncMock()) + route._cleanup_task = MagicMock() + # Mock _schedule_delayed_cleanup to be a real coroutine + original = route._schedule_delayed_cleanup + + async def fake_schedule(*args, **kwargs): + route._cleanup_task(*args) + await asyncio.sleep(0) + + route._schedule_delayed_cleanup = fake_schedule + + kb_helper = AsyncMock() + kb_helper.upload_document = AsyncMock(return_value=MagicMock( + model_dump=MagicMock(return_value={"doc_id": "doc-1"}), + )) + + files = [{"file_name": "test.txt", "file_content": b"hello", "file_type": "txt"}] + + await route._background_upload_task( + task_id="task-1", + kb_helper=kb_helper, + files_to_upload=files, + chunk_size=512, + chunk_overlap=50, + batch_size=32, + tasks_limit=3, + max_retries=3, + ) + + # The finally block should have triggered _cleanup_task via + # the asyncio.create_task(_schedule_delayed_cleanup) call. + # Since we used a real async sleep of 0, the task should complete. + await asyncio.sleep(0.05) + route._cleanup_task.assert_called_with("task-1") + + @pytest.mark.asyncio + async def test_background_upload_schedules_cleanup_on_failure(self): + """Finally block still runs even when task fails.""" + from astrbot.dashboard.routes.knowledge_base import KnowledgeBaseRoute + + route = KnowledgeBaseRoute.__new__(KnowledgeBaseRoute) + route.upload_tasks = {} + route.upload_progress = {} + route._init_task = MagicMock() + route._set_task_result = MagicMock() + route._update_progress = MagicMock() + route._make_progress_callback = MagicMock(return_value=AsyncMock()) + route._cleanup_task = MagicMock() + route._format_failed_doc_error = MagicMock(return_value="test error") + + async def fake_schedule(*args, **kwargs): + route._cleanup_task(*args) + await asyncio.sleep(0) + + route._schedule_delayed_cleanup = fake_schedule + + kb_helper = AsyncMock() + kb_helper.upload_document = AsyncMock( + side_effect=RuntimeError("upload exploded"), + ) + + files = [{"file_name": "test.txt", "file_content": b"hello", "file_type": "txt"}] + + await route._background_upload_task( + task_id="task-1", + kb_helper=kb_helper, + files_to_upload=files, + chunk_size=512, + chunk_overlap=50, + batch_size=32, + tasks_limit=3, + max_retries=3, + ) + + await asyncio.sleep(0.05) + route._cleanup_task.assert_called_with("task-1") + + @pytest.mark.asyncio + async def test_background_import_schedules_cleanup(self): + """_background_import_task schedules delayed cleanup in finally block.""" + from astrbot.dashboard.routes.knowledge_base import KnowledgeBaseRoute + + route = KnowledgeBaseRoute.__new__(KnowledgeBaseRoute) + route.upload_tasks = {} + route.upload_progress = {} + route._init_task = MagicMock() + route._set_task_result = MagicMock() + route._update_progress = MagicMock() + route._make_progress_callback = MagicMock(return_value=AsyncMock()) + route._cleanup_task = MagicMock() + + async def fake_schedule(*args, **kwargs): + route._cleanup_task(*args) + await asyncio.sleep(0) + + route._schedule_delayed_cleanup = fake_schedule + + kb_helper = AsyncMock() + kb_helper.upload_document = AsyncMock(return_value=MagicMock( + model_dump=MagicMock(return_value={"doc_id": "doc-1"}), + )) + + documents = [{"file_name": "test.txt", "chunks": ["chunk 1", "chunk 2"]}] + + await route._background_import_task( + task_id="task-2", + kb_helper=kb_helper, + documents=documents, + batch_size=32, + tasks_limit=3, + max_retries=3, + ) + + await asyncio.sleep(0.05) + route._cleanup_task.assert_called_with("task-2") + + @pytest.mark.asyncio + async def test_background_url_upload_schedules_cleanup(self): + """_background_upload_from_url_task schedules delayed cleanup.""" + from astrbot.dashboard.routes.knowledge_base import KnowledgeBaseRoute + + route = KnowledgeBaseRoute.__new__(KnowledgeBaseRoute) + route.upload_tasks = {} + route.upload_progress = {} + route._init_task = MagicMock() + route._set_task_result = MagicMock() + route._update_progress = MagicMock() + route._make_progress_callback = MagicMock(return_value=AsyncMock()) + route._cleanup_task = MagicMock() + + async def fake_schedule(*args, **kwargs): + route._cleanup_task(*args) + await asyncio.sleep(0) + + route._schedule_delayed_cleanup = fake_schedule + + kb_helper = AsyncMock() + kb_helper.upload_from_url = AsyncMock(return_value=MagicMock( + model_dump=MagicMock(return_value={"doc_id": "doc-1"}), + )) + + await route._background_upload_from_url_task( + task_id="task-3", + kb_helper=kb_helper, + url="https://example.com", + chunk_size=512, + chunk_overlap=50, + batch_size=32, + tasks_limit=3, + max_retries=3, + enable_cleaning=False, + cleaning_provider_id=None, + ) + + await asyncio.sleep(0.05) + route._cleanup_task.assert_called_with("task-3") From d921fe9fd920bbbfdb452816d0778edfb297d2fd Mon Sep 17 00:00:00 2001 From: lxfight <1686540385@qq.com> Date: Thu, 4 Jun 2026 19:57:17 +0800 Subject: [PATCH 05/48] feat(kb): add batch delete documents API Adds POST /api/kb/document/batch-delete endpoint for deleting multiple documents in a single request (max 100 per call). Implementation: - kb_db_sqlite: delete_documents_by_ids() uses single SQL IN clause for kb.db delete, then asyncio.gather for parallel vec_db cleanup - kb_helper: delete_documents() calls kb_db, then updates stats once - Best-effort semantics: one vec_db failure doesn't block others Previously, deleting N documents required N HTTP round-trips and N separate update_kb_stats calls. Now it's 1 request + 1 stat update. --- astrbot/core/knowledge_base/kb_db_sqlite.py | 39 +++++ astrbot/core/knowledge_base/kb_helper.py | 16 ++ astrbot/dashboard/routes/knowledge_base.py | 51 ++++++ tests/test_kb_batch_delete.py | 162 ++++++++++++++++++++ 4 files changed, 268 insertions(+) create mode 100644 tests/test_kb_batch_delete.py diff --git a/astrbot/core/knowledge_base/kb_db_sqlite.py b/astrbot/core/knowledge_base/kb_db_sqlite.py index 1f634eec3e..4704f3d19c 100644 --- a/astrbot/core/knowledge_base/kb_db_sqlite.py +++ b/astrbot/core/knowledge_base/kb_db_sqlite.py @@ -1,3 +1,4 @@ +import asyncio from contextlib import asynccontextmanager from pathlib import Path from typing import TYPE_CHECKING @@ -311,6 +312,44 @@ async def delete_document_by_id(self, doc_id: str, vec_db: "FaissVecDB") -> None # 在 vec db 中删除相关向量 await vec_db.delete_documents(metadata_filters={"kb_doc_id": doc_id}) + async def delete_documents_by_ids( + self, doc_ids: list[str], vec_db: "FaissVecDB", + ) -> dict[str, bool]: + """批量删除文档及其向量数据。 + + 单个文档的 vec_db 删除失败不影响其他文档(best-effort)。 + """ + if not doc_ids: + return {} + + # 批量从知识库表中删除 + async with self.get_db() as session, session.begin(): + delete_stmt = delete(KBDocument).where( + col(KBDocument.doc_id).in_(doc_ids), + ) + await session.execute(delete_stmt) + + # 并行清理 vec_db(向量 + SQLite 文档存储) + async def _delete_one(doc_id: str) -> tuple[str, bool]: + try: + await vec_db.delete_documents( + metadata_filters={"kb_doc_id": doc_id}, + ) + return doc_id, True + except Exception as e: + logger.error( + f"删除文档 {doc_id} 的向量数据失败: {e}", + ) + return doc_id, False + + results: dict[str, bool] = {} + tasks = [_delete_one(doc_id) for doc_id in doc_ids] + vec_results = await asyncio.gather(*tasks) + for doc_id, success in vec_results: + results[doc_id] = success + + return results + # ===== 多媒体查询 ===== async def list_media_by_doc(self, doc_id: str) -> list[KBMedia]: diff --git a/astrbot/core/knowledge_base/kb_helper.py b/astrbot/core/knowledge_base/kb_helper.py index 7893a489bc..0b54f342df 100644 --- a/astrbot/core/knowledge_base/kb_helper.py +++ b/astrbot/core/knowledge_base/kb_helper.py @@ -507,6 +507,22 @@ async def delete_document(self, doc_id: str) -> None: ) await self.refresh_kb() + async def delete_documents(self, doc_ids: list[str]) -> dict[str, bool]: + """批量删除文档,单次更新统计。 + + vec_db 删除失败不阻塞其他文档(best-effort)。 + """ + results = await self.kb_db.delete_documents_by_ids( + doc_ids=doc_ids, + vec_db=self.vec_db, # type: ignore + ) + await self.kb_db.update_kb_stats( + kb_id=self.kb.kb_id, + vec_db=self.vec_db, # type: ignore + ) + await self.refresh_kb() + return results + async def delete_chunk(self, chunk_id: str, doc_id: str) -> None: """删除单个文本块及其相关数据""" vec_db: FaissVecDB = self.vec_db # type: ignore diff --git a/astrbot/dashboard/routes/knowledge_base.py b/astrbot/dashboard/routes/knowledge_base.py index 5dcd2be2f9..769e82caa8 100644 --- a/astrbot/dashboard/routes/knowledge_base.py +++ b/astrbot/dashboard/routes/knowledge_base.py @@ -55,6 +55,7 @@ def __init__( "/kb/document/upload/progress": ("GET", self.get_upload_progress), "/kb/document/get": ("GET", self.get_document), "/kb/document/delete": ("POST", self.delete_document), + "/kb/document/batch-delete": ("POST", self.batch_delete_documents), # # 块管理 "/kb/chunk/list": ("GET", self.list_chunks), "/kb/chunk/delete": ("POST", self.delete_chunk), @@ -1019,6 +1020,56 @@ async def delete_document(self): logger.error(traceback.format_exc()) return Response().error(f"删除文档失败: {e!s}").__dict__ + async def batch_delete_documents(self): + """批量删除文档 + + Body: + - kb_id: 知识库 ID (必填) + - doc_ids: 文档 ID 列表 (必填, 最多 100 个) + """ + try: + kb_manager = self._get_kb_manager() + data = await request.json + + kb_id = data.get("kb_id") + if not kb_id: + return Response().error("缺少参数 kb_id").__dict__ + doc_ids = data.get("doc_ids") + if not doc_ids or not isinstance(doc_ids, list): + return Response().error("缺少参数 doc_ids 或格式错误").__dict__ + if len(doc_ids) > 100: + return Response().error("最多只能批量删除 100 个文档").__dict__ + + kb_helper = await kb_manager.get_kb(kb_id) + if not kb_helper: + return Response().error("知识库不存在").__dict__ + + results = await kb_helper.delete_documents(doc_ids) + + success_count = sum(1 for v in results.values() if v) + failed_count = len(doc_ids) - success_count + + return ( + Response() + .ok( + { + "results": results, + "total": len(doc_ids), + "success_count": success_count, + "failed_count": failed_count, + }, + "批量删除完成", + ) + .__dict__ + ) + + except ValueError as e: + return Response().error(str(e)).__dict__ + except Exception as e: + logger.error(f"批量删除文档失败: {e}") + logger.error(traceback.format_exc()) + return Response().error(f"批量删除文档失败: {e!s}").__dict__ + async def delete_chunk(self): """删除文本块 diff --git a/tests/test_kb_batch_delete.py b/tests/test_kb_batch_delete.py new file mode 100644 index 0000000000..5f3c47189f --- /dev/null +++ b/tests/test_kb_batch_delete.py @@ -0,0 +1,162 @@ +"""Tests for #3: Batch delete documents API. + +Verifies: +- Batch delete from kb.db (single SQL IN clause) +- Parallel vec_db cleanup +- Single update_kb_stats call (not N calls) +- Best-effort semantics: one failure doesn't block others +- Empty list edge case + +NOTE: The knowledge_base package has a circular import chain: + kb_helper → provider.manager → persona_mgr → ... → kb_mgr → provider.manager +We break the chain by stubbing provider.manager in sys.modules before any import. +""" + +import sys +from unittest.mock import AsyncMock, MagicMock, call + +import pytest + +# ── Break circular import BEFORE any knowledge_base module is touched ── +_mock_pm = MagicMock() +_mock_pm.ProviderManager = MagicMock() +sys.modules["astrbot.core.provider.manager"] = _mock_pm + + +def _build_helper(): + from astrbot.core.knowledge_base.kb_helper import KBHelper + from astrbot.core.knowledge_base.models import KnowledgeBase + + kb = KnowledgeBase( + kb_name="test-kb", kb_id="kb-test-1", + embedding_provider_id="emb-1", + chunk_size=512, chunk_overlap=50, + ) + helper = KBHelper.__new__(KBHelper) + helper.kb = kb + helper.kb_db = AsyncMock() + helper.vec_db = AsyncMock() + helper.refresh_kb = AsyncMock() + return helper + + +class TestBatchDeleteKbDb: + """Verify batch delete at the kb_db_sqlite layer.""" + + @pytest.mark.asyncio + async def test_delete_documents_by_ids_empty_list(self): + """Empty list returns empty dict.""" + from astrbot.core.knowledge_base.kb_db_sqlite import KBSQLiteDatabase + + kb_db = KBSQLiteDatabase.__new__(KBSQLiteDatabase) + vec_db = AsyncMock() + + results = await kb_db.delete_documents_by_ids([], vec_db) + + assert results == {} + vec_db.delete_documents.assert_not_awaited() + + @pytest.mark.asyncio + async def test_delete_documents_by_ids_batch_kb_db(self): + """Documents deleted from kb.db via single IN-clause SQL.""" + from astrbot.core.knowledge_base.kb_db_sqlite import KBSQLiteDatabase + + kb_db = KBSQLiteDatabase.__new__(KBSQLiteDatabase) + + session = AsyncMock() + session.__aenter__.return_value = session + session.begin = MagicMock(return_value=session) + kb_db.get_db = MagicMock(return_value=session) + + vec_db = AsyncMock() + vec_db.delete_documents = AsyncMock() + + results = await kb_db.delete_documents_by_ids( + ["doc-1", "doc-2", "doc-3"], vec_db, + ) + + assert results == {"doc-1": True, "doc-2": True, "doc-3": True} + assert vec_db.delete_documents.await_count == 3 + vec_db.delete_documents.assert_has_awaits( + [ + call(metadata_filters={"kb_doc_id": "doc-1"}), + call(metadata_filters={"kb_doc_id": "doc-2"}), + call(metadata_filters={"kb_doc_id": "doc-3"}), + ], + any_order=True, + ) + + @pytest.mark.asyncio + async def test_delete_documents_best_effort(self): + """One vec_db failure doesn't block other deletions.""" + from astrbot.core.knowledge_base.kb_db_sqlite import KBSQLiteDatabase + + kb_db = KBSQLiteDatabase.__new__(KBSQLiteDatabase) + + session = AsyncMock() + session.__aenter__.return_value = session + session.begin = MagicMock(return_value=session) + kb_db.get_db = MagicMock(return_value=session) + + vec_db = AsyncMock() + + async def _delete_side_effect(metadata_filters): + doc_id = metadata_filters["kb_doc_id"] + if doc_id == "doc-2": + raise RuntimeError("vector delete failed") + + vec_db.delete_documents = AsyncMock(side_effect=_delete_side_effect) + + results = await kb_db.delete_documents_by_ids( + ["doc-1", "doc-2", "doc-3"], vec_db, + ) + + assert results == {"doc-1": True, "doc-2": False, "doc-3": True} + assert vec_db.delete_documents.await_count == 3 + + +class TestHelperBatchDelete: + """Verify batch delete at the kb_helper layer.""" + + @pytest.mark.asyncio + async def test_delete_documents_updates_stats_once(self): + """update_kb_stats is called exactly once, not N times.""" + helper = _build_helper() + helper.kb_db.delete_documents_by_ids = AsyncMock( + return_value={"doc-1": True, "doc-2": True}, + ) + + results = await helper.delete_documents(["doc-1", "doc-2"]) + + assert results == {"doc-1": True, "doc-2": True} + helper.kb_db.update_kb_stats.assert_awaited_once_with( + kb_id="kb-test-1", vec_db=helper.vec_db, + ) + helper.refresh_kb.assert_awaited_once() + + @pytest.mark.asyncio + async def test_delete_documents_empty_list(self): + """Empty list delegates to kb_db layer (returns empty dict).""" + helper = _build_helper() + helper.kb_db.delete_documents_by_ids = AsyncMock(return_value={}) + + results = await helper.delete_documents([]) + + assert results == {} + helper.kb_db.update_kb_stats.assert_awaited_once() + helper.refresh_kb.assert_awaited_once() + + @pytest.mark.asyncio + async def test_delete_documents_preserves_failures(self): + """Failures from kb_db layer are propagated in the result dict.""" + helper = _build_helper() + helper.kb_db.delete_documents_by_ids = AsyncMock( + return_value={"doc-1": True, "doc-2": False, "doc-3": True}, + ) + + results = await helper.delete_documents(["doc-1", "doc-2", "doc-3"]) + + assert results == {"doc-1": True, "doc-2": False, "doc-3": True} + # stats still updated once even with partial failures + helper.kb_db.update_kb_stats.assert_awaited_once() + helper.refresh_kb.assert_awaited_once() From 6a9c6b75dc8742236b17c7c1708c62b312fbbdd5 Mon Sep 17 00:00:00 2001 From: lxfight <1686540385@qq.com> Date: Thu, 4 Jun 2026 20:18:36 +0800 Subject: [PATCH 06/48] fix(kb): cap BM25 fallback at 10K chunks to prevent OOM _retrieve_with_bm25 previously loaded ALL documents with no limit and no kb_id filter, risking OOM on large knowledge bases when FTS5 is unavailable. Changes: - Pass metadata_filters={'kb_id': kb_id} to scope the query - Cap loaded documents at MAX_BM25_DOCS=10,000 per KB - Log warning when cap is hit so operators know FTS5 needs attention --- .../retrieval/sparse_retriever.py | 15 ++++- tests/test_kb_sparse_retrieval.py | 55 +++++++++++++++++++ 2 files changed, 67 insertions(+), 3 deletions(-) diff --git a/astrbot/core/knowledge_base/retrieval/sparse_retriever.py b/astrbot/core/knowledge_base/retrieval/sparse_retriever.py index 844b712d32..2b213bef0f 100644 --- a/astrbot/core/knowledge_base/retrieval/sparse_retriever.py +++ b/astrbot/core/knowledge_base/retrieval/sparse_retriever.py @@ -112,6 +112,9 @@ async def retrieve( return results + # BM25 回退路径单次最多加载的文档数,防止 OOM + MAX_BM25_DOCS = 10_000 + async def _retrieve_with_bm25( self, query: str, @@ -121,6 +124,7 @@ async def _retrieve_with_bm25( """FTS5 不可用时的 BM25Okapi 回退路径。 BM25Okapi 原始分值 higher-is-better → 取反统一为 lower-is-better。 + 单 KB 最多加载 MAX_BM25_DOCS 条 chunk,超限时截断并打 warning。 """ top_k_sparse = 0 chunks = [] @@ -130,10 +134,15 @@ async def _retrieve_with_bm25( if not vec_db: continue result = await vec_db.document_storage.get_documents( - metadata_filters={}, - limit=None, - offset=None, + metadata_filters={"kb_id": kb_id}, + limit=self.MAX_BM25_DOCS, + offset=0, ) + if len(result) >= self.MAX_BM25_DOCS: + logger.warning( + f"知识库 {kb_id} 的 BM25 回退检索已触及 {self.MAX_BM25_DOCS} " + f"条 chunk 上限,结果可能不完整。建议检查 FTS5 索引状态。", + ) chunk_mds = [json.loads(doc["metadata"]) for doc in result] result = [ { diff --git a/tests/test_kb_sparse_retrieval.py b/tests/test_kb_sparse_retrieval.py index aa4fe48c0f..96e95a4edb 100644 --- a/tests/test_kb_sparse_retrieval.py +++ b/tests/test_kb_sparse_retrieval.py @@ -192,3 +192,58 @@ async def test_fts5_and_bm25_both_contribute_to_sort(self): # No out-of-range scores for r in results: assert r.score >= -1000.0, f"Unexpectedly low score: {r.score}" + + @pytest.mark.asyncio + async def test_bm25_fallback_honors_chunk_limit(self): + """BM25 fallback caps loaded chunks at MAX_BM25_DOCS to prevent OOM.""" + sr = SparseRetriever(kb_db=AsyncMock()) + + cap = sr.MAX_BM25_DOCS + # Create more docs than the cap + many_docs = [ + _make_fake_doc( + f"chunk-{i}", f"document content {i}", + {"chunk_index": i, "kb_doc_id": f"d{i//10}", "kb_id": "kb-a"}, + ) + for i in range(cap + 100) + ] + + vec_db = AsyncMock() + vec_db.document_storage.search_sparse = AsyncMock(return_value=None) + vec_db.document_storage.get_documents = AsyncMock(return_value=many_docs) + + kb_options = {"kb-a": {"vec_db": vec_db, "top_k_sparse": 50}} + + results = await sr.retrieve(query="test", kb_ids=["kb-a"], kb_options=kb_options) + + # get_documents was called with the cap as limit + vec_db.document_storage.get_documents.assert_awaited_once_with( + metadata_filters={"kb_id": "kb-a"}, + limit=cap, + offset=0, + ) + + # Results should not exceed the cap (minus what top_k_sparse filters) + assert len(results) <= 50 # top_k_sparse limit + + @pytest.mark.asyncio + async def test_bm25_fallback_filters_by_kb_id(self): + """BM25 fallback now passes kb_id metadata filter to get_documents.""" + sr = SparseRetriever(kb_db=AsyncMock()) + + vec_db = AsyncMock() + vec_db.document_storage.search_sparse = AsyncMock(return_value=None) + vec_db.document_storage.get_documents = AsyncMock(return_value=[]) + + kb_options = { + "kb-a": {"vec_db": vec_db, "top_k_sparse": 10}, + } + + await sr.retrieve(query="test", kb_ids=["kb-a"], kb_options=kb_options) + + # Verify the kb_id filter is passed (previously was empty {}) + vec_db.document_storage.get_documents.assert_awaited_once_with( + metadata_filters={"kb_id": "kb-a"}, + limit=sr.MAX_BM25_DOCS, + offset=0, + ) From 21546f8397709ac029422d07364e58c101d3457e Mon Sep 17 00:00:00 2001 From: lxfight <1686540385@qq.com> Date: Thu, 4 Jun 2026 20:35:51 +0800 Subject: [PATCH 07/48] fix(kb): offload faiss.write_index to thread via asyncio.to_thread MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit EmbeddingStorage.save_index() called faiss.write_index() directly, which is synchronous I/O and blocks the asyncio event loop. For large FAISS indexes this can freeze the entire server for seconds. Fix: wrap faiss.write_index in asyncio.to_thread() so the blocking I/O runs on a thread pool, keeping the event loop responsive. Called from insert(), insert_batch(), and delete() — every vector mutation path. --- .../db/vec_db/faiss_impl/embedding_storage.py | 5 +- tests/test_kb_faiss_async_save.py | 156 ++++++++++++++++++ 2 files changed, 159 insertions(+), 2 deletions(-) create mode 100644 tests/test_kb_faiss_async_save.py diff --git a/astrbot/core/db/vec_db/faiss_impl/embedding_storage.py b/astrbot/core/db/vec_db/faiss_impl/embedding_storage.py index dc6977cf8a..ebdfb55a54 100644 --- a/astrbot/core/db/vec_db/faiss_impl/embedding_storage.py +++ b/astrbot/core/db/vec_db/faiss_impl/embedding_storage.py @@ -4,6 +4,7 @@ raise ImportError( "faiss 未安装。请使用 'pip install faiss-cpu' 或 'pip install faiss-gpu' 安装。", ) +import asyncio import os import numpy as np @@ -84,7 +85,7 @@ async def delete(self, ids: list[int]) -> None: await self.save_index() async def save_index(self) -> None: - """保存索引 + """保存索引(在单独线程中执行以避免阻塞事件循环) Args: path (str): 保存索引的路径 @@ -92,4 +93,4 @@ async def save_index(self) -> None: """ if self.index is None: return - faiss.write_index(self.index, self.path) + await asyncio.to_thread(faiss.write_index, self.index, self.path) diff --git a/tests/test_kb_faiss_async_save.py b/tests/test_kb_faiss_async_save.py new file mode 100644 index 0000000000..01da490a62 --- /dev/null +++ b/tests/test_kb_faiss_async_save.py @@ -0,0 +1,156 @@ +"""Tests for #5: FAISS save_index uses asyncio.to_thread to avoid blocking +the event loop during synchronous faiss.write_index calls. +""" + +import asyncio +from unittest.mock import MagicMock, patch + +import numpy as np +import pytest + + +def _make_storage(dimension: int = 128, path: str = "/tmp/test.index"): + """Build an EmbeddingStorage instance with a minimal mocked FAISS index.""" + from astrbot.core.db.vec_db.faiss_impl.embedding_storage import EmbeddingStorage + + storage = EmbeddingStorage.__new__(EmbeddingStorage) + storage.dimension = dimension + storage.path = path + # Mock FAISS index — just enough to satisfy the method guards + storage.index = MagicMock() + storage.index.ntotal = 100 + return storage + + +class TestFaissSaveIndexAsync: + """Verify save_index delegates to asyncio.to_thread.""" + + @pytest.mark.asyncio + async def test_save_index_uses_to_thread(self): + """save_index offloads faiss.write_index to a thread.""" + import faiss # noqa: F401 — ensure faiss is importable + + storage = _make_storage() + + with patch( + "astrbot.core.db.vec_db.faiss_impl.embedding_storage.asyncio.to_thread", + ) as mock_to_thread: + mock_to_thread.return_value = None # simulate completion + await storage.save_index() + + mock_to_thread.assert_awaited_once_with( + faiss.write_index, storage.index, storage.path, + ) + + @pytest.mark.asyncio + async def test_save_index_skips_when_index_none(self): + """save_index is a no-op when index hasn't been initialized.""" + storage = _make_storage() + storage.index = None + + with patch( + "astrbot.core.db.vec_db.faiss_impl.embedding_storage.asyncio.to_thread", + ) as mock_to_thread: + await storage.save_index() + + mock_to_thread.assert_not_called() + + @pytest.mark.asyncio + async def test_insert_calls_save_index(self): + """insert() calls save_index after adding the vector.""" + storage = _make_storage() + storage.index.add_with_ids = MagicMock() + + with patch.object(storage, "save_index", return_value=None) as mock_save: + vector = np.random.rand(storage.dimension).astype(np.float32) + await storage.insert(vector, id=42) + + storage.index.add_with_ids.assert_called_once() + mock_save.assert_awaited_once() + + @pytest.mark.asyncio + async def test_insert_batch_calls_save_index(self): + """insert_batch() calls save_index after batch-adding vectors.""" + storage = _make_storage() + storage.index.add_with_ids = MagicMock() + + with patch.object(storage, "save_index", return_value=None) as mock_save: + vectors = np.random.rand(10, storage.dimension).astype(np.float32) + ids = list(range(10)) + await storage.insert_batch(vectors, ids) + + storage.index.add_with_ids.assert_called_once() + mock_save.assert_awaited_once() + + @pytest.mark.asyncio + async def test_delete_calls_save_index(self): + """delete() calls save_index after removing vectors.""" + storage = _make_storage() + storage.index.remove_ids = MagicMock() + + with patch.object(storage, "save_index", return_value=None) as mock_save: + await storage.delete([1, 2, 3]) + + storage.index.remove_ids.assert_called_once() + mock_save.assert_awaited_once() + + @pytest.mark.asyncio + async def test_save_index_with_real_faiss_index(self): + """End-to-end: save_index with a real FAISS index writes to a temp file.""" + import faiss + import tempfile + + dim = 128 + base_index = faiss.IndexFlatL2(dim) + index = faiss.IndexIDMap(base_index) + index.add_with_ids( + np.random.rand(5, dim).astype(np.float32), + np.array([1, 2, 3, 4, 5], dtype=np.int64), + ) + + with tempfile.NamedTemporaryFile(suffix=".index", delete=False) as tmp: + tmp_path = tmp.name + + try: + storage = _make_storage(dimension=dim, path=tmp_path) + storage.index = index + + await storage.save_index() + + # Verify file was written and is readable + assert __import__("os").path.exists(tmp_path) + assert __import__("os").path.getsize(tmp_path) > 0 + + # Round-trip: read back and verify dimension matches + restored = faiss.read_index(tmp_path) + assert restored.ntotal == 5 + finally: + __import__("os").unlink(tmp_path) + + @pytest.mark.asyncio + async def test_real_save_does_not_block_event_loop(self): + """Verify a real save_index completes quickly for a small index.""" + import faiss + import tempfile + + dim = 64 + base_index = faiss.IndexFlatL2(dim) + index = faiss.IndexIDMap(base_index) + # 1000 vectors — should be very fast + index.add_with_ids( + np.random.rand(1000, dim).astype(np.float32), + np.arange(1000, dtype=np.int64), + ) + + with tempfile.NamedTemporaryFile(suffix=".index", delete=False) as tmp: + tmp_path = tmp.name + + try: + storage = _make_storage(dimension=dim, path=tmp_path) + storage.index = index + + # Should complete quickly + await asyncio.wait_for(storage.save_index(), timeout=5.0) + assert __import__("os").path.getsize(tmp_path) > 0 + finally: + __import__("os").unlink(tmp_path) From f811403de669d700cd7df9a20888410c4ac4a56e Mon Sep 17 00:00:00 2001 From: lxfight <1686540385@qq.com> Date: Fri, 5 Jun 2026 21:28:31 +0800 Subject: [PATCH 08/48] fix(kb): switch FAISS from L2 to IP cosine similarity with write lock MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - IndexFlatL2 → IndexFlatIP: stored vectors now L2-normalized at insert, query vectors normalized at search, IP = cosine similarity - Score formula: (scores + 1.0) / 2.0 maps IP [-1,1] → [0,1] - Add _safe_normalize_l2() with zero-vector detection (raises ValueError) - Add asyncio.Lock serializing all FAISS write operations (insert/batch/delete/save) - Auto-migrate legacy L2 indexes on load (reconstruct + normalize + rebuild) - Validate dimension match when loading existing index - Add comprehensive tests: normalization, IP scores, write lock, L2→IP migration, dimension mismatch, zero-vector guard, HNSW create/search/persistence --- .../db/vec_db/faiss_impl/embedding_storage.py | 178 ++++++++-- astrbot/core/db/vec_db/faiss_impl/vec_db.py | 141 ++++++-- tests/test_kb_faiss_async_save.py | 23 +- tests/unit/test_embedding_storage.py | 326 ++++++++++++++++++ 4 files changed, 611 insertions(+), 57 deletions(-) create mode 100644 tests/unit/test_embedding_storage.py diff --git a/astrbot/core/db/vec_db/faiss_impl/embedding_storage.py b/astrbot/core/db/vec_db/faiss_impl/embedding_storage.py index ebdfb55a54..8747b5bfc5 100644 --- a/astrbot/core/db/vec_db/faiss_impl/embedding_storage.py +++ b/astrbot/core/db/vec_db/faiss_impl/embedding_storage.py @@ -10,16 +10,130 @@ import numpy as np +def _safe_normalize_l2(vectors: np.ndarray) -> None: + """L2 归一化,对零向量抛出明确错误 + + 正常的 embedding 模型不应产生零向量。零向量无法归一化(会产生 NaN), + 说明 embedding provider 返回了异常数据,应当尽早暴露问题。 + """ + # 检测全零行 + if vectors.ndim == 2: + norms = np.linalg.norm(vectors, axis=1, keepdims=True) + zero_count = int((norms < 1e-12).sum()) + if zero_count > 0: + raise ValueError( + f"向量归一化失败:检测到 {zero_count} 个零向量。" + "Embedding Provider 返回了全零向量,这可能说明 API 密钥无效、" + "模型不支持当前输入、或服务端异常。请检查 Embedding Provider 配置。" + ) + elif vectors.ndim == 1: + if np.linalg.norm(vectors) < 1e-12: + raise ValueError( + "向量归一化失败:检测到零向量。" + "Embedding Provider 返回了全零向量,这可能说明 API 密钥无效、" + "模型不支持当前输入、或服务端异常。请检查 Embedding Provider 配置。" + ) + + faiss.normalize_L2(vectors) + + class EmbeddingStorage: - def __init__(self, dimension: int, path: str | None = None) -> None: + def __init__( + self, + dimension: int, + path: str | None = None, + index_type: str = "flat", + ) -> None: self.dimension = dimension self.path = path self.index = None + self.index_type = index_type # "flat" | "hnsw" + self._write_lock = asyncio.Lock() if path and os.path.exists(path): self.index = faiss.read_index(path) + # 验证加载的索引维度是否匹配 + loaded_dim = self.index.d + if loaded_dim != self.dimension: + raise ValueError( + f"索引维度不匹配: 磁盘索引维度={loaded_dim}, " + f"当前 Embedding Provider 维度={self.dimension}。" + f"请确认 Embedding Provider 与已有索引一致," + f"或删除旧索引后重新创建知识库。" + ) + self._migrate_l2_to_ip_if_needed() else: - base_index = faiss.IndexFlatL2(dimension) + self.index = self._create_index() + + def _create_index(self): + """根据 index_type 创建 FAISS 索引""" + if self.index_type == "hnsw": + # HNSW32 with Inner Product metric for cosine similarity + base_index = faiss.index_factory( + self.dimension, + "HNSW32", + faiss.METRIC_INNER_PRODUCT, + ) + return faiss.IndexIDMap(base_index) + # 默认: flat (精确搜索) + return faiss.IndexIDMap(faiss.IndexFlatIP(self.dimension)) + + def _migrate_l2_to_ip_if_needed(self) -> None: + """检测并迁移旧版 L2 索引到 IP (余弦相似度) + + 旧版使用 IndexFlatL2,新版使用 IndexFlatIP + 归一化向量。 + 迁移过程:reconstruct 所有向量 → L2 归一化 → 重建为 IP 索引。 + """ + assert self.index is not None + # IndexIDMap 包装了 base index,需要解包检查 + base_index = ( + self.index.index if hasattr(self.index, "index") else self.index + ) + if not isinstance(base_index, faiss.IndexFlatL2): + return # 已经是 IP 或其他类型,无需迁移 + + import warnings + + ntotal = self.index.ntotal + if ntotal == 0: + warnings.warn( + "检测到空的旧版 L2 索引,将重建为 IP 索引。", + stacklevel=2, + ) + base_index = faiss.IndexFlatIP(self.dimension) self.index = faiss.IndexIDMap(base_index) + return + + warnings.warn( + f"检测到旧版 L2 索引 (含 {ntotal} 个向量),正在自动迁移到 IP 索引..." + "这可能需要几秒钟。迁移后旧索引将被覆盖。", + stacklevel=2, + ) + + # 重建所有向量并归一化 + # 注意: IndexIDMap.reconstruct 在某些 FAISS 构建版本中不可用 + try: + vectors = np.zeros((ntotal, self.dimension), dtype=np.float32) + for i in range(ntotal): + vectors[i] = self.index.reconstruct(i) + except RuntimeError: + warnings.warn( + "无法从旧索引重建向量(reconstruct 不可用)," + "将重建空 IP 索引。请重新上传文档以生成新向量。", + stacklevel=2, + ) + self.index = faiss.IndexIDMap(faiss.IndexFlatIP(self.dimension)) + faiss.write_index(self.index, self.path) + return + + _safe_normalize_l2(vectors) + + # 重建为 IP 索引 + new_index = faiss.IndexIDMap(faiss.IndexFlatIP(self.dimension)) + new_index.add_with_ids(vectors, np.arange(ntotal, dtype=np.int64)) + + self.index = new_index + # 立即保存迁移后的索引 + faiss.write_index(self.index, self.path) async def insert(self, vector: np.ndarray, id: int) -> None: """插入向量 @@ -31,13 +145,16 @@ async def insert(self, vector: np.ndarray, id: int) -> None: ValueError: 如果向量的维度与存储的维度不匹配 """ - assert self.index is not None, "FAISS index is not initialized." - if vector.shape[0] != self.dimension: - raise ValueError( - f"向量维度不匹配, 期望: {self.dimension}, 实际: {vector.shape[0]}", - ) - self.index.add_with_ids(vector.reshape(1, -1), np.array([id])) - await self.save_index() + async with self._write_lock: + assert self.index is not None, "FAISS index is not initialized." + if vector.shape[0] != self.dimension: + raise ValueError( + f"向量维度不匹配, 期望: {self.dimension}, 实际: {vector.shape[0]}", + ) + v_2d = vector.reshape(1, -1) + _safe_normalize_l2(v_2d) + self.index.add_with_ids(v_2d, np.array([id])) + await self._save_index_locked() async def insert_batch(self, vectors: np.ndarray, ids: list[int]) -> None: """批量插入向量 @@ -49,13 +166,15 @@ async def insert_batch(self, vectors: np.ndarray, ids: list[int]) -> None: ValueError: 如果向量的维度与存储的维度不匹配 """ - assert self.index is not None, "FAISS index is not initialized." - if vectors.shape[1] != self.dimension: - raise ValueError( - f"向量维度不匹配, 期望: {self.dimension}, 实际: {vectors.shape[1]}", - ) - self.index.add_with_ids(vectors, np.array(ids)) - await self.save_index() + async with self._write_lock: + assert self.index is not None, "FAISS index is not initialized." + if vectors.shape[1] != self.dimension: + raise ValueError( + f"向量维度不匹配, 期望: {self.dimension}, 实际: {vectors.shape[1]}", + ) + _safe_normalize_l2(vectors) + self.index.add_with_ids(vectors, np.array(ids)) + await self._save_index_locked() async def search(self, vector: np.ndarray, k: int) -> tuple: """搜索最相似的向量 @@ -68,7 +187,7 @@ async def search(self, vector: np.ndarray, k: int) -> tuple: """ assert self.index is not None, "FAISS index is not initialized." - faiss.normalize_L2(vector) + _safe_normalize_l2(vector) distances, indices = self.index.search(vector, k) return distances, indices @@ -79,18 +198,25 @@ async def delete(self, ids: list[int]) -> None: ids (list[int]): 要删除的向量ID列表 """ - assert self.index is not None, "FAISS index is not initialized." - id_array = np.array(ids, dtype=np.int64) - self.index.remove_ids(id_array) - await self.save_index() + async with self._write_lock: + assert self.index is not None, "FAISS index is not initialized." + id_array = np.array(ids, dtype=np.int64) + self.index.remove_ids(id_array) + await self._save_index_locked() - async def save_index(self) -> None: - """保存索引(在单独线程中执行以避免阻塞事件循环) - - Args: - path (str): 保存索引的路径 + async def _save_index_locked(self) -> None: + """内部方法:在已持有 _write_lock 的情况下保存索引到磁盘。 + 调用者必须已经获取 _write_lock。 """ if self.index is None: return await asyncio.to_thread(faiss.write_index, self.index, self.path) + + async def save_index(self) -> None: + """保存索引(在单独线程中执行以避免阻塞事件循环) + + 公共方法,自动获取写锁以确保线程安全。 + """ + async with self._write_lock: + await self._save_index_locked() diff --git a/astrbot/core/db/vec_db/faiss_impl/vec_db.py b/astrbot/core/db/vec_db/faiss_impl/vec_db.py index 0474683754..b09175dae4 100644 --- a/astrbot/core/db/vec_db/faiss_impl/vec_db.py +++ b/astrbot/core/db/vec_db/faiss_impl/vec_db.py @@ -1,5 +1,7 @@ import time import uuid +from collections import OrderedDict +from hashlib import sha256 import numpy as np @@ -12,6 +14,50 @@ from .embedding_storage import EmbeddingStorage +class EmbeddingCache: + """基于 LRU 的文本 → 嵌入向量缓存(线程安全) + + 使用 SHA256 哈希文本作为缓存 key,避免对相同内容重复调用 embedding API。 + """ + + def __init__(self, max_size: int = 10000) -> None: + import asyncio + + self._cache: OrderedDict[str, np.ndarray] = OrderedDict() + self._max_size = max_size + self._lock = asyncio.Lock() + + @staticmethod + def _hash(text: str) -> str: + return sha256(text.encode()).hexdigest() + + async def get(self, text: str) -> np.ndarray | None: + async with self._lock: + key = self._hash(text) + if key in self._cache: + self._cache.move_to_end(key) + return self._cache[key].copy() + return None + + async def put(self, text: str, embedding: np.ndarray) -> None: + async with self._lock: + key = self._hash(text) + if key not in self._cache: + if len(self._cache) >= self._max_size: + self._cache.popitem(last=False) + else: + self._cache.move_to_end(key) + self._cache[key] = embedding.copy() + + async def clear(self) -> None: + async with self._lock: + self._cache.clear() + + async def __len__(self) -> int: + async with self._lock: + return len(self._cache) + + class FaissVecDB(BaseVecDB): """A class to represent a vector database.""" @@ -21,6 +67,7 @@ def __init__( index_store_path: str, embedding_provider: EmbeddingProvider, rerank_provider: RerankProvider | None = None, + index_type: str = "flat", ) -> None: self.doc_store_path = doc_store_path self.index_store_path = index_store_path @@ -29,9 +76,11 @@ def __init__( self.embedding_storage = EmbeddingStorage( embedding_provider.get_dim(), index_store_path, + index_type=index_type, ) self.embedding_provider = embedding_provider self.rerank_provider = rerank_provider + self.embedding_cache = EmbeddingCache() async def initialize(self) -> None: await self.document_storage.initialize() @@ -81,6 +130,9 @@ async def insert_batch( ) return [] + # 空列表快速返回后,确保不再处理零向量 + assert len(contents) > 0, "contents must not be empty" + content_count = len(contents) if len(metadatas) != content_count: raise KnowledgeBaseUploadError( @@ -107,33 +159,64 @@ async def insert_batch( }, ) + # 检查嵌入缓存,分离已缓存的文本和需要计算的文本 start = time.time() - logger.debug(f"Generating embeddings for {len(contents)} contents...") - vectors = await self.embedding_provider.get_embeddings_batch( - contents, - batch_size=batch_size, - tasks_limit=tasks_limit, - max_retries=max_retries, - progress_callback=progress_callback, + cached_vectors: dict[int, np.ndarray] = {} + uncached_indices: list[int] = [] + uncached_texts: list[str] = [] + + for idx, text in enumerate(contents): + cached = await self.embedding_cache.get(text) + if cached is not None: + cached_vectors[idx] = cached + else: + uncached_indices.append(idx) + uncached_texts.append(text) + + cache_hits = len(cached_vectors) + cache_misses = len(uncached_texts) + logger.debug( + f"Embedding cache: {cache_hits} hits, {cache_misses} misses " + f"out of {len(contents)} contents.", ) + + # 只对未缓存的文本生成嵌入 + vectors = [np.empty(0, dtype=np.float32)] * len(contents) + if uncached_texts: + new_embeddings = await self.embedding_provider.get_embeddings_batch( + uncached_texts, + batch_size=batch_size, + tasks_limit=tasks_limit, + max_retries=max_retries, + progress_callback=progress_callback, + ) + # 验证返回数量 + if len(new_embeddings) != len(uncached_texts): + raise KnowledgeBaseUploadError( + stage="embedding", + user_message=( + "向量化失败:嵌入模型返回的向量数量与文本分块数量不一致" + f"(期望 {len(uncached_texts)},实际 {len(new_embeddings)})。" + "这通常说明当前 Embedding 接口未完整返回批量结果," + "或该服务不兼容当前批量请求格式。" + ), + details={ + "expected_contents": len(uncached_texts), + "actual_vectors": len(new_embeddings), + }, + ) + for i, idx in enumerate(uncached_indices): + vectors[idx] = np.asarray(new_embeddings[i], dtype=np.float32) + await self.embedding_cache.put(uncached_texts[i], vectors[idx]) + + for idx, cached_vec in cached_vectors.items(): + vectors[idx] = cached_vec + end = time.time() logger.debug( - f"Generated embeddings for {len(contents)} contents in {end - start:.2f} seconds.", + f"Embeddings ready for {len(contents)} contents " + f"in {end - start:.2f}s (cached: {cache_hits}, fresh: {cache_misses}).", ) - if len(vectors) != content_count: - raise KnowledgeBaseUploadError( - stage="embedding", - user_message=( - "向量化失败:嵌入模型返回的向量数量与文本分块数量不一致" - f"(期望 {content_count},实际 {len(vectors)})。" - "这通常说明当前 Embedding 接口未完整返回批量结果," - "或该服务不兼容当前批量请求格式。" - ), - details={ - "expected_contents": content_count, - "actual_vectors": len(vectors), - }, - ) # 使用 DocumentStorage 的批量插入方法 int_ids = await self.document_storage.insert_documents_batch( @@ -211,15 +294,23 @@ async def retrieve( List[Result]: 查询结果 """ - embedding = await self.embedding_provider.get_embedding(query) + # 先查缓存,再调 embedding provider + cached = await self.embedding_cache.get(query) + if cached is not None: + embedding = cached + else: + embedding = await self.embedding_provider.get_embedding(query) + await self.embedding_cache.put( + query, np.asarray(embedding, dtype=np.float32), + ) scores, indices = await self.embedding_storage.search( vector=np.array([embedding]).astype("float32"), k=fetch_k if metadata_filters else k, ) if len(indices[0]) == 0 or indices[0][0] == -1: return [] - # normalize scores - scores[0] = 1.0 - (scores[0] / 2.0) + # 将内积分数 (余弦相似度, 范围 [-1, 1]) 映射到 [0, 1] + scores[0] = (scores[0] + 1.0) / 2.0 # NOTE: maybe the size is less than k. fetched_docs = await self.document_storage.get_documents( metadata_filters=metadata_filters or {}, diff --git a/tests/test_kb_faiss_async_save.py b/tests/test_kb_faiss_async_save.py index 01da490a62..0fae243f15 100644 --- a/tests/test_kb_faiss_async_save.py +++ b/tests/test_kb_faiss_async_save.py @@ -11,11 +11,14 @@ def _make_storage(dimension: int = 128, path: str = "/tmp/test.index"): """Build an EmbeddingStorage instance with a minimal mocked FAISS index.""" + import asyncio + from astrbot.core.db.vec_db.faiss_impl.embedding_storage import EmbeddingStorage storage = EmbeddingStorage.__new__(EmbeddingStorage) storage.dimension = dimension storage.path = path + storage._write_lock = asyncio.Lock() # Mock FAISS index — just enough to satisfy the method guards storage.index = MagicMock() storage.index.ntotal = 100 @@ -57,11 +60,15 @@ async def test_save_index_skips_when_index_none(self): @pytest.mark.asyncio async def test_insert_calls_save_index(self): - """insert() calls save_index after adding the vector.""" + """insert() calls _save_index_locked after adding the vector.""" storage = _make_storage() storage.index.add_with_ids = MagicMock() - with patch.object(storage, "save_index", return_value=None) as mock_save: + with patch.object( + storage, "_save_index_locked", return_value=None + ) as mock_save: + import faiss # noqa: F811 + vector = np.random.rand(storage.dimension).astype(np.float32) await storage.insert(vector, id=42) @@ -70,11 +77,13 @@ async def test_insert_calls_save_index(self): @pytest.mark.asyncio async def test_insert_batch_calls_save_index(self): - """insert_batch() calls save_index after batch-adding vectors.""" + """insert_batch() calls _save_index_locked after batch-adding vectors.""" storage = _make_storage() storage.index.add_with_ids = MagicMock() - with patch.object(storage, "save_index", return_value=None) as mock_save: + with patch.object( + storage, "_save_index_locked", return_value=None + ) as mock_save: vectors = np.random.rand(10, storage.dimension).astype(np.float32) ids = list(range(10)) await storage.insert_batch(vectors, ids) @@ -84,11 +93,13 @@ async def test_insert_batch_calls_save_index(self): @pytest.mark.asyncio async def test_delete_calls_save_index(self): - """delete() calls save_index after removing vectors.""" + """delete() calls _save_index_locked after removing vectors.""" storage = _make_storage() storage.index.remove_ids = MagicMock() - with patch.object(storage, "save_index", return_value=None) as mock_save: + with patch.object( + storage, "_save_index_locked", return_value=None + ) as mock_save: await storage.delete([1, 2, 3]) storage.index.remove_ids.assert_called_once() diff --git a/tests/unit/test_embedding_storage.py b/tests/unit/test_embedding_storage.py new file mode 100644 index 0000000000..c754bb0db5 --- /dev/null +++ b/tests/unit/test_embedding_storage.py @@ -0,0 +1,326 @@ +"""测试 FAISS EmbeddingStorage — 向量归一化、余弦相似度、写锁、索引迁移""" + +import asyncio +import tempfile +from pathlib import Path + +import numpy as np +import pytest + +from astrbot.core.db.vec_db.faiss_impl.embedding_storage import EmbeddingStorage + + +DIM = 128 + + +def make_random_vector(dim: int = DIM) -> np.ndarray: + return np.random.default_rng(42).random(dim).astype(np.float32) + + +def make_random_batch(n: int, dim: int = DIM) -> np.ndarray: + return np.random.default_rng(42).random((n, dim)).astype(np.float32) + + +def _normalize_vector(v: np.ndarray) -> None: + """用 FAISS 归一化单个向量(原地修改)""" + faiss = pytest.importorskip("faiss") + faiss.normalize_L2(v.reshape(1, -1)) + + +def assert_unit_norm(vector: np.ndarray) -> None: + """断言向量已 L2 归一化(模长 ≈ 1.0)""" + norm = np.linalg.norm(vector) + assert abs(norm - 1.0) < 1e-5, f"向量未归一化, 模长={norm}" + + +class TestVectorNormalization: + """Phase 1A: 验证入库向量归一化 & 余弦相似度""" + + @pytest.mark.asyncio + async def test_insert_normalizes_vector(self): + """插入后存储的向量应该已被 L2 归一化(通过自身搜索验证) + + 插入时自动归一化向量,用同一向量查询应得到接近 1.0 的内积分。 + """ + with tempfile.TemporaryDirectory() as tmpdir: + index_path = Path(tmpdir) / "index.faiss" + storage = EmbeddingStorage(dimension=DIM, path=str(index_path)) + + v = make_random_vector() + await storage.insert(v, id=1) + + # 搜索自身:归一化后内积应 ≈ 1.0 + distances, indices = await storage.search( + v.copy().reshape(1, -1), k=1 + ) + assert indices[0][0] == 1, f"应返回 id=1,实际={indices[0][0]}" + assert distances[0][0] > 0.999, ( + f"归一化后自身内积应 ≈ 1.0,实际={distances[0][0]}" + ) + + @pytest.mark.asyncio + async def test_insert_batch_normalizes_vectors(self): + """批量插入后所有存储的向量应该已被 L2 归一化(通过搜索验证)""" + with tempfile.TemporaryDirectory() as tmpdir: + index_path = Path(tmpdir) / "index.faiss" + storage = EmbeddingStorage(dimension=DIM, path=str(index_path)) + + vectors = make_random_batch(10) + ids = list(range(10)) + await storage.insert_batch(vectors, ids) + + # 用其中一个向量搜索自身 + q = vectors[0].copy() + distances, _ = await storage.search(q.reshape(1, -1), k=1) + assert distances[0][0] > 0.999, ( + f"归一化后自身内积应 ≈ 1.0,实际={distances[0][0]}" + ) + + @pytest.mark.asyncio + async def test_search_returns_ip_scores(self): + """IP 搜索对归一化向量应返回内积分数 (≈余弦相似度)""" + with tempfile.TemporaryDirectory() as tmpdir: + index_path = Path(tmpdir) / "index.faiss" + storage = EmbeddingStorage(dimension=DIM, path=str(index_path)) + + # 插入一个向量 + v = np.ones(DIM, dtype=np.float32) + _normalize_vector(v) + await storage.insert(v, id=1) + + # 用相同向量搜索自身 — 内积应接近 1.0 + query = v.copy() + distances, indices = await storage.search( + query.reshape(1, -1), k=1 + ) + # IP 分数应在 [-1, 1] 范围内 + assert -1.0 - 1e-5 <= distances[0][0] <= 1.0 + 1e-5, ( + f"IP 分数超出 [-1,1] 范围: {distances[0][0]}" + ) + # 同向量内积应接近 1.0 + assert abs(distances[0][0] - 1.0) < 1e-3, ( + f"自身内积应 ≈ 1.0,实际={distances[0][0]}" + ) + + @pytest.mark.asyncio + async def test_score_conversion_range(self): + """分数转换 (scores + 1) / 2 应映射 [-1,1] → [0,1]""" + from astrbot.core.db.vec_db.faiss_impl.vec_db import FaissVecDB + + # 模拟检索后分数转换 + test_cases = [ + (np.array([[1.0]]), 1.0), # 完美匹配 + (np.array([[0.0]]), 0.5), # 正交 + (np.array([[-1.0]]), 0.0), # 完全相反 + ] + for raw_scores, expected in test_cases: + converted = (raw_scores[0] + 1.0) / 2.0 + assert abs(converted[0] - expected) < 1e-5, ( + f"转换错误: {raw_scores[0][0]} → {converted[0]}, 期望 {expected}" + ) + + +class TestWriteLock: + """Phase 1B: 验证 asyncio.Lock 串行化写入操作""" + + @pytest.mark.asyncio + async def test_concurrent_inserts_serialized(self): + """并发插入应被正确序列化,最终 ntotal 正确""" + with tempfile.TemporaryDirectory() as tmpdir: + index_path = Path(tmpdir) / "index.faiss" + storage = EmbeddingStorage(dimension=DIM, path=str(index_path)) + + async def insert_one(offset: int) -> None: + for i in range(5): + v = make_random_vector() + await storage.insert(v, id=offset * 5 + i) + + # 4 个协程并发插入 + await asyncio.gather( + insert_one(0), insert_one(1), insert_one(2), insert_one(3), + ) + + assert storage.index.ntotal == 20, ( + f"并发插入后 ntotal 应为 20, 实际={storage.index.ntotal}" + ) + + @pytest.mark.asyncio + async def test_search_not_blocked_by_write(self): + """写入锁不应阻塞搜索(搜索不加锁)""" + with tempfile.TemporaryDirectory() as tmpdir: + index_path = Path(tmpdir) / "index.faiss" + storage = EmbeddingStorage(dimension=DIM, path=str(index_path)) + + # 预先插入一些数据 + for i in range(10): + v = make_random_vector() + await storage.insert(v, id=i) + + query = make_random_vector() + + # 同时进行搜索和插入 + search_task = asyncio.create_task( + storage.search(query.reshape(1, -1), k=5) + ) + insert_task = asyncio.create_task( + storage.insert(make_random_vector(), id=100) + ) + + results = await asyncio.gather(search_task, insert_task) + distances, _ = results[0] + assert len(distances[0]) == 5 + + +class TestIndexMigration: + """Phase 1A: 向后兼容 — L2 索引迁移到 IP""" + + @pytest.mark.asyncio + async def test_migration_l2_to_ip(self): + """加载旧的 L2 索引时自动迁移为 IP""" + faiss = pytest.importorskip("faiss") + + with tempfile.TemporaryDirectory() as tmpdir: + index_path = Path(tmpdir) / "index.faiss" + + # 模拟旧版 L2 索引 + old_index = faiss.IndexIDMap(faiss.IndexFlatL2(DIM)) + v = make_random_vector() + old_index.add_with_ids(v.reshape(1, -1), np.array([1])) + faiss.write_index(old_index, str(index_path)) + + # 加载时应检测 L2 并迁移 + storage = EmbeddingStorage(dimension=DIM, path=str(index_path)) + + # 迁移后应为有效索引 + assert storage.index is not None + assert storage.index.ntotal == 1 + + # 确保能正常搜索(search 方法自动归一化查询向量) + distances, _ = await storage.search(v.copy().reshape(1, -1), k=1) + assert distances[0][0] > 0.9, ( + f"迁移后搜索自身应有高分, 实际={distances[0][0]}" + ) + + @pytest.mark.asyncio + async def test_no_crash_on_reload_existing_ip_index(self): + """重新加载已有的 IP 索引不应报错""" + with tempfile.TemporaryDirectory() as tmpdir: + index_path = Path(tmpdir) / "index.faiss" + + # 创建 IP 索引 + storage = EmbeddingStorage(dimension=DIM, path=str(index_path)) + v = make_random_vector() + await storage.insert(v, id=1) # insert 自动归一化 + + # 重新加载 + storage2 = EmbeddingStorage(dimension=DIM, path=str(index_path)) + assert storage2.index is not None + assert storage2.index.ntotal == 1 + + @pytest.mark.asyncio + async def test_dimension_mismatch_on_load_raises_error(self): + """加载维度不匹配的索引时应抛出清晰错误""" + faiss = pytest.importorskip("faiss") + + with tempfile.TemporaryDirectory() as tmpdir: + index_path = Path(tmpdir) / "index.faiss" + # 创建不同维度的索引 + wrong_dim = 256 + index = faiss.IndexIDMap(faiss.IndexFlatIP(wrong_dim)) + faiss.write_index(index, str(index_path)) + + with pytest.raises(ValueError, match="索引维度不匹配"): + EmbeddingStorage(dimension=DIM, path=str(index_path)) + + +class TestZeroVectorGuard: + """零向量应抛出明确错误,而非静默产生无意义数据""" + + @pytest.mark.asyncio + async def test_zero_vector_insert_raises_error(self): + """插入零向量应抛出 ValueError""" + with tempfile.TemporaryDirectory() as tmpdir: + index_path = Path(tmpdir) / "index.faiss" + storage = EmbeddingStorage(dimension=DIM, path=str(index_path)) + + zero_v = np.zeros(DIM, dtype=np.float32) + with pytest.raises(ValueError, match="零向量"): + await storage.insert(zero_v, id=1) + + @pytest.mark.asyncio + async def test_batch_zero_vectors_raises_error(self): + """批量插入含零向量应抛出 ValueError""" + with tempfile.TemporaryDirectory() as tmpdir: + index_path = Path(tmpdir) / "index.faiss" + storage = EmbeddingStorage(dimension=DIM, path=str(index_path)) + + vectors = make_random_batch(10) + vectors[0] = np.zeros(DIM, dtype=np.float32) + ids = list(range(10)) + with pytest.raises(ValueError, match="零向量"): + await storage.insert_batch(vectors, ids) + + @pytest.mark.asyncio + async def test_near_zero_vector_inserted_normally(self): + """接近零但不为零的向量应正常插入并归一化""" + with tempfile.TemporaryDirectory() as tmpdir: + index_path = Path(tmpdir) / "index.faiss" + storage = EmbeddingStorage(dimension=DIM, path=str(index_path)) + + # 非常小但不为零的向量 + tiny_v = np.full(DIM, 1e-8, dtype=np.float32) + await storage.insert(tiny_v, id=1) + assert storage.index.ntotal == 1 + + +class TestHNSWIndex: + """Phase 2A: HNSW 索引创建、持久化和搜索""" + + @pytest.mark.asyncio + async def test_create_hnsw_index(self): + """创建 HNSW 索引应成功""" + with tempfile.TemporaryDirectory() as tmpdir: + index_path = Path(tmpdir) / "index.faiss" + storage = EmbeddingStorage( + dimension=DIM, path=str(index_path), index_type="hnsw", + ) + assert storage.index is not None + assert storage.index.ntotal == 0 + + @pytest.mark.asyncio + async def test_hnsw_insert_and_search(self): + """HNSW 索引应支持插入和搜索""" + with tempfile.TemporaryDirectory() as tmpdir: + index_path = Path(tmpdir) / "index.faiss" + storage = EmbeddingStorage( + dimension=DIM, path=str(index_path), index_type="hnsw", + ) + # 插入多个向量 + for i in range(10): + v = make_random_vector() + await storage.insert(v, id=i) + + assert storage.index.ntotal == 10 + + # 搜索 + q = make_random_vector() + distances, indices = await storage.search(q.reshape(1, -1), k=5) + assert len(indices[0]) == 5 + + @pytest.mark.asyncio + async def test_hnsw_persistence(self): + """HNSW 索引应能持久化并重新加载""" + with tempfile.TemporaryDirectory() as tmpdir: + index_path = Path(tmpdir) / "index.faiss" + storage = EmbeddingStorage( + dimension=DIM, path=str(index_path), index_type="hnsw", + ) + v = make_random_vector() + await storage.insert(v, id=1) + + # 重新加载 + storage2 = EmbeddingStorage( + dimension=DIM, path=str(index_path), index_type="hnsw", + ) + assert storage2.index is not None + assert storage2.index.ntotal == 1 From 3ca771a7ee44ae04aae579d736bd0f39049ae677 Mon Sep 17 00:00:00 2001 From: lxfight <1686540385@qq.com> Date: Fri, 5 Jun 2026 21:28:39 +0800 Subject: [PATCH 09/48] fix(kb): correct BM25 top_k_sparse aggregation across multiple KBs Previously top_k_sparse was summed across KBs, causing sparse results to dominate RRF fusion. Now each KB's BM25 results are truncated to its own top_k_sparse before merging, and the global cap uses max(). Add test verifying per-KB truncation with 2 KBs and different limits. --- .../retrieval/sparse_retriever.py | 69 +++++++++++-------- tests/unit/test_sparse_retriever.py | 52 ++++++++++++++ 2 files changed, 94 insertions(+), 27 deletions(-) diff --git a/astrbot/core/knowledge_base/retrieval/sparse_retriever.py b/astrbot/core/knowledge_base/retrieval/sparse_retriever.py index 2b213bef0f..0608f71810 100644 --- a/astrbot/core/knowledge_base/retrieval/sparse_retriever.py +++ b/astrbot/core/knowledge_base/retrieval/sparse_retriever.py @@ -127,12 +127,15 @@ async def _retrieve_with_bm25( 单 KB 最多加载 MAX_BM25_DOCS 条 chunk,超限时截断并打 warning。 """ top_k_sparse = 0 - chunks = [] + all_kb_chunks: list[dict] = [] for kb_id in kb_ids: vec_db: FaissVecDB | None = kb_options.get(kb_id, {}).get("vec_db") if not vec_db: continue + kb_top_k = kb_options.get(kb_id, {}).get("top_k_sparse", 50) + top_k_sparse = max(top_k_sparse, kb_top_k) + result = await vec_db.document_storage.get_documents( metadata_filters={"kb_id": kb_id}, limit=self.MAX_BM25_DOCS, @@ -144,42 +147,54 @@ async def _retrieve_with_bm25( f"条 chunk 上限,结果可能不完整。建议检查 FTS5 索引状态。", ) chunk_mds = [json.loads(doc["metadata"]) for doc in result] - result = [ + kb_chunks = [ { "chunk_id": doc["doc_id"], "chunk_index": chunk_md["chunk_index"], "doc_id": chunk_md["kb_doc_id"], "kb_id": kb_id, "text": doc["text"], + "kb_top_k": kb_top_k, } for doc, chunk_md in zip(result, chunk_mds) ] - chunks.extend(result) - top_k_sparse += kb_options.get(kb_id, {}).get("top_k_sparse", 50) + all_kb_chunks.append(kb_chunks) - if not chunks: + if not any(all_kb_chunks): return [] - corpus = [chunk["text"] for chunk in chunks] - tokenized_corpus = [tokenize_text(doc, self.hit_stopwords) for doc in corpus] - bm25 = BM25Okapi(tokenized_corpus) - - tokenized_query = tokenize_text(query, self.hit_stopwords) - scores = bm25.get_scores(tokenized_query) - - results = [] - for idx, score in enumerate(scores): - chunk = chunks[idx] - results.append( - SparseResult( - chunk_id=chunk["chunk_id"], - chunk_index=chunk["chunk_index"], - doc_id=chunk["doc_id"], - kb_id=chunk["kb_id"], - content=chunk["text"], - score=-float(score), - ), - ) + # 每个知识库独立计算 BM25 分数并截断,再合并 + merged_results = [] + for kb_chunks in all_kb_chunks: + if not kb_chunks: + continue + kb_top_k = kb_chunks[0]["kb_top_k"] - results.sort(key=lambda x: x.score) - return results[:top_k_sparse] + corpus = [chunk["text"] for chunk in kb_chunks] + tokenized_corpus = [ + tokenize_text(doc, self.hit_stopwords) for doc in corpus + ] + bm25 = BM25Okapi(tokenized_corpus) + + tokenized_query = tokenize_text(query, self.hit_stopwords) + scores = bm25.get_scores(tokenized_query) + + for idx, score in enumerate(scores): + chunk = kb_chunks[idx] + merged_results.append( + SparseResult( + chunk_id=chunk["chunk_id"], + chunk_index=chunk["chunk_index"], + doc_id=chunk["doc_id"], + kb_id=chunk["kb_id"], + content=chunk["text"], + score=-float(score), + ), + ) + + # 截断当前 KB 的结果 + kb_sorted = sorted(merged_results[-len(kb_chunks):], key=lambda x: x.score) + merged_results = merged_results[:-len(kb_chunks)] + kb_sorted[:kb_top_k] + + merged_results.sort(key=lambda x: x.score) + return merged_results[:top_k_sparse] diff --git a/tests/unit/test_sparse_retriever.py b/tests/unit/test_sparse_retriever.py index 11c491b4d2..02b1f1ffc3 100644 --- a/tests/unit/test_sparse_retriever.py +++ b/tests/unit/test_sparse_retriever.py @@ -91,3 +91,55 @@ async def test_sparse_retriever_falls_back_to_bm25_when_fts5_is_unavailable(): assert [result.chunk_id for result in results] == ["chunk-1"] assert storage.search_sparse_calls == 1 assert storage.get_documents_calls == 1 + + +class MultiKBStorage: + """模拟多知识库 BM25 回退场景""" + + def __init__(self, kb_id: str): + self.kb_id = kb_id + self.search_sparse_calls = 0 + self.get_documents_calls = 0 + + async def search_sparse(self, query_tokens: list[str], limit: int): + self.search_sparse_calls += 1 + return None # 始终回退到 BM25 + + async def get_documents( + self, metadata_filters: dict, limit: int | None, offset + ): + self.get_documents_calls += 1 + # 返回 10 条 chunk,远多于 top_k_sparse 限制 + return [ + make_doc(f"{self.kb_id}-chunk-{i}", f"document chunk {i}", i) + for i in range(10) + ] + + +@pytest.mark.asyncio +async def test_bm25_fallback_respects_per_kb_top_k_sparse(): + """多知识库 BM25 回退时,每个知识库的结果应被截断到各自的 top_k_sparse + + Phase 1C: 验证 top_k_sparse 不再被错误求和,而是逐 KB 截断。 + """ + storage_a = MultiKBStorage("kb-a") + storage_b = MultiKBStorage("kb-b") + vec_db_a = SimpleNamespace(document_storage=storage_a) + vec_db_b = SimpleNamespace(document_storage=storage_b) + retriever = SparseRetriever(kb_db=None) + + results = await retriever.retrieve( + query="test query", + kb_ids=["kb-a", "kb-b"], + kb_options={ + "kb-a": {"vec_db": vec_db_a, "top_k_sparse": 2}, + "kb-b": {"vec_db": vec_db_b, "top_k_sparse": 3}, + }, + ) + + # 总结果数不应超过 max(2, 3) = 3(最终截断),且每个 KB 各贡献 ≤ 其 top_k + assert len(results) <= 3, f"结果过多: {len(results)}" + kb_a_count = sum(1 for r in results if r.kb_id == "kb-a") + kb_b_count = sum(1 for r in results if r.kb_id == "kb-b") + assert kb_a_count <= 2, f"KB-A 贡献了 {kb_a_count} 条,应 ≤ 2" + assert kb_b_count <= 3, f"KB-B 贡献了 {kb_b_count} 条,应 ≤ 3" From 65ed3fc8b67171be5382ce66c637ce036c0be878 Mon Sep 17 00:00:00 2001 From: lxfight <1686540385@qq.com> Date: Fri, 5 Jun 2026 21:28:47 +0800 Subject: [PATCH 10/48] perf(kb): add kb_id generated column and parallelize dense retrieval - Add kb_id generated column with index to avoid full table scans on json_extract(metadata, '$.kb_id') for metadata filtering - Extend generated-column optimization to count_documents() - Parallelize _dense_retrieve across KBs via asyncio.gather --- .../db/vec_db/faiss_impl/document_storage.py | 35 +++++++++++++++---- .../core/knowledge_base/retrieval/manager.py | 34 ++++++++++++------ 2 files changed, 53 insertions(+), 16 deletions(-) diff --git a/astrbot/core/db/vec_db/faiss_impl/document_storage.py b/astrbot/core/db/vec_db/faiss_impl/document_storage.py index f451847964..641453206a 100644 --- a/astrbot/core/db/vec_db/faiss_impl/document_storage.py +++ b/astrbot/core/db/vec_db/faiss_impl/document_storage.py @@ -76,6 +76,12 @@ async def initialize(self) -> None: "GENERATED ALWAYS AS (json_extract(metadata, '$.user_id')) STORED", ), ) + await conn.execute( + text( + "ALTER TABLE documents ADD COLUMN kb_id TEXT " + "GENERATED ALWAYS AS (json_extract(metadata, '$.kb_id')) STORED", + ), + ) # Create indexes await conn.execute( @@ -88,6 +94,11 @@ async def initialize(self) -> None: "CREATE INDEX IF NOT EXISTS idx_documents_user_id ON documents(user_id)", ), ) + await conn.execute( + text( + "CREATE INDEX IF NOT EXISTS idx_documents_kb_id ON documents(kb_id)", + ), + ) except BaseException: pass @@ -257,9 +268,15 @@ async def get_documents( query = select(Document) for key, val in metadata_filters.items(): - query = query.where( - text(f"json_extract(metadata, '$.{key}') = :filter_{key}"), - ).params(**{f"filter_{key}": val}) + # kb_id 和 kb_doc_id 有生成列和索引,直接用列名过滤避免全表扫描 + if key in ("kb_id", "kb_doc_id"): + query = query.where( + text(f"{key} = :filter_{key}"), + ).params(**{f"filter_{key}": val}) + else: + query = query.where( + text(f"json_extract(metadata, '$.{key}') = :filter_{key}"), + ).params(**{f"filter_{key}": val}) if ids is not None and len(ids) > 0: valid_ids = [int(i) for i in ids if i != -1] @@ -453,9 +470,15 @@ async def count_documents(self, metadata_filters: dict | None = None) -> int: if metadata_filters: for key, val in metadata_filters.items(): - query = query.where( - text(f"json_extract(metadata, '$.{key}') = :filter_{key}"), - ).params(**{f"filter_{key}": val}) + # 使用生成列避免全表扫描 + if key in ("kb_id", "kb_doc_id"): + query = query.where( + text(f"{key} = :filter_{key}"), + ).params(**{f"filter_{key}": val}) + else: + query = query.where( + text(f"json_extract(metadata, '$.{key}') = :filter_{key}"), + ).params(**{f"filter_{key}": val}) result = await session.execute(query) count = result.scalar_one_or_none() diff --git a/astrbot/core/knowledge_base/retrieval/manager.py b/astrbot/core/knowledge_base/retrieval/manager.py index 1d65401ce5..2377fd7a87 100644 --- a/astrbot/core/knowledge_base/retrieval/manager.py +++ b/astrbot/core/knowledge_base/retrieval/manager.py @@ -209,7 +209,7 @@ async def _dense_retrieve( ): """稠密检索 (向量相似度) - 为每个知识库使用独立的向量数据库进行检索,然后合并结果。 + 为每个知识库使用独立的向量数据库进行并行检索,然后合并结果。 Args: query: 查询文本 @@ -220,10 +220,11 @@ async def _dense_retrieve( List[Result]: 检索结果列表 """ - all_results: list[Result] = [] - for kb_id in kb_ids: + import asyncio + + async def _retrieve_one(kb_id: str) -> list[Result]: if kb_id not in kb_options: - continue + return [] try: vec_db: FaissVecDB = kb_options[kb_id]["vec_db"] dense_k = int(kb_options[kb_id]["top_k_dense"]) @@ -234,17 +235,30 @@ async def _dense_retrieve( rerank=False, # 稠密检索阶段不进行 rerank metadata_filters={"kb_id": kb_id}, ) - - all_results.extend(vec_results) + return vec_results except Exception as e: - logger.error(f"知识库 {kb_id} 稠密检索失败: {e}", exc_info=True) + logger.error( + f"知识库 {kb_id} 稠密检索失败: {e}", exc_info=True, + ) if len(kb_ids) == 1: - raise RuntimeError(f"知识库 {kb_id} 稠密检索失败: {e}") from e + raise RuntimeError( + f"知识库 {kb_id} 稠密检索失败: {e}", + ) from e # multi-KB: skip the faulty KB and continue + return [] + + tasks = [_retrieve_one(kb_id) for kb_id in kb_ids] + results_per_kb = await asyncio.gather(*tasks, return_exceptions=True) + + all_results: list[Result] = [] + for result in results_per_kb: + if isinstance(result, Exception): + logger.error(f"稠密检索异常: {result}", exc_info=True) + continue + all_results.extend(result) - # 按相似度排序并返回 top_k + # 按相似度排序并返回 all_results.sort(key=lambda x: x.similarity, reverse=True) - # return all_results[: len(all_results) // len(kb_ids)] return all_results async def _rerank( From 6a331c5b19df23a9ea62d2f9ced44a08c53822e8 Mon Sep 17 00:00:00 2001 From: lxfight <1686540385@qq.com> Date: Fri, 5 Jun 2026 21:28:56 +0800 Subject: [PATCH 11/48] feat(kb): add HNSW index support and thread-safe LRU embedding cache - Support index_type='hnsw' via faiss.index_factory('HNSW32', METRIC_INNER_PRODUCT) - KnowledgeBase model gains index_type field (flat|hnsw) - EmbeddingCache: async-locked LRU cache keyed by SHA256(text) used in insert_batch and retrieve to skip redundant API calls - Cache integration validates returned vector count before indexing - Add tests: cache hit/miss, LRU eviction, insert_batch cache usage --- astrbot/core/knowledge_base/kb_helper.py | 1 + astrbot/core/knowledge_base/kb_mgr.py | 7 ++ astrbot/core/knowledge_base/models.py | 2 + tests/unit/test_faiss_vec_db.py | 107 +++++++++++++++++++---- 4 files changed, 102 insertions(+), 15 deletions(-) diff --git a/astrbot/core/knowledge_base/kb_helper.py b/astrbot/core/knowledge_base/kb_helper.py index 0b54f342df..3f56d075eb 100644 --- a/astrbot/core/knowledge_base/kb_helper.py +++ b/astrbot/core/knowledge_base/kb_helper.py @@ -189,6 +189,7 @@ async def _ensure_vec_db(self) -> "FaissVecDB": index_store_path=str(self.kb_dir / "index.faiss"), embedding_provider=ep, rerank_provider=rp, + index_type=self.kb.index_type or "flat", ) await vec_db.initialize() self.vec_db = vec_db diff --git a/astrbot/core/knowledge_base/kb_mgr.py b/astrbot/core/knowledge_base/kb_mgr.py index 3285d42c79..3795f9c4b3 100644 --- a/astrbot/core/knowledge_base/kb_mgr.py +++ b/astrbot/core/knowledge_base/kb_mgr.py @@ -94,6 +94,7 @@ async def create_kb( top_k_dense: int | None = None, top_k_sparse: int | None = None, top_m_final: int | None = None, + index_type: str | None = None, ) -> KBHelper: """创建新的知识库实例""" if embedding_provider_id is None: @@ -109,6 +110,7 @@ async def create_kb( top_k_dense=top_k_dense if top_k_dense is not None else 50, top_k_sparse=top_k_sparse if top_k_sparse is not None else 50, top_m_final=top_m_final if top_m_final is not None else 5, + index_type=index_type if index_type is not None else "flat", ) try: async with self.kb_db.get_db() as session: @@ -175,6 +177,7 @@ async def update_kb( top_k_dense: int | None = None, top_k_sparse: int | None = None, top_m_final: int | None = None, + index_type: str | None = None, ) -> KBHelper | None: """更新知识库实例""" kb_helper = await self.get_kb(kb_id) @@ -193,6 +196,7 @@ async def update_kb( "top_k_dense": kb.top_k_dense, "top_k_sparse": kb.top_k_sparse, "top_m_final": kb.top_m_final, + "index_type": kb.index_type, } previous_init_error = kb_helper.init_error @@ -215,6 +219,8 @@ async def update_kb( kb.top_k_sparse = top_k_sparse if top_m_final is not None: kb.top_m_final = top_m_final + if index_type is not None: + kb.index_type = index_type # Build a new helper first. Keep current vec_db alive until new init succeeds. new_helper = KBHelper( @@ -239,6 +245,7 @@ async def update_kb( kb.top_k_dense = previous_state["top_k_dense"] kb.top_k_sparse = previous_state["top_k_sparse"] kb.top_m_final = previous_state["top_m_final"] + kb.index_type = previous_state["index_type"] kb_helper.init_error = previous_init_error logger.error( f"知识库 {kb.kb_name}({kb.kb_id}) 重新初始化失败,继续使用旧实例: {e}", diff --git a/astrbot/core/knowledge_base/models.py b/astrbot/core/knowledge_base/models.py index da919a384a..a65cec0419 100644 --- a/astrbot/core/knowledge_base/models.py +++ b/astrbot/core/knowledge_base/models.py @@ -36,6 +36,8 @@ class KnowledgeBase(BaseKBModel, table=True): # 分块配置参数 chunk_size: int | None = Field(default=512, nullable=True) chunk_overlap: int | None = Field(default=50, nullable=True) + # 索引类型: "flat" (精确) 或 "hnsw" (近似最近邻,适合大规模) + index_type: str | None = Field(default="flat", max_length=10) # 检索配置参数 top_k_dense: int | None = Field(default=50, nullable=True) top_k_sparse: int | None = Field(default=50, nullable=True) diff --git a/tests/unit/test_faiss_vec_db.py b/tests/unit/test_faiss_vec_db.py index d294d51cd3..4d3999f793 100644 --- a/tests/unit/test_faiss_vec_db.py +++ b/tests/unit/test_faiss_vec_db.py @@ -1,35 +1,41 @@ -from unittest.mock import AsyncMock +from unittest.mock import AsyncMock, MagicMock +import numpy as np import pytest -from astrbot.core.db.vec_db.faiss_impl.vec_db import FaissVecDB +from astrbot.core.db.vec_db.faiss_impl.vec_db import EmbeddingCache, FaissVecDB from astrbot.core.exceptions import KnowledgeBaseUploadError -@pytest.mark.asyncio -async def test_insert_batch_skips_empty_contents() -> None: +def _make_vecdb(): + """创建最小化的 FaissVecDB mock""" vec_db = FaissVecDB.__new__(FaissVecDB) vec_db.embedding_provider = AsyncMock() vec_db.document_storage = AsyncMock() - vec_db.embedding_storage = AsyncMock() + vec_db.embedding_storage = MagicMock() + vec_db.embedding_storage.dimension = 128 + vec_db.embedding_cache = EmbeddingCache(max_size=100) + return vec_db + + +@pytest.mark.asyncio +async def test_insert_batch_skips_empty_contents() -> None: + vec_db = _make_vecdb() result = await FaissVecDB.insert_batch(vec_db, []) assert result == [] - vec_db.embedding_provider.get_embeddings_batch.assert_not_awaited() - vec_db.document_storage.insert_documents_batch.assert_not_awaited() - vec_db.embedding_storage.insert_batch.assert_not_awaited() + vec_db.embedding_provider.get_embeddings_batch.assert_not_called() + vec_db.document_storage.insert_documents_batch.assert_not_called() + vec_db.embedding_storage.insert_batch.assert_not_called() @pytest.mark.asyncio async def test_insert_batch_raises_friendly_error_for_embedding_count_mismatch() -> ( None ): - vec_db = FaissVecDB.__new__(FaissVecDB) - vec_db.embedding_provider = AsyncMock() + vec_db = _make_vecdb() vec_db.embedding_provider.get_embeddings_batch.return_value = [[0.1, 0.2]] - vec_db.document_storage = AsyncMock() - vec_db.embedding_storage = AsyncMock() vec_db.embedding_storage.dimension = 2 with pytest.raises(KnowledgeBaseUploadError) as exc_info: @@ -41,6 +47,77 @@ async def test_insert_batch_raises_friendly_error_for_embedding_count_mismatch() ) assert "向量化失败" in str(exc_info.value) - assert "期望 2,实际 1" in str(exc_info.value) - vec_db.document_storage.insert_documents_batch.assert_not_awaited() - vec_db.embedding_storage.insert_batch.assert_not_awaited() + assert "期望 2" in str(exc_info.value) + assert "实际 1" in str(exc_info.value) + vec_db.document_storage.insert_documents_batch.assert_not_called() + vec_db.embedding_storage.insert_batch.assert_not_called() + + +class TestEmbeddingCache: + """Phase 2B: 嵌入缓存测试""" + + @pytest.mark.asyncio + async def test_cache_hit_returns_cached_value(self): + """缓存命中时返回已缓存的向量""" + cache = EmbeddingCache(max_size=100) + text = "hello world" + emb = np.array([0.1, 0.2, 0.3], dtype=np.float32) + + await cache.put(text, emb) + result = await cache.get(text) + + assert result is not None + assert np.array_equal(result, emb) + + @pytest.mark.asyncio + async def test_cache_miss_returns_none(self): + """缓存未命中时返回 None""" + cache = EmbeddingCache(max_size=100) + result = await cache.get("unknown text") + assert result is None + + @pytest.mark.asyncio + async def test_cache_lru_eviction(self): + """超出 max_size 时最早的条目应被淘汰""" + cache = EmbeddingCache(max_size=3) + for i in range(5): + await cache.put(f"text_{i}", np.array([float(i)], dtype=np.float32)) + + assert await cache.__len__() == 3 + # text_0 和 text_1 应该被淘汰 + assert await cache.get("text_0") is None + assert await cache.get("text_1") is None + # text_2, text_3, text_4 应该仍然存在 + assert await cache.get("text_2") is not None + assert await cache.get("text_3") is not None + assert await cache.get("text_4") is not None + + @pytest.mark.asyncio + async def test_insert_batch_uses_cache(self): + """insert_batch 缓存命中时减少 provider 调用""" + vec_db = _make_vecdb() + # 预缓存两个文本 + await vec_db.embedding_cache.put( + "cached_1", np.array([0.5] * 128, dtype=np.float32), + ) + await vec_db.embedding_cache.put( + "cached_2", np.array([0.6] * 128, dtype=np.float32), + ) + vec_db.embedding_provider.get_embeddings_batch.return_value = ( + [[0.1] * 128] + ) + vec_db.document_storage.insert_documents_batch = AsyncMock( + return_value=[10, 11, 12], + ) + vec_db.embedding_storage.insert_batch = AsyncMock() + + result = await FaissVecDB.insert_batch( + vec_db, + contents=["cached_1", "cached_2", "fresh_text"], + batch_size=32, + tasks_limit=3, + max_retries=3, + ) + assert len(result) == 3 + # 只应调用一次 get_embeddings_batch(仅 fresh_text 未缓存) + assert vec_db.embedding_provider.get_embeddings_batch.call_count == 1 From 9ef28c5ec129de57ca661cd3c737c467e91b5cfb Mon Sep 17 00:00:00 2001 From: lxfight <1686540385@qq.com> Date: Fri, 5 Jun 2026 22:33:01 +0800 Subject: [PATCH 12/48] =?UTF-8?q?fix(kb):=20preserve=20FAISS=20external=20?= =?UTF-8?q?IDs=20during=20L2=E2=86=92IP=20index=20migration?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The old migration used np.arange(ntotal) which assigned new sequential IDs, breaking the mapping between FAISS vectors and DocumentStorage chunks. Extract original external IDs from IndexIDMap and re-register them with the new IP index. Also broadens L2 detection from isinstance(IndexFlatL2) to metric_type check, covering any L2 variant. --- .../db/vec_db/faiss_impl/embedding_storage.py | 44 ++++++++++++------- tests/unit/test_embedding_storage.py | 17 +++++++ 2 files changed, 45 insertions(+), 16 deletions(-) diff --git a/astrbot/core/db/vec_db/faiss_impl/embedding_storage.py b/astrbot/core/db/vec_db/faiss_impl/embedding_storage.py index 8747b5bfc5..b0f433943f 100644 --- a/astrbot/core/db/vec_db/faiss_impl/embedding_storage.py +++ b/astrbot/core/db/vec_db/faiss_impl/embedding_storage.py @@ -81,14 +81,12 @@ def _migrate_l2_to_ip_if_needed(self) -> None: """检测并迁移旧版 L2 索引到 IP (余弦相似度) 旧版使用 IndexFlatL2,新版使用 IndexFlatIP + 归一化向量。 - 迁移过程:reconstruct 所有向量 → L2 归一化 → 重建为 IP 索引。 + 迁移过程:保留原 external ids → reconstruct 所有向量 → L2 归一化 → 重建为 IP 索引。 """ assert self.index is not None # IndexIDMap 包装了 base index,需要解包检查 - base_index = ( - self.index.index if hasattr(self.index, "index") else self.index - ) - if not isinstance(base_index, faiss.IndexFlatL2): + base_index = self.index.index if hasattr(self.index, "index") else self.index + if getattr(base_index, "metric_type", None) != faiss.METRIC_L2: return # 已经是 IP 或其他类型,无需迁移 import warnings @@ -112,29 +110,43 @@ def _migrate_l2_to_ip_if_needed(self) -> None: # 重建所有向量并归一化 # 注意: IndexIDMap.reconstruct 在某些 FAISS 构建版本中不可用 try: + ids = self._get_index_ids() vectors = np.zeros((ntotal, self.dimension), dtype=np.float32) - for i in range(ntotal): - vectors[i] = self.index.reconstruct(i) - except RuntimeError: - warnings.warn( - "无法从旧索引重建向量(reconstruct 不可用)," - "将重建空 IP 索引。请重新上传文档以生成新向量。", - stacklevel=2, + reconstruct_index = ( + self.index.index if hasattr(self.index, "index") else self.index ) - self.index = faiss.IndexIDMap(faiss.IndexFlatIP(self.dimension)) - faiss.write_index(self.index, self.path) - return + for pos in range(ntotal): + vectors[pos] = reconstruct_index.reconstruct(pos) + except Exception as exc: + raise RuntimeError( + "无法从旧索引重建向量(reconstruct 不可用)," + "已保留旧索引文件未覆盖。请重新上传文档或手动重建知识库索引。" + ) from exc _safe_normalize_l2(vectors) # 重建为 IP 索引 new_index = faiss.IndexIDMap(faiss.IndexFlatIP(self.dimension)) - new_index.add_with_ids(vectors, np.arange(ntotal, dtype=np.int64)) + new_index.add_with_ids(vectors, ids) self.index = new_index # 立即保存迁移后的索引 faiss.write_index(self.index, self.path) + def _get_index_ids(self) -> np.ndarray: + assert self.index is not None + ntotal = self.index.ntotal + id_map = getattr(self.index, "id_map", None) + if id_map is None: + return np.arange(ntotal, dtype=np.int64) + + ids = faiss.vector_to_array(id_map).astype(np.int64) + if len(ids) != ntotal: + raise RuntimeError( + f"FAISS IDMap 数量异常: ntotal={ntotal}, id_map={len(ids)}", + ) + return ids + async def insert(self, vector: np.ndarray, id: int) -> None: """插入向量 diff --git a/tests/unit/test_embedding_storage.py b/tests/unit/test_embedding_storage.py index c754bb0db5..4da1b6eb8c 100644 --- a/tests/unit/test_embedding_storage.py +++ b/tests/unit/test_embedding_storage.py @@ -201,6 +201,23 @@ async def test_migration_l2_to_ip(self): f"迁移后搜索自身应有高分, 实际={distances[0][0]}" ) + @pytest.mark.asyncio + async def test_migration_preserves_external_ids(self): + faiss = pytest.importorskip("faiss") + + with tempfile.TemporaryDirectory() as tmpdir: + index_path = Path(tmpdir) / "index.faiss" + old_index = faiss.IndexIDMap(faiss.IndexFlatL2(DIM)) + vectors = make_random_batch(3) + ids = np.array([10, 42, 99], dtype=np.int64) + old_index.add_with_ids(vectors, ids) + faiss.write_index(old_index, str(index_path)) + + storage = EmbeddingStorage(dimension=DIM, path=str(index_path)) + + _, result_ids = await storage.search(vectors[1].copy().reshape(1, -1), k=1) + assert result_ids[0][0] == 42 + @pytest.mark.asyncio async def test_no_crash_on_reload_existing_ip_index(self): """重新加载已有的 IP 索引不应报错""" From 8f63a6fae2aace7da594cbb51bd42b442af9b344 Mon Sep 17 00:00:00 2001 From: lxfight <1686540385@qq.com> Date: Fri, 5 Jun 2026 22:33:07 +0800 Subject: [PATCH 13/48] fix(kb): return bool from vec_db.delete to signal whether chunk existed Returning None made it impossible for callers to distinguish 'successfully deleted' from 'chunk not found'. Now returns False when the chunk is missing, True on successful deletion. --- astrbot/core/db/vec_db/faiss_impl/vec_db.py | 10 +++++--- tests/unit/test_faiss_vec_db.py | 28 +++++++++++++++++++++ 2 files changed, 34 insertions(+), 4 deletions(-) diff --git a/astrbot/core/db/vec_db/faiss_impl/vec_db.py b/astrbot/core/db/vec_db/faiss_impl/vec_db.py index b09175dae4..1c87d0ca25 100644 --- a/astrbot/core/db/vec_db/faiss_impl/vec_db.py +++ b/astrbot/core/db/vec_db/faiss_impl/vec_db.py @@ -181,7 +181,7 @@ async def insert_batch( ) # 只对未缓存的文本生成嵌入 - vectors = [np.empty(0, dtype=np.float32)] * len(contents) + vectors = [np.empty(0, dtype=np.float32) for _ in contents] if uncached_texts: new_embeddings = await self.embedding_provider.get_embeddings_batch( uncached_texts, @@ -301,7 +301,8 @@ async def retrieve( else: embedding = await self.embedding_provider.get_embedding(query) await self.embedding_cache.put( - query, np.asarray(embedding, dtype=np.float32), + query, + np.asarray(embedding, dtype=np.float32), ) scores, indices = await self.embedding_storage.search( vector=np.array([embedding]).astype("float32"), @@ -346,17 +347,18 @@ async def retrieve( return top_k_results - async def delete(self, doc_id: str) -> None: + async def delete(self, doc_id: str) -> bool: """删除一条文档块(chunk)""" # 获得对应的 int id result = await self.document_storage.get_document_by_doc_id(doc_id) int_id = result["id"] if result else None if int_id is None: - return + return False # 使用 DocumentStorage 的删除方法 await self.document_storage.delete_document_by_doc_id(doc_id) await self.embedding_storage.delete([int_id]) + return True async def close(self) -> None: await self.document_storage.close() diff --git a/tests/unit/test_faiss_vec_db.py b/tests/unit/test_faiss_vec_db.py index 4d3999f793..610ba446d4 100644 --- a/tests/unit/test_faiss_vec_db.py +++ b/tests/unit/test_faiss_vec_db.py @@ -53,6 +53,34 @@ async def test_insert_batch_raises_friendly_error_for_embedding_count_mismatch() vec_db.embedding_storage.insert_batch.assert_not_called() +@pytest.mark.asyncio +async def test_delete_returns_false_when_chunk_is_missing() -> None: + vec_db = _make_vecdb() + vec_db.document_storage.get_document_by_doc_id.return_value = None + vec_db.document_storage.delete_document_by_doc_id = AsyncMock() + vec_db.embedding_storage.delete = AsyncMock() + + deleted = await FaissVecDB.delete(vec_db, "missing-chunk") + + assert deleted is False + vec_db.document_storage.delete_document_by_doc_id.assert_not_called() + vec_db.embedding_storage.delete.assert_not_called() + + +@pytest.mark.asyncio +async def test_delete_returns_true_when_chunk_exists() -> None: + vec_db = _make_vecdb() + vec_db.document_storage.get_document_by_doc_id.return_value = {"id": 42} + vec_db.document_storage.delete_document_by_doc_id = AsyncMock() + vec_db.embedding_storage.delete = AsyncMock() + + deleted = await FaissVecDB.delete(vec_db, "chunk-1") + + assert deleted is True + vec_db.document_storage.delete_document_by_doc_id.assert_awaited_once_with("chunk-1") + vec_db.embedding_storage.delete.assert_awaited_once_with([42]) + + class TestEmbeddingCache: """Phase 2B: 嵌入缓存测试""" From c7593f9285f1c07afc987ddea410dc7d19e789c6 Mon Sep 17 00:00:00 2001 From: lxfight <1686540385@qq.com> Date: Fri, 5 Jun 2026 22:33:20 +0800 Subject: [PATCH 14/48] fix(kb): serialize FTS5 rebuild, use VIRTUAL generated columns, return None on limit<=0 - Add _fts_rebuild_lock to prevent concurrent FTS5 index rebuilds from corrupting the contentless-delete table. - Switch generated columns from STORED to VIRTUAL so ALTER TABLE is O(1) on existing large tables (index still materializes the computed value). - Extract _apply_metadata_filters helper to reduce duplication across get_documents / delete_documents / count_documents. - Return None from search_sparse when limit <= 0 instead of empty list, so callers correctly fall back to in-memory BM25. --- .../db/vec_db/faiss_impl/document_storage.py | 219 +++++++++++------- tests/unit/test_document_storage_fts.py | 125 ++++++++++ 2 files changed, 263 insertions(+), 81 deletions(-) diff --git a/astrbot/core/db/vec_db/faiss_impl/document_storage.py b/astrbot/core/db/vec_db/faiss_impl/document_storage.py index 641453206a..4b3c1a4c60 100644 --- a/astrbot/core/db/vec_db/faiss_impl/document_storage.py +++ b/astrbot/core/db/vec_db/faiss_impl/document_storage.py @@ -1,12 +1,16 @@ import json import os +from asyncio import Lock from contextlib import asynccontextmanager from datetime import datetime from pathlib import Path from sqlalchemy import Column, Text, bindparam +from sqlalchemy.dialects import sqlite from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession, create_async_engine from sqlalchemy.orm import sessionmaker +from sqlalchemy.pool import NullPool +from sqlalchemy.schema import CreateTable from sqlmodel import Field, MetaData, SQLModel, col, func, select, text from astrbot.core import logger @@ -55,61 +59,98 @@ def __init__(self, db_path: str) -> None: self._fts_contentless_delete = False self._fts_index_ready = False self._stopwords: set[str] | None = None + self._fts_rebuild_lock = Lock() async def initialize(self) -> None: """Initialize the SQLite database and create the documents table if it doesn't exist.""" await self.connect() async with self.engine.begin() as conn: # type: ignore - # Create tables using SQLModel - await conn.run_sync(BaseDocModel.metadata.create_all) + await self._ensure_documents_table(conn) + await self._ensure_generated_columns(conn) - try: - await conn.execute( - text( - "ALTER TABLE documents ADD COLUMN kb_doc_id TEXT " - "GENERATED ALWAYS AS (json_extract(metadata, '$.kb_doc_id')) STORED", - ), - ) - await conn.execute( - text( - "ALTER TABLE documents ADD COLUMN user_id TEXT " - "GENERATED ALWAYS AS (json_extract(metadata, '$.user_id')) STORED", - ), - ) - await conn.execute( - text( - "ALTER TABLE documents ADD COLUMN kb_id TEXT " - "GENERATED ALWAYS AS (json_extract(metadata, '$.kb_id')) STORED", - ), - ) + await self._initialize_fts5(conn) + await conn.commit() - # Create indexes - await conn.execute( - text( - "CREATE INDEX IF NOT EXISTS idx_documents_kb_doc_id ON documents(kb_doc_id)", - ), - ) - await conn.execute( - text( - "CREATE INDEX IF NOT EXISTS idx_documents_user_id ON documents(user_id)", - ), - ) - await conn.execute( - text( - "CREATE INDEX IF NOT EXISTS idx_documents_kb_id ON documents(kb_id)", - ), - ) - except BaseException: - pass + async def _table_columns(self, executor, table_name: str) -> set[str]: + result = await executor.execute(text(f"PRAGMA table_xinfo({table_name})")) + return {row[1] for row in result.fetchall()} - await conn.execute( + async def _ensure_generated_columns(self, executor) -> None: + generated_columns = { + "kb_doc_id": "json_extract(metadata, '$.kb_doc_id')", + "user_id": "json_extract(metadata, '$.user_id')", + "kb_id": "json_extract(metadata, '$.kb_id')", + } + columns = await self._table_columns(executor, "documents") + for column_name, expression in generated_columns.items(): + if column_name in columns: + continue + await executor.execute( text( - "CREATE UNIQUE INDEX IF NOT EXISTS idx_documents_doc_id_unique ON documents(doc_id)", + f"ALTER TABLE documents ADD COLUMN {column_name} TEXT " + f"GENERATED ALWAYS AS ({expression}) VIRTUAL", ), ) + columns.add(column_name) - await self._initialize_fts5(conn) - await conn.commit() + index_statements = [ + "CREATE INDEX IF NOT EXISTS idx_documents_kb_doc_id " + "ON documents(kb_doc_id)", + "CREATE INDEX IF NOT EXISTS idx_documents_user_id ON documents(user_id)", + "CREATE INDEX IF NOT EXISTS idx_documents_kb_id ON documents(kb_id)", + ] + for statement in index_statements: + await executor.execute(text(statement)) + + async def _ensure_documents_table(self, executor) -> None: + """Create the document table from the SQLModel definition.""" + result = await executor.execute( + text( + """ + SELECT 1 + FROM sqlite_master + WHERE type='table' AND name=:table_name + LIMIT 1 + """, + ), + {"table_name": Document.__tablename__}, + ) + if result.scalar_one_or_none() is not None: + await self._ensure_doc_id_unique_index(executor) + return + + create_table = CreateTable(Document.__table__, if_not_exists=True) # type: ignore[attr-defined] + + await executor.execute( + text(str(create_table.compile(dialect=sqlite.dialect()))) + ) + await self._ensure_doc_id_unique_index(executor) + + async def _ensure_doc_id_unique_index(self, executor) -> None: + duplicate_result = await executor.execute( + text( + """ + SELECT doc_id + FROM documents + GROUP BY doc_id + HAVING COUNT(*) > 1 + LIMIT 1 + """, + ), + ) + if duplicate_result.scalar_one_or_none() is not None: + logger.warning( + "Skipping documents.doc_id unique index migration because duplicate " + f"doc_id values already exist in {self.db_path}.", + ) + return + + await executor.execute( + text( + "CREATE UNIQUE INDEX IF NOT EXISTS " + "idx_documents_doc_id_unique ON documents(doc_id)", + ), + ) async def _initialize_fts5(self, executor) -> None: try: @@ -214,6 +255,7 @@ async def connect(self) -> None: self.DATABASE_URL, echo=False, future=True, + poolclass=NullPool, ) self.async_session_maker = sessionmaker( self.engine, # type: ignore @@ -266,17 +308,11 @@ async def get_documents( async with self.get_session() as session: query = select(Document) - - for key, val in metadata_filters.items(): - # kb_id 和 kb_doc_id 有生成列和索引,直接用列名过滤避免全表扫描 - if key in ("kb_id", "kb_doc_id"): - query = query.where( - text(f"{key} = :filter_{key}"), - ).params(**{f"filter_{key}": val}) - else: - query = query.where( - text(f"json_extract(metadata, '$.{key}') = :filter_{key}"), - ).params(**{f"filter_{key}": val}) + query = await self._apply_metadata_filters( + session, + query, + metadata_filters, + ) if ids is not None and len(ids) > 0: valid_ids = [int(i) for i in ids if i != -1] @@ -438,11 +474,11 @@ async def delete_documents(self, metadata_filters: dict) -> None: async with self.get_session() as session, session.begin(): query = select(Document) - - for key, val in metadata_filters.items(): - query = query.where( - text(f"json_extract(metadata, '$.{key}') = :filter_{key}"), - ).params(**{f"filter_{key}": val}) + query = await self._apply_metadata_filters( + session, + query, + metadata_filters, + ) result = await session.execute(query) documents = result.scalars().all() @@ -469,21 +505,34 @@ async def count_documents(self, metadata_filters: dict | None = None) -> int: query = select(func.count(col(Document.id))) if metadata_filters: - for key, val in metadata_filters.items(): - # 使用生成列避免全表扫描 - if key in ("kb_id", "kb_doc_id"): - query = query.where( - text(f"{key} = :filter_{key}"), - ).params(**{f"filter_{key}": val}) - else: - query = query.where( - text(f"json_extract(metadata, '$.{key}') = :filter_{key}"), - ).params(**{f"filter_{key}": val}) + query = await self._apply_metadata_filters( + session, + query, + metadata_filters, + ) result = await session.execute(query) count = result.scalar_one_or_none() return count if count is not None else 0 + async def _apply_metadata_filters( + self, + session: AsyncSession, + query, + metadata_filters: dict, + ): + columns = await self._table_columns(session, "documents") + for key, val in metadata_filters.items(): + if key in {"kb_id", "kb_doc_id", "user_id"} and key in columns: + query = query.where( + text(f"{key} = :filter_{key}"), + ).params(**{f"filter_{key}": val}) + else: + query = query.where( + text(f"json_extract(metadata, '$.{key}') = :filter_{key}"), + ).params(**{f"filter_{key}": val}) + return query + async def ensure_fts_index(self) -> bool: """Ensure the FTS5 sparse index exists and matches the documents table.""" if not self.fts5_available: @@ -493,22 +542,30 @@ async def ensure_fts_index(self) -> bool: assert self.engine is not None, "Database connection is not initialized." - async with self.get_session() as session: - doc_count = await self._count_documents_in_session(session) - fts_count = await self._count_fts_rows(session) - if doc_count == fts_count: - self._fts_index_ready = True + async with self._fts_rebuild_lock: + if self._fts_index_ready: return True - logger.info( - f"Rebuilding FTS5 sparse index for {self.db_path}: " - f"documents={doc_count}, fts_rows={fts_count}", - ) - await self.rebuild_fts_index() - return self.fts5_available + async with self.get_session() as session: + doc_count = await self._count_documents_in_session(session) + fts_count = await self._count_fts_rows(session) + if doc_count == fts_count: + self._fts_index_ready = True + return True + + logger.info( + f"Rebuilding FTS5 sparse index for {self.db_path}: " + f"documents={doc_count}, fts_rows={fts_count}", + ) + await self._rebuild_fts_index_unlocked() + return self.fts5_available async def rebuild_fts_index(self) -> None: """Rebuild the contentless FTS5 sparse index from documents.""" + async with self._fts_rebuild_lock: + await self._rebuild_fts_index_unlocked() + + async def _rebuild_fts_index_unlocked(self) -> None: if not self.fts5_available: return @@ -553,7 +610,7 @@ async def search_sparse( sparse retrieval implementation. """ if limit <= 0: - return [] + return None if not await self.ensure_fts_index(): return None diff --git a/tests/unit/test_document_storage_fts.py b/tests/unit/test_document_storage_fts.py index a7dd32c94c..0ea699b0c7 100644 --- a/tests/unit/test_document_storage_fts.py +++ b/tests/unit/test_document_storage_fts.py @@ -1,6 +1,8 @@ +import asyncio import sqlite3 import pytest +from sqlalchemy.exc import IntegrityError from astrbot.core.db.vec_db.faiss_impl.document_storage import DocumentStorage @@ -57,6 +59,53 @@ async def test_document_storage_fts_rebuilds_existing_documents(tmp_path): await storage.close() +@pytest.mark.asyncio +async def test_document_storage_search_sparse_non_positive_limit_falls_back(tmp_path): + storage = DocumentStorage(str(tmp_path / "doc.db")) + await storage.initialize() + + assert await storage.search_sparse(["知识库"], limit=0) is None + + await storage.close() + + +@pytest.mark.asyncio +async def test_document_storage_fts_rebuild_is_serialized(tmp_path, monkeypatch): + storage = DocumentStorage(str(tmp_path / "doc.db")) + await storage.initialize() + + storage.fts5_available = False + await storage.insert_document( + doc_id="legacy-chunk", + text="legacy 知识库 文本", + metadata={"kb_doc_id": "doc-1", "kb_id": "kb-1", "chunk_index": 0}, + ) + + storage.fts5_available = True + storage._fts_index_ready = False + rebuild_count = 0 + original_rebuild = storage._rebuild_fts_index_unlocked + + async def counted_rebuild(): + nonlocal rebuild_count + rebuild_count += 1 + await asyncio.sleep(0) + await original_rebuild() + + monkeypatch.setattr(storage, "_rebuild_fts_index_unlocked", counted_rebuild) + + results = await asyncio.gather( + storage.ensure_fts_index(), + storage.ensure_fts_index(), + storage.ensure_fts_index(), + ) + + assert results == [True, True, True] + assert rebuild_count == 1 + + await storage.close() + + @pytest.mark.asyncio async def test_document_storage_fts_delete_skips_missing_fts_row(tmp_path): storage = DocumentStorage(str(tmp_path / "doc.db")) @@ -101,3 +150,79 @@ async def test_document_storage_fts_recovers_from_legacy_non_fts_table(tmp_path) assert [result["doc_id"] for result in results] == ["legacy-fix"] await storage.close() + + +@pytest.mark.asyncio +async def test_document_storage_adds_unique_doc_id_index_to_existing_table(tmp_path): + db_path = tmp_path / "doc.db" + conn = sqlite3.connect(db_path) + conn.execute( + """ + CREATE TABLE documents ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + doc_id VARCHAR NOT NULL, + text VARCHAR NOT NULL, + metadata TEXT, + created_at DATETIME, + updated_at DATETIME + ) + """, + ) + conn.execute( + "INSERT INTO documents (doc_id, text) VALUES ('legacy-chunk', 'legacy text')" + ) + conn.commit() + conn.close() + + storage = DocumentStorage(str(db_path)) + await storage.initialize() + + with pytest.raises(IntegrityError): + await storage.insert_document( + doc_id="legacy-chunk", + text="duplicate text", + metadata={}, + ) + + await storage.close() + + +@pytest.mark.asyncio +async def test_document_storage_adds_missing_kb_id_generated_column(tmp_path): + db_path = tmp_path / "doc.db" + conn = sqlite3.connect(db_path) + conn.execute( + """ + CREATE TABLE documents ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + doc_id VARCHAR NOT NULL, + text VARCHAR NOT NULL, + metadata TEXT, + created_at DATETIME, + updated_at DATETIME, + kb_doc_id TEXT GENERATED ALWAYS AS + (json_extract(metadata, '$.kb_doc_id')) VIRTUAL + ) + """, + ) + conn.execute( + """ + INSERT INTO documents (doc_id, text, metadata) + VALUES ( + 'legacy-chunk', + 'legacy text', + '{"kb_doc_id":"doc-1","kb_id":"kb-1","chunk_index":0}' + ) + """, + ) + conn.commit() + conn.close() + + storage = DocumentStorage(str(db_path)) + await storage.initialize() + + docs = await storage.get_documents(metadata_filters={"kb_id": "kb-1"}) + + assert [doc["doc_id"] for doc in docs] == ["legacy-chunk"] + + await storage.close() From ab904aca847a5bf52a69804ffbedfd3c307de1ae Mon Sep 17 00:00:00 2001 From: lxfight <1686540385@qq.com> Date: Fri, 5 Jun 2026 22:33:28 +0800 Subject: [PATCH 15/48] fix(kb): add incremental index_type migration and guard document deletion with kb_id - Add _ensure_column helper for safe incremental schema migrations. - Add index_type column migration in migrate_to_v1 for legacy databases. - Add kb_id parameter to delete_document_by_id / delete_documents_by_ids to prevent accidental cross-KB deletion. - Reorder batch delete: clean vectors first, then remove metadata only for successful vec_db deletions (best-effort per document). - Also cascade-delete KBMedia rows alongside KBDocument rows. --- astrbot/core/knowledge_base/kb_db_sqlite.py | 110 ++++++++++++++++---- 1 file changed, 89 insertions(+), 21 deletions(-) diff --git a/astrbot/core/knowledge_base/kb_db_sqlite.py b/astrbot/core/knowledge_base/kb_db_sqlite.py index 4704f3d19c..2813e291b6 100644 --- a/astrbot/core/knowledge_base/kb_db_sqlite.py +++ b/astrbot/core/knowledge_base/kb_db_sqlite.py @@ -89,6 +89,13 @@ async def migrate_to_v1(self) -> None: async with self.get_db() as session: session: AsyncSession async with session.begin(): + await self._ensure_column( + session, + table_name="knowledge_bases", + column_name="index_type", + column_sql="index_type TEXT DEFAULT 'flat'", + ) + # 创建知识库表索引 await session.execute( text( @@ -168,6 +175,24 @@ async def migrate_to_v1(self) -> None: await session.commit() + async def _ensure_column( + self, + session: AsyncSession, + *, + table_name: str, + column_name: str, + column_sql: str, + ) -> None: + """Add a column when upgrading an existing SQLite table.""" + result = await session.execute(text(f"PRAGMA table_xinfo({table_name})")) + columns = {row[1] for row in result.fetchall()} + if column_name in columns: + return + logger.info( + f"知识库数据库迁移: 为表 {table_name} 添加列 {column_name}", + ) + await session.execute(text(f"ALTER TABLE {table_name} ADD COLUMN {column_sql}")) + async def close(self) -> None: """关闭数据库连接""" await self.engine.dispose() @@ -300,41 +325,69 @@ async def get_documents_with_metadata_batch( return metadata_map - async def delete_document_by_id(self, doc_id: str, vec_db: "FaissVecDB") -> None: + async def delete_document_by_id( + self, + doc_id: str, + vec_db: "FaissVecDB", + kb_id: str | None = None, + ) -> bool: """删除单个文档及其相关数据""" - # 在知识库表中删除 + doc = await self.get_document_by_id(doc_id) + if not doc or (kb_id is not None and doc.kb_id != kb_id): + return False + + metadata_filters = {"kb_doc_id": doc_id} + if kb_id is not None: + metadata_filters["kb_id"] = kb_id + + # 先删向量库;如果失败,保留 metadata 以便重试/修复。 + await vec_db.delete_documents(metadata_filters=metadata_filters) + async with self.get_db() as session, session.begin(): - # 删除文档记录 delete_stmt = delete(KBDocument).where(col(KBDocument.doc_id) == doc_id) + if kb_id is not None: + delete_stmt = delete_stmt.where(col(KBDocument.kb_id) == kb_id) await session.execute(delete_stmt) - await session.commit() + await session.execute(delete(KBMedia).where(col(KBMedia.doc_id) == doc_id)) - # 在 vec db 中删除相关向量 - await vec_db.delete_documents(metadata_filters={"kb_doc_id": doc_id}) + return True async def delete_documents_by_ids( - self, doc_ids: list[str], vec_db: "FaissVecDB", + self, + doc_ids: list[str], + vec_db: "FaissVecDB", + kb_id: str | None = None, ) -> dict[str, bool]: """批量删除文档及其向量数据。 - 单个文档的 vec_db 删除失败不影响其他文档(best-effort)。 + 先删除向量数据,再删除 metadata;单个文档的 vec_db 删除失败 + 不影响其他文档(best-effort),失败项保留 metadata 以便重试。 """ if not doc_ids: return {} - # 批量从知识库表中删除 - async with self.get_db() as session, session.begin(): - delete_stmt = delete(KBDocument).where( - col(KBDocument.doc_id).in_(doc_ids), - ) - await session.execute(delete_stmt) + requested_doc_ids = list(dict.fromkeys(doc_ids)) + results = dict.fromkeys(requested_doc_ids, False) + + candidates = requested_doc_ids + if kb_id is not None: + async with self.get_db() as session: + stmt = select(KBDocument.doc_id).where( + col(KBDocument.doc_id).in_(requested_doc_ids), + col(KBDocument.kb_id) == kb_id, + ) + result = await session.execute(stmt) + candidates = [row[0] for row in result.fetchall()] + + if not candidates: + return results - # 并行清理 vec_db(向量 + SQLite 文档存储) async def _delete_one(doc_id: str) -> tuple[str, bool]: + metadata_filters = {"kb_doc_id": doc_id} + if kb_id is not None: + metadata_filters["kb_id"] = kb_id try: - await vec_db.delete_documents( - metadata_filters={"kb_doc_id": doc_id}, - ) + await vec_db.delete_documents(metadata_filters=metadata_filters) return doc_id, True except Exception as e: logger.error( @@ -342,11 +395,26 @@ async def _delete_one(doc_id: str) -> tuple[str, bool]: ) return doc_id, False - results: dict[str, bool] = {} - tasks = [_delete_one(doc_id) for doc_id in doc_ids] - vec_results = await asyncio.gather(*tasks) + vec_results = await asyncio.gather( + *[_delete_one(doc_id) for doc_id in candidates], + ) + successful_doc_ids = [] for doc_id, success in vec_results: results[doc_id] = success + if success: + successful_doc_ids.append(doc_id) + + if successful_doc_ids: + async with self.get_db() as session, session.begin(): + delete_stmt = delete(KBDocument).where( + col(KBDocument.doc_id).in_(successful_doc_ids), + ) + if kb_id is not None: + delete_stmt = delete_stmt.where(col(KBDocument.kb_id) == kb_id) + await session.execute(delete_stmt) + await session.execute( + delete(KBMedia).where(col(KBMedia.doc_id).in_(successful_doc_ids)), + ) return results From 590e898b353127896cf72ab42c33c002712bef05 Mon Sep 17 00:00:00 2001 From: lxfight <1686540385@qq.com> Date: Fri, 5 Jun 2026 22:33:41 +0800 Subject: [PATCH 16/48] refactor(kb): clarify RRF variable names and simplify BM25 per-KB aggregation MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Rename vec_doc_id_to_dense → chunk_id_to_dense in rank_fusion to reflect that Document.doc_id stores the chunk UUID, not a doc ID. - Replace fragile suffix-slicing (merged_results[-len(kb_chunks):]) with an explicit kb_results accumulator in BM25 fallback path, keeping per-KB truncation self-contained. --- .../core/knowledge_base/retrieval/manager.py | 3 ++- .../knowledge_base/retrieval/rank_fusion.py | 23 +++++++++---------- .../retrieval/sparse_retriever.py | 11 ++++----- 3 files changed, 18 insertions(+), 19 deletions(-) diff --git a/astrbot/core/knowledge_base/retrieval/manager.py b/astrbot/core/knowledge_base/retrieval/manager.py index 2377fd7a87..07543b48a7 100644 --- a/astrbot/core/knowledge_base/retrieval/manager.py +++ b/astrbot/core/knowledge_base/retrieval/manager.py @@ -238,7 +238,8 @@ async def _retrieve_one(kb_id: str) -> list[Result]: return vec_results except Exception as e: logger.error( - f"知识库 {kb_id} 稠密检索失败: {e}", exc_info=True, + f"知识库 {kb_id} 稠密检索失败: {e}", + exc_info=True, ) if len(kb_ids) == 1: raise RuntimeError( diff --git a/astrbot/core/knowledge_base/retrieval/rank_fusion.py b/astrbot/core/knowledge_base/retrieval/rank_fusion.py index 056c59493f..744287e655 100644 --- a/astrbot/core/knowledge_base/retrieval/rank_fusion.py +++ b/astrbot/core/knowledge_base/retrieval/rank_fusion.py @@ -63,28 +63,27 @@ async def fuse( List[FusedResult]: 融合后的结果列表 """ - # 1. 构建排名映射 + # 1. Build rank maps keyed by vector-storage chunk IDs. dense_ranks = { r.data["doc_id"]: (idx + 1) for idx, r in enumerate(dense_results) - } # 这里的 doc_id 实际上是 chunk_id + } sparse_ranks = {r.chunk_id: (idx + 1) for idx, r in enumerate(sparse_results)} - # 2. 收集所有唯一的 ID - # 需要统一为 chunk_id + # 2. Collect all unique chunk IDs. all_chunk_ids = set() - vec_doc_id_to_dense: dict[str, Result] = {} # vec_doc_id -> Result - chunk_id_to_sparse: dict[str, SparseResult] = {} # chunk_id -> SparseResult + chunk_id_to_dense: dict[str, Result] = {} + chunk_id_to_sparse: dict[str, SparseResult] = {} # 处理稀疏检索结果 for r in sparse_results: all_chunk_ids.add(r.chunk_id) chunk_id_to_sparse[r.chunk_id] = r - # 处理稠密检索结果 (需要转换 vec_doc_id 到 chunk_id) + # Dense results use Document.doc_id, which stores the chunk UUID. for r in dense_results: - vec_doc_id = r.data["doc_id"] - all_chunk_ids.add(vec_doc_id) - vec_doc_id_to_dense[vec_doc_id] = r + chunk_id = r.data["doc_id"] + all_chunk_ids.add(chunk_id) + chunk_id_to_dense[chunk_id] = r # 3. 计算 RRF 分数 rrf_scores: dict[str, float] = {} @@ -134,9 +133,9 @@ async def fuse( score=rrf_scores[identifier], ), ) - elif identifier in vec_doc_id_to_dense: + elif identifier in chunk_id_to_dense: # 从向量检索获取信息,需要从数据库获取块的详细信息 - vec_result = vec_doc_id_to_dense[identifier] + vec_result = chunk_id_to_dense[identifier] chunk_md = json.loads(vec_result.data["metadata"]) fused_results.append( FusedResult( diff --git a/astrbot/core/knowledge_base/retrieval/sparse_retriever.py b/astrbot/core/knowledge_base/retrieval/sparse_retriever.py index 0608f71810..c34728a273 100644 --- a/astrbot/core/knowledge_base/retrieval/sparse_retriever.py +++ b/astrbot/core/knowledge_base/retrieval/sparse_retriever.py @@ -163,8 +163,8 @@ async def _retrieve_with_bm25( if not any(all_kb_chunks): return [] - # 每个知识库独立计算 BM25 分数并截断,再合并 - merged_results = [] + # 每个知识库独立计算 BM25 分数并截断,再合并。 + merged_results: list[SparseResult] = [] for kb_chunks in all_kb_chunks: if not kb_chunks: continue @@ -179,9 +179,10 @@ async def _retrieve_with_bm25( tokenized_query = tokenize_text(query, self.hit_stopwords) scores = bm25.get_scores(tokenized_query) + kb_results: list[SparseResult] = [] for idx, score in enumerate(scores): chunk = kb_chunks[idx] - merged_results.append( + kb_results.append( SparseResult( chunk_id=chunk["chunk_id"], chunk_index=chunk["chunk_index"], @@ -192,9 +193,7 @@ async def _retrieve_with_bm25( ), ) - # 截断当前 KB 的结果 - kb_sorted = sorted(merged_results[-len(kb_chunks):], key=lambda x: x.score) - merged_results = merged_results[:-len(kb_chunks)] + kb_sorted[:kb_top_k] + merged_results.extend(sorted(kb_results, key=lambda x: x.score)[:kb_top_k]) merged_results.sort(key=lambda x: x.score) return merged_results[:top_k_sparse] From e1269ef066ce9a9baf4dca68233acce6cab91a60 Mon Sep 17 00:00:00 2001 From: lxfight <1686540385@qq.com> Date: Fri, 5 Jun 2026 22:33:48 +0800 Subject: [PATCH 17/48] fix(kb): add rate-limiter lock, validate chunk/doc existence on delete, track init retries - Wrap RateLimiter time checks in asyncio.Lock for accurate throttling when used concurrently via asyncio.gather. - Raise ValueError from delete_chunk/delete_document when the target chunk or document does not exist, instead of silently succeeding. - Pass kb_id to delete_documents_by_ids for cross-KB safety. - Add init_retry_count and last_init_retry_at fields on KBHelper to support cooldown-gated re-initialization in KnowledgeBaseManager. --- astrbot/core/knowledge_base/kb_helper.py | 24 ++++++++++----- tests/unit/test_kb_rate_limiter.py | 38 ++++++++++++++++++++++++ 2 files changed, 55 insertions(+), 7 deletions(-) create mode 100644 tests/unit/test_kb_rate_limiter.py diff --git a/astrbot/core/knowledge_base/kb_helper.py b/astrbot/core/knowledge_base/kb_helper.py index 3f56d075eb..a02c58c125 100644 --- a/astrbot/core/knowledge_base/kb_helper.py +++ b/astrbot/core/knowledge_base/kb_helper.py @@ -40,18 +40,20 @@ def __init__(self, max_rpm: int) -> None: self.max_per_minute = max_rpm self.interval = 60.0 / max_rpm if max_rpm > 0 else 0 self.last_call_time = 0 + self._lock = asyncio.Lock() async def __aenter__(self): if self.interval == 0: return - now = time.monotonic() - elapsed = now - self.last_call_time + async with self._lock: + now = time.monotonic() + elapsed = now - self.last_call_time - if elapsed < self.interval: - await asyncio.sleep(self.interval - elapsed) + if elapsed < self.interval: + await asyncio.sleep(self.interval - elapsed) - self.last_call_time = time.monotonic() + self.last_call_time = time.monotonic() async def __aexit__(self, exc_type, exc_val, exc_tb): pass @@ -133,6 +135,8 @@ def __init__( self.kb_root_dir = kb_root_dir self.chunker = chunker self.init_error = None + self.init_retry_count = 0 + self.last_init_retry_at = 0.0 self.kb_dir = Path(self.kb_root_dir) / self.kb.kb_id self.kb_medias_dir = Path(self.kb_dir) / "medias" / self.kb.kb_id @@ -498,10 +502,13 @@ async def get_document(self, doc_id: str) -> KBDocument | None: async def delete_document(self, doc_id: str) -> None: """删除单个文档及其相关数据""" - await self.kb_db.delete_document_by_id( + deleted = await self.kb_db.delete_document_by_id( doc_id=doc_id, vec_db=self.vec_db, # type: ignore + kb_id=self.kb.kb_id, ) + if not deleted: + raise ValueError(f"无法找到 ID 为 {doc_id} 的文档") await self.kb_db.update_kb_stats( kb_id=self.kb.kb_id, vec_db=self.vec_db, # type: ignore @@ -516,6 +523,7 @@ async def delete_documents(self, doc_ids: list[str]) -> dict[str, bool]: results = await self.kb_db.delete_documents_by_ids( doc_ids=doc_ids, vec_db=self.vec_db, # type: ignore + kb_id=self.kb.kb_id, ) await self.kb_db.update_kb_stats( kb_id=self.kb.kb_id, @@ -527,7 +535,9 @@ async def delete_documents(self, doc_ids: list[str]) -> dict[str, bool]: async def delete_chunk(self, chunk_id: str, doc_id: str) -> None: """删除单个文本块及其相关数据""" vec_db: FaissVecDB = self.vec_db # type: ignore - await vec_db.delete(chunk_id) + deleted = await vec_db.delete(chunk_id) + if not deleted: + raise ValueError(f"无法找到 ID 为 {chunk_id} 的文本块") await self.kb_db.update_kb_stats( kb_id=self.kb.kb_id, vec_db=self.vec_db, # type: ignore diff --git a/tests/unit/test_kb_rate_limiter.py b/tests/unit/test_kb_rate_limiter.py new file mode 100644 index 0000000000..2341f017de --- /dev/null +++ b/tests/unit/test_kb_rate_limiter.py @@ -0,0 +1,38 @@ +import asyncio +from types import SimpleNamespace + +import pytest + +from astrbot.core.knowledge_base import kb_helper +from astrbot.core.knowledge_base.kb_helper import RateLimiter + + +@pytest.mark.asyncio +async def test_rate_limiter_serializes_concurrent_entries(monkeypatch): + real_sleep = asyncio.sleep + monotonic_values = iter([0.0, 0.0, 0.0, 0.0]) + sleeps: list[float] = [] + + async def fake_sleep(delay: float) -> None: + sleeps.append(delay) + await real_sleep(0) + + monkeypatch.setattr( + kb_helper, + "time", + SimpleNamespace(monotonic=lambda: next(monotonic_values)), + ) + monkeypatch.setattr( + kb_helper, + "asyncio", + SimpleNamespace(Lock=asyncio.Lock, sleep=fake_sleep), + ) + + limiter = RateLimiter(max_rpm=60) + limiter.last_call_time = -1.0 + await asyncio.gather( + limiter.__aenter__(), + limiter.__aenter__(), + ) + + assert sleeps == [1.0] From d74fe03f5f2b8940a7a852c2ed5be37cb95fa3ce Mon Sep 17 00:00:00 2001 From: lxfight <1686540385@qq.com> Date: Fri, 5 Jun 2026 22:34:08 +0800 Subject: [PATCH 18/48] fix(kb): cascade-delete KB records, add instance lock, name index, and init retry MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Delete KBDocument and KBMedia rows alongside KnowledgeBase when removing a KB, preventing orphan metadata records. - Add asyncio.Lock around kb_insts mutations so get_kb / update_kb / delete_kb are serialized and readers never observe a half-swapped instance. - Maintain O(1) kb_name→kb_id index to replace linear scan in get_kb_by_name. - Add cooldown-gated init retry (60 s, max 3 attempts) so transient provider outages do not permanently disable a KB. - Use _UNSET sentinel for rerank_provider_id in update_kb so callers can distinguish 'not provided' from 'explicitly set to None'. - Clean up file directories when create_kb fails after helper init, preventing empty orphan directories. --- astrbot/core/knowledge_base/kb_mgr.py | 347 +++++++++++++++-------- tests/unit/test_kb_manager_delete.py | 137 +++++++++ tests/unit/test_kb_manager_resilience.py | 175 ++++++++++++ 3 files changed, 543 insertions(+), 116 deletions(-) create mode 100644 tests/unit/test_kb_manager_delete.py diff --git a/astrbot/core/knowledge_base/kb_mgr.py b/astrbot/core/knowledge_base/kb_mgr.py index 3795f9c4b3..c1680cb1e3 100644 --- a/astrbot/core/knowledge_base/kb_mgr.py +++ b/astrbot/core/knowledge_base/kb_mgr.py @@ -1,5 +1,10 @@ +import asyncio +import time from pathlib import Path +from sqlalchemy import delete +from sqlmodel import col + from astrbot.core import logger from astrbot.core.provider.manager import ProviderManager from astrbot.core.utils.astrbot_path import get_astrbot_knowledge_base_path @@ -8,7 +13,7 @@ from .chunking.recursive import RecursiveCharacterChunker from .kb_db_sqlite import KBSQLiteDatabase from .kb_helper import KBHelper -from .models import KBDocument, KnowledgeBase +from .models import KBDocument, KBMedia, KnowledgeBase from .retrieval.manager import RetrievalManager, RetrievalResult from .retrieval.rank_fusion import RankFusion from .retrieval.sparse_retriever import SparseRetriever @@ -17,6 +22,9 @@ DB_PATH = Path(FILES_PATH) / "kb.db" """Knowledge Base storage root directory""" CHUNKER = RecursiveCharacterChunker() +_UNSET = object() +INIT_RETRY_COOLDOWN_SECONDS = 60.0 +INIT_RETRY_MAX_ATTEMPTS = 3 class KnowledgeBaseManager: @@ -32,6 +40,79 @@ def __init__( self._session_deleted_callback_registered = False self.kb_insts: dict[str, KBHelper] = {} + self._kb_name_index: dict[str, str] = {} + self._kb_instances_lock = asyncio.Lock() + + def _ensure_kb_name_index(self) -> None: + if not hasattr(self, "kb_insts"): + self.kb_insts = {} + if not hasattr(self, "_kb_name_index"): + self._kb_name_index = {} + known_ids = set(self.kb_insts) + self._kb_name_index = { + name: kb_id + for name, kb_id in self._kb_name_index.items() + if kb_id in known_ids + } + for kb_id, kb_helper in self.kb_insts.items(): + self._kb_name_index[kb_helper.kb.kb_name] = kb_id + + def _ensure_kb_instances_lock(self) -> asyncio.Lock: + if not hasattr(self, "_kb_instances_lock"): + self._kb_instances_lock = asyncio.Lock() + return self._kb_instances_lock + + def _set_kb_instance(self, kb_id: str, kb_helper: KBHelper) -> None: + self._ensure_kb_name_index() + self.kb_insts[kb_id] = kb_helper + self._kb_name_index = { + name: indexed_kb_id + for name, indexed_kb_id in self._kb_name_index.items() + if indexed_kb_id != kb_id + } + self._kb_name_index[kb_helper.kb.kb_name] = kb_id + + def _get_kb_unlocked(self, kb_id: str) -> KBHelper | None: + if not hasattr(self, "kb_insts"): + self.kb_insts = {} + return self.kb_insts.get(kb_id) + + def _can_retry_helper_init(self, kb_helper: KBHelper) -> bool: + if not kb_helper.init_error: + return False + retry_count = getattr(kb_helper, "init_retry_count", 0) + if retry_count >= INIT_RETRY_MAX_ATTEMPTS: + return False + last_retry_at = getattr(kb_helper, "last_init_retry_at", 0.0) + return time.monotonic() - last_retry_at >= INIT_RETRY_COOLDOWN_SECONDS + + async def _retry_helper_init_if_due(self, kb_helper: KBHelper) -> None: + if not self._can_retry_helper_init(kb_helper): + return + + kb_helper.init_retry_count = getattr(kb_helper, "init_retry_count", 0) + 1 + kb_helper.last_init_retry_at = time.monotonic() + try: + await kb_helper.initialize() + kb_helper.init_error = None + kb_helper.init_retry_count = 0 + kb_helper.last_init_retry_at = 0.0 + except Exception as e: + kb_helper.init_error = str(e) + logger.warning( + f"知识库 {kb_helper.kb.kb_name}({kb_helper.kb.kb_id}) " + f"第 {kb_helper.init_retry_count} 次重新初始化失败: {e}", + exc_info=True, + ) + + def _remove_kb_instance(self, kb_id: str) -> None: + self._ensure_kb_name_index() + self.kb_insts.pop(kb_id, None) + self._kb_name_index = { + name: indexed_kb_id + for name, indexed_kb_id in self._kb_name_index.items() + if indexed_kb_id != kb_id + } async def initialize(self) -> None: """初始化知识库模块""" @@ -76,11 +157,13 @@ async def load_kbs(self) -> None: await kb_helper.initialize() except Exception as e: kb_helper.init_error = str(e) + kb_helper.init_retry_count = 0 + kb_helper.last_init_retry_at = time.monotonic() logger.error( f"知识库 {record.kb_name}({record.kb_id}) 初始化失败: {e}", exc_info=True, ) - self.kb_insts[record.kb_id] = kb_helper + self._set_kb_instance(record.kb_id, kb_helper) async def create_kb( self, @@ -112,66 +195,96 @@ async def create_kb( top_m_final=top_m_final if top_m_final is not None else 5, index_type=index_type if index_type is not None else "flat", ) + kb_helper: KBHelper | None = None try: - async with self.kb_db.get_db() as session: - session.add(kb) - await session.flush() - - kb_helper = KBHelper( - kb_db=self.kb_db, - kb=kb, - provider_manager=self.provider_manager, - kb_root_dir=FILES_PATH, - chunker=CHUNKER, - ) - await kb_helper.initialize() - await session.commit() - self.kb_insts[kb.kb_id] = kb_helper - return kb_helper + async with self._ensure_kb_instances_lock(): + async with self.kb_db.get_db() as session: + session.add(kb) + await session.flush() + + kb_helper = KBHelper( + kb_db=self.kb_db, + kb=kb, + provider_manager=self.provider_manager, + kb_root_dir=FILES_PATH, + chunker=CHUNKER, + ) + await kb_helper.initialize() + await session.commit() + self._set_kb_instance(kb.kb_id, kb_helper) + return kb_helper except Exception as e: + if kb_helper is not None: + try: + await kb_helper.delete_vec_db() + except Exception as cleanup_err: + logger.warning( + f"创建知识库 {kb_name} 失败后清理文件目录失败: {cleanup_err}", + ) if "kb_name" in str(e): raise ValueError(f"知识库名称 '{kb_name}' 已存在") raise async def get_kb(self, kb_id: str) -> KBHelper | None: """获取知识库实例""" - if kb_id in self.kb_insts: - return self.kb_insts[kb_id] + async with self._ensure_kb_instances_lock(): + kb_helper = self._get_kb_unlocked(kb_id) + if kb_helper is not None: + await self._retry_helper_init_if_due(kb_helper) + return kb_helper async def get_kb_by_name(self, kb_name: str) -> KBHelper | None: """通过名称获取知识库实例""" - for kb_helper in self.kb_insts.values(): - if kb_helper.kb.kb_name == kb_name: - return kb_helper - return None + async with self._ensure_kb_instances_lock(): + self._ensure_kb_name_index() + kb_id = self._kb_name_index.get(kb_name) + if kb_id: + return self.kb_insts.get(kb_id) + return None async def delete_kb(self, kb_id: str) -> bool: """删除知识库实例""" - kb_helper = await self.get_kb(kb_id) - if not kb_helper: - return False + async with self._ensure_kb_instances_lock(): + kb_helper = self._get_kb_unlocked(kb_id) + if not kb_helper: + return False - await kb_helper.delete_vec_db() - async with self.kb_db.get_db() as session: - await session.delete(kb_helper.kb) - await session.commit() + async with self.kb_db.get_db() as session: + await session.execute( + delete(KBMedia).where(col(KBMedia.kb_id) == kb_id) + ) + await session.execute( + delete(KBDocument).where(col(KBDocument.kb_id) == kb_id) + ) + await session.execute( + delete(KnowledgeBase).where(col(KnowledgeBase.kb_id) == kb_id) + ) + await session.commit() - self.kb_insts.pop(kb_id, None) - return True + try: + await kb_helper.delete_vec_db() + except Exception as e: + logger.warning( + f"知识库 {kb_id} 数据库记录已删除,但文件目录清理失败: {e}" + ) + + self._remove_kb_instance(kb_id) + return True async def list_kbs(self) -> list[KnowledgeBase]: """列出所有知识库实例""" - kbs = [kb_helper.kb for kb_helper in self.kb_insts.values()] - return kbs + async with self._ensure_kb_instances_lock(): + kbs = [kb_helper.kb for kb_helper in self.kb_insts.values()] + return kbs async def update_kb( self, kb_id: str, - kb_name: str, + kb_name: str | None = None, description: str | None = None, emoji: str | None = None, embedding_provider_id: str | None = None, - rerank_provider_id: str | None = None, + rerank_provider_id: str | None | object = _UNSET, chunk_size: int | None = None, chunk_overlap: int | None = None, top_k_dense: int | None = None, @@ -180,89 +293,91 @@ async def update_kb( index_type: str | None = None, ) -> KBHelper | None: """更新知识库实例""" - kb_helper = await self.get_kb(kb_id) - if not kb_helper: - return None - - kb = kb_helper.kb - previous_state = { - "kb_name": kb.kb_name, - "description": kb.description, - "emoji": kb.emoji, - "embedding_provider_id": kb.embedding_provider_id, - "rerank_provider_id": kb.rerank_provider_id, - "chunk_size": kb.chunk_size, - "chunk_overlap": kb.chunk_overlap, - "top_k_dense": kb.top_k_dense, - "top_k_sparse": kb.top_k_sparse, - "top_m_final": kb.top_m_final, - "index_type": kb.index_type, - } - previous_init_error = kb_helper.init_error - - if kb_name is not None: - kb.kb_name = kb_name - if description is not None: - kb.description = description - if emoji is not None: - kb.emoji = emoji - if embedding_provider_id is not None: - kb.embedding_provider_id = embedding_provider_id - kb.rerank_provider_id = rerank_provider_id # 允许设置为 None - if chunk_size is not None: - kb.chunk_size = chunk_size - if chunk_overlap is not None: - kb.chunk_overlap = chunk_overlap - if top_k_dense is not None: - kb.top_k_dense = top_k_dense - if top_k_sparse is not None: - kb.top_k_sparse = top_k_sparse - if top_m_final is not None: - kb.top_m_final = top_m_final - if index_type is not None: - kb.index_type = index_type - - # Build a new helper first. Keep current vec_db alive until new init succeeds. - new_helper = KBHelper( - kb_db=self.kb_db, - kb=kb, - provider_manager=self.provider_manager, - kb_root_dir=FILES_PATH, - chunker=CHUNKER, - ) - - try: - await new_helper.initialize() - except Exception as e: - # Roll back in-memory settings and keep current helper available. - kb.kb_name = previous_state["kb_name"] - kb.description = previous_state["description"] - kb.emoji = previous_state["emoji"] - kb.embedding_provider_id = previous_state["embedding_provider_id"] - kb.rerank_provider_id = previous_state["rerank_provider_id"] - kb.chunk_size = previous_state["chunk_size"] - kb.chunk_overlap = previous_state["chunk_overlap"] - kb.top_k_dense = previous_state["top_k_dense"] - kb.top_k_sparse = previous_state["top_k_sparse"] - kb.top_m_final = previous_state["top_m_final"] - kb.index_type = previous_state["index_type"] - kb_helper.init_error = previous_init_error - logger.error( - f"知识库 {kb.kb_name}({kb.kb_id}) 重新初始化失败,继续使用旧实例: {e}", - exc_info=True, + async with self._ensure_kb_instances_lock(): + kb_helper = self._get_kb_unlocked(kb_id) + if not kb_helper: + return None + + kb = kb_helper.kb + previous_state = { + "kb_name": kb.kb_name, + "description": kb.description, + "emoji": kb.emoji, + "embedding_provider_id": kb.embedding_provider_id, + "rerank_provider_id": kb.rerank_provider_id, + "chunk_size": kb.chunk_size, + "chunk_overlap": kb.chunk_overlap, + "top_k_dense": kb.top_k_dense, + "top_k_sparse": kb.top_k_sparse, + "top_m_final": kb.top_m_final, + "index_type": kb.index_type, + } + previous_init_error = kb_helper.init_error + + if kb_name is not None: + kb.kb_name = kb_name + if description is not None: + kb.description = description + if emoji is not None: + kb.emoji = emoji + if embedding_provider_id is not None: + kb.embedding_provider_id = embedding_provider_id + if rerank_provider_id is not _UNSET: + kb.rerank_provider_id = rerank_provider_id # type: ignore[assignment] + if chunk_size is not None: + kb.chunk_size = chunk_size + if chunk_overlap is not None: + kb.chunk_overlap = chunk_overlap + if top_k_dense is not None: + kb.top_k_dense = top_k_dense + if top_k_sparse is not None: + kb.top_k_sparse = top_k_sparse + if top_m_final is not None: + kb.top_m_final = top_m_final + if index_type is not None: + kb.index_type = index_type + + # Build a new helper first. Keep current vec_db alive until new init succeeds. + new_helper = KBHelper( + kb_db=self.kb_db, + kb=kb, + provider_manager=self.provider_manager, + kb_root_dir=FILES_PATH, + chunker=CHUNKER, ) - return kb_helper - async with self.kb_db.get_db() as session: - session.add(kb) - await session.commit() - await session.refresh(kb) + try: + await new_helper.initialize() + except Exception as e: + # Roll back in-memory settings and keep current helper available. + kb.kb_name = previous_state["kb_name"] + kb.description = previous_state["description"] + kb.emoji = previous_state["emoji"] + kb.embedding_provider_id = previous_state["embedding_provider_id"] + kb.rerank_provider_id = previous_state["rerank_provider_id"] + kb.chunk_size = previous_state["chunk_size"] + kb.chunk_overlap = previous_state["chunk_overlap"] + kb.top_k_dense = previous_state["top_k_dense"] + kb.top_k_sparse = previous_state["top_k_sparse"] + kb.top_m_final = previous_state["top_m_final"] + kb.index_type = previous_state["index_type"] + kb_helper.init_error = previous_init_error + logger.error( + f"知识库 {kb.kb_name}({kb.kb_id}) 重新初始化失败,继续使用旧实例: {e}", + exc_info=True, + ) + return kb_helper + + async with self.kb_db.get_db() as session: + session.add(kb) + await session.commit() + await session.refresh(kb) - old_helper = kb_helper - self.kb_insts[kb_id] = new_helper - await old_helper.terminate() - new_helper.init_error = None - return new_helper + old_helper = kb_helper + self._set_kb_instance(kb_id, new_helper) + await old_helper.terminate() + new_helper.init_error = None + return new_helper async def retrieve( self, diff --git a/tests/unit/test_kb_manager_delete.py b/tests/unit/test_kb_manager_delete.py new file mode 100644 index 0000000000..751057a821 --- /dev/null +++ b/tests/unit/test_kb_manager_delete.py @@ -0,0 +1,137 @@ +import sys +from datetime import datetime, timezone +from unittest.mock import AsyncMock, MagicMock + +import pytest + +_mock_pm = MagicMock() +_mock_pm.ProviderManager = MagicMock() +sys.modules["astrbot.core.provider.manager"] = _mock_pm + + +@pytest.mark.asyncio +async def test_delete_kb_removes_related_document_and_media_metadata(tmp_path): + from astrbot.core.knowledge_base.kb_db_sqlite import KBSQLiteDatabase + from astrbot.core.knowledge_base.kb_helper import KBHelper + from astrbot.core.knowledge_base.kb_mgr import KnowledgeBaseManager + from astrbot.core.knowledge_base.models import KBDocument, KBMedia, KnowledgeBase + + kb_db = KBSQLiteDatabase(str(tmp_path / "kb.db")) + await kb_db.initialize() + await kb_db.migrate_to_v1() + + kb = KnowledgeBase( + kb_id="kb-delete", + kb_name="delete-me", + embedding_provider_id="emb-1", + ) + other_kb = KnowledgeBase( + kb_id="kb-keep", + kb_name="keep-me", + embedding_provider_id="emb-1", + ) + doc = KBDocument( + doc_id="doc-delete", + kb_id="kb-delete", + doc_name="delete.txt", + file_type="txt", + file_size=1, + file_path="", + ) + other_doc = KBDocument( + doc_id="doc-keep", + kb_id="kb-keep", + doc_name="keep.txt", + file_type="txt", + file_size=1, + file_path="", + ) + media = KBMedia( + media_id="media-delete", + doc_id="doc-delete", + kb_id="kb-delete", + media_type="image", + file_name="delete.png", + file_path="", + file_size=1, + mime_type="image/png", + created_at=datetime.now(timezone.utc), + ) + other_media = KBMedia( + media_id="media-keep", + doc_id="doc-keep", + kb_id="kb-keep", + media_type="image", + file_name="keep.png", + file_path="", + file_size=1, + mime_type="image/png", + created_at=datetime.now(timezone.utc), + ) + async with kb_db.get_db() as session: + session.add(kb) + session.add(other_kb) + session.add(doc) + session.add(other_doc) + session.add(media) + session.add(other_media) + await session.commit() + + helper = KBHelper.__new__(KBHelper) + helper.kb = kb + helper.delete_vec_db = AsyncMock() + + manager = KnowledgeBaseManager.__new__(KnowledgeBaseManager) + manager.kb_db = kb_db + manager.kb_insts = {"kb-delete": helper} + + deleted = await manager.delete_kb("kb-delete") + + assert deleted is True + helper.delete_vec_db.assert_awaited_once() + assert await kb_db.get_kb_by_id("kb-delete") is None + assert await kb_db.get_document_by_id("doc-delete") is None + assert await kb_db.get_media_by_id("media-delete") is None + assert await kb_db.get_kb_by_id("kb-keep") is not None + assert await kb_db.get_document_by_id("doc-keep") is not None + assert await kb_db.get_media_by_id("media-keep") is not None + assert await manager.get_kb_by_name("delete-me") is None + + await kb_db.close() + + +@pytest.mark.asyncio +async def test_create_kb_cleans_created_directory_when_initialize_fails( + tmp_path, + monkeypatch, +): + from astrbot.core.knowledge_base.kb_helper import KBHelper + from astrbot.core.knowledge_base.kb_mgr import KnowledgeBaseManager + + manager = KnowledgeBaseManager.__new__(KnowledgeBaseManager) + manager.provider_manager = MagicMock() + manager.kb_db = MagicMock() + manager.kb_insts = {} + + session = MagicMock() + session.add = MagicMock() + session.flush = AsyncMock() + session.commit = AsyncMock() + context = MagicMock() + context.__aenter__ = AsyncMock(return_value=session) + context.__aexit__ = AsyncMock(return_value=False) + manager.kb_db.get_db.return_value = context + + async def fail_initialize(self): + raise RuntimeError("provider unavailable") + + monkeypatch.setattr(KBHelper, "initialize", fail_initialize) + monkeypatch.setattr("astrbot.core.knowledge_base.kb_mgr.FILES_PATH", str(tmp_path)) + + with pytest.raises(RuntimeError, match="provider unavailable"): + await manager.create_kb( + kb_name="broken", + embedding_provider_id="emb-1", + ) + + assert list(tmp_path.iterdir()) == [] diff --git a/tests/unit/test_kb_manager_resilience.py b/tests/unit/test_kb_manager_resilience.py index ed43a338f8..c4eb30257f 100644 --- a/tests/unit/test_kb_manager_resilience.py +++ b/tests/unit/test_kb_manager_resilience.py @@ -12,6 +12,7 @@ import sys import types +import asyncio from pathlib import Path from unittest.mock import AsyncMock, MagicMock, patch @@ -218,6 +219,99 @@ async def test_update_kb_switches_instance_only_after_new_reinit_success( old_helper.terminate.assert_called_once() +@pytest.mark.asyncio +async def test_get_kb_waits_for_update_instance_swap( + stub_provider_manager_module, + mock_provider_manager, + mock_kb_db, + mock_knowledge_base, +): + from astrbot.core.knowledge_base.kb_helper import KBHelper + from astrbot.core.knowledge_base.kb_mgr import KnowledgeBaseManager + + old_helper = KBHelper.__new__(KBHelper) + old_helper.kb = mock_knowledge_base + old_helper.init_error = None + old_helper.terminate = AsyncMock() + + kb_mgr = KnowledgeBaseManager.__new__(KnowledgeBaseManager) + kb_mgr.provider_manager = mock_provider_manager + kb_mgr.kb_db = mock_kb_db + kb_mgr.kb_insts = {mock_knowledge_base.kb_id: old_helper} + + commit_started = asyncio.Event() + release_commit = asyncio.Event() + + async def commit(): + commit_started.set() + await release_commit.wait() + + mock_session = MagicMock() + mock_session.add = MagicMock() + mock_session.commit = AsyncMock(side_effect=commit) + mock_session.refresh = AsyncMock() + mock_db_context = MagicMock() + mock_db_context.__aenter__ = AsyncMock(return_value=mock_session) + mock_db_context.__aexit__ = AsyncMock(return_value=False) + mock_kb_db.get_db.return_value = mock_db_context + + with patch.object(KBHelper, "initialize", new_callable=AsyncMock): + update_task = asyncio.create_task( + kb_mgr.update_kb( + kb_id=mock_knowledge_base.kb_id, + kb_name="updated_kb", + ) + ) + await commit_started.wait() + + get_task = asyncio.create_task(kb_mgr.get_kb(mock_knowledge_base.kb_id)) + await asyncio.sleep(0) + assert not get_task.done() + + release_commit.set() + updated_helper = await update_task + observed_helper = await get_task + + assert updated_helper is observed_helper + assert observed_helper is kb_mgr.kb_insts[mock_knowledge_base.kb_id] + assert observed_helper is not old_helper + + +@pytest.mark.asyncio +async def test_get_kb_does_not_retry_failed_helper_during_cooldown( + stub_provider_manager_module, + mock_provider_manager, + mock_kb_db, + mock_knowledge_base, +): + from astrbot.core.knowledge_base.kb_helper import KBHelper + from astrbot.core.knowledge_base.kb_mgr import ( + INIT_RETRY_COOLDOWN_SECONDS, + KnowledgeBaseManager, + ) + + helper = KBHelper.__new__(KBHelper) + helper.kb = mock_knowledge_base + helper.init_error = "provider unavailable" + helper.init_retry_count = 0 + helper.last_init_retry_at = 100.0 + helper.initialize = AsyncMock() + + kb_mgr = KnowledgeBaseManager.__new__(KnowledgeBaseManager) + kb_mgr.provider_manager = mock_provider_manager + kb_mgr.kb_db = mock_kb_db + kb_mgr.kb_insts = {mock_knowledge_base.kb_id: helper} + + with patch( + "astrbot.core.knowledge_base.kb_mgr.time.monotonic", + return_value=100.0 + INIT_RETRY_COOLDOWN_SECONDS - 1, + ): + result = await kb_mgr.get_kb(mock_knowledge_base.kb_id) + + assert result is helper + helper.initialize.assert_not_awaited() + + @pytest.mark.asyncio async def test_ensure_vec_db_clears_stale_init_error( stub_provider_manager_module, @@ -264,6 +358,87 @@ async def test_ensure_vec_db_clears_stale_init_error( assert helper.vec_db is mock_vec_db +@pytest.mark.asyncio +async def test_update_kb_omitted_rerank_provider_preserves_existing_value( + stub_provider_manager_module, + mock_provider_manager, + mock_kb_db, + mock_knowledge_base, +): + from astrbot.core.knowledge_base.kb_helper import KBHelper + from astrbot.core.knowledge_base.kb_mgr import KnowledgeBaseManager + + mock_knowledge_base.rerank_provider_id = "rerank-1" + old_helper = KBHelper.__new__(KBHelper) + old_helper.kb = mock_knowledge_base + old_helper.init_error = None + old_helper.terminate = AsyncMock() + + kb_mgr = KnowledgeBaseManager.__new__(KnowledgeBaseManager) + kb_mgr.provider_manager = mock_provider_manager + kb_mgr.kb_db = mock_kb_db + kb_mgr.kb_insts = {mock_knowledge_base.kb_id: old_helper} + + mock_session = MagicMock() + mock_session.add = MagicMock() + mock_session.commit = AsyncMock() + mock_session.refresh = AsyncMock() + mock_db_context = MagicMock() + mock_db_context.__aenter__ = AsyncMock(return_value=mock_session) + mock_db_context.__aexit__ = AsyncMock() + mock_kb_db.get_db.return_value = mock_db_context + + with patch.object(KBHelper, "initialize", new_callable=AsyncMock): + result = await kb_mgr.update_kb( + kb_id=mock_knowledge_base.kb_id, + kb_name="updated_kb", + ) + + assert result is not None + assert result.kb.rerank_provider_id == "rerank-1" + + +@pytest.mark.asyncio +async def test_update_kb_explicit_none_clears_rerank_provider( + stub_provider_manager_module, + mock_provider_manager, + mock_kb_db, + mock_knowledge_base, +): + from astrbot.core.knowledge_base.kb_helper import KBHelper + from astrbot.core.knowledge_base.kb_mgr import KnowledgeBaseManager + + mock_knowledge_base.rerank_provider_id = "rerank-1" + old_helper = KBHelper.__new__(KBHelper) + old_helper.kb = mock_knowledge_base + old_helper.init_error = None + old_helper.terminate = AsyncMock() + + kb_mgr = KnowledgeBaseManager.__new__(KnowledgeBaseManager) + kb_mgr.provider_manager = mock_provider_manager + kb_mgr.kb_db = mock_kb_db + kb_mgr.kb_insts = {mock_knowledge_base.kb_id: old_helper} + + mock_session = MagicMock() + mock_session.add = MagicMock() + mock_session.commit = AsyncMock() + mock_session.refresh = AsyncMock() + mock_db_context = MagicMock() + mock_db_context.__aenter__ = AsyncMock(return_value=mock_session) + mock_db_context.__aexit__ = AsyncMock() + mock_kb_db.get_db.return_value = mock_db_context + + with patch.object(KBHelper, "initialize", new_callable=AsyncMock): + result = await kb_mgr.update_kb( + kb_id=mock_knowledge_base.kb_id, + kb_name="updated_kb", + rerank_provider_id=None, + ) + + assert result is not None + assert result.kb.rerank_provider_id is None + + @pytest.mark.asyncio async def test_ensure_vec_db_sets_init_error_on_failure( stub_provider_manager_module, From 7e3bfc81cf8ab3c26fba99b42e8d27ccca0c9911 Mon Sep 17 00:00:00 2001 From: lxfight <1686540385@qq.com> Date: Fri, 5 Jun 2026 22:36:06 +0800 Subject: [PATCH 19/48] fix(kb): improve dashboard error messages, rerank sentinel, and upload validation - Return a user-visible message when all session-configured KBs are unavailable instead of silently returning None. - Use _UNSET sentinel for rerank_provider_id in the update-kb route so omitting the field preserves the existing value while passing null explicitly clears it. - Add update_fields allowlist check to reject requests with no recognised update fields. - Fix _format_failed_doc_error startswith check to require 'filename:' prefix, preventing false positives when file name happens to be a prefix of the error message. - Add index_type to the recognised update fields. --- astrbot/core/tools/knowledge_base_tools.py | 2 +- astrbot/dashboard/routes/knowledge_base.py | 49 ++++++++------- tests/test_kb_update_route.py | 71 ++++++++++++++++++++++ tests/test_kb_upload_memory_leak.py | 19 ++++++ tests/unit/test_knowledge_base_tools.py | 26 ++++++++ 5 files changed, 145 insertions(+), 22 deletions(-) create mode 100644 tests/test_kb_update_route.py create mode 100644 tests/unit/test_knowledge_base_tools.py diff --git a/astrbot/core/tools/knowledge_base_tools.py b/astrbot/core/tools/knowledge_base_tools.py index e082fd4253..da00c18f47 100644 --- a/astrbot/core/tools/knowledge_base_tools.py +++ b/astrbot/core/tools/knowledge_base_tools.py @@ -53,7 +53,7 @@ async def retrieve_knowledge_base( f"[知识库] 会话 {umo} 配置的以下知识库无效: {invalid_kb_ids}", ) if not kb_names: - return None + return "会话配置的知识库均不存在或未加载,请检查知识库设置。" logger.debug(f"[知识库] 使用会话级配置,知识库数量: {len(kb_names)}") else: kb_names = config.get("kb_names", []) diff --git a/astrbot/dashboard/routes/knowledge_base.py b/astrbot/dashboard/routes/knowledge_base.py index 769e82caa8..d1625fb6c2 100644 --- a/astrbot/dashboard/routes/knowledge_base.py +++ b/astrbot/dashboard/routes/knowledge_base.py @@ -147,7 +147,7 @@ async def _callback(stage: str, current: int, total: int) -> None: @staticmethod def _format_failed_doc_error(file_name: str, error: Exception) -> str: message = str(error).strip() or "上传失败:发生未知错误。" - if message.startswith(file_name): + if message.startswith(f"{file_name}:"): return message return f"{file_name}: {message}" @@ -516,34 +516,36 @@ async def update_kb(self): if not kb_id: return Response().error("缺少参数 kb_id").__dict__ + update_fields = [ + "kb_name", + "description", + "emoji", + "embedding_provider_id", + "rerank_provider_id", + "chunk_size", + "chunk_overlap", + "top_k_dense", + "top_k_sparse", + "top_m_final", + "index_type", + ] + if not any(field in data for field in update_fields): + return Response().error("至少需要提供一个更新字段").__dict__ + kb_name = data.get("kb_name") description = data.get("description") emoji = data.get("emoji") embedding_provider_id = data.get("embedding_provider_id") - rerank_provider_id = data.get("rerank_provider_id") + rerank_provider_provided = "rerank_provider_id" in data + rerank_provider_id = ( + data.get("rerank_provider_id") if rerank_provider_provided else None + ) chunk_size = data.get("chunk_size") chunk_overlap = data.get("chunk_overlap") top_k_dense = data.get("top_k_dense") top_k_sparse = data.get("top_k_sparse") top_m_final = data.get("top_m_final") - - # 检查是否至少提供了一个更新字段 - if all( - v is None - for v in [ - kb_name, - description, - emoji, - embedding_provider_id, - rerank_provider_id, - chunk_size, - chunk_overlap, - top_k_dense, - top_k_sparse, - top_m_final, - ] - ): - return Response().error("至少需要提供一个更新字段").__dict__ + index_type = data.get("index_type") kb_helper = await kb_manager.update_kb( kb_id=kb_id, @@ -551,12 +553,17 @@ async def update_kb(self): description=description, emoji=emoji, embedding_provider_id=embedding_provider_id, - rerank_provider_id=rerank_provider_id, + **( + {"rerank_provider_id": rerank_provider_id} + if rerank_provider_provided + else {} + ), chunk_size=chunk_size, chunk_overlap=chunk_overlap, top_k_dense=top_k_dense, top_k_sparse=top_k_sparse, top_m_final=top_m_final, + index_type=index_type, ) if not kb_helper: diff --git a/tests/test_kb_update_route.py b/tests/test_kb_update_route.py new file mode 100644 index 0000000000..18a22194f5 --- /dev/null +++ b/tests/test_kb_update_route.py @@ -0,0 +1,71 @@ +from unittest.mock import AsyncMock, MagicMock + +import pytest +from quart import Quart + + +def _build_route_with_manager(kb_manager: MagicMock): + from astrbot.dashboard.routes.knowledge_base import KnowledgeBaseRoute + + route = KnowledgeBaseRoute.__new__(KnowledgeBaseRoute) + route._get_kb_manager = MagicMock(return_value=kb_manager) + return route + + +def _build_kb_helper(rerank_provider_id: str | None = "rerank-1"): + from astrbot.core.knowledge_base.models import KnowledgeBase + + kb = KnowledgeBase( + kb_id="kb-1", + kb_name="kb", + embedding_provider_id="emb-1", + rerank_provider_id=rerank_provider_id, + ) + helper = MagicMock() + helper.kb = kb + return helper + + +@pytest.mark.asyncio +async def test_update_kb_omits_unprovided_rerank_provider_id(): + from astrbot.dashboard.routes.knowledge_base import KnowledgeBaseRoute + + app = Quart(__name__) + kb_manager = MagicMock() + kb_manager.update_kb = AsyncMock(return_value=_build_kb_helper()) + route = _build_route_with_manager(kb_manager) + + async with app.test_request_context( + "/api/kb/update", + method="POST", + json={"kb_id": "kb-1", "chunk_size": 1024}, + ): + response = await KnowledgeBaseRoute.update_kb(route) + + assert response["status"] == "ok" + kwargs = kb_manager.update_kb.await_args.kwargs + assert kwargs["kb_id"] == "kb-1" + assert kwargs["chunk_size"] == 1024 + assert "rerank_provider_id" not in kwargs + + +@pytest.mark.asyncio +async def test_update_kb_explicit_null_forwards_rerank_provider_id(): + from astrbot.dashboard.routes.knowledge_base import KnowledgeBaseRoute + + app = Quart(__name__) + kb_manager = MagicMock() + kb_manager.update_kb = AsyncMock(return_value=_build_kb_helper(None)) + route = _build_route_with_manager(kb_manager) + + async with app.test_request_context( + "/api/kb/update", + method="POST", + json={"kb_id": "kb-1", "rerank_provider_id": None}, + ): + response = await KnowledgeBaseRoute.update_kb(route) + + assert response["status"] == "ok" + kwargs = kb_manager.update_kb.await_args.kwargs + assert kwargs["kb_id"] == "kb-1" + assert kwargs["rerank_provider_id"] is None diff --git a/tests/test_kb_upload_memory_leak.py b/tests/test_kb_upload_memory_leak.py index a596350b68..1113dd4793 100644 --- a/tests/test_kb_upload_memory_leak.py +++ b/tests/test_kb_upload_memory_leak.py @@ -18,6 +18,25 @@ class TestUploadTaskCleanup: """Verify task cleanup in get_upload_progress.""" + def test_format_failed_doc_error_only_skips_exact_file_prefix(self): + """File names that are only a prefix of another word still get prepended.""" + from astrbot.dashboard.routes.knowledge_base import KnowledgeBaseRoute + + assert ( + KnowledgeBaseRoute._format_failed_doc_error( + "doc", + ValueError("document parse error"), + ) + == "doc: document parse error" + ) + assert ( + KnowledgeBaseRoute._format_failed_doc_error( + "doc", + ValueError("doc: parse error"), + ) + == "doc: parse error" + ) + @pytest.mark.asyncio async def test_cleanup_on_completed_poll(self): """Completed task cleaned up when client polls for result.""" diff --git a/tests/unit/test_knowledge_base_tools.py b/tests/unit/test_knowledge_base_tools.py new file mode 100644 index 0000000000..abbda85991 --- /dev/null +++ b/tests/unit/test_knowledge_base_tools.py @@ -0,0 +1,26 @@ +from unittest.mock import AsyncMock, MagicMock + +import pytest + + +@pytest.mark.asyncio +async def test_retrieve_knowledge_base_reports_all_invalid_session_kbs(monkeypatch): + from astrbot.core.tools import knowledge_base_tools + + context = MagicMock() + context.kb_manager.get_kb = AsyncMock(return_value=None) + + monkeypatch.setattr( + knowledge_base_tools.sp, + "session_get", + AsyncMock(return_value={"kb_ids": ["missing-kb"], "top_k": 5}), + ) + + result = await knowledge_base_tools.retrieve_knowledge_base( + query="hello", + umo="session-1", + context=context, + ) + + assert result == "会话配置的知识库均不存在或未加载,请检查知识库设置。" + context.kb_manager.retrieve.assert_not_called() From 7a642aa2c7f5752c3e03194ca32b5a0972193aa9 Mon Sep 17 00:00:00 2001 From: lxfight <1686540385@qq.com> Date: Fri, 5 Jun 2026 22:36:26 +0800 Subject: [PATCH 20/48] test(kb): add coverage for kb_id-guarded deletion, vec-delete failure, and migration MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Verify delete_document_by_id rejects documents from a different KB. - Verify metadata is preserved when vector deletion fails. - Verify delete_chunk raises ValueError for missing chunks. - Add integration test for index_type column migration on legacy DB. - Assert kb_id is forwarded through the helper→db call chain. --- tests/test_kb_batch_delete.py | 162 +++++++++++++++++++++++++++++++--- 1 file changed, 148 insertions(+), 14 deletions(-) diff --git a/tests/test_kb_batch_delete.py b/tests/test_kb_batch_delete.py index 5f3c47189f..e5ee7679aa 100644 --- a/tests/test_kb_batch_delete.py +++ b/tests/test_kb_batch_delete.py @@ -1,17 +1,6 @@ -"""Tests for #3: Batch delete documents API. - -Verifies: -- Batch delete from kb.db (single SQL IN clause) -- Parallel vec_db cleanup -- Single update_kb_stats call (not N calls) -- Best-effort semantics: one failure doesn't block others -- Empty list edge case - -NOTE: The knowledge_base package has a circular import chain: - kb_helper → provider.manager → persona_mgr → ... → kb_mgr → provider.manager -We break the chain by stubbing provider.manager in sys.modules before any import. -""" +"""Tests for batch knowledge-base document deletion.""" +import sqlite3 import sys from unittest.mock import AsyncMock, MagicMock, call @@ -58,7 +47,7 @@ async def test_delete_documents_by_ids_empty_list(self): @pytest.mark.asyncio async def test_delete_documents_by_ids_batch_kb_db(self): - """Documents deleted from kb.db via single IN-clause SQL.""" + """Vector cleanup succeeds before kb.db metadata is deleted.""" from astrbot.core.knowledge_base.kb_db_sqlite import KBSQLiteDatabase kb_db = KBSQLiteDatabase.__new__(KBSQLiteDatabase) @@ -85,6 +74,7 @@ async def test_delete_documents_by_ids_batch_kb_db(self): ], any_order=True, ) + session.execute.assert_called() @pytest.mark.asyncio async def test_delete_documents_best_effort(self): @@ -114,6 +104,57 @@ async def _delete_side_effect(metadata_filters): assert results == {"doc-1": True, "doc-2": False, "doc-3": True} assert vec_db.delete_documents.await_count == 3 + @pytest.mark.asyncio + async def test_delete_document_keeps_metadata_when_vec_delete_fails(self): + """Metadata remains visible when vector deletion fails.""" + from astrbot.core.knowledge_base.kb_db_sqlite import KBSQLiteDatabase + from astrbot.core.knowledge_base.models import KBDocument + + kb_db = KBSQLiteDatabase.__new__(KBSQLiteDatabase) + doc = KBDocument( + doc_id="doc-1", + kb_id="kb-a", + doc_name="a.txt", + file_type="txt", + file_size=1, + file_path="", + ) + kb_db.get_document_by_id = AsyncMock(return_value=doc) + session = AsyncMock() + session.__aenter__.return_value = session + session.begin = MagicMock(return_value=session) + kb_db.get_db = MagicMock(return_value=session) + vec_db = AsyncMock() + vec_db.delete_documents = AsyncMock(side_effect=RuntimeError("boom")) + + with pytest.raises(RuntimeError, match="boom"): + await kb_db.delete_document_by_id("doc-1", vec_db, kb_id="kb-a") + + session.execute.assert_not_called() + + @pytest.mark.asyncio + async def test_delete_document_rejects_wrong_kb_id(self): + """A document from another KB must not be deleted.""" + from astrbot.core.knowledge_base.kb_db_sqlite import KBSQLiteDatabase + from astrbot.core.knowledge_base.models import KBDocument + + kb_db = KBSQLiteDatabase.__new__(KBSQLiteDatabase) + doc = KBDocument( + doc_id="doc-1", + kb_id="kb-other", + doc_name="a.txt", + file_type="txt", + file_size=1, + file_path="", + ) + kb_db.get_document_by_id = AsyncMock(return_value=doc) + vec_db = AsyncMock() + + deleted = await kb_db.delete_document_by_id("doc-1", vec_db, kb_id="kb-a") + + assert deleted is False + vec_db.delete_documents.assert_not_awaited() + class TestHelperBatchDelete: """Verify batch delete at the kb_helper layer.""" @@ -129,6 +170,11 @@ async def test_delete_documents_updates_stats_once(self): results = await helper.delete_documents(["doc-1", "doc-2"]) assert results == {"doc-1": True, "doc-2": True} + helper.kb_db.delete_documents_by_ids.assert_awaited_once_with( + doc_ids=["doc-1", "doc-2"], + vec_db=helper.vec_db, + kb_id="kb-test-1", + ) helper.kb_db.update_kb_stats.assert_awaited_once_with( kb_id="kb-test-1", vec_db=helper.vec_db, ) @@ -160,3 +206,91 @@ async def test_delete_documents_preserves_failures(self): # stats still updated once even with partial failures helper.kb_db.update_kb_stats.assert_awaited_once() helper.refresh_kb.assert_awaited_once() + + @pytest.mark.asyncio + async def test_delete_chunk_raises_when_chunk_is_missing(self): + helper = _build_helper() + helper.vec_db.delete = AsyncMock(return_value=False) + + with pytest.raises(ValueError, match="无法找到 ID 为 chunk-missing 的文本块"): + await helper.delete_chunk("chunk-missing", "doc-1") + + helper.vec_db.delete.assert_awaited_once_with("chunk-missing") + helper.kb_db.update_kb_stats.assert_not_awaited() + helper.refresh_kb.assert_not_awaited() + + +@pytest.mark.asyncio +async def test_kb_sqlite_migration_adds_index_type_to_legacy_table(tmp_path): + from astrbot.core.knowledge_base.kb_db_sqlite import KBSQLiteDatabase + + db_path = tmp_path / "kb.db" + conn = sqlite3.connect(db_path) + conn.execute( + """ + CREATE TABLE knowledge_bases ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + kb_id VARCHAR(36) NOT NULL UNIQUE, + kb_name VARCHAR(100) NOT NULL, + description TEXT, + emoji VARCHAR(10), + embedding_provider_id VARCHAR(100), + rerank_provider_id VARCHAR(100), + chunk_size INTEGER, + chunk_overlap INTEGER, + top_k_dense INTEGER, + top_k_sparse INTEGER, + top_m_final INTEGER, + created_at DATETIME NOT NULL, + updated_at DATETIME NOT NULL, + doc_count INTEGER NOT NULL, + chunk_count INTEGER NOT NULL + ) + """, + ) + conn.execute( + """ + CREATE TABLE kb_documents ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + doc_id VARCHAR(36) NOT NULL UNIQUE, + kb_id VARCHAR(36) NOT NULL, + doc_name VARCHAR(255) NOT NULL, + file_type VARCHAR(20) NOT NULL, + file_size INTEGER NOT NULL, + file_path VARCHAR(512) NOT NULL, + chunk_count INTEGER NOT NULL, + media_count INTEGER NOT NULL, + created_at DATETIME NOT NULL, + updated_at DATETIME NOT NULL + ) + """, + ) + conn.execute( + """ + CREATE TABLE kb_media ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + media_id VARCHAR(36) NOT NULL UNIQUE, + doc_id VARCHAR(36) NOT NULL, + kb_id VARCHAR(36) NOT NULL, + media_type VARCHAR(20) NOT NULL, + file_name VARCHAR(255) NOT NULL, + file_path VARCHAR(512) NOT NULL, + file_size INTEGER NOT NULL, + mime_type VARCHAR(100) NOT NULL, + created_at DATETIME NOT NULL + ) + """, + ) + conn.commit() + conn.close() + + kb_db = KBSQLiteDatabase(str(db_path)) + await kb_db.initialize() + await kb_db.migrate_to_v1() + + conn = sqlite3.connect(db_path) + columns = {row[1] for row in conn.execute("PRAGMA table_info(knowledge_bases)")} + conn.close() + await kb_db.close() + + assert "index_type" in columns From 13a12432a3204e8c3709e9c50c8a067b3a3008ab Mon Sep 17 00:00:00 2001 From: lxfight <1686540385@qq.com> Date: Sat, 6 Jun 2026 00:11:55 +0800 Subject: [PATCH 21/48] feat(kb): add document search by name/type and remove default list_kbs limit - Add search parameter to list_documents_by_kb and count_documents_by_kb for SQL ILIKE filtering on doc_name and file_type. - Change list_kbs default limit from 100 to None so callers can request all KB records without an artificial ceiling. - Add KBHelper.count_documents to forward the search filter to kb_db. --- astrbot/core/knowledge_base/kb_db_sqlite.py | 47 +++++++++++++++++---- astrbot/core/knowledge_base/kb_helper.py | 12 +++++- 2 files changed, 49 insertions(+), 10 deletions(-) diff --git a/astrbot/core/knowledge_base/kb_db_sqlite.py b/astrbot/core/knowledge_base/kb_db_sqlite.py index 2813e291b6..6ccb91e84d 100644 --- a/astrbot/core/knowledge_base/kb_db_sqlite.py +++ b/astrbot/core/knowledge_base/kb_db_sqlite.py @@ -3,8 +3,9 @@ from pathlib import Path from typing import TYPE_CHECKING -from sqlalchemy import delete, func, select, text, update +from sqlalchemy import delete, event, func, or_, select, text, update from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine +from sqlalchemy.pool import NullPool from sqlmodel import col, desc from astrbot.core import logger @@ -212,15 +213,22 @@ async def get_kb_by_name(self, kb_name: str) -> KnowledgeBase | None: result = await session.execute(stmt) return result.scalar_one_or_none() - async def list_kbs(self, offset: int = 0, limit: int = 100) -> list[KnowledgeBase]: + async def list_kbs( + self, + offset: int = 0, + limit: int | None = None, + ) -> list[KnowledgeBase]: """列出所有知识库""" async with self.get_db() as session: stmt = ( select(KnowledgeBase) .offset(offset) - .limit(limit) - .order_by(desc(KnowledgeBase.created_at)) + .order_by( + desc(KnowledgeBase.created_at), + ) ) + if limit is not None: + stmt = stmt.limit(limit) result = await session.execute(stmt) return list(result.scalars().all()) @@ -245,12 +253,22 @@ async def list_documents_by_kb( kb_id: str, offset: int = 0, limit: int = 100, + search: str | None = None, ) -> list[KBDocument]: """列出知识库的所有文档""" async with self.get_db() as session: + conditions = [col(KBDocument.kb_id) == kb_id] + if search: + pattern = f"%{search}%" + conditions.append( + or_( + col(KBDocument.doc_name).ilike(pattern), + col(KBDocument.file_type).ilike(pattern), + ), + ) stmt = ( select(KBDocument) - .where(col(KBDocument.kb_id) == kb_id) + .where(*conditions) .offset(offset) .limit(limit) .order_by(desc(KBDocument.created_at)) @@ -258,12 +276,23 @@ async def list_documents_by_kb( result = await session.execute(stmt) return list(result.scalars().all()) - async def count_documents_by_kb(self, kb_id: str) -> int: + async def count_documents_by_kb( + self, + kb_id: str, + search: str | None = None, + ) -> int: """统计知识库的文档数量""" async with self.get_db() as session: - stmt = select(func.count(col(KBDocument.id))).where( - col(KBDocument.kb_id) == kb_id, - ) + conditions = [col(KBDocument.kb_id) == kb_id] + if search: + pattern = f"%{search}%" + conditions.append( + or_( + col(KBDocument.doc_name).ilike(pattern), + col(KBDocument.file_type).ilike(pattern), + ), + ) + stmt = select(func.count(col(KBDocument.id))).where(*conditions) result = await session.execute(stmt) return result.scalar() or 0 diff --git a/astrbot/core/knowledge_base/kb_helper.py b/astrbot/core/knowledge_base/kb_helper.py index a02c58c125..6935fb9ac3 100644 --- a/astrbot/core/knowledge_base/kb_helper.py +++ b/astrbot/core/knowledge_base/kb_helper.py @@ -490,11 +490,21 @@ async def list_documents( self, offset: int = 0, limit: int = 100, + search: str | None = None, ) -> list[KBDocument]: """列出知识库的所有文档""" - docs = await self.kb_db.list_documents_by_kb(self.kb.kb_id, offset, limit) + docs = await self.kb_db.list_documents_by_kb( + self.kb.kb_id, + offset, + limit, + search, + ) return docs + async def count_documents(self, search: str | None = None) -> int: + """统计知识库的所有文档数量""" + return await self.kb_db.count_documents_by_kb(self.kb.kb_id, search) + async def get_document(self, doc_id: str) -> KBDocument | None: """获取单个文档""" doc = await self.kb_db.get_document_by_id(doc_id) From c40eb5a8138caa924b05d229894be0bd46fd9740 Mon Sep 17 00:00:00 2001 From: lxfight <1686540385@qq.com> Date: Sat, 6 Jun 2026 00:14:19 +0800 Subject: [PATCH 22/48] feat(kb): support kb_ids in retrieve and validate KB options before mutation MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Accept kb_ids as an alternative to kb_names in retrieve(), so callers that already hold KB IDs can skip the name→id lookup. - Add _validate_kb_options helper to reject invalid chunk/retrieval parameters (negative sizes, overlap >= chunk_size, unknown index_type) at the manager layer before persisting any state. - In update_kb, build a candidate_state dict first, validate it, and only mutate the KB model after validation passes — this prevents partially-updated in-memory state when invalid options are submitted. --- astrbot/core/knowledge_base/kb_mgr.py | 141 ++++++++++++++++++----- tests/unit/test_kb_manager_resilience.py | 53 +++++++++ 2 files changed, 164 insertions(+), 30 deletions(-) diff --git a/astrbot/core/knowledge_base/kb_mgr.py b/astrbot/core/knowledge_base/kb_mgr.py index c1680cb1e3..d24b452e27 100644 --- a/astrbot/core/knowledge_base/kb_mgr.py +++ b/astrbot/core/knowledge_base/kb_mgr.py @@ -25,6 +25,38 @@ _UNSET = object() INIT_RETRY_COOLDOWN_SECONDS = 60.0 INIT_RETRY_MAX_ATTEMPTS = 3 +VALID_INDEX_TYPES = {"flat", "hnsw"} + + +def _validate_kb_options( + *, + chunk_size: int | None, + chunk_overlap: int | None, + top_k_dense: int | None, + top_k_sparse: int | None, + top_m_final: int | None, + index_type: str | None, +) -> None: + if chunk_size is not None and chunk_size <= 0: + raise ValueError("chunk_size 必须大于 0") + if chunk_overlap is not None and chunk_overlap < 0: + raise ValueError("chunk_overlap 不能为负数") + if ( + chunk_size is not None + and chunk_overlap is not None + and chunk_overlap >= chunk_size + ): + raise ValueError("chunk_overlap 必须小于 chunk_size") + if top_k_dense is not None and top_k_dense <= 0: + raise ValueError("top_k_dense 必须大于 0") + if top_k_sparse is not None and top_k_sparse <= 0: + raise ValueError("top_k_sparse 必须大于 0") + if top_m_final is not None and top_m_final <= 0: + raise ValueError("top_m_final 必须大于 0") + if index_type is not None and index_type not in VALID_INDEX_TYPES: + raise ValueError( + f"index_type 必须是 {', '.join(sorted(VALID_INDEX_TYPES))} 之一" + ) class KnowledgeBaseManager: @@ -182,18 +214,32 @@ async def create_kb( """创建新的知识库实例""" if embedding_provider_id is None: raise ValueError("创建知识库时必须提供embedding_provider_id") + effective_chunk_size = chunk_size if chunk_size is not None else 512 + effective_chunk_overlap = chunk_overlap if chunk_overlap is not None else 50 + effective_top_k_dense = top_k_dense if top_k_dense is not None else 50 + effective_top_k_sparse = top_k_sparse if top_k_sparse is not None else 50 + effective_top_m_final = top_m_final if top_m_final is not None else 5 + effective_index_type = index_type if index_type is not None else "flat" + _validate_kb_options( + chunk_size=effective_chunk_size, + chunk_overlap=effective_chunk_overlap, + top_k_dense=effective_top_k_dense, + top_k_sparse=effective_top_k_sparse, + top_m_final=effective_top_m_final, + index_type=effective_index_type, + ) kb = KnowledgeBase( kb_name=kb_name, description=description, emoji=emoji or "📚", embedding_provider_id=embedding_provider_id, rerank_provider_id=rerank_provider_id, - chunk_size=chunk_size if chunk_size is not None else 512, - chunk_overlap=chunk_overlap if chunk_overlap is not None else 50, - top_k_dense=top_k_dense if top_k_dense is not None else 50, - top_k_sparse=top_k_sparse if top_k_sparse is not None else 50, - top_m_final=top_m_final if top_m_final is not None else 5, - index_type=index_type if index_type is not None else "flat", + chunk_size=effective_chunk_size, + chunk_overlap=effective_chunk_overlap, + top_k_dense=effective_top_k_dense, + top_k_sparse=effective_top_k_sparse, + top_m_final=effective_top_m_final, + index_type=effective_index_type, ) kb_helper: KBHelper | None = None try: @@ -314,28 +360,48 @@ async def update_kb( } previous_init_error = kb_helper.init_error + candidate_state = previous_state.copy() if kb_name is not None: - kb.kb_name = kb_name + candidate_state["kb_name"] = kb_name if description is not None: - kb.description = description + candidate_state["description"] = description if emoji is not None: - kb.emoji = emoji + candidate_state["emoji"] = emoji if embedding_provider_id is not None: - kb.embedding_provider_id = embedding_provider_id + candidate_state["embedding_provider_id"] = embedding_provider_id if rerank_provider_id is not _UNSET: - kb.rerank_provider_id = rerank_provider_id # type: ignore[assignment] + candidate_state["rerank_provider_id"] = rerank_provider_id if chunk_size is not None: - kb.chunk_size = chunk_size + candidate_state["chunk_size"] = chunk_size if chunk_overlap is not None: - kb.chunk_overlap = chunk_overlap + candidate_state["chunk_overlap"] = chunk_overlap if top_k_dense is not None: - kb.top_k_dense = top_k_dense + candidate_state["top_k_dense"] = top_k_dense if top_k_sparse is not None: - kb.top_k_sparse = top_k_sparse + candidate_state["top_k_sparse"] = top_k_sparse if top_m_final is not None: - kb.top_m_final = top_m_final + candidate_state["top_m_final"] = top_m_final if index_type is not None: - kb.index_type = index_type + candidate_state["index_type"] = index_type + _validate_kb_options( + chunk_size=candidate_state["chunk_size"], + chunk_overlap=candidate_state["chunk_overlap"], + top_k_dense=candidate_state["top_k_dense"], + top_k_sparse=candidate_state["top_k_sparse"], + top_m_final=candidate_state["top_m_final"], + index_type=candidate_state["index_type"], + ) + kb.kb_name = candidate_state["kb_name"] + kb.description = candidate_state["description"] + kb.emoji = candidate_state["emoji"] + kb.embedding_provider_id = candidate_state["embedding_provider_id"] + kb.rerank_provider_id = candidate_state["rerank_provider_id"] # type: ignore[assignment] + kb.chunk_size = candidate_state["chunk_size"] + kb.chunk_overlap = candidate_state["chunk_overlap"] + kb.top_k_dense = candidate_state["top_k_dense"] + kb.top_k_sparse = candidate_state["top_k_sparse"] + kb.top_m_final = candidate_state["top_m_final"] + kb.index_type = candidate_state["index_type"] # Build a new helper first. Keep current vec_db alive until new init succeeds. new_helper = KBHelper( @@ -382,34 +448,49 @@ async def update_kb( async def retrieve( self, query: str, - kb_names: list[str], + kb_names: list[str] | None = None, + kb_ids: list[str] | None = None, top_k_fusion: int = 20, top_m_final: int = 5, ) -> dict | None: """从指定知识库中检索相关内容""" - kb_ids = [] + resolved_kb_ids = [] kb_id_helper_map = {} unavailable_kbs = [] - for kb_name in kb_names: - if kb_helper := await self.get_kb_by_name(kb_name): - if kb_helper.init_error: - unavailable_kbs.append((kb_name, kb_helper.init_error)) - logger.warning(f"知识库 {kb_name} 不可用: {kb_helper.init_error}") - continue - kb_ids.append(kb_helper.kb.kb_id) - kb_id_helper_map[kb_helper.kb.kb_id] = kb_helper + if kb_ids: + for kb_id in kb_ids: + if kb_helper := await self.get_kb(kb_id): + if kb_helper.init_error: + unavailable_kbs.append((kb_id, kb_helper.init_error)) + logger.warning(f"知识库 {kb_id} 不可用: {kb_helper.init_error}") + continue + resolved_kb_ids.append(kb_helper.kb.kb_id) + kb_id_helper_map[kb_helper.kb.kb_id] = kb_helper + elif kb_names: + for kb_name in kb_names: + if kb_helper := await self.get_kb_by_name(kb_name): + if kb_helper.init_error: + unavailable_kbs.append((kb_name, kb_helper.init_error)) + logger.warning( + f"知识库 {kb_name} 不可用: {kb_helper.init_error}", + ) + continue + resolved_kb_ids.append(kb_helper.kb.kb_id) + kb_id_helper_map[kb_helper.kb.kb_id] = kb_helper + else: + return {} # all requested KBs are unavailable - if not kb_ids and unavailable_kbs: + if not resolved_kb_ids and unavailable_kbs: errors = "; ".join(f"{n}: {e}" for n, e in unavailable_kbs) raise ValueError(f"所有请求的知识库均不可用: {errors}") - if not kb_ids: + if not resolved_kb_ids: return {} results = await self.retrieval_manager.retrieve( query=query, - kb_ids=kb_ids, + kb_ids=resolved_kb_ids, kb_id_helper_map=kb_id_helper_map, top_k_fusion=top_k_fusion, top_m_final=top_m_final, diff --git a/tests/unit/test_kb_manager_resilience.py b/tests/unit/test_kb_manager_resilience.py index c4eb30257f..c84973f131 100644 --- a/tests/unit/test_kb_manager_resilience.py +++ b/tests/unit/test_kb_manager_resilience.py @@ -88,6 +88,59 @@ def mock_embedding_provider(): return provider +@pytest.mark.asyncio +async def test_load_kbs_does_not_limit_database_records( + stub_provider_manager_module, + mock_provider_manager, + mock_kb_db, +): + from astrbot.core.knowledge_base.kb_helper import KBHelper + from astrbot.core.knowledge_base.kb_mgr import KnowledgeBaseManager + + kb_mgr = KnowledgeBaseManager.__new__(KnowledgeBaseManager) + kb_mgr.provider_manager = mock_provider_manager + kb_mgr.kb_db = mock_kb_db + kb_mgr.kb_insts = {} + kb_mgr._kb_name_index = {} + + with patch.object(KBHelper, "initialize", new_callable=AsyncMock): + await kb_mgr.load_kbs() + + mock_kb_db.list_kbs.assert_awaited_once_with() + + +@pytest.mark.asyncio +async def test_update_kb_invalid_options_do_not_mutate_existing_kb( + stub_provider_manager_module, + mock_provider_manager, + mock_kb_db, + mock_knowledge_base, +): + from astrbot.core.knowledge_base.kb_helper import KBHelper + from astrbot.core.knowledge_base.kb_mgr import KnowledgeBaseManager + + old_helper = KBHelper.__new__(KBHelper) + old_helper.kb = mock_knowledge_base + old_helper.init_error = None + + kb_mgr = KnowledgeBaseManager.__new__(KnowledgeBaseManager) + kb_mgr.provider_manager = mock_provider_manager + kb_mgr.kb_db = mock_kb_db + kb_mgr.kb_insts = {mock_knowledge_base.kb_id: old_helper} + + with patch.object(KBHelper, "initialize", new_callable=AsyncMock) as mock_init: + with pytest.raises(ValueError, match="chunk_overlap"): + await kb_mgr.update_kb( + kb_id=mock_knowledge_base.kb_id, + chunk_size=100, + chunk_overlap=100, + ) + + mock_init.assert_not_awaited() + assert mock_knowledge_base.chunk_size == 512 + assert mock_knowledge_base.chunk_overlap == 50 + + @pytest.mark.asyncio async def test_update_kb_preserves_old_instance_when_reinit_fails( stub_provider_manager_module, From 8d869f049c78f933f826c74e9f8ac44aecadb54a Mon Sep 17 00:00:00 2001 From: lxfight <1686540385@qq.com> Date: Sat, 6 Jun 2026 00:14:49 +0800 Subject: [PATCH 23/48] fix(kb): validate upload parameters, add file type/size checks, support search/kb_ids in API - Add _coerce_optional_int, _validate_chunk_options, _validate_kb_options, _validate_upload_options, and _validate_upload_file helpers so every numeric parameter is validated before use. - Enforce ALLOWED_UPLOAD_EXTENSIONS whitelist and 128 MB file size cap. - Expose kb_ids alongside kb_names in /kb/retrieve so the dashboard can query by stable ID. - Add search and total count to /kb/document/list for paginated search in the UI. - Pre-validate update_kb options against the existing KB state before forwarding to the manager layer. - Resolve t-SNE visualization kb_names from kb_ids when only IDs are provided. --- astrbot/dashboard/routes/knowledge_base.py | 296 ++++++++++++++++++--- tests/test_kb_update_route.py | 123 +++++++++ 2 files changed, 388 insertions(+), 31 deletions(-) diff --git a/astrbot/dashboard/routes/knowledge_base.py b/astrbot/dashboard/routes/knowledge_base.py index d1625fb6c2..c17663de1d 100644 --- a/astrbot/dashboard/routes/knowledge_base.py +++ b/astrbot/dashboard/routes/knowledge_base.py @@ -17,6 +17,20 @@ from ..utils import generate_tsne_visualization from .route import Response, Route, RouteContext +ALLOWED_UPLOAD_EXTENSIONS = { + "adoc", + "docx", + "epub", + "md", + "markdown", + "pdf", + "rst", + "txt", + "xls", + "xlsx", +} +MAX_UPLOAD_FILE_SIZE = 128 * 1024 * 1024 + class KnowledgeBaseRoute(Route): """知识库管理路由 @@ -151,6 +165,85 @@ def _format_failed_doc_error(file_name: str, error: Exception) -> str: return message return f"{file_name}: {message}" + @staticmethod + def _coerce_optional_int(value: Any, field_name: str) -> int | None: + if value in (None, ""): + return None + try: + return int(value) + except (TypeError, ValueError) as e: + raise ValueError(f"{field_name} 必须是整数") from e + + @staticmethod + def _validate_chunk_options( + *, + chunk_size: int | None, + chunk_overlap: int | None, + ) -> None: + if chunk_size is not None and chunk_size <= 0: + raise ValueError("chunk_size 必须大于 0") + if chunk_overlap is not None and chunk_overlap < 0: + raise ValueError("chunk_overlap 不能为负数") + if ( + chunk_size is not None + and chunk_overlap is not None + and chunk_overlap >= chunk_size + ): + raise ValueError("chunk_overlap 必须小于 chunk_size") + + @staticmethod + def _validate_positive_int(value: int | None, field_name: str) -> None: + if value is not None and value <= 0: + raise ValueError(f"{field_name} 必须大于 0") + + @classmethod + def _validate_kb_options( + cls, + *, + chunk_size: int | None, + chunk_overlap: int | None, + top_k_dense: int | None, + top_k_sparse: int | None, + top_m_final: int | None, + index_type: str | None, + ) -> None: + cls._validate_chunk_options( + chunk_size=chunk_size, + chunk_overlap=chunk_overlap, + ) + cls._validate_positive_int(top_k_dense, "top_k_dense") + cls._validate_positive_int(top_k_sparse, "top_k_sparse") + cls._validate_positive_int(top_m_final, "top_m_final") + if index_type is not None and index_type not in {"flat", "hnsw"}: + raise ValueError("index_type 必须是 flat 或 hnsw") + + @classmethod + def _validate_upload_options( + cls, + *, + chunk_size: int, + chunk_overlap: int, + batch_size: int, + tasks_limit: int, + max_retries: int, + ) -> None: + cls._validate_chunk_options( + chunk_size=chunk_size, + chunk_overlap=chunk_overlap, + ) + cls._validate_positive_int(batch_size, "batch_size") + cls._validate_positive_int(tasks_limit, "tasks_limit") + if max_retries < 0: + raise ValueError("max_retries 不能为负数") + + @staticmethod + def _validate_upload_file(file_name: str, file_size: int) -> None: + file_type = file_name.rsplit(".", 1)[-1].lower() if "." in file_name else "" + if file_type not in ALLOWED_UPLOAD_EXTENSIONS: + raise ValueError(f"不支持的文件类型: {file_name}") + if file_size > MAX_UPLOAD_FILE_SIZE: + raise ValueError(f"文件超过 128MB 限制: {file_name}") + async def _background_upload_task( self, task_id: str, @@ -395,11 +488,32 @@ async def create_kb(self): emoji = data.get("emoji") embedding_provider_id = data.get("embedding_provider_id") rerank_provider_id = data.get("rerank_provider_id") - chunk_size = data.get("chunk_size") - chunk_overlap = data.get("chunk_overlap") - top_k_dense = data.get("top_k_dense") - top_k_sparse = data.get("top_k_sparse") - top_m_final = data.get("top_m_final") + chunk_size = self._coerce_optional_int(data.get("chunk_size"), "chunk_size") + chunk_overlap = self._coerce_optional_int( + data.get("chunk_overlap"), + "chunk_overlap", + ) + top_k_dense = self._coerce_optional_int( + data.get("top_k_dense"), + "top_k_dense", + ) + top_k_sparse = self._coerce_optional_int( + data.get("top_k_sparse"), + "top_k_sparse", + ) + top_m_final = self._coerce_optional_int( + data.get("top_m_final"), + "top_m_final", + ) + index_type = data.get("index_type") + self._validate_kb_options( + chunk_size=chunk_size if chunk_size is not None else 512, + chunk_overlap=chunk_overlap if chunk_overlap is not None else 50, + top_k_dense=top_k_dense if top_k_dense is not None else 50, + top_k_sparse=top_k_sparse if top_k_sparse is not None else 50, + top_m_final=top_m_final if top_m_final is not None else 5, + index_type=index_type if index_type is not None else "flat", + ) # pre-check embedding dim if not embedding_provider_id: @@ -454,6 +568,7 @@ async def create_kb(self): top_k_dense=top_k_dense, top_k_sparse=top_k_sparse, top_m_final=top_m_final, + index_type=index_type, ) kb = kb_helper.kb @@ -540,12 +655,48 @@ async def update_kb(self): rerank_provider_id = ( data.get("rerank_provider_id") if rerank_provider_provided else None ) - chunk_size = data.get("chunk_size") - chunk_overlap = data.get("chunk_overlap") - top_k_dense = data.get("top_k_dense") - top_k_sparse = data.get("top_k_sparse") - top_m_final = data.get("top_m_final") + chunk_size = self._coerce_optional_int(data.get("chunk_size"), "chunk_size") + chunk_overlap = self._coerce_optional_int( + data.get("chunk_overlap"), + "chunk_overlap", + ) + top_k_dense = self._coerce_optional_int( + data.get("top_k_dense"), + "top_k_dense", + ) + top_k_sparse = self._coerce_optional_int( + data.get("top_k_sparse"), + "top_k_sparse", + ) + top_m_final = self._coerce_optional_int( + data.get("top_m_final"), + "top_m_final", + ) index_type = data.get("index_type") + kb_helper = await kb_manager.get_kb(kb_id) + if not kb_helper: + return Response().error("知识库不存在").__dict__ + current_kb = kb_helper.kb + self._validate_kb_options( + chunk_size=chunk_size + if chunk_size is not None + else current_kb.chunk_size, + chunk_overlap=chunk_overlap + if chunk_overlap is not None + else current_kb.chunk_overlap, + top_k_dense=top_k_dense + if top_k_dense is not None + else current_kb.top_k_dense, + top_k_sparse=top_k_sparse + if top_k_sparse is not None + else current_kb.top_k_sparse, + top_m_final=top_m_final + if top_m_final is not None + else current_kb.top_m_final, + index_type=index_type + if index_type is not None + else current_kb.index_type, + ) kb_helper = await kb_manager.update_kb( kb_id=kb_id, @@ -660,19 +811,32 @@ async def list_documents(self): if not kb_helper: return Response().error("知识库不存在").__dict__ - page = request.args.get("page", 1, type=int) - page_size = request.args.get("page_size", 100, type=int) + page = max(request.args.get("page", 1, type=int), 1) + page_size = max(request.args.get("page_size", 100, type=int), 1) + search = (request.args.get("search") or "").strip() or None offset = (page - 1) * page_size limit = page_size - doc_list = await kb_helper.list_documents(offset=offset, limit=limit) + doc_list = await kb_helper.list_documents( + offset=offset, + limit=limit, + search=search, + ) + total = await kb_helper.count_documents(search=search) doc_list = [doc.model_dump() for doc in doc_list] return ( Response() - .ok({"items": doc_list, "page": page, "page_size": page_size}) + .ok( + { + "items": doc_list, + "page": page, + "page_size": page_size, + "total": total, + }, + ) .__dict__ ) @@ -724,11 +888,38 @@ async def upload_document(self): files = await request.files kb_id = form_data.get("kb_id") - chunk_size = int(form_data.get("chunk_size", 512)) - chunk_overlap = int(form_data.get("chunk_overlap", 50)) - batch_size = int(form_data.get("batch_size", 32)) - tasks_limit = int(form_data.get("tasks_limit", 3)) - max_retries = int(form_data.get("max_retries", 3)) + chunk_size = self._coerce_optional_int( + form_data.get("chunk_size"), + "chunk_size", + ) + chunk_overlap = self._coerce_optional_int( + form_data.get("chunk_overlap"), + "chunk_overlap", + ) + batch_size = self._coerce_optional_int( + form_data.get("batch_size"), + "batch_size", + ) + tasks_limit = self._coerce_optional_int( + form_data.get("tasks_limit"), + "tasks_limit", + ) + max_retries = self._coerce_optional_int( + form_data.get("max_retries"), + "max_retries", + ) + chunk_size = chunk_size if chunk_size is not None else 512 + chunk_overlap = chunk_overlap if chunk_overlap is not None else 50 + batch_size = batch_size if batch_size is not None else 32 + tasks_limit = tasks_limit if tasks_limit is not None else 3 + max_retries = max_retries if max_retries is not None else 3 + self._validate_upload_options( + chunk_size=chunk_size, + chunk_overlap=chunk_overlap, + batch_size=batch_size, + tasks_limit=tasks_limit, + max_retries=max_retries, + ) if not kb_id: return Response().error("缺少参数 kb_id").__dict__ @@ -767,6 +958,7 @@ async def upload_document(self): file_type = ( file_name.rsplit(".", 1)[-1].lower() if "." in file_name else "" ) + self._validate_upload_file(file_name, len(file_content)) files_to_upload.append( { @@ -843,9 +1035,16 @@ def _validate_import_request(self, data: dict): ): raise ValueError("chunks 必须是非空字符串列表") - batch_size = data.get("batch_size", 32) - tasks_limit = data.get("tasks_limit", 3) - max_retries = data.get("max_retries", 3) + batch_size = self._coerce_optional_int(data.get("batch_size"), "batch_size") + tasks_limit = self._coerce_optional_int(data.get("tasks_limit"), "tasks_limit") + max_retries = self._coerce_optional_int(data.get("max_retries"), "max_retries") + batch_size = batch_size if batch_size is not None else 32 + tasks_limit = tasks_limit if tasks_limit is not None else 3 + max_retries = max_retries if max_retries is not None else 3 + self._validate_positive_int(batch_size, "batch_size") + self._validate_positive_int(tasks_limit, "tasks_limit") + if max_retries < 0: + raise ValueError("max_retries 不能为负数") return kb_id, documents, batch_size, tasks_limit, max_retries async def import_documents(self): @@ -1175,19 +1374,27 @@ async def retrieve(self): data = await request.json query = data.get("query") + kb_ids = data.get("kb_ids") kb_names = data.get("kb_names") debug = data.get("debug", False) if not query: return Response().error("缺少参数 query").__dict__ - if not kb_names or not isinstance(kb_names, list): - return Response().error("缺少参数 kb_names 或格式错误").__dict__ + if kb_ids is not None and not isinstance(kb_ids, list): + return Response().error("参数 kb_ids 格式错误").__dict__ + if kb_names is not None and not isinstance(kb_names, list): + return Response().error("参数 kb_names 格式错误").__dict__ + if not kb_ids and not kb_names: + return Response().error("缺少参数 kb_ids 或 kb_names").__dict__ - top_k = data.get("top_k", 5) + top_k = self._coerce_optional_int(data.get("top_k", 5), "top_k") + top_k = top_k if top_k is not None else 5 + self._validate_positive_int(top_k, "top_k") results = await kb_manager.retrieve( query=query, kb_names=kb_names, + kb_ids=kb_ids, top_m_final=top_k, ) result_list = [] @@ -1203,9 +1410,15 @@ async def retrieve(self): # Debug 模式:生成 t-SNE 可视化 if debug: try: + visualization_kb_names = kb_names + if not visualization_kb_names and kb_ids: + visualization_kb_names = [] + for kb_id in kb_ids: + if kb_helper := await kb_manager.get_kb(kb_id): + visualization_kb_names.append(kb_helper.kb.kb_name) img_base64 = await generate_tsne_visualization( query, - kb_names, + visualization_kb_names or [], kb_manager, ) if img_base64: @@ -1251,11 +1464,32 @@ async def upload_document_from_url(self): if not url: return Response().error("缺少参数 url").__dict__ - chunk_size = data.get("chunk_size", 512) - chunk_overlap = data.get("chunk_overlap", 50) - batch_size = data.get("batch_size", 32) - tasks_limit = data.get("tasks_limit", 3) - max_retries = data.get("max_retries", 3) + chunk_size = self._coerce_optional_int(data.get("chunk_size"), "chunk_size") + chunk_overlap = self._coerce_optional_int( + data.get("chunk_overlap"), + "chunk_overlap", + ) + batch_size = self._coerce_optional_int(data.get("batch_size"), "batch_size") + tasks_limit = self._coerce_optional_int( + data.get("tasks_limit"), + "tasks_limit", + ) + max_retries = self._coerce_optional_int( + data.get("max_retries"), + "max_retries", + ) + chunk_size = chunk_size if chunk_size is not None else 512 + chunk_overlap = chunk_overlap if chunk_overlap is not None else 50 + batch_size = batch_size if batch_size is not None else 32 + tasks_limit = tasks_limit if tasks_limit is not None else 3 + max_retries = max_retries if max_retries is not None else 3 + self._validate_upload_options( + chunk_size=chunk_size, + chunk_overlap=chunk_overlap, + batch_size=batch_size, + tasks_limit=tasks_limit, + max_retries=max_retries, + ) enable_cleaning = data.get("enable_cleaning", False) cleaning_provider_id = data.get("cleaning_provider_id") diff --git a/tests/test_kb_update_route.py b/tests/test_kb_update_route.py index 18a22194f5..e2e7401339 100644 --- a/tests/test_kb_update_route.py +++ b/tests/test_kb_update_route.py @@ -26,12 +26,33 @@ def _build_kb_helper(rerank_provider_id: str | None = "rerank-1"): return helper +def _build_kb_helper_with_options(**kwargs): + from astrbot.core.knowledge_base.models import KnowledgeBase + + kb = KnowledgeBase( + kb_id=kwargs.get("kb_id", "kb-1"), + kb_name=kwargs.get("kb_name", "kb"), + embedding_provider_id="emb-1", + rerank_provider_id=kwargs.get("rerank_provider_id", "rerank-1"), + chunk_size=kwargs.get("chunk_size", 512), + chunk_overlap=kwargs.get("chunk_overlap", 50), + top_k_dense=kwargs.get("top_k_dense", 50), + top_k_sparse=kwargs.get("top_k_sparse", 50), + top_m_final=kwargs.get("top_m_final", 5), + index_type=kwargs.get("index_type", "flat"), + ) + helper = MagicMock() + helper.kb = kb + return helper + + @pytest.mark.asyncio async def test_update_kb_omits_unprovided_rerank_provider_id(): from astrbot.dashboard.routes.knowledge_base import KnowledgeBaseRoute app = Quart(__name__) kb_manager = MagicMock() + kb_manager.get_kb = AsyncMock(return_value=_build_kb_helper_with_options()) kb_manager.update_kb = AsyncMock(return_value=_build_kb_helper()) route = _build_route_with_manager(kb_manager) @@ -55,6 +76,7 @@ async def test_update_kb_explicit_null_forwards_rerank_provider_id(): app = Quart(__name__) kb_manager = MagicMock() + kb_manager.get_kb = AsyncMock(return_value=_build_kb_helper_with_options()) kb_manager.update_kb = AsyncMock(return_value=_build_kb_helper(None)) route = _build_route_with_manager(kb_manager) @@ -69,3 +91,104 @@ async def test_update_kb_explicit_null_forwards_rerank_provider_id(): kwargs = kb_manager.update_kb.await_args.kwargs assert kwargs["kb_id"] == "kb-1" assert kwargs["rerank_provider_id"] is None + + +@pytest.mark.asyncio +async def test_update_kb_rejects_overlap_not_less_than_chunk_size(): + from astrbot.dashboard.routes.knowledge_base import KnowledgeBaseRoute + + app = Quart(__name__) + kb_manager = MagicMock() + kb_manager.get_kb = AsyncMock(return_value=_build_kb_helper_with_options()) + kb_manager.update_kb = AsyncMock() + route = _build_route_with_manager(kb_manager) + + async with app.test_request_context( + "/api/kb/update", + method="POST", + json={"kb_id": "kb-1", "chunk_size": 100, "chunk_overlap": 100}, + ): + response = await KnowledgeBaseRoute.update_kb(route) + + assert response["status"] == "error" + assert "chunk_overlap" in response["message"] + kb_manager.update_kb.assert_not_awaited() + + +@pytest.mark.asyncio +async def test_retrieve_accepts_kb_ids(): + from astrbot.dashboard.routes.knowledge_base import KnowledgeBaseRoute + + app = Quart(__name__) + kb_manager = MagicMock() + kb_manager.retrieve = AsyncMock(return_value={"results": []}) + route = _build_route_with_manager(kb_manager) + + async with app.test_request_context( + "/api/kb/retrieve", + method="POST", + json={"query": "hello", "kb_ids": ["kb-1"], "top_k": 3}, + ): + response = await KnowledgeBaseRoute.retrieve(route) + + assert response["status"] == "ok" + kb_manager.retrieve.assert_awaited_once_with( + query="hello", + kb_names=None, + kb_ids=["kb-1"], + top_m_final=3, + ) + + +@pytest.mark.asyncio +async def test_retrieve_rejects_invalid_top_k(): + from astrbot.dashboard.routes.knowledge_base import KnowledgeBaseRoute + + app = Quart(__name__) + kb_manager = MagicMock() + kb_manager.retrieve = AsyncMock() + route = _build_route_with_manager(kb_manager) + + async with app.test_request_context( + "/api/kb/retrieve", + method="POST", + json={"query": "hello", "kb_ids": ["kb-1"], "top_k": 0}, + ): + response = await KnowledgeBaseRoute.retrieve(route) + + assert response["status"] == "error" + assert "top_k" in response["message"] + kb_manager.retrieve.assert_not_awaited() + + +@pytest.mark.asyncio +async def test_list_documents_returns_total_and_uses_requested_pagination(): + from astrbot.dashboard.routes.knowledge_base import KnowledgeBaseRoute + + app = Quart(__name__) + kb_helper = MagicMock() + doc = MagicMock() + doc.model_dump.return_value = {"doc_id": "doc-1", "doc_name": "alpha.md"} + kb_helper.list_documents = AsyncMock(return_value=[doc]) + kb_helper.count_documents = AsyncMock(return_value=123) + kb_manager = MagicMock() + kb_manager.get_kb = AsyncMock(return_value=kb_helper) + route = _build_route_with_manager(kb_manager) + + async with app.test_request_context( + "/api/kb/document/list?kb_id=kb-1&page=3&page_size=25&search=alpha", + method="GET", + ): + response = await KnowledgeBaseRoute.list_documents(route) + + assert response["status"] == "ok" + assert response["data"]["items"] == [{"doc_id": "doc-1", "doc_name": "alpha.md"}] + assert response["data"]["page"] == 3 + assert response["data"]["page_size"] == 25 + assert response["data"]["total"] == 123 + kb_helper.list_documents.assert_awaited_once_with( + offset=50, + limit=25, + search="alpha", + ) + kb_helper.count_documents.assert_awaited_once_with(search="alpha") From f88a33900bb10bdfea1f04c0bcfcab320de41898 Mon Sep 17 00:00:00 2001 From: lxfight <1686540385@qq.com> Date: Sat, 6 Jun 2026 00:14:59 +0800 Subject: [PATCH 24/48] feat(dashboard): add document search, improved upload UX, and validation feedback in KB UI - Add search input to DocumentsTab with debounced server-side filtering. - Show total document count alongside paginated results. - Improve file upload with better validation feedback, file count display, and clear-button for the selection. - Add upload parameter validation (chunk_size, chunk_overlap, etc.) with inline error messages before submitting. - Surface init_error state on KB cards in the list view. - Add i18n keys for new UI strings across en-US, zh-CN, and ru-RU. --- .../en-US/features/knowledge-base/detail.json | 64 + .../features/knowledge-base/document.json | 8 + .../en-US/features/knowledge-base/index.json | 12 +- .../ru-RU/features/knowledge-base/detail.json | 298 ++-- .../features/knowledge-base/document.json | 116 +- .../ru-RU/features/knowledge-base/index.json | 136 +- .../zh-CN/features/knowledge-base/detail.json | 64 + .../features/knowledge-base/document.json | 8 + .../zh-CN/features/knowledge-base/index.json | 12 +- dashboard/src/main.ts | 11 +- dashboard/src/plugins/vuetify.ts | 75 +- .../views/knowledge-base/DocumentDetail.vue | 370 +++-- .../src/views/knowledge-base/KBDetail.vue | 233 ++-- dashboard/src/views/knowledge-base/KBList.vue | 555 +++++--- .../components/DocumentsTab.vue | 1213 +++++++++++------ .../components/RetrievalTab.vue | 229 +++- .../knowledge-base/components/SettingsTab.vue | 279 ++-- .../components/TavilyKeyDialog.vue | 119 +- 18 files changed, 2465 insertions(+), 1337 deletions(-) diff --git a/dashboard/src/i18n/locales/en-US/features/knowledge-base/detail.json b/dashboard/src/i18n/locales/en-US/features/knowledge-base/detail.json index 78a00669e3..642eb7fdbe 100644 --- a/dashboard/src/i18n/locales/en-US/features/knowledge-base/detail.json +++ b/dashboard/src/i18n/locales/en-US/features/knowledge-base/detail.json @@ -29,6 +29,7 @@ "title": "Documents", "upload": "Upload Document", "empty": "No documents", + "searchPlaceholder": "Search documents...", "name": "Name", "type": "Type", "size": "Size", @@ -37,11 +38,13 @@ "actions": "Actions", "view": "View", "delete": "Delete", + "cancel": "Cancel", "deleteConfirm": "Are you sure you want to delete document '{name}'?", "deleteWarning": "This will delete the document and all its chunks. This action cannot be undone.", "uploading": "Uploading...", "uploadSuccess": "Document uploaded successfully", "uploadFailed": "Failed to upload document", + "loadFailed": "Failed to load documents", "deleteSuccess": "Document deleted successfully", "deleteFailed": "Failed to delete document" }, @@ -51,6 +54,14 @@ "dropzone": "Drop files here or click to select", "supportedFormats": "Supported formats: .txt, .md, .markdown, .rst, .adoc, .pdf, .docx, .epub, .xls, .xlsx", "maxSize": "Max file size: 128MB", + "maxFiles": "Upload up to 10 files", + "maxFilesWarning": "You can select up to {count} files", + "selectedFiles": "{count} files selected", + "clear": "Clear", + "someFilesRejected": "Some files were not added", + "unsupportedFile": "{name}: unsupported file type", + "fileTooLarge": "{name}: file exceeds 128MB", + "invalidSettings": "Please check the upload settings", "chunkSettings": "Chunk Settings", "batchSettings": "Batch Settings", "cleaningSettings": "Cleaning Settings", @@ -75,6 +86,24 @@ "urlPlaceholder": "Enter the URL of the web page to extract content from", "urlRequired": "Please enter a URL", "urlHint": "The main content will be automatically extracted from the target URL as a document. Currently supports {supported} pages. Before use, please ensure that the target web page allows crawler access.", + "tavilyCheckFailed": "Failed to check web search configuration", + "tavilyRequired": "Tavily Key is required for this feature", + "configure": "Configure", + "tavilyConfigured": "Tavily API Key configured", + "backgroundUploading": "Uploading {count} files in the background...", + "backgroundUrlUploading": "Extracting URL content in the background...", + "successCount": "Successfully uploaded {count} documents", + "partialSuccess": "Upload finished: {success} succeeded, {failed} failed", + "failedWithReason": "Upload failed: {reason}", + "unknownError": "Unknown error", + "stages": { + "waiting": "Waiting...", + "extracting": "Extracting content...", + "cleaning": "Cleaning content...", + "parsing": "Parsing document...", + "chunking": "Chunking text...", + "embedding": "Generating embeddings..." + }, "beta": "Beta" }, "retrieval": { @@ -88,6 +117,8 @@ "noResults": "No results found", "tryDifferentQuery": "Try a different query", "settings": "Retrieval Settings", + "debugMode": "Debug Mode", + "tsneVisualization": "t-SNE Visualization", "topK": "Number of Results", "topKHint": "Maximum number of results to return", "enableRerank": "Enable Rerank", @@ -113,9 +144,42 @@ "enableRerank": "Enable Rerank", "embeddingProvider": "Embedding Provider", "rerankProvider": "Rerank Provider", + "embeddingProviderHint": "The embedding model is bound to the current vector index. Create a new knowledge base to change it.", + "indexType": "Index Type", + "indexTypeHint": "Flat is exact; HNSW is better for larger knowledge bases.", + "indexTypes": { + "flat": "Flat exact index", + "hnsw": "HNSW approximate index" + }, "save": "Save Settings", "saveSuccess": "Settings saved successfully", "saveFailed": "Failed to save settings", + "providersLoadFailed": "Failed to load model providers", "tips": "Tip: Modifying retrieval settings will affect subsequent knowledge base queries." + }, + "validation": { + "integer": "Enter an integer", + "positiveInteger": "Enter an integer greater than 0", + "nonNegativeInteger": "Enter an integer no less than 0", + "overlapLessThanSize": "Chunk overlap must be less than chunk size", + "topKRange": "Number of results must be an integer from 1 to 100" + }, + "actions": { + "retry": "Retry" + }, + "messages": { + "loadFailed": "Failed to load knowledge base details" + }, + "tavily": { + "title": "Configure Tavily API Key", + "description": "A Tavily API Key is required to use web-based knowledge base features. You can get one from", + "officialSite": "Tavily", + "apiKeyLabel": "Tavily API Key", + "cancel": "Cancel", + "save": "Save", + "keyRequired": "API Key is required", + "loadConfigFailed": "Failed to load current configuration", + "saveFailed": "Failed to save. Please check the key.", + "unknownSaveFailed": "Failed to save due to an unknown error" } } diff --git a/dashboard/src/i18n/locales/en-US/features/knowledge-base/document.json b/dashboard/src/i18n/locales/en-US/features/knowledge-base/document.json index d3a3b65c9a..8cf45bd51f 100644 --- a/dashboard/src/i18n/locales/en-US/features/knowledge-base/document.json +++ b/dashboard/src/i18n/locales/en-US/features/knowledge-base/document.json @@ -15,6 +15,7 @@ "index": "Index", "content": "Content", "charCount": "Characters", + "charCountValue": "{count} characters", "actions": "Actions", "view": "View", "edit": "Edit", @@ -51,5 +52,12 @@ "charCount": "Characters", "vecDocId": "Vector ID", "close": "Close" + }, + "actions": { + "retry": "Retry" + }, + "messages": { + "loadDocumentFailed": "Failed to load document details", + "loadChunksFailed": "Failed to load chunks" } } diff --git a/dashboard/src/i18n/locales/en-US/features/knowledge-base/index.json b/dashboard/src/i18n/locales/en-US/features/knowledge-base/index.json index 67bb4d5717..960edf067c 100644 --- a/dashboard/src/i18n/locales/en-US/features/knowledge-base/index.json +++ b/dashboard/src/i18n/locales/en-US/features/knowledge-base/index.json @@ -11,7 +11,9 @@ "documents": "Documents", "chunks": "Chunks", "sessionConfig": "Session Config", - "initError": "Initialization Failed" + "initError": "Initialization Failed", + "noDescription": "No description", + "switchToLegacy": "Switch to legacy knowledge base" }, "card": { "edit": "Edit", @@ -31,9 +33,12 @@ "rerankModelLabel": "Rerank Model (Optional)", "providerInfo": "Provider: {id} | Dimensions: {dimensions}", "rerankProviderInfo": "Provider: {id}", + "nameHint": "If you rename this knowledge base later, update any configuration that still references names.", + "embeddingModelHint": "The embedding model cannot be changed after creation. Create a new knowledge base to use another model.", "cancel": "Cancel", "submit": "Create", - "nameRequired": "Please enter knowledge base name" + "nameRequired": "Please enter knowledge base name", + "embeddingRequired": "Please select an embedding model" }, "edit": { "title": "Edit Knowledge Base", @@ -63,6 +68,7 @@ "updateFailed": "Failed to update", "deleteSuccess": "Knowledge base deleted successfully", "deleteFailed": "Failed to delete", - "loadError": "Failed to load knowledge base list" + "loadError": "Failed to load knowledge base list", + "providersLoadError": "Failed to load model providers" } } diff --git a/dashboard/src/i18n/locales/ru-RU/features/knowledge-base/detail.json b/dashboard/src/i18n/locales/ru-RU/features/knowledge-base/detail.json index 5145d5c285..0e3eab1cfe 100644 --- a/dashboard/src/i18n/locales/ru-RU/features/knowledge-base/detail.json +++ b/dashboard/src/i18n/locales/ru-RU/features/knowledge-base/detail.json @@ -1,121 +1,185 @@ { - "title": "Детали базы знаний", - "backToList": "К списку", - "breadcrumb": { - "list": "Базы знаний" + "title": "Детали базы знаний", + "backToList": "К списку", + "breadcrumb": { + "list": "Базы знаний" + }, + "tabs": { + "overview": "Обзор", + "documents": "Документы", + "retrieval": "Поиск", + "sessions": "Сессии", + "settings": "Настройки" + }, + "overview": { + "title": "Информация", + "name": "Название", + "description": "Описание", + "emoji": "Иконка", + "createdAt": "Создана", + "updatedAt": "Обновлена", + "stats": "Статистика", + "docCount": "Количество документов", + "chunkCount": "Количество фрагментов", + "embeddingModel": "Embedding модель", + "rerankModel": "Rerank модель", + "notSet": "не выбрано" + }, + "documents": { + "title": "Список документов", + "upload": "Загрузить", + "empty": "Документов нет", + "searchPlaceholder": "Поиск документов...", + "name": "Имя файла", + "type": "Тип", + "size": "Размер", + "chunks": "Фрагменты", + "createdAt": "Дата загрузки", + "actions": "Действия", + "view": "Смотреть", + "delete": "Удалить", + "cancel": "Отмена", + "deleteConfirm": "Вы уверены, что хотите удалить «{name}»?", + "deleteWarning": "Это удалит файл и все его фрагменты из индекса.", + "uploading": "Загрузка...", + "uploadSuccess": "Файл успешно загружен", + "uploadFailed": "Ошибка загрузки", + "loadFailed": "Не удалось загрузить документы", + "deleteSuccess": "Файл удален", + "deleteFailed": "Ошибка удаления" + }, + "upload": { + "title": "Добавление контента", + "selectFile": "Файл", + "dropzone": "Нажмите или перетащите файл сюда", + "supportedFormats": "Форматы: .txt, .md, .markdown, .rst, .adoc, .pdf, .docx, .epub, .xls, .xlsx", + "maxSize": "Максимум: 128MB", + "maxFiles": "Можно загрузить до 10 файлов", + "maxFilesWarning": "Можно выбрать не более {count} файлов", + "selectedFiles": "Выбрано файлов: {count}", + "clear": "Очистить", + "someFilesRejected": "Некоторые файлы не добавлены", + "unsupportedFile": "{name}: неподдерживаемый тип файла", + "fileTooLarge": "{name}: файл больше 128MB", + "invalidSettings": "Проверьте параметры загрузки", + "chunkSettings": "Фрагментация", + "batchSettings": "Пакетная обработка", + "cleaningSettings": "Очистка данных", + "enableCleaning": "Включить очистку контента", + "cleaningProvider": "Сервис для очистки", + "cleaningProviderHint": "LLM провайдер для суммаризации и извлечения смыслов из веб-страниц", + "chunkSize": "Размер чанка", + "chunkSizeHint": "Символов в блоке (по умолчанию: 512)", + "chunkOverlap": "Перекрытие", + "chunkOverlapHint": "Перекрытие между блоками (по умолчанию: 50)", + "batchSize": "Размер пакета", + "batchSizeHint": "Блоков за один запрос (по умолчанию: 32)", + "tasksLimit": "Лимит задач", + "tasksLimitHint": "Макс. параллельных потоков (по умолчанию: 3)", + "maxRetries": "Попытки", + "maxRetriesHint": "Повторов при сбое (по умолчанию: 3)", + "cancel": "Отмена", + "submit": "Загрузить", + "fileRequired": "Пожалуйста, выберите файл", + "fileUpload": "Загрузка файла", + "fromUrl": "Из URL", + "urlPlaceholder": "Ссылка на веб-страницу", + "urlRequired": "Введите URL", + "urlHint": "Контент будет автоматически извлечен со страницы. Убедитесь, что сайт разрешает доступ роботам.", + "tavilyCheckFailed": "Не удалось проверить настройки веб-поиска", + "tavilyRequired": "Для этой функции нужен Tavily Key", + "configure": "Настроить", + "tavilyConfigured": "Tavily API Key сохранен", + "backgroundUploading": "Фоновая загрузка файлов: {count}...", + "backgroundUrlUploading": "Фоновое извлечение контента из URL...", + "successCount": "Успешно загружено документов: {count}", + "partialSuccess": "Загрузка завершена: успешно {success}, ошибок {failed}", + "failedWithReason": "Ошибка загрузки: {reason}", + "unknownError": "Неизвестная ошибка", + "stages": { + "waiting": "Ожидание...", + "extracting": "Извлечение контента...", + "cleaning": "Очистка контента...", + "parsing": "Разбор документа...", + "chunking": "Разбиение текста...", + "embedding": "Генерация векторов..." }, - "tabs": { - "overview": "Обзор", - "documents": "Документы", - "retrieval": "Поиск", - "sessions": "Сессии", - "settings": "Настройки" + "beta": "Бета-версия" + }, + "retrieval": { + "title": "Поиск и проверка", + "subtitle": "Проверьте качество поиска (Dense & Sparse) по вашей базе знаний", + "query": "Тестовый запрос", + "queryPlaceholder": "Что вы хотите найти?", + "search": "Найти", + "searching": "Ищем...", + "results": "Результаты поиска", + "noResults": "Релевантный контент не найден", + "tryDifferentQuery": "Попробуйте изменить формулировку запроса", + "settings": "Параметры поиска", + "debugMode": "Режим отладки", + "tsneVisualization": "t-SNE визуализация", + "topK": "Количество результатов", + "topKHint": "Сколько фрагментов возвращать", + "enableRerank": "Включить Rerank", + "enableRerankHint": "Применить переранжирование для повышения точности", + "score": "Вес (Score)", + "document": "Документ", + "chunk": "Фрагмент #{index}", + "content": "Текст", + "charCount": "{count} симв.", + "searchSuccess": "Поиск завершен, найдено: {count}", + "searchFailed": "Ошибка выполнения поиска", + "queryRequired": "Введите поисковый запрос" + }, + "settings": { + "title": "Общие настройки базы", + "basic": "Основные", + "retrieval": "Поиск", + "chunkSize": "Размер чанка", + "chunkOverlap": "Перекрытие", + "topKDense": "Вернуть (Dense)", + "topKSparse": "Вернуть (Sparse)", + "topMFinal": "Итоговый результат", + "enableRerank": "Включить Rerank", + "embeddingProvider": "Провайдер Embedding", + "rerankProvider": "Провайдер Rerank", + "embeddingProviderHint": "Embedding модель связана с текущим векторным индексом. Для смены создайте новую базу знаний.", + "indexType": "Тип индекса", + "indexTypeHint": "Flat точнее, HNSW лучше для больших баз знаний.", + "indexTypes": { + "flat": "Flat точный индекс", + "hnsw": "HNSW приближенный индекс" }, - "overview": { - "title": "Информация", - "name": "Название", - "description": "Описание", - "emoji": "Иконка", - "createdAt": "Создана", - "updatedAt": "Обновлена", - "stats": "Статистика", - "docCount": "Количество документов", - "chunkCount": "Количество фрагментов", - "embeddingModel": "Embedding модель", - "rerankModel": "Rerank модель", - "notSet": "не выбрано" - }, - "documents": { - "title": "Список документов", - "upload": "Загрузить", - "empty": "Документов нет", - "name": "Имя файла", - "type": "Тип", - "size": "Размер", - "chunks": "Фрагменты", - "createdAt": "Дата загрузки", - "actions": "Действия", - "view": "Смотреть", - "delete": "Удалить", - "deleteConfirm": "Вы уверены, что хотите удалить «{name}»?", - "deleteWarning": "Это удалит файл и все его фрагменты из индекса.", - "uploading": "Загрузка...", - "uploadSuccess": "Файл успешно загружен", - "uploadFailed": "Ошибка загрузки", - "deleteSuccess": "Файл удален", - "deleteFailed": "Ошибка удаления" - }, - "upload": { - "title": "Добавление контента", - "selectFile": "Файл", - "dropzone": "Нажмите или перетащите файл сюда", - "supportedFormats": "Форматы: .txt, .md, .markdown, .rst, .adoc, .pdf, .docx, .epub, .xls, .xlsx", - "maxSize": "Максимум: 128MB", - "chunkSettings": "Фрагментация", - "batchSettings": "Пакетная обработка", - "cleaningSettings": "Очистка данных", - "enableCleaning": "Включить очистку контента", - "cleaningProvider": "Сервис для очистки", - "cleaningProviderHint": "LLM провайдер для суммаризации и извлечения смыслов из веб-страниц", - "chunkSize": "Размер чанка", - "chunkSizeHint": "Символов в блоке (по умолчанию: 512)", - "chunkOverlap": "Перекрытие", - "chunkOverlapHint": "Перекрытие между блоками (по умолчанию: 50)", - "batchSize": "Размер пакета", - "batchSizeHint": "Блоков за один запрос (по умолчанию: 32)", - "tasksLimit": "Лимит задач", - "tasksLimitHint": "Макс. параллельных потоков (по умолчанию: 3)", - "maxRetries": "Попытки", - "maxRetriesHint": "Повторов при сбое (по умолчанию: 3)", - "cancel": "Отмена", - "submit": "Загрузить", - "fileRequired": "Пожалуйста, выберите файл", - "fileUpload": "Загрузка файла", - "fromUrl": "Из URL", - "urlPlaceholder": "Ссылка на веб-страницу", - "urlRequired": "Введите URL", - "urlHint": "Контент будет автоматически извлечен со страницы. Убедитесь, что сайт разрешает доступ роботам.", - "beta": "Бета-версия" - }, - "retrieval": { - "title": "Поиск и проверка", - "subtitle": "Проверьте качество поиска (Dense & Sparse) по вашей базе знаний", - "query": "Тестовый запрос", - "queryPlaceholder": "Что вы хотите найти?", - "search": "Найти", - "searching": "Ищем...", - "results": "Результаты поиска", - "noResults": "Релевантный контент не найден", - "tryDifferentQuery": "Попробуйте изменить формулировку запроса", - "settings": "Параметры поиска", - "topK": "Количество результатов", - "topKHint": "Сколько фрагментов возвращать", - "enableRerank": "Включить Rerank", - "enableRerankHint": "Применить переранжирование для повышения точности", - "score": "Вес (Score)", - "document": "Документ", - "chunk": "Фрагмент #{index}", - "content": "Текст", - "charCount": "{count} симв.", - "searchSuccess": "Поиск завершен, найдено: {count}", - "searchFailed": "Ошибка выполнения поиска", - "queryRequired": "Введите поисковый запрос" - }, - "settings": { - "title": "Общие настройки базы", - "basic": "Основные", - "retrieval": "Поиск", - "chunkSize": "Размер чанка", - "chunkOverlap": "Перекрытие", - "topKDense": "Вернуть (Dense)", - "topKSparse": "Вернуть (Sparse)", - "topMFinal": "Итоговый результат", - "enableRerank": "Включить Rerank", - "embeddingProvider": "Провайдер Embedding", - "rerankProvider": "Провайдер Rerank", - "save": "Сохранить", - "saveSuccess": "Настройки сохранены", - "saveFailed": "Ошибка сохранения", - "tips": "Внимание! Изменение этих параметров повлияет на будущую выдачу базы знаний." - } + "save": "Сохранить", + "saveSuccess": "Настройки сохранены", + "saveFailed": "Ошибка сохранения", + "providersLoadFailed": "Не удалось загрузить провайдеры моделей", + "tips": "Внимание! Изменение этих параметров повлияет на будущую выдачу базы знаний." + }, + "validation": { + "integer": "Введите целое число", + "positiveInteger": "Введите целое число больше 0", + "nonNegativeInteger": "Введите целое число не меньше 0", + "overlapLessThanSize": "Перекрытие должно быть меньше размера чанка", + "topKRange": "Количество результатов должно быть целым числом от 1 до 100" + }, + "actions": { + "retry": "Повторить" + }, + "messages": { + "loadFailed": "Не удалось загрузить детали базы знаний" + }, + "tavily": { + "title": "Настройка Tavily API Key", + "description": "Для веб-функций базы знаний нужен Tavily API Key. Получить его можно на", + "officialSite": "сайте Tavily", + "apiKeyLabel": "Tavily API Key", + "cancel": "Отмена", + "save": "Сохранить", + "keyRequired": "API Key обязателен", + "loadConfigFailed": "Не удалось загрузить текущую конфигурацию", + "saveFailed": "Не удалось сохранить. Проверьте ключ.", + "unknownSaveFailed": "Не удалось сохранить из-за неизвестной ошибки" + } } diff --git a/dashboard/src/i18n/locales/ru-RU/features/knowledge-base/document.json b/dashboard/src/i18n/locales/ru-RU/features/knowledge-base/document.json index 7fcb30ee9f..2de459be24 100644 --- a/dashboard/src/i18n/locales/ru-RU/features/knowledge-base/document.json +++ b/dashboard/src/i18n/locales/ru-RU/features/knowledge-base/document.json @@ -1,55 +1,63 @@ { - "title": "Просмотр документа", - "backToKB": "К базе знаний", - "info": { - "title": "Информация о документе", - "name": "Имя файла", - "type": "Формат", - "size": "Размер", - "chunkCount": "Количество фрагментов", - "createdAt": "Загружен" - }, - "chunks": { - "title": "Фрагменты текста", - "empty": "Фрагменты не найдены", - "index": "Индекс", - "content": "Текст", - "charCount": "Символов", - "actions": "Действия", - "view": "Детали", - "edit": "Изменить", - "delete": "Удалить", - "preview": "Обзор", - "search": "Поиск по документу", - "searchPlaceholder": "Найти во фрагментах...", - "showing": "Показано", - "deleteConfirm": "Удалить этот фрагмент?", - "deleteSuccess": "Фрагмент удален", - "deleteFailed": "Ошибка удаления" - }, - "edit": { - "title": "Редактирование фрагмента", - "content": "Текст", - "cancel": "Отмена", - "save": "Сохранить", - "saveSuccess": "Фрагмент обновлен", - "saveFailed": "Ошибка сохранения" - }, - "delete": { - "title": "Удаление", - "confirmText": "Вы уверены?", - "warning": "Удаление фрагмента может ухудшить качество ответов AI по этой теме.", - "cancel": "Отмена", - "confirm": "Удалить", - "deleteSuccess": "Удаление выполнено", - "deleteFailed": "Ошибка удаления" - }, - "view": { - "title": "Детальный просмотр", - "index": "Индекс", - "content": "Текст", - "charCount": "Символов", - "vecDocId": "ID вектора", - "close": "Закрыть" - } -} \ No newline at end of file + "title": "Просмотр документа", + "backToKB": "К базе знаний", + "info": { + "title": "Информация о документе", + "name": "Имя файла", + "type": "Формат", + "size": "Размер", + "chunkCount": "Количество фрагментов", + "createdAt": "Загружен" + }, + "chunks": { + "title": "Фрагменты текста", + "empty": "Фрагменты не найдены", + "index": "Индекс", + "content": "Текст", + "charCount": "Символов", + "charCountValue": "{count} симв.", + "actions": "Действия", + "view": "Детали", + "edit": "Изменить", + "delete": "Удалить", + "preview": "Обзор", + "search": "Поиск по документу", + "searchPlaceholder": "Найти во фрагментах...", + "showing": "Показано", + "deleteConfirm": "Удалить этот фрагмент?", + "deleteSuccess": "Фрагмент удален", + "deleteFailed": "Ошибка удаления" + }, + "edit": { + "title": "Редактирование фрагмента", + "content": "Текст", + "cancel": "Отмена", + "save": "Сохранить", + "saveSuccess": "Фрагмент обновлен", + "saveFailed": "Ошибка сохранения" + }, + "delete": { + "title": "Удаление", + "confirmText": "Вы уверены?", + "warning": "Удаление фрагмента может ухудшить качество ответов AI по этой теме.", + "cancel": "Отмена", + "confirm": "Удалить", + "deleteSuccess": "Удаление выполнено", + "deleteFailed": "Ошибка удаления" + }, + "view": { + "title": "Детальный просмотр", + "index": "Индекс", + "content": "Текст", + "charCount": "Символов", + "vecDocId": "ID вектора", + "close": "Закрыть" + }, + "actions": { + "retry": "Повторить" + }, + "messages": { + "loadDocumentFailed": "Не удалось загрузить документ", + "loadChunksFailed": "Не удалось загрузить фрагменты" + } +} diff --git a/dashboard/src/i18n/locales/ru-RU/features/knowledge-base/index.json b/dashboard/src/i18n/locales/ru-RU/features/knowledge-base/index.json index 4eb99d5f06..ca7f5e26ed 100644 --- a/dashboard/src/i18n/locales/ru-RU/features/knowledge-base/index.json +++ b/dashboard/src/i18n/locales/ru-RU/features/knowledge-base/index.json @@ -1,68 +1,74 @@ { - "title": "Управление базами знаний", - "subtitle": "Централизованное управление всеми знаниями AstrBot", - "list": { - "title": "Базы знаний", - "subtitle": "Все доступные коллекции знаний", - "create": "Создать базу", - "refresh": "Обновить", - "empty": "Баз знаний пока нет", - "loading": "Загрузка...", - "documents": "док.", - "chunks": "фрагм.", - "sessionConfig": "Профиль", - "initError": "Ошибка инициализации" - }, - "card": { - "edit": "Изменить", - "delete": "Удалить", - "open": "Открыть", - "docCount": "Документов: {count}", - "chunkCount": "Фрагментов: {count}" - }, - "create": { - "title": "Создание базы знаний", - "nameLabel": "Название", - "namePlaceholder": "Придумайте имя для базы", - "descriptionLabel": "Описание", - "descriptionPlaceholder": "Для чего нужна эта база?", - "emojiLabel": "Иконка", - "embeddingModelLabel": "Embedding модель", - "rerankModelLabel": "Rerank модель (опционально)", - "providerInfo": "Провайдер: {id} | Размерность: {dimensions}", - "rerankProviderInfo": "Провайдер: {id}", - "cancel": "Отмена", - "submit": "Создать", - "nameRequired": "Введите название базы знаний" - }, - "edit": { - "title": "Редактирование", - "submit": "Сохранить" - }, - "delete": { - "title": "Удаление", - "confirmText": "Вы уверены, что хотите удалить базу знаний «{name}»?", - "warning": "Это действие необратимо. Все документы, фрагменты и настройки будут навсегда удалены.", - "cancel": "Отмена", - "confirm": "Удалить" - }, - "emoji": { - "title": "Выберите иконку", - "close": "Закрыть", - "categories": { - "books": "Книги и документы", - "emotions": "Эмоции", - "objects": "Вещи", - "symbols": "Символы" - } - }, - "messages": { - "createSuccess": "База знаний создана", - "createFailed": "Ошибка создания", - "updateSuccess": "Обновлено успешно", - "updateFailed": "Ошибка обновления", - "deleteSuccess": "Удалено успешно", - "deleteFailed": "Ошибка удаления", - "loadError": "Не удалось загрузить список" + "title": "Управление базами знаний", + "subtitle": "Централизованное управление всеми знаниями AstrBot", + "list": { + "title": "Базы знаний", + "subtitle": "Все доступные коллекции знаний", + "create": "Создать базу", + "refresh": "Обновить", + "empty": "Баз знаний пока нет", + "loading": "Загрузка...", + "documents": "док.", + "chunks": "фрагм.", + "sessionConfig": "Профиль", + "initError": "Ошибка инициализации", + "noDescription": "Нет описания", + "switchToLegacy": "Перейти к старой базе знаний" + }, + "card": { + "edit": "Изменить", + "delete": "Удалить", + "open": "Открыть", + "docCount": "Документов: {count}", + "chunkCount": "Фрагментов: {count}" + }, + "create": { + "title": "Создание базы знаний", + "nameLabel": "Название", + "namePlaceholder": "Придумайте имя для базы", + "descriptionLabel": "Описание", + "descriptionPlaceholder": "Для чего нужна эта база?", + "emojiLabel": "Иконка", + "embeddingModelLabel": "Embedding модель", + "rerankModelLabel": "Rerank модель (опционально)", + "providerInfo": "Провайдер: {id} | Размерность: {dimensions}", + "rerankProviderInfo": "Провайдер: {id}", + "nameHint": "Если позже переименуете базу, обновите конфигурации, где она указана по имени.", + "embeddingModelHint": "Embedding модель нельзя изменить после создания. Для другой модели создайте новую базу.", + "cancel": "Отмена", + "submit": "Создать", + "nameRequired": "Введите название базы знаний", + "embeddingRequired": "Выберите embedding модель" + }, + "edit": { + "title": "Редактирование", + "submit": "Сохранить" + }, + "delete": { + "title": "Удаление", + "confirmText": "Вы уверены, что хотите удалить базу знаний «{name}»?", + "warning": "Это действие необратимо. Все документы, фрагменты и настройки будут навсегда удалены.", + "cancel": "Отмена", + "confirm": "Удалить" + }, + "emoji": { + "title": "Выберите иконку", + "close": "Закрыть", + "categories": { + "books": "Книги и документы", + "emotions": "Эмоции", + "objects": "Вещи", + "symbols": "Символы" } + }, + "messages": { + "createSuccess": "База знаний создана", + "createFailed": "Ошибка создания", + "updateSuccess": "Обновлено успешно", + "updateFailed": "Ошибка обновления", + "deleteSuccess": "Удалено успешно", + "deleteFailed": "Ошибка удаления", + "loadError": "Не удалось загрузить список", + "providersLoadError": "Не удалось загрузить провайдеры моделей" + } } diff --git a/dashboard/src/i18n/locales/zh-CN/features/knowledge-base/detail.json b/dashboard/src/i18n/locales/zh-CN/features/knowledge-base/detail.json index 54bc60b7a7..4d294f12e4 100644 --- a/dashboard/src/i18n/locales/zh-CN/features/knowledge-base/detail.json +++ b/dashboard/src/i18n/locales/zh-CN/features/knowledge-base/detail.json @@ -29,6 +29,7 @@ "title": "文档列表", "upload": "上传文档", "empty": "暂无文档", + "searchPlaceholder": "搜索文档...", "name": "文档名称", "type": "类型", "size": "大小", @@ -37,11 +38,13 @@ "actions": "操作", "view": "查看", "delete": "删除", + "cancel": "取消", "deleteConfirm": "确定要删除文档「{name}」吗?", "deleteWarning": "此操作将删除文档及其所有分块,不可恢复。", "uploading": "正在上传...", "uploadSuccess": "文档上传成功", "uploadFailed": "文档上传失败", + "loadFailed": "加载文档列表失败", "deleteSuccess": "文档删除成功", "deleteFailed": "文档删除失败" }, @@ -51,6 +54,14 @@ "dropzone": "拖放文件到这里或点击选择", "supportedFormats": "支持的格式: .txt, .md, .markdown, .rst, .adoc, .pdf, .docx, .epub, .xls, .xlsx", "maxSize": "最大文件大小: 128MB", + "maxFiles": "最多可上传 10 个文件", + "maxFilesWarning": "最多只能选择 {count} 个文件", + "selectedFiles": "已选择 {count} 个文件", + "clear": "清空", + "someFilesRejected": "部分文件未加入上传队列", + "unsupportedFile": "{name}: 不支持的文件类型", + "fileTooLarge": "{name}: 文件超过 128MB", + "invalidSettings": "请检查上传参数", "chunkSettings": "分块设置", "batchSettings": "批处理设置", "cleaningSettings": "清洗设置", @@ -75,6 +86,24 @@ "urlPlaceholder": "请输入要提取内容的网页 URL", "urlRequired": "请输入 URL", "urlHint": "将自动从目标 URL 提取主要内容作为文档。目前支持 {supported} 页面,请确保目标网页允许爬虫访问。", + "tavilyCheckFailed": "检查网页搜索配置失败", + "tavilyRequired": "使用此功能需要配置 Tavily Key", + "configure": "配置", + "tavilyConfigured": "Tavily API Key 配置成功", + "backgroundUploading": "正在后台上传 {count} 个文件...", + "backgroundUrlUploading": "正在从 URL 后台提取内容...", + "successCount": "成功上传 {count} 个文档", + "partialSuccess": "上传完成: {success} 个成功, {failed} 个失败", + "failedWithReason": "上传失败: {reason}", + "unknownError": "未知错误", + "stages": { + "waiting": "等待中...", + "extracting": "提取内容...", + "cleaning": "清洗内容...", + "parsing": "解析文档...", + "chunking": "文本分块...", + "embedding": "生成向量..." + }, "beta": "测试版" }, "retrieval": { @@ -88,6 +117,8 @@ "noResults": "没有找到相关内容", "tryDifferentQuery": "尝试使用不同的查询词", "settings": "检索设置", + "debugMode": "调试模式", + "tsneVisualization": "t-SNE 可视化", "topK": "返回结果数量", "topKHint": "最多返回多少条检索结果", "enableRerank": "启用重排序", @@ -113,9 +144,42 @@ "enableRerank": "启用重排序", "embeddingProvider": "嵌入模型提供商", "rerankProvider": "重排序模型提供商", + "embeddingProviderHint": "嵌入模型与现有向量索引绑定,如需更换请创建新的知识库。", + "indexType": "索引类型", + "indexTypeHint": "Flat 更精确,HNSW 更适合大规模知识库。", + "indexTypes": { + "flat": "Flat 精确索引", + "hnsw": "HNSW 近似索引" + }, "save": "保存设置", "saveSuccess": "设置保存成功", "saveFailed": "设置保存失败", + "providersLoadFailed": "加载模型提供商失败", "tips": "提示: 修改检索设置后,将影响后续的知识库查询效果。" + }, + "validation": { + "integer": "请输入整数", + "positiveInteger": "请输入大于 0 的整数", + "nonNegativeInteger": "请输入不小于 0 的整数", + "overlapLessThanSize": "分块重叠必须小于分块大小", + "topKRange": "返回结果数量必须是 1 到 100 的整数" + }, + "actions": { + "retry": "重试" + }, + "messages": { + "loadFailed": "加载知识库详情失败" + }, + "tavily": { + "title": "配置 Tavily API Key", + "description": "为了使用基于网页的知识库功能,需要提供 Tavily API Key。您可以从", + "officialSite": "Tavily 官网", + "apiKeyLabel": "Tavily API Key", + "cancel": "取消", + "save": "保存", + "keyRequired": "API Key 不能为空", + "loadConfigFailed": "获取当前配置失败", + "saveFailed": "保存失败,请检查 Key 是否正确", + "unknownSaveFailed": "保存失败,发生未知错误" } } diff --git a/dashboard/src/i18n/locales/zh-CN/features/knowledge-base/document.json b/dashboard/src/i18n/locales/zh-CN/features/knowledge-base/document.json index c90c29cc29..ffa01d074a 100644 --- a/dashboard/src/i18n/locales/zh-CN/features/knowledge-base/document.json +++ b/dashboard/src/i18n/locales/zh-CN/features/knowledge-base/document.json @@ -15,6 +15,7 @@ "index": "序号", "content": "内容", "charCount": "字符数", + "charCountValue": "{count} 字符", "actions": "操作", "view": "查看", "edit": "编辑", @@ -51,5 +52,12 @@ "charCount": "字符数", "vecDocId": "向量ID", "close": "关闭" + }, + "actions": { + "retry": "重试" + }, + "messages": { + "loadDocumentFailed": "加载文档详情失败", + "loadChunksFailed": "加载分块列表失败" } } diff --git a/dashboard/src/i18n/locales/zh-CN/features/knowledge-base/index.json b/dashboard/src/i18n/locales/zh-CN/features/knowledge-base/index.json index cac88bacd1..6343412817 100644 --- a/dashboard/src/i18n/locales/zh-CN/features/knowledge-base/index.json +++ b/dashboard/src/i18n/locales/zh-CN/features/knowledge-base/index.json @@ -11,7 +11,9 @@ "documents": "文档", "chunks": "分块", "sessionConfig": "会话配置", - "initError": "初始化失败" + "initError": "初始化失败", + "noDescription": "暂无描述", + "switchToLegacy": "切换到旧版知识库" }, "card": { "edit": "编辑", @@ -31,9 +33,12 @@ "rerankModelLabel": "重排序模型 (Rerank Model, 可选)", "providerInfo": "提供商: {id} | 维度: {dimensions}", "rerankProviderInfo": "提供商: {id}", + "nameHint": "如果后续修改知识库名称,请同步更新仍按名称引用的配置。", + "embeddingModelHint": "嵌入模型选择后无法修改,如需更换请创建新的知识库。", "cancel": "取消", "submit": "创建", - "nameRequired": "请输入知识库名称" + "nameRequired": "请输入知识库名称", + "embeddingRequired": "请选择嵌入模型" }, "edit": { "title": "编辑知识库", @@ -63,6 +68,7 @@ "updateFailed": "更新失败", "deleteSuccess": "知识库删除成功", "deleteFailed": "删除失败", - "loadError": "加载知识库列表失败" + "loadError": "加载知识库列表失败", + "providersLoadError": "加载模型提供商失败" } } diff --git a/dashboard/src/main.ts b/dashboard/src/main.ts index ce5514207c..eb2f15c205 100644 --- a/dashboard/src/main.ts +++ b/dashboard/src/main.ts @@ -2,7 +2,7 @@ import { createApp } from 'vue'; import { createPinia } from 'pinia'; import App from './App.vue'; import { router } from './router'; -import vuetify from './plugins/vuetify'; +import vuetify, { getVuetifyLocale } from './plugins/vuetify'; import confirmPlugin from './plugins/confirmPlugin'; import { setupI18n } from './i18n/composables'; import '@/scss/style.scss'; @@ -47,12 +47,18 @@ import { waitForRouterReadyInBackground } from './utils/routerReadiness.mjs'; }, }; +const syncVuetifyLocale = (event: Event) => { + const locale = (event as CustomEvent<{ locale?: string }>).detail?.locale; + vuetify.locale.current.value = getVuetifyLocale(locale); +}; + // 初始化新的i18n系统,等待完成后再挂载应用 setupI18n().then(async () => { console.log('🌍 新i18n系统初始化完成'); - + const app = createApp(App); const pinia = createPinia(); + window.addEventListener('astrbot-locale-changed', syncVuetifyLocale); app.use(pinia); app.use(router); app.use(print); @@ -86,6 +92,7 @@ setupI18n().then(async () => { // 即使i18n初始化失败,也要挂载应用(使用回退机制) const app = createApp(App); const pinia = createPinia(); + window.addEventListener('astrbot-locale-changed', syncVuetifyLocale); app.use(pinia); app.use(router); app.use(print); diff --git a/dashboard/src/plugins/vuetify.ts b/dashboard/src/plugins/vuetify.ts index e38fd388e6..474f1ca02c 100644 --- a/dashboard/src/plugins/vuetify.ts +++ b/dashboard/src/plugins/vuetify.ts @@ -1,32 +1,91 @@ import { createVuetify } from 'vuetify'; +import { en, ru, zhHans } from 'vuetify/locale'; import '@/assets/mdi-subset/materialdesignicons-subset.css'; import * as components from 'vuetify/components'; import * as directives from 'vuetify/directives'; import { PurpleTheme } from '@/theme/LightTheme'; -import { PurpleThemeDark } from "@/theme/DarkTheme"; +import { PurpleThemeDark } from '@/theme/DarkTheme'; + +const zhHansMessages = { + ...zhHans, + open: '打开', + dismiss: '关闭', + dataFooter: { + ...zhHans.dataFooter, + itemsPerPageText: '每页条数:', + firstPage: '第一页', + lastPage: '最后一页', + }, + input: { + ...zhHans.input, + clear: '清空 {0}', + prependAction: '{0} 前置操作', + appendAction: '{0} 后置操作', + otp: '请输入第 {0} 位验证码', + }, + pagination: { + ...zhHans.pagination, + ariaLabel: { + ...zhHans.pagination.ariaLabel, + first: '第一页', + last: '最后一页', + }, + }, + stepper: { + next: '下一步', + prev: '上一步', + }, + loading: '加载中...', +}; + +const vuetifyLocaleMap: Record = { + 'zh-CN': 'zhHans', + 'en-US': 'en', + 'ru-RU': 'ru', +}; + +export const getVuetifyLocale = (locale?: string | null) => { + if (!locale) { + return 'zhHans'; + } + return vuetifyLocaleMap[locale] || 'zhHans'; +}; export default createVuetify({ components, directives, + locale: { + locale: getVuetifyLocale( + typeof localStorage === 'undefined' + ? null + : localStorage.getItem('astrbot-locale'), + ), + fallback: 'en', + messages: { + en, + ru, + zhHans: zhHansMessages, + }, + }, theme: { defaultTheme: 'PurpleTheme', themes: { PurpleTheme, - PurpleThemeDark - } + PurpleThemeDark, + }, }, defaults: { VBtn: {}, VCard: { - rounded: 'lg' + rounded: 'lg', }, VTextField: { - rounded: 'lg' + rounded: 'lg', }, VTooltip: { // set v-tooltip default location to top - location: 'top' - } - } + location: 'top', + }, + }, }); diff --git a/dashboard/src/views/knowledge-base/DocumentDetail.vue b/dashboard/src/views/knowledge-base/DocumentDetail.vue index 921315e627..0645a8f0c9 100644 --- a/dashboard/src/views/knowledge-base/DocumentDetail.vue +++ b/dashboard/src/views/knowledge-base/DocumentDetail.vue @@ -9,7 +9,9 @@ />

{{ document.doc_name }}

-

{{ t('title') }}

+

+ {{ t("title") }} +

@@ -18,18 +20,29 @@ + +
+ {{ loadError }} + + {{ t("actions.retry") }} + +
+
+
- {{ t('info.title') }} + {{ t("info.title") }}
mdi-label
-
{{ t('info.name') }}
+
+ {{ t("info.name") }} +
{{ document.doc_name }}
@@ -40,8 +53,10 @@ {{ getFileIcon(document.file_type) }}
-
{{ t('info.type') }}
-
{{ document.file_type || '-' }}
+
+ {{ t("info.type") }} +
+
{{ document.file_type || "-" }}
@@ -49,8 +64,12 @@
mdi-file-chart
-
{{ t('info.size') }}
-
{{ formatFileSize(document.file_size) }}
+
+ {{ t("info.size") }} +
+
+ {{ formatFileSize(document.file_size) }} +
@@ -58,7 +77,9 @@
mdi-text-box
-
{{ t('info.chunkCount') }}
+
+ {{ t("info.chunkCount") }} +
{{ document.chunk_count || 0 }}
@@ -67,8 +88,12 @@
mdi-calendar
-
{{ t('info.createdAt') }}
-
{{ formatDate(document.created_at) }}
+
+ {{ t("info.createdAt") }} +
+
+ {{ formatDate(document.created_at) }} +
@@ -79,9 +104,9 @@ - {{ t('chunks.title') }} + {{ t("chunks.title") }} - {{ totalChunks }} {{ t('chunks.title') }} + {{ totalChunks }} {{ t("chunks.title") }} -
+
- {{ t('chunks.showing') }} {{ (page - 1) * pageSize + 1 }} - {{ Math.min(page * pageSize, totalChunks) }} / {{ totalChunks }} + {{ t("chunks.showing") }} {{ (page - 1) * pageSize + 1 }} - + {{ Math.min(page * pageSize, totalChunks) }} / {{ totalChunks }}
- {{ t('view.title') }} + {{ t("view.title") }} - + @@ -190,28 +224,40 @@ - {{ t('view.index') }} - #{{ (selectedChunk?.chunk_index || 0) + 1 }} + {{ t("view.index") }} + #{{ + (selectedChunk?.chunk_index || 0) + 1 + }} - {{ t('view.charCount') }} - {{ selectedChunk?.char_count || 0 }} 字符 + {{ t("view.charCount") }} + {{ + t("chunks.charCountValue", { + count: selectedChunk?.char_count || 0, + }) + }} - {{ t('view.vecDocId') }} - {{ selectedChunk?.chunk_id || '-' }} + {{ t("view.vecDocId") }} + {{ + selectedChunk?.chunk_id || "-" + }} -
{{ t('view.content') }}
+
+ {{ t("view.content") }} +
{{ selectedChunk?.content }}
@@ -219,7 +265,7 @@ - {{ t('view.close') }} + {{ t("view.close") }}
@@ -233,190 +279,216 @@ diff --git a/dashboard/src/views/knowledge-base/components/SettingsTab.vue b/dashboard/src/views/knowledge-base/components/SettingsTab.vue index 7d18c305a9..c7ed7c42be 100644 --- a/dashboard/src/views/knowledge-base/components/SettingsTab.vue +++ b/dashboard/src/views/knowledge-base/components/SettingsTab.vue @@ -1,12 +1,12 @@ diff --git a/dashboard/src/views/knowledge-base/components/TavilyKeyDialog.vue b/dashboard/src/views/knowledge-base/components/TavilyKeyDialog.vue index 37cf9df8c9..b56a6086dc 100644 --- a/dashboard/src/views/knowledge-base/components/TavilyKeyDialog.vue +++ b/dashboard/src/views/knowledge-base/components/TavilyKeyDialog.vue @@ -2,15 +2,18 @@ - 配置 Tavily API Key + {{ t("tavily.title") }}

- 为了使用基于网页的知识库功能,需要提供 Tavily API Key。您可以从 Tavily 官网 获取。 + {{ t("tavily.description") }} + {{ + t("tavily.officialSite") + }}

- 取消 + {{ t("tavily.cancel") }} - - 保存 + + {{ t("tavily.save") }}
@@ -33,77 +41,86 @@ \ No newline at end of file +}; + From a361cded3a579db0cea6ed69dba658b1c6206c72 Mon Sep 17 00:00:00 2001 From: lxfight <1686540385@qq.com> Date: Sat, 6 Jun 2026 11:04:39 +0800 Subject: [PATCH 25/48] feat(kb): upgrade knowledge base workflows --- .../db/vec_db/faiss_impl/document_storage.py | 125 +- astrbot/core/knowledge_base/capabilities.py | 110 ++ .../core/knowledge_base/chunking/markdown.py | 261 ++- .../core/knowledge_base/document_metadata.py | 61 + astrbot/core/knowledge_base/kb_db_sqlite.py | 423 ++++- astrbot/core/knowledge_base/kb_helper.py | 1460 ++++++++++++++- astrbot/core/knowledge_base/kb_mgr.py | 140 +- astrbot/core/knowledge_base/models.py | 67 +- astrbot/core/knowledge_base/parsers/base.py | 9 + .../core/knowledge_base/parsers/pdf_parser.py | 16 +- .../core/knowledge_base/retrieval/__init__.py | 25 +- .../core/knowledge_base/retrieval/manager.py | 566 +++++- .../knowledge_base/retrieval/rank_fusion.py | 22 + .../retrieval/sparse_retriever.py | 4 + astrbot/dashboard/routes/knowledge_base.py | 1192 ++++++++++++- .../en-US/features/knowledge-base/detail.json | 197 ++- .../features/knowledge-base/document.json | 65 +- .../ru-RU/features/knowledge-base/detail.json | 199 ++- .../features/knowledge-base/document.json | 65 +- .../zh-CN/features/knowledge-base/detail.json | 197 ++- .../features/knowledge-base/document.json | 65 +- .../zh-CN/features/knowledge-base/index.json | 4 +- .../views/knowledge-base/DocumentDetail.vue | 840 ++++++++- .../src/views/knowledge-base/KBDetail.vue | 890 +++++++++- dashboard/src/views/knowledge-base/KBList.vue | 45 +- .../src/views/knowledge-base/capabilities.ts | 95 + .../components/DocumentsTab.vue | 1089 ++++++++++-- .../components/RetrievalTab.vue | 493 +++++- .../knowledge-base/components/SettingsTab.vue | 79 +- .../components/TavilyKeyDialog.vue | 2 +- .../views/knowledge-base/knowledgeBaseUi.mjs | 779 ++++++++ dashboard/tests/knowledgeBase.test.mjs | 1468 ++++++++++++++++ tests/test_kb_batch_delete.py | 448 ++++- tests/test_kb_stats.py | 102 ++ tests/test_kb_update_route.py | 1479 +++++++++++++++- tests/test_kb_upload_memory_leak.py | 467 ++++- tests/test_kb_upload_rollback.py | 1564 ++++++++++++++++- tests/unit/test_document_storage_fts.py | 34 + tests/unit/test_kb_core_features.py | 955 ++++++++++ tests/unit/test_kb_manager_delete.py | 11 +- 40 files changed, 15454 insertions(+), 659 deletions(-) create mode 100644 astrbot/core/knowledge_base/capabilities.py create mode 100644 astrbot/core/knowledge_base/document_metadata.py create mode 100644 dashboard/src/views/knowledge-base/capabilities.ts create mode 100644 dashboard/src/views/knowledge-base/knowledgeBaseUi.mjs create mode 100644 dashboard/tests/knowledgeBase.test.mjs create mode 100644 tests/unit/test_kb_core_features.py diff --git a/astrbot/core/db/vec_db/faiss_impl/document_storage.py b/astrbot/core/db/vec_db/faiss_impl/document_storage.py index 4b3c1a4c60..84069ba52f 100644 --- a/astrbot/core/db/vec_db/faiss_impl/document_storage.py +++ b/astrbot/core/db/vec_db/faiss_impl/document_storage.py @@ -18,6 +18,7 @@ build_fts5_or_query, load_stopwords, to_fts5_search_text, + tokenize_text, ) FTS_TABLE_NAME = "documents_fts" @@ -515,23 +516,133 @@ async def count_documents(self, metadata_filters: dict | None = None) -> int: count = result.scalar_one_or_none() return count if count is not None else 0 + async def search_documents( + self, + query_text: str, + metadata_filters: dict | None = None, + offset: int = 0, + limit: int = 100, + ) -> tuple[list[dict], int] | None: + """Search documents with FTS5 and optional metadata filters. + + Returns None when FTS5 is unavailable so callers can choose whether to + fall back to an alternate search strategy. + """ + if limit <= 0: + return [], 0 + if not await self.ensure_fts_index(): + return None + + match_query = build_fts5_or_query(tokenize_text(query_text, self.stopwords)) + if not match_query: + return [], 0 + + metadata_filters = metadata_filters or {} + async with self.get_session() as session: + filters_sql, filter_params = await self._metadata_filter_sql( + session, + metadata_filters, + table_alias="d", + ) + where_clause = f"{FTS_TABLE_NAME} MATCH :query" + if filters_sql: + where_clause = f"{where_clause} AND {' AND '.join(filters_sql)}" + params = { + "query": match_query, + "limit": int(limit), + "offset": int(offset), + **filter_params, + } + try: + count_result = await session.execute( + text( + f""" + SELECT count(*) + FROM {FTS_TABLE_NAME} + JOIN documents d ON d.id = {FTS_TABLE_NAME}.rowid + WHERE {where_clause} + """, + ), + params, + ) + total = int(count_result.scalar_one_or_none() or 0) + result = await session.execute( + text( + f""" + SELECT + d.id AS id, + d.doc_id AS doc_id, + d.text AS text, + d.metadata AS metadata, + d.created_at AS created_at, + d.updated_at AS updated_at, + bm25({FTS_TABLE_NAME}) AS score + FROM {FTS_TABLE_NAME} + JOIN documents d ON d.id = {FTS_TABLE_NAME}.rowid + WHERE {where_clause} + ORDER BY score ASC, d.id ASC + LIMIT :limit + OFFSET :offset + """, + ), + params, + ) + except Exception as e: + logger.warning( + f"FTS5 document search failed for {self.db_path}: {e}", + ) + self.fts5_available = False + return None + + rows = result.mappings().all() + return [ + { + "id": row["id"], + "doc_id": row["doc_id"], + "text": row["text"], + "metadata": row["metadata"], + "created_at": row["created_at"], + "updated_at": row["updated_at"], + "score": float(row["score"]), + } + for row in rows + ], total + async def _apply_metadata_filters( self, session: AsyncSession, query, metadata_filters: dict, ): + filters_sql, params = await self._metadata_filter_sql( + session, + metadata_filters, + ) + for filter_sql in filters_sql: + query = query.where(text(filter_sql)) + if params: + query = query.params(**params) + return query + + async def _metadata_filter_sql( + self, + session: AsyncSession, + metadata_filters: dict, + table_alias: str | None = None, + ) -> tuple[list[str], dict]: columns = await self._table_columns(session, "documents") + prefix = f"{table_alias}." if table_alias else "" + filters_sql = [] + params = {} for key, val in metadata_filters.items(): if key in {"kb_id", "kb_doc_id", "user_id"} and key in columns: - query = query.where( - text(f"{key} = :filter_{key}"), - ).params(**{f"filter_{key}": val}) + filters_sql.append(f"{prefix}{key} = :filter_{key}") else: - query = query.where( - text(f"json_extract(metadata, '$.{key}') = :filter_{key}"), - ).params(**{f"filter_{key}": val}) - return query + filters_sql.append( + f"json_extract({prefix}metadata, '$.{key}') = :filter_{key}" + ) + params[f"filter_{key}"] = val + return filters_sql, params async def ensure_fts_index(self) -> bool: """Ensure the FTS5 sparse index exists and matches the documents table.""" diff --git a/astrbot/core/knowledge_base/capabilities.py b/astrbot/core/knowledge_base/capabilities.py new file mode 100644 index 0000000000..9367604514 --- /dev/null +++ b/astrbot/core/knowledge_base/capabilities.py @@ -0,0 +1,110 @@ +"""Knowledge base capabilities and default limits.""" + +from typing import Any + +ALLOWED_UPLOAD_EXTENSIONS = frozenset( + { + "adoc", + "docx", + "epub", + "md", + "markdown", + "pdf", + "rst", + "txt", + "xls", + "xlsx", + }, +) + +MAX_UPLOAD_FILE_SIZE = 128 * 1024 * 1024 +MAX_UPLOAD_FILES = 10 +MAX_BATCH_DELETE_DOCUMENTS = 100 +MAX_BATCH_REBUILD_DOCUMENTS = 100 +MAX_RETRIEVE_TOP_K = 100 +DEFAULT_KB_PAGE_SIZE = 20 +DEFAULT_DOCUMENT_PAGE_SIZE = 10 +DEFAULT_CHUNK_PAGE_SIZE = 10 +DEFAULT_BULK_PAGE_SIZE = 100 +DOCUMENT_PAGE_SIZE_OPTIONS = (10, 20, 50, 100) +CHUNK_PAGE_SIZE_OPTIONS = (10, 25, 50, 100) + +DOCUMENT_FILTER_STATUSES = ( + "pending", + "parsing", + "chunking", + "embedding", + "ready", + "failed", +) +DOCUMENT_FILTER_SOURCE_TYPES = ("file", "url", "import") + +FEATURE_SPARSE_RETRIEVAL = True +FEATURE_RERANK = True +FEATURE_URL_IMPORT = True +FEATURE_DOCUMENT_REBUILD = True +FEATURE_KB_REBUILD = True +FEATURE_CONSISTENCY_CHECK = True +FEATURE_CONSISTENCY_REPAIR = True +FEATURE_BATCH_DELETE = True +FEATURE_BATCH_REBUILD = True + +DEFAULT_CHUNK_SIZE = 512 +DEFAULT_CHUNK_OVERLAP = 50 +DEFAULT_TOP_K_DENSE = 50 +DEFAULT_TOP_K_SPARSE = 50 +DEFAULT_TOP_M_FINAL = 5 +DEFAULT_INDEX_TYPE = "flat" +DEFAULT_UPLOAD_BATCH_SIZE = 32 +DEFAULT_UPLOAD_TASKS_LIMIT = 3 +DEFAULT_UPLOAD_MAX_RETRIES = 3 + + +def get_knowledge_base_capabilities() -> dict[str, Any]: + """Return API-safe knowledge base capabilities.""" + return { + "upload": { + "allowed_extensions": sorted(ALLOWED_UPLOAD_EXTENSIONS), + "max_file_size_bytes": MAX_UPLOAD_FILE_SIZE, + "max_files_per_upload": MAX_UPLOAD_FILES, + }, + "defaults": { + "chunk_size": DEFAULT_CHUNK_SIZE, + "chunk_overlap": DEFAULT_CHUNK_OVERLAP, + "batch_size": DEFAULT_UPLOAD_BATCH_SIZE, + "tasks_limit": DEFAULT_UPLOAD_TASKS_LIMIT, + "max_retries": DEFAULT_UPLOAD_MAX_RETRIES, + "top_k_dense": DEFAULT_TOP_K_DENSE, + "top_k_sparse": DEFAULT_TOP_K_SPARSE, + "top_m_final": DEFAULT_TOP_M_FINAL, + "index_type": DEFAULT_INDEX_TYPE, + }, + "limits": { + "max_retrieve_top_k": MAX_RETRIEVE_TOP_K, + "max_batch_delete_documents": MAX_BATCH_DELETE_DOCUMENTS, + "max_batch_rebuild_documents": MAX_BATCH_REBUILD_DOCUMENTS, + }, + "pagination": { + "document_page_size_options": list(DOCUMENT_PAGE_SIZE_OPTIONS), + "chunk_page_size_options": list(CHUNK_PAGE_SIZE_OPTIONS), + "default_kb_page_size": DEFAULT_KB_PAGE_SIZE, + "default_document_page_size": DEFAULT_DOCUMENT_PAGE_SIZE, + "default_chunk_page_size": DEFAULT_CHUNK_PAGE_SIZE, + "bulk_page_size": DEFAULT_BULK_PAGE_SIZE, + }, + "document_filters": { + "statuses": list(DOCUMENT_FILTER_STATUSES), + "source_types": list(DOCUMENT_FILTER_SOURCE_TYPES), + }, + "features": { + "sparse_retrieval": FEATURE_SPARSE_RETRIEVAL, + "rerank": FEATURE_RERANK, + "url_import": FEATURE_URL_IMPORT, + "document_rebuild": FEATURE_DOCUMENT_REBUILD, + "kb_rebuild": FEATURE_KB_REBUILD, + "consistency_check": FEATURE_CONSISTENCY_CHECK, + "consistency_repair": FEATURE_CONSISTENCY_REPAIR, + "batch_delete": FEATURE_BATCH_DELETE, + "batch_rebuild": FEATURE_BATCH_REBUILD, + }, + } diff --git a/astrbot/core/knowledge_base/chunking/markdown.py b/astrbot/core/knowledge_base/chunking/markdown.py index 9ace43110d..49d3c42cd7 100644 --- a/astrbot/core/knowledge_base/chunking/markdown.py +++ b/astrbot/core/knowledge_base/chunking/markdown.py @@ -16,10 +16,29 @@ class _Section: """解析后的 Markdown 章节""" heading_path: list[str] + title_path: list[str] + section_index: int | None text: str has_body: bool +@dataclass +class MarkdownChunk: + """A Markdown chunk with source structure metadata.""" + + text: str + title_path: list[str] | None = None + section_index: int | None = None + + +@dataclass +class _ChunkDraft: + text: str + has_body: bool + title_path: list[str] | None + section_index: int | None + + class MarkdownChunker(BaseChunker): """Markdown 感知分块器 @@ -72,31 +91,28 @@ async def chunk(self, text: str, **kwargs) -> list[str]: list[str]: 分块后的文本列表 """ + chunks = await self.chunk_with_metadata(text, **kwargs) + return [chunk.text for chunk in chunks] + + async def chunk_with_metadata(self, text: str, **kwargs) -> list[MarkdownChunk]: + """Split Markdown text and keep per-chunk structure metadata.""" if not text or not text.strip(): return [] chunk_size = kwargs.get("chunk_size", self.chunk_size) chunk_overlap = kwargs.get("chunk_overlap", self.chunk_overlap) - # 解析 Markdown 结构 sections = self._parse_sections(text) if not sections: - # 没有识别到标题结构,回退到递归分割 - return await self._fallback_chunker.chunk( + chunks = await self._fallback_chunker.chunk( text, chunk_size=chunk_size, chunk_overlap=chunk_overlap ) + return [MarkdownChunk(text=chunk) for chunk in chunks] - # 将 sections 转换为 raw chunks raw_chunks = await self._sections_to_chunks(sections, chunk_size, chunk_overlap) - - # 合并纯标题节到下一个有内容的 chunk merged = self._merge_heading_only_chunks(raw_chunks, chunk_size) - - # 合并过短的相邻 chunk - merged = self._merge_short_chunks(merged, chunk_size) - - return merged + return self._merge_short_chunks(merged, chunk_size) def _estimate_prefix_length(self, heading_path: list[str]) -> int: """估算标题上下文前缀的最大长度(用于扣除子块可用空间)""" @@ -109,13 +125,15 @@ def _estimate_prefix_length(self, heading_path: list[str]) -> int: async def _sections_to_chunks( self, sections: list[_Section], chunk_size: int, chunk_overlap: int - ) -> list[tuple[str, bool]]: + ) -> list[_ChunkDraft]: """将解析后的 sections 转换为 (chunk_text, has_body) 列表""" - raw_chunks: list[tuple[str, bool]] = [] + raw_chunks: list[_ChunkDraft] = [] for section in sections: section_text = section.text heading_path = section.heading_path + title_path = self._normalize_title_path(section.title_path) + section_index = section.section_index has_body = section.has_body # 构建带上下文的文本 @@ -123,7 +141,14 @@ async def _sections_to_chunks( full_text = context_prefix + section_text if len(full_text) <= chunk_size: - raw_chunks.append((full_text.strip(), has_body)) + raw_chunks.append( + _ChunkDraft( + text=full_text.strip(), + has_body=has_body, + title_path=title_path, + section_index=section_index, + ) + ) else: # 章节过长,内部递归分割 # 扣除前缀长度,确保添加前缀后不超过 chunk_size @@ -139,7 +164,14 @@ async def _sections_to_chunks( chunk_text = self._apply_heading_context( heading_path, sub_chunk, is_continuation=(i > 0) ) - raw_chunks.append((chunk_text, True)) + raw_chunks.append( + _ChunkDraft( + text=chunk_text, + has_body=True, + title_path=title_path, + section_index=section_index, + ) + ) return raw_chunks @@ -162,74 +194,183 @@ def _apply_heading_context( return f"{title}\n\n{content}".strip() def _merge_heading_only_chunks( - self, raw_chunks: list[tuple[str, bool]], chunk_size: int - ) -> list[str]: + self, raw_chunks: list[_ChunkDraft], chunk_size: int + ) -> list[MarkdownChunk]: """合并没有实质正文的 chunk 到下一个有正文的 chunk""" - merged: list[str] = [] - pending = "" + merged: list[MarkdownChunk] = [] + pending_text = "" + pending_title_path: list[str] | None = None + pending_section_index: int | None = None - for chunk_text, has_body in raw_chunks: + for chunk in raw_chunks: + chunk_text = chunk.text if not chunk_text: continue - if not has_body: + if not chunk.has_body: # 纯标题节,暂存;但如果 pending 已经够长,先 flush - if pending and len(pending) + len(chunk_text) + 2 > chunk_size: - merged.append(pending.strip()) - pending = "" - pending += chunk_text + "\n\n" + if ( + pending_text + and len(pending_text) + len(chunk_text) + 2 > chunk_size + ): + merged.append( + MarkdownChunk( + text=pending_text.strip(), + title_path=pending_title_path, + section_index=pending_section_index, + ) + ) + pending_text = "" + pending_title_path = None + pending_section_index = None + pending_text += chunk_text + "\n\n" + pending_title_path = chunk.title_path or pending_title_path + pending_section_index = chunk.section_index else: - if pending: - combined = pending + chunk_text + if pending_text: + combined = pending_text + chunk_text if len(combined) <= chunk_size: - merged.append(combined.strip()) + merged.append( + MarkdownChunk( + text=combined.strip(), + title_path=chunk.title_path or pending_title_path, + section_index=chunk.section_index, + ) + ) else: - merged.append(pending.strip()) - merged.append(chunk_text.strip()) - pending = "" + merged.append( + MarkdownChunk( + text=pending_text.strip(), + title_path=pending_title_path, + section_index=pending_section_index, + ) + ) + merged.append( + MarkdownChunk( + text=chunk_text.strip(), + title_path=chunk.title_path, + section_index=chunk.section_index, + ) + ) + pending_text = "" + pending_title_path = None + pending_section_index = None else: - merged.append(chunk_text.strip()) + merged.append( + MarkdownChunk( + text=chunk_text.strip(), + title_path=chunk.title_path, + section_index=chunk.section_index, + ) + ) # 处理尾部残留的 pending - if pending: - pending_text = pending.strip() - if merged and len(merged[-1] + "\n\n" + pending_text) <= chunk_size: - merged[-1] = merged[-1] + "\n\n" + pending_text + if pending_text: + trailing_text = pending_text.strip() + if merged and len(merged[-1].text + "\n\n" + trailing_text) <= chunk_size: + merged[-1] = MarkdownChunk( + text=merged[-1].text + "\n\n" + trailing_text, + title_path=self._merge_title_paths( + [merged[-1].title_path, pending_title_path] + ), + section_index=self._merge_section_indexes( + [merged[-1].section_index, pending_section_index] + ), + ) else: - merged.append(pending_text) + merged.append( + MarkdownChunk( + text=trailing_text, + title_path=pending_title_path, + section_index=pending_section_index, + ) + ) - return [c for c in merged if c.strip()] + return [chunk for chunk in merged if chunk.text.strip()] - def _merge_short_chunks(self, chunks: list[str], chunk_size: int) -> list[str]: + def _merge_short_chunks( + self, chunks: list[MarkdownChunk], chunk_size: int + ) -> list[MarkdownChunk]: """合并过短的相邻 chunk(低于 min_chunk_size)""" if self.min_chunk_size <= 0 or len(chunks) <= 1: return chunks - final: list[str] = [] - buf = "" + final: list[MarkdownChunk] = [] + buf: MarkdownChunk | None = None - for c in chunks: + for chunk in chunks: if buf: - combined = buf + "\n\n" + c + combined = buf.text + "\n\n" + chunk.text if len(combined) <= chunk_size: - buf = combined + buf = MarkdownChunk( + text=combined, + title_path=self._merge_title_paths( + [buf.title_path, chunk.title_path] + ), + section_index=self._merge_section_indexes( + [buf.section_index, chunk.section_index] + ), + ) else: final.append(buf) - buf = c if len(c) < self.min_chunk_size else "" - if len(c) >= self.min_chunk_size: - final.append(c) - elif len(c) < self.min_chunk_size: - buf = c + if len(chunk.text) < self.min_chunk_size: + buf = chunk + else: + buf = None + final.append(chunk) + elif len(chunk.text) < self.min_chunk_size: + buf = chunk else: - final.append(c) + final.append(chunk) if buf: - if final and len(final[-1] + "\n\n" + buf) <= chunk_size: - final[-1] = final[-1] + "\n\n" + buf + if final and len(final[-1].text + "\n\n" + buf.text) <= chunk_size: + final[-1] = MarkdownChunk( + text=final[-1].text + "\n\n" + buf.text, + title_path=self._merge_title_paths( + [final[-1].title_path, buf.title_path] + ), + section_index=self._merge_section_indexes( + [final[-1].section_index, buf.section_index] + ), + ) else: final.append(buf) return final + @staticmethod + def _normalize_title_path(title_path: list[str]) -> list[str] | None: + path = [title.strip() for title in title_path if title and title.strip()] + return path or None + + @staticmethod + def _merge_title_paths(paths: list[list[str] | None]) -> list[str] | None: + non_empty_paths = [path for path in paths if path] + if not non_empty_paths: + return None + + common = list(non_empty_paths[0]) + for path in non_empty_paths[1:]: + prefix: list[str] = [] + for left, right in zip(common, path, strict=False): + if left != right: + break + prefix.append(left) + common = prefix + if not common: + return None + return common + + @staticmethod + def _merge_section_indexes(indexes: list[int | None]) -> int | None: + non_empty_indexes = [index for index in indexes if index is not None] + if not non_empty_indexes: + return None + first_index = non_empty_indexes[0] + if all(index == first_index for index in non_empty_indexes): + return first_index + return None + def _parse_sections(self, text: str) -> list[_Section]: """解析 Markdown 文本为章节列表 @@ -264,11 +405,21 @@ def _parse_sections(self, text: str) -> list[_Section]: return [] sections: list[_Section] = [] + section_index = 0 # 处理第一个标题之前的内容(如果有) preamble = text[: headings[0]["start"]].strip() if preamble: - sections.append(_Section(heading_path=[], text=preamble, has_body=True)) + sections.append( + _Section( + heading_path=[], + title_path=[], + section_index=section_index, + text=preamble, + has_body=True, + ) + ) + section_index += 1 # 维护标题栈来追踪层级路径 heading_stack: list[dict] = [] @@ -297,14 +448,18 @@ def _parse_sections(self, text: str) -> list[_Section]: # 构建标题路径 heading_path = [h["title"] for h in heading_stack[:-1]] + title_path = [h["title"] for h in heading_stack] sections.append( _Section( heading_path=heading_path, + title_path=title_path, + section_index=section_index, text=section_text, has_body=bool(body), ) ) + section_index += 1 return sections diff --git a/astrbot/core/knowledge_base/document_metadata.py b/astrbot/core/knowledge_base/document_metadata.py new file mode 100644 index 0000000000..4c78efe410 --- /dev/null +++ b/astrbot/core/knowledge_base/document_metadata.py @@ -0,0 +1,61 @@ +"""Helpers for knowledge-base document governance metadata.""" + +import hashlib +import re +import uuid +from pathlib import Path + +from .chunking.base import BaseChunker +from .parsers.base import BaseParser + +DEFAULT_PARSER_VERSION = "1" +DEFAULT_CHUNKER_VERSION = "1" + + +def build_content_hash(content: bytes | str | list[str]) -> str: + """Return a stable SHA256 hash for source content.""" + digest = hashlib.sha256() + if isinstance(content, bytes): + digest.update(content) + elif isinstance(content, str): + digest.update(content.encode("utf-8")) + else: + for chunk in content: + digest.update(chunk.encode("utf-8")) + digest.update(b"\x00") + return digest.hexdigest() + + +def get_parser_name(parser: BaseParser | None) -> str | None: + if parser is None: + return None + return parser.__class__.__name__ + + +def get_chunker_name(chunker: BaseChunker | None) -> str | None: + if chunker is None: + return None + return chunker.__class__.__name__ + + +def sanitize_source_filename(file_name: str | None, fallback_suffix: str = "") -> str: + """Return a filename safe for storage inside a KB-owned directory.""" + raw = (file_name or "").replace("\\", "/").split("/")[-1].replace("\x00", "") + safe = re.sub(r"[^A-Za-z0-9._ -]", "_", raw).strip(" .") + if not safe: + safe = f"document_{uuid.uuid4().hex[:8]}{fallback_suffix}" + return safe[:255] + + +def build_stored_source_path( + files_dir: Path, + *, + doc_id: str, + file_name: str, + file_type: str, +) -> Path: + suffix = Path(file_name).suffix + if not suffix and file_type: + suffix = f".{file_type}" + safe_name = sanitize_source_filename(file_name, fallback_suffix=suffix) + return files_dir / doc_id / safe_name diff --git a/astrbot/core/knowledge_base/kb_db_sqlite.py b/astrbot/core/knowledge_base/kb_db_sqlite.py index 6ccb91e84d..a779953fa6 100644 --- a/astrbot/core/knowledge_base/kb_db_sqlite.py +++ b/astrbot/core/knowledge_base/kb_db_sqlite.py @@ -1,5 +1,7 @@ import asyncio +import json from contextlib import asynccontextmanager +from datetime import datetime, timezone from pathlib import Path from typing import TYPE_CHECKING @@ -12,6 +14,7 @@ from astrbot.core.knowledge_base.models import ( BaseKBModel, KBDocument, + KBIngestionTask, KBMedia, KnowledgeBase, ) @@ -20,6 +23,8 @@ if TYPE_CHECKING: from astrbot.core.db.vec_db.faiss_impl import FaissVecDB +_UNSET = object() + class KBSQLiteDatabase: def __init__(self, db_path: str | None = None) -> None: @@ -96,6 +101,8 @@ async def migrate_to_v1(self) -> None: column_name="index_type", column_sql="index_type TEXT DEFAULT 'flat'", ) + await self._ensure_document_governance_columns(session) + await self._ensure_ingestion_task_table(session) # 创建知识库表索引 await session.execute( @@ -148,6 +155,24 @@ async def migrate_to_v1(self) -> None: "ON kb_documents(created_at)", ), ) + await session.execute( + text( + "CREATE INDEX IF NOT EXISTS idx_doc_content_hash " + "ON kb_documents(content_hash)", + ), + ) + await session.execute( + text( + "CREATE INDEX IF NOT EXISTS idx_doc_status " + "ON kb_documents(status)", + ), + ) + await session.execute( + text( + "CREATE INDEX IF NOT EXISTS idx_doc_parent_doc_id " + "ON kb_documents(parent_doc_id)", + ), + ) # 创建多媒体表索引 await session.execute( @@ -173,6 +198,7 @@ async def migrate_to_v1(self) -> None: "ON kb_media(media_type)", ), ) + await self._ensure_ingestion_task_indexes(session) await session.commit() @@ -194,6 +220,104 @@ async def _ensure_column( ) await session.execute(text(f"ALTER TABLE {table_name} ADD COLUMN {column_sql}")) + async def _ensure_document_governance_columns( + self, + session: AsyncSession, + ) -> None: + columns = { + "source_type": "source_type TEXT NOT NULL DEFAULT 'file'", + "source_uri": "source_uri TEXT", + "content_hash": "content_hash VARCHAR(64)", + "parser_name": "parser_name VARCHAR(100)", + "parser_version": "parser_version VARCHAR(50)", + "chunker_name": "chunker_name VARCHAR(100)", + "chunker_version": "chunker_version VARCHAR(50)", + "status": "status TEXT NOT NULL DEFAULT 'ready'", + "error_stage": "error_stage VARCHAR(50)", + "error_message": "error_message TEXT", + "version": "version INTEGER NOT NULL DEFAULT 1", + "parent_doc_id": "parent_doc_id VARCHAR(36)", + "indexed_at": "indexed_at DATETIME", + } + for column_name, column_sql in columns.items(): + await self._ensure_column( + session, + table_name="kb_documents", + column_name=column_name, + column_sql=column_sql, + ) + + async def _ensure_ingestion_task_table(self, session: AsyncSession) -> None: + await session.execute( + text( + """ + CREATE TABLE IF NOT EXISTS kb_ingestion_tasks ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + task_id VARCHAR(36) NOT NULL UNIQUE, + kb_id VARCHAR(36) NOT NULL, + task_type VARCHAR(30) NOT NULL, + status VARCHAR(20) NOT NULL DEFAULT 'pending', + progress_stage VARCHAR(50), + progress_current INTEGER NOT NULL DEFAULT 0, + progress_total INTEGER NOT NULL DEFAULT 100, + progress TEXT, + result TEXT, + error TEXT, + created_at DATETIME NOT NULL, + updated_at DATETIME NOT NULL + ) + """, + ), + ) + + async def _ensure_ingestion_task_indexes(self, session: AsyncSession) -> None: + indexes = { + "idx_task_task_id": "task_id", + "idx_task_kb_id": "kb_id", + "idx_task_type": "task_type", + "idx_task_status": "status", + "idx_task_created_at": "created_at", + } + for index_name, column_name in indexes.items(): + await session.execute( + text( + f"CREATE INDEX IF NOT EXISTS {index_name} " + f"ON kb_ingestion_tasks({column_name})", + ), + ) + + @staticmethod + def _encode_json(value) -> str | None: + if value is None: + return None + return json.dumps(value, ensure_ascii=False, default=str) + + @staticmethod + def _decode_json(value: str | None): + if value is None: + return None + try: + return json.loads(value) + except json.JSONDecodeError: + return value + + @classmethod + def _task_to_dict(cls, task: KBIngestionTask) -> dict: + return { + "task_id": task.task_id, + "kb_id": task.kb_id, + "task_type": task.task_type, + "status": task.status, + "progress_stage": task.progress_stage, + "progress_current": task.progress_current, + "progress_total": task.progress_total, + "progress": cls._decode_json(task.progress), + "result": cls._decode_json(task.result), + "error": cls._decode_json(task.error), + "created_at": task.created_at.isoformat(), + "updated_at": task.updated_at.isoformat(), + } + async def close(self) -> None: """关闭数据库连接""" await self.engine.dispose() @@ -239,6 +363,146 @@ async def count_kbs(self) -> int: result = await session.execute(stmt) return result.scalar() or 0 + # ===== 任务查询 ===== + + async def create_ingestion_task( + self, + *, + task_id: str, + kb_id: str, + task_type: str, + status: str = "pending", + progress_stage: str | None = None, + progress_current: int = 0, + progress_total: int = 100, + progress: dict | None = None, + ) -> dict: + task = KBIngestionTask( + task_id=task_id, + kb_id=kb_id, + task_type=task_type, + status=status, + progress_stage=progress_stage, + progress_current=progress_current, + progress_total=progress_total, + progress=self._encode_json(progress), + ) + async with self.get_db() as session: + session.add(task) + await session.commit() + await session.refresh(task) + return self._task_to_dict(task) + + async def update_ingestion_task( + self, + task_id: str, + *, + status: str | object = _UNSET, + progress_stage: str | None | object = _UNSET, + progress_current: int | object = _UNSET, + progress_total: int | object = _UNSET, + progress: dict | None | object = _UNSET, + result: dict | None | object = _UNSET, + error: str | None | object = _UNSET, + ) -> dict | None: + async with self.get_db() as session: + stmt = select(KBIngestionTask).where( + col(KBIngestionTask.task_id) == task_id, + ) + query_result = await session.execute(stmt) + task = query_result.scalar_one_or_none() + if task is None: + return None + + if status is not _UNSET: + task.status = status # type: ignore[assignment] + if progress_stage is not _UNSET: + task.progress_stage = progress_stage # type: ignore[assignment] + if progress_current is not _UNSET: + task.progress_current = progress_current # type: ignore[assignment] + if progress_total is not _UNSET: + task.progress_total = progress_total # type: ignore[assignment] + if progress is not _UNSET: + task.progress = self._encode_json(progress) + if result is not _UNSET: + task.result = self._encode_json(result) + if error is not _UNSET: + task.error = self._encode_json(error) + task.updated_at = datetime.now(timezone.utc) + + session.add(task) + await session.commit() + await session.refresh(task) + return self._task_to_dict(task) + + async def get_ingestion_task(self, task_id: str) -> dict | None: + async with self.get_db() as session: + stmt = select(KBIngestionTask).where( + col(KBIngestionTask.task_id) == task_id, + ) + result = await session.execute(stmt) + task = result.scalar_one_or_none() + return self._task_to_dict(task) if task is not None else None + + @staticmethod + def _build_ingestion_task_conditions( + *, + kb_id: str | None = None, + status: str | None = None, + task_type: str | None = None, + ) -> list: + conditions = [] + if kb_id is not None: + conditions.append(col(KBIngestionTask.kb_id) == kb_id) + if status is not None: + conditions.append(col(KBIngestionTask.status) == status) + if task_type is not None: + conditions.append(col(KBIngestionTask.task_type) == task_type) + return conditions + + async def list_ingestion_tasks( + self, + *, + kb_id: str | None = None, + status: str | None = None, + task_type: str | None = None, + offset: int = 0, + limit: int = 100, + ) -> list[dict]: + conditions = self._build_ingestion_task_conditions( + kb_id=kb_id, + status=status, + task_type=task_type, + ) + + async with self.get_db() as session: + stmt = ( + select(KBIngestionTask) + .where(*conditions) + .offset(offset) + .limit(limit) + .order_by(desc(KBIngestionTask.created_at)) + ) + result = await session.execute(stmt) + return [self._task_to_dict(task) for task in result.scalars().all()] + + async def count_ingestion_tasks( + self, + *, + kb_id: str | None = None, + status: str | None = None, + task_type: str | None = None, + ) -> int: + conditions = self._build_ingestion_task_conditions( + kb_id=kb_id, + status=status, + task_type=task_type, + ) + async with self.get_db() as session: + stmt = select(func.count(col(KBIngestionTask.id))).where(*conditions) + result = await session.execute(stmt) + return result.scalar() or 0 + # ===== 文档查询 ===== async def get_document_by_id(self, doc_id: str) -> KBDocument | None: @@ -248,24 +512,67 @@ async def get_document_by_id(self, doc_id: str) -> KBDocument | None: result = await session.execute(stmt) return result.scalar_one_or_none() + async def get_document_by_content_hash( + self, + *, + kb_id: str, + content_hash: str, + ) -> KBDocument | None: + """Return an existing active document with the same source content hash.""" + async with self.get_db() as session: + stmt = ( + select(KBDocument) + .where( + col(KBDocument.kb_id) == kb_id, + col(KBDocument.content_hash) == content_hash, + col(KBDocument.status) != "failed", + ) + .order_by(desc(KBDocument.created_at)) + .limit(1) + ) + result = await session.execute(stmt) + return result.scalar_one_or_none() + + @staticmethod + def _build_document_filters( + *, + kb_id: str, + search: str | None = None, + status: str | None = None, + source_type: str | None = None, + ) -> list: + conditions = [col(KBDocument.kb_id) == kb_id] + if search: + pattern = f"%{search}%" + conditions.append( + or_( + col(KBDocument.doc_name).ilike(pattern), + col(KBDocument.file_type).ilike(pattern), + ), + ) + if status: + conditions.append(col(KBDocument.status) == status) + if source_type: + conditions.append(col(KBDocument.source_type) == source_type) + return conditions + async def list_documents_by_kb( self, kb_id: str, offset: int = 0, limit: int = 100, search: str | None = None, + status: str | None = None, + source_type: str | None = None, ) -> list[KBDocument]: """列出知识库的所有文档""" async with self.get_db() as session: - conditions = [col(KBDocument.kb_id) == kb_id] - if search: - pattern = f"%{search}%" - conditions.append( - or_( - col(KBDocument.doc_name).ilike(pattern), - col(KBDocument.file_type).ilike(pattern), - ), - ) + conditions = self._build_document_filters( + kb_id=kb_id, + search=search, + status=status, + source_type=source_type, + ) stmt = ( select(KBDocument) .where(*conditions) @@ -280,18 +587,17 @@ async def count_documents_by_kb( self, kb_id: str, search: str | None = None, + status: str | None = None, + source_type: str | None = None, ) -> int: """统计知识库的文档数量""" async with self.get_db() as session: - conditions = [col(KBDocument.kb_id) == kb_id] - if search: - pattern = f"%{search}%" - conditions.append( - or_( - col(KBDocument.doc_name).ilike(pattern), - col(KBDocument.file_type).ilike(pattern), - ), - ) + conditions = self._build_document_filters( + kb_id=kb_id, + search=search, + status=status, + source_type=source_type, + ) stmt = select(func.count(col(KBDocument.id))).where(*conditions) result = await session.execute(stmt) return result.scalar() or 0 @@ -481,3 +787,84 @@ async def update_kb_stats(self, kb_id: str, vec_db: "FaissVecDB") -> None: await session.execute(update_stmt) await session.commit() + + async def get_kb_stats(self, kb_id: str) -> dict | None: + """Return persisted document statistics for a knowledge base.""" + async with self.get_db() as session: + kb_result = await session.execute( + select(KnowledgeBase).where(col(KnowledgeBase.kb_id) == kb_id), + ) + kb = kb_result.scalar_one_or_none() + if kb is None: + return None + + status_result = await session.execute( + select(KBDocument.status, func.count(col(KBDocument.id))) + .where(col(KBDocument.kb_id) == kb_id) + .group_by(KBDocument.status), + ) + status_counts = { + status or "unknown": count for status, count in status_result.all() + } + + chunk_result = await session.execute( + select(func.coalesce(func.sum(col(KBDocument.chunk_count)), 0)).where( + col(KBDocument.kb_id) == kb_id, + ), + ) + document_chunk_count = int(chunk_result.scalar() or 0) + + media_result = await session.execute( + select(func.count(col(KBMedia.id))).where(col(KBMedia.kb_id) == kb_id), + ) + media_count = int(media_result.scalar() or 0) + source_file_count_result = await session.execute( + select(func.count(col(KBDocument.id))).where( + col(KBDocument.kb_id) == kb_id, + col(KBDocument.source_type) == "file", + col(KBDocument.file_path) != "", + ), + ) + source_file_count = int(source_file_count_result.scalar() or 0) + document_storage_result = await session.execute( + select(func.coalesce(func.sum(col(KBDocument.file_size)), 0)).where( + col(KBDocument.kb_id) == kb_id, + col(KBDocument.file_path) != "", + ), + ) + document_storage_bytes = int(document_storage_result.scalar() or 0) + media_storage_result = await session.execute( + select(func.coalesce(func.sum(col(KBMedia.file_size)), 0)).where( + col(KBMedia.kb_id) == kb_id, + ), + ) + media_storage_bytes = int(media_storage_result.scalar() or 0) + + document_count = sum(status_counts.values()) + ready_document_count = status_counts.get("ready", 0) + failed_document_count = status_counts.get("failed", 0) + pending_document_count = status_counts.get("pending", 0) + processing_document_count = sum( + status_counts.get(status, 0) + for status in ("parsing", "chunking", "embedding") + ) + + return { + "kb_id": kb.kb_id, + "kb_name": kb.kb_name, + "doc_count": kb.doc_count, + "chunk_count": kb.chunk_count, + "document_count": document_count, + "ready_document_count": ready_document_count, + "failed_document_count": failed_document_count, + "pending_document_count": pending_document_count, + "processing_document_count": processing_document_count, + "indexed_chunk_count": kb.chunk_count, + "document_chunk_count": document_chunk_count, + "media_count": media_count, + "source_file_count": source_file_count, + "storage_bytes": document_storage_bytes + media_storage_bytes, + "status_counts": status_counts, + "created_at": kb.created_at.isoformat(), + "updated_at": kb.updated_at.isoformat(), + } diff --git a/astrbot/core/knowledge_base/kb_helper.py b/astrbot/core/knowledge_base/kb_helper.py index 6935fb9ac3..a4a8b4b5e6 100644 --- a/astrbot/core/knowledge_base/kb_helper.py +++ b/astrbot/core/knowledge_base/kb_helper.py @@ -3,6 +3,7 @@ import re import time import uuid +from datetime import datetime, timezone from pathlib import Path from typing import TYPE_CHECKING @@ -11,7 +12,6 @@ from astrbot.core import logger from astrbot.core.db.vec_db.base import BaseVecDB from astrbot.core.exceptions import KnowledgeBaseUploadError -from astrbot.core.provider.manager import ProviderManager from astrbot.core.provider.provider import ( EmbeddingProvider, RerankProvider, @@ -20,17 +20,45 @@ Provider as LLMProvider, ) +from .capabilities import ( + DEFAULT_CHUNK_OVERLAP, + DEFAULT_CHUNK_SIZE, + DEFAULT_UPLOAD_BATCH_SIZE, + DEFAULT_UPLOAD_MAX_RETRIES, + DEFAULT_UPLOAD_TASKS_LIMIT, +) from .chunking.base import BaseChunker from .chunking.markdown import MarkdownChunker from .chunking.recursive import RecursiveCharacterChunker +from .document_metadata import ( + DEFAULT_CHUNKER_VERSION, + DEFAULT_PARSER_VERSION, + build_content_hash, + build_stored_source_path, + get_chunker_name, + get_parser_name, +) from .kb_db_sqlite import KBSQLiteDatabase from .models import KBDocument, KBMedia, KnowledgeBase -from .parsers.url_parser import extract_text_from_url +from .parsers.base import TextSegment +from .parsers.url_parser import URLExtractor, extract_text_from_url from .parsers.util import select_parser from .prompts import TEXT_REPAIR_SYSTEM_PROMPT if TYPE_CHECKING: from astrbot.core.db.vec_db.faiss_impl.vec_db import FaissVecDB + from astrbot.core.provider.manager import ProviderManager + + +DOCUMENT_REBUILD_PAGE_SIZE = 100 +CONSISTENCY_CHECK_PAGE_SIZE = 1000 +CONSISTENCY_REPAIR_TYPES = frozenset( + { + "orphan_vectors", + "chunk_count_mismatches", + }, +) +NON_PERSISTED_FAILURE_STAGES = frozenset({"deduplication"}) class RateLimiter: @@ -116,6 +144,114 @@ def _compact_chunks(chunks: list[str]) -> list[str]: return [chunk.strip() for chunk in chunks if chunk and chunk.strip()] +def _estimate_text_tokens(text: str) -> int: + chinese_count = sum(1 for char in text if "\u4e00" <= char <= "\u9fff") + other_count = len(text) - chinese_count + return int(chinese_count * 0.6 + other_count * 0.3) + + +def _build_chunk_metadata( + *, + kb_id: str, + doc_id: str, + chunks_text: list[str], + chunk_ids: list[str], + chunk_extra_metadatas: list[dict] | None = None, +) -> list[dict]: + if chunk_extra_metadatas is not None and len(chunk_extra_metadatas) != len( + chunks_text + ): + raise ValueError("chunk_extra_metadatas length must match chunks_text length") + + metadatas = [] + start_offset = 0 + for idx, chunk_text in enumerate(chunks_text): + end_offset = start_offset + len(chunk_text) + metadata = { + "kb_id": kb_id, + "kb_doc_id": doc_id, + "chunk_index": idx, + "section_index": idx, + "content_hash": build_content_hash(chunk_text), + "char_count": len(chunk_text), + "token_count_estimate": _estimate_text_tokens(chunk_text), + "start_offset": start_offset, + "end_offset": end_offset, + "previous_chunk_id": chunk_ids[idx - 1] if idx > 0 else None, + "next_chunk_id": chunk_ids[idx + 1] if idx < len(chunk_ids) - 1 else None, + } + if chunk_extra_metadatas is not None: + metadata.update(chunk_extra_metadatas[idx]) + metadatas.append(metadata) + start_offset = end_offset + return metadatas + + +async def _chunk_text_with_metadata( + *, + chunker: BaseChunker, + text: str, + chunk_size: int, + chunk_overlap: int, + extra_metadata: dict | None = None, +) -> tuple[list[str], list[dict] | None]: + chunks_text = await chunker.chunk( + text, + chunk_size=chunk_size, + chunk_overlap=chunk_overlap, + ) + chunks_text = _compact_chunks(chunks_text) + if not chunks_text: + return [], [] if extra_metadata is not None else None + if extra_metadata is None: + return chunks_text, None + return chunks_text, [dict(extra_metadata) for _ in chunks_text] + + +async def _chunk_text_segments_with_metadata( + *, + chunker: BaseChunker, + text_segments: list[TextSegment], + chunk_size: int, + chunk_overlap: int, +) -> tuple[list[str], list[dict]]: + chunks_text: list[str] = [] + chunk_extra_metadatas: list[dict] = [] + for segment in text_segments: + segment_text = getattr(segment, "text", "") + segment_metadata = getattr(segment, "metadata", None) or {} + segment_chunks, segment_metadatas = await _chunk_text_with_metadata( + chunker=chunker, + text=segment_text, + chunk_size=chunk_size, + chunk_overlap=chunk_overlap, + extra_metadata=segment_metadata, + ) + chunks_text.extend(segment_chunks) + chunk_extra_metadatas.extend(segment_metadatas or []) + return chunks_text, chunk_extra_metadatas + + +def _build_duplicate_document_error( + *, + file_name: str, + content_hash: str, + existing_doc: KBDocument, +) -> KnowledgeBaseUploadError: + return KnowledgeBaseUploadError( + stage="deduplication", + user_message=( + f"重复文档:{file_name} 与已存在文档 {existing_doc.doc_name} 内容相同。" + ), + details={ + "file_name": file_name, + "content_hash": content_hash, + "existing_doc_id": existing_doc.doc_id, + "existing_doc_name": existing_doc.doc_name, + }, + ) + + class KBHelper: vec_db: BaseVecDB kb: KnowledgeBase @@ -125,7 +261,7 @@ def __init__( self, kb_db: KBSQLiteDatabase, kb: KnowledgeBase, - provider_manager: ProviderManager, + provider_manager: "ProviderManager", kb_root_dir: str, chunker: BaseChunker, ) -> None: @@ -213,18 +349,162 @@ async def terminate(self) -> None: if hasattr(self, "vec_db") and self.vec_db: await self.vec_db.close() + async def _ensure_not_duplicate_document( + self, + *, + file_name: str, + content_hash: str | None, + ) -> None: + if not content_hash: + return + try: + existing_doc = await self.kb_db.get_document_by_content_hash( + kb_id=self.kb.kb_id, + content_hash=content_hash, + ) + except KnowledgeBaseUploadError: + raise + except Exception as exc: + raise KnowledgeBaseUploadError( + stage="deduplication", + user_message=("重复检测失败:无法确认文档是否已存在,请稍后重试。"), + details={"file_name": file_name, "content_hash": content_hash}, + ) from exc + if existing_doc is not None: + raise _build_duplicate_document_error( + file_name=file_name, + content_hash=content_hash, + existing_doc=existing_doc, + ) + + @staticmethod + def _get_upload_failure_stage(error: Exception) -> str: + if isinstance(error, KnowledgeBaseUploadError): + return error.stage + return "unknown" + + async def _persist_failed_document( + self, + *, + doc_id: str, + file_name: str, + file_type: str, + file_size: int, + stored_file_path: Path | None, + source_type: str, + source_uri: str, + content_hash: str | None, + parser_name: str | None, + chunker_name: str | None, + parent_doc_id: str | None, + document_version: int, + error: Exception, + ) -> bool: + """Persist a failed document record for ingestion diagnostics.""" + error_stage = self._get_upload_failure_stage(error) + if error_stage in NON_PERSISTED_FAILURE_STAGES: + return False + + failed_doc = KBDocument( + doc_id=doc_id, + kb_id=self.kb.kb_id, + doc_name=file_name, + file_type=file_type, + file_size=file_size, + file_path=str(stored_file_path) if stored_file_path else "", + source_type=source_type, + source_uri=source_uri, + content_hash=content_hash, + parser_name=parser_name, + parser_version=DEFAULT_PARSER_VERSION if parser_name else None, + chunker_name=chunker_name, + chunker_version=DEFAULT_CHUNKER_VERSION if chunker_name else None, + status="failed", + error_stage=error_stage, + error_message=str(error).strip() or error.__class__.__name__, + version=document_version, + parent_doc_id=parent_doc_id, + ) + + try: + async with self.kb_db.get_db() as session: + async with session.begin(): + session.add(failed_doc) + await session.commit() + await session.refresh(failed_doc) + except Exception as persist_err: + logger.warning( + f"记录失败文档 {doc_id} 的元数据失败: {persist_err}", + ) + return False + + try: + await self.kb_db.update_kb_stats( + kb_id=self.kb.kb_id, + vec_db=self.vec_db, # type: ignore[arg-type] + ) + await self.refresh_kb() + await self.refresh_document(doc_id) + except Exception as stats_err: + logger.warning( + f"刷新失败文档 {doc_id} 的知识库统计失败: {stats_err}", + ) + return True + + @staticmethod + def _build_url_file_name(url: str) -> str: + file_name = url.split("/")[-1] or f"document_from_{url}" + if not Path(file_name).suffix: + file_name += ".url" + return file_name + + async def _persist_failed_url_document( + self, + *, + url: str, + text_content: str | None, + parent_doc_id: str | None, + document_version: int, + error: Exception, + ) -> bool: + return await self._persist_failed_document( + doc_id=str(uuid.uuid4()), + file_name=self._build_url_file_name(url), + file_type="url", + file_size=len(text_content) if text_content else 0, + stored_file_path=None, + source_type="url", + source_uri=url, + content_hash=( + build_content_hash(text_content) if text_content is not None else None + ), + parser_name=URLExtractor.__name__, + chunker_name=get_chunker_name(self.chunker), + parent_doc_id=parent_doc_id, + document_version=document_version, + error=error, + ) + async def upload_document( self, file_name: str, file_content: bytes | None, file_type: str, - chunk_size: int = 512, - chunk_overlap: int = 50, - batch_size: int = 32, - tasks_limit: int = 3, - max_retries: int = 3, + chunk_size: int = DEFAULT_CHUNK_SIZE, + chunk_overlap: int = DEFAULT_CHUNK_OVERLAP, + batch_size: int = DEFAULT_UPLOAD_BATCH_SIZE, + tasks_limit: int = DEFAULT_UPLOAD_TASKS_LIMIT, + max_retries: int = DEFAULT_UPLOAD_MAX_RETRIES, progress_callback=None, pre_chunked_text: list[str] | None = None, + source_type: str | None = None, + source_uri: str | None = None, + source_content_hash: str | None = None, + source_parser_name: str | None = None, + source_chunker_name: str | None = None, + parent_doc_id: str | None = None, + document_version: int = 1, + skip_duplicate_check: bool = False, ) -> KBDocument: """上传并处理文档(带原子性保证和失败清理) @@ -247,21 +527,37 @@ async def upload_document( await self._ensure_vec_db() doc_id = str(uuid.uuid4()) media_paths: list[Path] = [] + stored_file_path: Path | None = None file_size = 0 vectors_stored = False # 标记向量是否已写入, 用于失败回滚 - - # file_path = self.kb_files_dir / f"{doc_id}.{file_type}" - # async with aiofiles.open(file_path, "wb") as f: - # await f.write(file_content) + metadata_stored = False + failed_metadata_stored = False + effective_source_type = source_type or ( + "import" if pre_chunked_text is not None else "file" + ) + effective_source_uri = source_uri or file_name + content_hash: str | None = source_content_hash + parser_name: str | None = source_parser_name + chunker_name: str | None = source_chunker_name try: chunks_text = [] + chunk_extra_metadatas: list[dict] | None = None saved_media = [] if pre_chunked_text is not None: # 如果提供了预分块文本,直接使用 chunks_text = _compact_chunks(pre_chunked_text) file_size = sum(len(chunk) for chunk in chunks_text) + if content_hash is None: + content_hash = build_content_hash(chunks_text) + if chunker_name is None: + chunker_name = "pre_chunked" + if not skip_duplicate_check: + await self._ensure_not_duplicate_document( + file_name=file_name, + content_hash=content_hash, + ) logger.info(f"使用预分块文本进行上传,共 {len(chunks_text)} 个块。") else: # 否则,执行标准的文件解析和分块流程 @@ -271,6 +567,22 @@ async def upload_document( ) file_size = len(file_content) + content_hash = build_content_hash(file_content) + if not skip_duplicate_check: + await self._ensure_not_duplicate_document( + file_name=file_name, + content_hash=content_hash, + ) + + stored_file_path = build_stored_source_path( + self.kb_files_dir, + doc_id=doc_id, + file_name=file_name, + file_type=file_type, + ) + stored_file_path.parent.mkdir(parents=True, exist_ok=True) + async with aiofiles.open(stored_file_path, "wb") as f: + await f.write(file_content) # 阶段1: 解析文档 if progress_callback: @@ -278,6 +590,7 @@ async def upload_document( try: parser = await select_parser(f".{file_type}") + parser_name = get_parser_name(parser) parse_result = await parser.parse(file_content, file_name) except KnowledgeBaseUploadError: raise @@ -292,6 +605,7 @@ async def upload_document( ) from exc text_content = parse_result.text media_items = parse_result.media + text_segments = getattr(parse_result, "text_segments", None) if not text_content or not text_content.strip(): raise KnowledgeBaseUploadError( stage="parsing", @@ -334,12 +648,46 @@ async def upload_document( f"检测到 Markdown 文件 '{file_name}',使用 MarkdownChunker 进行结构化分块" ) - chunks_text = await effective_chunker.chunk( - text_content, - chunk_size=chunk_size, - chunk_overlap=chunk_overlap, - ) - chunks_text = _compact_chunks(chunks_text) + chunker_name = get_chunker_name(effective_chunker) + if isinstance(effective_chunker, MarkdownChunker): + structured_chunks = await effective_chunker.chunk_with_metadata( + text_content, + chunk_size=chunk_size, + chunk_overlap=chunk_overlap, + ) + chunks_text = [] + chunk_extra_metadatas = [] + for chunk in structured_chunks: + chunk_text = chunk.text.strip() + if not chunk_text: + continue + chunks_text.append(chunk_text) + chunk_extra_metadatas.append( + { + "title_path": chunk.title_path, + "section_index": chunk.section_index, + } + ) + elif text_segments: + ( + chunks_text, + chunk_extra_metadatas, + ) = await _chunk_text_segments_with_metadata( + chunker=effective_chunker, + text_segments=text_segments, + chunk_size=chunk_size, + chunk_overlap=chunk_overlap, + ) + else: + ( + chunks_text, + chunk_extra_metadatas, + ) = await _chunk_text_with_metadata( + chunker=effective_chunker, + text=text_content, + chunk_size=chunk_size, + chunk_overlap=chunk_overlap, + ) except KnowledgeBaseUploadError: raise except Exception as exc: @@ -369,16 +717,16 @@ async def upload_document( ) contents = [] - metadatas = [] for idx, chunk_text in enumerate(chunks_text): contents.append(chunk_text) - metadatas.append( - { - "kb_id": self.kb.kb_id, - "kb_doc_id": doc_id, - "chunk_index": idx, - }, - ) + chunk_ids = [str(uuid.uuid4()) for _ in chunks_text] + metadatas = _build_chunk_metadata( + kb_id=self.kb.kb_id, + doc_id=doc_id, + chunks_text=chunks_text, + chunk_ids=chunk_ids, + chunk_extra_metadatas=chunk_extra_metadatas, + ) if progress_callback: await progress_callback("chunking", 100, 100) @@ -392,6 +740,7 @@ async def embedding_progress_callback(current, total) -> None: await self.vec_db.insert_batch( contents=contents, metadatas=metadatas, + ids=chunk_ids, batch_size=batch_size, tasks_limit=tasks_limit, max_retries=max_retries, @@ -414,10 +763,20 @@ async def embedding_progress_callback(current, total) -> None: doc_name=file_name, file_type=file_type, file_size=file_size, - # file_path=str(file_path), - file_path="", + file_path=str(stored_file_path) if stored_file_path else "", + source_type=effective_source_type, + source_uri=effective_source_uri, + content_hash=content_hash, + parser_name=parser_name, + parser_version=DEFAULT_PARSER_VERSION if parser_name else None, + chunker_name=chunker_name, + chunker_version=DEFAULT_CHUNKER_VERSION if chunker_name else None, + status="ready", + indexed_at=datetime.now(timezone.utc), + version=document_version, + parent_doc_id=parent_doc_id, chunk_count=len(chunks_text), - media_count=0, + media_count=len(saved_media), ) try: async with self.kb_db.get_db() as session: @@ -426,6 +785,7 @@ async def embedding_progress_callback(current, total) -> None: for media in saved_media: session.add(media) await session.commit() + metadata_stored = True await session.refresh(doc) except KnowledgeBaseUploadError: @@ -462,7 +822,7 @@ async def embedding_progress_callback(current, total) -> None: logger.error(f"上传文档失败: {e}", exc_info=True) # 回滚已写入的向量, 防止孤数据 - if vectors_stored: + if vectors_stored and not metadata_stored: try: vec_db: FaissVecDB = self.vec_db # type: ignore await vec_db.delete_documents( @@ -474,15 +834,43 @@ async def embedding_progress_callback(current, total) -> None: f"清理文档 {doc_id} 向量回滚失败: {cleanup_err}", ) - # if file_path.exists(): - # file_path.unlink() + if not metadata_stored: + failed_metadata_stored = await self._persist_failed_document( + doc_id=doc_id, + file_name=file_name, + file_type=file_type, + file_size=file_size, + stored_file_path=stored_file_path, + source_type=effective_source_type, + source_uri=effective_source_uri, + content_hash=content_hash, + parser_name=parser_name, + chunker_name=chunker_name, + parent_doc_id=parent_doc_id, + document_version=document_version, + error=e, + ) - for media_path in media_paths: + if ( + stored_file_path + and stored_file_path.exists() + and not metadata_stored + and not failed_metadata_stored + ): try: - if media_path.exists(): - media_path.unlink() - except Exception as me: - logger.warning(f"清理多媒体文件失败 {media_path}: {me}") + stored_file_path.unlink() + if stored_file_path.parent != self.kb_files_dir: + stored_file_path.parent.rmdir() + except Exception as fe: + logger.warning(f"清理原始文件失败 {stored_file_path}: {fe}") + + if not metadata_stored: + for media_path in media_paths: + try: + if media_path.exists(): + media_path.unlink() + except Exception as me: + logger.warning(f"清理多媒体文件失败 {media_path}: {me}") raise @@ -491,6 +879,8 @@ async def list_documents( offset: int = 0, limit: int = 100, search: str | None = None, + status: str | None = None, + source_type: str | None = None, ) -> list[KBDocument]: """列出知识库的所有文档""" docs = await self.kb_db.list_documents_by_kb( @@ -498,12 +888,24 @@ async def list_documents( offset, limit, search, + status=status, + source_type=source_type, ) return docs - async def count_documents(self, search: str | None = None) -> int: + async def count_documents( + self, + search: str | None = None, + status: str | None = None, + source_type: str | None = None, + ) -> int: """统计知识库的所有文档数量""" - return await self.kb_db.count_documents_by_kb(self.kb.kb_id, search) + return await self.kb_db.count_documents_by_kb( + self.kb.kb_id, + search, + status=status, + source_type=source_type, + ) async def get_document(self, doc_id: str) -> KBDocument | None: """获取单个文档""" @@ -512,6 +914,10 @@ async def get_document(self, doc_id: str) -> KBDocument | None: async def delete_document(self, doc_id: str) -> None: """删除单个文档及其相关数据""" + doc = await self.get_document(doc_id) + if not doc: + raise ValueError(f"无法找到 ID 为 {doc_id} 的文档") + media_items = await self.kb_db.list_media_by_doc(doc_id) deleted = await self.kb_db.delete_document_by_id( doc_id=doc_id, vec_db=self.vec_db, # type: ignore @@ -519,6 +925,7 @@ async def delete_document(self, doc_id: str) -> None: ) if not deleted: raise ValueError(f"无法找到 ID 为 {doc_id} 的文档") + self._cleanup_document_files(doc, media_items) await self.kb_db.update_kb_stats( kb_id=self.kb.kb_id, vec_db=self.vec_db, # type: ignore @@ -530,11 +937,25 @@ async def delete_documents(self, doc_ids: list[str]) -> dict[str, bool]: vec_db 删除失败不阻塞其他文档(best-effort)。 """ + docs_by_id = { + doc_id: doc + for doc_id in dict.fromkeys(doc_ids) + if (doc := await self.get_document(doc_id)) is not None + } + media_by_doc_id = { + doc_id: await self.kb_db.list_media_by_doc(doc_id) for doc_id in docs_by_id + } results = await self.kb_db.delete_documents_by_ids( doc_ids=doc_ids, vec_db=self.vec_db, # type: ignore kb_id=self.kb.kb_id, ) + for doc_id, deleted in results.items(): + if deleted and doc_id in docs_by_id: + self._cleanup_document_files( + docs_by_id[doc_id], + media_by_doc_id.get(doc_id, []), + ) await self.kb_db.update_kb_stats( kb_id=self.kb.kb_id, vec_db=self.vec_db, # type: ignore @@ -542,6 +963,286 @@ async def delete_documents(self, doc_ids: list[str]) -> dict[str, bool]: await self.refresh_kb() return results + async def rebuild_document( + self, + doc_id: str, + *, + chunk_size: int | None = None, + chunk_overlap: int | None = None, + batch_size: int = DEFAULT_UPLOAD_BATCH_SIZE, + tasks_limit: int = DEFAULT_UPLOAD_TASKS_LIMIT, + max_retries: int = DEFAULT_UPLOAD_MAX_RETRIES, + progress_callback=None, + ) -> KBDocument: + doc = await self.get_document(doc_id) + if not doc: + raise ValueError(f"无法找到 ID 为 {doc_id} 的文档") + next_version = (doc.version or 1) + 1 + parent_doc_id = doc.parent_doc_id or doc.doc_id + effective_chunk_size = ( + chunk_size + if chunk_size is not None + else self.kb.chunk_size or DEFAULT_CHUNK_SIZE + ) + effective_chunk_overlap = ( + chunk_overlap + if chunk_overlap is not None + else self.kb.chunk_overlap or DEFAULT_CHUNK_OVERLAP + ) + + if doc.source_type == "file" and doc.file_path: + source_path = Path(doc.file_path).resolve(strict=False) + files_root = self.kb_files_dir.resolve(strict=False) + if not source_path.is_relative_to(files_root) or not source_path.exists(): + raise ValueError("无法找到可用于重建的原始文件") + + rebuilt_doc = await self.upload_document( + file_name=doc.doc_name, + file_content=source_path.read_bytes(), + file_type=doc.file_type, + chunk_size=effective_chunk_size, + chunk_overlap=effective_chunk_overlap, + batch_size=batch_size, + tasks_limit=tasks_limit, + max_retries=max_retries, + progress_callback=progress_callback, + source_type=doc.source_type, + source_uri=doc.source_uri or doc.doc_name, + parent_doc_id=parent_doc_id, + document_version=next_version, + skip_duplicate_check=True, + ) + elif doc.source_type == "url": + if not doc.source_uri: + raise ValueError("无法找到可用于重建的 URL 来源") + rebuilt_doc = await self.upload_from_url( + url=doc.source_uri, + chunk_size=effective_chunk_size, + chunk_overlap=effective_chunk_overlap, + batch_size=batch_size, + tasks_limit=tasks_limit, + max_retries=max_retries, + progress_callback=progress_callback, + parent_doc_id=parent_doc_id, + document_version=next_version, + skip_duplicate_check=True, + ) + elif doc.source_type == "import": + imported_chunks = await self._get_import_rebuild_chunks(doc.doc_id) + if not imported_chunks: + raise ValueError("无法找到可用于重建的导入文本块") + rebuilt_doc = await self.upload_document( + file_name=doc.doc_name, + file_content=None, + file_type=doc.file_type, + chunk_size=effective_chunk_size, + chunk_overlap=effective_chunk_overlap, + batch_size=batch_size, + tasks_limit=tasks_limit, + max_retries=max_retries, + progress_callback=progress_callback, + pre_chunked_text=imported_chunks, + source_type="import", + source_uri=doc.source_uri or doc.doc_name, + source_content_hash=build_content_hash(imported_chunks), + source_chunker_name=doc.chunker_name or "pre_chunked", + parent_doc_id=parent_doc_id, + document_version=next_version, + skip_duplicate_check=True, + ) + else: + raise ValueError("当前仅支持重建已保存原始文件、URL 或导入来源的文档") + + try: + await self.delete_document(doc_id) + except Exception as exc: + try: + await self.delete_document(rebuilt_doc.doc_id) + except Exception as cleanup_exc: + logger.error( + f"重建文档 {doc_id} 后清理新版本失败: {cleanup_exc}", + ) + raise KnowledgeBaseUploadError( + stage="rebuild", + user_message=( + "重建失败:新版本已生成,但替换旧文档时失败,已尝试回滚新版本。" + ), + details={ + "doc_id": doc_id, + "new_doc_id": rebuilt_doc.doc_id, + }, + ) from exc + return rebuilt_doc + + async def _get_import_rebuild_chunks(self, doc_id: str) -> list[str]: + chunks: list[dict] = [] + offset = 0 + while True: + page = await self.get_chunks_by_doc_id( + doc_id, + offset=offset, + limit=DOCUMENT_REBUILD_PAGE_SIZE, + ) + if not page: + break + chunks.extend(page) + if len(page) < DOCUMENT_REBUILD_PAGE_SIZE: + break + offset += DOCUMENT_REBUILD_PAGE_SIZE + + chunks.sort(key=lambda chunk: int(chunk.get("chunk_index") or 0)) + return [ + chunk["content"] + for chunk in chunks + if isinstance(chunk.get("content"), str) and chunk["content"].strip() + ] + + async def rebuild_all_documents( + self, + *, + chunk_size: int | None = None, + chunk_overlap: int | None = None, + batch_size: int = DEFAULT_UPLOAD_BATCH_SIZE, + tasks_limit: int = DEFAULT_UPLOAD_TASKS_LIMIT, + max_retries: int = DEFAULT_UPLOAD_MAX_RETRIES, + progress_callback=None, + ) -> dict: + docs: list[KBDocument] = [] + offset = 0 + while True: + page = await self.list_documents( + offset=offset, + limit=DOCUMENT_REBUILD_PAGE_SIZE, + ) + docs.extend(page) + if len(page) < DOCUMENT_REBUILD_PAGE_SIZE: + break + offset += DOCUMENT_REBUILD_PAGE_SIZE + + rebuilt_docs = [] + failed_docs = [] + + total = len(docs) + for index, doc in enumerate(docs, start=1): + if progress_callback: + await progress_callback("rebuilding", index - 1, total) + try: + rebuilt = await self.rebuild_document( + doc.doc_id, + chunk_size=chunk_size, + chunk_overlap=chunk_overlap, + batch_size=batch_size, + tasks_limit=tasks_limit, + max_retries=max_retries, + progress_callback=progress_callback, + ) + rebuilt_docs.append(rebuilt.model_dump()) + except Exception as e: + logger.error(f"重建文档 {doc.doc_id} 失败: {e}") + failed_docs.append( + { + "doc_id": doc.doc_id, + "doc_name": doc.doc_name, + "error": str(e), + }, + ) + + if progress_callback: + await progress_callback("rebuilding", total, total) + + return { + "rebuilt": rebuilt_docs, + "failed": failed_docs, + "total": total, + "success_count": len(rebuilt_docs), + "failed_count": len(failed_docs), + } + + async def rebuild_documents( + self, + doc_ids: list[str], + *, + chunk_size: int | None = None, + chunk_overlap: int | None = None, + batch_size: int = DEFAULT_UPLOAD_BATCH_SIZE, + tasks_limit: int = DEFAULT_UPLOAD_TASKS_LIMIT, + max_retries: int = DEFAULT_UPLOAD_MAX_RETRIES, + progress_callback=None, + ) -> dict: + rebuilt_docs = [] + failed_docs = [] + normalized_doc_ids = list(dict.fromkeys(doc_ids)) + + total = len(normalized_doc_ids) + for index, doc_id in enumerate(normalized_doc_ids, start=1): + if progress_callback: + await progress_callback("rebuilding", index - 1, total) + try: + rebuilt = await self.rebuild_document( + doc_id, + chunk_size=chunk_size, + chunk_overlap=chunk_overlap, + batch_size=batch_size, + tasks_limit=tasks_limit, + max_retries=max_retries, + progress_callback=progress_callback, + ) + rebuilt_docs.append(rebuilt.model_dump()) + except Exception as e: + logger.error(f"重建文档 {doc_id} 失败: {e}") + failed_doc = await self.get_document(doc_id) + failed_docs.append( + { + "doc_id": doc_id, + "doc_name": failed_doc.doc_name if failed_doc else doc_id, + "error": str(e), + }, + ) + + if progress_callback: + await progress_callback("rebuilding", total, total) + + return { + "rebuilt": rebuilt_docs, + "failed": failed_docs, + "total": total, + "success_count": len(rebuilt_docs), + "failed_count": len(failed_docs), + } + + def _cleanup_document_files( + self, + doc: KBDocument, + media_items: list[KBMedia], + ) -> None: + file_paths: list[Path] = [] + if doc.file_path: + file_paths.append(Path(doc.file_path)) + file_paths.extend(Path(media.file_path) for media in media_items) + + cleanup_roots = ( + self.kb_files_dir.resolve(strict=False), + self.kb_medias_dir.resolve(strict=False), + ) + for file_path in file_paths: + resolved_path = file_path.resolve(strict=False) + if not any(resolved_path.is_relative_to(root) for root in cleanup_roots): + logger.warning( + f"跳过清理知识库目录外文件: {resolved_path}", + ) + continue + try: + if resolved_path.exists(): + resolved_path.unlink() + parent = resolved_path.parent + if any(parent.is_relative_to(root) for root in cleanup_roots): + try: + parent.rmdir() + except OSError: + pass + except Exception as e: + logger.warning(f"清理知识库文件失败 {resolved_path}: {e}") + async def delete_chunk(self, chunk_id: str, doc_id: str) -> None: """删除单个文本块及其相关数据""" vec_db: FaissVecDB = self.vec_db # type: ignore @@ -587,20 +1288,102 @@ async def get_chunks_by_doc_id( offset=offset, limit=limit, ) - result = [] - for chunk in chunks: - chunk_md = json.loads(chunk["metadata"]) - result.append( - { - "chunk_id": chunk["doc_id"], - "doc_id": chunk_md["kb_doc_id"], - "kb_id": chunk_md["kb_id"], - "chunk_index": chunk_md["chunk_index"], - "content": chunk["text"], - "char_count": len(chunk["text"]), - }, + return [self._format_chunk_response(chunk) for chunk in chunks] + + async def search_chunks_by_doc_id( + self, + doc_id: str, + search: str | None = None, + offset: int = 0, + limit: int = 100, + ) -> tuple[list[dict], int]: + """Search or list chunks for one document with a matching total.""" + if not search: + chunks = await self.get_chunks_by_doc_id( + doc_id=doc_id, + offset=offset, + limit=limit, ) - return result + return chunks, await self.get_chunk_count_by_doc_id(doc_id) + + vec_db: FaissVecDB = self.vec_db # type: ignore + search_documents = getattr(vec_db.document_storage, "search_documents", None) + if search_documents is None: + return [], 0 + + result = await search_documents( + search, + metadata_filters={"kb_doc_id": doc_id}, + offset=offset, + limit=limit, + ) + if result is None: + return [], 0 + chunks, total = result + return [self._format_chunk_response(chunk) for chunk in chunks], total + + @staticmethod + def _format_chunk_response(chunk: dict) -> dict: + chunk_md = json.loads(chunk["metadata"]) + char_count = chunk_md.get("char_count", len(chunk["text"])) + return { + "chunk_id": chunk["doc_id"], + "doc_id": chunk_md["kb_doc_id"], + "kb_id": chunk_md["kb_id"], + "chunk_index": chunk_md["chunk_index"], + "section_index": chunk_md.get("section_index"), + "content": chunk["text"], + "char_count": char_count, + "token_count_estimate": chunk_md.get("token_count_estimate"), + "content_hash": chunk_md.get("content_hash"), + "start_offset": chunk_md.get("start_offset"), + "end_offset": chunk_md.get("end_offset"), + "previous_chunk_id": chunk_md.get("previous_chunk_id"), + "next_chunk_id": chunk_md.get("next_chunk_id"), + "title_path": chunk_md.get("title_path"), + "page_number": chunk_md.get("page_number"), + "parent_chunk_id": chunk_md.get("parent_chunk_id"), + } + + async def get_chunk_by_id( + self, + chunk_id: str, + doc_id: str | None = None, + ) -> dict | None: + """获取单个文本块及其元数据""" + vec_db: FaissVecDB = self.vec_db # type: ignore + chunk = await vec_db.document_storage.get_document_by_doc_id(chunk_id) + if not chunk: + return None + formatted_chunk = self._format_chunk_response(chunk) + if doc_id and formatted_chunk["doc_id"] != doc_id: + return None + return formatted_chunk + + async def get_chunk_context(self, chunk_id: str, doc_id: str) -> dict: + """获取文本块和相邻上下文块""" + current = await self.get_chunk_by_id(chunk_id, doc_id) + if not current: + raise ValueError(f"无法找到 ID 为 {chunk_id} 的文本块") + + previous_chunk = None + next_chunk = None + if current.get("previous_chunk_id"): + previous_chunk = await self.get_chunk_by_id( + current["previous_chunk_id"], + doc_id, + ) + if current.get("next_chunk_id"): + next_chunk = await self.get_chunk_by_id( + current["next_chunk_id"], + doc_id, + ) + + return { + "previous": previous_chunk, + "current": current, + "next": next_chunk, + } async def get_chunk_count_by_doc_id(self, doc_id: str) -> int: """获取文档的块数量""" @@ -608,6 +1391,434 @@ async def get_chunk_count_by_doc_id(self, doc_id: str) -> int: count = await vec_db.count_documents(metadata_filter={"kb_doc_id": doc_id}) return count + async def check_consistency(self) -> dict: + """Return a read-only consistency report for document metadata and chunks.""" + docs = await self._list_all_documents_for_consistency() + doc_by_id = {doc.doc_id: doc for doc in docs} + stored_chunks = await self._list_all_chunks_for_consistency() + + chunks_by_doc_id: dict[str, list[dict]] = {} + orphan_vectors: list[dict] = [] + invalid_vector_metadata: list[dict] = [] + + for chunk in stored_chunks: + try: + metadata = self._parse_stored_chunk_metadata(chunk) + except ValueError as exc: + invalid_vector_metadata.append( + self._format_vector_issue(chunk, metadata_error=str(exc)), + ) + continue + + doc_id = metadata.get("kb_doc_id") + if not isinstance(doc_id, str) or not doc_id: + invalid_vector_metadata.append( + self._format_vector_issue( + chunk, + metadata=metadata, + metadata_error="missing kb_doc_id", + ), + ) + continue + + if doc_id not in doc_by_id: + orphan_vectors.append( + self._format_vector_issue(chunk, metadata=metadata), + ) + continue + + chunks_by_doc_id.setdefault(doc_id, []).append(chunk) + + missing_vectors: list[dict] = [] + chunk_count_mismatches: list[dict] = [] + for doc in docs: + expected_chunk_count = int(doc.chunk_count or 0) + actual_chunk_count = len(chunks_by_doc_id.get(doc.doc_id, [])) + if expected_chunk_count > 0 and actual_chunk_count == 0: + missing_vectors.append( + self._format_document_issue( + doc, + expected_chunk_count=expected_chunk_count, + actual_chunk_count=actual_chunk_count, + ), + ) + if expected_chunk_count != actual_chunk_count: + chunk_count_mismatches.append( + self._format_document_issue( + doc, + expected_chunk_count=expected_chunk_count, + actual_chunk_count=actual_chunk_count, + ), + ) + + missing_source_files, unsafe_source_paths, source_file_count = ( + self._check_source_file_consistency(docs) + ) + + status_counts: dict[str, int] = {} + for doc in docs: + status = doc.status or "unknown" + status_counts[status] = status_counts.get(status, 0) + 1 + + issues = { + "missing_vectors": missing_vectors, + "orphan_vectors": orphan_vectors, + "missing_source_files": missing_source_files, + "chunk_count_mismatches": chunk_count_mismatches, + "invalid_vector_metadata": invalid_vector_metadata, + "unsafe_source_paths": unsafe_source_paths, + } + issue_counts = {name: len(items) for name, items in issues.items()} + + return { + "kb_id": self.kb.kb_id, + "kb_name": self.kb.kb_name, + "checked_at": datetime.now(timezone.utc).isoformat(), + "summary": { + "sqlite_document_count": len(docs), + "ready_document_count": status_counts.get("ready", 0), + "failed_document_count": status_counts.get("failed", 0), + "document_chunk_count": sum(int(doc.chunk_count or 0) for doc in docs), + "indexed_chunk_count": len(stored_chunks), + "source_file_count": source_file_count, + "status_counts": status_counts, + **issue_counts, + "healthy": all(count == 0 for count in issue_counts.values()), + }, + "issues": issues, + } + + async def repair_consistency( + self, + repair_types: list[str] | None = None, + ) -> dict: + """Repair low-risk consistency issues and report skipped unsafe issues.""" + selected_repair_types = self._normalize_consistency_repair_types(repair_types) + pre_check = await self.check_consistency() + + repaired: list[dict] = [] + skipped: list[dict] = [] + failed: list[dict] = [] + + if "orphan_vectors" in selected_repair_types: + orphan_vectors = pre_check["issues"].get("orphan_vectors", []) + orphan_doc_ids = sorted( + { + issue.get("doc_id") + for issue in orphan_vectors + if isinstance(issue.get("doc_id"), str) and issue.get("doc_id") + }, + ) + for doc_id in orphan_doc_ids: + issue_count = sum( + 1 for issue in orphan_vectors if issue.get("doc_id") == doc_id + ) + try: + await self.vec_db.delete_documents( # type: ignore[attr-defined] + metadata_filters={ + "kb_id": self.kb.kb_id, + "kb_doc_id": doc_id, + }, + ) + repaired.append( + { + "type": "orphan_vectors", + "doc_id": doc_id, + "count": issue_count, + "action": "deleted_vectors", + }, + ) + except Exception as exc: + failed.append( + { + "type": "orphan_vectors", + "doc_id": doc_id, + "count": issue_count, + "action": "delete_vectors", + "error": str(exc), + }, + ) + + if "chunk_count_mismatches" in selected_repair_types: + for issue in pre_check["issues"].get("chunk_count_mismatches", []): + doc_id = issue.get("doc_id") + expected_count = int(issue.get("expected_chunk_count") or 0) + actual_count = int(issue.get("actual_chunk_count") or 0) + if not isinstance(doc_id, str) or not doc_id: + skipped.append( + { + "type": "chunk_count_mismatches", + "reason": "missing_doc_id", + "issue": issue, + }, + ) + continue + + if expected_count > actual_count: + skipped.append( + { + "type": "chunk_count_mismatches", + "doc_id": doc_id, + "reason": "missing_vectors_require_rebuild", + "expected_chunk_count": expected_count, + "actual_chunk_count": actual_count, + }, + ) + continue + + try: + await self.refresh_document(doc_id) + repaired.append( + { + "type": "chunk_count_mismatches", + "doc_id": doc_id, + "action": "refreshed_document_chunk_count", + "expected_chunk_count": expected_count, + "actual_chunk_count": actual_count, + }, + ) + except Exception as exc: + failed.append( + { + "type": "chunk_count_mismatches", + "doc_id": doc_id, + "action": "refresh_document", + "expected_chunk_count": expected_count, + "actual_chunk_count": actual_count, + "error": str(exc), + }, + ) + + for issue_type in ( + "missing_vectors", + "missing_source_files", + "invalid_vector_metadata", + "unsafe_source_paths", + ): + for issue in pre_check["issues"].get(issue_type, []): + skipped.append( + { + "type": issue_type, + "doc_id": issue.get("doc_id"), + "chunk_id": issue.get("chunk_id"), + "reason": self._get_consistency_repair_skip_reason( + issue_type, + ), + "issue": issue, + }, + ) + + if repaired or failed: + await self.kb_db.update_kb_stats( + kb_id=self.kb.kb_id, + vec_db=self.vec_db, # type: ignore + ) + await self.refresh_kb() + + post_check = await self.check_consistency() + return { + "kb_id": self.kb.kb_id, + "kb_name": self.kb.kb_name, + "repaired_at": datetime.now(timezone.utc).isoformat(), + "repair_types": selected_repair_types, + "summary": { + "repaired_count": len(repaired), + "skipped_count": len(skipped), + "failed_count": len(failed), + "healthy_after_repair": post_check["summary"]["healthy"], + }, + "actions": { + "repaired": repaired, + "skipped": skipped, + "failed": failed, + }, + "pre_check": pre_check, + "post_check": post_check, + } + + @staticmethod + def _normalize_consistency_repair_types( + repair_types: list[str] | None, + ) -> list[str]: + if repair_types is None: + return sorted(CONSISTENCY_REPAIR_TYPES) + + normalized = list( + dict.fromkeys( + repair_type.strip() + for repair_type in repair_types + if isinstance(repair_type, str) and repair_type.strip() + ), + ) + invalid_types = sorted(set(normalized) - CONSISTENCY_REPAIR_TYPES) + if invalid_types: + raise ValueError( + f"不支持的一致性修复类型: {', '.join(invalid_types)}", + ) + return normalized + + @staticmethod + def _get_consistency_repair_skip_reason(issue_type: str) -> str: + skip_reasons = { + "missing_vectors": "document_rebuild_required", + "missing_source_files": "source_file_missing_manual_action_required", + "invalid_vector_metadata": "invalid_metadata_manual_action_required", + "unsafe_source_paths": "unsafe_source_path_manual_action_required", + } + return skip_reasons.get(issue_type, "manual_action_required") + + async def _list_all_documents_for_consistency(self) -> list[KBDocument]: + return await self._collect_paginated_documents( + page_size=CONSISTENCY_CHECK_PAGE_SIZE, + ) + + async def _list_all_chunks_for_consistency(self) -> list[dict]: + return await self._collect_paginated_vector_documents( + page_size=CONSISTENCY_CHECK_PAGE_SIZE, + unsupported_message="当前知识库存储后端不支持一致性检查", + ) + + @staticmethod + def _parse_stored_chunk_metadata(chunk: dict) -> dict: + raw_metadata = chunk.get("metadata") + if raw_metadata is None: + return {} + if isinstance(raw_metadata, dict): + return raw_metadata + try: + metadata = json.loads(raw_metadata) + except (TypeError, json.JSONDecodeError) as exc: + raise ValueError("invalid metadata JSON") from exc + if not isinstance(metadata, dict): + raise ValueError("metadata must be a JSON object") + return metadata + + @staticmethod + def _format_vector_issue( + chunk: dict, + *, + metadata: dict | None = None, + metadata_error: str | None = None, + ) -> dict: + issue = { + "chunk_id": chunk.get("doc_id"), + "storage_id": chunk.get("id"), + } + if metadata: + issue.update( + { + "doc_id": metadata.get("kb_doc_id"), + "kb_id": metadata.get("kb_id"), + "chunk_index": metadata.get("chunk_index"), + }, + ) + if metadata_error: + issue["metadata_error"] = metadata_error + return issue + + @staticmethod + def _format_document_issue( + doc: KBDocument, + *, + expected_chunk_count: int | None = None, + actual_chunk_count: int | None = None, + reason: str | None = None, + ) -> dict: + issue = { + "doc_id": doc.doc_id, + "doc_name": doc.doc_name, + "status": doc.status, + "source_type": doc.source_type, + "file_path": doc.file_path, + } + if expected_chunk_count is not None: + issue["expected_chunk_count"] = expected_chunk_count + if actual_chunk_count is not None: + issue["actual_chunk_count"] = actual_chunk_count + if reason: + issue["reason"] = reason + return issue + + def _check_source_file_consistency( + self, + docs: list[KBDocument], + ) -> tuple[list[dict], list[dict], int]: + missing_source_files: list[dict] = [] + unsafe_source_paths: list[dict] = [] + source_file_count = 0 + files_root = self.kb_files_dir.resolve(strict=False) + + for doc in docs: + if doc.source_type != "file": + continue + + if not doc.file_path: + if doc.status == "ready": + missing_source_files.append( + self._format_document_issue(doc, reason="empty_file_path"), + ) + continue + + file_path = Path(doc.file_path).resolve(strict=False) + if not file_path.is_relative_to(files_root): + unsafe_source_paths.append( + self._format_document_issue( + doc, + reason="outside_kb_files_dir", + ), + ) + continue + if file_path.exists(): + source_file_count += 1 + else: + missing_source_files.append( + self._format_document_issue(doc, reason="not_found"), + ) + + return missing_source_files, unsafe_source_paths, source_file_count + + async def _collect_paginated_documents(self, *, page_size: int) -> list[KBDocument]: + docs: list[KBDocument] = [] + offset = 0 + while True: + page = await self.list_documents( + offset=offset, + limit=page_size, + ) + docs.extend(page) + if len(page) < page_size: + break + offset += page_size + return docs + + async def _collect_paginated_vector_documents( + self, + *, + page_size: int, + unsupported_message: str, + ) -> list[dict]: + document_storage = getattr(self.vec_db, "document_storage", None) + get_documents = getattr(document_storage, "get_documents", None) + if get_documents is None: + raise ValueError(unsupported_message) + + chunks: list[dict] = [] + offset = 0 + while True: + page_result = get_documents( + metadata_filters={"kb_id": self.kb.kb_id}, + offset=offset, + limit=page_size, + ) + if not hasattr(page_result, "__await__"): + raise ValueError(unsupported_message) + page = await page_result + chunks.extend(page) + if len(page) < page_size: + break + offset += page_size + return chunks + async def _save_media( self, doc_id: str, @@ -642,14 +1853,17 @@ async def _save_media( async def upload_from_url( self, url: str, - chunk_size: int = 512, - chunk_overlap: int = 50, - batch_size: int = 32, - tasks_limit: int = 3, - max_retries: int = 3, + chunk_size: int = DEFAULT_CHUNK_SIZE, + chunk_overlap: int = DEFAULT_CHUNK_OVERLAP, + batch_size: int = DEFAULT_UPLOAD_BATCH_SIZE, + tasks_limit: int = DEFAULT_UPLOAD_TASKS_LIMIT, + max_retries: int = DEFAULT_UPLOAD_MAX_RETRIES, progress_callback=None, enable_cleaning: bool = False, cleaning_provider_id: str | None = None, + parent_doc_id: str | None = None, + document_version: int = 1, + skip_duplicate_check: bool = False, ) -> KBDocument: """从 URL 上传并处理文档(带原子性保证和失败清理) Args: @@ -669,52 +1883,100 @@ async def upload_from_url( ValueError: 如果 URL 为空或无法提取内容 IOError: 如果网络请求失败 """ - # 获取 Tavily API 密钥 - config = self.prov_mgr.acm.default_conf - tavily_keys = config.get("provider_settings", {}).get( - "websearch_tavily_key", [] - ) - if not tavily_keys: - raise ValueError( - "Error: Tavily API key is not configured in provider_settings." + text_content: str | None = None + try: + # 获取 Tavily API 密钥 + config = self.prov_mgr.acm.default_conf + tavily_keys = config.get("provider_settings", {}).get( + "websearch_tavily_key", [] ) + if not tavily_keys: + raise KnowledgeBaseUploadError( + stage="configuration", + user_message=( + "URL 导入失败:Tavily API key 未配置。" + "请先在 provider_settings 中配置 websearch_tavily_key。" + ), + details={"url": url}, + ) - # 阶段1: 从 URL 提取内容 - if progress_callback: - await progress_callback("extracting", 0, 100) + # 阶段1: 从 URL 提取内容 + if progress_callback: + await progress_callback("extracting", 0, 100) - try: - text_content = await extract_text_from_url(url, tavily_keys) - except Exception as e: - logger.error(f"Failed to extract content from URL {url}: {e}") - raise OSError(f"Failed to extract content from URL {url}: {e}") from e + try: + text_content = await extract_text_from_url(url, tavily_keys) + except KnowledgeBaseUploadError: + raise + except Exception as e: + logger.error(f"Failed to extract content from URL {url}: {e}") + raise KnowledgeBaseUploadError( + stage="extracting", + user_message=( + "URL 导入失败:无法提取网页内容。" + "请确认 URL 可访问且 Tavily 配置有效。" + ), + details={"url": url}, + ) from e - if not text_content: - raise ValueError(f"No content extracted from URL: {url}") + if not text_content or not text_content.strip(): + raise KnowledgeBaseUploadError( + stage="extracting", + user_message=( + "URL 导入失败:未能从网页中提取可索引文本。" + "请确认页面存在正文内容,或尝试更换 URL。" + ), + details={"url": url}, + ) - if progress_callback: - await progress_callback("extracting", 100, 100) + if progress_callback: + await progress_callback("extracting", 100, 100) - # 阶段2: (可选)清洗内容并分块 - final_chunks = await self._clean_and_rechunk_content( - content=text_content, - url=url, - progress_callback=progress_callback, - enable_cleaning=enable_cleaning, - cleaning_provider_id=cleaning_provider_id, - chunk_size=chunk_size, - chunk_overlap=chunk_overlap, - ) + # 阶段2: (可选)清洗内容并分块 + try: + final_chunks = await self._clean_and_rechunk_content( + content=text_content, + url=url, + progress_callback=progress_callback, + enable_cleaning=enable_cleaning, + cleaning_provider_id=cleaning_provider_id, + chunk_size=chunk_size, + chunk_overlap=chunk_overlap, + ) + except KnowledgeBaseUploadError: + raise + except Exception as e: + stage = "cleaning" if enable_cleaning else "chunking" + raise KnowledgeBaseUploadError( + stage=stage, + user_message=( + "URL 导入失败:网页内容切分失败。" + "请稍后重试,或调整分块参数后再次导入。" + ), + details={"url": url}, + ) from e - if enable_cleaning and not final_chunks: - raise ValueError( - "内容清洗后未提取到有效文本。请尝试关闭内容清洗功能,或更换更高性能的LLM模型后重试。" + if enable_cleaning and not final_chunks: + raise KnowledgeBaseUploadError( + stage="cleaning", + user_message=( + "URL 导入失败:内容清洗后未提取到有效文本。" + "请尝试关闭内容清洗功能,或更换更高性能的 LLM 模型后重试。" + ), + details={"url": url}, + ) + except Exception as e: + await self._persist_failed_url_document( + url=url, + text_content=text_content, + parent_doc_id=parent_doc_id, + document_version=document_version, + error=e, ) + raise # 创建一个虚拟文件名 - file_name = url.split("/")[-1] or f"document_from_{url}" - if not Path(file_name).suffix: - file_name += ".url" + file_name = self._build_url_file_name(url) # 复用现有的 upload_document 方法,但传入预分块文本 return await self.upload_document( @@ -728,6 +1990,14 @@ async def upload_from_url( max_retries=max_retries, progress_callback=progress_callback, pre_chunked_text=final_chunks, + source_type="url", + source_uri=url, + source_content_hash=build_content_hash(text_content), + source_parser_name=URLExtractor.__name__, + source_chunker_name=get_chunker_name(self.chunker), + parent_doc_id=parent_doc_id, + document_version=document_version, + skip_duplicate_check=skip_duplicate_check, ) async def _clean_and_rechunk_content( @@ -738,8 +2008,8 @@ async def _clean_and_rechunk_content( enable_cleaning: bool = False, cleaning_provider_id: str | None = None, repair_max_rpm: int = 60, - chunk_size: int = 512, - chunk_overlap: int = 50, + chunk_size: int = DEFAULT_CHUNK_SIZE, + chunk_overlap: int = DEFAULT_CHUNK_OVERLAP, ) -> list[str]: """ 对从 URL 获取的内容进行清洗、修复、翻译和重新分块。 diff --git a/astrbot/core/knowledge_base/kb_mgr.py b/astrbot/core/knowledge_base/kb_mgr.py index d24b452e27..dc1dab016e 100644 --- a/astrbot/core/knowledge_base/kb_mgr.py +++ b/astrbot/core/knowledge_base/kb_mgr.py @@ -1,23 +1,41 @@ import asyncio import time from pathlib import Path +from typing import TYPE_CHECKING from sqlalchemy import delete from sqlmodel import col from astrbot.core import logger -from astrbot.core.provider.manager import ProviderManager from astrbot.core.utils.astrbot_path import get_astrbot_knowledge_base_path # from .chunking.fixed_size import FixedSizeChunker +from .capabilities import ( + DEFAULT_CHUNK_OVERLAP, + DEFAULT_CHUNK_SIZE, + DEFAULT_INDEX_TYPE, + DEFAULT_TOP_K_DENSE, + DEFAULT_TOP_K_SPARSE, + DEFAULT_TOP_M_FINAL, + DEFAULT_UPLOAD_BATCH_SIZE, + DEFAULT_UPLOAD_MAX_RETRIES, + DEFAULT_UPLOAD_TASKS_LIMIT, +) from .chunking.recursive import RecursiveCharacterChunker from .kb_db_sqlite import KBSQLiteDatabase from .kb_helper import KBHelper -from .models import KBDocument, KBMedia, KnowledgeBase +from .models import ( + KBDocument, + KBMedia, + KnowledgeBase, +) from .retrieval.manager import RetrievalManager, RetrievalResult from .retrieval.rank_fusion import RankFusion from .retrieval.sparse_retriever import SparseRetriever +if TYPE_CHECKING: + from astrbot.core.provider.manager import ProviderManager + FILES_PATH = get_astrbot_knowledge_base_path() DB_PATH = Path(FILES_PATH) / "kb.db" """Knowledge Base storage root directory""" @@ -65,7 +83,7 @@ class KnowledgeBaseManager: def __init__( self, - provider_manager: ProviderManager, + provider_manager: "ProviderManager", ) -> None: DB_PATH.parent.mkdir(parents=True, exist_ok=True) self.provider_manager = provider_manager @@ -214,12 +232,24 @@ async def create_kb( """创建新的知识库实例""" if embedding_provider_id is None: raise ValueError("创建知识库时必须提供embedding_provider_id") - effective_chunk_size = chunk_size if chunk_size is not None else 512 - effective_chunk_overlap = chunk_overlap if chunk_overlap is not None else 50 - effective_top_k_dense = top_k_dense if top_k_dense is not None else 50 - effective_top_k_sparse = top_k_sparse if top_k_sparse is not None else 50 - effective_top_m_final = top_m_final if top_m_final is not None else 5 - effective_index_type = index_type if index_type is not None else "flat" + effective_chunk_size = ( + chunk_size if chunk_size is not None else DEFAULT_CHUNK_SIZE + ) + effective_chunk_overlap = ( + chunk_overlap if chunk_overlap is not None else DEFAULT_CHUNK_OVERLAP + ) + effective_top_k_dense = ( + top_k_dense if top_k_dense is not None else DEFAULT_TOP_K_DENSE + ) + effective_top_k_sparse = ( + top_k_sparse if top_k_sparse is not None else DEFAULT_TOP_K_SPARSE + ) + effective_top_m_final = ( + top_m_final if top_m_final is not None else DEFAULT_TOP_M_FINAL + ) + effective_index_type = ( + index_type if index_type is not None else DEFAULT_INDEX_TYPE + ) _validate_kb_options( chunk_size=effective_chunk_size, chunk_overlap=effective_chunk_overlap, @@ -451,7 +481,9 @@ async def retrieve( kb_names: list[str] | None = None, kb_ids: list[str] | None = None, top_k_fusion: int = 20, - top_m_final: int = 5, + top_m_final: int = DEFAULT_TOP_M_FINAL, + include_trace: bool = False, + retrieval_overrides: dict | None = None, ) -> dict | None: """从指定知识库中检索相关内容""" resolved_kb_ids = [] @@ -488,15 +520,42 @@ async def retrieve( if not resolved_kb_ids: return {} - results = await self.retrieval_manager.retrieve( - query=query, - kb_ids=resolved_kb_ids, - kb_id_helper_map=kb_id_helper_map, - top_k_fusion=top_k_fusion, - top_m_final=top_m_final, - ) + trace_payload = None + if include_trace: + retrieval_response = await self.retrieval_manager.retrieve_with_trace( + query=query, + kb_ids=resolved_kb_ids, + kb_id_helper_map=kb_id_helper_map, + top_k_fusion=top_k_fusion, + top_m_final=top_m_final, + retrieval_overrides=retrieval_overrides, + ) + results = retrieval_response.results + trace_payload = retrieval_response.trace.to_dict() + else: + results = await self.retrieval_manager.retrieve( + query=query, + kb_ids=resolved_kb_ids, + kb_id_helper_map=kb_id_helper_map, + top_k_fusion=top_k_fusion, + top_m_final=top_m_final, + retrieval_overrides=retrieval_overrides, + ) if not results: - return None + empty_response = { + "context_text": "", + "results": [], + } + if include_trace: + empty_response["trace"] = trace_payload or { + "dense": [], + "sparse": [], + "fusion": [], + "dedup": [], + "rerank": [], + "final": [], + } + return empty_response if include_trace else None context_text = self._format_context(results) @@ -508,6 +567,7 @@ async def retrieve( "kb_name": r.kb_name, "doc_name": r.doc_name, "chunk_index": r.metadata.get("chunk_index", 0), + "source": self._format_result_source(r), "content": r.content, "score": r.score, "char_count": r.metadata.get("char_count", 0), @@ -515,10 +575,40 @@ async def retrieve( for r in results ] - return { + response = { "context_text": context_text, "results": results_dict, } + if include_trace: + response["trace"] = trace_payload + return response + + def _format_result_source(self, result: RetrievalResult) -> dict: + return { + "kb_name": result.kb_name, + "document_name": result.doc_name, + "chunk_index": result.metadata.get("chunk_index", 0), + "section_index": result.metadata.get("section_index"), + "title_path": result.metadata.get("title_path"), + "page_number": result.metadata.get("page_number"), + "parent_chunk_id": result.metadata.get("parent_chunk_id"), + } + + def _format_source_label(self, result: RetrievalResult) -> str: + source = self._format_result_source(result) + details = [] + title_path = source.get("title_path") + if isinstance(title_path, list) and title_path: + details.append(" > ".join(str(title) for title in title_path)) + if source.get("page_number") is not None: + details.append(f"第 {source['page_number']} 页") + if source.get("section_index") is not None: + details.append(f"章节 {source['section_index']}") + + base = f"{result.kb_name} / {result.doc_name}" + if details: + return f"{base} ({'; '.join(details)})" + return base def _format_context(self, results: list[RetrievalResult]) -> str: """格式化知识上下文 @@ -534,7 +624,7 @@ def _format_context(self, results: list[RetrievalResult]) -> str: for i, result in enumerate(results, 1): lines.append(f"【知识 {i}】") - lines.append(f"来源: {result.kb_name} / {result.doc_name}") + lines.append(f"来源: {self._format_source_label(result)}") lines.append(f"内容: {result.content}") lines.append(f"相关度: {result.score:.2f}") lines.append("") @@ -562,11 +652,11 @@ async def upload_from_url( self, kb_id: str, url: str, - chunk_size: int = 512, - chunk_overlap: int = 50, - batch_size: int = 32, - tasks_limit: int = 3, - max_retries: int = 3, + chunk_size: int = DEFAULT_CHUNK_SIZE, + chunk_overlap: int = DEFAULT_CHUNK_OVERLAP, + batch_size: int = DEFAULT_UPLOAD_BATCH_SIZE, + tasks_limit: int = DEFAULT_UPLOAD_TASKS_LIMIT, + max_retries: int = DEFAULT_UPLOAD_MAX_RETRIES, progress_callback=None, ) -> KBDocument: """从 URL 上传文档到指定的知识库 diff --git a/astrbot/core/knowledge_base/models.py b/astrbot/core/knowledge_base/models.py index a65cec0419..cd0e8290f0 100644 --- a/astrbot/core/knowledge_base/models.py +++ b/astrbot/core/knowledge_base/models.py @@ -3,6 +3,15 @@ from sqlmodel import Field, MetaData, SQLModel, Text, UniqueConstraint +from .capabilities import ( + DEFAULT_CHUNK_OVERLAP, + DEFAULT_CHUNK_SIZE, + DEFAULT_INDEX_TYPE, + DEFAULT_TOP_K_DENSE, + DEFAULT_TOP_K_SPARSE, + DEFAULT_TOP_M_FINAL, +) + class BaseKBModel(SQLModel, table=False): metadata = MetaData() @@ -34,14 +43,14 @@ class KnowledgeBase(BaseKBModel, table=True): embedding_provider_id: str | None = Field(default=None, max_length=100) rerank_provider_id: str | None = Field(default=None, max_length=100) # 分块配置参数 - chunk_size: int | None = Field(default=512, nullable=True) - chunk_overlap: int | None = Field(default=50, nullable=True) + chunk_size: int | None = Field(default=DEFAULT_CHUNK_SIZE, nullable=True) + chunk_overlap: int | None = Field(default=DEFAULT_CHUNK_OVERLAP, nullable=True) # 索引类型: "flat" (精确) 或 "hnsw" (近似最近邻,适合大规模) - index_type: str | None = Field(default="flat", max_length=10) + index_type: str | None = Field(default=DEFAULT_INDEX_TYPE, max_length=10) # 检索配置参数 - top_k_dense: int | None = Field(default=50, nullable=True) - top_k_sparse: int | None = Field(default=50, nullable=True) - top_m_final: int | None = Field(default=5, nullable=True) + top_k_dense: int | None = Field(default=DEFAULT_TOP_K_DENSE, nullable=True) + top_k_sparse: int | None = Field(default=DEFAULT_TOP_K_SPARSE, nullable=True) + top_m_final: int | None = Field(default=DEFAULT_TOP_M_FINAL, nullable=True) created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc)) updated_at: datetime = Field( default_factory=lambda: datetime.now(timezone.utc), @@ -83,6 +92,18 @@ class KBDocument(BaseKBModel, table=True): file_type: str = Field(max_length=20, nullable=False) file_size: int = Field(nullable=False) file_path: str = Field(max_length=512, nullable=False) + source_type: str = Field(default="file", max_length=20, nullable=False) + source_uri: str | None = Field(default=None, sa_type=Text) + content_hash: str | None = Field(default=None, max_length=64, index=True) + parser_name: str | None = Field(default=None, max_length=100) + parser_version: str | None = Field(default=None, max_length=50) + chunker_name: str | None = Field(default=None, max_length=100) + chunker_version: str | None = Field(default=None, max_length=50) + status: str = Field(default="ready", max_length=20, nullable=False, index=True) + error_stage: str | None = Field(default=None, max_length=50) + error_message: str | None = Field(default=None, sa_type=Text) + version: int = Field(default=1, nullable=False) + parent_doc_id: str | None = Field(default=None, max_length=36, index=True) chunk_count: int = Field(default=0, nullable=False) media_count: int = Field(default=0, nullable=False) created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc)) @@ -90,6 +111,7 @@ class KBDocument(BaseKBModel, table=True): default_factory=lambda: datetime.now(timezone.utc), sa_column_kwargs={"onupdate": datetime.now(timezone.utc)}, ) + indexed_at: datetime | None = Field(default=None) class KBMedia(BaseKBModel, table=True): @@ -120,3 +142,36 @@ class KBMedia(BaseKBModel, table=True): file_size: int = Field(nullable=False) mime_type: str = Field(max_length=100, nullable=False) created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc)) + + +class KBIngestionTask(BaseKBModel, table=True): + """Persistent knowledge-base ingestion task state.""" + + __tablename__ = "kb_ingestion_tasks" # type: ignore + + id: int | None = Field( + primary_key=True, + sa_column_kwargs={"autoincrement": True}, + default=None, + ) + task_id: str = Field( + max_length=36, + nullable=False, + unique=True, + default_factory=lambda: str(uuid.uuid4()), + index=True, + ) + kb_id: str = Field(max_length=36, nullable=False, index=True) + task_type: str = Field(max_length=30, nullable=False, index=True) + status: str = Field(default="pending", max_length=20, nullable=False, index=True) + progress_stage: str | None = Field(default=None, max_length=50) + progress_current: int = Field(default=0, nullable=False) + progress_total: int = Field(default=100, nullable=False) + progress: str | None = Field(default=None, sa_type=Text) + result: str | None = Field(default=None, sa_type=Text) + error: str | None = Field(default=None, sa_type=Text) + created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc)) + updated_at: datetime = Field( + default_factory=lambda: datetime.now(timezone.utc), + sa_column_kwargs={"onupdate": datetime.now(timezone.utc)}, + ) diff --git a/astrbot/core/knowledge_base/parsers/base.py b/astrbot/core/knowledge_base/parsers/base.py index 4ffca9c6f2..c204adcfeb 100644 --- a/astrbot/core/knowledge_base/parsers/base.py +++ b/astrbot/core/knowledge_base/parsers/base.py @@ -20,6 +20,14 @@ class MediaItem: mime_type: str +@dataclass +class TextSegment: + """Parsed text segment with optional source location metadata.""" + + text: str + metadata: dict + + @dataclass class ParseResult: """解析结果 @@ -29,6 +37,7 @@ class ParseResult: text: str media: list[MediaItem] + text_segments: list[TextSegment] | None = None class BaseParser(ABC): diff --git a/astrbot/core/knowledge_base/parsers/pdf_parser.py b/astrbot/core/knowledge_base/parsers/pdf_parser.py index aeeea930a2..811341f25c 100644 --- a/astrbot/core/knowledge_base/parsers/pdf_parser.py +++ b/astrbot/core/knowledge_base/parsers/pdf_parser.py @@ -11,6 +11,7 @@ BaseParser, MediaItem, ParseResult, + TextSegment, ) @@ -35,13 +36,20 @@ async def parse(self, file_content: bytes, file_name: str) -> ParseResult: reader = PdfReader(pdf_file) text_parts = [] + text_segments = [] media_items = [] # 提取文本 - for page in reader.pages: + for page_number, page in enumerate(reader.pages, start=1): text = page.extract_text() if text: text_parts.append(text) + text_segments.append( + TextSegment( + text=text, + metadata={"page_number": page_number}, + ) + ) # 提取图片 image_counter = 0 @@ -98,4 +106,8 @@ async def parse(self, file_content: bytes, file_name: str) -> ParseResult: continue full_text = "\n\n".join(text_parts) - return ParseResult(text=full_text, media=media_items) + return ParseResult( + text=full_text, + media=media_items, + text_segments=text_segments, + ) diff --git a/astrbot/core/knowledge_base/retrieval/__init__.py b/astrbot/core/knowledge_base/retrieval/__init__.py index b7c88075d5..26508c31f2 100644 --- a/astrbot/core/knowledge_base/retrieval/__init__.py +++ b/astrbot/core/knowledge_base/retrieval/__init__.py @@ -3,7 +3,12 @@ from typing import TYPE_CHECKING if TYPE_CHECKING: - from .manager import RetrievalManager, RetrievalResult + from .manager import ( + RetrievalManager, + RetrievalResult, + RetrievalTrace, + RetrievalWithTrace, + ) from .rank_fusion import FusedResult, RankFusion from .sparse_retriever import SparseResult, SparseRetriever @@ -12,18 +17,32 @@ "RankFusion", "RetrievalManager", "RetrievalResult", + "RetrievalTrace", + "RetrievalWithTrace", "SparseResult", "SparseRetriever", ] def __getattr__(name: str): - if name in {"RetrievalManager", "RetrievalResult"}: - from .manager import RetrievalManager, RetrievalResult + if name in { + "RetrievalManager", + "RetrievalResult", + "RetrievalTrace", + "RetrievalWithTrace", + }: + from .manager import ( + RetrievalManager, + RetrievalResult, + RetrievalTrace, + RetrievalWithTrace, + ) return { "RetrievalManager": RetrievalManager, "RetrievalResult": RetrievalResult, + "RetrievalTrace": RetrievalTrace, + "RetrievalWithTrace": RetrievalWithTrace, }[name] if name in {"FusedResult", "RankFusion"}: diff --git a/astrbot/core/knowledge_base/retrieval/manager.py b/astrbot/core/knowledge_base/retrieval/manager.py index 07543b48a7..dbb5a483c9 100644 --- a/astrbot/core/knowledge_base/retrieval/manager.py +++ b/astrbot/core/knowledge_base/retrieval/manager.py @@ -3,14 +3,20 @@ 协调稠密检索、稀疏检索和 Rerank,提供统一的检索接口 """ +import json import time from dataclasses import dataclass from typing import TYPE_CHECKING from astrbot import logger from astrbot.core.db.vec_db.base import Result +from astrbot.core.knowledge_base.capabilities import ( + DEFAULT_TOP_K_DENSE, + DEFAULT_TOP_K_SPARSE, + DEFAULT_TOP_M_FINAL, +) from astrbot.core.knowledge_base.kb_db_sqlite import KBSQLiteDatabase -from astrbot.core.knowledge_base.retrieval.rank_fusion import RankFusion +from astrbot.core.knowledge_base.retrieval.rank_fusion import FusedResult, RankFusion from astrbot.core.knowledge_base.retrieval.sparse_retriever import SparseRetriever from astrbot.core.provider.provider import RerankProvider @@ -20,6 +26,13 @@ from astrbot.core.db.vec_db.faiss_impl import FaissVecDB +RetrievalOverrideValue = int | str | None +RetrievalOverrides = dict[str, RetrievalOverrideValue] + +DEDUP_SHINGLE_SIZE = 5 +DEDUP_JACCARD_THRESHOLD = 0.92 + + @dataclass class RetrievalResult: """检索结果""" @@ -34,6 +47,38 @@ class RetrievalResult: metadata: dict +@dataclass +class RetrievalTrace: + """Detailed retrieval pipeline trace for diagnostics.""" + + dense: list[dict] + sparse: list[dict] + fusion: list[dict] + dedup: list[dict] + dedup_removed: list[dict] + rerank: list[dict] + final: list[dict] + + def to_dict(self) -> dict: + return { + "dense": self.dense, + "sparse": self.sparse, + "fusion": self.fusion, + "dedup": self.dedup, + "dedup_removed": self.dedup_removed, + "rerank": self.rerank, + "final": self.final, + } + + +@dataclass +class RetrievalWithTrace: + """Retrieval results with optional pipeline diagnostics.""" + + results: list[RetrievalResult] + trace: RetrievalTrace + + class RetrievalManager: """检索管理器 @@ -67,7 +112,8 @@ async def retrieve( kb_ids: list[str], kb_id_helper_map: dict[str, KBHelper], top_k_fusion: int = 20, - top_m_final: int = 5, + top_m_final: int = DEFAULT_TOP_M_FINAL, + retrieval_overrides: RetrievalOverrides | None = None, ) -> list[RetrievalResult]: """混合检索 @@ -90,24 +136,11 @@ async def retrieve( if not kb_ids: return [] - kb_options: dict = {} - new_kb_ids = [] - for kb_id in kb_ids: - kb_helper = kb_id_helper_map.get(kb_id) - if kb_helper: - kb = kb_helper.kb - kb_options[kb_id] = { - "top_k_dense": kb.top_k_dense or 50, - "top_k_sparse": kb.top_k_sparse or 50, - "top_m_final": kb.top_m_final or 5, - "vec_db": kb_helper.vec_db, - "rerank_provider_id": kb.rerank_provider_id, - } - new_kb_ids.append(kb_id) - else: - logger.warning(f"知识库 ID {kb_id} 实例未找到, 已跳过该知识库的检索") - - kb_ids = new_kb_ids + kb_ids, kb_options = self._build_kb_options( + kb_ids, + kb_id_helper_map, + retrieval_overrides=retrieval_overrides, + ) # 1. 稠密检索 time_start = time.time() @@ -140,15 +173,302 @@ async def retrieve( sparse_results=sparse_results, top_k=top_k_fusion, ) + deduped_results = self._deduplicate_fused_results(fused_results) time_end = time.time() logger.debug( - f"Rank fusion took {time_end - time_start:.2f}s and returned {len(fused_results)} results.", + f"Rank fusion took {time_end - time_start:.2f}s and returned " + f"{len(fused_results)} results; dedup kept {len(deduped_results)}.", ) # 4. 转换为 RetrievalResult (批量获取元数据) - doc_ids = {fr.doc_id for fr in fused_results} + doc_ids = {fr.doc_id for fr in deduped_results} metadata_map = await self.kb_db.get_documents_with_metadata_batch(doc_ids) + retrieval_results = self._build_retrieval_results( + fused_results=deduped_results, + metadata_map=metadata_map, + ) + # 5. Rerank + first_rerank = self._get_first_rerank_provider(kb_ids, kb_options) + if first_rerank and retrieval_results: + try: + retrieval_results = await self._rerank( + query=query, + results=retrieval_results, + top_k=top_m_final, + rerank_provider=first_rerank, + ) + except Exception as e: + logger.warning(f"Rerank 执行失败,已跳过重排序并使用融合结果: {e}") + + return retrieval_results[:top_m_final] + + async def retrieve_with_trace( + self, + query: str, + kb_ids: list[str], + kb_id_helper_map: dict[str, KBHelper], + top_k_fusion: int = 20, + top_m_final: int = DEFAULT_TOP_M_FINAL, + retrieval_overrides: RetrievalOverrides | None = None, + ) -> RetrievalWithTrace: + """Hybrid retrieval with detailed stage diagnostics.""" + if not kb_ids: + return RetrievalWithTrace( + results=[], + trace=RetrievalTrace( + dense=[], + sparse=[], + fusion=[], + dedup=[], + dedup_removed=[], + rerank=[], + final=[], + ), + ) + + kb_ids, kb_options = self._build_kb_options( + kb_ids, + kb_id_helper_map, + retrieval_overrides=retrieval_overrides, + ) + + dense_results = await self._dense_retrieve( + query=query, + kb_ids=kb_ids, + kb_options=kb_options, + ) + sparse_results = await self.sparse_retriever.retrieve( + query=query, + kb_ids=kb_ids, + kb_options=kb_options, + ) + fused_results = await self.rank_fusion.fuse( + dense_results=dense_results, + sparse_results=sparse_results, + top_k=top_k_fusion, + ) + deduped_results, dedup_removed_results = ( + self._deduplicate_fused_results_with_trace( + fused_results, + ) + ) + + doc_ids = self._collect_trace_doc_ids( + dense_results=dense_results, + sparse_results=sparse_results, + fused_results=fused_results, + ) + metadata_map = await self.kb_db.get_documents_with_metadata_batch(doc_ids) + doc_lookup = { + doc_id: { + "doc_name": metadata["document"].doc_name, + "kb_name": metadata["knowledge_base"].kb_name, + } + for doc_id, metadata in metadata_map.items() + } + + retrieval_results = self._build_retrieval_results( + fused_results=deduped_results, + metadata_map=metadata_map, + ) + + rerank_results: list[RetrievalResult] = [] + first_rerank = self._get_first_rerank_provider(kb_ids, kb_options) + if first_rerank and retrieval_results: + try: + retrieval_results = await self._rerank( + query=query, + results=retrieval_results, + top_k=top_m_final, + rerank_provider=first_rerank, + ) + rerank_results = retrieval_results + except Exception as e: + logger.warning(f"Rerank 执行失败,已跳过重排序并使用融合结果: {e}") + + final_results = retrieval_results[:top_m_final] + trace = RetrievalTrace( + dense=self._serialize_dense_trace(dense_results, doc_lookup), + sparse=self._serialize_sparse_trace(sparse_results, doc_lookup), + fusion=self._serialize_fusion_trace(fused_results, doc_lookup), + dedup=self._serialize_fusion_trace(deduped_results, doc_lookup), + dedup_removed=self._serialize_dedup_removed_trace( + dedup_removed_results, + doc_lookup, + ), + rerank=self._serialize_retrieval_trace(rerank_results, "rerank"), + final=self._serialize_retrieval_trace(final_results, "final"), + ) + return RetrievalWithTrace(results=final_results, trace=trace) + + def _build_kb_options( + self, + kb_ids: list[str], + kb_id_helper_map: dict[str, KBHelper], + *, + retrieval_overrides: RetrievalOverrides | None = None, + ) -> tuple[list[str], dict]: + kb_options: dict = {} + valid_kb_ids = [] + for kb_id in kb_ids: + kb_helper = kb_id_helper_map.get(kb_id) + if not kb_helper: + logger.warning(f"知识库 ID {kb_id} 实例未找到, 已跳过该知识库的检索") + continue + kb = kb_helper.kb + kb_option = { + "top_k_dense": kb.top_k_dense or DEFAULT_TOP_K_DENSE, + "top_k_sparse": kb.top_k_sparse or DEFAULT_TOP_K_SPARSE, + "top_m_final": kb.top_m_final or DEFAULT_TOP_M_FINAL, + "vec_db": kb_helper.vec_db, + "rerank_provider_id": kb.rerank_provider_id, + } + if retrieval_overrides: + for field_name in ( + "top_k_dense", + "top_k_sparse", + "top_m_final", + "rerank_provider_id", + ): + if field_name in retrieval_overrides: + kb_option[field_name] = retrieval_overrides[field_name] + kb_options[kb_id] = kb_option + valid_kb_ids.append(kb_id) + return valid_kb_ids, kb_options + + def _collect_trace_doc_ids( + self, + *, + dense_results: list[Result], + sparse_results, + fused_results, + ) -> set[str]: + doc_ids = {result.doc_id for result in sparse_results} + doc_ids.update(result.doc_id for result in fused_results) + for result in dense_results: + metadata = self._safe_metadata(result.data.get("metadata")) + doc_id = metadata.get("kb_doc_id") + if doc_id: + doc_ids.add(doc_id) + return doc_ids + + def _deduplicate_fused_results( + self, + fused_results: list[FusedResult], + ) -> list[FusedResult]: + deduped_results, _ = self._deduplicate_fused_results_with_trace(fused_results) + return deduped_results + + def _deduplicate_fused_results_with_trace( + self, + fused_results: list[FusedResult], + ) -> tuple[list[FusedResult], list[dict]]: + selected: list[FusedResult] = [] + removed: list[dict] = [] + signatures: list[tuple[FusedResult, str, frozenset[str]]] = [] + + for result in fused_results: + normalized = self._normalize_content_for_dedup(result.content) + if not normalized: + selected.append(result) + continue + + shingles = self._build_content_shingles(normalized) + duplicate_of = self._find_duplicate_signature( + normalized, + shingles, + signatures, + ) + if duplicate_of: + selected_result, selected_normalized, selected_shingles = duplicate_of + removed.append( + { + "result": result, + "duplicate_of": selected_result, + "similarity": self._dedup_similarity( + normalized, + shingles, + selected_normalized, + selected_shingles, + ), + }, + ) + continue + + selected.append(result) + signatures.append((result, normalized, shingles)) + + return selected, removed + + @staticmethod + def _normalize_content_for_dedup(content: str) -> str: + return "".join(str(content or "").lower().split()) + + @staticmethod + def _build_content_shingles( + normalized_content: str, + size: int = DEDUP_SHINGLE_SIZE, + ) -> frozenset[str]: + if not normalized_content: + return frozenset() + if len(normalized_content) <= size: + return frozenset({normalized_content}) + return frozenset( + normalized_content[index : index + size] + for index in range(len(normalized_content) - size + 1) + ) + + @staticmethod + def _is_duplicate_signature( + normalized: str, + shingles: frozenset[str], + existing: tuple[FusedResult, str, frozenset[str]], + ) -> bool: + _, existing_normalized, existing_shingles = existing + return ( + RetrievalManager._dedup_similarity( + normalized, + shingles, + existing_normalized, + existing_shingles, + ) + >= DEDUP_JACCARD_THRESHOLD + ) + + @staticmethod + def _dedup_similarity( + normalized: str, + shingles: frozenset[str], + existing_normalized: str, + existing_shingles: frozenset[str], + ) -> float: + if normalized == existing_normalized: + return 1.0 + if not shingles or not existing_shingles: + return 0.0 + union = len(shingles | existing_shingles) + if union == 0: + return 0.0 + return len(shingles & existing_shingles) / union + + def _find_duplicate_signature( + self, + normalized: str, + shingles: frozenset[str], + signatures: list[tuple[FusedResult, str, frozenset[str]]], + ) -> tuple[FusedResult, str, frozenset[str]] | None: + for signature in signatures: + if self._is_duplicate_signature(normalized, shingles, signature): + return signature + return None + + def _build_retrieval_results( + self, + *, + fused_results, + metadata_map: dict, + ) -> list[RetrievalResult]: retrieval_results = [] for fr in fused_results: metadata_dict = metadata_map.get(fr.doc_id) @@ -163,13 +483,22 @@ async def retrieve( content=fr.content, score=fr.score, metadata={ + **(fr.metadata or {}), "chunk_index": fr.chunk_index, "char_count": len(fr.content), + "dense_rank": fr.dense_rank, + "sparse_rank": fr.sparse_rank, + "dense_score": fr.dense_score, + "sparse_score": fr.sparse_score, + "rrf_score": fr.rrf_score + if fr.rrf_score is not None + else fr.score, }, ), ) + return retrieval_results - # 5. Rerank + def _get_first_rerank_provider(self, kb_ids: list[str], kb_options: dict): first_rerank = None for kb_id in kb_ids: vec_db = kb_options[kb_id]["vec_db"] @@ -188,18 +517,186 @@ async def retrieve( ): first_rerank = rerank_provider break - if first_rerank and retrieval_results: - try: - retrieval_results = await self._rerank( - query=query, - results=retrieval_results, - top_k=top_m_final, - rerank_provider=first_rerank, - ) - except Exception as e: - logger.warning(f"Rerank 执行失败,已跳过重排序并使用融合结果: {e}") + return first_rerank - return retrieval_results[:top_m_final] + @staticmethod + def _content_preview(content: str, limit: int = 240) -> str: + if len(content) <= limit: + return content + return f"{content[:limit]}..." + + def _serialize_dense_trace( + self, + dense_results: list[Result], + doc_lookup: dict[str, dict], + ) -> list[dict]: + trace = [] + for rank, result in enumerate(dense_results, 1): + chunk_id = result.data.get("doc_id") + metadata = self._safe_metadata(result.data.get("metadata")) + doc_id = metadata.get("kb_doc_id") + source = doc_lookup.get(doc_id, {}) + trace.append( + { + "rank": rank, + "chunk_id": chunk_id, + "doc_id": doc_id, + "doc_name": source.get("doc_name"), + "kb_id": metadata.get("kb_id"), + "kb_name": source.get("kb_name"), + "chunk_index": metadata.get("chunk_index", 0), + "score": result.similarity, + "dense_score": result.similarity, + "title_path": metadata.get("title_path"), + "page_number": metadata.get("page_number"), + "section_index": metadata.get("section_index"), + "content_preview": self._content_preview( + result.data.get("text", ""), + ), + }, + ) + return trace + + def _serialize_sparse_trace( + self, + sparse_results, + doc_lookup: dict[str, dict], + ) -> list[dict]: + trace = [] + for rank, result in enumerate(sparse_results, 1): + source = doc_lookup.get(result.doc_id, {}) + trace.append( + { + "rank": rank, + "chunk_id": result.chunk_id, + "doc_id": result.doc_id, + "doc_name": source.get("doc_name"), + "kb_id": result.kb_id, + "kb_name": source.get("kb_name"), + "chunk_index": result.chunk_index, + "score": result.score, + "sparse_score": result.score, + "title_path": (result.metadata or {}).get("title_path"), + "page_number": (result.metadata or {}).get("page_number"), + "section_index": (result.metadata or {}).get("section_index"), + "content_preview": self._content_preview(result.content), + }, + ) + return trace + + def _serialize_fusion_trace( + self, + fused_results, + doc_lookup: dict[str, dict], + ) -> list[dict]: + trace = [] + for rank, result in enumerate(fused_results, 1): + source = doc_lookup.get(result.doc_id, {}) + trace.append( + { + "rank": rank, + "chunk_id": result.chunk_id, + "doc_id": result.doc_id, + "doc_name": source.get("doc_name"), + "kb_id": result.kb_id, + "kb_name": source.get("kb_name"), + "chunk_index": result.chunk_index, + "score": result.score, + "dense_rank": result.dense_rank, + "sparse_rank": result.sparse_rank, + "dense_score": result.dense_score, + "sparse_score": result.sparse_score, + "rrf_score": result.rrf_score + if result.rrf_score is not None + else result.score, + "title_path": (result.metadata or {}).get("title_path"), + "page_number": (result.metadata or {}).get("page_number"), + "section_index": (result.metadata or {}).get("section_index"), + "content_preview": self._content_preview(result.content), + }, + ) + return trace + + def _serialize_dedup_removed_trace( + self, + removed_results: list[dict], + doc_lookup: dict[str, dict], + ) -> list[dict]: + trace = [] + for rank, removed in enumerate(removed_results, 1): + result = removed["result"] + duplicate_of = removed["duplicate_of"] + source = doc_lookup.get(result.doc_id, {}) + trace.append( + { + "rank": rank, + "chunk_id": result.chunk_id, + "doc_id": result.doc_id, + "doc_name": source.get("doc_name"), + "kb_id": result.kb_id, + "kb_name": source.get("kb_name"), + "chunk_index": result.chunk_index, + "score": result.score, + "dense_rank": result.dense_rank, + "sparse_rank": result.sparse_rank, + "dense_score": result.dense_score, + "sparse_score": result.sparse_score, + "rrf_score": result.rrf_score + if result.rrf_score is not None + else result.score, + "duplicate_of_chunk_id": duplicate_of.chunk_id, + "duplicate_of_doc_id": duplicate_of.doc_id, + "dedup_similarity": removed["similarity"], + "title_path": (result.metadata or {}).get("title_path"), + "page_number": (result.metadata or {}).get("page_number"), + "section_index": (result.metadata or {}).get("section_index"), + "content_preview": self._content_preview(result.content), + }, + ) + return trace + + def _serialize_retrieval_trace( + self, + results: list[RetrievalResult], + stage: str, + ) -> list[dict]: + trace = [] + for rank, result in enumerate(results, 1): + trace.append( + { + "rank": rank, + "chunk_id": result.chunk_id, + "doc_id": result.doc_id, + "doc_name": result.doc_name, + "kb_id": result.kb_id, + "kb_name": result.kb_name, + "chunk_index": result.metadata.get("chunk_index", 0), + "score": result.score, + "dense_rank": result.metadata.get("dense_rank"), + "sparse_rank": result.metadata.get("sparse_rank"), + "dense_score": result.metadata.get("dense_score"), + "sparse_score": result.metadata.get("sparse_score"), + "rrf_score": result.metadata.get("rrf_score"), + "rerank_score": result.metadata.get("rerank_score"), + "title_path": result.metadata.get("title_path"), + "page_number": result.metadata.get("page_number"), + "section_index": result.metadata.get("section_index"), + "stage": stage, + "content_preview": self._content_preview(result.content), + }, + ) + return trace + + @staticmethod + def _safe_metadata(raw_metadata) -> dict: + if not raw_metadata: + return {} + if isinstance(raw_metadata, dict): + return raw_metadata + try: + return json.loads(raw_metadata) + except Exception: + return {} async def _dense_retrieve( self, @@ -298,6 +795,7 @@ async def _rerank( idx = rerank_result.index if idx < len(results): result = results[idx] + result.metadata["rerank_score"] = rerank_result.relevance_score result.score = rerank_result.relevance_score reranked_list.append(result) diff --git a/astrbot/core/knowledge_base/retrieval/rank_fusion.py b/astrbot/core/knowledge_base/retrieval/rank_fusion.py index 744287e655..2dbb1a5bef 100644 --- a/astrbot/core/knowledge_base/retrieval/rank_fusion.py +++ b/astrbot/core/knowledge_base/retrieval/rank_fusion.py @@ -22,6 +22,12 @@ class FusedResult: kb_id: str content: str score: float + metadata: dict | None = None + dense_rank: int | None = None + sparse_rank: int | None = None + dense_score: float | None = None + sparse_score: float | None = None + rrf_score: float | None = None class RankFusion: @@ -131,6 +137,16 @@ async def fuse( kb_id=sr.kb_id, content=sr.content, score=rrf_scores[identifier], + metadata=sr.metadata, + dense_rank=dense_ranks.get(identifier), + sparse_rank=sparse_ranks.get(identifier), + dense_score=( + chunk_id_to_dense[identifier].similarity + if identifier in chunk_id_to_dense + else None + ), + sparse_score=sr.score, + rrf_score=rrf_scores[identifier], ), ) elif identifier in chunk_id_to_dense: @@ -145,6 +161,12 @@ async def fuse( kb_id=chunk_md["kb_id"], content=vec_result.data["text"], score=rrf_scores[identifier], + metadata=chunk_md, + dense_rank=dense_ranks.get(identifier), + sparse_rank=sparse_ranks.get(identifier), + dense_score=vec_result.similarity, + sparse_score=None, + rrf_score=rrf_scores[identifier], ), ) diff --git a/astrbot/core/knowledge_base/retrieval/sparse_retriever.py b/astrbot/core/knowledge_base/retrieval/sparse_retriever.py index c34728a273..8790d0224c 100644 --- a/astrbot/core/knowledge_base/retrieval/sparse_retriever.py +++ b/astrbot/core/knowledge_base/retrieval/sparse_retriever.py @@ -34,6 +34,7 @@ class SparseResult: kb_id: str content: str score: float + metadata: dict | None = None class SparseRetriever: @@ -86,6 +87,7 @@ async def retrieve( kb_id=kb_id, content=doc["text"], score=max(0.0, float(doc["score"])), + metadata=chunk_md, ), ) @@ -155,6 +157,7 @@ async def _retrieve_with_bm25( "kb_id": kb_id, "text": doc["text"], "kb_top_k": kb_top_k, + "metadata": chunk_md, } for doc, chunk_md in zip(result, chunk_mds) ] @@ -190,6 +193,7 @@ async def _retrieve_with_bm25( kb_id=chunk["kb_id"], content=chunk["text"], score=-float(score), + metadata=chunk["metadata"], ), ) diff --git a/astrbot/dashboard/routes/knowledge_base.py b/astrbot/dashboard/routes/knowledge_base.py index c17663de1d..06419f54f6 100644 --- a/astrbot/dashboard/routes/knowledge_base.py +++ b/astrbot/dashboard/routes/knowledge_base.py @@ -11,26 +11,35 @@ from astrbot.core import logger from astrbot.core.core_lifecycle import AstrBotCoreLifecycle +from astrbot.core.knowledge_base.capabilities import ( + ALLOWED_UPLOAD_EXTENSIONS, + DEFAULT_CHUNK_OVERLAP, + DEFAULT_CHUNK_PAGE_SIZE, + DEFAULT_CHUNK_SIZE, + DEFAULT_DOCUMENT_PAGE_SIZE, + DEFAULT_INDEX_TYPE, + DEFAULT_KB_PAGE_SIZE, + DEFAULT_TOP_K_DENSE, + DEFAULT_TOP_K_SPARSE, + DEFAULT_TOP_M_FINAL, + DEFAULT_UPLOAD_BATCH_SIZE, + DEFAULT_UPLOAD_MAX_RETRIES, + DEFAULT_UPLOAD_TASKS_LIMIT, + DOCUMENT_FILTER_SOURCE_TYPES, + DOCUMENT_FILTER_STATUSES, + MAX_BATCH_DELETE_DOCUMENTS, + MAX_BATCH_REBUILD_DOCUMENTS, + MAX_RETRIEVE_TOP_K, + MAX_UPLOAD_FILE_SIZE, + MAX_UPLOAD_FILES, + get_knowledge_base_capabilities, +) from astrbot.core.provider.provider import EmbeddingProvider, RerankProvider from astrbot.core.utils.astrbot_path import get_astrbot_temp_path from ..utils import generate_tsne_visualization from .route import Response, Route, RouteContext -ALLOWED_UPLOAD_EXTENSIONS = { - "adoc", - "docx", - "epub", - "md", - "markdown", - "pdf", - "rst", - "txt", - "xls", - "xlsx", -} -MAX_UPLOAD_FILE_SIZE = 128 * 1024 * 1024 - class KnowledgeBaseRoute(Route): """知识库管理路由 @@ -55,12 +64,16 @@ def __init__( # 注册路由 self.routes = { # 知识库管理 + "/kb/capabilities": ("GET", self.get_capabilities), "/kb/list": ("GET", self.list_kbs), "/kb/create": ("POST", self.create_kb), "/kb/get": ("GET", self.get_kb), "/kb/update": ("POST", self.update_kb), "/kb/delete": ("POST", self.delete_kb), "/kb/stats": ("GET", self.get_kb_stats), + "/kb/consistency/check": ("GET", self.check_kb_consistency), + "/kb/consistency/repair": ("POST", self.repair_kb_consistency), + "/kb/rebuild": ("POST", self.rebuild_kb), # 文档管理 "/kb/document/list": ("GET", self.list_documents), "/kb/document/upload": ("POST", self.upload_document), @@ -68,10 +81,15 @@ def __init__( "/kb/document/upload/url": ("POST", self.upload_document_from_url), "/kb/document/upload/progress": ("GET", self.get_upload_progress), "/kb/document/get": ("GET", self.get_document), + "/kb/document/rebuild": ("POST", self.rebuild_document), + "/kb/document/batch-rebuild": ("POST", self.batch_rebuild_documents), "/kb/document/delete": ("POST", self.delete_document), "/kb/document/batch-delete": ("POST", self.batch_delete_documents), + "/kb/task/get": ("GET", self.get_task), + "/kb/task/list": ("GET", self.list_tasks), # # 块管理 "/kb/chunk/list": ("GET", self.list_chunks), + "/kb/chunk/context": ("GET", self.get_chunk_context), "/kb/chunk/delete": ("POST", self.delete_chunk), # # 多媒体管理 # "/kb/media/list": ("GET", self.list_media), @@ -84,6 +102,77 @@ def __init__( def _get_kb_manager(self): return self.core_lifecycle.kb_manager + def _get_kb_db(self): + if not hasattr(self, "core_lifecycle"): + return None + kb_manager = self._get_kb_manager() + return getattr(kb_manager, "kb_db", None) + + @staticmethod + def _get_positive_query_int(name: str, default: int) -> int: + value = request.args.get(name, default, type=int) + return max(value if value is not None else default, 1) + + async def get_capabilities(self): + """Return knowledge base capabilities, defaults, and limits.""" + return Response().ok(get_knowledge_base_capabilities()).__dict__ + + async def _create_persistent_task( + self, + *, + task_id: str, + kb_id: str | None, + task_type: str, + status: str, + progress: dict | None = None, + ) -> None: + kb_db = self._get_kb_db() + if not kb_db or not kb_id: + return + try: + await kb_db.create_ingestion_task( + task_id=task_id, + kb_id=kb_id, + task_type=task_type, + status=status, + progress_stage=(progress or {}).get("stage"), + progress_current=(progress or {}).get("current", 0), + progress_total=(progress or {}).get("total", 100), + progress=progress, + ) + except Exception as e: + logger.warning(f"创建知识库持久任务记录失败 {task_id}: {e}") + + async def _update_persistent_task(self, task_id: str, **updates) -> None: + kb_db = self._get_kb_db() + if not kb_db: + return + try: + await kb_db.update_ingestion_task(task_id, **updates) + except Exception as e: + logger.warning(f"更新知识库持久任务记录失败 {task_id}: {e}") + + async def _get_persistent_task(self, task_id: str) -> dict | None: + kb_db = self._get_kb_db() + if not kb_db: + return None + try: + return await kb_db.get_ingestion_task(task_id) + except Exception as e: + logger.warning(f"读取知识库持久任务记录失败 {task_id}: {e}") + return None + + def _get_persistent_progress_updates(self, task_id: str) -> dict: + progress = self.upload_progress.get(task_id) + if not progress: + return {} + return { + "progress_stage": progress.get("stage"), + "progress_current": progress.get("current", 0), + "progress_total": progress.get("total", 100), + "progress": progress, + } + def _init_task(self, task_id: str, status: str = "pending") -> None: self.upload_tasks[task_id] = { "status": status, @@ -144,6 +233,16 @@ def _update_progress( if total is not None: p["total"] = total + async def _persist_progress(self, task_id: str) -> None: + progress = self.upload_progress.get(task_id) + if not progress: + return + await self._update_persistent_task( + task_id, + status=progress.get("status"), + **self._get_persistent_progress_updates(task_id), + ) + def _make_progress_callback(self, task_id: str, file_idx: int, file_name: str): async def _callback(stage: str, current: int, total: int) -> None: self._update_progress( @@ -155,6 +254,7 @@ async def _callback(stage: str, current: int, total: int) -> None: current=current, total=total, ) + await self._persist_progress(task_id) return _callback @@ -165,6 +265,13 @@ def _format_failed_doc_error(file_name: str, error: Exception) -> str: return message return f"{file_name}: {message}" + @staticmethod + def _format_size_limit(size_bytes: int) -> str: + size_mb = size_bytes / (1024 * 1024) + if size_mb.is_integer(): + return f"{int(size_mb)}MB" + return f"{size_mb:.2f}MB" + @staticmethod def _coerce_optional_int(value: Any, field_name: str) -> int | None: if value in (None, ""): @@ -174,6 +281,20 @@ def _coerce_optional_int(value: Any, field_name: str) -> int | None: except (TypeError, ValueError) as e: raise ValueError(f"{field_name} 必须是整数") from e + @staticmethod + def _coerce_optional_bool(value: Any, field_name: str) -> bool: + if isinstance(value, bool): + return value + if value in (None, ""): + return False + if isinstance(value, str): + lowered = value.strip().lower() + if lowered in {"true", "1", "yes", "on"}: + return True + if lowered in {"false", "0", "no", "off"}: + return False + raise ValueError(f"{field_name} 必须是布尔值") + @staticmethod def _validate_chunk_options( *, @@ -242,7 +363,8 @@ def _validate_upload_file(file_name: str, file_size: int) -> None: if file_type not in ALLOWED_UPLOAD_EXTENSIONS: raise ValueError(f"不支持的文件类型: {file_name}") if file_size > MAX_UPLOAD_FILE_SIZE: - raise ValueError(f"文件超过 128MB 限制: {file_name}") + limit = KnowledgeBaseRoute._format_size_limit(MAX_UPLOAD_FILE_SIZE) + raise ValueError(f"文件超过 {limit} 限制: {file_name}") async def _background_upload_task( self, @@ -267,6 +389,7 @@ async def _background_upload_task( "current": 0, "total": 100, } + await self._persist_progress(task_id) uploaded_docs = [] failed_docs = [] @@ -283,6 +406,7 @@ async def _background_upload_task( current=0, total=100, ) + await self._persist_progress(task_id) # 创建进度回调函数 progress_callback = self._make_progress_callback( @@ -324,11 +448,24 @@ async def _background_upload_task( } self._set_task_result(task_id, "completed", result=result) + await self._update_persistent_task( + task_id, + status="completed", + result=result, + error=None, + **self._get_persistent_progress_updates(task_id), + ) except Exception as e: logger.error(f"后台上传任务 {task_id} 失败: {e}") logger.error(traceback.format_exc()) self._set_task_result(task_id, "failed", error=str(e)) + await self._update_persistent_task( + task_id, + status="failed", + error=str(e), + **self._get_persistent_progress_updates(task_id), + ) finally: # 兜底清理:防止客户端不轮询 get_upload_progress 导致内存泄漏 asyncio.create_task(self._schedule_delayed_cleanup(task_id)) @@ -354,6 +491,7 @@ async def _background_import_task( "current": 0, "total": 100, } + await self._persist_progress(task_id) uploaded_docs = [] failed_docs = [] @@ -373,6 +511,7 @@ async def _background_import_task( current=0, total=100, ) + await self._persist_progress(task_id) # 创建进度回调函数 progress_callback = self._make_progress_callback( @@ -394,6 +533,8 @@ async def _background_import_task( max_retries=max_retries, progress_callback=progress_callback, pre_chunked_text=chunks, + source_type="import", + source_uri=file_name, ) uploaded_docs.append(doc.model_dump()) @@ -417,11 +558,250 @@ async def _background_import_task( } self._set_task_result(task_id, "completed", result=result) + await self._update_persistent_task( + task_id, + status="completed", + result=result, + error=None, + **self._get_persistent_progress_updates(task_id), + ) except Exception as e: logger.error(f"后台导入任务 {task_id} 失败: {e}") logger.error(traceback.format_exc()) self._set_task_result(task_id, "failed", error=str(e)) + await self._update_persistent_task( + task_id, + status="failed", + error=str(e), + **self._get_persistent_progress_updates(task_id), + ) + finally: + asyncio.create_task(self._schedule_delayed_cleanup(task_id)) + + async def _background_rebuild_document_task( + self, + task_id: str, + kb_helper, + doc_id: str, + chunk_size: int | None, + chunk_overlap: int | None, + batch_size: int, + tasks_limit: int, + max_retries: int, + ) -> None: + """Run a single document rebuild in the background.""" + try: + self._init_task(task_id, status="processing") + self.upload_progress[task_id] = { + "status": "processing", + "file_index": 0, + "file_total": 1, + "file_name": doc_id, + "stage": "rebuilding", + "current": 0, + "total": 100, + } + await self._persist_progress(task_id) + + progress_callback = self._make_progress_callback(task_id, 0, doc_id) + doc = await kb_helper.rebuild_document( + doc_id, + chunk_size=chunk_size, + chunk_overlap=chunk_overlap, + batch_size=batch_size, + tasks_limit=tasks_limit, + max_retries=max_retries, + progress_callback=progress_callback, + ) + + result = { + "task_id": task_id, + "rebuilt": [doc.model_dump()], + "failed": [], + "total": 1, + "success_count": 1, + "failed_count": 0, + } + self._update_progress( + task_id, + status="completed", + file_index=0, + file_name=doc_id, + stage="completed", + current=100, + total=100, + ) + self._set_task_result(task_id, "completed", result=result) + await self._update_persistent_task( + task_id, + status="completed", + result=result, + error=None, + **self._get_persistent_progress_updates(task_id), + ) + + except Exception as e: + logger.error(f"后台重建文档任务 {task_id} 失败: {e}") + logger.error(traceback.format_exc()) + self._set_task_result(task_id, "failed", error=str(e)) + await self._update_persistent_task( + task_id, + status="failed", + error=str(e), + **self._get_persistent_progress_updates(task_id), + ) + finally: + asyncio.create_task(self._schedule_delayed_cleanup(task_id)) + + async def _background_rebuild_kb_task( + self, + task_id: str, + kb_helper, + chunk_size: int | None, + chunk_overlap: int | None, + batch_size: int, + tasks_limit: int, + max_retries: int, + ) -> None: + """Run a full knowledge base rebuild in the background.""" + kb_name = getattr(getattr(kb_helper, "kb", None), "kb_name", "knowledge base") + try: + self._init_task(task_id, status="processing") + self.upload_progress[task_id] = { + "status": "processing", + "file_index": 0, + "file_total": 1, + "file_name": kb_name, + "stage": "rebuilding", + "current": 0, + "total": 100, + } + await self._persist_progress(task_id) + + progress_callback = self._make_progress_callback( + task_id, + 0, + kb_name, + ) + result = await kb_helper.rebuild_all_documents( + chunk_size=chunk_size, + chunk_overlap=chunk_overlap, + batch_size=batch_size, + tasks_limit=tasks_limit, + max_retries=max_retries, + progress_callback=progress_callback, + ) + result = { + "task_id": task_id, + **result, + } + completed_total = max(int(result.get("total") or 0), 1) + self._update_progress( + task_id, + status="completed", + file_index=0, + file_name=kb_name, + stage="completed", + current=completed_total, + total=completed_total, + ) + self._set_task_result(task_id, "completed", result=result) + await self._update_persistent_task( + task_id, + status="completed", + result=result, + error=None, + **self._get_persistent_progress_updates(task_id), + ) + + except Exception as e: + logger.error(f"后台重建知识库任务 {task_id} 失败: {e}") + logger.error(traceback.format_exc()) + self._set_task_result(task_id, "failed", error=str(e)) + await self._update_persistent_task( + task_id, + status="failed", + error=str(e), + **self._get_persistent_progress_updates(task_id), + ) + finally: + asyncio.create_task(self._schedule_delayed_cleanup(task_id)) + + async def _background_rebuild_documents_task( + self, + task_id: str, + kb_helper, + doc_ids: list[str], + chunk_size: int | None, + chunk_overlap: int | None, + batch_size: int, + tasks_limit: int, + max_retries: int, + ) -> None: + """Run selected document rebuilds in the background.""" + total = max(len(doc_ids), 1) + task_name = f"{len(doc_ids)} selected documents" + try: + self._init_task(task_id, status="processing") + self.upload_progress[task_id] = { + "status": "processing", + "file_index": 0, + "file_total": total, + "file_name": task_name, + "stage": "rebuilding", + "current": 0, + "total": total, + } + await self._persist_progress(task_id) + + progress_callback = self._make_progress_callback( + task_id, + 0, + task_name, + ) + result = await kb_helper.rebuild_documents( + doc_ids, + chunk_size=chunk_size, + chunk_overlap=chunk_overlap, + batch_size=batch_size, + tasks_limit=tasks_limit, + max_retries=max_retries, + progress_callback=progress_callback, + ) + result = { + "task_id": task_id, + **result, + } + completed_total = max(int(result.get("total") or 0), 1) + self._update_progress( + task_id, + status="completed", + file_index=0, + file_name=task_name, + stage="completed", + current=completed_total, + total=completed_total, + ) + self._set_task_result(task_id, "completed", result=result) + await self._update_persistent_task( + task_id, + status="completed", + result=result, + error=None, + **self._get_persistent_progress_updates(task_id), + ) + + except Exception as e: + logger.error(f"后台批量重建文档任务 {task_id} 失败: {e}") + logger.error(traceback.format_exc()) + self._set_task_result(task_id, "failed", error=str(e)) + await self._update_persistent_task( + task_id, + status="failed", + error=str(e), + **self._get_persistent_progress_updates(task_id), + ) finally: asyncio.create_task(self._schedule_delayed_cleanup(task_id)) @@ -430,20 +810,32 @@ async def list_kbs(self): Query 参数: - page: 页码 (默认 1) - - page_size: 每页数量 (默认 20) + - page_size: 每页数量 - refresh_stats: 是否刷新统计信息 (默认 false,首次加载时可设为 true) """ try: kb_manager = self._get_kb_manager() - page = request.args.get("page", 1, type=int) - page_size = request.args.get("page_size", 20, type=int) + page = self._get_positive_query_int("page", 1) + page_size = self._get_positive_query_int( + "page_size", + DEFAULT_KB_PAGE_SIZE, + ) + refresh_stats = request.args.get("refresh_stats") == "true" + kb_db = self._get_kb_db() kbs = await kb_manager.list_kbs() + total = len(kbs) + start = (page - 1) * page_size + paged_kbs = kbs[start : start + page_size] # 转换为字典列表 kb_list = [] - for kb in kbs: + for kb in paged_kbs: kb_dict = kb.model_dump() + if refresh_stats and kb_db and hasattr(kb_db, "get_kb_stats"): + stats = await kb_db.get_kb_stats(kb.kb_id) + if stats: + kb_dict.update(stats) # include init_error from KBHelper if present kb_helper = await kb_manager.get_kb(kb.kb_id) if kb_helper and kb_helper.init_error: @@ -452,7 +844,14 @@ async def list_kbs(self): return ( Response() - .ok({"items": kb_list, "page": page, "page_size": page_size}) + .ok( + { + "items": kb_list, + "page": page, + "page_size": page_size, + "total": total, + }, + ) .__dict__ ) except ValueError as e: @@ -507,12 +906,20 @@ async def create_kb(self): ) index_type = data.get("index_type") self._validate_kb_options( - chunk_size=chunk_size if chunk_size is not None else 512, - chunk_overlap=chunk_overlap if chunk_overlap is not None else 50, - top_k_dense=top_k_dense if top_k_dense is not None else 50, - top_k_sparse=top_k_sparse if top_k_sparse is not None else 50, - top_m_final=top_m_final if top_m_final is not None else 5, - index_type=index_type if index_type is not None else "flat", + chunk_size=chunk_size if chunk_size is not None else DEFAULT_CHUNK_SIZE, + chunk_overlap=chunk_overlap + if chunk_overlap is not None + else DEFAULT_CHUNK_OVERLAP, + top_k_dense=top_k_dense + if top_k_dense is not None + else DEFAULT_TOP_K_DENSE, + top_k_sparse=top_k_sparse + if top_k_sparse is not None + else DEFAULT_TOP_K_SPARSE, + top_m_final=top_m_final + if top_m_final is not None + else DEFAULT_TOP_M_FINAL, + index_type=index_type if index_type is not None else DEFAULT_INDEX_TYPE, ) # pre-check embedding dim @@ -773,12 +1180,28 @@ async def get_kb_stats(self): if not kb_helper: return Response().error("知识库不存在").__dict__ kb = kb_helper.kb + kb_db = self._get_kb_db() + if kb_db and hasattr(kb_db, "get_kb_stats"): + stats = await kb_db.get_kb_stats(kb_id) + if stats is not None: + return Response().ok(stats).__dict__ stats = { "kb_id": kb.kb_id, "kb_name": kb.kb_name, "doc_count": kb.doc_count, "chunk_count": kb.chunk_count, + "document_count": kb.doc_count, + "ready_document_count": kb.doc_count, + "failed_document_count": 0, + "pending_document_count": 0, + "processing_document_count": 0, + "indexed_chunk_count": kb.chunk_count, + "document_chunk_count": kb.chunk_count, + "media_count": 0, + "source_file_count": 0, + "storage_bytes": 0, + "status_counts": {"ready": kb.doc_count}, "created_at": kb.created_at.isoformat(), "updated_at": kb.updated_at.isoformat(), } @@ -792,6 +1215,56 @@ async def get_kb_stats(self): logger.error(traceback.format_exc()) return Response().error(f"获取知识库统计失败: {e!s}").__dict__ + async def check_kb_consistency(self): + """Check consistency across metadata, source files, and indexed chunks.""" + try: + kb_manager = self._get_kb_manager() + kb_id = request.args.get("kb_id") + if not kb_id: + return Response().error("缺少参数 kb_id").__dict__ + + kb_helper = await kb_manager.get_kb(kb_id) + if not kb_helper: + return Response().error("知识库不存在").__dict__ + + report = await kb_helper.check_consistency() + return Response().ok(report).__dict__ + + except ValueError as e: + return Response().error(str(e)).__dict__ + except Exception as e: + logger.error(f"检查知识库一致性失败: {e}") + logger.error(traceback.format_exc()) + return Response().error(f"检查知识库一致性失败: {e!s}").__dict__ + + async def repair_kb_consistency(self): + """Repair low-risk consistency issues for a knowledge base.""" + try: + kb_manager = self._get_kb_manager() + data = await request.json + + kb_id = data.get("kb_id") + if not kb_id: + return Response().error("缺少参数 kb_id").__dict__ + + repair_types = data.get("repair_types") + if repair_types is not None and not isinstance(repair_types, list): + return Response().error("repair_types 格式错误").__dict__ + + kb_helper = await kb_manager.get_kb(kb_id) + if not kb_helper: + return Response().error("知识库不存在").__dict__ + + report = await kb_helper.repair_consistency(repair_types=repair_types) + return Response().ok(report).__dict__ + + except ValueError as e: + return Response().error(str(e)).__dict__ + except Exception as e: + logger.error(f"修复知识库一致性失败: {e}") + logger.error(traceback.format_exc()) + return Response().error(f"修复知识库一致性失败: {e!s}").__dict__ + # ===== 文档管理 API ===== async def list_documents(self): @@ -800,7 +1273,7 @@ async def list_documents(self): Query 参数: - kb_id: 知识库 ID (必填) - page: 页码 (默认 1) - - page_size: 每页数量 (默认 20) + - page_size: 每页数量 """ try: kb_manager = self._get_kb_manager() @@ -811,9 +1284,18 @@ async def list_documents(self): if not kb_helper: return Response().error("知识库不存在").__dict__ - page = max(request.args.get("page", 1, type=int), 1) - page_size = max(request.args.get("page_size", 100, type=int), 1) + page = self._get_positive_query_int("page", 1) + page_size = self._get_positive_query_int( + "page_size", + DEFAULT_DOCUMENT_PAGE_SIZE, + ) search = (request.args.get("search") or "").strip() or None + status = (request.args.get("status") or "").strip() or None + source_type = (request.args.get("source_type") or "").strip() or None + if status and status not in DOCUMENT_FILTER_STATUSES: + return Response().error("status 参数无效").__dict__ + if source_type and source_type not in DOCUMENT_FILTER_SOURCE_TYPES: + return Response().error("source_type 参数无效").__dict__ offset = (page - 1) * page_size limit = page_size @@ -822,8 +1304,17 @@ async def list_documents(self): offset=offset, limit=limit, search=search, + status=status, + source_type=source_type, + ) + total = await kb_helper.count_documents( + search=search, + status=status, + source_type=source_type, ) - total = await kb_helper.count_documents(search=search) + document_count = total + if search is not None or status is not None or source_type is not None: + document_count = await kb_helper.count_documents() doc_list = [doc.model_dump() for doc in doc_list] @@ -835,6 +1326,8 @@ async def list_documents(self): "page": page, "page_size": page_size, "total": total, + "filtered_total": total, + "document_count": document_count, }, ) .__dict__ @@ -875,9 +1368,9 @@ async def upload_document(self): kb_id = None chunk_size = None chunk_overlap = None - batch_size = 32 - tasks_limit = 3 - max_retries = 3 + batch_size = None + tasks_limit = None + max_retries = None files_to_upload = [] # 存储待上传的文件信息列表 if content_type and "multipart/form-data" not in content_type: @@ -908,11 +1401,19 @@ async def upload_document(self): form_data.get("max_retries"), "max_retries", ) - chunk_size = chunk_size if chunk_size is not None else 512 - chunk_overlap = chunk_overlap if chunk_overlap is not None else 50 - batch_size = batch_size if batch_size is not None else 32 - tasks_limit = tasks_limit if tasks_limit is not None else 3 - max_retries = max_retries if max_retries is not None else 3 + chunk_size = chunk_size if chunk_size is not None else DEFAULT_CHUNK_SIZE + chunk_overlap = ( + chunk_overlap if chunk_overlap is not None else DEFAULT_CHUNK_OVERLAP + ) + batch_size = ( + batch_size if batch_size is not None else DEFAULT_UPLOAD_BATCH_SIZE + ) + tasks_limit = ( + tasks_limit if tasks_limit is not None else DEFAULT_UPLOAD_TASKS_LIMIT + ) + max_retries = ( + max_retries if max_retries is not None else DEFAULT_UPLOAD_MAX_RETRIES + ) self._validate_upload_options( chunk_size=chunk_size, chunk_overlap=chunk_overlap, @@ -935,8 +1436,10 @@ async def upload_document(self): return Response().error("缺少文件").__dict__ # 限制文件数量 - if len(file_list) > 10: - return Response().error("最多只能上传10个文件").__dict__ + if len(file_list) > MAX_UPLOAD_FILES: + return ( + Response().error(f"最多只能上传{MAX_UPLOAD_FILES}个文件").__dict__ + ) # 处理每个文件 for file in file_list: @@ -982,6 +1485,20 @@ async def upload_document(self): # 初始化任务状态 self._init_task(task_id, status="pending") + await self._create_persistent_task( + task_id=task_id, + kb_id=kb_id, + task_type="upload", + status="pending", + progress={ + "status": "pending", + "file_index": 0, + "file_total": len(files_to_upload), + "stage": "waiting", + "current": 0, + "total": 100, + }, + ) # 启动后台任务 asyncio.create_task( @@ -1038,9 +1555,13 @@ def _validate_import_request(self, data: dict): batch_size = self._coerce_optional_int(data.get("batch_size"), "batch_size") tasks_limit = self._coerce_optional_int(data.get("tasks_limit"), "tasks_limit") max_retries = self._coerce_optional_int(data.get("max_retries"), "max_retries") - batch_size = batch_size if batch_size is not None else 32 - tasks_limit = tasks_limit if tasks_limit is not None else 3 - max_retries = max_retries if max_retries is not None else 3 + batch_size = batch_size if batch_size is not None else DEFAULT_UPLOAD_BATCH_SIZE + tasks_limit = ( + tasks_limit if tasks_limit is not None else DEFAULT_UPLOAD_TASKS_LIMIT + ) + max_retries = ( + max_retries if max_retries is not None else DEFAULT_UPLOAD_MAX_RETRIES + ) self._validate_positive_int(batch_size, "batch_size") self._validate_positive_int(tasks_limit, "tasks_limit") if max_retries < 0: @@ -1078,6 +1599,20 @@ async def import_documents(self): # 初始化任务状态 self._init_task(task_id, status="pending") + await self._create_persistent_task( + task_id=task_id, + kb_id=kb_id, + task_type="import", + status="pending", + progress={ + "status": "pending", + "file_index": 0, + "file_total": len(documents), + "stage": "waiting", + "current": 0, + "total": 100, + }, + ) # 启动后台任务 asyncio.create_task( @@ -1129,7 +1664,23 @@ async def get_upload_progress(self): # 检查任务是否存在 if task_id not in self.upload_tasks: - return Response().error("找不到该任务").__dict__ + persistent_task = await self._get_persistent_task(task_id) + if persistent_task is None: + return Response().error("找不到该任务").__dict__ + response_data = { + "task_id": task_id, + "status": persistent_task["status"], + "progress_stage": persistent_task.get("progress_stage"), + "progress_current": persistent_task.get("progress_current", 0), + "progress_total": persistent_task.get("progress_total", 100), + } + if persistent_task.get("progress") is not None: + response_data["progress"] = persistent_task["progress"] + if persistent_task["status"] == "completed": + response_data["result"] = persistent_task.get("result") + if persistent_task["status"] == "failed": + response_data["error"] = persistent_task.get("error") + return Response().ok(response_data).__dict__ task_info = self.upload_tasks[task_id] status = task_info["status"] @@ -1163,6 +1714,69 @@ async def get_upload_progress(self): logger.error(traceback.format_exc()) return Response().error(f"获取上传进度失败: {e!s}").__dict__ + async def get_task(self): + """获取知识库持久任务详情""" + try: + task_id = request.args.get("task_id") + if not task_id: + return Response().error("缺少参数 task_id").__dict__ + + task = await self._get_persistent_task(task_id) + if not task: + return Response().error("任务不存在").__dict__ + return Response().ok(task).__dict__ + + except Exception as e: + logger.error(f"获取知识库任务失败: {e}") + logger.error(traceback.format_exc()) + return Response().error(f"获取知识库任务失败: {e!s}").__dict__ + + async def list_tasks(self): + """列出知识库持久任务""" + try: + kb_db = self._get_kb_db() + if not kb_db: + return Response().error("知识库数据库未初始化").__dict__ + + page = self._get_positive_query_int("page", 1) + page_size = self._get_positive_query_int( + "page_size", + DEFAULT_DOCUMENT_PAGE_SIZE, + ) + kb_id = (request.args.get("kb_id") or "").strip() or None + status = (request.args.get("status") or "").strip() or None + task_type = (request.args.get("task_type") or "").strip() or None + + tasks = await kb_db.list_ingestion_tasks( + kb_id=kb_id, + status=status, + task_type=task_type, + offset=(page - 1) * page_size, + limit=page_size, + ) + total = await kb_db.count_ingestion_tasks( + kb_id=kb_id, + status=status, + task_type=task_type, + ) + return ( + Response() + .ok( + { + "items": tasks, + "total": total, + "page": page, + "page_size": page_size, + }, + ) + .__dict__ + ) + + except Exception as e: + logger.error(f"获取知识库任务列表失败: {e}") + logger.error(traceback.format_exc()) + return Response().error(f"获取知识库任务列表失败: {e!s}").__dict__ + async def get_document(self): """获取文档详情 @@ -1226,6 +1840,371 @@ async def delete_document(self): logger.error(traceback.format_exc()) return Response().error(f"删除文档失败: {e!s}").__dict__ + async def rebuild_document(self): + """重建单个文档""" + try: + kb_manager = self._get_kb_manager() + data = await request.json + + kb_id = data.get("kb_id") + if not kb_id: + return Response().error("缺少参数 kb_id").__dict__ + doc_id = data.get("doc_id") + if not doc_id: + return Response().error("缺少参数 doc_id").__dict__ + + chunk_size = self._coerce_optional_int(data.get("chunk_size"), "chunk_size") + chunk_overlap = self._coerce_optional_int( + data.get("chunk_overlap"), + "chunk_overlap", + ) + batch_size = self._coerce_optional_int(data.get("batch_size"), "batch_size") + tasks_limit = self._coerce_optional_int( + data.get("tasks_limit"), + "tasks_limit", + ) + max_retries = self._coerce_optional_int( + data.get("max_retries"), + "max_retries", + ) + effective_chunk_size = ( + chunk_size if chunk_size is not None else DEFAULT_CHUNK_SIZE + ) + effective_chunk_overlap = ( + chunk_overlap if chunk_overlap is not None else DEFAULT_CHUNK_OVERLAP + ) + effective_batch_size = ( + batch_size if batch_size is not None else DEFAULT_UPLOAD_BATCH_SIZE + ) + effective_tasks_limit = ( + tasks_limit if tasks_limit is not None else DEFAULT_UPLOAD_TASKS_LIMIT + ) + effective_max_retries = ( + max_retries if max_retries is not None else DEFAULT_UPLOAD_MAX_RETRIES + ) + self._validate_upload_options( + chunk_size=effective_chunk_size, + chunk_overlap=effective_chunk_overlap, + batch_size=effective_batch_size, + tasks_limit=effective_tasks_limit, + max_retries=effective_max_retries, + ) + background = self._coerce_optional_bool( + data.get("background"), + "background", + ) + + kb_helper = await kb_manager.get_kb(kb_id) + if not kb_helper: + return Response().error("知识库不存在").__dict__ + + if background: + task_id = str(uuid.uuid4()) + self._init_task(task_id, status="pending") + await self._create_persistent_task( + task_id=task_id, + kb_id=kb_id, + task_type="document_rebuild", + status="pending", + progress={ + "status": "pending", + "file_index": 0, + "file_total": 1, + "file_name": doc_id, + "stage": "waiting", + "current": 0, + "total": 100, + }, + ) + asyncio.create_task( + self._background_rebuild_document_task( + task_id=task_id, + kb_helper=kb_helper, + doc_id=doc_id, + chunk_size=chunk_size, + chunk_overlap=chunk_overlap, + batch_size=effective_batch_size, + tasks_limit=effective_tasks_limit, + max_retries=effective_max_retries, + ), + ) + return ( + Response() + .ok( + { + "task_id": task_id, + "doc_id": doc_id, + "message": ( + "document rebuild task created, " + "processing in background" + ), + }, + ) + .__dict__ + ) + + doc = await kb_helper.rebuild_document( + doc_id, + chunk_size=chunk_size, + chunk_overlap=chunk_overlap, + batch_size=effective_batch_size, + tasks_limit=effective_tasks_limit, + max_retries=effective_max_retries, + ) + return Response().ok(doc.model_dump(), "重建文档成功").__dict__ + + except ValueError as e: + return Response().error(str(e)).__dict__ + except Exception as e: + logger.error(f"重建文档失败: {e}") + logger.error(traceback.format_exc()) + return Response().error(f"重建文档失败: {e!s}").__dict__ + + async def rebuild_kb(self): + """重建整个知识库""" + try: + kb_manager = self._get_kb_manager() + data = await request.json + + kb_id = data.get("kb_id") + if not kb_id: + return Response().error("缺少参数 kb_id").__dict__ + + chunk_size = self._coerce_optional_int(data.get("chunk_size"), "chunk_size") + chunk_overlap = self._coerce_optional_int( + data.get("chunk_overlap"), + "chunk_overlap", + ) + batch_size = self._coerce_optional_int(data.get("batch_size"), "batch_size") + tasks_limit = self._coerce_optional_int( + data.get("tasks_limit"), + "tasks_limit", + ) + max_retries = self._coerce_optional_int( + data.get("max_retries"), + "max_retries", + ) + effective_chunk_size = ( + chunk_size if chunk_size is not None else DEFAULT_CHUNK_SIZE + ) + effective_chunk_overlap = ( + chunk_overlap if chunk_overlap is not None else DEFAULT_CHUNK_OVERLAP + ) + effective_batch_size = ( + batch_size if batch_size is not None else DEFAULT_UPLOAD_BATCH_SIZE + ) + effective_tasks_limit = ( + tasks_limit if tasks_limit is not None else DEFAULT_UPLOAD_TASKS_LIMIT + ) + effective_max_retries = ( + max_retries if max_retries is not None else DEFAULT_UPLOAD_MAX_RETRIES + ) + self._validate_upload_options( + chunk_size=effective_chunk_size, + chunk_overlap=effective_chunk_overlap, + batch_size=effective_batch_size, + tasks_limit=effective_tasks_limit, + max_retries=effective_max_retries, + ) + background = self._coerce_optional_bool( + data.get("background"), + "background", + ) + + kb_helper = await kb_manager.get_kb(kb_id) + if not kb_helper: + return Response().error("知识库不存在").__dict__ + + if background: + kb_name = getattr( + getattr(kb_helper, "kb", None), + "kb_name", + "knowledge base", + ) + task_id = str(uuid.uuid4()) + self._init_task(task_id, status="pending") + await self._create_persistent_task( + task_id=task_id, + kb_id=kb_id, + task_type="kb_rebuild", + status="pending", + progress={ + "status": "pending", + "file_index": 0, + "file_total": 1, + "file_name": kb_name, + "stage": "waiting", + "current": 0, + "total": 100, + }, + ) + asyncio.create_task( + self._background_rebuild_kb_task( + task_id=task_id, + kb_helper=kb_helper, + chunk_size=chunk_size, + chunk_overlap=chunk_overlap, + batch_size=effective_batch_size, + tasks_limit=effective_tasks_limit, + max_retries=effective_max_retries, + ), + ) + return ( + Response() + .ok( + { + "task_id": task_id, + "kb_id": kb_id, + "message": ( + "knowledge base rebuild task created, " + "processing in background" + ), + }, + ) + .__dict__ + ) + + result = await kb_helper.rebuild_all_documents( + chunk_size=chunk_size, + chunk_overlap=chunk_overlap, + batch_size=effective_batch_size, + tasks_limit=effective_tasks_limit, + max_retries=effective_max_retries, + ) + return Response().ok(result, "重建知识库完成").__dict__ + + except ValueError as e: + return Response().error(str(e)).__dict__ + except Exception as e: + logger.error(f"重建知识库失败: {e}") + logger.error(traceback.format_exc()) + return Response().error(f"重建知识库失败: {e!s}").__dict__ + + async def batch_rebuild_documents(self): + """Start a background task to rebuild selected documents. + + Body: + - kb_id: knowledge base ID (required) + - doc_ids: document ID list (required) + """ + try: + kb_manager = self._get_kb_manager() + data = await request.json + + kb_id = data.get("kb_id") + if not kb_id: + return Response().error("缺少参数 kb_id").__dict__ + doc_ids = data.get("doc_ids") + if not doc_ids or not isinstance(doc_ids, list): + return Response().error("缺少参数 doc_ids 或格式错误").__dict__ + normalized_doc_ids = list( + dict.fromkeys( + doc_id.strip() + for doc_id in doc_ids + if isinstance(doc_id, str) and doc_id.strip() + ) + ) + if not normalized_doc_ids: + return Response().error("缺少参数 doc_ids 或格式错误").__dict__ + if len(normalized_doc_ids) > MAX_BATCH_REBUILD_DOCUMENTS: + return ( + Response() + .error(f"最多只能批量重建 {MAX_BATCH_REBUILD_DOCUMENTS} 个文档") + .__dict__ + ) + + chunk_size = self._coerce_optional_int(data.get("chunk_size"), "chunk_size") + chunk_overlap = self._coerce_optional_int( + data.get("chunk_overlap"), + "chunk_overlap", + ) + batch_size = self._coerce_optional_int(data.get("batch_size"), "batch_size") + tasks_limit = self._coerce_optional_int( + data.get("tasks_limit"), + "tasks_limit", + ) + max_retries = self._coerce_optional_int( + data.get("max_retries"), + "max_retries", + ) + effective_chunk_size = ( + chunk_size if chunk_size is not None else DEFAULT_CHUNK_SIZE + ) + effective_chunk_overlap = ( + chunk_overlap if chunk_overlap is not None else DEFAULT_CHUNK_OVERLAP + ) + effective_batch_size = ( + batch_size if batch_size is not None else DEFAULT_UPLOAD_BATCH_SIZE + ) + effective_tasks_limit = ( + tasks_limit if tasks_limit is not None else DEFAULT_UPLOAD_TASKS_LIMIT + ) + effective_max_retries = ( + max_retries if max_retries is not None else DEFAULT_UPLOAD_MAX_RETRIES + ) + self._validate_upload_options( + chunk_size=effective_chunk_size, + chunk_overlap=effective_chunk_overlap, + batch_size=effective_batch_size, + tasks_limit=effective_tasks_limit, + max_retries=effective_max_retries, + ) + + kb_helper = await kb_manager.get_kb(kb_id) + if not kb_helper: + return Response().error("知识库不存在").__dict__ + + task_id = str(uuid.uuid4()) + self._init_task(task_id, status="pending") + await self._create_persistent_task( + task_id=task_id, + kb_id=kb_id, + task_type="document_batch_rebuild", + status="pending", + progress={ + "status": "pending", + "file_index": 0, + "file_total": len(normalized_doc_ids), + "file_name": f"{len(normalized_doc_ids)} selected documents", + "stage": "waiting", + "current": 0, + "total": len(normalized_doc_ids), + }, + ) + asyncio.create_task( + self._background_rebuild_documents_task( + task_id=task_id, + kb_helper=kb_helper, + doc_ids=normalized_doc_ids, + chunk_size=chunk_size, + chunk_overlap=chunk_overlap, + batch_size=effective_batch_size, + tasks_limit=effective_tasks_limit, + max_retries=effective_max_retries, + ), + ) + return ( + Response() + .ok( + { + "task_id": task_id, + "doc_ids": normalized_doc_ids, + "message": ( + "document batch rebuild task created, " + "processing in background" + ), + }, + ) + .__dict__ + ) + + except ValueError as e: + return Response().error(str(e)).__dict__ + except Exception as e: + logger.error(f"批量重建文档失败: {e}") + logger.error(traceback.format_exc()) + return Response().error(f"批量重建文档失败: {e!s}").__dict__ + async def batch_delete_documents(self): """批量删除文档 @@ -1243,8 +2222,12 @@ async def batch_delete_documents(self): doc_ids = data.get("doc_ids") if not doc_ids or not isinstance(doc_ids, list): return Response().error("缺少参数 doc_ids 或格式错误").__dict__ - if len(doc_ids) > 100: - return Response().error("最多只能批量删除 100 个文档").__dict__ + if len(doc_ids) > MAX_BATCH_DELETE_DOCUMENTS: + return ( + Response() + .error(f"最多只能批量删除 {MAX_BATCH_DELETE_DOCUMENTS} 个文档") + .__dict__ + ) kb_helper = await kb_manager.get_kb(kb_id) if not kb_helper: @@ -1317,14 +2300,18 @@ async def list_chunks(self): Query 参数: - kb_id: 知识库 ID (必填) - page: 页码 (默认 1) - - page_size: 每页数量 (默认 20) + - page_size: 每页数量 """ try: kb_manager = self._get_kb_manager() kb_id = request.args.get("kb_id") doc_id = request.args.get("doc_id") - page = request.args.get("page", 1, type=int) - page_size = request.args.get("page_size", 100, type=int) + page = self._get_positive_query_int("page", 1) + page_size = self._get_positive_query_int( + "page_size", + DEFAULT_CHUNK_PAGE_SIZE, + ) + search = (request.args.get("search") or "").strip() or None if not kb_id: return Response().error("缺少参数 kb_id").__dict__ if not doc_id: @@ -1334,11 +2321,15 @@ async def list_chunks(self): limit = page_size if not kb_helper: return Response().error("知识库不存在").__dict__ - chunk_list = await kb_helper.get_chunks_by_doc_id( + chunk_list, total = await kb_helper.search_chunks_by_doc_id( doc_id=doc_id, + search=search, offset=offset, limit=limit, ) + document_chunk_count = total + if search is not None: + document_chunk_count = await kb_helper.get_chunk_count_by_doc_id(doc_id) return ( Response() .ok( @@ -1346,7 +2337,9 @@ async def list_chunks(self): "items": chunk_list, "page": page, "page_size": page_size, - "total": await kb_helper.get_chunk_count_by_doc_id(doc_id), + "total": total, + "filtered_total": total, + "document_chunk_count": document_chunk_count, }, ) .__dict__ @@ -1358,6 +2351,41 @@ async def list_chunks(self): logger.error(traceback.format_exc()) return Response().error(f"获取块列表失败: {e!s}").__dict__ + async def get_chunk_context(self): + """获取文本块和相邻上下文块 + + Query 参数: + - kb_id: 知识库 ID (必填) + - doc_id: 文档 ID (必填) + - chunk_id: 文本块 ID (必填) + """ + try: + kb_manager = self._get_kb_manager() + kb_id = request.args.get("kb_id") + doc_id = request.args.get("doc_id") + chunk_id = request.args.get("chunk_id") + if not kb_id: + return Response().error("缺少参数 kb_id").__dict__ + if not doc_id: + return Response().error("缺少参数 doc_id").__dict__ + if not chunk_id: + return Response().error("缺少参数 chunk_id").__dict__ + + kb_helper = await kb_manager.get_kb(kb_id) + if not kb_helper: + return Response().error("知识库不存在").__dict__ + context = await kb_helper.get_chunk_context( + chunk_id=chunk_id, + doc_id=doc_id, + ) + return Response().ok(data=context).__dict__ + except ValueError as e: + return Response().error(str(e)).__dict__ + except Exception as e: + logger.error(f"获取文本块上下文失败: {e}") + logger.error(traceback.format_exc()) + return Response().error(f"获取文本块上下文失败: {e!s}").__dict__ + # ===== 检索 API ===== async def retrieve(self): @@ -1376,7 +2404,8 @@ async def retrieve(self): query = data.get("query") kb_ids = data.get("kb_ids") kb_names = data.get("kb_names") - debug = data.get("debug", False) + debug = self._coerce_optional_bool(data.get("debug", False), "debug") + trace = self._coerce_optional_bool(data.get("trace", False), "trace") if not query: return Response().error("缺少参数 query").__dict__ @@ -1387,15 +2416,21 @@ async def retrieve(self): if not kb_ids and not kb_names: return Response().error("缺少参数 kb_ids 或 kb_names").__dict__ - top_k = self._coerce_optional_int(data.get("top_k", 5), "top_k") - top_k = top_k if top_k is not None else 5 + top_k = self._coerce_optional_int( + data.get("top_k", DEFAULT_TOP_M_FINAL), + "top_k", + ) + top_k = top_k if top_k is not None else DEFAULT_TOP_M_FINAL self._validate_positive_int(top_k, "top_k") + if top_k > MAX_RETRIEVE_TOP_K: + return Response().error(f"top_k 不能大于 {MAX_RETRIEVE_TOP_K}").__dict__ results = await kb_manager.retrieve( query=query, kb_names=kb_names, kb_ids=kb_ids, top_m_final=top_k, + include_trace=trace or debug, ) result_list = [] if results: @@ -1406,6 +2441,8 @@ async def retrieve(self): "total": len(result_list), "query": query, } + if results and "trace" in results: + response_data["trace"] = results["trace"] # Debug 模式:生成 t-SNE 可视化 if debug: @@ -1478,11 +2515,19 @@ async def upload_document_from_url(self): data.get("max_retries"), "max_retries", ) - chunk_size = chunk_size if chunk_size is not None else 512 - chunk_overlap = chunk_overlap if chunk_overlap is not None else 50 - batch_size = batch_size if batch_size is not None else 32 - tasks_limit = tasks_limit if tasks_limit is not None else 3 - max_retries = max_retries if max_retries is not None else 3 + chunk_size = chunk_size if chunk_size is not None else DEFAULT_CHUNK_SIZE + chunk_overlap = ( + chunk_overlap if chunk_overlap is not None else DEFAULT_CHUNK_OVERLAP + ) + batch_size = ( + batch_size if batch_size is not None else DEFAULT_UPLOAD_BATCH_SIZE + ) + tasks_limit = ( + tasks_limit if tasks_limit is not None else DEFAULT_UPLOAD_TASKS_LIMIT + ) + max_retries = ( + max_retries if max_retries is not None else DEFAULT_UPLOAD_MAX_RETRIES + ) self._validate_upload_options( chunk_size=chunk_size, chunk_overlap=chunk_overlap, @@ -1503,6 +2548,21 @@ async def upload_document_from_url(self): # 初始化任务状态 self._init_task(task_id, status="pending") + await self._create_persistent_task( + task_id=task_id, + kb_id=kb_id, + task_type="url", + status="pending", + progress={ + "status": "pending", + "file_index": 0, + "file_total": 1, + "file_name": f"URL: {url}", + "stage": "waiting", + "current": 0, + "total": 100, + }, + ) # 启动后台任务 asyncio.create_task( @@ -1565,6 +2625,7 @@ async def _background_upload_from_url_task( "current": 0, "total": 100, } + await self._persist_progress(task_id) # 创建进度回调函数 progress_callback = self._make_progress_callback(task_id, 0, f"URL: {url}") @@ -1593,10 +2654,23 @@ async def _background_upload_from_url_task( } self._set_task_result(task_id, "completed", result=result) + await self._update_persistent_task( + task_id, + status="completed", + result=result, + error=None, + **self._get_persistent_progress_updates(task_id), + ) except Exception as e: logger.error(f"后台上传URL任务 {task_id} 失败: {e}") logger.error(traceback.format_exc()) self._set_task_result(task_id, "failed", error=str(e)) + await self._update_persistent_task( + task_id, + status="failed", + error=str(e), + **self._get_persistent_progress_updates(task_id), + ) finally: asyncio.create_task(self._schedule_delayed_cleanup(task_id)) diff --git a/dashboard/src/i18n/locales/en-US/features/knowledge-base/detail.json b/dashboard/src/i18n/locales/en-US/features/knowledge-base/detail.json index 642eb7fdbe..177dda1fd3 100644 --- a/dashboard/src/i18n/locales/en-US/features/knowledge-base/detail.json +++ b/dashboard/src/i18n/locales/en-US/features/knowledge-base/detail.json @@ -21,23 +21,120 @@ "stats": "Statistics", "docCount": "Documents", "chunkCount": "Chunks", + "readyDocCount": "Ready Documents", + "failedDocCount": "Failed Documents", + "sourceFiles": "Source Files", + "storageUsed": "Storage Used", "embeddingModel": "Embedding Model", "rerankModel": "Rerank Model", "notSet": "Not Set" }, + "consistency": { + "title": "Index Consistency", + "run": "Run Check", + "repair": "Repair Fixable Issues", + "notRun": "No consistency check has been run yet. Run a check to compare document metadata, source files, and indexed chunks.", + "healthy": "No consistency issues found", + "unhealthy": "{count} consistency issues found", + "checkedAt": "Checked at: {time}", + "sqliteDocuments": "Metadata Documents", + "indexedChunks": "Indexed Chunks", + "documentChunks": "Document Chunks", + "sourceFiles": "Source Files", + "expectedChunks": "{count} expected chunks", + "actualChunks": "{count} actual chunks", + "checkSuccessHealthy": "Consistency check completed with no issues", + "checkSuccessUnhealthy": "Consistency check completed with {count} issues", + "checkFailed": "Consistency check failed", + "repairSuccess": "Consistency repair completed: {repaired} repaired, {skipped} skipped", + "repairPartialSuccess": "Consistency repair partially completed: {repaired} repaired, {skipped} skipped, {failed} failed", + "repairFailed": "Consistency repair failed", + "issues": { + "missingVectors": "Documents Missing Indexed Chunks", + "orphanVectors": "Orphan Indexed Chunks", + "missingSourceFiles": "Missing Source Files", + "chunkCountMismatches": "Chunk Count Mismatches", + "invalidVectorMetadata": "Invalid Index Metadata", + "unsafeSourcePaths": "Unsafe Source Paths" + }, + "reasons": { + "empty_file_path": "Source file path is empty", + "outside_kb_files_dir": "Source file path is outside the knowledge base directory", + "not_found": "Source file does not exist" + } + }, + "maintenance": { + "rebuild": "Rebuild Index", + "rebuildStarted": "Knowledge base rebuild started", + "rebuildSuccess": "Knowledge base rebuild completed", + "rebuildFailed": "Failed to rebuild knowledge base", + "rebuildFailedWithReason": "Failed to rebuild knowledge base: {reason}", + "rebuildPartialSuccess": "Knowledge base rebuild partially completed: {success} succeeded, {failed} failed", + "unknownError": "Unknown error", + "stages": { + "waiting": "Waiting...", + "rebuilding": "Rebuilding knowledge base...", + "parsing": "Parsing document...", + "chunking": "Chunking text...", + "embedding": "Generating embeddings...", + "completed": "Completed" + } + }, + "tasks": { + "title": "Recent Tasks", + "refresh": "Refresh tasks", + "empty": "No task records yet", + "loadFailed": "Failed to load recent tasks", + "recentFailures": "Recent Failures", + "noErrorMessage": "No error message", + "types": { + "upload": "Document Upload", + "import": "Document Import", + "url": "URL Import", + "document_rebuild": "Document Rebuild", + "document_batch_rebuild": "Batch Document Rebuild", + "kb_rebuild": "Knowledge Base Rebuild" + }, + "statuses": { + "pending": "Pending", + "processing": "Processing", + "completed": "Completed", + "failed": "Failed" + } + }, "documents": { "title": "Documents", "upload": "Upload Document", "empty": "No documents", "searchPlaceholder": "Search documents...", + "statusFilter": "Status", + "sourceFilter": "Source", + "allStatuses": "All Statuses", + "allSources": "All Sources", + "filteredCount": "Showing {filtered} / {total} documents", "name": "Name", "type": "Type", + "status": "Status", "size": "Size", "chunks": "Chunks", "createdAt": "Uploaded At", "actions": "Actions", "view": "View", + "copyFailure": "Copy Failure Diagnostics", + "rebuild": "Retry Rebuild", "delete": "Delete", + "rebuildTitle": "Rebuild Document Index", + "rebuildConfirm": "Rebuild the index for document '{name}'?", + "rebuildWarning": "Rebuild will parse and write the index again. The previous index may still be used until the task finishes.", + "batchRebuild": "Rebuild Selected ({count})", + "batchRebuildTitle": "Rebuild Selected Documents", + "batchRebuildConfirm": "Rebuild the index for the {count} selected documents?", + "batchRebuildMore": "{count} more", + "batchRebuildWarning": "Batch rebuild will parse and write indexes for the selected documents again. Previous indexes may still be used until the task finishes.", + "batchDelete": "Delete Selected ({count})", + "batchDeleteTitle": "Delete Selected Documents", + "batchDeleteConfirm": "Delete the {count} selected documents?", + "batchDeleteMore": "{count} more", "cancel": "Cancel", "deleteConfirm": "Are you sure you want to delete document '{name}'?", "deleteWarning": "This will delete the document and all its chunks. This action cannot be undone.", @@ -46,21 +143,54 @@ "uploadFailed": "Failed to upload document", "loadFailed": "Failed to load documents", "deleteSuccess": "Document deleted successfully", - "deleteFailed": "Failed to delete document" + "deleteFailed": "Failed to delete document", + "batchDeleteSuccess": "{count} documents deleted", + "batchDeletePartialSuccess": "Batch delete partially completed: {success} succeeded, {failed} failed", + "batchDeleteFailed": "Failed to batch delete documents", + "batchDeleteLimitExceeded": "You can delete up to {limit} documents at once", + "batchRebuildStarted": "Started rebuilding {count} documents", + "batchRebuildFailed": "Failed to batch rebuild documents", + "batchRebuildLimitExceeded": "You can rebuild up to {limit} documents at once", + "failureDocument": "Document", + "failureDocumentId": "Document ID", + "failureStage": "Failure Stage", + "failureMessage": "Error Message", + "unknownFailureStage": "Unknown Stage", + "noFailureMessage": "No error message", + "copyFailureSuccess": "Failure diagnostics copied", + "copyFailureFailed": "Failed to copy failure diagnostics", + "rebuildStarted": "Document rebuild started", + "rebuildSuccess": "Document rebuilt successfully", + "rebuildFailed": "Failed to rebuild document", + "rebuildFailedWithReason": "Failed to rebuild document: {reason}", + "rebuildPartialSuccess": "Document rebuild partially completed: {success} succeeded, {failed} failed", + "statuses": { + "pending": "Pending", + "parsing": "Parsing", + "chunking": "Chunking", + "embedding": "Indexing", + "ready": "Ready", + "failed": "Failed" + }, + "sourceTypes": { + "file": "File", + "url": "URL", + "import": "Import" + } }, "upload": { "title": "Upload Document", "selectFile": "Select File", "dropzone": "Drop files here or click to select", - "supportedFormats": "Supported formats: .txt, .md, .markdown, .rst, .adoc, .pdf, .docx, .epub, .xls, .xlsx", - "maxSize": "Max file size: 128MB", - "maxFiles": "Upload up to 10 files", + "supportedFormats": "Supported formats: {formats}", + "maxSize": "Max file size: {size}", + "maxFiles": "Upload up to {count} files", "maxFilesWarning": "You can select up to {count} files", "selectedFiles": "{count} files selected", "clear": "Clear", "someFilesRejected": "Some files were not added", "unsupportedFile": "{name}: unsupported file type", - "fileTooLarge": "{name}: file exceeds 128MB", + "fileTooLarge": "{name}: file exceeds {size}", "invalidSettings": "Please check the upload settings", "chunkSettings": "Chunk Settings", "batchSettings": "Batch Settings", @@ -69,15 +199,15 @@ "cleaningProvider": "Cleaning Service Provider", "cleaningProviderHint": "Select an LLM provider to clean and summarize the extracted web page content", "chunkSize": "Chunk Size", - "chunkSizeHint": "Number of characters per chunk (default: 512)", + "chunkSizeHint": "Number of characters per chunk (default: {value})", "chunkOverlap": "Chunk Overlap", - "chunkOverlapHint": "Overlapping characters between chunks (default: 50)", + "chunkOverlapHint": "Overlapping characters between chunks (default: {value})", "batchSize": "Batch Size", - "batchSizeHint": "Number of chunks to process in each batch (default: 32)", + "batchSizeHint": "Number of chunks to process in each batch (default: {value})", "tasksLimit": "Concurrent Tasks Limit", - "tasksLimitHint": "Maximum number of concurrent upload tasks (default: 3)", + "tasksLimitHint": "Maximum number of concurrent upload tasks (default: {value})", "maxRetries": "Max Retries", - "maxRetriesHint": "Number of times to retry a failed upload task (default: 3)", + "maxRetriesHint": "Number of times to retry a failed upload task (default: {value})", "cancel": "Cancel", "submit": "Upload", "fileRequired": "Please select a file to upload", @@ -86,6 +216,7 @@ "urlPlaceholder": "Enter the URL of the web page to extract content from", "urlRequired": "Please enter a URL", "urlHint": "The main content will be automatically extracted from the target URL as a document. Currently supports {supported} pages. Before use, please ensure that the target web page allows crawler access.", + "unsupportedUrlImport": "URL import is not enabled by the backend", "tavilyCheckFailed": "Failed to check web search configuration", "tavilyRequired": "Tavily Key is required for this feature", "configure": "Configure", @@ -102,7 +233,9 @@ "cleaning": "Cleaning content...", "parsing": "Parsing document...", "chunking": "Chunking text...", - "embedding": "Generating embeddings..." + "embedding": "Generating embeddings...", + "rebuilding": "Rebuilding document...", + "completed": "Completed" }, "beta": "Beta" }, @@ -118,6 +251,12 @@ "tryDifferentQuery": "Try a different query", "settings": "Retrieval Settings", "debugMode": "Debug Mode", + "debugModeTsne": "Debug Mode (t-SNE)", + "traceMode": "Retrieval Trace", + "cancel": "Cancel", + "caseNotesPlaceholder": "Example: sparse retrieval ranked too low", + "caseTags": "Tags", + "caseTagsPlaceholder": "Example: manual, retrieval-ui, bad-case", "tsneVisualization": "t-SNE Visualization", "topK": "Number of Results", "topKHint": "Maximum number of results to return", @@ -128,9 +267,40 @@ "chunk": "Chunk #{index}", "content": "Content", "charCount": "{count} characters", + "traceTitle": "Retrieval Trace", + "traceStageCount": "{count} stages", + "traceHits": "{count} hits", + "traceDenseRank": "Dense rank #{rank}", + "traceSparseRank": "Sparse rank #{rank}", + "traceDenseScore": "Dense score", + "traceSparseScore": "Sparse score", + "traceRrfScore": "RRF score", + "traceRerankScore": "Rerank score", + "traceDuplicateOf": "Duplicate of {chunk}", + "traceDedupSimilarity": "Duplicate similarity {value}", + "sourcePage": "Page {page}", + "sourceSection": "Section {index}", + "sourceParentChunk": "Parent chunk {id}", + "tracePreviewEmpty": "No content preview", + "traceEmpty": "No candidates in this stage", + "unknownDocument": "Unknown document", + "traceStages": { + "dense": "Dense Recall", + "sparse": "Sparse Recall", + "fusion": "RRF Fusion", + "dedup": "Near-Duplicate Removal", + "dedup_removed": "Removed Duplicates", + "rerank": "Rerank", + "final": "Final Context" + }, "searchSuccess": "Search completed, found {count} results", "searchFailed": "Search failed", - "queryRequired": "Please enter a query" + "queryRequired": "Please enter a query", + "latestRunResults": "Latest Results", + "metricRecall": "Recall", + "metricNdcg": "nDCG", + "metricPrecision": "Precision", + "metricFirstHit": "First Hit" }, "settings": { "title": "Knowledge Base Settings", @@ -162,7 +332,7 @@ "positiveInteger": "Enter an integer greater than 0", "nonNegativeInteger": "Enter an integer no less than 0", "overlapLessThanSize": "Chunk overlap must be less than chunk size", - "topKRange": "Number of results must be an integer from 1 to 100" + "topKRange": "Number of results must be an integer from 1 to {max}" }, "actions": { "retry": "Retry" @@ -175,6 +345,7 @@ "description": "A Tavily API Key is required to use web-based knowledge base features. You can get one from", "officialSite": "Tavily", "apiKeyLabel": "Tavily API Key", + "apiKeyPlaceholder": "tvly-...", "cancel": "Cancel", "save": "Save", "keyRequired": "API Key is required", diff --git a/dashboard/src/i18n/locales/en-US/features/knowledge-base/document.json b/dashboard/src/i18n/locales/en-US/features/knowledge-base/document.json index 8cf45bd51f..dbdca2bf67 100644 --- a/dashboard/src/i18n/locales/en-US/features/knowledge-base/document.json +++ b/dashboard/src/i18n/locales/en-US/features/knowledge-base/document.json @@ -9,13 +9,48 @@ "chunkCount": "Chunk Count", "createdAt": "Uploaded At" }, + "processing": { + "title": "Processing Information", + "status": "Status", + "sourceType": "Source Type", + "sourceUri": "Source URI", + "contentHash": "Content Hash", + "parser": "Parser", + "chunker": "Chunker", + "version": "Version", + "parentDocId": "Parent Document ID", + "indexedAt": "Indexed At", + "unknownStage": "Unknown Stage", + "noErrorMessage": "No error message", + "statuses": { + "pending": "Pending", + "parsing": "Parsing", + "chunking": "Chunking", + "embedding": "Indexing", + "ready": "Ready", + "failed": "Failed" + }, + "sourceTypes": { + "file": "File", + "url": "URL", + "import": "Import", + "api": "API" + } + }, "chunks": { "title": "Chunks", + "total": "{count} chunks", + "filteredTotal": "{filtered} / {total} matching chunks", "empty": "No chunks", "index": "Index", "content": "Content", + "titlePath": "Title Path", "charCount": "Characters", "charCountValue": "{count} characters", + "tokenEstimate": "Estimated Tokens", + "tokenEstimateValue": "About {count} tokens", + "offset": "Offset", + "contentHash": "Content Hash", "actions": "Actions", "view": "View", "edit": "Edit", @@ -24,6 +59,7 @@ "search": "Search Chunks", "searchPlaceholder": "Enter keywords to search chunks...", "showing": "Showing", + "showingRange": "Showing {start} - {end} / {total} chunks", "deleteConfirm": "Are you sure you want to delete this chunk?", "deleteSuccess": "Chunk deleted successfully", "deleteFailed": "Failed to delete chunk" @@ -50,14 +86,39 @@ "index": "Index", "content": "Content", "charCount": "Characters", + "tokenEstimate": "Estimated Tokens", + "titlePath": "Title Path", + "section": "Section", + "pageNumber": "Page", + "offset": "Offset", + "contentHash": "Content Hash", + "adjacentChunks": "Adjacent Chunks", + "previousChunk": "Previous: {id}", + "nextChunk": "Next: {id}", + "parentChunk": "Parent Chunk", "vecDocId": "Vector ID", + "context": "Adjacent Context", + "previous": "Previous", + "current": "Current", + "next": "Next", + "contextMissing": "No adjacent chunk", "close": "Close" }, "actions": { - "retry": "Retry" + "retry": "Retry", + "retryRebuild": "Retry Rebuild", + "retryRebuildConfirm": "Rebuild the index for this document?" }, "messages": { "loadDocumentFailed": "Failed to load document details", - "loadChunksFailed": "Failed to load chunks" + "loadChunksFailed": "Failed to load chunks", + "loadChunkContextFailed": "Failed to load adjacent context", + "rebuildStarted": "Document rebuild started", + "rebuildCompleted": "Document rebuild completed", + "rebuildFailed": "Failed to rebuild document", + "rebuildFailedWithReason": "Failed to rebuild document: {reason}", + "focusChunkLoaded": "Opened the retrieved chunk", + "focusChunkFailed": "Failed to open the retrieved chunk", + "focusChunkNotFound": "Retrieved chunk not found" } } diff --git a/dashboard/src/i18n/locales/ru-RU/features/knowledge-base/detail.json b/dashboard/src/i18n/locales/ru-RU/features/knowledge-base/detail.json index 0e3eab1cfe..7fd96ced44 100644 --- a/dashboard/src/i18n/locales/ru-RU/features/knowledge-base/detail.json +++ b/dashboard/src/i18n/locales/ru-RU/features/knowledge-base/detail.json @@ -1,4 +1,4 @@ -{ +{ "title": "Детали базы знаний", "backToList": "К списку", "breadcrumb": { @@ -21,23 +21,120 @@ "stats": "Статистика", "docCount": "Количество документов", "chunkCount": "Количество фрагментов", + "readyDocCount": "Готовые документы", + "failedDocCount": "Ошибки документов", + "sourceFiles": "Исходные файлы", + "storageUsed": "Занято места", "embeddingModel": "Embedding модель", "rerankModel": "Rerank модель", "notSet": "не выбрано" }, + "consistency": { + "title": "Согласованность индекса", + "run": "Проверить", + "repair": "Исправить доступное", + "notRun": "Проверка еще не запускалась. Запустите ее, чтобы сравнить метаданные документов, исходные файлы и индексированные фрагменты.", + "healthy": "Проблем согласованности не найдено", + "unhealthy": "Найдено проблем: {count}", + "checkedAt": "Проверено: {time}", + "sqliteDocuments": "Документы в метаданных", + "indexedChunks": "Фрагменты в индексе", + "documentChunks": "Фрагменты документов", + "sourceFiles": "Исходные файлы", + "expectedChunks": "Ожидалось фрагментов: {count}", + "actualChunks": "Фактически фрагментов: {count}", + "checkSuccessHealthy": "Проверка завершена, проблем не найдено", + "checkSuccessUnhealthy": "Проверка завершена, найдено проблем: {count}", + "checkFailed": "Не удалось выполнить проверку", + "repairSuccess": "Исправление завершено: исправлено {repaired}, пропущено {skipped}", + "repairPartialSuccess": "Исправление частично завершено: исправлено {repaired}, пропущено {skipped}, ошибок {failed}", + "repairFailed": "Не удалось исправить согласованность", + "issues": { + "missingVectors": "У документов нет фрагментов в индексе", + "orphanVectors": "Фрагменты без документа", + "missingSourceFiles": "Нет исходных файлов", + "chunkCountMismatches": "Не совпадает число фрагментов", + "invalidVectorMetadata": "Ошибки метаданных индекса", + "unsafeSourcePaths": "Некорректные пути исходных файлов" + }, + "reasons": { + "empty_file_path": "Путь к исходному файлу пуст", + "outside_kb_files_dir": "Путь к исходному файлу вне каталога базы знаний", + "not_found": "Исходный файл не найден" + } + }, + "maintenance": { + "rebuild": "Переиндексировать", + "rebuildStarted": "Переиндексация базы знаний запущена", + "rebuildSuccess": "Переиндексация базы знаний завершена", + "rebuildFailed": "Не удалось переиндексировать базу знаний", + "rebuildFailedWithReason": "Не удалось переиндексировать базу знаний: {reason}", + "rebuildPartialSuccess": "Переиндексация частично завершена: успешно {success}, ошибок {failed}", + "unknownError": "Неизвестная ошибка", + "stages": { + "waiting": "Ожидание...", + "rebuilding": "Переиндексация базы знаний...", + "parsing": "Разбор документа...", + "chunking": "Разбиение текста...", + "embedding": "Генерация векторов...", + "completed": "Завершено" + } + }, + "tasks": { + "title": "Последние задачи", + "refresh": "Обновить задачи", + "empty": "Задач пока нет", + "loadFailed": "Не удалось загрузить последние задачи", + "recentFailures": "Последние ошибки", + "noErrorMessage": "Нет сообщения об ошибке", + "types": { + "upload": "Загрузка документа", + "import": "Импорт документа", + "url": "Импорт URL", + "document_rebuild": "Переиндексация документа", + "document_batch_rebuild": "Пакетная переиндексация документов", + "kb_rebuild": "Переиндексация базы знаний" + }, + "statuses": { + "pending": "Ожидание", + "processing": "В обработке", + "completed": "Завершено", + "failed": "Ошибка" + } + }, "documents": { "title": "Список документов", "upload": "Загрузить", "empty": "Документов нет", "searchPlaceholder": "Поиск документов...", + "statusFilter": "Статус", + "sourceFilter": "Источник", + "allStatuses": "Все статусы", + "allSources": "Все источники", + "filteredCount": "Показано {filtered} / {total} документов", "name": "Имя файла", "type": "Тип", + "status": "Статус", "size": "Размер", "chunks": "Фрагменты", "createdAt": "Дата загрузки", "actions": "Действия", "view": "Смотреть", + "copyFailure": "Копировать диагностику", + "rebuild": "Повторить индексацию", "delete": "Удалить", + "rebuildTitle": "Переиндексировать документ", + "rebuildConfirm": "Переиндексировать документ «{name}»?", + "rebuildWarning": "Переиндексация повторно разберет документ и запишет индекс. До завершения задачи может использоваться прежний индекс.", + "batchRebuild": "Переиндексировать выбранные ({count})", + "batchRebuildTitle": "Переиндексировать выбранные документы", + "batchRebuildConfirm": "Переиндексировать выбранные документы: {count}?", + "batchRebuildMore": "Еще {count}", + "batchRebuildWarning": "Пакетная переиндексация повторно разберет выбранные документы и запишет индексы. До завершения задачи могут использоваться прежние индексы.", + "batchDelete": "Удалить выбранные ({count})", + "batchDeleteTitle": "Удалить выбранные документы", + "batchDeleteConfirm": "Удалить выбранные документы: {count}?", + "batchDeleteMore": "Еще {count}", "cancel": "Отмена", "deleteConfirm": "Вы уверены, что хотите удалить «{name}»?", "deleteWarning": "Это удалит файл и все его фрагменты из индекса.", @@ -46,21 +143,54 @@ "uploadFailed": "Ошибка загрузки", "loadFailed": "Не удалось загрузить документы", "deleteSuccess": "Файл удален", - "deleteFailed": "Ошибка удаления" + "deleteFailed": "Ошибка удаления", + "batchDeleteSuccess": "Удалено документов: {count}", + "batchDeletePartialSuccess": "Пакетное удаление частично завершено: успешно {success}, ошибок {failed}", + "batchDeleteFailed": "Не удалось удалить документы пакетом", + "batchDeleteLimitExceeded": "За один раз можно удалить не более {limit} документов", + "batchRebuildStarted": "Запущена переиндексация документов: {count}", + "batchRebuildFailed": "Не удалось переиндексировать документы пакетом", + "batchRebuildLimitExceeded": "За один раз можно переиндексировать не более {limit} документов", + "failureDocument": "Документ", + "failureDocumentId": "ID документа", + "failureStage": "Этап ошибки", + "failureMessage": "Сообщение ошибки", + "unknownFailureStage": "Неизвестный этап", + "noFailureMessage": "Нет сообщения об ошибке", + "copyFailureSuccess": "Диагностика ошибки скопирована", + "copyFailureFailed": "Не удалось скопировать диагностику ошибки", + "rebuildStarted": "Переиндексация документа запущена", + "rebuildSuccess": "Документ переиндексирован", + "rebuildFailed": "Не удалось переиндексировать документ", + "rebuildFailedWithReason": "Не удалось переиндексировать документ: {reason}", + "rebuildPartialSuccess": "Переиндексация частично завершена: успешно {success}, ошибок {failed}", + "statuses": { + "pending": "Ожидание", + "parsing": "Разбор", + "chunking": "Фрагментация", + "embedding": "Индексация", + "ready": "Готово", + "failed": "Ошибка" + }, + "sourceTypes": { + "file": "Файл", + "url": "URL", + "import": "Импорт" + } }, "upload": { "title": "Добавление контента", "selectFile": "Файл", "dropzone": "Нажмите или перетащите файл сюда", - "supportedFormats": "Форматы: .txt, .md, .markdown, .rst, .adoc, .pdf, .docx, .epub, .xls, .xlsx", - "maxSize": "Максимум: 128MB", - "maxFiles": "Можно загрузить до 10 файлов", + "supportedFormats": "Форматы: {formats}", + "maxSize": "Максимум: {size}", + "maxFiles": "Можно загрузить до {count} файлов", "maxFilesWarning": "Можно выбрать не более {count} файлов", "selectedFiles": "Выбрано файлов: {count}", "clear": "Очистить", "someFilesRejected": "Некоторые файлы не добавлены", "unsupportedFile": "{name}: неподдерживаемый тип файла", - "fileTooLarge": "{name}: файл больше 128MB", + "fileTooLarge": "{name}: файл больше {size}", "invalidSettings": "Проверьте параметры загрузки", "chunkSettings": "Фрагментация", "batchSettings": "Пакетная обработка", @@ -69,15 +199,15 @@ "cleaningProvider": "Сервис для очистки", "cleaningProviderHint": "LLM провайдер для суммаризации и извлечения смыслов из веб-страниц", "chunkSize": "Размер чанка", - "chunkSizeHint": "Символов в блоке (по умолчанию: 512)", + "chunkSizeHint": "Символов в блоке (по умолчанию: {value})", "chunkOverlap": "Перекрытие", - "chunkOverlapHint": "Перекрытие между блоками (по умолчанию: 50)", + "chunkOverlapHint": "Перекрытие между блоками (по умолчанию: {value})", "batchSize": "Размер пакета", - "batchSizeHint": "Блоков за один запрос (по умолчанию: 32)", + "batchSizeHint": "Блоков за один запрос (по умолчанию: {value})", "tasksLimit": "Лимит задач", - "tasksLimitHint": "Макс. параллельных потоков (по умолчанию: 3)", + "tasksLimitHint": "Макс. параллельных потоков (по умолчанию: {value})", "maxRetries": "Попытки", - "maxRetriesHint": "Повторов при сбое (по умолчанию: 3)", + "maxRetriesHint": "Повторов при сбое (по умолчанию: {value})", "cancel": "Отмена", "submit": "Загрузить", "fileRequired": "Пожалуйста, выберите файл", @@ -86,6 +216,7 @@ "urlPlaceholder": "Ссылка на веб-страницу", "urlRequired": "Введите URL", "urlHint": "Контент будет автоматически извлечен со страницы. Убедитесь, что сайт разрешает доступ роботам.", + "unsupportedUrlImport": "Импорт из URL не включен на сервере", "tavilyCheckFailed": "Не удалось проверить настройки веб-поиска", "tavilyRequired": "Для этой функции нужен Tavily Key", "configure": "Настроить", @@ -102,7 +233,9 @@ "cleaning": "Очистка контента...", "parsing": "Разбор документа...", "chunking": "Разбиение текста...", - "embedding": "Генерация векторов..." + "embedding": "Генерация векторов...", + "rebuilding": "Переиндексация документа...", + "completed": "Завершено" }, "beta": "Бета-версия" }, @@ -118,6 +251,12 @@ "tryDifferentQuery": "Попробуйте изменить формулировку запроса", "settings": "Параметры поиска", "debugMode": "Режим отладки", + "debugModeTsne": "Режим отладки (t-SNE)", + "traceMode": "Трассировка поиска", + "cancel": "Отмена", + "caseNotesPlaceholder": "Например: Sparse поиск дал низкий ранг", + "caseTags": "Теги", + "caseTagsPlaceholder": "Например: manual, retrieval-ui, bad-case", "tsneVisualization": "t-SNE визуализация", "topK": "Количество результатов", "topKHint": "Сколько фрагментов возвращать", @@ -128,9 +267,40 @@ "chunk": "Фрагмент #{index}", "content": "Текст", "charCount": "{count} симв.", + "traceTitle": "Трассировка поиска", + "traceStageCount": "Этапов: {count}", + "traceHits": "Найдено: {count}", + "traceDenseRank": "Dense ранг #{rank}", + "traceSparseRank": "Sparse ранг #{rank}", + "traceDenseScore": "Оценка dense", + "traceSparseScore": "Оценка sparse", + "traceRrfScore": "Оценка RRF", + "traceRerankScore": "Оценка rerank", + "traceDuplicateOf": "Дубликат {chunk}", + "traceDedupSimilarity": "Сходство дубля {value}", + "sourcePage": "Стр. {page}", + "sourceSection": "Раздел {index}", + "sourceParentChunk": "Родительский фрагмент {id}", + "tracePreviewEmpty": "Нет предпросмотра", + "traceEmpty": "На этом этапе нет кандидатов", + "unknownDocument": "Неизвестный документ", + "traceStages": { + "dense": "Dense поиск", + "sparse": "Sparse поиск", + "fusion": "RRF объединение", + "dedup": "Удаление дублей", + "dedup_removed": "Удаленные дубли", + "rerank": "Rerank", + "final": "Итоговый контекст" + }, "searchSuccess": "Поиск завершен, найдено: {count}", "searchFailed": "Ошибка выполнения поиска", - "queryRequired": "Введите поисковый запрос" + "queryRequired": "Введите поисковый запрос", + "latestRunResults": "Последние результаты", + "metricRecall": "Recall", + "metricNdcg": "nDCG", + "metricPrecision": "Precision", + "metricFirstHit": "Первое попадание" }, "settings": { "title": "Общие настройки базы", @@ -162,7 +332,7 @@ "positiveInteger": "Введите целое число больше 0", "nonNegativeInteger": "Введите целое число не меньше 0", "overlapLessThanSize": "Перекрытие должно быть меньше размера чанка", - "topKRange": "Количество результатов должно быть целым числом от 1 до 100" + "topKRange": "Количество результатов должно быть целым числом от 1 до {max}" }, "actions": { "retry": "Повторить" @@ -175,6 +345,7 @@ "description": "Для веб-функций базы знаний нужен Tavily API Key. Получить его можно на", "officialSite": "сайте Tavily", "apiKeyLabel": "Tavily API Key", + "apiKeyPlaceholder": "tvly-...", "cancel": "Отмена", "save": "Сохранить", "keyRequired": "API Key обязателен", diff --git a/dashboard/src/i18n/locales/ru-RU/features/knowledge-base/document.json b/dashboard/src/i18n/locales/ru-RU/features/knowledge-base/document.json index 2de459be24..4f391e4e93 100644 --- a/dashboard/src/i18n/locales/ru-RU/features/knowledge-base/document.json +++ b/dashboard/src/i18n/locales/ru-RU/features/knowledge-base/document.json @@ -9,13 +9,48 @@ "chunkCount": "Количество фрагментов", "createdAt": "Загружен" }, + "processing": { + "title": "Информация обработки", + "status": "Статус", + "sourceType": "Тип источника", + "sourceUri": "Источник", + "contentHash": "Хэш контента", + "parser": "Парсер", + "chunker": "Разбиение", + "version": "Версия", + "parentDocId": "ID родительского документа", + "indexedAt": "Индексирован", + "unknownStage": "Неизвестный этап", + "noErrorMessage": "Нет сообщения об ошибке", + "statuses": { + "pending": "Ожидание", + "parsing": "Разбор", + "chunking": "Фрагментация", + "embedding": "Индексация", + "ready": "Готово", + "failed": "Ошибка" + }, + "sourceTypes": { + "file": "Файл", + "url": "URL", + "import": "Импорт", + "api": "API" + } + }, "chunks": { "title": "Фрагменты текста", + "total": "Фрагментов: {count}", + "filteredTotal": "Найдено {filtered} / {total} фрагм.", "empty": "Фрагменты не найдены", "index": "Индекс", "content": "Текст", + "titlePath": "Путь заголовков", "charCount": "Символов", "charCountValue": "{count} симв.", + "tokenEstimate": "Оценка токенов", + "tokenEstimateValue": "Около {count} ток.", + "offset": "Позиция", + "contentHash": "Хэш контента", "actions": "Действия", "view": "Детали", "edit": "Изменить", @@ -24,6 +59,7 @@ "search": "Поиск по документу", "searchPlaceholder": "Найти во фрагментах...", "showing": "Показано", + "showingRange": "Показано {start} - {end} / {total} фрагм.", "deleteConfirm": "Удалить этот фрагмент?", "deleteSuccess": "Фрагмент удален", "deleteFailed": "Ошибка удаления" @@ -50,14 +86,39 @@ "index": "Индекс", "content": "Текст", "charCount": "Символов", + "tokenEstimate": "Оценка токенов", + "titlePath": "Путь заголовков", + "section": "Раздел", + "pageNumber": "Страница", + "offset": "Позиция", + "contentHash": "Хэш контента", + "adjacentChunks": "Соседние фрагменты", + "previousChunk": "Предыдущий: {id}", + "nextChunk": "Следующий: {id}", + "parentChunk": "Родительский фрагмент", "vecDocId": "ID вектора", + "context": "Соседний контекст", + "previous": "Предыдущий", + "current": "Текущий", + "next": "Следующий", + "contextMissing": "Соседний фрагмент отсутствует", "close": "Закрыть" }, "actions": { - "retry": "Повторить" + "retry": "Повторить", + "retryRebuild": "Повторить индексацию", + "retryRebuildConfirm": "Переиндексировать этот документ?" }, "messages": { "loadDocumentFailed": "Не удалось загрузить документ", - "loadChunksFailed": "Не удалось загрузить фрагменты" + "loadChunksFailed": "Не удалось загрузить фрагменты", + "loadChunkContextFailed": "Не удалось загрузить соседний контекст", + "rebuildStarted": "Переиндексация документа запущена", + "rebuildCompleted": "Переиндексация документа завершена", + "rebuildFailed": "Не удалось переиндексировать документ", + "rebuildFailedWithReason": "Не удалось переиндексировать документ: {reason}", + "focusChunkLoaded": "Открыт найденный фрагмент", + "focusChunkFailed": "Не удалось открыть найденный фрагмент", + "focusChunkNotFound": "Найденный фрагмент не найден" } } diff --git a/dashboard/src/i18n/locales/zh-CN/features/knowledge-base/detail.json b/dashboard/src/i18n/locales/zh-CN/features/knowledge-base/detail.json index 4d294f12e4..bc04ecfca8 100644 --- a/dashboard/src/i18n/locales/zh-CN/features/knowledge-base/detail.json +++ b/dashboard/src/i18n/locales/zh-CN/features/knowledge-base/detail.json @@ -21,23 +21,120 @@ "stats": "统计信息", "docCount": "文档数量", "chunkCount": "分块数量", + "readyDocCount": "已索引文档", + "failedDocCount": "失败文档", + "sourceFiles": "源文件", + "storageUsed": "存储占用", "embeddingModel": "嵌入模型", "rerankModel": "重排序模型", "notSet": "未设置" }, + "consistency": { + "title": "索引一致性", + "run": "运行检查", + "repair": "修复可修复项", + "notRun": "尚未运行一致性检查。点击运行检查可诊断文档元数据、源文件和索引文本块是否一致。", + "healthy": "未发现一致性问题", + "unhealthy": "发现 {count} 个一致性问题", + "checkedAt": "检查时间: {time}", + "sqliteDocuments": "元数据文档", + "indexedChunks": "索引分块", + "documentChunks": "文档分块", + "sourceFiles": "源文件", + "expectedChunks": "预期 {count} 个分块", + "actualChunks": "实际 {count} 个分块", + "checkSuccessHealthy": "一致性检查完成,未发现问题", + "checkSuccessUnhealthy": "一致性检查完成,发现 {count} 个问题", + "checkFailed": "一致性检查失败", + "repairSuccess": "一致性修复完成: 修复 {repaired} 项, 跳过 {skipped} 项", + "repairPartialSuccess": "一致性修复部分完成: 修复 {repaired} 项, 跳过 {skipped} 项, 失败 {failed} 项", + "repairFailed": "一致性修复失败", + "issues": { + "missingVectors": "文档缺失索引分块", + "orphanVectors": "孤儿索引分块", + "missingSourceFiles": "源文件缺失", + "chunkCountMismatches": "分块数量不一致", + "invalidVectorMetadata": "索引元数据异常", + "unsafeSourcePaths": "源文件路径异常" + }, + "reasons": { + "empty_file_path": "源文件路径为空", + "outside_kb_files_dir": "源文件路径不在知识库目录内", + "not_found": "源文件不存在" + } + }, + "maintenance": { + "rebuild": "重建索引", + "rebuildStarted": "知识库重建任务已开始", + "rebuildSuccess": "知识库重建完成", + "rebuildFailed": "知识库重建失败", + "rebuildFailedWithReason": "知识库重建失败: {reason}", + "rebuildPartialSuccess": "知识库重建部分完成: 成功 {success} 个, 失败 {failed} 个", + "unknownError": "未知错误", + "stages": { + "waiting": "等待中...", + "rebuilding": "重建知识库...", + "parsing": "解析文档...", + "chunking": "文本分块...", + "embedding": "生成向量...", + "completed": "已完成" + } + }, + "tasks": { + "title": "最近任务", + "refresh": "刷新任务", + "empty": "暂无任务记录", + "loadFailed": "加载最近任务失败", + "recentFailures": "最近失败", + "noErrorMessage": "暂无错误信息", + "types": { + "upload": "上传文档", + "import": "导入文档", + "url": "URL 导入", + "document_rebuild": "文档重建", + "document_batch_rebuild": "批量文档重建", + "kb_rebuild": "知识库重建" + }, + "statuses": { + "pending": "等待中", + "processing": "处理中", + "completed": "已完成", + "failed": "失败" + } + }, "documents": { "title": "文档列表", "upload": "上传文档", "empty": "暂无文档", "searchPlaceholder": "搜索文档...", + "statusFilter": "状态", + "sourceFilter": "来源", + "allStatuses": "全部状态", + "allSources": "全部来源", + "filteredCount": "显示 {filtered} / {total} 个文档", "name": "文档名称", "type": "类型", + "status": "状态", "size": "大小", "chunks": "分块数", "createdAt": "上传时间", "actions": "操作", "view": "查看", + "copyFailure": "复制失败诊断", + "rebuild": "重试重建", "delete": "删除", + "rebuildTitle": "重建文档索引", + "rebuildConfirm": "确定要重新构建文档「{name}」的索引吗?", + "rebuildWarning": "重建会重新解析并写入索引。任务完成前,旧索引仍可能被检索到。", + "batchRebuild": "批量重建 ({count})", + "batchRebuildTitle": "批量重建文档索引", + "batchRebuildConfirm": "确定要重新构建选中的 {count} 个文档索引吗?", + "batchRebuildMore": "还有 {count} 个", + "batchRebuildWarning": "批量重建会为选中文档重新解析并写入索引。任务完成前,旧索引仍可能被检索到。", + "batchDelete": "批量删除 ({count})", + "batchDeleteTitle": "批量删除文档", + "batchDeleteConfirm": "确定要删除选中的 {count} 个文档吗?", + "batchDeleteMore": "还有 {count} 个", "cancel": "取消", "deleteConfirm": "确定要删除文档「{name}」吗?", "deleteWarning": "此操作将删除文档及其所有分块,不可恢复。", @@ -46,21 +143,54 @@ "uploadFailed": "文档上传失败", "loadFailed": "加载文档列表失败", "deleteSuccess": "文档删除成功", - "deleteFailed": "文档删除失败" + "deleteFailed": "文档删除失败", + "batchDeleteSuccess": "已删除 {count} 个文档", + "batchDeletePartialSuccess": "批量删除部分完成: 成功 {success} 个, 失败 {failed} 个", + "batchDeleteFailed": "批量删除文档失败", + "batchDeleteLimitExceeded": "单次最多只能删除 {limit} 个文档", + "batchRebuildStarted": "已开始重建 {count} 个文档", + "batchRebuildFailed": "批量重建文档失败", + "batchRebuildLimitExceeded": "单次最多只能重建 {limit} 个文档", + "failureDocument": "文档", + "failureDocumentId": "文档 ID", + "failureStage": "失败阶段", + "failureMessage": "错误信息", + "unknownFailureStage": "未知阶段", + "noFailureMessage": "暂无错误信息", + "copyFailureSuccess": "已复制失败诊断信息", + "copyFailureFailed": "复制失败诊断信息失败", + "rebuildStarted": "文档重建任务已开始", + "rebuildSuccess": "文档重建成功", + "rebuildFailed": "文档重建失败", + "rebuildFailedWithReason": "文档重建失败: {reason}", + "rebuildPartialSuccess": "文档重建部分成功: 成功 {success} 个, 失败 {failed} 个", + "statuses": { + "pending": "等待中", + "parsing": "解析中", + "chunking": "分块中", + "embedding": "索引中", + "ready": "已索引", + "failed": "失败" + }, + "sourceTypes": { + "file": "文件", + "url": "URL", + "import": "导入" + } }, "upload": { "title": "上传文档", "selectFile": "选择文件", "dropzone": "拖放文件到这里或点击选择", - "supportedFormats": "支持的格式: .txt, .md, .markdown, .rst, .adoc, .pdf, .docx, .epub, .xls, .xlsx", - "maxSize": "最大文件大小: 128MB", - "maxFiles": "最多可上传 10 个文件", + "supportedFormats": "支持的格式: {formats}", + "maxSize": "最大文件大小: {size}", + "maxFiles": "最多可上传 {count} 个文件", "maxFilesWarning": "最多只能选择 {count} 个文件", "selectedFiles": "已选择 {count} 个文件", "clear": "清空", "someFilesRejected": "部分文件未加入上传队列", "unsupportedFile": "{name}: 不支持的文件类型", - "fileTooLarge": "{name}: 文件超过 128MB", + "fileTooLarge": "{name}: 文件超过 {size}", "invalidSettings": "请检查上传参数", "chunkSettings": "分块设置", "batchSettings": "批处理设置", @@ -69,15 +199,15 @@ "cleaningProvider": "清洗服务提供商", "cleaningProviderHint": "选择一个 LLM 服务商来对提取的网页内容进行清洗和总结", "chunkSize": "分块大小", - "chunkSizeHint": "每个文本块的字符数 (默认: 512)", + "chunkSizeHint": "每个文本块的字符数 (默认: {value})", "chunkOverlap": "分块重叠", - "chunkOverlapHint": "相邻文本块之间的重叠字符数 (默认: 50)", + "chunkOverlapHint": "相邻文本块之间的重叠字符数 (默认: {value})", "batchSize": "批处理大小", - "batchSizeHint": "每批处理的文本块数量 (默认: 32)", + "batchSizeHint": "每批处理的文本块数量 (默认: {value})", "tasksLimit": "并发任务限制", - "tasksLimitHint": "最大并发上传任务数 (默认: 3)", + "tasksLimitHint": "最大并发上传任务数 (默认: {value})", "maxRetries": "最大重试次数", - "maxRetriesHint": "上传失败任务的重试次数 (默认: 3)", + "maxRetriesHint": "上传失败任务的重试次数 (默认: {value})", "cancel": "取消", "submit": "上传", "fileRequired": "请选择要上传的文件", @@ -86,6 +216,7 @@ "urlPlaceholder": "请输入要提取内容的网页 URL", "urlRequired": "请输入 URL", "urlHint": "将自动从目标 URL 提取主要内容作为文档。目前支持 {supported} 页面,请确保目标网页允许爬虫访问。", + "unsupportedUrlImport": "当前后端未启用 URL 导入功能", "tavilyCheckFailed": "检查网页搜索配置失败", "tavilyRequired": "使用此功能需要配置 Tavily Key", "configure": "配置", @@ -102,7 +233,9 @@ "cleaning": "清洗内容...", "parsing": "解析文档...", "chunking": "文本分块...", - "embedding": "生成向量..." + "embedding": "生成向量...", + "rebuilding": "重建文档...", + "completed": "已完成" }, "beta": "测试版" }, @@ -118,6 +251,12 @@ "tryDifferentQuery": "尝试使用不同的查询词", "settings": "检索设置", "debugMode": "调试模式", + "debugModeTsne": "调试模式 (t-SNE)", + "traceMode": "检索链路追踪", + "cancel": "取消", + "caseNotesPlaceholder": "例如:稀疏检索排名偏低", + "caseTags": "标签", + "caseTagsPlaceholder": "例如:manual, retrieval-ui, bad-case", "tsneVisualization": "t-SNE 可视化", "topK": "返回结果数量", "topKHint": "最多返回多少条检索结果", @@ -128,9 +267,40 @@ "chunk": "文本块 #{index}", "content": "内容", "charCount": "{count} 字符", + "traceTitle": "检索链路", + "traceStageCount": "{count} 个阶段", + "traceHits": "{count} 条", + "traceDenseRank": "稠密排名 #{rank}", + "traceSparseRank": "稀疏排名 #{rank}", + "traceDenseScore": "稠密分", + "traceSparseScore": "稀疏分", + "traceRrfScore": "RRF 分", + "traceRerankScore": "重排分", + "traceDuplicateOf": "重复于 {chunk}", + "traceDedupSimilarity": "重复相似度 {value}", + "sourcePage": "第 {page} 页", + "sourceSection": "章节 {index}", + "sourceParentChunk": "父文本块 {id}", + "tracePreviewEmpty": "暂无内容预览", + "traceEmpty": "该阶段没有候选结果", + "unknownDocument": "未知文档", + "traceStages": { + "dense": "稠密召回", + "sparse": "稀疏召回", + "fusion": "RRF 融合", + "dedup": "近重复去除", + "dedup_removed": "已移除重复项", + "rerank": "重排序", + "final": "最终上下文" + }, "searchSuccess": "检索完成,找到 {count} 条结果", "searchFailed": "检索失败", - "queryRequired": "请输入检索查询" + "queryRequired": "请输入检索查询", + "latestRunResults": "最近结果", + "metricRecall": "召回率", + "metricNdcg": "归一化折损累计增益 (nDCG)", + "metricPrecision": "精确率", + "metricFirstHit": "首个命中" }, "settings": { "title": "知识库设置", @@ -162,7 +332,7 @@ "positiveInteger": "请输入大于 0 的整数", "nonNegativeInteger": "请输入不小于 0 的整数", "overlapLessThanSize": "分块重叠必须小于分块大小", - "topKRange": "返回结果数量必须是 1 到 100 的整数" + "topKRange": "返回结果数量必须是 1 到 {max} 的整数" }, "actions": { "retry": "重试" @@ -175,6 +345,7 @@ "description": "为了使用基于网页的知识库功能,需要提供 Tavily API Key。您可以从", "officialSite": "Tavily 官网", "apiKeyLabel": "Tavily API Key", + "apiKeyPlaceholder": "tvly-...", "cancel": "取消", "save": "保存", "keyRequired": "API Key 不能为空", diff --git a/dashboard/src/i18n/locales/zh-CN/features/knowledge-base/document.json b/dashboard/src/i18n/locales/zh-CN/features/knowledge-base/document.json index ffa01d074a..6127213d92 100644 --- a/dashboard/src/i18n/locales/zh-CN/features/knowledge-base/document.json +++ b/dashboard/src/i18n/locales/zh-CN/features/knowledge-base/document.json @@ -9,13 +9,48 @@ "chunkCount": "分块数量", "createdAt": "上传时间" }, + "processing": { + "title": "处理信息", + "status": "状态", + "sourceType": "来源类型", + "sourceUri": "来源地址", + "contentHash": "内容哈希", + "parser": "解析器", + "chunker": "分块器", + "version": "版本", + "parentDocId": "父文档 ID", + "indexedAt": "索引时间", + "unknownStage": "未知阶段", + "noErrorMessage": "暂无错误信息", + "statuses": { + "pending": "等待中", + "parsing": "解析中", + "chunking": "分块中", + "embedding": "索引中", + "ready": "已索引", + "failed": "失败" + }, + "sourceTypes": { + "file": "文件", + "url": "URL", + "import": "导入", + "api": "API" + } + }, "chunks": { "title": "分块列表", + "total": "{count} 个分块", + "filteredTotal": "匹配 {filtered} / {total} 个分块", "empty": "暂无分块", "index": "序号", "content": "内容", + "titlePath": "标题路径", "charCount": "字符数", "charCountValue": "{count} 字符", + "tokenEstimate": "估算 Token", + "tokenEstimateValue": "约 {count} token", + "offset": "位置", + "contentHash": "内容哈希", "actions": "操作", "view": "查看", "edit": "编辑", @@ -24,6 +59,7 @@ "search": "搜索分块", "searchPlaceholder": "输入关键词搜索分块内容...", "showing": "显示", + "showingRange": "显示 {start} - {end} / {total} 个分块", "deleteConfirm": "确定要删除该文本块吗?", "deleteSuccess": "文本块删除成功", "deleteFailed": "文本块删除失败" @@ -50,14 +86,39 @@ "index": "序号", "content": "内容", "charCount": "字符数", + "tokenEstimate": "估算 Token", + "titlePath": "标题路径", + "section": "章节", + "pageNumber": "页码", + "offset": "位置", + "contentHash": "内容哈希", + "adjacentChunks": "相邻分块", + "previousChunk": "上一块: {id}", + "nextChunk": "下一块: {id}", + "parentChunk": "父分块", "vecDocId": "向量ID", + "context": "相邻上下文", + "previous": "上一块", + "current": "当前块", + "next": "下一块", + "contextMissing": "暂无相邻分块", "close": "关闭" }, "actions": { - "retry": "重试" + "retry": "重试", + "retryRebuild": "重试重建", + "retryRebuildConfirm": "确定要重新构建该文档索引吗?" }, "messages": { "loadDocumentFailed": "加载文档详情失败", - "loadChunksFailed": "加载分块列表失败" + "loadChunksFailed": "加载分块列表失败", + "loadChunkContextFailed": "加载相邻上下文失败", + "rebuildStarted": "文档重建任务已开始", + "rebuildCompleted": "文档重建完成", + "rebuildFailed": "文档重建失败", + "rebuildFailedWithReason": "文档重建失败: {reason}", + "focusChunkLoaded": "已打开检索命中的分块", + "focusChunkFailed": "打开检索命中的分块失败", + "focusChunkNotFound": "未找到检索命中的分块" } } diff --git a/dashboard/src/i18n/locales/zh-CN/features/knowledge-base/index.json b/dashboard/src/i18n/locales/zh-CN/features/knowledge-base/index.json index 6343412817..87d74926db 100644 --- a/dashboard/src/i18n/locales/zh-CN/features/knowledge-base/index.json +++ b/dashboard/src/i18n/locales/zh-CN/features/knowledge-base/index.json @@ -29,8 +29,8 @@ "descriptionLabel": "描述", "descriptionPlaceholder": "简单描述这个知识库的用途...", "emojiLabel": "图标", - "embeddingModelLabel": "嵌入模型 (Embedding Model)", - "rerankModelLabel": "重排序模型 (Rerank Model, 可选)", + "embeddingModelLabel": "嵌入模型", + "rerankModelLabel": "重排序模型(可选)", "providerInfo": "提供商: {id} | 维度: {dimensions}", "rerankProviderInfo": "提供商: {id}", "nameHint": "如果后续修改知识库名称,请同步更新仍按名称引用的配置。", diff --git a/dashboard/src/views/knowledge-base/DocumentDetail.vue b/dashboard/src/views/knowledge-base/DocumentDetail.vue index 0645a8f0c9..5042e4d4a7 100644 --- a/dashboard/src/views/knowledge-base/DocumentDetail.vue +++ b/dashboard/src/views/knowledge-base/DocumentDetail.vue @@ -101,15 +101,203 @@ + + {{ t("processing.title") }} + + + +
+ + {{ getDocumentStatusIcon(document.status) }} + +
+
+ {{ t("processing.status") }} +
+ + {{ getDocumentStatusText(document.status) }} + +
+
+
+ +
+ mdi-source-branch +
+
+ {{ t("processing.sourceType") }} +
+
+ {{ getSourceTypeText(document.source_type) }} +
+
+
+
+ +
+ mdi-counter +
+
+ {{ t("processing.version") }} +
+
+ {{ document.version || 1 }} +
+
+
+
+ +
+ mdi-calendar-check +
+
+ {{ t("processing.indexedAt") }} +
+
+ {{ formatDate(document.indexed_at) }} +
+
+
+
+ +
+ mdi-link-variant + +
+
+ +
+ mdi-fingerprint + +
+
+ +
+ mdi-file-cog-outline + +
+
+ +
+ mdi-text-box-check-outline + +
+
+ +
+ mdi-file-replace-outline + +
+
+
+ +
+ + + {{ t("actions.retryRebuild") }} + +
+
+
+
+ {{ t("chunks.title") }} - {{ totalChunks }} {{ t("chunks.title") }} + {{ + hasChunkSearch + ? t("chunks.filteredTotal", { + filtered: totalChunks, + total: displayDocumentChunkCount, + }) + : t("chunks.total", { count: displayDocumentChunkCount }) + }} - + /> + + + + + + + + From 8d60a8e6865c2a2c9f9890311980266abb993c91 Mon Sep 17 00:00:00 2001 From: lxfight <1686540385@qq.com> Date: Sat, 6 Jun 2026 15:18:30 +0800 Subject: [PATCH 30/48] style: simplify knowledge base overview layout --- .../en-US/features/knowledge-base/detail.json | 6 +- .../zh-CN/features/knowledge-base/detail.json | 6 +- .../src/views/knowledge-base/KBDetail.vue | 439 +++++++++++------- dashboard/src/views/knowledge-base/index.vue | 2 +- 4 files changed, 276 insertions(+), 177 deletions(-) diff --git a/dashboard/src/i18n/locales/en-US/features/knowledge-base/detail.json b/dashboard/src/i18n/locales/en-US/features/knowledge-base/detail.json index 177dda1fd3..525085d08b 100644 --- a/dashboard/src/i18n/locales/en-US/features/knowledge-base/detail.json +++ b/dashboard/src/i18n/locales/en-US/features/knowledge-base/detail.json @@ -34,6 +34,9 @@ "run": "Run Check", "repair": "Repair Fixable Issues", "notRun": "No consistency check has been run yet. Run a check to compare document metadata, source files, and indexed chunks.", + "notRunHint": "A full check reads index metadata and lists fixable issues.", + "notRunChunkMismatch": "Current snapshot has {metadata} document chunks but {indexed} indexed chunks. Run a check.", + "notRunFailedDocs": "{count} documents are failed. Review the document list or run a consistency check.", "healthy": "No consistency issues found", "unhealthy": "{count} consistency issues found", "checkedAt": "Checked at: {time}", @@ -85,8 +88,9 @@ "refresh": "Refresh tasks", "empty": "No task records yet", "loadFailed": "Failed to load recent tasks", - "recentFailures": "Recent Failures", "noErrorMessage": "No error message", + "resultSummary": "{total} total, {success} succeeded, {failed} failed", + "progressDetail": "Progress {progress}", "types": { "upload": "Document Upload", "import": "Document Import", diff --git a/dashboard/src/i18n/locales/zh-CN/features/knowledge-base/detail.json b/dashboard/src/i18n/locales/zh-CN/features/knowledge-base/detail.json index bc04ecfca8..95ec3e89eb 100644 --- a/dashboard/src/i18n/locales/zh-CN/features/knowledge-base/detail.json +++ b/dashboard/src/i18n/locales/zh-CN/features/knowledge-base/detail.json @@ -34,6 +34,9 @@ "run": "运行检查", "repair": "修复可修复项", "notRun": "尚未运行一致性检查。点击运行检查可诊断文档元数据、源文件和索引文本块是否一致。", + "notRunHint": "完整检查会读取索引元数据,并列出可修复项。", + "notRunChunkMismatch": "当前快照显示文档记录有 {metadata} 个分块,索引中有 {indexed} 个分块,建议运行检查。", + "notRunFailedDocs": "当前有 {count} 个失败文档,建议查看文档列表或运行一致性检查。", "healthy": "未发现一致性问题", "unhealthy": "发现 {count} 个一致性问题", "checkedAt": "检查时间: {time}", @@ -85,8 +88,9 @@ "refresh": "刷新任务", "empty": "暂无任务记录", "loadFailed": "加载最近任务失败", - "recentFailures": "最近失败", "noErrorMessage": "暂无错误信息", + "resultSummary": "共 {total} 个,成功 {success} 个,失败 {failed} 个", + "progressDetail": "进度 {progress}", "types": { "upload": "上传文档", "import": "导入文档", diff --git a/dashboard/src/views/knowledge-base/KBDetail.vue b/dashboard/src/views/knowledge-base/KBDetail.vue index d0c00c6652..f9dbca0c63 100644 --- a/dashboard/src/views/knowledge-base/KBDetail.vue +++ b/dashboard/src/views/knowledge-base/KBDetail.vue @@ -43,9 +43,12 @@ - - - + + + {{ t("overview.title") }} @@ -108,19 +111,46 @@ formatDate(kb.updated_at) }} + + + + {{ + t("overview.embeddingModel") + }} + {{ + kb.embedding_provider_id || t("overview.notSet") + }} + + + + + {{ + t("overview.rerankModel") + }} + {{ + kb.rerank_provider_id || t("overview.notSet") + }} + - - + + {{ t("overview.stats") }} - - + +
- mdi-file-document
{{ documentCount }}
@@ -129,9 +159,9 @@
- +
- mdi-text-box
{{ indexedChunkCount }}
@@ -140,9 +170,9 @@
- +
- mdi-check-circle-outline
{{ readyDocumentCount }}
@@ -151,9 +181,9 @@
- +
- mdi-alert-circle-outline
{{ failedDocumentCount }}
@@ -162,20 +192,18 @@
- +
- mdi-file-cabinet + mdi-folder
{{ sourceFileCount }}
{{ t("overview.sourceFiles") }}
- +
- mdi-database-outline + mdi-database
{{ formatFileSize(storageBytes) }}
@@ -187,8 +215,10 @@ + - + + @@ -298,12 +328,17 @@ - {{ t("consistency.notRun") }} +
+ {{ consistencyPrecheckMessage }} + + {{ t("consistency.notRunHint") }} + +
@@ -400,142 +435,97 @@
+
- - - {{ t("tasks.title") }} - - - - - +
+ + - {{ recentTasksLoadError }} - - - {{ t("tasks.empty") }} - - - {{ t("tasks.title") }} + + + + + - - - {{ getTaskTypeText(task.task_type) }} - - {{ getTaskStatusText(task.status) }} - - - - {{ - formatDate(task.updated_at || task.created_at || "") - }} - - - - - -
-
- {{ t("tasks.recentFailures") }} -
- + {{ recentTasksLoadError }} + + + {{ t("tasks.empty") }} + + - - - {{ - t("overview.rerankModel") - }} - {{ - kb.rerank_provider_id || t("overview.notSet") - }} - - - - + + +
@@ -694,6 +684,9 @@ const failedDocumentCount = computed( const indexedChunkCount = computed( () => kb.value.indexed_chunk_count ?? kb.value.chunk_count ?? 0, ); +const documentChunkCount = computed( + () => kb.value.document_chunk_count ?? indexedChunkCount.value, +); const sourceFileCount = computed(() => kb.value.source_file_count ?? 0); const storageBytes = computed(() => kb.value.storage_bytes ?? 0); const supportsConsistencyCheck = computed(() => @@ -711,7 +704,6 @@ const consistencyReport = ref(null); const kbRebuilding = ref(false); const kbRebuildTaskId = ref(""); const recentTasks = ref([]); -const recentFailedTasks = ref([]); const recentTasksLoading = ref(false); const recentTasksLoadError = ref(""); const kbRebuildProgress = ref({ @@ -757,6 +749,26 @@ const visibleConsistencyIssueTypes = computed(() => { (issueType) => (consistencyReport.value?.summary[issueType.key] ?? 0) > 0, ); }); +const hasChunkCountDrift = computed( + () => documentChunkCount.value !== indexedChunkCount.value, +); +const consistencyPrecheckType = computed(() => + failedDocumentCount.value > 0 || hasChunkCountDrift.value ? "warning" : "info", +); +const consistencyPrecheckMessage = computed(() => { + if (hasChunkCountDrift.value) { + return t("consistency.notRunChunkMismatch", { + metadata: documentChunkCount.value, + indexed: indexedChunkCount.value, + }); + } + if (failedDocumentCount.value > 0) { + return t("consistency.notRunFailedDocs", { + count: failedDocumentCount.value, + }); + } + return t("consistency.notRun"); +}); const repairableConsistencyTypes = computed(() => getRepairableConsistencyTypes(consistencyReport.value), ); @@ -765,7 +777,6 @@ const canRepairConsistency = computed( supportsConsistencyRepair.value && hasRepairableConsistencyIssues(consistencyReport.value), ); - const snackbar = ref({ show: false, text: "", @@ -818,35 +829,19 @@ const loadRecentTasks = async () => { recentTasksLoading.value = true; recentTasksLoadError.value = ""; try { - const [tasksResponse, failedTasksResponse] = await Promise.all([ - axios.get("/api/kb/task/list", { - params: { - kb_id: kbId.value, - page: 1, - page_size: 5, - }, - }), - axios.get("/api/kb/task/list", { - params: { - kb_id: kbId.value, - status: "failed", - page: 1, - page_size: 3, - }, - }), - ]); + const tasksResponse = await axios.get("/api/kb/task/list", { + params: { + kb_id: kbId.value, + page: 1, + page_size: 5, + }, + }); if (tasksResponse.data.status !== "ok") { recentTasksLoadError.value = tasksResponse.data.message || t("tasks.loadFailed"); return; } - if (failedTasksResponse.data.status !== "ok") { - recentTasksLoadError.value = - failedTasksResponse.data.message || t("tasks.loadFailed"); - return; - } recentTasks.value = tasksResponse.data.data.items || []; - recentFailedTasks.value = failedTasksResponse.data.data.items || []; } catch (error) { console.error("Failed to load recent knowledge base tasks:", error); recentTasksLoadError.value = t("tasks.loadFailed"); @@ -1093,6 +1088,19 @@ const getTaskTypeText = (taskType: string) => const getTaskStatusText = (status: string) => t(`tasks.statuses.${status}`) || status; +const toTaskCount = (value: unknown) => { + const numberValue = Number(value); + return Number.isFinite(numberValue) ? numberValue : 0; +}; + +const getTaskResultCounts = (task: KnowledgeBaseTask) => { + const result = task.result || {}; + const success = toTaskCount(result.success_count); + const failed = toTaskCount(result.failed_count); + const total = toTaskCount(result.total) || success + failed; + return { success, failed, total }; +}; + const formatTaskProgress = (task: KnowledgeBaseTask) => { const progress = getKnowledgeBaseTaskProgress(task); return `${progress.current} / ${progress.total}`; @@ -1101,6 +1109,30 @@ const formatTaskProgress = (task: KnowledgeBaseTask) => { const formatTaskError = (task: KnowledgeBaseTask) => getKnowledgeBaseTaskErrorText(task.error, t("tasks.noErrorMessage")); +const formatTaskSubtitle = (task: KnowledgeBaseTask) => + formatDate(task.updated_at || task.created_at || ""); + +const formatTaskDetail = (task: KnowledgeBaseTask) => { + if (task.status === "pending" || task.status === "processing") { + return t("tasks.progressDetail", { + progress: formatTaskProgress(task), + }); + } + if (task.status === "failed") { + return formatTaskError(task); + } + + const { success, failed, total } = getTaskResultCounts(task); + if (total > 0) { + return t("tasks.resultSummary", { + success, + failed, + total, + }); + } + return ""; +}; + const formatConsistencyIssueTitle = (issue: ConsistencyIssue) => { return ( issue.doc_name || issue.doc_id || issue.chunk_id || String(issue.storage_id) @@ -1189,13 +1221,42 @@ watch( min-height: 400px; } +.overview-layout { + align-items: stretch; +} + +.overview-layout > .v-col { + display: flex; +} + +.overview-card { + width: 100%; +} + +.overview-card--fill { + height: 100%; +} + +.overview-side-stack { + display: grid; + gap: 16px; + width: 100%; +} + +.stats-grid > .v-col { + display: flex; +} + .stat-box { + min-height: 118px; + width: 100%; display: flex; flex-direction: column; align-items: center; - padding: 24px; + justify-content: center; + padding: 18px 14px; text-align: center; - border-radius: 12px; + border-radius: 8px; background: rgba(var(--v-theme-surface-variant), 0.1); transition: all 0.3s ease; } @@ -1205,14 +1266,21 @@ watch( } .stat-value { - font-size: 2rem; + font-size: 1.75rem; font-weight: 600; + line-height: 1.2; margin-top: 8px; + max-width: 100%; + overflow-wrap: anywhere; } .stat-label { + color: rgba(var(--v-theme-on-surface), 0.72); font-size: 0.875rem; + line-height: 1.35; margin-top: 4px; + max-width: 100%; + overflow-wrap: anywhere; } .consistency-metric { @@ -1236,10 +1304,33 @@ watch( overflow-wrap: anywhere; } +.task-detail-line { + display: block; + margin-top: 2px; + color: rgba(var(--v-theme-on-surface), 0.68); + font-size: 0.75rem; + line-height: 1.35; + overflow-wrap: anywhere; +} + +.task-list--timeline :deep(.v-list-item) { + border-left: 2px solid rgba(var(--v-theme-outline), 0.16); + padding-left: 12px !important; +} + /* 响应式设计 */ @media (max-width: 768px) { .kb-title { font-size: 1.25rem; } + + .stat-box { + min-height: 108px; + padding: 16px 10px; + } + + .stat-value { + font-size: 1.45rem; + } } diff --git a/dashboard/src/views/knowledge-base/index.vue b/dashboard/src/views/knowledge-base/index.vue index 13df70e6fb..bada54f21d 100644 --- a/dashboard/src/views/knowledge-base/index.vue +++ b/dashboard/src/views/knowledge-base/index.vue @@ -56,7 +56,7 @@ const goToList = () => {