diff --git a/src/diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py b/src/diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py index a787a34bdc01..b678c54cb7d5 100644 --- a/src/diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py +++ b/src/diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py @@ -447,6 +447,13 @@ def encode_prompt( else: batch_size = prompt_embeds.shape[0] + # The internal `_get_*_prompt_embeds` helpers expand the encoded embeddings + # by `num_images_per_prompt`, but user-supplied embeddings bypass that path. + # Track that here so we can apply the same expansion at the end and keep the + # batch dimension consistent with `prepare_latents` (see #10712). + prompt_embeds_was_provided = prompt_embeds is not None + negative_prompt_embeds_was_provided = negative_prompt_embeds is not None + if prompt_embeds is None: prompt_2 = prompt_2 or prompt prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2 @@ -543,6 +550,28 @@ def encode_prompt( [negative_pooled_prompt_embed, negative_pooled_prompt_2_embed], dim=-1 ) + # Apply `num_images_per_prompt` expansion to user-supplied embeddings to match + # what `_get_*_prompt_embeds` already does for freshly-encoded ones (#10712). + if prompt_embeds_was_provided and num_images_per_prompt > 1: + seq_len, hidden_dim = prompt_embeds.shape[-2], prompt_embeds.shape[-1] + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, hidden_dim) + pooled_dim = pooled_prompt_embeds.shape[-1] + pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt) + pooled_prompt_embeds = pooled_prompt_embeds.view(batch_size * num_images_per_prompt, pooled_dim) + + if do_classifier_free_guidance and negative_prompt_embeds_was_provided and num_images_per_prompt > 1: + seq_len, hidden_dim = negative_prompt_embeds.shape[-2], negative_prompt_embeds.shape[-1] + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view( + batch_size * num_images_per_prompt, seq_len, hidden_dim + ) + pooled_dim = negative_pooled_prompt_embeds.shape[-1] + negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(1, num_images_per_prompt) + negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.view( + batch_size * num_images_per_prompt, pooled_dim + ) + if self.text_encoder is not None: if isinstance(self, SD3LoraLoaderMixin) and USE_PEFT_BACKEND: # Retrieve the original scale by scaling back the LoRA layers diff --git a/src/diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py b/src/diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py index 96f53b16cbe8..d5d08da26ff8 100644 --- a/src/diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py +++ b/src/diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py @@ -469,6 +469,13 @@ def encode_prompt( else: batch_size = prompt_embeds.shape[0] + # The internal `_get_*_prompt_embeds` helpers expand the encoded embeddings + # by `num_images_per_prompt`, but user-supplied embeddings bypass that path. + # Track that here so we can apply the same expansion at the end and keep the + # batch dimension consistent with `prepare_latents` (see #10712). + prompt_embeds_was_provided = prompt_embeds is not None + negative_prompt_embeds_was_provided = negative_prompt_embeds is not None + if prompt_embeds is None: prompt_2 = prompt_2 or prompt prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2 @@ -565,6 +572,28 @@ def encode_prompt( [negative_pooled_prompt_embed, negative_pooled_prompt_2_embed], dim=-1 ) + # Apply `num_images_per_prompt` expansion to user-supplied embeddings to match + # what `_get_*_prompt_embeds` already does for freshly-encoded ones (#10712). + if prompt_embeds_was_provided and num_images_per_prompt > 1: + seq_len, hidden_dim = prompt_embeds.shape[-2], prompt_embeds.shape[-1] + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, hidden_dim) + pooled_dim = pooled_prompt_embeds.shape[-1] + pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt) + pooled_prompt_embeds = pooled_prompt_embeds.view(batch_size * num_images_per_prompt, pooled_dim) + + if do_classifier_free_guidance and negative_prompt_embeds_was_provided and num_images_per_prompt > 1: + seq_len, hidden_dim = negative_prompt_embeds.shape[-2], negative_prompt_embeds.shape[-1] + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view( + batch_size * num_images_per_prompt, seq_len, hidden_dim + ) + pooled_dim = negative_pooled_prompt_embeds.shape[-1] + negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(1, num_images_per_prompt) + negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.view( + batch_size * num_images_per_prompt, pooled_dim + ) + if self.text_encoder is not None: if isinstance(self, SD3LoraLoaderMixin) and USE_PEFT_BACKEND: # Retrieve the original scale by scaling back the LoRA layers diff --git a/src/diffusers/pipelines/pag/pipeline_pag_sd_3.py b/src/diffusers/pipelines/pag/pipeline_pag_sd_3.py index f0fbef29b699..a696a6f5583b 100644 --- a/src/diffusers/pipelines/pag/pipeline_pag_sd_3.py +++ b/src/diffusers/pipelines/pag/pipeline_pag_sd_3.py @@ -417,6 +417,13 @@ def encode_prompt( else: batch_size = prompt_embeds.shape[0] + # The internal `_get_*_prompt_embeds` helpers expand the encoded embeddings + # by `num_images_per_prompt`, but user-supplied embeddings bypass that path. + # Track that here so we can apply the same expansion at the end and keep the + # batch dimension consistent with `prepare_latents` (see #10712). + prompt_embeds_was_provided = prompt_embeds is not None + negative_prompt_embeds_was_provided = negative_prompt_embeds is not None + if prompt_embeds is None: prompt_2 = prompt_2 or prompt prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2 @@ -513,6 +520,28 @@ def encode_prompt( [negative_pooled_prompt_embed, negative_pooled_prompt_2_embed], dim=-1 ) + # Apply `num_images_per_prompt` expansion to user-supplied embeddings to match + # what `_get_*_prompt_embeds` already does for freshly-encoded ones (#10712). + if prompt_embeds_was_provided and num_images_per_prompt > 1: + seq_len, hidden_dim = prompt_embeds.shape[-2], prompt_embeds.shape[-1] + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, hidden_dim) + pooled_dim = pooled_prompt_embeds.shape[-1] + pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt) + pooled_prompt_embeds = pooled_prompt_embeds.view(batch_size * num_images_per_prompt, pooled_dim) + + if do_classifier_free_guidance and negative_prompt_embeds_was_provided and num_images_per_prompt > 1: + seq_len, hidden_dim = negative_prompt_embeds.shape[-2], negative_prompt_embeds.shape[-1] + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view( + batch_size * num_images_per_prompt, seq_len, hidden_dim + ) + pooled_dim = negative_pooled_prompt_embeds.shape[-1] + negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(1, num_images_per_prompt) + negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.view( + batch_size * num_images_per_prompt, pooled_dim + ) + if self.text_encoder is not None: if isinstance(self, SD3LoraLoaderMixin) and USE_PEFT_BACKEND: # Retrieve the original scale by scaling back the LoRA layers diff --git a/src/diffusers/pipelines/pag/pipeline_pag_sd_3_img2img.py b/src/diffusers/pipelines/pag/pipeline_pag_sd_3_img2img.py index 84b727dc0613..a77bfaaa134d 100644 --- a/src/diffusers/pipelines/pag/pipeline_pag_sd_3_img2img.py +++ b/src/diffusers/pipelines/pag/pipeline_pag_sd_3_img2img.py @@ -433,6 +433,13 @@ def encode_prompt( else: batch_size = prompt_embeds.shape[0] + # The internal `_get_*_prompt_embeds` helpers expand the encoded embeddings + # by `num_images_per_prompt`, but user-supplied embeddings bypass that path. + # Track that here so we can apply the same expansion at the end and keep the + # batch dimension consistent with `prepare_latents` (see #10712). + prompt_embeds_was_provided = prompt_embeds is not None + negative_prompt_embeds_was_provided = negative_prompt_embeds is not None + if prompt_embeds is None: prompt_2 = prompt_2 or prompt prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2 @@ -529,6 +536,28 @@ def encode_prompt( [negative_pooled_prompt_embed, negative_pooled_prompt_2_embed], dim=-1 ) + # Apply `num_images_per_prompt` expansion to user-supplied embeddings to match + # what `_get_*_prompt_embeds` already does for freshly-encoded ones (#10712). + if prompt_embeds_was_provided and num_images_per_prompt > 1: + seq_len, hidden_dim = prompt_embeds.shape[-2], prompt_embeds.shape[-1] + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, hidden_dim) + pooled_dim = pooled_prompt_embeds.shape[-1] + pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt) + pooled_prompt_embeds = pooled_prompt_embeds.view(batch_size * num_images_per_prompt, pooled_dim) + + if do_classifier_free_guidance and negative_prompt_embeds_was_provided and num_images_per_prompt > 1: + seq_len, hidden_dim = negative_prompt_embeds.shape[-2], negative_prompt_embeds.shape[-1] + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view( + batch_size * num_images_per_prompt, seq_len, hidden_dim + ) + pooled_dim = negative_pooled_prompt_embeds.shape[-1] + negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(1, num_images_per_prompt) + negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.view( + batch_size * num_images_per_prompt, pooled_dim + ) + if self.text_encoder is not None: if isinstance(self, SD3LoraLoaderMixin) and USE_PEFT_BACKEND: # Retrieve the original scale by scaling back the LoRA layers diff --git a/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py b/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py index 7764a79d7faf..1bcf051489a3 100644 --- a/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +++ b/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py @@ -426,6 +426,13 @@ def encode_prompt( else: batch_size = prompt_embeds.shape[0] + # The internal `_get_*_prompt_embeds` helpers expand the encoded embeddings + # by `num_images_per_prompt`, but user-supplied embeddings bypass that path. + # Track that here so we can apply the same expansion at the end and keep the + # batch dimension consistent with `prepare_latents` (see #10712). + prompt_embeds_was_provided = prompt_embeds is not None + negative_prompt_embeds_was_provided = negative_prompt_embeds is not None + if prompt_embeds is None: prompt_2 = prompt_2 or prompt prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2 @@ -522,6 +529,28 @@ def encode_prompt( [negative_pooled_prompt_embed, negative_pooled_prompt_2_embed], dim=-1 ) + # Apply `num_images_per_prompt` expansion to user-supplied embeddings to match + # what `_get_*_prompt_embeds` already does for freshly-encoded ones (#10712). + if prompt_embeds_was_provided and num_images_per_prompt > 1: + seq_len, hidden_dim = prompt_embeds.shape[-2], prompt_embeds.shape[-1] + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, hidden_dim) + pooled_dim = pooled_prompt_embeds.shape[-1] + pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt) + pooled_prompt_embeds = pooled_prompt_embeds.view(batch_size * num_images_per_prompt, pooled_dim) + + if do_classifier_free_guidance and negative_prompt_embeds_was_provided and num_images_per_prompt > 1: + seq_len, hidden_dim = negative_prompt_embeds.shape[-2], negative_prompt_embeds.shape[-1] + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view( + batch_size * num_images_per_prompt, seq_len, hidden_dim + ) + pooled_dim = negative_pooled_prompt_embeds.shape[-1] + negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(1, num_images_per_prompt) + negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.view( + batch_size * num_images_per_prompt, pooled_dim + ) + if self.text_encoder is not None: if isinstance(self, SD3LoraLoaderMixin) and USE_PEFT_BACKEND: # Retrieve the original scale by scaling back the LoRA layers diff --git a/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py b/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py index 7951b970cd0c..dc8a8496021f 100644 --- a/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +++ b/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py @@ -452,6 +452,13 @@ def encode_prompt( else: batch_size = prompt_embeds.shape[0] + # The internal `_get_*_prompt_embeds` helpers expand the encoded embeddings + # by `num_images_per_prompt`, but user-supplied embeddings bypass that path. + # Track that here so we can apply the same expansion at the end and keep the + # batch dimension consistent with `prepare_latents` (see #10712). + prompt_embeds_was_provided = prompt_embeds is not None + negative_prompt_embeds_was_provided = negative_prompt_embeds is not None + if prompt_embeds is None: prompt_2 = prompt_2 or prompt prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2 @@ -548,6 +555,28 @@ def encode_prompt( [negative_pooled_prompt_embed, negative_pooled_prompt_2_embed], dim=-1 ) + # Apply `num_images_per_prompt` expansion to user-supplied embeddings to match + # what `_get_*_prompt_embeds` already does for freshly-encoded ones (#10712). + if prompt_embeds_was_provided and num_images_per_prompt > 1: + seq_len, hidden_dim = prompt_embeds.shape[-2], prompt_embeds.shape[-1] + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, hidden_dim) + pooled_dim = pooled_prompt_embeds.shape[-1] + pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt) + pooled_prompt_embeds = pooled_prompt_embeds.view(batch_size * num_images_per_prompt, pooled_dim) + + if do_classifier_free_guidance and negative_prompt_embeds_was_provided and num_images_per_prompt > 1: + seq_len, hidden_dim = negative_prompt_embeds.shape[-2], negative_prompt_embeds.shape[-1] + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view( + batch_size * num_images_per_prompt, seq_len, hidden_dim + ) + pooled_dim = negative_pooled_prompt_embeds.shape[-1] + negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(1, num_images_per_prompt) + negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.view( + batch_size * num_images_per_prompt, pooled_dim + ) + if self.text_encoder is not None: if isinstance(self, SD3LoraLoaderMixin) and USE_PEFT_BACKEND: # Retrieve the original scale by scaling back the LoRA layers diff --git a/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py b/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py index d3594b868f89..1062b91f7d89 100644 --- a/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +++ b/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py @@ -458,6 +458,13 @@ def encode_prompt( else: batch_size = prompt_embeds.shape[0] + # The internal `_get_*_prompt_embeds` helpers expand the encoded embeddings + # by `num_images_per_prompt`, but user-supplied embeddings bypass that path. + # Track that here so we can apply the same expansion at the end and keep the + # batch dimension consistent with `prepare_latents` (see #10712). + prompt_embeds_was_provided = prompt_embeds is not None + negative_prompt_embeds_was_provided = negative_prompt_embeds is not None + if prompt_embeds is None: prompt_2 = prompt_2 or prompt prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2 @@ -554,6 +561,28 @@ def encode_prompt( [negative_pooled_prompt_embed, negative_pooled_prompt_2_embed], dim=-1 ) + # Apply `num_images_per_prompt` expansion to user-supplied embeddings to match + # what `_get_*_prompt_embeds` already does for freshly-encoded ones (#10712). + if prompt_embeds_was_provided and num_images_per_prompt > 1: + seq_len, hidden_dim = prompt_embeds.shape[-2], prompt_embeds.shape[-1] + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, hidden_dim) + pooled_dim = pooled_prompt_embeds.shape[-1] + pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt) + pooled_prompt_embeds = pooled_prompt_embeds.view(batch_size * num_images_per_prompt, pooled_dim) + + if do_classifier_free_guidance and negative_prompt_embeds_was_provided and num_images_per_prompt > 1: + seq_len, hidden_dim = negative_prompt_embeds.shape[-2], negative_prompt_embeds.shape[-1] + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view( + batch_size * num_images_per_prompt, seq_len, hidden_dim + ) + pooled_dim = negative_pooled_prompt_embeds.shape[-1] + negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(1, num_images_per_prompt) + negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.view( + batch_size * num_images_per_prompt, pooled_dim + ) + if self.text_encoder is not None: if isinstance(self, SD3LoraLoaderMixin) and USE_PEFT_BACKEND: # Retrieve the original scale by scaling back the LoRA layers diff --git a/tests/pipelines/stable_diffusion_3/test_pipeline_stable_diffusion_3.py b/tests/pipelines/stable_diffusion_3/test_pipeline_stable_diffusion_3.py index 200c832d0941..a00e94757b3b 100644 --- a/tests/pipelines/stable_diffusion_3/test_pipeline_stable_diffusion_3.py +++ b/tests/pipelines/stable_diffusion_3/test_pipeline_stable_diffusion_3.py @@ -223,6 +223,44 @@ def test_skip_guidance_layers(self): self.assertEqual(output_full.shape, output_skip.shape, "Outputs should have the same shape") + def test_pipeline_accepts_prompt_embeds_with_num_images_per_prompt(self): + # Regression test for https://github.com/huggingface/diffusers/issues/10712: pre-computed + # `prompt_embeds` produced by `encode_prompt(num_images_per_prompt=k)` would crash the + # pipeline when the matching `num_images_per_prompt=k` was passed to `__call__` because + # only the prompt-encoding path expanded by `num_images_per_prompt`. + components = self.get_dummy_components() + pipe = self.pipeline_class(**components).to(torch_device) + pipe.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs(torch_device) + prompt = inputs.pop("prompt") + num_images_per_prompt = 2 + + prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds = ( + pipe.encode_prompt( + prompt=prompt, + prompt_2=None, + prompt_3=None, + device=torch_device, + num_images_per_prompt=num_images_per_prompt, + do_classifier_free_guidance=True, + ) + ) + + images = pipe( + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, + num_images_per_prompt=num_images_per_prompt, + **inputs, + ).images + + # 1 prompt * num_images_per_prompt (the pre-encoded batch dim) * num_images_per_prompt again + # — matches the SDXL behaviour of expanding once per `__call__`; the important assertion is + # that the pipeline no longer crashes. + self.assertEqual(images.shape[0], num_images_per_prompt * num_images_per_prompt) + @slow @require_big_accelerator