diff --git a/src/diffusers/hooks/group_offloading.py b/src/diffusers/hooks/group_offloading.py index 49509cbf04b9..e54a836d01e4 100644 --- a/src/diffusers/hooks/group_offloading.py +++ b/src/diffusers/hooks/group_offloading.py @@ -53,28 +53,40 @@ def _get_torchao_inner_tensor_names(tensor: torch.Tensor) -> list[str]: return names -def _swap_torchao_tensor(param: torch.Tensor, source: torch.Tensor) -> None: - """Move a TorchAO parameter to the device of `source` via `swap_tensors`. +def _swap_tensor_data(target: torch.Tensor, source: torch.Tensor) -> None: + """Replace `target`'s underlying tensor with `source` in-place. - `param.data = source` does not work for `_make_wrapper_subclass` tensors because the `.data` setter only replaces - the outer wrapper storage while leaving the subclass's internal attributes (e.g. `.qdata`, `.scale`) on the - original device. `swap_tensors` swaps the full tensor contents in-place, preserving the parameter's identity so - that any dict keyed by `id(param)` remains valid. + `target`'s Python identity (and therefore any dict keyed by `id(target)`) is preserved. - Refer to https://github.com/huggingface/diffusers/pull/13276#discussion_r2944471548 for the full discussion. + For TorchAO wrapper-subclass tensors we must use `torch.utils.swap_tensors`: `target.data = source` only replaces + the outer wrapper and leaves the internal attributes (`.qdata`, `.scale`, ...) on the wrong device. See + https://github.com/huggingface/diffusers/pull/13276#discussion_r2944471548. + + For regular tensors we use the `target.data = source` assignment. `swap_tensors` cannot be used here because + `torch.compile` attaches weakrefs to parameters and `swap_tensors` refuses to operate on tensors with weakrefs. """ - torch.utils.swap_tensors(param, source) + if _is_torchao_tensor(target): + torch.utils.swap_tensors(target, source) + else: + target.data = source + +def _restore_from_cached_cpu(target: torch.Tensor, source: torch.Tensor) -> None: + """Make `target` reference `source`'s data without mutating `source`. -def _restore_torchao_tensor(param: torch.Tensor, source: torch.Tensor) -> None: - """Restore internal tensor data of a TorchAO parameter from `source` without mutating `source`. + Used for the stream-offload path where `source` is a long-lived cached CPU copy held in `cpu_param_dict`. + `swap_tensors` cannot be used here because it is bidirectional — swapping with the cached copy would put GPU data + into it and corrupt the cache for the next onload cycle. - Unlike `_swap_torchao_tensor` this copies attribute references one-by-one via `setattr` so that `source` is **not** - modified. Use this when `source` is a cached tensor that must remain unchanged (e.g. a pinned CPU copy in - `cpu_param_dict`). + For TorchAO wrapper-subclass tensors, copy attribute references one-by-one via `setattr`. For regular tensors we + still rely on the `target.data = source` assignment: there is no public single-direction equivalent that works on a + leaf Parameter without explicitly entering a `no_grad` context. """ - for attr_name in _get_torchao_inner_tensor_names(source): - setattr(param, attr_name, getattr(source, attr_name)) + if _is_torchao_tensor(target): + for attr_name in _get_torchao_inner_tensor_names(source): + setattr(target, attr_name, getattr(source, attr_name)) + else: + target.data = source def _record_stream_torchao_tensor(param: torch.Tensor, stream) -> None: @@ -211,10 +223,7 @@ def _pinned_memory_tensors(self): def _transfer_tensor_to_device(self, tensor, source_tensor, default_stream): moved = source_tensor.to(self.onload_device, non_blocking=self.non_blocking) - if _is_torchao_tensor(tensor): - _swap_torchao_tensor(tensor, moved) - else: - tensor.data = moved + _swap_tensor_data(tensor, moved) if self.record_stream: if _is_torchao_tensor(tensor): _record_stream_torchao_tensor(tensor, default_stream) @@ -266,7 +275,8 @@ def _onload_from_disk(self): if self.stream is not None: for key, tensor_obj in self.key_to_tensor.items(): pinned_tensor = loaded_tensors[key].pin_memory() - tensor_obj.data = pinned_tensor.to(self.onload_device, non_blocking=self.non_blocking) + moved = pinned_tensor.to(self.onload_device, non_blocking=self.non_blocking) + _swap_tensor_data(tensor_obj, moved) if self.record_stream: tensor_obj.data.record_stream(current_stream) else: @@ -275,7 +285,7 @@ def _onload_from_disk(self): ) loaded_tensors = safetensors.torch.load_file(self.safetensors_file_path, device=onload_device) for key, tensor_obj in self.key_to_tensor.items(): - tensor_obj.data = loaded_tensors[key] + _swap_tensor_data(tensor_obj, loaded_tensors[key]) def _onload_from_memory(self): if self.stream is not None: @@ -310,7 +320,7 @@ def _offload_to_disk(self): # We do this to free up the RAM which is still holding the up tensor data. for tensor_obj in self.tensor_to_key.keys(): - tensor_obj.data = torch.empty_like(tensor_obj.data, device=self.offload_device) + _swap_tensor_data(tensor_obj, torch.empty_like(tensor_obj.data, device=self.offload_device)) def _offload_to_memory(self): if self.stream is not None: @@ -319,35 +329,22 @@ def _offload_to_memory(self): for group_module in self.modules: for param in group_module.parameters(): - if _is_torchao_tensor(param): - _restore_torchao_tensor(param, self.cpu_param_dict[param]) - else: - param.data = self.cpu_param_dict[param] + _restore_from_cached_cpu(param, self.cpu_param_dict[param]) for param in self.parameters: - if _is_torchao_tensor(param): - _restore_torchao_tensor(param, self.cpu_param_dict[param]) - else: - param.data = self.cpu_param_dict[param] + _restore_from_cached_cpu(param, self.cpu_param_dict[param]) for buffer in self.buffers: - if _is_torchao_tensor(buffer): - _restore_torchao_tensor(buffer, self.cpu_param_dict[buffer]) - else: - buffer.data = self.cpu_param_dict[buffer] + _restore_from_cached_cpu(buffer, self.cpu_param_dict[buffer]) else: for group_module in self.modules: group_module.to(self.offload_device, non_blocking=False) for param in self.parameters: - if _is_torchao_tensor(param): - moved = param.to(self.offload_device, non_blocking=False) - _swap_torchao_tensor(param, moved) - else: - param.data = param.data.to(self.offload_device, non_blocking=False) + # For TorchAO `.data` returns an incomplete wrapper without internal attributes; + # call `.to()` on the parameter itself in that case. + source = param if _is_torchao_tensor(param) else param.data + _swap_tensor_data(param, source.to(self.offload_device, non_blocking=False)) for buffer in self.buffers: - if _is_torchao_tensor(buffer): - moved = buffer.to(self.offload_device, non_blocking=False) - _swap_torchao_tensor(buffer, moved) - else: - buffer.data = buffer.data.to(self.offload_device, non_blocking=False) + source = buffer if _is_torchao_tensor(buffer) else buffer.data + _swap_tensor_data(buffer, source.to(self.offload_device, non_blocking=False)) @torch.compiler.disable() def onload_(self):