|
| 1 | +# Copyright 2026 Arm Limited and/or its affiliates. |
| 2 | +# |
| 3 | +# This source code is licensed under the BSD-style license found in the |
| 4 | +# LICENSE file in the root directory of this source tree. |
| 5 | +"""Fake-op support for the generic TOSA ``CUSTOM`` dialect op. |
| 6 | +
|
| 7 | +The serialized TOSA ``CUSTOM`` op is intentionally generic: it carries a |
| 8 | +stable operator identity (for example ``myns.my_op``) plus an |
| 9 | +opaque payload in ``implementation_attrs``. That is enough for serialization, |
| 10 | +but not enough for FakeTensor propagation unless we also teach the compiler how |
| 11 | +to model the output tensors of the specific wrapped op. |
| 12 | +
|
| 13 | +This module provides a lightweight registration mechanism for those compiler |
| 14 | +side fake implementations: |
| 15 | +
|
| 16 | +1. A lowering pass rewrites an op to ``exir_ops.backend.tosa.CUSTOM.default``. |
| 17 | +2. The wrapped custom op registers a thin adapter with |
| 18 | + ``@register_fake_tosa("namespace::op")``. |
| 19 | +3. The generic ``CUSTOM`` fake implementation looks up that adapter by the |
| 20 | + ``operator_name`` argument and invokes it with the full custom-op calling |
| 21 | + convention ``(inputs, operator_name, domain_name, implementation_attrs)``. |
| 22 | +
|
| 23 | +The adapter should stay thin: it should only translate from the generic TOSA |
| 24 | +CUSTOM signature back to the wrapped op's fake semantics. The real semantic |
| 25 | +logic should continue to live in the original fake implementation where |
| 26 | +possible. |
| 27 | +
|
| 28 | +""" |
| 29 | + |
| 30 | +import inspect |
| 31 | +from collections.abc import Callable |
| 32 | + |
| 33 | +import torch |
| 34 | +from executorch.backends.arm.tosa.dialect.ops_registration import register_fake_tosa_op |
| 35 | + |
| 36 | +from executorch.backends.arm.tosa.specification import ( |
| 37 | + get_context_spec, |
| 38 | + TosaSpecification, |
| 39 | +) |
| 40 | + |
| 41 | +_TOSA_CUSTOM_FAKE_IMPLS: dict[str, Callable] = {} |
| 42 | + |
| 43 | + |
| 44 | +def _normalize_tosa_custom_operator_name(operator_name: str) -> str: |
| 45 | + """Normalize operator names so ``ns::op`` and ``ns.op`` map identically.""" |
| 46 | + return operator_name.replace("::", ".") |
| 47 | + |
| 48 | + |
| 49 | +def validate_tosa_custom_fake_impl(fake_impl: object) -> Callable: |
| 50 | + """Validate the signature expected by ``register_fake_tosa``. |
| 51 | +
|
| 52 | + Registered fake implementations must accept the generic TOSA CUSTOM fake |
| 53 | + calling convention: |
| 54 | +
|
| 55 | + ``(inputs, operator_name, domain_name, implementation_attrs)`` |
| 56 | +
|
| 57 | + and return ``list[Tensor]``. |
| 58 | +
|
| 59 | + """ |
| 60 | + if not callable(fake_impl): |
| 61 | + raise TypeError( |
| 62 | + "Expected tosa.CUSTOM fake impl to be callable, " f"got {type(fake_impl)}" |
| 63 | + ) |
| 64 | + |
| 65 | + params = tuple(inspect.signature(fake_impl).parameters.values()) |
| 66 | + positional_kinds = { |
| 67 | + inspect.Parameter.POSITIONAL_ONLY, |
| 68 | + inspect.Parameter.POSITIONAL_OR_KEYWORD, |
| 69 | + } |
| 70 | + if len(params) != 4 or any(param.kind not in positional_kinds for param in params): |
| 71 | + raise TypeError( |
| 72 | + "tosa.CUSTOM fake impl must have signature " |
| 73 | + "(inputs, operator_name, domain_name, implementation_attrs)" |
| 74 | + ) |
| 75 | + return fake_impl |
| 76 | + |
| 77 | + |
| 78 | +def register_fake_tosa(operator_name: str) -> Callable[[Callable], Callable]: |
| 79 | + """Register a fake implementation for a specific wrapped TOSA custom op. |
| 80 | +
|
| 81 | + Args: |
| 82 | + operator_name: Stable custom operator identifier. Both ``ns::op`` and |
| 83 | + ``ns.op`` spellings are accepted. |
| 84 | +
|
| 85 | + Returns: |
| 86 | + A decorator that registers a callable with signature |
| 87 | + ``(inputs, operator_name, domain_name, implementation_attrs)`` and |
| 88 | + returning ``list[Tensor]``. |
| 89 | +
|
| 90 | + Example: |
| 91 | + ``@register_fake_tosa("my_namespace::my_op")`` |
| 92 | +
|
| 93 | + """ |
| 94 | + normalized_name = _normalize_tosa_custom_operator_name(operator_name) |
| 95 | + |
| 96 | + def decorator(fake_impl: Callable) -> Callable: |
| 97 | + validated = validate_tosa_custom_fake_impl(fake_impl) |
| 98 | + _TOSA_CUSTOM_FAKE_IMPLS[normalized_name] = validated |
| 99 | + return fake_impl |
| 100 | + |
| 101 | + return decorator |
| 102 | + |
| 103 | + |
| 104 | +def has_fake_tosa_impl(operator_name: str) -> bool: |
| 105 | + """Return whether a wrapped custom op has a registered fake impl.""" |
| 106 | + normalized_name = _normalize_tosa_custom_operator_name(operator_name) |
| 107 | + return normalized_name in _TOSA_CUSTOM_FAKE_IMPLS |
| 108 | + |
| 109 | + |
| 110 | +def run_registered_fake_tosa_impl( |
| 111 | + inputs: list[torch.Tensor], |
| 112 | + operator_name: str, |
| 113 | + domain_name: str, |
| 114 | + implementation_attrs: list[int], |
| 115 | +) -> list[torch.Tensor]: |
| 116 | + """Invoke the registered fake implementation for a wrapped custom op.""" |
| 117 | + normalized_name = _normalize_tosa_custom_operator_name(operator_name) |
| 118 | + fake_impl = _TOSA_CUSTOM_FAKE_IMPLS.get(normalized_name) |
| 119 | + if fake_impl is None: |
| 120 | + raise RuntimeError( |
| 121 | + f"tosa.CUSTOM requires a registered fake impl for {normalized_name}" |
| 122 | + ) |
| 123 | + outputs = fake_impl(inputs, operator_name, domain_name, implementation_attrs) |
| 124 | + if not isinstance(outputs, list): |
| 125 | + raise TypeError( |
| 126 | + "tosa.CUSTOM fake impl must return list[Tensor], " f"got {type(outputs)}" |
| 127 | + ) |
| 128 | + if not outputs: |
| 129 | + raise RuntimeError("tosa.CUSTOM fake impl must return at least one output") |
| 130 | + if not all(isinstance(output, torch.Tensor) for output in outputs): |
| 131 | + raise TypeError("tosa.CUSTOM fake impl must return list[Tensor]") |
| 132 | + return outputs |
| 133 | + |
| 134 | + |
| 135 | +@register_fake_tosa_op( |
| 136 | + "CUSTOM(Tensor[] inputs, str operator_name, str domain_name, int[] implementation_attrs) -> Tensor[]", |
| 137 | + TosaSpecification.all_versions_and_profiles(), |
| 138 | +) |
| 139 | +def CUSTOM( |
| 140 | + inputs: list[torch.Tensor], |
| 141 | + operator_name: str, |
| 142 | + domain_name: str, |
| 143 | + implementation_attrs: list[int], |
| 144 | +) -> list[torch.Tensor]: |
| 145 | + """Fake implementation for TOSA CUSTOM op. |
| 146 | +
|
| 147 | + The CUSTOM op is backend-defined. The fake implementation dispatches to a |
| 148 | + registered compiler-side fake implementation for the specific custom op. |
| 149 | +
|
| 150 | + """ |
| 151 | + _ = get_context_spec() # ensure a spec context exists |
| 152 | + if not inputs: |
| 153 | + raise RuntimeError("tosa.CUSTOM requires at least one input tensor") |
| 154 | + return run_registered_fake_tosa_impl( |
| 155 | + inputs, |
| 156 | + operator_name, |
| 157 | + domain_name, |
| 158 | + implementation_attrs, |
| 159 | + ) |
0 commit comments