Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 61 additions & 18 deletions src/diffusers/hooks/group_offloading.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
import safetensors.torch
import torch

from ..utils import get_logger, is_accelerate_available, is_torchao_available
from ..utils import get_logger, is_accelerate_available, is_torchao_available, is_torchao_version
from ._common import _GO_LC_SUPPORTED_PYTORCH_LAYERS
from .hooks import HookRegistry, ModelHook

Expand All @@ -31,6 +31,15 @@
from accelerate.hooks import AlignDevicesHook, CpuOffload
from accelerate.utils import send_to_device

if is_torchao_available() and is_torchao_version(">=", "0.15.0"):
from torchao.prototype.safetensors.safetensors_support import (
flatten_tensor_state_dict,
unflatten_tensor_state_dict,
)
else:
flatten_tensor_state_dict = None
unflatten_tensor_state_dict = None


logger = get_logger(__name__) # pylint: disable=invalid-name

Expand Down Expand Up @@ -162,8 +171,15 @@ def __init__(

self.tensor_to_key = {tensor: f"tensor_{i}" for i, tensor in enumerate(all_tensors)}
self.key_to_tensor = {v: k for k, v in self.tensor_to_key.items()}
self._has_torchao = any(_is_torchao_tensor(t) for t in all_tensors)
if self._has_torchao and flatten_tensor_state_dict is None:
raise ImportError(
"Disk offloading of TorchAO-quantized tensors requires torchao>=0.15.0. Either "
"upgrade torchao or run group offloading without `offload_to_disk_path`."
)
self.cpu_param_dict = {}
else:
self._has_torchao = False
self.cpu_param_dict = self._init_cpu_param_dict()

self._torch_accelerator_module = (
Expand Down Expand Up @@ -238,18 +254,9 @@ def _process_tensors_from_modules(self, pinned_memory=None, default_stream=None)
source = pinned_memory[buffer] if pinned_memory else buffer.data
self._transfer_tensor_to_device(buffer, source, default_stream)

def _check_disk_offload_torchao(self):
all_tensors = list(self.tensor_to_key.keys())
has_torchao = any(_is_torchao_tensor(t) for t in all_tensors)
if has_torchao:
raise ValueError(
"Disk offloading is not supported for TorchAO quantized tensors because safetensors "
"cannot serialize TorchAO subclass tensors. Use memory offloading instead by not "
"setting `offload_to_disk_path`."
)

def _onload_from_disk(self):
self._check_disk_offload_torchao()
if self._has_torchao:
return self._onload_from_disk_torchao()

if self.stream is not None:
# Wait for previous Host->Device transfer to complete
Expand Down Expand Up @@ -277,6 +284,25 @@ def _onload_from_disk(self):
for key, tensor_obj in self.key_to_tensor.items():
tensor_obj.data = loaded_tensors[key]

def _onload_from_disk_torchao(self):
"""Synchronous direct-to-device load for groups containing TorchAO subclass tensors."""
with safetensors.safe_open(self.safetensors_file_path, framework="pt") as f:
metadata = f.metadata() or {}
flat = safetensors.torch.load_file(self.safetensors_file_path, device=str(self.onload_device))
unflattened, leftover = unflatten_tensor_state_dict(flat, metadata)
if leftover:
raise ValueError(
f"Group offload unflatten left unprocessed entries: {sorted(leftover.keys())[:10]}"
)
# Strip the `.weight` suffix we added before flatten on the save side.
for full_key, tensor in unflattened.items():
key = full_key[: -len(".weight")] if full_key.endswith(".weight") else full_key
tensor_obj = self.key_to_tensor[key]
if _is_torchao_tensor(tensor_obj):
_swap_torchao_tensor(tensor_obj, tensor)
else:
tensor_obj.data = tensor

def _onload_from_memory(self):
if self.stream is not None:
# Wait for previous Host->Device transfer to complete
Expand All @@ -293,24 +319,41 @@ def _onload_from_memory(self):
self._process_tensors_from_modules(None)

def _offload_to_disk(self):
self._check_disk_offload_torchao()

# TODO: we can potentially optimize this code path by checking if the _all_ the desired
# safetensor files exist on the disk and if so, skip this step entirely, reducing IO
# overhead. Currently, we just check if the given `safetensors_file_path` exists and if not
# we perform a write.
# Check if the file has been saved in this session or if it already exists on disk.
if not self._is_offloaded_to_disk and not os.path.exists(self.safetensors_file_path):
os.makedirs(os.path.dirname(self.safetensors_file_path), exist_ok=True)
tensors_to_save = {key: tensor.data.to(self.offload_device) for tensor, key in self.tensor_to_key.items()}
safetensors.torch.save_file(tensors_to_save, self.safetensors_file_path)
if self._has_torchao:
# Append `.weight` so torchao's flatten helper sees the subclass boundary.
state = {
f"{key}.weight": (tensor.cpu() if _is_torchao_tensor(tensor) else tensor.data.cpu())
for tensor, key in self.tensor_to_key.items()
}
flat, metadata = flatten_tensor_state_dict(state)
safetensors.torch.save_file(flat, self.safetensors_file_path, metadata=metadata)
else:
tensors_to_save = {
key: tensor.data.to(self.offload_device) for tensor, key in self.tensor_to_key.items()
}
safetensors.torch.save_file(tensors_to_save, self.safetensors_file_path)

# The group is now considered offloaded to disk for the rest of the session.
self._is_offloaded_to_disk = True

# We do this to free up the RAM which is still holding the up tensor data.
# Free GPU memory. Torchao subclasses need `_swap_torchao_tensor` because `.data =` is a no-op
# on their inner storages.
offload_device = (
torch.device(self.offload_device) if isinstance(self.offload_device, str) else self.offload_device
)
for tensor_obj in self.tensor_to_key.keys():
tensor_obj.data = torch.empty_like(tensor_obj.data, device=self.offload_device)
if _is_torchao_tensor(tensor_obj):
if tensor_obj.device != offload_device:
_swap_torchao_tensor(tensor_obj, tensor_obj.to(offload_device))
else:
tensor_obj.data = torch.empty_like(tensor_obj.data, device=self.offload_device)

def _offload_to_memory(self):
if self.stream is not None:
Expand Down
2 changes: 2 additions & 0 deletions src/diffusers/models/model_loading_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -357,6 +357,8 @@ def _load_shard_file(
disable_mmap=False,
):
state_dict = load_state_dict(shard_file, dduf_entries=dduf_entries, disable_mmap=disable_mmap)
if hf_quantizer is not None and getattr(hf_quantizer, "metadata", None):
state_dict = hf_quantizer.update_state_dict_with_metadata(state_dict, hf_quantizer.metadata)
mismatched_keys = _find_mismatched_keys(
state_dict,
model_state_dict,
Expand Down
13 changes: 11 additions & 2 deletions src/diffusers/models/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -758,7 +758,14 @@ def save_pretrained(
model_to_save.save_config(save_directory)

# Save the model
state_dict = model_to_save.state_dict()
safetensors_metadata = {"format": "pt"}
if hf_quantizer is not None:
state_dict, quantizer_metadata = hf_quantizer.get_state_dict_and_metadata(
model_to_save, safe_serialization=safe_serialization
)
safetensors_metadata.update(quantizer_metadata)
else:
state_dict = model_to_save.state_dict()

if use_flashpack:
if is_flashpack_available():
Expand Down Expand Up @@ -805,7 +812,7 @@ def save_pretrained(
if safe_serialization:
# At some point we will need to deal better with save_function (used for TPU and other distributed
# joyfulness), but for now this enough.
safetensors.torch.save_file(shard, filepath, metadata={"format": "pt"})
safetensors.torch.save_file(shard, filepath, metadata=safetensors_metadata)
else:
torch.save(shard, filepath)

Expand Down Expand Up @@ -1368,6 +1375,8 @@ def from_pretrained(cls, pretrained_model_name_or_path: str | os.PathLike | None
loaded_keys = list(state_dict.keys())

if hf_quantizer is not None:
hf_quantizer.set_metadata(resolved_model_file)
loaded_keys = hf_quantizer.update_loaded_keys(loaded_keys)
hf_quantizer.preprocess_model(
model=model, device_map=device_map, keep_in_fp32_modules=keep_in_fp32_modules
)
Expand Down
20 changes: 20 additions & 0 deletions src/diffusers/quantizers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,6 +235,26 @@ def _process_model_after_weight_loading(self, model, **kwargs): ...
@abstractmethod
def is_serializable(self): ...

def get_state_dict_and_metadata(
self, model: "ModelMixin", safe_serialization: bool = False
) -> tuple[dict[str, Any], dict[str, str]]:
"""Save-time hook: return `(state_dict, safetensors_metadata)`."""
return model.state_dict(), {}

def set_metadata(self, checkpoint_files: list[str]) -> None:
"""Load-time hook: read whatever per-shard safetensors metadata the quantizer needs."""
return None

def update_state_dict_with_metadata(
self, state_dict: dict[str, Any], metadata: dict[str, str]
) -> dict[str, Any]:
"""Load-time hook: transform a shard's state dict before per-parameter dispatch."""
return state_dict

def update_loaded_keys(self, loaded_keys: list[str]) -> list[str]:
"""Load-time hook: rewrite checkpoint key names to match the post-transform state dict."""
return loaded_keys

@property
@abstractmethod
def is_trainable(self): ...
Expand Down
143 changes: 136 additions & 7 deletions src/diffusers/quantizers/torchao/torchao_quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,38 @@
if is_torchao_available():
from torchao.quantization import quantize_

if is_torchao_available() and is_torchao_version(">=", "0.15.0"):
from torchao.prototype.safetensors.safetensors_support import (
flatten_tensor_state_dict,
unflatten_tensor_state_dict,
)
from torchao.prototype.safetensors.safetensors_utils import is_metadata_torchao
else:
flatten_tensor_state_dict = None
unflatten_tensor_state_dict = None
is_metadata_torchao = None


_SAFETENSORS_SUPPORTED_CONFIGS = (
"Float8DynamicActivationFloat8WeightConfig",
"Float8WeightOnlyConfig",
"Int4WeightOnlyConfig",
"Int8DynamicActivationInt8WeightConfig",
"Int8DynamicActivationIntxWeightConfig",
"Int8WeightOnlyConfig",
"IntxWeightOnlyConfig",
)

# Longest first so `_weight_scale_and_zero` matches before `_weight_scale`.
_TORCHAO_FLAT_SUFFIXES = (
"_weight_scale_and_zero",
"_weight_per_tensor_scale",
"_weight_act_pre_scale",
"_weight_zero_point",
"_weight_qdata",
"_weight_scale",
)


def _update_torch_safe_globals():
safe_globals = [
Expand Down Expand Up @@ -337,20 +369,89 @@ def _process_model_before_weight_loading(
def _process_model_after_weight_loading(self, model: "ModelMixin"):
return model

def is_serializable(self, safe_serialization=None):
# TODO(aryan): needs to be tested
if safe_serialization:
logger.warning(
"torchao quantized model does not support safe serialization, please set `safe_serialization` to False."
)
return False
def set_metadata(self, checkpoint_files: list[str]) -> None:
"""Collect torchao's flatten metadata from each safetensors shard header."""
self.metadata = {}
self._pending_flat = {}
if not checkpoint_files:
return
from safetensors import safe_open

for checkpoint_file in checkpoint_files:
if not isinstance(checkpoint_file, str) or not checkpoint_file.endswith(".safetensors"):
continue
try:
with safe_open(checkpoint_file, framework="pt") as f:
shard_metadata = f.metadata() or {}
except Exception as e:
logger.debug(f"Could not read safetensors metadata from {checkpoint_file}: {e}")
continue
self.metadata.update(shard_metadata)

def update_loaded_keys(self, loaded_keys: list[str]) -> list[str]:
"""Collapse `<prefix>._weight_*` flat-suffix names back to `<prefix>.weight`."""
metadata = getattr(self, "metadata", None)
if not metadata or is_metadata_torchao is None or not is_metadata_torchao(metadata):
return loaded_keys
rewritten = []
seen = set()
for key in loaded_keys:
target = key
for suffix in _TORCHAO_FLAT_SUFFIXES:
if key.endswith(suffix):
target = key[: -len(suffix)] + "weight"
break
if target not in seen:
seen.add(target)
rewritten.append(target)
return rewritten

def update_state_dict_with_metadata(
self, state_dict: dict[str, "torch.Tensor"], metadata: dict[str, str]
) -> dict[str, "torch.Tensor"]:
"""Per-shard unflatten with a `_pending_flat` buffer for layers split across shards."""
if (
unflatten_tensor_state_dict is None
or is_metadata_torchao is None
or not metadata
or not is_metadata_torchao(metadata)
):
return state_dict

# Carve out dot-less keys (top-level plain params like PixArt's `scale_shift_table`).
# torchao's unflatten does `tensor_name.rsplit(".", 1)` on every entry in metadata's
# `tensor_names` list and crashes on names without a dot. Also strip them from a copy
# of the metadata so torchao's loop never touches them.
import json

try:
tensor_names = json.loads(metadata.get("tensor_names", "[]"))
except (ValueError, TypeError):
tensor_names = []
dotless = [n for n in tensor_names if "." not in n]
if dotless:
metadata = {k: v for k, v in metadata.items() if k not in dotless}
metadata["tensor_names"] = json.dumps([n for n in tensor_names if "." in n])

passthrough = {k: state_dict[k] for k in dotless if k in state_dict}
to_unflatten = {k: v for k, v in state_dict.items() if k not in passthrough}

pending = getattr(self, "_pending_flat", None) or {}
if pending:
to_unflatten = {**pending, **to_unflatten}
unflattened, leftover = unflatten_tensor_state_dict(to_unflatten, metadata)
self._pending_flat = leftover or {}
return {**passthrough, **unflattened}

@property
def is_serializable(self):
_is_torchao_serializable = version.parse(importlib.metadata.version("huggingface_hub")) >= version.parse(
"0.25.0"
)

if not _is_torchao_serializable:
logger.warning("torchao quantized model is only serializable after huggingface_hub >= 0.25.0 ")
return False

if self.offload and self.quantization_config.modules_to_not_convert is None:
logger.warning(
Expand All @@ -361,6 +462,34 @@ def is_serializable(self, safe_serialization=None):

return _is_torchao_serializable

def get_state_dict_and_metadata(
self, model: "ModelMixin", safe_serialization: bool = False
) -> tuple[dict[str, "torch.Tensor"], dict[str, str]]:
"""Flatten torchao tensor subclasses; raises if config is unsupported or v1."""
if not safe_serialization or flatten_tensor_state_dict is None:
return model.state_dict(), {}

if not is_torchao_version(">=", "0.15.0"):
raise ValueError(
"Saving a torchao quantized model with `safe_serialization=True` requires torchao>=0.15.0. "
"Either upgrade torchao or pass `safe_serialization=False`."
)
quant_type = self.quantization_config.quant_type
cfg_name = quant_type.__class__.__name__
if cfg_name not in _SAFETENSORS_SUPPORTED_CONFIGS:
raise ValueError(
f"torchao config `{cfg_name}` does not support safetensors serialization yet. "
f"Supported configs: {_SAFETENSORS_SUPPORTED_CONFIGS}. "
"Pass `safe_serialization=False` to save with pickle instead."
)
if getattr(quant_type, "version", 2) != 2:
raise ValueError(
f"torchao config `{cfg_name}` was constructed with `version=1`, which produces the deprecated "
"`AffineQuantizedTensor` subclass that cannot be serialized to safetensors. Reconstruct it as "
f"`{cfg_name}(version=2, ...)` or pass `safe_serialization=False`."
)
return flatten_tensor_state_dict(model.state_dict())

_TRAINABLE_QUANTIZATION_CONFIGS = (
"Int8WeightOnlyConfig",
"Int8DynamicActivationInt8WeightConfig",
Expand Down
Loading
Loading