diff --git a/tests/jax/test_fused_attn.py b/tests/jax/test_fused_attn.py index 1fb0108068..bbc45cea06 100644 --- a/tests/jax/test_fused_attn.py +++ b/tests/jax/test_fused_attn.py @@ -444,6 +444,56 @@ def _check_configs(self): "is either BSHD_BSHD_BSHD or THD_THD_THD" ) + # D=256 bprop on SM10.x uses cuDNN's dedicated SDPA bprop kernel + # (cuDNN FE 1.24 / BE 9.23+). FE forces this path onto the deterministic algorithm path, + # which rejects dBias, dropout, and ALiBi. It supports vanilla softmax only and allows SWA + # together with a causal mask only. + compute_capability = get_device_compute_capability(0) + is_sm10x = 100 <= compute_capability < 110 + if self.is_training and is_sm10x and (self.head_dim_qk == 256 or self.head_dim_v == 256): + if self.head_dim_qk != 256 or self.head_dim_v != 256: + pytest.skip( + "D=256 BWD on Blackwell only supports d_qk == d_v == 256;" + f" got d_qk={self.head_dim_qk}, d_v={self.head_dim_v}." + ) + cudnn_version = get_cudnn_version() + if cudnn_version < 92300: + pytest.skip( + "D=256 BWD on Blackwell requires cuDNN 9.23 or newer;" + f" got cuDNN {cudnn_version}." + ) + # Non-learnable bias is fine (bias is allowed as an input); only dBias is + # unsupported. The JAX runner asks for dBias iff the bias shape is [1, h, s, s] + # (see test_backward), so gate on that. + unsupported = None + if self.attn_bias_type == AttnBiasType.PRE_SCALE_BIAS: + unsupported = "pre-scale bias" + elif self.attn_bias_type != AttnBiasType.NO_BIAS and self.bias_shape == BiasShape._1HSS: + unsupported = ( + "bias gradients (dBias); frozen/non-learnable bias inputs" + " (i.e. non-1HSS bias shapes) are supported" + ) + elif self.dropout_prob != 0.0: + unsupported = "dropout" + elif self.softmax_type != AttnSoftmaxType.VANILLA_SOFTMAX: + unsupported = "non-vanilla softmax" + if unsupported is not None: + pytest.skip( + "D=256 BWD on Blackwell uses the deterministic SM100 D=256 SDPA BWD" + f" kernel which does not support {unsupported}." + ) + if self.window_size is not None and self.window_size != (-1, -1): + if not self.attn_mask_type.is_causal(): + pytest.skip( + "D=256 BWD on Blackwell uses the SM100 D=256 SDPA BWD kernel" + " which requires window_size=(-1, -1) for non-causal masks." + ) + if self.window_size[1] not in (-1, 0): + pytest.skip( + "D=256 BWD on Blackwell only supports right window -1 or 0" + " for causal masks." + ) + self.backend = FusedAttnHelper( self.is_training, self.dtype, @@ -1474,6 +1524,21 @@ def test_backward( QKVLayout.THD_THD_THD, id="2-1024-2048-12-6-128-64-BF16-CROSS-GQA-RAGGED_SEPARATE", ), + # D=256 deterministic backward on the SM100 dedicated SDPA bprop kernel + # (cuDNN FE 1.24 / BE 9.23+). Unsupported configs (e.g. dBias, non-256 head dims) + # are skipped by FusedAttnRunner._check_configs. + pytest.param( + 4, + 128, + 128, + 16, + 16, + 256, + 256, + jnp.float16, + QKVLayout.BSHD_BS2HD, + id="4-128-128-16-16-256-256-FP16-SELF-KV_PACKED", + ), ], ) @pytest.mark.parametrize( diff --git a/tests/pytorch/attention/test_attention.py b/tests/pytorch/attention/test_attention.py index 401bd6f01d..b5201e81b9 100644 --- a/tests/pytorch/attention/test_attention.py +++ b/tests/pytorch/attention/test_attention.py @@ -377,6 +377,38 @@ def test_dpa_fa4_hdim256(dtype, model_configs, model): test_dot_product_attention(dtype, model_configs, model, False, True, None, False, False) +# cuDNN FusedAttention D=256 bprop is supported on sm10x from cuDNN 9.23 (FE 1.24), +# via the dedicated deterministic SDPA bprop kernel, which supports d_qk == d_v == 256 only, +# vanilla type of softmax only, no dropout, no ALiBi, and (for non-causal masks) full-window attention only. +# (for non-causal masks) full-window attention. +model_configs_fused_hdim256 = { + # test: ModelConfig(b, sq, hq, dqk) -> head_dim_v defaults to head_dim_qk (256) + "fused_hd256_no_mask": ModelConfig(2, 512, 16, 256), + "fused_hd256_padding": ModelConfig(2, 512, 16, 256, attn_mask_type="padding"), + # SWA is allowed only together with a causal mask on the D=256 bprop kernel. + "fused_hd256_causal_swa": ModelConfig( + 2, 1024, 16, 256, attn_mask_type="causal", window_size=(128, 0) + ), + # GQA variant (num_gqa_groups < num_heads). + "fused_hd256_padding_causal_gqa": ModelConfig( + 2, 1024, 16, 256, num_gqa_groups=4, attn_mask_type="padding_causal" + ), +} + + +@pytest.mark.skipif(get_cudnn_version() < (9, 23, 0), reason="cuDNN 9.23+ is required.") +@pytest.mark.skipif( + device_compute_capability not in ((10, 0), (10, 3)), + reason="cuDNN FusedAttention head_dim=256 backward is Blackwell server (SM100/SM103) only.", +) +@pytest.mark.parametrize("dtype", param_types) +@pytest.mark.parametrize("model_configs", [model_configs_fused_hdim256]) +@pytest.mark.parametrize("model", model_configs_fused_hdim256.keys()) +def test_dpa_fused_attn_hdim256(dtype, model_configs, model): + """Test DotProductAttention with cuDNN FusedAttention: head_dim=256 backward on Blackwell""" + test_dot_product_attention(dtype, model_configs, model, False, True, None, False, False) + + model_configs_fa4_mla = { # test: ModelConfig(b, sq, hq, dqk, head_dim_v=dv) "fa4_mla_1": ModelConfig(4, 128, 16, 128, head_dim_v=64), diff --git a/transformer_engine/common/fused_attn/fused_attn.cpp b/transformer_engine/common/fused_attn/fused_attn.cpp index d2eb1a831c..d70ccf88c9 100644 --- a/transformer_engine/common/fused_attn/fused_attn.cpp +++ b/transformer_engine/common/fused_attn/fused_attn.cpp @@ -315,6 +315,21 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( (head_dim_qk <= 256 && head_dim_v <= 256 && ((!is_training && sm_arch_ == 90 && cudnn_runtime_version >= 90100) || (is_training && sm_arch_ == 90 && cudnn_runtime_version >= 90500))) || + // 9.23: d_qk = d_v = 256 + SM10x (cuDNN FE 1.24 / BE 9.23+) + bprop + non-paged + (head_dim_qk == 256 && head_dim_v == 256 && is_training && sm_arch_ >= 100 && + sm_arch_ < 110 && cudnn_runtime_version >= 92300 && + layout_group != NVTE_QKV_Layout_Group::NVTE_Paged_KV_HD_HD_HD && + // The FE forces this path onto the deterministic bprop algorithm, which on + // Blackwell rejects dBias, dropout, and ALiBi (and supports vanilla softmax only). + bias_type == NVTE_Bias_Type::NVTE_NO_BIAS && dropout == 0.0 && + softmax_type == NVTE_Softmax_Type::NVTE_VANILLA_SOFTMAX && + // Non-causal D=256 supports only full-window attention; SWA is allowed only for causal masks. + ((window_size_left == -1 && window_size_right == -1) || + ((attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK || + attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK || + attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_BOTTOM_RIGHT_MASK || + attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK) && + (window_size_right == -1 || window_size_right == 0)))) || // 9.9: any head_dim + Blackwell + fprop + non_paged + sq > 1 (!is_training && sm_arch_ >= 100 && cudnn_runtime_version >= 90900 && max_seqlen_q > 1 && layout_group != NVTE_QKV_Layout_Group::NVTE_Paged_KV_HD_HD_HD) ||