diff --git a/.github/workflows/nightly_tests.yml b/.github/workflows/nightly_tests.yml index 4bf5f886330e..b82c195519bf 100644 --- a/.github/workflows/nightly_tests.yml +++ b/.github/workflows/nightly_tests.yml @@ -341,7 +341,7 @@ jobs: additional_deps: ["peft", "kernels"] - backend: "torchao" test_location: "torchao" - additional_deps: [] + additional_deps: [mslk-cuda] - backend: "optimum_quanto" test_location: "quanto" additional_deps: [] diff --git a/tests/models/testing_utils/quantization.py b/tests/models/testing_utils/quantization.py index 1aab0b240148..b08bd8d2e37d 100644 --- a/tests/models/testing_utils/quantization.py +++ b/tests/models/testing_utils/quantization.py @@ -920,6 +920,7 @@ def test_torchao_device_map(self): """Test that device_map='auto' works correctly with quantization.""" self._test_quantization_device_map(TorchAoConfigMixin.TORCHAO_QUANT_TYPES["int8wo"]) + @pytest.mark.xfail(reason="dequantize is not implemented in torchao") def test_torchao_dequantize(self): """Test that dequantize() works correctly.""" self._test_dequantize(TorchAoConfigMixin.TORCHAO_QUANT_TYPES["int8wo"]) diff --git a/tests/models/transformers/test_models_transformer_flux.py b/tests/models/transformers/test_models_transformer_flux.py index e4e91e52fb80..5eaadf7a8ad2 100644 --- a/tests/models/transformers/test_models_transformer_flux.py +++ b/tests/models/transformers/test_models_transformer_flux.py @@ -362,6 +362,10 @@ def pretrained_model_kwargs(self): class TestFluxTransformerTorchAo(FluxTransformerTesterConfig, TorchAoTesterMixin): """TorchAO quantization tests for Flux Transformer.""" + @property + def torch_dtype(self): + return torch.bfloat16 + class TestFluxTransformerGGUF(FluxTransformerTesterConfig, GGUFTesterMixin): @property