[WIP] [core] fix group offloading when using torchao#13276
[WIP] [core] fix group offloading when using torchao#13276
Conversation
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
| pinned_dict = None | ||
|
|
||
| def _transfer_tensor_to_device(self, tensor, source_tensor, default_stream): | ||
| tensor.data = source_tensor.to(self.onload_device, non_blocking=self.non_blocking) |
There was a problem hiding this comment.
does this mean the to op is not implemented properly for torchao tensors?
if you have a minimal repro, we might be able to fix I think
There was a problem hiding this comment.
import torch
from torchao.quantization import Int8WeightOnlyConfig, quantize_
linear = torch.nn.Linear(64, 64, dtype=torch.bfloat16)
quantize_(linear, Int8WeightOnlyConfig(version=2))
p = linear.weight
# Move a copy to CUDA and assign via .data
cpu_copy = p.data.cpu()
cuda_copy = cpu_copy.to("cuda")
p.data = cuda_copy
print(f"p.qdata.device = {p.qdata.device}") # cpu
print(f"cuda_copy.qdata.device = {cuda_copy.qdata.device}") # cuda:0
# Forward fails: input on cuda, weight internals still on cpu
linear.bias.data = linear.bias.data.to("cuda")
x = torch.randn(1, 64, device="cuda", dtype=torch.bfloat16)
linear(x) # RuntimeError: mat2 is on cpu
There was a problem hiding this comment.
I feel the proper way to do this is:
# (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary.
# pyre-strict
import torch
from torchao.quantization import Int8WeightOnlyConfig, quantize_
linear = torch.nn.Linear(64, 64, dtype=torch.bfloat16)
quantize_(linear, Int8WeightOnlyConfig(version=2))
p = linear.weight
# Move a copy to CUDA and assign via .data
# cpu_copy = p.data.cpu()
cpu_copy = p.cpu()
cuda_copy = cpu_copy.to("cuda")
# p.data = cuda_copy
torch.utils.swap_tensors(linear.weight, cuda_copy)
print(f"p.qdata.device = {p.qdata.device}") # cpu
print(f"cuda_copy.qdata.device = {cuda_copy.qdata.device}") # cuda:0
# Forward fails: input on cuda, weight internals still on cpu
linear.bias.data = linear.bias.data.to("cuda")
x = torch.randn(1, 64, device="cuda", dtype=torch.bfloat16)
linear(x) # RuntimeError: mat2 is on cpu
parameter.data is not a recommended API. and linear.weight is also no longer a nn.Parameter after quantization, it's a different tensor subclass (nn.Parameter is also a tensor subclass: https://github.com/pytorch/pytorch/blob/e9ebbd3bee0761eb9d93b53f4a80d3afa2cc46f8/torch/nn/parameter.py#L30).
There was a problem hiding this comment.
But this does not seem like a fix? Your snippet mentions linear(x) also fails because of the runtime error. Elaborate?
Also, how do I best implement it in the context of the error and the diffusers code? I provided the minimal snippet for your convenience, but it doesn't serve the use case. We need to be able to fix it in the context of the use case.
There was a problem hiding this comment.
But this does not seem like a fix? Your snippet mentions linear(x) also fails because of the runtime error. Elaborate?
this is the proper way to move device for a tensor subclass instance I think. please ignore comments, that was copied from your original example. this runs on my side.
Also, how do I best implement it in the context of the error and the diffusers code? I provided the minimal snippet for your convenience, but it doesn't serve the use case. We need to be able to fix it in the context of the use case.
basically we should not be using parameter.data = parameter.data.to("cuda") for quantized weights, but use swap_tensors instead.
we have to go through all linear modules in the model, and use swap_tensor to change device:
for n, m in model.named_modules():
if isinstance(m, nn.Linear):
torch.utils.swap_tensors(m.weight, m.weight.to("cuda"))
| for group_module in self.modules: | ||
| group_module.to(self.offload_device, non_blocking=False) | ||
| for param in self.parameters: | ||
| param.data = param.data.to(self.offload_device, non_blocking=False) |
There was a problem hiding this comment.
I remember hearing from Brian and Alban before that param.data is a private API and we should not rely on it, I think it also does not work with tensor subclasses
| if self.record_stream: | ||
| tensor.data.record_stream(default_stream) | ||
| if _is_torchao_tensor(tensor): | ||
| _record_stream_torchao_tensor(tensor, default_stream) |
There was a problem hiding this comment.
we could potentially implement the record_stream op as a torch_function op in the torchao tensor subclasses as well I think so that you can do:
tensor.record_stream(default_stream) directly.
also wondering if this would work if you just do this for nn.Parameter as well (parameter.record_stream(default_stream) instead of (parameter.data.record_stream(default_stream))?
There was a problem hiding this comment.
tensor.record_stream(default_stream) directly.
Suggestion looks great. But I guess that will take some work on your end to ship. Maybe we can add a comment about it here and revisit when you land it?
also wondering if this would work if you just do this for nn.Parameter as well (parameter.record_stream(default_stream) instead of (parameter.data.record_stream(default_stream))?
Wouldn't find refactoring it from tensor.data.record_stream(default_stream) pattern but we couldn't find out other solutions when we started working on it last year. Separate PR perhaps?
| for param in group_module.parameters(): | ||
| param.data = self.cpu_param_dict[param] | ||
| if _is_torchao_tensor(param): | ||
| _restore_torchao_tensor(param, self.cpu_param_dict[param]) |
There was a problem hiding this comment.
similarly for this one I'm wondering if it would make sense to implement some copy op in torchao tensor subclasses, also cpu_param_dict can store the torchao tensor subclass instances directly as well, instead of looking into implementation details
There was a problem hiding this comment.
Indeed that would be great!
| logger = get_logger(__name__) # pylint: disable=invalid-name | ||
|
|
||
|
|
||
| def _is_torchao_tensor(tensor: torch.Tensor) -> bool: |
There was a problem hiding this comment.
while most of the torchao tensor subclasses are developed on top of TorchAOBaseTensor. it's not a requirement to use it. practically this should work for most of the use case but it's not 100% guaranteed
I feel ideally / long term, we can refactor all uses of parameter.data to just operate on parameter itself (if it works)
There was a problem hiding this comment.
I would love that but sadly that's not the case currently as we cannot always control implementation details from external dependencies.
|
@jerryzh168 thanks for sharing further thoughts here! I appreciate them. While I agree with the comments on long-term vision, I don't think we can assume / control underlying implementation details coming from external dependencies (such as TorchAO and other quantization backends). Hence, being explicit about the control paths and keeping them separate feels like a fair compromise (compromise because of the increased cyclomatic complexity coming from the conditionals). But as we go ahead and implement several aspects in TorchAO (record stream, copy, etc.). I think we can work together to reflect them within Diffusers. WDYT? |
What does this PR do?
Fix offloading when using TorchAO. This assumes that the underlying quantization tensor class implements pinning properly. But that's not something we can do in TorchAO, anyway.
The benefit of this is that many new releases benefit from quantization schemes robustly implemented and tested in TorchAO. But quantization alone rarely helps, we need offloading too. Many large models need group offloading (overlapping compute with data transfer).
Problem
Group offloading moves parameters between CPU and GPU by reassigning
param.data:This works for regular tensors but breaks for TorchAO quantized tensors.
TorchAO tensors are special instances that store their actual data in internal attributes (e.g., .qdata, .scale), not in the standard tensor storage. The
.dataassignment replaces theouter wrapper storage but leaves these internal attributes on the original device, causing a device mismatch at compute time.
A further subtlety: accessing
.dataon a wrapper subclass parameter returns a new wrapper object each time, so mutating attributes onparam.datadoesn't persist either.This PR
~.data approach~
For TorchAO tensors, instead of reassigning
data, we update the internal tensor attributes directly on the parameter object itself:For TorchAO tensors,
param.data = source_tensor.to(device)doesn't work because_make_wrapper_subclasstensors store their actual data in internal attributes (.qdata,.scale, etc.), and the .data setter only replaces the outer wrapper storage.We use two strategies depending on the code path:
Onload —
torch.utils.swap_tensors, which swaps the full tensor contents in-place:Offload (with stream) —
setattrto copy internal tensor references without mutating the cached CPU copy:swap_tensorscan't be used for the stream offload path because it's bidirectional — it would put CUDA data into the cached CPU copy, corrupting it for the next onload cycle.Related issue: pytorch/ao#4088.
Happens with nightlies as well.
Code to test: https://gist.github.com/sayakpaul/929678132809874c5dbf9c5215460d33#file-check_torchao_offload_compile-py (run with
--quantize,--group-offload; and potentially with--full-compile).Nice results (with quantization + group offloading + full compile):
Needs the nightlies (of both PyTorch and TorchAO) for testing.
Important
While this PR executes the TorchAO-specific changes, I think we could refactor group offloading-related utilities to rely on
swap_tensors, instead of.dataas it is considered to be a private API.