Add optional Approximate Top-K configuration for MLA Indexer#4243
Open
JHCuc3m wants to merge 2 commits into
Open
Add optional Approximate Top-K configuration for MLA Indexer#4243JHCuc3m wants to merge 2 commits into
JHCuc3m wants to merge 2 commits into
Conversation
Codecov Report✅ All modified and coverable lines are covered by tests. 📢 Thoughts on this report? Let us know! |
36913f5 to
9ef80f2
Compare
This commit adds an optional configuration parameter 'indexer_use_approx_top_k' to the MLA Indexer, allowing users to use JAX's TPU-optimized 'approx_max_k' instead of the default exact 'top_k'. When enabled, this optimization avoids expensive bitonic sorts on large sequence dimensions, significantly reducing step time overhead for long-context TPU runs. Benchmarks show a 4x speedup on f32[1,1024,65536] tensors when tested with the approximate path enabled using a recall target of 0.95. Changes: - base.yml: Added `indexer_use_approx_top_k` and `indexer_approx_top_k_recall` configs (defaulting to false/0.95). - types.py: Added the new parameters to the AttentionIndexer Pydantic schema to pass configuration validation. - attention_mla.py: Integrated conditional routing to `jax.lax.approx_max_k` in the Indexer forward pass. - attention_test.py: Added a compilation safety unit test and a recall correctness validation test (which verifies approx recall meets the target). TAG=agy CONV=7fa653b2-d3d7-4a18-8df8-9020b6805e11
9ef80f2 to
53d7917
Compare
Collaborator
Author
|
Code quality checker failed from existing file content |
NuojCheng
reviewed
Jun 25, 2026
|
|
||
| # Assert that the actual recall is close to or exceeds the target. | ||
| # We allow a small margin (e.g., 0.05) due to the approximate nature and sample size. | ||
| self.assertGreaterEqual(mean_recall, recall_target - 0.05) |
Collaborator
There was a problem hiding this comment.
Is 0.05 too large? What about making it 0.01?
NuojCheng
approved these changes
Jun 25, 2026
NuojCheng
left a comment
Collaborator
There was a problem hiding this comment.
I think it is a very cool optimization! Thank you Jiahao
Updated recall assertion to require exact target match.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Description
This PR adds an optional configuration parameter
indexer_use_approx_top_kto the Multi-head Latent Attention (MLA) Indexer, allowing users to enable JAX's TPU-optimizedapprox_max_kprimitive instead of the default exacttop_kselection.Why is this change being made?
During performance investigations of DeepSeek-V3.2 in long-context mode (128K sequence length) with 128-way Context Parallelism (CP=128), the
top-kwas identified a major bottleneck in the MLA Indexer forward pass:[1, 1024, 131072]on each device.top_kselection is implemented on TPU via an expensive Bitonic Sort along the 131K sequence dimension.Benchmarks show a 4x speedup on f32[1,1024,65536] tensors when tested with the approximate path enabled using a recall target of 0.95.
Why this is a good solution
JAX's$\mathcal{O}(N \log^2 N)$ to $\approx \mathcal{O}(N + K \log K)$ with a significantly smaller constant factor. Paper: https://arxiv.org/pdf/2206.14286.
approx_max_kemploys block-based reduction optimized for TPU Matrix Units (MXU). It avoids full bitonic sorting, reducing complexity fromWorkload show a ~4x speedup on
f32[1, 1024, 65536]tensors when tested on TPU with the approximate path enabled using a recall target of 0.95.Specific Implementation Details
indexer_use_approx_top_kandindexer_approx_top_k_recallto theAttentionIndexerPydantic class to pass configuration validation.indexer_use_approx_top_k: false,indexer_approx_top_k_recall: 0.95).Indexer.__call__to conditionally route the selection tojax.lax.approx_max_kwhen enabled.Shortcomings & Future Improvements
indexer_use_approx_top_kinstead oftop_kmight affect downstream model performance, while it is expected to be minimal when a high recall rate is used.Tests
1. Regression Guard (Default Path)
We ran the attention unit test suite with the default configuration (
indexer_use_approx_top_k=false) to ensure no regressions:pytest tests/unit/attention_test.py2. Compilation & Tracing Safety
We added a new unit test,
test_indexer_with_approx_top_k, to verify that the new path compiles and traces successfully in JAX:pytest tests/unit/attention_test.py -k test_indexer_with_approx_top_k3. Mathematical Correctness & Recall Tracking
We added a correctness test,$K=64$ ), and calculates the actual recall:
test_approx_top_k_recall, which generates random scores of shape[4, 16, 1024], runs both exact and approximate top-K (pytest tests/unit/attention_test.py -k test_approx_top_k_recall -sChecklist
Before submitting this PR, please make sure (put X in square brackets):
gemini-reviewlabel.