From 285bb41596614d80911ecbaff5d7215df12b14d6 Mon Sep 17 00:00:00 2001 From: jader Date: Tue, 23 Jun 2026 07:53:53 +0000 Subject: [PATCH 1/8] feat(pipeline/lingbot): add profiling context for performance monitoring in runtime and service --- .../pipelines/lingbot_world_fast/pipeline.py | 100 ++++++++++-------- .../pipelines/lingbot_world_fast/service.py | 12 ++- 2 files changed, 64 insertions(+), 48 deletions(-) diff --git a/telefuser/pipelines/lingbot_world_fast/pipeline.py b/telefuser/pipelines/lingbot_world_fast/pipeline.py index 4e22735..713157d 100644 --- a/telefuser/pipelines/lingbot_world_fast/pipeline.py +++ b/telefuser/pipelines/lingbot_world_fast/pipeline.py @@ -20,6 +20,7 @@ from telefuser.schedulers.unipc import FlowUniPCMultistepScheduler from telefuser.utils.logging import logger from telefuser.utils.model_weight import load_state_dict +from telefuser.utils.profiler import ProfilingContext4Debug from .control import ( build_action_control_chunk, @@ -370,6 +371,7 @@ def build_control_override( ) return control.control_tensor + @ProfilingContext4Debug("create_runtime") @torch.inference_mode() def create_runtime( self, @@ -497,55 +499,59 @@ def generate_next_chunk( runtime.active = False return [] - idx = runtime.current_chunk_index - latent_chunk = runtime.noise_chunks[idx] - condition_chunk = runtime.condition_chunks[idx] - control_chunk = control_override - if control_chunk is None and runtime.control_chunks is not None and idx < len(runtime.control_chunks): - control_chunk = runtime.control_chunks[idx] - - self._notify_progress(progress_callback, "denoising_chunk", index=idx) - current_start = idx * runtime.chunk_size * runtime.frame_tokens - denoised = self.denoise_stage.denoise_chunk( - latent_chunk=latent_chunk, - condition_chunk=condition_chunk, - prompt_emb=runtime.prompt_emb, - timesteps=runtime.timesteps, - scheduler=runtime.scheduler, - control_chunk=control_chunk, - self_kv_cache=runtime.self_kv_cache, - crossattn_cache=runtime.crossattn_cache, - current_start=current_start, - max_attention_size=runtime.max_attention_size, - generator=runtime.generator, - ) + with ProfilingContext4Debug("generate_next_chunk"): + idx = runtime.current_chunk_index + latent_chunk = runtime.noise_chunks[idx] + condition_chunk = runtime.condition_chunks[idx] + control_chunk = control_override + if control_chunk is None and runtime.control_chunks is not None and idx < len(runtime.control_chunks): + control_chunk = runtime.control_chunks[idx] + + self._notify_progress(progress_callback, "denoising_chunk", index=idx) + current_start = idx * runtime.chunk_size * runtime.frame_tokens + with ProfilingContext4Debug("denoise_chunk"): + denoised = self.denoise_stage.denoise_chunk( + latent_chunk=latent_chunk, + condition_chunk=condition_chunk, + prompt_emb=runtime.prompt_emb, + timesteps=runtime.timesteps, + scheduler=runtime.scheduler, + control_chunk=control_chunk, + self_kv_cache=runtime.self_kv_cache, + crossattn_cache=runtime.crossattn_cache, + current_start=current_start, + max_attention_size=runtime.max_attention_size, + generator=runtime.generator, + ) - self._notify_progress(progress_callback, "updating_cache", index=idx) - self.dit( - x=denoised.to(dtype=self.torch_dtype), - timestep=torch.zeros((1,), dtype=torch.float32, device=self.device), - context=runtime.prompt_emb, - y=condition_chunk, - control_tensor=control_chunk, - kv_cache=runtime.self_kv_cache, - crossattn_cache=runtime.crossattn_cache, - current_start=current_start, - max_attention_size=runtime.max_attention_size, - ) + self._notify_progress(progress_callback, "updating_cache", index=idx) + with ProfilingContext4Debug("kv_cache_update_forward"): + self.dit( + x=denoised.to(dtype=self.torch_dtype), + timestep=torch.zeros((1,), dtype=torch.float32, device=self.device), + context=runtime.prompt_emb, + y=condition_chunk, + control_tensor=control_chunk, + kv_cache=runtime.self_kv_cache, + crossattn_cache=runtime.crossattn_cache, + current_start=current_start, + max_attention_size=runtime.max_attention_size, + ) - self._notify_progress(progress_callback, "decoding_chunk", index=idx, device=str(self.vae_device)) - frames = self.decode_video_cached( - denoised, - is_first_clip=(idx == 0), - is_last_clip=(idx == len(runtime.noise_chunks) - 1), - ) - images = self.tensor2video(frames) - self._notify_progress(progress_callback, "chunk_decoded", index=idx, frames=len(images)) - runtime.current_chunk_index += 1 - runtime.emitted_frames += len(images) - if runtime.current_chunk_index >= len(runtime.noise_chunks): - runtime.active = False - return images + self._notify_progress(progress_callback, "decoding_chunk", index=idx, device=str(self.vae_device)) + with ProfilingContext4Debug("vae_decode"): + frames = self.decode_video_cached( + denoised, + is_first_clip=(idx == 0), + is_last_clip=(idx == len(runtime.noise_chunks) - 1), + ) + images = self.tensor2video(frames) + self._notify_progress(progress_callback, "chunk_decoded", index=idx, frames=len(images)) + runtime.current_chunk_index += 1 + runtime.emitted_frames += len(images) + if runtime.current_chunk_index >= len(runtime.noise_chunks): + runtime.active = False + return images @staticmethod def encode_frames_to_b64(frames: list[Image.Image], quality: int = 85) -> list[str]: diff --git a/telefuser/pipelines/lingbot_world_fast/service.py b/telefuser/pipelines/lingbot_world_fast/service.py index 79c4523..1ea3a72 100644 --- a/telefuser/pipelines/lingbot_world_fast/service.py +++ b/telefuser/pipelines/lingbot_world_fast/service.py @@ -7,12 +7,13 @@ import threading import time import uuid -from collections.abc import AsyncGenerator +from collections.abc import AsyncGenerator, Callable import torch from PIL import Image, ImageDraw from telefuser.utils.logging import logger +from telefuser.utils.profiler import ProfilingContext4Debug from .pipeline import LingBotWorldFastPipeline from .session import ( @@ -351,6 +352,15 @@ def emit_status(stage: str, **data: object) -> None: payload.update(data) self._put_output(state, payload) + with ProfilingContext4Debug("workloop"): + self._run_worker_loop(session_id, state, emit_status) + + def _run_worker_loop( + self, + session_id: str, + state: LingBotWorldFastSessionState, + emit_status: Callable[..., None], + ) -> None: try: self._emit_preview_frame(state) emit_status("initializing_runtime") From a122ea355d75e37ad3a8d3f77965dd4f47460908 Mon Sep 17 00:00:00 2001 From: jader Date: Thu, 25 Jun 2026 10:35:29 +0000 Subject: [PATCH 2/8] feat(pipeline): add local attention and sink size configuration options --- .../pipelines/lingbot_world_fast/pipeline.py | 31 +++++++++++++++++-- 1 file changed, 29 insertions(+), 2 deletions(-) diff --git a/telefuser/pipelines/lingbot_world_fast/pipeline.py b/telefuser/pipelines/lingbot_world_fast/pipeline.py index 713157d..a06d804 100644 --- a/telefuser/pipelines/lingbot_world_fast/pipeline.py +++ b/telefuser/pipelines/lingbot_world_fast/pipeline.py @@ -46,6 +46,8 @@ class LingBotWorldFastPipelineConfig: orig_height: int = 480 orig_width: int = 832 max_area: int = 480 * 832 + local_attn_size: int = -1 + sink_size: int = 0 class LingBotWorldFastPipeline(BasePipeline): @@ -104,13 +106,34 @@ def init(self, module_manager, config: LingBotWorldFastPipelineConfig) -> None: str(fast_path), torch_dtype=config.dit_torch_dtype, control_type=config.control_type, - config={"patch_size": (1, 2, 2), "text_len": 512, "control_type": config.control_type}, + config=self._build_dit_config(config), ).to(self.device) self.dit.eval().requires_grad_(False) self.denoise_stage = LingBotWorldFastDenoisingStage(self.dit, torch_dtype=self.torch_dtype) self.timesteps = LingBotWorldFastTimesteps() + @staticmethod + def _build_dit_config(config: LingBotWorldFastPipelineConfig) -> dict[str, object]: + return { + "patch_size": (1, 2, 2), + "text_len": 512, + "control_type": config.control_type, + "local_attn_size": int(config.local_attn_size), + "sink_size": int(config.sink_size), + } + + @staticmethod + def _resolve_self_kv_size( + *, + frame_tokens: int, + latent_frames: int, + config: LingBotWorldFastPipelineConfig, + ) -> int: + if int(config.local_attn_size) > -1: + return int(frame_tokens) * int(config.local_attn_size) + return int(frame_tokens) * int(latent_frames) + @torch.inference_mode() def encode_prompt(self, prompt: str) -> torch.Tensor: ids, mask = self.tokenizer(prompt, return_mask=True, add_special_tokens=True) @@ -417,7 +440,11 @@ def create_runtime( frame_num = (lat_f - 1) * 4 + 1 patch_area = self.dit.patch_size[1] * self.dit.patch_size[2] frame_tokens = (lat_h * lat_w) // patch_area - kv_size = frame_tokens * lat_f + kv_size = self._resolve_self_kv_size( + frame_tokens=frame_tokens, + latent_frames=lat_f, + config=self.config, + ) max_seq_len = session_config.chunk_size * frame_tokens max_attention_size = ( kv_size if session_config.max_attention_size is None else int(session_config.max_attention_size) From aba8f1c39f2139f084da12e07f807bd6f6e043d1 Mon Sep 17 00:00:00 2001 From: jader Date: Wed, 20 May 2026 03:08:07 +0000 Subject: [PATCH 3/8] feat(utils/profiler): Expose torch.profiler flags by environment variables --- telefuser/utils/profiler.py | 34 ++++++++++- tests/unit/utils/test_profiler_flags.py | 78 +++++++++++++++++++++++++ 2 files changed, 109 insertions(+), 3 deletions(-) create mode 100644 tests/unit/utils/test_profiler_flags.py diff --git a/telefuser/utils/profiler.py b/telefuser/utils/profiler.py index bbd8c81..d2160a2 100644 --- a/telefuser/utils/profiler.py +++ b/telefuser/utils/profiler.py @@ -15,6 +15,10 @@ - Kernel categorization - Chrome trace visualization - Enable: ENABLE_PROFILER_NAMES=stage_name1,stage_name2 + - Optional torch.profiler flags: + TELEFUSER_TORCH_PROFILER_RECORD_SHAPES=true|false + TELEFUSER_TORCH_PROFILER_PROFILE_MEMORY=true|false + TELEFUSER_TORCH_PROFILER_WITH_STACK=true|false Layer 3 (External tool): - ncu deep kernel analysis @@ -178,6 +182,32 @@ def _should_enable_profiler(name: str) -> bool: return name in enabled_set +def _env_bool(name: str, default: bool) -> bool: + """Read a boolean env var while preserving default behavior on unset/invalid values.""" + value = os.getenv(name) + if value is None or value.strip() == "": + return default + normalized = value.strip().lower() + if normalized in ("1", "true", "yes", "on"): + return True + if normalized in ("0", "false", "no", "off"): + return False + logger.warning(f"[Profiler] Invalid boolean value for {name}={value!r}; using default {default}.") + return default + + +def _get_torch_profiler_options() -> dict[str, bool]: + """Get configurable torch.profiler options. + + Defaults intentionally match the historical TeleFuser profiler behavior. + """ + return { + "record_shapes": _env_bool("TELEFUSER_TORCH_PROFILER_RECORD_SHAPES", True), + "profile_memory": _env_bool("TELEFUSER_TORCH_PROFILER_PROFILE_MEMORY", True), + "with_stack": _env_bool("TELEFUSER_TORCH_PROFILER_WITH_STACK", True), + } + + def _get_timing_report_path() -> str | None: """Get Layer 1 timing report output path from env.""" return os.getenv("TELEFUSER_TIMING_REPORT") @@ -694,9 +724,7 @@ def __enter__(self): self._profiler = torch.profiler.profile( activities=activities, - record_shapes=True, - profile_memory=True, - with_stack=True, + **_get_torch_profiler_options(), ) self._profiler.start() logger.info(f"{self._rank_info} [Profiler] Started torch.profiler for '{self.name}'") diff --git a/tests/unit/utils/test_profiler_flags.py b/tests/unit/utils/test_profiler_flags.py new file mode 100644 index 0000000..d6e4b64 --- /dev/null +++ b/tests/unit/utils/test_profiler_flags.py @@ -0,0 +1,78 @@ +from __future__ import annotations + +from telefuser.utils import profiler as profiler_module + + +class _FakePlatform: + device_type = "cpu" + + def synchronize(self) -> None: + pass + + +class _FakeTorchProfiler: + def __init__(self, **kwargs) -> None: + self.kwargs = kwargs + + def start(self) -> None: + pass + + def stop(self) -> None: + pass + + def export_chrome_trace(self, path: str) -> None: + self.trace_path = path + + def key_averages(self): + return [] + + +def test_torch_profiler_options_follow_env(monkeypatch, tmp_path) -> None: + profile_kwargs: list[dict] = [] + + def fake_profile(**kwargs): + profile_kwargs.append(kwargs) + return _FakeTorchProfiler(**kwargs) + + monkeypatch.setenv("ENABLE_PROFILER_NAMES", "outer") + monkeypatch.setenv("TELEFUSER_PROFILER_OUTPUT_DIR", str(tmp_path)) + monkeypatch.setenv("TELEFUSER_TORCH_PROFILER_RECORD_SHAPES", "false") + monkeypatch.setenv("TELEFUSER_TORCH_PROFILER_PROFILE_MEMORY", "false") + monkeypatch.setenv("TELEFUSER_TORCH_PROFILER_WITH_STACK", "false") + monkeypatch.setattr(profiler_module, "current_platform", _FakePlatform()) + monkeypatch.setattr(profiler_module.torch.profiler, "profile", fake_profile) + monkeypatch.setattr(profiler_module, "reset_peak_memory_stats", lambda: None) + monkeypatch.setattr(profiler_module, "capture_memory_snapshot", lambda: None) + + context = profiler_module.ProfilingContext("outer") + context.__enter__() + assert context._profiler is not None + context._profiler.stop() + + assert profile_kwargs[0]["record_shapes"] is False + assert profile_kwargs[0]["profile_memory"] is False + assert profile_kwargs[0]["with_stack"] is False + + +def test_torch_profiler_options_preserve_existing_defaults(monkeypatch, tmp_path) -> None: + profile_kwargs: list[dict] = [] + + def fake_profile(**kwargs): + profile_kwargs.append(kwargs) + return _FakeTorchProfiler(**kwargs) + + monkeypatch.setenv("ENABLE_PROFILER_NAMES", "outer") + monkeypatch.setenv("TELEFUSER_PROFILER_OUTPUT_DIR", str(tmp_path)) + monkeypatch.setattr(profiler_module, "current_platform", _FakePlatform()) + monkeypatch.setattr(profiler_module.torch.profiler, "profile", fake_profile) + monkeypatch.setattr(profiler_module, "reset_peak_memory_stats", lambda: None) + monkeypatch.setattr(profiler_module, "capture_memory_snapshot", lambda: None) + + context = profiler_module.ProfilingContext("outer") + context.__enter__() + assert context._profiler is not None + context._profiler.stop() + + assert profile_kwargs[0]["record_shapes"] is True + assert profile_kwargs[0]["profile_memory"] is True + assert profile_kwargs[0]["with_stack"] is True From 27ec05ac8e6918a8ab0c23e1144e3a5392aa5222 Mon Sep 17 00:00:00 2001 From: jader Date: Mon, 29 Jun 2026 09:57:32 +0000 Subject: [PATCH 4/8] chore(gitignore): ignore uv lock file --- .gitignore | 3 +++ 1 file changed, 3 insertions(+) diff --git a/.gitignore b/.gitignore index 2f78881..e0b3fbe 100755 --- a/.gitignore +++ b/.gitignore @@ -93,6 +93,9 @@ ENV/ env.bak/ venv.bak/ +# uv +uv.lock + # Spyder project settings .spyderproject .spyproject From cc2e5f0046ec53db084579379dd3274970d0f63e Mon Sep 17 00:00:00 2001 From: jader Date: Mon, 29 Jun 2026 09:57:43 +0000 Subject: [PATCH 5/8] fix(package): handle missing generated version --- telefuser/__init__.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/telefuser/__init__.py b/telefuser/__init__.py index 8dee4bf..3a72a53 100644 --- a/telefuser/__init__.py +++ b/telefuser/__init__.py @@ -1 +1,6 @@ -from ._version import __version__ +try: + from ._version import __version__ +except ModuleNotFoundError as exc: + if exc.name != "telefuser._version": + raise + __version__ = "0.0.0+unknown" From 2d9c91c93f32ca2cb9af4917aeef585a145c0586 Mon Sep 17 00:00:00 2001 From: jader Date: Mon, 29 Jun 2026 09:57:50 +0000 Subject: [PATCH 6/8] perf(ops): use native layernorm in eager mode --- telefuser/ops/normalization.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/telefuser/ops/normalization.py b/telefuser/ops/normalization.py index 4f00d78..38c0ef0 100644 --- a/telefuser/ops/normalization.py +++ b/telefuser/ops/normalization.py @@ -160,9 +160,8 @@ def forward_cuda(self, x: torch.Tensor) -> torch.Tensor: return self.forward_native(x) if not self.elementwise_affine: return self.forward_native(x) - - layer_norm_fn = _get_triton_kernel("layer_norm_fn") - return layer_norm_fn(x, self.weight, self.bias, eps=self.eps) + # Triton kernel in eager mode currently bring performance degradation + return self.forward_native(x) def forward_native(self, x: torch.Tensor) -> torch.Tensor: """PyTorch-native implementation for compile compatibility.""" From 161d1d55e561f652b05aca2ef410d30c36d19826 Mon Sep 17 00:00:00 2001 From: jader Date: Mon, 29 Jun 2026 09:58:10 +0000 Subject: [PATCH 7/8] perf(models): use channels-last 3d for conv inputs --- telefuser/models/lingbot_world_fast_dit.py | 1 + telefuser/models/wan_video_vae.py | 1 + 2 files changed, 2 insertions(+) diff --git a/telefuser/models/lingbot_world_fast_dit.py b/telefuser/models/lingbot_world_fast_dit.py index cc0fe93..0838a9a 100644 --- a/telefuser/models/lingbot_world_fast_dit.py +++ b/telefuser/models/lingbot_world_fast_dit.py @@ -434,6 +434,7 @@ def _prepare_control_tokens(self, control_tensor: torch.Tensor | None) -> torch. return control_tokens + hidden def patchify(self, x: torch.Tensor) -> tuple[torch.Tensor, tuple[int, int, int]]: + x = x.contiguous(memory_format=torch.channels_last_3d) x = self.patch_embedding(x) grid_size = x.shape[2:] return rearrange(x, "b c f h w -> b (f h w) c").contiguous(), grid_size diff --git a/telefuser/models/wan_video_vae.py b/telefuser/models/wan_video_vae.py index 24abf18..51e9782 100644 --- a/telefuser/models/wan_video_vae.py +++ b/telefuser/models/wan_video_vae.py @@ -93,6 +93,7 @@ def forward(self, x: torch.Tensor, cache_x: torch.Tensor | None = None) -> torch x = torch.cat([cache_x, x], dim=2) padding[4] -= cache_x.shape[2] x = F.pad(x, padding) + x = x.contiguous(memory_format=torch.channels_last_3d) return super().forward(x) From c6a4024db68bcc8beb54ff57dcc55f6be81cf753 Mon Sep 17 00:00:00 2001 From: jader Date: Mon, 29 Jun 2026 09:58:42 +0000 Subject: [PATCH 8/8] perf(lingbot): keep kv cache indices on host --- telefuser/models/lingbot_world_fast_dit.py | 20 ++++++++++++------- .../pipelines/lingbot_world_fast/denoising.py | 2 +- .../pipelines/lingbot_world_fast/pipeline.py | 6 +++--- .../pipelines/lingbot_world_fast/session.py | 2 +- 4 files changed, 18 insertions(+), 12 deletions(-) diff --git a/telefuser/models/lingbot_world_fast_dit.py b/telefuser/models/lingbot_world_fast_dit.py index 0838a9a..c9685cf 100644 --- a/telefuser/models/lingbot_world_fast_dit.py +++ b/telefuser/models/lingbot_world_fast_dit.py @@ -17,6 +17,12 @@ from .wan_video_dit import apply_rotary_emb, precompute_freqs_cis_3d, sinusoidal_embedding_1d +def _cache_index_to_int(value: int | torch.Tensor) -> int: + if isinstance(value, int): + return value + return int(value.item()) + + class CausalSelfAttention(nn.Module): """Causal self-attention with rolling KV cache support.""" @@ -90,7 +96,7 @@ def forward( freqs_cos: torch.Tensor, freqs_sin: torch.Tensor, grid_size: tuple[int, int, int], - kv_cache: dict[str, torch.Tensor], + kv_cache: dict[str, torch.Tensor | int], current_start: int, max_attention_size: int, ) -> torch.Tensor: @@ -111,8 +117,8 @@ def forward( cache_k = kv_cache["k"] cache_v = kv_cache["v"] kv_cache_size = cache_k.shape[1] - global_end = int(kv_cache["global_end_index"].item()) - local_end = int(kv_cache["local_end_index"].item()) + global_end = _cache_index_to_int(kv_cache["global_end_index"]) + local_end = _cache_index_to_int(kv_cache["local_end_index"]) if self.local_attn_size != -1 and current_end > global_end and num_new_tokens + local_end > kv_cache_size: evicted = num_new_tokens + local_end - kv_cache_size @@ -141,8 +147,8 @@ def forward( out = F.scaled_dot_product_attention(q, k_cache, v_cache, is_causal=False) out = out.permute(0, 2, 1, 3).contiguous() - kv_cache["global_end_index"].fill_(current_end) - kv_cache["local_end_index"].fill_(local_end) + kv_cache["global_end_index"] = current_end + kv_cache["local_end_index"] = local_end out = rearrange(out, "b s n d -> b s (n d)") return self.o(out) @@ -241,7 +247,7 @@ def forward( freqs_cos: torch.Tensor, freqs_sin: torch.Tensor, grid_size: tuple[int, int, int], - kv_cache: dict[str, torch.Tensor], + kv_cache: dict[str, torch.Tensor | int], crossattn_cache: dict[str, torch.Tensor | bool], current_start: int, max_attention_size: int, @@ -458,7 +464,7 @@ def forward( context: torch.Tensor, y: torch.Tensor | None = None, control_tensor: torch.Tensor | None = None, - kv_cache: list[dict[str, torch.Tensor]] | None = None, + kv_cache: list[dict[str, torch.Tensor | int]] | None = None, crossattn_cache: list[dict[str, torch.Tensor | bool]] | None = None, current_start: int = 0, max_attention_size: int = 1_000_000, diff --git a/telefuser/pipelines/lingbot_world_fast/denoising.py b/telefuser/pipelines/lingbot_world_fast/denoising.py index dfc1d72..ba5188c 100644 --- a/telefuser/pipelines/lingbot_world_fast/denoising.py +++ b/telefuser/pipelines/lingbot_world_fast/denoising.py @@ -52,7 +52,7 @@ def denoise_chunk( timesteps: torch.Tensor, scheduler: FlowUniPCMultistepScheduler, control_chunk: torch.Tensor | None, - self_kv_cache: list[dict[str, torch.Tensor]], + self_kv_cache: list[dict[str, torch.Tensor | int]], crossattn_cache: list[dict[str, torch.Tensor | bool]], current_start: int, max_attention_size: int, diff --git a/telefuser/pipelines/lingbot_world_fast/pipeline.py b/telefuser/pipelines/lingbot_world_fast/pipeline.py index a06d804..d011c4c 100644 --- a/telefuser/pipelines/lingbot_world_fast/pipeline.py +++ b/telefuser/pipelines/lingbot_world_fast/pipeline.py @@ -209,15 +209,15 @@ def _init_self_kv_cache( kv_size: int, dtype: torch.dtype, device: str | torch.device, - ) -> list[dict[str, torch.Tensor]]: + ) -> list[dict[str, torch.Tensor | int]]: head_dim = self.dit.dim // self.dit.num_heads shape = (batch_size, kv_size, self.dit.num_heads, head_dim) return [ { "k": torch.zeros(shape, dtype=dtype, device=device), "v": torch.zeros(shape, dtype=dtype, device=device), - "global_end_index": torch.tensor([0], dtype=torch.long, device=device), - "local_end_index": torch.tensor([0], dtype=torch.long, device=device), + "global_end_index": 0, + "local_end_index": 0, } for _ in range(self.dit.num_layers) ] diff --git a/telefuser/pipelines/lingbot_world_fast/session.py b/telefuser/pipelines/lingbot_world_fast/session.py index 9a468f4..6432ea2 100644 --- a/telefuser/pipelines/lingbot_world_fast/session.py +++ b/telefuser/pipelines/lingbot_world_fast/session.py @@ -42,7 +42,7 @@ class LingBotWorldFastRuntimeState: condition_chunks: list[torch.Tensor] control_chunks: list[torch.Tensor] | None timesteps: torch.Tensor - self_kv_cache: list[dict[str, torch.Tensor]] + self_kv_cache: list[dict[str, torch.Tensor | int]] crossattn_cache: list[dict[str, torch.Tensor | bool]] latent_h: int latent_w: int