diff --git a/src/diffusers/hooks/group_offloading.py b/src/diffusers/hooks/group_offloading.py index 10d3f0c245a1..df80d1a9c98f 100644 --- a/src/diffusers/hooks/group_offloading.py +++ b/src/diffusers/hooks/group_offloading.py @@ -172,12 +172,25 @@ def __init__( else torch.cuda ) + @staticmethod + def _pin_memory(tensor): + try: + return tensor if tensor.is_pinned() else tensor.pin_memory() + except NotImplementedError: + if _is_torchao_tensor(tensor): + # Some legacy TorchAO tensor subclasses do not implement aten.is_pinned. + return tensor + raise + @staticmethod def _to_cpu(tensor, low_cpu_mem_usage): + is_torchao_tensor = _is_torchao_tensor(tensor) # For TorchAO tensors, `.data` returns an incomplete wrapper without internal attributes # (e.g. `.qdata`, `.scale`), so we must call `.cpu()` on the tensor directly. - t = tensor.cpu() if _is_torchao_tensor(tensor) else tensor.data.cpu() - return t if low_cpu_mem_usage else t.pin_memory() + t = tensor.cpu() if is_torchao_tensor else tensor.data.cpu() + if low_cpu_mem_usage: + return t + return ModuleGroup._pin_memory(t) def _init_cpu_param_dict(self): cpu_param_dict = {} @@ -202,7 +215,7 @@ def _init_cpu_param_dict(self): def _pinned_memory_tensors(self): try: pinned_dict = { - param: tensor.pin_memory() if not tensor.is_pinned() else tensor + param: self._pin_memory(tensor) for param, tensor in self.cpu_param_dict.items() } yield pinned_dict diff --git a/tests/quantization/torchao/test_torchao.py b/tests/quantization/torchao/test_torchao.py index fbe832348d2e..7283103bfc87 100644 --- a/tests/quantization/torchao/test_torchao.py +++ b/tests/quantization/torchao/test_torchao.py @@ -522,6 +522,30 @@ def test_sequential_cpu_offload(self): inputs = self.get_dummy_inputs(torch_device) _ = pipe(**inputs) + def test_group_offloading_torchao_int8wo_v1(self): + quantization_config = TorchAoConfig(Int8WeightOnlyConfig(version=1)) + inputs = self.get_dummy_tensor_inputs(torch_device) + transformer = self.get_dummy_components(quantization_config)["transformer"].to(torch_device) + with torch.no_grad(): + output_without_offloading = transformer(**inputs)[0] + del transformer + backend_empty_cache(torch_device) + gc.collect() + + for offload_kwargs in ( + {"offload_type": "leaf_level"}, + {"offload_type": "leaf_level", "use_stream": True}, + {"offload_type": "block_level", "num_blocks_per_group": 1, "use_stream": True}, + ): + transformer = self.get_dummy_components(quantization_config)["transformer"] + transformer.enable_group_offload(torch_device, **offload_kwargs) + with torch.no_grad(): + output = transformer(**inputs)[0] + assert torch.allclose(output_without_offloading, output, atol=1e-3, rtol=1e-3) + del transformer + backend_empty_cache(torch_device) + gc.collect() + @require_torchao_version_greater_or_equal("0.15.0") def test_aobase_config(self): quantization_config = TorchAoConfig(Int8WeightOnlyConfig())