Skip to content

NucleusMoE-Image#13317

Open
sippycoder wants to merge 10 commits intohuggingface:mainfrom
sippycoder:main
Open

NucleusMoE-Image#13317
sippycoder wants to merge 10 commits intohuggingface:mainfrom
sippycoder:main

Conversation

@sippycoder
Copy link

What does this PR do?

This PR introduces NucleusMoE-Image series into the diffusers library.

NucleusMoE-Image is a 2B active 17B parameter model trained with efficiency at its core. Our novel architecture highlights the scalability of sparse MoE architecture for Image generation. The technical report will be released very soon.

@sippycoder
Copy link
Author

cc: @sayakpaul @IlyasMoutawwakil

@sayakpaul sayakpaul requested review from dg845 and yiyixuxu March 24, 2026 04:08
logger = logging.get_logger(__name__)


# copied from diffusers.models.transformers.transformer_qwenimage.apply_rotary_emb_qwen
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# copied from diffusers.models.transformers.transformer_qwenimage.apply_rotary_emb_qwen
# Copied from diffusers.models.transformers.transformer_qwenimage.apply_rotary_emb_qwen with qwen->nucleus

nit: # Copied from mechanism supports renamings with the above syntax

return self.norm(conditioning)


# copied from diffusers.models.transformers.transformer_qwenimage.QwenEmbedRope
Copy link
Collaborator

@dg845 dg845 Mar 25, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# copied from diffusers.models.transformers.transformer_qwenimage.QwenEmbedRope
# Copied from diffusers.models.transformers.transformer_qwenimage.QwenEmbedRope with Qwen->NucleusMoE

See #13317 (comment). Alternatively, if NucleusMoEEmbedRope is changed (for example to remove txt_seq_lens as suggested in #13317 (comment)), the # Copied from statement should be removed.

Comment on lines +178 to +185
if txt_seq_lens is not None:
deprecate(
"txt_seq_lens",
"0.39.0",
"Passing `txt_seq_lens` is deprecated and will be removed in version 0.39.0. "
"Please use `max_txt_seq_len` instead.",
standard_warn=False,
)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As this is a new model, can we remove the dependence on the deprecated txt_seq_lens argument?


def forward(
self,
video_fhw: tuple[int, int, int, list[tuple[int, int, int]]],
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
video_fhw: tuple[int, int, int, list[tuple[int, int, int]]],
video_fhw: tuple[int, int, int] | list[tuple[int, int, int]],

nit: fix type annotation

return out


@maybe_allow_in_graph
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can remove the maybe_allow_in_graph decorator as the NucleusMoE-Image transformer compile tests

RUN_SLOW=1 RUN_COMPILE=1 pytest tests/models/transformers/test_models_transformer_nucleusmoe_image.py::TestNucleusMoEImageTransformerCompile

have the same pass/fail pattern with and without it. (Currently, test_compile_on_different_shapes fails both with and without maybe_allow_in_graph; all other tests pass.)

attention_kwargs: dict[str, Any] | None = None,
) -> torch.Tensor:
scale1, gate1, scale2, gate2 = self.img_mod(temb).unsqueeze(1).chunk(4, dim=-1)
scale1, scale2 = 1 + scale1, 1 + scale2
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: I think it's more clear if we do the calculation inline as e.g. img_modulated = img_normed * (1 + scale1).

Comment on lines +545 to +546
gate1 = gate1.clamp(min=-2.0, max=2.0)
gate2 = gate2.clamp(min=-2.0, max=2.0)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems weird to me that we first clamp the gates to [-2.0, 2.0] and then essentially clamp again by squashing with the tanh function below. Is this intended?

Comment on lines +574 to +575
if hidden_states.dtype == torch.float16:
hidden_states = hidden_states.clip(-65504, 65504)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if hidden_states.dtype == torch.float16:
hidden_states = hidden_states.clip(-65504, 65504)
if hidden_states.dtype == torch.float16:
fp16_finfo = torch.finfo(torch.float16)
hidden_states = hidden_states.clip(fp16_finfo.min, fp16_finfo.max)

dense_moe_strategy: str = "leave_first_three_and_last_block_dense",
num_experts: int = 128,
moe_intermediate_dim: int = 1344,
capacity_factors: List[float] = [8.0] * 24,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
capacity_factors: List[float] = [8.0] * 24,
capacity_factors: float | list[float] = 8.0,

I think allowing capacity_factors to take float arguments as well makes the code a little cleaner. We would then expand float inputs to a list inside __init__:

if isinstance(capacity_factors, float):
    capacity_factors = [capacity_factors] * num_layers

def forward(
self,
hidden_states: torch.Tensor,
img_shapes: list[tuple[int, int, int]] | None = None,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
img_shapes: list[tuple[int, int, int]] | None = None,
img_shapes: tuple[int, int, int] | list[tuple[int, int, int]],

I think allowing img_shapes to take tuple[int, int, int] arguments as well would be cleaner, similar to #13317 (comment). If I understand correctly, NucleusMoEEmbedRope only accepts batches with the same image shape, so this would make it easier to specify such shapes.

"Please use `encoder_hidden_states_mask` instead.",
standard_warn=False,
)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can remove the deprecated txt_seq_lens argument here as well. See #13317 (comment).

"""


def calculate_shift(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def calculate_shift(
# Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage.calculate_shift
def calculate_shift(

return mu


def retrieve_timesteps(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def retrieve_timesteps(
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
def retrieve_timesteps(

Comment on lines +177 to +178
self.default_sample_size = 128
self.return_index = -8
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should default_sample_size and return_index be configurable via __init__?

Comment on lines +265 to +266
prompt_embeds_mask=None,
negative_prompt_embeds_mask=None,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
prompt_embeds_mask=None,
negative_prompt_embeds_mask=None,

nit: remove prompt_embeds_mask and negative_prompt_embeds_mask as they are not used in check_inputs.

return latents

@staticmethod
def _unpack_latents(latents, height, width, vae_scale_factor):
Copy link
Collaborator

@dg845 dg845 Mar 25, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we refactor _pack_latents and _unpack_latents to take a patch_size argument instead of hardcoding the patch size to 2? This would make the code more robust.

Comment on lines +336 to +337
height = 2 * (int(height) // (self.vae_scale_factor * 2))
width = 2 * (int(width) // (self.vae_scale_factor * 2))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we refactor this to use self.transformer.config.patch_size instead of hardcoding the patch size to 2? See also #13317 (comment).

latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width)
return latents

def enable_vae_slicing(self):
Copy link
Collaborator

@dg845 dg845 Mar 25, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can remove the VAE slicing/tiling methods here as they are deprecated. Users can always call the corresponding methods on the VAE itself (e.g. pipe.vae.enable_tiling()) to enable/disable slicing/tiling.

self,
prompt: str | list[str] = None,
negative_prompt: str | list[str] = None,
true_cfg_scale: float = 4.0,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
true_cfg_scale: float = 4.0,
guidance_scale: float = 4.0,

nit: rename to guidance_scale to follow the diffusers CFG naming conventions.

Comment on lines +551 to +553
latent_h = 2 * (int(height) // (self.vae_scale_factor * 2))
latent_w = 2 * (int(width) // (self.vae_scale_factor * 2))
img_shapes = [(1, latent_h // 2, latent_w // 2)] * (batch_size * num_images_per_prompt)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Similar to #13317 (comment), can we refactor this to use self.transformer.config.patch_size?


noise_pred = self.transformer(
hidden_states=latents,
timestep=timestep / 1000,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of hardcoding this at 1000, could we use self.scheduler.config.num_train_timesteps instead?

noise_norm = torch.norm(comb_pred, dim=-1, keepdim=True)
noise_pred = comb_pred * (cond_norm / noise_norm)

noise_pred = -noise_pred
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need to negate noise_pred here?

def __init__(self):
super().__init__()
# Maps encoder_hidden_states.data_ptr() → (txt_key, txt_value)
self.kv_cache: dict[int, tuple] = {}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
self.kv_cache: dict[int, tuple] = {}
self.kv_cache: dict[int, tuple[torch.Tensor, torch.Tensor]] = {}

nit: more specific type annotation

Copy link
Collaborator

@dg845 dg845 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the PR! Left an initial review :). @yiyixuxu, could you also take a look at the text KV cache code in src/diffusers/hooks/text_kv_cache.py?

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Comment on lines +380 to +391
self.experts = nn.ModuleList(
[
FeedForward(
dim=hidden_size,
dim_out=hidden_size,
inner_dim=moe_intermediate_dim,
activation_fn="swiglu",
bias=False,
)
for _ in range(num_experts)
]
)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you would need the projections to be in packed/contiguous format for torch.grouped_mm support (num_experts, dim_in, dim_out), @sayakpaul is that possible ? in Transformers we use the inline weight converter

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not at the moment because MoEs are still a bit of a special case in this part of world.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants