Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
194 changes: 194 additions & 0 deletions backends/cortex_m/passes/convert_to_cortex_m_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,6 +304,191 @@
)
return exir_ops.edge.cortex_m.quantized_conv2d.default, new_args

def _lower_conv1d(self, node, graph_module):
"""Lower a quantized 3D aten.convolution.default (conv1d) end-to-end.

Wraps the runtime input in `aten.unsqueeze_copy +
dim_order_ops._clone_dim_order` to a 4D NHWC tensor with H=1,
AoT-reshapes the weight to 4D OHWI / IHWO with H=1, calls the existing
cortex_m.quantized_conv2d kernel, then clones + squeezes back to 3D
NCW. Replaces uses of `node` with the terminal squeeze and erases
`node`. The caller does not need to do any further graph mutation.
"""
(
x,
weight,
bias,
stride,
padding,
dilation,
_transposed,
_output_padding,
groups,
) = node.args

stride_2d = [1, stride[0]]
padding_2d = [0, padding[0]]
dilation_2d = [1, dilation[0]]

input_scale = node.meta["input_qparams"][0].scale
input_zero_point = node.meta["input_qparams"][0].zp
weight_scales = node.meta["input_qparams"][1].scale
if not isinstance(weight_scales, list):
fake_weight_tensor = get_first_fake_tensor(weight)
weight_scales = [weight_scales] * fake_weight_tensor.shape[0]

output_qparams = node.meta["output_qparams"][0]
output_scale = output_qparams.scale
output_zero_point = output_qparams.zp
output_qmin = output_qparams.qmin
output_qmax = output_qparams.qmax

quantized_multipliers = []
quantized_shifts = []
for weight_scale in weight_scales:
quantized_multiplier, quantized_shift = quantize_multiplier_aot(
input_scale * weight_scale / output_scale
)
quantized_multipliers.append(quantized_multiplier)
quantized_shifts.append(quantized_shift)

param_weight_tensor = get_param_tensor(self.exported_program, weight)
if param_weight_tensor is None:
raise RuntimeError(
f"Expected conv1d weight parameter tensor for node {node.name}."
)

# Conv1d weight shape: (out_channels, in_channels/groups, K).
in_channels = param_weight_tensor.shape[1] * groups
out_channels = param_weight_tensor.shape[0]
is_depthwise = (in_channels == groups) and (out_channels % in_channels == 0)
batch_size = self._get_batch_size_from_conv(node)
use_depthwise_conv = is_depthwise and (batch_size == 1)

# Lift the weight to the 4D layout the existing quantized_conv2d kernels
# already consume: unsqueeze a singleton H=1, then permute by the same
# axes the 4D path uses (OHWI for regular, IHWO for depthwise).
param_weight_4d = param_weight_tensor.unsqueeze(2)
if use_depthwise_conv:
weight_permuted = param_weight_4d.permute(1, 2, 3, 0).contiguous()
else:
weight_permuted = param_weight_4d.permute(0, 2, 3, 1).contiguous()

with node.graph.inserting_after(weight):
weight_nhwc = create_constant_placeholder(
self.exported_program,
node.graph,
node.name + "_weight_nhwc",
InputKind.PARAMETER,
weight_permuted,
)
quantized_multiplier_tensor = create_constant_placeholder(
self.exported_program,
node.graph,
node.name + "_quantized_multiplier",
InputKind.PARAMETER,
torch.tensor(quantized_multipliers, dtype=torch.int32),
)
quantized_shift_tensor = create_constant_placeholder(
self.exported_program,
node.graph,
node.name + "_quantized_shift",
InputKind.PARAMETER,
torch.tensor(quantized_shifts, dtype=torch.int32),
)

# Build the input chain (NCW -> 4D NHWC), the conv, and the output chain
# (4D NHWC -> NCW), all inserted in graph order before the original conv.
# Use view_copy (not unsqueeze_copy / squeeze_copy) so that
# backends/transforms/fuse_view_copy.FuseViewCopyTransform can collapse
# the view_copy <-> view_copy chain that forms between consecutive
# conv1d layers (e.g. Wav2Letter, Silero VAD encoder).
#
# For the input shape, prefer the explicit shape arg on x when x is a
# view_copy we just inserted for an earlier conv1d (its meta["val"]
# hasn't been repopulated yet at this point in the pass). Restrict to
# 3D targets so an unrelated view_copy producing a different rank
# can't silently feed a malformed shape into the conv reshape.
in_3d_shape = None
if (
isinstance(x, torch.fx.Node)
and x.target == exir_ops.edge.aten.view_copy.default
and len(x.args[1]) == 3
):
in_3d_shape = list(x.args[1])
if in_3d_shape is None:
in_3d_shape = list(get_first_fake_tensor(x).shape)
assert (
len(in_3d_shape) == 3
), f"_lower_conv1d expects a 3D input, got shape {in_3d_shape}"
x_4d_shape = [in_3d_shape[0], in_3d_shape[1], 1, in_3d_shape[2]]
out_3d_shape = list(node.meta["val"].shape)
out_4d_shape = [out_3d_shape[0], out_3d_shape[1], 1, out_3d_shape[2]]

Check warning on line 426 in backends/cortex_m/passes/convert_to_cortex_m_pass.py

View workflow job for this annotation

GitHub Actions / lintrunner

FLAKE8 F841

local variable 'out_4d_shape' is assigned to but never used See https://www.flake8rules.com/rules/F841.html.

with node.graph.inserting_before(node):
x_4d_nchw = node.graph.create_node(
"call_function",
target=exir_ops.edge.aten.view_copy.default,
args=(x, x_4d_shape),
)
x_4d_nhwc = node.graph.create_node(
"call_function",
target=exir_ops.edge.dim_order_ops._clone_dim_order.default,
args=(x_4d_nchw,),
kwargs={"dim_order": [0, 2, 3, 1]},
)
scratch = self._create_uninitialized_alloc_node()

# `is_depthwise` already required `out_channels % in_channels == 0`,
# so depth_multiplier is exact; pass it as the extra positional that
# the depthwise kernel takes between dilation and input_zero_point.
if use_depthwise_conv:
conv_op = exir_ops.edge.cortex_m.quantized_depthwise_conv2d.default
depth_multiplier_args = (out_channels // in_channels,)
else:
conv_op = exir_ops.edge.cortex_m.quantized_conv2d.default
depth_multiplier_args = ()

conv_args = (
x_4d_nhwc,
weight_nhwc,
bias,
stride_2d,
padding_2d,
dilation_2d,
*depth_multiplier_args,
-input_zero_point,
output_zero_point,
quantized_multiplier_tensor,
quantized_shift_tensor,
output_qmin,
output_qmax,
scratch,
)

conv_node = node.graph.create_node(
"call_function",
target=conv_op,
args=conv_args,
kwargs={},
)
self._initialize_alloc_node_size(conv_node)

out_4d_nchw = node.graph.create_node(
"call_function",
target=exir_ops.edge.dim_order_ops._clone_dim_order.default,
args=(conv_node,),
kwargs={"dim_order": [0, 1, 2, 3]},
)
out_3d = node.graph.create_node(
"call_function",
target=exir_ops.edge.aten.view_copy.default,
args=(out_4d_nchw, out_3d_shape),
)

node.replace_all_uses_with(out_3d)
graph_module.graph.erase_node(node)

def _initialize_alloc_node_size(self, node: torch.fx.Node) -> None:
"""For nodes with a registered buffer size function for node.target, set the buffer sizes
of the last n args, which should be exir.memory.alloc nodes. For nodes without a
Expand Down Expand Up @@ -500,6 +685,15 @@
case exir_ops.edge.aten.convolution.default:
# Check if it's transposed convolution (arg index 6)
transposed = node.args[6] if len(node.args) > 6 else False
# stride length is 1 for conv1d, 2 for conv2d. Conv1d is
# lowered to the existing quantized_conv2d kernel with H=1
# by inserting unsqueeze + dim-order clone around the call;
# the helper handles its own replace_all_uses + erase.
is_conv1d = len(node.args[3]) == 1
if is_conv1d and not transposed:
self._lower_conv1d(node, graph_module)
modified = True
continue
if transposed:
op, args = self._get_transpose_conv2d_replacement(node)
else:
Expand Down
11 changes: 11 additions & 0 deletions backends/cortex_m/passes/cortex_m_pass_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
ScalarsToAttributePass,
)
from executorch.backends.cortex_m.target_config import CortexM, CortexMTargetConfig
from executorch.backends.transforms.fuse_view_copy import FuseViewCopyTransform
from executorch.backends.transforms.remove_getitem_op import RemoveGetItemPass
from executorch.backends.transforms.replace_scalar_with_tensor import (
ReplaceScalarWithTensorArgPass,
Expand All @@ -27,6 +28,7 @@
from .convert_to_cortex_m_pass import ConvertToCortexMPass
from .decompose_hardswish_pass import DecomposeHardswishPass
from .decompose_mean_pass import DecomposeMeanPass
from .fold_inverse_dim_order_clone_pass import FoldInverseDimOrderClonePass
from .quantized_clamp_activation_pass import QuantizedClampActivationPass
from .quantized_op_fusion_pass import QuantizedOpFusionPass
from .replace_quant_nodes_pass import ReplaceQuantNodesPass
Expand All @@ -46,6 +48,15 @@ class CortexMPassManager(PassManager):
DecomposeHardswishPass,
QuantizedOpFusionPass,
ConvertToCortexMPass,
# Conv1d lowering inserts view_copy + _clone_dim_order wrappers around
# each conv2d call. Between consecutive conv1d layers these chain
# together to form an identity. FuseViewCopyTransform collapses the
# view_copy <-> view_copy chain (treating the dim_order clones as
# unary elementwise ops it walks through); FoldInverseDimOrderClonePass
# then removes the surviving _clone_dim_order pair whose composed
# dim_order is the identity.
FuseViewCopyTransform,
FoldInverseDimOrderClonePass,
]

pass_list_transform_for_annotation: list[PassClass] = [
Expand Down
66 changes: 66 additions & 0 deletions backends/cortex_m/passes/fold_inverse_dim_order_clone_pass.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
# Copyright 2026 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import torch
import torch.fx
from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.pass_base import ExportPass
from torch.fx.passes.infra.pass_manager import PassResult


class FoldInverseDimOrderClonePass(ExportPass):
"""Fold adjacent `_clone_dim_order` pairs whose net effect is identity.

The conv1d lowering inserts a `_clone_dim_order(dim_order=[0,2,3,1])`
before each conv and a `_clone_dim_order(dim_order=[0,1,2,3])` after. When
`FuseViewCopyTransform` collapses the intermediate view_copy chain between
two consecutive conv1d lowerings, the surviving graph is

... -> _clone_dim_order(to NCHW) -> _clone_dim_order(to NHWC) -> ...

where the second clone's dim_order is the inverse of the first applied to
the same shape -- two byte reorders that cancel. This pass detects that
exact pattern and replaces uses of the second clone with the first
clone's input, then lets dead code elimination remove both.
"""

_CLONE = exir_ops.edge.dim_order_ops._clone_dim_order.default

def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
modified = False

for node in list(graph_module.graph.nodes):
if node.op != "call_function" or node.target != self._CLONE:
continue

second_clone = node
first_clone = second_clone.args[0]
if (
not isinstance(first_clone, torch.fx.Node)
or first_clone.op != "call_function"
or first_clone.target != self._CLONE
or len(first_clone.users) != 1
):
continue

original_input = first_clone.args[0]
if not isinstance(original_input, torch.fx.Node):
continue

# Net effect is identity iff the second clone's target dim_order
# equals the input tensor's dim_order before the first clone.
original_dim_order = tuple(original_input.meta["val"].dim_order())
second_dim_order = tuple(second_clone.kwargs.get("dim_order", ()))
if original_dim_order != second_dim_order:
continue

second_clone.replace_all_uses_with(original_input)
modified = True

if modified:
graph_module.graph.eliminate_dead_code()
graph_module.recompile()

return PassResult(graph_module, modified)
29 changes: 29 additions & 0 deletions backends/cortex_m/quantizer/pattern_checkers.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,35 @@ def check_quantization_config(
return is_int8 and is_ch_axis_0


class CortexMConv1DCheck(PatternCheck):
"""Accepts aten.conv1d.default with rank-3 NCW inputs.

The conv1d is lowered to cortex_m.quantized_conv2d via AoT weight
reshape (O, I, K) -> (O, 1, K, I) and graph-level input unsqueeze +
channels_last conversion in ConvertToCortexMPass.
"""

@classmethod
def check_pattern(cls, pattern):
for node in pattern:
tensor = get_first_fake_tensor(node)
if tensor is None or tensor.ndim != 3:
return False
return True

@classmethod
def check_quantization_config(
cls, pattern: list[Node], quantization_config: QuantizationConfig
) -> bool:
is_int8 = cls.is_int8_activations(quantization_config)
conv_node = pattern[0] if pattern else None
weight_qspec = quantization_config.get_weight_qspec(conv_node)
if not isinstance(weight_qspec, QuantizationSpec):
return False
is_ch_axis_0 = weight_qspec.ch_axis == 0 or weight_qspec.ch_axis is None
return is_int8 and is_ch_axis_0


class CortexMLinearCheck(PatternCheck):
@classmethod
def check_quantization_config(
Expand Down
22 changes: 22 additions & 0 deletions backends/cortex_m/quantizer/quantizer_support.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
CortexMAddMulCheck,
CortexMAvgPool2DCheck,
CortexMBmmCheck,
CortexMConv1DCheck,
CortexMConv2DCheck,
CortexMConvTranspose2DCheck,
CortexMLinearCheck,
Expand Down Expand Up @@ -76,6 +77,27 @@
): CortexMConv2DCheck,
(torch.ops.aten.conv2d.default, torch.ops.aten.clamp.default): CortexMConv2DCheck,
(torch.ops.aten.conv2d.default, torch.ops.aten.clamp_.default): CortexMConv2DCheck,
(torch.ops.aten.conv1d.default,): CortexMConv1DCheck,
(torch.ops.aten.conv1d.default, torch.ops.aten.relu.default): CortexMConv1DCheck,
(torch.ops.aten.conv1d.default, torch.ops.aten.relu_.default): CortexMConv1DCheck,
(
torch.ops.aten.conv1d.default,
torch.ops.aten.hardtanh.default,
): CortexMConv1DCheck,
(
torch.ops.aten.conv1d.default,
torch.ops.aten.hardtanh_.default,
): CortexMConv1DCheck,
(
torch.ops.aten.conv1d.default,
torch.ops.aten.hardsigmoid.default,
): CortexMConv1DCheck,
(
torch.ops.aten.conv1d.default,
torch.ops.aten.hardsigmoid_.default,
): CortexMConv1DCheck,
(torch.ops.aten.conv1d.default, torch.ops.aten.clamp.default): CortexMConv1DCheck,
(torch.ops.aten.conv1d.default, torch.ops.aten.clamp_.default): CortexMConv1DCheck,
}

CONV_TRANSPOSE_OP_PATTERNS = {
Expand Down
2 changes: 2 additions & 0 deletions backends/cortex_m/test/build_test_runner.sh
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,8 @@ aten::cat.out,\
aten::full.out,\
aten::ge.Tensor_out,\
aten::unsqueeze_copy.out,\
aten::squeeze_copy.dims_out,\
aten::view_copy.out,\
aten::select_copy.int_out,\
aten::amax.out"

Expand Down
Loading
Loading