Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
88 changes: 80 additions & 8 deletions src/diffusers/hooks/group_offloading.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
import safetensors.torch
import torch

from ..utils import get_logger, is_accelerate_available
from ..utils import get_logger, is_accelerate_available, is_torchao_available
from ._common import _GO_LC_SUPPORTED_PYTORCH_LAYERS
from .hooks import HookRegistry, ModelHook

Expand All @@ -35,6 +35,54 @@
logger = get_logger(__name__) # pylint: disable=invalid-name


def _is_torchao_tensor(tensor: torch.Tensor) -> bool:

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

while most of the torchao tensor subclasses are developed on top of TorchAOBaseTensor. it's not a requirement to use it. practically this should work for most of the use case but it's not 100% guaranteed

I feel ideally / long term, we can refactor all uses of parameter.data to just operate on parameter itself (if it works)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would love that but sadly that's not the case currently as we cannot always control implementation details from external dependencies.

if not is_torchao_available():
return False
from torchao.utils import TorchAOBaseTensor

return isinstance(tensor, TorchAOBaseTensor)


def _get_torchao_inner_tensor_names(tensor: torch.Tensor) -> list[str]:
"""Get names of all internal tensor data attributes from a TorchAO tensor."""
cls = type(tensor)
names = list(getattr(cls, "tensor_data_names", []))
for attr_name in getattr(cls, "optional_tensor_data_names", []):
if getattr(tensor, attr_name, None) is not None:
names.append(attr_name)
return names


def _swap_torchao_tensor(param: torch.Tensor, source: torch.Tensor) -> None:
"""Move a TorchAO parameter to the device of `source` via `swap_tensors`.

`param.data = source` does not work for `_make_wrapper_subclass` tensors because the `.data` setter only replaces
the outer wrapper storage while leaving the subclass's internal attributes (e.g. `.qdata`, `.scale`) on the
original device. `swap_tensors` swaps the full tensor contents in-place, preserving the parameter's identity so
that any dict keyed by `id(param)` remains valid.

Refer to https://github.com/huggingface/diffusers/pull/13276#discussion_r2944471548 for the full discussion.
"""
torch.utils.swap_tensors(param, source)


def _restore_torchao_tensor(param: torch.Tensor, source: torch.Tensor) -> None:
"""Restore internal tensor data of a TorchAO parameter from `source` without mutating `source`.

Unlike `_swap_torchao_tensor` this copies attribute references one-by-one via `setattr` so that `source` is **not**
modified. Use this when `source` is a cached tensor that must remain unchanged (e.g. a pinned CPU copy in
`cpu_param_dict`).
"""
for attr_name in _get_torchao_inner_tensor_names(source):
setattr(param, attr_name, getattr(source, attr_name))


def _record_stream_torchao_tensor(param: torch.Tensor, stream) -> None:
"""Record stream for all internal tensors of a TorchAO parameter."""
for attr_name in _get_torchao_inner_tensor_names(param):
getattr(param, attr_name).record_stream(stream)


# fmt: off
_GROUP_OFFLOADING = "group_offloading"
_LAYER_EXECUTION_TRACKER = "layer_execution_tracker"
Expand Down Expand Up @@ -157,9 +205,16 @@ def _pinned_memory_tensors(self):
pinned_dict = None

def _transfer_tensor_to_device(self, tensor, source_tensor, default_stream):
tensor.data = source_tensor.to(self.onload_device, non_blocking=self.non_blocking)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does this mean the to op is not implemented properly for torchao tensors?

if you have a minimal repro, we might be able to fix I think

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

import torch
from torchao.quantization import Int8WeightOnlyConfig, quantize_

linear = torch.nn.Linear(64, 64, dtype=torch.bfloat16)
quantize_(linear, Int8WeightOnlyConfig(version=2))
p = linear.weight

# Move a copy to CUDA and assign via .data
cpu_copy = p.data.cpu()
cuda_copy = cpu_copy.to("cuda")
p.data = cuda_copy

print(f"p.qdata.device = {p.qdata.device}")  # cpu
print(f"cuda_copy.qdata.device = {cuda_copy.qdata.device}")  # cuda:0

# Forward fails: input on cuda, weight internals still on cpu
linear.bias.data = linear.bias.data.to("cuda")
x = torch.randn(1, 64, device="cuda", dtype=torch.bfloat16)
linear(x)  # RuntimeError: mat2 is on cpu

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I feel the proper way to do this is:

# (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary.

# pyre-strict
import torch
from torchao.quantization import Int8WeightOnlyConfig, quantize_

linear = torch.nn.Linear(64, 64, dtype=torch.bfloat16)
quantize_(linear, Int8WeightOnlyConfig(version=2))
p = linear.weight

# Move a copy to CUDA and assign via .data
# cpu_copy = p.data.cpu()
cpu_copy = p.cpu()
cuda_copy = cpu_copy.to("cuda")
# p.data = cuda_copy
torch.utils.swap_tensors(linear.weight, cuda_copy)

print(f"p.qdata.device = {p.qdata.device}")  # cpu
print(f"cuda_copy.qdata.device = {cuda_copy.qdata.device}")  # cuda:0

# Forward fails: input on cuda, weight internals still on cpu
linear.bias.data = linear.bias.data.to("cuda")
x = torch.randn(1, 64, device="cuda", dtype=torch.bfloat16)
linear(x)  # RuntimeError: mat2 is on cpu

parameter.data is not a recommended API. and linear.weight is also no longer a nn.Parameter after quantization, it's a different tensor subclass (nn.Parameter is also a tensor subclass: https://github.com/pytorch/pytorch/blob/e9ebbd3bee0761eb9d93b53f4a80d3afa2cc46f8/torch/nn/parameter.py#L30).

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But this does not seem like a fix? Your snippet mentions linear(x) also fails because of the runtime error. Elaborate?

Also, how do I best implement it in the context of the error and the diffusers code? I provided the minimal snippet for your convenience, but it doesn't serve the use case. We need to be able to fix it in the context of the use case.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But this does not seem like a fix? Your snippet mentions linear(x) also fails because of the runtime error. Elaborate?

this is the proper way to move device for a tensor subclass instance I think. please ignore comments, that was copied from your original example. this runs on my side.

Also, how do I best implement it in the context of the error and the diffusers code? I provided the minimal snippet for your convenience, but it doesn't serve the use case. We need to be able to fix it in the context of the use case.

basically we should not be using parameter.data = parameter.data.to("cuda") for quantized weights, but use swap_tensors instead.

we have to go through all linear modules in the model, and use swap_tensor to change device:

for n, m in model.named_modules():
    if isinstance(m, nn.Linear):
        torch.utils.swap_tensors(m.weight, m.weight.to("cuda"))

moved = source_tensor.to(self.onload_device, non_blocking=self.non_blocking)
if _is_torchao_tensor(tensor):
_swap_torchao_tensor(tensor, moved)
else:
tensor.data = moved
if self.record_stream:
tensor.data.record_stream(default_stream)
if _is_torchao_tensor(tensor):
_record_stream_torchao_tensor(tensor, default_stream)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we could potentially implement the record_stream op as a torch_function op in the torchao tensor subclasses as well I think so that you can do:

tensor.record_stream(default_stream) directly.

also wondering if this would work if you just do this for nn.Parameter as well (parameter.record_stream(default_stream) instead of (parameter.data.record_stream(default_stream))?

Copy link
Member Author

@sayakpaul sayakpaul Mar 24, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

tensor.record_stream(default_stream) directly.

Suggestion looks great. But I guess that will take some work on your end to ship. Maybe we can add a comment about it here and revisit when you land it?

also wondering if this would work if you just do this for nn.Parameter as well (parameter.record_stream(default_stream) instead of (parameter.data.record_stream(default_stream))?

Wouldn't mind refactoring it from tensor.data.record_stream(default_stream) pattern but we couldn't find out other solutions when we started working on it last year. Separate PR perhaps?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sounds good to check in another PR

else:
tensor.data.record_stream(default_stream)

def _process_tensors_from_modules(self, pinned_memory=None, default_stream=None):
for group_module in self.modules:
Expand Down Expand Up @@ -245,18 +300,35 @@ def _offload_to_memory(self):

for group_module in self.modules:
for param in group_module.parameters():
param.data = self.cpu_param_dict[param]
if _is_torchao_tensor(param):
_restore_torchao_tensor(param, self.cpu_param_dict[param])

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

similarly for this one I'm wondering if it would make sense to implement some copy op in torchao tensor subclasses, also cpu_param_dict can store the torchao tensor subclass instances directly as well, instead of looking into implementation details

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Indeed that would be great!

else:
param.data = self.cpu_param_dict[param]
for param in self.parameters:
param.data = self.cpu_param_dict[param]
if _is_torchao_tensor(param):
_restore_torchao_tensor(param, self.cpu_param_dict[param])
else:
param.data = self.cpu_param_dict[param]
for buffer in self.buffers:
buffer.data = self.cpu_param_dict[buffer]
if _is_torchao_tensor(buffer):
_restore_torchao_tensor(buffer, self.cpu_param_dict[buffer])
else:
buffer.data = self.cpu_param_dict[buffer]
else:
for group_module in self.modules:
group_module.to(self.offload_device, non_blocking=False)
for param in self.parameters:
param.data = param.data.to(self.offload_device, non_blocking=False)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I remember hearing from Brian and Alban before that param.data is a private API and we should not rely on it, I think it also does not work with tensor subclasses

if _is_torchao_tensor(param):
moved = param.data.to(self.offload_device, non_blocking=False)
_swap_torchao_tensor(param, moved)
else:
param.data = param.data.to(self.offload_device, non_blocking=False)
for buffer in self.buffers:
buffer.data = buffer.data.to(self.offload_device, non_blocking=False)
if _is_torchao_tensor(buffer):
moved = buffer.data.to(self.offload_device, non_blocking=False)
_swap_torchao_tensor(buffer, moved)
else:
buffer.data = buffer.data.to(self.offload_device, non_blocking=False)

@torch.compiler.disable()
def onload_(self):
Expand Down
Loading