Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions backends/transforms/postpone_permute_below_squeeze_view.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# Copyright 2026 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# pyre-unsafe

import copy
from typing import cast, List

import torch
Expand Down Expand Up @@ -108,7 +108,7 @@ def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool:
# view_node_shape is almost same as permute_node_shape
# except it has one more dim somewhere
# and the extra dim has value of 1.
new_view_shape = copy.deepcopy(pred_shape)
new_view_shape = list(pred_shape)
new_view_shape.insert(index, 1)
new_permute_dims = [x + 1 if x >= index else x for x in permute_dims]
new_permute_dims.insert(index, index)
Expand All @@ -132,7 +132,7 @@ def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool:
# and the extra dim has value of 1.
# Convert permute_dims to list of ints
index_to_remove = permute_dims[index]
new_view_shape = copy.deepcopy(pred_shape)
new_view_shape = list(pred_shape)
del new_view_shape[index_to_remove]
new_permute_dims = [
x - 1 if x > index_to_remove else x for x in permute_dims
Expand Down
35 changes: 35 additions & 0 deletions backends/transforms/test/test_permute_optimization_passes.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# Copyright 2026 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
Expand All @@ -25,6 +26,8 @@
from executorch.backends.transforms.replace_nop_transpose_or_permute_with_view import (
ReplaceNopTransposeOrPermuteWithViewPass,
)

from executorch.exir import EdgeCompileConfig, to_edge
from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.pass_base import PassResult
from torch.utils import _pytree as pytree
Expand Down Expand Up @@ -477,6 +480,38 @@ def test_permute4_view3_chains(self) -> None:
"PostponePermuteOpBelowSqueezeOrUnsqueezeLikeView",
)

def test_postpone_permute_with_symbolic_shapes(self) -> None:
class DynamicPermuteViewModule(torch.nn.Module):
def forward(self, x: torch.Tensor) -> torch.Tensor:
y = x.view(x.shape[0], 12, 64)
y = y.permute(1, 0, 2)
y = y.view(1, 12, x.shape[0], 64)
return y.permute(0, 1, 3, 2)

exported_program = torch.export.export(
DynamicPermuteViewModule(),
(torch.randn(3, 1, 768),),
dynamic_shapes={"x": {0: torch.export.Dim("batch", min=1, max=8)}},
)
edge_program = to_edge(
exported_program,
compile_config=EdgeCompileConfig(_check_ir_validity=False),
)
graph_module = edge_program.exported_program().graph_module

result = cast(
PassResult,
PostponePermuteOpBelowSqueezeOrUnsqueezeLikeView().call(graph_module),
)

self.assertTrue(result.modified)
self.assertEqual(
count_node(result.graph_module, exir_ops.edge.aten.view_copy.default), 2
)
self.assertEqual(
count_node(result.graph_module, exir_ops.edge.aten.permute_copy.default), 2
)

def test_negative_not_squeeze_like(self) -> None:
"""View that reshapes (not just squeeze/unsqueeze) should NOT be reordered."""
builder = GraphBuilder()
Expand Down
4 changes: 1 addition & 3 deletions exir/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,6 @@
# pyre-ignore-all-errors[16]
from __future__ import annotations

import copy

import math
import typing
from typing import Dict, List, NamedTuple, Optional, Tuple, Union
Expand Down Expand Up @@ -112,7 +110,7 @@ def stride_from_dim_order(sizes: List[int], dim_order: List[int]) -> List[int]:
"""
if len(sizes) == 0:
return []
strides = copy.deepcopy(sizes)
strides = list(sizes)
ndim = len(sizes)
strides[dim_order[ndim - 1]] = 1
for i in range(ndim - 2, -1, -1):
Expand Down
20 changes: 20 additions & 0 deletions exir/tests/test_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -388,6 +388,26 @@ def test_strides_from_dim_order(self) -> None:
strides = stride_from_dim_order(sizes, dim_order)
self.assertEqual(expected_strides, strides)

def test_strides_from_dim_order_with_symbolic_sizes(self) -> None:
class ViewModule(torch.nn.Module):
def forward(self, x: torch.Tensor) -> torch.Tensor:
return x.view(x.shape[0], -1)

exported_program = torch.export.export(
ViewModule(),
(torch.randn(2, 3, 4),),
dynamic_shapes={"x": {0: torch.export.Dim("batch", min=1, max=8)}},
)
placeholder = next(
node
for node in exported_program.graph_module.graph.nodes
if node.op == "placeholder"
)
sizes = list(placeholder.meta["val"].shape)

self.assertIsInstance(sizes[0], torch.SymInt)
self.assertEqual([12, 4, 1], stride_from_dim_order(sizes, [0, 1, 2]))

def test_num_bytes_from_shape_and_dtype(self) -> None:
shape = (2, 3, 4)
self.assertEqual(24, num_bytes_from_shape_and_dtype(shape, torch.int8))
Expand Down
Loading