From 31f28de3db244f9714f244b93ba560d5e492aa93 Mon Sep 17 00:00:00 2001 From: Oscar Andersson Date: Thu, 26 Mar 2026 13:04:25 +0100 Subject: [PATCH] Arm backend: Add MOD_SHAPE backend dialect op Adds new TOSA backend dialect op for MOD_SHAPE. MOD_SHAPE computes arg0 modulo arg1, i.e: def MOD_SHAPE(arg0, arg1): out_shape = [] for dim0, dim1 in zip(arg0, arg1): out_shape.append(dim0 % dim1) return out_shape Signed-off-by: Oscar Andersson Change-Id: I2f934eb251263a1973a0fc6c3d7723c8fb2a7bc1 --- .../test/misc/test_tosa_dialect_shape_ops.py | 49 +++++++++++++++++++ backends/arm/tosa/dialect/ops/shape_ops.py | 15 ++++++ 2 files changed, 64 insertions(+) diff --git a/backends/arm/test/misc/test_tosa_dialect_shape_ops.py b/backends/arm/test/misc/test_tosa_dialect_shape_ops.py index c4acdd98bf0..878869cf5ee 100644 --- a/backends/arm/test/misc/test_tosa_dialect_shape_ops.py +++ b/backends/arm/test/misc/test_tosa_dialect_shape_ops.py @@ -298,6 +298,55 @@ def test_mul_mixed_shape(): assert _expr_equals(result[0], sympy.Integer(3) * sympy.Symbol("s0")) +# Test MOD_SHAPE with constant values, which should perform modulo and return a constant shape. +def test_mod_const_shape_no_target(): + shape_env = ShapeEnv() + with TosaLoweringContext( + TosaSpecification.create_from_string("TOSA-1.1+FP+shape"), shape_env + ), FakeTensorMode(): + const_0 = exir_ops.backend.tosa.CONST_SHAPE.default([8, 21]) + const_1 = exir_ops.backend.tosa.CONST_SHAPE.default([3, 5]) + result = exir_ops.backend.tosa.MOD_SHAPE.default(const_0, const_1) + assert len(result) == 2 + assert result == [2, 1] + + +# Test MOD_SHAPE with symbolic values, which should produce a Mod expression. +def test_mod_symbolic_shape_no_target(): + shape_env = ShapeEnv() + s0 = _make_symint(shape_env, "s0", hint=8) + s1 = _make_symint(shape_env, "s1", hint=3) + + with TosaLoweringContext( + TosaSpecification.create_from_string("TOSA-1.1+FP+shape"), shape_env + ), FakeTensorMode(shape_env=shape_env) as mode: + s0_tensor = torch.empty(size=(1, 3, s0)) + s1_tensor = torch.empty(size=(1, 3, s1)) + dim_s0 = exir_ops.backend.tosa.DIM.default(mode.from_tensor(s0_tensor), axis=2) + dim_s1 = exir_ops.backend.tosa.DIM.default(mode.from_tensor(s1_tensor), axis=2) + result = exir_ops.backend.tosa.MOD_SHAPE.default(dim_s0, dim_s1) + assert len(result) == 1 + assert isinstance(result[0], torch.SymInt) + assert _expr_equals(result[0], sympy.Mod(sympy.Symbol("s0"), sympy.Symbol("s1"))) + + +def test_mod_mixed_shape_no_target(): + shape_env = ShapeEnv() + s0 = _make_symint(shape_env, "s0", hint=4) + + with TosaLoweringContext( + TosaSpecification.create_from_string("TOSA-1.1+FP+shape"), shape_env + ), FakeTensorMode(shape_env=shape_env) as mode: + const_shape = exir_ops.backend.tosa.CONST_SHAPE.default([8]) + s0_tensor = torch.empty(size=(1, 3, s0)) + dim_s0 = exir_ops.backend.tosa.DIM.default(mode.from_tensor(s0_tensor), axis=2) + result = exir_ops.backend.tosa.MOD_SHAPE.default(const_shape, dim_s0) + + assert len(result) == 1 + assert isinstance(result[0], torch.SymInt) + assert _expr_equals(result[0], sympy.Mod(sympy.Integer(8), sympy.Symbol("s0"))) + + # Test DIV_FLOOR_SHAPE with constant values, which should perform floor division and return a constant shape. def test_div_floor_const_shape(): shape_env = ShapeEnv() diff --git a/backends/arm/tosa/dialect/ops/shape_ops.py b/backends/arm/tosa/dialect/ops/shape_ops.py index 0f36420697b..edeb731620d 100644 --- a/backends/arm/tosa/dialect/ops/shape_ops.py +++ b/backends/arm/tosa/dialect/ops/shape_ops.py @@ -171,3 +171,18 @@ def MUL_SHAPE( """ return _combine_shapes(shape1, shape2, lambda a, b: a * b) + + +@register_fake_tosa_op( + "MOD_SHAPE(SymInt[] shape1, SymInt[] shape2) -> SymInt[]", # schema + TosaSpecification.all_profiles_for_version("1.1"), +) +def MOD_SHAPE( + shape1: list[IntLikeType], + shape2: list[IntLikeType], +) -> list[IntLikeType]: + """MOD_SHAPE operator computes the element-wise modulo of the first shape + tensor by the second. + """ + + return _combine_shapes(shape1, shape2, lambda a, b: a % b)