Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 8 additions & 7 deletions src/maxtext/layers/attention_mla.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,15 +232,16 @@ def generate_mask(self, topk_indices, s):
Returns:
mask: [b, t, s] - `0.0` at topk_indices, `DEFAULT_MASK_VALUE` (large negative) elsewhere.
"""
# 1. Create a range [0, 1, ..., s-1]
# 2. Broadcast compare against [b, t, k] to get [b, t, k, s]
# 3. Use .any() to see if a s-index is present in any of the k slots
is_topk = (jnp.arange(s) == topk_indices[..., None]).any(axis=-2)
# 4. Use where to select between 0.0 and the mask value
# cast values to dtype
b, t, _ = topk_indices.shape
batch_indices = jnp.arange(b)[:, None, None]
time_indices = jnp.arange(t)[None, :, None]

val_true = jnp.array(0.0, dtype=self.dtype)
val_false = jnp.array(DEFAULT_MASK_VALUE, dtype=self.dtype)
return jnp.where(is_topk, val_true, val_false)

mask = jnp.full((b, t, s), val_false, dtype=self.dtype)
mask = mask.at[batch_indices, time_indices, topk_indices].set(val_true)
return mask

def __call__(
self,
Expand Down
26 changes: 26 additions & 0 deletions tests/unit/attention_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1802,6 +1802,32 @@ def test_indexer_loss_kl_divergence_zero(self):

np.testing.assert_allclose(loss, 0.0, atol=1e-5)

def test_generate_mask_equivalence(self):
"""Test that the optimized scatter-based mask generation is equivalent to the old broadcast-based one."""
mla_config_args = self.config_arguments.copy()
mla_config_args["use_indexer"] = True
mla_config_args["attention"] = "dot_product"
_, mla = self.init_mla(mla_config_args, rope_type="default")

# Use a small shape that won't OOM the old method but is representative
b, t, k, s = 2, 128, 64, 1024

key = jax.random.PRNGKey(0)
topk_indices = jax.random.randint(key, (b, t, k), 0, s)

# Old implementation (broadcast-based) reconstructed for comparison
def generate_mask_old(topk_indices, s):
is_topk = (jnp.arange(s) == topk_indices[..., None]).any(axis=-2)
val_true = jnp.array(0.0, dtype=mla.indexer.dtype)
val_false = jnp.array(DEFAULT_MASK_VALUE, dtype=mla.indexer.dtype)
return jnp.where(is_topk, val_true, val_false)

mask_old = generate_mask_old(topk_indices, s)
mask_new = mla.indexer.generate_mask(topk_indices, s)

# Assert exact mathematical equivalence
np.testing.assert_allclose(mask_new, mask_old, atol=1e-5)

def test_indexer_gradients(self):
# Test that gradients do NOT flow back to inputs
bsz, seqlen = 2, 8
Expand Down
Loading