diff --git a/backends/cuda/triton/kernels/tq4_sdpa.py b/backends/cuda/triton/kernels/tq4_sdpa.py index a4748540342..c68ea086940 100644 --- a/backends/cuda/triton/kernels/tq4_sdpa.py +++ b/backends/cuda/triton/kernels/tq4_sdpa.py @@ -640,6 +640,7 @@ def tq4_sdpa( rotation: torch.Tensor, attn_mask: Optional[torch.Tensor] = None, is_causal: bool = False, + scale: Optional[float] = None, ) -> torch.Tensor: """Fused TQ4 SDPA over nibble-packed compressed K/V cache. @@ -660,6 +661,10 @@ def tq4_sdpa( rotation: [D, D] orthogonal rotation matrix attn_mask: Optional [B, 1, L_Q, L_KV] bool mask is_causal: apply causal masking (requires L_Q == L_KV) + scale: softmax scale applied to ``Q @ K^T``. Defaults to + ``1/sqrt(HEAD_DIM)`` when ``None``. Models that handle their + own normalization (e.g. Gemma 4 with QK-norm uses ``1.0``) + should pass an explicit value. Returns: [B, H_Q, L_Q, D] bf16 attention output @@ -671,7 +676,7 @@ def tq4_sdpa( _validate_tq4_mask(attn_mask, B, N_Q, N_KV) - sm_scale = 1.0 / math.sqrt(D) + sm_scale = float(1.0 / math.sqrt(D)) if scale is None else float(scale) num_groups = H_Q // H_KV # Build [256] bf16 lookup tables from [16] centroids. @@ -752,5 +757,6 @@ def _tq4_sdpa_fake( rotation: torch.Tensor, attn_mask: Optional[torch.Tensor] = None, is_causal: bool = False, + scale: Optional[float] = None, ) -> torch.Tensor: return torch.empty_like(query) diff --git a/examples/models/gemma4_31b/cuda_source_transformations.py b/examples/models/gemma4_31b/cuda_source_transformations.py new file mode 100644 index 00000000000..aeafd97f74e --- /dev/null +++ b/examples/models/gemma4_31b/cuda_source_transformations.py @@ -0,0 +1,135 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""CUDA source transformations for Gemma 4 31B-IT. + +Currently only adds optional TurboQuant TQ4 KV cache compression for +full-attention layers, leaving sliding-window layers untouched. When +``use_turboquant=True`` is passed: + +- ``Gemma4Attention.kv_cache`` is replaced with + ``extension.llm.modules.turboquant.TurboQuantKVCache`` on every + full-attention layer (sliding layers keep their ``RingKVCache``). +- The attention forward is monkey-patched to call + ``torch.ops.triton.tq4_sdpa`` (the fused TQ4 attention kernel) instead + of ``F.scaled_dot_product_attention``. + +The model file (``model.py``) stays backend-agnostic — all CUDA +TurboQuant specifics live here. +""" + +from __future__ import annotations + +import types + +# Importing this module registers ``torch.ops.triton.tq4_sdpa``. +import executorch.backends.cuda.triton.kernels.tq4_sdpa # noqa: F401 + +import torch +import torch.nn as nn + +from executorch.examples.models.gemma4.text_decoder import apply_rotary_emb +from executorch.extension.llm.modules.turboquant import TurboQuantKVCache + + +def _turboquant_attention_forward( + self, + x: torch.Tensor, + input_pos: torch.Tensor, + attn_mask: torch.Tensor, +) -> torch.Tensor: + """Drop-in replacement for ``Gemma4Attention.forward`` that uses + ``torch.ops.triton.tq4_sdpa`` over a ``TurboQuantKVCache``. + + Mirrors the default forward up to (and including) RoPE; only the + cache update and SDPA call differ. + """ + B, T, _ = x.shape + + q = self.q_proj(x).view(B, T, self.n_heads, self.head_dim) + raw_k = self.k_proj(x).view(B, T, self.n_kv_heads, self.head_dim) + if self.k_eq_v: + raw_v = raw_k + else: + raw_v = self.v_proj(x).view(B, T, self.n_kv_heads, self.head_dim) + + q = self.q_norm(q) + k = self.k_norm(raw_k) + v = self.v_norm(raw_v) + + # (B, H, T, D) for SDPA / KV cache. + q = q.transpose(1, 2) + k = k.transpose(1, 2) + v = v.transpose(1, 2) + + # RoPE: same code path as default forward. + freqs = torch.outer(input_pos.float(), self.inv_freq) + emb = torch.cat((freqs, freqs), dim=-1) + cos = torch.cos(emb) + sin = torch.sin(emb) + q, k = apply_rotary_emb(q, k, cos, sin) + + # Compress + write. Returns the full compressed cache tensors — + # tq4_sdpa decompresses per tile in its inner loop, so the full + # uncompressed K/V is never materialized. + k_packed, k_norms, v_packed, v_norms = self.kv_cache.update(input_pos, k, v) + + # ``scale=self.scaling`` (= 1.0 for Gemma 4) — overrides tq4_sdpa's + # default ``1/sqrt(D)`` because Gemma's QK-norm has absorbed the + # 1/sqrt(d) factor into trained weights. + y = torch.ops.triton.tq4_sdpa( + q, + k_packed, + k_norms, + v_packed, + v_norms, + self.kv_cache.centroids, + self.kv_cache.rotation, + attn_mask, + False, # is_causal — attn_mask already encodes causal masking + self.scaling, + ) + + y = y.transpose(1, 2).contiguous().view(B, T, self.n_heads * self.head_dim) + return self.o_proj(y) + + +def cuda_source_transformations( + model: nn.Module, + *, + use_turboquant: bool = False, +) -> None: + """Apply CUDA source transformations to a Gemma 4 31B model in place. + + Args: + model: ``Gemma4_31B`` instance to transform. + use_turboquant: When True, swap full-attention layers' KV caches + for the backend-agnostic ``TurboQuantKVCache`` (~3.8× cache + memory savings) and route their SDPA through + ``torch.ops.triton.tq4_sdpa``. Sliding-window layers are + unaffected. + """ + if not use_turboquant: + return + + config = model.config + n_swapped = 0 + for layer in model.layers: + attn = layer.self_attn + if attn.is_sliding: + continue + attn.kv_cache = TurboQuantKVCache( + n_heads=attn.n_kv_heads, + head_dim=attn.head_dim, + max_seq_len=config.max_seq_len, + ) + attn.forward = types.MethodType(_turboquant_attention_forward, attn) + n_swapped += 1 + + print( + f"[gemma4_31b cuda] TurboQuant: swapped {n_swapped} full-attention " + f"KV caches with TurboQuantKVCache (TQ4)" + ) diff --git a/examples/models/gemma4_31b/export.py b/examples/models/gemma4_31b/export.py index ed3dcdba9c3..1de00097d4f 100644 --- a/examples/models/gemma4_31b/export.py +++ b/examples/models/gemma4_31b/export.py @@ -145,13 +145,7 @@ def export_and_lower( ) -> None: """Export and lower the model to ExecuTorch for the given backend.""" if backend == "cuda": - if use_turboquant: - raise ValueError( - "--turboquant is only supported with --backend mlx " - "(the CUDA path here uses a different TurboQuant integration; " - "see examples/models/qwen3_5_moe/export.py)." - ) - _export_cuda(model, config, output_dir) + _export_cuda(model, config, output_dir, use_turboquant=use_turboquant) elif backend == "mlx": _export_mlx(model, config, output_dir, use_turboquant=use_turboquant) else: @@ -160,7 +154,12 @@ def export_and_lower( ) -def _export_cuda(model: Gemma4_31B, config: Gemma4_31BConfig, output_dir: str) -> None: +def _export_cuda( + model: Gemma4_31B, + config: Gemma4_31BConfig, + output_dir: str, + use_turboquant: bool = False, +) -> None: import gc import torch._inductor.config as inductor_config @@ -184,6 +183,13 @@ def _export_cuda(model: Gemma4_31B, config: Gemma4_31BConfig, output_dir: str) - materialize_runtime_buffers(model, dtype=torch.bfloat16) + if use_turboquant: + from executorch.examples.models.gemma4_31b.cuda_source_transformations import ( + cuda_source_transformations, + ) + + cuda_source_transformations(model, use_turboquant=True) + # Int4Tensor weights are used directly — no format conversion. # F.linear dispatches to executorch_cuda::int4_plain_mm (CUDA shim). # Both decode and prefill share the same nibble-packed weights. @@ -440,14 +446,13 @@ def main() -> None: parser.add_argument( "--turboquant", action="store_true", - help="Use TurboQuant TQ4 KV cache compression (MLX backend only). " - "~3.8× cache memory savings; applies only to full-attention " - "(non-sliding) layers — sliding layers keep RingBufferKVCache.", + help="Use TurboQuant TQ4 KV cache compression. ~3.8× cache memory " + "savings; applies only to full-attention (non-sliding) layers — " + "sliding layers keep their default cache. Supported on both " + "--backend mlx and --backend cuda.", ) args = parser.parse_args() - if args.turboquant and args.backend != "mlx": - parser.error("--turboquant requires --backend mlx.") if args.backend == "cuda" and not torch.cuda.is_available(): parser.error("CUDA is required for the cuda backend.")