Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
171 changes: 130 additions & 41 deletions src/diffusers/hooks/group_offloading.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

import hashlib
import json
import os
from contextlib import contextmanager, nullcontext
from dataclasses import dataclass, replace
Expand All @@ -21,8 +22,9 @@

import safetensors.torch
import torch
from safetensors import safe_open

from ..utils import get_logger, is_accelerate_available, is_torchao_available
from ..utils import get_logger, is_accelerate_available, is_torchao_available, is_torchao_version
from ._common import _GO_LC_SUPPORTED_PYTORCH_LAYERS
from .hooks import HookRegistry, ModelHook

Expand All @@ -32,6 +34,15 @@
from accelerate.utils import send_to_device


if is_torchao_available():
if is_torchao_version(">=", "0.15.0"):
from torchao.prototype.safetensors.safetensors_support import (
flatten_tensor_state_dict,
unflatten_tensor_state_dict,
)
from torchao.prototype.safetensors.safetensors_utils import is_metadata_torchao


logger = get_logger(__name__) # pylint: disable=invalid-name


Expand Down Expand Up @@ -146,26 +157,28 @@ def __init__(
self.offload_to_disk_path = offload_to_disk_path
self._is_offloaded_to_disk = False

all_tensors = []
for module in self.modules:
all_tensors.extend(list(module.parameters()))
all_tensors.extend(list(module.buffers()))
all_tensors.extend(self.parameters)
all_tensors.extend(self.buffers)
all_tensors = list(dict.fromkeys(all_tensors)) # Remove duplicates

self.tensor_to_key = {tensor: f"tensor_{i}" for i, tensor in enumerate(all_tensors)}
self._torchao_disk_key_remap: dict[str, str] = {}
Comment on lines +160 to +169
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't need to hold it outside of the self.offload_to_disk_path no?


if self.offload_to_disk_path is not None:
# Instead of `group_id or str(id(self))` we do this because `group_id` can be "" as well.
self.group_id = group_id if group_id is not None else str(id(self))
short_hash = _compute_group_hash(self.group_id)
self.safetensors_file_path = os.path.join(self.offload_to_disk_path, f"group_{short_hash}.safetensors")

all_tensors = []
for module in self.modules:
all_tensors.extend(list(module.parameters()))
all_tensors.extend(list(module.buffers()))
all_tensors.extend(self.parameters)
all_tensors.extend(self.buffers)
all_tensors = list(dict.fromkeys(all_tensors)) # Remove duplicates

self.tensor_to_key = {tensor: f"tensor_{i}" for i, tensor in enumerate(all_tensors)}
self.key_to_tensor = {v: k for k, v in self.tensor_to_key.items()}
self.cpu_param_dict = {}
else:
self.cpu_param_dict = self._init_cpu_param_dict()

self._has_torchao_tensors = any(_is_torchao_tensor(tensor) for tensor in self.tensor_to_key)

self._torch_accelerator_module = (
getattr(torch, torch.accelerator.current_accelerator().type)
if hasattr(torch, "accelerator")
Expand All @@ -179,6 +192,26 @@ def _to_cpu(tensor, low_cpu_mem_usage):
t = tensor.cpu() if _is_torchao_tensor(tensor) else tensor.data.cpu()
return t if low_cpu_mem_usage else t.pin_memory()

@staticmethod
def _get_torchao_subset_metadata_for_unflatten(metadata):
tensor_names = metadata.get("tensor_names")
if tensor_names is None:
return None

try:
tensor_names = json.loads(tensor_names)
except (TypeError, json.JSONDecodeError):
return None

dotted_tensor_names = [name for name in tensor_names if "." in name]
if len(dotted_tensor_names) == 0:
return None

return {
"tensor_names": json.dumps(dotted_tensor_names),
**{name: metadata[name] for name in dotted_tensor_names if name in metadata},
}

def _init_cpu_param_dict(self):
cpu_param_dict = {}
if self.stream is None:
Expand Down Expand Up @@ -238,19 +271,79 @@ def _process_tensors_from_modules(self, pinned_memory=None, default_stream=None)
source = pinned_memory[buffer] if pinned_memory else buffer.data
self._transfer_tensor_to_device(buffer, source, default_stream)

def _check_disk_offload_torchao(self):
all_tensors = list(self.tensor_to_key.keys())
has_torchao = any(_is_torchao_tensor(t) for t in all_tensors)
if has_torchao:
raise ValueError(
"Disk offloading is not supported for TorchAO quantized tensors because safetensors "
"cannot serialize TorchAO subclass tensors. Use memory offloading instead by not "
"setting `offload_to_disk_path`."
def _get_disk_state_dict(self):
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This diff seems complicated to me. I think the code would simplify a bit to try to aim for a pattern like:

if has_torchao:
    handle_for_torchao_safetensors_group_offloading_with_disk()
else:
    keep_the_existing_code

Once this pattern is established, we could look into what needs to be factored out in utilities.

If it helps, I wouldn't mind doing it in a separate PR.

tensors_to_save = {
key: (
tensor.to(self.offload_device) if _is_torchao_tensor(tensor) else tensor.data.to(self.offload_device)
)
for tensor, key in self.tensor_to_key.items()
}

metadata = {}
if self._has_torchao_tensors and is_torchao_version(">=", "0.15.0"):
tensors_for_flatten = {}
self._torchao_disk_key_remap = {}
for key, tensor in tensors_to_save.items():
if _is_torchao_tensor(tensor) and "." not in key:
flattened_key = f"{key}.weight"
self._torchao_disk_key_remap[key] = flattened_key
tensors_for_flatten[flattened_key] = tensor
else:
tensors_for_flatten[key] = tensor

def _onload_from_disk(self):
self._check_disk_offload_torchao()
flattened_state_dict = flatten_tensor_state_dict(tensors_for_flatten)
if isinstance(flattened_state_dict, tuple):
tensors_to_save, metadata = flattened_state_dict

return tensors_to_save, metadata

def _load_disk_state_dict(self, device):
loaded_tensors = safetensors.torch.load_file(self.safetensors_file_path, device=device)

if not self._has_torchao_tensors or not is_torchao_version(">=", "0.15.0"):
return loaded_tensors

with safe_open(self.safetensors_file_path, framework="pt") as f:
metadata = f.metadata() or {}

if is_metadata_torchao(metadata):
try:
reconstructed_state_dict, leftover_state_dict = unflatten_tensor_state_dict(loaded_tensors, metadata)
loaded_tensors = {**leftover_state_dict, **reconstructed_state_dict}
except Exception as error:
logger.warning(
"Failed to unflatten TorchAO state dict metadata from disk; falling back to raw tensors."
)
logger.debug(error)

subset_metadata = self._get_torchao_subset_metadata_for_unflatten(metadata)
if subset_metadata is not None:
try:
reconstructed_state_dict, leftover_state_dict = unflatten_tensor_state_dict(
loaded_tensors, subset_metadata
)
loaded_tensors = {**leftover_state_dict, **reconstructed_state_dict}
except Exception as subset_error:
logger.debug("Failed to unflatten subset of TorchAO metadata; using raw tensors for onload.")
logger.debug(subset_error)

# Support legacy in-memory tensor keys used by GroupOffloading when
# flattening introduced dot-based names to satisfy TorchAO's safetensors API.
for original_key, flattened_key in self._torchao_disk_key_remap.items():
if original_key not in loaded_tensors and flattened_key in loaded_tensors:
loaded_tensors[original_key] = loaded_tensors.pop(flattened_key)

return loaded_tensors

def _release_onload_tensors(self):
for tensor_obj in self.tensor_to_key.keys():
if _is_torchao_tensor(tensor_obj):
placeholder = tensor_obj.to(self.offload_device)
_swap_torchao_tensor(tensor_obj, placeholder)
else:
tensor_obj.data = torch.empty_like(tensor_obj.data, device=self.offload_device)

def _onload_from_disk(self):
if self.stream is not None:
# Wait for previous Host->Device transfer to complete
self.stream.synchronize()
Expand All @@ -259,23 +352,22 @@ def _onload_from_disk(self):
current_stream = self._torch_accelerator_module.current_stream() if self.record_stream else None

with context:
# Load to CPU (if using streams) or directly to target device, pin, and async copy to device
device = str(self.onload_device) if self.stream is None else "cpu"
loaded_tensors = safetensors.torch.load_file(self.safetensors_file_path, device=device)
loaded_tensors = self._load_disk_state_dict(device=device)

if self.stream is not None:
for key, tensor_obj in self.key_to_tensor.items():
pinned_tensor = loaded_tensors[key].pin_memory()
tensor_obj.data = pinned_tensor.to(self.onload_device, non_blocking=self.non_blocking)
if self.record_stream:
tensor_obj.data.record_stream(current_stream)
pinned_memory = {
tensor_obj: loaded_tensors[self.tensor_to_key[tensor_obj]].pin_memory()
for tensor_obj in self.tensor_to_key
}
self._process_tensors_from_modules(pinned_memory, default_stream=current_stream)
else:
onload_device = (
self.onload_device.type if isinstance(self.onload_device, torch.device) else self.onload_device
)
loaded_tensors = safetensors.torch.load_file(self.safetensors_file_path, device=onload_device)
for key, tensor_obj in self.key_to_tensor.items():
tensor_obj.data = loaded_tensors[key]
for tensor_obj in self.tensor_to_key:
self._transfer_tensor_to_device(
tensor_obj,
loaded_tensors[self.tensor_to_key[tensor_obj]],
default_stream=None,
)

def _onload_from_memory(self):
if self.stream is not None:
Expand All @@ -293,24 +385,21 @@ def _onload_from_memory(self):
self._process_tensors_from_modules(None)

def _offload_to_disk(self):
self._check_disk_offload_torchao()

# TODO: we can potentially optimize this code path by checking if the _all_ the desired
# safetensor files exist on the disk and if so, skip this step entirely, reducing IO
# overhead. Currently, we just check if the given `safetensors_file_path` exists and if not
# we perform a write.
# Check if the file has been saved in this session or if it already exists on disk.
if not self._is_offloaded_to_disk and not os.path.exists(self.safetensors_file_path):
os.makedirs(os.path.dirname(self.safetensors_file_path), exist_ok=True)
tensors_to_save = {key: tensor.data.to(self.offload_device) for tensor, key in self.tensor_to_key.items()}
safetensors.torch.save_file(tensors_to_save, self.safetensors_file_path)
tensors_to_save, metadata = self._get_disk_state_dict()
safetensors.torch.save_file(tensors_to_save, self.safetensors_file_path, metadata=metadata)

# The group is now considered offloaded to disk for the rest of the session.
self._is_offloaded_to_disk = True

# We do this to free up the RAM which is still holding the up tensor data.
for tensor_obj in self.tensor_to_key.keys():
tensor_obj.data = torch.empty_like(tensor_obj.data, device=self.offload_device)
self._release_onload_tensors()

def _offload_to_memory(self):
if self.stream is not None:
Expand Down
3 changes: 3 additions & 0 deletions src/diffusers/models/model_loading_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -357,6 +357,9 @@ def _load_shard_file(
disable_mmap=False,
):
state_dict = load_state_dict(shard_file, dduf_entries=dduf_entries, disable_mmap=disable_mmap)
if hf_quantizer is not None and hasattr(hf_quantizer, "get_reconstructed_state_dict"):
state_dict = hf_quantizer.get_reconstructed_state_dict(state_dict)

mismatched_keys = _find_mismatched_keys(
state_dict,
model_state_dict,
Expand Down
39 changes: 36 additions & 3 deletions src/diffusers/models/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -759,6 +759,17 @@ def save_pretrained(

# Save the model
state_dict = model_to_save.state_dict()
quantization_metadata = {}
if safe_serialization and hf_quantizer is not None:
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should it also include a check if the quantizer is of TorchAO type just to be explicit?

get_state_dict_and_metadata = getattr(hf_quantizer, "get_state_dict_and_metadata", None)
if callable(get_state_dict_and_metadata):
state_dict_and_metadata = get_state_dict_and_metadata(model_to_save)
else:
state_dict_and_metadata = model_to_save.state_dict()
if isinstance(state_dict_and_metadata, tuple):
state_dict, quantization_metadata = state_dict_and_metadata
else:
state_dict = state_dict_and_metadata
Comment on lines +767 to +772
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Where do we get a model_to_save.state_dict() that is of type tuple?


if use_flashpack:
if is_flashpack_available():
Expand Down Expand Up @@ -803,15 +814,21 @@ def save_pretrained(
shard = {tensor: state_dict[tensor].contiguous() for tensor in tensors}
filepath = os.path.join(save_directory, filename)
if safe_serialization:
metadata = dict(state_dict_split.metadata)
metadata.update(quantization_metadata)
metadata = {k: str(v) if not isinstance(v, str) else v for k, v in metadata.items()}
# At some point we will need to deal better with save_function (used for TPU and other distributed
# joyfulness), but for now this enough.
safetensors.torch.save_file(shard, filepath, metadata={"format": "pt"})
safetensors.torch.save_file(shard, filepath, metadata=metadata)
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we're saving all metadata, then I'd restrict it to torchao only.

else:
torch.save(shard, filepath)

if state_dict_split.is_sharded:
metadata = dict(state_dict_split.metadata)
metadata.update(quantization_metadata)
metadata = {k: str(v) if not isinstance(v, str) else v for k, v in metadata.items()}
index = {
"metadata": state_dict_split.metadata,
"metadata": metadata,
"weight_map": state_dict_split.tensor_to_filename,
}
save_index_file = SAFE_WEIGHTS_INDEX_NAME if safe_serialization else WEIGHTS_INDEX_NAME
Expand Down Expand Up @@ -1367,11 +1384,27 @@ def from_pretrained(cls, pretrained_model_name_or_path: str | os.PathLike | None
else:
loaded_keys = list(state_dict.keys())

checkpoint_files = resolved_model_file
if hf_quantizer is not None:
if hasattr(hf_quantizer, "set_metadata"):
hf_quantizer.set_metadata(checkpoint_files)
quantized_weight_names = []
if hasattr(hf_quantizer, "get_weight_names"):
quantized_weight_names = hf_quantizer.get_weight_names()
if quantized_weight_names:
loaded_keys = list(quantized_weight_names)

if hf_quantizer is not None:
hf_quantizer.preprocess_model(
model=model, device_map=device_map, keep_in_fp32_modules=keep_in_fp32_modules
model=model,
device_map=device_map,
keep_in_fp32_modules=keep_in_fp32_modules,
checkpoint_files=checkpoint_files,
)

if hf_quantizer is not None and hf_quantizer.quantization_config.quant_method == QuantizationMethod.TORCHAO:
is_parallel_loading_enabled = False
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is that?


# Now that the model is loaded, we can determine the device_map
device_map = _determine_device_map(
model, device_map, max_memory, torch_dtype, keep_in_fp32_modules, hf_quantizer
Expand Down
Loading
Loading