Conversation
NucleusImage - text kv caching
| logger = logging.get_logger(__name__) | ||
|
|
||
|
|
||
| # copied from diffusers.models.transformers.transformer_qwenimage.apply_rotary_emb_qwen |
There was a problem hiding this comment.
| # copied from diffusers.models.transformers.transformer_qwenimage.apply_rotary_emb_qwen | |
| # Copied from diffusers.models.transformers.transformer_qwenimage.apply_rotary_emb_qwen with qwen->nucleus |
nit: # Copied from mechanism supports renamings with the above syntax
| return self.norm(conditioning) | ||
|
|
||
|
|
||
| # copied from diffusers.models.transformers.transformer_qwenimage.QwenEmbedRope |
There was a problem hiding this comment.
| # copied from diffusers.models.transformers.transformer_qwenimage.QwenEmbedRope | |
| # Copied from diffusers.models.transformers.transformer_qwenimage.QwenEmbedRope with Qwen->NucleusMoE |
See #13317 (comment). Alternatively, if NucleusMoEEmbedRope is changed (for example to remove txt_seq_lens as suggested in #13317 (comment)), the # Copied from statement should be removed.
| if txt_seq_lens is not None: | ||
| deprecate( | ||
| "txt_seq_lens", | ||
| "0.39.0", | ||
| "Passing `txt_seq_lens` is deprecated and will be removed in version 0.39.0. " | ||
| "Please use `max_txt_seq_len` instead.", | ||
| standard_warn=False, | ||
| ) |
There was a problem hiding this comment.
As this is a new model, can we remove the dependence on the deprecated txt_seq_lens argument?
|
|
||
| def forward( | ||
| self, | ||
| video_fhw: tuple[int, int, int, list[tuple[int, int, int]]], |
There was a problem hiding this comment.
| video_fhw: tuple[int, int, int, list[tuple[int, int, int]]], | |
| video_fhw: tuple[int, int, int] | list[tuple[int, int, int]], |
nit: fix type annotation
| return out | ||
|
|
||
|
|
||
| @maybe_allow_in_graph |
There was a problem hiding this comment.
I think we can remove the maybe_allow_in_graph decorator as the NucleusMoE-Image transformer compile tests
RUN_SLOW=1 RUN_COMPILE=1 pytest tests/models/transformers/test_models_transformer_nucleusmoe_image.py::TestNucleusMoEImageTransformerCompilehave the same pass/fail pattern with and without it. (Currently, test_compile_on_different_shapes fails both with and without maybe_allow_in_graph; all other tests pass.)
| attention_kwargs: dict[str, Any] | None = None, | ||
| ) -> torch.Tensor: | ||
| scale1, gate1, scale2, gate2 = self.img_mod(temb).unsqueeze(1).chunk(4, dim=-1) | ||
| scale1, scale2 = 1 + scale1, 1 + scale2 |
There was a problem hiding this comment.
nit: I think it's more clear if we do the calculation inline as e.g. img_modulated = img_normed * (1 + scale1).
| gate1 = gate1.clamp(min=-2.0, max=2.0) | ||
| gate2 = gate2.clamp(min=-2.0, max=2.0) |
There was a problem hiding this comment.
It seems weird to me that we first clamp the gates to [-2.0, 2.0] and then essentially clamp again by squashing with the tanh function below. Is this intended?
| if hidden_states.dtype == torch.float16: | ||
| hidden_states = hidden_states.clip(-65504, 65504) |
There was a problem hiding this comment.
| if hidden_states.dtype == torch.float16: | |
| hidden_states = hidden_states.clip(-65504, 65504) | |
| if hidden_states.dtype == torch.float16: | |
| fp16_finfo = torch.finfo(torch.float16) | |
| hidden_states = hidden_states.clip(fp16_finfo.min, fp16_finfo.max) |
| dense_moe_strategy: str = "leave_first_three_and_last_block_dense", | ||
| num_experts: int = 128, | ||
| moe_intermediate_dim: int = 1344, | ||
| capacity_factors: List[float] = [8.0] * 24, |
There was a problem hiding this comment.
| capacity_factors: List[float] = [8.0] * 24, | |
| capacity_factors: float | list[float] = 8.0, |
I think allowing capacity_factors to take float arguments as well makes the code a little cleaner. We would then expand float inputs to a list inside __init__:
if isinstance(capacity_factors, float):
capacity_factors = [capacity_factors] * num_layers| def forward( | ||
| self, | ||
| hidden_states: torch.Tensor, | ||
| img_shapes: list[tuple[int, int, int]] | None = None, |
There was a problem hiding this comment.
| img_shapes: list[tuple[int, int, int]] | None = None, | |
| img_shapes: tuple[int, int, int] | list[tuple[int, int, int]], |
I think allowing img_shapes to take tuple[int, int, int] arguments as well would be cleaner, similar to #13317 (comment). If I understand correctly, NucleusMoEEmbedRope only accepts batches with the same image shape, so this would make it easier to specify such shapes.
| "Please use `encoder_hidden_states_mask` instead.", | ||
| standard_warn=False, | ||
| ) | ||
|
|
There was a problem hiding this comment.
I think we can remove the deprecated txt_seq_lens argument here as well. See #13317 (comment).
| """ | ||
|
|
||
|
|
||
| def calculate_shift( |
There was a problem hiding this comment.
| def calculate_shift( | |
| # Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage.calculate_shift | |
| def calculate_shift( |
| return mu | ||
|
|
||
|
|
||
| def retrieve_timesteps( |
There was a problem hiding this comment.
| def retrieve_timesteps( | |
| # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps | |
| def retrieve_timesteps( |
| self.default_sample_size = 128 | ||
| self.return_index = -8 |
There was a problem hiding this comment.
Should default_sample_size and return_index be configurable via __init__?
| prompt_embeds_mask=None, | ||
| negative_prompt_embeds_mask=None, |
There was a problem hiding this comment.
| prompt_embeds_mask=None, | |
| negative_prompt_embeds_mask=None, |
nit: remove prompt_embeds_mask and negative_prompt_embeds_mask as they are not used in check_inputs.
| return latents | ||
|
|
||
| @staticmethod | ||
| def _unpack_latents(latents, height, width, vae_scale_factor): |
There was a problem hiding this comment.
Could we refactor _pack_latents and _unpack_latents to take a patch_size argument instead of hardcoding the patch size to 2? This would make the code more robust.
| height = 2 * (int(height) // (self.vae_scale_factor * 2)) | ||
| width = 2 * (int(width) // (self.vae_scale_factor * 2)) |
There was a problem hiding this comment.
Could we refactor this to use self.transformer.config.patch_size instead of hardcoding the patch size to 2? See also #13317 (comment).
| latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width) | ||
| return latents | ||
|
|
||
| def enable_vae_slicing(self): |
There was a problem hiding this comment.
We can remove the VAE slicing/tiling methods here as they are deprecated. Users can always call the corresponding methods on the VAE itself (e.g. pipe.vae.enable_tiling()) to enable/disable slicing/tiling.
| self, | ||
| prompt: str | list[str] = None, | ||
| negative_prompt: str | list[str] = None, | ||
| true_cfg_scale: float = 4.0, |
There was a problem hiding this comment.
| true_cfg_scale: float = 4.0, | |
| guidance_scale: float = 4.0, |
nit: rename to guidance_scale to follow the diffusers CFG naming conventions.
| latent_h = 2 * (int(height) // (self.vae_scale_factor * 2)) | ||
| latent_w = 2 * (int(width) // (self.vae_scale_factor * 2)) | ||
| img_shapes = [(1, latent_h // 2, latent_w // 2)] * (batch_size * num_images_per_prompt) |
There was a problem hiding this comment.
Similar to #13317 (comment), can we refactor this to use self.transformer.config.patch_size?
|
|
||
| noise_pred = self.transformer( | ||
| hidden_states=latents, | ||
| timestep=timestep / 1000, |
There was a problem hiding this comment.
Instead of hardcoding this at 1000, could we use self.scheduler.config.num_train_timesteps instead?
| noise_norm = torch.norm(comb_pred, dim=-1, keepdim=True) | ||
| noise_pred = comb_pred * (cond_norm / noise_norm) | ||
|
|
||
| noise_pred = -noise_pred |
There was a problem hiding this comment.
Why do we need to negate noise_pred here?
| def __init__(self): | ||
| super().__init__() | ||
| # Maps encoder_hidden_states.data_ptr() → (txt_key, txt_value) | ||
| self.kv_cache: dict[int, tuple] = {} |
There was a problem hiding this comment.
| self.kv_cache: dict[int, tuple] = {} | |
| self.kv_cache: dict[int, tuple[torch.Tensor, torch.Tensor]] = {} |
nit: more specific type annotation
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
| self.experts = nn.ModuleList( | ||
| [ | ||
| FeedForward( | ||
| dim=hidden_size, | ||
| dim_out=hidden_size, | ||
| inner_dim=moe_intermediate_dim, | ||
| activation_fn="swiglu", | ||
| bias=False, | ||
| ) | ||
| for _ in range(num_experts) | ||
| ] | ||
| ) |
There was a problem hiding this comment.
you would need the projections to be in packed/contiguous format for torch.grouped_mm support (num_experts, dim_in, dim_out), @sayakpaul is that possible ? in Transformers we use the inline weight converter
There was a problem hiding this comment.
Not at the moment because MoEs are still a bit of a special case in this part of world.
What does this PR do?
This PR introduces NucleusMoE-Image series into the diffusers library.
NucleusMoE-Image is a 2B active 17B parameter model trained with efficiency at its core. Our novel architecture highlights the scalability of sparse MoE architecture for Image generation. The technical report will be released very soon.