From 5f4207ed29533962902ab941a049bf2b5c183cff Mon Sep 17 00:00:00 2001 From: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Date: Thu, 26 Feb 2026 10:58:05 -0800 Subject: [PATCH 1/3] fix L3 FA fp8 tests Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --- tests/pytorch/attention/test_attention.py | 112 +++++++++++----------- 1 file changed, 58 insertions(+), 54 deletions(-) diff --git a/tests/pytorch/attention/test_attention.py b/tests/pytorch/attention/test_attention.py index 243fcac882..ecb6933365 100644 --- a/tests/pytorch/attention/test_attention.py +++ b/tests/pytorch/attention/test_attention.py @@ -2127,9 +2127,10 @@ def test_dpa_fp8_vs_f16(dtype, model, qkv_layout, fp8_dpa_bwd, is_training, scal is_training=is_training, deterministic=_deterministic, ) - flash_attn_supported, fused_attn_supported, unfused_attn_supported = available_backends - if flash_attn_supported + fused_attn_supported < 1: + flash_attn_supported, fused_attn_supported_fp8, unfused_attn_supported = available_backends + if flash_attn_supported + fused_attn_supported_fp8 < 1: pytest.skip("No FP8 attention backend available.") + fused_attn_supported_f16 = False if not fp8_dpa_bwd: available_backends, _, fused_attn_backends = get_available_attention_backends( config, @@ -2138,8 +2139,8 @@ def test_dpa_fp8_vs_f16(dtype, model, qkv_layout, fp8_dpa_bwd, is_training, scal is_training=is_training, deterministic=_deterministic, ) - _, fused_attn_supported, _ = available_backends - if not fused_attn_supported: + _, fused_attn_supported_f16, _ = available_backends + if not fused_attn_supported_f16: pytest.skip("No attention backend available.") if config.num_heads != config.num_gqa_groups and "3" in qkv_layout: pytest.skip("qkv_layout not applicable for MQA/GQA") @@ -2164,30 +2165,32 @@ def test_dpa_fp8_vs_f16(dtype, model, qkv_layout, fp8_dpa_bwd, is_training, scal dtype, config, True, qkv_layout, is_training, fp8_recipe ) - os.environ["NVTE_FLASH_ATTN"] = "0" - os.environ["NVTE_FUSED_ATTN"] = "1" - os.environ["NVTE_UNFUSED_ATTN"] = "0" - _attention_backends["backend_selection_requires_update"] = True - logging.info("[test_dpa_fp8_vs_f16]: run with fp8_dpa = True (FusedAttention)") - fused_attn_fwd_fp8, fused_attn_bwd_fp8 = _run_dpa_fp8_vs_f16( - dtype, config, True, qkv_layout, is_training, fp8_recipe - ) - - os.environ["NVTE_FLASH_ATTN"] = "0" - os.environ["NVTE_FUSED_ATTN"] = "1" - os.environ["NVTE_UNFUSED_ATTN"] = "0" - if config.dropout_p == 0.0: - # test cuDNN FP8 dropout: need a FP16/BF16 reference on Blackwell - logging.info("[test_dpa_fp8_vs_f16]: run with fp8_dpa = False (FusedAttention)") - fused_attn_fwd_f16, fused_attn_bwd_f16 = _run_dpa_fp8_vs_f16( - dtype, config, False, qkv_layout, is_training, fp8_recipe + if fused_attn_supported_fp8: + os.environ["NVTE_FLASH_ATTN"] = "0" + os.environ["NVTE_FUSED_ATTN"] = "1" + os.environ["NVTE_UNFUSED_ATTN"] = "0" + _attention_backends["backend_selection_requires_update"] = True + logging.info("[test_dpa_fp8_vs_f16]: run with fp8_dpa = True (FusedAttention)") + fused_attn_fwd_fp8, fused_attn_bwd_fp8 = _run_dpa_fp8_vs_f16( + dtype, config, True, qkv_layout, is_training, fp8_recipe ) + if fused_attn_supported_f16: + os.environ["NVTE_FLASH_ATTN"] = "0" + os.environ["NVTE_FUSED_ATTN"] = "1" + os.environ["NVTE_UNFUSED_ATTN"] = "0" + if config.dropout_p == 0.0: + # test cuDNN FP8 dropout: need a FP16/BF16 reference on Blackwell + logging.info("[test_dpa_fp8_vs_f16]: run with fp8_dpa = False (FusedAttention)") + fused_attn_fwd_f16, fused_attn_bwd_f16 = _run_dpa_fp8_vs_f16( + dtype, config, False, qkv_layout, is_training, fp8_recipe + ) + atol = 5e-1 rtol = 5e-2 rmse_tol = 0.11 bwd_names = ["dq", "dk", "dv"] - if flash_attn_supported: + if flash_attn_supported and fused_attn_supported_f16: logging.debug("========== {:^25s} ==========".format("flash fp8 vs fused f16:")) logging.debug("========== {:^25s} ==========".format("forward output")) compare_and_assert( @@ -2200,7 +2203,7 @@ def test_dpa_fp8_vs_f16(dtype, model, qkv_layout, fp8_dpa_bwd, is_training, scal rmse_tol, True, ) - if unfused_attn_supported: + if unfused_attn_supported and fused_attn_supported_f16: logging.debug("========== {:^25s} ==========".format("unfused fp8 vs fused f16:")) logging.debug("========== {:^25s} ==========".format("forward output")) compare_and_assert( @@ -2226,37 +2229,38 @@ def test_dpa_fp8_vs_f16(dtype, model, qkv_layout, fp8_dpa_bwd, is_training, scal rmse_tol, True, ) - if config.dropout_p != 0.0: - # test cuDNN FP8 dropout - assert torch.all( - fused_attn_fwd_fp8 == 1 - ), "fused_attn_fwd_fp8 must be all 1s when Q/K/V are all 1s." - else: - logging.debug("========== {:^25s} ==========".format("fused fp8 vs fused f16:")) - logging.debug("========== {:^25s} ==========".format("forward output")) - compare_and_assert( - fused_attn_fwd_fp8, - fused_attn_fwd_f16, - "fused_attn_fwd_fp8", - "fused_attn_fwd_f16", - atol, - rtol, - rmse_tol, - True, - ) - if is_training: - for i, _ in enumerate(fused_attn_bwd_f16): - logging.debug("========== {:^25s} ==========".format(bwd_names[i])) - compare_and_assert( - fused_attn_bwd_fp8[i], - fused_attn_bwd_f16[i], - f"fused_attn_bwd_fp8[{i}]", - f"fused_attn_bwd_f16[{i}]", - atol, - rtol, - rmse_tol, - True, - ) + if fused_attn_supported_fp8 and fused_attn_supported_f16: + if config.dropout_p != 0.0: + # test cuDNN FP8 dropout + assert torch.all( + fused_attn_fwd_fp8 == 1 + ), "fused_attn_fwd_fp8 must be all 1s when Q/K/V are all 1s." + else: + logging.debug("========== {:^25s} ==========".format("fused fp8 vs fused f16:")) + logging.debug("========== {:^25s} ==========".format("forward output")) + compare_and_assert( + fused_attn_fwd_fp8, + fused_attn_fwd_f16, + "fused_attn_fwd_fp8", + "fused_attn_fwd_f16", + atol, + rtol, + rmse_tol, + True, + ) + if is_training: + for i, _ in enumerate(fused_attn_bwd_f16): + logging.debug("========== {:^25s} ==========".format(bwd_names[i])) + compare_and_assert( + fused_attn_bwd_fp8[i], + fused_attn_bwd_f16[i], + f"fused_attn_bwd_fp8[{i}]", + f"fused_attn_bwd_f16[{i}]", + atol, + rtol, + rmse_tol, + True, + ) os.environ["NVTE_UnfusedDPA_Emulate_FP8"] = "0" From 1a409891ef18342b0fd1ae66b0ca10670451ceb1 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 26 Feb 2026 19:06:19 +0000 Subject: [PATCH 2/3] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/pytorch/attention/test_attention.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/pytorch/attention/test_attention.py b/tests/pytorch/attention/test_attention.py index ecb6933365..4dca4f1f75 100644 --- a/tests/pytorch/attention/test_attention.py +++ b/tests/pytorch/attention/test_attention.py @@ -2229,7 +2229,7 @@ def test_dpa_fp8_vs_f16(dtype, model, qkv_layout, fp8_dpa_bwd, is_training, scal rmse_tol, True, ) - if fused_attn_supported_fp8 and fused_attn_supported_f16: + if fused_attn_supported_fp8 and fused_attn_supported_f16: if config.dropout_p != 0.0: # test cuDNN FP8 dropout assert torch.all( From 229b1effece54f35b23564e554babeb6d6fb16c7 Mon Sep 17 00:00:00 2001 From: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Date: Fri, 27 Feb 2026 12:23:34 -0800 Subject: [PATCH 3/3] fix skip logic based on reference backend Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --- tests/pytorch/attention/test_attention.py | 48 +++++++++++------------ 1 file changed, 22 insertions(+), 26 deletions(-) diff --git a/tests/pytorch/attention/test_attention.py b/tests/pytorch/attention/test_attention.py index 4dca4f1f75..31c7041897 100644 --- a/tests/pytorch/attention/test_attention.py +++ b/tests/pytorch/attention/test_attention.py @@ -1865,7 +1865,7 @@ def test_mha_fp8_vs_f16( ) fp8_meta = {} fp8_meta["recipe"] = fp8_recipe - available_backends, _, fused_attn_backends = get_available_attention_backends( + available_backends, _, _ = get_available_attention_backends( config, qkv_dtype=torch.float8_e4m3fn, qkv_layout=qkv_format.replace("hd", "h3d"), @@ -1875,20 +1875,18 @@ def test_mha_fp8_vs_f16( deterministic=_deterministic, ) flash_attn_supported, fused_attn_supported_fp8, unfused_attn_supported = available_backends + available_backends, _, fused_attn_backends = get_available_attention_backends( + config, + qkv_dtype=dtype, + qkv_layout=qkv_format.replace("hd", "h3d"), + is_training=is_training, + deterministic=_deterministic, + ) + _, fused_attn_supported_f16, _ = available_backends if flash_attn_supported + fused_attn_supported_fp8 < 1: pytest.skip("No FP8 attention backend available.") - fused_attn_supported_f16 = False - if not fp8_dpa_bwd: - available_backends, _, fused_attn_backends = get_available_attention_backends( - config, - qkv_dtype=dtype, - qkv_layout=qkv_format.replace("hd", "h3d"), - is_training=is_training, - deterministic=_deterministic, - ) - _, fused_attn_supported_f16, _ = available_backends - if not fused_attn_supported_f16: - pytest.skip("No attention backend available.") + if not fused_attn_supported_f16: + pytest.skip("No reference backend available.") if flash_attn_supported: os.environ["NVTE_FLASH_ATTN"] = "1" @@ -2118,7 +2116,7 @@ def test_dpa_fp8_vs_f16(dtype, model, qkv_layout, fp8_dpa_bwd, is_training, scal ) fp8_meta = {} fp8_meta["recipe"] = fp8_recipe - available_backends, _, fused_attn_backends = get_available_attention_backends( + available_backends, _, _ = get_available_attention_backends( config, qkv_dtype=torch.float8_e4m3fn, qkv_layout=qkv_layout, @@ -2128,20 +2126,18 @@ def test_dpa_fp8_vs_f16(dtype, model, qkv_layout, fp8_dpa_bwd, is_training, scal deterministic=_deterministic, ) flash_attn_supported, fused_attn_supported_fp8, unfused_attn_supported = available_backends + available_backends, _, _ = get_available_attention_backends( + config, + qkv_dtype=dtype, + qkv_layout=qkv_layout, + is_training=is_training, + deterministic=_deterministic, + ) + _, fused_attn_supported_f16, _ = available_backends if flash_attn_supported + fused_attn_supported_fp8 < 1: pytest.skip("No FP8 attention backend available.") - fused_attn_supported_f16 = False - if not fp8_dpa_bwd: - available_backends, _, fused_attn_backends = get_available_attention_backends( - config, - qkv_dtype=dtype, - qkv_layout=qkv_layout, - is_training=is_training, - deterministic=_deterministic, - ) - _, fused_attn_supported_f16, _ = available_backends - if not fused_attn_supported_f16: - pytest.skip("No attention backend available.") + if not fused_attn_supported_f16: + pytest.skip("No reference backend available.") if config.num_heads != config.num_gqa_groups and "3" in qkv_layout: pytest.skip("qkv_layout not applicable for MQA/GQA")