Skip to content

allegro model/pipeline review #13647

@hlky

Description

@hlky

allegro model/pipeline review

Commit tested: 0f1abc4ae8b0eb2a3b40e82a310507281144c423

Review performed against the repository review rules.

Issue 1: VAE decode/encode fail unless tiling is enabled

Affected code:

def _encode(self, x: torch.Tensor) -> torch.Tensor:
# TODO(aryan)
# if self.use_tiling and (width > self.tile_sample_min_width or height > self.tile_sample_min_height):
if self.use_tiling:
return self.tiled_encode(x)
raise NotImplementedError("Encoding without tiling has not been implemented yet.")
@apply_forward_hook
def encode(
self, x: torch.Tensor, return_dict: bool = True
) -> AutoencoderKLOutput | tuple[DiagonalGaussianDistribution]:
r"""
Encode a batch of videos into latents.
Args:
x (`torch.Tensor`):
Input batch of videos.
return_dict (`bool`, defaults to `True`):
Whether to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple.
Returns:
The latent representations of the encoded videos. If `return_dict` is True, a
[`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned.
"""
if self.use_slicing and x.shape[0] > 1:
encoded_slices = [self._encode(x_slice) for x_slice in x.split(1)]
h = torch.cat(encoded_slices)
else:
h = self._encode(x)
posterior = DiagonalGaussianDistribution(h)
if not return_dict:
return (posterior,)
return AutoencoderKLOutput(latent_dist=posterior)
def _decode(self, z: torch.Tensor) -> torch.Tensor:
# TODO(aryan): refactor tiling implementation
# if self.use_tiling and (width > self.tile_latent_min_width or height > self.tile_latent_min_height):
if self.use_tiling:
return self.tiled_decode(z)
raise NotImplementedError("Decoding without tiling has not been implemented yet.")

if not output_type == "latent":
latents = latents.to(self.vae.dtype)
video = self.decode_latents(latents)
video = video[:, :, :num_frames, :height, :width]
video = self.video_processor.postprocess_video(video=video, output_type=output_type)

Problem:
AutoencoderKLAllegro.encode() and .decode() raise NotImplementedError unless vae.enable_tiling() was called. The pipeline also fails by default when output_type != "latent" unless the user remembered to enable tiling. This also breaks save/load parity because use_tiling is runtime state, not serialized config.

Impact:
Basic public VAE APIs and default pipeline inference are fragile. Docs include a quantized AllegroPipeline example that does not enable tiling, so that path can fail at decode time.

Reproduction:

import torch
from diffusers import AutoencoderKLAllegro

vae = AutoencoderKLAllegro(
    block_out_channels=(8, 8, 8, 8),
    latent_channels=4,
    layers_per_block=1,
    norm_num_groups=2,
)
vae.decode(torch.randn(1, 4, 2, 2, 2))

Relevant precedent:
Related mitigation, not a duplicate full fix: #10212

Suggested fix:

def _encode(self, x):
    if self.use_tiling:
        return self.tiled_encode(x)
    batch_size = x.shape[0]
    h = self.encoder(x)
    h = h.permute(0, 2, 1, 3, 4).flatten(0, 1)
    h = self.quant_conv(h)
    return h.unflatten(0, (batch_size, -1)).permute(0, 2, 1, 3, 4)

def _decode(self, z):
    if self.use_tiling:
        return self.tiled_decode(z)
    batch_size = z.shape[0]
    z = z.permute(0, 2, 1, 3, 4).flatten(0, 1)
    z = self.post_quant_conv(z)
    z = z.unflatten(0, (batch_size, -1)).permute(0, 2, 1, 3, 4)
    return self.decoder(z)

Issue 2: num_videos_per_prompt is silently ignored

Affected code:


) = self.encode_prompt(
prompt,
do_classifier_free_guidance,
negative_prompt=negative_prompt,
num_videos_per_prompt=num_videos_per_prompt,
device=device,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
prompt_attention_mask=prompt_attention_mask,
negative_prompt_attention_mask=negative_prompt_attention_mask,
clean_caption=clean_caption,
max_sequence_length=max_sequence_length,
)
if do_classifier_free_guidance:
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask], dim=0)
if prompt_embeds.ndim == 3:
prompt_embeds = prompt_embeds.unsqueeze(1) # b l d -> b 1 l d
# 4. Prepare timesteps
if XLA_AVAILABLE:
timestep_device = "cpu"
else:
timestep_device = device
timesteps, num_inference_steps = retrieve_timesteps(
self.scheduler, num_inference_steps, timestep_device, timesteps
)
self.scheduler.set_timesteps(num_inference_steps, device=device)
# 5. Prepare latents.
latent_channels = self.transformer.config.in_channels
latents = self.prepare_latents(
batch_size * num_videos_per_prompt,

Problem:
__call__ accepts and documents num_videos_per_prompt, but line 824 overwrites any user value with 1.

Impact:
Users requesting multiple videos per prompt always receive one video, with no warning or error.

Reproduction:

import torch
from diffusers import AllegroPipeline, AllegroTransformer3DModel, AutoencoderKLAllegro, DDIMScheduler

transformer = AllegroTransformer3DModel(
    num_attention_heads=2, attention_head_dim=12, in_channels=4, out_channels=4,
    num_layers=1, cross_attention_dim=24, sample_width=8, sample_height=8,
    sample_frames=8, caption_channels=24,
)
vae = AutoencoderKLAllegro(block_out_channels=(8, 8, 8, 8), latent_channels=4, layers_per_block=1, norm_num_groups=2)
pipe = AllegroPipeline(None, None, vae, transformer, DDIMScheduler())
pipe.set_progress_bar_config(disable=True)

out = pipe(
    prompt_embeds=torch.randn(1, 16, 24),
    prompt_attention_mask=torch.ones(1, 16, dtype=torch.long),
    guidance_scale=1.0,
    num_videos_per_prompt=3,
    num_inference_steps=1,
    height=16, width=16, num_frames=8,
    output_type="latent",
).frames
print(out.shape)  # torch.Size([1, 4, 2, 2, 2]), expected batch 3

Relevant precedent:
HunyuanVideoPipeline preserves num_videos_per_prompt through prompt encoding and latent preparation.

Suggested fix:

# Remove this line from __call__:
num_videos_per_prompt = 1

Issue 3: Custom timestep scheduler state is overwritten

Affected code:

timesteps, num_inference_steps = retrieve_timesteps(
self.scheduler, num_inference_steps, timestep_device, timesteps
)
self.scheduler.set_timesteps(num_inference_steps, device=device)

Problem:
retrieve_timesteps() correctly applies custom timesteps, but the next line calls self.scheduler.set_timesteps(num_inference_steps, device=device) again, replacing scheduler internal state with default timesteps while the loop still iterates over the custom local timesteps.

Impact:
Schedulers whose step() depends on scheduler.timesteps can use mismatched sigma/step indices for custom timesteps.

Reproduction:

import torch
from diffusers import AllegroPipeline, AllegroTransformer3DModel, AutoencoderKLAllegro, EulerDiscreteScheduler

transformer = AllegroTransformer3DModel(
    num_attention_heads=2, attention_head_dim=12, in_channels=4, out_channels=4,
    num_layers=1, cross_attention_dim=24, sample_width=8, sample_height=8,
    sample_frames=8, caption_channels=24,
)
vae = AutoencoderKLAllegro(block_out_channels=(8, 8, 8, 8), latent_channels=4, layers_per_block=1, norm_num_groups=2)
pipe = AllegroPipeline(None, None, vae, transformer, EulerDiscreteScheduler(num_train_timesteps=1000))
pipe.set_progress_bar_config(disable=True)

seen = []
def cb(pipe, i, t, kwargs):
    seen.append((int(t), [int(x) for x in pipe.scheduler.timesteps[:2]]))
    return kwargs

pipe(
    prompt_embeds=torch.randn(1, 16, 24),
    prompt_attention_mask=torch.ones(1, 16, dtype=torch.long),
    guidance_scale=1.0,
    timesteps=[999, 500],
    height=16, width=16, num_frames=8,
    output_type="latent",
    callback_on_step_end=cb,
)
print(seen)  # [(999, [999, 0]), (500, [999, 0])]

Relevant precedent:
QwenImage and Hunyuan pipelines call retrieve_timesteps() once and do not reset the scheduler afterward.

Suggested fix:

timesteps, num_inference_steps = retrieve_timesteps(
    self.scheduler, num_inference_steps, timestep_device, timesteps
)
# Delete the second set_timesteps call.

Issue 4: Allegro attention ignores the attention backend dispatcher

Affected code:

class AllegroAttnProcessor2_0:
r"""
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). This is
used in the Allegro model. It applies a normalization layer and rotary embedding on the query and key vector.
"""
def __init__(self):
if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError(
"AllegroAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
)
def __call__(
self,
attn: Attention,
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor | None = None,
attention_mask: torch.Tensor | None = None,
temb: torch.Tensor | None = None,
image_rotary_emb: torch.Tensor | None = None,
) -> torch.Tensor:
residual = hidden_states
if attn.spatial_norm is not None:
hidden_states = attn.spatial_norm(hidden_states, temb)
input_ndim = hidden_states.ndim
if input_ndim == 4:
batch_size, channel, height, width = hidden_states.shape
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
batch_size, sequence_length, _ = (
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
)
if attention_mask is not None:
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
# scaled_dot_product_attention expects attention_mask shape to be
# (batch, heads, source_length, target_length)
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
if attn.group_norm is not None:
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
query = attn.to_q(hidden_states)
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
elif attn.norm_cross:
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
key = attn.to_k(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states)
inner_dim = key.shape[-1]
head_dim = inner_dim // attn.heads
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
# Apply RoPE if needed
if image_rotary_emb is not None and not attn.is_cross_attention:
from .embeddings import apply_rotary_emb_allegro
query = apply_rotary_emb_allegro(query, image_rotary_emb[0], image_rotary_emb[1])
key = apply_rotary_emb_allegro(key, image_rotary_emb[0], image_rotary_emb[1])
# the output of sdp = (batch, num_heads, seq_len, head_dim)
# TODO: add support for attn.scale when we move to Torch 2.1
hidden_states = F.scaled_dot_product_attention(
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
)

self.attn1 = Attention(
query_dim=dim,
heads=num_attention_heads,
dim_head=attention_head_dim,
dropout=dropout,
bias=attention_bias,
cross_attention_dim=None,
processor=AllegroAttnProcessor2_0(),
)
# 2. Cross Attention
self.norm2 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)
self.attn2 = Attention(
query_dim=dim,
cross_attention_dim=cross_attention_dim,
heads=num_attention_heads,
dim_head=attention_head_dim,
dropout=dropout,
bias=attention_bias,
processor=AllegroAttnProcessor2_0(),

Problem:
AllegroAttnProcessor2_0 calls F.scaled_dot_product_attention directly and has no _attention_backend attribute. model.set_attention_backend() therefore cannot route Allegro attention through dispatch_attention_fn.

Impact:
Allegro cannot reliably use newer diffusers attention backends, backend-specific validation, or context-parallel-compatible attention paths.

Reproduction:

from diffusers import AllegroTransformer3DModel
from diffusers.models.attention_processor import Attention

model = AllegroTransformer3DModel(
    num_attention_heads=2, attention_head_dim=12, in_channels=4, out_channels=4,
    num_layers=1, cross_attention_dim=24, sample_width=8, sample_height=8,
    sample_frames=8, caption_channels=24,
)
model.set_attention_backend("native")
print([(n, hasattr(m.processor, "_attention_backend")) for n, m in model.named_modules() if isinstance(m, Attention)])
# [('transformer_blocks.0.attn1', False), ('transformer_blocks.0.attn2', False)]

Relevant precedent:
The current review rules point to Flux/Wan/Qwen-style processors that call dispatch_attention_fn.

Suggested fix:
Move Allegro attention to the modern model-local attention pattern, add _attention_backend and _parallel_config on the processor, and replace the direct SDPA call with dispatch_attention_fn(..., backend=self._attention_backend, parallel_config=self._parallel_config).

Issue 5: AllegroPipelineOutput is not exported from the Allegro package

Affected code:

_dummy_objects = {}
_import_structure = {}
try:
if not (is_transformers_available() and is_torch_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ...utils import dummy_torch_and_transformers_objects # noqa F403
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
else:
_import_structure["pipeline_allegro"] = ["AllegroPipeline"]
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
try:
if not (is_transformers_available() and is_torch_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ...utils.dummy_torch_and_transformers_objects import *
else:
from .pipeline_allegro import AllegroPipeline

@dataclass
class AllegroPipelineOutput(BaseOutput):
r"""
Output class for Allegro pipelines.
Args:
frames (`torch.Tensor`, `np.ndarray`, or list[list[PIL.Image.Image]]):
list of video outputs - It can be a nested list of length `batch_size,` with each sub-list containing
denoised PIL image sequences of length `num_frames.` It can also be a NumPy array or Torch tensor of shape
`(batch_size, num_frames, channels, height, width)`.
"""
frames: torch.Tensor | np.ndarray | list[list[PIL.Image.Image]]

Problem:
The output class exists and is documented, but from diffusers.pipelines.allegro import AllegroPipelineOutput fails because pipeline_output is absent from the lazy import structure.

Impact:
Public import behavior is inconsistent with many pipeline packages that expose their output dataclasses.

Reproduction:

from diffusers.pipelines.allegro import AllegroPipelineOutput

Relevant precedent:
src/diffusers/pipelines/qwenimage/__init__.py exports QwenImagePipelineOutput through _import_structure["pipeline_output"].

Suggested fix:

_import_structure = {"pipeline_output": ["AllegroPipelineOutput"]}

# In the TYPE_CHECKING / slow import branch:
from .pipeline_output import AllegroPipelineOutput

Issue 6: Test coverage does not exercise meaningful Allegro VAE decode behavior

Affected code:

@unittest.skip("Decoding without tiling is not yet implemented")
def test_save_load_local(self):
pass
@unittest.skip("Decoding without tiling is not yet implemented")
def test_save_load_optional_components(self):
pass
@unittest.skip("Decoding without tiling is not yet implemented")
def test_pipeline_with_accelerator_device_map(self):
pass
def test_inference(self):
device = "cpu"
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe.to(device)
pipe.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(device)
video = pipe(**inputs).frames
generated_video = video[0]
self.assertEqual(generated_video.shape, (8, 3, 16, 16))
expected_video = torch.randn(8, 3, 16, 16)
max_diff = np.abs(generated_video - expected_video).max()
self.assertLessEqual(max_diff, 1e10)

# TODO(aryan)
@unittest.skip("Decoding without tiling is not yet implemented.")
def test_vae_tiling(self, expected_diff_max: float = 0.2):
generator_device = "cpu"
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe.to("cpu")
pipe.set_progress_bar_config(disable=None)
# Without tiling
inputs = self.get_dummy_inputs(generator_device)
inputs["height"] = inputs["width"] = 128
output_without_tiling = pipe(**inputs)[0]
# With tiling
pipe.vae.enable_tiling(
tile_sample_min_height=96,
tile_sample_min_width=96,
tile_overlap_factor_height=1 / 12,
tile_overlap_factor_width=1 / 12,
)
inputs = self.get_dummy_inputs(generator_device)
inputs["height"] = inputs["width"] = 128
output_with_tiling = pipe(**inputs)[0]
self.assertLess(
(to_np(output_without_tiling) - to_np(output_with_tiling)).max(),
expected_diff_max,
"VAE tiling should not affect the inference results",
)
@require_hf_hub_version_greater("0.26.5")
@require_transformers_version_greater("4.47.1")
def test_save_load_dduf(self):
# reimplement because it needs `enable_tiling()` on the loaded pipe.
from huggingface_hub import export_folder_as_dduf
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(device="cpu")
inputs.pop("generator")
inputs["generator"] = torch.manual_seed(0)
pipeline_out = pipe(**inputs)[0].cpu()
with tempfile.TemporaryDirectory() as tmpdir:
dduf_filename = os.path.join(tmpdir, f"{pipe.__class__.__name__.lower()}.dduf")
pipe.save_pretrained(tmpdir, safe_serialization=True)
export_folder_as_dduf(dduf_filename, folder_path=tmpdir)
loaded_pipe = self.pipeline_class.from_pretrained(tmpdir, dduf_file=dduf_filename).to(torch_device)
loaded_pipe.vae.enable_tiling()
inputs["generator"] = torch.manual_seed(0)
loaded_pipeline_out = loaded_pipe(**inputs)[0].cpu()
assert np.allclose(pipeline_out, loaded_pipeline_out)

def test_allegro(self):
generator = torch.Generator("cpu").manual_seed(0)
pipe = AllegroPipeline.from_pretrained("rhymes-ai/Allegro", torch_dtype=torch.float16)
pipe.enable_model_cpu_offload(device=torch_device)
prompt = self.prompt
videos = pipe(
prompt=prompt,
height=720,
width=1280,
num_frames=88,
generator=generator,
num_inference_steps=2,
output_type="pt",
).frames
video = videos[0]
expected_video = torch.randn(1, 88, 720, 1280, 3).numpy()
max_diff = numpy_cosine_similarity_distance(video, expected_video)
assert max_diff < 1e-3, f"Max diff is too high. got {video}"

Problem:
There is no standalone AutoencoderKLAllegro model test. Several important pipeline tests are skipped because non-tiled decoding is unimplemented, and the fast/slow inference tests compare against random tensors rather than fixed expected outputs.

Impact:
The current tests can pass while decode is unimplemented or returns meaningless output, so regressions in the VAE and pipeline output quality are not caught.

Reproduction:

from pathlib import Path

print(Path("tests/models/autoencoders/test_models_autoencoder_kl_allegro.py").exists())
text = Path("tests/pipelines/allegro/test_allegro.py").read_text()
print("Decoding without tiling is not yet implemented" in text)
print("expected_video = torch.randn" in text)

Relevant precedent:
Other video VAEs such as CogVideoX/Wan have model-level VAE coverage and pipeline tests with deterministic expected slices.

Suggested fix:
Add a dedicated AutoencoderKLAllegro fast test using sample_size=16 and tiny channels, cover encode, decode, save/load, slicing, and tiling. Replace random expected tensors in pipeline tests with deterministic slices from known-good outputs.

Duplicate-search status: searched GitHub Issues and PRs for Allegro, pipeline_allegro.py, autoencoder_kl_allegro.py, transformer_allegro.py, AutoencoderKLAllegro Decoding without tiling, Allegro num_videos_per_prompt, Allegro timesteps scheduler.set_timesteps, AllegroPipelineOutput, and AllegroAttnProcessor2_0 attention backend. I found related historical PRs, especially #10212, but no open duplicate for the actionable issues above.

Test execution note: targeted .venv snippets were run successfully. A direct pytest collection of tests/pipelines/allegro/test_allegro.py::AllegroPipelineFastTests::test_inference failed in this Windows environment because the installed PyTorch lacks torch._C._distributed_c10d, matching the known class of issue in #12409.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions