diff --git a/backends/transforms/postpone_permute_below_squeeze_view.py b/backends/transforms/postpone_permute_below_squeeze_view.py index f676e19fb65..e0e9a3ec198 100644 --- a/backends/transforms/postpone_permute_below_squeeze_view.py +++ b/backends/transforms/postpone_permute_below_squeeze_view.py @@ -1,12 +1,12 @@ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. +# Copyright 2026 Arm Limited and/or its affiliates. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. # pyre-unsafe -import copy from typing import cast, List import torch @@ -108,7 +108,7 @@ def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool: # view_node_shape is almost same as permute_node_shape # except it has one more dim somewhere # and the extra dim has value of 1. - new_view_shape = copy.deepcopy(pred_shape) + new_view_shape = list(pred_shape) new_view_shape.insert(index, 1) new_permute_dims = [x + 1 if x >= index else x for x in permute_dims] new_permute_dims.insert(index, index) @@ -132,7 +132,7 @@ def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool: # and the extra dim has value of 1. # Convert permute_dims to list of ints index_to_remove = permute_dims[index] - new_view_shape = copy.deepcopy(pred_shape) + new_view_shape = list(pred_shape) del new_view_shape[index_to_remove] new_permute_dims = [ x - 1 if x > index_to_remove else x for x in permute_dims diff --git a/backends/transforms/test/test_permute_optimization_passes.py b/backends/transforms/test/test_permute_optimization_passes.py index dd356aad8a2..550446da562 100644 --- a/backends/transforms/test/test_permute_optimization_passes.py +++ b/backends/transforms/test/test_permute_optimization_passes.py @@ -1,5 +1,6 @@ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. +# Copyright 2026 Arm Limited and/or its affiliates. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. @@ -25,6 +26,8 @@ from executorch.backends.transforms.replace_nop_transpose_or_permute_with_view import ( ReplaceNopTransposeOrPermuteWithViewPass, ) + +from executorch.exir import EdgeCompileConfig, to_edge from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.pass_base import PassResult from torch.utils import _pytree as pytree @@ -477,6 +480,38 @@ def test_permute4_view3_chains(self) -> None: "PostponePermuteOpBelowSqueezeOrUnsqueezeLikeView", ) + def test_postpone_permute_with_symbolic_shapes(self) -> None: + class DynamicPermuteViewModule(torch.nn.Module): + def forward(self, x: torch.Tensor) -> torch.Tensor: + y = x.view(x.shape[0], 12, 64) + y = y.permute(1, 0, 2) + y = y.view(1, 12, x.shape[0], 64) + return y.permute(0, 1, 3, 2) + + exported_program = torch.export.export( + DynamicPermuteViewModule(), + (torch.randn(3, 1, 768),), + dynamic_shapes={"x": {0: torch.export.Dim("batch", min=1, max=8)}}, + ) + edge_program = to_edge( + exported_program, + compile_config=EdgeCompileConfig(_check_ir_validity=False), + ) + graph_module = edge_program.exported_program().graph_module + + result = cast( + PassResult, + PostponePermuteOpBelowSqueezeOrUnsqueezeLikeView().call(graph_module), + ) + + self.assertTrue(result.modified) + self.assertEqual( + count_node(result.graph_module, exir_ops.edge.aten.view_copy.default), 2 + ) + self.assertEqual( + count_node(result.graph_module, exir_ops.edge.aten.permute_copy.default), 2 + ) + def test_negative_not_squeeze_like(self) -> None: """View that reshapes (not just squeeze/unsqueeze) should NOT be reordered.""" builder = GraphBuilder() diff --git a/exir/tensor.py b/exir/tensor.py index 02295eb8013..4898ac79f7a 100644 --- a/exir/tensor.py +++ b/exir/tensor.py @@ -10,8 +10,6 @@ # pyre-ignore-all-errors[16] from __future__ import annotations -import copy - import math import typing from typing import Dict, List, NamedTuple, Optional, Tuple, Union @@ -112,7 +110,7 @@ def stride_from_dim_order(sizes: List[int], dim_order: List[int]) -> List[int]: """ if len(sizes) == 0: return [] - strides = copy.deepcopy(sizes) + strides = list(sizes) ndim = len(sizes) strides[dim_order[ndim - 1]] = 1 for i in range(ndim - 2, -1, -1): diff --git a/exir/tests/test_tensor.py b/exir/tests/test_tensor.py index 25bf2ea451e..6435ca98a13 100644 --- a/exir/tests/test_tensor.py +++ b/exir/tests/test_tensor.py @@ -388,6 +388,26 @@ def test_strides_from_dim_order(self) -> None: strides = stride_from_dim_order(sizes, dim_order) self.assertEqual(expected_strides, strides) + def test_strides_from_dim_order_with_symbolic_sizes(self) -> None: + class ViewModule(torch.nn.Module): + def forward(self, x: torch.Tensor) -> torch.Tensor: + return x.view(x.shape[0], -1) + + exported_program = torch.export.export( + ViewModule(), + (torch.randn(2, 3, 4),), + dynamic_shapes={"x": {0: torch.export.Dim("batch", min=1, max=8)}}, + ) + placeholder = next( + node + for node in exported_program.graph_module.graph.nodes + if node.op == "placeholder" + ) + sizes = list(placeholder.meta["val"].shape) + + self.assertIsInstance(sizes[0], torch.SymInt) + self.assertEqual([12, 4, 1], stride_from_dim_order(sizes, [0, 1, 2])) + def test_num_bytes_from_shape_and_dtype(self) -> None: shape = (2, 3, 4) self.assertEqual(24, num_bytes_from_shape_and_dtype(shape, torch.int8))