Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion backends/cuda/triton/kernels/tq4_sdpa.py
Original file line number Diff line number Diff line change
Expand Up @@ -640,6 +640,7 @@ def tq4_sdpa(
rotation: torch.Tensor,
attn_mask: Optional[torch.Tensor] = None,
is_causal: bool = False,
scale: Optional[float] = None,
) -> torch.Tensor:
"""Fused TQ4 SDPA over nibble-packed compressed K/V cache.

Expand All @@ -660,6 +661,10 @@ def tq4_sdpa(
rotation: [D, D] orthogonal rotation matrix
attn_mask: Optional [B, 1, L_Q, L_KV] bool mask
is_causal: apply causal masking (requires L_Q == L_KV)
scale: softmax scale applied to ``Q @ K^T``. Defaults to
``1/sqrt(HEAD_DIM)`` when ``None``. Models that handle their
own normalization (e.g. Gemma 4 with QK-norm uses ``1.0``)
should pass an explicit value.

Returns:
[B, H_Q, L_Q, D] bf16 attention output
Expand All @@ -671,7 +676,7 @@ def tq4_sdpa(

_validate_tq4_mask(attn_mask, B, N_Q, N_KV)

sm_scale = 1.0 / math.sqrt(D)
sm_scale = float(1.0 / math.sqrt(D)) if scale is None else float(scale)
num_groups = H_Q // H_KV

# Build [256] bf16 lookup tables from [16] centroids.
Expand Down Expand Up @@ -752,5 +757,6 @@ def _tq4_sdpa_fake(
rotation: torch.Tensor,
attn_mask: Optional[torch.Tensor] = None,
is_causal: bool = False,
scale: Optional[float] = None,
) -> torch.Tensor:
return torch.empty_like(query)
135 changes: 135 additions & 0 deletions examples/models/gemma4_31b/cuda_source_transformations.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

"""CUDA source transformations for Gemma 4 31B-IT.

Currently only adds optional TurboQuant TQ4 KV cache compression for
full-attention layers, leaving sliding-window layers untouched. When
``use_turboquant=True`` is passed:

- ``Gemma4Attention.kv_cache`` is replaced with
``extension.llm.modules.turboquant.TurboQuantKVCache`` on every
full-attention layer (sliding layers keep their ``RingKVCache``).
- The attention forward is monkey-patched to call
``torch.ops.triton.tq4_sdpa`` (the fused TQ4 attention kernel) instead
of ``F.scaled_dot_product_attention``.

The model file (``model.py``) stays backend-agnostic — all CUDA
TurboQuant specifics live here.
"""

from __future__ import annotations

import types

# Importing this module registers ``torch.ops.triton.tq4_sdpa``.
import executorch.backends.cuda.triton.kernels.tq4_sdpa # noqa: F401

import torch
import torch.nn as nn

from executorch.examples.models.gemma4.text_decoder import apply_rotary_emb
from executorch.extension.llm.modules.turboquant import TurboQuantKVCache


def _turboquant_attention_forward(
self,
x: torch.Tensor,
input_pos: torch.Tensor,
attn_mask: torch.Tensor,
) -> torch.Tensor:
"""Drop-in replacement for ``Gemma4Attention.forward`` that uses
``torch.ops.triton.tq4_sdpa`` over a ``TurboQuantKVCache``.

Mirrors the default forward up to (and including) RoPE; only the
cache update and SDPA call differ.
"""
B, T, _ = x.shape

q = self.q_proj(x).view(B, T, self.n_heads, self.head_dim)
raw_k = self.k_proj(x).view(B, T, self.n_kv_heads, self.head_dim)
if self.k_eq_v:
raw_v = raw_k
else:
raw_v = self.v_proj(x).view(B, T, self.n_kv_heads, self.head_dim)

q = self.q_norm(q)
k = self.k_norm(raw_k)
v = self.v_norm(raw_v)

# (B, H, T, D) for SDPA / KV cache.
q = q.transpose(1, 2)
k = k.transpose(1, 2)
v = v.transpose(1, 2)

# RoPE: same code path as default forward.
freqs = torch.outer(input_pos.float(), self.inv_freq)
emb = torch.cat((freqs, freqs), dim=-1)
cos = torch.cos(emb)
sin = torch.sin(emb)
q, k = apply_rotary_emb(q, k, cos, sin)

# Compress + write. Returns the full compressed cache tensors —
# tq4_sdpa decompresses per tile in its inner loop, so the full
# uncompressed K/V is never materialized.
k_packed, k_norms, v_packed, v_norms = self.kv_cache.update(input_pos, k, v)

# ``scale=self.scaling`` (= 1.0 for Gemma 4) — overrides tq4_sdpa's
# default ``1/sqrt(D)`` because Gemma's QK-norm has absorbed the
# 1/sqrt(d) factor into trained weights.
y = torch.ops.triton.tq4_sdpa(
q,
k_packed,
k_norms,
v_packed,
v_norms,
self.kv_cache.centroids,
self.kv_cache.rotation,
attn_mask,
False, # is_causal — attn_mask already encodes causal masking
self.scaling,
)

y = y.transpose(1, 2).contiguous().view(B, T, self.n_heads * self.head_dim)
return self.o_proj(y)


def cuda_source_transformations(
model: nn.Module,
*,
use_turboquant: bool = False,
) -> None:
"""Apply CUDA source transformations to a Gemma 4 31B model in place.

Args:
model: ``Gemma4_31B`` instance to transform.
use_turboquant: When True, swap full-attention layers' KV caches
for the backend-agnostic ``TurboQuantKVCache`` (~3.8× cache
memory savings) and route their SDPA through
``torch.ops.triton.tq4_sdpa``. Sliding-window layers are
unaffected.
"""
if not use_turboquant:
return

config = model.config
n_swapped = 0
for layer in model.layers:
attn = layer.self_attn
if attn.is_sliding:
continue
attn.kv_cache = TurboQuantKVCache(
n_heads=attn.n_kv_heads,
head_dim=attn.head_dim,
max_seq_len=config.max_seq_len,
)
attn.forward = types.MethodType(_turboquant_attention_forward, attn)
n_swapped += 1

print(
f"[gemma4_31b cuda] TurboQuant: swapped {n_swapped} full-attention "
f"KV caches with TurboQuantKVCache (TQ4)"
)
31 changes: 18 additions & 13 deletions examples/models/gemma4_31b/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,13 +145,7 @@ def export_and_lower(
) -> None:
"""Export and lower the model to ExecuTorch for the given backend."""
if backend == "cuda":
if use_turboquant:
raise ValueError(
"--turboquant is only supported with --backend mlx "
"(the CUDA path here uses a different TurboQuant integration; "
"see examples/models/qwen3_5_moe/export.py)."
)
_export_cuda(model, config, output_dir)
_export_cuda(model, config, output_dir, use_turboquant=use_turboquant)
elif backend == "mlx":
_export_mlx(model, config, output_dir, use_turboquant=use_turboquant)
else:
Expand All @@ -160,7 +154,12 @@ def export_and_lower(
)


def _export_cuda(model: Gemma4_31B, config: Gemma4_31BConfig, output_dir: str) -> None:
def _export_cuda(
model: Gemma4_31B,
config: Gemma4_31BConfig,
output_dir: str,
use_turboquant: bool = False,
) -> None:
import gc

import torch._inductor.config as inductor_config
Expand All @@ -184,6 +183,13 @@ def _export_cuda(model: Gemma4_31B, config: Gemma4_31BConfig, output_dir: str) -

materialize_runtime_buffers(model, dtype=torch.bfloat16)

if use_turboquant:
from executorch.examples.models.gemma4_31b.cuda_source_transformations import (
cuda_source_transformations,
)

cuda_source_transformations(model, use_turboquant=True)

# Int4Tensor weights are used directly — no format conversion.
# F.linear dispatches to executorch_cuda::int4_plain_mm (CUDA shim).
# Both decode and prefill share the same nibble-packed weights.
Expand Down Expand Up @@ -440,14 +446,13 @@ def main() -> None:
parser.add_argument(
"--turboquant",
action="store_true",
help="Use TurboQuant TQ4 KV cache compression (MLX backend only). "
"~3.8× cache memory savings; applies only to full-attention "
"(non-sliding) layers — sliding layers keep RingBufferKVCache.",
help="Use TurboQuant TQ4 KV cache compression. ~3.8× cache memory "
"savings; applies only to full-attention (non-sliding) layers — "
"sliding layers keep their default cache. Supported on both "
"--backend mlx and --backend cuda.",
)
args = parser.parse_args()

if args.turboquant and args.backend != "mlx":
parser.error("--turboquant requires --backend mlx.")
if args.backend == "cuda" and not torch.cuda.is_available():
parser.error("CUDA is required for the cuda backend.")

Expand Down
Loading