-
Notifications
You must be signed in to change notification settings - Fork 6.9k
[core] fix group offloading when using torchao #13276
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
019a9de
8797398
1a959dc
9b9e2e1
d2666a9
6125a4f
7006773
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -22,7 +22,7 @@ | |
| import safetensors.torch | ||
| import torch | ||
|
|
||
| from ..utils import get_logger, is_accelerate_available | ||
| from ..utils import get_logger, is_accelerate_available, is_torchao_available | ||
| from ._common import _GO_LC_SUPPORTED_PYTORCH_LAYERS | ||
| from .hooks import HookRegistry, ModelHook | ||
|
|
||
|
|
@@ -35,6 +35,54 @@ | |
| logger = get_logger(__name__) # pylint: disable=invalid-name | ||
|
|
||
|
|
||
| def _is_torchao_tensor(tensor: torch.Tensor) -> bool: | ||
| if not is_torchao_available(): | ||
| return False | ||
| from torchao.utils import TorchAOBaseTensor | ||
|
|
||
| return isinstance(tensor, TorchAOBaseTensor) | ||
|
|
||
|
|
||
| def _get_torchao_inner_tensor_names(tensor: torch.Tensor) -> list[str]: | ||
| """Get names of all internal tensor data attributes from a TorchAO tensor.""" | ||
| cls = type(tensor) | ||
| names = list(getattr(cls, "tensor_data_names", [])) | ||
| for attr_name in getattr(cls, "optional_tensor_data_names", []): | ||
| if getattr(tensor, attr_name, None) is not None: | ||
| names.append(attr_name) | ||
| return names | ||
|
|
||
|
|
||
| def _swap_torchao_tensor(param: torch.Tensor, source: torch.Tensor) -> None: | ||
| """Move a TorchAO parameter to the device of `source` via `swap_tensors`. | ||
|
|
||
| `param.data = source` does not work for `_make_wrapper_subclass` tensors because the `.data` setter only replaces | ||
| the outer wrapper storage while leaving the subclass's internal attributes (e.g. `.qdata`, `.scale`) on the | ||
| original device. `swap_tensors` swaps the full tensor contents in-place, preserving the parameter's identity so | ||
| that any dict keyed by `id(param)` remains valid. | ||
|
|
||
| Refer to https://github.com/huggingface/diffusers/pull/13276#discussion_r2944471548 for the full discussion. | ||
| """ | ||
| torch.utils.swap_tensors(param, source) | ||
|
|
||
|
|
||
| def _restore_torchao_tensor(param: torch.Tensor, source: torch.Tensor) -> None: | ||
| """Restore internal tensor data of a TorchAO parameter from `source` without mutating `source`. | ||
|
|
||
| Unlike `_swap_torchao_tensor` this copies attribute references one-by-one via `setattr` so that `source` is **not** | ||
| modified. Use this when `source` is a cached tensor that must remain unchanged (e.g. a pinned CPU copy in | ||
| `cpu_param_dict`). | ||
| """ | ||
| for attr_name in _get_torchao_inner_tensor_names(source): | ||
| setattr(param, attr_name, getattr(source, attr_name)) | ||
|
|
||
|
|
||
| def _record_stream_torchao_tensor(param: torch.Tensor, stream) -> None: | ||
| """Record stream for all internal tensors of a TorchAO parameter.""" | ||
| for attr_name in _get_torchao_inner_tensor_names(param): | ||
| getattr(param, attr_name).record_stream(stream) | ||
|
|
||
|
|
||
| # fmt: off | ||
| _GROUP_OFFLOADING = "group_offloading" | ||
| _LAYER_EXECUTION_TRACKER = "layer_execution_tracker" | ||
|
|
@@ -157,9 +205,16 @@ def _pinned_memory_tensors(self): | |
| pinned_dict = None | ||
|
|
||
| def _transfer_tensor_to_device(self, tensor, source_tensor, default_stream): | ||
| tensor.data = source_tensor.to(self.onload_device, non_blocking=self.non_blocking) | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. does this mean the if you have a minimal repro, we might be able to fix I think
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I feel the proper way to do this is: parameter.data is not a recommended API. and
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. But this does not seem like a fix? Your snippet mentions Also, how do I best implement it in the context of the error and the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
this is the proper way to move device for a tensor subclass instance I think. please ignore comments, that was copied from your original example. this runs on my side.
basically we should not be using we have to go through all linear modules in the model, and use swap_tensor to change device: |
||
| moved = source_tensor.to(self.onload_device, non_blocking=self.non_blocking) | ||
| if _is_torchao_tensor(tensor): | ||
| _swap_torchao_tensor(tensor, moved) | ||
| else: | ||
| tensor.data = moved | ||
| if self.record_stream: | ||
| tensor.data.record_stream(default_stream) | ||
| if _is_torchao_tensor(tensor): | ||
| _record_stream_torchao_tensor(tensor, default_stream) | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. we could potentially implement the tensor.record_stream(default_stream) directly. also wondering if this would work if you just do this for nn.Parameter as well (parameter.record_stream(default_stream) instead of (parameter.data.record_stream(default_stream))?
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggestion looks great. But I guess that will take some work on your end to ship. Maybe we can add a comment about it here and revisit when you land it?
Wouldn't mind refactoring it from There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. sounds good to check in another PR |
||
| else: | ||
| tensor.data.record_stream(default_stream) | ||
|
|
||
| def _process_tensors_from_modules(self, pinned_memory=None, default_stream=None): | ||
| for group_module in self.modules: | ||
|
|
@@ -245,18 +300,35 @@ def _offload_to_memory(self): | |
|
|
||
| for group_module in self.modules: | ||
| for param in group_module.parameters(): | ||
| param.data = self.cpu_param_dict[param] | ||
| if _is_torchao_tensor(param): | ||
| _restore_torchao_tensor(param, self.cpu_param_dict[param]) | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. similarly for this one I'm wondering if it would make sense to implement some copy op in torchao tensor subclasses, also
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Indeed that would be great! |
||
| else: | ||
| param.data = self.cpu_param_dict[param] | ||
| for param in self.parameters: | ||
| param.data = self.cpu_param_dict[param] | ||
| if _is_torchao_tensor(param): | ||
| _restore_torchao_tensor(param, self.cpu_param_dict[param]) | ||
| else: | ||
| param.data = self.cpu_param_dict[param] | ||
| for buffer in self.buffers: | ||
| buffer.data = self.cpu_param_dict[buffer] | ||
| if _is_torchao_tensor(buffer): | ||
| _restore_torchao_tensor(buffer, self.cpu_param_dict[buffer]) | ||
| else: | ||
| buffer.data = self.cpu_param_dict[buffer] | ||
| else: | ||
| for group_module in self.modules: | ||
| group_module.to(self.offload_device, non_blocking=False) | ||
| for param in self.parameters: | ||
| param.data = param.data.to(self.offload_device, non_blocking=False) | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I remember hearing from Brian and Alban before that param.data is a private API and we should not rely on it, I think it also does not work with tensor subclasses |
||
| if _is_torchao_tensor(param): | ||
| moved = param.data.to(self.offload_device, non_blocking=False) | ||
| _swap_torchao_tensor(param, moved) | ||
| else: | ||
| param.data = param.data.to(self.offload_device, non_blocking=False) | ||
| for buffer in self.buffers: | ||
| buffer.data = buffer.data.to(self.offload_device, non_blocking=False) | ||
| if _is_torchao_tensor(buffer): | ||
| moved = buffer.data.to(self.offload_device, non_blocking=False) | ||
| _swap_torchao_tensor(buffer, moved) | ||
| else: | ||
| buffer.data = buffer.data.to(self.offload_device, non_blocking=False) | ||
|
|
||
| @torch.compiler.disable() | ||
| def onload_(self): | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
while most of the torchao tensor subclasses are developed on top of
TorchAOBaseTensor. it's not a requirement to use it. practically this should work for most of the use case but it's not 100% guaranteedI feel ideally / long term, we can refactor all uses of parameter.data to just operate on parameter itself (if it works)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would love that but sadly that's not the case currently as we cannot always control implementation details from external dependencies.