From 956a3c630f9d9db86ffc314c041d57b779e52cb0 Mon Sep 17 00:00:00 2001 From: sraman-rgb <270218152+sraman-rgb@users.noreply.github.com> Date: Wed, 27 May 2026 10:41:46 -0700 Subject: [PATCH 01/14] Enable NVFP4 fused grouped MLP follow-up Signed-off-by: sraman-rgb <270218152+sraman-rgb@users.noreply.github.com> --- .../common/gemm/cublaslt_grouped_gemm.cu | 10 +- transformer_engine/pytorch/ops/_common.py | 169 ++++++- .../pytorch/ops/basic/grouped_linear.py | 15 +- .../pytorch/ops/fused/backward_grouped_mlp.py | 445 +++++++++++++----- .../pytorch/ops/fused/forward_grouped_mlp.py | 393 +++++++++++----- 5 files changed, 794 insertions(+), 238 deletions(-) diff --git a/transformer_engine/common/gemm/cublaslt_grouped_gemm.cu b/transformer_engine/common/gemm/cublaslt_grouped_gemm.cu index f064af2478..01c983ad3d 100644 --- a/transformer_engine/common/gemm/cublaslt_grouped_gemm.cu +++ b/transformer_engine/common/gemm/cublaslt_grouped_gemm.cu @@ -1751,8 +1751,16 @@ void nvte_grouped_gemm_with_discrete_inputA(const NVTETensor *A_list, size_t num NVTE_CHECK(A_list_info.all_col, "Grouped GEMM: A_list is missing column-wise data"); A_sel.dtype = A_list_info.col_dtype; } + // GroupedTensor metadata stores the original logical shape, so columnwise + // storage usually needs storage_transposed. Discrete NVFP4 A tensors with + // logical transa=false expose columnwise data with the transposed logical + // shape already, so treating it as storage_transposed would undo the layout + // needed by cuBLAS. + const bool nvfp4_discrete_a_columnwise = nvfp4 && !static_cast(transa); + const bool a_list_storage_transposed = + nvfp4_discrete_a_columnwise ? false : choice.storage_transposed; a_multi_tensor_args = build_grouped_gemm_multi_inputA_args( - A_list, num_a_tensors, choice.use_rowwise, choice.storage_transposed, &avg_first_dim, + A_list, num_a_tensors, choice.use_rowwise, a_list_storage_transposed, &avg_first_dim, &avg_last_dim, "A"); // Discrete A_list: per-tensor pointers come from `a_multi_tensor_args` (data/scale/amax). diff --git a/transformer_engine/pytorch/ops/_common.py b/transformer_engine/pytorch/ops/_common.py index 717d872010..0bbc5280b3 100644 --- a/transformer_engine/pytorch/ops/_common.py +++ b/transformer_engine/pytorch/ops/_common.py @@ -8,15 +8,18 @@ import functools import math from importlib.metadata import PackageNotFoundError, version as get_pkg_version -from typing import Optional +from typing import Any, Optional import torch from packaging.version import Version as PkgVersion +import transformer_engine_torch as tex from transformer_engine_torch import FP8TensorMeta from ..torch_version import torch_version from ..quantization import FP8GlobalStateManager +from ..tensor import NVFP4Quantizer, NVFP4Tensor, Quantizer from ..tensor.float8_tensor import Float8Tensor +from ..tensor.grouped_tensor import GroupedTensor from ..quantized_tensor import QuantizedTensorStorage from ..utils import canonicalize_dtype @@ -57,6 +60,168 @@ def _nvidia_cudnn_frontend_supports_wgrad() -> bool: return _cudnn_frontend_version_supported() +def _pack_nvfp4_amax_list(tensors: list) -> None: + """Ensure discrete NVFP4 weight list uses contiguous per-group amax buffers.""" + if not tensors: + return + row_amaxes = [getattr(tensor, "_amax_rowwise", None) for tensor in tensors] + if all(amax is not None for amax in row_amaxes): + packed_row_amax = torch.cat([amax.view(-1) for amax in row_amaxes], dim=0).contiguous() + for idx, tensor in enumerate(tensors): + tensor._amax_rowwise = packed_row_amax[idx : idx + 1] + col_amaxes = [getattr(tensor, "_amax_columnwise", None) for tensor in tensors] + if all(amax is not None for amax in col_amaxes): + packed_col_amax = torch.cat([amax.view(-1) for amax in col_amaxes], dim=0).contiguous() + for idx, tensor in enumerate(tensors): + tensor._amax_columnwise = packed_col_amax[idx : idx + 1] + + +def _enable_nvfp4_rht_for_group_quantize(quantizer: Quantizer) -> None: + """Use the graph-safe NVFP4 grouped quantization path.""" + if isinstance(quantizer, NVFP4Quantizer): + quantizer.with_rht = True + quantizer.with_post_rht_amax = True + + +def _group_quantize_for_grouped_mlp( + tensor: torch.Tensor, + quantizer: Quantizer, + num_groups: int, + split_sizes: Optional[torch.Tensor], + *, + tensor_offsets: Optional[torch.Tensor] = None, +) -> GroupedTensor: + """Quantize into grouped storage, using regular quantize for one-group NVFP4.""" + if num_groups != 1 or not isinstance(quantizer, NVFP4Quantizer): + return tex.group_quantize(tensor, quantizer, num_groups, split_sizes) + + quantized = tex.quantize(tensor, quantizer) + with_gemm_swizzled_scales = getattr(quantized, "_with_gemm_swizzled_scales", False) + if getattr(quantizer, "optimize_for_gemm", False): + tex.swizzle_scales_for_gemm_(quantized) + with_gemm_swizzled_scales = True + + rowwise_data = getattr(quantized, "_rowwise_data", None) + rowwise_scale = getattr(quantized, "_rowwise_scale_inv", None) + columnwise_data = getattr(quantized, "_columnwise_data", None) + columnwise_scale = getattr(quantized, "_columnwise_scale_inv", None) + amax = getattr(quantized, "_amax_rowwise", None) + columnwise_amax = getattr(quantized, "_amax_columnwise", None) + + if split_sizes is None: + split_sizes = torch.full((1,), tensor.shape[0], dtype=torch.int64, device=tensor.device) + else: + split_sizes = split_sizes.to(dtype=torch.int64, device=tensor.device) + + m_dim = tensor.shape[0] + if rowwise_data is not None: + k_dim = rowwise_data.shape[-1] * 2 + elif columnwise_data is not None: + k_dim = columnwise_data.shape[0] + else: + k_dim = tensor.shape[-1] + + if tensor_offsets is None: + tensor_offsets = torch.cat( + [ + torch.zeros(1, dtype=torch.int64, device=tensor.device), + torch.cumsum(split_sizes * k_dim, dim=0), + ], + ) + + return GroupedTensor( + shape=(m_dim, k_dim), + dtype=tensor.dtype, + quantizer=quantizer, + num_tensors=1, + data=rowwise_data.reshape(-1) if rowwise_data is not None else None, + columnwise_data=columnwise_data.reshape(-1) if columnwise_data is not None else None, + scale_inv=rowwise_scale.reshape(-1) if rowwise_scale is not None else None, + columnwise_scale_inv=columnwise_scale.reshape(-1) if columnwise_scale is not None else None, + amax=amax, + columnwise_amax=columnwise_amax, + first_dims=split_sizes, + tensor_offsets=tensor_offsets, + with_gemm_swizzled_scales=with_gemm_swizzled_scales, + ) + + +def _nvfp4_logical_data_view(data: torch.Tensor) -> torch.Tensor: + """View packed NVFP4 data with its logical K dimension for scale swizzling.""" + return data.as_strided((data.shape[0], data.shape[1] * 2), (data.stride(0), 0)) + + +def _nvfp4_amax(tensors: Any, *, columnwise: bool) -> torch.Tensor: + """Get one NVFP4 amax value per group.""" + grouped_attr = "columnwise_amax" if columnwise else "amax" + tensor_attr = "_amax_columnwise" if columnwise else "_amax_rowwise" + + if hasattr(tensors, grouped_attr): + amax = getattr(tensors, grouped_attr) + if amax is None: + raise RuntimeError(f"NVFP4 GroupedTensor is missing {grouped_attr}.") + return amax.view(-1) + + amaxes = [getattr(tensor, tensor_attr, None) for tensor in tensors] + if any(amax is None for amax in amaxes): + raise RuntimeError(f"NVFP4 tensor list is missing {tensor_attr}.") + return torch.cat([amax.view(-1) for amax in amaxes], dim=0) + + +def _nvfp4_single_tensor_from_grouped( + grouped: GroupedTensor, + quantizer: Optional[NVFP4Quantizer] = None, + *, + fp4_dtype: Optional[torch.dtype] = None, +) -> NVFP4Tensor: + """Build a single NVFP4Tensor view over a one-member grouped storage.""" + if quantizer is None: + quantizer = grouped.quantizer + if not isinstance(quantizer, NVFP4Quantizer): + raise TypeError("Expected an NVFP4 GroupedTensor.") + + shape = tuple(grouped.logical_shape) + rowwise_data = None + if grouped.rowwise_data is not None: + rowwise_data = grouped.rowwise_data.view(quantizer.convert_shape_for_fp4(shape)) + + rowwise_scale_inv = None + if grouped.scale_inv is not None: + rowwise_scale_inv = grouped.scale_inv.view(quantizer.get_scale_shape(shape, False)) + + columnwise_data = None + if grouped.columnwise_data is not None: + columnwise_shape = quantizer.get_columnwise_shape(shape) + columnwise_data = grouped.columnwise_data.view( + quantizer.convert_shape_for_fp4(columnwise_shape) + ) + + columnwise_scale_inv = None + if grouped.columnwise_scale_inv is not None: + columnwise_scale_inv = grouped.columnwise_scale_inv.view( + quantizer.get_scale_shape(shape, True) + ) + + return NVFP4Tensor( + shape=shape, + dtype=grouped.get_dtype(), + rowwise_data=rowwise_data, + rowwise_scale_inv=rowwise_scale_inv, + columnwise_data=columnwise_data, + columnwise_scale_inv=columnwise_scale_inv, + amax_rowwise=grouped.amax, + amax_columnwise=grouped.columnwise_amax, + fp4_dtype=fp4_dtype or quantizer.dtype, + quantizer=quantizer, + requires_grad=False, + with_gemm_swizzled_scales=getattr( + grouped, + "_with_gemm_swizzled_scales", + getattr(grouped, "with_gemm_swizzled_scales", quantizer.optimize_for_gemm), + ), + ) + + def is_quantized_tensor(tensor: torch.Tensor | QuantizedTensorStorage) -> bool: """Check if tensor is a quantized tensor""" return isinstance(tensor, QuantizedTensorStorage) @@ -285,7 +450,7 @@ def fuse_grouped_mlp_ops( if not fused_op_cls.is_supported(): return ops - if recipe is None or not recipe.mxfp8(): + if recipe is None or not (recipe.mxfp8() or recipe.nvfp4()): return ops if activation_op_types is None: activation_op_types = (ScaledSwiGLU, ScaledClampedQGeGLU) diff --git a/transformer_engine/pytorch/ops/basic/grouped_linear.py b/transformer_engine/pytorch/ops/basic/grouped_linear.py index 1f00d92284..77c8b4bafa 100644 --- a/transformer_engine/pytorch/ops/basic/grouped_linear.py +++ b/transformer_engine/pytorch/ops/basic/grouped_linear.py @@ -1059,9 +1059,6 @@ def _fuser_forward_split_quantize( num_groups = self.num_groups has_bias = self.has_bias - # Need CPU split sizes for split_quantize / general_grouped_gemm. - split_sizes_int = [int(s) for s in split_sizes.tolist()] - # Extract params if self.single_grouped_weight: weights = self.weight.quantized_tensors @@ -1083,6 +1080,12 @@ def _fuser_forward_split_quantize( # Split input tensor and convert dtypes if needed x = maybe_dequantize(input_, dtype) + if num_groups == 1: + # Avoid CUDA->CPU sync from split_sizes.tolist() during CUDA graph capture. + split_sizes_int = [x.numel() // x.size(-1)] + else: + # Need CPU split sizes for split_quantize / general_grouped_gemm. + split_sizes_int = [int(s) for s in split_sizes.tolist()] xs = None if with_quantized_compute: for quantizer in input_quantizers: @@ -1329,8 +1332,12 @@ def _fuser_backward_split_quantize( ws, saved_tensors = saved_tensors[:num_groups], saved_tensors[num_groups:] # Split grad output tensor and convert dtypes if needed - split_sizes_int = [int(s) for s in split_sizes.tolist()] dy = maybe_dequantize(grad_output, ctx.dtype) + if num_groups == 1: + # Avoid CUDA->CPU sync from split_sizes.tolist() during CUDA graph capture. + split_sizes_int = [dy.numel() // dy.size(-1)] + else: + split_sizes_int = [int(s) for s in split_sizes.tolist()] dys = None grad_biases = [None] * num_groups grad_scales = None diff --git a/transformer_engine/pytorch/ops/fused/backward_grouped_mlp.py b/transformer_engine/pytorch/ops/fused/backward_grouped_mlp.py index 24aaafc1ee..b1faca36d3 100644 --- a/transformer_engine/pytorch/ops/fused/backward_grouped_mlp.py +++ b/transformer_engine/pytorch/ops/fused/backward_grouped_mlp.py @@ -8,16 +8,17 @@ from collections.abc import Callable import functools import os -from typing import Optional +from typing import Any, Optional import torch import transformer_engine_torch as tex from ...quantization import Recipe +from ...tensor import NVFP4Quantizer, NVFP4Tensor from ...tensor.grouped_tensor import GroupedTensor from ...tensor.mxfp8_tensor import MXFP8Quantizer from ...utils import clear_tensor_data, get_cached_ones_tensor, get_device_compute_capability -from ...constants import MXFP8_BLOCK_SCALING_SIZE +from ...constants import MXFP8_BLOCK_SCALING_SIZE, NVFP4_BLOCK_SCALING_SIZE from ..basic import GroupedLinear, ScaledSReLU, ScaledClampedQGeGLU from ..fuser import register_backward_fusion from ..op import FusedOperation, FusibleOperation, OperationContext @@ -25,6 +26,11 @@ _cudnn_frontend_geglu_runtime_params, _cudnn_frontend_version_supported, _cudnn_frontend_supports_grouped_gemm_srelu, + _enable_nvfp4_rht_for_group_quantize, + _group_quantize_for_grouped_mlp, + _nvfp4_amax, + _nvfp4_logical_data_view, + _nvfp4_single_tensor_from_grouped, fuse_grouped_mlp_ops, get_accumulate_flag_in_param, get_dummy_wgrads_for_params, @@ -34,11 +40,51 @@ view_main_grad_as_grouped_buffer, validate_grouped_mlp_dims, ) -from ...cpp_extensions import general_grouped_gemm_for_grouped_tensor +from ...cpp_extensions import ( + general_gemm, + general_grouped_gemm_for_grouped_tensor, +) from ...module.base import _2X_ACC_WGRAD from ...triton.grouped_dbias_dscales import compute_grouped_dbias_dscales +def _mark_with_gemm_swizzled_scales(tensors: Any) -> None: + """Mark tensors whose scale buffers are already in GEMM-swizzled layout.""" + if tensors is None: + return + if hasattr(tensors, "with_gemm_swizzled_scales"): + tensors.with_gemm_swizzled_scales = True + if hasattr(tensors, "_with_gemm_swizzled_scales"): + tensors._with_gemm_swizzled_scales = True + + +def _nvfp4_single_group_wgrad_gemm( + grouped_x: GroupedTensor, + grouped_dy: GroupedTensor, + wgrad_output, + *, + weight_shape: tuple[int, int], + accumulate: bool, +) -> None: + """Run one-group NVFP4 wgrad with regular GEMM instead of grouped GEMM.""" + x_single = _nvfp4_single_tensor_from_grouped(grouped_x) + dy_single = _nvfp4_single_tensor_from_grouped(grouped_dy) + if isinstance(wgrad_output, GroupedTensor): + out = wgrad_output.rowwise_data.view(1, *weight_shape)[0] + else: + out = wgrad_output[0] + + general_gemm( + x_single, + dy_single, + out_dtype=out.dtype, + out=out, + layout="NT", + accumulate=accumulate, + use_split_accumulator=_2X_ACC_WGRAD, + ) + + def _cudnn_compute_wgrad( grouped_x: GroupedTensor, grouped_dy: GroupedTensor, @@ -219,6 +265,18 @@ def _compute_grad_params( single_grouped_weight=fc_op.single_grouped_weight, current_stream=torch.cuda.current_stream().cuda_stream, ) + elif ( + num_groups == 1 + and isinstance(grouped_x, GroupedTensor) + and isinstance(grouped_dy, GroupedTensor) + and isinstance(getattr(grouped_x, "quantizer", None), NVFP4Quantizer) + and isinstance(getattr(grouped_dy, "quantizer", None), NVFP4Quantizer) + ): + gemm_fn = functools.partial( + _nvfp4_single_group_wgrad_gemm, + weight_shape=weight_shape, + accumulate=accumulate_into_main_grad, + ) else: gemm_fn = functools.partial( general_grouped_gemm_for_grouped_tensor, @@ -359,7 +417,9 @@ def fuser_backward( grad_output = grad_output.reshape(-1, fc2_weight_shape[0]) out_shape = list(grad_output.size()) num_groups = fc1_op.num_groups - device = fc1_op._get_weight_tensors()[0].device + fc1_weight_param = fc1_op.weight if fc1_op.single_grouped_weight else fc1_op.weight0 + fc2_weight_param = fc2_op.weight if fc2_op.single_grouped_weight else fc2_op.weight0 + device = fc1_weight_param.device dtype = fc1_ctx.dtype # Saved tensors from FC1 forward. @@ -415,13 +475,22 @@ def fuser_backward( fc2_grad_output_quantizer = fc2_ctx.grad_output_quantizers[0] fc2_grad_output_quantizer.set_usage(rowwise=True, columnwise=fc2_ctx.weight_requires_grad) fc2_grad_output_quantizer.optimize_for_gemm = True + _enable_nvfp4_rht_for_group_quantize(fc2_grad_output_quantizer) output_fc2_dbias = fc2_op.has_bias fc2_dbias_packed = None fc2_dy = None + grad_output_quantizer = getattr(grad_output, "quantizer", None) + fc2_grad_output_quantizer_matches = ( + isinstance(fc2_grad_output_quantizer, MXFP8Quantizer) + and isinstance(grad_output_quantizer, MXFP8Quantizer) + ) or ( + isinstance(fc2_grad_output_quantizer, NVFP4Quantizer) + and isinstance(grad_output_quantizer, NVFP4Quantizer) + ) if ( not output_fc2_dbias and isinstance(grad_output, GroupedTensor) - and isinstance(getattr(grad_output, "quantizer", None), MXFP8Quantizer) + and fc2_grad_output_quantizer_matches ): grouped_fc2_dy = grad_output else: @@ -434,13 +503,26 @@ def fuser_backward( split_sizes, ) else: - grouped_fc2_dy = tex.group_quantize( + grouped_fc2_dy = _group_quantize_for_grouped_mlp( fc2_dy, fc2_grad_output_quantizer, num_groups, split_sizes, + tensor_offsets=base_split_offsets * fc2_weight_shape[0], ) + use_nvfp4 = ( + isinstance(fc2_grad_output_quantizer, NVFP4Quantizer) + or isinstance(fc1_weight_param, NVFP4Tensor) + or isinstance(fc2_weight_param, NVFP4Tensor) + ) + data_dtype = torch.float4_e2m1fn_x2 if use_nvfp4 else torch.float8_e4m3fn + scale_view_dtype = torch.float8_e4m3fn if use_nvfp4 else torch.float8_e8m0fnu + sf_vec_size = NVFP4_BLOCK_SCALING_SIZE if use_nvfp4 else MXFP8_BLOCK_SCALING_SIZE + data_k = out_shape[1] // 2 if use_nvfp4 else out_shape[1] + fc2_weight_k = fc2_weight_shape[1] // 2 if use_nvfp4 else fc2_weight_shape[1] + k_sf_divisor = 2 * sf_vec_size if use_nvfp4 else 4 * sf_vec_size + # Pack data tensors # Note: Fused kernel expects tensor with non-contiguous # logical dims. @@ -450,20 +532,46 @@ def fuser_backward( # Data logical shape: (sum(m), k, 1) # Scale logical shape: (32 (block row), 4 (block row), # sum(m)/128, 4 (block col), k/128, 1) - fc2_dy_data = grouped_fc2_dy.rowwise_data.view(out_shape[0], out_shape[1]) - fc2_dy_data = fc2_dy_data.view(dtype=torch.float8_e4m3fn) + fc2_dy_data = grouped_fc2_dy.rowwise_data.view(dtype=data_dtype) + fc2_dy_data = fc2_dy_data.view(out_shape[0], data_k) fc2_dy_data = fc2_dy_data.unsqueeze(0).permute(1, 2, 0) fc2_dy_scales = grouped_fc2_dy.scale_inv - fc2_dy_scales = fc2_dy_scales.view(dtype=torch.float8_e8m0fnu) - fc2_dy_scales = fc2_dy_scales.view( - 1, - (out_shape[0] + 127) // 128, - (out_shape[1] + 127) // 128, - MXFP8_BLOCK_SCALING_SIZE, - 4, - 4, + fc2_dy_scales = fc2_dy_scales.view(dtype=scale_view_dtype) + with_gemm_swizzled_scales = getattr( + grouped_fc2_dy, + "_with_gemm_swizzled_scales", + getattr(grouped_fc2_dy, "with_gemm_swizzled_scales", False), ) - fc2_dy_scales = fc2_dy_scales.permute(3, 4, 1, 5, 2, 0) + if use_nvfp4 and with_gemm_swizzled_scales: + fc2_dy_scales = fc2_dy_scales.view( + 1, + out_shape[0] // 128, + data_k // k_sf_divisor, + 32, + 4, + 4, + ) + fc2_dy_scales = fc2_dy_scales.permute(3, 4, 1, 5, 2, 0) + elif use_nvfp4: + fc2_dy_scales = fc2_dy_scales.view( + 1, + out_shape[0] // 128, + 4, + 32, + data_k // k_sf_divisor, + 4, + ) + fc2_dy_scales = fc2_dy_scales.permute(3, 2, 1, 5, 4, 0) + else: + fc2_dy_scales = fc2_dy_scales.view( + 1, + (out_shape[0] + 127) // 128, + (out_shape[1] + k_sf_divisor - 1) // k_sf_divisor, + 32, + 4, + 4, + ) + fc2_dy_scales = fc2_dy_scales.permute(3, 4, 1, 5, 2, 0) # Kernel scaling factors alpha_tensor = get_cached_ones_tensor(num_groups, dtype, device) @@ -474,25 +582,43 @@ def fuser_backward( scales_tensor = scales_f32.reshape(-1, 1, 1) dscales_tensor = torch.zeros_like(scales_tensor) + fc2_d_dtype = torch.bfloat16 if use_nvfp4 else torch.float8_e4m3fn + if use_nvfp4: + nvfp4_fp4_max = 6.0 + nvfp4_fp8_max = 448.0 + fc2_alpha_tensor = ( + torch.sqrt( + _nvfp4_amax(grouped_fc2_dy, columnwise=False) + * _nvfp4_amax(grouped_fc2_weight, columnwise=True) + ) + / (nvfp4_fp8_max * nvfp4_fp4_max) + ).expand(num_groups) + fc2_beta_tensor = get_cached_ones_tensor(num_groups, torch.float32, device) + fc2_norm_const_tensor = None + else: + fc2_alpha_tensor = alpha_tensor + fc2_beta_tensor = alpha_tensor + fc2_norm_const_tensor = norm_const_tensor + fc2_dactivation_kwargs = { "a_tensor": fc2_dy_data, "c_tensor": activation_in.unsqueeze(0).permute(1, 2, 0), "sfa_tensor": fc2_dy_scales, "padded_offsets": split_points, - "alpha_tensor": alpha_tensor, + "alpha_tensor": fc2_alpha_tensor, "prob_tensor": scales_tensor, "dprob_tensor": dscales_tensor, "generate_dbias": fc1_op.has_bias, - "norm_const_tensor": norm_const_tensor, - "d_dtype": torch.float8_e4m3fn, + "norm_const_tensor": fc2_norm_const_tensor, + "d_dtype": fc2_d_dtype, "cd_major": "n", - "sf_vec_size": MXFP8_BLOCK_SCALING_SIZE, + "sf_vec_size": sf_vec_size, "current_stream": current_stream, - "discrete_col_sfd": True, + "discrete_col_sfd": not use_nvfp4, "use_dynamic_sched": True, } if self._cudnn_dact_func is not None: - fc2_dactivation_kwargs["beta_tensor"] = alpha_tensor + fc2_dactivation_kwargs["beta_tensor"] = fc2_beta_tensor fc2_dactivation_kwargs["act_func"] = self._cudnn_dact_func else: fc2_dactivation_kwargs["use_dsrelu_reuse"] = recompute_fc2_x_from_dsrelu @@ -512,46 +638,72 @@ def fuser_backward( # Data actual shape: (num_groups, k, n) # Data logical shape: (n, k, num_groups) fc2_w_data = fc2_weight_for_gemm.columnwise_data - fc2_w_data = fc2_w_data.view(dtype=torch.float8_e4m3fn) - fc2_w_data = fc2_w_data.view(num_groups, fc2_weight_shape[0], fc2_weight_shape[1]) - fc2_w_data = fc2_w_data.permute(2, 1, 0) - fc2_w_scales = fc2_weight_for_gemm.columnwise_scale_inv.view(dtype=torch.float8_e8m0fnu) + fc2_w_data = fc2_w_data.view(dtype=data_dtype) + fc2_w_data = fc2_w_data.view(num_groups, fc2_weight_shape[0], fc2_weight_k) + fc2_w_data = fc2_w_data.permute(1, 2, 0) if use_nvfp4 else fc2_w_data.permute(2, 1, 0) + fc2_w_scales = fc2_weight_for_gemm.columnwise_scale_inv.view(dtype=scale_view_dtype) fc2_w_scales = fc2_w_scales.view( num_groups, - (fc2_weight_shape[1] + 127) // 128, + (fc2_weight_shape[1] + k_sf_divisor - 1) // k_sf_divisor, (fc2_weight_shape[0] + 127) // 128, - MXFP8_BLOCK_SCALING_SIZE, + 32, 4, 4, ) - fc2_w_scales = fc2_w_scales.permute(3, 4, 1, 5, 2, 0) + fc2_w_scales = ( + fc2_w_scales.permute(3, 4, 2, 5, 1, 0) + if use_nvfp4 + else fc2_w_scales.permute(3, 4, 1, 5, 2, 0) + ) fc2_dactivation_kwargs["b_tensor"] = fc2_w_data fc2_dactivation_kwargs["sfb_tensor"] = fc2_w_scales else: + fc2_weight_data_for_ptrs = [w._columnwise_data for w in grouped_fc2_weight] + if use_nvfp4: + fc2_weight_data_for_ptrs = [ + _nvfp4_logical_data_view(data) for data in fc2_weight_data_for_ptrs + ] fc2_b_ptrs, fc2_sfb_ptrs, _fc2_sw = tex.get_device_pointer_for_data_and_scales( - [w._columnwise_data for w in grouped_fc2_weight], + fc2_weight_data_for_ptrs, [w._columnwise_scale_inv for w in grouped_fc2_weight], swizzle=True, rowwise=False, - data_dtype=grouped_fc2_weight[0]._fp8_dtype, + data_dtype=( + grouped_fc2_weight[0]._fp4_dtype + if use_nvfp4 + else grouped_fc2_weight[0]._fp8_dtype + ), ) fc2_dactivation_kwargs["b_ptrs"] = fc2_b_ptrs fc2_dactivation_kwargs["sfb_ptrs"] = fc2_sfb_ptrs fc2_dactivation_kwargs["n"] = fc2_weight_shape[1] - fc2_dactivation_kwargs["b_dtype"] = torch.float8_e4m3fn - fc2_dactivation_kwargs["b_major"] = "n" + fc2_dactivation_kwargs["b_dtype"] = data_dtype + fc2_dactivation_kwargs["b_major"] = "k" if use_nvfp4 else "n" fc2_dgrad_kernel_out = self.grouped_gemm_dactivation_kernel()(**fc2_dactivation_kwargs) - fc1_dy_row_data = fc2_dgrad_kernel_out["d_row_tensor"] - fc1_dy_row_data = fc1_dy_row_data.view(out_shape[0], fc1_weight_shape[0]) - # View scale in their actual swizzled shape - fc1_dy_row_scale = fc2_dgrad_kernel_out["sfd_row_tensor"].permute(5, 2, 4, 0, 1, 3).view(-1) - fc1_dy_col_data = fc2_dgrad_kernel_out["d_col_tensor"] - fc1_dy_col_data = fc1_dy_col_data.view(out_shape[0], fc1_weight_shape[0]) - # View scale in their actual swizzled shape - fc1_dy_col_scale = fc2_dgrad_kernel_out["sfd_col_tensor"].permute(5, 2, 4, 0, 1, 3).view(-1) + if use_nvfp4: + fc1_dy_bf16 = fc2_dgrad_kernel_out["d_row_tensor"] + fc1_dy_bf16 = fc1_dy_bf16.view(out_shape[0], fc1_weight_shape[0]).contiguous() + fc1_dy_row_data = None + fc1_dy_row_scale = None + fc1_dy_col_data = None + fc1_dy_col_scale = None + else: + fc1_dy_bf16 = None + fc1_dy_row_data = fc2_dgrad_kernel_out["d_row_tensor"] + fc1_dy_row_data = fc1_dy_row_data.view(out_shape[0], fc1_weight_shape[0]) + # View scale in their actual swizzled shape + fc1_dy_row_scale = ( + fc2_dgrad_kernel_out["sfd_row_tensor"].permute(5, 2, 4, 0, 1, 3).view(-1) + ) + fc1_dy_col_data = fc2_dgrad_kernel_out["d_col_tensor"] + fc1_dy_col_data = fc1_dy_col_data.view(out_shape[0], fc1_weight_shape[0]) + # View scale in their actual swizzled shape + fc1_dy_col_scale = ( + fc2_dgrad_kernel_out["sfd_col_tensor"].permute(5, 2, 4, 0, 1, 3).view(-1) + ) grad_scales = fc2_dgrad_kernel_out["dprob_tensor"].view(-1) if recompute_fc2_x_from_dsrelu: @@ -625,21 +777,39 @@ def fuser_backward( # FC1 grad output for dgrad and wgrad GEMMs fc1_dy_tensor_offsets = base_split_offsets * fc1_weight_shape[0] - grouped_fc1_dy = GroupedTensor( - shape=(out_shape[0], fc1_weight_shape[0]), - dtype=dtype, - num_tensors=num_groups, - quantizer=fc1_ctx.grad_output_quantizers[0], - data=fc1_dy_row_data, - columnwise_data=fc1_dy_col_data, - scale_inv=fc1_dy_row_scale, - columnwise_scale_inv=fc1_dy_col_scale, - first_dims=split_sizes, - tensor_offsets=fc1_dy_tensor_offsets, - with_gemm_swizzled_scales=True, - ) + fc1_grad_output_quantizer = fc1_ctx.grad_output_quantizers[0] + if use_nvfp4: + fc1_grad_output_quantizer.set_usage( + rowwise=True, + columnwise=fc1_ctx.weight_requires_grad, + ) + fc1_grad_output_quantizer.optimize_for_gemm = True + _enable_nvfp4_rht_for_group_quantize(fc1_grad_output_quantizer) + grouped_fc1_dy = _group_quantize_for_grouped_mlp( + fc1_dy_bf16, + fc1_grad_output_quantizer, + num_groups, + split_sizes, + tensor_offsets=fc1_dy_tensor_offsets, + ) + _mark_with_gemm_swizzled_scales(grouped_fc1_dy) + else: + grouped_fc1_dy = GroupedTensor( + shape=(out_shape[0], fc1_weight_shape[0]), + dtype=dtype, + num_tensors=num_groups, + quantizer=fc1_grad_output_quantizer, + data=fc1_dy_row_data, + columnwise_data=fc1_dy_col_data, + scale_inv=fc1_dy_row_scale, + columnwise_scale_inv=fc1_dy_col_scale, + first_dims=split_sizes, + tensor_offsets=fc1_dy_tensor_offsets, + with_gemm_swizzled_scales=True, + ) # FC2 wgrad GEMM + wgrad_kernel_fn = None if use_nvfp4 else self.grouped_gemm_wgrad_kernel() fc2_grad_params = _compute_grad_params( fc_op=fc2_op, ctx=fc2_ctx, @@ -652,7 +822,7 @@ def fuser_backward( bias_grads=fc2_bias_grads, bias_grad_packed=fc2_bias_grad_packed, label="FC2", - cudnn_wgrad_kernel_fn=self.grouped_gemm_wgrad_kernel(), + cudnn_wgrad_kernel_fn=wgrad_kernel_fn, offsets=split_points, ) @@ -674,66 +844,107 @@ def fuser_backward( if fc1_ctx.input_requires_grad: in_shape = out_shape[:-1] + [fc1_weight_shape[1]] - fc1_dgrad_a_data = fc2_dgrad_kernel_out["d_row_tensor"] - fc1_dgrad_a_scales = fc2_dgrad_kernel_out["sfd_row_tensor"] - - fc1_dgrad_kwargs = { - "a_tensor": fc1_dgrad_a_data, - "sfa_tensor": fc1_dgrad_a_scales, - "padded_offsets": split_points, - "alpha_tensor": alpha_tensor, - "norm_const_tensor": None, - "prob_tensor": torch.ones((out_shape[0], 1, 1), dtype=torch.float32, device=device), - "acc_dtype": torch.float32, - "d_dtype": dtype, - "cd_major": "n", - "sf_vec_size": MXFP8_BLOCK_SCALING_SIZE, - "current_stream": current_stream, - "discrete_col_sfd": True, - "use_dynamic_sched": True, - } - - if fc1_op.single_grouped_weight: - # Clone and swizzle scales for GEMM - fc1_weight_for_gemm = grouped_fc1_weight.copy() - tex.grouped_swizzle_for_gemm(fc1_weight_for_gemm, rowwise=False, columnwise=True) - - fc1_w_data = fc1_weight_for_gemm.columnwise_data - fc1_w_data = fc1_w_data.view(dtype=torch.float8_e4m3fn) - fc1_w_data = fc1_w_data.view(num_groups, fc1_weight_shape[0], fc1_weight_shape[1]) - fc1_w_data = fc1_w_data.permute(2, 1, 0) - fc1_w_scales = fc1_weight_for_gemm.columnwise_scale_inv.view( - dtype=torch.float8_e8m0fnu - ) - fc1_w_scales = fc1_w_scales.view( - num_groups, - (fc1_weight_shape[1] + 127) // 128, - (fc1_weight_shape[0] + 127) // 128, - MXFP8_BLOCK_SCALING_SIZE, - 4, - 4, - ) - fc1_w_scales = fc1_w_scales.permute(3, 4, 1, 5, 2, 0) - - fc1_dgrad_kwargs["b_tensor"] = fc1_w_data - fc1_dgrad_kwargs["sfb_tensor"] = fc1_w_scales + if use_nvfp4: + _mark_with_gemm_swizzled_scales(grouped_fc1_weight) + _mark_with_gemm_swizzled_scales(grouped_fc1_dy) + grad_input = torch.empty(in_shape, dtype=dtype, device=device) + if num_groups == 1: + if fc1_op.single_grouped_weight: + fc1_w_single = grouped_fc1_weight.split_into_quantized_tensors()[0] + else: + fc1_w_single = grouped_fc1_weight[0] + fc1_dy_single = _nvfp4_single_tensor_from_grouped(grouped_fc1_dy) + general_gemm( + fc1_w_single, + fc1_dy_single, + out_dtype=dtype, + out=grad_input, + layout="NN", + ) + else: + fc1_x_tensor_offsets = base_split_offsets * fc1_weight_shape[1] + grouped_grad_input = GroupedTensor( + shape=(out_shape[0], fc1_weight_shape[1]), + dtype=dtype, + num_tensors=num_groups, + quantizer=None, + data=grad_input.view(-1), + first_dims=split_sizes, + tensor_offsets=fc1_x_tensor_offsets, + ) + general_grouped_gemm_for_grouped_tensor( + grouped_fc1_weight, + grouped_fc1_dy, + grouped_grad_input, + layout="NN", + ) else: - fc1_b_ptrs, fc1_sfb_ptrs, _ = tex.get_device_pointer_for_data_and_scales( - [w._columnwise_data for w in grouped_fc1_weight], - [w._columnwise_scale_inv for w in grouped_fc1_weight], - swizzle=True, - rowwise=False, - data_dtype=grouped_fc1_weight[0]._fp8_dtype, - ) - - fc1_dgrad_kwargs["b_ptrs"] = fc1_b_ptrs - fc1_dgrad_kwargs["sfb_ptrs"] = fc1_sfb_ptrs - fc1_dgrad_kwargs["n"] = fc1_weight_shape[1] - fc1_dgrad_kwargs["b_dtype"] = torch.float8_e4m3fn - fc1_dgrad_kwargs["b_major"] = "n" - - fc1_dgrad_kernel_out = self.grouped_gemm_quant_kernel()(**fc1_dgrad_kwargs) - grad_input = fc1_dgrad_kernel_out["d_tensor"].view(in_shape) + fc1_dgrad_a_data = fc2_dgrad_kernel_out["d_row_tensor"] + fc1_dgrad_a_scales = fc2_dgrad_kernel_out["sfd_row_tensor"] + + fc1_dgrad_kwargs = { + "a_tensor": fc1_dgrad_a_data, + "sfa_tensor": fc1_dgrad_a_scales, + "padded_offsets": split_points, + "alpha_tensor": alpha_tensor, + "norm_const_tensor": None, + "prob_tensor": torch.ones( + (out_shape[0], 1, 1), dtype=torch.float32, device=device + ), + "acc_dtype": torch.float32, + "d_dtype": dtype, + "cd_major": "n", + "sf_vec_size": MXFP8_BLOCK_SCALING_SIZE, + "current_stream": current_stream, + "discrete_col_sfd": True, + "use_dynamic_sched": True, + } + + if fc1_op.single_grouped_weight: + # Clone and swizzle scales for GEMM + fc1_weight_for_gemm = grouped_fc1_weight.copy() + tex.grouped_swizzle_for_gemm( + fc1_weight_for_gemm, rowwise=False, columnwise=True + ) + + fc1_w_data = fc1_weight_for_gemm.columnwise_data + fc1_w_data = fc1_w_data.view(dtype=torch.float8_e4m3fn) + fc1_w_data = fc1_w_data.view( + num_groups, fc1_weight_shape[0], fc1_weight_shape[1] + ) + fc1_w_data = fc1_w_data.permute(2, 1, 0) + fc1_w_scales = fc1_weight_for_gemm.columnwise_scale_inv.view( + dtype=torch.float8_e8m0fnu + ) + fc1_w_scales = fc1_w_scales.view( + num_groups, + (fc1_weight_shape[1] + 127) // 128, + (fc1_weight_shape[0] + 127) // 128, + MXFP8_BLOCK_SCALING_SIZE, + 4, + 4, + ) + fc1_w_scales = fc1_w_scales.permute(3, 4, 1, 5, 2, 0) + + fc1_dgrad_kwargs["b_tensor"] = fc1_w_data + fc1_dgrad_kwargs["sfb_tensor"] = fc1_w_scales + else: + fc1_b_ptrs, fc1_sfb_ptrs, _ = tex.get_device_pointer_for_data_and_scales( + [w._columnwise_data for w in grouped_fc1_weight], + [w._columnwise_scale_inv for w in grouped_fc1_weight], + swizzle=True, + rowwise=False, + data_dtype=grouped_fc1_weight[0]._fp8_dtype, + ) + + fc1_dgrad_kwargs["b_ptrs"] = fc1_b_ptrs + fc1_dgrad_kwargs["sfb_ptrs"] = fc1_sfb_ptrs + fc1_dgrad_kwargs["n"] = fc1_weight_shape[1] + fc1_dgrad_kwargs["b_dtype"] = torch.float8_e4m3fn + fc1_dgrad_kwargs["b_major"] = "n" + + fc1_dgrad_kernel_out = self.grouped_gemm_quant_kernel()(**fc1_dgrad_kwargs) + grad_input = fc1_dgrad_kernel_out["d_tensor"].view(in_shape) # FC1 wgrad GEMM fc1_grad_params = _compute_grad_params( @@ -748,7 +959,7 @@ def fuser_backward( bias_grads=fc1_bias_grads, bias_grad_packed=fc1_bias_grad_packed, label="FC1", - cudnn_wgrad_kernel_fn=self.grouped_gemm_wgrad_kernel(), + cudnn_wgrad_kernel_fn=wgrad_kernel_fn, offsets=split_points, ) @@ -841,6 +1052,8 @@ def fuse_backward_srelu_ops( ) -> list[FusibleOperation]: """Apply GroupedLinear + ScaledSReLU + GroupedLinear fusion for backward pass.""" + if recipe is None or not recipe.mxfp8(): + return ops return fuse_grouped_mlp_ops( ops, recipe=recipe, diff --git a/transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py b/transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py index 034d404439..82b154ae22 100644 --- a/transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py +++ b/transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py @@ -13,12 +13,13 @@ import torch import transformer_engine_torch as tex +from ...cpp_extensions import general_gemm, general_grouped_gemm_for_grouped_tensor from ...quantization import Recipe -from ...tensor import Quantizer +from ...tensor import NVFP4Quantizer, NVFP4Tensor, Quantizer from ...utils import get_cached_ones_tensor, get_device_compute_capability, mark_grouped_tensor from ...tensor.grouped_tensor import GroupedTensor from ...tensor.mxfp8_tensor import MXFP8Quantizer -from ...constants import MXFP8_BLOCK_SCALING_SIZE +from ...constants import MXFP8_BLOCK_SCALING_SIZE, NVFP4_BLOCK_SCALING_SIZE from ..basic import GroupedLinear, ScaledSReLU, ScaledClampedQGeGLU from ..fuser import register_forward_fusion from ..op import FusedOperation, FusibleOperation, OperationContext @@ -26,7 +27,13 @@ _cudnn_frontend_geglu_runtime_params, _cudnn_frontend_version_supported, _cudnn_frontend_supports_grouped_gemm_srelu, + _enable_nvfp4_rht_for_group_quantize, + _group_quantize_for_grouped_mlp, _nvidia_cudnn_frontend_supports_wgrad, + _nvfp4_amax, + _nvfp4_logical_data_view, + _nvfp4_single_tensor_from_grouped, + _pack_nvfp4_amax_list, fuse_grouped_mlp_ops, is_glu_activation, is_quantized_tensor, @@ -202,6 +209,7 @@ def fuser_forward( split_sizes = split_sizes.to(dtype=torch.int64, device=device) base_split_offsets = tex.splits_to_offsets(split_sizes, 1) split_points = base_split_offsets[1:].to(dtype=torch.int) + fc1_x_tensor_offsets = base_split_offsets * fc1_weight_shape[1] fc2_x_tensor_offsets = base_split_offsets * fc2_weight_shape[1] # Extract per-row activation probabilities from the middle op. @@ -224,7 +232,7 @@ def fuser_forward( if fc1_op.weight.rowwise_data is None: raise RuntimeError("FC1 grouped weight has no rowwise_data to quantize.") fc1_weight_quantizer.set_usage(rowwise=True, columnwise=input_requires_grad) - grouped_fc1_weight = tex.group_quantize( + grouped_fc1_weight = _group_quantize_for_grouped_mlp( fc1_op.weight.rowwise_data.view(fc1_op.weight.logical_shape), fc1_weight_quantizer, num_groups, @@ -241,6 +249,8 @@ def fuser_forward( else: quantized_fc1_weights.append(weight) grouped_fc1_weight = quantized_fc1_weights + if isinstance(fc1_input_quantizer, NVFP4Quantizer): + _pack_nvfp4_amax_list(grouped_fc1_weight) # Prepare FC2 grouped weight tensor for fused kernels. if fc2_op.single_grouped_weight: @@ -256,7 +266,7 @@ def fuser_forward( if fc2_op.weight.rowwise_data is None: raise RuntimeError("FC2 grouped weight has no rowwise_data to quantize.") fc2_weight_quantizer.set_usage(rowwise=True, columnwise=input_requires_grad) - grouped_fc2_weight = tex.group_quantize( + grouped_fc2_weight = _group_quantize_for_grouped_mlp( fc2_op.weight.rowwise_data.view(fc2_op.weight.logical_shape), fc2_weight_quantizer, num_groups, @@ -274,6 +284,8 @@ def fuser_forward( else: quantized_fc2_weights.append(weight) grouped_fc2_weight = quantized_fc2_weights + if isinstance(fc2_input_quantizer, NVFP4Quantizer): + _pack_nvfp4_amax_list(grouped_fc2_weight) # Some wrapper-copy paths may drop grouped storage metadata; enforce defaults. if getattr(grouped_fc1_weight, "_with_gemm_swizzled_scales", None) is None and isinstance( @@ -288,13 +300,34 @@ def fuser_forward( # Group-quantize input tensor and convert dtypes if needed fc1_input_quantizer.set_usage(rowwise=True, columnwise=weight_requires_grad) fc1_input_quantizer.optimize_for_gemm = True - if isinstance(input_, GroupedTensor) and isinstance( - getattr(input_, "quantizer", None), MXFP8Quantizer + _enable_nvfp4_rht_for_group_quantize(fc1_input_quantizer) + input_quantizer = getattr(input_, "quantizer", None) + if isinstance(input_, GroupedTensor) and ( + isinstance(fc1_input_quantizer, MXFP8Quantizer) + and isinstance(input_quantizer, MXFP8Quantizer) + or isinstance(fc1_input_quantizer, NVFP4Quantizer) + and isinstance(input_quantizer, NVFP4Quantizer) ): grouped_fc1_x = input_ else: fc1_x = maybe_dequantize(input_, dtype) - grouped_fc1_x = tex.group_quantize(fc1_x, fc1_input_quantizer, num_groups, split_sizes) + grouped_fc1_x = _group_quantize_for_grouped_mlp( + fc1_x, + fc1_input_quantizer, + num_groups, + split_sizes, + tensor_offsets=fc1_x_tensor_offsets, + ) + + use_nvfp4 = isinstance(fc1_input_quantizer, NVFP4Quantizer) or isinstance( + fc1_weight_param, NVFP4Tensor + ) + data_dtype = torch.float4_e2m1fn_x2 if use_nvfp4 else torch.float8_e4m3fn + scale_view_dtype = torch.float8_e4m3fn if use_nvfp4 else torch.float8_e8m0fnu + sf_vec_size = NVFP4_BLOCK_SCALING_SIZE if use_nvfp4 else MXFP8_BLOCK_SCALING_SIZE + data_in_k = in_shape[1] // 2 if use_nvfp4 else in_shape[1] + fc1_weight_k = fc1_weight_shape[1] // 2 if use_nvfp4 else fc1_weight_shape[1] + k_sf_divisor = 2 * sf_vec_size if use_nvfp4 else 4 * sf_vec_size # Pack data tensors # Note: Fused kernel expects tensor with non-contiguous @@ -305,20 +338,46 @@ def fuser_forward( # Data logical shape: (sum(m), k, 1) # Scale logical shape: (32 (block row), 4 (block row), # sum(m)/128, 4 (block col), k/128, 1) - fc1_x_data = grouped_fc1_x.rowwise_data.view(in_shape[0], in_shape[1]) - fc1_x_data = fc1_x_data.view(dtype=torch.float8_e4m3fn) + fc1_x_data = grouped_fc1_x.rowwise_data.view(dtype=data_dtype) + fc1_x_data = fc1_x_data.view(in_shape[0], data_in_k) fc1_x_data = fc1_x_data.unsqueeze(0).permute(1, 2, 0) fc1_x_scales = grouped_fc1_x.scale_inv - fc1_x_scales = fc1_x_scales.view(dtype=torch.float8_e8m0fnu) - fc1_x_scales = fc1_x_scales.view( - 1, - (in_shape[0] + 127) // 128, - (in_shape[1] + 127) // 128, - MXFP8_BLOCK_SCALING_SIZE, - 4, - 4, + fc1_x_scales = fc1_x_scales.view(dtype=scale_view_dtype) + with_gemm_swizzled_scales = getattr( + grouped_fc1_x, + "_with_gemm_swizzled_scales", + getattr(grouped_fc1_x, "with_gemm_swizzled_scales", False), ) - fc1_x_scales = fc1_x_scales.permute(3, 4, 1, 5, 2, 0) + if use_nvfp4 and with_gemm_swizzled_scales: + fc1_x_scales = fc1_x_scales.view( + 1, + in_shape[0] // 128, + data_in_k // k_sf_divisor, + 32, + 4, + 4, + ) + fc1_x_scales = fc1_x_scales.permute(3, 4, 1, 5, 2, 0) + elif use_nvfp4: + fc1_x_scales = fc1_x_scales.view( + 1, + in_shape[0] // 128, + 4, + 32, + data_in_k // k_sf_divisor, + 4, + ) + fc1_x_scales = fc1_x_scales.permute(3, 2, 1, 5, 4, 0) + else: + fc1_x_scales = fc1_x_scales.view( + 1, + (in_shape[0] + 127) // 128, + (in_shape[1] + k_sf_divisor - 1) // k_sf_divisor, + 32, + 4, + 4, + ) + fc1_x_scales = fc1_x_scales.permute(3, 4, 1, 5, 2, 0) alpha_tensor = get_cached_ones_tensor(num_groups, dtype, device) norm_const_tensor = get_cached_ones_tensor(1, torch.float32, device) @@ -327,21 +386,37 @@ def fuser_forward( fc1_bias_packed = _pack_grouped_linear_bias_for_cudnn(fc1_op) fc2_bias_packed = _pack_grouped_linear_bias_for_cudnn(fc2_op) + fc1_d_dtype = torch.bfloat16 if use_nvfp4 else torch.float8_e4m3fn + fc1_prob_tensor = ( + scales.detach().to(dtype=torch.float32 if use_nvfp4 else dtype).reshape(-1, 1, 1) + ) + fc1_norm_const_tensor = None if use_nvfp4 else norm_const_tensor + if use_nvfp4: + nvfp4_fp4_max = 6.0 + nvfp4_fp8_max = 448.0 + fc1_alpha_tensor = ( + _nvfp4_amax(grouped_fc1_x, columnwise=False) + * _nvfp4_amax(grouped_fc1_weight, columnwise=False) + / (nvfp4_fp4_max**2 * nvfp4_fp8_max**2) + ).to(torch.float32) + else: + fc1_alpha_tensor = alpha_tensor + fc1_activation_kwargs = { "a_tensor": fc1_x_data, "sfa_tensor": fc1_x_scales, "padded_offsets": split_points, - "alpha_tensor": alpha_tensor, + "alpha_tensor": fc1_alpha_tensor, "bias_tensor": fc1_bias_packed, - "norm_const_tensor": norm_const_tensor, - "prob_tensor": scales.detach().to(dtype=dtype).reshape(-1, 1, 1), + "norm_const_tensor": fc1_norm_const_tensor, + "prob_tensor": fc1_prob_tensor, "acc_dtype": torch.float32, "c_dtype": torch.bfloat16, - "d_dtype": torch.float8_e4m3fn, + "d_dtype": fc1_d_dtype, "cd_major": "n", - "sf_vec_size": MXFP8_BLOCK_SCALING_SIZE, + "sf_vec_size": sf_vec_size, "current_stream": current_stream, - "discrete_col_sfd": True, + "discrete_col_sfd": not use_nvfp4, "use_dynamic_sched": True, } if self._cudnn_act_func is not None: @@ -363,15 +438,15 @@ def fuser_forward( # Data actual shape: (num_groups, n, k) # Data logical shape: (n, k, num_groups) fc1_w_data = fc1_weight_for_gemm.rowwise_data - fc1_w_data = fc1_w_data.view(dtype=torch.float8_e4m3fn) - fc1_w_data = fc1_w_data.view(num_groups, fc1_weight_shape[0], fc1_weight_shape[1]) + fc1_w_data = fc1_w_data.view(dtype=data_dtype) + fc1_w_data = fc1_w_data.view(num_groups, fc1_weight_shape[0], fc1_weight_k) fc1_w_data = fc1_w_data.permute(1, 2, 0) - fc1_w_scales = fc1_weight_for_gemm.scale_inv.view(dtype=torch.float8_e8m0fnu) + fc1_w_scales = fc1_weight_for_gemm.scale_inv.view(dtype=scale_view_dtype) fc1_w_scales = fc1_w_scales.view( num_groups, (fc1_weight_shape[0] + 127) // 128, - (fc1_weight_shape[1] + 127) // 128, - MXFP8_BLOCK_SCALING_SIZE, + (fc1_weight_shape[1] + k_sf_divisor - 1) // k_sf_divisor, + 32, 4, 4, ) @@ -381,17 +456,26 @@ def fuser_forward( fc1_activation_kwargs["sfb_tensor"] = fc1_w_scales else: # Discrete-weight kernel: per-expert data/scale pointers + fc1_weight_data_for_ptrs = [w._rowwise_data for w in grouped_fc1_weight] + if use_nvfp4: + fc1_weight_data_for_ptrs = [ + _nvfp4_logical_data_view(data) for data in fc1_weight_data_for_ptrs + ] fc1_b_ptrs, fc1_sfb_ptrs, _fc1_sw = tex.get_device_pointer_for_data_and_scales( - [w._rowwise_data for w in grouped_fc1_weight], + fc1_weight_data_for_ptrs, [w._rowwise_scale_inv for w in grouped_fc1_weight], swizzle=True, rowwise=True, - data_dtype=grouped_fc1_weight[0]._fp8_dtype, + data_dtype=( + grouped_fc1_weight[0]._fp4_dtype + if use_nvfp4 + else grouped_fc1_weight[0]._fp8_dtype + ), ) fc1_activation_kwargs["b_ptrs"] = fc1_b_ptrs fc1_activation_kwargs["sfb_ptrs"] = fc1_sfb_ptrs fc1_activation_kwargs["n"] = fc1_weight_shape[0] - fc1_activation_kwargs["b_dtype"] = torch.float8_e4m3fn + fc1_activation_kwargs["b_dtype"] = data_dtype fc1_activation_kwargs["b_major"] = "k" fc1_kernel_out = self.grouped_gemm_activation_kernel()(**fc1_activation_kwargs) @@ -407,94 +491,169 @@ def fuser_forward( # k/128, 4 (block row), sum(m_splits)/128, 1) activation_in = fc1_kernel_out["c_tensor"] activation_in = activation_in.view(in_shape[0], fc1_weight_shape[0]) - fc2_in_row_data = fc1_kernel_out["d_tensor"] - fc2_in_row_data = fc2_in_row_data.view(in_shape[0], fc2_weight_shape[1]) - fc2_in_row_scale = fc1_kernel_out["sfd_row_tensor"] - fc2_in_row_scale = fc2_in_row_scale.permute(5, 2, 4, 0, 1, 3) - - fc2_in_col_data = fc1_kernel_out["d_col_tensor"] - fc2_in_col_data = fc2_in_col_data.view(in_shape[0], fc2_weight_shape[1]) - fc2_in_col_scale = fc1_kernel_out["sfd_col_tensor"] - fc2_in_col_scale = fc2_in_col_scale.permute(5, 2, 4, 0, 1, 3) - # Repack columnwise scales on GPU to preserve group ordering. - - # FC2 inputs scales are already swizzled/optimized for GEMM - grouped_fc2_x = GroupedTensor( - shape=(in_shape[0], fc2_weight_shape[1]), - dtype=dtype, - num_tensors=num_groups, - quantizer=fc2_input_quantizer, - data=fc2_in_row_data.reshape(-1), - columnwise_data=fc2_in_col_data.reshape(-1), - scale_inv=fc2_in_row_scale.reshape(-1), - columnwise_scale_inv=fc2_in_col_scale.reshape(-1), - first_dims=split_sizes, - tensor_offsets=fc2_x_tensor_offsets, - with_gemm_swizzled_scales=True, - ) # FC2 GEMM fc2_out_shape = in_shape[:-1] + [fc2_weight_shape[0]] fc2_scales = basic_op_extra_inputs[2][1] if fc2_op._scale_bias else None - fc2_scales_tensor = ( - fc2_scales.detach().to(dtype=torch.float32).reshape(-1, 1, 1) - if fc2_scales is not None - else torch.ones((in_shape[0], 1, 1), dtype=torch.float32, device=device) - ) - fc2_quant_kwargs = { - "a_tensor": fc1_kernel_out["d_tensor"], - "sfa_tensor": fc1_kernel_out["sfd_row_tensor"], - "padded_offsets": split_points, - "alpha_tensor": alpha_tensor, - "bias_tensor": fc2_bias_packed, - "norm_const_tensor": None, - "prob_tensor": fc2_scales_tensor, - "acc_dtype": torch.float32, - "d_dtype": dtype, - "cd_major": "n", - "sf_vec_size": MXFP8_BLOCK_SCALING_SIZE, - "current_stream": current_stream, - "use_dynamic_sched": True, - } - - if fc2_op.single_grouped_weight: - # Clone and swizzle scales for GEMM (original stays unmodified for save_for_backward) - fc2_weight_for_gemm = grouped_fc2_weight.copy() - tex.grouped_swizzle_for_gemm(fc2_weight_for_gemm, rowwise=True, columnwise=False) - - fc2_w_data = fc2_weight_for_gemm.rowwise_data - fc2_w_data = fc2_w_data.view(dtype=torch.float8_e4m3fn) - fc2_w_data = fc2_w_data.view(num_groups, fc2_weight_shape[0], fc2_weight_shape[1]) - fc2_w_data = fc2_w_data.permute(1, 2, 0) - fc2_w_scales = fc2_weight_for_gemm.scale_inv.view(dtype=torch.float8_e8m0fnu) - fc2_w_scales = fc2_w_scales.view( + if use_nvfp4: + fc2_bias_for_gemm = None + fc2_bias_scale = None + if fc2_bias_packed is not None: + fc2_bias_for_gemm = fc2_op._get_grouped_bias_for_gemm(dtype) + if fc2_scales is not None: + fc2_bias_scale = fc2_scales.reshape(-1) + if fc2_bias_scale.dtype != torch.float32: + fc2_bias_scale = fc2_bias_scale.to(dtype=torch.float32) + + fc2_in = fc1_kernel_out["d_tensor"] + fc2_in = fc2_in.view(in_shape[0], fc2_weight_shape[1]).contiguous() + fc2_input_quantizer.set_usage(rowwise=True, columnwise=weight_requires_grad) + fc2_input_quantizer.optimize_for_gemm = True + _enable_nvfp4_rht_for_group_quantize(fc2_input_quantizer) + grouped_fc2_x = _group_quantize_for_grouped_mlp( + fc2_in, + fc2_input_quantizer, num_groups, - (fc2_weight_shape[0] + 127) // 128, - (fc2_weight_shape[1] + 127) // 128, - MXFP8_BLOCK_SCALING_SIZE, - 4, - 4, + split_sizes, + tensor_offsets=fc2_x_tensor_offsets, ) - fc2_w_scales = fc2_w_scales.permute(3, 4, 1, 5, 2, 0) - fc2_quant_kwargs["b_tensor"] = fc2_w_data - fc2_quant_kwargs["sfb_tensor"] = fc2_w_scales + + fc2_out_buf = torch.empty(fc2_out_shape, dtype=dtype, device=device) + if ( + num_groups == 1 + and grouped_fc2_x.columnwise_data is not None + and grouped_fc2_x.columnwise_scale_inv is not None + ): + if fc2_op.single_grouped_weight: + fc2_w_single = grouped_fc2_weight.split_into_quantized_tensors()[0] + else: + fc2_w_single = grouped_fc2_weight[0] + fc2_x_single = _nvfp4_single_tensor_from_grouped( + grouped_fc2_x, + fc2_input_quantizer, + fp4_dtype=getattr(fc2_w_single, "_fp4_dtype", fc2_input_quantizer.dtype), + ) + general_gemm( + fc2_w_single, + fc2_x_single, + out_dtype=dtype, + out=fc2_out_buf, + layout="TN", + use_split_accumulator=False, + ) + if fc2_bias_packed is not None: + token_bias = fc2_bias_packed.transpose(0, 1).contiguous().expand( + in_shape[0], -1 + ) + if fc2_scales is not None: + fc2_out_buf = fc2_out_buf + token_bias * fc2_scales.view(-1, 1) + else: + fc2_out_buf = fc2_out_buf + token_bias + else: + fc2_out_offsets = base_split_offsets * fc2_weight_shape[0] + fc2_out_grouped = GroupedTensor( + shape=(in_shape[0], fc2_weight_shape[0]), + dtype=dtype, + num_tensors=num_groups, + quantizer=None, + data=fc2_out_buf.view(-1), + first_dims=split_sizes, + tensor_offsets=fc2_out_offsets, + ) + general_grouped_gemm_for_grouped_tensor( + grouped_fc2_weight, + grouped_fc2_x, + fc2_out_grouped, + layout="TN", + bias=fc2_bias_for_gemm, + bias_scale=fc2_bias_scale, + ) + fc2_out = fc2_out_buf else: - fc2_b_ptrs, fc2_sfb_ptrs, _ = tex.get_device_pointer_for_data_and_scales( - [w._rowwise_data for w in grouped_fc2_weight], - [w._rowwise_scale_inv for w in grouped_fc2_weight], - swizzle=True, - rowwise=True, - data_dtype=grouped_fc2_weight[0]._fp8_dtype, + fc2_in_row_data = fc1_kernel_out["d_tensor"] + fc2_in_row_data = fc2_in_row_data.view(in_shape[0], fc2_weight_shape[1]) + fc2_in_row_scale = fc1_kernel_out["sfd_row_tensor"] + fc2_in_row_scale = fc2_in_row_scale.permute(5, 2, 4, 0, 1, 3) + + fc2_in_col_data = fc1_kernel_out["d_col_tensor"] + fc2_in_col_data = fc2_in_col_data.view(in_shape[0], fc2_weight_shape[1]) + fc2_in_col_scale = fc1_kernel_out["sfd_col_tensor"] + fc2_in_col_scale = fc2_in_col_scale.permute(5, 2, 4, 0, 1, 3) + + grouped_fc2_x = GroupedTensor( + shape=(in_shape[0], fc2_weight_shape[1]), + dtype=dtype, + num_tensors=num_groups, + quantizer=fc2_input_quantizer, + data=fc2_in_row_data.reshape(-1), + columnwise_data=fc2_in_col_data.reshape(-1), + scale_inv=fc2_in_row_scale.reshape(-1), + columnwise_scale_inv=fc2_in_col_scale.reshape(-1), + first_dims=split_sizes, + tensor_offsets=fc2_x_tensor_offsets, + with_gemm_swizzled_scales=True, ) - fc2_quant_kwargs["b_ptrs"] = fc2_b_ptrs - fc2_quant_kwargs["sfb_ptrs"] = fc2_sfb_ptrs - fc2_quant_kwargs["n"] = fc2_weight_shape[0] - fc2_quant_kwargs["b_dtype"] = torch.float8_e4m3fn - fc2_quant_kwargs["b_major"] = "k" - fc2_kernel_out = self.grouped_gemm_quant_kernel()(**fc2_quant_kwargs) - fc2_out = fc2_kernel_out["d_tensor"].permute(2, 0, 1).view(fc2_out_shape).contiguous() + fc2_scales_tensor = ( + fc2_scales.detach().to(dtype=torch.float32).reshape(-1, 1, 1) + if fc2_scales is not None + else torch.ones((in_shape[0], 1, 1), dtype=torch.float32, device=device) + ) + fc2_quant_kwargs = { + "a_tensor": fc1_kernel_out["d_tensor"], + "sfa_tensor": fc1_kernel_out["sfd_row_tensor"], + "padded_offsets": split_points, + "alpha_tensor": alpha_tensor, + "bias_tensor": fc2_bias_packed, + "norm_const_tensor": None, + "prob_tensor": fc2_scales_tensor, + "acc_dtype": torch.float32, + "d_dtype": dtype, + "cd_major": "n", + "sf_vec_size": MXFP8_BLOCK_SCALING_SIZE, + "current_stream": current_stream, + "use_dynamic_sched": True, + } + + if fc2_op.single_grouped_weight: + # Clone and swizzle scales for GEMM (original stays unmodified for save_for_backward) + fc2_weight_for_gemm = grouped_fc2_weight.copy() + tex.grouped_swizzle_for_gemm(fc2_weight_for_gemm, rowwise=True, columnwise=False) + + fc2_w_data = fc2_weight_for_gemm.rowwise_data + fc2_w_data = fc2_w_data.view(dtype=torch.float8_e4m3fn) + fc2_w_data = fc2_w_data.view(num_groups, fc2_weight_shape[0], fc2_weight_shape[1]) + fc2_w_data = fc2_w_data.permute(1, 2, 0) + + fc2_w_scales = fc2_weight_for_gemm.scale_inv.view(dtype=torch.float8_e8m0fnu) + fc2_w_scales = fc2_w_scales.view( + num_groups, + (fc2_weight_shape[0] + 127) // 128, + (fc2_weight_shape[1] + 127) // 128, + MXFP8_BLOCK_SCALING_SIZE, + 4, + 4, + ) + fc2_w_scales = fc2_w_scales.permute(3, 4, 1, 5, 2, 0) + fc2_quant_kwargs["b_tensor"] = fc2_w_data + fc2_quant_kwargs["sfb_tensor"] = fc2_w_scales + else: + fc2_b_ptrs, fc2_sfb_ptrs, _ = tex.get_device_pointer_for_data_and_scales( + [w._rowwise_data for w in grouped_fc2_weight], + [w._rowwise_scale_inv for w in grouped_fc2_weight], + swizzle=True, + rowwise=True, + data_dtype=grouped_fc2_weight[0]._fp8_dtype, + ) + fc2_quant_kwargs["b_ptrs"] = fc2_b_ptrs + fc2_quant_kwargs["sfb_ptrs"] = fc2_sfb_ptrs + fc2_quant_kwargs["n"] = fc2_weight_shape[0] + fc2_quant_kwargs["b_dtype"] = torch.float8_e4m3fn + fc2_quant_kwargs["b_major"] = "k" + + fc2_kernel_out = self.grouped_gemm_quant_kernel()(**fc2_quant_kwargs) + fc2_out = fc2_kernel_out["d_tensor"].permute(2, 0, 1).view(fc2_out_shape).contiguous() # Save state for backward pass if requires_grad: @@ -513,11 +672,13 @@ def fuser_forward( ) saved_grouped_fc2_x = None if recompute_srelu_fc2_x else grouped_fc2_x - # Save the input ``GroupedTensor``s themselves for the activations. - for grouped_fc_x in (grouped_fc1_x, saved_grouped_fc2_x): - if grouped_fc_x is not None: - grouped_fc_x.rowwise_data = None - grouped_fc_x.scale_inv = None + # MXFP8 wgrad only needs columnwise tiles. NVFP4 generic GEMM fallbacks + # need the full grouped tensor state, including rowwise data and amax. + if not use_nvfp4: + for grouped_fc_x in (grouped_fc1_x, saved_grouped_fc2_x): + if grouped_fc_x is not None: + grouped_fc_x.rowwise_data = None + grouped_fc_x.scale_inv = None # FC1 saved-tensor layout. # [split_sizes, base_split_offsets, split_points, @@ -649,6 +810,8 @@ def fuse_forward_srelu_ops( ) -> list[FusibleOperation]: """Apply GroupedLinear + ScaledSReLU + GroupedLinear fusion for forward pass.""" + if recipe is None or not recipe.mxfp8(): + return ops return fuse_grouped_mlp_ops( ops, recipe=recipe, From c01eaa41d196748b991225428925bc0784cfd9bb Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 27 May 2026 20:26:28 +0000 Subject: [PATCH 02/14] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py b/transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py index 82b154ae22..69862fb9c0 100644 --- a/transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py +++ b/transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py @@ -543,8 +543,8 @@ def fuser_forward( use_split_accumulator=False, ) if fc2_bias_packed is not None: - token_bias = fc2_bias_packed.transpose(0, 1).contiguous().expand( - in_shape[0], -1 + token_bias = ( + fc2_bias_packed.transpose(0, 1).contiguous().expand(in_shape[0], -1) ) if fc2_scales is not None: fc2_out_buf = fc2_out_buf + token_bias * fc2_scales.view(-1, 1) From 6c2ee782901e274e84e777eb2f1f4c8cf9fcaad6 Mon Sep 17 00:00:00 2001 From: Siddhartha Raman S Date: Wed, 27 May 2026 17:26:45 -0700 Subject: [PATCH 03/14] Address NVFP4 grouped MLP review comments Signed-off-by: Siddhartha Raman S --- transformer_engine/pytorch/ops/_common.py | 16 ---------------- .../pytorch/ops/basic/grouped_linear.py | 15 ++++----------- .../pytorch/ops/fused/backward_grouped_mlp.py | 15 +-------------- .../pytorch/ops/fused/forward_grouped_mlp.py | 5 ----- 4 files changed, 5 insertions(+), 46 deletions(-) diff --git a/transformer_engine/pytorch/ops/_common.py b/transformer_engine/pytorch/ops/_common.py index 0bbc5280b3..6da737d032 100644 --- a/transformer_engine/pytorch/ops/_common.py +++ b/transformer_engine/pytorch/ops/_common.py @@ -60,22 +60,6 @@ def _nvidia_cudnn_frontend_supports_wgrad() -> bool: return _cudnn_frontend_version_supported() -def _pack_nvfp4_amax_list(tensors: list) -> None: - """Ensure discrete NVFP4 weight list uses contiguous per-group amax buffers.""" - if not tensors: - return - row_amaxes = [getattr(tensor, "_amax_rowwise", None) for tensor in tensors] - if all(amax is not None for amax in row_amaxes): - packed_row_amax = torch.cat([amax.view(-1) for amax in row_amaxes], dim=0).contiguous() - for idx, tensor in enumerate(tensors): - tensor._amax_rowwise = packed_row_amax[idx : idx + 1] - col_amaxes = [getattr(tensor, "_amax_columnwise", None) for tensor in tensors] - if all(amax is not None for amax in col_amaxes): - packed_col_amax = torch.cat([amax.view(-1) for amax in col_amaxes], dim=0).contiguous() - for idx, tensor in enumerate(tensors): - tensor._amax_columnwise = packed_col_amax[idx : idx + 1] - - def _enable_nvfp4_rht_for_group_quantize(quantizer: Quantizer) -> None: """Use the graph-safe NVFP4 grouped quantization path.""" if isinstance(quantizer, NVFP4Quantizer): diff --git a/transformer_engine/pytorch/ops/basic/grouped_linear.py b/transformer_engine/pytorch/ops/basic/grouped_linear.py index 77c8b4bafa..1f00d92284 100644 --- a/transformer_engine/pytorch/ops/basic/grouped_linear.py +++ b/transformer_engine/pytorch/ops/basic/grouped_linear.py @@ -1059,6 +1059,9 @@ def _fuser_forward_split_quantize( num_groups = self.num_groups has_bias = self.has_bias + # Need CPU split sizes for split_quantize / general_grouped_gemm. + split_sizes_int = [int(s) for s in split_sizes.tolist()] + # Extract params if self.single_grouped_weight: weights = self.weight.quantized_tensors @@ -1080,12 +1083,6 @@ def _fuser_forward_split_quantize( # Split input tensor and convert dtypes if needed x = maybe_dequantize(input_, dtype) - if num_groups == 1: - # Avoid CUDA->CPU sync from split_sizes.tolist() during CUDA graph capture. - split_sizes_int = [x.numel() // x.size(-1)] - else: - # Need CPU split sizes for split_quantize / general_grouped_gemm. - split_sizes_int = [int(s) for s in split_sizes.tolist()] xs = None if with_quantized_compute: for quantizer in input_quantizers: @@ -1332,12 +1329,8 @@ def _fuser_backward_split_quantize( ws, saved_tensors = saved_tensors[:num_groups], saved_tensors[num_groups:] # Split grad output tensor and convert dtypes if needed + split_sizes_int = [int(s) for s in split_sizes.tolist()] dy = maybe_dequantize(grad_output, ctx.dtype) - if num_groups == 1: - # Avoid CUDA->CPU sync from split_sizes.tolist() during CUDA graph capture. - split_sizes_int = [dy.numel() // dy.size(-1)] - else: - split_sizes_int = [int(s) for s in split_sizes.tolist()] dys = None grad_biases = [None] * num_groups grad_scales = None diff --git a/transformer_engine/pytorch/ops/fused/backward_grouped_mlp.py b/transformer_engine/pytorch/ops/fused/backward_grouped_mlp.py index b1faca36d3..87ffd34969 100644 --- a/transformer_engine/pytorch/ops/fused/backward_grouped_mlp.py +++ b/transformer_engine/pytorch/ops/fused/backward_grouped_mlp.py @@ -8,7 +8,7 @@ from collections.abc import Callable import functools import os -from typing import Any, Optional +from typing import Optional import torch @@ -48,16 +48,6 @@ from ...triton.grouped_dbias_dscales import compute_grouped_dbias_dscales -def _mark_with_gemm_swizzled_scales(tensors: Any) -> None: - """Mark tensors whose scale buffers are already in GEMM-swizzled layout.""" - if tensors is None: - return - if hasattr(tensors, "with_gemm_swizzled_scales"): - tensors.with_gemm_swizzled_scales = True - if hasattr(tensors, "_with_gemm_swizzled_scales"): - tensors._with_gemm_swizzled_scales = True - - def _nvfp4_single_group_wgrad_gemm( grouped_x: GroupedTensor, grouped_dy: GroupedTensor, @@ -792,7 +782,6 @@ def fuser_backward( split_sizes, tensor_offsets=fc1_dy_tensor_offsets, ) - _mark_with_gemm_swizzled_scales(grouped_fc1_dy) else: grouped_fc1_dy = GroupedTensor( shape=(out_shape[0], fc1_weight_shape[0]), @@ -845,8 +834,6 @@ def fuser_backward( in_shape = out_shape[:-1] + [fc1_weight_shape[1]] if use_nvfp4: - _mark_with_gemm_swizzled_scales(grouped_fc1_weight) - _mark_with_gemm_swizzled_scales(grouped_fc1_dy) grad_input = torch.empty(in_shape, dtype=dtype, device=device) if num_groups == 1: if fc1_op.single_grouped_weight: diff --git a/transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py b/transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py index 69862fb9c0..52fcf74fb5 100644 --- a/transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py +++ b/transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py @@ -33,7 +33,6 @@ _nvfp4_amax, _nvfp4_logical_data_view, _nvfp4_single_tensor_from_grouped, - _pack_nvfp4_amax_list, fuse_grouped_mlp_ops, is_glu_activation, is_quantized_tensor, @@ -249,8 +248,6 @@ def fuser_forward( else: quantized_fc1_weights.append(weight) grouped_fc1_weight = quantized_fc1_weights - if isinstance(fc1_input_quantizer, NVFP4Quantizer): - _pack_nvfp4_amax_list(grouped_fc1_weight) # Prepare FC2 grouped weight tensor for fused kernels. if fc2_op.single_grouped_weight: @@ -284,8 +281,6 @@ def fuser_forward( else: quantized_fc2_weights.append(weight) grouped_fc2_weight = quantized_fc2_weights - if isinstance(fc2_input_quantizer, NVFP4Quantizer): - _pack_nvfp4_amax_list(grouped_fc2_weight) # Some wrapper-copy paths may drop grouped storage metadata; enforce defaults. if getattr(grouped_fc1_weight, "_with_gemm_swizzled_scales", None) is None and isinstance( From 343633418a3592801f27f609c2eee0315749228b Mon Sep 17 00:00:00 2001 From: Siddhartha Raman Sundara Raman Date: Wed, 27 May 2026 19:40:15 -0500 Subject: [PATCH 04/14] Update transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com> Signed-off-by: Siddhartha Raman Sundara Raman --- transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py b/transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py index 52fcf74fb5..8ea0f650c6 100644 --- a/transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py +++ b/transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py @@ -346,7 +346,7 @@ def fuser_forward( if use_nvfp4 and with_gemm_swizzled_scales: fc1_x_scales = fc1_x_scales.view( 1, - in_shape[0] // 128, + (in_shape[0] + 127) // 128, data_in_k // k_sf_divisor, 32, 4, @@ -356,7 +356,7 @@ def fuser_forward( elif use_nvfp4: fc1_x_scales = fc1_x_scales.view( 1, - in_shape[0] // 128, + (in_shape[0] + 127) // 128, 4, 32, data_in_k // k_sf_divisor, From 71427e7c10c88455f33d10bffb4abed6655f725b Mon Sep 17 00:00:00 2001 From: Siddhartha Raman Sundara Raman Date: Wed, 27 May 2026 19:40:43 -0500 Subject: [PATCH 05/14] Update transformer_engine/pytorch/ops/fused/backward_grouped_mlp.py Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com> Signed-off-by: Siddhartha Raman Sundara Raman --- transformer_engine/pytorch/ops/fused/backward_grouped_mlp.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/transformer_engine/pytorch/ops/fused/backward_grouped_mlp.py b/transformer_engine/pytorch/ops/fused/backward_grouped_mlp.py index 87ffd34969..a2196c9e94 100644 --- a/transformer_engine/pytorch/ops/fused/backward_grouped_mlp.py +++ b/transformer_engine/pytorch/ops/fused/backward_grouped_mlp.py @@ -535,7 +535,7 @@ def fuser_backward( if use_nvfp4 and with_gemm_swizzled_scales: fc2_dy_scales = fc2_dy_scales.view( 1, - out_shape[0] // 128, + (out_shape[0] + 127) // 128, data_k // k_sf_divisor, 32, 4, @@ -545,7 +545,7 @@ def fuser_backward( elif use_nvfp4: fc2_dy_scales = fc2_dy_scales.view( 1, - out_shape[0] // 128, + (out_shape[0] + 127) // 128, 4, 32, data_k // k_sf_divisor, From 78fcc3ea4b4219b9785d848c343f94617a5ec82c Mon Sep 17 00:00:00 2001 From: Siddhartha Raman S Date: Thu, 28 May 2026 08:00:24 -0700 Subject: [PATCH 06/14] Address grouped MLP NVFP4 review feedback Signed-off-by: Siddhartha Raman S --- tests/pytorch/test_fusible_ops.py | 15 ++++++++++++++- .../pytorch/csrc/extensions/utils.cpp | 5 +++++ transformer_engine/pytorch/ops/_common.py | 12 ------------ .../pytorch/ops/fused/backward_grouped_mlp.py | 16 ++++------------ .../pytorch/ops/fused/forward_grouped_mlp.py | 16 ++++------------ 5 files changed, 27 insertions(+), 37 deletions(-) diff --git a/tests/pytorch/test_fusible_ops.py b/tests/pytorch/test_fusible_ops.py index 1ced32e1a5..69d793a287 100644 --- a/tests/pytorch/test_fusible_ops.py +++ b/tests/pytorch/test_fusible_ops.py @@ -3762,6 +3762,14 @@ def test_grouped_mlp( # TODO: ksivaman: Need to debug numerics for this case. pytest.skip("Bias/dbias not yet supported in NVFP4 fused grouped MLP with GeGLU") fc1_out_features = 2 * hidden_size if activation_is_glu else hidden_size + use_nvfp4_rht_recipe = ( + quantization == "nvfp4" + and activation_is_glu + and glu_interleave_size == 32 + and dtype in (torch.bfloat16, torch.float16) + and te_ops.fused.ForwardGroupedMLP_CuTeGEMMGLU_MXFP8.is_supported() + and te_ops.fused.BackwardGroupedMLP_CuTeGEMMDGLU_MXFP8.is_supported() + ) # Activation parameters for clamped QGeGLU variants if activation == "scaled_clamped_qgeglu_custom": @@ -3883,7 +3891,10 @@ def test_grouped_mlp( y_ref.backward(dy_ref) # Construct operations - recipe = make_recipe(quantization) + recipe_kwargs = {} + if use_nvfp4_rht_recipe: + recipe_kwargs["disable_rht"] = False + recipe = make_recipe(quantization, **recipe_kwargs) if activation == "scaled_clamped_qgeglu_custom": scaled_act = te_ops.ScaledClampedQGeGLU( glu_interleave_size=glu_interleave_size, @@ -3976,6 +3987,8 @@ def test_grouped_mlp( fc2.backward_dw() # Check for expected fusions + if use_nvfp4_rht_recipe: + return if ( quantization == "mxfp8" and dtype in (torch.bfloat16, torch.float16) diff --git a/transformer_engine/pytorch/csrc/extensions/utils.cpp b/transformer_engine/pytorch/csrc/extensions/utils.cpp index 9a093608d4..1133e00ede 100644 --- a/transformer_engine/pytorch/csrc/extensions/utils.cpp +++ b/transformer_engine/pytorch/csrc/extensions/utils.cpp @@ -66,6 +66,11 @@ std::tuple get_device_pointer_for_data_and_s data_shape.ndim = 2; data_shape.data[0] = static_cast(data_tensors[0].size(0)); data_shape.data[1] = static_cast(data_tensors[0].size(1)); + if (is_fp4_dtype(data_dtype)) { + // FP4 tensors are packed with two logical values per byte. The data pointers still refer to + // the packed physical storage, but TensorWrapper needs the logical shape for scale swizzling. + data_shape.data[1] *= 2; + } // Collect data device pointers std::vector data_host_ptrs(num_tensors); diff --git a/transformer_engine/pytorch/ops/_common.py b/transformer_engine/pytorch/ops/_common.py index 6da737d032..9ca630811f 100644 --- a/transformer_engine/pytorch/ops/_common.py +++ b/transformer_engine/pytorch/ops/_common.py @@ -60,13 +60,6 @@ def _nvidia_cudnn_frontend_supports_wgrad() -> bool: return _cudnn_frontend_version_supported() -def _enable_nvfp4_rht_for_group_quantize(quantizer: Quantizer) -> None: - """Use the graph-safe NVFP4 grouped quantization path.""" - if isinstance(quantizer, NVFP4Quantizer): - quantizer.with_rht = True - quantizer.with_post_rht_amax = True - - def _group_quantize_for_grouped_mlp( tensor: torch.Tensor, quantizer: Quantizer, @@ -130,11 +123,6 @@ def _group_quantize_for_grouped_mlp( ) -def _nvfp4_logical_data_view(data: torch.Tensor) -> torch.Tensor: - """View packed NVFP4 data with its logical K dimension for scale swizzling.""" - return data.as_strided((data.shape[0], data.shape[1] * 2), (data.stride(0), 0)) - - def _nvfp4_amax(tensors: Any, *, columnwise: bool) -> torch.Tensor: """Get one NVFP4 amax value per group.""" grouped_attr = "columnwise_amax" if columnwise else "amax" diff --git a/transformer_engine/pytorch/ops/fused/backward_grouped_mlp.py b/transformer_engine/pytorch/ops/fused/backward_grouped_mlp.py index a2196c9e94..fe7c214c45 100644 --- a/transformer_engine/pytorch/ops/fused/backward_grouped_mlp.py +++ b/transformer_engine/pytorch/ops/fused/backward_grouped_mlp.py @@ -26,10 +26,8 @@ _cudnn_frontend_geglu_runtime_params, _cudnn_frontend_version_supported, _cudnn_frontend_supports_grouped_gemm_srelu, - _enable_nvfp4_rht_for_group_quantize, _group_quantize_for_grouped_mlp, _nvfp4_amax, - _nvfp4_logical_data_view, _nvfp4_single_tensor_from_grouped, fuse_grouped_mlp_ops, get_accumulate_flag_in_param, @@ -299,8 +297,8 @@ def _compute_grad_params( return w_list + bias_list -class _BackwardGroupedMLP_CuTeGEMMDBase_MXFP8(FusedOperation): - """Base fused backward op for MXFP8 GroupedLinear + activation + GroupedLinear. +class _BackwardGroupedMLP_CuTeGEMMDBase(FusedOperation): + """Base fused backward op for block-scaled GroupedLinear + activation + GroupedLinear. Uses experimental CuTe DSL kernel from cuDNN front-end. @@ -465,7 +463,6 @@ def fuser_backward( fc2_grad_output_quantizer = fc2_ctx.grad_output_quantizers[0] fc2_grad_output_quantizer.set_usage(rowwise=True, columnwise=fc2_ctx.weight_requires_grad) fc2_grad_output_quantizer.optimize_for_gemm = True - _enable_nvfp4_rht_for_group_quantize(fc2_grad_output_quantizer) output_fc2_dbias = fc2_op.has_bias fc2_dbias_packed = None fc2_dy = None @@ -650,10 +647,6 @@ def fuser_backward( fc2_dactivation_kwargs["sfb_tensor"] = fc2_w_scales else: fc2_weight_data_for_ptrs = [w._columnwise_data for w in grouped_fc2_weight] - if use_nvfp4: - fc2_weight_data_for_ptrs = [ - _nvfp4_logical_data_view(data) for data in fc2_weight_data_for_ptrs - ] fc2_b_ptrs, fc2_sfb_ptrs, _fc2_sw = tex.get_device_pointer_for_data_and_scales( fc2_weight_data_for_ptrs, [w._columnwise_scale_inv for w in grouped_fc2_weight], @@ -774,7 +767,6 @@ def fuser_backward( columnwise=fc1_ctx.weight_requires_grad, ) fc1_grad_output_quantizer.optimize_for_gemm = True - _enable_nvfp4_rht_for_group_quantize(fc1_grad_output_quantizer) grouped_fc1_dy = _group_quantize_for_grouped_mlp( fc1_dy_bf16, fc1_grad_output_quantizer, @@ -972,7 +964,7 @@ def fuser_backward( ) -class BackwardGroupedMLP_CuTeGEMMDGLU_MXFP8(_BackwardGroupedMLP_CuTeGEMMDBase_MXFP8): +class BackwardGroupedMLP_CuTeGEMMDGLU_MXFP8(_BackwardGroupedMLP_CuTeGEMMDBase): """Fused backward op for GroupedLinear + scaled GLU + GroupedLinear.""" @classmethod @@ -984,7 +976,7 @@ def grouped_gemm_dactivation_kernel(cls) -> Callable: return grouped_gemm_dglu_wrapper_sm100 -class BackwardGroupedMLP_CuTeGEMMDUnary_MXFP8(_BackwardGroupedMLP_CuTeGEMMDBase_MXFP8): +class BackwardGroupedMLP_CuTeGEMMDUnary_MXFP8(_BackwardGroupedMLP_CuTeGEMMDBase): """Fused backward op for GroupedLinear + scaled unary activation + GroupedLinear.""" @classmethod diff --git a/transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py b/transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py index 8ea0f650c6..e1c9143a7f 100644 --- a/transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py +++ b/transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py @@ -27,11 +27,9 @@ _cudnn_frontend_geglu_runtime_params, _cudnn_frontend_version_supported, _cudnn_frontend_supports_grouped_gemm_srelu, - _enable_nvfp4_rht_for_group_quantize, _group_quantize_for_grouped_mlp, _nvidia_cudnn_frontend_supports_wgrad, _nvfp4_amax, - _nvfp4_logical_data_view, _nvfp4_single_tensor_from_grouped, fuse_grouped_mlp_ops, is_glu_activation, @@ -73,8 +71,8 @@ def _grouped_gemm_dsrelu_backward_supported() -> bool: return grouped_gemm_dsrelu_wrapper_sm100 is not None -class _ForwardGroupedMLP_CuTeGEMMBase_MXFP8(FusedOperation): - """Base fused op for MXFP8 GroupedLinear + activation + GroupedLinear. +class _ForwardGroupedMLP_CuTeGEMMBase(FusedOperation): + """Base fused op for block-scaled GroupedLinear + activation + GroupedLinear. Uses experimental CuTe DSL kernel from cuDNN front-end. @@ -295,7 +293,6 @@ def fuser_forward( # Group-quantize input tensor and convert dtypes if needed fc1_input_quantizer.set_usage(rowwise=True, columnwise=weight_requires_grad) fc1_input_quantizer.optimize_for_gemm = True - _enable_nvfp4_rht_for_group_quantize(fc1_input_quantizer) input_quantizer = getattr(input_, "quantizer", None) if isinstance(input_, GroupedTensor) and ( isinstance(fc1_input_quantizer, MXFP8Quantizer) @@ -452,10 +449,6 @@ def fuser_forward( else: # Discrete-weight kernel: per-expert data/scale pointers fc1_weight_data_for_ptrs = [w._rowwise_data for w in grouped_fc1_weight] - if use_nvfp4: - fc1_weight_data_for_ptrs = [ - _nvfp4_logical_data_view(data) for data in fc1_weight_data_for_ptrs - ] fc1_b_ptrs, fc1_sfb_ptrs, _fc1_sw = tex.get_device_pointer_for_data_and_scales( fc1_weight_data_for_ptrs, [w._rowwise_scale_inv for w in grouped_fc1_weight], @@ -505,7 +498,6 @@ def fuser_forward( fc2_in = fc2_in.view(in_shape[0], fc2_weight_shape[1]).contiguous() fc2_input_quantizer.set_usage(rowwise=True, columnwise=weight_requires_grad) fc2_input_quantizer.optimize_for_gemm = True - _enable_nvfp4_rht_for_group_quantize(fc2_input_quantizer) grouped_fc2_x = _group_quantize_for_grouped_mlp( fc2_in, fc2_input_quantizer, @@ -738,7 +730,7 @@ def fuser_forward( return fc2_out, [(), (), ()] -class ForwardGroupedMLP_CuTeGEMMGLU_MXFP8(_ForwardGroupedMLP_CuTeGEMMBase_MXFP8): +class ForwardGroupedMLP_CuTeGEMMGLU_MXFP8(_ForwardGroupedMLP_CuTeGEMMBase): """Fused op for MXFP8 GroupedLinear + scaled GLU + GroupedLinear.""" @classmethod @@ -750,7 +742,7 @@ def grouped_gemm_activation_kernel(cls) -> Callable: return grouped_gemm_glu_wrapper_sm100 -class ForwardGroupedMLP_CuTeGEMMUnary_MXFP8(_ForwardGroupedMLP_CuTeGEMMBase_MXFP8): +class ForwardGroupedMLP_CuTeGEMMUnary_MXFP8(_ForwardGroupedMLP_CuTeGEMMBase): """Fused op for MXFP8 GroupedLinear + scaled unary activation + GroupedLinear.""" @classmethod From d82c25bde30b1cdebf9aabd59ac4a663ba1a1647 Mon Sep 17 00:00:00 2001 From: Siddhartha Raman S Date: Thu, 28 May 2026 14:55:51 -0700 Subject: [PATCH 07/14] Add NVFP4 RHT grouped MLP coverage Route the NVFP4 RHT grouped MLP test through the shared recipe helpers, compare the fused path against a TE unfused reference, and keep plain NVFP4 on the non-RHT fallback path. Signed-off-by: Siddhartha Raman S --- tests/pytorch/test_fusible_ops.py | 291 +++++++++++++----- tests/pytorch/utils.py | 22 +- .../common/gemm/cublaslt_grouped_gemm.cu | 10 +- transformer_engine/pytorch/ops/_common.py | 3 + 4 files changed, 223 insertions(+), 103 deletions(-) diff --git a/tests/pytorch/test_fusible_ops.py b/tests/pytorch/test_fusible_ops.py index 69d793a287..1a656191d9 100644 --- a/tests/pytorch/test_fusible_ops.py +++ b/tests/pytorch/test_fusible_ops.py @@ -80,6 +80,9 @@ if nvfp4_available: _quantization_list.append("nvfp4") _quantization_list.append("nvfp4_4over6") +_grouped_mlp_quantization_list = list(_quantization_list) +if nvfp4_available: + _grouped_mlp_quantization_list.append("nvfp4_rht") @pytest.fixture(autouse=True, scope="function") @@ -109,7 +112,10 @@ def maybe_skip_quantization( pytest.skip(reason_for_no_fp8) if quantization == "mxfp8" and not mxfp8_available: pytest.skip(reason_for_no_mxfp8) - if quantization in ("nvfp4", "nvfp4_row_scaled", "nvfp4_4over6") and not nvfp4_available: + if ( + quantization in ("nvfp4", "nvfp4_row_scaled", "nvfp4_4over6", "nvfp4_rht") + and not nvfp4_available + ): pytest.skip(reason_for_no_nvfp4) # Check dims @@ -122,14 +128,14 @@ def maybe_skip_quantization( elif quantization == "mxfp8": if math.prod(dims[:-1]) % 32 != 0 or dims[-1] % 32 != 0: pytest.skip("MXFP8 GEMMs require dims that are divisible by 32") - elif quantization in ("nvfp4", "nvfp4_row_scaled", "nvfp4_4over6"): + elif quantization in ("nvfp4", "nvfp4_row_scaled", "nvfp4_4over6", "nvfp4_rht"): if math.prod(dims[:-1]) % 16 != 0 or dims[-1] % 16 != 0: pytest.skip("NVFP4 GEMMs require dims that are divisible by 16") # Check dtype if dtype is not None: if ( - quantization in ("nvfp4", "nvfp4_row_scaled", "nvfp4_4over6") + quantization in ("nvfp4", "nvfp4_row_scaled", "nvfp4_4over6", "nvfp4_rht") and dtype != torch.bfloat16 ): pytest.skip("NVFP4 quantization is only supported with BF16 data") @@ -187,10 +193,11 @@ def make_reference_and_test_tensors( test = quantizer(test) elif quantization == "mxfp8": test = MXFP8Quantizer(fp8_dtype=tex.DType.kFloat8E4M3)(test) - elif quantization in ("nvfp4", "nvfp4_row_scaled"): + elif quantization in ("nvfp4", "nvfp4_row_scaled", "nvfp4_rht"): + with_rht = quantization == "nvfp4_rht" test = NVFP4Quantizer( - with_rht=False, - with_post_rht_amax=False, + with_rht=with_rht, + with_post_rht_amax=with_rht, with_2d_quantization=False, stochastic_rounding=False, with_random_sign_mask=False, @@ -3685,7 +3692,7 @@ def test_layernorm_mlp( @pytest.mark.parametrize("bias", (False, True)) @pytest.mark.parametrize("dtype", _dtypes) - @pytest.mark.parametrize("quantization", _quantization_list) + @pytest.mark.parametrize("quantization", _grouped_mlp_quantization_list) @pytest.mark.parametrize("single_grouped_weight", (False, True)) @pytest.mark.parametrize("single_grouped_bias", (False, True)) @pytest.mark.parametrize("accumulate_into_main_grad", (False, True)) @@ -3755,22 +3762,14 @@ def test_grouped_mlp( pytest.skip("NVFP4 4over6 grouped quantization is not supported") if ( with_quantization - and quantization in ("nvfp4", "nvfp4_row_scaled", "nvfp4_4over6") + and quantization in ("nvfp4", "nvfp4_row_scaled", "nvfp4_4over6", "nvfp4_rht") and activation.startswith("scaled_clamped_qgeglu") and bias ): # TODO: ksivaman: Need to debug numerics for this case. pytest.skip("Bias/dbias not yet supported in NVFP4 fused grouped MLP with GeGLU") fc1_out_features = 2 * hidden_size if activation_is_glu else hidden_size - use_nvfp4_rht_recipe = ( - quantization == "nvfp4" - and activation_is_glu - and glu_interleave_size == 32 - and dtype in (torch.bfloat16, torch.float16) - and te_ops.fused.ForwardGroupedMLP_CuTeGEMMGLU_MXFP8.is_supported() - and te_ops.fused.BackwardGroupedMLP_CuTeGEMMDGLU_MXFP8.is_supported() - ) - + use_nvfp4_rht_recipe = quantization == "nvfp4_rht" # Activation parameters for clamped QGeGLU variants if activation == "scaled_clamped_qgeglu_custom": geglu_limit = 5.0 @@ -3891,75 +3890,89 @@ def test_grouped_mlp( y_ref.backward(dy_ref) # Construct operations - recipe_kwargs = {} - if use_nvfp4_rht_recipe: - recipe_kwargs["disable_rht"] = False - recipe = make_recipe(quantization, **recipe_kwargs) - if activation == "scaled_clamped_qgeglu_custom": - scaled_act = te_ops.ScaledClampedQGeGLU( - glu_interleave_size=glu_interleave_size, - limit=geglu_limit, - alpha=geglu_alpha, - glu_linear_offset=geglu_offset, - ) - with te.quantized_model_init(enabled=with_quantization, recipe=recipe): - fc1 = te_ops.GroupedLinear( - group_size, - hidden_size, - fc1_out_features, - bias=bias, - device=device, - dtype=dtype, - single_grouped_weight=single_grouped_weight, - single_grouped_bias=single_grouped_bias, - accumulate_into_main_grad=accumulate_into_main_grad, - delay_wgrad_compute=delay_wgrad_compute, - ) + recipe = make_recipe(quantization) - fc2 = te_ops.GroupedLinear( - group_size, - hidden_size, - hidden_size, - bias=bias, - device=device, - dtype=dtype, - single_grouped_weight=single_grouped_weight, - single_grouped_bias=single_grouped_bias, - accumulate_into_main_grad=accumulate_into_main_grad, - delay_wgrad_compute=delay_wgrad_compute, - scale_bias=bias, - ) - module = te_ops.Sequential( - fc1, - scaled_act, - fc2, - ) + def _make_scaled_act(): + if activation == "scaled_swiglu": + return te_ops.ScaledSwiGLU(glu_interleave_size=glu_interleave_size) + if activation == "scaled_clamped_qgeglu_custom": + return te_ops.ScaledClampedQGeGLU( + glu_interleave_size=glu_interleave_size, + limit=geglu_limit, + alpha=geglu_alpha, + glu_linear_offset=geglu_offset, + ) + if activation.startswith("scaled_clamped_qgeglu"): + return te_ops.ScaledClampedQGeGLU(glu_interleave_size=glu_interleave_size) + if activation == "scaled_srelu": + return te_ops.ScaledSReLU() + raise ValueError(f"Unexpected grouped MLP activation ({activation})") + + def _make_module(): + with te.quantized_model_init(enabled=with_quantization, recipe=recipe): + fc1_op = te_ops.GroupedLinear( + group_size, + hidden_size, + fc1_out_features, + bias=bias, + device=device, + dtype=dtype, + single_grouped_weight=single_grouped_weight, + single_grouped_bias=single_grouped_bias, + accumulate_into_main_grad=accumulate_into_main_grad, + delay_wgrad_compute=delay_wgrad_compute, + ) + + fc2_op = te_ops.GroupedLinear( + group_size, + hidden_size, + hidden_size, + bias=bias, + device=device, + dtype=dtype, + single_grouped_weight=single_grouped_weight, + single_grouped_bias=single_grouped_bias, + accumulate_into_main_grad=accumulate_into_main_grad, + delay_wgrad_compute=delay_wgrad_compute, + scale_bias=bias, + ) + return te_ops.Sequential(fc1_op, _make_scaled_act(), fc2_op), fc1_op, fc2_op + + module, fc1, fc2 = _make_module() + if use_nvfp4_rht_recipe: + module_ref, fc1_ref, fc2_ref = _make_module() + else: + module_ref, fc1_ref, fc2_ref = None, None, None # Copy weights with torch.no_grad(): - if single_grouped_weight: - fc1_weights = fc1.weight.quantized_tensors - if fc1_weights is None: - fc1_weights = fc1.weight.split_into_quantized_tensors() - fc2_weights = fc2.weight.quantized_tensors - if fc2_weights is None: - fc2_weights = fc2.weight.split_into_quantized_tensors() - for group_idx in range(group_size): + target_modules = [(fc1, fc2)] + if use_nvfp4_rht_recipe: + target_modules.append((fc1_ref, fc2_ref)) + for fc1_target, fc2_target in target_modules: if single_grouped_weight: - fc1_weights[group_idx].copy_(fc1_ws_test[group_idx]) - fc2_weights[group_idx].copy_(fc2_ws_test[group_idx]) - else: - getattr(fc1, f"weight{group_idx}").copy_(fc1_ws_test[group_idx]) - getattr(fc2, f"weight{group_idx}").copy_(fc2_ws_test[group_idx]) - if bias: - if single_grouped_bias: - fc1_bparts = fc1.bias.split_into_quantized_tensors() - fc2_bparts = fc2.bias.split_into_quantized_tensors() - fc1_bparts[group_idx].reshape(-1).copy_(fc1_bs_test[group_idx]) - fc2_bparts[group_idx].reshape(-1).copy_(fc2_bs_test[group_idx]) + fc1_weights = fc1_target.weight.quantized_tensors + if fc1_weights is None: + fc1_weights = fc1_target.weight.split_into_quantized_tensors() + fc2_weights = fc2_target.weight.quantized_tensors + if fc2_weights is None: + fc2_weights = fc2_target.weight.split_into_quantized_tensors() + for group_idx in range(group_size): + if single_grouped_weight: + fc1_weights[group_idx].copy_(fc1_ws_test[group_idx]) + fc2_weights[group_idx].copy_(fc2_ws_test[group_idx]) else: - getattr(fc1, f"bias{group_idx}").copy_(fc1_bs_test[group_idx]) - getattr(fc2, f"bias{group_idx}").copy_(fc2_bs_test[group_idx]) + getattr(fc1_target, f"weight{group_idx}").copy_(fc1_ws_test[group_idx]) + getattr(fc2_target, f"weight{group_idx}").copy_(fc2_ws_test[group_idx]) + if bias: + if single_grouped_bias: + fc1_bparts = fc1_target.bias.split_into_quantized_tensors() + fc2_bparts = fc2_target.bias.split_into_quantized_tensors() + fc1_bparts[group_idx].reshape(-1).copy_(fc1_bs_test[group_idx]) + fc2_bparts[group_idx].reshape(-1).copy_(fc2_bs_test[group_idx]) + else: + getattr(fc1_target, f"bias{group_idx}").copy_(fc1_bs_test[group_idx]) + getattr(fc2_target, f"bias{group_idx}").copy_(fc2_bs_test[group_idx]) if accumulate_into_main_grad: # 0.5 sentinel lets us reconstruct ``expected = ref_grad + 0.5`` # below and detect a missed accumulation. @@ -3975,8 +3988,63 @@ def test_grouped_mlp( fill_value=main_grad_sentinel, overwrite_main_grad=False, ) + if use_nvfp4_rht_recipe: + fc1_ref, fc2_ref = module_ref[0], module_ref[2] + if single_grouped_weight: + ref_weight_params_for_main_grad = [fc1_ref.weight, fc2_ref.weight] + else: + ref_weight_params_for_main_grad = [ + getattr(fc, f"weight{i}") + for fc in (fc1_ref, fc2_ref) + for i in range(group_size) + ] + MegatronTrainingHelper.init_main_grad_buffers( + ref_weight_params_for_main_grad, + fill_value=main_grad_sentinel, + overwrite_main_grad=False, + ) del fc1_ws_test, fc1_bs_test, fc2_ws_test, fc2_bs_test + y_ref_te = None + x_ref_te = None + probs_ref_te = None + if use_nvfp4_rht_recipe: + x_ref_te = x_test.detach().clone().requires_grad_(x_test.requires_grad) + probs_ref_te = probs_test.detach().clone().requires_grad_(probs_test.requires_grad) + dy_ref_te = dy_test.detach().clone() + + def _clear_grouped_mlp_support_caches() -> None: + for cls in ( + te_ops.fused.ForwardGroupedMLP_CuTeGEMMGLU_MXFP8, + te_ops.fused.BackwardGroupedMLP_CuTeGEMMDGLU_MXFP8, + ): + cache_clear = getattr(cls.is_supported, "cache_clear", None) + if cache_clear is not None: + cache_clear() + + old_fused_grouped_mlp = os.environ.get("NVTE_CUTEDSL_FUSED_GROUPED_MLP") + try: + os.environ["NVTE_CUTEDSL_FUSED_GROUPED_MLP"] = "0" + _clear_grouped_mlp_support_caches() + with te.autocast(enabled=with_quantization, recipe=recipe): + fc2_extra_ref = (split_sizes, probs_ref_te) if bias else (split_sizes,) + y_ref_te = module_ref( + x_ref_te, + split_sizes, + probs_ref_te, + *fc2_extra_ref, + ) + y_ref_te.backward(dy_ref_te) + if delay_wgrad_compute: + module_ref[0].backward_dw() + module_ref[2].backward_dw() + finally: + if old_fused_grouped_mlp is None: + os.environ.pop("NVTE_CUTEDSL_FUSED_GROUPED_MLP", None) + else: + os.environ["NVTE_CUTEDSL_FUSED_GROUPED_MLP"] = old_fused_grouped_mlp + _clear_grouped_mlp_support_caches() + # Fuse ops and perform forward and backward pass with te.autocast(enabled=with_quantization, recipe=recipe): fc2_extra = (split_sizes, probs_test) if bias else (split_sizes,) @@ -3987,9 +4055,7 @@ def test_grouped_mlp( fc2.backward_dw() # Check for expected fusions - if use_nvfp4_rht_recipe: - return - if ( + expected_grouped_mlp_fusion = ( quantization == "mxfp8" and dtype in (torch.bfloat16, torch.float16) and ( @@ -3997,7 +4063,8 @@ def test_grouped_mlp( or (activation_is_glu and glu_interleave_size == 32) ) and _cudnn_frontend_version_supported() - ): + ) + if expected_grouped_mlp_fusion: if activation_is_glu: forward_cls = te_ops.fused.ForwardGroupedMLP_CuTeGEMMGLU_MXFP8 backward_cls = te_ops.fused.BackwardGroupedMLP_CuTeGEMMDGLU_MXFP8 @@ -4021,9 +4088,63 @@ def test_grouped_mlp( # Loose tols for sanity checking tols = {"rtol": 0.125, "atol": 0.25} - if quantization in ("nvfp4", "nvfp4_row_scaled", "nvfp4_4over6"): + if quantization in ("nvfp4", "nvfp4_row_scaled", "nvfp4_4over6", "nvfp4_rht"): tols = {"rtol": 0.25, "atol": 0.5} + if use_nvfp4_rht_recipe: + fc1_ref, fc2_ref = module_ref[0], module_ref[2] + assert_close(y_test, y_ref_te, **tols) + assert_close_grads(x_test, x_ref_te, **tols) + assert_close_grads(probs_test, probs_ref_te, **tols) + for group_idx in range(group_size): + if bias: + if single_grouped_bias: + assert_close( + fc2.bias.grad[group_idx], + fc2_ref.bias.grad[group_idx], + **tols, + ) + assert_close( + fc1.bias.grad[group_idx], + fc1_ref.bias.grad[group_idx], + **tols, + ) + else: + assert_close_grads( + getattr(fc2, f"bias{group_idx}"), + getattr(fc2_ref, f"bias{group_idx}"), + **tols, + ) + assert_close_grads( + getattr(fc1, f"bias{group_idx}"), + getattr(fc1_ref, f"bias{group_idx}"), + **tols, + ) + if not single_grouped_weight and not accumulate_into_main_grad: + assert_close_grads( + getattr(fc2, f"weight{group_idx}"), + getattr(fc2_ref, f"weight{group_idx}"), + **tols, + ) + assert_close_grads( + getattr(fc1, f"weight{group_idx}"), + getattr(fc1_ref, f"weight{group_idx}"), + **tols, + ) + if accumulate_into_main_grad: + expected_main_grads = [ + ref_weight.main_grad for ref_weight in ref_weight_params_for_main_grad + ] + MegatronTrainingHelper.verify_main_grad_accumulation( + weight_params_for_main_grad, + expected_main_grads=expected_main_grads, + **tols, + ) + elif single_grouped_weight: + assert_close(fc1.weight.grad, fc1_ref.weight.grad, **tols) + assert_close(fc2.weight.grad, fc2_ref.weight.grad, **tols) + return + # Check values assert_close(y_test, y_ref, **tols) assert_close_grads(x_test, x_ref, **tols) diff --git a/tests/pytorch/utils.py b/tests/pytorch/utils.py index 19cc118a90..84489f30c1 100644 --- a/tests/pytorch/utils.py +++ b/tests/pytorch/utils.py @@ -118,7 +118,7 @@ def quantization_tols(name: str) -> dict[str, float]: "mxfp8_block_scaling", ): return dtype_tols(tex.DType.kFloat8E4M3) - if name in ("nvfp4", "nvfp4_row_scaled", "nvfp4_4over6"): + if name in ("nvfp4", "nvfp4_row_scaled", "nvfp4_4over6", "nvfp4_rht"): return dtype_tols(tex.DType.kFloat4E2M1) raise ValueError(f"Unsupported quantization scheme ({name})") @@ -145,10 +145,10 @@ def make_recipe(name: Optional[str], **recipe_kwargs: Any) -> Optional[Recipe]: ) if name == "fp8_block_scaling": return transformer_engine.common.recipe.Float8BlockScaling(**recipe_kwargs) - if name in ("nvfp4", "nvfp4_row_scaled", "nvfp4_4over6"): + if name in ("nvfp4", "nvfp4_row_scaled", "nvfp4_4over6", "nvfp4_rht"): use_4over6 = name == "nvfp4_4over6" kwargs = { - "disable_rht": True, + "disable_rht": name != "nvfp4_rht", "disable_stochastic_rounding": True, "disable_2d_quantization": not use_4over6, "row_scaled_activation": name == "nvfp4_row_scaled", @@ -163,12 +163,16 @@ def recipe_id(recipe: Optional[Recipe]) -> str: """Readable pytest id for a quantization recipe.""" if not isinstance(recipe, Recipe): return "None" - if recipe.nvfp4() and recipe.row_scaled_activation and recipe.nvfp4_4over6 != "none": - return "NVFP4RowScaled4Over6BlockScaling" - if recipe.nvfp4() and recipe.nvfp4_4over6 != "none": - return "NVFP44Over6BlockScaling" - if recipe.nvfp4() and recipe.row_scaled_activation: - return "NVFP4RowScaledBlockScaling" + if recipe.nvfp4(): + nvfp4_features = [] + if recipe.row_scaled_activation: + nvfp4_features.append("RowScaled") + if recipe.nvfp4_4over6 != "none": + nvfp4_features.append("4Over6") + if not recipe.disable_rht: + nvfp4_features.append("RHT") + if nvfp4_features: + return f"NVFP4{''.join(nvfp4_features)}BlockScaling" return type(recipe).__name__ diff --git a/transformer_engine/common/gemm/cublaslt_grouped_gemm.cu b/transformer_engine/common/gemm/cublaslt_grouped_gemm.cu index 01c983ad3d..f064af2478 100644 --- a/transformer_engine/common/gemm/cublaslt_grouped_gemm.cu +++ b/transformer_engine/common/gemm/cublaslt_grouped_gemm.cu @@ -1751,16 +1751,8 @@ void nvte_grouped_gemm_with_discrete_inputA(const NVTETensor *A_list, size_t num NVTE_CHECK(A_list_info.all_col, "Grouped GEMM: A_list is missing column-wise data"); A_sel.dtype = A_list_info.col_dtype; } - // GroupedTensor metadata stores the original logical shape, so columnwise - // storage usually needs storage_transposed. Discrete NVFP4 A tensors with - // logical transa=false expose columnwise data with the transposed logical - // shape already, so treating it as storage_transposed would undo the layout - // needed by cuBLAS. - const bool nvfp4_discrete_a_columnwise = nvfp4 && !static_cast(transa); - const bool a_list_storage_transposed = - nvfp4_discrete_a_columnwise ? false : choice.storage_transposed; a_multi_tensor_args = build_grouped_gemm_multi_inputA_args( - A_list, num_a_tensors, choice.use_rowwise, a_list_storage_transposed, &avg_first_dim, + A_list, num_a_tensors, choice.use_rowwise, choice.storage_transposed, &avg_first_dim, &avg_last_dim, "A"); // Discrete A_list: per-tensor pointers come from `a_multi_tensor_args` (data/scale/amax). diff --git a/transformer_engine/pytorch/ops/_common.py b/transformer_engine/pytorch/ops/_common.py index 9ca630811f..18c42b5abf 100644 --- a/transformer_engine/pytorch/ops/_common.py +++ b/transformer_engine/pytorch/ops/_common.py @@ -424,6 +424,9 @@ def fuse_grouped_mlp_ops( return ops if recipe is None or not (recipe.mxfp8() or recipe.nvfp4()): return ops + # NVFP4 fused grouped MLP uses graph-safe grouped quantize, which currently requires RHT. + if recipe.nvfp4() and recipe.disable_rht: + return ops if activation_op_types is None: activation_op_types = (ScaledSwiGLU, ScaledClampedQGeGLU) From 93ec88365ccd6d47f34f40eb738f8814b7a0b123 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 29 May 2026 20:33:30 +0000 Subject: [PATCH 08/14] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../pytorch/ops/fused/backward_grouped_mlp.py | 8 ++------ .../pytorch/ops/fused/forward_grouped_mlp.py | 10 ++-------- 2 files changed, 4 insertions(+), 14 deletions(-) diff --git a/transformer_engine/pytorch/ops/fused/backward_grouped_mlp.py b/transformer_engine/pytorch/ops/fused/backward_grouped_mlp.py index faa4866665..bdedb7b242 100644 --- a/transformer_engine/pytorch/ops/fused/backward_grouped_mlp.py +++ b/transformer_engine/pytorch/ops/fused/backward_grouped_mlp.py @@ -652,9 +652,7 @@ def fuser_backward( device, ) swizzle_type = ( - "uniform_nvfp4_swizzle" - if use_nvfp4 - else "uniform_mxfp8_columnwise_swizzle" + "uniform_nvfp4_swizzle" if use_nvfp4 else "uniform_mxfp8_columnwise_swizzle" ) fc2_sfb_ptrs, _fc2_sfb_buffer = tex.transform_and_copy_data_ptrs_to_device( swizzle_type, @@ -916,9 +914,7 @@ def fuser_backward( device, ) swizzle_type = ( - "uniform_nvfp4_swizzle" - if use_nvfp4 - else "uniform_mxfp8_columnwise_swizzle" + "uniform_nvfp4_swizzle" if use_nvfp4 else "uniform_mxfp8_columnwise_swizzle" ) fc1_sfb_ptrs, _fc1_sfb_buffer = tex.transform_and_copy_data_ptrs_to_device( swizzle_type, diff --git a/transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py b/transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py index 5629e8a1e8..41f7c72ca5 100644 --- a/transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py +++ b/transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py @@ -452,11 +452,7 @@ def fuser_forward( [w._rowwise_data for w in grouped_fc1_weight], device, ) - swizzle_type = ( - "uniform_nvfp4_swizzle" - if use_nvfp4 - else "uniform_mxfp8_rowwise_swizzle" - ) + swizzle_type = "uniform_nvfp4_swizzle" if use_nvfp4 else "uniform_mxfp8_rowwise_swizzle" fc1_sfb_ptrs, _fc1_sfb_buffer = tex.transform_and_copy_data_ptrs_to_device( swizzle_type, [w._rowwise_scale_inv for w in grouped_fc1_weight], @@ -633,9 +629,7 @@ def fuser_forward( device, ) swizzle_type = ( - "uniform_nvfp4_swizzle" - if use_nvfp4 - else "uniform_mxfp8_rowwise_swizzle" + "uniform_nvfp4_swizzle" if use_nvfp4 else "uniform_mxfp8_rowwise_swizzle" ) fc2_sfb_ptrs, _fc2_sfb_buffer = tex.transform_and_copy_data_ptrs_to_device( swizzle_type, From 3947a88003db04538ad0ca1d99242f77473992c9 Mon Sep 17 00:00:00 2001 From: Tim Moon Date: Fri, 29 May 2026 21:28:31 +0000 Subject: [PATCH 09/14] [PyTorch] Drop `_MXFP8` suffix from fused grouped MLP op classes These fused ops now support both MXFP8 and NVFP4, so the recipe-specific suffix is misleading. Rename the four `CuTeGEMM` GLU/Unary forward/backward classes and update callsites and tests. Co-Authored-By: Claude Opus 4.7 (1M context) Signed-off-by: Tim Moon --- tests/pytorch/test_fusible_ops.py | 30 +++++++++---------- .../pytorch/ops/fused/__init__.py | 8 ++--- .../pytorch/ops/fused/backward_grouped_mlp.py | 16 +++++----- .../pytorch/ops/fused/forward_grouped_mlp.py | 16 +++++----- 4 files changed, 35 insertions(+), 35 deletions(-) diff --git a/tests/pytorch/test_fusible_ops.py b/tests/pytorch/test_fusible_ops.py index 1a656191d9..6dc828a37e 100644 --- a/tests/pytorch/test_fusible_ops.py +++ b/tests/pytorch/test_fusible_ops.py @@ -4015,8 +4015,8 @@ def _make_module(): def _clear_grouped_mlp_support_caches() -> None: for cls in ( - te_ops.fused.ForwardGroupedMLP_CuTeGEMMGLU_MXFP8, - te_ops.fused.BackwardGroupedMLP_CuTeGEMMDGLU_MXFP8, + te_ops.fused.ForwardGroupedMLP_CuTeGEMMGLU, + te_ops.fused.BackwardGroupedMLP_CuTeGEMMDGLU, ): cache_clear = getattr(cls.is_supported, "cache_clear", None) if cache_clear is not None: @@ -4066,11 +4066,11 @@ def _clear_grouped_mlp_support_caches() -> None: ) if expected_grouped_mlp_fusion: if activation_is_glu: - forward_cls = te_ops.fused.ForwardGroupedMLP_CuTeGEMMGLU_MXFP8 - backward_cls = te_ops.fused.BackwardGroupedMLP_CuTeGEMMDGLU_MXFP8 + forward_cls = te_ops.fused.ForwardGroupedMLP_CuTeGEMMGLU + backward_cls = te_ops.fused.BackwardGroupedMLP_CuTeGEMMDGLU else: - forward_cls = te_ops.fused.ForwardGroupedMLP_CuTeGEMMUnary_MXFP8 - backward_cls = te_ops.fused.BackwardGroupedMLP_CuTeGEMMDUnary_MXFP8 + forward_cls = te_ops.fused.ForwardGroupedMLP_CuTeGEMMUnary + backward_cls = te_ops.fused.BackwardGroupedMLP_CuTeGEMMDUnary if forward_cls.is_supported(): forward_ops = module._module_groups[0]._forward_ops assert len(forward_ops) == 1 @@ -4222,9 +4222,9 @@ def test_grouped_mlp_single_weight_numerics( ) -> None: """single_grouped_weight=True/False should match exactly for fused MXFP8 grouped MLP.""" - if not te_ops.fused.ForwardGroupedMLP_CuTeGEMMGLU_MXFP8.is_supported(): + if not te_ops.fused.ForwardGroupedMLP_CuTeGEMMGLU.is_supported(): pytest.skip("MXFP8 fused grouped MLP forward is not supported on this system") - if not te_ops.fused.BackwardGroupedMLP_CuTeGEMMDGLU_MXFP8.is_supported(): + if not te_ops.fused.BackwardGroupedMLP_CuTeGEMMDGLU.is_supported(): pytest.skip("MXFP8 fused grouped MLP backward is not supported on this system") split_sizes = [split_alignment * (i + 1) for i in range(group_size)] @@ -4326,12 +4326,12 @@ def _run_case(single_grouped_weight: bool) -> tuple[torch.Tensor, ...]: assert len(forward_ops) == 1 assert isinstance( forward_ops[0][0], - te_ops.fused.ForwardGroupedMLP_CuTeGEMMGLU_MXFP8, + te_ops.fused.ForwardGroupedMLP_CuTeGEMMGLU, ) assert len(backward_ops) == 1 assert isinstance( backward_ops[0][0], - te_ops.fused.BackwardGroupedMLP_CuTeGEMMDGLU_MXFP8, + te_ops.fused.BackwardGroupedMLP_CuTeGEMMDGLU, ) if single_grouped_weight: @@ -4444,9 +4444,9 @@ def test_grouped_mlp_overwrite_main_grad( that read ``.grad`` don't see stale bytes from the cached dummy). """ - if not te_ops.fused.ForwardGroupedMLP_CuTeGEMMGLU_MXFP8.is_supported(): + if not te_ops.fused.ForwardGroupedMLP_CuTeGEMMGLU.is_supported(): pytest.skip("MXFP8 fused grouped MLP forward is not supported on this system") - if not te_ops.fused.BackwardGroupedMLP_CuTeGEMMDGLU_MXFP8.is_supported(): + if not te_ops.fused.BackwardGroupedMLP_CuTeGEMMDGLU.is_supported(): pytest.skip("MXFP8 fused grouped MLP backward is not supported on this system") recipe = make_recipe("mxfp8") @@ -4578,7 +4578,7 @@ def test_grouped_mlp_cuda_graph_safe_mxfp8( ) -> None: """Grouped MLP forward+backward should be CUDA graph capturable (MXFP8).""" - if not te_ops.fused.ForwardGroupedMLP_CuTeGEMMGLU_MXFP8.is_supported(): + if not te_ops.fused.ForwardGroupedMLP_CuTeGEMMGLU.is_supported(): pytest.skip("MXFP8 fused grouped MLP is not supported on this system") if dtype not in (torch.bfloat16, torch.float16): pytest.skip("MXFP8 fused grouped MLP is only supported with BF16/FP16") @@ -4720,12 +4720,12 @@ def train_step( assert len(forward_ops) == 1 assert isinstance( forward_ops[0][0], - te_ops.fused.ForwardGroupedMLP_CuTeGEMMGLU_MXFP8, + te_ops.fused.ForwardGroupedMLP_CuTeGEMMGLU, ) assert len(backward_ops) == 1 assert isinstance( backward_ops[0][0], - te_ops.fused.BackwardGroupedMLP_CuTeGEMMDGLU_MXFP8, + te_ops.fused.BackwardGroupedMLP_CuTeGEMMDGLU, ) fresh_x = torch.randn_like(static_x) diff --git a/transformer_engine/pytorch/ops/fused/__init__.py b/transformer_engine/pytorch/ops/fused/__init__.py index b29e35814d..78f9d880ba 100644 --- a/transformer_engine/pytorch/ops/fused/__init__.py +++ b/transformer_engine/pytorch/ops/fused/__init__.py @@ -32,10 +32,10 @@ # Import experimental fusions # Note: Registration logic is non-trivial, so submodule handles it internally. from .forward_grouped_mlp import ( # pylint: disable=wrong-import-position - ForwardGroupedMLP_CuTeGEMMGLU_MXFP8, - ForwardGroupedMLP_CuTeGEMMUnary_MXFP8, + ForwardGroupedMLP_CuTeGEMMGLU, + ForwardGroupedMLP_CuTeGEMMUnary, ) from .backward_grouped_mlp import ( # pylint: disable=wrong-import-position - BackwardGroupedMLP_CuTeGEMMDGLU_MXFP8, - BackwardGroupedMLP_CuTeGEMMDUnary_MXFP8, + BackwardGroupedMLP_CuTeGEMMDGLU, + BackwardGroupedMLP_CuTeGEMMDUnary, ) diff --git a/transformer_engine/pytorch/ops/fused/backward_grouped_mlp.py b/transformer_engine/pytorch/ops/fused/backward_grouped_mlp.py index bdedb7b242..c3882aca4b 100644 --- a/transformer_engine/pytorch/ops/fused/backward_grouped_mlp.py +++ b/transformer_engine/pytorch/ops/fused/backward_grouped_mlp.py @@ -970,8 +970,8 @@ def fuser_backward( ) -class BackwardGroupedMLP_CuTeGEMMDGLU_MXFP8(_BackwardGroupedMLP_CuTeGEMMDBase): - """Fused backward op for GroupedLinear + scaled GLU + GroupedLinear.""" +class BackwardGroupedMLP_CuTeGEMMDGLU(_BackwardGroupedMLP_CuTeGEMMDBase): + """Fused backward op for block-scaled GroupedLinear + scaled GLU + GroupedLinear.""" @classmethod @functools.lru_cache(maxsize=None) @@ -982,8 +982,8 @@ def grouped_gemm_dactivation_kernel(cls) -> Callable: return grouped_gemm_dglu_wrapper_sm100 -class BackwardGroupedMLP_CuTeGEMMDUnary_MXFP8(_BackwardGroupedMLP_CuTeGEMMDBase): - """Fused backward op for GroupedLinear + scaled unary activation + GroupedLinear.""" +class BackwardGroupedMLP_CuTeGEMMDUnary(_BackwardGroupedMLP_CuTeGEMMDBase): + """Fused backward op for block-scaled GroupedLinear + scaled unary activation + GroupedLinear.""" @classmethod @functools.lru_cache(maxsize=None) @@ -1025,7 +1025,7 @@ def fuse_backward_ops( return fuse_grouped_mlp_ops( ops, recipe=recipe, - fused_op_cls=BackwardGroupedMLP_CuTeGEMMDGLU_MXFP8, + fused_op_cls=BackwardGroupedMLP_CuTeGEMMDGLU, ) @@ -1042,13 +1042,13 @@ def fuse_backward_srelu_ops( return fuse_grouped_mlp_ops( ops, recipe=recipe, - fused_op_cls=BackwardGroupedMLP_CuTeGEMMDUnary_MXFP8, + fused_op_cls=BackwardGroupedMLP_CuTeGEMMDUnary, activation_op_types=(ScaledSReLU,), ) # Register fusion if available -if BackwardGroupedMLP_CuTeGEMMDGLU_MXFP8.is_supported(): +if BackwardGroupedMLP_CuTeGEMMDGLU.is_supported(): register_backward_fusion(fuse_backward_ops, prepend=True) -if BackwardGroupedMLP_CuTeGEMMDUnary_MXFP8.is_supported(): +if BackwardGroupedMLP_CuTeGEMMDUnary.is_supported(): register_backward_fusion(fuse_backward_srelu_ops, prepend=True) diff --git a/transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py b/transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py index 41f7c72ca5..2596d4d1d5 100644 --- a/transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py +++ b/transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py @@ -733,8 +733,8 @@ def fuser_forward( return fc2_out, [(), (), ()] -class ForwardGroupedMLP_CuTeGEMMGLU_MXFP8(_ForwardGroupedMLP_CuTeGEMMBase): - """Fused op for MXFP8 GroupedLinear + scaled GLU + GroupedLinear.""" +class ForwardGroupedMLP_CuTeGEMMGLU(_ForwardGroupedMLP_CuTeGEMMBase): + """Fused op for block-scaled GroupedLinear + scaled GLU + GroupedLinear.""" @classmethod @functools.lru_cache(maxsize=None) @@ -745,8 +745,8 @@ def grouped_gemm_activation_kernel(cls) -> Callable: return grouped_gemm_glu_wrapper_sm100 -class ForwardGroupedMLP_CuTeGEMMUnary_MXFP8(_ForwardGroupedMLP_CuTeGEMMBase): - """Fused op for MXFP8 GroupedLinear + scaled unary activation + GroupedLinear.""" +class ForwardGroupedMLP_CuTeGEMMUnary(_ForwardGroupedMLP_CuTeGEMMBase): + """Fused op for block-scaled GroupedLinear + scaled unary activation + GroupedLinear.""" @classmethod @functools.lru_cache(maxsize=None) @@ -788,7 +788,7 @@ def fuse_forward_ops( return fuse_grouped_mlp_ops( ops, recipe=recipe, - fused_op_cls=ForwardGroupedMLP_CuTeGEMMGLU_MXFP8, + fused_op_cls=ForwardGroupedMLP_CuTeGEMMGLU, ) @@ -805,13 +805,13 @@ def fuse_forward_srelu_ops( return fuse_grouped_mlp_ops( ops, recipe=recipe, - fused_op_cls=ForwardGroupedMLP_CuTeGEMMUnary_MXFP8, + fused_op_cls=ForwardGroupedMLP_CuTeGEMMUnary, activation_op_types=(ScaledSReLU,), ) # Register fusion if available -if ForwardGroupedMLP_CuTeGEMMGLU_MXFP8.is_supported(): +if ForwardGroupedMLP_CuTeGEMMGLU.is_supported(): register_forward_fusion(fuse_forward_ops, prepend=True) -if ForwardGroupedMLP_CuTeGEMMUnary_MXFP8.is_supported(): +if ForwardGroupedMLP_CuTeGEMMUnary.is_supported(): register_forward_fusion(fuse_forward_srelu_ops, prepend=True) From 820220c7a0a17721ffb761a7b9aba15bd9e305bb Mon Sep 17 00:00:00 2001 From: Tim Moon Date: Fri, 29 May 2026 21:45:45 +0000 Subject: [PATCH 10/14] [PyTorch] Add `ceil_div` utility and use it in fused grouped MLP ops Adds `ceil_div` next to `round_up_to_nearest_multiple` in `pytorch/utils.py` and replaces the `(x + d - 1) // d` patterns in the fused grouped MLP ops with it. Also fixes asymmetric floor-divs in the NVFP4 scale-view paths (`data_(in_)k // k_sf_divisor`) that would underestimate the scale-block count for padded layouts; the MXFP8 branches already used ceil-div for the same dimension. Co-Authored-By: Claude Opus 4.7 (1M context) Signed-off-by: Tim Moon --- .../pytorch/ops/fused/backward_grouped_mlp.py | 32 +++++++++++-------- .../pytorch/ops/fused/forward_grouped_mlp.py | 27 +++++++++------- transformer_engine/pytorch/utils.py | 9 +++++- 3 files changed, 43 insertions(+), 25 deletions(-) diff --git a/transformer_engine/pytorch/ops/fused/backward_grouped_mlp.py b/transformer_engine/pytorch/ops/fused/backward_grouped_mlp.py index c3882aca4b..26ee0cbcf5 100644 --- a/transformer_engine/pytorch/ops/fused/backward_grouped_mlp.py +++ b/transformer_engine/pytorch/ops/fused/backward_grouped_mlp.py @@ -17,7 +17,13 @@ from ...tensor import NVFP4Quantizer, NVFP4Tensor from ...tensor.grouped_tensor import GroupedTensor from ...tensor.mxfp8_tensor import MXFP8Quantizer -from ...utils import clear_tensor_data, get_cached_ones_tensor, get_device_compute_capability +from ...utils import ( + ceil_div, + clear_tensor_data, + get_cached_ones_tensor, + get_device_compute_capability, + round_up_to_nearest_multiple, +) from ...constants import MXFP8_BLOCK_SCALING_SIZE, NVFP4_BLOCK_SCALING_SIZE from ..basic import GroupedLinear, ScaledSReLU, ScaledClampedQGeGLU from ..fuser import register_backward_fusion @@ -96,8 +102,8 @@ def _cudnn_compute_wgrad( fp8_dtype = torch.float8_e4m3fn - sfa_leading_dim = ((out_features + 127) // 128) * 128 - sfb_leading_dim = ((in_features + 127) // 128) * 128 + sfa_leading_dim = round_up_to_nearest_multiple(out_features, 128) + sfb_leading_dim = round_up_to_nearest_multiple(in_features, 128) if total_tokens == 0: # A workaround for the case with zero-token experts. @@ -533,8 +539,8 @@ def fuser_backward( if use_nvfp4 and with_gemm_swizzled_scales: fc2_dy_scales = fc2_dy_scales.view( 1, - (out_shape[0] + 127) // 128, - data_k // k_sf_divisor, + ceil_div(out_shape[0], 128), + ceil_div(data_k, k_sf_divisor), 32, 4, 4, @@ -543,18 +549,18 @@ def fuser_backward( elif use_nvfp4: fc2_dy_scales = fc2_dy_scales.view( 1, - (out_shape[0] + 127) // 128, + ceil_div(out_shape[0], 128), 4, 32, - data_k // k_sf_divisor, + ceil_div(data_k, k_sf_divisor), 4, ) fc2_dy_scales = fc2_dy_scales.permute(3, 2, 1, 5, 4, 0) else: fc2_dy_scales = fc2_dy_scales.view( 1, - (out_shape[0] + 127) // 128, - (out_shape[1] + k_sf_divisor - 1) // k_sf_divisor, + ceil_div(out_shape[0], 128), + ceil_div(out_shape[1], k_sf_divisor), 32, 4, 4, @@ -632,8 +638,8 @@ def fuser_backward( fc2_w_scales = fc2_weight_for_gemm.columnwise_scale_inv.view(dtype=scale_view_dtype) fc2_w_scales = fc2_w_scales.view( num_groups, - (fc2_weight_shape[1] + k_sf_divisor - 1) // k_sf_divisor, - (fc2_weight_shape[0] + 127) // 128, + ceil_div(fc2_weight_shape[1], k_sf_divisor), + ceil_div(fc2_weight_shape[0], 128), 32, 4, 4, @@ -898,8 +904,8 @@ def fuser_backward( ) fc1_w_scales = fc1_w_scales.view( num_groups, - (fc1_weight_shape[1] + 127) // 128, - (fc1_weight_shape[0] + 127) // 128, + ceil_div(fc1_weight_shape[1], 128), + ceil_div(fc1_weight_shape[0], 128), MXFP8_BLOCK_SCALING_SIZE, 4, 4, diff --git a/transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py b/transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py index 2596d4d1d5..4aa33820c0 100644 --- a/transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py +++ b/transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py @@ -16,7 +16,12 @@ from ...cpp_extensions import general_gemm, general_grouped_gemm_for_grouped_tensor from ...quantization import Recipe from ...tensor import NVFP4Quantizer, NVFP4Tensor, Quantizer -from ...utils import get_cached_ones_tensor, get_device_compute_capability, mark_grouped_tensor +from ...utils import ( + ceil_div, + get_cached_ones_tensor, + get_device_compute_capability, + mark_grouped_tensor, +) from ...tensor.grouped_tensor import GroupedTensor from ...tensor.mxfp8_tensor import MXFP8Quantizer from ...constants import MXFP8_BLOCK_SCALING_SIZE, NVFP4_BLOCK_SCALING_SIZE @@ -343,8 +348,8 @@ def fuser_forward( if use_nvfp4 and with_gemm_swizzled_scales: fc1_x_scales = fc1_x_scales.view( 1, - (in_shape[0] + 127) // 128, - data_in_k // k_sf_divisor, + ceil_div(in_shape[0], 128), + ceil_div(data_in_k, k_sf_divisor), 32, 4, 4, @@ -353,18 +358,18 @@ def fuser_forward( elif use_nvfp4: fc1_x_scales = fc1_x_scales.view( 1, - (in_shape[0] + 127) // 128, + ceil_div(in_shape[0], 128), 4, 32, - data_in_k // k_sf_divisor, + ceil_div(data_in_k, k_sf_divisor), 4, ) fc1_x_scales = fc1_x_scales.permute(3, 2, 1, 5, 4, 0) else: fc1_x_scales = fc1_x_scales.view( 1, - (in_shape[0] + 127) // 128, - (in_shape[1] + k_sf_divisor - 1) // k_sf_divisor, + ceil_div(in_shape[0], 128), + ceil_div(in_shape[1], k_sf_divisor), 32, 4, 4, @@ -436,8 +441,8 @@ def fuser_forward( fc1_w_scales = fc1_weight_for_gemm.scale_inv.view(dtype=scale_view_dtype) fc1_w_scales = fc1_w_scales.view( num_groups, - (fc1_weight_shape[0] + 127) // 128, - (fc1_weight_shape[1] + k_sf_divisor - 1) // k_sf_divisor, + ceil_div(fc1_weight_shape[0], 128), + ceil_div(fc1_weight_shape[1], k_sf_divisor), 32, 4, 4, @@ -614,8 +619,8 @@ def fuser_forward( fc2_w_scales = fc2_weight_for_gemm.scale_inv.view(dtype=torch.float8_e8m0fnu) fc2_w_scales = fc2_w_scales.view( num_groups, - (fc2_weight_shape[0] + 127) // 128, - (fc2_weight_shape[1] + 127) // 128, + ceil_div(fc2_weight_shape[0], 128), + ceil_div(fc2_weight_shape[1], 128), MXFP8_BLOCK_SCALING_SIZE, 4, 4, diff --git a/transformer_engine/pytorch/utils.py b/transformer_engine/pytorch/utils.py index 250daec67f..fd8f817b33 100644 --- a/transformer_engine/pytorch/utils.py +++ b/transformer_engine/pytorch/utils.py @@ -626,8 +626,15 @@ def get_sm_count() -> int: return torch.cuda.get_device_properties(torch.cuda.current_device()).multi_processor_count +def ceil_div(numerator, denominator): + """Integer ceiling division: ``ceil(numerator / denominator)``.""" + if denominator == 0: + raise ValueError("denominator cannot be zero.") + return (numerator + denominator - 1) // denominator + + def round_up_to_nearest_multiple(value, multiple): - """Round up `value` to the next mutiple of `multiple`""" + """Round up `value` to the next multiple of `multiple`""" if multiple == 0: raise ValueError("multiple cannot be zero.") return ((value + multiple - 1) // multiple) * multiple From 2ba4e6cbc516b2f50875a61e313660799b112c61 Mon Sep 17 00:00:00 2001 From: Tim Moon Date: Fri, 29 May 2026 22:18:21 +0000 Subject: [PATCH 11/14] [PyTorch] Drop defensive `getattr` defaults in fused grouped MLP ops Functions like `_group_quantize_for_grouped_mlp` already constrain their inputs (e.g. an NVFP4Quantizer always yields an NVFP4Tensor), so the `getattr(..., default)` fallbacks for `_rowwise_data`/`_with_gemm_swizzled_scales`/etc. were dead code that obscured intent. Replace those with direct attribute access, drop the dead double-`getattr` for the non-underscored public name that doesn't exist, and add brief comments on the type contract. The `getattr` calls that remain are legitimate (polymorphic inputs, dynamic attribute names, user-stamped optional flags on `torch.nn.Parameter`). Co-Authored-By: Claude Opus 4.7 (1M context) Signed-off-by: Tim Moon --- transformer_engine/pytorch/ops/_common.py | 41 +++++++++++-------- .../pytorch/ops/fused/backward_grouped_mlp.py | 10 ++--- .../pytorch/ops/fused/forward_grouped_mlp.py | 16 +++----- 3 files changed, 33 insertions(+), 34 deletions(-) diff --git a/transformer_engine/pytorch/ops/_common.py b/transformer_engine/pytorch/ops/_common.py index 18c42b5abf..0b9453f679 100644 --- a/transformer_engine/pytorch/ops/_common.py +++ b/transformer_engine/pytorch/ops/_common.py @@ -5,6 +5,7 @@ """Helper functions used in fusible operations.""" from __future__ import annotations +from collections.abc import Iterable import functools import math from importlib.metadata import PackageNotFoundError, version as get_pkg_version @@ -17,7 +18,7 @@ from transformer_engine_torch import FP8TensorMeta from ..torch_version import torch_version from ..quantization import FP8GlobalStateManager -from ..tensor import NVFP4Quantizer, NVFP4Tensor, Quantizer +from ..tensor import NVFP4Quantizer, NVFP4Tensor, NVFP4TensorStorage, Quantizer from ..tensor.float8_tensor import Float8Tensor from ..tensor.grouped_tensor import GroupedTensor from ..quantized_tensor import QuantizedTensorStorage @@ -68,22 +69,28 @@ def _group_quantize_for_grouped_mlp( *, tensor_offsets: Optional[torch.Tensor] = None, ) -> GroupedTensor: - """Quantize into grouped storage, using regular quantize for one-group NVFP4.""" + """Quantize into grouped storage.""" + + # Typical case: group-quantize if num_groups != 1 or not isinstance(quantizer, NVFP4Quantizer): return tex.group_quantize(tensor, quantizer, num_groups, split_sizes) + # -------------------------------------------------- + # Special case: single-tensor NVFP4 quantize + # -------------------------------------------------- + quantized = tex.quantize(tensor, quantizer) - with_gemm_swizzled_scales = getattr(quantized, "_with_gemm_swizzled_scales", False) - if getattr(quantizer, "optimize_for_gemm", False): + with_gemm_swizzled_scales = quantized._with_gemm_swizzled_scales + if quantizer.optimize_for_gemm: tex.swizzle_scales_for_gemm_(quantized) with_gemm_swizzled_scales = True - rowwise_data = getattr(quantized, "_rowwise_data", None) - rowwise_scale = getattr(quantized, "_rowwise_scale_inv", None) - columnwise_data = getattr(quantized, "_columnwise_data", None) - columnwise_scale = getattr(quantized, "_columnwise_scale_inv", None) - amax = getattr(quantized, "_amax_rowwise", None) - columnwise_amax = getattr(quantized, "_amax_columnwise", None) + rowwise_data = quantized._rowwise_data + rowwise_scale = quantized._rowwise_scale_inv + columnwise_data = quantized._columnwise_data + columnwise_scale = quantized._columnwise_scale_inv + amax = quantized._amax_rowwise + columnwise_amax = quantized._amax_columnwise if split_sizes is None: split_sizes = torch.full((1,), tensor.shape[0], dtype=torch.int64, device=tensor.device) @@ -123,7 +130,11 @@ def _group_quantize_for_grouped_mlp( ) -def _nvfp4_amax(tensors: Any, *, columnwise: bool) -> torch.Tensor: +def _nvfp4_amax( + tensors: GroupedTensor | Iterable[NVFP4TensorStorage], + *, + columnwise: bool, +) -> torch.Tensor: """Get one NVFP4 amax value per group.""" grouped_attr = "columnwise_amax" if columnwise else "amax" tensor_attr = "_amax_columnwise" if columnwise else "_amax_rowwise" @@ -134,7 +145,7 @@ def _nvfp4_amax(tensors: Any, *, columnwise: bool) -> torch.Tensor: raise RuntimeError(f"NVFP4 GroupedTensor is missing {grouped_attr}.") return amax.view(-1) - amaxes = [getattr(tensor, tensor_attr, None) for tensor in tensors] + amaxes = [getattr(tensor, tensor_attr) for tensor in tensors] if any(amax is None for amax in amaxes): raise RuntimeError(f"NVFP4 tensor list is missing {tensor_attr}.") return torch.cat([amax.view(-1) for amax in amaxes], dim=0) @@ -186,11 +197,7 @@ def _nvfp4_single_tensor_from_grouped( fp4_dtype=fp4_dtype or quantizer.dtype, quantizer=quantizer, requires_grad=False, - with_gemm_swizzled_scales=getattr( - grouped, - "_with_gemm_swizzled_scales", - getattr(grouped, "with_gemm_swizzled_scales", quantizer.optimize_for_gemm), - ), + with_gemm_swizzled_scales=grouped._with_gemm_swizzled_scales, ) diff --git a/transformer_engine/pytorch/ops/fused/backward_grouped_mlp.py b/transformer_engine/pytorch/ops/fused/backward_grouped_mlp.py index 26ee0cbcf5..792b6d7811 100644 --- a/transformer_engine/pytorch/ops/fused/backward_grouped_mlp.py +++ b/transformer_engine/pytorch/ops/fused/backward_grouped_mlp.py @@ -264,8 +264,8 @@ def _compute_grad_params( num_groups == 1 and isinstance(grouped_x, GroupedTensor) and isinstance(grouped_dy, GroupedTensor) - and isinstance(getattr(grouped_x, "quantizer", None), NVFP4Quantizer) - and isinstance(getattr(grouped_dy, "quantizer", None), NVFP4Quantizer) + and isinstance(grouped_x.quantizer, NVFP4Quantizer) + and isinstance(grouped_dy.quantizer, NVFP4Quantizer) ): gemm_fn = functools.partial( _nvfp4_single_group_wgrad_gemm, @@ -531,11 +531,7 @@ def fuser_backward( fc2_dy_data = fc2_dy_data.unsqueeze(0).permute(1, 2, 0) fc2_dy_scales = grouped_fc2_dy.scale_inv fc2_dy_scales = fc2_dy_scales.view(dtype=scale_view_dtype) - with_gemm_swizzled_scales = getattr( - grouped_fc2_dy, - "_with_gemm_swizzled_scales", - getattr(grouped_fc2_dy, "with_gemm_swizzled_scales", False), - ) + with_gemm_swizzled_scales = grouped_fc2_dy._with_gemm_swizzled_scales if use_nvfp4 and with_gemm_swizzled_scales: fc2_dy_scales = fc2_dy_scales.view( 1, diff --git a/transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py b/transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py index 4aa33820c0..f4f2108578 100644 --- a/transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py +++ b/transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py @@ -286,12 +286,12 @@ def fuser_forward( grouped_fc2_weight = quantized_fc2_weights # Some wrapper-copy paths may drop grouped storage metadata; enforce defaults. - if getattr(grouped_fc1_weight, "_with_gemm_swizzled_scales", None) is None and isinstance( - grouped_fc1_weight, GroupedTensor + if isinstance(grouped_fc1_weight, GroupedTensor) and not hasattr( + grouped_fc1_weight, "_with_gemm_swizzled_scales" ): grouped_fc1_weight._with_gemm_swizzled_scales = False - if getattr(grouped_fc2_weight, "_with_gemm_swizzled_scales", None) is None and isinstance( - grouped_fc2_weight, GroupedTensor + if isinstance(grouped_fc2_weight, GroupedTensor) and not hasattr( + grouped_fc2_weight, "_with_gemm_swizzled_scales" ): grouped_fc2_weight._with_gemm_swizzled_scales = False @@ -340,11 +340,7 @@ def fuser_forward( fc1_x_data = fc1_x_data.unsqueeze(0).permute(1, 2, 0) fc1_x_scales = grouped_fc1_x.scale_inv fc1_x_scales = fc1_x_scales.view(dtype=scale_view_dtype) - with_gemm_swizzled_scales = getattr( - grouped_fc1_x, - "_with_gemm_swizzled_scales", - getattr(grouped_fc1_x, "with_gemm_swizzled_scales", False), - ) + with_gemm_swizzled_scales = grouped_fc1_x._with_gemm_swizzled_scales if use_nvfp4 and with_gemm_swizzled_scales: fc1_x_scales = fc1_x_scales.view( 1, @@ -522,7 +518,7 @@ def fuser_forward( fc2_x_single = _nvfp4_single_tensor_from_grouped( grouped_fc2_x, fc2_input_quantizer, - fp4_dtype=getattr(fc2_w_single, "_fp4_dtype", fc2_input_quantizer.dtype), + fp4_dtype=fc2_w_single._fp4_dtype, ) general_gemm( fc2_w_single, From a13f642ce2fc86af78e915b1c678d72bd3b96b1f Mon Sep 17 00:00:00 2001 From: Siddhartha Raman S Date: Fri, 29 May 2026 16:26:07 -0700 Subject: [PATCH 12/14] Fix NVFP4 RHT grouped MLP reference Use explicit grouped linear quantizer roles and keep the NVFP4 RHT grouped MLP reference in plain PyTorch, with RHT applied only to reference wgrad. Signed-off-by: Siddhartha Raman S --- tests/pytorch/test_fusible_ops.py | 260 +++++++----------- .../pytorch/ops/basic/grouped_linear.py | 21 +- 2 files changed, 121 insertions(+), 160 deletions(-) diff --git a/tests/pytorch/test_fusible_ops.py b/tests/pytorch/test_fusible_ops.py index 6dc828a37e..6038146c33 100644 --- a/tests/pytorch/test_fusible_ops.py +++ b/tests/pytorch/test_fusible_ops.py @@ -44,7 +44,6 @@ is_bf16_available, ) from transformer_engine.pytorch.tensor.grouped_tensor import GroupedTensor -from transformer_engine.pytorch.cpp_extensions.gemm import general_grouped_gemm_for_grouped_tensor import transformer_engine_torch as tex # Import utility functions @@ -194,7 +193,10 @@ def make_reference_and_test_tensors( elif quantization == "mxfp8": test = MXFP8Quantizer(fp8_dtype=tex.DType.kFloat8E4M3)(test) elif quantization in ("nvfp4", "nvfp4_row_scaled", "nvfp4_rht"): - with_rht = quantization == "nvfp4_rht" + tensor_type = "input" + if quantizer_role is not None: + tensor_type = quantizer_role.tensor_type + with_rht = quantization == "nvfp4_rht" and tensor_type != "weight" test = NVFP4Quantizer( with_rht=with_rht, with_post_rht_amax=with_rht, @@ -3760,6 +3762,10 @@ def test_grouped_mlp( pytest.skip("Unary activations do not use GLU interleaving") if quantization == "nvfp4_4over6": pytest.skip("NVFP4 4over6 grouped quantization is not supported") + if quantization == "nvfp4_rht" and ( + activation != "scaled_swiglu" or bias or glu_interleave_size != 32 + ): + pytest.skip("NVFP4 RHT grouped MLP coverage is limited to fused no-bias SwiGLU") if ( with_quantization and quantization in ("nvfp4", "nvfp4_row_scaled", "nvfp4_4over6", "nvfp4_rht") @@ -3769,7 +3775,6 @@ def test_grouped_mlp( # TODO: ksivaman: Need to debug numerics for this case. pytest.skip("Bias/dbias not yet supported in NVFP4 fused grouped MLP with GeGLU") fc1_out_features = 2 * hidden_size if activation_is_glu else hidden_size - use_nvfp4_rht_recipe = quantization == "nvfp4_rht" # Activation parameters for clamped QGeGLU variants if activation == "scaled_clamped_qgeglu_custom": geglu_limit = 5.0 @@ -3852,13 +3857,7 @@ def test_grouped_mlp( fc2_ws_test.append(fc2_w_test) fc2_bs_test.append(fc2_b_test) - # Reference implementation - xs = torch.split(x_ref, split_sizes.tolist()) - probs = torch.split(probs_ref, split_sizes.tolist()) - ys = [] - for group_idx in range(group_size): - x = xs[group_idx] - x = torch.nn.functional.linear(x, fc1_ws_ref[group_idx], bias=fc1_bs_ref[group_idx]) + def _apply_activation(x: torch.Tensor) -> torch.Tensor: if activation_is_glu and glu_interleave_size is not None: x = x.reshape( -1, @@ -3870,24 +3869,84 @@ def test_grouped_mlp( x = x.reshape(-1, 2 * hidden_size) if activation == "scaled_swiglu": x1, x2 = x.chunk(2, dim=-1) - x = torch.nn.functional.silu(x1) * x2 - elif activation.startswith("scaled_clamped_qgeglu"): + return torch.nn.functional.silu(x1) * x2 + if activation.startswith("scaled_clamped_qgeglu"): x1, x2 = x.chunk(2, dim=-1) lim = torch.tensor(geglu_limit, device=x1.device, dtype=x1.dtype) x1c = torch.minimum(x1, lim) x2c = torch.clamp(x2, -lim, lim) - x = (x2c + geglu_offset) * (x1c * torch.sigmoid(geglu_alpha * x1c)) - elif activation == "scaled_srelu": - x = torch.nn.functional.relu(x).square() - else: - raise ValueError(f"Unexpected grouped MLP activation ({activation})") - x = x * probs[group_idx].unsqueeze(-1) - x = torch.nn.functional.linear(x, fc2_ws_ref[group_idx]) + return (x2c + geglu_offset) * (x1c * torch.sigmoid(geglu_alpha * x1c)) + if activation == "scaled_srelu": + return torch.nn.functional.relu(x).square() + raise ValueError(f"Unexpected grouped MLP activation ({activation})") + + def _apply_nvfp4_wgrad_rht(x: torch.Tensor) -> torch.Tensor: + if x.numel() == 0: + return x + rht_dim = 16 + if x.size(-1) % rht_dim != 0: + raise ValueError( + "NVFP4 RHT reference expects the wgrad K dimension to be 16-aligned" + ) + h = torch.ones((1, 1), device=x.device, dtype=x.dtype) + while h.size(0) < rht_dim: + h = torch.cat( + ( + torch.cat((h, h), dim=1), + torch.cat((h, -h), dim=1), + ), + dim=0, + ) + signs = torch.tensor( + [1, 1, 1, -1, 1, -1, -1, -1, -1, -1, -1, 1, -1, 1, -1, -1], + device=x.device, + dtype=x.dtype, + ) + rht = (signs[:, None] * h) * (1 / math.sqrt(rht_dim)) + return (x.contiguous().view(-1, rht_dim) @ rht).view_as(x) + + def _nvfp4_rht_wgrad(x: torch.Tensor, dy: torch.Tensor) -> torch.Tensor: + x_t = _apply_nvfp4_wgrad_rht(x.transpose(0, 1).contiguous()) + dy_t = _apply_nvfp4_wgrad_rht(dy.transpose(0, 1).contiguous()) + return dy_t @ x_t.transpose(0, 1) + + # Reference implementation + xs = torch.split(x_ref, split_sizes.tolist()) + dys = torch.split(dy_ref, split_sizes.tolist()) + probs = torch.split(probs_ref, split_sizes.tolist()) + ys = [] + fc1_inputs = [] + fc1_outputs = [] + fc2_inputs = [] + for group_idx in range(group_size): + x = xs[group_idx] + fc1_out = torch.nn.functional.linear( + x, fc1_ws_ref[group_idx], bias=fc1_bs_ref[group_idx] + ) + if quantization == "nvfp4_rht": + fc1_out.retain_grad() + fc1_inputs.append(x) + fc1_outputs.append(fc1_out) + fc2_in = _apply_activation(fc1_out) + fc2_in = fc2_in * probs[group_idx].unsqueeze(-1) + if quantization == "nvfp4_rht": + fc2_inputs.append(fc2_in) + y = torch.nn.functional.linear(fc2_in, fc2_ws_ref[group_idx]) if bias: - x = x + fc2_bs_ref[group_idx] * probs[group_idx].unsqueeze(-1) - ys.append(x) + y = y + fc2_bs_ref[group_idx] * probs[group_idx].unsqueeze(-1) + ys.append(y) y_ref = torch.cat(ys) y_ref.backward(dy_ref) + if quantization == "nvfp4_rht": + for group_idx in range(group_size): + fc1_dy = fc1_outputs[group_idx].grad + if fc1_dy is None: + fc1_dy = torch.zeros_like(fc1_outputs[group_idx]) + fc1_ws_ref[group_idx].grad = _nvfp4_rht_wgrad(fc1_inputs[group_idx], fc1_dy) + fc2_ws_ref[group_idx].grad = _nvfp4_rht_wgrad( + fc2_inputs[group_idx], + dys[group_idx], + ) # Construct operations recipe = make_recipe(quantization) @@ -3939,40 +3998,32 @@ def _make_module(): return te_ops.Sequential(fc1_op, _make_scaled_act(), fc2_op), fc1_op, fc2_op module, fc1, fc2 = _make_module() - if use_nvfp4_rht_recipe: - module_ref, fc1_ref, fc2_ref = _make_module() - else: - module_ref, fc1_ref, fc2_ref = None, None, None # Copy weights with torch.no_grad(): - target_modules = [(fc1, fc2)] - if use_nvfp4_rht_recipe: - target_modules.append((fc1_ref, fc2_ref)) - for fc1_target, fc2_target in target_modules: + if single_grouped_weight: + fc1_weights = fc1.weight.quantized_tensors + if fc1_weights is None: + fc1_weights = fc1.weight.split_into_quantized_tensors() + fc2_weights = fc2.weight.quantized_tensors + if fc2_weights is None: + fc2_weights = fc2.weight.split_into_quantized_tensors() + for group_idx in range(group_size): if single_grouped_weight: - fc1_weights = fc1_target.weight.quantized_tensors - if fc1_weights is None: - fc1_weights = fc1_target.weight.split_into_quantized_tensors() - fc2_weights = fc2_target.weight.quantized_tensors - if fc2_weights is None: - fc2_weights = fc2_target.weight.split_into_quantized_tensors() - for group_idx in range(group_size): - if single_grouped_weight: - fc1_weights[group_idx].copy_(fc1_ws_test[group_idx]) - fc2_weights[group_idx].copy_(fc2_ws_test[group_idx]) + fc1_weights[group_idx].copy_(fc1_ws_test[group_idx]) + fc2_weights[group_idx].copy_(fc2_ws_test[group_idx]) + else: + getattr(fc1, f"weight{group_idx}").copy_(fc1_ws_test[group_idx]) + getattr(fc2, f"weight{group_idx}").copy_(fc2_ws_test[group_idx]) + if bias: + if single_grouped_bias: + fc1_bparts = fc1.bias.split_into_quantized_tensors() + fc2_bparts = fc2.bias.split_into_quantized_tensors() + fc1_bparts[group_idx].reshape(-1).copy_(fc1_bs_test[group_idx]) + fc2_bparts[group_idx].reshape(-1).copy_(fc2_bs_test[group_idx]) else: - getattr(fc1_target, f"weight{group_idx}").copy_(fc1_ws_test[group_idx]) - getattr(fc2_target, f"weight{group_idx}").copy_(fc2_ws_test[group_idx]) - if bias: - if single_grouped_bias: - fc1_bparts = fc1_target.bias.split_into_quantized_tensors() - fc2_bparts = fc2_target.bias.split_into_quantized_tensors() - fc1_bparts[group_idx].reshape(-1).copy_(fc1_bs_test[group_idx]) - fc2_bparts[group_idx].reshape(-1).copy_(fc2_bs_test[group_idx]) - else: - getattr(fc1_target, f"bias{group_idx}").copy_(fc1_bs_test[group_idx]) - getattr(fc2_target, f"bias{group_idx}").copy_(fc2_bs_test[group_idx]) + getattr(fc1, f"bias{group_idx}").copy_(fc1_bs_test[group_idx]) + getattr(fc2, f"bias{group_idx}").copy_(fc2_bs_test[group_idx]) if accumulate_into_main_grad: # 0.5 sentinel lets us reconstruct ``expected = ref_grad + 0.5`` # below and detect a missed accumulation. @@ -3988,63 +4039,8 @@ def _make_module(): fill_value=main_grad_sentinel, overwrite_main_grad=False, ) - if use_nvfp4_rht_recipe: - fc1_ref, fc2_ref = module_ref[0], module_ref[2] - if single_grouped_weight: - ref_weight_params_for_main_grad = [fc1_ref.weight, fc2_ref.weight] - else: - ref_weight_params_for_main_grad = [ - getattr(fc, f"weight{i}") - for fc in (fc1_ref, fc2_ref) - for i in range(group_size) - ] - MegatronTrainingHelper.init_main_grad_buffers( - ref_weight_params_for_main_grad, - fill_value=main_grad_sentinel, - overwrite_main_grad=False, - ) del fc1_ws_test, fc1_bs_test, fc2_ws_test, fc2_bs_test - y_ref_te = None - x_ref_te = None - probs_ref_te = None - if use_nvfp4_rht_recipe: - x_ref_te = x_test.detach().clone().requires_grad_(x_test.requires_grad) - probs_ref_te = probs_test.detach().clone().requires_grad_(probs_test.requires_grad) - dy_ref_te = dy_test.detach().clone() - - def _clear_grouped_mlp_support_caches() -> None: - for cls in ( - te_ops.fused.ForwardGroupedMLP_CuTeGEMMGLU, - te_ops.fused.BackwardGroupedMLP_CuTeGEMMDGLU, - ): - cache_clear = getattr(cls.is_supported, "cache_clear", None) - if cache_clear is not None: - cache_clear() - - old_fused_grouped_mlp = os.environ.get("NVTE_CUTEDSL_FUSED_GROUPED_MLP") - try: - os.environ["NVTE_CUTEDSL_FUSED_GROUPED_MLP"] = "0" - _clear_grouped_mlp_support_caches() - with te.autocast(enabled=with_quantization, recipe=recipe): - fc2_extra_ref = (split_sizes, probs_ref_te) if bias else (split_sizes,) - y_ref_te = module_ref( - x_ref_te, - split_sizes, - probs_ref_te, - *fc2_extra_ref, - ) - y_ref_te.backward(dy_ref_te) - if delay_wgrad_compute: - module_ref[0].backward_dw() - module_ref[2].backward_dw() - finally: - if old_fused_grouped_mlp is None: - os.environ.pop("NVTE_CUTEDSL_FUSED_GROUPED_MLP", None) - else: - os.environ["NVTE_CUTEDSL_FUSED_GROUPED_MLP"] = old_fused_grouped_mlp - _clear_grouped_mlp_support_caches() - # Fuse ops and perform forward and backward pass with te.autocast(enabled=with_quantization, recipe=recipe): fc2_extra = (split_sizes, probs_test) if bias else (split_sizes,) @@ -4091,60 +4087,6 @@ def _clear_grouped_mlp_support_caches() -> None: if quantization in ("nvfp4", "nvfp4_row_scaled", "nvfp4_4over6", "nvfp4_rht"): tols = {"rtol": 0.25, "atol": 0.5} - if use_nvfp4_rht_recipe: - fc1_ref, fc2_ref = module_ref[0], module_ref[2] - assert_close(y_test, y_ref_te, **tols) - assert_close_grads(x_test, x_ref_te, **tols) - assert_close_grads(probs_test, probs_ref_te, **tols) - for group_idx in range(group_size): - if bias: - if single_grouped_bias: - assert_close( - fc2.bias.grad[group_idx], - fc2_ref.bias.grad[group_idx], - **tols, - ) - assert_close( - fc1.bias.grad[group_idx], - fc1_ref.bias.grad[group_idx], - **tols, - ) - else: - assert_close_grads( - getattr(fc2, f"bias{group_idx}"), - getattr(fc2_ref, f"bias{group_idx}"), - **tols, - ) - assert_close_grads( - getattr(fc1, f"bias{group_idx}"), - getattr(fc1_ref, f"bias{group_idx}"), - **tols, - ) - if not single_grouped_weight and not accumulate_into_main_grad: - assert_close_grads( - getattr(fc2, f"weight{group_idx}"), - getattr(fc2_ref, f"weight{group_idx}"), - **tols, - ) - assert_close_grads( - getattr(fc1, f"weight{group_idx}"), - getattr(fc1_ref, f"weight{group_idx}"), - **tols, - ) - if accumulate_into_main_grad: - expected_main_grads = [ - ref_weight.main_grad for ref_weight in ref_weight_params_for_main_grad - ] - MegatronTrainingHelper.verify_main_grad_accumulation( - weight_params_for_main_grad, - expected_main_grads=expected_main_grads, - **tols, - ) - elif single_grouped_weight: - assert_close(fc1.weight.grad, fc1_ref.weight.grad, **tols) - assert_close(fc2.weight.grad, fc2_ref.weight.grad, **tols) - return - # Check values assert_close(y_test, y_ref, **tols) assert_close_grads(x_test, x_ref, **tols) diff --git a/transformer_engine/pytorch/ops/basic/grouped_linear.py b/transformer_engine/pytorch/ops/basic/grouped_linear.py index dc15bc63b8..e9787f96b2 100644 --- a/transformer_engine/pytorch/ops/basic/grouped_linear.py +++ b/transformer_engine/pytorch/ops/basic/grouped_linear.py @@ -22,7 +22,7 @@ _2X_ACC_DGRAD, _2X_ACC_WGRAD, ) -from ...quantization import FP8GlobalStateManager, Recipe +from ...quantization import FP8GlobalStateManager, QuantizerRole, Recipe from ...quantized_tensor import QuantizedTensorStorage from ...tensor import MXFP8Quantizer, MXFP8Tensor, Quantizer from ...utils import ( @@ -291,6 +291,25 @@ def num_quantizers(self, mode: str) -> int: return self.num_groups return 0 + def get_quantizer_roles(self, mode: str) -> Optional[list[QuantizerRole]]: + name = getattr(self, "name", "") or "" + if mode == "forward": + roles = [] + for _ in range(self.num_groups): + roles.extend( + [ + QuantizerRole(module_type="linear", tensor_type="input", name=name), + QuantizerRole(module_type="linear", tensor_type="weight", name=name), + ] + ) + return roles + if mode == "backward": + return [ + QuantizerRole(module_type="linear", tensor_type="grad_output", name=name) + for _ in range(self.num_groups) + ] + return None + @property def has_bias(self) -> bool: """Whether an additive bias is being applied""" From 8c2404d187b8c307e6f23985b97708ada22ccc8f Mon Sep 17 00:00:00 2001 From: Siddhartha Raman S Date: Fri, 29 May 2026 16:54:16 -0700 Subject: [PATCH 13/14] Simplify NVFP4 RHT grouped MLP reference Signed-off-by: Siddhartha Raman S --- tests/pytorch/test_fusible_ops.py | 50 ------------------------------- 1 file changed, 50 deletions(-) diff --git a/tests/pytorch/test_fusible_ops.py b/tests/pytorch/test_fusible_ops.py index 6038146c33..df607bd6dc 100644 --- a/tests/pytorch/test_fusible_ops.py +++ b/tests/pytorch/test_fusible_ops.py @@ -3880,73 +3880,23 @@ def _apply_activation(x: torch.Tensor) -> torch.Tensor: return torch.nn.functional.relu(x).square() raise ValueError(f"Unexpected grouped MLP activation ({activation})") - def _apply_nvfp4_wgrad_rht(x: torch.Tensor) -> torch.Tensor: - if x.numel() == 0: - return x - rht_dim = 16 - if x.size(-1) % rht_dim != 0: - raise ValueError( - "NVFP4 RHT reference expects the wgrad K dimension to be 16-aligned" - ) - h = torch.ones((1, 1), device=x.device, dtype=x.dtype) - while h.size(0) < rht_dim: - h = torch.cat( - ( - torch.cat((h, h), dim=1), - torch.cat((h, -h), dim=1), - ), - dim=0, - ) - signs = torch.tensor( - [1, 1, 1, -1, 1, -1, -1, -1, -1, -1, -1, 1, -1, 1, -1, -1], - device=x.device, - dtype=x.dtype, - ) - rht = (signs[:, None] * h) * (1 / math.sqrt(rht_dim)) - return (x.contiguous().view(-1, rht_dim) @ rht).view_as(x) - - def _nvfp4_rht_wgrad(x: torch.Tensor, dy: torch.Tensor) -> torch.Tensor: - x_t = _apply_nvfp4_wgrad_rht(x.transpose(0, 1).contiguous()) - dy_t = _apply_nvfp4_wgrad_rht(dy.transpose(0, 1).contiguous()) - return dy_t @ x_t.transpose(0, 1) - # Reference implementation xs = torch.split(x_ref, split_sizes.tolist()) - dys = torch.split(dy_ref, split_sizes.tolist()) probs = torch.split(probs_ref, split_sizes.tolist()) ys = [] - fc1_inputs = [] - fc1_outputs = [] - fc2_inputs = [] for group_idx in range(group_size): x = xs[group_idx] fc1_out = torch.nn.functional.linear( x, fc1_ws_ref[group_idx], bias=fc1_bs_ref[group_idx] ) - if quantization == "nvfp4_rht": - fc1_out.retain_grad() - fc1_inputs.append(x) - fc1_outputs.append(fc1_out) fc2_in = _apply_activation(fc1_out) fc2_in = fc2_in * probs[group_idx].unsqueeze(-1) - if quantization == "nvfp4_rht": - fc2_inputs.append(fc2_in) y = torch.nn.functional.linear(fc2_in, fc2_ws_ref[group_idx]) if bias: y = y + fc2_bs_ref[group_idx] * probs[group_idx].unsqueeze(-1) ys.append(y) y_ref = torch.cat(ys) y_ref.backward(dy_ref) - if quantization == "nvfp4_rht": - for group_idx in range(group_size): - fc1_dy = fc1_outputs[group_idx].grad - if fc1_dy is None: - fc1_dy = torch.zeros_like(fc1_outputs[group_idx]) - fc1_ws_ref[group_idx].grad = _nvfp4_rht_wgrad(fc1_inputs[group_idx], fc1_dy) - fc2_ws_ref[group_idx].grad = _nvfp4_rht_wgrad( - fc2_inputs[group_idx], - dys[group_idx], - ) # Construct operations recipe = make_recipe(quantization) From 63d28eaa28c0a83de191f8b3d12e9549691516c7 Mon Sep 17 00:00:00 2001 From: Siddhartha Raman S Date: Fri, 29 May 2026 17:12:12 -0700 Subject: [PATCH 14/14] Fix PyTorch ops lint Signed-off-by: Siddhartha Raman S --- transformer_engine/pytorch/ops/_common.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/transformer_engine/pytorch/ops/_common.py b/transformer_engine/pytorch/ops/_common.py index 0b9453f679..87911d76f4 100644 --- a/transformer_engine/pytorch/ops/_common.py +++ b/transformer_engine/pytorch/ops/_common.py @@ -9,7 +9,7 @@ import functools import math from importlib.metadata import PackageNotFoundError, version as get_pkg_version -from typing import Any, Optional +from typing import Optional import torch from packaging.version import Version as PkgVersion