diff --git a/src/diffusers/hooks/group_offloading.py b/src/diffusers/hooks/group_offloading.py index 49509cbf04b9..5f0a50610804 100644 --- a/src/diffusers/hooks/group_offloading.py +++ b/src/diffusers/hooks/group_offloading.py @@ -13,6 +13,7 @@ # limitations under the License. import hashlib +import json import os from contextlib import contextmanager, nullcontext from dataclasses import dataclass, replace @@ -21,8 +22,9 @@ import safetensors.torch import torch +from safetensors import safe_open -from ..utils import get_logger, is_accelerate_available, is_torchao_available +from ..utils import get_logger, is_accelerate_available, is_torchao_available, is_torchao_version from ._common import _GO_LC_SUPPORTED_PYTORCH_LAYERS from .hooks import HookRegistry, ModelHook @@ -32,6 +34,15 @@ from accelerate.utils import send_to_device +if is_torchao_available(): + if is_torchao_version(">=", "0.15.0"): + from torchao.prototype.safetensors.safetensors_support import ( + flatten_tensor_state_dict, + unflatten_tensor_state_dict, + ) + from torchao.prototype.safetensors.safetensors_utils import is_metadata_torchao + + logger = get_logger(__name__) # pylint: disable=invalid-name @@ -146,26 +157,28 @@ def __init__( self.offload_to_disk_path = offload_to_disk_path self._is_offloaded_to_disk = False + all_tensors = [] + for module in self.modules: + all_tensors.extend(list(module.parameters())) + all_tensors.extend(list(module.buffers())) + all_tensors.extend(self.parameters) + all_tensors.extend(self.buffers) + all_tensors = list(dict.fromkeys(all_tensors)) # Remove duplicates + + self.tensor_to_key = {tensor: f"tensor_{i}" for i, tensor in enumerate(all_tensors)} + self._torchao_disk_key_remap: dict[str, str] = {} + if self.offload_to_disk_path is not None: # Instead of `group_id or str(id(self))` we do this because `group_id` can be "" as well. self.group_id = group_id if group_id is not None else str(id(self)) short_hash = _compute_group_hash(self.group_id) self.safetensors_file_path = os.path.join(self.offload_to_disk_path, f"group_{short_hash}.safetensors") - - all_tensors = [] - for module in self.modules: - all_tensors.extend(list(module.parameters())) - all_tensors.extend(list(module.buffers())) - all_tensors.extend(self.parameters) - all_tensors.extend(self.buffers) - all_tensors = list(dict.fromkeys(all_tensors)) # Remove duplicates - - self.tensor_to_key = {tensor: f"tensor_{i}" for i, tensor in enumerate(all_tensors)} - self.key_to_tensor = {v: k for k, v in self.tensor_to_key.items()} self.cpu_param_dict = {} else: self.cpu_param_dict = self._init_cpu_param_dict() + self._has_torchao_tensors = any(_is_torchao_tensor(tensor) for tensor in self.tensor_to_key) + self._torch_accelerator_module = ( getattr(torch, torch.accelerator.current_accelerator().type) if hasattr(torch, "accelerator") @@ -179,6 +192,26 @@ def _to_cpu(tensor, low_cpu_mem_usage): t = tensor.cpu() if _is_torchao_tensor(tensor) else tensor.data.cpu() return t if low_cpu_mem_usage else t.pin_memory() + @staticmethod + def _get_torchao_subset_metadata_for_unflatten(metadata): + tensor_names = metadata.get("tensor_names") + if tensor_names is None: + return None + + try: + tensor_names = json.loads(tensor_names) + except (TypeError, json.JSONDecodeError): + return None + + dotted_tensor_names = [name for name in tensor_names if "." in name] + if len(dotted_tensor_names) == 0: + return None + + return { + "tensor_names": json.dumps(dotted_tensor_names), + **{name: metadata[name] for name in dotted_tensor_names if name in metadata}, + } + def _init_cpu_param_dict(self): cpu_param_dict = {} if self.stream is None: @@ -238,19 +271,79 @@ def _process_tensors_from_modules(self, pinned_memory=None, default_stream=None) source = pinned_memory[buffer] if pinned_memory else buffer.data self._transfer_tensor_to_device(buffer, source, default_stream) - def _check_disk_offload_torchao(self): - all_tensors = list(self.tensor_to_key.keys()) - has_torchao = any(_is_torchao_tensor(t) for t in all_tensors) - if has_torchao: - raise ValueError( - "Disk offloading is not supported for TorchAO quantized tensors because safetensors " - "cannot serialize TorchAO subclass tensors. Use memory offloading instead by not " - "setting `offload_to_disk_path`." + def _get_disk_state_dict(self): + tensors_to_save = { + key: ( + tensor.to(self.offload_device) if _is_torchao_tensor(tensor) else tensor.data.to(self.offload_device) ) + for tensor, key in self.tensor_to_key.items() + } + + metadata = {} + if self._has_torchao_tensors and is_torchao_version(">=", "0.15.0"): + tensors_for_flatten = {} + self._torchao_disk_key_remap = {} + for key, tensor in tensors_to_save.items(): + if _is_torchao_tensor(tensor) and "." not in key: + flattened_key = f"{key}.weight" + self._torchao_disk_key_remap[key] = flattened_key + tensors_for_flatten[flattened_key] = tensor + else: + tensors_for_flatten[key] = tensor - def _onload_from_disk(self): - self._check_disk_offload_torchao() + flattened_state_dict = flatten_tensor_state_dict(tensors_for_flatten) + if isinstance(flattened_state_dict, tuple): + tensors_to_save, metadata = flattened_state_dict + + return tensors_to_save, metadata + + def _load_disk_state_dict(self, device): + loaded_tensors = safetensors.torch.load_file(self.safetensors_file_path, device=device) + if not self._has_torchao_tensors or not is_torchao_version(">=", "0.15.0"): + return loaded_tensors + + with safe_open(self.safetensors_file_path, framework="pt") as f: + metadata = f.metadata() or {} + + if is_metadata_torchao(metadata): + try: + reconstructed_state_dict, leftover_state_dict = unflatten_tensor_state_dict(loaded_tensors, metadata) + loaded_tensors = {**leftover_state_dict, **reconstructed_state_dict} + except Exception as error: + logger.warning( + "Failed to unflatten TorchAO state dict metadata from disk; falling back to raw tensors." + ) + logger.debug(error) + + subset_metadata = self._get_torchao_subset_metadata_for_unflatten(metadata) + if subset_metadata is not None: + try: + reconstructed_state_dict, leftover_state_dict = unflatten_tensor_state_dict( + loaded_tensors, subset_metadata + ) + loaded_tensors = {**leftover_state_dict, **reconstructed_state_dict} + except Exception as subset_error: + logger.debug("Failed to unflatten subset of TorchAO metadata; using raw tensors for onload.") + logger.debug(subset_error) + + # Support legacy in-memory tensor keys used by GroupOffloading when + # flattening introduced dot-based names to satisfy TorchAO's safetensors API. + for original_key, flattened_key in self._torchao_disk_key_remap.items(): + if original_key not in loaded_tensors and flattened_key in loaded_tensors: + loaded_tensors[original_key] = loaded_tensors.pop(flattened_key) + + return loaded_tensors + + def _release_onload_tensors(self): + for tensor_obj in self.tensor_to_key.keys(): + if _is_torchao_tensor(tensor_obj): + placeholder = tensor_obj.to(self.offload_device) + _swap_torchao_tensor(tensor_obj, placeholder) + else: + tensor_obj.data = torch.empty_like(tensor_obj.data, device=self.offload_device) + + def _onload_from_disk(self): if self.stream is not None: # Wait for previous Host->Device transfer to complete self.stream.synchronize() @@ -259,23 +352,22 @@ def _onload_from_disk(self): current_stream = self._torch_accelerator_module.current_stream() if self.record_stream else None with context: - # Load to CPU (if using streams) or directly to target device, pin, and async copy to device device = str(self.onload_device) if self.stream is None else "cpu" - loaded_tensors = safetensors.torch.load_file(self.safetensors_file_path, device=device) + loaded_tensors = self._load_disk_state_dict(device=device) if self.stream is not None: - for key, tensor_obj in self.key_to_tensor.items(): - pinned_tensor = loaded_tensors[key].pin_memory() - tensor_obj.data = pinned_tensor.to(self.onload_device, non_blocking=self.non_blocking) - if self.record_stream: - tensor_obj.data.record_stream(current_stream) + pinned_memory = { + tensor_obj: loaded_tensors[self.tensor_to_key[tensor_obj]].pin_memory() + for tensor_obj in self.tensor_to_key + } + self._process_tensors_from_modules(pinned_memory, default_stream=current_stream) else: - onload_device = ( - self.onload_device.type if isinstance(self.onload_device, torch.device) else self.onload_device - ) - loaded_tensors = safetensors.torch.load_file(self.safetensors_file_path, device=onload_device) - for key, tensor_obj in self.key_to_tensor.items(): - tensor_obj.data = loaded_tensors[key] + for tensor_obj in self.tensor_to_key: + self._transfer_tensor_to_device( + tensor_obj, + loaded_tensors[self.tensor_to_key[tensor_obj]], + default_stream=None, + ) def _onload_from_memory(self): if self.stream is not None: @@ -293,8 +385,6 @@ def _onload_from_memory(self): self._process_tensors_from_modules(None) def _offload_to_disk(self): - self._check_disk_offload_torchao() - # TODO: we can potentially optimize this code path by checking if the _all_ the desired # safetensor files exist on the disk and if so, skip this step entirely, reducing IO # overhead. Currently, we just check if the given `safetensors_file_path` exists and if not @@ -302,15 +392,14 @@ def _offload_to_disk(self): # Check if the file has been saved in this session or if it already exists on disk. if not self._is_offloaded_to_disk and not os.path.exists(self.safetensors_file_path): os.makedirs(os.path.dirname(self.safetensors_file_path), exist_ok=True) - tensors_to_save = {key: tensor.data.to(self.offload_device) for tensor, key in self.tensor_to_key.items()} - safetensors.torch.save_file(tensors_to_save, self.safetensors_file_path) + tensors_to_save, metadata = self._get_disk_state_dict() + safetensors.torch.save_file(tensors_to_save, self.safetensors_file_path, metadata=metadata) # The group is now considered offloaded to disk for the rest of the session. self._is_offloaded_to_disk = True # We do this to free up the RAM which is still holding the up tensor data. - for tensor_obj in self.tensor_to_key.keys(): - tensor_obj.data = torch.empty_like(tensor_obj.data, device=self.offload_device) + self._release_onload_tensors() def _offload_to_memory(self): if self.stream is not None: diff --git a/src/diffusers/models/model_loading_utils.py b/src/diffusers/models/model_loading_utils.py index 04642ad5d401..beeee1b498b0 100644 --- a/src/diffusers/models/model_loading_utils.py +++ b/src/diffusers/models/model_loading_utils.py @@ -357,6 +357,9 @@ def _load_shard_file( disable_mmap=False, ): state_dict = load_state_dict(shard_file, dduf_entries=dduf_entries, disable_mmap=disable_mmap) + if hf_quantizer is not None and hasattr(hf_quantizer, "get_reconstructed_state_dict"): + state_dict = hf_quantizer.get_reconstructed_state_dict(state_dict) + mismatched_keys = _find_mismatched_keys( state_dict, model_state_dict, diff --git a/src/diffusers/models/modeling_utils.py b/src/diffusers/models/modeling_utils.py index 0423b7287193..30bf01da51b2 100644 --- a/src/diffusers/models/modeling_utils.py +++ b/src/diffusers/models/modeling_utils.py @@ -759,6 +759,17 @@ def save_pretrained( # Save the model state_dict = model_to_save.state_dict() + quantization_metadata = {} + if safe_serialization and hf_quantizer is not None: + get_state_dict_and_metadata = getattr(hf_quantizer, "get_state_dict_and_metadata", None) + if callable(get_state_dict_and_metadata): + state_dict_and_metadata = get_state_dict_and_metadata(model_to_save) + else: + state_dict_and_metadata = model_to_save.state_dict() + if isinstance(state_dict_and_metadata, tuple): + state_dict, quantization_metadata = state_dict_and_metadata + else: + state_dict = state_dict_and_metadata if use_flashpack: if is_flashpack_available(): @@ -803,15 +814,21 @@ def save_pretrained( shard = {tensor: state_dict[tensor].contiguous() for tensor in tensors} filepath = os.path.join(save_directory, filename) if safe_serialization: + metadata = dict(state_dict_split.metadata) + metadata.update(quantization_metadata) + metadata = {k: str(v) if not isinstance(v, str) else v for k, v in metadata.items()} # At some point we will need to deal better with save_function (used for TPU and other distributed # joyfulness), but for now this enough. - safetensors.torch.save_file(shard, filepath, metadata={"format": "pt"}) + safetensors.torch.save_file(shard, filepath, metadata=metadata) else: torch.save(shard, filepath) if state_dict_split.is_sharded: + metadata = dict(state_dict_split.metadata) + metadata.update(quantization_metadata) + metadata = {k: str(v) if not isinstance(v, str) else v for k, v in metadata.items()} index = { - "metadata": state_dict_split.metadata, + "metadata": metadata, "weight_map": state_dict_split.tensor_to_filename, } save_index_file = SAFE_WEIGHTS_INDEX_NAME if safe_serialization else WEIGHTS_INDEX_NAME @@ -1367,11 +1384,27 @@ def from_pretrained(cls, pretrained_model_name_or_path: str | os.PathLike | None else: loaded_keys = list(state_dict.keys()) + checkpoint_files = resolved_model_file + if hf_quantizer is not None: + if hasattr(hf_quantizer, "set_metadata"): + hf_quantizer.set_metadata(checkpoint_files) + quantized_weight_names = [] + if hasattr(hf_quantizer, "get_weight_names"): + quantized_weight_names = hf_quantizer.get_weight_names() + if quantized_weight_names: + loaded_keys = list(quantized_weight_names) + if hf_quantizer is not None: hf_quantizer.preprocess_model( - model=model, device_map=device_map, keep_in_fp32_modules=keep_in_fp32_modules + model=model, + device_map=device_map, + keep_in_fp32_modules=keep_in_fp32_modules, + checkpoint_files=checkpoint_files, ) + if hf_quantizer is not None and hf_quantizer.quantization_config.quant_method == QuantizationMethod.TORCHAO: + is_parallel_loading_enabled = False + # Now that the model is loaded, we can determine the device_map device_map = _determine_device_map( model, device_map, max_memory, torch_dtype, keep_in_fp32_modules, hf_quantizer diff --git a/src/diffusers/quantizers/torchao/torchao_quantizer.py b/src/diffusers/quantizers/torchao/torchao_quantizer.py index 3a20dca88ecf..b7027d4c268e 100644 --- a/src/diffusers/quantizers/torchao/torchao_quantizer.py +++ b/src/diffusers/quantizers/torchao/torchao_quantizer.py @@ -18,6 +18,7 @@ """ import importlib +import json import re import types from typing import TYPE_CHECKING, Any @@ -26,6 +27,7 @@ from ...utils import ( get_module_from_name, + is_safetensors_available, is_torch_available, is_torch_version, is_torchao_available, @@ -41,6 +43,9 @@ if TYPE_CHECKING: from ...models.modeling_utils import ModelMixin +if is_safetensors_available(): + from safetensors import safe_open + if is_torch_available(): import torch @@ -72,6 +77,13 @@ if is_torchao_available(): from torchao.quantization import quantize_ + if is_torchao_version(">=", "0.15.0"): + from torchao.prototype.safetensors.safetensors_support import ( + flatten_tensor_state_dict, + unflatten_tensor_state_dict, + ) + from torchao.prototype.safetensors.safetensors_utils import is_metadata_torchao + def _update_torch_safe_globals(): safe_globals = [ @@ -154,6 +166,11 @@ class TorchAoHfQuantizer(DiffusersQuantizer): def __init__(self, quantization_config, **kwargs): super().__init__(quantization_config, **kwargs) + self._metadata = {} + self._pending_flattened_state_dict = {} + self._loaded_weight_names = set() + self._expected_weight_names = set() + def validate_environment(self, *args, **kwargs): if not is_torchao_available(): raise ImportError( @@ -236,6 +253,76 @@ def adjust_max_memory(self, max_memory: dict[str, int | str]) -> dict[str, int | max_memory = {key: val * 0.9 for key, val in max_memory.items()} return max_memory + def get_state_dict_and_metadata(self, model): + """ + We flatten the state dict of tensor subclasses so that it is compatible with the safetensors format. + """ + if not is_torchao_available() or not is_torchao_version(">=", "0.15.0"): + return model.state_dict(), {} + return flatten_tensor_state_dict(model.state_dict()) + + def set_metadata(self, checkpoint_files: list[str]): + if not is_torchao_version(">=", "0.15.0"): + self._metadata = {} + return + + if self.metadata is None: + self.metadata = {} + self._pending_flattened_state_dict = {} + self._loaded_weight_names = set() + self._expected_weight_names = set() + + if len(checkpoint_files) == 0: + return + + if not checkpoint_files[0].endswith(".safetensors"): + self._metadata = {} + return + + metadata = {} + for checkpoint in checkpoint_files: + with safe_open(checkpoint, framework="pt") as f: + metadata.update(f.metadata() or {}) + + self._metadata = metadata if is_metadata_torchao(metadata) else {} + if is_metadata_torchao(self._metadata): + try: + self._expected_weight_names = set(json.loads(self._metadata["tensor_names"])) + except (TypeError, json.JSONDecodeError, UnicodeDecodeError): + self._metadata = {} + self._expected_weight_names = set() + + @property + def metadata(self): + return self._metadata + + @metadata.setter + def metadata(self, value: dict): + self._metadata = value + + def get_reconstructed_state_dict(self, state_dict: dict[str, Any]) -> dict[str, Any]: + if not self._metadata or not is_torchao_version(">=", "0.15.0") or not is_metadata_torchao(self._metadata): + return state_dict + + merged_state_dict = {**self._pending_flattened_state_dict, **state_dict} + reconstructed_state_dict, self._pending_flattened_state_dict = unflatten_tensor_state_dict( + merged_state_dict, self._metadata + ) + + self._loaded_weight_names.update(reconstructed_state_dict.keys()) + return reconstructed_state_dict + + def get_weight_conversions(self): + return [] + + def get_weight_names(self): + return self._expected_weight_names if self._expected_weight_names else set() + + def get_weight_reconstruction_pending_keys(self): + if not self._expected_weight_names: + return [] + return sorted(self._expected_weight_names - self._loaded_weight_names) + def check_if_quantized_param( self, model: "ModelMixin", @@ -337,11 +424,12 @@ def _process_model_before_weight_loading( def _process_model_after_weight_loading(self, model: "ModelMixin"): return model - def is_serializable(self, safe_serialization=None): - # TODO(aryan): needs to be tested - if safe_serialization: + @property + def is_serializable(self): + if not is_torchao_version(">=", "0.15.0"): logger.warning( - "torchao quantized model does not support safe serialization, please set `safe_serialization` to False." + "TorchAO quantized model is not serializable with safe serialization without safetensors support " + "from the installed torchao version." ) return False diff --git a/tests/quantization/torchao/test_torchao.py b/tests/quantization/torchao/test_torchao.py index 8a811cfc1c73..58913a0fc29f 100644 --- a/tests/quantization/torchao/test_torchao.py +++ b/tests/quantization/torchao/test_torchao.py @@ -14,6 +14,7 @@ # limitations under the License. import gc +import os import tempfile import unittest from typing import List @@ -589,13 +590,32 @@ def _test_original_model_expected_slice(self, quant_type, expected_slice): self.assertTrue(isinstance(weight, TorchAOBaseTensor)) self.assertTrue(numpy_cosine_similarity_distance(output_slice, expected_slice) < 1e-3) - def _check_serialization_expected_slice(self, quant_type, expected_slice, device): + def _check_serialization_expected_slice( + self, quant_type, expected_slice, device, safe_serialization=False, max_shard_size=None, assert_sharded=False + ): + if safe_serialization and getattr(quant_type, "version", None) != 2: + self.skipTest("TorchAO safe serialization tests require quantization config version=2.") + quantized_model = self.get_dummy_model(quant_type, device) + save_kwargs = {"safe_serialization": safe_serialization} + if max_shard_size is not None: + save_kwargs["max_shard_size"] = max_shard_size + with tempfile.TemporaryDirectory() as tmp_dir: - quantized_model.save_pretrained(tmp_dir, safe_serialization=False) + quantized_model.save_pretrained(tmp_dir, **save_kwargs) + if assert_sharded: + shard_files = [f for f in os.listdir(tmp_dir) if f.endswith(".safetensors")] + if max_shard_size is not None: + self.assertTrue(len(shard_files) > 1, "Expected a sharded safe-serialization checkpoint.") + self.assertTrue( + any("index" in f and f.endswith(".json") for f in os.listdir(tmp_dir)), + "Expected an index file for sharded safe checkpoint.", + ) loaded_quantized_model = FluxTransformer2DModel.from_pretrained( - tmp_dir, torch_dtype=torch.bfloat16, use_safetensors=False + tmp_dir, + torch_dtype=torch.bfloat16, + use_safetensors=safe_serialization, ).to(device=torch_device) inputs = self.get_dummy_tensor_inputs(torch_device) @@ -605,6 +625,55 @@ def _check_serialization_expected_slice(self, quant_type, expected_slice, device self.assertTrue(isinstance(loaded_quantized_model.proj_out.weight, TorchAOBaseTensor)) self.assertTrue(numpy_cosine_similarity_distance(output_slice, expected_slice) < 1e-3) + def test_int_a8w8_safe_cpu(self): + quant_type = Int8DynamicActivationInt8WeightConfig(version=2) + expected_slice = np.array([0.3633, -0.1357, -0.0188, -0.249, -0.4688, 0.5078, -0.1289, -0.6914, 0.4551]) + device = "cpu" + self._check_serialization_expected_slice(quant_type, expected_slice, device, safe_serialization=True) + + def test_int_a8w8_safe(self): + quant_type = Int8DynamicActivationInt8WeightConfig(version=2) + expected_slice = np.array([0.3633, -0.1357, -0.0188, -0.249, -0.4688, 0.5078, -0.1289, -0.6914, 0.4551]) + device = torch_device + self._check_serialization_expected_slice(quant_type, expected_slice, device, safe_serialization=True) + + def test_group_offload_to_disk(self): + quant_type = Int8DynamicActivationInt8WeightConfig(version=2) + expected_slice = np.array([0.3633, -0.1357, -0.0188, -0.249, -0.4688, 0.5078, -0.1289, -0.6914, 0.4551]) + + quantized_model = self.get_dummy_model(quant_type, torch_device) + + with tempfile.TemporaryDirectory() as offload_to_disk_path: + quantized_model.enable_group_offload( + onload_device=torch_device, + offload_type="leaf_level", + offload_to_disk_path=offload_to_disk_path, + ) + + inputs = self.get_dummy_tensor_inputs(torch_device) + output = quantized_model(**inputs)[0] + output_slice = output.flatten()[-9:].detach().float().cpu().numpy() + + self.assertTrue(numpy_cosine_similarity_distance(output_slice, expected_slice) < 1e-3) + + output = quantized_model(**inputs)[0] + output_slice_2 = output.flatten()[-9:].detach().float().cpu().numpy() + + self.assertTrue(numpy_cosine_similarity_distance(output_slice_2, expected_slice) < 1e-3) + + def test_int_a8w8_safe_sharded(self): + quant_type = Int8DynamicActivationInt8WeightConfig(version=2) + expected_slice = np.array([0.3633, -0.1357, -0.0188, -0.249, -0.4688, 0.5078, -0.1289, -0.6914, 0.4551]) + device = torch_device + self._check_serialization_expected_slice( + quant_type, + expected_slice, + device, + safe_serialization=True, + max_shard_size="16KB", + assert_sharded=True, + ) + def test_int_a8w8_accelerator(self): quant_type = Int8DynamicActivationInt8WeightConfig() expected_slice = np.array([0.3633, -0.1357, -0.0188, -0.249, -0.4688, 0.5078, -0.1289, -0.6914, 0.4551])