Skip to content
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,9 @@ ENV/
env.bak/
venv.bak/

# uv
uv.lock

# Spyder project settings
.spyderproject
.spyproject
Expand Down
7 changes: 6 additions & 1 deletion telefuser/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,6 @@
from ._version import __version__
try:
from ._version import __version__
except ModuleNotFoundError as exc:
if exc.name != "telefuser._version":
raise
__version__ = "0.0.0+unknown"
21 changes: 14 additions & 7 deletions telefuser/models/lingbot_world_fast_dit.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,12 @@
from .wan_video_dit import apply_rotary_emb, precompute_freqs_cis_3d, sinusoidal_embedding_1d


def _cache_index_to_int(value: int | torch.Tensor) -> int:
if isinstance(value, int):
return value
return int(value.item())


class CausalSelfAttention(nn.Module):
"""Causal self-attention with rolling KV cache support."""

Expand Down Expand Up @@ -90,7 +96,7 @@ def forward(
freqs_cos: torch.Tensor,
freqs_sin: torch.Tensor,
grid_size: tuple[int, int, int],
kv_cache: dict[str, torch.Tensor],
kv_cache: dict[str, torch.Tensor | int],
current_start: int,
max_attention_size: int,
) -> torch.Tensor:
Expand All @@ -111,8 +117,8 @@ def forward(
cache_k = kv_cache["k"]
cache_v = kv_cache["v"]
kv_cache_size = cache_k.shape[1]
global_end = int(kv_cache["global_end_index"].item())
local_end = int(kv_cache["local_end_index"].item())
global_end = _cache_index_to_int(kv_cache["global_end_index"])
local_end = _cache_index_to_int(kv_cache["local_end_index"])

if self.local_attn_size != -1 and current_end > global_end and num_new_tokens + local_end > kv_cache_size:
evicted = num_new_tokens + local_end - kv_cache_size
Expand Down Expand Up @@ -141,8 +147,8 @@ def forward(
out = F.scaled_dot_product_attention(q, k_cache, v_cache, is_causal=False)
out = out.permute(0, 2, 1, 3).contiguous()

kv_cache["global_end_index"].fill_(current_end)
kv_cache["local_end_index"].fill_(local_end)
kv_cache["global_end_index"] = current_end
kv_cache["local_end_index"] = local_end

out = rearrange(out, "b s n d -> b s (n d)")
return self.o(out)
Expand Down Expand Up @@ -241,7 +247,7 @@ def forward(
freqs_cos: torch.Tensor,
freqs_sin: torch.Tensor,
grid_size: tuple[int, int, int],
kv_cache: dict[str, torch.Tensor],
kv_cache: dict[str, torch.Tensor | int],
crossattn_cache: dict[str, torch.Tensor | bool],
current_start: int,
max_attention_size: int,
Expand Down Expand Up @@ -434,6 +440,7 @@ def _prepare_control_tokens(self, control_tensor: torch.Tensor | None) -> torch.
return control_tokens + hidden

def patchify(self, x: torch.Tensor) -> tuple[torch.Tensor, tuple[int, int, int]]:
x = x.contiguous(memory_format=torch.channels_last_3d)
x = self.patch_embedding(x)
grid_size = x.shape[2:]
return rearrange(x, "b c f h w -> b (f h w) c").contiguous(), grid_size
Expand All @@ -457,7 +464,7 @@ def forward(
context: torch.Tensor,
y: torch.Tensor | None = None,
control_tensor: torch.Tensor | None = None,
kv_cache: list[dict[str, torch.Tensor]] | None = None,
kv_cache: list[dict[str, torch.Tensor | int]] | None = None,
crossattn_cache: list[dict[str, torch.Tensor | bool]] | None = None,
current_start: int = 0,
max_attention_size: int = 1_000_000,
Expand Down
1 change: 1 addition & 0 deletions telefuser/models/wan_video_vae.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ def forward(self, x: torch.Tensor, cache_x: torch.Tensor | None = None) -> torch
x = torch.cat([cache_x, x], dim=2)
padding[4] -= cache_x.shape[2]
x = F.pad(x, padding)
x = x.contiguous(memory_format=torch.channels_last_3d)
return super().forward(x)


Expand Down
5 changes: 2 additions & 3 deletions telefuser/ops/normalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,9 +160,8 @@ def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
return self.forward_native(x)
if not self.elementwise_affine:
return self.forward_native(x)

layer_norm_fn = _get_triton_kernel("layer_norm_fn")
return layer_norm_fn(x, self.weight, self.bias, eps=self.eps)
# Triton kernel in eager mode currently bring performance degradation
return self.forward_native(x)

def forward_native(self, x: torch.Tensor) -> torch.Tensor:
"""PyTorch-native implementation for compile compatibility."""
Expand Down
2 changes: 1 addition & 1 deletion telefuser/pipelines/lingbot_world_fast/denoising.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def denoise_chunk(
timesteps: torch.Tensor,
scheduler: FlowUniPCMultistepScheduler,
control_chunk: torch.Tensor | None,
self_kv_cache: list[dict[str, torch.Tensor]],
self_kv_cache: list[dict[str, torch.Tensor | int]],
crossattn_cache: list[dict[str, torch.Tensor | bool]],
current_start: int,
max_attention_size: int,
Expand Down
137 changes: 85 additions & 52 deletions telefuser/pipelines/lingbot_world_fast/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from telefuser.schedulers.unipc import FlowUniPCMultistepScheduler
from telefuser.utils.logging import logger
from telefuser.utils.model_weight import load_state_dict
from telefuser.utils.profiler import ProfilingContext4Debug

from .control import (
build_action_control_chunk,
Expand All @@ -45,6 +46,8 @@ class LingBotWorldFastPipelineConfig:
orig_height: int = 480
orig_width: int = 832
max_area: int = 480 * 832
local_attn_size: int = -1
sink_size: int = 0


class LingBotWorldFastPipeline(BasePipeline):
Expand Down Expand Up @@ -103,13 +106,34 @@ def init(self, module_manager, config: LingBotWorldFastPipelineConfig) -> None:
str(fast_path),
torch_dtype=config.dit_torch_dtype,
control_type=config.control_type,
config={"patch_size": (1, 2, 2), "text_len": 512, "control_type": config.control_type},
config=self._build_dit_config(config),
).to(self.device)
self.dit.eval().requires_grad_(False)

self.denoise_stage = LingBotWorldFastDenoisingStage(self.dit, torch_dtype=self.torch_dtype)
self.timesteps = LingBotWorldFastTimesteps()

@staticmethod
def _build_dit_config(config: LingBotWorldFastPipelineConfig) -> dict[str, object]:
return {
"patch_size": (1, 2, 2),
"text_len": 512,
"control_type": config.control_type,
"local_attn_size": int(config.local_attn_size),
"sink_size": int(config.sink_size),
}

@staticmethod
def _resolve_self_kv_size(
*,
frame_tokens: int,
latent_frames: int,
config: LingBotWorldFastPipelineConfig,
) -> int:
if int(config.local_attn_size) > -1:
return int(frame_tokens) * int(config.local_attn_size)
return int(frame_tokens) * int(latent_frames)

@torch.inference_mode()
def encode_prompt(self, prompt: str) -> torch.Tensor:
ids, mask = self.tokenizer(prompt, return_mask=True, add_special_tokens=True)
Expand Down Expand Up @@ -185,15 +209,15 @@ def _init_self_kv_cache(
kv_size: int,
dtype: torch.dtype,
device: str | torch.device,
) -> list[dict[str, torch.Tensor]]:
) -> list[dict[str, torch.Tensor | int]]:
head_dim = self.dit.dim // self.dit.num_heads
shape = (batch_size, kv_size, self.dit.num_heads, head_dim)
return [
{
"k": torch.zeros(shape, dtype=dtype, device=device),
"v": torch.zeros(shape, dtype=dtype, device=device),
"global_end_index": torch.tensor([0], dtype=torch.long, device=device),
"local_end_index": torch.tensor([0], dtype=torch.long, device=device),
"global_end_index": 0,
"local_end_index": 0,
}
for _ in range(self.dit.num_layers)
]
Expand Down Expand Up @@ -370,6 +394,7 @@ def build_control_override(
)
return control.control_tensor

@ProfilingContext4Debug("create_runtime")
@torch.inference_mode()
def create_runtime(
self,
Expand Down Expand Up @@ -415,7 +440,11 @@ def create_runtime(
frame_num = (lat_f - 1) * 4 + 1
patch_area = self.dit.patch_size[1] * self.dit.patch_size[2]
frame_tokens = (lat_h * lat_w) // patch_area
kv_size = frame_tokens * lat_f
kv_size = self._resolve_self_kv_size(
frame_tokens=frame_tokens,
latent_frames=lat_f,
config=self.config,
)
max_seq_len = session_config.chunk_size * frame_tokens
max_attention_size = (
kv_size if session_config.max_attention_size is None else int(session_config.max_attention_size)
Expand Down Expand Up @@ -497,55 +526,59 @@ def generate_next_chunk(
runtime.active = False
return []

idx = runtime.current_chunk_index
latent_chunk = runtime.noise_chunks[idx]
condition_chunk = runtime.condition_chunks[idx]
control_chunk = control_override
if control_chunk is None and runtime.control_chunks is not None and idx < len(runtime.control_chunks):
control_chunk = runtime.control_chunks[idx]

self._notify_progress(progress_callback, "denoising_chunk", index=idx)
current_start = idx * runtime.chunk_size * runtime.frame_tokens
denoised = self.denoise_stage.denoise_chunk(
latent_chunk=latent_chunk,
condition_chunk=condition_chunk,
prompt_emb=runtime.prompt_emb,
timesteps=runtime.timesteps,
scheduler=runtime.scheduler,
control_chunk=control_chunk,
self_kv_cache=runtime.self_kv_cache,
crossattn_cache=runtime.crossattn_cache,
current_start=current_start,
max_attention_size=runtime.max_attention_size,
generator=runtime.generator,
)
with ProfilingContext4Debug("generate_next_chunk"):
idx = runtime.current_chunk_index
latent_chunk = runtime.noise_chunks[idx]
condition_chunk = runtime.condition_chunks[idx]
control_chunk = control_override
if control_chunk is None and runtime.control_chunks is not None and idx < len(runtime.control_chunks):
control_chunk = runtime.control_chunks[idx]

self._notify_progress(progress_callback, "denoising_chunk", index=idx)
current_start = idx * runtime.chunk_size * runtime.frame_tokens
with ProfilingContext4Debug("denoise_chunk"):
denoised = self.denoise_stage.denoise_chunk(
latent_chunk=latent_chunk,
condition_chunk=condition_chunk,
prompt_emb=runtime.prompt_emb,
timesteps=runtime.timesteps,
scheduler=runtime.scheduler,
control_chunk=control_chunk,
self_kv_cache=runtime.self_kv_cache,
crossattn_cache=runtime.crossattn_cache,
current_start=current_start,
max_attention_size=runtime.max_attention_size,
generator=runtime.generator,
)

self._notify_progress(progress_callback, "updating_cache", index=idx)
self.dit(
x=denoised.to(dtype=self.torch_dtype),
timestep=torch.zeros((1,), dtype=torch.float32, device=self.device),
context=runtime.prompt_emb,
y=condition_chunk,
control_tensor=control_chunk,
kv_cache=runtime.self_kv_cache,
crossattn_cache=runtime.crossattn_cache,
current_start=current_start,
max_attention_size=runtime.max_attention_size,
)
self._notify_progress(progress_callback, "updating_cache", index=idx)
with ProfilingContext4Debug("kv_cache_update_forward"):
self.dit(
x=denoised.to(dtype=self.torch_dtype),
timestep=torch.zeros((1,), dtype=torch.float32, device=self.device),
context=runtime.prompt_emb,
y=condition_chunk,
control_tensor=control_chunk,
kv_cache=runtime.self_kv_cache,
crossattn_cache=runtime.crossattn_cache,
current_start=current_start,
max_attention_size=runtime.max_attention_size,
)

self._notify_progress(progress_callback, "decoding_chunk", index=idx, device=str(self.vae_device))
frames = self.decode_video_cached(
denoised,
is_first_clip=(idx == 0),
is_last_clip=(idx == len(runtime.noise_chunks) - 1),
)
images = self.tensor2video(frames)
self._notify_progress(progress_callback, "chunk_decoded", index=idx, frames=len(images))
runtime.current_chunk_index += 1
runtime.emitted_frames += len(images)
if runtime.current_chunk_index >= len(runtime.noise_chunks):
runtime.active = False
return images
self._notify_progress(progress_callback, "decoding_chunk", index=idx, device=str(self.vae_device))
with ProfilingContext4Debug("vae_decode"):
frames = self.decode_video_cached(
denoised,
is_first_clip=(idx == 0),
is_last_clip=(idx == len(runtime.noise_chunks) - 1),
)
images = self.tensor2video(frames)
self._notify_progress(progress_callback, "chunk_decoded", index=idx, frames=len(images))
runtime.current_chunk_index += 1
runtime.emitted_frames += len(images)
if runtime.current_chunk_index >= len(runtime.noise_chunks):
runtime.active = False
return images

@staticmethod
def encode_frames_to_b64(frames: list[Image.Image], quality: int = 85) -> list[str]:
Expand Down
12 changes: 11 additions & 1 deletion telefuser/pipelines/lingbot_world_fast/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,13 @@
import threading
import time
import uuid
from collections.abc import AsyncGenerator
from collections.abc import AsyncGenerator, Callable

import torch
from PIL import Image, ImageDraw

from telefuser.utils.logging import logger
from telefuser.utils.profiler import ProfilingContext4Debug

from .pipeline import LingBotWorldFastPipeline
from .session import (
Expand Down Expand Up @@ -351,6 +352,15 @@ def emit_status(stage: str, **data: object) -> None:
payload.update(data)
self._put_output(state, payload)

with ProfilingContext4Debug("workloop"):
self._run_worker_loop(session_id, state, emit_status)

def _run_worker_loop(
self,
session_id: str,
state: LingBotWorldFastSessionState,
emit_status: Callable[..., None],
) -> None:
try:
self._emit_preview_frame(state)
emit_status("initializing_runtime")
Expand Down
2 changes: 1 addition & 1 deletion telefuser/pipelines/lingbot_world_fast/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ class LingBotWorldFastRuntimeState:
condition_chunks: list[torch.Tensor]
control_chunks: list[torch.Tensor] | None
timesteps: torch.Tensor
self_kv_cache: list[dict[str, torch.Tensor]]
self_kv_cache: list[dict[str, torch.Tensor | int]]
crossattn_cache: list[dict[str, torch.Tensor | bool]]
latent_h: int
latent_w: int
Expand Down
Loading
Loading