Skip to content

Commit a6e4ade

Browse files
robellCopilot
andauthored
Add structure for use of TOSA CUSTOM ops (#18837)
TOSA custom ops allow for wrapped custom operators of their own dialect. This adds the tosa op and recognition in partitioning to enable custom torch.library operators to be passed into the arm backend and mapped to backend provided implementations. The broader implementation using these mechanisms will follow. - Enables operators to be registered in the partitioner - add tosa.CUSTOM fake op registration in the dialect - for backend passes (in tree or registed) to create+use tosa.custom only within partition - register a TOSA CUSTOM node visitor for serialization - as custom wraps an operator, adds support for register_fake_tosa to register the shape-only operator in tosa cc @digantdesai @freddan80 @per @zingo @oscarandersson8218 @mansnils @Sebastian-Larsson --------- Signed-off-by: Rob Elliott <Robert.Elliott@arm.com> Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
1 parent 3d4be1d commit a6e4ade

9 files changed

Lines changed: 297 additions & 3 deletions

File tree

backends/arm/ethosu/partitioner.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,9 @@
33
# This source code is licensed under the BSD-style license found in the
44
# LICENSE file in the root directory of this source tree.
55

6-
76
from typing import final, Optional, Sequence
87

8+
import torch
99
from executorch.backends.arm.ethosu import EthosUBackend, EthosUCompileSpec
1010
from executorch.backends.arm.tosa.partitioner import TOSAPartitioner
1111
from executorch.exir.backend.partitioner import DelegationSpec
@@ -33,3 +33,4 @@ def __init__(
3333
)
3434
self.additional_checks = additional_checks
3535
self.tosa_spec = compile_spec.tosa_spec
36+
self._custom_partition_ops: set[torch._ops.OpOverload] = set()

backends/arm/operators/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@
5151
op_to_dim_order_copy,
5252
op_tosa_conv2d,
5353
op_tosa_conv3d,
54+
op_tosa_custom,
5455
op_tosa_depthwise_conv2d,
5556
op_tosa_gather,
5657
op_tosa_matmul,
Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
# Copyright 2026 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
from typing import Any, List
7+
8+
import torch
9+
import tosa_serializer as ts
10+
11+
from executorch.backends.arm.operators.node_visitor import (
12+
NodeVisitor,
13+
register_node_visitor,
14+
)
15+
from executorch.backends.arm.tosa.mapping import TosaArg
16+
17+
18+
@register_node_visitor
19+
class CustomVisitor(NodeVisitor):
20+
"""Lower the TOSA CUSTOM op from the TOSA backend dialect."""
21+
22+
target = "tosa.CUSTOM.default"
23+
24+
def define_node(
25+
self,
26+
node: torch.fx.Node,
27+
tosa_graph: Any,
28+
inputs: List[TosaArg],
29+
output: TosaArg,
30+
) -> None:
31+
allowed_kwargs = {"operator_name", "domain_name", "implementation_attrs"}
32+
unexpected = set(node.kwargs.keys()) - allowed_kwargs
33+
if unexpected:
34+
raise ValueError(
35+
f"tosa.CUSTOM received unexpected kwargs: {sorted(unexpected)}"
36+
)
37+
38+
operator_name = node.kwargs.get("operator_name")
39+
domain_name = node.kwargs.get("domain_name")
40+
implementation_attrs = node.kwargs.get("implementation_attrs")
41+
42+
if operator_name is None or domain_name is None:
43+
raise ValueError(
44+
"tosa.CUSTOM requires operator_name and domain_name in kwargs"
45+
)
46+
47+
if implementation_attrs is None:
48+
impl_list = []
49+
elif isinstance(implementation_attrs, list):
50+
# NOTE: PyTorch schemas do not support a bytes type; we pass
51+
# implementation_attrs as int[] representing raw bytes.
52+
impl_list = [int(x) for x in implementation_attrs]
53+
else:
54+
raise TypeError(
55+
"implementation_attrs must be None or list[int]; "
56+
f"got {type(implementation_attrs)}"
57+
)
58+
59+
attr = ts.TosaSerializerAttribute()
60+
attr.CustomAttribute(
61+
operator_name=operator_name,
62+
domain_name=domain_name,
63+
implementation_attrs=impl_list,
64+
)
65+
66+
expanded = [TosaArg(item, self.tosa_spec) for item in inputs[0].special]
67+
input_names = [arg.name for arg in expanded]
68+
output_names = (
69+
output.multiple_output_names
70+
if getattr(output, "multiple_output_names", None)
71+
else [output.name]
72+
)
73+
if len(output_names) != 1:
74+
# TODO: Support multi-output CUSTOM ops with per-output meta/shape.
75+
raise ValueError(
76+
f"tosa.CUSTOM currently requires a single output, got {len(output_names)}"
77+
)
78+
self._serialize_operator(
79+
node,
80+
tosa_graph,
81+
ts.Op.CUSTOM,
82+
input_names,
83+
output_names,
84+
attr,
85+
)

backends/arm/public_api_manifests/api_manifest_running.toml

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,10 @@ signature = "EthosUPartitioner.ops_to_not_decompose(self, ep: torch.export.expor
5656
kind = "function"
5757
signature = "EthosUPartitioner.partition(self, exported_program: torch.export.exported_program.ExportedProgram) -> executorch.exir.backend.partitioner.PartitionResult"
5858

59+
[python.EthosUPartitioner.register_custom_partition_op]
60+
kind = "function"
61+
signature = "EthosUPartitioner.register_custom_partition_op(self, op: torch._ops.OpOverload) -> None"
62+
5963
[python.EthosUQuantizer]
6064
kind = "class"
6165
signature = "EthosUQuantizer(compile_spec: 'EthosUCompileSpec', use_composable_quantizer: 'bool' = False) -> 'None'"
@@ -136,6 +140,10 @@ signature = "VgfPartitioner.ops_to_not_decompose(self, ep: torch.export.exported
136140
kind = "function"
137141
signature = "VgfPartitioner.partition(self, exported_program: torch.export.exported_program.ExportedProgram) -> executorch.exir.backend.partitioner.PartitionResult"
138142

143+
[python.VgfPartitioner.register_custom_partition_op]
144+
kind = "function"
145+
signature = "VgfPartitioner.register_custom_partition_op(self, op: torch._ops.OpOverload) -> None"
146+
139147
[python.VgfQuantizer]
140148
kind = "class"
141149
signature = "VgfQuantizer(compile_spec: 'VgfCompileSpec', use_composable_quantizer: 'bool' = False) -> 'None'"

backends/arm/tosa/dialect/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from executorch.backends.arm.tosa.dialect.ops import ( # noqa F401
77
conv2d,
88
conv3d,
9+
custom,
910
depthwise_conv2d,
1011
gather,
1112
matmul,
Lines changed: 159 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,159 @@
1+
# Copyright 2026 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
"""Fake-op support for the generic TOSA ``CUSTOM`` dialect op.
6+
7+
The serialized TOSA ``CUSTOM`` op is intentionally generic: it carries a
8+
stable operator identity (for example ``myns.my_op``) plus an
9+
opaque payload in ``implementation_attrs``. That is enough for serialization,
10+
but not enough for FakeTensor propagation unless we also teach the compiler how
11+
to model the output tensors of the specific wrapped op.
12+
13+
This module provides a lightweight registration mechanism for those compiler
14+
side fake implementations:
15+
16+
1. A lowering pass rewrites an op to ``exir_ops.backend.tosa.CUSTOM.default``.
17+
2. The wrapped custom op registers a thin adapter with
18+
``@register_fake_tosa("namespace::op")``.
19+
3. The generic ``CUSTOM`` fake implementation looks up that adapter by the
20+
``operator_name`` argument and invokes it with the full custom-op calling
21+
convention ``(inputs, operator_name, domain_name, implementation_attrs)``.
22+
23+
The adapter should stay thin: it should only translate from the generic TOSA
24+
CUSTOM signature back to the wrapped op's fake semantics. The real semantic
25+
logic should continue to live in the original fake implementation where
26+
possible.
27+
28+
"""
29+
30+
import inspect
31+
from collections.abc import Callable
32+
33+
import torch
34+
from executorch.backends.arm.tosa.dialect.ops_registration import register_fake_tosa_op
35+
36+
from executorch.backends.arm.tosa.specification import (
37+
get_context_spec,
38+
TosaSpecification,
39+
)
40+
41+
_TOSA_CUSTOM_FAKE_IMPLS: dict[str, Callable] = {}
42+
43+
44+
def _normalize_tosa_custom_operator_name(operator_name: str) -> str:
45+
"""Normalize operator names so ``ns::op`` and ``ns.op`` map identically."""
46+
return operator_name.replace("::", ".")
47+
48+
49+
def validate_tosa_custom_fake_impl(fake_impl: object) -> Callable:
50+
"""Validate the signature expected by ``register_fake_tosa``.
51+
52+
Registered fake implementations must accept the generic TOSA CUSTOM fake
53+
calling convention:
54+
55+
``(inputs, operator_name, domain_name, implementation_attrs)``
56+
57+
and return ``list[Tensor]``.
58+
59+
"""
60+
if not callable(fake_impl):
61+
raise TypeError(
62+
"Expected tosa.CUSTOM fake impl to be callable, " f"got {type(fake_impl)}"
63+
)
64+
65+
params = tuple(inspect.signature(fake_impl).parameters.values())
66+
positional_kinds = {
67+
inspect.Parameter.POSITIONAL_ONLY,
68+
inspect.Parameter.POSITIONAL_OR_KEYWORD,
69+
}
70+
if len(params) != 4 or any(param.kind not in positional_kinds for param in params):
71+
raise TypeError(
72+
"tosa.CUSTOM fake impl must have signature "
73+
"(inputs, operator_name, domain_name, implementation_attrs)"
74+
)
75+
return fake_impl
76+
77+
78+
def register_fake_tosa(operator_name: str) -> Callable[[Callable], Callable]:
79+
"""Register a fake implementation for a specific wrapped TOSA custom op.
80+
81+
Args:
82+
operator_name: Stable custom operator identifier. Both ``ns::op`` and
83+
``ns.op`` spellings are accepted.
84+
85+
Returns:
86+
A decorator that registers a callable with signature
87+
``(inputs, operator_name, domain_name, implementation_attrs)`` and
88+
returning ``list[Tensor]``.
89+
90+
Example:
91+
``@register_fake_tosa("my_namespace::my_op")``
92+
93+
"""
94+
normalized_name = _normalize_tosa_custom_operator_name(operator_name)
95+
96+
def decorator(fake_impl: Callable) -> Callable:
97+
validated = validate_tosa_custom_fake_impl(fake_impl)
98+
_TOSA_CUSTOM_FAKE_IMPLS[normalized_name] = validated
99+
return fake_impl
100+
101+
return decorator
102+
103+
104+
def has_fake_tosa_impl(operator_name: str) -> bool:
105+
"""Return whether a wrapped custom op has a registered fake impl."""
106+
normalized_name = _normalize_tosa_custom_operator_name(operator_name)
107+
return normalized_name in _TOSA_CUSTOM_FAKE_IMPLS
108+
109+
110+
def run_registered_fake_tosa_impl(
111+
inputs: list[torch.Tensor],
112+
operator_name: str,
113+
domain_name: str,
114+
implementation_attrs: list[int],
115+
) -> list[torch.Tensor]:
116+
"""Invoke the registered fake implementation for a wrapped custom op."""
117+
normalized_name = _normalize_tosa_custom_operator_name(operator_name)
118+
fake_impl = _TOSA_CUSTOM_FAKE_IMPLS.get(normalized_name)
119+
if fake_impl is None:
120+
raise RuntimeError(
121+
f"tosa.CUSTOM requires a registered fake impl for {normalized_name}"
122+
)
123+
outputs = fake_impl(inputs, operator_name, domain_name, implementation_attrs)
124+
if not isinstance(outputs, list):
125+
raise TypeError(
126+
"tosa.CUSTOM fake impl must return list[Tensor], " f"got {type(outputs)}"
127+
)
128+
if not outputs:
129+
raise RuntimeError("tosa.CUSTOM fake impl must return at least one output")
130+
if not all(isinstance(output, torch.Tensor) for output in outputs):
131+
raise TypeError("tosa.CUSTOM fake impl must return list[Tensor]")
132+
return outputs
133+
134+
135+
@register_fake_tosa_op(
136+
"CUSTOM(Tensor[] inputs, str operator_name, str domain_name, int[] implementation_attrs) -> Tensor[]",
137+
TosaSpecification.all_versions_and_profiles(),
138+
)
139+
def CUSTOM(
140+
inputs: list[torch.Tensor],
141+
operator_name: str,
142+
domain_name: str,
143+
implementation_attrs: list[int],
144+
) -> list[torch.Tensor]:
145+
"""Fake implementation for TOSA CUSTOM op.
146+
147+
The CUSTOM op is backend-defined. The fake implementation dispatches to a
148+
registered compiler-side fake implementation for the specific custom op.
149+
150+
"""
151+
_ = get_context_spec() # ensure a spec context exists
152+
if not inputs:
153+
raise RuntimeError("tosa.CUSTOM requires at least one input tensor")
154+
return run_registered_fake_tosa_impl(
155+
inputs,
156+
operator_name,
157+
domain_name,
158+
implementation_attrs,
159+
)

backends/arm/tosa/mapping.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,11 @@ def extract_tensor_meta(meta):
139139
if type(val) is tuple:
140140
# TODO: should use first concrete representation
141141
val = val[0]
142+
if isinstance(val, list):
143+
if not val:
144+
raise ValueError("Expected node.meta['val'] list to be non-empty")
145+
# Use first concrete representation for multi-output ops.
146+
val = val[0]
142147

143148
if not isinstance(val, torch._subclasses.fake_tensor.FakeTensor):
144149
raise ValueError(

backends/arm/tosa/partitioner.py

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,11 +43,24 @@
4343
from torch.export.exported_program import ExportedProgram
4444
from torch.fx import GraphModule
4545
from torch.fx.passes.infra.partitioner import CapabilityBasedPartitioner, Partition
46-
from torch.fx.passes.operator_support import OperatorSupportBase
46+
from torch.fx.passes.operator_support import any_chain, OperatorSupportBase
4747

4848
logger = logging.getLogger(__name__)
4949

5050

51+
def _is_custom_partition_op(
52+
custom_ops: set[torch._ops.OpOverload], target: object
53+
) -> bool:
54+
if target in custom_ops:
55+
return True
56+
if hasattr(target, "_op"):
57+
try:
58+
return target._op in custom_ops
59+
except Exception:
60+
return False
61+
return False
62+
63+
5164
def _is_noop_clone(node: torch.fx.node.Node) -> bool:
5265
return node.target == exir_ops.edge.dim_order_ops._clone_dim_order.default
5366

@@ -149,6 +162,13 @@ def __init__(
149162
)
150163
self.tosa_spec = compile_spec.tosa_spec
151164
self.additional_checks = additional_checks
165+
self._custom_partition_ops: set[torch._ops.OpOverload] = set()
166+
167+
def register_custom_partition_op(self, op: torch._ops.OpOverload) -> None:
168+
"""Register a custom op to be considered supported by this
169+
partitioner.
170+
"""
171+
self._custom_partition_ops.add(op)
152172

153173
def _detag_boundary_nodes(
154174
self, module: GraphModule, tag: str, reporter: WhyNoPartitionReporter
@@ -233,6 +253,16 @@ def _tag_module( # noqa
233253
operator_support = tosa_support_factory(
234254
self.tosa_spec, containing_program, reporter, self.additional_checks
235255
)
256+
if self._custom_partition_ops:
257+
custom_ops = set(self._custom_partition_ops)
258+
259+
class CustomOpSupported(OperatorSupportBase):
260+
def is_node_supported(self, submodules, node: torch.fx.Node) -> bool:
261+
return node.op == "call_function" and _is_custom_partition_op(
262+
custom_ops, node.target
263+
)
264+
265+
operator_support = any_chain(operator_support, CustomOpSupported())
236266
capability_partitioner = CapabilityBasedPartitioner(
237267
module,
238268
operator_support,
@@ -368,6 +398,8 @@ def filter_fn(node: torch.fx.Node) -> bool:
368398
bool: True to keep the op intact; otherwise, False.
369399
370400
"""
401+
if _is_custom_partition_op(self._custom_partition_ops, node.target):
402+
return True
371403
if (
372404
self.tosa_spec.support_float()
373405
and node.target in ops_to_not_decompose_if_fp
@@ -444,6 +476,7 @@ def filter_fn(node: torch.fx.Node) -> bool:
444476
| ops_to_not_decompose_if_fp
445477
| ops_to_not_decompose_if_integer
446478
)
479+
ops_to_not_decompose.extend(self._custom_partition_ops)
447480

448481
if not self.tosa_spec.is_U55_subset:
449482
# Tosa operator "RESIZE" is not supported on U55. Since upsample_bilinear2d

0 commit comments

Comments
 (0)