Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions backends/arm/test/passes/test_fuse_duplicate_users_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,9 @@ class ModuleWithOps(torch.nn.Module):


class FuseaAvgPool(ModuleWithOps):
# CSE deduplicates the 3 identical avg(x) calls to 1 during to_edge
ops_before_pass = {
"executorch_exir_dialects_edge__ops_aten_avg_pool2d_default": 3,
"executorch_exir_dialects_edge__ops_aten_avg_pool2d_default": 1,
}
ops_after_pass = {"executorch_exir_dialects_edge__ops_aten_avg_pool2d_default": 1}

Expand All @@ -33,8 +34,9 @@ def forward(self, x):


class FuseAvgPoolChain(ModuleWithOps):
# CSE deduplicates the 3 identical avg(avg(x)) chains to 1 chain of 2 during to_edge
ops_before_pass = {
"executorch_exir_dialects_edge__ops_aten_avg_pool2d_default": 6,
"executorch_exir_dialects_edge__ops_aten_avg_pool2d_default": 2,
}
ops_after_pass = {"executorch_exir_dialects_edge__ops_aten_avg_pool2d_default": 2}

Expand Down
21 changes: 17 additions & 4 deletions backends/xnnpack/_passes/convert_to_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
# LICENSE file in the root directory of this source tree.

import logging
from typing import List
from typing import List, Set

import torch

Expand Down Expand Up @@ -120,10 +120,23 @@ def create_linear(

# Ignore dynamic shape nodes
outputs = [
node
for node in src_partition.output_nodes
if node.target != torch.ops.aten.sym_size.int and node.op != "placeholder"
n
for n in src_partition.output_nodes
if n.target != torch.ops.aten.sym_size.int and n.op != "placeholder"
]
if len(outputs) > 1:
# CSE may merge nodes across source partitions, creating extra
# output nodes. Keep only the output reachable from the mm/addmm.
partition_nodes: Set[torch.fx.Node] = set(src_partition.nodes)
reachable: Set[torch.fx.Node] = set()
queue = list(node.users.keys())
while queue:
cur = queue.pop()
if cur in reachable or cur not in partition_nodes:
continue
reachable.add(cur)
queue.extend(cur.users.keys())
outputs = [n for n in outputs if n in reachable]
assert (
len(outputs) == 1
), f"Unexpected number of outputs for a torch.nn.Linear module, expecting 1 but got {outputs}"
Expand Down
1 change: 1 addition & 0 deletions exir/passes/BUCK
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ fbcode_target(_kind = runtime.python_library,
deps = [
":const_prop_pass",
":convert_constant_dim_order_pass",
":cse_pass",
":debug_handle_generator_pass",
":external_constants_pass",
":init_mutable_pass",
Expand Down
3 changes: 3 additions & 0 deletions exir/passes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
from executorch.exir.pass_base import ExportPass
from executorch.exir.pass_manager import PassManager, PassType
from executorch.exir.passes.const_prop_pass import ConstPropPass
from executorch.exir.passes.cse_pass import CSEPass
from executorch.exir.passes.debug_handle_generator_pass import DebugHandleGeneratorPass

from executorch.exir.passes.executorch_prim_ops_registry import _EXECUTORCH_SYM_OPS
Expand Down Expand Up @@ -72,6 +73,7 @@
__all__ = [
"ExportPass",
"ConstPropPass",
"CSEPass",
"QuantFusionPass",
"OpReplacePass",
"ToDevicePass",
Expand Down Expand Up @@ -518,6 +520,7 @@ def dead_code_elimination_pass(graph_module: torch.fx.GraphModule) -> PassResult
base_post_op_replace_passes: List[Callable[[torch.nn.Module], PassResult]] = (
PassManager(
passes=[
CSEPass(),
dead_code_elimination_pass,
DebugHandleGeneratorPass(),
]
Expand Down
7 changes: 5 additions & 2 deletions exir/tests/test_passes.py
Original file line number Diff line number Diff line change
Expand Up @@ -2284,6 +2284,8 @@ def _do_checks(

m = Module()
n = m.to_copy_count()
# CSE deduplicates the two identical branches, reducing 4 _to_copy ops to 2
n_after_cse = 2
input = torch.randn([2, 3, 4, 5]).to(memory_format=torch.contiguous_format)

# 1. vanilla export, no edge ops
Expand All @@ -2302,7 +2304,7 @@ def _do_checks(
_do_checks(
edge_prog.graph_module.code,
edge_aten_op_str,
n,
n_after_cse,
[aten_op_str, edge_dim_order_op_str],
)

Expand All @@ -2312,11 +2314,12 @@ def _do_checks(
_do_checks(
new_res.graph_module.code,
edge_aten_op_str,
n,
n_after_cse,
[aten_op_str, edge_dim_order_op_str],
)

# 2b. let's try with dim order enabled, we should see edge dim order ops but not edge aten ops
# CSE does not deduplicate dim_order ops (non-aten ops are treated as unsafe)
edge_prog_dim_order = to_edge(
ep, compile_config=exir.EdgeCompileConfig(_skip_dim_order=False)
)._edge_programs["forward"]
Expand Down
Loading