diff --git a/exir/passes/prune_empty_tensors_pass.py b/exir/passes/prune_empty_tensors_pass.py index e9addfadced..402217c125b 100644 --- a/exir/passes/prune_empty_tensors_pass.py +++ b/exir/passes/prune_empty_tensors_pass.py @@ -27,9 +27,10 @@ class PruneEmptyTensorsPass(ExportPass): def remove_empty_tensors_from_cat( self, graph_module: GraphModule, cat_node: Node - ) -> None: + ) -> bool: """ - Removes empty tensors from the graph that are inputs to aten.cat.default + Removes empty tensors from the graph that are inputs to aten.cat.default. + Returns True if the cat node was rewritten. """ concat_list = cast(List[Node], cat_node.args[0]) pruned_concat_list = [] @@ -38,6 +39,9 @@ def remove_empty_tensors_from_cat( if input_arg_tensor.numel() != 0: pruned_concat_list.append(input_arg) + if len(pruned_concat_list) == len(concat_list): + return False + cat_node.args = (pruned_concat_list,) + cat_node.args[1:] if len(pruned_concat_list) == 0: # if all the inputs to the cat are empty tensors, then we can replace @@ -52,16 +56,19 @@ def remove_empty_tensors_from_cat( ) full_like.meta = cat_node.meta cat_node.replace_all_uses_with(full_like) + return True def call(self, graph_module: GraphModule) -> PassResult: + modified = False for node in graph_module.graph.nodes: if node.op != "call_function": continue if node.target == torch.ops.aten.cat.default: - self.remove_empty_tensors_from_cat(graph_module, node) + modified |= self.remove_empty_tensors_from_cat(graph_module, node) - graph_module.graph.eliminate_dead_code() - graph_module.graph.lint() + if modified: + graph_module.graph.eliminate_dead_code() + graph_module.graph.lint() - return PassResult(graph_module, True) + return PassResult(graph_module, modified) diff --git a/exir/passes/remove_graph_asserts_pass.py b/exir/passes/remove_graph_asserts_pass.py index 870621876e8..d6f5ff72363 100644 --- a/exir/passes/remove_graph_asserts_pass.py +++ b/exir/passes/remove_graph_asserts_pass.py @@ -11,33 +11,45 @@ from torch.fx.passes.infra.pass_base import PassBase, PassResult +def _erase_asserts_from_modules( + graph_module: torch.fx.GraphModule, + targets: tuple, +) -> bool: + modified = False + for module in graph_module.modules(): + if not isinstance(module, torch.fx.GraphModule): + continue + module_modified = False + for node in module.graph.nodes: + if node.op == "call_function" and node.target in targets: + module.graph.erase_node(node) + module_modified = True + if module_modified: + module.recompile() + module.graph.eliminate_dead_code() + modified = True + return modified + + +_CORE_ASSERT_TARGETS: tuple = ( + torch.ops.aten._assert_async.msg, + torch.ops.aten._assert_scalar.default, + torch.ops.aten.sym_constrain_range_for_size.default, + torch.ops.aten.sym_constrain_range.default, + torch.ops.aten._assert_tensor_metadata.default, +) + + class RemoveGraphAssertsPass(PassBase): """ Temporary pass to remove all the assert ops until runtime decides to address it. """ def call(self, graph_module: torch.fx.GraphModule) -> PassResult: - for module in graph_module.modules(): - if not isinstance(module, torch.fx.GraphModule): - continue - - for node in module.graph.nodes: - if node.op == "call_function" and ( - node.target - in ( - torch.ops.aten._assert_async.msg, - torch.ops.aten._assert_scalar.default, - torch.ops.aten.sym_constrain_range_for_size.default, - torch.ops.aten.sym_constrain_range.default, - torch.ops.aten._assert_tensor_metadata.default, - ) - ): - module.graph.erase_node(node) - - module.recompile() - module.graph.eliminate_dead_code() - - return PassResult(graph_module, True) + return PassResult( + graph_module, + _erase_asserts_from_modules(graph_module, _CORE_ASSERT_TARGETS), + ) class RemoveNonCoreAtenOpGraphAssertsPass(PassBase): @@ -46,17 +58,10 @@ class RemoveNonCoreAtenOpGraphAssertsPass(PassBase): """ def call(self, graph_module: torch.fx.GraphModule) -> PassResult: - for module in graph_module.modules(): - if not isinstance(module, torch.fx.GraphModule): - continue - - for node in module.graph.nodes: - if node.op == "call_function" and ( - node.target in (torch.ops.aten._assert_tensor_metadata.default,) - ): - module.graph.erase_node(node) - - module.recompile() - module.graph.eliminate_dead_code() - - return PassResult(graph_module, True) + return PassResult( + graph_module, + _erase_asserts_from_modules( + graph_module, + (torch.ops.aten._assert_tensor_metadata.default,), + ), + ) diff --git a/exir/passes/remove_noop_pass.py b/exir/passes/remove_noop_pass.py index e2d92909e53..fefb1e3655e 100644 --- a/exir/passes/remove_noop_pass.py +++ b/exir/passes/remove_noop_pass.py @@ -30,7 +30,8 @@ def eliminate_dq_q( graph_module: GraphModule, dequant_nodes: List[torch.fx.Node], -) -> None: +) -> bool: + modified = False for node in dequant_nodes: assert node.target in _DEQUANT_OPS for user in list(node.users): @@ -41,6 +42,8 @@ def eliminate_dq_q( if qparams_dq != qparams_q: continue user.replace_all_uses_with(node.args[0]) # pyre-fixme[6] + modified = True + return modified class RemoveNoopPass(ExportPass): @@ -54,6 +57,7 @@ def call(self, graph_module: GraphModule) -> PassResult: # are removed in this pass and later check for redundant dq->q patterns and # remove them. dequant_nodes = [] + modified = False for node in graph_module.graph.nodes: if node.op != "call_function": @@ -74,6 +78,7 @@ def call(self, graph_module: GraphModule) -> PassResult: if node.args[0].target in _DEQUANT_OPS: dequant_nodes += [node.args[0]] node.replace_all_uses_with(node.args[0]) + modified = True continue if node.target == torch.ops.aten.slice_copy.Tensor: @@ -91,13 +96,16 @@ def call(self, graph_module: GraphModule) -> PassResult: if node.args[0].target in _DEQUANT_OPS: dequant_nodes += [node.args[0]] node.replace_all_uses_with(node.args[0]) + modified = True - graph_module.graph.eliminate_dead_code() - eliminate_dq_q(graph_module, dequant_nodes) - graph_module.graph.lint() - graph_module.graph.eliminate_dead_code() + if modified: + graph_module.graph.eliminate_dead_code() + modified |= eliminate_dq_q(graph_module, dequant_nodes) + if modified: + graph_module.graph.lint() + graph_module.graph.eliminate_dead_code() - return PassResult(graph_module, True) + return PassResult(graph_module, modified) class RemoveToCopyPass(ExportPass): @@ -106,6 +114,7 @@ class RemoveToCopyPass(ExportPass): """ def call(self, graph_module: GraphModule) -> PassResult: + modified = False for node in graph_module.graph.nodes: if node.op != "call_function": continue @@ -122,8 +131,10 @@ def call(self, graph_module: GraphModule) -> PassResult: and orig_tensor.stride() == node.meta["val"].stride() ): node.replace_all_uses_with(node.args[0]) + modified = True - graph_module.graph.eliminate_dead_code() - graph_module.graph.lint() + if modified: + graph_module.graph.eliminate_dead_code() + graph_module.graph.lint() - return PassResult(graph_module, True) + return PassResult(graph_module, modified) diff --git a/exir/passes/replace_sym_size_op_pass.py b/exir/passes/replace_sym_size_op_pass.py index 8066f75c0a1..99e823b18cd 100644 --- a/exir/passes/replace_sym_size_op_pass.py +++ b/exir/passes/replace_sym_size_op_pass.py @@ -26,10 +26,12 @@ class ReplaceSymSizeOpPass(PassBase): """ def call(self, graph_module: torch.fx.GraphModule) -> PassResult: + modified = False for module in graph_module.modules(): if not isinstance(module, torch.fx.GraphModule): continue for node in module.graph.nodes: if node.target in replacements: node.target = replacements[node.target] - return PassResult(graph_module, True) + modified = True + return PassResult(graph_module, modified) diff --git a/exir/passes/to_device_pass.py b/exir/passes/to_device_pass.py index 10c324bb9fb..64e9b20fbed 100644 --- a/exir/passes/to_device_pass.py +++ b/exir/passes/to_device_pass.py @@ -38,7 +38,7 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult: if modified: graph_module.recompile() - return PassResult(graph_module, True) + return PassResult(graph_module, modified) def __call__(self, graph_module: torch.fx.GraphModule) -> PassResult: """Reimplement __call__ to avoid Optional[PassResult] type hint.""" diff --git a/exir/tests/test_passes.py b/exir/tests/test_passes.py index 1316dffb828..7445d79dee6 100644 --- a/exir/tests/test_passes.py +++ b/exir/tests/test_passes.py @@ -67,8 +67,13 @@ from executorch.exir.passes.normalize_view_copy_base_pass import ( NormalizeViewCopyBasePass, ) -from executorch.exir.passes.remove_graph_asserts_pass import RemoveGraphAssertsPass +from executorch.exir.passes.prune_empty_tensors_pass import PruneEmptyTensorsPass +from executorch.exir.passes.remove_graph_asserts_pass import ( + RemoveGraphAssertsPass, + RemoveNonCoreAtenOpGraphAssertsPass, +) from executorch.exir.passes.remove_mixed_type_operators import RemoveMixedTypeOperators +from executorch.exir.passes.remove_noop_pass import RemoveToCopyPass from executorch.exir.passes.replace_edge_with_backend_pass import EdgeToBackendOpsPass from executorch.exir.passes.replace_view_copy_with_view_pass import ( ReplaceViewCopyWithViewPass, @@ -77,6 +82,7 @@ from executorch.exir.passes.spec_prop_pass import SpecPropPass from executorch.exir.passes.sym_shape_eval_pass import ConstraintBasedSymShapeEvalPass from executorch.exir.passes.sym_to_tensor_pass import SymToTensorPass +from executorch.exir.passes.to_device_pass import ToDevicePass from executorch.exir.program._program import lift_constant_tensor_pass from executorch.exir.schema import TensorShapeDynamism from executorch.exir.sym_util import eval_upper_bound @@ -2831,3 +2837,185 @@ def forward(self, x): self.assertTrue(result.modified) self.assertEqual(self._count_ops(result.graph_module, neg_target), 1) self.assertEqual(self._count_ops(result.graph_module, abs_target), 1) + + +class TestPassModifiedFlag(unittest.TestCase): + """ + PassResult.modified is read by _transform_with_pass_manager in + exir/program/_program.py; a false-positive forces a full + ExportedProgram rebuild + verifier re-run. These tests pin the + contract: a pass with nothing to do must report modified=False. + """ + + @staticmethod + def _aten_gm(module, example_inputs): + return torch.export.export(module, example_inputs, strict=True).graph_module + + @staticmethod + def _identity_aten_gm(): + class Identity(torch.nn.Module): + def forward(self, x): + return x + 1 + + return TestPassModifiedFlag._aten_gm(Identity(), (torch.ones(2),)) + + # ---- RemoveNoopPass ---- + + def test_remove_noop_pass_noop_when_nothing_to_remove(self): + gm = self._identity_aten_gm() + result = RemoveNoopPass()(gm) + self.assertFalse(result.modified) + + def test_remove_noop_pass_modified_when_redundant_to_dtype(self): + graph = torch.fx.Graph() + with FakeTensorMode() as fake_mode: + fake_input = fake_mode.from_tensor(torch.randn(2, 3)) + x = graph.placeholder("x") + x.meta["val"] = fake_input + to_node = graph.call_function( + torch.ops.aten.to.dtype, args=(x, torch.float32) + ) + to_node.meta["val"] = fake_input + graph.output(to_node) + gm = torch.fx.GraphModule(torch.nn.Module(), graph) + + result = RemoveNoopPass()(gm) + self.assertTrue(result.modified) + + # ---- RemoveToCopyPass ---- + + def test_remove_to_copy_pass_noop_when_nothing_to_remove(self): + gm = self._identity_aten_gm() + result = RemoveToCopyPass()(gm) + self.assertFalse(result.modified) + + def test_remove_to_copy_pass_modified_when_redundant_copy(self): + # Build a graph with a redundant aten._to_copy.default whose + # output FakeTensor matches the input on dtype/device/shape/stride. + graph = torch.fx.Graph() + with FakeTensorMode() as fake_mode: + fake_input = fake_mode.from_tensor(torch.randn(2, 3)) + x = graph.placeholder("x") + x.meta["val"] = fake_input + copy_node = graph.call_function(torch.ops.aten._to_copy.default, args=(x,)) + copy_node.meta["val"] = fake_mode.from_tensor(torch.randn(2, 3)) + graph.output(copy_node) + gm = torch.fx.GraphModule(torch.nn.Module(), graph) + + result = RemoveToCopyPass()(gm) + self.assertTrue(result.modified) + + # ---- PruneEmptyTensorsPass ---- + + def test_prune_empty_tensors_pass_noop_when_no_cat(self): + gm = self._identity_aten_gm() + result = PruneEmptyTensorsPass()(gm) + self.assertFalse(result.modified) + + def test_prune_empty_tensors_pass_noop_when_cat_has_no_empty(self): + class Cat(torch.nn.Module): + def forward(self, x, y): + return torch.cat([x, y]) + + gm = self._aten_gm(Cat(), (torch.ones(2, 3), torch.ones(2, 3))) + result = PruneEmptyTensorsPass()(gm) + self.assertFalse(result.modified) + + def test_prune_empty_tensors_pass_modified_when_empty_input(self): + class Cat(torch.nn.Module): + def forward(self, x): + return torch.cat([torch.empty((0, 3)), x]) + + gm = self._aten_gm(Cat(), (torch.ones(2, 3),)) + result = PruneEmptyTensorsPass()(gm) + self.assertTrue(result.modified) + + # ---- RemoveGraphAssertsPass ---- + + def test_remove_graph_asserts_pass_noop_when_no_asserts(self): + gm = self._identity_aten_gm() + result = RemoveGraphAssertsPass()(gm) + self.assertFalse(result.modified) + + def test_remove_graph_asserts_pass_modified_when_asserts_present(self): + # Construct a graph with an _assert_async node directly so the test + # does not depend on torch.export's heuristics for which asserts + # survive constant folding. + graph = torch.fx.Graph() + x = graph.placeholder("x") + add = graph.call_function(torch.ops.aten.add.Tensor, args=(x, x)) + graph.call_function(torch.ops.aten._assert_async.msg, args=(add, "asserted")) + graph.output(add) + gm = torch.fx.GraphModule(torch.nn.Module(), graph) + + result = RemoveGraphAssertsPass()(gm) + self.assertTrue(result.modified) + remaining = [ + n + for n in result.graph_module.graph.nodes + if n.op == "call_function" and n.target == torch.ops.aten._assert_async.msg + ] + self.assertEqual(remaining, []) + + # ---- RemoveNonCoreAtenOpGraphAssertsPass ---- + + def test_remove_noncore_asserts_pass_noop_when_no_asserts(self): + gm = self._identity_aten_gm() + result = RemoveNonCoreAtenOpGraphAssertsPass()(gm) + self.assertFalse(result.modified) + + def test_remove_noncore_asserts_pass_modified_when_metadata_assert(self): + graph = torch.fx.Graph() + x = graph.placeholder("x") + add = graph.call_function(torch.ops.aten.add.Tensor, args=(x, x)) + graph.call_function( + torch.ops.aten._assert_tensor_metadata.default, + args=(add, None, None, torch.float32), + ) + graph.output(add) + gm = torch.fx.GraphModule(torch.nn.Module(), graph) + + result = RemoveNonCoreAtenOpGraphAssertsPass()(gm) + self.assertTrue(result.modified) + + # ---- ReplaceSymSizeOpPass ---- + + def test_replace_sym_size_op_pass_noop_when_no_packets(self): + gm = self._identity_aten_gm() + result = ReplaceSymSizeOpPass()(gm) + self.assertFalse(result.modified) + + def test_replace_sym_size_op_pass_modified_when_packet_present(self): + # Build a graph that references torch.ops.aten.sym_size (the + # OpOverloadPacket form) so the pass has something to rewrite. + graph = torch.fx.Graph() + x = graph.placeholder("x") + sym_node = graph.call_function(torch.ops.aten.sym_size, args=(x, 0)) + graph.output(sym_node) + gm = torch.fx.GraphModule(torch.nn.Module(), graph) + + result = ReplaceSymSizeOpPass()(gm) + self.assertTrue(result.modified) + new_targets = { + n.target for n in result.graph_module.graph.nodes if n.op == "call_function" + } + self.assertIn(torch.ops.aten.sym_size.int, new_targets) + self.assertNotIn(torch.ops.aten.sym_size, new_targets) + + # ---- ToDevicePass ---- + + def test_to_device_pass_noop_when_already_target_device(self): + # The identity model has no device= kwargs and is already on CPU. + gm = self._identity_aten_gm() + result = ToDevicePass("cpu")(gm) + self.assertFalse(result.modified) + + def test_to_device_pass_modified_when_kwarg_device_differs(self): + # arange has an explicit device kwarg in the exported graph. + class Arange(torch.nn.Module): + def forward(self, x): + return torch.arange(0, 4, device="cpu") + x + + gm = self._aten_gm(Arange(), (torch.zeros(4),)) + result = ToDevicePass("meta")(gm) + self.assertTrue(result.modified)