From 62e0805b155412819bc3238161bdf02ca67cad44 Mon Sep 17 00:00:00 2001 From: Rishabh Manoj Date: Thu, 12 Mar 2026 08:13:08 +0000 Subject: [PATCH] Minor Optimizations to WAN RoPE --- src/maxdiffusion/models/attention_flax.py | 28 +++++++++++++------ .../wan/transformers/transformer_wan.py | 4 +-- 2 files changed, 20 insertions(+), 12 deletions(-) diff --git a/src/maxdiffusion/models/attention_flax.py b/src/maxdiffusion/models/attention_flax.py index b583171d..9b738c18 100644 --- a/src/maxdiffusion/models/attention_flax.py +++ b/src/maxdiffusion/models/attention_flax.py @@ -1081,18 +1081,28 @@ def __init__( ) def _apply_rope(self, xq: jax.Array, xk: jax.Array, freqs_cis: jax.Array) -> Tuple[jax.Array, jax.Array]: - dtype = xq.dtype - reshape_xq = xq.astype(jnp.float32).reshape(*xq.shape[:-1], -1, 2) - reshape_xk = xk.astype(jnp.float32).reshape(*xk.shape[:-1], -1, 2) + # 1. Extract cos and sin, keeping them in native bfloat16 + cos = jnp.real(freqs_cis).astype(xq.dtype) + sin = jnp.imag(freqs_cis).astype(xq.dtype) - xq_ = jax.lax.complex(reshape_xq[..., 0], reshape_xq[..., 1]) - xk_ = jax.lax.complex(reshape_xk[..., 0], reshape_xk[..., 1]) + # 2. Reshape the last dimension into pairs + xq_reshaped = xq.reshape(*xq.shape[:-1], -1, 2) + xk_reshaped = xk.reshape(*xk.shape[:-1], -1, 2) - xq_out_complex = xq_ * freqs_cis - xk_out_complex = xk_ * freqs_cis + # 3. Unbind the pairs + xq_0, xq_1 = xq_reshaped[..., 0], xq_reshaped[..., 1] + xk_0, xk_1 = xk_reshaped[..., 0], xk_reshaped[..., 1] - xq_out = jnp.stack([jnp.real(xq_out_complex), jnp.imag(xq_out_complex)], axis=-1).reshape(xq.shape).astype(dtype) - xk_out = jnp.stack([jnp.real(xk_out_complex), jnp.imag(xk_out_complex)], axis=-1).reshape(xk.shape).astype(dtype) + # 4. Pure real arithmetic (XLA will fuse these instantly into FMA instructions) + xq_out_0 = xq_0 * cos - xq_1 * sin + xq_out_1 = xq_0 * sin + xq_1 * cos + + xk_out_0 = xk_0 * cos - xk_1 * sin + xk_out_1 = xk_0 * sin + xk_1 * cos + + # 5. Stack and reshape back to original + xq_out = jnp.stack([xq_out_0, xq_out_1], axis=-1).reshape(xq.shape) + xk_out = jnp.stack([xk_out_0, xk_out_1], axis=-1).reshape(xk.shape) return xq_out, xk_out diff --git a/src/maxdiffusion/models/wan/transformers/transformer_wan.py b/src/maxdiffusion/models/wan/transformers/transformer_wan.py index 887cb0d0..e701ab92 100644 --- a/src/maxdiffusion/models/wan/transformers/transformer_wan.py +++ b/src/maxdiffusion/models/wan/transformers/transformer_wan.py @@ -684,7 +684,5 @@ def layer_forward(hidden_states): batch_size, post_patch_num_frames, post_patch_height, post_patch_width, p_t, p_h, p_w, -1 ) hidden_states = jnp.transpose(hidden_states, (0, 7, 1, 4, 2, 5, 3, 6)) - hidden_states = jax.lax.collapse(hidden_states, 6, None) - hidden_states = jax.lax.collapse(hidden_states, 4, 6) - hidden_states = jax.lax.collapse(hidden_states, 2, 4) + hidden_states = hidden_states.reshape(batch_size, -1, num_frames, height, width) return hidden_states