From 2cc7e116ef9bb405ef4ed81db9a165784728aa2e Mon Sep 17 00:00:00 2001 From: Daniel Gu Date: Fri, 30 Jan 2026 08:16:16 +0100 Subject: [PATCH 01/25] LTX2 condition pipeline initial commit --- src/diffusers/__init__.py | 2 + src/diffusers/pipelines/__init__.py | 8 +- src/diffusers/pipelines/ltx2/__init__.py | 2 + .../pipelines/ltx2/pipeline_ltx2_condition.py | 1332 +++++++++++++++++ 4 files changed, 1342 insertions(+), 2 deletions(-) create mode 100644 src/diffusers/pipelines/ltx2/pipeline_ltx2_condition.py diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 52ec30c536bd..c28c8011dc3d 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -552,6 +552,7 @@ "LEditsPPPipelineStableDiffusionXL", "LongCatImageEditPipeline", "LongCatImagePipeline", + "LTX2ConditionPipeline", "LTX2ImageToVideoPipeline", "LTX2LatentUpsamplePipeline", "LTX2Pipeline", @@ -1284,6 +1285,7 @@ LEditsPPPipelineStableDiffusionXL, LongCatImageEditPipeline, LongCatImagePipeline, + LTX2ConditionPipeline, LTX2ImageToVideoPipeline, LTX2LatentUpsamplePipeline, LTX2Pipeline, diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py index 65378631a172..39af113bfade 100644 --- a/src/diffusers/pipelines/__init__.py +++ b/src/diffusers/pipelines/__init__.py @@ -291,7 +291,11 @@ "LTXLatentUpsamplePipeline", "LTXI2VLongMultiPromptPipeline", ] - _import_structure["ltx2"] = ["LTX2Pipeline", "LTX2ImageToVideoPipeline", "LTX2LatentUpsamplePipeline"] + _import_structure["ltx2"] = [ + "LTX2Pipeline", + "LTX2ConditionPipelineLTX2ImageToVideoPipeline", + "LTX2LatentUpsamplePipeline", + ] _import_structure["lumina"] = ["LuminaPipeline", "LuminaText2ImgPipeline"] _import_structure["lumina2"] = ["Lumina2Pipeline", "Lumina2Text2ImgPipeline"] _import_structure["lucy"] = ["LucyEditPipeline"] @@ -742,7 +746,7 @@ LTXLatentUpsamplePipeline, LTXPipeline, ) - from .ltx2 import LTX2ImageToVideoPipeline, LTX2LatentUpsamplePipeline, LTX2Pipeline + from .ltx2 import LTX2ConditionPipeline, LTX2ImageToVideoPipeline, LTX2LatentUpsamplePipeline, LTX2Pipeline from .lucy import LucyEditPipeline from .lumina import LuminaPipeline, LuminaText2ImgPipeline from .lumina2 import Lumina2Pipeline, Lumina2Text2ImgPipeline diff --git a/src/diffusers/pipelines/ltx2/__init__.py b/src/diffusers/pipelines/ltx2/__init__.py index 115e83e827a4..d6a408d5c546 100644 --- a/src/diffusers/pipelines/ltx2/__init__.py +++ b/src/diffusers/pipelines/ltx2/__init__.py @@ -25,6 +25,7 @@ _import_structure["connectors"] = ["LTX2TextConnectors"] _import_structure["latent_upsampler"] = ["LTX2LatentUpsamplerModel"] _import_structure["pipeline_ltx2"] = ["LTX2Pipeline"] + _import_structure["pipeline_ltx2_condition"] = ["LTX2ConditionPipeline"] _import_structure["pipeline_ltx2_image2video"] = ["LTX2ImageToVideoPipeline"] _import_structure["pipeline_ltx2_latent_upsample"] = ["LTX2LatentUpsamplePipeline"] _import_structure["vocoder"] = ["LTX2Vocoder"] @@ -40,6 +41,7 @@ from .connectors import LTX2TextConnectors from .latent_upsampler import LTX2LatentUpsamplerModel from .pipeline_ltx2 import LTX2Pipeline + from .pipeline_ltx2_condition import LTX2ConditionPipeline from .pipeline_ltx2_image2video import LTX2ImageToVideoPipeline from .pipeline_ltx2_latent_upsample import LTX2LatentUpsamplePipeline from .vocoder import LTX2Vocoder diff --git a/src/diffusers/pipelines/ltx2/pipeline_ltx2_condition.py b/src/diffusers/pipelines/ltx2/pipeline_ltx2_condition.py new file mode 100644 index 000000000000..f69353d3a253 --- /dev/null +++ b/src/diffusers/pipelines/ltx2/pipeline_ltx2_condition.py @@ -0,0 +1,1332 @@ +# Copyright 2025 Lightricks and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copy +import inspect +from dataclasses import dataclass +from typing import Any, Callable, Dict, List, Optional, Union + +import numpy as np +import PIL.Image +import torch +from transformers import Gemma3ForConditionalGeneration, GemmaTokenizer, GemmaTokenizerFast + +from ...callbacks import MultiPipelineCallbacks, PipelineCallback +from ...image_processor import PipelineImageInput +from ...loaders import FromSingleFileMixin, LTXVideoLoraLoaderMixin +from ...models.autoencoders import AutoencoderKLLTX2Audio, AutoencoderKLLTX2Video +from ...models.transformers import LTX2VideoTransformer3DModel +from ...schedulers import FlowMatchEulerDiscreteScheduler +from ...utils import is_torch_xla_available, logging, replace_example_docstring +from ...utils.torch_utils import randn_tensor +from ...video_processor import VideoProcessor +from ..pipeline_utils import DiffusionPipeline +from .connectors import LTX2TextConnectors +from .pipeline_output import LTX2PipelineOutput +from .vocoder import LTX2Vocoder + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers import LTX2Pipeline + >>> from diffusers.pipelines.ltx2.export_utils import encode_video + >>> from diffusers.utils import load_image + + >>> pipe = LTX2ImageToVideoPipeline.from_pretrained("Lightricks/LTX-2", torch_dtype=torch.bfloat16) + >>> pipe.enable_model_cpu_offload() + + >>> image = load_image( + ... "https://huggingface.co/datasets/a-r-r-o-w/tiny-meme-dataset-captioned/resolve/main/images/8.png" + ... ) + >>> prompt = "A young girl stands calmly in the foreground, looking directly at the camera, as a house fire rages in the background." + >>> negative_prompt = "worst quality, inconsistent motion, blurry, jittery, distorted" + + >>> frame_rate = 24.0 + >>> video = pipe( + ... image=image, + ... prompt=prompt, + ... negative_prompt=negative_prompt, + ... width=768, + ... height=512, + ... num_frames=121, + ... frame_rate=frame_rate, + ... num_inference_steps=40, + ... guidance_scale=4.0, + ... output_type="np", + ... return_dict=False, + ... ) + >>> video = (video * 255).round().astype("uint8") + >>> video = torch.from_numpy(video) + + >>> encode_video( + ... video[0], + ... fps=frame_rate, + ... audio=audio[0].float().cpu(), + ... audio_sample_rate=pipe.vocoder.config.output_sampling_rate, # should be 24000 + ... output_path="video.mp4", + ... ) + ``` +""" + + +@dataclass +class LTXVideoCondition: + """ + Defines a single frame-conditioning item for LTX Video - a single frame or a sequence of frames. + + Attributes: + image (`PIL.Image.Image`): + The image to condition the video on. + video (`List[PIL.Image.Image]`): + The video to condition the video on. + frame_index (`int`): + The frame index at which the image or video will conditionally effect the video generation. + strength (`float`, defaults to `1.0`): + The strength of the conditioning effect. A value of `1.0` means the conditioning effect is fully applied. + """ + + image: Optional[PIL.Image.Image] = None + video: Optional[List[PIL.Image.Image]] = None + frame_index: int = 0 + strength: float = 1.0 + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents +def retrieve_latents( + encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample" +): + if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": + return encoder_output.latent_dist.sample(generator) + elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": + return encoder_output.latent_dist.mode() + elif hasattr(encoder_output, "latents"): + return encoder_output.latents + else: + raise AttributeError("Could not access latents of provided encoder_output") + + +# Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift +def calculate_shift( + image_seq_len, + base_seq_len: int = 256, + max_seq_len: int = 4096, + base_shift: float = 0.5, + max_shift: float = 1.15, +): + m = (max_shift - base_shift) / (max_seq_len - base_seq_len) + b = base_shift - m * base_seq_len + mu = image_seq_len * m + b + return mu + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: Optional[int] = None, + device: Optional[Union[str, torch.device]] = None, + timesteps: Optional[List[int]] = None, + sigmas: Optional[List[float]] = None, + **kwargs, +): + r""" + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`List[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`List[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg +def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): + r""" + Rescales `noise_cfg` tensor based on `guidance_rescale` to improve image quality and fix overexposure. Based on + Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are + Flawed](https://huggingface.co/papers/2305.08891). + + Args: + noise_cfg (`torch.Tensor`): + The predicted noise tensor for the guided diffusion process. + noise_pred_text (`torch.Tensor`): + The predicted noise tensor for the text-guided diffusion process. + guidance_rescale (`float`, *optional*, defaults to 0.0): + A rescale factor applied to the noise predictions. + + Returns: + noise_cfg (`torch.Tensor`): The rescaled noise prediction tensor. + """ + std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True) + std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True) + # rescale the results from guidance (fixes overexposure) + noise_pred_rescaled = noise_cfg * (std_text / std_cfg) + # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images + noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg + return noise_cfg + + +class LTX2ConditionPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraLoaderMixin): + r""" + Pipeline for video generation which allows image conditions to be inserted at arbitary parts of the video. + + Reference: https://github.com/Lightricks/LTX-Video + + TODO + """ + + model_cpu_offload_seq = "text_encoder->connectors->transformer->vae->audio_vae->vocoder" + _optional_components = [] + _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"] + + def __init__( + self, + scheduler: FlowMatchEulerDiscreteScheduler, + vae: AutoencoderKLLTX2Video, + audio_vae: AutoencoderKLLTX2Audio, + text_encoder: Gemma3ForConditionalGeneration, + tokenizer: Union[GemmaTokenizer, GemmaTokenizerFast], + connectors: LTX2TextConnectors, + transformer: LTX2VideoTransformer3DModel, + vocoder: LTX2Vocoder, + ): + super().__init__() + + self.register_modules( + vae=vae, + audio_vae=audio_vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + connectors=connectors, + transformer=transformer, + vocoder=vocoder, + scheduler=scheduler, + ) + + self.vae_spatial_compression_ratio = ( + self.vae.spatial_compression_ratio if getattr(self, "vae", None) is not None else 32 + ) + self.vae_temporal_compression_ratio = ( + self.vae.temporal_compression_ratio if getattr(self, "vae", None) is not None else 8 + ) + # TODO: check whether the MEL compression ratio logic here is corrct + self.audio_vae_mel_compression_ratio = ( + self.audio_vae.mel_compression_ratio if getattr(self, "audio_vae", None) is not None else 4 + ) + self.audio_vae_temporal_compression_ratio = ( + self.audio_vae.temporal_compression_ratio if getattr(self, "audio_vae", None) is not None else 4 + ) + self.transformer_spatial_patch_size = ( + self.transformer.config.patch_size if getattr(self, "transformer", None) is not None else 1 + ) + self.transformer_temporal_patch_size = ( + self.transformer.config.patch_size_t if getattr(self, "transformer") is not None else 1 + ) + + self.audio_sampling_rate = ( + self.audio_vae.config.sample_rate if getattr(self, "audio_vae", None) is not None else 16000 + ) + self.audio_hop_length = ( + self.audio_vae.config.mel_hop_length if getattr(self, "audio_vae", None) is not None else 160 + ) + + self.video_processor = VideoProcessor(vae_scale_factor=self.vae_spatial_compression_ratio, resample="bilinear") + self.tokenizer_max_length = ( + self.tokenizer.model_max_length if getattr(self, "tokenizer", None) is not None else 1024 + ) + + @staticmethod + # Copied from diffusers.pipelines.ltx2.pipeline_ltx2.LTX2Pipeline._pack_text_embeds + def _pack_text_embeds( + text_hidden_states: torch.Tensor, + sequence_lengths: torch.Tensor, + device: Union[str, torch.device], + padding_side: str = "left", + scale_factor: int = 8, + eps: float = 1e-6, + ) -> torch.Tensor: + """ + Packs and normalizes text encoder hidden states, respecting padding. Normalization is performed per-batch and + per-layer in a masked fashion (only over non-padded positions). + + Args: + text_hidden_states (`torch.Tensor` of shape `(batch_size, seq_len, hidden_dim, num_layers)`): + Per-layer hidden_states from a text encoder (e.g. `Gemma3ForConditionalGeneration`). + sequence_lengths (`torch.Tensor of shape `(batch_size,)`): + The number of valid (non-padded) tokens for each batch instance. + device: (`str` or `torch.device`, *optional*): + torch device to place the resulting embeddings on + padding_side: (`str`, *optional*, defaults to `"left"`): + Whether the text tokenizer performs padding on the `"left"` or `"right"`. + scale_factor (`int`, *optional*, defaults to `8`): + Scaling factor to multiply the normalized hidden states by. + eps (`float`, *optional*, defaults to `1e-6`): + A small positive value for numerical stability when performing normalization. + + Returns: + `torch.Tensor` of shape `(batch_size, seq_len, hidden_dim * num_layers)`: + Normed and flattened text encoder hidden states. + """ + batch_size, seq_len, hidden_dim, num_layers = text_hidden_states.shape + original_dtype = text_hidden_states.dtype + + # Create padding mask + token_indices = torch.arange(seq_len, device=device).unsqueeze(0) + if padding_side == "right": + # For right padding, valid tokens are from 0 to sequence_length-1 + mask = token_indices < sequence_lengths[:, None] # [batch_size, seq_len] + elif padding_side == "left": + # For left padding, valid tokens are from (T - sequence_length) to T-1 + start_indices = seq_len - sequence_lengths[:, None] # [batch_size, 1] + mask = token_indices >= start_indices # [B, T] + else: + raise ValueError(f"padding_side must be 'left' or 'right', got {padding_side}") + mask = mask[:, :, None, None] # [batch_size, seq_len] --> [batch_size, seq_len, 1, 1] + + # Compute masked mean over non-padding positions of shape (batch_size, 1, 1, seq_len) + masked_text_hidden_states = text_hidden_states.masked_fill(~mask, 0.0) + num_valid_positions = (sequence_lengths * hidden_dim).view(batch_size, 1, 1, 1) + masked_mean = masked_text_hidden_states.sum(dim=(1, 2), keepdim=True) / (num_valid_positions + eps) + + # Compute min/max over non-padding positions of shape (batch_size, 1, 1 seq_len) + x_min = text_hidden_states.masked_fill(~mask, float("inf")).amin(dim=(1, 2), keepdim=True) + x_max = text_hidden_states.masked_fill(~mask, float("-inf")).amax(dim=(1, 2), keepdim=True) + + # Normalization + normalized_hidden_states = (text_hidden_states - masked_mean) / (x_max - x_min + eps) + normalized_hidden_states = normalized_hidden_states * scale_factor + + # Pack the hidden states to a 3D tensor (batch_size, seq_len, hidden_dim * num_layers) + normalized_hidden_states = normalized_hidden_states.flatten(2) + mask_flat = mask.squeeze(-1).expand(-1, -1, hidden_dim * num_layers) + normalized_hidden_states = normalized_hidden_states.masked_fill(~mask_flat, 0.0) + normalized_hidden_states = normalized_hidden_states.to(dtype=original_dtype) + return normalized_hidden_states + + # Copied from diffusers.pipelines.ltx2.pipeline_ltx2.LTX2Pipeline._get_gemma_prompt_embeds + def _get_gemma_prompt_embeds( + self, + prompt: Union[str, List[str]], + num_videos_per_prompt: int = 1, + max_sequence_length: int = 1024, + scale_factor: int = 8, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + device: (`str` or `torch.device`): + torch device to place the resulting embeddings on + dtype: (`torch.dtype`): + torch dtype to cast the prompt embeds to + max_sequence_length (`int`, defaults to 1024): Maximum sequence length to use for the prompt. + """ + device = device or self._execution_device + dtype = dtype or self.text_encoder.dtype + + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) + + if getattr(self, "tokenizer", None) is not None: + # Gemma expects left padding for chat-style prompts + self.tokenizer.padding_side = "left" + if self.tokenizer.pad_token is None: + self.tokenizer.pad_token = self.tokenizer.eos_token + + prompt = [p.strip() for p in prompt] + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + add_special_tokens=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + prompt_attention_mask = text_inputs.attention_mask + text_input_ids = text_input_ids.to(device) + prompt_attention_mask = prompt_attention_mask.to(device) + + text_encoder_outputs = self.text_encoder( + input_ids=text_input_ids, attention_mask=prompt_attention_mask, output_hidden_states=True + ) + text_encoder_hidden_states = text_encoder_outputs.hidden_states + text_encoder_hidden_states = torch.stack(text_encoder_hidden_states, dim=-1) + sequence_lengths = prompt_attention_mask.sum(dim=-1) + + prompt_embeds = self._pack_text_embeds( + text_encoder_hidden_states, + sequence_lengths, + device=device, + padding_side=self.tokenizer.padding_side, + scale_factor=scale_factor, + ) + prompt_embeds = prompt_embeds.to(dtype=dtype) + + # duplicate text embeddings for each generation per prompt, using mps friendly method + _, seq_len, _ = prompt_embeds.shape + prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1) + + prompt_attention_mask = prompt_attention_mask.view(batch_size, -1) + prompt_attention_mask = prompt_attention_mask.repeat(num_videos_per_prompt, 1) + + return prompt_embeds, prompt_attention_mask + + # Copied from diffusers.pipelines.ltx2.pipeline_ltx2.LTX2Pipeline.encode_prompt + def encode_prompt( + self, + prompt: Union[str, List[str]], + negative_prompt: Optional[Union[str, List[str]]] = None, + do_classifier_free_guidance: bool = True, + num_videos_per_prompt: int = 1, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + prompt_attention_mask: Optional[torch.Tensor] = None, + negative_prompt_attention_mask: Optional[torch.Tensor] = None, + max_sequence_length: int = 1024, + scale_factor: int = 8, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + do_classifier_free_guidance (`bool`, *optional*, defaults to `True`): + Whether to use classifier free guidance or not. + num_videos_per_prompt (`int`, *optional*, defaults to 1): + Number of videos that should be generated per prompt. torch device to place the resulting embeddings on + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + device: (`torch.device`, *optional*): + torch device + dtype: (`torch.dtype`, *optional*): + torch dtype + """ + device = device or self._execution_device + + prompt = [prompt] if isinstance(prompt, str) else prompt + if prompt is not None: + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + prompt_embeds, prompt_attention_mask = self._get_gemma_prompt_embeds( + prompt=prompt, + num_videos_per_prompt=num_videos_per_prompt, + max_sequence_length=max_sequence_length, + scale_factor=scale_factor, + device=device, + dtype=dtype, + ) + + if do_classifier_free_guidance and negative_prompt_embeds is None: + negative_prompt = negative_prompt or "" + negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + + if prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + + negative_prompt_embeds, negative_prompt_attention_mask = self._get_gemma_prompt_embeds( + prompt=negative_prompt, + num_videos_per_prompt=num_videos_per_prompt, + max_sequence_length=max_sequence_length, + scale_factor=scale_factor, + device=device, + dtype=dtype, + ) + + return prompt_embeds, prompt_attention_mask, negative_prompt_embeds, negative_prompt_attention_mask + + # Copied from diffusers.pipelines.ltx2.pipeline_ltx2.LTX2Pipeline.check_inputs + def check_inputs( + self, + prompt, + height, + width, + callback_on_step_end_tensor_inputs=None, + prompt_embeds=None, + negative_prompt_embeds=None, + prompt_attention_mask=None, + negative_prompt_attention_mask=None, + ): + if height % 32 != 0 or width % 32 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 32 but are {height} and {width}.") + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if prompt_embeds is not None and prompt_attention_mask is None: + raise ValueError("Must provide `prompt_attention_mask` when specifying `prompt_embeds`.") + + if negative_prompt_embeds is not None and negative_prompt_attention_mask is None: + raise ValueError("Must provide `negative_prompt_attention_mask` when specifying `negative_prompt_embeds`.") + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + if prompt_attention_mask.shape != negative_prompt_attention_mask.shape: + raise ValueError( + "`prompt_attention_mask` and `negative_prompt_attention_mask` must have the same shape when passed directly, but" + f" got: `prompt_attention_mask` {prompt_attention_mask.shape} != `negative_prompt_attention_mask`" + f" {negative_prompt_attention_mask.shape}." + ) + + @staticmethod + # Copied from diffusers.pipelines.ltx2.pipeline_ltx2.LTX2Pipeline._pack_latents + def _pack_latents(latents: torch.Tensor, patch_size: int = 1, patch_size_t: int = 1) -> torch.Tensor: + # Unpacked latents of shape are [B, C, F, H, W] are patched into tokens of shape [B, C, F // p_t, p_t, H // p, p, W // p, p]. + # The patch dimensions are then permuted and collapsed into the channel dimension of shape: + # [B, F // p_t * H // p * W // p, C * p_t * p * p] (an ndim=3 tensor). + # dim=0 is the batch size, dim=1 is the effective video sequence length, dim=2 is the effective number of input features + batch_size, num_channels, num_frames, height, width = latents.shape + post_patch_num_frames = num_frames // patch_size_t + post_patch_height = height // patch_size + post_patch_width = width // patch_size + latents = latents.reshape( + batch_size, + -1, + post_patch_num_frames, + patch_size_t, + post_patch_height, + patch_size, + post_patch_width, + patch_size, + ) + latents = latents.permute(0, 2, 4, 6, 1, 3, 5, 7).flatten(4, 7).flatten(1, 3) + return latents + + @staticmethod + # Copied from diffusers.pipelines.ltx2.pipeline_ltx2.LTX2Pipeline._unpack_latents + def _unpack_latents( + latents: torch.Tensor, num_frames: int, height: int, width: int, patch_size: int = 1, patch_size_t: int = 1 + ) -> torch.Tensor: + # Packed latents of shape [B, S, D] (S is the effective video sequence length, D is the effective feature dimensions) + # are unpacked and reshaped into a video tensor of shape [B, C, F, H, W]. This is the inverse operation of + # what happens in the `_pack_latents` method. + batch_size = latents.size(0) + latents = latents.reshape(batch_size, num_frames, height, width, -1, patch_size_t, patch_size, patch_size) + latents = latents.permute(0, 4, 1, 5, 2, 6, 3, 7).flatten(6, 7).flatten(4, 5).flatten(2, 3) + return latents + + @staticmethod + def _normalize_latents( + latents: torch.Tensor, latents_mean: torch.Tensor, latents_std: torch.Tensor, scaling_factor: float = 1.0 + ) -> torch.Tensor: + # Normalize latents across the channel dimension [B, C, F, H, W] + latents_mean = latents_mean.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype) + latents_std = latents_std.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype) + latents = (latents - latents_mean) * scaling_factor / latents_std + return latents + + @staticmethod + # Copied from diffusers.pipelines.ltx2.pipeline_ltx2.LTX2Pipeline._denormalize_latents + def _denormalize_latents( + latents: torch.Tensor, latents_mean: torch.Tensor, latents_std: torch.Tensor, scaling_factor: float = 1.0 + ) -> torch.Tensor: + # Denormalize latents across the channel dimension [B, C, F, H, W] + latents_mean = latents_mean.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype) + latents_std = latents_std.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype) + latents = latents * latents_std / scaling_factor + latents_mean + return latents + + @staticmethod + # Copied from diffusers.pipelines.ltx2.pipeline_ltx2.LTX2Pipeline._create_noised_state + def _create_noised_state( + latents: torch.Tensor, noise_scale: Union[float, torch.Tensor], generator: Optional[torch.Generator] = None + ): + noise = randn_tensor(latents.shape, generator=generator, device=latents.device, dtype=latents.dtype) + noised_latents = noise_scale * noise + (1 - noise_scale) * latents + return noised_latents + + @staticmethod + # Copied from diffusers.pipelines.ltx2.pipeline_ltx2.LTX2Pipeline._pack_audio_latents + def _pack_audio_latents( + latents: torch.Tensor, patch_size: Optional[int] = None, patch_size_t: Optional[int] = None + ) -> torch.Tensor: + # Audio latents shape: [B, C, L, M], where L is the latent audio length and M is the number of mel bins + if patch_size is not None and patch_size_t is not None: + # Packs the latents into a patch sequence of shape [B, L // p_t * M // p, C * p_t * p] (a ndim=3 tnesor). + # dim=1 is the effective audio sequence length and dim=2 is the effective audio input feature size. + batch_size, num_channels, latent_length, latent_mel_bins = latents.shape + post_patch_latent_length = latent_length / patch_size_t + post_patch_mel_bins = latent_mel_bins / patch_size + latents = latents.reshape( + batch_size, -1, post_patch_latent_length, patch_size_t, post_patch_mel_bins, patch_size + ) + latents = latents.permute(0, 2, 4, 1, 3, 5).flatten(3, 5).flatten(1, 2) + else: + # Packs the latents into a patch sequence of shape [B, L, C * M]. This implicitly assumes a (mel) + # patch_size of M (all mel bins constitutes a single patch) and a patch_size_t of 1. + latents = latents.transpose(1, 2).flatten(2, 3) # [B, C, L, M] --> [B, L, C * M] + return latents + + @staticmethod + # Copied from diffusers.pipelines.ltx2.pipeline_ltx2.LTX2Pipeline._unpack_audio_latents + def _unpack_audio_latents( + latents: torch.Tensor, + latent_length: int, + num_mel_bins: int, + patch_size: Optional[int] = None, + patch_size_t: Optional[int] = None, + ) -> torch.Tensor: + # Unpacks an audio patch sequence of shape [B, S, D] into a latent spectrogram tensor of shape [B, C, L, M], + # where L is the latent audio length and M is the number of mel bins. + if patch_size is not None and patch_size_t is not None: + batch_size = latents.size(0) + latents = latents.reshape(batch_size, latent_length, num_mel_bins, -1, patch_size_t, patch_size) + latents = latents.permute(0, 3, 1, 4, 2, 5).flatten(4, 5).flatten(2, 3) + else: + # Assume [B, S, D] = [B, L, C * M], which implies that patch_size = M and patch_size_t = 1. + latents = latents.unflatten(2, (-1, num_mel_bins)).transpose(1, 2) + return latents + + @staticmethod + # Copied from diffusers.pipelines.ltx2.pipeline_ltx2.LTX2Pipeline._normalize_audio_latents + def _normalize_audio_latents(latents: torch.Tensor, latents_mean: torch.Tensor, latents_std: torch.Tensor): + latents_mean = latents_mean.to(latents.device, latents.dtype) + latents_std = latents_std.to(latents.device, latents.dtype) + return (latents - latents_mean) / latents_std + + @staticmethod + # Copied from diffusers.pipelines.ltx2.pipeline_ltx2.LTX2Pipeline._denormalize_audio_latents + def _denormalize_audio_latents(latents: torch.Tensor, latents_mean: torch.Tensor, latents_std: torch.Tensor): + latents_mean = latents_mean.to(latents.device, latents.dtype) + latents_std = latents_std.to(latents.device, latents.dtype) + return (latents * latents_std) + latents_mean + + def prepare_latents( + self, + image: Optional[torch.Tensor] = None, + batch_size: int = 1, + num_channels_latents: int = 128, + height: int = 512, + width: int = 704, + num_frames: int = 161, + noise_scale: float = 0.0, + dtype: Optional[torch.dtype] = None, + device: Optional[torch.device] = None, + generator: Optional[torch.Generator] = None, + latents: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + height = height // self.vae_spatial_compression_ratio + width = width // self.vae_spatial_compression_ratio + num_frames = (num_frames - 1) // self.vae_temporal_compression_ratio + 1 + + shape = (batch_size, num_channels_latents, num_frames, height, width) + mask_shape = (batch_size, 1, num_frames, height, width) + + if latents is not None: + conditioning_mask = latents.new_zeros(mask_shape) + conditioning_mask[:, :, 0] = 1.0 + if latents.ndim == 5: + latents = self._normalize_latents( + latents, self.vae.latents_mean, self.vae.latents_std, self.vae.config.scaling_factor + ) + latents = self._create_noised_state(latents, noise_scale * (1 - conditioning_mask), generator) + # latents are of shape [B, C, F, H, W], need to be packed + latents = self._pack_latents( + latents, self.transformer_spatial_patch_size, self.transformer_temporal_patch_size + ) + conditioning_mask = self._pack_latents( + conditioning_mask, self.transformer_spatial_patch_size, self.transformer_temporal_patch_size + ).squeeze(-1) + if latents.ndim != 3 or latents.shape[:2] != conditioning_mask.shape: + raise ValueError( + f"Provided `latents` tensor has shape {latents.shape}, but the expected shape is {conditioning_mask.shape + (num_channels_latents,)}." + ) + return latents.to(device=device, dtype=dtype), conditioning_mask + + if isinstance(generator, list): + if len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + init_latents = [ + retrieve_latents(self.vae.encode(image[i].unsqueeze(0).unsqueeze(2)), generator[i], "argmax") + for i in range(batch_size) + ] + else: + init_latents = [ + retrieve_latents(self.vae.encode(img.unsqueeze(0).unsqueeze(2)), generator, "argmax") for img in image + ] + + init_latents = torch.cat(init_latents, dim=0).to(dtype) + init_latents = self._normalize_latents(init_latents, self.vae.latents_mean, self.vae.latents_std) + init_latents = init_latents.repeat(1, 1, num_frames, 1, 1) + + # First condition is image latents and those should be kept clean. + conditioning_mask = torch.zeros(mask_shape, device=device, dtype=dtype) + conditioning_mask[:, :, 0] = 1.0 + + noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + # Interpolation. + latents = init_latents * conditioning_mask + noise * (1 - conditioning_mask) + + conditioning_mask = self._pack_latents( + conditioning_mask, self.transformer_spatial_patch_size, self.transformer_temporal_patch_size + ).squeeze(-1) + latents = self._pack_latents( + latents, self.transformer_spatial_patch_size, self.transformer_temporal_patch_size + ) + + return latents, conditioning_mask + + # Copied from diffusers.pipelines.ltx2.pipeline_ltx2.LTX2Pipeline.prepare_audio_latents + def prepare_audio_latents( + self, + batch_size: int = 1, + num_channels_latents: int = 8, + audio_latent_length: int = 1, # 1 is just a dummy value + num_mel_bins: int = 64, + noise_scale: float = 0.0, + dtype: Optional[torch.dtype] = None, + device: Optional[torch.device] = None, + generator: Optional[torch.Generator] = None, + latents: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + if latents is not None: + if latents.ndim == 4: + # latents are of shape [B, C, L, M], need to be packed + latents = self._pack_audio_latents(latents) + if latents.ndim != 3: + raise ValueError( + f"Provided `latents` tensor has shape {latents.shape}, but the expected shape is [batch_size, num_seq, num_features]." + ) + latents = self._normalize_audio_latents(latents, self.audio_vae.latents_mean, self.audio_vae.latents_std) + latents = self._create_noised_state(latents, noise_scale, generator) + return latents.to(device=device, dtype=dtype) + + # TODO: confirm whether this logic is correct + latent_mel_bins = num_mel_bins // self.audio_vae_mel_compression_ratio + + shape = (batch_size, num_channels_latents, audio_latent_length, latent_mel_bins) + + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + latents = self._pack_audio_latents(latents) + return latents + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def guidance_rescale(self): + return self._guidance_rescale + + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1.0 + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def current_timestep(self): + return self._current_timestep + + @property + def attention_kwargs(self): + return self._attention_kwargs + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + image: PipelineImageInput = None, + prompt: Union[str, List[str]] = None, + negative_prompt: Optional[Union[str, List[str]]] = None, + height: int = 512, + width: int = 768, + num_frames: int = 121, + frame_rate: float = 24.0, + num_inference_steps: int = 40, + sigmas: Optional[List[float]] = None, + timesteps: List[int] = None, + guidance_scale: float = 4.0, + guidance_rescale: float = 0.0, + noise_scale: float = 0.0, + num_videos_per_prompt: Optional[int] = 1, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.Tensor] = None, + audio_latents: Optional[torch.Tensor] = None, + prompt_embeds: Optional[torch.Tensor] = None, + prompt_attention_mask: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_attention_mask: Optional[torch.Tensor] = None, + decode_timestep: Union[float, List[float]] = 0.0, + decode_noise_scale: Optional[Union[float, List[float]]] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + attention_kwargs: Optional[Dict[str, Any]] = None, + callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + max_sequence_length: int = 1024, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + image (`PipelineImageInput`): + The input image to condition the generation on. Must be an image, a list of images or a `torch.Tensor`. + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + height (`int`, *optional*, defaults to `512`): + The height in pixels of the generated image. This is set to 480 by default for the best results. + width (`int`, *optional*, defaults to `768`): + The width in pixels of the generated image. This is set to 848 by default for the best results. + num_frames (`int`, *optional*, defaults to `121`): + The number of video frames to generate + frame_rate (`float`, *optional*, defaults to `24.0`): + The frames per second (FPS) of the generated video. + num_inference_steps (`int`, *optional*, defaults to 40): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + sigmas (`List[float]`, *optional*): + Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in + their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed + will be used. + timesteps (`List[int]`, *optional*): + Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument + in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is + passed will be used. Must be in descending order. + guidance_scale (`float`, *optional*, defaults to `4.0`): + Guidance scale as defined in [Classifier-Free Diffusion + Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2. + of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting + `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to + the text `prompt`, usually at the expense of lower image quality. + guidance_rescale (`float`, *optional*, defaults to 0.0): + Guidance rescale factor proposed by [Common Diffusion Noise Schedules and Sample Steps are + Flawed](https://huggingface.co/papers/2305.08891) `guidance_scale` is defined as `φ` in equation 16. of + [Common Diffusion Noise Schedules and Sample Steps are + Flawed](https://huggingface.co/papers/2305.08891). Guidance rescale factor should fix overexposure when + using zero terminal SNR. + noise_scale (`float`, *optional*, defaults to `0.0`): + The interpolation factor between random noise and denoised latents at each timestep. Applying noise to + the `latents` and `audio_latents` before continue denoising. + num_videos_per_prompt (`int`, *optional*, defaults to 1): + The number of videos to generate per prompt. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for video + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will be generated by sampling using the supplied random `generator`. + audio_latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for audio + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will be generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + prompt_attention_mask (`torch.Tensor`, *optional*): + Pre-generated attention mask for text embeddings. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. For PixArt-Sigma this negative prompt should be "". If not + provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. + negative_prompt_attention_mask (`torch.FloatTensor`, *optional*): + Pre-generated attention mask for negative text embeddings. + decode_timestep (`float`, defaults to `0.0`): + The timestep at which generated video is decoded. + decode_noise_scale (`float`, defaults to `None`): + The interpolation factor between random noise and denoised latents at the decode timestep. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.ltx.LTX2PipelineOutput`] instead of a plain tuple. + attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + max_sequence_length (`int`, *optional*, defaults to `1024`): + Maximum sequence length to use with the `prompt`. + + Examples: + + Returns: + [`~pipelines.ltx.LTX2PipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.ltx.LTX2PipelineOutput`] is returned, otherwise a `tuple` is + returned where the first element is a list with the generated images. + """ + + if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt=prompt, + height=height, + width=width, + callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + prompt_attention_mask=prompt_attention_mask, + negative_prompt_attention_mask=negative_prompt_attention_mask, + ) + + self._guidance_scale = guidance_scale + self._guidance_rescale = guidance_rescale + self._attention_kwargs = attention_kwargs + self._interrupt = False + self._current_timestep = None + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + # 3. Prepare text embeddings + ( + prompt_embeds, + prompt_attention_mask, + negative_prompt_embeds, + negative_prompt_attention_mask, + ) = self.encode_prompt( + prompt=prompt, + negative_prompt=negative_prompt, + do_classifier_free_guidance=self.do_classifier_free_guidance, + num_videos_per_prompt=num_videos_per_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + prompt_attention_mask=prompt_attention_mask, + negative_prompt_attention_mask=negative_prompt_attention_mask, + max_sequence_length=max_sequence_length, + device=device, + ) + if self.do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) + prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask], dim=0) + + additive_attention_mask = (1 - prompt_attention_mask.to(prompt_embeds.dtype)) * -1000000.0 + connector_prompt_embeds, connector_audio_prompt_embeds, connector_attention_mask = self.connectors( + prompt_embeds, additive_attention_mask, additive_mask=True + ) + + # 4. Prepare latent variables + latent_num_frames = (num_frames - 1) // self.vae_temporal_compression_ratio + 1 + latent_height = height // self.vae_spatial_compression_ratio + latent_width = width // self.vae_spatial_compression_ratio + if latents is not None: + if latents.ndim == 5: + logger.info( + "Got latents of shape [batch_size, latent_dim, latent_frames, latent_height, latent_width], `latent_num_frames`, `latent_height`, `latent_width` will be inferred." + ) + _, _, latent_num_frames, latent_height, latent_width = latents.shape # [B, C, F, H, W] + elif latents.ndim == 3: + logger.warning( + f"You have supplied packed `latents` of shape {latents.shape}, so the latent dims cannot be" + f" inferred. Make sure the supplied `height`, `width`, and `num_frames` are correct." + ) + else: + raise ValueError( + f"Provided `latents` tensor has shape {latents.shape}, but the expected shape is either [batch_size, seq_len, num_features] or [batch_size, latent_dim, latent_frames, latent_height, latent_width]." + ) + video_sequence_length = latent_num_frames * latent_height * latent_width + + if latents is None: + image = self.video_processor.preprocess(image, height=height, width=width) + image = image.to(device=device, dtype=prompt_embeds.dtype) + + num_channels_latents = self.transformer.config.in_channels + latents, conditioning_mask = self.prepare_latents( + image, + batch_size * num_videos_per_prompt, + num_channels_latents, + height, + width, + num_frames, + noise_scale, + torch.float32, + device, + generator, + latents, + ) + if self.do_classifier_free_guidance: + conditioning_mask = torch.cat([conditioning_mask, conditioning_mask]) + + duration_s = num_frames / frame_rate + audio_latents_per_second = ( + self.audio_sampling_rate / self.audio_hop_length / float(self.audio_vae_temporal_compression_ratio) + ) + audio_num_frames = round(duration_s * audio_latents_per_second) + if audio_latents is not None: + if audio_latents.ndim == 4: + logger.info( + "Got audio_latents of shape [batch_size, num_channels, audio_length, mel_bins], `audio_num_frames` will be inferred." + ) + _, _, audio_num_frames, _ = audio_latents.shape # [B, C, L, M] + elif audio_latents.ndim == 3: + logger.warning( + f"You have supplied packed `audio_latents` of shape {audio_latents.shape}, so the latent dims" + f" cannot be inferred. Make sure the supplied `num_frames` and `frame_rate` are correct." + ) + else: + raise ValueError( + f"Provided `audio_latents` tensor has shape {audio_latents.shape}, but the expected shape is either [batch_size, seq_len, num_features] or [batch_size, num_channels, audio_length, mel_bins]." + ) + + num_mel_bins = self.audio_vae.config.mel_bins if getattr(self, "audio_vae", None) is not None else 64 + latent_mel_bins = num_mel_bins // self.audio_vae_mel_compression_ratio + num_channels_latents_audio = ( + self.audio_vae.config.latent_channels if getattr(self, "audio_vae", None) is not None else 8 + ) + audio_latents = self.prepare_audio_latents( + batch_size * num_videos_per_prompt, + num_channels_latents=num_channels_latents_audio, + audio_latent_length=audio_num_frames, + num_mel_bins=num_mel_bins, + noise_scale=noise_scale, + dtype=torch.float32, + device=device, + generator=generator, + latents=audio_latents, + ) + + # 5. Prepare timesteps + sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas + mu = calculate_shift( + video_sequence_length, + self.scheduler.config.get("base_image_seq_len", 1024), + self.scheduler.config.get("max_image_seq_len", 4096), + self.scheduler.config.get("base_shift", 0.95), + self.scheduler.config.get("max_shift", 2.05), + ) + + # For now, duplicate the scheduler for use with the audio latents + audio_scheduler = copy.deepcopy(self.scheduler) + _, _ = retrieve_timesteps( + audio_scheduler, + num_inference_steps, + device, + timesteps, + sigmas=sigmas, + mu=mu, + ) + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, + num_inference_steps, + device, + timesteps, + sigmas=sigmas, + mu=mu, + ) + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + self._num_timesteps = len(timesteps) + + # 6. Prepare micro-conditions + rope_interpolation_scale = ( + self.vae_temporal_compression_ratio / frame_rate, + self.vae_spatial_compression_ratio, + self.vae_spatial_compression_ratio, + ) + # Pre-compute video and audio positional ids as they will be the same at each step of the denoising loop + video_coords = self.transformer.rope.prepare_video_coords( + latents.shape[0], latent_num_frames, latent_height, latent_width, latents.device, fps=frame_rate + ) + audio_coords = self.transformer.audio_rope.prepare_audio_coords( + audio_latents.shape[0], audio_num_frames, audio_latents.device + ) + + # 7. Denoising loop + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + self._current_timestep = t + + latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents + latent_model_input = latent_model_input.to(prompt_embeds.dtype) + audio_latent_model_input = ( + torch.cat([audio_latents] * 2) if self.do_classifier_free_guidance else audio_latents + ) + audio_latent_model_input = audio_latent_model_input.to(prompt_embeds.dtype) + + timestep = t.expand(latent_model_input.shape[0]) + video_timestep = timestep.unsqueeze(-1) * (1 - conditioning_mask) + + with self.transformer.cache_context("cond_uncond"): + noise_pred_video, noise_pred_audio = self.transformer( + hidden_states=latent_model_input, + audio_hidden_states=audio_latent_model_input, + encoder_hidden_states=connector_prompt_embeds, + audio_encoder_hidden_states=connector_audio_prompt_embeds, + timestep=video_timestep, + audio_timestep=timestep, + encoder_attention_mask=connector_attention_mask, + audio_encoder_attention_mask=connector_attention_mask, + num_frames=latent_num_frames, + height=latent_height, + width=latent_width, + fps=frame_rate, + audio_num_frames=audio_num_frames, + video_coords=video_coords, + audio_coords=audio_coords, + # rope_interpolation_scale=rope_interpolation_scale, + attention_kwargs=attention_kwargs, + return_dict=False, + ) + noise_pred_video = noise_pred_video.float() + noise_pred_audio = noise_pred_audio.float() + + if self.do_classifier_free_guidance: + noise_pred_video_uncond, noise_pred_video_text = noise_pred_video.chunk(2) + noise_pred_video = noise_pred_video_uncond + self.guidance_scale * ( + noise_pred_video_text - noise_pred_video_uncond + ) + + noise_pred_audio_uncond, noise_pred_audio_text = noise_pred_audio.chunk(2) + noise_pred_audio = noise_pred_audio_uncond + self.guidance_scale * ( + noise_pred_audio_text - noise_pred_audio_uncond + ) + + if self.guidance_rescale > 0: + # Based on 3.4. in https://huggingface.co/papers/2305.08891 + noise_pred_video = rescale_noise_cfg( + noise_pred_video, noise_pred_video_text, guidance_rescale=self.guidance_rescale + ) + noise_pred_audio = rescale_noise_cfg( + noise_pred_audio, noise_pred_audio_text, guidance_rescale=self.guidance_rescale + ) + + # compute the previous noisy sample x_t -> x_t-1 + noise_pred_video = self._unpack_latents( + noise_pred_video, + latent_num_frames, + latent_height, + latent_width, + self.transformer_spatial_patch_size, + self.transformer_temporal_patch_size, + ) + latents = self._unpack_latents( + latents, + latent_num_frames, + latent_height, + latent_width, + self.transformer_spatial_patch_size, + self.transformer_temporal_patch_size, + ) + + noise_pred_video = noise_pred_video[:, :, 1:] + noise_latents = latents[:, :, 1:] + pred_latents = self.scheduler.step(noise_pred_video, t, noise_latents, return_dict=False)[0] + + latents = torch.cat([latents[:, :, :1], pred_latents], dim=2) + latents = self._pack_latents( + latents, self.transformer_spatial_patch_size, self.transformer_temporal_patch_size + ) + + # NOTE: for now duplicate scheduler for audio latents in case self.scheduler sets internal state in + # the step method (such as _step_index) + audio_latents = audio_scheduler.step(noise_pred_audio, t, audio_latents, return_dict=False)[0] + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + latents = self._unpack_latents( + latents, + latent_num_frames, + latent_height, + latent_width, + self.transformer_spatial_patch_size, + self.transformer_temporal_patch_size, + ) + latents = self._denormalize_latents( + latents, self.vae.latents_mean, self.vae.latents_std, self.vae.config.scaling_factor + ) + + audio_latents = self._denormalize_audio_latents( + audio_latents, self.audio_vae.latents_mean, self.audio_vae.latents_std + ) + audio_latents = self._unpack_audio_latents(audio_latents, audio_num_frames, num_mel_bins=latent_mel_bins) + + if output_type == "latent": + video = latents + audio = audio_latents + else: + latents = latents.to(prompt_embeds.dtype) + + if not self.vae.config.timestep_conditioning: + timestep = None + else: + noise = randn_tensor(latents.shape, generator=generator, device=device, dtype=latents.dtype) + if not isinstance(decode_timestep, list): + decode_timestep = [decode_timestep] * batch_size + if decode_noise_scale is None: + decode_noise_scale = decode_timestep + elif not isinstance(decode_noise_scale, list): + decode_noise_scale = [decode_noise_scale] * batch_size + + timestep = torch.tensor(decode_timestep, device=device, dtype=latents.dtype) + decode_noise_scale = torch.tensor(decode_noise_scale, device=device, dtype=latents.dtype)[ + :, None, None, None, None + ] + latents = (1 - decode_noise_scale) * latents + decode_noise_scale * noise + + latents = latents.to(self.vae.dtype) + video = self.vae.decode(latents, timestep, return_dict=False)[0] + video = self.video_processor.postprocess_video(video, output_type=output_type) + + audio_latents = audio_latents.to(self.audio_vae.dtype) + generated_mel_spectrograms = self.audio_vae.decode(audio_latents, return_dict=False)[0] + audio = self.vocoder(generated_mel_spectrograms) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (video, audio) + + return LTX2PipelineOutput(frames=video, audio=audio) From 02c750b590a63e3f2b832894fe454d65b6d88019 Mon Sep 17 00:00:00 2001 From: Daniel Gu Date: Tue, 3 Feb 2026 08:27:42 +0100 Subject: [PATCH 02/25] Fix pipeline import error --- src/diffusers/pipelines/__init__.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py index 39af113bfade..50454af74cd9 100644 --- a/src/diffusers/pipelines/__init__.py +++ b/src/diffusers/pipelines/__init__.py @@ -293,7 +293,8 @@ ] _import_structure["ltx2"] = [ "LTX2Pipeline", - "LTX2ConditionPipelineLTX2ImageToVideoPipeline", + "LTX2ConditionPipeline", + "LTX2ImageToVideoPipeline", "LTX2LatentUpsamplePipeline", ] _import_structure["lumina"] = ["LuminaPipeline", "LuminaText2ImgPipeline"] From ed52c0d7cced0f85003dfe528383969af609323b Mon Sep 17 00:00:00 2001 From: Daniel Gu Date: Tue, 3 Feb 2026 09:48:38 +0100 Subject: [PATCH 03/25] Implement LTX-2-style general image conditioning --- .../pipelines/ltx2/pipeline_ltx2_condition.py | 418 +++++++++++++----- 1 file changed, 311 insertions(+), 107 deletions(-) diff --git a/src/diffusers/pipelines/ltx2/pipeline_ltx2_condition.py b/src/diffusers/pipelines/ltx2/pipeline_ltx2_condition.py index f69353d3a253..164d14569714 100644 --- a/src/diffusers/pipelines/ltx2/pipeline_ltx2_condition.py +++ b/src/diffusers/pipelines/ltx2/pipeline_ltx2_condition.py @@ -15,7 +15,7 @@ import copy import inspect from dataclasses import dataclass -from typing import Any, Callable, Dict, List, Optional, Union +from typing import Any, Callable, Dict, List, Optional, Tuple, Union import numpy as np import PIL.Image @@ -92,24 +92,22 @@ @dataclass -class LTXVideoCondition: +class LTX2VideoCondition: """ - Defines a single frame-conditioning item for LTX Video - a single frame or a sequence of frames. + Defines a single frame-conditioning item for LTX-2 Video - a single frame or a sequence of frames. Attributes: - image (`PIL.Image.Image`): - The image to condition the video on. - video (`List[PIL.Image.Image]`): - The video to condition the video on. - frame_index (`int`): - The frame index at which the image or video will conditionally effect the video generation. + image (`PIL.Image.Image` or `List[PIL.Image.Image]` or `np.ndarray` or `torch.Tensor`): + The image (or video) to condition the video on. Accepts any type that can be handled by + VideoProcessor.preprocess_video. + index (`int`, defaults to `0`): + The index at which the image or video will conditionally affect the video generation. strength (`float`, defaults to `1.0`): The strength of the conditioning effect. A value of `1.0` means the conditioning effect is fully applied. """ - image: Optional[PIL.Image.Image] = None - video: Optional[List[PIL.Image.Image]] = None - frame_index: int = 0 + image: Union[PIL.Image.Image, List[PIL.Image.Image], np.ndarray, torch.Tensor] + index: int = 0 strength: float = 1.0 @@ -526,9 +524,13 @@ def encode_prompt( return prompt_embeds, prompt_attention_mask, negative_prompt_embeds, negative_prompt_attention_mask - # Copied from diffusers.pipelines.ltx2.pipeline_ltx2.LTX2Pipeline.check_inputs def check_inputs( self, + conditions, + image, + video, + cond_index, + strength, prompt, height, width, @@ -580,6 +582,23 @@ def check_inputs( f" {negative_prompt_attention_mask.shape}." ) + if conditions is not None and (image is not None or video is not None): + raise ValueError("If `conditions` is provided, `image` and `video` must not be provided.") + + if conditions is None: + if isinstance(image, list) and isinstance(cond_index, list) and len(image) != len(cond_index): + raise ValueError( + "If `conditions` is not provided, `image` and `cond_index` must be of the same length." + ) + elif isinstance(image, list) and isinstance(strength, list) and len(image) != len(strength): + raise ValueError("If `conditions` is not provided, `image` and `strength` must be of the same length.") + elif isinstance(video, list) and isinstance(cond_index, list) and len(video) != len(cond_index): + raise ValueError( + "If `conditions` is not provided, `video` and `cond_index` must be of the same length." + ) + elif isinstance(video, list) and isinstance(strength, list) and len(video) != len(strength): + raise ValueError("If `conditions` is not provided, `video` and `strength` must be of the same length.") + @staticmethod # Copied from diffusers.pipelines.ltx2.pipeline_ltx2.LTX2Pipeline._pack_latents def _pack_latents(latents: torch.Tensor, patch_size: int = 1, patch_size_t: int = 1) -> torch.Tensor: @@ -703,84 +722,248 @@ def _denormalize_audio_latents(latents: torch.Tensor, latents_mean: torch.Tensor latents_std = latents_std.to(latents.device, latents.dtype) return (latents * latents_std) + latents_mean + # Copied from diffusers.pipelines.ltx.pipeline_ltx_condition.LTXConditionPipeline.trim_conditioning_sequence + def trim_conditioning_sequence(self, start_frame: int, sequence_num_frames: int, target_num_frames: int) -> int: + """ + Trim a conditioning sequence to the allowed number of frames. + + Args: + start_frame (int): The target frame number of the first frame in the sequence. + sequence_num_frames (int): The number of frames in the sequence. + target_num_frames (int): The target number of frames in the generated video. + Returns: + int: updated sequence length + """ + scale_factor = self.vae_temporal_compression_ratio + num_frames = min(sequence_num_frames, target_num_frames - start_frame) + # Trim down to a multiple of temporal_scale_factor frames plus 1 + num_frames = (num_frames - 1) // scale_factor * scale_factor + 1 + return num_frames + + def latent_idx_from_index(self, frame_idx: int, index_type: str = "latent") -> int: + if index_type == "latent": + latent_idx = frame_idx + else: + raise ValueError( + f"Got unsupported `index_type` {index_type}. Supported index types are `latent`." + ) + return latent_idx + + def preprocess_conditions( + self, + conditions: Optional[Union[LTX2VideoCondition, List[LTX2VideoCondition]]] = None, + height: int = 512, + width: int = 768, + num_frames: int = 121, + device: Optional[torch.device] = None, + index_type: str = "latent", + ) -> Tuple[List[torch.Tensor], List[float], List[int]]: + """ + Preprocesses the condition images/videos to torch tensors. + + Args: + conditions (`LTX2VideoCondition` or `List[LTX2VideoCondition]`, *optional*, defaults to `None`): + A list of image/video condition instances. + height (`int`, *optional*, defaults to `512`): + The desired height in pixels. + width (`int`, *optional*, defaults to `768`): + The desired width in pixels. + num_frames (`int`, *optional*, defaults to `121`): + The desired number of frames in the generated video. + device (`torch.device`, *optional*, defaults to `None`): + The device on which to put the preprocessed image/video tensors. + + Returns: + `Tuple[List[torch.Tensor], List[float], List[int]]`: + Returns a 3-tuple of lists of length `len(conditions)` as follows: + 1. The first list is a list of preprocessed video tensors of shape [batch_size=1, num_channels, + num_frames, height, width]. + 2. The second list is a list of conditioning strengths. + 3. The third list is a list of indices in latent space to insert the corresponding condition. + """ + conditioning_frames, conditioning_strengths, conditioning_indices = [], [], [] + + if conditions is None: + conditions = [] + if isinstance(conditions, LTX2VideoCondition): + conditions = [conditions] + + frame_scale_factor = self.vae_temporal_compression_ratio + latent_num_frames = (num_frames - 1) // frame_scale_factor + 1 + for i, condition in enumerate(conditions): + if isinstance(condition.image, PIL.Image.Image): + # Single image, convert to List[PIL.Image.Image] + video_like_cond = [condition.image] + elif isinstance(condition.image, np.ndarray) and condition.image.ndim == 3: + # Image-like ndarray of shape (H, W, C), insert frame dim in first axis + video_like_cond = np.expand_dims(condition.image, axis=0) + elif isinstance(condition.image, torch.Tensor) and condition.image.ndim == 3: + # Image-like tensor of shape (C, H, W), insert frame dim in first dim + video_like_cond = condition.image.unsqueeze(0) + else: + # Treat all other as videos. Note that this means 4D ndarrays and tensors will be treated as videos of + # shape (F, H, W, C) and (F, C, H, W), respectively. + video_like_cond = condition.image + condition_pixels = self.video_processor.preprocess_video(video_like_cond, height, width) + + latent_start_idx = self.latent_idx_from_index(condition.index, index_type) + # Support negative latent indices (e.g. -1 for the last latent index) + if latent_start_idx < 0: + # latent_start_idx will be positive because latent_num_frames is positive + latent_start_idx = latent_start_idx % latent_num_frames + if latent_start_idx >= latent_num_frames: + logger.warning( + f"The starting latent index {latent_start_idx} of condition {i} is too big for the specified number" + f" of latent frames {latent_num_frames}. This condition will be skipped." + ) + continue + + cond_num_frames = condition_pixels.size(2) + start_idx = (latent_start_idx - 1) * frame_scale_factor + 1 + truncated_cond_frames = self.trim_conditioning_sequence(start_idx, cond_num_frames, num_frames) + condition_pixels = condition_pixels[:, :, :truncated_cond_frames] + + conditioning_frames.append(condition_pixels.to(dtype=self.vae.dtype, device=device)) + conditioning_strengths.append(condition.strength) + conditioning_indices.append(latent_start_idx) + + return conditioning_frames, conditioning_strengths, conditioning_indices + + def apply_visual_conditioning( + self, + latents: torch.Tensor, + conditioning_mask: torch.Tensor, + condition_latents: List[torch.Tensor], + condition_strengths: List[float], + condition_indices: List[int], + latent_height: int, + latent_width: int, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Applies visual conditioning frames to an initial latent. + + Args: + latents (`torch.Tensor`): + Initial packed (patchified) latents of shape [batch_size, patch_seq_len, hidden_dim]. + conditioning_mask (`torch.Tensor`, *optional*): + Initial packed (patchified) conditioning mask of shape [batch_size, patch_seq_len, 1] with + values in [0, 1] where 0 means that the denoising model output will be fully used and 1 means that the + condition will be fully used (with intermediate values specifying a blend of the denoised and latent + values). + + Returns: + `Tuple[torch.Tensor, torch.Tensor, torch.Tensor]`: + Returns a 3-tuple of tensors where: + 1. The first element is the packed video latents (with unchanged shape [batch_size, patch_seq_len, + hidden_dim]) with the conditions applied + 2. The second element is the packed conditioning mask with conditioning strengths applied + 3. The third element holds the clean conditioning latents. + """ + # Latents-like tensor which holds the clean conditioning latents + clean_latents = torch.zeros_like(latents) + for cond, strength, latent_idx in zip(condition_latents, condition_strengths, condition_indices): + num_cond_tokens = cond.size(1) + start_token_idx = latent_idx * latent_height * latent_width + end_token_idx = start_token_idx + num_cond_tokens + + # Overwrite the portion of latents starting with start_token_idx with the condition + latents[:, start_token_idx:end_token_idx] = cond + conditioning_mask[:, start_token_idx:end_token_idx] = strength + clean_latents[:, start_token_idx:end_token_idx] = cond + + return latents, conditioning_mask, clean_latents + def prepare_latents( self, - image: Optional[torch.Tensor] = None, + conditions: Optional[Union[LTX2VideoCondition, List[LTX2VideoCondition]]] = None, batch_size: int = 1, num_channels_latents: int = 128, height: int = 512, - width: int = 704, - num_frames: int = 161, - noise_scale: float = 0.0, + width: int = 768, + num_frames: int = 121, + noise_scale: float = 1.0, dtype: Optional[torch.dtype] = None, device: Optional[torch.device] = None, generator: Optional[torch.Generator] = None, latents: Optional[torch.Tensor] = None, - ) -> torch.Tensor: - height = height // self.vae_spatial_compression_ratio - width = width // self.vae_spatial_compression_ratio - num_frames = (num_frames - 1) // self.vae_temporal_compression_ratio + 1 + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + latent_height = height // self.vae_spatial_compression_ratio + latent_width = width // self.vae_spatial_compression_ratio + latent_num_frames = (num_frames - 1) // self.vae_temporal_compression_ratio + 1 - shape = (batch_size, num_channels_latents, num_frames, height, width) - mask_shape = (batch_size, 1, num_frames, height, width) + shape = (batch_size, num_channels_latents, latent_num_frames, latent_height, latent_width) + mask_shape = (batch_size, 1, latent_num_frames, latent_height, latent_width) if latents is not None: - conditioning_mask = latents.new_zeros(mask_shape) - conditioning_mask[:, :, 0] = 1.0 + # Latents is either shape [B, F, C, H, W] or [B, seq_len, hidden_dim] if latents.ndim == 5: latents = self._normalize_latents( latents, self.vae.latents_mean, self.vae.latents_std, self.vae.config.scaling_factor ) - latents = self._create_noised_state(latents, noise_scale * (1 - conditioning_mask), generator) - # latents are of shape [B, C, F, H, W], need to be packed - latents = self._pack_latents( - latents, self.transformer_spatial_patch_size, self.transformer_temporal_patch_size - ) - conditioning_mask = self._pack_latents( - conditioning_mask, self.transformer_spatial_patch_size, self.transformer_temporal_patch_size - ).squeeze(-1) - if latents.ndim != 3 or latents.shape[:2] != conditioning_mask.shape: - raise ValueError( - f"Provided `latents` tensor has shape {latents.shape}, but the expected shape is {conditioning_mask.shape + (num_channels_latents,)}." - ) - return latents.to(device=device, dtype=dtype), conditioning_mask - - if isinstance(generator, list): - if len(generator) != batch_size: - raise ValueError( - f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" - f" size of {batch_size}. Make sure the batch size matches the length of the generators." - ) - - init_latents = [ - retrieve_latents(self.vae.encode(image[i].unsqueeze(0).unsqueeze(2)), generator[i], "argmax") - for i in range(batch_size) - ] else: - init_latents = [ - retrieve_latents(self.vae.encode(img.unsqueeze(0).unsqueeze(2)), generator, "argmax") for img in image - ] + # NOTE: we set the initial latents to zeros rather a sample from the standard Gaussian prior because we + # will sample from the prior later once we have calculated the conditioning mask + latents = torch.zeros(shape, device=device, dtype=dtype) - init_latents = torch.cat(init_latents, dim=0).to(dtype) - init_latents = self._normalize_latents(init_latents, self.vae.latents_mean, self.vae.latents_std) - init_latents = init_latents.repeat(1, 1, num_frames, 1, 1) + if latents.ndim == 5: + latents = self._pack_latents( + latents, self.transformer_spatial_patch_size, self.transformer_temporal_patch_size + ) + conditioning_mask = latents.new_zeros(mask_shape) + conditioning_mask = self._pack_latents( + conditioning_mask, self.transformer_spatial_patch_size, self.transformer_temporal_patch_size + ) # [B, seq_len, 1] + + if latents.ndim != 3 or latents.shape[:2] != conditioning_mask.shape[:2]: + raise ValueError( + f"Provided `latents` tensor has shape {latents.shape}, but the expected shape is {conditioning_mask.shape[:2] + (num_channels_latents,)}." + ) - # First condition is image latents and those should be kept clean. - conditioning_mask = torch.zeros(mask_shape, device=device, dtype=dtype) - conditioning_mask[:, :, 0] = 1.0 + if isinstance(generator, list): + logger.warning( + f"{self.__class__.__name__} does not support using a list of generators. The first generator in the" + f" list will be used for all (pseudo-)random operations." + ) + generator = generator[0] - noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) - # Interpolation. - latents = init_latents * conditioning_mask + noise * (1 - conditioning_mask) + condition_frames, condition_strengths, condition_indices = self.preprocess_conditions( + conditions, height, width, num_frames, device=device + ) + # TODO: should we first concatenate all of the condition tensors and encode them all together? The advantage + # is that this would generally respect VAE settings like tiling, but a disadvantage is that this would by + # default take a lot of memory (for LTX 2, tiled encoding/decoding is almost always necessary). + condition_latents = [] + for condition_tensor in condition_frames: + condition_latent = retrieve_latents( + self.vae.encode(condition_tensor), generator=generator, sample_mode="argmax" + ) + condition_latent = self._normalize_latents( + condition_latent, self.vae.latents_mean, self.vae.latents_std + ).to(device=device, dtype=dtype) + condition_latent = self._pack_latents( + condition_latent, self.transformer_spatial_patch_size, self.transformer_temporal_patch_size + ) + condition_latents.append(condition_latent) - conditioning_mask = self._pack_latents( - conditioning_mask, self.transformer_spatial_patch_size, self.transformer_temporal_patch_size - ).squeeze(-1) - latents = self._pack_latents( - latents, self.transformer_spatial_patch_size, self.transformer_temporal_patch_size + # NOTE: following the I2V pipeline, we return a conditioning mask. The original LTX 2 code uses a denoising + # mask, which is the inverse of the conditioning mask (`denoise_mask = 1 - conditioning_mask`) + latents, conditioning_mask, clean_latents = self.apply_visual_conditioning( + latents, + conditioning_mask, + condition_latents, + condition_strengths, + condition_indices, + latent_height=latent_height, + latent_width=latent_width, ) - return latents, conditioning_mask + # Sample from the standard Gaussian prior (or an intermediate Gaussian distribution if noise_scale < 1.0). + noise = randn_tensor(latents.shape, generator=generator, device=latents.device, dtype=latents.dtype) + scaled_mask = (1.0 - conditioning_mask) * noise_scale + # Add noise to the `latents` so that it is at the noise level specified by `noise_scale`. + latents = noise * scaled_mask + latents * (1 - scaled_mask) + + return latents, conditioning_mask, clean_latents # Copied from diffusers.pipelines.ltx2.pipeline_ltx2.LTX2Pipeline.prepare_audio_latents def prepare_audio_latents( @@ -854,7 +1037,11 @@ def interrupt(self): @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( self, - image: PipelineImageInput = None, + conditions: Union[LTX2VideoCondition, List[LTX2VideoCondition]] = None, + image: Union[PipelineImageInput, List[PipelineImageInput]] = None, + video: List[PipelineImageInput] = None, + cond_index: Union[int, List[int]] = 0, + strength: Union[float, List[float]] = 1.0, prompt: Union[str, List[str]] = None, negative_prompt: Optional[Union[str, List[str]]] = None, height: int = 512, @@ -866,7 +1053,7 @@ def __call__( timesteps: List[int] = None, guidance_scale: float = 4.0, guidance_rescale: float = 0.0, - noise_scale: float = 0.0, + noise_scale: Optional[float] = None, num_videos_per_prompt: Optional[int] = 1, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.Tensor] = None, @@ -888,8 +1075,14 @@ def __call__( Function invoked when calling the pipeline for generation. Args: - image (`PipelineImageInput`): - The input image to condition the generation on. Must be an image, a list of images or a `torch.Tensor`. + conditions (`List[LTXVideoCondition], *optional*`): + The list of frame-conditioning items for the video generation.If not provided, conditions will be + created using `image`, `video`, `frame_index` and `strength`. + image (`PipelineImageInput` or `List[PipelineImageInput]`, *optional*): + The image or images to condition the video generation. If not provided, one has to pass `video` or + `conditions`. + video (`List[PipelineImageInput]`, *optional*): + The video to condition the video generation. If not provided, one has to pass `image` or `conditions`. prompt (`str` or `List[str]`, *optional*): The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. instead. @@ -924,9 +1117,10 @@ def __call__( [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://huggingface.co/papers/2305.08891). Guidance rescale factor should fix overexposure when using zero terminal SNR. - noise_scale (`float`, *optional*, defaults to `0.0`): + noise_scale (`float`, *optional*, defaults to `None`): The interpolation factor between random noise and denoised latents at each timestep. Applying noise to - the `latents` and `audio_latents` before continue denoising. + the `latents` and `audio_latents` before continue denoising. If not set, will be inferred from the + sigma schedule. num_videos_per_prompt (`int`, *optional*, defaults to 1): The number of videos to generate per prompt. generator (`torch.Generator` or `List[torch.Generator]`, *optional*): @@ -988,6 +1182,11 @@ def __call__( # 1. Check inputs. Raise error if not correct self.check_inputs( + conditions=conditions, + image=image, + video=video, + cond_index=cond_index, + strength=strength, prompt=prompt, height=height, width=width, @@ -1012,6 +1211,30 @@ def __call__( else: batch_size = prompt_embeds.shape[0] + if conditions is not None: + if not isinstance(conditions, list): + conditions = [conditions] + elif video is not None: + if isinstance(video, list): + # Interpret as a list of conditions + conditions = [] + for v_cond, idx, strength in zip(video, cond_index, strength): + conditions.append(LTX2VideoCondition(image=v_cond, index=idx, strength=strength)) + else: + # Interpret as single condition (cond_index and strength are also assumed to not be lists) + conditions = [LTX2VideoCondition(image=video, index=cond_index, strength=strength)] + elif image is not None: + if isinstance(image, list): + conditions = [] + for i_cond, idx, strength in zip(video, cond_index, strength): + conditions.append(LTX2VideoCondition(image=i_cond, index=idx, strength=strength)) + else: + conditions = [LTX2VideoCondition(image=video, index=cond_index, strength=strength)] + + # Infer noise scale: first (largest) sigma value if using custom sigmas, else 1.0 + if noise_scale is None: + noise_scale = sigmas[0] if sigmas is not None else 1.0 + device = self._execution_device # 3. Prepare text embeddings @@ -1062,13 +1285,9 @@ def __call__( ) video_sequence_length = latent_num_frames * latent_height * latent_width - if latents is None: - image = self.video_processor.preprocess(image, height=height, width=width) - image = image.to(device=device, dtype=prompt_embeds.dtype) - num_channels_latents = self.transformer.config.in_channels - latents, conditioning_mask = self.prepare_latents( - image, + latents, conditioning_mask, clean_latents = self.prepare_latents( + conditions, batch_size * num_videos_per_prompt, num_channels_latents, height, @@ -1182,7 +1401,7 @@ def __call__( audio_latent_model_input = audio_latent_model_input.to(prompt_embeds.dtype) timestep = t.expand(latent_model_input.shape[0]) - video_timestep = timestep.unsqueeze(-1) * (1 - conditioning_mask) + video_timestep = timestep.unsqueeze(-1) * (1 - conditioning_mask.squeeze(-1)) with self.transformer.cache_context("cond_uncond"): noise_pred_video, noise_pred_audio = self.transformer( @@ -1228,32 +1447,17 @@ def __call__( noise_pred_audio, noise_pred_audio_text, guidance_rescale=self.guidance_rescale ) - # compute the previous noisy sample x_t -> x_t-1 - noise_pred_video = self._unpack_latents( - noise_pred_video, - latent_num_frames, - latent_height, - latent_width, - self.transformer_spatial_patch_size, - self.transformer_temporal_patch_size, - ) - latents = self._unpack_latents( - latents, - latent_num_frames, - latent_height, - latent_width, - self.transformer_spatial_patch_size, - self.transformer_temporal_patch_size, - ) - - noise_pred_video = noise_pred_video[:, :, 1:] - noise_latents = latents[:, :, 1:] - pred_latents = self.scheduler.step(noise_pred_video, t, noise_latents, return_dict=False)[0] - - latents = torch.cat([latents[:, :, :1], pred_latents], dim=2) - latents = self._pack_latents( - latents, self.transformer_spatial_patch_size, self.transformer_temporal_patch_size - ) + # Apply the (packed) conditioning mask to the denoising model output, which will blend the conditions + # with the denoised output according to the conditioning strength (a strength of 1.0 means we fully + # overwrite the denoised output with the condition) + # NOTE: use only the first chunk of conditioning mask in case it is duplicated for CFG + bsz = noise_pred_video.size(0) + denoised_latents_with_cond = ( + noise_pred_video * (1 - conditioning_mask[:bsz]) + clean_latents.float() * conditioning_mask[:bsz] + ).to(noise_pred_video.dtype) + + # Compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(denoised_latents_with_cond, t, latents, return_dict=False)[0] # NOTE: for now duplicate scheduler for audio latents in case self.scheduler sets internal state in # the step method (such as _step_index) From 5368d73f7eae441ae7bc37dbee36531c0851df92 Mon Sep 17 00:00:00 2001 From: Daniel Gu Date: Wed, 4 Feb 2026 05:47:38 +0100 Subject: [PATCH 04/25] Blend denoising output and clean latents in sample space instead of velocity space --- .../pipelines/ltx2/pipeline_ltx2_condition.py | 21 ++++++++++++------- 1 file changed, 13 insertions(+), 8 deletions(-) diff --git a/src/diffusers/pipelines/ltx2/pipeline_ltx2_condition.py b/src/diffusers/pipelines/ltx2/pipeline_ltx2_condition.py index 164d14569714..4ea5f17cb9c0 100644 --- a/src/diffusers/pipelines/ltx2/pipeline_ltx2_condition.py +++ b/src/diffusers/pipelines/ltx2/pipeline_ltx2_condition.py @@ -819,7 +819,7 @@ def preprocess_conditions( continue cond_num_frames = condition_pixels.size(2) - start_idx = (latent_start_idx - 1) * frame_scale_factor + 1 + start_idx = max((latent_start_idx - 1) * frame_scale_factor + 1, 0) truncated_cond_frames = self.trim_conditioning_sequence(start_idx, cond_num_frames, num_frames) condition_pixels = condition_pixels[:, :, :truncated_cond_frames] @@ -905,11 +905,11 @@ def prepare_latents( # will sample from the prior later once we have calculated the conditioning mask latents = torch.zeros(shape, device=device, dtype=dtype) + conditioning_mask = latents.new_zeros(mask_shape) if latents.ndim == 5: latents = self._pack_latents( latents, self.transformer_spatial_patch_size, self.transformer_temporal_patch_size ) - conditioning_mask = latents.new_zeros(mask_shape) conditioning_mask = self._pack_latents( conditioning_mask, self.transformer_spatial_patch_size, self.transformer_temporal_patch_size ) # [B, seq_len, 1] @@ -1447,17 +1447,22 @@ def __call__( noise_pred_audio, noise_pred_audio_text, guidance_rescale=self.guidance_rescale ) - # Apply the (packed) conditioning mask to the denoising model output, which will blend the conditions - # with the denoised output according to the conditioning strength (a strength of 1.0 means we fully - # overwrite the denoised output with the condition) # NOTE: use only the first chunk of conditioning mask in case it is duplicated for CFG bsz = noise_pred_video.size(0) - denoised_latents_with_cond = ( - noise_pred_video * (1 - conditioning_mask[:bsz]) + clean_latents.float() * conditioning_mask[:bsz] + sigma = self.scheduler.sigmas[i] + # Convert the noise_pred_video velocity model prediction into a sample (x0) prediction + denoised_sample = latents - noise_pred_video * sigma + # Apply the (packed) conditioning mask to the denoised (x0) sample, which will blend the conditions + # with the denoised sample according to the conditioning strength (a strength of 1.0 means we fully + # overwrite the denoised sample with the condition) + denoised_sample_cond = ( + denoised_sample * (1 - conditioning_mask[:bsz]) + clean_latents.float() * conditioning_mask[:bsz] ).to(noise_pred_video.dtype) + # Convert the denoised (x0) sample back to a velocity for the scheduler + denoised_latents_cond = ((latents - denoised_sample_cond) / sigma).to(noise_pred_video.dtype) # Compute the previous noisy sample x_t -> x_t-1 - latents = self.scheduler.step(denoised_latents_with_cond, t, latents, return_dict=False)[0] + latents = self.scheduler.step(denoised_latents_cond, t, latents, return_dict=False)[0] # NOTE: for now duplicate scheduler for audio latents in case self.scheduler sets internal state in # the step method (such as _step_index) From 5577e08433ae544e7d38c407b5418a2ddc8aa9e5 Mon Sep 17 00:00:00 2001 From: Daniel Gu Date: Wed, 4 Feb 2026 06:48:36 +0100 Subject: [PATCH 05/25] make style and make quality --- .../pipelines/ltx2/pipeline_ltx2_condition.py | 11 ++++------- 1 file changed, 4 insertions(+), 7 deletions(-) diff --git a/src/diffusers/pipelines/ltx2/pipeline_ltx2_condition.py b/src/diffusers/pipelines/ltx2/pipeline_ltx2_condition.py index 4ea5f17cb9c0..0a15b2f17789 100644 --- a/src/diffusers/pipelines/ltx2/pipeline_ltx2_condition.py +++ b/src/diffusers/pipelines/ltx2/pipeline_ltx2_condition.py @@ -744,9 +744,7 @@ def latent_idx_from_index(self, frame_idx: int, index_type: str = "latent") -> i if index_type == "latent": latent_idx = frame_idx else: - raise ValueError( - f"Got unsupported `index_type` {index_type}. Supported index types are `latent`." - ) + raise ValueError(f"Got unsupported `index_type` {index_type}. Supported index types are `latent`.") return latent_idx def preprocess_conditions( @@ -846,10 +844,9 @@ def apply_visual_conditioning( latents (`torch.Tensor`): Initial packed (patchified) latents of shape [batch_size, patch_seq_len, hidden_dim]. conditioning_mask (`torch.Tensor`, *optional*): - Initial packed (patchified) conditioning mask of shape [batch_size, patch_seq_len, 1] with - values in [0, 1] where 0 means that the denoising model output will be fully used and 1 means that the - condition will be fully used (with intermediate values specifying a blend of the denoised and latent - values). + Initial packed (patchified) conditioning mask of shape [batch_size, patch_seq_len, 1] with values in + [0, 1] where 0 means that the denoising model output will be fully used and 1 means that the condition + will be fully used (with intermediate values specifying a blend of the denoised and latent values). Returns: `Tuple[torch.Tensor, torch.Tensor, torch.Tensor]`: From e0bd6a07f76b3c14a0b0f33e351cedeff8bdae51 Mon Sep 17 00:00:00 2001 From: Daniel Gu Date: Wed, 4 Feb 2026 06:52:04 +0100 Subject: [PATCH 06/25] make fix-copies --- .../utils/dummy_torch_and_transformers_objects.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/src/diffusers/utils/dummy_torch_and_transformers_objects.py b/src/diffusers/utils/dummy_torch_and_transformers_objects.py index a23f852616c0..3a86ec0da169 100644 --- a/src/diffusers/utils/dummy_torch_and_transformers_objects.py +++ b/src/diffusers/utils/dummy_torch_and_transformers_objects.py @@ -2012,6 +2012,21 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) +class LTX2ConditionPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + class LTX2ImageToVideoPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] From 45051e18f5057487d7788e68b92e1abbf6d8f961 Mon Sep 17 00:00:00 2001 From: Daniel Gu Date: Thu, 5 Feb 2026 02:05:31 +0100 Subject: [PATCH 07/25] Rename LTX2VideoCondition image to frames --- .../pipelines/ltx2/pipeline_ltx2_condition.py | 26 +++++++++---------- 1 file changed, 13 insertions(+), 13 deletions(-) diff --git a/src/diffusers/pipelines/ltx2/pipeline_ltx2_condition.py b/src/diffusers/pipelines/ltx2/pipeline_ltx2_condition.py index 0a15b2f17789..427ab9632a3e 100644 --- a/src/diffusers/pipelines/ltx2/pipeline_ltx2_condition.py +++ b/src/diffusers/pipelines/ltx2/pipeline_ltx2_condition.py @@ -97,7 +97,7 @@ class LTX2VideoCondition: Defines a single frame-conditioning item for LTX-2 Video - a single frame or a sequence of frames. Attributes: - image (`PIL.Image.Image` or `List[PIL.Image.Image]` or `np.ndarray` or `torch.Tensor`): + frames (`PIL.Image.Image` or `List[PIL.Image.Image]` or `np.ndarray` or `torch.Tensor`): The image (or video) to condition the video on. Accepts any type that can be handled by VideoProcessor.preprocess_video. index (`int`, defaults to `0`): @@ -106,7 +106,7 @@ class LTX2VideoCondition: The strength of the conditioning effect. A value of `1.0` means the conditioning effect is fully applied. """ - image: Union[PIL.Image.Image, List[PIL.Image.Image], np.ndarray, torch.Tensor] + frames: Union[PIL.Image.Image, List[PIL.Image.Image], np.ndarray, torch.Tensor] index: int = 0 strength: float = 1.0 @@ -789,19 +789,19 @@ def preprocess_conditions( frame_scale_factor = self.vae_temporal_compression_ratio latent_num_frames = (num_frames - 1) // frame_scale_factor + 1 for i, condition in enumerate(conditions): - if isinstance(condition.image, PIL.Image.Image): + if isinstance(condition.frames, PIL.Image.Image): # Single image, convert to List[PIL.Image.Image] - video_like_cond = [condition.image] - elif isinstance(condition.image, np.ndarray) and condition.image.ndim == 3: + video_like_cond = [condition.frames] + elif isinstance(condition.frames, np.ndarray) and condition.frames.ndim == 3: # Image-like ndarray of shape (H, W, C), insert frame dim in first axis - video_like_cond = np.expand_dims(condition.image, axis=0) - elif isinstance(condition.image, torch.Tensor) and condition.image.ndim == 3: + video_like_cond = np.expand_dims(condition.frames, axis=0) + elif isinstance(condition.frames, torch.Tensor) and condition.frames.ndim == 3: # Image-like tensor of shape (C, H, W), insert frame dim in first dim - video_like_cond = condition.image.unsqueeze(0) + video_like_cond = condition.frames.unsqueeze(0) else: # Treat all other as videos. Note that this means 4D ndarrays and tensors will be treated as videos of # shape (F, H, W, C) and (F, C, H, W), respectively. - video_like_cond = condition.image + video_like_cond = condition.frames condition_pixels = self.video_processor.preprocess_video(video_like_cond, height, width) latent_start_idx = self.latent_idx_from_index(condition.index, index_type) @@ -1216,17 +1216,17 @@ def __call__( # Interpret as a list of conditions conditions = [] for v_cond, idx, strength in zip(video, cond_index, strength): - conditions.append(LTX2VideoCondition(image=v_cond, index=idx, strength=strength)) + conditions.append(LTX2VideoCondition(frames=v_cond, index=idx, strength=strength)) else: # Interpret as single condition (cond_index and strength are also assumed to not be lists) - conditions = [LTX2VideoCondition(image=video, index=cond_index, strength=strength)] + conditions = [LTX2VideoCondition(frames=video, index=cond_index, strength=strength)] elif image is not None: if isinstance(image, list): conditions = [] for i_cond, idx, strength in zip(video, cond_index, strength): - conditions.append(LTX2VideoCondition(image=i_cond, index=idx, strength=strength)) + conditions.append(LTX2VideoCondition(frames=i_cond, index=idx, strength=strength)) else: - conditions = [LTX2VideoCondition(image=video, index=cond_index, strength=strength)] + conditions = [LTX2VideoCondition(frames=video, index=cond_index, strength=strength)] # Infer noise scale: first (largest) sigma value if using custom sigmas, else 1.0 if noise_scale is None: From d39d89f0d6a266e56f7b7a7c37861236d14e4ac1 Mon Sep 17 00:00:00 2001 From: Daniel Gu Date: Thu, 5 Feb 2026 02:16:01 +0100 Subject: [PATCH 08/25] Update LTX2ConditionPipeline example --- .../pipelines/ltx2/pipeline_ltx2_condition.py | 21 ++++++++++++------- 1 file changed, 14 insertions(+), 7 deletions(-) diff --git a/src/diffusers/pipelines/ltx2/pipeline_ltx2_condition.py b/src/diffusers/pipelines/ltx2/pipeline_ltx2_condition.py index 427ab9632a3e..364b448302f3 100644 --- a/src/diffusers/pipelines/ltx2/pipeline_ltx2_condition.py +++ b/src/diffusers/pipelines/ltx2/pipeline_ltx2_condition.py @@ -50,22 +50,29 @@ Examples: ```py >>> import torch - >>> from diffusers import LTX2Pipeline + >>> from diffusers import LTX2ConditionPipeline >>> from diffusers.pipelines.ltx2.export_utils import encode_video + >>> from diffusers.pipelines.ltx2.pipeline_ltx2_condition import LTX2VideoCondition >>> from diffusers.utils import load_image - >>> pipe = LTX2ImageToVideoPipeline.from_pretrained("Lightricks/LTX-2", torch_dtype=torch.bfloat16) + >>> pipe = LTX2ConditionPipeline.from_pretrained("Lightricks/LTX-2", torch_dtype=torch.bfloat16) >>> pipe.enable_model_cpu_offload() - >>> image = load_image( - ... "https://huggingface.co/datasets/a-r-r-o-w/tiny-meme-dataset-captioned/resolve/main/images/8.png" + >>> first_image = load_image( + ... "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/flf2v_input_first_frame.png" ... ) - >>> prompt = "A young girl stands calmly in the foreground, looking directly at the camera, as a house fire rages in the background." - >>> negative_prompt = "worst quality, inconsistent motion, blurry, jittery, distorted" + >>> last_image = load_image( + ... "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/flf2v_input_last_frame.png" + ... ) + >>> first_cond = LTX2VideoCondition(frames=first_image, index=0, strength=1.0) + >>> last_cond = LTX2VideoCondition(frames=last_image, index=-1, strength=1.0) + >>> conditions = [first_cond, last_cond] + >>> prompt = "CG animation style, a small blue bird takes off from the ground, flapping its wings." + >>> negative_prompt = "worst quality, inconsistent motion, blurry, jittery, distorted, static" >>> frame_rate = 24.0 >>> video = pipe( - ... image=image, + ... conditions=conditions, ... prompt=prompt, ... negative_prompt=negative_prompt, ... width=768, From 2e824f561a8027e4c0159b19b146a093d9e52164 Mon Sep 17 00:00:00 2001 From: Daniel Gu Date: Thu, 5 Feb 2026 06:30:42 +0100 Subject: [PATCH 09/25] Remove support for image and video in __call__ --- .../pipelines/ltx2/pipeline_ltx2_condition.py | 61 +------------------ 1 file changed, 3 insertions(+), 58 deletions(-) diff --git a/src/diffusers/pipelines/ltx2/pipeline_ltx2_condition.py b/src/diffusers/pipelines/ltx2/pipeline_ltx2_condition.py index 364b448302f3..7ee3740f13a6 100644 --- a/src/diffusers/pipelines/ltx2/pipeline_ltx2_condition.py +++ b/src/diffusers/pipelines/ltx2/pipeline_ltx2_condition.py @@ -23,7 +23,6 @@ from transformers import Gemma3ForConditionalGeneration, GemmaTokenizer, GemmaTokenizerFast from ...callbacks import MultiPipelineCallbacks, PipelineCallback -from ...image_processor import PipelineImageInput from ...loaders import FromSingleFileMixin, LTXVideoLoraLoaderMixin from ...models.autoencoders import AutoencoderKLLTX2Audio, AutoencoderKLLTX2Video from ...models.transformers import LTX2VideoTransformer3DModel @@ -533,11 +532,6 @@ def encode_prompt( def check_inputs( self, - conditions, - image, - video, - cond_index, - strength, prompt, height, width, @@ -589,23 +583,6 @@ def check_inputs( f" {negative_prompt_attention_mask.shape}." ) - if conditions is not None and (image is not None or video is not None): - raise ValueError("If `conditions` is provided, `image` and `video` must not be provided.") - - if conditions is None: - if isinstance(image, list) and isinstance(cond_index, list) and len(image) != len(cond_index): - raise ValueError( - "If `conditions` is not provided, `image` and `cond_index` must be of the same length." - ) - elif isinstance(image, list) and isinstance(strength, list) and len(image) != len(strength): - raise ValueError("If `conditions` is not provided, `image` and `strength` must be of the same length.") - elif isinstance(video, list) and isinstance(cond_index, list) and len(video) != len(cond_index): - raise ValueError( - "If `conditions` is not provided, `video` and `cond_index` must be of the same length." - ) - elif isinstance(video, list) and isinstance(strength, list) and len(video) != len(strength): - raise ValueError("If `conditions` is not provided, `video` and `strength` must be of the same length.") - @staticmethod # Copied from diffusers.pipelines.ltx2.pipeline_ltx2.LTX2Pipeline._pack_latents def _pack_latents(latents: torch.Tensor, patch_size: int = 1, patch_size_t: int = 1) -> torch.Tensor: @@ -1042,10 +1019,6 @@ def interrupt(self): def __call__( self, conditions: Union[LTX2VideoCondition, List[LTX2VideoCondition]] = None, - image: Union[PipelineImageInput, List[PipelineImageInput]] = None, - video: List[PipelineImageInput] = None, - cond_index: Union[int, List[int]] = 0, - strength: Union[float, List[float]] = 1.0, prompt: Union[str, List[str]] = None, negative_prompt: Optional[Union[str, List[str]]] = None, height: int = 512, @@ -1080,13 +1053,7 @@ def __call__( Args: conditions (`List[LTXVideoCondition], *optional*`): - The list of frame-conditioning items for the video generation.If not provided, conditions will be - created using `image`, `video`, `frame_index` and `strength`. - image (`PipelineImageInput` or `List[PipelineImageInput]`, *optional*): - The image or images to condition the video generation. If not provided, one has to pass `video` or - `conditions`. - video (`List[PipelineImageInput]`, *optional*): - The video to condition the video generation. If not provided, one has to pass `image` or `conditions`. + The list of frame-conditioning items for the video generation. prompt (`str` or `List[str]`, *optional*): The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. instead. @@ -1186,11 +1153,6 @@ def __call__( # 1. Check inputs. Raise error if not correct self.check_inputs( - conditions=conditions, - image=image, - video=video, - cond_index=cond_index, - strength=strength, prompt=prompt, height=height, width=width, @@ -1215,25 +1177,8 @@ def __call__( else: batch_size = prompt_embeds.shape[0] - if conditions is not None: - if not isinstance(conditions, list): - conditions = [conditions] - elif video is not None: - if isinstance(video, list): - # Interpret as a list of conditions - conditions = [] - for v_cond, idx, strength in zip(video, cond_index, strength): - conditions.append(LTX2VideoCondition(frames=v_cond, index=idx, strength=strength)) - else: - # Interpret as single condition (cond_index and strength are also assumed to not be lists) - conditions = [LTX2VideoCondition(frames=video, index=cond_index, strength=strength)] - elif image is not None: - if isinstance(image, list): - conditions = [] - for i_cond, idx, strength in zip(video, cond_index, strength): - conditions.append(LTX2VideoCondition(frames=i_cond, index=idx, strength=strength)) - else: - conditions = [LTX2VideoCondition(frames=video, index=cond_index, strength=strength)] + if conditions is not None and not isinstance(conditions, list): + conditions = [conditions] # Infer noise scale: first (largest) sigma value if using custom sigmas, else 1.0 if noise_scale is None: From 33e6ec1f85af8854502ee451bab3906891df788f Mon Sep 17 00:00:00 2001 From: Daniel Gu Date: Thu, 5 Feb 2026 06:37:56 +0100 Subject: [PATCH 10/25] Put latent_idx_from_index logic inline --- .../pipelines/ltx2/pipeline_ltx2_condition.py | 10 ++-------- 1 file changed, 2 insertions(+), 8 deletions(-) diff --git a/src/diffusers/pipelines/ltx2/pipeline_ltx2_condition.py b/src/diffusers/pipelines/ltx2/pipeline_ltx2_condition.py index 7ee3740f13a6..9ba86a141dd4 100644 --- a/src/diffusers/pipelines/ltx2/pipeline_ltx2_condition.py +++ b/src/diffusers/pipelines/ltx2/pipeline_ltx2_condition.py @@ -724,13 +724,6 @@ def trim_conditioning_sequence(self, start_frame: int, sequence_num_frames: int, num_frames = (num_frames - 1) // scale_factor * scale_factor + 1 return num_frames - def latent_idx_from_index(self, frame_idx: int, index_type: str = "latent") -> int: - if index_type == "latent": - latent_idx = frame_idx - else: - raise ValueError(f"Got unsupported `index_type` {index_type}. Supported index types are `latent`.") - return latent_idx - def preprocess_conditions( self, conditions: Optional[Union[LTX2VideoCondition, List[LTX2VideoCondition]]] = None, @@ -788,7 +781,8 @@ def preprocess_conditions( video_like_cond = condition.frames condition_pixels = self.video_processor.preprocess_video(video_like_cond, height, width) - latent_start_idx = self.latent_idx_from_index(condition.index, index_type) + # Interpret the index as a latent index, following the original LTX-2 code. + latent_start_idx = condition.index # Support negative latent indices (e.g. -1 for the last latent index) if latent_start_idx < 0: # latent_start_idx will be positive because latent_num_frames is positive From 98f74b2fe4b1d4a1440ce42ec0bc42b47e0d65a8 Mon Sep 17 00:00:00 2001 From: Daniel Gu Date: Thu, 5 Feb 2026 06:48:27 +0100 Subject: [PATCH 11/25] Improve comment on using the conditioning mask in denoising loop --- src/diffusers/pipelines/ltx2/pipeline_ltx2_condition.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/diffusers/pipelines/ltx2/pipeline_ltx2_condition.py b/src/diffusers/pipelines/ltx2/pipeline_ltx2_condition.py index 9ba86a141dd4..432c365477fb 100644 --- a/src/diffusers/pipelines/ltx2/pipeline_ltx2_condition.py +++ b/src/diffusers/pipelines/ltx2/pipeline_ltx2_condition.py @@ -1395,9 +1395,9 @@ def __call__( sigma = self.scheduler.sigmas[i] # Convert the noise_pred_video velocity model prediction into a sample (x0) prediction denoised_sample = latents - noise_pred_video * sigma - # Apply the (packed) conditioning mask to the denoised (x0) sample, which will blend the conditions - # with the denoised sample according to the conditioning strength (a strength of 1.0 means we fully - # overwrite the denoised sample with the condition) + # Apply the (packed) conditioning mask to the denoised (x0) sample and clean conditioning. The + # conditioning mask contains conditioning strengths from 0 (always use denoised sample) to 1 (always + # use conditions), with intermediate values specifying how strongly to follow the conditions. denoised_sample_cond = ( denoised_sample * (1 - conditioning_mask[:bsz]) + clean_latents.float() * conditioning_mask[:bsz] ).to(noise_pred_video.dtype) From 83c8ae6b2968620702ccd1e20f03c3304937bae9 Mon Sep 17 00:00:00 2001 From: dg845 <58458699+dg845@users.noreply.github.com> Date: Fri, 13 Feb 2026 19:15:08 -0800 Subject: [PATCH 12/25] Apply suggestions from code review MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Álvaro Somoza --- src/diffusers/pipelines/ltx2/pipeline_ltx2_condition.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/diffusers/pipelines/ltx2/pipeline_ltx2_condition.py b/src/diffusers/pipelines/ltx2/pipeline_ltx2_condition.py index 432c365477fb..f886354e3330 100644 --- a/src/diffusers/pipelines/ltx2/pipeline_ltx2_condition.py +++ b/src/diffusers/pipelines/ltx2/pipeline_ltx2_condition.py @@ -23,7 +23,7 @@ from transformers import Gemma3ForConditionalGeneration, GemmaTokenizer, GemmaTokenizerFast from ...callbacks import MultiPipelineCallbacks, PipelineCallback -from ...loaders import FromSingleFileMixin, LTXVideoLoraLoaderMixin +from ...loaders import FromSingleFileMixin, LTX2LoraLoaderMixin from ...models.autoencoders import AutoencoderKLLTX2Audio, AutoencoderKLLTX2Video from ...models.transformers import LTX2VideoTransformer3DModel from ...schedulers import FlowMatchEulerDiscreteScheduler @@ -232,7 +232,7 @@ def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): return noise_cfg -class LTX2ConditionPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraLoaderMixin): +class LTX2ConditionPipeline(DiffusionPipeline, FromSingleFileMixin, LTX2LoraLoaderMixin): r""" Pipeline for video generation which allows image conditions to be inserted at arbitary parts of the video. From 1cdea99b8bd40231ee93466fa2c19119e6411b31 Mon Sep 17 00:00:00 2001 From: Daniel Gu Date: Sat, 14 Feb 2026 04:27:05 +0100 Subject: [PATCH 13/25] make fix-copies --- .../pipelines/ltx2/pipeline_ltx2_condition.py | 62 +++++++++---------- 1 file changed, 31 insertions(+), 31 deletions(-) diff --git a/src/diffusers/pipelines/ltx2/pipeline_ltx2_condition.py b/src/diffusers/pipelines/ltx2/pipeline_ltx2_condition.py index f886354e3330..58d65ba7647f 100644 --- a/src/diffusers/pipelines/ltx2/pipeline_ltx2_condition.py +++ b/src/diffusers/pipelines/ltx2/pipeline_ltx2_condition.py @@ -119,7 +119,7 @@ class LTX2VideoCondition: # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents def retrieve_latents( - encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample" + encoder_output: torch.Tensor, generator: torch.Generator | None = None, sample_mode: str = "sample" ): if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": return encoder_output.latent_dist.sample(generator) @@ -148,10 +148,10 @@ def calculate_shift( # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps def retrieve_timesteps( scheduler, - num_inference_steps: Optional[int] = None, - device: Optional[Union[str, torch.device]] = None, - timesteps: Optional[List[int]] = None, - sigmas: Optional[List[float]] = None, + num_inference_steps: int | None = None, + device: str | torch.device | None = None, + timesteps: list[int] | None = None, + sigmas: list[float] | None = None, **kwargs, ): r""" @@ -166,15 +166,15 @@ def retrieve_timesteps( must be `None`. device (`str` or `torch.device`, *optional*): The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. - timesteps (`List[int]`, *optional*): + timesteps (`list[int]`, *optional*): Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, `num_inference_steps` and `sigmas` must be `None`. - sigmas (`List[float]`, *optional*): + sigmas (`list[float]`, *optional*): Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, `num_inference_steps` and `timesteps` must be `None`. Returns: - `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + `tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the second element is the number of inference steps. """ if timesteps is not None and sigmas is not None: @@ -306,7 +306,7 @@ def __init__( def _pack_text_embeds( text_hidden_states: torch.Tensor, sequence_lengths: torch.Tensor, - device: Union[str, torch.device], + device: str | torch.device, padding_side: str = "left", scale_factor: int = 8, eps: float = 1e-6, @@ -372,18 +372,18 @@ def _pack_text_embeds( # Copied from diffusers.pipelines.ltx2.pipeline_ltx2.LTX2Pipeline._get_gemma_prompt_embeds def _get_gemma_prompt_embeds( self, - prompt: Union[str, List[str]], + prompt: str | list[str], num_videos_per_prompt: int = 1, max_sequence_length: int = 1024, scale_factor: int = 8, - device: Optional[torch.device] = None, - dtype: Optional[torch.dtype] = None, + device: torch.device | None = None, + dtype: torch.dtype | None = None, ): r""" Encodes the prompt into text encoder hidden states. Args: - prompt (`str` or `List[str]`, *optional*): + prompt (`str` or `list[str]`, *optional*): prompt to be encoded device: (`str` or `torch.device`): torch device to place the resulting embeddings on @@ -446,26 +446,26 @@ def _get_gemma_prompt_embeds( # Copied from diffusers.pipelines.ltx2.pipeline_ltx2.LTX2Pipeline.encode_prompt def encode_prompt( self, - prompt: Union[str, List[str]], - negative_prompt: Optional[Union[str, List[str]]] = None, + prompt: str | list[str], + negative_prompt: str | list[str] | None = None, do_classifier_free_guidance: bool = True, num_videos_per_prompt: int = 1, - prompt_embeds: Optional[torch.Tensor] = None, - negative_prompt_embeds: Optional[torch.Tensor] = None, - prompt_attention_mask: Optional[torch.Tensor] = None, - negative_prompt_attention_mask: Optional[torch.Tensor] = None, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + prompt_attention_mask: torch.Tensor | None = None, + negative_prompt_attention_mask: torch.Tensor | None = None, max_sequence_length: int = 1024, scale_factor: int = 8, - device: Optional[torch.device] = None, - dtype: Optional[torch.dtype] = None, + device: torch.device | None = None, + dtype: torch.dtype | None = None, ): r""" Encodes the prompt into text encoder hidden states. Args: - prompt (`str` or `List[str]`, *optional*): + prompt (`str` or `list[str]`, *optional*): prompt to be encoded - negative_prompt (`str` or `List[str]`, *optional*): + negative_prompt (`str` or `list[str]`, *optional*): The prompt or prompts not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). @@ -644,7 +644,7 @@ def _denormalize_latents( @staticmethod # Copied from diffusers.pipelines.ltx2.pipeline_ltx2.LTX2Pipeline._create_noised_state def _create_noised_state( - latents: torch.Tensor, noise_scale: Union[float, torch.Tensor], generator: Optional[torch.Generator] = None + latents: torch.Tensor, noise_scale: float | torch.Tensor, generator: torch.Generator | None = None ): noise = randn_tensor(latents.shape, generator=generator, device=latents.device, dtype=latents.dtype) noised_latents = noise_scale * noise + (1 - noise_scale) * latents @@ -653,7 +653,7 @@ def _create_noised_state( @staticmethod # Copied from diffusers.pipelines.ltx2.pipeline_ltx2.LTX2Pipeline._pack_audio_latents def _pack_audio_latents( - latents: torch.Tensor, patch_size: Optional[int] = None, patch_size_t: Optional[int] = None + latents: torch.Tensor, patch_size: int | None = None, patch_size_t: int | None = None ) -> torch.Tensor: # Audio latents shape: [B, C, L, M], where L is the latent audio length and M is the number of mel bins if patch_size is not None and patch_size_t is not None: @@ -678,8 +678,8 @@ def _unpack_audio_latents( latents: torch.Tensor, latent_length: int, num_mel_bins: int, - patch_size: Optional[int] = None, - patch_size_t: Optional[int] = None, + patch_size: int | None = None, + patch_size_t: int | None = None, ) -> torch.Tensor: # Unpacks an audio patch sequence of shape [B, S, D] into a latent spectrogram tensor of shape [B, C, L, M], # where L is the latent audio length and M is the number of mel bins. @@ -948,10 +948,10 @@ def prepare_audio_latents( audio_latent_length: int = 1, # 1 is just a dummy value num_mel_bins: int = 64, noise_scale: float = 0.0, - dtype: Optional[torch.dtype] = None, - device: Optional[torch.device] = None, - generator: Optional[torch.Generator] = None, - latents: Optional[torch.Tensor] = None, + dtype: torch.dtype | None = None, + device: torch.device | None = None, + generator: torch.Generator | None = None, + latents: torch.Tensor | None = None, ) -> torch.Tensor: if latents is not None: if latents.ndim == 4: From 1c120c6ad9ba34bf4abfc7b085952c60f744b7cd Mon Sep 17 00:00:00 2001 From: Daniel Gu Date: Sat, 14 Feb 2026 04:50:35 +0100 Subject: [PATCH 14/25] Migrate to Python 3.9+ style type annotations without explicit typing imports --- .../pipelines/ltx2/pipeline_ltx2_condition.py | 72 +++++++++---------- 1 file changed, 36 insertions(+), 36 deletions(-) diff --git a/src/diffusers/pipelines/ltx2/pipeline_ltx2_condition.py b/src/diffusers/pipelines/ltx2/pipeline_ltx2_condition.py index 58d65ba7647f..b528c2a69e47 100644 --- a/src/diffusers/pipelines/ltx2/pipeline_ltx2_condition.py +++ b/src/diffusers/pipelines/ltx2/pipeline_ltx2_condition.py @@ -15,7 +15,7 @@ import copy import inspect from dataclasses import dataclass -from typing import Any, Callable, Dict, List, Optional, Tuple, Union +from typing import Any, Callable import numpy as np import PIL.Image @@ -112,7 +112,7 @@ class LTX2VideoCondition: The strength of the conditioning effect. A value of `1.0` means the conditioning effect is fully applied. """ - frames: Union[PIL.Image.Image, List[PIL.Image.Image], np.ndarray, torch.Tensor] + frames: PIL.Image.Image | list[PIL.Image.Image] | np.ndarray | torch.Tensor index: int = 0 strength: float = 1.0 @@ -251,7 +251,7 @@ def __init__( vae: AutoencoderKLLTX2Video, audio_vae: AutoencoderKLLTX2Audio, text_encoder: Gemma3ForConditionalGeneration, - tokenizer: Union[GemmaTokenizer, GemmaTokenizerFast], + tokenizer: GemmaTokenizer | GemmaTokenizerFast, connectors: LTX2TextConnectors, transformer: LTX2VideoTransformer3DModel, vocoder: LTX2Vocoder, @@ -726,13 +726,13 @@ def trim_conditioning_sequence(self, start_frame: int, sequence_num_frames: int, def preprocess_conditions( self, - conditions: Optional[Union[LTX2VideoCondition, List[LTX2VideoCondition]]] = None, + conditions: LTX2VideoCondition | list[LTX2VideoCondition] | None = None, height: int = 512, width: int = 768, num_frames: int = 121, - device: Optional[torch.device] = None, + device: torch.device | None = None, index_type: str = "latent", - ) -> Tuple[List[torch.Tensor], List[float], List[int]]: + ) -> tuple[list[torch.Tensor], list[float], list[int]]: """ Preprocesses the condition images/videos to torch tensors. @@ -809,12 +809,12 @@ def apply_visual_conditioning( self, latents: torch.Tensor, conditioning_mask: torch.Tensor, - condition_latents: List[torch.Tensor], - condition_strengths: List[float], - condition_indices: List[int], + condition_latents: list[torch.Tensor], + condition_strengths: list[float], + condition_indices: list[int], latent_height: int, latent_width: int, - ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """ Applies visual conditioning frames to an initial latent. @@ -850,18 +850,18 @@ def apply_visual_conditioning( def prepare_latents( self, - conditions: Optional[Union[LTX2VideoCondition, List[LTX2VideoCondition]]] = None, + conditions: LTX2VideoCondition | list[LTX2VideoCondition] | None = None, batch_size: int = 1, num_channels_latents: int = 128, height: int = 512, width: int = 768, num_frames: int = 121, noise_scale: float = 1.0, - dtype: Optional[torch.dtype] = None, - device: Optional[torch.device] = None, - generator: Optional[torch.Generator] = None, - latents: Optional[torch.Tensor] = None, - ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + dtype: torch.dtype | None = None, + device: torch.device | None = None, + generator: torch.Generator | None = None, + latents: torch.Tensor | None = None, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: latent_height = height // self.vae_spatial_compression_ratio latent_width = width // self.vae_spatial_compression_ratio latent_num_frames = (num_frames - 1) // self.vae_temporal_compression_ratio + 1 @@ -1012,34 +1012,34 @@ def interrupt(self): @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( self, - conditions: Union[LTX2VideoCondition, List[LTX2VideoCondition]] = None, - prompt: Union[str, List[str]] = None, - negative_prompt: Optional[Union[str, List[str]]] = None, + conditions: LTX2VideoCondition | list[LTX2VideoCondition] | None = None, + prompt: str | list[str] = None, + negative_prompt: str | list[str] | None = None, height: int = 512, width: int = 768, num_frames: int = 121, frame_rate: float = 24.0, num_inference_steps: int = 40, - sigmas: Optional[List[float]] = None, - timesteps: List[int] = None, + sigmas: list[float] | None = None, + timesteps: list[float] | None = None, guidance_scale: float = 4.0, guidance_rescale: float = 0.0, - noise_scale: Optional[float] = None, - num_videos_per_prompt: Optional[int] = 1, - generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, - latents: Optional[torch.Tensor] = None, - audio_latents: Optional[torch.Tensor] = None, - prompt_embeds: Optional[torch.Tensor] = None, - prompt_attention_mask: Optional[torch.Tensor] = None, - negative_prompt_embeds: Optional[torch.Tensor] = None, - negative_prompt_attention_mask: Optional[torch.Tensor] = None, - decode_timestep: Union[float, List[float]] = 0.0, - decode_noise_scale: Optional[Union[float, List[float]]] = None, - output_type: Optional[str] = "pil", + noise_scale: float | None = None, + num_videos_per_prompt: int | None = 1, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.Tensor | None = None, + audio_latents: torch.Tensor | None = None, + prompt_embeds: torch.Tensor | None = None, + prompt_attention_mask: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + negative_prompt_attention_mask: torch.Tensor | None = None, + decode_timestep: float | list[float] = 0.0, + decode_noise_scale: float | list[float] | None = None, + output_type: str = "pil", return_dict: bool = True, - attention_kwargs: Optional[Dict[str, Any]] = None, - callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, - callback_on_step_end_tensor_inputs: List[str] = ["latents"], + attention_kwargs: dict[str, Any] | None = None, + callback_on_step_end: Callable[[int, int], None] | None = None, + callback_on_step_end_tensor_inputs: list[str] = ["latents"], max_sequence_length: int = 1024, ): r""" From ca931c641622676c4bc94e9bac4040cf04ff6618 Mon Sep 17 00:00:00 2001 From: Daniel Gu Date: Mon, 16 Feb 2026 01:17:52 +0100 Subject: [PATCH 15/25] Forward kwargs from preprocess/postprocess_video to preprocess/postprocess resp. --- src/diffusers/video_processor.py | 17 +++++++++++------ 1 file changed, 11 insertions(+), 6 deletions(-) diff --git a/src/diffusers/video_processor.py b/src/diffusers/video_processor.py index 34427686394d..0c51b4b38f23 100644 --- a/src/diffusers/video_processor.py +++ b/src/diffusers/video_processor.py @@ -25,9 +25,9 @@ class VideoProcessor(VaeImageProcessor): r"""Simple video processor.""" - def preprocess_video(self, video, height: int | None = None, width: int | None = None) -> torch.Tensor: + def preprocess_video(self, video, height: int | None = None, width: int | None = None, **kwargs) -> torch.Tensor: r""" - Preprocesses input video(s). + Preprocesses input video(s). Keyword arguments will be forwarded to `VaeImageProcessor.preprocess`. Args: video (`list[PIL.Image]`, `list[list[PIL.Image]]`, `torch.Tensor`, `np.array`, `list[torch.Tensor]`, `list[np.array]`): @@ -49,6 +49,10 @@ def preprocess_video(self, video, height: int | None = None, width: int | None = width (`int`, *optional*`, defaults to `None`): The width in preprocessed frames of the video. If `None`, will use get_default_height_width()` to get the default width. + + Returns: + `torch.Tensor` of shape `(batch_size, num_channels, num_frames, height, width)`: + A 5D tensor holding the batched channels-first video(s). """ if isinstance(video, list) and isinstance(video[0], np.ndarray) and video[0].ndim == 5: warnings.warn( @@ -79,7 +83,7 @@ def preprocess_video(self, video, height: int | None = None, width: int | None = "Input is in incorrect format. Currently, we only support numpy.ndarray, torch.Tensor, PIL.Image.Image" ) - video = torch.stack([self.preprocess(img, height=height, width=width) for img in video], dim=0) + video = torch.stack([self.preprocess(img, height=height, width=width, **kwargs) for img in video], dim=0) # move the number of channels before the number of frames. video = video.permute(0, 2, 1, 3, 4) @@ -87,10 +91,11 @@ def preprocess_video(self, video, height: int | None = None, width: int | None = return video def postprocess_video( - self, video: torch.Tensor, output_type: str = "np" + self, video: torch.Tensor, output_type: str = "np", **kwargs ) -> np.ndarray | torch.Tensor | list[PIL.Image.Image]: r""" - Converts a video tensor to a list of frames for export. + Converts a video tensor to a list of frames for export. Keyword arguments will be forwarded to + `VaeImageProcessor.postprocess`. Args: video (`torch.Tensor`): The video as a tensor. @@ -100,7 +105,7 @@ def postprocess_video( outputs = [] for batch_idx in range(batch_size): batch_vid = video[batch_idx].permute(1, 0, 2, 3) - batch_output = self.postprocess(batch_vid, output_type) + batch_output = self.postprocess(batch_vid, output_type, **kwargs) outputs.append(batch_output) if output_type == "np": From df2ca6ed225e29b9dea93f78947739aedd2abfda Mon Sep 17 00:00:00 2001 From: Daniel Gu Date: Mon, 16 Feb 2026 01:33:14 +0100 Subject: [PATCH 16/25] Center crop LTX-2 conditions following original code --- src/diffusers/pipelines/ltx2/pipeline_ltx2_condition.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/ltx2/pipeline_ltx2_condition.py b/src/diffusers/pipelines/ltx2/pipeline_ltx2_condition.py index b528c2a69e47..7ac11edc9500 100644 --- a/src/diffusers/pipelines/ltx2/pipeline_ltx2_condition.py +++ b/src/diffusers/pipelines/ltx2/pipeline_ltx2_condition.py @@ -779,7 +779,7 @@ def preprocess_conditions( # Treat all other as videos. Note that this means 4D ndarrays and tensors will be treated as videos of # shape (F, H, W, C) and (F, C, H, W), respectively. video_like_cond = condition.frames - condition_pixels = self.video_processor.preprocess_video(video_like_cond, height, width) + condition_pixels = self.video_processor.preprocess_video(video_like_cond, height, width, resize_mode="crop") # Interpret the index as a latent index, following the original LTX-2 code. latent_start_idx = condition.index From 49ef4c5ba2753c98f01e41769489596b4bec756b Mon Sep 17 00:00:00 2001 From: Daniel Gu Date: Mon, 16 Feb 2026 01:48:27 +0100 Subject: [PATCH 17/25] Duplicate video and audio position ids if using CFG --- src/diffusers/pipelines/ltx2/pipeline_ltx2_condition.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/diffusers/pipelines/ltx2/pipeline_ltx2_condition.py b/src/diffusers/pipelines/ltx2/pipeline_ltx2_condition.py index 7ac11edc9500..3b32a51fb476 100644 --- a/src/diffusers/pipelines/ltx2/pipeline_ltx2_condition.py +++ b/src/diffusers/pipelines/ltx2/pipeline_ltx2_condition.py @@ -1327,6 +1327,10 @@ def __call__( audio_coords = self.transformer.audio_rope.prepare_audio_coords( audio_latents.shape[0], audio_num_frames, audio_latents.device ) + # Duplicate the positional ids as well if using CFG + if self.do_classifier_free_guidance: + video_coords = video_coords.repeat((2,) + (1,) * (video_coords.ndim - 1)) # Repeat twice in batch dim + audio_coords = audio_coords.repeat((2,) + (1,) * (audio_coords.ndim - 1)) # 7. Denoising loop with self.progress_bar(total=num_inference_steps) as progress_bar: From b4e7815306665afd83c59bbddbc1b7fa96ee9f9b Mon Sep 17 00:00:00 2001 From: Daniel Gu Date: Mon, 16 Feb 2026 01:49:18 +0100 Subject: [PATCH 18/25] make style and make quality --- src/diffusers/pipelines/ltx2/pipeline_ltx2_condition.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/ltx2/pipeline_ltx2_condition.py b/src/diffusers/pipelines/ltx2/pipeline_ltx2_condition.py index 3b32a51fb476..27d5f4773348 100644 --- a/src/diffusers/pipelines/ltx2/pipeline_ltx2_condition.py +++ b/src/diffusers/pipelines/ltx2/pipeline_ltx2_condition.py @@ -779,7 +779,9 @@ def preprocess_conditions( # Treat all other as videos. Note that this means 4D ndarrays and tensors will be treated as videos of # shape (F, H, W, C) and (F, C, H, W), respectively. video_like_cond = condition.frames - condition_pixels = self.video_processor.preprocess_video(video_like_cond, height, width, resize_mode="crop") + condition_pixels = self.video_processor.preprocess_video( + video_like_cond, height, width, resize_mode="crop" + ) # Interpret the index as a latent index, following the original LTX-2 code. latent_start_idx = condition.index From 65597652a76ab2134b79d1771a680556c0926d87 Mon Sep 17 00:00:00 2001 From: Daniel Gu Date: Mon, 16 Feb 2026 02:14:28 +0100 Subject: [PATCH 19/25] Remove unused index_type arg to preprocess_conditions --- src/diffusers/pipelines/ltx2/pipeline_ltx2_condition.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/diffusers/pipelines/ltx2/pipeline_ltx2_condition.py b/src/diffusers/pipelines/ltx2/pipeline_ltx2_condition.py index 27d5f4773348..ae4d40f052a5 100644 --- a/src/diffusers/pipelines/ltx2/pipeline_ltx2_condition.py +++ b/src/diffusers/pipelines/ltx2/pipeline_ltx2_condition.py @@ -731,7 +731,6 @@ def preprocess_conditions( width: int = 768, num_frames: int = 121, device: torch.device | None = None, - index_type: str = "latent", ) -> tuple[list[torch.Tensor], list[float], list[int]]: """ Preprocesses the condition images/videos to torch tensors. From 47ebd9222261e6fee3e6db875edc30aff7c41cf6 Mon Sep 17 00:00:00 2001 From: Daniel Gu Date: Wed, 4 Mar 2026 08:49:12 +0100 Subject: [PATCH 20/25] Add # Copied from for _normalize_latents --- src/diffusers/pipelines/ltx2/pipeline_ltx2_condition.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/diffusers/pipelines/ltx2/pipeline_ltx2_condition.py b/src/diffusers/pipelines/ltx2/pipeline_ltx2_condition.py index ae4d40f052a5..7a4228b2681c 100644 --- a/src/diffusers/pipelines/ltx2/pipeline_ltx2_condition.py +++ b/src/diffusers/pipelines/ltx2/pipeline_ltx2_condition.py @@ -621,6 +621,7 @@ def _unpack_latents( return latents @staticmethod + # Copied from diffusers.pipelines.ltx2.pipeline_ltx2.LTX2ImageToVideoPipeline._normalize_latents def _normalize_latents( latents: torch.Tensor, latents_mean: torch.Tensor, latents_std: torch.Tensor, scaling_factor: float = 1.0 ) -> torch.Tensor: From 5a0cf67d8190ea3b712e1373868bc2bdcb443b0b Mon Sep 17 00:00:00 2001 From: Daniel Gu Date: Wed, 4 Mar 2026 08:54:18 +0100 Subject: [PATCH 21/25] Fix _normalize_latents # Copied from statement --- src/diffusers/pipelines/ltx2/pipeline_ltx2_condition.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/ltx2/pipeline_ltx2_condition.py b/src/diffusers/pipelines/ltx2/pipeline_ltx2_condition.py index 7a4228b2681c..e4b2becd5b5e 100644 --- a/src/diffusers/pipelines/ltx2/pipeline_ltx2_condition.py +++ b/src/diffusers/pipelines/ltx2/pipeline_ltx2_condition.py @@ -621,7 +621,7 @@ def _unpack_latents( return latents @staticmethod - # Copied from diffusers.pipelines.ltx2.pipeline_ltx2.LTX2ImageToVideoPipeline._normalize_latents + # Copied from diffusers.pipelines.ltx2.pipeline_ltx2_image2video.LTX2ImageToVideoPipeline._normalize_latents def _normalize_latents( latents: torch.Tensor, latents_mean: torch.Tensor, latents_std: torch.Tensor, scaling_factor: float = 1.0 ) -> torch.Tensor: From 2d42573ec45acf807a65254cdadc527a421a2b84 Mon Sep 17 00:00:00 2001 From: Daniel Gu Date: Wed, 4 Mar 2026 10:34:07 +0100 Subject: [PATCH 22/25] Add LTX-2 condition pipeline docs --- docs/source/en/api/pipelines/ltx2.md | 179 +++++++++++++++++++++++++++ 1 file changed, 179 insertions(+) diff --git a/docs/source/en/api/pipelines/ltx2.md b/docs/source/en/api/pipelines/ltx2.md index c77efa09f594..85b0f9691891 100644 --- a/docs/source/en/api/pipelines/ltx2.md +++ b/docs/source/en/api/pipelines/ltx2.md @@ -193,6 +193,179 @@ encode_video( ) ``` +## Condition Pipeline Generation + +You can use `LTX2ConditionPipeline` to specify image and/or video conditions at arbitrary latent indices. For example, we can specify both a first-frame and last-frame condition to perform first-last-frame-to-video (FLF2V) generation: + +```py +import torch +from diffusers import LTX2ConditionPipeline, LTX2LatentUpsamplePipeline +from diffusers.pipelines.ltx2.latent_upsampler import LTX2LatentUpsamplerModel +from diffusers.pipelines.ltx2.pipeline_ltx2_condition import LTX2VideoCondition +from diffusers.pipelines.ltx2.utils import DISTILLED_SIGMA_VALUES, STAGE_2_DISTILLED_SIGMA_VALUES +from diffusers.pipelines.ltx2.export_utils import encode_video +from diffusers.utils import load_image + +device = "cuda" +width = 768 +height = 512 +random_seed = 42 +generator = torch.Generator(device).manual_seed(random_seed) +model_path = "rootonchair/LTX-2-19b-distilled" + +pipe = LTX2ConditionPipeline.from_pretrained(model_path, torch_dtype=torch.bfloat16) +pipe.enable_sequential_cpu_offload(device=device) +pipe.vae.enable_tiling() + +prompt = ( + "CG animation style, a small blue bird takes off from the ground, flapping its wings. The bird's feathers are " + "delicate, with a unique pattern on its chest. The background shows a blue sky with white clouds under bright " + "sunshine. The camera follows the bird upward, capturing its flight and the vastness of the sky from a close-up, " + "low-angle perspective." +) + +first_image = load_image( + "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/flf2v_input_first_frame.png", +) +last_image = load_image( + "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/flf2v_input_last_frame.png", +) +first_cond = LTX2VideoCondition(frames=first_image, index=0, strength=1.0) +last_cond = LTX2VideoCondition(frames=last_image, index=-1, strength=1.0) +conditions = [first_cond, last_cond] + +frame_rate = 24.0 +video_latent, audio_latent = pipe( + conditions=conditions, + prompt=prompt, + width=width, + height=height, + num_frames=121, + frame_rate=frame_rate, + num_inference_steps=8, + sigmas=DISTILLED_SIGMA_VALUES, + guidance_scale=1.0, + generator=generator, + output_type="latent", + return_dict=False, +) + +latent_upsampler = LTX2LatentUpsamplerModel.from_pretrained( + model_path, + subfolder="latent_upsampler", + torch_dtype=torch.bfloat16, +) +upsample_pipe = LTX2LatentUpsamplePipeline(vae=pipe.vae, latent_upsampler=latent_upsampler) +upsample_pipe.enable_model_cpu_offload(device=device) +upscaled_video_latent = upsample_pipe( + latents=video_latent, + output_type="latent", + return_dict=False, +)[0] + +video, audio = pipe( + latents=upscaled_video_latent, + audio_latents=audio_latent, + prompt=prompt, + width=width * 2, + height=height * 2, + num_inference_steps=3, + sigmas=STAGE_2_DISTILLED_SIGMA_VALUES, + generator=generator, + guidance_scale=1.0, + output_type="np", + return_dict=False, +) + +encode_video( + video[0], + fps=frame_rate, + audio=audio[0].float().cpu(), + audio_sample_rate=pipe.vocoder.config.output_sampling_rate, + output_path="ltx2_distilled_flf2v.mp4", +) +``` + +You can use both image and video conditions: + +```py +import torch +from diffusers import LTX2ConditionPipeline +from diffusers.pipelines.ltx2.pipeline_ltx2_condition import LTX2VideoCondition +from diffusers.pipelines.ltx2.export_utils import encode_video +from diffusers.utils import load_image, load_video + +device = "cuda" +width = 768 +height = 512 +random_seed = 42 +generator = torch.Generator(device).manual_seed(random_seed) +model_path = "rootonchair/LTX-2-19b-distilled" + +pipe = LTX2ConditionPipeline.from_pretrained(model_path, torch_dtype=torch.bfloat16) +pipe.enable_sequential_cpu_offload(device=device) +pipe.vae.enable_tiling() + +prompt = ( + "The video depicts a long, straight highway stretching into the distance, flanked by metal guardrails. The road is " + "divided into multiple lanes, with a few vehicles visible in the far distance. The surrounding landscape features " + "dry, grassy fields on one side and rolling hills on the other. The sky is mostly clear with a few scattered " + "clouds, suggesting a bright, sunny day. And then the camera switch to a winding mountain road covered in snow, " + "with a single vehicle traveling along it. The road is flanked by steep, rocky cliffs and sparse vegetation. The " + "landscape is characterized by rugged terrain and a river visible in the distance. The scene captures the " + "solitude and beauty of a winter drive through a mountainous region." +) +negative_prompt = ( + "blurry, out of focus, overexposed, underexposed, low contrast, washed out colors, excessive noise, " + "grainy texture, poor lighting, flickering, motion blur, distorted proportions, unnatural skin tones, " + "deformed facial features, asymmetrical face, missing facial features, extra limbs, disfigured hands, " + "wrong hand count, artifacts around text, inconsistent perspective, camera shake, incorrect depth of " + "field, background too sharp, background clutter, distracting reflections, harsh shadows, inconsistent " + "lighting direction, color banding, cartoonish rendering, 3D CGI look, unrealistic materials, uncanny " + "valley effect, incorrect ethnicity, wrong gender, exaggerated expressions, wrong gaze direction, " + "mismatched lip sync, silent or muted audio, distorted voice, robotic voice, echo, background noise, " + "off-sync audio, incorrect dialogue, added dialogue, repetitive speech, jittery movement, awkward " + "pauses, incorrect timing, unnatural transitions, inconsistent framing, tilted camera, flat lighting, " + "inconsistent tone, cinematic oversaturation, stylized filters, or AI artifacts." +) + +cond_video = load_video( + "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/cosmos/cosmos-video2world-input-vid.mp4" +) +cond_image = load_image( + "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/cosmos/cosmos-video2world-input.jpg" +) +video_cond = LTX2VideoCondition(frames=cond_video, index=0, strength=1.0) +image_cond = LTX2VideoCondition(frames=cond_image, index=8, strength=1.0) +conditions = [video_cond, image_cond] + +frame_rate = 24.0 +video, audio = pipe( + conditions=conditions, + prompt=prompt, + negative_prompt=negative_prompt, + width=width, + height=height, + num_frames=121, + frame_rate=frame_rate, + num_inference_steps=40, + guidance_scale=4.0, + generator=generator, + output_type="np", + return_dict=False, +) + +encode_video( + video[0], + fps=frame_rate, + audio=audio[0].float().cpu(), + audio_sample_rate=pipe.vocoder.config.output_sampling_rate, + output_path="ltx2_cond_video.mp4", +) +``` + +Because the conditioning is done via latent frames, the 8 data space frames corresponding to the specified latent frame for an image condition will tend to be static. + ## LTX2Pipeline [[autodoc]] LTX2Pipeline @@ -205,6 +378,12 @@ encode_video( - all - __call__ +## LTX2ConditionPipeline + +[[autodoc]] LTX2ConditionPipeline + - all + - __call__ + ## LTX2LatentUpsamplePipeline [[autodoc]] LTX2LatentUpsamplePipeline From 1b8444099fd0bee3c657cddd4f51263584a214c3 Mon Sep 17 00:00:00 2001 From: Daniel Gu Date: Thu, 5 Mar 2026 02:03:54 +0100 Subject: [PATCH 23/25] Remove TODOs --- src/diffusers/pipelines/ltx2/pipeline_ltx2_condition.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/src/diffusers/pipelines/ltx2/pipeline_ltx2_condition.py b/src/diffusers/pipelines/ltx2/pipeline_ltx2_condition.py index e4b2becd5b5e..1d3c6bdbc696 100644 --- a/src/diffusers/pipelines/ltx2/pipeline_ltx2_condition.py +++ b/src/diffusers/pipelines/ltx2/pipeline_ltx2_condition.py @@ -275,7 +275,6 @@ def __init__( self.vae_temporal_compression_ratio = ( self.vae.temporal_compression_ratio if getattr(self, "vae", None) is not None else 8 ) - # TODO: check whether the MEL compression ratio logic here is corrct self.audio_vae_mel_compression_ratio = ( self.audio_vae.mel_compression_ratio if getattr(self, "audio_vae", None) is not None else 4 ) @@ -906,9 +905,6 @@ def prepare_latents( condition_frames, condition_strengths, condition_indices = self.preprocess_conditions( conditions, height, width, num_frames, device=device ) - # TODO: should we first concatenate all of the condition tensors and encode them all together? The advantage - # is that this would generally respect VAE settings like tiling, but a disadvantage is that this would by - # default take a lot of memory (for LTX 2, tiled encoding/decoding is almost always necessary). condition_latents = [] for condition_tensor in condition_frames: condition_latent = retrieve_latents( From 4b6168a3a34fec42408c1d80493ae0e367713333 Mon Sep 17 00:00:00 2001 From: Daniel Gu Date: Thu, 5 Mar 2026 08:30:16 +0100 Subject: [PATCH 24/25] Support only unpacked latents (5D for video, 4D for audio) --- .../pipelines/ltx2/pipeline_ltx2_condition.py | 79 ++++++++----------- 1 file changed, 34 insertions(+), 45 deletions(-) diff --git a/src/diffusers/pipelines/ltx2/pipeline_ltx2_condition.py b/src/diffusers/pipelines/ltx2/pipeline_ltx2_condition.py index 1d3c6bdbc696..4c591c433fc0 100644 --- a/src/diffusers/pipelines/ltx2/pipeline_ltx2_condition.py +++ b/src/diffusers/pipelines/ltx2/pipeline_ltx2_condition.py @@ -539,6 +539,8 @@ def check_inputs( negative_prompt_embeds=None, prompt_attention_mask=None, negative_prompt_attention_mask=None, + latents=None, + audio_latents=None, ): if height % 32 != 0 or width % 32 != 0: raise ValueError(f"`height` and `width` have to be divisible by 32 but are {height} and {width}.") @@ -582,6 +584,19 @@ def check_inputs( f" {negative_prompt_attention_mask.shape}." ) + if latents is not None and latents.ndim != 5: + raise ValueError( + f"Only unpacked (5D) video latents of shape `[batch_size, latent_channels, latent_frames," + f" latent_height, latent_width] are supported, but got {latents.ndim} dims. If you have packed (3D)" + f" latents, please unpack them (e.g. using the `_unpack_latents` method)." + ) + if audio_latents is not None and audio_latents.ndim != 4: + raise ValueError( + f"Only unpacked (4D) audio latents of shape `[batch_size, num_channels, audio_length, mel_bins] are" + f" supported, but got {latents.ndim} dims. If you have packed (3D) latents, please unpack them (e.g." + f" using the `_unpack_audio_latents` method)." + ) + @staticmethod # Copied from diffusers.pipelines.ltx2.pipeline_ltx2.LTX2Pipeline._pack_latents def _pack_latents(latents: torch.Tensor, patch_size: int = 1, patch_size_t: int = 1) -> torch.Tensor: @@ -871,21 +886,19 @@ def prepare_latents( mask_shape = (batch_size, 1, latent_num_frames, latent_height, latent_width) if latents is not None: - # Latents is either shape [B, F, C, H, W] or [B, seq_len, hidden_dim] - if latents.ndim == 5: - latents = self._normalize_latents( - latents, self.vae.latents_mean, self.vae.latents_std, self.vae.config.scaling_factor - ) + # Latents are expected to be unpacked (5D) with shape [B, F, C, H, W] + latents = self._normalize_latents( + latents, self.vae.latents_mean, self.vae.latents_std, self.vae.config.scaling_factor + ) else: # NOTE: we set the initial latents to zeros rather a sample from the standard Gaussian prior because we # will sample from the prior later once we have calculated the conditioning mask latents = torch.zeros(shape, device=device, dtype=dtype) conditioning_mask = latents.new_zeros(mask_shape) - if latents.ndim == 5: - latents = self._pack_latents( - latents, self.transformer_spatial_patch_size, self.transformer_temporal_patch_size - ) + latents = self._pack_latents( + latents, self.transformer_spatial_patch_size, self.transformer_temporal_patch_size + ) conditioning_mask = self._pack_latents( conditioning_mask, self.transformer_spatial_patch_size, self.transformer_temporal_patch_size ) # [B, seq_len, 1] @@ -952,18 +965,12 @@ def prepare_audio_latents( latents: torch.Tensor | None = None, ) -> torch.Tensor: if latents is not None: - if latents.ndim == 4: - # latents are of shape [B, C, L, M], need to be packed - latents = self._pack_audio_latents(latents) - if latents.ndim != 3: - raise ValueError( - f"Provided `latents` tensor has shape {latents.shape}, but the expected shape is [batch_size, num_seq, num_features]." - ) + # latents expected to be unpacked (4D) with shape [B, C, L, M] + latents = self._pack_audio_latents(latents) latents = self._normalize_audio_latents(latents, self.audio_vae.latents_mean, self.audio_vae.latents_std) latents = self._create_noised_state(latents, noise_scale, generator) return latents.to(device=device, dtype=dtype) - # TODO: confirm whether this logic is correct latent_mel_bins = num_mel_bins // self.audio_vae_mel_compression_ratio shape = (batch_size, num_channels_latents, audio_latent_length, latent_mel_bins) @@ -1153,6 +1160,8 @@ def __call__( negative_prompt_embeds=negative_prompt_embeds, prompt_attention_mask=prompt_attention_mask, negative_prompt_attention_mask=negative_prompt_attention_mask, + latents=latents, + audio_latents=audio_latents, ) self._guidance_scale = guidance_scale @@ -1210,20 +1219,10 @@ def __call__( latent_height = height // self.vae_spatial_compression_ratio latent_width = width // self.vae_spatial_compression_ratio if latents is not None: - if latents.ndim == 5: - logger.info( - "Got latents of shape [batch_size, latent_dim, latent_frames, latent_height, latent_width], `latent_num_frames`, `latent_height`, `latent_width` will be inferred." - ) - _, _, latent_num_frames, latent_height, latent_width = latents.shape # [B, C, F, H, W] - elif latents.ndim == 3: - logger.warning( - f"You have supplied packed `latents` of shape {latents.shape}, so the latent dims cannot be" - f" inferred. Make sure the supplied `height`, `width`, and `num_frames` are correct." - ) - else: - raise ValueError( - f"Provided `latents` tensor has shape {latents.shape}, but the expected shape is either [batch_size, seq_len, num_features] or [batch_size, latent_dim, latent_frames, latent_height, latent_width]." - ) + logger.info( + "Got latents of shape [batch_size, latent_dim, latent_frames, latent_height, latent_width], `latent_num_frames`, `latent_height`, `latent_width` will be inferred." + ) + _, _, latent_num_frames, latent_height, latent_width = latents.shape # [B, C, F, H, W] video_sequence_length = latent_num_frames * latent_height * latent_width num_channels_latents = self.transformer.config.in_channels @@ -1249,20 +1248,10 @@ def __call__( ) audio_num_frames = round(duration_s * audio_latents_per_second) if audio_latents is not None: - if audio_latents.ndim == 4: - logger.info( - "Got audio_latents of shape [batch_size, num_channels, audio_length, mel_bins], `audio_num_frames` will be inferred." - ) - _, _, audio_num_frames, _ = audio_latents.shape # [B, C, L, M] - elif audio_latents.ndim == 3: - logger.warning( - f"You have supplied packed `audio_latents` of shape {audio_latents.shape}, so the latent dims" - f" cannot be inferred. Make sure the supplied `num_frames` and `frame_rate` are correct." - ) - else: - raise ValueError( - f"Provided `audio_latents` tensor has shape {audio_latents.shape}, but the expected shape is either [batch_size, seq_len, num_features] or [batch_size, num_channels, audio_length, mel_bins]." - ) + logger.info( + "Got audio_latents of shape [batch_size, num_channels, audio_num_frames, mel_bins], `audio_num_frames` will be inferred." + ) + _, _, audio_num_frames, _ = audio_latents.shape # [B, C, L, M] num_mel_bins = self.audio_vae.config.mel_bins if getattr(self, "audio_vae", None) is not None else 64 latent_mel_bins = num_mel_bins // self.audio_vae_mel_compression_ratio From 56e505770287967e9e099dc5ed458f0fd3cda1ae Mon Sep 17 00:00:00 2001 From: Daniel Gu Date: Thu, 5 Mar 2026 08:35:30 +0100 Subject: [PATCH 25/25] Remove # Copied from for prepare_audio_latents --- src/diffusers/pipelines/ltx2/pipeline_ltx2_condition.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/diffusers/pipelines/ltx2/pipeline_ltx2_condition.py b/src/diffusers/pipelines/ltx2/pipeline_ltx2_condition.py index 4c591c433fc0..4c451330f439 100644 --- a/src/diffusers/pipelines/ltx2/pipeline_ltx2_condition.py +++ b/src/diffusers/pipelines/ltx2/pipeline_ltx2_condition.py @@ -951,7 +951,6 @@ def prepare_latents( return latents, conditioning_mask, clean_latents - # Copied from diffusers.pipelines.ltx2.pipeline_ltx2.LTX2Pipeline.prepare_audio_latents def prepare_audio_latents( self, batch_size: int = 1,