From 90c78e5c5d53319d0a35852248c3fd46fd8233f8 Mon Sep 17 00:00:00 2001 From: Scott Roy Date: Thu, 28 May 2026 18:01:04 -0700 Subject: [PATCH 1/8] up --- backends/mlx/llm/turboquant_cache.py | 227 ++++++++++++++ backends/mlx/model_ops/test_tq4_compress.py | 183 +++++++++++ backends/mlx/model_ops/test_tq_dequant.py | 166 ++++++++++ backends/mlx/model_ops/test_tq_norm.py | 150 +++++++++ backends/mlx/model_ops/tq4_compress.py | 280 +++++++++++++++++ backends/mlx/model_ops/tq_dequant.py | 289 ++++++++++++++++++ backends/mlx/model_ops/tq_norm.py | 254 +++++++++++++++ backends/mlx/test/op_test_runner.cpp | 12 + backends/mlx/test/test_utils.py | 5 + examples/models/gemma4_31b/README.md | 18 ++ examples/models/gemma4_31b/export.py | 45 ++- .../gemma4_31b/mlx_source_transformations.py | 73 +++-- 12 files changed, 1673 insertions(+), 29 deletions(-) create mode 100644 backends/mlx/llm/turboquant_cache.py create mode 100644 backends/mlx/model_ops/test_tq4_compress.py create mode 100644 backends/mlx/model_ops/test_tq_dequant.py create mode 100644 backends/mlx/model_ops/test_tq_norm.py create mode 100644 backends/mlx/model_ops/tq4_compress.py create mode 100644 backends/mlx/model_ops/tq_dequant.py create mode 100644 backends/mlx/model_ops/tq_norm.py diff --git a/backends/mlx/llm/turboquant_cache.py b/backends/mlx/llm/turboquant_cache.py new file mode 100644 index 00000000000..265dc0b89b5 --- /dev/null +++ b/backends/mlx/llm/turboquant_cache.py @@ -0,0 +1,227 @@ +#!/usr/bin/env python3 +# +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" +TurboQuant TQ4 KV cache for the MLX backend. + +Subclass of the backend-agnostic +``extension/llm/modules/turboquant/kv_cache.py::TurboQuantKVCache``. + +The cache stores K and V in **rotated space** (post-multiplied by R^T) +as nibble-packed uint8 codebook indices plus per-vector bf16 norms. +SDPA runs in rotated space and undoes the rotation on the output side +(both Q and output rotations are ``T_q × D²``, much smaller than +applying the inverse rotation to K/V which would be ``T_kv × D²``). + +Reference: + TurboQuant: Online Vector Quantization with Near-optimal + Distortion Rate. arXiv:2504.19874 (ICLR 2026). +""" + +from typing import Optional, Tuple + +# Register the MLX custom ops used by this cache. +import executorch.backends.mlx.custom_ops # noqa: F401 mlx::custom_sdpa, mlx::kv_cache_update +import executorch.backends.mlx.model_ops.tq4_compress # noqa: F401 mlx::tq4_compress +import executorch.backends.mlx.model_ops.tq_dequant # noqa: F401 mlx::tq_dequant +import executorch.backends.mlx.model_ops.tq_norm # noqa: F401 mlx::tq_norm + +import torch + +from executorch.extension.llm.modules.turboquant.kv_cache import ( + TurboQuantKVCache as _SharedTurboQuantKVCache, +) + + +class TurboQuantKVCache(_SharedTurboQuantKVCache): + """ + TurboQuant TQ4 KV cache, MLX-backend variant. + + Drop-in replacement for ``backends/mlx/llm/cache.py::KVCache``. + + Args: + max_batch_size: Must be 1 (TQ4 is batch=1 only). + max_context_length: Maximum sequence length. + n_heads: Number of KV heads. + head_dim: Per-head dimension. Must be even and a multiple of 32. + enable_dynamic_shape: Accepted for interface parity; ignored. + dtype: Compute dtype (bf16). Used for pre-cast buffers. + bits: Quantization bits (must be 4). + seed: RNG seed for the orthogonal rotation matrix. + """ + + def __init__( + self, + max_batch_size: int, + max_context_length: int, + n_heads: int, + head_dim: int, + enable_dynamic_shape: bool, + dtype: torch.dtype = torch.bfloat16, + bits: int = 4, + seed: int = 42, + ): + if max_batch_size != 1: + raise ValueError( + f"TurboQuantKVCache only supports max_batch_size=1, " + f"got {max_batch_size}" + ) + super().__init__( + n_heads=n_heads, + head_dim=head_dim, + max_seq_len=max_context_length, + bits=bits, + seed=seed, + ) + self.max_batch_size = max_batch_size + self.max_context_length = max_context_length + self.enable_dynamic_shape = enable_dynamic_shape + + # Replace parent's fp32 ``rotation`` and ``centroids`` buffers + # with compute-dtype versions in-place. Avoids a per-call + # ``_to_copy`` cast in the lowered graph at every use site. + # Parent's ``_decompress`` (testing-only) is the sole consumer + # of these as fp32 and is not called at runtime. + self.register_buffer( + "rotation", + self.rotation.to(dtype).contiguous(), + persistent=False, + ) + self.register_buffer( + "centroids", + self.centroids.to(dtype).contiguous(), + persistent=False, + ) + # Pre-cast eps for the divide-by-zero guard in _compress. + self.register_buffer( + "norm_eps", + torch.tensor(1e-10, dtype=dtype), + persistent=False, + ) + + def _compress(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + """Compress ``(1, H, T, D)`` → packed ``(1, H, T, D//2)`` u8 + + norms ``(1, H, T, 1)`` bf16. + + The L2-norm reduction uses ``mlx::tq_norm`` (one Metal kernel + with fp32 sum-of-squares in registers via ``simd_sum``); the + bucketize + nibble-pack tail uses ``mlx::tq4_compress`` (one + Metal kernel for both steps). + """ + orig_shape = x.shape + flat = x.reshape(-1, self.head_dim) + + norms = torch.ops.mlx.tq_norm(flat) + normalized = flat / (norms + self.norm_eps) + rotated = normalized @ self.rotation_T + packed = torch.ops.mlx.tq4_compress(rotated, self.boundaries) + + return ( + packed.reshape(*orig_shape[:-1], self.half_dim), + norms.reshape(*orig_shape[:-1], 1), + ) + + def update( + self, + input_pos, + k_val: torch.Tensor, + v_val: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + """Compress + write K/V at ``input_pos``, return the full + compressed cache buffers. + + Accepts ``input_pos`` as either a ``(T,)`` LongTensor of + positions or a Python int / SymInt ``start_pos``. Writes go + through ``mlx::kv_cache_update`` (matching the non-TQ + ``MLXKVCache`` path) which lowers to a tighter in-place + scatter than ``index_copy_`` would. + """ + if isinstance(input_pos, torch.Tensor): + start_pos = input_pos[0].item() + seq_len = k_val.size(2) + torch._check(seq_len == v_val.size(2)) + torch._check(start_pos >= 0) + torch._check(start_pos + seq_len <= self.max_context_length) + else: + start_pos = input_pos + + k_packed, k_norms = self._compress(k_val) + v_packed, v_norms = self._compress(v_val) + + torch.ops.mlx.kv_cache_update(self.k_packed, k_packed, start_pos) + torch.ops.mlx.kv_cache_update(self.k_norms, k_norms, start_pos) + torch.ops.mlx.kv_cache_update(self.v_packed, v_packed, start_pos) + torch.ops.mlx.kv_cache_update(self.v_norms, v_norms, start_pos) + + # Slices on the return create new graph nodes so the same node + # is not both BUFFER_MUTATION and USER_OUTPUT. + return ( + self.k_packed[:, :, :, :], + self.k_norms[:, :, :, :], + self.v_packed[:, :, :, :], + self.v_norms[:, :, :, :], + ) + + # forward() is inherited from the parent (delegates to update). + + def sdpa( + self, + query: torch.Tensor, + start_pos, + scale: Optional[float] = None, + ) -> torch.Tensor: + """SDPA over the compressed cache. + + Runs attention in rotated space: + 1. Q_rot = Q @ R^T (T_q x D^2) + 2. K_rot, V_rot = tq_dequant(...) (rotated-space K/V) + 3. out_rot = custom_sdpa(Q_rot, K_rot, V_rot, ...) + 4. out = out_rot @ R (T_q x D^2) + + Since R is orthogonal, score = (Q·R^T)·(K·R^T)^T = Q·K^T, so + attention is invariant under matched rotation of Q and K. The + ``T_kv x D^2`` inverse-rotation matmul on K/V is replaced with + two ``T_q x D^2`` matmuls (Q and output). + + Args: + query: ``(B, H_q, T_q, D)`` bf16. + start_pos: int or SymInt — absolute position of the first + query token. + scale: 1/sqrt(D) if None. + + Returns: + ``(B, H_q, T_q, D)`` bf16 attention output, in original + (un-rotated) space. + """ + seq_len = query.size(2) + end_pos = start_pos + seq_len + torch._check(start_pos >= 0) + torch._check(end_pos <= self.max_context_length) + + q_rot = query @ self.rotation_T + + k_packed_live = self.k_packed[:, :, :end_pos, :] + k_norms_live = self.k_norms[:, :, :end_pos, :] + v_packed_live = self.v_packed[:, :, :end_pos, :] + v_norms_live = self.v_norms[:, :, :end_pos, :] + + k_rot = torch.ops.mlx.tq_dequant(k_packed_live, k_norms_live, self.centroids) + v_rot = torch.ops.mlx.tq_dequant(v_packed_live, v_norms_live, self.centroids) + + out_rot = torch.ops.mlx.custom_sdpa( + q_rot, + k_rot, + v_rot, + start_pos, + None, # attn_mask + 0.0, # dropout_p + True, # is_causal + scale, + ) + + return out_rot @ self.rotation diff --git a/backends/mlx/model_ops/test_tq4_compress.py b/backends/mlx/model_ops/test_tq4_compress.py new file mode 100644 index 00000000000..c2aaa13afa7 --- /dev/null +++ b/backends/mlx/model_ops/test_tq4_compress.py @@ -0,0 +1,183 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" +Tests for ``mlx::tq4_compress``. + +Verifies the fused Metal kernel produces byte-exact output vs the +eager Python implementation across head_dim values used by TurboQuant. + +Usage:: + + python -m executorch.backends.mlx.model_ops.test_tq4_compress run + python -m executorch.backends.mlx.model_ops.test_tq4_compress run -v + python -m executorch.backends.mlx.model_ops.test_tq4_compress run --rebuild +""" + +from typing import List, Tuple + +import executorch.backends.mlx.model_ops.tq4_compress # noqa: F401 + +import torch +import torch.nn as nn + +from executorch.backends.mlx.test.test_utils import OpTestCase + + +class TQ4CompressModel(nn.Module): + """``values → packed`` via ``mlx::tq4_compress``. + + Boundaries are stored as a buffer so the model is exportable + without feeding them as a graph input. + """ + + def __init__(self, head_dim: int, dtype: torch.dtype = torch.bfloat16): + super().__init__() + # 15 sorted thresholds (4-bit codebook). + self.register_buffer( + "boundaries", + torch.linspace(-0.2, 0.2, 15, dtype=dtype), + ) + + def forward(self, values: torch.Tensor) -> torch.Tensor: + return torch.ops.mlx.tq4_compress(values, self.boundaries) + + +class TQ4CompressTest(OpTestCase): + """Byte-exact comparison vs eager bucketize + nibble-pack.""" + + name = "tq4_compress" + rtol = 0.0 + atol = 0.0 + + def __init__( + self, + batch_size: int = 1, + n_heads: int = 8, + seq_len: int = 4, + head_dim: int = 128, + dtype: torch.dtype = torch.bfloat16, + ): + self.batch_size = batch_size + self.n_heads = n_heads + self.seq_len = seq_len + self.head_dim = head_dim + self.dtype = dtype + + parts = [ + "tq4_compress", + f"b{batch_size}", + f"h{n_heads}", + f"t{seq_len}", + f"d{head_dim}", + ] + if dtype != torch.bfloat16: + parts.append(str(dtype).split(".")[-1]) + self.name = "_".join(parts) + + @classmethod + def get_test_configs(cls) -> List["TQ4CompressTest"]: + return [ + # head_dim=128 (Qwen3.5 MoE / Gemma 4 sliding) + cls(seq_len=1, head_dim=128), + cls(seq_len=8, head_dim=128), + cls(seq_len=64, head_dim=128), + cls(n_heads=1, seq_len=1, head_dim=128), + # head_dim=256 (Gemma 4 sliding-attention) + cls(head_dim=256), + cls(seq_len=16, head_dim=256), + # head_dim=512 (Gemma 4 31B full-attention) + cls(n_heads=4, seq_len=4, head_dim=512), + cls(n_heads=4, seq_len=64, head_dim=512), + # Smaller D for sanity + cls(head_dim=64, n_heads=2, seq_len=4), + ] + + def create_model(self) -> nn.Module: + return TQ4CompressModel(head_dim=self.head_dim, dtype=self.dtype).to(self.dtype) + + def create_inputs(self) -> Tuple[torch.Tensor, ...]: + # Activation-scale values; the kernel is byte-exact regardless + # of magnitude as long as values fall within the bucketize + # comparison range. + values = torch.randn( + self.batch_size, + self.n_heads, + self.seq_len, + self.head_dim, + dtype=self.dtype, + ) * (1.0 / (self.head_dim**0.5)) + return (values,) + + +if __name__ == "__main__": # noqa: C901 + import argparse + import sys + + from executorch.backends.mlx.test.test_utils import rebuild_op_test_runner + + parser = argparse.ArgumentParser(description="Test mlx::tq4_compress op") + parser.add_argument( + "action", + choices=["generate", "compare", "run", "list"], + help="Action: generate (export), compare (check outputs), run (full), list (show configs)", + ) + parser.add_argument("--verbose", "-v", action="store_true") + parser.add_argument( + "--rebuild", action="store_true", help="Rebuild C++ runner first" + ) + parser.add_argument( + "--config", type=str, default=None, help="Run specific config by name" + ) + args = parser.parse_args() + + if args.rebuild and not rebuild_op_test_runner(verbose=args.verbose): + sys.exit(1) + + configs = TQ4CompressTest.get_test_configs() + + if args.action == "list": + for cfg in configs: + print(f" {cfg.name}") + sys.exit(0) + + if args.config: + configs = [c for c in configs if c.name == args.config] + if not configs: + print(f"No config matching '{args.config}'") + sys.exit(1) + + passed = 0 + failed = 0 + failed_names: List[str] = [] + + for test in configs: + if args.action == "generate": + pte_path, _, _ = test.generate_test_files(verbose=args.verbose) + print(f"Generated: {pte_path}") + elif args.action == "compare": + actual_path = test.get_test_dir() / "actual_output.bin" + ok, msg = test.compare_with_actual(actual_path) + print(f"{'✓' if ok else '✗'} {test.name}: {msg}") + if ok: + passed += 1 + else: + failed += 1 + failed_names.append(test.name) + elif args.action == "run": + ok = test.run_test(verbose=args.verbose) + if ok: + passed += 1 + else: + failed += 1 + failed_names.append(test.name) + + if args.action in ("run", "compare"): + print(f"\nPassed: {passed}, Failed: {failed}") + if failed_names: + print(f"Failed: {', '.join(failed_names)}") + sys.exit(0 if failed == 0 else 1) diff --git a/backends/mlx/model_ops/test_tq_dequant.py b/backends/mlx/model_ops/test_tq_dequant.py new file mode 100644 index 00000000000..07d9deb895a --- /dev/null +++ b/backends/mlx/model_ops/test_tq_dequant.py @@ -0,0 +1,166 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" +Tests for ``mlx::tq_dequant``. + +Verifies the fused unpack + gather + multiply Metal kernel matches +the eager reference at head_dim values used by TurboQuant +(D ∈ {128, 256, 512}). Output is byte-exact — no fp32 promotion in +either path. + +Usage:: + + python -m executorch.backends.mlx.model_ops.test_tq_dequant run + python -m executorch.backends.mlx.model_ops.test_tq_dequant run -v + python -m executorch.backends.mlx.model_ops.test_tq_dequant run --rebuild +""" + +from typing import List, Tuple + +import executorch.backends.mlx.model_ops.tq_dequant # noqa: F401 + +import torch +import torch.nn as nn + +from executorch.backends.mlx.test.test_utils import OpTestCase + + +class TQDequantModel(nn.Module): + """``packed, norms, centroids → unrotated``.""" + + def forward( + self, + packed: torch.Tensor, + norms: torch.Tensor, + centroids: torch.Tensor, + ) -> torch.Tensor: + return torch.ops.mlx.tq_dequant(packed, norms, centroids) + + +class TQDequantTest(OpTestCase): + """Byte-exact comparison vs eager unpack + gather + multiply.""" + + name = "tq_dequant" + rtol = 0.0 + atol = 0.0 + + def __init__( + self, + batch_size: int = 1, + n_heads: int = 8, + seq_len: int = 4, + head_dim: int = 128, + ): + self.batch_size = batch_size + self.n_heads = n_heads + self.seq_len = seq_len + self.head_dim = head_dim + self.half_dim = head_dim // 2 + self.name = f"tq_dequant_b{batch_size}_h{n_heads}_t{seq_len}_d{head_dim}" + + @classmethod + def get_test_configs(cls) -> List["TQDequantTest"]: + return [ + # head_dim=128 (Qwen3.5 MoE / Gemma 4 sliding) + cls(seq_len=1, head_dim=128), + cls(seq_len=8, head_dim=128), + cls(seq_len=64, head_dim=128), + cls(n_heads=1, seq_len=1, head_dim=128), + # head_dim=256 (Gemma 4 sliding-attention) + cls(seq_len=4, head_dim=256), + cls(seq_len=16, head_dim=256), + # head_dim=512 (Gemma 4 31B full-attention) + cls(n_heads=4, seq_len=4, head_dim=512), + cls(n_heads=4, seq_len=64, head_dim=512), + ] + + def create_model(self) -> nn.Module: + return TQDequantModel() + + def create_inputs(self) -> Tuple[torch.Tensor, ...]: + # Random packed bytes exercise every codebook entry. + packed = torch.randint( + 0, + 256, + (self.batch_size, self.n_heads, self.seq_len, self.half_dim), + dtype=torch.uint8, + ) + norms = ( + torch.randn( + self.batch_size, + self.n_heads, + self.seq_len, + 1, + dtype=torch.bfloat16, + ).abs() + + 0.1 + ) + # Deterministic codebook covering [-1, 1]. + centroids = torch.linspace(-1.0, 1.0, 16, dtype=torch.bfloat16) + return (packed, norms, centroids) + + +if __name__ == "__main__": # noqa: C901 + import argparse + import sys + + from executorch.backends.mlx.test.test_utils import rebuild_op_test_runner + + parser = argparse.ArgumentParser(description="Test mlx::tq_dequant op") + parser.add_argument("action", choices=["generate", "compare", "run", "list"]) + parser.add_argument("--verbose", "-v", action="store_true") + parser.add_argument("--rebuild", action="store_true") + parser.add_argument("--config", type=str, default=None) + args = parser.parse_args() + + if args.rebuild and not rebuild_op_test_runner(verbose=args.verbose): + sys.exit(1) + + configs = TQDequantTest.get_test_configs() + + if args.action == "list": + for cfg in configs: + print(f" {cfg.name}") + sys.exit(0) + + if args.config: + configs = [c for c in configs if c.name == args.config] + if not configs: + print(f"No config matching '{args.config}'") + sys.exit(1) + + passed = 0 + failed = 0 + failed_names: List[str] = [] + + for test in configs: + if args.action == "generate": + pte_path, _, _ = test.generate_test_files(verbose=args.verbose) + print(f"Generated: {pte_path}") + elif args.action == "compare": + actual_path = test.get_test_dir() / "actual_output.bin" + ok, msg = test.compare_with_actual(actual_path) + print(f"{'✓' if ok else '✗'} {test.name}: {msg}") + if ok: + passed += 1 + else: + failed += 1 + failed_names.append(test.name) + elif args.action == "run": + ok = test.run_test(verbose=args.verbose) + if ok: + passed += 1 + else: + failed += 1 + failed_names.append(test.name) + + if args.action in ("run", "compare"): + print(f"\nPassed: {passed}, Failed: {failed}") + if failed_names: + print(f"Failed: {', '.join(failed_names)}") + sys.exit(0 if failed == 0 else 1) diff --git a/backends/mlx/model_ops/test_tq_norm.py b/backends/mlx/model_ops/test_tq_norm.py new file mode 100644 index 00000000000..35c4491d8ae --- /dev/null +++ b/backends/mlx/model_ops/test_tq_norm.py @@ -0,0 +1,150 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" +Tests for ``mlx::tq_norm``. + +Verifies the fused L2-norm Metal kernel matches eager ``vector_norm`` +at head_dim values used by TurboQuant (D ∈ {128, 256, 512}). + +Usage:: + + python -m executorch.backends.mlx.model_ops.test_tq_norm run + python -m executorch.backends.mlx.model_ops.test_tq_norm run -v + python -m executorch.backends.mlx.model_ops.test_tq_norm run --rebuild +""" + +from typing import List, Tuple + +import executorch.backends.mlx.model_ops.tq_norm # noqa: F401 + +import torch +import torch.nn as nn + +from executorch.backends.mlx.test.test_utils import OpTestCase + + +class TQNormModel(nn.Module): + """``x → ||x||₂`` over the last dim.""" + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return torch.ops.mlx.tq_norm(x) + + +class TQNormTest(OpTestCase): + """Compare ``mlx::tq_norm`` to eager ``vector_norm`` within bf16 ULPs.""" + + name = "tq_norm" + rtol = 1e-2 + atol = 1e-2 + + def __init__( + self, + batch_size: int = 1, + n_heads: int = 8, + seq_len: int = 4, + head_dim: int = 128, + ): + self.batch_size = batch_size + self.n_heads = n_heads + self.seq_len = seq_len + self.head_dim = head_dim + self.name = f"tq_norm_b{batch_size}_h{n_heads}_t{seq_len}_d{head_dim}" + + @classmethod + def get_test_configs(cls) -> List["TQNormTest"]: + return [ + # head_dim=128 (Qwen3.5 MoE / Gemma 4 sliding) + cls(seq_len=1, head_dim=128), + cls(seq_len=8, head_dim=128), + cls(seq_len=64, head_dim=128), + cls(n_heads=1, seq_len=1, head_dim=128), + # head_dim=256 (Gemma 4 sliding-attention) + cls(seq_len=4, head_dim=256), + cls(seq_len=16, head_dim=256), + # head_dim=512 (Gemma 4 31B full-attention) + cls(n_heads=4, seq_len=4, head_dim=512), + cls(n_heads=4, seq_len=64, head_dim=512), + ] + + def create_model(self) -> nn.Module: + return TQNormModel().to(torch.bfloat16) + + def create_inputs(self) -> Tuple[torch.Tensor, ...]: + # Activation-scale bf16 inputs. + x = torch.randn( + self.batch_size, + self.n_heads, + self.seq_len, + self.head_dim, + dtype=torch.bfloat16, + ) * (1.0 / (self.head_dim**0.5)) + return (x,) + + +if __name__ == "__main__": # noqa: C901 + import argparse + import sys + + from executorch.backends.mlx.test.test_utils import rebuild_op_test_runner + + parser = argparse.ArgumentParser(description="Test mlx::tq_norm op") + parser.add_argument( + "action", + choices=["generate", "compare", "run", "list"], + ) + parser.add_argument("--verbose", "-v", action="store_true") + parser.add_argument("--rebuild", action="store_true") + parser.add_argument("--config", type=str, default=None) + args = parser.parse_args() + + if args.rebuild and not rebuild_op_test_runner(verbose=args.verbose): + sys.exit(1) + + configs = TQNormTest.get_test_configs() + + if args.action == "list": + for cfg in configs: + print(f" {cfg.name}") + sys.exit(0) + + if args.config: + configs = [c for c in configs if c.name == args.config] + if not configs: + print(f"No config matching '{args.config}'") + sys.exit(1) + + passed = 0 + failed = 0 + failed_names: List[str] = [] + + for test in configs: + if args.action == "generate": + pte_path, _, _ = test.generate_test_files(verbose=args.verbose) + print(f"Generated: {pte_path}") + elif args.action == "compare": + actual_path = test.get_test_dir() / "actual_output.bin" + ok, msg = test.compare_with_actual(actual_path) + print(f"{'✓' if ok else '✗'} {test.name}: {msg}") + if ok: + passed += 1 + else: + failed += 1 + failed_names.append(test.name) + elif args.action == "run": + ok = test.run_test(verbose=args.verbose) + if ok: + passed += 1 + else: + failed += 1 + failed_names.append(test.name) + + if args.action in ("run", "compare"): + print(f"\nPassed: {passed}, Failed: {failed}") + if failed_names: + print(f"Failed: {', '.join(failed_names)}") + sys.exit(0 if failed == 0 else 1) diff --git a/backends/mlx/model_ops/tq4_compress.py b/backends/mlx/model_ops/tq4_compress.py new file mode 100644 index 00000000000..0d5c9748200 --- /dev/null +++ b/backends/mlx/model_ops/tq4_compress.py @@ -0,0 +1,280 @@ +# +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +# + +""" +``mlx::tq4_compress``: TurboQuant TQ4 quantize + nibble-pack. + +Maps ``(..., D)`` floats to ``(..., D/2)`` uint8 by: + 1. Bucketizing each value against ``boundaries`` (15 sorted thresholds). + 2. Packing pairs of 4-bit indices into one byte: high nibble holds + the even-position index, low nibble holds the odd-position index. + +Constraints: + * ``boundaries`` must be 1-D length 15 (4-bit codebook). + * Last dim of ``values`` must be even and statically known. + +Usage:: + + import executorch.backends.mlx.model_ops.tq4_compress # noqa: F401 + + packed = torch.ops.mlx.tq4_compress(rotated, boundaries) + # rotated: (..., D) float + # boundaries: (15,) same dtype as rotated + # packed: (..., D/2) uint8 +""" + +from __future__ import annotations + +from functools import reduce +from operator import mul +from typing import Optional, Union + +import torch +from torch import Tensor +from torch.fx.node import Node + + +@torch.library.custom_op("mlx::tq4_compress", mutates_args=()) +def tq4_compress(values: Tensor, boundaries: Tensor) -> Tensor: + """TurboQuant TQ4 quantize + nibble-pack. + + Args: + values: ``(..., D)`` float, last dim must be even. + boundaries: ``(15,)`` 1-D sorted, same dtype as ``values``. + + Returns: + ``(..., D/2)`` uint8. Each byte holds two 4-bit indices: high + nibble is the even-position index, low nibble is the odd. + """ + if boundaries.dim() != 1 or boundaries.shape[0] != 15: + raise ValueError( + f"mlx::tq4_compress: boundaries must be 1-D length 15; " + f"got shape {tuple(boundaries.shape)}" + ) + if values.shape[-1] % 2 != 0: + raise ValueError( + f"mlx::tq4_compress: input last dim must be even; got " + f"{values.shape[-1]}" + ) + + indices = torch.bucketize(values, boundaries).to(torch.uint8) + packed = (indices[..., 0::2] << 4) | indices[..., 1::2] + return packed + + +@torch.library.register_fake("mlx::tq4_compress") +def tq4_compress_fake(values: Tensor, boundaries: Tensor) -> Tensor: + out_shape = list(values.shape) + out_shape[-1] = out_shape[-1] // 2 + return values.new_empty(out_shape, dtype=torch.uint8) + + +# --------------------------------------------------------------------------- +# MLX handler +# --------------------------------------------------------------------------- + +from executorch.backends.mlx.builder.op_helpers import torch_dtype_to_scalar_type +from executorch.backends.mlx.builder.op_registry import REGISTRY +from executorch.backends.mlx.builder.program_builder import MLXProgramBuilder +from executorch.backends.mlx.builder.slot_manager import Slot +from executorch.backends.mlx.serialization.mlx_graph_schema import ( + IntOrVid, + MetalKernelNode, + MultiplyIntNode, + SymSizeNode, +) + + +# One thread per output byte: reads ``values[2*gid]``, ``values[2*gid+1]``, +# bucketizes against the 15 boundaries (loop unrolled, ``B`` is a template +# constant), and packs the two 4-bit indices into one byte. +_TQ4_COMPRESS_SOURCE = """ + uint gid = thread_position_in_grid.x; + if (gid >= N_OUT) return; + float v_hi = float(values[2 * gid]); + float v_lo = float(values[2 * gid + 1]); + uchar idx_hi = 0; + uchar idx_lo = 0; + #pragma unroll + for (uint i = 0; i < B; ++i) { + float bnd = float(boundaries[i]); + idx_hi += (uchar)(v_hi > bnd); + idx_lo += (uchar)(v_lo > bnd); + } + out[gid] = (idx_hi << 4) | idx_lo; +""" + + +def _compute_output_numel(P: MLXProgramBuilder, node: Node) -> Union[int, IntOrVid]: + """Output numel = numel(input) / 2. Returns a static int when the + full shape is known, else an IntOrVid built from SymSize + + MultiplyInt nodes.""" + val = node.meta.get("val") + if val is None: + raise ValueError("mlx::tq4_compress: input node has no meta['val']") + shape = val.shape + + if all(isinstance(s, int) for s in shape): + return reduce(mul, [int(s) for s in shape], 1) // 2 + + in_slot = P.slot_map([node])[0] + in_tid = P.slot_to_tid(in_slot) + + last_idx = len(shape) - 1 + acc_iov: Optional[IntOrVid] = None + for dim_idx in range(len(shape)): + s = shape[dim_idx] + if isinstance(s, int): + d = int(s) + if dim_idx == last_idx: + d //= 2 + d_iov = IntOrVid.from_literal(d) + else: + if dim_idx == last_idx: + # The schema has no integer-divide-by-Vid op; require the + # last dim be static so the /2 stays a literal. + raise NotImplementedError( + "mlx::tq4_compress: dynamic last-dim is not supported" + ) + _, d_val = P.make_tmp_value_slot() + P.emit( + SymSizeNode( + a=in_tid, + dim=dim_idx, + out=P.slot_to_vid(d_val), + ) + ) + d_iov = P.to_int_or_vid(d_val) + + if acc_iov is None: + acc_iov = d_iov + else: + _, acc_val = P.make_tmp_value_slot() + P.emit( + MultiplyIntNode( + a=acc_iov, + b=d_iov, + out=P.slot_to_vid(acc_val), + ) + ) + acc_iov = P.to_int_or_vid(acc_val) + + assert acc_iov is not None + return acc_iov + + +def _output_shape_flat(P: MLXProgramBuilder, node: Node, in_slot: Slot) -> list: + """Output shape: same as input but with last dim halved.""" + val = node.meta["val"] + shape = val.shape + last_idx = len(shape) - 1 + out: list = [] + for dim_idx, s in enumerate(shape): + if isinstance(s, int): + d = int(s) + if dim_idx == last_idx: + d //= 2 + out.append(IntOrVid.from_literal(d)) + else: + if dim_idx == last_idx: + raise NotImplementedError( + "mlx::tq4_compress: dynamic last-dim is not supported" + ) + _, d_val = P.make_tmp_value_slot() + P.emit( + SymSizeNode( + a=P.slot_to_tid(in_slot), + dim=dim_idx, + out=P.slot_to_vid(d_val), + ) + ) + out.append(P.to_int_or_vid(d_val)) + return out + + +@REGISTRY.register(target=[torch.ops.mlx.tq4_compress.default]) +def _tq4_compress_handler(P: MLXProgramBuilder, n: Node) -> Slot: + """Lower ``mlx::tq4_compress`` to a fused Metal kernel.""" + args = P.args(n) + if len(args) != 2: + raise ValueError( + f"mlx::tq4_compress: expected 2 args (values, boundaries), " + f"got {len(args)}" + ) + + values_slot, boundaries_slot = args + values_node = n.args[0] + boundaries_node = n.args[1] + + values_meta = values_node.meta["val"] + boundaries_meta = boundaries_node.meta["val"] + + # Validate boundaries length: must be 15 for 4-bit nibble pack. + bnd_shape = boundaries_meta.shape + if ( + len(bnd_shape) != 1 + or not isinstance(bnd_shape[0], int) + or int(bnd_shape[0]) != 15 + ): + raise ValueError( + f"mlx::tq4_compress: boundaries must be 1-D length 15; " + f"got shape {tuple(bnd_shape)}" + ) + + last_dim = values_meta.shape[-1] + if not isinstance(last_dim, int): + raise NotImplementedError( + "mlx::tq4_compress: last dim must be statically known" + ) + if int(last_dim) % 2 != 0: + raise ValueError(f"mlx::tq4_compress: last dim must be even; got {last_dim}") + + in_dtype_int = torch_dtype_to_scalar_type(values_meta.dtype) + + out = P.make_or_get_slot(n) + out_shape_flat = _output_shape_flat(P, values_node, values_slot) + + n_out = _compute_output_numel(P, values_node) + n_out_iov: IntOrVid = ( + IntOrVid.from_literal(int(n_out)) if isinstance(n_out, int) else n_out + ) + + P.emit( + MetalKernelNode( + name="tq4_compress", + source=_TQ4_COMPRESS_SOURCE, + inputs=[ + P.slot_to_tid(values_slot), + P.slot_to_tid(boundaries_slot), + ], + outputs=[P.slot_to_tid(out)], + grid=[n_out_iov, IntOrVid.from_literal(1), IntOrVid.from_literal(1)], + threadgroup=[ + IntOrVid.from_literal(1), + IntOrVid.from_literal(1), + IntOrVid.from_literal(1), + ], + input_names=["values", "boundaries"], + output_names=["out"], + output_shapes_flat=out_shape_flat, + output_shape_lengths=[len(out_shape_flat)], + output_dtypes=[torch_dtype_to_scalar_type(torch.uint8)], + template_arg_names=["InT", "B", "N_OUT"], + template_arg_kinds=[2, 0, 0], # 2=dtype, 0=int + template_arg_values=[ + in_dtype_int, + 15, + int(n_out) if isinstance(n_out, int) else (1 << 30), + ], + ) + ) + + return out + + +_registered = True diff --git a/backends/mlx/model_ops/tq_dequant.py b/backends/mlx/model_ops/tq_dequant.py new file mode 100644 index 00000000000..4b6605e903e --- /dev/null +++ b/backends/mlx/model_ops/tq_dequant.py @@ -0,0 +1,289 @@ +# +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +# + +""" +``mlx::tq_dequant``: TurboQuant TQ4 unpack + centroid gather + multiply-by-norm. + + indices = unpack 4-bit nibbles from packed bytes (..., D) + centvals = centroids[indices] (..., D) + out = centvals * norms (..., D) + +Output is in **rotated space** — the inverse rotation, if needed, is +left to the caller (typically MLX's tuned bf16 GEMM). + +Constraints: + * ``D`` (= ``packed.shape[-1] * 2``) must be a multiple of 32. + * ``centroids`` must be a 1-D tensor of length 16. + * Output dtype matches ``norms.dtype``. + +Usage:: + + import executorch.backends.mlx.model_ops.tq_dequant # noqa: F401 + + out = torch.ops.mlx.tq_dequant(packed, norms, centroids) + # packed: (..., D/2) uint8 + # norms: (..., 1) bf16 + # centroids: (16,) bf16 + # out: (..., D) bf16 (in rotated space) +""" + +from __future__ import annotations + +from functools import reduce +from operator import mul +from typing import Optional, Union + +import torch +from torch import Tensor +from torch.fx.node import Node + + +# --------------------------------------------------------------------------- +# Custom op + eager fallback +# --------------------------------------------------------------------------- + + +@torch.library.custom_op("mlx::tq_dequant", mutates_args=()) +def tq_dequant( + packed: Tensor, + norms: Tensor, + centroids: Tensor, +) -> Tensor: + """Fused unpack + centroid gather + multiply-by-norm. + + Args: + packed: ``(..., D/2)`` uint8. High nibble = even-position index, + low nibble = odd-position index. + norms: ``(..., 1)`` of compute dtype, broadcasts over D. + centroids: ``(16,)`` of compute dtype. + + Returns: + ``(..., D)`` of compute dtype, in rotated space. + """ + if centroids.dim() != 1 or centroids.shape[0] != 16: + raise ValueError( + f"mlx::tq_dequant: centroids must be 1-D length 16; got " + f"shape {tuple(centroids.shape)}" + ) + high = (packed >> 4).long() + low = (packed & 0x0F).long() + indices = torch.stack([high, low], dim=-1).reshape( + *packed.shape[:-1], packed.shape[-1] * 2 + ) + return centroids[indices] * norms + + +@torch.library.register_fake("mlx::tq_dequant") +def tq_dequant_fake(packed: Tensor, norms: Tensor, centroids: Tensor) -> Tensor: + out_shape = list(packed.shape) + out_shape[-1] = out_shape[-1] * 2 + return packed.new_empty(out_shape, dtype=norms.dtype) + + +# --------------------------------------------------------------------------- +# MLX handler +# --------------------------------------------------------------------------- + +from executorch.backends.mlx.builder.op_helpers import torch_dtype_to_scalar_type +from executorch.backends.mlx.builder.op_registry import REGISTRY +from executorch.backends.mlx.builder.program_builder import MLXProgramBuilder +from executorch.backends.mlx.builder.slot_manager import Slot +from executorch.backends.mlx.serialization.mlx_graph_schema import ( + IntOrVid, + MetalKernelNode, + MultiplyIntNode, + SymSizeNode, +) + + +_TQ_DEQUANT_HEADER = """ +#include +using namespace metal; +""" + + +# Per-vector decompress: +# * Grid (32, 1, M), threadgroup (32, 1, 1): one simdgroup per vector. +# * Each lane handles DIMS_PER_LANE = D/32 output values, sourced +# from BYTES_PER_LANE = DIMS_PER_LANE/2 packed bytes. +# * The 16-entry codebook is preloaded into per-lane registers. +_TQ_DEQUANT_SOURCE = """ + constexpr uint DIMS_PER_LANE = D / 32; + constexpr uint BYTES_PER_LANE = DIMS_PER_LANE / 2; + + uint vec_id = thread_position_in_grid.z; + uint lane_id = thread_position_in_threadgroup.x; + + InT cent[16]; + for (uint c = 0; c < 16; ++c) { + cent[c] = centroids[c]; + } + + InT norm = norms[vec_id]; + + uint packed_base = vec_id * (D / 2) + lane_id * BYTES_PER_LANE; + uint out_base = vec_id * D + lane_id * DIMS_PER_LANE; + + for (uint i = 0; i < BYTES_PER_LANE; ++i) { + uchar byte = packed[packed_base + i]; + uchar idx_hi = (byte >> 4) & 0x0F; + uchar idx_lo = byte & 0x0F; + out[out_base + 2 * i + 0] = cent[idx_hi] * norm; + out[out_base + 2 * i + 1] = cent[idx_lo] * norm; + } +""" + + +def _compute_M(P: MLXProgramBuilder, packed_node: Node) -> Union[int, IntOrVid]: + """``M`` = numel(packed) / (D/2) = product of leading dims of ``packed``.""" + val = packed_node.meta.get("val") + if val is None: + raise ValueError("mlx::tq_dequant: input has no meta['val']") + shape = val.shape + + if not isinstance(shape[-1], int): + raise NotImplementedError( + "mlx::tq_dequant: last dim of packed must be statically known" + ) + + leading = list(shape[:-1]) + if all(isinstance(s, int) for s in leading): + return reduce(mul, [int(s) for s in leading], 1) + + in_slot = P.slot_map([packed_node])[0] + in_tid = P.slot_to_tid(in_slot) + + acc_iov: Optional[IntOrVid] = None + for dim_idx, s in enumerate(leading): + if isinstance(s, int): + d_iov = IntOrVid.from_literal(int(s)) + else: + _, d_val = P.make_tmp_value_slot() + P.emit(SymSizeNode(a=in_tid, dim=dim_idx, out=P.slot_to_vid(d_val))) + d_iov = P.to_int_or_vid(d_val) + + if acc_iov is None: + acc_iov = d_iov + else: + _, acc_val = P.make_tmp_value_slot() + P.emit(MultiplyIntNode(a=acc_iov, b=d_iov, out=P.slot_to_vid(acc_val))) + acc_iov = P.to_int_or_vid(acc_val) + + assert acc_iov is not None + return acc_iov + + +def _output_shape_flat( + P: MLXProgramBuilder, packed_node: Node, packed_slot: Slot +) -> list: + """Output shape: same as packed but with last dim doubled.""" + val = packed_node.meta["val"] + shape = val.shape + last_idx = len(shape) - 1 + out: list = [] + for dim_idx, s in enumerate(shape): + if isinstance(s, int): + d = int(s) * 2 if dim_idx == last_idx else int(s) + out.append(IntOrVid.from_literal(d)) + else: + if dim_idx == last_idx: + raise NotImplementedError( + "mlx::tq_dequant: dynamic last-dim is not supported" + ) + _, d_val = P.make_tmp_value_slot() + P.emit( + SymSizeNode( + a=P.slot_to_tid(packed_slot), + dim=dim_idx, + out=P.slot_to_vid(d_val), + ) + ) + out.append(P.to_int_or_vid(d_val)) + return out + + +@REGISTRY.register(target=[torch.ops.mlx.tq_dequant.default]) +def _tq_dequant_handler(P: MLXProgramBuilder, n: Node) -> Slot: + """Lower ``mlx::tq_dequant`` to a single fused Metal kernel.""" + args = P.args(n) + if len(args) != 3: + raise ValueError( + f"mlx::tq_dequant: expected 3 args (packed, norms, centroids); " + f"got {len(args)}" + ) + packed_slot, norms_slot, centroids_slot = args + packed_node = n.args[0] + norms_node = n.args[1] + centroids_node = n.args[2] + + packed_meta = packed_node.meta["val"] + norms_meta = norms_node.meta["val"] + centroids_meta = centroids_node.meta["val"] + + if centroids_meta.dim() != 1 or int(centroids_meta.shape[0]) != 16: + raise ValueError( + f"mlx::tq_dequant: centroids must be 1-D length 16; got " + f"shape {tuple(centroids_meta.shape)}" + ) + + last_dim_packed = packed_meta.shape[-1] + if not isinstance(last_dim_packed, int): + raise NotImplementedError( + "mlx::tq_dequant: packed last dim must be statically known" + ) + half_D = int(last_dim_packed) + D = half_D * 2 + if D % 32 != 0: + raise NotImplementedError( + f"mlx::tq_dequant: unpacked dim must be a multiple of 32 (one " + f"per SIMD lane); got D={D}" + ) + + out_dtype_int = torch_dtype_to_scalar_type(norms_meta.dtype) + + out = P.make_or_get_slot(n) + out_shape_flat = _output_shape_flat(P, packed_node, packed_slot) + M = _compute_M(P, packed_node) + M_iov: IntOrVid = IntOrVid.from_literal(int(M)) if isinstance(M, int) else M + + P.emit( + MetalKernelNode( + name="tq_dequant", + source=_TQ_DEQUANT_SOURCE, + header=_TQ_DEQUANT_HEADER, + inputs=[ + P.slot_to_tid(packed_slot), + P.slot_to_tid(norms_slot), + P.slot_to_tid(centroids_slot), + ], + outputs=[P.slot_to_tid(out)], + grid=[ + IntOrVid.from_literal(32), + IntOrVid.from_literal(1), + M_iov, + ], + threadgroup=[ + IntOrVid.from_literal(32), + IntOrVid.from_literal(1), + IntOrVid.from_literal(1), + ], + input_names=["packed", "norms", "centroids"], + output_names=["out"], + output_shapes_flat=out_shape_flat, + output_shape_lengths=[len(out_shape_flat)], + output_dtypes=[out_dtype_int], + template_arg_names=["InT", "D"], + template_arg_kinds=[2, 0], # 2=dtype, 0=int + template_arg_values=[out_dtype_int, D], + ) + ) + + return out + + +_registered = True diff --git a/backends/mlx/model_ops/tq_norm.py b/backends/mlx/model_ops/tq_norm.py new file mode 100644 index 00000000000..64a210e5704 --- /dev/null +++ b/backends/mlx/model_ops/tq_norm.py @@ -0,0 +1,254 @@ +# +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +# + +""" +``mlx::tq_norm``: L2 norm along the last dim, lowered to a single Metal kernel. + + norms[..., 0] = sqrt(sum_i x[..., i]^2) + +Reads / writes ``x.dtype`` directly (no graph-level dtype casts). +Reduces in fp32 inside Metal registers via ``simd_sum`` for precision +on large ``D`` (bf16 sum-of-squares loses too much for D>=128). + +Constraints: + * Last dim ``D`` must be statically known and a multiple of 32. + +Usage:: + + import executorch.backends.mlx.model_ops.tq_norm # noqa: F401 + + norms = torch.ops.mlx.tq_norm(x) + # x: (..., D) bf16 + # norms: (..., 1) bf16, equal to vector_norm(x, dim=-1, keepdim=True) +""" + +from __future__ import annotations + +from functools import reduce +from operator import mul +from typing import Optional, Union + +import torch +from torch import Tensor +from torch.fx.node import Node + + +# --------------------------------------------------------------------------- +# Custom op + eager fallback +# --------------------------------------------------------------------------- + + +@torch.library.custom_op("mlx::tq_norm", mutates_args=()) +def tq_norm(x: Tensor) -> Tensor: + """L2 norm along last dim. + + Args: + x: ``(..., D)``. For MLX lowering, ``D`` must be a multiple of 32. + + Returns: + ``(..., 1)`` of the same dtype as ``x``. + """ + return torch.linalg.vector_norm(x, dim=-1, keepdim=True).to(x.dtype) + + +@torch.library.register_fake("mlx::tq_norm") +def tq_norm_fake(x: Tensor) -> Tensor: + out_shape = list(x.shape) + out_shape[-1] = 1 + return x.new_empty(out_shape, dtype=x.dtype) + + +# --------------------------------------------------------------------------- +# MLX handler +# --------------------------------------------------------------------------- + +from executorch.backends.mlx.builder.op_helpers import torch_dtype_to_scalar_type +from executorch.backends.mlx.builder.op_registry import REGISTRY +from executorch.backends.mlx.builder.program_builder import MLXProgramBuilder +from executorch.backends.mlx.builder.slot_manager import Slot +from executorch.backends.mlx.serialization.mlx_graph_schema import ( + IntOrVid, + MetalKernelNode, + MultiplyIntNode, + SymSizeNode, +) + + +_TQ_NORM_HEADER = """ +#include +using namespace metal; +""" + + +# Per-vector reduction: +# * Grid (32, 1, M), threadgroup (32, 1, 1): one simdgroup per vector. +# * Each lane covers DIMS_PER_LANE = D/32 elements; partial sums are +# accumulated in an fp32 register. +# * ``simd_sum`` reduces across the 32 lanes; lane 0 sqrts and writes. +_TQ_NORM_SOURCE = """ + constexpr uint DIMS_PER_LANE = D / 32; + + uint vec_id = thread_position_in_grid.z; + uint lane_id = thread_position_in_threadgroup.x; + + uint base = vec_id * D + lane_id * DIMS_PER_LANE; + + float local_sum_sq = 0.0f; + for (uint i = 0; i < DIMS_PER_LANE; ++i) { + float v = float(x[base + i]); + local_sum_sq += v * v; + } + + float total_sum_sq = simd_sum(local_sum_sq); + + if (lane_id == 0) { + norms[vec_id] = (InT)sqrt(total_sum_sq); + } +""" + + +def _compute_M(P: MLXProgramBuilder, node: Node) -> Union[int, IntOrVid]: + """``M = numel(x) / D`` (product of leading dims). Returns a static + int when known, else an IntOrVid built from SymSize + MultiplyInt.""" + val = node.meta.get("val") + if val is None: + raise ValueError("mlx::tq_norm: input node has no meta['val']") + shape = val.shape + + last_dim = shape[-1] + if not isinstance(last_dim, int): + raise NotImplementedError("mlx::tq_norm: last dim must be statically known") + + leading_shape = list(shape[:-1]) + + if all(isinstance(s, int) for s in leading_shape): + return reduce(mul, [int(s) for s in leading_shape], 1) + + in_slot = P.slot_map([node])[0] + in_tid = P.slot_to_tid(in_slot) + + acc_iov: Optional[IntOrVid] = None + for dim_idx, s in enumerate(leading_shape): + if isinstance(s, int): + d_iov = IntOrVid.from_literal(int(s)) + else: + _, d_val = P.make_tmp_value_slot() + P.emit( + SymSizeNode( + a=in_tid, + dim=dim_idx, + out=P.slot_to_vid(d_val), + ) + ) + d_iov = P.to_int_or_vid(d_val) + + if acc_iov is None: + acc_iov = d_iov + else: + _, acc_val = P.make_tmp_value_slot() + P.emit( + MultiplyIntNode( + a=acc_iov, + b=d_iov, + out=P.slot_to_vid(acc_val), + ) + ) + acc_iov = P.to_int_or_vid(acc_val) + + assert acc_iov is not None + return acc_iov + + +def _output_shape_flat(P: MLXProgramBuilder, node: Node, in_slot: Slot) -> list: + """Output shape: same as input but with last dim = 1.""" + val = node.meta["val"] + shape = val.shape + last_idx = len(shape) - 1 + out: list = [] + for dim_idx, s in enumerate(shape): + if isinstance(s, int): + d = 1 if dim_idx == last_idx else int(s) + out.append(IntOrVid.from_literal(d)) + else: + if dim_idx == last_idx: + raise NotImplementedError( + "mlx::tq_norm: dynamic last-dim is not supported" + ) + _, d_val = P.make_tmp_value_slot() + P.emit( + SymSizeNode( + a=P.slot_to_tid(in_slot), + dim=dim_idx, + out=P.slot_to_vid(d_val), + ) + ) + out.append(P.to_int_or_vid(d_val)) + return out + + +@REGISTRY.register(target=[torch.ops.mlx.tq_norm.default]) +def _tq_norm_handler(P: MLXProgramBuilder, n: Node) -> Slot: + """Lower ``mlx::tq_norm`` to a single fused Metal kernel.""" + args = P.args(n) + if len(args) != 1: + raise ValueError(f"mlx::tq_norm: expected 1 arg (x), got {len(args)}") + + (x_slot,) = args + x_node = n.args[0] + + x_meta = x_node.meta["val"] + + last_dim = x_meta.shape[-1] + if not isinstance(last_dim, int): + raise NotImplementedError("mlx::tq_norm: last dim must be statically known") + D = int(last_dim) + if D % 32 != 0: + raise NotImplementedError( + f"mlx::tq_norm: last dim must be a multiple of 32 (one per " + f"SIMD lane); got D={D}" + ) + + in_dtype_int = torch_dtype_to_scalar_type(x_meta.dtype) + + out = P.make_or_get_slot(n) + out_shape_flat = _output_shape_flat(P, x_node, x_slot) + M = _compute_M(P, x_node) + M_iov: IntOrVid = IntOrVid.from_literal(int(M)) if isinstance(M, int) else M + + P.emit( + MetalKernelNode( + name="tq_norm", + source=_TQ_NORM_SOURCE, + header=_TQ_NORM_HEADER, + inputs=[P.slot_to_tid(x_slot)], + outputs=[P.slot_to_tid(out)], + grid=[ + IntOrVid.from_literal(32), + IntOrVid.from_literal(1), + M_iov, + ], + threadgroup=[ + IntOrVid.from_literal(32), + IntOrVid.from_literal(1), + IntOrVid.from_literal(1), + ], + input_names=["x"], + output_names=["norms"], + output_shapes_flat=out_shape_flat, + output_shape_lengths=[len(out_shape_flat)], + output_dtypes=[in_dtype_int], + template_arg_names=["InT", "D"], + template_arg_kinds=[2, 0], # 2=dtype, 0=int + template_arg_values=[in_dtype_int, D], + ) + ) + + return out + + +_registered = True diff --git a/backends/mlx/test/op_test_runner.cpp b/backends/mlx/test/op_test_runner.cpp index 6bed13d7a56..925ff410f42 100644 --- a/backends/mlx/test/op_test_runner.cpp +++ b/backends/mlx/test/op_test_runner.cpp @@ -58,6 +58,7 @@ enum class DType : uint32_t { Int64 = 3, BFloat16 = 4, Bool = 5, + UInt8 = 6, }; size_t dtype_size(DType dtype) { @@ -74,6 +75,8 @@ size_t dtype_size(DType dtype) { return 2; case DType::Bool: return 1; + case DType::UInt8: + return 1; default: return 4; } @@ -93,6 +96,8 @@ exec_aten::ScalarType dtype_to_scalar_type(DType dtype) { return exec_aten::ScalarType::BFloat16; case DType::Bool: return exec_aten::ScalarType::Bool; + case DType::UInt8: + return exec_aten::ScalarType::Byte; default: return exec_aten::ScalarType::Float; } @@ -112,6 +117,8 @@ DType scalar_type_to_dtype(exec_aten::ScalarType stype) { return DType::BFloat16; case exec_aten::ScalarType::Bool: return DType::Bool; + case exec_aten::ScalarType::Byte: + return DType::UInt8; default: return DType::Float32; } @@ -316,6 +323,11 @@ int main(int argc, char* argv[]) { std::memcpy(data.data(), t.data.data(), t.data.size()); tensor_ptr = make_tensor_ptr( sizes, std::move(data), {}, {}, exec_aten::ScalarType::Bool); + } else if (t.dtype == DType::UInt8) { + std::vector data(t.data.size()); + std::memcpy(data.data(), t.data.data(), t.data.size()); + tensor_ptr = make_tensor_ptr( + sizes, std::move(data), {}, {}, exec_aten::ScalarType::Byte); } else { std::cerr << "Unsupported dtype: " << static_cast(t.dtype) << std::endl; diff --git a/backends/mlx/test/test_utils.py b/backends/mlx/test/test_utils.py index 660968195b7..5dbc35b824d 100644 --- a/backends/mlx/test/test_utils.py +++ b/backends/mlx/test/test_utils.py @@ -44,6 +44,7 @@ class TestTimeoutError(Exception): DTYPE_INT64 = 3 DTYPE_BFLOAT16 = 4 DTYPE_BOOL = 5 +DTYPE_UINT8 = 6 # Default tolerance presets for different data types. @@ -110,6 +111,7 @@ def torch_dtype_to_bin_dtype(dtype: torch.dtype) -> int: torch.int64: DTYPE_INT64, torch.bfloat16: DTYPE_BFLOAT16, torch.bool: DTYPE_BOOL, + torch.uint8: DTYPE_UINT8, } if dtype not in mapping: raise ValueError(f"Unsupported dtype: {dtype}") @@ -125,6 +127,7 @@ def bin_dtype_to_torch_dtype(dtype_val: int) -> torch.dtype: DTYPE_INT64: torch.int64, DTYPE_BFLOAT16: torch.bfloat16, DTYPE_BOOL: torch.bool, + DTYPE_UINT8: torch.uint8, } if dtype_val not in mapping: raise ValueError(f"Unknown dtype value: {dtype_val}") @@ -208,6 +211,7 @@ def load_tensors_from_bin(path: Union[str, Path]) -> List[torch.Tensor]: torch.int32: np.int32, torch.int64: np.int64, torch.bool: np.bool_, + torch.uint8: np.uint8, # bfloat16 needs special handling - read as uint16 } @@ -219,6 +223,7 @@ def load_tensors_from_bin(path: Union[str, Path]) -> List[torch.Tensor]: torch.int64: 8, torch.bfloat16: 2, torch.bool: 1, + torch.uint8: 1, } tensors = [] diff --git a/examples/models/gemma4_31b/README.md b/examples/models/gemma4_31b/README.md index c6ac10748d8..ae3bcb24c19 100644 --- a/examples/models/gemma4_31b/README.md +++ b/examples/models/gemma4_31b/README.md @@ -93,6 +93,24 @@ method with dynamic sequence length and host-side sampling. Writes `model.pte` (and optionally `model.ptd`) into `--output-dir`. +#### TurboQuant KV cache (long context, MLX only) + +For long-context inference, add `--turboquant` to swap the full-attention +layers' KV cache for a TurboQuant TQ4 cache (4-bit codebook + nibble pack). +This gives ~3.8× cache memory savings on the full-attention layers and lets +you fit context lengths that wouldn't fit in bf16. Sliding-window layers are unaffected. + +```bash +python examples/models/gemma4_31b/export.py \ + --prequantized ./gemma4_31b_int4 \ + --output-dir ./gemma4_31b_exports_mlx_tq \ + --max-seq-len 65536 \ + --backend mlx \ + --turboquant +``` + +Use TurboQuant when you need context beyond what bf16 fits; otherwise leave it off. + ## Eager inference The prompt is automatically wrapped with the Gemma 4 IT chat template. diff --git a/examples/models/gemma4_31b/export.py b/examples/models/gemma4_31b/export.py index bd648f534b5..f4e7e03c71e 100644 --- a/examples/models/gemma4_31b/export.py +++ b/examples/models/gemma4_31b/export.py @@ -141,12 +141,19 @@ def export_and_lower( config: Gemma4_31BConfig, output_dir: str, backend: str = "cuda", + use_turboquant: bool = False, ) -> None: """Export and lower the model to ExecuTorch for the given backend.""" if backend == "cuda": + if use_turboquant: + raise ValueError( + "--turboquant is only supported with --backend mlx " + "(the CUDA path here uses a different TurboQuant integration; " + "see examples/models/qwen3_5_moe/export.py)." + ) _export_cuda(model, config, output_dir) elif backend == "mlx": - _export_mlx(model, config, output_dir) + _export_mlx(model, config, output_dir, use_turboquant=use_turboquant) else: raise ValueError( f"Unsupported backend: {backend!r}. Supported: {_SUPPORTED_BACKENDS}." @@ -279,7 +286,12 @@ def _export_cuda(model: Gemma4_31B, config: Gemma4_31BConfig, output_dir: str) - print("Done.") -def _export_mlx(model: Gemma4_31B, config: Gemma4_31BConfig, output_dir: str) -> None: +def _export_mlx( + model: Gemma4_31B, + config: Gemma4_31BConfig, + output_dir: str, + use_turboquant: bool = False, +) -> None: """Export to .pte via torch.export + MLX backend. Unlike CUDA (which exports separate decode/prefill methods with an @@ -287,6 +299,10 @@ def _export_mlx(model: Gemma4_31B, config: Gemma4_31BConfig, output_dir: str) -> sequence length. No int4_dispatch import — IntxUnpackedToInt8Tensor's default dispatch produces the ``dequantize_affine → linear`` pattern that MLX's QuantizedLinearHandler matches. + + When ``use_turboquant=True``, full-attention layers swap to + ``MLXTurboQuantKVCache`` for ~3.8× KV cache memory savings. Sliding + layers are unaffected (already use ``RingBufferKVCache``). """ import gc @@ -304,10 +320,13 @@ def _export_mlx(model: Gemma4_31B, config: Gemma4_31BConfig, output_dir: str) -> from executorch.exir.passes import MemoryPlanningPass from torch.export import Dim, export - mlx_source_transformations(model, dtype=torch.bfloat16) + mlx_source_transformations( + model, dtype=torch.bfloat16, use_turboquant=use_turboquant + ) + materialize_runtime_buffers(model, dtype=torch.bfloat16) - max_prefill = min(config.max_seq_len - 1, config.sliding_window * 2) + max_prefill = 256 seq_dim = Dim("seq_len", min=1, max=max_prefill) print(f"Exporting (T in [1, {max_prefill}])...") @@ -418,8 +437,17 @@ def main() -> None: choices=list(_SUPPORTED_BACKENDS), help="Target backend for export.", ) + parser.add_argument( + "--turboquant", + action="store_true", + help="Use TurboQuant TQ4 KV cache compression (MLX backend only). " + "~3.8× cache memory savings; applies only to full-attention " + "(non-sliding) layers — sliding layers keep RingBufferKVCache.", + ) args = parser.parse_args() + if args.turboquant and args.backend != "mlx": + parser.error("--turboquant requires --backend mlx.") if args.backend == "cuda" and not torch.cuda.is_available(): parser.error("CUDA is required for the cuda backend.") @@ -446,10 +474,15 @@ def main() -> None: if args.gguf and args.backend == "mlx": os.environ["ET_MLX_ALLOW_NON_FUSED_QUANTIZED_OPS"] = "1" try: - export_and_lower(model, config, args.output_dir, backend=args.backend) + export_and_lower( + model, + config, + args.output_dir, + backend=args.backend, + use_turboquant=args.turboquant, + ) finally: os.environ.pop("ET_MLX_ALLOW_NON_FUSED_QUANTIZED_OPS", None) - if __name__ == "__main__": main() diff --git a/examples/models/gemma4_31b/mlx_source_transformations.py b/examples/models/gemma4_31b/mlx_source_transformations.py index 3a8ae4420e3..0bbd4f7b250 100644 --- a/examples/models/gemma4_31b/mlx_source_transformations.py +++ b/examples/models/gemma4_31b/mlx_source_transformations.py @@ -24,6 +24,9 @@ KVCache as MLXKVCache, RingBufferKVCache as MLXRingKVCache, ) +from executorch.backends.mlx.llm.turboquant_cache import ( + TurboQuantKVCache as MLXTurboQuantKVCache, +) def _replace_attention_forward(attn: nn.Module) -> None: @@ -68,30 +71,34 @@ def _mlx_forward(self, x: torch.Tensor, input_pos: torch.Tensor) -> torch.Tensor q = torch.ops.mlx.rope(q, rotary_dim, start_pos, False, 0.0, 1.0, mlx_freqs) k = torch.ops.mlx.rope(k, rotary_dim, start_pos, False, 0.0, 1.0, mlx_freqs) - k_cache, v_cache = self.kv_cache.update(start_pos, k, v) - - if self.is_sliding: - sdpa_mask = self.kv_cache.create_sliding_window_mask(start_pos, T) - y = torch.ops.mlx.custom_sdpa( - q, - k_cache, - v_cache, - start_pos=self.kv_cache.buffer_size - T, - attn_mask=sdpa_mask, - dropout_p=0.0, - is_causal=False, - scale=self.scaling, - ) + if getattr(self, "is_turboquant", False): + self.kv_cache.update(start_pos, k, v) + y = self.kv_cache.sdpa(q, start_pos, scale=self.scaling) else: - y = torch.ops.mlx.custom_sdpa( - q, - k_cache, - v_cache, - start_pos=start_pos, - dropout_p=0.0, - is_causal=True, - scale=self.scaling, - ) + k_cache, v_cache = self.kv_cache.update(start_pos, k, v) + + if self.is_sliding: + sdpa_mask = self.kv_cache.create_sliding_window_mask(start_pos, T) + y = torch.ops.mlx.custom_sdpa( + q, + k_cache, + v_cache, + start_pos=self.kv_cache.buffer_size - T, + attn_mask=sdpa_mask, + dropout_p=0.0, + is_causal=False, + scale=self.scaling, + ) + else: + y = torch.ops.mlx.custom_sdpa( + q, + k_cache, + v_cache, + start_pos=start_pos, + dropout_p=0.0, + is_causal=True, + scale=self.scaling, + ) y = y.transpose(1, 2).contiguous().view(B, T, self.n_heads * self.head_dim) return self.o_proj(y) @@ -150,6 +157,7 @@ def _mlx_model_forward( def mlx_source_transformations( model: nn.Module, dtype: torch.dtype = torch.bfloat16, + use_turboquant: bool = False, ) -> None: """Apply MLX source transformations to a Gemma 4 31B model in-place. @@ -162,6 +170,13 @@ def mlx_source_transformations( - Rewrites layer forward to drop mask parameters (each attention builds its own mask via ``custom_sdpa``) - Rewrites model forward to drop the sampler and ``_build_masks`` + + Args: + model: Gemma4_31B model to transform in place. + dtype: dtype for KV cache buffers (bf16 by default). + use_turboquant: If True, swap full-attention layers' KV caches + for ``MLXTurboQuantKVCache`` (~3.8× cache memory savings). + Sliding-window layers are unaffected. """ config = model.config @@ -176,6 +191,17 @@ def mlx_source_transformations( head_dim=attn.head_dim, dtype=dtype, ) + attn.is_turboquant = False + elif use_turboquant: + attn.kv_cache = MLXTurboQuantKVCache( + max_batch_size=1, + max_context_length=config.max_seq_len, + n_heads=attn.n_kv_heads, + head_dim=attn.head_dim, + enable_dynamic_shape=True, + dtype=dtype, + ) + attn.is_turboquant = True else: attn.kv_cache = MLXKVCache( max_batch_size=1, @@ -185,6 +211,7 @@ def mlx_source_transformations( enable_dynamic_shape=True, dtype=dtype, ) + attn.is_turboquant = False _replace_attention_forward(attn) _replace_layer_forward(layer) From 84836fe2f0f3d764cf501ad4ec3c8e924c40c9bb Mon Sep 17 00:00:00 2001 From: Scott Roy Date: Fri, 29 May 2026 10:53:54 -0700 Subject: [PATCH 2/8] up --- backends/mlx/llm/turboquant_cache.py | 12 ++++++++++++ backends/mlx/model_ops/tq4_compress.py | 9 ++++----- 2 files changed, 16 insertions(+), 5 deletions(-) diff --git a/backends/mlx/llm/turboquant_cache.py b/backends/mlx/llm/turboquant_cache.py index 265dc0b89b5..19a65704311 100644 --- a/backends/mlx/llm/turboquant_cache.py +++ b/backends/mlx/llm/turboquant_cache.py @@ -71,6 +71,18 @@ def __init__( f"TurboQuantKVCache only supports max_batch_size=1, " f"got {max_batch_size}" ) + if bits != 4: + raise ValueError( + f"TurboQuantKVCache only supports bits=4 " + f"(16-entry codebook), got bits={bits}" + ) + # MLX-backend Metal kernels (``tq_dequant``, ``tq_norm``) hard-code + # 32 SIMD lanes per vector, so ``head_dim`` must be a multiple of 32 + if head_dim % 32 != 0: + raise ValueError( + f"TurboQuantKVCache requires head_dim to be " + f"a multiple of 32 (Metal SIMD constraint), got {head_dim}" + ) super().__init__( n_heads=n_heads, head_dim=head_dim, diff --git a/backends/mlx/model_ops/tq4_compress.py b/backends/mlx/model_ops/tq4_compress.py index 0d5c9748200..69db558abde 100644 --- a/backends/mlx/model_ops/tq4_compress.py +++ b/backends/mlx/model_ops/tq4_compress.py @@ -95,7 +95,6 @@ def tq4_compress_fake(values: Tensor, boundaries: Tensor) -> Tensor: # constant), and packs the two 4-bit indices into one byte. _TQ4_COMPRESS_SOURCE = """ uint gid = thread_position_in_grid.x; - if (gid >= N_OUT) return; float v_hi = float(values[2 * gid]); float v_lo = float(values[2 * gid + 1]); uchar idx_hi = 0; @@ -254,8 +253,9 @@ def _tq4_compress_handler(P: MLXProgramBuilder, n: Node) -> Slot: ], outputs=[P.slot_to_tid(out)], grid=[n_out_iov, IntOrVid.from_literal(1), IntOrVid.from_literal(1)], + # 32 threads per threadgroup so each TG fills one Apple-GPU SIMD group threadgroup=[ - IntOrVid.from_literal(1), + IntOrVid.from_literal(32), IntOrVid.from_literal(1), IntOrVid.from_literal(1), ], @@ -264,12 +264,11 @@ def _tq4_compress_handler(P: MLXProgramBuilder, n: Node) -> Slot: output_shapes_flat=out_shape_flat, output_shape_lengths=[len(out_shape_flat)], output_dtypes=[torch_dtype_to_scalar_type(torch.uint8)], - template_arg_names=["InT", "B", "N_OUT"], - template_arg_kinds=[2, 0, 0], # 2=dtype, 0=int + template_arg_names=["InT", "B"], + template_arg_kinds=[2, 0], # 2=dtype, 0=int template_arg_values=[ in_dtype_int, 15, - int(n_out) if isinstance(n_out, int) else (1 << 30), ], ) ) From faacfe7b3d8a310b58743461fdab0a5f14dfa1c3 Mon Sep 17 00:00:00 2001 From: Scott Roy Date: Fri, 29 May 2026 14:32:39 -0700 Subject: [PATCH 3/8] up --- .github/workflows/mlx.yml | 12 +++ backends/mlx/builder/op_helpers.py | 112 ++++++++++++++++++++++++ backends/mlx/model_ops/tq4_compress.py | 114 +++---------------------- backends/mlx/model_ops/tq_dequant.py | 89 ++----------------- backends/mlx/model_ops/tq_norm.py | 100 ++-------------------- 5 files changed, 152 insertions(+), 275 deletions(-) diff --git a/.github/workflows/mlx.yml b/.github/workflows/mlx.yml index 027101ba7f0..c51f126dbe6 100644 --- a/.github/workflows/mlx.yml +++ b/.github/workflows/mlx.yml @@ -80,6 +80,18 @@ jobs: ${CONDA_RUN} python -m executorch.backends.mlx.model_ops.test_gated_delta_rule run -v echo "::endgroup::" + echo "::group::Run tq_norm op tests" + ${CONDA_RUN} python -m executorch.backends.mlx.model_ops.test_tq_norm run -v + echo "::endgroup::" + + echo "::group::Run tq4_compress op tests" + ${CONDA_RUN} python -m executorch.backends.mlx.model_ops.test_tq4_compress run -v + echo "::endgroup::" + + echo "::group::Run tq_dequant op tests" + ${CONDA_RUN} python -m executorch.backends.mlx.model_ops.test_tq_dequant run -v + echo "::endgroup::" + test-mlx-qwen35-moe: uses: pytorch/test-infra/.github/workflows/macos_job.yml@main with: diff --git a/backends/mlx/builder/op_helpers.py b/backends/mlx/builder/op_helpers.py index 7740546cc2c..be199f75340 100644 --- a/backends/mlx/builder/op_helpers.py +++ b/backends/mlx/builder/op_helpers.py @@ -17,6 +17,7 @@ if TYPE_CHECKING: from executorch.backends.mlx.builder.program_builder import MLXProgramBuilder + from executorch.backends.mlx.serialization.mlx_graph_schema import IntOrVid # When True, always serialize the biases tensor for quantized ops. # When False, use init-time computation when zero_point is all zeros, @@ -173,6 +174,117 @@ def emit_lifted_constant(P: "MLXProgramBuilder", value, dtype: torch.dtype) -> S return slot +def emit_shape( + P: "MLXProgramBuilder", + node: Node, + slot: Slot, + *, + end_dim: "Optional[int]" = None, +) -> "list[IntOrVid]": + """Return the shape of ``node`` as a list of ``IntOrVid``. + + Each static dim becomes a literal ``IntOrVid``; each dynamic dim + emits a ``SymSizeNode`` against ``slot`` and is wrapped via + ``P.to_int_or_vid``. + + Args: + P: program builder. + node: FX node whose shape to walk (must have ``meta['val']``). + slot: slot corresponding to ``node`` (used as the + ``SymSize`` source for any dynamic dim). + end_dim: stop index (exclusive). ``None`` means the full ndim. + Negative values index from the end (e.g. ``-1`` is "all + leading dims, drop the last"). + + Returns: + ``list[IntOrVid]`` of length ``end_dim`` (after normalization). + """ + from executorch.backends.mlx.serialization.mlx_graph_schema import ( + IntOrVid, + SymSizeNode, + ) + + shape = node.meta["val"].shape + ndim = len(shape) + if end_dim is None: + end_dim = ndim + elif end_dim < 0: + end_dim += ndim + + out: "list[IntOrVid]" = [] + for dim_idx in range(end_dim): + s = shape[dim_idx] + if isinstance(s, int): + out.append(IntOrVid.from_literal(int(s))) + else: + _, d_val = P.make_tmp_value_slot() + P.emit( + SymSizeNode( + a=P.slot_to_tid(slot), + dim=dim_idx, + out=P.slot_to_vid(d_val), + ) + ) + out.append(P.to_int_or_vid(d_val)) + return out + + +def emit_product( + P: "MLXProgramBuilder", + dims: "list[IntOrVid]", +) -> "IntOrVid": + """Multiplicative reduction over a list of ``IntOrVid`` values. + + Folds all literal entries AOT into a single static product, then + emits ``MultiplyIntNode`` only for the dynamic entries (and one + final node combining the static product with the dynamic accumulator + when both contribute). + + Args: + P: program builder. + dims: list of ``IntOrVid``. May be empty (returns + ``IntOrVid.from_literal(1)``), all literals, or a mix. + + Returns: + An ``IntOrVid`` representing the product. Always literal when + every entry is literal (or ``dims`` is empty). + """ + from executorch.backends.mlx.serialization.mlx_graph_schema import ( + IntOrVid, + MultiplyIntNode, + ) + + static_product = 1 + dynamic_dims: "list[IntOrVid]" = [] + for d in dims: + if d.is_vid: + dynamic_dims.append(d) + else: + static_product *= d.literal + + if not dynamic_dims: + return IntOrVid.from_literal(static_product) + + acc = dynamic_dims[0] + for d in dynamic_dims[1:]: + _, acc_val = P.make_tmp_value_slot() + P.emit(MultiplyIntNode(a=acc, b=d, out=P.slot_to_vid(acc_val))) + acc = P.to_int_or_vid(acc_val) + + if static_product == 1: + return acc + + _, final_val = P.make_tmp_value_slot() + P.emit( + MultiplyIntNode( + a=IntOrVid.from_literal(static_product), + b=acc, + out=P.slot_to_vid(final_val), + ) + ) + return P.to_int_or_vid(final_val) + + def emit_quantized_biases( P: "MLXProgramBuilder", zero_point_key: str, diff --git a/backends/mlx/model_ops/tq4_compress.py b/backends/mlx/model_ops/tq4_compress.py index 69db558abde..f08d47b9a11 100644 --- a/backends/mlx/model_ops/tq4_compress.py +++ b/backends/mlx/model_ops/tq4_compress.py @@ -30,10 +30,6 @@ from __future__ import annotations -from functools import reduce -from operator import mul -from typing import Optional, Union - import torch from torch import Tensor from torch.fx.node import Node @@ -78,15 +74,17 @@ def tq4_compress_fake(values: Tensor, boundaries: Tensor) -> Tensor: # MLX handler # --------------------------------------------------------------------------- -from executorch.backends.mlx.builder.op_helpers import torch_dtype_to_scalar_type +from executorch.backends.mlx.builder.op_helpers import ( + emit_product, + emit_shape, + torch_dtype_to_scalar_type, +) from executorch.backends.mlx.builder.op_registry import REGISTRY from executorch.backends.mlx.builder.program_builder import MLXProgramBuilder from executorch.backends.mlx.builder.slot_manager import Slot from executorch.backends.mlx.serialization.mlx_graph_schema import ( IntOrVid, MetalKernelNode, - MultiplyIntNode, - SymSizeNode, ) @@ -109,93 +107,6 @@ def tq4_compress_fake(values: Tensor, boundaries: Tensor) -> Tensor: """ -def _compute_output_numel(P: MLXProgramBuilder, node: Node) -> Union[int, IntOrVid]: - """Output numel = numel(input) / 2. Returns a static int when the - full shape is known, else an IntOrVid built from SymSize + - MultiplyInt nodes.""" - val = node.meta.get("val") - if val is None: - raise ValueError("mlx::tq4_compress: input node has no meta['val']") - shape = val.shape - - if all(isinstance(s, int) for s in shape): - return reduce(mul, [int(s) for s in shape], 1) // 2 - - in_slot = P.slot_map([node])[0] - in_tid = P.slot_to_tid(in_slot) - - last_idx = len(shape) - 1 - acc_iov: Optional[IntOrVid] = None - for dim_idx in range(len(shape)): - s = shape[dim_idx] - if isinstance(s, int): - d = int(s) - if dim_idx == last_idx: - d //= 2 - d_iov = IntOrVid.from_literal(d) - else: - if dim_idx == last_idx: - # The schema has no integer-divide-by-Vid op; require the - # last dim be static so the /2 stays a literal. - raise NotImplementedError( - "mlx::tq4_compress: dynamic last-dim is not supported" - ) - _, d_val = P.make_tmp_value_slot() - P.emit( - SymSizeNode( - a=in_tid, - dim=dim_idx, - out=P.slot_to_vid(d_val), - ) - ) - d_iov = P.to_int_or_vid(d_val) - - if acc_iov is None: - acc_iov = d_iov - else: - _, acc_val = P.make_tmp_value_slot() - P.emit( - MultiplyIntNode( - a=acc_iov, - b=d_iov, - out=P.slot_to_vid(acc_val), - ) - ) - acc_iov = P.to_int_or_vid(acc_val) - - assert acc_iov is not None - return acc_iov - - -def _output_shape_flat(P: MLXProgramBuilder, node: Node, in_slot: Slot) -> list: - """Output shape: same as input but with last dim halved.""" - val = node.meta["val"] - shape = val.shape - last_idx = len(shape) - 1 - out: list = [] - for dim_idx, s in enumerate(shape): - if isinstance(s, int): - d = int(s) - if dim_idx == last_idx: - d //= 2 - out.append(IntOrVid.from_literal(d)) - else: - if dim_idx == last_idx: - raise NotImplementedError( - "mlx::tq4_compress: dynamic last-dim is not supported" - ) - _, d_val = P.make_tmp_value_slot() - P.emit( - SymSizeNode( - a=P.slot_to_tid(in_slot), - dim=dim_idx, - out=P.slot_to_vid(d_val), - ) - ) - out.append(P.to_int_or_vid(d_val)) - return out - - @REGISTRY.register(target=[torch.ops.mlx.tq4_compress.default]) def _tq4_compress_handler(P: MLXProgramBuilder, n: Node) -> Slot: """Lower ``mlx::tq4_compress`` to a fused Metal kernel.""" @@ -232,16 +143,18 @@ def _tq4_compress_handler(P: MLXProgramBuilder, n: Node) -> Slot: ) if int(last_dim) % 2 != 0: raise ValueError(f"mlx::tq4_compress: last dim must be even; got {last_dim}") + half_last = int(last_dim) // 2 in_dtype_int = torch_dtype_to_scalar_type(values_meta.dtype) out = P.make_or_get_slot(n) - out_shape_flat = _output_shape_flat(P, values_node, values_slot) + leading = emit_shape(P, values_node, values_slot, end_dim=-1) + half_last_iov = IntOrVid.from_literal(half_last) + out_shape_flat = leading + [half_last_iov] - n_out = _compute_output_numel(P, values_node) - n_out_iov: IntOrVid = ( - IntOrVid.from_literal(int(n_out)) if isinstance(n_out, int) else n_out - ) + # One thread per output byte, so the grid size is the output numel + # (product of leading dims times the halved last dim). + n_out_iov = emit_product(P, leading + [half_last_iov]) P.emit( MetalKernelNode( @@ -274,6 +187,3 @@ def _tq4_compress_handler(P: MLXProgramBuilder, n: Node) -> Slot: ) return out - - -_registered = True diff --git a/backends/mlx/model_ops/tq_dequant.py b/backends/mlx/model_ops/tq_dequant.py index 4b6605e903e..23efe5692b9 100644 --- a/backends/mlx/model_ops/tq_dequant.py +++ b/backends/mlx/model_ops/tq_dequant.py @@ -34,10 +34,6 @@ from __future__ import annotations -from functools import reduce -from operator import mul -from typing import Optional, Union - import torch from torch import Tensor from torch.fx.node import Node @@ -89,15 +85,17 @@ def tq_dequant_fake(packed: Tensor, norms: Tensor, centroids: Tensor) -> Tensor: # MLX handler # --------------------------------------------------------------------------- -from executorch.backends.mlx.builder.op_helpers import torch_dtype_to_scalar_type +from executorch.backends.mlx.builder.op_helpers import ( + emit_product, + emit_shape, + torch_dtype_to_scalar_type, +) from executorch.backends.mlx.builder.op_registry import REGISTRY from executorch.backends.mlx.builder.program_builder import MLXProgramBuilder from executorch.backends.mlx.builder.slot_manager import Slot from executorch.backends.mlx.serialization.mlx_graph_schema import ( IntOrVid, MetalKernelNode, - MultiplyIntNode, - SymSizeNode, ) @@ -139,74 +137,6 @@ def tq_dequant_fake(packed: Tensor, norms: Tensor, centroids: Tensor) -> Tensor: """ -def _compute_M(P: MLXProgramBuilder, packed_node: Node) -> Union[int, IntOrVid]: - """``M`` = numel(packed) / (D/2) = product of leading dims of ``packed``.""" - val = packed_node.meta.get("val") - if val is None: - raise ValueError("mlx::tq_dequant: input has no meta['val']") - shape = val.shape - - if not isinstance(shape[-1], int): - raise NotImplementedError( - "mlx::tq_dequant: last dim of packed must be statically known" - ) - - leading = list(shape[:-1]) - if all(isinstance(s, int) for s in leading): - return reduce(mul, [int(s) for s in leading], 1) - - in_slot = P.slot_map([packed_node])[0] - in_tid = P.slot_to_tid(in_slot) - - acc_iov: Optional[IntOrVid] = None - for dim_idx, s in enumerate(leading): - if isinstance(s, int): - d_iov = IntOrVid.from_literal(int(s)) - else: - _, d_val = P.make_tmp_value_slot() - P.emit(SymSizeNode(a=in_tid, dim=dim_idx, out=P.slot_to_vid(d_val))) - d_iov = P.to_int_or_vid(d_val) - - if acc_iov is None: - acc_iov = d_iov - else: - _, acc_val = P.make_tmp_value_slot() - P.emit(MultiplyIntNode(a=acc_iov, b=d_iov, out=P.slot_to_vid(acc_val))) - acc_iov = P.to_int_or_vid(acc_val) - - assert acc_iov is not None - return acc_iov - - -def _output_shape_flat( - P: MLXProgramBuilder, packed_node: Node, packed_slot: Slot -) -> list: - """Output shape: same as packed but with last dim doubled.""" - val = packed_node.meta["val"] - shape = val.shape - last_idx = len(shape) - 1 - out: list = [] - for dim_idx, s in enumerate(shape): - if isinstance(s, int): - d = int(s) * 2 if dim_idx == last_idx else int(s) - out.append(IntOrVid.from_literal(d)) - else: - if dim_idx == last_idx: - raise NotImplementedError( - "mlx::tq_dequant: dynamic last-dim is not supported" - ) - _, d_val = P.make_tmp_value_slot() - P.emit( - SymSizeNode( - a=P.slot_to_tid(packed_slot), - dim=dim_idx, - out=P.slot_to_vid(d_val), - ) - ) - out.append(P.to_int_or_vid(d_val)) - return out - - @REGISTRY.register(target=[torch.ops.mlx.tq_dequant.default]) def _tq_dequant_handler(P: MLXProgramBuilder, n: Node) -> Slot: """Lower ``mlx::tq_dequant`` to a single fused Metal kernel.""" @@ -247,9 +177,9 @@ def _tq_dequant_handler(P: MLXProgramBuilder, n: Node) -> Slot: out_dtype_int = torch_dtype_to_scalar_type(norms_meta.dtype) out = P.make_or_get_slot(n) - out_shape_flat = _output_shape_flat(P, packed_node, packed_slot) - M = _compute_M(P, packed_node) - M_iov: IntOrVid = IntOrVid.from_literal(int(M)) if isinstance(M, int) else M + leading = emit_shape(P, packed_node, packed_slot, end_dim=-1) + out_shape_flat = leading + [IntOrVid.from_literal(D)] + M_iov = emit_product(P, leading) P.emit( MetalKernelNode( @@ -284,6 +214,3 @@ def _tq_dequant_handler(P: MLXProgramBuilder, n: Node) -> Slot: ) return out - - -_registered = True diff --git a/backends/mlx/model_ops/tq_norm.py b/backends/mlx/model_ops/tq_norm.py index 64a210e5704..7e6a4d657f3 100644 --- a/backends/mlx/model_ops/tq_norm.py +++ b/backends/mlx/model_ops/tq_norm.py @@ -29,10 +29,6 @@ from __future__ import annotations -from functools import reduce -from operator import mul -from typing import Optional, Union - import torch from torch import Tensor from torch.fx.node import Node @@ -67,15 +63,17 @@ def tq_norm_fake(x: Tensor) -> Tensor: # MLX handler # --------------------------------------------------------------------------- -from executorch.backends.mlx.builder.op_helpers import torch_dtype_to_scalar_type +from executorch.backends.mlx.builder.op_helpers import ( + emit_product, + emit_shape, + torch_dtype_to_scalar_type, +) from executorch.backends.mlx.builder.op_registry import REGISTRY from executorch.backends.mlx.builder.program_builder import MLXProgramBuilder from executorch.backends.mlx.builder.slot_manager import Slot from executorch.backends.mlx.serialization.mlx_graph_schema import ( IntOrVid, MetalKernelNode, - MultiplyIntNode, - SymSizeNode, ) @@ -112,85 +110,6 @@ def tq_norm_fake(x: Tensor) -> Tensor: """ -def _compute_M(P: MLXProgramBuilder, node: Node) -> Union[int, IntOrVid]: - """``M = numel(x) / D`` (product of leading dims). Returns a static - int when known, else an IntOrVid built from SymSize + MultiplyInt.""" - val = node.meta.get("val") - if val is None: - raise ValueError("mlx::tq_norm: input node has no meta['val']") - shape = val.shape - - last_dim = shape[-1] - if not isinstance(last_dim, int): - raise NotImplementedError("mlx::tq_norm: last dim must be statically known") - - leading_shape = list(shape[:-1]) - - if all(isinstance(s, int) for s in leading_shape): - return reduce(mul, [int(s) for s in leading_shape], 1) - - in_slot = P.slot_map([node])[0] - in_tid = P.slot_to_tid(in_slot) - - acc_iov: Optional[IntOrVid] = None - for dim_idx, s in enumerate(leading_shape): - if isinstance(s, int): - d_iov = IntOrVid.from_literal(int(s)) - else: - _, d_val = P.make_tmp_value_slot() - P.emit( - SymSizeNode( - a=in_tid, - dim=dim_idx, - out=P.slot_to_vid(d_val), - ) - ) - d_iov = P.to_int_or_vid(d_val) - - if acc_iov is None: - acc_iov = d_iov - else: - _, acc_val = P.make_tmp_value_slot() - P.emit( - MultiplyIntNode( - a=acc_iov, - b=d_iov, - out=P.slot_to_vid(acc_val), - ) - ) - acc_iov = P.to_int_or_vid(acc_val) - - assert acc_iov is not None - return acc_iov - - -def _output_shape_flat(P: MLXProgramBuilder, node: Node, in_slot: Slot) -> list: - """Output shape: same as input but with last dim = 1.""" - val = node.meta["val"] - shape = val.shape - last_idx = len(shape) - 1 - out: list = [] - for dim_idx, s in enumerate(shape): - if isinstance(s, int): - d = 1 if dim_idx == last_idx else int(s) - out.append(IntOrVid.from_literal(d)) - else: - if dim_idx == last_idx: - raise NotImplementedError( - "mlx::tq_norm: dynamic last-dim is not supported" - ) - _, d_val = P.make_tmp_value_slot() - P.emit( - SymSizeNode( - a=P.slot_to_tid(in_slot), - dim=dim_idx, - out=P.slot_to_vid(d_val), - ) - ) - out.append(P.to_int_or_vid(d_val)) - return out - - @REGISTRY.register(target=[torch.ops.mlx.tq_norm.default]) def _tq_norm_handler(P: MLXProgramBuilder, n: Node) -> Slot: """Lower ``mlx::tq_norm`` to a single fused Metal kernel.""" @@ -216,9 +135,9 @@ def _tq_norm_handler(P: MLXProgramBuilder, n: Node) -> Slot: in_dtype_int = torch_dtype_to_scalar_type(x_meta.dtype) out = P.make_or_get_slot(n) - out_shape_flat = _output_shape_flat(P, x_node, x_slot) - M = _compute_M(P, x_node) - M_iov: IntOrVid = IntOrVid.from_literal(int(M)) if isinstance(M, int) else M + leading = emit_shape(P, x_node, x_slot, end_dim=-1) + out_shape_flat = leading + [IntOrVid.from_literal(1)] + M_iov = emit_product(P, leading) P.emit( MetalKernelNode( @@ -249,6 +168,3 @@ def _tq_norm_handler(P: MLXProgramBuilder, n: Node) -> Slot: ) return out - - -_registered = True From 9dff8a9593b49e974b00fa1bfab2d64de40bf213 Mon Sep 17 00:00:00 2001 From: Scott Roy Date: Fri, 29 May 2026 16:11:57 -0700 Subject: [PATCH 4/8] up --- backends/mlx/llm/turboquant_cache.py | 14 +- backends/mlx/model_ops/tq_dequant.py | 8 +- backends/mlx/test/test_ops.py | 396 +++++++++++++++++++++++++++ 3 files changed, 409 insertions(+), 9 deletions(-) diff --git a/backends/mlx/llm/turboquant_cache.py b/backends/mlx/llm/turboquant_cache.py index 19a65704311..7f2109ba074 100644 --- a/backends/mlx/llm/turboquant_cache.py +++ b/backends/mlx/llm/turboquant_cache.py @@ -48,7 +48,7 @@ class TurboQuantKVCache(_SharedTurboQuantKVCache): max_batch_size: Must be 1 (TQ4 is batch=1 only). max_context_length: Maximum sequence length. n_heads: Number of KV heads. - head_dim: Per-head dimension. Must be even and a multiple of 32. + head_dim: Per-head dimension. Must be even and a multiple of 64. enable_dynamic_shape: Accepted for interface parity; ignored. dtype: Compute dtype (bf16). Used for pre-cast buffers. bits: Quantization bits (must be 4). @@ -76,12 +76,15 @@ def __init__( f"TurboQuantKVCache only supports bits=4 " f"(16-entry codebook), got bits={bits}" ) - # MLX-backend Metal kernels (``tq_dequant``, ``tq_norm``) hard-code - # 32 SIMD lanes per vector, so ``head_dim`` must be a multiple of 32 - if head_dim % 32 != 0: + # MLX-backend Metal kernels need ``head_dim % 64 == 0``: ``tq_norm`` + # uses 32 SIMD lanes (so D must be a multiple of 32), and + # ``tq_dequant`` packs 2 dims per byte across 32 lanes (so D must + # be a multiple of 64). Take the stricter constraint here. + if head_dim % 64 != 0: raise ValueError( f"TurboQuantKVCache requires head_dim to be " - f"a multiple of 32 (Metal SIMD constraint), got {head_dim}" + f"a multiple of 64 (Metal SIMD + 4-bit pack constraint), " + f"got {head_dim}" ) super().__init__( n_heads=n_heads, @@ -222,6 +225,7 @@ def sdpa( v_packed_live = self.v_packed[:, :, :end_pos, :] v_norms_live = self.v_norms[:, :, :end_pos, :] + # TODO: optimize with a fused dequant + SDPA k_rot = torch.ops.mlx.tq_dequant(k_packed_live, k_norms_live, self.centroids) v_rot = torch.ops.mlx.tq_dequant(v_packed_live, v_norms_live, self.centroids) diff --git a/backends/mlx/model_ops/tq_dequant.py b/backends/mlx/model_ops/tq_dequant.py index 23efe5692b9..28a168e9be0 100644 --- a/backends/mlx/model_ops/tq_dequant.py +++ b/backends/mlx/model_ops/tq_dequant.py @@ -17,7 +17,7 @@ left to the caller (typically MLX's tuned bf16 GEMM). Constraints: - * ``D`` (= ``packed.shape[-1] * 2``) must be a multiple of 32. + * ``D`` (= ``packed.shape[-1] * 2``) must be a multiple of 64. * ``centroids`` must be a 1-D tensor of length 16. * Output dtype matches ``norms.dtype``. @@ -168,10 +168,10 @@ def _tq_dequant_handler(P: MLXProgramBuilder, n: Node) -> Slot: ) half_D = int(last_dim_packed) D = half_D * 2 - if D % 32 != 0: + if D % 64 != 0: raise NotImplementedError( - f"mlx::tq_dequant: unpacked dim must be a multiple of 32 (one " - f"per SIMD lane); got D={D}" + f"mlx::tq_dequant: unpacked dim must be a multiple of 64 " + f"(2 dims per packed byte, 32 SIMD lanes); got D={D}" ) out_dtype_int = torch_dtype_to_scalar_type(norms_meta.dtype) diff --git a/backends/mlx/test/test_ops.py b/backends/mlx/test/test_ops.py index 45ea024f0e8..ec80b1d3911 100644 --- a/backends/mlx/test/test_ops.py +++ b/backends/mlx/test/test_ops.py @@ -2236,6 +2236,402 @@ def get_dynamic_shapes(self) -> Optional[Dict[str, any]]: } +from executorch.backends.mlx.llm.turboquant_cache import TurboQuantKVCache + + +class TurboQuantKVCacheModel(nn.Module): + """ + Test model wrapping TurboQuantKVCache.update(). + + TurboQuantKVCache stores K/V in rotated 4-bit packed form. ``update`` + returns the four cache buffers (k_packed, k_norms, v_packed, v_norms) + rather than uncompressed K/V. + """ + + def __init__( + self, + max_batch_size: int, + max_context_length: int, + n_heads: int, + head_dim: int, + enable_dynamic_shape: bool = True, + ): + super().__init__() + self.cache = TurboQuantKVCache( + max_batch_size=max_batch_size, + max_context_length=max_context_length, + n_heads=n_heads, + head_dim=head_dim, + enable_dynamic_shape=enable_dynamic_shape, + ) + + def forward( + self, + input_pos: torch.Tensor, + k_val: torch.Tensor, + v_val: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + return self.cache.update(input_pos, k_val, v_val) + + +@register_test +class TurboQuantKVCacheTest(OpTestCase): + """ + Test case for TurboQuantKVCache with tensor input_pos. + + Verifies eager-vs-MLX consistency for the compress + write path + (``mlx::tq_norm``, ``mlx::tq4_compress``, ``mlx::kv_cache_update``). + The packed cache is uint8 (byte-exact), norms are bf16 (loose tol). + """ + + name = "turboquant_kv_cache" + # uint8 packed cache stays effectively exact under atol<1; bf16 + # norms need ~1e-1 absolute slack for the eager-vs-MLX bf16 path. + rtol = 1e-5 + atol = 1e-1 + + def __init__( + self, + n_heads: int = 4, + head_dim: int = 64, + max_context_length: int = 128, + seq_step: int = 8, + enable_dynamic_shape: bool = True, + ): + # TurboQuantKVCache requires batch=1. + self.max_batch_size = 1 + self.n_heads = n_heads + self.head_dim = head_dim + self.max_context_length = max_context_length + self.seq_step = seq_step + self.enable_dynamic_shape = enable_dynamic_shape + + @classmethod + def get_test_configs(cls) -> List["TurboQuantKVCacheTest"]: + return [ + cls(), # default: head_dim=64 (smallest valid) + cls(head_dim=128), + cls(enable_dynamic_shape=False), + ] + + def create_model(self) -> nn.Module: + return TurboQuantKVCacheModel( + max_batch_size=self.max_batch_size, + max_context_length=self.max_context_length, + n_heads=self.n_heads, + head_dim=self.head_dim, + enable_dynamic_shape=self.enable_dynamic_shape, + ) + + def create_inputs(self) -> Tuple[torch.Tensor, ...]: + input_pos = torch.tensor([0], dtype=torch.int64) + k_val = torch.randn( + self.max_batch_size, + self.n_heads, + self.seq_step, + self.head_dim, + dtype=torch.bfloat16, + ) + v_val = torch.randn( + self.max_batch_size, + self.n_heads, + self.seq_step, + self.head_dim, + dtype=torch.bfloat16, + ) + return (input_pos, k_val, v_val) + + def create_test_inputs(self) -> Tuple[torch.Tensor, ...]: + # With static shape, test inputs must match the exported seq length. + test_seq_step = ( + self.seq_step if not self.enable_dynamic_shape else self.seq_step + 4 + ) + input_pos = torch.tensor([16], dtype=torch.int64) + k_val = torch.randn( + self.max_batch_size, + self.n_heads, + test_seq_step, + self.head_dim, + dtype=torch.bfloat16, + ) + v_val = torch.randn( + self.max_batch_size, + self.n_heads, + test_seq_step, + self.head_dim, + dtype=torch.bfloat16, + ) + return (input_pos, k_val, v_val) + + def get_dynamic_shapes(self) -> Optional[Dict[str, any]]: + if not self.enable_dynamic_shape: + return None + seq_dim = Dim("seq_step", min=1, max=self.max_context_length) + return { + "input_pos": None, + "k_val": {2: seq_dim}, + "v_val": {2: seq_dim}, + } + + +class TurboQuantKVCacheIntModel(nn.Module): + """ + Test model that passes int/SymInt (not tensor) to + ``TurboQuantKVCache.update`` — the multi-layer pattern. + """ + + def __init__( + self, + max_batch_size: int, + max_context_length: int, + n_heads: int, + head_dim: int, + enable_dynamic_shape: bool = True, + ): + super().__init__() + self.cache = TurboQuantKVCache( + max_batch_size=max_batch_size, + max_context_length=max_context_length, + n_heads=n_heads, + head_dim=head_dim, + enable_dynamic_shape=enable_dynamic_shape, + ) + + def forward( + self, + input_pos: torch.Tensor, + k_val: torch.Tensor, + v_val: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + start_pos = input_pos[0].item() + return self.cache.update(start_pos, k_val, v_val) + + +@register_test +class TurboQuantKVCacheIntTest(OpTestCase): + """Test case for TurboQuantKVCache with int/SymInt input_pos.""" + + name = "turboquant_kv_cache_int" + rtol = 1e-5 + atol = 1e-1 + + def __init__( + self, + n_heads: int = 4, + head_dim: int = 64, + max_context_length: int = 128, + seq_step: int = 8, + enable_dynamic_shape: bool = True, + ): + self.max_batch_size = 1 + self.n_heads = n_heads + self.head_dim = head_dim + self.max_context_length = max_context_length + self.seq_step = seq_step + self.enable_dynamic_shape = enable_dynamic_shape + + @classmethod + def get_test_configs(cls) -> List["TurboQuantKVCacheIntTest"]: + return [ + cls(), + cls(head_dim=128), + ] + + def create_model(self) -> nn.Module: + return TurboQuantKVCacheIntModel( + max_batch_size=self.max_batch_size, + max_context_length=self.max_context_length, + n_heads=self.n_heads, + head_dim=self.head_dim, + enable_dynamic_shape=self.enable_dynamic_shape, + ) + + def create_inputs(self) -> Tuple[torch.Tensor, ...]: + input_pos = torch.tensor([0], dtype=torch.int64) + k_val = torch.randn( + self.max_batch_size, + self.n_heads, + self.seq_step, + self.head_dim, + dtype=torch.bfloat16, + ) + v_val = torch.randn( + self.max_batch_size, + self.n_heads, + self.seq_step, + self.head_dim, + dtype=torch.bfloat16, + ) + return (input_pos, k_val, v_val) + + def create_test_inputs(self) -> Tuple[torch.Tensor, ...]: + test_seq_step = self.seq_step + 4 + input_pos = torch.tensor([16], dtype=torch.int64) + k_val = torch.randn( + self.max_batch_size, + self.n_heads, + test_seq_step, + self.head_dim, + dtype=torch.bfloat16, + ) + v_val = torch.randn( + self.max_batch_size, + self.n_heads, + test_seq_step, + self.head_dim, + dtype=torch.bfloat16, + ) + return (input_pos, k_val, v_val) + + def get_dynamic_shapes(self) -> Optional[Dict[str, any]]: + if not self.enable_dynamic_shape: + return None + seq_dim = Dim("seq_step", min=1, max=self.max_context_length) + return { + "input_pos": None, + "k_val": {2: seq_dim}, + "v_val": {2: seq_dim}, + } + + +class TurboQuantKVCacheSdpaModel(nn.Module): + """ + Test model wrapping ``TurboQuantKVCache.update + .sdpa`` — the full + prefill/decode flow (compress, dequant, attention in rotated space, + un-rotate output). + """ + + def __init__( + self, + max_batch_size: int, + max_context_length: int, + n_heads: int, + head_dim: int, + enable_dynamic_shape: bool = True, + ): + super().__init__() + self.max_context_length = max_context_length + self.cache = TurboQuantKVCache( + max_batch_size=max_batch_size, + max_context_length=max_context_length, + n_heads=n_heads, + head_dim=head_dim, + enable_dynamic_shape=enable_dynamic_shape, + ) + + def forward( + self, + input_pos: torch.Tensor, + k_val: torch.Tensor, + v_val: torch.Tensor, + query: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + start_pos = input_pos[0].item() + seq_len = k_val.size(2) + torch._check(start_pos >= 0) + torch._check(start_pos + seq_len <= self.max_context_length) + + k_packed, k_norms, v_packed, v_norms = self.cache.update( + start_pos, k_val, v_val + ) + out = self.cache.sdpa(query, start_pos) + return out, k_packed, k_norms, v_packed, v_norms + + +@register_test +class TurboQuantKVCacheSdpaTest(OpTestCase): + """ + Test case for ``TurboQuantKVCache.update`` + ``.sdpa``. + + Exercises the full forward path: compress + write through + ``mlx::tq_norm`` / ``mlx::tq4_compress`` / ``mlx::kv_cache_update``, + then dequantize and attend via ``mlx::tq_dequant`` / + ``mlx::custom_sdpa`` with Q rotated in and output rotated back. + Looser tolerance is needed because attention runs in bf16. + """ + + name = "turboquant_kv_cache_sdpa" + rtol = 1e-5 + atol = 5e-2 # bf16 SDPA output + + def __init__( + self, + n_heads: int = 4, + head_dim: int = 64, + max_context_length: int = 128, + seq_step: int = 8, + enable_dynamic_shape: bool = True, + ): + self.max_batch_size = 1 + self.n_heads = n_heads + self.head_dim = head_dim + self.max_context_length = max_context_length + self.seq_step = seq_step + self.enable_dynamic_shape = enable_dynamic_shape + + @classmethod + def get_test_configs(cls) -> List["TurboQuantKVCacheSdpaTest"]: + return [ + cls(), + cls(head_dim=128), + ] + + def create_model(self) -> nn.Module: + return TurboQuantKVCacheSdpaModel( + max_batch_size=self.max_batch_size, + max_context_length=self.max_context_length, + n_heads=self.n_heads, + head_dim=self.head_dim, + enable_dynamic_shape=self.enable_dynamic_shape, + ) + + def _make_inputs( + self, start: int, q_len: int, kv_len: int + ) -> Tuple[torch.Tensor, ...]: + input_pos = torch.tensor([start], dtype=torch.int64) + k_val = torch.randn( + self.max_batch_size, + self.n_heads, + kv_len, + self.head_dim, + dtype=torch.bfloat16, + ) + v_val = torch.randn( + self.max_batch_size, + self.n_heads, + kv_len, + self.head_dim, + dtype=torch.bfloat16, + ) + query = torch.randn( + self.max_batch_size, + self.n_heads, + q_len, + self.head_dim, + dtype=torch.bfloat16, + ) + return (input_pos, k_val, v_val, query) + + def create_inputs(self) -> Tuple[torch.Tensor, ...]: + # Prefill-style: start=0, q_len == kv_len. + return self._make_inputs(start=0, q_len=self.seq_step, kv_len=self.seq_step) + + def create_test_inputs(self) -> Tuple[torch.Tensor, ...]: + # Decode-style: write a single token into the existing cache. + return self._make_inputs(start=16, q_len=1, kv_len=1) + + def get_dynamic_shapes(self) -> Optional[Dict[str, any]]: + if not self.enable_dynamic_shape: + return None + seq_dim = Dim("seq_step", min=1, max=self.max_context_length) + return { + "input_pos": None, + "k_val": {2: seq_dim}, + "v_val": {2: seq_dim}, + "query": {2: seq_dim}, + } + + class RingBufferKVCacheModel(nn.Module): """ Test model wrapping RingBufferKVCache from cache.py. From ec4cc843d48da1a3d6f548133240e70ad0adce03 Mon Sep 17 00:00:00 2001 From: Scott Roy Date: Fri, 29 May 2026 16:16:54 -0700 Subject: [PATCH 5/8] up --- examples/models/gemma4_31b/export.py | 1 + 1 file changed, 1 insertion(+) diff --git a/examples/models/gemma4_31b/export.py b/examples/models/gemma4_31b/export.py index f4e7e03c71e..ed3dcdba9c3 100644 --- a/examples/models/gemma4_31b/export.py +++ b/examples/models/gemma4_31b/export.py @@ -484,5 +484,6 @@ def main() -> None: finally: os.environ.pop("ET_MLX_ALLOW_NON_FUSED_QUANTIZED_OPS", None) + if __name__ == "__main__": main() From a21ed983938392e0c4d8be754749c5a9563c63c8 Mon Sep 17 00:00:00 2001 From: Scott Roy Date: Fri, 29 May 2026 17:01:17 -0700 Subject: [PATCH 6/8] up --- backends/cuda/triton/kernels/tq4_sdpa.py | 8 +- .../gemma4_31b/cuda_source_transformations.py | 138 ++++++++++++++++++ examples/models/gemma4_31b/export.py | 31 ++-- 3 files changed, 163 insertions(+), 14 deletions(-) create mode 100644 examples/models/gemma4_31b/cuda_source_transformations.py diff --git a/backends/cuda/triton/kernels/tq4_sdpa.py b/backends/cuda/triton/kernels/tq4_sdpa.py index a4748540342..c68ea086940 100644 --- a/backends/cuda/triton/kernels/tq4_sdpa.py +++ b/backends/cuda/triton/kernels/tq4_sdpa.py @@ -640,6 +640,7 @@ def tq4_sdpa( rotation: torch.Tensor, attn_mask: Optional[torch.Tensor] = None, is_causal: bool = False, + scale: Optional[float] = None, ) -> torch.Tensor: """Fused TQ4 SDPA over nibble-packed compressed K/V cache. @@ -660,6 +661,10 @@ def tq4_sdpa( rotation: [D, D] orthogonal rotation matrix attn_mask: Optional [B, 1, L_Q, L_KV] bool mask is_causal: apply causal masking (requires L_Q == L_KV) + scale: softmax scale applied to ``Q @ K^T``. Defaults to + ``1/sqrt(HEAD_DIM)`` when ``None``. Models that handle their + own normalization (e.g. Gemma 4 with QK-norm uses ``1.0``) + should pass an explicit value. Returns: [B, H_Q, L_Q, D] bf16 attention output @@ -671,7 +676,7 @@ def tq4_sdpa( _validate_tq4_mask(attn_mask, B, N_Q, N_KV) - sm_scale = 1.0 / math.sqrt(D) + sm_scale = float(1.0 / math.sqrt(D)) if scale is None else float(scale) num_groups = H_Q // H_KV # Build [256] bf16 lookup tables from [16] centroids. @@ -752,5 +757,6 @@ def _tq4_sdpa_fake( rotation: torch.Tensor, attn_mask: Optional[torch.Tensor] = None, is_causal: bool = False, + scale: Optional[float] = None, ) -> torch.Tensor: return torch.empty_like(query) diff --git a/examples/models/gemma4_31b/cuda_source_transformations.py b/examples/models/gemma4_31b/cuda_source_transformations.py new file mode 100644 index 00000000000..9d0884a8e14 --- /dev/null +++ b/examples/models/gemma4_31b/cuda_source_transformations.py @@ -0,0 +1,138 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""CUDA source transformations for Gemma 4 31B-IT. + +Currently only adds optional TurboQuant TQ4 KV cache compression for +full-attention layers, leaving sliding-window layers untouched. When +``use_turboquant=True`` is passed: + +- ``Gemma4Attention.kv_cache`` is replaced with + ``extension.llm.modules.turboquant.TurboQuantKVCache`` on every + full-attention layer (sliding layers keep their ``RingKVCache``). +- The attention forward is monkey-patched to call + ``torch.ops.triton.tq4_sdpa`` (the fused TQ4 attention kernel) instead + of ``F.scaled_dot_product_attention``. + +The model file (``model.py``) stays backend-agnostic — all CUDA +TurboQuant specifics live here. +""" + +from __future__ import annotations + +import types +from typing import Optional + +import torch +import torch.nn as nn + +# Importing this module registers ``torch.ops.triton.tq4_sdpa``. +import executorch.backends.cuda.triton.kernels.tq4_sdpa # noqa: F401 + +from executorch.examples.models.gemma4.text_decoder import apply_rotary_emb +from executorch.extension.llm.modules.turboquant import TurboQuantKVCache + + +def _turboquant_attention_forward( + self, + x: torch.Tensor, + input_pos: torch.Tensor, + attn_mask: torch.Tensor, +) -> torch.Tensor: + """Drop-in replacement for ``Gemma4Attention.forward`` that uses + ``torch.ops.triton.tq4_sdpa`` over a ``TurboQuantKVCache``. + + Mirrors the default forward up to (and including) RoPE; only the + cache update and SDPA call differ. + """ + B, T, _ = x.shape + + q = self.q_proj(x).view(B, T, self.n_heads, self.head_dim) + raw_k = self.k_proj(x).view(B, T, self.n_kv_heads, self.head_dim) + if self.k_eq_v: + raw_v = raw_k + else: + raw_v = self.v_proj(x).view(B, T, self.n_kv_heads, self.head_dim) + + q = self.q_norm(q) + k = self.k_norm(raw_k) + v = self.v_norm(raw_v) + + # (B, H, T, D) for SDPA / KV cache. + q = q.transpose(1, 2) + k = k.transpose(1, 2) + v = v.transpose(1, 2) + + # RoPE: same code path as default forward. + freqs = torch.outer(input_pos.float(), self.inv_freq) + emb = torch.cat((freqs, freqs), dim=-1) + cos = torch.cos(emb) + sin = torch.sin(emb) + q, k = apply_rotary_emb(q, k, cos, sin) + + # Compress + write. Returns the full compressed cache tensors — + # tq4_sdpa decompresses per tile in its inner loop, so the full + # uncompressed K/V is never materialized. + k_packed, k_norms, v_packed, v_norms = self.kv_cache.update( + input_pos, k, v + ) + + # ``scale=self.scaling`` (= 1.0 for Gemma 4) — overrides tq4_sdpa's + # default ``1/sqrt(D)`` because Gemma's QK-norm has absorbed the + # 1/sqrt(d) factor into trained weights. + y = torch.ops.triton.tq4_sdpa( + q, + k_packed, + k_norms, + v_packed, + v_norms, + self.kv_cache.centroids, + self.kv_cache.rotation, + attn_mask, + False, # is_causal — attn_mask already encodes causal masking + self.scaling, + ) + + y = y.transpose(1, 2).contiguous().view(B, T, self.n_heads * self.head_dim) + return self.o_proj(y) + + +def cuda_source_transformations( + model: nn.Module, + *, + use_turboquant: bool = False, +) -> None: + """Apply CUDA source transformations to a Gemma 4 31B model in place. + + Args: + model: ``Gemma4_31B`` instance to transform. + use_turboquant: When True, swap full-attention layers' KV caches + for the backend-agnostic ``TurboQuantKVCache`` (~3.8× cache + memory savings) and route their SDPA through + ``torch.ops.triton.tq4_sdpa``. Sliding-window layers are + unaffected. + """ + if not use_turboquant: + return + + config = model.config + n_swapped = 0 + for layer in model.layers: + attn = layer.self_attn + if attn.is_sliding: + continue + attn.kv_cache = TurboQuantKVCache( + n_heads=attn.n_kv_heads, + head_dim=attn.head_dim, + max_seq_len=config.max_seq_len, + ) + attn.forward = types.MethodType(_turboquant_attention_forward, attn) + n_swapped += 1 + + print( + f"[gemma4_31b cuda] TurboQuant: swapped {n_swapped} full-attention " + f"KV caches with TurboQuantKVCache (TQ4)" + ) diff --git a/examples/models/gemma4_31b/export.py b/examples/models/gemma4_31b/export.py index ed3dcdba9c3..1de00097d4f 100644 --- a/examples/models/gemma4_31b/export.py +++ b/examples/models/gemma4_31b/export.py @@ -145,13 +145,7 @@ def export_and_lower( ) -> None: """Export and lower the model to ExecuTorch for the given backend.""" if backend == "cuda": - if use_turboquant: - raise ValueError( - "--turboquant is only supported with --backend mlx " - "(the CUDA path here uses a different TurboQuant integration; " - "see examples/models/qwen3_5_moe/export.py)." - ) - _export_cuda(model, config, output_dir) + _export_cuda(model, config, output_dir, use_turboquant=use_turboquant) elif backend == "mlx": _export_mlx(model, config, output_dir, use_turboquant=use_turboquant) else: @@ -160,7 +154,12 @@ def export_and_lower( ) -def _export_cuda(model: Gemma4_31B, config: Gemma4_31BConfig, output_dir: str) -> None: +def _export_cuda( + model: Gemma4_31B, + config: Gemma4_31BConfig, + output_dir: str, + use_turboquant: bool = False, +) -> None: import gc import torch._inductor.config as inductor_config @@ -184,6 +183,13 @@ def _export_cuda(model: Gemma4_31B, config: Gemma4_31BConfig, output_dir: str) - materialize_runtime_buffers(model, dtype=torch.bfloat16) + if use_turboquant: + from executorch.examples.models.gemma4_31b.cuda_source_transformations import ( + cuda_source_transformations, + ) + + cuda_source_transformations(model, use_turboquant=True) + # Int4Tensor weights are used directly — no format conversion. # F.linear dispatches to executorch_cuda::int4_plain_mm (CUDA shim). # Both decode and prefill share the same nibble-packed weights. @@ -440,14 +446,13 @@ def main() -> None: parser.add_argument( "--turboquant", action="store_true", - help="Use TurboQuant TQ4 KV cache compression (MLX backend only). " - "~3.8× cache memory savings; applies only to full-attention " - "(non-sliding) layers — sliding layers keep RingBufferKVCache.", + help="Use TurboQuant TQ4 KV cache compression. ~3.8× cache memory " + "savings; applies only to full-attention (non-sliding) layers — " + "sliding layers keep their default cache. Supported on both " + "--backend mlx and --backend cuda.", ) args = parser.parse_args() - if args.turboquant and args.backend != "mlx": - parser.error("--turboquant requires --backend mlx.") if args.backend == "cuda" and not torch.cuda.is_available(): parser.error("CUDA is required for the cuda backend.") From 5760cf3eec8ff0a0ba977d508242f15eebf53694 Mon Sep 17 00:00:00 2001 From: Scott Roy Date: Fri, 29 May 2026 17:04:29 -0700 Subject: [PATCH 7/8] up --- .../models/gemma4_31b/cuda_source_transformations.py | 11 ++++------- 1 file changed, 4 insertions(+), 7 deletions(-) diff --git a/examples/models/gemma4_31b/cuda_source_transformations.py b/examples/models/gemma4_31b/cuda_source_transformations.py index 9d0884a8e14..aeafd97f74e 100644 --- a/examples/models/gemma4_31b/cuda_source_transformations.py +++ b/examples/models/gemma4_31b/cuda_source_transformations.py @@ -24,14 +24,13 @@ from __future__ import annotations import types -from typing import Optional - -import torch -import torch.nn as nn # Importing this module registers ``torch.ops.triton.tq4_sdpa``. import executorch.backends.cuda.triton.kernels.tq4_sdpa # noqa: F401 +import torch +import torch.nn as nn + from executorch.examples.models.gemma4.text_decoder import apply_rotary_emb from executorch.extension.llm.modules.turboquant import TurboQuantKVCache @@ -76,9 +75,7 @@ def _turboquant_attention_forward( # Compress + write. Returns the full compressed cache tensors — # tq4_sdpa decompresses per tile in its inner loop, so the full # uncompressed K/V is never materialized. - k_packed, k_norms, v_packed, v_norms = self.kv_cache.update( - input_pos, k, v - ) + k_packed, k_norms, v_packed, v_norms = self.kv_cache.update(input_pos, k, v) # ``scale=self.scaling`` (= 1.0 for Gemma 4) — overrides tq4_sdpa's # default ``1/sqrt(D)`` because Gemma's QK-norm has absorbed the From 0d0c1bad0e273f58ae996ad1c9554608f8924dd5 Mon Sep 17 00:00:00 2001 From: Scott Roy Date: Fri, 29 May 2026 17:19:33 -0700 Subject: [PATCH 8/8] up --- examples/models/gemma4_31b/export.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/examples/models/gemma4_31b/export.py b/examples/models/gemma4_31b/export.py index f17caf674c1..1de00097d4f 100644 --- a/examples/models/gemma4_31b/export.py +++ b/examples/models/gemma4_31b/export.py @@ -453,8 +453,6 @@ def main() -> None: ) args = parser.parse_args() - if args.turboquant and args.backend != "mlx": - parser.error("--turboquant requires --backend mlx.") if args.backend == "cuda" and not torch.cuda.is_available(): parser.error("CUDA is required for the cuda backend.")