diff --git a/backends/arm/test/passes/test_fuse_duplicate_users_pass.py b/backends/arm/test/passes/test_fuse_duplicate_users_pass.py index d94e01f9847..13748b4401e 100644 --- a/backends/arm/test/passes/test_fuse_duplicate_users_pass.py +++ b/backends/arm/test/passes/test_fuse_duplicate_users_pass.py @@ -19,8 +19,9 @@ class ModuleWithOps(torch.nn.Module): class FuseaAvgPool(ModuleWithOps): + # CSE deduplicates the 3 identical avg(x) calls to 1 during to_edge ops_before_pass = { - "executorch_exir_dialects_edge__ops_aten_avg_pool2d_default": 3, + "executorch_exir_dialects_edge__ops_aten_avg_pool2d_default": 1, } ops_after_pass = {"executorch_exir_dialects_edge__ops_aten_avg_pool2d_default": 1} @@ -33,8 +34,9 @@ def forward(self, x): class FuseAvgPoolChain(ModuleWithOps): + # CSE deduplicates the 3 identical avg(avg(x)) chains to 1 chain of 2 during to_edge ops_before_pass = { - "executorch_exir_dialects_edge__ops_aten_avg_pool2d_default": 6, + "executorch_exir_dialects_edge__ops_aten_avg_pool2d_default": 2, } ops_after_pass = {"executorch_exir_dialects_edge__ops_aten_avg_pool2d_default": 2} diff --git a/backends/xnnpack/_passes/convert_to_linear.py b/backends/xnnpack/_passes/convert_to_linear.py index 2cef71bf927..c8c8715738b 100644 --- a/backends/xnnpack/_passes/convert_to_linear.py +++ b/backends/xnnpack/_passes/convert_to_linear.py @@ -5,7 +5,7 @@ # LICENSE file in the root directory of this source tree. import logging -from typing import List +from typing import List, Set import torch @@ -120,10 +120,23 @@ def create_linear( # Ignore dynamic shape nodes outputs = [ - node - for node in src_partition.output_nodes - if node.target != torch.ops.aten.sym_size.int and node.op != "placeholder" + n + for n in src_partition.output_nodes + if n.target != torch.ops.aten.sym_size.int and n.op != "placeholder" ] + if len(outputs) > 1: + # CSE may merge nodes across source partitions, creating extra + # output nodes. Keep only the output reachable from the mm/addmm. + partition_nodes: Set[torch.fx.Node] = set(src_partition.nodes) + reachable: Set[torch.fx.Node] = set() + queue = list(node.users.keys()) + while queue: + cur = queue.pop() + if cur in reachable or cur not in partition_nodes: + continue + reachable.add(cur) + queue.extend(cur.users.keys()) + outputs = [n for n in outputs if n in reachable] assert ( len(outputs) == 1 ), f"Unexpected number of outputs for a torch.nn.Linear module, expecting 1 but got {outputs}" diff --git a/exir/passes/BUCK b/exir/passes/BUCK index 4647388b388..7922e98992d 100644 --- a/exir/passes/BUCK +++ b/exir/passes/BUCK @@ -11,6 +11,7 @@ fbcode_target(_kind = runtime.python_library, deps = [ ":const_prop_pass", ":convert_constant_dim_order_pass", + ":cse_pass", ":debug_handle_generator_pass", ":external_constants_pass", ":init_mutable_pass", diff --git a/exir/passes/__init__.py b/exir/passes/__init__.py index ede866549b2..09ed3b5a099 100644 --- a/exir/passes/__init__.py +++ b/exir/passes/__init__.py @@ -34,6 +34,7 @@ from executorch.exir.pass_base import ExportPass from executorch.exir.pass_manager import PassManager, PassType from executorch.exir.passes.const_prop_pass import ConstPropPass +from executorch.exir.passes.cse_pass import CSEPass from executorch.exir.passes.debug_handle_generator_pass import DebugHandleGeneratorPass from executorch.exir.passes.executorch_prim_ops_registry import _EXECUTORCH_SYM_OPS @@ -72,6 +73,7 @@ __all__ = [ "ExportPass", "ConstPropPass", + "CSEPass", "QuantFusionPass", "OpReplacePass", "ToDevicePass", @@ -518,6 +520,7 @@ def dead_code_elimination_pass(graph_module: torch.fx.GraphModule) -> PassResult base_post_op_replace_passes: List[Callable[[torch.nn.Module], PassResult]] = ( PassManager( passes=[ + CSEPass(), dead_code_elimination_pass, DebugHandleGeneratorPass(), ] diff --git a/exir/tests/test_passes.py b/exir/tests/test_passes.py index 1316dffb828..0306c346485 100644 --- a/exir/tests/test_passes.py +++ b/exir/tests/test_passes.py @@ -2284,6 +2284,8 @@ def _do_checks( m = Module() n = m.to_copy_count() + # CSE deduplicates the two identical branches, reducing 4 _to_copy ops to 2 + n_after_cse = 2 input = torch.randn([2, 3, 4, 5]).to(memory_format=torch.contiguous_format) # 1. vanilla export, no edge ops @@ -2302,7 +2304,7 @@ def _do_checks( _do_checks( edge_prog.graph_module.code, edge_aten_op_str, - n, + n_after_cse, [aten_op_str, edge_dim_order_op_str], ) @@ -2312,11 +2314,12 @@ def _do_checks( _do_checks( new_res.graph_module.code, edge_aten_op_str, - n, + n_after_cse, [aten_op_str, edge_dim_order_op_str], ) # 2b. let's try with dim order enabled, we should see edge dim order ops but not edge aten ops + # CSE does not deduplicate dim_order ops (non-aten ops are treated as unsafe) edge_prog_dim_order = to_edge( ep, compile_config=exir.EdgeCompileConfig(_skip_dim_order=False) )._edge_programs["forward"]