-
Notifications
You must be signed in to change notification settings - Fork 735
[JAX] [PyT] [Common] Enable D=256 BWD cuDNN fused attn for Blackwell CC 10.x #3056
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
7b2e9f8
a78c4c1
546efcf
226156c
d177ecf
fe1bbd0
793ee2f
45b332f
e317f99
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||
|---|---|---|---|---|---|---|---|---|
|
|
@@ -377,6 +377,38 @@ def test_dpa_fa4_hdim256(dtype, model_configs, model): | |||||||
| test_dot_product_attention(dtype, model_configs, model, False, True, None, False, False) | ||||||||
|
|
||||||||
|
|
||||||||
| # cuDNN FusedAttention D=256 bprop is supported on sm10x from cuDNN 9.23 (FE 1.24), | ||||||||
| # via the dedicated deterministic SDPA bprop kernel, which supports d_qk == d_v == 256 only, | ||||||||
| # vanilla type of softmax only, no dropout, no ALiBi, and (for non-causal masks) full-window attention only. | ||||||||
| # (for non-causal masks) full-window attention. | ||||||||
|
Comment on lines
+382
to
+383
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The comment block ends with a repeated phrase: line 383 (
Suggested change
Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time! |
||||||||
| model_configs_fused_hdim256 = { | ||||||||
| # test: ModelConfig(b, sq, hq, dqk) -> head_dim_v defaults to head_dim_qk (256) | ||||||||
| "fused_hd256_no_mask": ModelConfig(2, 512, 16, 256), | ||||||||
| "fused_hd256_padding": ModelConfig(2, 512, 16, 256, attn_mask_type="padding"), | ||||||||
| # SWA is allowed only together with a causal mask on the D=256 bprop kernel. | ||||||||
| "fused_hd256_causal_swa": ModelConfig( | ||||||||
| 2, 1024, 16, 256, attn_mask_type="causal", window_size=(128, 0) | ||||||||
| ), | ||||||||
| # GQA variant (num_gqa_groups < num_heads). | ||||||||
| "fused_hd256_padding_causal_gqa": ModelConfig( | ||||||||
| 2, 1024, 16, 256, num_gqa_groups=4, attn_mask_type="padding_causal" | ||||||||
| ), | ||||||||
| } | ||||||||
|
|
||||||||
|
|
||||||||
| @pytest.mark.skipif(get_cudnn_version() < (9, 23, 0), reason="cuDNN 9.23+ is required.") | ||||||||
| @pytest.mark.skipif( | ||||||||
| device_compute_capability not in ((10, 0), (10, 3)), | ||||||||
| reason="cuDNN FusedAttention head_dim=256 backward is Blackwell server (SM100/SM103) only.", | ||||||||
| ) | ||||||||
| @pytest.mark.parametrize("dtype", param_types) | ||||||||
| @pytest.mark.parametrize("model_configs", [model_configs_fused_hdim256]) | ||||||||
| @pytest.mark.parametrize("model", model_configs_fused_hdim256.keys()) | ||||||||
| def test_dpa_fused_attn_hdim256(dtype, model_configs, model): | ||||||||
| """Test DotProductAttention with cuDNN FusedAttention: head_dim=256 backward on Blackwell""" | ||||||||
| test_dot_product_attention(dtype, model_configs, model, False, True, None, False, False) | ||||||||
|
|
||||||||
|
|
||||||||
| model_configs_fa4_mla = { | ||||||||
| # test: ModelConfig(b, sq, hq, dqk, head_dim_v=dv) | ||||||||
| "fa4_mla_1": ModelConfig(4, 128, 16, 128, head_dim_v=64), | ||||||||
|
|
||||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The comment says "frozen/non-learnable bias inputs (i.e. non-1HSS bias shapes) are supported" and the skip block deliberately allows those configs to proceed. However, the C++ gate in
fused_attn.cpprequiresbias_type == NVTE_NO_BIASfor the new D=256 BWD path, meaning any config withattn_bias_type != NO_BIAS && bias_shape != _1HSSwill silently fall back to a different backend rather than exercising the new kernel. The test will not fail, but it also will not validate the D=256 BWD path for those configs, and the inline comment creates a misleading expectation that such configs are actually routed through it.