From 8b6b12544aeaacd1513dd6203d54ac93092ecf62 Mon Sep 17 00:00:00 2001 From: Ziang Li Date: Fri, 22 May 2026 21:35:06 -0700 Subject: [PATCH 1/2] Use cuDNN for row-scaled NVFP4 grouped GEMM Signed-off-by: Ziang Li --- tests/pytorch/nvfp4/test_nvfp4_gemm_exact.py | 129 ++++++++++- .../pytorch/cpp_extensions/gemm.py | 204 +++++++++++++++++- 2 files changed, 331 insertions(+), 2 deletions(-) diff --git a/tests/pytorch/nvfp4/test_nvfp4_gemm_exact.py b/tests/pytorch/nvfp4/test_nvfp4_gemm_exact.py index bd4d029729..a7988cd9d7 100644 --- a/tests/pytorch/nvfp4/test_nvfp4_gemm_exact.py +++ b/tests/pytorch/nvfp4/test_nvfp4_gemm_exact.py @@ -329,12 +329,139 @@ def check_nvfp4_row_scaled_grouped_gemm_matches_per_gemm( single_output=single_output, ) + if single_output: + grouped_slices = torch.split(grouped_out, m_splits, dim=0) + else: + grouped_slices = grouped_out + uses_cudnn_grouped_path = ( + out_dtype in (torch.bfloat16, torch.float16) + and not use_4over6 + and all(m % 256 == 0 for m in m_splits) + and k % 128 == 0 + and n % 128 == 0 + ) + atol = 0.5 if uses_cudnn_grouped_path else 0.0 + rtol = 0.25 if uses_cudnn_grouped_path else 0.0 + for grouped, ref in zip(grouped_slices, expected): + torch.testing.assert_close(grouped, ref, atol=atol, rtol=rtol) + + +@pytest.mark.skipif(not recipe_available, reason=reason_for_no_recipe) +@pytest.mark.parametrize( + "use_bias, single_output", + [(False, False), (True, True)], + ids=["no_bias_list_output", "bias_single_output"], +) +def test_nvfp4_row_scaled_grouped_gemm_uses_cudnn_quant_wrapper( + use_bias: bool, + single_output: bool, + monkeypatch, +): + if torch.cuda.get_device_capability() < (10, 0): + pytest.skip("Requires SM100+ for cuDNN grouped GEMM quant kernel.") + + try: + import cudnn + except ImportError as exc: + pytest.skip(f"cudnn frontend unavailable: {exc}") + if not hasattr(cudnn, "grouped_gemm_quant_wrapper_sm100"): + pytest.skip("grouped_gemm_quant_wrapper_sm100 unavailable") + + te_dtype = tex.DType.kFloat4E2M1 + device = "cuda" + dtype = torch.bfloat16 + m_splits = [256, 512] + k = 128 + n = 128 + torch.manual_seed(29) + torch.cuda.manual_seed(29) + + x_quantizer = NVFP4Quantizer( + fp4_dtype=te_dtype, + rowwise=True, + columnwise=False, + with_amax_reduction=False, + amax_reduction_group=None, + with_rht=False, + with_post_rht_amax=False, + row_scaled_nvfp4=True, + ) + w_quantizer = NVFP4Quantizer( + fp4_dtype=te_dtype, + rowwise=True, + columnwise=True, + with_amax_reduction=False, + amax_reduction_group=None, + with_rht=False, + with_post_rht_amax=False, + ) + + x_nvfp4 = [] + w_nvfp4 = [] + bias = [] + expected = [] + for m in m_splits: + x = torch.randn((m, k), dtype=dtype, device=device) + w = torch.randn((n, k), dtype=dtype, device=device) + x_nvfp4.append( + x_quantizer.update_quantized( + x, + x_quantizer.make_empty(x.shape, dtype=dtype, device=device), + ) + ) + w_nvfp4.append( + w_quantizer.update_quantized( + w, + w_quantizer.make_empty(w.shape, dtype=dtype, device=device), + ) + ) + bias.append(torch.randn(n, dtype=torch.bfloat16, device=device) if use_bias else None) + expected.append( + general_gemm( + w_nvfp4[-1], + x_nvfp4[-1], + out_dtype=dtype, + layout="TN", + bias=bias[-1], + )[0] + ) + + calls = [] + original_wrapper = cudnn.grouped_gemm_quant_wrapper_sm100 + + def traced_wrapper(*args, **kwargs): + calls.append(kwargs) + return original_wrapper(*args, **kwargs) + + monkeypatch.setattr(cudnn, "grouped_gemm_quant_wrapper_sm100", traced_wrapper) + if single_output: + out = [torch.empty((sum(m_splits), n), dtype=dtype, device=device)] + else: + out = [torch.empty((m, n), dtype=dtype, device=device) for m in m_splits] + grouped_out, _, _ = general_grouped_gemm( + w_nvfp4, + x_nvfp4, + out, + quantization_params=[None] * len(m_splits), + out_dtype=dtype, + layout="TN", + m_splits=m_splits, + bias=bias, + use_bias=use_bias, + single_output=single_output, + ) + + assert len(calls) == 1 + assert calls[0]["sf_vec_size"] == 16 + assert calls[0]["row_scale_tensor"].shape == (sum(m_splits),) + assert calls[0]["b_major"] == "k" + assert (calls[0]["bias_tensor"] is not None) == use_bias if single_output: grouped_slices = torch.split(grouped_out, m_splits, dim=0) else: grouped_slices = grouped_out for grouped, ref in zip(grouped_slices, expected): - torch.testing.assert_close(grouped, ref, atol=0.0, rtol=0.0) + torch.testing.assert_close(grouped, ref, atol=0.5, rtol=0.25) def check_nvfp4_row_scaled_gemm_matches_emulated( diff --git a/transformer_engine/pytorch/cpp_extensions/gemm.py b/transformer_engine/pytorch/cpp_extensions/gemm.py index edf2c1e1c2..a27706fbb4 100644 --- a/transformer_engine/pytorch/cpp_extensions/gemm.py +++ b/transformer_engine/pytorch/cpp_extensions/gemm.py @@ -10,7 +10,7 @@ import functools import torch import transformer_engine_torch as tex -from ..constants import TE_DType +from ..constants import NVFP4_BLOCK_SCALING_SIZE, TE_DType from ..utils import get_sm_count, _empty_tensor from ..quantized_tensor import Quantizer @@ -103,6 +103,190 @@ def _nvfp4_row_scaled_gemm_inputs( ) +def _ceil_div(a: int, b: int) -> int: + """Integer ceil division.""" + return (a + b - 1) // b + + +def _nvfp4_cudnn_scale_layout(scale: torch.Tensor, m: int, k: int) -> torch.Tensor: + """Pack compact NVFP4 scales into the cuDNN CuTe layout.""" + m_tiles = _ceil_div(m, 128) + sf_k = _ceil_div(k, NVFP4_BLOCK_SCALING_SIZE) + k_tiles = _ceil_div(sf_k, 4) + compact = scale.view(torch.float8_e4m3fn)[: m_tiles * 128, : k_tiles * 4].contiguous() + logical_layout = compact.view(1, m_tiles, 4, 32, k_tiles, 4).permute(3, 2, 1, 5, 4, 0) + base = torch.empty( + (1, m_tiles, k_tiles, 32, 4, 4), + dtype=torch.float8_e4m3fn, + device=scale.device, + ) + cudnn_layout = base.permute(3, 4, 1, 5, 2, 0) + cudnn_layout.copy_(logical_layout) + return cudnn_layout + + +def _nvfp4_rowwise_data_logical_view(tensor: NVFP4TensorStorage) -> torch.Tensor: + """Return a logical FP4 rowwise data view with the packed buffer's data pointer.""" + packed = tensor._rowwise_data + rows = int(tensor.size(0)) + cols = int(tensor.size(1)) + return torch.as_strided(packed, (rows, cols), (packed.stride(0), 0)) + + +def _try_cudnn_grouped_gemm_quant_for_row_scaled_nvfp4( + A: List[torch.Tensor], + B: List[torch.Tensor], + out: List[torch.Tensor], + *, + transa: bool, + transb: bool, + m_splits: Optional[List[int]], + bias: List[torch.Tensor], + use_bias: bool, + single_output: bool, + accumulate: bool, + gelu: bool, + grad: bool, + use_split_accumulator: bool, +) -> Optional[torch.Tensor]: + """Use cuDNN grouped GEMM quant for supported row-scaled NVFP4 grouped GEMMs. + + Returns ``None`` when the inputs are outside the currently supported cuDNN + path so callers can fall back to the existing per-GEMM implementation. + """ + if grad or gelu or accumulate or use_split_accumulator: + return None + if not transa or transb: + return None + if not out or out[0].dtype not in (torch.bfloat16, torch.float16): + return None + if not all(isinstance(tensor, NVFP4TensorStorage) for tensor in A + B): + return None + if any(_is_nvfp4_row_scaled_tensor(tensor) for tensor in A): + return None + if not all(_is_nvfp4_row_scaled_tensor(tensor) for tensor in B): + return None + if any(getattr(tensor, "_nvfp4_use_4over6", False) for tensor in A + B): + return None + + num_gemms = len(A) + m_splits_list = ( + list(m_splits) if m_splits is not None else [int(tensor.size(0)) for tensor in B] + ) + if len(m_splits_list) != num_gemms: + return None + if any(m % 256 != 0 for m in m_splits_list): + return None + + k = int(B[0].size(1)) + n = int(A[0].size(0)) + if k % 128 != 0 or n % 128 != 0: + return None + if any(tuple(tensor.size()) != (n, k) for tensor in A): + return None + if any( + int(tensor.size(0)) != m or int(tensor.size(1)) != k for tensor, m in zip(B, m_splits_list) + ): + return None + + try: + from cudnn import ( + grouped_gemm_quant_wrapper_sm100, + ) # pylint: disable=import-outside-toplevel + except ImportError: + return None + + device = B[0]._rowwise_data.device + total_m = sum(m_splits_list) + + a_data = torch.cat( + [tensor._rowwise_data.view(m, k // 2) for tensor, m in zip(B, m_splits_list)], + dim=0, + ) + a_tensor = a_data.view(torch.float4_e2m1fn_x2).unsqueeze(0).permute(1, 2, 0) + + sf_cols = _ceil_div(_ceil_div(k, NVFP4_BLOCK_SCALING_SIZE), 4) * 4 + sfa_compact = torch.cat( + [ + tensor._rowwise_scale_inv.view(torch.float8_e4m3fn)[:m, :sf_cols] + for tensor, m in zip(B, m_splits_list) + ], + dim=0, + ) + sfa_tensor = _nvfp4_cudnn_scale_layout(sfa_compact, total_m, k) + + sfb_tensors = [_nvfp4_cudnn_scale_layout(tensor._rowwise_scale_inv, n, k) for tensor in A] + b_ptrs, sfb_ptrs, _sfb_keepalive = tex.get_device_pointer_for_data_and_scales( + [_nvfp4_rowwise_data_logical_view(tensor) for tensor in A], + sfb_tensors, + False, + True, + A[0]._fp4_dtype, + ) + + row_scale = [] + for weight, activation in zip(A, B): + weight_amax = weight._amax_rowwise if transa else weight._amax_columnwise + if weight_amax is None or activation._amax_rowwise is None: + return None + if weight_amax.numel() != 1: + return None + activation_decode_scale = activation._amax_rowwise / ( + float(activation._nvfp4_e4m3_max) * 6.0 + ) + weight_decode_scale = weight_amax / (float(weight._nvfp4_e4m3_max) * 6.0) + row_scale.append((activation_decode_scale * weight_decode_scale).to(dtype=torch.float32)) + row_scale_tensor = torch.cat(row_scale).contiguous() + + bias_tensor = None + if use_bias: + if any(tensor.numel() == 0 for tensor in bias): + return None + bias_tensor = torch.stack(bias, dim=0).transpose(0, 1) + + padded_offsets = torch.tensor( + [sum(m_splits_list[: i + 1]) for i in range(num_gemms)], + dtype=torch.int32, + device=device, + ) + alpha_tensor = torch.ones(num_gemms, dtype=torch.float32, device=device) + prob_tensor = torch.ones(total_m, 1, 1, dtype=torch.float32, device=device) + + result = grouped_gemm_quant_wrapper_sm100( + a_tensor=a_tensor, + b_ptrs=b_ptrs, + sfa_tensor=sfa_tensor, + sfb_ptrs=sfb_ptrs, + padded_offsets=padded_offsets, + alpha_tensor=alpha_tensor, + bias_tensor=bias_tensor, + norm_const_tensor=None, + prob_tensor=prob_tensor, + row_scale_tensor=row_scale_tensor, + acc_dtype=torch.float32, + d_dtype=out[0].dtype, + cd_major="n", + sf_vec_size=NVFP4_BLOCK_SCALING_SIZE, + discrete_col_sfd=True, + b_dtype=torch.float4_e2m1fn_x2, + b_major="k", + n=n, + current_stream=torch.cuda.current_stream().cuda_stream, + use_dynamic_sched=True, + ) + d_tensor = result["d_tensor"].squeeze(-1) + + if single_output: + out[0].copy_(d_tensor) + return out[0] + + start = 0 + for output, m in zip(out, m_splits_list): + output.copy_(d_tensor[start : start + m]) + start += m + return out + + def general_gemm( A: torch.Tensor, B: torch.Tensor, @@ -329,6 +513,24 @@ def general_grouped_gemm( assert ( m_splits is not None ), "Row-scaled NVFP4 grouped GEMM requires m_splits with single output." + cudnn_out = _try_cudnn_grouped_gemm_quant_for_row_scaled_nvfp4( + A, + B, + out, + transa=transa, + transb=transb, + m_splits=m_splits, + bias=bias, + use_bias=use_bias, + single_output=single_output, + accumulate=accumulate, + gelu=gelu, + grad=grad, + use_split_accumulator=use_split_accumulator, + ) + if cudnn_out is not None: + return cudnn_out, grad_bias, gelu_input + out_init = out[0] if single_output else None if single_output: start_idx = 0 From a0032dd5b80a06036a950692f97fdf0d778531fa Mon Sep 17 00:00:00 2001 From: Ziang Li Date: Mon, 25 May 2026 20:58:57 -0700 Subject: [PATCH 2/2] Require cuDNN for row-scaled NVFP4 grouped GEMM Signed-off-by: Ziang Li --- tests/pytorch/nvfp4/test_nvfp4_gemm_exact.py | 68 ++++++---- .../pytorch/cpp_extensions/gemm.py | 124 +++++++----------- 2 files changed, 92 insertions(+), 100 deletions(-) diff --git a/tests/pytorch/nvfp4/test_nvfp4_gemm_exact.py b/tests/pytorch/nvfp4/test_nvfp4_gemm_exact.py index a7988cd9d7..4c7e0e9421 100644 --- a/tests/pytorch/nvfp4/test_nvfp4_gemm_exact.py +++ b/tests/pytorch/nvfp4/test_nvfp4_gemm_exact.py @@ -258,6 +258,13 @@ def check_nvfp4_row_scaled_grouped_gemm_matches_per_gemm( torch.cuda.manual_seed(23) num_gemms = len(m_splits) + uses_cudnn_grouped_path = ( + out_dtype in (torch.bfloat16, torch.float16) + and not use_4over6 + and all(m % 256 == 0 for m in m_splits) + and k % 128 == 0 + and n % 128 == 0 + ) x_quantizer = NVFP4Quantizer( fp4_dtype=te_dtype, @@ -301,49 +308,56 @@ def check_nvfp4_row_scaled_grouped_gemm_matches_per_gemm( ) ) bias.append(torch.randn(n, dtype=torch.bfloat16, device=device) if use_bias else None) - expected.append( - general_gemm( - w_nvfp4[-1], - x_nvfp4[-1], - out_dtype=out_dtype, - layout="TN", - bias=bias[-1], - )[0] - ) + if uses_cudnn_grouped_path: + expected.append( + general_gemm( + w_nvfp4[-1], + x_nvfp4[-1], + out_dtype=out_dtype, + layout="TN", + bias=bias[-1], + )[0] + ) if single_output: out = [torch.empty((sum(m_splits), n), dtype=out_dtype, device=device)] else: out = [torch.empty((m, n), dtype=out_dtype, device=device) for m in m_splits] - grouped_out, _, _ = general_grouped_gemm( + grouped_gemm_args = ( w_nvfp4, x_nvfp4, out, - quantization_params=[None] * num_gemms, - out_dtype=out_dtype, - layout="TN", - m_splits=m_splits, - bias=bias, - use_bias=use_bias, - single_output=single_output, ) + grouped_gemm_kwargs = { + "quantization_params": [None] * num_gemms, + "out_dtype": out_dtype, + "layout": "TN", + "m_splits": m_splits, + "bias": bias, + "use_bias": use_bias, + "single_output": single_output, + } + if not uses_cudnn_grouped_path: + with pytest.raises((NotImplementedError, ValueError)): + general_grouped_gemm(*grouped_gemm_args, **grouped_gemm_kwargs) + return + + try: + import cudnn + except ImportError as exc: + pytest.skip(f"cudnn frontend unavailable: {exc}") + if not hasattr(cudnn, "grouped_gemm_quant_wrapper_sm100"): + pytest.skip("grouped_gemm_quant_wrapper_sm100 unavailable") + + grouped_out, _, _ = general_grouped_gemm(*grouped_gemm_args, **grouped_gemm_kwargs) if single_output: grouped_slices = torch.split(grouped_out, m_splits, dim=0) else: grouped_slices = grouped_out - uses_cudnn_grouped_path = ( - out_dtype in (torch.bfloat16, torch.float16) - and not use_4over6 - and all(m % 256 == 0 for m in m_splits) - and k % 128 == 0 - and n % 128 == 0 - ) - atol = 0.5 if uses_cudnn_grouped_path else 0.0 - rtol = 0.25 if uses_cudnn_grouped_path else 0.0 for grouped, ref in zip(grouped_slices, expected): - torch.testing.assert_close(grouped, ref, atol=atol, rtol=rtol) + torch.testing.assert_close(grouped, ref, atol=0.5, rtol=0.25) @pytest.mark.skipif(not recipe_available, reason=reason_for_no_recipe) diff --git a/transformer_engine/pytorch/cpp_extensions/gemm.py b/transformer_engine/pytorch/cpp_extensions/gemm.py index a27706fbb4..13d535b7b2 100644 --- a/transformer_engine/pytorch/cpp_extensions/gemm.py +++ b/transformer_engine/pytorch/cpp_extensions/gemm.py @@ -133,7 +133,7 @@ def _nvfp4_rowwise_data_logical_view(tensor: NVFP4TensorStorage) -> torch.Tensor return torch.as_strided(packed, (rows, cols), (packed.stride(0), 0)) -def _try_cudnn_grouped_gemm_quant_for_row_scaled_nvfp4( +def _cudnn_grouped_gemm_quant_for_row_scaled_nvfp4( A: List[torch.Tensor], B: List[torch.Tensor], out: List[torch.Tensor], @@ -148,53 +148,57 @@ def _try_cudnn_grouped_gemm_quant_for_row_scaled_nvfp4( gelu: bool, grad: bool, use_split_accumulator: bool, -) -> Optional[torch.Tensor]: - """Use cuDNN grouped GEMM quant for supported row-scaled NVFP4 grouped GEMMs. - - Returns ``None`` when the inputs are outside the currently supported cuDNN - path so callers can fall back to the existing per-GEMM implementation. - """ +) -> torch.Tensor: + """Use cuDNN grouped GEMM quant for row-scaled NVFP4 grouped GEMMs.""" if grad or gelu or accumulate or use_split_accumulator: - return None + raise NotImplementedError( + "cuDNN row-scaled NVFP4 grouped GEMM supports fprop without GELU, " + "accumulation, or split accumulator only." + ) if not transa or transb: - return None + raise NotImplementedError("cuDNN row-scaled NVFP4 grouped GEMM supports TN layout only.") if not out or out[0].dtype not in (torch.bfloat16, torch.float16): - return None + raise NotImplementedError( + "cuDNN row-scaled NVFP4 grouped GEMM supports BF16/FP16 outputs only." + ) if not all(isinstance(tensor, NVFP4TensorStorage) for tensor in A + B): - return None + raise TypeError("cuDNN row-scaled NVFP4 grouped GEMM requires NVFP4 inputs.") if any(_is_nvfp4_row_scaled_tensor(tensor) for tensor in A): - return None + raise NotImplementedError( + "cuDNN row-scaled NVFP4 grouped GEMM does not support row-scaled A." + ) if not all(_is_nvfp4_row_scaled_tensor(tensor) for tensor in B): - return None + raise NotImplementedError("cuDNN row-scaled NVFP4 grouped GEMM requires row-scaled B.") if any(getattr(tensor, "_nvfp4_use_4over6", False) for tensor in A + B): - return None + raise NotImplementedError("cuDNN row-scaled NVFP4 grouped GEMM does not support 4over6.") num_gemms = len(A) m_splits_list = ( list(m_splits) if m_splits is not None else [int(tensor.size(0)) for tensor in B] ) if len(m_splits_list) != num_gemms: - return None + raise ValueError("m_splits length must match the number of grouped GEMMs.") if any(m % 256 != 0 for m in m_splits_list): - return None + raise NotImplementedError( + "cuDNN row-scaled NVFP4 grouped GEMM requires M multiples of 256." + ) k = int(B[0].size(1)) n = int(A[0].size(0)) if k % 128 != 0 or n % 128 != 0: - return None + raise NotImplementedError( + "cuDNN row-scaled NVFP4 grouped GEMM requires K and N multiples of 128." + ) if any(tuple(tensor.size()) != (n, k) for tensor in A): - return None + raise ValueError("All grouped GEMM A tensors must have the same (N, K) shape.") if any( int(tensor.size(0)) != m or int(tensor.size(1)) != k for tensor, m in zip(B, m_splits_list) ): - return None + raise ValueError("Grouped GEMM B tensor shapes must match m_splits and K.") - try: - from cudnn import ( - grouped_gemm_quant_wrapper_sm100, - ) # pylint: disable=import-outside-toplevel - except ImportError: - return None + import cudnn # pylint: disable=import-outside-toplevel + + grouped_gemm_quant_wrapper_sm100 = cudnn.grouped_gemm_quant_wrapper_sm100 device = B[0]._rowwise_data.device total_m = sum(m_splits_list) @@ -228,9 +232,11 @@ def _try_cudnn_grouped_gemm_quant_for_row_scaled_nvfp4( for weight, activation in zip(A, B): weight_amax = weight._amax_rowwise if transa else weight._amax_columnwise if weight_amax is None or activation._amax_rowwise is None: - return None + raise ValueError( + "Row-scaled NVFP4 grouped GEMM requires activation and weight amax metadata." + ) if weight_amax.numel() != 1: - return None + raise ValueError("Row-scaled NVFP4 grouped GEMM requires tensor-scaled weights.") activation_decode_scale = activation._amax_rowwise / ( float(activation._nvfp4_e4m3_max) * 6.0 ) @@ -241,7 +247,7 @@ def _try_cudnn_grouped_gemm_quant_for_row_scaled_nvfp4( bias_tensor = None if use_bias: if any(tensor.numel() == 0 for tensor in bias): - return None + raise ValueError("Bias tensors must be non-empty when use_bias=True.") bias_tensor = torch.stack(bias, dim=0).transpose(0, 1) padded_offsets = torch.tensor( @@ -513,53 +519,25 @@ def general_grouped_gemm( assert ( m_splits is not None ), "Row-scaled NVFP4 grouped GEMM requires m_splits with single output." - cudnn_out = _try_cudnn_grouped_gemm_quant_for_row_scaled_nvfp4( - A, - B, - out, - transa=transa, - transb=transb, - m_splits=m_splits, - bias=bias, - use_bias=use_bias, - single_output=single_output, - accumulate=accumulate, - gelu=gelu, - grad=grad, - use_split_accumulator=use_split_accumulator, - ) - if cudnn_out is not None: - return cudnn_out, grad_bias, gelu_input - - out_init = out[0] if single_output else None - if single_output: - start_idx = 0 - out_views = [] - for i in range(num_gemms): - size = m_splits[i] - out_views.append(out_init[start_idx : start_idx + size]) - start_idx += size - else: - out_views = out - for i in range(num_gemms): - if out_views[i].numel() == 0: - continue - general_gemm( - A[i], - B[i], - quantization_params=quantization_params[i], - out_dtype=out_views[i].dtype, - out=out_views[i], - gelu=gelu, + return ( + _cudnn_grouped_gemm_quant_for_row_scaled_nvfp4( + A, + B, + out, + transa=transa, + transb=transb, + m_splits=m_splits, + bias=bias, + use_bias=use_bias, + single_output=single_output, accumulate=accumulate, - layout=layout, - bias=bias[i] if use_bias else None, - use_split_accumulator=use_split_accumulator, + gelu=gelu, grad=grad, - ) - if single_output: - out = out_init - return out, grad_bias, gelu_input + use_split_accumulator=use_split_accumulator, + ), + grad_bias, + gelu_input, + ) if isinstance(quantization_params[0], DebugQuantizer): assert not gelu, "GELU not supported in debug mode"