From 8d32e219b98809a6f3b3c6b991d25e2c2ec8e9ca Mon Sep 17 00:00:00 2001 From: Jiahao Chen Zhou Date: Wed, 24 Jun 2026 20:55:13 +0000 Subject: [PATCH] Optimize MLA generate_mask with scatter update Replace high-complexity broadcasted comparison in `generate_mask` with a JAX scatter-based update to prevent OOM and reduce overhead. - Why: The previous implementation relied on a broadcasted comparison between the full sequence and the selected top-k indices. At scale (large context lengths), this created a massive logical intermediate tensor, causing immediate OOM crashes during local CPU testing. On TPU, even though the XLA compiler successfully fused the operation to avoid physical memory (HBM) OOM, the hardware remained heavily bottlenecked by the extreme number of element-wise ALU comparisons executed in registers. - How: Initializes the mask with `DEFAULT_MASK_VALUE` using `jnp.full`, then uses advanced indexing to scatter-write `0.0` at the selected `topk_indices`. This fundamentally changes the algorithm to perform direct writes instead of comparing all elements, reducing complexity and instruction count. - Verification: - Added `test_generate_mask_equivalence` to `attention_test.py` to verify mathematical equivalence (Passed). - Ran existing unit tests (Passed). TAG=agy CONV=f93063bc-d96c-46f1-9562-20d2a5bf3241 --- src/maxtext/layers/attention_mla.py | 15 ++++++++------- tests/unit/attention_test.py | 26 ++++++++++++++++++++++++++ 2 files changed, 34 insertions(+), 7 deletions(-) diff --git a/src/maxtext/layers/attention_mla.py b/src/maxtext/layers/attention_mla.py index df7fd16ea2..3592c629b8 100644 --- a/src/maxtext/layers/attention_mla.py +++ b/src/maxtext/layers/attention_mla.py @@ -232,15 +232,16 @@ def generate_mask(self, topk_indices, s): Returns: mask: [b, t, s] - `0.0` at topk_indices, `DEFAULT_MASK_VALUE` (large negative) elsewhere. """ - # 1. Create a range [0, 1, ..., s-1] - # 2. Broadcast compare against [b, t, k] to get [b, t, k, s] - # 3. Use .any() to see if a s-index is present in any of the k slots - is_topk = (jnp.arange(s) == topk_indices[..., None]).any(axis=-2) - # 4. Use where to select between 0.0 and the mask value - # cast values to dtype + b, t, _ = topk_indices.shape + batch_indices = jnp.arange(b)[:, None, None] + time_indices = jnp.arange(t)[None, :, None] + val_true = jnp.array(0.0, dtype=self.dtype) val_false = jnp.array(DEFAULT_MASK_VALUE, dtype=self.dtype) - return jnp.where(is_topk, val_true, val_false) + + mask = jnp.full((b, t, s), val_false, dtype=self.dtype) + mask = mask.at[batch_indices, time_indices, topk_indices].set(val_true) + return mask def __call__( self, diff --git a/tests/unit/attention_test.py b/tests/unit/attention_test.py index 404e089210..594566e2f6 100644 --- a/tests/unit/attention_test.py +++ b/tests/unit/attention_test.py @@ -1802,6 +1802,32 @@ def test_indexer_loss_kl_divergence_zero(self): np.testing.assert_allclose(loss, 0.0, atol=1e-5) + def test_generate_mask_equivalence(self): + """Test that the optimized scatter-based mask generation is equivalent to the old broadcast-based one.""" + mla_config_args = self.config_arguments.copy() + mla_config_args["use_indexer"] = True + mla_config_args["attention"] = "dot_product" + _, mla = self.init_mla(mla_config_args, rope_type="default") + + # Use a small shape that won't OOM the old method but is representative + b, t, k, s = 2, 128, 64, 1024 + + key = jax.random.PRNGKey(0) + topk_indices = jax.random.randint(key, (b, t, k), 0, s) + + # Old implementation (broadcast-based) reconstructed for comparison + def generate_mask_old(topk_indices, s): + is_topk = (jnp.arange(s) == topk_indices[..., None]).any(axis=-2) + val_true = jnp.array(0.0, dtype=mla.indexer.dtype) + val_false = jnp.array(DEFAULT_MASK_VALUE, dtype=mla.indexer.dtype) + return jnp.where(is_topk, val_true, val_false) + + mask_old = generate_mask_old(topk_indices, s) + mask_new = mla.indexer.generate_mask(topk_indices, s) + + # Assert exact mathematical equivalence + np.testing.assert_allclose(mask_new, mask_old, atol=1e-5) + def test_indexer_gradients(self): # Test that gradients do NOT flow back to inputs bsz, seqlen = 2, 8