From 1c004f24f71c1e66287502f897eb75d2a2648951 Mon Sep 17 00:00:00 2001 From: Kshitij Lakhani Date: Tue, 26 May 2026 17:44:33 -0700 Subject: [PATCH 1/8] [PyTorch] Add opt-in MXFP8 support on sm120 and dedicated GEMM test Adds NVTE_ENABLE_MXFP8_SM120 environment variable to unblock MXFP8 testing on sm120 (compute capability 12.0) devices. Default behavior unchanged; MXFP8 remains gated off on sm120 without explicit opt-in since not all GEMM layouts are currently supported. Also adds tests/pytorch/mxfp8/test_mxfp8_gemm_exact.py: a focused layout x shape x dtype matrix exercising MXFP8 single GEMM via the underlying general_gemm call directly. The TN layout is exercised across small/medium/transformer-sized shapes and BF16/FP32 outputs. NN and NT layouts on sm120 are marked strict-xfail; the suite will fail-on-XPASS once full layout support is added so the markers can be removed. --- tests/pytorch/mxfp8/test_mxfp8_gemm_exact.py | 176 +++++++++++++++++++ transformer_engine/pytorch/quantization.py | 2 + 2 files changed, 178 insertions(+) create mode 100644 tests/pytorch/mxfp8/test_mxfp8_gemm_exact.py diff --git a/tests/pytorch/mxfp8/test_mxfp8_gemm_exact.py b/tests/pytorch/mxfp8/test_mxfp8_gemm_exact.py new file mode 100644 index 0000000000..c00e6f1915 --- /dev/null +++ b/tests/pytorch/mxfp8/test_mxfp8_gemm_exact.py @@ -0,0 +1,176 @@ +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""Layout-by-layout exactness test for MXFP8 single GEMM. + +Each `te.Linear` forward + backward issues three cuBLAS GEMMs (see +`transformer_engine/pytorch/module/linear.py`): + + layout role A B Out + ------ ------- -------------------- --------------------- ------------------- + TN fwd (out_f, in_f) (batch, in_f) (batch, out_f) + NN dgrad (out_f, in_f) (batch, out_f) (batch, in_f) + NT wgrad (batch, in_f) (batch, out_f) (out_f, in_f) + +This test drives the underlying ``general_gemm`` call directly for each layout +with MXFP8-quantized operands, and compares the cuBLAS output to a +dequantized-operand reference matmul. + +Background: With cuBLAS 13.5.1.27 on sm_120, MXFP8 GEMM is only supported in +the TN layout; NN and NT return ``CUBLAS_STATUS_NOT_SUPPORTED`` from +``cublasLtMatmulAlgoGetHeuristic``. See ``Testing/cublas_logs/README.md`` and +``Testing/repro_mxfp8_layouts.cu`` in this repo for a layout-by-layout cuBLAS +reproducer. The cases for NN/NT are marked ``strict=True`` xfail on sm_120 so +that the test suite automatically flags an XPASS when cuBLAS adds support. +""" + +from __future__ import annotations + +import os + +import pytest +import torch + +import transformer_engine.pytorch as te +import transformer_engine_torch as tex +from transformer_engine.pytorch import MXFP8Quantizer +from transformer_engine.pytorch.cpp_extensions.gemm import general_gemm + + +_MXFP8_AVAILABLE, _MXFP8_SKIP_REASON = te.is_mxfp8_available(return_reason=True) + + +def _is_sm120() -> bool: + if not torch.cuda.is_available(): + return False + return torch.cuda.get_device_capability(0) == (12, 0) + + +_SM120_NON_TN_XFAIL = pytest.mark.xfail( + _is_sm120(), + strict=True, + reason=( + "MXFP8 NN/NT GEMM is not supported by cuBLAS on sm_120 (cublasLt 13.5.x" + " returns CUBLAS_STATUS_NOT_SUPPORTED from cublasLtMatmulAlgoGetHeuristic)." + " Remove this xfail when cuBLAS adds support." + ), +) + + +def _quantize_mxfp8(t: torch.Tensor): + """MXFP8 quantize with both row-wise and column-wise data populated.""" + return MXFP8Quantizer( + fp8_dtype=tex.DType.kFloat8E4M3, + rowwise=True, + columnwise=True, + )(t) + + +def _reference_for_layout( + layout: str, + w_q, + x_q, + dy_q, + out_dtype: torch.dtype, +) -> torch.Tensor: + """Reference matmul computed in fp32 from dequantized operands. + + Mirrors the three GEMMs done by ``te.Linear`` (see module docstring). + """ + if layout == "TN": # fwd: x @ w.T -> (batch, out_f) + a = w_q.dequantize(dtype=torch.float32) + b = x_q.dequantize(dtype=torch.float32) + return (b @ a.T).to(out_dtype) + if layout == "NN": # dgrad: dy @ w -> (batch, in_f) + a = w_q.dequantize(dtype=torch.float32) + b = dy_q.dequantize(dtype=torch.float32) + return (b @ a).to(out_dtype) + if layout == "NT": # wgrad: dy.T @ x -> (out_f, in_f) + a = x_q.dequantize(dtype=torch.float32) + b = dy_q.dequantize(dtype=torch.float32) + return (b.T @ a).to(out_dtype) + raise ValueError(f"Unknown layout {layout!r}") + + +def _shapes_for_layout(layout: str, w_q, x_q, dy_q): + if layout == "TN": + return w_q, x_q + if layout == "NN": + return w_q, dy_q + if layout == "NT": + return x_q, dy_q + raise ValueError(f"Unknown layout {layout!r}") + + +# Shape triples are (batch, in_features, out_features), all multiples of 32 as +# required by the MXFP8 quantizer (32-element scaling blocks). +_SHAPES = [ + (32, 32, 64), + (128, 128, 128), + (256, 1024, 512), + (2048, 2048, 8192), +] + + +@pytest.mark.skipif(not _MXFP8_AVAILABLE, reason=_MXFP8_SKIP_REASON) +@pytest.mark.parametrize( + "layout", + [ + pytest.param("TN", id="TN_fwd"), + pytest.param("NN", id="NN_dgrad", marks=_SM120_NON_TN_XFAIL), + pytest.param("NT", id="NT_wgrad", marks=_SM120_NON_TN_XFAIL), + ], +) +@pytest.mark.parametrize( + "batch, in_features, out_features", + _SHAPES, + ids=[f"b{b}_in{i}_out{o}" for (b, i, o) in _SHAPES], +) +@pytest.mark.parametrize( + "in_dtype", + [torch.bfloat16, torch.float16], + ids=["bf16", "fp16"], +) +@pytest.mark.parametrize( + "out_dtype", + [torch.bfloat16, torch.float32], + ids=["out_bf16", "out_fp32"], +) +def test_mxfp8_single_gemm_versus_reference( + layout: str, + batch: int, + in_features: int, + out_features: int, + in_dtype: torch.dtype, + out_dtype: torch.dtype, +): + """One cuBLAS MXFP8 GEMM per layout, compared against a dequantized reference.""" + device = "cuda" + torch.manual_seed(0) + torch.cuda.manual_seed(0) + + w = torch.randn(out_features, in_features, dtype=in_dtype, device=device) + x = torch.randn(batch, in_features, dtype=in_dtype, device=device) + dy = torch.randn(batch, out_features, dtype=in_dtype, device=device) + + w_q = _quantize_mxfp8(w) + x_q = _quantize_mxfp8(x) + dy_q = _quantize_mxfp8(dy) + + A, B = _shapes_for_layout(layout, w_q, x_q, dy_q) + out, *_ = general_gemm(A, B, out_dtype=out_dtype, layout=layout) + + ref = _reference_for_layout(layout, w_q, x_q, dy_q, out_dtype) + assert tuple(out.shape) == tuple( + ref.shape + ), f"shape mismatch: cuBLAS {tuple(out.shape)} vs ref {tuple(ref.shape)}" + + # MXFP8 tolerance, aligned with tests/pytorch/utils.py::quantization_tols("mxfp8") + # which returns dtype_tols(kFloat8E4M3) == dict(rtol=0.125, atol=0.0675). + torch.testing.assert_close( + out.to(torch.float32), + ref.to(torch.float32), + atol=0.0675, + rtol=0.125, + ) diff --git a/transformer_engine/pytorch/quantization.py b/transformer_engine/pytorch/quantization.py index e503b4b560..5aadf13d85 100644 --- a/transformer_engine/pytorch/quantization.py +++ b/transformer_engine/pytorch/quantization.py @@ -162,6 +162,8 @@ def _compute_fp8_support() -> Tuple[bool, str]: def _compute_mxfp8_support() -> Tuple[bool, str]: """Return if fp8 support is available""" if get_device_compute_capability() >= (12, 0): + if os.getenv("NVTE_ENABLE_MXFP8_SM120", "0") == "1": + return True, "" return False, "MXFP8 (for all gemm layouts) is not supported on 12.0+ architectures yet." if get_device_compute_capability() >= (10, 0): # blackwell and above return True, "" From 42cf70c7857208d8782d85d33bea5855a7f419cd Mon Sep 17 00:00:00 2001 From: Kshitij Lakhani Date: Wed, 27 May 2026 11:18:26 -0700 Subject: [PATCH 2/8] Add support check for sm120 and cublas version Signed-off-by: Kshitij Lakhani --- transformer_engine/pytorch/quantization.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/transformer_engine/pytorch/quantization.py b/transformer_engine/pytorch/quantization.py index 5aadf13d85..7fc9e55a6d 100644 --- a/transformer_engine/pytorch/quantization.py +++ b/transformer_engine/pytorch/quantization.py @@ -160,11 +160,15 @@ def _compute_fp8_support() -> Tuple[bool, str]: def _compute_mxfp8_support() -> Tuple[bool, str]: - """Return if fp8 support is available""" + """Return if MXFP8 support is available.""" if get_device_compute_capability() >= (12, 0): - if os.getenv("NVTE_ENABLE_MXFP8_SM120", "0") == "1": + cublaslt_version = tex.get_cublasLt_version() + if cublaslt_version >= 130600: return True, "" - return False, "MXFP8 (for all gemm layouts) is not supported on 12.0+ architectures yet." + return False, ( + "MXFP8 on sm_120 requires cuBLASLt >= 13.6.0.2 for NN/NT GEMM " + f"support (loaded cuBLASLt={cublaslt_version})." + ) if get_device_compute_capability() >= (10, 0): # blackwell and above return True, "" return False, "Device compute capability 10.0 or higher required for MXFP8 execution." From 045fa2db38ffa5cc3c3cb62f55155bc0a6fc354d Mon Sep 17 00:00:00 2001 From: Kshitij Lakhani Date: Wed, 27 May 2026 11:31:00 -0700 Subject: [PATCH 3/8] Add support for NN and NT tests to run on sm120 if cuBLAS version is 13.6+ Signed-off-by: Kshitij Lakhani --- tests/pytorch/mxfp8/test_mxfp8_gemm_exact.py | 32 +++++++++++++------- 1 file changed, 21 insertions(+), 11 deletions(-) diff --git a/tests/pytorch/mxfp8/test_mxfp8_gemm_exact.py b/tests/pytorch/mxfp8/test_mxfp8_gemm_exact.py index c00e6f1915..4401ed24ca 100644 --- a/tests/pytorch/mxfp8/test_mxfp8_gemm_exact.py +++ b/tests/pytorch/mxfp8/test_mxfp8_gemm_exact.py @@ -17,18 +17,19 @@ with MXFP8-quantized operands, and compares the cuBLAS output to a dequantized-operand reference matmul. -Background: With cuBLAS 13.5.1.27 on sm_120, MXFP8 GEMM is only supported in -the TN layout; NN and NT return ``CUBLAS_STATUS_NOT_SUPPORTED`` from -``cublasLtMatmulAlgoGetHeuristic``. See ``Testing/cublas_logs/README.md`` and +Background: With cuBLAS 13.5.1.27 on sm_120, MXFP8 GEMM was only supported in +the TN layout; NN and NT returned ``CUBLAS_STATUS_NOT_SUPPORTED`` from +``cublasLtMatmulAlgoGetHeuristic``. cuBLAS 13.6.0.2 adds NN/NT support on +sm_120; both layouts then run end-to-end and match the dequantized reference +within MXFP8 tolerance. See ``Testing/cublas_logs/README.md`` and ``Testing/repro_mxfp8_layouts.cu`` in this repo for a layout-by-layout cuBLAS -reproducer. The cases for NN/NT are marked ``strict=True`` xfail on sm_120 so -that the test suite automatically flags an XPASS when cuBLAS adds support. +reproducer. NN/NT are marked ``strict=True`` xfail only when the loaded +cuBLASLt is below the version that adds support, so the suite automatically +flags an XPASS once cuBLAS is upgraded in-place. """ from __future__ import annotations -import os - import pytest import torch @@ -40,6 +41,9 @@ _MXFP8_AVAILABLE, _MXFP8_SKIP_REASON = te.is_mxfp8_available(return_reason=True) +CUBLASLT_MXFP8_FULL_LAYOUTS_SM120 = 130600 +_CUBLASLT_VERSION = tex.get_cublasLt_version() + def _is_sm120() -> bool: if not torch.cuda.is_available(): @@ -47,13 +51,19 @@ def _is_sm120() -> bool: return torch.cuda.get_device_capability(0) == (12, 0) +def _needs_sm120_non_tn_xfail() -> bool: + """NN/NT MXFP8 is unsupported by cuBLAS on sm_120 below cuBLASLt 13.6.0.2.""" + return _is_sm120() and _CUBLASLT_VERSION < CUBLASLT_MXFP8_FULL_LAYOUTS_SM120 + + _SM120_NON_TN_XFAIL = pytest.mark.xfail( - _is_sm120(), + _needs_sm120_non_tn_xfail(), strict=True, reason=( - "MXFP8 NN/NT GEMM is not supported by cuBLAS on sm_120 (cublasLt 13.5.x" - " returns CUBLAS_STATUS_NOT_SUPPORTED from cublasLtMatmulAlgoGetHeuristic)." - " Remove this xfail when cuBLAS adds support." + f"MXFP8 NN/NT GEMM is not supported by cuBLAS on sm_120 below cuBLASLt" + f" {CUBLASLT_MXFP8_FULL_LAYOUTS_SM120} (loaded={_CUBLASLT_VERSION});" + " cublasLtMatmulAlgoGetHeuristic returns CUBLAS_STATUS_NOT_SUPPORTED." + " Upgrade cuBLAS to ≥ 13.6.0.2 to unblock these layouts." ), ) From 52044a0575a9a10298fe2bc0eb162c1fe0ddfa6c Mon Sep 17 00:00:00 2001 From: Kshitij Lakhani Date: Wed, 27 May 2026 17:43:54 -0700 Subject: [PATCH 4/8] Drop dedicated MXFP8 GEMM exact test Remove tests/pytorch/mxfp8/test_mxfp8_gemm_exact.py. The TN/NN/NT MXFP8 GEMM code paths it was added to localize are already exercised end-to-end on sm_120 (with cuBLASLt >= 13.6.0.2) by the existing te.Linear / te.LayerNormLinear / te.GroupedLinear / te.TransformerLayer numerics tests in tests/pytorch/test_numerics.py via the MXFP8BlockScaling entry in fp8_recipes (each Linear forward + backward dispatches the three cuBLAS GEMMs as fwd=TN, dgrad=NN, wgrad=NT). The runtime _compute_mxfp8_support gate added in the earlier commits on this branch already module-skips MXFP8 below cuBLASLt 13.6.0.2 on sm_120, so the per-layout strict-xfail layer in this file is redundant. Out-of-tree triage material (Testing/repro_mxfp8_layouts.cu and the Testing/repro_mxfp8_layouts.py driver) remains available if a future cuBLAS regression needs layout-localized signal again. --- tests/pytorch/mxfp8/test_mxfp8_gemm_exact.py | 186 ------------------- 1 file changed, 186 deletions(-) delete mode 100644 tests/pytorch/mxfp8/test_mxfp8_gemm_exact.py diff --git a/tests/pytorch/mxfp8/test_mxfp8_gemm_exact.py b/tests/pytorch/mxfp8/test_mxfp8_gemm_exact.py deleted file mode 100644 index 4401ed24ca..0000000000 --- a/tests/pytorch/mxfp8/test_mxfp8_gemm_exact.py +++ /dev/null @@ -1,186 +0,0 @@ -# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# -# See LICENSE for license information. - -"""Layout-by-layout exactness test for MXFP8 single GEMM. - -Each `te.Linear` forward + backward issues three cuBLAS GEMMs (see -`transformer_engine/pytorch/module/linear.py`): - - layout role A B Out - ------ ------- -------------------- --------------------- ------------------- - TN fwd (out_f, in_f) (batch, in_f) (batch, out_f) - NN dgrad (out_f, in_f) (batch, out_f) (batch, in_f) - NT wgrad (batch, in_f) (batch, out_f) (out_f, in_f) - -This test drives the underlying ``general_gemm`` call directly for each layout -with MXFP8-quantized operands, and compares the cuBLAS output to a -dequantized-operand reference matmul. - -Background: With cuBLAS 13.5.1.27 on sm_120, MXFP8 GEMM was only supported in -the TN layout; NN and NT returned ``CUBLAS_STATUS_NOT_SUPPORTED`` from -``cublasLtMatmulAlgoGetHeuristic``. cuBLAS 13.6.0.2 adds NN/NT support on -sm_120; both layouts then run end-to-end and match the dequantized reference -within MXFP8 tolerance. See ``Testing/cublas_logs/README.md`` and -``Testing/repro_mxfp8_layouts.cu`` in this repo for a layout-by-layout cuBLAS -reproducer. NN/NT are marked ``strict=True`` xfail only when the loaded -cuBLASLt is below the version that adds support, so the suite automatically -flags an XPASS once cuBLAS is upgraded in-place. -""" - -from __future__ import annotations - -import pytest -import torch - -import transformer_engine.pytorch as te -import transformer_engine_torch as tex -from transformer_engine.pytorch import MXFP8Quantizer -from transformer_engine.pytorch.cpp_extensions.gemm import general_gemm - - -_MXFP8_AVAILABLE, _MXFP8_SKIP_REASON = te.is_mxfp8_available(return_reason=True) - -CUBLASLT_MXFP8_FULL_LAYOUTS_SM120 = 130600 -_CUBLASLT_VERSION = tex.get_cublasLt_version() - - -def _is_sm120() -> bool: - if not torch.cuda.is_available(): - return False - return torch.cuda.get_device_capability(0) == (12, 0) - - -def _needs_sm120_non_tn_xfail() -> bool: - """NN/NT MXFP8 is unsupported by cuBLAS on sm_120 below cuBLASLt 13.6.0.2.""" - return _is_sm120() and _CUBLASLT_VERSION < CUBLASLT_MXFP8_FULL_LAYOUTS_SM120 - - -_SM120_NON_TN_XFAIL = pytest.mark.xfail( - _needs_sm120_non_tn_xfail(), - strict=True, - reason=( - f"MXFP8 NN/NT GEMM is not supported by cuBLAS on sm_120 below cuBLASLt" - f" {CUBLASLT_MXFP8_FULL_LAYOUTS_SM120} (loaded={_CUBLASLT_VERSION});" - " cublasLtMatmulAlgoGetHeuristic returns CUBLAS_STATUS_NOT_SUPPORTED." - " Upgrade cuBLAS to ≥ 13.6.0.2 to unblock these layouts." - ), -) - - -def _quantize_mxfp8(t: torch.Tensor): - """MXFP8 quantize with both row-wise and column-wise data populated.""" - return MXFP8Quantizer( - fp8_dtype=tex.DType.kFloat8E4M3, - rowwise=True, - columnwise=True, - )(t) - - -def _reference_for_layout( - layout: str, - w_q, - x_q, - dy_q, - out_dtype: torch.dtype, -) -> torch.Tensor: - """Reference matmul computed in fp32 from dequantized operands. - - Mirrors the three GEMMs done by ``te.Linear`` (see module docstring). - """ - if layout == "TN": # fwd: x @ w.T -> (batch, out_f) - a = w_q.dequantize(dtype=torch.float32) - b = x_q.dequantize(dtype=torch.float32) - return (b @ a.T).to(out_dtype) - if layout == "NN": # dgrad: dy @ w -> (batch, in_f) - a = w_q.dequantize(dtype=torch.float32) - b = dy_q.dequantize(dtype=torch.float32) - return (b @ a).to(out_dtype) - if layout == "NT": # wgrad: dy.T @ x -> (out_f, in_f) - a = x_q.dequantize(dtype=torch.float32) - b = dy_q.dequantize(dtype=torch.float32) - return (b.T @ a).to(out_dtype) - raise ValueError(f"Unknown layout {layout!r}") - - -def _shapes_for_layout(layout: str, w_q, x_q, dy_q): - if layout == "TN": - return w_q, x_q - if layout == "NN": - return w_q, dy_q - if layout == "NT": - return x_q, dy_q - raise ValueError(f"Unknown layout {layout!r}") - - -# Shape triples are (batch, in_features, out_features), all multiples of 32 as -# required by the MXFP8 quantizer (32-element scaling blocks). -_SHAPES = [ - (32, 32, 64), - (128, 128, 128), - (256, 1024, 512), - (2048, 2048, 8192), -] - - -@pytest.mark.skipif(not _MXFP8_AVAILABLE, reason=_MXFP8_SKIP_REASON) -@pytest.mark.parametrize( - "layout", - [ - pytest.param("TN", id="TN_fwd"), - pytest.param("NN", id="NN_dgrad", marks=_SM120_NON_TN_XFAIL), - pytest.param("NT", id="NT_wgrad", marks=_SM120_NON_TN_XFAIL), - ], -) -@pytest.mark.parametrize( - "batch, in_features, out_features", - _SHAPES, - ids=[f"b{b}_in{i}_out{o}" for (b, i, o) in _SHAPES], -) -@pytest.mark.parametrize( - "in_dtype", - [torch.bfloat16, torch.float16], - ids=["bf16", "fp16"], -) -@pytest.mark.parametrize( - "out_dtype", - [torch.bfloat16, torch.float32], - ids=["out_bf16", "out_fp32"], -) -def test_mxfp8_single_gemm_versus_reference( - layout: str, - batch: int, - in_features: int, - out_features: int, - in_dtype: torch.dtype, - out_dtype: torch.dtype, -): - """One cuBLAS MXFP8 GEMM per layout, compared against a dequantized reference.""" - device = "cuda" - torch.manual_seed(0) - torch.cuda.manual_seed(0) - - w = torch.randn(out_features, in_features, dtype=in_dtype, device=device) - x = torch.randn(batch, in_features, dtype=in_dtype, device=device) - dy = torch.randn(batch, out_features, dtype=in_dtype, device=device) - - w_q = _quantize_mxfp8(w) - x_q = _quantize_mxfp8(x) - dy_q = _quantize_mxfp8(dy) - - A, B = _shapes_for_layout(layout, w_q, x_q, dy_q) - out, *_ = general_gemm(A, B, out_dtype=out_dtype, layout=layout) - - ref = _reference_for_layout(layout, w_q, x_q, dy_q, out_dtype) - assert tuple(out.shape) == tuple( - ref.shape - ), f"shape mismatch: cuBLAS {tuple(out.shape)} vs ref {tuple(ref.shape)}" - - # MXFP8 tolerance, aligned with tests/pytorch/utils.py::quantization_tols("mxfp8") - # which returns dtype_tols(kFloat8E4M3) == dict(rtol=0.125, atol=0.0675). - torch.testing.assert_close( - out.to(torch.float32), - ref.to(torch.float32), - atol=0.0675, - rtol=0.125, - ) From c6329fe8a79242fe24549246864fc327e59f7e15 Mon Sep 17 00:00:00 2001 From: Kshitij Lakhani Date: Thu, 28 May 2026 17:29:17 -0700 Subject: [PATCH 5/8] Guard MXFP8 grouped GEMM entrypoints with support check cuBLASLt 13.6.0.2 supports single-GEMM MXFP8 on sm_120 / sm_121 but not the grouped variant. Route general_grouped_gemm and general_grouped_gemm_for_grouped_tensor through check_mxfp8_grouped_gemm_support() and raise NotImplementedError when unsupported, instead of failing opaquely inside cuBLAS. Signed-off-by: Kshitij Lakhani --- .../pytorch/cpp_extensions/gemm.py | 44 +++++++++++++++++++ 1 file changed, 44 insertions(+) diff --git a/transformer_engine/pytorch/cpp_extensions/gemm.py b/transformer_engine/pytorch/cpp_extensions/gemm.py index edf2c1e1c2..13c25bb986 100644 --- a/transformer_engine/pytorch/cpp_extensions/gemm.py +++ b/transformer_engine/pytorch/cpp_extensions/gemm.py @@ -13,9 +13,11 @@ from ..constants import TE_DType from ..utils import get_sm_count, _empty_tensor +from ..quantization import check_mxfp8_grouped_gemm_support from ..quantized_tensor import Quantizer from ..tensor.storage.float8_blockwise_tensor_storage import Float8BlockwiseQTensorStorage from ..tensor.storage.grouped_tensor_storage import GroupedTensorStorage +from ..tensor.storage.mxfp8_tensor_storage import MXFP8TensorStorage from ..tensor.storage.nvfp4_tensor_storage import NVFP4TensorStorage from ..tensor.utils import is_custom from ..custom_recipes.gemm import custom_gemm @@ -76,6 +78,44 @@ def _is_nvfp4_row_scaled_tensor(tensor: torch.Tensor) -> bool: return isinstance(tensor, NVFP4TensorStorage) and tensor._row_scaled_nvfp4 +def _is_mxfp8_storage(tensor) -> bool: + """Whether ``tensor`` is MXFP8-quantized storage (per-tensor or grouped).""" + if tensor is None: + return False + if isinstance(tensor, MXFP8TensorStorage): + return True + if isinstance(tensor, GroupedTensorStorage): + quantizer = getattr(tensor, "quantizer", None) + if quantizer is not None: + try: + recipe = quantizer._get_compatible_recipe() + except (AttributeError, NotImplementedError): + return False + return bool(recipe.mxfp8()) + return False + + +def _check_mxfp8_grouped_gemm_inputs(*tensor_iterables) -> None: + """Raise ``NotImplementedError`` if MXFP8 grouped GEMM is requested but + not supported on the current device / cuBLASLt combination. + + Accepts one or more iterables of tensors (e.g. ``A``, ``B`` from + :func:`general_grouped_gemm`) so callers don't have to flatten lists. + """ + for tensors in tensor_iterables: + if tensors is None: + continue + if isinstance(tensors, (list, tuple)): + has_mxfp8 = any(_is_mxfp8_storage(t) for t in tensors) + else: + has_mxfp8 = _is_mxfp8_storage(tensors) + if has_mxfp8: + supported, reason = check_mxfp8_grouped_gemm_support() + if not supported: + raise NotImplementedError(reason) + return + + def _nvfp4_row_scaled_gemm_inputs( A: NVFP4TensorStorage, B: NVFP4TensorStorage, @@ -294,6 +334,8 @@ def general_grouped_gemm( """ TN layout Grouped GEMM with fp8 inputs. """ + _check_mxfp8_grouped_gemm_inputs(A, B) + num_gemms = len(A) transa = layout[0] == "T" @@ -470,6 +512,8 @@ def general_grouped_gemm_for_grouped_tensor( The caller must ensure that GroupedTensor metadata is already compatible with the underlying GEMM implementation (e.g., aligned offsets and output metadata layout). """ + _check_mxfp8_grouped_gemm_inputs(A, B) + assert layout in ("TN", "NN", "NT"), f"GEMM layout {layout} not supported." if grad: raise NotImplementedError("grad is not supported for grouped_tensor GEMM yet.") From b1b484b3a8406eeb02e595c26c6eae19ae38257f Mon Sep 17 00:00:00 2001 From: Kshitij Lakhani Date: Thu, 28 May 2026 17:37:32 -0700 Subject: [PATCH 6/8] Add separate gate for MXFP8 grouped GEMM support Introduce _compute_mxfp8_grouped_gemm_support / check_mxfp8_grouped_gemm_support and a public is_mxfp8_grouped_gemm_available helper so callers (te.GroupedLinear, general_grouped_gemm[_for_grouped_tensor], and grouped-GEMM tests) can gate on grouped MXFP8 separately from single-GEMM MXFP8. On sm_120 / sm_121, cuBLASLt 13.6.0.2 enables single MXFP8 GEMM (TN/NN/NT) but not the grouped variant; the new gate returns False there with a descriptive reason. Also widen the single-GEMM gate to sm_121 alongside sm_120. Signed-off-by: Kshitij Lakhani --- transformer_engine/pytorch/__init__.py | 1 + transformer_engine/pytorch/quantization.py | 90 ++++++++++++++++++++-- 2 files changed, 86 insertions(+), 5 deletions(-) diff --git a/transformer_engine/pytorch/__init__.py b/transformer_engine/pytorch/__init__.py index 7653d5992e..533d64fd4d 100644 --- a/transformer_engine/pytorch/__init__.py +++ b/transformer_engine/pytorch/__init__.py @@ -45,6 +45,7 @@ from transformer_engine.pytorch.quantization import quantized_model_init from transformer_engine.pytorch.quantization import is_fp8_available from transformer_engine.pytorch.quantization import is_mxfp8_available +from transformer_engine.pytorch.quantization import is_mxfp8_grouped_gemm_available from transformer_engine.pytorch.quantization import is_fp8_block_scaling_available from transformer_engine.pytorch.quantization import is_nvfp4_available from transformer_engine.pytorch.quantization import get_default_recipe diff --git a/transformer_engine/pytorch/quantization.py b/transformer_engine/pytorch/quantization.py index 7fc9e55a6d..b7c59ccff2 100644 --- a/transformer_engine/pytorch/quantization.py +++ b/transformer_engine/pytorch/quantization.py @@ -37,6 +37,7 @@ "quantized_model_init", "is_fp8_available", "is_mxfp8_available", + "is_mxfp8_grouped_gemm_available", "is_fp8_block_scaling_available", "is_nvfp4_available", "get_default_recipe", @@ -49,6 +50,7 @@ _FP8_SUPPORT: Optional[Tuple[bool, str]] = None _MXFP8_SUPPORT: Optional[Tuple[bool, str]] = None +_MXFP8_GROUPED_GEMM_SUPPORT: Optional[Tuple[bool, str]] = None _NVFP4_SUPPORT: Optional[Tuple[bool, str]] = None _FP8_BLOCK_SCALING_SUPPORT: Optional[Tuple[bool, str]] = None @@ -160,13 +162,17 @@ def _compute_fp8_support() -> Tuple[bool, str]: def _compute_mxfp8_support() -> Tuple[bool, str]: - """Return if MXFP8 support is available.""" - if get_device_compute_capability() >= (12, 0): + """Return if MXFP8 single-GEMM support is available. + + On sm_120 / sm_121 this covers the single-GEMM TN/NN/NT paths via cuBLASLt >= 13.6.0.2; + grouped MXFP8 GEMM is gated separately by :func:`_compute_mxfp8_grouped_gemm_support`. + """ + if get_device_compute_capability() in ((12, 0), (12, 1)): cublaslt_version = tex.get_cublasLt_version() if cublaslt_version >= 130600: return True, "" return False, ( - "MXFP8 on sm_120 requires cuBLASLt >= 13.6.0.2 for NN/NT GEMM " + "MXFP8 on sm_120 / sm_121 requires cuBLASLt >= 13.6.0.2 for NN/NT GEMM " f"support (loaded cuBLASLt={cublaslt_version})." ) if get_device_compute_capability() >= (10, 0): # blackwell and above @@ -174,6 +180,30 @@ def _compute_mxfp8_support() -> Tuple[bool, str]: return False, "Device compute capability 10.0 or higher required for MXFP8 execution." +def _compute_mxfp8_grouped_gemm_support() -> Tuple[bool, str]: + """Return if MXFP8 *grouped*-GEMM support is available. + + This is strictly a subset of single-GEMM MXFP8 support: it inherits the + requirements of :func:`_compute_mxfp8_support` and then additionally + requires that the loaded cuBLASLt implements grouped MXFP8 GEMM on the + current device. On sm_120 / sm_121 the cuBLASLt 13.6.0.2 release supports + single MXFP8 GEMM (TN/NN/NT) but does NOT yet implement grouped MXFP8 GEMM; + callers should treat that case as unsupported until a future cuBLAS adds it. + """ + base_ok, base_reason = _compute_mxfp8_support() + if not base_ok: + return False, base_reason + if get_device_compute_capability() in ((12, 0), (12, 1)): + cublaslt_version = tex.get_cublasLt_version() + return False, ( + "MXFP8 grouped GEMM is not yet supported on sm_120 / sm_121 by the loaded " + f"cuBLASLt={cublaslt_version} (single-GEMM MXFP8 is supported with " + "cuBLASLt >= 13.6.0.2). Use a non-grouped module or switch to a " + "non-MXFP8 recipe (e.g. Float8CurrentScaling) for grouped-GEMM workloads." + ) + return True, "" + + def _compute_nvfp4_support() -> Tuple[bool, str]: """Return if nvfp4 support is available""" if get_device_compute_capability() >= (10, 0): # blackwell and above @@ -202,13 +232,22 @@ def check_fp8_support() -> Tuple[bool, str]: @torch.compiler.assume_constant_result def check_mxfp8_support() -> Tuple[bool, str]: - """Return if MXFP8 support is available.""" + """Return if MXFP8 single-GEMM support is available.""" global _MXFP8_SUPPORT if _MXFP8_SUPPORT is None: _MXFP8_SUPPORT = _compute_mxfp8_support() return _MXFP8_SUPPORT +@torch.compiler.assume_constant_result +def check_mxfp8_grouped_gemm_support() -> Tuple[bool, str]: + """Return if MXFP8 grouped-GEMM support is available.""" + global _MXFP8_GROUPED_GEMM_SUPPORT + if _MXFP8_GROUPED_GEMM_SUPPORT is None: + _MXFP8_GROUPED_GEMM_SUPPORT = _compute_mxfp8_grouped_gemm_support() + return _MXFP8_GROUPED_GEMM_SUPPORT + + @torch.compiler.assume_constant_result def check_nvfp4_support() -> Tuple[bool, str]: """Return if NVFP4 support is available.""" @@ -329,7 +368,15 @@ def is_fp8_available(return_reason: bool = False) -> Union[bool, Tuple[bool, str def is_mxfp8_available(return_reason: bool = False) -> Union[bool, Tuple[bool, str]]: """ - Determine if support is available for the MXFP8 recipe. + Determine if support is available for the MXFP8 recipe (single GEMM). + + This reports support for the single-GEMM MXFP8 dispatch (the common TN/NN/NT + fwd/dgrad/wgrad path used by ``te.Linear`` / ``te.LayerNormLinear`` / + ``te.LayerNormMLP`` / ``te.TransformerLayer``). Grouped MXFP8 GEMM + (e.g. ``te.GroupedLinear``, ``general_grouped_gemm``, + ``general_grouped_gemm_for_grouped_tensor``) is gated separately by + :func:`is_mxfp8_grouped_gemm_available` because it may be unsupported on + some device + cuBLASLt combinations even when single-GEMM MXFP8 is supported. Parameters ---------- @@ -345,6 +392,34 @@ def is_mxfp8_available(return_reason: bool = False) -> Union[bool, Tuple[bool, s return check_mxfp8_support()[0] +def is_mxfp8_grouped_gemm_available( + return_reason: bool = False, +) -> Union[bool, Tuple[bool, str]]: + """ + Determine if support is available for MXFP8 grouped GEMM. + + MXFP8 grouped GEMM is a strict superset of single-GEMM MXFP8 in terms of + requirements: the underlying cuBLASLt must implement the grouped MXFP8 + GEMM heuristic for the current device. Use this check to gate + ``te.GroupedLinear`` / ``general_grouped_gemm`` / + ``general_grouped_gemm_for_grouped_tensor`` dispatch and to skip + MXFP8 grouped-GEMM tests on devices where the underlying cuBLASLt + does not (yet) implement that path. + + Parameters + ---------- + return_reason : bool, optional + If ``False`` (default), return only a boolean indicating availability. + If ``True``, return a tuple ``(is_available, reason)`` where ``reason`` provides + a human-readable explanation when required support is not available. The reason + will be an empty string if support is available. + + """ + if return_reason: + return check_mxfp8_grouped_gemm_support() + return check_mxfp8_grouped_gemm_support()[0] + + def is_fp8_block_scaling_available(return_reason: bool = False) -> Union[bool, Tuple[bool, str]]: """ Determine if support is available for the FP8 block scaling recipe. @@ -430,6 +505,11 @@ def is_mxfp8_available(cls) -> Tuple[bool, str]: """Return if MXFP8/current scaling support is available.""" return check_mxfp8_support() + @classmethod + def is_mxfp8_grouped_gemm_available(cls) -> Tuple[bool, str]: + """Return if MXFP8 grouped-GEMM support is available.""" + return check_mxfp8_grouped_gemm_support() + @classmethod def is_fp8_block_scaling_available(cls) -> Tuple[bool, str]: """Return if Float8 block scaling support is available.""" From 0370659a330dcacf0d249bb23edef9ed44965cf6 Mon Sep 17 00:00:00 2001 From: Kshitij Lakhani Date: Thu, 28 May 2026 17:40:32 -0700 Subject: [PATCH 7/8] Skip MXFP8 grouped-GEMM test cases when unsupported - Probe is_mxfp8_grouped_gemm_available in test_fusible_ops, test_numerics, and test_sanity, and pytest.skip MXFP8 grouped_linear / padding_grouped_linear / grouped_gemm cases (plus a maybe_skip_quantization_for_grouped_gemm helper in test_fusible_ops) with the gate's reason. Signed-off-by: Kshitij Lakhani --- tests/pytorch/test_fusible_ops.py | 17 +++++++++++++++++ tests/pytorch/test_numerics.py | 16 ++++++++++++++++ tests/pytorch/test_sanity.py | 5 +++++ 3 files changed, 38 insertions(+) diff --git a/tests/pytorch/test_fusible_ops.py b/tests/pytorch/test_fusible_ops.py index 1ced32e1a5..ba5f699da3 100644 --- a/tests/pytorch/test_fusible_ops.py +++ b/tests/pytorch/test_fusible_ops.py @@ -61,6 +61,9 @@ # Check for supported quantization schemes fp8_available, reason_for_no_fp8 = te.is_fp8_available(return_reason=True) mxfp8_available, reason_for_no_mxfp8 = te.is_mxfp8_available(return_reason=True) +mxfp8_grouped_gemm_available, reason_for_no_mxfp8_grouped_gemm = ( + te.is_mxfp8_grouped_gemm_available(return_reason=True) +) nvfp4_available, reason_for_no_nvfp4 = te.is_nvfp4_available(return_reason=True) # Supported data types @@ -89,6 +92,17 @@ def _reset_rng_states_per_test(): yield +def maybe_skip_quantization_for_grouped_gemm(quantization: Optional[str]) -> None: + """Skip MXFP8 grouped-GEMM cases on devices where they're not yet supported. + + cuBLASLt 13.6.0.2 supports single-GEMM MXFP8 on sm_120 but not grouped + MXFP8 GEMM; the grouped-MXFP8 dispatch in ``general_grouped_gemm`` / + ``general_grouped_gemm_for_grouped_tensor`` will refuse those inputs. + """ + if quantization == "mxfp8" and not mxfp8_grouped_gemm_available: + pytest.skip(reason_for_no_mxfp8_grouped_gemm) + + def maybe_skip_quantization( quantization: Optional[str], *, @@ -2111,6 +2125,7 @@ def test_grouped_linear( # Skip invalid configurations maybe_skip_quantization(quantization, dims=in_shape, device=device, dtype=dtype) maybe_skip_quantization(quantization, dims=out_shape) + maybe_skip_quantization_for_grouped_gemm(quantization) if quantization is None and (quantized_compute or quantized_weight): pytest.skip("Quantization scheme is not specified") if quantization is not None and not (quantized_compute or quantized_weight): @@ -2316,6 +2331,7 @@ def test_grouped_linear_cuda_graph_safe( pytest.skip("quantized_weight requires a quantization recipe") if single_grouped_bias and not bias: pytest.skip("single_grouped_bias requires bias=True") + maybe_skip_quantization_for_grouped_gemm(quantization) # Split sizes (statically pinned for graph capture) split_sizes = [split_alignment * (i + 1) for i in range(group_size)] @@ -3741,6 +3757,7 @@ def test_grouped_mlp( raise ValueError(f"Unexpected grouped MLP activation ({activation})") activation_is_glu = is_glu_activation(scaled_act) maybe_skip_quantization(quantization, dims=in_shape, device=device, dtype=dtype) + maybe_skip_quantization_for_grouped_gemm(quantization) if single_grouped_weight and quantization != "mxfp8": pytest.skip("single_grouped_weight is only supported for MXFP8 quantization") if single_grouped_bias and not bias: diff --git a/tests/pytorch/test_numerics.py b/tests/pytorch/test_numerics.py index 5f82bfcba2..6d96e27ebf 100644 --- a/tests/pytorch/test_numerics.py +++ b/tests/pytorch/test_numerics.py @@ -41,6 +41,7 @@ get_device_compute_capability, is_fp8_available, is_mxfp8_available, + is_mxfp8_grouped_gemm_available, is_fp8_block_scaling_available, is_bf16_available, is_nvfp4_available, @@ -60,6 +61,9 @@ # Only run FP8 tests on supported devices. fp8_available, reason_for_no_fp8 = is_fp8_available(return_reason=True) mxfp8_available, reason_for_no_mxfp8 = is_mxfp8_available(return_reason=True) +mxfp8_grouped_gemm_available, reason_for_no_mxfp8_grouped_gemm = ( + is_mxfp8_grouped_gemm_available(return_reason=True) +) fp8_block_scaling_available = is_fp8_block_scaling_available() nvfp4_available = is_nvfp4_available() @@ -1954,6 +1958,8 @@ def test_grouped_linear_accuracy( pytest.skip("FP8 parameters are not supported in debug mode.") if NVTE_TEST_NVINSPECT_ENABLED and delay_wgrad_compute: pytest.skip("Delayed wgrad compute is not supported in debug mode.") + if fp8 and recipe.mxfp8() and not mxfp8_grouped_gemm_available: + pytest.skip(reason_for_no_mxfp8_grouped_gemm) skip_unsupported_backward_override( "grouped_linear", recipe, getattr(recipe, "backward_override", None) ) @@ -2101,6 +2107,8 @@ def test_grouped_linear_accuracy_save_original_input( pytest.skip("DelayedScaling recipe is not supported with save_original_input") if NVTE_TEST_NVINSPECT_ENABLED and delay_wgrad_compute: pytest.skip("Delayed wgrad compute is not supported in debug mode.") + if fp8 and recipe.mxfp8() and not mxfp8_grouped_gemm_available: + pytest.skip(reason_for_no_mxfp8_grouped_gemm) skip_unsupported_backward_override( "grouped_linear", recipe, getattr(recipe, "backward_override", None) ) @@ -2310,6 +2318,8 @@ def test_padding_grouped_linear_accuracy( ): if fp8_model_params and NVTE_TEST_NVINSPECT_ENABLED: pytest.skip("FP8 parameters are not supported in debug mode.") + if fp8 and recipe.mxfp8() and not mxfp8_grouped_gemm_available: + pytest.skip(reason_for_no_mxfp8_grouped_gemm) skip_unsupported_backward_override( "grouped_linear", recipe, getattr(recipe, "backward_override", None) ) @@ -2390,6 +2400,8 @@ def test_padding_grouped_linear_accuracy_save_original_input( pytest.skip("FP8 parameters are not supported in debug mode.") if fp8 and recipe.delayed(): pytest.skip("DelayedScaling recipe is not supported with save_original_input") + if fp8 and recipe.mxfp8() and not mxfp8_grouped_gemm_available: + pytest.skip(reason_for_no_mxfp8_grouped_gemm) skip_unsupported_backward_override( "grouped_linear", recipe, getattr(recipe, "backward_override", None) ) @@ -3095,6 +3107,8 @@ def test_grouped_gemm_grouped_tensor_zero_work(layout, accumulate, quant_type) - pytest.skip("bfloat16 is required for grouped GEMM test.") if quant_type == "mxfp8" and not mxfp8_available: pytest.skip(reason_for_no_mxfp8) + if quant_type == "mxfp8" and not mxfp8_grouped_gemm_available: + pytest.skip(reason_for_no_mxfp8_grouped_gemm) z = 4 k, n = 256, 256 @@ -3263,6 +3277,8 @@ def test_grouped_gemm_grouped_tensor_mxfp8( pytest.skip("Grouped GEMM requires Blackwell (SM100) or newer.") if dtype == torch.bfloat16 and not is_bf16_available(): pytest.skip("bfloat16 is required for grouped GEMM test.") + if not mxfp8_grouped_gemm_available: + pytest.skip(reason_for_no_mxfp8_grouped_gemm) torch.manual_seed(0) z, m, k, n = shape diff --git a/tests/pytorch/test_sanity.py b/tests/pytorch/test_sanity.py index 27eafbecdc..29a79f0cf2 100644 --- a/tests/pytorch/test_sanity.py +++ b/tests/pytorch/test_sanity.py @@ -44,6 +44,9 @@ fp8_available, reason_for_no_fp8 = te.is_fp8_available(return_reason=True) fp8_block_scaling_available, _ = te.is_fp8_block_scaling_available(return_reason=True) mxfp8_available, reason_for_no_mxfp8 = te.is_mxfp8_available(return_reason=True) +mxfp8_grouped_gemm_available, reason_for_no_mxfp8_grouped_gemm = ( + te.is_mxfp8_grouped_gemm_available(return_reason=True) +) nvfp4_available, _ = te.is_nvfp4_available(return_reason=True) # Record initial RNG state from script run. @@ -607,6 +610,8 @@ def test_sanity_grouped_linear( if fp8_recipe is not None: if not is_fp8_supported(config): pytest.skip("Model config does not support FP8") + if fp8_recipe.mxfp8() and not mxfp8_grouped_gemm_available: + pytest.skip(reason_for_no_mxfp8_grouped_gemm) if fp8_recipe.nvfp4(): if not getattr(fp8_recipe, "row_scaled_activation", False): pytest.skip("NVFP4 not supported for grouped linear") From 66df3eba33c644d97bfc17b848c479974f1bd500 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 29 May 2026 00:42:40 +0000 Subject: [PATCH 8/8] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/pytorch/test_fusible_ops.py | 4 ++-- tests/pytorch/test_numerics.py | 4 ++-- tests/pytorch/test_sanity.py | 4 ++-- transformer_engine/pytorch/quantization.py | 22 ++++++++++++++-------- 4 files changed, 20 insertions(+), 14 deletions(-) diff --git a/tests/pytorch/test_fusible_ops.py b/tests/pytorch/test_fusible_ops.py index ba5f699da3..9541e94fd0 100644 --- a/tests/pytorch/test_fusible_ops.py +++ b/tests/pytorch/test_fusible_ops.py @@ -61,8 +61,8 @@ # Check for supported quantization schemes fp8_available, reason_for_no_fp8 = te.is_fp8_available(return_reason=True) mxfp8_available, reason_for_no_mxfp8 = te.is_mxfp8_available(return_reason=True) -mxfp8_grouped_gemm_available, reason_for_no_mxfp8_grouped_gemm = ( - te.is_mxfp8_grouped_gemm_available(return_reason=True) +mxfp8_grouped_gemm_available, reason_for_no_mxfp8_grouped_gemm = te.is_mxfp8_grouped_gemm_available( + return_reason=True ) nvfp4_available, reason_for_no_nvfp4 = te.is_nvfp4_available(return_reason=True) diff --git a/tests/pytorch/test_numerics.py b/tests/pytorch/test_numerics.py index 6d96e27ebf..ab2cbaf87d 100644 --- a/tests/pytorch/test_numerics.py +++ b/tests/pytorch/test_numerics.py @@ -61,8 +61,8 @@ # Only run FP8 tests on supported devices. fp8_available, reason_for_no_fp8 = is_fp8_available(return_reason=True) mxfp8_available, reason_for_no_mxfp8 = is_mxfp8_available(return_reason=True) -mxfp8_grouped_gemm_available, reason_for_no_mxfp8_grouped_gemm = ( - is_mxfp8_grouped_gemm_available(return_reason=True) +mxfp8_grouped_gemm_available, reason_for_no_mxfp8_grouped_gemm = is_mxfp8_grouped_gemm_available( + return_reason=True ) fp8_block_scaling_available = is_fp8_block_scaling_available() nvfp4_available = is_nvfp4_available() diff --git a/tests/pytorch/test_sanity.py b/tests/pytorch/test_sanity.py index 29a79f0cf2..2a4c672ce8 100644 --- a/tests/pytorch/test_sanity.py +++ b/tests/pytorch/test_sanity.py @@ -44,8 +44,8 @@ fp8_available, reason_for_no_fp8 = te.is_fp8_available(return_reason=True) fp8_block_scaling_available, _ = te.is_fp8_block_scaling_available(return_reason=True) mxfp8_available, reason_for_no_mxfp8 = te.is_mxfp8_available(return_reason=True) -mxfp8_grouped_gemm_available, reason_for_no_mxfp8_grouped_gemm = ( - te.is_mxfp8_grouped_gemm_available(return_reason=True) +mxfp8_grouped_gemm_available, reason_for_no_mxfp8_grouped_gemm = te.is_mxfp8_grouped_gemm_available( + return_reason=True ) nvfp4_available, _ = te.is_nvfp4_available(return_reason=True) diff --git a/transformer_engine/pytorch/quantization.py b/transformer_engine/pytorch/quantization.py index b7c59ccff2..6afba9b0f3 100644 --- a/transformer_engine/pytorch/quantization.py +++ b/transformer_engine/pytorch/quantization.py @@ -171,9 +171,12 @@ def _compute_mxfp8_support() -> Tuple[bool, str]: cublaslt_version = tex.get_cublasLt_version() if cublaslt_version >= 130600: return True, "" - return False, ( - "MXFP8 on sm_120 / sm_121 requires cuBLASLt >= 13.6.0.2 for NN/NT GEMM " - f"support (loaded cuBLASLt={cublaslt_version})." + return ( + False, + ( + "MXFP8 on sm_120 / sm_121 requires cuBLASLt >= 13.6.0.2 for NN/NT GEMM " + f"support (loaded cuBLASLt={cublaslt_version})." + ), ) if get_device_compute_capability() >= (10, 0): # blackwell and above return True, "" @@ -195,11 +198,14 @@ def _compute_mxfp8_grouped_gemm_support() -> Tuple[bool, str]: return False, base_reason if get_device_compute_capability() in ((12, 0), (12, 1)): cublaslt_version = tex.get_cublasLt_version() - return False, ( - "MXFP8 grouped GEMM is not yet supported on sm_120 / sm_121 by the loaded " - f"cuBLASLt={cublaslt_version} (single-GEMM MXFP8 is supported with " - "cuBLASLt >= 13.6.0.2). Use a non-grouped module or switch to a " - "non-MXFP8 recipe (e.g. Float8CurrentScaling) for grouped-GEMM workloads." + return ( + False, + ( + "MXFP8 grouped GEMM is not yet supported on sm_120 / sm_121 by the loaded " + f"cuBLASLt={cublaslt_version} (single-GEMM MXFP8 is supported with " + "cuBLASLt >= 13.6.0.2). Use a non-grouped module or switch to a " + "non-MXFP8 recipe (e.g. Float8CurrentScaling) for grouped-GEMM workloads." + ), ) return True, ""