Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 16 additions & 3 deletions src/diffusers/hooks/group_offloading.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,12 +172,25 @@ def __init__(
else torch.cuda
)

@staticmethod
def _pin_memory(tensor):
try:
return tensor if tensor.is_pinned() else tensor.pin_memory()
except NotImplementedError:
if _is_torchao_tensor(tensor):
# Some legacy TorchAO tensor subclasses do not implement aten.is_pinned.
return tensor
raise

@staticmethod
def _to_cpu(tensor, low_cpu_mem_usage):
is_torchao_tensor = _is_torchao_tensor(tensor)
# For TorchAO tensors, `.data` returns an incomplete wrapper without internal attributes
# (e.g. `.qdata`, `.scale`), so we must call `.cpu()` on the tensor directly.
t = tensor.cpu() if _is_torchao_tensor(tensor) else tensor.data.cpu()
return t if low_cpu_mem_usage else t.pin_memory()
t = tensor.cpu() if is_torchao_tensor else tensor.data.cpu()
if low_cpu_mem_usage:
return t
return ModuleGroup._pin_memory(t)

def _init_cpu_param_dict(self):
cpu_param_dict = {}
Expand All @@ -202,7 +215,7 @@ def _init_cpu_param_dict(self):
def _pinned_memory_tensors(self):
try:
pinned_dict = {
param: tensor.pin_memory() if not tensor.is_pinned() else tensor
param: self._pin_memory(tensor)
for param, tensor in self.cpu_param_dict.items()
}
yield pinned_dict
Expand Down
24 changes: 24 additions & 0 deletions tests/quantization/torchao/test_torchao.py
Original file line number Diff line number Diff line change
Expand Up @@ -522,6 +522,30 @@ def test_sequential_cpu_offload(self):
inputs = self.get_dummy_inputs(torch_device)
_ = pipe(**inputs)

def test_group_offloading_torchao_int8wo_v1(self):
quantization_config = TorchAoConfig(Int8WeightOnlyConfig(version=1))
inputs = self.get_dummy_tensor_inputs(torch_device)
transformer = self.get_dummy_components(quantization_config)["transformer"].to(torch_device)
with torch.no_grad():
output_without_offloading = transformer(**inputs)[0]
del transformer
backend_empty_cache(torch_device)
gc.collect()

for offload_kwargs in (
{"offload_type": "leaf_level"},
{"offload_type": "leaf_level", "use_stream": True},
{"offload_type": "block_level", "num_blocks_per_group": 1, "use_stream": True},
):
transformer = self.get_dummy_components(quantization_config)["transformer"]
transformer.enable_group_offload(torch_device, **offload_kwargs)
with torch.no_grad():
output = transformer(**inputs)[0]
assert torch.allclose(output_without_offloading, output, atol=1e-3, rtol=1e-3)
del transformer
backend_empty_cache(torch_device)
gc.collect()

@require_torchao_version_greater_or_equal("0.15.0")
def test_aobase_config(self):
quantization_config = TorchAoConfig(Int8WeightOnlyConfig())
Expand Down
Loading