Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion backends/arm/ethosu/partitioner.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.


from typing import final, Optional, Sequence

import torch
from executorch.backends.arm.ethosu import EthosUBackend, EthosUCompileSpec
from executorch.backends.arm.tosa.partitioner import TOSAPartitioner
from executorch.exir.backend.partitioner import DelegationSpec
Expand Down Expand Up @@ -33,3 +33,4 @@ def __init__(
)
self.additional_checks = additional_checks
self.tosa_spec = compile_spec.tosa_spec
self._custom_partition_ops: set[torch._ops.OpOverload] = set()
1 change: 1 addition & 0 deletions backends/arm/operators/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@
op_to_dim_order_copy,
op_tosa_conv2d,
op_tosa_conv3d,
op_tosa_custom,
op_tosa_depthwise_conv2d,
op_tosa_gather,
op_tosa_matmul,
Expand Down
85 changes: 85 additions & 0 deletions backends/arm/operators/op_tosa_custom.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
# Copyright 2026 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from typing import Any, List

import torch
import tosa_serializer as ts

from executorch.backends.arm.operators.node_visitor import (
NodeVisitor,
register_node_visitor,
)
from executorch.backends.arm.tosa.mapping import TosaArg


@register_node_visitor
class CustomVisitor(NodeVisitor):
"""Lower the TOSA CUSTOM op from the TOSA backend dialect."""

target = "tosa.CUSTOM.default"

def define_node(
self,
node: torch.fx.Node,
tosa_graph: Any,
inputs: List[TosaArg],
output: TosaArg,
) -> None:
allowed_kwargs = {"operator_name", "domain_name", "implementation_attrs"}
unexpected = set(node.kwargs.keys()) - allowed_kwargs
if unexpected:
raise ValueError(
f"tosa.CUSTOM received unexpected kwargs: {sorted(unexpected)}"
)

operator_name = node.kwargs.get("operator_name")
domain_name = node.kwargs.get("domain_name")
implementation_attrs = node.kwargs.get("implementation_attrs")

if operator_name is None or domain_name is None:
raise ValueError(
"tosa.CUSTOM requires operator_name and domain_name in kwargs"
)

if implementation_attrs is None:
impl_list = []
elif isinstance(implementation_attrs, list):
# NOTE: PyTorch schemas do not support a bytes type; we pass
# implementation_attrs as int[] representing raw bytes.
impl_list = [int(x) for x in implementation_attrs]
else:
raise TypeError(
"implementation_attrs must be None or list[int]; "
f"got {type(implementation_attrs)}"
)

attr = ts.TosaSerializerAttribute()
attr.CustomAttribute(
operator_name=operator_name,
domain_name=domain_name,
implementation_attrs=impl_list,
)

expanded = [TosaArg(item, self.tosa_spec) for item in inputs[0].special]
input_names = [arg.name for arg in expanded]
output_names = (
output.multiple_output_names
if getattr(output, "multiple_output_names", None)
else [output.name]
)
if len(output_names) != 1:
# TODO: Support multi-output CUSTOM ops with per-output meta/shape.
raise ValueError(
f"tosa.CUSTOM currently requires a single output, got {len(output_names)}"
)
self._serialize_operator(
node,
tosa_graph,
ts.Op.CUSTOM,
input_names,
output_names,
attr,
)
8 changes: 8 additions & 0 deletions backends/arm/public_api_manifests/api_manifest_running.toml
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,10 @@ signature = "EthosUPartitioner.ops_to_not_decompose(self, ep: torch.export.expor
kind = "function"
signature = "EthosUPartitioner.partition(self, exported_program: torch.export.exported_program.ExportedProgram) -> executorch.exir.backend.partitioner.PartitionResult"

[python.EthosUPartitioner.register_custom_partition_op]
kind = "function"
signature = "EthosUPartitioner.register_custom_partition_op(self, op: torch._ops.OpOverload) -> None"

[python.EthosUQuantizer]
kind = "class"
signature = "EthosUQuantizer(compile_spec: 'EthosUCompileSpec', use_composable_quantizer: 'bool' = False) -> 'None'"
Expand Down Expand Up @@ -136,6 +140,10 @@ signature = "VgfPartitioner.ops_to_not_decompose(self, ep: torch.export.exported
kind = "function"
signature = "VgfPartitioner.partition(self, exported_program: torch.export.exported_program.ExportedProgram) -> executorch.exir.backend.partitioner.PartitionResult"

[python.VgfPartitioner.register_custom_partition_op]
kind = "function"
signature = "VgfPartitioner.register_custom_partition_op(self, op: torch._ops.OpOverload) -> None"

[python.VgfQuantizer]
kind = "class"
signature = "VgfQuantizer(compile_spec: 'VgfCompileSpec', use_composable_quantizer: 'bool' = False) -> 'None'"
Expand Down
1 change: 1 addition & 0 deletions backends/arm/tosa/dialect/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from executorch.backends.arm.tosa.dialect.ops import ( # noqa F401
conv2d,
conv3d,
custom,
depthwise_conv2d,
gather,
matmul,
Expand Down
159 changes: 159 additions & 0 deletions backends/arm/tosa/dialect/ops/custom.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,159 @@
# Copyright 2026 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
"""Fake-op support for the generic TOSA ``CUSTOM`` dialect op.

The serialized TOSA ``CUSTOM`` op is intentionally generic: it carries a
stable operator identity (for example ``myns.my_op``) plus an
opaque payload in ``implementation_attrs``. That is enough for serialization,
but not enough for FakeTensor propagation unless we also teach the compiler how
to model the output tensors of the specific wrapped op.

This module provides a lightweight registration mechanism for those compiler
side fake implementations:

1. A lowering pass rewrites an op to ``exir_ops.backend.tosa.CUSTOM.default``.
2. The wrapped custom op registers a thin adapter with
``@register_fake_tosa("namespace::op")``.
3. The generic ``CUSTOM`` fake implementation looks up that adapter by the
``operator_name`` argument and invokes it with the full custom-op calling
convention ``(inputs, operator_name, domain_name, implementation_attrs)``.

The adapter should stay thin: it should only translate from the generic TOSA
CUSTOM signature back to the wrapped op's fake semantics. The real semantic
logic should continue to live in the original fake implementation where
possible.

"""

import inspect
from collections.abc import Callable

import torch
from executorch.backends.arm.tosa.dialect.ops_registration import register_fake_tosa_op

from executorch.backends.arm.tosa.specification import (
get_context_spec,
TosaSpecification,
)

_TOSA_CUSTOM_FAKE_IMPLS: dict[str, Callable] = {}


def _normalize_tosa_custom_operator_name(operator_name: str) -> str:
"""Normalize operator names so ``ns::op`` and ``ns.op`` map identically."""
return operator_name.replace("::", ".")


def validate_tosa_custom_fake_impl(fake_impl: object) -> Callable:
"""Validate the signature expected by ``register_fake_tosa``.

Registered fake implementations must accept the generic TOSA CUSTOM fake
calling convention:

``(inputs, operator_name, domain_name, implementation_attrs)``

and return ``list[Tensor]``.

"""
if not callable(fake_impl):
raise TypeError(
"Expected tosa.CUSTOM fake impl to be callable, " f"got {type(fake_impl)}"
)

params = tuple(inspect.signature(fake_impl).parameters.values())
positional_kinds = {
inspect.Parameter.POSITIONAL_ONLY,
inspect.Parameter.POSITIONAL_OR_KEYWORD,
}
if len(params) != 4 or any(param.kind not in positional_kinds for param in params):
raise TypeError(
"tosa.CUSTOM fake impl must have signature "
"(inputs, operator_name, domain_name, implementation_attrs)"
)
return fake_impl


def register_fake_tosa(operator_name: str) -> Callable[[Callable], Callable]:
"""Register a fake implementation for a specific wrapped TOSA custom op.

Args:
operator_name: Stable custom operator identifier. Both ``ns::op`` and
``ns.op`` spellings are accepted.

Returns:
A decorator that registers a callable with signature
``(inputs, operator_name, domain_name, implementation_attrs)`` and
returning ``list[Tensor]``.

Example:
``@register_fake_tosa("my_namespace::my_op")``

"""
normalized_name = _normalize_tosa_custom_operator_name(operator_name)

def decorator(fake_impl: Callable) -> Callable:
validated = validate_tosa_custom_fake_impl(fake_impl)
_TOSA_CUSTOM_FAKE_IMPLS[normalized_name] = validated
return fake_impl

return decorator


def has_fake_tosa_impl(operator_name: str) -> bool:
"""Return whether a wrapped custom op has a registered fake impl."""
normalized_name = _normalize_tosa_custom_operator_name(operator_name)
return normalized_name in _TOSA_CUSTOM_FAKE_IMPLS


def run_registered_fake_tosa_impl(
inputs: list[torch.Tensor],
operator_name: str,
domain_name: str,
implementation_attrs: list[int],
) -> list[torch.Tensor]:
"""Invoke the registered fake implementation for a wrapped custom op."""
normalized_name = _normalize_tosa_custom_operator_name(operator_name)
fake_impl = _TOSA_CUSTOM_FAKE_IMPLS.get(normalized_name)
if fake_impl is None:
raise RuntimeError(
f"tosa.CUSTOM requires a registered fake impl for {normalized_name}"
)
outputs = fake_impl(inputs, operator_name, domain_name, implementation_attrs)
if not isinstance(outputs, list):
raise TypeError(
"tosa.CUSTOM fake impl must return list[Tensor], " f"got {type(outputs)}"
)
if not outputs:
raise RuntimeError("tosa.CUSTOM fake impl must return at least one output")
if not all(isinstance(output, torch.Tensor) for output in outputs):
raise TypeError("tosa.CUSTOM fake impl must return list[Tensor]")
return outputs


@register_fake_tosa_op(
"CUSTOM(Tensor[] inputs, str operator_name, str domain_name, int[] implementation_attrs) -> Tensor[]",
TosaSpecification.all_versions_and_profiles(),
)
def CUSTOM(
inputs: list[torch.Tensor],
operator_name: str,
domain_name: str,
implementation_attrs: list[int],
) -> list[torch.Tensor]:
"""Fake implementation for TOSA CUSTOM op.

The CUSTOM op is backend-defined. The fake implementation dispatches to a
registered compiler-side fake implementation for the specific custom op.

"""
_ = get_context_spec() # ensure a spec context exists
if not inputs:
raise RuntimeError("tosa.CUSTOM requires at least one input tensor")
return run_registered_fake_tosa_impl(
inputs,
operator_name,
domain_name,
implementation_attrs,
)
Comment on lines +135 to +159
Copy link

Copilot AI Apr 13, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consider adding targeted tests for the new tosa.CUSTOM fake op dispatch/registration path (e.g., successful dispatch to a registered fake impl, and the failure mode when no impl is registered). The repo already has dialect fake-op tests under backends/arm/test/misc/ (e.g. conv2d/shape_ops), so this new behavior should be covered similarly.

Copilot uses AI. Check for mistakes.
5 changes: 5 additions & 0 deletions backends/arm/tosa/mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,11 @@ def extract_tensor_meta(meta):
if type(val) is tuple:
# TODO: should use first concrete representation
val = val[0]
if isinstance(val, list):
if not val:
raise ValueError("Expected node.meta['val'] list to be non-empty")
# Use first concrete representation for multi-output ops.
val = val[0]

if not isinstance(val, torch._subclasses.fake_tensor.FakeTensor):
raise ValueError(
Expand Down
35 changes: 34 additions & 1 deletion backends/arm/tosa/partitioner.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,11 +43,24 @@
from torch.export.exported_program import ExportedProgram
from torch.fx import GraphModule
from torch.fx.passes.infra.partitioner import CapabilityBasedPartitioner, Partition
from torch.fx.passes.operator_support import OperatorSupportBase
from torch.fx.passes.operator_support import any_chain, OperatorSupportBase

logger = logging.getLogger(__name__)


def _is_custom_partition_op(
custom_ops: set[torch._ops.OpOverload], target: object
) -> bool:
if target in custom_ops:
return True
if hasattr(target, "_op"):
try:
return target._op in custom_ops
except Exception:
return False
return False


def _is_noop_clone(node: torch.fx.node.Node) -> bool:
return node.target == exir_ops.edge.dim_order_ops._clone_dim_order.default

Expand Down Expand Up @@ -149,6 +162,13 @@ def __init__(
)
self.tosa_spec = compile_spec.tosa_spec
self.additional_checks = additional_checks
self._custom_partition_ops: set[torch._ops.OpOverload] = set()

def register_custom_partition_op(self, op: torch._ops.OpOverload) -> None:
"""Register a custom op to be considered supported by this
partitioner.
"""
self._custom_partition_ops.add(op)

def _detag_boundary_nodes(
self, module: GraphModule, tag: str, reporter: WhyNoPartitionReporter
Expand Down Expand Up @@ -233,6 +253,16 @@ def _tag_module( # noqa
operator_support = tosa_support_factory(
self.tosa_spec, containing_program, reporter, self.additional_checks
)
if self._custom_partition_ops:
custom_ops = set(self._custom_partition_ops)

class CustomOpSupported(OperatorSupportBase):
def is_node_supported(self, submodules, node: torch.fx.Node) -> bool:
return node.op == "call_function" and _is_custom_partition_op(
custom_ops, node.target
)

operator_support = any_chain(operator_support, CustomOpSupported())
capability_partitioner = CapabilityBasedPartitioner(
module,
operator_support,
Expand Down Expand Up @@ -368,6 +398,8 @@ def filter_fn(node: torch.fx.Node) -> bool:
bool: True to keep the op intact; otherwise, False.

"""
if _is_custom_partition_op(self._custom_partition_ops, node.target):
return True
if (
self.tosa_spec.support_float()
and node.target in ops_to_not_decompose_if_fp
Expand Down Expand Up @@ -444,6 +476,7 @@ def filter_fn(node: torch.fx.Node) -> bool:
| ops_to_not_decompose_if_fp
| ops_to_not_decompose_if_integer
)
ops_to_not_decompose.extend(self._custom_partition_ops)

if not self.tosa_spec.is_U55_subset:
# Tosa operator "RESIZE" is not supported on U55. Since upsample_bilinear2d
Expand Down
Loading
Loading