Skip to content

[pipelines] fix SD3 crash with pre-computed prompt_embeds and num_images_per_prompt#13755

Open
zxuhan wants to merge 1 commit into
huggingface:mainfrom
zxuhan:fix/sd3-prompt-embeds-num-images-10712
Open

[pipelines] fix SD3 crash with pre-computed prompt_embeds and num_images_per_prompt#13755
zxuhan wants to merge 1 commit into
huggingface:mainfrom
zxuhan:fix/sd3-prompt-embeds-num-images-10712

Conversation

@zxuhan
Copy link
Copy Markdown

@zxuhan zxuhan commented May 14, 2026

What does this PR do?

Fixes #10712.

StableDiffusion3Pipeline.__call__ raises a RuntimeError when the caller passes pre-computed prompt_embeds together with num_images_per_prompt > 1:

prompt_embeds, neg, pooled, neg_pooled = pipe.encode_prompt(
    prompt="...", do_classifier_free_guidance=True, num_images_per_prompt=2,
)
pipe(prompt_embeds=prompt_embeds, negative_prompt_embeds=neg,
     pooled_prompt_embeds=pooled, negative_pooled_prompt_embeds=neg_pooled,
     num_images_per_prompt=2)
# RuntimeError: size of tensor a (8) must match size of tensor b (4)

Root cause

encode_prompt relies on _get_clip_prompt_embeds and _get_t5_prompt_embeds to apply the num_images_per_prompt expansion. Those helpers only run when prompt_embeds is None, so user-supplied embeddings keep their original batch dimension. __call__ then computes latents = prepare_latents(batch_size * num_images_per_prompt, ...), CFG doubles that to 2 * batch_size * num_images_per_prompt, while encoder_hidden_states stays at 2 * batch_size. The joint-attention block fails when it concatenates the two.

StableDiffusionXLPipeline is not affected because its encode_prompt applies the repeat(1, num_images_per_prompt, 1).view(...) expansion unconditionally after the if/else branch.

Fix

Apply the same expansion to user-supplied prompt_embeds (and the matching pooled_prompt_embeds, negative_prompt_embeds, negative_pooled_prompt_embeds) at the end of encode_prompt. The change is propagated via make fix-copies to the img2img, inpaint, controlnet, and PAG variants:

  • pipeline_stable_diffusion_3_img2img.py
  • pipeline_stable_diffusion_3_inpaint.py
  • controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py
  • controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py
  • pag/pipeline_pag_sd_3.py
  • pag/pipeline_pag_sd_3_img2img.py

Adds test_pipeline_accepts_prompt_embeds_with_num_images_per_prompt, which feeds the output of encode_prompt(num_images_per_prompt=k) back into the pipeline. Without the fix the test reproduces the RuntimeError from the issue; with the fix it passes.

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline?
  • Did you read our philosophy doc (important for complex PRs)?
  • Was this discussed/approved via a GitHub issue or the forum? See StableDiffusion3 pipeline RuntimeError when using prompt_embeds #10712.
  • Did you make sure to update the documentation with your changes? The existing docstring for num_images_per_prompt already describes the intended behavior; this PR makes the implementation match it.
  • Did you write any new necessary tests?

Who can review?

@yiyixuxu @sayakpaul

…ith `num_images_per_prompt`

`StableDiffusion3Pipeline.encode_prompt` expands the encoded embeddings inside
`_get_clip_prompt_embeds` / `_get_t5_prompt_embeds`, but skips that path when
the user supplies `prompt_embeds` directly. The pipeline then multiplies
`batch_size` by `num_images_per_prompt` for `prepare_latents`, so the latent
batch and the transformer's `encoder_hidden_states` end up with mismatched
shapes and the call dies inside the joint-attention block (huggingface#10712).

Mirror SDXL's pattern by applying the same expansion to user-supplied
`prompt_embeds` (and the matching `pooled_prompt_embeds` / negatives) at the
end of `encode_prompt`. Propagated to the img2img, inpaint, controlnet, and
PAG variants via `make fix-copies`. Adds a regression test that feeds
`encode_prompt(num_images_per_prompt=k)` output back into the pipeline.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

StableDiffusion3 pipeline RuntimeError when using prompt_embeds

1 participant