Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 19 additions & 9 deletions src/maxdiffusion/models/attention_flax.py
Original file line number Diff line number Diff line change
Expand Up @@ -1081,18 +1081,28 @@ def __init__(
)

def _apply_rope(self, xq: jax.Array, xk: jax.Array, freqs_cis: jax.Array) -> Tuple[jax.Array, jax.Array]:
dtype = xq.dtype
reshape_xq = xq.astype(jnp.float32).reshape(*xq.shape[:-1], -1, 2)
reshape_xk = xk.astype(jnp.float32).reshape(*xk.shape[:-1], -1, 2)
# 1. Extract cos and sin, keeping them in native bfloat16
cos = jnp.real(freqs_cis).astype(xq.dtype)
sin = jnp.imag(freqs_cis).astype(xq.dtype)

xq_ = jax.lax.complex(reshape_xq[..., 0], reshape_xq[..., 1])
xk_ = jax.lax.complex(reshape_xk[..., 0], reshape_xk[..., 1])
# 2. Reshape the last dimension into pairs
xq_reshaped = xq.reshape(*xq.shape[:-1], -1, 2)
xk_reshaped = xk.reshape(*xk.shape[:-1], -1, 2)

xq_out_complex = xq_ * freqs_cis
xk_out_complex = xk_ * freqs_cis
# 3. Unbind the pairs
xq_0, xq_1 = xq_reshaped[..., 0], xq_reshaped[..., 1]
xk_0, xk_1 = xk_reshaped[..., 0], xk_reshaped[..., 1]

xq_out = jnp.stack([jnp.real(xq_out_complex), jnp.imag(xq_out_complex)], axis=-1).reshape(xq.shape).astype(dtype)
xk_out = jnp.stack([jnp.real(xk_out_complex), jnp.imag(xk_out_complex)], axis=-1).reshape(xk.shape).astype(dtype)
# 4. Pure real arithmetic (XLA will fuse these instantly into FMA instructions)
xq_out_0 = xq_0 * cos - xq_1 * sin
xq_out_1 = xq_0 * sin + xq_1 * cos

xk_out_0 = xk_0 * cos - xk_1 * sin
xk_out_1 = xk_0 * sin + xk_1 * cos

# 5. Stack and reshape back to original
xq_out = jnp.stack([xq_out_0, xq_out_1], axis=-1).reshape(xq.shape)
xk_out = jnp.stack([xk_out_0, xk_out_1], axis=-1).reshape(xk.shape)

return xq_out, xk_out

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -684,7 +684,5 @@ def layer_forward(hidden_states):
batch_size, post_patch_num_frames, post_patch_height, post_patch_width, p_t, p_h, p_w, -1
)
hidden_states = jnp.transpose(hidden_states, (0, 7, 1, 4, 2, 5, 3, 6))
hidden_states = jax.lax.collapse(hidden_states, 6, None)
hidden_states = jax.lax.collapse(hidden_states, 4, 6)
hidden_states = jax.lax.collapse(hidden_states, 2, 4)
hidden_states = hidden_states.reshape(batch_size, -1, num_frames, height, width)
return hidden_states
Loading