Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/maxtext/configs/base.yml
Original file line number Diff line number Diff line change
Expand Up @@ -397,6 +397,8 @@ use_indexer: false
indexer_head_dim: 128
indexer_n_heads: 64
indexer_topk: 2048
indexer_use_approx_top_k: false
indexer_approx_top_k_recall: 0.95
# Determines the training strategy for the indexer:
# - false (Dense Warm-up): Computes indexer loss over all tokens. Used with `trainable_parameters_mask` to freeze other model parameters.
# - true (Sparse Training): Computes indexer loss over top-k tokens only and detaches the indexer input for independent optimization.
Expand Down
4 changes: 4 additions & 0 deletions src/maxtext/configs/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -668,6 +668,10 @@ class AttentionIndexer(BaseModel):
description="Determines the training strategy for the indexer: Dense Warm-up or Sparse Training stage.",
)
indexer_loss_scaling_factor: float = Field(0.0, description="Multiplier for the indexer KL divergence loss.")
indexer_use_approx_top_k: bool = Field(
False, description="Whether to use approximate top-k selection for the indexer on TPU."
)
indexer_approx_top_k_recall: float = Field(0.95, description="Recall target for approximate top-k selection.")


class Llama4Attention(BaseModel):
Expand Down
9 changes: 8 additions & 1 deletion src/maxtext/layers/attention_mla.py
Original file line number Diff line number Diff line change
Expand Up @@ -366,7 +366,14 @@ def __call__(
indexer_score += attention_mask

# TopK selection based on index score
_, topk_indices = jax.lax.top_k(indexer_score, k=self.indexer_topk) # topk_indices [b, t, k]
if self.config.indexer_use_approx_top_k:
_, topk_indices = jax.lax.approx_max_k(
indexer_score,
k=self.indexer_topk,
recall_target=self.config.indexer_approx_top_k_recall,
)
else:
_, topk_indices = jax.lax.top_k(indexer_score, k=self.indexer_topk) # topk_indices [b, t, k]

# Create Sparse Index Mask: 0 and large negatives
indexer_mask = self.generate_mask(topk_indices, k.shape[1]) # [b, t, s]
Expand Down
56 changes: 56 additions & 0 deletions tests/unit/attention_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1802,6 +1802,62 @@ def test_indexer_loss_kl_divergence_zero(self):

np.testing.assert_allclose(loss, 0.0, atol=1e-5)

def test_indexer_with_approx_top_k(self):
"""Verify indexer runs with both approx and exact top-k."""
for use_approx in [False, True]:
with self.subTest(indexer_use_approx_top_k=use_approx):
mla_config_args = self.config_arguments.copy()
mla_config_args["use_indexer"] = True
mla_config_args["indexer_use_approx_top_k"] = use_approx
mla_config_args["indexer_topk"] = 4 # Force indexer to run instead of returning early
mla_config_args["attention"] = "dot_product"

cfg, mla = self.init_mla(mla_config_args, rope_type="default")

lnx, decoder_segment_ids, decoder_positions = self.get_structured_data(cfg, cfg.dtype)

# Run forward pass which triggers indexer
out, _ = mla(
lnx,
lnx,
decoder_segment_ids=decoder_segment_ids,
inputs_positions=decoder_positions,
deterministic=True,
model_mode=MODEL_MODE_TRAIN,
)
self.assertIsNotNone(out)

def test_approx_top_k_recall(self):
"""Verify that approx_max_k meets the specified recall target compared to exact top_k."""
jax_rng = jax.random.PRNGKey(0)

# We need a large enough N to make the approximation meaningful.
# Use shape [batch=4, queries=16, N=1024]
batch, queries, N = 4, 16, 1024
K = 64
recall_target = 0.95

# Generate random scores
scores = jax.random.normal(jax_rng, (batch, queries, N))

# 1. Run exact Top-K
_, true_indices = jax.lax.top_k(scores, k=K) # [batch, queries, K]

# 2. Run approx Top-K
_, approx_indices = jax.lax.approx_max_k(scores, k=K, recall_target=recall_target) # [batch, queries, K]

# 3. Calculate Recall
# Broadcast compare true_indices [B, Q, K, 1] and approx_indices [B, Q, 1, K]
matches = (true_indices[..., None] == approx_indices[..., None, :]).any(axis=-1) # [B, Q, K]
num_matches = matches.sum(axis=-1) # [B, Q]
actual_recalls = num_matches / K # [B, Q]
mean_recall = jnp.mean(actual_recalls)

print(f"\nApprox Top-K Recall Target: {recall_target}, Actual Mean Recall: {mean_recall:.4f}")

# Assert that the actual recall is equal or exceeds the target.
self.assertGreaterEqual(mean_recall, recall_target)

def test_indexer_gradients(self):
# Test that gradients do NOT flow back to inputs
bsz, seqlen = 2, 8
Expand Down
Loading