Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 28 additions & 0 deletions backends/mlx/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -2926,6 +2926,34 @@ def _clamp_handler(P: MLXProgramBuilder, n: Node) -> Slot:
return out


@REGISTRY.register(target=[torch.ops.aten.hardtanh.default])
def _hardtanh_handler(P: MLXProgramBuilder, n: Node) -> Slot:
"""Handle aten.hardtanh by clamping input to [min_val, max_val]."""
args = P.args(n)
require_args(args, 1, 3, "aten.hardtanh")
require_kwargs(P.kwargs(n), set(), "aten.hardtanh")

x = args[0]
min_val = float(args[1]) if len(args) > 1 else -1.0
max_val = float(args[2]) if len(args) > 2 else 1.0

x_meta = n.args[0].meta.get("val")
if x_meta is None:
raise ValueError("Input tensor metadata not found for hardtanh")
dtype = x_meta.dtype

out = P.make_or_get_slot(n)
P.emit(
ClipNode(
x=P.slot_to_tid(x),
out=P.slot_to_tid(out),
a_min=P.slot_to_tid(emit_lifted_constant(P, min_val, dtype)),
a_max=P.slot_to_tid(emit_lifted_constant(P, max_val, dtype)),
)
)
return out


@REGISTRY.register(
target=[torch.ops.aten.expand.default, torch.ops.aten.expand_copy.default]
)
Expand Down
57 changes: 57 additions & 0 deletions backends/mlx/test/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -348,6 +348,63 @@ def create_inputs(self) -> Tuple[torch.Tensor, ...]:
return (x,)


class HardtanhModel(nn.Module):
"""Model that applies hardtanh with custom bounds."""

def __init__(self, min_val: float, max_val: float):
super().__init__()
self.min_val = min_val
self.max_val = max_val

def forward(self, x: torch.Tensor) -> torch.Tensor:
return torch.nn.functional.hardtanh(
x, min_val=self.min_val, max_val=self.max_val
)


@register_test
class HardtanhTest(OpTestCase):
"""Test case for hardtanh op with various min/max bounds."""

name = "hardtanh"
rtol = 1e-5
atol = 1e-5

def __init__(
self,
shape: Tuple[int, ...] = (2, 3, 4),
min_val: float = -1.0,
max_val: float = 1.0,
):
self.shape = shape
self.min_val = min_val
self.max_val = max_val

shape_str = "x".join(str(s) for s in shape)
self.name = f"hardtanh_min{min_val}_max{max_val}_{shape_str}"

@classmethod
def get_test_configs(cls) -> List["HardtanhTest"]:
return [
# Default bounds
cls(shape=(2, 3, 4), min_val=-1.0, max_val=1.0),
# ReLU6
cls(shape=(4, 8), min_val=0.0, max_val=6.0),
# Symmetric custom bounds
cls(shape=(10,), min_val=-2.0, max_val=2.0),
# Asymmetric custom bounds, higher rank
cls(shape=(2, 8, 16), min_val=-0.25, max_val=0.75),
]

def create_model(self) -> nn.Module:
return HardtanhModel(self.min_val, self.max_val)

def create_inputs(self) -> Tuple[torch.Tensor, ...]:
# Values span well beyond the bounds so clamping is actually exercised
x = torch.randn(self.shape) * 4
return (x,)


class GELUModel(nn.Module):
"""Simple model using GELU activation."""

Expand Down
Loading