-
Notifications
You must be signed in to change notification settings - Fork 7k
Incorporate safetensors support to TorchAO #13719
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -13,6 +13,7 @@ | |
| # limitations under the License. | ||
|
|
||
| import hashlib | ||
| import json | ||
| import os | ||
| from contextlib import contextmanager, nullcontext | ||
| from dataclasses import dataclass, replace | ||
|
|
@@ -21,8 +22,9 @@ | |
|
|
||
| import safetensors.torch | ||
| import torch | ||
| from safetensors import safe_open | ||
|
|
||
| from ..utils import get_logger, is_accelerate_available, is_torchao_available | ||
| from ..utils import get_logger, is_accelerate_available, is_torchao_available, is_torchao_version | ||
| from ._common import _GO_LC_SUPPORTED_PYTORCH_LAYERS | ||
| from .hooks import HookRegistry, ModelHook | ||
|
|
||
|
|
@@ -32,6 +34,15 @@ | |
| from accelerate.utils import send_to_device | ||
|
|
||
|
|
||
| if is_torchao_available(): | ||
| if is_torchao_version(">=", "0.15.0"): | ||
| from torchao.prototype.safetensors.safetensors_support import ( | ||
| flatten_tensor_state_dict, | ||
| unflatten_tensor_state_dict, | ||
| ) | ||
| from torchao.prototype.safetensors.safetensors_utils import is_metadata_torchao | ||
|
|
||
|
|
||
| logger = get_logger(__name__) # pylint: disable=invalid-name | ||
|
|
||
|
|
||
|
|
@@ -146,26 +157,28 @@ def __init__( | |
| self.offload_to_disk_path = offload_to_disk_path | ||
| self._is_offloaded_to_disk = False | ||
|
|
||
| all_tensors = [] | ||
| for module in self.modules: | ||
| all_tensors.extend(list(module.parameters())) | ||
| all_tensors.extend(list(module.buffers())) | ||
| all_tensors.extend(self.parameters) | ||
| all_tensors.extend(self.buffers) | ||
| all_tensors = list(dict.fromkeys(all_tensors)) # Remove duplicates | ||
|
|
||
| self.tensor_to_key = {tensor: f"tensor_{i}" for i, tensor in enumerate(all_tensors)} | ||
| self._torchao_disk_key_remap: dict[str, str] = {} | ||
|
|
||
| if self.offload_to_disk_path is not None: | ||
| # Instead of `group_id or str(id(self))` we do this because `group_id` can be "" as well. | ||
| self.group_id = group_id if group_id is not None else str(id(self)) | ||
| short_hash = _compute_group_hash(self.group_id) | ||
| self.safetensors_file_path = os.path.join(self.offload_to_disk_path, f"group_{short_hash}.safetensors") | ||
|
|
||
| all_tensors = [] | ||
| for module in self.modules: | ||
| all_tensors.extend(list(module.parameters())) | ||
| all_tensors.extend(list(module.buffers())) | ||
| all_tensors.extend(self.parameters) | ||
| all_tensors.extend(self.buffers) | ||
| all_tensors = list(dict.fromkeys(all_tensors)) # Remove duplicates | ||
|
|
||
| self.tensor_to_key = {tensor: f"tensor_{i}" for i, tensor in enumerate(all_tensors)} | ||
| self.key_to_tensor = {v: k for k, v in self.tensor_to_key.items()} | ||
| self.cpu_param_dict = {} | ||
| else: | ||
| self.cpu_param_dict = self._init_cpu_param_dict() | ||
|
|
||
| self._has_torchao_tensors = any(_is_torchao_tensor(tensor) for tensor in self.tensor_to_key) | ||
|
|
||
| self._torch_accelerator_module = ( | ||
| getattr(torch, torch.accelerator.current_accelerator().type) | ||
| if hasattr(torch, "accelerator") | ||
|
|
@@ -179,6 +192,26 @@ def _to_cpu(tensor, low_cpu_mem_usage): | |
| t = tensor.cpu() if _is_torchao_tensor(tensor) else tensor.data.cpu() | ||
| return t if low_cpu_mem_usage else t.pin_memory() | ||
|
|
||
| @staticmethod | ||
| def _get_torchao_subset_metadata_for_unflatten(metadata): | ||
| tensor_names = metadata.get("tensor_names") | ||
| if tensor_names is None: | ||
| return None | ||
|
|
||
| try: | ||
| tensor_names = json.loads(tensor_names) | ||
| except (TypeError, json.JSONDecodeError): | ||
| return None | ||
|
|
||
| dotted_tensor_names = [name for name in tensor_names if "." in name] | ||
| if len(dotted_tensor_names) == 0: | ||
| return None | ||
|
|
||
| return { | ||
| "tensor_names": json.dumps(dotted_tensor_names), | ||
| **{name: metadata[name] for name in dotted_tensor_names if name in metadata}, | ||
| } | ||
|
|
||
| def _init_cpu_param_dict(self): | ||
| cpu_param_dict = {} | ||
| if self.stream is None: | ||
|
|
@@ -238,19 +271,79 @@ def _process_tensors_from_modules(self, pinned_memory=None, default_stream=None) | |
| source = pinned_memory[buffer] if pinned_memory else buffer.data | ||
| self._transfer_tensor_to_device(buffer, source, default_stream) | ||
|
|
||
| def _check_disk_offload_torchao(self): | ||
| all_tensors = list(self.tensor_to_key.keys()) | ||
| has_torchao = any(_is_torchao_tensor(t) for t in all_tensors) | ||
| if has_torchao: | ||
| raise ValueError( | ||
| "Disk offloading is not supported for TorchAO quantized tensors because safetensors " | ||
| "cannot serialize TorchAO subclass tensors. Use memory offloading instead by not " | ||
| "setting `offload_to_disk_path`." | ||
| def _get_disk_state_dict(self): | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This diff seems complicated to me. I think the code would simplify a bit to try to aim for a pattern like: if has_torchao:
handle_for_torchao_safetensors_group_offloading_with_disk()
else:
keep_the_existing_codeOnce this pattern is established, we could look into what needs to be factored out in utilities. If it helps, I wouldn't mind doing it in a separate PR. |
||
| tensors_to_save = { | ||
| key: ( | ||
| tensor.to(self.offload_device) if _is_torchao_tensor(tensor) else tensor.data.to(self.offload_device) | ||
| ) | ||
| for tensor, key in self.tensor_to_key.items() | ||
| } | ||
|
|
||
| metadata = {} | ||
| if self._has_torchao_tensors and is_torchao_version(">=", "0.15.0"): | ||
| tensors_for_flatten = {} | ||
| self._torchao_disk_key_remap = {} | ||
| for key, tensor in tensors_to_save.items(): | ||
| if _is_torchao_tensor(tensor) and "." not in key: | ||
| flattened_key = f"{key}.weight" | ||
| self._torchao_disk_key_remap[key] = flattened_key | ||
| tensors_for_flatten[flattened_key] = tensor | ||
| else: | ||
| tensors_for_flatten[key] = tensor | ||
|
|
||
| def _onload_from_disk(self): | ||
| self._check_disk_offload_torchao() | ||
| flattened_state_dict = flatten_tensor_state_dict(tensors_for_flatten) | ||
| if isinstance(flattened_state_dict, tuple): | ||
| tensors_to_save, metadata = flattened_state_dict | ||
|
|
||
| return tensors_to_save, metadata | ||
|
|
||
| def _load_disk_state_dict(self, device): | ||
| loaded_tensors = safetensors.torch.load_file(self.safetensors_file_path, device=device) | ||
|
|
||
| if not self._has_torchao_tensors or not is_torchao_version(">=", "0.15.0"): | ||
| return loaded_tensors | ||
|
|
||
| with safe_open(self.safetensors_file_path, framework="pt") as f: | ||
| metadata = f.metadata() or {} | ||
|
|
||
| if is_metadata_torchao(metadata): | ||
| try: | ||
| reconstructed_state_dict, leftover_state_dict = unflatten_tensor_state_dict(loaded_tensors, metadata) | ||
| loaded_tensors = {**leftover_state_dict, **reconstructed_state_dict} | ||
| except Exception as error: | ||
| logger.warning( | ||
| "Failed to unflatten TorchAO state dict metadata from disk; falling back to raw tensors." | ||
| ) | ||
| logger.debug(error) | ||
|
|
||
| subset_metadata = self._get_torchao_subset_metadata_for_unflatten(metadata) | ||
| if subset_metadata is not None: | ||
| try: | ||
| reconstructed_state_dict, leftover_state_dict = unflatten_tensor_state_dict( | ||
| loaded_tensors, subset_metadata | ||
| ) | ||
| loaded_tensors = {**leftover_state_dict, **reconstructed_state_dict} | ||
| except Exception as subset_error: | ||
| logger.debug("Failed to unflatten subset of TorchAO metadata; using raw tensors for onload.") | ||
| logger.debug(subset_error) | ||
|
|
||
| # Support legacy in-memory tensor keys used by GroupOffloading when | ||
| # flattening introduced dot-based names to satisfy TorchAO's safetensors API. | ||
| for original_key, flattened_key in self._torchao_disk_key_remap.items(): | ||
| if original_key not in loaded_tensors and flattened_key in loaded_tensors: | ||
| loaded_tensors[original_key] = loaded_tensors.pop(flattened_key) | ||
|
|
||
| return loaded_tensors | ||
|
|
||
| def _release_onload_tensors(self): | ||
| for tensor_obj in self.tensor_to_key.keys(): | ||
| if _is_torchao_tensor(tensor_obj): | ||
| placeholder = tensor_obj.to(self.offload_device) | ||
| _swap_torchao_tensor(tensor_obj, placeholder) | ||
| else: | ||
| tensor_obj.data = torch.empty_like(tensor_obj.data, device=self.offload_device) | ||
|
|
||
| def _onload_from_disk(self): | ||
| if self.stream is not None: | ||
| # Wait for previous Host->Device transfer to complete | ||
| self.stream.synchronize() | ||
|
|
@@ -259,23 +352,22 @@ def _onload_from_disk(self): | |
| current_stream = self._torch_accelerator_module.current_stream() if self.record_stream else None | ||
|
|
||
| with context: | ||
| # Load to CPU (if using streams) or directly to target device, pin, and async copy to device | ||
| device = str(self.onload_device) if self.stream is None else "cpu" | ||
| loaded_tensors = safetensors.torch.load_file(self.safetensors_file_path, device=device) | ||
| loaded_tensors = self._load_disk_state_dict(device=device) | ||
|
|
||
| if self.stream is not None: | ||
| for key, tensor_obj in self.key_to_tensor.items(): | ||
| pinned_tensor = loaded_tensors[key].pin_memory() | ||
| tensor_obj.data = pinned_tensor.to(self.onload_device, non_blocking=self.non_blocking) | ||
| if self.record_stream: | ||
| tensor_obj.data.record_stream(current_stream) | ||
| pinned_memory = { | ||
| tensor_obj: loaded_tensors[self.tensor_to_key[tensor_obj]].pin_memory() | ||
| for tensor_obj in self.tensor_to_key | ||
| } | ||
| self._process_tensors_from_modules(pinned_memory, default_stream=current_stream) | ||
| else: | ||
| onload_device = ( | ||
| self.onload_device.type if isinstance(self.onload_device, torch.device) else self.onload_device | ||
| ) | ||
| loaded_tensors = safetensors.torch.load_file(self.safetensors_file_path, device=onload_device) | ||
| for key, tensor_obj in self.key_to_tensor.items(): | ||
| tensor_obj.data = loaded_tensors[key] | ||
| for tensor_obj in self.tensor_to_key: | ||
| self._transfer_tensor_to_device( | ||
| tensor_obj, | ||
| loaded_tensors[self.tensor_to_key[tensor_obj]], | ||
| default_stream=None, | ||
| ) | ||
|
|
||
| def _onload_from_memory(self): | ||
| if self.stream is not None: | ||
|
|
@@ -293,24 +385,21 @@ def _onload_from_memory(self): | |
| self._process_tensors_from_modules(None) | ||
|
|
||
| def _offload_to_disk(self): | ||
| self._check_disk_offload_torchao() | ||
|
|
||
| # TODO: we can potentially optimize this code path by checking if the _all_ the desired | ||
| # safetensor files exist on the disk and if so, skip this step entirely, reducing IO | ||
| # overhead. Currently, we just check if the given `safetensors_file_path` exists and if not | ||
| # we perform a write. | ||
| # Check if the file has been saved in this session or if it already exists on disk. | ||
| if not self._is_offloaded_to_disk and not os.path.exists(self.safetensors_file_path): | ||
| os.makedirs(os.path.dirname(self.safetensors_file_path), exist_ok=True) | ||
| tensors_to_save = {key: tensor.data.to(self.offload_device) for tensor, key in self.tensor_to_key.items()} | ||
| safetensors.torch.save_file(tensors_to_save, self.safetensors_file_path) | ||
| tensors_to_save, metadata = self._get_disk_state_dict() | ||
| safetensors.torch.save_file(tensors_to_save, self.safetensors_file_path, metadata=metadata) | ||
|
|
||
| # The group is now considered offloaded to disk for the rest of the session. | ||
| self._is_offloaded_to_disk = True | ||
|
|
||
| # We do this to free up the RAM which is still holding the up tensor data. | ||
| for tensor_obj in self.tensor_to_key.keys(): | ||
| tensor_obj.data = torch.empty_like(tensor_obj.data, device=self.offload_device) | ||
| self._release_onload_tensors() | ||
|
|
||
| def _offload_to_memory(self): | ||
| if self.stream is not None: | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -759,6 +759,17 @@ def save_pretrained( | |
|
|
||
| # Save the model | ||
| state_dict = model_to_save.state_dict() | ||
| quantization_metadata = {} | ||
| if safe_serialization and hf_quantizer is not None: | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should it also include a check if the quantizer is of TorchAO type just to be explicit? |
||
| get_state_dict_and_metadata = getattr(hf_quantizer, "get_state_dict_and_metadata", None) | ||
| if callable(get_state_dict_and_metadata): | ||
| state_dict_and_metadata = get_state_dict_and_metadata(model_to_save) | ||
| else: | ||
| state_dict_and_metadata = model_to_save.state_dict() | ||
| if isinstance(state_dict_and_metadata, tuple): | ||
| state_dict, quantization_metadata = state_dict_and_metadata | ||
| else: | ||
| state_dict = state_dict_and_metadata | ||
|
Comment on lines
+767
to
+772
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Where do we get a |
||
|
|
||
| if use_flashpack: | ||
| if is_flashpack_available(): | ||
|
|
@@ -803,15 +814,21 @@ def save_pretrained( | |
| shard = {tensor: state_dict[tensor].contiguous() for tensor in tensors} | ||
| filepath = os.path.join(save_directory, filename) | ||
| if safe_serialization: | ||
| metadata = dict(state_dict_split.metadata) | ||
| metadata.update(quantization_metadata) | ||
| metadata = {k: str(v) if not isinstance(v, str) else v for k, v in metadata.items()} | ||
| # At some point we will need to deal better with save_function (used for TPU and other distributed | ||
| # joyfulness), but for now this enough. | ||
| safetensors.torch.save_file(shard, filepath, metadata={"format": "pt"}) | ||
| safetensors.torch.save_file(shard, filepath, metadata=metadata) | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If we're saving all metadata, then I'd restrict it to torchao only. |
||
| else: | ||
| torch.save(shard, filepath) | ||
|
|
||
| if state_dict_split.is_sharded: | ||
| metadata = dict(state_dict_split.metadata) | ||
| metadata.update(quantization_metadata) | ||
| metadata = {k: str(v) if not isinstance(v, str) else v for k, v in metadata.items()} | ||
| index = { | ||
| "metadata": state_dict_split.metadata, | ||
| "metadata": metadata, | ||
| "weight_map": state_dict_split.tensor_to_filename, | ||
| } | ||
| save_index_file = SAFE_WEIGHTS_INDEX_NAME if safe_serialization else WEIGHTS_INDEX_NAME | ||
|
|
@@ -1367,11 +1384,27 @@ def from_pretrained(cls, pretrained_model_name_or_path: str | os.PathLike | None | |
| else: | ||
| loaded_keys = list(state_dict.keys()) | ||
|
|
||
| checkpoint_files = resolved_model_file | ||
| if hf_quantizer is not None: | ||
| if hasattr(hf_quantizer, "set_metadata"): | ||
| hf_quantizer.set_metadata(checkpoint_files) | ||
| quantized_weight_names = [] | ||
| if hasattr(hf_quantizer, "get_weight_names"): | ||
| quantized_weight_names = hf_quantizer.get_weight_names() | ||
| if quantized_weight_names: | ||
| loaded_keys = list(quantized_weight_names) | ||
|
|
||
| if hf_quantizer is not None: | ||
| hf_quantizer.preprocess_model( | ||
| model=model, device_map=device_map, keep_in_fp32_modules=keep_in_fp32_modules | ||
| model=model, | ||
| device_map=device_map, | ||
| keep_in_fp32_modules=keep_in_fp32_modules, | ||
| checkpoint_files=checkpoint_files, | ||
| ) | ||
|
|
||
| if hf_quantizer is not None and hf_quantizer.quantization_config.quant_method == QuantizationMethod.TORCHAO: | ||
| is_parallel_loading_enabled = False | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why is that? |
||
|
|
||
| # Now that the model is loaded, we can determine the device_map | ||
| device_map = _determine_device_map( | ||
| model, device_map, max_memory, torch_dtype, keep_in_fp32_modules, hf_quantizer | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We don't need to hold it outside of the
self.offload_to_disk_pathno?