diff --git a/src/maxtext/layers/attention_mla.py b/src/maxtext/layers/attention_mla.py index df7fd16ea2..3592c629b8 100644 --- a/src/maxtext/layers/attention_mla.py +++ b/src/maxtext/layers/attention_mla.py @@ -232,15 +232,16 @@ def generate_mask(self, topk_indices, s): Returns: mask: [b, t, s] - `0.0` at topk_indices, `DEFAULT_MASK_VALUE` (large negative) elsewhere. """ - # 1. Create a range [0, 1, ..., s-1] - # 2. Broadcast compare against [b, t, k] to get [b, t, k, s] - # 3. Use .any() to see if a s-index is present in any of the k slots - is_topk = (jnp.arange(s) == topk_indices[..., None]).any(axis=-2) - # 4. Use where to select between 0.0 and the mask value - # cast values to dtype + b, t, _ = topk_indices.shape + batch_indices = jnp.arange(b)[:, None, None] + time_indices = jnp.arange(t)[None, :, None] + val_true = jnp.array(0.0, dtype=self.dtype) val_false = jnp.array(DEFAULT_MASK_VALUE, dtype=self.dtype) - return jnp.where(is_topk, val_true, val_false) + + mask = jnp.full((b, t, s), val_false, dtype=self.dtype) + mask = mask.at[batch_indices, time_indices, topk_indices].set(val_true) + return mask def __call__( self, diff --git a/tests/unit/attention_test.py b/tests/unit/attention_test.py index 404e089210..594566e2f6 100644 --- a/tests/unit/attention_test.py +++ b/tests/unit/attention_test.py @@ -1802,6 +1802,32 @@ def test_indexer_loss_kl_divergence_zero(self): np.testing.assert_allclose(loss, 0.0, atol=1e-5) + def test_generate_mask_equivalence(self): + """Test that the optimized scatter-based mask generation is equivalent to the old broadcast-based one.""" + mla_config_args = self.config_arguments.copy() + mla_config_args["use_indexer"] = True + mla_config_args["attention"] = "dot_product" + _, mla = self.init_mla(mla_config_args, rope_type="default") + + # Use a small shape that won't OOM the old method but is representative + b, t, k, s = 2, 128, 64, 1024 + + key = jax.random.PRNGKey(0) + topk_indices = jax.random.randint(key, (b, t, k), 0, s) + + # Old implementation (broadcast-based) reconstructed for comparison + def generate_mask_old(topk_indices, s): + is_topk = (jnp.arange(s) == topk_indices[..., None]).any(axis=-2) + val_true = jnp.array(0.0, dtype=mla.indexer.dtype) + val_false = jnp.array(DEFAULT_MASK_VALUE, dtype=mla.indexer.dtype) + return jnp.where(is_topk, val_true, val_false) + + mask_old = generate_mask_old(topk_indices, s) + mask_new = mla.indexer.generate_mask(topk_indices, s) + + # Assert exact mathematical equivalence + np.testing.assert_allclose(mask_new, mask_old, atol=1e-5) + def test_indexer_gradients(self): # Test that gradients do NOT flow back to inputs bsz, seqlen = 2, 8