From d135a1626d8d75247353d2ee5729e1303da5165a Mon Sep 17 00:00:00 2001 From: Ting Sun Date: Fri, 3 Jul 2026 04:45:07 +0700 Subject: [PATCH 1/2] Fix torchao group offloading with use_stream=True --- src/diffusers/hooks/group_offloading.py | 9 ++++++--- tests/quantization/torchao/test_torchao.py | 23 ++++++++++++++++++++++ 2 files changed, 29 insertions(+), 3 deletions(-) diff --git a/src/diffusers/hooks/group_offloading.py b/src/diffusers/hooks/group_offloading.py index 10d3f0c245a1..bc05aacb5d17 100644 --- a/src/diffusers/hooks/group_offloading.py +++ b/src/diffusers/hooks/group_offloading.py @@ -174,10 +174,13 @@ def __init__( @staticmethod def _to_cpu(tensor, low_cpu_mem_usage): + is_torchao_tensor = _is_torchao_tensor(tensor) # For TorchAO tensors, `.data` returns an incomplete wrapper without internal attributes # (e.g. `.qdata`, `.scale`), so we must call `.cpu()` on the tensor directly. - t = tensor.cpu() if _is_torchao_tensor(tensor) else tensor.data.cpu() - return t if low_cpu_mem_usage else t.pin_memory() + t = tensor.cpu() if is_torchao_tensor else tensor.data.cpu() + if low_cpu_mem_usage or is_torchao_tensor: + return t + return t.pin_memory() def _init_cpu_param_dict(self): cpu_param_dict = {} @@ -202,7 +205,7 @@ def _init_cpu_param_dict(self): def _pinned_memory_tensors(self): try: pinned_dict = { - param: tensor.pin_memory() if not tensor.is_pinned() else tensor + param: tensor if (_is_torchao_tensor(tensor) or tensor.is_pinned()) else tensor.pin_memory() for param, tensor in self.cpu_param_dict.items() } yield pinned_dict diff --git a/tests/quantization/torchao/test_torchao.py b/tests/quantization/torchao/test_torchao.py index fbe832348d2e..c3f8254dadfe 100644 --- a/tests/quantization/torchao/test_torchao.py +++ b/tests/quantization/torchao/test_torchao.py @@ -522,6 +522,29 @@ def test_sequential_cpu_offload(self): inputs = self.get_dummy_inputs(torch_device) _ = pipe(**inputs) + def test_group_offloading(self): + inputs = self.get_dummy_tensor_inputs(torch_device) + transformer = self.get_dummy_components(TorchAoConfig(Int8WeightOnlyConfig()))["transformer"].to(torch_device) + with torch.no_grad(): + output_without_offloading = transformer(**inputs)[0] + del transformer + backend_empty_cache(torch_device) + gc.collect() + + for offload_kwargs in ( + {"offload_type": "leaf_level"}, + {"offload_type": "leaf_level", "use_stream": True}, + {"offload_type": "block_level", "num_blocks_per_group": 1, "use_stream": True}, + ): + transformer = self.get_dummy_components(TorchAoConfig(Int8WeightOnlyConfig()))["transformer"] + transformer.enable_group_offload(torch_device, **offload_kwargs) + with torch.no_grad(): + output = transformer(**inputs)[0] + assert torch.allclose(output_without_offloading, output, atol=1e-3, rtol=1e-3) + del transformer + backend_empty_cache(torch_device) + gc.collect() + @require_torchao_version_greater_or_equal("0.15.0") def test_aobase_config(self): quantization_config = TorchAoConfig(Int8WeightOnlyConfig()) From 2075bc3512edb0cd9e7ab363708610400d08a429 Mon Sep 17 00:00:00 2001 From: Ting Sun Date: Fri, 3 Jul 2026 17:06:27 +0800 Subject: [PATCH 2/2] Fix TorchAO v1 group offloading with use_stream=True --- src/diffusers/hooks/group_offloading.py | 16 +++++++++++++--- tests/quantization/torchao/test_torchao.py | 7 ++++--- 2 files changed, 17 insertions(+), 6 deletions(-) diff --git a/src/diffusers/hooks/group_offloading.py b/src/diffusers/hooks/group_offloading.py index bc05aacb5d17..df80d1a9c98f 100644 --- a/src/diffusers/hooks/group_offloading.py +++ b/src/diffusers/hooks/group_offloading.py @@ -172,15 +172,25 @@ def __init__( else torch.cuda ) + @staticmethod + def _pin_memory(tensor): + try: + return tensor if tensor.is_pinned() else tensor.pin_memory() + except NotImplementedError: + if _is_torchao_tensor(tensor): + # Some legacy TorchAO tensor subclasses do not implement aten.is_pinned. + return tensor + raise + @staticmethod def _to_cpu(tensor, low_cpu_mem_usage): is_torchao_tensor = _is_torchao_tensor(tensor) # For TorchAO tensors, `.data` returns an incomplete wrapper without internal attributes # (e.g. `.qdata`, `.scale`), so we must call `.cpu()` on the tensor directly. t = tensor.cpu() if is_torchao_tensor else tensor.data.cpu() - if low_cpu_mem_usage or is_torchao_tensor: + if low_cpu_mem_usage: return t - return t.pin_memory() + return ModuleGroup._pin_memory(t) def _init_cpu_param_dict(self): cpu_param_dict = {} @@ -205,7 +215,7 @@ def _init_cpu_param_dict(self): def _pinned_memory_tensors(self): try: pinned_dict = { - param: tensor if (_is_torchao_tensor(tensor) or tensor.is_pinned()) else tensor.pin_memory() + param: self._pin_memory(tensor) for param, tensor in self.cpu_param_dict.items() } yield pinned_dict diff --git a/tests/quantization/torchao/test_torchao.py b/tests/quantization/torchao/test_torchao.py index c3f8254dadfe..7283103bfc87 100644 --- a/tests/quantization/torchao/test_torchao.py +++ b/tests/quantization/torchao/test_torchao.py @@ -522,9 +522,10 @@ def test_sequential_cpu_offload(self): inputs = self.get_dummy_inputs(torch_device) _ = pipe(**inputs) - def test_group_offloading(self): + def test_group_offloading_torchao_int8wo_v1(self): + quantization_config = TorchAoConfig(Int8WeightOnlyConfig(version=1)) inputs = self.get_dummy_tensor_inputs(torch_device) - transformer = self.get_dummy_components(TorchAoConfig(Int8WeightOnlyConfig()))["transformer"].to(torch_device) + transformer = self.get_dummy_components(quantization_config)["transformer"].to(torch_device) with torch.no_grad(): output_without_offloading = transformer(**inputs)[0] del transformer @@ -536,7 +537,7 @@ def test_group_offloading(self): {"offload_type": "leaf_level", "use_stream": True}, {"offload_type": "block_level", "num_blocks_per_group": 1, "use_stream": True}, ): - transformer = self.get_dummy_components(TorchAoConfig(Int8WeightOnlyConfig()))["transformer"] + transformer = self.get_dummy_components(quantization_config)["transformer"] transformer.enable_group_offload(torch_device, **offload_kwargs) with torch.no_grad(): output = transformer(**inputs)[0]