From 5f865d41363e54c09af539bf471130514045e8a9 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Wed, 13 May 2026 21:58:34 +0900 Subject: [PATCH 1/2] start moving to swap_tensors. --- src/diffusers/hooks/group_offloading.py | 102 +++++++++++++----------- 1 file changed, 56 insertions(+), 46 deletions(-) diff --git a/src/diffusers/hooks/group_offloading.py b/src/diffusers/hooks/group_offloading.py index 49509cbf04b9..e893caee3b95 100644 --- a/src/diffusers/hooks/group_offloading.py +++ b/src/diffusers/hooks/group_offloading.py @@ -53,28 +53,53 @@ def _get_torchao_inner_tensor_names(tensor: torch.Tensor) -> list[str]: return names -def _swap_torchao_tensor(param: torch.Tensor, source: torch.Tensor) -> None: - """Move a TorchAO parameter to the device of `source` via `swap_tensors`. - - `param.data = source` does not work for `_make_wrapper_subclass` tensors because the `.data` setter only replaces - the outer wrapper storage while leaving the subclass's internal attributes (e.g. `.qdata`, `.scale`) on the - original device. `swap_tensors` swaps the full tensor contents in-place, preserving the parameter's identity so - that any dict keyed by `id(param)` remains valid. - - Refer to https://github.com/huggingface/diffusers/pull/13276#discussion_r2944471548 for the full discussion. +def _swap_tensor_data(target: torch.Tensor, source: torch.Tensor) -> None: + """Replace `target`'s underlying tensor with `source` in-place via `torch.utils.swap_tensors`. + + Drop-in for the private `target.data = source` assignment. `target`'s Python identity (and therefore any + dict keyed by `id(target)`) is preserved. After this call, `source` should be considered consumed: the swap + is bidirectional, so `source` ends up holding what `target` held before. + + For the cached-CPU restore path where `source` must remain unchanged, use `_restore_from_cached_cpu` instead. + + Two subtleties: + + 1. `swap_tensors` also swaps `__class__`. If `target` is `nn.Parameter` and `source` is a plain `Tensor`, + `target` would be demoted to `Tensor`, breaking `isinstance(p, nn.Parameter)` and any feature that + depends on Parameter-ness (e.g. LoRA injection). We wrap `source` in a Parameter first to keep types + aligned. This mirrors what `torch.__future__.set_swap_module_params_on_conversion(True)` does in + `nn.Module._apply`. + 2. TorchAO wrapper-subclass tensors store data in internal attributes (`.qdata`, `.scale`, ...) rather + than the standard storage; `target.data = source` only replaces the outer wrapper and leaves those + attributes on the wrong device. `swap_tensors` operates on the full wrapper, including those + attributes, so it works directly without the Parameter wrapping (TorchAO Parameters are themselves + wrapper subclasses; `.to()` returns the same subclass). See + https://github.com/huggingface/diffusers/pull/13276#discussion_r2944471548. """ - torch.utils.swap_tensors(param, source) + if _is_torchao_tensor(target): + torch.utils.swap_tensors(target, source) + return + if isinstance(target, torch.nn.Parameter) and not isinstance(source, torch.nn.Parameter): + source = torch.nn.Parameter(source, requires_grad=target.requires_grad) + torch.utils.swap_tensors(target, source) -def _restore_torchao_tensor(param: torch.Tensor, source: torch.Tensor) -> None: - """Restore internal tensor data of a TorchAO parameter from `source` without mutating `source`. +def _restore_from_cached_cpu(target: torch.Tensor, source: torch.Tensor) -> None: + """Make `target` reference `source`'s data without mutating `source`. - Unlike `_swap_torchao_tensor` this copies attribute references one-by-one via `setattr` so that `source` is **not** - modified. Use this when `source` is a cached tensor that must remain unchanged (e.g. a pinned CPU copy in - `cpu_param_dict`). + Used for the stream-offload path where `source` is a long-lived cached CPU copy held in `cpu_param_dict`. + `swap_tensors` cannot be used here because it is bidirectional — swapping with the cached copy would put + GPU data into it and corrupt the cache for the next onload cycle. + + For TorchAO wrapper-subclass tensors, copy attribute references one-by-one via `setattr`. For regular + tensors we still rely on the `target.data = source` assignment: there is no public single-direction + equivalent that works on a leaf Parameter without explicitly entering a `no_grad` context. """ - for attr_name in _get_torchao_inner_tensor_names(source): - setattr(param, attr_name, getattr(source, attr_name)) + if _is_torchao_tensor(target): + for attr_name in _get_torchao_inner_tensor_names(source): + setattr(target, attr_name, getattr(source, attr_name)) + else: + target.data = source def _record_stream_torchao_tensor(param: torch.Tensor, stream) -> None: @@ -211,10 +236,7 @@ def _pinned_memory_tensors(self): def _transfer_tensor_to_device(self, tensor, source_tensor, default_stream): moved = source_tensor.to(self.onload_device, non_blocking=self.non_blocking) - if _is_torchao_tensor(tensor): - _swap_torchao_tensor(tensor, moved) - else: - tensor.data = moved + _swap_tensor_data(tensor, moved) if self.record_stream: if _is_torchao_tensor(tensor): _record_stream_torchao_tensor(tensor, default_stream) @@ -266,7 +288,8 @@ def _onload_from_disk(self): if self.stream is not None: for key, tensor_obj in self.key_to_tensor.items(): pinned_tensor = loaded_tensors[key].pin_memory() - tensor_obj.data = pinned_tensor.to(self.onload_device, non_blocking=self.non_blocking) + moved = pinned_tensor.to(self.onload_device, non_blocking=self.non_blocking) + _swap_tensor_data(tensor_obj, moved) if self.record_stream: tensor_obj.data.record_stream(current_stream) else: @@ -275,7 +298,7 @@ def _onload_from_disk(self): ) loaded_tensors = safetensors.torch.load_file(self.safetensors_file_path, device=onload_device) for key, tensor_obj in self.key_to_tensor.items(): - tensor_obj.data = loaded_tensors[key] + _swap_tensor_data(tensor_obj, loaded_tensors[key]) def _onload_from_memory(self): if self.stream is not None: @@ -310,7 +333,7 @@ def _offload_to_disk(self): # We do this to free up the RAM which is still holding the up tensor data. for tensor_obj in self.tensor_to_key.keys(): - tensor_obj.data = torch.empty_like(tensor_obj.data, device=self.offload_device) + _swap_tensor_data(tensor_obj, torch.empty_like(tensor_obj.data, device=self.offload_device)) def _offload_to_memory(self): if self.stream is not None: @@ -319,35 +342,22 @@ def _offload_to_memory(self): for group_module in self.modules: for param in group_module.parameters(): - if _is_torchao_tensor(param): - _restore_torchao_tensor(param, self.cpu_param_dict[param]) - else: - param.data = self.cpu_param_dict[param] + _restore_from_cached_cpu(param, self.cpu_param_dict[param]) for param in self.parameters: - if _is_torchao_tensor(param): - _restore_torchao_tensor(param, self.cpu_param_dict[param]) - else: - param.data = self.cpu_param_dict[param] + _restore_from_cached_cpu(param, self.cpu_param_dict[param]) for buffer in self.buffers: - if _is_torchao_tensor(buffer): - _restore_torchao_tensor(buffer, self.cpu_param_dict[buffer]) - else: - buffer.data = self.cpu_param_dict[buffer] + _restore_from_cached_cpu(buffer, self.cpu_param_dict[buffer]) else: for group_module in self.modules: group_module.to(self.offload_device, non_blocking=False) for param in self.parameters: - if _is_torchao_tensor(param): - moved = param.to(self.offload_device, non_blocking=False) - _swap_torchao_tensor(param, moved) - else: - param.data = param.data.to(self.offload_device, non_blocking=False) + # For TorchAO `.data` returns an incomplete wrapper without internal attributes; + # call `.to()` on the parameter itself in that case. + source = param if _is_torchao_tensor(param) else param.data + _swap_tensor_data(param, source.to(self.offload_device, non_blocking=False)) for buffer in self.buffers: - if _is_torchao_tensor(buffer): - moved = buffer.to(self.offload_device, non_blocking=False) - _swap_torchao_tensor(buffer, moved) - else: - buffer.data = buffer.data.to(self.offload_device, non_blocking=False) + source = buffer if _is_torchao_tensor(buffer) else buffer.data + _swap_tensor_data(buffer, source.to(self.offload_device, non_blocking=False)) @torch.compiler.disable() def onload_(self): From ef34881839000dd8a8132417caf87f436ab7e8ca Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Thu, 14 May 2026 17:21:11 +0900 Subject: [PATCH 2/2] fix more. --- src/diffusers/hooks/group_offloading.py | 47 +++++++++---------------- 1 file changed, 17 insertions(+), 30 deletions(-) diff --git a/src/diffusers/hooks/group_offloading.py b/src/diffusers/hooks/group_offloading.py index e893caee3b95..e54a836d01e4 100644 --- a/src/diffusers/hooks/group_offloading.py +++ b/src/diffusers/hooks/group_offloading.py @@ -54,46 +54,33 @@ def _get_torchao_inner_tensor_names(tensor: torch.Tensor) -> list[str]: def _swap_tensor_data(target: torch.Tensor, source: torch.Tensor) -> None: - """Replace `target`'s underlying tensor with `source` in-place via `torch.utils.swap_tensors`. - - Drop-in for the private `target.data = source` assignment. `target`'s Python identity (and therefore any - dict keyed by `id(target)`) is preserved. After this call, `source` should be considered consumed: the swap - is bidirectional, so `source` ends up holding what `target` held before. - - For the cached-CPU restore path where `source` must remain unchanged, use `_restore_from_cached_cpu` instead. - - Two subtleties: - - 1. `swap_tensors` also swaps `__class__`. If `target` is `nn.Parameter` and `source` is a plain `Tensor`, - `target` would be demoted to `Tensor`, breaking `isinstance(p, nn.Parameter)` and any feature that - depends on Parameter-ness (e.g. LoRA injection). We wrap `source` in a Parameter first to keep types - aligned. This mirrors what `torch.__future__.set_swap_module_params_on_conversion(True)` does in - `nn.Module._apply`. - 2. TorchAO wrapper-subclass tensors store data in internal attributes (`.qdata`, `.scale`, ...) rather - than the standard storage; `target.data = source` only replaces the outer wrapper and leaves those - attributes on the wrong device. `swap_tensors` operates on the full wrapper, including those - attributes, so it works directly without the Parameter wrapping (TorchAO Parameters are themselves - wrapper subclasses; `.to()` returns the same subclass). See - https://github.com/huggingface/diffusers/pull/13276#discussion_r2944471548. + """Replace `target`'s underlying tensor with `source` in-place. + + `target`'s Python identity (and therefore any dict keyed by `id(target)`) is preserved. + + For TorchAO wrapper-subclass tensors we must use `torch.utils.swap_tensors`: `target.data = source` only replaces + the outer wrapper and leaves the internal attributes (`.qdata`, `.scale`, ...) on the wrong device. See + https://github.com/huggingface/diffusers/pull/13276#discussion_r2944471548. + + For regular tensors we use the `target.data = source` assignment. `swap_tensors` cannot be used here because + `torch.compile` attaches weakrefs to parameters and `swap_tensors` refuses to operate on tensors with weakrefs. """ if _is_torchao_tensor(target): torch.utils.swap_tensors(target, source) - return - if isinstance(target, torch.nn.Parameter) and not isinstance(source, torch.nn.Parameter): - source = torch.nn.Parameter(source, requires_grad=target.requires_grad) - torch.utils.swap_tensors(target, source) + else: + target.data = source def _restore_from_cached_cpu(target: torch.Tensor, source: torch.Tensor) -> None: """Make `target` reference `source`'s data without mutating `source`. Used for the stream-offload path where `source` is a long-lived cached CPU copy held in `cpu_param_dict`. - `swap_tensors` cannot be used here because it is bidirectional — swapping with the cached copy would put - GPU data into it and corrupt the cache for the next onload cycle. + `swap_tensors` cannot be used here because it is bidirectional — swapping with the cached copy would put GPU data + into it and corrupt the cache for the next onload cycle. - For TorchAO wrapper-subclass tensors, copy attribute references one-by-one via `setattr`. For regular - tensors we still rely on the `target.data = source` assignment: there is no public single-direction - equivalent that works on a leaf Parameter without explicitly entering a `no_grad` context. + For TorchAO wrapper-subclass tensors, copy attribute references one-by-one via `setattr`. For regular tensors we + still rely on the `target.data = source` assignment: there is no public single-direction equivalent that works on a + leaf Parameter without explicitly entering a `no_grad` context. """ if _is_torchao_tensor(target): for attr_name in _get_torchao_inner_tensor_names(source):