Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 13 additions & 6 deletions exir/passes/prune_empty_tensors_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,10 @@ class PruneEmptyTensorsPass(ExportPass):

def remove_empty_tensors_from_cat(
self, graph_module: GraphModule, cat_node: Node
) -> None:
) -> bool:
"""
Removes empty tensors from the graph that are inputs to aten.cat.default
Removes empty tensors from the graph that are inputs to aten.cat.default.
Returns True if the cat node was rewritten.
"""
concat_list = cast(List[Node], cat_node.args[0])
pruned_concat_list = []
Expand All @@ -38,6 +39,9 @@ def remove_empty_tensors_from_cat(
if input_arg_tensor.numel() != 0:
pruned_concat_list.append(input_arg)

if len(pruned_concat_list) == len(concat_list):
return False

cat_node.args = (pruned_concat_list,) + cat_node.args[1:]
if len(pruned_concat_list) == 0:
# if all the inputs to the cat are empty tensors, then we can replace
Expand All @@ -52,16 +56,19 @@ def remove_empty_tensors_from_cat(
)
full_like.meta = cat_node.meta
cat_node.replace_all_uses_with(full_like)
return True

def call(self, graph_module: GraphModule) -> PassResult:
modified = False
for node in graph_module.graph.nodes:
if node.op != "call_function":
continue

if node.target == torch.ops.aten.cat.default:
self.remove_empty_tensors_from_cat(graph_module, node)
modified |= self.remove_empty_tensors_from_cat(graph_module, node)

graph_module.graph.eliminate_dead_code()
graph_module.graph.lint()
if modified:
graph_module.graph.eliminate_dead_code()
graph_module.graph.lint()

return PassResult(graph_module, True)
return PassResult(graph_module, modified)
75 changes: 40 additions & 35 deletions exir/passes/remove_graph_asserts_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,33 +11,45 @@
from torch.fx.passes.infra.pass_base import PassBase, PassResult


def _erase_asserts_from_modules(
graph_module: torch.fx.GraphModule,
targets: tuple,
) -> bool:
modified = False
for module in graph_module.modules():
if not isinstance(module, torch.fx.GraphModule):
continue
module_modified = False
for node in module.graph.nodes:
if node.op == "call_function" and node.target in targets:
module.graph.erase_node(node)
module_modified = True
if module_modified:
module.recompile()
module.graph.eliminate_dead_code()
modified = True
return modified


_CORE_ASSERT_TARGETS: tuple = (
torch.ops.aten._assert_async.msg,
torch.ops.aten._assert_scalar.default,
torch.ops.aten.sym_constrain_range_for_size.default,
torch.ops.aten.sym_constrain_range.default,
torch.ops.aten._assert_tensor_metadata.default,
)


class RemoveGraphAssertsPass(PassBase):
"""
Temporary pass to remove all the assert ops until runtime decides to address it.
"""

def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
for module in graph_module.modules():
if not isinstance(module, torch.fx.GraphModule):
continue

for node in module.graph.nodes:
if node.op == "call_function" and (
node.target
in (
torch.ops.aten._assert_async.msg,
torch.ops.aten._assert_scalar.default,
torch.ops.aten.sym_constrain_range_for_size.default,
torch.ops.aten.sym_constrain_range.default,
torch.ops.aten._assert_tensor_metadata.default,
)
):
module.graph.erase_node(node)

module.recompile()
module.graph.eliminate_dead_code()

return PassResult(graph_module, True)
return PassResult(
graph_module,
_erase_asserts_from_modules(graph_module, _CORE_ASSERT_TARGETS),
)


class RemoveNonCoreAtenOpGraphAssertsPass(PassBase):
Expand All @@ -46,17 +58,10 @@ class RemoveNonCoreAtenOpGraphAssertsPass(PassBase):
"""

def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
for module in graph_module.modules():
if not isinstance(module, torch.fx.GraphModule):
continue

for node in module.graph.nodes:
if node.op == "call_function" and (
node.target in (torch.ops.aten._assert_tensor_metadata.default,)
):
module.graph.erase_node(node)

module.recompile()
module.graph.eliminate_dead_code()

return PassResult(graph_module, True)
return PassResult(
graph_module,
_erase_asserts_from_modules(
graph_module,
(torch.ops.aten._assert_tensor_metadata.default,),
),
)
29 changes: 20 additions & 9 deletions exir/passes/remove_noop_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,8 @@
def eliminate_dq_q(
graph_module: GraphModule,
dequant_nodes: List[torch.fx.Node],
) -> None:
) -> bool:
modified = False
for node in dequant_nodes:
assert node.target in _DEQUANT_OPS
for user in list(node.users):
Expand All @@ -41,6 +42,8 @@ def eliminate_dq_q(
if qparams_dq != qparams_q:
continue
user.replace_all_uses_with(node.args[0]) # pyre-fixme[6]
modified = True
return modified


class RemoveNoopPass(ExportPass):
Expand All @@ -54,6 +57,7 @@ def call(self, graph_module: GraphModule) -> PassResult:
# are removed in this pass and later check for redundant dq->q patterns and
# remove them.
dequant_nodes = []
modified = False

for node in graph_module.graph.nodes:
if node.op != "call_function":
Expand All @@ -74,6 +78,7 @@ def call(self, graph_module: GraphModule) -> PassResult:
if node.args[0].target in _DEQUANT_OPS:
dequant_nodes += [node.args[0]]
node.replace_all_uses_with(node.args[0])
modified = True
continue

if node.target == torch.ops.aten.slice_copy.Tensor:
Expand All @@ -91,13 +96,16 @@ def call(self, graph_module: GraphModule) -> PassResult:
if node.args[0].target in _DEQUANT_OPS:
dequant_nodes += [node.args[0]]
node.replace_all_uses_with(node.args[0])
modified = True

graph_module.graph.eliminate_dead_code()
eliminate_dq_q(graph_module, dequant_nodes)
graph_module.graph.lint()
graph_module.graph.eliminate_dead_code()
if modified:
graph_module.graph.eliminate_dead_code()
modified |= eliminate_dq_q(graph_module, dequant_nodes)
if modified:
graph_module.graph.lint()
graph_module.graph.eliminate_dead_code()

return PassResult(graph_module, True)
return PassResult(graph_module, modified)


class RemoveToCopyPass(ExportPass):
Expand All @@ -106,6 +114,7 @@ class RemoveToCopyPass(ExportPass):
"""

def call(self, graph_module: GraphModule) -> PassResult:
modified = False
for node in graph_module.graph.nodes:
if node.op != "call_function":
continue
Expand All @@ -122,8 +131,10 @@ def call(self, graph_module: GraphModule) -> PassResult:
and orig_tensor.stride() == node.meta["val"].stride()
):
node.replace_all_uses_with(node.args[0])
modified = True

graph_module.graph.eliminate_dead_code()
graph_module.graph.lint()
if modified:
graph_module.graph.eliminate_dead_code()
graph_module.graph.lint()

return PassResult(graph_module, True)
return PassResult(graph_module, modified)
4 changes: 3 additions & 1 deletion exir/passes/replace_sym_size_op_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,12 @@ class ReplaceSymSizeOpPass(PassBase):
"""

def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
modified = False
for module in graph_module.modules():
if not isinstance(module, torch.fx.GraphModule):
continue
for node in module.graph.nodes:
if node.target in replacements:
node.target = replacements[node.target]
return PassResult(graph_module, True)
modified = True
return PassResult(graph_module, modified)
2 changes: 1 addition & 1 deletion exir/passes/to_device_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
if modified:
graph_module.recompile()

return PassResult(graph_module, True)
return PassResult(graph_module, modified)

def __call__(self, graph_module: torch.fx.GraphModule) -> PassResult:
"""Reimplement __call__ to avoid Optional[PassResult] type hint."""
Expand Down
Loading
Loading