From ff26c0ba19306c7882581599cc3ed077a76b7350 Mon Sep 17 00:00:00 2001 From: Howard Zhang Date: Fri, 27 Mar 2026 15:08:50 -0700 Subject: [PATCH] change minimum version guard for torchao to 0.15.0 --- .../quantizers/quantization_config.py | 8 +++---- .../quantizers/torchao/torchao_quantizer.py | 8 +++---- tests/quantization/torchao/test_torchao.py | 21 ++++++++----------- 3 files changed, 17 insertions(+), 20 deletions(-) diff --git a/src/diffusers/quantizers/quantization_config.py b/src/diffusers/quantizers/quantization_config.py index 6f5e0c007294..c3d829fde8cf 100644 --- a/src/diffusers/quantizers/quantization_config.py +++ b/src/diffusers/quantizers/quantization_config.py @@ -470,8 +470,8 @@ def __init__( self.post_init() def post_init(self): - if is_torchao_version("<=", "0.9.0"): - raise ValueError("TorchAoConfig requires torchao > 0.9.0. Please upgrade with `pip install -U torchao`.") + if is_torchao_version("<", "0.15.0"): + raise ValueError("TorchAoConfig requires torchao >= 0.15.0. Please upgrade with `pip install -U torchao`.") from torchao.quantization.quant_api import AOBaseConfig @@ -495,8 +495,8 @@ def to_dict(self): @classmethod def from_dict(cls, config_dict, return_unused_kwargs=False, **kwargs): """Create configuration from a dictionary.""" - if not is_torchao_version(">", "0.9.0"): - raise NotImplementedError("TorchAoConfig requires torchao > 0.9.0 for construction from dict") + if not is_torchao_version(">=", "0.15.0"): + raise NotImplementedError("TorchAoConfig requires torchao >= 0.15.0 for construction from dict") config_dict = config_dict.copy() quant_type = config_dict.pop("quant_type") diff --git a/src/diffusers/quantizers/torchao/torchao_quantizer.py b/src/diffusers/quantizers/torchao/torchao_quantizer.py index 8fe2ef711046..88b45349daea 100644 --- a/src/diffusers/quantizers/torchao/torchao_quantizer.py +++ b/src/diffusers/quantizers/torchao/torchao_quantizer.py @@ -113,7 +113,7 @@ def _update_torch_safe_globals(): is_torch_available() and is_torch_version(">=", "2.6.0") and is_torchao_available() - and is_torchao_version(">=", "0.7.0") + and is_torchao_version(">=", "0.15.0") ): _update_torch_safe_globals() @@ -168,10 +168,10 @@ def validate_environment(self, *args, **kwargs): raise ImportError( "Loading a TorchAO quantized model requires the torchao library. Please install with `pip install torchao`" ) - torchao_version = version.parse(importlib.metadata.version("torch")) - if torchao_version < version.parse("0.7.0"): + torchao_version = version.parse(importlib.metadata.version("torchao")) + if torchao_version < version.parse("0.15.0"): raise RuntimeError( - f"The minimum required version of `torchao` is 0.7.0, but the current version is {torchao_version}. Please upgrade with `pip install -U torchao`." + f"The minimum required version of `torchao` is 0.15.0, but the current version is {torchao_version}. Please upgrade with `pip install -U torchao`." ) self.offload = False diff --git a/tests/quantization/torchao/test_torchao.py b/tests/quantization/torchao/test_torchao.py index 92d5bd42ee28..7a05582cbfba 100644 --- a/tests/quantization/torchao/test_torchao.py +++ b/tests/quantization/torchao/test_torchao.py @@ -14,13 +14,11 @@ # limitations under the License. import gc -import importlib.metadata import tempfile import unittest from typing import List import numpy as np -from packaging import version from parameterized import parameterized from transformers import AutoTokenizer, CLIPTextModel, CLIPTokenizer, T5EncoderModel @@ -82,18 +80,17 @@ def _is_xpu_or_cuda_capability_atleast_8_9() -> bool: Float8WeightOnlyConfig, Int4WeightOnlyConfig, Int8DynamicActivationInt8WeightConfig, + Int8DynamicActivationIntxWeightConfig, Int8WeightOnlyConfig, + IntxWeightOnlyConfig, ) from torchao.quantization.linear_activation_quantized_tensor import LinearActivationQuantizedTensor from torchao.utils import get_model_size_in_bytes - if version.parse(importlib.metadata.version("torchao")) >= version.Version("0.10.0"): - from torchao.quantization import Int8DynamicActivationIntxWeightConfig, IntxWeightOnlyConfig - @require_torch @require_torch_accelerator -@require_torchao_version_greater_or_equal("0.14.0") +@require_torchao_version_greater_or_equal("0.15.0") class TorchAoConfigTest(unittest.TestCase): def test_to_dict(self): """ @@ -128,7 +125,7 @@ def test_repr(self): # Slices for these tests have been obtained on our aws-g6e-xlarge-plus runners @require_torch @require_torch_accelerator -@require_torchao_version_greater_or_equal("0.14.0") +@require_torchao_version_greater_or_equal("0.15.0") class TorchAoTest(unittest.TestCase): def tearDown(self): gc.collect() @@ -527,7 +524,7 @@ def test_sequential_cpu_offload(self): inputs = self.get_dummy_inputs(torch_device) _ = pipe(**inputs) - @require_torchao_version_greater_or_equal("0.9.0") + @require_torchao_version_greater_or_equal("0.15.0") def test_aobase_config(self): quantization_config = TorchAoConfig(Int8WeightOnlyConfig()) components = self.get_dummy_components(quantization_config) @@ -540,7 +537,7 @@ def test_aobase_config(self): # Slices for these tests have been obtained on our aws-g6e-xlarge-plus runners @require_torch @require_torch_accelerator -@require_torchao_version_greater_or_equal("0.14.0") +@require_torchao_version_greater_or_equal("0.15.0") class TorchAoSerializationTest(unittest.TestCase): model_name = "hf-internal-testing/tiny-flux-pipe" @@ -650,7 +647,7 @@ def test_aobase_config(self): self._check_serialization_expected_slice(quant_type, expected_slice, device) -@require_torchao_version_greater_or_equal("0.14.0") +@require_torchao_version_greater_or_equal("0.15.0") class TorchAoCompileTest(QuantCompileTests, unittest.TestCase): @property def quantization_config(self): @@ -696,7 +693,7 @@ def test_torch_compile_with_group_offload_leaf(self, use_stream): # Slices for these tests have been obtained on our aws-g6e-xlarge-plus runners @require_torch @require_torch_accelerator -@require_torchao_version_greater_or_equal("0.14.0") +@require_torchao_version_greater_or_equal("0.15.0") @slow @nightly class SlowTorchAoTests(unittest.TestCase): @@ -854,7 +851,7 @@ def test_memory_footprint_int8wo(self): @require_torch @require_torch_accelerator -@require_torchao_version_greater_or_equal("0.14.0") +@require_torchao_version_greater_or_equal("0.15.0") @slow @nightly class SlowTorchAoPreserializedModelTests(unittest.TestCase):