Skip to content

Optimize RoPE arithmetic and simplify tensor reshaping#355

Merged
copybara-service[bot] merged 1 commit intomainfrom
wan-opt
Mar 12, 2026
Merged

Optimize RoPE arithmetic and simplify tensor reshaping#355
copybara-service[bot] merged 1 commit intomainfrom
wan-opt

Conversation

@Perseus14
Copy link
Collaborator

This PR introduces two focused optimizations to improve memory bandwidth utilization and overall code maintainability. It optimizes the Rotary Position Embedding (RoPE) application by utilizing native real arithmetic for better hardware fusion, and refactors a sequence of tensor collapses into a single, robust reshape operation.

Changes:

  • Replaced the float32 upcasting and jax.lax.complex multiplication with explicit real-number arithmetic (computing the 2D rotation directly), keeping the tensors in their native dtype (e.g., bfloat16).

  • Replaced three sequential jax.lax.collapse operations with a single, explicit hidden_states.reshape(batch_size, -1, num_frames, height, width).

@Perseus14 Perseus14 requested a review from entrpn as a code owner March 12, 2026 08:19
@github-actions
Copy link

@mbohlool mbohlool self-requested a review March 12, 2026 19:36
Copy link
Collaborator

@mbohlool mbohlool left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@copybara-service copybara-service bot merged commit cbe451a into main Mar 12, 2026
13 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants