Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 26 additions & 3 deletions devtools/inspector/tests/inspector_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
# pyre-unsafe

import copy
import functools
import os
import random
import statistics
Expand Down Expand Up @@ -90,6 +91,28 @@ def forward(self, indices: torch.Tensor, values: torch.Tensor) -> torch.Tensor:
ETRECORD_PATH = "unittest_etrecord_path"


def disable_if(condition, reason):
"""Disable a test when condition is true, still reporting it as executed.

Conditional analogue of unittest.skipIf that keeps disabled tests visible in
logs instead of producing a skipped result, which some test runners handle
inconsistently.
"""

def decorator(fn):
if not condition:
return fn

@functools.wraps(fn)
def wrapper(*args, **kwargs):
print(f"DISABLED_TEST: {fn.__qualname__}: {reason}")
return None

return wrapper

return decorator


# TODO: write an E2E test: create an inspector instance, mock just the file reads, and then verify the external correctness
class TestInspector(unittest.TestCase):
def test_perf_data(self) -> None:
Expand Down Expand Up @@ -1504,7 +1527,7 @@ def test_calculate_numeric_gap_with_edge_dialect_exported_program_name(self):
self.assertIsInstance(df, pd.DataFrame)
self.assertEqual(len(df), 1)

@unittest.skipIf(sys.platform.startswith("win"), "Skipping on Windows")
@disable_if(sys.platform.startswith("win"), "Skipping on Windows")
def test_transformer_block_xnnpack_numeric_gap_within_tolerance(self):
"""
Test that the numeric gap between AOT and runtime intermediate outputs
Expand Down Expand Up @@ -1693,7 +1716,7 @@ def forward(
f"Stack trace for {op_name} doesn't contain file info",
)

@unittest.skipIf(sys.platform.startswith("win"), "Skipping on Windows")
@disable_if(sys.platform.startswith("win"), "Skipping on Windows")
def test_intermediate_tensor_comparison_with_torch_export(self):
"""Test intermediate tensor comparison using torch.export.export and to_edge_transform_and_lower.

Expand Down Expand Up @@ -1840,7 +1863,7 @@ def _gen_random_runtime_output(
) -> List[Union[None, List[torch.Tensor], bool, float, int, str, torch.Tensor]]:
return [torch.randn(RAW_DATA_SIZE)]

@unittest.skipIf(sys.platform.startswith("win"), "Skipping on Windows")
@disable_if(sys.platform.startswith("win"), "Skipping on Windows")
def test_disable_debug_handle_validation_with_symbolic_shapes(self):
"""
Test that demonstrates the issue with symbolic shape related nodes losing from_node info
Expand Down
Loading