From 2a6ba11dee541e2dd4ebccc4a4f129228f96bae2 Mon Sep 17 00:00:00 2001 From: jinyx5 Date: Thu, 18 Jun 2026 15:48:03 +0800 Subject: [PATCH 1/4] feat: integrate cacheseek cross-request cache (KV/latent reuse) Replace the in-tree telefuser/cache_mem cache with cacheseek as the cross-request cache middleware. - service (container/task_service/api_server): build and drive (CacheService, TeleFuserCacheAdapter); per request build_query -> lookup -> apply_resume -> on_response -> save - lingbot_world_fast: world_kv hooks (on_runtime_created / on_chunk_finalized) + decode-only fast path for exact-prefix KV reuse; enable rolling KV window (local_attn_size=7, sink_size=3) - remove legacy telefuser/cache_mem + service/cache/cache_factory| cache_service and the cache_mem unit tests - pin torch==2.7.0 + torchvision==0.22.0 - docs: update latent_cache (en/zh) --- docs/en/latent_cache.md | 55 +- docs/zh/latent_cache.md | 40 +- .../wan22_14b_text_to_video_service.py | 68 +- ...wan22_14b_text_to_video_service_nocache.py | 256 ++++++ pyproject.toml | 10 +- telefuser/cache_mem/__init__.py | 27 - telefuser/cache_mem/cache_types.py | 40 - telefuser/cache_mem/config.py | 83 -- telefuser/cache_mem/connection.py | 197 ----- telefuser/cache_mem/encoders.py | 398 --------- telefuser/cache_mem/encoding/__init__.py | 0 telefuser/cache_mem/encoding/interfaces.py | 27 - telefuser/cache_mem/latent_cache.py | 213 ----- telefuser/cache_mem/log_monitor.py | 77 -- telefuser/cache_mem/metadata.py | 268 ------ telefuser/cache_mem/src/__init__.py | 0 telefuser/cache_mem/src/models/__init__.py | 0 .../src/models/qwen3_vl_embedding.py | 346 -------- .../cache_mem/src/models/qwen3_vl_reranker.py | 437 ---------- telefuser/cache_mem/state/__init__.py | 0 telefuser/cache_mem/state/interfaces.py | 67 -- telefuser/cache_mem/storage/__init__.py | 11 - telefuser/cache_mem/storage/fluxon.py | 24 - telefuser/cache_mem/storage/interfaces.py | 25 - telefuser/cache_mem/storage/local_file.py | 112 --- telefuser/cache_mem/storage/memory.py | 24 - telefuser/cache_mem/strategies.py | 819 ------------------ telefuser/cache_mem/vector_store/__init__.py | 5 - telefuser/cache_mem/vector_store/faiss.py | 298 ------- .../cache_mem/vector_store/interfaces.py | 42 - telefuser/cache_mem/vector_store/qdrant.py | 46 - .../pipelines/lingbot_world_fast/pipeline.py | 93 +- .../pipelines/lingbot_world_fast/session.py | 8 + telefuser/service/api/api_server.py | 6 +- telefuser/service/cache/__init__.py | 12 +- telefuser/service/cache/cache_factory.py | 176 ---- telefuser/service/cache/cache_service.py | 389 --------- telefuser/service/core/container.py | 24 +- telefuser/service/core/task_service.py | 46 +- tests/unit/cache_mem/__init__.py | 1 - tests/unit/cache_mem/test_concurrency.py | 355 -------- tests/unit/cache_mem/test_metadata.py | 150 ---- tests/unit/cache_mem/test_storage.py | 105 --- tests/unit/cache_mem/test_types_and_config.py | 96 -- 44 files changed, 475 insertions(+), 5001 deletions(-) create mode 100644 examples/wan_video/wan22_14b_text_to_video_service_nocache.py delete mode 100644 telefuser/cache_mem/__init__.py delete mode 100644 telefuser/cache_mem/cache_types.py delete mode 100644 telefuser/cache_mem/config.py delete mode 100644 telefuser/cache_mem/connection.py delete mode 100644 telefuser/cache_mem/encoders.py delete mode 100644 telefuser/cache_mem/encoding/__init__.py delete mode 100644 telefuser/cache_mem/encoding/interfaces.py delete mode 100644 telefuser/cache_mem/latent_cache.py delete mode 100644 telefuser/cache_mem/log_monitor.py delete mode 100644 telefuser/cache_mem/metadata.py delete mode 100644 telefuser/cache_mem/src/__init__.py delete mode 100644 telefuser/cache_mem/src/models/__init__.py delete mode 100644 telefuser/cache_mem/src/models/qwen3_vl_embedding.py delete mode 100644 telefuser/cache_mem/src/models/qwen3_vl_reranker.py delete mode 100644 telefuser/cache_mem/state/__init__.py delete mode 100644 telefuser/cache_mem/state/interfaces.py delete mode 100644 telefuser/cache_mem/storage/__init__.py delete mode 100644 telefuser/cache_mem/storage/fluxon.py delete mode 100644 telefuser/cache_mem/storage/interfaces.py delete mode 100644 telefuser/cache_mem/storage/local_file.py delete mode 100644 telefuser/cache_mem/storage/memory.py delete mode 100644 telefuser/cache_mem/strategies.py delete mode 100644 telefuser/cache_mem/vector_store/__init__.py delete mode 100644 telefuser/cache_mem/vector_store/faiss.py delete mode 100644 telefuser/cache_mem/vector_store/interfaces.py delete mode 100644 telefuser/cache_mem/vector_store/qdrant.py delete mode 100644 telefuser/service/cache/cache_factory.py delete mode 100644 telefuser/service/cache/cache_service.py delete mode 100644 tests/unit/cache_mem/__init__.py delete mode 100644 tests/unit/cache_mem/test_concurrency.py delete mode 100644 tests/unit/cache_mem/test_metadata.py delete mode 100644 tests/unit/cache_mem/test_storage.py delete mode 100644 tests/unit/cache_mem/test_types_and_config.py diff --git a/docs/en/latent_cache.md b/docs/en/latent_cache.md index 251f429..b424e52 100644 --- a/docs/en/latent_cache.md +++ b/docs/en/latent_cache.md @@ -128,14 +128,14 @@ to full denoising, so the main path is never poisoned by a bad cache entry. ## Factory Function -The production path does not construct `LatentCache` directly. Instead, -`CacheServiceFactory` builds a `CacheService` from CLI arguments and the -`CACHE_CONFIG` declared in the pipeline file: +The production path no longer constructs `LatentCache` directly. Instead, +the Cacheseek TeleFuser adapter builds a `(CacheService, TeleFuserCacheAdapter)` +pair from CLI arguments and the `CACHE_CONFIG` declared in the pipeline file: ```python -from telefuser.service.cache import CacheServiceFactory +from cacheseek.adapters.telefuser.cache_factory import CacheServiceFactory -cache_service = CacheServiceFactory.create_cache_service( +cache_service, cache_adapter = CacheServiceFactory.create_cache_service( ppl_file="examples/wan_video/wan22_14b_text_to_video_service.py", enable_latent_cache=True, cache_mode="read_write", # "read_write" / "read_only" / "write_only" @@ -149,18 +149,17 @@ cache_service = CacheServiceFactory.create_cache_service( 2. Overrides the final config with the CLI's `enable_latent_cache` / `cache_mode`. 3. Initializes the cache log sink up front. -4. Loads `build_latent_data` from `ppl_file` (**must exist**, otherwise it - raises an error). -5. Instantiates `LatentCache(cache_dir, config)` and wraps it inside - `CacheService`. +4. Builds the Cacheseek storage, vector store, metadata manager, and strategy. +5. Returns the framework-agnostic `CacheService` plus the TeleFuser adapter + used for `build_query`, `apply_resume`, and `on_response`. -Manual construction is also supported when needed: +Manual construction should use Cacheseek primitives directly when needed: ```python from pathlib import Path -from telefuser.cache_mem.config import CacheConfig -from telefuser.cache_mem.latent_cache import LatentCache +from cacheseek.core.config import CacheConfig +from cacheseek.core.lifecycle import CacheService config = CacheConfig( enable_latent_cache=True, @@ -168,13 +167,15 @@ config = CacheConfig( cache_strategy_type="video_approximate", vector_dim=2048, ) -cache = LatentCache(Path(config.latent_cache_dir), config) +cache_service = CacheService.from_config(config) ``` -The strategy class is looked up in the registry via `cache_strategy_type`: +The strategy class is selected by Cacheseek's TeleFuser factory from +`cache_strategy_type`. Custom strategies should be registered in Cacheseek, +not under `telefuser.cache_mem`. ```python -from telefuser.cache_mem.strategies import register_strategy, get_strategy_class +from cacheseek.core.strategies import get_strategy_class, register_strategy register_strategy("video_approximate", VideoBasedApproximateCache) # already registered by default strategy_cls = get_strategy_class("video_approximate") @@ -203,17 +204,17 @@ The only production strategy implementation is #### Write Path -When a request finishes, the pipeline hands its `latent_payload` (containing -the per-step latents plus video frames used for prompt similarity) to -`CacheService.save_latent_payload`, which enqueues it onto the -`cache-save-worker` background thread. The thread invokes -`LatentCache.save`: +When a request finishes, the pipeline returns `latent_payload` containing +the per-step latents plus video frames used for prompt similarity. The service +layer passes that payload through `TeleFuserCacheAdapter.on_response`, then +calls `CacheService.save(cache_query, outputs)`, which enqueues it onto the +Cacheseek async save worker: 1. Writes each step's latent to the KV store under a key shaped like `f"{cache_id}_step{step}"`. 2. Encodes the video frames with `Qwen3-VL-Embedding` and upserts the vector into the vector store (default collection name `video`). -3. Registers `cache_id → {prompt, saved_steps, size_mb, …}` in metadata, +3. Registers `cache_id -> {prompt, saved_steps, size_mb, ...}` in metadata, persisting `prompt_index.json` and `cache_meta.json`. If any step fails, all the latents / vectors / metadata that were already @@ -221,11 +222,12 @@ written are rolled back cleanly to avoid an inconsistent state. #### Hit Path -When a new request arrives, `CacheService.build_latent_data`: +When a new request arrives, the service layer runs +`adapter.build_query -> cache_service.lookup -> adapter.apply_resume`: 1. Waits on `vector_update_idle` to make sure the vector upsert from the previous async save has been committed. -2. Calls `LatentCache.lookup`: encodes the new prompt, queries the top-k +2. Calls Cacheseek lookup: encodes the new prompt, queries the top-k approximate caches in the vector store, optionally reranks with Qwen3-VL-Reranker, and compares against the threshold to decide on a hit. @@ -287,12 +289,9 @@ CACHE_CONFIG = dict( ) ``` -The pipeline file also has to provide two hooks the service layer relies on -to wire the cache into the main path: +The pipeline file only needs to expose `run_with_file` plus a `CACHE_CONFIG` +dict for Cacheseek configuration: -- `build_latent_data(task_data: dict, cache_result=None) -> dict`: converts - `cache_result` into the `latent_data` dict the pipeline expects (with - `hit / skip_step / cached_latent / saved_steps`). - `run_with_file(pipeline, **task_data) -> dict`: feeds `latent_data` into the pipeline and returns `latent_payload` as part of the result so the service layer can write it back to the cache. diff --git a/docs/zh/latent_cache.md b/docs/zh/latent_cache.md index 4606e5a..1c05dd7 100644 --- a/docs/zh/latent_cache.md +++ b/docs/zh/latent_cache.md @@ -121,13 +121,13 @@ shape / 范围校验,shape 不一致或 `skip_step` 越界时会自动丢弃 ## 工厂函数 -线上路径不直接构造 `LatentCache`,而是由 `CacheServiceFactory` 根据 -CLI 参数和 pipeline 文件中的 `CACHE_CONFIG` 生成 `CacheService`: +线上路径不再直接构造 `LatentCache`,而是由 Cacheseek 的 TeleFuser 适配器根据 +CLI 参数和 pipeline 文件中的 `CACHE_CONFIG` 生成 `(CacheService, TeleFuserCacheAdapter)`: ```python -from telefuser.service.cache import CacheServiceFactory +from cacheseek.adapters.telefuser.cache_factory import CacheServiceFactory -cache_service = CacheServiceFactory.create_cache_service( +cache_service, cache_adapter = CacheServiceFactory.create_cache_service( ppl_file="examples/wan_video/wan22_14b_text_to_video_service.py", enable_latent_cache=True, cache_mode="read_write", # "read_write" / "read_only" / "write_only" @@ -139,16 +139,17 @@ cache_service = CacheServiceFactory.create_cache_service( 1. 从 `ppl_file` 加载 `CACHE_CONFIG`(dict 或 `CacheConfig` 实例)作为默认配置基础。 2. 用 CLI 的 `enable_latent_cache` / `cache_mode` 覆盖最终配置。 3. 提前初始化 cache 日志 sink。 -4. 加载 `ppl_file` 中的 `build_latent_data` 函数(**必须存在**,否则报错)。 -5. 实例化 `LatentCache(cache_dir, config)`,再包装为 `CacheService`。 +4. 构造 Cacheseek 的存储、向量库、元数据管理器和策略。 +5. 返回框架无关的 `CacheService`,以及用于 `build_query`、`apply_resume` + 和 `on_response` 的 TeleFuser 适配器。 -需要直接构造时也支持手动接入: +需要直接构造时应使用 Cacheseek 原语: ```python from pathlib import Path -from telefuser.cache_mem.config import CacheConfig -from telefuser.cache_mem.latent_cache import LatentCache +from cacheseek.core.config import CacheConfig +from cacheseek.core.lifecycle import CacheService config = CacheConfig( enable_latent_cache=True, @@ -156,13 +157,14 @@ config = CacheConfig( cache_strategy_type="video_approximate", vector_dim=2048, ) -cache = LatentCache(Path(config.latent_cache_dir), config) +cache_service = CacheService.from_config(config) ``` -策略类通过 `cache_strategy_type` 在注册表中查找: +策略类由 Cacheseek 的 TeleFuser factory 根据 `cache_strategy_type` 选择。 +自定义策略应注册到 Cacheseek,而不是 `telefuser.cache_mem`: ```python -from telefuser.cache_mem.strategies import register_strategy, get_strategy_class +from cacheseek.core.strategies import get_strategy_class, register_strategy register_strategy("video_approximate", VideoBasedApproximateCache) # 默认已注册 strategy_cls = get_strategy_class("video_approximate") @@ -185,24 +187,26 @@ strategy_cls = get_strategy_class("video_approximate") #### 写入路径 -请求结束、pipeline 把 `latent_payload`(含按步存储的 latent + 用于 prompt -相似度的视频帧)传给 `CacheService.save_latent_payload`,后者放入 -`cache-save-worker` 后台线程;线程调用 `LatentCache.save`: +请求结束后,pipeline 返回 `latent_payload`(含按步存储的 latent + 用于 prompt +相似度的视频帧)。服务层先经过 `TeleFuserCacheAdapter.on_response` 打包, +再调用 `CacheService.save(cache_query, outputs)`,由 Cacheseek 异步保存 worker +处理: 1. 将每个 step 的 latent 写到 KV,key 形如 `f"{cache_id}_step{step}"`。 2. 通过 `Qwen3-VL-Embedding` 将视频帧编码成向量,upsert 至 向量检索库(默认 collection 名 `video`)。 -3. 在 metadata 里登记 `cache_id → {prompt, saved_steps, size_mb, …}`, +3. 在 metadata 里登记 `cache_id -> {prompt, saved_steps, size_mb, ...}`, 持久化 `prompt_index.json` 和 `cache_meta.json`。 任何一步失败,已写入的 latent / 向量 / metadata 都会回滚干净,避免状态不一致。 #### 命中路径 -新请求到达,`CacheService.build_latent_data`: +新请求到达后,服务层执行 +`adapter.build_query -> cache_service.lookup -> adapter.apply_resume`: 1. 等待 `vector_update_idle`——确保上一笔异步 save 的向量 upsert 已落库。 -2. 调用 `LatentCache.lookup`:对新 prompt 编码,在向量检索库中查 top-k 近似 +2. 调用 Cacheseek lookup:对新 prompt 编码,在向量检索库中查 top-k 近似 缓存;可选用 Qwen3-VL-Reranker 重排,跟阈值比对决定是否命中。 3. 命中后从 KV 读出 `skip_step` 对应的 latent 张量,封装成 `CacheResult` 返回。 diff --git a/examples/wan_video/wan22_14b_text_to_video_service.py b/examples/wan_video/wan22_14b_text_to_video_service.py index 7d3e590..ba8605e 100644 --- a/examples/wan_video/wan22_14b_text_to_video_service.py +++ b/examples/wan_video/wan22_14b_text_to_video_service.py @@ -3,17 +3,18 @@ Service-mode counterpart of ``wan22_14b_text_to_video_h100.py``. Exposes: - ``get_pipeline`` for service startup - ``run_with_file`` for TeleFuser PipelineService (must return dict with ``output_path``) -- ``build_latent_data`` for CacheServiceFactory cache lookup path - ``CACHE_CONFIG`` for CacheServiceFactory config overrides -Cross-request latent cache is wired via: -1. Service layer -> ``cache_service.build_latent_data(task_request, task_data)`` - -> this module's ``build_latent_data`` merges cache_result into ``latent_data`` +Cross-request latent cache is wired through Cacheseek: +1. Service layer -> ``adapter.build_query(task_request)`` + -> ``cache_service.lookup(cache_query)`` + -> ``adapter.apply_resume(lookup_result, engine_ctx=task_data)`` 2. ``run_with_file`` forwards ``latent_data`` to ``pipeline.__call__`` 3. Pipeline returns ``(frames, latent_payload)`` when ``latent_data`` is not None 4. ``run_with_file`` samples a few frames and writes back ``latent_payload["embedding_video_frames"]`` to satisfy VideoBasedApproximateCache.save's precondition -5. Service layer -> ``cache_service.save_latent_payload(task_request, latent_payload)`` +5. Service layer -> ``adapter.on_response(task_request, latent_payload)`` + -> ``cache_service.save(cache_query, outputs)`` """ from __future__ import annotations @@ -21,8 +22,8 @@ import os import torch +from cacheseek.core.config import CacheConfig -from telefuser.cache_mem.config import CacheConfig from telefuser.core.config import AttentionConfig, AttnImplType, FeatureCacheConfig, WeightOffloadType from telefuser.core.module_manager import ModuleManager from telefuser.pipelines.wan_video.wan22_video import ( @@ -59,18 +60,26 @@ CACHE_CONFIG = dict( enable_latent_cache=True, latent_cache_dir=os.getenv("TELEFUSER_LATENT_CACHE_DIR", "./latent_cache/wan22_t2v"), - # write_only: skip lookup (every prompt is unique in W03 dataset_500), - # force save_latent_payload so all 500 prompts persist a latent snapshot. - # See cache-evolution/design/w03 · latents 分布可视化数据集构建(design).md - cache_mode="write_only", - # KV store: Fluxon is stubbed in MVP -> use local file backend. - kv_store_type="local_file", - # Vector store: Qdrant is stubbed in MVP -> use faiss backend. - vector_store_type="faiss", + # Default write_only for dataset_500 (every prompt is unique, so always save). + # Set TELEFUSER_CACHE_MODE=read_write to also serve from cache (verifying hits). + # See the latent distribution dataset design notes. + cache_mode=os.getenv("TELEFUSER_CACHE_MODE", "write_only"), + # KV store: default local_file. Switch to fluxon by setting + # TELEFUSER_KV_STORE_TYPE=fluxon + # FLUXON_CONFIG_PATH=/path/to/external_config.yaml + # Fluxon adapter requires a running master + kvclient owner on this host; + # See the Fluxon backend integration notes for the deploy procedure. + kv_store_type=os.getenv("TELEFUSER_KV_STORE_TYPE", "local_file"), + fluxon_config_path=os.getenv("FLUXON_CONFIG_PATH", ""), + # Vector store: default faiss; switch to qdrant via env when a real + # Qdrant server is running (see the Qdrant deployment notes). + # TELEFUSER_VECTOR_STORE_TYPE=qdrant QDRANT_URL=http://127.0.0.1:6333 + vector_store_type=os.getenv("TELEFUSER_VECTOR_STORE_TYPE", "faiss"), + qdrant_url=os.getenv("QDRANT_URL", ""), # Qwen3-VL-Embedding-2B hidden_size=2048. connection.py default is 512 (too small). # MUST match encoder output dim or FAISSVectorStore.search raises dim mismatch. vector_dim=2048, - # Steps to snapshot. Aligned with W03 dataset design §L145/255 + # Steps to snapshot. Aligned with the dataset design notes. # (5 mid-to-late steps, no step 0). key_steps=[5, 10, 15, 20, 25], # Video embedding: required by VideoBasedApproximateCache.save (else rollback). @@ -88,7 +97,7 @@ video_vector_collection="video", # Reranker: Qwen3-VL-Reranker-2B is a text-only cross-encoder (score_mm over # {query_text, candidate_texts}). Adds ~4GB to logical GPU 1, shared with - # video_encoder — together ~22GB on an 80GB H100. + # video_encoder; together they use about 22GB on an 80GB H100. rerank_enabled=True, rerank_model_path=os.getenv("QWEN3VL_RERANKER_PATH", "/storage/model_zoo/Qwen3-VL-Reranker-2B"), # Under parallelism=2 + CVD=2,3, logical 1 already has video_encoder + dit rank 1, @@ -99,7 +108,7 @@ rerank_top_k=5, # Used by _determine_skip_step when rerank_enabled=True (rerank score path, # strategies.py:361-364). bf16 fp noise gives sim~0.87 for identical prompts - # via vector, but rerank cross-encoder is usually tighter — 0.85 leaves room. + # via vector, but rerank cross-encoder is usually tighter; 0.85 leaves room. rerank_score_threshold=0.85, ) @@ -130,31 +139,6 @@ def _sample_video_frames(video_frames, max_frames: int | None = None): return [video_frames[idx] for idx in indices if 0 <= idx < total] -def build_latent_data(task_data: dict, cache_result=None) -> dict: - """Build latent_data consumed by pipeline (hit or miss both call this). - - Follows teleai_pipe reference implementation: always return a dict so - the pipeline goes through the cache-aware code path (save snapshots on - miss, skip steps on hit). - """ - saved_steps = CACHE_CONFIG.get("key_steps") - if not saved_steps: - saved_steps = CacheConfig().key_steps - cached_latent = None - skip_step = 0 - hit = False - if cache_result is not None and getattr(cache_result, "hit", False): - cached_latent = getattr(cache_result, "latent_state", None) - skip_step = int(getattr(cache_result, "skip_step", 0) or 0) - hit = cached_latent is not None and skip_step > 0 - return { - "hit": hit, - "skip_step": skip_step if hit else 0, - "cached_latent": cached_latent if hit else None, - "saved_steps": saved_steps, - } - - def get_pipeline(parallelism: int = 1, model_root: str | None = None): """Build Wan22VideoPipeline for service startup. diff --git a/examples/wan_video/wan22_14b_text_to_video_service_nocache.py b/examples/wan_video/wan22_14b_text_to_video_service_nocache.py new file mode 100644 index 0000000..592ac80 --- /dev/null +++ b/examples/wan_video/wan22_14b_text_to_video_service_nocache.py @@ -0,0 +1,256 @@ +"""Wan2.2 14B Text-to-Video service pipeline with latent cache support. + +Service-mode counterpart of ``wan22_14b_text_to_video_h100.py`` without latent cache. +Exposes: +- ``get_pipeline`` for service startup +- ``run_with_file`` for TeleFuser PipelineService (must return dict with ``output_path``) +- ``CACHE_CONFIG`` for CacheServiceFactory config overrides + +``CACHE_CONFIG["enable_latent_cache"]`` is False here, so the Cacheseek +lookup/resume/save lifecycle is not engaged. +""" + +from __future__ import annotations + +import os + +import torch +from cacheseek.core.config import CacheConfig + +from telefuser.core.config import AttentionConfig, AttnImplType, FeatureCacheConfig, WeightOffloadType +from telefuser.core.module_manager import ModuleManager +from telefuser.pipelines.wan_video.wan22_video import ( + Wan22VideoPipeline, + Wan22VideoPipelineConfig, +) +from telefuser.utils.video import get_target_video_size_from_ratio, save_video + +TF_MODEL_ZOO_PATH = os.environ.get("TF_MODEL_ZOO_PATH", "model_zoo") +PPL_CONFIG = dict( + name="wan22_A14B_t2v_service", + model_root="/storage/model_zoo/Wan2.2-T2V-A14B", + negative_prompt="Overly saturated colors, overexposed, static, blurry details, subtitles, style, artwork, painting, frame, still, overall grayish, worst quality, low quality, JPEG compression artifacts, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn face, deformed, disfigured, malformed limbs, fused fingers, static frames, cluttered background, three legs, crowded background, walking backwards", + num_inference_steps=40, + num_frames=81, + resolution="720p", + aspect_ratio="16:9", + cfg_scale_high=5.0, + cfg_scale_low=5.0, + seed=42, + tiled=False, + sigma_shift=5.0, + boundary=0.9, + sample_solver="euler", + attn_impl=AttnImplType.TORCH_SDPA, + dit_high_path_list="high_noise_model/diffusion_pytorch_model-0000*-of-00006.safetensors", + dit_low_path_list="low_noise_model/diffusion_pytorch_model-0000*-of-00006.safetensors", + enable_feature_cache_dit_high=True, + enable_feature_cache_dit_low=True, + model_type="Wan2_2-T2V-A14B", + target_fps=16, +) + +CACHE_CONFIG = dict( + enable_latent_cache=False, + latent_cache_dir=os.getenv("TELEFUSER_LATENT_CACHE_DIR", "./latent_cache/wan22_t2v"), + # Cache is disabled for this pipeline variant; the remaining values are + # kept aligned with the cache-enabled sibling for easy toggling. + cache_mode="write_only", + # KV store: Fluxon is stubbed in MVP -> use local file backend. + kv_store_type="local_file", + # Vector store: Qdrant is stubbed in MVP -> use faiss backend. + vector_store_type="faiss", + # Qwen3-VL-Embedding-2B hidden_size=2048. connection.py default is 512 (too small). + # MUST match encoder output dim or FAISSVectorStore.search raises dim mismatch. + vector_dim=2048, + # Steps to snapshot. Aligned with the dataset design notes. + # (5 mid-to-late steps, no step 0). + key_steps=[5, 10, 15, 20, 25], + # Video embedding: required by VideoBasedApproximateCache.save (else rollback). + video_embedding_enabled=True, + video_embedding_model_path=os.getenv("QWEN3VL_EMBEDDING_PATH", ""), + video_embedding_max_frames=16, + # CacheConfig defaults assume 4 visible GPUs (text=1, video=2, rerank=3). + # Under CUDA_VISIBLE_DEVICES=2,3 the logical range is 0,1 only -> override. + # Both encoders colocated on logical 1 (GPU 3) so strategies.py can share + # a single Qwen3VLEncoder instance for text+video (saves 5GB, see + # strategies.py video_encoder sharing branch). Reranker takes logical 0 + # (GPU 2, alone with DiT rank 0); shared encoder + DiT rank 1 on logical 1. + text_embedding_device_id=1, + video_embedding_device_id=1, + video_vector_collection="video", + # Reranker: Qwen3-VL-Reranker-2B is a text-only cross-encoder (score_mm over + # {query_text, candidate_texts}). Adds ~4GB to logical GPU 1, shared with + # video_encoder; together they use about 22GB on an 80GB H100. + rerank_enabled=True, + rerank_model_path=os.getenv("QWEN3VL_RERANKER_PATH", "/storage/model_zoo/Qwen3-VL-Reranker-2B"), + # Under parallelism=2 + CVD=2,3, logical 1 already has video_encoder + dit rank 1, + # putting reranker there too overflows GPU 3 (~80GB H100). Default to logical 0 + # (GPU 2, shared with prompt_encoder + dit rank 0, ~14GB headroom remaining). + # Override via env TELEFUSER_RERANK_DEVICE_ID when running parallelism=4 etc. + rerank_device_id=int(os.getenv("TELEFUSER_RERANK_DEVICE_ID", "0")), + rerank_top_k=5, + # Used by _determine_skip_step when rerank_enabled=True (rerank score path, + # strategies.py:361-364). bf16 fp noise gives sim~0.87 for identical prompts + # via vector, but rerank cross-encoder is usually tighter; 0.85 leaves room. + rerank_score_threshold=0.85, +) + + +def _sample_indices(total: int, max_frames: int) -> list[int]: + if total <= 0: + return [] + max_frames = max(1, int(max_frames or 1)) + if total <= max_frames: + return list(range(total)) + step = float(total) / float(max_frames) + return [min(int(i * step), total - 1) for i in range(max_frames)] + + +def _sample_video_frames(video_frames, max_frames: int | None = None): + """Sample representative frames from the output video for embedding.""" + if video_frames is None: + return [] + if max_frames is None: + max_frames = CACHE_CONFIG.get( + "video_embedding_max_frames", + CacheConfig().video_embedding_max_frames, + ) + total = len(video_frames) + if total <= 0: + return [] + indices = _sample_indices(total, max_frames) + return [video_frames[idx] for idx in indices if 0 <= idx < total] + + +# build_latent_data ppl-file hook removed (kept consistent +# with the cache-enabled sibling file). enable_latent_cache=False here, so the +# adapter path is never engaged anyway, but removing the dead hook keeps the +# two ppl files structurally aligned. + + +def get_pipeline(parallelism: int = 1, model_root: str | None = None): + """Build Wan22VideoPipeline for service startup. + + Args: + parallelism: Number of parallel GPUs (1/2/4/8). + model_root: Override for ``PPL_CONFIG["model_root"]``. + """ + ppl_config = PPL_CONFIG + model_root = model_root or ppl_config["model_root"] + + module_manager = ModuleManager(device="cpu") + module_manager.load_model(f"{model_root}/Wan2.1_VAE.pth", torch_dtype=torch.bfloat16) + module_manager.load_model( + os.path.join(model_root, ppl_config["dit_high_path_list"]), + torch_dtype=torch.bfloat16, + ) + module_manager.load_model( + os.path.join(model_root, ppl_config["dit_low_path_list"]), + torch_dtype=torch.bfloat16, + ) + module_manager.load_model( + f"{model_root}/models_t5_umt5-xxl-enc-bf16.pth", + torch_dtype=torch.bfloat16, + ) + + pipe = Wan22VideoPipeline(device="cuda", torch_dtype=torch.bfloat16) + pipe_config = Wan22VideoPipelineConfig() + pipe_config.text_encoding_config.offload_config.offload_type = WeightOffloadType.MODEL_CPU_OFFLOAD + pipe_config.vae_config.offload_config.offload_type = WeightOffloadType.MODEL_CPU_OFFLOAD + pipe_config.dit_high_config.offload_config.offload_type = WeightOffloadType.MODEL_CPU_OFFLOAD + pipe_config.dit_low_config.offload_config.offload_type = WeightOffloadType.MODEL_CPU_OFFLOAD + pipe_config.dit_high_config.attention_config = AttentionConfig.dense_attention(ppl_config["attn_impl"]) + pipe_config.dit_low_config.attention_config = AttentionConfig.dense_attention(ppl_config["attn_impl"]) + pipe_config.sample_solver = ppl_config["sample_solver"] + + if ppl_config.get("enable_feature_cache_dit_high", False): + pipe_config.dit_high_config.feature_cache_config = FeatureCacheConfig( + enabled=True, model_type=ppl_config["model_type"] + ) + if ppl_config.get("enable_feature_cache_dit_low", False): + pipe_config.dit_low_config.feature_cache_config = FeatureCacheConfig( + enabled=True, model_type=ppl_config["model_type"] + ) + + if parallelism > 1: + cfg_scale_high = ppl_config["cfg_scale_high"] + cfg_scale_low = ppl_config["cfg_scale_low"] + if cfg_scale_high > 1: + pipe_config.dit_high_config.parallel_config.cfg_degree = 2 + pipe_config.dit_high_config.parallel_config.sp_ulysses_degree = parallelism // 2 + else: + pipe_config.dit_high_config.parallel_config.sp_ulysses_degree = parallelism + if cfg_scale_low > 1: + pipe_config.dit_low_config.parallel_config.cfg_degree = 2 + pipe_config.dit_low_config.parallel_config.sp_ulysses_degree = parallelism // 2 + else: + pipe_config.dit_low_config.parallel_config.sp_ulysses_degree = parallelism + pipe_config.dit_high_config.parallel_config.device_ids = list(range(parallelism)) + pipe_config.dit_low_config.parallel_config.device_ids = list(range(parallelism)) + pipe_config.enable_denoising_parallel = True + + pipe.init(module_manager, pipe_config) + return pipe + + +def run_with_file(pipeline, **task_data) -> dict: + """Service entrypoint invoked by PipelineService. + + Returns a dict with ``output_path`` (required) and optionally + ``latent_payload`` (consumed by task_service's post-inference hook). + + ``**task_data`` is preferred over explicit args: this guarantees + ``latent_data`` (injected by task_service's pre-inference hook) survives + ``_select_kwargs`` signature filtering in pipeline_runner.py. + """ + prompt = task_data["prompt"] + output_path = task_data["output_path"] + negative_prompt = task_data.get("negative_prompt", "") or "" + seed = int(task_data.get("seed", PPL_CONFIG["seed"])) + resolution = task_data.get("resolution") or PPL_CONFIG["resolution"] + aspect_ratio = task_data.get("aspect_ratio") or PPL_CONFIG["aspect_ratio"] + latent_data = task_data.get("latent_data") + + width, height = get_target_video_size_from_ratio( + aspect_ratio, + resolution=resolution, + height_division_factor=16, + width_division_factor=16, + ) + + result = pipeline( + prompt=prompt, + negative_prompt=f"{negative_prompt} {PPL_CONFIG['negative_prompt']}", + num_inference_steps=PPL_CONFIG["num_inference_steps"], + num_frames=PPL_CONFIG["num_frames"], + cfg_scale_high=PPL_CONFIG["cfg_scale_high"], + cfg_scale_low=PPL_CONFIG["cfg_scale_low"], + seed=seed, + tiled=PPL_CONFIG["tiled"], + height=height, + width=width, + sigma_shift=PPL_CONFIG["sigma_shift"], + boundary=PPL_CONFIG["boundary"], + latent_data=latent_data, + ) + + latent_payload: dict | None = None + if isinstance(result, tuple): + frames, latent_payload = result + else: + frames = result + + # Back-fill embedding_video_frames so VideoBasedApproximateCache.save + # can upsert to vector_store without rolling back the KV write. + if latent_payload is not None: + sampled = _sample_video_frames(frames) + if sampled: + latent_payload["embedding_video_frames"] = sampled + + save_video(frames, output_path, fps=PPL_CONFIG["target_fps"], quality=6) + + ret: dict = {"output_path": str(output_path)} + if latent_payload is not None: + ret["latent_payload"] = latent_payload + return ret diff --git a/pyproject.toml b/pyproject.toml index 695d1ed..56f6500 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -44,14 +44,18 @@ dependencies = [ "ray", "pydantic>=2.0.0", "pydantic-settings>=2.0.0", - "torchvision", + # Pin the torch stack to the validated combo (2.7.0 + cu126). Without this a + # fresh `pip install -e .` resolves to the latest torch (2.12 / cuda 13) via + # torchvision, which the H100 deployment + cacheseek repro were NOT validated on. + "torch==2.7.0", + "torchvision==0.22.0", ] [project.optional-dependencies] -# Latent cache subsystem (telefuser.cache_mem). +# Latent cache subsystem used by Cacheseek-backed service integrations. # Required only when the service is started with enable_latent_cache=true -# (or when running the wan22 service example with cache_mem). Installing +# (or when running the Wan2.2 service example with latent cache). Installing # without this extra leaves the cache layer disabled and adds zero overhead. cache = [ "faiss-cpu", diff --git a/telefuser/cache_mem/__init__.py b/telefuser/cache_mem/__init__.py deleted file mode 100644 index 32662d9..0000000 --- a/telefuser/cache_mem/__init__.py +++ /dev/null @@ -1,27 +0,0 @@ -from __future__ import annotations - -from importlib import import_module -from typing import Any - -__all__ = ["CacheConfig", "CacheResult", "LatentCache"] - - -def __getattr__(name: str) -> Any: - """Lazily expose heavy symbols to keep lightweight imports usable.""" - if name == "CacheResult": - module = import_module("telefuser.cache_mem.cache_types") - return getattr(module, "CacheResult") - if name == "LatentCache": - module = import_module("telefuser.cache_mem.latent_cache") - return getattr(module, "LatentCache") - if name == "CacheConfig": - try: - module = import_module("telefuser.cache_mem.config") - return getattr(module, "CacheConfig") - except (ImportError, ModuleNotFoundError): - return None - raise AttributeError(f"module {__name__!r} has no attribute {name!r}") - - -def __dir__() -> list[str]: - return sorted(set(globals().keys()) | set(__all__)) diff --git a/telefuser/cache_mem/cache_types.py b/telefuser/cache_mem/cache_types.py deleted file mode 100644 index 0c398d8..0000000 --- a/telefuser/cache_mem/cache_types.py +++ /dev/null @@ -1,40 +0,0 @@ -from __future__ import annotations - -from dataclasses import dataclass, field -from typing import Any, Dict, List, Optional - -import torch - - -@dataclass -class CacheResult: - """缓存查询结果。""" - - hit: bool - skip_step: int = 0 - cache_type: str = "none" # "approximate", "continue", "exact", "none" - similarity: float = 0.0 - latent_state: Optional[torch.Tensor] = None - cached_prompt: str = "" - session_id: Optional[str] = None - - -@dataclass -class IndexEntry: - """索引条目。""" - - cache_id: str - prompt: str - saved_steps: List[int] - cache_type: str = "approximate_cache" - - -@dataclass -class VectorSearchResult: - """向量检索结果。""" - - cache_id: str - similarity: float - prompt: str - saved_steps: List[int] - payload: Dict[str, Any] diff --git a/telefuser/cache_mem/config.py b/telefuser/cache_mem/config.py deleted file mode 100644 index 365ac26..0000000 --- a/telefuser/cache_mem/config.py +++ /dev/null @@ -1,83 +0,0 @@ -from dataclasses import dataclass, field -from enum import Enum -from typing import List, Optional - - -class CacheMode(Enum): - READ_WRITE = "read_write" # 读取和写入缓存(默认) - READ_ONLY = "read_only" # 仅读取缓存 - WRITE_ONLY = "write_only" # 仅写入缓存 - - -@dataclass -class CacheConfig: - """Cache configuration shared across stages/pipelines.""" - - # 基础缓存 (Basic cache) - enable_latent_cache: bool = False - cache_mode: CacheMode = CacheMode.READ_WRITE # read_write | read_only | write_only - latent_cache_dir: str = "./latent_cache" - max_cache_size_gb: int = 10 - cache_log_enabled: bool = True - cache_log_dir: Optional[str] = None # 默认: {latent_cache_dir}/logs - cache_log_level: str = "DEBUG" - cache_log_rotation: str = "100 MB" - cache_log_retention: str = "7 days" - - # KV 存储 (KV store,用于 latent 等键值缓存) - kv_store_type: str = "local_file" # "local_file" | "fluxon" - fluxon_config_path: Optional[str] = "" - - # 向量存储 (Vector store,用于 embedding 检索) - vector_store_type: str = "faiss" # "qdrant" | "faiss" - qdrant_url: Optional[str] = "" - qdrant_api_key: Optional[str] = None - faiss_index_dir: Optional[str] = None - vector_dim: int = 2048 # 向量维度(FAISS 初始化需要,应与 embedding 模型输出维度一致) - cache_strategy_type: str = "video_approximate" # 策略类型,对应 STRATEGY_REGISTRY 中的 key - - # 相似度与检索策略 (Similarity & lookup strategy) - key_steps: List[int] = field(default_factory=lambda: [0, 1, 2, 3, 4, 5]) # 参与缓存复用的 step - lookup_mode: str = "video" # 检索模式,如 "video" - - # 文本嵌入 (Prompt/text embedding 模型) - text_embedding_model_path: str = "" - text_embedding_instruction: str = "Represent the user's input" - text_embedding_device_id: Optional[int] = None - text_embedding_torch_dtype: Optional[str] = None - text_embedding_attn_impl: Optional[str] = None - - # 视频嵌入 (Video embedding 模型) - video_embedding_enabled: bool = True - video_embedding_model_path: str = "Qwen/Qwen3-VL-Embedding-2B" - video_embedding_instruction: str = "Represent the user's input" - video_embedding_fps: float = 1.0 - video_embedding_max_frames: int = 16 - video_embedding_max_length: int = 8192 - video_embedding_min_pixels: int = 4096 - video_embedding_max_pixels: int = 1843200 - video_embedding_total_pixels: int = 7864320 - video_embedding_device_id: Optional[int] = None - video_embedding_torch_dtype: Optional[str] = None - video_embedding_attn_impl: Optional[str] = None - - # 视频向量检索与重排 (Video vector search & rerank) - video_similarity_threshold: Optional[float] = 0.10 - video_vector_collection: str = "video" - rerank_enabled: bool = False - rerank_model_path: str = "Qwen/Qwen3-VL-Reranker-2B" - rerank_top_k: int = 5 - rerank_batch_size: int = 2 - rerank_device_id: Optional[int] = None - rerank_torch_dtype: Optional[str] = None - rerank_score_threshold: float = 0.90 - - # 异步保存 (Async save / write-behind) - save_async_enabled: bool = True - save_queue_size: int = 2 - save_on_full: str = "drop" # drop | sync | downgrade - save_queue_warn_threshold: int = 8 - vector_wait_warn_s: float = 2.0 - vector_wait_poll_s: float = 0.05 - vector_wait_timeout_s: float = 120.0 - flush_on_shutdown: bool = True diff --git a/telefuser/cache_mem/connection.py b/telefuser/cache_mem/connection.py deleted file mode 100644 index e50f9ad..0000000 --- a/telefuser/cache_mem/connection.py +++ /dev/null @@ -1,197 +0,0 @@ -from __future__ import annotations - -import threading -from pathlib import Path -from typing import Any, Optional - -from loguru import logger - -from .storage.interfaces import KVStore -from .vector_store.interfaces import VectorStore - - -class ConnectionManager: - def __init__( - self, - config: Any, - storage_dir: Optional[Path] = None, - ) -> None: - self._config = config - self._storage_dir = Path(storage_dir) if storage_dir else None - self._lock = threading.Lock() - - self._vector_store: Optional[VectorStore] = None - self._vector_store_created = False - self._kv_store: Optional[KVStore] = None - self._kv_store_created = False - - # ── public properties ────────────────────────────────────────── - - @property - def vector_store(self) -> Optional[VectorStore]: - """延迟创建 VectorStore 连接(Qdrant / FAISS)。""" - if not self._vector_store_created: - with self._lock: - if not self._vector_store_created: - self._vector_store = self._create_vector_store() - self._vector_store_created = True - return self._vector_store - - @property - def kv_store(self) -> Optional[KVStore]: - """延迟创建 KVStore 连接(Fluxon / LocalFile)。""" - if not self._kv_store_created: - with self._lock: - if not self._kv_store_created: - self._kv_store = self._create_kv_store() - self._kv_store_created = True - return self._kv_store - - # ── health check ─────────────────────────────────────────────── - - def health_check(self) -> dict: - result: dict = {} - - # vector_store - vs = self._vector_store - if vs is None and not self._vector_store_created: - result["vector_store"] = {"status": "not_initialized"} - elif vs is None: - result["vector_store"] = {"status": "disabled"} - else: - vs_status: dict = { - "status": "connected", - "type": type(vs).__name__, - } - # Qdrant: 尝试获取 collections 列表验证连通性 - if hasattr(vs, "client"): - try: - vs.client.get_collections() - vs_status["reachable"] = True - except Exception as exc: - logger.exception( - "ConnectionManager.health_check vector_store reachability failed: {}", - exc, - ) - vs_status["reachable"] = False - vs_status["error"] = str(exc) - result["vector_store"] = vs_status - - # kv_store - kvs = self._kv_store - if kvs is None and not self._kv_store_created: - result["kv_store"] = {"status": "not_initialized"} - elif kvs is None: - result["kv_store"] = {"status": "disabled"} - else: - result["kv_store"] = { - "status": "connected", - "type": type(kvs).__name__, - } - - return result - - # ── shutdown ─────────────────────────────────────────────────── - - def shutdown(self) -> None: - with self._lock: - for name, store in [ - ("vector_store", self._vector_store), - ("kv_store", self._kv_store), - ]: - if store is None: - continue - for method_name in ("shutdown", "close"): - if hasattr(store, method_name): - try: - getattr(store, method_name)() - except Exception as exc: - logger.exception( - "ConnectionManager.{}.{} failed: {}", - name, - method_name, - exc, - ) - break - self._vector_store = None - self._vector_store_created = False - self._kv_store = None - self._kv_store_created = False - - # ── private: 创建逻辑(从 LatentCache._build_* 迁移) ────────── - - def _create_vector_store(self) -> Optional[VectorStore]: - from .vector_store.qdrant import QdrantVectorStore - - config = self._config - store_type = (getattr(config, "vector_store_type", "") or "").lower() - - if store_type == "faiss": - return self._build_faiss_store() - - if store_type == "qdrant": - qdrant_url = getattr(config, "qdrant_url", None) - if not qdrant_url: - logger.debug("Qdrant vector store selected without qdrant_url; using in-memory Qdrant") - try: - return QdrantVectorStore( - url=qdrant_url or "", - api_key=getattr(config, "qdrant_api_key", None), - ) - except NotImplementedError as exc: - # TODO(qdrant): drop this fallback once QdrantVectorStore lands; - # otherwise it silently masks regressions in the qdrant backend. - logger.warning( - "Qdrant vector store is not implemented yet ({}); " - "falling back to FAISSVectorStore. " - "Set vector_store_type='faiss' in CacheConfig to silence this warning.", - exc, - ) - return self._build_faiss_store() - - if store_type: - logger.debug( - "Unknown vector_store_type '{}'; vector store disabled", - store_type, - ) - else: - logger.debug("vector_store_type not set; vector store disabled") - return None - - def _build_faiss_store(self) -> "VectorStore": - from .vector_store.faiss import FAISSVectorStore - - config = self._config - cache_dir = self._storage_dir.parent if self._storage_dir else Path(".") - index_dir = getattr(config, "faiss_index_dir", None) or str(cache_dir / "faiss") - vector_dim = int(getattr(config, "vector_dim", 512)) - return FAISSVectorStore(Path(index_dir), vector_dim=vector_dim, index_type="L2") - - def _create_kv_store(self) -> KVStore: - from .storage.fluxon import FluxonKVStore - from .storage.local_file import LocalFileKVStore - - config = self._config - store_type = (getattr(config, "kv_store_type", "") or "").lower() - - if store_type == "fluxon": - config_path = getattr(config, "fluxon_config_path", None) - try: - return FluxonKVStore(config_path=config_path) - except Exception as exc: - logger.exception( - "FluxonKVStore init failed config_path={}: {}", - config_path, - exc, - ) - raise RuntimeError( - f"FluxonKVStore init failed config_path={config_path} err_type={type(exc).__name__} err={exc}" - ) from exc - - if store_type and store_type not in {"local", "local_file"}: - logger.debug( - "Unknown kv_store_type '{}'; falling back to LocalFileKVStore", - store_type, - ) - storage_dir = self._storage_dir or Path("./storage") - return LocalFileKVStore(storage_dir) diff --git a/telefuser/cache_mem/encoders.py b/telefuser/cache_mem/encoders.py deleted file mode 100644 index 5c337c8..0000000 --- a/telefuser/cache_mem/encoders.py +++ /dev/null @@ -1,398 +0,0 @@ -from __future__ import annotations - -import importlib -import inspect -import os -from pathlib import Path -from typing import TYPE_CHECKING, Any, Dict, List, Optional - -from loguru import logger - -if TYPE_CHECKING: - from PIL import Image - -from .encoding.interfaces import PromptEncoder, VideoEncoder - -_QWEN3_VL_EMBEDDER_MODULES = [ - "telefuser.cache_mem.src.models.qwen3_vl_embedding", - "scripts.qwen3_vl_embedding", - "qwen3_vl_embedding", -] - -_QWEN3_VL_RERANKER_MODULES = [ - "telefuser.cache_mem.src.models.qwen3_vl_reranker", - "scripts.qwen3_vl_reranker", - "qwen3_vl_reranker", -] - - -def _try_import_symbol( - module_candidates: List[str], - symbol_name: str, - label: str, -) -> Optional[Any]: - for module_path in module_candidates: - try: - module = importlib.import_module(module_path) - except ModuleNotFoundError: - continue - except Exception as exc: - logger.exception(f"{label} import failed for {module_path}: {exc}") - continue - symbol = getattr(module, symbol_name, None) - if symbol is None: - logger.warning(f"{label} missing symbol {symbol_name} in {module_path}") - continue - return symbol - return None - - -def _process_inputs(processor: object, inputs: object) -> object: - processor_type = type(processor).__name__ - has_process = hasattr(processor, "process") - is_callable = callable(processor) - try: - if has_process: - return processor.process(inputs) - if is_callable: - return processor(inputs) - except Exception as exc: - logger.exception( - "processor invocation failed processor_type={} has_process={} callable={} err={}", - processor_type, - has_process, - is_callable, - exc, - ) - raise RuntimeError( - "processor invocation failed " - f"processor_type={processor_type} " - f"has_process={has_process} callable={is_callable} " - f"err_type={type(exc).__name__} err={exc}" - ) from exc - raise TypeError(f"processor is neither callable nor provides process() processor_type={processor_type}") - - -def _extract_first_vector(vectors: Any) -> List[float]: - if vectors is None: - return [] - if isinstance(vectors, list): - if not vectors: - return [] - first = vectors[0] - if isinstance(first, (int, float)): - return [float(value) for value in vectors] - try: - import torch - - if isinstance(first, torch.Tensor): - return first.detach().cpu().tolist() - except ModuleNotFoundError: - pass - return list(first) - try: - import torch - - if isinstance(vectors, torch.Tensor): - if vectors.numel() == 0: - return [] - if vectors.dim() == 1: - return vectors.detach().cpu().tolist() - return vectors[0].detach().cpu().tolist() - except ModuleNotFoundError: - pass - try: - import numpy as np - - if isinstance(vectors, np.ndarray): - if vectors.size == 0: - return [] - if vectors.ndim == 1: - return vectors.tolist() - return vectors[0].tolist() - except ModuleNotFoundError: - pass - return [] - - -class Qwen3VLEncoder(VideoEncoder): - def __init__( - self, - model_path: str = "Qwen/Qwen3-VL-Embedding-2B", - instruction: str = "Represent the user's input", - max_frames: int = 16, - fps: float = 1.0, - torch_dtype: Optional[str] = None, - attn_implementation: Optional[str] = None, - device_id: Optional[int] = None, - embedder: Optional[object] = None, - ) -> None: - self.model_path = model_path - self.instruction = instruction - self.max_frames = int(max_frames) - self.fps = float(fps) - self.torch_dtype = torch_dtype - self.attn_implementation = attn_implementation - self.device_id = device_id - self._embedder = embedder - self._embedder_init_error: Optional[BaseException] = None - self._embedder_init_attempted = embedder is not None - if self._embedder is None: - self._get_embedder() - - def encode(self, prompt: str) -> List[float]: - return self._encode_inputs([{"text": prompt or "", "instruction": self.instruction}]) - - def encode_video( - self, - frames: List["Image.Image"], - prompt: Optional[str] = None, - ) -> List[float]: - if not frames: - return [] - item: Dict[str, object] = { - "video": frames, - "instruction": self.instruction, - } - if prompt is not None: - item["text"] = prompt or "" - return self._encode_inputs([item]) - - def decompose_prompt(self, prompt: str) -> Dict[str, str]: - return {"whole": prompt or ""} - - def _encode_inputs(self, inputs: List[Dict[str, object]]) -> List[float]: - embedder = self._get_embedder() - if embedder is None: - return [] - return _extract_first_vector(_process_inputs(embedder, inputs)) - - def _get_embedder(self) -> Optional[object]: - if self._embedder is not None: - return self._embedder - if self._embedder_init_error is not None: - raise self._embedder_init_error - if self._embedder_init_attempted: - return None - self._embedder_init_attempted = True - - embedder_cls = _try_import_symbol( - _QWEN3_VL_EMBEDDER_MODULES, - "Qwen3VLEmbedder", - "Qwen3VLEncoder", - ) - if embedder_cls is None: - self._embedder_init_error = ImportError( - f"Qwen3VLEncoder embedder not found in candidates={_QWEN3_VL_EMBEDDER_MODULES}" - ) - raise self._embedder_init_error - - init_kwargs: Dict[str, object] = { - "model_name_or_path": self.model_path, - "fps": self.fps, - "max_frames": self.max_frames, - } - try: - params = inspect.signature(embedder_cls.__init__).parameters - if "torch_dtype" in params and self.torch_dtype is not None: - init_kwargs["torch_dtype"] = self.torch_dtype - if "attn_implementation" in params and self.attn_implementation is not None: - init_kwargs["attn_implementation"] = self.attn_implementation - if self.device_id is not None: - did = int(self.device_id) - if "device_id" in params: - init_kwargs["device_id"] = did - elif "device" in params: - init_kwargs["device"] = "cpu" if did < 0 else f"cuda:{did}" - except (TypeError, ValueError) as exc: - logger.exception( - "Qwen3VLEncoder could not inspect __init__ signature embedder_cls={} err_type={} err={}", - getattr(embedder_cls, "__name__", repr(embedder_cls)), - type(exc).__name__, - exc, - ) - - try: - self._embedder = embedder_cls(**init_kwargs) - except Exception as exc: - logger.exception( - "Qwen3VLEncoder init failed model_path={} embedder_cls={} err={}", - self.model_path, - getattr(embedder_cls, "__name__", repr(embedder_cls)), - exc, - ) - self._embedder_init_error = RuntimeError( - "Qwen3VLEncoder init failed " - f"model_path={self.model_path} " - f"embedder_cls={getattr(embedder_cls, '__name__', repr(embedder_cls))} " - f"type={type(exc).__name__} err={exc}" - ) - raise self._embedder_init_error from exc - return self._embedder - - -class Qwen3VLReranker: - def __init__( - self, - model_path: str = "Qwen/Qwen3-VL-Reranker-8B", - instruction: str = "Retrieval relevant image or text with user's query", - fps: float = 1.0, - device_id: Optional[int] = None, - batch_size: int = 2, - torch_dtype: Optional[str] = None, - attn_implementation: Optional[str] = None, - reranker: Optional[object] = None, - ) -> None: - self.model_path = model_path - self.instruction = instruction - self.fps = float(fps) - self.device_id = device_id - self.batch_size = max(1, int(batch_size or 1)) - self.torch_dtype = torch_dtype - self.attn_implementation = attn_implementation - self._reranker = reranker - self._reranker_init_error: Optional[BaseException] = None - self._reranker_init_attempted = reranker is not None - if self._reranker is None: - self._init_reranker() - - def score_mm(self, query: Dict[str, object], documents: List[Dict[str, object]]) -> List[float]: - if not isinstance(query, dict) or not documents: - return [] - self._init_reranker() - if self._reranker is None: - return [] - scores = self._score_with_reranker_mm(self._reranker, query, documents) - return self._normalize_scores(scores, len(documents)) - - def _init_reranker(self) -> None: - if self._reranker is not None: - return - if self._reranker_init_error is not None: - raise self._reranker_init_error - if self._reranker_init_attempted: - return - self._reranker_init_attempted = True - - os.environ.setdefault("HF_HUB_OFFLINE", "1") - os.environ.setdefault("TRANSFORMERS_OFFLINE", "1") - if not Path(self.model_path).exists(): - logger.warning("Qwen3VLReranker model_path is not local; offline mode requires cached files") - - reranker_cls = _try_import_symbol( - _QWEN3_VL_RERANKER_MODULES, - "Qwen3VLReranker", - "Qwen3VLReranker", - ) - if reranker_cls is None: - self._reranker_init_error = ImportError( - f"Qwen3VLReranker implementation not found in candidates={_QWEN3_VL_RERANKER_MODULES}" - ) - raise self._reranker_init_error - - init_kwargs: Dict[str, object] = {"model_name_or_path": self.model_path} - try: - params = inspect.signature(reranker_cls.__init__).parameters - if "batch_size" in params: - init_kwargs["batch_size"] = self.batch_size - if "torch_dtype" in params and self.torch_dtype is not None: - init_kwargs["torch_dtype"] = self.torch_dtype - if "attn_implementation" in params and self.attn_implementation is not None: - init_kwargs["attn_implementation"] = self.attn_implementation - if self.device_id is not None: - did = int(self.device_id) - if "device_id" in params: - init_kwargs["device_id"] = did - elif "device" in params: - init_kwargs["device"] = "cpu" if did < 0 else f"cuda:{did}" - except (TypeError, ValueError) as exc: - logger.exception( - "Qwen3VLReranker could not inspect __init__ signature reranker_cls={} err_type={} err={}", - getattr(reranker_cls, "__name__", repr(reranker_cls)), - type(exc).__name__, - exc, - ) - - try: - self._reranker = reranker_cls(**init_kwargs) - except Exception as exc: - logger.exception( - "Qwen3VLReranker init failed model_path={} reranker_cls={} err={}", - self.model_path, - getattr(reranker_cls, "__name__", repr(reranker_cls)), - exc, - ) - self._reranker_init_error = RuntimeError( - "Qwen3VLReranker init failed " - f"model_path={self.model_path} " - f"reranker_cls={getattr(reranker_cls, '__name__', repr(reranker_cls))} " - f"type={type(exc).__name__} err={exc}" - ) - raise self._reranker_init_error from exc - - def _score_with_reranker_mm( - self, - reranker: object, - query: Dict[str, object], - documents: List[Dict[str, object]], - ) -> Optional[object]: - if not hasattr(reranker, "process"): - raise AttributeError(f"Qwen3VLReranker backend is missing process() type={type(reranker).__name__}") - inputs = { - "instruction": self.instruction, - "query": query, - "documents": documents, - "fps": self.fps, - } - try: - return reranker.process(inputs) - except TypeError as exc: - logger.exception( - "Qwen3VLReranker.process rejected multimodal payload reranker_type={} err={}", - type(reranker).__name__, - exc, - ) - raise RuntimeError( - "Qwen3VLReranker.process rejected multimodal payload " - f"reranker_type={type(reranker).__name__} " - f"err_type={type(exc).__name__} err={exc}" - ) from exc - - def _normalize_scores(self, scores: Optional[object], expected_len: int) -> List[float]: - if scores is None: - return [] - try: - import torch - - if isinstance(scores, torch.Tensor): - scores = scores.detach().cpu().tolist() - except ModuleNotFoundError: - pass - try: - import numpy as np - - if isinstance(scores, np.ndarray): - scores = scores.tolist() - except ModuleNotFoundError: - pass - if not isinstance(scores, list) or not scores: - return [] - if isinstance(scores[0], dict): - values = [] - for item in scores: - if not isinstance(item, dict): - continue - if "score" in item: - values.append(float(item["score"])) - elif "relevance" in item: - values.append(float(item["relevance"])) - elif "logit" in item: - values.append(float(item["logit"])) - return values - if isinstance(scores[0], (list, tuple)) and len(scores[0]) == 2 and isinstance(scores[0][0], int): - ordered = [float("-inf")] * expected_len - for idx, value in scores: - if 0 <= int(idx) < expected_len: - ordered[int(idx)] = float(value) - return ordered - return [float(item) for item in scores] diff --git a/telefuser/cache_mem/encoding/__init__.py b/telefuser/cache_mem/encoding/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/telefuser/cache_mem/encoding/interfaces.py b/telefuser/cache_mem/encoding/interfaces.py deleted file mode 100644 index 70e2546..0000000 --- a/telefuser/cache_mem/encoding/interfaces.py +++ /dev/null @@ -1,27 +0,0 @@ -from __future__ import annotations - -from abc import ABC, abstractmethod -from typing import TYPE_CHECKING, Dict, List - -if TYPE_CHECKING: - from PIL import Image - - -class PromptEncoder(ABC): - """Prompt 编码器接口。""" - - @abstractmethod - def encode(self, prompt: str) -> List[float]: - pass - - @abstractmethod - def decompose_prompt(self, prompt: str) -> Dict[str, str]: - pass - - -class VideoEncoder(ABC): - """Video 编码器接口。""" - - @abstractmethod - def encode_video(self, frames: List["Image.Image"]) -> List[float]: - pass diff --git a/telefuser/cache_mem/latent_cache.py b/telefuser/cache_mem/latent_cache.py deleted file mode 100644 index 3d65083..0000000 --- a/telefuser/cache_mem/latent_cache.py +++ /dev/null @@ -1,213 +0,0 @@ -from __future__ import annotations - -from pathlib import Path -from typing import Any, Dict, List, Optional - -import torch -from loguru import logger - -from .cache_types import CacheResult -from .connection import ConnectionManager -from .metadata import LocalCacheMetadataManager -from .state.interfaces import CacheMetadataManager -from .storage.interfaces import KVStore -from .strategies import BaseCacheStrategy, get_strategy_class -from .vector_store.interfaces import VectorStore - -try: - from telefuser.cache_mem.config import CacheConfig -except (ImportError, ModuleNotFoundError): # optional dependency for cache service - CacheConfig = None - - -class LatentCache: - def __init__( - self, - cache_dir: Optional[Path] = None, - config: Optional["CacheConfig"] = None, - kv_store: Optional[KVStore] = None, - vector_store: Optional[VectorStore] = None, - metadata_manager: Optional[CacheMetadataManager] = None, - strategy: Optional[BaseCacheStrategy] = None, - ): - # Initialize config and directories. - if config is None: - if CacheConfig is None: - raise ValueError("LatentCache requires CacheConfig but it is unavailable") - config = CacheConfig() - self.config = config - if hasattr(self.config, "latent_cache_dir"): - self.cache_dir = Path(cache_dir or self.config.latent_cache_dir) - else: - self.cache_dir = Path(cache_dir or self.config.cache_dir) - self.cache_dir.mkdir(parents=True, exist_ok=True) - self.storage_dir = self.cache_dir / "storage" - self.dit_cache_dir = self.cache_dir / "dit_cache" - self.metadata_dir = self.cache_dir / "metadata" - self.storage_dir.mkdir(parents=True, exist_ok=True) - self.dit_cache_dir.mkdir(parents=True, exist_ok=True) - self.metadata_dir.mkdir(parents=True, exist_ok=True) - - self._conn_mgr = ConnectionManager( - config, - storage_dir=self.storage_dir, - ) - self._managed_kv = kv_store is None - self._managed_vector = vector_store is None - - # Initialize kv_store - self.kv_store = kv_store or self._conn_mgr.kv_store - - # Initialize metadata_manager - self.metadata_manager = metadata_manager or LocalCacheMetadataManager(self.metadata_dir) - - # Initialize vector_store - self.vector_store = vector_store or self._conn_mgr.vector_store - - # Initialize strategy - if strategy is not None: - self.strategy = strategy - else: - strategy_name = getattr(self.config, "cache_strategy_type", "") - strategy_cls = get_strategy_class(strategy_name) if strategy_name else None - if strategy_cls is not None: - self.strategy = strategy_cls( - self.config, - self.kv_store, - self.vector_store, - self.metadata_manager, - ) - else: - self.strategy = None - - async def lookup(self, task_request: Any) -> CacheResult: - task_type = getattr(task_request, "task", "t2v") - prompt = getattr(task_request, "prompt", "") - if self.strategy is None: - return CacheResult(hit=False) - result = await self.strategy.lookup(prompt, task_type) - return await self.strategy.apply(result) - - async def save( - self, - task_request: Any, - latent_states_dict: Dict[int, torch.Tensor], - num_frames: int, - final_step: int, - saved_steps: List[int], - embedding_video_frames: Optional[List[Any]] = None, - ) -> None: - task_type = getattr(task_request, "task", "t2v") - prompt = getattr(task_request, "prompt", "") - if self.strategy is None: - return - await self.strategy.save( - prompt, - latent_states_dict, - num_frames, - task_type, - saved_steps, - embedding_video_frames=embedding_video_frames, - ) - - def shutdown(self) -> None: - """Release internal resources (best effort).""" - # Shut down strategy and metadata_manager directly. - for name in ("strategy", "metadata_manager"): - obj = getattr(self, name, None) - if obj is None: - continue - for method_name in ("shutdown", "close"): - if hasattr(obj, method_name): - try: - getattr(obj, method_name)() - except Exception as exc: - logger.exception(f"LatentCache.{name}.{method_name} failed: {exc}") - break - - if self._conn_mgr is not None and (self._managed_kv or self._managed_vector): - try: - self._conn_mgr.shutdown() - except Exception as exc: - logger.exception(f"LatentCache.ConnectionManager.shutdown failed: {exc}") - - for name, managed in [ - ("kv_store", self._managed_kv), - ("vector_store", self._managed_vector), - ]: - if managed: - continue # Already handled by ConnectionManager - obj = getattr(self, name, None) - if obj is None: - continue - for method_name in ("shutdown", "close"): - if hasattr(obj, method_name): - try: - getattr(obj, method_name)() - except Exception as exc: - logger.exception(f"LatentCache.{name}.{method_name} failed: {exc}") - break - - self.strategy = None - self.vector_store = None - self.kv_store = None - self.metadata_manager = None - - def purge_by_prompt(self, prompt: str, collection: str = "whole") -> bool: - """Delete cache by prompt (metadata / vector_store / kv_store).""" - prompt = prompt or "" - if not prompt: - return False - entry = self.metadata_manager.lookup_prompt( - prompt, - cache_type="video_approximate_cache", - ) - if entry is None: - return False - cache_id = entry.cache_id - errors: List[str] = [] - for step in entry.saved_steps: - try: - self.kv_store.remove(f"{cache_id}_step{int(step)}") - except Exception as exc: - logger.exception( - "LatentCache.purge_by_prompt kv remove failed prompt={} cache_id={} step={} err={}", - prompt, - cache_id, - int(step), - exc, - ) - errors.append( - f"kv remove failed cache_id={cache_id} step={int(step)} type={type(exc).__name__} err={exc}" - ) - if self.vector_store is not None: - try: - self.vector_store.delete(collection, [cache_id]) - except Exception as exc: - logger.exception( - "LatentCache.purge_by_prompt vector delete failed prompt={} collection={} cache_id={} err={}", - prompt, - collection, - cache_id, - exc, - ) - errors.append( - "vector delete failed " - f"collection={collection} cache_id={cache_id} " - f"type={type(exc).__name__} err={exc}" - ) - try: - self.metadata_manager.remove_cache(cache_id) - except Exception as exc: - logger.exception( - "LatentCache.purge_by_prompt metadata remove failed prompt={} cache_id={} err={}", - prompt, - cache_id, - exc, - ) - errors.append(f"metadata remove failed cache_id={cache_id} type={type(exc).__name__} err={exc}") - if errors: - raise RuntimeError( - f"LatentCache.purge_by_prompt failed prompt={prompt!r} cache_id={cache_id}: {'; '.join(errors)}" - ) - return True diff --git a/telefuser/cache_mem/log_monitor.py b/telefuser/cache_mem/log_monitor.py deleted file mode 100644 index 2cff527..0000000 --- a/telefuser/cache_mem/log_monitor.py +++ /dev/null @@ -1,77 +0,0 @@ -from __future__ import annotations - -from pathlib import Path -from typing import Optional - -from loguru import logger - -# 被拦截的模块名前缀列表 -_CACHE_MODULE_PREFIXES = ( - "telefuser.cache_mem", - "telefuser.service.cache.cache_service", - "telefuser.service.cache.cache_factory", -) - -_sink_id: Optional[int] = None - - -def _cache_module_filter(record: dict) -> bool: - name = record.get("name", "") - return any(name.startswith(prefix) for prefix in _CACHE_MODULE_PREFIXES) - - -def setup_cache_log_sink( - log_dir: str | Path, - *, - level: str = "DEBUG", - rotation: str = "100 MB", - retention: str = "7 days", - fmt: str = ("[CACHE] {time:YYYY-MM-DD HH:mm:ss.SSS} | {level:<8} | {name}:{function}:{line} | {message}"), -) -> Path: - global _sink_id - - # 如果已有旧 sink,先清除 - if _sink_id is not None: - try: - logger.remove(_sink_id) - except ValueError: - pass # sink 已被其他逻辑移除 - _sink_id = None - - log_dir = Path(log_dir) - log_dir.mkdir(parents=True, exist_ok=True) - log_path = log_dir / "cache_service.log" - - _sink_id = logger.add( - str(log_path), - filter=_cache_module_filter, - format=fmt, - level=level, - rotation=rotation, - retention=retention, - encoding="utf-8", - enqueue=True, # 异步写入,不阻塞业务线程 - ) - - logger.info( - "Cache log sink configured: path={} level={} rotation={} retention={}", - log_path, - level, - rotation, - retention, - ) - return log_path - - -def remove_cache_log_sink() -> None: - global _sink_id - if _sink_id is not None: - try: - logger.remove(_sink_id) - except ValueError: - pass - _sink_id = None - - -def is_cache_log_sink_active() -> bool: - return _sink_id is not None diff --git a/telefuser/cache_mem/metadata.py b/telefuser/cache_mem/metadata.py deleted file mode 100644 index af3bbc8..0000000 --- a/telefuser/cache_mem/metadata.py +++ /dev/null @@ -1,268 +0,0 @@ -from __future__ import annotations - -import json -import threading -import time -from pathlib import Path -from typing import Dict, List, Optional, Tuple - -from loguru import logger - -from .cache_types import IndexEntry -from .state.interfaces import CacheMetadataManager - - -class LocalCacheMetadataManager(CacheMetadataManager): - def __init__(self, metadata_cache_dir: str | Path) -> None: - self._default_cache_type = "approximate_cache" - self.metadata_cache_dir = Path(metadata_cache_dir) - self.metadata_cache_dir.mkdir(parents=True, exist_ok=True) - self._index_path = self.metadata_cache_dir / "prompt_index.json" - self._meta_path = self.metadata_cache_dir / "cache_meta.json" - self._lock = threading.RLock() - self._index: Dict[str, Dict[str, IndexEntry]] = self._load_index() - self._meta: Dict[str, Dict[str, object]] = self._load_meta() - - # --- CRUD 核心操作 --- - - def register_cache( - self, - cache_id: str, - prompt: str, - saved_steps: List[int], - size_mb: float, - num_frames: int, - cache_type: Optional[str] = None, - ) -> None: - steps = sorted(set(int(s) for s in saved_steps)) - # Normalize so None never collides with the string "None" after JSON round-trip. - normalized_cache_type = self._normalize_cache_type(cache_type) - with self._lock: - index = self._index.setdefault(normalized_cache_type, {}) - index[cache_id] = IndexEntry( - cache_id=cache_id, - prompt=prompt, - saved_steps=steps, - cache_type=normalized_cache_type, - ) - self._meta[cache_id] = { - "prompt": prompt, - "saved_steps": steps, - "size_mb": float(size_mb), - "num_frames": int(num_frames), - "access_count": int(self._meta.get(cache_id, {}).get("access_count", 0)), - "last_access_time": float(time.time()), - "cache_type": normalized_cache_type, - } - self._save_index() - self._save_meta() - - def remove_cache(self, cache_id: str) -> None: - with self._lock: - meta = self._meta.pop(cache_id, None) - cache_type = meta.get("cache_type") if meta else None - if cache_type: - self._index.get(str(cache_type), {}).pop(cache_id, None) - else: - # cache_type 未知时全表扫(备用降级,正常不走) - logger.debug( - "LocalCacheMetadataManager.remove_cache fallback scan (cache_type missing) cache_id={}", - cache_id, - ) - for mapping in self._index.values(): - mapping.pop(cache_id, None) - self._save_index() - self._save_meta() - - def lookup_prompt(self, prompt: str, cache_type: Optional[str] = None) -> Optional[IndexEntry]: - # 主键改为 cache_id 后,这里从 O(1) 转为 values() 迭代找 prompt 匹配。 - # dict.values() 按插入顺序迭代(Python 3.7+),所以同 prompt 多次 save 时 - # 默认返回最早插入的 entry,`purge_by_prompt` 外层循环调用可清空全部历史。 - def _scan(mapping: Dict[str, IndexEntry]) -> Optional[IndexEntry]: - for entry in mapping.values(): - if entry.prompt == prompt: - return entry - return None - - with self._lock: - if cache_type: - return _scan(self._index.get(self._normalize_cache_type(cache_type), {})) - # Default to text cache first, then scan others. - entry = _scan(self._index.get(self._default_cache_type, {})) - if entry is not None: - return entry - for mapping in self._index.values(): - entry = _scan(mapping) - if entry is not None: - return entry - return None - - def get_cache_meta(self, cache_id: str) -> Optional[dict]: - with self._lock: - meta = self._meta.get(cache_id) - if meta is None: - return None - return dict(meta) - - # --- 访问统计 & 淘汰 --- - - def record_access(self, cache_id: str) -> None: - normalized = self._normalize_cache_id(cache_id) - with self._lock: - meta = self._meta.get(normalized) - if meta is None: - return - meta["access_count"] = int(meta.get("access_count", 0)) + 1 - meta["last_access_time"] = float(time.time()) - self._save_meta() - - def plan_eviction(self, required_mb: float, limit_mb: float) -> List[Tuple[str, Dict[str, object]]]: - with self._lock: - current_mb = sum(float(v.get("size_mb", 0.0)) for v in self._meta.values()) - if current_mb + required_mb <= limit_mb: - return [] - need = current_mb + required_mb - limit_mb - items = sorted( - self._meta.items(), - key=lambda kv: float(kv[1].get("last_access_time", 0.0)), - ) - selected: List[Tuple[str, Dict[str, object]]] = [] - freed = 0.0 - for cache_id, meta in items: - selected.append((cache_id, meta)) - freed += float(meta.get("size_mb", 0.0)) - if freed >= need: - break - return selected - - # --- 审计日志 --- - - def record_hit_pair( - self, - request_prompt: str, - cache_id: str, - cached_prompt: str, - similarity: float, - task_type: str, - cache_type: str, - skip_step: int, - ) -> None: - payload = { - "timestamp": float(time.time()), - "request_prompt": str(request_prompt or ""), - "cache_id": str(cache_id), - "cached_prompt": str(cached_prompt or ""), - "similarity": float(similarity), - "task_type": str(task_type or ""), - "cache_type": str(cache_type or ""), - "skip_step": int(skip_step), - } - log_path = self.metadata_cache_dir / "hit_pairs.jsonl" - with log_path.open("a", encoding="utf-8") as f: - f.write(json.dumps(payload, ensure_ascii=True) + "\n") - - def record_similarity_scores( - self, - request_prompt: str, - task_type: str, - cache_type: str, - stage: str, - candidates: List[dict], - ) -> None: - payload = { - "timestamp": float(time.time()), - "request_prompt": str(request_prompt or ""), - "task_type": str(task_type or ""), - "cache_type": str(cache_type or ""), - "stage": str(stage or ""), - "candidates": candidates, - } - log_path = self.metadata_cache_dir / "similarity_scores.jsonl" - with log_path.open("a", encoding="utf-8") as f: - f.write(json.dumps(payload, ensure_ascii=True) + "\n") - - # --- 持久化(私有) --- - - def _load_index(self) -> Dict[str, Dict[str, IndexEntry]]: - if not self._index_path.exists(): - return {} - raw = self._read_json_object(self._index_path, "prompt index") - - # Schema: {cache_type: {cache_id: entry_dict}} - result: Dict[str, Dict[str, IndexEntry]] = {} - for cache_type, entries in raw.items(): - if not isinstance(entries, dict) or not entries: - continue - ct_str = str(cache_type) - mapping: Dict[str, IndexEntry] = {} - for cache_id, entry in entries.items(): - if not isinstance(entry, dict): - continue - mapping[str(cache_id)] = IndexEntry( - cache_id=str(cache_id), - prompt=str(entry.get("prompt", "")), - saved_steps=[int(x) for x in entry.get("saved_steps", [])], - cache_type=str(entry.get("cache_type") or ct_str or self._default_cache_type), - ) - if mapping: - result[ct_str] = mapping - return result - - def _load_meta(self) -> Dict[str, Dict[str, object]]: - if not self._meta_path.exists(): - return {} - raw = self._read_json_object(self._meta_path, "cache metadata") - return raw - - def _save_index(self) -> None: - # Schema: {cache_type: {cache_id: entry_dict}} - data: Dict[str, Dict[str, Dict[str, object]]] = {} - for cache_type, mapping in self._index.items(): - data[str(cache_type)] = { - cache_id: { - "prompt": entry.prompt, - "saved_steps": entry.saved_steps, - "cache_type": entry.cache_type or cache_type, - } - for cache_id, entry in mapping.items() - } - self._index_path.write_text(json.dumps(data, ensure_ascii=True)) - - def _save_meta(self) -> None: - self._meta_path.write_text(json.dumps(self._meta, ensure_ascii=True)) - - # --- 工具函数(私有) --- - - def _normalize_cache_id(self, cache_id: str) -> str: - return (cache_id or "").replace("-", "") - - def _normalize_cache_type(self, cache_type: Optional[str]) -> str: - cache_type = str(cache_type or "").strip() - return cache_type or self._default_cache_type - - def _read_json_object(self, path: Path, label: str) -> Dict[str, object]: - try: - raw = path.read_text() - except OSError as exc: - logger.exception( - "LocalCacheMetadataManager failed to read {} path={} err={}", - label, - path, - exc, - ) - raise RuntimeError(f"LocalCacheMetadataManager failed to read {label} path={path}: {exc}") from exc - try: - data = json.loads(raw) - except json.JSONDecodeError as exc: - logger.exception( - "LocalCacheMetadataManager {} is not valid JSON path={} err={}", - label, - path, - exc, - ) - raise ValueError(f"LocalCacheMetadataManager {label} is not valid JSON path={path}: {exc}") from exc - if not isinstance(data, dict): - raise ValueError( - f"LocalCacheMetadataManager {label} must be a JSON object path={path} got_type={type(data).__name__}" - ) - return data diff --git a/telefuser/cache_mem/src/__init__.py b/telefuser/cache_mem/src/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/telefuser/cache_mem/src/models/__init__.py b/telefuser/cache_mem/src/models/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/telefuser/cache_mem/src/models/qwen3_vl_embedding.py b/telefuser/cache_mem/src/models/qwen3_vl_embedding.py deleted file mode 100644 index 2157507..0000000 --- a/telefuser/cache_mem/src/models/qwen3_vl_embedding.py +++ /dev/null @@ -1,346 +0,0 @@ -import logging -import unicodedata -from dataclasses import dataclass -from typing import Any, Dict, List, Optional, Union - -import numpy as np -import torch -import torch.nn.functional as F -from PIL import Image -from qwen_vl_utils.vision_process import process_vision_info -from transformers.cache_utils import Cache -from transformers.modeling_outputs import ModelOutput -from transformers.models.qwen3_vl.modeling_qwen3_vl import Qwen3VLConfig, Qwen3VLModel, Qwen3VLPreTrainedModel -from transformers.models.qwen3_vl.processing_qwen3_vl import Qwen3VLProcessor -from transformers.processing_utils import Unpack -from transformers.utils import TransformersKwargs -from transformers.utils.generic import check_model_inputs - -logger = logging.getLogger(__name__) - -# Constants for configuration -MAX_LENGTH = 8192 -IMAGE_BASE_FACTOR = 16 -IMAGE_FACTOR = IMAGE_BASE_FACTOR * 2 -MIN_PIXELS = 4 * IMAGE_FACTOR * IMAGE_FACTOR -MAX_PIXELS = 1800 * IMAGE_FACTOR * IMAGE_FACTOR -FPS = 1 -MAX_FRAMES = 64 -FRAME_MAX_PIXELS = 768 * IMAGE_FACTOR * IMAGE_FACTOR -MAX_TOTAL_PIXELS = 10 * FRAME_MAX_PIXELS -PAD_TOKEN = "<|endoftext|>" - - -# Define output structure for embeddings -@dataclass -class Qwen3VLForEmbeddingOutput(ModelOutput): - last_hidden_state: Optional[torch.FloatTensor] = None - attention_mask: Optional[torch.Tensor] = None - - -# Define model class to compute embeddings -class Qwen3VLForEmbedding(Qwen3VLPreTrainedModel): - _checkpoint_conversion_mapping = {} - accepts_loss_kwargs = False - config: Qwen3VLConfig - - def __init__(self, config): - super().__init__(config) - self.model = Qwen3VLModel(config) - self.post_init() - - def get_input_embeddings(self): - return self.model.get_input_embeddings() - - def set_input_embeddings(self, value): - self.model.set_input_embeddings(value) - - def set_decoder(self, decoder): - self.model.set_decoder(decoder) - - def get_decoder(self): - return self.model.get_decoder() - - # Extract video features from model - def get_video_features( - self, pixel_values_videos: torch.FloatTensor, video_grid_thw: Optional[torch.LongTensor] = None - ): - return self.model.get_video_features(pixel_values_videos, video_grid_thw) - - # Extract image features from model - def get_image_features(self, pixel_values: torch.FloatTensor, image_grid_thw: Optional[torch.LongTensor] = None): - return self.model.get_image_features(pixel_values, image_grid_thw) - - # Make modules accessible through properties - @property - def language_model(self): - return self.model.language_model - - @property - def visual(self): - return self.model.visual - - # Forward pass through model with input parameters - # @check_model_inputs - def forward( - self, - input_ids: torch.LongTensor = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[Cache] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - pixel_values: Optional[torch.Tensor] = None, - pixel_values_videos: Optional[torch.FloatTensor] = None, - image_grid_thw: Optional[torch.LongTensor] = None, - video_grid_thw: Optional[torch.LongTensor] = None, - cache_position: Optional[torch.LongTensor] = None, - logits_to_keep: Union[int, torch.Tensor] = 0, - **kwargs: Unpack[TransformersKwargs], - ) -> Union[tuple, Qwen3VLForEmbeddingOutput]: - # Pass inputs through the model - outputs = self.model( - input_ids=input_ids, - pixel_values=pixel_values, - pixel_values_videos=pixel_values_videos, - image_grid_thw=image_grid_thw, - video_grid_thw=video_grid_thw, - position_ids=position_ids, - attention_mask=attention_mask, - past_key_values=past_key_values, - inputs_embeds=inputs_embeds, - cache_position=cache_position, - **kwargs, - ) - # Return the model output - return Qwen3VLForEmbeddingOutput( - last_hidden_state=outputs.last_hidden_state, - attention_mask=attention_mask, - ) - - -def sample_frames( - frames: List[Union[str, Image.Image]], - num_segments: int, - max_segments: int, -) -> List[Union[str, Image.Image]]: - """Uniformly sample the final number of frames across the full video.""" - if not frames: - raise ValueError("sample_frames requires at least one frame") - duration = len(frames) - target_segments = max(1, min(int(num_segments), int(max_segments))) - frame_id_array = np.linspace(0, duration - 1, target_segments, dtype=int) - frame_id_list = frame_id_array.tolist() - - # Create a list of sampled frames - sampled_frames = [] - for frame_idx in frame_id_list: - if frame_idx < 0 or frame_idx >= duration: - raise IndexError(f"sample_frames generated out-of-range index={frame_idx} duration={duration}") - sampled_frames.append(frames[frame_idx]) - return sampled_frames - - -# Define embedder class for processing inputs and generating embeddings -class Qwen3VLEmbedder: - def __init__( - self, - model_name_or_path: str, - max_length: int = MAX_LENGTH, - min_pixels: int = MIN_PIXELS, - max_pixels: int = MAX_PIXELS, - total_pixels: int = MAX_TOTAL_PIXELS, - fps: float = FPS, - num_frames: int = MAX_FRAMES, - max_frames: int = MAX_FRAMES, - device=torch.device("cuda" if torch.cuda.is_available() else "cpu"), - default_instruction: str = "Represent the user's input.", - **kwargs, - ): - self.max_length = max_length - self.min_pixels = min_pixels - self.max_pixels = max_pixels - self.total_pixels = total_pixels - self.fps = fps - self.num_frames = num_frames - self.max_frames = max_frames - - self.default_instruction = default_instruction - - self.model = Qwen3VLForEmbedding.from_pretrained(model_name_or_path, trust_remote_code=True, **kwargs).to( - device - ) - self.processor = Qwen3VLProcessor.from_pretrained(model_name_or_path, padding_side="right") - self.model.eval() - - @torch.no_grad() - def forward(self, inputs: Dict[str, Any]) -> Dict[str, torch.Tensor]: - outputs = self.model(**inputs) - return {"last_hidden_state": outputs.last_hidden_state, "attention_mask": inputs.get("attention_mask")} - - # Truncate token sequence to a specified max length - def _truncate_tokens(self, token_ids: List[int], max_length: int) -> List[int]: - if len(token_ids) <= max_length: - return token_ids - - special_token_ids = set(self.processor.tokenizer.all_special_ids) - num_special = sum(1 for token_idx in token_ids if token_idx in special_token_ids) - num_non_special_to_keep = max_length - num_special - - final_token_ids = [] - non_special_kept_count = 0 - # Ensure retention of special tokens while truncating the rest - for token_idx in token_ids: - if token_idx in special_token_ids: - final_token_ids.append(token_idx) - elif non_special_kept_count < num_non_special_to_keep: - final_token_ids.append(token_idx) - non_special_kept_count += 1 - return final_token_ids - - # Format input based on provided text, image, video, and instruction - def format_model_input( - self, - text: Optional[str] = None, - image: Optional[Union[str, Image.Image]] = None, - video: Optional[Union[str, List[Union[str, Image.Image]]]] = None, - instruction: Optional[str] = None, - fps: Optional[float] = None, - max_frames: Optional[int] = None, - ) -> List[Dict]: - # Ensure instruction ends with punctuation - if instruction: - instruction = instruction.strip() - if instruction and not unicodedata.category(instruction[-1]).startswith("P"): - instruction = instruction + "." - - # Initialize conversation with system prompts - content = [] - conversation = [ - {"role": "system", "content": [{"type": "text", "text": instruction or self.default_instruction}]}, - {"role": "user", "content": content}, - ] - - # Add text, image, or video content to conversation - if not text and not image and not video: - content.append({"type": "text", "text": "NULL"}) - return conversation - - if video: - video_content = None - video_kwargs = {"total_pixels": self.total_pixels} - if isinstance(video, list): - video_content = video - if self.num_frames is not None or self.max_frames is not None: - video_content = sample_frames(video_content, self.num_frames, self.max_frames) - video_content = [("file://" + ele if isinstance(ele, str) else ele) for ele in video_content] - elif isinstance(video, str): - video_content = video if video.startswith(("http://", "https://")) else "file://" + video - video_kwargs = { - "fps": fps or self.fps, - "max_frames": max_frames or self.max_frames, - } - else: - raise TypeError(f"Unrecognized video type: {type(video)}") - - # Add video input details to content - if video_content: - content.append({"type": "video", "video": video_content, **video_kwargs}) - - if image: - image_content = None - if isinstance(image, Image.Image): - image_content = image - elif isinstance(image, str): - image_content = image if image.startswith(("http", "oss")) else "file://" + image - else: - raise TypeError(f"Unrecognized image type: {type(image)}") - - # Add image input details to content - if image_content: - content.append( - { - "type": "image", - "image": image_content, - "min_pixels": self.min_pixels, - "max_pixels": self.max_pixels, - } - ) - - if text: - content.append({"type": "text", "text": text}) - - return conversation - - # Preprocess input conversations for model consumption - def _preprocess_inputs(self, conversations: List[List[Dict]]) -> Dict[str, torch.Tensor]: - text = self.processor.apply_chat_template(conversations, add_generation_prompt=True, tokenize=False) - - try: - images, video_inputs, video_kwargs = process_vision_info( - conversations, image_patch_size=16, return_video_metadata=True, return_video_kwargs=True - ) - except Exception as exc: - logger.exception( - "Qwen3VLEmbedder failed to process multimodal inputs conversation_count=%s err=%s", - len(conversations), - exc, - ) - raise RuntimeError( - f"Qwen3VLEmbedder failed to process multimodal inputs conversation_count={len(conversations)} err={exc}" - ) from exc - - if video_inputs is not None: - videos, video_metadata = zip(*video_inputs) - videos = list(videos) - video_metadata = list(video_metadata) - else: - videos, video_metadata = None, None - - inputs = self.processor( - text=text, - images=images, - videos=videos, - video_metadata=video_metadata, - truncation=True, - max_length=self.max_length, - padding=True, - do_resize=False, - return_tensors="pt", - **video_kwargs, - ) - return inputs - - # Pool the last hidden state by attention mask for embeddings - @staticmethod - def _pooling_last(hidden_state: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor: - flipped_tensor = attention_mask.flip(dims=[1]) - last_one_positions = flipped_tensor.argmax(dim=1) - col = attention_mask.shape[1] - last_one_positions - 1 - row = torch.arange(hidden_state.shape[0], device=hidden_state.device) - return hidden_state[row, col] - - # Process inputs to generate normalized embeddings - def process(self, inputs: List[Dict[str, Any]], normalize: bool = True) -> tuple: - conversations = [ - self.format_model_input( - text=ele.get("text"), - image=ele.get("image"), - video=ele.get("video"), - instruction=ele.get("instruction"), - fps=ele.get("fps"), - max_frames=ele.get("max_frames"), - ) - for ele in inputs - ] - - processed_inputs = self._preprocess_inputs(conversations) - processed_inputs = {k: v.to(self.model.device) for k, v in processed_inputs.items()} - - outputs = self.forward(processed_inputs) - embeddings = self._pooling_last(outputs["last_hidden_state"], outputs["attention_mask"]) - - # Normalize the embeddings if specified - if normalize: - embeddings = F.normalize(embeddings, p=2, dim=-1) - - return embeddings diff --git a/telefuser/cache_mem/src/models/qwen3_vl_reranker.py b/telefuser/cache_mem/src/models/qwen3_vl_reranker.py deleted file mode 100644 index d067475..0000000 --- a/telefuser/cache_mem/src/models/qwen3_vl_reranker.py +++ /dev/null @@ -1,437 +0,0 @@ -import logging -import os -import unicodedata -from typing import Dict, List, Optional, Union -from urllib.parse import urlparse - -import numpy as np -import torch -from PIL import Image -from qwen_vl_utils import process_vision_info -from scipy import special -from transformers import AutoProcessor, Qwen3VLForConditionalGeneration - -logger = logging.getLogger(__name__) - -# Default configuration constants -MAX_LENGTH = 10240 -IMAGE_BASE_FACTOR = 16 -IMAGE_FACTOR = IMAGE_BASE_FACTOR * 2 -MIN_PIXELS = 4 * IMAGE_FACTOR * IMAGE_FACTOR # 4 tokens -MAX_PIXELS = 1800 * IMAGE_FACTOR * IMAGE_FACTOR # 1800 tokens -FPS = 1 -MAX_FRAMES = 64 -FRAME_MAX_PIXELS = 768 * IMAGE_FACTOR * IMAGE_FACTOR -MAX_TOTAL_PIXELS = 10 * FRAME_MAX_PIXELS # 7680 tokens - - -def is_image_path(path: str) -> bool: - image_extensions = {".jpg", ".jpeg", ".png", ".gif", ".bmp", ".webp", ".tiff", ".svg"} - - if path.startswith(("http://", "https://")): - # Parse URL to remove query parameters - parsed_url = urlparse(path) - clean_path = parsed_url.path - else: - clean_path = path - - # Check file extension - _, ext = os.path.splitext(clean_path.lower()) - return ext in image_extensions - - -def is_video_input(video) -> bool: - if isinstance(video, str): - return True - - if isinstance(video, list) and len(video) > 0: - # Check first element to determine the type - first_elem = video[0] - - if isinstance(first_elem, Image.Image): - return True - - if isinstance(first_elem, str): - return is_image_path(first_elem) - - return False - - -def sample_frames(frames: List[Union[str, Image.Image]], max_segments: int) -> List[Union[str, Image.Image]]: - duration = len(frames) - if duration <= max_segments: - return frames - - frame_id_array = np.linspace(0, duration - 1, max_segments, dtype=int) - frame_id_list = frame_id_array.tolist() - sampled_frames = [frames[frame_idx] for frame_idx in frame_id_list] - return sampled_frames - - -class Qwen3VLReranker: - def __init__( - self, - model_name_or_path: str, - max_length: int = MAX_LENGTH, - min_pixels: int = MIN_PIXELS, - max_pixels: int = MAX_PIXELS, - total_pixels: int = MAX_TOTAL_PIXELS, - fps: float = FPS, - max_frames: int = MAX_FRAMES, - batch_size: int = 2, - device: Optional[Union[str, "torch.device"]] = None, - default_instruction: str = "Given a search query, retrieve relevant candidates that answer the query.", - **kwargs, - ): - if device is not None: - self.device = torch.device(device) - else: - self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - - self.max_length = max_length - self.min_pixels = min_pixels - self.max_pixels = max_pixels - self.total_pixels = total_pixels - self.fps = fps - self.max_frames = max_frames - self.batch_size = max(1, int(batch_size or 1)) - self.default_instruction = default_instruction - - # Load the language model - lm = Qwen3VLForConditionalGeneration.from_pretrained(model_name_or_path, trust_remote_code=True, **kwargs).to( - self.device - ) - - self.model = lm.model - self.processor = AutoProcessor.from_pretrained(model_name_or_path, trust_remote_code=True, padding_side="left") - self.model.eval() - - # Initialize binary classification head for yes/no scoring - token_true_id = self.processor.tokenizer.get_vocab()["yes"] - token_false_id = self.processor.tokenizer.get_vocab()["no"] - self.score_linear = self.get_binary_linear(lm, token_true_id, token_false_id) - self.score_linear.eval() - self.score_linear.to(self.device).to(self.model.dtype) - - def get_binary_linear(self, model, token_yes: int, token_no: int) -> torch.nn.Linear: - lm_head_weights = model.lm_head.weight.data - - weight_yes = lm_head_weights[token_yes] - weight_no = lm_head_weights[token_no] - - D = weight_yes.size()[0] - linear_layer = torch.nn.Linear(D, 1, bias=False) - with torch.no_grad(): - linear_layer.weight[0] = weight_yes - weight_no - return linear_layer - - @torch.no_grad() - def compute_scores(self, inputs: Dict) -> List[float]: - batch_scores = self.model(**inputs).last_hidden_state[:, -1] - scores = self.score_linear(batch_scores) - scores = torch.sigmoid(scores).squeeze(-1).cpu().detach().tolist() - return scores - - def _move_inputs_to_device(self, tokenized_inputs: Dict) -> Dict: - if hasattr(tokenized_inputs, "to"): - return tokenized_inputs.to(self.model.device) - return { - key: value.to(self.model.device) if hasattr(value, "to") else value - for key, value in tokenized_inputs.items() - } - - def _normalize_score_list(self, scores: Union[List[float], tuple, float]) -> List[float]: - if isinstance(scores, list): - return [float(value) for value in scores] - if isinstance(scores, tuple): - return [float(value) for value in scores] - return [float(scores)] - - def _score_pairs_batched(self, batch_pairs: List[Dict]) -> List[float]: - tokenized_inputs = self.tokenize(batch_pairs) - tokenized_inputs = self._move_inputs_to_device(tokenized_inputs) - scores = self._normalize_score_list(self.compute_scores(tokenized_inputs)) - if len(scores) != len(batch_pairs): - raise ValueError(f"score size mismatch: expected {len(batch_pairs)} got {len(scores)}") - return scores - - def _score_pairs_serial(self, batch_pairs: List[Dict]) -> List[float]: - final_scores: List[float] = [] - for pair in batch_pairs: - final_scores.extend(self._score_pairs_batched([pair])) - return final_scores - - def truncate_tokens_optimized(self, tokens: List[str], max_length: int, special_tokens: List[str]) -> List[str]: - if len(tokens) <= max_length: - return tokens - - special_tokens_set = set(special_tokens) - - # Calculate budget: how many non-special tokens we can keep - num_special = sum(1 for token in tokens if token in special_tokens_set) - num_non_special_to_keep = max_length - num_special - - # Build final list according to budget - final_tokens = [] - non_special_kept_count = 0 - for token in tokens: - if token in special_tokens_set: - final_tokens.append(token) - elif non_special_kept_count < num_non_special_to_keep: - final_tokens.append(token) - non_special_kept_count += 1 - - return final_tokens - - def tokenize(self, pairs: List[Dict], **kwargs) -> Dict: - max_length = self.max_length - text = self.processor.apply_chat_template(pairs, tokenize=False, add_generation_prompt=True) - - try: - images, videos, video_kwargs = process_vision_info( - pairs, image_patch_size=16, return_video_kwargs=True, return_video_metadata=True - ) - except Exception as exc: - logger.exception( - "Qwen3VLReranker failed to process multimodal pairs pair_count=%s err=%s", - len(pairs), - exc, - ) - raise RuntimeError( - f"Qwen3VLReranker failed to process multimodal pairs pair_count={len(pairs)} err={exc}" - ) from exc - - if videos is not None: - videos, video_metadatas = zip(*videos) - videos, video_metadatas = list(videos), list(video_metadatas) - else: - video_metadatas = None - - inputs = self.processor( - text=text, - images=images, - videos=videos, - video_metadata=video_metadatas, - truncation=False, - padding=False, - do_resize=False, - **video_kwargs, - ) - - # Truncate input IDs while preserving special tokens - for i, ele in enumerate(inputs["input_ids"]): - inputs["input_ids"][i] = ( - self.truncate_tokens_optimized( - inputs["input_ids"][i][:-5], max_length, self.processor.tokenizer.all_special_ids - ) - + inputs["input_ids"][i][-5:] - ) - - # Apply padding - temp_inputs = self.processor.tokenizer.pad( - {"input_ids": inputs["input_ids"]}, padding=True, return_tensors="pt", max_length=self.max_length - ) - for key in temp_inputs: - inputs[key] = temp_inputs[key] - - return inputs - - def format_mm_content( - self, - text: Optional[Union[List[str], str]] = None, - image: Optional[Union[List[Union[str, Image.Image]], str, Image.Image]] = None, - video: Optional[ - Union[List[Union[str, List[Union[str, Image.Image]]]], str, List[Union[str, Image.Image]]] - ] = None, - prefix: str = "Query:", - fps: Optional[float] = None, - max_frames: Optional[int] = None, - ) -> List[Dict]: - content = [] - content.append({"type": "text", "text": prefix}) - - # Normalize text input to list - if text is None: - texts = [] - elif isinstance(text, str): - texts = [text] - else: - texts = text - - # Normalize image input to list - if image is None: - images = [] - elif not isinstance(image, list): - images = [image] - else: - images = image - - # Normalize video input to list - if video is None: - videos = [] - elif is_video_input(video): - videos = [video] - else: - # Assume it's a list of videos - videos = video - - if not texts and not images and not videos: - content.append({"type": "text", "text": "NULL"}) - return content - - # Process each video - for vid in videos: - video_content = None - video_kwargs = {"total_pixels": self.total_pixels} - - if isinstance(vid, list): - # Video as frame sequence - video_content = vid - if self.max_frames is not None: - video_content = sample_frames(video_content, self.max_frames) - video_content = [("file://" + ele if isinstance(ele, str) else ele) for ele in video_content] - elif isinstance(vid, str): - # Video as file path - video_content = vid if vid.startswith(("http://", "https://")) else "file://" + vid - video_kwargs = {"fps": fps or self.fps, "max_frames": max_frames or self.max_frames} - else: - raise TypeError(f"Unrecognized video type: {type(vid)}") - - # Add video input to content - if video_content: - content.append({"type": "video", "video": video_content, **video_kwargs}) - - # Process each image - for img in images: - image_content = None - - if isinstance(img, Image.Image): - image_content = img - elif isinstance(img, str): - image_content = img if img.startswith(("http://", "https://")) else "file://" + img - else: - raise TypeError(f"Unrecognized image type: {type(img)}") - - # Add image input to content - if image_content: - content.append( - { - "type": "image", - "image": image_content, - "min_pixels": self.min_pixels, - "max_pixels": self.max_pixels, - } - ) - - # Process each text - for txt in texts: - content.append({"type": "text", "text": txt}) - - return content - - def format_mm_instruction( - self, - query_text: Optional[Union[str, tuple]] = None, - query_image: Optional[Union[List[Union[str, Image.Image]], str, Image.Image]] = None, - query_video: Optional[ - Union[List[Union[str, List[Union[str, Image.Image]]]], str, List[Union[str, Image.Image]]] - ] = None, - doc_text: Optional[Union[List[str], str]] = None, - doc_image: Optional[Union[List[Union[str, Image.Image]], str, Image.Image]] = None, - doc_video: Optional[ - Union[List[Union[str, List[Union[str, Image.Image]]]], str, List[Union[str, Image.Image]]] - ] = None, - instruction: Optional[str] = None, - fps: Optional[float] = None, - max_frames: Optional[int] = None, - ) -> List[Dict]: - inputs = [] - inputs.append( - { - "role": "system", - "content": [ - { - "type": "text", - "text": ( - "Judge whether the Document meets the requirements based on the Query" - " and the Instruct provided. Note that the answer can only be" - ' "yes" or "no".' - ), - } - ], - } - ) - - # Handle query_text as tuple containing (instruction, text) - if isinstance(query_text, tuple): - instruct, query_text = query_text - else: - instruct = instruction - - contents = [] - contents.append({"type": "text", "text": ": " + (instruct or self.default_instruction)}) - - # Format query content - query_content = self.format_mm_content( - query_text, query_image, query_video, prefix=":", fps=fps, max_frames=max_frames - ) - contents.extend(query_content) - - # Format document content - doc_content = self.format_mm_content( - doc_text, doc_image, doc_video, prefix="\n:", fps=fps, max_frames=max_frames - ) - contents.extend(doc_content) - - inputs.append({"role": "user", "content": contents}) - - return inputs - - def process( - self, - inputs: Dict, - ) -> List[float]: - instruction = inputs.get("instruction", self.default_instruction) - - query = inputs.get("query", {}) - documents = inputs.get("documents", []) - - if not query or not documents: - return [] - - # Format each query-document pair - pairs = [ - self.format_mm_instruction( - query.get("text", None), - query.get("image", None), - query.get("video", None), - document.get("text", None), - document.get("image", None), - document.get("video", None), - instruction=instruction, - fps=inputs.get("fps", self.fps), - max_frames=inputs.get("max_frames", self.max_frames), - ) - for document in documents - ] - - batch_size = max(1, int(getattr(self, "batch_size", 1) or 1)) - final_scores: List[float] = [] - for start in range(0, len(pairs), batch_size): - batch_pairs = pairs[start : start + batch_size] - try: - final_scores.extend(self._score_pairs_batched(batch_pairs)) - except Exception as exc: - logger.exception( - "Qwen3VLReranker batch scoring failed start=%s size=%s err=%s", - start, - len(batch_pairs), - exc, - ) - raise RuntimeError( - "Qwen3VLReranker batch scoring failed " - f"start={start} size={len(batch_pairs)} " - f"err_type={type(exc).__name__} err={exc}" - ) from exc - - return final_scores diff --git a/telefuser/cache_mem/state/__init__.py b/telefuser/cache_mem/state/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/telefuser/cache_mem/state/interfaces.py b/telefuser/cache_mem/state/interfaces.py deleted file mode 100644 index 59e62cd..0000000 --- a/telefuser/cache_mem/state/interfaces.py +++ /dev/null @@ -1,67 +0,0 @@ -from __future__ import annotations - -from abc import ABC, abstractmethod -from typing import List, Optional - -from ..cache_types import IndexEntry - - -class CacheMetadataManager(ABC): - """缓存元数据管理器接口。""" - - @abstractmethod - def register_cache( - self, - cache_id: str, - prompt: str, - saved_steps: List[int], - size_mb: float, - num_frames: int, - cache_type: Optional[str] = None, - ) -> None: - pass - - @abstractmethod - def lookup_prompt(self, prompt: str, cache_type: Optional[str] = None) -> Optional[IndexEntry]: - pass - - @abstractmethod - def record_access(self, cache_id: str) -> None: - pass - - @abstractmethod - def plan_eviction(self, required_mb: float, limit_mb: float) -> List[tuple]: - pass - - @abstractmethod - def remove_cache(self, cache_id: str) -> None: - pass - - @abstractmethod - def record_hit_pair( - self, - request_prompt: str, - cache_id: str, - cached_prompt: str, - similarity: float, - task_type: str, - cache_type: str, - skip_step: int, - ) -> None: - pass - - @abstractmethod - def record_similarity_scores( - self, - request_prompt: str, - task_type: str, - cache_type: str, - stage: str, - candidates: List[dict], - ) -> None: - pass - - @abstractmethod - def get_cache_meta(self, cache_id: str) -> Optional[dict]: - """获取指定 cache_id 的元数据,用于排查一致性问题。""" - pass diff --git a/telefuser/cache_mem/storage/__init__.py b/telefuser/cache_mem/storage/__init__.py deleted file mode 100644 index 11e80de..0000000 --- a/telefuser/cache_mem/storage/__init__.py +++ /dev/null @@ -1,11 +0,0 @@ -from .fluxon import FluxonKVStore -from .interfaces import KVStore -from .local_file import LocalFileKVStore -from .memory import InMemoryKVStore - -__all__ = [ - "KVStore", - "InMemoryKVStore", - "LocalFileKVStore", - "FluxonKVStore", -] diff --git a/telefuser/cache_mem/storage/fluxon.py b/telefuser/cache_mem/storage/fluxon.py deleted file mode 100644 index ab392ba..0000000 --- a/telefuser/cache_mem/storage/fluxon.py +++ /dev/null @@ -1,24 +0,0 @@ -from __future__ import annotations - -from typing import Any, Optional - -from .interfaces import KVStore - - -class FluxonKVStore(KVStore): - """Fluxon KV backend stub — not available in MVP.""" - - def __init__(self, config_path: Optional[str] = None, store: Optional[Any] = None): - raise NotImplementedError("FluxonKV backend not available in MVP. Planned for v2.") - - def get(self, key: str) -> Optional[bytes]: - raise NotImplementedError("FluxonKV backend not available in MVP. Planned for v2.") - - def put(self, key: str, value: bytes) -> None: - raise NotImplementedError("FluxonKV backend not available in MVP. Planned for v2.") - - def remove(self, key: str) -> None: - raise NotImplementedError("FluxonKV backend not available in MVP. Planned for v2.") - - def list_keys(self) -> list[str]: - raise NotImplementedError("FluxonKV backend not available in MVP. Planned for v2.") diff --git a/telefuser/cache_mem/storage/interfaces.py b/telefuser/cache_mem/storage/interfaces.py deleted file mode 100644 index 9f3f2d0..0000000 --- a/telefuser/cache_mem/storage/interfaces.py +++ /dev/null @@ -1,25 +0,0 @@ -from __future__ import annotations - -from abc import ABC, abstractmethod -from typing import Optional - - -class KVStore(ABC): - """键值存储接口。""" - - @abstractmethod - def get(self, key: str) -> Optional[bytes]: - pass - - @abstractmethod - def put(self, key: str, value: bytes) -> None: - pass - - @abstractmethod - def remove(self, key: str) -> None: - pass - - @abstractmethod - def list_keys(self) -> list[str]: - """列出当前存储的 key 列表。""" - pass diff --git a/telefuser/cache_mem/storage/local_file.py b/telefuser/cache_mem/storage/local_file.py deleted file mode 100644 index 52046b6..0000000 --- a/telefuser/cache_mem/storage/local_file.py +++ /dev/null @@ -1,112 +0,0 @@ -from __future__ import annotations - -import hashlib -import json -import os -import threading -from pathlib import Path -from typing import Dict, Optional - -from loguru import logger - -from .interfaces import KVStore - - -class LocalFileKVStore(KVStore): - """本地文件 KV 存储实现(按 key 持久化到磁盘)。 - - Thread-safe within a single process via an internal RLock guarding - the in-memory index dict and disk index file. For cross-process - safety, callers should additionally hold a FileLock on the cache - directory around multi-resource transactions. - """ - - def __init__(self, root_dir: str | Path) -> None: - self.root_dir = Path(root_dir) - self.root_dir.mkdir(parents=True, exist_ok=True) - self._index_path = self.root_dir / "kv_index.json" - self._lock = threading.RLock() - self._index: Dict[str, str] = self._load_index() - - def put(self, key: str, value: bytes) -> None: - with self._lock: - filename = self._index.get(key) - if not filename: - filename = self._hash_key(key) - self._index[key] = filename - file_path = self.root_dir / filename - tmp_path = self.root_dir / f"{filename}.tmp" - tmp_path.write_bytes(value) - os.replace(tmp_path, file_path) - self._save_index() - - def get(self, key: str) -> Optional[bytes]: - with self._lock: - filename = self._index.get(key) - if not filename: - return None - file_path = self.root_dir / filename - # Read bytes outside the lock — file content is immutable once - # written via put()'s atomic os.replace, so concurrent reads are safe. - if not file_path.exists(): - return None - return file_path.read_bytes() - - def remove(self, key: str) -> None: - with self._lock: - filename = self._index.pop(key, None) - if filename: - file_path = self.root_dir / filename - if file_path.exists(): - file_path.unlink() - self._save_index() - - def list_keys(self) -> list[str]: - with self._lock: - return list(self._index.keys()) - - def _hash_key(self, key: str) -> str: - digest = hashlib.sha256(key.encode("utf-8")).hexdigest() - return f"{digest}.bin" - - def _load_index(self) -> Dict[str, str]: - if not self._index_path.exists(): - return {} - try: - raw = self._index_path.read_text() - except OSError as exc: - logger.exception( - "LocalFileKVStore failed to read index path={} err={}", - self._index_path, - exc, - ) - raise RuntimeError(f"LocalFileKVStore failed to read index path={self._index_path}: {exc}") from exc - try: - data = json.loads(raw) - except json.JSONDecodeError as exc: - logger.exception( - "LocalFileKVStore index is not valid JSON path={} err={}", - self._index_path, - exc, - ) - raise ValueError(f"LocalFileKVStore index is not valid JSON path={self._index_path}: {exc}") from exc - if not isinstance(data, dict): - raise ValueError( - f"LocalFileKVStore index must be a JSON object path={self._index_path} got_type={type(data).__name__}" - ) - invalid_items = [ - (key, value) for key, value in data.items() if not isinstance(key, str) or not isinstance(value, str) - ] - if invalid_items: - key, value = invalid_items[0] - raise ValueError( - "LocalFileKVStore index contains non-string entry " - f"path={self._index_path} key_type={type(key).__name__} " - f"value_type={type(value).__name__}" - ) - return data - - def _save_index(self) -> None: - tmp = self._index_path.with_suffix(".json.tmp") - tmp.write_text(json.dumps(self._index, ensure_ascii=True)) - os.replace(tmp, self._index_path) diff --git a/telefuser/cache_mem/storage/memory.py b/telefuser/cache_mem/storage/memory.py deleted file mode 100644 index 5fbb568..0000000 --- a/telefuser/cache_mem/storage/memory.py +++ /dev/null @@ -1,24 +0,0 @@ -from __future__ import annotations - -from typing import Dict, Optional - -from .interfaces import KVStore - - -class InMemoryKVStore(KVStore): - """内存 KV 存储实现(简单字典)。""" - - def __init__(self) -> None: - self._store: Dict[str, bytes] = {} - - def get(self, key: str) -> Optional[bytes]: - return self._store.get(key) - - def put(self, key: str, value: bytes) -> None: - self._store[key] = value - - def remove(self, key: str) -> None: - self._store.pop(key, None) - - def list_keys(self) -> list[str]: - return list(self._store.keys()) diff --git a/telefuser/cache_mem/strategies.py b/telefuser/cache_mem/strategies.py deleted file mode 100644 index 0370733..0000000 --- a/telefuser/cache_mem/strategies.py +++ /dev/null @@ -1,819 +0,0 @@ -from __future__ import annotations - -import io -import uuid -from abc import ABC, abstractmethod -from typing import Any, Dict, List, Optional - -import torch -from loguru import logger - -from .cache_types import CacheResult, VectorSearchResult -from .config import CacheConfig -from .encoders import Qwen3VLEncoder, Qwen3VLReranker -from .encoding.interfaces import PromptEncoder, VideoEncoder -from .state.interfaces import CacheMetadataManager -from .storage.interfaces import KVStore -from .vector_store.interfaces import VectorStore - - -class BaseCacheStrategy(ABC): - """缓存策略抽象基类。""" - - def __init__( - self, - config: CacheConfig, - kv_store: KVStore, - metadata_manager: CacheMetadataManager, - ): - self.config = config - self.kv_store = kv_store - self.metadata_manager = metadata_manager - - @abstractmethod - async def lookup(self, **kwargs) -> CacheResult: - pass - - async def apply(self, result: CacheResult) -> CacheResult: - return result - - @abstractmethod - async def save(self, **kwargs) -> None: - pass - - def _load_latent(self, cache_id: str, step: int) -> Optional[torch.Tensor]: - key = f"{cache_id}_step{int(step)}" - data = self.kv_store.get(key) - if data is None and "-" in (cache_id or ""): - normalized = self._normalize_cache_id(cache_id) - if normalized != cache_id: - key = f"{normalized}_step{int(step)}" - data = self.kv_store.get(key) - if data is None: - return None - try: - # weights_only=True blocks arbitrary code execution from untrusted - # KV bytes; we only persist tensors here so this is safe. - return torch.load(io.BytesIO(data), map_location="cpu", weights_only=True) - except Exception as exc: - logger.exception( - "Cache load failed cache_id={} step={} err={}", - cache_id, - int(step), - exc, - ) - raise RuntimeError( - f"Cache load failed cache_id={cache_id} step={int(step)} type={type(exc).__name__} err={exc}" - ) from exc - - def _save_latent(self, cache_id: str, step: int, latent: torch.Tensor) -> None: - key = f"{cache_id}_step{int(step)}" - buffer = io.BytesIO() - try: - torch.save(latent, buffer) - self.kv_store.put(key, buffer.getvalue()) - except Exception as exc: - logger.exception( - "Cache save failed cache_id={} step={} err={}", - cache_id, - int(step), - exc, - ) - raise RuntimeError( - f"Cache save failed cache_id={cache_id} step={int(step)} type={type(exc).__name__} err={exc}" - ) from exc - - def _latent_size_bytes(self, cache_id: str, step: int, latent: torch.Tensor) -> int: - nelement = getattr(latent, "nelement", None) - element_size = getattr(latent, "element_size", None) - if not callable(nelement) or not callable(element_size): - raise TypeError( - "Latent tensor does not expose size methods " - f"cache_id={cache_id} step={int(step)} type={type(latent).__name__}" - ) - return int(nelement()) * int(element_size()) - - def _generate_cache_id(self) -> str: - return uuid.uuid4().hex - - def _normalize_cache_id(self, cache_id: str) -> str: - return (cache_id or "").replace("-", "") - - def _normalize_search_results(self, results: List[VectorSearchResult]) -> None: - for r in results: - r.cache_id = self._normalize_cache_id(r.cache_id) - - def _candidate_text(self, result: VectorSearchResult) -> str: - text = result.prompt or "" - if not text and isinstance(result.payload, dict): - text = result.payload.get("prompt") or "" - return text - - -class VideoBasedApproximateCache(BaseCacheStrategy): - def __init__( - self, - config, - kv_store: KVStore, - vector_store: Optional[VectorStore], - metadata_manager: CacheMetadataManager, - *, - prompt_encoder: Optional[PromptEncoder] = None, - video_encoder: Optional["VideoEncoder"] = None, - reranker: Optional[object] = None, - ): - super().__init__(config, kv_store, metadata_manager) - self.vector_store = vector_store - - # Build text / video encoder - enable_video_embedding = bool(getattr(self.config, "video_embedding_enabled", False)) - text_model_path = getattr(self.config, "text_embedding_model_path", None) or None - use_text_embedding = bool(text_model_path) or enable_video_embedding - - def _build_prompt_encoder() -> Qwen3VLEncoder: - model_path = ( - text_model_path - or getattr( - self.config, - "video_embedding_model_path", - None, - ) - or "Qwen/Qwen3-VL-Embedding-2B" - ) - device_id = getattr(self.config, "text_embedding_device_id", None) - encoder = Qwen3VLEncoder( - model_path=model_path, - instruction=getattr( - self.config, - "text_embedding_instruction", - "Represent the user's input", - ), - max_frames=int(getattr(self.config, "video_embedding_max_frames", 16)), - fps=float(getattr(self.config, "video_embedding_fps", 1.0)), - device_id=device_id, - torch_dtype=getattr(self.config, "text_embedding_torch_dtype", None), - attn_implementation=getattr(self.config, "text_embedding_attn_impl", None), - ) - logger.info( - "VideoBasedApproximateCache prompt encoder enabled model_path={} device_id={}", - model_path, - device_id, - ) - return encoder - - def _build_video_encoder() -> Qwen3VLEncoder: - model_path = ( - getattr( - self.config, - "video_embedding_model_path", - None, - ) - or text_model_path - or "Qwen/Qwen3-VL-Embedding-2B" - ) - device_id = getattr(self.config, "video_embedding_device_id", None) - encoder = Qwen3VLEncoder( - model_path=model_path, - instruction=getattr( - self.config, - "video_embedding_instruction", - "Represent the user's input", - ), - max_frames=int(getattr(self.config, "video_embedding_max_frames", 16)), - fps=float(getattr(self.config, "video_embedding_fps", 1.0)), - device_id=device_id, - torch_dtype=getattr(self.config, "video_embedding_torch_dtype", None), - attn_implementation=getattr(self.config, "video_embedding_attn_impl", None), - ) - logger.info( - "VideoBasedApproximateCache video encoder enabled model_path={} device_id={}", - model_path, - device_id, - ) - return encoder - - self.prompt_encoder = prompt_encoder - self.video_encoder = video_encoder - - if use_text_embedding and self.prompt_encoder is None: - self.prompt_encoder = _build_prompt_encoder() - if enable_video_embedding and self.video_encoder is None: - # Qwen3VLEncoder exposes both encode(text) and encode_video(frames) - # on a single backend embedder. When text and video configs would - # load the identical model onto the identical device, reuse the - # prompt_encoder instance to save ~5GB GPU mem and one cold load. - video_model_path = ( - getattr(self.config, "video_embedding_model_path", None) - or getattr(self.config, "text_embedding_model_path", None) - or "Qwen/Qwen3-VL-Embedding-2B" - ) - video_device_id = getattr(self.config, "video_embedding_device_id", None) - if ( - self.prompt_encoder is not None - and getattr(self.prompt_encoder, "model_path", None) == video_model_path - and getattr(self.prompt_encoder, "device_id", None) == video_device_id - ): - self.video_encoder = self.prompt_encoder - logger.info( - "VideoBasedApproximateCache video_encoder shares prompt_encoder " - "instance (same model_path={} device_id={}, save ~5GB)", - video_model_path, - video_device_id, - ) - else: - self.video_encoder = _build_video_encoder() - - if use_text_embedding and self.prompt_encoder is None: - logger.warning( - "VideoBasedApproximateCache prompt encoder unavailable;" - " configure text_embedding_model_path or provide prompt_encoder" - ) - if enable_video_embedding and self.video_encoder is None: - logger.warning( - "VideoBasedApproximateCache video encoder unavailable;" - " configure video embedding or provide video_encoder" - ) - - # Build reranker - if reranker is not None: - self.reranker = reranker - elif getattr(self.config, "rerank_enabled", False): - self.reranker = Qwen3VLReranker( - model_path=getattr(self.config, "rerank_model_path", None) or "Qwen/Qwen3-VL-Reranker-2B", - device_id=getattr(self.config, "rerank_device_id", None), - batch_size=int(getattr(self.config, "rerank_batch_size", 2) or 2), - torch_dtype=getattr(self.config, "rerank_torch_dtype", None), - ) - backend_reranker = getattr(self.reranker, "_reranker", None) - actual_reranker_device = getattr(getattr(backend_reranker, "model", None), "device", None) - if actual_reranker_device is None: - actual_reranker_device = getattr(backend_reranker, "device", "unknown") - logger.debug( - "VideoBasedApproximateCache reranker enabled model_path={} device_id={} actual_device={}", - getattr(self.config, "rerank_model_path", ""), - getattr(self.config, "rerank_device_id", None), - actual_reranker_device, - ) - else: - self.reranker = None - - async def lookup(self, prompt: str, task_type: str) -> CacheResult: - prompt = prompt or "" - logger.debug(f"VideoBasedApproximateCache.lookup start task_type={task_type} prompt_len={len(prompt)}") - if not prompt: - logger.debug("VideoBasedApproximateCache.lookup miss: empty prompt") - return CacheResult(hit=False) - if self.vector_store is None: - logger.debug("VideoBasedApproximateCache.lookup miss: vector_store unavailable") - return CacheResult(hit=False) - - if self.prompt_encoder is None: - logger.warning("VideoBasedApproximateCache.lookup miss: prompt encoder unavailable") - return CacheResult(hit=False) - - query_vec = self.prompt_encoder.encode(prompt) - if not query_vec: - logger.debug("VideoBasedApproximateCache.lookup miss: prompt embedding unavailable") - return CacheResult(hit=False) - - hit_score = None - if getattr(self.config, "rerank_enabled", False): - top_k = int(getattr(self.config, "rerank_top_k", 1) or 1) - results = self._vector_search(query_vec, top_k=top_k) - self._normalize_search_results(results) - if not results: - logger.debug("VideoBasedApproximateCache.lookup miss: no vector result") - return CacheResult(hit=False) - try: - self.metadata_manager.record_similarity_scores( - request_prompt=prompt, - task_type=task_type, - cache_type="video_approximate_cache", - stage="vector_search", - candidates=[ - { - "cache_id": item.cache_id, - "similarity": float(item.similarity), - "prompt": item.prompt, - "saved_steps": item.saved_steps, - } - for item in results - ], - ) - except Exception as exc: - logger.exception( - "VideoBasedApproximateCache record_similarity_scores failed stage=vector_search err_type={} err={}", - type(exc).__name__, - exc, - ) - scores = self._rerank_scores(prompt, results, "VideoBasedApproximateCache") - if scores is None: - logger.debug("VideoBasedApproximateCache.lookup rerank skip: fallback to vector similarity") - result = results[0] - threshold = getattr(self.config, "video_similarity_threshold", 0.10) - if result.similarity < threshold: - logger.debug( - "VideoBasedApproximateCache.lookup miss: similarity below threshold " - f"sim={result.similarity:.4f} threshold={threshold:.4f}" - ) - return CacheResult(hit=False) - hit_score = result.similarity - skip_step = self._determine_skip_step(hit_score, result.saved_steps) - if skip_step <= 0: - logger.debug( - "VideoBasedApproximateCache.lookup miss: skip_step=0 " - f"sim={result.similarity:.4f} saved_steps={result.saved_steps}" - ) - return CacheResult(hit=False) - else: - if len(scores) != len(results): - logger.warning( - "VideoBasedApproximateCache.lookup rerank invalid scores size={}", - len(scores or []), - ) - return CacheResult(hit=False) - try: - self.metadata_manager.record_similarity_scores( - request_prompt=prompt, - task_type=task_type, - cache_type="video_approximate_cache", - stage="rerank", - candidates=[ - { - "cache_id": item.cache_id, - "similarity": float(item.similarity), - "rerank_score": float(scores[idx]), - "prompt": item.prompt, - "saved_steps": item.saved_steps, - } - for idx, item in enumerate(results) - ], - ) - except Exception as exc: - logger.exception( - "VideoBasedApproximateCache record_similarity_scores failed stage=rerank err_type={} err={}", - type(exc).__name__, - exc, - ) - best_idx = max(range(len(scores)), key=lambda idx: scores[idx]) - rerank_score = float(scores[best_idx]) - result = results[best_idx] - logger.debug( - "VideoBasedApproximateCache.lookup rerank select cache_id={} score={:.4f} sim={:.4f}", - result.cache_id, - rerank_score, - result.similarity, - ) - rerank_threshold = float(getattr(self.config, "rerank_score_threshold", 0.95) or 0.95) - if rerank_score <= rerank_threshold: - logger.debug( - "VideoBasedApproximateCache.lookup miss: rerank score below threshold " - f"score={rerank_score:.4f} threshold={rerank_threshold:.4f}" - ) - return CacheResult(hit=False) - hit_score = rerank_score - skip_step = self._determine_skip_step(hit_score, result.saved_steps) - if skip_step <= 0: - logger.debug( - "VideoBasedApproximateCache.lookup miss: skip_step=0 " - f"score={rerank_score:.4f} saved_steps={result.saved_steps}" - ) - return CacheResult(hit=False) - else: - results = self._vector_search(query_vec, top_k=1) - if not results: - logger.debug("VideoBasedApproximateCache.lookup miss: no vector result") - return CacheResult(hit=False) - try: - self.metadata_manager.record_similarity_scores( - request_prompt=prompt, - task_type=task_type, - cache_type="video_approximate_cache", - stage="vector_search", - candidates=[ - { - "cache_id": item.cache_id, - "similarity": float(item.similarity), - "prompt": item.prompt, - "saved_steps": item.saved_steps, - } - for item in results - ], - ) - except Exception as exc: - logger.exception( - "VideoBasedApproximateCache record_similarity_scores failed stage=vector_search err_type={} err={}", - type(exc).__name__, - exc, - ) - result = results[0] - - threshold = getattr(self.config, "video_similarity_threshold", 0.10) - if result.similarity < threshold: - logger.debug( - "VideoBasedApproximateCache.lookup miss: similarity below threshold " - f"sim={result.similarity:.4f} threshold={threshold:.4f}" - ) - return CacheResult(hit=False) - - hit_score = result.similarity - skip_step = self._determine_skip_step(hit_score, result.saved_steps) - if skip_step <= 0: - logger.debug( - "VideoBasedApproximateCache.lookup miss: skip_step=0 " - f"sim={result.similarity:.4f} saved_steps={result.saved_steps}" - ) - return CacheResult(hit=False) - - latent = self._load_latent(result.cache_id, skip_step) - if latent is None: - meta = None - try: - meta = self.metadata_manager.get_cache_meta(result.cache_id) - except Exception as exc: - logger.exception( - "VideoBasedApproximateCache lookup meta check failed cache_id={} err_type={} err={}", - result.cache_id, - type(exc).__name__, - exc, - ) - meta_hint = "" - if meta: - meta_hint = ( - f" meta_prompt={meta.get('prompt')} " - f"meta_steps={meta.get('saved_steps')} " - f"meta_type={meta.get('cache_type')}" - ) - logger.warning( - "VideoBasedApproximateCache.lookup miss: hit by threshold but KV missing " - f"cache_id={result.cache_id} step={skip_step} sim={result.similarity:.4f} " - f"meta_exists={bool(meta)}{meta_hint}" - ) - return CacheResult(hit=False) - self.metadata_manager.record_access(result.cache_id) - try: - self.metadata_manager.record_hit_pair( - request_prompt=prompt, - cache_id=result.cache_id, - cached_prompt=result.prompt, - similarity=float(hit_score if hit_score is not None else result.similarity), - task_type=task_type, - cache_type="video_approximate_cache", - skip_step=skip_step, - ) - except Exception as exc: - logger.exception( - "VideoBasedApproximateCache record_hit_pair failed cache_id={} err_type={} err={}", - result.cache_id, - type(exc).__name__, - exc, - ) - logger.debug( - "VideoBasedApproximateCache.lookup hit " - f"cache_id={result.cache_id} step={skip_step} sim={result.similarity:.4f}" - ) - return CacheResult( - hit=True, - skip_step=skip_step, - cache_type="video_approximate_cache", - similarity=result.similarity, - latent_state=latent, - cached_prompt=result.prompt, - ) - - async def save( - self, - prompt: str, - latent_states_dict: Dict[int, torch.Tensor], - num_frames: int, - task_type: str, - saved_steps: List[int], - embedding_video_frames: Optional[List[Any]] = None, - ) -> None: - prompt = prompt or "" - logger.debug( - "VideoBasedApproximateCache.save start " - f"task_type={task_type} prompt_len={len(prompt)} saved_steps={saved_steps}" - ) - if not prompt: - logger.debug("VideoBasedApproximateCache.save skip: empty prompt") - return - if not latent_states_dict or not saved_steps: - logger.debug("VideoBasedApproximateCache.save skip: no latent_states or saved_steps") - return - - cache_id = self._generate_cache_id() - requested_steps = sorted(set(int(s) for s in saved_steps)) - saved_steps = [] - total_bytes = 0 - collection = getattr(self.config, "video_vector_collection", "video") - vector_written = False - metadata_attempted = False - - try: - for step in requested_steps: - latent = latent_states_dict.get(step) - if latent is None: - continue - self._save_latent(cache_id, step, latent) - saved_steps.append(step) - total_bytes += self._latent_size_bytes(cache_id, step, latent) - except Exception as exc: - logger.exception( - "VideoBasedApproximateCache.save latent persistence failed cache_id={} err={}", - cache_id, - exc, - ) - if saved_steps: - try: - self._remove_saved_latents(cache_id, saved_steps) - except Exception as cleanup_exc: - raise RuntimeError( - "VideoBasedApproximateCache.save failed during latent persistence " - f"cache_id={cache_id} err={exc}; cleanup_err={cleanup_exc}" - ) from exc - raise RuntimeError( - f"VideoBasedApproximateCache.save failed during latent persistence cache_id={cache_id} err={exc}" - ) from exc - - if not saved_steps: - logger.debug("VideoBasedApproximateCache.save skip: no latent saved") - return - - size_mb = float(total_bytes) / (1024 * 1024) if total_bytes > 0 else 0.0 - logger.debug( - "VideoBasedApproximateCache.save stored " - f"cache_id={cache_id} steps={saved_steps} size_mb={size_mb:.4f} frames={num_frames}" - ) - - if self.vector_store is None: - logger.warning("VideoBasedApproximateCache.save skip: vector_store unavailable") - self._remove_saved_latents(cache_id, saved_steps) - return - - if not embedding_video_frames: - logger.debug("VideoBasedApproximateCache.save skip: no video frames provided") - self._remove_saved_latents(cache_id, saved_steps) - return - - if self.video_encoder is None: - logger.warning("VideoBasedApproximateCache.save skip: video encoder unavailable") - self._remove_saved_latents(cache_id, saved_steps) - return - - try: - frames = self._load_frames_for_embedding( - embedding_video_frames=embedding_video_frames, - ) - if not frames: - logger.debug("VideoBasedApproximateCache.save skip: sampled frames empty") - self._remove_saved_latents(cache_id, saved_steps) - return - logger.debug( - "VideoBasedApproximateCache.save frames decoded " - f"count={len(frames)} size={getattr(frames[0], 'size', None)}" - ) - video_vec = self.video_encoder.encode_video(frames, prompt=prompt) - if not video_vec: - logger.debug("VideoBasedApproximateCache.save skip: video embedding unavailable") - self._remove_saved_latents(cache_id, saved_steps) - return - - vector_dim = len(video_vec) - self.vector_store.ensure_collection(collection, vector_dim) - logger.debug(f"VideoBasedApproximateCache.save ensure collection={collection} dim={vector_dim}") - - payload = { - "prompt": prompt, - "saved_steps": saved_steps, - "task_type": task_type, - } - self.vector_store.upsert( - collection, - cache_id, - video_vec, - payload, - ) - vector_written = True - metadata_attempted = True - self.metadata_manager.register_cache( - cache_id, - prompt, - saved_steps, - size_mb, - num_frames, - cache_type="video_approximate_cache", - ) - except Exception as exc: - logger.exception( - "VideoBasedApproximateCache.save failed cache_id={} collection={} err={}", - cache_id, - collection, - exc, - ) - try: - self._rollback_cache_entry( - cache_id=cache_id, - saved_steps=saved_steps, - collection=collection, - remove_vector=vector_written, - remove_metadata=metadata_attempted, - ) - except Exception as rollback_exc: - raise RuntimeError( - "VideoBasedApproximateCache.save failed " - f"cache_id={cache_id} collection={collection} err={exc}; " - f"rollback_err={rollback_exc}" - ) from exc - raise RuntimeError( - f"VideoBasedApproximateCache.save failed cache_id={cache_id} collection={collection} err={exc}" - ) from exc - logger.debug(f"VideoBasedApproximateCache.vector_store upsert collection={collection} cache_id={cache_id}") - - def _remove_saved_latents(self, cache_id: str, saved_steps: List[int]) -> None: - errors: List[str] = [] - for step in saved_steps: - try: - self.kv_store.remove(f"{cache_id}_step{int(step)}") - except Exception as exc: - logger.exception( - "VideoBasedApproximateCache latent cleanup failed cache_id={} step={} err={}", - cache_id, - int(step), - exc, - ) - errors.append( - f"kv remove failed cache_id={cache_id} step={int(step)} type={type(exc).__name__} err={exc}" - ) - if errors: - raise RuntimeError("VideoBasedApproximateCache latent cleanup failed: " + "; ".join(errors)) - - def _rollback_cache_entry( - self, - cache_id: str, - saved_steps: List[int], - collection: str, - remove_vector: bool, - remove_metadata: bool, - ) -> None: - errors: List[str] = [] - if remove_vector and self.vector_store is not None: - try: - self.vector_store.delete(collection, [cache_id]) - except Exception as exc: - logger.exception( - "VideoBasedApproximateCache vector rollback failed collection={} cache_id={} err={}", - collection, - cache_id, - exc, - ) - errors.append( - "vector rollback failed " - f"collection={collection} cache_id={cache_id} " - f"type={type(exc).__name__} err={exc}" - ) - if remove_metadata: - try: - self.metadata_manager.remove_cache(cache_id) - except Exception as exc: - logger.exception( - "VideoBasedApproximateCache metadata rollback failed cache_id={} err={}", - cache_id, - exc, - ) - errors.append(f"metadata rollback failed cache_id={cache_id} type={type(exc).__name__} err={exc}") - try: - self._remove_saved_latents(cache_id, saved_steps) - except Exception as exc: - errors.append(str(exc)) - if errors: - raise RuntimeError(f"VideoBasedApproximateCache rollback failed cache_id={cache_id}: {'; '.join(errors)}") - - def _vector_search(self, query_vec: List[float], top_k: int = 1) -> List[VectorSearchResult]: - if self.vector_store is None: - return [] - collection = getattr(self.config, "video_vector_collection", "video") - top_k = max(1, int(top_k or 1)) - res = self.vector_store.search(collection, query_vec, limit=top_k) - if not res: - return [] - res.sort(key=lambda item: item.similarity, reverse=True) - return res[:top_k] - - def _load_frames_for_embedding( - self, - *, - embedding_video_frames: Optional[List[Any]], - ) -> List[Any]: - if embedding_video_frames: - return list(embedding_video_frames) - return [] - - def _sample_indices(self, total: int, max_frames: int) -> List[int]: - if total <= 0: - return [] - max_frames = max(1, int(max_frames or 1)) - if total <= max_frames: - return list(range(total)) - step = float(total) / float(max_frames) - return [min(int(i * step), total - 1) for i in range(max_frames)] - - def _determine_skip_step(self, similarity: float, saved_steps: List[int]) -> int: - steps = set(int(s) for s in saved_steps) - rerank_threshold = float(getattr(self.config, "rerank_score_threshold", 0.90) or 0.90) - if similarity > rerank_threshold and 5 in steps: - return 5 - return 0 - - def _build_rerank_documents( - self, - results: List[VectorSearchResult], - ) -> List[Dict[str, object]]: - documents: List[Dict[str, object]] = [] - for item in results: - text = self._candidate_text(item) - doc: Dict[str, object] = {} - if text: - doc["text"] = text - documents.append(doc) - return documents - - # def _build_rerank_documents( - # self, - # results: List[VectorSearchResult], - # ) -> List[Dict[str, object]]: - # documents: List[Dict[str, object]] = [] - # for item in results: - # text = self._candidate_text(item) - # doc: Dict[str, object] = {} - # if text: - # doc["text"] = text - # documents.append(doc) - # return documents - - def _rerank_scores( - self, - query: str, - results: List[VectorSearchResult], - source: str, - ) -> Optional[List[float]]: - reranker = getattr(self, "reranker", None) - if reranker is None: - logger.warning(f"{source} rerank skip: reranker unavailable") - return None - if not hasattr(reranker, "score_mm"): - logger.warning(f"{source} rerank skip: text reranker unavailable") - return None - documents = self._build_rerank_documents(results) - has_text_docs = any("text" in doc and doc["text"] for doc in documents) - if not has_text_docs: - logger.debug(f"{source} rerank skip: no text candidates available") - return None - - try: - logger.debug(f"{source} rerank mode=text candidates={len(results)}") - scores = reranker.score_mm({"text": query}, documents) - except Exception as exc: - logger.exception(f"{source} text rerank failed: {exc}") - raise RuntimeError(f"{source} text rerank failed err_type={type(exc).__name__} err={exc}") from exc - if not scores or len(scores) != len(results): - raise ValueError(f"{source} rerank invalid scores size={len(scores or [])} expected={len(results)}") - score_pairs = [] - for idx, item in enumerate(results): - try: - score_value = float(scores[idx]) - score_pairs.append(f"{item.cache_id}:{score_value:.4f}/{item.similarity:.4f}") - except (IndexError, TypeError, ValueError) as exc: - logger.exception( - "{} rerank score formatting failed cache_id={} idx={} err_type={} err={}", - source, - item.cache_id, - idx, - type(exc).__name__, - exc, - ) - raise RuntimeError( - f"{source} rerank score formatting failed " - f"cache_id={item.cache_id} idx={idx} " - f"err_type={type(exc).__name__} err={exc}" - ) from exc - logger.debug(f"{source} rerank scores={score_pairs}") - return [float(value) for value in scores] - - -# --------------------------------------------------------------------------- -# Strategy Registry -# --------------------------------------------------------------------------- - -_STRATEGY_REGISTRY: Dict[str, type] = {} - - -def register_strategy(name: str, cls: type) -> None: - _STRATEGY_REGISTRY[name] = cls - - -def get_strategy_class(name: str) -> Optional[type]: - return _STRATEGY_REGISTRY.get(name) - - -register_strategy("video_approximate", VideoBasedApproximateCache) diff --git a/telefuser/cache_mem/vector_store/__init__.py b/telefuser/cache_mem/vector_store/__init__.py deleted file mode 100644 index bf4f5d9..0000000 --- a/telefuser/cache_mem/vector_store/__init__.py +++ /dev/null @@ -1,5 +0,0 @@ -from .faiss import FAISSVectorStore -from .interfaces import VectorStore -from .qdrant import QdrantVectorStore - -__all__ = ["VectorStore", "QdrantVectorStore", "FAISSVectorStore"] diff --git a/telefuser/cache_mem/vector_store/faiss.py b/telefuser/cache_mem/vector_store/faiss.py deleted file mode 100644 index 94046f3..0000000 --- a/telefuser/cache_mem/vector_store/faiss.py +++ /dev/null @@ -1,298 +0,0 @@ -from __future__ import annotations - -import json -import threading -from pathlib import Path -from typing import Any, Dict, List, Optional - -from loguru import logger - -from ..cache_types import VectorSearchResult -from .interfaces import VectorStore - - -class FAISSVectorStore(VectorStore): - def __init__( - self, - index_dir: Path, - vector_dim: int, - index_type: str = "L2", - ) -> None: - self.index_dir = Path(index_dir) - self.vector_dim = vector_dim - self.index_type = index_type - self._lock = threading.RLock() - self._indices: Dict[str, Any] = {} - self._metadata: Dict[str, Dict[str, Any]] = {} - self.index_dir.mkdir(parents=True, exist_ok=True) - - def search( - self, - collection: str, - vector: List[float], - limit: int = 1, - score_threshold: Optional[float] = None, - ) -> List[VectorSearchResult]: - with self._lock: - index = self._load_index(collection) - if index is None: - return [] - import numpy as np - - meta = self._metadata.get(collection, {}) - id_map = meta.get("id_map", {}) - payload_map = meta.get("payload", {}) - - vec = np.asarray([vector], dtype="float32") - if vec.shape[1] != self.vector_dim: - raise ValueError( - "FAISSVectorStore.search vector dimension mismatch " - f"collection={collection} got={vec.shape[1]} expected={self.vector_dim}" - ) - if self.index_type.lower() == "cosine": - faiss = self._import_faiss() - faiss.normalize_L2(vec) - - distances, ids = index.search(vec, limit) - results: List[VectorSearchResult] = [] - for dist, idx in zip(distances[0], ids[0]): - if idx < 0: - continue - point_id = self._find_point_id(id_map, int(idx)) - if point_id is None: - continue - payload = payload_map.get(point_id, {}) - if self.index_type.lower() == "l2": - similarity = 1.0 / (1.0 + float(dist)) - else: - similarity = float(dist) - if score_threshold is not None and similarity < score_threshold: - continue - results.append( - VectorSearchResult( - cache_id=str(point_id), - similarity=similarity, - prompt=str(payload.get("prompt", "")), - saved_steps=list(payload.get("saved_steps", [])), - payload=payload, - ) - ) - return results - - def upsert( - self, - collection: str, - point_id: str, - vector: List[float], - payload: Dict[str, Any], - ) -> None: - with self._lock: - index = self._load_index(collection) - if index is None: - self.ensure_collection(collection, self.vector_dim) - index = self._load_index(collection) - if index is None: - raise RuntimeError( - f"FAISSVectorStore.upsert could not load or create collection collection={collection}" - ) - import numpy as np - - meta = self._metadata.setdefault(collection, {"id_map": {}, "payload": {}, "next_id": 1}) - id_map = meta["id_map"] - payload_map = meta["payload"] - vec = np.asarray([vector], dtype="float32") - if vec.shape[1] != self.vector_dim: - raise ValueError( - "FAISSVectorStore.upsert vector dimension mismatch " - f"collection={collection} got={vec.shape[1]} expected={self.vector_dim}" - ) - if self.index_type.lower() == "cosine": - faiss = self._import_faiss() - faiss.normalize_L2(vec) - - existing = id_map.get(point_id) - if existing is not None: - self._remove_ids(index, [int(existing)]) - else: - existing = int(meta.get("next_id", 1)) - meta["next_id"] = existing + 1 - id_map[point_id] = existing - - index.add_with_ids(vec, self._as_faiss_ids([existing])) - payload_map[point_id] = payload - self._save_index(collection, index) - - def delete(self, collection: str, point_ids: List[str]) -> None: - with self._lock: - index = self._load_index(collection) - if index is None: - return - meta = self._metadata.get(collection, {}) - id_map = meta.get("id_map", {}) - payload_map = meta.get("payload", {}) - to_remove = [] - for pid in point_ids: - idx = id_map.pop(pid, None) - if idx is not None: - to_remove.append(int(idx)) - payload_map.pop(pid, None) - if to_remove: - self._remove_ids(index, to_remove) - self._save_index(collection, index) - - def ensure_collection(self, collection: str, vector_dim: int) -> None: - with self._lock: - index = self._load_index(collection) - if index is not None: - return - faiss = self._import_faiss() - if self.index_type.lower() == "l2": - base = faiss.IndexFlatL2(vector_dim) - elif self.index_type.lower() in ("ip", "innerproduct"): - base = faiss.IndexFlatIP(vector_dim) - elif self.index_type.lower() == "cosine": - base = faiss.IndexFlatIP(vector_dim) - else: - raise ValueError(f"Unsupported index_type: {self.index_type}") - index = faiss.IndexIDMap2(base) - self._indices[collection] = index - self._metadata[collection] = {"id_map": {}, "payload": {}, "next_id": 1} - self._save_index(collection, index) - - def get_vector_size(self, collection: str) -> Optional[int]: - with self._lock: - index = self._load_index(collection) - if index is None: - return None - return int(index.d) - - def _load_index(self, collection: str) -> Optional[Any]: - if collection in self._indices: - return self._indices[collection] - index_path, meta_path = self._get_paths(collection) - if not index_path.exists(): - return None - faiss = self._import_faiss() - try: - index = faiss.read_index(str(index_path)) - except Exception as exc: - logger.exception( - "FAISSVectorStore failed to read index collection={} path={} err={}", - collection, - index_path, - exc, - ) - raise RuntimeError( - "FAISSVectorStore failed to read index " - f"collection={collection} path={index_path} " - f"err_type={type(exc).__name__} err={exc}" - ) from exc - self._indices[collection] = index - empty_meta = {"id_map": {}, "payload": {}, "next_id": 1} - if meta_path.exists(): - try: - raw = json.loads(meta_path.read_text()) - except OSError as exc: - logger.exception( - "FAISSVectorStore failed to read metadata collection={} path={} err={}", - collection, - meta_path, - exc, - ) - raise RuntimeError( - "FAISSVectorStore failed to read metadata " - f"collection={collection} path={meta_path} " - f"err_type={type(exc).__name__} err={exc}" - ) from exc - except json.JSONDecodeError as exc: - logger.exception( - "FAISSVectorStore metadata is not valid JSON collection={} path={} err={}", - collection, - meta_path, - exc, - ) - raise ValueError( - f"FAISSVectorStore metadata is not valid JSON collection={collection} path={meta_path}: {exc}" - ) from exc - if not isinstance(raw, dict): - raise ValueError( - "FAISSVectorStore metadata must be a JSON object " - f"collection={collection} path={meta_path} " - f"got_type={type(raw).__name__}" - ) - id_map = raw.get("id_map", {}) - if not isinstance(id_map, dict): - raise ValueError( - "FAISSVectorStore metadata.id_map must be a dict " - f"collection={collection} path={meta_path} " - f"got_type={type(id_map).__name__}" - ) - if len(id_map) != index.ntotal: - raise ValueError( - "FAISSVectorStore metadata does not match index " - f"collection={collection} path={meta_path} " - f"id_map={len(id_map)} ntotal={index.ntotal}" - ) - self._metadata[collection] = raw - else: - self._metadata[collection] = empty_meta - return index - - def _save_index(self, collection: str, index: Any) -> None: - import os - - index_path, meta_path = self._get_paths(collection) - faiss = self._import_faiss() - tmp_index = index_path.with_suffix(".faiss.tmp") - tmp_meta = meta_path.with_suffix(".json.tmp") - try: - faiss.write_index(index, str(tmp_index)) - except Exception as exc: - logger.exception( - "FAISSVectorStore._save_index failed writing index collection={} path={} err={}", - collection, - str(tmp_index), - exc, - ) - raise RuntimeError(f"FAISSVectorStore failed to write index for collection={collection}") from exc - meta = self._metadata.get(collection, {"id_map": {}, "payload": {}, "next_id": 1}) - try: - tmp_meta.write_text(json.dumps(meta, ensure_ascii=True)) - except Exception as exc: - logger.exception( - "FAISSVectorStore._save_index failed writing metadata collection={} path={} err={}", - collection, - str(tmp_meta), - exc, - ) - raise RuntimeError(f"FAISSVectorStore failed to write metadata for collection={collection}") from exc - os.replace(tmp_meta, meta_path) - os.replace(tmp_index, index_path) - - def _get_paths(self, collection: str) -> tuple[Path, Path]: - index_path = self.index_dir / f"{collection}.faiss" - meta_path = self.index_dir / f"{collection}.json" - return index_path, meta_path - - def _import_faiss(self): - try: - import faiss # type: ignore - except (ImportError, ModuleNotFoundError) as exc: - raise ImportError("faiss 未安装,无法使用 FAISSVectorStore。") from exc - return faiss - - def _as_faiss_ids(self, ids: List[int]): - import numpy as np - - return np.asarray(ids, dtype="int64") - - def _remove_ids(self, index: Any, ids: List[int]) -> None: - faiss = self._import_faiss() - selector = faiss.IDSelectorBatch(len(ids), self._as_faiss_ids(ids)) - index.remove_ids(selector) - - def _find_point_id(self, id_map: Dict[str, Any], idx: int) -> Optional[str]: - for key, value in id_map.items(): - if int(value) == idx: - return str(key) - return None diff --git a/telefuser/cache_mem/vector_store/interfaces.py b/telefuser/cache_mem/vector_store/interfaces.py deleted file mode 100644 index daa32ca..0000000 --- a/telefuser/cache_mem/vector_store/interfaces.py +++ /dev/null @@ -1,42 +0,0 @@ -from __future__ import annotations - -from abc import ABC, abstractmethod -from typing import Any, Dict, List, Optional - -from ..cache_types import VectorSearchResult - - -class VectorStore(ABC): - """向量存储接口。""" - - @abstractmethod - def search( - self, - collection: str, - vector: List[float], - limit: int = 1, - score_threshold: Optional[float] = None, - ) -> List[VectorSearchResult]: - pass - - @abstractmethod - def upsert( - self, - collection: str, - point_id: str, - vector: List[float], - payload: Dict[str, Any], - ) -> None: - pass - - @abstractmethod - def delete(self, collection: str, point_ids: List[str]) -> None: - pass - - @abstractmethod - def ensure_collection(self, collection: str, vector_dim: int) -> None: - pass - - @abstractmethod - def get_vector_size(self, collection: str) -> Optional[int]: - pass diff --git a/telefuser/cache_mem/vector_store/qdrant.py b/telefuser/cache_mem/vector_store/qdrant.py deleted file mode 100644 index f00cbf8..0000000 --- a/telefuser/cache_mem/vector_store/qdrant.py +++ /dev/null @@ -1,46 +0,0 @@ -from __future__ import annotations - -from typing import Any, Dict, List, Optional - -from ..cache_types import VectorSearchResult -from .interfaces import VectorStore - - -class QdrantVectorStore(VectorStore): - """Qdrant vector store stub — not available in MVP.""" - - def __init__( - self, - url: str = "", - api_key: Optional[str] = None, - prefer_grpc: bool = False, - timeout: int = 30, - ) -> None: - raise NotImplementedError("Qdrant backend not available in MVP. Planned for v2.") - - def search( - self, - collection: str, - vector: List[float], - limit: int = 1, - score_threshold: Optional[float] = None, - ) -> List[VectorSearchResult]: - raise NotImplementedError("Qdrant backend not available in MVP. Planned for v2.") - - def upsert( - self, - collection: str, - point_id: str, - vector: List[float], - payload: Dict[str, Any], - ) -> None: - raise NotImplementedError("Qdrant backend not available in MVP. Planned for v2.") - - def delete(self, collection: str, point_ids: List[str]) -> None: - raise NotImplementedError("Qdrant backend not available in MVP. Planned for v2.") - - def ensure_collection(self, collection: str, vector_dim: int) -> None: - raise NotImplementedError("Qdrant backend not available in MVP. Planned for v2.") - - def get_vector_size(self, collection: str) -> Optional[int]: - raise NotImplementedError("Qdrant backend not available in MVP. Planned for v2.") diff --git a/telefuser/pipelines/lingbot_world_fast/pipeline.py b/telefuser/pipelines/lingbot_world_fast/pipeline.py index 4e22735..f54b651 100644 --- a/telefuser/pipelines/lingbot_world_fast/pipeline.py +++ b/telefuser/pipelines/lingbot_world_fast/pipeline.py @@ -45,6 +45,9 @@ class LingBotWorldFastPipelineConfig: orig_height: int = 480 orig_width: int = 832 max_area: int = 480 * 832 + # 滚动 KV 窗口(单位:latent 帧;local_attn_size 含 sink)。-1 = 全长 KV(旧行为) + local_attn_size: int = 7 + sink_size: int = 3 class LingBotWorldFastPipeline(BasePipeline): @@ -103,7 +106,13 @@ def init(self, module_manager, config: LingBotWorldFastPipelineConfig) -> None: str(fast_path), torch_dtype=config.dit_torch_dtype, control_type=config.control_type, - config={"patch_size": (1, 2, 2), "text_len": 512, "control_type": config.control_type}, + config={ + "patch_size": (1, 2, 2), + "text_len": 512, + "control_type": config.control_type, + "local_attn_size": config.local_attn_size, + "sink_size": config.sink_size, + }, ).to(self.device) self.dit.eval().requires_grad_(False) @@ -415,7 +424,9 @@ def create_runtime( frame_num = (lat_f - 1) * 4 + 1 patch_area = self.dit.patch_size[1] * self.dit.patch_size[2] frame_tokens = (lat_h * lat_w) // patch_area - kv_size = frame_tokens * lat_f + # 滚动窗口:KV buffer 只开 local_attn_size 帧(含 sink);-1 = 全长(旧行为) + kv_window_frames = lat_f if self.config.local_attn_size == -1 else min(self.config.local_attn_size, lat_f) + kv_size = frame_tokens * kv_window_frames max_seq_len = session_config.chunk_size * frame_tokens max_attention_size = ( kv_size if session_config.max_attention_size is None else int(session_config.max_attention_size) @@ -480,8 +491,21 @@ def create_runtime( max_attention_size=max_attention_size, scheduler=FlowUniPCMultistepScheduler(num_train_timesteps=1000, shift=1, use_dynamic_shifting=False), generator=generator, + kv_local_attn_size=self.config.local_attn_size, + kv_sink_size=self.config.sink_size if self.config.local_attn_size != -1 else 0, ) runtime.timesteps = self.timesteps.select(runtime.scheduler, session_config.sample_shift) + if session_config.world_kv_binding is not None: + runtime.world_kv_binding = session_config.world_kv_binding + try: + runtime.world_kv_binding.on_runtime_created(runtime, session_config) + if runtime.world_kv_cached_latents: + logger.info( + f"world_kv: fast-forward {len(runtime.world_kv_cached_latents)} chunks (decode-only)" + ) + except Exception as exc: + logger.warning(f"world_kv on_runtime_created failed; falling back to cold run: {exc}") + runtime.world_kv_cached_latents = {} self._notify_progress(progress_callback, "runtime_created", width=width, height=height, latent_frames=lat_f) logger.info(f"LingBot runtime created: {width}x{height}, latent={lat_f}x{lat_h}x{lat_w}") return runtime @@ -504,34 +528,47 @@ def generate_next_chunk( if control_chunk is None and runtime.control_chunks is not None and idx < len(runtime.control_chunks): control_chunk = runtime.control_chunks[idx] - self._notify_progress(progress_callback, "denoising_chunk", index=idx) current_start = idx * runtime.chunk_size * runtime.frame_tokens - denoised = self.denoise_stage.denoise_chunk( - latent_chunk=latent_chunk, - condition_chunk=condition_chunk, - prompt_emb=runtime.prompt_emb, - timesteps=runtime.timesteps, - scheduler=runtime.scheduler, - control_chunk=control_chunk, - self_kv_cache=runtime.self_kv_cache, - crossattn_cache=runtime.crossattn_cache, - current_start=current_start, - max_attention_size=runtime.max_attention_size, - generator=runtime.generator, - ) + cached_latent = runtime.world_kv_cached_latents.pop(idx, None) if runtime.world_kv_cached_latents else None + if cached_latent is not None: + # world_kv fast-forward 命中:KV 已被 seed,latent 来自缓存骨架 → decode-only, + # 跳过 denoise 与 clean-KV rewrite(generator 抽取已由 binding 对齐烧掉)。 + self._notify_progress(progress_callback, "decoding_cached_chunk", index=idx) + denoised = cached_latent.to(device=self.device, dtype=self.torch_dtype) + else: + self._notify_progress(progress_callback, "denoising_chunk", index=idx) + denoised = self.denoise_stage.denoise_chunk( + latent_chunk=latent_chunk, + condition_chunk=condition_chunk, + prompt_emb=runtime.prompt_emb, + timesteps=runtime.timesteps, + scheduler=runtime.scheduler, + control_chunk=control_chunk, + self_kv_cache=runtime.self_kv_cache, + crossattn_cache=runtime.crossattn_cache, + current_start=current_start, + max_attention_size=runtime.max_attention_size, + generator=runtime.generator, + ) - self._notify_progress(progress_callback, "updating_cache", index=idx) - self.dit( - x=denoised.to(dtype=self.torch_dtype), - timestep=torch.zeros((1,), dtype=torch.float32, device=self.device), - context=runtime.prompt_emb, - y=condition_chunk, - control_tensor=control_chunk, - kv_cache=runtime.self_kv_cache, - crossattn_cache=runtime.crossattn_cache, - current_start=current_start, - max_attention_size=runtime.max_attention_size, - ) + self._notify_progress(progress_callback, "updating_cache", index=idx) + self.dit( + x=denoised.to(dtype=self.torch_dtype), + timestep=torch.zeros((1,), dtype=torch.float32, device=self.device), + context=runtime.prompt_emb, + y=condition_chunk, + control_tensor=control_chunk, + kv_cache=runtime.self_kv_cache, + crossattn_cache=runtime.crossattn_cache, + current_start=current_start, + max_attention_size=runtime.max_attention_size, + ) + + if runtime.world_kv_binding is not None: + try: + runtime.world_kv_binding.on_chunk_finalized(runtime, idx, denoised) + except Exception as exc: + logger.warning(f"world_kv on_chunk_finalized failed at chunk {idx}: {exc}") self._notify_progress(progress_callback, "decoding_chunk", index=idx, device=str(self.vae_device)) frames = self.decode_video_cached( diff --git a/telefuser/pipelines/lingbot_world_fast/session.py b/telefuser/pipelines/lingbot_world_fast/session.py index 9a468f4..1609df9 100644 --- a/telefuser/pipelines/lingbot_world_fast/session.py +++ b/telefuser/pipelines/lingbot_world_fast/session.py @@ -28,6 +28,8 @@ class LingBotWorldFastSessionConfig: poses: object | None = None intrinsics: object | None = None action: object | None = None + # cacheseek world_kv 跨请求 KV 复用(可选;None = 行为与原版完全一致) + world_kv_binding: object | None = None control_move_step: float = 0.18 control_yaw_step_degrees: float = 10.0 control_lateral_step: float = 0.12 @@ -58,6 +60,12 @@ class LingBotWorldFastRuntimeState: emitted_frames: int = 0 active: bool = True generator: torch.Generator | None = None + # KV 几何(latent 帧):binding 据此组装滚动窗口 ring。-1 = 全长 KV + kv_local_attn_size: int = -1 + kv_sink_size: int = 0 + # cacheseek world_kv:binding + fast-forward 命中的 decode-only latent(chunk_idx → x0) + world_kv_binding: object | None = None + world_kv_cached_latents: dict[int, torch.Tensor] = field(default_factory=dict) @dataclass diff --git a/telefuser/service/api/api_server.py b/telefuser/service/api/api_server.py index f659203..685d568 100644 --- a/telefuser/service/api/api_server.py +++ b/telefuser/service/api/api_server.py @@ -272,15 +272,19 @@ def initialize_services( cache_dir: Path, inference_service: PipelineService, cache_service: Any | None = None, + cache_adapter: Any | None = None, ) -> None: - """Initialize file and media services.""" + """Initialize file and media services. + """ self.file_service = FileService(cache_dir) self.inference_service = inference_service self.cache_service = cache_service + self.cache_adapter = cache_adapter self.media_service = MediaGenerationService( self.file_service, inference_service, cache_service=cache_service, + cache_adapter=cache_adapter, ) self.task_processor = AsyncTaskProcessor( task_manager=self.task_manager, diff --git a/telefuser/service/cache/__init__.py b/telefuser/service/cache/__init__.py index 1008fa3..be04e8a 100644 --- a/telefuser/service/cache/__init__.py +++ b/telefuser/service/cache/__init__.py @@ -1,4 +1,10 @@ -from .cache_factory import CacheServiceFactory -from .cache_service import CacheService +"""Deprecated cache package namespace. -__all__ = ["CacheService", "CacheServiceFactory"] +The latent cache implementation moved to the external `cacheseek` package. +TeleFuser service code imports Cacheseek directly and no longer supports the +legacy `telefuser.service.cache` facade. +""" + +_BACKEND = "cacheseek" + +__all__ = ["_BACKEND"] diff --git a/telefuser/service/cache/cache_factory.py b/telefuser/service/cache/cache_factory.py deleted file mode 100644 index 4342b91..0000000 --- a/telefuser/service/cache/cache_factory.py +++ /dev/null @@ -1,176 +0,0 @@ -import traceback -from dataclasses import fields -from pathlib import Path -from typing import Any, Optional - -from loguru import logger - -from telefuser.utils.utils import import_function_from_file - -from .cache_service import CacheService - -try: - from telefuser.cache_mem.log_monitor import setup_cache_log_sink -except Exception: - setup_cache_log_sink = None - -try: - from telefuser.cache_mem.config import CacheConfig, CacheMode - from telefuser.cache_mem.latent_cache import LatentCache -except Exception as exc: # optional dependency for cache service - _cache_dep_import_error = exc - _cache_dep_import_traceback = "".join(traceback.format_exception(type(exc), exc, exc.__traceback__)).rstrip() - LatentCache = Any - CacheConfig = Any - CacheMode = Any -else: - _cache_dep_import_error = None - _cache_dep_import_traceback = None - - -class CacheServiceFactory: - """Create CacheService with config parsing and dependency wiring.""" - - @staticmethod - def create_cache_service( - ppl_file: Optional[str], - enable_latent_cache: Optional[bool], - cache_mode: Optional[str] = None, - ) -> Optional[CacheService]: - try: - if CacheConfig is Any or CacheMode is Any or LatentCache is Any: - unavailable_symbols = ( - ", ".join( - symbol_name - for symbol_name, symbol_value in ( - ("CacheConfig", CacheConfig), - ("CacheMode", CacheMode), - ("LatentCache", LatentCache), - ) - if symbol_value is Any - ) - or "unknown" - ) - if _cache_dep_import_error is not None: - logger.warning( - "Cache config not available, cache service disabled. " - "unavailable_symbols={}, import_error_type={}, import_error={}, " - "traceback:\n{}", - unavailable_symbols, - type(_cache_dep_import_error).__name__, - _cache_dep_import_error, - _cache_dep_import_traceback, - ) - else: - logger.warning( - "Cache config not available, cache service disabled. unavailable_symbols={}", - unavailable_symbols, - ) - return None - - # ppl_file 必须提供且包含 build_latent_data,否则抛出错误、不初始化 cache_service - if ppl_file is None: - raise ValueError( - "enable_latent_cache is enabled but no ppl_file provided. " - "Please provide a pipeline file that contains the build_latent_data function." - ) - # 尝试从 ppl_file 读取 CACHE_CONFIG 覆盖项 - ppl_cache_config = None - ppl_cache_config_load_error = None - try: - ppl_cache_config = import_function_from_file(ppl_file, "CACHE_CONFIG") - logger.info(f"Found CACHE_CONFIG in {ppl_file}") - except AttributeError: - ppl_cache_config = None - except Exception as exc: - ppl_cache_config_load_error = exc - logger.warning(f"Failed to load CACHE_CONFIG from {ppl_file}: {exc}") - ppl_cache_config = None - - # 构建 app_cache_config(默认值 + ppl 覆盖) - cache_config_source = "CacheConfig" - if isinstance(ppl_cache_config, CacheConfig): - app_cache_config = ppl_cache_config - cache_config_source = "ppl CACHE_CONFIG" - elif isinstance(ppl_cache_config, dict): - valid_keys = {field.name for field in fields(CacheConfig)} - overrides = {k: v for k, v in ppl_cache_config.items() if k in valid_keys} - unknown_keys = sorted(set(ppl_cache_config.keys()) - valid_keys) - if unknown_keys: - logger.warning(f"Ignore unknown CACHE_CONFIG keys: {', '.join(unknown_keys)}") - app_cache_config = CacheConfig(**overrides) - cache_config_source = "ppl CACHE_CONFIG" - else: - app_cache_config = CacheConfig() - - # 兼容 cache_mode 为字符串的写法 - if isinstance(app_cache_config.cache_mode, str): - try: - app_cache_config.cache_mode = CacheMode(app_cache_config.cache_mode) - except ValueError: - logger.warning( - f"Invalid cache_mode '{app_cache_config.cache_mode}' in CACHE_CONFIG, " - "fallback to default READ_WRITE" - ) - app_cache_config.cache_mode = CacheConfig().cache_mode - - # 命令行传入的 enable_latent_cache 写入配置(调用方已保证为 True 才进入本函数) - if enable_latent_cache is not None: - app_cache_config.enable_latent_cache = enable_latent_cache - cache_config_source = "command line" - - if cache_mode is not None: - try: - app_cache_config.cache_mode = CacheMode(cache_mode) - cache_config_source = "command line" - except ValueError: - logger.warning(f"Invalid cache_mode '{cache_mode}', using {app_cache_config.cache_mode}") - - # 尽量提前初始化 cache 日志 sink,覆盖后续 build/latent cache 初始化错误 - if getattr(app_cache_config, "cache_log_enabled", False) and setup_cache_log_sink: - cache_log_dir = getattr(app_cache_config, "cache_log_dir", None) - if not cache_log_dir: - cache_log_dir = str(Path(app_cache_config.latent_cache_dir) / "logs") - setup_cache_log_sink( - log_dir=cache_log_dir, - level=getattr(app_cache_config, "cache_log_level", "DEBUG"), - rotation=getattr(app_cache_config, "cache_log_rotation", "100 MB"), - retention=getattr(app_cache_config, "cache_log_retention", "7 days"), - ) - if ppl_cache_config_load_error is not None: - logger.warning( - "CACHE_CONFIG load failed during cache init, using defaults. Original error: {}", - ppl_cache_config_load_error, - ) - - try: - build_latent_data_func = import_function_from_file(ppl_file, "build_latent_data") - logger.info(f"Found build_latent_data function in {ppl_file}") - except (ImportError, AttributeError) as e: - raise ValueError( - f"ppl_file must define 'build_latent_data' for cache service. " - f"Missing or invalid in {ppl_file}. Error: {e}" - ) from e - - # 初始化 LatentCache - latent_cache = LatentCache( - Path(app_cache_config.latent_cache_dir), - app_cache_config, - ) - - # 初始化 CacheService - cache_service = CacheService( - latent_cache=latent_cache, - build_latent_data_func=build_latent_data_func, - cache_mode=app_cache_config.cache_mode, - app_cache_config=app_cache_config, - ) - - mode_value = getattr(app_cache_config.cache_mode, "value", app_cache_config.cache_mode) - logger.info(f"Cache service enabled (mode: {mode_value}, source: {cache_config_source})") - return cache_service - except ValueError: - raise - except Exception as e: - logger.warning(f"Failed to initialize cache service: {e}") - return None diff --git a/telefuser/service/cache/cache_service.py b/telefuser/service/cache/cache_service.py deleted file mode 100644 index f8b8642..0000000 --- a/telefuser/service/cache/cache_service.py +++ /dev/null @@ -1,389 +0,0 @@ -import asyncio -import time -from queue import Empty, Queue -from threading import Event, Lock, Thread -from typing import Any, Optional - -import torch.multiprocessing as mp -from loguru import logger - -from telefuser.utils.profiler import ProfilingContext4Debug - -try: - from telefuser.cache_mem.config import CacheConfig, CacheMode - from telefuser.cache_mem.latent_cache import LatentCache -except Exception: # optional dependency for cache service - LatentCache = Any - CacheConfig = Any - CacheMode = Any - -try: - mp.set_start_method("spawn", force=True) -except RuntimeError: - pass - - -class CacheService: - """缓存服务(lookup + writeback)。 - - 默认策略:缓存层任何异常都不影响主链路,仅记录告警并降级。 - """ - - def __init__( - self, - latent_cache: Optional[LatentCache] = None, - build_latent_data_func: Optional[callable] = None, - cache_mode: Optional[Any] = None, - app_cache_config: Optional[Any] = None, - ) -> None: - self.latent_cache = latent_cache - self.app_cache_config = app_cache_config # CacheConfig from telefuser.cache_mem.config - self.build_latent_data_func = build_latent_data_func - self.cache_mode = cache_mode or (CacheMode.READ_WRITE if CacheMode is not Any else None) - - # 确定缓存模式。 - if self.app_cache_config and hasattr(self.app_cache_config, "cache_mode"): - self.cache_mode = self.app_cache_config.cache_mode - elif CacheConfig is not Any: - self.cache_mode = CacheConfig().cache_mode - - # 异步保存相关配置(按 CacheConfig 字段读取,保持文档接口一致)。 - cache_config = self.app_cache_config - if cache_config is None and CacheConfig is not Any: - try: - cache_config = CacheConfig() - except Exception: - cache_config = None - - self.save_async_enabled = bool(getattr(cache_config, "save_async_enabled", False)) - self.save_queue_size = int(getattr(cache_config, "save_queue_size", 0) or 0) - self.save_on_full = str(getattr(cache_config, "save_on_full", "drop") or "drop").lower() - self.save_queue_warn_threshold = int(getattr(cache_config, "save_queue_warn_threshold", 0) or 0) - self.vector_wait_warn_s = float(getattr(cache_config, "vector_wait_warn_s", 0) or 0) - self.vector_wait_poll_s = float(getattr(cache_config, "vector_wait_poll_s", 0) or 0) - self.vector_wait_timeout_s = float(getattr(cache_config, "vector_wait_timeout_s", 0) or 0) - self.flush_on_shutdown = bool(getattr(cache_config, "flush_on_shutdown", False)) - - self.save_queue: Optional[Queue] = None - self.save_worker: Optional[Thread] = None - self._save_stop_event = Event() - self.pending_vector_updates = 0 - self._pending_lock = Lock() - self.vector_update_idle = Event() - self.vector_update_idle.set() - - if self.save_async_enabled: - maxsize = max(1, self.save_queue_size) if self.save_queue_size else 0 - self.save_queue = Queue(maxsize=maxsize) - self.save_worker = Thread( - target=self._start_save_worker, - name="cache-save-worker", - daemon=True, - ) - self.save_worker.start() - - def _reserve_vector_update(self) -> None: - with self._pending_lock: - self.pending_vector_updates += 1 - self.vector_update_idle.clear() - - def _release_vector_update(self) -> None: - with self._pending_lock: - self.pending_vector_updates = max(0, self.pending_vector_updates - 1) - if self.pending_vector_updates == 0: - self.vector_update_idle.set() - - def set_build_latent_data_func(self, func: Optional[callable]) -> None: - """Set build_latent_data function imported from ppl file.""" - - self.build_latent_data_func = func - - def _start_save_worker(self) -> None: - """后台保存线程入口:循环消费队列并执行保存。""" - if self.save_queue is None: - logger.warning("cache-save-worker start skipped: save_queue is None") - return - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - logger.info("cache-save-worker started") - try: - while not self._save_stop_event.is_set(): - try: - task = self.save_queue.get(timeout=0.5) - except Empty: - continue - try: - if task is None: - continue - task_request, latent_payload = task - coro = self._save_latent_payload_impl(task_request, latent_payload, vector_update_reserved=True) - if asyncio.iscoroutine(coro): - loop.run_until_complete(coro) - except Exception as exc: - logger.exception(f"cache-save-worker save failed: {exc}") - finally: - self.save_queue.task_done() - finally: - loop.close() - logger.info("cache-save-worker stopped") - - async def _save_latent_payload_impl( - self, - task_request: Any, - latent_payload: Optional[dict], - vector_update_reserved: bool = False, - ) -> None: - """后台保存实现(后续步骤补充 pending_vector_updates 逻辑)。""" - if self.latent_cache is None or not latent_payload: - return - - latent_states_dict = latent_payload.get("latent_states_dict") - saved_steps = latent_payload.get("saved_steps") or [] - final_step = latent_payload.get("final_step") - num_frames = latent_payload.get("num_frames", 0) - embedding_video_frames = latent_payload.get("embedding_video_frames") - if latent_states_dict is None or final_step is None or not saved_steps: - return - - need_vector_update = bool(embedding_video_frames) - try: - with ProfilingContext4Debug("save_latent_payload"): - if need_vector_update: - if not vector_update_reserved: - self._reserve_vector_update() - try: - task_id = getattr(task_request, "task_id", None) - logger.info("cache-save-worker save start task_id={}", task_id) - except Exception: - pass - await self.latent_cache.save( - task_request, - latent_states_dict, - num_frames, - int(final_step), - list(saved_steps), - embedding_video_frames=embedding_video_frames, - ) - try: - task_id = getattr(task_request, "task_id", None) - logger.info("cache-save-worker save end task_id={}", task_id) - except Exception: - pass - except Exception as exc: - logger.exception(f"Cache writeback failed, ignored: {exc}") - finally: - if need_vector_update: - self._release_vector_update() - - async def _wait_vector_updates_done(self, task_request: Any = None) -> None: - """等待向量更新栅栏释放,避免 lookup 读到未完成的向量更新。""" - if not self.save_async_enabled: - return - if self.vector_update_idle.is_set(): - return - poll_s = self.vector_wait_poll_s if self.vector_wait_poll_s > 0 else 0.05 - warn_s = self.vector_wait_warn_s if self.vector_wait_warn_s > 0 else 0.0 - timeout_s = self.vector_wait_timeout_s if self.vector_wait_timeout_s > 0 else 0.0 - task_id = getattr(task_request, "task_id", None) - start = time.monotonic() - warned = False - logger.info( - "CacheService.build_latent_data wait vector_update_idle start task_id={} pending={}", - task_id, - self.pending_vector_updates, - ) - while not self.vector_update_idle.is_set(): - elapsed = time.monotonic() - start - if warn_s and not warned and elapsed >= warn_s: - logger.warning( - "CacheService.build_latent_data wait vector_update_idle exceeded {:.2f}s task_id={} pending={}", - warn_s, - task_id, - self.pending_vector_updates, - ) - warned = True - if timeout_s and elapsed >= timeout_s: - logger.warning( - "CacheService.build_latent_data wait vector_update_idle timeout {:.2f}s " - "task_id={} pending={}; continue with lookup", - timeout_s, - task_id, - self.pending_vector_updates, - ) - return - await asyncio.sleep(poll_s) - logger.info( - "CacheService.build_latent_data wait vector_update_idle end task_id={} elapsed={:.2f}s", - task_id, - time.monotonic() - start, - ) - - async def build_latent_data(self, task_request: Any, task_data: dict) -> Optional[dict]: - """构建 latent_data,用于传递给 pipeline。 - - 默认降级:缓存 lookup / build 任何异常都返回安全的 miss 结构。 - """ - with ProfilingContext4Debug("build_latent_data"): - cache_config = self.app_cache_config - if cache_config is None and CacheConfig is not Any: - try: - cache_config = CacheConfig() - except Exception: - cache_config = None - - cache_result = None - if self.latent_cache is not None and self.cache_mode in [CacheMode.READ_WRITE, CacheMode.READ_ONLY]: - try: - await self._wait_vector_updates_done(task_request) - cache_result = await self.latent_cache.lookup(task_request) - except Exception as exc: - logger.exception(f"Cache lookup failed, degrade to miss: {exc}") - cache_result = None - - if self.build_latent_data_func is not None: - try: - latent_data = self.build_latent_data_func(task_data=task_data, cache_result=cache_result) - if latent_data is not None: - return latent_data - except Exception as exc: - logger.exception(f"build_latent_data_func failed, fallback to default: {exc}") - - cached_latent = None - skip_step = 0 - if cache_result is not None and getattr(cache_result, "hit", False): - cached_latent = getattr(cache_result, "latent_state", None) - skip_step = int(getattr(cache_result, "skip_step", 0) or 0) - - saved_steps = [] - if cache_config is not None: - saved_steps = list(getattr(cache_config, "key_steps", []) or []) - saved_steps = [int(step) for step in saved_steps] - - return { - "hit": bool(cached_latent is not None and skip_step > 0), - "skip_step": skip_step, - "cached_latent": cached_latent, - "saved_steps": saved_steps, - } - - async def save_latent_payload(self, task_request: Any, latent_payload: Optional[dict]) -> None: - """保存 latent_payload 到缓存。 - - 默认降级:缓存 writeback 任何异常都记录告警并忽略。 - """ - - if self.cache_mode == CacheMode.READ_ONLY or self.latent_cache is None or not latent_payload: - return - - latent_states_dict = latent_payload.get("latent_states_dict") - saved_steps = latent_payload.get("saved_steps") or [] - final_step = latent_payload.get("final_step") - embedding_video_frames = latent_payload.get("embedding_video_frames") - need_vector_update = bool(embedding_video_frames) - if latent_states_dict is None or final_step is None or not saved_steps: - return - - if self.save_async_enabled: - if self.save_queue is None: - logger.warning("Cache save enqueue skipped: save_queue is not initialized") - return - if self.save_queue_warn_threshold > 0: - try: - if self.save_queue.qsize() >= self.save_queue_warn_threshold: - logger.warning( - "Cache save queue length warning: {}", - self.save_queue.qsize(), - ) - except Exception: - pass - if not self.save_queue.full(): - reserved_vector_update = False - try: - if need_vector_update: - self._reserve_vector_update() - reserved_vector_update = True - self.save_queue.put_nowait((task_request, latent_payload)) - try: - task_id = getattr(task_request, "task_id", None) - logger.info( - "Cache save enqueue task_id={} qsize={}", - task_id, - self.save_queue.qsize(), - ) - except Exception: - pass - except Exception as exc: - if reserved_vector_update: - self._release_vector_update() - logger.exception(f"Cache save enqueue failed, ignored: {exc}") - return - - policy = (self.save_on_full or "drop").lower() - if policy == "drop": - logger.warning("Cache save queue full: drop task") - return - if policy == "downgrade": - # TODO: implement latent-only downgrade/eviction when the async - # save queue is full. For now, keep the behavior explicit and - # avoid running a partial path that pretends to persist data. - logger.warning("Cache save queue full: downgrade policy is TODO; drop task") - return - if policy == "sync": - logger.warning("Cache save queue full: fallback to sync save") - else: - logger.warning("Cache save queue full: unknown policy={}, drop task", policy) - return - - # Sync fallback: route through _save_latent_payload_impl so the - # vector_update_idle barrier is respected (otherwise concurrent - # lookups can race the in-flight upsert and read stale state). - try: - await self._save_latent_payload_impl( - task_request, - latent_payload, - vector_update_reserved=False, - ) - except Exception as exc: - logger.exception(f"Cache writeback failed, ignored: {exc}") - - async def _save_latent_payload_downgrade(self, task_request: Any, latent_payload: Optional[dict]) -> None: - """TODO: save latent-only cache entries when full-queue downgrade is implemented.""" - del task_request, latent_payload - logger.warning("Cache save downgrade is TODO; task dropped") - - def shutdown(self) -> None: - """释放缓存服务资源(尽力而为)。""" - - if self.save_worker is not None: - try: - if self.flush_on_shutdown and self.save_queue is not None: - try: - self.save_queue.join() - except Exception as exc: - logger.exception(f"CacheService.flush failed: {exc}") - self._save_stop_event.set() - if self.save_queue is not None: - try: - self.save_queue.put_nowait(None) - except Exception: - pass - self.save_worker.join(timeout=5) - except Exception as exc: - logger.exception(f"CacheService.stop worker failed: {exc}") - - if self.latent_cache is not None and hasattr(self.latent_cache, "shutdown"): - try: - self.latent_cache.shutdown() - except Exception as exc: - logger.exception(f"CacheService.shutdown failed: {exc}") - self.latent_cache = None - self.save_worker = None - self.save_queue = None - - def __del__(self) -> None: - try: - self.shutdown() - except Exception: - pass - if hasattr(self, "pipeline") and self.pipeline is not None: - del self.pipeline diff --git a/telefuser/service/core/container.py b/telefuser/service/core/container.py index b50f549..82847a4 100644 --- a/telefuser/service/core/container.py +++ b/telefuser/service/core/container.py @@ -40,6 +40,7 @@ class ServiceContainer: stream_pipeline_service: StreamPipelineService | None = None media_service: MediaGenerationService | None = None cache_service: Any | None = None + cache_adapter: Any | None = None # cacheseek.adapters.telefuser.TeleFuserCacheAdapter _cache_dir: Path | None = field(default=None, repr=False) @classmethod @@ -112,28 +113,36 @@ def initialize_media_service(self) -> MediaGenerationService: file_service=self.file_service, inference_service=self.pipeline_service, cache_service=self.cache_service, + cache_adapter=self.cache_adapter, ) return self.media_service def initialize_cache_service(self, pipe_path: str) -> Any | None: - """Initialize optional latent cache service when enabled.""" + """Initialize optional latent cache service when enabled. + + """ if not getattr(self.config, "enable_latent_cache", False): return None - # Lazy import to avoid pulling cache_mem deps when disabled. - from ..cache.cache_factory import CacheServiceFactory + # Lazy import to avoid pulling cacheseek deps when disabled. + from cacheseek.adapters.telefuser.cache_factory import CacheServiceFactory try: - self.cache_service = CacheServiceFactory.create_cache_service( + result = CacheServiceFactory.create_cache_service( ppl_file=pipe_path, enable_latent_cache=True, ) except Exception as exc: logger.warning(f"CacheServiceFactory.create_cache_service failed: {exc}") + result = None + + if result is None: self.cache_service = None + self.cache_adapter = None + logger.warning("enable_latent_cache=True but cache service init returned None") + return None - if self.cache_service is None: - logger.warning("enable_latent_cache=True but cache_service is None") + self.cache_service, self.cache_adapter = result return self.cache_service def initialize_all( @@ -216,6 +225,7 @@ def get_api_app(self, enable_rate_limit: bool = True) -> FastAPI: self.file_service.cache_dir, self.pipeline_service, cache_service=self.cache_service, + cache_adapter=self.cache_adapter, # forward adapter to api_server ) if self.stream_pipeline_service: @@ -256,7 +266,7 @@ async def cleanup(self) -> None: except Exception as exc: logger.warning(f"cache service shutdown failed: {exc}") self.cache_service = None - + self.cache_adapter = None self.media_service = None diff --git a/telefuser/service/core/task_service.py b/telefuser/service/core/task_service.py index 4e62549..6b991f1 100644 --- a/telefuser/service/core/task_service.py +++ b/telefuser/service/core/task_service.py @@ -23,16 +23,12 @@ def _build_cache_task_request(task_data: dict) -> SimpleNamespace: """Build a minimal task_request stub for the cache layer. - - Splatting ``task_data`` directly would crash because ``TaskRequest`` is - ``extra="allow"`` and may contain keys that are not valid Python - identifiers. The cache layer only reads ``task_id`` / ``task`` / - ``prompt`` via ``getattr``, so we whitelist those. """ return SimpleNamespace( task_id=task_data.get("task_id"), task=task_data.get("task"), prompt=task_data.get("prompt") or "", + seed=task_data.get("seed"), ) @@ -47,10 +43,12 @@ def __init__( file_service: FileService, inference_service: PipelineService, cache_service: Any | None = None, + cache_adapter: Any | None = None, ) -> None: self.file_service = file_service self.inference_service = inference_service self.cache_service = cache_service + self.cache_adapter = cache_adapter async def generate_media_with_stop_event( self, message: TaskRequest, stop_event: threading.Event @@ -110,18 +108,37 @@ async def update_audio_path(audio_name: str, message: TaskRequest, task_data: di actual_save_path = self.file_service.get_output_path(message.output_path, media_type=media_type) task_data["output_path"] = str(actual_save_path) - # Cache lookup (best-effort: degrade silently on any failure) + # Best-effort: every step degrades silently to "no cache" on failure. + cache_active = self.cache_service is not None and self.cache_adapter is not None cache_task_request: SimpleNamespace | None = None - if self.cache_service is not None: + cache_query: Any | None = None + if cache_active: try: cache_task_request = _build_cache_task_request(task_data) - latent_data = await self.cache_service.build_latent_data(cache_task_request, task_data) + cache_query = self.cache_adapter.build_query(cache_task_request) # build_query + except Exception as exc: + logger.warning(f"[task_service] cache build_query failed, ignored: {exc}") + cache_query = None + + lookup_result = None + if cache_active and cache_query is not None: + try: + lookup_result = await self.cache_service.lookup(cache_query) # lookup + except Exception as exc: + logger.warning(f"[task_service] cache lookup failed, ignored: {exc}") + lookup_result = None + + if cache_active and lookup_result is not None: + try: + latent_data = self.cache_adapter.apply_resume( # apply_resume + lookup_result, engine_ctx=task_data + ) if latent_data is not None: task_data["latent_data"] = latent_data except Exception as exc: - logger.warning(f"[task_service] cache lookup failed, ignored: {exc}") + logger.warning(f"[task_service] cache apply_resume failed, ignored: {exc}") - result = await self.inference_service.run_task_with_stop_event( + result = await self.inference_service.run_task_with_stop_event( # inference task_data, stop_event, output_root=str(self.file_service.output_dir), @@ -134,15 +151,18 @@ async def update_audio_path(audio_name: str, message: TaskRequest, task_data: di raise RuntimeError("Task processing timeout") if result.get("status") == PipelineRunStatus.SUCCESS: - # Cache writeback (best-effort: degrade silently on any failure) - if self.cache_service is not None: + # Cache writeback (best-effort: degrade silently on any failure). + if cache_active and cache_query is not None: try: raw = result.get("raw") latent_payload = raw.get("latent_payload") if isinstance(raw, dict) else None if latent_payload: if cache_task_request is None: cache_task_request = _build_cache_task_request(task_data) - await self.cache_service.save_latent_payload(cache_task_request, latent_payload) + outputs = self.cache_adapter.on_response( # on_response + cache_task_request, latent_payload + ) + await self.cache_service.save(cache_query, outputs) # save except Exception as exc: logger.warning(f"[task_service] cache writeback failed, ignored: {exc}") diff --git a/tests/unit/cache_mem/__init__.py b/tests/unit/cache_mem/__init__.py deleted file mode 100644 index 8b13789..0000000 --- a/tests/unit/cache_mem/__init__.py +++ /dev/null @@ -1 +0,0 @@ - diff --git a/tests/unit/cache_mem/test_concurrency.py b/tests/unit/cache_mem/test_concurrency.py deleted file mode 100644 index eb30190..0000000 --- a/tests/unit/cache_mem/test_concurrency.py +++ /dev/null @@ -1,355 +0,0 @@ -"""Concurrency tests for cache_mem stores guarded by ``threading.RLock``. - -These tests lock in the invariant that under heavy multi-threaded -concurrent access: - -* No registered/put/upsert entry is lost. -* Read paths never raise mid-mutation (e.g. no ``RuntimeError: dictionary - changed size during iteration``). -* On-disk artifacts (``kv_index.json`` / ``prompt_index.json`` / - ``cache_meta.json`` / ``.faiss`` + ``.json``) remain valid - and reload cleanly into a fresh instance. - -Pure CPU, no GPU. ``ThreadPoolExecutor`` is used to surface races on a -multi-core machine while still finishing in well under 2 s each. -""" - -from __future__ import annotations - -import os -import sys - -# faiss-cpu and torch each ship their own OpenMP runtime; on macOS loading -# both into the same process aborts inside libomp unless this is set. -if sys.platform == "darwin": - os.environ.setdefault("KMP_DUPLICATE_LIB_OK", "TRUE") - -import json # noqa: E402 -import random # noqa: E402 -import threading # noqa: E402 -from concurrent.futures import ThreadPoolExecutor, wait # noqa: E402 -from pathlib import Path # noqa: E402 - -import pytest # noqa: E402 - -from telefuser.cache_mem.metadata import LocalCacheMetadataManager # noqa: E402 -from telefuser.cache_mem.storage.local_file import LocalFileKVStore # noqa: E402 - -NUM_OPS = 200 - - -def _run_with_barrier(fns: list) -> list: - """Run callables in parallel, releasing them simultaneously via a barrier. - - The pool must have at least ``len(fns)`` workers so every callable can - actually start and reach ``barrier.wait()``; otherwise the barrier - deadlocks because pool workers block on the barrier before yielding - back to pick up queued tasks. - - Returns the list of futures (already completed). Re-raises any worker - exception via ``future.result()``. - """ - n = len(fns) - barrier = threading.Barrier(n) - - def _wrapped(fn): - def _inner(): - barrier.wait() - return fn() - - return _inner - - with ThreadPoolExecutor(max_workers=n) as pool: - futures = [pool.submit(_wrapped(fn)) for fn in fns] - wait(futures) - return [f.result() for f in futures] - - -# --------------------------------------------------------------------------- -# LocalCacheMetadataManager -# --------------------------------------------------------------------------- - - -class TestMetadataManagerConcurrency: - """Concurrency invariants for ``LocalCacheMetadataManager``.""" - - def test_no_loss_register(self, tmp_path: Path) -> None: - """200 unique register_cache calls — every entry survives, on disk too.""" - meta_dir = tmp_path / "meta_no_loss" - mgr = LocalCacheMetadataManager(meta_dir) - - def make_fn(i: int): - def _fn() -> None: - mgr.register_cache( - cache_id=f"c{i}", - prompt=f"p{i}", - saved_steps=[i % 10], - size_mb=1.0, - num_frames=4, - ) - - return _fn - - _run_with_barrier([make_fn(i) for i in range(NUM_OPS)]) - - # In-memory: all 200 entries present. - assert len(mgr._meta) == NUM_OPS - for i in range(NUM_OPS): - entry = mgr.lookup_prompt(f"p{i}") - assert entry is not None, f"missing entry for p{i}" - assert entry.cache_id == f"c{i}" - - # On-disk: reload via a second manager. - mgr2 = LocalCacheMetadataManager(meta_dir) - assert len(mgr2._meta) == NUM_OPS - for i in range(NUM_OPS): - assert mgr2.lookup_prompt(f"p{i}") is not None, f"missing on reload p{i}" - - def test_atomic_access_counter(self, tmp_path: Path) -> None: - """200 record_access calls on the same cache_id — final count is exactly 200. - - This is the load-bearing test: Python ``+=`` on a dict value is not - atomic, so without the RLock this test fails. - """ - mgr = LocalCacheMetadataManager(tmp_path / "meta_counter") - mgr.register_cache("c0", "prompt0", saved_steps=[0], size_mb=1.0, num_frames=4) - - fns = [lambda: mgr.record_access("c0") for _ in range(NUM_OPS)] - _run_with_barrier(fns) - - meta = mgr.get_cache_meta("c0") - assert meta is not None - assert meta["access_count"] == NUM_OPS, f"expected {NUM_OPS}, got {meta['access_count']}" - - def test_mixed_read_write(self, tmp_path: Path) -> None: - """Half the threads register fresh ids, half iterate lookup_prompt. - - Reads must never raise (e.g. dictionary-changed-size errors). - """ - mgr = LocalCacheMetadataManager(tmp_path / "meta_mixed") - # Pre-register a base set so readers have non-empty maps to iterate. - base_count = 50 - for i in range(base_count): - mgr.register_cache( - cache_id=f"base{i}", - prompt=f"base_p{i}", - saved_steps=[i % 5], - size_mb=0.5, - num_frames=4, - ) - - def make_writer(i: int): - def _fn() -> None: - mgr.register_cache( - cache_id=f"new{i}", - prompt=f"new_p{i}", - saved_steps=[i % 5], - size_mb=0.5, - num_frames=4, - ) - - return _fn - - def make_reader(i: int): - def _fn() -> None: - # Look up an existing base entry; should always succeed. - target = f"base_p{i % base_count}" - entry = mgr.lookup_prompt(target) - assert entry is not None, f"reader could not find {target}" - - return _fn - - fns = [] - for i in range(NUM_OPS // 2): - fns.append(make_writer(i)) - fns.append(make_reader(i)) - _run_with_barrier(fns) - - # Writers' entries must all be there. - assert len(mgr._meta) == base_count + NUM_OPS // 2 - - -# --------------------------------------------------------------------------- -# LocalFileKVStore -# --------------------------------------------------------------------------- - - -class TestLocalFileKVStoreConcurrency: - """Concurrency invariants for ``LocalFileKVStore``.""" - - def test_no_loss_put(self, tmp_path: Path) -> None: - """200 unique put calls — all keys end up in the index and survive reload.""" - kv_dir = tmp_path / "kv_no_loss" - store = LocalFileKVStore(kv_dir) - - def make_fn(i: int): - def _fn() -> None: - store.put(f"k{i}", f"v{i}".encode("utf-8")) - - return _fn - - _run_with_barrier([make_fn(i) for i in range(NUM_OPS)]) - - expected_keys = {f"k{i}" for i in range(NUM_OPS)} - assert set(store.list_keys()) == expected_keys - - # kv_index.json must be valid JSON containing all keys. - index_path = kv_dir / "kv_index.json" - raw = json.loads(index_path.read_text()) - assert set(raw.keys()) == expected_keys - - # Reload via a second store on the same dir. - store2 = LocalFileKVStore(kv_dir) - assert set(store2.list_keys()) == expected_keys - # Spot-check some values round-trip correctly. - for i in (0, NUM_OPS // 2, NUM_OPS - 1): - assert store2.get(f"k{i}") == f"v{i}".encode("utf-8") - - def test_concurrent_overwrite_same_key(self, tmp_path: Path) -> None: - """N threads put the same key — final value is one of theirs (not torn).""" - kv_dir = tmp_path / "kv_overwrite" - store = LocalFileKVStore(kv_dir) - - values = [f"v_{i}".encode("utf-8") for i in range(NUM_OPS)] - - def make_fn(payload: bytes): - def _fn() -> None: - store.put("k", payload) - - return _fn - - _run_with_barrier([make_fn(v) for v in values]) - - final = store.get("k") - assert final in values, f"final value {final!r} is not one of the writes" - - # Index file must still be valid JSON containing exactly "k". - raw = json.loads((kv_dir / "kv_index.json").read_text()) - assert list(raw.keys()) == ["k"] - - # Reload yields the same single key. - store2 = LocalFileKVStore(kv_dir) - assert store2.list_keys() == ["k"] - assert store2.get("k") == final - - -# --------------------------------------------------------------------------- -# FAISSVectorStore (skipped if faiss unavailable) -# --------------------------------------------------------------------------- - - -pytest.importorskip("faiss") # noqa: E402 - - -# On macOS, torch and faiss-cpu each ship their own libomp; loaded into the -# same process they crash inside faiss.search regardless of KMP_DUPLICATE_LIB_OK. -# The RLock invariants we want to lock in are platform-independent, so the -# Linux CI run is what actually gates correctness — skipping locally on Darwin -# avoids a known-bad infra setup, not a real bug in our wrapper. -@pytest.mark.skipif( - sys.platform == "darwin", - reason="faiss-cpu + torch OpenMP collision on macOS aborts inside faiss.search", -) -class TestFAISSVectorStoreConcurrency: - """Concurrency invariants for ``FAISSVectorStore``.""" - - VECTOR_DIM = 8 - COLLECTION = "test_col" - - @staticmethod - def _random_vec(rng: random.Random, dim: int) -> list[float]: - # Random unit-ish vector; exact magnitude does not matter for L2 search. - return [rng.random() for _ in range(dim)] - - def test_no_loss_upsert(self, tmp_path: Path) -> None: - """200 unique upserts — every point survives and on-disk metadata is consistent.""" - from telefuser.cache_mem.vector_store.faiss import FAISSVectorStore - - index_dir = tmp_path / "faiss_no_loss" - store = FAISSVectorStore(index_dir=index_dir, vector_dim=self.VECTOR_DIM, index_type="L2") - store.ensure_collection(self.COLLECTION, self.VECTOR_DIM) - - # Pre-generate vectors so threads do no shared RNG work. - rng = random.Random(0) - points = [(f"pid{i}", self._random_vec(rng, self.VECTOR_DIM)) for i in range(NUM_OPS)] - - def make_fn(pid: str, vec: list[float]): - def _fn() -> None: - store.upsert(self.COLLECTION, pid, vec, payload={"prompt": pid, "saved_steps": [0]}) - - return _fn - - _run_with_barrier([make_fn(pid, vec) for pid, vec in points]) - - # Search returns at most NUM_OPS unique results. - query = self._random_vec(random.Random(1), self.VECTOR_DIM) - results = store.search(self.COLLECTION, query, limit=NUM_OPS) - assert len(results) <= NUM_OPS - # All returned cache_ids belong to the set we inserted. - inserted_ids = {pid for pid, _ in points} - for r in results: - assert r.cache_id in inserted_ids - - # On-disk: id_map length must equal index.ntotal — that is exactly the - # consistency invariant ``_load_index`` asserts on reload. - meta_path = index_dir / f"{self.COLLECTION}.json" - on_disk_meta = json.loads(meta_path.read_text()) - id_map = on_disk_meta.get("id_map", {}) - # Reload via a second store; the constructor + first _load_index call - # will raise if id_map length disagrees with index.ntotal. - store2 = FAISSVectorStore(index_dir=index_dir, vector_dim=self.VECTOR_DIM, index_type="L2") - # Touch the collection to trigger _load_index validation. - size = store2.get_vector_size(self.COLLECTION) - assert size == self.VECTOR_DIM - # Full upsert survival: every inserted id must remain in the on-disk map. - assert set(id_map.keys()) == inserted_ids - - def test_concurrent_upsert_and_search(self, tmp_path: Path) -> None: - """Half upsert, half search — search never raises mid-mutation.""" - from telefuser.cache_mem.vector_store.faiss import FAISSVectorStore - - index_dir = tmp_path / "faiss_mixed" - store = FAISSVectorStore(index_dir=index_dir, vector_dim=self.VECTOR_DIM, index_type="L2") - store.ensure_collection(self.COLLECTION, self.VECTOR_DIM) - - # Seed the collection so searches have something to find. - seed_rng = random.Random(42) - for i in range(20): - store.upsert( - self.COLLECTION, - f"seed{i}", - self._random_vec(seed_rng, self.VECTOR_DIM), - payload={"prompt": f"seed{i}", "saved_steps": [0]}, - ) - - rng = random.Random(7) - upsert_points = [(f"new{i}", self._random_vec(rng, self.VECTOR_DIM)) for i in range(NUM_OPS // 2)] - query_vecs = [self._random_vec(rng, self.VECTOR_DIM) for _ in range(NUM_OPS // 2)] - - def make_upsert(pid: str, vec: list[float]): - def _fn() -> None: - store.upsert(self.COLLECTION, pid, vec, payload={"prompt": pid, "saved_steps": [0]}) - - return _fn - - def make_search(vec: list[float]): - def _fn() -> list: - results = store.search(self.COLLECTION, vec, limit=5) - # Search must always return a list (possibly empty), never crash. - assert isinstance(results, list) - return results - - return _fn - - fns = [] - for (pid, vec), qvec in zip(upsert_points, query_vecs): - fns.append(make_upsert(pid, vec)) - fns.append(make_search(qvec)) - _run_with_barrier(fns) - - # All upserts succeeded. - meta_path = index_dir / f"{self.COLLECTION}.json" - on_disk_meta = json.loads(meta_path.read_text()) - id_map = on_disk_meta.get("id_map", {}) - for pid, _ in upsert_points: - assert pid in id_map, f"upsert for {pid} was lost" diff --git a/tests/unit/cache_mem/test_metadata.py b/tests/unit/cache_mem/test_metadata.py deleted file mode 100644 index 510a253..0000000 --- a/tests/unit/cache_mem/test_metadata.py +++ /dev/null @@ -1,150 +0,0 @@ -"""Unit tests for LocalCacheMetadataManager. - -Tests CRUD operations, eviction planning, access tracking, and -audit logging. Pure CPU, no GPU required. -""" - -from __future__ import annotations - -import json -from pathlib import Path - -import pytest - -from telefuser.cache_mem.metadata import LocalCacheMetadataManager - - -@pytest.fixture() -def mgr(tmp_path: Path) -> LocalCacheMetadataManager: - return LocalCacheMetadataManager(tmp_path / "meta") - - -class TestRegisterAndLookup: - def test_register_and_lookup_by_prompt(self, mgr: LocalCacheMetadataManager): - mgr.register_cache("c1", "hello world", saved_steps=[3, 5], size_mb=1.0, num_frames=8) - entry = mgr.lookup_prompt("hello world") - assert entry is not None - assert entry.cache_id == "c1" - assert entry.saved_steps == [3, 5] - - def test_lookup_nonexistent_prompt(self, mgr: LocalCacheMetadataManager): - assert mgr.lookup_prompt("does not exist") is None - - def test_register_with_cache_type(self, mgr: LocalCacheMetadataManager): - mgr.register_cache("c2", "typed", saved_steps=[1], size_mb=0.5, num_frames=4, cache_type="video") - entry = mgr.lookup_prompt("typed", cache_type="video") - assert entry is not None - assert entry.cache_type == "video" - # Should not appear in default type - assert mgr.lookup_prompt("typed", cache_type="nonexistent") is None - - def test_duplicate_steps_are_deduplicated(self, mgr: LocalCacheMetadataManager): - mgr.register_cache("c3", "dup", saved_steps=[5, 3, 5, 3], size_mb=1.0, num_frames=4) - entry = mgr.lookup_prompt("dup") - assert entry is not None - assert entry.saved_steps == [3, 5] - - -class TestRemoveCache: - def test_remove_existing(self, mgr: LocalCacheMetadataManager): - mgr.register_cache("c1", "to delete", saved_steps=[1], size_mb=0.5, num_frames=4) - mgr.remove_cache("c1") - assert mgr.lookup_prompt("to delete") is None - assert mgr.get_cache_meta("c1") is None - - def test_remove_nonexistent_is_noop(self, mgr: LocalCacheMetadataManager): - mgr.remove_cache("nonexistent") # should not raise - - -class TestGetCacheMeta: - def test_returns_meta_dict(self, mgr: LocalCacheMetadataManager): - mgr.register_cache("c1", "meta test", saved_steps=[2], size_mb=1.5, num_frames=16) - meta = mgr.get_cache_meta("c1") - assert meta is not None - assert meta["prompt"] == "meta test" - assert meta["size_mb"] == 1.5 - assert meta["num_frames"] == 16 - - def test_returns_none_for_missing(self, mgr: LocalCacheMetadataManager): - assert mgr.get_cache_meta("missing") is None - - -class TestRecordAccess: - def test_increments_access_count(self, mgr: LocalCacheMetadataManager): - mgr.register_cache("c1", "access", saved_steps=[1], size_mb=0.5, num_frames=4) - mgr.record_access("c1") - mgr.record_access("c1") - meta = mgr.get_cache_meta("c1") - assert meta is not None - assert meta["access_count"] == 2 - - def test_noop_for_missing(self, mgr: LocalCacheMetadataManager): - mgr.record_access("missing") # should not raise - - -class TestPlanEviction: - def test_no_eviction_needed(self, mgr: LocalCacheMetadataManager): - mgr.register_cache("c1", "small", saved_steps=[1], size_mb=1.0, num_frames=4) - result = mgr.plan_eviction(required_mb=1.0, limit_mb=10.0) - assert result == [] - - def test_evicts_oldest_first(self, mgr: LocalCacheMetadataManager): - mgr.register_cache("old", "old", saved_steps=[1], size_mb=5.0, num_frames=4) - # Access "old" to set its timestamp, then register "new" which gets a newer timestamp - mgr.register_cache("new", "new", saved_steps=[1], size_mb=5.0, num_frames=4) - result = mgr.plan_eviction(required_mb=3.0, limit_mb=10.0) - assert len(result) > 0 - # Oldest entry (by last_access_time) should be evicted first - evicted_ids = [cid for cid, _ in result] - assert "old" in evicted_ids - - -class TestRecordHitPair: - def test_writes_jsonl(self, mgr: LocalCacheMetadataManager): - mgr.record_hit_pair( - request_prompt="new prompt", - cache_id="c1", - cached_prompt="old prompt", - similarity=0.95, - task_type="t2v", - cache_type="approximate", - skip_step=5, - ) - log_path = mgr.metadata_cache_dir / "hit_pairs.jsonl" - assert log_path.exists() - line = json.loads(log_path.read_text().strip()) - assert line["similarity"] == 0.95 - assert line["skip_step"] == 5 - - -class TestRecordSimilarityScores: - def test_writes_jsonl(self, mgr: LocalCacheMetadataManager): - mgr.record_similarity_scores( - request_prompt="query", - task_type="t2v", - cache_type="video", - stage="rerank", - candidates=[{"id": "c1", "score": 0.9}], - ) - log_path = mgr.metadata_cache_dir / "similarity_scores.jsonl" - assert log_path.exists() - line = json.loads(log_path.read_text().strip()) - assert line["stage"] == "rerank" - assert len(line["candidates"]) == 1 - - -class TestPersistence: - def test_survives_reload(self, tmp_path: Path): - meta_dir = tmp_path / "persist" - mgr1 = LocalCacheMetadataManager(meta_dir) - mgr1.register_cache("c1", "persist", saved_steps=[1, 3], size_mb=2.0, num_frames=8) - - # Create a new manager pointing to the same directory - mgr2 = LocalCacheMetadataManager(meta_dir) - entry = mgr2.lookup_prompt("persist") - assert entry is not None - assert entry.cache_id == "c1" - assert entry.saved_steps == [1, 3] - meta = mgr2.get_cache_meta("c1") - assert meta is not None - assert meta["size_mb"] == 2.0 diff --git a/tests/unit/cache_mem/test_storage.py b/tests/unit/cache_mem/test_storage.py deleted file mode 100644 index 40f4c00..0000000 --- a/tests/unit/cache_mem/test_storage.py +++ /dev/null @@ -1,105 +0,0 @@ -"""Unit tests for KV storage backends. - -Tests LocalFileKVStore and InMemoryKVStore. Pure CPU, no GPU required. -""" - -from __future__ import annotations - -from pathlib import Path - -import pytest - -from telefuser.cache_mem.storage.local_file import LocalFileKVStore -from telefuser.cache_mem.storage.memory import InMemoryKVStore - - -class TestInMemoryKVStore: - def test_put_and_get(self): - store = InMemoryKVStore() - store.put("k1", b"value1") - assert store.get("k1") == b"value1" - - def test_get_missing_returns_none(self): - store = InMemoryKVStore() - assert store.get("missing") is None - - def test_remove(self): - store = InMemoryKVStore() - store.put("k1", b"value1") - store.remove("k1") - assert store.get("k1") is None - - def test_remove_missing_is_noop(self): - store = InMemoryKVStore() - store.remove("nonexistent") # should not raise - - def test_list_keys(self): - store = InMemoryKVStore() - store.put("a", b"1") - store.put("b", b"2") - assert sorted(store.list_keys()) == ["a", "b"] - - def test_list_keys_empty(self): - store = InMemoryKVStore() - assert store.list_keys() == [] - - def test_overwrite(self): - store = InMemoryKVStore() - store.put("k1", b"old") - store.put("k1", b"new") - assert store.get("k1") == b"new" - - -class TestLocalFileKVStore: - @pytest.fixture() - def store(self, tmp_path: Path) -> LocalFileKVStore: - return LocalFileKVStore(tmp_path / "kv") - - def test_put_and_get(self, store: LocalFileKVStore): - store.put("k1", b"hello") - assert store.get("k1") == b"hello" - - def test_get_missing_returns_none(self, store: LocalFileKVStore): - assert store.get("missing") is None - - def test_remove(self, store: LocalFileKVStore): - store.put("k1", b"data") - store.remove("k1") - assert store.get("k1") is None - - def test_remove_missing_is_noop(self, store: LocalFileKVStore): - store.remove("nonexistent") # should not raise - - def test_list_keys(self, store: LocalFileKVStore): - store.put("x", b"1") - store.put("y", b"2") - assert sorted(store.list_keys()) == ["x", "y"] - - def test_overwrite(self, store: LocalFileKVStore): - store.put("k1", b"old") - store.put("k1", b"new") - assert store.get("k1") == b"new" - - def test_binary_data(self, store: LocalFileKVStore): - data = bytes(range(256)) - store.put("binary", data) - assert store.get("binary") == data - - def test_persistence_across_instances(self, tmp_path: Path): - kv_dir = tmp_path / "persist_kv" - s1 = LocalFileKVStore(kv_dir) - s1.put("persist_key", b"persist_value") - - s2 = LocalFileKVStore(kv_dir) - assert s2.get("persist_key") == b"persist_value" - - def test_remove_cleans_file(self, store: LocalFileKVStore): - store.put("to_del", b"data") - # Verify file exists - filename = store._index.get("to_del") - assert filename is not None - file_path = store.root_dir / filename - assert file_path.exists() - # Remove and verify file is gone - store.remove("to_del") - assert not file_path.exists() diff --git a/tests/unit/cache_mem/test_types_and_config.py b/tests/unit/cache_mem/test_types_and_config.py deleted file mode 100644 index c84b114..0000000 --- a/tests/unit/cache_mem/test_types_and_config.py +++ /dev/null @@ -1,96 +0,0 @@ -"""Unit tests for cache data types and config. - -Tests CacheResult, IndexEntry, VectorSearchResult construction and -CacheConfig defaults / field parsing. Pure CPU, no GPU required. -""" - -from __future__ import annotations - -import torch - -from telefuser.cache_mem.cache_types import CacheResult, IndexEntry, VectorSearchResult -from telefuser.cache_mem.config import CacheConfig, CacheMode - - -class TestCacheResult: - def test_default_miss(self): - r = CacheResult(hit=False) - assert not r.hit - assert r.skip_step == 0 - assert r.cache_type == "none" - assert r.latent_state is None - - def test_hit_with_latent(self): - t = torch.randn(1, 16, 5, 32, 32) - r = CacheResult( - hit=True, - skip_step=5, - cache_type="approximate", - similarity=0.95, - latent_state=t, - cached_prompt="cached", - ) - assert r.hit - assert r.skip_step == 5 - assert r.latent_state is t - - -class TestIndexEntry: - def test_construction(self): - e = IndexEntry(cache_id="abc123", prompt="hello", saved_steps=[3, 5]) - assert e.cache_id == "abc123" - assert e.cache_type == "approximate_cache" - - def test_custom_cache_type(self): - e = IndexEntry(cache_id="x", prompt="p", saved_steps=[], cache_type="video") - assert e.cache_type == "video" - - -class TestVectorSearchResult: - def test_construction(self): - r = VectorSearchResult( - cache_id="v1", - similarity=0.88, - prompt="search query", - saved_steps=[1, 2], - payload={"extra": "data"}, - ) - assert r.similarity == 0.88 - assert r.payload["extra"] == "data" - - -class TestCacheConfig: - def test_defaults(self): - cfg = CacheConfig() - assert cfg.enable_latent_cache is False - assert cfg.cache_mode == CacheMode.READ_WRITE - assert cfg.kv_store_type == "local_file" - assert cfg.vector_store_type == "faiss" - assert cfg.vector_dim == 2048 - assert cfg.save_async_enabled is True - - def test_custom_values(self): - cfg = CacheConfig( - enable_latent_cache=True, - cache_mode=CacheMode.READ_ONLY, - kv_store_type="fluxon", - vector_store_type="qdrant", - vector_dim=1024, - video_similarity_threshold=0.25, - rerank_enabled=True, - ) - assert cfg.enable_latent_cache is True - assert cfg.cache_mode == CacheMode.READ_ONLY - assert cfg.kv_store_type == "fluxon" - assert cfg.vector_dim == 1024 - assert cfg.rerank_enabled is True - - -class TestCacheMode: - def test_enum_values(self): - assert CacheMode.READ_WRITE.value == "read_write" - assert CacheMode.READ_ONLY.value == "read_only" - assert CacheMode.WRITE_ONLY.value == "write_only" - - def test_from_string(self): - assert CacheMode("read_only") == CacheMode.READ_ONLY From 0b467a992a3d57328452444d4a63863091831fdf Mon Sep 17 00:00:00 2001 From: jinyx5 Date: Mon, 22 Jun 2026 09:56:46 +0800 Subject: [PATCH 2/4] =?UTF-8?q?fix(wan22=20t2v=20service):=20cacheseek.cor?= =?UTF-8?q?e.config=20=E2=86=92=20=E5=85=AC=E5=85=B1=20cacheseek=20import?= =?UTF-8?q?=20(arch-v2)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit arch-v2 退役了 cacheseek.core,CacheConfig 现从顶层 `cacheseek` 导出。cache 与 nocache 两个 wan22 T2V service 入口仍 import arch-v1 的 cacheseek.core.config, 导致 cacheseek approximate-reuse e2e 在服务启动期 ModuleNotFoundError 崩溃。 Co-Authored-By: Claude Opus 4.8 (1M context) --- examples/wan_video/wan22_14b_text_to_video_service.py | 2 +- examples/wan_video/wan22_14b_text_to_video_service_nocache.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/wan_video/wan22_14b_text_to_video_service.py b/examples/wan_video/wan22_14b_text_to_video_service.py index ba8605e..70d6687 100644 --- a/examples/wan_video/wan22_14b_text_to_video_service.py +++ b/examples/wan_video/wan22_14b_text_to_video_service.py @@ -22,7 +22,7 @@ import os import torch -from cacheseek.core.config import CacheConfig +from cacheseek import CacheConfig from telefuser.core.config import AttentionConfig, AttnImplType, FeatureCacheConfig, WeightOffloadType from telefuser.core.module_manager import ModuleManager diff --git a/examples/wan_video/wan22_14b_text_to_video_service_nocache.py b/examples/wan_video/wan22_14b_text_to_video_service_nocache.py index 592ac80..f707f16 100644 --- a/examples/wan_video/wan22_14b_text_to_video_service_nocache.py +++ b/examples/wan_video/wan22_14b_text_to_video_service_nocache.py @@ -15,7 +15,7 @@ import os import torch -from cacheseek.core.config import CacheConfig +from cacheseek import CacheConfig from telefuser.core.config import AttentionConfig, AttnImplType, FeatureCacheConfig, WeightOffloadType from telefuser.core.module_manager import ModuleManager From 252e6f1d9c44bac6ab31b89fffffbc5f8bf84294 Mon Sep 17 00:00:00 2001 From: jader Date: Wed, 1 Jul 2026 13:36:18 +0000 Subject: [PATCH 3/4] docs: update latent cache documentation for CacheSeek integration --- docs/en/latent_cache.md | 266 ++++++++-------------------------------- docs/zh/latent_cache.md | 224 +++++++-------------------------- 2 files changed, 100 insertions(+), 390 deletions(-) diff --git a/docs/en/latent_cache.md b/docs/en/latent_cache.md index 9055ab8..3e4f1bd 100644 --- a/docs/en/latent_cache.md +++ b/docs/en/latent_cache.md @@ -3,7 +3,8 @@ Latent cache reuses **the intermediate latent from a previous inference** when an incoming prompt is similar enough to a prompt already served, so the first N denoising steps can be skipped. TeleFuser integrates this feature -through the external **CacheSeek** package. The legacy in-tree +through the external **CacheSeek** package: +. The legacy in-tree `telefuser/cache_mem/` implementation has been removed. ## Latent Cache vs. Feature Cache @@ -103,117 +104,24 @@ cache_service, cache_adapter = CacheServiceFactory.create_cache_service( ) ``` -`create_cache_service` is implemented by CacheSeek and does the following: +TeleFuser only relies on this factory's input/output contract: -1. Loads `CACHE_CONFIG` (a dict or a `CacheConfig` instance) from `ppl_file` - as the base configuration. -2. Overrides the final config with the CLI's `enable_latent_cache` / - `cache_mode`. -3. Initializes the cache log sink up front. -4. Builds the CacheSeek storage, vector store, metadata manager, and strategy. -5. Returns the framework-agnostic `CacheService` plus the TeleFuser adapter - used for `build_query`, `apply_resume`, and `on_response`. +- Pass the current pipeline file path as `ppl_file`. +- Pass `enable_latent_cache=True`. +- If the CLI set `--cache-mode`, pass that value as `cache_mode`; otherwise + pass `None`, leaving the pipeline file's `CACHE_CONFIG` or CacheSeek to + decide. +- Expect `(cache_service, cache_adapter)` back, where the adapter provides + `build_query`, `apply_resume`, and `on_response`. -Manual construction should use CacheSeek primitives directly when needed: - -```python -from pathlib import Path - -from cacheseek.service.config import CacheConfig -from cacheseek.service.lifecycle import CacheService - -config = CacheConfig( - enable_latent_cache=True, - latent_cache_dir="./latent_cache/wan22_t2v", - cache_strategy_type="video_approximate", - vector_dim=2048, -) -cache_service = CacheService.from_config(config) -``` - -The strategy class is selected by CacheSeek's TeleFuser factory from -`cache_strategy_type`. Custom strategies should be registered in CacheSeek, -not in TeleFuser. +If you need to construct the cache service outside TeleFuser service startup, +use the CacheSeek documentation and `cacheseek/service/config.py` as the +source of truth. TeleFuser no longer provides the old in-tree `LatentCache` +facade. --- -## VideoBasedApproximateCache - -The only production strategy implementation is -`VideoBasedApproximateCache`, which combines: - -- **Prompt encoding**: `Qwen3-VL-Embedding` encodes the prompt into a vector - that is written to the vector store. -- **Video encoding**: during save, several frames of the generated video are - encoded into the same vector space, used as the similarity basis for - future hits. -- **Optional rerank**: when `rerank_enabled` is on, `Qwen3-VL-Reranker` - performs cross-encoder reranking over the top-k candidates. -- **Shared backend**: when text and video embedding configs end up loading - the same model on the same device, the two automatically share a single - `Qwen3VLEncoder` instance, saving roughly 5 GB of GPU memory and one cold - load. - -### How VideoBasedApproximateCache Works - -#### Write Path - -When a request finishes, the pipeline returns `latent_payload` containing -the per-step latents plus video frames used for prompt similarity. The service -layer passes that payload through `TeleFuserCacheAdapter.on_response`, then -calls `CacheService.save(cache_query, outputs)`, which enqueues it onto the -CacheSeek async save worker: - -1. Writes each step's latent to the KV store under a key shaped like - `f"{cache_id}_step{step}"`. -2. Encodes the video frames with `Qwen3-VL-Embedding` and upserts the - vector into the vector store (default collection name `video`). -3. Registers `cache_id -> {prompt, saved_steps, size_mb, ...}` in metadata, - persisting `prompt_index.json` and `cache_meta.json`. - -If any step fails, all the latents / vectors / metadata that were already -written are rolled back cleanly to avoid an inconsistent state. - -#### Hit Path - -When a new request arrives, the service layer runs -`adapter.build_query -> cache_service.lookup -> adapter.apply_resume`: - -1. Waits on `vector_update_idle` to make sure the vector upsert from the - previous async save has been committed. -2. Calls CacheSeek lookup: encodes the new prompt, queries the top-k - approximate caches in the vector store, optionally reranks with - Qwen3-VL-Reranker, and compares against the threshold to decide on a - hit. -3. On a hit, loads the latent tensor for `skip_step` from the KV store and - wraps it into a `CacheResult`. - -The `latent_data` dict the pipeline receives includes `cached_latent`, -`skip_step`, and `saved_steps`. The pipeline restarts the denoise loop at -`skip_step` and snapshots this run's latents according to `saved_steps` — -that is how the cache keeps growing. - -### Cache Parameters - -The core parameters used by `VideoBasedApproximateCache`: - -| Parameter | Type | Description | -| ---------------------------- | ----- | --------------------------------------------------- | -| `key_steps` | list | Step list at which the pipeline is asked to snapshot | -| `video_similarity_threshold` | float | Lower bound for a vector-search hit | -| `rerank_enabled` | bool | Whether to rerank the top-k with Qwen3-VL-Reranker | -| `rerank_top_k` | int | Number of candidates fed into rerank | -| `rerank_score_threshold` | float | Lower bound for a hit when rerank is enabled | -| `video_embedding_max_frames` | int | Max frames sampled when encoding video | -| `video_vector_collection` | str | FAISS collection name (default `video`) | - -> Which step to restart from after a hit is decided by -> `_determine_skip_step`: in the current implementation, it skips to step 5 -> when `similarity` is above the rerank threshold and `5 ∈ saved_steps`, -> otherwise it is treated as a miss. Override this method in a subclass to -> customize the skip policy. - -### Using VideoBasedApproximateCache +## Pipeline Configuration Example Declare `CACHE_CONFIG` in the pipeline file and start TeleFuser with `--enable-latent-cache`: @@ -243,12 +151,27 @@ CACHE_CONFIG = dict( ) ``` -The pipeline file only needs to expose `run_with_file` plus a `CACHE_CONFIG` -dict for CacheSeek configuration: +The table below describes common fields in the Wan2.2 service example's +`CACHE_CONFIG` and the example defaults. Here "default" means the value the +example file passes to CacheSeek when no matching environment variable or CLI +override is set. For CacheSeek's complete field list and built-in defaults, +refer to the CacheSeek documentation and `cacheseek/service/config.py`. -- `run_with_file(pipeline, **task_data) -> dict`: feeds `latent_data` into - the pipeline and returns `latent_payload` as part of the result so the - service layer can write it back to the cache. +| Field | Example default | Description | +|---|---|---| +| `enable_latent_cache` | `True` | The example pipeline declares latent cache support; service startup still requires `--enable-latent-cache` to initialize CacheSeek. | +| `cache_mode` | `write_only` | Write-only by default, useful for warming or building cache first; override with `TELEFUSER_CACHE_MODE` or CLI `--cache-mode` to `read_write` / `read_only` / `write_only`. | +| `latent_cache_dir` | `./latent_cache/wan22_t2v` | Cache root directory; override with `TELEFUSER_LATENT_CACHE_DIR`. | +| `kv_store_type` | `local_file` | KV backend type; override with `TELEFUSER_KV_STORE_TYPE`. | +| `vector_store_type` | `faiss` | Vector-search backend type; override with `TELEFUSER_VECTOR_STORE_TYPE`. | +| `vector_dim` | `2048` | Vector dimension; must match the embedding model output dimension. | +| `key_steps` | `[5, 10, 15, 20, 25]` | Denoise steps at which the pipeline is asked to snapshot latents. | +| `video_embedding_enabled` | `True` | Enables video-frame embedding; the example save path backfills `embedding_video_frames`. | +| `video_embedding_model_path` | `""` | Video embedding model path; override with `QWEN3VL_EMBEDDING_PATH`. How an empty value is resolved is owned by CacheSeek. | +| `video_embedding_max_frames` | `16` | Maximum number of video frames sampled before cache writeback. | +| `text_embedding_device_id` / `video_embedding_device_id` | `1` / `1` | Logical GPU ids for embedding models; adjust for `CUDA_VISIBLE_DEVICES` and parallelism. | +| `video_vector_collection` | `video` | Video vector collection name. | +| `rerank_enabled` / `rerank_top_k` / `rerank_score_threshold` | `True` / `5` / `0.85` | Example rerank configuration; CacheSeek executes the actual hit policy. | --- @@ -273,109 +196,26 @@ telefuser serve examples/wan_video/wan22_14b_text_to_video_service.py \ --- -## CacheConfig Field Reference - -The authoritative definition lives in CacheSeek's `cacheseek/service/config.py`. -The table below lists the fields commonly used by the TeleFuser Wan2.2 service -example. - -### Basic - -| Field | Default | Description | -|---|---|---| -| `enable_latent_cache` | `False` | Master switch; toggled by CLI `--enable-latent-cache`. | -| `cache_mode` | `read_write` | One of `read_write` / `read_only` / `write_only`. | -| `latent_cache_dir` | `./latent_cache` | Root directory for storage, metadata, FAISS, and logs. | -| `max_cache_size_gb` | `10` | Soft eviction cap (LRU by `last_access_time`). | - -### Logging - -| Field | Default | -|---|---| -| `cache_log_enabled` | `True` | -| `cache_log_dir` | `{latent_cache_dir}/logs` | -| `cache_log_level` | `DEBUG` | -| `cache_log_rotation` | `100 MB` | -| `cache_log_retention` | `7 days` | - -### KV / Vector Backend - -| Field | Default | Description | -|---|---|---| -| `kv_store_type` | `local_file` | Or `fluxon` (stub). | -| `vector_store_type` | `faiss` | Or `qdrant` (stub). | -| `vector_dim` | `2048` | Must match the embedder output dim. | -| `faiss_index_dir` | `{latent_cache_dir}/faiss` | | -| `qdrant_url` / `qdrant_api_key` | `""` / `None` | Configure once Qdrant is wired up for real. | -| `cache_strategy_type` | `video_approximate` | Key in the strategy registry. | - -### Strategy and Embedding - -| Field | Default | Description | -|---|---|---| -| `key_steps` | `[5, 10, 15, 20, 25]` | Step list at which the pipeline is asked to snapshot. | -| `lookup_mode` | `video` | | -| `video_embedding_enabled` | `True` | | -| `video_embedding_model_path` | `Qwen/Qwen3-VL-Embedding-2B` | | -| `video_embedding_max_frames` | `16` | | -| `video_embedding_fps` | `1.0` | | -| `text_embedding_model_path` | `""` | Empty means reuse the video embedder. | -| `video_similarity_threshold` | `0.10` | Lower bound for a vector-search hit. | -| `rerank_enabled` | `True` | When on, rerank the top-k with Qwen3-VL-Reranker. | -| `rerank_top_k` | `5` | | -| `rerank_score_threshold` | `0.80` | Lower bound for a hit when rerank is enabled. | - -When the text and video embedding configurations end up loading the same -model on the same device, `VideoBasedApproximateCache` lets them share a -single `Qwen3VLEncoder` instance, saving roughly 5 GB of GPU memory and one -cold load. - -### Async Save - -| Field | Default | Description | -|---|---|---| -| `save_async_enabled` | `True` | Offload `save` onto the worker thread. | -| `save_queue_size` | `2` | `0` means unbounded. | -| `save_on_full` | `drop` | `drop` / `sync` / `downgrade` (downgrade is TODO). | -| `save_queue_warn_threshold` | `8` | Log a warning when queue depth exceeds this value. | -| `vector_wait_warn_s` | `2.0` | Log a warning when `lookup` waits on the vector barrier longer than this. | -| `vector_wait_timeout_s` | `120.0` | Give up the barrier after timeout and treat as miss. | -| `flush_on_shutdown` | `True` | `CacheService.shutdown` drains the queue first. | +## Startup and Runtime Behavior ### The Three Cache Modes | Mode | Effect | |---|---| -| `READ_WRITE` | Lookup hits, and writes are also persisted. The default. | -| `READ_ONLY` | Lookup hits, but the cache is not updated. Useful during canary rollouts. | -| `WRITE_ONLY` | Lookup always misses, only accumulating cache. Common when warming up a cache against a benchmark. | - ---- - -## Architecture Overview - -``` -┌────────────────────────────────────────────────────────┐ -│ LatentCache (facade) │ -│ │ -│ ├─ Strategy VideoBasedApproximateCache │ -│ │ ├─ prompt_encoder Qwen3-VL-Embedding │ -│ │ ├─ video_encoder Qwen3-VL-Embedding (shared) │ -│ │ └─ reranker Qwen3-VL-Reranker (optional) │ -│ │ │ -│ ├─ KVStore LocalFileKVStore | FluxonKVStore* │ -│ ├─ VectorStore FAISSVectorStore | QdrantStore* │ -│ └─ MetadataManager LocalCacheMetadataManager │ -└──────────▲─────────────────────────────────────────────┘ - │ via CacheService (async writeback wrapper) - │ - FastAPI request thread / pipeline -``` - -The Fluxon / Qdrant backends marked with `*` are still stubs (they raise -`NotImplementedError`); the production path only goes through -`LocalFileKVStore` + `FAISSVectorStore`. - -`CacheService` owns the background async writeback thread plus a barrier -called `vector_update_idle`, which prevents a `lookup` from reading a stale -index before the previous `save` finishes its vector upsert. +| `read_write` | Read existing cache entries; write new cache entries after each request completes. | +| `read_only` | Read existing cache entries; do not write new cache entries. | +| `write_only` | Do not use cache hits; only write cache entries after each request completes. | + +`cache_mode` can be set in the pipeline file's `CACHE_CONFIG` or overridden +with `telefuser serve --cache-mode`. The CLI only accepts the three values +above. When `--cache-mode` is omitted, TeleFuser does not override the +`CACHE_CONFIG` value. The Wan2.2 service example defaults to `write_only`. + +### Startup and Failure Semantics + +- Without `--enable-latent-cache`, CacheSeek is not loaded and latent cache is + not initialized. +- With `--enable-latent-cache`, a missing CacheSeek install or initialization + failure fails service startup immediately. +- Per-request failures in `build_query` / `lookup` / `apply_resume` / `save` + are logged as warnings and the request continues through the uncached path. diff --git a/docs/zh/latent_cache.md b/docs/zh/latent_cache.md index e2f54f8..633857c 100644 --- a/docs/zh/latent_cache.md +++ b/docs/zh/latent_cache.md @@ -2,7 +2,8 @@ Latent cache 用于在新到达的 prompt 和已经生成过的 prompt 足够相似 时**复用上一次推理的中间 latent**,跳过前若干步去噪。TeleFuser 通过外部 -**CacheSeek** 包接入该能力;旧的仓内 `telefuser/cache_mem/` 实现已移除。 +**CacheSeek** 包接入该能力:; +旧的仓内 `telefuser/cache_mem/` 实现已移除。 ## Latent Cache 与 Feature Cache 的区别 @@ -96,96 +97,22 @@ cache_service, cache_adapter = CacheServiceFactory.create_cache_service( ) ``` -`create_cache_service` 内部会: +TeleFuser 侧只依赖这个 factory 的输入输出契约: -1. 从 `ppl_file` 加载 `CACHE_CONFIG`(dict 或 `CacheConfig` 实例)作为默认配置基础。 -2. 用 CLI 的 `enable_latent_cache` / `cache_mode` 覆盖最终配置。 -3. 提前初始化 cache 日志 sink。 -4. 构造 CacheSeek 的存储、向量库、元数据管理器和策略。 -5. 返回框架无关的 `CacheService`,以及用于 `build_query`、`apply_resume` - 和 `on_response` 的 TeleFuser 适配器。 +- 传入当前 pipeline 文件路径 `ppl_file`。 +- 传入 `enable_latent_cache=True`。 +- 如果 CLI 设置了 `--cache-mode`,把该值作为 `cache_mode` 传给 CacheSeek; + 未设置时传 `None`,由 pipeline 文件中的 `CACHE_CONFIG` 或 CacheSeek 决定。 +- 期望返回 `(cache_service, cache_adapter)`,其中 adapter 提供 + `build_query`、`apply_resume` 和 `on_response`。 -需要直接构造时应使用 CacheSeek 原语: - -```python -from pathlib import Path - -from cacheseek.service.config import CacheConfig -from cacheseek.service.lifecycle import CacheService - -config = CacheConfig( - enable_latent_cache=True, - latent_cache_dir="./latent_cache/wan22_t2v", - cache_strategy_type="video_approximate", - vector_dim=2048, -) -cache_service = CacheService.from_config(config) -``` - -策略类由 CacheSeek 的 TeleFuser factory 根据 `cache_strategy_type` 选择。 -自定义策略应注册到 CacheSeek,而不是 TeleFuser。 +需要绕过 TeleFuser service 直接构造缓存服务时,请以 CacheSeek 文档和 +`cacheseek/service/config.py` 为准;TeleFuser 不再提供旧的仓内 +`LatentCache` 外观类。 --- -## VideoBasedApproximateCache - -线上唯一的策略实现是 `VideoBasedApproximateCache`,结合: - -- **Prompt 编码**:`Qwen3-VL-Embedding` 把 prompt 编码成向量,写入向量检索库。 -- **视频编码**:save 阶段对生成视频的若干帧编码至同一向量空间,作为命中时的相似度计算依据。 -- **可选 rerank**:开启 `rerank_enabled` 后用 `Qwen3-VL-Reranker` 在 top-k 上 - 做交叉编码精排。 -- **共享后端**:当 text 和 video 的 embedding 配置最终落到同一个 model + device - 时,自动让两者共享同一个 `Qwen3VLEncoder` 实例,节约近 5 GB 显存和一次冷加载。 - -### VideoBasedApproximateCache 工作原理 - -#### 写入路径 - -请求结束后,pipeline 返回 `latent_payload`(含按步存储的 latent + 用于 prompt -相似度的视频帧)。服务层先经过 `TeleFuserCacheAdapter.on_response` 打包, -再调用 `CacheService.save(cache_query, outputs)`,由 CacheSeek 异步保存 worker -处理: - -1. 将每个 step 的 latent 写到 KV,key 形如 `f"{cache_id}_step{step}"`。 -2. 通过 `Qwen3-VL-Embedding` 将视频帧编码成向量,upsert 至 向量检索库(默认 - collection 名 `video`)。 -3. 在 metadata 里登记 `cache_id -> {prompt, saved_steps, size_mb, ...}`, - 持久化 `prompt_index.json` 和 `cache_meta.json`。 - -任何一步失败,已写入的 latent / 向量 / metadata 都会回滚干净,避免状态不一致。 - -#### 命中路径 - -新请求到达后,服务层执行 -`adapter.build_query -> cache_service.lookup -> adapter.apply_resume`: - -1. 等待 `vector_update_idle`——确保上一笔异步 save 的向量 upsert 已落库。 -2. 调用 CacheSeek lookup:对新 prompt 编码,在向量检索库中查 top-k 近似 - 缓存;可选用 Qwen3-VL-Reranker 重排,跟阈值比对决定是否命中。 -3. 命中后从 KV 读出 `skip_step` 对应的 latent 张量,封装成 `CacheResult` 返回。 - -Pipeline 拿到的 `latent_data` 字典里包括 `cached_latent`、`skip_step`、 -`saved_steps`。Pipeline 于 `skip_step` 处重启去噪循环,并按 `saved_steps` -把当次的 latent 也快照下来——缓存就是这样越攒越多的。 - -### 缓存参数 - -`VideoBasedApproximateCache` 关心的核心参数: - -| 参数 | 类型 | 描述 | -| ---------------------------- | ----- | ---------------------------------- | -| `key_steps` | list | pipeline 被要求 snapshot 的 step 列表 | -| `video_similarity_threshold` | float | 向量搜索的命中下限 | -| `rerank_enabled` | bool | 是否启用 Qwen3-VL-Reranker 在 top-k 上重排 | -| `rerank_top_k` | int | 进入 rerank 的候选数量 | -| `rerank_score_threshold` | float | rerank 启用时的命中下限 | -| `video_embedding_max_frames` | int | 视频编码时最多采样的帧数 | -| `video_vector_collection` | str | FAISS collection 名(默认 `video`) | - -> 命中后到底从第几步重启由 `_determine_skip_step` 决定:当前实现里 `similarity`高于 rerank 阈值且 `5 ∈ saved_steps` 时跳过到第 5 步,否则视为 miss。需要自定义跳点策略时可在子类里覆盖此方法。 - -### 使用 VideoBasedApproximateCache +## Pipeline 配置示例 在 pipeline 文件里声明 `CACHE_CONFIG`,并在启动时传入 `--enable-latent-cache`: @@ -214,6 +141,27 @@ CACHE_CONFIG = dict( ) ``` +下表说明 Wan2.2 service 示例 `CACHE_CONFIG` 中的常用字段和示例默认值。 +这里的“默认”指示例文件在没有对应环境变量或 CLI 覆盖时传给 CacheSeek 的值; +CacheSeek 自身的全量字段和内置默认值以 CacheSeek 文档和 +`cacheseek/service/config.py` 为准。 + +| 字段 | 示例默认值 | 说明 | +|---|---|---| +| `enable_latent_cache` | `True` | 示例 pipeline 声明支持 latent cache;服务启动仍需传 `--enable-latent-cache` 才会初始化 CacheSeek。 | +| `cache_mode` | `write_only` | 默认只写缓存,适合先预热或生成缓存;可用 `TELEFUSER_CACHE_MODE` 或 CLI `--cache-mode` 覆盖为 `read_write` / `read_only` / `write_only`。 | +| `latent_cache_dir` | `./latent_cache/wan22_t2v` | 缓存根目录;可用 `TELEFUSER_LATENT_CACHE_DIR` 覆盖。 | +| `kv_store_type` | `local_file` | KV 后端类型;可用 `TELEFUSER_KV_STORE_TYPE` 覆盖。 | +| `vector_store_type` | `faiss` | 向量检索后端类型;可用 `TELEFUSER_VECTOR_STORE_TYPE` 覆盖。 | +| `vector_dim` | `2048` | 向量维度,需要与 embedding 模型输出维度一致。 | +| `key_steps` | `[5, 10, 15, 20, 25]` | Pipeline 被要求 snapshot 的 denoise step 列表。 | +| `video_embedding_enabled` | `True` | 启用视频帧 embedding;示例 save 路径会回填 `embedding_video_frames`。 | +| `video_embedding_model_path` | `""` | 视频 embedding 模型路径;可用 `QWEN3VL_EMBEDDING_PATH` 覆盖,空值如何解析由 CacheSeek 决定。 | +| `video_embedding_max_frames` | `16` | 写回缓存前最多采样的视频帧数。 | +| `text_embedding_device_id` / `video_embedding_device_id` | `1` / `1` | embedding 模型使用的逻辑 GPU id;需要按 `CUDA_VISIBLE_DEVICES` 和并行配置调整。 | +| `video_vector_collection` | `video` | 视频向量 collection 名称。 | +| `rerank_enabled` / `rerank_top_k` / `rerank_score_threshold` | `True` / `5` / `0.85` | rerank 示例配置;实际命中策略由 CacheSeek 执行。 | + --- @@ -238,102 +186,24 @@ telefuser serve examples/wan_video/wan22_14b_text_to_video_service.py \ --- -## CacheConfig 字段说明 - -权威定义在 CacheSeek 的 `cacheseek/service/config.py`。下表只列出 -TeleFuser Wan2.2 service 示例常用字段。 - -### 基础 - -| 字段 | 默认 | 说明 | -|---|---|---| -| `enable_latent_cache` | `False` | 总开关,CLI `--enable-latent-cache` 会翻它。 | -| `cache_mode` | `read_write` | `read_write` / `read_only` / `write_only`。 | -| `latent_cache_dir` | `./latent_cache` | 存储、metadata、FAISS、日志的根目录。 | -| `max_cache_size_gb` | `10` | 软淘汰上限(按 `last_access_time` 做 LRU)。 | - -### 日志 - -| 字段 | 默认 | -|---|---| -| `cache_log_enabled` | `True` | -| `cache_log_dir` | `{latent_cache_dir}/logs` | -| `cache_log_level` | `DEBUG` | -| `cache_log_rotation` | `100 MB` | -| `cache_log_retention` | `7 days` | - -### KV / Vector 后端 - -| 字段 | 默认 | 说明 | -|---|---|---| -| `kv_store_type` | `local_file` | 或 `fluxon`(stub)。 | -| `vector_store_type` | `faiss` | 或 `qdrant`(stub)。 | -| `vector_dim` | `2048` | 必须和 embedder 输出维度一致。 | -| `faiss_index_dir` | `{latent_cache_dir}/faiss` | | -| `qdrant_url` / `qdrant_api_key` | `""` / `None` | 等真正接 Qdrant 时再配。 | -| `cache_strategy_type` | `video_approximate` | 策略注册表里的 key。 | - -### 策略与 embedding - -| 字段 | 默认 | 说明 | -|---|---|---| -| `key_steps` | `[5, 10, 15, 20, 25]` | pipeline 被要求 snapshot 的 step 列表。 | -| `lookup_mode` | `video` | | -| `video_embedding_enabled` | `True` | | -| `video_embedding_model_path` | `Qwen/Qwen3-VL-Embedding-2B` | | -| `video_embedding_max_frames` | `16` | | -| `video_embedding_fps` | `1.0` | | -| `text_embedding_model_path` | `""` | 留空则复用 video embedder。 | -| `video_similarity_threshold` | `0.10` | 向量搜索的命中下限。 | -| `rerank_enabled` | `True` | 开了就用 Qwen3-VL-Reranker 在 top-k 上重排。 | -| `rerank_top_k` | `5` | | -| `rerank_score_threshold` | `0.80` | rerank 启用时的命中下限。 | - -当 text 和 video 的 embedding 配置最终落到同一个 model + device 时, -`VideoBasedApproximateCache` 会让两者共享同一个 `Qwen3VLEncoder` 实例, -省下大约 5 GB 显存和一次冷加载。 - -### 异步保存 - -| 字段 | 默认 | 说明 | -|---|---|---| -| `save_async_enabled` | `True` | 把 `save` 卸到 worker 线程。 | -| `save_queue_size` | `2` | `0` 表示不限。 | -| `save_on_full` | `drop` | `drop` / `sync` / `downgrade`(downgrade 是 TODO)。 | -| `save_queue_warn_threshold` | `8` | 队列深度超此值打 warning。 | -| `vector_wait_warn_s` | `2.0` | `lookup` 等向量栅栏超过此值打 warning。 | -| `vector_wait_timeout_s` | `120.0` | 等到 timeout 就放弃栅栏,按 miss 走。 | -| `flush_on_shutdown` | `True` | `CacheService.shutdown` 会先把队列里的任务放空。 | +## 启动与运行行为 ### Cache mode 三档 | 模式 | 效果 | |---|---| -| `READ_WRITE` | lookup 命中、写完也回写。常态。 | -| `READ_ONLY` | lookup 命中、但不更新缓存。在线灰度期间用得上。 | -| `WRITE_ONLY` | lookup 永远 miss、只攒缓存。对着 benchmark 跑一遍预热 cache 时常用。 | - ---- +| `read_write` | 读取已有缓存;请求完成后也写回新的缓存。 | +| `read_only` | 只读取已有缓存,不写回新的缓存。 | +| `write_only` | 不使用缓存命中结果,只在请求完成后写入缓存。 | -## 架构总览 +`cache_mode` 可以写在 pipeline 文件的 `CACHE_CONFIG` 里,也可以用 +`telefuser serve --cache-mode` 覆盖。CLI 只接受上表三种值;未传 +`--cache-mode` 时,TeleFuser 不覆盖 `CACHE_CONFIG` 中的配置。Wan2.2 +service 示例的默认值是 `write_only`。 -``` -┌────────────────────────────────────────────────────────┐ -│ LatentCache(外观类) │ -│ │ -│ ├─ Strategy VideoBasedApproximateCache │ -│ │ ├─ prompt_encoder Qwen3-VL-Embedding │ -│ │ ├─ video_encoder Qwen3-VL-Embedding(共享) │ -│ │ └─ reranker Qwen3-VL-Reranker(可选) │ -│ │ │ -│ ├─ KVStore LocalFileKVStore | FluxonKVStore* │ -│ ├─ VectorStore FAISSVectorStore | QdrantStore* │ -│ └─ MetadataManager LocalCacheMetadataManager │ -└──────────▲─────────────────────────────────────────────┘ - │ 通过 CacheService(异步写回包装) - │ - FastAPI 请求线程 / pipeline -``` +### 启动和失败语义 -`*` 标注的 Fluxon / Qdrant 后端目前是 stub(`NotImplementedError`), -线上路径只有 `LocalFileKVStore` + `FAISSVectorStore` 两个分支。 +- 未传 `--enable-latent-cache` 时,不会加载 CacheSeek,也不会初始化 latent cache。 +- 传入 `--enable-latent-cache` 后,如果 CacheSeek 未安装或初始化失败,服务启动直接失败。 +- 单次请求中的 `build_query` / `lookup` / `apply_resume` / `save` 失败会记录 warning, + 并按无缓存路径继续。 From 38379f73238c91c54b40186a22967178df774663 Mon Sep 17 00:00:00 2001 From: jader Date: Wed, 1 Jul 2026 13:59:53 +0000 Subject: [PATCH 4/4] style: apply ruff format --- telefuser/service/api/api_server.py | 3 +-- telefuser/service/core/task_service.py | 3 +-- tools/deploy/docker_monitor.py | 4 ++-- tools/deploy/show_stat.py | 32 +++++++++++++------------- tools/viewer/weight_viewer.py | 6 ++--- 5 files changed, 23 insertions(+), 25 deletions(-) diff --git a/telefuser/service/api/api_server.py b/telefuser/service/api/api_server.py index 685d568..29f382d 100644 --- a/telefuser/service/api/api_server.py +++ b/telefuser/service/api/api_server.py @@ -274,8 +274,7 @@ def initialize_services( cache_service: Any | None = None, cache_adapter: Any | None = None, ) -> None: - """Initialize file and media services. - """ + """Initialize file and media services.""" self.file_service = FileService(cache_dir) self.inference_service = inference_service self.cache_service = cache_service diff --git a/telefuser/service/core/task_service.py b/telefuser/service/core/task_service.py index 6b991f1..97b1527 100644 --- a/telefuser/service/core/task_service.py +++ b/telefuser/service/core/task_service.py @@ -22,8 +22,7 @@ def _build_cache_task_request(task_data: dict) -> SimpleNamespace: - """Build a minimal task_request stub for the cache layer. - """ + """Build a minimal task_request stub for the cache layer.""" return SimpleNamespace( task_id=task_data.get("task_id"), task=task_data.get("task"), diff --git a/tools/deploy/docker_monitor.py b/tools/deploy/docker_monitor.py index 251e5f5..a79119f 100644 --- a/tools/deploy/docker_monitor.py +++ b/tools/deploy/docker_monitor.py @@ -754,9 +754,9 @@ def validate_sliding_average_calculation(): all_correct = True for i, (comp, exp) in enumerate(zip(computed, test_case["expected"])): if abs(comp - exp) < 0.01: - print(f" 步骤 {i+1}: ✓ 通过") + print(f" 步骤 {i + 1}: ✓ 通过") else: - print(f" 步骤 {i+1}: ✗ 失败 (计算: {comp}, 预期: {exp})") + print(f" 步骤 {i + 1}: ✗ 失败 (计算: {comp}, 预期: {exp})") all_correct = False if all_correct: diff --git a/tools/deploy/show_stat.py b/tools/deploy/show_stat.py index 29c20b4..1ab23f7 100644 --- a/tools/deploy/show_stat.py +++ b/tools/deploy/show_stat.py @@ -515,7 +515,7 @@ def print_peak_analysis(self, sliding_avg: list[float]): # 显示峰值信息 print(f"\n{self.colors['CYAN']}前5个峰值:{self.colors['NC']}") for i, (idx, value) in enumerate(peaks[:5]): - print(f" 峰值{i+1}: 索引={idx}, 值={value:.2f}MB") + print(f" 峰值{i + 1}: 索引={idx}, 值={value:.2f}MB") if len(peaks) > 5: print(f" ... 还有 {len(peaks) - 5} 个峰值") @@ -730,7 +730,7 @@ def visualize_matplotlib(self, csv_file: str, window_size: int = 10, output_file fit_line, "r--", linewidth=2, - label=f'linear fit (R²={fit_result["r_squared"]:.3f})', + label=f"linear fit (R²={fit_result['r_squared']:.3f})", ) else: fit_line = [fit_result["slope"] * x + fit_result["intercept"] for x in x_fit] @@ -739,7 +739,7 @@ def visualize_matplotlib(self, csv_file: str, window_size: int = 10, output_file fit_line, "r--", linewidth=2, - label=f'linear fit (R²={fit_result["r_squared"]:.3f})', + label=f"linear fit (R²={fit_result['r_squared']:.3f})", ) # 分析内存泄漏风险 @@ -750,12 +750,12 @@ def visualize_matplotlib(self, csv_file: str, window_size: int = 10, output_file # 添加文本框显示分析结果 textstr = "\n".join( [ - f'fit func: {fit_result["equation"]}', - f'ratio: {fit_result["slope"]:.6f}', - f'R²: {fit_result["r_squared"]:.4f}', - f'grow rate: {leak_analysis["slope_per_minute"]:.3f} MB/minute', - f'risk: {leak_analysis["risk_level"]}', - f'eval: {leak_analysis["description"]}', + f"fit func: {fit_result['equation']}", + f"ratio: {fit_result['slope']:.6f}", + f"R²: {fit_result['r_squared']:.4f}", + f"grow rate: {leak_analysis['slope_per_minute']:.3f} MB/minute", + f"risk: {leak_analysis['risk_level']}", + f"eval: {leak_analysis['description']}", ] ) @@ -799,7 +799,7 @@ def visualize_matplotlib(self, csv_file: str, window_size: int = 10, output_file peak_fit_line, "m-", linewidth=2, - label=f'peak fit (R²={peak_fit_result["r_squared"]:.3f})', + label=f"peak fit (R²={peak_fit_result['r_squared']:.3f})", ) else: # 使用索引进行峰值拟合 @@ -814,7 +814,7 @@ def visualize_matplotlib(self, csv_file: str, window_size: int = 10, output_file peak_fit_line, "m-", linewidth=2, - label=f'peak fit (R²={peak_fit_result["r_squared"]:.3f})', + label=f"peak fit (R²={peak_fit_result['r_squared']:.3f})", ) # 添加峰值拟合分析文本框 @@ -827,11 +827,11 @@ def visualize_matplotlib(self, csv_file: str, window_size: int = 10, output_file peak_textstr = "\n".join( [ - f'peak fit: {peak_fit_result["equation"]}', - f'peak ratio: {peak_fit_result["slope"]:.6f}', - f'peak R²: {peak_fit_result["r_squared"]:.4f}', - f'peak grow rate: {peak_leak_analysis["slope_per_minute"]:.3f} MB/minute', - f'peak risk: {peak_leak_analysis["risk_level"]}', + f"peak fit: {peak_fit_result['equation']}", + f"peak ratio: {peak_fit_result['slope']:.6f}", + f"peak R²: {peak_fit_result['r_squared']:.4f}", + f"peak grow rate: {peak_leak_analysis['slope_per_minute']:.3f} MB/minute", + f"peak risk: {peak_leak_analysis['risk_level']}", ] ) diff --git a/tools/viewer/weight_viewer.py b/tools/viewer/weight_viewer.py index 53cca7e..abcffa0 100644 --- a/tools/viewer/weight_viewer.py +++ b/tools/viewer/weight_viewer.py @@ -205,11 +205,11 @@ def traverse_stats(node, prefix=""): def _format_number(self, num: int) -> str: """Format number display""" if num >= 1e9: - return f"{num/1e9:.2f}B" + return f"{num / 1e9:.2f}B" elif num >= 1e6: - return f"{num/1e6:.2f}M" + return f"{num / 1e6:.2f}M" elif num >= 1e3: - return f"{num/1e3:.2f}K" + return f"{num / 1e3:.2f}K" else: return str(num)