Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions src/diffusers/quantizers/quantization_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -470,8 +470,8 @@ def __init__(
self.post_init()

def post_init(self):
if is_torchao_version("<=", "0.9.0"):
raise ValueError("TorchAoConfig requires torchao > 0.9.0. Please upgrade with `pip install -U torchao`.")
if is_torchao_version("<", "0.15.0"):
raise ValueError("TorchAoConfig requires torchao >= 0.15.0. Please upgrade with `pip install -U torchao`.")

from torchao.quantization.quant_api import AOBaseConfig

Expand All @@ -495,8 +495,8 @@ def to_dict(self):
@classmethod
def from_dict(cls, config_dict, return_unused_kwargs=False, **kwargs):
"""Create configuration from a dictionary."""
if not is_torchao_version(">", "0.9.0"):
raise NotImplementedError("TorchAoConfig requires torchao > 0.9.0 for construction from dict")
if not is_torchao_version(">=", "0.15.0"):
raise NotImplementedError("TorchAoConfig requires torchao >= 0.15.0 for construction from dict")
config_dict = config_dict.copy()
quant_type = config_dict.pop("quant_type")

Expand Down
8 changes: 4 additions & 4 deletions src/diffusers/quantizers/torchao/torchao_quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ def _update_torch_safe_globals():
is_torch_available()
and is_torch_version(">=", "2.6.0")
and is_torchao_available()
and is_torchao_version(">=", "0.7.0")
and is_torchao_version(">=", "0.15.0")
):
_update_torch_safe_globals()

Expand Down Expand Up @@ -168,10 +168,10 @@ def validate_environment(self, *args, **kwargs):
raise ImportError(
"Loading a TorchAO quantized model requires the torchao library. Please install with `pip install torchao`"
)
torchao_version = version.parse(importlib.metadata.version("torch"))
if torchao_version < version.parse("0.7.0"):
torchao_version = version.parse(importlib.metadata.version("torchao"))
if torchao_version < version.parse("0.15.0"):
raise RuntimeError(
f"The minimum required version of `torchao` is 0.7.0, but the current version is {torchao_version}. Please upgrade with `pip install -U torchao`."
f"The minimum required version of `torchao` is 0.15.0, but the current version is {torchao_version}. Please upgrade with `pip install -U torchao`."
)

self.offload = False
Expand Down
21 changes: 9 additions & 12 deletions tests/quantization/torchao/test_torchao.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,11 @@
# limitations under the License.

import gc
import importlib.metadata
import tempfile
import unittest
from typing import List

import numpy as np
from packaging import version
from parameterized import parameterized
from transformers import AutoTokenizer, CLIPTextModel, CLIPTokenizer, T5EncoderModel

Expand Down Expand Up @@ -82,18 +80,17 @@ def _is_xpu_or_cuda_capability_atleast_8_9() -> bool:
Float8WeightOnlyConfig,
Int4WeightOnlyConfig,
Int8DynamicActivationInt8WeightConfig,
Int8DynamicActivationIntxWeightConfig,
Int8WeightOnlyConfig,
IntxWeightOnlyConfig,
)
from torchao.quantization.linear_activation_quantized_tensor import LinearActivationQuantizedTensor
from torchao.utils import get_model_size_in_bytes

if version.parse(importlib.metadata.version("torchao")) >= version.Version("0.10.0"):
from torchao.quantization import Int8DynamicActivationIntxWeightConfig, IntxWeightOnlyConfig


@require_torch
@require_torch_accelerator
@require_torchao_version_greater_or_equal("0.14.0")
@require_torchao_version_greater_or_equal("0.15.0")
class TorchAoConfigTest(unittest.TestCase):
def test_to_dict(self):
"""
Expand Down Expand Up @@ -128,7 +125,7 @@ def test_repr(self):
# Slices for these tests have been obtained on our aws-g6e-xlarge-plus runners
@require_torch
@require_torch_accelerator
@require_torchao_version_greater_or_equal("0.14.0")
@require_torchao_version_greater_or_equal("0.15.0")
class TorchAoTest(unittest.TestCase):
def tearDown(self):
gc.collect()
Expand Down Expand Up @@ -527,7 +524,7 @@ def test_sequential_cpu_offload(self):
inputs = self.get_dummy_inputs(torch_device)
_ = pipe(**inputs)

@require_torchao_version_greater_or_equal("0.9.0")
@require_torchao_version_greater_or_equal("0.15.0")
def test_aobase_config(self):
quantization_config = TorchAoConfig(Int8WeightOnlyConfig())
components = self.get_dummy_components(quantization_config)
Expand All @@ -540,7 +537,7 @@ def test_aobase_config(self):
# Slices for these tests have been obtained on our aws-g6e-xlarge-plus runners
@require_torch
@require_torch_accelerator
@require_torchao_version_greater_or_equal("0.14.0")
@require_torchao_version_greater_or_equal("0.15.0")
class TorchAoSerializationTest(unittest.TestCase):
model_name = "hf-internal-testing/tiny-flux-pipe"

Expand Down Expand Up @@ -650,7 +647,7 @@ def test_aobase_config(self):
self._check_serialization_expected_slice(quant_type, expected_slice, device)


@require_torchao_version_greater_or_equal("0.14.0")
@require_torchao_version_greater_or_equal("0.15.0")
class TorchAoCompileTest(QuantCompileTests, unittest.TestCase):
@property
def quantization_config(self):
Expand Down Expand Up @@ -696,7 +693,7 @@ def test_torch_compile_with_group_offload_leaf(self, use_stream):
# Slices for these tests have been obtained on our aws-g6e-xlarge-plus runners
@require_torch
@require_torch_accelerator
@require_torchao_version_greater_or_equal("0.14.0")
@require_torchao_version_greater_or_equal("0.15.0")
@slow
@nightly
class SlowTorchAoTests(unittest.TestCase):
Expand Down Expand Up @@ -854,7 +851,7 @@ def test_memory_footprint_int8wo(self):

@require_torch
@require_torch_accelerator
@require_torchao_version_greater_or_equal("0.14.0")
@require_torchao_version_greater_or_equal("0.15.0")
@slow
@nightly
class SlowTorchAoPreserializedModelTests(unittest.TestCase):
Expand Down
Loading