Description
On TPU v6e-16 (multi-host), combining two XLA:TPU collective-overlap flag groups silently corrupts outputs to NaN for a sequence-parallel (Ulysses) + splash-attention workload (MaxDiffusion Wan2.2-I2V-A14B DiT). Each group alone is numerically correct; only together do they NaN — while delivering a real ~5% speedup. The corruption appears tied to overlapping the Ulysses all-to-all/all-reduce collectives around the splash-attention Mosaic custom-call.
Environment
- TPU v6e-16 (4 hosts × 4 chips), multi-host via
jax.distributed
jax 0.10.2, flax 0.12.7, recent libtpu; bfloat16
- MaxDiffusion Wan2.2-I2V-A14B (two-expert DiT),
attention=ulysses_custom (splash/flash), ici_context=4 × ici_tensor=4
Flags that trigger it (LIBTPU_INIT_ARGS)
--xla_tpu_enable_async_collective_fusion=true
--xla_tpu_enable_async_collective_fusion_fuse_all_reduce=true
--xla_tpu_enable_async_collective_fusion_multiple_steps=true
--xla_tpu_overlap_compute_collective_tc=true
--xla_enable_async_all_reduce=true
Symptom
The DiT denoise output becomes NaN (and diverges across processes); caught downstream by a jax.device_put → multihost_utils.assert_equal on a replicated sharding:
AssertionError: ... not the same on each process. Expected: [[[[[nan nan nan ...
Bisection (warm DiT denoise; "correct" = downstream VAE decode succeeds)
| flags |
warm DiT |
correct? |
| baseline (none) |
269.7 s |
✅ |
| async-collective-fusion group only |
261.6 s |
✅ |
| overlap-compute-collective group only |
262.1 s |
✅ |
| megacore-fusion only |
260.9 s |
✅ |
| fusion + megacore |
260.6 s |
✅ |
| overlap + megacore |
262.0 s |
✅ |
| fusion + overlap |
247.9 s |
❌ NaN |
enable_async_collective_fusion + overlap_compute_collective_tc (primaries only) |
260.7 s |
✅ |
So it's specifically async-collective-fusion × overlap-compute-collective together, and requires the aggressive sub-flags (fuse_all_reduce / multiple_steps / async_all_reduce) that overlap collectives across steps. The ~5% speedup appears only in the NaN-producing config.
What we ruled out (no minimal repro yet)
A standalone multi-host program on the same 4×4 v6e mesh — a loop of matmul → psum(all-reduce) → all_to_all → all_to_all → tanh — does not reproduce, in f32 or bf16 (output bit-identical with vs without the flags). So this is not generic collective+overlap; it needs the workload's specific collective structure — most likely the interaction with the splash-attention Mosaic custom-call in the Ulysses attention path (which a pure-XLA matmul graph lacks).
Ask
Is this a known issue with async-collective-fusion + overlap-compute-collective around Mosaic/splash-attention custom-calls on v6e? A self-contained minimal repro would need to include the splash-kernel + Ulysses all-to-all pattern; happy to provide HLO dumps or test candidate fixes/flags with guidance.
Filing here (rather than jax-ml/jax) since this is MaxDiffusion's ulysses_custom + splash-attention path on v6e, and the combination delivers a real ~5% DiT speedup that's currently unusable. Questions: is this a known incompatibility for the Wan2.2-I2V Ulysses path? Should MaxDiffusion guard/validate these flags, or is there a MaxDiffusion-side scheduling fix — and can this be escalated to the XLA:TPU/libtpu team with the model repro?
Description
On TPU v6e-16 (multi-host), combining two XLA:TPU collective-overlap flag groups silently corrupts outputs to
NaNfor a sequence-parallel (Ulysses) + splash-attention workload (MaxDiffusion Wan2.2-I2V-A14B DiT). Each group alone is numerically correct; only together do they NaN — while delivering a real ~5% speedup. The corruption appears tied to overlapping the Ulyssesall-to-all/all-reducecollectives around the splash-attention Mosaic custom-call.Environment
jax.distributedjax0.10.2,flax0.12.7, recentlibtpu;bfloat16attention=ulysses_custom(splash/flash),ici_context=4 × ici_tensor=4Flags that trigger it (
LIBTPU_INIT_ARGS)Symptom
The DiT denoise output becomes
NaN(and diverges across processes); caught downstream by ajax.device_put→multihost_utils.assert_equalon a replicated sharding:Bisection (warm DiT denoise; "correct" = downstream VAE decode succeeds)
enable_async_collective_fusion+overlap_compute_collective_tc(primaries only)So it's specifically
async-collective-fusion×overlap-compute-collectivetogether, and requires the aggressive sub-flags (fuse_all_reduce/multiple_steps/async_all_reduce) that overlap collectives across steps. The ~5% speedup appears only in the NaN-producing config.What we ruled out (no minimal repro yet)
A standalone multi-host program on the same 4×4 v6e mesh — a loop of
matmul → psum(all-reduce) → all_to_all → all_to_all → tanh— does not reproduce, in f32 or bf16 (output bit-identical with vs without the flags). So this is not generic collective+overlap; it needs the workload's specific collective structure — most likely the interaction with the splash-attention Mosaic custom-call in the Ulysses attention path (which a pure-XLA matmul graph lacks).Ask
Is this a known issue with
async-collective-fusion+overlap-compute-collectivearound Mosaic/splash-attention custom-calls on v6e? A self-contained minimal repro would need to include the splash-kernel + Ulysses all-to-all pattern; happy to provide HLO dumps or test candidate fixes/flags with guidance.Filing here (rather than jax-ml/jax) since this is MaxDiffusion's
ulysses_custom+ splash-attention path on v6e, and the combination delivers a real ~5% DiT speedup that's currently unusable. Questions: is this a known incompatibility for the Wan2.2-I2V Ulysses path? Should MaxDiffusion guard/validate these flags, or is there a MaxDiffusion-side scheduling fix — and can this be escalated to the XLA:TPU/libtpu team with the model repro?