From 9e822c10e3ce7b30b806351b3b203296aa45594b Mon Sep 17 00:00:00 2001 From: Devin Lai Date: Fri, 29 May 2026 10:22:39 +0800 Subject: [PATCH] [MLX] Add aten.bitwise_or op handler Add MLX delegate support for aten.bitwise_or Tensor and Scalar overloads, following the existing bitwise_and path through the FlatBuffer schema, Python op handler, C++ interpreter, and tests. Fixes #18926. --- backends/mlx/ops.py | 7 ++++ backends/mlx/runtime/MLXInterpreter.h | 9 +++++ backends/mlx/serialization/schema.fbs | 9 ++++- backends/mlx/test/test_ops.py | 47 +++++++++++++++++++++++++++ 4 files changed, 71 insertions(+), 1 deletion(-) diff --git a/backends/mlx/ops.py b/backends/mlx/ops.py index 204e45ba341..8a23aeb196e 100644 --- a/backends/mlx/ops.py +++ b/backends/mlx/ops.py @@ -52,6 +52,7 @@ Atan2Node, BitwiseAndNode, BitwiseInvertNode, + BitwiseOrNode, BroadcastToNode, CeilNode, ClipNode, @@ -490,6 +491,12 @@ def _isnan_handler(P: MLXProgramBuilder, n: Node) -> Slot: "aten.bitwise_and", True, ), + ( + [torch.ops.aten.bitwise_or.Tensor, torch.ops.aten.bitwise_or.Scalar], + BitwiseOrNode, + "aten.bitwise_or", + True, + ), ( [torch.ops.aten.lt.Tensor, torch.ops.aten.lt.Scalar], LessNode, diff --git a/backends/mlx/runtime/MLXInterpreter.h b/backends/mlx/runtime/MLXInterpreter.h index fb6597d171e..5bb19d4cca9 100644 --- a/backends/mlx/runtime/MLXInterpreter.h +++ b/backends/mlx/runtime/MLXInterpreter.h @@ -1416,6 +1416,12 @@ inline void exec_bitwise_and( bitwise_and(st.const_tensor_ref(n.a), st.const_tensor_ref(n.b), s)); } +inline void +exec_bitwise_or(const BitwiseOrNode& n, ExecutionState& st, StreamOrDevice s) { + st.set_tensor( + n.out, bitwise_or(st.const_tensor_ref(n.a), st.const_tensor_ref(n.b), s)); +} + inline void exec_tri(const TriNode& n, ExecutionState& st, StreamOrDevice s) { int rows = resolve_int(n.n, st); int cols = resolve_int(n.m, st); @@ -2069,6 +2075,9 @@ class Interpreter { case OpCode::BITWISE_AND: ops::exec_bitwise_and(std::get(instr.node), st, s); break; + case OpCode::BITWISE_OR: + ops::exec_bitwise_or(std::get(instr.node), st, s); + break; case OpCode::TRI: ops::exec_tri(std::get(instr.node), st, s); break; diff --git a/backends/mlx/serialization/schema.fbs b/backends/mlx/serialization/schema.fbs index 774e6454926..a7a58a4d878 100644 --- a/backends/mlx/serialization/schema.fbs +++ b/backends/mlx/serialization/schema.fbs @@ -585,6 +585,12 @@ table BitwiseAndNode { out: Tid (required); } +table BitwiseOrNode { + a: Tid (required); + b: Tid (required); + out: Tid (required); +} + // Triangular matrix ops table TriNode { out: Tid (required); @@ -1137,7 +1143,8 @@ union OpNode { MetalKernelNode, BitwiseInvertNode, RollNode, - BitwiseAndNode + BitwiseAndNode, + BitwiseOrNode // BC: Add new op nodes here (append only) } diff --git a/backends/mlx/test/test_ops.py b/backends/mlx/test/test_ops.py index 4471610519e..28257e7a0c9 100644 --- a/backends/mlx/test/test_ops.py +++ b/backends/mlx/test/test_ops.py @@ -4352,6 +4352,8 @@ def create_model(self) -> nn.Module: # logical {"op_name": "bitwise_and_bool", "op_fn": torch.bitwise_and, "shapes": _SHAPES_3, "dtypes": [torch.bool], "input_fn_a": _bool_input_fn(), "input_fn_b": _bool_input_fn()}, {"op_name": "bitwise_and_int", "op_fn": torch.bitwise_and, "shapes": _SHAPES_3, "dtypes": [torch.int32, torch.int64], "input_fn_a": _int_input_fn(0, 256), "input_fn_b": _int_input_fn(0, 256)}, + {"op_name": "bitwise_or_bool", "op_fn": torch.bitwise_or, "shapes": _SHAPES_3, "dtypes": [torch.bool], "input_fn_a": _bool_input_fn(), "input_fn_b": _bool_input_fn()}, + {"op_name": "bitwise_or_int", "op_fn": torch.bitwise_or, "shapes": _SHAPES_3, "dtypes": [torch.int32, torch.int64], "input_fn_a": _int_input_fn(0, 256), "input_fn_b": _int_input_fn(0, 256)}, {"op_name": "logical_and", "op_fn": torch.logical_and, "shapes": [(2, 3, 4), (10,), (4, 8)], "dtypes": [torch.bool], "input_fn_a": _bool_input_fn(), "input_fn_b": _bool_input_fn()}, {"op_name": "logical_or", "op_fn": torch.logical_or, "shapes": [(2, 3, 4), (10,), (4, 8)], "dtypes": [torch.bool], "input_fn_a": _bool_input_fn(), "input_fn_b": _bool_input_fn()}, ] @@ -4409,6 +4411,51 @@ def create_model(self) -> nn.Module: return BitwiseAndScalarModel(self.scalar) +class BitwiseOrScalarModel(nn.Module): + def __init__(self, scalar): + super().__init__() + self.scalar = scalar + + def forward(self, a: torch.Tensor) -> torch.Tensor: + return torch.bitwise_or(a, self.scalar) + + +@register_test +class BitwiseOrScalarTest(OpTestCase): + """Test case for aten.bitwise_or op (Tensor_Scalar variant).""" + + name = "bitwise_or_scalar" + + def __init__( + self, + shape: Tuple[int, ...], + dtype: torch.dtype, + scalar, + ): + self.shape = shape + self.dtype = dtype + self.scalar = scalar + shape_str = "x".join(str(s) for s in shape) + dtype_str = str(dtype).replace("torch.", "") + self.name = f"bitwise_or_scalar_{shape_str}_{dtype_str}" + + @classmethod + def get_test_configs(cls) -> List["BitwiseOrScalarTest"]: + return [ + cls(shape=(16,), dtype=torch.bool, scalar=True), + cls(shape=(4, 4), dtype=torch.int32, scalar=7), + cls(shape=(2, 3, 4), dtype=torch.int64, scalar=13), + ] + + def create_inputs(self) -> Tuple[torch.Tensor, ...]: + if self.dtype == torch.bool: + return _bool_input_fn()(self.shape, self.dtype) + return _int_input_fn(0, 256)(self.shape, self.dtype) + + def create_model(self) -> nn.Module: + return BitwiseOrScalarModel(self.scalar) + + @register_test class PowerScalarTest(OpTestCase): """Test case for aten.pow op (Tensor_Scalar variant)."""