From 76dae1b6e0d7867a5da294d727b64ef54ba1510c Mon Sep 17 00:00:00 2001 From: Jiahao Chen Zhou Date: Fri, 26 Jun 2026 21:45:24 +0000 Subject: [PATCH] Add optional Approximate Top-K configuration for MLA Indexer --- src/maxtext/configs/base.yml | 2 ++ src/maxtext/configs/types.py | 4 +++ src/maxtext/layers/attention_mla.py | 9 ++++- tests/unit/attention_test.py | 56 +++++++++++++++++++++++++++++ 4 files changed, 70 insertions(+), 1 deletion(-) diff --git a/src/maxtext/configs/base.yml b/src/maxtext/configs/base.yml index 1541170fee..12061d2c91 100644 --- a/src/maxtext/configs/base.yml +++ b/src/maxtext/configs/base.yml @@ -397,6 +397,8 @@ use_indexer: false indexer_head_dim: 128 indexer_n_heads: 64 indexer_topk: 2048 +indexer_use_approx_top_k: false +indexer_approx_top_k_recall: 0.95 # Determines the training strategy for the indexer: # - false (Dense Warm-up): Computes indexer loss over all tokens. Used with `trainable_parameters_mask` to freeze other model parameters. # - true (Sparse Training): Computes indexer loss over top-k tokens only and detaches the indexer input for independent optimization. diff --git a/src/maxtext/configs/types.py b/src/maxtext/configs/types.py index 3efd2e2f55..97aec80717 100644 --- a/src/maxtext/configs/types.py +++ b/src/maxtext/configs/types.py @@ -668,6 +668,10 @@ class AttentionIndexer(BaseModel): description="Determines the training strategy for the indexer: Dense Warm-up or Sparse Training stage.", ) indexer_loss_scaling_factor: float = Field(0.0, description="Multiplier for the indexer KL divergence loss.") + indexer_use_approx_top_k: bool = Field( + False, description="Whether to use approximate top-k selection for the indexer on TPU." + ) + indexer_approx_top_k_recall: float = Field(0.95, description="Recall target for approximate top-k selection.") class Llama4Attention(BaseModel): diff --git a/src/maxtext/layers/attention_mla.py b/src/maxtext/layers/attention_mla.py index df7fd16ea2..d5f79a1a4b 100644 --- a/src/maxtext/layers/attention_mla.py +++ b/src/maxtext/layers/attention_mla.py @@ -366,7 +366,14 @@ def __call__( indexer_score += attention_mask # TopK selection based on index score - _, topk_indices = jax.lax.top_k(indexer_score, k=self.indexer_topk) # topk_indices [b, t, k] + if self.config.indexer_use_approx_top_k: + _, topk_indices = jax.lax.approx_max_k( + indexer_score, + k=self.indexer_topk, + recall_target=self.config.indexer_approx_top_k_recall, + ) + else: + _, topk_indices = jax.lax.top_k(indexer_score, k=self.indexer_topk) # topk_indices [b, t, k] # Create Sparse Index Mask: 0 and large negatives indexer_mask = self.generate_mask(topk_indices, k.shape[1]) # [b, t, s] diff --git a/tests/unit/attention_test.py b/tests/unit/attention_test.py index 3fa8391833..261e13ca28 100644 --- a/tests/unit/attention_test.py +++ b/tests/unit/attention_test.py @@ -1802,6 +1802,62 @@ def test_indexer_loss_kl_divergence_zero(self): np.testing.assert_allclose(loss, 0.0, atol=1e-5) + def test_indexer_with_approx_top_k(self): + """Verify indexer runs with both approx and exact top-k.""" + for use_approx in [False, True]: + with self.subTest(indexer_use_approx_top_k=use_approx): + mla_config_args = self.config_arguments.copy() + mla_config_args["use_indexer"] = True + mla_config_args["indexer_use_approx_top_k"] = use_approx + mla_config_args["indexer_topk"] = 4 # Force indexer to run instead of returning early + mla_config_args["attention"] = "dot_product" + + cfg, mla = self.init_mla(mla_config_args, rope_type="default") + + lnx, decoder_segment_ids, decoder_positions = self.get_structured_data(cfg, cfg.dtype) + + # Run forward pass which triggers indexer + out, _ = mla( + lnx, + lnx, + decoder_segment_ids=decoder_segment_ids, + inputs_positions=decoder_positions, + deterministic=True, + model_mode=MODEL_MODE_TRAIN, + ) + self.assertIsNotNone(out) + + def test_approx_top_k_recall(self): + """Verify that approx_max_k meets the specified recall target compared to exact top_k.""" + jax_rng = jax.random.PRNGKey(0) + + # We need a large enough N to make the approximation meaningful. + # Use shape [batch=4, queries=16, N=1024] + batch, queries, N = 4, 16, 1024 + K = 64 + recall_target = 0.95 + + # Generate random scores + scores = jax.random.normal(jax_rng, (batch, queries, N)) + + # 1. Run exact Top-K + _, true_indices = jax.lax.top_k(scores, k=K) # [batch, queries, K] + + # 2. Run approx Top-K + _, approx_indices = jax.lax.approx_max_k(scores, k=K, recall_target=recall_target) # [batch, queries, K] + + # 3. Calculate Recall + # Broadcast compare true_indices [B, Q, K, 1] and approx_indices [B, Q, 1, K] + matches = (true_indices[..., None] == approx_indices[..., None, :]).any(axis=-1) # [B, Q, K] + num_matches = matches.sum(axis=-1) # [B, Q] + actual_recalls = num_matches / K # [B, Q] + mean_recall = jnp.mean(actual_recalls) + + print(f"\nApprox Top-K Recall Target: {recall_target}, Actual Mean Recall: {mean_recall:.4f}") + + # Assert that the actual recall is equal or exceeds the target. + self.assertGreaterEqual(mean_recall, recall_target) + def test_indexer_gradients(self): # Test that gradients do NOT flow back to inputs bsz, seqlen = 2, 8