Skip to content

Commit 2c545f8

Browse files
Arm backend: Add MOD_SHAPE backend dialect op (#18836)
Adds new TOSA backend dialect op for MOD_SHAPE. MOD_SHAPE computes arg0 modulo arg1, i.e: def MOD_SHAPE(arg0, arg1): out_shape = [] for dim0, dim1 in zip(arg0, arg1): out_shape.append(dim0 % dim1) return out_shape Change-Id: I2f934eb251263a1973a0fc6c3d7723c8fb2a7bc1 cc @digantdesai @freddan80 @per @zingo @mansnils @Sebastian-Larsson @robell Signed-off-by: Oscar Andersson <oscar.andersson@arm.com>
1 parent a6e4ade commit 2c545f8

2 files changed

Lines changed: 64 additions & 0 deletions

File tree

backends/arm/test/misc/test_tosa_dialect_shape_ops.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -298,6 +298,55 @@ def test_mul_mixed_shape():
298298
assert _expr_equals(result[0], sympy.Integer(3) * sympy.Symbol("s0"))
299299

300300

301+
# Test MOD_SHAPE with constant values, which should perform modulo and return a constant shape.
302+
def test_mod_const_shape_no_target():
303+
shape_env = ShapeEnv()
304+
with TosaLoweringContext(
305+
TosaSpecification.create_from_string("TOSA-1.1+FP+shape"), shape_env
306+
), FakeTensorMode():
307+
const_0 = exir_ops.backend.tosa.CONST_SHAPE.default([8, 21])
308+
const_1 = exir_ops.backend.tosa.CONST_SHAPE.default([3, 5])
309+
result = exir_ops.backend.tosa.MOD_SHAPE.default(const_0, const_1)
310+
assert len(result) == 2
311+
assert result == [2, 1]
312+
313+
314+
# Test MOD_SHAPE with symbolic values, which should produce a Mod expression.
315+
def test_mod_symbolic_shape_no_target():
316+
shape_env = ShapeEnv()
317+
s0 = _make_symint(shape_env, "s0", hint=8)
318+
s1 = _make_symint(shape_env, "s1", hint=3)
319+
320+
with TosaLoweringContext(
321+
TosaSpecification.create_from_string("TOSA-1.1+FP+shape"), shape_env
322+
), FakeTensorMode(shape_env=shape_env) as mode:
323+
s0_tensor = torch.empty(size=(1, 3, s0))
324+
s1_tensor = torch.empty(size=(1, 3, s1))
325+
dim_s0 = exir_ops.backend.tosa.DIM.default(mode.from_tensor(s0_tensor), axis=2)
326+
dim_s1 = exir_ops.backend.tosa.DIM.default(mode.from_tensor(s1_tensor), axis=2)
327+
result = exir_ops.backend.tosa.MOD_SHAPE.default(dim_s0, dim_s1)
328+
assert len(result) == 1
329+
assert isinstance(result[0], torch.SymInt)
330+
assert _expr_equals(result[0], sympy.Mod(sympy.Symbol("s0"), sympy.Symbol("s1")))
331+
332+
333+
def test_mod_mixed_shape_no_target():
334+
shape_env = ShapeEnv()
335+
s0 = _make_symint(shape_env, "s0", hint=4)
336+
337+
with TosaLoweringContext(
338+
TosaSpecification.create_from_string("TOSA-1.1+FP+shape"), shape_env
339+
), FakeTensorMode(shape_env=shape_env) as mode:
340+
const_shape = exir_ops.backend.tosa.CONST_SHAPE.default([8])
341+
s0_tensor = torch.empty(size=(1, 3, s0))
342+
dim_s0 = exir_ops.backend.tosa.DIM.default(mode.from_tensor(s0_tensor), axis=2)
343+
result = exir_ops.backend.tosa.MOD_SHAPE.default(const_shape, dim_s0)
344+
345+
assert len(result) == 1
346+
assert isinstance(result[0], torch.SymInt)
347+
assert _expr_equals(result[0], sympy.Mod(sympy.Integer(8), sympy.Symbol("s0")))
348+
349+
301350
# Test DIV_FLOOR_SHAPE with constant values, which should perform floor division and return a constant shape.
302351
def test_div_floor_const_shape():
303352
shape_env = ShapeEnv()

backends/arm/tosa/dialect/ops/shape_ops.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -171,3 +171,18 @@ def MUL_SHAPE(
171171
"""
172172

173173
return _combine_shapes(shape1, shape2, lambda a, b: a * b)
174+
175+
176+
@register_fake_tosa_op(
177+
"MOD_SHAPE(SymInt[] shape1, SymInt[] shape2) -> SymInt[]", # schema
178+
TosaSpecification.all_profiles_for_version("1.1"),
179+
)
180+
def MOD_SHAPE(
181+
shape1: list[IntLikeType],
182+
shape2: list[IntLikeType],
183+
) -> list[IntLikeType]:
184+
"""MOD_SHAPE operator computes the element-wise modulo of the first shape
185+
tensor by the second.
186+
"""
187+
188+
return _combine_shapes(shape1, shape2, lambda a, b: a % b)

0 commit comments

Comments
 (0)