Skip to content

[Neuron] Add tensor parallel support for Neuron backend#13718

Open
JingyaHuang wants to merge 54 commits into
huggingface:mainfrom
JingyaHuang:support-neuron-tp
Open

[Neuron] Add tensor parallel support for Neuron backend#13718
JingyaHuang wants to merge 54 commits into
huggingface:mainfrom
JingyaHuang:support-neuron-tp

Conversation

@JingyaHuang

@JingyaHuang JingyaHuang commented May 11, 2026

Copy link
Copy Markdown
Contributor

What does this PR do?

Adds tensor-parallel (TP) inference for diffusers models on AWS Neuron (Trainium/Inferentia) device. Here as suggested we use Flux2 Klein as the starting point. But the TP support here is generic, easy to extend to other backend(cuda, tpu and more) and is exposed through the existing public API used for CP: model.enable_parallelism(config=TensorParallelConfig(...)).

Key changes:

  • A model-agnostic apply_tensor_parallel that shards from a flat _tp_plan.

Quick test — Flux2 TP on Neuron (For future release)

run with torchrun --nproc_per_node=8 flux2_tp8_neuron.py

import torch
  import torch.distributed as dist
  from torch.distributed.device_mesh import DeviceMesh
  import torch_neuronx  # noqa: F401 — registers torch.neuron

  from diffusers import Flux2KleinPipeline, TensorParallelConfig

  MODEL = "black-forest-labs/FLUX.2-klein-9B"
  PROMPT = "a golden retriever surfing a wave, photorealistic"

  dist.init_process_group(backend="neuron")
  device = torch.neuron.current_device()
  rank = dist.get_rank()
  tp_size = dist.get_world_size()
  tp_mesh = DeviceMesh("neuron", list(range(tp_size)))

  pipe = Flux2KleinPipeline.from_pretrained(MODEL, torch_dtype=torch.bfloat16)

  # Text encoder + VAE: replicated on every rank (no TP).
  pipe.text_encoder = pipe.text_encoder.to(device)
  pipe.vae = pipe.vae.to(device)

  # Transformer: shard across all ranks while still on CPU, then move to device.
  pipe.transformer.enable_parallelism(config=TensorParallelConfig(mesh=tp_mesh))
  pipe.transformer = pipe.transformer.to(device)
  torch.neuron.synchronize()

  image = pipe(
      prompt=PROMPT, height=1024, width=1024,
      num_inference_steps=4, guidance_scale=1.0,
  ).images[0]

  if rank == 0:
      image.save("flux2_tp8.png")
      print("Saved flux2_tp8.png")

  dist.destroy_process_group()

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@github-actions github-actions Bot added the tests label Jun 24, 2026
@JingyaHuang JingyaHuang marked this pull request as ready for review June 24, 2026 15:30

# flash-attn only returns LSE if dropout_p > 0. So, we need to workaround.
if grad_enabled or (_parallel_config is not None and _parallel_config.context_parallel_config._world_size > 1):
if grad_enabled or (_parallel_config is not None and _parallel_config._cp_world_size > 1):

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

With TP, context_parallel_config can be None, we set up _parallel_config._cp_world_size for it.

@JingyaHuang JingyaHuang requested a review from sayakpaul June 25, 2026 12:25
Comment thread docs/source/en/training/distributed_inference.md Outdated
HaozheZhang6 and others added 5 commits June 26, 2026 15:38
…rs (huggingface#13946)

SkyReels-V2 and ChronoEdit are both built on Wan, and their transformers have
the same keys as WanTransformer3DModel, so they reuse
convert_wan_transformer_to_diffusers (like WanVACE / WanAnimate). This lets the
community GGUF builds load directly.

Co-authored-by: Claude Opus 4.8 <noreply@anthropic.com>
fix(cosmos3): pin VAE latent norm buffers to encode output device

Under sharded placement (device_map="balanced"), vae.encode() runs on the
VAE's own device while the mean/inv_std buffers were pinned to x.device,
causing a cross-device RuntimeError. Compute raw_mu first, then pin the
normalization buffers to its device so all tensors share one device.

Co-authored-by: Atharva Joshi <atjoshi@smc521ge-0036.ipp2a2.colossus.nvidia.com>
Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
…13876)

* docs: fix repeated word typo in set_timesteps docstring

Removed the duplicate word "schedule" from the docstring for the sigmas argument in EulerDiscreteScheduler.set_timesteps.

* Update scheduling_euler_discrete.py

* Apply style fixes

---------

Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>

@sayakpaul sayakpaul left a comment

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for working on this.

There is lot of intrusive and model-specific changes which I think is a bit of an anti-pattern. I think it's also probably because of some of the fusion stuff that's happening inside Flux2.

More specifically, the intrusive pieces exist for one reason: Flux2 fuses projections into single Linears (SwiGLU gate+linear, and to_qkv_mlp_proj packing Q/K/V and MLP).

Contiguous column sharding is blind to that internal layout, so:

  • you must reorder rows so each rank gets paired slices -> the permuters, and
  • the local tensor width no longer factors as heads × head_dim or splits cleanly into qkv/mlp -> the runtime local_* recomputation.

I opened JingyaHuang#1 to simplify some of the stuff. LMK.

Furthermore, would the changes related to fusing be the same for Flux1, for example? I think gf the layers were unfused, parallelize_module + DTensor would handle head-splitting automatically and none of this would be needed.

Comment thread src/diffusers/hooks/tensor_parallel.py
Comment thread src/diffusers/hooks/tensor_parallel.py
Comment thread src/diffusers/hooks/tensor_parallel.py Outdated
config: TensorParallelConfig,
tp_plan: dict,
*,
backend: str = "default",

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can this not be derived from torch_device?

return _get_projections(attn, hidden_states, encoder_hidden_states)


def _get_tp_degree(parallel_config) -> int:

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems like it should be present in _modeling_parallel.py?

Comment on lines +255 to +265
@property
def _cp_world_size(self) -> int:
"""Context-parallel world size, or 1 when context parallelism is not enabled.

Lets attention backends branch on context parallelism without dereferencing a possibly ``None``
``context_parallel_config`` (e.g. when only tensor parallelism is active).
"""
cp = self.context_parallel_config
if cp is None or cp._world_size is None:
return 1
return cp._world_size

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Where is this needed?

Comment on lines +914 to +920
# On Neuron, run the index-heavy `_unpack_latents_with_ids` on CPU to avoid expensive
# device<->host syncs from the gather/scatter arithmetic, then move the result back.
latent_device = latents.device
on_neuron = get_device() == "neuron"
if on_neuron:
latents = latents.cpu()
latent_ids = latent_ids.cpu()

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this not needed on CUDA?

sayakpaul and others added 7 commits June 26, 2026 17:07
Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
…arallel

Adopts Sayak's changes from #1 that replace the
Flux2-specific _tp_fused_block_permuters (permute-then-slice) with generic
PackedColwiseParallel / PackedRowwiseParallel styles that slice fused
projections block-by-block. Also drops the now-unused
_tp_fused_block_permuters base-class default in modeling_utils.

Keeps torch.chunk in Flux2SwiGLU.forward (TorchAO compile regression fix),
overriding the half-slicing on Sayak's branch.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
@sayakpaul

Copy link
Copy Markdown
Member

@JingyaHuang did my PR break any neuron-specific stuff?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

documentation Improvements or additions to documentation hooks models pipelines size/L PR with diff > 200 LOC tests utils

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants