diff --git a/src/diffusers/hooks/group_offloading.py b/src/diffusers/hooks/group_offloading.py index 49509cbf04b9..58674cc7f2a1 100644 --- a/src/diffusers/hooks/group_offloading.py +++ b/src/diffusers/hooks/group_offloading.py @@ -22,7 +22,7 @@ import safetensors.torch import torch -from ..utils import get_logger, is_accelerate_available, is_torchao_available +from ..utils import get_logger, is_accelerate_available, is_torchao_available, is_torchao_version from ._common import _GO_LC_SUPPORTED_PYTORCH_LAYERS from .hooks import HookRegistry, ModelHook @@ -31,6 +31,15 @@ from accelerate.hooks import AlignDevicesHook, CpuOffload from accelerate.utils import send_to_device +if is_torchao_available() and is_torchao_version(">=", "0.15.0"): + from torchao.prototype.safetensors.safetensors_support import ( + flatten_tensor_state_dict, + unflatten_tensor_state_dict, + ) +else: + flatten_tensor_state_dict = None + unflatten_tensor_state_dict = None + logger = get_logger(__name__) # pylint: disable=invalid-name @@ -162,8 +171,15 @@ def __init__( self.tensor_to_key = {tensor: f"tensor_{i}" for i, tensor in enumerate(all_tensors)} self.key_to_tensor = {v: k for k, v in self.tensor_to_key.items()} + self._has_torchao = any(_is_torchao_tensor(t) for t in all_tensors) + if self._has_torchao and flatten_tensor_state_dict is None: + raise ImportError( + "Disk offloading of TorchAO-quantized tensors requires torchao>=0.15.0. Either " + "upgrade torchao or run group offloading without `offload_to_disk_path`." + ) self.cpu_param_dict = {} else: + self._has_torchao = False self.cpu_param_dict = self._init_cpu_param_dict() self._torch_accelerator_module = ( @@ -238,18 +254,9 @@ def _process_tensors_from_modules(self, pinned_memory=None, default_stream=None) source = pinned_memory[buffer] if pinned_memory else buffer.data self._transfer_tensor_to_device(buffer, source, default_stream) - def _check_disk_offload_torchao(self): - all_tensors = list(self.tensor_to_key.keys()) - has_torchao = any(_is_torchao_tensor(t) for t in all_tensors) - if has_torchao: - raise ValueError( - "Disk offloading is not supported for TorchAO quantized tensors because safetensors " - "cannot serialize TorchAO subclass tensors. Use memory offloading instead by not " - "setting `offload_to_disk_path`." - ) - def _onload_from_disk(self): - self._check_disk_offload_torchao() + if self._has_torchao: + return self._onload_from_disk_torchao() if self.stream is not None: # Wait for previous Host->Device transfer to complete @@ -277,6 +284,25 @@ def _onload_from_disk(self): for key, tensor_obj in self.key_to_tensor.items(): tensor_obj.data = loaded_tensors[key] + def _onload_from_disk_torchao(self): + """Synchronous direct-to-device load for groups containing TorchAO subclass tensors.""" + with safetensors.safe_open(self.safetensors_file_path, framework="pt") as f: + metadata = f.metadata() or {} + flat = safetensors.torch.load_file(self.safetensors_file_path, device=str(self.onload_device)) + unflattened, leftover = unflatten_tensor_state_dict(flat, metadata) + if leftover: + raise ValueError( + f"Group offload unflatten left unprocessed entries: {sorted(leftover.keys())[:10]}" + ) + # Strip the `.weight` suffix we added before flatten on the save side. + for full_key, tensor in unflattened.items(): + key = full_key[: -len(".weight")] if full_key.endswith(".weight") else full_key + tensor_obj = self.key_to_tensor[key] + if _is_torchao_tensor(tensor_obj): + _swap_torchao_tensor(tensor_obj, tensor) + else: + tensor_obj.data = tensor + def _onload_from_memory(self): if self.stream is not None: # Wait for previous Host->Device transfer to complete @@ -293,8 +319,6 @@ def _onload_from_memory(self): self._process_tensors_from_modules(None) def _offload_to_disk(self): - self._check_disk_offload_torchao() - # TODO: we can potentially optimize this code path by checking if the _all_ the desired # safetensor files exist on the disk and if so, skip this step entirely, reducing IO # overhead. Currently, we just check if the given `safetensors_file_path` exists and if not @@ -302,15 +326,34 @@ def _offload_to_disk(self): # Check if the file has been saved in this session or if it already exists on disk. if not self._is_offloaded_to_disk and not os.path.exists(self.safetensors_file_path): os.makedirs(os.path.dirname(self.safetensors_file_path), exist_ok=True) - tensors_to_save = {key: tensor.data.to(self.offload_device) for tensor, key in self.tensor_to_key.items()} - safetensors.torch.save_file(tensors_to_save, self.safetensors_file_path) + if self._has_torchao: + # Append `.weight` so torchao's flatten helper sees the subclass boundary. + state = { + f"{key}.weight": (tensor.cpu() if _is_torchao_tensor(tensor) else tensor.data.cpu()) + for tensor, key in self.tensor_to_key.items() + } + flat, metadata = flatten_tensor_state_dict(state) + safetensors.torch.save_file(flat, self.safetensors_file_path, metadata=metadata) + else: + tensors_to_save = { + key: tensor.data.to(self.offload_device) for tensor, key in self.tensor_to_key.items() + } + safetensors.torch.save_file(tensors_to_save, self.safetensors_file_path) # The group is now considered offloaded to disk for the rest of the session. self._is_offloaded_to_disk = True - # We do this to free up the RAM which is still holding the up tensor data. + # Free GPU memory. Torchao subclasses need `_swap_torchao_tensor` because `.data =` is a no-op + # on their inner storages. + offload_device = ( + torch.device(self.offload_device) if isinstance(self.offload_device, str) else self.offload_device + ) for tensor_obj in self.tensor_to_key.keys(): - tensor_obj.data = torch.empty_like(tensor_obj.data, device=self.offload_device) + if _is_torchao_tensor(tensor_obj): + if tensor_obj.device != offload_device: + _swap_torchao_tensor(tensor_obj, tensor_obj.to(offload_device)) + else: + tensor_obj.data = torch.empty_like(tensor_obj.data, device=self.offload_device) def _offload_to_memory(self): if self.stream is not None: diff --git a/src/diffusers/models/model_loading_utils.py b/src/diffusers/models/model_loading_utils.py index 04642ad5d401..52cf96d24197 100644 --- a/src/diffusers/models/model_loading_utils.py +++ b/src/diffusers/models/model_loading_utils.py @@ -357,6 +357,8 @@ def _load_shard_file( disable_mmap=False, ): state_dict = load_state_dict(shard_file, dduf_entries=dduf_entries, disable_mmap=disable_mmap) + if hf_quantizer is not None and getattr(hf_quantizer, "metadata", None): + state_dict = hf_quantizer.update_state_dict_with_metadata(state_dict, hf_quantizer.metadata) mismatched_keys = _find_mismatched_keys( state_dict, model_state_dict, diff --git a/src/diffusers/models/modeling_utils.py b/src/diffusers/models/modeling_utils.py index 0423b7287193..1e3a52292b02 100644 --- a/src/diffusers/models/modeling_utils.py +++ b/src/diffusers/models/modeling_utils.py @@ -758,7 +758,14 @@ def save_pretrained( model_to_save.save_config(save_directory) # Save the model - state_dict = model_to_save.state_dict() + safetensors_metadata = {"format": "pt"} + if hf_quantizer is not None: + state_dict, quantizer_metadata = hf_quantizer.get_state_dict_and_metadata( + model_to_save, safe_serialization=safe_serialization + ) + safetensors_metadata.update(quantizer_metadata) + else: + state_dict = model_to_save.state_dict() if use_flashpack: if is_flashpack_available(): @@ -805,7 +812,7 @@ def save_pretrained( if safe_serialization: # At some point we will need to deal better with save_function (used for TPU and other distributed # joyfulness), but for now this enough. - safetensors.torch.save_file(shard, filepath, metadata={"format": "pt"}) + safetensors.torch.save_file(shard, filepath, metadata=safetensors_metadata) else: torch.save(shard, filepath) @@ -1368,6 +1375,8 @@ def from_pretrained(cls, pretrained_model_name_or_path: str | os.PathLike | None loaded_keys = list(state_dict.keys()) if hf_quantizer is not None: + hf_quantizer.set_metadata(resolved_model_file) + loaded_keys = hf_quantizer.update_loaded_keys(loaded_keys) hf_quantizer.preprocess_model( model=model, device_map=device_map, keep_in_fp32_modules=keep_in_fp32_modules ) diff --git a/src/diffusers/quantizers/base.py b/src/diffusers/quantizers/base.py index b0988284b648..5884477601e8 100644 --- a/src/diffusers/quantizers/base.py +++ b/src/diffusers/quantizers/base.py @@ -235,6 +235,26 @@ def _process_model_after_weight_loading(self, model, **kwargs): ... @abstractmethod def is_serializable(self): ... + def get_state_dict_and_metadata( + self, model: "ModelMixin", safe_serialization: bool = False + ) -> tuple[dict[str, Any], dict[str, str]]: + """Save-time hook: return `(state_dict, safetensors_metadata)`.""" + return model.state_dict(), {} + + def set_metadata(self, checkpoint_files: list[str]) -> None: + """Load-time hook: read whatever per-shard safetensors metadata the quantizer needs.""" + return None + + def update_state_dict_with_metadata( + self, state_dict: dict[str, Any], metadata: dict[str, str] + ) -> dict[str, Any]: + """Load-time hook: transform a shard's state dict before per-parameter dispatch.""" + return state_dict + + def update_loaded_keys(self, loaded_keys: list[str]) -> list[str]: + """Load-time hook: rewrite checkpoint key names to match the post-transform state dict.""" + return loaded_keys + @property @abstractmethod def is_trainable(self): ... diff --git a/src/diffusers/quantizers/torchao/torchao_quantizer.py b/src/diffusers/quantizers/torchao/torchao_quantizer.py index 3a20dca88ecf..c7cfc1e297af 100644 --- a/src/diffusers/quantizers/torchao/torchao_quantizer.py +++ b/src/diffusers/quantizers/torchao/torchao_quantizer.py @@ -72,6 +72,38 @@ if is_torchao_available(): from torchao.quantization import quantize_ +if is_torchao_available() and is_torchao_version(">=", "0.15.0"): + from torchao.prototype.safetensors.safetensors_support import ( + flatten_tensor_state_dict, + unflatten_tensor_state_dict, + ) + from torchao.prototype.safetensors.safetensors_utils import is_metadata_torchao +else: + flatten_tensor_state_dict = None + unflatten_tensor_state_dict = None + is_metadata_torchao = None + + +_SAFETENSORS_SUPPORTED_CONFIGS = ( + "Float8DynamicActivationFloat8WeightConfig", + "Float8WeightOnlyConfig", + "Int4WeightOnlyConfig", + "Int8DynamicActivationInt8WeightConfig", + "Int8DynamicActivationIntxWeightConfig", + "Int8WeightOnlyConfig", + "IntxWeightOnlyConfig", +) + +# Longest first so `_weight_scale_and_zero` matches before `_weight_scale`. +_TORCHAO_FLAT_SUFFIXES = ( + "_weight_scale_and_zero", + "_weight_per_tensor_scale", + "_weight_act_pre_scale", + "_weight_zero_point", + "_weight_qdata", + "_weight_scale", +) + def _update_torch_safe_globals(): safe_globals = [ @@ -337,20 +369,89 @@ def _process_model_before_weight_loading( def _process_model_after_weight_loading(self, model: "ModelMixin"): return model - def is_serializable(self, safe_serialization=None): - # TODO(aryan): needs to be tested - if safe_serialization: - logger.warning( - "torchao quantized model does not support safe serialization, please set `safe_serialization` to False." - ) - return False + def set_metadata(self, checkpoint_files: list[str]) -> None: + """Collect torchao's flatten metadata from each safetensors shard header.""" + self.metadata = {} + self._pending_flat = {} + if not checkpoint_files: + return + from safetensors import safe_open + + for checkpoint_file in checkpoint_files: + if not isinstance(checkpoint_file, str) or not checkpoint_file.endswith(".safetensors"): + continue + try: + with safe_open(checkpoint_file, framework="pt") as f: + shard_metadata = f.metadata() or {} + except Exception as e: + logger.debug(f"Could not read safetensors metadata from {checkpoint_file}: {e}") + continue + self.metadata.update(shard_metadata) + + def update_loaded_keys(self, loaded_keys: list[str]) -> list[str]: + """Collapse `._weight_*` flat-suffix names back to `.weight`.""" + metadata = getattr(self, "metadata", None) + if not metadata or is_metadata_torchao is None or not is_metadata_torchao(metadata): + return loaded_keys + rewritten = [] + seen = set() + for key in loaded_keys: + target = key + for suffix in _TORCHAO_FLAT_SUFFIXES: + if key.endswith(suffix): + target = key[: -len(suffix)] + "weight" + break + if target not in seen: + seen.add(target) + rewritten.append(target) + return rewritten + + def update_state_dict_with_metadata( + self, state_dict: dict[str, "torch.Tensor"], metadata: dict[str, str] + ) -> dict[str, "torch.Tensor"]: + """Per-shard unflatten with a `_pending_flat` buffer for layers split across shards.""" + if ( + unflatten_tensor_state_dict is None + or is_metadata_torchao is None + or not metadata + or not is_metadata_torchao(metadata) + ): + return state_dict + + # Carve out dot-less keys (top-level plain params like PixArt's `scale_shift_table`). + # torchao's unflatten does `tensor_name.rsplit(".", 1)` on every entry in metadata's + # `tensor_names` list and crashes on names without a dot. Also strip them from a copy + # of the metadata so torchao's loop never touches them. + import json + + try: + tensor_names = json.loads(metadata.get("tensor_names", "[]")) + except (ValueError, TypeError): + tensor_names = [] + dotless = [n for n in tensor_names if "." not in n] + if dotless: + metadata = {k: v for k, v in metadata.items() if k not in dotless} + metadata["tensor_names"] = json.dumps([n for n in tensor_names if "." in n]) + + passthrough = {k: state_dict[k] for k in dotless if k in state_dict} + to_unflatten = {k: v for k, v in state_dict.items() if k not in passthrough} + + pending = getattr(self, "_pending_flat", None) or {} + if pending: + to_unflatten = {**pending, **to_unflatten} + unflattened, leftover = unflatten_tensor_state_dict(to_unflatten, metadata) + self._pending_flat = leftover or {} + return {**passthrough, **unflattened} + @property + def is_serializable(self): _is_torchao_serializable = version.parse(importlib.metadata.version("huggingface_hub")) >= version.parse( "0.25.0" ) if not _is_torchao_serializable: logger.warning("torchao quantized model is only serializable after huggingface_hub >= 0.25.0 ") + return False if self.offload and self.quantization_config.modules_to_not_convert is None: logger.warning( @@ -361,6 +462,34 @@ def is_serializable(self, safe_serialization=None): return _is_torchao_serializable + def get_state_dict_and_metadata( + self, model: "ModelMixin", safe_serialization: bool = False + ) -> tuple[dict[str, "torch.Tensor"], dict[str, str]]: + """Flatten torchao tensor subclasses; raises if config is unsupported or v1.""" + if not safe_serialization or flatten_tensor_state_dict is None: + return model.state_dict(), {} + + if not is_torchao_version(">=", "0.15.0"): + raise ValueError( + "Saving a torchao quantized model with `safe_serialization=True` requires torchao>=0.15.0. " + "Either upgrade torchao or pass `safe_serialization=False`." + ) + quant_type = self.quantization_config.quant_type + cfg_name = quant_type.__class__.__name__ + if cfg_name not in _SAFETENSORS_SUPPORTED_CONFIGS: + raise ValueError( + f"torchao config `{cfg_name}` does not support safetensors serialization yet. " + f"Supported configs: {_SAFETENSORS_SUPPORTED_CONFIGS}. " + "Pass `safe_serialization=False` to save with pickle instead." + ) + if getattr(quant_type, "version", 2) != 2: + raise ValueError( + f"torchao config `{cfg_name}` was constructed with `version=1`, which produces the deprecated " + "`AffineQuantizedTensor` subclass that cannot be serialized to safetensors. Reconstruct it as " + f"`{cfg_name}(version=2, ...)` or pass `safe_serialization=False`." + ) + return flatten_tensor_state_dict(model.state_dict()) + _TRAINABLE_QUANTIZATION_CONFIGS = ( "Int8WeightOnlyConfig", "Int8DynamicActivationInt8WeightConfig", diff --git a/tests/quantization/torchao/test_torchao.py b/tests/quantization/torchao/test_torchao.py index 8a811cfc1c73..d4f7bca8e8d3 100644 --- a/tests/quantization/torchao/test_torchao.py +++ b/tests/quantization/torchao/test_torchao.py @@ -640,6 +640,86 @@ def test_aobase_config(self): self._test_original_model_expected_slice(quant_type, expected_slice) self._check_serialization_expected_slice(quant_type, expected_slice, device) + def _check_safetensors_save(self, quant_type, device): + import os + + import safetensors + from torchao.prototype.safetensors.safetensors_utils import is_metadata_torchao + + from diffusers.utils.constants import SAFETENSORS_WEIGHTS_NAME + + quantized_model = self.get_dummy_model(quant_type, device) + inputs = self.get_dummy_tensor_inputs(torch_device) + original_output = quantized_model(**inputs)[0] + original_slice = original_output.flatten()[-9:].detach().float().cpu().numpy() + + with tempfile.TemporaryDirectory() as tmp_dir: + quantized_model.save_pretrained(tmp_dir, safe_serialization=True) + + shard_path = os.path.join(tmp_dir, SAFETENSORS_WEIGHTS_NAME) + self.assertTrue(os.path.exists(shard_path), "expected a safetensors shard to be written") + + # Safetensors header must carry torchao's flatten metadata alongside the format key. + with safetensors.safe_open(shard_path, framework="pt") as f: + metadata = f.metadata() or {} + self.assertEqual(metadata.get("format"), "pt") + self.assertGreater(len(metadata), 1, "expected torchao metadata alongside the format key") + self.assertTrue(is_metadata_torchao(metadata), f"safetensors header is not torchao-shaped: {metadata}") + + # Round-trip through from_pretrained: this exercises set_metadata + update_state_dict_with_metadata. + loaded = FluxTransformer2DModel.from_pretrained(tmp_dir, torch_dtype=torch.bfloat16).to(device=torch_device) + + loaded_output = loaded(**inputs)[0] + loaded_slice = loaded_output.flatten()[-9:].detach().float().cpu().numpy() + self.assertIsInstance(loaded.transformer_blocks[0].ff.net[2].weight, TorchAOBaseTensor) + self.assertTrue( + numpy_cosine_similarity_distance(original_slice, loaded_slice) < 1e-3, + f"reloaded outputs diverge from original (slice diff): {original_slice} vs {loaded_slice}", + ) + + def test_safetensors_save_int_a16w8(self): + # torchao's safetensors flatten helper only supports the version=2 tensor subclasses + # (Int8Tensor / Int4Tensor / …). version=1 produces the deprecated AffineQuantizedTensor + # which `flatten_tensor_state_dict` does not handle. + self._check_safetensors_save(Int8WeightOnlyConfig(version=2), torch_device) + + def test_group_offload_to_disk_int_a16w8(self): + """Group offload-to-disk for a torchao-quantized model: round-trip and output parity.""" + import os + + from diffusers.hooks import apply_group_offloading + + quantized = self.get_dummy_model(Int8WeightOnlyConfig(version=2), torch_device) + inputs = self.get_dummy_tensor_inputs(torch_device) + baseline = quantized(**inputs)[0].flatten()[-9:].detach().float().cpu().numpy() + + offloaded = self.get_dummy_model(Int8WeightOnlyConfig(version=2), torch_device) + with tempfile.TemporaryDirectory() as tmp_dir: + apply_group_offloading( + offloaded, + onload_device=torch_device, + offload_device="cpu", + offload_type="block_level", + num_blocks_per_group=1, + offload_to_disk_path=tmp_dir, + ) + offloaded_out = offloaded(**inputs)[0].flatten()[-9:].detach().float().cpu().numpy() + # The disk cache should now exist with at least one group_*.safetensors shard. + shards = [f for f in os.listdir(tmp_dir) if f.startswith("group_") and f.endswith(".safetensors")] + self.assertGreater(len(shards), 0, "expected at least one group shard on disk") + # After the forward pass each group has been offloaded back to CPU; the torchao subclass + # weight should not be lingering on the accelerator (this is the whole point of disk offload). + self.assertEqual( + offloaded.transformer_blocks[0].ff.net[2].weight.device.type, + "cpu", + "torchao subclass weight should be on cpu after group offload-to-disk", + ) + + self.assertTrue( + numpy_cosine_similarity_distance(baseline, offloaded_out) < 1e-3, + f"group-offloaded output diverges from baseline: {baseline} vs {offloaded_out}", + ) + @require_torchao_version_greater_or_equal("0.15.0") class TorchAoCompileTest(QuantCompileTests, unittest.TestCase):