Skip to content

NaN miscompile: async-collective-fusion + overlap-compute-collective corrupt Wan2.2-I2V ulysses+splash attention on v6e (blocks ~5% DiT speedup) #429

Description

@ThomasNing

Description

On TPU v6e-16 (multi-host), combining two XLA:TPU collective-overlap flag groups silently corrupts outputs to NaN for a sequence-parallel (Ulysses) + splash-attention workload (MaxDiffusion Wan2.2-I2V-A14B DiT). Each group alone is numerically correct; only together do they NaN — while delivering a real ~5% speedup. The corruption appears tied to overlapping the Ulysses all-to-all/all-reduce collectives around the splash-attention Mosaic custom-call.

Environment

  • TPU v6e-16 (4 hosts × 4 chips), multi-host via jax.distributed
  • jax 0.10.2, flax 0.12.7, recent libtpu; bfloat16
  • MaxDiffusion Wan2.2-I2V-A14B (two-expert DiT), attention=ulysses_custom (splash/flash), ici_context=4 × ici_tensor=4

Flags that trigger it (LIBTPU_INIT_ARGS)

--xla_tpu_enable_async_collective_fusion=true
--xla_tpu_enable_async_collective_fusion_fuse_all_reduce=true
--xla_tpu_enable_async_collective_fusion_multiple_steps=true
--xla_tpu_overlap_compute_collective_tc=true
--xla_enable_async_all_reduce=true

Symptom

The DiT denoise output becomes NaN (and diverges across processes); caught downstream by a jax.device_putmultihost_utils.assert_equal on a replicated sharding:

AssertionError: ... not the same on each process. Expected: [[[[[nan nan nan ...

Bisection (warm DiT denoise; "correct" = downstream VAE decode succeeds)

flags warm DiT correct?
baseline (none) 269.7 s
async-collective-fusion group only 261.6 s
overlap-compute-collective group only 262.1 s
megacore-fusion only 260.9 s
fusion + megacore 260.6 s
overlap + megacore 262.0 s
fusion + overlap 247.9 s NaN
enable_async_collective_fusion + overlap_compute_collective_tc (primaries only) 260.7 s

So it's specifically async-collective-fusion × overlap-compute-collective together, and requires the aggressive sub-flags (fuse_all_reduce / multiple_steps / async_all_reduce) that overlap collectives across steps. The ~5% speedup appears only in the NaN-producing config.

What we ruled out (no minimal repro yet)

A standalone multi-host program on the same 4×4 v6e mesh — a loop of matmul → psum(all-reduce) → all_to_all → all_to_all → tanh — does not reproduce, in f32 or bf16 (output bit-identical with vs without the flags). So this is not generic collective+overlap; it needs the workload's specific collective structure — most likely the interaction with the splash-attention Mosaic custom-call in the Ulysses attention path (which a pure-XLA matmul graph lacks).

Ask

Is this a known issue with async-collective-fusion + overlap-compute-collective around Mosaic/splash-attention custom-calls on v6e? A self-contained minimal repro would need to include the splash-kernel + Ulysses all-to-all pattern; happy to provide HLO dumps or test candidate fixes/flags with guidance.


Filing here (rather than jax-ml/jax) since this is MaxDiffusion's ulysses_custom + splash-attention path on v6e, and the combination delivers a real ~5% DiT speedup that's currently unusable. Questions: is this a known incompatibility for the Wan2.2-I2V Ulysses path? Should MaxDiffusion guard/validate these flags, or is there a MaxDiffusion-side scheduling fix — and can this be escalated to the XLA:TPU/libtpu team with the model repro?

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions