cogview3 model/pipeline review
Commit tested: 0f1abc4ae8b0eb2a3b40e82a310507281144c423
Review performed against the repository review rules.
Duplicate search: checked GitHub Issues/PRs for cogview3, affected class/output names, prompt-embed batch failures, size validation, attention backend behavior, pooled_projection_dim, stale checkpoint IDs, and fp16 black-image failures. Related but not full duplicates: PR #10211 fixed only the pipeline example checkpoint ID; issue #10343 covers CogView3 fp16 black images.
Local execution note: targeted reproductions were run with .venv. Full fast test collection did not reach target code because this environment's Torch build lacks torch._C._distributed_c10d.
Issue 1: Lazy export points to a non-existent output class
Affected code:
|
_import_structure = {"pipeline_output": ["CogView3PlusPipelineOutput"]} |
|
@dataclass |
|
class CogView3PipelineOutput(BaseOutput): |
|
""" |
|
Output class for CogView3 pipelines. |
|
|
|
Args: |
|
images (`list[PIL.Image.Image]` or `np.ndarray`) |
|
list of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width, |
|
num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline. |
|
""" |
|
|
|
images: list[PIL.Image.Image] | np.ndarray |
Problem:
cogview3.__init__ exports CogView3PlusPipelineOutput, but pipeline_output.py defines CogView3PipelineOutput. Importing the advertised lazy symbol fails.
Impact:
Public lazy imports are broken for the output type, and docs/export tooling can drift from the actual class.
Reproduction:
try:
from diffusers.pipelines.cogview3 import CogView3PlusPipelineOutput
except Exception as e:
print(type(e).__name__, e)
from diffusers.pipelines.cogview3.pipeline_output import CogView3PipelineOutput
print(CogView3PipelineOutput.__name__)
Relevant precedent:
Docs already autodoc the real class name at:
|
## CogView3PipelineOutput |
|
|
|
[[autodoc]] pipelines.cogview3.pipeline_output.CogView3PipelineOutput |
Suggested fix:
_import_structure = {"pipeline_output": ["CogView3PipelineOutput"]}
# in TYPE_CHECKING branch
from .pipeline_output import CogView3PipelineOutput
Issue 2: Precomputed prompt embeddings are not expanded or dtype/device-normalized
Affected code:
|
if prompt_embeds is None: |
|
prompt_embeds = self._get_t5_prompt_embeds( |
|
prompt=prompt, |
|
num_images_per_prompt=num_images_per_prompt, |
|
max_sequence_length=max_sequence_length, |
|
device=device, |
|
dtype=dtype, |
|
) |
|
|
|
if do_classifier_free_guidance and negative_prompt is None: |
|
negative_prompt_embeds = prompt_embeds.new_zeros(prompt_embeds.shape) |
|
|
|
if do_classifier_free_guidance and negative_prompt_embeds is None: |
|
negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt |
|
|
|
if prompt is not None and type(prompt) is not type(negative_prompt): |
|
raise TypeError( |
|
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" |
|
f" {type(prompt)}." |
|
) |
|
elif batch_size != len(negative_prompt): |
|
raise ValueError( |
|
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" |
|
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" |
|
" the batch size of `prompt`." |
|
) |
|
|
|
negative_prompt_embeds = self._get_t5_prompt_embeds( |
|
prompt=negative_prompt, |
|
num_images_per_prompt=num_images_per_prompt, |
|
max_sequence_length=max_sequence_length, |
|
device=device, |
|
dtype=dtype, |
|
) |
|
|
|
# 5. Prepare latents. |
|
latent_channels = self.transformer.config.in_channels |
|
latents = self.prepare_latents( |
|
batch_size * num_images_per_prompt, |
|
latent_channels, |
|
height, |
|
width, |
|
prompt_embeds.dtype, |
|
device, |
|
generator, |
|
latents, |
|
) |
|
|
|
# 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline |
|
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) |
|
|
|
# 7. Prepare additional timestep conditions |
|
original_size = torch.tensor([original_size], dtype=prompt_embeds.dtype) |
|
target_size = torch.tensor([target_size], dtype=prompt_embeds.dtype) |
|
crops_coords_top_left = torch.tensor([crops_coords_top_left], dtype=prompt_embeds.dtype) |
|
|
|
if self.do_classifier_free_guidance: |
|
original_size = torch.cat([original_size, original_size]) |
|
target_size = torch.cat([target_size, target_size]) |
|
crops_coords_top_left = torch.cat([crops_coords_top_left, crops_coords_top_left]) |
|
|
|
original_size = original_size.to(device).repeat(batch_size * num_images_per_prompt, 1) |
|
target_size = target_size.to(device).repeat(batch_size * num_images_per_prompt, 1) |
|
crops_coords_top_left = crops_coords_top_left.to(device).repeat(batch_size * num_images_per_prompt, 1) |
|
|
|
# 8. Denoising loop |
|
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) |
|
|
|
with self.progress_bar(total=num_inference_steps) as progress_bar: |
|
# for DPM-solver++ |
|
old_pred_original_sample = None |
|
for i, t in enumerate(timesteps): |
|
if self.interrupt: |
|
continue |
|
|
|
latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents |
|
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) |
|
|
|
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML |
|
timestep = t.expand(latent_model_input.shape[0]) |
|
|
|
# predict noise model_output |
|
noise_pred = self.transformer( |
|
hidden_states=latent_model_input, |
|
encoder_hidden_states=prompt_embeds, |
|
timestep=timestep, |
|
original_size=original_size, |
|
target_size=target_size, |
|
crop_coords=crops_coords_top_left, |
|
return_dict=False, |
|
)[0] |
Problem:
When prompt_embeds or negative_prompt_embeds are passed directly, CogView3 does not repeat them for num_images_per_prompt and does not cast/move them to the execution dtype/device. The latent batch is expanded, but the prompt batch is not.
Impact:
Valid diffusers usage with precomputed embeddings fails with batch mismatch, device mismatch, or dtype mismatch.
Reproduction:
import torch
from diffusers import CogVideoXDDIMScheduler, CogView3PlusPipeline, CogView3PlusTransformer2DModel
def make_pipe(dtype=torch.float32):
transformer = CogView3PlusTransformer2DModel(
patch_size=2, in_channels=4, num_layers=1, attention_head_dim=4,
num_attention_heads=2, out_channels=4, text_embed_dim=8,
time_embed_dim=8, condition_dim=2, pos_embed_max_size=8, sample_size=2,
).to(dtype=dtype)
pipe = CogView3PlusPipeline(None, None, None, transformer, CogVideoXDDIMScheduler())
pipe.set_progress_bar_config(disable=True)
return pipe
try:
make_pipe()(prompt_embeds=torch.randn(1, 8, 8), num_images_per_prompt=2,
num_inference_steps=1, guidance_scale=1.0, height=16, width=16, output_type="latent")
except Exception as e:
print(type(e).__name__, e)
try:
make_pipe(torch.float16)(prompt_embeds=torch.randn(1, 8, 8, dtype=torch.float32),
num_inference_steps=1, guidance_scale=1.0,
height=16, width=16, output_type="latent")
except Exception as e:
print(type(e).__name__, e)
Relevant precedent:
CogView4 expands precomputed embeds:
|
if prompt_embeds is None: |
|
prompt_embeds = self._get_glm_embeds(prompt, max_sequence_length, device, dtype) |
|
|
|
seq_len = prompt_embeds.size(1) |
|
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) |
|
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) |
|
|
|
if do_classifier_free_guidance and negative_prompt_embeds is None: |
|
negative_prompt = negative_prompt or "" |
|
negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt |
|
|
|
if prompt is not None and type(prompt) is not type(negative_prompt): |
|
raise TypeError( |
|
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" |
|
f" {type(prompt)}." |
|
) |
|
elif batch_size != len(negative_prompt): |
|
raise ValueError( |
|
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" |
|
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" |
|
" the batch size of `prompt`." |
|
) |
|
|
|
negative_prompt_embeds = self._get_glm_embeds(negative_prompt, max_sequence_length, device, dtype) |
|
|
|
seq_len = negative_prompt_embeds.size(1) |
|
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) |
|
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) |
Suggested fix:
if prompt_embeds is None:
prompt_embeds = self._get_t5_prompt_embeds(...)
else:
dtype = dtype or self.transformer.dtype
prompt_embeds = prompt_embeds.to(device=device, dtype=dtype)
bs_embed, seq_len, _ = prompt_embeds.shape
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
if do_classifier_free_guidance and negative_prompt_embeds is None and negative_prompt is None:
negative_prompt_embeds = prompt_embeds.new_zeros(prompt_embeds.shape)
elif do_classifier_free_guidance and negative_prompt_embeds is not None:
negative_prompt_embeds = negative_prompt_embeds.to(device=device, dtype=prompt_embeds.dtype)
bs_embed, seq_len, _ = negative_prompt_embeds.shape
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
negative_prompt_embeds = negative_prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
Issue 3: Height/width validation ignores transformer patch size
Affected code:
|
# Copied from diffusers.pipelines.latte.pipeline_latte.LattePipeline.check_inputs |
|
def check_inputs( |
|
self, |
|
prompt, |
|
height, |
|
width, |
|
negative_prompt, |
|
callback_on_step_end_tensor_inputs, |
|
prompt_embeds=None, |
|
negative_prompt_embeds=None, |
|
): |
|
if height % 8 != 0 or width % 8 != 0: |
|
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") |
Problem:
The pipeline only requires dimensions divisible by 8, but CogView3 patchifies VAE latents with patch_size=2. A 24x24 image passes check_inputs, then fails inside the transformer because latent size 3x3 is not divisible by patch size.
Impact:
Users get a late internal transformer error instead of a clear pipeline validation error.
Reproduction:
import torch
from diffusers import CogVideoXDDIMScheduler, CogView3PlusPipeline, CogView3PlusTransformer2DModel
transformer = CogView3PlusTransformer2DModel(
patch_size=2, in_channels=4, num_layers=1, attention_head_dim=4,
num_attention_heads=2, out_channels=4, text_embed_dim=8,
time_embed_dim=8, condition_dim=2, pos_embed_max_size=8, sample_size=2,
)
pipe = CogView3PlusPipeline(None, None, None, transformer, CogVideoXDDIMScheduler())
pipe.set_progress_bar_config(disable=True)
try:
pipe(prompt_embeds=torch.randn(1, 8, 8), num_inference_steps=1,
guidance_scale=1.0, height=24, width=24, output_type="latent")
except Exception as e:
print(type(e).__name__, e)
Relevant precedent:
CogView4 validates 16, matching vae_scale_factor * patch_size:
|
if height % 16 != 0 or width % 16 != 0: |
|
raise ValueError(f"`height` and `width` have to be divisible by 16 but are {height} and {width}.") |
Suggested fix:
divisibility = self.vae_scale_factor * self.transformer.config.patch_size
if height % divisibility != 0 or width % divisibility != 0:
raise ValueError(
f"`height` and `width` have to be divisible by {divisibility} but are {height} and {width}."
)
Issue 4: Attention backend selection is effectively ignored
Affected code:
|
from ..attention import AttentionMixin, FeedForward |
|
from ..attention_processor import Attention, CogVideoXAttnProcessor2_0 |
|
self.attn1 = Attention( |
|
query_dim=dim, |
|
heads=num_attention_heads, |
|
dim_head=attention_head_dim, |
|
out_dim=dim, |
|
bias=True, |
|
qk_norm="layer_norm", |
|
elementwise_affine=False, |
|
eps=1e-6, |
|
processor=CogVideoXAttnProcessor2_0(), |
|
) |
|
class CogVideoXAttnProcessor2_0: |
|
r""" |
|
Processor for implementing scaled dot-product attention for the CogVideoX model. It applies a rotary embedding on |
|
query and key vectors, but does not include spatial normalization. |
|
""" |
|
|
|
def __init__(self): |
|
if not hasattr(F, "scaled_dot_product_attention"): |
|
raise ImportError("CogVideoXAttnProcessor requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") |
|
|
|
def __call__( |
|
self, |
|
attn: Attention, |
|
hidden_states: torch.Tensor, |
|
encoder_hidden_states: torch.Tensor, |
|
attention_mask: torch.Tensor | None = None, |
|
image_rotary_emb: torch.Tensor | None = None, |
|
) -> torch.Tensor: |
|
text_seq_length = encoder_hidden_states.size(1) |
|
|
|
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1) |
|
|
|
batch_size, sequence_length, _ = hidden_states.shape |
|
|
|
if attention_mask is not None: |
|
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) |
|
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) |
|
|
|
query = attn.to_q(hidden_states) |
|
key = attn.to_k(hidden_states) |
|
value = attn.to_v(hidden_states) |
|
|
|
inner_dim = key.shape[-1] |
|
head_dim = inner_dim // attn.heads |
|
|
|
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) |
|
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) |
|
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) |
|
|
|
if attn.norm_q is not None: |
|
query = attn.norm_q(query) |
|
if attn.norm_k is not None: |
|
key = attn.norm_k(key) |
|
|
|
# Apply RoPE if needed |
|
if image_rotary_emb is not None: |
|
from .embeddings import apply_rotary_emb |
|
|
|
query[:, :, text_seq_length:] = apply_rotary_emb(query[:, :, text_seq_length:], image_rotary_emb) |
|
if not attn.is_cross_attention: |
|
key[:, :, text_seq_length:] = apply_rotary_emb(key[:, :, text_seq_length:], image_rotary_emb) |
|
|
|
hidden_states = F.scaled_dot_product_attention( |
|
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False |
|
) |
Problem:
CogView3Plus reuses CogVideoXAttnProcessor2_0, which calls F.scaled_dot_product_attention directly and has no _attention_backend. model.set_attention_backend(...) skips it, so backend dispatch support is a no-op for this transformer.
Impact:
Users cannot reliably select supported attention backends for CogView3Plus, and the implementation violates the review rule requiring model-local processors to route through dispatch_attention_fn.
Reproduction:
from diffusers import CogView3PlusTransformer2DModel
model = CogView3PlusTransformer2DModel(
patch_size=2, in_channels=4, num_layers=1, attention_head_dim=4,
num_attention_heads=2, out_channels=4, text_embed_dim=8,
time_embed_dim=8, condition_dim=2, pos_embed_max_size=8, sample_size=2,
)
name, processor = next(iter(model.attn_processors.items()))
print(name, processor.__class__.__name__, hasattr(processor, "_attention_backend"))
model.set_attention_backend("native")
print(hasattr(processor, "_attention_backend"), getattr(processor, "_attention_backend", None))
Relevant precedent:
QwenImage routes attention through dispatch_attention_fn:
|
joint_hidden_states = dispatch_attention_fn( |
|
joint_query, |
|
joint_key, |
|
joint_value, |
|
attn_mask=attention_mask, |
|
dropout_p=0.0, |
|
is_causal=False, |
|
backend=self._attention_backend, |
Suggested fix:
Implement a CogView3Plus-local attention processor, with _attention_backend and _parallel_config, and call dispatch_attention_fn(...) instead of F.scaled_dot_product_attention(...). If doing the full rule-compliant refactor, also define a local attention module inheriting AttentionModuleMixin.
Issue 5: Released transformer config contains an ignored key on load
Affected code:
|
patch_size: int = 2, |
|
in_channels: int = 16, |
|
num_layers: int = 30, |
|
attention_head_dim: int = 40, |
|
num_attention_heads: int = 64, |
|
out_channels: int = 16, |
|
text_embed_dim: int = 4096, |
|
time_embed_dim: int = 512, |
|
condition_dim: int = 256, |
|
pos_embed_max_size: int = 128, |
|
sample_size: int = 128, |
|
): |
|
super().__init__() |
|
self.out_channels = out_channels |
|
self.inner_dim = num_attention_heads * attention_head_dim |
|
|
|
# CogView3 uses 3 additional SDXL-like conditions - original_size, target_size, crop_coords |
|
# Each of these are sincos embeddings of shape 2 * condition_dim |
|
self.pooled_projection_dim = 3 * 2 * condition_dim |
|
|
|
self.patch_embed = CogView3PlusPatchEmbed( |
|
in_channels=in_channels, |
|
hidden_size=self.inner_dim, |
|
patch_size=patch_size, |
|
text_hidden_size=text_embed_dim, |
|
pos_embed_max_size=pos_embed_max_size, |
|
) |
|
|
|
self.time_condition_embed = CogView3CombinedTimestepSizeEmbeddings( |
|
embedding_dim=time_embed_dim, |
|
condition_dim=condition_dim, |
|
pooled_projection_dim=self.pooled_projection_dim, |
Problem:
The published THUDM/CogView3-Plus-3B transformer config includes pooled_projection_dim, but the model constructor does not accept it. Loading emits an unexpected-config warning even though the value matches the derived value.
Impact:
Users see a spurious load warning for the official checkpoint, and config compatibility remains noisy.
Reproduction:
from diffusers import CogView3PlusTransformer2DModel
config = {
"patch_size": 2, "in_channels": 4, "num_layers": 1,
"attention_head_dim": 4, "num_attention_heads": 2,
"out_channels": 4, "text_embed_dim": 8, "time_embed_dim": 8,
"condition_dim": 2, "pos_embed_max_size": 8, "sample_size": 2,
"pooled_projection_dim": 12,
}
model = CogView3PlusTransformer2DModel.from_config(config)
print(model.pooled_projection_dim)
Relevant precedent:
Other transformer configs expose pooled_projection_dim directly where it is serialized:
|
pooled_projection_dim: int = 1024, |
|
text_len: int = 77, |
|
text_len_t5: int = 256, |
|
use_style_cond_and_image_meta_size: bool = True, |
|
): |
|
super().__init__() |
|
self.out_channels = in_channels * 2 if learn_sigma else in_channels |
|
self.num_heads = num_attention_heads |
|
self.inner_dim = num_attention_heads * attention_head_dim |
|
|
|
self.text_embedder = PixArtAlphaTextProjection( |
|
in_features=cross_attention_dim_t5, |
|
hidden_size=cross_attention_dim_t5 * 4, |
|
out_features=cross_attention_dim, |
|
act_fn="silu_fp32", |
|
) |
|
|
|
self.text_embedding_padding = nn.Parameter(torch.randn(text_len + text_len_t5, cross_attention_dim)) |
|
|
|
self.pos_embed = PatchEmbed( |
|
height=sample_size, |
|
width=sample_size, |
|
in_channels=in_channels, |
|
embed_dim=hidden_size, |
|
patch_size=patch_size, |
|
pos_embed_type=None, |
|
) |
|
|
|
self.time_extra_emb = HunyuanCombinedTimestepTextSizeStyleEmbedding( |
|
hidden_size, |
|
pooled_projection_dim=pooled_projection_dim, |
Suggested fix:
def __init__(..., pooled_projection_dim: int | None = None, ...):
...
self.pooled_projection_dim = pooled_projection_dim or 3 * 2 * condition_dim
self.register_to_config(pooled_projection_dim=self.pooled_projection_dim)
Issue 6: CogView3 tests/docs do not provide meaningful coverage
Affected code:
|
self.assertEqual(generated_image.shape, (3, 16, 16)) |
|
expected_image = torch.randn(3, 16, 16) |
|
max_diff = np.abs(generated_image - expected_image).max() |
|
self.assertLessEqual(max_diff, 1e10) |
|
def test_cogview3plus(self): |
|
generator = torch.Generator("cpu").manual_seed(0) |
|
|
|
pipe = CogView3PlusPipeline.from_pretrained("THUDM/CogView3Plus-3b", torch_dtype=torch.float16) |
|
pipe.enable_model_cpu_offload(device=torch_device) |
|
prompt = self.prompt |
|
|
|
images = pipe( |
|
prompt=prompt, |
|
height=1024, |
|
width=1024, |
|
generator=generator, |
|
num_inference_steps=2, |
|
output_type="np", |
|
)[0] |
|
|
|
image = images[0] |
|
expected_image = torch.randn(1, 1024, 1024, 3).numpy() |
|
|
|
max_diff = numpy_cosine_similarity_distance(image, expected_image) |
|
assert max_diff < 1e-3, f"Max diff is too high. got {image}" |
|
from diffusers import CogView3PlusTransformer2DModel |
|
|
|
transformer = CogView3PlusTransformer2DModel.from_pretrained("THUDM/CogView3Plus-3b", subfolder="transformer", torch_dtype=torch.bfloat16).to("cuda") |
Problem:
Fast tests compare against torch.randn(...) with 1e10 tolerance, so they only check shape. The slow test uses non-existent model ID THUDM/CogView3Plus-3b, loads unsupported torch.float16, and compares images[0] to a random expected array with an incompatible leading batch dimension. The model docs also use the stale checkpoint ID.
Impact:
Fast and slow tests exist, but they do not catch output regressions. The slow test cannot validate the official checkpoint as written.
Reproduction:
from huggingface_hub import model_info
import numpy as np
from numpy.linalg import norm
for model_id in ["THUDM/CogView3Plus-3b", "THUDM/CogView3-Plus-3B"]:
try:
print(model_id, model_info(model_id).id)
except Exception as e:
print(model_id, type(e).__name__)
def numpy_cosine_similarity_distance(a, b):
similarity = np.dot(a, b) / (norm(a) * norm(b))
return 1.0 - similarity.mean()
try:
image = np.zeros((4, 4, 3), dtype=np.float32)
expected_image = np.random.randn(1, 4, 4, 3).astype(np.float32)
print(numpy_cosine_similarity_distance(image, expected_image))
except Exception as e:
print(type(e).__name__, e)
Relevant precedent:
PR fixing the pipeline example checkpoint ID:
#10211
Existing fp16 black-image duplicate:
#10343
Suggested fix:
Update docs/tests to THUDM/CogView3-Plus-3B, use torch.bfloat16 for slow inference, and replace random expectations with deterministic expected slices or stored expected statistics generated from the official checkpoint. Fast test_inference should assert a small image slice with a real tolerance, not 1e10.
cogview3model/pipeline reviewCommit tested:
0f1abc4ae8b0eb2a3b40e82a310507281144c423Review performed against the repository review rules.
Duplicate search: checked GitHub Issues/PRs for
cogview3, affected class/output names, prompt-embed batch failures, size validation, attention backend behavior,pooled_projection_dim, stale checkpoint IDs, and fp16 black-image failures. Related but not full duplicates: PR #10211 fixed only the pipeline example checkpoint ID; issue #10343 covers CogView3 fp16 black images.Local execution note: targeted reproductions were run with
.venv. Full fast test collection did not reach target code because this environment's Torch build lackstorch._C._distributed_c10d.Issue 1: Lazy export points to a non-existent output class
Affected code:
diffusers/src/diffusers/pipelines/cogview3/__init__.py
Line 15 in 0f1abc4
diffusers/src/diffusers/pipelines/cogview3/pipeline_output.py
Lines 9 to 20 in 0f1abc4
Problem:
cogview3.__init__exportsCogView3PlusPipelineOutput, butpipeline_output.pydefinesCogView3PipelineOutput. Importing the advertised lazy symbol fails.Impact:
Public lazy imports are broken for the output type, and docs/export tooling can drift from the actual class.
Reproduction:
Relevant precedent:
Docs already autodoc the real class name at:
diffusers/docs/source/en/api/pipelines/cogview3.md
Lines 35 to 37 in 0f1abc4
Suggested fix:
Issue 2: Precomputed prompt embeddings are not expanded or dtype/device-normalized
Affected code:
diffusers/src/diffusers/pipelines/cogview3/pipeline_cogview3plus.py
Lines 258 to 292 in 0f1abc4
diffusers/src/diffusers/pipelines/cogview3/pipeline_cogview3plus.py
Lines 576 to 631 in 0f1abc4
Problem:
When
prompt_embedsornegative_prompt_embedsare passed directly, CogView3 does not repeat them fornum_images_per_promptand does not cast/move them to the execution dtype/device. The latent batch is expanded, but the prompt batch is not.Impact:
Valid diffusers usage with precomputed embeddings fails with batch mismatch, device mismatch, or dtype mismatch.
Reproduction:
Relevant precedent:
CogView4 expands precomputed embeds:
diffusers/src/diffusers/pipelines/cogview4/pipeline_cogview4.py
Lines 269 to 296 in 0f1abc4
Suggested fix:
Issue 3: Height/width validation ignores transformer patch size
Affected code:
diffusers/src/diffusers/pipelines/cogview3/pipeline_cogview3plus.py
Lines 336 to 348 in 0f1abc4
Problem:
The pipeline only requires dimensions divisible by
8, but CogView3 patchifies VAE latents withpatch_size=2. A24x24image passescheck_inputs, then fails inside the transformer because latent size3x3is not divisible by patch size.Impact:
Users get a late internal transformer error instead of a clear pipeline validation error.
Reproduction:
Relevant precedent:
CogView4 validates
16, matchingvae_scale_factor * patch_size:diffusers/src/diffusers/pipelines/cogview4/pipeline_cogview4.py
Lines 328 to 329 in 0f1abc4
Suggested fix:
Issue 4: Attention backend selection is effectively ignored
Affected code:
diffusers/src/diffusers/models/transformers/transformer_cogview3plus.py
Lines 21 to 22 in 0f1abc4
diffusers/src/diffusers/models/transformers/transformer_cogview3plus.py
Lines 58 to 68 in 0f1abc4
diffusers/src/diffusers/models/attention_processor.py
Lines 2277 to 2331 in 0f1abc4
Problem:
CogView3Plus reuses
CogVideoXAttnProcessor2_0, which callsF.scaled_dot_product_attentiondirectly and has no_attention_backend.model.set_attention_backend(...)skips it, so backend dispatch support is a no-op for this transformer.Impact:
Users cannot reliably select supported attention backends for CogView3Plus, and the implementation violates the review rule requiring model-local processors to route through
dispatch_attention_fn.Reproduction:
Relevant precedent:
QwenImage routes attention through
dispatch_attention_fn:diffusers/src/diffusers/models/transformers/transformer_qwenimage.py
Lines 562 to 569 in 0f1abc4
Suggested fix:
Implement a CogView3Plus-local attention processor, with
_attention_backendand_parallel_config, and calldispatch_attention_fn(...)instead ofF.scaled_dot_product_attention(...). If doing the full rule-compliant refactor, also define a local attention module inheritingAttentionModuleMixin.Issue 5: Released transformer config contains an ignored key on load
Affected code:
diffusers/src/diffusers/models/transformers/transformer_cogview3plus.py
Lines 168 to 199 in 0f1abc4
Problem:
The published
THUDM/CogView3-Plus-3Btransformer config includespooled_projection_dim, but the model constructor does not accept it. Loading emits an unexpected-config warning even though the value matches the derived value.Impact:
Users see a spurious load warning for the official checkpoint, and config compatibility remains noisy.
Reproduction:
Relevant precedent:
Other transformer configs expose
pooled_projection_dimdirectly where it is serialized:diffusers/src/diffusers/models/transformers/hunyuan_transformer_2d.py
Lines 264 to 294 in 0f1abc4
Suggested fix:
Issue 6: CogView3 tests/docs do not provide meaningful coverage
Affected code:
diffusers/tests/pipelines/cogview3/test_cogview3plus.py
Lines 135 to 138 in 0f1abc4
diffusers/tests/pipelines/cogview3/test_cogview3plus.py
Lines 256 to 276 in 0f1abc4
diffusers/docs/source/en/api/models/cogview3plus_transformer2d.md
Lines 19 to 21 in 0f1abc4
Problem:
Fast tests compare against
torch.randn(...)with1e10tolerance, so they only check shape. The slow test uses non-existent model IDTHUDM/CogView3Plus-3b, loads unsupportedtorch.float16, and comparesimages[0]to a random expected array with an incompatible leading batch dimension. The model docs also use the stale checkpoint ID.Impact:
Fast and slow tests exist, but they do not catch output regressions. The slow test cannot validate the official checkpoint as written.
Reproduction:
Relevant precedent:
PR fixing the pipeline example checkpoint ID:
#10211
Existing fp16 black-image duplicate:
#10343
Suggested fix:
Update docs/tests to
THUDM/CogView3-Plus-3B, usetorch.bfloat16for slow inference, and replace random expectations with deterministic expected slices or stored expected statistics generated from the official checkpoint. Fasttest_inferenceshould assert a small image slice with a real tolerance, not1e10.