From de4511115ba5b5817d5a36b1bf565bd6311427dd Mon Sep 17 00:00:00 2001 From: ssjia Date: Fri, 29 May 2026 07:29:56 -0700 Subject: [PATCH] Update [ghstack-poisoned] --- devtools/inspector/tests/inspector_test.py | 29 +++++++++++++++++++--- 1 file changed, 26 insertions(+), 3 deletions(-) diff --git a/devtools/inspector/tests/inspector_test.py b/devtools/inspector/tests/inspector_test.py index b33c5b37164..4c59190650c 100644 --- a/devtools/inspector/tests/inspector_test.py +++ b/devtools/inspector/tests/inspector_test.py @@ -7,6 +7,7 @@ # pyre-unsafe import copy +import functools import os import random import statistics @@ -90,6 +91,28 @@ def forward(self, indices: torch.Tensor, values: torch.Tensor) -> torch.Tensor: ETRECORD_PATH = "unittest_etrecord_path" +def disable_if(condition, reason): + """Disable a test when condition is true, still reporting it as executed. + + Conditional analogue of unittest.skipIf that keeps disabled tests visible in + logs instead of producing a skipped result, which some test runners handle + inconsistently. + """ + + def decorator(fn): + if not condition: + return fn + + @functools.wraps(fn) + def wrapper(*args, **kwargs): + print(f"DISABLED_TEST: {fn.__qualname__}: {reason}") + return None + + return wrapper + + return decorator + + # TODO: write an E2E test: create an inspector instance, mock just the file reads, and then verify the external correctness class TestInspector(unittest.TestCase): def test_perf_data(self) -> None: @@ -1504,7 +1527,7 @@ def test_calculate_numeric_gap_with_edge_dialect_exported_program_name(self): self.assertIsInstance(df, pd.DataFrame) self.assertEqual(len(df), 1) - @unittest.skipIf(sys.platform.startswith("win"), "Skipping on Windows") + @disable_if(sys.platform.startswith("win"), "Skipping on Windows") def test_transformer_block_xnnpack_numeric_gap_within_tolerance(self): """ Test that the numeric gap between AOT and runtime intermediate outputs @@ -1693,7 +1716,7 @@ def forward( f"Stack trace for {op_name} doesn't contain file info", ) - @unittest.skipIf(sys.platform.startswith("win"), "Skipping on Windows") + @disable_if(sys.platform.startswith("win"), "Skipping on Windows") def test_intermediate_tensor_comparison_with_torch_export(self): """Test intermediate tensor comparison using torch.export.export and to_edge_transform_and_lower. @@ -1840,7 +1863,7 @@ def _gen_random_runtime_output( ) -> List[Union[None, List[torch.Tensor], bool, float, int, str, torch.Tensor]]: return [torch.randn(RAW_DATA_SIZE)] - @unittest.skipIf(sys.platform.startswith("win"), "Skipping on Windows") + @disable_if(sys.platform.startswith("win"), "Skipping on Windows") def test_disable_debug_handle_validation_with_symbolic_shapes(self): """ Test that demonstrates the issue with symbolic shape related nodes losing from_node info