Skip to content

Optimize MLA generate_mask with scatter update#4261

Draft
JHCuc3m wants to merge 1 commit into
mainfrom
zjiahao/DSA3.2-scatter-mask-gen
Draft

Optimize MLA generate_mask with scatter update#4261
JHCuc3m wants to merge 1 commit into
mainfrom
zjiahao/DSA3.2-scatter-mask-gen

Conversation

@JHCuc3m

@JHCuc3m JHCuc3m commented Jun 24, 2026

Copy link
Copy Markdown
Collaborator

Description

This PR optimizes the generate_mask function in the Multi-head Latent Attention (MLA) sparse indexer by replacing a high-complexity broadcasted comparison with a JAX scatter-based update.

  • The Problem: The previous implementation (jnp.arange(s) == topk_indices[..., None]).any(axis=-2) had an algorithmic complexity of $O(b \cdot t \cdot k \cdot s)$. At a 128K context window, which theoretically creates a massive virtual intermediate tensor of ~274 Billion elements (~256 GB). On TPU, although the XLA compiler successfully fused the loop to avoid physical HBM allocation (reducing HBM traffic to ~142 MB), the hardware remained heavily bottlenecked by the sheer volume of element-wise ALU comparisons (274 Billion ops) executed in registers, taking ~153 ms per step.

  • The Solution: The optimized version initializes the mask to DEFAULT_MASK_VALUE $O(b \cdot t \cdot s)$ and then uses indexing to scatter-write 0.0 at the selected topk_indices $O(b \cdot t \cdot k)$, where $k \ll s$ (2,048 vs 131,072).

  • Implementation Details: Initializes the mask using jnp.full and performs the scatter update using broadcasted batch and time indices (batch_indices = jnp.arange(b)[:, None, None], time_indices = jnp.arange(t)[None, :, None]) via mask.at[batch_indices, time_indices, topk_indices].set(0.0).

Tests

1. Mathematical Equivalence Test

Added test_generate_mask_equivalence to tests/unit/attention_test.py to verify that the new scatter-based implementation produces identical results to the old broadcast-based implementation.

  • Result: Passed (exact match within atol=1e-5).

2. Regression Testing

Ran the existing unit test suite to ensure no regressions in attention or model functionality:

  • tests/unit/attention_test.py (Passed)
  • tests/unit/model_test.py (Passed)

Checklist

  • I have performed a self-review of my code.
  • I have necessary comments in my code, particularly in hard-to-understand areas.
  • I have run end-to-end tests tests and provided workload links above if applicable. (Validation run zj-scatter-mask-val is currently running on the cluster).
  • I have made or will make corresponding changes to the doc if needed.

@codecov

codecov Bot commented Jun 24, 2026

Copy link
Copy Markdown

Codecov Report

✅ All modified and coverable lines are covered by tests.

📢 Thoughts on this report? Let us know!

Replace high-complexity broadcasted comparison in `generate_mask`
with a JAX scatter-based update to prevent OOM and reduce overhead.

- Why: The previous implementation relied on a broadcasted comparison between
  the full sequence and the selected top-k indices. At scale (large context
  lengths), this created a massive logical intermediate tensor, causing immediate
  OOM crashes during local CPU testing. On TPU, even though the XLA compiler successfully
  fused the operation to avoid physical memory (HBM) OOM, the hardware remained heavily
  bottlenecked by the extreme number of element-wise ALU comparisons executed in registers.
- How: Initializes the mask with `DEFAULT_MASK_VALUE` using `jnp.full`, then
  uses advanced indexing to scatter-write `0.0` at the selected `topk_indices`.
  This fundamentally changes the algorithm to perform direct writes instead of
  comparing all elements, reducing complexity and instruction count.
- Verification:
  - Added `test_generate_mask_equivalence` to `attention_test.py` to verify
    mathematical equivalence (Passed).
  - Ran existing unit tests (Passed).

TAG=agy
CONV=f93063bc-d96c-46f1-9562-20d2a5bf3241
@JHCuc3m JHCuc3m force-pushed the zjiahao/DSA3.2-scatter-mask-gen branch from c6af97c to 8d32e21 Compare June 24, 2026 21:51
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant