From a0823154f9df95ad14f9dd61f86ee83d3a26fc89 Mon Sep 17 00:00:00 2001 From: "tangyanfei.8" Date: Thu, 18 Jun 2026 12:10:13 +0000 Subject: [PATCH 01/14] feat: add image edit plus --- ...convert_joyimage_edit_plus_to_diffusers.py | 290 ++++++++ src/diffusers/__init__.py | 6 + src/diffusers/models/__init__.py | 2 + src/diffusers/models/transformers/__init__.py | 1 + .../transformers/transformer_joyimage.py | 2 + .../transformer_joyimage_edit_plus.py | 365 +++++++++ src/diffusers/pipelines/__init__.py | 4 +- src/diffusers/pipelines/joyimage/__init__.py | 7 +- .../joyimage/pipeline_joyimage_edit_plus.py | 697 ++++++++++++++++++ .../pipelines/joyimage/pipeline_output.py | 8 + 10 files changed, 1377 insertions(+), 5 deletions(-) create mode 100644 scripts/convert_joyimage_edit_plus_to_diffusers.py create mode 100644 src/diffusers/models/transformers/transformer_joyimage_edit_plus.py create mode 100644 src/diffusers/pipelines/joyimage/pipeline_joyimage_edit_plus.py diff --git a/scripts/convert_joyimage_edit_plus_to_diffusers.py b/scripts/convert_joyimage_edit_plus_to_diffusers.py new file mode 100644 index 000000000000..f01adb03c747 --- /dev/null +++ b/scripts/convert_joyimage_edit_plus_to_diffusers.py @@ -0,0 +1,290 @@ +import argparse +from typing import Any, Dict, Tuple + +import torch +from accelerate import init_empty_weights +from transformers import AutoProcessor, AutoTokenizer, Qwen3VLForConditionalGeneration + +from diffusers import ( + AutoencoderKLWan, + JoyImageEditPlusPipeline, +) +from diffusers.models.transformers.transformer_joyimage_edit_plus import JoyImageEditPlusTransformer3DModel +from diffusers.schedulers.scheduling_flow_match_euler_discrete import ( + FlowMatchEulerDiscreteScheduler, +) + + +# VAE conversion reused from convert_joyimage_edit_to_diffusers.py (identical VAE) +def convert_vae(vae_ckpt_path): + old_state_dict = torch.load(vae_ckpt_path, weights_only=True) + new_state_dict = {} + + middle_key_mapping = { + "encoder.middle.0.residual.0.gamma": "encoder.mid_block.resnets.0.norm1.gamma", + "encoder.middle.0.residual.2.bias": "encoder.mid_block.resnets.0.conv1.bias", + "encoder.middle.0.residual.2.weight": "encoder.mid_block.resnets.0.conv1.weight", + "encoder.middle.0.residual.3.gamma": "encoder.mid_block.resnets.0.norm2.gamma", + "encoder.middle.0.residual.6.bias": "encoder.mid_block.resnets.0.conv2.bias", + "encoder.middle.0.residual.6.weight": "encoder.mid_block.resnets.0.conv2.weight", + "encoder.middle.2.residual.0.gamma": "encoder.mid_block.resnets.1.norm1.gamma", + "encoder.middle.2.residual.2.bias": "encoder.mid_block.resnets.1.conv1.bias", + "encoder.middle.2.residual.2.weight": "encoder.mid_block.resnets.1.conv1.weight", + "encoder.middle.2.residual.3.gamma": "encoder.mid_block.resnets.1.norm2.gamma", + "encoder.middle.2.residual.6.bias": "encoder.mid_block.resnets.1.conv2.bias", + "encoder.middle.2.residual.6.weight": "encoder.mid_block.resnets.1.conv2.weight", + "decoder.middle.0.residual.0.gamma": "decoder.mid_block.resnets.0.norm1.gamma", + "decoder.middle.0.residual.2.bias": "decoder.mid_block.resnets.0.conv1.bias", + "decoder.middle.0.residual.2.weight": "decoder.mid_block.resnets.0.conv1.weight", + "decoder.middle.0.residual.3.gamma": "decoder.mid_block.resnets.0.norm2.gamma", + "decoder.middle.0.residual.6.bias": "decoder.mid_block.resnets.0.conv2.bias", + "decoder.middle.0.residual.6.weight": "decoder.mid_block.resnets.0.conv2.weight", + "decoder.middle.2.residual.0.gamma": "decoder.mid_block.resnets.1.norm1.gamma", + "decoder.middle.2.residual.2.bias": "decoder.mid_block.resnets.1.conv1.bias", + "decoder.middle.2.residual.2.weight": "decoder.mid_block.resnets.1.conv1.weight", + "decoder.middle.2.residual.3.gamma": "decoder.mid_block.resnets.1.norm2.gamma", + "decoder.middle.2.residual.6.bias": "decoder.mid_block.resnets.1.conv2.bias", + "decoder.middle.2.residual.6.weight": "decoder.mid_block.resnets.1.conv2.weight", + } + + attention_mapping = { + "encoder.middle.1.norm.gamma": "encoder.mid_block.attentions.0.norm.gamma", + "encoder.middle.1.to_qkv.weight": "encoder.mid_block.attentions.0.to_qkv.weight", + "encoder.middle.1.to_qkv.bias": "encoder.mid_block.attentions.0.to_qkv.bias", + "encoder.middle.1.proj.weight": "encoder.mid_block.attentions.0.proj.weight", + "encoder.middle.1.proj.bias": "encoder.mid_block.attentions.0.proj.bias", + "decoder.middle.1.norm.gamma": "decoder.mid_block.attentions.0.norm.gamma", + "decoder.middle.1.to_qkv.weight": "decoder.mid_block.attentions.0.to_qkv.weight", + "decoder.middle.1.to_qkv.bias": "decoder.mid_block.attentions.0.to_qkv.bias", + "decoder.middle.1.proj.weight": "decoder.mid_block.attentions.0.proj.weight", + "decoder.middle.1.proj.bias": "decoder.mid_block.attentions.0.proj.bias", + } + + head_mapping = { + "encoder.head.0.gamma": "encoder.norm_out.gamma", + "encoder.head.2.bias": "encoder.conv_out.bias", + "encoder.head.2.weight": "encoder.conv_out.weight", + "decoder.head.0.gamma": "decoder.norm_out.gamma", + "decoder.head.2.bias": "decoder.conv_out.bias", + "decoder.head.2.weight": "decoder.conv_out.weight", + } + + quant_mapping = { + "conv1.weight": "quant_conv.weight", + "conv1.bias": "quant_conv.bias", + "conv2.weight": "post_quant_conv.weight", + "conv2.bias": "post_quant_conv.bias", + } + + for key, value in old_state_dict.items(): + if key in middle_key_mapping: + new_state_dict[middle_key_mapping[key]] = value + elif key in attention_mapping: + new_state_dict[attention_mapping[key]] = value + elif key in head_mapping: + new_state_dict[head_mapping[key]] = value + elif key in quant_mapping: + new_state_dict[quant_mapping[key]] = value + elif key == "encoder.conv1.weight": + new_state_dict["encoder.conv_in.weight"] = value + elif key == "encoder.conv1.bias": + new_state_dict["encoder.conv_in.bias"] = value + elif key == "decoder.conv1.weight": + new_state_dict["decoder.conv_in.weight"] = value + elif key == "decoder.conv1.bias": + new_state_dict["decoder.conv_in.bias"] = value + elif key.startswith("encoder.downsamples."): + new_key = key.replace("encoder.downsamples.", "encoder.down_blocks.") + if ".residual.0.gamma" in new_key: + new_key = new_key.replace(".residual.0.gamma", ".norm1.gamma") + elif ".residual.2.bias" in new_key: + new_key = new_key.replace(".residual.2.bias", ".conv1.bias") + elif ".residual.2.weight" in new_key: + new_key = new_key.replace(".residual.2.weight", ".conv1.weight") + elif ".residual.3.gamma" in new_key: + new_key = new_key.replace(".residual.3.gamma", ".norm2.gamma") + elif ".residual.6.bias" in new_key: + new_key = new_key.replace(".residual.6.bias", ".conv2.bias") + elif ".residual.6.weight" in new_key: + new_key = new_key.replace(".residual.6.weight", ".conv2.weight") + elif ".shortcut.bias" in new_key: + new_key = new_key.replace(".shortcut.bias", ".conv_shortcut.bias") + elif ".shortcut.weight" in new_key: + new_key = new_key.replace(".shortcut.weight", ".conv_shortcut.weight") + new_state_dict[new_key] = value + elif key.startswith("decoder.upsamples."): + parts = key.split(".") + block_idx = int(parts[2]) + + if "residual" in key: + if block_idx in [0, 1, 2]: + new_block_idx = 0 + resnet_idx = block_idx + elif block_idx in [4, 5, 6]: + new_block_idx = 1 + resnet_idx = block_idx - 4 + elif block_idx in [8, 9, 10]: + new_block_idx = 2 + resnet_idx = block_idx - 8 + elif block_idx in [12, 13, 14]: + new_block_idx = 3 + resnet_idx = block_idx - 12 + else: + new_state_dict[key] = value + continue + + if ".residual.0.gamma" in key: + new_key = f"decoder.up_blocks.{new_block_idx}.resnets.{resnet_idx}.norm1.gamma" + elif ".residual.2.bias" in key: + new_key = f"decoder.up_blocks.{new_block_idx}.resnets.{resnet_idx}.conv1.bias" + elif ".residual.2.weight" in key: + new_key = f"decoder.up_blocks.{new_block_idx}.resnets.{resnet_idx}.conv1.weight" + elif ".residual.3.gamma" in key: + new_key = f"decoder.up_blocks.{new_block_idx}.resnets.{resnet_idx}.norm2.gamma" + elif ".residual.6.bias" in key: + new_key = f"decoder.up_blocks.{new_block_idx}.resnets.{resnet_idx}.conv2.bias" + elif ".residual.6.weight" in key: + new_key = f"decoder.up_blocks.{new_block_idx}.resnets.{resnet_idx}.conv2.weight" + else: + new_key = key + new_state_dict[new_key] = value + + elif ".shortcut." in key: + if block_idx == 4: + new_key = key.replace(".shortcut.", ".resnets.0.conv_shortcut.") + new_key = new_key.replace("decoder.upsamples.4", "decoder.up_blocks.1") + else: + new_key = key.replace("decoder.upsamples.", "decoder.up_blocks.") + new_key = new_key.replace(".shortcut.", ".conv_shortcut.") + new_state_dict[new_key] = value + + elif ".resample." in key or ".time_conv." in key: + if block_idx == 3: + new_key = key.replace(f"decoder.upsamples.{block_idx}", "decoder.up_blocks.0.upsamplers.0") + elif block_idx == 7: + new_key = key.replace(f"decoder.upsamples.{block_idx}", "decoder.up_blocks.1.upsamplers.0") + elif block_idx == 11: + new_key = key.replace(f"decoder.upsamples.{block_idx}", "decoder.up_blocks.2.upsamplers.0") + else: + new_key = key.replace("decoder.upsamples.", "decoder.up_blocks.") + new_state_dict[new_key] = value + else: + new_key = key.replace("decoder.upsamples.", "decoder.up_blocks.") + new_state_dict[new_key] = value + else: + new_state_dict[key] = value + + with init_empty_weights(): + vae = AutoencoderKLWan() + vae.load_state_dict(new_state_dict, strict=True, assign=True) + return vae + + +def get_transformer_config() -> Dict[str, Any]: + return { + "hidden_size": 4096, + "in_channels": 16, + "num_attention_heads": 32, + "num_layers": 40, + "out_channels": 16, + "patch_size": [1, 2, 2], + "rope_dim_list": [16, 56, 56], + "text_dim": 4096, + "rope_type": "rope", + "theta": 10000, + } + + +def convert_transformer(ckpt_path: str): + checkpoint = torch.load(ckpt_path, weights_only=True) + if "model" in checkpoint: + original_state_dict = checkpoint["model"] + else: + original_state_dict = checkpoint + + attn_suffixes = ( + "img_attn_qkv.", + "img_attn_q_norm.", + "img_attn_k_norm.", + "img_attn_proj.", + "txt_attn_qkv.", + "txt_attn_q_norm.", + "txt_attn_k_norm.", + "txt_attn_proj.", + ) + remapped = {} + for key, value in original_state_dict.items(): + new_key = key + if key.startswith("double_blocks."): + for suffix in attn_suffixes: + if "." + suffix in key and ".attn." + suffix not in key: + new_key = key.replace("." + suffix, ".attn." + suffix) + break + remapped[new_key] = value + + config = get_transformer_config() + with init_empty_weights(): + transformer = JoyImageEditPlusTransformer3DModel(**config) + transformer.load_state_dict(remapped, strict=True, assign=True) + return transformer + + +def get_args(): + parser = argparse.ArgumentParser(description="Convert JoyImage Edit Plus checkpoints to diffusers format") + parser.add_argument("--transformer_ckpt_path", type=str, default=None) + parser.add_argument("--vae_ckpt_path", type=str, default=None) + parser.add_argument("--text_encoder_path", type=str, default=None) + parser.add_argument("--save_pipeline", action="store_true") + parser.add_argument("--output_path", type=str, required=True) + parser.add_argument("--dtype", default="bf16", help="Torch dtype (fp32, fp16, bf16)") + parser.add_argument("--flow_shift", type=float, default=1.5) + return parser.parse_args() + + +DTYPE_MAPPING = { + "fp32": torch.float32, + "fp16": torch.float16, + "bf16": torch.bfloat16, +} + +if __name__ == "__main__": + args = get_args() + transformer = None + vae = None + dtype = DTYPE_MAPPING[args.dtype] + + if args.save_pipeline: + assert args.transformer_ckpt_path is not None and args.vae_ckpt_path is not None + assert args.text_encoder_path is not None + + if args.transformer_ckpt_path is not None: + transformer = convert_transformer(args.transformer_ckpt_path) + transformer = transformer.to(dtype=dtype) + if not args.save_pipeline: + transformer.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB") + + if args.vae_ckpt_path is not None: + vae = convert_vae(args.vae_ckpt_path) + vae = vae.to(dtype=dtype) + if not args.save_pipeline: + vae.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB") + + if args.save_pipeline: + processor = AutoProcessor.from_pretrained(args.text_encoder_path) + text_encoder = Qwen3VLForConditionalGeneration.from_pretrained( + args.text_encoder_path, torch_dtype=torch.bfloat16 + ).to("cuda") + tokenizer = AutoTokenizer.from_pretrained(args.text_encoder_path) + scheduler = FlowMatchEulerDiscreteScheduler(num_train_timesteps=1000, shift=args.flow_shift) + transformer = transformer.to("cuda") + vae = vae.to("cuda") + pipe = JoyImageEditPlusPipeline( + processor=processor, + transformer=transformer, + text_encoder=text_encoder, + tokenizer=tokenizer, + vae=vae, + scheduler=scheduler, + ).to("cuda") + pipe.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB") + processor.save_pretrained(f"{args.output_path}/processor") diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 9ec449df0508..b3c62bb70cc1 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -275,6 +275,7 @@ "I2VGenXLUNet", "Ideogram4Transformer2DModel", "JoyImageEditTransformer3DModel", + "JoyImageEditPlusTransformer3DModel", "Kandinsky3UNet", "Kandinsky5Transformer3DModel", "Krea2Transformer2DModel", @@ -624,6 +625,8 @@ "ImageTextPipelineOutput", "JoyImageEditPipeline", "JoyImageEditPipelineOutput", + "JoyImageEditPlusPipeline", + "JoyImageEditPlusPipelineOutput", "Kandinsky3Img2ImgPipeline", "Kandinsky3Pipeline", "Kandinsky5I2IPipeline", @@ -1137,6 +1140,7 @@ I2VGenXLUNet, Ideogram4Transformer2DModel, JoyImageEditTransformer3DModel, + JoyImageEditPlusTransformer3DModel, Kandinsky3UNet, Kandinsky5Transformer3DModel, Krea2Transformer2DModel, @@ -1461,6 +1465,8 @@ ImageTextPipelineOutput, JoyImageEditPipeline, JoyImageEditPipelineOutput, + JoyImageEditPlusPipeline, + JoyImageEditPlusPipelineOutput, Kandinsky3Img2ImgPipeline, Kandinsky3Pipeline, Kandinsky5I2IPipeline, diff --git a/src/diffusers/models/__init__.py b/src/diffusers/models/__init__.py index 3e56e49ce04e..30eec69dd02a 100755 --- a/src/diffusers/models/__init__.py +++ b/src/diffusers/models/__init__.py @@ -121,6 +121,7 @@ _import_structure["transformers.transformer_hunyuanimage"] = ["HunyuanImageTransformer2DModel"] _import_structure["transformers.transformer_ideogram4"] = ["Ideogram4Transformer2DModel"] _import_structure["transformers.transformer_joyimage"] = ["JoyImageEditTransformer3DModel"] + _import_structure["transformers.transformer_joyimage_edit_plus"] = ["JoyImageEditPlusTransformer3DModel"] _import_structure["transformers.transformer_kandinsky"] = ["Kandinsky5Transformer3DModel"] _import_structure["transformers.transformer_krea2"] = ["Krea2Transformer2DModel"] _import_structure["transformers.transformer_longcat_audio_dit"] = ["LongCatAudioDiTTransformer"] @@ -255,6 +256,7 @@ HunyuanVideoTransformer3DModel, Ideogram4Transformer2DModel, JoyImageEditTransformer3DModel, + JoyImageEditPlusTransformer3DModel, Kandinsky5Transformer3DModel, Krea2Transformer2DModel, LatteTransformer3DModel, diff --git a/src/diffusers/models/transformers/__init__.py b/src/diffusers/models/transformers/__init__.py index 4ba9703b5fc0..21f5cb853643 100755 --- a/src/diffusers/models/transformers/__init__.py +++ b/src/diffusers/models/transformers/__init__.py @@ -42,6 +42,7 @@ from .transformer_hunyuanimage import HunyuanImageTransformer2DModel from .transformer_ideogram4 import Ideogram4Transformer2DModel from .transformer_joyimage import JoyImageEditTransformer3DModel + from .transformer_joyimage_edit_plus import JoyImageEditPlusTransformer3DModel from .transformer_kandinsky import Kandinsky5Transformer3DModel from .transformer_krea2 import Krea2Transformer2DModel from .transformer_longcat_audio_dit import LongCatAudioDiTTransformer diff --git a/src/diffusers/models/transformers/transformer_joyimage.py b/src/diffusers/models/transformers/transformer_joyimage.py index b17ddb05f799..d30b0501e02f 100644 --- a/src/diffusers/models/transformers/transformer_joyimage.py +++ b/src/diffusers/models/transformers/transformer_joyimage.py @@ -283,6 +283,7 @@ def forward( encoder_hidden_states: torch.Tensor, temb: torch.Tensor, image_rotary_emb: Tuple[torch.Tensor, torch.Tensor] | None = None, + attention_mask: torch.Tensor | None = None, ) -> Tuple[torch.Tensor, torch.Tensor]: # modulation ( @@ -312,6 +313,7 @@ def forward( hidden_states=img_modulated, encoder_hidden_states=txt_modulated, image_rotary_emb=image_rotary_emb, + attention_mask=attention_mask, ) hidden_states = hidden_states + img_attn * img_mod1_gate.unsqueeze(1) diff --git a/src/diffusers/models/transformers/transformer_joyimage_edit_plus.py b/src/diffusers/models/transformers/transformer_joyimage_edit_plus.py new file mode 100644 index 000000000000..abc8c2b4340a --- /dev/null +++ b/src/diffusers/models/transformers/transformer_joyimage_edit_plus.py @@ -0,0 +1,365 @@ +# Copyright 2025 The JoyImage Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +import math +from typing import List, Tuple, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from ...configuration_utils import ConfigMixin, register_to_config +from ...utils import logging +from ..attention import AttentionMixin, FeedForward +from ..attention_dispatch import dispatch_attention_fn +from ..modeling_outputs import Transformer2DModelOutput +from ..modeling_utils import ModelMixin +from ..normalization import FP32LayerNorm +from .transformer_joyimage import ( + JoyImageAttention, + JoyImageModulate, + JoyImageTimeTextImageEmbedding, + JoyImageTransformerBlock, +) + + +logger = logging.get_logger(__name__) + + +def _apply_rotary_emb_batched( + xq: torch.Tensor, + xk: torch.Tensor, + freqs_cis: Tuple[torch.Tensor, torch.Tensor], +) -> Tuple[torch.Tensor, torch.Tensor]: + """RoPE that handles both batched [B, S, D] and unbatched [S, D] freqs.""" + cos, sin = freqs_cis[0].to(xq.device), freqs_cis[1].to(xq.device) + + if cos.ndim == 2: + # unbatched: [S, D] -> [1, S, 1, D] + cos = cos.unsqueeze(0).unsqueeze(2) + sin = sin.unsqueeze(0).unsqueeze(2) + elif cos.ndim == 3: + # batched: [B, S, D] -> [B, S, 1, D] + cos = cos.unsqueeze(2) + sin = sin.unsqueeze(2) + + def _rotate_half(x): + x_real, x_imag = x.float().reshape(*x.shape[:-1], -1, 2).unbind(-1) + return torch.stack([-x_imag, x_real], dim=-1).flatten(3) + + xq_out = (xq.float() * cos + _rotate_half(xq) * sin).type_as(xq) + xk_out = (xk.float() * cos + _rotate_half(xk) * sin).type_as(xk) + return xq_out, xk_out + + +class JoyImageEditPlusAttnProcessor: + """Attention processor that supports batched RoPE embeddings for edit-plus multi-image input.""" + + _attention_backend = None + _parallel_config = None + + def __init__(self): + pass + + def __call__( + self, + attn: "JoyImageAttention", + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor = None, + image_rotary_emb: Tuple[torch.Tensor, torch.Tensor] | None = None, + attention_mask: torch.Tensor | None = None, + **kwargs, + ) -> Tuple[torch.Tensor, torch.Tensor]: + if encoder_hidden_states is None: + raise ValueError("JoyImageEditPlusAttnProcessor requires encoder_hidden_states") + + heads = attn.heads + + img_qkv = attn.img_attn_qkv(hidden_states) + img_query, img_key, img_value = img_qkv.chunk(3, dim=-1) + + txt_qkv = attn.txt_attn_qkv(encoder_hidden_states) + txt_query, txt_key, txt_value = txt_qkv.chunk(3, dim=-1) + + img_query = img_query.unflatten(-1, (heads, -1)) + img_key = img_key.unflatten(-1, (heads, -1)) + img_value = img_value.unflatten(-1, (heads, -1)) + + txt_query = txt_query.unflatten(-1, (heads, -1)) + txt_key = txt_key.unflatten(-1, (heads, -1)) + txt_value = txt_value.unflatten(-1, (heads, -1)) + + img_query = attn.img_attn_q_norm(img_query) + img_key = attn.img_attn_k_norm(img_key) + txt_query = attn.txt_attn_q_norm(txt_query) + txt_key = attn.txt_attn_k_norm(txt_key) + + if image_rotary_emb is not None: + vis_freqs, txt_freqs = image_rotary_emb + if vis_freqs is not None: + img_query, img_key = _apply_rotary_emb_batched(img_query, img_key, vis_freqs) + if txt_freqs is not None: + txt_query, txt_key = _apply_rotary_emb_batched(txt_query, txt_key, txt_freqs) + + joint_query = torch.cat([img_query, txt_query], dim=1) + joint_key = torch.cat([img_key, txt_key], dim=1) + joint_value = torch.cat([img_value, txt_value], dim=1) + + joint_hidden_states = dispatch_attention_fn( + joint_query, + joint_key, + joint_value, + attn_mask=attention_mask, + dropout_p=0.0, + is_causal=False, + backend=self._attention_backend, + parallel_config=self._parallel_config, + ) + + joint_hidden_states = joint_hidden_states.flatten(2, 3) + joint_hidden_states = joint_hidden_states.to(joint_query.dtype) + + img_attn_output = joint_hidden_states[:, : hidden_states.shape[1], :] + txt_attn_output = joint_hidden_states[:, hidden_states.shape[1] :, :] + + img_attn_output = attn.img_attn_proj(img_attn_output) + txt_attn_output = attn.txt_attn_proj(txt_attn_output) + + return img_attn_output, txt_attn_output + + +class JoyImageEditPlusTransformer3DModel(ModelMixin, ConfigMixin, AttentionMixin): + """JoyImage Edit Plus Transformer for multi-image editing. + + Uses a patchify+padding approach where each reference image and the target noise are independently + patchified and concatenated into a flat patch sequence. Supports variable-resolution reference images. + + Input format: [B, max_patches, C, pt, ph, pw] (6D padded patches) + """ + + _skip_layerwise_casting_patterns = ["img_in", "condition_embedder", "norm"] + _no_split_modules = ["JoyImageTransformerBlock"] + _supports_gradient_checkpointing = True + _keep_in_fp32_modules = [ + "time_embedder", + "norm1", + "norm2", + "norm_out", + ] + _repeated_blocks = ["JoyImageTransformerBlock"] + + @register_to_config + def __init__( + self, + patch_size: list = [1, 2, 2], + in_channels: int = 16, + out_channels: int | None = None, + hidden_size: int = 3072, + num_attention_heads: int = 24, + text_dim: int = 4096, + mlp_width_ratio: float = 4.0, + num_layers: int = 20, + rope_dim_list: list[int] = [16, 56, 56], + rope_type: str = "rope", + theta: int = 256, + ): + super().__init__() + + self.out_channels = out_channels or in_channels + self.patch_size = patch_size + self.hidden_size = hidden_size + self.num_attention_heads = num_attention_heads + self.rope_dim_list = rope_dim_list + self.rope_type = rope_type + self.theta = theta + + attention_head_dim = hidden_size // num_attention_heads + if hidden_size % num_attention_heads != 0: + raise ValueError( + f"hidden_size ({hidden_size}) must be divisible by num_attention_heads ({num_attention_heads})" + ) + + self.img_in = nn.Conv3d(in_channels, hidden_size, kernel_size=patch_size, stride=patch_size) + + self.condition_embedder = JoyImageTimeTextImageEmbedding( + dim=hidden_size, + time_freq_dim=256, + time_proj_dim=hidden_size * 6, + text_embed_dim=text_dim, + ) + + self.double_blocks = nn.ModuleList( + [ + JoyImageTransformerBlock( + dim=hidden_size, + num_attention_heads=num_attention_heads, + attention_head_dim=attention_head_dim, + mlp_width_ratio=mlp_width_ratio, + ) + for _ in range(num_layers) + ] + ) + + self.norm_out = FP32LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.proj_out = nn.Linear(hidden_size, self.out_channels * math.prod(patch_size)) + + self.gradient_checkpointing = False + + # Set batched-RoPE-aware attention processor on all blocks + for block in self.double_blocks: + block.attn.set_processor(JoyImageEditPlusAttnProcessor()) + + def _get_rotary_pos_embed_for_range( + self, + start: Tuple[int, int, int], + stop: Tuple[int, int, int], + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Generate 3D RoPE for a spatial range [start, stop).""" + head_dim = self.hidden_size // self.num_attention_heads + rope_dim_list = self.rope_dim_list + if rope_dim_list is None: + rope_dim_list = [head_dim // 3] * 3 + + grids = [] + for i in range(3): + grids.append(torch.arange(start[i], stop[i], dtype=torch.float32)) + + mesh = torch.stack(torch.meshgrid(*grids, indexing="ij"), dim=0) + + cos_parts, sin_parts = [], [] + for i, dim in enumerate(rope_dim_list): + pos = mesh[i].reshape(-1) + freqs = 1.0 / (self.theta ** (torch.arange(0, dim, 2, dtype=torch.float32)[: (dim // 2)] / dim)) + angles = torch.outer(pos, freqs) + cos_parts.append(angles.cos().repeat_interleave(2, dim=1)) + sin_parts.append(angles.sin().repeat_interleave(2, dim=1)) + + return torch.cat(cos_parts, dim=1), torch.cat(sin_parts, dim=1) + + def forward( + self, + hidden_states: torch.Tensor, + timestep: torch.Tensor, + encoder_hidden_states: torch.Tensor, + encoder_hidden_states_mask: torch.Tensor | None = None, + shape_list: List[List[Tuple[int, int, int]]] | None = None, + return_dict: bool = False, + ) -> Union[torch.Tensor, Tuple]: + """ + Args: + hidden_states: [B, max_patches, C, pt, ph, pw] - patchified latent input. + timestep: [B] - diffusion timestep. + encoder_hidden_states: [B, L, D] - text encoder outputs. + encoder_hidden_states_mask: [B, L] - attention mask for text tokens. + shape_list: Per-sample list of (t, h, w) tuples for each component (target + references). + return_dict: Whether to return a dict or tuple. + """ + batch_size, max_num_patches, channels, pt, ph, pw = hidden_states.shape + device = hidden_states.device + + # Unwrap list inputs (SglangXvideo passes these as lists from CFG branches) + if not isinstance(encoder_hidden_states, torch.Tensor): + encoder_hidden_states = encoder_hidden_states[0] + if isinstance(encoder_hidden_states_mask, list): + encoder_hidden_states_mask = encoder_hidden_states_mask[0] + + # Resolve shape_list from forward context if not explicitly provided + if shape_list is None: + try: + from sglang.multimodal_gen.runtime.managers.forward_context import get_forward_context + + forward_batch = get_forward_context().forward_batch + if forward_batch is not None and forward_batch.vae_image_sizes is not None: + shape_list = [list(forward_batch.vae_image_sizes)] * batch_size + except (ImportError, AttributeError): + pass + if shape_list is None: + raise ValueError( + "shape_list must be provided either as an argument or via forward_batch.vae_image_sizes" + ) + + # 1. Condition embeddings + _, vec, txt = self.condition_embedder(timestep, encoder_hidden_states) + if vec.shape[-1] > self.hidden_size: + vec = vec.unflatten(1, (6, -1)) + + # 2. Patchify via Conv3d: flatten (B, N) -> apply conv -> reshape back + x = hidden_states.reshape(batch_size * max_num_patches, channels, pt, ph, pw) + x = self.img_in(x) # (B*N, D, 1, 1, 1) + img = x.reshape(batch_size, max_num_patches, -1) + + # 3. Build per-component RoPE with temporal offsets + sample_cos_list, sample_sin_list = [], [] + + for i in range(batch_size): + s_cos_parts, s_sin_parts = [], [] + current_t_offset = 0 + + for thw in shape_list[i]: + t, h, w = thw + start = (current_t_offset, 0, 0) + stop = (current_t_offset + t, h, w) + cos_emb, sin_emb = self._get_rotary_pos_embed_for_range(start, stop) + s_cos_parts.append(cos_emb) + s_sin_parts.append(sin_emb) + current_t_offset += t + + s_cos = torch.cat(s_cos_parts, dim=0).to(device) + s_sin = torch.cat(s_sin_parts, dim=0).to(device) + + actual_len = s_cos.shape[0] + pad_len = max_num_patches - actual_len + if pad_len > 0: + s_cos = F.pad(s_cos, (0, 0, 0, pad_len), value=1.0) + s_sin = F.pad(s_sin, (0, 0, 0, pad_len), value=0.0) + + sample_cos_list.append(s_cos) + sample_sin_list.append(s_sin) + + vis_freqs = (torch.stack(sample_cos_list), torch.stack(sample_sin_list)) + + # 4. Build attention mask: [B, 1, 1, img_seq + txt_seq] + # img patches: only actual (non-padding) patches are valid; txt uses encoder_hidden_states_mask + attention_mask = None + if encoder_hidden_states_mask is not None: + img_mask = torch.zeros(batch_size, max_num_patches, device=device, dtype=encoder_hidden_states_mask.dtype) + for i in range(batch_size): + actual_len = sum(t * h * w for t, h, w in shape_list[i]) + img_mask[i, :actual_len] = 1.0 + full_mask = torch.cat([img_mask, encoder_hidden_states_mask], dim=1) + attention_mask = full_mask.unsqueeze(1).unsqueeze(1).bool() + + # 5. Run double blocks + for block in self.double_blocks: + if torch.is_grad_enabled() and self.gradient_checkpointing: + img, txt = self._gradient_checkpointing_func(block, img, txt, vec, (vis_freqs, None), attention_mask) + else: + img, txt = block( + hidden_states=img, + encoder_hidden_states=txt, + temb=vec, + image_rotary_emb=(vis_freqs, None), + attention_mask=attention_mask, + ) + + # 6. Output projection + reshape to 6D patches + img = self.proj_out(self.norm_out(img)) + img = img.reshape( + batch_size, max_num_patches, pt, ph, pw, self.out_channels + ).permute(0, 1, 5, 2, 3, 4) # -> [B, N, C, pt, ph, pw] + + if not return_dict: + return (img,) + return Transformer2DModelOutput(sample=img) diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py index 234085456708..0e25c647299b 100644 --- a/src/diffusers/pipelines/__init__.py +++ b/src/diffusers/pipelines/__init__.py @@ -345,7 +345,7 @@ "LTX2ImageToVideoPipeline", "LTX2LatentUpsamplePipeline", ] - _import_structure["joyimage"] = ["JoyImageEditPipeline", "JoyImageEditPipelineOutput"] + _import_structure["joyimage"] = ["JoyImageEditPipeline", "JoyImageEditPipelineOutput", "JoyImageEditPlusPipeline", "JoyImageEditPlusPipelineOutput"] _import_structure["lumina"] = ["LuminaPipeline", "LuminaText2ImgPipeline"] _import_structure["lumina2"] = ["Lumina2Pipeline", "Lumina2Text2ImgPipeline"] _import_structure["lucy"] = ["LucyEditPipeline"] @@ -758,7 +758,7 @@ from .hunyuan_video1_5 import HunyuanVideo15ImageToVideoPipeline, HunyuanVideo15Pipeline from .hunyuandit import HunyuanDiTPipeline from .ideogram4 import Ideogram4Pipeline, Ideogram4PromptEnhancerHead - from .joyimage import JoyImageEditPipeline, JoyImageEditPipelineOutput + from .joyimage import JoyImageEditPipeline, JoyImageEditPipelineOutput, JoyImageEditPlusPipeline, JoyImageEditPlusPipelineOutput from .kandinsky import ( KandinskyCombinedPipeline, KandinskyImg2ImgCombinedPipeline, diff --git a/src/diffusers/pipelines/joyimage/__init__.py b/src/diffusers/pipelines/joyimage/__init__.py index 85b9246b22a6..a5faea9d9763 100644 --- a/src/diffusers/pipelines/joyimage/__init__.py +++ b/src/diffusers/pipelines/joyimage/__init__.py @@ -22,8 +22,8 @@ _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) else: _import_structure["pipeline_joyimage_edit"] = ["JoyImageEditPipeline"] - - _import_structure["pipeline_output"] = ["JoyImageEditPipelineOutput"] + _import_structure["pipeline_joyimage_edit_plus"] = ["JoyImageEditPlusPipeline"] + _import_structure["pipeline_output"] = ["JoyImageEditPipelineOutput", "JoyImageEditPlusPipelineOutput"] if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: try: @@ -34,7 +34,8 @@ from ...utils.dummy_torch_and_transformers_objects import * else: from .pipeline_joyimage_edit import JoyImageEditPipeline - from .pipeline_output import JoyImageEditPipelineOutput + from .pipeline_joyimage_edit_plus import JoyImageEditPlusPipeline + from .pipeline_output import JoyImageEditPipelineOutput, JoyImageEditPlusPipelineOutput else: import sys diff --git a/src/diffusers/pipelines/joyimage/pipeline_joyimage_edit_plus.py b/src/diffusers/pipelines/joyimage/pipeline_joyimage_edit_plus.py new file mode 100644 index 000000000000..c938e8e8ab32 --- /dev/null +++ b/src/diffusers/pipelines/joyimage/pipeline_joyimage_edit_plus.py @@ -0,0 +1,697 @@ +# Copyright 2025 The JoyImage Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +import math +from typing import Callable, Dict, List, Optional, Tuple, Union + +import numpy as np +import torch +from einops import rearrange +from PIL import Image +from transformers import ( + Qwen2Tokenizer, + Qwen3VLForConditionalGeneration, + Qwen3VLProcessor, +) + +from ...callbacks import MultiPipelineCallbacks, PipelineCallback +from ...image_processor import VaeImageProcessor +from ...models import AutoencoderKLWan +from ...models.transformers.transformer_joyimage_edit_plus import JoyImageEditPlusTransformer3DModel +from ...schedulers import FlowMatchEulerDiscreteScheduler +from ...utils import replace_example_docstring +from ...utils.torch_utils import randn_tensor +from ..pipeline_utils import DiffusionPipeline +from .image_processor import JoyImageEditImageProcessor, find_best_bucket +from .pipeline_output import JoyImageEditPlusPipelineOutput + + +EXAMPLE_DOC_STRING = """ +Examples: + ```python + >>> import torch + >>> from diffusers import JoyImageEditPlusPipeline + >>> from diffusers.utils import load_image + + >>> model_id = "jdopensource/JoyAI-Image-Edit-Plus-Diffusers" + >>> pipe = JoyImageEditPlusPipeline.from_pretrained(model_id, torch_dtype=torch.bfloat16) + >>> pipe.to("cuda") + + >>> images = [ + ... load_image("dog.png"), + ... load_image("person.png"), + ... ] + >>> output = pipe( + ... images=images, + ... prompt="Let the person lovingly play with the dog.", + ... height=1024, + ... width=1024, + ... num_inference_steps=30, + ... guidance_scale=4.0, + ... generator=torch.manual_seed(42), + ... ) + >>> output.images[0].save("output.png") + ``` +""" + + +def retrieve_timesteps( + scheduler, + num_inference_steps: Optional[int] = None, + device: Optional[Union[str, torch.device]] = None, + timesteps: Optional[List[int]] = None, + sigmas: Optional[List[float]] = None, + **kwargs, +): + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed.") + + if timesteps is not None: + if "timesteps" not in set(inspect.signature(scheduler.set_timesteps).parameters.keys()): + raise ValueError(f"{scheduler.__class__} does not support custom timesteps.") + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + if "sigmas" not in set(inspect.signature(scheduler.set_timesteps).parameters.keys()): + raise ValueError(f"{scheduler.__class__} does not support custom sigmas.") + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + + return timesteps, num_inference_steps + + +class JoyImageEditPlusPipeline(DiffusionPipeline): + """Diffusion pipeline for multi-image editing using JoyImage Edit Plus. + + Supports multiple reference images with different resolutions. Each reference image is independently + VAE-encoded and patchified, then concatenated with the target noise patches for joint denoising. + + Model offloading order: text_encoder -> transformer -> vae. + """ + + model_cpu_offload_seq = "text_encoder->transformer->vae" + _callback_tensor_inputs = ["latents", "prompt_embeds"] + + def __init__( + self, + scheduler: FlowMatchEulerDiscreteScheduler, + vae: AutoencoderKLWan, + text_encoder: Qwen3VLForConditionalGeneration, + tokenizer: Qwen2Tokenizer, + transformer: JoyImageEditPlusTransformer3DModel, + processor: Qwen3VLProcessor, + text_token_max_length: int = 2048, + ): + super().__init__() + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + transformer=transformer, + scheduler=scheduler, + processor=processor, + ) + + self.text_token_max_length = text_token_max_length + + self.vae_scale_factor_temporal = self.vae.config.scale_factor_temporal if getattr(self, "vae", None) else 4 + self.vae_scale_factor_spatial = self.vae.config.scale_factor_spatial if getattr(self, "vae", None) else 8 + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor_spatial) + self.vae_image_processor = JoyImageEditImageProcessor( + vae_scale_factor=self.vae_scale_factor_spatial, + ) + + self.prompt_template_encode = { + "multiple_images": ( + "<|im_start|>system\n \\nDescribe the image by detailing the color, shape, size, texture, " + "quantity, text, spatial relationships of the objects and background:<|im_end|>\n" + "{}<|im_start|>assistant\n" + ), + } + self.prompt_template_encode_start_idx = { + "multiple_images": 34, + } + + # ------------------------------------------------------------------ + # Internal helpers + # ------------------------------------------------------------------ + + def _get_last_decoder_hidden_states(self, forward_fn, **kwargs): + """ + Run ``forward_fn(**kwargs)`` while capturing the **pre-norm** output of the last decoder layer via a forward + hook. + + This model was trained on transformers 4.57, where ``Qwen3VLForConditionalGeneration``'s + ``@check_model_inputs`` decorator monkey-patched each decoder layer to collect ``hidden_states``. Because + ``Qwen3VLCausalLMOutputWithPast`` has no ``last_hidden_state`` field, ``tie_last_hidden_states`` had no effect + and ``hidden_states[-1]`` was the **pre-norm** output of the last decoder layer. + + Starting from https://github.com/huggingface/transformers/pull/42609 the CausalLM forward explicitly returns + ``hidden_states=outputs.hidden_states`` from the inner model. Combined with the subsequent + ``@check_model_inputs`` → ``@capture_outputs`` migration (transformers 5.x), ``hidden_states`` is now captured + at the ``Qwen3VLTextModel`` level where ``tie_last_hidden_states=True`` replaces ``hidden_states[-1]`` with the + **post-norm** ``last_hidden_state``. The CausalLM simply passes this through, so ``hidden_states[-1]`` becomes + post-norm – a ~10x scale difference (std ~2 vs ~21) that breaks inference. + + This helper bypasses both mechanisms by hooking the last decoder layer directly, returning the raw pre-norm + output regardless of the transformers version. + """ + captured = {} + + def _hook(_module, _input, output): + captured["hidden_states"] = output[0] if isinstance(output, tuple) else output + + handle = self.text_encoder.model.language_model.layers[-1].register_forward_hook(_hook) + try: + forward_fn(**kwargs) + finally: + handle.remove() + return captured["hidden_states"] + + def encode_prompt_multiple_images( + self, + prompt: Union[str, List[str]], + device: Optional[torch.device] = None, + images: Optional[List[Image.Image]] = None, + max_sequence_length: Optional[int] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Encode prompts with inline tokens via the Qwen3-VL processor.""" + device = device or self._execution_device + template = self.prompt_template_encode["multiple_images"] + drop_idx = self.prompt_template_encode_start_idx["multiple_images"] + + prompt = [prompt] if isinstance(prompt, str) else prompt + prompt = [p.replace("\n", "<|vision_start|><|image_pad|><|vision_end|>") for p in prompt] + prompt = [template.format(p) for p in prompt] + + inputs = self.processor( + text=prompt, + images=images, + padding=True, + return_tensors="pt", + ).to(device) + + last_hidden_states = self._get_last_decoder_hidden_states(self.text_encoder, **inputs) + + prompt_embeds = last_hidden_states[:, drop_idx:] + prompt_embeds_mask = inputs["attention_mask"][:, drop_idx:] + + if max_sequence_length is not None and prompt_embeds.shape[1] > max_sequence_length: + prompt_embeds = prompt_embeds[:, -max_sequence_length:, :] + prompt_embeds_mask = prompt_embeds_mask[:, -max_sequence_length:] + + return prompt_embeds, prompt_embeds_mask + + def _pad_sequence(self, x: torch.Tensor, target_length: int) -> torch.Tensor: + current_length = x.shape[1] + if current_length >= target_length: + return x[:, -target_length:] + padding_length = target_length - current_length + if x.ndim >= 3: + padding = torch.zeros( + (x.shape[0], padding_length, *x.shape[2:]), dtype=x.dtype, device=x.device + ) + else: + padding = torch.zeros((x.shape[0], padding_length), dtype=x.dtype, device=x.device) + return torch.cat([x, padding], dim=1) + + def normalize_latents(self, latent: torch.Tensor) -> torch.Tensor: + if hasattr(self.vae.config, "latents_mean") and hasattr(self.vae.config, "latents_std"): + latents_mean = ( + torch.tensor(self.vae.config.latents_mean).view(1, -1, 1, 1, 1).to(latent.device, latent.dtype) + ) + latents_std = ( + torch.tensor(self.vae.config.latents_std).view(1, -1, 1, 1, 1).to(latent.device, latent.dtype) + ) + latent = (latent - latents_mean) / latents_std + else: + latent = latent * self.vae.config.scaling_factor + return latent + + def denormalize_latents(self, latent: torch.Tensor) -> torch.Tensor: + if hasattr(self.vae.config, "latents_mean") and hasattr(self.vae.config, "latents_std"): + latents_mean = ( + torch.tensor(self.vae.config.latents_mean).view(1, -1, 1, 1, 1).to(latent.device, latent.dtype) + ) + latents_std = ( + torch.tensor(self.vae.config.latents_std).view(1, -1, 1, 1, 1).to(latent.device, latent.dtype) + ) + latent = latent * latents_std + latents_mean + else: + latent = latent / self.vae.config.scaling_factor + return latent + + def _resize_center_crop(self, img: Image.Image, target_size: Tuple[int, int]) -> Image.Image: + w, h = img.size + bh, bw = target_size + scale = max(bh / h, bw / w) + resize_h, resize_w = math.ceil(h * scale), math.ceil(w * scale) + img = img.resize((resize_w, resize_h), Image.LANCZOS) + left = (resize_w - bw) // 2 + top = (resize_h - bh) // 2 + img = img.crop((left, top, left + bw, top + bh)) + return img + + def _get_bucket_size(self, img: Image.Image) -> Tuple[int, int]: + return find_best_bucket(img.size[1], img.size[0], self.vae_image_processor.config.basesize) + + def prepare_latents( + self, + batch_size: int, + num_channels_latents: int, + height: int, + width: int, + dtype: torch.dtype, + device: torch.device, + generator: Optional[Union[torch.Generator, List[torch.Generator]]], + reference_images: Optional[List[List[Image.Image]]] = None, + enable_denormalization: bool = True, + ) -> Tuple[torch.Tensor, torch.Tensor, List[List[Tuple[int, int, int]]]]: + """Prepare 6D padded latent tensor with target noise + reference image latents. + + Returns: + padded_latents: [B, max_patches, C, pt, ph, pw] + target_mask: [B, max_patches] (True for target patches) + shape_list: per-sample list of (t, h, w) tuples for each component + """ + pt, ph, pw = self.transformer.config.patch_size + + all_patches = [] + all_target_masks = [] + all_shape_lists = [] + max_patches = 0 + + for i in range(batch_size): + sample_gen = generator[i] if isinstance(generator, list) else generator + + # Target noise + t_target = 1 + h_target = int(height) // self.vae_scale_factor_spatial + w_target = int(width) // self.vae_scale_factor_spatial + noise_shape = (num_channels_latents, t_target, h_target, w_target) + noise_block = randn_tensor(noise_shape, generator=sample_gen, device=device, dtype=dtype) + + sample_items = [noise_block] + + # Reference images + if reference_images is not None and reference_images[i]: + for ref_img_pil in reference_images[i]: + ref_h, ref_w = self._get_bucket_size(ref_img_pil) + ref_img_pil = self._resize_center_crop(ref_img_pil, (ref_h, ref_w)) + + ref_tensor = torch.from_numpy(np.array(ref_img_pil.convert("RGB"))).to(device=device, dtype=dtype) + ref_tensor = (ref_tensor / 127.5 - 1.0).permute(2, 0, 1).unsqueeze(1).unsqueeze(0) + + with torch.autocast(device_type="cuda", dtype=torch.float32): + ref_latent = self.vae.encode(ref_tensor.float()).latent_dist.mode() + ref_latent = ref_latent.to(dtype) + ref_latent = self.normalize_latents(ref_latent) + ref_latent = ref_latent.squeeze(0) # [C, 1, H', W'] + sample_items.append(ref_latent) + + # Patchify each item and build shape_list + sample_patches = [] + sample_masks = [] + sample_shapes = [] + + for j, item in enumerate(sample_items): + c, t, h, w = item.shape + l_t, l_h, l_w = t // pt, h // ph, w // pw + sample_shapes.append((l_t, l_h, l_w)) + + patches = rearrange(item, "c (t pt) (h ph) (w pw) -> (t h w) c pt ph pw", pt=pt, ph=ph, pw=pw) + sample_patches.append(patches) + sample_masks.append(torch.full((patches.shape[0],), j == 0, device=device, dtype=torch.bool)) + + combined_patches = torch.cat(sample_patches, dim=0) + combined_masks = torch.cat(sample_masks, dim=0) + + all_patches.append(combined_patches) + all_target_masks.append(combined_masks) + all_shape_lists.append(sample_shapes) + max_patches = max(max_patches, combined_patches.shape[0]) + + # Pad to uniform size + padded_latents = torch.zeros( + (batch_size, max_patches, num_channels_latents, pt, ph, pw), device=device, dtype=dtype + ) + target_mask = torch.zeros((batch_size, max_patches), device=device, dtype=torch.bool) + + for i in range(batch_size): + n = all_patches[i].shape[0] + padded_latents[i, :n] = all_patches[i] + target_mask[i, :n] = all_target_masks[i] + + return padded_latents, target_mask, all_shape_lists + + # ------------------------------------------------------------------ + # Properties + # ------------------------------------------------------------------ + + @property + def guidance_scale(self) -> float: + return self._guidance_scale + + @property + def do_classifier_free_guidance(self) -> bool: + return self._guidance_scale > 1 + + @property + def num_timesteps(self) -> int: + return self._num_timesteps + + @property + def interrupt(self) -> bool: + return self._interrupt + + # ------------------------------------------------------------------ + # Forward pass + # ------------------------------------------------------------------ + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + images: List[Image.Image] | List[List[Image.Image]] | None = None, + prompt: str | List[str] = None, + height: int | None = None, + width: int | None = None, + num_inference_steps: int = 30, + timesteps: List[int] = None, + sigmas: List[float] = None, + guidance_scale: float = 4.0, + negative_prompt: Optional[Union[str, List[str]]] = None, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.Tensor] = None, + prompt_embeds: Optional[torch.Tensor] = None, + prompt_embeds_mask: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds_mask: Optional[torch.Tensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback_on_step_end: Optional[ + Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks] + ] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + max_sequence_length: int = 4096, + enable_denormalization: bool = True, + ): + r""" + Generate an edited image from multiple reference images and a text prompt. + + Args: + images (`List[Image.Image]` or `List[List[Image.Image]]`): + Reference images for editing. Each image can have a different resolution. + If a flat list is provided, it's treated as one sample with multiple references. + prompt (`str` or `List[str]`): + Text prompt describing the desired edit. + height (`int`, *optional*): + Output height in pixels. If None, determined from the last reference image's bucket. + width (`int`, *optional*): + Output width in pixels. If None, determined from the last reference image's bucket. + num_inference_steps (`int`, defaults to 30): + Number of denoising steps. + guidance_scale (`float`, defaults to 4.0): + Classifier-free guidance scale. + negative_prompt (`str` or `List[str]`, *optional*): + Negative prompt for CFG. + generator (`torch.Generator`, *optional*): + RNG generator for reproducibility. + enable_denormalization (`bool`, defaults to True): + Whether to denormalize latents before VAE decoding. + + Examples: + + Returns: + [`JoyImageEditPlusPipelineOutput`] or `tuple`. + """ + # Normalize images input to List[List[Image]] + if images is not None: + if isinstance(images[0], Image.Image): + images = [images] # single sample + + self._guidance_scale = guidance_scale + self._interrupt = False + + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + # Determine output resolution from last reference image if not specified + if height is None or width is None: + if images is not None and len(images[0]) > 0: + last_img = images[0][-1] + height, width = self._get_bucket_size(last_img) + else: + height = height or 1024 + width = width or 1024 + + device = self._execution_device + + # Pre-process images: bucket-resize each reference image (matching original pipeline) + if images is not None: + processed_images = [] + for sample_imgs in images: + processed_sample = [] + for img in sample_imgs: + ref_h, ref_w = self._get_bucket_size(img) + resize_img = self._resize_center_crop(img, (ref_h, ref_w)) + processed_sample.append(resize_img) + processed_images.append(processed_sample) + images = processed_images + + # Construct prompts with tokens + prompt = [prompt] if isinstance(prompt, str) else prompt + if images is not None: + formatted_prompts = [] + for i in range(batch_size): + num_refs = len(images[i]) if i < len(images) else 0 + image_tags = "".join(["\n" for _ in range(num_refs)]) + p = prompt[i] if i < len(prompt) else prompt[0] + formatted_prompts.append(f"<|im_start|>user\n{image_tags}{p}<|im_end|>\n") + else: + formatted_prompts = [f"<|im_start|>user\n{p}<|im_end|>\n" for p in prompt] + + # Flatten all images for the processor + flattened_images = None + if images is not None: + flattened_images = [img for sublist in images for img in sublist] + + # Encode prompt + if prompt_embeds is None: + prompt_embeds, prompt_embeds_mask = self.encode_prompt_multiple_images( + prompt=formatted_prompts, + images=flattened_images, + device=device, + max_sequence_length=max_sequence_length, + ) + + torch.save(prompt_embeds, "prompt_embeds.pt") + # Encode negative prompt for CFG + if self.do_classifier_free_guidance: + print(f"negative_prompt: {negative_prompt}") + if negative_prompt is None and negative_prompt_embeds is None: + neg_prompts = [] + for i in range(batch_size): + num_refs = len(images[i]) if images is not None and i < len(images) else 0 + image_tags = "".join(["\n" for _ in range(num_refs)]) + neg_prompts.append(f"<|im_start|>user\n{image_tags} <|im_end|>\n") + negative_prompt = neg_prompts + elif negative_prompt is not None and negative_prompt_embeds is None: + neg_list = [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + neg_prompts = [] + for i in range(batch_size): + num_refs = len(images[i]) if images is not None and i < len(images) else 0 + image_tags = "".join(["\n" for _ in range(num_refs)]) + n = neg_list[i] if i < len(neg_list) else neg_list[0] + neg_prompts.append(f"<|im_start|>user\n{image_tags}{n}<|im_end|>\n") + negative_prompt = neg_prompts + + if negative_prompt_embeds is None: + neg_prompt_list = [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + negative_prompt_embeds, negative_prompt_embeds_mask = self.encode_prompt_multiple_images( + prompt=neg_prompt_list, + images=flattened_images, + device=device, + max_sequence_length=max_sequence_length, + ) + + # Pad and concatenate [negative, positive] + max_seq_len = max(prompt_embeds.shape[1], negative_prompt_embeds.shape[1]) + prompt_embeds = torch.cat([ + self._pad_sequence(negative_prompt_embeds, max_seq_len), + self._pad_sequence(prompt_embeds, max_seq_len), + ]) + if prompt_embeds_mask is not None and negative_prompt_embeds_mask is not None: + prompt_embeds_mask = torch.cat([ + self._pad_sequence(negative_prompt_embeds_mask, max_seq_len), + self._pad_sequence(prompt_embeds_mask, max_seq_len), + ]) + torch.save(prompt_embeds, 'prompt_embeds_2.pt') + + # Prepare timesteps — compute sigmas with single shift to match original scheduler + if timesteps is None and sigmas is None: + shift = getattr(self.scheduler.config, "shift", 1.0) + raw_sigmas = torch.linspace(1, 0, num_inference_steps + 1) + shifted_sigmas = shift * raw_sigmas / (1 + (shift - 1) * raw_sigmas) + sigmas = shifted_sigmas[:-1].tolist() + original_shift = self.scheduler.shift + self.scheduler.set_shift(1.0) + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, num_inference_steps, device, timesteps, sigmas + ) + self.scheduler.set_shift(original_shift) + else: + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, num_inference_steps, device, timesteps, sigmas + ) + + # Prepare latents (patchified) + num_channels_latents = self.transformer.config.in_channels + padded_latents, target_mask, shape_list = self.prepare_latents( + batch_size=batch_size, + num_channels_latents=num_channels_latents, + height=height, + width=width, + dtype=prompt_embeds.dtype, + device=device, + generator=generator, + reference_images=images, + enable_denormalization=enable_denormalization, + ) + torch.save(padded_latents, "padded_latents.pt") + torch.save(target_mask, "target_mask.pt") + # exit(0) + + # Zero out padding text tokens to prevent them from corrupting attention + # (original uses explicit attention masking; here we neutralize padding values) + if prompt_embeds_mask is not None: + prompt_embeds = prompt_embeds * prompt_embeds_mask.unsqueeze(-1) + + # Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + self._num_timesteps = len(timesteps) + clean_reference_backup = padded_latents.clone() + + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + # Restore reference patches + padded_latents[~target_mask] = clean_reference_backup[~target_mask] + + model_input = padded_latents + + # CFG expansion + if self.do_classifier_free_guidance: + model_input_cfg = torch.cat([model_input] * 2) + t_expand = t.repeat(model_input_cfg.shape[0]) + cfg_shape_list = shape_list * 2 + else: + model_input_cfg = model_input + t_expand = t.repeat(batch_size) + cfg_shape_list = shape_list + + # Transformer forward + noise_pred = self.transformer( + hidden_states=model_input_cfg, + timestep=t_expand, + encoder_hidden_states=prompt_embeds, + encoder_hidden_states_mask=prompt_embeds_mask, + shape_list=cfg_shape_list, + return_dict=False, + )[0] + + # CFG combination with norm rescaling + if self.do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + comb_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) + cond_norm = torch.norm(noise_pred_text, dim=2, keepdim=True) + noise_norm = torch.norm(comb_pred, dim=2, keepdim=True) + noise_pred = comb_pred * (cond_norm / noise_norm) + + # Scheduler step + padded_latents = self.scheduler.step(noise_pred, t, padded_latents, return_dict=False)[0].to( + dtype=prompt_embeds.dtype + ) + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + padded_latents = callback_outputs.pop("latents", padded_latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + if progress_bar is not None: + progress_bar.update() + + # Post-processing: decode target latents + if output_type != "latent": + padded_latents[~target_mask] = clean_reference_backup[~target_mask] + pt, ph, pw = self.transformer.config.patch_size + + image_list = [] + for b_idx in range(batch_size): + l_t, l_h, l_w = shape_list[b_idx][0] + target_len = l_t * l_h * l_w + + target_patches = padded_latents[b_idx, :target_len] + video_latent = rearrange( + target_patches, + "(t h w) c pt ph pw -> 1 c (t pt) (h ph) (w pw)", + t=l_t, h=l_h, w=l_w, + ) + + video_latent = self.denormalize_latents(video_latent) + + with torch.autocast(device_type="cuda", dtype=torch.float32): + sample_image = self.vae.decode(video_latent.float(), return_dict=False)[0] + sample_image = (sample_image / 2 + 0.5).clamp(0, 1).squeeze(0).cpu().float() + image_list.append(sample_image) + + # Convert to output format + output_images = [] + for img_tensor in image_list: + # img_tensor: [C, T, H, W] -> [C, H, W] (T=1) + img_tensor = img_tensor[:, 0] + img_np = (img_tensor.permute(1, 2, 0).numpy() * 255).clip(0, 255).astype(np.uint8) + if output_type == "pil": + output_images.append(Image.fromarray(img_np)) + elif output_type == "np": + output_images.append(img_np) + else: + output_images.append(img_tensor) + + image = output_images + else: + image = padded_latents + + self.maybe_free_model_hooks() + + if not return_dict: + return image + + return JoyImageEditPlusPipelineOutput(images=image) diff --git a/src/diffusers/pipelines/joyimage/pipeline_output.py b/src/diffusers/pipelines/joyimage/pipeline_output.py index 175dce3540d7..40d9d3aa100f 100644 --- a/src/diffusers/pipelines/joyimage/pipeline_output.py +++ b/src/diffusers/pipelines/joyimage/pipeline_output.py @@ -14,3 +14,11 @@ class JoyImageEditPipelineOutput(BaseOutput): """ images: Union[List[PIL.Image.Image], np.ndarray] + +@dataclass +class JoyImageEditPlusPipelineOutput(BaseOutput): + """ + Output class for JoyImage Edit Plus multi-image editing pipelines. + """ + + images: Union[List[PIL.Image.Image], np.ndarray] \ No newline at end of file From d6375a8b618bbc96078356bd1248efff202b7977 Mon Sep 17 00:00:00 2001 From: "tangyanfei.8" Date: Mon, 22 Jun 2026 05:31:11 +0000 Subject: [PATCH 02/14] refactor: remove debug code --- .../pipelines/joyimage/pipeline_joyimage_edit_plus.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/src/diffusers/pipelines/joyimage/pipeline_joyimage_edit_plus.py b/src/diffusers/pipelines/joyimage/pipeline_joyimage_edit_plus.py index c938e8e8ab32..980939f427d6 100644 --- a/src/diffusers/pipelines/joyimage/pipeline_joyimage_edit_plus.py +++ b/src/diffusers/pipelines/joyimage/pipeline_joyimage_edit_plus.py @@ -506,10 +506,8 @@ def __call__( max_sequence_length=max_sequence_length, ) - torch.save(prompt_embeds, "prompt_embeds.pt") # Encode negative prompt for CFG if self.do_classifier_free_guidance: - print(f"negative_prompt: {negative_prompt}") if negative_prompt is None and negative_prompt_embeds is None: neg_prompts = [] for i in range(batch_size): @@ -547,7 +545,6 @@ def __call__( self._pad_sequence(negative_prompt_embeds_mask, max_seq_len), self._pad_sequence(prompt_embeds_mask, max_seq_len), ]) - torch.save(prompt_embeds, 'prompt_embeds_2.pt') # Prepare timesteps — compute sigmas with single shift to match original scheduler if timesteps is None and sigmas is None: @@ -579,9 +576,6 @@ def __call__( reference_images=images, enable_denormalization=enable_denormalization, ) - torch.save(padded_latents, "padded_latents.pt") - torch.save(target_mask, "target_mask.pt") - # exit(0) # Zero out padding text tokens to prevent them from corrupting attention # (original uses explicit attention masking; here we neutralize padding values) From 885186a66852d10a3f733aad8189d278348ea13f Mon Sep 17 00:00:00 2001 From: "tangyanfei.8" Date: Tue, 23 Jun 2026 02:15:16 +0000 Subject: [PATCH 03/14] fix: address review issues for JoyImage Edit Plus - Remove einops dependency: replace rearrange with reshape/permute - Remove sglang-specific code from transformer forward - Remove unused import inspect from transformer - Fix hardcoded device_type="cuda" to use device.type - Simplify scheduler sigma math: delegate to retrieve_timesteps - Remove unused enable_denormalization parameter - Fix callback latents variable binding - Fix output_type="pt" to return stacked tensor - Set return_dict default to True in transformer forward - Add dummy objects for JoyImageEditPlus classes - Add transformer and pipeline test files --- .../transformer_joyimage_edit_plus.py | 19 +- .../joyimage/pipeline_joyimage_edit_plus.py | 50 ++-- src/diffusers/utils/dummy_pt_objects.py | 15 ++ .../dummy_torch_and_transformers_objects.py | 30 +++ ...t_models_transformer_joyimage_edit_plus.py | 114 +++++++++ .../joyimage/test_joyimage_edit_plus.py | 225 ++++++++++++++++++ 6 files changed, 403 insertions(+), 50 deletions(-) create mode 100644 tests/models/transformers/test_models_transformer_joyimage_edit_plus.py create mode 100644 tests/pipelines/joyimage/test_joyimage_edit_plus.py diff --git a/src/diffusers/models/transformers/transformer_joyimage_edit_plus.py b/src/diffusers/models/transformers/transformer_joyimage_edit_plus.py index abc8c2b4340a..572c983ec453 100644 --- a/src/diffusers/models/transformers/transformer_joyimage_edit_plus.py +++ b/src/diffusers/models/transformers/transformer_joyimage_edit_plus.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import inspect import math from typing import List, Tuple, Union @@ -255,7 +254,7 @@ def forward( encoder_hidden_states: torch.Tensor, encoder_hidden_states_mask: torch.Tensor | None = None, shape_list: List[List[Tuple[int, int, int]]] | None = None, - return_dict: bool = False, + return_dict: bool = True, ) -> Union[torch.Tensor, Tuple]: """ Args: @@ -269,22 +268,6 @@ def forward( batch_size, max_num_patches, channels, pt, ph, pw = hidden_states.shape device = hidden_states.device - # Unwrap list inputs (SglangXvideo passes these as lists from CFG branches) - if not isinstance(encoder_hidden_states, torch.Tensor): - encoder_hidden_states = encoder_hidden_states[0] - if isinstance(encoder_hidden_states_mask, list): - encoder_hidden_states_mask = encoder_hidden_states_mask[0] - - # Resolve shape_list from forward context if not explicitly provided - if shape_list is None: - try: - from sglang.multimodal_gen.runtime.managers.forward_context import get_forward_context - - forward_batch = get_forward_context().forward_batch - if forward_batch is not None and forward_batch.vae_image_sizes is not None: - shape_list = [list(forward_batch.vae_image_sizes)] * batch_size - except (ImportError, AttributeError): - pass if shape_list is None: raise ValueError( "shape_list must be provided either as an argument or via forward_batch.vae_image_sizes" diff --git a/src/diffusers/pipelines/joyimage/pipeline_joyimage_edit_plus.py b/src/diffusers/pipelines/joyimage/pipeline_joyimage_edit_plus.py index 980939f427d6..144650d46b05 100644 --- a/src/diffusers/pipelines/joyimage/pipeline_joyimage_edit_plus.py +++ b/src/diffusers/pipelines/joyimage/pipeline_joyimage_edit_plus.py @@ -18,7 +18,6 @@ import numpy as np import torch -from einops import rearrange from PIL import Image from transformers import ( Qwen2Tokenizer, @@ -282,7 +281,6 @@ def prepare_latents( device: torch.device, generator: Optional[Union[torch.Generator, List[torch.Generator]]], reference_images: Optional[List[List[Image.Image]]] = None, - enable_denormalization: bool = True, ) -> Tuple[torch.Tensor, torch.Tensor, List[List[Tuple[int, int, int]]]]: """Prepare 6D padded latent tensor with target noise + reference image latents. @@ -319,7 +317,7 @@ def prepare_latents( ref_tensor = torch.from_numpy(np.array(ref_img_pil.convert("RGB"))).to(device=device, dtype=dtype) ref_tensor = (ref_tensor / 127.5 - 1.0).permute(2, 0, 1).unsqueeze(1).unsqueeze(0) - with torch.autocast(device_type="cuda", dtype=torch.float32): + with torch.autocast(device_type=device.type, dtype=torch.float32): ref_latent = self.vae.encode(ref_tensor.float()).latent_dist.mode() ref_latent = ref_latent.to(dtype) ref_latent = self.normalize_latents(ref_latent) @@ -336,7 +334,8 @@ def prepare_latents( l_t, l_h, l_w = t // pt, h // ph, w // pw sample_shapes.append((l_t, l_h, l_w)) - patches = rearrange(item, "c (t pt) (h ph) (w pw) -> (t h w) c pt ph pw", pt=pt, ph=ph, pw=pw) + patches = item.reshape(c, l_t, pt, l_h, ph, l_w, pw) + patches = patches.permute(1, 3, 5, 0, 2, 4, 6).reshape(-1, c, pt, ph, pw) sample_patches.append(patches) sample_masks.append(torch.full((patches.shape[0],), j == 0, device=device, dtype=torch.bool)) @@ -411,7 +410,6 @@ def __call__( ] = None, callback_on_step_end_tensor_inputs: List[str] = ["latents"], max_sequence_length: int = 4096, - enable_denormalization: bool = True, ): r""" Generate an edited image from multiple reference images and a text prompt. @@ -434,8 +432,6 @@ def __call__( Negative prompt for CFG. generator (`torch.Generator`, *optional*): RNG generator for reproducibility. - enable_denormalization (`bool`, defaults to True): - Whether to denormalize latents before VAE decoding. Examples: @@ -546,22 +542,10 @@ def __call__( self._pad_sequence(prompt_embeds_mask, max_seq_len), ]) - # Prepare timesteps — compute sigmas with single shift to match original scheduler - if timesteps is None and sigmas is None: - shift = getattr(self.scheduler.config, "shift", 1.0) - raw_sigmas = torch.linspace(1, 0, num_inference_steps + 1) - shifted_sigmas = shift * raw_sigmas / (1 + (shift - 1) * raw_sigmas) - sigmas = shifted_sigmas[:-1].tolist() - original_shift = self.scheduler.shift - self.scheduler.set_shift(1.0) - timesteps, num_inference_steps = retrieve_timesteps( - self.scheduler, num_inference_steps, device, timesteps, sigmas - ) - self.scheduler.set_shift(original_shift) - else: - timesteps, num_inference_steps = retrieve_timesteps( - self.scheduler, num_inference_steps, device, timesteps, sigmas - ) + # Prepare timesteps + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, num_inference_steps, device, timesteps, sigmas + ) # Prepare latents (patchified) num_channels_latents = self.transformer.config.in_channels @@ -574,7 +558,6 @@ def __call__( device=device, generator=generator, reference_images=images, - enable_denormalization=enable_denormalization, ) # Zero out padding text tokens to prevent them from corrupting attention @@ -631,6 +614,7 @@ def __call__( ) if callback_on_step_end is not None: + latents = padded_latents callback_kwargs = {} for k in callback_on_step_end_tensor_inputs: callback_kwargs[k] = locals()[k] @@ -653,15 +637,13 @@ def __call__( target_len = l_t * l_h * l_w target_patches = padded_latents[b_idx, :target_len] - video_latent = rearrange( - target_patches, - "(t h w) c pt ph pw -> 1 c (t pt) (h ph) (w pw)", - t=l_t, h=l_h, w=l_w, - ) + c_lat = target_patches.shape[1] + video_latent = target_patches.reshape(l_t, l_h, l_w, c_lat, pt, ph, pw) + video_latent = video_latent.permute(3, 0, 4, 1, 5, 2, 6).reshape(1, c_lat, l_t * pt, l_h * ph, l_w * pw) video_latent = self.denormalize_latents(video_latent) - with torch.autocast(device_type="cuda", dtype=torch.float32): + with torch.autocast(device_type=device.type, dtype=torch.float32): sample_image = self.vae.decode(video_latent.float(), return_dict=False)[0] sample_image = (sample_image / 2 + 0.5).clamp(0, 1).squeeze(0).cpu().float() image_list.append(sample_image) @@ -671,15 +653,19 @@ def __call__( for img_tensor in image_list: # img_tensor: [C, T, H, W] -> [C, H, W] (T=1) img_tensor = img_tensor[:, 0] - img_np = (img_tensor.permute(1, 2, 0).numpy() * 255).clip(0, 255).astype(np.uint8) if output_type == "pil": + img_np = (img_tensor.permute(1, 2, 0).cpu().float().numpy() * 255).clip(0, 255).astype(np.uint8) output_images.append(Image.fromarray(img_np)) elif output_type == "np": + img_np = (img_tensor.permute(1, 2, 0).cpu().float().numpy() * 255).clip(0, 255).astype(np.uint8) output_images.append(img_np) else: output_images.append(img_tensor) - image = output_images + if output_type == "pt": + image = torch.stack(output_images) + else: + image = output_images else: image = padded_latents diff --git a/src/diffusers/utils/dummy_pt_objects.py b/src/diffusers/utils/dummy_pt_objects.py index 8eb942e68075..06c5b1d425fe 100644 --- a/src/diffusers/utils/dummy_pt_objects.py +++ b/src/diffusers/utils/dummy_pt_objects.py @@ -1500,6 +1500,21 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) +class JoyImageEditPlusTransformer3DModel(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + class Kandinsky3UNet(metaclass=DummyObject): _backends = ["torch"] diff --git a/src/diffusers/utils/dummy_torch_and_transformers_objects.py b/src/diffusers/utils/dummy_torch_and_transformers_objects.py index 4d7710adcdd1..8955e52aae6f 100644 --- a/src/diffusers/utils/dummy_torch_and_transformers_objects.py +++ b/src/diffusers/utils/dummy_torch_and_transformers_objects.py @@ -2222,6 +2222,36 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) +class JoyImageEditPlusPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + +class JoyImageEditPlusPipelineOutput(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + class Kandinsky3Img2ImgPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] diff --git a/tests/models/transformers/test_models_transformer_joyimage_edit_plus.py b/tests/models/transformers/test_models_transformer_joyimage_edit_plus.py new file mode 100644 index 000000000000..451dbfbbf0ca --- /dev/null +++ b/tests/models/transformers/test_models_transformer_joyimage_edit_plus.py @@ -0,0 +1,114 @@ +# Copyright 2025 The HuggingFace Team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest +import torch + +from diffusers import JoyImageEditPlusTransformer3DModel +from diffusers.utils.torch_utils import randn_tensor + +from ...testing_utils import enable_full_determinism, torch_device +from ..testing_utils import ( + AttentionTesterMixin, + BaseModelTesterConfig, + MemoryTesterMixin, + ModelTesterMixin, + TorchCompileTesterMixin, + TrainingTesterMixin, +) + + +enable_full_determinism() + + +class JoyImageEditPlusTransformerTesterConfig(BaseModelTesterConfig): + @property + def model_class(self): + return JoyImageEditPlusTransformer3DModel + + @property + def output_shape(self) -> tuple[int, ...]: + return (2, 16, 1, 2, 2) + + @property + def input_shape(self) -> tuple[int, ...]: + return (2, 16, 1, 2, 2) + + @property + def main_input_name(self) -> str: + return "hidden_states" + + @property + def uses_custom_attn_processor(self) -> bool: + return True + + @property + def model_split_percents(self) -> list: + return [0.7, 0.6, 0.6] + + @property + def generator(self): + return torch.Generator("cpu").manual_seed(0) + + def get_init_dict(self) -> dict[str, int | list[int]]: + return { + "patch_size": [1, 2, 2], + "in_channels": 16, + "hidden_size": 32, + "num_attention_heads": 2, + "text_dim": 16, + "num_layers": 2, + "rope_dim_list": [4, 6, 6], + "theta": 256, + } + + def get_dummy_inputs(self) -> dict[str, torch.Tensor]: + batch_size = 1 + max_patches = 2 + hidden_states = randn_tensor( + (batch_size, max_patches, 16, 1, 2, 2), generator=self.generator, device=torch_device + ) + encoder_hidden_states = randn_tensor((batch_size, 12, 16), generator=self.generator, device=torch_device) + timestep = torch.tensor([1.0]).to(torch_device).expand(batch_size) + shape_list = [[(1, 1, 1), (1, 1, 1)]] + return { + "hidden_states": hidden_states, + "encoder_hidden_states": encoder_hidden_states, + "timestep": timestep, + "shape_list": shape_list, + } + + +class TestJoyImageEditPlusTransformer(JoyImageEditPlusTransformerTesterConfig, ModelTesterMixin): + @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16], ids=["fp16", "bf16"]) + def test_from_save_pretrained_dtype_inference(self, tmp_path, dtype): + pytest.skip("Tolerance requirements too high for meaningful test") + + +class TestJoyImageEditPlusTransformerMemory(JoyImageEditPlusTransformerTesterConfig, MemoryTesterMixin): + pass + + +class TestJoyImageEditPlusTransformerTraining(JoyImageEditPlusTransformerTesterConfig, TrainingTesterMixin): + def test_gradient_checkpointing_is_applied(self): + expected_set = {"JoyImageEditPlusTransformer3DModel"} + super().test_gradient_checkpointing_is_applied(expected_set=expected_set) + + +class TestJoyImageEditPlusTransformerAttention(JoyImageEditPlusTransformerTesterConfig, AttentionTesterMixin): + pass + + +class TestJoyImageEditPlusTransformerCompile(JoyImageEditPlusTransformerTesterConfig, TorchCompileTesterMixin): + pass diff --git a/tests/pipelines/joyimage/test_joyimage_edit_plus.py b/tests/pipelines/joyimage/test_joyimage_edit_plus.py new file mode 100644 index 000000000000..e41265d30128 --- /dev/null +++ b/tests/pipelines/joyimage/test_joyimage_edit_plus.py @@ -0,0 +1,225 @@ +# Copyright 2025 The HuggingFace Team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +from unittest.mock import patch + +import numpy as np +import pytest +import torch +from PIL import Image +from transformers import Qwen3VLForConditionalGeneration, Qwen3VLProcessor + +from diffusers import ( + AutoencoderKLWan, + FlowMatchEulerDiscreteScheduler, + JoyImageEditPlusPipeline, + JoyImageEditPlusTransformer3DModel, +) +from diffusers.hooks import apply_group_offloading + +from ...testing_utils import enable_full_determinism, require_torch_accelerator, torch_device +from ..pipeline_params import TEXT_TO_IMAGE_PARAMS +from ..test_pipelines_common import PipelineTesterMixin + + +enable_full_determinism() + + +class JoyImageEditPlusPipelineFastTests(PipelineTesterMixin, unittest.TestCase): + pipeline_class = JoyImageEditPlusPipeline + params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs"} + batch_params = frozenset(["prompt", "images"]) + required_optional_params = frozenset( + [ + "num_inference_steps", + "generator", + "latents", + "return_dict", + "callback_on_step_end", + "callback_on_step_end_tensor_inputs", + ] + ) + supports_dduf = False + test_xformers_attention = False + test_layerwise_casting = True + test_group_offloading = True + + def setUp(self): + super().setUp() + self._bucket_patcher = patch( + "diffusers.pipelines.joyimage.image_processor.find_best_bucket", + return_value=(32, 32), + ) + self._bucket_patcher.start() + + def tearDown(self): + self._bucket_patcher.stop() + super().tearDown() + + def get_dummy_components(self): + tiny_ckpt_id = "huangfeice/tiny-random-Qwen3VLForConditionalGeneration" + + torch.manual_seed(0) + transformer = JoyImageEditPlusTransformer3DModel( + patch_size=[1, 2, 2], + in_channels=16, + hidden_size=32, + num_attention_heads=2, + text_dim=16, + num_layers=1, + rope_dim_list=[4, 6, 6], + theta=256, + ) + + torch.manual_seed(0) + vae = AutoencoderKLWan( + base_dim=3, + z_dim=16, + dim_mult=[1, 1, 1, 1], + num_res_blocks=1, + temperal_downsample=[False, True, True], + ) + + scheduler = FlowMatchEulerDiscreteScheduler() + + processor = Qwen3VLProcessor.from_pretrained(tiny_ckpt_id) + processor.image_processor.min_pixels = 4 * 28 * 28 + processor.image_processor.max_pixels = 4 * 28 * 28 + + text_encoder = Qwen3VLForConditionalGeneration.from_pretrained(tiny_ckpt_id) + text_encoder.resize_token_embeddings(len(processor.tokenizer)) + + components = { + "transformer": transformer, + "vae": vae, + "scheduler": scheduler, + "text_encoder": text_encoder, + "tokenizer": processor.tokenizer, + "processor": processor, + } + return components + + def get_dummy_inputs(self, device, seed=0): + if str(device).startswith("mps"): + generator = torch.manual_seed(seed) + else: + generator = torch.Generator(device=device).manual_seed(seed) + + inputs = { + "prompt": "combine the two images", + "images": [Image.new("RGB", (32, 32)), Image.new("RGB", (32, 32))], + "generator": generator, + "num_inference_steps": 2, + "guidance_scale": 1.0, + "height": 32, + "width": 32, + "max_sequence_length": 16, + "output_type": "pt", + } + return inputs + + def test_inference(self): + device = "cpu" + + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + pipe.to(device) + pipe.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs(device) + image = pipe(**inputs).images + generated_image = image[0] + + self.assertEqual(generated_image.shape, (3, 32, 32)) + + def test_inference_batch_single_identical(self): + self._test_inference_batch_single_identical(batch_size=3, expected_max_diff=1e-1) + + @unittest.skip("num_images_per_prompt not applicable: each prompt is bound to reference images") + def test_num_images_per_prompt(self): + pass + + @unittest.skip("Test not supported") + def test_attention_slicing_forward_pass(self): + pass + + @pytest.mark.xfail(condition=True, reason="Preconfigured embeddings need to be revisited.", strict=False) + def test_encode_prompt_works_in_isolation(self, extra_required_param_value_dict=None, atol=1e-4, rtol=1e-4): + super().test_encode_prompt_works_in_isolation(extra_required_param_value_dict, atol, rtol) + + @require_torch_accelerator + def test_group_offloading_inference(self): + if not self.test_group_offloading: + return + + def create_pipe(): + torch.manual_seed(0) + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + pipe.set_progress_bar_config(disable=None) + return pipe + + def run_forward(pipe): + torch.manual_seed(0) + inputs = self.get_dummy_inputs(torch_device) + return pipe(**inputs)[0] + + pipe = create_pipe().to(torch_device) + output_without_group_offloading = run_forward(pipe) + + pipe = create_pipe() + for component_name in ["transformer", "text_encoder"]: + component = getattr(pipe, component_name, None) + if component is None: + continue + if hasattr(component, "enable_group_offload"): + component.enable_group_offload( + torch.device(torch_device), offload_type="block_level", num_blocks_per_group=1 + ) + else: + apply_group_offloading( + component, + onload_device=torch.device(torch_device), + offload_type="block_level", + num_blocks_per_group=1, + ) + pipe.vae.to(torch_device) + output_with_block_level = run_forward(pipe) + + pipe = create_pipe() + pipe.transformer.enable_group_offload(torch.device(torch_device), offload_type="leaf_level") + pipe.text_encoder.to(torch_device) + pipe.vae.to(torch_device) + output_with_leaf_level = run_forward(pipe) + + if torch.is_tensor(output_without_group_offloading): + output_without_group_offloading = output_without_group_offloading.detach().cpu().numpy() + output_with_block_level = output_with_block_level.detach().cpu().numpy() + output_with_leaf_level = output_with_leaf_level.detach().cpu().numpy() + + self.assertTrue(np.allclose(output_without_group_offloading, output_with_block_level, atol=1e-4)) + self.assertTrue(np.allclose(output_without_group_offloading, output_with_leaf_level, atol=1e-4)) + + @unittest.skip("Qwen3VLForConditionalGeneration does not support leaf-level group offloading") + def test_pipeline_level_group_offloading_inference(self): + pass + + @unittest.skip("Qwen3VLForConditionalGeneration does not support sequential CPU offloading") + def test_sequential_cpu_offload_forward_pass(self): + pass + + @unittest.skip("Qwen3VLForConditionalGeneration does not support sequential CPU offloading") + def test_sequential_offload_forward_pass_twice(self): + pass From aa2f5638b31463b1074b9c3e60cbf78332a989db Mon Sep 17 00:00:00 2001 From: "tangyanfei.8" Date: Tue, 23 Jun 2026 02:30:41 +0000 Subject: [PATCH 04/14] fix: add missing newline at end of pipeline_output.py --- src/diffusers/pipelines/joyimage/pipeline_output.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/joyimage/pipeline_output.py b/src/diffusers/pipelines/joyimage/pipeline_output.py index 40d9d3aa100f..23cb24431462 100644 --- a/src/diffusers/pipelines/joyimage/pipeline_output.py +++ b/src/diffusers/pipelines/joyimage/pipeline_output.py @@ -21,4 +21,4 @@ class JoyImageEditPlusPipelineOutput(BaseOutput): Output class for JoyImage Edit Plus multi-image editing pipelines. """ - images: Union[List[PIL.Image.Image], np.ndarray] \ No newline at end of file + images: Union[List[PIL.Image.Image], np.ndarray] From 8a911e5614155a0e1c112ac217b5c4e19763b500 Mon Sep 17 00:00:00 2001 From: "tangyanfei.8" Date: Tue, 23 Jun 2026 02:39:50 +0000 Subject: [PATCH 05/14] fix: add missing newline at end of pipeline_output.py --- src/diffusers/pipelines/joyimage/pipeline_output.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/diffusers/pipelines/joyimage/pipeline_output.py b/src/diffusers/pipelines/joyimage/pipeline_output.py index 23cb24431462..30be7c248e33 100644 --- a/src/diffusers/pipelines/joyimage/pipeline_output.py +++ b/src/diffusers/pipelines/joyimage/pipeline_output.py @@ -22,3 +22,4 @@ class JoyImageEditPlusPipelineOutput(BaseOutput): """ images: Union[List[PIL.Image.Image], np.ndarray] + From 1344fd0275326ff909aa7aa260d84c042c0c40db Mon Sep 17 00:00:00 2001 From: "tangyanfei.8" Date: Wed, 24 Jun 2026 10:16:46 +0000 Subject: [PATCH 06/14] doc: add joyimage-edit-plus doc --- docs/source/en/_toctree.yml | 4 ++ .../models/transformer_joyimage_edit_plus.md | 29 +++++++++ .../en/api/pipelines/joyimage_edit_plus.md | 61 +++++++++++++++++++ 3 files changed, 94 insertions(+) create mode 100644 docs/source/en/api/models/transformer_joyimage_edit_plus.md create mode 100644 docs/source/en/api/pipelines/joyimage_edit_plus.md diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index 23e2c867b580..f3239722c64f 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -355,6 +355,8 @@ title: Ideogram4Transformer2DModel - local: api/models/transformer_joyimage title: JoyImageEditTransformer3DModel + - local: api/models/transformer_joyimage_edit_plus + title: JoyImageEditPlusTransformer3DModel - local: api/models/krea2_transformer2d title: Krea2Transformer2DModel - local: api/models/latte_transformer3d @@ -555,6 +557,8 @@ title: InstructPix2Pix - local: api/pipelines/joyimage_edit title: JoyImage Edit + - local: api/pipelines/joyimage_edit_plus + title: JoyImage Edit Plus - local: api/pipelines/kandinsky title: Kandinsky 2.1 - local: api/pipelines/kandinsky_v22 diff --git a/docs/source/en/api/models/transformer_joyimage_edit_plus.md b/docs/source/en/api/models/transformer_joyimage_edit_plus.md new file mode 100644 index 000000000000..776c53eaf20c --- /dev/null +++ b/docs/source/en/api/models/transformer_joyimage_edit_plus.md @@ -0,0 +1,29 @@ + + +# JoyImageEditPlusTransformer3DModel + +The model can be loaded with the following code snippet. + +```python +from diffusers import JoyImageEditPlusTransformer3DModel + +transformer = JoyImageEditPlusTransformer3DModel.from_pretrained("jdopensource/JoyAI-Image-Edit-Plus-Diffusers", subfolder="transformer", torch_dtype=torch.bfloat16) +``` + +## JoyImageEditPlusTransformer3DModel + +[[autodoc]] JoyImageEditPlusTransformer3DModel + +## Transformer2DModelOutput + +[[autodoc]] models.modeling_outputs.Transformer2DModelOutput diff --git a/docs/source/en/api/pipelines/joyimage_edit_plus.md b/docs/source/en/api/pipelines/joyimage_edit_plus.md new file mode 100644 index 000000000000..2ce8e2f29647 --- /dev/null +++ b/docs/source/en/api/pipelines/joyimage_edit_plus.md @@ -0,0 +1,61 @@ + + +# JoyAI-Image-Edit-Plus + +[JoyAI-Image](https://github.com/jd-opensource/JoyAI-Image) is a unified multimodal foundation model for image understanding, text-to-image generation, and instruction-guided image editing. It combines an 8B Multimodal Large Language Model (MLLM) with a 16B Multimodal Diffusion Transformer (MMDiT). + +JoyAI-Image-Edit-Plus is a multi-image instruction-guided editing model that accepts **multiple reference images** and a text instruction to generate a new image that combines elements from the references according to the instruction. It supports 1–5 reference images per sample. + +| Model | Description | Download | +|:-----:|:-----------:|:--------:| +| JoyAI-Image-Edit-Plus | Multi-image instruction-guided editing with element composition from multiple references | [Hugging Face](https://huggingface.co/jdopensource/JoyAI-Image-Edit-Plus-Diffusers) | + +```python +import torch +from PIL import Image +from diffusers import JoyImageEditPlusPipeline + +pipeline = JoyImageEditPlusPipeline.from_pretrained( + "jdopensource/JoyAI-Image-Edit-Plus-Diffusers", torch_dtype=torch.bfloat16 +) +pipeline.to("cuda") + +images = [ + Image.open("reference_0.png").convert("RGB"), + Image.open("reference_1.png").convert("RGB"), +] + +target_h, target_w = pipeline._get_bucket_size(images[-1]) + +output = pipeline( + images=images, + prompt="Combine the person from the second image with the scene from the first image.", + negative_prompt="low quality, blurry, deformed", + height=target_h, + width=target_w, + num_inference_steps=30, + guidance_scale=4.0, + generator=torch.Generator("cuda").manual_seed(42), +).images[0] +output.save("joyimage_edit_plus_output.png") +``` + +## JoyImageEditPlusPipeline + +[[autodoc]] JoyImageEditPlusPipeline + - all + - __call__ + +## JoyImageEditPlusPipelineOutput + +[[autodoc]] pipelines.joyimage.pipeline_output.JoyImageEditPlusPipelineOutput From 53d0b331fa6cc7a482b3e6bde50c9f20922db6b4 Mon Sep 17 00:00:00 2001 From: "tangyanfei.8" Date: Wed, 24 Jun 2026 11:27:30 +0000 Subject: [PATCH 07/14] refactor: update code format --- .../transformer_joyimage_edit_plus.py | 48 +++-- .../joyimage/pipeline_joyimage_edit_plus.py | 174 +++++++++++++----- 2 files changed, 159 insertions(+), 63 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_joyimage_edit_plus.py b/src/diffusers/models/transformers/transformer_joyimage_edit_plus.py index 572c983ec453..125dc30cc726 100644 --- a/src/diffusers/models/transformers/transformer_joyimage_edit_plus.py +++ b/src/diffusers/models/transformers/transformer_joyimage_edit_plus.py @@ -13,7 +13,6 @@ # limitations under the License. import math -from typing import List, Tuple, Union import torch import torch.nn as nn @@ -40,8 +39,8 @@ def _apply_rotary_emb_batched( xq: torch.Tensor, xk: torch.Tensor, - freqs_cis: Tuple[torch.Tensor, torch.Tensor], -) -> Tuple[torch.Tensor, torch.Tensor]: + freqs_cis: tuple[torch.Tensor, torch.Tensor], +) -> tuple[torch.Tensor, torch.Tensor]: """RoPE that handles both batched [B, S, D] and unbatched [S, D] freqs.""" cos, sin = freqs_cis[0].to(xq.device), freqs_cis[1].to(xq.device) @@ -77,10 +76,10 @@ def __call__( attn: "JoyImageAttention", hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor = None, - image_rotary_emb: Tuple[torch.Tensor, torch.Tensor] | None = None, + image_rotary_emb: tuple[torch.Tensor, torch.Tensor] | None = None, attention_mask: torch.Tensor | None = None, **kwargs, - ) -> Tuple[torch.Tensor, torch.Tensor]: + ) -> tuple[torch.Tensor, torch.Tensor]: if encoder_hidden_states is None: raise ValueError("JoyImageEditPlusAttnProcessor requires encoder_hidden_states") @@ -140,12 +139,37 @@ def __call__( class JoyImageEditPlusTransformer3DModel(ModelMixin, ConfigMixin, AttentionMixin): - """JoyImage Edit Plus Transformer for multi-image editing. + r""" + JoyImage Edit Plus Transformer for multi-image editing. Uses a patchify+padding approach where each reference image and the target noise are independently patchified and concatenated into a flat patch sequence. Supports variable-resolution reference images. - Input format: [B, max_patches, C, pt, ph, pw] (6D padded patches) + Input format: `[B, max_patches, C, pt, ph, pw]` (6D padded patches). + + Args: + patch_size (`list`, defaults to `[1, 2, 2]`): + Patch size for patchifying the latent input along `(t, h, w)` dimensions. + in_channels (`int`, defaults to `16`): + The number of channels in the input latent. + out_channels (`int`, *optional*, defaults to `None`): + The number of channels in the output. If not specified, it defaults to `in_channels`. + hidden_size (`int`, defaults to `3072`): + The dimensionality of the hidden representations. + num_attention_heads (`int`, defaults to `24`): + The number of attention heads. + text_dim (`int`, defaults to `4096`): + The dimensionality of the text encoder output. + mlp_width_ratio (`float`, defaults to `4.0`): + The ratio of MLP hidden dimension to `hidden_size`. + num_layers (`int`, defaults to `20`): + The number of double-stream transformer blocks. + rope_dim_list (`list[int]`, defaults to `[16, 56, 56]`): + The dimensions for 3D rotary positional embeddings along `(t, h, w)`. + rope_type (`str`, defaults to `"rope"`): + The type of rotary positional embedding. + theta (`int`, defaults to `256`): + The base frequency for rotary embeddings. """ _skip_layerwise_casting_patterns = ["img_in", "condition_embedder", "norm"] @@ -222,9 +246,9 @@ def __init__( def _get_rotary_pos_embed_for_range( self, - start: Tuple[int, int, int], - stop: Tuple[int, int, int], - ) -> Tuple[torch.Tensor, torch.Tensor]: + start: tuple[int, int, int], + stop: tuple[int, int, int], + ) -> tuple[torch.Tensor, torch.Tensor]: """Generate 3D RoPE for a spatial range [start, stop).""" head_dim = self.hidden_size // self.num_attention_heads rope_dim_list = self.rope_dim_list @@ -253,9 +277,9 @@ def forward( timestep: torch.Tensor, encoder_hidden_states: torch.Tensor, encoder_hidden_states_mask: torch.Tensor | None = None, - shape_list: List[List[Tuple[int, int, int]]] | None = None, + shape_list: list[list[tuple[int, int, int]]] | None = None, return_dict: bool = True, - ) -> Union[torch.Tensor, Tuple]: + ) -> torch.Tensor | tuple: """ Args: hidden_states: [B, max_patches, C, pt, ph, pw] - patchified latent input. diff --git a/src/diffusers/pipelines/joyimage/pipeline_joyimage_edit_plus.py b/src/diffusers/pipelines/joyimage/pipeline_joyimage_edit_plus.py index 144650d46b05..d314219b45d4 100644 --- a/src/diffusers/pipelines/joyimage/pipeline_joyimage_edit_plus.py +++ b/src/diffusers/pipelines/joyimage/pipeline_joyimage_edit_plus.py @@ -14,7 +14,7 @@ import inspect import math -from typing import Callable, Dict, List, Optional, Tuple, Union +from typing import Callable import numpy as np import torch @@ -30,13 +30,16 @@ from ...models import AutoencoderKLWan from ...models.transformers.transformer_joyimage_edit_plus import JoyImageEditPlusTransformer3DModel from ...schedulers import FlowMatchEulerDiscreteScheduler -from ...utils import replace_example_docstring +from ...utils import logging, replace_example_docstring from ...utils.torch_utils import randn_tensor from ..pipeline_utils import DiffusionPipeline from .image_processor import JoyImageEditImageProcessor, find_best_bucket from .pipeline_output import JoyImageEditPlusPipelineOutput +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + EXAMPLE_DOC_STRING = """ Examples: ```python @@ -68,12 +71,35 @@ def retrieve_timesteps( scheduler, - num_inference_steps: Optional[int] = None, - device: Optional[Union[str, torch.device]] = None, - timesteps: Optional[List[int]] = None, - sigmas: Optional[List[float]] = None, + num_inference_steps: int | None = None, + device: str | torch.device | None = None, + timesteps: list[int] | None = None, + sigmas: list[float] | None = None, **kwargs, ): + r""" + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`list[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`list[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and + the second element is the number of inference steps. + """ if timesteps is not None and sigmas is not None: raise ValueError("Only one of `timesteps` or `sigmas` can be passed.") @@ -97,12 +123,27 @@ def retrieve_timesteps( class JoyImageEditPlusPipeline(DiffusionPipeline): - """Diffusion pipeline for multi-image editing using JoyImage Edit Plus. + r""" + Diffusion pipeline for multi-image instruction-guided editing using JoyImage Edit Plus. Supports multiple reference images with different resolutions. Each reference image is independently VAE-encoded and patchified, then concatenated with the target noise patches for joint denoising. - Model offloading order: text_encoder -> transformer -> vae. + Args: + scheduler ([`FlowMatchEulerDiscreteScheduler`]): + A scheduler to be used in combination with `transformer` to denoise the encoded image latents. + vae ([`AutoencoderKLWan`]): + Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations. + text_encoder ([`Qwen3VLForConditionalGeneration`]): + Multimodal text encoder for prompt encoding with inline image understanding. + tokenizer ([`Qwen2Tokenizer`]): + Tokenizer for text processing. + transformer ([`JoyImageEditPlusTransformer3DModel`]): + Conditional Transformer (MMDiT) architecture to denoise the encoded image latents. + processor ([`Qwen3VLProcessor`]): + Processor for multimodal inputs (text + images). + text_token_max_length (`int`, defaults to `2048`): + Maximum token length for text encoding. """ model_cpu_offload_seq = "text_encoder->transformer->vae" @@ -186,11 +227,11 @@ def _hook(_module, _input, output): def encode_prompt_multiple_images( self, - prompt: Union[str, List[str]], - device: Optional[torch.device] = None, - images: Optional[List[Image.Image]] = None, - max_sequence_length: Optional[int] = None, - ) -> Tuple[torch.Tensor, torch.Tensor]: + prompt: str | list[str], + device: torch.device | None = None, + images: list[Image.Image] | None = None, + max_sequence_length: int | None = None, + ) -> tuple[torch.Tensor, torch.Tensor]: """Encode prompts with inline tokens via the Qwen3-VL processor.""" device = device or self._execution_device template = self.prompt_template_encode["multiple_images"] @@ -257,7 +298,7 @@ def denormalize_latents(self, latent: torch.Tensor) -> torch.Tensor: latent = latent / self.vae.config.scaling_factor return latent - def _resize_center_crop(self, img: Image.Image, target_size: Tuple[int, int]) -> Image.Image: + def _resize_center_crop(self, img: Image.Image, target_size: tuple[int, int]) -> Image.Image: w, h = img.size bh, bw = target_size scale = max(bh / h, bw / w) @@ -268,7 +309,7 @@ def _resize_center_crop(self, img: Image.Image, target_size: Tuple[int, int]) -> img = img.crop((left, top, left + bw, top + bh)) return img - def _get_bucket_size(self, img: Image.Image) -> Tuple[int, int]: + def _get_bucket_size(self, img: Image.Image) -> tuple[int, int]: return find_best_bucket(img.size[1], img.size[0], self.vae_image_processor.config.basesize) def prepare_latents( @@ -279,9 +320,9 @@ def prepare_latents( width: int, dtype: torch.dtype, device: torch.device, - generator: Optional[Union[torch.Generator, List[torch.Generator]]], - reference_images: Optional[List[List[Image.Image]]] = None, - ) -> Tuple[torch.Tensor, torch.Tensor, List[List[Tuple[int, int, int]]]]: + generator: torch.Generator | list[torch.Generator] | None, + reference_images: list[list[Image.Image]] | None = None, + ) -> tuple[torch.Tensor, torch.Tensor, list[list[tuple[int, int, int]]]]: """Prepare 6D padded latent tensor with target noise + reference image latents. Returns: @@ -388,55 +429,86 @@ def interrupt(self) -> bool: @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( self, - images: List[Image.Image] | List[List[Image.Image]] | None = None, - prompt: str | List[str] = None, + images: list[Image.Image] | list[list[Image.Image]] | None = None, + prompt: str | list[str] = None, height: int | None = None, width: int | None = None, num_inference_steps: int = 30, - timesteps: List[int] = None, - sigmas: List[float] = None, + timesteps: list[int] = None, + sigmas: list[float] = None, guidance_scale: float = 4.0, - negative_prompt: Optional[Union[str, List[str]]] = None, - generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, - latents: Optional[torch.Tensor] = None, - prompt_embeds: Optional[torch.Tensor] = None, - prompt_embeds_mask: Optional[torch.Tensor] = None, - negative_prompt_embeds: Optional[torch.Tensor] = None, - negative_prompt_embeds_mask: Optional[torch.Tensor] = None, - output_type: Optional[str] = "pil", + negative_prompt: str | list[str] | None = None, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.Tensor | None = None, + prompt_embeds: torch.Tensor | None = None, + prompt_embeds_mask: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds_mask: torch.Tensor | None = None, + output_type: str | None = "pil", return_dict: bool = True, - callback_on_step_end: Optional[ - Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks] - ] = None, - callback_on_step_end_tensor_inputs: List[str] = ["latents"], + callback_on_step_end: Callable[[int, int, dict], None] | PipelineCallback | MultiPipelineCallbacks | None = None, + callback_on_step_end_tensor_inputs: list[str] = ["latents"], max_sequence_length: int = 4096, ): r""" - Generate an edited image from multiple reference images and a text prompt. + Function invoked when calling the pipeline for generation. Args: - images (`List[Image.Image]` or `List[List[Image.Image]]`): - Reference images for editing. Each image can have a different resolution. - If a flat list is provided, it's treated as one sample with multiple references. - prompt (`str` or `List[str]`): - Text prompt describing the desired edit. + images (`list[Image.Image]` or `list[list[Image.Image]]`, *optional*): + Reference images for editing. Each image can have a different resolution. If a flat list is provided, + it is treated as one sample with multiple references. + prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds` + instead. height (`int`, *optional*): - Output height in pixels. If None, determined from the last reference image's bucket. + The height in pixels of the generated image. If `None`, determined from the last reference image. width (`int`, *optional*): - Output width in pixels. If None, determined from the last reference image's bucket. - num_inference_steps (`int`, defaults to 30): - Number of denoising steps. - guidance_scale (`float`, defaults to 4.0): - Classifier-free guidance scale. - negative_prompt (`str` or `List[str]`, *optional*): - Negative prompt for CFG. - generator (`torch.Generator`, *optional*): - RNG generator for reproducibility. + The width in pixels of the generated image. If `None`, determined from the last reference image. + num_inference_steps (`int`, *optional*, defaults to `30`): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + timesteps (`list[int]`, *optional*): + Custom timesteps to use for the denoising process. If not defined, equal spacing is used. + sigmas (`list[float]`, *optional*): + Custom sigmas to use for the denoising process. + guidance_scale (`float`, *optional*, defaults to `4.0`): + Classifier-free guidance scale. Higher values encourage the model to generate images more aligned + with the `prompt` at the expense of lower image quality. + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, a blank prompt is used + for classifier-free guidance. + generator (`torch.Generator` or `list[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents to be used as inputs for image generation. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs. + prompt_embeds_mask (`torch.Tensor`, *optional*): + Attention mask for pre-generated text embeddings. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. + negative_prompt_embeds_mask (`torch.Tensor`, *optional*): + Attention mask for pre-generated negative text embeddings. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generated image. Choose between `"pil"` (`PIL.Image.Image`), `"np"` + (`np.ndarray`), `"pt"` (`torch.Tensor`), or `"latent"` for raw latent output. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`JoyImageEditPlusPipelineOutput`] instead of a plain tuple. + callback_on_step_end (`Callable`, *optional*): + A function called at the end of each denoising step with arguments: the pipeline, step index, + timestep, and a dict of callback tensor inputs. + callback_on_step_end_tensor_inputs (`list[str]`, *optional*, defaults to `["latents"]`): + The list of tensor inputs for the `callback_on_step_end` function. + max_sequence_length (`int`, *optional*, defaults to `4096`): + Maximum sequence length for the text encoder. Examples: Returns: - [`JoyImageEditPlusPipelineOutput`] or `tuple`. + [`JoyImageEditPlusPipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`JoyImageEditPlusPipelineOutput`] is returned, otherwise a `tuple` is + returned where the first element is a list of generated images. """ # Normalize images input to List[List[Image]] if images is not None: From 4f6cb761c1025f07aa9cfab6304f2600a0027759 Mon Sep 17 00:00:00 2001 From: "tangyanfei.8" Date: Thu, 25 Jun 2026 02:19:27 +0000 Subject: [PATCH 08/14] refactor: update code formate --- .../en/api/pipelines/joyimage_edit_plus.md | 2 +- .../transformer_joyimage_edit_plus.py | 242 +++++++++++++++--- .../joyimage/pipeline_joyimage_edit_plus.py | 28 +- 3 files changed, 218 insertions(+), 54 deletions(-) diff --git a/docs/source/en/api/pipelines/joyimage_edit_plus.md b/docs/source/en/api/pipelines/joyimage_edit_plus.md index 2ce8e2f29647..1e4574a1e86f 100644 --- a/docs/source/en/api/pipelines/joyimage_edit_plus.md +++ b/docs/source/en/api/pipelines/joyimage_edit_plus.md @@ -35,7 +35,7 @@ images = [ Image.open("reference_1.png").convert("RGB"), ] -target_h, target_w = pipeline._get_bucket_size(images[-1]) +target_h, target_w = pipeline.vae_image_processor.get_default_height_width(images[-1]) output = pipeline( images=images, diff --git a/src/diffusers/models/transformers/transformer_joyimage_edit_plus.py b/src/diffusers/models/transformers/transformer_joyimage_edit_plus.py index 125dc30cc726..334309e8f989 100644 --- a/src/diffusers/models/transformers/transformer_joyimage_edit_plus.py +++ b/src/diffusers/models/transformers/transformer_joyimage_edit_plus.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import inspect import math import torch @@ -20,20 +21,15 @@ from ...configuration_utils import ConfigMixin, register_to_config from ...utils import logging -from ..attention import AttentionMixin, FeedForward +from ..attention import AttentionMixin, AttentionModuleMixin, FeedForward from ..attention_dispatch import dispatch_attention_fn +from ..embeddings import PixArtAlphaTextProjection, TimestepEmbedding, Timesteps from ..modeling_outputs import Transformer2DModelOutput from ..modeling_utils import ModelMixin from ..normalization import FP32LayerNorm -from .transformer_joyimage import ( - JoyImageAttention, - JoyImageModulate, - JoyImageTimeTextImageEmbedding, - JoyImageTransformerBlock, -) -logger = logging.get_logger(__name__) +logger = logging.get_logger(__name__) # pylint: disable=invalid-name def _apply_rotary_emb_batched( @@ -41,17 +37,12 @@ def _apply_rotary_emb_batched( xk: torch.Tensor, freqs_cis: tuple[torch.Tensor, torch.Tensor], ) -> tuple[torch.Tensor, torch.Tensor]: - """RoPE that handles both batched [B, S, D] and unbatched [S, D] freqs.""" + """RoPE for batched [B, S, D] freqs.""" cos, sin = freqs_cis[0].to(xq.device), freqs_cis[1].to(xq.device) - if cos.ndim == 2: - # unbatched: [S, D] -> [1, S, 1, D] - cos = cos.unsqueeze(0).unsqueeze(2) - sin = sin.unsqueeze(0).unsqueeze(2) - elif cos.ndim == 3: - # batched: [B, S, D] -> [B, S, 1, D] - cos = cos.unsqueeze(2) - sin = sin.unsqueeze(2) + # batched: [B, S, D] -> [B, S, 1, D] + cos = cos.unsqueeze(2) + sin = sin.unsqueeze(2) def _rotate_half(x): x_real, x_imag = x.float().reshape(*x.shape[:-1], -1, 2).unbind(-1) @@ -62,6 +53,27 @@ def _rotate_half(x): return xq_out, xk_out +# Copied from diffusers.models.transformers.transformer_joyimage.JoyImageModulate with JoyImage->JoyImageEditPlus +class JoyImageEditPlusModulate(nn.Module): + """Wan-style learnable modulation table. + + Produces `factor` modulation vectors by adding the conditioning signal to a learnable parameter table. + """ + + def __init__(self, hidden_size: int, factor: int, dtype=None, device=None): + super().__init__() + self.factor = factor + self.modulate_table = nn.Parameter( + torch.zeros(1, factor, hidden_size, dtype=dtype, device=device) / hidden_size**0.5, + requires_grad=True, + ) + + def forward(self, x: torch.Tensor) -> list[torch.Tensor]: + if x.ndim != 3: + x = x.unsqueeze(1) + return [o.squeeze(1) for o in (self.modulate_table + x).chunk(self.factor, dim=1)] + + class JoyImageEditPlusAttnProcessor: """Attention processor that supports batched RoPE embeddings for edit-plus multi-image input.""" @@ -73,12 +85,11 @@ def __init__(self): def __call__( self, - attn: "JoyImageAttention", + attn: "JoyImageEditPlusAttention", hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor = None, image_rotary_emb: tuple[torch.Tensor, torch.Tensor] | None = None, attention_mask: torch.Tensor | None = None, - **kwargs, ) -> tuple[torch.Tensor, torch.Tensor]: if encoder_hidden_states is None: raise ValueError("JoyImageEditPlusAttnProcessor requires encoder_hidden_states") @@ -138,6 +149,180 @@ def __call__( return img_attn_output, txt_attn_output +# Copied from diffusers.models.transformers.transformer_joyimage.JoyImageAttention with JoyImage->JoyImageEditPlus +class JoyImageEditPlusAttention(nn.Module, AttentionModuleMixin): + """Joint attention module for JoyImage Edit Plus double-stream blocks.""" + + _default_processor_cls = JoyImageEditPlusAttnProcessor + _available_processors = [JoyImageEditPlusAttnProcessor] + _supports_qkv_fusion = False + + def __init__( + self, + dim: int, + num_attention_heads: int, + attention_head_dim: int, + eps: float = 1e-6, + processor=None, + ): + super().__init__() + + self.heads = num_attention_heads + self.head_dim = attention_head_dim + inner_dim = num_attention_heads * attention_head_dim + + self.img_attn_qkv = nn.Linear(dim, inner_dim * 3, bias=True) + self.img_attn_q_norm = nn.RMSNorm(attention_head_dim, eps=eps) + self.img_attn_k_norm = nn.RMSNorm(attention_head_dim, eps=eps) + self.img_attn_proj = nn.Linear(inner_dim, dim, bias=True) + + self.txt_attn_qkv = nn.Linear(dim, inner_dim * 3, bias=True) + self.txt_attn_q_norm = nn.RMSNorm(attention_head_dim, eps=eps) + self.txt_attn_k_norm = nn.RMSNorm(attention_head_dim, eps=eps) + self.txt_attn_proj = nn.Linear(inner_dim, dim, bias=True) + + if processor is None: + processor = self._default_processor_cls() + self.set_processor(processor) + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor | None = None, + image_rotary_emb: tuple[torch.Tensor, torch.Tensor] | None = None, + attention_mask: torch.Tensor | None = None, + ) -> tuple[torch.Tensor, torch.Tensor]: + attn_parameters = set(inspect.signature(self.processor.__call__).parameters.keys()) + kwargs = {} + if "attention_mask" in attn_parameters: + kwargs["attention_mask"] = attention_mask + return self.processor(self, hidden_states, encoder_hidden_states, image_rotary_emb, **kwargs) + + +# Copied from diffusers.models.transformers.transformer_joyimage.JoyImageTransformerBlock with JoyImage->JoyImageEditPlus +class JoyImageEditPlusTransformerBlock(nn.Module): + """Double-stream transformer block for JoyImage Edit Plus.""" + + def __init__( + self, + dim: int, + num_attention_heads: int, + attention_head_dim: int, + mlp_width_ratio: float = 4.0, + eps: float = 1e-6, + ): + super().__init__() + + self.dim = dim + self.num_attention_heads = num_attention_heads + self.attention_head_dim = attention_head_dim + mlp_hidden_dim = int(dim * mlp_width_ratio) + + # image stream + self.img_mod = JoyImageEditPlusModulate(dim, factor=6) + self.img_norm1 = FP32LayerNorm(dim, elementwise_affine=False, eps=eps) + self.img_norm2 = FP32LayerNorm(dim, elementwise_affine=False, eps=eps) + self.img_mlp = FeedForward(dim, inner_dim=mlp_hidden_dim, activation_fn="gelu-approximate") + + # text stream + self.txt_mod = JoyImageEditPlusModulate(dim, factor=6) + self.txt_norm1 = FP32LayerNorm(dim, elementwise_affine=False, eps=eps) + self.txt_norm2 = FP32LayerNorm(dim, elementwise_affine=False, eps=eps) + self.txt_mlp = FeedForward(dim, inner_dim=mlp_hidden_dim, activation_fn="gelu-approximate") + + # joint attention + self.attn = JoyImageEditPlusAttention(dim, num_attention_heads, attention_head_dim, eps=eps) + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, + temb: torch.Tensor, + image_rotary_emb: tuple[torch.Tensor, torch.Tensor] | None = None, + attention_mask: torch.Tensor | None = None, + ) -> tuple[torch.Tensor, torch.Tensor]: + # modulation + ( + img_mod1_shift, + img_mod1_scale, + img_mod1_gate, + img_mod2_shift, + img_mod2_scale, + img_mod2_gate, + ) = self.img_mod(temb) + ( + txt_mod1_shift, + txt_mod1_scale, + txt_mod1_gate, + txt_mod2_shift, + txt_mod2_scale, + txt_mod2_gate, + ) = self.txt_mod(temb) + + # --- attention --- + img_normed = self.img_norm1(hidden_states) + txt_normed = self.txt_norm1(encoder_hidden_states) + img_modulated = img_normed * (1 + img_mod1_scale.unsqueeze(1)) + img_mod1_shift.unsqueeze(1) + txt_modulated = txt_normed * (1 + txt_mod1_scale.unsqueeze(1)) + txt_mod1_shift.unsqueeze(1) + + img_attn, txt_attn = self.attn( + hidden_states=img_modulated, + encoder_hidden_states=txt_modulated, + image_rotary_emb=image_rotary_emb, + attention_mask=attention_mask, + ) + + hidden_states = hidden_states + img_attn * img_mod1_gate.unsqueeze(1) + encoder_hidden_states = encoder_hidden_states + txt_attn * txt_mod1_gate.unsqueeze(1) + + # --- FFN --- + img_ffn_normed = self.img_norm2(hidden_states) + txt_ffn_normed = self.txt_norm2(encoder_hidden_states) + img_ffn_input = img_ffn_normed * (1 + img_mod2_scale.unsqueeze(1)) + img_mod2_shift.unsqueeze(1) + txt_ffn_input = txt_ffn_normed * (1 + txt_mod2_scale.unsqueeze(1)) + txt_mod2_shift.unsqueeze(1) + img_ffn_output = self.img_mlp(img_ffn_input) + txt_ffn_output = self.txt_mlp(txt_ffn_input) + hidden_states = hidden_states + img_ffn_output * img_mod2_gate.unsqueeze(1) + encoder_hidden_states = encoder_hidden_states + txt_ffn_output * txt_mod2_gate.unsqueeze(1) + + return hidden_states, encoder_hidden_states + + +# Copied from diffusers.models.transformers.transformer_joyimage.JoyImageTimeTextImageEmbedding with JoyImage->JoyImageEditPlus +class JoyImageEditPlusTimeTextImageEmbedding(nn.Module): + def __init__( + self, + dim: int, + time_freq_dim: int, + time_proj_dim: int, + text_embed_dim: int, + ): + super().__init__() + + self.timesteps_proj = Timesteps(num_channels=time_freq_dim, flip_sin_to_cos=True, downscale_freq_shift=0) + self.time_embedder = TimestepEmbedding(in_channels=time_freq_dim, time_embed_dim=dim) + self.act_fn = nn.SiLU() + self.time_proj = nn.Linear(dim, time_proj_dim) + self.text_embedder = PixArtAlphaTextProjection(text_embed_dim, dim, act_fn="gelu_tanh") + + def forward( + self, + timestep: torch.Tensor, + encoder_hidden_states: torch.Tensor, + ): + timestep = self.timesteps_proj(timestep) + + time_embedder_dtype = next(iter(self.time_embedder.parameters())).dtype + if timestep.dtype != time_embedder_dtype and time_embedder_dtype != torch.int8: + timestep = timestep.to(time_embedder_dtype) + temb = self.time_embedder(timestep).type_as(encoder_hidden_states) + timestep_proj = self.time_proj(self.act_fn(temb)) + + encoder_hidden_states = self.text_embedder(encoder_hidden_states) + + return temb, timestep_proj, encoder_hidden_states + + class JoyImageEditPlusTransformer3DModel(ModelMixin, ConfigMixin, AttentionMixin): r""" JoyImage Edit Plus Transformer for multi-image editing. @@ -173,7 +358,7 @@ class JoyImageEditPlusTransformer3DModel(ModelMixin, ConfigMixin, AttentionMixin """ _skip_layerwise_casting_patterns = ["img_in", "condition_embedder", "norm"] - _no_split_modules = ["JoyImageTransformerBlock"] + _no_split_modules = ["JoyImageEditPlusTransformerBlock"] _supports_gradient_checkpointing = True _keep_in_fp32_modules = [ "time_embedder", @@ -181,7 +366,7 @@ class JoyImageEditPlusTransformer3DModel(ModelMixin, ConfigMixin, AttentionMixin "norm2", "norm_out", ] - _repeated_blocks = ["JoyImageTransformerBlock"] + _repeated_blocks = ["JoyImageEditPlusTransformerBlock"] @register_to_config def __init__( @@ -216,7 +401,7 @@ def __init__( self.img_in = nn.Conv3d(in_channels, hidden_size, kernel_size=patch_size, stride=patch_size) - self.condition_embedder = JoyImageTimeTextImageEmbedding( + self.condition_embedder = JoyImageEditPlusTimeTextImageEmbedding( dim=hidden_size, time_freq_dim=256, time_proj_dim=hidden_size * 6, @@ -225,7 +410,7 @@ def __init__( self.double_blocks = nn.ModuleList( [ - JoyImageTransformerBlock( + JoyImageEditPlusTransformerBlock( dim=hidden_size, num_attention_heads=num_attention_heads, attention_head_dim=attention_head_dim, @@ -277,7 +462,7 @@ def forward( timestep: torch.Tensor, encoder_hidden_states: torch.Tensor, encoder_hidden_states_mask: torch.Tensor | None = None, - shape_list: list[list[tuple[int, int, int]]] | None = None, + shape_list: list[list[tuple[int, int, int]]] = None, return_dict: bool = True, ) -> torch.Tensor | tuple: """ @@ -292,15 +477,9 @@ def forward( batch_size, max_num_patches, channels, pt, ph, pw = hidden_states.shape device = hidden_states.device - if shape_list is None: - raise ValueError( - "shape_list must be provided either as an argument or via forward_batch.vae_image_sizes" - ) - # 1. Condition embeddings _, vec, txt = self.condition_embedder(timestep, encoder_hidden_states) - if vec.shape[-1] > self.hidden_size: - vec = vec.unflatten(1, (6, -1)) + vec = vec.unflatten(1, (6, -1)) # 2. Patchify via Conv3d: flatten (B, N) -> apply conv -> reshape back x = hidden_states.reshape(batch_size * max_num_patches, channels, pt, ph, pw) @@ -338,7 +517,6 @@ def forward( vis_freqs = (torch.stack(sample_cos_list), torch.stack(sample_sin_list)) # 4. Build attention mask: [B, 1, 1, img_seq + txt_seq] - # img patches: only actual (non-padding) patches are valid; txt uses encoder_hidden_states_mask attention_mask = None if encoder_hidden_states_mask is not None: img_mask = torch.zeros(batch_size, max_num_patches, device=device, dtype=encoder_hidden_states_mask.dtype) diff --git a/src/diffusers/pipelines/joyimage/pipeline_joyimage_edit_plus.py b/src/diffusers/pipelines/joyimage/pipeline_joyimage_edit_plus.py index d314219b45d4..222430ffb341 100644 --- a/src/diffusers/pipelines/joyimage/pipeline_joyimage_edit_plus.py +++ b/src/diffusers/pipelines/joyimage/pipeline_joyimage_edit_plus.py @@ -13,7 +13,6 @@ # limitations under the License. import inspect -import math from typing import Callable import numpy as np @@ -33,7 +32,7 @@ from ...utils import logging, replace_example_docstring from ...utils.torch_utils import randn_tensor from ..pipeline_utils import DiffusionPipeline -from .image_processor import JoyImageEditImageProcessor, find_best_bucket +from .image_processor import JoyImageEditImageProcessor from .pipeline_output import JoyImageEditPlusPipelineOutput @@ -69,6 +68,7 @@ """ +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps def retrieve_timesteps( scheduler, num_inference_steps: int | None = None, @@ -298,20 +298,6 @@ def denormalize_latents(self, latent: torch.Tensor) -> torch.Tensor: latent = latent / self.vae.config.scaling_factor return latent - def _resize_center_crop(self, img: Image.Image, target_size: tuple[int, int]) -> Image.Image: - w, h = img.size - bh, bw = target_size - scale = max(bh / h, bw / w) - resize_h, resize_w = math.ceil(h * scale), math.ceil(w * scale) - img = img.resize((resize_w, resize_h), Image.LANCZOS) - left = (resize_w - bw) // 2 - top = (resize_h - bh) // 2 - img = img.crop((left, top, left + bw, top + bh)) - return img - - def _get_bucket_size(self, img: Image.Image) -> tuple[int, int]: - return find_best_bucket(img.size[1], img.size[0], self.vae_image_processor.config.basesize) - def prepare_latents( self, batch_size: int, @@ -352,8 +338,8 @@ def prepare_latents( # Reference images if reference_images is not None and reference_images[i]: for ref_img_pil in reference_images[i]: - ref_h, ref_w = self._get_bucket_size(ref_img_pil) - ref_img_pil = self._resize_center_crop(ref_img_pil, (ref_h, ref_w)) + ref_h, ref_w = self.vae_image_processor.get_default_height_width(ref_img_pil) + ref_img_pil = self.vae_image_processor.resize_center_crop(ref_img_pil, (ref_h, ref_w)) ref_tensor = torch.from_numpy(np.array(ref_img_pil.convert("RGB"))).to(device=device, dtype=dtype) ref_tensor = (ref_tensor / 127.5 - 1.0).permute(2, 0, 1).unsqueeze(1).unsqueeze(0) @@ -529,7 +515,7 @@ def __call__( if height is None or width is None: if images is not None and len(images[0]) > 0: last_img = images[0][-1] - height, width = self._get_bucket_size(last_img) + height, width = self.vae_image_processor.get_default_height_width(last_img) else: height = height or 1024 width = width or 1024 @@ -542,8 +528,8 @@ def __call__( for sample_imgs in images: processed_sample = [] for img in sample_imgs: - ref_h, ref_w = self._get_bucket_size(img) - resize_img = self._resize_center_crop(img, (ref_h, ref_w)) + ref_h, ref_w = self.vae_image_processor.get_default_height_width(img) + resize_img = self.vae_image_processor.resize_center_crop(img, (ref_h, ref_w)) processed_sample.append(resize_img) processed_images.append(processed_sample) images = processed_images From d63e322e61d0edf637e15e9730eb988b08bb2a6f Mon Sep 17 00:00:00 2001 From: "tangyanfei.8" Date: Thu, 2 Jul 2026 02:18:41 +0000 Subject: [PATCH 09/14] refactor: merge edit-plus conversion script into convert_joyimage_edit_to_diffusers.py MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit JoyImage Edit and Edit Plus share identical VAE and transformer weight layouts — only the target model class differs. Consolidate both into a single script with a --model_type flag (edit | edit_plus) instead of maintaining two nearly-duplicate files. --- ...convert_joyimage_edit_plus_to_diffusers.py | 290 ------------------ scripts/convert_joyimage_edit_to_diffusers.py | 189 ++++++------ 2 files changed, 91 insertions(+), 388 deletions(-) delete mode 100644 scripts/convert_joyimage_edit_plus_to_diffusers.py diff --git a/scripts/convert_joyimage_edit_plus_to_diffusers.py b/scripts/convert_joyimage_edit_plus_to_diffusers.py deleted file mode 100644 index f01adb03c747..000000000000 --- a/scripts/convert_joyimage_edit_plus_to_diffusers.py +++ /dev/null @@ -1,290 +0,0 @@ -import argparse -from typing import Any, Dict, Tuple - -import torch -from accelerate import init_empty_weights -from transformers import AutoProcessor, AutoTokenizer, Qwen3VLForConditionalGeneration - -from diffusers import ( - AutoencoderKLWan, - JoyImageEditPlusPipeline, -) -from diffusers.models.transformers.transformer_joyimage_edit_plus import JoyImageEditPlusTransformer3DModel -from diffusers.schedulers.scheduling_flow_match_euler_discrete import ( - FlowMatchEulerDiscreteScheduler, -) - - -# VAE conversion reused from convert_joyimage_edit_to_diffusers.py (identical VAE) -def convert_vae(vae_ckpt_path): - old_state_dict = torch.load(vae_ckpt_path, weights_only=True) - new_state_dict = {} - - middle_key_mapping = { - "encoder.middle.0.residual.0.gamma": "encoder.mid_block.resnets.0.norm1.gamma", - "encoder.middle.0.residual.2.bias": "encoder.mid_block.resnets.0.conv1.bias", - "encoder.middle.0.residual.2.weight": "encoder.mid_block.resnets.0.conv1.weight", - "encoder.middle.0.residual.3.gamma": "encoder.mid_block.resnets.0.norm2.gamma", - "encoder.middle.0.residual.6.bias": "encoder.mid_block.resnets.0.conv2.bias", - "encoder.middle.0.residual.6.weight": "encoder.mid_block.resnets.0.conv2.weight", - "encoder.middle.2.residual.0.gamma": "encoder.mid_block.resnets.1.norm1.gamma", - "encoder.middle.2.residual.2.bias": "encoder.mid_block.resnets.1.conv1.bias", - "encoder.middle.2.residual.2.weight": "encoder.mid_block.resnets.1.conv1.weight", - "encoder.middle.2.residual.3.gamma": "encoder.mid_block.resnets.1.norm2.gamma", - "encoder.middle.2.residual.6.bias": "encoder.mid_block.resnets.1.conv2.bias", - "encoder.middle.2.residual.6.weight": "encoder.mid_block.resnets.1.conv2.weight", - "decoder.middle.0.residual.0.gamma": "decoder.mid_block.resnets.0.norm1.gamma", - "decoder.middle.0.residual.2.bias": "decoder.mid_block.resnets.0.conv1.bias", - "decoder.middle.0.residual.2.weight": "decoder.mid_block.resnets.0.conv1.weight", - "decoder.middle.0.residual.3.gamma": "decoder.mid_block.resnets.0.norm2.gamma", - "decoder.middle.0.residual.6.bias": "decoder.mid_block.resnets.0.conv2.bias", - "decoder.middle.0.residual.6.weight": "decoder.mid_block.resnets.0.conv2.weight", - "decoder.middle.2.residual.0.gamma": "decoder.mid_block.resnets.1.norm1.gamma", - "decoder.middle.2.residual.2.bias": "decoder.mid_block.resnets.1.conv1.bias", - "decoder.middle.2.residual.2.weight": "decoder.mid_block.resnets.1.conv1.weight", - "decoder.middle.2.residual.3.gamma": "decoder.mid_block.resnets.1.norm2.gamma", - "decoder.middle.2.residual.6.bias": "decoder.mid_block.resnets.1.conv2.bias", - "decoder.middle.2.residual.6.weight": "decoder.mid_block.resnets.1.conv2.weight", - } - - attention_mapping = { - "encoder.middle.1.norm.gamma": "encoder.mid_block.attentions.0.norm.gamma", - "encoder.middle.1.to_qkv.weight": "encoder.mid_block.attentions.0.to_qkv.weight", - "encoder.middle.1.to_qkv.bias": "encoder.mid_block.attentions.0.to_qkv.bias", - "encoder.middle.1.proj.weight": "encoder.mid_block.attentions.0.proj.weight", - "encoder.middle.1.proj.bias": "encoder.mid_block.attentions.0.proj.bias", - "decoder.middle.1.norm.gamma": "decoder.mid_block.attentions.0.norm.gamma", - "decoder.middle.1.to_qkv.weight": "decoder.mid_block.attentions.0.to_qkv.weight", - "decoder.middle.1.to_qkv.bias": "decoder.mid_block.attentions.0.to_qkv.bias", - "decoder.middle.1.proj.weight": "decoder.mid_block.attentions.0.proj.weight", - "decoder.middle.1.proj.bias": "decoder.mid_block.attentions.0.proj.bias", - } - - head_mapping = { - "encoder.head.0.gamma": "encoder.norm_out.gamma", - "encoder.head.2.bias": "encoder.conv_out.bias", - "encoder.head.2.weight": "encoder.conv_out.weight", - "decoder.head.0.gamma": "decoder.norm_out.gamma", - "decoder.head.2.bias": "decoder.conv_out.bias", - "decoder.head.2.weight": "decoder.conv_out.weight", - } - - quant_mapping = { - "conv1.weight": "quant_conv.weight", - "conv1.bias": "quant_conv.bias", - "conv2.weight": "post_quant_conv.weight", - "conv2.bias": "post_quant_conv.bias", - } - - for key, value in old_state_dict.items(): - if key in middle_key_mapping: - new_state_dict[middle_key_mapping[key]] = value - elif key in attention_mapping: - new_state_dict[attention_mapping[key]] = value - elif key in head_mapping: - new_state_dict[head_mapping[key]] = value - elif key in quant_mapping: - new_state_dict[quant_mapping[key]] = value - elif key == "encoder.conv1.weight": - new_state_dict["encoder.conv_in.weight"] = value - elif key == "encoder.conv1.bias": - new_state_dict["encoder.conv_in.bias"] = value - elif key == "decoder.conv1.weight": - new_state_dict["decoder.conv_in.weight"] = value - elif key == "decoder.conv1.bias": - new_state_dict["decoder.conv_in.bias"] = value - elif key.startswith("encoder.downsamples."): - new_key = key.replace("encoder.downsamples.", "encoder.down_blocks.") - if ".residual.0.gamma" in new_key: - new_key = new_key.replace(".residual.0.gamma", ".norm1.gamma") - elif ".residual.2.bias" in new_key: - new_key = new_key.replace(".residual.2.bias", ".conv1.bias") - elif ".residual.2.weight" in new_key: - new_key = new_key.replace(".residual.2.weight", ".conv1.weight") - elif ".residual.3.gamma" in new_key: - new_key = new_key.replace(".residual.3.gamma", ".norm2.gamma") - elif ".residual.6.bias" in new_key: - new_key = new_key.replace(".residual.6.bias", ".conv2.bias") - elif ".residual.6.weight" in new_key: - new_key = new_key.replace(".residual.6.weight", ".conv2.weight") - elif ".shortcut.bias" in new_key: - new_key = new_key.replace(".shortcut.bias", ".conv_shortcut.bias") - elif ".shortcut.weight" in new_key: - new_key = new_key.replace(".shortcut.weight", ".conv_shortcut.weight") - new_state_dict[new_key] = value - elif key.startswith("decoder.upsamples."): - parts = key.split(".") - block_idx = int(parts[2]) - - if "residual" in key: - if block_idx in [0, 1, 2]: - new_block_idx = 0 - resnet_idx = block_idx - elif block_idx in [4, 5, 6]: - new_block_idx = 1 - resnet_idx = block_idx - 4 - elif block_idx in [8, 9, 10]: - new_block_idx = 2 - resnet_idx = block_idx - 8 - elif block_idx in [12, 13, 14]: - new_block_idx = 3 - resnet_idx = block_idx - 12 - else: - new_state_dict[key] = value - continue - - if ".residual.0.gamma" in key: - new_key = f"decoder.up_blocks.{new_block_idx}.resnets.{resnet_idx}.norm1.gamma" - elif ".residual.2.bias" in key: - new_key = f"decoder.up_blocks.{new_block_idx}.resnets.{resnet_idx}.conv1.bias" - elif ".residual.2.weight" in key: - new_key = f"decoder.up_blocks.{new_block_idx}.resnets.{resnet_idx}.conv1.weight" - elif ".residual.3.gamma" in key: - new_key = f"decoder.up_blocks.{new_block_idx}.resnets.{resnet_idx}.norm2.gamma" - elif ".residual.6.bias" in key: - new_key = f"decoder.up_blocks.{new_block_idx}.resnets.{resnet_idx}.conv2.bias" - elif ".residual.6.weight" in key: - new_key = f"decoder.up_blocks.{new_block_idx}.resnets.{resnet_idx}.conv2.weight" - else: - new_key = key - new_state_dict[new_key] = value - - elif ".shortcut." in key: - if block_idx == 4: - new_key = key.replace(".shortcut.", ".resnets.0.conv_shortcut.") - new_key = new_key.replace("decoder.upsamples.4", "decoder.up_blocks.1") - else: - new_key = key.replace("decoder.upsamples.", "decoder.up_blocks.") - new_key = new_key.replace(".shortcut.", ".conv_shortcut.") - new_state_dict[new_key] = value - - elif ".resample." in key or ".time_conv." in key: - if block_idx == 3: - new_key = key.replace(f"decoder.upsamples.{block_idx}", "decoder.up_blocks.0.upsamplers.0") - elif block_idx == 7: - new_key = key.replace(f"decoder.upsamples.{block_idx}", "decoder.up_blocks.1.upsamplers.0") - elif block_idx == 11: - new_key = key.replace(f"decoder.upsamples.{block_idx}", "decoder.up_blocks.2.upsamplers.0") - else: - new_key = key.replace("decoder.upsamples.", "decoder.up_blocks.") - new_state_dict[new_key] = value - else: - new_key = key.replace("decoder.upsamples.", "decoder.up_blocks.") - new_state_dict[new_key] = value - else: - new_state_dict[key] = value - - with init_empty_weights(): - vae = AutoencoderKLWan() - vae.load_state_dict(new_state_dict, strict=True, assign=True) - return vae - - -def get_transformer_config() -> Dict[str, Any]: - return { - "hidden_size": 4096, - "in_channels": 16, - "num_attention_heads": 32, - "num_layers": 40, - "out_channels": 16, - "patch_size": [1, 2, 2], - "rope_dim_list": [16, 56, 56], - "text_dim": 4096, - "rope_type": "rope", - "theta": 10000, - } - - -def convert_transformer(ckpt_path: str): - checkpoint = torch.load(ckpt_path, weights_only=True) - if "model" in checkpoint: - original_state_dict = checkpoint["model"] - else: - original_state_dict = checkpoint - - attn_suffixes = ( - "img_attn_qkv.", - "img_attn_q_norm.", - "img_attn_k_norm.", - "img_attn_proj.", - "txt_attn_qkv.", - "txt_attn_q_norm.", - "txt_attn_k_norm.", - "txt_attn_proj.", - ) - remapped = {} - for key, value in original_state_dict.items(): - new_key = key - if key.startswith("double_blocks."): - for suffix in attn_suffixes: - if "." + suffix in key and ".attn." + suffix not in key: - new_key = key.replace("." + suffix, ".attn." + suffix) - break - remapped[new_key] = value - - config = get_transformer_config() - with init_empty_weights(): - transformer = JoyImageEditPlusTransformer3DModel(**config) - transformer.load_state_dict(remapped, strict=True, assign=True) - return transformer - - -def get_args(): - parser = argparse.ArgumentParser(description="Convert JoyImage Edit Plus checkpoints to diffusers format") - parser.add_argument("--transformer_ckpt_path", type=str, default=None) - parser.add_argument("--vae_ckpt_path", type=str, default=None) - parser.add_argument("--text_encoder_path", type=str, default=None) - parser.add_argument("--save_pipeline", action="store_true") - parser.add_argument("--output_path", type=str, required=True) - parser.add_argument("--dtype", default="bf16", help="Torch dtype (fp32, fp16, bf16)") - parser.add_argument("--flow_shift", type=float, default=1.5) - return parser.parse_args() - - -DTYPE_MAPPING = { - "fp32": torch.float32, - "fp16": torch.float16, - "bf16": torch.bfloat16, -} - -if __name__ == "__main__": - args = get_args() - transformer = None - vae = None - dtype = DTYPE_MAPPING[args.dtype] - - if args.save_pipeline: - assert args.transformer_ckpt_path is not None and args.vae_ckpt_path is not None - assert args.text_encoder_path is not None - - if args.transformer_ckpt_path is not None: - transformer = convert_transformer(args.transformer_ckpt_path) - transformer = transformer.to(dtype=dtype) - if not args.save_pipeline: - transformer.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB") - - if args.vae_ckpt_path is not None: - vae = convert_vae(args.vae_ckpt_path) - vae = vae.to(dtype=dtype) - if not args.save_pipeline: - vae.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB") - - if args.save_pipeline: - processor = AutoProcessor.from_pretrained(args.text_encoder_path) - text_encoder = Qwen3VLForConditionalGeneration.from_pretrained( - args.text_encoder_path, torch_dtype=torch.bfloat16 - ).to("cuda") - tokenizer = AutoTokenizer.from_pretrained(args.text_encoder_path) - scheduler = FlowMatchEulerDiscreteScheduler(num_train_timesteps=1000, shift=args.flow_shift) - transformer = transformer.to("cuda") - vae = vae.to("cuda") - pipe = JoyImageEditPlusPipeline( - processor=processor, - transformer=transformer, - text_encoder=text_encoder, - tokenizer=tokenizer, - vae=vae, - scheduler=scheduler, - ).to("cuda") - pipe.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB") - processor.save_pretrained(f"{args.output_path}/processor") diff --git a/scripts/convert_joyimage_edit_to_diffusers.py b/scripts/convert_joyimage_edit_to_diffusers.py index 3ad23de8f462..a25600932e59 100644 --- a/scripts/convert_joyimage_edit_to_diffusers.py +++ b/scripts/convert_joyimage_edit_to_diffusers.py @@ -1,3 +1,28 @@ +"""Convert JoyImage Edit / Edit Plus checkpoints to diffusers format. + +Supports both JoyImage-Edit (single-image editing) and JoyImage-Edit-Plus +(multi-image editing). The transformer weight layout is identical; only the +target model class and pipeline differ. + +Usage: + # Convert JoyImage Edit (default) + python convert_joyimage_edit_to_diffusers.py \ + --transformer_ckpt_path /path/to/transformer.pt \ + --vae_ckpt_path /path/to/vae.pt \ + --text_encoder_path Qwen/Qwen3-VL-8B-Instruct \ + --output_path /path/to/output \ + --save_pipeline + + # Convert JoyImage Edit Plus + python convert_joyimage_edit_to_diffusers.py \ + --model_type edit_plus \ + --transformer_ckpt_path /path/to/transformer.pt \ + --vae_ckpt_path /path/to/vae.pt \ + --text_encoder_path Qwen/Qwen3-VL-8B-Instruct \ + --output_path /path/to/output \ + --save_pipeline +""" + import argparse from typing import Any, Dict, Tuple @@ -10,19 +35,18 @@ JoyImageEditPipeline, JoyImageEditTransformer3DModel, ) +from diffusers.models.transformers.transformer_joyimage_edit_plus import JoyImageEditPlusTransformer3DModel +from diffusers.pipelines.joyimage.pipeline_joyimage_edit_plus import JoyImageEditPlusPipeline from diffusers.schedulers.scheduling_flow_match_euler_discrete import ( FlowMatchEulerDiscreteScheduler, ) -# This code is modified from convert_wan_to_diffusers.py to support input ckpt path def convert_vae(vae_ckpt_path): old_state_dict = torch.load(vae_ckpt_path, weights_only=True) new_state_dict = {} - # Create mappings for specific components middle_key_mapping = { - # Encoder middle block "encoder.middle.0.residual.0.gamma": "encoder.mid_block.resnets.0.norm1.gamma", "encoder.middle.0.residual.2.bias": "encoder.mid_block.resnets.0.conv1.bias", "encoder.middle.0.residual.2.weight": "encoder.mid_block.resnets.0.conv1.weight", @@ -35,7 +59,6 @@ def convert_vae(vae_ckpt_path): "encoder.middle.2.residual.3.gamma": "encoder.mid_block.resnets.1.norm2.gamma", "encoder.middle.2.residual.6.bias": "encoder.mid_block.resnets.1.conv2.bias", "encoder.middle.2.residual.6.weight": "encoder.mid_block.resnets.1.conv2.weight", - # Decoder middle block "decoder.middle.0.residual.0.gamma": "decoder.mid_block.resnets.0.norm1.gamma", "decoder.middle.0.residual.2.bias": "decoder.mid_block.resnets.0.conv1.bias", "decoder.middle.0.residual.2.weight": "decoder.mid_block.resnets.0.conv1.weight", @@ -50,15 +73,12 @@ def convert_vae(vae_ckpt_path): "decoder.middle.2.residual.6.weight": "decoder.mid_block.resnets.1.conv2.weight", } - # Create a mapping for attention blocks attention_mapping = { - # Encoder middle attention "encoder.middle.1.norm.gamma": "encoder.mid_block.attentions.0.norm.gamma", "encoder.middle.1.to_qkv.weight": "encoder.mid_block.attentions.0.to_qkv.weight", "encoder.middle.1.to_qkv.bias": "encoder.mid_block.attentions.0.to_qkv.bias", "encoder.middle.1.proj.weight": "encoder.mid_block.attentions.0.proj.weight", "encoder.middle.1.proj.bias": "encoder.mid_block.attentions.0.proj.bias", - # Decoder middle attention "decoder.middle.1.norm.gamma": "decoder.mid_block.attentions.0.norm.gamma", "decoder.middle.1.to_qkv.weight": "decoder.mid_block.attentions.0.to_qkv.weight", "decoder.middle.1.to_qkv.bias": "decoder.mid_block.attentions.0.to_qkv.bias", @@ -66,19 +86,15 @@ def convert_vae(vae_ckpt_path): "decoder.middle.1.proj.bias": "decoder.mid_block.attentions.0.proj.bias", } - # Create a mapping for the head components head_mapping = { - # Encoder head "encoder.head.0.gamma": "encoder.norm_out.gamma", "encoder.head.2.bias": "encoder.conv_out.bias", "encoder.head.2.weight": "encoder.conv_out.weight", - # Decoder head "decoder.head.0.gamma": "decoder.norm_out.gamma", "decoder.head.2.bias": "decoder.conv_out.bias", "decoder.head.2.weight": "decoder.conv_out.weight", } - # Create a mapping for the quant components quant_mapping = { "conv1.weight": "quant_conv.weight", "conv1.bias": "quant_conv.bias", @@ -86,40 +102,25 @@ def convert_vae(vae_ckpt_path): "conv2.bias": "post_quant_conv.bias", } - # Process each key in the state dict for key, value in old_state_dict.items(): - # Handle middle block keys using the mapping if key in middle_key_mapping: - new_key = middle_key_mapping[key] - new_state_dict[new_key] = value - # Handle attention blocks using the mapping + new_state_dict[middle_key_mapping[key]] = value elif key in attention_mapping: - new_key = attention_mapping[key] - new_state_dict[new_key] = value - # Handle head keys using the mapping + new_state_dict[attention_mapping[key]] = value elif key in head_mapping: - new_key = head_mapping[key] - new_state_dict[new_key] = value - # Handle quant keys using the mapping + new_state_dict[head_mapping[key]] = value elif key in quant_mapping: - new_key = quant_mapping[key] - new_state_dict[new_key] = value - # Handle encoder conv1 + new_state_dict[quant_mapping[key]] = value elif key == "encoder.conv1.weight": new_state_dict["encoder.conv_in.weight"] = value elif key == "encoder.conv1.bias": new_state_dict["encoder.conv_in.bias"] = value - # Handle decoder conv1 elif key == "decoder.conv1.weight": new_state_dict["decoder.conv_in.weight"] = value elif key == "decoder.conv1.bias": new_state_dict["decoder.conv_in.bias"] = value - # Handle encoder downsamples elif key.startswith("encoder.downsamples."): - # Convert to down_blocks new_key = key.replace("encoder.downsamples.", "encoder.down_blocks.") - - # Convert residual block naming but keep the original structure if ".residual.0.gamma" in new_key: new_key = new_key.replace(".residual.0.gamma", ".norm1.gamma") elif ".residual.2.bias" in new_key: @@ -136,16 +137,11 @@ def convert_vae(vae_ckpt_path): new_key = new_key.replace(".shortcut.bias", ".conv_shortcut.bias") elif ".shortcut.weight" in new_key: new_key = new_key.replace(".shortcut.weight", ".conv_shortcut.weight") - new_state_dict[new_key] = value - - # Handle decoder upsamples elif key.startswith("decoder.upsamples."): - # Convert to up_blocks parts = key.split(".") block_idx = int(parts[2]) - # Group residual blocks if "residual" in key: if block_idx in [0, 1, 2]: new_block_idx = 0 @@ -160,11 +156,9 @@ def convert_vae(vae_ckpt_path): new_block_idx = 3 resnet_idx = block_idx - 12 else: - # Keep as is for other blocks new_state_dict[key] = value continue - # Convert residual block naming if ".residual.0.gamma" in key: new_key = f"decoder.up_blocks.{new_block_idx}.resnets.{resnet_idx}.norm1.gamma" elif ".residual.2.bias" in key: @@ -179,10 +173,8 @@ def convert_vae(vae_ckpt_path): new_key = f"decoder.up_blocks.{new_block_idx}.resnets.{resnet_idx}.conv2.weight" else: new_key = key - new_state_dict[new_key] = value - # Handle shortcut connections elif ".shortcut." in key: if block_idx == 4: new_key = key.replace(".shortcut.", ".resnets.0.conv_shortcut.") @@ -190,35 +182,22 @@ def convert_vae(vae_ckpt_path): else: new_key = key.replace("decoder.upsamples.", "decoder.up_blocks.") new_key = new_key.replace(".shortcut.", ".conv_shortcut.") - new_state_dict[new_key] = value - # Handle upsamplers elif ".resample." in key or ".time_conv." in key: if block_idx == 3: - new_key = key.replace( - f"decoder.upsamples.{block_idx}", - "decoder.up_blocks.0.upsamplers.0", - ) + new_key = key.replace(f"decoder.upsamples.{block_idx}", "decoder.up_blocks.0.upsamplers.0") elif block_idx == 7: - new_key = key.replace( - f"decoder.upsamples.{block_idx}", - "decoder.up_blocks.1.upsamplers.0", - ) + new_key = key.replace(f"decoder.upsamples.{block_idx}", "decoder.up_blocks.1.upsamplers.0") elif block_idx == 11: - new_key = key.replace( - f"decoder.upsamples.{block_idx}", - "decoder.up_blocks.2.upsamplers.0", - ) + new_key = key.replace(f"decoder.upsamples.{block_idx}", "decoder.up_blocks.2.upsamplers.0") else: new_key = key.replace("decoder.upsamples.", "decoder.up_blocks.") - new_state_dict[new_key] = value else: new_key = key.replace("decoder.upsamples.", "decoder.up_blocks.") new_state_dict[new_key] = value else: - # Keep other keys unchanged new_state_dict[key] = value with init_empty_weights(): @@ -227,32 +206,27 @@ def convert_vae(vae_ckpt_path): return vae -def get_transformer_config() -> Tuple[Dict[str, Any], ...]: - config = { - "diffusers_config": { - "hidden_size": 4096, - "in_channels": 16, - "num_attention_heads": 32, - "num_layers": 40, - "out_channels": 16, - "patch_size": [1, 2, 2], - "rope_dim_list": [16, 56, 56], - "text_dim": 4096, - "rope_type": "rope", - "theta": 10000, - }, - } - return config +TRANSFORMER_CONFIG = { + "hidden_size": 4096, + "in_channels": 16, + "num_attention_heads": 32, + "num_layers": 40, + "out_channels": 16, + "patch_size": [1, 2, 2], + "rope_dim_list": [16, 56, 56], + "text_dim": 4096, + "rope_type": "rope", + "theta": 10000, +} -def convert_transformer(ckpt_path: str): +def convert_transformer(ckpt_path: str, model_type: str = "edit"): checkpoint = torch.load(ckpt_path, weights_only=True) if "model" in checkpoint: original_state_dict = checkpoint["model"] else: original_state_dict = checkpoint - # Attention weights moved from block to block.attn submodule attn_suffixes = ( "img_attn_qkv.", "img_attn_q_norm.", @@ -268,21 +242,32 @@ def convert_transformer(ckpt_path: str): new_key = key if key.startswith("double_blocks."): for suffix in attn_suffixes: - # double_blocks.0.img_attn_qkv.weight -> double_blocks.0.attn.img_attn_qkv.weight if "." + suffix in key and ".attn." + suffix not in key: new_key = key.replace("." + suffix, ".attn." + suffix) break remapped[new_key] = value - config = get_transformer_config() + transformer_cls = ( + JoyImageEditPlusTransformer3DModel if model_type == "edit_plus" + else JoyImageEditTransformer3DModel + ) with init_empty_weights(): - transformer = JoyImageEditTransformer3DModel(**config["diffusers_config"]) + transformer = transformer_cls(**TRANSFORMER_CONFIG) transformer.load_state_dict(remapped, strict=True, assign=True) return transformer def get_args(): - parser = argparse.ArgumentParser() + parser = argparse.ArgumentParser( + description="Convert JoyImage Edit / Edit Plus checkpoints to diffusers format" + ) + parser.add_argument( + "--model_type", + type=str, + choices=["edit", "edit_plus"], + default="edit", + help="Model type: 'edit' for JoyImage-Edit, 'edit_plus' for JoyImage-Edit-Plus", + ) parser.add_argument( "--transformer_ckpt_path", type=str, @@ -299,13 +284,7 @@ def get_args(): "--text_encoder_path", type=str, default=None, - help="Path to original llama checkpoint", - ) - parser.add_argument( - "--tokenizer_path", - type=str, - default=None, - help="Path to original llama tokenizer", + help="Path to Qwen3-VL text encoder (e.g. Qwen/Qwen3-VL-8B-Instruct)", ) parser.add_argument("--save_pipeline", action="store_true") parser.add_argument( @@ -314,8 +293,8 @@ def get_args(): required=True, help="Path where converted model should be saved", ) - parser.add_argument("--dtype", default="bf16", help="Torch dtype to save the transformer in.") - parser.add_argument("--flow_shift", type=float, default=7.0) + parser.add_argument("--dtype", default="bf16", help="Torch dtype (fp32, fp16, bf16)") + parser.add_argument("--flow_shift", type=float, default=1.5) return parser.parse_args() @@ -324,6 +303,7 @@ def get_args(): "fp16": torch.float16, "bf16": torch.bfloat16, } + if __name__ == "__main__": args = get_args() transformer = None @@ -333,34 +313,47 @@ def get_args(): if args.save_pipeline: assert args.transformer_ckpt_path is not None and args.vae_ckpt_path is not None assert args.text_encoder_path is not None - # assert args.tokenizer_path is not None + if args.transformer_ckpt_path is not None: - transformer = convert_transformer(args.transformer_ckpt_path) + transformer = convert_transformer(args.transformer_ckpt_path, model_type=args.model_type) transformer = transformer.to(dtype=dtype) if not args.save_pipeline: transformer.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB") + if args.vae_ckpt_path is not None: vae = convert_vae(args.vae_ckpt_path) vae = vae.to(dtype=dtype) if not args.save_pipeline: vae.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB") + if args.save_pipeline: processor = AutoProcessor.from_pretrained(args.text_encoder_path) text_encoder = Qwen3VLForConditionalGeneration.from_pretrained( args.text_encoder_path, torch_dtype=torch.bfloat16 ).to("cuda") tokenizer = AutoTokenizer.from_pretrained(args.text_encoder_path) - flow_shift = 1.5 - scheduler = FlowMatchEulerDiscreteScheduler(num_train_timesteps=1000, shift=flow_shift) + scheduler = FlowMatchEulerDiscreteScheduler(num_train_timesteps=1000, shift=args.flow_shift) transformer = transformer.to("cuda") vae = vae.to("cuda") - pipe = JoyImageEditPipeline( - processor=processor, - transformer=transformer, - text_encoder=text_encoder, - tokenizer=tokenizer, - vae=vae, - scheduler=scheduler, - ).to("cuda") + + if args.model_type == "edit_plus": + pipe = JoyImageEditPlusPipeline( + processor=processor, + transformer=transformer, + text_encoder=text_encoder, + tokenizer=tokenizer, + vae=vae, + scheduler=scheduler, + ).to("cuda") + else: + pipe = JoyImageEditPipeline( + processor=processor, + transformer=transformer, + text_encoder=text_encoder, + tokenizer=tokenizer, + vae=vae, + scheduler=scheduler, + ).to("cuda") + pipe.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB") processor.save_pretrained(f"{args.output_path}/processor") From f170222c6270b840c5f7c63d4d7a2f6e044c30b7 Mon Sep 17 00:00:00 2001 From: "tangyanfei.8" Date: Thu, 2 Jul 2026 06:14:20 +0000 Subject: [PATCH 10/14] revert: remove unnecessary attention_mask change from transformer_joyimage.py --- src/diffusers/models/transformers/transformer_joyimage.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_joyimage.py b/src/diffusers/models/transformers/transformer_joyimage.py index d30b0501e02f..b17ddb05f799 100644 --- a/src/diffusers/models/transformers/transformer_joyimage.py +++ b/src/diffusers/models/transformers/transformer_joyimage.py @@ -283,7 +283,6 @@ def forward( encoder_hidden_states: torch.Tensor, temb: torch.Tensor, image_rotary_emb: Tuple[torch.Tensor, torch.Tensor] | None = None, - attention_mask: torch.Tensor | None = None, ) -> Tuple[torch.Tensor, torch.Tensor]: # modulation ( @@ -313,7 +312,6 @@ def forward( hidden_states=img_modulated, encoder_hidden_states=txt_modulated, image_rotary_emb=image_rotary_emb, - attention_mask=attention_mask, ) hidden_states = hidden_states + img_attn * img_mod1_gate.unsqueeze(1) From 31421c121d797cc14eed35af12ea438a4acf4873 Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" Date: Fri, 3 Jul 2026 01:39:58 +0000 Subject: [PATCH 11/14] Apply style fixes --- docs/source/en/_toctree.yml | 4 +- scripts/convert_joyimage_edit_to_diffusers.py | 8 +-- src/diffusers/__init__.py | 4 +- src/diffusers/models/__init__.py | 2 +- .../transformer_joyimage_edit_plus.py | 10 ++-- src/diffusers/pipelines/__init__.py | 14 ++++- .../joyimage/pipeline_joyimage_edit_plus.py | 56 ++++++++++--------- .../pipelines/joyimage/pipeline_output.py | 2 +- 8 files changed, 56 insertions(+), 44 deletions(-) diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index b84bd0281a7c..0ae5262c39e3 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -353,10 +353,10 @@ title: HunyuanVideoTransformer3DModel - local: api/models/ideogram4_transformer2d title: Ideogram4Transformer2DModel - - local: api/models/transformer_joyimage - title: JoyImageEditTransformer3DModel - local: api/models/transformer_joyimage_edit_plus title: JoyImageEditPlusTransformer3DModel + - local: api/models/transformer_joyimage + title: JoyImageEditTransformer3DModel - local: api/models/krea2_transformer2d title: Krea2Transformer2DModel - local: api/models/latte_transformer3d diff --git a/scripts/convert_joyimage_edit_to_diffusers.py b/scripts/convert_joyimage_edit_to_diffusers.py index a25600932e59..3fe4a7b12cd3 100644 --- a/scripts/convert_joyimage_edit_to_diffusers.py +++ b/scripts/convert_joyimage_edit_to_diffusers.py @@ -24,7 +24,6 @@ """ import argparse -from typing import Any, Dict, Tuple import torch from accelerate import init_empty_weights @@ -248,8 +247,7 @@ def convert_transformer(ckpt_path: str, model_type: str = "edit"): remapped[new_key] = value transformer_cls = ( - JoyImageEditPlusTransformer3DModel if model_type == "edit_plus" - else JoyImageEditTransformer3DModel + JoyImageEditPlusTransformer3DModel if model_type == "edit_plus" else JoyImageEditTransformer3DModel ) with init_empty_weights(): transformer = transformer_cls(**TRANSFORMER_CONFIG) @@ -258,9 +256,7 @@ def convert_transformer(ckpt_path: str, model_type: str = "edit"): def get_args(): - parser = argparse.ArgumentParser( - description="Convert JoyImage Edit / Edit Plus checkpoints to diffusers format" - ) + parser = argparse.ArgumentParser(description="Convert JoyImage Edit / Edit Plus checkpoints to diffusers format") parser.add_argument( "--model_type", type=str, diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 1b4a0f3bc5c4..8e984d43d867 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -274,8 +274,8 @@ "HunyuanVideoTransformer3DModel", "I2VGenXLUNet", "Ideogram4Transformer2DModel", - "JoyImageEditTransformer3DModel", "JoyImageEditPlusTransformer3DModel", + "JoyImageEditTransformer3DModel", "Kandinsky3UNet", "Kandinsky5Transformer3DModel", "Krea2Transformer2DModel", @@ -1145,8 +1145,8 @@ HunyuanVideoTransformer3DModel, I2VGenXLUNet, Ideogram4Transformer2DModel, - JoyImageEditTransformer3DModel, JoyImageEditPlusTransformer3DModel, + JoyImageEditTransformer3DModel, Kandinsky3UNet, Kandinsky5Transformer3DModel, Krea2Transformer2DModel, diff --git a/src/diffusers/models/__init__.py b/src/diffusers/models/__init__.py index 6d7ee8a7bb74..1746db347d32 100755 --- a/src/diffusers/models/__init__.py +++ b/src/diffusers/models/__init__.py @@ -255,8 +255,8 @@ HunyuanVideoFramepackTransformer3DModel, HunyuanVideoTransformer3DModel, Ideogram4Transformer2DModel, - JoyImageEditTransformer3DModel, JoyImageEditPlusTransformer3DModel, + JoyImageEditTransformer3DModel, Kandinsky5Transformer3DModel, Krea2Transformer2DModel, LatteTransformer3DModel, diff --git a/src/diffusers/models/transformers/transformer_joyimage_edit_plus.py b/src/diffusers/models/transformers/transformer_joyimage_edit_plus.py index 334309e8f989..d45b58c38883 100644 --- a/src/diffusers/models/transformers/transformer_joyimage_edit_plus.py +++ b/src/diffusers/models/transformers/transformer_joyimage_edit_plus.py @@ -327,8 +327,8 @@ class JoyImageEditPlusTransformer3DModel(ModelMixin, ConfigMixin, AttentionMixin r""" JoyImage Edit Plus Transformer for multi-image editing. - Uses a patchify+padding approach where each reference image and the target noise are independently - patchified and concatenated into a flat patch sequence. Supports variable-resolution reference images. + Uses a patchify+padding approach where each reference image and the target noise are independently patchified and + concatenated into a flat patch sequence. Supports variable-resolution reference images. Input format: `[B, max_patches, C, pt, ph, pw]` (6D padded patches). @@ -541,9 +541,9 @@ def forward( # 6. Output projection + reshape to 6D patches img = self.proj_out(self.norm_out(img)) - img = img.reshape( - batch_size, max_num_patches, pt, ph, pw, self.out_channels - ).permute(0, 1, 5, 2, 3, 4) # -> [B, N, C, pt, ph, pw] + img = img.reshape(batch_size, max_num_patches, pt, ph, pw, self.out_channels).permute( + 0, 1, 5, 2, 3, 4 + ) # -> [B, N, C, pt, ph, pw] if not return_dict: return (img,) diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py index 5fbc684bf376..7b0c43727975 100644 --- a/src/diffusers/pipelines/__init__.py +++ b/src/diffusers/pipelines/__init__.py @@ -346,7 +346,12 @@ "LTX2ImageToVideoPipeline", "LTX2LatentUpsamplePipeline", ] - _import_structure["joyimage"] = ["JoyImageEditPipeline", "JoyImageEditPipelineOutput", "JoyImageEditPlusPipeline", "JoyImageEditPlusPipelineOutput"] + _import_structure["joyimage"] = [ + "JoyImageEditPipeline", + "JoyImageEditPipelineOutput", + "JoyImageEditPlusPipeline", + "JoyImageEditPlusPipelineOutput", + ] _import_structure["lumina"] = ["LuminaPipeline", "LuminaText2ImgPipeline"] _import_structure["lumina2"] = ["Lumina2Pipeline", "Lumina2Text2ImgPipeline"] _import_structure["lucy"] = ["LucyEditPipeline"] @@ -760,7 +765,12 @@ from .hunyuan_video1_5 import HunyuanVideo15ImageToVideoPipeline, HunyuanVideo15Pipeline from .hunyuandit import HunyuanDiTPipeline from .ideogram4 import Ideogram4Pipeline, Ideogram4PromptEnhancerHead - from .joyimage import JoyImageEditPipeline, JoyImageEditPipelineOutput, JoyImageEditPlusPipeline, JoyImageEditPlusPipelineOutput + from .joyimage import ( + JoyImageEditPipeline, + JoyImageEditPipelineOutput, + JoyImageEditPlusPipeline, + JoyImageEditPlusPipelineOutput, + ) from .kandinsky import ( KandinskyCombinedPipeline, KandinskyImg2ImgCombinedPipeline, diff --git a/src/diffusers/pipelines/joyimage/pipeline_joyimage_edit_plus.py b/src/diffusers/pipelines/joyimage/pipeline_joyimage_edit_plus.py index 222430ffb341..22713e975653 100644 --- a/src/diffusers/pipelines/joyimage/pipeline_joyimage_edit_plus.py +++ b/src/diffusers/pipelines/joyimage/pipeline_joyimage_edit_plus.py @@ -97,8 +97,8 @@ def retrieve_timesteps( `num_inference_steps` and `timesteps` must be `None`. Returns: - `tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and - the second element is the number of inference steps. + `tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. """ if timesteps is not None and sigmas is not None: raise ValueError("Only one of `timesteps` or `sigmas` can be passed.") @@ -126,8 +126,8 @@ class JoyImageEditPlusPipeline(DiffusionPipeline): r""" Diffusion pipeline for multi-image instruction-guided editing using JoyImage Edit Plus. - Supports multiple reference images with different resolutions. Each reference image is independently - VAE-encoded and patchified, then concatenated with the target noise patches for joint denoising. + Supports multiple reference images with different resolutions. Each reference image is independently VAE-encoded + and patchified, then concatenated with the target noise patches for joint denoising. Args: scheduler ([`FlowMatchEulerDiscreteScheduler`]): @@ -265,9 +265,7 @@ def _pad_sequence(self, x: torch.Tensor, target_length: int) -> torch.Tensor: return x[:, -target_length:] padding_length = target_length - current_length if x.ndim >= 3: - padding = torch.zeros( - (x.shape[0], padding_length, *x.shape[2:]), dtype=x.dtype, device=x.device - ) + padding = torch.zeros((x.shape[0], padding_length, *x.shape[2:]), dtype=x.dtype, device=x.device) else: padding = torch.zeros((x.shape[0], padding_length), dtype=x.dtype, device=x.device) return torch.cat([x, padding], dim=1) @@ -312,8 +310,7 @@ def prepare_latents( """Prepare 6D padded latent tensor with target noise + reference image latents. Returns: - padded_latents: [B, max_patches, C, pt, ph, pw] - target_mask: [B, max_patches] (True for target patches) + padded_latents: [B, max_patches, C, pt, ph, pw] target_mask: [B, max_patches] (True for target patches) shape_list: per-sample list of (t, h, w) tuples for each component """ pt, ph, pw = self.transformer.config.patch_size @@ -432,7 +429,10 @@ def __call__( negative_prompt_embeds_mask: torch.Tensor | None = None, output_type: str | None = "pil", return_dict: bool = True, - callback_on_step_end: Callable[[int, int, dict], None] | PipelineCallback | MultiPipelineCallbacks | None = None, + callback_on_step_end: Callable[[int, int, dict], None] + | PipelineCallback + | MultiPipelineCallbacks + | None = None, callback_on_step_end_tensor_inputs: list[str] = ["latents"], max_sequence_length: int = 4096, ): @@ -458,11 +458,11 @@ def __call__( sigmas (`list[float]`, *optional*): Custom sigmas to use for the denoising process. guidance_scale (`float`, *optional*, defaults to `4.0`): - Classifier-free guidance scale. Higher values encourage the model to generate images more aligned - with the `prompt` at the expense of lower image quality. + Classifier-free guidance scale. Higher values encourage the model to generate images more aligned with + the `prompt` at the expense of lower image quality. negative_prompt (`str` or `list[str]`, *optional*): - The prompt or prompts not to guide the image generation. If not defined, a blank prompt is used - for classifier-free guidance. + The prompt or prompts not to guide the image generation. If not defined, a blank prompt is used for + classifier-free guidance. generator (`torch.Generator` or `list[torch.Generator]`, *optional*): One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. @@ -482,8 +482,8 @@ def __call__( return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`JoyImageEditPlusPipelineOutput`] instead of a plain tuple. callback_on_step_end (`Callable`, *optional*): - A function called at the end of each denoising step with arguments: the pipeline, step index, - timestep, and a dict of callback tensor inputs. + A function called at the end of each denoising step with arguments: the pipeline, step index, timestep, + and a dict of callback tensor inputs. callback_on_step_end_tensor_inputs (`list[str]`, *optional*, defaults to `["latents"]`): The list of tensor inputs for the `callback_on_step_end` function. max_sequence_length (`int`, *optional*, defaults to `4096`): @@ -590,15 +590,19 @@ def __call__( # Pad and concatenate [negative, positive] max_seq_len = max(prompt_embeds.shape[1], negative_prompt_embeds.shape[1]) - prompt_embeds = torch.cat([ - self._pad_sequence(negative_prompt_embeds, max_seq_len), - self._pad_sequence(prompt_embeds, max_seq_len), - ]) + prompt_embeds = torch.cat( + [ + self._pad_sequence(negative_prompt_embeds, max_seq_len), + self._pad_sequence(prompt_embeds, max_seq_len), + ] + ) if prompt_embeds_mask is not None and negative_prompt_embeds_mask is not None: - prompt_embeds_mask = torch.cat([ - self._pad_sequence(negative_prompt_embeds_mask, max_seq_len), - self._pad_sequence(prompt_embeds_mask, max_seq_len), - ]) + prompt_embeds_mask = torch.cat( + [ + self._pad_sequence(negative_prompt_embeds_mask, max_seq_len), + self._pad_sequence(prompt_embeds_mask, max_seq_len), + ] + ) # Prepare timesteps timesteps, num_inference_steps = retrieve_timesteps( @@ -697,7 +701,9 @@ def __call__( target_patches = padded_latents[b_idx, :target_len] c_lat = target_patches.shape[1] video_latent = target_patches.reshape(l_t, l_h, l_w, c_lat, pt, ph, pw) - video_latent = video_latent.permute(3, 0, 4, 1, 5, 2, 6).reshape(1, c_lat, l_t * pt, l_h * ph, l_w * pw) + video_latent = video_latent.permute(3, 0, 4, 1, 5, 2, 6).reshape( + 1, c_lat, l_t * pt, l_h * ph, l_w * pw + ) video_latent = self.denormalize_latents(video_latent) diff --git a/src/diffusers/pipelines/joyimage/pipeline_output.py b/src/diffusers/pipelines/joyimage/pipeline_output.py index 30be7c248e33..4ffb3e53a103 100644 --- a/src/diffusers/pipelines/joyimage/pipeline_output.py +++ b/src/diffusers/pipelines/joyimage/pipeline_output.py @@ -15,6 +15,7 @@ class JoyImageEditPipelineOutput(BaseOutput): images: Union[List[PIL.Image.Image], np.ndarray] + @dataclass class JoyImageEditPlusPipelineOutput(BaseOutput): """ @@ -22,4 +23,3 @@ class JoyImageEditPlusPipelineOutput(BaseOutput): """ images: Union[List[PIL.Image.Image], np.ndarray] - From 33f550d325c2d3d6db08a5f5ede9f048a17a9a08 Mon Sep 17 00:00:00 2001 From: "tangyanfei.8" Date: Fri, 3 Jul 2026 02:06:50 +0000 Subject: [PATCH 12/14] fix: remove stale Copied-from annotations that diverged from source --- .../models/transformers/transformer_joyimage_edit_plus.py | 2 -- src/diffusers/pipelines/joyimage/pipeline_joyimage_edit_plus.py | 1 - 2 files changed, 3 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_joyimage_edit_plus.py b/src/diffusers/models/transformers/transformer_joyimage_edit_plus.py index d45b58c38883..706204760dac 100644 --- a/src/diffusers/models/transformers/transformer_joyimage_edit_plus.py +++ b/src/diffusers/models/transformers/transformer_joyimage_edit_plus.py @@ -149,7 +149,6 @@ def __call__( return img_attn_output, txt_attn_output -# Copied from diffusers.models.transformers.transformer_joyimage.JoyImageAttention with JoyImage->JoyImageEditPlus class JoyImageEditPlusAttention(nn.Module, AttentionModuleMixin): """Joint attention module for JoyImage Edit Plus double-stream blocks.""" @@ -199,7 +198,6 @@ def forward( return self.processor(self, hidden_states, encoder_hidden_states, image_rotary_emb, **kwargs) -# Copied from diffusers.models.transformers.transformer_joyimage.JoyImageTransformerBlock with JoyImage->JoyImageEditPlus class JoyImageEditPlusTransformerBlock(nn.Module): """Double-stream transformer block for JoyImage Edit Plus.""" diff --git a/src/diffusers/pipelines/joyimage/pipeline_joyimage_edit_plus.py b/src/diffusers/pipelines/joyimage/pipeline_joyimage_edit_plus.py index 22713e975653..6294933aa96c 100644 --- a/src/diffusers/pipelines/joyimage/pipeline_joyimage_edit_plus.py +++ b/src/diffusers/pipelines/joyimage/pipeline_joyimage_edit_plus.py @@ -68,7 +68,6 @@ """ -# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps def retrieve_timesteps( scheduler, num_inference_steps: int | None = None, From ac7a4b68b1910759b69bc22d617f7ae3da20940d Mon Sep 17 00:00:00 2001 From: "tangyanfei.8" Date: Fri, 3 Jul 2026 04:45:14 +0000 Subject: [PATCH 13/14] fix: address CI check failures for edit-plus PR - Add missing Returns section to JoyImageEditPlusTransformer3DModel.forward docstring - Fix alphabetical ordering of dummy classes in dummy_pt_objects.py --- .../models/transformers/transformer_joyimage_edit_plus.py | 4 ++++ src/diffusers/utils/dummy_pt_objects.py | 4 ++-- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_joyimage_edit_plus.py b/src/diffusers/models/transformers/transformer_joyimage_edit_plus.py index 706204760dac..f245e1bf094f 100644 --- a/src/diffusers/models/transformers/transformer_joyimage_edit_plus.py +++ b/src/diffusers/models/transformers/transformer_joyimage_edit_plus.py @@ -471,6 +471,10 @@ def forward( encoder_hidden_states_mask: [B, L] - attention mask for text tokens. shape_list: Per-sample list of (t, h, w) tuples for each component (target + references). return_dict: Whether to return a dict or tuple. + + Returns: + If `return_dict` is True, an [`~models.modeling_outputs.Transformer2DModelOutput`] is returned, otherwise a + `tuple` where the first element is the sample tensor. """ batch_size, max_num_patches, channels, pt, ph, pw = hidden_states.shape device = hidden_states.device diff --git a/src/diffusers/utils/dummy_pt_objects.py b/src/diffusers/utils/dummy_pt_objects.py index 36489e473fb5..9035efb3e6e2 100644 --- a/src/diffusers/utils/dummy_pt_objects.py +++ b/src/diffusers/utils/dummy_pt_objects.py @@ -1485,7 +1485,7 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) -class JoyImageEditTransformer3DModel(metaclass=DummyObject): +class JoyImageEditPlusTransformer3DModel(metaclass=DummyObject): _backends = ["torch"] def __init__(self, *args, **kwargs): @@ -1500,7 +1500,7 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) -class JoyImageEditPlusTransformer3DModel(metaclass=DummyObject): +class JoyImageEditTransformer3DModel(metaclass=DummyObject): _backends = ["torch"] def __init__(self, *args, **kwargs): From 3fa516a2201609cf0488b992dca96b148d512240 Mon Sep 17 00:00:00 2001 From: "tangyanfei.8" Date: Fri, 3 Jul 2026 06:43:34 +0000 Subject: [PATCH 14/14] fix: use VAE dtype instead of float32 cast in edit-plus pipeline Replace `torch.autocast(dtype=torch.float32)` + `.float()` with `.to(self.vae.dtype)` for both VAE encode and decode calls. The previous approach caused dtype mismatch (float32 input vs bfloat16 bias) on CPU where autocast does not automatically cast conv weights, breaking the CI `test_layerwise_casting_inference` test. --- .../pipelines/joyimage/pipeline_joyimage_edit_plus.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/src/diffusers/pipelines/joyimage/pipeline_joyimage_edit_plus.py b/src/diffusers/pipelines/joyimage/pipeline_joyimage_edit_plus.py index 6294933aa96c..34d4a34caea0 100644 --- a/src/diffusers/pipelines/joyimage/pipeline_joyimage_edit_plus.py +++ b/src/diffusers/pipelines/joyimage/pipeline_joyimage_edit_plus.py @@ -340,8 +340,7 @@ def prepare_latents( ref_tensor = torch.from_numpy(np.array(ref_img_pil.convert("RGB"))).to(device=device, dtype=dtype) ref_tensor = (ref_tensor / 127.5 - 1.0).permute(2, 0, 1).unsqueeze(1).unsqueeze(0) - with torch.autocast(device_type=device.type, dtype=torch.float32): - ref_latent = self.vae.encode(ref_tensor.float()).latent_dist.mode() + ref_latent = self.vae.encode(ref_tensor.to(self.vae.dtype)).latent_dist.mode() ref_latent = ref_latent.to(dtype) ref_latent = self.normalize_latents(ref_latent) ref_latent = ref_latent.squeeze(0) # [C, 1, H', W'] @@ -706,8 +705,7 @@ def __call__( video_latent = self.denormalize_latents(video_latent) - with torch.autocast(device_type=device.type, dtype=torch.float32): - sample_image = self.vae.decode(video_latent.float(), return_dict=False)[0] + sample_image = self.vae.decode(video_latent.to(self.vae.dtype), return_dict=False)[0] sample_image = (sample_image / 2 + 0.5).clamp(0, 1).squeeze(0).cpu().float() image_list.append(sample_image)