From 5e5b575fb3413fbaf04a949c3d0fa5796b79e4f4 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Wed, 25 Mar 2026 09:38:49 +0530 Subject: [PATCH 1/3] fix torchao tests --- tests/models/testing_utils/quantization.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/tests/models/testing_utils/quantization.py b/tests/models/testing_utils/quantization.py index 0f1fbde72485..ec74422741c3 100644 --- a/tests/models/testing_utils/quantization.py +++ b/tests/models/testing_utils/quantization.py @@ -177,6 +177,11 @@ def _test_quantization_inference(self, config_kwargs): model_quantized.to(torch_device) inputs = self.get_dummy_inputs() + model_dtype = next(model_quantized.parameters()).dtype + inputs = { + k: v.to(dtype=model_dtype) if torch.is_tensor(v) and torch.is_floating_point(v) else v + for k, v in inputs.items() + } output = model_quantized(**inputs, return_dict=False)[0] assert output is not None, "Model output is None" @@ -930,6 +935,7 @@ def test_torchao_device_map(self): """Test that device_map='auto' works correctly with quantization.""" self._test_quantization_device_map(TorchAoConfigMixin.TORCHAO_QUANT_TYPES["int8wo"]) + @pytest.mark.xfail(reason="dequantize is not implemented in torchao") def test_torchao_dequantize(self): """Test that dequantize() works correctly.""" self._test_dequantize(TorchAoConfigMixin.TORCHAO_QUANT_TYPES["int8wo"]) From 4e01e02395145cd79e258cd40ad4ec0d62f4c42c Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Wed, 25 Mar 2026 09:41:04 +0530 Subject: [PATCH 2/3] add mslk for additional dependencies. --- .github/workflows/nightly_tests.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/nightly_tests.yml b/.github/workflows/nightly_tests.yml index 416d2af3fc2e..e242b4b57cb0 100644 --- a/.github/workflows/nightly_tests.yml +++ b/.github/workflows/nightly_tests.yml @@ -341,7 +341,7 @@ jobs: additional_deps: ["peft", "kernels"] - backend: "torchao" test_location: "torchao" - additional_deps: [] + additional_deps: [mslk-cuda] - backend: "optimum_quanto" test_location: "quanto" additional_deps: [] From d742b19f8b571205582a539342498277724d7887 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Fri, 1 May 2026 12:16:22 +0530 Subject: [PATCH 3/3] add dtype --- tests/models/testing_utils/quantization.py | 5 ----- tests/models/transformers/test_models_transformer_flux.py | 4 ++++ 2 files changed, 4 insertions(+), 5 deletions(-) diff --git a/tests/models/testing_utils/quantization.py b/tests/models/testing_utils/quantization.py index 64ab15d272d7..b08bd8d2e37d 100644 --- a/tests/models/testing_utils/quantization.py +++ b/tests/models/testing_utils/quantization.py @@ -175,11 +175,6 @@ def _test_quantization_inference(self, config_kwargs): model_quantized.to(torch_device) inputs = self.get_dummy_inputs() - model_dtype = next(model_quantized.parameters()).dtype - inputs = { - k: v.to(dtype=model_dtype) if torch.is_tensor(v) and torch.is_floating_point(v) else v - for k, v in inputs.items() - } output = model_quantized(**inputs, return_dict=False)[0] assert output is not None, "Model output is None" diff --git a/tests/models/transformers/test_models_transformer_flux.py b/tests/models/transformers/test_models_transformer_flux.py index e4e91e52fb80..5eaadf7a8ad2 100644 --- a/tests/models/transformers/test_models_transformer_flux.py +++ b/tests/models/transformers/test_models_transformer_flux.py @@ -362,6 +362,10 @@ def pretrained_model_kwargs(self): class TestFluxTransformerTorchAo(FluxTransformerTesterConfig, TorchAoTesterMixin): """TorchAO quantization tests for Flux Transformer.""" + @property + def torch_dtype(self): + return torch.bfloat16 + class TestFluxTransformerGGUF(FluxTransformerTesterConfig, GGUFTesterMixin): @property