From 7b2e9f805223f0950027896c12bf8541f8439cfe Mon Sep 17 00:00:00 2001 From: Kshitij Lakhani Date: Wed, 27 May 2026 17:41:35 -0700 Subject: [PATCH 1/9] Refine the support for D=256 on Blackwell server type GPUs Signed-off-by: Kshitij Lakhani --- .../common/fused_attn/fused_attn.cpp | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/transformer_engine/common/fused_attn/fused_attn.cpp b/transformer_engine/common/fused_attn/fused_attn.cpp index d2eb1a831c..27dd11ab5e 100644 --- a/transformer_engine/common/fused_attn/fused_attn.cpp +++ b/transformer_engine/common/fused_attn/fused_attn.cpp @@ -315,6 +315,25 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( (head_dim_qk <= 256 && head_dim_v <= 256 && ((!is_training && sm_arch_ == 90 && cudnn_runtime_version >= 90100) || (is_training && sm_arch_ == 90 && cudnn_runtime_version >= 90500))) || + // 9.23: d_qk = d_v = 256 + SM10.x + bprop + non-paged + // cuDNN's dedicated SM10.x D=256 SDPA backward kernel (cuDNN FE 1.24 / + // BE 9.23+) only supports d_qk == d_v == 256. + (head_dim_qk == 256 && head_dim_v == 256 && is_training && + sm_arch_ >= 100 && sm_arch_ < 110 && cudnn_runtime_version >= 92300 && + layout_group != NVTE_QKV_Layout_Group::NVTE_Paged_KV_HD_HD_HD && + // The FE forces this path onto deterministic bprop, which then rejects alibi and + // dropout, and only supports vanilla softmax. + bias_type != NVTE_Bias_Type::NVTE_ALIBI && dropout == 0.0 && + softmax_type == NVTE_Softmax_Type::NVTE_VANILLA_SOFTMAX && + // Non-causal D=256 supports only full-window attention; SWA is allowed only for causal masks. + // NOTE: SWA support for non causal would be available when cuDNN decides to redirect + // D=256 support to a CUDA backend instead of a Python OSS kernel + ((window_size_left == -1 && window_size_right == -1) || + ((attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK || + attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK || + attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_BOTTOM_RIGHT_MASK || + attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK) && + (window_size_right == -1 || window_size_right == 0)))) || // 9.9: any head_dim + Blackwell + fprop + non_paged + sq > 1 (!is_training && sm_arch_ >= 100 && cudnn_runtime_version >= 90900 && max_seqlen_q > 1 && layout_group != NVTE_QKV_Layout_Group::NVTE_Paged_KV_HD_HD_HD) || From a78c4c1710d9d40fdc99ed3e38f1e1789157e1f0 Mon Sep 17 00:00:00 2001 From: Kshitij Lakhani Date: Thu, 28 May 2026 11:38:21 -0700 Subject: [PATCH 2/9] Add deterministic tests for D=256 for sm10.x Signed-off-by: Kshitij Lakhani --- tests/jax/test_fused_attn.py | 160 +++++++++++++++++++++++++++++++++++ 1 file changed, 160 insertions(+) diff --git a/tests/jax/test_fused_attn.py b/tests/jax/test_fused_attn.py index 1fb0108068..658ae25dff 100644 --- a/tests/jax/test_fused_attn.py +++ b/tests/jax/test_fused_attn.py @@ -444,6 +444,59 @@ def _check_configs(self): "is either BSHD_BSHD_BSHD or THD_THD_THD" ) + # D=256 bprop on SM10.x uses cuDNN's dedicated SDPA bprop kernel + # (cuDNN FE 1.24 / BE 9.23+). FE forces this path onto the deterministic algorithm path, + # which rejects dBias, dropout, and ALiBi. It supports vanilla softmax only and allows SWA + # together with a causal mask only. + compute_capability = get_device_compute_capability(0) + is_sm10x = 100 <= compute_capability < 110 + if self.is_training and is_sm10x and (self.head_dim_qk == 256 or self.head_dim_v == 256): + if self.head_dim_qk != 256 or self.head_dim_v != 256: + pytest.skip( + "D=256 BWD on Blackwell only supports d_qk == d_v == 256;" + f" got d_qk={self.head_dim_qk}, d_v={self.head_dim_v}." + ) + cudnn_version = get_cudnn_version() + if cudnn_version < 92300: + pytest.skip( + "D=256 BWD on Blackwell requires cuDNN 9.23 or newer;" + f" got cuDNN {cudnn_version}." + ) + # Non-learnable bias is fine (bias is allowed as an input); only dBias is + # unsupported. The JAX runner asks for dBias iff the bias shape is [1, h, s, s] + # (see test_backward), so gate on that. + unsupported = None + if self.attn_bias_type == AttnBiasType.PRE_SCALE_BIAS: + unsupported = "pre-scale bias" + elif ( + self.attn_bias_type != AttnBiasType.NO_BIAS + and self.bias_shape == BiasShape._1HSS + ): + unsupported = ( + "bias gradients (dBias); frozen/non-learnable bias inputs" + " (i.e. non-1HSS bias shapes) are supported" + ) + elif self.dropout_prob != 0.0: + unsupported = "dropout" + elif self.softmax_type != AttnSoftmaxType.VANILLA_SOFTMAX: + unsupported = "non-vanilla softmax" + if unsupported is not None: + pytest.skip( + "D=256 BWD on Blackwell uses the deterministic SM100 D=256 SDPA BWD" + f" kernel which does not support {unsupported}." + ) + if self.window_size is not None and self.window_size != (-1, -1): + if not self.attn_mask_type.is_causal(): + pytest.skip( + "D=256 BWD on Blackwell uses the SM100 D=256 SDPA BWD kernel" + " which requires window_size=(-1, -1) for non-causal masks." + ) + if self.window_size[1] not in (-1, 0): + pytest.skip( + "D=256 BWD on Blackwell only supports right window -1 or 0" + " for causal masks." + ) + self.backend = FusedAttnHelper( self.is_training, self.dtype, @@ -1605,3 +1658,110 @@ def test_backward( swa, seq_desc_format, ) + + +@pytest.mark.parametrize( + "attn_mask_type", + [ + pytest.param(AttnMaskType.NO_MASK, id="NO_MASK"), + pytest.param(AttnMaskType.PADDING_MASK, id="PADDING"), + pytest.param(AttnMaskType.CAUSAL_MASK, id="CAUSAL"), + pytest.param(AttnMaskType.PADDING_CAUSAL_MASK, id="PADDING_CAUSAL"), + pytest.param( + AttnMaskType.PADDING_CAUSAL_BOTTOM_RIGHT_MASK, id="PADDING_CAUSAL_BOTTOM_RIGHT" + ), + ], +) +@pytest.mark.parametrize( + "softmax_type", + [ + pytest.param(AttnSoftmaxType.VANILLA_SOFTMAX, id="VANILLA_SOFTMAX"), + ], +) +@pytest.mark.parametrize( + "dropout_prob", + [ + pytest.param(0.0, id="DROP_0.0"), + ], +) +@pytest.mark.parametrize( + "swa", + [ + pytest.param(False, id="NO_SWA"), + ], +) +@pytest.mark.parametrize( + "seq_desc_format", + [ + pytest.param(SeqDescFormat.Seqlens, id="Seqlens"), + ], +) +@pytest.mark.skipif(not _deterministic, reason="Test determinism only") +class TestFusedAttnD256WithDeterminism: + """ + Fused attention D=256 deterministic backward tester. + """ + + @staticmethod + @pytest.mark.parametrize( + "b, s_q, s_kv, h_q, h_kv, d_qk, d_v, dtype, qkv_layout", + [ + pytest.param( + 4, + 128, + 128, + 16, + 16, + 256, + 256, + jnp.float16, + QKVLayout.BSHD_BS2HD, + id="4-128-128-16-16-256-256-FP16-SELF-KV_PACKED", + ), + ], + ) + @pytest.mark.parametrize( + "attn_bias_type, bias_shape", + [ + pytest.param(AttnBiasType.NO_BIAS, None, id="NO_BIAS"), + ], + ) + def test_backward( + b, + s_q, + s_kv, + h_q, + h_kv, + d_qk, + d_v, + attn_bias_type, + attn_mask_type, + softmax_type, + dropout_prob, + dtype, + qkv_layout, + bias_shape, + swa, + seq_desc_format, + ): + """ + Test D=256 backward with the supported deterministic SM100 bprop configuration. + """ + TestFusedAttn.test_backward( + b, + s_q, + s_kv, + h_q, + h_kv, + d_qk, + d_v, + attn_bias_type, + attn_mask_type, + softmax_type, + dropout_prob, + dtype, + qkv_layout, + bias_shape, + swa, + seq_desc_format, + ) From 546efcfa2c8ff90305dab405edec81d3f01f9a06 Mon Sep 17 00:00:00 2001 From: Kshitij Lakhani Date: Thu, 28 May 2026 12:48:10 -0700 Subject: [PATCH 3/9] Test clean up Signed-off-by: Kshitij Lakhani --- tests/jax/test_fused_attn.py | 122 +++++------------------------------ 1 file changed, 15 insertions(+), 107 deletions(-) diff --git a/tests/jax/test_fused_attn.py b/tests/jax/test_fused_attn.py index 658ae25dff..0192fb05b4 100644 --- a/tests/jax/test_fused_attn.py +++ b/tests/jax/test_fused_attn.py @@ -1527,6 +1527,21 @@ def test_backward( QKVLayout.THD_THD_THD, id="2-1024-2048-12-6-128-64-BF16-CROSS-GQA-RAGGED_SEPARATE", ), + # D=256 deterministic backward on the SM100 dedicated SDPA bprop kernel + # (cuDNN FE 1.24 / BE 9.23+). Unsupported configs (e.g. dBias, non-256 head dims) + # are skipped by FusedAttnRunner._check_configs. + pytest.param( + 4, + 128, + 128, + 16, + 16, + 256, + 256, + jnp.float16, + QKVLayout.BSHD_BS2HD, + id="4-128-128-16-16-256-256-FP16-SELF-KV_PACKED", + ), ], ) @pytest.mark.parametrize( @@ -1658,110 +1673,3 @@ def test_backward( swa, seq_desc_format, ) - - -@pytest.mark.parametrize( - "attn_mask_type", - [ - pytest.param(AttnMaskType.NO_MASK, id="NO_MASK"), - pytest.param(AttnMaskType.PADDING_MASK, id="PADDING"), - pytest.param(AttnMaskType.CAUSAL_MASK, id="CAUSAL"), - pytest.param(AttnMaskType.PADDING_CAUSAL_MASK, id="PADDING_CAUSAL"), - pytest.param( - AttnMaskType.PADDING_CAUSAL_BOTTOM_RIGHT_MASK, id="PADDING_CAUSAL_BOTTOM_RIGHT" - ), - ], -) -@pytest.mark.parametrize( - "softmax_type", - [ - pytest.param(AttnSoftmaxType.VANILLA_SOFTMAX, id="VANILLA_SOFTMAX"), - ], -) -@pytest.mark.parametrize( - "dropout_prob", - [ - pytest.param(0.0, id="DROP_0.0"), - ], -) -@pytest.mark.parametrize( - "swa", - [ - pytest.param(False, id="NO_SWA"), - ], -) -@pytest.mark.parametrize( - "seq_desc_format", - [ - pytest.param(SeqDescFormat.Seqlens, id="Seqlens"), - ], -) -@pytest.mark.skipif(not _deterministic, reason="Test determinism only") -class TestFusedAttnD256WithDeterminism: - """ - Fused attention D=256 deterministic backward tester. - """ - - @staticmethod - @pytest.mark.parametrize( - "b, s_q, s_kv, h_q, h_kv, d_qk, d_v, dtype, qkv_layout", - [ - pytest.param( - 4, - 128, - 128, - 16, - 16, - 256, - 256, - jnp.float16, - QKVLayout.BSHD_BS2HD, - id="4-128-128-16-16-256-256-FP16-SELF-KV_PACKED", - ), - ], - ) - @pytest.mark.parametrize( - "attn_bias_type, bias_shape", - [ - pytest.param(AttnBiasType.NO_BIAS, None, id="NO_BIAS"), - ], - ) - def test_backward( - b, - s_q, - s_kv, - h_q, - h_kv, - d_qk, - d_v, - attn_bias_type, - attn_mask_type, - softmax_type, - dropout_prob, - dtype, - qkv_layout, - bias_shape, - swa, - seq_desc_format, - ): - """ - Test D=256 backward with the supported deterministic SM100 bprop configuration. - """ - TestFusedAttn.test_backward( - b, - s_q, - s_kv, - h_q, - h_kv, - d_qk, - d_v, - attn_bias_type, - attn_mask_type, - softmax_type, - dropout_prob, - dtype, - qkv_layout, - bias_shape, - swa, - seq_desc_format, - ) From 226156c77193593f05029a934582f60c1cac8491 Mon Sep 17 00:00:00 2001 From: Kshitij Lakhani Date: Thu, 28 May 2026 13:28:03 -0700 Subject: [PATCH 4/9] Add PyT side tests for D=256 cuDNN fused attn on SM100-110 Signed-off-by: Kshitij Lakhani --- tests/pytorch/attention/test_attention.py | 38 +++++++++++++++++++++++ 1 file changed, 38 insertions(+) diff --git a/tests/pytorch/attention/test_attention.py b/tests/pytorch/attention/test_attention.py index 401bd6f01d..242f03cd7f 100644 --- a/tests/pytorch/attention/test_attention.py +++ b/tests/pytorch/attention/test_attention.py @@ -377,6 +377,44 @@ def test_dpa_fa4_hdim256(dtype, model_configs, model): test_dot_product_attention(dtype, model_configs, model, False, True, None, False, False) +# cuDNN FusedAttention head_dim=256 backward is supported on Blackwell server GPUs +# (SM100/SM103) from cuDNN 9.23 (FE 1.24), via the dedicated deterministic SDPA bprop +# kernel. It requires d_qk == d_v == 256, vanilla softmax, no dropout, no ALiBi, and +# (for non-causal masks) full-window attention. See nvte_get_fused_attn_backend in +# transformer_engine/common/fused_attn/fused_attn.cpp. These configs use d_qk == d_v == 256 +# with s_q == s_kv > 1 so the training (forward + backward) FusedAttention route is exercised +# and compared against the reference backends. +model_configs_fused_hdim256 = { + # test: ModelConfig(b, sq, hq, dqk) -> head_dim_v defaults to head_dim_qk (256) + "fused_hd256_no_mask": ModelConfig(2, 512, 16, 256), + "fused_hd256_causal": ModelConfig(2, 512, 16, 256, attn_mask_type="causal"), + "fused_hd256_padding": ModelConfig(2, 512, 16, 256, attn_mask_type="padding"), + "fused_hd256_padding_causal": ModelConfig(2, 512, 16, 256, attn_mask_type="padding_causal"), + "fused_hd256_padding_causal_br": ModelConfig( + 2, 512, 16, 256, attn_mask_type="padding_causal_bottom_right" + ), + # SWA is allowed only together with a causal mask on the D=256 bprop kernel. + "fused_hd256_causal_swa": ModelConfig( + 2, 512, 16, 256, attn_mask_type="causal", window_size=(128, 0) + ), + # GQA variant (num_gqa_groups < num_heads). + "fused_hd256_gqa": ModelConfig(2, 512, 16, 256, num_gqa_groups=4, attn_mask_type="causal"), +} + + +@pytest.mark.skipif(get_cudnn_version() < (9, 23, 0), reason="cuDNN 9.23+ is required.") +@pytest.mark.skipif( + device_compute_capability not in ((10, 0), (10, 3)), + reason="cuDNN FusedAttention head_dim=256 backward is Blackwell server (SM100/SM103) only.", +) +@pytest.mark.parametrize("dtype", param_types) +@pytest.mark.parametrize("model_configs", [model_configs_fused_hdim256]) +@pytest.mark.parametrize("model", model_configs_fused_hdim256.keys()) +def test_dpa_fused_attn_hdim256(dtype, model_configs, model): + """Test DotProductAttention with cuDNN FusedAttention: head_dim=256 backward on Blackwell""" + test_dot_product_attention(dtype, model_configs, model, False, True, None, False, False) + + model_configs_fa4_mla = { # test: ModelConfig(b, sq, hq, dqk, head_dim_v=dv) "fa4_mla_1": ModelConfig(4, 128, 16, 128, head_dim_v=64), From d177ecfbe9f32627b0f9943db248357e16d32d7f Mon Sep 17 00:00:00 2001 From: Kshitij Lakhani Date: Thu, 28 May 2026 16:04:04 -0700 Subject: [PATCH 5/9] Refine the PyT D=256 tests Signed-off-by: Kshitij Lakhani --- tests/pytorch/attention/test_attention.py | 9 ++------- 1 file changed, 2 insertions(+), 7 deletions(-) diff --git a/tests/pytorch/attention/test_attention.py b/tests/pytorch/attention/test_attention.py index 242f03cd7f..0f0d6e87ff 100644 --- a/tests/pytorch/attention/test_attention.py +++ b/tests/pytorch/attention/test_attention.py @@ -387,18 +387,13 @@ def test_dpa_fa4_hdim256(dtype, model_configs, model): model_configs_fused_hdim256 = { # test: ModelConfig(b, sq, hq, dqk) -> head_dim_v defaults to head_dim_qk (256) "fused_hd256_no_mask": ModelConfig(2, 512, 16, 256), - "fused_hd256_causal": ModelConfig(2, 512, 16, 256, attn_mask_type="causal"), "fused_hd256_padding": ModelConfig(2, 512, 16, 256, attn_mask_type="padding"), - "fused_hd256_padding_causal": ModelConfig(2, 512, 16, 256, attn_mask_type="padding_causal"), - "fused_hd256_padding_causal_br": ModelConfig( - 2, 512, 16, 256, attn_mask_type="padding_causal_bottom_right" - ), # SWA is allowed only together with a causal mask on the D=256 bprop kernel. "fused_hd256_causal_swa": ModelConfig( - 2, 512, 16, 256, attn_mask_type="causal", window_size=(128, 0) + 2, 1024, 16, 256, attn_mask_type="causal", window_size=(128, 0) ), # GQA variant (num_gqa_groups < num_heads). - "fused_hd256_gqa": ModelConfig(2, 512, 16, 256, num_gqa_groups=4, attn_mask_type="causal"), + "fused_hd256_padding_causal_gqa": ModelConfig(2, 1024, 16, 256, num_gqa_groups=4, attn_mask_type="padding_causal"), } From fe1bbd06049d08896b1158a95ea42a8436531a8f Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 28 May 2026 23:06:28 +0000 Subject: [PATCH 6/9] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/jax/test_fused_attn.py | 5 +---- tests/pytorch/attention/test_attention.py | 4 +++- transformer_engine/common/fused_attn/fused_attn.cpp | 8 ++++---- 3 files changed, 8 insertions(+), 9 deletions(-) diff --git a/tests/jax/test_fused_attn.py b/tests/jax/test_fused_attn.py index 0192fb05b4..bbc45cea06 100644 --- a/tests/jax/test_fused_attn.py +++ b/tests/jax/test_fused_attn.py @@ -468,10 +468,7 @@ def _check_configs(self): unsupported = None if self.attn_bias_type == AttnBiasType.PRE_SCALE_BIAS: unsupported = "pre-scale bias" - elif ( - self.attn_bias_type != AttnBiasType.NO_BIAS - and self.bias_shape == BiasShape._1HSS - ): + elif self.attn_bias_type != AttnBiasType.NO_BIAS and self.bias_shape == BiasShape._1HSS: unsupported = ( "bias gradients (dBias); frozen/non-learnable bias inputs" " (i.e. non-1HSS bias shapes) are supported" diff --git a/tests/pytorch/attention/test_attention.py b/tests/pytorch/attention/test_attention.py index 0f0d6e87ff..f2a8c9cf76 100644 --- a/tests/pytorch/attention/test_attention.py +++ b/tests/pytorch/attention/test_attention.py @@ -393,7 +393,9 @@ def test_dpa_fa4_hdim256(dtype, model_configs, model): 2, 1024, 16, 256, attn_mask_type="causal", window_size=(128, 0) ), # GQA variant (num_gqa_groups < num_heads). - "fused_hd256_padding_causal_gqa": ModelConfig(2, 1024, 16, 256, num_gqa_groups=4, attn_mask_type="padding_causal"), + "fused_hd256_padding_causal_gqa": ModelConfig( + 2, 1024, 16, 256, num_gqa_groups=4, attn_mask_type="padding_causal" + ), } diff --git a/transformer_engine/common/fused_attn/fused_attn.cpp b/transformer_engine/common/fused_attn/fused_attn.cpp index 27dd11ab5e..5e7517078a 100644 --- a/transformer_engine/common/fused_attn/fused_attn.cpp +++ b/transformer_engine/common/fused_attn/fused_attn.cpp @@ -318,16 +318,16 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( // 9.23: d_qk = d_v = 256 + SM10.x + bprop + non-paged // cuDNN's dedicated SM10.x D=256 SDPA backward kernel (cuDNN FE 1.24 / // BE 9.23+) only supports d_qk == d_v == 256. - (head_dim_qk == 256 && head_dim_v == 256 && is_training && - sm_arch_ >= 100 && sm_arch_ < 110 && cudnn_runtime_version >= 92300 && + (head_dim_qk == 256 && head_dim_v == 256 && is_training && sm_arch_ >= 100 && + sm_arch_ < 110 && cudnn_runtime_version >= 92300 && layout_group != NVTE_QKV_Layout_Group::NVTE_Paged_KV_HD_HD_HD && // The FE forces this path onto deterministic bprop, which then rejects alibi and // dropout, and only supports vanilla softmax. bias_type != NVTE_Bias_Type::NVTE_ALIBI && dropout == 0.0 && softmax_type == NVTE_Softmax_Type::NVTE_VANILLA_SOFTMAX && // Non-causal D=256 supports only full-window attention; SWA is allowed only for causal masks. - // NOTE: SWA support for non causal would be available when cuDNN decides to redirect - // D=256 support to a CUDA backend instead of a Python OSS kernel + // NOTE: SWA support for non causal would be available when cuDNN decides to redirect + // D=256 support to a CUDA backend instead of a Python OSS kernel ((window_size_left == -1 && window_size_right == -1) || ((attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK || attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK || From 793ee2f79b7ee3ad4223d16116627ccaf1c96907 Mon Sep 17 00:00:00 2001 From: Kshitij Lakhani Date: Fri, 29 May 2026 15:17:27 -0700 Subject: [PATCH 7/9] Fix the filtering condition for bias type for D=256 on sm10x for cudnn fused attn Signed-off-by: Kshitij Lakhani --- transformer_engine/common/fused_attn/fused_attn.cpp | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/transformer_engine/common/fused_attn/fused_attn.cpp b/transformer_engine/common/fused_attn/fused_attn.cpp index 5e7517078a..a5a8145975 100644 --- a/transformer_engine/common/fused_attn/fused_attn.cpp +++ b/transformer_engine/common/fused_attn/fused_attn.cpp @@ -321,9 +321,12 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( (head_dim_qk == 256 && head_dim_v == 256 && is_training && sm_arch_ >= 100 && sm_arch_ < 110 && cudnn_runtime_version >= 92300 && layout_group != NVTE_QKV_Layout_Group::NVTE_Paged_KV_HD_HD_HD && - // The FE forces this path onto deterministic bprop, which then rejects alibi and - // dropout, and only supports vanilla softmax. - bias_type != NVTE_Bias_Type::NVTE_ALIBI && dropout == 0.0 && + // The FE forces this path onto the deterministic bprop algorithm, which on + // Blackwell rejects dBias, dropout, and ALiBi (and supports vanilla softmax only). + // Require NO_BIAS: a learnable pre/post-scale bias would request dBias in the bprop + // graph and fail validation. The selector has no visibility into the bias shape, so + // gate on the bias type to avoid over-selecting this backend. + bias_type == NVTE_Bias_Type::NVTE_NO_BIAS && dropout == 0.0 && softmax_type == NVTE_Softmax_Type::NVTE_VANILLA_SOFTMAX && // Non-causal D=256 supports only full-window attention; SWA is allowed only for causal masks. // NOTE: SWA support for non causal would be available when cuDNN decides to redirect From 45b332f5827c810f6e80de6a8b3477ad26707d0b Mon Sep 17 00:00:00 2001 From: Kshitij Lakhani Date: Fri, 29 May 2026 15:53:16 -0700 Subject: [PATCH 8/9] Code clean up Signed-off-by: Kshitij Lakhani --- tests/pytorch/attention/test_attention.py | 11 ++++------- transformer_engine/common/fused_attn/fused_attn.cpp | 9 +-------- 2 files changed, 5 insertions(+), 15 deletions(-) diff --git a/tests/pytorch/attention/test_attention.py b/tests/pytorch/attention/test_attention.py index f2a8c9cf76..2a175a2ad2 100644 --- a/tests/pytorch/attention/test_attention.py +++ b/tests/pytorch/attention/test_attention.py @@ -377,13 +377,10 @@ def test_dpa_fa4_hdim256(dtype, model_configs, model): test_dot_product_attention(dtype, model_configs, model, False, True, None, False, False) -# cuDNN FusedAttention head_dim=256 backward is supported on Blackwell server GPUs -# (SM100/SM103) from cuDNN 9.23 (FE 1.24), via the dedicated deterministic SDPA bprop -# kernel. It requires d_qk == d_v == 256, vanilla softmax, no dropout, no ALiBi, and -# (for non-causal masks) full-window attention. See nvte_get_fused_attn_backend in -# transformer_engine/common/fused_attn/fused_attn.cpp. These configs use d_qk == d_v == 256 -# with s_q == s_kv > 1 so the training (forward + backward) FusedAttention route is exercised -# and compared against the reference backends. +# cuDNN FusedAttention D=256 bprop is supported on sm10x from cuDNN 9.23 (FE 1.24), +# via the dedicated deterministic SDPA bprop kernel, which supports d_qk == d_v == 256 only, +# vanilla type of softmax only, no dropout, no ALiBi, and (for non-causal masks) full-window attention only. +# (for non-causal masks) full-window attention. model_configs_fused_hdim256 = { # test: ModelConfig(b, sq, hq, dqk) -> head_dim_v defaults to head_dim_qk (256) "fused_hd256_no_mask": ModelConfig(2, 512, 16, 256), diff --git a/transformer_engine/common/fused_attn/fused_attn.cpp b/transformer_engine/common/fused_attn/fused_attn.cpp index a5a8145975..d70ccf88c9 100644 --- a/transformer_engine/common/fused_attn/fused_attn.cpp +++ b/transformer_engine/common/fused_attn/fused_attn.cpp @@ -315,22 +315,15 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( (head_dim_qk <= 256 && head_dim_v <= 256 && ((!is_training && sm_arch_ == 90 && cudnn_runtime_version >= 90100) || (is_training && sm_arch_ == 90 && cudnn_runtime_version >= 90500))) || - // 9.23: d_qk = d_v = 256 + SM10.x + bprop + non-paged - // cuDNN's dedicated SM10.x D=256 SDPA backward kernel (cuDNN FE 1.24 / - // BE 9.23+) only supports d_qk == d_v == 256. + // 9.23: d_qk = d_v = 256 + SM10x (cuDNN FE 1.24 / BE 9.23+) + bprop + non-paged (head_dim_qk == 256 && head_dim_v == 256 && is_training && sm_arch_ >= 100 && sm_arch_ < 110 && cudnn_runtime_version >= 92300 && layout_group != NVTE_QKV_Layout_Group::NVTE_Paged_KV_HD_HD_HD && // The FE forces this path onto the deterministic bprop algorithm, which on // Blackwell rejects dBias, dropout, and ALiBi (and supports vanilla softmax only). - // Require NO_BIAS: a learnable pre/post-scale bias would request dBias in the bprop - // graph and fail validation. The selector has no visibility into the bias shape, so - // gate on the bias type to avoid over-selecting this backend. bias_type == NVTE_Bias_Type::NVTE_NO_BIAS && dropout == 0.0 && softmax_type == NVTE_Softmax_Type::NVTE_VANILLA_SOFTMAX && // Non-causal D=256 supports only full-window attention; SWA is allowed only for causal masks. - // NOTE: SWA support for non causal would be available when cuDNN decides to redirect - // D=256 support to a CUDA backend instead of a Python OSS kernel ((window_size_left == -1 && window_size_right == -1) || ((attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK || attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK || From e317f9940a5f4a53e3f1c12213ded142a27f3b09 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 29 May 2026 22:54:12 +0000 Subject: [PATCH 9/9] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/pytorch/attention/test_attention.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/pytorch/attention/test_attention.py b/tests/pytorch/attention/test_attention.py index 2a175a2ad2..b5201e81b9 100644 --- a/tests/pytorch/attention/test_attention.py +++ b/tests/pytorch/attention/test_attention.py @@ -378,7 +378,7 @@ def test_dpa_fa4_hdim256(dtype, model_configs, model): # cuDNN FusedAttention D=256 bprop is supported on sm10x from cuDNN 9.23 (FE 1.24), -# via the dedicated deterministic SDPA bprop kernel, which supports d_qk == d_v == 256 only, +# via the dedicated deterministic SDPA bprop kernel, which supports d_qk == d_v == 256 only, # vanilla type of softmax only, no dropout, no ALiBi, and (for non-causal masks) full-window attention only. # (for non-causal masks) full-window attention. model_configs_fused_hdim256 = {