diff --git a/tests/pytorch/nvfp4/bench_nvfp4_per_token.py b/tests/pytorch/nvfp4/bench_nvfp4_per_token.py new file mode 100644 index 0000000000..525bfc9f5c --- /dev/null +++ b/tests/pytorch/nvfp4/bench_nvfp4_per_token.py @@ -0,0 +1,1476 @@ +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""Bench NVFP4 per-token K1+K2 quant vs per-tensor RHT+SR baseline. + +bf16, M % 128 == 0, K % 128 == 0. + +Modes: + * default: 2-way quant-only (per-token vs per-tensor). Ratio = pt / pten. + * ``--rht``: 3-way quant-only (adds per-token + col-wise 16-pt RHT). + * ``--swizzle``: 3-way END-TO-END (quant + swizzle + cuBLAS LT NVFP4 GEMM). + Compares per-token (separate swizzle launch) vs per-token (fused + swizzle in K2) vs per-tensor. Ratio = per-token (+swizzle) / per-tensor. + * ``--gemm-only``: 2-way cuBLAS LT NVFP4 GEMM in isolation. + Inputs are pre-quantized + pre-swizzled before timing, so the loop + only times ``nvfp4_per_token_gemm(sf_swizzled=True)`` vs + ``nvfp4_per_tensor_gemm(sf_swizzled=True)``. Ratio = pt / pten + exposes the per-call cost of the per-token post-scale kernel + (both paths run the same cuBLAS LT call + alpha-fold kernel). + * ``--qs``: 2-way K1+K2 + standalone rowwise swizzle. NO GEMM. + - default (solo, 1 tensor): K1+K2(A) + swizzle(A); apples-to-apples + with --composite (which is also 1-tensor) -- the delta vs --composite + is the pure marginal swizzle cost. + - ``--pair`` (2 tensors): K1+K2(A) + K1+K2(B) + swizzle(A) + swizzle(B); + matches prod NVFP4 GEMM's per-call quant+swizzle pipeline (1 swizzle + per operand). Use this when you want "one GEMM call's worth of + non-GEMM cost". + - ``--fuse``: also bench per-token with fused-swizzle K2 (K2 directly + writes the rowwise SF in cuBLAS LT swizzled layout; no separate + swizzle launch). Prints a 3-way table: per-token / per-token(fuse) / + per-tensor. The (fuse) column saves 1 swizzle launch/operand vs the + non-fuse column. + Ratio = per-token / per-tensor (3-way mode adds a per-token(fuse) column). + * ``--k1-only``: K1 in isolation (orthogonal to --swizzle / --qs). +""" + +from __future__ import annotations + +import argparse +import math +import statistics +import sys +from dataclasses import dataclass +from typing import Callable, List, Tuple + +import torch + +# Import transformer_engine first so libtransformer_engine.so is dlopen'd +# before transformer_engine_torch tries to resolve its typeinfo symbols. +import transformer_engine.pytorch as te # noqa: F401 +import transformer_engine_torch as tex +from transformer_engine.pytorch import NVFP4Quantizer + + +def cuda_time_ms(fn: Callable[[], None], *, warmup: int = 5, iters: int = 50) -> float: + """Median wall time of fn over iters invocations, in ms.""" + for _ in range(warmup): + fn() + torch.cuda.synchronize() + + starts = [torch.cuda.Event(enable_timing=True) for _ in range(iters)] + ends = [torch.cuda.Event(enable_timing=True) for _ in range(iters)] + for i in range(iters): + starts[i].record() + fn() + ends[i].record() + torch.cuda.synchronize() + samples = [starts[i].elapsed_time(ends[i]) for i in range(iters)] + return statistics.median(samples) + + +def cuda_graph_time_ms(fn: Callable[[], object], *, warmup: int = 5, iters: int = 50) -> float: + """Median g.replay() wall time of fn captured into a CUDA Graph (kernel-only floor). + + Returns nan if capture fails. + """ + try: + side = torch.cuda.Stream() + side.wait_stream(torch.cuda.current_stream()) + with torch.cuda.stream(side): + for _ in range(warmup): + fn() + torch.cuda.current_stream().wait_stream(side) + torch.cuda.synchronize() + + g = torch.cuda.CUDAGraph() + with torch.cuda.graph(g): + fn() + except Exception as e: + print(f" [graph capture skipped: {type(e).__name__}: {e}]", file=sys.stderr) + return float("nan") + + starts = [torch.cuda.Event(enable_timing=True) for _ in range(iters)] + ends = [torch.cuda.Event(enable_timing=True) for _ in range(iters)] + for i in range(iters): + starts[i].record() + g.replay() + ends[i].record() + torch.cuda.synchronize() + samples = [starts[i].elapsed_time(ends[i]) for i in range(iters)] + return statistics.median(samples) + + +def _make_baseline_quantizer() -> NVFP4Quantizer: + """Per-tensor baseline quantizer: RHT + SR + random sign mask.""" + return NVFP4Quantizer( + fp4_dtype=tex.DType.kFloat4E2M1, + rowwise=True, + columnwise=True, + with_amax_reduction=False, + amax_reduction_group=None, + with_rht=True, + with_post_rht_amax=True, + with_2d_quantization=False, + stochastic_rounding=True, + with_random_sign_mask=True, + ) + + +def _has_sm100() -> bool: + if not torch.cuda.is_available(): + return False + major, _ = torch.cuda.get_device_capability() + return major >= 10 + + +@dataclass +class ShapeBench: + M: int + K: int + t_pt: float # per-token full K1+K2, no RHT (Eager pybind, ms) + t_pt_rht: float # per-token full K1+K2, +RHT col-wise (Eager pybind, ms) + t_pten: float # per-tensor full K1+K2 with RHT+SR (Eager pybind, ms) + t_pt_g: float # per-token under CUDA Graphs replay (ms) + t_pt_rht_g: float # per-token+RHT under CUDA Graphs replay (ms) + t_pten_g: float # per-tensor under CUDA Graphs replay (ms) + + +@dataclass +class K1ShapeBench: + M: int + K: int + # K1-only timings: 3 paths x 2 modes (Eager + CUDA Graphs). + t_pt: float # per-token K1, no RHT (rowwise+columnwise amax vectors) + t_pt_rht: float # per-token K1, +RHT on col direction + t_prod: float # prod K1 hadamard_transform_amax (per-tensor scalar amax) + t_pt_g: float + t_pt_rht_g: float + t_prod_g: float + + +@dataclass +class E2EShapeBench: + """End-to-end (quant + GEMM) timing for --swizzle mode. N is bound to M.""" + + M: int + K: int + t_pt: float # per-token (no fused swizzle): quant + ext swizzle + GEMM + t_pt_swz: float # per-token (fused swizzle): quant_with_swizzle=True + GEMM + t_pten: float # per-tensor: NVFP4Quantizer + cuBLAS LT GEMM + t_pt_g: float + t_pt_swz_g: float + t_pten_g: float + + +@dataclass +class GemmOnlyShapeBench: + """NVFP4 GEMM in isolation. Inputs pre-quantized + pre-swizzled outside the + timed window; only the GEMM kernel call is timed. N = K. + + The single comparison that matters for shipping fused-EVT NVFP4 against + the current prod NVFP4 path: + ct_fused = forked CUTLASS NVFP4 GEMM with per-row * per-col rescale + fused into the EVT epilogue (single launch). This is what + WOULD ship for per-token NVFP4. + pten_gemm = cuBLAS LT NVFP4 + alpha-fold (single launch, no post-scale). + This is the CURRENT prod per-tensor NVFP4 GEMM. + + Ratio cf/pten = ct_fused / pten_gemm. < 1.0 ⇒ shippable (per-token at + least matches prod per-tensor wall-clock). + """ + + M: int + K: int + N: int + t_pten: float # nvfp4_per_tensor_gemm: cuBLAS LT + alpha-fold (prod baseline) + t_clf: float # nvfp4_cutlass_per_token_gemm: per-row * per-col fused EVT + t_pten_g: float + t_clf_g: float + + +@dataclass +class QSShapeBench: + """K1+K2 + rowwise swizzle, no GEMM. solo=3 launches, --pair=6, + --fuse adds per-token-fused column (K2 emits swizzled SF in 1 launch).""" + + M: int + K: int + t_pt: float # per-token K1+K2 + ext swizzle (1 or 2 operands depending on pair) + t_pten: float # per-tensor K1+K2 + ext swizzle (matching operand count) + t_pt_g: float + t_pten_g: float + t_pt_swz: float = float("nan") # per-token K1+K2 with fused swizzle (no ext swz launch) + t_pt_swz_g: float = float("nan") + + +# Default mask seed; matches prod's `te-nvfp4-build-overrides.mdc` convention. +_RHT_MASK_DEFAULT: int = 0xACE1 + + +def _bench_shape( + M: int, K: int, *, device: torch.device, with_rht: bool = False, mask_t: int = _RHT_MASK_DEFAULT +) -> ShapeBench: + """Composite K1+K2 at (M, K). pt = per-token (no RHT); pt_rht = +col-wise + 16-pt RHT (NaN unless with_rht=True); pten = per-tensor + RHT + SR.""" + a = torch.randn((M, K), dtype=torch.bfloat16, device=device) + + quantizer = _make_baseline_quantizer() + dst_a = quantizer.make_empty(a.shape, dtype=torch.bfloat16, device=device) + + # Per-token A-side buffers reused across no-RHT and +RHT paths. + BLOCK_K = 16 + ra_a = torch.empty((M,), dtype=torch.float32, device=device) + ca_a = torch.empty((K,), dtype=torch.float32, device=device) + q_row_a = torch.empty((M, K // 2), dtype=torch.uint8, device=device) + s_dec_row_a = torch.empty((M, K // BLOCK_K), dtype=torch.uint8, device=device) + q_col_a = torch.empty((K, M // 2), dtype=torch.uint8, device=device) + s_dec_col_a = torch.empty((K, M // BLOCK_K), dtype=torch.uint8, device=device) + + def _baseline_quant_fn(): + tex.quantize(a, quantizer, dst_a, None) + + def _pt_full_quant_fn(): + tex.nvfp4_per_token_quantize( + a, + q_row_a, + s_dec_row_a, + ra_a, + q_col_a, + s_dec_col_a, + ca_a, + True, + True, + with_rht=False, + random_sign_mask_t=0, + ) + + t_pten = cuda_time_ms(_baseline_quant_fn) + t_pt = cuda_time_ms(_pt_full_quant_fn) + t_pten_g = cuda_graph_time_ms(_baseline_quant_fn) + t_pt_g = cuda_graph_time_ms(_pt_full_quant_fn) + + if with_rht: + + def _pt_full_quant_rht_fn(): + tex.nvfp4_per_token_quantize( + a, + q_row_a, + s_dec_row_a, + ra_a, + q_col_a, + s_dec_col_a, + ca_a, + True, + True, + with_rht=True, + random_sign_mask_t=mask_t, + ) + + t_pt_rht = cuda_time_ms(_pt_full_quant_rht_fn) + t_pt_rht_g = cuda_graph_time_ms(_pt_full_quant_rht_fn) + else: + t_pt_rht = float("nan") + t_pt_rht_g = float("nan") + + return ShapeBench( + M=M, + K=K, + t_pt=t_pt, + t_pt_rht=t_pt_rht, + t_pten=t_pten, + t_pt_g=t_pt_g, + t_pt_rht_g=t_pt_rht_g, + t_pten_g=t_pten_g, + ) + + +def _bench_shape_e2e_swizzle( + M: int, + K: int, + *, + device: torch.device, + with_rht: bool = False, + mask_t: int = _RHT_MASK_DEFAULT, +) -> E2EShapeBench: + """E2E (quant + cuBLAS LT NVFP4 GEMM) for --swizzle, square N=M. + pt: ext swizzle; pt_swz: fused-swizzle K2 (no internal swz launch); + pten: NVFP4Quantizer + nvfp4_per_tensor_gemm baseline.""" + from transformer_engine.pytorch.cpp_extensions.gemm import get_cublas_workspace + + N = M # square; cuBLAS LT NVFP4 is TN-only -- A: (M, K), B: (N, K) + a = torch.randn((M, K), dtype=torch.bfloat16, device=device) + b = torch.randn((N, K), dtype=torch.bfloat16, device=device) + d = torch.empty((M, N), dtype=torch.bfloat16, device=device) + # torch.device("cuda").index is None (no explicit device index); resolve to + # an actual GPU index via the allocated tensor so get_cublas_workspace + # creates the workspace on the right CUDA device instead of CPU. + workspace = get_cublas_workspace(a.device.index, ub=False, grouped_gemm=False) + + # Per-token quant produces row + col directions on every call (matches the + # per-tensor baseline below which does both in one kernel). GEMM consumes + # only the rowwise side; the col allocation is realistic prod overhead. + BLOCK_K = 16 + + def _alloc_pt(R, C): + return ( + torch.empty((R, C // 2), dtype=torch.uint8, device=device), + torch.empty((R, C // BLOCK_K), dtype=torch.uint8, device=device), + torch.empty((R,), dtype=torch.float32, device=device), + torch.empty((C, R // 2), dtype=torch.uint8, device=device), + torch.empty((C, R // BLOCK_K), dtype=torch.uint8, device=device), + torch.empty((C,), dtype=torch.float32, device=device), + ) + + a_qr, a_sr, a_ra, a_qc, a_sc, a_ca = _alloc_pt(M, K) + b_qr, b_sr, b_ra, b_qc, b_sc, b_ca = _alloc_pt(N, K) + + def _pt_quant(t, qr, sr, ra_buf, qc, sc, ca_buf, *, fused_swizzle: bool): + tex.nvfp4_per_token_quantize( + t, + qr, + sr, + ra_buf, + qc, + sc, + ca_buf, + True, + True, # rowwise + columnwise (apples-to-apples vs per-tensor) + with_rht=with_rht, + random_sign_mask_t=mask_t if with_rht else 0, + with_swizzle=fused_swizzle, + ) + + def _pt_e2e_ext_swizzle(): + _pt_quant(a, a_qr, a_sr, a_ra, a_qc, a_sc, a_ca, fused_swizzle=False) + _pt_quant(b, b_qr, b_sr, b_ra, b_qc, b_sc, b_ca, fused_swizzle=False) + tex.nvfp4_per_token_gemm( + a_qr, + b_qr, + a_sr.reshape(-1), + b_sr.reshape(-1), + a_ra, + b_ra, + d, + workspace, + M, + N, + K, + 1.0, + 0.0, + a_sf_swizzled=False, + b_sf_swizzled=False, + ) + + def _pt_e2e_fused_swizzle(): + _pt_quant(a, a_qr, a_sr, a_ra, a_qc, a_sc, a_ca, fused_swizzle=True) + _pt_quant(b, b_qr, b_sr, b_ra, b_qc, b_sc, b_ca, fused_swizzle=True) + tex.nvfp4_per_token_gemm( + a_qr, + b_qr, + a_sr.reshape(-1), + b_sr.reshape(-1), + a_ra, + b_ra, + d, + workspace, + M, + N, + K, + 1.0, + 0.0, + a_sf_swizzled=True, + b_sf_swizzled=True, + ) + + # Per-tensor path: NVFP4Quantizer (RHT+SR) + bench-only nvfp4_per_tensor_gemm. + quantizer = _make_baseline_quantizer() + dst_a = quantizer.make_empty(a.shape, dtype=torch.bfloat16, device=device) + dst_b = quantizer.make_empty(b.shape, dtype=torch.bfloat16, device=device) + + def _pten_e2e(): + tex.quantize(a, quantizer, dst_a, None) + tex.quantize(b, quantizer, dst_b, None) + tex.nvfp4_per_tensor_gemm( + dst_a._rowwise_data, + dst_b._rowwise_data, + dst_a._rowwise_scale_inv, + dst_b._rowwise_scale_inv, + dst_a._amax_rowwise, + dst_b._amax_rowwise, + d, + workspace, + M, + N, + K, + 1.0, + 0.0, + ) + + t_pt = cuda_time_ms(_pt_e2e_ext_swizzle) + t_pt_swz = cuda_time_ms(_pt_e2e_fused_swizzle) + t_pten = cuda_time_ms(_pten_e2e) + t_pt_g = cuda_graph_time_ms(_pt_e2e_ext_swizzle) + t_pt_swz_g = cuda_graph_time_ms(_pt_e2e_fused_swizzle) + t_pten_g = cuda_graph_time_ms(_pten_e2e) + + return E2EShapeBench( + M=M, + K=K, + t_pt=t_pt, + t_pt_swz=t_pt_swz, + t_pten=t_pten, + t_pt_g=t_pt_g, + t_pt_swz_g=t_pt_swz_g, + t_pten_g=t_pten_g, + ) + + +def _bench_shape_gemm_only( + M: int, + K: int, + *, + device: torch.device, + with_rht: bool = False, + mask_t: int = _RHT_MASK_DEFAULT, +) -> GemmOnlyShapeBench: + """NVFP4 GEMM in isolation (N = K). Quant + swizzle run once before timing; + only the GEMM kernel call is timed. + + Two paths timed: + - pten_gemm: cuBLAS LT NVFP4 per-tensor (current prod NVFP4 GEMM). + - ct_fused : forked CUTLASS NVFP4 GEMM with per-row * per-col rescale + fused into the EVT epilogue (per-token, single launch). + """ + from transformer_engine.pytorch.cpp_extensions.gemm import get_cublas_workspace + + N = K + a = torch.randn((M, K), dtype=torch.bfloat16, device=device) + b = torch.randn((N, K), dtype=torch.bfloat16, device=device) + d = torch.empty((M, N), dtype=torch.bfloat16, device=device) + workspace = get_cublas_workspace(a.device.index, ub=False, grouped_gemm=False) + + BLOCK_K = 16 + + # Per-token rowwise quant for the fused CUTLASS path. Pre-swizzled SF so + # the timed window only covers the GEMM kernel (no swizzle launch). + a_qr = torch.empty((M, K // 2), dtype=torch.uint8, device=device) + a_sr = torch.empty((M, K // BLOCK_K), dtype=torch.uint8, device=device) + a_ra = torch.empty((M,), dtype=torch.float32, device=device) + b_qr = torch.empty((N, K // 2), dtype=torch.uint8, device=device) + b_sr = torch.empty((N, K // BLOCK_K), dtype=torch.uint8, device=device) + b_ra = torch.empty((N,), dtype=torch.float32, device=device) + empty_u8 = torch.empty(0, dtype=torch.uint8, device=device) + empty_f32 = torch.empty(0, dtype=torch.float32, device=device) + tex.nvfp4_per_token_quantize( + a, + a_qr, + a_sr, + a_ra, + empty_u8, + empty_u8, + empty_f32, + True, + False, + with_rht=with_rht, + random_sign_mask_t=mask_t if with_rht else 0, + with_swizzle=True, + ) + tex.nvfp4_per_token_quantize( + b, + b_qr, + b_sr, + b_ra, + empty_u8, + empty_u8, + empty_f32, + True, + False, + with_rht=with_rht, + random_sign_mask_t=mask_t if with_rht else 0, + with_swizzle=True, + ) + a_sr_flat = a_sr.reshape(-1) + b_sr_flat = b_sr.reshape(-1) + + # Per-tensor: NVFP4Quantizer (RHT+SR) -> pre-swizzle SF once so prod GEMM + # call doesn't pay 2 swizzle launches inside the timed window either. + quantizer = _make_baseline_quantizer() + dst_a = quantizer.make_empty(a.shape, dtype=torch.bfloat16, device=device) + dst_b = quantizer.make_empty(b.shape, dtype=torch.bfloat16, device=device) + tex.quantize(a, quantizer, dst_a, None) + tex.quantize(b, quantizer, dst_b, None) + + pten_a_sr_flat = dst_a._rowwise_scale_inv.reshape(-1) + pten_b_sr_flat = dst_b._rowwise_scale_inv.reshape(-1) + pten_a_sr_swz = torch.empty(pten_a_sr_flat.numel(), dtype=torch.uint8, device=device) + pten_b_sr_swz = torch.empty(pten_b_sr_flat.numel(), dtype=torch.uint8, device=device) + tex.nvfp4_per_token_swizzle_rowwise_sf(dst_a._rowwise_data, pten_a_sr_flat, pten_a_sr_swz) + tex.nvfp4_per_token_swizzle_rowwise_sf(dst_b._rowwise_data, pten_b_sr_flat, pten_b_sr_swz) + + def _pten_gemm(): + tex.nvfp4_per_tensor_gemm( + dst_a._rowwise_data, + dst_b._rowwise_data, + pten_a_sr_swz, + pten_b_sr_swz, + dst_a._amax_rowwise, + dst_b._amax_rowwise, + d, + workspace, + M, + N, + K, + 1.0, + 0.0, + a_sf_swizzled=True, + b_sf_swizzled=True, + ) + + # Forked CUTLASS NVFP4 GEMM with per-row * per-col rescale fused INTO the + # epilogue (EVT). Single launch; the M*N output never round-trips through + # HBM. This is the kernel that should beat pten_gemm at training shapes. + d_clf = torch.empty_like(d) + + def _cutlass_fused(): + tex.nvfp4_cutlass_per_token_gemm( + a_qr, + b_qr, + a_sr_flat, + b_sr_flat, + a_ra, + b_ra, + d_clf, + M, + N, + K, + a_sf_swizzled=True, + b_sf_swizzled=True, + ) + + t_pten = cuda_time_ms(_pten_gemm) + t_clf = cuda_time_ms(_cutlass_fused) + t_pten_g = cuda_graph_time_ms(_pten_gemm) + t_clf_g = cuda_graph_time_ms(_cutlass_fused) + + return GemmOnlyShapeBench( + M=M, + K=K, + N=N, + t_pten=t_pten, + t_clf=t_clf, + t_pten_g=t_pten_g, + t_clf_g=t_clf_g, + ) + + +def _bench_shape_qs( + M: int, + K: int, + *, + device: torch.device, + with_rht: bool = False, + mask_t: int = _RHT_MASK_DEFAULT, + pair: bool = False, + fuse: bool = False, +) -> QSShapeBench: + """K1+K2 + standalone rowwise swizzle, no GEMM. solo=3 launches/operand, + --pair=6 (A+B). Swizzle binding identical across pt/pten -- only K1+K2 differs.""" + N = M # square; matches --swizzle's apples-to-apples convention. + a = torch.randn((M, K), dtype=torch.bfloat16, device=device) + + BLOCK_K = 16 + + def _alloc_pt(R, C): + return ( + torch.empty((R, C // 2), dtype=torch.uint8, device=device), + torch.empty((R, C // BLOCK_K), dtype=torch.uint8, device=device), + torch.empty((R,), dtype=torch.float32, device=device), + torch.empty((C, R // 2), dtype=torch.uint8, device=device), + torch.empty((C, R // BLOCK_K), dtype=torch.uint8, device=device), + torch.empty((C,), dtype=torch.float32, device=device), + ) + + a_qr, a_sr, a_ra, a_qc, a_sc, a_ca = _alloc_pt(M, K) + a_sr_swz = torch.empty(a_sr.numel(), dtype=torch.uint8, device=device) + + # B-side allocation only when --pair (avoids spurious HBM pressure in solo). + if pair: + b = torch.randn((N, K), dtype=torch.bfloat16, device=device) + b_qr, b_sr, b_ra, b_qc, b_sc, b_ca = _alloc_pt(N, K) + b_sr_swz = torch.empty(b_sr.numel(), dtype=torch.uint8, device=device) + + def _pt_quant(t, qr, sr, ra_buf, qc, sc, ca_buf): + tex.nvfp4_per_token_quantize( + t, + qr, + sr, + ra_buf, + qc, + sc, + ca_buf, + True, + True, + with_rht=with_rht, + random_sign_mask_t=mask_t if with_rht else 0, + with_swizzle=False, # explicit external swizzle, see below + ) + + if pair: + + def _pt_qs(): + _pt_quant(a, a_qr, a_sr, a_ra, a_qc, a_sc, a_ca) + _pt_quant(b, b_qr, b_sr, b_ra, b_qc, b_sc, b_ca) + tex.nvfp4_per_token_swizzle_rowwise_sf(a_qr, a_sr.reshape(-1), a_sr_swz) + tex.nvfp4_per_token_swizzle_rowwise_sf(b_qr, b_sr.reshape(-1), b_sr_swz) + + else: + + def _pt_qs(): + _pt_quant(a, a_qr, a_sr, a_ra, a_qc, a_sc, a_ca) + tex.nvfp4_per_token_swizzle_rowwise_sf(a_qr, a_sr.reshape(-1), a_sr_swz) + + # Per-tensor baseline path: NVFP4Quantizer (RHT+SR), reuse internal storage. + quantizer = _make_baseline_quantizer() + dst_a = quantizer.make_empty(a.shape, dtype=torch.bfloat16, device=device) + pten_a_sr_swz = torch.empty(dst_a._rowwise_scale_inv.numel(), dtype=torch.uint8, device=device) + if pair: + dst_b = quantizer.make_empty(b.shape, dtype=torch.bfloat16, device=device) + pten_b_sr_swz = torch.empty( + dst_b._rowwise_scale_inv.numel(), dtype=torch.uint8, device=device + ) + + if pair: + + def _pten_qs(): + tex.quantize(a, quantizer, dst_a, None) + tex.quantize(b, quantizer, dst_b, None) + tex.nvfp4_per_token_swizzle_rowwise_sf( + dst_a._rowwise_data, dst_a._rowwise_scale_inv.reshape(-1), pten_a_sr_swz + ) + tex.nvfp4_per_token_swizzle_rowwise_sf( + dst_b._rowwise_data, dst_b._rowwise_scale_inv.reshape(-1), pten_b_sr_swz + ) + + else: + + def _pten_qs(): + tex.quantize(a, quantizer, dst_a, None) + tex.nvfp4_per_token_swizzle_rowwise_sf( + dst_a._rowwise_data, dst_a._rowwise_scale_inv.reshape(-1), pten_a_sr_swz + ) + + t_pt = cuda_time_ms(_pt_qs) + t_pten = cuda_time_ms(_pten_qs) + t_pt_g = cuda_graph_time_ms(_pt_qs) + t_pten_g = cuda_graph_time_ms(_pten_qs) + + t_pt_swz = float("nan") + t_pt_swz_g = float("nan") + if fuse: + # Fused-swizzle K2: writes rowwise SF directly in swizzled layout + # (same numel as compact, just byte-permuted). No external swizzle + # launch -- K1+K2 alone is the full pipeline. + a_qr_f, a_sr_f, a_ra_f, a_qc_f, a_sc_f, a_ca_f = _alloc_pt(M, K) + if pair: + b_qr_f, b_sr_f, b_ra_f, b_qc_f, b_sc_f, b_ca_f = _alloc_pt(N, K) + + def _pt_quant_fused(t, qr, sr, ra_buf, qc, sc, ca_buf): + tex.nvfp4_per_token_quantize( + t, + qr, + sr, + ra_buf, + qc, + sc, + ca_buf, + True, + True, + with_rht=with_rht, + random_sign_mask_t=mask_t if with_rht else 0, + with_swizzle=True, # <-- fused: K2 emits swizzled rowwise SF + ) + + if pair: + + def _pt_qs_fused(): + _pt_quant_fused(a, a_qr_f, a_sr_f, a_ra_f, a_qc_f, a_sc_f, a_ca_f) + _pt_quant_fused(b, b_qr_f, b_sr_f, b_ra_f, b_qc_f, b_sc_f, b_ca_f) + + else: + + def _pt_qs_fused(): + _pt_quant_fused(a, a_qr_f, a_sr_f, a_ra_f, a_qc_f, a_sc_f, a_ca_f) + + t_pt_swz = cuda_time_ms(_pt_qs_fused) + t_pt_swz_g = cuda_graph_time_ms(_pt_qs_fused) + + return QSShapeBench( + M=M, + K=K, + t_pt=t_pt, + t_pten=t_pten, + t_pt_g=t_pt_g, + t_pten_g=t_pten_g, + t_pt_swz=t_pt_swz, + t_pt_swz_g=t_pt_swz_g, + ) + + +def _print_qs_table(records: List[QSShapeBench], *, fuse: bool) -> None: + """K1+K2 + rowwise swizzle (no GEMM). 2-way default, 3-way w/ --fuse. + Ratio = per-token(fuse if --fuse else plain) / per-tensor.""" + + def _fmt(r: float) -> str: + return "nan" if math.isnan(r) else f"{r:.2f}x" + + if not fuse: + w_pt, w_pten, w_ratio = 14, 15, 8 + block_w = w_pt + 1 + w_pten + 1 + w_ratio + header1 = ( + f"{'':>7} {'':>6} |{'Eager, unit (ms)':^{block_w}} |{'Graph, unit (ms)':^{block_w}}" + ) + header2 = ( + f"{'M':>7} {'K':>6}" + " |" + f"{'per-token':>{w_pt}} {'per-tensor':>{w_pten}} {'ratio':>{w_ratio}}" + " |" + f"{'per-token':>{w_pt}} {'per-tensor':>{w_pten}} {'ratio':>{w_ratio}}" + ) + print(header1) + print(header2) + print("-" * len(header2)) + prev_M = None + for rec in records: + if prev_M is not None and rec.M != prev_M: + print() + prev_M = rec.M + ratio = _ratio(rec.t_pt, rec.t_pten) + ratio_g = _ratio(rec.t_pt_g, rec.t_pten_g) + print( + f"{rec.M:>7} {rec.K:>6}" + " |" + f"{rec.t_pt:>{w_pt}.4f} {rec.t_pten:>{w_pten}.4f} {_fmt(ratio):>{w_ratio}}" + " |" + f"{rec.t_pt_g:>{w_pt}.4f} {rec.t_pten_g:>{w_pten}.4f} {_fmt(ratio_g):>{w_ratio}}" + ) + return + + # 3-way with fuse column + w_pt, w_swz, w_pten, w_ratio = 12, 14, 13, 8 + block_w = w_pt + 1 + w_swz + 1 + w_pten + 1 + w_ratio + header1 = f"{'':>7} {'':>6} |{'Eager, unit (ms)':^{block_w}} |{'Graph, unit (ms)':^{block_w}}" + header2 = ( + f"{'M':>7} {'K':>6}" + " |" + f"{'per-token':>{w_pt}} {'per-token':>{w_swz}}" + f" {'per-tensor':>{w_pten}} {'ratio':>{w_ratio}}" + " |" + f"{'per-token':>{w_pt}} {'per-token':>{w_swz}}" + f" {'per-tensor':>{w_pten}} {'ratio':>{w_ratio}}" + ) + header3 = ( + f"{'':>7} {'':>6}" + " |" + f"{'':>{w_pt}} {'(fuse)':>{w_swz}}" + f" {'':>{w_pten}} {'':>{w_ratio}}" + " |" + f"{'':>{w_pt}} {'(fuse)':>{w_swz}}" + f" {'':>{w_pten}} {'':>{w_ratio}}" + ) + print(header1) + print(header2) + print(header3) + print("-" * len(header2)) + prev_M = None + for rec in records: + if prev_M is not None and rec.M != prev_M: + print() + prev_M = rec.M + # 3-way ratio uses the fused-swizzle column vs per-tensor. + ratio = _ratio(rec.t_pt_swz, rec.t_pten) + ratio_g = _ratio(rec.t_pt_swz_g, rec.t_pten_g) + print( + f"{rec.M:>7} {rec.K:>6}" + " |" + f"{rec.t_pt:>{w_pt}.4f} {rec.t_pt_swz:>{w_swz}.4f}" + f" {rec.t_pten:>{w_pten}.4f} {_fmt(ratio):>{w_ratio}}" + " |" + f"{rec.t_pt_g:>{w_pt}.4f} {rec.t_pt_swz_g:>{w_swz}.4f}" + f" {rec.t_pten_g:>{w_pten}.4f} {_fmt(ratio_g):>{w_ratio}}" + ) + + +def _print_qs_legend(*, with_rht: bool, rht_mask: int, pair: bool, fuse: bool) -> None: + print() + n_tensors = 2 if pair else 1 + n_launches_ext = 3 * n_tensors # K1+K2+swz per tensor + n_launches_fused = 2 * n_tensors # K1+K2 only per tensor (swizzle folded into K2) + mode_tag = "--pair, 2 operands" if pair else "default solo, 1 operand" + n_kernels_tag = f"ext-swz pipeline {n_launches_ext} launches" + ( + f" / fused pipeline {n_launches_fused} launches" if fuse else "" + ) + print(f"Legend (K1+K2 + rowwise swizzle; NO GEMM; mode = {mode_tag}; {n_kernels_tag}):") + rht_suffix = ( + f"with_rht=True + random_sign_mask_t=0x{rht_mask:04X}" if with_rht else "with_rht=False" + ) + print( + f" per-token (ms) = {n_tensors} x nvfp4_per_token_quantize({rht_suffix})" + " # K1+K2 each" + ) + print( + f" + {n_tensors} x nvfp4_per_token_swizzle_rowwise_sf" + " # 1 swz each" + ) + print(" K1 = nvfp4_per_token_amax (per-row/per-col vec amax)") + print(" K2 = nvfp4_per_token_encode (cast + e4m3 SF + optional RHT)") + if fuse: + print( + f" per-token (fuse) (ms) = {n_tensors} x nvfp4_per_token_quantize(..., " + "with_swizzle=True)" + ) + print(" # K1+K2 each; K2 directly emits the swizzled rowwise") + print(" # SF in cuBLAS LT layout (no separate swizzle launch).") + print( + f" per-tensor (ms) = {n_tensors} x tex.quantize(NVFP4Quantizer(rht+sr))" + " # K1+K2 each" + ) + print( + f" + {n_tensors} x nvfp4_per_token_swizzle_rowwise_sf" + " # 1 swz each" + ) + print(" K1 = nvte_hadamard_transform_amax (post-RHT scalar amax)") + print(" K2 = nvte_quantize_with_hadamard_transform") + print(" (RHT + SR + cast fusion, rowwise + columnwise)") + if fuse: + print(" The (fuse) column saves 1 swizzle launch/operand vs the non-fuse column;") + print(" the K2 byte-output is identical (verified by pytest byte-equality test).") + if not pair: + print(" solo mode is apples-to-apples with --composite (also 1 operand): the delta") + print(" per-token(--qs) - per-token(--composite) ~= one nvte_swizzle launch.") + else: + print(" --pair mode = one prod NVFP4 GEMM call's quant+swizzle pipeline (1 swz/operand).") + if fuse: + print(" ratio = per-token(fuse) / per-tensor") + else: + print(" ratio = per-token / per-tensor") + print(" ** < 1.0 = this PR wins vs prod K1+K2+swizzle path **") + print(" (Graph) suffix = same under CUDA Graphs replay (Python + alloc elided).") + + +def _print_e2e_swizzle_table(records: List[E2EShapeBench]) -> None: + """3-way end-to-end (--swizzle). ratio = per-token (+swizzle) / per-tensor.""" + w_pt, w_swz, w_pten, w_ratio = 12, 14, 13, 8 + block_w = w_pt + 1 + w_swz + 1 + w_pten + 1 + w_ratio + header1 = f"{'':>7} {'':>6} |{'Eager, unit (ms)':^{block_w}} |{'Graph, unit (ms)':^{block_w}}" + header2 = ( + f"{'M':>7} {'K':>6}" + " |" + f"{'per-token':>{w_pt}} {'per-token':>{w_swz}}" + f" {'per-tensor':>{w_pten}} {'ratio':>{w_ratio}}" + " |" + f"{'per-token':>{w_pt}} {'per-token':>{w_swz}}" + f" {'per-tensor':>{w_pten}} {'ratio':>{w_ratio}}" + ) + header3 = ( + f"{'':>7} {'':>6}" + " |" + f"{'':>{w_pt}} {'(+swizzle)':>{w_swz}}" + f" {'':>{w_pten}} {'':>{w_ratio}}" + " |" + f"{'':>{w_pt}} {'(+swizzle)':>{w_swz}}" + f" {'':>{w_pten}} {'':>{w_ratio}}" + ) + print(header1) + print(header2) + print(header3) + print("-" * len(header2)) + prev_M = None + for rec in records: + if prev_M is not None and rec.M != prev_M: + print() + prev_M = rec.M + ratio = _ratio(rec.t_pt_swz, rec.t_pten) + ratio_g = _ratio(rec.t_pt_swz_g, rec.t_pten_g) + + def _fmt(r: float) -> str: + return "nan" if math.isnan(r) else f"{r:.2f}x" + + print( + f"{rec.M:>7} {rec.K:>6}" + " |" + f"{rec.t_pt:>{w_pt}.4f} {rec.t_pt_swz:>{w_swz}.4f}" + f" {rec.t_pten:>{w_pten}.4f} {_fmt(ratio):>{w_ratio}}" + " |" + f"{rec.t_pt_g:>{w_pt}.4f} {rec.t_pt_swz_g:>{w_swz}.4f}" + f" {rec.t_pten_g:>{w_pten}.4f} {_fmt(ratio_g):>{w_ratio}}" + ) + + +def _print_e2e_swizzle_legend(*, with_rht: bool, rht_mask: int) -> None: + print() + print("Legend (end-to-end quant + cuBLAS LT NVFP4 GEMM, square N=M):") + rht_suffix = ( + f"with_rht=True + random_sign_mask_t=0x{rht_mask:04X}" if with_rht else "with_rht=False" + ) + print(f" per-token (ms) = nvfp4_per_token_quantize({rht_suffix}) +") + print(" nvfp4_per_token_gemm(sf_swizzled=False)") + print(" -> K1 + K2 + 2 swizzle launches + cuBLAS LT GEMM") + print(" + per-token post-scale.") + print(f" per-token (+swizzle) (ms) = nvfp4_per_token_quantize({rht_suffix},") + print(" with_swizzle=True) +") + print(" nvfp4_per_token_gemm(sf_swizzled=True)") + print(" -> K1 + K2 (fused swizzle) + cuBLAS LT GEMM") + print(" + per-token post-scale. (2 launches saved.)") + print(" per-tensor (ms) = tex.quantize(a, NVFP4Quantizer(rht+sr)) +") + print(" nvfp4_per_tensor_gemm (cuBLAS LT NVFP4)") + print(" -> fused RHT+quant + 2 swizzle launches + GEMM.") + print(" ratio = per-token (+swizzle) / per-tensor") + print(" ** < 1.0 = this PR wins vs prod baseline **") + print(" (Graph) suffix = same under CUDA Graphs replay (Python + alloc elided).") + + +def _print_gemm_only_table(records: List[GemmOnlyShapeBench]) -> None: + """GEMM-only (--gemm-only) timings: + pten_gemm = cuBLAS LT per-tensor NVFP4 GEMM (current PROD baseline). + ct_fused = forked CUTLASS per-token NVFP4 GEMM with per-row * per-col + rescale fused into the EVT epilogue (1 launch). + + Ratio: + cf/pten = ct_fused / pten_gemm + ** < 1.0 = per-token fused CUTLASS matches/beats prod per-tensor ** + """ + w_pten, w_clf, w_ratio = 11, 11, 8 + block_w = w_pten + 1 + w_clf + 1 + w_ratio + header1 = f"{'':>7} {'':>6} {'':>6} |{'Eager':^{block_w}} |{'Graph':^{block_w}}" + body = f"{'pten_gemm':>{w_pten}} {'ct_fused':>{w_clf}} {'cf/pten':>{w_ratio}}" + header2 = f"{'M':>7} {'K':>6} {'N':>6} |{body}|{body}" + print(header1) + print(header2) + print("-" * len(header2)) + + def _fmt(r: float) -> str: + return "nan" if math.isnan(r) else f"{r:.2f}x" + + prev_M = None + for rec in records: + if prev_M is not None and rec.M != prev_M: + print() + prev_M = rec.M + r_cf = _ratio(rec.t_clf, rec.t_pten) + r_cf_g = _ratio(rec.t_clf_g, rec.t_pten_g) + print( + f"{rec.M:>7} {rec.K:>6} {rec.N:>6}" + " |" + f"{rec.t_pten:>{w_pten}.4f} {rec.t_clf:>{w_clf}.4f}" + f" {_fmt(r_cf):>{w_ratio}}" + "|" + f"{rec.t_pten_g:>{w_pten}.4f} {rec.t_clf_g:>{w_clf}.4f}" + f" {_fmt(r_cf_g):>{w_ratio}}" + ) + + +def _print_gemm_only_legend() -> None: + print() + print("Legend (GEMM-only; inputs pre-quantized + pre-swizzled, N = K):") + print(" pten_gemm (ms) = nvfp4_per_tensor_gemm(sf_swizzled=True)") + print(" -> cuBLAS LT NVFP4 + alpha-fold (current PROD per-tensor GEMM).") + print(" ct_fused (ms) = nvfp4_cutlass_per_token_gemm(sf_swizzled=True)") + print(" -> forked CUTLASS NVFP4 GEMM with per-row * per-col rescale") + print(" FUSED into the EVT epilogue (1 launch, no post-scale).") + print(" D = bf16(alpha_a[i] * alpha_b[j] * (A @ B^T)[i, j]).") + print(" cf/pten = ct_fused / pten_gemm") + print(" ** < 1.0 = per-token fused CUTLASS matches/beats prod per-tensor **") + print(" (Graph) suffix = same under CUDA Graphs replay (Python + alloc elided).") + + +def _bench_shape_k1_only( + M: int, K: int, *, device: torch.device, with_rht: bool = False, mask_t: int = _RHT_MASK_DEFAULT +) -> K1ShapeBench: + """K1-only. pt = per-token (no RHT); pt_rht = +col RHT (NaN unless --rht); + prod = hadamard_transform_amax (scalar amax; not apples-to-apples).""" + a = torch.randn((M, K), dtype=torch.bfloat16, device=device) + + # Per-token K1 amax buffers (vectors). + ra_pt = torch.empty((M,), dtype=torch.float32, device=device) + ca_pt = torch.empty((K,), dtype=torch.float32, device=device) + + # prod K1 amax buffers (scalars). + ra_prod = torch.empty((1,), dtype=torch.float32, device=device) + ca_prod = torch.empty((1,), dtype=torch.float32, device=device) + + def _pt_k1_fn(): + tex.nvfp4_per_token_amax( + a, + ra_pt, + ca_pt, + True, + True, + with_rht=False, + random_sign_mask_t=0, + ) + + def _prod_k1_fn(): + # row pre-RHT + col post-RHT scalar amax; both numel=1 buffers. + tex.hadamard_transform_amax(a, ra_prod, ca_prod, mask_t) + + t_pt = cuda_time_ms(_pt_k1_fn) + t_prod = cuda_time_ms(_prod_k1_fn) + t_pt_g = cuda_graph_time_ms(_pt_k1_fn) + t_prod_g = cuda_graph_time_ms(_prod_k1_fn) + + if with_rht: + ra_pt_rht = torch.empty((M,), dtype=torch.float32, device=device) + ca_pt_rht = torch.empty((K,), dtype=torch.float32, device=device) + + def _pt_k1_rht_fn(): + tex.nvfp4_per_token_amax( + a, + ra_pt_rht, + ca_pt_rht, + True, + True, + with_rht=True, + random_sign_mask_t=mask_t, + ) + + t_pt_rht = cuda_time_ms(_pt_k1_rht_fn) + t_pt_rht_g = cuda_graph_time_ms(_pt_k1_rht_fn) + else: + t_pt_rht = float("nan") + t_pt_rht_g = float("nan") + + return K1ShapeBench( + M=M, + K=K, + t_pt=t_pt, + t_pt_rht=t_pt_rht, + t_prod=t_prod, + t_pt_g=t_pt_g, + t_pt_rht_g=t_pt_rht_g, + t_prod_g=t_prod_g, + ) + + +# 6x3 sweep matching bench_nvfp4_per_token_group.py: M in {1024..32768}, K in {2048,4096,8192}. +_M_VALUES: Tuple[int, ...] = (1024, 2048, 4096, 8192, 16384, 32768) +_K_VALUES: Tuple[int, ...] = (2048, 4096, 8192) +_DEFAULT_SHAPES: Tuple[Tuple[int, int], ...] = tuple((m, k) for m in _M_VALUES for k in _K_VALUES) + + +def _parse_shape(s: str) -> Tuple[int, int]: + parts = s.split("x") + if len(parts) != 2: + raise argparse.ArgumentTypeError(f"Shape must be MxK, got '{s}'") + return tuple(int(p) for p in parts) # type: ignore[return-value] + + +def _ratio(num: float, den: float) -> float: + if den <= 0 or math.isnan(num) or math.isnan(den): + return float("nan") + return num / den + + +def _print_composite_table_2way(records: List[ShapeBench]) -> None: + """2-way composite (no RHT). ratio = per-token / per-tensor (< 1.0 wins).""" + w_pt, w_pten, w_ratio = 14, 15, 8 + block_w = w_pt + 1 + w_pten + 1 + w_ratio + header1 = f"{'':>7} {'':>6} |{'Eager, unit (ms)':^{block_w}} |{'Graph, unit (ms)':^{block_w}}" + header2 = ( + f"{'M':>7} {'K':>6}" + " |" + f"{'per-token':>{w_pt}} {'per-tensor':>{w_pten}} {'ratio':>{w_ratio}}" + " |" + f"{'per-token':>{w_pt}} {'per-tensor':>{w_pten}} {'ratio':>{w_ratio}}" + ) + print(header1) + print(header2) + print("-" * len(header2)) + prev_M = None + for rec in records: + if prev_M is not None and rec.M != prev_M: + print() + prev_M = rec.M + ratio = _ratio(rec.t_pt, rec.t_pten) + ratio_g = _ratio(rec.t_pt_g, rec.t_pten_g) + + def _fmt(r: float) -> str: + return "nan" if math.isnan(r) else f"{r:.2f}x" + + print( + f"{rec.M:>7} {rec.K:>6}" + " |" + f"{rec.t_pt:>{w_pt}.4f} {rec.t_pten:>{w_pten}.4f} {_fmt(ratio):>{w_ratio}}" + " |" + f"{rec.t_pt_g:>{w_pt}.4f} {rec.t_pten_g:>{w_pten}.4f} {_fmt(ratio_g):>{w_ratio}}" + ) + + +def _print_composite_table(records: List[ShapeBench]) -> None: + """3-way composite (--rht). ratio = per-token (+rht) / per-tensor.""" + w_pt, w_pt_rht, w_pten, w_ratio = 12, 12, 13, 8 + block_w = w_pt + 1 + w_pt_rht + 1 + w_pten + 1 + w_ratio + header1 = f"{'':>7} {'':>6} |{'Eager, unit (ms)':^{block_w}} |{'Graph, unit (ms)':^{block_w}}" + header2 = ( + f"{'M':>7} {'K':>6}" + " |" + f"{'per-token':>{w_pt}} {'per-token':>{w_pt_rht}}" + f" {'per-tensor':>{w_pten}} {'ratio':>{w_ratio}}" + " |" + f"{'per-token':>{w_pt}} {'per-token':>{w_pt_rht}}" + f" {'per-tensor':>{w_pten}} {'ratio':>{w_ratio}}" + ) + header3 = ( + f"{'':>7} {'':>6}" + " |" + f"{'':>{w_pt}} {'(+rht)':>{w_pt_rht}}" + f" {'':>{w_pten}} {'':>{w_ratio}}" + " |" + f"{'':>{w_pt}} {'(+rht)':>{w_pt_rht}}" + f" {'':>{w_pten}} {'':>{w_ratio}}" + ) + print(header1) + print(header2) + print(header3) + print("-" * len(header2)) + prev_M = None + for rec in records: + if prev_M is not None and rec.M != prev_M: + print() + prev_M = rec.M + ratio = _ratio(rec.t_pt_rht, rec.t_pten) + ratio_g = _ratio(rec.t_pt_rht_g, rec.t_pten_g) + + def _fmt(r: float) -> str: + return "nan" if math.isnan(r) else f"{r:.2f}x" + + print( + f"{rec.M:>7} {rec.K:>6}" + " |" + f"{rec.t_pt:>{w_pt}.4f} {rec.t_pt_rht:>{w_pt_rht}.4f}" + f" {rec.t_pten:>{w_pten}.4f} {_fmt(ratio):>{w_ratio}}" + " |" + f"{rec.t_pt_g:>{w_pt}.4f} {rec.t_pt_rht_g:>{w_pt_rht}.4f}" + f" {rec.t_pten_g:>{w_pten}.4f} {_fmt(ratio_g):>{w_ratio}}" + ) + + +def _print_k1_2way_table(records: List[K1ShapeBench]) -> None: + """2-way K1 (default --k1-only). pt_K1 vs prod_K1 (not apples-to-apples: + per-token outputs M+K floats, prod outputs 2 scalars).""" + print("K1-only: pt vs prod (NOT apples-to-apples; output shapes differ).") + header = ( + f"{'M':>7} {'K':>6}" + " |" + f"{'pt_K1':>9} {'prod_K1':>9} {'ratio':>8}" + " |" + f"{'pt_K1(Graph)':>14} {'prod_K1(Graph)':>16} {'ratio(Graph)':>13}" + ) + print(header) + print("-" * len(header)) + prev_M = None + for rec in records: + if prev_M is not None and rec.M != prev_M: + print() + prev_M = rec.M + ratio = _ratio(rec.t_pt, rec.t_prod) + ratio_g = _ratio(rec.t_pt_g, rec.t_prod_g) + ratio_s = "nan" if math.isnan(ratio) else f"{ratio:.2f}x" + ratio_g_s = "nan" if math.isnan(ratio_g) else f"{ratio_g:.2f}x" + print( + f"{rec.M:>7} {rec.K:>6}" + " |" + f"{rec.t_pt:>9.4f} {rec.t_prod:>9.4f} {ratio_s:>8}" + " |" + f"{rec.t_pt_g:>14.4f} {rec.t_prod_g:>16.4f} {ratio_g_s:>13}" + ) + + +def _print_k1_rht_cost_table(records: List[K1ShapeBench]) -> None: + """Table A: pt_K1 vs pt_K1+RHT (apples-to-apples; same output shapes).""" + print("Table A -- K1-only RHT cost (pt = per-token, +RHT = col-wise FHT).") + header = ( + f"{'M':>7} {'K':>6}" + " |" + f"{'pt_K1':>9} {'pt_K1+RHT':>11} {'ratio':>8}" + " |" + f"{'pt_K1(Graph)':>14} {'pt_K1+RHT(Graph)':>18} {'ratio(Graph)':>13}" + ) + print(header) + print("-" * len(header)) + prev_M = None + for rec in records: + if prev_M is not None and rec.M != prev_M: + print() + prev_M = rec.M + ratio = _ratio(rec.t_pt_rht, rec.t_pt) + ratio_g = _ratio(rec.t_pt_rht_g, rec.t_pt_g) + ratio_s = "nan" if math.isnan(ratio) else f"{ratio:.2f}x" + ratio_g_s = "nan" if math.isnan(ratio_g) else f"{ratio_g:.2f}x" + print( + f"{rec.M:>7} {rec.K:>6}" + " |" + f"{rec.t_pt:>9.4f} {rec.t_pt_rht:>11.4f} {ratio_s:>8}" + " |" + f"{rec.t_pt_g:>14.4f} {rec.t_pt_rht_g:>18.4f} {ratio_g_s:>13}" + ) + + +def _print_k1_vs_prod_table(records: List[K1ShapeBench]) -> None: + """Table B: pt_K1+RHT vs prod_K1 (not apples-to-apples; 2 scalars + vs M+K floats). Fast-floor reference only.""" + print("Table B -- K1-only vs prod (NOT apples-to-apples; output shapes differ).") + header = ( + f"{'M':>7} {'K':>6}" + " |" + f"{'pt_K1+RHT':>11} {'prod_K1':>9} {'ratio':>8}" + " |" + f"{'pt_K1+RHT(Graph)':>18} {'prod_K1(Graph)':>16} {'ratio(Graph)':>13}" + ) + print(header) + print("-" * len(header)) + prev_M = None + for rec in records: + if prev_M is not None and rec.M != prev_M: + print() + prev_M = rec.M + ratio = _ratio(rec.t_pt_rht, rec.t_prod) + ratio_g = _ratio(rec.t_pt_rht_g, rec.t_prod_g) + ratio_s = "nan" if math.isnan(ratio) else f"{ratio:.2f}x" + ratio_g_s = "nan" if math.isnan(ratio_g) else f"{ratio_g:.2f}x" + print( + f"{rec.M:>7} {rec.K:>6}" + " |" + f"{rec.t_pt_rht:>11.4f} {rec.t_prod:>9.4f} {ratio_s:>8}" + " |" + f"{rec.t_pt_rht_g:>18.4f} {rec.t_prod_g:>16.4f} {ratio_g_s:>13}" + ) + + +def _print_composite_legend(*, with_rht: bool, rht_mask: int) -> None: + """Prose legend mapping table labels to their C++ entry points.""" + print() + print("Legend:") + if with_rht: + print(" per-token (ms) = tex.nvfp4_per_token_quantize(a, ..., rowwise+colwise,") + print(" with_rht=False)") + print(" = K1 fused amax + K2 fused cast (2 launches), no RHT.") + print( + " per-token (+rht) (ms) = same, but with_rht=True +" + f" random_sign_mask_t=0x{rht_mask:04X}." + ) + print(" Applies a 16-point RHT along the columnwise direction in") + print(" BOTH K1 amax and K2 cast; rowwise stays raw. Length-16") + print(" matches the 1x16 inner-SF block of NVFP4, so each scale") + print(" window is decorrelated.") + print(" per-tensor (ms) = tex.quantize(a, NVFP4Quantizer(rht+sr), ...)") + print(" = nvte_quantize_with_hadamard_transform") + print( + " (1 fused launch: rowwise quant + col-wise RHT + col quant," + ) + print(" prod baseline).") + print(" ratio = per-token (+rht) / per-tensor") + print(" ** < 1.0 = this PR wins vs prod baseline **") + else: + print( + " per-token (ms) = tex.nvfp4_per_token_quantize(a, ..., rowwise+colwise," + " with_rht=False)" + ) + print(" = K1 fused amax + K2 fused cast (2 launches), no RHT.") + print(" per-tensor (ms) = tex.quantize(a, NVFP4Quantizer(rht+sr), ...)") + print(" = nvte_quantize_with_hadamard_transform") + print(" (1 fused launch: rowwise quant + col-wise RHT + col quant,") + print(" prod baseline).") + print( + " ratio = per-token / per-tensor ** < 1.0 = per-token wins vs prod" + " baseline **" + ) + print(" (Graph) suffix = same under CUDA Graphs replay (Python + alloc elided).") + + +def main() -> int: + parser = argparse.ArgumentParser( + description="Benchmark NVFP4 per-token K1+K2 quant vs per-tensor production NVFP4." + ) + parser.add_argument( + "--shapes", + type=_parse_shape, + nargs="+", + default=None, + help=( + "Shapes to bench, in MxK form (e.g. 4096x4096). " + "Default: an internally-chosen production-shape sweep." + ), + ) + parser.add_argument( + "--rht", + action="store_true", + help=( + "Also time the per-token + RHT path (col-wise 16-pt RHT in K1 + K2). " + "Default OFF: prints a 2-way table (per-token vs per-tensor). With " + "--rht: prints a 3-way table with one ratio " + "(per-token (+rht) / per-tensor)." + ), + ) + parser.add_argument( + "--k1-only", + action="store_true", + help=( + "K1-only mode (no K2 cast). Without --rht: 2-way table (pt_K1 " + "vs prod_K1). With --rht: two tables back-to-back -- (A) RHT cost " + "pt_K1 vs pt_K1+RHT (apples-to-apples) and (B) pt_K1+RHT vs prod_K1 " + "(context only; output shapes differ)." + ), + ) + parser.add_argument( + "--swizzle", + action="store_true", + help=( + "End-to-end mode: quant + cuBLAS LT NVFP4 GEMM (square N=M). " + "Prints a 3-way table: per-token (external swizzle) vs per-token " + "(fused swizzle in K2, sf_swizzled=True) vs per-tensor. Ratio = " + "per-token (+swizzle) / per-tensor. --rht composes (adds 16-pt " + "col-wise RHT to the per-token paths)." + ), + ) + parser.add_argument( + "--qs", + action="store_true", + help=( + "K1+K2 + standalone rowwise swizzle. NO GEMM. 2-way table: " + "per-token vs per-tensor. Default solo (1 operand, 3 launches) is " + "apples-to-apples with --composite; add --pair for 2-operand " + "(6 launches, matches prod NVFP4 GEMM's per-call pipeline). " + "--rht composes." + ), + ) + parser.add_argument( + "--gemm-only", + action="store_true", + help=( + "GEMM-only mode (square N=M): inputs are pre-quantized + pre-swizzled " + "outside the timed window, so only the cuBLAS LT NVFP4 GEMM call is " + "timed. 2-way table: pt_gemm (per-token GEMM + per-row post-scale) " + "vs pten_gemm (per-tensor GEMM, alpha-folded). ratio = pt / pten " + "exposes the per-call cost of the per-token post-scale kernel. " + "--rht composes (RHT applied only to the per-token quant setup)." + ), + ) + parser.add_argument( + "--pair", + action="store_true", + help=( + "Modifier for --qs: bench the 2-operand (A + B) pipeline, matching " + "what prod NVFP4 GEMM does per call (1 K1+K2 + 1 swizzle per " + "operand). Default (no --pair) is solo (1 operand)." + ), + ) + parser.add_argument( + "--fuse", + action="store_true", + help=( + "Modifier for --qs: also bench per-token with fused-swizzle K2 " + "(K2 directly emits the rowwise SF in cuBLAS LT swizzled layout; " + "no separate swizzle launch). Adds a 'per-token(fuse)' column to " + "the table, and the ratio switches to per-token(fuse) / per-tensor." + ), + ) + parser.add_argument( + "--rht-mask", + type=lambda s: int(s, 0), + default=_RHT_MASK_DEFAULT, + help=( + "16-bit random sign mask for the RHT path (only matters with --rht). " + f"Default 0x{_RHT_MASK_DEFAULT:04X}; accepts hex (0x...) or decimal." + ), + ) + args = parser.parse_args() + + if not _has_sm100(): + print("SKIP: NVFP4 per-token quant requires SM100 (Blackwell).", file=sys.stderr) + return 1 + + device = torch.device("cuda") + shapes = list(args.shapes) if args.shapes else list(_DEFAULT_SHAPES) + mask = args.rht_mask & 0xFFFF + + # --pair / --fuse are modifiers for --qs; auto-imply --qs if either is set + # alone, so we don't silently fall through to --composite default and bake + # a confusing "looks-like-the-modifier-worked-but-didnt" table. + if (args.pair or args.fuse) and not args.qs: + modifiers = [] + if args.pair: + modifiers.append("--pair") + if args.fuse: + modifiers.append("--fuse") + print( + f"INFO: {' / '.join(modifiers)} implies --qs; running --qs " + f"{' '.join(modifiers)} (K1+K2 + swizzle, no GEMM).", + file=sys.stderr, + ) + args.qs = True + + exclusive = sum(int(x) for x in (args.k1_only, args.swizzle, args.qs, args.gemm_only)) + if exclusive > 1: + print( + "ERROR: --k1-only, --swizzle, --qs, and --gemm-only are mutually exclusive.", + file=sys.stderr, + ) + return 2 + + if args.k1_only: + records_k1: List[K1ShapeBench] = [ + _bench_shape_k1_only(M, K, device=device, with_rht=args.rht, mask_t=mask) + for (M, K) in shapes + ] + if args.rht: + _print_k1_rht_cost_table(records_k1) + print() + _print_k1_vs_prod_table(records_k1) + else: + _print_k1_2way_table(records_k1) + elif args.swizzle: + records_e2e: List[E2EShapeBench] = [ + _bench_shape_e2e_swizzle(M, K, device=device, with_rht=args.rht, mask_t=mask) + for (M, K) in shapes + ] + _print_e2e_swizzle_table(records_e2e) + _print_e2e_swizzle_legend(with_rht=args.rht, rht_mask=mask) + elif args.qs: + records_qs: List[QSShapeBench] = [ + _bench_shape_qs( + M, + K, + device=device, + with_rht=args.rht, + mask_t=mask, + pair=args.pair, + fuse=args.fuse, + ) + for (M, K) in shapes + ] + _print_qs_table(records_qs, fuse=args.fuse) + _print_qs_legend(with_rht=args.rht, rht_mask=mask, pair=args.pair, fuse=args.fuse) + elif args.gemm_only: + records_go: List[GemmOnlyShapeBench] = [ + _bench_shape_gemm_only(M, K, device=device, with_rht=args.rht, mask_t=mask) + for (M, K) in shapes + ] + _print_gemm_only_table(records_go) + _print_gemm_only_legend() + else: + records: List[ShapeBench] = [ + _bench_shape(M, K, device=device, with_rht=args.rht, mask_t=mask) for (M, K) in shapes + ] + if args.rht: + _print_composite_table(records) + else: + _print_composite_table_2way(records) + _print_composite_legend(with_rht=args.rht, rht_mask=mask) + + return 0 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/tests/pytorch/nvfp4/bench_nvfp4_per_token_group.py b/tests/pytorch/nvfp4/bench_nvfp4_per_token_group.py new file mode 100644 index 0000000000..d6f3a50da5 --- /dev/null +++ b/tests/pytorch/nvfp4/bench_nvfp4_per_token_group.py @@ -0,0 +1,460 @@ +"""Bench NVFP4 per-token grouped K1+K2 quant vs per-tensor RHT+SR baseline. + +Modes: + * default: 2-way (per-token vs per-tensor). Ratio = pt / pten. + * ``--rht``: 3-way (adds per-token + col-wise 16-pt RHT). Ratio = + per-token (+rht) / per-tensor. + +Default sweep: N=8 equal splits, sum_M in {1024..32768} x K in {2048,4096,8192}. +Requires bf16, K % 128 == 0, every split % 128 == 0, num_splits <= 64. + +CLI: + --shapes SUMMxK ... custom shapes (default: 18-row sweep) + --num-splits N equal splits per shape (default 8) + --rht enable 3-way RHT comparison + --rht-mask 0x... 16-bit RHT sign pattern (default 0xACE1) +""" + +from __future__ import annotations + +import argparse +import math +import statistics +import sys +from typing import Callable, List, Tuple + +import torch + +# Import transformer_engine first so libtransformer_engine.so is dlopen'd +# before transformer_engine_torch tries to resolve its typeinfo symbols. +import transformer_engine.pytorch as te # noqa: F401 +import transformer_engine_torch as tex # type: ignore # noqa: F401 + +from transformer_engine.pytorch import NVFP4Quantizer +from transformer_engine.pytorch.custom_recipes.quantization_nvfp4_per_token_group import ( + nvfp4_per_token_group_quantize, +) + + +def _make_baseline_quantizer_list(num_splits: int) -> List[NVFP4Quantizer]: + """Per-tensor RHT+SR baseline: one quantizer instance shared across N splits.""" + q = NVFP4Quantizer( + fp4_dtype=tex.DType.kFloat4E2M1, + rowwise=True, + columnwise=True, + with_rht=True, + with_post_rht_amax=True, + stochastic_rounding=True, + with_random_sign_mask=True, + ) + return [q] * num_splits + + +def cuda_graph_time_ms(fn: Callable[[], object], *, warmup: int = 5, iters: int = 50) -> float: + """Median g.replay() time of fn under CUDA Graphs, in ms (nan on capture failure).""" + try: + side = torch.cuda.Stream() + side.wait_stream(torch.cuda.current_stream()) + with torch.cuda.stream(side): + for _ in range(warmup): + fn() + torch.cuda.current_stream().wait_stream(side) + torch.cuda.synchronize() + + g = torch.cuda.CUDAGraph() + with torch.cuda.graph(g): + fn() + except Exception as e: + print(f" [graph capture skipped: {type(e).__name__}: {e}]", file=sys.stderr) + return float("nan") + + starts = [torch.cuda.Event(enable_timing=True) for _ in range(iters)] + ends = [torch.cuda.Event(enable_timing=True) for _ in range(iters)] + for i in range(iters): + starts[i].record() + g.replay() + ends[i].record() + torch.cuda.synchronize() + return statistics.median(starts[i].elapsed_time(ends[i]) for i in range(iters)) + + +# Default RHT mask seed; matches te-nvfp4-build-overrides.mdc convention. +_RHT_MASK_DEFAULT: int = 0xACE1 + + +def _time_grouped( + x_concat, + split_sections, + rowwise, + columnwise, + *, + with_rht: bool = False, + mask: int = _RHT_MASK_DEFAULT, + n_iters: int = 20, + n_warmup: int = 5, +) -> float: + """Per-token grouped via the BULK Python wrapper. Allocation in-loop.""" + for _ in range(n_warmup): + _ = nvfp4_per_token_group_quantize( + x_concat, + split_sections, + rowwise=rowwise, + columnwise=columnwise, + with_rht=with_rht, + random_sign_mask_t=mask, + ) + torch.cuda.synchronize() + start = torch.cuda.Event(enable_timing=True) + stop = torch.cuda.Event(enable_timing=True) + start.record() + for _ in range(n_iters): + _ = nvfp4_per_token_group_quantize( + x_concat, + split_sections, + rowwise=rowwise, + columnwise=columnwise, + with_rht=with_rht, + random_sign_mask_t=mask, + ) + stop.record() + torch.cuda.synchronize() + return start.elapsed_time(stop) / n_iters # ms + + +def _time_split_quantize(x_concat, split_sections, quantizer_list, n_iters=20, n_warmup=5): + """Per-tensor grouped baseline: tex.split_quantize, allocation in-binding.""" + for _ in range(n_warmup): + _ = tex.split_quantize(x_concat, split_sections, quantizer_list) + torch.cuda.synchronize() + start = torch.cuda.Event(enable_timing=True) + stop = torch.cuda.Event(enable_timing=True) + start.record() + for _ in range(n_iters): + _ = tex.split_quantize(x_concat, split_sections, quantizer_list) + stop.record() + torch.cuda.synchronize() + return start.elapsed_time(stop) / n_iters # ms + + +def _time_split_quantize_graph(x_concat, split_sections, quantizer_list, n_iters=20, n_warmup=5): + """Per-tensor grouped under CUDA Graphs replay.""" + + def fn() -> None: + _ = tex.split_quantize(x_concat, split_sections, quantizer_list) + + return cuda_graph_time_ms(fn, warmup=n_warmup, iters=n_iters) + + +def _time_grouped_graph( + x_concat, + split_sections, + rowwise, + columnwise, + *, + with_rht: bool = False, + mask: int = _RHT_MASK_DEFAULT, + n_iters: int = 20, + n_warmup: int = 5, +) -> float: + """Per-token grouped under CUDA Graphs replay.""" + + def fn() -> None: + _ = nvfp4_per_token_group_quantize( + x_concat, + split_sections, + rowwise=rowwise, + columnwise=columnwise, + with_rht=with_rht, + random_sign_mask_t=mask, + ) + + return cuda_graph_time_ms(fn, warmup=n_warmup, iters=n_iters) + + +# Default sweep: N = 8 equal splits (MoE-typical), sum_M in {1024..32768}, +# K in {2048..8192}. Override either via the CLI flags below. +_DEFAULT_NUM_SPLITS: int = 8 +_DEFAULT_SUM_M_VALUES: Tuple[int, ...] = (1024, 2048, 4096, 8192, 16384, 32768) +_DEFAULT_K_VALUES: Tuple[int, ...] = (2048, 4096, 8192) + + +def _parse_shape(s: str) -> Tuple[int, int]: + """Parse a `sum_MxK` CLI argument.""" + parts = s.split("x") + if len(parts) != 2: + raise argparse.ArgumentTypeError(f"Shape must be sum_MxK, got '{s}'") + return tuple(int(p) for p in parts) # type: ignore[return-value] + + +def _build_bench_cases( + shapes: List[Tuple[int, int]], num_splits: int +) -> List[Tuple[List[int], int]]: + """Turn (sum_M, K) pairs into (split_sections, K) cases; each split + must be a multiple of 128. + """ + cases: List[Tuple[List[int], int]] = [] + for sum_M, K in shapes: + if sum_M % num_splits != 0: + raise argparse.ArgumentTypeError( + f"sum_M={sum_M} not divisible by num_splits={num_splits}" + ) + M_i = sum_M // num_splits + if M_i % 128 != 0: + raise argparse.ArgumentTypeError( + f"sum_M={sum_M} / num_splits={num_splits} = M_i={M_i} must be a " + "multiple of 128 (NVFP4 per-token kernel constraint)" + ) + if K % 128 != 0: + raise argparse.ArgumentTypeError(f"K={K} must be a multiple of 128") + cases.append(([M_i] * num_splits, K)) + return cases + + +def main() -> int: + parser = argparse.ArgumentParser( + description=( + "Bench NVFP4 per-token grouped K1+K2 quant. Three-way: " + "per-token (no RHT) / per-token+RHT / per-tensor (RHT+SR)." + ) + ) + parser.add_argument( + "--shapes", + type=_parse_shape, + nargs="+", + default=None, + help=( + "Shapes to bench, in sum_MxK form (e.g. 8192x4096). " + "Default: a 6x3 = 18-row internally-chosen sweep." + ), + ) + parser.add_argument( + "--num-splits", + type=int, + default=_DEFAULT_NUM_SPLITS, + help=( + f"Number of equal splits per shape (default {_DEFAULT_NUM_SPLITS}; " + "<= 64). M_i = sum_M / num_splits must be a multiple of 128." + ), + ) + parser.add_argument( + "--rht", + action="store_true", + help=( + "Enable 3-way table with per-token + col-wise 16-pt RHT path. " + "Default OFF prints 2-way (per-token vs per-tensor)." + ), + ) + parser.add_argument( + "--rht-mask", + type=lambda s: int(s, 0), + default=_RHT_MASK_DEFAULT, + help=( + f"16-bit RHT sign mask (default 0x{_RHT_MASK_DEFAULT:04X}; accepts " + "hex/dec). Only affects per-token+RHT; per-tensor uses its own mask." + ), + ) + args = parser.parse_args() + + if not torch.cuda.is_available(): + print("CUDA unavailable, skipping bench.") + return 1 + cap = torch.cuda.get_device_capability() + if cap[0] < 10: + print(f"NVFP4 per-token requires SM100+ (got SM{cap[0]}.{cap[1]}); skipping.") + return 1 + if args.num_splits <= 0 or args.num_splits > 64: + print(f"--num-splits must be in [1, 64], got {args.num_splits}") + return 2 + + if args.shapes is not None: + shapes_in = [tuple(s) for s in args.shapes] + else: + shapes_in = [(sm, k) for sm in _DEFAULT_SUM_M_VALUES for k in _DEFAULT_K_VALUES] + bench_cases = _build_bench_cases(shapes_in, args.num_splits) + rht_mask: int = args.rht_mask & 0xFFFF + with_rht: bool = args.rht + + device = torch.device("cuda") + print(f"# Device: {torch.cuda.get_device_name(0)} (cap {cap[0]}.{cap[1]})") + print(f"# Split structure: N={args.num_splits} equal splits, M_i = sum_M / {args.num_splits}") + if with_rht: + print( + f"# RHT mask: 0x{rht_mask:04X} (per-token+RHT col-wise; per-tensor uses its own" + " internal mask)" + ) + else: + print( + "# RHT: disabled (pass --rht to enable 3-way per-token / per-token (+rht) / per-tensor" + " table)" + ) + print() + + # Per-tensor baseline quantizer is fixed to row+col, so both enabled. + rowwise = True + columnwise = True + + def _fmt(r: float) -> str: + return "nan" if math.isnan(r) else f"{r:.2f}x" + + def _ratio(num: float, den: float) -> float: + if den <= 0 or math.isnan(num) or math.isnan(den): + return float("nan") + return num / den + + # Multi-line header: section label + column names (+ `(+rht)` sub-label + # row in 3-way mode), then separator + data rows. + if with_rht: + w_pt, w_pt_rht, w_pten, w_ratio = 12, 12, 13, 8 + block_w = w_pt + 1 + w_pt_rht + 1 + w_pten + 1 + w_ratio + header1 = ( + f"{'':>6} {'':>5} |{'Eager, unit (ms)':^{block_w}} |{'Graph, unit (ms)':^{block_w}}" + ) + header2 = ( + f"{'sum_M':>6} {'K':>5}" + " |" + f"{'per-token':>{w_pt}} {'per-token':>{w_pt_rht}}" + f" {'per-tensor':>{w_pten}} {'ratio':>{w_ratio}}" + " |" + f"{'per-token':>{w_pt}} {'per-token':>{w_pt_rht}}" + f" {'per-tensor':>{w_pten}} {'ratio':>{w_ratio}}" + ) + header3 = ( + f"{'':>6} {'':>5}" + " |" + f"{'':>{w_pt}} {'(+rht)':>{w_pt_rht}}" + f" {'':>{w_pten}} {'':>{w_ratio}}" + " |" + f"{'':>{w_pt}} {'(+rht)':>{w_pt_rht}}" + f" {'':>{w_pten}} {'':>{w_ratio}}" + ) + print(header1) + print(header2) + print(header3) + else: + w_pt, w_pten, w_ratio = 14, 15, 8 + block_w = w_pt + 1 + w_pten + 1 + w_ratio + header1 = ( + f"{'':>6} {'':>5} |{'Eager, unit (ms)':^{block_w}} |{'Graph, unit (ms)':^{block_w}}" + ) + header2 = ( + f"{'sum_M':>6} {'K':>5}" + " |" + f"{'per-token':>{w_pt}} {'per-tensor':>{w_pten}} {'ratio':>{w_ratio}}" + " |" + f"{'per-token':>{w_pt}} {'per-tensor':>{w_pten}} {'ratio':>{w_ratio}}" + ) + print(header1) + print(header2) + print("-" * len(header2)) + + prev_sum_M = None + for split_sections, K in bench_cases: + sum_M = sum(split_sections) + num_splits = len(split_sections) + + # Blank line between sum_M groups for readability. + if prev_sum_M is not None and sum_M != prev_sum_M: + print() + prev_sum_M = sum_M + + x_concat = (torch.randn((sum_M, K), dtype=torch.bfloat16, device=device) * 3.0).contiguous() + quantizer_list = _make_baseline_quantizer_list(num_splits) + + t_pt = _time_grouped(x_concat, split_sections, rowwise, columnwise, with_rht=False) + t_pten = _time_split_quantize(x_concat, split_sections, quantizer_list) + t_pt_g = _time_grouped_graph( + x_concat, + split_sections, + rowwise, + columnwise, + with_rht=False, + ) + t_pten_g = _time_split_quantize_graph( + x_concat, + split_sections, + quantizer_list, + ) + + if with_rht: + t_pt_rht = _time_grouped( + x_concat, split_sections, rowwise, columnwise, with_rht=True, mask=rht_mask + ) + t_pt_rht_g = _time_grouped_graph( + x_concat, + split_sections, + rowwise, + columnwise, + with_rht=True, + mask=rht_mask, + ) + + ratio_eager = _ratio(t_pt_rht, t_pten) + ratio_graph = _ratio(t_pt_rht_g, t_pten_g) + + print( + f"{sum_M:>6d} {K:>5d}" + " |" + f"{t_pt:>{w_pt}.4f} {t_pt_rht:>{w_pt_rht}.4f}" + f" {t_pten:>{w_pten}.4f} {_fmt(ratio_eager):>{w_ratio}}" + " |" + f"{t_pt_g:>{w_pt}.4f} {t_pt_rht_g:>{w_pt_rht}.4f}" + f" {t_pten_g:>{w_pten}.4f} {_fmt(ratio_graph):>{w_ratio}}" + ) + else: + ratio_eager = _ratio(t_pt, t_pten) + ratio_graph = _ratio(t_pt_g, t_pten_g) + print( + f"{sum_M:>6d} {K:>5d}" + " |" + f"{t_pt:>{w_pt}.4f} {t_pten:>{w_pten}.4f} {_fmt(ratio_eager):>{w_ratio}}" + " |" + f"{t_pt_g:>{w_pt}.4f} {t_pten_g:>{w_pten}.4f} {_fmt(ratio_graph):>{w_ratio}}" + ) + + del x_concat, quantizer_list + torch.cuda.empty_cache() + + print() + print("Legend:") + if with_rht: + print(" per-token (ms) = nvfp4_per_token_group_quantize(x, splits,") + print(" rowwise+colwise, with_rht=False)") + print(" = K1 fused amax + K2 fused cast (2 launches), no RHT.") + print( + " per-token (+rht) (ms) = same, but with_rht=True +" + f" random_sign_mask_t=0x{rht_mask:04X}." + ) + print(" Applies a 16-point RHT along the columnwise direction in") + print(" BOTH K1 amax and K2 cast; rowwise stays raw. Length-16") + print(" matches the 1x16 inner-SF block of NVFP4, so each scale") + print(" window is decorrelated.") + print( + " per-tensor (ms) = tex.split_quantize(x, splits, [NVFP4Quantizer(rht+sr)]*N)" + ) + print(" = nvte_group_hadamard_transform_amax") + print(" + nvte_group_hadamard_transform_cast_fusion") + print(" (2 launches, prod baseline).") + print(" ratio = per-token (+rht) / per-tensor") + print(" ** < 1.0 = this PR wins vs prod baseline **") + else: + print( + " per-token (ms) = nvfp4_per_token_group_quantize(x, splits, rowwise+colwise," + " with_rht=False)" + ) + print(" = K1 fused amax + K2 fused cast (2 launches), no RHT.") + print(" per-tensor (ms) = tex.split_quantize(x, splits, [NVFP4Quantizer(rht+sr)]*N)") + print(" = nvte_group_hadamard_transform_amax") + print( + " + nvte_group_hadamard_transform_cast_fusion (2 launches, prod" + " baseline)." + ) + print( + " ratio = per-token / per-tensor ** < 1.0 = per-token wins vs prod" + " baseline **" + ) + print(" (Graph) suffix = same under CUDA Graphs replay (Python + alloc elided).") + return 0 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/tests/pytorch/nvfp4/test_nvfp4_cutlass_per_token_gemm.py b/tests/pytorch/nvfp4/test_nvfp4_cutlass_per_token_gemm.py new file mode 100644 index 0000000000..f1af0c99b7 --- /dev/null +++ b/tests/pytorch/nvfp4/test_nvfp4_cutlass_per_token_gemm.py @@ -0,0 +1,272 @@ +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""Numerical tests for tex.nvfp4_cutlass_per_token_gemm (fused EVT) vs the +cuBLAS-LT per-token reference (GEMM + standalone post_scale). M, N, K multiples +of 256; rtol=2e-2 ~ 2.5x bf16 ULP for fp32 reduction-order noise.""" + +from __future__ import annotations + +from typing import Tuple + +import pytest +import torch + +# Must import transformer_engine first to dlopen libtransformer_engine.so. +import transformer_engine.pytorch as te # noqa: F401 +import transformer_engine_torch as tex # type: ignore + + +def _has_sm100() -> bool: + if not torch.cuda.is_available(): + return False + major, _ = torch.cuda.get_device_capability() + return major >= 10 + + +_GATED_SM100 = pytest.mark.skipif( + not _has_sm100(), + reason="CUTLASS NVFP4 fused per-token GEMM requires SM100 (Blackwell).", +) + +_GATED_HAS_KERNEL = pytest.mark.skipif( + not hasattr(tex, "nvfp4_cutlass_per_token_gemm"), + reason="tex.nvfp4_cutlass_per_token_gemm not built into this binary.", +) + + +def _quantize_per_token( + x: torch.Tensor, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Per-token quantize (rows, K) bf16 -> (q_FP4_packed (rows, K/2), + sf_FP8e4m3 (rows, K/16), row_amax fp32 (rows,)). + """ + assert x.dim() == 2 and x.dtype == torch.bfloat16 + rows, K = x.shape + q_row = torch.empty((rows, K // 2), dtype=torch.uint8, device=x.device) + s_row = torch.empty((rows, K // 16), dtype=torch.uint8, device=x.device) + a_row = torch.empty((rows,), dtype=torch.float32, device=x.device) + q_col = torch.empty(0, dtype=torch.uint8, device=x.device) + s_col = torch.empty(0, dtype=torch.uint8, device=x.device) + a_col = torch.empty(0, dtype=torch.float32, device=x.device) + tex.nvfp4_per_token_quantize( + x, + q_row, + s_row, + a_row, + q_col, + s_col, + a_col, + rowwise=True, + columnwise=False, + with_rht=False, + random_sign_mask_t=int(0xACE1), + with_swizzle=False, + ) + return q_row, s_row, a_row + + +def _ref_pertoken_gemm_via_cublaslt( + a_q: torch.Tensor, + b_q: torch.Tensor, + a_sf: torch.Tensor, + b_sf: torch.Tensor, + alpha_a: torch.Tensor, + alpha_b: torch.Tensor, + M: int, + N: int, + K: int, +) -> torch.Tensor: + """Reference: tex.nvfp4_per_token_gemm = cuBLAS-LT NVFP4 GEMM + standalone + post-scale. Already correctness-tested in test_nvfp4_per_token.py. + """ + workspace = torch.empty(33_554_432, dtype=torch.uint8, device=a_q.device) + d = torch.empty((M, N), dtype=torch.bfloat16, device=a_q.device) + tex.nvfp4_per_token_gemm( + a_q, + b_q, + a_sf.reshape(-1), + b_sf.reshape(-1), + alpha_a, + alpha_b, + d, + workspace, + M, + N, + K, + 1.0, + 0.0, + a_sf_swizzled=False, + b_sf_swizzled=False, + ) + return d + + +def _run_fused( + a_q: torch.Tensor, + b_q: torch.Tensor, + a_sf: torch.Tensor, + b_sf: torch.Tensor, + alpha_a: torch.Tensor, + alpha_b: torch.Tensor, + M: int, + N: int, + K: int, +) -> torch.Tensor: + d = torch.empty((M, N), dtype=torch.bfloat16, device=a_q.device) + tex.nvfp4_cutlass_per_token_gemm( + a_q, + b_q, + a_sf.reshape(-1), + b_sf.reshape(-1), + alpha_a, + alpha_b, + d, + M, + N, + K, + a_sf_swizzled=False, + b_sf_swizzled=False, + ) + return d + + +# Shapes obey the kernel contract (M, N, K all multiples of 256). +_SHAPES = [ + (256, 256, 256), # smallest legal shape + (512, 256, 256), + (256, 512, 256), + (256, 256, 512), + (512, 1024, 768), # not power-of-2 K + (1024, 1024, 1024), +] + + +@_GATED_SM100 +@_GATED_HAS_KERNEL +@pytest.mark.parametrize("M,N,K", _SHAPES) +def test_fused_matches_cublaslt_per_token(M: int, N: int, K: int) -> None: + """Fused CUTLASS per-token == cuBLAS LT per-token (within bf16 + reduction-order tolerance).""" + device = torch.device("cuda") + torch.manual_seed(0xACE1) + + a = (torch.randn((M, K), dtype=torch.bfloat16, device=device) * 0.5).contiguous() + b = (torch.randn((N, K), dtype=torch.bfloat16, device=device) * 0.5).contiguous() + + # Per-token quantize for both operands. + a_q, a_sf, a_row_amax = _quantize_per_token(a) + b_q, b_sf, b_row_amax = _quantize_per_token(b) + + # The two paths share quantizer outputs (a_q, b_q, a_sf, b_sf) and amaxes, + # so the difference is purely in the GEMM kernel and the order of the + # per-row * per-col fold (epilogue vs separate kernel). + d_ref = _ref_pertoken_gemm_via_cublaslt( + a_q, + b_q, + a_sf, + b_sf, + a_row_amax, + b_row_amax, + M, + N, + K, + ) + d_fused = _run_fused( + a_q, + b_q, + a_sf, + b_sf, + a_row_amax, + b_row_amax, + M, + N, + K, + ) + + # Float32 view for comparison; bf16 ULP is 2^-7 = 7.8e-3 relative. + ref_f32 = d_ref.float() + out_f32 = d_fused.float() + + # Diagnostic statistics for failure mode debugging. + abs_diff = (out_f32 - ref_f32).abs() + rel_diff = abs_diff / (ref_f32.abs().clamp_min(1e-6)) + max_abs = abs_diff.max().item() + mean_abs = abs_diff.mean().item() + max_rel = rel_diff.max().item() + mean_rel = rel_diff.mean().item() + + # Relative tolerance 2e-2 leaves ~2.5x headroom over the bf16 ULP floor. + torch.testing.assert_close(out_f32, ref_f32, rtol=2e-2, atol=2e-2) + + print( + f" M={M:>5} N={N:>5} K={K:>5}: " + f"max_abs={max_abs:.3e} mean_abs={mean_abs:.3e} " + f"max_rel={max_rel:.3e} mean_rel={mean_rel:.3e}" + ) + + +# NVFP4 spec outer-dequant baked into the fused EVT; cuBLAS-LT auto-folds the +# same factor via its amax slot. Mirror of nvfp4_cutlass_gemm.cu. +NVFP4_DEQUANT_K = 1.0 / (6.0 * 6.0 * 448.0 * 448.0) # = 1 / 2688^2 ~= 1.38e-7 + + +@_GATED_SM100 +@_GATED_HAS_KERNEL +def test_fused_alpha_unity_matches_scalar_gemm_with_baked_const() -> None: + """With alpha=1 the EVT collapses to D = bf16(NVFP4_DEQUANT_K * acc) and + must match nvfp4_cutlass_gemm(alpha=NVFP4_DEQUANT_K) BIT-FOR-BIT (same + mainloop reduction; the *1.0f multiplies are exact in fp32).""" + M, N, K = 256, 256, 256 + device = torch.device("cuda") + torch.manual_seed(0xACE2) + + a = (torch.randn((M, K), dtype=torch.bfloat16, device=device) * 0.5).contiguous() + b = (torch.randn((N, K), dtype=torch.bfloat16, device=device) * 0.5).contiguous() + a_q, a_sf, _ = _quantize_per_token(a) + b_q, b_sf, _ = _quantize_per_token(b) + + alpha_a = torch.ones((M,), dtype=torch.float32, device=device) + alpha_b = torch.ones((N,), dtype=torch.float32, device=device) + + d_fused = _run_fused(a_q, b_q, a_sf, b_sf, alpha_a, alpha_b, M, N, K) + + d_scalar = torch.empty((M, N), dtype=torch.bfloat16, device=device) + tex.nvfp4_cutlass_gemm( + a_q, + b_q, + a_sf.reshape(-1), + b_sf.reshape(-1), + d_scalar, + M, + N, + K, + NVFP4_DEQUANT_K, + 0.0, + a_sf_swizzled=False, + b_sf_swizzled=False, + ) + + # Exact match (any deviation = EVT bug). + torch.testing.assert_close( + d_fused.float(), + d_scalar.float(), + rtol=0.0, + atol=0.0, + msg=( + "Fused EVT with unity alpha + baked 1/2688^2 must match " + "nvfp4_cutlass_gemm(alpha=1/2688^2) bit-exact." + ), + ) + + +if __name__ == "__main__": + if not _has_sm100(): + print("SKIP: not SM100") + elif not hasattr(tex, "nvfp4_cutlass_per_token_gemm"): + print("SKIP: kernel not built") + else: + for shape in _SHAPES: + test_fused_matches_cublaslt_per_token(*shape) + test_fused_alpha_unity_matches_scalar_gemm_with_baked_const() + print("All tests passed.") diff --git a/tests/pytorch/nvfp4/test_nvfp4_per_token.py b/tests/pytorch/nvfp4/test_nvfp4_per_token.py new file mode 100644 index 0000000000..5d99d2919b --- /dev/null +++ b/tests/pytorch/nvfp4/test_nvfp4_per_token.py @@ -0,0 +1,967 @@ +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""Correctness tests for NVFP4 per-token cast + cuBLAS LT NVFP4 GEMM. + +Covers byte-equal kernel-vs-reference quantize parity, K1/K2 split-vs-composite +parity, dequant + fp32 reference, optional RHT (K1 amax + K2 cast), and a +cuBLAS LT NVFP4 GEMM smoke. Requires bf16 input, M % 128 == 0, K % 128 == 0; +GEMM and RHT tests gated by SM100. +""" + +from __future__ import annotations + +from typing import Tuple + +import pytest +import torch + +# Must import transformer_engine first to dlopen libtransformer_engine.so so +# transformer_engine_torch.so can resolve typeinfo / vtable symbols at load time. +import transformer_engine.pytorch as te # noqa: F401 +import transformer_engine_torch as tex # type: ignore # noqa: F401 + +from transformer_engine.pytorch.custom_recipes.gemm_nvfp4_per_token import ( + dequantize_nvfp4_per_token, + nvfp4_per_token_gemm, + nvfp4_per_token_gemm_dequant, +) +from transformer_engine.pytorch.custom_recipes.quantization_nvfp4_per_token import ( + BLOCK_K, + NVFP4QuantizerPerTokenRef, + nvfp4_per_token_amax, + nvfp4_per_token_encode, + nvfp4_per_token_quantize, +) + + +def _has_sm100() -> bool: + if not torch.cuda.is_available(): + return False + major, _ = torch.cuda.get_device_capability() + return major >= 10 + + +_GATED_SM100 = pytest.mark.skipif( + not _has_sm100(), + reason="NVFP4 per-token GEMM via cuBLAS LT requires SM100 (Blackwell).", +) + +_GATED_FP4 = pytest.mark.skipif( + not torch.cuda.is_available(), + reason="NVFP4 per-token cast requires CUDA.", +) + + +# (1) Quantize parity: kernel vs Python reference. + +# Shapes obey the kernel contract (M % 128 == 0, K % 128 == 0). +_QUANT_SHAPES = [ + (128, 128), # smallest legal shape + (128, 256), # K > inner SF window of single chunk + (256, 128), # M > inner SF window of single chunk + (256, 512), + (512, 1024), +] + + +def _unpack_fp4_byte_pairs(x: torch.Tensor) -> torch.Tensor: + """Unpack two FP4 values per byte into one uint8 nibble per element.""" + repeated = x.repeat_interleave(2, dim=1) + repeated[:, 0::2] &= 0x0F + repeated[:, 1::2] >>= 4 + return repeated + + +@_GATED_FP4 +@pytest.mark.parametrize("M,N", _QUANT_SHAPES) +@pytest.mark.parametrize("rowwise,columnwise", [(True, False), (False, True), (True, True)]) +def test_per_token_quantize_byte_exact(M: int, N: int, rowwise: bool, columnwise: bool) -> None: + """Composite per-token output is byte-equal to the Python reference.""" + torch.manual_seed(0xBEEF * (M + 17) + (N + 3)) + device = torch.device("cuda") + x = torch.randn((M, N), dtype=torch.bfloat16, device=device) * 4.0 + # Outliers so the per-row outer is exercised. + if M >= 4: + x[0, :] *= 8.0 + x[-1, :] *= 0.125 + + ref = NVFP4QuantizerPerTokenRef(rowwise=rowwise, columnwise=columnwise).quantize(x) + sut = nvfp4_per_token_quantize(x, rowwise=rowwise, columnwise=columnwise) + + if rowwise: + qx_sut = _unpack_fp4_byte_pairs(sut.data.view(torch.uint8)) + qx_ref = _unpack_fp4_byte_pairs(ref.data.view(torch.uint8)) + torch.testing.assert_close(qx_sut, qx_ref, atol=0.0, rtol=0.0) + torch.testing.assert_close( + sut.scale.view(torch.uint8), + ref.scale.view(torch.uint8), + atol=0.0, + rtol=0.0, + ) + torch.testing.assert_close(sut.row_amax, ref.row_amax, atol=0.0, rtol=0.0) + + if columnwise: + qxt_sut = _unpack_fp4_byte_pairs(sut.columnwise_data.view(torch.uint8)) + qxt_ref = _unpack_fp4_byte_pairs(ref.columnwise_data.view(torch.uint8)) + torch.testing.assert_close(qxt_sut, qxt_ref, atol=0.0, rtol=0.0) + torch.testing.assert_close( + sut.columnwise_scale.view(torch.uint8), + ref.columnwise_scale.view(torch.uint8), + atol=0.0, + rtol=0.0, + ) + torch.testing.assert_close(sut.col_amax, ref.col_amax, atol=0.0, rtol=0.0) + + +# (2) Split-kernel parity: K1 then K2 == composite K1+K2. + + +@_GATED_FP4 +@pytest.mark.parametrize("M,N", _QUANT_SHAPES) +@pytest.mark.parametrize("rowwise,columnwise", [(True, False), (False, True), (True, True)]) +def test_per_token_split_byte_equal( + M: int, + N: int, + rowwise: bool, + columnwise: bool, +) -> None: + """K1 (amax) then K2 (encode) byte-equals the composite K1+K2.""" + torch.manual_seed(0xC0FFEE * (M + 7) + (N + 11)) + device = torch.device("cuda") + x = torch.randn((M, N), dtype=torch.bfloat16, device=device) * 4.0 + if M >= 4: + x[0, :] *= 8.0 + x[-1, :] *= 0.125 + + composite = nvfp4_per_token_quantize(x, rowwise=rowwise, columnwise=columnwise) + + row_amax, col_amax = nvfp4_per_token_amax( + x, + rowwise=rowwise, + columnwise=columnwise, + ) + split = nvfp4_per_token_encode( + x, + row_amax=row_amax, + col_amax=col_amax, + rowwise=rowwise, + columnwise=columnwise, + ) + + if rowwise: + torch.testing.assert_close(split.row_amax, composite.row_amax, atol=0.0, rtol=0.0) + torch.testing.assert_close( + split.data.view(torch.uint8), + composite.data.view(torch.uint8), + atol=0.0, + rtol=0.0, + ) + torch.testing.assert_close( + split.scale.view(torch.uint8), + composite.scale.view(torch.uint8), + atol=0.0, + rtol=0.0, + ) + if columnwise: + torch.testing.assert_close(split.col_amax, composite.col_amax, atol=0.0, rtol=0.0) + torch.testing.assert_close( + split.columnwise_data.view(torch.uint8), + composite.columnwise_data.view(torch.uint8), + atol=0.0, + rtol=0.0, + ) + torch.testing.assert_close( + split.columnwise_scale.view(torch.uint8), + composite.columnwise_scale.view(torch.uint8), + atol=0.0, + rtol=0.0, + ) + + +# (2b) Input-validation rejections. + + +@_GATED_FP4 +def test_per_token_validation_rejects_fp32() -> None: + """Per-token must ``ValueError`` on non-bf16 input (no fallback path).""" + device = torch.device("cuda") + x = torch.randn((128, 128), dtype=torch.float32, device=device) + with pytest.raises(ValueError, match="bf16"): + nvfp4_per_token_quantize(x, rowwise=True, columnwise=False) + + +@_GATED_FP4 +def test_per_token_validation_rejects_unaligned() -> None: + """Per-token must ``ValueError`` on M or K not 128-aligned.""" + device = torch.device("cuda") + x = torch.randn((128, 64), dtype=torch.bfloat16, device=device) + with pytest.raises(ValueError, match="K % 128"): + nvfp4_per_token_quantize(x, rowwise=True, columnwise=False) + + x2 = torch.randn((64, 128), dtype=torch.bfloat16, device=device) + with pytest.raises(ValueError, match="M % 128"): + nvfp4_per_token_quantize(x2, rowwise=True, columnwise=False) + + +# (3) Dequant + fp32 reference matmul sanity (pure-Python, no kernel). + + +@_GATED_FP4 +@pytest.mark.parametrize("M,N", [(32, 64), (64, 256)]) +def test_per_token_dequant_roundtrip_close(M: int, N: int) -> None: + """``dequantize(quantize(x)) ~ x`` at FP4 quantization precision.""" + torch.manual_seed(0x1234) + device = torch.device("cuda") + x = torch.randn((M, N), dtype=torch.float32, device=device) + + ref = NVFP4QuantizerPerTokenRef(rowwise=True).quantize(x) + y = dequantize_nvfp4_per_token(ref.data, ref.scale, ref.row_amax) + + # Loose bound: catches dequant-formula bugs, not quantization quality. + rel = (y - x).abs() / x.abs().clamp(min=1e-6) + assert rel.mean().item() < 0.5, f"mean rel error {rel.mean().item():.3g} > 0.5" + + +# (4) Production GEMM: cuBLAS LT NVFP4 + post-scale composite. +# Shapes need M, N % 128 == 0 and K % 16 == 0 for cuBLAS LT NVFP4. +_GEMM_SHAPES = [ + (128, 128, 128), # smallest legal shape + (128, 128, 256), # exercise K > inner SF window + (256, 128, 256), # non-square (M != N) + (256, 256, 256), # square mid-size +] + + +def _three_pronged_bf16_close( + d_test: torch.Tensor, + d_ref: torch.Tensor, + *, + label: str, + rel_l2_floor: float = 2e-2, + bad_count_ratio: float = 1e-2, + atol: float = 1e-1, + bad_rtol: float = 5e-2, +) -> None: + """Dequant-vs-SUT closeness for random GEMM outputs. + + Three-pronged: energy-weighted rel_l2 (primary), torch.allclose-style + n_bad_mixed (localised faults), max_abs (NaN-like blow-up sanity). + """ + finite_mask = torch.isfinite(d_test) & torch.isfinite(d_ref) + d_t = d_test.float()[finite_mask] + d_r = d_ref.float()[finite_mask] + diff = (d_t - d_r).abs() + n = d_t.numel() + + diff_l2 = float(diff.norm().item()) + ref_l2 = float(d_r.norm().item()) + rel_l2 = diff_l2 / (ref_l2 + 1e-30) + + n_bad_mixed = int((diff > atol + bad_rtol * d_r.abs()).sum().item()) + + max_abs = float(diff.max().item()) if n else float("nan") + mean_ref_abs = float(d_r.abs().mean().item()) if n else float("nan") + max_abs_bound = atol + bad_rtol * mean_ref_abs + + rel = diff / d_r.abs().clamp(min=1e-30) + mean_rel = float(rel.mean().item()) if n else float("nan") + max_rel = float(rel.max().item()) if n else float("nan") + + diag = ( + f"[{label}] N_finite={n}/{int(finite_mask.numel())} " + f"rel_l2={rel_l2:.3g} max_abs={max_abs:.3g} n_bad_mixed={n_bad_mixed} " + f"mean_|d_ref|={mean_ref_abs:.3g} " + f"(diag: mean_rel={mean_rel:.3g} max_rel={max_rel:.3g} " + "— mean_rel/max_rel are NOT asserted; see helper docstring)" + ) + print(diag) + + bad_count_abs_floor = max(8, int(bad_count_ratio * n)) + assert rel_l2 <= rel_l2_floor, ( + f"{diag} -> rel_l2 > {rel_l2_floor} (energy-weighted global " + "relative error too high — possible structural bug)" + ) + assert n_bad_mixed <= bad_count_abs_floor, ( + f"{diag} -> n_bad_mixed > {bad_count_abs_floor} " + f"(|diff| > atol={atol} + rtol={bad_rtol} * |d_r| for too " + "many elements — possible localised broken row/col)" + ) + assert max_abs <= max_abs_bound, ( + f"{diag} -> max_abs > {max_abs_bound:.3g} = atol + " + "bad_rtol * mean_|d_ref| (worst element is way outside the " + "noise envelope — possible NaN-like blow-up)" + ) + + +@_GATED_SM100 +@pytest.mark.parametrize("M,N,K", _GEMM_SHAPES) +def test_per_token_gemm_close_to_bf16(M: int, N: int, K: int) -> None: + """End-to-end per_token_gemm is structurally close to BF16 GEMM. + + Uses cos_sim + magnitude-ratio (direction + magnitude) instead of + per-element mean_rel, which is pathological on random GEMM outputs. + """ + torch.manual_seed(0xACE * M + K) + device = torch.device("cuda") + a = torch.randn((M, K), dtype=torch.bfloat16, device=device) + b = torch.randn((N, K), dtype=torch.bfloat16, device=device) + + a_q = nvfp4_per_token_quantize(a, rowwise=True) + b_q = nvfp4_per_token_quantize(b, rowwise=True) + + d_sut = nvfp4_per_token_gemm( + a_q.data, + a_q.scale, + a_q.row_amax, + b_q.data, + b_q.scale, + b_q.row_amax, + ) + + d_ref = (a.float() @ b.float().t()).to(torch.bfloat16) + + d_sut_f = d_sut.float().flatten() + d_ref_f = d_ref.float().flatten() + + sut_norm = d_sut_f.norm() + ref_norm = d_ref_f.norm() + cos_sim = float((d_sut_f @ d_ref_f) / (sut_norm * ref_norm + 1e-30)) + mag_ratio = float(sut_norm / (ref_norm + 1e-30)) + + # cos_sim >= 0.95 catches operand swap; mag in [0.7, 1.3] catches + # missing/duplicated scale or wrong alpha-by-constant. + cos_sim_floor = 0.95 + mag_lo, mag_hi = 0.7, 1.3 + + diag = ( + f"[per_token({M}x{N}x{K})] cos_sim={cos_sim:.4f} " + f"mag_ratio={mag_ratio:.4f} " + f"||d_sut||={float(sut_norm):.4g} ||d_ref||={float(ref_norm):.4g}" + ) + assert cos_sim >= cos_sim_floor, ( + f"{diag} -> cos_sim < {cos_sim_floor} (structural mismatch; " + "likely wrong operand swap, missing scale, or indexing bug)" + ) + assert mag_lo <= mag_ratio <= mag_hi, ( + f"{diag} -> mag_ratio not in [{mag_lo}, {mag_hi}] " + "(systematic magnitude error; check alpha/post-scale)" + ) + + +@_GATED_SM100 +@pytest.mark.parametrize("M,N,K", _GEMM_SHAPES) +def test_per_token_gemm_close_to_dequant_ref(M: int, N: int, K: int) -> None: + """End-to-end per_token_gemm close to dequant + fp32 matmul (TF32 envelope).""" + torch.manual_seed(0xDEAD * (M + 7) + (N + 1) * K) + device = torch.device("cuda") + a = torch.randn((M, K), dtype=torch.bfloat16, device=device) * 0.5 + b = torch.randn((N, K), dtype=torch.bfloat16, device=device) * 0.5 + + a_q = nvfp4_per_token_quantize(a, rowwise=True) + b_q = nvfp4_per_token_quantize(b, rowwise=True) + + d_sut = nvfp4_per_token_gemm( + a_q.data, + a_q.scale, + a_q.row_amax, + b_q.data, + b_q.scale, + b_q.row_amax, + ).float() + + d_ref = nvfp4_per_token_gemm_dequant( + a_q.data, + a_q.scale, + a_q.row_amax, + b_q.data, + b_q.scale, + b_q.row_amax, + out_dtype=torch.float32, + ) + + _three_pronged_bf16_close( + d_sut, + d_ref, + label=f"vs_dequant({M}x{N}x{K})", + # Empirical rel_l2 ~5e-3..1.5e-2 on random N(0, 0.5), K=128-256. + rel_l2_floor=2e-2, + atol=1e-1, + bad_rtol=5e-2, + bad_count_ratio=1e-2, + ) + + +@_GATED_SM100 +def test_per_token_gemm_rejects_beta_nonzero() -> None: + """beta != 0 raises until residual handling is added.""" + device = torch.device("cuda") + M, N, K = 128, 128, 128 + a = torch.randn((M, K), dtype=torch.bfloat16, device=device) + b = torch.randn((N, K), dtype=torch.bfloat16, device=device) + a_q = nvfp4_per_token_quantize(a, rowwise=True) + b_q = nvfp4_per_token_quantize(b, rowwise=True) + + with pytest.raises(ValueError, match=r"beta != 0"): + nvfp4_per_token_gemm( + a_q.data, + a_q.scale, + a_q.row_amax, + b_q.data, + b_q.scale, + b_q.row_amax, + beta=1.0, + ) + + +# ============================================================================= +# (5) RHT correctness: K1 amax + K2 cast with optional col-wise RHT. +# Opt-in via with_rht=True + random_sign_mask_t=; row direction never +# sees RHT. with_rht=False is byte-equal to the pre-RHT path. +# ============================================================================= + +_RHT_SHAPES = [ + (128, 128), + (256, 256), + (128, 1024), # K > single 64x64 sub-tile along col + (1024, 128), # M > single 64x64 sub-tile along row + (512, 512), +] + + +def _walsh_hadamard_16(device: torch.device) -> torch.Tensor: + """16x16 Sylvester / Walsh-Hadamard matrix, +/-1 entries (unnormalized).""" + H = torch.tensor([[1.0]], dtype=torch.float32, device=device) + for _ in range(4): + top = torch.cat([H, H], dim=1) + bot = torch.cat([H, -H], dim=1) + H = torch.cat([top, bot], dim=0) + return H + + +def _sign_diag_16(mask: int, device: torch.device) -> torch.Tensor: + """16-elt +/-1 vector; s_i = -1 iff bit i of `mask` is set.""" + bits = torch.tensor( + [1 - 2 * ((mask >> i) & 1) for i in range(16)], + dtype=torch.float32, + device=device, + ) + return bits + + +def _reference_col_amax_rht(x_bf16: torch.Tensor, mask: int) -> torch.Tensor: + """PyTorch reference for the per-token col-wise RHT amax: max over + 16-row blocks of |H * D * x_block| / 4. FHT may permute element order + but |y|.max() is permutation-invariant. + """ + M, K = x_bf16.shape + assert M % 16 == 0, "Test setup error: M must be a multiple of 16." + H = _walsh_hadamard_16(x_bf16.device) + sign = _sign_diag_16(mask, x_bf16.device) + x = x_bf16.to(torch.float32) + blocks = x.reshape(M // 16, 16, K) + masked = blocks * sign.view(1, 16, 1) + rotated = torch.einsum("ij,bjk->bik", H, masked) + return (rotated.abs() / 4.0).reshape(-1, K).amax(dim=0) + + +def _reference_amax_raw(x_bf16: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + """Raw per-row + per-col absolute max (no RHT, bf16 -> fp32 first).""" + x = x_bf16.to(torch.float32) + return x.abs().amax(dim=1), x.abs().amax(dim=0) + + +def _allocate_per_token_buffers(M: int, K: int, device: torch.device): + """Match the layout that ``tex.nvfp4_per_token_quantize`` writes.""" + return { + "q_row": torch.empty((M, K // 2), dtype=torch.uint8, device=device), + "s_row": torch.empty((M, K // BLOCK_K), dtype=torch.uint8, device=device), + "ra": torch.empty((M,), dtype=torch.float32, device=device), + "q_col": torch.empty((K, M // 2), dtype=torch.uint8, device=device), + "s_col": torch.empty((K, M // BLOCK_K), dtype=torch.uint8, device=device), + "ca": torch.empty((K,), dtype=torch.float32, device=device), + } + + +def _dequant_fp4_with_outer_amax( + q_packed: torch.Tensor, # (R, C // 2) uint8 packed FP4 + s_dec: torch.Tensor, # (R, C // 16) e4m3 held as uint8 + outer_amax: torch.Tensor, # (R,) fp32 +) -> torch.Tensor: + """Decode a rowwise FP4 tensor back to fp32 using the kernel's own + arithmetic: x_hat = qcode * s_dec_e4m3 * (6 / S_enc_row), + S_enc_row = (448 * 6) / max(outer_amax, 1e-12). + """ + R, half_C = q_packed.shape + C = half_C * 2 + s_dec_f = s_dec.view(torch.float8_e4m3fn).to(torch.float32) + + lo = (q_packed & 0x0F).to(torch.int8) + hi = ((q_packed >> 4) & 0x0F).to(torch.int8) + interleaved = torch.stack([lo, hi], dim=-1).reshape(R, C) + # NVFP4 E2M1 LUT (sign-magnitude): 0000..0111 map to {0, 0.5, 1, 1.5, + # 2, 3, 4, 6}; 1000..1111 are the negatives. + fp4_lut = torch.tensor( + [0.0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0, -0.0, -0.5, -1.0, -1.5, -2.0, -3.0, -4.0, -6.0], + dtype=torch.float32, + device=q_packed.device, + ) + fp4_val = fp4_lut[interleaved.to(torch.int64)] + + fp8_max = 448.0 + fp4_max = 6.0 + safe_amax = torch.clamp(outer_amax, min=1e-12) + S_enc_row = (fp8_max * fp4_max) / safe_amax + inv_S = (1.0 / S_enc_row).unsqueeze(1) + + block_scale_inv = s_dec_f * inv_S + block_scale_inv = block_scale_inv.repeat_interleave(BLOCK_K, dim=1) + + return fp4_val * block_scale_inv + + +# ----- (5a) K1 RHT: standalone amax kernel ---------------------------------- + + +@_GATED_SM100 +@pytest.mark.parametrize("M,K", _RHT_SHAPES) +def test_per_token_k1_with_rht_false_equals_raw_amax(M: int, K: int) -> None: + """Regression: with_rht=False reproduces raw bf16->fp32 amax along each axis.""" + torch.manual_seed(0xABCD * (M + 1) + K) + device = torch.device("cuda") + x = torch.randn((M, K), dtype=torch.bfloat16, device=device) + + row_amax = torch.empty((M,), dtype=torch.float32, device=device) + col_amax = torch.empty((K,), dtype=torch.float32, device=device) + + tex.nvfp4_per_token_amax( + x, + row_amax, + col_amax, + True, + True, + with_rht=False, + random_sign_mask_t=0, + ) + + ref_row, ref_col = _reference_amax_raw(x) + torch.testing.assert_close( + row_amax, ref_row, rtol=0.0, atol=0.0, msg=f"row_amax mismatch at ({M}, {K})" + ) + torch.testing.assert_close( + col_amax, ref_col, rtol=0.0, atol=0.0, msg=f"col_amax mismatch at ({M}, {K})" + ) + + +@_GATED_SM100 +@pytest.mark.parametrize("M,K", _RHT_SHAPES) +@pytest.mark.parametrize("mask", [0x0000, 0xACE1, 0xFFFF, 0x5A5A]) +def test_per_token_k1_with_rht_matches_reference( + M: int, + K: int, + mask: int, +) -> None: + """with_rht=True col_amax matches max|H*D*x_block|/4; rowwise stays raw.""" + torch.manual_seed(0xDEAD * (M + 7) + (K + 3) + mask) + device = torch.device("cuda") + x = torch.randn((M, K), dtype=torch.bfloat16, device=device) + + row_amax = torch.empty((M,), dtype=torch.float32, device=device) + col_amax = torch.empty((K,), dtype=torch.float32, device=device) + + tex.nvfp4_per_token_amax( + x, + row_amax, + col_amax, + True, + True, + with_rht=True, + random_sign_mask_t=mask, + ) + + ref_row, _ = _reference_amax_raw(x) + torch.testing.assert_close( + row_amax, + ref_row, + rtol=0.0, + atol=0.0, + msg=f"row_amax mismatch at ({M}, {K}, mask=0x{mask:04X})", + ) + + # Col tolerance accounts for bf16->fp32 promotion noise + butterfly + # summation order vs. einsum reduction order. + ref_col = _reference_col_amax_rht(x, mask) + torch.testing.assert_close( + col_amax, + ref_col, + rtol=2e-3, + atol=1e-4, + msg=f"col_amax (RHT) mismatch at ({M}, {K}, mask=0x{mask:04X})", + ) + + +@_GATED_SM100 +@pytest.mark.parametrize("M,K", [(128, 128), (256, 512)]) +def test_per_token_k1_with_rht_zero_mask_is_hadamard_only(M: int, K: int) -> None: + """mask=0 -> D=I; col_amax equals bare Hadamard amax max|H*x_block|/4.""" + torch.manual_seed(0xC0DE * (M + 11) + K) + device = torch.device("cuda") + x = torch.randn((M, K), dtype=torch.bfloat16, device=device) + + row_amax = torch.empty((M,), dtype=torch.float32, device=device) + col_amax = torch.empty((K,), dtype=torch.float32, device=device) + + tex.nvfp4_per_token_amax( + x, + row_amax, + col_amax, + True, + True, + with_rht=True, + random_sign_mask_t=0, + ) + + H = _walsh_hadamard_16(device) + x_fp32 = x.to(torch.float32) + blocks = x_fp32.reshape(M // 16, 16, K) + rotated = torch.einsum("ij,bjk->bik", H, blocks) + ref_col = (rotated.abs() / 4.0).reshape(-1, K).amax(dim=0) + + torch.testing.assert_close( + col_amax, + ref_col, + rtol=2e-3, + atol=1e-4, + msg=f"col_amax (RHT, mask=0) mismatch at ({M}, {K})", + ) + + +# ----- (5b) K2 + composite RHT: encode kernel and composite quantize -------- + + +@_GATED_SM100 +@pytest.mark.parametrize("M,K", _RHT_SHAPES) +def test_per_token_composite_with_rht_false_byte_equal(M: int, K: int) -> None: + """Regression: with_rht=False composite byte-equals the default (no-kwargs) path.""" + torch.manual_seed(0xCAFE * (M + 1) + K) + device = torch.device("cuda") + x = torch.randn((M, K), dtype=torch.bfloat16, device=device) + + bufs_default = _allocate_per_token_buffers(M, K, device) + bufs_explicit = _allocate_per_token_buffers(M, K, device) + + tex.nvfp4_per_token_quantize( + x, + bufs_default["q_row"], + bufs_default["s_row"], + bufs_default["ra"], + bufs_default["q_col"], + bufs_default["s_col"], + bufs_default["ca"], + True, + True, + ) + tex.nvfp4_per_token_quantize( + x, + bufs_explicit["q_row"], + bufs_explicit["s_row"], + bufs_explicit["ra"], + bufs_explicit["q_col"], + bufs_explicit["s_col"], + bufs_explicit["ca"], + True, + True, + with_rht=False, + random_sign_mask_t=0xACE1, + ) + + for k in ("q_row", "s_row", "ra", "q_col", "s_col", "ca"): + assert torch.equal( + bufs_default[k], bufs_explicit[k] + ), f"with_rht=False not byte-equal to default path on `{k}` at ({M}, {K})" + + +@_GATED_SM100 +@pytest.mark.parametrize("M,K", _RHT_SHAPES) +def test_per_token_composite_rowwise_unchanged_under_rht(M: int, K: int) -> None: + """Rowwise FP4 + inner SF + row amax byte-equal across with_rht=False / True.""" + torch.manual_seed(0xBEEF * (M + 3) + K) + device = torch.device("cuda") + x = torch.randn((M, K), dtype=torch.bfloat16, device=device) + + bufs_no_rht = _allocate_per_token_buffers(M, K, device) + bufs_with_rht = _allocate_per_token_buffers(M, K, device) + + tex.nvfp4_per_token_quantize( + x, + bufs_no_rht["q_row"], + bufs_no_rht["s_row"], + bufs_no_rht["ra"], + bufs_no_rht["q_col"], + bufs_no_rht["s_col"], + bufs_no_rht["ca"], + True, + True, + with_rht=False, + random_sign_mask_t=0, + ) + tex.nvfp4_per_token_quantize( + x, + bufs_with_rht["q_row"], + bufs_with_rht["s_row"], + bufs_with_rht["ra"], + bufs_with_rht["q_col"], + bufs_with_rht["s_col"], + bufs_with_rht["ca"], + True, + True, + with_rht=True, + random_sign_mask_t=0xACE1, + ) + + for k in ("q_row", "s_row", "ra"): + assert torch.equal(bufs_no_rht[k], bufs_with_rht[k]), ( + f"rowwise output differs between with_rht=False/True on `{k}` " + f"at ({M}, {K}) -- rowwise should never see RHT." + ) + + +@_GATED_SM100 +@pytest.mark.parametrize("M,K", [(128, 128), (256, 512), (512, 512)]) +@pytest.mark.parametrize("mask", [0x0000, 0xACE1, 0xFFFF]) +def test_per_token_composite_with_rht_col_dequant_matches_reference( + M: int, + K: int, + mask: int, +) -> None: + """Dequant'd col FP4 (with_rht=True) ~ H*D*x_block/sqrt(16); checks + column-aggregate median + p99 relative error (FP4's 16-code grain and + butterfly permutation make element-wise comparison too loose). + """ + torch.manual_seed(0xFEED * (M + 5) + K + mask) + device = torch.device("cuda") + # Scale down so most blocks land in non-saturating FP4 (else we measure + # clamping noise, not RHT). + x = torch.randn((M, K), dtype=torch.bfloat16, device=device) * 0.5 + + bufs = _allocate_per_token_buffers(M, K, device) + tex.nvfp4_per_token_quantize( + x, + bufs["q_row"], + bufs["s_row"], + bufs["ra"], + bufs["q_col"], + bufs["s_col"], + bufs["ca"], + True, + True, + with_rht=True, + random_sign_mask_t=mask, + ) + + H = _walsh_hadamard_16(device) + sign = _sign_diag_16(mask, device) + x_fp32 = x.to(torch.float32) + blocks = x_fp32.reshape(M // 16, 16, K) + masked = blocks * sign.view(1, 16, 1) + rotated = torch.einsum("ij,bjk->bik", H, masked) # (M/16, 16, K) + y_ref = rotated.reshape(M, K) / 4.0 # (M, K) + y_ref_col_view = y_ref.transpose(0, 1).contiguous() # (K, M) + + y_kernel = _dequant_fp4_with_outer_amax( + bufs["q_col"], + bufs["s_col"], + bufs["ca"], + ) # (K, M) + + diff = (y_kernel - y_ref_col_view).abs() + col_outer = bufs["ca"].unsqueeze(1).clamp(min=1e-6) + rel = diff / col_outer + p99 = torch.quantile(rel.flatten(), 0.99).item() + median = rel.median().item() + assert median < 0.1, ( + f"median per-element relative error too large: {median:.4f} > 0.1 " + f"at ({M}, {K}, mask=0x{mask:04X})" + ) + assert ( + p99 < 0.5 + ), f"p99 per-element relative error too large: {p99:.4f} > 0.5 at ({M}, {K}, mask=0x{mask:04X})" + + +@_GATED_SM100 +@pytest.mark.parametrize("M,K", [(128, 128), (256, 256)]) +def test_per_token_composite_with_rht_col_amax_matches_k1( + M: int, + K: int, +) -> None: + """Composite col_amax byte-equals standalone K1 amax with the same mask.""" + torch.manual_seed(0xDADA * (M + 13) + K) + device = torch.device("cuda") + x = torch.randn((M, K), dtype=torch.bfloat16, device=device) + mask = 0xACE1 + + bufs = _allocate_per_token_buffers(M, K, device) + tex.nvfp4_per_token_quantize( + x, + bufs["q_row"], + bufs["s_row"], + bufs["ra"], + bufs["q_col"], + bufs["s_col"], + bufs["ca"], + True, + True, + with_rht=True, + random_sign_mask_t=mask, + ) + + ra_k1 = torch.empty((M,), dtype=torch.float32, device=device) + ca_k1 = torch.empty((K,), dtype=torch.float32, device=device) + tex.nvfp4_per_token_amax( + x, + ra_k1, + ca_k1, + True, + True, + with_rht=True, + random_sign_mask_t=mask, + ) + + torch.testing.assert_close( + bufs["ca"], ca_k1, rtol=0.0, atol=0.0, msg=f"composite ca != K1-only ca at ({M}, {K})" + ) + torch.testing.assert_close( + bufs["ra"], ra_k1, rtol=0.0, atol=0.0, msg=f"composite ra != K1-only ra at ({M}, {K})" + ) + + +# ============================================================================= +# (6) Fused-swizzle correctness: K2 with_swizzle=True emits rowwise SF in +# cuBLAS LT layout. Tests cover byte-equal vs Python reference, other-outputs +# identical to with_swizzle=False, and GEMM fast-path numerical equivalence. +# ============================================================================= + +_SWIZZLE_SHAPES = [ + (128, 128), + (256, 256), + (512, 512), + (256, 1024), + (1024, 256), +] + + +def _swizzle_sf_reference(sf_m_major: torch.Tensor) -> torch.Tensor: + """Reference M-major (M, K_SF) e4m3 -> cuBLAS LT swizzled flat bytes + (128Mx4K tile, 16-byte slot = 4 M-stripes x 4 K-bytes stripe-major).""" + M, K_SF = sf_m_major.shape + assert M % 128 == 0 + assert K_SF % 4 == 0 + device = sf_m_major.device + sf_u8 = sf_m_major.contiguous().view(torch.uint8) + out = torch.empty(M * K_SF, dtype=torch.uint8, device=device) + + m_idx = torch.arange(M, device=device, dtype=torch.int64).view(M, 1).expand(M, K_SF) + k_idx = torch.arange(K_SF, device=device, dtype=torch.int64).view(1, K_SF).expand(M, K_SF) + m_tile = m_idx // 128 + k_tile = k_idx // 4 + out_idx = ( + m_tile * 128 * K_SF + + k_tile * 512 + + (m_idx % 32) * 16 + + ((m_idx % 128) // 32) * 4 + + (k_idx % 4) + ) + out[out_idx.reshape(-1)] = sf_u8.reshape(-1) + return out + + +@_GATED_SM100 +@pytest.mark.parametrize("M,K", _SWIZZLE_SHAPES) +def test_per_token_with_swizzle_sf_byte_equal_to_reference(M: int, K: int) -> None: + """Fused-swizzle rowwise scale_inv matches the Python byte-permutation + reference of the M-major SF (covers both rowwise-only and rowwise+colwise). + """ + device = torch.device("cuda") + torch.manual_seed(0) + x = torch.randn((M, K), dtype=torch.bfloat16, device=device) + + out_plain = nvfp4_per_token_quantize(x, rowwise=True, columnwise=True, with_swizzle=False) + out_swz = nvfp4_per_token_quantize(x, rowwise=True, columnwise=True, with_swizzle=True) + + ref_swz_sf = _swizzle_sf_reference(out_plain.scale.view(torch.uint8)) + got_swz_sf = out_swz.scale.view(torch.uint8).reshape(-1) + + torch.testing.assert_close( + got_swz_sf, + ref_swz_sf, + rtol=0, + atol=0, + msg=f"fused-swizzle rowwise SF mismatch at ({M}, {K})", + ) + + +@_GATED_SM100 +@pytest.mark.parametrize("M,K", _SWIZZLE_SHAPES) +def test_per_token_with_swizzle_other_outputs_unchanged(M: int, K: int) -> None: + """Only the rowwise scale_inv layout differs: FP4 data, row_amax, colwise + data / scale_inv / col_amax must be byte-identical between with_swizzle + True and False. + """ + device = torch.device("cuda") + torch.manual_seed(0) + x = torch.randn((M, K), dtype=torch.bfloat16, device=device) + + out_plain = nvfp4_per_token_quantize(x, rowwise=True, columnwise=True, with_swizzle=False) + out_swz = nvfp4_per_token_quantize(x, rowwise=True, columnwise=True, with_swizzle=True) + + torch.testing.assert_close(out_swz.data, out_plain.data, rtol=0, atol=0) + torch.testing.assert_close(out_swz.row_amax, out_plain.row_amax, rtol=0, atol=0) + torch.testing.assert_close(out_swz.columnwise_data, out_plain.columnwise_data, rtol=0, atol=0) + torch.testing.assert_close(out_swz.columnwise_scale, out_plain.columnwise_scale, rtol=0, atol=0) + torch.testing.assert_close(out_swz.col_amax, out_plain.col_amax, rtol=0, atol=0) + + +@_GATED_SM100 +@pytest.mark.parametrize("M,K", [(256, 256), (512, 1024), (1024, 512)]) +def test_per_token_gemm_with_fused_swizzle_matches_unswizzled(M: int, K: int) -> None: + """E2E GEMM two paths: (A) with_swizzle=False + ext swizzle (sf_swizzled=False) + vs (B) with_swizzle=True + sf_swizzled=True. Same SF bytes to cuBLAS LT, + so C outputs must be byte-equal.""" + device = torch.device("cuda") + torch.manual_seed(0) + N = M # square; GEMM is TN with M, N free + A = torch.randn((M, K), dtype=torch.bfloat16, device=device) + B = torch.randn((N, K), dtype=torch.bfloat16, device=device) + + a_plain = nvfp4_per_token_quantize(A, rowwise=True, columnwise=False, with_swizzle=False) + b_plain = nvfp4_per_token_quantize(B, rowwise=True, columnwise=False, with_swizzle=False) + c_unswz = nvfp4_per_token_gemm( + a_plain.data, + a_plain.scale, + a_plain.row_amax, + b_plain.data, + b_plain.scale, + b_plain.row_amax, + ) + + a_swz = nvfp4_per_token_quantize(A, rowwise=True, columnwise=False, with_swizzle=True) + b_swz = nvfp4_per_token_quantize(B, rowwise=True, columnwise=False, with_swizzle=True) + c_swz = nvfp4_per_token_gemm( + a_swz.data, + a_swz.scale, + a_swz.row_amax, + b_swz.data, + b_swz.scale, + b_swz.row_amax, + a_sf_swizzled=True, + b_sf_swizzled=True, + ) + + torch.testing.assert_close( + c_swz, + c_unswz, + rtol=0, + atol=0, + msg=f"fused-swizzle GEMM output != unswizzled-input GEMM at ({M}, {K})", + ) diff --git a/tests/pytorch/nvfp4/test_nvfp4_per_token_group.py b/tests/pytorch/nvfp4/test_nvfp4_per_token_group.py new file mode 100644 index 0000000000..21fedabaab --- /dev/null +++ b/tests/pytorch/nvfp4/test_nvfp4_per_token_group.py @@ -0,0 +1,591 @@ +"""Correctness tests for grouped (multi-tensor) NVFP4 per-token cast. + +The grouped kernel must be byte-equal to a for-loop of single-tensor +calls. Covers composite K1+K2, K1-only, single-split, many-split, and +optional RHT (random Hadamard transform) on the column direction. +""" + +from __future__ import annotations + +from typing import List, Optional, Sequence, Tuple + +import pytest +import torch + +# Import transformer_engine first to dlopen libtransformer_engine.so so that +# transformer_engine_torch can resolve typeinfo / vtable symbols at load time. +import transformer_engine.pytorch as te # noqa: F401 +import transformer_engine_torch as tex # type: ignore # noqa: F401 + +from transformer_engine.pytorch.custom_recipes.quantization_nvfp4_per_token import ( + BLOCK_K, + RefNVFP4TensorPerToken, + nvfp4_per_token_quantize, +) +from transformer_engine.pytorch.custom_recipes.quantization_nvfp4_per_token_group import ( + nvfp4_per_token_group_quantize, +) + + +def _has_fp4() -> bool: + if not torch.cuda.is_available(): + return False + major, _ = torch.cuda.get_device_capability() + return major >= 10 + + +_GATED_FP4 = pytest.mark.skipif( + not _has_fp4(), + reason="NVFP4 per-token cast requires SM100 (Blackwell) + CUDA 12.8+", +) + + +# Helper: invoke the grouped binding. +def _alloc_per_token_buffers( + M_i: int, + K: int, + rowwise: bool, + columnwise: bool, + device: torch.device, +) -> Tuple[ + Optional[torch.Tensor], # q_row + Optional[torch.Tensor], # s_dec_row + Optional[torch.Tensor], # row_amax + Optional[torch.Tensor], # q_col + Optional[torch.Tensor], # s_dec_col + Optional[torch.Tensor], # col_amax +]: + q_row = None + s_dec_row = None + row_amax = None + q_col = None + s_dec_col = None + col_amax = None + if rowwise: + q_row = torch.empty((M_i, K // 2), dtype=torch.uint8, device=device) + s_dec_row = torch.empty((M_i, K // BLOCK_K), dtype=torch.uint8, device=device) + row_amax = torch.empty((M_i,), dtype=torch.float32, device=device) + if columnwise: + q_col = torch.empty((K, M_i // 2), dtype=torch.uint8, device=device) + s_dec_col = torch.empty((K, M_i // BLOCK_K), dtype=torch.uint8, device=device) + col_amax = torch.empty((K,), dtype=torch.float32, device=device) + return q_row, s_dec_row, row_amax, q_col, s_dec_col, col_amax + + +def _group_quantize_py( + x_concat: torch.Tensor, + split_sections: List[int], + rowwise: bool, + columnwise: bool, +) -> List[RefNVFP4TensorPerToken]: + """Pre-allocate per-split outputs, dispatch tex.nvfp4_per_token_group_quantize.""" + assert x_concat.dim() == 2 + sum_M, K = x_concat.shape + assert sum(split_sections) == sum_M + device = x_concat.device + + n = len(split_sections) + q_row_list: List[torch.Tensor] = [] + s_dec_row_list: List[torch.Tensor] = [] + row_amax_list: List[torch.Tensor] = [] + q_col_list: List[torch.Tensor] = [] + s_dec_col_list: List[torch.Tensor] = [] + col_amax_list: List[torch.Tensor] = [] + + for M_i in split_sections: + qr, sr, ra, qc, sc, ca = _alloc_per_token_buffers(M_i, K, rowwise, columnwise, device) + if rowwise: + q_row_list.append(qr) + s_dec_row_list.append(sr) + row_amax_list.append(ra) + if columnwise: + q_col_list.append(qc) + s_dec_col_list.append(sc) + col_amax_list.append(ca) + + # Binding wants lists matching num_tensors; pass empty for skipped direction. + empty: List[torch.Tensor] = [] + + tex.nvfp4_per_token_group_quantize( + x_concat, + split_sections, + q_row_list if rowwise else empty, + s_dec_row_list if rowwise else empty, + row_amax_list if rowwise else empty, + q_col_list if columnwise else empty, + s_dec_col_list if columnwise else empty, + col_amax_list if columnwise else empty, + rowwise, + columnwise, + ) + + out: List[RefNVFP4TensorPerToken] = [] + for i in range(n): + # Re-view e4m3 SF as torch.float8_e4m3fn (same bytes, expected dtype). + tensor = RefNVFP4TensorPerToken( + data=q_row_list[i] if rowwise else None, + scale=(s_dec_row_list[i].view(torch.float8_e4m3fn) if rowwise else None), + row_amax=row_amax_list[i] if rowwise else None, + columnwise_data=q_col_list[i] if columnwise else None, + columnwise_scale=(s_dec_col_list[i].view(torch.float8_e4m3fn) if columnwise else None), + col_amax=col_amax_list[i] if columnwise else None, + ) + out.append(tensor) + return out + + +# Test fixtures. Per-token kernel requires M_i % 128 == 0 and K % 128 == 0. +_SHAPES: List[Tuple[List[int], int]] = [ + # (split_sections, K) + ([128], 128), # trivial: 1 split, smallest legal shape + ([128, 128], 128), # 2 equal splits + ([128, 256], 128), # 2 unequal splits + ([128, 256, 128], 256), # 3 splits, mixed sizes + ([128, 128, 128, 128], 256), # 4 equal splits + ([256, 128, 384, 128, 128], 512), # 5-way unequal split, typical MoE + ([256, 256], 1024), # larger K, 2 splits +] + + +# (1) Composite K1+K2: grouped == for-loop of single-tensor, byte-equal. +@_GATED_FP4 +@pytest.mark.parametrize("split_sections,K", _SHAPES) +@pytest.mark.parametrize("rowwise,columnwise", [(True, False), (False, True), (True, True)]) +@pytest.mark.parametrize("dtype", [torch.bfloat16]) +def test_group_per_token_quantize_byte_equal( + split_sections: List[int], + K: int, + rowwise: bool, + columnwise: bool, + dtype: torch.dtype, +) -> None: + """Grouped == for-loop of single-tensor, byte-equal (FP4 + SF + amax).""" + torch.manual_seed(0xCAFE * (sum(split_sections) + 7) + K) + device = torch.device("cuda") + sum_M = sum(split_sections) + + # Per-split inputs with sprinkled outliers to stress per-row outer. + splits_in: List[torch.Tensor] = [] + for i, M_i in enumerate(split_sections): + s = torch.randn((M_i, K), dtype=dtype, device=device) * (2.0 + 0.5 * i) + if M_i >= 4: + s[0, :] *= 8.0 + s[-1, :] *= 0.125 + splits_in.append(s) + + x_concat = torch.cat(splits_in, dim=0) + assert x_concat.shape == (sum_M, K) + + oracle: List[RefNVFP4TensorPerToken] = [ + nvfp4_per_token_quantize(s, rowwise=rowwise, columnwise=columnwise) for s in splits_in + ] + + sut: List[RefNVFP4TensorPerToken] = _group_quantize_py( + x_concat, split_sections, rowwise=rowwise, columnwise=columnwise + ) + + assert len(sut) == len(oracle) == len(split_sections) + + for i in range(len(split_sections)): + if rowwise: + torch.testing.assert_close( + sut[i].data.view(torch.uint8), + oracle[i].data.view(torch.uint8), + atol=0.0, + rtol=0.0, + msg=f"rowwise q[{i}] mismatch", + ) + torch.testing.assert_close( + sut[i].scale.view(torch.uint8), + oracle[i].scale.view(torch.uint8), + atol=0.0, + rtol=0.0, + msg=f"rowwise s_dec[{i}] mismatch", + ) + torch.testing.assert_close( + sut[i].row_amax, + oracle[i].row_amax, + atol=0.0, + rtol=0.0, + msg=f"row_amax[{i}] mismatch", + ) + if columnwise: + torch.testing.assert_close( + sut[i].columnwise_data.view(torch.uint8), + oracle[i].columnwise_data.view(torch.uint8), + atol=0.0, + rtol=0.0, + msg=f"columnwise q[{i}] mismatch", + ) + torch.testing.assert_close( + sut[i].columnwise_scale.view(torch.uint8), + oracle[i].columnwise_scale.view(torch.uint8), + atol=0.0, + rtol=0.0, + msg=f"columnwise s_dec[{i}] mismatch", + ) + torch.testing.assert_close( + sut[i].col_amax, + oracle[i].col_amax, + atol=0.0, + rtol=0.0, + msg=f"col_amax[{i}] mismatch", + ) + + +# (2) K1-only (amax) entry == K1-only of single-tensor, byte-equal. +@_GATED_FP4 +@pytest.mark.parametrize("split_sections,K", _SHAPES[:3]) # subset, K1 is simple +@pytest.mark.parametrize("rowwise,columnwise", [(True, False), (False, True), (True, True)]) +def test_group_per_token_amax_byte_equal( + split_sections: List[int], + K: int, + rowwise: bool, + columnwise: bool, +) -> None: + """tex.nvfp4_per_token_group_amax matches K1 of the for-loop variant.""" + torch.manual_seed(0xDEAD * sum(split_sections) + K) + device = torch.device("cuda") + sum_M = sum(split_sections) + n = len(split_sections) + + splits_in: List[torch.Tensor] = [] + for i, M_i in enumerate(split_sections): + splits_in.append(torch.randn((M_i, K), dtype=torch.bfloat16, device=device) * 3.0) + x_concat = torch.cat(splits_in, dim=0) + + # Oracle row_amax / col_amax via single-tensor quantize (shared K1). + oracle_row = [] + oracle_col = [] + for s in splits_in: + o = nvfp4_per_token_quantize(s, rowwise=rowwise, columnwise=columnwise) + oracle_row.append(o.row_amax if rowwise else None) + oracle_col.append(o.col_amax if columnwise else None) + + row_amax_list = ( + [torch.empty((M_i,), dtype=torch.float32, device=device) for M_i in split_sections] + if rowwise + else [] + ) + col_amax_list = ( + [torch.empty((K,), dtype=torch.float32, device=device) for _ in range(n)] + if columnwise + else [] + ) + + tex.nvfp4_per_token_group_amax( + x_concat, split_sections, row_amax_list, col_amax_list, rowwise, columnwise + ) + + if rowwise: + for i in range(n): + torch.testing.assert_close( + row_amax_list[i], + oracle_row[i], + atol=0.0, + rtol=0.0, + msg=f"row_amax[{i}] mismatch", + ) + if columnwise: + for i in range(n): + torch.testing.assert_close( + col_amax_list[i], + oracle_col[i], + atol=0.0, + rtol=0.0, + msg=f"col_amax[{i}] mismatch", + ) + + +# (3) Single-split call must equal the single-tensor kernel. +@_GATED_FP4 +@pytest.mark.parametrize("M,K", [(128, 128), (128, 256), (256, 1024)]) +@pytest.mark.parametrize("rowwise,columnwise", [(True, False), (False, True), (True, True)]) +def test_group_single_split_matches_single_tensor( + M: int, K: int, rowwise: bool, columnwise: bool +) -> None: + """One-split grouped call == single-tensor call (boundary-advance no-op).""" + torch.manual_seed(0xBABE * M + K) + device = torch.device("cuda") + x = torch.randn((M, K), dtype=torch.bfloat16, device=device) * 4.0 + + oracle = nvfp4_per_token_quantize(x, rowwise=rowwise, columnwise=columnwise) + sut_list = _group_quantize_py(x, [M], rowwise=rowwise, columnwise=columnwise) + assert len(sut_list) == 1 + sut = sut_list[0] + + if rowwise: + torch.testing.assert_close(sut.data, oracle.data, atol=0.0, rtol=0.0) + torch.testing.assert_close( + sut.scale.view(torch.uint8), + oracle.scale.view(torch.uint8), + atol=0.0, + rtol=0.0, + ) + torch.testing.assert_close(sut.row_amax, oracle.row_amax, atol=0.0, rtol=0.0) + if columnwise: + torch.testing.assert_close(sut.columnwise_data, oracle.columnwise_data, atol=0.0, rtol=0.0) + torch.testing.assert_close( + sut.columnwise_scale.view(torch.uint8), + oracle.columnwise_scale.view(torch.uint8), + atol=0.0, + rtol=0.0, + ) + torch.testing.assert_close(sut.col_amax, oracle.col_amax, atol=0.0, rtol=0.0) + + +# (4) Many-split scaling test (close to the 64-tensor cap). +@_GATED_FP4 +@pytest.mark.parametrize("n_splits", [8, 16, 32, 64]) +def test_group_many_splits_byte_equal(n_splits: int) -> None: + """Many small splits (MoE expert layout) still byte-equal to oracle.""" + torch.manual_seed(0xFEED * n_splits) + device = torch.device("cuda") + K = 256 + split_sections = [128] * n_splits + + splits_in = [ + torch.randn((128, K), dtype=torch.bfloat16, device=device) * (1.0 + 0.1 * i) + for i in range(n_splits) + ] + x_concat = torch.cat(splits_in, dim=0) + + oracle = [nvfp4_per_token_quantize(s, rowwise=True, columnwise=True) for s in splits_in] + sut = _group_quantize_py(x_concat, split_sections, rowwise=True, columnwise=True) + + for i in range(n_splits): + torch.testing.assert_close(sut[i].data, oracle[i].data, atol=0.0, rtol=0.0) + torch.testing.assert_close(sut[i].row_amax, oracle[i].row_amax, atol=0.0, rtol=0.0) + torch.testing.assert_close( + sut[i].columnwise_data, oracle[i].columnwise_data, atol=0.0, rtol=0.0 + ) + torch.testing.assert_close(sut[i].col_amax, oracle[i].col_amax, atol=0.0, rtol=0.0) + + +# ============================================================================= +# (5) RHT correctness: grouped K1+K2 with optional col-wise RHT. +# Contract: each split's 6 outputs MUST byte-equal single-tensor with the +# same mask. Row direction never sees RHT. +# ============================================================================= + +_RHT_GROUP_SHAPES: List[Tuple[List[int], int]] = [ + ([128, 128], 128), # 2 splits, smallest legal shape + ([128, 256, 128], 256), # 3 splits, mixed sizes + ([256, 256, 256, 256], 512), # 4 equal splits, larger K + ([128, 384], 128), # 2 splits, very asymmetric +] + + +def _rht_pt_buffers(M: int, K: int, device: torch.device): + """Match the layout that ``tex.nvfp4_per_token_quantize`` writes.""" + return { + "q_row": torch.empty((M, K // 2), dtype=torch.uint8, device=device), + "s_row": torch.empty((M, K // BLOCK_K), dtype=torch.uint8, device=device), + "ra": torch.empty((M,), dtype=torch.float32, device=device), + "q_col": torch.empty((K, M // 2), dtype=torch.uint8, device=device), + "s_col": torch.empty((K, M // BLOCK_K), dtype=torch.uint8, device=device), + "ca": torch.empty((K,), dtype=torch.float32, device=device), + } + + +def _split_views(x_concat: torch.Tensor, splits: Sequence[int]) -> List[torch.Tensor]: + out, off = [], 0 + for s in splits: + out.append(x_concat[off : off + s].contiguous()) + off += int(s) + return out + + +@_GATED_FP4 +@pytest.mark.parametrize("splits,K", _RHT_GROUP_SHAPES) +def test_group_with_rht_false_byte_equal_to_default( + splits: List[int], + K: int, +) -> None: + """Regression: with_rht=False grouped byte-equals the default (no-kwargs) path.""" + torch.manual_seed(0xCAFE * (sum(splits) + 1) + K + len(splits)) + device = torch.device("cuda") + sum_M = sum(splits) + x = torch.randn((sum_M, K), dtype=torch.bfloat16, device=device).contiguous() + + outs_default = nvfp4_per_token_group_quantize( + x, + splits, + rowwise=True, + columnwise=True, + ) + outs_explicit_false = nvfp4_per_token_group_quantize( + x, + splits, + rowwise=True, + columnwise=True, + with_rht=False, + random_sign_mask_t=0xACE1, + ) + + assert len(outs_default) == len(outs_explicit_false) == len(splits) + for i, (a, b) in enumerate(zip(outs_default, outs_explicit_false)): + for attr in ( + "data", + "scale", + "row_amax", + "columnwise_data", + "columnwise_scale", + "col_amax", + ): + ta, tb = getattr(a, attr), getattr(b, attr) + assert torch.equal(ta, tb), ( + f"split[{i}].{attr} differs between default and explicit " + f"with_rht=False at K={K}, splits={splits}" + ) + + +@_GATED_FP4 +@pytest.mark.parametrize("splits,K", _RHT_GROUP_SHAPES) +def test_group_rowwise_unchanged_under_rht( + splits: List[int], + K: int, +) -> None: + """Rowwise outputs byte-equal across with_rht=False / True.""" + torch.manual_seed(0xBEEF * (sum(splits) + 3) + K) + device = torch.device("cuda") + sum_M = sum(splits) + x = torch.randn((sum_M, K), dtype=torch.bfloat16, device=device).contiguous() + + outs_no_rht = nvfp4_per_token_group_quantize( + x, + splits, + rowwise=True, + columnwise=True, + with_rht=False, + random_sign_mask_t=0, + ) + outs_with_rht = nvfp4_per_token_group_quantize( + x, + splits, + rowwise=True, + columnwise=True, + with_rht=True, + random_sign_mask_t=0xACE1, + ) + + for i, (a, b) in enumerate(zip(outs_no_rht, outs_with_rht)): + for attr in ("data", "scale", "row_amax"): + ta, tb = getattr(a, attr), getattr(b, attr) + assert torch.equal(ta, tb), ( + f"split[{i}].{attr} differs between with_rht=False and =True " + f"on the ROW direction at K={K}, splits={splits} -- " + "rowwise should never see RHT." + ) + + +@_GATED_FP4 +@pytest.mark.parametrize("splits,K", _RHT_GROUP_SHAPES) +@pytest.mark.parametrize("mask", [0x0000, 0xACE1, 0xFFFF]) +def test_group_with_rht_equals_single_tensor_per_split( + splits: List[int], + K: int, + mask: int, +) -> None: + """Each split's 6 outputs byte-equal single-tensor with the same mask.""" + torch.manual_seed(0xDADA * (sum(splits) + 11) + K + mask) + device = torch.device("cuda") + sum_M = sum(splits) + x = torch.randn((sum_M, K), dtype=torch.bfloat16, device=device).contiguous() + + outs_grouped = nvfp4_per_token_group_quantize( + x, + splits, + rowwise=True, + columnwise=True, + with_rht=True, + random_sign_mask_t=mask, + ) + + x_splits = _split_views(x, splits) + for i, (x_i, out_g) in enumerate(zip(x_splits, outs_grouped)): + M_i = x_i.size(0) + bufs = _rht_pt_buffers(M_i, K, device) + tex.nvfp4_per_token_quantize( + x_i, + bufs["q_row"], + bufs["s_row"], + bufs["ra"], + bufs["q_col"], + bufs["s_col"], + bufs["ca"], + True, + True, + with_rht=True, + random_sign_mask_t=mask, + ) + + mapping = { + "data": ("q_row", out_g.data), + "scale": ("s_row", out_g.scale.view(torch.uint8)), + "row_amax": ("ra", out_g.row_amax), + "columnwise_data": ("q_col", out_g.columnwise_data), + "columnwise_scale": ("s_col", out_g.columnwise_scale.view(torch.uint8)), + "col_amax": ("ca", out_g.col_amax), + } + for attr, (single_key, grouped_t) in mapping.items(): + single_t = bufs[single_key] + assert single_t.shape == grouped_t.shape, ( + f"split[{i}].{attr} shape mismatch: grouped={grouped_t.shape}, " + f"single-tensor={single_t.shape} at K={K}, splits={splits}, mask=0x{mask:04X}" + ) + assert torch.equal(grouped_t, single_t), ( + f"split[{i}].{attr} grouped result differs from single-tensor " + f"reference at K={K}, splits={splits}, mask=0x{mask:04X}" + ) + + +@_GATED_FP4 +@pytest.mark.parametrize("splits,K", _RHT_GROUP_SHAPES[:2]) +def test_group_k1_amax_matches_single_tensor_per_split_under_rht( + splits: List[int], + K: int, +) -> None: + """Grouped K1 amax byte-equals single-tensor K1 per split. Isolates K1 + via the lighter nvfp4_per_token_group_amax binding to catch K1-vs-K2 + divergences earlier than the full composite check. + """ + torch.manual_seed(0x1234 * (sum(splits) + 7) + K) + device = torch.device("cuda") + sum_M = sum(splits) + x = torch.randn((sum_M, K), dtype=torch.bfloat16, device=device).contiguous() + mask = 0xACE1 + + row_amax_list = [torch.empty((int(s),), dtype=torch.float32, device=device) for s in splits] + col_amax_list = [torch.empty((K,), dtype=torch.float32, device=device) for _ in splits] + tex.nvfp4_per_token_group_amax( + x, + [int(s) for s in splits], + row_amax_list, + col_amax_list, + True, + True, + with_rht=True, + random_sign_mask_t=mask, + ) + + x_splits = _split_views(x, splits) + for i, (x_i, ra_g, ca_g) in enumerate(zip(x_splits, row_amax_list, col_amax_list)): + M_i = x_i.size(0) + ra_s = torch.empty((M_i,), dtype=torch.float32, device=device) + ca_s = torch.empty((K,), dtype=torch.float32, device=device) + tex.nvfp4_per_token_amax( + x_i, + ra_s, + ca_s, + True, + True, + with_rht=True, + random_sign_mask_t=mask, + ) + torch.testing.assert_close( + ra_g, ra_s, rtol=0.0, atol=0.0, msg=f"split[{i}] row_amax mismatch (K1 only)" + ) + torch.testing.assert_close( + ca_g, ca_s, rtol=0.0, atol=0.0, msg=f"split[{i}] col_amax mismatch (K1 only)" + ) diff --git a/transformer_engine/common/CMakeLists.txt b/transformer_engine/common/CMakeLists.txt index 06d85b6d84..baabb37516 100644 --- a/transformer_engine/common/CMakeLists.txt +++ b/transformer_engine/common/CMakeLists.txt @@ -240,8 +240,12 @@ list(APPEND transformer_engine_cuda_arch_specific_sources multi_tensor/compute_scale.cu recipe/mxfp8_scaling.cu recipe/nvfp4.cu - transpose/quantize_transpose_square_blockwise.cu - transpose/quantize_transpose_vector_blockwise_fp4.cu) + cast/nvfp4/quantize_nvfp4_per_token.cu + cast/nvfp4/quantize_nvfp4_per_token_group.cu + gemm/nvfp4_per_token_post_scale.cu + gemm/nvfp4_cutlass_gemm.cu + transpose/quantize_transpose_square_blockwise.cu + transpose/quantize_transpose_vector_blockwise_fp4.cu) # Compiling the files with the worst compilation time first to hopefully overlap # better with the faster-compiling cpp files diff --git a/transformer_engine/common/cast/nvfp4/quantize_nvfp4_per_token.cu b/transformer_engine/common/cast/nvfp4/quantize_nvfp4_per_token.cu new file mode 100644 index 0000000000..4efad9237a --- /dev/null +++ b/transformer_engine/common/cast/nvfp4/quantize_nvfp4_per_token.cu @@ -0,0 +1,1192 @@ +/************************************************************************* + * Copyright (c) 2025-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +/*! \file quantize_nvfp4_per_token.cu + * \brief NVFP4 per-token cast on the bf16 fast path: + * TMA + mbarrier + 64x64 sub-tile + 2-buffer ping-pong. + * + * Pipeline structure mirrors the per-tensor cast kernel + * (``quantize_transpose_nvfp4_tuned_1D_kernel``) and the RHT + * amax kernel (``HadamardAmaxTmaKernel``): + * + * * Two-kernel design (amax pass + encode pass). Output of amax fed + * into encode via per-row/per-col buffers in ``output->amax`` + * and ``output->columnwise_amax`` (sized [M] and [K] respectively). + * * Each CTA covers a 128x128 chunk decomposed as 4 sequential 64x64 + * sub-tiles, double-buffered. Each sub-tile is one TMA bulk-2D + * tensor transaction. mbarrier expect_tx + parity wait gives + * one-iteration-overlap between HBM and compute. + * * Encode pass reads the input tile ONCE into SMEM, then dispatches + * both the rowwise (FP4 + per-row scale) and the columnwise (FP4 + * transpose + per-col scale) outputs from that same staged copy. + * Outer scaling factors S_enc are loaded from + * ``row_amax_in[M]`` / ``col_amax_in[K]`` once per CTA into a small + * SMEM cache (1 KiB total). + */ + +#include + +#include "common/cast/core/common.cuh" +#include "common/cast/nvfp4/core_nvfp4.cuh" +#include "common/common.h" +#include "common/util/ptx.cuh" +#include "common/utils.cuh" + +namespace transformer_engine { +namespace nvfp4_per_token { + +#if FP4_TYPE_SUPPORTED + +using dispatch::common::align_smem_ptr_per_TMA_requirements; +using dispatch::nvfp4::nvfp4_scale_t; +using dispatch::nvfp4::core::compute_global_encode_scaling_factor_FP4; +using dispatch::nvfp4::quantization_SF::compute_decoding_scaling_factor; + +constexpr int CHUNK_DIM_Y = 128; // CTA covers this many rows of input +constexpr int CHUNK_DIM_X = 128; // CTA covers this many cols of input +constexpr int TILE_DIM_Y = 64; // TMA bulk-2D box height +constexpr int TILE_DIM_X = 64; // TMA bulk-2D box width +constexpr int THREADS_NUM = 128; // threads per CTA +constexpr int ELTS_PER_THREAD = 16; // = NVFP4 block size = SCALE_DIM +constexpr int SCALE_DIM = 16; // NVFP4 inner block (1x16) +constexpr int PREFETCH_STAGES = 1; // 1-stage prefetch overlap +constexpr int BUFFS_NUM = PREFETCH_STAGES + 1; // = 2 ping-pong input buffers + +// Derived (chunk / tile / stage) +constexpr int TILES_Y = CHUNK_DIM_Y / TILE_DIM_Y; // 2 +constexpr int TILES_X = CHUNK_DIM_X / TILE_DIM_X; // 2 +constexpr int STAGES = TILES_Y * TILES_X; // 4 + +constexpr int SCALES_PER_CHUNK_X = CHUNK_DIM_X / SCALE_DIM; // 8 inner blocks per row of the chunk +constexpr int SCALES_PER_CHUNK_Y = CHUNK_DIM_Y / SCALE_DIM; // 8 inner blocks per col of the chunk +constexpr int SCALES_PER_TILE_X = TILE_DIM_X / SCALE_DIM; // 4 +constexpr int SCALES_PER_TILE_Y = TILE_DIM_Y / SCALE_DIM; // 4 + +// Encode helpers' thread layout (rowwise pass: 4x32 = K-dim x M-dim) +constexpr int THREADS_X_ROWWISE = TILE_DIM_X / ELTS_PER_THREAD; // 4 +constexpr int THREADS_Y_ROWWISE = THREADS_NUM / THREADS_X_ROWWISE; // 32 +constexpr int THREADS_PER_SCALE_ROWWISE = + SCALE_DIM / ELTS_PER_THREAD; // 1 (each block owned by 1 thread) +constexpr int ITERATIONS_NORMAL = TILE_DIM_Y / THREADS_Y_ROWWISE; // 2 + +// Buffer dimensions (input bf16 SMEM tiles + FP4 output SMEM tiles for TMA store) +constexpr int BUFF_IN_DIM_Y = TILE_DIM_Y; +constexpr int BUFF_IN_DIM_X = TILE_DIM_X; +constexpr int BUFF_IN_SIZE = BUFF_IN_DIM_Y * BUFF_IN_DIM_X; // elements +constexpr int BUFF_OUT_DIM_Y = TILE_DIM_Y; +constexpr int BUFF_OUT_DIM_X = (TILE_DIM_X * 4) / 8; // 32 (2 fp4 per byte) +constexpr int BUFF_OUT_SIZE = BUFF_OUT_DIM_Y * BUFF_OUT_DIM_X; +constexpr int BUFF_OUT_TR_DIM_Y = TILE_DIM_X; +constexpr int BUFF_OUT_TR_DIM_X = (TILE_DIM_Y * 4) / 8; // 32 +constexpr int BUFF_OUT_TR_SIZE = BUFF_OUT_TR_DIM_Y * BUFF_OUT_TR_DIM_X; +constexpr int BUFFS_NUM_OUT = BUFFS_NUM; // 2 ping-pong (matches input) +constexpr int BUFFS_NUM_OUT_TR = 2; // 2 ping-pong for transpose + +// Manual swizzling parameters to reduce SMEM bank conflicts on rowwise loads +constexpr int PACK_SIZE = 8; +constexpr int WAVES = ELTS_PER_THREAD / PACK_SIZE; // 2 +constexpr int TOTAL_BANKS_WIDTH = (32 * 4 * 8) / 4; // 256 +constexpr int THREADS_PER_BANK = TOTAL_BANKS_WIDTH / ELTS_PER_THREAD; // 16 + +using IType = bf16; +using IType2 = ptx::FPx2; // = ptx::bf16x2 +using IType3D = IType[BUFFS_NUM][BUFF_IN_DIM_Y][BUFF_IN_DIM_X]; +using IType2x3D = IType2[BUFFS_NUM][BUFF_IN_DIM_Y][BUFF_IN_DIM_X / 2]; +using OType2x3D = fp4e2m1x2[BUFFS_NUM_OUT][BUFF_OUT_DIM_Y][BUFF_OUT_DIM_X]; +using OType2xt3D = fp4e2m1x2[BUFFS_NUM_OUT_TR][BUFF_OUT_TR_DIM_Y][BUFF_OUT_TR_DIM_X]; +using ScalesType2D = nvfp4_scale_t[CHUNK_DIM_Y][SCALES_PER_CHUNK_X]; +using ScalesTypeTr2D = nvfp4_scale_t[CHUNK_DIM_X][SCALES_PER_CHUNK_Y]; + +// Compute the per-block (1x16) byte-equal arithmetic and emit FP4 codes into +// SMEM rowwise output buffer + e4m3 scale into SMEM rowwise scale buffer. +__device__ __forceinline__ void rowwise_scaling_per_token( + const IType* __restrict__ sIn_ptr, fp4e2m1x2* __restrict__ sOut_ptr, + nvfp4_scale_t* __restrict__ sSFrowwise_ptr, + const float* __restrict__ sRowAmax, // [CHUNK_DIM_Y], indexed by chunk-local row + const int stage_Y, const int stage_X, const int buff_in, const int buff_out) { + const auto& sIn = *reinterpret_cast(sIn_ptr); + auto& sOut = *reinterpret_cast(sOut_ptr); + auto& sSFrowwise = *reinterpret_cast(sSFrowwise_ptr); + + const int thread_lane = threadIdx.x % THREADS_PER_WARP; + const int bank_group = thread_lane / THREADS_PER_BANK; + + const int tid_Y_rowwise = threadIdx.x / THREADS_X_ROWWISE; // 0..31 + const int tid_X_rowwise = threadIdx.x % THREADS_X_ROWWISE; // 0..3 + + const int thread_offset_X_rowwise = + tid_X_rowwise * ELTS_PER_THREAD; // K-elt offset in tile (0/16/32/48) + + const int SF_thread_offset_rowwise_X = + tid_X_rowwise / THREADS_PER_SCALE_ROWWISE; // = tid_X_rowwise here + const bool SF_storing_thread = (tid_X_rowwise % THREADS_PER_SCALE_ROWWISE == 0); + + const int stage_rowwise_scales_offset_X = + SF_thread_offset_rowwise_X + stage_X * SCALES_PER_TILE_X; + +#pragma unroll + for (int it = 0; it < ITERATIONS_NORMAL; ++it) { + const int it_offset_Y_rowwise = tid_Y_rowwise + it * THREADS_Y_ROWWISE; // 0..63 over 2 iters + const int chunk_local_row = stage_Y * TILE_DIM_Y + it_offset_Y_rowwise; // 0..127 + + // Per-row S_enc (look up from CTA-cached row amax buffer) + const float row_amax = sRowAmax[chunk_local_row]; + const float S_enc = compute_global_encode_scaling_factor_FP4(fmaxf(row_amax, 1e-12f)); + + __align__(16) IType2 rIn[WAVES][PACK_SIZE / 2]; + + // Read 16 elements (in PACK_SIZE=8 waves), swizzled to avoid bank conflicts, + // and reduce to a 1x16 block amax. + IType2 thread_amax_2x = {static_cast(0.0f), static_cast(0.0f)}; +#pragma unroll + for (int w = 0; w < WAVES; ++w) { + const int swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % ELTS_PER_THREAD; + const int swizzled_thread_idx = thread_offset_X_rowwise + swizzled_group_idx; + + __uint128_t& elts_8x = *reinterpret_cast<__uint128_t*>(&rIn[w]); + elts_8x = ptx::ld_shared_b128(&sIn[buff_in][it_offset_Y_rowwise][swizzled_thread_idx]); +#pragma unroll + for (int e = 0; e < PACK_SIZE / 2; ++e) { + ptx::abs_max_2x(thread_amax_2x, thread_amax_2x, rIn[w][e]); + } + } + const float block_amax = + static_cast(__hmax(__habs(thread_amax_2x.x), __habs(thread_amax_2x.y))); + + // Byte-equal compute path (matches the Python reference in + // ``NVFP4QuantizerPerTokenRef``): + const fp8e4m3 s_dec = compute_decoding_scaling_factor(block_amax, S_enc); + const float s_dec_f = static_cast(s_dec); + const float block_scale = (s_dec_f == 0.f) ? 0.f : __fdiv_rn(S_enc, s_dec_f); + + // Store e4m3 scale to SMEM SF buffer (1 thread per 1x16 block stores). + if (SF_storing_thread) { + const int scales_offset_Y = chunk_local_row; + const int scales_offset_X = stage_rowwise_scales_offset_X; + sSFrowwise[scales_offset_Y][scales_offset_X] = s_dec; + } + + // Cast 16 elements to FP4 using mul_cvt_4x (4 elements per call, the + // byte-equal path against the Python reference). We've already pre-loaded + // into rIn[WAVES][4]. + // WAVES = 2, PACK_SIZE/2 = 4 elements per wave + // Total per iteration: 2 waves * (4 IType2 elts) = 16 elements +#pragma unroll + for (int w = 0; w < WAVES; ++w) { + const int swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % ELTS_PER_THREAD; + const int swizzled_idx = (swizzled_group_idx + thread_offset_X_rowwise) / 2; + + // 4 fp4 quads from 8 bf16 elements (in PACK_SIZE=8 waves): + // rIn[w][0..3] = 4 IType2 pairs = 8 elements. + // Each mul_cvt_4x packs 4 elements; we need 2 calls per wave. + fp4e2m1x4 qu0{}, qu1{}; + ptx::mul_cvt_4x(qu0, rIn[w][0], rIn[w][1], block_scale); + ptx::mul_cvt_4x(qu1, rIn[w][2], rIn[w][3], block_scale); + + // Pack into a 32-bit word and store to SMEM out (b32 store) + uint32_t out_x8 = (static_cast(*reinterpret_cast(&qu0))) | + (static_cast(*reinterpret_cast(&qu1)) << 16); + ptx::st_shared_b32(&sOut[buff_out][it_offset_Y_rowwise][swizzled_idx], out_x8); + } + } +} + +// Randomized Hadamard Transform helpers (per-thread, 16-wide). Used by the +// optional col-wise RHT path (kWithRht=true) in K1 amax and K2 colwise cast; +// K1 and K2 must consume identical helper output for the encoded FP4 and +// outer SF to be self-consistent (mismatch -> saturated codes / wrong SF). + +// Apply +/-1 sign diagonal D then a 16-pt Walsh-Hadamard butterfly in place. +// Output is NOT normalized; caller multiplies by k16HadamardNorm (0.25). +// Sign-flip is a branchless XOR on the fp32 sign bit (bit-exact == r = -r on +// finite fp32, which is all this helper sees from bf16 SMEM reads). +__device__ __forceinline__ void apply_signed_fht16_inplace(float r[16], uint32_t random_sign_mask) { +#pragma unroll + for (int i = 0; i < 16; ++i) { + const uint32_t bits = __float_as_uint(r[i]); + const uint32_t flip = ((random_sign_mask >> i) & 1u) << 31; + r[i] = __uint_as_float(bits ^ flip); + } +#pragma unroll + for (int stride = 1; stride < 16; stride <<= 1) { +#pragma unroll + for (int g = 0; g < 16; g += stride << 1) { +#pragma unroll + for (int j = 0; j < stride; ++j) { + const float a = r[g + j]; + const float b = r[g + j + stride]; + r[g + j] = a + b; + r[g + j + stride] = a - b; + } + } + } +} + +__device__ __forceinline__ float amax_16_abs(const float r[16]) { + float m = 0.f; +#pragma unroll + for (int i = 0; i < 16; ++i) m = fmaxf(m, fabsf(r[i])); + return m; +} + +// 1/sqrt(16) normalization for the 16-pt Hadamard so H*H^T = I after sign +// scaling. Applied once per block on K1 amax / K2 block_scale. +constexpr float k16HadamardNorm = 0.25f; + +// Per-block (1x16 along M) columnwise FP4 cast; writes transposed FP4 + +// e4m3 SF to SMEM. When kWithRht=true, each thread's 16-row strip is rotated +// through the FHT with random_sign_mask_t; K1 amax must use the same mask so +// sColAmax already reflects the rotated columns. +template +__device__ __forceinline__ void colwise_scaling_per_token( + const IType* __restrict__ sIn_ptr, fp4e2m1x2* __restrict__ sOut_tr_ptr, + nvfp4_scale_t* __restrict__ sSFcolwise_ptr, + const float* __restrict__ sColAmax, // [CHUNK_DIM_X], indexed by chunk-local col + const int stage_Y, const int stage_X, const int buff_in, const int buff_out_tr, + const uint32_t random_sign_mask_t = 0u) { + const auto& sIn2x = *reinterpret_cast(sIn_ptr); + auto& sOut_tr = *reinterpret_cast(sOut_tr_ptr); + auto& sSFcolwise = *reinterpret_cast(sSFcolwise_ptr); + + const int warp = threadIdx.x / THREADS_PER_WARP; // 0..3 + const int thread_lane = threadIdx.x % THREADS_PER_WARP; + + const int tid_Y_colwise = (thread_lane % 4 + warp) % 4; // 0..3 (M-block index in tile) + const int tid_X_colwise = thread_lane; // 0..31 (col-pair index in tile) + + const int thread_offset_Y_colwise = tid_Y_colwise * SCALE_DIM; // 0/16/32/48 + const int thread_offset_X_colwise = tid_X_colwise * 2; // 0/2/.../62 (2 cols per thread) + + const int in_thread_offset_Y = thread_offset_Y_colwise; + const int in_thread_offset_X = thread_offset_X_colwise / 2; // index into IType2[] + + const int out_tr_thread_offset_Y = thread_offset_X_colwise; // transpose: X becomes Y + const int out_tr_thread_offset_X = thread_offset_Y_colwise / 2; // /2 for fp4e2m1x2 byte index + + const int scale_tr_offset_Y = + (stage_X * TILE_DIM_X) + 2 * tid_X_colwise; // chunk-local col index (×1) + const int scale_tr_offset_X = + (stage_Y * SCALES_PER_TILE_Y) + tid_Y_colwise; // chunk-local M-block index + + __align__(8) IType rIn[2][SCALE_DIM]; + // RHT staging in fp32 from FHT through mul_cvt_4x: avoids the lossy + // fp32->bf16->fp32 round-trip and lets us fold the 0.25 normalization into + // block_scale. Untouched by the non-RHT instantiation (nvcc DCE). + float rRht[2][SCALE_DIM]; + + // Non-RHT path accumulates the 1x16 block amax during the load; RHT path + // recomputes it after the butterfly so we skip abs_max_2x here. + IType2 thread_amax_2x = {static_cast(0.0f), static_cast(0.0f)}; +#pragma unroll + for (int i = 0; i < SCALE_DIM; ++i) { + const IType2 elt_pair = + ptx::ld_shared_b32(&sIn2x[buff_in][in_thread_offset_Y + i][in_thread_offset_X]); + rIn[0][i] = elt_pair.x; + rIn[1][i] = elt_pair.y; + if constexpr (!kWithRht) { + ptx::abs_max_2x(thread_amax_2x, thread_amax_2x, elt_pair); + } + } + + // 1x16 block amax used to calibrate the inner FP4 scale. + float block_amax[2]; + if constexpr (kWithRht) { +#pragma unroll + for (int w = 0; w < 2; ++w) { +#pragma unroll + for (int i = 0; i < SCALE_DIM; ++i) { + rRht[w][i] = static_cast(rIn[w][i]); + } + apply_signed_fht16_inplace(rRht[w], random_sign_mask_t); + float local_max = 0.f; +#pragma unroll + for (int i = 0; i < SCALE_DIM; ++i) { + local_max = fmaxf(local_max, fabsf(rRht[w][i])); + } + // amax(|r * 0.25|) == amax(|r|) * 0.25 (exact: 0.25 = 2^-2). One + // post-amax mul instead of 16 per-element muls; matching 0.25 folded + // into block_scale_rht below. + block_amax[w] = local_max * k16HadamardNorm; + } + } else { + block_amax[0] = static_cast(__habs(thread_amax_2x.x)); + block_amax[1] = static_cast(__habs(thread_amax_2x.y)); + } + +#pragma unroll + for (int w = 0; w < 2; ++w) { + // Per-col S_enc lookup (each of the 2 cols this thread owns has its own amax/S_enc). + const int chunk_local_col = scale_tr_offset_Y + w; + const float col_amax = sColAmax[chunk_local_col]; + const float S_enc_col = compute_global_encode_scaling_factor_FP4(fmaxf(col_amax, 1e-12f)); + + const fp8e4m3 s_dec = compute_decoding_scaling_factor(block_amax[w], S_enc_col); + const float s_dec_f = static_cast(s_dec); + const float block_scale = (s_dec_f == 0.f) ? 0.f : __fdiv_rn(S_enc_col, s_dec_f); + + // Store e4m3 scale to SMEM colwise SF buffer. + sSFcolwise[scale_tr_offset_Y + w][scale_tr_offset_X] = s_dec; + + // 4x mul_cvt_4x emits 16 FP4 codes. RHT path feeds fp32 staging so we + // skip the bf16 round-trip; block_scale_rht folds in 0.25. + fp4e2m1x4 qu[4]; + if constexpr (kWithRht) { + const float block_scale_rht = block_scale * k16HadamardNorm; +#pragma unroll + for (int e = 0; e < 4; ++e) { + const ptx::floatx2 in01{rRht[w][4 * e + 0], rRht[w][4 * e + 1]}; + const ptx::floatx2 in23{rRht[w][4 * e + 2], rRht[w][4 * e + 3]}; + ptx::mul_cvt_4x(qu[e], in01, in23, block_scale_rht); + } + } else { +#pragma unroll + for (int e = 0; e < 4; ++e) { + IType2 in01{rIn[w][4 * e + 0], rIn[w][4 * e + 1]}; + IType2 in23{rIn[w][4 * e + 2], rIn[w][4 * e + 3]}; + ptx::mul_cvt_4x(qu[e], in01, in23, block_scale); + } + } + + // Pack 4 fp4e2m1x4 (= 16 fp4) into a 64-bit value and store to SMEM transpose buffer. + uint64_t out_pack_16x = (static_cast(*reinterpret_cast(&qu[0])) << 0) | + (static_cast(*reinterpret_cast(&qu[1])) << 16) | + (static_cast(*reinterpret_cast(&qu[2])) << 32) | + (static_cast(*reinterpret_cast(&qu[3])) << 48); + ptx::st_shared_b64(&sOut_tr[buff_out_tr][out_tr_thread_offset_Y + w][out_tr_thread_offset_X], + out_pack_16x); + } +} + +// ============================================================================= +// Kernel 2: per-token encode (rowwise + optional colwise transpose). +// kWithRht=true: col-wise FP4 cast over RHT-rotated strips (row never sees RHT). +// kWithSwizzle=true: rowwise SF emitted directly in cuBLAS LT 128x4 tile layout. +// ============================================================================= +template +__global__ void __launch_bounds__(THREADS_NUM) + per_token_encode_kernel(const __grid_constant__ CUtensorMap tensor_map_input, + const __grid_constant__ CUtensorMap tensor_map_output, + const __grid_constant__ CUtensorMap tensor_map_output_t, + nvfp4_scale_t* const scales_ptr, nvfp4_scale_t* const scales_t_ptr, + const float* const row_amax_in, const float* const col_amax_in, + const float* noop, const size_t rows, const size_t cols, + const size_t scale_stride, const size_t scale_stride_t, + const uint32_t random_sign_mask_t) { +#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) + if (noop != nullptr && noop[0] == 1.0f) { + return; + } + + const bool leading_thread = (threadIdx.x == 0); + + // ------------------------------------------------------------------------- + // Dynamic SMEM layout + // sIn: 2 buffers x (64x64 bf16) = 16 KiB + // sOut: 2 buffers x (64x32 fp4 packed) = 4 KiB (rowwise FP4) + // sOut_tr: 2 buffers x (64x32 fp4 packed) = 4 KiB (colwise FP4) + // sSFrowwise: 128 x 8 e4m3 = 1 KiB + // sSFcolwise: 128 x 8 e4m3 = 1 KiB + // sRowAmax: 128 fp32 = 512 B + // sColAmax: 128 fp32 = 512 B + // IN_buff_readable_mbar: 2 x 8 B = 16 B + // Total: ~27 KiB + alignment padding. + // ------------------------------------------------------------------------- + constexpr int buff_elems_total_in = BUFFS_NUM * BUFF_IN_SIZE; + constexpr int buff_size_aligned_in = + DIVUP_TO_MULTIPLE(buff_elems_total_in * sizeof(IType), TMA_SHMEM_ALIGNMENT); + constexpr int buff_size_aligned_out = + DIVUP_TO_MULTIPLE(BUFFS_NUM_OUT * BUFF_OUT_SIZE, TMA_SHMEM_ALIGNMENT); + constexpr int buff_size_aligned_out_t = + DIVUP_TO_MULTIPLE(BUFFS_NUM_OUT_TR * BUFF_OUT_TR_SIZE, TMA_SHMEM_ALIGNMENT); + constexpr int out_mem_rowwise_data = DO_ROW ? buff_size_aligned_out : 0; + constexpr int out_mem_colwise_data = DO_COL ? buff_size_aligned_out_t : 0; + constexpr int out_mem_rowwise_scales = + DO_ROW ? DIVUP_TO_MULTIPLE(CHUNK_DIM_Y * SCALES_PER_CHUNK_X * sizeof(nvfp4_scale_t), + TMA_SHMEM_ALIGNMENT) + : 0; + constexpr int out_mem_colwise_scales = + DO_COL ? DIVUP_TO_MULTIPLE(CHUNK_DIM_X * SCALES_PER_CHUNK_Y * sizeof(nvfp4_scale_t), + TMA_SHMEM_ALIGNMENT) + : 0; + + extern __shared__ unsigned char dynamic_shmem[]; + unsigned char* dshmem = align_smem_ptr_per_TMA_requirements(dynamic_shmem); + + IType* sIn_ptr = reinterpret_cast(dshmem); + fp4e2m1x2* sOut_ptr = reinterpret_cast(dshmem + buff_size_aligned_in); + fp4e2m1x2* sOut_tr_ptr = + reinterpret_cast(dshmem + buff_size_aligned_in + out_mem_rowwise_data); + + nvfp4_scale_t* sSFrowwise_ptr = reinterpret_cast( + dshmem + buff_size_aligned_in + out_mem_rowwise_data + out_mem_colwise_data); + nvfp4_scale_t* sSFcolwise_ptr = + reinterpret_cast(dshmem + buff_size_aligned_in + out_mem_rowwise_data + + out_mem_colwise_data + out_mem_rowwise_scales); + + // Per-CTA row/col amax SMEM cache (128 floats each). + __shared__ float sRowAmax[CHUNK_DIM_Y]; + __shared__ float sColAmax[CHUNK_DIM_X]; + __shared__ uint64_t IN_buff_readable_mbar[BUFFS_NUM]; + + auto& sIn = *reinterpret_cast(sIn_ptr); + + constexpr int shmem_buff_size = buff_size_aligned_in / BUFFS_NUM; + + const int32_t ctaid_X = blockIdx.x; + const int32_t ctaid_Y = blockIdx.y; + const int block_offset_Y = ctaid_Y * CHUNK_DIM_Y; + const int block_offset_X = ctaid_X * CHUNK_DIM_X; + // Transpose-output block offsets: row-CTA(X) -> col-tensor's M; col-CTA(Y) -> col-tensor's N. + const int block_offset_Y_tr = ctaid_X * CHUNK_DIM_X; + const int block_offset_X_tr = ctaid_Y * CHUNK_DIM_Y; + + const int scales_block_offset_Y_rowwise = ctaid_Y * CHUNK_DIM_Y; + const int scales_block_offset_X_rowwise = ctaid_X * SCALES_PER_CHUNK_X; + const int scales_block_offset_Y_tr = ctaid_X * CHUNK_DIM_X; + const int scales_block_offset_X_tr = ctaid_Y * SCALES_PER_CHUNK_Y; + + // Load per-row / per-col amax into SMEM cache (cooperative, full chunk = 128 entries each). + if (DO_ROW && threadIdx.x < CHUNK_DIM_Y) { + sRowAmax[threadIdx.x] = row_amax_in[block_offset_Y + threadIdx.x]; + } + if (DO_COL && threadIdx.x < CHUNK_DIM_X) { + sColAmax[threadIdx.x] = col_amax_in[block_offset_X + threadIdx.x]; + } + + // Initialize mbarriers. + if (leading_thread) { +#pragma unroll + for (int buff = 0; buff < BUFFS_NUM; ++buff) { + ptx::mbarrier_init(&IN_buff_readable_mbar[buff], 1); + } + ptx::fence_proxy_async_shared_cta(); + } + __syncthreads(); + + // Prefetch stage 0 (one-iteration overlap throughout main loop). +#pragma unroll + for (int stage = 0; stage < PREFETCH_STAGES; ++stage) { + const int buff_in = stage; + const int stage_Y = stage / TILES_X; + const int stage_X = stage % TILES_X; + const int global_offset_Y = block_offset_Y + stage_Y * TILE_DIM_Y; + const int global_offset_X = block_offset_X + stage_X * TILE_DIM_X; + if (leading_thread) { + uint64_t* dst = reinterpret_cast(&sIn[buff_in]); + const uint64_t* src = reinterpret_cast(&tensor_map_input); + ptx::mbarrier_arrive_expect_tx(&IN_buff_readable_mbar[buff_in], shmem_buff_size); + ptx::cp_async_bulk_tensor_2d_global_to_shared(dst, src, global_offset_X, global_offset_Y, + &IN_buff_readable_mbar[buff_in]); + } + } + + int buff_in = 0; + int buff_out = 0; + int buff_out_tr = 0; + int IN_buff_readable_parity[BUFFS_NUM] = {0, 0}; + +#pragma unroll + for (int stage = 0; stage < STAGES; ++stage) { + const int stage_Y = stage / TILES_X; + const int stage_X = stage % TILES_X; + const int stage_offset_Y = stage_Y * TILE_DIM_Y; + const int stage_offset_X = stage_X * TILE_DIM_X; + + // Prefetch next stage's input (skip after the second-to-last stage). + if (stage < STAGES - PREFETCH_STAGES) { + const int next_prefetch_buff = (buff_in + PREFETCH_STAGES) % BUFFS_NUM; + const int next_prefetch_stage = (stage + PREFETCH_STAGES) % STAGES; + const int next_stage_Y = next_prefetch_stage / TILES_X; + const int next_stage_X = next_prefetch_stage % TILES_X; + const int next_global_offset_Y = block_offset_Y + next_stage_Y * TILE_DIM_Y; + const int next_global_offset_X = block_offset_X + next_stage_X * TILE_DIM_X; + + if (leading_thread) { + uint64_t* dst = reinterpret_cast(&sIn[next_prefetch_buff]); + const uint64_t* src = reinterpret_cast(&tensor_map_input); + ptx::mbarrier_arrive_expect_tx(&IN_buff_readable_mbar[next_prefetch_buff], shmem_buff_size); + ptx::cp_async_bulk_tensor_2d_global_to_shared(dst, src, next_global_offset_X, + next_global_offset_Y, + &IN_buff_readable_mbar[next_prefetch_buff]); + } + ptx::fence_proxy_async_shared_cta(); + } + + // Wait for current stage's input to land. + ptx::mbarrier_wait_parity_acquire_cta_shared_cta(&IN_buff_readable_mbar[buff_in], + IN_buff_readable_parity[buff_in]); + IN_buff_readable_parity[buff_in] ^= 1; + + // Wait for any prior TMA store to have finished reading the output SMEM + // buffers (so we can overwrite them). + ptx::cp_async_bulk_wait_group_read(); + + // ----- Compute: rowwise + colwise from the same SMEM tile ----- + if (DO_ROW) { + rowwise_scaling_per_token(sIn_ptr, sOut_ptr, sSFrowwise_ptr, sRowAmax, stage_Y, stage_X, + buff_in, buff_out); + } + if (DO_COL) { + colwise_scaling_per_token(sIn_ptr, sOut_tr_ptr, sSFcolwise_ptr, sColAmax, stage_Y, + stage_X, buff_in, buff_out_tr, random_sign_mask_t); + } + + // Fence + sync so all threads' SMEM writes are visible to TMA store. + ptx::fence_proxy_async_shared_cta(); + __syncthreads(); + + // Issue TMA store(s) for this stage's outputs. + if (leading_thread) { + const int global_offset_Y = block_offset_Y + stage_offset_Y; + const int global_offset_X = block_offset_X + stage_offset_X; + const int global_offset_Y_tr = block_offset_Y_tr + stage_offset_X; + const int global_offset_X_tr = block_offset_X_tr + stage_offset_Y; + + if (DO_ROW) { + auto& sOut = *reinterpret_cast(sOut_ptr); + ptx::cp_async_bulk_tensor_2d_shared_to_global( + reinterpret_cast(&tensor_map_output), global_offset_X, global_offset_Y, + reinterpret_cast(&sOut[buff_out])); + } + if (DO_COL) { + auto& sOut_tr = *reinterpret_cast(sOut_tr_ptr); + ptx::cp_async_bulk_tensor_2d_shared_to_global( + reinterpret_cast(&tensor_map_output_t), global_offset_X_tr, + global_offset_Y_tr, reinterpret_cast(&sOut_tr[buff_out_tr])); + } + ptx::cp_async_bulk_commit_group(); + } + + buff_in = (buff_in + 1) % BUFFS_NUM; + buff_out = (buff_out + 1) % BUFFS_NUM_OUT; + buff_out_tr = (buff_out_tr + 1) % BUFFS_NUM_OUT_TR; + } // end of stages + + // Vectorized SF scatter. kWithSwizzle=false: compact M-major (downstream + // nvte_swizzle_scaling_factors re-permutes). kWithSwizzle=true: emit cuBLAS + // LT 128Mx4K tile layout directly; thread mapping below is perf-critical. + if (DO_ROW) { + auto& sSFrowwise = *reinterpret_cast(sSFrowwise_ptr); + if constexpr (kWithSwizzle) { + // uint64_t SMEM load below assumes each sSFrowwise row is exactly 8 bytes + // (= 2 K-tiles of 4 K-bytes each); any other geometry needs a different + // pack/store split. + static_assert(SCALES_PER_CHUNK_X == 8, + "fused-swizzle rowwise scatter assumes SCALES_PER_CHUNK_X == 8"); + const int tid = threadIdx.x; + const int b = tid & 3; // M-stripe [0, 4), fast axis -> coalesced gmem + const int ty = tid >> 2; // slot index within K-tile [0, 32) + const int lm = b * 32 + ty; // [0, 128) + const size_t M_tile_idx = ctaid_Y; + const size_t K_tile_global_base = ctaid_X * (SCALES_PER_CHUNK_X / 4); // 2 + + // Single 8-byte SMEM load (vs 2 x 4-byte) halves the SMEM access count + // and degrades the bank conflict from 4-way to 2-way (each lane touches + // 2 adjacent banks at the same lm row instead of 1 bank twice). + const uint64_t packed_all = *reinterpret_cast(&sSFrowwise[lm][0]); + const uint32_t packed_lo = static_cast(packed_all); + const uint32_t packed_hi = static_cast(packed_all >> 32); + + const size_t base_byte = M_tile_idx * CHUNK_DIM_Y * scale_stride + K_tile_global_base * 512 + + static_cast(ty) * 16 + static_cast(b) * 4; + *reinterpret_cast(&scales_ptr[base_byte]) = packed_lo; + *reinterpret_cast(&scales_ptr[base_byte + 512]) = packed_hi; + } else { + using ScalesVec = Vec; + const int chunk_cols = static_cast(cols) - block_offset_X; + const int count = min(SCALES_PER_CHUNK_X, chunk_cols / SCALE_DIM); + + for (size_t row = threadIdx.x; row < CHUNK_DIM_Y; row += THREADS_NUM) { + const size_t row_global = scales_block_offset_Y_rowwise + row; + if (row_global < rows) { + ScalesVec& scales_vec = *reinterpret_cast(sSFrowwise[row]); + const size_t scale_idx_global = row_global * scale_stride + scales_block_offset_X_rowwise; + scales_vec.store_to_elts(&scales_ptr[scale_idx_global], 0, count); + } + } + } + } + if (DO_COL) { + auto& sSFcolwise = *reinterpret_cast(sSFcolwise_ptr); + using ScalesVec = Vec; + const int chunk_rows = static_cast(rows) - block_offset_Y; + const int count = min(SCALES_PER_CHUNK_Y, chunk_rows / SCALE_DIM); + + for (size_t row_tr = threadIdx.x; row_tr < CHUNK_DIM_X; row_tr += THREADS_NUM) { + const size_t row_tr_global = scales_block_offset_Y_tr + row_tr; + if (row_tr_global < cols) { + ScalesVec& scales_vec = *reinterpret_cast(sSFcolwise[row_tr]); + const size_t scale_idx_global = row_tr_global * scale_stride_t + scales_block_offset_X_tr; + scales_vec.store_to_elts(&scales_t_ptr[scale_idx_global], 0, count); + } + } + } + + if (leading_thread) { +#pragma unroll + for (int buff = 0; buff < BUFFS_NUM; ++buff) { + ptx::mbarrier_invalid(&IN_buff_readable_mbar[buff]); + } + } +#else + NVTE_DEVICE_ERROR("Per-token encode kernel requires SM 10.0+ (Blackwell)."); +#endif // __CUDA_ARCH__ >= 1000 +} + +// ============================================================================= +// Kernel 1: per-token amax (rowwise + colwise atomicMaxFloat). +// +// Same TMA + mbarrier + 64x64 sub-tile + ping-pong pipeline as the encode +// kernel above, just with compute = abs + reduce instead of FP4 encode. +// +// Compute mapping (one thread per output slot): +// tid t in [0, 128): +// row partial: max over (cols 0..127) for row (row_base + t) +// col partial: max over (rows 0..127) for col (col_base + t) +// For each 64x64 sub-tile in stage (stage_Y, stage_X): +// if t in [stage_Y*64, stage_Y*64+64): scan 64 cols of sub-tile for row t +// if t in [stage_X*64, stage_X*64+64): scan 64 rows of sub-tile for col t +// After all 4 stages, emit one atomicMaxFloat per row slot + one per col slot. +// +// kWithRht=true: col-wise amax over RHT-rotated 16-row strips (per-thread +// FHT with random_sign_mask_t). Row direction never sees RHT. +// ============================================================================= +template +__global__ void __launch_bounds__(THREADS_NUM) + per_token_amax_kernel(const __grid_constant__ CUtensorMap tensor_map_input, + float* __restrict__ row_amax_out, // [M], nullptr if !DO_ROW + float* __restrict__ col_amax_out, // [K], nullptr if !DO_COL + const float* noop, const size_t rows, const size_t cols, + const uint32_t random_sign_mask_t) { // col-only; low 16 bits = signs +#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) + if (noop != nullptr && noop[0] == 1.0f) { + return; + } + + const bool leading_thread = (threadIdx.x == 0); + const int tid = threadIdx.x; + + constexpr int buff_elems_total_in = BUFFS_NUM * BUFF_IN_SIZE; + constexpr int buff_size_aligned_in = + DIVUP_TO_MULTIPLE(buff_elems_total_in * sizeof(IType), TMA_SHMEM_ALIGNMENT); + + extern __shared__ unsigned char dynamic_shmem[]; + unsigned char* dshmem = align_smem_ptr_per_TMA_requirements(dynamic_shmem); + IType* sIn_ptr = reinterpret_cast(dshmem); + auto& sIn = *reinterpret_cast(sIn_ptr); + + __shared__ uint64_t IN_buff_readable_mbar[BUFFS_NUM]; + constexpr int shmem_buff_size = buff_size_aligned_in / BUFFS_NUM; + + const int32_t ctaid_X = blockIdx.x; + const int32_t ctaid_Y = blockIdx.y; + const int block_offset_Y = ctaid_Y * CHUNK_DIM_Y; + const int block_offset_X = ctaid_X * CHUNK_DIM_X; + + // Per-thread row & col partial accumulators (each thread owns 1 of each). + float row_partial = 0.f; + float col_partial = 0.f; + + // Which row / col does THIS thread own within the 128x128 chunk? + // row owned: row_base + tid -> needs sub-tile rows [stage_Y*64, +64) + // i.e., this thread contributes to row partial in stages + // where stage_Y == tid / 64. + // col owned: col_base + tid -> stage_X == tid / 64. + const int my_row_stage_Y = tid / TILE_DIM_Y; // 0 or 1 + const int my_col_stage_X = tid / TILE_DIM_X; // 0 or 1 + const int my_row_in_subtile = tid % TILE_DIM_Y; // 0..63 + const int my_col_in_subtile = tid % TILE_DIM_X; // 0..63 + + if (leading_thread) { +#pragma unroll + for (int buff = 0; buff < BUFFS_NUM; ++buff) { + ptx::mbarrier_init(&IN_buff_readable_mbar[buff], 1); + } + ptx::fence_proxy_async_shared_cta(); + } + __syncthreads(); + + // Prefetch stage 0. +#pragma unroll + for (int stage = 0; stage < PREFETCH_STAGES; ++stage) { + const int buff_in = stage; + const int stage_Y = stage / TILES_X; + const int stage_X = stage % TILES_X; + const int global_offset_Y = block_offset_Y + stage_Y * TILE_DIM_Y; + const int global_offset_X = block_offset_X + stage_X * TILE_DIM_X; + if (leading_thread) { + ptx::mbarrier_arrive_expect_tx(&IN_buff_readable_mbar[buff_in], shmem_buff_size); + ptx::cp_async_bulk_tensor_2d_global_to_shared( + reinterpret_cast(&sIn[buff_in]), + reinterpret_cast(&tensor_map_input), global_offset_X, global_offset_Y, + &IN_buff_readable_mbar[buff_in]); + } + } + + int buff_in = 0; + int IN_buff_readable_parity[BUFFS_NUM] = {0, 0}; + +#pragma unroll + for (int stage = 0; stage < STAGES; ++stage) { + const int stage_Y = stage / TILES_X; + const int stage_X = stage % TILES_X; + + // Prefetch next stage. + if (stage < STAGES - PREFETCH_STAGES) { + const int next_prefetch_buff = (buff_in + PREFETCH_STAGES) % BUFFS_NUM; + const int next_prefetch_stage = (stage + PREFETCH_STAGES) % STAGES; + const int next_stage_Y = next_prefetch_stage / TILES_X; + const int next_stage_X = next_prefetch_stage % TILES_X; + const int next_global_offset_Y = block_offset_Y + next_stage_Y * TILE_DIM_Y; + const int next_global_offset_X = block_offset_X + next_stage_X * TILE_DIM_X; + if (leading_thread) { + ptx::mbarrier_arrive_expect_tx(&IN_buff_readable_mbar[next_prefetch_buff], shmem_buff_size); + ptx::cp_async_bulk_tensor_2d_global_to_shared( + reinterpret_cast(&sIn[next_prefetch_buff]), + reinterpret_cast(&tensor_map_input), next_global_offset_X, + next_global_offset_Y, &IN_buff_readable_mbar[next_prefetch_buff]); + } + ptx::fence_proxy_async_shared_cta(); + } + + // Wait for this stage's tile. + ptx::mbarrier_wait_parity_acquire_cta_shared_cta(&IN_buff_readable_mbar[buff_in], + IN_buff_readable_parity[buff_in]); + IN_buff_readable_parity[buff_in] ^= 1; + + // ----- Row partial update: walk this thread's row across the sub-tile ----- + if (DO_ROW && stage_Y == my_row_stage_Y) { + // 32 warp lanes each own a distinct row but read col-offset e in lockstep; + // SMEM row stride is 64*sizeof(bf16) = 128 B = exactly 32 banks, so every + // lane lands on the same bank set -> 32-way bank conflict per LDS.128. + // Rotate the e-iter visit order by (my_row_in_subtile >> 2) so that lanes + // in distinct row-quads pick distinct e values per iter, splitting the + // warp into 8 disjoint bank groups (4-way conflict, 8x reduction). + // Per-thread data set unchanged; max() is associative & commutative => byte-equal. + float local_max = row_partial; + const int row_bank_group = (my_row_in_subtile >> 2) & 0x7; +#pragma unroll + for (int e_iter = 0; e_iter < 8; ++e_iter) { + const int e = ((e_iter + row_bank_group) & 0x7) << 3; + __uint128_t elts_8x = ptx::ld_shared_b128(&sIn[buff_in][my_row_in_subtile][e]); + const IType2* pairs = reinterpret_cast(&elts_8x); + IType2 amax_2x = {static_cast(0.0f), static_cast(0.0f)}; +#pragma unroll + for (int p = 0; p < 4; ++p) { + ptx::abs_max_2x(amax_2x, amax_2x, pairs[p]); + } + local_max = + fmaxf(local_max, static_cast(__hmax(__habs(amax_2x.x), __habs(amax_2x.y)))); + } + row_partial = local_max; + } + + // ----- Col partial update: walk this thread's col down the sub-tile ----- + if (DO_COL && stage_X == my_col_stage_X) { + if constexpr (kWithRht) { + // 4 contiguous 16-row blocks per sub-tile, one FHT per block; amax + // is taken over the rotated values. +#pragma unroll + for (int blk = 0; blk < TILE_DIM_Y / 16; ++blk) { + float r[16]; +#pragma unroll + for (int i = 0; i < 16; ++i) { + r[i] = static_cast(sIn[buff_in][blk * 16 + i][my_col_in_subtile]); + } + apply_signed_fht16_inplace(r, random_sign_mask_t); + col_partial = fmaxf(col_partial, amax_16_abs(r) * k16HadamardNorm); + } + } else { + // Scan 64 rows for our col. Single-column access pattern (1 byte stride + // per row in SMEM); we read 1 bf16 at a time. Bank conflicts mitigated + // by 64-wide tile (column stride = TILE_DIM_X * 2 = 128 bytes, which is + // 1 bank * 32 rows; with 32 threads on different cols, conflicts hit + // groups of 32 -> serialized 32-way, accepted for v1). + float local_max = col_partial; +#pragma unroll + for (int e = 0; e < TILE_DIM_Y; ++e) { + const IType v = sIn[buff_in][e][my_col_in_subtile]; + local_max = fmaxf(local_max, fabsf(static_cast(v))); + } + col_partial = local_max; + } + } + + __syncthreads(); + buff_in = (buff_in + 1) % BUFFS_NUM; + } + + // ----- Cross-CTA reduction: 1 atomicMaxFloat per row/col slot per CTA ----- + if (DO_ROW) { + atomicMaxFloat(&row_amax_out[block_offset_Y + tid], row_partial); + } + if (DO_COL) { + atomicMaxFloat(&col_amax_out[block_offset_X + tid], col_partial); + } + + if (leading_thread) { +#pragma unroll + for (int buff = 0; buff < BUFFS_NUM; ++buff) { + ptx::mbarrier_invalid(&IN_buff_readable_mbar[buff]); + } + } +#else + NVTE_DEVICE_ERROR("Per-token amax kernel requires SM 10.0+ (Blackwell)."); +#endif // __CUDA_ARCH__ >= 1000 +} + +#endif // FP4_TYPE_SUPPORTED (closes the kernels block opened at line 69) + +// ============================================================================= +// Launchers +// ============================================================================= + +#if FP4_TYPE_SUPPORTED +// Launch Kernel 1 (amax). Pre-zeroes the amax buffers (atomicMax identity). +// with_rht=true applies a 16-pt RHT on the col direction before amax; +// random_sign_mask_t carries the 16-bit sign pattern (ignored when false). +inline void launch_amax(const Tensor& input, Tensor* output, const Tensor& noop, + const bool with_rht, const uint32_t random_sign_mask_t, + cudaStream_t stream) { + const size_t M = input.flat_first_dim(); + const size_t K = input.flat_last_dim(); + + const bool do_row = (output->amax.dptr != nullptr); + const bool do_col = (output->columnwise_amax.dptr != nullptr); + if (!do_row && !do_col) return; + + // Pre-zero amax buffers (atomicMaxFloat identity for non-negative values). + if (do_row) { + NVTE_CHECK(output->amax.numel() == M, "Per-token amax: output->amax numel must equal M = ", M, + ", got ", output->amax.numel()); + NVTE_CHECK_CUDA(cudaMemsetAsync(output->amax.dptr, 0, M * sizeof(float), stream)); + } + if (do_col) { + NVTE_CHECK(output->columnwise_amax.numel() == K, + "Per-token amax: output->columnwise_amax numel must equal K = ", K, ", got ", + output->columnwise_amax.numel()); + NVTE_CHECK_CUDA(cudaMemsetAsync(output->columnwise_amax.dptr, 0, K * sizeof(float), stream)); + } + + checkCuDriverContext(stream); + + alignas(64) CUtensorMap tmap_in{}; + create_2D_tensor_map(tmap_in, input.data, M, K, TILE_DIM_Y, TILE_DIM_X, K, 0, sizeof(IType) * 8); + + constexpr int buff_elems_total_in = BUFFS_NUM * BUFF_IN_SIZE; + constexpr int buff_size_aligned_in = + DIVUP_TO_MULTIPLE(buff_elems_total_in * sizeof(IType), TMA_SHMEM_ALIGNMENT); + constexpr int dshmem_size = buff_size_aligned_in + TMA_SHMEM_ALIGNMENT; // + align pad + + dim3 grid(static_cast(K / CHUNK_DIM_X), static_cast(M / CHUNK_DIM_Y), 1); + dim3 block(THREADS_NUM, 1, 1); + + const float* noop_ptr = + (noop.data.dptr != nullptr) ? reinterpret_cast(noop.data.dptr) : nullptr; + + // RHT only matters when colwise amax is computed; collapse to the + // kWithRht=false instantiation otherwise. + const bool with_rht_effective = with_rht && do_col; + TRANSFORMER_ENGINE_SWITCH_CONDITION( + do_row, DO_ROW, + TRANSFORMER_ENGINE_SWITCH_CONDITION( + do_col, DO_COL, TRANSFORMER_ENGINE_SWITCH_CONDITION(with_rht_effective, kWithRht, { + auto kernel = per_token_amax_kernel; + cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, dshmem_size); + kernel<<>>( + tmap_in, do_row ? reinterpret_cast(output->amax.dptr) : nullptr, + do_col ? reinterpret_cast(output->columnwise_amax.dptr) : nullptr, noop_ptr, + M, K, random_sign_mask_t); + }))); + NVTE_CHECK_CUDA(cudaGetLastError()); +} + +// Launch K2 encode. Requires pre-filled amax/columnwise_amax; writes data + +// scale_inv (both directions). with_rht requires K1 to have run with the +// SAME mask. with_swizzle: rowwise SF in cuBLAS LT layout (rowwise-only). +inline void launch_encode(const Tensor& input, Tensor* output, const Tensor& noop, + const bool with_rht, const uint32_t random_sign_mask_t, + const bool with_swizzle, cudaStream_t stream) { + const size_t M = input.flat_first_dim(); + const size_t K = input.flat_last_dim(); + + const bool do_row = output->has_data(); + const bool do_col = output->has_columnwise_data(); + if (!do_row && !do_col) return; + + if (do_row) { + NVTE_CHECK(output->amax.dptr != nullptr, + "Per-token encode: output->amax (per-row, [M]) must be pre-filled."); + NVTE_CHECK(output->data.dptr != nullptr, + "Per-token encode: output->data (rowwise FP4) must be allocated."); + NVTE_CHECK(output->scale_inv.dptr != nullptr, + "Per-token encode: output->scale_inv must be allocated."); + } + if (do_col) { + NVTE_CHECK(output->columnwise_amax.dptr != nullptr, + "Per-token encode: output->columnwise_amax (per-col, [K]) must be pre-filled."); + NVTE_CHECK(output->columnwise_data.dptr != nullptr, + "Per-token encode: output->columnwise_data must be allocated."); + NVTE_CHECK(output->columnwise_scale_inv.dptr != nullptr, + "Per-token encode: output->columnwise_scale_inv must be allocated."); + } + + checkCuDriverContext(stream); + + alignas(64) CUtensorMap tmap_in{}; + alignas(64) CUtensorMap tmap_out{}; + alignas(64) CUtensorMap tmap_out_t{}; + + create_2D_tensor_map(tmap_in, input.data, M, K, TILE_DIM_Y, TILE_DIM_X, K, 0, sizeof(IType) * 8); + if (do_row) { + create_2D_tensor_map(tmap_out, output->data, M, K, TILE_DIM_Y, TILE_DIM_X, K, 0, 4); + } + if (do_col) { + create_2D_tensor_map(tmap_out_t, output->columnwise_data, K, M, TILE_DIM_X, TILE_DIM_Y, M, 0, + 4); + } + + constexpr int buff_elems_total_in = BUFFS_NUM * BUFF_IN_SIZE; + constexpr int buff_size_aligned_in = + DIVUP_TO_MULTIPLE(buff_elems_total_in * sizeof(IType), TMA_SHMEM_ALIGNMENT); + constexpr int buff_size_aligned_out = + DIVUP_TO_MULTIPLE(BUFFS_NUM_OUT * BUFF_OUT_SIZE, TMA_SHMEM_ALIGNMENT); + constexpr int buff_size_aligned_out_t = + DIVUP_TO_MULTIPLE(BUFFS_NUM_OUT_TR * BUFF_OUT_TR_SIZE, TMA_SHMEM_ALIGNMENT); + constexpr int buff_size_scales = DIVUP_TO_MULTIPLE( + CHUNK_DIM_Y * SCALES_PER_CHUNK_X * sizeof(nvfp4_scale_t), TMA_SHMEM_ALIGNMENT); + constexpr int buff_size_scales_t = DIVUP_TO_MULTIPLE( + CHUNK_DIM_X * SCALES_PER_CHUNK_Y * sizeof(nvfp4_scale_t), TMA_SHMEM_ALIGNMENT); + + // Total dyn SMEM: input + output FP4 (row + col) + SF (row + col) + 128B align. + const int dshmem_size = buff_size_aligned_in + (do_row ? buff_size_aligned_out : 0) + + (do_col ? buff_size_aligned_out_t : 0) + (do_row ? buff_size_scales : 0) + + (do_col ? buff_size_scales_t : 0) + TMA_SHMEM_ALIGNMENT; + + dim3 grid(static_cast(K / CHUNK_DIM_X), static_cast(M / CHUNK_DIM_Y), 1); + dim3 block(THREADS_NUM, 1, 1); + + const float* noop_ptr = + (noop.data.dptr != nullptr) ? reinterpret_cast(noop.data.dptr) : nullptr; + const size_t scale_stride = do_row ? output->scale_inv.shape[1] : 0; + const size_t scale_stride_t = do_col ? output->columnwise_scale_inv.shape[1] : 0; + + nvfp4_scale_t* scales_ptr = + do_row ? reinterpret_cast(output->scale_inv.dptr) : nullptr; + nvfp4_scale_t* scales_t_ptr = + do_col ? reinterpret_cast(output->columnwise_scale_inv.dptr) : nullptr; + const float* row_amax_in = do_row ? reinterpret_cast(output->amax.dptr) : nullptr; + const float* col_amax_in = + do_col ? reinterpret_cast(output->columnwise_amax.dptr) : nullptr; + + // RHT only matters with colwise FP4 -> collapse to kWithRht=false for + // rowwise-only callers; swizzle only matters with rowwise FP4 -> + // collapse to kWithSwizzle=false for colwise-only callers. + const bool with_rht_effective = with_rht && do_col; + const bool with_swizzle_effective = with_swizzle && do_row; + TRANSFORMER_ENGINE_SWITCH_CONDITION( + do_row, DO_ROW, + TRANSFORMER_ENGINE_SWITCH_CONDITION( + do_col, DO_COL, + TRANSFORMER_ENGINE_SWITCH_CONDITION( + with_rht_effective, kWithRht, + TRANSFORMER_ENGINE_SWITCH_CONDITION(with_swizzle_effective, kWithSwizzle, { + auto kernel = per_token_encode_kernel; + cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, + dshmem_size); + kernel<<>>( + tmap_in, tmap_out, tmap_out_t, scales_ptr, scales_t_ptr, row_amax_in, + col_amax_in, noop_ptr, M, K, scale_stride, scale_stride_t, random_sign_mask_t); + })))); + NVTE_CHECK_CUDA(cudaGetLastError()); +} +#endif // FP4_TYPE_SUPPORTED + +// ============================================================================= +// Impls (validation + dispatch). The K1 amax / K2 encode passes are exposed +// as separately callable entry points alongside the composite K1+K2 entry, +// to enable per-kernel benchmarking and diagnostic use. +// ============================================================================= + +#if FP4_TYPE_SUPPORTED +// Common input + shape validation, shared by all 3 entry points. +// Output constraints differ by entry point (see validate_*_output helpers below). +inline void validate_input_shape(const Tensor& input) { + NVTE_CHECK(input.has_data(), "Per-token cast: input has no data."); + NVTE_CHECK(input.dtype() == DType::kBFloat16, "Per-token cast is bf16-only. Got dtype enum ", + static_cast(input.dtype())); + const size_t M = input.flat_first_dim(); + const size_t K = input.flat_last_dim(); + NVTE_CHECK(M % CHUNK_DIM_Y == 0, "Per-token cast: M must be a multiple of ", CHUNK_DIM_Y, + ", got M=", M); + NVTE_CHECK(K % CHUNK_DIM_X == 0, "Per-token cast: K must be a multiple of ", CHUNK_DIM_X, + ", got K=", K); +} + +// K1 (amax-only) requires at least one amax buffer allocated; FP4 output is not used. +inline void validate_amax_output(const Tensor* output) { + NVTE_CHECK(output->amax.dptr != nullptr || output->columnwise_amax.dptr != nullptr, + "Per-token K1 (amax): at least one of rowwise/columnwise amax buffer " + "must be allocated."); +} + +// K2 / composite require >=1 FP4 output buffer. with_swizzle=true: rowwise +// SF emitted in cuBLAS LT swizzled layout (caller sets with_gemm_swizzled_scales). +inline void validate_encode_output(const Tensor* output, const bool with_swizzle) { + NVTE_CHECK(output->has_data() || output->has_columnwise_data(), + "Per-token K2 (encode): at least one of rowwise/columnwise FP4 output " + "must be allocated."); + if (!with_swizzle) { + NVTE_CHECK(!output->with_gemm_swizzled_scales, + "Per-token cast emits compact (non-swizzled) inner SF unless " + "with_swizzle=true is passed."); + } +} + +// K1 amax with optional col-wise RHT. with_rht=false is byte-equal to the +// pre-RHT per-token K1 path regardless of random_sign_mask_t. +void per_token_amax_blocked_impl(const Tensor& input, const Tensor& noop, Tensor* output, + const bool with_rht, const uint32_t random_sign_mask_t, + cudaStream_t stream) { + validate_input_shape(input); + validate_amax_output(output); + if (input.flat_first_dim() == 0 || input.flat_last_dim() == 0) return; + launch_amax(input, output, noop, with_rht, random_sign_mask_t, stream); +} + +// K2 encode with optional col-wise RHT + fused rowwise swizzle. with_rht +// requires K1 amax to have been launched with the SAME mask, else the inner +// SF + FP4 codes are calibrated against mismatched data and saturate. +void per_token_encode_blocked_impl(const Tensor& input, const Tensor& noop, Tensor* output, + const bool with_rht, const uint32_t random_sign_mask_t, + const bool with_swizzle, cudaStream_t stream) { + validate_input_shape(input); + validate_encode_output(output, with_swizzle); + if (input.flat_first_dim() == 0 || input.flat_last_dim() == 0) return; + launch_encode(input, output, noop, with_rht, random_sign_mask_t, with_swizzle, stream); +} + +// Composite K1+K2. Both launches receive the same with_rht / mask so the +// colwise amax and FP4 cast see byte-identical data. +void per_token_quantize_blocked_impl(const Tensor& input, const Tensor& noop, Tensor* output, + const bool with_rht, const uint32_t random_sign_mask_t, + const bool with_swizzle, cudaStream_t stream) { + validate_input_shape(input); + validate_encode_output(output, with_swizzle); + if (input.flat_first_dim() == 0 || input.flat_last_dim() == 0) return; + launch_amax(input, output, noop, with_rht, random_sign_mask_t, stream); + launch_encode(input, output, noop, with_rht, random_sign_mask_t, with_swizzle, stream); +} + +bool can_use_per_token(size_t M, size_t K, DType dtype) { + return (dtype == DType::kBFloat16) && (M % CHUNK_DIM_Y == 0) && (K % CHUNK_DIM_X == 0); +} +#else // !FP4_TYPE_SUPPORTED +void per_token_amax_blocked_impl(const Tensor&, const Tensor&, Tensor*, bool, uint32_t, + cudaStream_t) { + NVTE_ERROR("NVFP4 requires SM100 (Blackwell); build with sm_100a/sm_100f."); +} +void per_token_encode_blocked_impl(const Tensor&, const Tensor&, Tensor*, bool, uint32_t, bool, + cudaStream_t) { + NVTE_ERROR("NVFP4 requires SM100 (Blackwell); build with sm_100a/sm_100f."); +} +void per_token_quantize_blocked_impl(const Tensor&, const Tensor&, Tensor*, bool, uint32_t, bool, + cudaStream_t) { + NVTE_ERROR("NVFP4 requires SM100 (Blackwell); build with sm_100a/sm_100f."); +} +bool can_use_per_token(size_t, size_t, DType) { return false; } +#endif // FP4_TYPE_SUPPORTED + +} // namespace nvfp4_per_token +} // namespace transformer_engine + +// ============================================================================= +// C-API entry points +// ============================================================================= + +void nvte_nvfp4_per_token_amax(const NVTETensor input, const NVTETensor noop, NVTETensor output, + const int with_rht, const int random_sign_mask_t, + cudaStream_t stream) { +#if FP4_TYPE_SUPPORTED + NVTE_API_CALL(nvte_nvfp4_per_token_amax); + using namespace transformer_engine; + const Tensor* input_tensor = convertNVTETensorCheck(input); + Tensor* output_tensor = convertNVTETensorCheck(output); + Tensor dummy_noop; + const Tensor* noop_tensor = (noop != nullptr) ? convertNVTETensorCheck(noop) : &dummy_noop; + // C-API takes `int` to match prod's nvte_hadamard_transform_amax convention; + // internally we treat the low 16 bits as a uint32_t bitmask. + nvfp4_per_token::per_token_amax_blocked_impl( + *input_tensor, *noop_tensor, output_tensor, with_rht != 0, + static_cast(random_sign_mask_t) & 0xFFFFu, stream); +#else + (void)input; + (void)noop; + (void)output; + (void)with_rht; + (void)random_sign_mask_t; + (void)stream; + NVTE_ERROR("FP4 support requires CUDA 12.8+, but compile-time CUDA version is ", CUDA_VERSION); +#endif +} + +void nvte_nvfp4_per_token_encode(const NVTETensor input, const NVTETensor noop, NVTETensor output, + const int with_rht, const int random_sign_mask_t, + const int with_swizzle, cudaStream_t stream) { +#if FP4_TYPE_SUPPORTED + NVTE_API_CALL(nvte_nvfp4_per_token_encode); + using namespace transformer_engine; + const Tensor* input_tensor = convertNVTETensorCheck(input); + Tensor* output_tensor = convertNVTETensorCheck(output); + Tensor dummy_noop; + const Tensor* noop_tensor = (noop != nullptr) ? convertNVTETensorCheck(noop) : &dummy_noop; + // C-API mirrors nvte_nvfp4_per_token_amax: `int` for cross-language ABI + // safety, internal kernel arg is uint32_t with only the low 16 bits used. + nvfp4_per_token::per_token_encode_blocked_impl( + *input_tensor, *noop_tensor, output_tensor, with_rht != 0, + static_cast(random_sign_mask_t) & 0xFFFFu, with_swizzle != 0, stream); +#else + (void)input; + (void)noop; + (void)output; + (void)with_rht; + (void)random_sign_mask_t; + (void)with_swizzle; + (void)stream; + NVTE_ERROR("FP4 support requires CUDA 12.8+, but compile-time CUDA version is ", CUDA_VERSION); +#endif +} + +void nvte_nvfp4_per_token_quantize(const NVTETensor input, const NVTETensor noop, NVTETensor output, + const int with_rht, const int random_sign_mask_t, + const int with_swizzle, cudaStream_t stream) { +#if FP4_TYPE_SUPPORTED + NVTE_API_CALL(nvte_nvfp4_per_token_quantize); + using namespace transformer_engine; + const Tensor* input_tensor = convertNVTETensorCheck(input); + Tensor* output_tensor = convertNVTETensorCheck(output); + Tensor dummy_noop; + const Tensor* noop_tensor = (noop != nullptr) ? convertNVTETensorCheck(noop) : &dummy_noop; + nvfp4_per_token::per_token_quantize_blocked_impl( + *input_tensor, *noop_tensor, output_tensor, with_rht != 0, + static_cast(random_sign_mask_t) & 0xFFFFu, with_swizzle != 0, stream); +#else + (void)input; + (void)noop; + (void)output; + (void)with_rht; + (void)random_sign_mask_t; + (void)with_swizzle; + (void)stream; + NVTE_ERROR("FP4 support requires CUDA 12.8+, but compile-time CUDA version is ", CUDA_VERSION); +#endif +} + +int nvte_nvfp4_per_token_can_dispatch(size_t M, size_t K, int input_dtype_enum) { + using namespace transformer_engine; + const DType dtype = static_cast(input_dtype_enum); + return nvfp4_per_token::can_use_per_token(M, K, dtype) ? 1 : 0; +} diff --git a/transformer_engine/common/cast/nvfp4/quantize_nvfp4_per_token_group.cu b/transformer_engine/common/cast/nvfp4/quantize_nvfp4_per_token_group.cu new file mode 100644 index 0000000000..a0be8d184a --- /dev/null +++ b/transformer_engine/common/cast/nvfp4/quantize_nvfp4_per_token_group.cu @@ -0,0 +1,1123 @@ +/************************************************************************* + * Copyright (c) 2025-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +/*! \file quantize_nvfp4_per_token_group.cu + * \brief Grouped NVFP4 per-token cast: bf16 input (sum_M, K), splits along + * M; K1 fused row+col amax + K2 row + col cast. Requires K % 128 == 0 + * and every split_sections[i] % 128 == 0. + */ + +#include +#include +#include +#include +#include + +#include +#include + +#include "common/cast/core/common.cuh" +#include "common/cast/nvfp4/core_nvfp4.cuh" +#include "common/common.h" +#include "common/util/ptx.cuh" +#include "common/utils.cuh" + +namespace transformer_engine { +namespace nvfp4_per_token_group { + +#if FP4_TYPE_SUPPORTED + +using dispatch::nvfp4::nvfp4_scale_t; +using dispatch::nvfp4::core::compute_global_encode_scaling_factor_FP4; +using dispatch::nvfp4::quantization_SF::compute_decoding_scaling_factor; +using ptx::FPx2; + +constexpr int kInnerK = 16; // NVFP4 inner block: 16 elements per e4m3 SF + +// 64-tensor cap so the args struct fits under the 4 KB launch-param limit. +constexpr int kMaxTensorsPerKernel = 64; + +// Per-launch arg table; passed as __grid_constant__ for constant-cache reads. +struct NVFP4PerTokenMultiArgs { + // K1 outputs (per-tensor pointers; one fp32 array per tensor) + void* row_amax_list[kMaxTensorsPerKernel]; // each: float* (M_i,) + void* col_amax_list[kMaxTensorsPerKernel]; // each: float* (K,) + + // K2 outputs (per-tensor pointers; FP4 codes + e4m3 inner SF) + void* q_row_list[kMaxTensorsPerKernel]; // each: uint8* (M_i, K/2) + void* s_dec_row_list[kMaxTensorsPerKernel]; // each: fp8e4m3* (M_i, K/16) + void* q_col_list[kMaxTensorsPerKernel]; // each: uint8* (K, M_i/2) + void* s_dec_col_list[kMaxTensorsPerKernel]; // each: fp8e4m3* (K, M_i/16) + + // Shared layout info + int split_sections_range[kMaxTensorsPerKernel + 1]; // prefix sum w/ leading 0 + int num_tensors; +}; + +__device__ __forceinline__ int GetTensorId(const NVFP4PerTokenMultiArgs& args, int global_row) { + const int n = args.num_tensors; + if (global_row >= args.split_sections_range[n]) return n - 1; + int tid = 0; + while (args.split_sections_range[tid + 1] <= global_row) ++tid; + return tid; +} + +// Fused K1: TMA-loaded SMEM tile feeds row+col amax; routes atomicMax to the +// per-tensor buffer via tensor_id lookup at CTA entry. +namespace fused { + +constexpr int CHUNK_DIM_Y = 128; // CTA covers this many rows +constexpr int CHUNK_DIM_X = 128; // CTA covers this many cols +constexpr int TILE_DIM_Y = 64; // TMA bulk-2D box height +constexpr int TILE_DIM_X = 64; // TMA bulk-2D box width +constexpr int THREADS_NUM = 128; +constexpr int PREFETCH_STAGES = 1; +constexpr int BUFFS_NUM = PREFETCH_STAGES + 1; +constexpr int TILES_Y = CHUNK_DIM_Y / TILE_DIM_Y; // 2 +constexpr int TILES_X = CHUNK_DIM_X / TILE_DIM_X; // 2 +constexpr int STAGES = TILES_Y * TILES_X; // 4 + +constexpr int BUFF_IN_DIM_Y = TILE_DIM_Y; +constexpr int BUFF_IN_DIM_X = TILE_DIM_X; +constexpr int BUFF_IN_SIZE = BUFF_IN_DIM_Y * BUFF_IN_DIM_X; + +using FusedIType = bf16; +using FusedIType2 = ptx::FPx2; +using FusedIType3D = FusedIType[BUFFS_NUM][BUFF_IN_DIM_Y][BUFF_IN_DIM_X]; + +// Randomized Hadamard Transform helpers (per-thread, 16-wide). Direct copy +// of the single-tensor helpers in quantize_nvfp4_per_token.cu; K1 and K2 +// must consume identical output for FP4 + outer SF to be self-consistent. +// TODO: hoist into a shared core header. +__device__ __forceinline__ void apply_signed_fht16_inplace(float r[16], uint32_t random_sign_mask) { +#pragma unroll + for (int i = 0; i < 16; ++i) { + const uint32_t bits = __float_as_uint(r[i]); + const uint32_t flip = ((random_sign_mask >> i) & 1u) << 31; + r[i] = __uint_as_float(bits ^ flip); + } +#pragma unroll + for (int stride = 1; stride < 16; stride <<= 1) { +#pragma unroll + for (int g = 0; g < 16; g += stride << 1) { +#pragma unroll + for (int j = 0; j < stride; ++j) { + const float a = r[g + j]; + const float b = r[g + j + stride]; + r[g + j] = a + b; + r[g + j + stride] = a - b; + } + } + } +} + +__device__ __forceinline__ float amax_16_abs(const float r[16]) { + float m = 0.f; +#pragma unroll + for (int i = 0; i < 16; ++i) m = fmaxf(m, fabsf(r[i])); + return m; +} + +// 1/sqrt(16) Hadamard normalization, folded once per 1x16 block. +constexpr float k16HadamardNorm = 0.25f; + +// Pre-zero amax buffers (identity for atomicMax). +template +__global__ void group_per_token_fused_zero_amax_kernel(NVFP4PerTokenMultiArgs args, int K) { + const int tensor_id = blockIdx.x; + if (tensor_id >= args.num_tensors) return; + if (DO_ROW) { + float* row_amax = reinterpret_cast(args.row_amax_list[tensor_id]); + if (row_amax != nullptr) { + const int M_i = + args.split_sections_range[tensor_id + 1] - args.split_sections_range[tensor_id]; + for (int m = threadIdx.x; m < M_i; m += blockDim.x) { + row_amax[m] = 0.0f; + } + } + } + if (DO_COL) { + float* col_amax = reinterpret_cast(args.col_amax_list[tensor_id]); + if (col_amax != nullptr) { + for (int k = threadIdx.x; k < K; k += blockDim.x) { + col_amax[k] = 0.0f; + } + } + } +} + +// kWithRht=true: col-wise amax over RHT-rotated 16-row strips. Row direction +// never sees RHT. +template +__global__ void __launch_bounds__(THREADS_NUM) + group_per_token_fused_amax_kernel(const __grid_constant__ CUtensorMap tensor_map_input, + const __grid_constant__ NVFP4PerTokenMultiArgs args, + const float* noop, const size_t rows, const size_t cols, + const uint32_t random_sign_mask_t) { +#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) + if (noop != nullptr && noop[0] == 1.0f) { + return; + } + + const bool leading_thread = (threadIdx.x == 0); + const int tid = threadIdx.x; + + constexpr int buff_elems_total_in = BUFFS_NUM * BUFF_IN_SIZE; + constexpr int buff_size_aligned_in = + DIVUP_TO_MULTIPLE(buff_elems_total_in * sizeof(FusedIType), TMA_SHMEM_ALIGNMENT); + + extern __shared__ unsigned char dynamic_shmem[]; + unsigned char* dshmem = dispatch::common::align_smem_ptr_per_TMA_requirements(dynamic_shmem); + FusedIType* sIn_ptr = reinterpret_cast(dshmem); + auto& sIn = *reinterpret_cast(sIn_ptr); + + __shared__ uint64_t IN_buff_readable_mbar[BUFFS_NUM]; + constexpr int shmem_buff_size = buff_size_aligned_in / BUFFS_NUM; + + const int32_t ctaid_X = blockIdx.x; + const int32_t ctaid_Y = blockIdx.y; + const int block_offset_Y = ctaid_Y * CHUNK_DIM_Y; + const int block_offset_X = ctaid_X * CHUNK_DIM_X; + + // Tile lies fully inside one tensor (split_sections[i] % 128 == 0). + const int tensor_id = GetTensorId(args, block_offset_Y); + const int local_row_base = block_offset_Y - args.split_sections_range[tensor_id]; + float* row_amax_out = DO_ROW ? reinterpret_cast(args.row_amax_list[tensor_id]) : nullptr; + float* col_amax_out = DO_COL ? reinterpret_cast(args.col_amax_list[tensor_id]) : nullptr; + + // Each thread owns chunk-row `tid` (for row amax) and chunk-col `tid` (for col amax). + float row_partial = 0.f; + float col_partial = 0.f; + const int my_row_stage_Y = tid / TILE_DIM_Y; + const int my_col_stage_X = tid / TILE_DIM_X; + const int my_row_in_subtile = tid % TILE_DIM_Y; + const int my_col_in_subtile = tid % TILE_DIM_X; + + if (leading_thread) { +#pragma unroll + for (int buff = 0; buff < BUFFS_NUM; ++buff) { + ptx::mbarrier_init(&IN_buff_readable_mbar[buff], 1); + } + ptx::fence_proxy_async_shared_cta(); + } + __syncthreads(); + + // Prefetch stage 0. +#pragma unroll + for (int stage = 0; stage < PREFETCH_STAGES; ++stage) { + const int buff_in = stage; + const int stage_Y = stage / TILES_X; + const int stage_X = stage % TILES_X; + const int global_offset_Y = block_offset_Y + stage_Y * TILE_DIM_Y; + const int global_offset_X = block_offset_X + stage_X * TILE_DIM_X; + if (leading_thread) { + ptx::mbarrier_arrive_expect_tx(&IN_buff_readable_mbar[buff_in], shmem_buff_size); + ptx::cp_async_bulk_tensor_2d_global_to_shared( + reinterpret_cast(&sIn[buff_in]), + reinterpret_cast(&tensor_map_input), global_offset_X, global_offset_Y, + &IN_buff_readable_mbar[buff_in]); + } + } + + int buff_in = 0; + int IN_buff_readable_parity[BUFFS_NUM] = {0, 0}; + +#pragma unroll + for (int stage = 0; stage < STAGES; ++stage) { + const int stage_Y = stage / TILES_X; + const int stage_X = stage % TILES_X; + + // Prefetch next stage. + if (stage < STAGES - PREFETCH_STAGES) { + const int next_prefetch_buff = (buff_in + PREFETCH_STAGES) % BUFFS_NUM; + const int next_prefetch_stage = (stage + PREFETCH_STAGES) % STAGES; + const int next_stage_Y = next_prefetch_stage / TILES_X; + const int next_stage_X = next_prefetch_stage % TILES_X; + const int next_global_offset_Y = block_offset_Y + next_stage_Y * TILE_DIM_Y; + const int next_global_offset_X = block_offset_X + next_stage_X * TILE_DIM_X; + if (leading_thread) { + ptx::mbarrier_arrive_expect_tx(&IN_buff_readable_mbar[next_prefetch_buff], shmem_buff_size); + ptx::cp_async_bulk_tensor_2d_global_to_shared( + reinterpret_cast(&sIn[next_prefetch_buff]), + reinterpret_cast(&tensor_map_input), next_global_offset_X, + next_global_offset_Y, &IN_buff_readable_mbar[next_prefetch_buff]); + } + ptx::fence_proxy_async_shared_cta(); + } + + // Wait for this stage's tile. + ptx::mbarrier_wait_parity_acquire_cta_shared_cta(&IN_buff_readable_mbar[buff_in], + IN_buff_readable_parity[buff_in]); + IN_buff_readable_parity[buff_in] ^= 1; + + // Row partial: rotate e-iter by bank group to split warp into 8 groups. + if (DO_ROW && stage_Y == my_row_stage_Y) { + float local_max = row_partial; + const int row_bank_group = (my_row_in_subtile >> 2) & 0x7; +#pragma unroll + for (int e_iter = 0; e_iter < 8; ++e_iter) { + const int e = ((e_iter + row_bank_group) & 0x7) << 3; + __uint128_t elts_8x = ptx::ld_shared_b128(&sIn[buff_in][my_row_in_subtile][e]); + const FusedIType2* pairs = reinterpret_cast(&elts_8x); + FusedIType2 amax_2x = {static_cast(0.0f), static_cast(0.0f)}; +#pragma unroll + for (int p = 0; p < 4; ++p) { + ptx::abs_max_2x(amax_2x, amax_2x, pairs[p]); + } + local_max = + fmaxf(local_max, static_cast(__hmax(__habs(amax_2x.x), __habs(amax_2x.y)))); + } + row_partial = local_max; + } + + // Col partial: 1 thread per column scans down 64 rows of the sub-tile. + if (DO_COL && stage_X == my_col_stage_X) { + if constexpr (kWithRht) { + // 4 contiguous 16-row blocks per sub-tile, one FHT per block; 0.25 + // is folded post-amax (exact, since 0.25 = 2^-2). +#pragma unroll + for (int blk = 0; blk < TILE_DIM_Y / 16; ++blk) { + float r[16]; +#pragma unroll + for (int i = 0; i < 16; ++i) { + r[i] = static_cast(sIn[buff_in][blk * 16 + i][my_col_in_subtile]); + } + apply_signed_fht16_inplace(r, random_sign_mask_t); + col_partial = fmaxf(col_partial, amax_16_abs(r) * k16HadamardNorm); + } + } else { + float local_max = col_partial; +#pragma unroll + for (int e = 0; e < TILE_DIM_Y; ++e) { + const FusedIType v = sIn[buff_in][e][my_col_in_subtile]; + local_max = fmaxf(local_max, fabsf(static_cast(v))); + } + col_partial = local_max; + } + } + + __syncthreads(); + buff_in = (buff_in + 1) % BUFFS_NUM; + } + + // CTAs across (ctaid_X) share row_amax slots; across (ctaid_Y) share col_amax slots. + if (DO_ROW) { + atomicMaxFloat(&row_amax_out[local_row_base + tid], row_partial); + } + if (DO_COL) { + atomicMaxFloat(&col_amax_out[block_offset_X + tid], col_partial); + } + + if (leading_thread) { +#pragma unroll + for (int buff = 0; buff < BUFFS_NUM; ++buff) { + ptx::mbarrier_invalid(&IN_buff_readable_mbar[buff]); + } + } +#else + (void)tensor_map_input; + (void)args; + (void)noop; + (void)rows; + (void)cols; + (void)random_sign_mask_t; + NVTE_DEVICE_ERROR("Fused grouped per-token amax kernel requires SM 10.0+ (Blackwell)."); +#endif // __CUDA_ARCH__ >= 1000 +} + +// K2 (encode) constants + helpers; byte-equal port of the single-tensor +// per-token cooperative 4x32 / 32x4 threading + ld_shared_b128 + mul_cvt_4x. +constexpr int ELTS_PER_THREAD = 16; // = NVFP4 block size = SCALE_DIM +constexpr int SCALE_DIM = 16; // NVFP4 inner block (1x16) +constexpr int SCALES_PER_CHUNK_X = CHUNK_DIM_X / SCALE_DIM; // 8 +constexpr int SCALES_PER_CHUNK_Y = CHUNK_DIM_Y / SCALE_DIM; // 8 +constexpr int SCALES_PER_TILE_X = TILE_DIM_X / SCALE_DIM; // 4 +constexpr int SCALES_PER_TILE_Y = TILE_DIM_Y / SCALE_DIM; // 4 + +// Rowwise pass: 4 (K-dim) x 32 (M-dim) -> 1 NVFP4 block per thread. +constexpr int THREADS_X_ROWWISE = TILE_DIM_X / ELTS_PER_THREAD; // 4 +constexpr int THREADS_Y_ROWWISE = THREADS_NUM / THREADS_X_ROWWISE; // 32 +constexpr int THREADS_PER_SCALE_ROWWISE = SCALE_DIM / ELTS_PER_THREAD; // 1 +constexpr int ITERATIONS_NORMAL = TILE_DIM_Y / THREADS_Y_ROWWISE; // 2 + +// Output / SF SMEM buffer dims (sub-tile sized, double-buffered for ping-pong). +constexpr int BUFF_OUT_DIM_Y = TILE_DIM_Y; +constexpr int BUFF_OUT_DIM_X = (TILE_DIM_X * 4) / 8; // 32 (fp4e2m1x2 bytes) +constexpr int BUFF_OUT_SIZE = BUFF_OUT_DIM_Y * BUFF_OUT_DIM_X; +constexpr int BUFF_OUT_TR_DIM_Y = TILE_DIM_X; +constexpr int BUFF_OUT_TR_DIM_X = (TILE_DIM_Y * 4) / 8; // 32 +constexpr int BUFF_OUT_TR_SIZE = BUFF_OUT_TR_DIM_Y * BUFF_OUT_TR_DIM_X; +constexpr int BUFFS_NUM_OUT = BUFFS_NUM; // 2 +constexpr int BUFFS_NUM_OUT_TR = 2; + +// Manual SMEM swizzling parameters (matches single-tensor encode kernel). +constexpr int PACK_SIZE = 8; +constexpr int WAVES = ELTS_PER_THREAD / PACK_SIZE; // 2 +constexpr int TOTAL_BANKS_WIDTH = (32 * 4 * 8) / 4; // 256 +constexpr int THREADS_PER_BANK = TOTAL_BANKS_WIDTH / ELTS_PER_THREAD; // 16 + +using IType = FusedIType; +using IType2 = FusedIType2; +using IType2x3D = IType2[BUFFS_NUM][BUFF_IN_DIM_Y][BUFF_IN_DIM_X / 2]; +using OType2x3D = fp4e2m1x2[BUFFS_NUM_OUT][BUFF_OUT_DIM_Y][BUFF_OUT_DIM_X]; +using OType2xt3D = fp4e2m1x2[BUFFS_NUM_OUT_TR][BUFF_OUT_TR_DIM_Y][BUFF_OUT_TR_DIM_X]; +using ScalesType2D = nvfp4_scale_t[CHUNK_DIM_Y][SCALES_PER_CHUNK_X]; +using ScalesTypeTr2D = nvfp4_scale_t[CHUNK_DIM_X][SCALES_PER_CHUNK_Y]; + +// Rowwise encode helper: reads sRowAmax (pre-populated by K1), writes FP4 + +// e4m3 SFs into sOut / sSFrowwise. Byte-equal to the single-tensor version. +__device__ __forceinline__ void rowwise_scaling_per_token( + const IType* __restrict__ sIn_ptr, fp4e2m1x2* __restrict__ sOut_ptr, + nvfp4_scale_t* __restrict__ sSFrowwise_ptr, const float* __restrict__ sRowAmax, + const int stage_Y, const int stage_X, const int buff_in, const int buff_out) { + const auto& sIn = *reinterpret_cast(sIn_ptr); + auto& sOut = *reinterpret_cast(sOut_ptr); + auto& sSFrowwise = *reinterpret_cast(sSFrowwise_ptr); + + const int thread_lane = threadIdx.x % THREADS_PER_WARP; + const int bank_group = thread_lane / THREADS_PER_BANK; + + const int tid_Y_rowwise = threadIdx.x / THREADS_X_ROWWISE; // 0..31 + const int tid_X_rowwise = threadIdx.x % THREADS_X_ROWWISE; // 0..3 + + const int thread_offset_X_rowwise = tid_X_rowwise * ELTS_PER_THREAD; + + const int SF_thread_offset_rowwise_X = tid_X_rowwise / THREADS_PER_SCALE_ROWWISE; + const bool SF_storing_thread = (tid_X_rowwise % THREADS_PER_SCALE_ROWWISE == 0); + + const int stage_rowwise_scales_offset_X = + SF_thread_offset_rowwise_X + stage_X * SCALES_PER_TILE_X; + +#pragma unroll + for (int it = 0; it < ITERATIONS_NORMAL; ++it) { + const int it_offset_Y_rowwise = tid_Y_rowwise + it * THREADS_Y_ROWWISE; + const int chunk_local_row = stage_Y * TILE_DIM_Y + it_offset_Y_rowwise; + + const float row_amax = sRowAmax[chunk_local_row]; + const float S_enc = compute_global_encode_scaling_factor_FP4(fmaxf(row_amax, 1e-12f)); + + __align__(16) IType2 rIn[WAVES][PACK_SIZE / 2]; + + IType2 thread_amax_2x = {static_cast(0.0f), static_cast(0.0f)}; +#pragma unroll + for (int w = 0; w < WAVES; ++w) { + const int swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % ELTS_PER_THREAD; + const int swizzled_thread_idx = thread_offset_X_rowwise + swizzled_group_idx; + + __uint128_t& elts_8x = *reinterpret_cast<__uint128_t*>(&rIn[w]); + elts_8x = ptx::ld_shared_b128(&sIn[buff_in][it_offset_Y_rowwise][swizzled_thread_idx]); +#pragma unroll + for (int e = 0; e < PACK_SIZE / 2; ++e) { + ptx::abs_max_2x(thread_amax_2x, thread_amax_2x, rIn[w][e]); + } + } + const float block_amax = + static_cast(__hmax(__habs(thread_amax_2x.x), __habs(thread_amax_2x.y))); + + const fp8e4m3 s_dec = compute_decoding_scaling_factor(block_amax, S_enc); + const float s_dec_f = static_cast(s_dec); + const float block_scale = (s_dec_f == 0.f) ? 0.f : __fdiv_rn(S_enc, s_dec_f); + + if (SF_storing_thread) { + const int scales_offset_Y = chunk_local_row; + const int scales_offset_X = stage_rowwise_scales_offset_X; + sSFrowwise[scales_offset_Y][scales_offset_X] = s_dec; + } + +#pragma unroll + for (int w = 0; w < WAVES; ++w) { + const int swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % ELTS_PER_THREAD; + const int swizzled_idx = (swizzled_group_idx + thread_offset_X_rowwise) / 2; + + fp4e2m1x4 qu0{}, qu1{}; + ptx::mul_cvt_4x(qu0, rIn[w][0], rIn[w][1], block_scale); + ptx::mul_cvt_4x(qu1, rIn[w][2], rIn[w][3], block_scale); + + uint32_t out_x8 = (static_cast(*reinterpret_cast(&qu0))) | + (static_cast(*reinterpret_cast(&qu1)) << 16); + ptx::st_shared_b32(&sOut[buff_out][it_offset_Y_rowwise][swizzled_idx], out_x8); + } + } +} + +// Colwise encode helper. kWithRht=true rotates each thread's 16-row strip +// via the FHT before block_amax + cast; K1 amax must have used the same +// mask so the per-col outer amax matches. +template +__device__ __forceinline__ void colwise_scaling_per_token( + const IType* __restrict__ sIn_ptr, fp4e2m1x2* __restrict__ sOut_tr_ptr, + nvfp4_scale_t* __restrict__ sSFcolwise_ptr, const float* __restrict__ sColAmax, + const int stage_Y, const int stage_X, const int buff_in, const int buff_out_tr, + const uint32_t random_sign_mask_t = 0u) { + const auto& sIn2x = *reinterpret_cast(sIn_ptr); + auto& sOut_tr = *reinterpret_cast(sOut_tr_ptr); + auto& sSFcolwise = *reinterpret_cast(sSFcolwise_ptr); + + const int warp = threadIdx.x / THREADS_PER_WARP; // 0..3 + const int thread_lane = threadIdx.x % THREADS_PER_WARP; + + const int tid_Y_colwise = (thread_lane % 4 + warp) % 4; // 0..3 + const int tid_X_colwise = thread_lane; // 0..31 + + const int thread_offset_Y_colwise = tid_Y_colwise * SCALE_DIM; + const int thread_offset_X_colwise = tid_X_colwise * 2; + + const int in_thread_offset_Y = thread_offset_Y_colwise; + const int in_thread_offset_X = thread_offset_X_colwise / 2; + + const int out_tr_thread_offset_Y = thread_offset_X_colwise; + const int out_tr_thread_offset_X = thread_offset_Y_colwise / 2; + + const int scale_tr_offset_Y = (stage_X * TILE_DIM_X) + 2 * tid_X_colwise; + const int scale_tr_offset_X = (stage_Y * SCALES_PER_TILE_Y) + tid_Y_colwise; + + __align__(8) IType rIn[2][SCALE_DIM]; + // RHT staging in fp32 (DCE'd in the non-RHT instantiation). + float rRht[2][SCALE_DIM]; + + IType2 thread_amax_2x = {static_cast(0.0f), static_cast(0.0f)}; +#pragma unroll + for (int i = 0; i < SCALE_DIM; ++i) { + const IType2 elt_pair = + ptx::ld_shared_b32(&sIn2x[buff_in][in_thread_offset_Y + i][in_thread_offset_X]); + rIn[0][i] = elt_pair.x; + rIn[1][i] = elt_pair.y; + if constexpr (!kWithRht) { + ptx::abs_max_2x(thread_amax_2x, thread_amax_2x, elt_pair); + } + } + + float block_amax[2]; + if constexpr (kWithRht) { +#pragma unroll + for (int w = 0; w < 2; ++w) { +#pragma unroll + for (int i = 0; i < SCALE_DIM; ++i) { + rRht[w][i] = static_cast(rIn[w][i]); + } + apply_signed_fht16_inplace(rRht[w], random_sign_mask_t); + float local_max = 0.f; +#pragma unroll + for (int i = 0; i < SCALE_DIM; ++i) { + local_max = fmaxf(local_max, fabsf(rRht[w][i])); + } + // amax(|r * 0.25|) == amax(|r|) * 0.25; 0.25 also folded into + // block_scale_rht below (bit-exact: 0.25 = 2^-2). + block_amax[w] = local_max * k16HadamardNorm; + } + } else { + block_amax[0] = static_cast(__habs(thread_amax_2x.x)); + block_amax[1] = static_cast(__habs(thread_amax_2x.y)); + } + +#pragma unroll + for (int w = 0; w < 2; ++w) { + const int chunk_local_col = scale_tr_offset_Y + w; + const float col_amax = sColAmax[chunk_local_col]; + const float S_enc_col = compute_global_encode_scaling_factor_FP4(fmaxf(col_amax, 1e-12f)); + + const fp8e4m3 s_dec = compute_decoding_scaling_factor(block_amax[w], S_enc_col); + const float s_dec_f = static_cast(s_dec); + const float block_scale = (s_dec_f == 0.f) ? 0.f : __fdiv_rn(S_enc_col, s_dec_f); + + sSFcolwise[scale_tr_offset_Y + w][scale_tr_offset_X] = s_dec; + + fp4e2m1x4 qu[4]; + if constexpr (kWithRht) { + // ptx::floatx2 keeps mul_cvt_4x's input fp32 (no bf16 round-trip). + const float block_scale_rht = block_scale * k16HadamardNorm; +#pragma unroll + for (int e = 0; e < 4; ++e) { + const ptx::floatx2 in01{rRht[w][4 * e + 0], rRht[w][4 * e + 1]}; + const ptx::floatx2 in23{rRht[w][4 * e + 2], rRht[w][4 * e + 3]}; + ptx::mul_cvt_4x(qu[e], in01, in23, block_scale_rht); + } + } else { +#pragma unroll + for (int e = 0; e < 4; ++e) { + IType2 in01{rIn[w][4 * e + 0], rIn[w][4 * e + 1]}; + IType2 in23{rIn[w][4 * e + 2], rIn[w][4 * e + 3]}; + ptx::mul_cvt_4x(qu[e], in01, in23, block_scale); + } + } + + uint64_t out_pack_16x = (static_cast(*reinterpret_cast(&qu[0])) << 0) | + (static_cast(*reinterpret_cast(&qu[1])) << 16) | + (static_cast(*reinterpret_cast(&qu[2])) << 32) | + (static_cast(*reinterpret_cast(&qu[3])) << 48); + ptx::st_shared_b64(&sOut_tr[buff_out_tr][out_tr_thread_offset_Y + w][out_tr_thread_offset_X], + out_pack_16x); + } +} + +// Fused K2: TMA-loads input, runs cooperative row+col encode helpers, scatters +// FP4 + SFs to per-tensor outputs via st.global (multi-dest, no TMA store). +// kWithRht=true (and DO_COL=true): col-wise FHT with random_sign_mask_t, +// matching the K1 amax launch. Row direction never sees RHT. +template +__global__ void __launch_bounds__(THREADS_NUM) + group_per_token_fused_cast_kernel(const __grid_constant__ CUtensorMap tensor_map_input, + const __grid_constant__ NVFP4PerTokenMultiArgs args, + const float* noop, const size_t rows, const size_t cols, + const uint32_t random_sign_mask_t) { +#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) + if (noop != nullptr && noop[0] == 1.0f) { + return; + } + (void)rows; + + const bool leading_thread = (threadIdx.x == 0); + + // Dynamic SMEM layout (~28 KiB): sIn (16K) + sOut (4K) + sOut_tr (4K) + + // sSF_row (1K) + sSF_col (1K) + sRowAmax/sColAmax (512B each). + constexpr int buff_elems_total_in = BUFFS_NUM * BUFF_IN_SIZE; + constexpr int buff_size_aligned_in = + DIVUP_TO_MULTIPLE(buff_elems_total_in * sizeof(IType), TMA_SHMEM_ALIGNMENT); + constexpr int buff_size_aligned_out = + DIVUP_TO_MULTIPLE(BUFFS_NUM_OUT * BUFF_OUT_SIZE, TMA_SHMEM_ALIGNMENT); + constexpr int buff_size_aligned_out_t = + DIVUP_TO_MULTIPLE(BUFFS_NUM_OUT_TR * BUFF_OUT_TR_SIZE, TMA_SHMEM_ALIGNMENT); + constexpr int out_mem_rowwise_data = DO_ROW ? buff_size_aligned_out : 0; + constexpr int out_mem_colwise_data = DO_COL ? buff_size_aligned_out_t : 0; + constexpr int out_mem_rowwise_scales = + DO_ROW ? DIVUP_TO_MULTIPLE(CHUNK_DIM_Y * SCALES_PER_CHUNK_X * sizeof(nvfp4_scale_t), + TMA_SHMEM_ALIGNMENT) + : 0; + constexpr int out_mem_colwise_scales = + DO_COL ? DIVUP_TO_MULTIPLE(CHUNK_DIM_X * SCALES_PER_CHUNK_Y * sizeof(nvfp4_scale_t), + TMA_SHMEM_ALIGNMENT) + : 0; + (void)out_mem_colwise_scales; + + extern __shared__ unsigned char dynamic_shmem[]; + unsigned char* dshmem = dispatch::common::align_smem_ptr_per_TMA_requirements(dynamic_shmem); + + IType* sIn_ptr = reinterpret_cast(dshmem); + fp4e2m1x2* sOut_ptr = reinterpret_cast(dshmem + buff_size_aligned_in); + fp4e2m1x2* sOut_tr_ptr = + reinterpret_cast(dshmem + buff_size_aligned_in + out_mem_rowwise_data); + nvfp4_scale_t* sSFrowwise_ptr = reinterpret_cast( + dshmem + buff_size_aligned_in + out_mem_rowwise_data + out_mem_colwise_data); + nvfp4_scale_t* sSFcolwise_ptr = + reinterpret_cast(dshmem + buff_size_aligned_in + out_mem_rowwise_data + + out_mem_colwise_data + out_mem_rowwise_scales); + + __shared__ float sRowAmax[CHUNK_DIM_Y]; + __shared__ float sColAmax[CHUNK_DIM_X]; + __shared__ uint64_t IN_buff_readable_mbar[BUFFS_NUM]; + + auto& sIn = *reinterpret_cast(sIn_ptr); + + constexpr int shmem_buff_size = buff_size_aligned_in / BUFFS_NUM; + + const int32_t ctaid_X = blockIdx.x; + const int32_t ctaid_Y = blockIdx.y; + const int block_offset_Y = ctaid_Y * CHUNK_DIM_Y; + const int block_offset_X = ctaid_X * CHUNK_DIM_X; + + // Chunk Y stays inside one tensor (split_sections[i] % 128 == 0). + const int tensor_id = GetTensorId(args, block_offset_Y); + const int local_row_base = block_offset_Y - args.split_sections_range[tensor_id]; + const int M_t = args.split_sections_range[tensor_id + 1] - args.split_sections_range[tensor_id]; + + // Per-tensor output bases (one constant-cache lookup per CTA). + uint8_t* const q_row_base = + DO_ROW ? reinterpret_cast(args.q_row_list[tensor_id]) : nullptr; + uint8_t* const q_col_base = + DO_COL ? reinterpret_cast(args.q_col_list[tensor_id]) : nullptr; + nvfp4_scale_t* const s_dec_row_base = + DO_ROW ? reinterpret_cast(args.s_dec_row_list[tensor_id]) : nullptr; + nvfp4_scale_t* const s_dec_col_base = + DO_COL ? reinterpret_cast(args.s_dec_col_list[tensor_id]) : nullptr; + const float* const row_amax_base = + DO_ROW ? reinterpret_cast(args.row_amax_list[tensor_id]) : nullptr; + const float* const col_amax_base = + DO_COL ? reinterpret_cast(args.col_amax_list[tensor_id]) : nullptr; + + const size_t data_stride_row = static_cast(cols) / 2; + const size_t data_stride_col = static_cast(M_t) / 2; + const size_t scale_stride_row = static_cast(cols) / SCALE_DIM; + const size_t scale_stride_col = static_cast(M_t) / SCALE_DIM; + + // Load per-row / per-col amax into SMEM cache. + if (DO_ROW && threadIdx.x < CHUNK_DIM_Y) { + sRowAmax[threadIdx.x] = row_amax_base[local_row_base + threadIdx.x]; + } + if (DO_COL && threadIdx.x < CHUNK_DIM_X) { + sColAmax[threadIdx.x] = col_amax_base[block_offset_X + threadIdx.x]; + } + + if (leading_thread) { +#pragma unroll + for (int buff = 0; buff < BUFFS_NUM; ++buff) { + ptx::mbarrier_init(&IN_buff_readable_mbar[buff], 1); + } + ptx::fence_proxy_async_shared_cta(); + } + __syncthreads(); + + // Prefetch stage 0. +#pragma unroll + for (int stage = 0; stage < PREFETCH_STAGES; ++stage) { + const int buff_in_p = stage; + const int stage_Y = stage / TILES_X; + const int stage_X = stage % TILES_X; + const int global_offset_Y = block_offset_Y + stage_Y * TILE_DIM_Y; + const int global_offset_X = block_offset_X + stage_X * TILE_DIM_X; + if (leading_thread) { + ptx::mbarrier_arrive_expect_tx(&IN_buff_readable_mbar[buff_in_p], shmem_buff_size); + ptx::cp_async_bulk_tensor_2d_global_to_shared( + reinterpret_cast(&sIn[buff_in_p]), + reinterpret_cast(&tensor_map_input), global_offset_X, global_offset_Y, + &IN_buff_readable_mbar[buff_in_p]); + } + } + + int buff_in = 0; + int buff_out = 0; + int buff_out_tr = 0; + int IN_buff_readable_parity[BUFFS_NUM] = {0, 0}; + +#pragma unroll + for (int stage = 0; stage < STAGES; ++stage) { + const int stage_Y = stage / TILES_X; + const int stage_X = stage % TILES_X; + + if (stage < STAGES - PREFETCH_STAGES) { + const int next_prefetch_buff = (buff_in + PREFETCH_STAGES) % BUFFS_NUM; + const int next_prefetch_stage = (stage + PREFETCH_STAGES) % STAGES; + const int next_stage_Y = next_prefetch_stage / TILES_X; + const int next_stage_X = next_prefetch_stage % TILES_X; + const int next_global_offset_Y = block_offset_Y + next_stage_Y * TILE_DIM_Y; + const int next_global_offset_X = block_offset_X + next_stage_X * TILE_DIM_X; + if (leading_thread) { + ptx::mbarrier_arrive_expect_tx(&IN_buff_readable_mbar[next_prefetch_buff], shmem_buff_size); + ptx::cp_async_bulk_tensor_2d_global_to_shared( + reinterpret_cast(&sIn[next_prefetch_buff]), + reinterpret_cast(&tensor_map_input), next_global_offset_X, + next_global_offset_Y, &IN_buff_readable_mbar[next_prefetch_buff]); + } + ptx::fence_proxy_async_shared_cta(); + } + + // Wait for current stage's input tile to land. + ptx::mbarrier_wait_parity_acquire_cta_shared_cta(&IN_buff_readable_mbar[buff_in], + IN_buff_readable_parity[buff_in]); + IN_buff_readable_parity[buff_in] ^= 1; + + // 4x32 cooperative row + col encode helpers. + if (DO_ROW) { + rowwise_scaling_per_token(sIn_ptr, sOut_ptr, sSFrowwise_ptr, sRowAmax, stage_Y, stage_X, + buff_in, buff_out); + } + if (DO_COL) { + colwise_scaling_per_token(sIn_ptr, sOut_tr_ptr, sSFcolwise_ptr, sColAmax, stage_Y, + stage_X, buff_in, buff_out_tr, random_sign_mask_t); + } + + // Make helper SMEM writes visible before the scatter epilogue. + __syncthreads(); + + // Scatter sOut / sOut_tr to per-tensor buffers via cooperative b128 stores; + // 2 threads per row/col x 16 B = 2048 B per sub-tile per direction. + if (DO_ROW) { + auto& sOut = *reinterpret_cast(sOut_ptr); + const int row_in_subtile = static_cast(threadIdx.x) >> 1; // 0..63 + const int half = static_cast(threadIdx.x) & 1; // 0..1 + const int local_row = local_row_base + stage_Y * TILE_DIM_Y + row_in_subtile; + const int byte_off_X = (block_offset_X / 2) + stage_X * (TILE_DIM_X / 2) + half * 16; + const uint4* src = reinterpret_cast(&sOut[buff_out][row_in_subtile][half * 16]); + uint4* dst = reinterpret_cast( + q_row_base + static_cast(local_row) * data_stride_row + byte_off_X); + *dst = *src; + } + if (DO_COL) { + auto& sOut_tr = *reinterpret_cast(sOut_tr_ptr); + const int col_in_subtile = static_cast(threadIdx.x) >> 1; // 0..63 + const int half = static_cast(threadIdx.x) & 1; // 0..1 + const int global_col = block_offset_X + stage_X * TILE_DIM_X + col_in_subtile; + const int byte_off_M = (local_row_base / 2) + stage_Y * (TILE_DIM_Y / 2) + half * 16; + const uint4* src = + reinterpret_cast(&sOut_tr[buff_out_tr][col_in_subtile][half * 16]); + uint4* dst = reinterpret_cast( + q_col_base + static_cast(global_col) * data_stride_col + byte_off_M); + *dst = *src; + } + + // Sync so the scatter completes before next stage overwrites the buffer. + __syncthreads(); + + buff_in = (buff_in + 1) % BUFFS_NUM; + buff_out = (buff_out + 1) % BUFFS_NUM_OUT; + buff_out_tr = (buff_out_tr + 1) % BUFFS_NUM_OUT_TR; + } + + // SF epilogue: cooperative store of sSFrowwise / sSFcolwise to global. + if (DO_ROW) { + auto& sSFrowwise = *reinterpret_cast(sSFrowwise_ptr); + using ScalesVec = Vec; + const size_t scales_block_offset_X_rowwise = static_cast(ctaid_X) * SCALES_PER_CHUNK_X; + for (int row = static_cast(threadIdx.x); row < CHUNK_DIM_Y; row += THREADS_NUM) { + ScalesVec& scales_vec = *reinterpret_cast(sSFrowwise[row]); + const size_t local_row = static_cast(local_row_base) + row; + const size_t scale_idx_global = local_row * scale_stride_row + scales_block_offset_X_rowwise; + scales_vec.store_to_elts(&s_dec_row_base[scale_idx_global], 0, SCALES_PER_CHUNK_X); + } + } + if (DO_COL) { + auto& sSFcolwise = *reinterpret_cast(sSFcolwise_ptr); + using ScalesVec = Vec; + // M-block offset within s_dec_col[global_col] (shape (K, M_i/16) row-major). + const size_t local_block_offset_M = static_cast(local_row_base) / SCALE_DIM; + for (int row_tr = static_cast(threadIdx.x); row_tr < CHUNK_DIM_X; row_tr += THREADS_NUM) { + ScalesVec& scales_vec = *reinterpret_cast(sSFcolwise[row_tr]); + const size_t global_col = static_cast(block_offset_X) + row_tr; + const size_t scale_idx_global = global_col * scale_stride_col + local_block_offset_M; + scales_vec.store_to_elts(&s_dec_col_base[scale_idx_global], 0, SCALES_PER_CHUNK_Y); + } + } + + if (leading_thread) { +#pragma unroll + for (int buff = 0; buff < BUFFS_NUM; ++buff) { + ptx::mbarrier_invalid(&IN_buff_readable_mbar[buff]); + } + } +#else + (void)tensor_map_input; + (void)args; + (void)noop; + (void)rows; + (void)cols; + (void)random_sign_mask_t; + NVTE_DEVICE_ERROR("Fused grouped per-token cast kernel requires SM 10.0+ (Blackwell)."); +#endif // __CUDA_ARCH__ >= 1000 +} + +// Host launcher for the fused K2 path. bf16-only. +// with_rht=true applies a 16-pt RHT on the col direction; K1 amax must have +// used the same flag + mask, else inner SF + FP4 saturate against mismatched +// data. +inline void launch_grouped_fused_cast_bf16(const NVFP4PerTokenMultiArgs& args, + const SimpleTensor& input_data, int sum_M, int K, + bool do_row, bool do_col, bool with_rht, + uint32_t random_sign_mask_t, const float* noop, + cudaStream_t stream) { + if (!do_row && !do_col) return; + + checkCuDriverContext(stream); + + alignas(64) CUtensorMap tmap_in{}; + create_2D_tensor_map(tmap_in, input_data, sum_M, K, TILE_DIM_Y, TILE_DIM_X, K, 0, + sizeof(FusedIType) * 8); + + dim3 grid(static_cast(K / CHUNK_DIM_X), static_cast(sum_M / CHUNK_DIM_Y), 1); + dim3 block(THREADS_NUM, 1, 1); + + // Collapse to kWithRht=false when no colwise output is requested. + const bool with_rht_effective = with_rht && do_col; + TRANSFORMER_ENGINE_SWITCH_CONDITION( + do_row, DO_ROW, + TRANSFORMER_ENGINE_SWITCH_CONDITION( + do_col, DO_COL, TRANSFORMER_ENGINE_SWITCH_CONDITION(with_rht_effective, kWithRht, { + constexpr int sz_in = DIVUP_TO_MULTIPLE(BUFFS_NUM * BUFF_IN_SIZE * sizeof(FusedIType), + TMA_SHMEM_ALIGNMENT); + constexpr int sz_out_r = + DO_ROW ? DIVUP_TO_MULTIPLE(BUFFS_NUM_OUT * BUFF_OUT_SIZE, TMA_SHMEM_ALIGNMENT) : 0; + constexpr int sz_out_c = + DO_COL ? DIVUP_TO_MULTIPLE(BUFFS_NUM_OUT_TR * BUFF_OUT_TR_SIZE, TMA_SHMEM_ALIGNMENT) + : 0; + constexpr int sz_sf_r = + DO_ROW ? DIVUP_TO_MULTIPLE(CHUNK_DIM_Y * SCALES_PER_CHUNK_X * sizeof(nvfp4_scale_t), + TMA_SHMEM_ALIGNMENT) + : 0; + constexpr int sz_sf_c = + DO_COL ? DIVUP_TO_MULTIPLE(CHUNK_DIM_X * SCALES_PER_CHUNK_Y * sizeof(nvfp4_scale_t), + TMA_SHMEM_ALIGNMENT) + : 0; + constexpr int dshmem_size = + sz_in + sz_out_r + sz_out_c + sz_sf_r + sz_sf_c + TMA_SHMEM_ALIGNMENT; + auto kernel = group_per_token_fused_cast_kernel; + cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, dshmem_size); + kernel<<>>( + tmap_in, args, noop, static_cast(sum_M), static_cast(K), + random_sign_mask_t); + }));); + NVTE_CHECK_CUDA(cudaGetLastError()); +} + +// Host launcher for the fused K1 path. bf16-only. +// with_rht=true applies a 16-pt RHT on the col amax (rowwise raw). The +// downstream K2 cast MUST use the same flag + mask. +inline void launch_grouped_fused_amax_bf16(const NVFP4PerTokenMultiArgs& args, + const SimpleTensor& input_data, int sum_M, int K, + bool do_row, bool do_col, bool with_rht, + uint32_t random_sign_mask_t, const float* noop, + cudaStream_t stream) { + if (!do_row && !do_col) return; + + // Pre-zero amax slots (atomicMax identity). + { + dim3 grid_zero(static_cast(args.num_tensors)); + dim3 block_zero(256); + if (do_row && do_col) { + group_per_token_fused_zero_amax_kernel + <<>>(args, K); + } else if (do_row) { + group_per_token_fused_zero_amax_kernel + <<>>(args, K); + } else { + group_per_token_fused_zero_amax_kernel + <<>>(args, K); + } + NVTE_CHECK_CUDA(cudaGetLastError()); + } + + checkCuDriverContext(stream); + + alignas(64) CUtensorMap tmap_in{}; + create_2D_tensor_map(tmap_in, input_data, sum_M, K, TILE_DIM_Y, TILE_DIM_X, K, 0, + sizeof(FusedIType) * 8); + + constexpr int buff_elems_total_in = BUFFS_NUM * BUFF_IN_SIZE; + constexpr int buff_size_aligned_in = + DIVUP_TO_MULTIPLE(buff_elems_total_in * sizeof(FusedIType), TMA_SHMEM_ALIGNMENT); + constexpr int dshmem_size = buff_size_aligned_in + TMA_SHMEM_ALIGNMENT; + + dim3 grid(static_cast(K / CHUNK_DIM_X), static_cast(sum_M / CHUNK_DIM_Y), 1); + dim3 block(THREADS_NUM, 1, 1); + + // Collapse to kWithRht=false when no colwise amax is requested. + const bool with_rht_effective = with_rht && do_col; + TRANSFORMER_ENGINE_SWITCH_CONDITION( + do_row, DO_ROW, + TRANSFORMER_ENGINE_SWITCH_CONDITION( + do_col, DO_COL, TRANSFORMER_ENGINE_SWITCH_CONDITION(with_rht_effective, kWithRht, { + auto kernel = group_per_token_fused_amax_kernel; + cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, dshmem_size); + kernel<<>>( + tmap_in, args, noop, static_cast(sum_M), static_cast(K), + random_sign_mask_t); + }));); + NVTE_CHECK_CUDA(cudaGetLastError()); +} + +} // namespace fused + +// Populate per-tensor pointer tables + split_sections prefix-sum. +// which_buffers bitmask: kBufRowAmax | kBufColAmax | kBufRowCast | kBufColCast. +enum BufferFlags : int { + kBufRowAmax = 0x1, + kBufColAmax = 0x2, + kBufRowCast = 0x4, + kBufColCast = 0x8, +}; + +void populate_args(NVFP4PerTokenMultiArgs* args, std::vector& outputs, + const size_t* split_sections, size_t num_tensors, int which_buffers, + int expected_sum_M, int K) { + std::memset(args, 0, sizeof(*args)); + args->num_tensors = static_cast(num_tensors); + args->split_sections_range[0] = 0; + for (size_t i = 0; i < num_tensors; ++i) { + Tensor* o = outputs[i]; + NVTE_CHECK(split_sections[i] % 128 == 0, "split_sections[", i, "] = ", split_sections[i], + " must be a multiple of 128"); + args->split_sections_range[i + 1] = + args->split_sections_range[i] + static_cast(split_sections[i]); + if (split_sections[i] == 0) continue; + if (which_buffers & kBufRowAmax) { + NVTE_CHECK(o->amax.dptr != nullptr, "NVFP4 per-token grouped: outputs[", i, + "].amax must be allocated for rowwise"); + args->row_amax_list[i] = o->amax.dptr; + } + if (which_buffers & kBufColAmax) { + NVTE_CHECK(o->columnwise_amax.dptr != nullptr, "NVFP4 per-token grouped: outputs[", i, + "].columnwise_amax must be allocated for columnwise"); + args->col_amax_list[i] = o->columnwise_amax.dptr; + } + if (which_buffers & kBufRowCast) { + NVTE_CHECK(o->data.dptr != nullptr && o->scale_inv.dptr != nullptr, + "NVFP4 per-token grouped: outputs[", i, + "].data + .scale_inv must be allocated for rowwise cast"); + args->q_row_list[i] = o->data.dptr; + args->s_dec_row_list[i] = o->scale_inv.dptr; + } + if (which_buffers & kBufColCast) { + NVTE_CHECK(o->columnwise_data.dptr != nullptr && o->columnwise_scale_inv.dptr != nullptr, + "NVFP4 per-token grouped: outputs[", i, + "].columnwise_data + .columnwise_scale_inv must be allocated for columnwise cast"); + args->q_col_list[i] = o->columnwise_data.dptr; + args->s_dec_col_list[i] = o->columnwise_scale_inv.dptr; + } + } + NVTE_CHECK(args->split_sections_range[num_tensors] == expected_sum_M, + "NVFP4 per-token grouped: sum(split_sections) = ", + args->split_sections_range[num_tensors], " must equal input rows ", expected_sum_M); + (void)K; +} + +// Host entry. do_amax / do_cast select K1 / K2 phases (composite = both). +// with_rht / mask are threaded into BOTH K1 and K2; the caller must use the +// same flag/mask if they invoke amax + cast separately. +void quantize_per_token_grouped(const Tensor& input, std::vector& outputs, + const size_t* split_sections, size_t num_tensors, bool rowwise, + bool columnwise, bool do_amax, bool do_cast, bool with_rht, + uint32_t random_sign_mask_t, cudaStream_t stream) { + NVTE_CHECK(num_tensors > 0, "NVFP4 per-token grouped: num_tensors must be > 0"); + NVTE_CHECK(num_tensors <= static_cast(kMaxTensorsPerKernel), + "NVFP4 per-token grouped: num_tensors (", num_tensors, + ") exceeds kMaxTensorsPerKernel = ", kMaxTensorsPerKernel); + NVTE_CHECK(rowwise || columnwise, + "NVFP4 per-token grouped: at least one of rowwise/columnwise must be true"); + NVTE_CHECK(input.has_data(), "NVFP4 per-token grouped: input has no data"); + NVTE_CHECK(input.dtype() == DType::kBFloat16, + "NVFP4 per-token grouped: input dtype must be bf16 (got ", + static_cast(input.dtype()), ")"); + + const int sum_M = static_cast(input.flat_first_dim()); + const int K = static_cast(input.flat_last_dim()); + if (sum_M == 0 || K == 0) return; + NVTE_CHECK(K % 128 == 0, "NVFP4 per-token grouped: K (", K, ") must be a multiple of 128"); + + int which_buffers = 0; + if ((do_amax || do_cast) && rowwise) which_buffers |= kBufRowAmax; + if ((do_amax || do_cast) && columnwise) which_buffers |= kBufColAmax; + if (do_cast && rowwise) which_buffers |= kBufRowCast; + if (do_cast && columnwise) which_buffers |= kBufColCast; + + NVFP4PerTokenMultiArgs args; + populate_args(&args, outputs, split_sections, num_tensors, which_buffers, sum_M, K); + + // K1 + K2 = 2 fused launches; K1 must complete before K2 reads its amax. + if (do_amax) { + fused::launch_grouped_fused_amax_bf16(args, input.data, sum_M, K, + /*do_row=*/rowwise, + /*do_col=*/columnwise, + /*with_rht=*/with_rht, + /*random_sign_mask_t=*/random_sign_mask_t, + /*noop=*/nullptr, stream); + } + if (do_cast) { + fused::launch_grouped_fused_cast_bf16(args, input.data, sum_M, K, + /*do_row=*/rowwise, + /*do_col=*/columnwise, + /*with_rht=*/with_rht, + /*random_sign_mask_t=*/random_sign_mask_t, + /*noop=*/nullptr, stream); + } +} + +#endif // FP4_TYPE_SUPPORTED + +} // namespace nvfp4_per_token_group +} // namespace transformer_engine + +// C-API entries. +namespace { + +std::vector collect_outputs(NVTETensor* outputs, size_t num_tensors) { + std::vector v; + v.reserve(num_tensors); + for (size_t i = 0; i < num_tensors; ++i) { + v.push_back(transformer_engine::convertNVTETensorCheck(outputs[i])); + } + return v; +} + +} // namespace + +void nvte_group_nvfp4_per_token_amax(const NVTETensor input, NVTETensor* outputs, + const size_t* split_sections, size_t num_tensors, bool rowwise, + bool columnwise, int with_rht, int random_sign_mask_t, + cudaStream_t stream) { +#if FP4_TYPE_SUPPORTED + NVTE_API_CALL(nvte_group_nvfp4_per_token_amax); + using namespace transformer_engine; + if (num_tensors == 0) return; + const Tensor* in = convertNVTETensorCheck(input); + std::vector outs = collect_outputs(outputs, num_tensors); + // C-API mirrors nvte_nvfp4_per_token_amax: `int` for cross-language ABI + // safety; internal kernel arg is uint32_t with only the low 16 bits used. + nvfp4_per_token_group::quantize_per_token_grouped( + *in, outs, split_sections, num_tensors, rowwise, columnwise, + /*do_amax=*/true, /*do_cast=*/false, + /*with_rht=*/with_rht != 0, + /*random_sign_mask_t=*/ + static_cast(random_sign_mask_t) & 0xFFFFu, stream); +#else + (void)input; + (void)outputs; + (void)split_sections; + (void)num_tensors; + (void)rowwise; + (void)columnwise; + (void)with_rht; + (void)random_sign_mask_t; + (void)stream; + NVTE_ERROR("FP4 support requires CUDA 12.8+, but compile-time CUDA version is ", CUDA_VERSION); +#endif +} + +void nvte_group_nvfp4_per_token_cast(const NVTETensor input, NVTETensor* outputs, + const size_t* split_sections, size_t num_tensors, bool rowwise, + bool columnwise, int with_rht, int random_sign_mask_t, + cudaStream_t stream) { +#if FP4_TYPE_SUPPORTED + NVTE_API_CALL(nvte_group_nvfp4_per_token_cast); + using namespace transformer_engine; + if (num_tensors == 0) return; + const Tensor* in = convertNVTETensorCheck(input); + std::vector outs = collect_outputs(outputs, num_tensors); + nvfp4_per_token_group::quantize_per_token_grouped( + *in, outs, split_sections, num_tensors, rowwise, columnwise, + /*do_amax=*/false, /*do_cast=*/true, + /*with_rht=*/with_rht != 0, + /*random_sign_mask_t=*/ + static_cast(random_sign_mask_t) & 0xFFFFu, stream); +#else + (void)input; + (void)outputs; + (void)split_sections; + (void)num_tensors; + (void)rowwise; + (void)columnwise; + (void)with_rht; + (void)random_sign_mask_t; + (void)stream; + NVTE_ERROR("FP4 support requires CUDA 12.8+, but compile-time CUDA version is ", CUDA_VERSION); +#endif +} + +void nvte_group_nvfp4_per_token_quantize(const NVTETensor input, NVTETensor* outputs, + const size_t* split_sections, size_t num_tensors, + bool rowwise, bool columnwise, int with_rht, + int random_sign_mask_t, cudaStream_t stream) { +#if FP4_TYPE_SUPPORTED + NVTE_API_CALL(nvte_group_nvfp4_per_token_quantize); + using namespace transformer_engine; + if (num_tensors == 0) return; + const Tensor* in = convertNVTETensorCheck(input); + std::vector outs = collect_outputs(outputs, num_tensors); + nvfp4_per_token_group::quantize_per_token_grouped( + *in, outs, split_sections, num_tensors, rowwise, columnwise, + /*do_amax=*/true, /*do_cast=*/true, + /*with_rht=*/with_rht != 0, + /*random_sign_mask_t=*/ + static_cast(random_sign_mask_t) & 0xFFFFu, stream); +#else + (void)input; + (void)outputs; + (void)split_sections; + (void)num_tensors; + (void)rowwise; + (void)columnwise; + (void)with_rht; + (void)random_sign_mask_t; + (void)stream; + NVTE_ERROR("FP4 support requires CUDA 12.8+, but compile-time CUDA version is ", CUDA_VERSION); +#endif +} diff --git a/transformer_engine/common/gemm/nvfp4_cutlass_gemm.cu b/transformer_engine/common/gemm/nvfp4_cutlass_gemm.cu new file mode 100644 index 0000000000..f929017d70 --- /dev/null +++ b/transformer_engine/common/gemm/nvfp4_cutlass_gemm.cu @@ -0,0 +1,427 @@ +/*************************************************************************************************** + * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + **************************************************************************************************/ + +// CUTLASS NVFP4xNVFP4 -> BF16 GEMM kernels (modeled on CUTLASS example 72a). +// Two C-API entry points: scalar (alpha, beta), and per-row*per-col fused EVT. + +#include +#include + +#include + +#include "../common.h" +#include "../util/logging.h" +#include "cute/tensor.hpp" +#include "cutlass/cutlass.h" +#include "cutlass/detail/sm100_blockscaled_layout.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" +#include "cutlass/epilogue/fusion/sm90_visitor_compute_tma_warpspecialized.hpp" +#include "cutlass/epilogue/fusion/sm90_visitor_load_tma_warpspecialized.hpp" +#include "cutlass/epilogue/fusion/sm90_visitor_tma_warpspecialized.hpp" +#include "cutlass/functional.h" +#include "cutlass/gemm/collective/collective_builder.hpp" +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/gemm/kernel/gemm_universal.hpp" +#include "cutlass/numeric_types.h" +#include "cutlass/util/packed_stride.hpp" + +namespace transformer_engine { +namespace nvfp4_cutlass { + +#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) + +namespace cute_ = cute; + +// CUTLASS GEMM type config (mirrors 72a). BF16 output matches the production +// TE NVFP4 GEMM (cublasLt path), making this a drop-in at the GEMM boundary. + +using ElementA = cutlass::nv_float4_t; +using LayoutATag = cutlass::layout::RowMajor; +constexpr int AlignmentA = 32; + +using ElementB = cutlass::nv_float4_t; +using LayoutBTag = cutlass::layout::ColumnMajor; +constexpr int AlignmentB = 32; + +using ElementC = cutlass::bfloat16_t; +using ElementD = cutlass::bfloat16_t; +using LayoutCTag = cutlass::layout::RowMajor; +using LayoutDTag = cutlass::layout::RowMajor; +// CUTLASS epilogue uses 128-bit vector loads/stores; bf16 packs 8 elts/128b. +constexpr int AlignmentC = 128 / cutlass::sizeof_bits::value; +constexpr int AlignmentD = 128 / cutlass::sizeof_bits::value; + +using ElementAccumulator = float; +using ArchTag = cutlass::arch::Sm100; +using OperatorClass = cutlass::arch::OpClassBlockScaledTensorOp; + +using MmaTileShape = cute_::Shape; +using ClusterShape = cute_::Shape; + +using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + ArchTag, OperatorClass, MmaTileShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, ElementAccumulator, ElementAccumulator, + ElementC, LayoutCTag, AlignmentC, ElementD, LayoutDTag, AlignmentD, + cutlass::epilogue::collective::EpilogueScheduleAuto>::CollectiveOp; + +using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, OperatorClass, ElementA, LayoutATag, AlignmentA, ElementB, LayoutBTag, AlignmentB, + ElementAccumulator, MmaTileShape, ClusterShape, + cutlass::gemm::collective::StageCountAutoCarveout( + sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::gemm::collective::KernelScheduleAuto>::CollectiveOp; + +using GemmKernel = + cutlass::gemm::kernel::GemmUniversal, CollectiveMainloop, + CollectiveEpilogue, void>; + +using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + +using StrideA = typename Gemm::GemmKernel::StrideA; +using StrideB = typename Gemm::GemmKernel::StrideB; +using StrideC = typename Gemm::GemmKernel::StrideC; +using StrideD = typename Gemm::GemmKernel::StrideD; + +using LayoutSFA = typename Gemm::GemmKernel::CollectiveMainloop::LayoutSFA; +using LayoutSFB = typename Gemm::GemmKernel::CollectiveMainloop::LayoutSFB; +using Sm1xxBlkScaledConfig = typename Gemm::GemmKernel::CollectiveMainloop::Sm1xxBlkScaledConfig; + +using ElementADataPtr = typename ElementA::DataType const*; +using ElementBDataPtr = typename ElementB::DataType const*; +using ElementASfPtr = typename ElementA::ScaleFactorType const*; +using ElementBSfPtr = typename ElementB::ScaleFactorType const*; + +// Core launcher (scalar alpha/beta). + +static void run_cutlass_gemm(void const* a_data_ptr, void const* b_data_ptr, void const* a_sf_ptr, + void const* b_sf_ptr, void* d_ptr, int M, int N, int K, float alpha, + float beta, cudaStream_t stream) { + auto stride_A = cutlass::make_cute_packed_stride(StrideA{}, {M, K, 1}); + auto stride_B = cutlass::make_cute_packed_stride(StrideB{}, {N, K, 1}); + auto stride_C = cutlass::make_cute_packed_stride(StrideC{}, {M, N, 1}); + auto stride_D = cutlass::make_cute_packed_stride(StrideD{}, {M, N, 1}); + + auto layout_SFA = Sm1xxBlkScaledConfig::tile_atom_to_shape_SFA(cute_::make_shape(M, N, K, 1)); + auto layout_SFB = Sm1xxBlkScaledConfig::tile_atom_to_shape_SFB(cute_::make_shape(M, N, K, 1)); + + typename Gemm::Arguments args{cutlass::gemm::GemmUniversalMode::kGemm, + {M, N, K, 1}, + {reinterpret_cast(a_data_ptr), stride_A, + reinterpret_cast(b_data_ptr), stride_B, + reinterpret_cast(a_sf_ptr), layout_SFA, + reinterpret_cast(b_sf_ptr), layout_SFB}, + {{alpha, beta}, + reinterpret_cast(d_ptr), + stride_C, + reinterpret_cast(d_ptr), + stride_D}}; + + Gemm gemm; + + // Stream-ordered workspace alloc; tight perf loops should pre-allocate. + size_t workspace_size = Gemm::get_workspace_size(args); + void* workspace = nullptr; + if (workspace_size > 0) { + NVTE_CHECK_CUDA(cudaMallocAsync(&workspace, workspace_size, stream)); + } + + cutlass::Status status = gemm.can_implement(args); + NVTE_CHECK(status == cutlass::Status::kSuccess, + "CUTLASS NVFP4 GEMM cannot implement: ", cutlassGetStatusString(status), " (M=", M, + " N=", N, " K=", K, ")"); + + status = gemm.initialize(args, workspace, stream); + NVTE_CHECK(status == cutlass::Status::kSuccess, + "CUTLASS NVFP4 GEMM initialize failed: ", cutlassGetStatusString(status)); + + status = gemm.run(stream); + NVTE_CHECK(status == cutlass::Status::kSuccess, + "CUTLASS NVFP4 GEMM run failed: ", cutlassGetStatusString(status)); + + if (workspace != nullptr) { + NVTE_CHECK_CUDA(cudaFreeAsync(workspace, stream)); + } +} + +// Per-token fused variant: same NVFP4 mainloop, custom EVT folds the +// cuBLAS-LT-equivalent D[i,j] = bf16(NVFP4_DEQUANT_K * alpha_a[i] * +// alpha_b[j] * acc) in one launch; no separate post-scale, no HBM round-trip. + +// NVFP4 has TWO-LEVEL dequant: per-block SF (mainloop) + outer 1/2688^2. +// cuBLAS-LT auto-folds the outer via amax slot (see nvte_nvfp4_compute_per_tensor_scale); +// CUTLASS NVFP4 is "raw", the EVT must apply 1/2688^2 explicitly. + +// EVT (all fp32 until final cast): +// L1 tmp1 = alpha_a[i] * acc; L2 tmp2 = alpha_b[j] * tmp1; +// L3 out = bf16(NVFP4_DEQUANT_K * tmp2). + +// CUTLASS naming note: Sm90Row/ColBroadcast = "load a row/col vector AND +// broadcast across the orthogonal dim". Sm90ColBroadcast (Stride<_1,_0,_0>) +// indexes M -> per-row; Sm90RowBroadcast (Stride<_0,_1,_0>) indexes N -> per-col. + +namespace fusion = cutlass::epilogue::fusion; +constexpr cutlass::FloatRoundStyle kRoundStyleFused = cutlass::FloatRoundStyle::round_to_nearest; + +using ElementScale = float; + +// NVFP4 spec constant: 1 / (fp4_max^2 * fp8_max^2) = 1/(6^2 * 448^2) = 1/7,225,344. +constexpr float kNvfp4DequantFactor = 1.0f / (6.0f * 6.0f * 448.0f * 448.0f); + +using AccFetchNode = fusion::Sm90AccFetch; + +using RowScaleNode = fusion::Sm90ColBroadcast< + /*Stages=*/0, + /*CtaTileShapeMNK=*/MmaTileShape, + /*ElementInput=*/ElementScale, + /*ElementCompute=*/ElementAccumulator>; + +using ColScaleNode = fusion::Sm90RowBroadcast< + /*Stages=*/0, + /*CtaTileShapeMNK=*/MmaTileShape, + /*ElementInput=*/ElementScale, + /*ElementCompute=*/ElementAccumulator>; + +// Tile-wide constant (the NVFP4 spec factor 1/2688^2); same pattern as +// Sm90LinCombPerRowBias scalar alpha/beta in sm90_callbacks_tma_warpspecialized. +using ConstScaleNode = fusion::Sm90ScalarBroadcast; + +// L1: tmp1 = alpha_a[i] * acc. +using MulAccByRowEVT = fusion::Sm90EVT, + RowScaleNode, AccFetchNode>; + +// L2: tmp2 = alpha_b[j] * tmp1 (still fp32; bf16 cast deferred to L3). +using MulByColEVT = fusion::Sm90EVT, + ColScaleNode, MulAccByRowEVT>; + +// L3: D = bf16(NVFP4_DEQUANT_K * tmp2). ElementD=bf16 forces round-to-nearest. +using FusedEVT = fusion::Sm90EVT< + fusion::Sm90Compute, + ConstScaleNode, MulByColEVT>; + +using CollectiveEpilogueFused = typename cutlass::epilogue::collective::CollectiveBuilder< + ArchTag, OperatorClass, MmaTileShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, ElementAccumulator, ElementAccumulator, + ElementC, LayoutCTag, AlignmentC, ElementD, LayoutDTag, AlignmentD, + cutlass::epilogue::collective::EpilogueScheduleAuto, FusedEVT>::CollectiveOp; + +using CollectiveMainloopFused = typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, OperatorClass, ElementA, LayoutATag, AlignmentA, ElementB, LayoutBTag, AlignmentB, + ElementAccumulator, MmaTileShape, ClusterShape, + cutlass::gemm::collective::StageCountAutoCarveout( + sizeof(typename CollectiveEpilogueFused::SharedStorage))>, + cutlass::gemm::collective::KernelScheduleAuto>::CollectiveOp; + +using GemmKernelFused = + cutlass::gemm::kernel::GemmUniversal, CollectiveMainloopFused, + CollectiveEpilogueFused, void>; + +using GemmFused = cutlass::gemm::device::GemmUniversalAdapter; + +static void run_cutlass_per_token_gemm(void const* a_data_ptr, void const* b_data_ptr, + void const* a_sf_ptr, void const* b_sf_ptr, + float const* alpha_a_ptr, float const* alpha_b_ptr, + void* d_ptr, int M, int N, int K, cudaStream_t stream) { + auto stride_A = cutlass::make_cute_packed_stride(StrideA{}, {M, K, 1}); + auto stride_B = cutlass::make_cute_packed_stride(StrideB{}, {N, K, 1}); + auto stride_C = cutlass::make_cute_packed_stride(StrideC{}, {M, N, 1}); + auto stride_D = cutlass::make_cute_packed_stride(StrideD{}, {M, N, 1}); + + auto layout_SFA = Sm1xxBlkScaledConfig::tile_atom_to_shape_SFA(cute_::make_shape(M, N, K, 1)); + auto layout_SFB = Sm1xxBlkScaledConfig::tile_atom_to_shape_SFB(cute_::make_shape(M, N, K, 1)); + + // EVT args order = children first, then this node's args (empty for + // Sm90Compute). Sm90ScalarBroadcast has 3 ARRAY fields + // (BroadcastCount=1) -> constant takes {{value}, {nullptr}, {Stride{}}}. + typename FusedEVT::Arguments fusion_args{ + // L3 child[0]: ConstScaleNode args -- NVFP4 spec factor 1/2688^2. + {/*scalars=*/{kNvfp4DequantFactor}, + /*scalar_ptrs=*/{nullptr}, + /*dScalar=*/{}}, + // L3 child[1]: L2 (MulByColEVT) args. + { + // L2 child[0]: alpha_b per-col broadcast (Sm90RowBroadcast Arguments). + {alpha_b_ptr, /*null_default=*/ElementScale{0}, /*dRow=*/{}}, + // L2 child[1]: L1 (MulAccByRowEVT) args. + { + // L1 child[0]: alpha_a per-row broadcast (Sm90ColBroadcast Arguments). + {alpha_a_ptr, /*null_default=*/ElementScale{0}, /*dCol=*/{}}, + // L1 child[1]: AccFetch (empty Arguments). + {}, + // L1 node: Sm90Compute (empty Arguments). + {}, + }, + // L2 node: Sm90Compute (empty Arguments). + {}, + }, + // L3 node: Sm90Compute (empty Arguments). + {}, + }; + + typename GemmFused::Arguments args{ + cutlass::gemm::GemmUniversalMode::kGemm, + {M, N, K, 1}, + {reinterpret_cast(a_data_ptr), stride_A, + reinterpret_cast(b_data_ptr), stride_B, + reinterpret_cast(a_sf_ptr), layout_SFA, + reinterpret_cast(b_sf_ptr), layout_SFB}, + {fusion_args, + /*ptr_C=*/nullptr, stride_C, // EVT has no SrcFetch; C unused. + reinterpret_cast(d_ptr), stride_D}}; + + GemmFused gemm; + + size_t workspace_size = GemmFused::get_workspace_size(args); + void* workspace = nullptr; + if (workspace_size > 0) { + NVTE_CHECK_CUDA(cudaMallocAsync(&workspace, workspace_size, stream)); + } + + cutlass::Status status = gemm.can_implement(args); + NVTE_CHECK(status == cutlass::Status::kSuccess, + "CUTLASS NVFP4 per-token fused GEMM cannot implement: ", + cutlassGetStatusString(status), " (M=", M, " N=", N, " K=", K, ")"); + + status = gemm.initialize(args, workspace, stream); + NVTE_CHECK( + status == cutlass::Status::kSuccess, + "CUTLASS NVFP4 per-token fused GEMM initialize failed: ", cutlassGetStatusString(status)); + + status = gemm.run(stream); + NVTE_CHECK(status == cutlass::Status::kSuccess, + "CUTLASS NVFP4 per-token fused GEMM run failed: ", cutlassGetStatusString(status)); + + if (workspace != nullptr) { + NVTE_CHECK_CUDA(cudaFreeAsync(workspace, stream)); + } +} + +#endif // CUTLASS_ARCH_MMA_SM100_SUPPORTED + +} // namespace nvfp4_cutlass +} // namespace transformer_engine + +// C API. + +void nvte_nvfp4_cutlass_gemm(const NVTETensor a_data, const NVTETensor b_data, + const NVTETensor a_sf, const NVTETensor b_sf, NVTETensor d, + float alpha, float beta, cudaStream_t stream) { + using namespace transformer_engine; + + auto* a_t = convertNVTETensorCheck(a_data); + auto* b_t = convertNVTETensorCheck(b_data); + auto* sa_t = convertNVTETensorCheck(a_sf); + auto* sb_t = convertNVTETensorCheck(b_sf); + auto* d_t = convertNVTETensorCheck(d); + + // Logical shapes are interpreted in elements (FP4 storage is packed 2/byte). + const auto a_shape = a_t->data.shape; + const auto b_shape = b_t->data.shape; + const auto d_shape = d_t->data.shape; + + NVTE_CHECK(a_shape.size() == 2, "A must be 2D (M, K), got rank=", a_shape.size()); + NVTE_CHECK(b_shape.size() == 2, "B must be 2D (N, K), got rank=", b_shape.size()); + NVTE_CHECK(d_shape.size() == 2, "D must be 2D (M, N), got rank=", d_shape.size()); + + const int M = static_cast(a_shape[0]); + const int K = static_cast(a_shape[1]); + const int N = static_cast(b_shape[0]); + + NVTE_CHECK(static_cast(b_shape[1]) == K, "A.K (", K, ") and B.K (", b_shape[1], + ") must match"); + NVTE_CHECK(static_cast(d_shape[0]) == M, "D.M (", d_shape[0], ") must match A.M (", M, ")"); + NVTE_CHECK(static_cast(d_shape[1]) == N, "D.N (", d_shape[1], ") must match B.N (", N, ")"); + + NVTE_CHECK(a_t->data.dtype == DType::kFloat4E2M1, "A data must be FP4 e2m1"); + NVTE_CHECK(b_t->data.dtype == DType::kFloat4E2M1, "B data must be FP4 e2m1"); + NVTE_CHECK(d_t->data.dtype == DType::kBFloat16, "D must be BF16"); + + // CUTLASS mainloop expects e4m3 SF; accept raw uint8 (PyTorch's wire type). + NVTE_CHECK(sa_t->data.dtype == DType::kFloat8E4M3 || sa_t->data.dtype == DType::kByte, + "A scale must be FP8 e4m3 (or raw uint8 byte)"); + NVTE_CHECK(sb_t->data.dtype == DType::kFloat8E4M3 || sb_t->data.dtype == DType::kByte, + "B scale must be FP8 e4m3 (or raw uint8 byte)"); + + NVTE_CHECK(M > 0 && N > 0 && K > 0, "M, N, K must be positive"); + NVTE_CHECK(M % 256 == 0 && N % 256 == 0 && K % 256 == 0, + "CUTLASS NVFP4 GEMM (Stage 1) requires M, N, K to be multiples of 256, got M=", M, + " N=", N, " K=", K, ". Use a TileShape-aware variant for smaller K."); + +#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) + nvfp4_cutlass::run_cutlass_gemm(a_t->data.dptr, b_t->data.dptr, sa_t->data.dptr, sb_t->data.dptr, + d_t->data.dptr, M, N, K, alpha, beta, stream); +#else + NVTE_ERROR( + "CUTLASS NVFP4 GEMM requires SM100 (Blackwell). Build with the sm_100a/sm_100f arch flag."); +#endif +} + +void nvte_nvfp4_cutlass_per_token_gemm(const NVTETensor a_data, const NVTETensor b_data, + const NVTETensor a_sf, const NVTETensor b_sf, + const NVTETensor alpha_a, const NVTETensor alpha_b, + NVTETensor d, cudaStream_t stream) { + using namespace transformer_engine; + + auto* a_t = convertNVTETensorCheck(a_data); + auto* b_t = convertNVTETensorCheck(b_data); + auto* sa_t = convertNVTETensorCheck(a_sf); + auto* sb_t = convertNVTETensorCheck(b_sf); + auto* aa_t = convertNVTETensorCheck(alpha_a); + auto* ab_t = convertNVTETensorCheck(alpha_b); + auto* d_t = convertNVTETensorCheck(d); + + const auto a_shape = a_t->data.shape; + const auto b_shape = b_t->data.shape; + const auto d_shape = d_t->data.shape; + + NVTE_CHECK(a_shape.size() == 2, "A must be 2D (M, K), got rank=", a_shape.size()); + NVTE_CHECK(b_shape.size() == 2, "B must be 2D (N, K), got rank=", b_shape.size()); + NVTE_CHECK(d_shape.size() == 2, "D must be 2D (M, N), got rank=", d_shape.size()); + + const int M = static_cast(a_shape[0]); + const int K = static_cast(a_shape[1]); + const int N = static_cast(b_shape[0]); + + NVTE_CHECK(static_cast(b_shape[1]) == K, "A.K (", K, ") and B.K (", b_shape[1], + ") must match"); + NVTE_CHECK(static_cast(d_shape[0]) == M, "D.M (", d_shape[0], ") must match A.M (", M, ")"); + NVTE_CHECK(static_cast(d_shape[1]) == N, "D.N (", d_shape[1], ") must match B.N (", N, ")"); + + NVTE_CHECK(a_t->data.dtype == DType::kFloat4E2M1, "A data must be FP4 e2m1"); + NVTE_CHECK(b_t->data.dtype == DType::kFloat4E2M1, "B data must be FP4 e2m1"); + NVTE_CHECK(d_t->data.dtype == DType::kBFloat16, "D must be BF16"); + NVTE_CHECK(sa_t->data.dtype == DType::kFloat8E4M3 || sa_t->data.dtype == DType::kByte, + "A scale must be FP8 e4m3 (or raw uint8 byte)"); + NVTE_CHECK(sb_t->data.dtype == DType::kFloat8E4M3 || sb_t->data.dtype == DType::kByte, + "B scale must be FP8 e4m3 (or raw uint8 byte)"); + NVTE_CHECK(aa_t->data.dtype == DType::kFloat32, "alpha_a must be FP32"); + NVTE_CHECK(ab_t->data.dtype == DType::kFloat32, "alpha_b must be FP32"); + + // alpha_a/b accepted as 1D or (M,1)/(N,1); only element count is validated. + const size_t aa_numel = aa_t->data.numel(); + const size_t ab_numel = ab_t->data.numel(); + NVTE_CHECK(aa_numel == static_cast(M), "alpha_a must have M=", M, " elements, got ", + aa_numel); + NVTE_CHECK(ab_numel == static_cast(N), "alpha_b must have N=", N, " elements, got ", + ab_numel); + + NVTE_CHECK(M > 0 && N > 0 && K > 0, "M, N, K must be positive"); + NVTE_CHECK(M % 256 == 0 && N % 256 == 0 && K % 256 == 0, + "CUTLASS NVFP4 per-token fused GEMM requires M, N, K to be multiples of 256, got M=", + M, " N=", N, " K=", K, "."); + +#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) + nvfp4_cutlass::run_cutlass_per_token_gemm( + a_t->data.dptr, b_t->data.dptr, sa_t->data.dptr, sb_t->data.dptr, + reinterpret_cast(aa_t->data.dptr), + reinterpret_cast(ab_t->data.dptr), d_t->data.dptr, M, N, K, stream); +#else + NVTE_ERROR( + "CUTLASS NVFP4 per-token fused GEMM requires SM100 (Blackwell). Build with sm_100a/sm_100f."); +#endif +} diff --git a/transformer_engine/common/gemm/nvfp4_per_token_post_scale.cu b/transformer_engine/common/gemm/nvfp4_per_token_post_scale.cu new file mode 100644 index 0000000000..4f4d22d22a --- /dev/null +++ b/transformer_engine/common/gemm/nvfp4_per_token_post_scale.cu @@ -0,0 +1,140 @@ +/************************************************************************* + * Copyright (c) 2025-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +/*! \file nvfp4_per_token_post_scale.cu + * \brief NVFP4 per-token GEMM-output post-scale: d[i,j] *= r_A[i] * r_B[j]. + * + * Standalone bf16 epilogue applied after cuBLAS LT NVFP4 GEMM with the + * operand amaxes pinned to 1.0. See nvfp4_per_token.h for the math chain. + */ + +#include + +#include "../common.h" +#include "../util/logging.h" +#include "../util/ptx.cuh" + +namespace transformer_engine { +namespace nvfp4_per_token { + +namespace { + +// Each block tiles 16 rows x 256 cols of the output: amaxes are loaded +// once into SMEM, then each thread handles 8 cols via a 16-byte int4 LD/ST +// for peak HBM coalescing on SM100. Wrapper enforces M, N % 128 alignment. +constexpr int kTileCols = 256; +constexpr int kTileRows = 16; +constexpr int kElemsPerThread = 8; // bf16x8 = 16-byte vector +constexpr int kThreadsX = kTileCols / kElemsPerThread; +constexpr int kThreadsY = kTileRows; +constexpr int kThreadsPerBlock = kThreadsX * kThreadsY; +static_assert(kTileCols % kElemsPerThread == 0, "kTileCols must be a multiple of kElemsPerThread"); +static_assert(kElemsPerThread * sizeof(__nv_bfloat16) == sizeof(int4), + "kElemsPerThread bf16 must pack into a single int4 (16 bytes)"); + +__global__ void __launch_bounds__(kThreadsPerBlock) + per_token_post_scale_kernel(__nv_bfloat16* __restrict__ d, const float* __restrict__ row_amax_a, + const float* __restrict__ row_amax_b, const int M, const int N) { + __shared__ float s_row_amax[kTileRows]; + __shared__ float s_col_amax[kTileCols]; + + const int row_tile = blockIdx.y * kTileRows; + const int col_tile = blockIdx.x * kTileCols; + + // Cooperatively load row + col amaxes into SMEM (272 floats / 512 threads). + const int tid = threadIdx.y * kThreadsX + threadIdx.x; + if (tid < kTileRows) { + const int gi = row_tile + tid; + s_row_amax[tid] = (gi < M) ? row_amax_a[gi] : 0.0f; + } + if (tid < kTileCols) { + const int gj = col_tile + tid; + s_col_amax[tid] = (gj < N) ? row_amax_b[gj] : 0.0f; + } + __syncthreads(); + + const int i = row_tile + threadIdx.y; + const int j0 = col_tile + threadIdx.x * kElemsPerThread; + if (i >= M || j0 >= N) return; + + const float a = s_row_amax[threadIdx.y]; + const size_t base = static_cast(i) * N + j0; + + // Fast path = 16-byte aligned LD/ST; slow path = boundary tile fallback. + if (j0 + kElemsPerThread <= N) { + // __align__(16) is required for the int4 reinterpret_cast to be defined. + __nv_bfloat16 __align__(16) chunk[kElemsPerThread]; + *reinterpret_cast(chunk) = *reinterpret_cast(&d[base]); +#pragma unroll + for (int e = 0; e < kElemsPerThread; ++e) { + const float b = s_col_amax[threadIdx.x * kElemsPerThread + e]; + const float current = static_cast(chunk[e]); + chunk[e] = static_cast<__nv_bfloat16>(current * a * b); + } + *reinterpret_cast(&d[base]) = *reinterpret_cast(chunk); + } else { +#pragma unroll + for (int e = 0; e < kElemsPerThread; ++e) { + const int j = j0 + e; + if (j >= N) break; + const float b = s_col_amax[threadIdx.x * kElemsPerThread + e]; + const size_t idx = base + e; + const float current = static_cast(d[idx]); + d[idx] = static_cast<__nv_bfloat16>(current * a * b); + } + } +} + +} // namespace + +void per_token_post_scale(Tensor* d, const Tensor& row_amax_a, const Tensor& row_amax_b, + cudaStream_t stream) { + NVTE_CHECK(d->has_data(), "NVFP4 per-token post-scale: d has no data."); + NVTE_CHECK(d->data.dtype == DType::kBFloat16, + "NVFP4 per-token post-scale: d must be BF16 (got non-BF16 dtype)."); + NVTE_CHECK(row_amax_a.data.dtype == DType::kFloat32, + "NVFP4 per-token post-scale: row_amax_a must be FP32."); + NVTE_CHECK(row_amax_b.data.dtype == DType::kFloat32, + "NVFP4 per-token post-scale: row_amax_b must be FP32."); + + const auto& d_shape = d->data.shape; + NVTE_CHECK(d_shape.size() == 2, + "NVFP4 per-token post-scale: d must be 2D, got rank=", d_shape.size()); + const int M = static_cast(d_shape[0]); + const int N = static_cast(d_shape[1]); + NVTE_CHECK(row_amax_a.data.numel() == static_cast(M), + "NVFP4 per-token post-scale: row_amax_a numel must equal M=", M, ", got ", + row_amax_a.data.numel()); + NVTE_CHECK(row_amax_b.data.numel() == static_cast(N), + "NVFP4 per-token post-scale: row_amax_b numel must equal N=", N, ", got ", + row_amax_b.data.numel()); + + if (M == 0 || N == 0) { + return; + } + + // 32 x 16 threads = 512/block; covers 256 cols x 16 rows = 4096 elems/block. + dim3 block(kThreadsX, kThreadsY, 1); + dim3 grid((N + kTileCols - 1) / kTileCols, (M + kTileRows - 1) / kTileRows, 1); + per_token_post_scale_kernel<<>>( + reinterpret_cast<__nv_bfloat16*>(d->data.dptr), + reinterpret_cast(row_amax_a.data.dptr), + reinterpret_cast(row_amax_b.data.dptr), M, N); + NVTE_CHECK_CUDA(cudaGetLastError()); +} + +} // namespace nvfp4_per_token +} // namespace transformer_engine + +void nvte_nvfp4_per_token_post_scale(NVTETensor d, const NVTETensor row_amax_a, + const NVTETensor row_amax_b, cudaStream_t stream) { + NVTE_API_CALL(nvte_nvfp4_per_token_post_scale); + using namespace transformer_engine; + + transformer_engine::nvfp4_per_token::per_token_post_scale( + convertNVTETensorCheck(d), *convertNVTETensorCheck(row_amax_a), + *convertNVTETensorCheck(row_amax_b), stream); +} diff --git a/transformer_engine/common/include/transformer_engine/nvfp4_cutlass_gemm.h b/transformer_engine/common/include/transformer_engine/nvfp4_cutlass_gemm.h new file mode 100644 index 0000000000..37b5bd6930 --- /dev/null +++ b/transformer_engine/common/include/transformer_engine/nvfp4_cutlass_gemm.h @@ -0,0 +1,41 @@ +/************************************************************************* + * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +/*! \file nvfp4_cutlass_gemm.h + * \brief CUTLASS NVFP4 GEMM kernels: scalar (alpha, beta) and per-row*per-col + * fused EVT variants. BF16 output matches the cublasLt NVFP4 path. */ + +#ifndef TRANSFORMER_ENGINE_NVFP4_CUTLASS_GEMM_H_ +#define TRANSFORMER_ENGINE_NVFP4_CUTLASS_GEMM_H_ + +#include + +#include "transformer_engine.h" + +#ifdef __cplusplus +extern "C" { +#endif // __cplusplus + +/*! \brief D = alpha * (A @ B^T) + beta * C. A row-major (M,K), B col-major + * (K,N), D/C row-major (M,N). A/B FP4-e2m1 packed; SFs FP8-e4m3 in CUTLASS + * Sm1xxBlkScaledConfig layout; D BF16. M, N, K must be multiples of 256. */ +void nvte_nvfp4_cutlass_gemm(const NVTETensor a_data, const NVTETensor b_data, + const NVTETensor a_sf, const NVTETensor b_sf, NVTETensor d, + float alpha, float beta, cudaStream_t stream); + +/*! \brief D[i,j] = bf16(alpha_a[i] * alpha_b[j] * (A @ B^T)[i,j]). Per-row * + * per-col rescale fused into the EVT epilogue (replaces the trailing + * nvte_nvfp4_per_token_post_scale kernel). alpha_a/b are FP32 (M,)/(N,). */ +void nvte_nvfp4_cutlass_per_token_gemm(const NVTETensor a_data, const NVTETensor b_data, + const NVTETensor a_sf, const NVTETensor b_sf, + const NVTETensor alpha_a, const NVTETensor alpha_b, + NVTETensor d, cudaStream_t stream); + +#ifdef __cplusplus +} // extern "C" +#endif // __cplusplus + +#endif // TRANSFORMER_ENGINE_NVFP4_CUTLASS_GEMM_H_ diff --git a/transformer_engine/common/include/transformer_engine/nvfp4_per_token.h b/transformer_engine/common/include/transformer_engine/nvfp4_per_token.h new file mode 100644 index 0000000000..c8b6b630f0 --- /dev/null +++ b/transformer_engine/common/include/transformer_engine/nvfp4_per_token.h @@ -0,0 +1,166 @@ +/************************************************************************* + * Copyright (c) 2025-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#ifndef TRANSFORMER_ENGINE_NVFP4_PER_TOKEN_H_ +#define TRANSFORMER_ENGINE_NVFP4_PER_TOKEN_H_ + +#include + +#include "transformer_engine.h" + +#ifdef __cplusplus +extern "C" { +#endif + +/*! \brief Composite K1+K2: per-row + per-col amax (K1) then FP4 + 1x16 + * e4m3 SF encode (K2), back-to-back on the same stream. + * + * Production entry point for the per-token cast on bf16 + 128-aligned shapes. + * + * \param[in] with_rht non-zero -> apply 16-pt RHT on the col direction in + * both K1 and K2. Rowwise stays raw; zero is byte-equal + * to the pre-RHT path. + * \param[in] random_sign_mask_t low 16 bits = sign-flip pattern shared by + * K1 and K2. Ignored when with_rht == 0. + * \param[in] with_swizzle non-zero -> K2 emits rowwise scale_inv directly + * in the cuBLAS LT swizzled tile layout (rowwise only; + * colwise stays compact M-major). + */ +void nvte_nvfp4_per_token_quantize(const NVTETensor input, const NVTETensor noop, NVTETensor output, + int with_rht, int random_sign_mask_t, int with_swizzle, + cudaStream_t stream); + +/*! \brief Kernel 1 in isolation: per-row + per-col amax via TMA + atomicMax. + * Pre-zeroes the amax buffers and merges per-CTA partials into + * ``output->amax`` (size [M]) / ``output->columnwise_amax`` + * (size [K]). Does NOT touch FP4 data / scale_inv slots. + * + * \param[in] with_rht non-zero -> apply 16-pt RHT on the col direction + * before columnwise_amax (rowwise stays raw); zero is + * byte-equal to the pre-RHT K1. + * \param[in] random_sign_mask_t low 16 bits = sign-flip pattern; ignored + * when with_rht == 0. Type matches prod's + * nvte_hadamard_transform_amax convention. + */ +void nvte_nvfp4_per_token_amax(const NVTETensor input, const NVTETensor noop, NVTETensor output, + int with_rht, int random_sign_mask_t, cudaStream_t stream); + +/*! \brief Kernel 2 in isolation: FP4 + 1x16 e4m3 SF encode given a + * pre-filled ``output->amax`` / ``output->columnwise_amax``. Reads + * the outer amax buffer(s) and writes the FP4 data / scale_inv + * tensors only. + * + * \param[in] with_rht non-zero -> col-wise cast applies the same 16-pt RHT + * that K1 amax must have used (caller's responsibility + * to thread the same flag + mask through K1 and K2). + * \param[in] random_sign_mask_t low 16 bits = sign-flip pattern; ignored + * when with_rht == 0. + * \param[in] with_swizzle non-zero -> write rowwise scale_inv directly in + * the cuBLAS LT swizzled tile layout (rowwise only; + * colwise stays compact M-major). + */ +void nvte_nvfp4_per_token_encode(const NVTETensor input, const NVTETensor noop, NVTETensor output, + int with_rht, int random_sign_mask_t, int with_swizzle, + cudaStream_t stream); + +/*! \brief Returns 1 iff the per-token kernels accept ``(M, K, dtype)``. + * + * Currently returns 1 iff ``dtype`` is bf16 AND ``M % 128 == 0`` AND + * ``K % 128 == 0``. Cheap host-side query (no CUDA call). + * + * \param[in] M first-dim (rows). + * \param[in] K last-dim (cols). + * \param[in] input_dtype_enum NVTE_DType cast to int. + */ +int nvte_nvfp4_per_token_can_dispatch(size_t M, size_t K, int input_dtype_enum); + +/*! \brief Apply per-row * per-col outer-scale to a (M, N) bf16 GEMM output. + * + * Computes: + * + * d[i, j] = d[i, j] * row_amax_a[i] * row_amax_b[j] + */ +void nvte_nvfp4_per_token_post_scale(NVTETensor d, const NVTETensor row_amax_a, + const NVTETensor row_amax_b, cudaStream_t stream); + +/* ============================================================================ + * Grouped (multi-tensor) per-token quantize. + * + * \param[in] input (sum_M, K) bf16/fp32, row-major contiguous + * \param[in,out] outputs array of `num_tensors` NVTETensors; on + * return, amax/columnwise_amax slots are filled. + * \param[in] split_sections array of `num_tensors` size_t values, + * each a multiple of 64; sum must equal sum_M. + * \param[in] num_tensors <= 64 + * \param[in] rowwise emit per-row amax in `outputs[i].amax` + * \param[in] columnwise emit per-col amax in `outputs[i].columnwise_amax` + * \param[in] with_rht non-zero -> 16-pt RHT on the col direction + * (rowwise stays raw). + * \param[in] random_sign_mask_t low 16 bits = sign-flip pattern; must + * match the value passed to the matching cast + * if amax + cast are launched separately. + * \param[in] stream CUDA stream + */ +void nvte_group_nvfp4_per_token_amax(const NVTETensor input, NVTETensor* outputs, + const size_t* split_sections, size_t num_tensors, bool rowwise, + bool columnwise, int with_rht, int random_sign_mask_t, + cudaStream_t stream); + +/*! \brief Grouped per-token encode (FP4 + 1x16 e4m3 inner SF) using the + * row_amax / col_amax values already populated by + * `nvte_group_nvfp4_per_token_amax`. + * + * \param[in] input same as `nvte_group_nvfp4_per_token_amax` + * \param[in,out] outputs on entry: amax/columnwise_amax populated; + * on return: data/scale_inv + columnwise_data/ + * columnwise_scale_inv populated. + * \param[in] split_sections same as `nvte_group_nvfp4_per_token_amax` + * \param[in] num_tensors <= 64 + * \param[in] rowwise emit per-row FP4 + inner SF + * \param[in] columnwise emit per-col FP4 + inner SF + * \param[in] with_rht must match the preceding amax call's + * with_rht; applies the same 16-pt RHT on the + * colwise cast. + * \param[in] random_sign_mask_t low 16 bits = sign-flip pattern; must + * match K1. + * \param[in] stream CUDA stream + */ +void nvte_group_nvfp4_per_token_cast(const NVTETensor input, NVTETensor* outputs, + const size_t* split_sections, size_t num_tensors, bool rowwise, + bool columnwise, int with_rht, int random_sign_mask_t, + cudaStream_t stream); + +/*! \brief Composite K1+K2 grouped per-token quantize. Calls the amax + cast + * kernels on the same stream. This is the external API + * `tex.split_quantize(per_token=True)` should call. + * + * \param[in] input (sum_M, K) bf16/fp32, row-major contiguous + * \param[in,out] outputs on entry: amax / columnwise_amax / data / + * scale_inv / columnwise_data / + * columnwise_scale_inv slots allocated; + * on return: all populated. + * \param[in] split_sections array of `num_tensors` size_t values, + * each a multiple of 64; sum must equal sum_M. + * \param[in] num_tensors <= 64 + * \param[in] rowwise emit rowwise output + * \param[in] columnwise emit columnwise output + * \param[in] with_rht non-zero -> 16-pt RHT on the col direction + * in BOTH K1 and K2; zero is byte-equal to the + * pre-RHT path. + * \param[in] random_sign_mask_t low 16 bits = sign-flip pattern shared + * between K1 and K2; ignored when with_rht==0. + * \param[in] stream CUDA stream + */ +void nvte_group_nvfp4_per_token_quantize(const NVTETensor input, NVTETensor* outputs, + const size_t* split_sections, size_t num_tensors, + bool rowwise, bool columnwise, int with_rht, + int random_sign_mask_t, cudaStream_t stream); + +#ifdef __cplusplus +} +#endif + +#endif // TRANSFORMER_ENGINE_NVFP4_PER_TOKEN_H_ diff --git a/transformer_engine/pytorch/csrc/extensions.h b/transformer_engine/pytorch/csrc/extensions.h index 8082ff07ed..2b84318235 100644 --- a/transformer_engine/pytorch/csrc/extensions.h +++ b/transformer_engine/pytorch/csrc/extensions.h @@ -411,6 +411,9 @@ at::Tensor scaled_aligned_causal_masked_softmax_backward(at::Tensor output_grads void compute_amax(const at::Tensor &tensor, at::Tensor &amax); +void hadamard_transform_amax(const at::Tensor &tensor, at::Tensor &rowwise_amax, + at::Tensor &columnwise_amax, int64_t rht_matrix_random_sign_mask); + void fused_amax_and_scale_update_after_reduction(const at::Tensor &amax_reduction_buffer, std::vector amax_histories, std::vector scales, @@ -448,6 +451,91 @@ void mxfp8_scaling_partial_cast(const at::Tensor &input, at::Tensor output_rowwi const at::Tensor &scale_inv_colwise, int rows, int cols, size_t start_offset); +// Stage-1 forked CUTLASS NVFP4 x NVFP4 -> BF16 GEMM with scalar (alpha, beta) +// epilogue. Drop-in replacement for cuBLAS LT NVFP4 (the production path). +// To get a CUTLASS per-token GEMM, pair this with nvte_nvfp4_per_token_post_scale +// (same trick the cuBLAS LT per-token path uses). a_sf_swizzled / b_sf_swizzled +// = true skip the corresponding internal swizzle and consume already-swizzled +// SFs directly (apples-to-apples vs cuBLAS LT in --gemm-only). +void nvfp4_cutlass_gemm(const at::Tensor &a_data, const at::Tensor &b_data, const at::Tensor &a_sf, + const at::Tensor &b_sf, at::Tensor d, int64_t m, int64_t n, int64_t k, + double alpha, double beta, bool a_sf_swizzled, bool b_sf_swizzled); + +// CUTLASS NVFP4 GEMM with per-token rescale fused into the epilogue: +// D[i, j] = bf16(alpha_a[i] * alpha_b[j] * (A @ B^T)[i, j]) +// One launch, no separate post-scale kernel. alpha_a / alpha_b are fp32 +// (M,) / (N,) outer-scale vectors. +void nvfp4_cutlass_per_token_gemm(const at::Tensor &a_data, const at::Tensor &b_data, + const at::Tensor &a_sf, const at::Tensor &b_sf, + const at::Tensor &alpha_a, const at::Tensor &alpha_b, + at::Tensor d, int64_t m, int64_t n, int64_t k, bool a_sf_swizzled, + bool b_sf_swizzled); + +// with_swizzle=true makes K2 write rowwise scale_inv in the cuBLAS LT +// swizzled tile layout (skips the standalone nvte_swizzle_scaling_factors). +// Has no effect on colwise scale_inv (rowwise-only for now). +void nvfp4_per_token_quantize(const at::Tensor &input, at::Tensor q_row, at::Tensor s_dec_row, + at::Tensor row_amax, at::Tensor q_col, at::Tensor s_dec_col, + at::Tensor col_amax, bool rowwise, bool columnwise, bool with_rht, + int64_t random_sign_mask_t, bool with_swizzle); + +void nvfp4_per_token_amax(const at::Tensor &input, at::Tensor row_amax, at::Tensor col_amax, + bool rowwise, bool columnwise, bool with_rht, int64_t random_sign_mask_t); + +void nvfp4_per_token_encode(const at::Tensor &input, at::Tensor q_row, at::Tensor s_dec_row, + at::Tensor row_amax, at::Tensor q_col, at::Tensor s_dec_col, + at::Tensor col_amax, bool rowwise, bool columnwise, bool with_rht, + int64_t random_sign_mask_t, bool with_swizzle); + +void nvfp4_per_token_post_scale(at::Tensor d, const at::Tensor &row_amax_a, + const at::Tensor &row_amax_b); + +// Standalone rowwise-SF swizzle for one NVFP4 operand. One launch per call, +// mirrors prod NVFP4 GEMM's per-operand swizzle. Used by --qs bench mode. +void nvfp4_per_token_swizzle_rowwise_sf(const at::Tensor &data, const at::Tensor &sf_in, + at::Tensor sf_out); + +void nvfp4_per_token_gemm(const at::Tensor &a_data, const at::Tensor &b_data, + const at::Tensor &a_sf, const at::Tensor &b_sf, + const at::Tensor &a_row_amax, const at::Tensor &b_row_amax, at::Tensor d, + const at::Tensor &workspace, int64_t m, int64_t n, int64_t k, + double alpha, double beta, bool a_sf_swizzled, bool b_sf_swizzled, + bool skip_post_scale = false); + +// Bench-only per-tensor twin of nvfp4_per_token_gemm: scalar amaxes folded +// into cuBLAS LT alpha via the amax slot; no trailing post-scale. +void nvfp4_per_tensor_gemm(const at::Tensor &a_data, const at::Tensor &b_data, + const at::Tensor &a_sf, const at::Tensor &b_sf, const at::Tensor &a_amax, + const at::Tensor &b_amax, at::Tensor d, const at::Tensor &workspace, + int64_t m, int64_t n, int64_t k, double alpha, double beta, + bool a_sf_swizzled, bool b_sf_swizzled); + +// with_rht=true applies a 16-pt RHT on the col direction in BOTH K1 and K2; +// random_sign_mask_t low 16 bits = sign pattern (ignored when with_rht=false). +void nvfp4_per_token_group_quantize( + const at::Tensor &input, const std::vector &split_sections, + std::vector q_row_list, std::vector s_dec_row_list, + std::vector row_amax_list, std::vector q_col_list, + std::vector s_dec_col_list, std::vector col_amax_list, bool rowwise, + bool columnwise, bool with_rht, int64_t random_sign_mask_t); + +// Amax-only variant of the grouped quantize. Useful for multi-rank training +// where amax is allReduced before the cast pass. Caller must thread the +// matching with_rht / mask into the subsequent cast launch. +void nvfp4_per_token_group_amax(const at::Tensor &input, const std::vector &split_sections, + std::vector row_amax_list, + std::vector col_amax_list, bool rowwise, + bool columnwise, bool with_rht, int64_t random_sign_mask_t); + +// Bulk grouped quantize: allocate-view-dispatch all in one pybind hop. +// Returns 6 per-split vectors (q_row, s_dec_row_fp8, row_amax, q_col, +// s_dec_col_fp8, col_amax); disabled directions return empty vectors. +std::tuple, std::vector, std::vector, + std::vector, std::vector, std::vector> +nvfp4_per_token_group_quantize_bulk(const at::Tensor &input, + const std::vector &split_sections, bool rowwise, + bool columnwise, bool with_rht, int64_t random_sign_mask_t); + /*************************************************************************************************** * Rotary positional embedding **************************************************************************************************/ diff --git a/transformer_engine/pytorch/csrc/extensions/nvfp4_cutlass_gemm.cpp b/transformer_engine/pytorch/csrc/extensions/nvfp4_cutlass_gemm.cpp new file mode 100644 index 0000000000..e5ec253662 --- /dev/null +++ b/transformer_engine/pytorch/csrc/extensions/nvfp4_cutlass_gemm.cpp @@ -0,0 +1,226 @@ +/************************************************************************* + * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#include +#include + +#include "../extensions.h" + +namespace transformer_engine::pytorch { + +void nvfp4_cutlass_gemm(const at::Tensor &a_data, const at::Tensor &b_data, const at::Tensor &a_sf, + const at::Tensor &b_sf, at::Tensor d, int64_t m, int64_t n, int64_t k, + double alpha, double beta, bool a_sf_swizzled, bool b_sf_swizzled) { + TORCH_CHECK( + a_data.is_cuda() && b_data.is_cuda() && a_sf.is_cuda() && b_sf.is_cuda() && d.is_cuda(), + "All tensors must be CUDA tensors"); + TORCH_CHECK(a_data.is_contiguous() && b_data.is_contiguous() && a_sf.is_contiguous() && + b_sf.is_contiguous() && d.is_contiguous(), + "All tensors must be contiguous"); + + // FP4 packed 2/byte and FP8-e4m3 SFs are both stored as uint8 (TE quantizer + // wire type). Accumulator is fp32 in TMEM; only the final epilogue cast is bf16. + TORCH_CHECK(a_data.scalar_type() == at::ScalarType::Byte, "a_data must be uint8 (FP4 packed)"); + TORCH_CHECK(b_data.scalar_type() == at::ScalarType::Byte, "b_data must be uint8 (FP4 packed)"); + TORCH_CHECK(a_sf.scalar_type() == at::ScalarType::Byte, "a_sf must be uint8 (FP8 e4m3)"); + TORCH_CHECK(b_sf.scalar_type() == at::ScalarType::Byte, "b_sf must be uint8 (FP8 e4m3)"); + TORCH_CHECK(d.scalar_type() == at::ScalarType::BFloat16, "d must be bfloat16"); + + TORCH_CHECK(a_data.dim() == 2, "a_data must be 2D, got rank=", a_data.dim()); + TORCH_CHECK(b_data.dim() == 2, "b_data must be 2D, got rank=", b_data.dim()); + TORCH_CHECK(d.dim() == 2, "d must be 2D, got rank=", d.dim()); + + // Storage shapes must match caller-declared (M, N, K). + TORCH_CHECK(a_data.size(0) == m && a_data.size(1) * 2 == k, + "a_data storage shape mismatch: expected (M=", m, ", K/2=", k / 2, "), got (", + a_data.size(0), ", ", a_data.size(1), ")"); + TORCH_CHECK(b_data.size(0) == n && b_data.size(1) * 2 == k, + "b_data storage shape mismatch: expected (N=", n, ", K/2=", k / 2, "), got (", + b_data.size(0), ", ", b_data.size(1), ")"); + TORCH_CHECK(d.size(0) == m && d.size(1) == n, "d shape mismatch: expected (M=", m, ", N=", n, + "), got (", d.size(0), ", ", d.size(1), ")"); + + // CUTLASS NVFP4 mainloop wants SF in SM100 Sm1xxBlkScaledConfig layout; + // swizzle internally so the caller can pass linear (M, K/16) too. + const auto stream = at::cuda::getCurrentCUDAStream(); + + const std::vector a_data_shape = {static_cast(m), static_cast(k)}; + const std::vector b_data_shape = {static_cast(n), static_cast(k)}; + const std::vector a_sf_shape = {static_cast(m), static_cast(k / 16)}; + const std::vector b_sf_shape = {static_cast(n), static_cast(k / 16)}; + + TORCH_CHECK(a_sf.numel() == static_cast(m * k / 16), + "a_sf size mismatch: expected M*K/16=", m * k / 16, ", got ", a_sf.numel()); + TORCH_CHECK(b_sf.numel() == static_cast(n * k / 16), + "b_sf size mismatch: expected N*K/16=", n * k / 16, ", got ", b_sf.numel()); + + // a_sf_swizzled / b_sf_swizzled = true skip the per-operand swizzle and + // consume the caller's buffer directly (bench-only fast-path for --gemm-only). + auto byte_opts = a_sf.options().dtype(at::kByte); + at::Tensor a_sf_swz_buf; + at::Tensor b_sf_swz_buf; + void *a_sf_swz_ptr = nullptr; + void *b_sf_swz_ptr = nullptr; + + if (a_sf_swizzled) { + a_sf_swz_ptr = a_sf.data_ptr(); + } else { + a_sf_swz_buf = at::empty({a_sf.numel()}, byte_opts); + a_sf_swz_ptr = a_sf_swz_buf.data_ptr(); + + TensorWrapper in_nvte(NVTE_NVFP4_1D_SCALING); + in_nvte.set_rowwise_data(a_data.data_ptr(), DType::kFloat4E2M1, a_data_shape); + in_nvte.set_rowwise_scale_inv(a_sf.data_ptr(), DType::kFloat8E4M3, a_sf_shape); + + TensorWrapper out_nvte(NVTE_NVFP4_1D_SCALING); + out_nvte.set_rowwise_data(a_data.data_ptr(), DType::kFloat4E2M1, a_data_shape); + out_nvte.set_rowwise_scale_inv(a_sf_swz_ptr, DType::kFloat8E4M3, a_sf_shape); + out_nvte.set_with_gemm_swizzled_scales(true); + + nvte_swizzle_scaling_factors(in_nvte.data(), out_nvte.data(), stream); + } + if (b_sf_swizzled) { + b_sf_swz_ptr = b_sf.data_ptr(); + } else { + b_sf_swz_buf = at::empty({b_sf.numel()}, byte_opts); + b_sf_swz_ptr = b_sf_swz_buf.data_ptr(); + + TensorWrapper in_nvte(NVTE_NVFP4_1D_SCALING); + in_nvte.set_rowwise_data(b_data.data_ptr(), DType::kFloat4E2M1, b_data_shape); + in_nvte.set_rowwise_scale_inv(b_sf.data_ptr(), DType::kFloat8E4M3, b_sf_shape); + + TensorWrapper out_nvte(NVTE_NVFP4_1D_SCALING); + out_nvte.set_rowwise_data(b_data.data_ptr(), DType::kFloat4E2M1, b_data_shape); + out_nvte.set_rowwise_scale_inv(b_sf_swz_ptr, DType::kFloat8E4M3, b_sf_shape); + out_nvte.set_with_gemm_swizzled_scales(true); + + nvte_swizzle_scaling_factors(in_nvte.data(), out_nvte.data(), stream); + } + + // Logical FP4/FP8 dtypes for the C API; pointers reference uint8 storage. + TensorWrapper a_te = + makeTransformerEngineTensor(a_data.data_ptr(), a_data_shape, DType::kFloat4E2M1); + TensorWrapper b_te = + makeTransformerEngineTensor(b_data.data_ptr(), b_data_shape, DType::kFloat4E2M1); + TensorWrapper a_sf_te = makeTransformerEngineTensor( + a_sf_swz_ptr, std::vector{static_cast(a_sf.numel())}, DType::kFloat8E4M3); + TensorWrapper b_sf_te = makeTransformerEngineTensor( + b_sf_swz_ptr, std::vector{static_cast(b_sf.numel())}, DType::kFloat8E4M3); + TensorWrapper d_te = makeTransformerEngineTensor( + d.data_ptr(), std::vector{static_cast(m), static_cast(n)}, + DType::kBFloat16); + + nvte_nvfp4_cutlass_gemm(a_te.data(), b_te.data(), a_sf_te.data(), b_sf_te.data(), d_te.data(), + static_cast(alpha), static_cast(beta), stream); +} + +// D[i,j] = bf16(alpha_a[i] * alpha_b[j] * (A @ B^T)[i,j]) -- per-row*per-col +// fold REPLACES the trailing nvfp4_per_token_post_scale kernel. Same SF-swizzle +// contract as nvfp4_cutlass_gemm above. +void nvfp4_cutlass_per_token_gemm(const at::Tensor &a_data, const at::Tensor &b_data, + const at::Tensor &a_sf, const at::Tensor &b_sf, + const at::Tensor &alpha_a, const at::Tensor &alpha_b, + at::Tensor d, int64_t m, int64_t n, int64_t k, bool a_sf_swizzled, + bool b_sf_swizzled) { + TORCH_CHECK(a_data.is_cuda() && b_data.is_cuda() && a_sf.is_cuda() && b_sf.is_cuda() && + alpha_a.is_cuda() && alpha_b.is_cuda() && d.is_cuda(), + "All tensors must be CUDA tensors"); + TORCH_CHECK(a_data.is_contiguous() && b_data.is_contiguous() && a_sf.is_contiguous() && + b_sf.is_contiguous() && alpha_a.is_contiguous() && alpha_b.is_contiguous() && + d.is_contiguous(), + "All tensors must be contiguous"); + + TORCH_CHECK(a_data.scalar_type() == at::ScalarType::Byte, "a_data must be uint8 (FP4 packed)"); + TORCH_CHECK(b_data.scalar_type() == at::ScalarType::Byte, "b_data must be uint8 (FP4 packed)"); + TORCH_CHECK(a_sf.scalar_type() == at::ScalarType::Byte, "a_sf must be uint8 (FP8 e4m3)"); + TORCH_CHECK(b_sf.scalar_type() == at::ScalarType::Byte, "b_sf must be uint8 (FP8 e4m3)"); + TORCH_CHECK(d.scalar_type() == at::ScalarType::BFloat16, "d must be bfloat16"); + TORCH_CHECK(alpha_a.scalar_type() == at::ScalarType::Float, "alpha_a must be float32"); + TORCH_CHECK(alpha_b.scalar_type() == at::ScalarType::Float, "alpha_b must be float32"); + + TORCH_CHECK(a_data.dim() == 2 && b_data.dim() == 2 && d.dim() == 2, + "a_data / b_data / d must all be 2D"); + + TORCH_CHECK(a_data.size(0) == m && a_data.size(1) * 2 == k, + "a_data storage shape mismatch: expected (M=", m, ", K/2=", k / 2, "), got (", + a_data.size(0), ", ", a_data.size(1), ")"); + TORCH_CHECK(b_data.size(0) == n && b_data.size(1) * 2 == k, + "b_data storage shape mismatch: expected (N=", n, ", K/2=", k / 2, "), got (", + b_data.size(0), ", ", b_data.size(1), ")"); + TORCH_CHECK(d.size(0) == m && d.size(1) == n, "d shape mismatch: expected (M=", m, ", N=", n, + "), got (", d.size(0), ", ", d.size(1), ")"); + TORCH_CHECK(alpha_a.numel() == m, "alpha_a must have M=", m, " elements, got ", alpha_a.numel()); + TORCH_CHECK(alpha_b.numel() == n, "alpha_b must have N=", n, " elements, got ", alpha_b.numel()); + + const auto stream = at::cuda::getCurrentCUDAStream(); + + const std::vector a_data_shape = {static_cast(m), static_cast(k)}; + const std::vector b_data_shape = {static_cast(n), static_cast(k)}; + const std::vector a_sf_shape = {static_cast(m), static_cast(k / 16)}; + const std::vector b_sf_shape = {static_cast(n), static_cast(k / 16)}; + + TORCH_CHECK(a_sf.numel() == static_cast(m * k / 16), + "a_sf size mismatch: expected M*K/16=", m * k / 16, ", got ", a_sf.numel()); + TORCH_CHECK(b_sf.numel() == static_cast(n * k / 16), + "b_sf size mismatch: expected N*K/16=", n * k / 16, ", got ", b_sf.numel()); + + // SF swizzle (shared logic with the scalar-alpha entry point above). + auto byte_opts = a_sf.options().dtype(at::kByte); + at::Tensor a_sf_swz_buf; + at::Tensor b_sf_swz_buf; + void *a_sf_swz_ptr = nullptr; + void *b_sf_swz_ptr = nullptr; + + if (a_sf_swizzled) { + a_sf_swz_ptr = a_sf.data_ptr(); + } else { + a_sf_swz_buf = at::empty({a_sf.numel()}, byte_opts); + a_sf_swz_ptr = a_sf_swz_buf.data_ptr(); + TensorWrapper in_nvte(NVTE_NVFP4_1D_SCALING); + in_nvte.set_rowwise_data(a_data.data_ptr(), DType::kFloat4E2M1, a_data_shape); + in_nvte.set_rowwise_scale_inv(a_sf.data_ptr(), DType::kFloat8E4M3, a_sf_shape); + TensorWrapper out_nvte(NVTE_NVFP4_1D_SCALING); + out_nvte.set_rowwise_data(a_data.data_ptr(), DType::kFloat4E2M1, a_data_shape); + out_nvte.set_rowwise_scale_inv(a_sf_swz_ptr, DType::kFloat8E4M3, a_sf_shape); + out_nvte.set_with_gemm_swizzled_scales(true); + nvte_swizzle_scaling_factors(in_nvte.data(), out_nvte.data(), stream); + } + if (b_sf_swizzled) { + b_sf_swz_ptr = b_sf.data_ptr(); + } else { + b_sf_swz_buf = at::empty({b_sf.numel()}, byte_opts); + b_sf_swz_ptr = b_sf_swz_buf.data_ptr(); + TensorWrapper in_nvte(NVTE_NVFP4_1D_SCALING); + in_nvte.set_rowwise_data(b_data.data_ptr(), DType::kFloat4E2M1, b_data_shape); + in_nvte.set_rowwise_scale_inv(b_sf.data_ptr(), DType::kFloat8E4M3, b_sf_shape); + TensorWrapper out_nvte(NVTE_NVFP4_1D_SCALING); + out_nvte.set_rowwise_data(b_data.data_ptr(), DType::kFloat4E2M1, b_data_shape); + out_nvte.set_rowwise_scale_inv(b_sf_swz_ptr, DType::kFloat8E4M3, b_sf_shape); + out_nvte.set_with_gemm_swizzled_scales(true); + nvte_swizzle_scaling_factors(in_nvte.data(), out_nvte.data(), stream); + } + + TensorWrapper a_te = + makeTransformerEngineTensor(a_data.data_ptr(), a_data_shape, DType::kFloat4E2M1); + TensorWrapper b_te = + makeTransformerEngineTensor(b_data.data_ptr(), b_data_shape, DType::kFloat4E2M1); + TensorWrapper a_sf_te = makeTransformerEngineTensor( + a_sf_swz_ptr, std::vector{static_cast(a_sf.numel())}, DType::kFloat8E4M3); + TensorWrapper b_sf_te = makeTransformerEngineTensor( + b_sf_swz_ptr, std::vector{static_cast(b_sf.numel())}, DType::kFloat8E4M3); + TensorWrapper aa_te = makeTransformerEngineTensor( + alpha_a.data_ptr(), std::vector{static_cast(m)}, DType::kFloat32); + TensorWrapper ab_te = makeTransformerEngineTensor( + alpha_b.data_ptr(), std::vector{static_cast(n)}, DType::kFloat32); + TensorWrapper d_te = makeTransformerEngineTensor( + d.data_ptr(), std::vector{static_cast(m), static_cast(n)}, + DType::kBFloat16); + + nvte_nvfp4_cutlass_per_token_gemm(a_te.data(), b_te.data(), a_sf_te.data(), b_sf_te.data(), + aa_te.data(), ab_te.data(), d_te.data(), stream); +} + +} // namespace transformer_engine::pytorch diff --git a/transformer_engine/pytorch/csrc/extensions/nvfp4_per_token.cpp b/transformer_engine/pytorch/csrc/extensions/nvfp4_per_token.cpp new file mode 100644 index 0000000000..3575401731 --- /dev/null +++ b/transformer_engine/pytorch/csrc/extensions/nvfp4_per_token.cpp @@ -0,0 +1,838 @@ +/************************************************************************* + * Copyright (c) 2025-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#include +#include +#include + +#include +#include +#include +#include +#include + +#include "../extensions.h" + +namespace transformer_engine::pytorch { + +// NVFP4 per-token cast bindings. Shared TensorWrapper assembler dispatches +// composite (K1+K2), K1-only and K2-only via `mode`. bf16-only, M/K % 128 == 0. +// SFs emit in compact (non-swizzled) layout; swizzle for cuBLAS LT lives elsewhere. +namespace { + +// Validates the input and assembles ``out_te`` for all 3 modes; caller +// dispatches to the right C-API entry on the caller's stream. +void assemble_per_token_tensors(const at::Tensor& input, at::Tensor q_row, at::Tensor s_dec_row, + at::Tensor row_amax, at::Tensor q_col, at::Tensor s_dec_col, + at::Tensor col_amax, bool rowwise, bool columnwise, int mode, + TensorWrapper& in_te, TensorWrapper& out_te) { + TORCH_CHECK(rowwise || columnwise, "At least one of rowwise/columnwise must be True."); + TORCH_CHECK(input.is_cuda(), "input must be a CUDA tensor"); + TORCH_CHECK(input.is_contiguous(), "input must be contiguous"); + TORCH_CHECK(input.dim() == 2, "input must be 2D"); + TORCH_CHECK(input.scalar_type() == at::ScalarType::BFloat16, + "Per-token cast is bf16-only. Got dtype ", input.scalar_type()); + const int64_t M = input.size(0); + const int64_t K = input.size(1); + TORCH_CHECK(M % 128 == 0, "Per-token cast requires M % 128 == 0; got M=", M); + TORCH_CHECK(K % 128 == 0, "Per-token cast requires K % 128 == 0; got K=", K); + + const std::vector in_shape = {static_cast(M), static_cast(K)}; + in_te = makeTransformerEngineTensor(input.data_ptr(), in_shape, DType::kBFloat16); + + // K1 (mode==1) populates ONLY amax slots; K2 / composite (mode==0/2) + // populate the FP4 + e4m3 SF slots too. The amax slots are also wired + // for K2 because the kernel READS them. + const bool needs_fp4_outputs = (mode == 0) || (mode == 2); + + if (rowwise) { + TORCH_CHECK(row_amax.is_cuda() && row_amax.is_contiguous(), + "row_amax must be a contiguous CUDA tensor"); + TORCH_CHECK(row_amax.scalar_type() == at::ScalarType::Float, "row_amax must be float32"); + TORCH_CHECK(row_amax.numel() == M, "row_amax numel mismatch: expected M=", M, ", got ", + row_amax.numel()); + out_te.set_amax(row_amax.data_ptr(), DType::kFloat32, + std::vector{static_cast(M)}); + + if (needs_fp4_outputs) { + TORCH_CHECK(q_row.is_cuda() && q_row.is_contiguous(), + "q_row must be a contiguous CUDA tensor"); + TORCH_CHECK(s_dec_row.is_cuda() && s_dec_row.is_contiguous(), + "s_dec_row must be a contiguous CUDA tensor"); + TORCH_CHECK(q_row.scalar_type() == at::ScalarType::Byte, "q_row must be uint8 (FP4 packed)"); + TORCH_CHECK(s_dec_row.scalar_type() == at::ScalarType::Byte, + "s_dec_row must be uint8 (FP8 e4m3 raw bytes)"); + TORCH_CHECK(q_row.numel() == M * K / 2, "q_row numel mismatch: expected M*K/2=", M * K / 2, + ", got ", q_row.numel()); + TORCH_CHECK(s_dec_row.numel() == M * K / 16, + "s_dec_row numel mismatch: expected M*K/16=", M * K / 16, ", got ", + s_dec_row.numel()); + out_te.set_rowwise_data(q_row.data_ptr(), DType::kFloat4E2M1, in_shape); + out_te.set_rowwise_scale_inv( + s_dec_row.data_ptr(), DType::kFloat8E4M3, + std::vector{static_cast(M), static_cast(K / 16)}); + } + } + if (columnwise) { + TORCH_CHECK(col_amax.is_cuda() && col_amax.is_contiguous(), + "col_amax must be a contiguous CUDA tensor"); + TORCH_CHECK(col_amax.scalar_type() == at::ScalarType::Float, "col_amax must be float32"); + TORCH_CHECK(col_amax.numel() == K, "col_amax numel mismatch: expected K=", K, ", got ", + col_amax.numel()); + out_te.set_columnwise_amax(col_amax.data_ptr(), DType::kFloat32, + std::vector{static_cast(K)}); + + if (needs_fp4_outputs) { + TORCH_CHECK(q_col.is_cuda() && q_col.is_contiguous(), + "q_col must be a contiguous CUDA tensor"); + TORCH_CHECK(s_dec_col.is_cuda() && s_dec_col.is_contiguous(), + "s_dec_col must be a contiguous CUDA tensor"); + TORCH_CHECK(q_col.scalar_type() == at::ScalarType::Byte, "q_col must be uint8 (FP4 packed)"); + TORCH_CHECK(s_dec_col.scalar_type() == at::ScalarType::Byte, + "s_dec_col must be uint8 (FP8 e4m3 raw bytes)"); + TORCH_CHECK(q_col.numel() == K * M / 2, "q_col numel mismatch: expected K*M/2=", K * M / 2, + ", got ", q_col.numel()); + TORCH_CHECK(s_dec_col.numel() == K * M / 16, + "s_dec_col numel mismatch: expected K*M/16=", K * M / 16, ", got ", + s_dec_col.numel()); + out_te.set_columnwise_data( + q_col.data_ptr(), DType::kFloat4E2M1, + std::vector{static_cast(K), static_cast(M)}); + out_te.set_columnwise_scale_inv( + s_dec_col.data_ptr(), DType::kFloat8E4M3, + std::vector{static_cast(K), static_cast(M / 16)}); + } + } +} + +} // namespace + +// Composite K1 + K2 (back-to-back). with_rht: 16-pt col-wise RHT in both +// (keeps outer + inner SFs consistent). with_swizzle: K2 emits rowwise +// scale_inv in cuBLAS LT swizzled layout (skips downstream swizzle). +void nvfp4_per_token_quantize(const at::Tensor& input, at::Tensor q_row, at::Tensor s_dec_row, + at::Tensor row_amax, at::Tensor q_col, at::Tensor s_dec_col, + at::Tensor col_amax, bool rowwise, bool columnwise, bool with_rht, + int64_t random_sign_mask_t, bool with_swizzle) { + TensorWrapper in_te; + TensorWrapper out_te(NVTE_NVFP4_1D_SCALING); + assemble_per_token_tensors(input, q_row, s_dec_row, row_amax, q_col, s_dec_col, col_amax, rowwise, + columnwise, /*mode=*/0, in_te, out_te); + if (with_swizzle) out_te.set_with_gemm_swizzled_scales(true); + const auto stream = at::cuda::getCurrentCUDAStream(); + nvte_nvfp4_per_token_quantize(in_te.data(), nullptr, out_te.data(), with_rht ? 1 : 0, + static_cast(random_sign_mask_t & 0xFFFF), with_swizzle ? 1 : 0, + stream); +} + +// K1-only (diagnostic / bench): populates only amax buffers. with_rht=true +// applies the 16-pt col-wise RHT before amax (rowwise unaffected); +// random_sign_mask_t low 16 bits = sign-flip pattern. +void nvfp4_per_token_amax(const at::Tensor& input, at::Tensor row_amax, at::Tensor col_amax, + bool rowwise, bool columnwise, bool with_rht, + int64_t random_sign_mask_t) { + at::Tensor empty_u8; // not consumed by K1 + TensorWrapper in_te; + TensorWrapper out_te(NVTE_NVFP4_1D_SCALING); + assemble_per_token_tensors(input, empty_u8, empty_u8, row_amax, empty_u8, empty_u8, col_amax, + rowwise, columnwise, /*mode=*/1, in_te, out_te); + const auto stream = at::cuda::getCurrentCUDAStream(); + // C-API matches prod's `int` convention; only low 16 bits are consumed. + nvte_nvfp4_per_token_amax(in_te.data(), nullptr, out_te.data(), with_rht ? 1 : 0, + static_cast(random_sign_mask_t & 0xFFFF), stream); +} + +// K2-only (bench): reads pre-filled amax, emits FP4 + SFs. with_rht needs +// col_amax from K1 with the SAME mask (else inner SFs miscalibrate). +// with_swizzle: rowwise scale_inv in cuBLAS LT swizzled layout. +void nvfp4_per_token_encode(const at::Tensor& input, at::Tensor q_row, at::Tensor s_dec_row, + at::Tensor row_amax, at::Tensor q_col, at::Tensor s_dec_col, + at::Tensor col_amax, bool rowwise, bool columnwise, bool with_rht, + int64_t random_sign_mask_t, bool with_swizzle) { + TensorWrapper in_te; + TensorWrapper out_te(NVTE_NVFP4_1D_SCALING); + assemble_per_token_tensors(input, q_row, s_dec_row, row_amax, q_col, s_dec_col, col_amax, rowwise, + columnwise, /*mode=*/2, in_te, out_te); + if (with_swizzle) out_te.set_with_gemm_swizzled_scales(true); + const auto stream = at::cuda::getCurrentCUDAStream(); + nvte_nvfp4_per_token_encode(in_te.data(), nullptr, out_te.data(), with_rht ? 1 : 0, + static_cast(random_sign_mask_t & 0xFFFF), with_swizzle ? 1 : 0, + stream); +} + +// Apply per-token post-scale to a GEMM output (see nvfp4_per_token.h for math). +void nvfp4_per_token_post_scale(at::Tensor d, const at::Tensor& row_amax_a, + const at::Tensor& row_amax_b) { + TORCH_CHECK(d.is_cuda() && d.is_contiguous(), "d must be a contiguous CUDA tensor"); + TORCH_CHECK(row_amax_a.is_cuda() && row_amax_a.is_contiguous(), + "row_amax_a must be a contiguous CUDA tensor"); + TORCH_CHECK(row_amax_b.is_cuda() && row_amax_b.is_contiguous(), + "row_amax_b must be a contiguous CUDA tensor"); + TORCH_CHECK(d.dim() == 2, "d must be 2D"); + TORCH_CHECK(d.scalar_type() == at::ScalarType::BFloat16, "d must be bf16"); + TORCH_CHECK(row_amax_a.scalar_type() == at::ScalarType::Float, "row_amax_a must be fp32"); + TORCH_CHECK(row_amax_b.scalar_type() == at::ScalarType::Float, "row_amax_b must be fp32"); + + const int64_t M = d.size(0); + const int64_t N = d.size(1); + TORCH_CHECK(row_amax_a.numel() == M, "row_amax_a numel mismatch: expected M=", M, ", got ", + row_amax_a.numel()); + TORCH_CHECK(row_amax_b.numel() == N, "row_amax_b numel mismatch: expected N=", N, ", got ", + row_amax_b.numel()); + + const auto stream = at::cuda::getCurrentCUDAStream(); + + TensorWrapper d_te = makeTransformerEngineTensor( + d.data_ptr(), std::vector{static_cast(M), static_cast(N)}, + DType::kBFloat16); + TensorWrapper ra_te = makeTransformerEngineTensor( + row_amax_a.data_ptr(), std::vector{static_cast(M)}, DType::kFloat32); + TensorWrapper rb_te = makeTransformerEngineTensor( + row_amax_b.data_ptr(), std::vector{static_cast(N)}, DType::kFloat32); + + nvte_nvfp4_per_token_post_scale(d_te.data(), ra_te.data(), rb_te.data(), stream); +} + +// Standalone rowwise-SF swizzle for one NVFP4 operand: 1 launch == +// 1 nvte_swizzle_scaling_factors, mirrors prod's per-operand swizzle. +// Bench-only (--qs); sf_in M-major (M, K/16) -> sf_out swizzled. +void nvfp4_per_token_swizzle_rowwise_sf(const at::Tensor& data, const at::Tensor& sf_in, + at::Tensor sf_out) { + TORCH_CHECK(data.is_cuda() && sf_in.is_cuda() && sf_out.is_cuda(), + "All tensors must be CUDA tensors"); + TORCH_CHECK(data.is_contiguous() && sf_in.is_contiguous() && sf_out.is_contiguous(), + "All tensors must be contiguous"); + TORCH_CHECK(data.scalar_type() == at::ScalarType::Byte, "data must be uint8 (FP4 packed)"); + TORCH_CHECK(sf_in.scalar_type() == at::ScalarType::Byte, "sf_in must be uint8 (FP8 e4m3)"); + TORCH_CHECK(sf_out.scalar_type() == at::ScalarType::Byte, "sf_out must be uint8 (FP8 e4m3)"); + TORCH_CHECK(data.dim() == 2, "data must be 2D (M, K/2)"); + TORCH_CHECK(sf_in.numel() == sf_out.numel(), "sf_in/sf_out numel mismatch: ", sf_in.numel(), + " vs ", sf_out.numel()); + + const int64_t m = data.size(0); + const int64_t k = data.size(1) * 2; // FP4 packed + TORCH_CHECK(k % 16 == 0, "k must be a multiple of 16 (NVFP4 inner SFVecSize), got ", k); + TORCH_CHECK(sf_in.numel() == m * k / 16, "sf_in numel mismatch: expected m*k/16=", m * k / 16, + ", got ", sf_in.numel()); + + const std::vector data_shape = {static_cast(m), static_cast(k)}; + const std::vector sf_shape = {static_cast(m), static_cast(k / 16)}; + + TensorWrapper in_nvte(NVTE_NVFP4_1D_SCALING); + in_nvte.set_rowwise_data(data.data_ptr(), DType::kFloat4E2M1, data_shape); + in_nvte.set_rowwise_scale_inv(sf_in.data_ptr(), DType::kFloat8E4M3, sf_shape); + + TensorWrapper out_nvte(NVTE_NVFP4_1D_SCALING); + out_nvte.set_rowwise_data(data.data_ptr(), DType::kFloat4E2M1, data_shape); + out_nvte.set_rowwise_scale_inv(sf_out.data_ptr(), DType::kFloat8E4M3, sf_shape); + out_nvte.set_with_gemm_swizzled_scales(true); + + const auto stream = at::cuda::getCurrentCUDAStream(); + nvte_swizzle_scaling_factors(in_nvte.data(), out_nvte.data(), stream); +} + +// E2E NVFP4 per-token GEMM: swizzle SFs -> cuBLAS LT (amax pinned to 1.0 +// to cancel 2688^2 inner-SF) -> per-row post-scale. beta must be 0. +// a_sf_swizzled/b_sf_swizzled=true skips the in-binding swizzle for that operand. +void nvfp4_per_token_gemm(const at::Tensor& a_data, const at::Tensor& b_data, + const at::Tensor& a_sf, const at::Tensor& b_sf, + const at::Tensor& a_row_amax, const at::Tensor& b_row_amax, at::Tensor d, + const at::Tensor& workspace, int64_t m, int64_t n, int64_t k, + double alpha, double beta, bool a_sf_swizzled, bool b_sf_swizzled, + bool skip_post_scale) { + TORCH_CHECK(a_data.is_cuda() && b_data.is_cuda() && a_sf.is_cuda() && b_sf.is_cuda() && + a_row_amax.is_cuda() && b_row_amax.is_cuda() && d.is_cuda() && + workspace.is_cuda(), + "All tensors must be CUDA tensors"); + TORCH_CHECK(a_data.is_contiguous() && b_data.is_contiguous() && a_sf.is_contiguous() && + b_sf.is_contiguous() && a_row_amax.is_contiguous() && + b_row_amax.is_contiguous() && d.is_contiguous() && workspace.is_contiguous(), + "All tensors must be contiguous"); + + TORCH_CHECK(a_data.scalar_type() == at::ScalarType::Byte, "a_data must be uint8 (FP4 packed)"); + TORCH_CHECK(b_data.scalar_type() == at::ScalarType::Byte, "b_data must be uint8 (FP4 packed)"); + TORCH_CHECK(a_sf.scalar_type() == at::ScalarType::Byte, "a_sf must be uint8 (FP8 e4m3)"); + TORCH_CHECK(b_sf.scalar_type() == at::ScalarType::Byte, "b_sf must be uint8 (FP8 e4m3)"); + TORCH_CHECK(a_row_amax.scalar_type() == at::ScalarType::Float, "a_row_amax must be float32"); + TORCH_CHECK(b_row_amax.scalar_type() == at::ScalarType::Float, "b_row_amax must be float32"); + TORCH_CHECK(d.scalar_type() == at::ScalarType::BFloat16, "d must be bfloat16"); + TORCH_CHECK(workspace.scalar_type() == at::ScalarType::Byte, "workspace must be uint8"); + + TORCH_CHECK(a_data.dim() == 2 && b_data.dim() == 2 && d.dim() == 2, "a_data/b_data/d must be 2D"); + TORCH_CHECK(a_data.size(0) == m && a_data.size(1) * 2 == k, + "a_data shape mismatch: expected (M=", m, ", K/2=", k / 2, "), got (", a_data.size(0), + ", ", a_data.size(1), ")"); + TORCH_CHECK(b_data.size(0) == n && b_data.size(1) * 2 == k, + "b_data shape mismatch: expected (N=", n, ", K/2=", k / 2, "), got (", b_data.size(0), + ", ", b_data.size(1), ")"); + TORCH_CHECK(d.size(0) == m && d.size(1) == n, "d shape mismatch: expected (M=", m, ", N=", n, + "), got (", d.size(0), ", ", d.size(1), ")"); + + TORCH_CHECK(k % 16 == 0, "k must be a multiple of 16 (NVFP4 inner SFVecSize)"); + TORCH_CHECK(a_sf.numel() == static_cast(m * k / 16), + "a_sf numel mismatch: expected M*K/16=", m * k / 16, ", got ", a_sf.numel()); + TORCH_CHECK(b_sf.numel() == static_cast(n * k / 16), + "b_sf numel mismatch: expected N*K/16=", n * k / 16, ", got ", b_sf.numel()); + TORCH_CHECK(a_row_amax.numel() == m, "a_row_amax numel mismatch: expected M=", m, ", got ", + a_row_amax.numel()); + TORCH_CHECK(b_row_amax.numel() == n, "b_row_amax numel mismatch: expected N=", n, ", got ", + b_row_amax.numel()); + + TORCH_CHECK(static_cast(beta) == 0.0f, + "nvfp4_per_token_gemm: beta != 0 not yet supported. Got beta=", beta); + + const auto stream = at::cuda::getCurrentCUDAStream(); + + const std::vector a_data_shape = {static_cast(m), static_cast(k)}; + const std::vector b_data_shape = {static_cast(n), static_cast(k)}; + const std::vector a_sf_shape = {static_cast(m), static_cast(k / 16)}; + const std::vector b_sf_shape = {static_cast(n), static_cast(k / 16)}; + + // SF buffers for cuBLAS LT: reuse caller's buffer if already swizzled, + // else allocate a swizzled copy. 0/1/2 swizzle launches total. + auto byte_opts = a_sf.options().dtype(at::kByte); + at::Tensor a_sf_buf; + at::Tensor b_sf_buf; + if (a_sf_swizzled) { + a_sf_buf = a_sf; + } else { + a_sf_buf = at::empty({a_sf.numel()}, byte_opts); + TensorWrapper in_nvte(NVTE_NVFP4_1D_SCALING); + in_nvte.set_rowwise_data(a_data.data_ptr(), DType::kFloat4E2M1, a_data_shape); + in_nvte.set_rowwise_scale_inv(a_sf.data_ptr(), DType::kFloat8E4M3, a_sf_shape); + + TensorWrapper out_nvte(NVTE_NVFP4_1D_SCALING); + out_nvte.set_rowwise_data(a_data.data_ptr(), DType::kFloat4E2M1, a_data_shape); + out_nvte.set_rowwise_scale_inv(a_sf_buf.data_ptr(), DType::kFloat8E4M3, a_sf_shape); + out_nvte.set_with_gemm_swizzled_scales(true); + + nvte_swizzle_scaling_factors(in_nvte.data(), out_nvte.data(), stream); + } + if (b_sf_swizzled) { + b_sf_buf = b_sf; + } else { + b_sf_buf = at::empty({b_sf.numel()}, byte_opts); + TensorWrapper in_nvte(NVTE_NVFP4_1D_SCALING); + in_nvte.set_rowwise_data(b_data.data_ptr(), DType::kFloat4E2M1, b_data_shape); + in_nvte.set_rowwise_scale_inv(b_sf.data_ptr(), DType::kFloat8E4M3, b_sf_shape); + + TensorWrapper out_nvte(NVTE_NVFP4_1D_SCALING); + out_nvte.set_rowwise_data(b_data.data_ptr(), DType::kFloat4E2M1, b_data_shape); + out_nvte.set_rowwise_scale_inv(b_sf_buf.data_ptr(), DType::kFloat8E4M3, b_sf_shape); + out_nvte.set_with_gemm_swizzled_scales(true); + + nvte_swizzle_scaling_factors(in_nvte.data(), out_nvte.data(), stream); + } + + // Pin operand amaxes to 1.0 so cuBLAS-internal alpha cancels the 2688^2 + // inner-SF factor. Cache one fp32 "1.0" tensor per device to avoid the + // ~30-50us per-call cost of at::ones({1}) at small shapes. + static std::array s_amax_one_cache; + static std::array s_amax_one_init; + const int dev_idx = a_data.device().index(); + TORCH_CHECK(dev_idx >= 0 && dev_idx < static_cast(s_amax_one_cache.size()), + "nvfp4_per_token_gemm: unexpected device index ", dev_idx); + std::call_once(s_amax_one_init[dev_idx], [&]() { + auto fp32_opts = a_data.options().dtype(at::kFloat); + s_amax_one_cache[dev_idx] = at::ones({1}, fp32_opts); + }); + at::Tensor& amax_one = s_amax_one_cache[dev_idx]; + + // Assemble A's NVTE tensor: NVFP4_1D_SCALING + swizzled SF + amax=1.0. + TensorWrapper a_te(NVTE_NVFP4_1D_SCALING); + a_te.set_rowwise_data(a_data.data_ptr(), DType::kFloat4E2M1, a_data_shape); + a_te.set_rowwise_scale_inv(a_sf_buf.data_ptr(), DType::kFloat8E4M3, a_sf_shape); + a_te.set_amax(amax_one.data_ptr(), DType::kFloat32, std::vector{1}); + a_te.set_with_gemm_swizzled_scales(true); + + TensorWrapper b_te(NVTE_NVFP4_1D_SCALING); + b_te.set_rowwise_data(b_data.data_ptr(), DType::kFloat4E2M1, b_data_shape); + b_te.set_rowwise_scale_inv(b_sf_buf.data_ptr(), DType::kFloat8E4M3, b_sf_shape); + b_te.set_amax(amax_one.data_ptr(), DType::kFloat32, std::vector{1}); + b_te.set_with_gemm_swizzled_scales(true); + + TensorWrapper d_te = makeTransformerEngineTensor( + d.data_ptr(), std::vector{static_cast(m), static_cast(n)}, + DType::kBFloat16); + + TensorWrapper workspace_te = makeTransformerEngineTensor( + workspace.data_ptr(), std::vector{static_cast(workspace.numel())}, + DType::kByte); + + // Operands SWAPPED so cuBLAS column-major D = op(B) @ op(A) matches the + // row-major (M, N) PyTorch expects. transa=T forced (NVFP4 is TN-only). + // C and D alias (no separate accumulator). + const float alpha_f = static_cast(alpha); + const float beta_f = static_cast(beta); + nvte_cublas_gemm_v2(/*transa=*/1, /*transb=*/0, &alpha_f, + b_te.data(), // cuBLAS-A := caller's B (N, K) + a_te.data(), // cuBLAS-B := caller's A (M, K) + &beta_f, d_te.data(), d_te.data(), workspace_te.data(), + /*config=*/nullptr, stream); + + // Per-row * per-col post-scale to recover C_true from D_cublas. + // skip_post_scale=true is bench-only: isolates the cuBLAS LT GEMM cost + // from the trailing M*N bf16 epilogue (D will hold raw cuBLAS output). + if (!skip_post_scale) { + TensorWrapper ra_te = makeTransformerEngineTensor( + a_row_amax.data_ptr(), std::vector{static_cast(m)}, DType::kFloat32); + TensorWrapper rb_te = makeTransformerEngineTensor( + b_row_amax.data_ptr(), std::vector{static_cast(n)}, DType::kFloat32); + + nvte_nvfp4_per_token_post_scale(d_te.data(), ra_te.data(), rb_te.data(), stream); + } +} + +// Per-tensor twin of nvfp4_per_token_gemm: scalar amax goes through cuBLAS's +// own amax slot (no post-scale). a_sf_swizzled/b_sf_swizzled=true skips the +// in-binding swizzle (mirrors nvfp4_per_token_gemm). Bench-only baseline. +void nvfp4_per_tensor_gemm(const at::Tensor& a_data, const at::Tensor& b_data, + const at::Tensor& a_sf, const at::Tensor& b_sf, const at::Tensor& a_amax, + const at::Tensor& b_amax, at::Tensor d, const at::Tensor& workspace, + int64_t m, int64_t n, int64_t k, double alpha, double beta, + bool a_sf_swizzled, bool b_sf_swizzled) { + TORCH_CHECK(a_data.is_cuda() && b_data.is_cuda() && a_sf.is_cuda() && b_sf.is_cuda() && + a_amax.is_cuda() && b_amax.is_cuda() && d.is_cuda() && workspace.is_cuda(), + "All tensors must be CUDA tensors"); + TORCH_CHECK(a_data.is_contiguous() && b_data.is_contiguous() && a_sf.is_contiguous() && + b_sf.is_contiguous() && a_amax.is_contiguous() && b_amax.is_contiguous() && + d.is_contiguous() && workspace.is_contiguous(), + "All tensors must be contiguous"); + TORCH_CHECK(a_data.scalar_type() == at::ScalarType::Byte, "a_data must be uint8 (FP4 packed)"); + TORCH_CHECK(b_data.scalar_type() == at::ScalarType::Byte, "b_data must be uint8 (FP4 packed)"); + TORCH_CHECK(a_sf.scalar_type() == at::ScalarType::Byte, "a_sf must be uint8 (FP8 e4m3)"); + TORCH_CHECK(b_sf.scalar_type() == at::ScalarType::Byte, "b_sf must be uint8 (FP8 e4m3)"); + TORCH_CHECK(a_amax.scalar_type() == at::ScalarType::Float, "a_amax must be float32"); + TORCH_CHECK(b_amax.scalar_type() == at::ScalarType::Float, "b_amax must be float32"); + TORCH_CHECK(d.scalar_type() == at::ScalarType::BFloat16, "d must be bfloat16"); + TORCH_CHECK(workspace.scalar_type() == at::ScalarType::Byte, "workspace must be uint8"); + + TORCH_CHECK(a_data.dim() == 2 && b_data.dim() == 2 && d.dim() == 2, "a_data/b_data/d must be 2D"); + TORCH_CHECK(a_data.size(0) == m && a_data.size(1) * 2 == k, + "a_data shape mismatch: expected (M=", m, ", K/2=", k / 2, "), got (", a_data.size(0), + ", ", a_data.size(1), ")"); + TORCH_CHECK(b_data.size(0) == n && b_data.size(1) * 2 == k, + "b_data shape mismatch: expected (N=", n, ", K/2=", k / 2, "), got (", b_data.size(0), + ", ", b_data.size(1), ")"); + TORCH_CHECK(d.size(0) == m && d.size(1) == n, "d shape mismatch: expected (M=", m, ", N=", n, + "), got (", d.size(0), ", ", d.size(1), ")"); + + TORCH_CHECK(k % 16 == 0, "k must be a multiple of 16 (NVFP4 inner SFVecSize)"); + TORCH_CHECK(a_sf.numel() == static_cast(m * k / 16), + "a_sf numel mismatch: expected M*K/16=", m * k / 16, ", got ", a_sf.numel()); + TORCH_CHECK(b_sf.numel() == static_cast(n * k / 16), + "b_sf numel mismatch: expected N*K/16=", n * k / 16, ", got ", b_sf.numel()); + TORCH_CHECK(a_amax.numel() == 1, "a_amax must be a scalar (numel=1), got ", a_amax.numel()); + TORCH_CHECK(b_amax.numel() == 1, "b_amax must be a scalar (numel=1), got ", b_amax.numel()); + + TORCH_CHECK(static_cast(beta) == 0.0f, + "nvfp4_per_tensor_gemm: beta != 0 not yet supported. Got beta=", beta); + + const auto stream = at::cuda::getCurrentCUDAStream(); + + const std::vector a_data_shape = {static_cast(m), static_cast(k)}; + const std::vector b_data_shape = {static_cast(n), static_cast(k)}; + const std::vector a_sf_shape = {static_cast(m), static_cast(k / 16)}; + const std::vector b_sf_shape = {static_cast(n), static_cast(k / 16)}; + + // SF buffers for cuBLAS LT: reuse caller's buffer if already swizzled, + // else allocate a swizzled copy. 0/1/2 swizzle launches total. + auto byte_opts = a_sf.options().dtype(at::kByte); + at::Tensor a_sf_buf; + at::Tensor b_sf_buf; + if (a_sf_swizzled) { + a_sf_buf = a_sf; + } else { + a_sf_buf = at::empty({a_sf.numel()}, byte_opts); + TensorWrapper in_nvte(NVTE_NVFP4_1D_SCALING); + in_nvte.set_rowwise_data(a_data.data_ptr(), DType::kFloat4E2M1, a_data_shape); + in_nvte.set_rowwise_scale_inv(a_sf.data_ptr(), DType::kFloat8E4M3, a_sf_shape); + + TensorWrapper out_nvte(NVTE_NVFP4_1D_SCALING); + out_nvte.set_rowwise_data(a_data.data_ptr(), DType::kFloat4E2M1, a_data_shape); + out_nvte.set_rowwise_scale_inv(a_sf_buf.data_ptr(), DType::kFloat8E4M3, a_sf_shape); + out_nvte.set_with_gemm_swizzled_scales(true); + + nvte_swizzle_scaling_factors(in_nvte.data(), out_nvte.data(), stream); + } + if (b_sf_swizzled) { + b_sf_buf = b_sf; + } else { + b_sf_buf = at::empty({b_sf.numel()}, byte_opts); + TensorWrapper in_nvte(NVTE_NVFP4_1D_SCALING); + in_nvte.set_rowwise_data(b_data.data_ptr(), DType::kFloat4E2M1, b_data_shape); + in_nvte.set_rowwise_scale_inv(b_sf.data_ptr(), DType::kFloat8E4M3, b_sf_shape); + + TensorWrapper out_nvte(NVTE_NVFP4_1D_SCALING); + out_nvte.set_rowwise_data(b_data.data_ptr(), DType::kFloat4E2M1, b_data_shape); + out_nvte.set_rowwise_scale_inv(b_sf_buf.data_ptr(), DType::kFloat8E4M3, b_sf_shape); + out_nvte.set_with_gemm_swizzled_scales(true); + + nvte_swizzle_scaling_factors(in_nvte.data(), out_nvte.data(), stream); + } + + // Per-tensor amaxes go in the amax slot; cuBLAS LT folds them into alpha. + TensorWrapper a_te(NVTE_NVFP4_1D_SCALING); + a_te.set_rowwise_data(a_data.data_ptr(), DType::kFloat4E2M1, a_data_shape); + a_te.set_rowwise_scale_inv(a_sf_buf.data_ptr(), DType::kFloat8E4M3, a_sf_shape); + a_te.set_amax(a_amax.data_ptr(), DType::kFloat32, std::vector{1}); + a_te.set_with_gemm_swizzled_scales(true); + + TensorWrapper b_te(NVTE_NVFP4_1D_SCALING); + b_te.set_rowwise_data(b_data.data_ptr(), DType::kFloat4E2M1, b_data_shape); + b_te.set_rowwise_scale_inv(b_sf_buf.data_ptr(), DType::kFloat8E4M3, b_sf_shape); + b_te.set_amax(b_amax.data_ptr(), DType::kFloat32, std::vector{1}); + b_te.set_with_gemm_swizzled_scales(true); + + TensorWrapper d_te = makeTransformerEngineTensor( + d.data_ptr(), std::vector{static_cast(m), static_cast(n)}, + DType::kBFloat16); + + TensorWrapper workspace_te = makeTransformerEngineTensor( + workspace.data_ptr(), std::vector{static_cast(workspace.numel())}, + DType::kByte); + + // Operand swap: see nvfp4_per_token_gemm. + const float alpha_f = static_cast(alpha); + const float beta_f = static_cast(beta); + nvte_cublas_gemm_v2(/*transa=*/1, /*transb=*/0, &alpha_f, + b_te.data(), // cuBLAS-A := caller's B (N, K) + a_te.data(), // cuBLAS-B := caller's A (M, K) + &beta_f, d_te.data(), d_te.data(), workspace_te.data(), + /*config=*/nullptr, stream); + // No post-scale: per-tensor amaxes already folded into cuBLAS-internal alpha. +} + +// Grouped (multi-tensor) per-token quantize. Each direction takes 3 lists +// of per-split tensors; ``split_sections[i] = M_i`` (% 128, sum = sum_M). +// Disabled direction's lists are ignored. +namespace { + +void build_per_token_output_wrapper(TensorWrapper& out_te, int64_t M_i, int64_t K, bool rowwise, + bool columnwise, const at::Tensor& q_row, + const at::Tensor& s_dec_row, const at::Tensor& row_amax, + const at::Tensor& q_col, const at::Tensor& s_dec_col, + const at::Tensor& col_amax) { + if (rowwise) { + TORCH_CHECK(q_row.is_cuda() && q_row.is_contiguous(), "q_row must be a contiguous CUDA tensor"); + TORCH_CHECK(s_dec_row.is_cuda() && s_dec_row.is_contiguous(), + "s_dec_row must be a contiguous CUDA tensor"); + TORCH_CHECK(row_amax.is_cuda() && row_amax.is_contiguous(), + "row_amax must be a contiguous CUDA tensor"); + TORCH_CHECK(q_row.scalar_type() == at::ScalarType::Byte, "q_row must be uint8"); + TORCH_CHECK(s_dec_row.scalar_type() == at::ScalarType::Byte, "s_dec_row must be uint8"); + TORCH_CHECK(row_amax.scalar_type() == at::ScalarType::Float, "row_amax must be fp32"); + TORCH_CHECK(q_row.numel() == M_i * K / 2, "q_row numel mismatch for split: expected ", + M_i * K / 2, ", got ", q_row.numel()); + TORCH_CHECK(s_dec_row.numel() == M_i * K / 16, "s_dec_row numel mismatch for split"); + TORCH_CHECK(row_amax.numel() == M_i, "row_amax numel mismatch for split"); + out_te.set_rowwise_data(q_row.data_ptr(), DType::kFloat4E2M1, + std::vector{static_cast(M_i), static_cast(K)}); + out_te.set_rowwise_scale_inv( + s_dec_row.data_ptr(), DType::kFloat8E4M3, + std::vector{static_cast(M_i), static_cast(K / 16)}); + out_te.set_amax(row_amax.data_ptr(), DType::kFloat32, + std::vector{static_cast(M_i)}); + } + if (columnwise) { + TORCH_CHECK(q_col.is_cuda() && q_col.is_contiguous(), "q_col must be a contiguous CUDA tensor"); + TORCH_CHECK(s_dec_col.is_cuda() && s_dec_col.is_contiguous(), + "s_dec_col must be a contiguous CUDA tensor"); + TORCH_CHECK(col_amax.is_cuda() && col_amax.is_contiguous(), + "col_amax must be a contiguous CUDA tensor"); + TORCH_CHECK(q_col.scalar_type() == at::ScalarType::Byte, "q_col must be uint8"); + TORCH_CHECK(s_dec_col.scalar_type() == at::ScalarType::Byte, "s_dec_col must be uint8"); + TORCH_CHECK(col_amax.scalar_type() == at::ScalarType::Float, "col_amax must be fp32"); + TORCH_CHECK(q_col.numel() == K * M_i / 2, "q_col numel mismatch for split"); + TORCH_CHECK(s_dec_col.numel() == K * M_i / 16, "s_dec_col numel mismatch for split"); + TORCH_CHECK(col_amax.numel() == K, "col_amax numel mismatch for split"); + out_te.set_columnwise_data( + q_col.data_ptr(), DType::kFloat4E2M1, + std::vector{static_cast(K), static_cast(M_i)}); + out_te.set_columnwise_scale_inv( + s_dec_col.data_ptr(), DType::kFloat8E4M3, + std::vector{static_cast(K), static_cast(M_i / 16)}); + out_te.set_columnwise_amax(col_amax.data_ptr(), DType::kFloat32, + std::vector{static_cast(K)}); + } +} + +DType resolve_input_dtype(const at::Tensor& input) { + if (input.scalar_type() == at::ScalarType::BFloat16) return DType::kBFloat16; + if (input.scalar_type() == at::ScalarType::Float) return DType::kFloat32; + if (input.scalar_type() == at::ScalarType::Half) return DType::kFloat16; + TORCH_CHECK(false, "input dtype must be bf16/fp16/fp32, got ", input.scalar_type()); + return DType::kBFloat16; // unreachable +} + +} // namespace + +void nvfp4_per_token_group_quantize( + const at::Tensor& input, const std::vector& split_sections, + std::vector q_row_list, std::vector s_dec_row_list, + std::vector row_amax_list, std::vector q_col_list, + std::vector s_dec_col_list, std::vector col_amax_list, bool rowwise, + bool columnwise, bool with_rht, int64_t random_sign_mask_t) { + TORCH_CHECK(rowwise || columnwise, "At least one of rowwise/columnwise must be True."); + TORCH_CHECK(input.is_cuda() && input.is_contiguous(), "input must be a contiguous CUDA tensor"); + TORCH_CHECK(input.dim() == 2, "input must be 2D"); + const int64_t sum_M = input.size(0); + const int64_t K = input.size(1); + const size_t num_tensors = split_sections.size(); + TORCH_CHECK(num_tensors > 0, "split_sections must not be empty"); + + // Sum + 64-multiple constraint. + int64_t acc = 0; + for (size_t i = 0; i < num_tensors; ++i) { + TORCH_CHECK(split_sections[i] >= 0, "split_sections[", i, "] must be non-negative"); + TORCH_CHECK(split_sections[i] % 64 == 0, "split_sections[", i, "] = ", split_sections[i], + " must be a multiple of 64"); + acc += split_sections[i]; + } + TORCH_CHECK(acc == sum_M, "sum(split_sections) = ", acc, " must equal input.size(0) = ", sum_M); + + if (rowwise) { + TORCH_CHECK(q_row_list.size() == num_tensors, "q_row_list size mismatch"); + TORCH_CHECK(s_dec_row_list.size() == num_tensors, "s_dec_row_list size mismatch"); + TORCH_CHECK(row_amax_list.size() == num_tensors, "row_amax_list size mismatch"); + } + if (columnwise) { + TORCH_CHECK(q_col_list.size() == num_tensors, "q_col_list size mismatch"); + TORCH_CHECK(s_dec_col_list.size() == num_tensors, "s_dec_col_list size mismatch"); + TORCH_CHECK(col_amax_list.size() == num_tensors, "col_amax_list size mismatch"); + } + + const DType in_dtype = resolve_input_dtype(input); + const auto stream = at::cuda::getCurrentCUDAStream(); + + TensorWrapper in_te = makeTransformerEngineTensor( + input.data_ptr(), std::vector{static_cast(sum_M), static_cast(K)}, + in_dtype); + + // One TensorWrapper per split; raw NVTETensor handles go into `handles`. + std::vector wrappers; + wrappers.reserve(num_tensors); + std::vector handles; + handles.reserve(num_tensors); + std::vector split_sections_sz(num_tensors); + + at::Tensor empty_dummy; // for slots we don't populate + for (size_t i = 0; i < num_tensors; ++i) { + const int64_t M_i = split_sections[i]; + split_sections_sz[i] = static_cast(M_i); + wrappers.emplace_back(NVTE_NVFP4_1D_SCALING); + if (M_i == 0) { + handles.push_back(wrappers.back().data()); + continue; // empty split is allowed (skipped inside the kernel) + } + build_per_token_output_wrapper( + wrappers.back(), M_i, K, rowwise, columnwise, rowwise ? q_row_list[i] : empty_dummy, + rowwise ? s_dec_row_list[i] : empty_dummy, rowwise ? row_amax_list[i] : empty_dummy, + columnwise ? q_col_list[i] : empty_dummy, columnwise ? s_dec_col_list[i] : empty_dummy, + columnwise ? col_amax_list[i] : empty_dummy); + handles.push_back(wrappers.back().data()); + } + + nvte_group_nvfp4_per_token_quantize(in_te.data(), handles.data(), split_sections_sz.data(), + num_tensors, rowwise, columnwise, static_cast(with_rht), + static_cast(random_sign_mask_t), stream); +} + +// Amax-only grouped variant (K1 only); for allReduce-before-cast flows. +void nvfp4_per_token_group_amax(const at::Tensor& input, const std::vector& split_sections, + std::vector row_amax_list, + std::vector col_amax_list, bool rowwise, + bool columnwise, bool with_rht, int64_t random_sign_mask_t) { + TORCH_CHECK(rowwise || columnwise, "At least one of rowwise/columnwise must be True."); + TORCH_CHECK(input.is_cuda() && input.is_contiguous(), "input must be a contiguous CUDA tensor"); + TORCH_CHECK(input.dim() == 2, "input must be 2D"); + const int64_t sum_M = input.size(0); + const int64_t K = input.size(1); + const size_t num_tensors = split_sections.size(); + TORCH_CHECK(num_tensors > 0, "split_sections must not be empty"); + int64_t acc = 0; + for (size_t i = 0; i < num_tensors; ++i) { + TORCH_CHECK(split_sections[i] % 64 == 0, "split_sections[", i, "] must be a multiple of 64"); + acc += split_sections[i]; + } + TORCH_CHECK(acc == sum_M, "sum(split_sections) must equal input.size(0)"); + if (rowwise) TORCH_CHECK(row_amax_list.size() == num_tensors, "row_amax_list size mismatch"); + if (columnwise) TORCH_CHECK(col_amax_list.size() == num_tensors, "col_amax_list size mismatch"); + + const DType in_dtype = resolve_input_dtype(input); + const auto stream = at::cuda::getCurrentCUDAStream(); + + TensorWrapper in_te = makeTransformerEngineTensor( + input.data_ptr(), std::vector{static_cast(sum_M), static_cast(K)}, + in_dtype); + + std::vector wrappers; + wrappers.reserve(num_tensors); + std::vector handles; + handles.reserve(num_tensors); + std::vector split_sections_sz(num_tensors); + + for (size_t i = 0; i < num_tensors; ++i) { + const int64_t M_i = split_sections[i]; + split_sections_sz[i] = static_cast(M_i); + wrappers.emplace_back(NVTE_NVFP4_1D_SCALING); + if (M_i == 0) { + handles.push_back(wrappers.back().data()); + continue; + } + if (rowwise) { + const at::Tensor& ra = row_amax_list[i]; + TORCH_CHECK(ra.is_cuda() && ra.scalar_type() == at::ScalarType::Float, "bad row_amax"); + TORCH_CHECK(ra.numel() == M_i, "row_amax numel mismatch"); + wrappers.back().set_amax(ra.data_ptr(), DType::kFloat32, + std::vector{static_cast(M_i)}); + } + if (columnwise) { + const at::Tensor& ca = col_amax_list[i]; + TORCH_CHECK(ca.is_cuda() && ca.scalar_type() == at::ScalarType::Float, "bad col_amax"); + TORCH_CHECK(ca.numel() == K, "col_amax numel mismatch"); + wrappers.back().set_columnwise_amax(ca.data_ptr(), DType::kFloat32, + std::vector{static_cast(K)}); + } + handles.push_back(wrappers.back().data()); + } + + nvte_group_nvfp4_per_token_amax(in_te.data(), handles.data(), split_sections_sz.data(), + num_tensors, rowwise, columnwise, static_cast(with_rht), + static_cast(random_sign_mask_t), stream); +} + +// BULK grouped per-token quantize: alloc + view + dispatch in ONE C++ call. +// Returns 6 per-split tensor lists (s_dec_* pre-cast to Float8_e4m3fn). +// Byte-equal to the prior Python wrap (saves ~70-90us at N=8). +std::tuple, std::vector, std::vector, + std::vector, std::vector, std::vector> +nvfp4_per_token_group_quantize_bulk(const at::Tensor& input, + const std::vector& split_sections, bool rowwise, + bool columnwise, bool with_rht, int64_t random_sign_mask_t) { + // Validation mirrors _validate_per_token_group_input in Python. + TORCH_CHECK(rowwise || columnwise, "At least one of rowwise/columnwise must be True."); + TORCH_CHECK(input.is_cuda(), "input must be a CUDA tensor"); + TORCH_CHECK(input.is_contiguous(), "x_concat must be contiguous (row-major)"); + TORCH_CHECK(input.dim() == 2, "nvfp4_per_token_group_quantize expects a 2D input, got ", + input.dim(), "D"); + TORCH_CHECK(input.scalar_type() == at::ScalarType::BFloat16, + "Per-token grouped kernel is bf16-only; got dtype ", input.scalar_type()); + + const int64_t sum_M = input.size(0); + const int64_t K = input.size(1); + constexpr int64_t kPerTokenTile = 128; + constexpr int64_t kBlockK = 16; + + TORCH_CHECK(K % kPerTokenTile == 0, "Per-token grouped kernel requires K % ", kPerTokenTile, + " == 0; got K=", K); + + const size_t num_tensors = split_sections.size(); + TORCH_CHECK(num_tensors > 0, "split_sections must not be empty"); + TORCH_CHECK(num_tensors <= 64, "num_tensors must be <= 64 (kernel arg-struct cap); got ", + num_tensors); + + int64_t acc = 0; + for (size_t i = 0; i < num_tensors; ++i) { + const int64_t M_i = split_sections[i]; + TORCH_CHECK(M_i > 0, "split_sections[", i, "] must be > 0, got ", M_i); + TORCH_CHECK(M_i % kPerTokenTile == 0, "split_sections[", i, "] = ", M_i, + " must be a multiple of ", kPerTokenTile); + acc += M_i; + } + TORCH_CHECK(acc == sum_M, "sum(split_sections) = ", acc, " must equal input.size(0) = ", sum_M); + + // Bulk allocation: one at::empty per output type, covers all splits. + auto opts_u8 = input.options().dtype(at::kByte); + auto opts_f32 = input.options().dtype(at::kFloat); + + at::Tensor q_row_bulk, s_dec_row_bulk, row_amax_bulk; + at::Tensor q_col_bulk, s_dec_col_bulk, col_amax_bulk; + + if (rowwise) { + q_row_bulk = at::empty({sum_M, K / 2}, opts_u8); + s_dec_row_bulk = at::empty({sum_M, K / kBlockK}, opts_u8); + row_amax_bulk = at::empty({sum_M}, opts_f32); + } + if (columnwise) { + q_col_bulk = at::empty({K * sum_M / 2}, opts_u8); + s_dec_col_bulk = at::empty({K * sum_M / kBlockK}, opts_u8); + col_amax_bulk = at::empty({static_cast(num_tensors), K}, opts_f32); + } + + // Per-split views built in C++; s_dec_* kept in both uint8 (for binding) + // and fp8_e4m3fn (returned to Python directly). + std::vector q_row_list, s_dec_row_u8_list, row_amax_list; + std::vector q_col_list, s_dec_col_u8_list, col_amax_list; + std::vector s_dec_row_fp8_list, s_dec_col_fp8_list; + if (rowwise) { + q_row_list.reserve(num_tensors); + s_dec_row_u8_list.reserve(num_tensors); + row_amax_list.reserve(num_tensors); + s_dec_row_fp8_list.reserve(num_tensors); + } + if (columnwise) { + q_col_list.reserve(num_tensors); + s_dec_col_u8_list.reserve(num_tensors); + col_amax_list.reserve(num_tensors); + s_dec_col_fp8_list.reserve(num_tensors); + } + + int64_t m_off = 0; + for (size_t i = 0; i < num_tensors; ++i) { + const int64_t M_i = split_sections[i]; + if (rowwise) { + q_row_list.emplace_back(q_row_bulk.narrow(0, m_off, M_i)); + s_dec_row_u8_list.emplace_back(s_dec_row_bulk.narrow(0, m_off, M_i)); + row_amax_list.emplace_back(row_amax_bulk.narrow(0, m_off, M_i)); + s_dec_row_fp8_list.emplace_back(s_dec_row_u8_list.back().view(at::kFloat8_e4m3fn)); + } + if (columnwise) { + auto q_col_flat = q_col_bulk.narrow(0, K * m_off / 2, K * M_i / 2); + q_col_list.emplace_back(q_col_flat.view({K, M_i / 2})); + auto s_dec_col_flat = s_dec_col_bulk.narrow(0, K * m_off / kBlockK, K * M_i / kBlockK); + s_dec_col_u8_list.emplace_back(s_dec_col_flat.view({K, M_i / kBlockK})); + col_amax_list.emplace_back(col_amax_bulk.select(0, static_cast(i))); + s_dec_col_fp8_list.emplace_back(s_dec_col_u8_list.back().view(at::kFloat8_e4m3fn)); + } + m_off += M_i; + } + + // Dispatch K1+K2 grouped kernel via the same C-API the thin entry uses. + const auto stream = at::cuda::getCurrentCUDAStream(); + TensorWrapper in_te = makeTransformerEngineTensor( + input.data_ptr(), std::vector{static_cast(sum_M), static_cast(K)}, + DType::kBFloat16); + + std::vector wrappers; + wrappers.reserve(num_tensors); + std::vector handles; + handles.reserve(num_tensors); + std::vector split_sections_sz(num_tensors); + + at::Tensor empty_dummy; + for (size_t i = 0; i < num_tensors; ++i) { + const int64_t M_i = split_sections[i]; + split_sections_sz[i] = static_cast(M_i); + wrappers.emplace_back(NVTE_NVFP4_1D_SCALING); + build_per_token_output_wrapper( + wrappers.back(), M_i, K, rowwise, columnwise, rowwise ? q_row_list[i] : empty_dummy, + rowwise ? s_dec_row_u8_list[i] : empty_dummy, rowwise ? row_amax_list[i] : empty_dummy, + columnwise ? q_col_list[i] : empty_dummy, columnwise ? s_dec_col_u8_list[i] : empty_dummy, + columnwise ? col_amax_list[i] : empty_dummy); + handles.push_back(wrappers.back().data()); + } + + nvte_group_nvfp4_per_token_quantize(in_te.data(), handles.data(), split_sections_sz.data(), + num_tensors, rowwise, columnwise, static_cast(with_rht), + static_cast(random_sign_mask_t), stream); + + return std::make_tuple(std::move(q_row_list), std::move(s_dec_row_fp8_list), + std::move(row_amax_list), std::move(q_col_list), + std::move(s_dec_col_fp8_list), std::move(col_amax_list)); +} + +} // namespace transformer_engine::pytorch diff --git a/transformer_engine/pytorch/csrc/extensions/pybind.cpp b/transformer_engine/pytorch/csrc/extensions/pybind.cpp index a4571c64e2..a5bd586dfa 100644 --- a/transformer_engine/pytorch/csrc/extensions/pybind.cpp +++ b/transformer_engine/pytorch/csrc/extensions/pybind.cpp @@ -342,6 +342,11 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("compute_amax", &transformer_engine::pytorch::compute_amax, "Compute absolute max value in tensor", py::arg("input"), py::arg("amax"), py::call_guard()); + m.def("hadamard_transform_amax", &transformer_engine::pytorch::hadamard_transform_amax, + "K1 of the NVFP4Quantizer RHT+post_rht_amax path: rowwise (pre-RHT) + " + "columnwise (RHT(input.T)) amax in one launch. Bench-only entry.", + py::arg("input"), py::arg("rowwise_amax"), py::arg("columnwise_amax"), + py::arg("rht_matrix_random_sign_mask"), py::call_guard()); m.def("fused_amax_and_scale_update_after_reduction", &transformer_engine::pytorch::fused_amax_and_scale_update_after_reduction, "Update amax history and FP8 scale/scale_inv after reduction", @@ -390,6 +395,96 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { py::arg("output_rowwise"), py::arg("output_colwise"), py::arg("scale_inv_rowwise"), py::arg("scale_inv_colwise"), py::arg("rows"), py::arg("cols"), py::arg("start_offset"), py::call_guard()); + m.def("nvfp4_per_token_quantize", &transformer_engine::pytorch::nvfp4_per_token_quantize, + "NVFP4 per-token cast (composite K1 amax + K2 encode). " + "with_rht=True: 16-pt col-wise RHT in K1+K2; " + "with_swizzle=True: rowwise scale_inv in cuBLAS LT swizzled layout.", + py::arg("input"), py::arg("q_row"), py::arg("s_dec_row"), py::arg("row_amax"), + py::arg("q_col"), py::arg("s_dec_col"), py::arg("col_amax"), py::arg("rowwise"), + py::arg("columnwise"), py::arg("with_rht") = false, + py::arg("random_sign_mask_t") = static_cast(0xACE1), + py::arg("with_swizzle") = false); + m.def("nvfp4_per_token_amax", &transformer_engine::pytorch::nvfp4_per_token_amax, + "K1-only: per-row/per-col outer amax via TMA + atomicMax. Bench/diagnostic. " + "with_rht=True applies a 16-pt col-wise RHT before amax.", + py::arg("input"), py::arg("row_amax"), py::arg("col_amax"), py::arg("rowwise"), + py::arg("columnwise"), py::arg("with_rht") = false, + py::arg("random_sign_mask_t") = static_cast(0xACE1)); + m.def("nvfp4_per_token_encode", &transformer_engine::pytorch::nvfp4_per_token_encode, + "K2-only: FP4 + e4m3 SF encode given pre-filled amax buffers. Bench/diagnostic. " + "with_rht=True requires col_amax produced by a K1 launch with the same mask; " + "with_swizzle=True writes rowwise scale_inv directly in the swizzled layout.", + py::arg("input"), py::arg("q_row"), py::arg("s_dec_row"), py::arg("row_amax"), + py::arg("q_col"), py::arg("s_dec_col"), py::arg("col_amax"), py::arg("rowwise"), + py::arg("columnwise"), py::arg("with_rht") = false, + py::arg("random_sign_mask_t") = static_cast(0xACE1), + py::arg("with_swizzle") = false); + m.def("nvfp4_cutlass_gemm", &transformer_engine::pytorch::nvfp4_cutlass_gemm, + "Stage-1 forked CUTLASS NVFP4 x NVFP4 -> BF16 GEMM with scalar (alpha, " + "beta) epilogue. Drop-in replacement for cuBLAS LT NVFP4. Pair with " + "nvfp4_per_token_post_scale to get a CUTLASS-based per-token GEMM. " + "a_sf_swizzled / b_sf_swizzled = true skip the internal swizzle for " + "bench parity with the cuBLAS LT path.", + py::arg("a_data"), py::arg("b_data"), py::arg("a_sf"), py::arg("b_sf"), py::arg("d"), + py::arg("m"), py::arg("n"), py::arg("k"), py::arg("alpha"), py::arg("beta"), + py::arg("a_sf_swizzled") = false, py::arg("b_sf_swizzled") = false); + m.def("nvfp4_cutlass_per_token_gemm", &transformer_engine::pytorch::nvfp4_cutlass_per_token_gemm, + "Forked CUTLASS NVFP4 GEMM with per-token rescale fused into the " + "epilogue: D = bf16(alpha_a[i] * alpha_b[j] * (A @ B^T)[i, j]). One " + "launch, no separate post-scale kernel. alpha_a (M,) and alpha_b (N,) " + "are fp32 outer-scale vectors. a_sf_swizzled / b_sf_swizzled = true " + "skip the internal swizzle for bench parity.", + py::arg("a_data"), py::arg("b_data"), py::arg("a_sf"), py::arg("b_sf"), py::arg("alpha_a"), + py::arg("alpha_b"), py::arg("d"), py::arg("m"), py::arg("n"), py::arg("k"), + py::arg("a_sf_swizzled") = false, py::arg("b_sf_swizzled") = false); + m.def("nvfp4_per_token_post_scale", &transformer_engine::pytorch::nvfp4_per_token_post_scale, + "Apply d[i,j] *= row_amax_a[i] * row_amax_b[j] in-place on bf16 D.", py::arg("d"), + py::arg("row_amax_a"), py::arg("row_amax_b")); + m.def("nvfp4_per_token_swizzle_rowwise_sf", + &transformer_engine::pytorch::nvfp4_per_token_swizzle_rowwise_sf, + "Standalone rowwise SF swizzle (1 launch); mirrors prod's per-operand swizzle. " + "data (M, K/2) FP4; sf_in (M, K/16) M-major; sf_out (M, K/16) swizzled.", + py::arg("data"), py::arg("sf_in"), py::arg("sf_out")); + m.def("nvfp4_per_token_gemm", &transformer_engine::pytorch::nvfp4_per_token_gemm, + "E2E NVFP4 per-token GEMM: swizzle SFs -> cuBLAS LT -> row*col post-scale. " + "beta must be 0. a_sf_swizzled/b_sf_swizzled=True skips that operand's swizzle. " + "skip_post_scale=True is bench-only (isolates cuBLAS LT GEMM cost).", + py::arg("a_data"), py::arg("b_data"), py::arg("a_sf"), py::arg("b_sf"), + py::arg("a_row_amax"), py::arg("b_row_amax"), py::arg("d"), py::arg("workspace"), + py::arg("m"), py::arg("n"), py::arg("k"), py::arg("alpha"), py::arg("beta"), + py::arg("a_sf_swizzled") = false, py::arg("b_sf_swizzled") = false, + py::arg("skip_post_scale") = false); + m.def("nvfp4_per_tensor_gemm", &transformer_engine::pytorch::nvfp4_per_tensor_gemm, + "Skinny prod NVFP4 GEMM twin of nvfp4_per_token_gemm: per-tensor amaxes " + "folded into cuBLAS alpha, no trailing post-scale. Bench-only. " + "a_sf_swizzled/b_sf_swizzled=True skips that operand's swizzle.", + py::arg("a_data"), py::arg("b_data"), py::arg("a_sf"), py::arg("b_sf"), py::arg("a_amax"), + py::arg("b_amax"), py::arg("d"), py::arg("workspace"), py::arg("m"), py::arg("n"), + py::arg("k"), py::arg("alpha"), py::arg("beta"), py::arg("a_sf_swizzled") = false, + py::arg("b_sf_swizzled") = false); + m.def("nvfp4_per_token_group_quantize", + &transformer_engine::pytorch::nvfp4_per_token_group_quantize, + "Grouped (multi-tensor) NVFP4 per-token cast: K1 + K2 across <= 64 splits " + "of a single (sum_M, K) input. Byte-equal to a for-loop of single-tensor. " + "with_rht=True applies a 16-pt col-wise RHT in both K1 and K2.", + py::arg("input"), py::arg("split_sections"), py::arg("q_row_list"), + py::arg("s_dec_row_list"), py::arg("row_amax_list"), py::arg("q_col_list"), + py::arg("s_dec_col_list"), py::arg("col_amax_list"), py::arg("rowwise"), + py::arg("columnwise"), py::arg("with_rht") = false, + py::arg("random_sign_mask_t") = static_cast(0xACE1)); + m.def("nvfp4_per_token_group_amax", &transformer_engine::pytorch::nvfp4_per_token_group_amax, + "K1-only variant of nvfp4_per_token_group_quantize: only fills amax slots. " + "with_rht / random_sign_mask_t must match the trailing cast launch.", + py::arg("input"), py::arg("split_sections"), py::arg("row_amax_list"), + py::arg("col_amax_list"), py::arg("rowwise"), py::arg("columnwise"), + py::arg("with_rht") = false, py::arg("random_sign_mask_t") = static_cast(0xACE1)); + m.def("nvfp4_per_token_group_quantize_bulk", + &transformer_engine::pytorch::nvfp4_per_token_group_quantize_bulk, + "Bulk grouped quantize: allocates per-split buffers + view-slices inside " + "the binding (one pybind hop instead of 1 + 6N), then dispatches the K1+K2 " + "kernel. with_rht=True applies a 16-pt col-wise RHT in both K1 and K2.", + py::arg("input"), py::arg("split_sections"), py::arg("rowwise"), py::arg("columnwise"), + py::arg("with_rht") = false, py::arg("random_sign_mask_t") = static_cast(0xACE1)); m.def("fused_multi_row_padding", &transformer_engine::pytorch::fused_multi_row_padding, "Fused Multi-tensor padding", py::call_guard()); m.def("fused_multi_row_unpadding", &transformer_engine::pytorch::fused_multi_row_unpadding, diff --git a/transformer_engine/pytorch/csrc/extensions/recipe.cpp b/transformer_engine/pytorch/csrc/extensions/recipe.cpp index c02d2ec616..d9d21a78bf 100644 --- a/transformer_engine/pytorch/csrc/extensions/recipe.cpp +++ b/transformer_engine/pytorch/csrc/extensions/recipe.cpp @@ -29,6 +29,32 @@ void compute_amax(const at::Tensor& tensor, at::Tensor& amax) { nvte_compute_amax(te_input.data(), fake_te_output.data(), at::cuda::getCurrentCUDAStream()); } +// Thin pybind for nvte_hadamard_transform_amax: K1 of the production +// NVFP4Quantizer(with_rht, with_post_rht_amax) path. Computes rowwise (pre-RHT) +// and columnwise (RHT(input.T)) amax in one launch. Bench-only entry. +void hadamard_transform_amax(const at::Tensor& tensor, at::Tensor& rowwise_amax, + at::Tensor& columnwise_amax, int64_t rht_matrix_random_sign_mask) { + auto input_tensor = tensor.contiguous(); + const TensorWrapper& te_input = makeTransformerEngineTensor(input_tensor); + + TORCH_CHECK(rowwise_amax.scalar_type() == at::kFloat, "rowwise_amax must be a float tensor"); + TORCH_CHECK(rowwise_amax.numel() == 1, "rowwise_amax must have exactly one element"); + TORCH_CHECK(columnwise_amax.scalar_type() == at::kFloat, + "columnwise_amax must be a float tensor"); + TORCH_CHECK(columnwise_amax.numel() == 1, "columnwise_amax must have exactly one element"); + + // Mirror NVFP4Quantizer: empty NVFP4_1D_SCALING with two amax slots. + TensorWrapper te_output(NVTE_NVFP4_1D_SCALING); + te_output.set_amax(rowwise_amax.data_ptr(), DType::kFloat32, std::vector{1}); + te_output.set_columnwise_amax(columnwise_amax.data_ptr(), DType::kFloat32, + std::vector{1}); + + nvte_hadamard_transform_amax(te_input.data(), te_output.data(), + /*random_sign_mask=*/0, + static_cast(rht_matrix_random_sign_mask), + at::cuda::getCurrentCUDAStream()); +} + void fused_amax_and_scale_update_after_reduction(const at::Tensor& amax_reduction_buffer, std::vector amax_histories, std::vector scales, diff --git a/transformer_engine/pytorch/custom_recipes/__init__.py b/transformer_engine/pytorch/custom_recipes/__init__.py index f115ffe743..6d21422bb3 100644 --- a/transformer_engine/pytorch/custom_recipes/__init__.py +++ b/transformer_engine/pytorch/custom_recipes/__init__.py @@ -3,3 +3,35 @@ # See LICENSE for license information. """Experimental features and APIs.""" + +# Per-token NVFP4: per-row outer + 1x16 e4m3 inner SF; cuBLAS LT NVFP4 GEMM +# with operand amaxes pinned to 1.0 and a trailing row-amax post-scale. +# See quantization_nvfp4_per_token.py / gemm_nvfp4_per_token.py for the math. +from transformer_engine.pytorch.custom_recipes.quantization_nvfp4_per_token import ( + NVFP4QuantizerPerTokenRef, + RefNVFP4TensorPerToken, + nvfp4_per_token_amax, + nvfp4_per_token_encode, + nvfp4_per_token_quantize, +) +from transformer_engine.pytorch.custom_recipes.quantization_nvfp4_per_token_group import ( + nvfp4_per_token_group_quantize, +) +from transformer_engine.pytorch.custom_recipes.gemm_nvfp4_per_token import ( + dequantize_nvfp4_per_token, + nvfp4_per_token_gemm, + nvfp4_per_token_gemm_dequant, +) + + +__all__ = [ + "NVFP4QuantizerPerTokenRef", + "RefNVFP4TensorPerToken", + "nvfp4_per_token_quantize", + "nvfp4_per_token_group_quantize", + "nvfp4_per_token_amax", + "nvfp4_per_token_encode", + "dequantize_nvfp4_per_token", + "nvfp4_per_token_gemm", + "nvfp4_per_token_gemm_dequant", +] diff --git a/transformer_engine/pytorch/custom_recipes/gemm_nvfp4_per_token.py b/transformer_engine/pytorch/custom_recipes/gemm_nvfp4_per_token.py new file mode 100644 index 0000000000..033c98729e --- /dev/null +++ b/transformer_engine/pytorch/custom_recipes/gemm_nvfp4_per_token.py @@ -0,0 +1,217 @@ +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""Reference + production GEMM for the NVFP4 per-token quantization scheme. + +Per-token GEMM reuses cuBLAS LT NVFP4 (no TE fork) + a trailing row-amax +post-scale. Each side is a (data, scale, row_amax) triple matching what +tex.nvfp4_per_token_quantize emits. See include/transformer_engine/nvfp4_per_token.h. +""" + +from __future__ import annotations + +from typing import Optional + +import torch + +# get_cublas_workspace is imported lazily inside nvfp4_per_token_gemm to +# avoid a circular import with cpp_extensions.gemm at module load time. +from transformer_engine.pytorch.custom_recipes.quantization_ref_nvfp4 import cast_from_fp4x2 +from transformer_engine.pytorch.custom_recipes.quantization_nvfp4_per_token import ( + BLOCK_K, + FLOAT4_E2M1_MAX, + FLOAT8_E4M3_MAX, + _AMAX_FLOOR, + RefNVFP4TensorPerToken, +) + + +__all__ = [ + "dequantize_nvfp4_per_token", + "nvfp4_per_token_gemm_dequant", + "nvfp4_per_token_gemm", +] + + +# Reference: dequantize + reference matmul. + + +def _validate_per_token_triple( + data: torch.Tensor, scale: torch.Tensor, row_amax: torch.Tensor, side: str +) -> int: + """Sanity-check one (data, scale, row_amax) triple; return K.""" + if data.ndim != 2 or scale.ndim != 2 or row_amax.ndim != 1: + raise ValueError( + f"{side}: expected 2D data/scale + 1D row_amax, got dims " + f"data={data.ndim}, scale={scale.ndim}, row_amax={row_amax.ndim}" + ) + rows = data.shape[0] + K = data.shape[1] * 2 # FP4 packs 2 values/byte. + if K % BLOCK_K != 0: + raise ValueError(f"{side}: K={K} must be a multiple of BLOCK_K={BLOCK_K}") + if scale.shape != (rows, K // BLOCK_K): + raise ValueError(f"{side}: scale shape {tuple(scale.shape)} != ({rows}, {K // BLOCK_K})") + if row_amax.shape != (rows,): + raise ValueError(f"{side}: row_amax shape {tuple(row_amax.shape)} != ({rows},)") + return K + + +def dequantize_nvfp4_per_token( + data: torch.Tensor, scale: torch.Tensor, row_amax: torch.Tensor +) -> torch.Tensor: + """Dequantize a per-token NVFP4 (data, scale, row_amax) triple to fp32. + + x[i, k] = code[i, k] * s_dec[i, k//16] * row_amax[i] / (FP4_MAX * E4M3_MAX). + """ + K = _validate_per_token_triple(data, scale, row_amax, "dequant") + rows = data.shape[0] + + codes = data.contiguous().view(dtype=torch.uint8) + qf = cast_from_fp4x2(codes, torch.float32) + + if scale.dtype == torch.float8_e4m3fn: + s_dec = scale.to(torch.float32) + else: + s_dec = scale.view(torch.float8_e4m3fn).to(torch.float32) + + inv_outer = row_amax.to(torch.float32) / (FLOAT4_E2M1_MAX * FLOAT8_E4M3_MAX) + per_block_decode = s_dec * inv_outer.unsqueeze(-1) + per_elem_decode = per_block_decode.repeat_interleave(BLOCK_K, dim=1) + assert per_elem_decode.shape == (rows, K) + return qf * per_elem_decode + + +def nvfp4_per_token_gemm_dequant( + a_data: torch.Tensor, + a_scale: torch.Tensor, + a_row_amax: torch.Tensor, + b_data: torch.Tensor, + b_scale: torch.Tensor, + b_row_amax: torch.Tensor, + *, + out_dtype: torch.dtype = torch.float32, +) -> torch.Tensor: + """Reference C = A @ B^T via dequant-then-fp32-matmul. + + Agrees with the cuBLAS LT path at TF32 precision (~1e-3 relative). + Exists as executable docs of the math chain and a sanity oracle. + """ + K_a = _validate_per_token_triple(a_data, a_scale, a_row_amax, "A") + K_b = _validate_per_token_triple(b_data, b_scale, b_row_amax, "B") + if K_a != K_b: + raise ValueError(f"K mismatch between A and B: {K_a} vs {K_b}") + + a_fp32 = dequantize_nvfp4_per_token(a_data, a_scale, a_row_amax) + b_fp32 = dequantize_nvfp4_per_token(b_data, b_scale, b_row_amax) + c = a_fp32 @ b_fp32.t() + return c.to(out_dtype) + + +# Production wrapper: cuBLAS LT NVFP4 GEMM + per-token post-scale. + + +def nvfp4_per_token_gemm( + a_data: torch.Tensor, + a_scale: torch.Tensor, + a_row_amax: torch.Tensor, + b_data: torch.Tensor, + b_scale: torch.Tensor, + b_row_amax: torch.Tensor, + *, + out: Optional[torch.Tensor] = None, + alpha: float = 1.0, + beta: float = 0.0, + out_dtype: torch.dtype = torch.bfloat16, + a_sf_swizzled: bool = False, + b_sf_swizzled: bool = False, +) -> torch.Tensor: + """Production C = alpha * (A @ B^T) via cuBLAS LT NVFP4 + per-token post-scale. + + Binding swizzles compact SFs in-flight, runs cuBLAS LT NVFP4 with operand + amaxes pinned to 1.0, then applies the row_amax_A * row_amax_B post-scale. + Output is bf16 (cuBLAS LT NVFP4 locks D to bf16/fp32); beta != 0 unsupported. + + ``a_sf_swizzled`` / ``b_sf_swizzled = True`` skips the in-binding swizzle + for that operand (caller's SF is already in the cuBLAS LT swizzled layout + e.g. from ``nvfp4_per_token_quantize(..., with_swizzle=True)``). + """ + import transformer_engine_torch as tex # type: ignore + + K_a = _validate_per_token_triple(a_data, a_scale, a_row_amax, "A") + K_b = _validate_per_token_triple(b_data, b_scale, b_row_amax, "B") + if K_a != K_b: + raise ValueError(f"K mismatch between A and B: {K_a} vs {K_b}") + K = K_a + M = a_data.shape[0] + N = b_data.shape[0] + + if K % 16 != 0: + raise ValueError(f"K must be a multiple of 16 (got K={K})") + # cuBLAS LT NVFP4 SF buffer is padded to (roundup(rows, 128), roundup(K/16, 4)). + # Our compact quantize emits (rows, K/16); SF padding is a TODO so reject M/N < 128. + if M < 128 or M % 128 != 0: + raise ValueError(f"M must be a multiple of 128 (got M={M}); SF padding is a TODO.") + if N < 128 or N % 128 != 0: + raise ValueError(f"N must be a multiple of 128 (got N={N}); SF padding is a TODO.") + if a_data.device != b_data.device: + raise ValueError( + f"A and B must be on the same device (got {a_data.device} vs {b_data.device})" + ) + device = a_data.device + + if out is None: + out_bf16 = torch.empty((M, N), dtype=torch.bfloat16, device=device) + else: + if out.shape != (M, N): + raise ValueError(f"out shape {tuple(out.shape)} != ({M}, {N})") + if out.dtype != torch.bfloat16: + raise ValueError( + f"out dtype must be bf16 for in-place use, got {out.dtype}. " + "(The binding produces bf16; pass `out=None` for non-bf16 dtypes " + "and the result will be cast at the end.)" + ) + out_bf16 = out + + if float(beta) != 0.0: + raise ValueError( + f"nvfp4_per_token_gemm: beta != 0 not yet supported, got beta={beta}. " + "Use beta=0 and accumulate outside the call if needed." + ) + + a_data_u8 = a_data.contiguous().view(dtype=torch.uint8) + b_data_u8 = b_data.contiguous().view(dtype=torch.uint8) + + # Binding expects uint8 SFs (accepts both e4m3 view and raw uint8 storage). + a_scale_u8 = a_scale.contiguous().view(dtype=torch.uint8) + b_scale_u8 = b_scale.contiguous().view(dtype=torch.uint8) + a_scale_u8_flat = a_scale_u8.reshape(-1) + b_scale_u8_flat = b_scale_u8.reshape(-1) + + a_row_amax_f32 = a_row_amax.to(torch.float32).contiguous() + b_row_amax_f32 = b_row_amax.to(torch.float32).contiguous() + + # Lazy import to break the cpp_extensions.gemm circular import. + from transformer_engine.pytorch.cpp_extensions.gemm import get_cublas_workspace + + workspace = get_cublas_workspace(device.index, ub=False, grouped_gemm=False) + + tex.nvfp4_per_token_gemm( + a_data_u8, + b_data_u8, + a_scale_u8_flat, + b_scale_u8_flat, + a_row_amax_f32, + b_row_amax_f32, + out_bf16, + workspace, + M, + N, + K, + float(alpha), + float(beta), + a_sf_swizzled=a_sf_swizzled, + b_sf_swizzled=b_sf_swizzled, + ) + + return out_bf16 if out_dtype is torch.bfloat16 else out_bf16.to(out_dtype) diff --git a/transformer_engine/pytorch/custom_recipes/quantization_nvfp4_per_token.py b/transformer_engine/pytorch/custom_recipes/quantization_nvfp4_per_token.py new file mode 100644 index 0000000000..7af52c89ac --- /dev/null +++ b/transformer_engine/pytorch/custom_recipes/quantization_nvfp4_per_token.py @@ -0,0 +1,422 @@ +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +from __future__ import annotations + +import dataclasses +from typing import Optional, Tuple + +import torch + +from transformer_engine.pytorch.custom_recipes.quantization_ref_nvfp4 import cast_to_fp4x2 + +# Inner sub-block size along K is fixed by the NVFP4 spec (one E4M3 +# ``s_dec`` per 16 FP4 samples); only the outer-amax granularity changes +# between per-token / per-tensor / blocked / 2D. +BLOCK_K: int = 16 + +# E2M1 / E4M3 numeric extrema (matches ``TypeExtrema`` in core_nvfp4.cuh). +FLOAT4_E2M1_MAX: float = 6.0 +FLOAT8_E4M3_MAX: float = 448.0 + +# Matches the kernel's ``fmaxf(row_amax, 1e-12f)`` clamp on the divisor of +# ``compute_global_encode_scaling_factor_FP4``. +_AMAX_FLOOR: float = 1e-12 + + +@dataclasses.dataclass +class RefNVFP4TensorPerToken: + """Container for the per-token reference output. + + Attributes + ---------- + data: + Packed rowwise FP4 bytes, ``(M, N // 2)`` ``uint8``. + scale: + Per-1x16-block rowwise decode scale (E4M3), ``(M, N // 16)`` + ``float8_e4m3fn``. + row_amax: + Per-row outer amax, ``(M,)`` ``float32``. This replaces the + per-tensor path's single-scalar ``amax`` and the blocked path's + per-window ``window_amax``. + columnwise_data, columnwise_scale, col_amax: + Their columnwise (transposed) counterparts. Shapes are + ``(N, M // 2)``, ``(N, M // 16)``, and ``(N,)`` respectively. + ``None`` if columnwise was not requested. + """ + + data: Optional[torch.Tensor] = None + scale: Optional[torch.Tensor] = None + row_amax: Optional[torch.Tensor] = None + columnwise_data: Optional[torch.Tensor] = None + columnwise_scale: Optional[torch.Tensor] = None + col_amax: Optional[torch.Tensor] = None + + +class NVFP4QuantizerPerTokenRef: + """Pure-PyTorch reference for the NVFP4 per-token cast kernel. + + Constructor takes the two output-direction switches (``rowwise`` and + ``columnwise``). RHT, 2D scaling, and stochastic rounding are not + exposed because the per-token CUDA kernel does not implement them + (the per-token path is target-shape simple-and-fast: per-row outer + + 1x16 inner SF, nothing else). + + The arithmetic chain (``S_enc``, ``s_dec``, ``block_scale``, FP4 cast) + matches ``NVFP4Quantizer1x64Ref`` / ``NVFP4QuantizerBlockedRef``; + only the outer-amax granularity differs: + + * 1x64Ref / BlockedRef : one outer amax per ``OUTER_K``-K-window + * **PerTokenRef** : one outer amax per row (full K window) + """ + + def __init__( + self, + rowwise: bool = True, + columnwise: bool = False, + ) -> None: + if not rowwise and not columnwise: + raise ValueError("At least one of rowwise / columnwise must be True.") + self.rowwise = rowwise + self.columnwise = columnwise + + def _quantize_2d(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Run the per-token reference math on a 2D input along its trailing dim. + + Returns ``(qx, sx, row_amax)`` where ``qx`` is ``(M, N // 2)`` + ``uint8``, ``sx`` is ``(M, N // BLOCK_K)`` ``float8_e4m3fn``, + and ``row_amax`` is ``(M,)`` ``float32``. + + The columnwise pass is implemented by calling this routine on + ``x.transpose(0, 1).contiguous()``. + """ + if x.ndim != 2: + raise ValueError(f"NVFP4QuantizerPerTokenRef expects a 2D tensor, got {x.ndim}D") + M, N = x.shape + if N % BLOCK_K != 0: + raise ValueError(f"N={N} must be a multiple of BLOCK_K={BLOCK_K}") + + device = x.device + fp32_max = torch.tensor(torch.finfo(torch.float32).max, device=device, dtype=torch.float32) + fp4_max = torch.tensor(FLOAT4_E2M1_MAX, device=device, dtype=torch.float32) + fp8_max = torch.tensor(FLOAT8_E4M3_MAX, device=device, dtype=torch.float32) + + n_blk = N // BLOCK_K + x_fp32 = x.to(torch.float32).contiguous() + x_blk = x_fp32.view(M, n_blk, BLOCK_K) + + # Outer = whole row. The kernel applies ``fmaxf(row_amax, 1e-12f)`` + # to the divisor; do the same here. + row_amax = torch.amax(torch.abs(x_fp32), dim=-1) # (M,) fp32 -- raw, pre-floor + row_amax_safe = torch.clamp(row_amax, min=_AMAX_FLOOR).unsqueeze(-1) # (M, 1) + + # Same ``compute_global_encode_scaling_factor_FP4`` form as the + # per-tensor / blocked paths (just with ``row_amax`` instead of + # ``window_amax`` / ``global_amax``). + S_enc_row = (fp8_max * fp4_max) / row_amax_safe # (M, 1) + S_enc_row = torch.minimum(S_enc_row, fp32_max) + S_enc_row = torch.where( + (row_amax_safe == 0) | (S_enc_row == 0), + torch.ones_like(S_enc_row), + S_enc_row, + ) + + # Fold ``1 / fp4_max`` into the multiplier the same way the kernel + # does in ``compute_decoding_scaling_factor`` (``S_enc * fp4_max_inv``). + S_enc_row_mul_inv6 = S_enc_row * torch.reciprocal(fp4_max) # (M, 1) + + # 1x16 block amax. Broadcast row's S_enc across n_blk blocks. + vec_max = torch.amax(torch.abs(x_blk), dim=-1, keepdim=True) # (M, n_blk, 1) + S_enc_per_blk = S_enc_row.unsqueeze(-1) # (M, 1, 1) -> broadcasts to (M, n_blk, 1) + S_enc_per_blk_mul = S_enc_row_mul_inv6.unsqueeze(-1) + + # decode_scale = saturating_cast(vec_max * S_enc / 6). + # Kernel does NOT clamp before the cast; we clamp here because + # PyTorch's ``.to(float8_e4m3fn)`` does not match CUDA's saturating + # cast for values above FP8_MAX. After the explicit clamp the two + # paths agree byte-for-byte. + decode_scale_fp32 = vec_max * S_enc_per_blk_mul + decode_scale_fp32 = torch.minimum(decode_scale_fp32, fp32_max) + decode_scale_fp32 = torch.clamp(decode_scale_fp32, min=-fp8_max, max=fp8_max) + decode_scale_e4m3 = decode_scale_fp32.to(torch.float8_e4m3fn) + decode_scale_back_fp32 = decode_scale_e4m3.to(torch.float32) + + # block_scale = S_enc / s_dec, matching ``__fdiv_rn`` in the + # kernel. All-zero blocks: s_dec saturates to 0, naive S_enc/0 + # would NaN; short-circuit to 0 to mirror the kernel. + zero_blk = decode_scale_back_fp32 == 0 + denom = torch.where( + zero_blk, torch.ones_like(decode_scale_back_fp32), decode_scale_back_fp32 + ) + encode_scale = S_enc_per_blk / denom + encode_scale = torch.where(zero_blk, torch.zeros_like(encode_scale), encode_scale) + encode_scale = torch.minimum(encode_scale, fp32_max) + + # Apply scale, clamp to FP4 range, pack two FP4 values per byte. + scaled_x = x_blk * encode_scale + clipped_x = torch.clamp(scaled_x, -fp4_max, fp4_max).reshape(M, N) + qx = cast_to_fp4x2(clipped_x).contiguous() # (M, N // 2) + + sx = decode_scale_e4m3.squeeze(-1).contiguous() # (M, n_blk) + row_amax_out = row_amax.to(torch.float32).contiguous() # (M,) -- raw, no floor + return qx, sx, row_amax_out + + def quantize(self, tensor: torch.Tensor) -> RefNVFP4TensorPerToken: + """Quantize ``tensor`` and return a ``RefNVFP4TensorPerToken``.""" + out = RefNVFP4TensorPerToken() + if self.rowwise: + qx, sx, ra = self._quantize_2d(tensor) + out.data = qx + out.scale = sx + out.row_amax = ra + if self.columnwise: + # The columnwise output is the rowwise quantization of the + # transpose; both directions share the same math chain. + qx_t, sx_t, ca = self._quantize_2d(tensor.transpose(0, 1).contiguous()) + out.columnwise_data = qx_t + out.columnwise_scale = sx_t + out.col_amax = ca + return out + + +# ============================================================================ +# Production wrapper (calls the CUDA kernel via the C-API binding). +# ============================================================================ + +# ---------------------------------------------------------------------------- +# Shape / dtype gate shared by all three entries. +# ---------------------------------------------------------------------------- +_PER_TOKEN_TILE: int = 128 # CHUNK_DIM_Y / CHUNK_DIM_X in the kernel + + +def _validate_per_token_input(x: torch.Tensor) -> Tuple[int, int]: + """Enforce the per-token kernel's hard constraints. Returns ``(M, K)``.""" + if x.ndim != 2: + raise ValueError(f"nvfp4_per_token expects a 2D tensor, got {x.ndim}D") + if x.dtype != torch.bfloat16: + raise ValueError( + f"Per-token kernel is bf16-only; got dtype {x.dtype}. " + "Non-bf16 inputs are not supported (no fallback path)." + ) + M, K = x.shape + if M % _PER_TOKEN_TILE != 0: + raise ValueError(f"Per-token kernel requires M % {_PER_TOKEN_TILE} == 0; got M={M}") + if K % _PER_TOKEN_TILE != 0: + raise ValueError(f"Per-token kernel requires K % {_PER_TOKEN_TILE} == 0; got K={K}") + return M, K + + +def nvfp4_per_token_quantize( + x: torch.Tensor, + *, + rowwise: bool = True, + columnwise: bool = False, + with_rht: bool = False, + random_sign_mask_t: int = 0xACE1, + with_swizzle: bool = False, +) -> RefNVFP4TensorPerToken: + """Production NVFP4 per-token cast through ``tex.nvfp4_per_token_quantize``. + + Composite K1 (per-row/per-col amax) + K2 (FP4 + e4m3 SF) on the same + stream. ``with_rht``: 16-pt col-wise RHT in K1+K2 (rowwise unaffected); + ``random_sign_mask_t`` low 16 bits = sign pattern (default ``0xACE1``). + + ``with_swizzle=True``: rowwise ``scale_inv`` in cuBLAS LT layout + (colwise stays compact). Downstream ``nvfp4_per_token_gemm`` must + use ``sf_swizzled=True`` to skip its built-in swizzle. + + Raises ``ValueError`` on non-bf16 input or non-128-aligned shapes. + """ + # Import lazily so the module does not require the binary at import time. + # (Mirrors the pattern in ``gemm_nvfp4_blocked.py``.) + import transformer_engine_torch as tex # type: ignore + + if not (rowwise or columnwise): + raise ValueError("At least one of rowwise / columnwise must be True.") + M, K = _validate_per_token_input(x) + + device = x.device + # Empty placeholders for the direction(s) we don't request -- the + # binding still expects the argument slots (typed-empty is fine). + empty = torch.empty(0, dtype=torch.uint8, device=device) + empty_f32 = torch.empty(0, dtype=torch.float32, device=device) + + if rowwise: + q_row = torch.empty((M, K // 2), dtype=torch.uint8, device=device) + s_dec_row = torch.empty((M, K // BLOCK_K), dtype=torch.uint8, device=device) + row_amax = torch.empty((M,), dtype=torch.float32, device=device) + else: + q_row, s_dec_row, row_amax = empty, empty, empty_f32 + + if columnwise: + q_col = torch.empty((K, M // 2), dtype=torch.uint8, device=device) + s_dec_col = torch.empty((K, M // BLOCK_K), dtype=torch.uint8, device=device) + col_amax = torch.empty((K,), dtype=torch.float32, device=device) + else: + q_col, s_dec_col, col_amax = empty, empty, empty_f32 + + tex.nvfp4_per_token_quantize( + x, + q_row, + s_dec_row, + row_amax, + q_col, + s_dec_col, + col_amax, + rowwise, + columnwise, + with_rht=with_rht, + random_sign_mask_t=int(random_sign_mask_t) & 0xFFFF, + with_swizzle=with_swizzle, + ) + + out = RefNVFP4TensorPerToken() + if rowwise: + out.data = q_row + out.scale = s_dec_row.view(torch.float8_e4m3fn) + out.row_amax = row_amax + if columnwise: + out.columnwise_data = q_col + out.columnwise_scale = s_dec_col.view(torch.float8_e4m3fn) + out.col_amax = col_amax + return out + + +# ============================================================================ +# Split entries (K1 = amax-only, K2 = encode-only). +# +# Diagnostic / benchmark interface, mirroring the production per-tensor +# kernel split (``HadamardAmaxTmaKernel`` for amax + the row_col_rht_gemm +# cast pass). Production callers should use ``nvfp4_per_token_quantize`` +# above; the composite handles K1 + K2 ordering on the same stream. +# ============================================================================ + + +def nvfp4_per_token_amax( + x: torch.Tensor, + *, + rowwise: bool = True, + columnwise: bool = True, + with_rht: bool = False, + random_sign_mask_t: int = 0xACE1, +) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]: + """Kernel 1 in isolation: per-row + per-col amax via TMA + atomicMax. + Returns ``(row_amax, col_amax)``; either may be ``None`` if the + corresponding direction is not requested. + + Lets the benchmark compare K1 wall-time against the production + ``HadamardAmaxTmaKernel``. Production callers should use the + composite ``nvfp4_per_token_quantize`` instead. + + ``with_rht=True`` applies a 16-pt col-wise RHT before amax; rowwise + never sees RHT. ``random_sign_mask_t`` low 16 bits = sign pattern + (default ``0xACE1``). + + Raises ``ValueError`` on non-bf16 input or non-128-aligned shapes. + """ + import transformer_engine_torch as tex # type: ignore + + if not (rowwise or columnwise): + raise ValueError("At least one of rowwise / columnwise must be True.") + M, K = _validate_per_token_input(x) + + device = x.device + row_amax = ( + torch.empty((M,), dtype=torch.float32, device=device) + if rowwise + else torch.empty(0, dtype=torch.float32, device=device) + ) + col_amax = ( + torch.empty((K,), dtype=torch.float32, device=device) + if columnwise + else torch.empty(0, dtype=torch.float32, device=device) + ) + + tex.nvfp4_per_token_amax( + x, + row_amax, + col_amax, + rowwise, + columnwise, + with_rht=with_rht, + random_sign_mask_t=int(random_sign_mask_t) & 0xFFFF, + ) + + return (row_amax if rowwise else None, col_amax if columnwise else None) + + +def nvfp4_per_token_encode( + x: torch.Tensor, + *, + row_amax: Optional[torch.Tensor] = None, + col_amax: Optional[torch.Tensor] = None, + rowwise: bool = True, + columnwise: bool = True, + with_rht: bool = False, + random_sign_mask_t: int = 0xACE1, + with_swizzle: bool = False, +) -> RefNVFP4TensorPerToken: + """K2 in isolation: FP4 + e4m3 SF given pre-filled amax buffer(s) + (``row_amax`` ``(M,)`` and/or ``col_amax`` ``(K,)`` from a prior + ``nvfp4_per_token_amax`` call). + + ``with_rht=True`` requires ``col_amax`` from a K1 call with the SAME + mask. ``with_swizzle=True`` emits rowwise ``scale_inv`` in cuBLAS LT + swizzled layout (skips a downstream swizzle launch). + + Raises ``ValueError`` on non-bf16 input, non-128-aligned shapes, or + missing / mis-shaped amax buffers. + """ + import transformer_engine_torch as tex # type: ignore + + if not (rowwise or columnwise): + raise ValueError("At least one of rowwise / columnwise must be True.") + M, K = _validate_per_token_input(x) + if rowwise and (row_amax is None or row_amax.shape != (M,)): + raise ValueError(f"row_amax must be (M={M},) fp32 when rowwise=True") + if columnwise and (col_amax is None or col_amax.shape != (K,)): + raise ValueError(f"col_amax must be (K={K},) fp32 when columnwise=True") + + device = x.device + empty = torch.empty(0, dtype=torch.uint8, device=device) + empty_f32 = torch.empty(0, dtype=torch.float32, device=device) + + if rowwise: + q_row = torch.empty((M, K // 2), dtype=torch.uint8, device=device) + s_dec_row = torch.empty((M, K // BLOCK_K), dtype=torch.uint8, device=device) + row_amax_t = row_amax # type: ignore[assignment] + else: + q_row, s_dec_row, row_amax_t = empty, empty, empty_f32 + if columnwise: + q_col = torch.empty((K, M // 2), dtype=torch.uint8, device=device) + s_dec_col = torch.empty((K, M // BLOCK_K), dtype=torch.uint8, device=device) + col_amax_t = col_amax # type: ignore[assignment] + else: + q_col, s_dec_col, col_amax_t = empty, empty, empty_f32 + + tex.nvfp4_per_token_encode( + x, + q_row, + s_dec_row, + row_amax_t, + q_col, + s_dec_col, + col_amax_t, + rowwise, + columnwise, + with_rht=with_rht, + random_sign_mask_t=int(random_sign_mask_t) & 0xFFFF, + with_swizzle=with_swizzle, + ) + + out = RefNVFP4TensorPerToken() + if rowwise: + out.data = q_row + out.scale = s_dec_row.view(torch.float8_e4m3fn) + out.row_amax = row_amax_t + if columnwise: + out.columnwise_data = q_col + out.columnwise_scale = s_dec_col.view(torch.float8_e4m3fn) + out.col_amax = col_amax_t + return out diff --git a/transformer_engine/pytorch/custom_recipes/quantization_nvfp4_per_token_group.py b/transformer_engine/pytorch/custom_recipes/quantization_nvfp4_per_token_group.py new file mode 100644 index 0000000000..7cd48bad7b --- /dev/null +++ b/transformer_engine/pytorch/custom_recipes/quantization_nvfp4_per_token_group.py @@ -0,0 +1,122 @@ +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""Grouped (multi-tensor) NVFP4 per-token quantize Python wrapper. + +Dispatches through ``tex.nvfp4_per_token_group_quantize_bulk`` -- the bulk +C++ binding owns allocation, view-slicing, and the composite K1+K2 kernel +dispatch. Requires bf16 input with K and every split_sections[i] a multiple +of 128; up to 64 splits. +""" + +from __future__ import annotations + +from typing import List, Sequence + +import torch + +from transformer_engine.pytorch.custom_recipes.quantization_nvfp4_per_token import ( + RefNVFP4TensorPerToken, + _PER_TOKEN_TILE, +) + + +def _validate_per_token_group_input( + x_concat: torch.Tensor, split_sections: Sequence[int] +) -> tuple[int, int]: + """Enforce the per-token grouped kernel's hard constraints. Returns + ``(sum_M, K)``. + """ + if x_concat.ndim != 2: + raise ValueError(f"nvfp4_per_token_group_quantize expects a 2D input, got {x_concat.ndim}D") + if not x_concat.is_contiguous(): + raise ValueError("x_concat must be contiguous (row-major)") + if x_concat.dtype != torch.bfloat16: + raise ValueError(f"Per-token grouped kernel is bf16-only; got dtype {x_concat.dtype}.") + sum_M, K = x_concat.shape + if K % _PER_TOKEN_TILE != 0: + raise ValueError(f"Per-token grouped kernel requires K % {_PER_TOKEN_TILE} == 0; got K={K}") + if len(split_sections) == 0: + raise ValueError("split_sections must not be empty") + if len(split_sections) > 64: + raise ValueError( + f"num_tensors must be <= 64 (kernel arg-struct cap); got {len(split_sections)}" + ) + acc = 0 + for i, M_i in enumerate(split_sections): + if M_i <= 0: + raise ValueError(f"split_sections[{i}] must be > 0, got {M_i}") + if M_i % _PER_TOKEN_TILE != 0: + raise ValueError(f"split_sections[{i}] = {M_i} must be a multiple of {_PER_TOKEN_TILE}") + acc += M_i + if acc != sum_M: + raise ValueError(f"sum(split_sections) = {acc} must equal input.size(0) = {sum_M}") + return sum_M, K + + +# Default RHT sign-flip mask seed; matches the single-tensor wrapper. +_RHT_MASK_DEFAULT: int = 0xACE1 + + +def nvfp4_per_token_group_quantize( + x_concat: torch.Tensor, + split_sections: Sequence[int], + *, + rowwise: bool = True, + columnwise: bool = False, + with_rht: bool = False, + random_sign_mask_t: int = _RHT_MASK_DEFAULT, +) -> List[RefNVFP4TensorPerToken]: + """Grouped NVFP4 per-token cast; returns N RefNVFP4TensorPerToken splits. + + Args: + x_concat: (sum_M, K) bf16, row-major contiguous. + split_sections: per-split row counts (each a multiple of 128). + rowwise / columnwise: which directions to emit. + with_rht: True -> apply a 16-pt col-wise RHT in BOTH K1 and K2; + downstream GEMM must consume RHT-rotated weights to stay + unbiased. Rowwise never sees RHT. + random_sign_mask_t: low 16 bits = sign pattern shared by K1+K2. + + Raises ``ValueError`` on shape / dtype / split-size violations. + """ + import transformer_engine_torch as tex # type: ignore + + if not (rowwise or columnwise): + raise ValueError("At least one of rowwise / columnwise must be True.") + + _validate_per_token_group_input(x_concat, split_sections) + split_sections_list = [int(M_i) for M_i in split_sections] + N = len(split_sections_list) + + # Bulk C++ call returns per-split views; s_dec_* already in fp8_e4m3fn dtype. + ( + q_row_list, + s_dec_row_list, + row_amax_list, + q_col_list, + s_dec_col_list, + col_amax_list, + ) = tex.nvfp4_per_token_group_quantize_bulk( + x_concat, + split_sections_list, + rowwise, + columnwise, + with_rht=bool(with_rht), + random_sign_mask_t=int(random_sign_mask_t) & 0xFFFF, + ) + + outs: List[RefNVFP4TensorPerToken] = [] + for i in range(N): + out = RefNVFP4TensorPerToken() + if rowwise: + out.data = q_row_list[i] + out.scale = s_dec_row_list[i] + out.row_amax = row_amax_list[i] + if columnwise: + out.columnwise_data = q_col_list[i] + out.columnwise_scale = s_dec_col_list[i] + out.col_amax = col_amax_list[i] + outs.append(out) + return outs