diff --git a/src/diffusers/hooks/group_offloading.py b/src/diffusers/hooks/group_offloading.py index 891ac28455af..a3a2b3ecd658 100644 --- a/src/diffusers/hooks/group_offloading.py +++ b/src/diffusers/hooks/group_offloading.py @@ -22,7 +22,7 @@ import safetensors.torch import torch -from ..utils import get_logger, is_accelerate_available +from ..utils import get_logger, is_accelerate_available, is_torchao_available from ._common import _GO_LC_SUPPORTED_PYTORCH_LAYERS from .hooks import HookRegistry, ModelHook @@ -35,6 +35,54 @@ logger = get_logger(__name__) # pylint: disable=invalid-name +def _is_torchao_tensor(tensor: torch.Tensor) -> bool: + if not is_torchao_available(): + return False + from torchao.utils import TorchAOBaseTensor + + return isinstance(tensor, TorchAOBaseTensor) + + +def _get_torchao_inner_tensor_names(tensor: torch.Tensor) -> list[str]: + """Get names of all internal tensor data attributes from a TorchAO tensor.""" + cls = type(tensor) + names = list(getattr(cls, "tensor_data_names", [])) + for attr_name in getattr(cls, "optional_tensor_data_names", []): + if getattr(tensor, attr_name, None) is not None: + names.append(attr_name) + return names + + +def _swap_torchao_tensor(param: torch.Tensor, source: torch.Tensor) -> None: + """Move a TorchAO parameter to the device of `source` via `swap_tensors`. + + `param.data = source` does not work for `_make_wrapper_subclass` tensors because the `.data` setter only replaces + the outer wrapper storage while leaving the subclass's internal attributes (e.g. `.qdata`, `.scale`) on the + original device. `swap_tensors` swaps the full tensor contents in-place, preserving the parameter's identity so + that any dict keyed by `id(param)` remains valid. + + Refer to https://github.com/huggingface/diffusers/pull/13276#discussion_r2944471548 for the full discussion. + """ + torch.utils.swap_tensors(param, source) + + +def _restore_torchao_tensor(param: torch.Tensor, source: torch.Tensor) -> None: + """Restore internal tensor data of a TorchAO parameter from `source` without mutating `source`. + + Unlike `_swap_torchao_tensor` this copies attribute references one-by-one via `setattr` so that `source` is **not** + modified. Use this when `source` is a cached tensor that must remain unchanged (e.g. a pinned CPU copy in + `cpu_param_dict`). + """ + for attr_name in _get_torchao_inner_tensor_names(source): + setattr(param, attr_name, getattr(source, attr_name)) + + +def _record_stream_torchao_tensor(param: torch.Tensor, stream) -> None: + """Record stream for all internal tensors of a TorchAO parameter.""" + for attr_name in _get_torchao_inner_tensor_names(param): + getattr(param, attr_name).record_stream(stream) + + # fmt: off _GROUP_OFFLOADING = "group_offloading" _LAYER_EXECUTION_TRACKER = "layer_execution_tracker" @@ -157,9 +205,16 @@ def _pinned_memory_tensors(self): pinned_dict = None def _transfer_tensor_to_device(self, tensor, source_tensor, default_stream): - tensor.data = source_tensor.to(self.onload_device, non_blocking=self.non_blocking) + moved = source_tensor.to(self.onload_device, non_blocking=self.non_blocking) + if _is_torchao_tensor(tensor): + _swap_torchao_tensor(tensor, moved) + else: + tensor.data = moved if self.record_stream: - tensor.data.record_stream(default_stream) + if _is_torchao_tensor(tensor): + _record_stream_torchao_tensor(tensor, default_stream) + else: + tensor.data.record_stream(default_stream) def _process_tensors_from_modules(self, pinned_memory=None, default_stream=None): for group_module in self.modules: @@ -245,18 +300,35 @@ def _offload_to_memory(self): for group_module in self.modules: for param in group_module.parameters(): - param.data = self.cpu_param_dict[param] + if _is_torchao_tensor(param): + _restore_torchao_tensor(param, self.cpu_param_dict[param]) + else: + param.data = self.cpu_param_dict[param] for param in self.parameters: - param.data = self.cpu_param_dict[param] + if _is_torchao_tensor(param): + _restore_torchao_tensor(param, self.cpu_param_dict[param]) + else: + param.data = self.cpu_param_dict[param] for buffer in self.buffers: - buffer.data = self.cpu_param_dict[buffer] + if _is_torchao_tensor(buffer): + _restore_torchao_tensor(buffer, self.cpu_param_dict[buffer]) + else: + buffer.data = self.cpu_param_dict[buffer] else: for group_module in self.modules: group_module.to(self.offload_device, non_blocking=False) for param in self.parameters: - param.data = param.data.to(self.offload_device, non_blocking=False) + if _is_torchao_tensor(param): + moved = param.data.to(self.offload_device, non_blocking=False) + _swap_torchao_tensor(param, moved) + else: + param.data = param.data.to(self.offload_device, non_blocking=False) for buffer in self.buffers: - buffer.data = buffer.data.to(self.offload_device, non_blocking=False) + if _is_torchao_tensor(buffer): + moved = buffer.data.to(self.offload_device, non_blocking=False) + _swap_torchao_tensor(buffer, moved) + else: + buffer.data = buffer.data.to(self.offload_device, non_blocking=False) @torch.compiler.disable() def onload_(self):