diff --git a/backends/arm/test/misc/test_tosa_dialect_shape_ops.py b/backends/arm/test/misc/test_tosa_dialect_shape_ops.py index c4acdd98bf0..878869cf5ee 100644 --- a/backends/arm/test/misc/test_tosa_dialect_shape_ops.py +++ b/backends/arm/test/misc/test_tosa_dialect_shape_ops.py @@ -298,6 +298,55 @@ def test_mul_mixed_shape(): assert _expr_equals(result[0], sympy.Integer(3) * sympy.Symbol("s0")) +# Test MOD_SHAPE with constant values, which should perform modulo and return a constant shape. +def test_mod_const_shape_no_target(): + shape_env = ShapeEnv() + with TosaLoweringContext( + TosaSpecification.create_from_string("TOSA-1.1+FP+shape"), shape_env + ), FakeTensorMode(): + const_0 = exir_ops.backend.tosa.CONST_SHAPE.default([8, 21]) + const_1 = exir_ops.backend.tosa.CONST_SHAPE.default([3, 5]) + result = exir_ops.backend.tosa.MOD_SHAPE.default(const_0, const_1) + assert len(result) == 2 + assert result == [2, 1] + + +# Test MOD_SHAPE with symbolic values, which should produce a Mod expression. +def test_mod_symbolic_shape_no_target(): + shape_env = ShapeEnv() + s0 = _make_symint(shape_env, "s0", hint=8) + s1 = _make_symint(shape_env, "s1", hint=3) + + with TosaLoweringContext( + TosaSpecification.create_from_string("TOSA-1.1+FP+shape"), shape_env + ), FakeTensorMode(shape_env=shape_env) as mode: + s0_tensor = torch.empty(size=(1, 3, s0)) + s1_tensor = torch.empty(size=(1, 3, s1)) + dim_s0 = exir_ops.backend.tosa.DIM.default(mode.from_tensor(s0_tensor), axis=2) + dim_s1 = exir_ops.backend.tosa.DIM.default(mode.from_tensor(s1_tensor), axis=2) + result = exir_ops.backend.tosa.MOD_SHAPE.default(dim_s0, dim_s1) + assert len(result) == 1 + assert isinstance(result[0], torch.SymInt) + assert _expr_equals(result[0], sympy.Mod(sympy.Symbol("s0"), sympy.Symbol("s1"))) + + +def test_mod_mixed_shape_no_target(): + shape_env = ShapeEnv() + s0 = _make_symint(shape_env, "s0", hint=4) + + with TosaLoweringContext( + TosaSpecification.create_from_string("TOSA-1.1+FP+shape"), shape_env + ), FakeTensorMode(shape_env=shape_env) as mode: + const_shape = exir_ops.backend.tosa.CONST_SHAPE.default([8]) + s0_tensor = torch.empty(size=(1, 3, s0)) + dim_s0 = exir_ops.backend.tosa.DIM.default(mode.from_tensor(s0_tensor), axis=2) + result = exir_ops.backend.tosa.MOD_SHAPE.default(const_shape, dim_s0) + + assert len(result) == 1 + assert isinstance(result[0], torch.SymInt) + assert _expr_equals(result[0], sympy.Mod(sympy.Integer(8), sympy.Symbol("s0"))) + + # Test DIV_FLOOR_SHAPE with constant values, which should perform floor division and return a constant shape. def test_div_floor_const_shape(): shape_env = ShapeEnv() diff --git a/backends/arm/tosa/dialect/ops/shape_ops.py b/backends/arm/tosa/dialect/ops/shape_ops.py index 0f36420697b..edeb731620d 100644 --- a/backends/arm/tosa/dialect/ops/shape_ops.py +++ b/backends/arm/tosa/dialect/ops/shape_ops.py @@ -171,3 +171,18 @@ def MUL_SHAPE( """ return _combine_shapes(shape1, shape2, lambda a, b: a * b) + + +@register_fake_tosa_op( + "MOD_SHAPE(SymInt[] shape1, SymInt[] shape2) -> SymInt[]", # schema + TosaSpecification.all_profiles_for_version("1.1"), +) +def MOD_SHAPE( + shape1: list[IntLikeType], + shape2: list[IntLikeType], +) -> list[IntLikeType]: + """MOD_SHAPE operator computes the element-wise modulo of the first shape + tensor by the second. + """ + + return _combine_shapes(shape1, shape2, lambda a, b: a % b)