From 019a9deafb0dc7beab1958fdc69201c371d3a5d1 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Tue, 17 Mar 2026 10:40:03 +0530 Subject: [PATCH 1/4] fix group offloading when using torchao --- src/diffusers/hooks/group_offloading.py | 75 ++++++++++++++++++++++--- 1 file changed, 67 insertions(+), 8 deletions(-) diff --git a/src/diffusers/hooks/group_offloading.py b/src/diffusers/hooks/group_offloading.py index 891ac28455af..04be39056656 100644 --- a/src/diffusers/hooks/group_offloading.py +++ b/src/diffusers/hooks/group_offloading.py @@ -22,7 +22,7 @@ import safetensors.torch import torch -from ..utils import get_logger, is_accelerate_available +from ..utils import get_logger, is_accelerate_available, is_torchao_available from ._common import _GO_LC_SUPPORTED_PYTORCH_LAYERS from .hooks import HookRegistry, ModelHook @@ -35,6 +35,41 @@ logger = get_logger(__name__) # pylint: disable=invalid-name +def _is_torchao_tensor(tensor: torch.Tensor) -> bool: + """Check if a tensor is a TorchAO quantized tensor subclass.""" + if not is_torchao_available(): + return False + from torchao.utils import TorchAOBaseTensor + + return isinstance(tensor, TorchAOBaseTensor) + + +def _get_torchao_inner_tensor_names(tensor: torch.Tensor) -> list[str]: + """Get names of all internal tensor data attributes from a TorchAO tensor.""" + cls = type(tensor) + names = list(getattr(cls, "tensor_data_names", [])) + for attr_name in getattr(cls, "optional_tensor_data_names", []): + if getattr(tensor, attr_name, None) is not None: + names.append(attr_name) + return names + + +def _update_torchao_tensor_in_place(param: torch.Tensor, source: torch.Tensor) -> None: + """Update internal tensor data of a TorchAO parameter in-place from source. + + Must operate on the parameter/buffer object directly (not ``param.data``) because ``_make_wrapper_subclass`` + returns a fresh wrapper from ``.data`` each time, so attribute mutations on ``.data`` are lost. + """ + for attr_name in _get_torchao_inner_tensor_names(source): + setattr(param, attr_name, getattr(source, attr_name)) + + +def _record_stream_torchao_tensor(param: torch.Tensor, stream) -> None: + """Record stream for all internal tensors of a TorchAO parameter.""" + for attr_name in _get_torchao_inner_tensor_names(param): + getattr(param, attr_name).record_stream(stream) + + # fmt: off _GROUP_OFFLOADING = "group_offloading" _LAYER_EXECUTION_TRACKER = "layer_execution_tracker" @@ -157,9 +192,16 @@ def _pinned_memory_tensors(self): pinned_dict = None def _transfer_tensor_to_device(self, tensor, source_tensor, default_stream): - tensor.data = source_tensor.to(self.onload_device, non_blocking=self.non_blocking) + moved = source_tensor.to(self.onload_device, non_blocking=self.non_blocking) + if _is_torchao_tensor(tensor): + _update_torchao_tensor_in_place(tensor, moved) + else: + tensor.data = moved if self.record_stream: - tensor.data.record_stream(default_stream) + if _is_torchao_tensor(tensor): + _record_stream_torchao_tensor(tensor, default_stream) + else: + tensor.data.record_stream(default_stream) def _process_tensors_from_modules(self, pinned_memory=None, default_stream=None): for group_module in self.modules: @@ -245,18 +287,35 @@ def _offload_to_memory(self): for group_module in self.modules: for param in group_module.parameters(): - param.data = self.cpu_param_dict[param] + if _is_torchao_tensor(param): + _update_torchao_tensor_in_place(param, self.cpu_param_dict[param]) + else: + param.data = self.cpu_param_dict[param] for param in self.parameters: - param.data = self.cpu_param_dict[param] + if _is_torchao_tensor(param): + _update_torchao_tensor_in_place(param, self.cpu_param_dict[param]) + else: + param.data = self.cpu_param_dict[param] for buffer in self.buffers: - buffer.data = self.cpu_param_dict[buffer] + if _is_torchao_tensor(buffer): + _update_torchao_tensor_in_place(buffer, self.cpu_param_dict[buffer]) + else: + buffer.data = self.cpu_param_dict[buffer] else: for group_module in self.modules: group_module.to(self.offload_device, non_blocking=False) for param in self.parameters: - param.data = param.data.to(self.offload_device, non_blocking=False) + if _is_torchao_tensor(param): + moved = param.data.to(self.offload_device, non_blocking=False) + _update_torchao_tensor_in_place(param, moved) + else: + param.data = param.data.to(self.offload_device, non_blocking=False) for buffer in self.buffers: - buffer.data = buffer.data.to(self.offload_device, non_blocking=False) + if _is_torchao_tensor(buffer): + moved = buffer.data.to(self.offload_device, non_blocking=False) + _update_torchao_tensor_in_place(buffer, moved) + else: + buffer.data = buffer.data.to(self.offload_device, non_blocking=False) @torch.compiler.disable() def onload_(self): From 1a959dc26f75dfaa1b2d2ad28eff0d1adf6592ed Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Mon, 23 Mar 2026 10:56:16 +0530 Subject: [PATCH 2/4] switch to swap_tensors. --- src/diffusers/hooks/group_offloading.py | 34 +++++++++++++++++-------- 1 file changed, 24 insertions(+), 10 deletions(-) diff --git a/src/diffusers/hooks/group_offloading.py b/src/diffusers/hooks/group_offloading.py index 04be39056656..56d81c5257b2 100644 --- a/src/diffusers/hooks/group_offloading.py +++ b/src/diffusers/hooks/group_offloading.py @@ -54,11 +54,25 @@ def _get_torchao_inner_tensor_names(tensor: torch.Tensor) -> list[str]: return names -def _update_torchao_tensor_in_place(param: torch.Tensor, source: torch.Tensor) -> None: - """Update internal tensor data of a TorchAO parameter in-place from source. +def _swap_torchao_tensor(param: torch.Tensor, source: torch.Tensor) -> None: + """Move a TorchAO parameter to the device of `source` via `swap_tensors`. - Must operate on the parameter/buffer object directly (not ``param.data``) because ``_make_wrapper_subclass`` - returns a fresh wrapper from ``.data`` each time, so attribute mutations on ``.data`` are lost. + `param.data = source` does not work for `_make_wrapper_subclass` tensors because the `.data` setter only replaces + the outer wrapper storage while leaving the subclass's internal attributes (e.g. `.qdata`, `.scale`) on the + original device. `swap_tensors` swaps the full tensor contents in-place, preserving the parameter's identity so + that any dict keyed by `id(param)` remains valid. + + Refer to https://github.com/huggingface/diffusers/pull/13276#discussion_r2944471548 for the full discussion. + """ + torch.utils.swap_tensors(param, source) + + +def _restore_torchao_tensor(param: torch.Tensor, source: torch.Tensor) -> None: + """Restore internal tensor data of a TorchAO parameter from `source` without mutating `source`. + + Unlike `_swap_torchao_tensor` this copies attribute references one-by-one via `setattr` so that `source` is **not** + modified. Use this when `source` is a cached tensor that must remain unchanged (e.g. a pinned CPU copy in + `cpu_param_dict`). """ for attr_name in _get_torchao_inner_tensor_names(source): setattr(param, attr_name, getattr(source, attr_name)) @@ -194,7 +208,7 @@ def _pinned_memory_tensors(self): def _transfer_tensor_to_device(self, tensor, source_tensor, default_stream): moved = source_tensor.to(self.onload_device, non_blocking=self.non_blocking) if _is_torchao_tensor(tensor): - _update_torchao_tensor_in_place(tensor, moved) + _swap_torchao_tensor(tensor, moved) else: tensor.data = moved if self.record_stream: @@ -288,17 +302,17 @@ def _offload_to_memory(self): for group_module in self.modules: for param in group_module.parameters(): if _is_torchao_tensor(param): - _update_torchao_tensor_in_place(param, self.cpu_param_dict[param]) + _restore_torchao_tensor(param, self.cpu_param_dict[param]) else: param.data = self.cpu_param_dict[param] for param in self.parameters: if _is_torchao_tensor(param): - _update_torchao_tensor_in_place(param, self.cpu_param_dict[param]) + _restore_torchao_tensor(param, self.cpu_param_dict[param]) else: param.data = self.cpu_param_dict[param] for buffer in self.buffers: if _is_torchao_tensor(buffer): - _update_torchao_tensor_in_place(buffer, self.cpu_param_dict[buffer]) + _restore_torchao_tensor(buffer, self.cpu_param_dict[buffer]) else: buffer.data = self.cpu_param_dict[buffer] else: @@ -307,13 +321,13 @@ def _offload_to_memory(self): for param in self.parameters: if _is_torchao_tensor(param): moved = param.data.to(self.offload_device, non_blocking=False) - _update_torchao_tensor_in_place(param, moved) + _swap_torchao_tensor(param, moved) else: param.data = param.data.to(self.offload_device, non_blocking=False) for buffer in self.buffers: if _is_torchao_tensor(buffer): moved = buffer.data.to(self.offload_device, non_blocking=False) - _update_torchao_tensor_in_place(buffer, moved) + _swap_torchao_tensor(buffer, moved) else: buffer.data = buffer.data.to(self.offload_device, non_blocking=False) From 9b9e2e17a68425eeb06fd71296b8f7ee025a7748 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Mon, 23 Mar 2026 11:22:36 +0530 Subject: [PATCH 3/4] up --- src/diffusers/hooks/group_offloading.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/diffusers/hooks/group_offloading.py b/src/diffusers/hooks/group_offloading.py index 56d81c5257b2..a3a2b3ecd658 100644 --- a/src/diffusers/hooks/group_offloading.py +++ b/src/diffusers/hooks/group_offloading.py @@ -36,7 +36,6 @@ def _is_torchao_tensor(tensor: torch.Tensor) -> bool: - """Check if a tensor is a TorchAO quantized tensor subclass.""" if not is_torchao_available(): return False from torchao.utils import TorchAOBaseTensor From 7eaeb99fcd13515bfc272bbc41d60f86261595af Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Mon, 30 Mar 2026 11:24:40 +0530 Subject: [PATCH 4/4] address feedback. --- src/diffusers/hooks/group_offloading.py | 21 +++++++++++++-------- 1 file changed, 13 insertions(+), 8 deletions(-) diff --git a/src/diffusers/hooks/group_offloading.py b/src/diffusers/hooks/group_offloading.py index a3a2b3ecd658..07a0548cdb31 100644 --- a/src/diffusers/hooks/group_offloading.py +++ b/src/diffusers/hooks/group_offloading.py @@ -172,6 +172,13 @@ def __init__( else torch.cuda ) + @staticmethod + def _to_cpu(tensor, low_cpu_mem_usage): + # For TorchAO tensors, `.data` returns an incomplete wrapper without internal attributes + # (e.g. `.qdata`, `.scale`), so we must call `.cpu()` on the tensor directly. + t = tensor.cpu() if _is_torchao_tensor(tensor) else tensor.data.cpu() + return t if low_cpu_mem_usage else t.pin_memory() + def _init_cpu_param_dict(self): cpu_param_dict = {} if self.stream is None: @@ -179,17 +186,15 @@ def _init_cpu_param_dict(self): for module in self.modules: for param in module.parameters(): - cpu_param_dict[param] = param.data.cpu() if self.low_cpu_mem_usage else param.data.cpu().pin_memory() + cpu_param_dict[param] = self._to_cpu(param, self.low_cpu_mem_usage) for buffer in module.buffers(): - cpu_param_dict[buffer] = ( - buffer.data.cpu() if self.low_cpu_mem_usage else buffer.data.cpu().pin_memory() - ) + cpu_param_dict[buffer] = self._to_cpu(buffer, self.low_cpu_mem_usage) for param in self.parameters: - cpu_param_dict[param] = param.data.cpu() if self.low_cpu_mem_usage else param.data.cpu().pin_memory() + cpu_param_dict[param] = self._to_cpu(param, self.low_cpu_mem_usage) for buffer in self.buffers: - cpu_param_dict[buffer] = buffer.data.cpu() if self.low_cpu_mem_usage else buffer.data.cpu().pin_memory() + cpu_param_dict[buffer] = self._to_cpu(buffer, self.low_cpu_mem_usage) return cpu_param_dict @@ -319,13 +324,13 @@ def _offload_to_memory(self): group_module.to(self.offload_device, non_blocking=False) for param in self.parameters: if _is_torchao_tensor(param): - moved = param.data.to(self.offload_device, non_blocking=False) + moved = param.to(self.offload_device, non_blocking=False) _swap_torchao_tensor(param, moved) else: param.data = param.data.to(self.offload_device, non_blocking=False) for buffer in self.buffers: if _is_torchao_tensor(buffer): - moved = buffer.data.to(self.offload_device, non_blocking=False) + moved = buffer.to(self.offload_device, non_blocking=False) _swap_torchao_tensor(buffer, moved) else: buffer.data = buffer.data.to(self.offload_device, non_blocking=False)