diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index 8e8776d4a8c2..00ad1bc0d96e 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -500,6 +500,8 @@ title: Stable Audio title: Audio - sections: + - local: api/pipelines/anima + title: Anima - local: api/pipelines/animatediff title: AnimateDiff - local: api/pipelines/aura_flow diff --git a/docs/source/en/api/pipelines/anima.md b/docs/source/en/api/pipelines/anima.md new file mode 100644 index 000000000000..b7db72632776 --- /dev/null +++ b/docs/source/en/api/pipelines/anima.md @@ -0,0 +1,26 @@ +# Anima + +Anima is a text-to-image model that reuses the [`CosmosTransformer3DModel`] with a Qwen3 text encoder, a T5-token text conditioner, and the [`AutoencoderKLQwenImage`] VAE. + +```python +import torch +from diffusers import ModularPipeline + +pipe = ModularPipeline.from_pretrained("mrfatso/anima-preview3-diffusers") +pipe.load_components(torch_dtype=torch.bfloat16) +pipe.to("cuda") + +image = pipe(prompt="masterpiece, best quality, 1girl, solo, city lights").images[0] +``` + +## AnimaModularPipeline + +[[autodoc]] AnimaModularPipeline + +## AnimaAutoBlocks + +[[autodoc]] AnimaAutoBlocks + +## AnimaTextConditioner + +[[autodoc]] AnimaTextConditioner diff --git a/scripts/convert_anima_to_diffusers.py b/scripts/convert_anima_to_diffusers.py new file mode 100644 index 000000000000..99005cd1e22a --- /dev/null +++ b/scripts/convert_anima_to_diffusers.py @@ -0,0 +1,314 @@ +""" +Convert Anima checkpoints to Diffusers format. + +Example: +```bash +python scripts/convert_anima_to_diffusers.py \ + --transformer_ckpt_path anima_model/anima-preview3-base.safetensors \ + --text_encoder_ckpt_path anima_model/qwen_3_06b_base.safetensors \ + --vae_ckpt_path anima_model/qwen_image_vae.safetensors \ + --qwen_tokenizer_path /home/user/Dev/ComfyUI/comfy/text_encoders/qwen25_tokenizer \ + --t5_tokenizer_path /home/user/Dev/ComfyUI/comfy/text_encoders/t5_tokenizer \ + --output_path anima_model/anima-preview3-diffusers \ + --save_pipeline +``` +""" + +import argparse +import pathlib +import sys +from typing import Any + +import torch +from accelerate import init_empty_weights +from convert_cosmos_to_diffusers import convert_transformer +from safetensors.torch import load_file +from transformers import AutoTokenizer, Qwen3Config, Qwen3Model, T5TokenizerFast + +from diffusers import ( + AnimaAutoBlocks, + AnimaTextConditioner, + AutoencoderKLQwenImage, + FlowMatchEulerDiscreteScheduler, +) + + +DTYPE_MAPPING = { + "fp32": torch.float32, + "fp16": torch.float16, + "bf16": torch.bfloat16, +} + + +def rename_residual_key(key: str) -> str: + replacements = { + ".residual.0.": ".norm1.", + ".residual.2.": ".conv1.", + ".residual.3.": ".norm2.", + ".residual.6.": ".conv2.", + ".shortcut.": ".conv_shortcut.", + } + for old, new in replacements.items(): + key = key.replace(old, new) + return key + + +def rename_mid_key(key: str) -> str: + replacements = { + ".middle.0.": ".mid_block.resnets.0.", + ".middle.1.": ".mid_block.attentions.0.", + ".middle.2.": ".mid_block.resnets.1.", + } + for old, new in replacements.items(): + key = key.replace(old, new) + return rename_residual_key(key) + + +def rename_decoder_upsample_key(key: str) -> str: + prefix = "decoder.upsamples." + suffix = key.removeprefix(prefix) + index_str, rest = suffix.split(".", 1) + index = int(index_str) + + if index in (3, 7, 11): + block_index = (index - 3) // 4 + new_key = f"decoder.up_blocks.{block_index}.upsamplers.0.{rest}" + else: + block_index = index // 4 + resnet_index = index % 4 + new_key = f"decoder.up_blocks.{block_index}.resnets.{resnet_index}.{rest}" + + return rename_residual_key(new_key) + + +def convert_qwen_image_vae_state_dict(state_dict: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]: + converted_state_dict = {} + for key, value in state_dict.items(): + if key.startswith("conv1."): + new_key = key.replace("conv1.", "quant_conv.", 1) + elif key.startswith("conv2."): + new_key = key.replace("conv2.", "post_quant_conv.", 1) + elif key.startswith("encoder.conv1."): + new_key = key.replace("encoder.conv1.", "encoder.conv_in.", 1) + elif key.startswith("decoder.conv1."): + new_key = key.replace("decoder.conv1.", "decoder.conv_in.", 1) + elif key.startswith("encoder.downsamples."): + new_key = rename_residual_key(key.replace("encoder.downsamples.", "encoder.down_blocks.", 1)) + elif key.startswith("decoder.upsamples."): + new_key = rename_decoder_upsample_key(key) + elif key.startswith("encoder.middle.") or key.startswith("decoder.middle."): + new_key = rename_mid_key(key) + elif key.startswith("encoder.head.0."): + new_key = key.replace("encoder.head.0.", "encoder.norm_out.", 1) + elif key.startswith("encoder.head.2."): + new_key = key.replace("encoder.head.2.", "encoder.conv_out.", 1) + elif key.startswith("decoder.head.0."): + new_key = key.replace("decoder.head.0.", "decoder.norm_out.", 1) + elif key.startswith("decoder.head.2."): + new_key = key.replace("decoder.head.2.", "decoder.conv_out.", 1) + else: + new_key = rename_residual_key(key) + + if new_key in converted_state_dict: + raise ValueError(f"Duplicate converted VAE key: {new_key}") + converted_state_dict[new_key] = value + + return converted_state_dict + + +def convert_qwen_image_vae(state_dict: dict[str, torch.Tensor]) -> AutoencoderKLQwenImage: + converted_state_dict = convert_qwen_image_vae_state_dict(state_dict) + with init_empty_weights(): + vae = AutoencoderKLQwenImage() + + expected_keys = set(vae.state_dict().keys()) + converted_keys = set(converted_state_dict.keys()) + missing_keys = expected_keys - converted_keys + unexpected_keys = converted_keys - expected_keys + if missing_keys or unexpected_keys: + if missing_keys: + print(f"ERROR: missing VAE keys ({len(missing_keys)}):", file=sys.stderr) + for key in sorted(missing_keys): + print(key, file=sys.stderr) + if unexpected_keys: + print(f"ERROR: unexpected VAE keys ({len(unexpected_keys)}):", file=sys.stderr) + for key in sorted(unexpected_keys): + print(key, file=sys.stderr) + sys.exit(1) + + vae.load_state_dict(converted_state_dict, strict=True, assign=True) + return vae + + +def infer_text_conditioner_config(state_dict: dict[str, torch.Tensor]) -> dict[str, Any]: + model_dim = state_dict["blocks.0.self_attn.q_proj.weight"].shape[0] + source_dim = state_dict["blocks.0.cross_attn.k_proj.weight"].shape[1] + target_vocab_size, target_dim = state_dict["embed.weight"].shape + attention_head_dim = state_dict["blocks.0.self_attn.q_norm.weight"].shape[0] + num_layers = 1 + max(int(key.split(".")[1]) for key in state_dict if key.startswith("blocks.")) + + return { + "source_dim": source_dim, + "target_dim": target_dim, + "model_dim": model_dim, + "num_layers": num_layers, + "num_attention_heads": model_dim // attention_head_dim, + "target_vocab_size": target_vocab_size, + } + + +def convert_text_conditioner(state_dict: dict[str, torch.Tensor]) -> AnimaTextConditioner: + config = infer_text_conditioner_config(state_dict) + with init_empty_weights(): + text_conditioner = AnimaTextConditioner(**config) + + expected_keys = set(text_conditioner.state_dict().keys()) + converted_keys = set(state_dict.keys()) + missing_keys = expected_keys - converted_keys + unexpected_keys = converted_keys - expected_keys + if missing_keys or unexpected_keys: + if missing_keys: + print(f"ERROR: missing text conditioner keys ({len(missing_keys)}):", file=sys.stderr) + for key in sorted(missing_keys): + print(key, file=sys.stderr) + if unexpected_keys: + print(f"ERROR: unexpected text conditioner keys ({len(unexpected_keys)}):", file=sys.stderr) + for key in sorted(unexpected_keys): + print(key, file=sys.stderr) + sys.exit(1) + + text_conditioner.load_state_dict(state_dict, strict=True, assign=True) + return text_conditioner + + +def infer_qwen3_config(state_dict: dict[str, torch.Tensor]) -> Qwen3Config: + vocab_size, hidden_size = state_dict["embed_tokens.weight"].shape + intermediate_size = state_dict["layers.0.mlp.gate_proj.weight"].shape[0] + num_hidden_layers = 1 + max(int(key.split(".")[1]) for key in state_dict if key.startswith("layers.")) + head_dim = state_dict["layers.0.self_attn.q_norm.weight"].shape[0] + num_attention_heads = state_dict["layers.0.self_attn.q_proj.weight"].shape[0] // head_dim + num_key_value_heads = state_dict["layers.0.self_attn.k_proj.weight"].shape[0] // head_dim + + return Qwen3Config( + vocab_size=vocab_size, + hidden_size=hidden_size, + intermediate_size=intermediate_size, + num_hidden_layers=num_hidden_layers, + num_attention_heads=num_attention_heads, + num_key_value_heads=num_key_value_heads, + max_position_embeddings=32768, + rms_norm_eps=1e-6, + rope_theta=1000000.0, + head_dim=head_dim, + attention_bias=False, + tie_word_embeddings=False, + ) + + +def convert_text_encoder(state_dict: dict[str, torch.Tensor]) -> Qwen3Model: + state_dict = {key.removeprefix("model."): value for key, value in state_dict.items()} + config = infer_qwen3_config(state_dict) + with init_empty_weights(): + text_encoder = Qwen3Model(config) + + expected_keys = set(text_encoder.state_dict().keys()) + converted_keys = set(state_dict.keys()) + missing_keys = expected_keys - converted_keys + unexpected_keys = converted_keys - expected_keys + if missing_keys or unexpected_keys: + if missing_keys: + print(f"ERROR: missing Qwen3 keys ({len(missing_keys)}):", file=sys.stderr) + for key in sorted(missing_keys): + print(key, file=sys.stderr) + if unexpected_keys: + print(f"ERROR: unexpected Qwen3 keys ({len(unexpected_keys)}):", file=sys.stderr) + for key in sorted(unexpected_keys): + print(key, file=sys.stderr) + sys.exit(1) + + text_encoder.load_state_dict(state_dict, strict=True, assign=True) + return text_encoder + + +def split_anima_transformer_checkpoint( + state_dict: dict[str, torch.Tensor], +) -> tuple[dict[str, torch.Tensor], dict[str, torch.Tensor]]: + transformer_state_dict = {} + text_conditioner_state_dict = {} + adapter_prefix = "net.llm_adapter." + + for key, value in state_dict.items(): + if key.startswith(adapter_prefix): + text_conditioner_state_dict[key.removeprefix(adapter_prefix)] = value + else: + transformer_state_dict[key] = value + + return transformer_state_dict, text_conditioner_state_dict + + +def save_pipeline(args, transformer, text_conditioner, text_encoder, vae): + tokenizer = AutoTokenizer.from_pretrained(args.qwen_tokenizer_path) + t5_tokenizer = T5TokenizerFast.from_pretrained(args.t5_tokenizer_path) + scheduler = FlowMatchEulerDiscreteScheduler(shift=3.0) + + pipe = AnimaAutoBlocks().init_pipeline() + pipe.update_components( + text_encoder=text_encoder, + tokenizer=tokenizer, + t5_tokenizer=t5_tokenizer, + text_conditioner=text_conditioner, + transformer=transformer, + vae=vae, + scheduler=scheduler, + ) + pipe.save_pretrained(args.output_path, safe_serialization=True, max_shard_size=args.max_shard_size) + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument("--transformer_ckpt_path", type=str, required=True, help="Path to Anima DiT safetensors") + parser.add_argument("--text_encoder_ckpt_path", type=str, required=True, help="Path to Qwen3 text encoder") + parser.add_argument("--vae_ckpt_path", type=str, required=True, help="Path to Qwen-Image VAE safetensors") + parser.add_argument("--qwen_tokenizer_path", type=str, default=None) + parser.add_argument("--t5_tokenizer_path", type=str, default=None) + parser.add_argument("--output_path", type=str, required=True) + parser.add_argument("--save_pipeline", action="store_true") + parser.add_argument("--dtype", default="bf16", choices=list(DTYPE_MAPPING.keys())) + parser.add_argument("--max_shard_size", default="5GB") + return parser.parse_args() + + +if __name__ == "__main__": + args = get_args() + output_path = pathlib.Path(args.output_path) + dtype = DTYPE_MAPPING[args.dtype] + + raw_transformer_state_dict = load_file(args.transformer_ckpt_path, device="cpu") + transformer_state_dict, text_conditioner_state_dict = split_anima_transformer_checkpoint(raw_transformer_state_dict) + transformer = convert_transformer( + "Cosmos-2.0-Diffusion-2B-Text2Image", state_dict=transformer_state_dict, weights_only=True + ).to(dtype=dtype) + text_conditioner = convert_text_conditioner(text_conditioner_state_dict).to(dtype=dtype) + + text_encoder_state_dict = load_file(args.text_encoder_ckpt_path, device="cpu") + text_encoder = convert_text_encoder(text_encoder_state_dict).to(dtype=dtype) + + vae_state_dict = load_file(args.vae_ckpt_path, device="cpu") + vae = convert_qwen_image_vae(vae_state_dict).to(dtype=dtype) + + if args.save_pipeline: + if args.qwen_tokenizer_path is None or args.t5_tokenizer_path is None: + raise ValueError("`--qwen_tokenizer_path` and `--t5_tokenizer_path` are required with `--save_pipeline`.") + save_pipeline(args, transformer, text_conditioner, text_encoder, vae) + else: + output_path.mkdir(parents=True, exist_ok=True) + transformer.save_pretrained( + output_path / "transformer", safe_serialization=True, max_shard_size=args.max_shard_size + ) + text_conditioner.save_pretrained( + output_path / "text_conditioner", safe_serialization=True, max_shard_size=args.max_shard_size + ) + text_encoder.save_pretrained( + output_path / "text_encoder", safe_serialization=True, max_shard_size=args.max_shard_size + ) + vae.save_pretrained(output_path / "vae", safe_serialization=True, max_shard_size=args.max_shard_size) diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 1b1f6b3032b3..f8e0ba59f16e 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -190,6 +190,7 @@ [ "AceStepTransformer1DModel", "AllegroTransformer3DModel", + "AnimaTextConditioner", "AsymmetricAutoencoderKL", "AttentionBackendName", "AuraFlowTransformer2DModel", @@ -474,6 +475,8 @@ "QwenImageLayeredAutoBlocks", "QwenImageLayeredModularPipeline", "QwenImageModularPipeline", + "AnimaAutoBlocks", + "AnimaModularPipeline", "StableDiffusion3AutoBlocks", "StableDiffusion3ModularPipeline", "StableDiffusionXLAutoBlocks", @@ -1012,6 +1015,7 @@ from .models import ( AceStepTransformer1DModel, AllegroTransformer3DModel, + AnimaTextConditioner, AsymmetricAutoencoderKL, AttentionBackendName, AuraFlowTransformer2DModel, @@ -1245,6 +1249,8 @@ from .utils.dummy_torch_and_transformers_objects import * # noqa F403 else: from .modular_pipelines import ( + AnimaAutoBlocks, + AnimaModularPipeline, ErnieImageAutoBlocks, ErnieImageModularPipeline, Flux2AutoBlocks, diff --git a/src/diffusers/loaders/__init__.py b/src/diffusers/loaders/__init__.py index 488f77422dcd..33eeba673a98 100644 --- a/src/diffusers/loaders/__init__.py +++ b/src/diffusers/loaders/__init__.py @@ -63,6 +63,7 @@ def text_encoder_attn_modules(text_encoder): _import_structure["single_file"] = ["FromSingleFileMixin"] _import_structure["lora_pipeline"] = [ "AmusedLoraLoaderMixin", + "AnimaLoraLoaderMixin", "StableDiffusionLoraLoaderMixin", "SD3LoraLoaderMixin", "AuraFlowLoraLoaderMixin", @@ -116,6 +117,7 @@ def text_encoder_attn_modules(text_encoder): ) from .lora_pipeline import ( AmusedLoraLoaderMixin, + AnimaLoraLoaderMixin, AuraFlowLoraLoaderMixin, CogVideoXLoraLoaderMixin, CogView4LoraLoaderMixin, diff --git a/src/diffusers/loaders/lora_conversion_utils.py b/src/diffusers/loaders/lora_conversion_utils.py index 510a698e505f..bf516abc825f 100644 --- a/src/diffusers/loaders/lora_conversion_utils.py +++ b/src/diffusers/loaders/lora_conversion_utils.py @@ -2317,6 +2317,52 @@ def get_alpha_scales(down_weight, alpha_key): return converted_state_dict +def _convert_non_diffusers_anima_lora_to_diffusers(state_dict): + rename_dict = { + "blocks.": "transformer_blocks.", + "adaln_modulation_self_attn.1": "norm1.linear_1", + "adaln_modulation_self_attn.2": "norm1.linear_2", + "adaln_modulation_cross_attn.1": "norm2.linear_1", + "adaln_modulation_cross_attn.2": "norm2.linear_2", + "adaln_modulation_mlp.1": "norm3.linear_1", + "adaln_modulation_mlp.2": "norm3.linear_2", + "self_attn.q_proj": "attn1.to_q", + "self_attn.k_proj": "attn1.to_k", + "self_attn.v_proj": "attn1.to_v", + "self_attn.output_proj": "attn1.to_out.0", + "cross_attn.q_proj": "attn2.to_q", + "cross_attn.k_proj": "attn2.to_k", + "cross_attn.v_proj": "attn2.to_v", + "cross_attn.output_proj": "attn2.to_out.0", + "mlp.layer1": "ff.net.0.proj", + "mlp.layer2": "ff.net.2", + "final_layer.adaln_modulation.1": "norm_out.linear_1", + "final_layer.adaln_modulation.2": "norm_out.linear_2", + "final_layer.linear": "proj_out", + "t_embedder.1": "time_embed.t_embedder", + "t_embedding_norm": "time_embed.norm", + "x_embedder.proj.1": "patch_embed.proj", + } + + converted_state_dict = {} + for key, value in state_dict.items(): + if not key.startswith("diffusion_model."): + converted_state_dict[key] = value + continue + + new_key = key.removeprefix("diffusion_model.") + if new_key.startswith("llm_adapter."): + new_key = f"text_conditioner.{new_key.removeprefix('llm_adapter.')}" + else: + for old_key, new_key_part in rename_dict.items(): + new_key = new_key.replace(old_key, new_key_part) + new_key = f"transformer.{new_key}" + + converted_state_dict[new_key] = value + + return converted_state_dict + + def _convert_non_diffusers_flux2_lora_to_diffusers(state_dict): converted_state_dict = {} diff --git a/src/diffusers/loaders/lora_pipeline.py b/src/diffusers/loaders/lora_pipeline.py index 403e5a87db61..d25d39d61592 100644 --- a/src/diffusers/loaders/lora_pipeline.py +++ b/src/diffusers/loaders/lora_pipeline.py @@ -46,6 +46,7 @@ _convert_kohya_flux2_lora_to_diffusers, _convert_kohya_flux_lora_to_diffusers, _convert_musubi_wan_lora_to_diffusers, + _convert_non_diffusers_anima_lora_to_diffusers, _convert_non_diffusers_flux2_lora_to_diffusers, _convert_non_diffusers_hidream_lora_to_diffusers, _convert_non_diffusers_lora_to_diffusers, @@ -5615,6 +5616,208 @@ def unfuse_lora(self, components: list[str] = ["transformer"], **kwargs): super().unfuse_lora(components=components, **kwargs) +class AnimaLoraLoaderMixin(LoraBaseMixin): + r""" + Load LoRA layers into [`CosmosTransformer3DModel`] and [`AnimaTextConditioner`]. + """ + + _lora_loadable_modules = ["transformer", "text_conditioner"] + transformer_name = TRANSFORMER_NAME + text_conditioner_name = "text_conditioner" + + @classmethod + @validate_hf_hub_args + def lora_state_dict( + cls, + pretrained_model_name_or_path_or_dict: str | dict[str, torch.Tensor], + **kwargs, + ): + r""" + See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`] for more details. + """ + cache_dir = kwargs.pop("cache_dir", None) + force_download = kwargs.pop("force_download", False) + proxies = kwargs.pop("proxies", None) + local_files_only = kwargs.pop("local_files_only", None) + token = kwargs.pop("token", None) + revision = kwargs.pop("revision", None) + subfolder = kwargs.pop("subfolder", None) + weight_name = kwargs.pop("weight_name", None) + use_safetensors = kwargs.pop("use_safetensors", None) + return_lora_metadata = kwargs.pop("return_lora_metadata", False) + + allow_pickle = False + if use_safetensors is None: + use_safetensors = True + allow_pickle = True + + user_agent = {"file_type": "attn_procs_weights", "framework": "pytorch"} + + state_dict, metadata = _fetch_state_dict( + pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict, + weight_name=weight_name, + use_safetensors=use_safetensors, + local_files_only=local_files_only, + cache_dir=cache_dir, + force_download=force_download, + proxies=proxies, + token=token, + revision=revision, + subfolder=subfolder, + user_agent=user_agent, + allow_pickle=allow_pickle, + ) + + is_dora_scale_present = any("dora_scale" in k for k in state_dict) + if is_dora_scale_present: + warn_msg = "It seems like you are using a DoRA checkpoint that is not compatible in Diffusers at the moment. So, we are going to filter out the keys associated to 'dora_scale` from the state dict. If you think this is a mistake please open an issue https://github.com/huggingface/diffusers/issues/new." + logger.warning(warn_msg) + state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k} + + has_diffusion_model = any(k.startswith("diffusion_model.") for k in state_dict) + if has_diffusion_model: + state_dict = _convert_non_diffusers_anima_lora_to_diffusers(state_dict) + + out = (state_dict, metadata) if return_lora_metadata else state_dict + return out + + def load_lora_weights( + self, + pretrained_model_name_or_path_or_dict: str | dict[str, torch.Tensor], + adapter_name: str | None = None, + hotswap: bool = False, + **kwargs, + ): + """ + See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`] for more details. + """ + if not USE_PEFT_BACKEND: + raise ValueError("PEFT backend is required for this method.") + + low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT_LORA) + if low_cpu_mem_usage and is_peft_version("<", "0.13.0"): + raise ValueError( + "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`." + ) + + if isinstance(pretrained_model_name_or_path_or_dict, dict): + pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy() + + kwargs["return_lora_metadata"] = True + state_dict, metadata = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs) + + is_correct_format = all("lora" in key for key in state_dict.keys()) + if not is_correct_format: + raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `'lora'` substring.") + + transformer_state_dict = {k: v for k, v in state_dict.items() if k.startswith(f"{self.transformer_name}.")} + text_conditioner_state_dict = { + k: v for k, v in state_dict.items() if k.startswith(f"{self.text_conditioner_name}.") + } + + if transformer_state_dict: + self.load_lora_into_transformer( + transformer_state_dict, + transformer=self.transformer, + adapter_name=adapter_name, + metadata=metadata, + _pipeline=self, + low_cpu_mem_usage=low_cpu_mem_usage, + hotswap=hotswap, + ) + + if text_conditioner_state_dict: + self.load_lora_into_text_conditioner( + text_conditioner_state_dict, + text_conditioner=self.text_conditioner, + adapter_name=adapter_name, + metadata=metadata, + _pipeline=self, + low_cpu_mem_usage=low_cpu_mem_usage, + hotswap=hotswap, + ) + + @classmethod + def load_lora_into_transformer( + cls, + state_dict, + transformer, + adapter_name=None, + _pipeline=None, + low_cpu_mem_usage=False, + hotswap: bool = False, + metadata=None, + ): + if low_cpu_mem_usage and is_peft_version("<", "0.13.0"): + raise ValueError( + "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`." + ) + + logger.info(f"Loading {cls.transformer_name}.") + transformer.load_lora_adapter( + state_dict, + network_alphas=None, + adapter_name=adapter_name, + metadata=metadata, + _pipeline=_pipeline, + low_cpu_mem_usage=low_cpu_mem_usage, + hotswap=hotswap, + ) + + @classmethod + def load_lora_into_text_conditioner( + cls, + state_dict, + text_conditioner, + adapter_name=None, + _pipeline=None, + low_cpu_mem_usage=False, + hotswap: bool = False, + metadata=None, + ): + if low_cpu_mem_usage and is_peft_version("<", "0.13.0"): + raise ValueError( + "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`." + ) + + logger.info(f"Loading {cls.text_conditioner_name}.") + text_conditioner.load_lora_adapter( + state_dict, + prefix=cls.text_conditioner_name, + network_alphas=None, + adapter_name=adapter_name, + metadata=metadata, + _pipeline=_pipeline, + low_cpu_mem_usage=low_cpu_mem_usage, + hotswap=hotswap, + ) + + def fuse_lora( + self, + components: list[str] = ["transformer", "text_conditioner"], + lora_scale: float = 1.0, + safe_fusing: bool = False, + adapter_names: list[str] | None = None, + **kwargs, + ): + r""" + See [`~loaders.StableDiffusionLoraLoaderMixin.fuse_lora`] for more details. + """ + super().fuse_lora( + components=components, + lora_scale=lora_scale, + safe_fusing=safe_fusing, + adapter_names=adapter_names, + **kwargs, + ) + + def unfuse_lora(self, components: list[str] = ["transformer", "text_conditioner"], **kwargs): + r""" + See [`~loaders.StableDiffusionLoraLoaderMixin.unfuse_lora`] for more details. + """ + super().unfuse_lora(components=components, **kwargs) + + class Flux2LoraLoaderMixin(LoraBaseMixin): r""" Load LoRA layers into [`Flux2Transformer2DModel`]. Specific to [`Flux2Pipeline`]. diff --git a/src/diffusers/models/__init__.py b/src/diffusers/models/__init__.py index bb765c56d013..6afb3f672de0 100755 --- a/src/diffusers/models/__init__.py +++ b/src/diffusers/models/__init__.py @@ -80,6 +80,7 @@ _import_structure["embeddings"] = ["ImageProjection"] _import_structure["modeling_utils"] = ["ModelMixin"] _import_structure["transformers.ace_step_transformer"] = ["AceStepTransformer1DModel"] + _import_structure["transformers.transformer_anima"] = ["AnimaTextConditioner"] _import_structure["transformers.auraflow_transformer_2d"] = ["AuraFlowTransformer2DModel"] _import_structure["transformers.cogvideox_transformer_3d"] = ["CogVideoXTransformer3DModel"] _import_structure["transformers.consisid_transformer_3d"] = ["ConsisIDTransformer3DModel"] @@ -213,6 +214,7 @@ from .transformers import ( AceStepTransformer1DModel, AllegroTransformer3DModel, + AnimaTextConditioner, AuraFlowTransformer2DModel, BriaFiboTransformer2DModel, BriaTransformer2DModel, diff --git a/src/diffusers/models/transformers/__init__.py b/src/diffusers/models/transformers/__init__.py index 5c64b5fc99fa..aac6941abd27 100755 --- a/src/diffusers/models/transformers/__init__.py +++ b/src/diffusers/models/transformers/__init__.py @@ -18,6 +18,7 @@ from .t5_film_transformer import T5FilmDecoder from .transformer_2d import Transformer2DModel from .transformer_allegro import AllegroTransformer3DModel + from .transformer_anima import AnimaTextConditioner from .transformer_bria import BriaTransformer2DModel from .transformer_bria_fibo import BriaFiboTransformer2DModel from .transformer_chroma import ChromaTransformer2DModel diff --git a/src/diffusers/models/transformers/transformer_anima.py b/src/diffusers/models/transformers/transformer_anima.py new file mode 100644 index 000000000000..1241f5e5eaad --- /dev/null +++ b/src/diffusers/models/transformers/transformer_anima.py @@ -0,0 +1,336 @@ +# Copyright 2026 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from ...configuration_utils import ConfigMixin, register_to_config +from ...loaders import PeftAdapterMixin +from ...models.attention import AttentionModuleMixin +from ...models.attention_dispatch import dispatch_attention_fn +from ...models.modeling_utils import ModelMixin + + +def _rotate_half(hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states_1 = hidden_states[..., : hidden_states.shape[-1] // 2] + hidden_states_2 = hidden_states[..., hidden_states.shape[-1] // 2 :] + return torch.cat((-hidden_states_2, hidden_states_1), dim=-1) + + +def _apply_rotary_pos_emb( + hidden_states: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, unsqueeze_dim: int = 1 +) -> torch.Tensor: + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + return (hidden_states * cos) + (_rotate_half(hidden_states) * sin) + + +class AnimaRotaryEmbedding(nn.Module): + def __init__(self, head_dim: int, rope_theta: float = 10000.0): + super().__init__() + inv_freq = 1.0 / ( + rope_theta ** (torch.arange(0, head_dim, 2, dtype=torch.int64).to(dtype=torch.float32) / head_dim) + ) + self.register_buffer("inv_freq", inv_freq, persistent=False) + + @torch.no_grad() + def forward(self, hidden_states: torch.Tensor, position_ids: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) + inv_freq_expanded = inv_freq_expanded.to(hidden_states.device) + position_ids_expanded = position_ids[:, None, :].float() + + device_type = hidden_states.device.type + device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" + with torch.autocast(device_type=device_type, enabled=False): + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos() + sin = emb.sin() + + return cos.to(dtype=hidden_states.dtype), sin.to(dtype=hidden_states.dtype) + + +class AnimaTextConditionerAttnProcessor: + _attention_backend = None + _parallel_config = None + + def __call__( + self, + attn: "AnimaTextConditionerAttention", + hidden_states: torch.Tensor, + attention_mask: torch.Tensor | None = None, + encoder_hidden_states: torch.Tensor | None = None, + position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None, + encoder_position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None, + ) -> torch.Tensor: + encoder_hidden_states = hidden_states if encoder_hidden_states is None else encoder_hidden_states + input_shape = hidden_states.shape[:-1] + encoder_input_shape = encoder_hidden_states.shape[:-1] + + query = attn.q_proj(hidden_states) + key = attn.k_proj(encoder_hidden_states) + value = attn.v_proj(encoder_hidden_states) + + query = query.view(*input_shape, attn.num_attention_heads, attn.attention_head_dim) + key = key.view(*encoder_input_shape, attn.num_attention_heads, attn.attention_head_dim) + value = value.view(*encoder_input_shape, attn.num_attention_heads, attn.attention_head_dim) + + query = attn.q_norm(query) + key = attn.k_norm(key) + + if position_embeddings is not None: + if encoder_position_embeddings is None: + raise ValueError("`encoder_position_embeddings` must be provided when using rotary embeddings.") + cos, sin = position_embeddings + query = _apply_rotary_pos_emb(query, cos, sin, unsqueeze_dim=2) + cos, sin = encoder_position_embeddings + key = _apply_rotary_pos_emb(key, cos, sin, unsqueeze_dim=2) + + hidden_states = dispatch_attention_fn( + query, + key, + value, + attn_mask=attention_mask, + dropout_p=0.0, + is_causal=False, + backend=self._attention_backend, + parallel_config=self._parallel_config, + ) + hidden_states = hidden_states.flatten(2, 3).contiguous() + hidden_states = attn.o_proj(hidden_states) + return hidden_states + + +class AnimaTextConditionerAttention(nn.Module, AttentionModuleMixin): + _default_processor_cls = AnimaTextConditionerAttnProcessor + _available_processors = [AnimaTextConditionerAttnProcessor] + _supports_qkv_fusion = False + + def __init__( + self, + query_dim: int, + context_dim: int, + num_attention_heads: int, + attention_head_dim: int, + processor: AnimaTextConditionerAttnProcessor | None = None, + ): + super().__init__() + inner_dim = num_attention_heads * attention_head_dim + + self.num_attention_heads = num_attention_heads + self.attention_head_dim = attention_head_dim + self.q_proj = nn.Linear(query_dim, inner_dim, bias=False) + self.q_norm = nn.RMSNorm(attention_head_dim, eps=1e-6) + self.k_proj = nn.Linear(context_dim, inner_dim, bias=False) + self.k_norm = nn.RMSNorm(attention_head_dim, eps=1e-6) + self.v_proj = nn.Linear(context_dim, inner_dim, bias=False) + self.o_proj = nn.Linear(inner_dim, query_dim, bias=False) + + if processor is None: + processor = self._default_processor_cls() + self.set_processor(processor) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor | None = None, + encoder_hidden_states: torch.Tensor | None = None, + position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None, + encoder_position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None, + ) -> torch.Tensor: + return self.processor( + self, + hidden_states, + attention_mask=attention_mask, + encoder_hidden_states=encoder_hidden_states, + position_embeddings=position_embeddings, + encoder_position_embeddings=encoder_position_embeddings, + ) + + +class AnimaTextConditionerBlock(nn.Module): + def __init__( + self, + source_dim: int, + model_dim: int, + num_attention_heads: int = 16, + mlp_ratio: float = 4.0, + use_self_attention: bool = True, + use_layer_norm: bool = False, + ): + super().__init__() + self.use_self_attention = use_self_attention + norm_cls = nn.LayerNorm if use_layer_norm else nn.RMSNorm + norm_kwargs = {} if use_layer_norm else {"eps": 1e-6} + + if use_self_attention: + self.norm_self_attn = norm_cls(model_dim, **norm_kwargs) + self.self_attn = AnimaTextConditionerAttention( + query_dim=model_dim, + context_dim=model_dim, + num_attention_heads=num_attention_heads, + attention_head_dim=model_dim // num_attention_heads, + ) + + self.norm_cross_attn = norm_cls(model_dim, **norm_kwargs) + self.cross_attn = AnimaTextConditionerAttention( + query_dim=model_dim, + context_dim=source_dim, + num_attention_heads=num_attention_heads, + attention_head_dim=model_dim // num_attention_heads, + ) + self.norm_mlp = norm_cls(model_dim, **norm_kwargs) + self.mlp = nn.Sequential( + nn.Linear(model_dim, int(model_dim * mlp_ratio)), + nn.GELU(), + nn.Linear(int(model_dim * mlp_ratio), model_dim), + ) + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, + target_attention_mask: torch.Tensor | None = None, + source_attention_mask: torch.Tensor | None = None, + position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None, + source_position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None, + ) -> torch.Tensor: + if self.use_self_attention: + norm_hidden_states = self.norm_self_attn(hidden_states) + attn_hidden_states = self.self_attn( + norm_hidden_states, + attention_mask=target_attention_mask, + position_embeddings=position_embeddings, + encoder_position_embeddings=position_embeddings, + ) + hidden_states = hidden_states + attn_hidden_states + + norm_hidden_states = self.norm_cross_attn(hidden_states) + attn_hidden_states = self.cross_attn( + norm_hidden_states, + attention_mask=source_attention_mask, + encoder_hidden_states=encoder_hidden_states, + position_embeddings=position_embeddings, + encoder_position_embeddings=source_position_embeddings, + ) + hidden_states = hidden_states + attn_hidden_states + hidden_states = hidden_states + self.mlp(self.norm_mlp(hidden_states)) + return hidden_states + + +class AnimaTextConditioner(ModelMixin, ConfigMixin, PeftAdapterMixin): + r""" + Text conditioner used by Anima to map Qwen3 hidden states and T5 token ids to Cosmos text embeddings. + + Anima reuses the Cosmos Predict2 DiT. The only model-specific conditioning module is this LLM adapter, which + cross-attends from learned T5 token embeddings to Qwen3 text encoder hidden states before the diffusion loop. + `target_dim` is the conditioner output dimension and must match the transformer's `text_embed_dim`. + """ + + _supports_gradient_checkpointing = True + _no_split_modules = ["AnimaTextConditionerBlock"] + + @register_to_config + def __init__( + self, + source_dim: int = 1024, + target_dim: int = 1024, + model_dim: int = 1024, + num_layers: int = 6, + num_attention_heads: int = 16, + mlp_ratio: float = 4.0, + target_vocab_size: int = 32128, + use_self_attention: bool = True, + use_layer_norm: bool = False, + min_sequence_length: int = 512, + ): + super().__init__() + self.embed = nn.Embedding(target_vocab_size, target_dim) + self.in_proj = nn.Linear(target_dim, model_dim) if model_dim != target_dim else nn.Identity() + self.rotary_emb = AnimaRotaryEmbedding(model_dim // num_attention_heads) + self.blocks = nn.ModuleList( + [ + AnimaTextConditionerBlock( + source_dim=source_dim, + model_dim=model_dim, + num_attention_heads=num_attention_heads, + mlp_ratio=mlp_ratio, + use_self_attention=use_self_attention, + use_layer_norm=use_layer_norm, + ) + for _ in range(num_layers) + ] + ) + self.out_proj = nn.Linear(model_dim, target_dim) + self.norm = nn.RMSNorm(target_dim, eps=1e-6) + self.gradient_checkpointing = False + + @staticmethod + def _prepare_attention_mask(attention_mask: torch.Tensor | None) -> torch.Tensor | None: + if attention_mask is None: + return None + attention_mask = attention_mask.to(torch.bool) + if attention_mask.ndim == 2: + attention_mask = attention_mask.unsqueeze(1).unsqueeze(1) + return attention_mask + + def forward( + self, + source_hidden_states: torch.Tensor, + target_input_ids: torch.Tensor, + target_attention_mask: torch.Tensor | None = None, + source_attention_mask: torch.Tensor | None = None, + ) -> torch.Tensor: + target_attention_mask = self._prepare_attention_mask(target_attention_mask) + source_attention_mask = self._prepare_attention_mask(source_attention_mask) + + hidden_states = self.embed(target_input_ids).to(dtype=source_hidden_states.dtype) + hidden_states = self.in_proj(hidden_states) + + position_ids = torch.arange(hidden_states.shape[1], device=hidden_states.device).unsqueeze(0) + source_position_ids = torch.arange(source_hidden_states.shape[1], device=hidden_states.device).unsqueeze(0) + position_embeddings = self.rotary_emb(hidden_states, position_ids) + source_position_embeddings = self.rotary_emb(hidden_states, source_position_ids) + + for block in self.blocks: + if torch.is_grad_enabled() and self.gradient_checkpointing: + hidden_states = self._gradient_checkpointing_func( + block, + hidden_states, + source_hidden_states, + target_attention_mask, + source_attention_mask, + position_embeddings, + source_position_embeddings, + ) + else: + hidden_states = block( + hidden_states, + source_hidden_states, + target_attention_mask=target_attention_mask, + source_attention_mask=source_attention_mask, + position_embeddings=position_embeddings, + source_position_embeddings=source_position_embeddings, + ) + + hidden_states = self.norm(self.out_proj(hidden_states)) + + if target_attention_mask is not None: + hidden_states = hidden_states * target_attention_mask.squeeze(1).squeeze(1).to(hidden_states).unsqueeze(-1) + + if hidden_states.shape[1] < self.config.min_sequence_length: + hidden_states = F.pad(hidden_states, (0, 0, 0, self.config.min_sequence_length - hidden_states.shape[1])) + + return hidden_states diff --git a/src/diffusers/modular_pipelines/__init__.py b/src/diffusers/modular_pipelines/__init__.py index 0b2225c980b3..998f2c884d75 100644 --- a/src/diffusers/modular_pipelines/__init__.py +++ b/src/diffusers/modular_pipelines/__init__.py @@ -89,6 +89,10 @@ "QwenImageLayeredModularPipeline", "QwenImageLayeredAutoBlocks", ] + _import_structure["anima"] = [ + "AnimaAutoBlocks", + "AnimaModularPipeline", + ] _import_structure["ernie_image"] = [ "ErnieImageAutoBlocks", "ErnieImageModularPipeline", @@ -114,6 +118,7 @@ except OptionalDependencyNotAvailable: from ..utils.dummy_pt_objects import * # noqa F403 else: + from .anima import AnimaAutoBlocks, AnimaModularPipeline from .components_manager import ComponentsManager from .ernie_image import ErnieImageAutoBlocks, ErnieImageModularPipeline from .flux import FluxAutoBlocks, FluxKontextAutoBlocks, FluxKontextModularPipeline, FluxModularPipeline diff --git a/src/diffusers/modular_pipelines/anima/__init__.py b/src/diffusers/modular_pipelines/anima/__init__.py new file mode 100644 index 000000000000..4772d906e03b --- /dev/null +++ b/src/diffusers/modular_pipelines/anima/__init__.py @@ -0,0 +1,47 @@ +from typing import TYPE_CHECKING + +from ...utils import ( + DIFFUSERS_SLOW_IMPORT, + OptionalDependencyNotAvailable, + _LazyModule, + get_objects_from_module, + is_torch_available, + is_transformers_available, +) + + +_dummy_objects = {} +_import_structure = {} + +try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ...utils import dummy_torch_and_transformers_objects # noqa F403 + + _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) +else: + _import_structure["modular_blocks_anima"] = ["AnimaAutoBlocks"] + _import_structure["modular_pipeline"] = ["AnimaModularPipeline"] + +if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: + try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + from ...utils.dummy_torch_and_transformers_objects import * # noqa F403 + else: + from .modular_blocks_anima import AnimaAutoBlocks + from .modular_pipeline import AnimaModularPipeline +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + module_spec=__spec__, + ) + + for name, value in _dummy_objects.items(): + setattr(sys.modules[__name__], name, value) diff --git a/src/diffusers/modular_pipelines/anima/before_denoise.py b/src/diffusers/modular_pipelines/anima/before_denoise.py new file mode 100644 index 000000000000..bff17fe07e95 --- /dev/null +++ b/src/diffusers/modular_pipelines/anima/before_denoise.py @@ -0,0 +1,284 @@ +# Copyright 2026 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect + +import numpy as np +import torch + +from ...models import CosmosTransformer3DModel +from ...schedulers import FlowMatchEulerDiscreteScheduler +from ...utils.torch_utils import randn_tensor +from ..modular_pipeline import ModularPipelineBlocks, PipelineState +from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam +from .modular_pipeline import AnimaModularPipeline + + +def retrieve_timesteps( + scheduler, + num_inference_steps: int | None = None, + device: str | torch.device | None = None, + timesteps: list[int] | None = None, + sigmas: list[float] | None = None, + **kwargs, +): + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +class AnimaTextInputStep(ModularPipelineBlocks): + model_name = "anima" + + @property + def description(self) -> str: + return "Input processing step that expands Anima prompt embeddings for the requested image batch." + + @property + def expected_components(self) -> list[ComponentSpec]: + return [ComponentSpec("transformer", CosmosTransformer3DModel)] + + @property + def inputs(self) -> list[InputParam]: + return [ + InputParam.template("num_images_per_prompt"), + InputParam( + "prompt_embeds", + required=True, + type_hint=torch.Tensor, + description="Conditioned prompt embeddings generated by the text encoder step.", + ), + InputParam( + "negative_prompt_embeds", + type_hint=torch.Tensor, + description="Conditioned negative prompt embeddings generated by the text encoder step.", + ), + ] + + @property + def intermediate_outputs(self) -> list[OutputParam]: + return [ + OutputParam( + "prompt_embeds", + type_hint=torch.Tensor, + kwargs_type="denoiser_input_fields", + description="Prompt embeddings expanded to the final denoising batch.", + ), + OutputParam( + "negative_prompt_embeds", + type_hint=torch.Tensor, + kwargs_type="denoiser_input_fields", + description="Negative prompt embeddings expanded to the final denoising batch.", + ), + OutputParam( + "batch_size", + type_hint=int, + description="Number of input prompts before `num_images_per_prompt` expansion.", + ), + OutputParam("dtype", type_hint=torch.dtype, description="Dtype used by the Anima denoiser."), + ] + + @torch.no_grad() + def __call__(self, components: AnimaModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + + block_state.batch_size = block_state.prompt_embeds.shape[0] + block_state.dtype = components.transformer.dtype + + _, seq_len, _ = block_state.prompt_embeds.shape + block_state.prompt_embeds = block_state.prompt_embeds.repeat(1, block_state.num_images_per_prompt, 1) + block_state.prompt_embeds = block_state.prompt_embeds.view( + block_state.batch_size * block_state.num_images_per_prompt, seq_len, -1 + ) + + if block_state.negative_prompt_embeds is not None: + _, seq_len, _ = block_state.negative_prompt_embeds.shape + block_state.negative_prompt_embeds = block_state.negative_prompt_embeds.repeat( + 1, block_state.num_images_per_prompt, 1 + ) + block_state.negative_prompt_embeds = block_state.negative_prompt_embeds.view( + block_state.batch_size * block_state.num_images_per_prompt, seq_len, -1 + ) + + self.set_block_state(state, block_state) + return components, state + + +class AnimaPrepareLatentsStep(ModularPipelineBlocks): + model_name = "anima" + + @property + def description(self) -> str: + return "Prepare noisy image latents and padding mask for Anima denoising." + + @property + def expected_components(self) -> list[ComponentSpec]: + return [ComponentSpec("transformer", CosmosTransformer3DModel)] + + @property + def inputs(self) -> list[InputParam]: + return [ + InputParam.template("height"), + InputParam.template("width"), + InputParam.template("latents"), + InputParam.template("num_images_per_prompt"), + InputParam.template("generator"), + InputParam( + "batch_size", + required=True, + type_hint=int, + description="Number of input prompts before `num_images_per_prompt` expansion.", + ), + InputParam("dtype", type_hint=torch.dtype, description="Dtype used by the Anima denoiser."), + ] + + @property + def intermediate_outputs(self) -> list[OutputParam]: + return [ + OutputParam("height", type_hint=int, description="Image height used for generation."), + OutputParam("width", type_hint=int, description="Image width used for generation."), + OutputParam("latents", type_hint=torch.Tensor, description="Noisy latents for the denoising process."), + OutputParam("padding_mask", type_hint=torch.Tensor, description="Cosmos padding mask for image latents."), + ] + + def check_inputs(self, components: AnimaModularPipeline, block_state): + divisor = components.vae_scale_factor * 2 + if block_state.height % divisor != 0 or block_state.width % divisor != 0: + raise ValueError( + f"`height` and `width` have to be divisible by {divisor} but are {block_state.height} and" + f" {block_state.width}." + ) + + @staticmethod + def prepare_latents( + batch_size: int, + num_channels_latents: int, + height: int, + width: int, + vae_scale_factor: int, + dtype: torch.dtype, + device: torch.device, + generator: torch.Generator | list[torch.Generator] | None, + latents: torch.Tensor | None = None, + ) -> torch.Tensor: + if latents is not None: + return latents.to(device=device, dtype=dtype) + + latent_height = height // vae_scale_factor + latent_width = width // vae_scale_factor + shape = (batch_size, num_channels_latents, 1, latent_height, latent_width) + + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + return randn_tensor(shape, generator=generator, device=device, dtype=dtype) + + @torch.no_grad() + def __call__(self, components: AnimaModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + + block_state.height = block_state.height or components.default_height + block_state.width = block_state.width or components.default_width + self.check_inputs(components, block_state) + + device = components._execution_device + block_state.latents = self.prepare_latents( + batch_size=block_state.batch_size * block_state.num_images_per_prompt, + num_channels_latents=components.num_channels_latents, + height=block_state.height, + width=block_state.width, + vae_scale_factor=components.vae_scale_factor, + dtype=torch.float32, + device=device, + generator=block_state.generator, + latents=block_state.latents, + ) + block_state.padding_mask = block_state.latents.new_zeros( + 1, 1, block_state.height, block_state.width, dtype=block_state.dtype + ) + + self.set_block_state(state, block_state) + return components, state + + +class AnimaSetTimestepsStep(ModularPipelineBlocks): + model_name = "anima" + + @property + def expected_components(self) -> list[ComponentSpec]: + return [ComponentSpec("scheduler", FlowMatchEulerDiscreteScheduler)] + + @property + def description(self) -> str: + return "Set the scheduler timesteps for Anima inference." + + @property + def inputs(self) -> list[InputParam]: + return [ + InputParam.template("num_inference_steps"), + InputParam.template("sigmas"), + ] + + @property + def intermediate_outputs(self) -> list[OutputParam]: + return [ + OutputParam("timesteps", type_hint=torch.Tensor, description="Timesteps for the denoising loop."), + OutputParam("num_inference_steps", type_hint=int, description="Number of denoising steps."), + ] + + @torch.no_grad() + def __call__(self, components: AnimaModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + device = components._execution_device + + sigmas = ( + np.linspace(1.0, 1 / block_state.num_inference_steps, block_state.num_inference_steps) + if block_state.sigmas is None + else block_state.sigmas + ) + block_state.timesteps, block_state.num_inference_steps = retrieve_timesteps( + components.scheduler, + device=device, + sigmas=sigmas, + ) + components.scheduler.set_begin_index(0) + + self.set_block_state(state, block_state) + return components, state diff --git a/src/diffusers/modular_pipelines/anima/decoders.py b/src/diffusers/modular_pipelines/anima/decoders.py new file mode 100644 index 000000000000..f1f4b475a4b8 --- /dev/null +++ b/src/diffusers/modular_pipelines/anima/decoders.py @@ -0,0 +1,120 @@ +# Copyright 2026 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +import PIL +import torch + +from ...configuration_utils import FrozenDict +from ...image_processor import VaeImageProcessor +from ...models import AutoencoderKLQwenImage +from ..modular_pipeline import ModularPipelineBlocks, PipelineState +from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam +from .modular_pipeline import AnimaModularPipeline + + +class AnimaVaeDecoderStep(ModularPipelineBlocks): + model_name = "anima" + + @property + def description(self) -> str: + return "Step that decodes Anima latents into image tensors." + + @property + def expected_components(self) -> list[ComponentSpec]: + return [ComponentSpec("vae", AutoencoderKLQwenImage)] + + @property + def inputs(self) -> list[InputParam]: + return [ + InputParam("latents", required=True, type_hint=torch.Tensor, description="Denoised Anima latents."), + ] + + @property + def intermediate_outputs(self) -> list[OutputParam]: + return [OutputParam.template("images", note="tensor output of the VAE decoder")] + + @torch.no_grad() + def __call__(self, components: AnimaModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + + latents = block_state.latents.to(components.vae.dtype) + latents_mean = ( + torch.tensor(components.vae.config.latents_mean) + .view(1, components.vae.config.z_dim, 1, 1, 1) + .to(latents.device, latents.dtype) + ) + latents_std = 1.0 / torch.tensor(components.vae.config.latents_std).view( + 1, components.vae.config.z_dim, 1, 1, 1 + ).to(latents.device, latents.dtype) + latents = latents / latents_std + latents_mean + + block_state.images = components.vae.decode(latents, return_dict=False)[0][:, :, 0] + + self.set_block_state(state, block_state) + return components, state + + +class AnimaProcessImagesOutputStep(ModularPipelineBlocks): + model_name = "anima" + + @property + def description(self) -> str: + return "Postprocess decoded Anima image tensors." + + @property + def expected_components(self) -> list[ComponentSpec]: + return [ + ComponentSpec( + "image_processor", + VaeImageProcessor, + config=FrozenDict({"vae_scale_factor": 8}), + default_creation_method="from_config", + ), + ] + + @property + def inputs(self) -> list[InputParam]: + return [ + InputParam("images", required=True, type_hint=torch.Tensor, description="Decoded Anima image tensors."), + InputParam.template("output_type"), + ] + + @property + def intermediate_outputs(self) -> list[OutputParam]: + return [ + OutputParam( + "images", + type_hint=list[PIL.Image.Image] | np.ndarray | torch.Tensor, + description="Generated images.", + ) + ] + + @staticmethod + def check_inputs(output_type): + if output_type not in ["pil", "np", "pt"]: + raise ValueError(f"Invalid output_type: {output_type}") + + @torch.no_grad() + def __call__(self, components: AnimaModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + self.check_inputs(block_state.output_type) + + block_state.images = components.image_processor.postprocess( + image=block_state.images, + output_type=block_state.output_type, + ) + + self.set_block_state(state, block_state) + return components, state diff --git a/src/diffusers/modular_pipelines/anima/denoise.py b/src/diffusers/modular_pipelines/anima/denoise.py new file mode 100644 index 000000000000..2c7f36484749 --- /dev/null +++ b/src/diffusers/modular_pipelines/anima/denoise.py @@ -0,0 +1,214 @@ +# Copyright 2026 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any + +import torch + +from ...configuration_utils import FrozenDict +from ...guiders import ClassifierFreeGuidance +from ...models import CosmosTransformer3DModel +from ...schedulers import FlowMatchEulerDiscreteScheduler +from ..modular_pipeline import BlockState, LoopSequentialPipelineBlocks, ModularPipelineBlocks, PipelineState +from ..modular_pipeline_utils import ComponentSpec, InputParam +from .modular_pipeline import AnimaModularPipeline + + +class AnimaLoopBeforeDenoiser(ModularPipelineBlocks): + model_name = "anima" + + @property + def description(self) -> str: + return "Step within the denoising loop that prepares Anima latent and timestep inputs." + + @property + def inputs(self) -> list[InputParam]: + return [ + InputParam("latents", required=True, type_hint=torch.Tensor, description="Current Anima latents."), + InputParam("dtype", required=True, type_hint=torch.dtype, description="Dtype used by the Anima denoiser."), + ] + + @torch.no_grad() + def __call__(self, components: AnimaModularPipeline, block_state: BlockState, i: int, t: torch.Tensor): + block_state.latent_model_input = block_state.latents.to(block_state.dtype) + + timestep = t.expand(block_state.latents.shape[0]).to(block_state.dtype) + block_state.timestep = timestep / components.scheduler.config.num_train_timesteps + return components, block_state + + +class AnimaLoopDenoiser(ModularPipelineBlocks): + model_name = "anima" + + def __init__( + self, + guider_input_fields: dict[str, Any] | None = None, + ): + if guider_input_fields is None: + guider_input_fields = {"encoder_hidden_states": ("prompt_embeds", "negative_prompt_embeds")} + if not isinstance(guider_input_fields, dict): + raise ValueError(f"`guider_input_fields` must be a dictionary but is {type(guider_input_fields)}") + self._guider_input_fields = guider_input_fields + super().__init__() + + @property + def expected_components(self) -> list[ComponentSpec]: + return [ + ComponentSpec( + "guider", + ClassifierFreeGuidance, + config=FrozenDict({"guidance_scale": 4.0}), + default_creation_method="from_config", + ), + ComponentSpec("transformer", CosmosTransformer3DModel), + ] + + @property + def description(self) -> str: + return "Step within the denoising loop that predicts Anima noise with guidance." + + @property + def inputs(self) -> list[InputParam]: + inputs = [ + InputParam( + "num_inference_steps", + required=True, + type_hint=int, + description="Number of denoising steps.", + ), + InputParam( + "padding_mask", + required=True, + type_hint=torch.Tensor, + description="Cosmos padding mask for image latents.", + ), + InputParam( + kwargs_type="denoiser_input_fields", + description="The conditional model inputs for the Anima denoiser.", + ), + ] + + guider_input_names = [] + uncond_guider_input_names = [] + for value in self._guider_input_fields.values(): + if isinstance(value, tuple): + guider_input_names.append(value[0]) + uncond_guider_input_names.append(value[1]) + else: + guider_input_names.append(value) + + for name in guider_input_names: + inputs.append(InputParam(name=name, required=True)) + for name in uncond_guider_input_names: + inputs.append(InputParam(name=name)) + return inputs + + @torch.no_grad() + def __call__( + self, components: AnimaModularPipeline, block_state: BlockState, i: int, t: torch.Tensor + ) -> PipelineState: + components.guider.set_state(step=i, num_inference_steps=block_state.num_inference_steps, timestep=t) + guider_state = components.guider.prepare_inputs_from_block_state(block_state, self._guider_input_fields) + + for guider_state_batch in guider_state: + components.guider.prepare_models(components.transformer) + cond_kwargs = { + key: getattr(guider_state_batch, key).to(block_state.dtype) + for key in self._guider_input_fields.keys() + } + guider_state_batch.noise_pred = components.transformer( + hidden_states=block_state.latent_model_input, + timestep=block_state.timestep, + padding_mask=block_state.padding_mask, + return_dict=False, + **cond_kwargs, + )[0] + components.guider.cleanup_models(components.transformer) + + block_state.noise_pred = components.guider(guider_state)[0] + return components, block_state + + +class AnimaLoopAfterDenoiser(ModularPipelineBlocks): + model_name = "anima" + + @property + def expected_components(self) -> list[ComponentSpec]: + return [ComponentSpec("scheduler", FlowMatchEulerDiscreteScheduler)] + + @property + def description(self) -> str: + return "Step within the denoising loop that updates Anima latents." + + @torch.no_grad() + def __call__(self, components: AnimaModularPipeline, block_state: BlockState, i: int, t: torch.Tensor): + latents_dtype = block_state.latents.dtype + block_state.latents = components.scheduler.step( + block_state.noise_pred, t, block_state.latents, return_dict=False + )[0] + if block_state.latents.dtype != latents_dtype and torch.backends.mps.is_available(): + block_state.latents = block_state.latents.to(latents_dtype) + + return components, block_state + + +class AnimaDenoiseLoopWrapper(LoopSequentialPipelineBlocks): + model_name = "anima" + + @property + def description(self) -> str: + return "Pipeline block that iteratively denoises Anima latents over scheduler timesteps." + + @property + def loop_expected_components(self) -> list[ComponentSpec]: + return [ComponentSpec("scheduler", FlowMatchEulerDiscreteScheduler)] + + @property + def loop_inputs(self) -> list[InputParam]: + return [ + InputParam("timesteps", required=True, type_hint=torch.Tensor, description="Timesteps to denoise over."), + InputParam("num_inference_steps", required=True, type_hint=int, description="Number of denoising steps."), + ] + + @torch.no_grad() + def __call__(self, components: AnimaModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + + num_warmup_steps = len(block_state.timesteps) - block_state.num_inference_steps * components.scheduler.order + + with self.progress_bar(total=block_state.num_inference_steps) as progress_bar: + for i, t in enumerate(block_state.timesteps): + components, block_state = self.loop_step(components, block_state, i=i, t=t) + if i == len(block_state.timesteps) - 1 or ( + (i + 1) > num_warmup_steps and (i + 1) % components.scheduler.order == 0 + ): + progress_bar.update() + + self.set_block_state(state, block_state) + return components, state + + +class AnimaDenoiseStep(AnimaDenoiseLoopWrapper): + block_classes = [ + AnimaLoopBeforeDenoiser, + AnimaLoopDenoiser( + guider_input_fields={"encoder_hidden_states": ("prompt_embeds", "negative_prompt_embeds")} + ), + AnimaLoopAfterDenoiser, + ] + block_names = ["before_denoiser", "denoiser", "after_denoiser"] + + @property + def description(self) -> str: + return "Denoise step that iteratively denoises image latents for Anima." diff --git a/src/diffusers/modular_pipelines/anima/encoders.py b/src/diffusers/modular_pipelines/anima/encoders.py new file mode 100644 index 000000000000..3462ee73cfa3 --- /dev/null +++ b/src/diffusers/modular_pipelines/anima/encoders.py @@ -0,0 +1,225 @@ +# Copyright 2026 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +from transformers import Qwen2Tokenizer, Qwen3Model, T5TokenizerFast + +from ...configuration_utils import FrozenDict +from ...guiders import ClassifierFreeGuidance +from ...models import AnimaTextConditioner +from ..modular_pipeline import ModularPipelineBlocks, PipelineState +from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam +from .modular_pipeline import AnimaModularPipeline + + +class AnimaTextEncoderStep(ModularPipelineBlocks): + model_name = "anima" + + @property + def description(self) -> str: + return "Text encoder step that maps prompts to Cosmos text conditioning for Anima." + + @property + def expected_components(self) -> list[ComponentSpec]: + return [ + ComponentSpec("text_encoder", Qwen3Model), + ComponentSpec("tokenizer", Qwen2Tokenizer), + ComponentSpec("t5_tokenizer", T5TokenizerFast), + ComponentSpec("text_conditioner", AnimaTextConditioner), + ComponentSpec( + "guider", + ClassifierFreeGuidance, + config=FrozenDict({"guidance_scale": 4.0}), + default_creation_method="from_config", + ), + ] + + @property + def inputs(self) -> list[InputParam]: + return [ + InputParam.template("prompt"), + InputParam.template("negative_prompt"), + InputParam.template("max_sequence_length"), + ] + + @property + def intermediate_outputs(self) -> list[OutputParam]: + return [ + OutputParam( + "prompt_embeds", + type_hint=torch.Tensor, + kwargs_type="denoiser_input_fields", + description="Conditioned prompt embeddings generated by the Anima text conditioner.", + ), + OutputParam( + "negative_prompt_embeds", + type_hint=torch.Tensor, + kwargs_type="denoiser_input_fields", + description="Conditioned negative prompt embeddings generated by the Anima text conditioner.", + ), + ] + + @staticmethod + def check_inputs(block_state): + if not isinstance(block_state.prompt, str) and not isinstance(block_state.prompt, list): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(block_state.prompt)}") + if block_state.max_sequence_length is not None and block_state.max_sequence_length > 4096: + raise ValueError(f"`max_sequence_length` cannot be greater than 4096 but is {block_state.max_sequence_length}") + + @staticmethod + def _get_qwen_prompt_embeds( + components: AnimaModularPipeline, + prompt: str | list[str], + max_sequence_length: int, + device: torch.device, + dtype: torch.dtype, + ) -> tuple[torch.Tensor, torch.Tensor]: + prompt = [prompt] if isinstance(prompt, str) else prompt + + text_inputs = components.tokenizer( + prompt, + padding="longest", + max_length=max_sequence_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids.to(device) + prompt_attention_mask = text_inputs.attention_mask.to(device) + if text_input_ids.shape[-1] == 0: + text_input_ids = text_input_ids.new_zeros((text_input_ids.shape[0], 1)) + prompt_attention_mask = prompt_attention_mask.new_zeros((prompt_attention_mask.shape[0], 1)) + + prompt_embeds = components.text_encoder( + input_ids=text_input_ids, + attention_mask=prompt_attention_mask, + output_hidden_states=False, + ).last_hidden_state + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + prompt_embeds = prompt_embeds * prompt_attention_mask.to(prompt_embeds).unsqueeze(-1) + + return prompt_embeds, prompt_attention_mask + + @staticmethod + def _get_t5_prompt_ids( + components: AnimaModularPipeline, + prompt: str | list[str], + max_sequence_length: int, + device: torch.device, + ) -> tuple[torch.Tensor, torch.Tensor]: + prompt = [prompt] if isinstance(prompt, str) else prompt + + text_inputs = components.t5_tokenizer( + prompt, + padding="longest", + max_length=max_sequence_length, + truncation=True, + return_tensors="pt", + ) + return text_inputs.input_ids.to(device), text_inputs.attention_mask.to(device) + + @classmethod + def encode_prompt( + cls, + components: AnimaModularPipeline, + prompt: str | list[str], + negative_prompt: str | list[str] | None = None, + prepare_unconditional_embeds: bool = True, + max_sequence_length: int = 512, + device: torch.device | None = None, + dtype: torch.dtype | None = None, + ) -> tuple[torch.Tensor, torch.Tensor | None]: + device = device or components._execution_device + dtype = dtype or components.text_conditioner.dtype + + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) + + qwen_prompt_embeds, qwen_attention_mask = cls._get_qwen_prompt_embeds( + components=components, + prompt=prompt, + max_sequence_length=max_sequence_length, + device=device, + dtype=dtype, + ) + t5_input_ids, t5_attention_mask = cls._get_t5_prompt_ids( + components=components, + prompt=prompt, + max_sequence_length=max_sequence_length, + device=device, + ) + prompt_embeds = components.text_conditioner( + source_hidden_states=qwen_prompt_embeds, + target_input_ids=t5_input_ids, + target_attention_mask=t5_attention_mask, + source_attention_mask=qwen_attention_mask, + ) + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + + negative_prompt_embeds = None + if prepare_unconditional_embeds: + negative_prompt = negative_prompt if negative_prompt is not None else "" + negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + + if prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + if batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + + negative_qwen_prompt_embeds, negative_qwen_attention_mask = cls._get_qwen_prompt_embeds( + components=components, + prompt=negative_prompt, + max_sequence_length=max_sequence_length, + device=device, + dtype=dtype, + ) + negative_t5_input_ids, negative_t5_attention_mask = cls._get_t5_prompt_ids( + components=components, + prompt=negative_prompt, + max_sequence_length=max_sequence_length, + device=device, + ) + negative_prompt_embeds = components.text_conditioner( + source_hidden_states=negative_qwen_prompt_embeds, + target_input_ids=negative_t5_input_ids, + target_attention_mask=negative_t5_attention_mask, + source_attention_mask=negative_qwen_attention_mask, + ) + negative_prompt_embeds = negative_prompt_embeds.to(dtype=dtype, device=device) + + return prompt_embeds, negative_prompt_embeds + + @torch.no_grad() + def __call__(self, components: AnimaModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + self.check_inputs(block_state) + + block_state.prompt_embeds, block_state.negative_prompt_embeds = self.encode_prompt( + components=components, + prompt=block_state.prompt, + negative_prompt=block_state.negative_prompt, + prepare_unconditional_embeds=components.guider.num_conditions > 1, + max_sequence_length=block_state.max_sequence_length, + device=components._execution_device, + dtype=components.transformer.dtype, + ) + + self.set_block_state(state, block_state) + return components, state diff --git a/src/diffusers/modular_pipelines/anima/modular_blocks_anima.py b/src/diffusers/modular_pipelines/anima/modular_blocks_anima.py new file mode 100644 index 000000000000..af966428cc71 --- /dev/null +++ b/src/diffusers/modular_pipelines/anima/modular_blocks_anima.py @@ -0,0 +1,173 @@ +# Copyright 2026 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from ..modular_pipeline import SequentialPipelineBlocks +from ..modular_pipeline_utils import OutputParam +from .before_denoise import AnimaPrepareLatentsStep, AnimaSetTimestepsStep, AnimaTextInputStep +from .decoders import AnimaProcessImagesOutputStep, AnimaVaeDecoderStep +from .denoise import AnimaDenoiseStep +from .encoders import AnimaTextEncoderStep + + +# auto_docstring +class AnimaCoreDenoiseStep(SequentialPipelineBlocks): + """ + Denoise block that takes encoded Anima text conditions and runs the denoising process. + + Components: + transformer (`CosmosTransformer3DModel`) + scheduler (`FlowMatchEulerDiscreteScheduler`) + guider (`ClassifierFreeGuidance`) + + Inputs: + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + prompt_embeds (`Tensor`): + Conditioned prompt embeddings generated by the text encoder step. + negative_prompt_embeds (`Tensor`, *optional*): + Conditioned negative prompt embeddings generated by the text encoder step. + height (`int`, *optional*): + The height in pixels of the generated image. + width (`int`, *optional*): + The width in pixels of the generated image. + latents (`Tensor`, *optional*): + Pre-generated noisy latents for image generation. + generator (`Generator`, *optional*): + Torch generator for deterministic generation. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. + sigmas (`list`, *optional*): + Custom sigmas for the denoising process. + **denoiser_input_fields (`None`, *optional*): + The conditional model inputs for the Anima denoiser. + + Outputs: + latents (`Tensor`): + Denoised latents. + """ + + block_classes = [ + AnimaTextInputStep, + AnimaPrepareLatentsStep, + AnimaSetTimestepsStep, + AnimaDenoiseStep, + ] + block_names = ["input", "prepare_latents", "set_timesteps", "denoise"] + + @property + def description(self) -> str: + return "Denoise block that takes encoded Anima text conditions and runs the denoising process." + + @property + def outputs(self): + return [OutputParam.template("latents")] + + +# auto_docstring +class AnimaDecodeStep(SequentialPipelineBlocks): + """ + Decode Anima latents into generated images. + + Components: + vae (`AutoencoderKLQwenImage`) + image_processor (`VaeImageProcessor`) + + Inputs: + latents (`Tensor`): + Denoised Anima latents. + output_type (`str`, *optional*, defaults to pil): + Output format: 'pil', 'np', 'pt'. + + Outputs: + images (`list`): + Generated images. + """ + + block_classes = [AnimaVaeDecoderStep, AnimaProcessImagesOutputStep] + block_names = ["decode", "postprocess"] + + @property + def description(self) -> str: + return "Decode Anima latents into generated images." + + @property + def outputs(self): + return [OutputParam.template("images")] + + +# auto_docstring +class AnimaAutoBlocks(SequentialPipelineBlocks): + """ + Auto Modular pipeline for text-to-image generation using Anima. + + Supported workflows: + - `text2image`: requires `prompt` + + Components: + text_encoder (`Qwen3Model`) + tokenizer (`Qwen2Tokenizer`) + t5_tokenizer (`T5Tokenizer`) + text_conditioner (`AnimaTextConditioner`) + guider (`ClassifierFreeGuidance`) + transformer (`CosmosTransformer3DModel`) + scheduler (`FlowMatchEulerDiscreteScheduler`) + vae (`AutoencoderKLQwenImage`) + image_processor (`VaeImageProcessor`) + + Inputs: + prompt (`str`): + The prompt or prompts to guide image generation. + negative_prompt (`str`, *optional*): + The prompt or prompts not to guide the image generation. + max_sequence_length (`int`, *optional*, defaults to 512): + Maximum sequence length for prompt encoding. + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + height (`int`, *optional*): + The height in pixels of the generated image. + width (`int`, *optional*): + The width in pixels of the generated image. + latents (`Tensor`, *optional*): + Pre-generated noisy latents for image generation. + generator (`Generator`, *optional*): + Torch generator for deterministic generation. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. + sigmas (`list`, *optional*): + Custom sigmas for the denoising process. + **denoiser_input_fields (`None`, *optional*): + The conditional model inputs for the Anima denoiser. + output_type (`str`, *optional*, defaults to pil): + Output format: 'pil', 'np', 'pt'. + + Outputs: + images (`list`): + Generated images. + """ + + block_classes = [ + AnimaTextEncoderStep, + AnimaCoreDenoiseStep, + AnimaDecodeStep, + ] + block_names = ["text_encoder", "denoise", "decode"] + _workflow_map = {"text2image": {"prompt": True}} + + @property + def description(self) -> str: + return "Auto Modular pipeline for text-to-image generation using Anima." + + @property + def outputs(self): + return [OutputParam.template("images")] diff --git a/src/diffusers/modular_pipelines/anima/modular_pipeline.py b/src/diffusers/modular_pipelines/anima/modular_pipeline.py new file mode 100644 index 000000000000..44fce4657c6f --- /dev/null +++ b/src/diffusers/modular_pipelines/anima/modular_pipeline.py @@ -0,0 +1,52 @@ +# Copyright 2026 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from ...loaders import AnimaLoraLoaderMixin +from ..modular_pipeline import ModularPipeline + + +class AnimaModularPipeline(ModularPipeline, AnimaLoraLoaderMixin): + """ + A ModularPipeline for Anima. + + > [!WARNING] > This is an experimental feature and is likely to change in the future. + """ + + default_blocks_name = "AnimaAutoBlocks" + + @property + def default_height(self): + return self.default_sample_size * self.vae_scale_factor + + @property + def default_width(self): + return self.default_sample_size * self.vae_scale_factor + + @property + def default_sample_size(self): + return 128 + + @property + def vae_scale_factor(self): + vae_scale_factor = 8 + if self.vae is not None: + vae_scale_factor = 2 ** len(self.vae.temperal_downsample) + return vae_scale_factor + + @property + def num_channels_latents(self): + num_channels_latents = 16 + if self.transformer is not None: + num_channels_latents = self.transformer.config.in_channels + return num_channels_latents diff --git a/src/diffusers/modular_pipelines/modular_pipeline.py b/src/diffusers/modular_pipelines/modular_pipeline.py index 8cfe07059272..37e069bc05f8 100644 --- a/src/diffusers/modular_pipelines/modular_pipeline.py +++ b/src/diffusers/modular_pipelines/modular_pipeline.py @@ -130,6 +130,7 @@ def _helios_pyramid_map_fn(config_dict=None): ("qwenimage-edit", _create_default_map_fn("QwenImageEditModularPipeline")), ("qwenimage-edit-plus", _create_default_map_fn("QwenImageEditPlusModularPipeline")), ("qwenimage-layered", _create_default_map_fn("QwenImageLayeredModularPipeline")), + ("anima", _create_default_map_fn("AnimaModularPipeline")), ("z-image", _create_default_map_fn("ZImageModularPipeline")), ("helios", _create_default_map_fn("HeliosModularPipeline")), ("helios-pyramid", _helios_pyramid_map_fn), diff --git a/src/diffusers/utils/dummy_pt_objects.py b/src/diffusers/utils/dummy_pt_objects.py index 9bfb73c1999e..bc9856604928 100644 --- a/src/diffusers/utils/dummy_pt_objects.py +++ b/src/diffusers/utils/dummy_pt_objects.py @@ -435,6 +435,21 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) +class AnimaTextConditioner(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + class AsymmetricAutoencoderKL(metaclass=DummyObject): _backends = ["torch"] diff --git a/src/diffusers/utils/dummy_torch_and_transformers_objects.py b/src/diffusers/utils/dummy_torch_and_transformers_objects.py index cfa1318783f3..d320e2374f50 100644 --- a/src/diffusers/utils/dummy_torch_and_transformers_objects.py +++ b/src/diffusers/utils/dummy_torch_and_transformers_objects.py @@ -2,6 +2,36 @@ from ..utils import DummyObject, requires_backends +class AnimaAutoBlocks(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + +class AnimaModularPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + class ErnieImageAutoBlocks(metaclass=DummyObject): _backends = ["torch", "transformers"] diff --git a/tests/modular_pipelines/anima/__init__.py b/tests/modular_pipelines/anima/__init__.py new file mode 100644 index 000000000000..8b137891791f --- /dev/null +++ b/tests/modular_pipelines/anima/__init__.py @@ -0,0 +1 @@ + diff --git a/tests/modular_pipelines/anima/test_modular_pipeline_anima.py b/tests/modular_pipelines/anima/test_modular_pipeline_anima.py new file mode 100644 index 000000000000..38fc0d3649d6 --- /dev/null +++ b/tests/modular_pipelines/anima/test_modular_pipeline_anima.py @@ -0,0 +1,230 @@ +# Copyright 2026 The HuggingFace Team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import tempfile +import unittest + +import torch +from transformers import Qwen2Tokenizer, Qwen3Config, Qwen3Model, T5TokenizerFast + +from diffusers import ( + AnimaAutoBlocks, + AnimaModularPipeline, + AnimaTextConditioner, + AutoencoderKLQwenImage, + CosmosTransformer3DModel, + FlowMatchEulerDiscreteScheduler, +) + +from ...testing_utils import enable_full_determinism, require_peft_backend +from ..test_modular_pipelines_common import ModularGuiderTesterMixin, ModularPipelineTesterMixin + + +enable_full_determinism() + + +ANIMA_TEXT2IMAGE_WORKFLOWS = { + "text2image": [ + ("text_encoder", "AnimaTextEncoderStep"), + ("denoise.input", "AnimaTextInputStep"), + ("denoise.prepare_latents", "AnimaPrepareLatentsStep"), + ("denoise.set_timesteps", "AnimaSetTimestepsStep"), + ("denoise.denoise", "AnimaDenoiseStep"), + ("decode.decode", "AnimaVaeDecoderStep"), + ("decode.postprocess", "AnimaProcessImagesOutputStep"), + ], +} + + +def get_dummy_components(): + torch.manual_seed(0) + transformer = CosmosTransformer3DModel( + in_channels=4, + out_channels=4, + num_attention_heads=2, + attention_head_dim=16, + num_layers=2, + mlp_ratio=2, + text_embed_dim=16, + adaln_lora_dim=4, + max_size=(4, 32, 32), + patch_size=(1, 2, 2), + rope_scale=(1.0, 4.0, 4.0), + concat_padding_mask=True, + extra_pos_embed_type=None, + ) + + torch.manual_seed(0) + vae = AutoencoderKLQwenImage( + base_dim=24, + z_dim=4, + dim_mult=[1, 2, 4], + num_res_blocks=1, + temperal_downsample=[False, True], + latents_mean=[0.0] * 4, + latents_std=[1.0] * 4, + ) + + torch.manual_seed(0) + text_conditioner = AnimaTextConditioner( + source_dim=16, + target_dim=16, + model_dim=16, + num_layers=2, + num_attention_heads=4, + target_vocab_size=32128, + min_sequence_length=16, + ) + + torch.manual_seed(0) + text_encoder_config = Qwen3Config( + vocab_size=152064, + hidden_size=16, + intermediate_size=32, + num_hidden_layers=2, + num_attention_heads=4, + num_key_value_heads=2, + max_position_embeddings=128, + rms_norm_eps=1e-6, + rope_theta=1000000.0, + head_dim=4, + attention_bias=False, + ) + text_encoder = Qwen3Model(text_encoder_config).eval() + tokenizer = Qwen2Tokenizer.from_pretrained("hf-internal-testing/tiny-random-Qwen2VLForConditionalGeneration") + t5_tokenizer = T5TokenizerFast.from_pretrained("hf-internal-testing/tiny-random-t5") + scheduler = FlowMatchEulerDiscreteScheduler(shift=3.0) + + return { + "transformer": transformer, + "vae": vae, + "scheduler": scheduler, + "text_encoder": text_encoder, + "tokenizer": tokenizer, + "t5_tokenizer": t5_tokenizer, + "text_conditioner": text_conditioner, + } + + +class AnimaTextConditionerFastTests(unittest.TestCase): + def test_conditioner_output_shape_and_padding(self): + conditioner = AnimaTextConditioner( + source_dim=16, + target_dim=16, + model_dim=16, + num_layers=2, + num_attention_heads=4, + target_vocab_size=128, + min_sequence_length=8, + ) + source_hidden_states = torch.randn(2, 5, 16) + target_input_ids = torch.randint(0, 128, (2, 4)) + source_attention_mask = torch.ones(2, 5) + target_attention_mask = torch.ones(2, 4) + target_attention_mask[1, -1] = 0 + + output = conditioner( + source_hidden_states=source_hidden_states, + target_input_ids=target_input_ids, + source_attention_mask=source_attention_mask, + target_attention_mask=target_attention_mask, + ) + + self.assertEqual(output.shape, (2, 8, 16)) + self.assertTrue(torch.allclose(output[1, 3], torch.zeros_like(output[1, 3]), atol=1e-5)) + self.assertTrue(torch.allclose(output[:, 4:], torch.zeros_like(output[:, 4:]), atol=1e-5)) + + +class TestAnimaModularPipelineFast(ModularPipelineTesterMixin, ModularGuiderTesterMixin): + pipeline_class = AnimaModularPipeline + pipeline_blocks_class = AnimaAutoBlocks + pretrained_model_name_or_path = "hf-internal-testing/tiny-anima-modular-pipe" + params = frozenset(["prompt", "height", "width", "negative_prompt"]) + batch_params = frozenset(["prompt", "negative_prompt"]) + expected_workflow_blocks = ANIMA_TEXT2IMAGE_WORKFLOWS + + def get_pipeline(self, components_manager=None, torch_dtype=torch.float32): + pipe = self.pipeline_blocks_class().init_pipeline(components_manager=components_manager) + pipe.update_components(**get_dummy_components()) + pipe.to(dtype=torch_dtype) + pipe.set_progress_bar_config(disable=None) + return pipe + + def get_dummy_inputs(self, seed=0): + generator = torch.Generator(device="cpu").manual_seed(seed) + return { + "prompt": "dance monkey", + "negative_prompt": "bad quality", + "generator": generator, + "num_inference_steps": 2, + "height": 32, + "width": 32, + "max_sequence_length": 16, + "output_type": "pt", + } + + def test_inference_empty_negative_prompt(self): + pipe = self.get_pipeline() + + inputs = self.get_dummy_inputs() + inputs["negative_prompt"] = "" + output = pipe(**inputs).images + + assert output.shape == (1, 3, 32, 32) + assert not torch.isnan(output).any() + + def test_inference_batch_single_identical(self): + super().test_inference_batch_single_identical(expected_max_diff=5e-4) + + def test_save_load_components(self): + pipe = self.get_pipeline() + + with tempfile.TemporaryDirectory() as tmpdir: + pipe.save_pretrained(tmpdir, safe_serialization=True) + pipe = self.pipeline_class.from_pretrained(tmpdir) + pipe.load_components() + + assert isinstance(pipe.text_conditioner, AnimaTextConditioner) + assert isinstance(pipe.transformer, CosmosTransformer3DModel) + + def test_lora_state_dict_conversion(self): + state_dict = { + "diffusion_model.blocks.0.self_attn.q_proj.lora_A.weight": torch.randn(2, 32), + "diffusion_model.blocks.0.self_attn.q_proj.lora_B.weight": torch.randn(32, 2), + "diffusion_model.blocks.0.adaln_modulation_cross_attn.1.lora_A.weight": torch.randn(2, 32), + "diffusion_model.blocks.0.adaln_modulation_cross_attn.1.lora_B.weight": torch.randn(4, 2), + "diffusion_model.llm_adapter.blocks.0.self_attn.q_proj.lora_A.weight": torch.randn(2, 16), + "diffusion_model.llm_adapter.blocks.0.self_attn.q_proj.lora_B.weight": torch.randn(16, 2), + } + + converted_state_dict = self.pipeline_class.lora_state_dict(state_dict) + + assert "transformer.transformer_blocks.0.attn1.to_q.lora_A.weight" in converted_state_dict + assert "transformer.transformer_blocks.0.norm2.linear_1.lora_B.weight" in converted_state_dict + assert "text_conditioner.blocks.0.self_attn.q_proj.lora_A.weight" in converted_state_dict + + @require_peft_backend + def test_load_lora_weights(self): + pipe = self.get_pipeline() + state_dict = { + "diffusion_model.blocks.0.self_attn.q_proj.lora_A.weight": torch.randn(2, 32), + "diffusion_model.blocks.0.self_attn.q_proj.lora_B.weight": torch.randn(32, 2), + "diffusion_model.llm_adapter.blocks.0.self_attn.q_proj.lora_A.weight": torch.randn(2, 16), + "diffusion_model.llm_adapter.blocks.0.self_attn.q_proj.lora_B.weight": torch.randn(16, 2), + } + + pipe.load_lora_weights(state_dict, adapter_name="dummy") + + assert "dummy" in pipe.transformer.peft_config + assert "dummy" in pipe.text_conditioner.peft_config