From ab030901be7fb2055379e55b11a707f7e069a6ba Mon Sep 17 00:00:00 2001 From: Yuliya Zhautouskaya Date: Thu, 2 Jul 2026 13:54:41 +0000 Subject: [PATCH 1/2] Cosmos3 ModularPipeline initial commit --- docs/source/en/api/pipelines/cosmos3.md | 188 ++++++++ src/diffusers/__init__.py | 4 + src/diffusers/modular_pipelines/__init__.py | 5 + .../modular_pipelines/cosmos/__init__.py | 47 ++ .../cosmos/before_denoise.py | 311 +++++++++++++ .../modular_pipelines/cosmos/decoders.py | 100 +++++ .../modular_pipelines/cosmos/denoise.py | 259 +++++++++++ .../modular_pipelines/cosmos/encoders.py | 163 +++++++ .../cosmos/modular_blocks_cosmos3.py | 66 +++ .../cosmos/modular_pipeline.py | 134 ++++++ .../modular_pipelines/modular_pipeline.py | 1 + .../cosmos/test_cosmos3_modular_parity.py | 414 ++++++++++++++++++ 12 files changed, 1692 insertions(+) create mode 100644 src/diffusers/modular_pipelines/cosmos/__init__.py create mode 100644 src/diffusers/modular_pipelines/cosmos/before_denoise.py create mode 100644 src/diffusers/modular_pipelines/cosmos/decoders.py create mode 100644 src/diffusers/modular_pipelines/cosmos/denoise.py create mode 100644 src/diffusers/modular_pipelines/cosmos/encoders.py create mode 100644 src/diffusers/modular_pipelines/cosmos/modular_blocks_cosmos3.py create mode 100644 src/diffusers/modular_pipelines/cosmos/modular_pipeline.py create mode 100644 tests/pipelines/cosmos/test_cosmos3_modular_parity.py diff --git a/docs/source/en/api/pipelines/cosmos3.md b/docs/source/en/api/pipelines/cosmos3.md index 1ac8f36457a4..845101f0cd2e 100644 --- a/docs/source/en/api/pipelines/cosmos3.md +++ b/docs/source/en/api/pipelines/cosmos3.md @@ -738,6 +738,194 @@ pipe = Cosmos3OmniPipeline.from_pretrained( - all - __call__ +## Cosmos3OmniModularPipeline + +Cosmos 3 is also available as a Modular Diffusers pipeline. The task-based [`Cosmos3OmniPipeline`] remains available; the modular pipeline coexists with it and covers the same modes (`text2image`, `text2video`, `image2video`, `video2video`, and action-conditioned generation, with optional sound when supported by the checkpoint). + +```python +import torch +from diffusers import Cosmos3OmniModularPipeline + +pipe = Cosmos3OmniModularPipeline.from_pretrained( + "nvidia/Cosmos3-Nano", torch_dtype=torch.bfloat16 +) +pipe.load_components(torch_dtype=torch.bfloat16) + +result = pipe( + prompt='{"scene":"A robot arm in a kitchen"}', + num_frames=1, + height=720, + width=1280, +) + +# Same return payload as the task pipeline. +image = result.video[0] +``` + +You can also load through [`ModularPipeline`] and let the repository config select the blocks class: + +```python +import torch +from diffusers import ModularPipeline + +pipe = ModularPipeline.from_pretrained("nvidia/Cosmos3-Nano", torch_dtype=torch.bfloat16) +pipe.load_components(torch_dtype=torch.bfloat16) +result = pipe(prompt='{"scene":"A robot arm in a kitchen"}', num_frames=1, height=720, width=1280) +``` + +To inspect or customize a specific Cosmos modular workflow, use `available_workflows` + `get_workflow()`: + +```python +available = pipe.blocks.available_workflows +image2video_blocks = pipe.blocks.get_workflow("image2video") +``` + +### Modular examples for all existing workflows + +The modular pipeline supports the same call signatures as the task pipeline. The snippets below mirror every generation example shown above (`text2video`, `text2image`, `image2video`, `video2video`, `video2video_sound`, `text2video_sound`, and `action_policy`). + +```python +import json +import torch +from diffusers import Cosmos3OmniModularPipeline, CosmosActionCondition +from diffusers.schedulers.scheduling_unipc_multistep import UniPCMultistepScheduler +from diffusers.utils import encode_video, export_to_video, load_image, load_video + +pipe = Cosmos3OmniModularPipeline.from_pretrained("nvidia/Cosmos3-Nano", torch_dtype=torch.bfloat16) +pipe.load_components(torch_dtype=torch.bfloat16) +pipe.to("cuda") +pipe.scheduler = UniPCMultistepScheduler.from_config( + pipe.scheduler.config, flow_shift=10.0, use_karras_sigmas=False +) + +# text2video +json_prompt = json.load(open("assets/example_t2v_prompt.json")) +negative_prompt = json.load(open("assets/negative_prompt.json")) +result = pipe( + prompt=json.dumps(json_prompt), + negative_prompt=json.dumps(negative_prompt), + num_frames=189, + height=720, + width=1280, + num_inference_steps=35, + guidance_scale=6.0, + fps=24.0, +) +export_to_video(result.video, "cosmos3_modular_t2v.mp4", fps=24, macro_block_size=1) + +# text2image +json_prompt = json.load(open("assets/example_t2i_prompt.json")) +result = pipe(prompt=json.dumps(json_prompt), num_frames=1, height=720, width=1280) +result.video[0].save("cosmos3_modular_t2i.jpg", format="JPEG", quality=85) + +# image2video +json_prompt = json.load(open("assets/example_i2v_prompt.json")) +negative_prompt = json.load(open("assets/negative_prompt_i2v.json")) +image = load_image("https://github.com/nvidia-cosmos/cosmos-dependencies/releases/download/assets/robot_153.jpg") +result = pipe( + prompt=json.dumps(json_prompt), + negative_prompt=json.dumps(negative_prompt), + image=image, + num_frames=189, + height=720, + width=1280, + fps=24.0, +) +export_to_video(result.video, "cosmos3_modular_i2v.mp4", fps=24, macro_block_size=1) + +# video2video +json_prompt = json.load(open("assets/example_v2v_prompt.json")) +negative_prompt = json.load(open("assets/negative_prompt_i2v.json")) +video = load_video( + "https://github.com/nvidia-cosmos/cosmos-dependencies/raw/refs/heads/assets/cosmos3/inputs/vision/robot_pouring.mp4" +) +result = pipe( + prompt=json.dumps(json_prompt), + negative_prompt=json.dumps(negative_prompt), + video=video, + condition_frame_indexes_vision=[0, 1], + condition_video_keep="first", + num_frames=189, + height=720, + width=1280, + num_inference_steps=35, + guidance_scale=6.0, + fps=24.0, +) +export_to_video(result.video, "cosmos3_modular_v2v.mp4", fps=24, macro_block_size=1) + +# video2video_sound +result = pipe( + prompt=json.dumps(json_prompt), + negative_prompt=json.dumps(negative_prompt), + video=video, + condition_frame_indexes_vision=[0, 1], + condition_video_keep="first", + num_frames=189, + height=720, + width=1280, + fps=24.0, + enable_sound=True, +) +encode_video( + result.video, + fps=24, + audio=result.sound, + audio_sample_rate=pipe.sound_tokenizer.config.sampling_rate, + output_path="cosmos3_modular_v2v_with_sound.mp4", +) + +# text2video_sound +json_prompt = json.load(open("assets/example_t2v_sound_prompt.json")) +negative_prompt = json.load(open("assets/negative_prompt.json")) +result = pipe( + prompt=json.dumps(json_prompt), + negative_prompt=json.dumps(negative_prompt), + num_frames=189, + height=720, + width=1280, + fps=24.0, + enable_sound=True, +) +encode_video( + result.video, + fps=24, + audio=result.sound, + audio_sample_rate=pipe.sound_tokenizer.config.sampling_rate, + output_path="cosmos3_modular_t2v_with_sound.mp4", +) + +# action_policy +prompt = "Put the pot to the left of the purple item." +action_video = load_video( + "https://github.com/nvidia-cosmos/cosmos-dependencies/raw/refs/heads/assets/cosmos3/inputs/action/bridge_20260501_0.mp4" +) +result = pipe( + prompt=prompt, + action=CosmosActionCondition( + mode="policy", + chunk_size=16, + domain_name="bridge_orig_lerobot", + resolution_tier=480, + video=action_video, + view_point="ego_view", + ), + fps=5, + num_inference_steps=30, + guidance_scale=1.0, + use_system_prompt=False, +) +export_to_video(result.video, "cosmos3_modular_action_policy.mp4", fps=5, macro_block_size=1) +if result.action is not None: + with open("cosmos3_modular_action_policy.json", "w") as f: + json.dump(result.action[0].tolist(), f) +``` + +[[autodoc]] Cosmos3OmniModularPipeline + +- all +- __call__ + ## CosmosActionCondition [[autodoc]] CosmosActionCondition diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 81b36e113df4..9e9edca987e4 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -476,6 +476,8 @@ [ "AnimaAutoBlocks", "AnimaModularPipeline", + "Cosmos3OmniBlocks", + "Cosmos3OmniModularPipeline", "ErnieImageAutoBlocks", "ErnieImageModularPipeline", "Flux2AutoBlocks", @@ -1323,6 +1325,8 @@ from .modular_pipelines import ( AnimaAutoBlocks, AnimaModularPipeline, + Cosmos3OmniBlocks, + Cosmos3OmniModularPipeline, ErnieImageAutoBlocks, ErnieImageModularPipeline, Flux2AutoBlocks, diff --git a/src/diffusers/modular_pipelines/__init__.py b/src/diffusers/modular_pipelines/__init__.py index 4b36994aef07..25db2ef3bee2 100644 --- a/src/diffusers/modular_pipelines/__init__.py +++ b/src/diffusers/modular_pipelines/__init__.py @@ -97,6 +97,10 @@ "AnimaAutoBlocks", "AnimaModularPipeline", ] + _import_structure["cosmos"] = [ + "Cosmos3OmniBlocks", + "Cosmos3OmniModularPipeline", + ] _import_structure["ernie_image"] = [ "ErnieImageAutoBlocks", "ErnieImageModularPipeline", @@ -124,6 +128,7 @@ else: from .anima import AnimaAutoBlocks, AnimaModularPipeline from .components_manager import ComponentsManager + from .cosmos import Cosmos3OmniBlocks, Cosmos3OmniModularPipeline from .ernie_image import ErnieImageAutoBlocks, ErnieImageModularPipeline from .flux import FluxAutoBlocks, FluxKontextAutoBlocks, FluxKontextModularPipeline, FluxModularPipeline from .flux2 import ( diff --git a/src/diffusers/modular_pipelines/cosmos/__init__.py b/src/diffusers/modular_pipelines/cosmos/__init__.py new file mode 100644 index 000000000000..20150b299893 --- /dev/null +++ b/src/diffusers/modular_pipelines/cosmos/__init__.py @@ -0,0 +1,47 @@ +from typing import TYPE_CHECKING + +from ...utils import ( + DIFFUSERS_SLOW_IMPORT, + OptionalDependencyNotAvailable, + _LazyModule, + get_objects_from_module, + is_torch_available, + is_transformers_available, +) + + +_dummy_objects = {} +_import_structure = {} + +try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ...utils import dummy_torch_and_transformers_objects # noqa F403 + + _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) +else: + _import_structure["modular_blocks_cosmos3"] = ["Cosmos3OmniBlocks"] + _import_structure["modular_pipeline"] = ["Cosmos3OmniModularPipeline"] + +if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: + try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + from ...utils.dummy_torch_and_transformers_objects import * # noqa F403 + else: + from .modular_blocks_cosmos3 import Cosmos3OmniBlocks + from .modular_pipeline import Cosmos3OmniModularPipeline +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + module_spec=__spec__, + ) + + for name, value in _dummy_objects.items(): + setattr(sys.modules[__name__], name, value) diff --git a/src/diffusers/modular_pipelines/cosmos/before_denoise.py b/src/diffusers/modular_pipelines/cosmos/before_denoise.py new file mode 100644 index 000000000000..d8c6bfd969d9 --- /dev/null +++ b/src/diffusers/modular_pipelines/cosmos/before_denoise.py @@ -0,0 +1,311 @@ +import copy +from typing import Any + +import torch + +from ...models.autoencoders.autoencoder_cosmos3_audio import Cosmos3AVAEAudioTokenizer +from ...models.autoencoders.autoencoder_kl_wan import AutoencoderKLWan +from ...models.transformers.transformer_cosmos3 import Cosmos3OmniTransformer +from ...schedulers import UniPCMultistepScheduler +from ...utils import logging +from ..modular_pipeline import ModularPipelineBlocks, PipelineState +from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam +from .modular_pipeline import Cosmos3OmniModularPipeline + + +logger = logging.get_logger(__name__) + + +class Cosmos3PrepareLatentsStep(ModularPipelineBlocks): + model_name = "cosmos3-omni" + + @property + def description(self) -> str: + return "Prepares vision/sound/action latents and conditioning masks." + + @property + def expected_components(self) -> list[ComponentSpec]: + return [ + ComponentSpec("transformer", Cosmos3OmniTransformer), + ComponentSpec("vae", AutoencoderKLWan), + ComponentSpec("sound_tokenizer", Cosmos3AVAEAudioTokenizer), + ] + + @property + def inputs(self) -> list[InputParam]: + return [ + InputParam(name="image", default=None), + InputParam(name="video", default=None), + InputParam(name="condition_frame_indexes_vision", default=(0, 1)), + InputParam(name="condition_video_keep", default="first"), + InputParam(name="num_frames", required=True), + InputParam(name="height", required=True), + InputParam(name="width", required=True), + InputParam(name="fps", type_hint=float, default=24.0), + InputParam(name="latents", default=None), + InputParam(name="sound_latents", default=None), + InputParam(name="action_latents", default=None), + InputParam(name="generator", default=None), + InputParam(name="enable_sound", type_hint=bool, default=False), + InputParam(name="action", default=None), + InputParam(name="device", required=True), + InputParam(name="dtype", required=True), + ] + + @property + def intermediate_outputs(self) -> list[OutputParam]: + return [ + OutputParam("latents"), + OutputParam("sound_latents"), + OutputParam("action_latents"), + OutputParam("fps_vision"), + OutputParam("fps_sound"), + OutputParam("vision_condition_mask"), + OutputParam("sound_condition_mask"), + OutputParam("action_condition_mask"), + OutputParam("action_domain_id"), + OutputParam("action_image_size"), + OutputParam("raw_action_dim_resolved"), + OutputParam("action_condition_frame_indexes"), + OutputParam("vision_condition_indexes_for_pack"), + OutputParam("has_image_condition"), + ] + + @torch.no_grad() + def __call__(self, components: Cosmos3OmniModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + + ( + block_state.latents, + block_state.sound_latents, + block_state.action_latents, + block_state.fps_vision, + block_state.fps_sound, + block_state.vision_condition_mask, + block_state.sound_condition_mask, + block_state.action_condition_mask, + block_state.action_domain_id, + block_state.action_image_size, + block_state.raw_action_dim_resolved, + block_state.action_condition_frame_indexes, + ) = components.prepare_latents( + image=block_state.image, + video=block_state.video, + condition_frame_indexes_vision=block_state.condition_frame_indexes_vision, + condition_video_keep=block_state.condition_video_keep, + num_frames=block_state.num_frames, + height=block_state.height, + width=block_state.width, + fps=block_state.fps, + latents=block_state.latents, + sound_latents=block_state.sound_latents, + action_latents=block_state.action_latents, + generator=block_state.generator, + device=block_state.device, + dtype=block_state.dtype, + enable_sound=block_state.enable_sound, + action=block_state.action, + ) + + vision_condition_indexes = torch.nonzero( + block_state.vision_condition_mask[:, 0, 0] > 0, as_tuple=False + ).flatten() + block_state.vision_condition_indexes_for_pack = [int(idx.item()) for idx in vision_condition_indexes] + block_state.has_image_condition = bool(block_state.vision_condition_indexes_for_pack) + + self.set_block_state(state, block_state) + return components, state + + +class Cosmos3PackSequenceStep(ModularPipelineBlocks): + model_name = "cosmos3-omni" + + @property + def description(self) -> str: + return "Builds static packed cond/uncond sequence metadata before denoising." + + @property + def inputs(self) -> list[InputParam]: + return [ + InputParam(name="cond_text_segment", required=True), + InputParam(name="uncond_text_segment", required=True), + InputParam(name="latents", required=True), + InputParam(name="sound_latents", default=None), + InputParam(name="action_latents", default=None), + InputParam(name="fps_vision", required=True), + InputParam(name="fps_sound", default=None), + InputParam(name="has_image_condition", required=True), + InputParam(name="vision_condition_indexes_for_pack", required=True), + InputParam(name="action_condition_frame_indexes", default=None), + InputParam(name="device", required=True), + ] + + @property + def intermediate_outputs(self) -> list[OutputParam]: + return [ + OutputParam("cond_packed_static"), + OutputParam("uncond_packed_static"), + OutputParam("num_noisy_vision_tokens"), + OutputParam("sound_len"), + OutputParam("action_noisy_len"), + ] + + @torch.no_grad() + def __call__(self, components: Cosmos3OmniModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + + cond_vision_segment = components._prepare_vision_segment( + input_vision_tokens=block_state.latents, + has_image_condition=block_state.has_image_condition, + mrope_offset=block_state.cond_text_segment["vision_start_temporal_offset"], + vision_fps=block_state.fps_vision, + curr=block_state.cond_text_segment["und_len"], + device=block_state.device, + condition_frame_indexes=block_state.vision_condition_indexes_for_pack, + ) + cond_sound_segment: dict[str, Any] = {} + if block_state.sound_latents is not None: + cond_sound_segment = components._prepare_sound_segment( + input_sound_tokens=block_state.sound_latents, + mrope_offset=block_state.cond_text_segment["vision_start_temporal_offset"], + sound_fps=block_state.fps_sound, + curr=block_state.cond_text_segment["und_len"] + cond_vision_segment["num_vision_tokens"], + device=block_state.device, + ) + cond_action_segment: dict[str, Any] = {} + if block_state.action_latents is not None: + cond_action_segment = components._prepare_action_segment( + input_action_tokens=block_state.action_latents, + condition_frame_indexes=block_state.action_condition_frame_indexes, + mrope_offset=block_state.cond_text_segment["vision_start_temporal_offset"], + action_fps=block_state.fps_vision, + curr=block_state.cond_text_segment["und_len"] + + cond_vision_segment["num_vision_tokens"] + + cond_sound_segment.get("sound_len", 0), + device=block_state.device, + ) + cond_mrope_segments = [ + block_state.cond_text_segment["text_mrope_ids"], + cond_vision_segment["vision_mrope_ids"], + ] + if cond_sound_segment: + cond_mrope_segments.append(cond_sound_segment["sound_mrope_ids"]) + if cond_action_segment: + cond_mrope_segments.append(cond_action_segment["action_mrope_ids"]) + block_state.cond_packed_static = { + **block_state.cond_text_segment, + **cond_vision_segment, + **cond_sound_segment, + **cond_action_segment, + "position_ids": torch.cat(cond_mrope_segments, dim=1), + "sequence_length": block_state.cond_text_segment["und_len"] + + cond_vision_segment["num_vision_tokens"] + + cond_sound_segment.get("sound_len", 0) + + cond_action_segment.get("action_len", 0), + } + + uncond_vision_segment = components._prepare_vision_segment( + input_vision_tokens=block_state.latents, + has_image_condition=block_state.has_image_condition, + mrope_offset=block_state.uncond_text_segment["vision_start_temporal_offset"], + vision_fps=block_state.fps_vision, + curr=block_state.uncond_text_segment["und_len"], + device=block_state.device, + condition_frame_indexes=block_state.vision_condition_indexes_for_pack, + ) + uncond_sound_segment: dict[str, Any] = {} + if block_state.sound_latents is not None: + uncond_sound_segment = components._prepare_sound_segment( + input_sound_tokens=block_state.sound_latents, + mrope_offset=block_state.uncond_text_segment["vision_start_temporal_offset"], + sound_fps=block_state.fps_sound, + curr=block_state.uncond_text_segment["und_len"] + uncond_vision_segment["num_vision_tokens"], + device=block_state.device, + ) + uncond_action_segment: dict[str, Any] = {} + if block_state.action_latents is not None: + uncond_action_segment = components._prepare_action_segment( + input_action_tokens=block_state.action_latents, + condition_frame_indexes=block_state.action_condition_frame_indexes, + mrope_offset=block_state.uncond_text_segment["vision_start_temporal_offset"], + action_fps=block_state.fps_vision, + curr=block_state.uncond_text_segment["und_len"] + + uncond_vision_segment["num_vision_tokens"] + + uncond_sound_segment.get("sound_len", 0), + device=block_state.device, + ) + uncond_mrope_segments = [ + block_state.uncond_text_segment["text_mrope_ids"], + uncond_vision_segment["vision_mrope_ids"], + ] + if uncond_sound_segment: + uncond_mrope_segments.append(uncond_sound_segment["sound_mrope_ids"]) + if uncond_action_segment: + uncond_mrope_segments.append(uncond_action_segment["action_mrope_ids"]) + block_state.uncond_packed_static = { + **block_state.uncond_text_segment, + **uncond_vision_segment, + **uncond_sound_segment, + **uncond_action_segment, + "position_ids": torch.cat(uncond_mrope_segments, dim=1), + "sequence_length": block_state.uncond_text_segment["und_len"] + + uncond_vision_segment["num_vision_tokens"] + + uncond_sound_segment.get("sound_len", 0) + + uncond_action_segment.get("action_len", 0), + } + + block_state.num_noisy_vision_tokens = cond_vision_segment["num_noisy_vision_tokens"] + block_state.sound_len = cond_sound_segment.get("sound_len") + block_state.action_noisy_len = cond_action_segment.get("num_noisy_action_tokens") + + self.set_block_state(state, block_state) + return components, state + + +class Cosmos3SetTimestepsStep(ModularPipelineBlocks): + model_name = "cosmos3-omni" + + @property + def description(self) -> str: + return "Initializes scheduler timesteps and modality schedulers." + + @property + def expected_components(self) -> list[ComponentSpec]: + return [ComponentSpec("scheduler", UniPCMultistepScheduler)] + + @property + def inputs(self) -> list[InputParam]: + return [ + InputParam.template("num_inference_steps", required=True), + InputParam(name="device", required=True), + InputParam(name="sound_latents", default=None), + InputParam(name="action_latents", default=None), + ] + + @property + def intermediate_outputs(self) -> list[OutputParam]: + return [ + OutputParam("timesteps"), + OutputParam("sound_scheduler"), + OutputParam("action_scheduler"), + OutputParam("num_warmup_steps"), + ] + + @torch.no_grad() + def __call__(self, components: Cosmos3OmniModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + components.scheduler.set_timesteps(block_state.num_inference_steps, device=block_state.device) + + block_state.timesteps = components.scheduler.timesteps + block_state.sound_scheduler = ( + copy.deepcopy(components.scheduler) if block_state.sound_latents is not None else None + ) + block_state.action_scheduler = ( + copy.deepcopy(components.scheduler) if block_state.action_latents is not None else None + ) + block_state.num_warmup_steps = ( + len(block_state.timesteps) - block_state.num_inference_steps * components.scheduler.order + ) + + self.set_block_state(state, block_state) + return components, state diff --git a/src/diffusers/modular_pipelines/cosmos/decoders.py b/src/diffusers/modular_pipelines/cosmos/decoders.py new file mode 100644 index 000000000000..b6993012752f --- /dev/null +++ b/src/diffusers/modular_pipelines/cosmos/decoders.py @@ -0,0 +1,100 @@ +import torch + +from ...models.autoencoders.autoencoder_cosmos3_audio import Cosmos3AVAEAudioTokenizer +from ...models.autoencoders.autoencoder_kl_wan import AutoencoderKLWan +from ...pipelines.cosmos.pipeline_cosmos3_omni import Cosmos3OmniPipelineOutput, CosmosSafetyChecker +from ...utils import logging +from ..modular_pipeline import ModularPipelineBlocks, PipelineState +from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam +from .modular_pipeline import Cosmos3OmniModularPipeline + + +logger = logging.get_logger(__name__) + + +class Cosmos3DecodeStep(ModularPipelineBlocks): + model_name = "cosmos3-omni" + + @property + def description(self) -> str: + return "Decodes denoised latents into video/sound/action outputs." + + @property + def expected_components(self) -> list[ComponentSpec]: + return [ + ComponentSpec("vae", AutoencoderKLWan), + ComponentSpec("sound_tokenizer", Cosmos3AVAEAudioTokenizer), + ] + + @property + def inputs(self) -> list[InputParam]: + return [ + InputParam(name="latents", required=True), + InputParam(name="sound_latents", default=None), + InputParam(name="action_latents", default=None), + InputParam(name="action_mode", default=None), + InputParam(name="raw_action_dim_resolved", default=None), + InputParam.template("output_type", default="pil"), + InputParam(name="enable_safety_check", default=True), + InputParam(name="device", required=True), + InputParam(name="return_dict", default=True), + ] + + @property + def intermediate_outputs(self) -> list[OutputParam]: + return [ + OutputParam("videos"), + OutputParam("sound"), + OutputParam("action"), + OutputParam("result"), + ] + + @torch.no_grad() + def __call__(self, components: Cosmos3OmniModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + + sound = components.decode_sound(block_state.sound_latents) if block_state.sound_latents is not None else None + action_output = None + if block_state.action_mode in {"inverse_dynamics", "policy"} and block_state.action_latents is not None: + action_output = block_state.action_latents + if block_state.raw_action_dim_resolved is not None: + action_output = action_output[:, : block_state.raw_action_dim_resolved] + action_output = [action_output.detach().cpu()] + + if block_state.output_type == "latent": + video = block_state.latents + else: + in_dtype = block_state.latents.dtype + vae_dtype = components.vae.dtype + mean = components._vae_latents_mean.to(device=block_state.latents.device, dtype=vae_dtype) + inv_std = components._vae_latents_inv_std.to(device=block_state.latents.device, dtype=vae_dtype) + z_raw = block_state.latents.to(vae_dtype) / inv_std.view(1, -1, 1, 1, 1) + mean.view(1, -1, 1, 1, 1) + decoded = components.vae.decode(z_raw).sample.to(in_dtype) + video = components.video_processor.postprocess_video(decoded, output_type=block_state.output_type)[0] + + if ( + block_state.enable_safety_check + and isinstance(components.safety_checker, CosmosSafetyChecker) + and block_state.output_type != "latent" + ): + video = components._apply_video_safety_check( + video, output_type=block_state.output_type, device=block_state.device + ) + + components.maybe_free_model_hooks() + + if not block_state.return_dict: + if block_state.action_mode is not None: + result = (video, sound, action_output) + else: + result = (video, sound) + else: + result = Cosmos3OmniPipelineOutput(video=video, sound=sound, action=action_output) + + block_state.videos = video + block_state.sound = sound + block_state.action = action_output + block_state.result = result + + self.set_block_state(state, block_state) + return components, state diff --git a/src/diffusers/modular_pipelines/cosmos/denoise.py b/src/diffusers/modular_pipelines/cosmos/denoise.py new file mode 100644 index 000000000000..9da3a7e1b9b4 --- /dev/null +++ b/src/diffusers/modular_pipelines/cosmos/denoise.py @@ -0,0 +1,259 @@ +import torch + +from ...models.transformers.transformer_cosmos3 import Cosmos3OmniTransformer +from ...schedulers import UniPCMultistepScheduler +from ...utils import logging +from ..modular_pipeline import BlockState, LoopSequentialPipelineBlocks, ModularPipelineBlocks, PipelineState +from ..modular_pipeline_utils import ComponentSpec, InputParam +from .modular_pipeline import Cosmos3OmniModularPipeline + + +logger = logging.get_logger(__name__) + + +class Cosmos3LoopStep(ModularPipelineBlocks): + model_name = "cosmos3-omni" + + @property + def expected_components(self) -> list[ComponentSpec]: + return [ + ComponentSpec("scheduler", UniPCMultistepScheduler), + ComponentSpec("transformer", Cosmos3OmniTransformer), + ] + + @property + def description(self) -> str: + return "Runs one Cosmos3 denoising iteration with optional sound/action streams." + + @property + def inputs(self) -> list[InputParam]: + return [ + InputParam(name="latents", required=True), + InputParam(name="sound_latents", default=None), + InputParam(name="action_latents", default=None), + InputParam(name="num_noisy_vision_tokens", required=True), + InputParam(name="sound_len", default=None), + InputParam(name="action_noisy_len", default=None), + InputParam(name="cond_packed_static", required=True), + InputParam(name="uncond_packed_static", required=True), + InputParam(name="vision_condition_mask", required=True), + InputParam(name="sound_condition_mask", default=None), + InputParam(name="action_condition_mask", default=None), + InputParam(name="action_domain_id", default=None), + InputParam(name="raw_action_dim_resolved", default=None), + InputParam(name="sound_scheduler", default=None), + InputParam(name="action_scheduler", default=None), + InputParam(name="guidance_scale", default=6.0), + InputParam(name="device", required=True), + InputParam(name="dtype", required=True), + ] + + @torch.no_grad() + def __call__(self, components: Cosmos3OmniModularPipeline, block_state: BlockState, i: int, t: torch.Tensor): + components._current_timestep = t + timestep = t.item() + + vision_tokens = block_state.latents.to(device=block_state.device, dtype=block_state.dtype) + sound_tokens = ( + block_state.sound_latents.to(device=block_state.device, dtype=block_state.dtype) + if block_state.sound_latents is not None + else None + ) + action_tokens = ( + block_state.action_latents.to(device=block_state.device, dtype=block_state.dtype) + if block_state.action_latents is not None + else None + ) + vision_timesteps = torch.full((block_state.num_noisy_vision_tokens,), timestep, device=block_state.device) + sound_timesteps = ( + torch.full((block_state.sound_len,), timestep, device=block_state.device) + if sound_tokens is not None + else None + ) + action_timesteps = ( + torch.full((block_state.action_noisy_len,), timestep, device=block_state.device) + if action_tokens is not None + else None + ) + + preds_vision, preds_sound, preds_action = components.transformer( + input_ids=block_state.cond_packed_static["input_ids"], + text_indexes=block_state.cond_packed_static["text_indexes"], + position_ids=block_state.cond_packed_static["position_ids"], + und_len=block_state.cond_packed_static["und_len"], + sequence_length=block_state.cond_packed_static["sequence_length"], + vision_tokens=[vision_tokens], + vision_token_shapes=block_state.cond_packed_static["vision_token_shapes"], + vision_sequence_indexes=block_state.cond_packed_static["vision_sequence_indexes"], + vision_mse_loss_indexes=block_state.cond_packed_static["vision_mse_loss_indexes"], + vision_timesteps=vision_timesteps, + vision_noisy_frame_indexes=block_state.cond_packed_static["vision_noisy_frame_indexes"], + sound_tokens=[sound_tokens] if sound_tokens is not None else None, + sound_token_shapes=block_state.cond_packed_static.get("sound_token_shapes"), + sound_sequence_indexes=block_state.cond_packed_static.get("sound_sequence_indexes"), + sound_mse_loss_indexes=block_state.cond_packed_static.get("sound_mse_loss_indexes"), + sound_timesteps=sound_timesteps, + sound_noisy_frame_indexes=block_state.cond_packed_static.get("sound_noisy_frame_indexes"), + action_tokens=[action_tokens] if action_tokens is not None else None, + action_token_shapes=block_state.cond_packed_static.get("action_token_shapes"), + action_sequence_indexes=block_state.cond_packed_static.get("action_sequence_indexes"), + action_mse_loss_indexes=block_state.cond_packed_static.get("action_mse_loss_indexes"), + action_timesteps=action_timesteps, + action_noisy_frame_indexes=block_state.cond_packed_static.get("action_noisy_frame_indexes"), + action_domain_ids=[block_state.action_domain_id] if block_state.action_domain_id is not None else None, + ) + cond_v_vision, cond_v_sound, cond_v_action = components._mask_velocity_predictions( + preds_vision, + preds_sound, + vision_condition_mask=[block_state.vision_condition_mask], + sound_condition_mask=[block_state.sound_condition_mask] + if block_state.sound_condition_mask is not None + else None, + preds_action=preds_action, + action_condition_mask=[block_state.action_condition_mask] + if block_state.action_condition_mask is not None + else None, + raw_action_dim=block_state.raw_action_dim_resolved, + ) + + uncond_v_vision = uncond_v_sound = uncond_v_action = None + if components.do_classifier_free_guidance: + preds_vision, preds_sound, preds_action = components.transformer( + input_ids=block_state.uncond_packed_static["input_ids"], + text_indexes=block_state.uncond_packed_static["text_indexes"], + position_ids=block_state.uncond_packed_static["position_ids"], + und_len=block_state.uncond_packed_static["und_len"], + sequence_length=block_state.uncond_packed_static["sequence_length"], + vision_tokens=[vision_tokens], + vision_token_shapes=block_state.uncond_packed_static["vision_token_shapes"], + vision_sequence_indexes=block_state.uncond_packed_static["vision_sequence_indexes"], + vision_mse_loss_indexes=block_state.uncond_packed_static["vision_mse_loss_indexes"], + vision_timesteps=vision_timesteps, + vision_noisy_frame_indexes=block_state.uncond_packed_static["vision_noisy_frame_indexes"], + sound_tokens=[sound_tokens] if sound_tokens is not None else None, + sound_token_shapes=block_state.uncond_packed_static.get("sound_token_shapes"), + sound_sequence_indexes=block_state.uncond_packed_static.get("sound_sequence_indexes"), + sound_mse_loss_indexes=block_state.uncond_packed_static.get("sound_mse_loss_indexes"), + sound_timesteps=sound_timesteps, + sound_noisy_frame_indexes=block_state.uncond_packed_static.get("sound_noisy_frame_indexes"), + action_tokens=[action_tokens] if action_tokens is not None else None, + action_token_shapes=block_state.uncond_packed_static.get("action_token_shapes"), + action_sequence_indexes=block_state.uncond_packed_static.get("action_sequence_indexes"), + action_mse_loss_indexes=block_state.uncond_packed_static.get("action_mse_loss_indexes"), + action_timesteps=action_timesteps, + action_noisy_frame_indexes=block_state.uncond_packed_static.get("action_noisy_frame_indexes"), + action_domain_ids=[block_state.action_domain_id] if block_state.action_domain_id is not None else None, + ) + uncond_v_vision, uncond_v_sound, uncond_v_action = components._mask_velocity_predictions( + preds_vision, + preds_sound, + vision_condition_mask=[block_state.vision_condition_mask], + sound_condition_mask=[block_state.sound_condition_mask] + if block_state.sound_condition_mask is not None + else None, + preds_action=preds_action, + action_condition_mask=[block_state.action_condition_mask] + if block_state.action_condition_mask is not None + else None, + raw_action_dim=block_state.raw_action_dim_resolved, + ) + + if components.do_classifier_free_guidance: + velocity_vision = uncond_v_vision + block_state.guidance_scale * (cond_v_vision - uncond_v_vision) + else: + velocity_vision = cond_v_vision + + block_state.latents = components.scheduler.step( + velocity_vision.unsqueeze(0), t, block_state.latents.unsqueeze(0), return_dict=False + )[0].squeeze(0) + + if block_state.sound_scheduler is not None and cond_v_sound is not None: + if components.do_classifier_free_guidance: + velocity_sound = uncond_v_sound + block_state.guidance_scale * (cond_v_sound - uncond_v_sound) + else: + velocity_sound = cond_v_sound + block_state.sound_latents = block_state.sound_scheduler.step( + velocity_sound.unsqueeze(0), t, block_state.sound_latents.unsqueeze(0), return_dict=False + )[0].squeeze(0) + + has_noisy_action = ( + block_state.action_condition_mask is not None + and block_state.action_condition_mask.sum() < block_state.action_condition_mask.numel() + ) + if block_state.action_scheduler is not None and has_noisy_action and cond_v_action is not None: + if components.do_classifier_free_guidance: + velocity_action = uncond_v_action + block_state.guidance_scale * (cond_v_action - uncond_v_action) + else: + velocity_action = cond_v_action + block_state.action_latents = block_state.action_scheduler.step( + velocity_action.unsqueeze(0), t, block_state.action_latents.unsqueeze(0), return_dict=False + )[0].squeeze(0) + if block_state.raw_action_dim_resolved is not None: + block_state.action_latents[:, block_state.raw_action_dim_resolved :] = 0 + + return components, block_state + + +class Cosmos3DenoiseLoopWrapper(LoopSequentialPipelineBlocks): + model_name = "cosmos3-omni" + + @property + def description(self) -> str: + return "Iteratively denoises Cosmos3 latents over scheduler timesteps." + + @property + def loop_expected_components(self) -> list[ComponentSpec]: + return [ + ComponentSpec("scheduler", UniPCMultistepScheduler), + ComponentSpec("transformer", Cosmos3OmniTransformer), + ] + + @property + def loop_inputs(self) -> list[InputParam]: + return [ + InputParam.template("timesteps", required=True), + InputParam.template("num_inference_steps", required=True), + InputParam(name="num_warmup_steps", required=True), + InputParam(name="callback_on_step_end", default=None), + InputParam(name="callback_on_step_end_tensor_inputs", default=["latents"]), + ] + + @torch.no_grad() + def __call__(self, components: Cosmos3OmniModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + components._num_timesteps = len(block_state.timesteps) + + with self.progress_bar(total=block_state.num_inference_steps) as progress_bar: + for i, t in enumerate(block_state.timesteps): + if components.interrupt: + continue + + components, block_state = self.loop_step(components, block_state, i=i, t=t) + + if block_state.callback_on_step_end is not None: + callback_kwargs = { + k: getattr(block_state, k) for k in block_state.callback_on_step_end_tensor_inputs + } + callback_outputs = block_state.callback_on_step_end(components, i, t, callback_kwargs) + if callback_outputs is not None and isinstance(callback_outputs, dict): + block_state.latents = callback_outputs.pop("latents", block_state.latents) + block_state.sound_latents = callback_outputs.pop("sound_latents", block_state.sound_latents) + block_state.action_latents = callback_outputs.pop("action_latents", block_state.action_latents) + + if i == len(block_state.timesteps) - 1 or ( + (i + 1) > block_state.num_warmup_steps and (i + 1) % components.scheduler.order == 0 + ): + progress_bar.update() + + components._current_timestep = None + self.set_block_state(state, block_state) + return components, state + + +class Cosmos3DenoiseStep(Cosmos3DenoiseLoopWrapper): + block_classes = [Cosmos3LoopStep()] + block_names = ["denoise_step"] + + @property + def description(self) -> str: + return "Cosmos3 denoising loop for generation modes." diff --git a/src/diffusers/modular_pipelines/cosmos/encoders.py b/src/diffusers/modular_pipelines/cosmos/encoders.py new file mode 100644 index 000000000000..58d0e58bb6c2 --- /dev/null +++ b/src/diffusers/modular_pipelines/cosmos/encoders.py @@ -0,0 +1,163 @@ +import torch +from transformers import AutoTokenizer + +from ...callbacks import MultiPipelineCallbacks, PipelineCallback +from ...models.autoencoders.autoencoder_cosmos3_audio import Cosmos3AVAEAudioTokenizer +from ...models.autoencoders.autoencoder_kl_wan import AutoencoderKLWan +from ...models.transformers.transformer_cosmos3 import Cosmos3OmniTransformer +from ...utils import logging +from ...pipelines.cosmos.pipeline_cosmos3_omni import ( + _ACTION_RESOLUTION_BINS, + CosmosActionCondition, + CosmosSafetyChecker, +) +from ...video_processor import VideoProcessor +from ..modular_pipeline import ModularPipelineBlocks, PipelineState +from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam +from .modular_pipeline import Cosmos3OmniModularPipeline + + +logger = logging.get_logger(__name__) + + +class Cosmos3TextEncoderStep(ModularPipelineBlocks): + model_name = "cosmos3-omni" + + @property + def description(self) -> str: + return "Validates inputs, tokenizes prompts, and packs text conditioning." + + @property + def expected_components(self) -> list[ComponentSpec]: + return [ + ComponentSpec("transformer", Cosmos3OmniTransformer), + ComponentSpec("text_tokenizer", AutoTokenizer), + ComponentSpec("vae", AutoencoderKLWan), + ComponentSpec("sound_tokenizer", Cosmos3AVAEAudioTokenizer), + ] + + @property + def inputs(self) -> list[InputParam]: + return [ + InputParam(name="prompt", type_hint=str, required=True), + InputParam(name="negative_prompt", default=None), + InputParam(name="image", default=None), + InputParam(name="video", default=None), + InputParam(name="condition_frame_indexes_vision", default=(0, 1)), + InputParam(name="condition_video_keep", default="first"), + InputParam(name="num_frames", default=None), + InputParam(name="height", default=None), + InputParam(name="width", default=None), + InputParam(name="fps", type_hint=float, default=24.0), + InputParam(name="num_inference_steps", type_hint=int, default=35), + InputParam(name="guidance_scale", type_hint=float, default=6.0), + InputParam(name="enable_sound", type_hint=bool, default=False), + InputParam(name="action", type_hint=CosmosActionCondition, default=None), + InputParam(name="use_system_prompt", type_hint=bool, default=True), + InputParam(name="callback_on_step_end", default=None), + InputParam(name="callback_on_step_end_tensor_inputs", default=["latents"]), + InputParam(name="add_resolution_template", type_hint=bool, default=True), + InputParam(name="add_duration_template", type_hint=bool, default=True), + InputParam(name="enable_safety_check", type_hint=bool, default=True), + ] + + @property + def intermediate_outputs(self) -> list[OutputParam]: + return [ + OutputParam("action_mode"), + OutputParam("device"), + OutputParam("dtype"), + OutputParam("cond_text_segment"), + OutputParam("uncond_text_segment"), + ] + + @torch.no_grad() + def __call__(self, components: Cosmos3OmniModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + + if isinstance(block_state.callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): + block_state.callback_on_step_end_tensor_inputs = block_state.callback_on_step_end.tensor_inputs + + if block_state.action is None: + if block_state.num_frames is None: + block_state.num_frames = 189 + if block_state.height is None: + block_state.height = 720 + if block_state.width is None: + block_state.width = 1280 + + components.check_inputs( + block_state.prompt, + block_state.negative_prompt, + block_state.image, + block_state.height, + block_state.width, + block_state.num_frames, + block_state.guidance_scale, + block_state.enable_sound, + block_state.callback_on_step_end_tensor_inputs, + block_state.action, + video=block_state.video, + condition_frame_indexes_vision=block_state.condition_frame_indexes_vision, + ) + + block_state.action_mode = block_state.action.mode if block_state.action is not None else None + if block_state.action is not None: + block_state.num_frames = block_state.action.chunk_size + 1 + conditioning_clip = ( + [block_state.action.image] if block_state.action.image is not None else block_state.action.video + ) + probe = components.video_processor.preprocess_video(conditioning_clip) + source_h, source_w = int(probe.shape[-2]), int(probe.shape[-1]) + resolution_key = str(block_state.action.resolution_tier) + block_state.height, block_state.width = VideoProcessor.classify_height_width_bin( + source_h, source_w, ratios=_ACTION_RESOLUTION_BINS[resolution_key] + ) + + components._current_timestep = None + components._interrupt = False + components._guidance_scale = block_state.guidance_scale + + if isinstance(block_state.prompt, list): + block_state.prompt = block_state.prompt[0] + if isinstance(block_state.negative_prompt, list): + block_state.negative_prompt = block_state.negative_prompt[0] + + block_state.device = components._get_execution_device() + block_state.dtype = components.transformer.dtype + + if block_state.enable_safety_check and getattr(components, "safety_checker", None) is None: + try: + components._ensure_safety_checker() + except ImportError: + pass + + if block_state.enable_safety_check and isinstance(components.safety_checker, CosmosSafetyChecker): + components.safety_checker.to(block_state.device) + try: + if not components.safety_checker.check_text_safety(block_state.prompt): + raise ValueError( + f"Cosmos Guardrail detected unsafe text in the prompt: {block_state.prompt}. " + "Please ensure that the prompt abides by the NVIDIA Open Model License Agreement." + ) + finally: + components.safety_checker.to("cpu") + + cond_input_ids, uncond_input_ids = components.tokenize_prompt( + block_state.prompt, + block_state.negative_prompt, + num_frames=block_state.num_frames, + height=block_state.height, + width=block_state.width, + fps=block_state.fps, + use_system_prompt=block_state.use_system_prompt, + add_resolution_template=block_state.add_resolution_template, + add_duration_template=block_state.add_duration_template, + action_mode=block_state.action_mode, + action_view_point=block_state.action.view_point if block_state.action is not None else None, + ) + block_state.cond_text_segment = components._prepare_text_segment(cond_input_ids, device=block_state.device) + block_state.uncond_text_segment = components._prepare_text_segment(uncond_input_ids, device=block_state.device) + + self.set_block_state(state, block_state) + return components, state diff --git a/src/diffusers/modular_pipelines/cosmos/modular_blocks_cosmos3.py b/src/diffusers/modular_pipelines/cosmos/modular_blocks_cosmos3.py new file mode 100644 index 000000000000..303c0fda703d --- /dev/null +++ b/src/diffusers/modular_pipelines/cosmos/modular_blocks_cosmos3.py @@ -0,0 +1,66 @@ +from ...utils import logging +from ..modular_pipeline import SequentialPipelineBlocks +from ..modular_pipeline_utils import OutputParam +from .before_denoise import Cosmos3PackSequenceStep, Cosmos3PrepareLatentsStep, Cosmos3SetTimestepsStep +from .decoders import Cosmos3DecodeStep +from .denoise import Cosmos3DenoiseStep +from .encoders import Cosmos3TextEncoderStep + + +logger = logging.get_logger(__name__) + + +# auto_docstring +class Cosmos3CoreDenoiseStep(SequentialPipelineBlocks): + model_name = "cosmos3-omni" + block_classes = [ + Cosmos3PrepareLatentsStep, + Cosmos3PackSequenceStep, + Cosmos3SetTimestepsStep, + Cosmos3DenoiseStep, + ] + block_names = ["prepare_latents", "pack_sequence", "set_timesteps", "denoise"] + + @property + def description(self): + return "Prepares modalities, packs sequences, initializes timesteps, and denoises." + + @property + def outputs(self): + return [ + OutputParam.template("latents"), + OutputParam("sound_latents"), + OutputParam("action_latents"), + ] + + +# auto_docstring +class Cosmos3OmniBlocks(SequentialPipelineBlocks): + model_name = "cosmos3-omni" + block_classes = [Cosmos3TextEncoderStep, Cosmos3CoreDenoiseStep, Cosmos3DecodeStep] + block_names = ["text_encoder", "denoise", "decode"] + _workflow_map = { + "text2image": {"prompt": True, "num_frames": 1}, + "text2video": {"prompt": True}, + "image2video": {"prompt": True, "image": True}, + "video2video": {"prompt": True, "video": True}, + "text2video_with_sound": {"prompt": True, "enable_sound": True}, + "image2video_with_sound": {"prompt": True, "image": True, "enable_sound": True}, + "video2video_with_sound": {"prompt": True, "video": True, "enable_sound": True}, + "action_policy": {"prompt": True, "action": True}, + "action_forward_dynamics": {"prompt": True, "action": True}, + "action_inverse_dynamics": {"prompt": True, "action": True}, + } + + @property + def description(self): + return "Modular pipeline blocks for Cosmos3 generation modes." + + @property + def outputs(self): + return [ + OutputParam("result"), + OutputParam.template("videos"), + OutputParam("sound"), + OutputParam("action"), + ] diff --git a/src/diffusers/modular_pipelines/cosmos/modular_pipeline.py b/src/diffusers/modular_pipelines/cosmos/modular_pipeline.py new file mode 100644 index 000000000000..81200bd96457 --- /dev/null +++ b/src/diffusers/modular_pipelines/cosmos/modular_pipeline.py @@ -0,0 +1,134 @@ +import torch + +from ...pipelines.cosmos.pipeline_cosmos3_omni import Cosmos3OmniPipeline, CosmosSafetyChecker +from ...utils import logging +from ...video_processor import VideoProcessor +from ..modular_pipeline import ModularPipeline, PipelineState + + +logger = logging.get_logger(__name__) + + +class Cosmos3OmniModularPipeline(ModularPipeline): + """ + A ModularPipeline for Cosmos 3 omni generation. + """ + + default_blocks_name = "Cosmos3OmniBlocks" + _callback_tensor_inputs = ["latents"] + _exclude_from_cpu_offload = ["safety_checker"] + model_cpu_offload_seq = "transformer->vae->sound_tokenizer" + + def _ensure_runtime_attributes(self): + if getattr(self, "vae", None) is not None: + self._vae_latents_mean = torch.tensor(self.vae.config.latents_mean, dtype=self.vae.dtype) + self._vae_latents_inv_std = 1.0 / torch.tensor(self.vae.config.latents_std, dtype=self.vae.dtype) + self.vae_scale_factor_spatial = int(self.vae.config.scale_factor_spatial) + elif not hasattr(self, "vae_scale_factor_spatial"): + self.vae_scale_factor_spatial = 16 + + if getattr(self, "text_tokenizer", None) is not None: + self.llm_special_tokens = { + "start_of_generation": self.text_tokenizer.convert_tokens_to_ids("<|vision_start|>"), + "eos_token_id": self.text_tokenizer.eos_token_id, + } + + if getattr(self, "video_processor", None) is None: + self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial, resample="bilinear") + + self.duration_template = "The video is {duration:.1f} seconds long and is of {fps:.0f} FPS." + self.image_resolution_template = "This image is of {height}x{width} resolution." + self.video_resolution_template = "This video is of {height}x{width} resolution." + self.inverse_duration_template = "The video is not {duration:.1f} seconds long and is not of {fps:.0f} FPS." + self.inverse_image_resolution_template = "This image is not of {height}x{width} resolution." + self.inverse_video_resolution_template = "This video is not of {height}x{width} resolution." + + if not hasattr(self, "safety_checker"): + self.safety_checker = None + if not hasattr(self, "_current_timestep"): + self._current_timestep = None + if not hasattr(self, "_interrupt"): + self._interrupt = False + if not hasattr(self, "_guidance_scale"): + self._guidance_scale = 1.0 + + def _ensure_safety_checker(self): + if getattr(self, "safety_checker", None) is None: + self.safety_checker = CosmosSafetyChecker() + + def __call__(self, state: PipelineState = None, output: str | list[str] | None = None, **kwargs): + self._ensure_runtime_attributes() + if output is None: + output = "result" + return super().__call__(state=state, output=output, **kwargs) + + def _get_execution_device(self): + return Cosmos3OmniPipeline._get_execution_device(self) + + def _encode_video(self, x): + return Cosmos3OmniPipeline._encode_video(self, x) + + def decode_sound(self, latent): + return Cosmos3OmniPipeline.decode_sound(self, latent) + + def _prepare_text_segment(self, input_ids, device): + return Cosmos3OmniPipeline._prepare_text_segment(self, input_ids, device) + + def _prepare_vision_segment(self, *args, **kwargs): + return Cosmos3OmniPipeline._prepare_vision_segment(self, *args, **kwargs) + + def _prepare_sound_segment(self, *args, **kwargs): + return Cosmos3OmniPipeline._prepare_sound_segment(self, *args, **kwargs) + + def _prepare_action_segment(self, *args, **kwargs): + return Cosmos3OmniPipeline._prepare_action_segment(self, *args, **kwargs) + + def _prepare_action_video_conditioning(self, *args, **kwargs): + return Cosmos3OmniPipeline._prepare_action_video_conditioning(self, *args, **kwargs) + + def _remove_action_video_padding_from_latent(self, *args, **kwargs): + return Cosmos3OmniPipeline._remove_action_video_padding_from_latent(self, *args, **kwargs) + + def prepare_latents(self, *args, **kwargs): + return Cosmos3OmniPipeline.prepare_latents(self, *args, **kwargs) + + def check_inputs(self, *args, **kwargs): + return Cosmos3OmniPipeline.check_inputs(self, *args, **kwargs) + + @staticmethod + def _build_action_json_prompt(*args, **kwargs): + return Cosmos3OmniPipeline._build_action_json_prompt(*args, **kwargs) + + def tokenize_prompt(self, *args, **kwargs): + return Cosmos3OmniPipeline.tokenize_prompt(self, *args, **kwargs) + + @staticmethod + def _mask_velocity_predictions(*args, **kwargs): + return Cosmos3OmniPipeline._mask_velocity_predictions(*args, **kwargs) + + def _apply_video_safety_check(self, *args, **kwargs): + return Cosmos3OmniPipeline._apply_video_safety_check(self, *args, **kwargs) + + def maybe_free_model_hooks(self): + for component in self.components.values(): + if hasattr(component, "_reset_stateful_cache"): + component._reset_stateful_cache() + + model_hooks = getattr(self._components_manager, "model_hooks", None) + if not model_hooks: + return + + for hook in model_hooks: + hook.offload() + + @property + def current_timestep(self): + return self._current_timestep + + @property + def interrupt(self): + return self._interrupt + + @property + def do_classifier_free_guidance(self): + return self._guidance_scale != 1.0 diff --git a/src/diffusers/modular_pipelines/modular_pipeline.py b/src/diffusers/modular_pipelines/modular_pipeline.py index 2d661028acf6..d43825860d8e 100644 --- a/src/diffusers/modular_pipelines/modular_pipeline.py +++ b/src/diffusers/modular_pipelines/modular_pipeline.py @@ -133,6 +133,7 @@ def _helios_pyramid_map_fn(config_dict=None): ("qwenimage-layered", _create_default_map_fn("QwenImageLayeredModularPipeline")), ("anima", _create_default_map_fn("AnimaModularPipeline")), ("z-image", _create_default_map_fn("ZImageModularPipeline")), + ("cosmos3-omni", _create_default_map_fn("Cosmos3OmniModularPipeline")), ("helios", _create_default_map_fn("HeliosModularPipeline")), ("helios-pyramid", _helios_pyramid_map_fn), ("hunyuan-video-1.5", _create_default_map_fn("HunyuanVideo15ModularPipeline")), diff --git a/tests/pipelines/cosmos/test_cosmos3_modular_parity.py b/tests/pipelines/cosmos/test_cosmos3_modular_parity.py new file mode 100644 index 000000000000..2db277da4a9f --- /dev/null +++ b/tests/pipelines/cosmos/test_cosmos3_modular_parity.py @@ -0,0 +1,414 @@ +# Copyright 2026 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np +import pytest +import torch +from PIL import Image + +from diffusers import AutoencoderKLWan, Cosmos3AVAEAudioTokenizer, Cosmos3OmniTransformer, UniPCMultistepScheduler +from diffusers.modular_pipelines.cosmos.modular_blocks_cosmos3 import Cosmos3OmniBlocks +from diffusers.modular_pipelines.cosmos.modular_pipeline import Cosmos3OmniModularPipeline +from diffusers.pipelines.cosmos.pipeline_cosmos3_omni import Cosmos3OmniPipeline, CosmosActionCondition + +from ...testing_utils import enable_full_determinism + + +enable_full_determinism() + + +class DummyChatTokenizer: + eos_token_id = 2 + _vision_start_id = 3 + + def convert_tokens_to_ids(self, token: str) -> int: + if token == "<|vision_start|>": + return self._vision_start_id + return 10 + + def apply_chat_template( + self, + conversations, + tokenize=True, + add_generation_prompt=True, + add_vision_id=False, + return_dict=True, + ): + text = " ".join(str(message.get("content", "")) for message in conversations) + if not text: + text = " " + + ids = [11] + for i, char in enumerate(text): + ids.append(12 + ((ord(char) + i) % 180)) + if add_generation_prompt: + ids.append(13) + + if return_dict: + return type("DummyBatchEncoding", (), {"input_ids": ids})() + return ids + + +class DummyCosmosSafetyChecker: + def to(self, *args, **kwargs): + return self + + def check_text_safety(self, prompt: str) -> bool: + return True + + def check_video_safety(self, frames_uint8: np.ndarray) -> np.ndarray: + return frames_uint8 + + +def _make_pil_video(seed: int, num_frames: int, height: int, width: int) -> list[Image.Image]: + rng = np.random.default_rng(seed) + frames = rng.integers(0, 255, size=(num_frames, height, width, 3), dtype=np.uint8) + return [Image.fromarray(frame) for frame in frames] + + +def _build_tiny_components(): + torch.manual_seed(0) + transformer = Cosmos3OmniTransformer( + hidden_size=32, + intermediate_size=64, + num_hidden_layers=2, + num_attention_heads=4, + num_key_value_heads=2, + head_dim=8, + latent_channel=4, + latent_patch_size=2, + patch_latent_dim=16, + vocab_size=256, + rope_scaling={"mrope_section": [2, 1, 1]}, + action_gen=True, + action_dim=10, + sound_gen=True, + sound_dim=4, + sound_latent_fps=5.0, + ) + + torch.manual_seed(0) + vae = AutoencoderKLWan( + base_dim=8, + decoder_base_dim=8, + z_dim=4, + dim_mult=[1, 1], + num_res_blocks=1, + attn_scales=[], + temperal_downsample=[False], + in_channels=3, + out_channels=3, + scale_factor_temporal=4, + scale_factor_spatial=8, + latents_mean=[0.0, 0.0, 0.0, 0.0], + latents_std=[1.0, 1.0, 1.0, 1.0], + ) + + scheduler = UniPCMultistepScheduler( + num_train_timesteps=1000, + prediction_type="epsilon", + ) + + sound_tokenizer = Cosmos3AVAEAudioTokenizer( + sampling_rate=16, + hop_size=4, + input_channels=1, + stereo=True, + normalize_volume=False, + enc_dim=4, + enc_num_blocks=1, + enc_n_fft=8, + enc_hop_length=2, + enc_latent_dim=8, + enc_c_mults=(1,), + enc_strides=(2,), + vocoder_input_dim=4, + dec_dim=4, + dec_c_mults=(1, 2), + dec_strides=(2, 2), + dec_out_channels=2, + ) + + return { + "transformer": transformer, + "text_tokenizer": DummyChatTokenizer(), + "vae": vae, + "scheduler": scheduler, + "sound_tokenizer": sound_tokenizer, + "safety_checker": DummyCosmosSafetyChecker(), + } + + +def _make_task_pipe() -> Cosmos3OmniPipeline: + components = _build_tiny_components() + pipe = Cosmos3OmniPipeline(**components, enable_safety_checker=True) + pipe.to("cpu") + pipe.set_progress_bar_config(disable=None) + return pipe + + +def _make_modular_pipe() -> Cosmos3OmniModularPipeline: + components = _build_tiny_components() + safety_checker = components.pop("safety_checker") + pipe = Cosmos3OmniModularPipeline(blocks=Cosmos3OmniBlocks()) + pipe.update_components(**components) + pipe.safety_checker = safety_checker + pipe.to("cpu") + pipe.set_progress_bar_config(disable=None) + return pipe + + +def _assert_close_outputs(task_out, modular_out, *, atol=0.0, rtol=0.0): + torch.testing.assert_close(task_out.video, modular_out.video, atol=atol, rtol=rtol) + + if task_out.sound is None or modular_out.sound is None: + assert task_out.sound is None and modular_out.sound is None + else: + torch.testing.assert_close(task_out.sound, modular_out.sound, atol=atol, rtol=rtol) + + if task_out.action is None or modular_out.action is None: + assert task_out.action is None and modular_out.action is None + else: + assert len(task_out.action) == len(modular_out.action) + for task_action, modular_action in zip(task_out.action, modular_out.action): + torch.testing.assert_close(task_action, modular_action, atol=atol, rtol=rtol) + + +def _build_case_kwargs(case_name: str) -> dict: + image = _make_pil_video(seed=1, num_frames=1, height=32, width=32)[0] + video = _make_pil_video(seed=2, num_frames=5, height=32, width=32) + action_video = _make_pil_video(seed=3, num_frames=5, height=32, width=32) + action_image = _make_pil_video(seed=4, num_frames=1, height=32, width=32)[0] + + common = { + "prompt": "A small robot performs a deterministic motion.", + "negative_prompt": "low quality", + "num_inference_steps": 2, + "guidance_scale": 2.0, + "fps": 5.0, + "output_type": "latent", + "enable_safety_check": False, + } + + if case_name == "text2image": + kwargs = {**common, "num_frames": 1, "height": 32, "width": 32} + elif case_name == "text2video": + kwargs = {**common, "num_frames": 5, "height": 32, "width": 32} + elif case_name == "image2video": + kwargs = {**common, "image": image, "num_frames": 5, "height": 32, "width": 32} + elif case_name == "video2video": + kwargs = { + **common, + "video": video, + "num_frames": 5, + "height": 32, + "width": 32, + "condition_frame_indexes_vision": [0, 1], + "condition_video_keep": "first", + } + elif case_name == "video2video_last": + kwargs = { + **common, + "video": video, + "num_frames": 5, + "height": 32, + "width": 32, + "condition_frame_indexes_vision": [0, 1], + "condition_video_keep": "last", + } + elif case_name == "text2video_sound": + kwargs = {**common, "num_frames": 5, "height": 32, "width": 32, "enable_sound": True} + elif case_name == "image2video_sound": + kwargs = {**common, "image": image, "num_frames": 5, "height": 32, "width": 32, "enable_sound": True} + elif case_name == "video2video_sound": + kwargs = { + **common, + "video": video, + "num_frames": 5, + "height": 32, + "width": 32, + "condition_frame_indexes_vision": [0, 1], + "condition_video_keep": "first", + "enable_sound": True, + } + elif case_name == "action_policy_image": + kwargs = { + **common, + "guidance_scale": 1.0, + "action": CosmosActionCondition( + mode="policy", + chunk_size=4, + domain_name="bridge_orig_lerobot", + resolution_tier=480, + image=action_image, + ), + } + elif case_name == "action_policy_video": + kwargs = { + **common, + "guidance_scale": 1.0, + "action": CosmosActionCondition( + mode="policy", + chunk_size=4, + domain_name="bridge_orig_lerobot", + resolution_tier=480, + video=action_video, + ), + } + elif case_name == "action_forward_video_bridge": + kwargs = { + **common, + "action": CosmosActionCondition( + mode="forward_dynamics", + chunk_size=4, + domain_name="bridge_orig_lerobot", + resolution_tier=480, + raw_actions=torch.linspace(-0.1, 0.1, steps=40, dtype=torch.float32).reshape(4, 10), + video=action_video, + ), + } + elif case_name == "action_inverse_video": + kwargs = { + **common, + "action": CosmosActionCondition( + mode="inverse_dynamics", + chunk_size=4, + domain_name="bridge_orig_lerobot", + resolution_tier=480, + video=action_video, + ), + } + elif case_name == "action_forward_image_av": + kwargs = { + **common, + "action": CosmosActionCondition( + mode="forward_dynamics", + chunk_size=4, + domain_name="av", + resolution_tier=480, + raw_actions=torch.linspace(-0.2, 0.2, steps=36, dtype=torch.float32).reshape(4, 9), + image=action_image, + ), + } + else: + raise ValueError(f"Unknown parity case: {case_name}") + + return kwargs + + +def _run_case(case_name: str): + task_pipe = _make_task_pipe() + modular_pipe = _make_modular_pipe() + kwargs = _build_case_kwargs(case_name) + + task_kwargs = dict(kwargs) + modular_kwargs = dict(kwargs) + task_kwargs["generator"] = torch.Generator(device="cpu").manual_seed(1234) + modular_kwargs["generator"] = torch.Generator(device="cpu").manual_seed(1234) + + task_out = task_pipe(**task_kwargs) + modular_out = modular_pipe(**modular_kwargs) + + if case_name in {"action_policy_image", "action_policy_video", "action_inverse_video"}: + assert task_out.action is not None, f"Task pipeline must return action outputs for {case_name}" + assert modular_out.action is not None, f"Modular pipeline must return action outputs for {case_name}" + assert len(task_out.action) > 0, f"Task pipeline returned empty action outputs for {case_name}" + assert len(modular_out.action) > 0, f"Modular pipeline returned empty action outputs for {case_name}" + + _assert_close_outputs(task_out, modular_out) + + +@pytest.mark.parametrize( + "case_name", + [ + "text2image", + "text2video", + "image2video", + "video2video", + "video2video_last", + "text2video_sound", + "image2video_sound", + "video2video_sound", + "action_policy_image", + "action_policy_video", + "action_forward_video_bridge", + "action_inverse_video", + "action_forward_image_av", + ], +) +def test_cosmos3_modular_parity_all_modes(case_name: str): + _run_case(case_name) + + +def test_cosmos3_modular_workflow_extraction(): + pipe = _make_modular_pipe() + expected = { + "text2image", + "text2video", + "image2video", + "video2video", + "text2video_with_sound", + "image2video_with_sound", + "video2video_with_sound", + "action_policy", + "action_forward_dynamics", + "action_inverse_dynamics", + } + assert set(pipe.blocks.available_workflows) == expected + + image2video_blocks = pipe.blocks.get_workflow("image2video") + assert list(image2video_blocks.sub_blocks.keys()) == [ + "text_encoder", + "denoise.prepare_latents", + "denoise.pack_sequence", + "denoise.set_timesteps", + "denoise.denoise", + "decode", + ] + + with pytest.raises(ValueError): + pipe.blocks.get_workflow("non_existent_workflow") + + +class Cosmos3ModularParitySmokeTests(unittest.TestCase): + def test_return_tuple_parity_for_video_and_sound(self): + task_pipe = _make_task_pipe() + modular_pipe = _make_modular_pipe() + + kwargs = { + "prompt": "A robot taps a table rhythmically.", + "negative_prompt": "", + "num_frames": 9, + "height": 32, + "width": 32, + "num_inference_steps": 2, + "guidance_scale": 2.0, + "fps": 5.0, + "enable_sound": True, + "output_type": "pt", + "return_dict": False, + "enable_safety_check": False, + } + task_kwargs = dict(kwargs) + modular_kwargs = dict(kwargs) + task_kwargs["generator"] = torch.Generator(device="cpu").manual_seed(7) + modular_kwargs["generator"] = torch.Generator(device="cpu").manual_seed(7) + + task_video, task_sound = task_pipe(**task_kwargs) + modular_video, modular_sound = modular_pipe(**modular_kwargs) + + torch.testing.assert_close(task_video, modular_video, atol=0.0, rtol=0.0) + torch.testing.assert_close(task_sound, modular_sound, atol=0.0, rtol=0.0) From 54bc6114779d69af89997a41e43ce5ecfab3ab70 Mon Sep 17 00:00:00 2001 From: Yuliya Zhautouskaya Date: Thu, 2 Jul 2026 14:51:54 +0000 Subject: [PATCH 2/2] Fix from_pretrained for modular pipleine without modular_model_index.json --- .../modular_pipelines/modular_pipeline.py | 22 +++++++++++-------- 1 file changed, 13 insertions(+), 9 deletions(-) diff --git a/src/diffusers/modular_pipelines/modular_pipeline.py b/src/diffusers/modular_pipelines/modular_pipeline.py index d43825860d8e..2d2fbe71f445 100644 --- a/src/diffusers/modular_pipelines/modular_pipeline.py +++ b/src/diffusers/modular_pipelines/modular_pipeline.py @@ -1866,15 +1866,19 @@ def from_pretrained( if modular_config_dict is not None: pipeline_class = _get_pipeline_class(cls, config=modular_config_dict) elif config_dict is not None: - from diffusers.pipelines.auto_pipeline import _get_model - - logger.debug(" try to determine the modular pipeline class from model_index.json") - standard_pipeline_class = _get_pipeline_class(cls, config=config_dict) - model_name = _get_model(standard_pipeline_class.__name__) - map_fn = MODULAR_PIPELINE_MAPPING.get(model_name, _create_default_map_fn("ModularPipeline")) - pipeline_class_name = map_fn(config_dict) - diffusers_module = importlib.import_module("diffusers") - pipeline_class = getattr(diffusers_module, pipeline_class_name) + if cls is not ModularPipeline: + # Keep explicit modular subclasses on their own class when loading from model_index.json. + pipeline_class = cls + else: + from diffusers.pipelines.auto_pipeline import _get_model + + logger.debug(" try to determine the modular pipeline class from model_index.json") + standard_pipeline_class = _get_pipeline_class(cls, config=config_dict) + model_name = _get_model(standard_pipeline_class.__name__) + map_fn = MODULAR_PIPELINE_MAPPING.get(model_name, _create_default_map_fn("ModularPipeline")) + pipeline_class_name = map_fn(config_dict) + diffusers_module = importlib.import_module("diffusers") + pipeline_class = getattr(diffusers_module, pipeline_class_name) else: # there is no config for modular pipeline, assuming that the pipeline block does not need any from_pretrained components pipeline_class = cls