From f13a80089e3819d8ec202cf2e168479883115b9e Mon Sep 17 00:00:00 2001 From: rmatif Date: Mon, 11 May 2026 17:40:00 +0200 Subject: [PATCH 1/6] Add Anima pipeline --- docs/source/en/_toctree.yml | 8 +- docs/source/en/api/pipelines/anima.md | 21 + scripts/convert_anima_to_diffusers.py | 313 ++++++++++ src/diffusers/__init__.py | 4 + src/diffusers/loaders/__init__.py | 2 + .../loaders/lora_conversion_utils.py | 46 ++ src/diffusers/loaders/lora_pipeline.py | 203 +++++++ src/diffusers/pipelines/__init__.py | 2 + src/diffusers/pipelines/anima/__init__.py | 48 ++ .../pipelines/anima/modeling_anima.py | 295 +++++++++ .../pipelines/anima/pipeline_anima.py | 566 ++++++++++++++++++ tests/pipelines/anima/__init__.py | 1 + tests/pipelines/anima/test_anima.py | 201 +++++++ 13 files changed, 1707 insertions(+), 3 deletions(-) create mode 100644 docs/source/en/api/pipelines/anima.md create mode 100644 scripts/convert_anima_to_diffusers.py create mode 100644 src/diffusers/pipelines/anima/__init__.py create mode 100644 src/diffusers/pipelines/anima/modeling_anima.py create mode 100644 src/diffusers/pipelines/anima/pipeline_anima.py create mode 100644 tests/pipelines/anima/__init__.py create mode 100644 tests/pipelines/anima/test_anima.py diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index 8e8776d4a8c2..1947e11d799e 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -499,9 +499,11 @@ - local: api/pipelines/stable_audio title: Stable Audio title: Audio - - sections: - - local: api/pipelines/animatediff - title: AnimateDiff + - sections: + - local: api/pipelines/anima + title: Anima + - local: api/pipelines/animatediff + title: AnimateDiff - local: api/pipelines/aura_flow title: AuraFlow - local: api/pipelines/bria_3_2 diff --git a/docs/source/en/api/pipelines/anima.md b/docs/source/en/api/pipelines/anima.md new file mode 100644 index 000000000000..fdb0a84d2967 --- /dev/null +++ b/docs/source/en/api/pipelines/anima.md @@ -0,0 +1,21 @@ +# Anima + +Anima is a text-to-image model that reuses the [`CosmosTransformer3DModel`] with a Qwen3 text encoder, a T5-token text conditioner, and the [`AutoencoderKLQwenImage`] VAE. + +```python +import torch +from diffusers import AnimaPipeline + +pipe = AnimaPipeline.from_pretrained("path/to/anima-diffusers", torch_dtype=torch.bfloat16) +pipe.to("cuda") + +image = pipe("A cinematic portrait of a woman in a rain-soaked city street").images[0] +``` + +## AnimaPipeline + +[[autodoc]] AnimaPipeline + +## AnimaTextConditioner + +[[autodoc]] AnimaTextConditioner diff --git a/scripts/convert_anima_to_diffusers.py b/scripts/convert_anima_to_diffusers.py new file mode 100644 index 000000000000..26cd8c7f029a --- /dev/null +++ b/scripts/convert_anima_to_diffusers.py @@ -0,0 +1,313 @@ +""" +Convert Anima checkpoints to Diffusers format. + +Example: +```bash +python scripts/convert_anima_to_diffusers.py \ + --transformer_ckpt_path anima_model/anima-preview3-base.safetensors \ + --text_encoder_ckpt_path anima_model/qwen_3_06b_base.safetensors \ + --vae_ckpt_path anima_model/qwen_image_vae.safetensors \ + --qwen_tokenizer_path /home/user/Dev/ComfyUI/comfy/text_encoders/qwen25_tokenizer \ + --t5_tokenizer_path /home/user/Dev/ComfyUI/comfy/text_encoders/t5_tokenizer \ + --output_path anima_model/anima-preview3-diffusers \ + --save_pipeline +``` +""" + +import argparse +import pathlib +import sys +from typing import Any + +import torch +from accelerate import init_empty_weights +from convert_cosmos_to_diffusers import convert_transformer +from safetensors.torch import load_file +from transformers import AutoTokenizer, Qwen3Config, Qwen3Model, T5TokenizerFast + +from diffusers import ( + AnimaPipeline, + AnimaTextConditioner, + AutoencoderKLQwenImage, + FlowMatchEulerDiscreteScheduler, +) + + +DTYPE_MAPPING = { + "fp32": torch.float32, + "fp16": torch.float16, + "bf16": torch.bfloat16, +} + + +def rename_residual_key(key: str) -> str: + replacements = { + ".residual.0.": ".norm1.", + ".residual.2.": ".conv1.", + ".residual.3.": ".norm2.", + ".residual.6.": ".conv2.", + ".shortcut.": ".conv_shortcut.", + } + for old, new in replacements.items(): + key = key.replace(old, new) + return key + + +def rename_mid_key(key: str) -> str: + replacements = { + ".middle.0.": ".mid_block.resnets.0.", + ".middle.1.": ".mid_block.attentions.0.", + ".middle.2.": ".mid_block.resnets.1.", + } + for old, new in replacements.items(): + key = key.replace(old, new) + return rename_residual_key(key) + + +def rename_decoder_upsample_key(key: str) -> str: + prefix = "decoder.upsamples." + suffix = key.removeprefix(prefix) + index_str, rest = suffix.split(".", 1) + index = int(index_str) + + if index in (3, 7, 11): + block_index = (index - 3) // 4 + new_key = f"decoder.up_blocks.{block_index}.upsamplers.0.{rest}" + else: + block_index = index // 4 + resnet_index = index % 4 + new_key = f"decoder.up_blocks.{block_index}.resnets.{resnet_index}.{rest}" + + return rename_residual_key(new_key) + + +def convert_qwen_image_vae_state_dict(state_dict: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]: + converted_state_dict = {} + for key, value in state_dict.items(): + if key.startswith("conv1."): + new_key = key.replace("conv1.", "quant_conv.", 1) + elif key.startswith("conv2."): + new_key = key.replace("conv2.", "post_quant_conv.", 1) + elif key.startswith("encoder.conv1."): + new_key = key.replace("encoder.conv1.", "encoder.conv_in.", 1) + elif key.startswith("decoder.conv1."): + new_key = key.replace("decoder.conv1.", "decoder.conv_in.", 1) + elif key.startswith("encoder.downsamples."): + new_key = rename_residual_key(key.replace("encoder.downsamples.", "encoder.down_blocks.", 1)) + elif key.startswith("decoder.upsamples."): + new_key = rename_decoder_upsample_key(key) + elif key.startswith("encoder.middle.") or key.startswith("decoder.middle."): + new_key = rename_mid_key(key) + elif key.startswith("encoder.head.0."): + new_key = key.replace("encoder.head.0.", "encoder.norm_out.", 1) + elif key.startswith("encoder.head.2."): + new_key = key.replace("encoder.head.2.", "encoder.conv_out.", 1) + elif key.startswith("decoder.head.0."): + new_key = key.replace("decoder.head.0.", "decoder.norm_out.", 1) + elif key.startswith("decoder.head.2."): + new_key = key.replace("decoder.head.2.", "decoder.conv_out.", 1) + else: + new_key = rename_residual_key(key) + + if new_key in converted_state_dict: + raise ValueError(f"Duplicate converted VAE key: {new_key}") + converted_state_dict[new_key] = value + + return converted_state_dict + + +def convert_qwen_image_vae(state_dict: dict[str, torch.Tensor]) -> AutoencoderKLQwenImage: + converted_state_dict = convert_qwen_image_vae_state_dict(state_dict) + with init_empty_weights(): + vae = AutoencoderKLQwenImage() + + expected_keys = set(vae.state_dict().keys()) + converted_keys = set(converted_state_dict.keys()) + missing_keys = expected_keys - converted_keys + unexpected_keys = converted_keys - expected_keys + if missing_keys or unexpected_keys: + if missing_keys: + print(f"ERROR: missing VAE keys ({len(missing_keys)}):", file=sys.stderr) + for key in sorted(missing_keys): + print(key, file=sys.stderr) + if unexpected_keys: + print(f"ERROR: unexpected VAE keys ({len(unexpected_keys)}):", file=sys.stderr) + for key in sorted(unexpected_keys): + print(key, file=sys.stderr) + sys.exit(1) + + vae.load_state_dict(converted_state_dict, strict=True, assign=True) + return vae + + +def infer_text_conditioner_config(state_dict: dict[str, torch.Tensor]) -> dict[str, Any]: + model_dim = state_dict["blocks.0.self_attn.q_proj.weight"].shape[0] + source_dim = state_dict["blocks.0.cross_attn.k_proj.weight"].shape[1] + target_vocab_size, target_dim = state_dict["embed.weight"].shape + attention_head_dim = state_dict["blocks.0.self_attn.q_norm.weight"].shape[0] + num_layers = 1 + max(int(key.split(".")[1]) for key in state_dict if key.startswith("blocks.")) + + return { + "source_dim": source_dim, + "target_dim": target_dim, + "model_dim": model_dim, + "num_layers": num_layers, + "num_attention_heads": model_dim // attention_head_dim, + "target_vocab_size": target_vocab_size, + } + + +def convert_text_conditioner(state_dict: dict[str, torch.Tensor]) -> AnimaTextConditioner: + config = infer_text_conditioner_config(state_dict) + with init_empty_weights(): + text_conditioner = AnimaTextConditioner(**config) + + expected_keys = set(text_conditioner.state_dict().keys()) + converted_keys = set(state_dict.keys()) + missing_keys = expected_keys - converted_keys + unexpected_keys = converted_keys - expected_keys + if missing_keys or unexpected_keys: + if missing_keys: + print(f"ERROR: missing text conditioner keys ({len(missing_keys)}):", file=sys.stderr) + for key in sorted(missing_keys): + print(key, file=sys.stderr) + if unexpected_keys: + print(f"ERROR: unexpected text conditioner keys ({len(unexpected_keys)}):", file=sys.stderr) + for key in sorted(unexpected_keys): + print(key, file=sys.stderr) + sys.exit(1) + + text_conditioner.load_state_dict(state_dict, strict=True, assign=True) + return text_conditioner + + +def infer_qwen3_config(state_dict: dict[str, torch.Tensor]) -> Qwen3Config: + vocab_size, hidden_size = state_dict["embed_tokens.weight"].shape + intermediate_size = state_dict["layers.0.mlp.gate_proj.weight"].shape[0] + num_hidden_layers = 1 + max(int(key.split(".")[1]) for key in state_dict if key.startswith("layers.")) + head_dim = state_dict["layers.0.self_attn.q_norm.weight"].shape[0] + num_attention_heads = state_dict["layers.0.self_attn.q_proj.weight"].shape[0] // head_dim + num_key_value_heads = state_dict["layers.0.self_attn.k_proj.weight"].shape[0] // head_dim + + return Qwen3Config( + vocab_size=vocab_size, + hidden_size=hidden_size, + intermediate_size=intermediate_size, + num_hidden_layers=num_hidden_layers, + num_attention_heads=num_attention_heads, + num_key_value_heads=num_key_value_heads, + max_position_embeddings=32768, + rms_norm_eps=1e-6, + rope_theta=1000000.0, + head_dim=head_dim, + attention_bias=False, + tie_word_embeddings=False, + ) + + +def convert_text_encoder(state_dict: dict[str, torch.Tensor]) -> Qwen3Model: + state_dict = {key.removeprefix("model."): value for key, value in state_dict.items()} + config = infer_qwen3_config(state_dict) + with init_empty_weights(): + text_encoder = Qwen3Model(config) + + expected_keys = set(text_encoder.state_dict().keys()) + converted_keys = set(state_dict.keys()) + missing_keys = expected_keys - converted_keys + unexpected_keys = converted_keys - expected_keys + if missing_keys or unexpected_keys: + if missing_keys: + print(f"ERROR: missing Qwen3 keys ({len(missing_keys)}):", file=sys.stderr) + for key in sorted(missing_keys): + print(key, file=sys.stderr) + if unexpected_keys: + print(f"ERROR: unexpected Qwen3 keys ({len(unexpected_keys)}):", file=sys.stderr) + for key in sorted(unexpected_keys): + print(key, file=sys.stderr) + sys.exit(1) + + text_encoder.load_state_dict(state_dict, strict=True, assign=True) + return text_encoder + + +def split_anima_transformer_checkpoint( + state_dict: dict[str, torch.Tensor], +) -> tuple[dict[str, torch.Tensor], dict[str, torch.Tensor]]: + transformer_state_dict = {} + text_conditioner_state_dict = {} + adapter_prefix = "net.llm_adapter." + + for key, value in state_dict.items(): + if key.startswith(adapter_prefix): + text_conditioner_state_dict[key.removeprefix(adapter_prefix)] = value + else: + transformer_state_dict[key] = value + + return transformer_state_dict, text_conditioner_state_dict + + +def save_pipeline(args, transformer, text_conditioner, text_encoder, vae): + tokenizer = AutoTokenizer.from_pretrained(args.qwen_tokenizer_path) + t5_tokenizer = T5TokenizerFast.from_pretrained(args.t5_tokenizer_path) + scheduler = FlowMatchEulerDiscreteScheduler(shift=3.0) + + pipe = AnimaPipeline( + text_encoder=text_encoder, + tokenizer=tokenizer, + t5_tokenizer=t5_tokenizer, + text_conditioner=text_conditioner, + transformer=transformer, + vae=vae, + scheduler=scheduler, + ) + pipe.save_pretrained(args.output_path, safe_serialization=True, max_shard_size=args.max_shard_size) + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument("--transformer_ckpt_path", type=str, required=True, help="Path to Anima DiT safetensors") + parser.add_argument("--text_encoder_ckpt_path", type=str, required=True, help="Path to Qwen3 text encoder") + parser.add_argument("--vae_ckpt_path", type=str, required=True, help="Path to Qwen-Image VAE safetensors") + parser.add_argument("--qwen_tokenizer_path", type=str, default=None) + parser.add_argument("--t5_tokenizer_path", type=str, default=None) + parser.add_argument("--output_path", type=str, required=True) + parser.add_argument("--save_pipeline", action="store_true") + parser.add_argument("--dtype", default="bf16", choices=list(DTYPE_MAPPING.keys())) + parser.add_argument("--max_shard_size", default="5GB") + return parser.parse_args() + + +if __name__ == "__main__": + args = get_args() + output_path = pathlib.Path(args.output_path) + dtype = DTYPE_MAPPING[args.dtype] + + raw_transformer_state_dict = load_file(args.transformer_ckpt_path, device="cpu") + transformer_state_dict, text_conditioner_state_dict = split_anima_transformer_checkpoint(raw_transformer_state_dict) + transformer = convert_transformer( + "Cosmos-2.0-Diffusion-2B-Text2Image", state_dict=transformer_state_dict, weights_only=True + ).to(dtype=dtype) + text_conditioner = convert_text_conditioner(text_conditioner_state_dict).to(dtype=dtype) + + text_encoder_state_dict = load_file(args.text_encoder_ckpt_path, device="cpu") + text_encoder = convert_text_encoder(text_encoder_state_dict).to(dtype=dtype) + + vae_state_dict = load_file(args.vae_ckpt_path, device="cpu") + vae = convert_qwen_image_vae(vae_state_dict).to(dtype=dtype) + + if args.save_pipeline: + if args.qwen_tokenizer_path is None or args.t5_tokenizer_path is None: + raise ValueError("`--qwen_tokenizer_path` and `--t5_tokenizer_path` are required with `--save_pipeline`.") + save_pipeline(args, transformer, text_conditioner, text_encoder, vae) + else: + output_path.mkdir(parents=True, exist_ok=True) + transformer.save_pretrained( + output_path / "transformer", safe_serialization=True, max_shard_size=args.max_shard_size + ) + text_conditioner.save_pretrained( + output_path / "text_conditioner", safe_serialization=True, max_shard_size=args.max_shard_size + ) + text_encoder.save_pretrained( + output_path / "text_encoder", safe_serialization=True, max_shard_size=args.max_shard_size + ) + vae.save_pretrained(output_path / "vae", safe_serialization=True, max_shard_size=args.max_shard_size) diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 1b1f6b3032b3..a4b1c72f1cca 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -502,6 +502,8 @@ "AmusedImg2ImgPipeline", "AmusedInpaintPipeline", "AmusedPipeline", + "AnimaPipeline", + "AnimaTextConditioner", "AnimateDiffControlNetPipeline", "AnimateDiffPAGPipeline", "AnimateDiffPipeline", @@ -1301,6 +1303,7 @@ AmusedImg2ImgPipeline, AmusedInpaintPipeline, AmusedPipeline, + AnimaPipeline, AnimateDiffControlNetPipeline, AnimateDiffPAGPipeline, AnimateDiffPipeline, @@ -1308,6 +1311,7 @@ AnimateDiffSparseControlNetPipeline, AnimateDiffVideoToVideoControlNetPipeline, AnimateDiffVideoToVideoPipeline, + AnimaTextConditioner, AudioLDM2Pipeline, AudioLDM2ProjectionModel, AudioLDM2UNet2DConditionModel, diff --git a/src/diffusers/loaders/__init__.py b/src/diffusers/loaders/__init__.py index 488f77422dcd..33eeba673a98 100644 --- a/src/diffusers/loaders/__init__.py +++ b/src/diffusers/loaders/__init__.py @@ -63,6 +63,7 @@ def text_encoder_attn_modules(text_encoder): _import_structure["single_file"] = ["FromSingleFileMixin"] _import_structure["lora_pipeline"] = [ "AmusedLoraLoaderMixin", + "AnimaLoraLoaderMixin", "StableDiffusionLoraLoaderMixin", "SD3LoraLoaderMixin", "AuraFlowLoraLoaderMixin", @@ -116,6 +117,7 @@ def text_encoder_attn_modules(text_encoder): ) from .lora_pipeline import ( AmusedLoraLoaderMixin, + AnimaLoraLoaderMixin, AuraFlowLoraLoaderMixin, CogVideoXLoraLoaderMixin, CogView4LoraLoaderMixin, diff --git a/src/diffusers/loaders/lora_conversion_utils.py b/src/diffusers/loaders/lora_conversion_utils.py index 510a698e505f..bf516abc825f 100644 --- a/src/diffusers/loaders/lora_conversion_utils.py +++ b/src/diffusers/loaders/lora_conversion_utils.py @@ -2317,6 +2317,52 @@ def get_alpha_scales(down_weight, alpha_key): return converted_state_dict +def _convert_non_diffusers_anima_lora_to_diffusers(state_dict): + rename_dict = { + "blocks.": "transformer_blocks.", + "adaln_modulation_self_attn.1": "norm1.linear_1", + "adaln_modulation_self_attn.2": "norm1.linear_2", + "adaln_modulation_cross_attn.1": "norm2.linear_1", + "adaln_modulation_cross_attn.2": "norm2.linear_2", + "adaln_modulation_mlp.1": "norm3.linear_1", + "adaln_modulation_mlp.2": "norm3.linear_2", + "self_attn.q_proj": "attn1.to_q", + "self_attn.k_proj": "attn1.to_k", + "self_attn.v_proj": "attn1.to_v", + "self_attn.output_proj": "attn1.to_out.0", + "cross_attn.q_proj": "attn2.to_q", + "cross_attn.k_proj": "attn2.to_k", + "cross_attn.v_proj": "attn2.to_v", + "cross_attn.output_proj": "attn2.to_out.0", + "mlp.layer1": "ff.net.0.proj", + "mlp.layer2": "ff.net.2", + "final_layer.adaln_modulation.1": "norm_out.linear_1", + "final_layer.adaln_modulation.2": "norm_out.linear_2", + "final_layer.linear": "proj_out", + "t_embedder.1": "time_embed.t_embedder", + "t_embedding_norm": "time_embed.norm", + "x_embedder.proj.1": "patch_embed.proj", + } + + converted_state_dict = {} + for key, value in state_dict.items(): + if not key.startswith("diffusion_model."): + converted_state_dict[key] = value + continue + + new_key = key.removeprefix("diffusion_model.") + if new_key.startswith("llm_adapter."): + new_key = f"text_conditioner.{new_key.removeprefix('llm_adapter.')}" + else: + for old_key, new_key_part in rename_dict.items(): + new_key = new_key.replace(old_key, new_key_part) + new_key = f"transformer.{new_key}" + + converted_state_dict[new_key] = value + + return converted_state_dict + + def _convert_non_diffusers_flux2_lora_to_diffusers(state_dict): converted_state_dict = {} diff --git a/src/diffusers/loaders/lora_pipeline.py b/src/diffusers/loaders/lora_pipeline.py index 403e5a87db61..d25d39d61592 100644 --- a/src/diffusers/loaders/lora_pipeline.py +++ b/src/diffusers/loaders/lora_pipeline.py @@ -46,6 +46,7 @@ _convert_kohya_flux2_lora_to_diffusers, _convert_kohya_flux_lora_to_diffusers, _convert_musubi_wan_lora_to_diffusers, + _convert_non_diffusers_anima_lora_to_diffusers, _convert_non_diffusers_flux2_lora_to_diffusers, _convert_non_diffusers_hidream_lora_to_diffusers, _convert_non_diffusers_lora_to_diffusers, @@ -5615,6 +5616,208 @@ def unfuse_lora(self, components: list[str] = ["transformer"], **kwargs): super().unfuse_lora(components=components, **kwargs) +class AnimaLoraLoaderMixin(LoraBaseMixin): + r""" + Load LoRA layers into [`CosmosTransformer3DModel`] and [`AnimaTextConditioner`]. + """ + + _lora_loadable_modules = ["transformer", "text_conditioner"] + transformer_name = TRANSFORMER_NAME + text_conditioner_name = "text_conditioner" + + @classmethod + @validate_hf_hub_args + def lora_state_dict( + cls, + pretrained_model_name_or_path_or_dict: str | dict[str, torch.Tensor], + **kwargs, + ): + r""" + See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`] for more details. + """ + cache_dir = kwargs.pop("cache_dir", None) + force_download = kwargs.pop("force_download", False) + proxies = kwargs.pop("proxies", None) + local_files_only = kwargs.pop("local_files_only", None) + token = kwargs.pop("token", None) + revision = kwargs.pop("revision", None) + subfolder = kwargs.pop("subfolder", None) + weight_name = kwargs.pop("weight_name", None) + use_safetensors = kwargs.pop("use_safetensors", None) + return_lora_metadata = kwargs.pop("return_lora_metadata", False) + + allow_pickle = False + if use_safetensors is None: + use_safetensors = True + allow_pickle = True + + user_agent = {"file_type": "attn_procs_weights", "framework": "pytorch"} + + state_dict, metadata = _fetch_state_dict( + pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict, + weight_name=weight_name, + use_safetensors=use_safetensors, + local_files_only=local_files_only, + cache_dir=cache_dir, + force_download=force_download, + proxies=proxies, + token=token, + revision=revision, + subfolder=subfolder, + user_agent=user_agent, + allow_pickle=allow_pickle, + ) + + is_dora_scale_present = any("dora_scale" in k for k in state_dict) + if is_dora_scale_present: + warn_msg = "It seems like you are using a DoRA checkpoint that is not compatible in Diffusers at the moment. So, we are going to filter out the keys associated to 'dora_scale` from the state dict. If you think this is a mistake please open an issue https://github.com/huggingface/diffusers/issues/new." + logger.warning(warn_msg) + state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k} + + has_diffusion_model = any(k.startswith("diffusion_model.") for k in state_dict) + if has_diffusion_model: + state_dict = _convert_non_diffusers_anima_lora_to_diffusers(state_dict) + + out = (state_dict, metadata) if return_lora_metadata else state_dict + return out + + def load_lora_weights( + self, + pretrained_model_name_or_path_or_dict: str | dict[str, torch.Tensor], + adapter_name: str | None = None, + hotswap: bool = False, + **kwargs, + ): + """ + See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`] for more details. + """ + if not USE_PEFT_BACKEND: + raise ValueError("PEFT backend is required for this method.") + + low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT_LORA) + if low_cpu_mem_usage and is_peft_version("<", "0.13.0"): + raise ValueError( + "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`." + ) + + if isinstance(pretrained_model_name_or_path_or_dict, dict): + pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy() + + kwargs["return_lora_metadata"] = True + state_dict, metadata = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs) + + is_correct_format = all("lora" in key for key in state_dict.keys()) + if not is_correct_format: + raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `'lora'` substring.") + + transformer_state_dict = {k: v for k, v in state_dict.items() if k.startswith(f"{self.transformer_name}.")} + text_conditioner_state_dict = { + k: v for k, v in state_dict.items() if k.startswith(f"{self.text_conditioner_name}.") + } + + if transformer_state_dict: + self.load_lora_into_transformer( + transformer_state_dict, + transformer=self.transformer, + adapter_name=adapter_name, + metadata=metadata, + _pipeline=self, + low_cpu_mem_usage=low_cpu_mem_usage, + hotswap=hotswap, + ) + + if text_conditioner_state_dict: + self.load_lora_into_text_conditioner( + text_conditioner_state_dict, + text_conditioner=self.text_conditioner, + adapter_name=adapter_name, + metadata=metadata, + _pipeline=self, + low_cpu_mem_usage=low_cpu_mem_usage, + hotswap=hotswap, + ) + + @classmethod + def load_lora_into_transformer( + cls, + state_dict, + transformer, + adapter_name=None, + _pipeline=None, + low_cpu_mem_usage=False, + hotswap: bool = False, + metadata=None, + ): + if low_cpu_mem_usage and is_peft_version("<", "0.13.0"): + raise ValueError( + "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`." + ) + + logger.info(f"Loading {cls.transformer_name}.") + transformer.load_lora_adapter( + state_dict, + network_alphas=None, + adapter_name=adapter_name, + metadata=metadata, + _pipeline=_pipeline, + low_cpu_mem_usage=low_cpu_mem_usage, + hotswap=hotswap, + ) + + @classmethod + def load_lora_into_text_conditioner( + cls, + state_dict, + text_conditioner, + adapter_name=None, + _pipeline=None, + low_cpu_mem_usage=False, + hotswap: bool = False, + metadata=None, + ): + if low_cpu_mem_usage and is_peft_version("<", "0.13.0"): + raise ValueError( + "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`." + ) + + logger.info(f"Loading {cls.text_conditioner_name}.") + text_conditioner.load_lora_adapter( + state_dict, + prefix=cls.text_conditioner_name, + network_alphas=None, + adapter_name=adapter_name, + metadata=metadata, + _pipeline=_pipeline, + low_cpu_mem_usage=low_cpu_mem_usage, + hotswap=hotswap, + ) + + def fuse_lora( + self, + components: list[str] = ["transformer", "text_conditioner"], + lora_scale: float = 1.0, + safe_fusing: bool = False, + adapter_names: list[str] | None = None, + **kwargs, + ): + r""" + See [`~loaders.StableDiffusionLoraLoaderMixin.fuse_lora`] for more details. + """ + super().fuse_lora( + components=components, + lora_scale=lora_scale, + safe_fusing=safe_fusing, + adapter_names=adapter_names, + **kwargs, + ) + + def unfuse_lora(self, components: list[str] = ["transformer", "text_conditioner"], **kwargs): + r""" + See [`~loaders.StableDiffusionLoraLoaderMixin.unfuse_lora`] for more details. + """ + super().unfuse_lora(components=components, **kwargs) + + class Flux2LoraLoaderMixin(LoraBaseMixin): r""" Load LoRA layers into [`Flux2Transformer2DModel`]. Specific to [`Flux2Pipeline`]. diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py index f0fc7585bf31..c03d588ac152 100644 --- a/src/diffusers/pipelines/__init__.py +++ b/src/diffusers/pipelines/__init__.py @@ -100,6 +100,7 @@ _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) else: + _import_structure["anima"] = ["AnimaPipeline", "AnimaTextConditioner"] _import_structure["deprecated"].extend( [ "AmusedImg2ImgPipeline", @@ -588,6 +589,7 @@ AceStepPipeline, ) from .allegro import AllegroPipeline + from .anima import AnimaPipeline, AnimaTextConditioner from .animatediff import ( AnimateDiffControlNetPipeline, AnimateDiffPipeline, diff --git a/src/diffusers/pipelines/anima/__init__.py b/src/diffusers/pipelines/anima/__init__.py new file mode 100644 index 000000000000..bca1117ca7c6 --- /dev/null +++ b/src/diffusers/pipelines/anima/__init__.py @@ -0,0 +1,48 @@ +from typing import TYPE_CHECKING + +from ...utils import ( + DIFFUSERS_SLOW_IMPORT, + OptionalDependencyNotAvailable, + _LazyModule, + get_objects_from_module, + is_torch_available, + is_transformers_available, +) + + +_dummy_objects = {} +_import_structure = {} + +try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ...utils import dummy_torch_and_transformers_objects # noqa F403 + + _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) +else: + _import_structure["modeling_anima"] = ["AnimaTextConditioner"] + _import_structure["pipeline_anima"] = ["AnimaPipeline"] + + +if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: + try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + from ...utils.dummy_torch_and_transformers_objects import * # noqa F403 + else: + from .modeling_anima import AnimaTextConditioner + from .pipeline_anima import AnimaPipeline +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + module_spec=__spec__, + ) + + for name, value in _dummy_objects.items(): + setattr(sys.modules[__name__], name, value) diff --git a/src/diffusers/pipelines/anima/modeling_anima.py b/src/diffusers/pipelines/anima/modeling_anima.py new file mode 100644 index 000000000000..748a2d030162 --- /dev/null +++ b/src/diffusers/pipelines/anima/modeling_anima.py @@ -0,0 +1,295 @@ +# Copyright 2026 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from ...configuration_utils import ConfigMixin, register_to_config +from ...loaders import PeftAdapterMixin +from ...models.modeling_utils import ModelMixin + + +def _rotate_half(hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states_1 = hidden_states[..., : hidden_states.shape[-1] // 2] + hidden_states_2 = hidden_states[..., hidden_states.shape[-1] // 2 :] + return torch.cat((-hidden_states_2, hidden_states_1), dim=-1) + + +def _apply_rotary_pos_emb( + hidden_states: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, unsqueeze_dim: int = 1 +) -> torch.Tensor: + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + return (hidden_states * cos) + (_rotate_half(hidden_states) * sin) + + +class AnimaRotaryEmbedding(nn.Module): + def __init__(self, head_dim: int, rope_theta: float = 10000.0): + super().__init__() + inv_freq = 1.0 / ( + rope_theta ** (torch.arange(0, head_dim, 2, dtype=torch.int64).to(dtype=torch.float32) / head_dim) + ) + self.register_buffer("inv_freq", inv_freq, persistent=False) + + @torch.no_grad() + def forward(self, hidden_states: torch.Tensor, position_ids: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) + inv_freq_expanded = inv_freq_expanded.to(hidden_states.device) + position_ids_expanded = position_ids[:, None, :].float() + + device_type = hidden_states.device.type + device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" + with torch.autocast(device_type=device_type, enabled=False): + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos() + sin = emb.sin() + + return cos.to(dtype=hidden_states.dtype), sin.to(dtype=hidden_states.dtype) + + +class AnimaTextConditionerAttention(nn.Module): + def __init__( + self, + query_dim: int, + context_dim: int, + num_attention_heads: int, + attention_head_dim: int, + ): + super().__init__() + inner_dim = num_attention_heads * attention_head_dim + + self.num_attention_heads = num_attention_heads + self.attention_head_dim = attention_head_dim + self.q_proj = nn.Linear(query_dim, inner_dim, bias=False) + self.q_norm = nn.RMSNorm(attention_head_dim, eps=1e-6) + self.k_proj = nn.Linear(context_dim, inner_dim, bias=False) + self.k_norm = nn.RMSNorm(attention_head_dim, eps=1e-6) + self.v_proj = nn.Linear(context_dim, inner_dim, bias=False) + self.o_proj = nn.Linear(inner_dim, query_dim, bias=False) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor | None = None, + encoder_hidden_states: torch.Tensor | None = None, + position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None, + encoder_position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None, + ) -> torch.Tensor: + encoder_hidden_states = hidden_states if encoder_hidden_states is None else encoder_hidden_states + input_shape = hidden_states.shape[:-1] + encoder_input_shape = encoder_hidden_states.shape[:-1] + + query = self.q_proj(hidden_states) + key = self.k_proj(encoder_hidden_states) + value = self.v_proj(encoder_hidden_states) + + query = query.view(*input_shape, self.num_attention_heads, self.attention_head_dim).transpose(1, 2) + key = key.view(*encoder_input_shape, self.num_attention_heads, self.attention_head_dim).transpose(1, 2) + value = value.view(*encoder_input_shape, self.num_attention_heads, self.attention_head_dim).transpose(1, 2) + + query = self.q_norm(query) + key = self.k_norm(key) + + if position_embeddings is not None: + if encoder_position_embeddings is None: + raise ValueError("`encoder_position_embeddings` must be provided when using rotary embeddings.") + cos, sin = position_embeddings + query = _apply_rotary_pos_emb(query, cos, sin) + cos, sin = encoder_position_embeddings + key = _apply_rotary_pos_emb(key, cos, sin) + + hidden_states = F.scaled_dot_product_attention(query, key, value, attn_mask=attention_mask) + hidden_states = hidden_states.transpose(1, 2).reshape(*input_shape, -1).contiguous() + hidden_states = self.o_proj(hidden_states) + return hidden_states + + +class AnimaTextConditionerBlock(nn.Module): + def __init__( + self, + source_dim: int, + model_dim: int, + num_attention_heads: int = 16, + mlp_ratio: float = 4.0, + use_self_attention: bool = True, + use_layer_norm: bool = False, + ): + super().__init__() + self.use_self_attention = use_self_attention + norm_cls = nn.LayerNorm if use_layer_norm else nn.RMSNorm + norm_kwargs = {} if use_layer_norm else {"eps": 1e-6} + + if use_self_attention: + self.norm_self_attn = norm_cls(model_dim, **norm_kwargs) + self.self_attn = AnimaTextConditionerAttention( + query_dim=model_dim, + context_dim=model_dim, + num_attention_heads=num_attention_heads, + attention_head_dim=model_dim // num_attention_heads, + ) + + self.norm_cross_attn = norm_cls(model_dim, **norm_kwargs) + self.cross_attn = AnimaTextConditionerAttention( + query_dim=model_dim, + context_dim=source_dim, + num_attention_heads=num_attention_heads, + attention_head_dim=model_dim // num_attention_heads, + ) + self.norm_mlp = norm_cls(model_dim, **norm_kwargs) + self.mlp = nn.Sequential( + nn.Linear(model_dim, int(model_dim * mlp_ratio)), + nn.GELU(), + nn.Linear(int(model_dim * mlp_ratio), model_dim), + ) + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, + target_attention_mask: torch.Tensor | None = None, + source_attention_mask: torch.Tensor | None = None, + position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None, + source_position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None, + ) -> torch.Tensor: + if self.use_self_attention: + norm_hidden_states = self.norm_self_attn(hidden_states) + attn_hidden_states = self.self_attn( + norm_hidden_states, + attention_mask=target_attention_mask, + position_embeddings=position_embeddings, + encoder_position_embeddings=position_embeddings, + ) + hidden_states = hidden_states + attn_hidden_states + + norm_hidden_states = self.norm_cross_attn(hidden_states) + attn_hidden_states = self.cross_attn( + norm_hidden_states, + attention_mask=source_attention_mask, + encoder_hidden_states=encoder_hidden_states, + position_embeddings=position_embeddings, + encoder_position_embeddings=source_position_embeddings, + ) + hidden_states = hidden_states + attn_hidden_states + hidden_states = hidden_states + self.mlp(self.norm_mlp(hidden_states)) + return hidden_states + + +class AnimaTextConditioner(ModelMixin, ConfigMixin, PeftAdapterMixin): + r""" + Text conditioner used by Anima to map Qwen3 hidden states and T5 token ids to Cosmos text embeddings. + + Anima reuses the Cosmos Predict2 DiT. The only model-specific conditioning module is this LLM adapter, which + cross-attends from learned T5 token embeddings to Qwen3 text encoder hidden states before the diffusion loop. + """ + + _supports_gradient_checkpointing = True + _no_split_modules = ["AnimaTextConditionerBlock"] + + @register_to_config + def __init__( + self, + source_dim: int = 1024, + target_dim: int = 1024, + model_dim: int = 1024, + num_layers: int = 6, + num_attention_heads: int = 16, + mlp_ratio: float = 4.0, + target_vocab_size: int = 32128, + use_self_attention: bool = True, + use_layer_norm: bool = False, + min_sequence_length: int = 512, + ): + super().__init__() + self.embed = nn.Embedding(target_vocab_size, target_dim) + self.in_proj = nn.Linear(target_dim, model_dim) if model_dim != target_dim else nn.Identity() + self.rotary_emb = AnimaRotaryEmbedding(model_dim // num_attention_heads) + self.blocks = nn.ModuleList( + [ + AnimaTextConditionerBlock( + source_dim=source_dim, + model_dim=model_dim, + num_attention_heads=num_attention_heads, + mlp_ratio=mlp_ratio, + use_self_attention=use_self_attention, + use_layer_norm=use_layer_norm, + ) + for _ in range(num_layers) + ] + ) + self.out_proj = nn.Linear(model_dim, target_dim) + self.norm = nn.RMSNorm(target_dim, eps=1e-6) + self.gradient_checkpointing = False + + @staticmethod + def _prepare_attention_mask(attention_mask: torch.Tensor | None) -> torch.Tensor | None: + if attention_mask is None: + return None + attention_mask = attention_mask.to(torch.bool) + if attention_mask.ndim == 2: + attention_mask = attention_mask.unsqueeze(1).unsqueeze(1) + return attention_mask + + def forward( + self, + source_hidden_states: torch.Tensor, + target_input_ids: torch.Tensor, + target_attention_mask: torch.Tensor | None = None, + source_attention_mask: torch.Tensor | None = None, + target_token_weights: torch.Tensor | None = None, + ) -> torch.Tensor: + target_attention_mask = self._prepare_attention_mask(target_attention_mask) + source_attention_mask = self._prepare_attention_mask(source_attention_mask) + + hidden_states = self.embed(target_input_ids).to(dtype=source_hidden_states.dtype) + hidden_states = self.in_proj(hidden_states) + + position_ids = torch.arange(hidden_states.shape[1], device=hidden_states.device).unsqueeze(0) + source_position_ids = torch.arange(source_hidden_states.shape[1], device=hidden_states.device).unsqueeze(0) + position_embeddings = self.rotary_emb(hidden_states, position_ids) + source_position_embeddings = self.rotary_emb(hidden_states, source_position_ids) + + for block in self.blocks: + if torch.is_grad_enabled() and self.gradient_checkpointing: + hidden_states = self._gradient_checkpointing_func( + block, + hidden_states, + source_hidden_states, + target_attention_mask, + source_attention_mask, + position_embeddings, + source_position_embeddings, + ) + else: + hidden_states = block( + hidden_states, + source_hidden_states, + target_attention_mask=target_attention_mask, + source_attention_mask=source_attention_mask, + position_embeddings=position_embeddings, + source_position_embeddings=source_position_embeddings, + ) + + hidden_states = self.norm(self.out_proj(hidden_states)) + + if target_token_weights is not None: + hidden_states = hidden_states * target_token_weights.to(hidden_states).unsqueeze(-1) + if target_attention_mask is not None: + hidden_states = hidden_states * target_attention_mask.squeeze(1).squeeze(1).to(hidden_states).unsqueeze(-1) + + if hidden_states.shape[1] < self.config.min_sequence_length: + hidden_states = F.pad(hidden_states, (0, 0, 0, self.config.min_sequence_length - hidden_states.shape[1])) + + return hidden_states diff --git a/src/diffusers/pipelines/anima/pipeline_anima.py b/src/diffusers/pipelines/anima/pipeline_anima.py new file mode 100644 index 000000000000..f269308c2dec --- /dev/null +++ b/src/diffusers/pipelines/anima/pipeline_anima.py @@ -0,0 +1,566 @@ +# Copyright 2026 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Callable + +import numpy as np +import torch +from transformers import PreTrainedModel, PreTrainedTokenizer, T5TokenizerFast + +from ...callbacks import MultiPipelineCallbacks, PipelineCallback +from ...image_processor import VaeImageProcessor +from ...loaders import AnimaLoraLoaderMixin +from ...models import AutoencoderKLQwenImage, CosmosTransformer3DModel +from ...schedulers import FlowMatchEulerDiscreteScheduler +from ...utils import is_torch_xla_available, logging, replace_example_docstring +from ...utils.torch_utils import randn_tensor +from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput +from .modeling_anima import AnimaTextConditioner + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +EXAMPLE_DOC_STRING = """ + Examples: + ```python + >>> import torch + >>> from diffusers import AnimaPipeline + + >>> pipe = AnimaPipeline.from_pretrained("path/to/anima-diffusers", torch_dtype=torch.bfloat16) + >>> pipe.to("cuda") + >>> image = pipe("A cinematic portrait of a woman in a rain-soaked city street").images[0] + >>> image.save("anima.png") + ``` +""" + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: int | None = None, + device: str | torch.device | None = None, + timesteps: list[int] | None = None, + sigmas: list[float] | None = None, + **kwargs, +): + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +class AnimaPipeline(DiffusionPipeline, AnimaLoraLoaderMixin): + r""" + Pipeline for text-to-image generation using Anima. + + Anima uses a Qwen3 text encoder, a T5-token LLM adapter, the Cosmos Predict2 DiT, and the Qwen-Image VAE. + Supports loading LoRA weights with [`~loaders.AnimaLoraLoaderMixin.load_lora_weights`]. + + Args: + text_encoder (`~transformers.PreTrainedModel`): + Qwen3 text encoder used to produce source hidden states for the Anima text conditioner. + tokenizer (`~transformers.PreTrainedTokenizer`): + Qwen tokenizer paired with `text_encoder`. + t5_tokenizer (`~transformers.T5TokenizerFast`): + T5 tokenizer used to produce target token ids for the Anima text conditioner. + text_conditioner ([`AnimaTextConditioner`]): + Adapter that maps Qwen3 hidden states and T5 token ids to Cosmos text embeddings. + transformer ([`CosmosTransformer3DModel`]): + Cosmos Predict2 transformer used to denoise image latents. + vae ([`AutoencoderKLQwenImage`]): + Qwen-Image VAE used to decode latents into images. + scheduler ([`FlowMatchEulerDiscreteScheduler`]): + Flow-matching scheduler used for denoising. + """ + + model_cpu_offload_seq = "text_encoder->text_conditioner->transformer->vae" + _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"] + + def __init__( + self, + text_encoder: PreTrainedModel, + tokenizer: PreTrainedTokenizer, + t5_tokenizer: T5TokenizerFast, + text_conditioner: AnimaTextConditioner, + transformer: CosmosTransformer3DModel, + vae: AutoencoderKLQwenImage, + scheduler: FlowMatchEulerDiscreteScheduler, + ): + super().__init__() + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + t5_tokenizer=t5_tokenizer, + text_conditioner=text_conditioner, + transformer=transformer, + scheduler=scheduler, + ) + + self.vae_scale_factor = 2 ** len(self.vae.temperal_downsample) if getattr(self, "vae", None) else 8 + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) + self.default_sample_size = 128 + + def _get_qwen_prompt_embeds( + self, + prompt: str | list[str], + max_sequence_length: int, + device: torch.device | None = None, + dtype: torch.dtype | None = None, + ) -> tuple[torch.Tensor, torch.Tensor]: + device = device or self._execution_device + dtype = dtype or self.text_encoder.dtype + prompt = [prompt] if isinstance(prompt, str) else prompt + + text_inputs = self.tokenizer( + prompt, + padding="longest", + max_length=max_sequence_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids.to(device) + prompt_attention_mask = text_inputs.attention_mask.to(device) + + prompt_embeds = self.text_encoder( + input_ids=text_input_ids, + attention_mask=prompt_attention_mask, + output_hidden_states=False, + ).last_hidden_state + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + prompt_embeds = prompt_embeds * prompt_attention_mask.to(prompt_embeds).unsqueeze(-1) + + return prompt_embeds, prompt_attention_mask + + def _get_t5_prompt_ids( + self, + prompt: str | list[str], + max_sequence_length: int, + device: torch.device | None = None, + ) -> tuple[torch.Tensor, torch.Tensor]: + device = device or self._execution_device + prompt = [prompt] if isinstance(prompt, str) else prompt + + text_inputs = self.t5_tokenizer( + prompt, + padding="longest", + max_length=max_sequence_length, + truncation=True, + return_tensors="pt", + ) + return text_inputs.input_ids.to(device), text_inputs.attention_mask.to(device) + + def encode_prompt( + self, + prompt: str | list[str], + negative_prompt: str | list[str] | None = None, + do_classifier_free_guidance: bool = True, + num_images_per_prompt: int = 1, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + max_sequence_length: int = 512, + device: torch.device | None = None, + dtype: torch.dtype | None = None, + ) -> tuple[torch.Tensor, torch.Tensor | None]: + device = device or self._execution_device + dtype = dtype or self.text_conditioner.dtype + + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) if prompt is not None else prompt_embeds.shape[0] + + if prompt_embeds is None: + qwen_prompt_embeds, qwen_attention_mask = self._get_qwen_prompt_embeds( + prompt=prompt, max_sequence_length=max_sequence_length, device=device, dtype=dtype + ) + t5_input_ids, t5_attention_mask = self._get_t5_prompt_ids( + prompt=prompt, max_sequence_length=max_sequence_length, device=device + ) + prompt_embeds = self.text_conditioner( + source_hidden_states=qwen_prompt_embeds, + target_input_ids=t5_input_ids, + target_attention_mask=t5_attention_mask, + source_attention_mask=qwen_attention_mask, + ) + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + + _, seq_len, _ = prompt_embeds.shape + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + if do_classifier_free_guidance and negative_prompt_embeds is None: + negative_prompt = negative_prompt if negative_prompt is not None else "" + negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + + if prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + if batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + + negative_qwen_prompt_embeds, negative_qwen_attention_mask = self._get_qwen_prompt_embeds( + prompt=negative_prompt, max_sequence_length=max_sequence_length, device=device, dtype=dtype + ) + negative_t5_input_ids, negative_t5_attention_mask = self._get_t5_prompt_ids( + prompt=negative_prompt, max_sequence_length=max_sequence_length, device=device + ) + negative_prompt_embeds = self.text_conditioner( + source_hidden_states=negative_qwen_prompt_embeds, + target_input_ids=negative_t5_input_ids, + target_attention_mask=negative_t5_attention_mask, + source_attention_mask=negative_qwen_attention_mask, + ) + negative_prompt_embeds = negative_prompt_embeds.to(dtype=dtype, device=device) + + _, seq_len, _ = negative_prompt_embeds.shape + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + return prompt_embeds, negative_prompt_embeds + + def check_inputs( + self, + prompt, + height, + width, + prompt_embeds=None, + negative_prompt=None, + negative_prompt_embeds=None, + callback_on_step_end_tensor_inputs=None, + max_sequence_length=None, + ): + if height % (self.vae_scale_factor * 2) != 0 or width % (self.vae_scale_factor * 2) != 0: + raise ValueError( + f"`height` and `width` have to be divisible by {self.vae_scale_factor * 2} but are {height} and" + f" {width}." + ) + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found " + f"{[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + if prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + if prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if max_sequence_length is not None and max_sequence_length > 4096: + raise ValueError(f"`max_sequence_length` cannot be greater than 4096 but is {max_sequence_length}") + + def prepare_latents( + self, + batch_size: int, + num_channels_latents: int, + height: int, + width: int, + num_frames: int, + dtype: torch.dtype, + device: torch.device, + generator: torch.Generator | list[torch.Generator] | None, + latents: torch.Tensor | None = None, + ) -> torch.Tensor: + if latents is not None: + return latents.to(device=device, dtype=dtype) + + num_latent_frames = (num_frames - 1) // self.vae_scale_factor + 1 + latent_height = height // self.vae_scale_factor + latent_width = width // self.vae_scale_factor + shape = (batch_size, num_channels_latents, num_latent_frames, latent_height, latent_width) + + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + return randn_tensor(shape, generator=generator, device=device, dtype=dtype) + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def current_timestep(self): + return self._current_timestep + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: str | list[str] = None, + negative_prompt: str | list[str] | None = None, + height: int | None = None, + width: int | None = None, + num_inference_steps: int = 50, + sigmas: list[float] | None = None, + guidance_scale: float = 4.0, + num_images_per_prompt: int = 1, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.Tensor | None = None, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + output_type: str | None = "pil", + return_dict: bool = True, + callback_on_step_end: Callable[[int, int], None] | PipelineCallback | MultiPipelineCallbacks | None = None, + callback_on_step_end_tensor_inputs: list[str] = ["latents"], + max_sequence_length: int = 512, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to guide image generation. If not defined, pass `prompt_embeds`. + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide image generation. Used when `guidance_scale > 1`. + height (`int`, *optional*, defaults to `1024`): + Height in pixels of the generated image. + width (`int`, *optional*, defaults to `1024`): + Width in pixels of the generated image. + num_inference_steps (`int`, *optional*, defaults to 50): + Number of denoising steps. + sigmas (`list[float]`, *optional*): + Custom sigma schedule to use for schedulers that support `sigmas`. + guidance_scale (`float`, *optional*, defaults to 4.0): + Classifier-free guidance scale. + num_images_per_prompt (`int`, *optional*, defaults to 1): + Number of images to generate per prompt. + generator (`torch.Generator` or `list[torch.Generator]`, *optional*): + Random generator for deterministic generation. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated conditioned prompt embeddings. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated conditioned negative prompt embeddings. + output_type (`str`, *optional*, defaults to `"pil"`): + Output format, one of `"pil"`, `"np"`, `"pt"`, or `"latent"`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether to return an [`~pipelines.ImagePipelineOutput`]. + callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*): + Function called at the end of each denoising step. + callback_on_step_end_tensor_inputs (`list`, *optional*): + Tensor inputs available to `callback_on_step_end`. + max_sequence_length (`int`, *optional*, defaults to 512): + Maximum sequence length used by both text tokenizers. + + Examples: + + Returns: + [`~pipelines.ImagePipelineOutput`] or `tuple`: + Generated images if `return_dict` is `True`; otherwise a tuple whose first item is the images. + """ + + height = height or self.default_sample_size * self.vae_scale_factor + width = width or self.default_sample_size * self.vae_scale_factor + + if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + + self.check_inputs( + prompt, + height, + width, + prompt_embeds, + negative_prompt, + negative_prompt_embeds, + callback_on_step_end_tensor_inputs, + max_sequence_length, + ) + + self._guidance_scale = guidance_scale + self._current_timestep = None + self._interrupt = False + + device = self._execution_device + num_frames = 1 + do_classifier_free_guidance = guidance_scale > 1.0 + + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + prompt_embeds, negative_prompt_embeds = self.encode_prompt( + prompt=prompt, + negative_prompt=negative_prompt, + do_classifier_free_guidance=do_classifier_free_guidance, + num_images_per_prompt=num_images_per_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + max_sequence_length=max_sequence_length, + device=device, + dtype=self.transformer.dtype, + ) + + sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas + timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, device=device, sigmas=sigmas) + self.scheduler.set_begin_index(0) + + transformer_dtype = self.transformer.dtype + num_channels_latents = self.transformer.config.in_channels + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + num_frames, + torch.float32, + device, + generator, + latents, + ) + + padding_mask = latents.new_zeros(1, 1, height, width, dtype=transformer_dtype) + + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + self._num_timesteps = len(timesteps) + + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + self._current_timestep = t + timestep = t.expand(latents.shape[0]).to(transformer_dtype) + timestep = timestep / self.scheduler.config.num_train_timesteps + latent_model_input = latents.to(transformer_dtype) + + noise_pred = self.transformer( + hidden_states=latent_model_input, + timestep=timestep, + encoder_hidden_states=prompt_embeds, + padding_mask=padding_mask, + return_dict=False, + )[0] + + if do_classifier_free_guidance: + negative_noise_pred = self.transformer( + hidden_states=latent_model_input, + timestep=timestep, + encoder_hidden_states=negative_prompt_embeds, + padding_mask=padding_mask, + return_dict=False, + )[0] + noise_pred = negative_noise_pred + self.guidance_scale * (noise_pred - negative_noise_pred) + + latents_dtype = latents.dtype + latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] + if latents.dtype != latents_dtype and torch.backends.mps.is_available(): + latents = latents.to(latents_dtype) + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + self._current_timestep = None + + if output_type == "latent": + image = latents[:, :, 0] + else: + latents = latents.to(self.vae.dtype) + latents_mean = ( + torch.tensor(self.vae.config.latents_mean) + .view(1, self.vae.config.z_dim, 1, 1, 1) + .to(latents.device, latents.dtype) + ) + latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to( + latents.device, latents.dtype + ) + latents = latents / latents_std + latents_mean + image = self.vae.decode(latents, return_dict=False)[0][:, :, 0] + image = self.image_processor.postprocess(image, output_type=output_type) + + self.maybe_free_model_hooks() + + if not return_dict: + return (image,) + + return ImagePipelineOutput(images=image) diff --git a/tests/pipelines/anima/__init__.py b/tests/pipelines/anima/__init__.py new file mode 100644 index 000000000000..b51b917e8d98 --- /dev/null +++ b/tests/pipelines/anima/__init__.py @@ -0,0 +1 @@ +# Empty init for Anima pipeline tests. diff --git a/tests/pipelines/anima/test_anima.py b/tests/pipelines/anima/test_anima.py new file mode 100644 index 000000000000..c3c3d5c46484 --- /dev/null +++ b/tests/pipelines/anima/test_anima.py @@ -0,0 +1,201 @@ +# Copyright 2026 The HuggingFace Team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import tempfile +import unittest + +import torch +from transformers import Qwen2Tokenizer, Qwen3Config, Qwen3Model, T5TokenizerFast + +from diffusers import ( + AnimaPipeline, + AnimaTextConditioner, + AutoencoderKLQwenImage, + CosmosTransformer3DModel, + FlowMatchEulerDiscreteScheduler, +) + +from ...testing_utils import enable_full_determinism, require_peft_backend + + +enable_full_determinism() + + +class AnimaTextConditionerFastTests(unittest.TestCase): + def test_conditioner_output_shape_and_padding(self): + conditioner = AnimaTextConditioner( + source_dim=16, + target_dim=16, + model_dim=16, + num_layers=2, + num_attention_heads=4, + target_vocab_size=128, + min_sequence_length=8, + ) + source_hidden_states = torch.randn(2, 5, 16) + target_input_ids = torch.randint(0, 128, (2, 4)) + source_attention_mask = torch.ones(2, 5) + target_attention_mask = torch.ones(2, 4) + target_attention_mask[1, -1] = 0 + + output = conditioner( + source_hidden_states=source_hidden_states, + target_input_ids=target_input_ids, + source_attention_mask=source_attention_mask, + target_attention_mask=target_attention_mask, + ) + + self.assertEqual(output.shape, (2, 8, 16)) + self.assertTrue(torch.allclose(output[1, 3], torch.zeros_like(output[1, 3]), atol=1e-5)) + self.assertTrue(torch.allclose(output[:, 4:], torch.zeros_like(output[:, 4:]), atol=1e-5)) + + +class AnimaPipelineFastTests(unittest.TestCase): + pipeline_class = AnimaPipeline + + def get_dummy_components(self): + torch.manual_seed(0) + transformer = CosmosTransformer3DModel( + in_channels=4, + out_channels=4, + num_attention_heads=2, + attention_head_dim=16, + num_layers=2, + mlp_ratio=2, + text_embed_dim=16, + adaln_lora_dim=4, + max_size=(4, 32, 32), + patch_size=(1, 2, 2), + rope_scale=(1.0, 4.0, 4.0), + concat_padding_mask=True, + extra_pos_embed_type=None, + ) + + torch.manual_seed(0) + vae = AutoencoderKLQwenImage( + base_dim=24, + z_dim=4, + dim_mult=[1, 2, 4], + num_res_blocks=1, + temperal_downsample=[False, True], + latents_mean=[0.0] * 4, + latents_std=[1.0] * 4, + ) + + torch.manual_seed(0) + text_conditioner = AnimaTextConditioner( + source_dim=16, + target_dim=16, + model_dim=16, + num_layers=2, + num_attention_heads=4, + target_vocab_size=32128, + min_sequence_length=16, + ) + + torch.manual_seed(0) + text_encoder_config = Qwen3Config( + vocab_size=152064, + hidden_size=16, + intermediate_size=32, + num_hidden_layers=2, + num_attention_heads=4, + num_key_value_heads=2, + max_position_embeddings=128, + rms_norm_eps=1e-6, + rope_theta=1000000.0, + head_dim=4, + attention_bias=False, + ) + text_encoder = Qwen3Model(text_encoder_config).eval() + tokenizer = Qwen2Tokenizer.from_pretrained("hf-internal-testing/tiny-random-Qwen2VLForConditionalGeneration") + t5_tokenizer = T5TokenizerFast.from_pretrained("hf-internal-testing/tiny-random-t5") + scheduler = FlowMatchEulerDiscreteScheduler(shift=3.0) + + return { + "transformer": transformer, + "vae": vae, + "scheduler": scheduler, + "text_encoder": text_encoder, + "tokenizer": tokenizer, + "t5_tokenizer": t5_tokenizer, + "text_conditioner": text_conditioner, + } + + def get_dummy_inputs(self, seed=0): + generator = torch.Generator(device="cpu").manual_seed(seed) + return { + "prompt": "dance monkey", + "negative_prompt": "bad quality", + "generator": generator, + "num_inference_steps": 2, + "guidance_scale": 3.0, + "height": 32, + "width": 32, + "max_sequence_length": 16, + "output_type": "pt", + } + + def test_inference(self): + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + pipe.set_progress_bar_config(disable=None) + + output = pipe(**self.get_dummy_inputs()).images + + self.assertEqual(output.shape, (1, 3, 32, 32)) + self.assertFalse(torch.isnan(output).any()) + + def test_save_load_optional_components(self): + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + + with tempfile.TemporaryDirectory() as tmpdir: + pipe.save_pretrained(tmpdir, safe_serialization=True) + pipe = self.pipeline_class.from_pretrained(tmpdir) + + self.assertIsInstance(pipe.text_conditioner, AnimaTextConditioner) + self.assertIsInstance(pipe.transformer, CosmosTransformer3DModel) + + def test_lora_state_dict_conversion(self): + state_dict = { + "diffusion_model.blocks.0.self_attn.q_proj.lora_A.weight": torch.randn(2, 32), + "diffusion_model.blocks.0.self_attn.q_proj.lora_B.weight": torch.randn(32, 2), + "diffusion_model.blocks.0.adaln_modulation_cross_attn.1.lora_A.weight": torch.randn(2, 32), + "diffusion_model.blocks.0.adaln_modulation_cross_attn.1.lora_B.weight": torch.randn(4, 2), + "diffusion_model.llm_adapter.blocks.0.self_attn.q_proj.lora_A.weight": torch.randn(2, 16), + "diffusion_model.llm_adapter.blocks.0.self_attn.q_proj.lora_B.weight": torch.randn(16, 2), + } + + converted_state_dict = self.pipeline_class.lora_state_dict(state_dict) + + self.assertIn("transformer.transformer_blocks.0.attn1.to_q.lora_A.weight", converted_state_dict) + self.assertIn("transformer.transformer_blocks.0.norm2.linear_1.lora_B.weight", converted_state_dict) + self.assertIn("text_conditioner.blocks.0.self_attn.q_proj.lora_A.weight", converted_state_dict) + + @require_peft_backend + def test_load_lora_weights(self): + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + state_dict = { + "diffusion_model.blocks.0.self_attn.q_proj.lora_A.weight": torch.randn(2, 32), + "diffusion_model.blocks.0.self_attn.q_proj.lora_B.weight": torch.randn(32, 2), + "diffusion_model.llm_adapter.blocks.0.self_attn.q_proj.lora_A.weight": torch.randn(2, 16), + "diffusion_model.llm_adapter.blocks.0.self_attn.q_proj.lora_B.weight": torch.randn(16, 2), + } + + pipe.load_lora_weights(state_dict, adapter_name="dummy") + + self.assertIn("dummy", pipe.transformer.peft_config) + self.assertIn("dummy", pipe.text_conditioner.peft_config) From 68425646d8aee7a390e7aa0b7ebc4b14567fd04f Mon Sep 17 00:00:00 2001 From: rmatif Date: Tue, 12 May 2026 13:38:25 +0200 Subject: [PATCH 2/6] Fix empty Anima negative prompts --- src/diffusers/pipelines/anima/pipeline_anima.py | 3 +++ tests/pipelines/anima/test_anima.py | 12 ++++++++++++ 2 files changed, 15 insertions(+) diff --git a/src/diffusers/pipelines/anima/pipeline_anima.py b/src/diffusers/pipelines/anima/pipeline_anima.py index f269308c2dec..7ae777322607 100644 --- a/src/diffusers/pipelines/anima/pipeline_anima.py +++ b/src/diffusers/pipelines/anima/pipeline_anima.py @@ -164,6 +164,9 @@ def _get_qwen_prompt_embeds( ) text_input_ids = text_inputs.input_ids.to(device) prompt_attention_mask = text_inputs.attention_mask.to(device) + if text_input_ids.shape[-1] == 0: + text_input_ids = text_input_ids.new_zeros((text_input_ids.shape[0], 1)) + prompt_attention_mask = prompt_attention_mask.new_zeros((prompt_attention_mask.shape[0], 1)) prompt_embeds = self.text_encoder( input_ids=text_input_ids, diff --git a/tests/pipelines/anima/test_anima.py b/tests/pipelines/anima/test_anima.py index c3c3d5c46484..fdf8211672e6 100644 --- a/tests/pipelines/anima/test_anima.py +++ b/tests/pipelines/anima/test_anima.py @@ -157,6 +157,18 @@ def test_inference(self): self.assertEqual(output.shape, (1, 3, 32, 32)) self.assertFalse(torch.isnan(output).any()) + def test_inference_empty_negative_prompt(self): + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + pipe.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs() + inputs["negative_prompt"] = "" + output = pipe(**inputs).images + + self.assertEqual(output.shape, (1, 3, 32, 32)) + self.assertFalse(torch.isnan(output).any()) + def test_save_load_optional_components(self): components = self.get_dummy_components() pipe = self.pipeline_class(**components) From aece3f393298f2bedba0efdd2f987d8a7fb52e60 Mon Sep 17 00:00:00 2001 From: rmatif Date: Wed, 13 May 2026 12:05:06 +0200 Subject: [PATCH 3/6] Fix Anima registration --- docs/source/en/_toctree.yml | 10 +++---- src/diffusers/__init__.py | 2 +- .../dummy_torch_and_transformers_objects.py | 30 +++++++++++++++++++ 3 files changed, 36 insertions(+), 6 deletions(-) diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index 1947e11d799e..00ad1bc0d96e 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -499,11 +499,11 @@ - local: api/pipelines/stable_audio title: Stable Audio title: Audio - - sections: - - local: api/pipelines/anima - title: Anima - - local: api/pipelines/animatediff - title: AnimateDiff + - sections: + - local: api/pipelines/anima + title: Anima + - local: api/pipelines/animatediff + title: AnimateDiff - local: api/pipelines/aura_flow title: AuraFlow - local: api/pipelines/bria_3_2 diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index a4b1c72f1cca..8a9499c7d077 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -1304,6 +1304,7 @@ AmusedInpaintPipeline, AmusedPipeline, AnimaPipeline, + AnimaTextConditioner, AnimateDiffControlNetPipeline, AnimateDiffPAGPipeline, AnimateDiffPipeline, @@ -1311,7 +1312,6 @@ AnimateDiffSparseControlNetPipeline, AnimateDiffVideoToVideoControlNetPipeline, AnimateDiffVideoToVideoPipeline, - AnimaTextConditioner, AudioLDM2Pipeline, AudioLDM2ProjectionModel, AudioLDM2UNet2DConditionModel, diff --git a/src/diffusers/utils/dummy_torch_and_transformers_objects.py b/src/diffusers/utils/dummy_torch_and_transformers_objects.py index cfa1318783f3..b8b05fce8818 100644 --- a/src/diffusers/utils/dummy_torch_and_transformers_objects.py +++ b/src/diffusers/utils/dummy_torch_and_transformers_objects.py @@ -812,6 +812,36 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) +class AnimaPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + +class AnimaTextConditioner(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + class AnimateDiffControlNetPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] From 922c516d995528565c7a2a8e2363de23d79846bd Mon Sep 17 00:00:00 2001 From: rmatif Date: Wed, 13 May 2026 12:33:06 +0200 Subject: [PATCH 4/6] Clean up Anima conditioner --- .../pipelines/anima/modeling_anima.py | 103 ++++++++++++------ .../pipelines/anima/pipeline_anima.py | 23 ++++ 2 files changed, 95 insertions(+), 31 deletions(-) diff --git a/src/diffusers/pipelines/anima/modeling_anima.py b/src/diffusers/pipelines/anima/modeling_anima.py index 748a2d030162..1241f5e5eaad 100644 --- a/src/diffusers/pipelines/anima/modeling_anima.py +++ b/src/diffusers/pipelines/anima/modeling_anima.py @@ -18,6 +18,8 @@ from ...configuration_utils import ConfigMixin, register_to_config from ...loaders import PeftAdapterMixin +from ...models.attention import AttentionModuleMixin +from ...models.attention_dispatch import dispatch_attention_fn from ...models.modeling_utils import ModelMixin @@ -60,13 +62,69 @@ def forward(self, hidden_states: torch.Tensor, position_ids: torch.Tensor) -> tu return cos.to(dtype=hidden_states.dtype), sin.to(dtype=hidden_states.dtype) -class AnimaTextConditionerAttention(nn.Module): +class AnimaTextConditionerAttnProcessor: + _attention_backend = None + _parallel_config = None + + def __call__( + self, + attn: "AnimaTextConditionerAttention", + hidden_states: torch.Tensor, + attention_mask: torch.Tensor | None = None, + encoder_hidden_states: torch.Tensor | None = None, + position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None, + encoder_position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None, + ) -> torch.Tensor: + encoder_hidden_states = hidden_states if encoder_hidden_states is None else encoder_hidden_states + input_shape = hidden_states.shape[:-1] + encoder_input_shape = encoder_hidden_states.shape[:-1] + + query = attn.q_proj(hidden_states) + key = attn.k_proj(encoder_hidden_states) + value = attn.v_proj(encoder_hidden_states) + + query = query.view(*input_shape, attn.num_attention_heads, attn.attention_head_dim) + key = key.view(*encoder_input_shape, attn.num_attention_heads, attn.attention_head_dim) + value = value.view(*encoder_input_shape, attn.num_attention_heads, attn.attention_head_dim) + + query = attn.q_norm(query) + key = attn.k_norm(key) + + if position_embeddings is not None: + if encoder_position_embeddings is None: + raise ValueError("`encoder_position_embeddings` must be provided when using rotary embeddings.") + cos, sin = position_embeddings + query = _apply_rotary_pos_emb(query, cos, sin, unsqueeze_dim=2) + cos, sin = encoder_position_embeddings + key = _apply_rotary_pos_emb(key, cos, sin, unsqueeze_dim=2) + + hidden_states = dispatch_attention_fn( + query, + key, + value, + attn_mask=attention_mask, + dropout_p=0.0, + is_causal=False, + backend=self._attention_backend, + parallel_config=self._parallel_config, + ) + hidden_states = hidden_states.flatten(2, 3).contiguous() + hidden_states = attn.o_proj(hidden_states) + return hidden_states + + +class AnimaTextConditionerAttention(nn.Module, AttentionModuleMixin): + _default_processor_cls = AnimaTextConditionerAttnProcessor + _available_processors = [AnimaTextConditionerAttnProcessor] + _supports_qkv_fusion = False + def __init__( self, query_dim: int, context_dim: int, num_attention_heads: int, attention_head_dim: int, + processor: AnimaTextConditionerAttnProcessor | None = None, ): super().__init__() inner_dim = num_attention_heads * attention_head_dim @@ -80,6 +138,10 @@ def __init__( self.v_proj = nn.Linear(context_dim, inner_dim, bias=False) self.o_proj = nn.Linear(inner_dim, query_dim, bias=False) + if processor is None: + processor = self._default_processor_cls() + self.set_processor(processor) + def forward( self, hidden_states: torch.Tensor, @@ -88,33 +150,14 @@ def forward( position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None, encoder_position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None, ) -> torch.Tensor: - encoder_hidden_states = hidden_states if encoder_hidden_states is None else encoder_hidden_states - input_shape = hidden_states.shape[:-1] - encoder_input_shape = encoder_hidden_states.shape[:-1] - - query = self.q_proj(hidden_states) - key = self.k_proj(encoder_hidden_states) - value = self.v_proj(encoder_hidden_states) - - query = query.view(*input_shape, self.num_attention_heads, self.attention_head_dim).transpose(1, 2) - key = key.view(*encoder_input_shape, self.num_attention_heads, self.attention_head_dim).transpose(1, 2) - value = value.view(*encoder_input_shape, self.num_attention_heads, self.attention_head_dim).transpose(1, 2) - - query = self.q_norm(query) - key = self.k_norm(key) - - if position_embeddings is not None: - if encoder_position_embeddings is None: - raise ValueError("`encoder_position_embeddings` must be provided when using rotary embeddings.") - cos, sin = position_embeddings - query = _apply_rotary_pos_emb(query, cos, sin) - cos, sin = encoder_position_embeddings - key = _apply_rotary_pos_emb(key, cos, sin) - - hidden_states = F.scaled_dot_product_attention(query, key, value, attn_mask=attention_mask) - hidden_states = hidden_states.transpose(1, 2).reshape(*input_shape, -1).contiguous() - hidden_states = self.o_proj(hidden_states) - return hidden_states + return self.processor( + self, + hidden_states, + attention_mask=attention_mask, + encoder_hidden_states=encoder_hidden_states, + position_embeddings=position_embeddings, + encoder_position_embeddings=encoder_position_embeddings, + ) class AnimaTextConditionerBlock(nn.Module): @@ -193,6 +236,7 @@ class AnimaTextConditioner(ModelMixin, ConfigMixin, PeftAdapterMixin): Anima reuses the Cosmos Predict2 DiT. The only model-specific conditioning module is this LLM adapter, which cross-attends from learned T5 token embeddings to Qwen3 text encoder hidden states before the diffusion loop. + `target_dim` is the conditioner output dimension and must match the transformer's `text_embed_dim`. """ _supports_gradient_checkpointing = True @@ -248,7 +292,6 @@ def forward( target_input_ids: torch.Tensor, target_attention_mask: torch.Tensor | None = None, source_attention_mask: torch.Tensor | None = None, - target_token_weights: torch.Tensor | None = None, ) -> torch.Tensor: target_attention_mask = self._prepare_attention_mask(target_attention_mask) source_attention_mask = self._prepare_attention_mask(source_attention_mask) @@ -284,8 +327,6 @@ def forward( hidden_states = self.norm(self.out_proj(hidden_states)) - if target_token_weights is not None: - hidden_states = hidden_states * target_token_weights.to(hidden_states).unsqueeze(-1) if target_attention_mask is not None: hidden_states = hidden_states * target_attention_mask.squeeze(1).squeeze(1).to(hidden_states).unsqueeze(-1) diff --git a/src/diffusers/pipelines/anima/pipeline_anima.py b/src/diffusers/pipelines/anima/pipeline_anima.py index 7ae777322607..33625192b78c 100644 --- a/src/diffusers/pipelines/anima/pipeline_anima.py +++ b/src/diffusers/pipelines/anima/pipeline_anima.py @@ -63,6 +63,29 @@ def retrieve_timesteps( sigmas: list[float] | None = None, **kwargs, ): + r""" + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`list[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`list[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ if timesteps is not None and sigmas is not None: raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") if timesteps is not None: From f3bb40371d77f5b1da469b656b519cab183547a8 Mon Sep 17 00:00:00 2001 From: rmatif Date: Wed, 13 May 2026 21:05:29 +0200 Subject: [PATCH 5/6] Refactor Anima to modular --- docs/source/en/api/pipelines/anima.md | 15 +- scripts/convert_anima_to_diffusers.py | 5 +- src/diffusers/__init__.py | 10 +- src/diffusers/models/__init__.py | 2 + src/diffusers/models/transformers/__init__.py | 1 + .../transformers/transformer_anima.py} | 0 src/diffusers/modular_pipelines/__init__.py | 5 + .../anima/__init__.py | 9 +- .../modular_pipelines/anima/before_denoise.py | 284 +++++++++ .../modular_pipelines/anima/decoders.py | 120 ++++ .../modular_pipelines/anima/denoise.py | 214 +++++++ .../modular_pipelines/anima/encoders.py | 225 +++++++ .../anima/modular_blocks_anima.py | 173 +++++ .../anima/modular_pipeline.py | 52 ++ .../modular_pipelines/modular_pipeline.py | 1 + src/diffusers/pipelines/__init__.py | 2 - .../pipelines/anima/pipeline_anima.py | 592 ------------------ src/diffusers/utils/dummy_pt_objects.py | 15 + .../dummy_torch_and_transformers_objects.py | 60 +- tests/modular_pipelines/anima/__init__.py | 1 + .../anima/test_modular_pipeline_anima.py | 230 +++++++ tests/pipelines/anima/__init__.py | 1 - tests/pipelines/anima/test_anima.py | 213 ------- 23 files changed, 1376 insertions(+), 854 deletions(-) rename src/diffusers/{pipelines/anima/modeling_anima.py => models/transformers/transformer_anima.py} (100%) rename src/diffusers/{pipelines => modular_pipelines}/anima/__init__.py (82%) create mode 100644 src/diffusers/modular_pipelines/anima/before_denoise.py create mode 100644 src/diffusers/modular_pipelines/anima/decoders.py create mode 100644 src/diffusers/modular_pipelines/anima/denoise.py create mode 100644 src/diffusers/modular_pipelines/anima/encoders.py create mode 100644 src/diffusers/modular_pipelines/anima/modular_blocks_anima.py create mode 100644 src/diffusers/modular_pipelines/anima/modular_pipeline.py delete mode 100644 src/diffusers/pipelines/anima/pipeline_anima.py create mode 100644 tests/modular_pipelines/anima/__init__.py create mode 100644 tests/modular_pipelines/anima/test_modular_pipeline_anima.py delete mode 100644 tests/pipelines/anima/__init__.py delete mode 100644 tests/pipelines/anima/test_anima.py diff --git a/docs/source/en/api/pipelines/anima.md b/docs/source/en/api/pipelines/anima.md index fdb0a84d2967..5d21235cb9d8 100644 --- a/docs/source/en/api/pipelines/anima.md +++ b/docs/source/en/api/pipelines/anima.md @@ -4,17 +4,22 @@ Anima is a text-to-image model that reuses the [`CosmosTransformer3DModel`] with ```python import torch -from diffusers import AnimaPipeline +from diffusers import AnimaAutoBlocks -pipe = AnimaPipeline.from_pretrained("path/to/anima-diffusers", torch_dtype=torch.bfloat16) +pipe = AnimaAutoBlocks().init_pipeline("path/to/anima-diffusers") +pipe.load_components(torch_dtype=torch.bfloat16) pipe.to("cuda") -image = pipe("A cinematic portrait of a woman in a rain-soaked city street").images[0] +image = pipe(prompt="masterpiece, best quality, 1girl, solo, city lights").images[0] ``` -## AnimaPipeline +## AnimaModularPipeline -[[autodoc]] AnimaPipeline +[[autodoc]] AnimaModularPipeline + +## AnimaAutoBlocks + +[[autodoc]] AnimaAutoBlocks ## AnimaTextConditioner diff --git a/scripts/convert_anima_to_diffusers.py b/scripts/convert_anima_to_diffusers.py index 26cd8c7f029a..99005cd1e22a 100644 --- a/scripts/convert_anima_to_diffusers.py +++ b/scripts/convert_anima_to_diffusers.py @@ -26,7 +26,7 @@ from transformers import AutoTokenizer, Qwen3Config, Qwen3Model, T5TokenizerFast from diffusers import ( - AnimaPipeline, + AnimaAutoBlocks, AnimaTextConditioner, AutoencoderKLQwenImage, FlowMatchEulerDiscreteScheduler, @@ -251,7 +251,8 @@ def save_pipeline(args, transformer, text_conditioner, text_encoder, vae): t5_tokenizer = T5TokenizerFast.from_pretrained(args.t5_tokenizer_path) scheduler = FlowMatchEulerDiscreteScheduler(shift=3.0) - pipe = AnimaPipeline( + pipe = AnimaAutoBlocks().init_pipeline() + pipe.update_components( text_encoder=text_encoder, tokenizer=tokenizer, t5_tokenizer=t5_tokenizer, diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 8a9499c7d077..f8e0ba59f16e 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -190,6 +190,7 @@ [ "AceStepTransformer1DModel", "AllegroTransformer3DModel", + "AnimaTextConditioner", "AsymmetricAutoencoderKL", "AttentionBackendName", "AuraFlowTransformer2DModel", @@ -474,6 +475,8 @@ "QwenImageLayeredAutoBlocks", "QwenImageLayeredModularPipeline", "QwenImageModularPipeline", + "AnimaAutoBlocks", + "AnimaModularPipeline", "StableDiffusion3AutoBlocks", "StableDiffusion3ModularPipeline", "StableDiffusionXLAutoBlocks", @@ -502,8 +505,6 @@ "AmusedImg2ImgPipeline", "AmusedInpaintPipeline", "AmusedPipeline", - "AnimaPipeline", - "AnimaTextConditioner", "AnimateDiffControlNetPipeline", "AnimateDiffPAGPipeline", "AnimateDiffPipeline", @@ -1014,6 +1015,7 @@ from .models import ( AceStepTransformer1DModel, AllegroTransformer3DModel, + AnimaTextConditioner, AsymmetricAutoencoderKL, AttentionBackendName, AuraFlowTransformer2DModel, @@ -1247,6 +1249,8 @@ from .utils.dummy_torch_and_transformers_objects import * # noqa F403 else: from .modular_pipelines import ( + AnimaAutoBlocks, + AnimaModularPipeline, ErnieImageAutoBlocks, ErnieImageModularPipeline, Flux2AutoBlocks, @@ -1303,8 +1307,6 @@ AmusedImg2ImgPipeline, AmusedInpaintPipeline, AmusedPipeline, - AnimaPipeline, - AnimaTextConditioner, AnimateDiffControlNetPipeline, AnimateDiffPAGPipeline, AnimateDiffPipeline, diff --git a/src/diffusers/models/__init__.py b/src/diffusers/models/__init__.py index bb765c56d013..6afb3f672de0 100755 --- a/src/diffusers/models/__init__.py +++ b/src/diffusers/models/__init__.py @@ -80,6 +80,7 @@ _import_structure["embeddings"] = ["ImageProjection"] _import_structure["modeling_utils"] = ["ModelMixin"] _import_structure["transformers.ace_step_transformer"] = ["AceStepTransformer1DModel"] + _import_structure["transformers.transformer_anima"] = ["AnimaTextConditioner"] _import_structure["transformers.auraflow_transformer_2d"] = ["AuraFlowTransformer2DModel"] _import_structure["transformers.cogvideox_transformer_3d"] = ["CogVideoXTransformer3DModel"] _import_structure["transformers.consisid_transformer_3d"] = ["ConsisIDTransformer3DModel"] @@ -213,6 +214,7 @@ from .transformers import ( AceStepTransformer1DModel, AllegroTransformer3DModel, + AnimaTextConditioner, AuraFlowTransformer2DModel, BriaFiboTransformer2DModel, BriaTransformer2DModel, diff --git a/src/diffusers/models/transformers/__init__.py b/src/diffusers/models/transformers/__init__.py index 5c64b5fc99fa..aac6941abd27 100755 --- a/src/diffusers/models/transformers/__init__.py +++ b/src/diffusers/models/transformers/__init__.py @@ -18,6 +18,7 @@ from .t5_film_transformer import T5FilmDecoder from .transformer_2d import Transformer2DModel from .transformer_allegro import AllegroTransformer3DModel + from .transformer_anima import AnimaTextConditioner from .transformer_bria import BriaTransformer2DModel from .transformer_bria_fibo import BriaFiboTransformer2DModel from .transformer_chroma import ChromaTransformer2DModel diff --git a/src/diffusers/pipelines/anima/modeling_anima.py b/src/diffusers/models/transformers/transformer_anima.py similarity index 100% rename from src/diffusers/pipelines/anima/modeling_anima.py rename to src/diffusers/models/transformers/transformer_anima.py diff --git a/src/diffusers/modular_pipelines/__init__.py b/src/diffusers/modular_pipelines/__init__.py index 0b2225c980b3..998f2c884d75 100644 --- a/src/diffusers/modular_pipelines/__init__.py +++ b/src/diffusers/modular_pipelines/__init__.py @@ -89,6 +89,10 @@ "QwenImageLayeredModularPipeline", "QwenImageLayeredAutoBlocks", ] + _import_structure["anima"] = [ + "AnimaAutoBlocks", + "AnimaModularPipeline", + ] _import_structure["ernie_image"] = [ "ErnieImageAutoBlocks", "ErnieImageModularPipeline", @@ -114,6 +118,7 @@ except OptionalDependencyNotAvailable: from ..utils.dummy_pt_objects import * # noqa F403 else: + from .anima import AnimaAutoBlocks, AnimaModularPipeline from .components_manager import ComponentsManager from .ernie_image import ErnieImageAutoBlocks, ErnieImageModularPipeline from .flux import FluxAutoBlocks, FluxKontextAutoBlocks, FluxKontextModularPipeline, FluxModularPipeline diff --git a/src/diffusers/pipelines/anima/__init__.py b/src/diffusers/modular_pipelines/anima/__init__.py similarity index 82% rename from src/diffusers/pipelines/anima/__init__.py rename to src/diffusers/modular_pipelines/anima/__init__.py index bca1117ca7c6..4772d906e03b 100644 --- a/src/diffusers/pipelines/anima/__init__.py +++ b/src/diffusers/modular_pipelines/anima/__init__.py @@ -21,9 +21,8 @@ _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) else: - _import_structure["modeling_anima"] = ["AnimaTextConditioner"] - _import_structure["pipeline_anima"] = ["AnimaPipeline"] - + _import_structure["modular_blocks_anima"] = ["AnimaAutoBlocks"] + _import_structure["modular_pipeline"] = ["AnimaModularPipeline"] if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: try: @@ -32,8 +31,8 @@ except OptionalDependencyNotAvailable: from ...utils.dummy_torch_and_transformers_objects import * # noqa F403 else: - from .modeling_anima import AnimaTextConditioner - from .pipeline_anima import AnimaPipeline + from .modular_blocks_anima import AnimaAutoBlocks + from .modular_pipeline import AnimaModularPipeline else: import sys diff --git a/src/diffusers/modular_pipelines/anima/before_denoise.py b/src/diffusers/modular_pipelines/anima/before_denoise.py new file mode 100644 index 000000000000..bff17fe07e95 --- /dev/null +++ b/src/diffusers/modular_pipelines/anima/before_denoise.py @@ -0,0 +1,284 @@ +# Copyright 2026 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect + +import numpy as np +import torch + +from ...models import CosmosTransformer3DModel +from ...schedulers import FlowMatchEulerDiscreteScheduler +from ...utils.torch_utils import randn_tensor +from ..modular_pipeline import ModularPipelineBlocks, PipelineState +from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam +from .modular_pipeline import AnimaModularPipeline + + +def retrieve_timesteps( + scheduler, + num_inference_steps: int | None = None, + device: str | torch.device | None = None, + timesteps: list[int] | None = None, + sigmas: list[float] | None = None, + **kwargs, +): + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +class AnimaTextInputStep(ModularPipelineBlocks): + model_name = "anima" + + @property + def description(self) -> str: + return "Input processing step that expands Anima prompt embeddings for the requested image batch." + + @property + def expected_components(self) -> list[ComponentSpec]: + return [ComponentSpec("transformer", CosmosTransformer3DModel)] + + @property + def inputs(self) -> list[InputParam]: + return [ + InputParam.template("num_images_per_prompt"), + InputParam( + "prompt_embeds", + required=True, + type_hint=torch.Tensor, + description="Conditioned prompt embeddings generated by the text encoder step.", + ), + InputParam( + "negative_prompt_embeds", + type_hint=torch.Tensor, + description="Conditioned negative prompt embeddings generated by the text encoder step.", + ), + ] + + @property + def intermediate_outputs(self) -> list[OutputParam]: + return [ + OutputParam( + "prompt_embeds", + type_hint=torch.Tensor, + kwargs_type="denoiser_input_fields", + description="Prompt embeddings expanded to the final denoising batch.", + ), + OutputParam( + "negative_prompt_embeds", + type_hint=torch.Tensor, + kwargs_type="denoiser_input_fields", + description="Negative prompt embeddings expanded to the final denoising batch.", + ), + OutputParam( + "batch_size", + type_hint=int, + description="Number of input prompts before `num_images_per_prompt` expansion.", + ), + OutputParam("dtype", type_hint=torch.dtype, description="Dtype used by the Anima denoiser."), + ] + + @torch.no_grad() + def __call__(self, components: AnimaModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + + block_state.batch_size = block_state.prompt_embeds.shape[0] + block_state.dtype = components.transformer.dtype + + _, seq_len, _ = block_state.prompt_embeds.shape + block_state.prompt_embeds = block_state.prompt_embeds.repeat(1, block_state.num_images_per_prompt, 1) + block_state.prompt_embeds = block_state.prompt_embeds.view( + block_state.batch_size * block_state.num_images_per_prompt, seq_len, -1 + ) + + if block_state.negative_prompt_embeds is not None: + _, seq_len, _ = block_state.negative_prompt_embeds.shape + block_state.negative_prompt_embeds = block_state.negative_prompt_embeds.repeat( + 1, block_state.num_images_per_prompt, 1 + ) + block_state.negative_prompt_embeds = block_state.negative_prompt_embeds.view( + block_state.batch_size * block_state.num_images_per_prompt, seq_len, -1 + ) + + self.set_block_state(state, block_state) + return components, state + + +class AnimaPrepareLatentsStep(ModularPipelineBlocks): + model_name = "anima" + + @property + def description(self) -> str: + return "Prepare noisy image latents and padding mask for Anima denoising." + + @property + def expected_components(self) -> list[ComponentSpec]: + return [ComponentSpec("transformer", CosmosTransformer3DModel)] + + @property + def inputs(self) -> list[InputParam]: + return [ + InputParam.template("height"), + InputParam.template("width"), + InputParam.template("latents"), + InputParam.template("num_images_per_prompt"), + InputParam.template("generator"), + InputParam( + "batch_size", + required=True, + type_hint=int, + description="Number of input prompts before `num_images_per_prompt` expansion.", + ), + InputParam("dtype", type_hint=torch.dtype, description="Dtype used by the Anima denoiser."), + ] + + @property + def intermediate_outputs(self) -> list[OutputParam]: + return [ + OutputParam("height", type_hint=int, description="Image height used for generation."), + OutputParam("width", type_hint=int, description="Image width used for generation."), + OutputParam("latents", type_hint=torch.Tensor, description="Noisy latents for the denoising process."), + OutputParam("padding_mask", type_hint=torch.Tensor, description="Cosmos padding mask for image latents."), + ] + + def check_inputs(self, components: AnimaModularPipeline, block_state): + divisor = components.vae_scale_factor * 2 + if block_state.height % divisor != 0 or block_state.width % divisor != 0: + raise ValueError( + f"`height` and `width` have to be divisible by {divisor} but are {block_state.height} and" + f" {block_state.width}." + ) + + @staticmethod + def prepare_latents( + batch_size: int, + num_channels_latents: int, + height: int, + width: int, + vae_scale_factor: int, + dtype: torch.dtype, + device: torch.device, + generator: torch.Generator | list[torch.Generator] | None, + latents: torch.Tensor | None = None, + ) -> torch.Tensor: + if latents is not None: + return latents.to(device=device, dtype=dtype) + + latent_height = height // vae_scale_factor + latent_width = width // vae_scale_factor + shape = (batch_size, num_channels_latents, 1, latent_height, latent_width) + + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + return randn_tensor(shape, generator=generator, device=device, dtype=dtype) + + @torch.no_grad() + def __call__(self, components: AnimaModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + + block_state.height = block_state.height or components.default_height + block_state.width = block_state.width or components.default_width + self.check_inputs(components, block_state) + + device = components._execution_device + block_state.latents = self.prepare_latents( + batch_size=block_state.batch_size * block_state.num_images_per_prompt, + num_channels_latents=components.num_channels_latents, + height=block_state.height, + width=block_state.width, + vae_scale_factor=components.vae_scale_factor, + dtype=torch.float32, + device=device, + generator=block_state.generator, + latents=block_state.latents, + ) + block_state.padding_mask = block_state.latents.new_zeros( + 1, 1, block_state.height, block_state.width, dtype=block_state.dtype + ) + + self.set_block_state(state, block_state) + return components, state + + +class AnimaSetTimestepsStep(ModularPipelineBlocks): + model_name = "anima" + + @property + def expected_components(self) -> list[ComponentSpec]: + return [ComponentSpec("scheduler", FlowMatchEulerDiscreteScheduler)] + + @property + def description(self) -> str: + return "Set the scheduler timesteps for Anima inference." + + @property + def inputs(self) -> list[InputParam]: + return [ + InputParam.template("num_inference_steps"), + InputParam.template("sigmas"), + ] + + @property + def intermediate_outputs(self) -> list[OutputParam]: + return [ + OutputParam("timesteps", type_hint=torch.Tensor, description="Timesteps for the denoising loop."), + OutputParam("num_inference_steps", type_hint=int, description="Number of denoising steps."), + ] + + @torch.no_grad() + def __call__(self, components: AnimaModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + device = components._execution_device + + sigmas = ( + np.linspace(1.0, 1 / block_state.num_inference_steps, block_state.num_inference_steps) + if block_state.sigmas is None + else block_state.sigmas + ) + block_state.timesteps, block_state.num_inference_steps = retrieve_timesteps( + components.scheduler, + device=device, + sigmas=sigmas, + ) + components.scheduler.set_begin_index(0) + + self.set_block_state(state, block_state) + return components, state diff --git a/src/diffusers/modular_pipelines/anima/decoders.py b/src/diffusers/modular_pipelines/anima/decoders.py new file mode 100644 index 000000000000..f1f4b475a4b8 --- /dev/null +++ b/src/diffusers/modular_pipelines/anima/decoders.py @@ -0,0 +1,120 @@ +# Copyright 2026 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +import PIL +import torch + +from ...configuration_utils import FrozenDict +from ...image_processor import VaeImageProcessor +from ...models import AutoencoderKLQwenImage +from ..modular_pipeline import ModularPipelineBlocks, PipelineState +from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam +from .modular_pipeline import AnimaModularPipeline + + +class AnimaVaeDecoderStep(ModularPipelineBlocks): + model_name = "anima" + + @property + def description(self) -> str: + return "Step that decodes Anima latents into image tensors." + + @property + def expected_components(self) -> list[ComponentSpec]: + return [ComponentSpec("vae", AutoencoderKLQwenImage)] + + @property + def inputs(self) -> list[InputParam]: + return [ + InputParam("latents", required=True, type_hint=torch.Tensor, description="Denoised Anima latents."), + ] + + @property + def intermediate_outputs(self) -> list[OutputParam]: + return [OutputParam.template("images", note="tensor output of the VAE decoder")] + + @torch.no_grad() + def __call__(self, components: AnimaModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + + latents = block_state.latents.to(components.vae.dtype) + latents_mean = ( + torch.tensor(components.vae.config.latents_mean) + .view(1, components.vae.config.z_dim, 1, 1, 1) + .to(latents.device, latents.dtype) + ) + latents_std = 1.0 / torch.tensor(components.vae.config.latents_std).view( + 1, components.vae.config.z_dim, 1, 1, 1 + ).to(latents.device, latents.dtype) + latents = latents / latents_std + latents_mean + + block_state.images = components.vae.decode(latents, return_dict=False)[0][:, :, 0] + + self.set_block_state(state, block_state) + return components, state + + +class AnimaProcessImagesOutputStep(ModularPipelineBlocks): + model_name = "anima" + + @property + def description(self) -> str: + return "Postprocess decoded Anima image tensors." + + @property + def expected_components(self) -> list[ComponentSpec]: + return [ + ComponentSpec( + "image_processor", + VaeImageProcessor, + config=FrozenDict({"vae_scale_factor": 8}), + default_creation_method="from_config", + ), + ] + + @property + def inputs(self) -> list[InputParam]: + return [ + InputParam("images", required=True, type_hint=torch.Tensor, description="Decoded Anima image tensors."), + InputParam.template("output_type"), + ] + + @property + def intermediate_outputs(self) -> list[OutputParam]: + return [ + OutputParam( + "images", + type_hint=list[PIL.Image.Image] | np.ndarray | torch.Tensor, + description="Generated images.", + ) + ] + + @staticmethod + def check_inputs(output_type): + if output_type not in ["pil", "np", "pt"]: + raise ValueError(f"Invalid output_type: {output_type}") + + @torch.no_grad() + def __call__(self, components: AnimaModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + self.check_inputs(block_state.output_type) + + block_state.images = components.image_processor.postprocess( + image=block_state.images, + output_type=block_state.output_type, + ) + + self.set_block_state(state, block_state) + return components, state diff --git a/src/diffusers/modular_pipelines/anima/denoise.py b/src/diffusers/modular_pipelines/anima/denoise.py new file mode 100644 index 000000000000..2c7f36484749 --- /dev/null +++ b/src/diffusers/modular_pipelines/anima/denoise.py @@ -0,0 +1,214 @@ +# Copyright 2026 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any + +import torch + +from ...configuration_utils import FrozenDict +from ...guiders import ClassifierFreeGuidance +from ...models import CosmosTransformer3DModel +from ...schedulers import FlowMatchEulerDiscreteScheduler +from ..modular_pipeline import BlockState, LoopSequentialPipelineBlocks, ModularPipelineBlocks, PipelineState +from ..modular_pipeline_utils import ComponentSpec, InputParam +from .modular_pipeline import AnimaModularPipeline + + +class AnimaLoopBeforeDenoiser(ModularPipelineBlocks): + model_name = "anima" + + @property + def description(self) -> str: + return "Step within the denoising loop that prepares Anima latent and timestep inputs." + + @property + def inputs(self) -> list[InputParam]: + return [ + InputParam("latents", required=True, type_hint=torch.Tensor, description="Current Anima latents."), + InputParam("dtype", required=True, type_hint=torch.dtype, description="Dtype used by the Anima denoiser."), + ] + + @torch.no_grad() + def __call__(self, components: AnimaModularPipeline, block_state: BlockState, i: int, t: torch.Tensor): + block_state.latent_model_input = block_state.latents.to(block_state.dtype) + + timestep = t.expand(block_state.latents.shape[0]).to(block_state.dtype) + block_state.timestep = timestep / components.scheduler.config.num_train_timesteps + return components, block_state + + +class AnimaLoopDenoiser(ModularPipelineBlocks): + model_name = "anima" + + def __init__( + self, + guider_input_fields: dict[str, Any] | None = None, + ): + if guider_input_fields is None: + guider_input_fields = {"encoder_hidden_states": ("prompt_embeds", "negative_prompt_embeds")} + if not isinstance(guider_input_fields, dict): + raise ValueError(f"`guider_input_fields` must be a dictionary but is {type(guider_input_fields)}") + self._guider_input_fields = guider_input_fields + super().__init__() + + @property + def expected_components(self) -> list[ComponentSpec]: + return [ + ComponentSpec( + "guider", + ClassifierFreeGuidance, + config=FrozenDict({"guidance_scale": 4.0}), + default_creation_method="from_config", + ), + ComponentSpec("transformer", CosmosTransformer3DModel), + ] + + @property + def description(self) -> str: + return "Step within the denoising loop that predicts Anima noise with guidance." + + @property + def inputs(self) -> list[InputParam]: + inputs = [ + InputParam( + "num_inference_steps", + required=True, + type_hint=int, + description="Number of denoising steps.", + ), + InputParam( + "padding_mask", + required=True, + type_hint=torch.Tensor, + description="Cosmos padding mask for image latents.", + ), + InputParam( + kwargs_type="denoiser_input_fields", + description="The conditional model inputs for the Anima denoiser.", + ), + ] + + guider_input_names = [] + uncond_guider_input_names = [] + for value in self._guider_input_fields.values(): + if isinstance(value, tuple): + guider_input_names.append(value[0]) + uncond_guider_input_names.append(value[1]) + else: + guider_input_names.append(value) + + for name in guider_input_names: + inputs.append(InputParam(name=name, required=True)) + for name in uncond_guider_input_names: + inputs.append(InputParam(name=name)) + return inputs + + @torch.no_grad() + def __call__( + self, components: AnimaModularPipeline, block_state: BlockState, i: int, t: torch.Tensor + ) -> PipelineState: + components.guider.set_state(step=i, num_inference_steps=block_state.num_inference_steps, timestep=t) + guider_state = components.guider.prepare_inputs_from_block_state(block_state, self._guider_input_fields) + + for guider_state_batch in guider_state: + components.guider.prepare_models(components.transformer) + cond_kwargs = { + key: getattr(guider_state_batch, key).to(block_state.dtype) + for key in self._guider_input_fields.keys() + } + guider_state_batch.noise_pred = components.transformer( + hidden_states=block_state.latent_model_input, + timestep=block_state.timestep, + padding_mask=block_state.padding_mask, + return_dict=False, + **cond_kwargs, + )[0] + components.guider.cleanup_models(components.transformer) + + block_state.noise_pred = components.guider(guider_state)[0] + return components, block_state + + +class AnimaLoopAfterDenoiser(ModularPipelineBlocks): + model_name = "anima" + + @property + def expected_components(self) -> list[ComponentSpec]: + return [ComponentSpec("scheduler", FlowMatchEulerDiscreteScheduler)] + + @property + def description(self) -> str: + return "Step within the denoising loop that updates Anima latents." + + @torch.no_grad() + def __call__(self, components: AnimaModularPipeline, block_state: BlockState, i: int, t: torch.Tensor): + latents_dtype = block_state.latents.dtype + block_state.latents = components.scheduler.step( + block_state.noise_pred, t, block_state.latents, return_dict=False + )[0] + if block_state.latents.dtype != latents_dtype and torch.backends.mps.is_available(): + block_state.latents = block_state.latents.to(latents_dtype) + + return components, block_state + + +class AnimaDenoiseLoopWrapper(LoopSequentialPipelineBlocks): + model_name = "anima" + + @property + def description(self) -> str: + return "Pipeline block that iteratively denoises Anima latents over scheduler timesteps." + + @property + def loop_expected_components(self) -> list[ComponentSpec]: + return [ComponentSpec("scheduler", FlowMatchEulerDiscreteScheduler)] + + @property + def loop_inputs(self) -> list[InputParam]: + return [ + InputParam("timesteps", required=True, type_hint=torch.Tensor, description="Timesteps to denoise over."), + InputParam("num_inference_steps", required=True, type_hint=int, description="Number of denoising steps."), + ] + + @torch.no_grad() + def __call__(self, components: AnimaModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + + num_warmup_steps = len(block_state.timesteps) - block_state.num_inference_steps * components.scheduler.order + + with self.progress_bar(total=block_state.num_inference_steps) as progress_bar: + for i, t in enumerate(block_state.timesteps): + components, block_state = self.loop_step(components, block_state, i=i, t=t) + if i == len(block_state.timesteps) - 1 or ( + (i + 1) > num_warmup_steps and (i + 1) % components.scheduler.order == 0 + ): + progress_bar.update() + + self.set_block_state(state, block_state) + return components, state + + +class AnimaDenoiseStep(AnimaDenoiseLoopWrapper): + block_classes = [ + AnimaLoopBeforeDenoiser, + AnimaLoopDenoiser( + guider_input_fields={"encoder_hidden_states": ("prompt_embeds", "negative_prompt_embeds")} + ), + AnimaLoopAfterDenoiser, + ] + block_names = ["before_denoiser", "denoiser", "after_denoiser"] + + @property + def description(self) -> str: + return "Denoise step that iteratively denoises image latents for Anima." diff --git a/src/diffusers/modular_pipelines/anima/encoders.py b/src/diffusers/modular_pipelines/anima/encoders.py new file mode 100644 index 000000000000..3462ee73cfa3 --- /dev/null +++ b/src/diffusers/modular_pipelines/anima/encoders.py @@ -0,0 +1,225 @@ +# Copyright 2026 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +from transformers import Qwen2Tokenizer, Qwen3Model, T5TokenizerFast + +from ...configuration_utils import FrozenDict +from ...guiders import ClassifierFreeGuidance +from ...models import AnimaTextConditioner +from ..modular_pipeline import ModularPipelineBlocks, PipelineState +from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam +from .modular_pipeline import AnimaModularPipeline + + +class AnimaTextEncoderStep(ModularPipelineBlocks): + model_name = "anima" + + @property + def description(self) -> str: + return "Text encoder step that maps prompts to Cosmos text conditioning for Anima." + + @property + def expected_components(self) -> list[ComponentSpec]: + return [ + ComponentSpec("text_encoder", Qwen3Model), + ComponentSpec("tokenizer", Qwen2Tokenizer), + ComponentSpec("t5_tokenizer", T5TokenizerFast), + ComponentSpec("text_conditioner", AnimaTextConditioner), + ComponentSpec( + "guider", + ClassifierFreeGuidance, + config=FrozenDict({"guidance_scale": 4.0}), + default_creation_method="from_config", + ), + ] + + @property + def inputs(self) -> list[InputParam]: + return [ + InputParam.template("prompt"), + InputParam.template("negative_prompt"), + InputParam.template("max_sequence_length"), + ] + + @property + def intermediate_outputs(self) -> list[OutputParam]: + return [ + OutputParam( + "prompt_embeds", + type_hint=torch.Tensor, + kwargs_type="denoiser_input_fields", + description="Conditioned prompt embeddings generated by the Anima text conditioner.", + ), + OutputParam( + "negative_prompt_embeds", + type_hint=torch.Tensor, + kwargs_type="denoiser_input_fields", + description="Conditioned negative prompt embeddings generated by the Anima text conditioner.", + ), + ] + + @staticmethod + def check_inputs(block_state): + if not isinstance(block_state.prompt, str) and not isinstance(block_state.prompt, list): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(block_state.prompt)}") + if block_state.max_sequence_length is not None and block_state.max_sequence_length > 4096: + raise ValueError(f"`max_sequence_length` cannot be greater than 4096 but is {block_state.max_sequence_length}") + + @staticmethod + def _get_qwen_prompt_embeds( + components: AnimaModularPipeline, + prompt: str | list[str], + max_sequence_length: int, + device: torch.device, + dtype: torch.dtype, + ) -> tuple[torch.Tensor, torch.Tensor]: + prompt = [prompt] if isinstance(prompt, str) else prompt + + text_inputs = components.tokenizer( + prompt, + padding="longest", + max_length=max_sequence_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids.to(device) + prompt_attention_mask = text_inputs.attention_mask.to(device) + if text_input_ids.shape[-1] == 0: + text_input_ids = text_input_ids.new_zeros((text_input_ids.shape[0], 1)) + prompt_attention_mask = prompt_attention_mask.new_zeros((prompt_attention_mask.shape[0], 1)) + + prompt_embeds = components.text_encoder( + input_ids=text_input_ids, + attention_mask=prompt_attention_mask, + output_hidden_states=False, + ).last_hidden_state + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + prompt_embeds = prompt_embeds * prompt_attention_mask.to(prompt_embeds).unsqueeze(-1) + + return prompt_embeds, prompt_attention_mask + + @staticmethod + def _get_t5_prompt_ids( + components: AnimaModularPipeline, + prompt: str | list[str], + max_sequence_length: int, + device: torch.device, + ) -> tuple[torch.Tensor, torch.Tensor]: + prompt = [prompt] if isinstance(prompt, str) else prompt + + text_inputs = components.t5_tokenizer( + prompt, + padding="longest", + max_length=max_sequence_length, + truncation=True, + return_tensors="pt", + ) + return text_inputs.input_ids.to(device), text_inputs.attention_mask.to(device) + + @classmethod + def encode_prompt( + cls, + components: AnimaModularPipeline, + prompt: str | list[str], + negative_prompt: str | list[str] | None = None, + prepare_unconditional_embeds: bool = True, + max_sequence_length: int = 512, + device: torch.device | None = None, + dtype: torch.dtype | None = None, + ) -> tuple[torch.Tensor, torch.Tensor | None]: + device = device or components._execution_device + dtype = dtype or components.text_conditioner.dtype + + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) + + qwen_prompt_embeds, qwen_attention_mask = cls._get_qwen_prompt_embeds( + components=components, + prompt=prompt, + max_sequence_length=max_sequence_length, + device=device, + dtype=dtype, + ) + t5_input_ids, t5_attention_mask = cls._get_t5_prompt_ids( + components=components, + prompt=prompt, + max_sequence_length=max_sequence_length, + device=device, + ) + prompt_embeds = components.text_conditioner( + source_hidden_states=qwen_prompt_embeds, + target_input_ids=t5_input_ids, + target_attention_mask=t5_attention_mask, + source_attention_mask=qwen_attention_mask, + ) + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + + negative_prompt_embeds = None + if prepare_unconditional_embeds: + negative_prompt = negative_prompt if negative_prompt is not None else "" + negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + + if prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + if batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + + negative_qwen_prompt_embeds, negative_qwen_attention_mask = cls._get_qwen_prompt_embeds( + components=components, + prompt=negative_prompt, + max_sequence_length=max_sequence_length, + device=device, + dtype=dtype, + ) + negative_t5_input_ids, negative_t5_attention_mask = cls._get_t5_prompt_ids( + components=components, + prompt=negative_prompt, + max_sequence_length=max_sequence_length, + device=device, + ) + negative_prompt_embeds = components.text_conditioner( + source_hidden_states=negative_qwen_prompt_embeds, + target_input_ids=negative_t5_input_ids, + target_attention_mask=negative_t5_attention_mask, + source_attention_mask=negative_qwen_attention_mask, + ) + negative_prompt_embeds = negative_prompt_embeds.to(dtype=dtype, device=device) + + return prompt_embeds, negative_prompt_embeds + + @torch.no_grad() + def __call__(self, components: AnimaModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + self.check_inputs(block_state) + + block_state.prompt_embeds, block_state.negative_prompt_embeds = self.encode_prompt( + components=components, + prompt=block_state.prompt, + negative_prompt=block_state.negative_prompt, + prepare_unconditional_embeds=components.guider.num_conditions > 1, + max_sequence_length=block_state.max_sequence_length, + device=components._execution_device, + dtype=components.transformer.dtype, + ) + + self.set_block_state(state, block_state) + return components, state diff --git a/src/diffusers/modular_pipelines/anima/modular_blocks_anima.py b/src/diffusers/modular_pipelines/anima/modular_blocks_anima.py new file mode 100644 index 000000000000..af966428cc71 --- /dev/null +++ b/src/diffusers/modular_pipelines/anima/modular_blocks_anima.py @@ -0,0 +1,173 @@ +# Copyright 2026 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from ..modular_pipeline import SequentialPipelineBlocks +from ..modular_pipeline_utils import OutputParam +from .before_denoise import AnimaPrepareLatentsStep, AnimaSetTimestepsStep, AnimaTextInputStep +from .decoders import AnimaProcessImagesOutputStep, AnimaVaeDecoderStep +from .denoise import AnimaDenoiseStep +from .encoders import AnimaTextEncoderStep + + +# auto_docstring +class AnimaCoreDenoiseStep(SequentialPipelineBlocks): + """ + Denoise block that takes encoded Anima text conditions and runs the denoising process. + + Components: + transformer (`CosmosTransformer3DModel`) + scheduler (`FlowMatchEulerDiscreteScheduler`) + guider (`ClassifierFreeGuidance`) + + Inputs: + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + prompt_embeds (`Tensor`): + Conditioned prompt embeddings generated by the text encoder step. + negative_prompt_embeds (`Tensor`, *optional*): + Conditioned negative prompt embeddings generated by the text encoder step. + height (`int`, *optional*): + The height in pixels of the generated image. + width (`int`, *optional*): + The width in pixels of the generated image. + latents (`Tensor`, *optional*): + Pre-generated noisy latents for image generation. + generator (`Generator`, *optional*): + Torch generator for deterministic generation. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. + sigmas (`list`, *optional*): + Custom sigmas for the denoising process. + **denoiser_input_fields (`None`, *optional*): + The conditional model inputs for the Anima denoiser. + + Outputs: + latents (`Tensor`): + Denoised latents. + """ + + block_classes = [ + AnimaTextInputStep, + AnimaPrepareLatentsStep, + AnimaSetTimestepsStep, + AnimaDenoiseStep, + ] + block_names = ["input", "prepare_latents", "set_timesteps", "denoise"] + + @property + def description(self) -> str: + return "Denoise block that takes encoded Anima text conditions and runs the denoising process." + + @property + def outputs(self): + return [OutputParam.template("latents")] + + +# auto_docstring +class AnimaDecodeStep(SequentialPipelineBlocks): + """ + Decode Anima latents into generated images. + + Components: + vae (`AutoencoderKLQwenImage`) + image_processor (`VaeImageProcessor`) + + Inputs: + latents (`Tensor`): + Denoised Anima latents. + output_type (`str`, *optional*, defaults to pil): + Output format: 'pil', 'np', 'pt'. + + Outputs: + images (`list`): + Generated images. + """ + + block_classes = [AnimaVaeDecoderStep, AnimaProcessImagesOutputStep] + block_names = ["decode", "postprocess"] + + @property + def description(self) -> str: + return "Decode Anima latents into generated images." + + @property + def outputs(self): + return [OutputParam.template("images")] + + +# auto_docstring +class AnimaAutoBlocks(SequentialPipelineBlocks): + """ + Auto Modular pipeline for text-to-image generation using Anima. + + Supported workflows: + - `text2image`: requires `prompt` + + Components: + text_encoder (`Qwen3Model`) + tokenizer (`Qwen2Tokenizer`) + t5_tokenizer (`T5Tokenizer`) + text_conditioner (`AnimaTextConditioner`) + guider (`ClassifierFreeGuidance`) + transformer (`CosmosTransformer3DModel`) + scheduler (`FlowMatchEulerDiscreteScheduler`) + vae (`AutoencoderKLQwenImage`) + image_processor (`VaeImageProcessor`) + + Inputs: + prompt (`str`): + The prompt or prompts to guide image generation. + negative_prompt (`str`, *optional*): + The prompt or prompts not to guide the image generation. + max_sequence_length (`int`, *optional*, defaults to 512): + Maximum sequence length for prompt encoding. + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + height (`int`, *optional*): + The height in pixels of the generated image. + width (`int`, *optional*): + The width in pixels of the generated image. + latents (`Tensor`, *optional*): + Pre-generated noisy latents for image generation. + generator (`Generator`, *optional*): + Torch generator for deterministic generation. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. + sigmas (`list`, *optional*): + Custom sigmas for the denoising process. + **denoiser_input_fields (`None`, *optional*): + The conditional model inputs for the Anima denoiser. + output_type (`str`, *optional*, defaults to pil): + Output format: 'pil', 'np', 'pt'. + + Outputs: + images (`list`): + Generated images. + """ + + block_classes = [ + AnimaTextEncoderStep, + AnimaCoreDenoiseStep, + AnimaDecodeStep, + ] + block_names = ["text_encoder", "denoise", "decode"] + _workflow_map = {"text2image": {"prompt": True}} + + @property + def description(self) -> str: + return "Auto Modular pipeline for text-to-image generation using Anima." + + @property + def outputs(self): + return [OutputParam.template("images")] diff --git a/src/diffusers/modular_pipelines/anima/modular_pipeline.py b/src/diffusers/modular_pipelines/anima/modular_pipeline.py new file mode 100644 index 000000000000..44fce4657c6f --- /dev/null +++ b/src/diffusers/modular_pipelines/anima/modular_pipeline.py @@ -0,0 +1,52 @@ +# Copyright 2026 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from ...loaders import AnimaLoraLoaderMixin +from ..modular_pipeline import ModularPipeline + + +class AnimaModularPipeline(ModularPipeline, AnimaLoraLoaderMixin): + """ + A ModularPipeline for Anima. + + > [!WARNING] > This is an experimental feature and is likely to change in the future. + """ + + default_blocks_name = "AnimaAutoBlocks" + + @property + def default_height(self): + return self.default_sample_size * self.vae_scale_factor + + @property + def default_width(self): + return self.default_sample_size * self.vae_scale_factor + + @property + def default_sample_size(self): + return 128 + + @property + def vae_scale_factor(self): + vae_scale_factor = 8 + if self.vae is not None: + vae_scale_factor = 2 ** len(self.vae.temperal_downsample) + return vae_scale_factor + + @property + def num_channels_latents(self): + num_channels_latents = 16 + if self.transformer is not None: + num_channels_latents = self.transformer.config.in_channels + return num_channels_latents diff --git a/src/diffusers/modular_pipelines/modular_pipeline.py b/src/diffusers/modular_pipelines/modular_pipeline.py index 8cfe07059272..37e069bc05f8 100644 --- a/src/diffusers/modular_pipelines/modular_pipeline.py +++ b/src/diffusers/modular_pipelines/modular_pipeline.py @@ -130,6 +130,7 @@ def _helios_pyramid_map_fn(config_dict=None): ("qwenimage-edit", _create_default_map_fn("QwenImageEditModularPipeline")), ("qwenimage-edit-plus", _create_default_map_fn("QwenImageEditPlusModularPipeline")), ("qwenimage-layered", _create_default_map_fn("QwenImageLayeredModularPipeline")), + ("anima", _create_default_map_fn("AnimaModularPipeline")), ("z-image", _create_default_map_fn("ZImageModularPipeline")), ("helios", _create_default_map_fn("HeliosModularPipeline")), ("helios-pyramid", _helios_pyramid_map_fn), diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py index c03d588ac152..f0fc7585bf31 100644 --- a/src/diffusers/pipelines/__init__.py +++ b/src/diffusers/pipelines/__init__.py @@ -100,7 +100,6 @@ _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) else: - _import_structure["anima"] = ["AnimaPipeline", "AnimaTextConditioner"] _import_structure["deprecated"].extend( [ "AmusedImg2ImgPipeline", @@ -589,7 +588,6 @@ AceStepPipeline, ) from .allegro import AllegroPipeline - from .anima import AnimaPipeline, AnimaTextConditioner from .animatediff import ( AnimateDiffControlNetPipeline, AnimateDiffPipeline, diff --git a/src/diffusers/pipelines/anima/pipeline_anima.py b/src/diffusers/pipelines/anima/pipeline_anima.py deleted file mode 100644 index 33625192b78c..000000000000 --- a/src/diffusers/pipelines/anima/pipeline_anima.py +++ /dev/null @@ -1,592 +0,0 @@ -# Copyright 2026 The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import inspect -from typing import Callable - -import numpy as np -import torch -from transformers import PreTrainedModel, PreTrainedTokenizer, T5TokenizerFast - -from ...callbacks import MultiPipelineCallbacks, PipelineCallback -from ...image_processor import VaeImageProcessor -from ...loaders import AnimaLoraLoaderMixin -from ...models import AutoencoderKLQwenImage, CosmosTransformer3DModel -from ...schedulers import FlowMatchEulerDiscreteScheduler -from ...utils import is_torch_xla_available, logging, replace_example_docstring -from ...utils.torch_utils import randn_tensor -from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput -from .modeling_anima import AnimaTextConditioner - - -if is_torch_xla_available(): - import torch_xla.core.xla_model as xm - - XLA_AVAILABLE = True -else: - XLA_AVAILABLE = False - -logger = logging.get_logger(__name__) # pylint: disable=invalid-name - - -EXAMPLE_DOC_STRING = """ - Examples: - ```python - >>> import torch - >>> from diffusers import AnimaPipeline - - >>> pipe = AnimaPipeline.from_pretrained("path/to/anima-diffusers", torch_dtype=torch.bfloat16) - >>> pipe.to("cuda") - >>> image = pipe("A cinematic portrait of a woman in a rain-soaked city street").images[0] - >>> image.save("anima.png") - ``` -""" - - -# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps -def retrieve_timesteps( - scheduler, - num_inference_steps: int | None = None, - device: str | torch.device | None = None, - timesteps: list[int] | None = None, - sigmas: list[float] | None = None, - **kwargs, -): - r""" - Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles - custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. - - Args: - scheduler (`SchedulerMixin`): - The scheduler to get timesteps from. - num_inference_steps (`int`): - The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` - must be `None`. - device (`str` or `torch.device`, *optional*): - The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. - timesteps (`list[int]`, *optional*): - Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, - `num_inference_steps` and `sigmas` must be `None`. - sigmas (`list[float]`, *optional*): - Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, - `num_inference_steps` and `timesteps` must be `None`. - - Returns: - `tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the - second element is the number of inference steps. - """ - if timesteps is not None and sigmas is not None: - raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") - if timesteps is not None: - accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) - if not accepts_timesteps: - raise ValueError( - f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" - f" timestep schedules. Please check whether you are using the correct scheduler." - ) - scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) - timesteps = scheduler.timesteps - num_inference_steps = len(timesteps) - elif sigmas is not None: - accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) - if not accept_sigmas: - raise ValueError( - f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" - f" sigmas schedules. Please check whether you are using the correct scheduler." - ) - scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) - timesteps = scheduler.timesteps - num_inference_steps = len(timesteps) - else: - scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) - timesteps = scheduler.timesteps - return timesteps, num_inference_steps - - -class AnimaPipeline(DiffusionPipeline, AnimaLoraLoaderMixin): - r""" - Pipeline for text-to-image generation using Anima. - - Anima uses a Qwen3 text encoder, a T5-token LLM adapter, the Cosmos Predict2 DiT, and the Qwen-Image VAE. - Supports loading LoRA weights with [`~loaders.AnimaLoraLoaderMixin.load_lora_weights`]. - - Args: - text_encoder (`~transformers.PreTrainedModel`): - Qwen3 text encoder used to produce source hidden states for the Anima text conditioner. - tokenizer (`~transformers.PreTrainedTokenizer`): - Qwen tokenizer paired with `text_encoder`. - t5_tokenizer (`~transformers.T5TokenizerFast`): - T5 tokenizer used to produce target token ids for the Anima text conditioner. - text_conditioner ([`AnimaTextConditioner`]): - Adapter that maps Qwen3 hidden states and T5 token ids to Cosmos text embeddings. - transformer ([`CosmosTransformer3DModel`]): - Cosmos Predict2 transformer used to denoise image latents. - vae ([`AutoencoderKLQwenImage`]): - Qwen-Image VAE used to decode latents into images. - scheduler ([`FlowMatchEulerDiscreteScheduler`]): - Flow-matching scheduler used for denoising. - """ - - model_cpu_offload_seq = "text_encoder->text_conditioner->transformer->vae" - _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"] - - def __init__( - self, - text_encoder: PreTrainedModel, - tokenizer: PreTrainedTokenizer, - t5_tokenizer: T5TokenizerFast, - text_conditioner: AnimaTextConditioner, - transformer: CosmosTransformer3DModel, - vae: AutoencoderKLQwenImage, - scheduler: FlowMatchEulerDiscreteScheduler, - ): - super().__init__() - - self.register_modules( - vae=vae, - text_encoder=text_encoder, - tokenizer=tokenizer, - t5_tokenizer=t5_tokenizer, - text_conditioner=text_conditioner, - transformer=transformer, - scheduler=scheduler, - ) - - self.vae_scale_factor = 2 ** len(self.vae.temperal_downsample) if getattr(self, "vae", None) else 8 - self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) - self.default_sample_size = 128 - - def _get_qwen_prompt_embeds( - self, - prompt: str | list[str], - max_sequence_length: int, - device: torch.device | None = None, - dtype: torch.dtype | None = None, - ) -> tuple[torch.Tensor, torch.Tensor]: - device = device or self._execution_device - dtype = dtype or self.text_encoder.dtype - prompt = [prompt] if isinstance(prompt, str) else prompt - - text_inputs = self.tokenizer( - prompt, - padding="longest", - max_length=max_sequence_length, - truncation=True, - return_tensors="pt", - ) - text_input_ids = text_inputs.input_ids.to(device) - prompt_attention_mask = text_inputs.attention_mask.to(device) - if text_input_ids.shape[-1] == 0: - text_input_ids = text_input_ids.new_zeros((text_input_ids.shape[0], 1)) - prompt_attention_mask = prompt_attention_mask.new_zeros((prompt_attention_mask.shape[0], 1)) - - prompt_embeds = self.text_encoder( - input_ids=text_input_ids, - attention_mask=prompt_attention_mask, - output_hidden_states=False, - ).last_hidden_state - prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) - prompt_embeds = prompt_embeds * prompt_attention_mask.to(prompt_embeds).unsqueeze(-1) - - return prompt_embeds, prompt_attention_mask - - def _get_t5_prompt_ids( - self, - prompt: str | list[str], - max_sequence_length: int, - device: torch.device | None = None, - ) -> tuple[torch.Tensor, torch.Tensor]: - device = device or self._execution_device - prompt = [prompt] if isinstance(prompt, str) else prompt - - text_inputs = self.t5_tokenizer( - prompt, - padding="longest", - max_length=max_sequence_length, - truncation=True, - return_tensors="pt", - ) - return text_inputs.input_ids.to(device), text_inputs.attention_mask.to(device) - - def encode_prompt( - self, - prompt: str | list[str], - negative_prompt: str | list[str] | None = None, - do_classifier_free_guidance: bool = True, - num_images_per_prompt: int = 1, - prompt_embeds: torch.Tensor | None = None, - negative_prompt_embeds: torch.Tensor | None = None, - max_sequence_length: int = 512, - device: torch.device | None = None, - dtype: torch.dtype | None = None, - ) -> tuple[torch.Tensor, torch.Tensor | None]: - device = device or self._execution_device - dtype = dtype or self.text_conditioner.dtype - - prompt = [prompt] if isinstance(prompt, str) else prompt - batch_size = len(prompt) if prompt is not None else prompt_embeds.shape[0] - - if prompt_embeds is None: - qwen_prompt_embeds, qwen_attention_mask = self._get_qwen_prompt_embeds( - prompt=prompt, max_sequence_length=max_sequence_length, device=device, dtype=dtype - ) - t5_input_ids, t5_attention_mask = self._get_t5_prompt_ids( - prompt=prompt, max_sequence_length=max_sequence_length, device=device - ) - prompt_embeds = self.text_conditioner( - source_hidden_states=qwen_prompt_embeds, - target_input_ids=t5_input_ids, - target_attention_mask=t5_attention_mask, - source_attention_mask=qwen_attention_mask, - ) - prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) - - _, seq_len, _ = prompt_embeds.shape - prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) - prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) - - if do_classifier_free_guidance and negative_prompt_embeds is None: - negative_prompt = negative_prompt if negative_prompt is not None else "" - negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt - - if prompt is not None and type(prompt) is not type(negative_prompt): - raise TypeError( - f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" - f" {type(prompt)}." - ) - if batch_size != len(negative_prompt): - raise ValueError( - f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" - f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" - " the batch size of `prompt`." - ) - - negative_qwen_prompt_embeds, negative_qwen_attention_mask = self._get_qwen_prompt_embeds( - prompt=negative_prompt, max_sequence_length=max_sequence_length, device=device, dtype=dtype - ) - negative_t5_input_ids, negative_t5_attention_mask = self._get_t5_prompt_ids( - prompt=negative_prompt, max_sequence_length=max_sequence_length, device=device - ) - negative_prompt_embeds = self.text_conditioner( - source_hidden_states=negative_qwen_prompt_embeds, - target_input_ids=negative_t5_input_ids, - target_attention_mask=negative_t5_attention_mask, - source_attention_mask=negative_qwen_attention_mask, - ) - negative_prompt_embeds = negative_prompt_embeds.to(dtype=dtype, device=device) - - _, seq_len, _ = negative_prompt_embeds.shape - negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) - negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) - - return prompt_embeds, negative_prompt_embeds - - def check_inputs( - self, - prompt, - height, - width, - prompt_embeds=None, - negative_prompt=None, - negative_prompt_embeds=None, - callback_on_step_end_tensor_inputs=None, - max_sequence_length=None, - ): - if height % (self.vae_scale_factor * 2) != 0 or width % (self.vae_scale_factor * 2) != 0: - raise ValueError( - f"`height` and `width` have to be divisible by {self.vae_scale_factor * 2} but are {height} and" - f" {width}." - ) - - if callback_on_step_end_tensor_inputs is not None and not all( - k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs - ): - raise ValueError( - f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found " - f"{[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" - ) - - if prompt is not None and prompt_embeds is not None: - raise ValueError( - f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" - " only forward one of the two." - ) - if prompt is None and prompt_embeds is None: - raise ValueError( - "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." - ) - if prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): - raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") - - if negative_prompt is not None and negative_prompt_embeds is not None: - raise ValueError( - f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" - f" {negative_prompt_embeds}. Please make sure to only forward one of the two." - ) - - if max_sequence_length is not None and max_sequence_length > 4096: - raise ValueError(f"`max_sequence_length` cannot be greater than 4096 but is {max_sequence_length}") - - def prepare_latents( - self, - batch_size: int, - num_channels_latents: int, - height: int, - width: int, - num_frames: int, - dtype: torch.dtype, - device: torch.device, - generator: torch.Generator | list[torch.Generator] | None, - latents: torch.Tensor | None = None, - ) -> torch.Tensor: - if latents is not None: - return latents.to(device=device, dtype=dtype) - - num_latent_frames = (num_frames - 1) // self.vae_scale_factor + 1 - latent_height = height // self.vae_scale_factor - latent_width = width // self.vae_scale_factor - shape = (batch_size, num_channels_latents, num_latent_frames, latent_height, latent_width) - - if isinstance(generator, list) and len(generator) != batch_size: - raise ValueError( - f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" - f" size of {batch_size}. Make sure the batch size matches the length of the generators." - ) - - return randn_tensor(shape, generator=generator, device=device, dtype=dtype) - - @property - def guidance_scale(self): - return self._guidance_scale - - @property - def num_timesteps(self): - return self._num_timesteps - - @property - def current_timestep(self): - return self._current_timestep - - @property - def interrupt(self): - return self._interrupt - - @torch.no_grad() - @replace_example_docstring(EXAMPLE_DOC_STRING) - def __call__( - self, - prompt: str | list[str] = None, - negative_prompt: str | list[str] | None = None, - height: int | None = None, - width: int | None = None, - num_inference_steps: int = 50, - sigmas: list[float] | None = None, - guidance_scale: float = 4.0, - num_images_per_prompt: int = 1, - generator: torch.Generator | list[torch.Generator] | None = None, - latents: torch.Tensor | None = None, - prompt_embeds: torch.Tensor | None = None, - negative_prompt_embeds: torch.Tensor | None = None, - output_type: str | None = "pil", - return_dict: bool = True, - callback_on_step_end: Callable[[int, int], None] | PipelineCallback | MultiPipelineCallbacks | None = None, - callback_on_step_end_tensor_inputs: list[str] = ["latents"], - max_sequence_length: int = 512, - ): - r""" - Function invoked when calling the pipeline for generation. - - Args: - prompt (`str` or `list[str]`, *optional*): - The prompt or prompts to guide image generation. If not defined, pass `prompt_embeds`. - negative_prompt (`str` or `list[str]`, *optional*): - The prompt or prompts not to guide image generation. Used when `guidance_scale > 1`. - height (`int`, *optional*, defaults to `1024`): - Height in pixels of the generated image. - width (`int`, *optional*, defaults to `1024`): - Width in pixels of the generated image. - num_inference_steps (`int`, *optional*, defaults to 50): - Number of denoising steps. - sigmas (`list[float]`, *optional*): - Custom sigma schedule to use for schedulers that support `sigmas`. - guidance_scale (`float`, *optional*, defaults to 4.0): - Classifier-free guidance scale. - num_images_per_prompt (`int`, *optional*, defaults to 1): - Number of images to generate per prompt. - generator (`torch.Generator` or `list[torch.Generator]`, *optional*): - Random generator for deterministic generation. - latents (`torch.Tensor`, *optional*): - Pre-generated noisy latents. - prompt_embeds (`torch.Tensor`, *optional*): - Pre-generated conditioned prompt embeddings. - negative_prompt_embeds (`torch.Tensor`, *optional*): - Pre-generated conditioned negative prompt embeddings. - output_type (`str`, *optional*, defaults to `"pil"`): - Output format, one of `"pil"`, `"np"`, `"pt"`, or `"latent"`. - return_dict (`bool`, *optional*, defaults to `True`): - Whether to return an [`~pipelines.ImagePipelineOutput`]. - callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*): - Function called at the end of each denoising step. - callback_on_step_end_tensor_inputs (`list`, *optional*): - Tensor inputs available to `callback_on_step_end`. - max_sequence_length (`int`, *optional*, defaults to 512): - Maximum sequence length used by both text tokenizers. - - Examples: - - Returns: - [`~pipelines.ImagePipelineOutput`] or `tuple`: - Generated images if `return_dict` is `True`; otherwise a tuple whose first item is the images. - """ - - height = height or self.default_sample_size * self.vae_scale_factor - width = width or self.default_sample_size * self.vae_scale_factor - - if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): - callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs - - self.check_inputs( - prompt, - height, - width, - prompt_embeds, - negative_prompt, - negative_prompt_embeds, - callback_on_step_end_tensor_inputs, - max_sequence_length, - ) - - self._guidance_scale = guidance_scale - self._current_timestep = None - self._interrupt = False - - device = self._execution_device - num_frames = 1 - do_classifier_free_guidance = guidance_scale > 1.0 - - if prompt is not None and isinstance(prompt, str): - batch_size = 1 - elif prompt is not None and isinstance(prompt, list): - batch_size = len(prompt) - else: - batch_size = prompt_embeds.shape[0] - - prompt_embeds, negative_prompt_embeds = self.encode_prompt( - prompt=prompt, - negative_prompt=negative_prompt, - do_classifier_free_guidance=do_classifier_free_guidance, - num_images_per_prompt=num_images_per_prompt, - prompt_embeds=prompt_embeds, - negative_prompt_embeds=negative_prompt_embeds, - max_sequence_length=max_sequence_length, - device=device, - dtype=self.transformer.dtype, - ) - - sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas - timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, device=device, sigmas=sigmas) - self.scheduler.set_begin_index(0) - - transformer_dtype = self.transformer.dtype - num_channels_latents = self.transformer.config.in_channels - latents = self.prepare_latents( - batch_size * num_images_per_prompt, - num_channels_latents, - height, - width, - num_frames, - torch.float32, - device, - generator, - latents, - ) - - padding_mask = latents.new_zeros(1, 1, height, width, dtype=transformer_dtype) - - num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order - self._num_timesteps = len(timesteps) - - with self.progress_bar(total=num_inference_steps) as progress_bar: - for i, t in enumerate(timesteps): - if self.interrupt: - continue - - self._current_timestep = t - timestep = t.expand(latents.shape[0]).to(transformer_dtype) - timestep = timestep / self.scheduler.config.num_train_timesteps - latent_model_input = latents.to(transformer_dtype) - - noise_pred = self.transformer( - hidden_states=latent_model_input, - timestep=timestep, - encoder_hidden_states=prompt_embeds, - padding_mask=padding_mask, - return_dict=False, - )[0] - - if do_classifier_free_guidance: - negative_noise_pred = self.transformer( - hidden_states=latent_model_input, - timestep=timestep, - encoder_hidden_states=negative_prompt_embeds, - padding_mask=padding_mask, - return_dict=False, - )[0] - noise_pred = negative_noise_pred + self.guidance_scale * (noise_pred - negative_noise_pred) - - latents_dtype = latents.dtype - latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] - if latents.dtype != latents_dtype and torch.backends.mps.is_available(): - latents = latents.to(latents_dtype) - - if callback_on_step_end is not None: - callback_kwargs = {} - for k in callback_on_step_end_tensor_inputs: - callback_kwargs[k] = locals()[k] - callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) - - latents = callback_outputs.pop("latents", latents) - prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) - negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) - - if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): - progress_bar.update() - - if XLA_AVAILABLE: - xm.mark_step() - - self._current_timestep = None - - if output_type == "latent": - image = latents[:, :, 0] - else: - latents = latents.to(self.vae.dtype) - latents_mean = ( - torch.tensor(self.vae.config.latents_mean) - .view(1, self.vae.config.z_dim, 1, 1, 1) - .to(latents.device, latents.dtype) - ) - latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to( - latents.device, latents.dtype - ) - latents = latents / latents_std + latents_mean - image = self.vae.decode(latents, return_dict=False)[0][:, :, 0] - image = self.image_processor.postprocess(image, output_type=output_type) - - self.maybe_free_model_hooks() - - if not return_dict: - return (image,) - - return ImagePipelineOutput(images=image) diff --git a/src/diffusers/utils/dummy_pt_objects.py b/src/diffusers/utils/dummy_pt_objects.py index 9bfb73c1999e..bc9856604928 100644 --- a/src/diffusers/utils/dummy_pt_objects.py +++ b/src/diffusers/utils/dummy_pt_objects.py @@ -435,6 +435,21 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) +class AnimaTextConditioner(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + class AsymmetricAutoencoderKL(metaclass=DummyObject): _backends = ["torch"] diff --git a/src/diffusers/utils/dummy_torch_and_transformers_objects.py b/src/diffusers/utils/dummy_torch_and_transformers_objects.py index b8b05fce8818..d320e2374f50 100644 --- a/src/diffusers/utils/dummy_torch_and_transformers_objects.py +++ b/src/diffusers/utils/dummy_torch_and_transformers_objects.py @@ -2,6 +2,36 @@ from ..utils import DummyObject, requires_backends +class AnimaAutoBlocks(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + +class AnimaModularPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + class ErnieImageAutoBlocks(metaclass=DummyObject): _backends = ["torch", "transformers"] @@ -812,36 +842,6 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) -class AnimaPipeline(metaclass=DummyObject): - _backends = ["torch", "transformers"] - - def __init__(self, *args, **kwargs): - requires_backends(self, ["torch", "transformers"]) - - @classmethod - def from_config(cls, *args, **kwargs): - requires_backends(cls, ["torch", "transformers"]) - - @classmethod - def from_pretrained(cls, *args, **kwargs): - requires_backends(cls, ["torch", "transformers"]) - - -class AnimaTextConditioner(metaclass=DummyObject): - _backends = ["torch", "transformers"] - - def __init__(self, *args, **kwargs): - requires_backends(self, ["torch", "transformers"]) - - @classmethod - def from_config(cls, *args, **kwargs): - requires_backends(cls, ["torch", "transformers"]) - - @classmethod - def from_pretrained(cls, *args, **kwargs): - requires_backends(cls, ["torch", "transformers"]) - - class AnimateDiffControlNetPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] diff --git a/tests/modular_pipelines/anima/__init__.py b/tests/modular_pipelines/anima/__init__.py new file mode 100644 index 000000000000..8b137891791f --- /dev/null +++ b/tests/modular_pipelines/anima/__init__.py @@ -0,0 +1 @@ + diff --git a/tests/modular_pipelines/anima/test_modular_pipeline_anima.py b/tests/modular_pipelines/anima/test_modular_pipeline_anima.py new file mode 100644 index 000000000000..38fc0d3649d6 --- /dev/null +++ b/tests/modular_pipelines/anima/test_modular_pipeline_anima.py @@ -0,0 +1,230 @@ +# Copyright 2026 The HuggingFace Team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import tempfile +import unittest + +import torch +from transformers import Qwen2Tokenizer, Qwen3Config, Qwen3Model, T5TokenizerFast + +from diffusers import ( + AnimaAutoBlocks, + AnimaModularPipeline, + AnimaTextConditioner, + AutoencoderKLQwenImage, + CosmosTransformer3DModel, + FlowMatchEulerDiscreteScheduler, +) + +from ...testing_utils import enable_full_determinism, require_peft_backend +from ..test_modular_pipelines_common import ModularGuiderTesterMixin, ModularPipelineTesterMixin + + +enable_full_determinism() + + +ANIMA_TEXT2IMAGE_WORKFLOWS = { + "text2image": [ + ("text_encoder", "AnimaTextEncoderStep"), + ("denoise.input", "AnimaTextInputStep"), + ("denoise.prepare_latents", "AnimaPrepareLatentsStep"), + ("denoise.set_timesteps", "AnimaSetTimestepsStep"), + ("denoise.denoise", "AnimaDenoiseStep"), + ("decode.decode", "AnimaVaeDecoderStep"), + ("decode.postprocess", "AnimaProcessImagesOutputStep"), + ], +} + + +def get_dummy_components(): + torch.manual_seed(0) + transformer = CosmosTransformer3DModel( + in_channels=4, + out_channels=4, + num_attention_heads=2, + attention_head_dim=16, + num_layers=2, + mlp_ratio=2, + text_embed_dim=16, + adaln_lora_dim=4, + max_size=(4, 32, 32), + patch_size=(1, 2, 2), + rope_scale=(1.0, 4.0, 4.0), + concat_padding_mask=True, + extra_pos_embed_type=None, + ) + + torch.manual_seed(0) + vae = AutoencoderKLQwenImage( + base_dim=24, + z_dim=4, + dim_mult=[1, 2, 4], + num_res_blocks=1, + temperal_downsample=[False, True], + latents_mean=[0.0] * 4, + latents_std=[1.0] * 4, + ) + + torch.manual_seed(0) + text_conditioner = AnimaTextConditioner( + source_dim=16, + target_dim=16, + model_dim=16, + num_layers=2, + num_attention_heads=4, + target_vocab_size=32128, + min_sequence_length=16, + ) + + torch.manual_seed(0) + text_encoder_config = Qwen3Config( + vocab_size=152064, + hidden_size=16, + intermediate_size=32, + num_hidden_layers=2, + num_attention_heads=4, + num_key_value_heads=2, + max_position_embeddings=128, + rms_norm_eps=1e-6, + rope_theta=1000000.0, + head_dim=4, + attention_bias=False, + ) + text_encoder = Qwen3Model(text_encoder_config).eval() + tokenizer = Qwen2Tokenizer.from_pretrained("hf-internal-testing/tiny-random-Qwen2VLForConditionalGeneration") + t5_tokenizer = T5TokenizerFast.from_pretrained("hf-internal-testing/tiny-random-t5") + scheduler = FlowMatchEulerDiscreteScheduler(shift=3.0) + + return { + "transformer": transformer, + "vae": vae, + "scheduler": scheduler, + "text_encoder": text_encoder, + "tokenizer": tokenizer, + "t5_tokenizer": t5_tokenizer, + "text_conditioner": text_conditioner, + } + + +class AnimaTextConditionerFastTests(unittest.TestCase): + def test_conditioner_output_shape_and_padding(self): + conditioner = AnimaTextConditioner( + source_dim=16, + target_dim=16, + model_dim=16, + num_layers=2, + num_attention_heads=4, + target_vocab_size=128, + min_sequence_length=8, + ) + source_hidden_states = torch.randn(2, 5, 16) + target_input_ids = torch.randint(0, 128, (2, 4)) + source_attention_mask = torch.ones(2, 5) + target_attention_mask = torch.ones(2, 4) + target_attention_mask[1, -1] = 0 + + output = conditioner( + source_hidden_states=source_hidden_states, + target_input_ids=target_input_ids, + source_attention_mask=source_attention_mask, + target_attention_mask=target_attention_mask, + ) + + self.assertEqual(output.shape, (2, 8, 16)) + self.assertTrue(torch.allclose(output[1, 3], torch.zeros_like(output[1, 3]), atol=1e-5)) + self.assertTrue(torch.allclose(output[:, 4:], torch.zeros_like(output[:, 4:]), atol=1e-5)) + + +class TestAnimaModularPipelineFast(ModularPipelineTesterMixin, ModularGuiderTesterMixin): + pipeline_class = AnimaModularPipeline + pipeline_blocks_class = AnimaAutoBlocks + pretrained_model_name_or_path = "hf-internal-testing/tiny-anima-modular-pipe" + params = frozenset(["prompt", "height", "width", "negative_prompt"]) + batch_params = frozenset(["prompt", "negative_prompt"]) + expected_workflow_blocks = ANIMA_TEXT2IMAGE_WORKFLOWS + + def get_pipeline(self, components_manager=None, torch_dtype=torch.float32): + pipe = self.pipeline_blocks_class().init_pipeline(components_manager=components_manager) + pipe.update_components(**get_dummy_components()) + pipe.to(dtype=torch_dtype) + pipe.set_progress_bar_config(disable=None) + return pipe + + def get_dummy_inputs(self, seed=0): + generator = torch.Generator(device="cpu").manual_seed(seed) + return { + "prompt": "dance monkey", + "negative_prompt": "bad quality", + "generator": generator, + "num_inference_steps": 2, + "height": 32, + "width": 32, + "max_sequence_length": 16, + "output_type": "pt", + } + + def test_inference_empty_negative_prompt(self): + pipe = self.get_pipeline() + + inputs = self.get_dummy_inputs() + inputs["negative_prompt"] = "" + output = pipe(**inputs).images + + assert output.shape == (1, 3, 32, 32) + assert not torch.isnan(output).any() + + def test_inference_batch_single_identical(self): + super().test_inference_batch_single_identical(expected_max_diff=5e-4) + + def test_save_load_components(self): + pipe = self.get_pipeline() + + with tempfile.TemporaryDirectory() as tmpdir: + pipe.save_pretrained(tmpdir, safe_serialization=True) + pipe = self.pipeline_class.from_pretrained(tmpdir) + pipe.load_components() + + assert isinstance(pipe.text_conditioner, AnimaTextConditioner) + assert isinstance(pipe.transformer, CosmosTransformer3DModel) + + def test_lora_state_dict_conversion(self): + state_dict = { + "diffusion_model.blocks.0.self_attn.q_proj.lora_A.weight": torch.randn(2, 32), + "diffusion_model.blocks.0.self_attn.q_proj.lora_B.weight": torch.randn(32, 2), + "diffusion_model.blocks.0.adaln_modulation_cross_attn.1.lora_A.weight": torch.randn(2, 32), + "diffusion_model.blocks.0.adaln_modulation_cross_attn.1.lora_B.weight": torch.randn(4, 2), + "diffusion_model.llm_adapter.blocks.0.self_attn.q_proj.lora_A.weight": torch.randn(2, 16), + "diffusion_model.llm_adapter.blocks.0.self_attn.q_proj.lora_B.weight": torch.randn(16, 2), + } + + converted_state_dict = self.pipeline_class.lora_state_dict(state_dict) + + assert "transformer.transformer_blocks.0.attn1.to_q.lora_A.weight" in converted_state_dict + assert "transformer.transformer_blocks.0.norm2.linear_1.lora_B.weight" in converted_state_dict + assert "text_conditioner.blocks.0.self_attn.q_proj.lora_A.weight" in converted_state_dict + + @require_peft_backend + def test_load_lora_weights(self): + pipe = self.get_pipeline() + state_dict = { + "diffusion_model.blocks.0.self_attn.q_proj.lora_A.weight": torch.randn(2, 32), + "diffusion_model.blocks.0.self_attn.q_proj.lora_B.weight": torch.randn(32, 2), + "diffusion_model.llm_adapter.blocks.0.self_attn.q_proj.lora_A.weight": torch.randn(2, 16), + "diffusion_model.llm_adapter.blocks.0.self_attn.q_proj.lora_B.weight": torch.randn(16, 2), + } + + pipe.load_lora_weights(state_dict, adapter_name="dummy") + + assert "dummy" in pipe.transformer.peft_config + assert "dummy" in pipe.text_conditioner.peft_config diff --git a/tests/pipelines/anima/__init__.py b/tests/pipelines/anima/__init__.py deleted file mode 100644 index b51b917e8d98..000000000000 --- a/tests/pipelines/anima/__init__.py +++ /dev/null @@ -1 +0,0 @@ -# Empty init for Anima pipeline tests. diff --git a/tests/pipelines/anima/test_anima.py b/tests/pipelines/anima/test_anima.py deleted file mode 100644 index fdf8211672e6..000000000000 --- a/tests/pipelines/anima/test_anima.py +++ /dev/null @@ -1,213 +0,0 @@ -# Copyright 2026 The HuggingFace Team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import tempfile -import unittest - -import torch -from transformers import Qwen2Tokenizer, Qwen3Config, Qwen3Model, T5TokenizerFast - -from diffusers import ( - AnimaPipeline, - AnimaTextConditioner, - AutoencoderKLQwenImage, - CosmosTransformer3DModel, - FlowMatchEulerDiscreteScheduler, -) - -from ...testing_utils import enable_full_determinism, require_peft_backend - - -enable_full_determinism() - - -class AnimaTextConditionerFastTests(unittest.TestCase): - def test_conditioner_output_shape_and_padding(self): - conditioner = AnimaTextConditioner( - source_dim=16, - target_dim=16, - model_dim=16, - num_layers=2, - num_attention_heads=4, - target_vocab_size=128, - min_sequence_length=8, - ) - source_hidden_states = torch.randn(2, 5, 16) - target_input_ids = torch.randint(0, 128, (2, 4)) - source_attention_mask = torch.ones(2, 5) - target_attention_mask = torch.ones(2, 4) - target_attention_mask[1, -1] = 0 - - output = conditioner( - source_hidden_states=source_hidden_states, - target_input_ids=target_input_ids, - source_attention_mask=source_attention_mask, - target_attention_mask=target_attention_mask, - ) - - self.assertEqual(output.shape, (2, 8, 16)) - self.assertTrue(torch.allclose(output[1, 3], torch.zeros_like(output[1, 3]), atol=1e-5)) - self.assertTrue(torch.allclose(output[:, 4:], torch.zeros_like(output[:, 4:]), atol=1e-5)) - - -class AnimaPipelineFastTests(unittest.TestCase): - pipeline_class = AnimaPipeline - - def get_dummy_components(self): - torch.manual_seed(0) - transformer = CosmosTransformer3DModel( - in_channels=4, - out_channels=4, - num_attention_heads=2, - attention_head_dim=16, - num_layers=2, - mlp_ratio=2, - text_embed_dim=16, - adaln_lora_dim=4, - max_size=(4, 32, 32), - patch_size=(1, 2, 2), - rope_scale=(1.0, 4.0, 4.0), - concat_padding_mask=True, - extra_pos_embed_type=None, - ) - - torch.manual_seed(0) - vae = AutoencoderKLQwenImage( - base_dim=24, - z_dim=4, - dim_mult=[1, 2, 4], - num_res_blocks=1, - temperal_downsample=[False, True], - latents_mean=[0.0] * 4, - latents_std=[1.0] * 4, - ) - - torch.manual_seed(0) - text_conditioner = AnimaTextConditioner( - source_dim=16, - target_dim=16, - model_dim=16, - num_layers=2, - num_attention_heads=4, - target_vocab_size=32128, - min_sequence_length=16, - ) - - torch.manual_seed(0) - text_encoder_config = Qwen3Config( - vocab_size=152064, - hidden_size=16, - intermediate_size=32, - num_hidden_layers=2, - num_attention_heads=4, - num_key_value_heads=2, - max_position_embeddings=128, - rms_norm_eps=1e-6, - rope_theta=1000000.0, - head_dim=4, - attention_bias=False, - ) - text_encoder = Qwen3Model(text_encoder_config).eval() - tokenizer = Qwen2Tokenizer.from_pretrained("hf-internal-testing/tiny-random-Qwen2VLForConditionalGeneration") - t5_tokenizer = T5TokenizerFast.from_pretrained("hf-internal-testing/tiny-random-t5") - scheduler = FlowMatchEulerDiscreteScheduler(shift=3.0) - - return { - "transformer": transformer, - "vae": vae, - "scheduler": scheduler, - "text_encoder": text_encoder, - "tokenizer": tokenizer, - "t5_tokenizer": t5_tokenizer, - "text_conditioner": text_conditioner, - } - - def get_dummy_inputs(self, seed=0): - generator = torch.Generator(device="cpu").manual_seed(seed) - return { - "prompt": "dance monkey", - "negative_prompt": "bad quality", - "generator": generator, - "num_inference_steps": 2, - "guidance_scale": 3.0, - "height": 32, - "width": 32, - "max_sequence_length": 16, - "output_type": "pt", - } - - def test_inference(self): - components = self.get_dummy_components() - pipe = self.pipeline_class(**components) - pipe.set_progress_bar_config(disable=None) - - output = pipe(**self.get_dummy_inputs()).images - - self.assertEqual(output.shape, (1, 3, 32, 32)) - self.assertFalse(torch.isnan(output).any()) - - def test_inference_empty_negative_prompt(self): - components = self.get_dummy_components() - pipe = self.pipeline_class(**components) - pipe.set_progress_bar_config(disable=None) - - inputs = self.get_dummy_inputs() - inputs["negative_prompt"] = "" - output = pipe(**inputs).images - - self.assertEqual(output.shape, (1, 3, 32, 32)) - self.assertFalse(torch.isnan(output).any()) - - def test_save_load_optional_components(self): - components = self.get_dummy_components() - pipe = self.pipeline_class(**components) - - with tempfile.TemporaryDirectory() as tmpdir: - pipe.save_pretrained(tmpdir, safe_serialization=True) - pipe = self.pipeline_class.from_pretrained(tmpdir) - - self.assertIsInstance(pipe.text_conditioner, AnimaTextConditioner) - self.assertIsInstance(pipe.transformer, CosmosTransformer3DModel) - - def test_lora_state_dict_conversion(self): - state_dict = { - "diffusion_model.blocks.0.self_attn.q_proj.lora_A.weight": torch.randn(2, 32), - "diffusion_model.blocks.0.self_attn.q_proj.lora_B.weight": torch.randn(32, 2), - "diffusion_model.blocks.0.adaln_modulation_cross_attn.1.lora_A.weight": torch.randn(2, 32), - "diffusion_model.blocks.0.adaln_modulation_cross_attn.1.lora_B.weight": torch.randn(4, 2), - "diffusion_model.llm_adapter.blocks.0.self_attn.q_proj.lora_A.weight": torch.randn(2, 16), - "diffusion_model.llm_adapter.blocks.0.self_attn.q_proj.lora_B.weight": torch.randn(16, 2), - } - - converted_state_dict = self.pipeline_class.lora_state_dict(state_dict) - - self.assertIn("transformer.transformer_blocks.0.attn1.to_q.lora_A.weight", converted_state_dict) - self.assertIn("transformer.transformer_blocks.0.norm2.linear_1.lora_B.weight", converted_state_dict) - self.assertIn("text_conditioner.blocks.0.self_attn.q_proj.lora_A.weight", converted_state_dict) - - @require_peft_backend - def test_load_lora_weights(self): - components = self.get_dummy_components() - pipe = self.pipeline_class(**components) - state_dict = { - "diffusion_model.blocks.0.self_attn.q_proj.lora_A.weight": torch.randn(2, 32), - "diffusion_model.blocks.0.self_attn.q_proj.lora_B.weight": torch.randn(32, 2), - "diffusion_model.llm_adapter.blocks.0.self_attn.q_proj.lora_A.weight": torch.randn(2, 16), - "diffusion_model.llm_adapter.blocks.0.self_attn.q_proj.lora_B.weight": torch.randn(16, 2), - } - - pipe.load_lora_weights(state_dict, adapter_name="dummy") - - self.assertIn("dummy", pipe.transformer.peft_config) - self.assertIn("dummy", pipe.text_conditioner.peft_config) From 507f374f1692679ca1595dc4f8acd6971ec318c8 Mon Sep 17 00:00:00 2001 From: rmatif Date: Thu, 14 May 2026 10:04:17 +0200 Subject: [PATCH 6/6] Use modular loader in Anima docs --- docs/source/en/api/pipelines/anima.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/source/en/api/pipelines/anima.md b/docs/source/en/api/pipelines/anima.md index 5d21235cb9d8..b7db72632776 100644 --- a/docs/source/en/api/pipelines/anima.md +++ b/docs/source/en/api/pipelines/anima.md @@ -4,9 +4,9 @@ Anima is a text-to-image model that reuses the [`CosmosTransformer3DModel`] with ```python import torch -from diffusers import AnimaAutoBlocks +from diffusers import ModularPipeline -pipe = AnimaAutoBlocks().init_pipeline("path/to/anima-diffusers") +pipe = ModularPipeline.from_pretrained("mrfatso/anima-preview3-diffusers") pipe.load_components(torch_dtype=torch.bfloat16) pipe.to("cuda")