diff --git a/exir/program/BUCK b/exir/program/BUCK index 11f62edd99e..9273fb755d4 100644 --- a/exir/program/BUCK +++ b/exir/program/BUCK @@ -49,6 +49,7 @@ fbcode_target(_kind = runtime.python_library, "//executorch/exir/passes:spec_prop_pass", "//executorch/exir/passes:weights_to_outputs_pass", "//executorch/exir/passes:convert_constant_dim_order_pass", + "//executorch/exir/passes:cse_pass", "//executorch/exir/verification:verifier", "//executorch/extension/flat_tensor/serialize:serialize", ] + (["//executorch/exir/program/fb:logger"] if not runtime.is_oss else []) diff --git a/exir/program/_program.py b/exir/program/_program.py index 485d72bbe45..8de134c4b7e 100644 --- a/exir/program/_program.py +++ b/exir/program/_program.py @@ -49,6 +49,7 @@ OpReplacePass, remove_unused_parameters_pass, ) +from executorch.exir.passes.cse_pass import CSEPass from executorch.exir.passes.external_constants_pass import ( external_constants_pass, external_mutable_weights_pass, @@ -759,6 +760,9 @@ def edge_to_executorch_passes( Get the pre memory planning passes based on the method name, if the pass is not in the dict, use the default pass. """ passes: List[PassType] = [ + # Run CSE before any user-supplied or downstream passes so redundant + # non-delegated ops are deduped before spec/device propagation runs. + CSEPass(), # ExecuTorch backend ops are unable to handle unbacked symints. So after # this pass, passes cannot be Interpreter-based, because it will fail if # there exists an unbacked symint operation. diff --git a/exir/tests/test_passes.py b/exir/tests/test_passes.py index 1316dffb828..1380a473033 100644 --- a/exir/tests/test_passes.py +++ b/exir/tests/test_passes.py @@ -2831,3 +2831,75 @@ def forward(self, x): self.assertTrue(result.modified) self.assertEqual(self._count_ops(result.graph_module, neg_target), 1) self.assertEqual(self._count_ops(result.graph_module, abs_target), 1) + + +class TestCSEPassPipelineIntegration(unittest.TestCase): + """ + Integration tests for wiring `CSEPass` into `edge_to_executorch_passes` + (the pass list consumed by `EdgeProgramManager.to_executorch`). CSE is + prepended so it dedupes redundancies introduced by user-supplied or + downstream passes at the edge -> executorch boundary, regardless of + `ExecutorchBackendConfig(passes=[...])` overrides. + """ + + @staticmethod + # pyre-ignore[2,3]: test helper, types are intentionally loose + def _count_op(gm, op_name: str) -> int: + return sum( + 1 + for n in gm.graph.nodes + if n.op == "call_function" and op_name in str(n.target) + ) + + def test_cse_runs_during_to_executorch(self) -> None: + """`sin(x) + sin(x)` should keep two `aten.sin` calls in the edge + graph (CSE no longer runs in `to_edge`) but collapse to one in the + executorch graph, proving CSE fires at the edge -> executorch + boundary.""" + + class M(torch.nn.Module): + def forward(self, x: torch.Tensor) -> torch.Tensor: + return torch.sin(x) + torch.sin(x) + + x = torch.randn(4, 4) + ep = export(M(), (x,), strict=True) + + edge = to_edge(ep) + edge_gm = edge.exported_program().graph_module + self.assertEqual(self._count_op(edge_gm, "aten.sin"), 2) + + et_gm = edge.to_executorch().exported_program().graph_module + self.assertEqual(self._count_op(et_gm, "aten.sin"), 1) + + def test_e2e_numerical_equivalence(self) -> None: + """Applying CSE (as `to_executorch` does) to a model with redundancy + at multiple ops must produce numerically equivalent output vs eager.""" + + class M(torch.nn.Module): + def forward(self, x: torch.Tensor) -> torch.Tensor: + a = torch.sin(x) + b = torch.sin(x) + c = torch.neg(torch.abs(x)) + d = torch.neg(torch.abs(x)) + return a + b + c + d + + x = torch.randn(8, 8) + eager = M()(x) + + ep = export(M(), (x,), strict=True) + edge_gm = to_edge(ep).exported_program().graph_module + + # Apply CSE to a clone of the edge graph, mirroring the boundary pass. + cse_result = CSEPass()(edge_gm) + self.assertTrue(cse_result.modified) + + # Confirm CSE actually fired: each redundant op collapses to 1. + self.assertEqual(self._count_op(cse_result.graph_module, "aten.sin"), 1) + self.assertEqual(self._count_op(cse_result.graph_module, "aten.abs"), 1) + self.assertEqual(self._count_op(cse_result.graph_module, "aten.neg"), 1) + + # Run the post-CSE edge graph and compare against eager. + cse_out = cse_result.graph_module(x) + if isinstance(cse_out, (list, tuple)): + cse_out = cse_out[0] + torch.testing.assert_close(cse_out, eager)