Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions src/diffusers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,8 @@
],
}

_import_structure["quantizers.quantization_config"].append("NunchakuLiteQuantizationConfig")

try:
if not is_torch_available() and not is_accelerate_available() and not is_bitsandbytes_available():
raise OptionalDependencyNotAvailable()
Expand Down Expand Up @@ -968,6 +970,7 @@
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
from .configuration_utils import ConfigMixin
from .quantizers import PipelineQuantizationConfig
from .quantizers.quantization_config import NunchakuLiteQuantizationConfig

try:
if not is_bitsandbytes_available():
Expand Down
1 change: 0 additions & 1 deletion src/diffusers/loaders/single_file_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -485,7 +485,6 @@ def load_single_file_checkpoint(
)

checkpoint = load_state_dict(pretrained_model_link_or_path, disable_mmap=disable_mmap)

# some checkpoints contain the model state dict under a "state_dict" key
while "state_dict" in checkpoint:
checkpoint = checkpoint["state_dict"]
Expand Down
4 changes: 4 additions & 0 deletions src/diffusers/quantizers/auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,12 @@
from .bitsandbytes import BnB4BitDiffusersQuantizer, BnB8BitDiffusersQuantizer
from .gguf import GGUFQuantizer
from .modelopt import NVIDIAModelOptQuantizer
from .nunchaku import NunchakuLiteQuantizer
from .quantization_config import (
AutoRoundConfig,
BitsAndBytesConfig,
GGUFQuantizationConfig,
NunchakuLiteQuantizationConfig,
NVIDIAModelOptConfig,
QuantizationConfigMixin,
QuantizationMethod,
Expand All @@ -44,6 +46,7 @@
"torchao": TorchAoHfQuantizer,
"modelopt": NVIDIAModelOptQuantizer,
"auto-round": AutoRoundQuantizer,
"nunchaku_lite": NunchakuLiteQuantizer,
}

AUTO_QUANTIZATION_CONFIG_MAPPING = {
Expand All @@ -54,6 +57,7 @@
"torchao": TorchAoConfig,
"modelopt": NVIDIAModelOptConfig,
"auto-round": AutoRoundConfig,
"nunchaku_lite": NunchakuLiteQuantizationConfig,
}


Expand Down
2 changes: 2 additions & 0 deletions src/diffusers/quantizers/nunchaku/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from .nunchaku_quantizer import NunchakuLiteQuantizer

67 changes: 67 additions & 0 deletions src/diffusers/quantizers/nunchaku/nunchaku_quantizer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
from __future__ import annotations

from typing import TYPE_CHECKING, Any

from ..base import DiffusersQuantizer
from .utils import (
check_strict_state_dict_match,
replace_with_nunchaku_linear,
)


if TYPE_CHECKING:
from ...models.modeling_utils import ModelMixin


from ...utils import is_kernels_available, logging


logger = logging.get_logger(__name__)


class NunchakuLiteQuantizer(DiffusersQuantizer):
def __init__(self, quantization_config, **kwargs):
super().__init__(quantization_config, **kwargs)
self.compute_dtype = quantization_config.compute_dtype
self.pre_quantized = quantization_config.pre_quantized

def validate_environment(self, *args, **kwargs):
if not is_kernels_available():
raise ImportError(
"Loading Nunchaku checkpoints requires the Hugging Face `kernels` package. "
"Install it with `pip install kernels`."
)

def update_torch_dtype(self, torch_dtype):
if torch_dtype is None:
torch_dtype = self.compute_dtype
return torch_dtype

def _process_model_before_weight_loading(
self,
model: "ModelMixin",
state_dict: dict[str, Any] | None = None,
metadata: dict[str, str] | None = None,
**kwargs,
):
quantization_config = self.quantization_config.to_dict()
num_replaced = replace_with_nunchaku_linear(model, quantization_config, self.compute_dtype)

if state_dict is not None:
check_strict_state_dict_match(model, state_dict)
logger.info(f"Applied Nunchaku quantization config with {num_replaced} targets.")

def _process_model_after_weight_loading(self, model: "ModelMixin", **kwargs):
return model

@property
def is_serializable(self):
return False

@property

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should set is_compileable() property too:

def is_compileable(self) -> bool:

def is_trainable(self) -> bool:
return False

@property
def is_compileable(self) -> bool:
return True
Loading
Loading