diff --git a/src/diffusers/hooks/group_offloading.py b/src/diffusers/hooks/group_offloading.py index f3d1f3389bb7..09f465cc84e5 100644 --- a/src/diffusers/hooks/group_offloading.py +++ b/src/diffusers/hooks/group_offloading.py @@ -22,7 +22,7 @@ import safetensors.torch import torch -from ..utils import get_logger, is_accelerate_available, is_torchao_available +from ..utils import get_logger, is_accelerate_available, is_optimum_quanto_available, is_torchao_available from ._common import _GO_LC_SUPPORTED_PYTORCH_LAYERS from .hooks import HookRegistry, ModelHook @@ -83,6 +83,31 @@ def _record_stream_torchao_tensor(param: torch.Tensor, stream) -> None: getattr(param, attr_name).record_stream(stream) +def _is_quanto_tensor(tensor: torch.Tensor) -> bool: + if not is_optimum_quanto_available(): + return False + from optimum.quanto import QTensor + + return isinstance(tensor, QTensor) + + +def _get_quanto_inner_tensor_names(tensor: torch.Tensor) -> list[str]: + """Get names of all internal tensor data attributes from a quanto QTensor (e.g. `_data`, `_scale`).""" + return list(tensor.__tensor_flatten__()[0]) + + +def _restore_quanto_tensor(param: torch.Tensor, source: torch.Tensor) -> None: + """Restore internal tensor data of a quanto QTensor from `source` without mutating `source`.""" + for attr_name in _get_quanto_inner_tensor_names(source): + setattr(param, attr_name, getattr(source, attr_name)) + + +def _record_stream_quanto_tensor(param: torch.Tensor, stream) -> None: + """Record stream for all internal tensors of a quanto QTensor.""" + for attr_name in _get_quanto_inner_tensor_names(param): + getattr(param, attr_name).record_stream(stream) + + # fmt: off _GROUP_OFFLOADING = "group_offloading" _LAYER_EXECUTION_TRACKER = "layer_execution_tracker" @@ -174,10 +199,15 @@ def __init__( @staticmethod def _to_cpu(tensor, low_cpu_mem_usage): - # For TorchAO tensors, `.data` returns an incomplete wrapper without internal attributes - # (e.g. `.qdata`, `.scale`), so we must call `.cpu()` on the tensor directly. - t = tensor.cpu() if _is_torchao_tensor(tensor) else tensor.data.cpu() - return t if low_cpu_mem_usage else t.pin_memory() + is_torchao_tensor = _is_torchao_tensor(tensor) + is_quanto_tensor = _is_quanto_tensor(tensor) + # For tensor subclasses (TorchAO / quanto), `.data` returns an incomplete wrapper without internal + # attributes (e.g. `.qdata`/`.scale`, `._data`/`._scale`), so we must call `.cpu()` on the tensor directly. + t = tensor.cpu() if (is_torchao_tensor or is_quanto_tensor) else tensor.data.cpu() + # Quanto tensors do not keep their subclass identity through `pin_memory()`, so skip pinning for them. + if low_cpu_mem_usage or is_quanto_tensor: + return t + return t.pin_memory() def _init_cpu_param_dict(self): cpu_param_dict = {} @@ -202,7 +232,7 @@ def _init_cpu_param_dict(self): def _pinned_memory_tensors(self): try: pinned_dict = { - param: tensor.pin_memory() if not tensor.is_pinned() else tensor + param: tensor if (_is_quanto_tensor(tensor) or tensor.is_pinned()) else tensor.pin_memory() for param, tensor in self.cpu_param_dict.items() } yield pinned_dict @@ -213,11 +243,15 @@ def _transfer_tensor_to_device(self, tensor, source_tensor, default_stream): moved = source_tensor.to(self.onload_device, non_blocking=self.non_blocking) if _is_torchao_tensor(tensor): _swap_torchao_tensor(tensor, moved) + elif _is_quanto_tensor(tensor): + torch.utils.swap_tensors(tensor, moved) else: tensor.data = moved if self.record_stream: if _is_torchao_tensor(tensor): _record_stream_torchao_tensor(tensor, default_stream) + elif _is_quanto_tensor(tensor): + _record_stream_quanto_tensor(tensor, default_stream) else: tensor.data.record_stream(default_stream) @@ -320,16 +354,22 @@ def _offload_to_memory(self): for param in group_module.parameters(): if _is_torchao_tensor(param): _restore_torchao_tensor(param, self.cpu_param_dict[param]) + elif _is_quanto_tensor(param): + _restore_quanto_tensor(param, self.cpu_param_dict[param]) else: param.data = self.cpu_param_dict[param] for param in self.parameters: if _is_torchao_tensor(param): _restore_torchao_tensor(param, self.cpu_param_dict[param]) + elif _is_quanto_tensor(param): + _restore_quanto_tensor(param, self.cpu_param_dict[param]) else: param.data = self.cpu_param_dict[param] for buffer in self.buffers: if _is_torchao_tensor(buffer): _restore_torchao_tensor(buffer, self.cpu_param_dict[buffer]) + elif _is_quanto_tensor(buffer): + _restore_quanto_tensor(buffer, self.cpu_param_dict[buffer]) else: buffer.data = self.cpu_param_dict[buffer] else: @@ -339,12 +379,16 @@ def _offload_to_memory(self): if _is_torchao_tensor(param): moved = param.to(self.offload_device, non_blocking=False) _swap_torchao_tensor(param, moved) + elif _is_quanto_tensor(param): + torch.utils.swap_tensors(param, param.to(self.offload_device, non_blocking=False)) else: param.data = param.data.to(self.offload_device, non_blocking=False) for buffer in self.buffers: if _is_torchao_tensor(buffer): moved = buffer.to(self.offload_device, non_blocking=False) _swap_torchao_tensor(buffer, moved) + elif _is_quanto_tensor(buffer): + torch.utils.swap_tensors(buffer, buffer.to(self.offload_device, non_blocking=False)) else: buffer.data = buffer.data.to(self.offload_device, non_blocking=False) diff --git a/tests/quantization/quanto/test_quanto.py b/tests/quantization/quanto/test_quanto.py index e3463f136f94..0d203a83bb44 100644 --- a/tests/quantization/quanto/test_quanto.py +++ b/tests/quantization/quanto/test_quanto.py @@ -273,6 +273,30 @@ def test_model_cpu_offload(self): pipe.enable_model_cpu_offload(device=torch_device) _ = pipe("a cat holding a sign that says hello", num_inference_steps=2) + def test_group_offloading(self): + inputs = self.get_dummy_inputs() + model = self.model_cls.from_pretrained(**self.get_dummy_model_init_kwargs()).to(torch_device) + with torch.no_grad(): + output_without_offloading = model(**inputs).sample + model.to("cpu") + del model + backend_empty_cache(torch_device) + gc.collect() + + for offload_kwargs in ( + {"offload_type": "leaf_level"}, + {"offload_type": "leaf_level", "use_stream": True}, + {"offload_type": "block_level", "num_blocks_per_group": 1, "use_stream": True}, + ): + model = self.model_cls.from_pretrained(**self.get_dummy_model_init_kwargs()) + model.enable_group_offload(torch_device, **offload_kwargs) + with torch.no_grad(): + output = model(**inputs).sample + assert torch.allclose(output_without_offloading, output, atol=1e-3, rtol=1e-3) + del model + backend_empty_cache(torch_device) + gc.collect() + def test_training(self): quantization_config = QuantoConfig(**self.get_dummy_init_kwargs()) quantized_model = self.model_cls.from_pretrained(