Skip to content

feat: experimental two-phase (head-chunked) Ulysses all-to-all#428

Draft
csgoogle wants to merge 1 commit into
mainfrom
sagarchapara/ulysses-two-phase
Draft

feat: experimental two-phase (head-chunked) Ulysses all-to-all#428
csgoogle wants to merge 1 commit into
mainfrom
sagarchapara/ulysses-two-phase

Conversation

@csgoogle

@csgoogle csgoogle commented Jun 24, 2026

Copy link
Copy Markdown
Collaborator

Add an opt-in ULYSSES_ATTENTION_CHUNKS env var to split the Ulysses all-to-all into per-head-group passes, so XLA's async-collective scheduler can overlap one group's attention compute with the next group's all-to-all. Defaults to 1 (current single-shot path, no behavior change). Numerically identical to single-shot since heads are independent.

Notes:

  • Requires async-collective LIBTPU flags to actually overlap.
  • Gain is largest when all-to-all is a meaningful fraction of attention time (high context-parallelism / shorter sequences); at WAN 2.2 720p (seq~75600) it is compute-bound so the win is small (~3% in microbench), but for seqlen ~24k we observe ~10% gains

@github-actions

Copy link
Copy Markdown

@google-cla

google-cla Bot commented Jun 24, 2026

Copy link
Copy Markdown

Thanks for your pull request! It looks like this may be your first contribution to a Google open source project. Before we can look at your pull request, you'll need to sign a Contributor License Agreement (CLA).

View this failed invocation of the CLA check for more information.

For the most up to date status, view the checks section at the bottom of the pull request.

Add an opt-in ULYSSES_ATTENTION_CHUNKS env var to split the Ulysses
all-to-all into per-head-group passes, so XLA's async-collective
scheduler can overlap one group's attention compute with the next
group's all-to-all. Defaults to 1 (current single-shot path, no
behavior change). Numerically identical to single-shot since heads
are independent.

Notes:
- Requires async-collective LIBTPU flags to actually overlap.
- Needs heads % (context_shards * chunks) == 0.
- Gain is largest when all-to-all is a meaningful fraction of attention
  time (high context-parallelism / shorter sequences); at WAN 2.2 720p
  (seq~75600) it is compute-bound so the win is small (~3% in microbench).
@csgoogle csgoogle force-pushed the sagarchapara/ulysses-two-phase branch from 7240f50 to 0d936f8 Compare June 24, 2026 13:59
@csgoogle csgoogle requested a review from Perseus14 June 24, 2026 14:05
# math is identical to the single-shot path (heads are independent); requires
# async-collective LIBTPU flags to actually overlap, and the per-chunk head
# count must still be shardable across the context axis.
num_chunks = int(os.environ.get("ULYSSES_ATTENTION_CHUNKS", "1"))

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's move this to config file to be used for any ulysses type kernel

f"got heads={num_heads} and context_shards={num_shards}."
)

# EXPERIMENTAL: split the all-to-all into `num_chunks` head-groups so XLA's

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this work on ulysses + ring as well?

@Perseus14 Perseus14 requested a review from eltsai June 24, 2026 20:14
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants