Skip to content

[WIP] [core] fix group offloading when using torchao#13276

Draft
sayakpaul wants to merge 6 commits intomainfrom
fix-torchao-groupoffloading
Draft

[WIP] [core] fix group offloading when using torchao#13276
sayakpaul wants to merge 6 commits intomainfrom
fix-torchao-groupoffloading

Conversation

@sayakpaul
Copy link
Member

@sayakpaul sayakpaul commented Mar 17, 2026

What does this PR do?

Fix offloading when using TorchAO. This assumes that the underlying quantization tensor class implements pinning properly. But that's not something we can do in TorchAO, anyway.

The benefit of this is that many new releases benefit from quantization schemes robustly implemented and tested in TorchAO. But quantization alone rarely helps, we need offloading too. Many large models need group offloading (overlapping compute with data transfer).

Problem

Group offloading moves parameters between CPU and GPU by reassigning param.data:

param.data = source_tensor.to(device) 

This works for regular tensors but breaks for TorchAO quantized tensors.

TorchAO tensors are special instances that store their actual data in internal attributes (e.g., .qdata, .scale), not in the standard tensor storage. The .data assignment replaces the
outer wrapper storage but leaves these internal attributes on the original device, causing a device mismatch at compute time.

A further subtlety: accessing .data on a wrapper subclass parameter returns a new wrapper object each time, so mutating attributes on param.data doesn't persist either.

This PR

~.data approach~

For TorchAO tensors, instead of reassigning data, we update the internal tensor attributes directly on the parameter object itself:

# Before (broken for TorchAO tensors)                                                                                                    
param.data = source_tensor.to(device)                                                                                                    
                                                                                                                                         
# After                                                                                                                                  
moved = source_tensor.to(device)                                                                                                         
if _is_torchao_tensor(param):                                                                                                            
    for attr in tensor_data_names:  # e.g. ["qdata", "scale"]                                                                            
        setattr(param, attr, getattr(moved, attr))                                                                                       
else:                                                                                                                                    
    param.data = moved     

For TorchAO tensors, param.data = source_tensor.to(device) doesn't work because _make_wrapper_subclass tensors store their actual data in internal attributes (.qdata, .scale, etc.), and the .data setter only replaces the outer wrapper storage.

We use two strategies depending on the code path:

Onload — torch.utils.swap_tensors, which swaps the full tensor contents in-place:

moved = source_tensor.to(device)
if _is_torchao_tensor(param):
   torch.utils.swap_tensors(param, moved)
else:
   param.data = moved

Offload (with stream) — setattr to copy internal tensor references without mutating the cached CPU copy:

if _is_torchao_tensor(param):
    for attr in tensor_data_names:  # e.g. ["qdata", "scale"]
        setattr(param, attr, getattr(cpu_cached_copy, attr))
else:
    param.data = cpu_cached_copy

swap_tensors can't be used for the stream offload path because it's bidirectional — it would put CUDA data into the cached CPU copy, corrupting it for the next onload cycle.

Related issue: pytorch/ao#4088.

Happens with nightlies as well.

Code to test: https://gist.github.com/sayakpaul/929678132809874c5dbf9c5215460d33#file-check_torchao_offload_compile-py (run with --quantize, --group-offload; and potentially with --full-compile).

Nice results (with quantization + group offloading + full compile):

100%|████████████████████████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:32<00:00,  8.06s/it]
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:08<00:00,  2.18s/it]

Needs the nightlies (of both PyTorch and TorchAO) for testing.

Important

While this PR executes the TorchAO-specific changes, I think we could refactor group offloading-related utilities to rely on swap_tensors, instead of .data as it is considered to be a private API.

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

pinned_dict = None

def _transfer_tensor_to_device(self, tensor, source_tensor, default_stream):
tensor.data = source_tensor.to(self.onload_device, non_blocking=self.non_blocking)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does this mean the to op is not implemented properly for torchao tensors?

if you have a minimal repro, we might be able to fix I think

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

import torch
from torchao.quantization import Int8WeightOnlyConfig, quantize_

linear = torch.nn.Linear(64, 64, dtype=torch.bfloat16)
quantize_(linear, Int8WeightOnlyConfig(version=2))
p = linear.weight

# Move a copy to CUDA and assign via .data
cpu_copy = p.data.cpu()
cuda_copy = cpu_copy.to("cuda")
p.data = cuda_copy

print(f"p.qdata.device = {p.qdata.device}")  # cpu
print(f"cuda_copy.qdata.device = {cuda_copy.qdata.device}")  # cuda:0

# Forward fails: input on cuda, weight internals still on cpu
linear.bias.data = linear.bias.data.to("cuda")
x = torch.randn(1, 64, device="cuda", dtype=torch.bfloat16)
linear(x)  # RuntimeError: mat2 is on cpu

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I feel the proper way to do this is:

# (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary.

# pyre-strict
import torch
from torchao.quantization import Int8WeightOnlyConfig, quantize_

linear = torch.nn.Linear(64, 64, dtype=torch.bfloat16)
quantize_(linear, Int8WeightOnlyConfig(version=2))
p = linear.weight

# Move a copy to CUDA and assign via .data
# cpu_copy = p.data.cpu()
cpu_copy = p.cpu()
cuda_copy = cpu_copy.to("cuda")
# p.data = cuda_copy
torch.utils.swap_tensors(linear.weight, cuda_copy)

print(f"p.qdata.device = {p.qdata.device}")  # cpu
print(f"cuda_copy.qdata.device = {cuda_copy.qdata.device}")  # cuda:0

# Forward fails: input on cuda, weight internals still on cpu
linear.bias.data = linear.bias.data.to("cuda")
x = torch.randn(1, 64, device="cuda", dtype=torch.bfloat16)
linear(x)  # RuntimeError: mat2 is on cpu

parameter.data is not a recommended API. and linear.weight is also no longer a nn.Parameter after quantization, it's a different tensor subclass (nn.Parameter is also a tensor subclass: https://github.com/pytorch/pytorch/blob/e9ebbd3bee0761eb9d93b53f4a80d3afa2cc46f8/torch/nn/parameter.py#L30).

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But this does not seem like a fix? Your snippet mentions linear(x) also fails because of the runtime error. Elaborate?

Also, how do I best implement it in the context of the error and the diffusers code? I provided the minimal snippet for your convenience, but it doesn't serve the use case. We need to be able to fix it in the context of the use case.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But this does not seem like a fix? Your snippet mentions linear(x) also fails because of the runtime error. Elaborate?

this is the proper way to move device for a tensor subclass instance I think. please ignore comments, that was copied from your original example. this runs on my side.

Also, how do I best implement it in the context of the error and the diffusers code? I provided the minimal snippet for your convenience, but it doesn't serve the use case. We need to be able to fix it in the context of the use case.

basically we should not be using parameter.data = parameter.data.to("cuda") for quantized weights, but use swap_tensors instead.

we have to go through all linear modules in the model, and use swap_tensor to change device:

for n, m in model.named_modules():
    if isinstance(m, nn.Linear):
        torch.utils.swap_tensors(m.weight, m.weight.to("cuda"))

for group_module in self.modules:
group_module.to(self.offload_device, non_blocking=False)
for param in self.parameters:
param.data = param.data.to(self.offload_device, non_blocking=False)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I remember hearing from Brian and Alban before that param.data is a private API and we should not rely on it, I think it also does not work with tensor subclasses

@sayakpaul sayakpaul requested a review from jerryzh168 March 23, 2026 05:30
if self.record_stream:
tensor.data.record_stream(default_stream)
if _is_torchao_tensor(tensor):
_record_stream_torchao_tensor(tensor, default_stream)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we could potentially implement the record_stream op as a torch_function op in the torchao tensor subclasses as well I think so that you can do:

tensor.record_stream(default_stream) directly.

also wondering if this would work if you just do this for nn.Parameter as well (parameter.record_stream(default_stream) instead of (parameter.data.record_stream(default_stream))?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

tensor.record_stream(default_stream) directly.

Suggestion looks great. But I guess that will take some work on your end to ship. Maybe we can add a comment about it here and revisit when you land it?

also wondering if this would work if you just do this for nn.Parameter as well (parameter.record_stream(default_stream) instead of (parameter.data.record_stream(default_stream))?

Wouldn't find refactoring it from tensor.data.record_stream(default_stream) pattern but we couldn't find out other solutions when we started working on it last year. Separate PR perhaps?

for param in group_module.parameters():
param.data = self.cpu_param_dict[param]
if _is_torchao_tensor(param):
_restore_torchao_tensor(param, self.cpu_param_dict[param])

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

similarly for this one I'm wondering if it would make sense to implement some copy op in torchao tensor subclasses, also cpu_param_dict can store the torchao tensor subclass instances directly as well, instead of looking into implementation details

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Indeed that would be great!

logger = get_logger(__name__) # pylint: disable=invalid-name


def _is_torchao_tensor(tensor: torch.Tensor) -> bool:

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

while most of the torchao tensor subclasses are developed on top of TorchAOBaseTensor. it's not a requirement to use it. practically this should work for most of the use case but it's not 100% guaranteed

I feel ideally / long term, we can refactor all uses of parameter.data to just operate on parameter itself (if it works)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would love that but sadly that's not the case currently as we cannot always control implementation details from external dependencies.

@sayakpaul
Copy link
Member Author

@jerryzh168 thanks for sharing further thoughts here! I appreciate them.

While I agree with the comments on long-term vision, I don't think we can assume / control underlying implementation details coming from external dependencies (such as TorchAO and other quantization backends).

Hence, being explicit about the control paths and keeping them separate feels like a fair compromise (compromise because of the increased cyclomatic complexity coming from the conditionals).

But as we go ahead and implement several aspects in TorchAO (record stream, copy, etc.). I think we can work together to reflect them within Diffusers.

WDYT?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants