Skip to content

Commit 95dcfb9

Browse files
Arm backend: Add TOSA dialect shape ops
Introduce TOSA shape ops for TOSA 1.1 to be able to materialize symints to TOSA. Shape ops are currently not serialized. Tracing functionality and validation is done in test_tosa_dialect_shape_ops.py. Also adds new field shape_env to TosaLoweringContext. Signed-off-by: Oscar Andersson <oscar.andersson@arm.com> Co-authored-by: Per Åstrand <per.astrand@arm.com> Change-Id: I051934619d1e95a5f7493cc4fdf49d6c76eb8eae
1 parent 5193141 commit 95dcfb9

4 files changed

Lines changed: 534 additions & 5 deletions

File tree

backends/arm/_passes/arm_pass_manager.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -220,7 +220,8 @@ def add_passes(self, passes: Sequence[ExportPass | None]):
220220
self.add_pass(p)
221221

222222
def _transform(self, graph_module: GraphModule):
223-
with TosaLoweringContext(self.tosa_spec):
223+
shape_env = graph_module.shape_env
224+
with TosaLoweringContext(self.tosa_spec, shape_env):
224225
return self(graph_module).graph_module
225226

226227
def add_pass(self, pipeline_pass):
Lines changed: 347 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,347 @@
1+
# Copyright 2026 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
#
6+
7+
import executorch.backends.arm.tosa.dialect # noqa: F401
8+
import pytest
9+
import sympy # type: ignore
10+
import torch
11+
from executorch.backends.arm.tosa.dialect.lib import TosaValueError
12+
from executorch.backends.arm.tosa.specification import (
13+
TosaLoweringContext,
14+
TosaSpecification,
15+
)
16+
from executorch.exir.dialects._ops import ops as exir_ops
17+
from torch._subclasses.fake_tensor import FakeTensorMode
18+
from torch.fx.experimental.symbolic_shapes import ShapeEnv
19+
20+
21+
def _make_symint(
22+
shape_env: ShapeEnv, symbol: str, hint: int, min: int = 1, max: int = 64
23+
) -> torch.SymInt:
24+
"""Create a symbolic dimension backed by the provided ShapeEnv."""
25+
symint = shape_env.create_symintnode(sympy.Symbol(symbol), hint=hint)
26+
symbol_expr = symint.node.expr
27+
shape_env.constrain_symbol_range(symbol_expr, compiler_min=min, compiler_max=max)
28+
return symint
29+
30+
31+
def _expr(sym: torch.SymInt) -> str:
32+
"""Return the SymPy expression backing a SymInt."""
33+
return str(sym.node._expr)
34+
35+
36+
def _expr_equals(sym: torch.SymInt, expected: sympy.Expr) -> bool:
37+
"""Return True if the SymPy expressions are equivalent."""
38+
actual = sympy.sympify(_expr(sym))
39+
expected_expr = sympy.sympify(expected)
40+
return sympy.simplify(actual - expected_expr) == 0
41+
42+
43+
# Test that DIM can extract a symbolic dimension from a tensor when the TOSA specification supports the shape extension.
44+
def test_dim_extracts_symbolic_dimension_no_target():
45+
shape_env = ShapeEnv()
46+
s0 = _make_symint(shape_env, "s0", hint=4)
47+
48+
with TosaLoweringContext(
49+
TosaSpecification.create_from_string("TOSA-1.1+FP+shape"), shape_env
50+
), FakeTensorMode(shape_env=shape_env) as mode:
51+
s0_tensor = torch.empty(size=(1, 3, s0))
52+
result = exir_ops.backend.tosa.DIM.default(mode.from_tensor(s0_tensor), axis=2)
53+
54+
assert isinstance(result, list)
55+
assert len(result) == 1
56+
assert isinstance(result[0], torch.SymInt)
57+
assert _expr(result[0]) == "s0"
58+
59+
60+
# Test that DIM raises an error when the TOSA specification doesn't support the shape extension, as DIM relies on shape
61+
# expressions to return symbolic dimensions.
62+
def test_dim_requires_shape_extension_no_target():
63+
spec_no_shape = TosaSpecification.create_from_string("TOSA-1.0+FP")
64+
shape_env = ShapeEnv()
65+
s0 = _make_symint(shape_env, "s0", hint=3)
66+
67+
with TosaLoweringContext(
68+
spec_no_shape,
69+
shape_env,
70+
), FakeTensorMode(shape_env=shape_env) as mode:
71+
s0_tensor = torch.empty(size=(1, 3, s0))
72+
with pytest.raises(TosaValueError, match="shape extension"):
73+
exir_ops.backend.tosa.DIM.default(mode.from_tensor(s0_tensor), axis=2)
74+
75+
76+
# Test that CONST_SHAPE creates a constant shape tensor and returns the expected shape list.
77+
def test_const_shape_no_target():
78+
with TosaLoweringContext(
79+
TosaSpecification.create_from_string("TOSA-1.1+FP+shape")
80+
), FakeTensorMode():
81+
shape = exir_ops.backend.tosa.CONST_SHAPE.default([2, 3, 4])
82+
assert shape == [2, 3, 4]
83+
84+
85+
# Test that CONCAT_SHAPE with constant shapes performs concatenation and returns a constant shape.
86+
def test_concat_const_shapes_no_target():
87+
with TosaLoweringContext(
88+
TosaSpecification.create_from_string("TOSA-1.1+FP+shape")
89+
), FakeTensorMode():
90+
const_shape_0 = exir_ops.backend.tosa.CONST_SHAPE.default([2, 3])
91+
const_shape_1 = exir_ops.backend.tosa.CONST_SHAPE.default([4])
92+
result = exir_ops.backend.tosa.CONCAT_SHAPE.default(
93+
[const_shape_0, const_shape_1]
94+
)
95+
assert result == [2, 3, 4]
96+
97+
98+
# Test that CONCAT_SHAPE with symbolic shapes produces a symbolic expression concatenating the dimensions.
99+
def test_concat_symbolic_shape_no_target():
100+
shape_env = ShapeEnv()
101+
s0 = _make_symint(shape_env, "s0", hint=2)
102+
s1 = _make_symint(shape_env, "s1", hint=3)
103+
104+
with TosaLoweringContext(
105+
TosaSpecification.create_from_string("TOSA-1.1+FP+shape"), shape_env
106+
), FakeTensorMode(shape_env=shape_env) as mode:
107+
s0_tensor = torch.empty(size=(1, 3, s0))
108+
s1_tensor = torch.empty(size=(1, 3, s1))
109+
110+
dim_s0 = exir_ops.backend.tosa.DIM.default(mode.from_tensor(s0_tensor), axis=2)
111+
dim_s1 = exir_ops.backend.tosa.DIM.default(mode.from_tensor(s1_tensor), axis=2)
112+
result = exir_ops.backend.tosa.CONCAT_SHAPE.default([dim_s0, dim_s1])
113+
114+
assert len(result) == 2
115+
assert _expr(result[0]) == "s0"
116+
assert _expr(result[1]) == "s1"
117+
118+
119+
def test_concat_mixed_shape_no_target():
120+
shape_env = ShapeEnv()
121+
s0 = _make_symint(shape_env, "s0", hint=2)
122+
123+
with TosaLoweringContext(
124+
TosaSpecification.create_from_string("TOSA-1.1+FP+shape")
125+
), FakeTensorMode(shape_env=shape_env) as mode:
126+
const_shape = exir_ops.backend.tosa.CONST_SHAPE.default([4, 5])
127+
s0_tensor = torch.empty(size=(1, 3, s0))
128+
dim_s0 = exir_ops.backend.tosa.DIM.default(mode.from_tensor(s0_tensor), axis=2)
129+
result = exir_ops.backend.tosa.CONCAT_SHAPE.default([const_shape, dim_s0])
130+
131+
assert len(result) == 3
132+
assert result[0] == 4
133+
assert result[1] == 5
134+
assert _expr(result[2]) == "s0"
135+
136+
137+
# Test that CONCAT_SHAPE raises an error when given fewer than 2 shape tensors, as it requires at least 2 to
138+
# concatenate.
139+
def test_concat_shape_requires_arguments_no_target():
140+
with pytest.raises(
141+
TosaValueError, match="CONCAT_SHAPE expected 2 or more shape tensors"
142+
):
143+
with TosaLoweringContext(
144+
TosaSpecification.create_from_string("TOSA-1.1+FP+shape")
145+
), FakeTensorMode():
146+
exir_ops.backend.tosa.CONCAT_SHAPE.default([])
147+
148+
149+
# Test ADD_SHAPE with constant values, which should perform elementwise addition and return a constant shape.
150+
def test_add_const_shape_no_target():
151+
shape_env = ShapeEnv()
152+
with TosaLoweringContext(
153+
TosaSpecification.create_from_string("TOSA-1.1+FP+shape"), shape_env
154+
), FakeTensorMode():
155+
const_0 = exir_ops.backend.tosa.CONST_SHAPE.default([2, 3])
156+
const_1 = exir_ops.backend.tosa.CONST_SHAPE.default([4, 5])
157+
result = exir_ops.backend.tosa.ADD_SHAPE.default(const_0, const_1)
158+
assert len(result) == 2
159+
assert result == [6, 8]
160+
161+
162+
# Test ADD_SHAPE with symbolic values, which should produce a symbolic expression adding the two dimensions.
163+
def test_add_symbolic_shape_no_target():
164+
shape_env = ShapeEnv()
165+
s0 = _make_symint(shape_env, "s0", hint=2)
166+
s1 = _make_symint(shape_env, "s1", hint=3)
167+
168+
with TosaLoweringContext(
169+
TosaSpecification.create_from_string("TOSA-1.1+FP+shape"), shape_env
170+
), FakeTensorMode(shape_env=shape_env) as mode:
171+
s0_tensor = torch.empty(size=(1, 3, s0))
172+
s1_tensor = torch.empty(size=(1, 3, s1))
173+
174+
dim_s0 = exir_ops.backend.tosa.DIM.default(mode.from_tensor(s0_tensor), axis=2)
175+
dim_s1 = exir_ops.backend.tosa.DIM.default(mode.from_tensor(s1_tensor), axis=2)
176+
result = exir_ops.backend.tosa.ADD_SHAPE.default(dim_s0, dim_s1)
177+
assert len(result) == 1
178+
assert isinstance(result[0], torch.SymInt)
179+
assert _expr_equals(result[0], sympy.Symbol("s0") + sympy.Symbol("s1"))
180+
181+
182+
def test_add_mixed_shape_no_target():
183+
shape_env = ShapeEnv()
184+
s0 = _make_symint(shape_env, "s0", hint=2)
185+
186+
with TosaLoweringContext(
187+
TosaSpecification.create_from_string("TOSA-1.1+FP+shape"), shape_env
188+
), FakeTensorMode(shape_env=shape_env) as mode:
189+
const_shape = exir_ops.backend.tosa.CONST_SHAPE.default([4])
190+
s0_tensor = torch.empty(size=(1, 3, s0))
191+
dim_s0 = exir_ops.backend.tosa.DIM.default(mode.from_tensor(s0_tensor), axis=2)
192+
result = exir_ops.backend.tosa.ADD_SHAPE.default(const_shape, dim_s0)
193+
194+
assert len(result) == 1
195+
assert isinstance(result[0], torch.SymInt)
196+
assert _expr_equals(result[0], sympy.Symbol("s0") + sympy.Integer(4))
197+
198+
199+
# Test SUB_SHAPE with constant values, which should perform subtraction and return a constant shape.
200+
def test_sub_const_shape_no_target():
201+
shape_env = ShapeEnv()
202+
with TosaLoweringContext(
203+
TosaSpecification.create_from_string("TOSA-1.1+FP+shape"), shape_env
204+
), FakeTensorMode():
205+
const_0 = exir_ops.backend.tosa.CONST_SHAPE.default([6, 5])
206+
const_1 = exir_ops.backend.tosa.CONST_SHAPE.default([2, 3])
207+
result = exir_ops.backend.tosa.SUB_SHAPE.default(const_0, const_1)
208+
assert len(result) == 2
209+
assert result == [4, 2]
210+
211+
212+
# Test SUB_SHAPE with symbolic values, which should produce a Sub expression.
213+
def test_sub_symbolic_shape_no_target():
214+
shape_env = ShapeEnv()
215+
s0 = _make_symint(shape_env, "s0", hint=2)
216+
s1 = _make_symint(shape_env, "s1", hint=3)
217+
218+
with TosaLoweringContext(
219+
TosaSpecification.create_from_string("TOSA-1.1+FP+shape"),
220+
shape_env,
221+
), FakeTensorMode(shape_env=shape_env) as mode:
222+
s0_tensor = torch.empty(size=(1, 3, s0))
223+
s1_tensor = torch.empty(size=(1, 3, s1))
224+
225+
dim_s0 = exir_ops.backend.tosa.DIM.default(mode.from_tensor(s0_tensor), axis=2)
226+
dim_s1 = exir_ops.backend.tosa.DIM.default(mode.from_tensor(s1_tensor), axis=2)
227+
result = exir_ops.backend.tosa.SUB_SHAPE.default(dim_s0, dim_s1)
228+
assert len(result) == 1
229+
assert isinstance(result[0], torch.SymInt)
230+
assert _expr_equals(result[0], sympy.Symbol("s0") - sympy.Symbol("s1"))
231+
232+
233+
def test_sub_mixed_shape_no_target():
234+
shape_env = ShapeEnv()
235+
s0 = _make_symint(shape_env, "s0", hint=3)
236+
237+
with TosaLoweringContext(
238+
TosaSpecification.create_from_string("TOSA-1.1+FP+shape"),
239+
shape_env,
240+
), FakeTensorMode(shape_env=shape_env) as mode:
241+
const_shape = exir_ops.backend.tosa.CONST_SHAPE.default([6])
242+
s0_tensor = torch.empty(size=(1, 3, s0))
243+
dim_s0 = exir_ops.backend.tosa.DIM.default(mode.from_tensor(s0_tensor), axis=2)
244+
result = exir_ops.backend.tosa.SUB_SHAPE.default(const_shape, dim_s0)
245+
246+
assert len(result) == 1
247+
assert isinstance(result[0], torch.SymInt)
248+
assert _expr_equals(result[0], sympy.Integer(6) - sympy.Symbol("s0"))
249+
250+
251+
# Test MUL_SHAPE with constant values, which should perform multiplication and return a constant shape.
252+
def test_mul_const_shape_no_target():
253+
shape_env = ShapeEnv()
254+
with TosaLoweringContext(
255+
TosaSpecification.create_from_string("TOSA-1.1+FP+shape"), shape_env
256+
), FakeTensorMode():
257+
const_0 = exir_ops.backend.tosa.CONST_SHAPE.default([2, 3])
258+
const_1 = exir_ops.backend.tosa.CONST_SHAPE.default([4, 5])
259+
result = exir_ops.backend.tosa.MUL_SHAPE.default(const_0, const_1)
260+
assert len(result) == 2
261+
assert result == [8, 15]
262+
263+
264+
# Test MUL_SHAPE with symbolic values, which should produce a Mul expression.
265+
def test_mul_symbolic_shape_no_target():
266+
shape_env = ShapeEnv()
267+
s0 = _make_symint(shape_env, "s0", hint=2)
268+
s1 = _make_symint(shape_env, "s1", hint=3)
269+
270+
with TosaLoweringContext(
271+
TosaSpecification.create_from_string("TOSA-1.1+FP+shape"), shape_env
272+
), FakeTensorMode(shape_env=shape_env) as mode:
273+
s0_tensor = torch.empty(size=(1, 3, s0))
274+
s1_tensor = torch.empty(size=(1, 3, s1))
275+
276+
dim_s0 = exir_ops.backend.tosa.DIM.default(mode.from_tensor(s0_tensor), axis=2)
277+
dim_s1 = exir_ops.backend.tosa.DIM.default(mode.from_tensor(s1_tensor), axis=2)
278+
result = exir_ops.backend.tosa.MUL_SHAPE.default(dim_s0, dim_s1)
279+
assert len(result) == 1
280+
assert isinstance(result[0], torch.SymInt)
281+
assert _expr_equals(result[0], sympy.Symbol("s0") * sympy.Symbol("s1"))
282+
283+
284+
def test_mul_mixed_shape_no_target():
285+
shape_env = ShapeEnv()
286+
s0 = _make_symint(shape_env, "s0", hint=4)
287+
288+
with TosaLoweringContext(
289+
TosaSpecification.create_from_string("TOSA-1.1+FP+shape"), shape_env
290+
), FakeTensorMode(shape_env=shape_env) as mode:
291+
const_shape = exir_ops.backend.tosa.CONST_SHAPE.default([3])
292+
s0_tensor = torch.empty(size=(1, 3, s0))
293+
dim_s0 = exir_ops.backend.tosa.DIM.default(mode.from_tensor(s0_tensor), axis=2)
294+
result = exir_ops.backend.tosa.MUL_SHAPE.default(const_shape, dim_s0)
295+
296+
assert len(result) == 1
297+
assert isinstance(result[0], torch.SymInt)
298+
assert _expr_equals(result[0], sympy.Integer(3) * sympy.Symbol("s0"))
299+
300+
301+
# Test DIV_FLOOR_SHAPE with constant values, which should perform floor division and return a constant shape.
302+
def test_div_floor_const_shape_no_target():
303+
shape_env = ShapeEnv()
304+
with TosaLoweringContext(
305+
TosaSpecification.create_from_string("TOSA-1.1+FP+shape"), shape_env
306+
), FakeTensorMode():
307+
const_0 = exir_ops.backend.tosa.CONST_SHAPE.default([8, 21])
308+
const_1 = exir_ops.backend.tosa.CONST_SHAPE.default([2, 4])
309+
result = exir_ops.backend.tosa.DIV_FLOOR_SHAPE.default(const_0, const_1)
310+
assert len(result) == 2
311+
assert result == [4, 5]
312+
313+
314+
# Test DIV_FLOOR_SHAPE with symbolic values, which should produce a FloorDiv expression.
315+
def test_div_floor_symbolic_shape_no_target():
316+
shape_env = ShapeEnv()
317+
s0 = _make_symint(shape_env, "s0", hint=8)
318+
s1 = _make_symint(shape_env, "s1", hint=3)
319+
320+
with TosaLoweringContext(
321+
TosaSpecification.create_from_string("TOSA-1.1+FP+shape"), shape_env
322+
), FakeTensorMode(shape_env=shape_env) as mode:
323+
s0_tensor = torch.empty(size=(1, 3, s0))
324+
s1_tensor = torch.empty(size=(1, 3, s1))
325+
dim_s0 = exir_ops.backend.tosa.DIM.default(mode.from_tensor(s0_tensor), axis=2)
326+
dim_s1 = exir_ops.backend.tosa.DIM.default(mode.from_tensor(s1_tensor), axis=2)
327+
result = exir_ops.backend.tosa.DIV_FLOOR_SHAPE.default(dim_s0, dim_s1)
328+
assert len(result) == 1
329+
assert isinstance(result[0], torch.SymInt)
330+
assert _expr_equals(result[0], sympy.sympify("(s0//s1)"))
331+
332+
333+
def test_div_floor_mixed_shape_no_target():
334+
shape_env = ShapeEnv()
335+
s0 = _make_symint(shape_env, "s0", hint=4)
336+
337+
with TosaLoweringContext(
338+
TosaSpecification.create_from_string("TOSA-1.1+FP+shape"), shape_env
339+
), FakeTensorMode(shape_env=shape_env) as mode:
340+
const_shape = exir_ops.backend.tosa.CONST_SHAPE.default([8])
341+
s0_tensor = torch.empty(size=(1, 3, s0))
342+
dim_s0 = exir_ops.backend.tosa.DIM.default(mode.from_tensor(s0_tensor), axis=2)
343+
result = exir_ops.backend.tosa.DIV_FLOOR_SHAPE.default(const_shape, dim_s0)
344+
345+
assert len(result) == 1
346+
assert isinstance(result[0], torch.SymInt)
347+
assert _expr_equals(result[0], sympy.sympify("8//s0"))

0 commit comments

Comments
 (0)