diff --git a/src/maxdiffusion/models/attention_flax.py b/src/maxdiffusion/models/attention_flax.py index edc9f4f7..e5d94a54 100644 --- a/src/maxdiffusion/models/attention_flax.py +++ b/src/maxdiffusion/models/attention_flax.py @@ -15,6 +15,7 @@ import contextlib import functools import math +import os from typing import Optional, Callable, Tuple, Any, Dict import flax.linen as nn from flax import nnx @@ -605,6 +606,21 @@ def _ulysses_attention( "Ulysses attention requires the number of heads to be divisible by the context shard count, " f"got heads={num_heads} and context_shards={num_shards}." ) + + # EXPERIMENTAL: split the all-to-all into `num_chunks` head-groups so XLA's + # async-collective scheduler can overlap one chunk's attention compute with + # the next chunk's all-to-all. Gated on an env var so it stays opt-in. The + # math is identical to the single-shot path (heads are independent); requires + # async-collective LIBTPU flags to actually overlap, and the per-chunk head + # count must still be shardable across the context axis. + num_chunks = int(os.environ.get("ULYSSES_ATTENTION_CHUNKS", "1")) + if num_chunks > 1: + if num_heads % (num_shards * num_chunks) != 0: + raise ValueError( + "ULYSSES_ATTENTION_CHUNKS requires heads divisible by (context_shards * chunks), " + f"got heads={num_heads}, context_shards={num_shards}, chunks={num_chunks}." + ) + if not use_custom_kernel: block_sizes = _select_flash_block_sizes(query, key, flash_block_sizes, dtype, "flash") @@ -721,7 +737,23 @@ def wrap_ulysses_attention(query, key, value): "Warning, batch dimension should be shardable among the devices in data and fsdp" f" axis, batch dimension: {query.shape[0]}, devices_in_batch_sharding: {devices_in_batch_sharding}" ) - x = wrap_ulysses_attention(query, key, value) + if num_chunks > 1: + # EXPERIMENTAL two-phase path: run the all-to-all + attention per head-group. + # Heads are independent, so this is numerically identical to the single-shot + # path; the goal is to let XLA overlap one group's compute with the next + # group's all-to-all (requires async-collective LIBTPU flags). + head_step = num_heads // num_chunks + chunk_outputs = [ + wrap_ulysses_attention( + query[:, i * head_step : (i + 1) * head_step], + key[:, i * head_step : (i + 1) * head_step], + value[:, i * head_step : (i + 1) * head_step], + ) + for i in range(num_chunks) + ] + x = jnp.concatenate(chunk_outputs, axis=1) + else: + x = wrap_ulysses_attention(query, key, value) x = x[:, :, :orig_q_seq_len, :] x = _reshape_heads_to_head_dim(x)