From 19fa1488486e9d8b309dcd61d1d686b853ffa552 Mon Sep 17 00:00:00 2001 From: atharvjairath <54663702+atharvjairath@users.noreply.github.com> Date: Wed, 27 May 2026 00:44:30 +0530 Subject: [PATCH] Add MLX hardtanh op handler --- backends/mlx/ops.py | 28 +++++++++++++++++ backends/mlx/test/test_ops.py | 57 +++++++++++++++++++++++++++++++++++ 2 files changed, 85 insertions(+) diff --git a/backends/mlx/ops.py b/backends/mlx/ops.py index 204e45ba341..c0dcfa5d661 100644 --- a/backends/mlx/ops.py +++ b/backends/mlx/ops.py @@ -2926,6 +2926,34 @@ def _clamp_handler(P: MLXProgramBuilder, n: Node) -> Slot: return out +@REGISTRY.register(target=[torch.ops.aten.hardtanh.default]) +def _hardtanh_handler(P: MLXProgramBuilder, n: Node) -> Slot: + """Handle aten.hardtanh by clamping input to [min_val, max_val].""" + args = P.args(n) + require_args(args, 1, 3, "aten.hardtanh") + require_kwargs(P.kwargs(n), set(), "aten.hardtanh") + + x = args[0] + min_val = float(args[1]) if len(args) > 1 else -1.0 + max_val = float(args[2]) if len(args) > 2 else 1.0 + + x_meta = n.args[0].meta.get("val") + if x_meta is None: + raise ValueError("Input tensor metadata not found for hardtanh") + dtype = x_meta.dtype + + out = P.make_or_get_slot(n) + P.emit( + ClipNode( + x=P.slot_to_tid(x), + out=P.slot_to_tid(out), + a_min=P.slot_to_tid(emit_lifted_constant(P, min_val, dtype)), + a_max=P.slot_to_tid(emit_lifted_constant(P, max_val, dtype)), + ) + ) + return out + + @REGISTRY.register( target=[torch.ops.aten.expand.default, torch.ops.aten.expand_copy.default] ) diff --git a/backends/mlx/test/test_ops.py b/backends/mlx/test/test_ops.py index 4471610519e..4ebead6137e 100644 --- a/backends/mlx/test/test_ops.py +++ b/backends/mlx/test/test_ops.py @@ -347,6 +347,63 @@ def create_inputs(self) -> Tuple[torch.Tensor, ...]: return (x,) +class HardtanhModel(nn.Module): + """Model that applies hardtanh with custom bounds.""" + + def __init__(self, min_val: float, max_val: float): + super().__init__() + self.min_val = min_val + self.max_val = max_val + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return torch.nn.functional.hardtanh( + x, min_val=self.min_val, max_val=self.max_val + ) + + +@register_test +class HardtanhTest(OpTestCase): + """Test case for hardtanh op with various min/max bounds.""" + + name = "hardtanh" + rtol = 1e-5 + atol = 1e-5 + + def __init__( + self, + shape: Tuple[int, ...] = (2, 3, 4), + min_val: float = -1.0, + max_val: float = 1.0, + ): + self.shape = shape + self.min_val = min_val + self.max_val = max_val + + shape_str = "x".join(str(s) for s in shape) + self.name = f"hardtanh_min{min_val}_max{max_val}_{shape_str}" + + @classmethod + def get_test_configs(cls) -> List["HardtanhTest"]: + return [ + # Default bounds + cls(shape=(2, 3, 4), min_val=-1.0, max_val=1.0), + # ReLU6 + cls(shape=(4, 8), min_val=0.0, max_val=6.0), + # Symmetric custom bounds + cls(shape=(10,), min_val=-2.0, max_val=2.0), + # Asymmetric custom bounds, higher rank + cls(shape=(2, 8, 16), min_val=-0.25, max_val=0.75), + ] + + def create_model(self) -> nn.Module: + return HardtanhModel(self.min_val, self.max_val) + + def create_inputs(self) -> Tuple[torch.Tensor, ...]: + # Values span well beyond the bounds so clamping is actually exercised + x = torch.randn(self.shape) * 4 + return (x,) + + class GELUModel(nn.Module): """Simple model using GELU activation."""