From e4f83d1046bd8c12aebb35e0071c5c5564456f92 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Tue, 24 Mar 2026 17:01:12 +0530 Subject: [PATCH] fix to device and to dtype tests. --- tests/pipelines/test_pipelines_common.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/tests/pipelines/test_pipelines_common.py b/tests/pipelines/test_pipelines_common.py index af3573ce84cb..4d9d1717ba86 100644 --- a/tests/pipelines/test_pipelines_common.py +++ b/tests/pipelines/test_pipelines_common.py @@ -1534,14 +1534,18 @@ def test_to_device(self): pipe.set_progress_bar_config(disable=None) pipe.to("cpu") - model_devices = [component.device.type for component in components.values() if hasattr(component, "device")] + model_devices = [ + component.device.type for component in components.values() if getattr(component, "device", None) + ] self.assertTrue(all(device == "cpu" for device in model_devices)) output_cpu = pipe(**self.get_dummy_inputs("cpu"))[0] self.assertTrue(np.isnan(output_cpu).sum() == 0) pipe.to(torch_device) - model_devices = [component.device.type for component in components.values() if hasattr(component, "device")] + model_devices = [ + component.device.type for component in components.values() if getattr(component, "device", None) + ] self.assertTrue(all(device == torch_device for device in model_devices)) output_device = pipe(**self.get_dummy_inputs(torch_device))[0] @@ -1552,11 +1556,11 @@ def test_to_dtype(self): pipe = self.pipeline_class(**components) pipe.set_progress_bar_config(disable=None) - model_dtypes = [component.dtype for component in components.values() if hasattr(component, "dtype")] + model_dtypes = [component.dtype for component in components.values() if getattr(component, "dtype", None)] self.assertTrue(all(dtype == torch.float32 for dtype in model_dtypes)) pipe.to(dtype=torch.float16) - model_dtypes = [component.dtype for component in components.values() if hasattr(component, "dtype")] + model_dtypes = [component.dtype for component in components.values() if getattr(component, "dtype", None)] self.assertTrue(all(dtype == torch.float16 for dtype in model_dtypes)) def test_attention_slicing_forward_pass(self, expected_max_diff=1e-3):