diff --git a/tests/pytorch/attention/test_attention.py b/tests/pytorch/attention/test_attention.py index 243fcac882..31c7041897 100644 --- a/tests/pytorch/attention/test_attention.py +++ b/tests/pytorch/attention/test_attention.py @@ -1865,7 +1865,7 @@ def test_mha_fp8_vs_f16( ) fp8_meta = {} fp8_meta["recipe"] = fp8_recipe - available_backends, _, fused_attn_backends = get_available_attention_backends( + available_backends, _, _ = get_available_attention_backends( config, qkv_dtype=torch.float8_e4m3fn, qkv_layout=qkv_format.replace("hd", "h3d"), @@ -1875,20 +1875,18 @@ def test_mha_fp8_vs_f16( deterministic=_deterministic, ) flash_attn_supported, fused_attn_supported_fp8, unfused_attn_supported = available_backends + available_backends, _, fused_attn_backends = get_available_attention_backends( + config, + qkv_dtype=dtype, + qkv_layout=qkv_format.replace("hd", "h3d"), + is_training=is_training, + deterministic=_deterministic, + ) + _, fused_attn_supported_f16, _ = available_backends if flash_attn_supported + fused_attn_supported_fp8 < 1: pytest.skip("No FP8 attention backend available.") - fused_attn_supported_f16 = False - if not fp8_dpa_bwd: - available_backends, _, fused_attn_backends = get_available_attention_backends( - config, - qkv_dtype=dtype, - qkv_layout=qkv_format.replace("hd", "h3d"), - is_training=is_training, - deterministic=_deterministic, - ) - _, fused_attn_supported_f16, _ = available_backends - if not fused_attn_supported_f16: - pytest.skip("No attention backend available.") + if not fused_attn_supported_f16: + pytest.skip("No reference backend available.") if flash_attn_supported: os.environ["NVTE_FLASH_ATTN"] = "1" @@ -2118,7 +2116,7 @@ def test_dpa_fp8_vs_f16(dtype, model, qkv_layout, fp8_dpa_bwd, is_training, scal ) fp8_meta = {} fp8_meta["recipe"] = fp8_recipe - available_backends, _, fused_attn_backends = get_available_attention_backends( + available_backends, _, _ = get_available_attention_backends( config, qkv_dtype=torch.float8_e4m3fn, qkv_layout=qkv_layout, @@ -2127,20 +2125,19 @@ def test_dpa_fp8_vs_f16(dtype, model, qkv_layout, fp8_dpa_bwd, is_training, scal is_training=is_training, deterministic=_deterministic, ) - flash_attn_supported, fused_attn_supported, unfused_attn_supported = available_backends - if flash_attn_supported + fused_attn_supported < 1: + flash_attn_supported, fused_attn_supported_fp8, unfused_attn_supported = available_backends + available_backends, _, _ = get_available_attention_backends( + config, + qkv_dtype=dtype, + qkv_layout=qkv_layout, + is_training=is_training, + deterministic=_deterministic, + ) + _, fused_attn_supported_f16, _ = available_backends + if flash_attn_supported + fused_attn_supported_fp8 < 1: pytest.skip("No FP8 attention backend available.") - if not fp8_dpa_bwd: - available_backends, _, fused_attn_backends = get_available_attention_backends( - config, - qkv_dtype=dtype, - qkv_layout=qkv_layout, - is_training=is_training, - deterministic=_deterministic, - ) - _, fused_attn_supported, _ = available_backends - if not fused_attn_supported: - pytest.skip("No attention backend available.") + if not fused_attn_supported_f16: + pytest.skip("No reference backend available.") if config.num_heads != config.num_gqa_groups and "3" in qkv_layout: pytest.skip("qkv_layout not applicable for MQA/GQA") @@ -2164,30 +2161,32 @@ def test_dpa_fp8_vs_f16(dtype, model, qkv_layout, fp8_dpa_bwd, is_training, scal dtype, config, True, qkv_layout, is_training, fp8_recipe ) - os.environ["NVTE_FLASH_ATTN"] = "0" - os.environ["NVTE_FUSED_ATTN"] = "1" - os.environ["NVTE_UNFUSED_ATTN"] = "0" - _attention_backends["backend_selection_requires_update"] = True - logging.info("[test_dpa_fp8_vs_f16]: run with fp8_dpa = True (FusedAttention)") - fused_attn_fwd_fp8, fused_attn_bwd_fp8 = _run_dpa_fp8_vs_f16( - dtype, config, True, qkv_layout, is_training, fp8_recipe - ) - - os.environ["NVTE_FLASH_ATTN"] = "0" - os.environ["NVTE_FUSED_ATTN"] = "1" - os.environ["NVTE_UNFUSED_ATTN"] = "0" - if config.dropout_p == 0.0: - # test cuDNN FP8 dropout: need a FP16/BF16 reference on Blackwell - logging.info("[test_dpa_fp8_vs_f16]: run with fp8_dpa = False (FusedAttention)") - fused_attn_fwd_f16, fused_attn_bwd_f16 = _run_dpa_fp8_vs_f16( - dtype, config, False, qkv_layout, is_training, fp8_recipe + if fused_attn_supported_fp8: + os.environ["NVTE_FLASH_ATTN"] = "0" + os.environ["NVTE_FUSED_ATTN"] = "1" + os.environ["NVTE_UNFUSED_ATTN"] = "0" + _attention_backends["backend_selection_requires_update"] = True + logging.info("[test_dpa_fp8_vs_f16]: run with fp8_dpa = True (FusedAttention)") + fused_attn_fwd_fp8, fused_attn_bwd_fp8 = _run_dpa_fp8_vs_f16( + dtype, config, True, qkv_layout, is_training, fp8_recipe ) + if fused_attn_supported_f16: + os.environ["NVTE_FLASH_ATTN"] = "0" + os.environ["NVTE_FUSED_ATTN"] = "1" + os.environ["NVTE_UNFUSED_ATTN"] = "0" + if config.dropout_p == 0.0: + # test cuDNN FP8 dropout: need a FP16/BF16 reference on Blackwell + logging.info("[test_dpa_fp8_vs_f16]: run with fp8_dpa = False (FusedAttention)") + fused_attn_fwd_f16, fused_attn_bwd_f16 = _run_dpa_fp8_vs_f16( + dtype, config, False, qkv_layout, is_training, fp8_recipe + ) + atol = 5e-1 rtol = 5e-2 rmse_tol = 0.11 bwd_names = ["dq", "dk", "dv"] - if flash_attn_supported: + if flash_attn_supported and fused_attn_supported_f16: logging.debug("========== {:^25s} ==========".format("flash fp8 vs fused f16:")) logging.debug("========== {:^25s} ==========".format("forward output")) compare_and_assert( @@ -2200,7 +2199,7 @@ def test_dpa_fp8_vs_f16(dtype, model, qkv_layout, fp8_dpa_bwd, is_training, scal rmse_tol, True, ) - if unfused_attn_supported: + if unfused_attn_supported and fused_attn_supported_f16: logging.debug("========== {:^25s} ==========".format("unfused fp8 vs fused f16:")) logging.debug("========== {:^25s} ==========".format("forward output")) compare_and_assert( @@ -2226,37 +2225,38 @@ def test_dpa_fp8_vs_f16(dtype, model, qkv_layout, fp8_dpa_bwd, is_training, scal rmse_tol, True, ) - if config.dropout_p != 0.0: - # test cuDNN FP8 dropout - assert torch.all( - fused_attn_fwd_fp8 == 1 - ), "fused_attn_fwd_fp8 must be all 1s when Q/K/V are all 1s." - else: - logging.debug("========== {:^25s} ==========".format("fused fp8 vs fused f16:")) - logging.debug("========== {:^25s} ==========".format("forward output")) - compare_and_assert( - fused_attn_fwd_fp8, - fused_attn_fwd_f16, - "fused_attn_fwd_fp8", - "fused_attn_fwd_f16", - atol, - rtol, - rmse_tol, - True, - ) - if is_training: - for i, _ in enumerate(fused_attn_bwd_f16): - logging.debug("========== {:^25s} ==========".format(bwd_names[i])) - compare_and_assert( - fused_attn_bwd_fp8[i], - fused_attn_bwd_f16[i], - f"fused_attn_bwd_fp8[{i}]", - f"fused_attn_bwd_f16[{i}]", - atol, - rtol, - rmse_tol, - True, - ) + if fused_attn_supported_fp8 and fused_attn_supported_f16: + if config.dropout_p != 0.0: + # test cuDNN FP8 dropout + assert torch.all( + fused_attn_fwd_fp8 == 1 + ), "fused_attn_fwd_fp8 must be all 1s when Q/K/V are all 1s." + else: + logging.debug("========== {:^25s} ==========".format("fused fp8 vs fused f16:")) + logging.debug("========== {:^25s} ==========".format("forward output")) + compare_and_assert( + fused_attn_fwd_fp8, + fused_attn_fwd_f16, + "fused_attn_fwd_fp8", + "fused_attn_fwd_f16", + atol, + rtol, + rmse_tol, + True, + ) + if is_training: + for i, _ in enumerate(fused_attn_bwd_f16): + logging.debug("========== {:^25s} ==========".format(bwd_names[i])) + compare_and_assert( + fused_attn_bwd_fp8[i], + fused_attn_bwd_f16[i], + f"fused_attn_bwd_fp8[{i}]", + f"fused_attn_bwd_f16[{i}]", + atol, + rtol, + rmse_tol, + True, + ) os.environ["NVTE_UnfusedDPA_Emulate_FP8"] = "0"