Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions backends/mlx/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@
Atan2Node,
BitwiseAndNode,
BitwiseInvertNode,
BitwiseOrNode,
BroadcastToNode,
CeilNode,
ClipNode,
Expand Down Expand Up @@ -490,6 +491,12 @@ def _isnan_handler(P: MLXProgramBuilder, n: Node) -> Slot:
"aten.bitwise_and",
True,
),
(
[torch.ops.aten.bitwise_or.Tensor, torch.ops.aten.bitwise_or.Scalar],
BitwiseOrNode,
"aten.bitwise_or",
True,
),
(
[torch.ops.aten.lt.Tensor, torch.ops.aten.lt.Scalar],
LessNode,
Expand Down
9 changes: 9 additions & 0 deletions backends/mlx/runtime/MLXInterpreter.h
Original file line number Diff line number Diff line change
Expand Up @@ -1416,6 +1416,12 @@ inline void exec_bitwise_and(
bitwise_and(st.const_tensor_ref(n.a), st.const_tensor_ref(n.b), s));
}

inline void
exec_bitwise_or(const BitwiseOrNode& n, ExecutionState& st, StreamOrDevice s) {
st.set_tensor(
n.out, bitwise_or(st.const_tensor_ref(n.a), st.const_tensor_ref(n.b), s));
}

inline void exec_tri(const TriNode& n, ExecutionState& st, StreamOrDevice s) {
int rows = resolve_int(n.n, st);
int cols = resolve_int(n.m, st);
Expand Down Expand Up @@ -2069,6 +2075,9 @@ class Interpreter {
case OpCode::BITWISE_AND:
ops::exec_bitwise_and(std::get<BitwiseAndNode>(instr.node), st, s);
break;
case OpCode::BITWISE_OR:
ops::exec_bitwise_or(std::get<BitwiseOrNode>(instr.node), st, s);
break;
case OpCode::TRI:
ops::exec_tri(std::get<TriNode>(instr.node), st, s);
break;
Expand Down
9 changes: 8 additions & 1 deletion backends/mlx/serialization/schema.fbs
Original file line number Diff line number Diff line change
Expand Up @@ -585,6 +585,12 @@ table BitwiseAndNode {
out: Tid (required);
}

table BitwiseOrNode {
a: Tid (required);
b: Tid (required);
out: Tid (required);
}

// Triangular matrix ops
table TriNode {
out: Tid (required);
Expand Down Expand Up @@ -1137,7 +1143,8 @@ union OpNode {
MetalKernelNode,
BitwiseInvertNode,
RollNode,
BitwiseAndNode
BitwiseAndNode,
BitwiseOrNode
// BC: Add new op nodes here (append only)
}

Expand Down
47 changes: 47 additions & 0 deletions backends/mlx/test/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -4352,6 +4352,8 @@ def create_model(self) -> nn.Module:
# logical
{"op_name": "bitwise_and_bool", "op_fn": torch.bitwise_and, "shapes": _SHAPES_3, "dtypes": [torch.bool], "input_fn_a": _bool_input_fn(), "input_fn_b": _bool_input_fn()},
{"op_name": "bitwise_and_int", "op_fn": torch.bitwise_and, "shapes": _SHAPES_3, "dtypes": [torch.int32, torch.int64], "input_fn_a": _int_input_fn(0, 256), "input_fn_b": _int_input_fn(0, 256)},
{"op_name": "bitwise_or_bool", "op_fn": torch.bitwise_or, "shapes": _SHAPES_3, "dtypes": [torch.bool], "input_fn_a": _bool_input_fn(), "input_fn_b": _bool_input_fn()},
{"op_name": "bitwise_or_int", "op_fn": torch.bitwise_or, "shapes": _SHAPES_3, "dtypes": [torch.int32, torch.int64], "input_fn_a": _int_input_fn(0, 256), "input_fn_b": _int_input_fn(0, 256)},
{"op_name": "logical_and", "op_fn": torch.logical_and, "shapes": [(2, 3, 4), (10,), (4, 8)], "dtypes": [torch.bool], "input_fn_a": _bool_input_fn(), "input_fn_b": _bool_input_fn()},
{"op_name": "logical_or", "op_fn": torch.logical_or, "shapes": [(2, 3, 4), (10,), (4, 8)], "dtypes": [torch.bool], "input_fn_a": _bool_input_fn(), "input_fn_b": _bool_input_fn()},
]
Expand Down Expand Up @@ -4409,6 +4411,51 @@ def create_model(self) -> nn.Module:
return BitwiseAndScalarModel(self.scalar)


class BitwiseOrScalarModel(nn.Module):
def __init__(self, scalar):
super().__init__()
self.scalar = scalar

def forward(self, a: torch.Tensor) -> torch.Tensor:
return torch.bitwise_or(a, self.scalar)


@register_test
class BitwiseOrScalarTest(OpTestCase):
"""Test case for aten.bitwise_or op (Tensor_Scalar variant)."""

name = "bitwise_or_scalar"

def __init__(
self,
shape: Tuple[int, ...],
dtype: torch.dtype,
scalar,
):
self.shape = shape
self.dtype = dtype
self.scalar = scalar
shape_str = "x".join(str(s) for s in shape)
dtype_str = str(dtype).replace("torch.", "")
self.name = f"bitwise_or_scalar_{shape_str}_{dtype_str}"

@classmethod
def get_test_configs(cls) -> List["BitwiseOrScalarTest"]:
return [
cls(shape=(16,), dtype=torch.bool, scalar=True),
cls(shape=(4, 4), dtype=torch.int32, scalar=7),
cls(shape=(2, 3, 4), dtype=torch.int64, scalar=13),
]

def create_inputs(self) -> Tuple[torch.Tensor, ...]:
if self.dtype == torch.bool:
return _bool_input_fn()(self.shape, self.dtype)
return _int_input_fn(0, 256)(self.shape, self.dtype)

def create_model(self) -> nn.Module:
return BitwiseOrScalarModel(self.scalar)


@register_test
class PowerScalarTest(OpTestCase):
"""Test case for aten.pow op (Tensor_Scalar variant)."""
Expand Down
Loading