Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
54 commits
Select commit Hold shift + click to select a range
98f6c8c
draft:add neuron as a legit backend
JingyaHuang Mar 18, 2026
c58b8b8
Merge branch 'huggingface:main' into add-neuron-backend
JingyaHuang Mar 18, 2026
3367409
Merge branch 'huggingface:main' into add-neuron-backend
JingyaHuang Mar 19, 2026
0c51734
Merge branch 'main' into add-neuron-backend
JingyaHuang Mar 25, 2026
a76953c
feat: neuron-specific changes in the pipeline
JingyaHuang Mar 26, 2026
2480388
tests: eager tests
JingyaHuang Mar 27, 2026
1469c04
draft: start with tp for flux2
JingyaHuang Apr 9, 2026
929ab72
fix: style
JingyaHuang Apr 9, 2026
52cac76
Merge branch 'huggingface:main' into add-neuron-backend
JingyaHuang Apr 9, 2026
30cb353
Merge branch 'huggingface:main' into support-neuron-tp
JingyaHuang Apr 9, 2026
28a5086
Merge branch 'add-neuron-backend' of github.com:JingyaHuang/diffusers…
JingyaHuang Apr 9, 2026
7fab0c4
Merge branch 'huggingface:main' into support-neuron-tp
JingyaHuang Apr 10, 2026
68689e5
Merge branch 'huggingface:main' into add-neuron-backend
JingyaHuang Apr 10, 2026
da79308
Merge branch 'main' into add-neuron-backend
JingyaHuang Apr 10, 2026
3bb9c7c
fix:apr_02 beta
JingyaHuang Apr 10, 2026
c4facab
Merge branch 'add-neuron-backend' of github.com:JingyaHuang/diffusers…
JingyaHuang Apr 10, 2026
dff1f32
feat:add wan
JingyaHuang Apr 10, 2026
1c930c4
Merge branch 'huggingface:main' into support-neuron-tp
JingyaHuang Apr 13, 2026
1eb5ff9
Merge branch 'huggingface:main' into add-neuron-backend
JingyaHuang Apr 13, 2026
cbe8f28
fix:pixart
JingyaHuang Apr 14, 2026
16b9606
fix: rewrite flux swiglu activation to avoid gather op in neuron IR
JingyaHuang Apr 15, 2026
7f13f68
test: pixart compile mode on neuron
JingyaHuang Apr 15, 2026
a46cb19
Merge branch 'main' into neuron-torch-comppile
JingyaHuang Apr 22, 2026
a354b88
cleanup & fix style
JingyaHuang May 11, 2026
931bb85
Merge branch 'neuron-torch-comppile' into support-neuron-tp
JingyaHuang May 11, 2026
9ab6dc3
Merge branch 'main' into support-neuron-tp
JingyaHuang May 11, 2026
48fb75b
Merge branch 'main' into support-neuron-tp
JingyaHuang Jun 22, 2026
c350f7b
merge: another change
JingyaHuang Jun 22, 2026
644477a
Merge branch 'main' into support-neuron-tp
JingyaHuang Jun 22, 2026
03cb725
review: cleanup+suggestions
JingyaHuang Jun 23, 2026
9da93ed
Merge branch 'support-neuron-tp' of github.com:JingyaHuang/diffusers …
JingyaHuang Jun 23, 2026
d44f772
fix: CIs style
JingyaHuang Jun 24, 2026
3fc043e
Merge branch 'main' into support-neuron-tp
JingyaHuang Jun 24, 2026
e6d20d8
tests: add test units for tp
JingyaHuang Jun 24, 2026
e76a2fc
Merge branch 'support-neuron-tp' of github.com:JingyaHuang/diffusers …
JingyaHuang Jun 24, 2026
034ba9e
fix: in case of text-encoder(s) on CPU
JingyaHuang Jun 24, 2026
4907524
review:cleanup+add test
JingyaHuang Jun 25, 2026
af2aed7
Merge branch 'support-neuron-tp' of github.com:JingyaHuang/diffusers …
JingyaHuang Jun 25, 2026
b9b048b
Merge branch 'main' into support-neuron-tp
JingyaHuang Jun 25, 2026
915eeb1
fix: style
JingyaHuang Jun 25, 2026
720dad2
Merge branch 'support-neuron-tp' of github.com:JingyaHuang/diffusers …
JingyaHuang Jun 25, 2026
89cf8b6
doc: remove it for now
JingyaHuang Jun 25, 2026
29cd9c3
Add from_single_file support for SkyReelsV2 and ChronoEdit transforme…
HaozheZhang6 Jun 25, 2026
eaab299
multi-GPU VAE Fix for Cosmos 3 (#13924)
atharvajoshi10 Jun 25, 2026
30a43d5
docs: fix repeated word typo in set_timesteps docstring (#13876)
ramkumar27072006 Jun 26, 2026
155802c
clean some stuff to simplify code.
sayakpaul Jun 26, 2026
f133732
Merge branch 'main' into JingyaHuang-support-neuron-tp
sayakpaul Jun 26, 2026
b3d8130
clean more to remove permutation related shenanigans.
sayakpaul Jun 27, 2026
7ea75f7
revert: put torch.chunk back
JingyaHuang Jul 1, 2026
c73cf09
Merge branch 'main' into support-neuron-tp
JingyaHuang Jul 1, 2026
eb58402
Merge branch 'main' into support-neuron-tp
JingyaHuang Jul 2, 2026
c3e123c
Update docs/source/en/training/distributed_inference.md
JingyaHuang Jul 2, 2026
dc33e26
Merge Sayak's TP simplification (PR #1): generic Packed{Col,Row}wiseP…
JingyaHuang Jul 2, 2026
491c537
Merge branch 'support-neuron-tp' of github.com:JingyaHuang/diffusers …
JingyaHuang Jul 2, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions docs/source/en/api/parallel.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,3 +22,9 @@ Parallelism strategies help speed up diffusion transformers by distributing comp
[[autodoc]] ContextParallelConfig

[[autodoc]] hooks.apply_context_parallel

## TensorParallelConfig

[[autodoc]] TensorParallelConfig

[[autodoc]] hooks.apply_tensor_parallel
44 changes: 44 additions & 0 deletions docs/source/en/training/distributed_inference.md
Original file line number Diff line number Diff line change
Expand Up @@ -431,3 +431,47 @@ pipeline = DiffusionPipeline.from_pretrained(
CKPT_ID, transformer=transformer, torch_dtype=torch.bfloat16,
).to(device)
```

## Tensor parallelism

[Tensor parallelism](https://huggingface.co/spaces/nanotron/ultrascale-playbook?section=tensor_parallelism) shards the weight matrices of a model across devices. Each device holds a column-wise (`"colwise"`) or row-wise (`"rowwise"`) slice of each layer, computes a partial result, and an `AllReduce`/`AllGather` at the layer boundary reconstructs the full output. Unlike context parallelism, it reduces the per-device *weight* memory, which is useful for models that do not fit on a single device.

Pass a [`TensorParallelConfig`] to [`~ModelMixin.enable_parallelism`]. `tp_degree` is the number of devices to shard across and must divide the model's number of attention heads. The model must define a `_tp_plan` (a flat mapping of module-name globs to a `"colwise"`/`"rowwise"` style).

```py
import torch
from torch import distributed as dist
from diffusers import DiffusionPipeline, TensorParallelConfig

def setup_distributed():
if not dist.is_initialized():
dist.init_process_group(backend="nccl")
rank = dist.get_rank()
device = torch.device(f"cuda:{rank}")
torch.cuda.set_device(device)
return device

def main():
device = setup_distributed()
world_size = dist.get_world_size()

pipeline = DiffusionPipeline.from_pretrained(
"black-forest-labs/FLUX.2-dev", torch_dtype=torch.bfloat16
).to(device)

pipeline.transformer.enable_parallelism(config=TensorParallelConfig(tp_degree=world_size))

generator = torch.Generator().manual_seed(42)
image = pipeline("a cat holding a sign that says hello", generator=generator).images[0]
if dist.get_rank() == 0:
image.save("output.png")
if dist.is_initialized():
dist.destroy_process_group()

if __name__ == "__main__":
main()
```

```shell
torchrun --nproc-per-node 2 above_script.py
```
2 changes: 2 additions & 0 deletions src/diffusers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -314,6 +314,7 @@
"StableCascadeUNet",
"T2IAdapter",
"T5FilmDecoder",
"TensorParallelConfig",
"Transformer2DModel",
"TransformerTemporalModel",
"UNet1DModel",
Expand Down Expand Up @@ -1181,6 +1182,7 @@
StableAudioDiTModel,
T2IAdapter,
T5FilmDecoder,
TensorParallelConfig,
Transformer2DModel,
TransformerTemporalModel,
UNet1DModel,
Expand Down
1 change: 1 addition & 0 deletions src/diffusers/hooks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,4 +27,5 @@
from .pyramid_attention_broadcast import PyramidAttentionBroadcastConfig, apply_pyramid_attention_broadcast
from .smoothed_energy_guidance_utils import SmoothedEnergyGuidanceConfig
from .taylorseer_cache import TaylorSeerCacheConfig, apply_taylorseer_cache
from .tensor_parallel import apply_tensor_parallel
from .text_kv_cache import TextKVCacheConfig, apply_text_kv_cache
225 changes: 225 additions & 0 deletions src/diffusers/hooks/tensor_parallel.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,225 @@
# Copyright 2026 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import torch

from ..models._modeling_parallel import TensorParallelConfig
from ..utils import get_logger


logger = get_logger(__name__) # pylint: disable=invalid-name


class PackedColwiseParallel:
"""Column-wise sharding for fused projections with heterogeneous block structure.

``blocks`` is a list of proportional integers whose sum divides the weight's row count. For example, ``[1, 1]`` for
a SwiGLU gate+linear projection (two equal halves) or ``[1, 1, 1, 3, 3]`` for a Q+K+V+gate+linear projection with
``mlp_ratio=3``. If ``blocks`` is ``None``, the Linear module must carry a ``_tp_packed_col_blocks`` attribute set
during model ``__init__``.
"""

def __init__(self, blocks: "list[int] | None" = None):
self.blocks = blocks


class PackedRowwiseParallel:
"""Row-wise sharding for fused projections with heterogeneous block structure.

``blocks`` describes the input-column partition of the fused Linear (e.g. ``[1, 3]`` when the input concatenates an
attention projection and an MLP projection with ``mlp_ratio=3``). If ``blocks`` is ``None``, the module must carry
a ``_tp_packed_row_blocks`` attribute.
"""

def __init__(self, blocks: "list[int] | None" = None):
self.blocks = blocks


def _blocks_to_block_sizes(total_size: int, blocks: "list[int]") -> "list[int]":
"""Convert proportional block counts to absolute sizes.

``blocks`` is a list of positive integers interpreted as proportional weights. Their sum must divide ``total_size``
evenly. Returns a list of absolute sizes that sum to ``total_size``.
"""
total = sum(blocks)
if total_size % total != 0:
raise ValueError(
f"Cannot split {total_size} into proportional blocks {blocks}: "
f"sum({blocks})={total} does not divide {total_size}."
)
unit = total_size // total
return [b * unit for b in blocks]


def _resolve_tp_plan(model: torch.nn.Module, tp_plan: dict) -> list:
"""Group a flat ``_tp_plan`` into per-block ``(submodule, {relative_path: style})`` plans.
Comment thread
JingyaHuang marked this conversation as resolved.

Each glob is split at its single ``*``; the prefix must resolve to a ``ModuleList`` and the suffix is the
per-element key. Grouping by block lets the caller issue one ``parallelize_module`` call per block, which
``RowwiseParallel`` needs to attach its input redistribution at the block boundary.
"""
Comment thread
JingyaHuang marked this conversation as resolved.
grouped: dict[int, tuple] = {}
order: list[int] = []

for pattern, style in tp_plan.items():
if pattern.count("*") > 1:
raise ValueError(f"Wildcard '*' can only be used once in a `_tp_plan` key, got '{pattern}'.")

if "*" in pattern:
prefix, _, suffix = pattern.partition("*")
container = model
for atom in prefix.strip(".").split("."):
container = getattr(container, atom)
if not isinstance(container, torch.nn.ModuleList):
raise ValueError(
f"`_tp_plan` wildcard '{pattern}' must expand over a `ModuleList`, but "
f"'{prefix.strip('.')}' resolved to '{container.__class__.__name__}'."
)
relative, blocks = suffix.strip("."), list(container)
else:
relative, blocks = pattern, [model]

for block in blocks:
key = id(block)
if key not in grouped:
grouped[key] = (block, {})
order.append(key)
grouped[key][1][relative] = style

return [grouped[key] for key in order]


def _styles(relative_plan: dict) -> dict:
"""Map a ``{relative_path: style}`` plan to ``parallelize_module`` style instances.

Values may be plain strings (``"colwise"`` / ``"rowwise"``) or ``PackedColwiseParallel`` /
``PackedRowwiseParallel`` marker instances.
"""
import torch.nn as nn
from torch.distributed.tensor import DTensor, Replicate, Shard, distribute_tensor
from torch.distributed.tensor.parallel import ColwiseParallel, RowwiseParallel

def _make_packed_col(marker: PackedColwiseParallel) -> ColwiseParallel:
_blocks = marker.blocks

class _PackedColwiseImpl(ColwiseParallel):
def _partition_linear_fn(self, name, module, device_mesh):
blocks = _blocks if _blocks is not None else getattr(module, "_tp_packed_col_blocks")
rank = device_mesh.get_local_rank()
tp_size = device_mesh.size()
for param_name, param in module.named_parameters():
if param_name == "weight":
full = distribute_tensor(
param, device_mesh, [Replicate()], src_data_rank=self.src_data_rank
).to_local()
block_sizes = _blocks_to_block_sizes(full.shape[0], blocks)
parts, offset = [], 0
for bs in block_sizes:
chunk = bs // tp_size
parts.append(full[offset + rank * chunk : offset + (rank + 1) * chunk].contiguous())
offset += bs
local = torch.cat(parts, dim=0)
dist_param = nn.Parameter(
DTensor.from_local(local, device_mesh, [Shard(0)], run_check=False),
requires_grad=param.requires_grad,
)
else:
dist_param = nn.Parameter(
distribute_tensor(param, device_mesh, [Shard(0)], src_data_rank=self.src_data_rank),
requires_grad=param.requires_grad,
)
module.register_parameter(param_name, dist_param)

return _PackedColwiseImpl()

def _make_packed_row(marker: PackedRowwiseParallel) -> RowwiseParallel:
_blocks = marker.blocks

class _PackedRowwiseImpl(RowwiseParallel):
def _partition_linear_fn(self, name, module, device_mesh):
blocks = _blocks if _blocks is not None else getattr(module, "_tp_packed_row_blocks")
rank = device_mesh.get_local_rank()
tp_size = device_mesh.size()
for param_name, param in module.named_parameters():
if param_name == "weight":
full = distribute_tensor(
param, device_mesh, [Replicate()], src_data_rank=self.src_data_rank
).to_local()
block_sizes = _blocks_to_block_sizes(full.shape[1], blocks)
parts, offset = [], 0
for bs in block_sizes:
chunk = bs // tp_size
parts.append(full[:, offset + rank * chunk : offset + (rank + 1) * chunk].contiguous())
offset += bs
local = torch.cat(parts, dim=1)
dist_param = nn.Parameter(
DTensor.from_local(local, device_mesh, [Shard(1)], run_check=False),
requires_grad=param.requires_grad,
)
else:
dist_param = nn.Parameter(
distribute_tensor(param, device_mesh, [Replicate()], src_data_rank=self.src_data_rank),
requires_grad=param.requires_grad,
)
module.register_parameter(param_name, dist_param)

return _PackedRowwiseImpl()

resolved = {}
for path, style in relative_plan.items():
if style == "colwise":
resolved[path] = ColwiseParallel()
elif style == "rowwise":
resolved[path] = RowwiseParallel()
elif isinstance(style, PackedColwiseParallel):
resolved[path] = _make_packed_col(style)
elif isinstance(style, PackedRowwiseParallel):
resolved[path] = _make_packed_row(style)
else:
raise ValueError(
f"Unsupported tensor-parallel style '{style}' for '{path}'. "
f"Expected 'colwise', 'rowwise', PackedColwiseParallel, or PackedRowwiseParallel."
)
return resolved


def apply_tensor_parallel(
model: torch.nn.Module,
config: TensorParallelConfig,
tp_plan: dict,
*,
backend: str = "default",

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can this not be derived from torch_device?

) -> None:
"""Apply tensor parallel on a model from its flat ``_tp_plan``.

``backend="neuron"`` routes to the Neuron pre-shard path (works around the NRT consecutive-reduce-scatter bug and
applies the Flux2 fused-weight permutations); ``"default"`` uses ``parallelize_module`` directly.
"""
tp_mesh = config._mesh
if tp_mesh is None:
raise ValueError("`config._mesh` is None. Call `config.setup(rank, world_size, device)` before applying TP.")

groups = _resolve_tp_plan(model, tp_plan)
logger.debug(f"Applying tensor parallel (backend={backend}) over {len(groups)} module group(s) on mesh {tp_mesh}.")

if backend == "neuron":
from .tensor_parallel_neuron import _apply_tp_neuron

_apply_tp_neuron(model, tp_mesh, groups)
return

from torch.distributed.tensor.parallel import parallelize_module

for submodule, relative_plan in groups:
parallelize_module(submodule, tp_mesh, _styles(relative_plan))
Loading
Loading