Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions exir/program/BUCK
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ fbcode_target(_kind = runtime.python_library,
"//executorch/exir/passes:spec_prop_pass",
"//executorch/exir/passes:weights_to_outputs_pass",
"//executorch/exir/passes:convert_constant_dim_order_pass",
"//executorch/exir/passes:cse_pass",
"//executorch/exir/verification:verifier",
"//executorch/extension/flat_tensor/serialize:serialize",
] + (["//executorch/exir/program/fb:logger"] if not runtime.is_oss else [])
Expand Down
4 changes: 4 additions & 0 deletions exir/program/_program.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
OpReplacePass,
remove_unused_parameters_pass,
)
from executorch.exir.passes.cse_pass import CSEPass
from executorch.exir.passes.external_constants_pass import (
external_constants_pass,
external_mutable_weights_pass,
Expand Down Expand Up @@ -759,6 +760,9 @@ def edge_to_executorch_passes(
Get the pre memory planning passes based on the method name, if the pass is not in the dict, use the default pass.
"""
passes: List[PassType] = [
# Run CSE before any user-supplied or downstream passes so redundant
# non-delegated ops are deduped before spec/device propagation runs.
CSEPass(),
# ExecuTorch backend ops are unable to handle unbacked symints. So after
# this pass, passes cannot be Interpreter-based, because it will fail if
# there exists an unbacked symint operation.
Expand Down
72 changes: 72 additions & 0 deletions exir/tests/test_passes.py
Original file line number Diff line number Diff line change
Expand Up @@ -2831,3 +2831,75 @@ def forward(self, x):
self.assertTrue(result.modified)
self.assertEqual(self._count_ops(result.graph_module, neg_target), 1)
self.assertEqual(self._count_ops(result.graph_module, abs_target), 1)


class TestCSEPassPipelineIntegration(unittest.TestCase):
"""
Integration tests for wiring `CSEPass` into `edge_to_executorch_passes`
(the pass list consumed by `EdgeProgramManager.to_executorch`). CSE is
prepended so it dedupes redundancies introduced by user-supplied or
downstream passes at the edge -> executorch boundary, regardless of
`ExecutorchBackendConfig(passes=[...])` overrides.
"""

@staticmethod
# pyre-ignore[2,3]: test helper, types are intentionally loose
def _count_op(gm, op_name: str) -> int:
return sum(
1
for n in gm.graph.nodes
if n.op == "call_function" and op_name in str(n.target)
)

def test_cse_runs_during_to_executorch(self) -> None:
"""`sin(x) + sin(x)` should keep two `aten.sin` calls in the edge
graph (CSE no longer runs in `to_edge`) but collapse to one in the
executorch graph, proving CSE fires at the edge -> executorch
boundary."""

class M(torch.nn.Module):
def forward(self, x: torch.Tensor) -> torch.Tensor:
return torch.sin(x) + torch.sin(x)

x = torch.randn(4, 4)
ep = export(M(), (x,), strict=True)

edge = to_edge(ep)
edge_gm = edge.exported_program().graph_module
self.assertEqual(self._count_op(edge_gm, "aten.sin"), 2)

et_gm = edge.to_executorch().exported_program().graph_module
self.assertEqual(self._count_op(et_gm, "aten.sin"), 1)

def test_e2e_numerical_equivalence(self) -> None:
"""Applying CSE (as `to_executorch` does) to a model with redundancy
at multiple ops must produce numerically equivalent output vs eager."""

class M(torch.nn.Module):
def forward(self, x: torch.Tensor) -> torch.Tensor:
a = torch.sin(x)
b = torch.sin(x)
c = torch.neg(torch.abs(x))
d = torch.neg(torch.abs(x))
return a + b + c + d

x = torch.randn(8, 8)
eager = M()(x)

ep = export(M(), (x,), strict=True)
edge_gm = to_edge(ep).exported_program().graph_module

# Apply CSE to a clone of the edge graph, mirroring the boundary pass.
cse_result = CSEPass()(edge_gm)
self.assertTrue(cse_result.modified)

# Confirm CSE actually fired: each redundant op collapses to 1.
self.assertEqual(self._count_op(cse_result.graph_module, "aten.sin"), 1)
self.assertEqual(self._count_op(cse_result.graph_module, "aten.abs"), 1)
self.assertEqual(self._count_op(cse_result.graph_module, "aten.neg"), 1)

# Run the post-CSE edge graph and compare against eager.
cse_out = cse_result.graph_module(x)
if isinstance(cse_out, (list, tuple)):
cse_out = cse_out[0]
torch.testing.assert_close(cse_out, eager)
Loading