Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -447,6 +447,13 @@ def encode_prompt(
else:
batch_size = prompt_embeds.shape[0]

# The internal `_get_*_prompt_embeds` helpers expand the encoded embeddings
# by `num_images_per_prompt`, but user-supplied embeddings bypass that path.
# Track that here so we can apply the same expansion at the end and keep the
# batch dimension consistent with `prepare_latents` (see #10712).
prompt_embeds_was_provided = prompt_embeds is not None
negative_prompt_embeds_was_provided = negative_prompt_embeds is not None

if prompt_embeds is None:
prompt_2 = prompt_2 or prompt
prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2
Expand Down Expand Up @@ -543,6 +550,28 @@ def encode_prompt(
[negative_pooled_prompt_embed, negative_pooled_prompt_2_embed], dim=-1
)

# Apply `num_images_per_prompt` expansion to user-supplied embeddings to match
# what `_get_*_prompt_embeds` already does for freshly-encoded ones (#10712).
if prompt_embeds_was_provided and num_images_per_prompt > 1:
seq_len, hidden_dim = prompt_embeds.shape[-2], prompt_embeds.shape[-1]
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, hidden_dim)
pooled_dim = pooled_prompt_embeds.shape[-1]
pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt)
pooled_prompt_embeds = pooled_prompt_embeds.view(batch_size * num_images_per_prompt, pooled_dim)

if do_classifier_free_guidance and negative_prompt_embeds_was_provided and num_images_per_prompt > 1:
seq_len, hidden_dim = negative_prompt_embeds.shape[-2], negative_prompt_embeds.shape[-1]
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
negative_prompt_embeds = negative_prompt_embeds.view(
batch_size * num_images_per_prompt, seq_len, hidden_dim
)
pooled_dim = negative_pooled_prompt_embeds.shape[-1]
negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(1, num_images_per_prompt)
negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.view(
batch_size * num_images_per_prompt, pooled_dim
)

if self.text_encoder is not None:
if isinstance(self, SD3LoraLoaderMixin) and USE_PEFT_BACKEND:
# Retrieve the original scale by scaling back the LoRA layers
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -469,6 +469,13 @@ def encode_prompt(
else:
batch_size = prompt_embeds.shape[0]

# The internal `_get_*_prompt_embeds` helpers expand the encoded embeddings
# by `num_images_per_prompt`, but user-supplied embeddings bypass that path.
# Track that here so we can apply the same expansion at the end and keep the
# batch dimension consistent with `prepare_latents` (see #10712).
prompt_embeds_was_provided = prompt_embeds is not None
negative_prompt_embeds_was_provided = negative_prompt_embeds is not None

if prompt_embeds is None:
prompt_2 = prompt_2 or prompt
prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2
Expand Down Expand Up @@ -565,6 +572,28 @@ def encode_prompt(
[negative_pooled_prompt_embed, negative_pooled_prompt_2_embed], dim=-1
)

# Apply `num_images_per_prompt` expansion to user-supplied embeddings to match
# what `_get_*_prompt_embeds` already does for freshly-encoded ones (#10712).
if prompt_embeds_was_provided and num_images_per_prompt > 1:
seq_len, hidden_dim = prompt_embeds.shape[-2], prompt_embeds.shape[-1]
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, hidden_dim)
pooled_dim = pooled_prompt_embeds.shape[-1]
pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt)
pooled_prompt_embeds = pooled_prompt_embeds.view(batch_size * num_images_per_prompt, pooled_dim)

if do_classifier_free_guidance and negative_prompt_embeds_was_provided and num_images_per_prompt > 1:
seq_len, hidden_dim = negative_prompt_embeds.shape[-2], negative_prompt_embeds.shape[-1]
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
negative_prompt_embeds = negative_prompt_embeds.view(
batch_size * num_images_per_prompt, seq_len, hidden_dim
)
pooled_dim = negative_pooled_prompt_embeds.shape[-1]
negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(1, num_images_per_prompt)
negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.view(
batch_size * num_images_per_prompt, pooled_dim
)

if self.text_encoder is not None:
if isinstance(self, SD3LoraLoaderMixin) and USE_PEFT_BACKEND:
# Retrieve the original scale by scaling back the LoRA layers
Expand Down
29 changes: 29 additions & 0 deletions src/diffusers/pipelines/pag/pipeline_pag_sd_3.py
Original file line number Diff line number Diff line change
Expand Up @@ -417,6 +417,13 @@ def encode_prompt(
else:
batch_size = prompt_embeds.shape[0]

# The internal `_get_*_prompt_embeds` helpers expand the encoded embeddings
# by `num_images_per_prompt`, but user-supplied embeddings bypass that path.
# Track that here so we can apply the same expansion at the end and keep the
# batch dimension consistent with `prepare_latents` (see #10712).
prompt_embeds_was_provided = prompt_embeds is not None
negative_prompt_embeds_was_provided = negative_prompt_embeds is not None

if prompt_embeds is None:
prompt_2 = prompt_2 or prompt
prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2
Expand Down Expand Up @@ -513,6 +520,28 @@ def encode_prompt(
[negative_pooled_prompt_embed, negative_pooled_prompt_2_embed], dim=-1
)

# Apply `num_images_per_prompt` expansion to user-supplied embeddings to match
# what `_get_*_prompt_embeds` already does for freshly-encoded ones (#10712).
if prompt_embeds_was_provided and num_images_per_prompt > 1:
seq_len, hidden_dim = prompt_embeds.shape[-2], prompt_embeds.shape[-1]
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, hidden_dim)
pooled_dim = pooled_prompt_embeds.shape[-1]
pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt)
pooled_prompt_embeds = pooled_prompt_embeds.view(batch_size * num_images_per_prompt, pooled_dim)

if do_classifier_free_guidance and negative_prompt_embeds_was_provided and num_images_per_prompt > 1:
seq_len, hidden_dim = negative_prompt_embeds.shape[-2], negative_prompt_embeds.shape[-1]
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
negative_prompt_embeds = negative_prompt_embeds.view(
batch_size * num_images_per_prompt, seq_len, hidden_dim
)
pooled_dim = negative_pooled_prompt_embeds.shape[-1]
negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(1, num_images_per_prompt)
negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.view(
batch_size * num_images_per_prompt, pooled_dim
)

if self.text_encoder is not None:
if isinstance(self, SD3LoraLoaderMixin) and USE_PEFT_BACKEND:
# Retrieve the original scale by scaling back the LoRA layers
Expand Down
29 changes: 29 additions & 0 deletions src/diffusers/pipelines/pag/pipeline_pag_sd_3_img2img.py
Original file line number Diff line number Diff line change
Expand Up @@ -433,6 +433,13 @@ def encode_prompt(
else:
batch_size = prompt_embeds.shape[0]

# The internal `_get_*_prompt_embeds` helpers expand the encoded embeddings
# by `num_images_per_prompt`, but user-supplied embeddings bypass that path.
# Track that here so we can apply the same expansion at the end and keep the
# batch dimension consistent with `prepare_latents` (see #10712).
prompt_embeds_was_provided = prompt_embeds is not None
negative_prompt_embeds_was_provided = negative_prompt_embeds is not None

if prompt_embeds is None:
prompt_2 = prompt_2 or prompt
prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2
Expand Down Expand Up @@ -529,6 +536,28 @@ def encode_prompt(
[negative_pooled_prompt_embed, negative_pooled_prompt_2_embed], dim=-1
)

# Apply `num_images_per_prompt` expansion to user-supplied embeddings to match
# what `_get_*_prompt_embeds` already does for freshly-encoded ones (#10712).
if prompt_embeds_was_provided and num_images_per_prompt > 1:
seq_len, hidden_dim = prompt_embeds.shape[-2], prompt_embeds.shape[-1]
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, hidden_dim)
pooled_dim = pooled_prompt_embeds.shape[-1]
pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt)
pooled_prompt_embeds = pooled_prompt_embeds.view(batch_size * num_images_per_prompt, pooled_dim)

if do_classifier_free_guidance and negative_prompt_embeds_was_provided and num_images_per_prompt > 1:
seq_len, hidden_dim = negative_prompt_embeds.shape[-2], negative_prompt_embeds.shape[-1]
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
negative_prompt_embeds = negative_prompt_embeds.view(
batch_size * num_images_per_prompt, seq_len, hidden_dim
)
pooled_dim = negative_pooled_prompt_embeds.shape[-1]
negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(1, num_images_per_prompt)
negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.view(
batch_size * num_images_per_prompt, pooled_dim
)

if self.text_encoder is not None:
if isinstance(self, SD3LoraLoaderMixin) and USE_PEFT_BACKEND:
# Retrieve the original scale by scaling back the LoRA layers
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -426,6 +426,13 @@ def encode_prompt(
else:
batch_size = prompt_embeds.shape[0]

# The internal `_get_*_prompt_embeds` helpers expand the encoded embeddings
# by `num_images_per_prompt`, but user-supplied embeddings bypass that path.
# Track that here so we can apply the same expansion at the end and keep the
# batch dimension consistent with `prepare_latents` (see #10712).
prompt_embeds_was_provided = prompt_embeds is not None
negative_prompt_embeds_was_provided = negative_prompt_embeds is not None

if prompt_embeds is None:
prompt_2 = prompt_2 or prompt
prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2
Expand Down Expand Up @@ -522,6 +529,28 @@ def encode_prompt(
[negative_pooled_prompt_embed, negative_pooled_prompt_2_embed], dim=-1
)

# Apply `num_images_per_prompt` expansion to user-supplied embeddings to match
# what `_get_*_prompt_embeds` already does for freshly-encoded ones (#10712).
if prompt_embeds_was_provided and num_images_per_prompt > 1:
seq_len, hidden_dim = prompt_embeds.shape[-2], prompt_embeds.shape[-1]
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, hidden_dim)
pooled_dim = pooled_prompt_embeds.shape[-1]
pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt)
pooled_prompt_embeds = pooled_prompt_embeds.view(batch_size * num_images_per_prompt, pooled_dim)

if do_classifier_free_guidance and negative_prompt_embeds_was_provided and num_images_per_prompt > 1:
seq_len, hidden_dim = negative_prompt_embeds.shape[-2], negative_prompt_embeds.shape[-1]
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
negative_prompt_embeds = negative_prompt_embeds.view(
batch_size * num_images_per_prompt, seq_len, hidden_dim
)
pooled_dim = negative_pooled_prompt_embeds.shape[-1]
negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(1, num_images_per_prompt)
negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.view(
batch_size * num_images_per_prompt, pooled_dim
)

if self.text_encoder is not None:
if isinstance(self, SD3LoraLoaderMixin) and USE_PEFT_BACKEND:
# Retrieve the original scale by scaling back the LoRA layers
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -452,6 +452,13 @@ def encode_prompt(
else:
batch_size = prompt_embeds.shape[0]

# The internal `_get_*_prompt_embeds` helpers expand the encoded embeddings
# by `num_images_per_prompt`, but user-supplied embeddings bypass that path.
# Track that here so we can apply the same expansion at the end and keep the
# batch dimension consistent with `prepare_latents` (see #10712).
prompt_embeds_was_provided = prompt_embeds is not None
negative_prompt_embeds_was_provided = negative_prompt_embeds is not None

if prompt_embeds is None:
prompt_2 = prompt_2 or prompt
prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2
Expand Down Expand Up @@ -548,6 +555,28 @@ def encode_prompt(
[negative_pooled_prompt_embed, negative_pooled_prompt_2_embed], dim=-1
)

# Apply `num_images_per_prompt` expansion to user-supplied embeddings to match
# what `_get_*_prompt_embeds` already does for freshly-encoded ones (#10712).
if prompt_embeds_was_provided and num_images_per_prompt > 1:
seq_len, hidden_dim = prompt_embeds.shape[-2], prompt_embeds.shape[-1]
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, hidden_dim)
pooled_dim = pooled_prompt_embeds.shape[-1]
pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt)
pooled_prompt_embeds = pooled_prompt_embeds.view(batch_size * num_images_per_prompt, pooled_dim)

if do_classifier_free_guidance and negative_prompt_embeds_was_provided and num_images_per_prompt > 1:
seq_len, hidden_dim = negative_prompt_embeds.shape[-2], negative_prompt_embeds.shape[-1]
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
negative_prompt_embeds = negative_prompt_embeds.view(
batch_size * num_images_per_prompt, seq_len, hidden_dim
)
pooled_dim = negative_pooled_prompt_embeds.shape[-1]
negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(1, num_images_per_prompt)
negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.view(
batch_size * num_images_per_prompt, pooled_dim
)

if self.text_encoder is not None:
if isinstance(self, SD3LoraLoaderMixin) and USE_PEFT_BACKEND:
# Retrieve the original scale by scaling back the LoRA layers
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -458,6 +458,13 @@ def encode_prompt(
else:
batch_size = prompt_embeds.shape[0]

# The internal `_get_*_prompt_embeds` helpers expand the encoded embeddings
# by `num_images_per_prompt`, but user-supplied embeddings bypass that path.
# Track that here so we can apply the same expansion at the end and keep the
# batch dimension consistent with `prepare_latents` (see #10712).
prompt_embeds_was_provided = prompt_embeds is not None
negative_prompt_embeds_was_provided = negative_prompt_embeds is not None

if prompt_embeds is None:
prompt_2 = prompt_2 or prompt
prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2
Expand Down Expand Up @@ -554,6 +561,28 @@ def encode_prompt(
[negative_pooled_prompt_embed, negative_pooled_prompt_2_embed], dim=-1
)

# Apply `num_images_per_prompt` expansion to user-supplied embeddings to match
# what `_get_*_prompt_embeds` already does for freshly-encoded ones (#10712).
if prompt_embeds_was_provided and num_images_per_prompt > 1:
seq_len, hidden_dim = prompt_embeds.shape[-2], prompt_embeds.shape[-1]
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, hidden_dim)
pooled_dim = pooled_prompt_embeds.shape[-1]
pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt)
pooled_prompt_embeds = pooled_prompt_embeds.view(batch_size * num_images_per_prompt, pooled_dim)

if do_classifier_free_guidance and negative_prompt_embeds_was_provided and num_images_per_prompt > 1:
seq_len, hidden_dim = negative_prompt_embeds.shape[-2], negative_prompt_embeds.shape[-1]
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
negative_prompt_embeds = negative_prompt_embeds.view(
batch_size * num_images_per_prompt, seq_len, hidden_dim
)
pooled_dim = negative_pooled_prompt_embeds.shape[-1]
negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(1, num_images_per_prompt)
negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.view(
batch_size * num_images_per_prompt, pooled_dim
)

if self.text_encoder is not None:
if isinstance(self, SD3LoraLoaderMixin) and USE_PEFT_BACKEND:
# Retrieve the original scale by scaling back the LoRA layers
Expand Down
Loading
Loading