@@ -298,6 +298,55 @@ def test_mul_mixed_shape():
298298 assert _expr_equals (result [0 ], sympy .Integer (3 ) * sympy .Symbol ("s0" ))
299299
300300
301+ # Test MOD_SHAPE with constant values, which should perform modulo and return a constant shape.
302+ def test_mod_const_shape_no_target ():
303+ shape_env = ShapeEnv ()
304+ with TosaLoweringContext (
305+ TosaSpecification .create_from_string ("TOSA-1.1+FP+shape" ), shape_env
306+ ), FakeTensorMode ():
307+ const_0 = exir_ops .backend .tosa .CONST_SHAPE .default ([8 , 21 ])
308+ const_1 = exir_ops .backend .tosa .CONST_SHAPE .default ([3 , 5 ])
309+ result = exir_ops .backend .tosa .MOD_SHAPE .default (const_0 , const_1 )
310+ assert len (result ) == 2
311+ assert result == [2 , 1 ]
312+
313+
314+ # Test MOD_SHAPE with symbolic values, which should produce a Mod expression.
315+ def test_mod_symbolic_shape_no_target ():
316+ shape_env = ShapeEnv ()
317+ s0 = _make_symint (shape_env , "s0" , hint = 8 )
318+ s1 = _make_symint (shape_env , "s1" , hint = 3 )
319+
320+ with TosaLoweringContext (
321+ TosaSpecification .create_from_string ("TOSA-1.1+FP+shape" ), shape_env
322+ ), FakeTensorMode (shape_env = shape_env ) as mode :
323+ s0_tensor = torch .empty (size = (1 , 3 , s0 ))
324+ s1_tensor = torch .empty (size = (1 , 3 , s1 ))
325+ dim_s0 = exir_ops .backend .tosa .DIM .default (mode .from_tensor (s0_tensor ), axis = 2 )
326+ dim_s1 = exir_ops .backend .tosa .DIM .default (mode .from_tensor (s1_tensor ), axis = 2 )
327+ result = exir_ops .backend .tosa .MOD_SHAPE .default (dim_s0 , dim_s1 )
328+ assert len (result ) == 1
329+ assert isinstance (result [0 ], torch .SymInt )
330+ assert _expr_equals (result [0 ], sympy .Mod (sympy .Symbol ("s0" ), sympy .Symbol ("s1" )))
331+
332+
333+ def test_mod_mixed_shape_no_target ():
334+ shape_env = ShapeEnv ()
335+ s0 = _make_symint (shape_env , "s0" , hint = 4 )
336+
337+ with TosaLoweringContext (
338+ TosaSpecification .create_from_string ("TOSA-1.1+FP+shape" ), shape_env
339+ ), FakeTensorMode (shape_env = shape_env ) as mode :
340+ const_shape = exir_ops .backend .tosa .CONST_SHAPE .default ([8 ])
341+ s0_tensor = torch .empty (size = (1 , 3 , s0 ))
342+ dim_s0 = exir_ops .backend .tosa .DIM .default (mode .from_tensor (s0_tensor ), axis = 2 )
343+ result = exir_ops .backend .tosa .MOD_SHAPE .default (const_shape , dim_s0 )
344+
345+ assert len (result ) == 1
346+ assert isinstance (result [0 ], torch .SymInt )
347+ assert _expr_equals (result [0 ], sympy .Mod (sympy .Integer (8 ), sympy .Symbol ("s0" )))
348+
349+
301350# Test DIV_FLOOR_SHAPE with constant values, which should perform floor division and return a constant shape.
302351def test_div_floor_const_shape ():
303352 shape_env = ShapeEnv ()
0 commit comments