Skip to content

consistency_models model/pipeline review #13643

@hlky

Description

@hlky

consistency_models model/pipeline review

Commit tested: 0f1abc4ae8b0eb2a3b40e82a310507281144c423

Review performed against the repository review rules.

Issue 1: Random class labels ignore the supplied generator

Affected code:

def prepare_class_labels(self, batch_size, device, class_labels=None):
if self.unet.config.num_class_embeds is not None:
if isinstance(class_labels, list):
class_labels = torch.tensor(class_labels, dtype=torch.int)
elif isinstance(class_labels, int):
assert batch_size == 1, "Batch size must be 1 if classes is an int"
class_labels = torch.tensor([class_labels], dtype=torch.int)
elif class_labels is None:
# Randomly generate batch_size class labels
# TODO: should use generator here? int analogue of randn_tensor is not exposed in ...utils
class_labels = torch.randint(0, self.unet.config.num_class_embeds, size=(batch_size,))
class_labels = class_labels.to(device)

Problem:
For class-conditional UNets, omitting class_labels makes the pipeline sample random labels with torch.randint(...), but prepare_class_labels does not receive or use the user-supplied generator. Seeded inference is therefore not fully controlled by generator.

Impact:
Two calls with identical explicit generators can produce different images when class_labels=None, which breaks the usual diffusers reproducibility contract.

Reproduction:

import torch
from diffusers import CMStochasticIterativeScheduler, ConsistencyModelPipeline, UNet2DModel

unet = UNet2DModel(
    sample_size=8,
    in_channels=3,
    out_channels=3,
    layers_per_block=1,
    block_out_channels=(8,),
    down_block_types=("DownBlock2D",),
    up_block_types=("UpBlock2D",),
    norm_num_groups=1,
    num_class_embeds=10,
)
pipe = ConsistencyModelPipeline(
    unet=unet,
    scheduler=CMStochasticIterativeScheduler(num_train_timesteps=40, sigma_min=0.002, sigma_max=80.0),
).to("cpu")
pipe.set_progress_bar_config(disable=True)

latents = torch.zeros((1, 3, 8, 8))
torch.manual_seed(0)

out_a = pipe(
    latents=latents,
    generator=torch.Generator(device="cpu").manual_seed(123),
    class_labels=None,
    num_inference_steps=1,
    output_type="pt",
).images
out_b = pipe(
    latents=latents,
    generator=torch.Generator(device="cpu").manual_seed(123),
    class_labels=None,
    num_inference_steps=1,
    output_type="pt",
).images

print(torch.equal(out_a, out_b), (out_a - out_b).abs().max().item())

Relevant precedent:
randn_tensor handles explicit generators, including CPU generators that later move tensors to the execution device:

def randn_tensor(
shape: tuple | list,
generator: list["torch.Generator"] | "torch.Generator" | None = None,
device: str | "torch.device" | None = None,
dtype: "torch.dtype" | None = None,
layout: "torch.layout" | None = None,
):
"""A helper function to create random tensors on the desired `device` with the desired `dtype`. When
passing a list of generators, you can seed each batch size individually. If CPU generators are passed, the tensor
is always created on the CPU.
"""
# device on which tensor is created defaults to device
if isinstance(device, str):
device = torch.device(device)
rand_device = device
batch_size = shape[0]
layout = layout or torch.strided
device = device or torch.device("cpu")
if generator is not None:
gen_device_type = generator.device.type if not isinstance(generator, list) else generator[0].device.type
if gen_device_type != device.type and gen_device_type == "cpu":
rand_device = "cpu"
if device != "mps":
logger.info(
f"The passed generator was created on 'cpu' even though a tensor on {device} was expected."
f" Tensors will be created on 'cpu' and then moved to {device}. Note that one can probably"
f" slightly speed up this function by passing a generator that was created on the {device} device."
)
elif gen_device_type != device.type and gen_device_type == "cuda":
raise ValueError(f"Cannot generate a {device} tensor from a generator of type {gen_device_type}.")
# make sure generator list of length 1 is treated like a non-list
if isinstance(generator, list) and len(generator) == 1:
generator = generator[0]
if isinstance(generator, list):
shape = (1,) + shape[1:]
latents = [
torch.randn(shape, generator=generator[i], device=rand_device, dtype=dtype, layout=layout)
for i in range(batch_size)
]
latents = torch.cat(latents, dim=0).to(device)
else:
latents = torch.randn(shape, generator=generator, device=rand_device, dtype=dtype, layout=layout).to(device)
return latents

Suggested fix:

def prepare_class_labels(self, batch_size, device, generator=None, class_labels=None):
    if self.unet.config.num_class_embeds is None:
        return None

    if isinstance(class_labels, list):
        class_labels = torch.tensor(class_labels, dtype=torch.long)
    elif isinstance(class_labels, int):
        class_labels = torch.tensor([class_labels] * batch_size, dtype=torch.long)
    elif class_labels is None:
        if isinstance(generator, list):
            class_labels = torch.cat(
                [
                    torch.randint(
                        0,
                        self.unet.config.num_class_embeds,
                        size=(1,),
                        generator=g,
                        device=g.device,
                    ).cpu()
                    for g in generator
                ]
            )
        else:
            rand_device = generator.device if generator is not None else torch.device("cpu")
            class_labels = torch.randint(
                0,
                self.unet.config.num_class_embeds,
                size=(batch_size,),
                generator=generator,
                device=rand_device,
            )

    return class_labels.to(device=device, dtype=torch.long)

Also pass generator=generator from __call__.

Duplicate check:
No matching existing issue or PR found for ConsistencyModelPipeline class_labels generator or consistency_models class_labels random.

Issue 2: Latent shape handling hardcodes RGB square images

Affected code:

if latents is not None:
expected_shape = (batch_size, 3, img_size, img_size)
if latents.shape != expected_shape:
raise ValueError(f"The shape of latents is {latents.shape} but is expected to be {expected_shape}.")

"""
# 0. Prepare call parameters
img_size = self.unet.config.sample_size
device = self._execution_device
# 1. Check inputs
self.check_inputs(num_inference_steps, timesteps, latents, batch_size, img_size, callback_steps)
# 2. Prepare image latents
# Sample image latents x_0 ~ N(0, sigma_0^2 * I)
sample = self.prepare_latents(
batch_size=batch_size,
num_channels=self.unet.config.in_channels,
height=img_size,
width=img_size,
dtype=self.unet.dtype,
device=device,
generator=generator,
latents=latents,

Problem:
check_inputs expects latent shape (batch_size, 3, img_size, img_size), even though prepare_latents uses self.unet.config.in_channels. The same img_size value is passed as both height and width, so tuple sample_size configs also fail before inference.

Impact:
Valid UNet2DModel configs with non-3 channel counts or non-square sample_size cannot be used through this pipeline, and valid user-supplied latents are rejected.

Reproduction:

import torch
from diffusers import CMStochasticIterativeScheduler, ConsistencyModelPipeline, UNet2DModel

scheduler = CMStochasticIterativeScheduler(num_train_timesteps=40, sigma_min=0.002, sigma_max=80.0)

one_channel_unet = UNet2DModel(
    sample_size=8,
    in_channels=1,
    out_channels=1,
    layers_per_block=1,
    block_out_channels=(8,),
    down_block_types=("DownBlock2D",),
    up_block_types=("UpBlock2D",),
    norm_num_groups=1,
)
pipe = ConsistencyModelPipeline(unet=one_channel_unet, scheduler=scheduler).to("cpu")
pipe.set_progress_bar_config(disable=True)

try:
    pipe(latents=torch.zeros((1, 1, 8, 8)), num_inference_steps=1, output_type="pt")
except Exception as e:
    print(type(e).__name__, e)

rect_unet = UNet2DModel(
    sample_size=(8, 10),
    in_channels=3,
    out_channels=3,
    layers_per_block=1,
    block_out_channels=(8,),
    down_block_types=("DownBlock2D",),
    up_block_types=("UpBlock2D",),
    norm_num_groups=1,
)
pipe = ConsistencyModelPipeline(unet=rect_unet, scheduler=scheduler).to("cpu")
pipe.set_progress_bar_config(disable=True)

try:
    pipe(num_inference_steps=1, output_type="pt")
except Exception as e:
    print(type(e).__name__, e)

Relevant precedent:
DDPMPipeline derives shape from both unet.config.in_channels and tuple sample_size:

# Sample gaussian noise to begin loop
if isinstance(self.unet.config.sample_size, int):
image_shape = (
batch_size,
self.unet.config.in_channels,
self.unet.config.sample_size,
self.unet.config.sample_size,
)
else:
image_shape = (batch_size, self.unet.config.in_channels, *self.unet.config.sample_size)

Suggested fix:

sample_size = self.unet.config.sample_size
if isinstance(sample_size, int):
    height = width = sample_size
else:
    height, width = sample_size

self.check_inputs(num_inference_steps, timesteps, latents, batch_size, height, width, callback_steps)

sample = self.prepare_latents(
    batch_size=batch_size,
    num_channels=self.unet.config.in_channels,
    height=height,
    width=width,
    dtype=self.unet.dtype,
    device=device,
    generator=generator,
    latents=latents,
)

and update validation to:

expected_shape = (batch_size, self.unet.config.in_channels, height, width)

Duplicate check:
No matching existing issue or PR found for ConsistencyModelPipeline sample_size tuple or pipeline_consistency_models in_channels latents.

Coverage / Search Status

Reviewed: public lazy imports and top-level exports, pipeline config/loading surface, runtime dtype/device/offload path, scheduler interaction, class-conditioning path, docs, examples, deprecation status, and tests under tests/pipelines/consistency_models.

Fast tests exist at tests/pipelines/consistency_models/test_consistency_models.py. Slow tests also exist under ConsistencyModelPipelineSlowTests, so slow coverage is not missing. Current tests do not cover omitted class labels on class-conditional models, non-RGB latent shapes, or tuple sample_size.

I attempted .venv pytest collection for the target test file, but this environment's Torch build is missing torch._C._distributed_c10d, which breaks collection through shared test utilities. The standalone reproductions above were run with .venv.

Duplicate searches were run with gh search issues --include-prs for consistency_models, ConsistencyModelPipeline, pipeline_consistency_models, CMStochasticIterativeScheduler ConsistencyModelPipeline, and the specific class-label / latent-shape failure modes. Existing results included older consistency-model sampling/training items, but none matched the two findings above.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions