diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 81b36e113df4..d5870aafd808 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -64,6 +64,8 @@ ], } +_import_structure["quantizers.quantization_config"].append("NunchakuLiteQuantizationConfig") + try: if not is_torch_available() and not is_accelerate_available() and not is_bitsandbytes_available(): raise OptionalDependencyNotAvailable() @@ -968,6 +970,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: from .configuration_utils import ConfigMixin from .quantizers import PipelineQuantizationConfig + from .quantizers.quantization_config import NunchakuLiteQuantizationConfig try: if not is_bitsandbytes_available(): diff --git a/src/diffusers/loaders/single_file_utils.py b/src/diffusers/loaders/single_file_utils.py index 296f32f891f0..eebd150a1477 100644 --- a/src/diffusers/loaders/single_file_utils.py +++ b/src/diffusers/loaders/single_file_utils.py @@ -485,7 +485,6 @@ def load_single_file_checkpoint( ) checkpoint = load_state_dict(pretrained_model_link_or_path, disable_mmap=disable_mmap) - # some checkpoints contain the model state dict under a "state_dict" key while "state_dict" in checkpoint: checkpoint = checkpoint["state_dict"] diff --git a/src/diffusers/quantizers/auto.py b/src/diffusers/quantizers/auto.py index a10bf0cdcb3f..03f2347792ea 100644 --- a/src/diffusers/quantizers/auto.py +++ b/src/diffusers/quantizers/auto.py @@ -22,10 +22,12 @@ from .bitsandbytes import BnB4BitDiffusersQuantizer, BnB8BitDiffusersQuantizer from .gguf import GGUFQuantizer from .modelopt import NVIDIAModelOptQuantizer +from .nunchaku import NunchakuLiteQuantizer from .quantization_config import ( AutoRoundConfig, BitsAndBytesConfig, GGUFQuantizationConfig, + NunchakuLiteQuantizationConfig, NVIDIAModelOptConfig, QuantizationConfigMixin, QuantizationMethod, @@ -44,6 +46,7 @@ "torchao": TorchAoHfQuantizer, "modelopt": NVIDIAModelOptQuantizer, "auto-round": AutoRoundQuantizer, + "nunchaku_lite": NunchakuLiteQuantizer, } AUTO_QUANTIZATION_CONFIG_MAPPING = { @@ -54,6 +57,7 @@ "torchao": TorchAoConfig, "modelopt": NVIDIAModelOptConfig, "auto-round": AutoRoundConfig, + "nunchaku_lite": NunchakuLiteQuantizationConfig, } diff --git a/src/diffusers/quantizers/nunchaku/__init__.py b/src/diffusers/quantizers/nunchaku/__init__.py new file mode 100644 index 000000000000..a8b9aa70a781 --- /dev/null +++ b/src/diffusers/quantizers/nunchaku/__init__.py @@ -0,0 +1,2 @@ +from .nunchaku_quantizer import NunchakuLiteQuantizer + diff --git a/src/diffusers/quantizers/nunchaku/nunchaku_quantizer.py b/src/diffusers/quantizers/nunchaku/nunchaku_quantizer.py new file mode 100644 index 000000000000..d1ae0fdd3531 --- /dev/null +++ b/src/diffusers/quantizers/nunchaku/nunchaku_quantizer.py @@ -0,0 +1,67 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Any + +from ..base import DiffusersQuantizer +from .utils import ( + check_strict_state_dict_match, + replace_with_nunchaku_linear, +) + + +if TYPE_CHECKING: + from ...models.modeling_utils import ModelMixin + + +from ...utils import is_kernels_available, logging + + +logger = logging.get_logger(__name__) + + +class NunchakuLiteQuantizer(DiffusersQuantizer): + def __init__(self, quantization_config, **kwargs): + super().__init__(quantization_config, **kwargs) + self.compute_dtype = quantization_config.compute_dtype + self.pre_quantized = quantization_config.pre_quantized + + def validate_environment(self, *args, **kwargs): + if not is_kernels_available(): + raise ImportError( + "Loading Nunchaku checkpoints requires the Hugging Face `kernels` package. " + "Install it with `pip install kernels`." + ) + + def update_torch_dtype(self, torch_dtype): + if torch_dtype is None: + torch_dtype = self.compute_dtype + return torch_dtype + + def _process_model_before_weight_loading( + self, + model: "ModelMixin", + state_dict: dict[str, Any] | None = None, + metadata: dict[str, str] | None = None, + **kwargs, + ): + quantization_config = self.quantization_config.to_dict() + num_replaced = replace_with_nunchaku_linear(model, quantization_config, self.compute_dtype) + + if state_dict is not None: + check_strict_state_dict_match(model, state_dict) + logger.info(f"Applied Nunchaku quantization config with {num_replaced} targets.") + + def _process_model_after_weight_loading(self, model: "ModelMixin", **kwargs): + return model + + @property + def is_serializable(self): + return False + + @property + def is_trainable(self) -> bool: + return False + + @property + def is_compileable(self) -> bool: + return True diff --git a/src/diffusers/quantizers/nunchaku/utils.py b/src/diffusers/quantizers/nunchaku/utils.py new file mode 100644 index 000000000000..1cae5cc5f479 --- /dev/null +++ b/src/diffusers/quantizers/nunchaku/utils.py @@ -0,0 +1,359 @@ +from __future__ import annotations + +import math +from typing import Any + +import torch +import torch.nn as nn + + +_HF_KERNEL_REPO = "rootonchair/nunchaku-lite-kernels" +_HF_KERNEL_VERSION = 1 + + +_ops = None + + +def _get_ops(): + global _ops + if _ops is None: + from kernels import get_kernel + + _ops = get_kernel(_HF_KERNEL_REPO, version=_HF_KERNEL_VERSION, trust_remote_code=True).ops + return _ops + + +def _gemm_w4a4( + act: torch.Tensor, + wgt: torch.Tensor, + out: torch.Tensor, + ascales: torch.Tensor, + wscales: torch.Tensor, + lora_act_in: torch.Tensor, + lora_up: torch.Tensor, + bias: torch.Tensor | None, + act_unsigned: bool, + lora_scales: list[float], + fp4: bool, + alpha: torch.Tensor | None, + wcscales: torch.Tensor | None, +) -> None: + _get_ops().gemm_w4a4( + act, + wgt, + out, + None, + ascales, + wscales, + None, + None, + lora_act_in, + lora_up, + None, + None, + None, + None, + None, + bias, + None, + None, + None, + act_unsigned, + lora_scales, + False, + fp4, + alpha, + wcscales, + None, + None, + None, + 0, + ) + + +def replace_with_nunchaku_linear( + model: nn.Module, quantization_config: dict[str, Any], compute_dtype: torch.dtype +) -> int: + num_replaced = 0 + svdq_config = quantization_config.get("svdq_w4a4") + awq_config = quantization_config.get("awq_w4a16") + + if svdq_config is not None: + num_replaced += _replace_quantize_targets(model, "svdq_w4a4", svdq_config, compute_dtype) + if awq_config is not None: + num_replaced += _replace_quantize_targets(model, "awq_w4a16", awq_config, compute_dtype) + if num_replaced == 0: + raise ValueError( + "Nunchaku compact quantization config must include `svdq_w4a4.targets` or `awq_w4a16.targets`." + ) + + return num_replaced + + +class SVDQW4A4Linear(nn.Module): + def __init__( + self, + in_features: int, + out_features: int, + rank: int = 32, + bias: bool = True, + precision: str = "int4", + group_size: int = 64, + torch_dtype: torch.dtype = torch.bfloat16, + device: str | torch.device | None = None, + act_unsigned: bool = False, + ): + super().__init__() + if device is None: + device = torch.device("cpu") + + if precision not in {"int4", "nvfp4"}: + raise ValueError(f"Invalid Nunchaku SVDQ precision: {precision!r}.") + if group_size <= 0: + raise ValueError(f"Nunchaku SVDQ group_size must be positive, got {group_size}.") + + self.in_features = in_features + self.out_features = out_features + self.rank = rank + self.precision = precision + self.group_size = group_size + self.torch_dtype = torch_dtype + self.act_unsigned = act_unsigned + + self.qweight = nn.Parameter( + torch.empty(out_features, in_features // 2, dtype=torch.int8, device=device), requires_grad=False + ) + self.bias = ( + nn.Parameter(torch.empty(out_features, dtype=torch_dtype, device=device), requires_grad=False) + if bias + else None + ) + self.wscales = nn.Parameter( + torch.empty( + in_features // group_size, + out_features, + dtype=torch_dtype if precision == "int4" else torch.float8_e4m3fn, + device=device, + ), + requires_grad=False, + ) + self.smooth_factor = nn.Parameter( + torch.empty(in_features, dtype=torch_dtype, device=device), requires_grad=False + ) + self.smooth_factor_orig = nn.Parameter( + torch.empty(in_features, dtype=torch_dtype, device=device), requires_grad=False + ) + self.proj_down = nn.Parameter(torch.empty(in_features, rank, dtype=torch_dtype, device=device), requires_grad=False) + self.proj_up = nn.Parameter(torch.empty(out_features, rank, dtype=torch_dtype, device=device), requires_grad=False) + + if precision == "nvfp4": + self.wcscales = nn.Parameter(torch.ones(out_features, dtype=torch_dtype, device=device), requires_grad=False) + self.wtscale = nn.Parameter(torch.ones(1, dtype=torch_dtype, device=device), requires_grad=False) + else: + self.wcscales = None + self.wtscale = None + + def forward(self, x: torch.Tensor, output: torch.Tensor | None = None) -> torch.Tensor: + original_shape = x.shape + channels = x.shape[-1] + x = x.reshape(-1, channels) + rows = x.shape[0] + if output is None: + output = torch.empty(rows, self.out_features, dtype=self.torch_dtype, device=x.device) + + pad_size = 256 + batch_size_pad = math.ceil(x.shape[0] / pad_size) * pad_size + quantized_x = torch.empty(batch_size_pad, channels // 2, dtype=torch.uint8, device=x.device) + if self.precision == "nvfp4": + ascales = torch.empty(channels // 16, batch_size_pad, dtype=torch.float8_e4m3fn, device=x.device) + else: + ascales = torch.empty(channels // 64, batch_size_pad, dtype=x.dtype, device=x.device) + lora_act = torch.empty(batch_size_pad, self.rank, dtype=torch.float32, device=x.device) + + _get_ops().quantize_w4a4_act_fuse_lora( + x, + quantized_x, + ascales, + self.proj_down, + lora_act, + self.smooth_factor, + False, + self.precision == "nvfp4", + ) + lora_scales = [1.0] * math.ceil(self.rank / 16) + _gemm_w4a4( + quantized_x, + self.qweight, + output, + ascales, + self.wscales, + lora_act, + self.proj_up, + self.bias, + self.act_unsigned, + lora_scales, + self.precision == "nvfp4", + self.wtscale, + self.wcscales, + ) + return output.reshape(*original_shape[:-1], self.out_features) + + +class AWQW4A16Linear(nn.Module): + def __init__( + self, + in_features: int, + out_features: int, + bias: bool = True, + group_size: int = 64, + torch_dtype: torch.dtype = torch.bfloat16, + device: str | torch.device | None = None, + ): + super().__init__() + if device is None: + device = torch.device("cpu") + if group_size != 64: + raise ValueError(f"Nunchaku AWQ W4A16 currently supports group_size=64 only, got {group_size}.") + + self.in_features = in_features + self.out_features = out_features + self.group_size = group_size + + self.qweight = nn.Parameter( + torch.empty(out_features // 4, in_features // 2, dtype=torch.int32, device=device), requires_grad=False + ) + self.bias = ( + nn.Parameter(torch.empty(out_features, dtype=torch_dtype, device=device), requires_grad=False) + if bias + else None + ) + self.wscales = nn.Parameter( + torch.empty(in_features // group_size, out_features, dtype=torch_dtype, device=device), requires_grad=False + ) + self.wzeros = nn.Parameter( + torch.empty(in_features // group_size, out_features, dtype=torch_dtype, device=device), requires_grad=False + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + if x.shape[-1] != self.in_features: + raise ValueError( + f"AWQW4A16Linear expected input last dimension {self.in_features}, got shape {tuple(x.shape)}." + ) + + output_shape = (*x.shape[:-1], self.out_features) + x_flat = x.reshape(-1, self.in_features).contiguous() + if x_flat.shape[0] == 0: + output = x.new_empty(output_shape) + elif self._use_gemm(x_flat.shape[0]): + output = _get_ops().awq_gemm_w4a16_g64_int32(x_flat, self.qweight, self.wscales, self.wzeros).reshape( + output_shape + ) + else: + output = self._forward_gemv_chunks(x_flat, _get_ops().gemv_awq).reshape(output_shape) + + if self.bias is not None: + output = output + self.bias.view([1] * (output.ndim - 1) + [-1]) + return output + + def _use_gemm(self, rows: int) -> bool: + return rows >= 16 and self.in_features % 64 == 0 and self.out_features % 128 == 0 + + def _forward_gemv_chunks(self, x_flat: torch.Tensor, gemv) -> torch.Tensor: + outputs = [] + for start in range(0, x_flat.shape[0], 8): + chunk = x_flat[start : start + 8] + outputs.append( + gemv( + chunk, + self.qweight, + self.wscales, + self.wzeros, + chunk.shape[0], + self.out_features, + self.in_features, + 64, + ) + ) + return torch.cat(outputs, dim=0) + + +def _replace_quantize_targets(model: nn.Module, op: str, raw: Any, compute_dtype: torch.dtype) -> int: + precision = raw["precision"] + group_size = raw["group_size"] + targets = raw["targets"] + rank = raw["rank"] if op == "svdq_w4a4" else 0 + + for target in targets: + try: + module = model.get_submodule(target) + except AttributeError as exc: + raise ValueError(f"Nunchaku target {target!r} does not exist in the model.") from exc + + in_features = getattr(module, "in_features", None) + out_features = getattr(module, "out_features", None) + bias = getattr(module, "bias", None) + if not isinstance(in_features, int) or not isinstance(out_features, int): + raise TypeError(f"Nunchaku target {target!r} must expose integer in_features/out_features.") + + if op == "svdq_w4a4": + replacement = SVDQW4A4Linear( + in_features, + out_features, + rank=rank, + bias=bias is not None, + precision="nvfp4" if precision == "fp4" else precision, + group_size=group_size, + torch_dtype=compute_dtype, + device=_module_device(module), + ) + elif op == "awq_w4a16": + replacement = AWQW4A16Linear( + in_features, + out_features, + bias=bias is not None, + group_size=group_size, + torch_dtype=compute_dtype, + device=_module_device(module), + ) + + _set_submodule(model, target, replacement) + + return len(targets) + + +def _set_submodule(model: nn.Module, path: str, module: nn.Module) -> None: + parent_path, _, child_name = path.rpartition(".") + parent = model.get_submodule(parent_path) if parent_path else model + if child_name.isdigit() and isinstance(parent, (nn.Sequential, nn.ModuleList)): + parent[int(child_name)] = module + else: + setattr(parent, child_name, module) + + +def _module_device(module: nn.Module) -> torch.device: + parameter = next(module.parameters(recurse=False), None) + if parameter is not None: + return parameter.device + return torch.device("cpu") + + +def check_strict_state_dict_match(model: nn.Module, state_dict: dict[str, Any]) -> None: + import itertools + + expected_keys = {n for n, _ in itertools.chain(model.named_parameters(), model.named_buffers())} + loaded_keys = set(state_dict.keys()) + missing_keys = sorted(expected_keys - loaded_keys) + unexpected_keys = sorted(loaded_keys - expected_keys) + if missing_keys or unexpected_keys: + message = "Nunchaku checkpoint keys must exactly match the patched model state dict." + if missing_keys: + message += f" Missing keys: {missing_keys[:10]}" + if len(missing_keys) > 10: + message += f" and {len(missing_keys) - 10} more" + message += "." + if unexpected_keys: + message += f" Unexpected keys: {unexpected_keys[:10]}" + if len(unexpected_keys) > 10: + message += f" and {len(unexpected_keys) - 10} more" + message += "." + raise ValueError(message) diff --git a/src/diffusers/quantizers/quantization_config.py b/src/diffusers/quantizers/quantization_config.py index 0c98e40ba962..627382e81e4e 100644 --- a/src/diffusers/quantizers/quantization_config.py +++ b/src/diffusers/quantizers/quantization_config.py @@ -45,6 +45,7 @@ class QuantizationMethod(str, Enum): BITS_AND_BYTES = "bitsandbytes" GGUF = "gguf" + NUNCHAKU_LITE = "nunchaku_lite" TORCHAO = "torchao" QUANTO = "quanto" MODELOPT = "modelopt" @@ -429,6 +430,96 @@ def __init__(self, compute_dtype: "torch.dtype" | None = None): self.compute_dtype = torch.float32 +@dataclass +class NunchakuLiteQuantizationConfig(QuantizationConfigMixin): + """Configuration for loading Nunchaku Lite checkpoints. + + Args: + compute_dtype (`torch.dtype`, defaults to `torch.bfloat16`): + Runtime dtype used by the floating-point buffers in the quantized modules. + svdq_w4a4 (`dict`, *optional*): + Explicit SVDQ W4A4 target configuration with `precision`, `group_size`, `rank`, and `targets`. + awq_w4a16 (`dict`, *optional*): + Explicit AWQ W4A16 target configuration with `precision`, `group_size`, and `targets`. + """ + + def __init__( + self, + compute_dtype: "torch.dtype" | str | None = None, + svdq_w4a4: dict[str, Any] | None = None, + awq_w4a16: dict[str, Any] | None = None, + **kwargs, + ): + self.quant_method = QuantizationMethod.NUNCHAKU_LITE + if compute_dtype is None: + compute_dtype = torch.bfloat16 + if isinstance(compute_dtype, str): + if not hasattr(torch, compute_dtype): + raise ValueError(f"Unsupported Nunchaku compute dtype: {compute_dtype!r}.") + compute_dtype = getattr(torch, compute_dtype) + if not isinstance(compute_dtype, torch.dtype): + raise ValueError("Nunchaku compute_dtype must be a string or a torch.dtype.") + self.compute_dtype = compute_dtype + self.pre_quantized = True + self.svdq_w4a4 = svdq_w4a4 + self.awq_w4a16 = awq_w4a16 + + self.post_init() + + def post_init(self): + if self.svdq_w4a4 is None and self.awq_w4a16 is None: + raise ValueError( + "Nunchaku compact quantization config must include `svdq_w4a4.targets` or `awq_w4a16.targets`." + ) + + for op, raw in (("svdq_w4a4", self.svdq_w4a4), ("awq_w4a16", self.awq_w4a16)): + if raw is None: + continue + if not isinstance(raw, dict): + raise ValueError(f"Nunchaku compact config section {op!r} must be a JSON object.") + + for key, expected_type in (("precision", str), ("group_size", int), ("targets", list)): + if key not in raw: + raise ValueError(f"Nunchaku compact config section {op!r} is missing required field {key!r}.") + if not isinstance(raw[key], expected_type): + raise ValueError( + f"Nunchaku compact config section {op!r} field {key!r} must be {expected_type.__name__}." + ) + + precision = raw["precision"] + group_size = raw["group_size"] + targets = raw["targets"] + if precision not in ("int4", "fp4"): + raise ValueError(f"Unsupported Nunchaku precision {precision!r} for {op!r}.") + if group_size <= 0: + raise ValueError(f"Nunchaku compact config section {op!r} must have positive group_size.") + if not targets: + raise ValueError(f"Nunchaku compact config section {op!r} must contain at least one target.") + if not all(isinstance(target, str) for target in targets): + raise ValueError(f"Nunchaku compact config section {op!r} targets must be strings.") + + if op == "svdq_w4a4": + if "rank" not in raw: + raise ValueError(f"Nunchaku compact config section {op!r} is missing required field 'rank'.") + if not isinstance(raw["rank"], int): + raise ValueError(f"Nunchaku compact config section {op!r} field 'rank' must be int.") + if raw["rank"] < 0: + raise ValueError(f"Nunchaku compact config section {op!r} must have non-negative rank.") + expected_group_size = 16 if precision == "fp4" else 64 + if group_size != expected_group_size: + raise ValueError( + f"Nunchaku SVDQ config with precision={precision!r} requires " + f"group_size={expected_group_size}, got {group_size}." + ) + elif precision != "int4": + raise ValueError("Nunchaku AWQ target requires precision='int4'.") + + def to_dict(self) -> dict[str, Any]: + output = super().to_dict() + output["compute_dtype"] = str(output["compute_dtype"]).split(".")[1] + return output + + @dataclass class TorchAoConfig(QuantizationConfigMixin): """This is a config class for torchao quantization/sparsity techniques. diff --git a/tests/models/testing_utils/__init__.py b/tests/models/testing_utils/__init__.py index 728a7ac80248..11abe12151f0 100644 --- a/tests/models/testing_utils/__init__.py +++ b/tests/models/testing_utils/__init__.py @@ -31,6 +31,9 @@ ModelOptCompileTesterMixin, ModelOptConfigMixin, ModelOptTesterMixin, + NunchakuLiteCompileTesterMixin, + NunchakuLiteConfigMixin, + NunchakuLiteTesterMixin, QuantizationCompileTesterMixin, QuantizationTesterMixin, QuantoCompileTesterMixin, @@ -76,6 +79,9 @@ "ModelOptConfigMixin", "ModelOptTesterMixin", "ModelTesterMixin", + "NunchakuLiteCompileTesterMixin", + "NunchakuLiteConfigMixin", + "NunchakuLiteTesterMixin", "PyramidAttentionBroadcastConfigMixin", "PyramidAttentionBroadcastTesterMixin", "TaylorSeerCacheConfigMixin", diff --git a/tests/models/testing_utils/quantization.py b/tests/models/testing_utils/quantization.py index 4849e28fb396..3352e60140da 100644 --- a/tests/models/testing_utils/quantization.py +++ b/tests/models/testing_utils/quantization.py @@ -22,13 +22,16 @@ AutoRoundConfig, BitsAndBytesConfig, GGUFQuantizationConfig, + NunchakuLiteQuantizationConfig, NVIDIAModelOptConfig, QuantoConfig, TorchAoConfig, ) +from diffusers.quantizers.nunchaku.utils import AWQW4A16Linear, SVDQW4A4Linear from diffusers.utils.import_utils import ( is_bitsandbytes_available, is_gguf_available, + is_kernels_available, is_nvidia_modelopt_available, is_optimum_quanto_available, is_torchao_available, @@ -1399,6 +1402,104 @@ def test_gguf_torch_compile_with_group_offload(self): self._test_torch_compile_with_group_offload({"compute_dtype": torch.bfloat16}) +@pytest.mark.skipif(not is_kernels_available(), reason="`kernels` is not available.") +@require_accelerate +@require_accelerator +class NunchakuLiteConfigMixin: + """ + Base mixin providing Nunchaku Lite quantization config and model creation. + + Expected class attributes: + - model_class: The model class to test + - quantized_model_name_or_path: Hub repository ID or local path for the quantized model + - pretrained_model_kwargs: (Optional) Dict of kwargs to pass to from_pretrained + """ + + config_dict = None + + def _create_quantized_model(self, config_kwargs=None, **extra_kwargs): + kwargs = getattr(self, "pretrained_model_kwargs", {}).copy() + if config_kwargs is not None: + kwargs["quantization_config"] = NunchakuLiteQuantizationConfig(**config_kwargs) + kwargs.update(extra_kwargs) + return self.model_class.from_pretrained(self.quantized_model_name_or_path, **kwargs) + + def _verify_if_layer_quantized(self, name, module, config_kwargs): + assert isinstance(module, (SVDQW4A4Linear, AWQW4A16Linear)), ( + f"Layer {name} is not a Nunchaku Lite layer, got {type(module)}" + ) + + +@pytest.mark.skipif(not is_kernels_available(), reason="`kernels` is not available.") +@require_accelerate +@require_accelerator +class NunchakuLiteTesterMixin(NunchakuLiteConfigMixin, QuantizationTesterMixin): + """ + Mixin class for testing Nunchaku Lite quantization on models. + + Expected class attributes: + - model_class: The model class to test + - quantized_model_name_or_path: Hub repository ID or local path for the quantized model + - pretrained_model_kwargs: (Optional) Dict of kwargs to pass to from_pretrained + + Expected methods to be implemented by subclasses: + - get_dummy_inputs(): Returns dict of inputs to pass to the model forward pass + """ + + def test_nunchaku_lite_quantization_inference(self): + self._test_quantization_inference(self.config_dict) + + def _is_module_quantized(self, module): + return isinstance(module, (SVDQW4A4Linear, AWQW4A16Linear)) + + def _test_quantized_layers(self, config_kwargs): + model = self._create_quantized_model(config_kwargs) + + num_quantized_layers = 0 + for name, module in model.named_modules(): + if self._is_module_quantized(module): + self._verify_if_layer_quantized(name, module, config_kwargs) + num_quantized_layers += 1 + + expected_quantized_layers = num_quantized_layers + num_fp32_modules = 0 + + assert num_quantized_layers > 0, ( + f"No quantized layers found in model (expected {expected_quantized_layers} Nunchaku Lite layers, " + f"{num_fp32_modules} kept in FP32)" + ) + assert num_quantized_layers == expected_quantized_layers, ( + f"Quantized layer count mismatch: expected {expected_quantized_layers}, got {num_quantized_layers} " + f"(total Nunchaku Lite layers: {expected_quantized_layers}, FP32 modules: {num_fp32_modules})" + ) + + def test_nunchaku_lite_quantized_layers(self): + self._test_quantized_layers(self.config_dict) + + +@pytest.mark.skipif(not is_kernels_available(), reason="`kernels` is not available.") +@require_accelerate +@require_accelerator +class NunchakuLiteCompileTesterMixin(NunchakuLiteConfigMixin, QuantizationCompileTesterMixin): + """ + Mixin class for testing torch.compile with Nunchaku Lite quantized models. + + Expected class attributes: + - model_class: The model class to test + - quantized_model_name_or_path: Hub repository ID or local path for the quantized model + - pretrained_model_kwargs: (Optional) Dict of kwargs to pass to from_pretrained + + Expected methods to be implemented by subclasses: + - get_dummy_inputs(): Returns dict of inputs to pass to the model forward pass + """ + + def test_nunchaku_lite_torch_compile(self): + self._test_torch_compile(self.config_dict) + + def test_nunchaku_lite_torch_compile_with_group_offload(self): + self._test_torch_compile_with_group_offload(self.config_dict) + + @is_modelopt @require_accelerator @require_accelerate diff --git a/tests/quantization/nunchaku/__init__.py b/tests/quantization/nunchaku/__init__.py new file mode 100644 index 000000000000..8b137891791f --- /dev/null +++ b/tests/quantization/nunchaku/__init__.py @@ -0,0 +1 @@ + diff --git a/tests/quantization/nunchaku/test_nunchaku.py b/tests/quantization/nunchaku/test_nunchaku.py new file mode 100644 index 000000000000..85928e35b38d --- /dev/null +++ b/tests/quantization/nunchaku/test_nunchaku.py @@ -0,0 +1,160 @@ +import gc +import json +import os +import tempfile +import unittest + +import torch +from safetensors.torch import save_file + +from diffusers import ConfigMixin, ModelMixin, NunchakuLiteQuantizationConfig +from diffusers.configuration_utils import register_to_config +from diffusers.quantizers import DiffusersAutoQuantizer +from diffusers.quantizers.nunchaku.utils import AWQW4A16Linear, SVDQW4A4Linear +from diffusers.utils import is_kernels_available + +from ...testing_utils import backend_empty_cache, nightly, require_accelerator, torch_device + + +class TinyPretrainedModel(ModelMixin, ConfigMixin): + config_name = "config.json" + + @register_to_config + def __init__(self): + super().__init__() + self.svdq = torch.nn.Linear(64, 128, bias=True) + self.awq = torch.nn.Linear(64, 128, bias=False) + + +def _state_dict(precision="int4"): + state_dict = { + "svdq.bias": torch.randn(128, dtype=torch.bfloat16), + "svdq.proj_down": torch.randn(64, 4, dtype=torch.bfloat16), + "svdq.proj_up": torch.randn(128, 4, dtype=torch.bfloat16), + "svdq.qweight": torch.randint(-8, 8, (128, 32), dtype=torch.int8), + "svdq.smooth_factor": torch.randn(64, dtype=torch.bfloat16), + "svdq.smooth_factor_orig": torch.randn(64, dtype=torch.bfloat16), + "awq.qweight": torch.randint(-8, 8, (32, 32), dtype=torch.int32), + "awq.wscales": torch.randn(1, 128, dtype=torch.bfloat16), + "awq.wzeros": torch.randn(1, 128, dtype=torch.bfloat16), + } + if precision == "fp4": + state_dict["svdq.wcscales"] = torch.randn(128, dtype=torch.bfloat16) + state_dict["svdq.wscales"] = torch.empty(4, 128, dtype=torch.float8_e4m3fn) + state_dict["svdq.wtscale"] = torch.randn(1, dtype=torch.bfloat16) + else: + state_dict["svdq.wscales"] = torch.randn(1, 128, dtype=torch.bfloat16) + return state_dict + + +def _compact_config(): + return { + "svdq_w4a4": { + "precision": "fp4", + "group_size": 16, + "rank": 4, + "targets": ["svdq"], + }, + "awq_w4a16": { + "precision": "int4", + "group_size": 64, + "targets": ["awq"], + }, + } + + +@nightly +@require_accelerator +class NunchakuLiteCudaKernelsTests(unittest.TestCase): + def setUp(self): + gc.collect() + backend_empty_cache(torch_device) + + def tearDown(self): + gc.collect() + backend_empty_cache(torch_device) + + def test_awq_cuda_kernels(self): + if torch_device != "cuda": + self.skipTest("Nunchaku Lite CUDA kernels test requires CUDA device") + if not is_kernels_available(): + self.skipTest("Nunchaku Lite CUDA kernels test requires kernels") + + torch.manual_seed(0) + layer = AWQW4A16Linear(64, 128, bias=True, group_size=64, torch_dtype=torch.bfloat16, device=torch_device) + layer.qweight.data = torch.randint(-8, 8, layer.qweight.shape, dtype=torch.int32, device=torch_device) + layer.wscales.data = torch.rand(layer.wscales.shape, dtype=torch.bfloat16, device=torch_device) + layer.wzeros.data = torch.rand(layer.wzeros.shape, dtype=torch.bfloat16, device=torch_device) + layer.bias.data.zero_() + + for shape in [(1, 8, 64), (1, 16, 64)]: + x = torch.randn(shape, dtype=torch.bfloat16, device=torch_device) + with torch.no_grad(): + output = layer(x) + + self.assertEqual(output.shape, (*shape[:-1], 128)) + self.assertFalse(torch.isnan(output).any()) + + +class NunchakuLiteBasicTests(unittest.TestCase): + model_cls = TinyPretrainedModel + + def test_compact_config_round_trips_dtype_and_targets(self): + quantization_config = NunchakuLiteQuantizationConfig(compute_dtype=torch.bfloat16, **_compact_config()) + config_dict = quantization_config.to_dict() + + self.assertEqual(config_dict["compute_dtype"], "bfloat16") + self.assertEqual(config_dict["svdq_w4a4"]["precision"], "fp4") + + reloaded_config = NunchakuLiteQuantizationConfig.from_dict(config_dict) + self.assertEqual(reloaded_config.compute_dtype, torch.bfloat16) + self.assertEqual(reloaded_config.svdq_w4a4["targets"], ["svdq"]) + + def test_compact_config_replaces_svdq_and_awq_without_state_dict(self): + model = self.model_cls() + quantizer = DiffusersAutoQuantizer.from_config( + NunchakuLiteQuantizationConfig(compute_dtype=torch.bfloat16, **_compact_config()) + ) + + quantizer.preprocess_model(model) + + self.assertIsInstance(model.svdq, SVDQW4A4Linear) + self.assertIsInstance(model.awq, AWQW4A16Linear) + self.assertEqual(model.svdq.precision, "nvfp4") + self.assertEqual(model.svdq.rank, 4) + self.assertIsNotNone(model.svdq.bias) + self.assertIsNone(model.awq.bias) + + @unittest.skipIf(not is_kernels_available(), "Nunchaku Lite from_pretrained requires kernels.") + def test_nunchaku_lite_loads_with_from_pretrained(self): + with tempfile.TemporaryDirectory() as tmpdir: + model = self.model_cls() + model.save_config(tmpdir) + + config_path = os.path.join(tmpdir, "config.json") + with open(config_path) as handle: + config = json.load(handle) + + compact_config = _compact_config() + config["quantization_config"] = NunchakuLiteQuantizationConfig( + compute_dtype=torch.bfloat16, **compact_config + ).to_dict() + + with open(config_path, "w") as handle: + json.dump(config, handle) + + svdq_config = compact_config["svdq_w4a4"] + precision = "fp4" if svdq_config["precision"] == "fp4" else "int4" + save_file( + _state_dict(precision=precision), os.path.join(tmpdir, "diffusion_pytorch_model.safetensors") + ) + + loaded_model = self.model_cls.from_pretrained(tmpdir) + + self.assertIsInstance(loaded_model.svdq, SVDQW4A4Linear) + self.assertIsInstance(loaded_model.awq, AWQW4A16Linear) + self.assertEqual(loaded_model.svdq.precision, "nvfp4") + + +if __name__ == "__main__": + unittest.main()