Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
85 changes: 41 additions & 44 deletions src/diffusers/hooks/group_offloading.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,28 +53,40 @@ def _get_torchao_inner_tensor_names(tensor: torch.Tensor) -> list[str]:
return names


def _swap_torchao_tensor(param: torch.Tensor, source: torch.Tensor) -> None:
"""Move a TorchAO parameter to the device of `source` via `swap_tensors`.
def _swap_tensor_data(target: torch.Tensor, source: torch.Tensor) -> None:
"""Replace `target`'s underlying tensor with `source` in-place.

`param.data = source` does not work for `_make_wrapper_subclass` tensors because the `.data` setter only replaces
the outer wrapper storage while leaving the subclass's internal attributes (e.g. `.qdata`, `.scale`) on the
original device. `swap_tensors` swaps the full tensor contents in-place, preserving the parameter's identity so
that any dict keyed by `id(param)` remains valid.
`target`'s Python identity (and therefore any dict keyed by `id(target)`) is preserved.

Refer to https://github.com/huggingface/diffusers/pull/13276#discussion_r2944471548 for the full discussion.
For TorchAO wrapper-subclass tensors we must use `torch.utils.swap_tensors`: `target.data = source` only replaces
the outer wrapper and leaves the internal attributes (`.qdata`, `.scale`, ...) on the wrong device. See
https://github.com/huggingface/diffusers/pull/13276#discussion_r2944471548.

For regular tensors we use the `target.data = source` assignment. `swap_tensors` cannot be used here because
`torch.compile` attaches weakrefs to parameters and `swap_tensors` refuses to operate on tensors with weakrefs.
"""
torch.utils.swap_tensors(param, source)
if _is_torchao_tensor(target):
torch.utils.swap_tensors(target, source)
else:
target.data = source


def _restore_from_cached_cpu(target: torch.Tensor, source: torch.Tensor) -> None:
"""Make `target` reference `source`'s data without mutating `source`.

def _restore_torchao_tensor(param: torch.Tensor, source: torch.Tensor) -> None:
"""Restore internal tensor data of a TorchAO parameter from `source` without mutating `source`.
Used for the stream-offload path where `source` is a long-lived cached CPU copy held in `cpu_param_dict`.
`swap_tensors` cannot be used here because it is bidirectional — swapping with the cached copy would put GPU data
into it and corrupt the cache for the next onload cycle.

Unlike `_swap_torchao_tensor` this copies attribute references one-by-one via `setattr` so that `source` is **not**
modified. Use this when `source` is a cached tensor that must remain unchanged (e.g. a pinned CPU copy in
`cpu_param_dict`).
For TorchAO wrapper-subclass tensors, copy attribute references one-by-one via `setattr`. For regular tensors we
still rely on the `target.data = source` assignment: there is no public single-direction equivalent that works on a
leaf Parameter without explicitly entering a `no_grad` context.
"""
for attr_name in _get_torchao_inner_tensor_names(source):
setattr(param, attr_name, getattr(source, attr_name))
if _is_torchao_tensor(target):
for attr_name in _get_torchao_inner_tensor_names(source):
setattr(target, attr_name, getattr(source, attr_name))
else:
target.data = source


def _record_stream_torchao_tensor(param: torch.Tensor, stream) -> None:
Expand Down Expand Up @@ -211,10 +223,7 @@ def _pinned_memory_tensors(self):

def _transfer_tensor_to_device(self, tensor, source_tensor, default_stream):
moved = source_tensor.to(self.onload_device, non_blocking=self.non_blocking)
if _is_torchao_tensor(tensor):
_swap_torchao_tensor(tensor, moved)
else:
tensor.data = moved
_swap_tensor_data(tensor, moved)
if self.record_stream:
if _is_torchao_tensor(tensor):
_record_stream_torchao_tensor(tensor, default_stream)
Expand Down Expand Up @@ -266,7 +275,8 @@ def _onload_from_disk(self):
if self.stream is not None:
for key, tensor_obj in self.key_to_tensor.items():
pinned_tensor = loaded_tensors[key].pin_memory()
tensor_obj.data = pinned_tensor.to(self.onload_device, non_blocking=self.non_blocking)
moved = pinned_tensor.to(self.onload_device, non_blocking=self.non_blocking)
_swap_tensor_data(tensor_obj, moved)
if self.record_stream:
tensor_obj.data.record_stream(current_stream)
else:
Expand All @@ -275,7 +285,7 @@ def _onload_from_disk(self):
)
loaded_tensors = safetensors.torch.load_file(self.safetensors_file_path, device=onload_device)
for key, tensor_obj in self.key_to_tensor.items():
tensor_obj.data = loaded_tensors[key]
_swap_tensor_data(tensor_obj, loaded_tensors[key])

def _onload_from_memory(self):
if self.stream is not None:
Expand Down Expand Up @@ -310,7 +320,7 @@ def _offload_to_disk(self):

# We do this to free up the RAM which is still holding the up tensor data.
for tensor_obj in self.tensor_to_key.keys():
tensor_obj.data = torch.empty_like(tensor_obj.data, device=self.offload_device)
_swap_tensor_data(tensor_obj, torch.empty_like(tensor_obj.data, device=self.offload_device))

def _offload_to_memory(self):
if self.stream is not None:
Expand All @@ -319,35 +329,22 @@ def _offload_to_memory(self):

for group_module in self.modules:
for param in group_module.parameters():
if _is_torchao_tensor(param):
_restore_torchao_tensor(param, self.cpu_param_dict[param])
else:
param.data = self.cpu_param_dict[param]
_restore_from_cached_cpu(param, self.cpu_param_dict[param])
for param in self.parameters:
if _is_torchao_tensor(param):
_restore_torchao_tensor(param, self.cpu_param_dict[param])
else:
param.data = self.cpu_param_dict[param]
_restore_from_cached_cpu(param, self.cpu_param_dict[param])
for buffer in self.buffers:
if _is_torchao_tensor(buffer):
_restore_torchao_tensor(buffer, self.cpu_param_dict[buffer])
else:
buffer.data = self.cpu_param_dict[buffer]
_restore_from_cached_cpu(buffer, self.cpu_param_dict[buffer])
else:
for group_module in self.modules:
group_module.to(self.offload_device, non_blocking=False)
for param in self.parameters:
if _is_torchao_tensor(param):
moved = param.to(self.offload_device, non_blocking=False)
_swap_torchao_tensor(param, moved)
else:
param.data = param.data.to(self.offload_device, non_blocking=False)
# For TorchAO `.data` returns an incomplete wrapper without internal attributes;
# call `.to()` on the parameter itself in that case.
source = param if _is_torchao_tensor(param) else param.data
_swap_tensor_data(param, source.to(self.offload_device, non_blocking=False))
for buffer in self.buffers:
if _is_torchao_tensor(buffer):
moved = buffer.to(self.offload_device, non_blocking=False)
_swap_torchao_tensor(buffer, moved)
else:
buffer.data = buffer.data.to(self.offload_device, non_blocking=False)
source = buffer if _is_torchao_tensor(buffer) else buffer.data
_swap_tensor_data(buffer, source.to(self.offload_device, non_blocking=False))

@torch.compiler.disable()
def onload_(self):
Expand Down
Loading