Skip to content

Add optional Approximate Top-K configuration for MLA Indexer#4243

Open
JHCuc3m wants to merge 2 commits into
mainfrom
zjiahao/DSA3.2-approx-top-k
Open

Add optional Approximate Top-K configuration for MLA Indexer#4243
JHCuc3m wants to merge 2 commits into
mainfrom
zjiahao/DSA3.2-approx-top-k

Conversation

@JHCuc3m

@JHCuc3m JHCuc3m commented Jun 23, 2026

Copy link
Copy Markdown
Collaborator

Description

This PR adds an optional configuration parameter indexer_use_approx_top_k to the Multi-head Latent Attention (MLA) Indexer, allowing users to enable JAX's TPU-optimized approx_max_k primitive instead of the default exact top_k selection.

Why is this change being made?

During performance investigations of DeepSeek-V3.2 in long-context mode (128K sequence length) with 128-way Context Parallelism (CP=128), the top-k was identified a major bottleneck in the MLA Indexer forward pass:

  • The indexer computes a local score matrix of shape [1, 1024, 131072] on each device.
  • The default exact top_k selection is implemented on TPU via an expensive Bitonic Sort along the 131K sequence dimension.
  • Sorting 131K elements per layer across 58 layers introduces massive step-time overhead.

Benchmarks show a 4x speedup on f32[1,1024,65536] tensors when tested with the approximate path enabled using a recall target of 0.95.

Why this is a good solution

JAX's approx_max_k employs block-based reduction optimized for TPU Matrix Units (MXU). It avoids full bitonic sorting, reducing complexity from $\mathcal{O}(N \log^2 N)$ to $\approx \mathcal{O}(N + K \log K)$ with a significantly smaller constant factor. Paper: https://arxiv.org/pdf/2206.14286.

Workload show a ~4x speedup on f32[1, 1024, 65536] tensors when tested on TPU with the approximate path enabled using a recall target of 0.95.

Specific Implementation Details

  1. Configuration Schema (types.py): Added indexer_use_approx_top_k and indexer_approx_top_k_recall to the AttentionIndexer Pydantic class to pass configuration validation.
  2. Default Config (base.yml): Exposed the parameters with safe defaults (indexer_use_approx_top_k: false, indexer_approx_top_k_recall: 0.95).
  3. Model Architecture (attention_mla.py): Updated Indexer.__call__ to conditionally route the selection to jax.lax.approx_max_k when enabled.

Shortcomings & Future Improvements

  • There is not systematic study on how the accuracy lost of using indexer_use_approx_top_k instead of top_k might affect downstream model performance, while it is expected to be minimal when a high recall rate is used.

Tests

1. Regression Guard (Default Path)

We ran the attention unit test suite with the default configuration (indexer_use_approx_top_k=false) to ensure no regressions:

  • Command: pytest tests/unit/attention_test.py
  • Result: PASSED (20 passed, 32 skipped).

2. Compilation & Tracing Safety

We added a new unit test, test_indexer_with_approx_top_k, to verify that the new path compiles and traces successfully in JAX:

  • Command: pytest tests/unit/attention_test.py -k test_indexer_with_approx_top_k
  • Result: PASSED.

3. Mathematical Correctness & Recall Tracking

We added a correctness test, test_approx_top_k_recall, which generates random scores of shape [4, 16, 1024], runs both exact and approximate top-K ($K=64$), and calculates the actual recall:

  • Command: pytest tests/unit/attention_test.py -k test_approx_top_k_recall -s
  • Result: PASSED (Achieved 100% recall on CPU).

Checklist

Before submitting this PR, please make sure (put X in square brackets):

@codecov

codecov Bot commented Jun 23, 2026

Copy link
Copy Markdown

Codecov Report

✅ All modified and coverable lines are covered by tests.

📢 Thoughts on this report? Let us know!

@JHCuc3m JHCuc3m force-pushed the zjiahao/DSA3.2-approx-top-k branch 2 times, most recently from 36913f5 to 9ef80f2 Compare June 23, 2026 22:16
This commit adds an optional configuration parameter 'indexer_use_approx_top_k' to the MLA Indexer, allowing users to use JAX's TPU-optimized 'approx_max_k' instead of the default exact 'top_k'. When enabled, this optimization avoids expensive bitonic sorts on large sequence dimensions, significantly reducing step time overhead for long-context TPU runs.

Benchmarks show a 4x speedup on f32[1,1024,65536] tensors when tested with the approximate path enabled using a recall target of 0.95.

Changes:
- base.yml: Added `indexer_use_approx_top_k` and `indexer_approx_top_k_recall` configs (defaulting to false/0.95).
- types.py: Added the new parameters to the AttentionIndexer Pydantic schema to pass configuration validation.
- attention_mla.py: Integrated conditional routing to `jax.lax.approx_max_k` in the Indexer forward pass.
- attention_test.py: Added a compilation safety unit test and a recall correctness validation test (which verifies approx recall meets the target).

TAG=agy
CONV=7fa653b2-d3d7-4a18-8df8-9020b6805e11
@JHCuc3m JHCuc3m force-pushed the zjiahao/DSA3.2-approx-top-k branch from 9ef80f2 to 53d7917 Compare June 23, 2026 22:23
@JHCuc3m

JHCuc3m commented Jun 23, 2026

Copy link
Copy Markdown
Collaborator Author

Code quality checker failed from existing file content

Comment thread tests/unit/attention_test.py Outdated

# Assert that the actual recall is close to or exceeds the target.
# We allow a small margin (e.g., 0.05) due to the approximate nature and sample size.
self.assertGreaterEqual(mean_recall, recall_target - 0.05)

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is 0.05 too large? What about making it 0.01?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

updated!

@NuojCheng NuojCheng left a comment

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it is a very cool optimization! Thank you Jiahao

Updated recall assertion to require exact target match.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants