Skip to content
65 changes: 65 additions & 0 deletions tests/jax/test_fused_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -444,6 +444,56 @@ def _check_configs(self):
"is either BSHD_BSHD_BSHD or THD_THD_THD"
)

# D=256 bprop on SM10.x uses cuDNN's dedicated SDPA bprop kernel
# (cuDNN FE 1.24 / BE 9.23+). FE forces this path onto the deterministic algorithm path,
# which rejects dBias, dropout, and ALiBi. It supports vanilla softmax only and allows SWA
# together with a causal mask only.
compute_capability = get_device_compute_capability(0)
is_sm10x = 100 <= compute_capability < 110
if self.is_training and is_sm10x and (self.head_dim_qk == 256 or self.head_dim_v == 256):
if self.head_dim_qk != 256 or self.head_dim_v != 256:
pytest.skip(
"D=256 BWD on Blackwell only supports d_qk == d_v == 256;"
f" got d_qk={self.head_dim_qk}, d_v={self.head_dim_v}."
)
cudnn_version = get_cudnn_version()
if cudnn_version < 92300:
pytest.skip(
"D=256 BWD on Blackwell requires cuDNN 9.23 or newer;"
f" got cuDNN {cudnn_version}."
)
# Non-learnable bias is fine (bias is allowed as an input); only dBias is
# unsupported. The JAX runner asks for dBias iff the bias shape is [1, h, s, s]
# (see test_backward), so gate on that.
unsupported = None
if self.attn_bias_type == AttnBiasType.PRE_SCALE_BIAS:
unsupported = "pre-scale bias"
elif self.attn_bias_type != AttnBiasType.NO_BIAS and self.bias_shape == BiasShape._1HSS:
unsupported = (
"bias gradients (dBias); frozen/non-learnable bias inputs"
" (i.e. non-1HSS bias shapes) are supported"
)
Comment on lines +465 to +475
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 JAX skip logic diverges from C++ backend gate for non-1HSS bias

The comment says "frozen/non-learnable bias inputs (i.e. non-1HSS bias shapes) are supported" and the skip block deliberately allows those configs to proceed. However, the C++ gate in fused_attn.cpp requires bias_type == NVTE_NO_BIAS for the new D=256 BWD path, meaning any config with attn_bias_type != NO_BIAS && bias_shape != _1HSS will silently fall back to a different backend rather than exercising the new kernel. The test will not fail, but it also will not validate the D=256 BWD path for those configs, and the inline comment creates a misleading expectation that such configs are actually routed through it.

elif self.dropout_prob != 0.0:
unsupported = "dropout"
elif self.softmax_type != AttnSoftmaxType.VANILLA_SOFTMAX:
unsupported = "non-vanilla softmax"
if unsupported is not None:
pytest.skip(
"D=256 BWD on Blackwell uses the deterministic SM100 D=256 SDPA BWD"
f" kernel which does not support {unsupported}."
)
if self.window_size is not None and self.window_size != (-1, -1):
if not self.attn_mask_type.is_causal():
pytest.skip(
"D=256 BWD on Blackwell uses the SM100 D=256 SDPA BWD kernel"
" which requires window_size=(-1, -1) for non-causal masks."
)
if self.window_size[1] not in (-1, 0):
pytest.skip(
"D=256 BWD on Blackwell only supports right window -1 or 0"
" for causal masks."
)

self.backend = FusedAttnHelper(
self.is_training,
self.dtype,
Expand Down Expand Up @@ -1474,6 +1524,21 @@ def test_backward(
QKVLayout.THD_THD_THD,
id="2-1024-2048-12-6-128-64-BF16-CROSS-GQA-RAGGED_SEPARATE",
),
# D=256 deterministic backward on the SM100 dedicated SDPA bprop kernel
# (cuDNN FE 1.24 / BE 9.23+). Unsupported configs (e.g. dBias, non-256 head dims)
# are skipped by FusedAttnRunner._check_configs.
pytest.param(
4,
128,
128,
16,
16,
256,
256,
jnp.float16,
QKVLayout.BSHD_BS2HD,
id="4-128-128-16-16-256-256-FP16-SELF-KV_PACKED",
),
],
)
@pytest.mark.parametrize(
Expand Down
32 changes: 32 additions & 0 deletions tests/pytorch/attention/test_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -377,6 +377,38 @@ def test_dpa_fa4_hdim256(dtype, model_configs, model):
test_dot_product_attention(dtype, model_configs, model, False, True, None, False, False)


# cuDNN FusedAttention D=256 bprop is supported on sm10x from cuDNN 9.23 (FE 1.24),
# via the dedicated deterministic SDPA bprop kernel, which supports d_qk == d_v == 256 only,
# vanilla type of softmax only, no dropout, no ALiBi, and (for non-causal masks) full-window attention only.
# (for non-causal masks) full-window attention.
Comment on lines +382 to +383
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Duplicated comment fragment

The comment block ends with a repeated phrase: line 383 (# (for non-causal masks) full-window attention.) is a verbatim fragment of line 382, left over from editing. It should be removed.

Suggested change
# vanilla type of softmax only, no dropout, no ALiBi, and (for non-causal masks) full-window attention only.
# (for non-causal masks) full-window attention.
# vanilla type of softmax only, no dropout, no ALiBi, and (for non-causal masks) full-window attention only.

Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!

model_configs_fused_hdim256 = {
# test: ModelConfig(b, sq, hq, dqk) -> head_dim_v defaults to head_dim_qk (256)
"fused_hd256_no_mask": ModelConfig(2, 512, 16, 256),
"fused_hd256_padding": ModelConfig(2, 512, 16, 256, attn_mask_type="padding"),
# SWA is allowed only together with a causal mask on the D=256 bprop kernel.
"fused_hd256_causal_swa": ModelConfig(
2, 1024, 16, 256, attn_mask_type="causal", window_size=(128, 0)
),
# GQA variant (num_gqa_groups < num_heads).
"fused_hd256_padding_causal_gqa": ModelConfig(
2, 1024, 16, 256, num_gqa_groups=4, attn_mask_type="padding_causal"
),
}


@pytest.mark.skipif(get_cudnn_version() < (9, 23, 0), reason="cuDNN 9.23+ is required.")
@pytest.mark.skipif(
device_compute_capability not in ((10, 0), (10, 3)),
reason="cuDNN FusedAttention head_dim=256 backward is Blackwell server (SM100/SM103) only.",
)
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("model_configs", [model_configs_fused_hdim256])
@pytest.mark.parametrize("model", model_configs_fused_hdim256.keys())
def test_dpa_fused_attn_hdim256(dtype, model_configs, model):
"""Test DotProductAttention with cuDNN FusedAttention: head_dim=256 backward on Blackwell"""
test_dot_product_attention(dtype, model_configs, model, False, True, None, False, False)


model_configs_fa4_mla = {
# test: ModelConfig(b, sq, hq, dqk, head_dim_v=dv)
"fa4_mla_1": ModelConfig(4, 128, 16, 128, head_dim_v=64),
Expand Down
15 changes: 15 additions & 0 deletions transformer_engine/common/fused_attn/fused_attn.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -315,6 +315,21 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(
(head_dim_qk <= 256 && head_dim_v <= 256 &&
((!is_training && sm_arch_ == 90 && cudnn_runtime_version >= 90100) ||
(is_training && sm_arch_ == 90 && cudnn_runtime_version >= 90500))) ||
// 9.23: d_qk = d_v = 256 + SM10x (cuDNN FE 1.24 / BE 9.23+) + bprop + non-paged
(head_dim_qk == 256 && head_dim_v == 256 && is_training && sm_arch_ >= 100 &&
sm_arch_ < 110 && cudnn_runtime_version >= 92300 &&
layout_group != NVTE_QKV_Layout_Group::NVTE_Paged_KV_HD_HD_HD &&
// The FE forces this path onto the deterministic bprop algorithm, which on
// Blackwell rejects dBias, dropout, and ALiBi (and supports vanilla softmax only).
bias_type == NVTE_Bias_Type::NVTE_NO_BIAS && dropout == 0.0 &&
softmax_type == NVTE_Softmax_Type::NVTE_VANILLA_SOFTMAX &&
// Non-causal D=256 supports only full-window attention; SWA is allowed only for causal masks.
((window_size_left == -1 && window_size_right == -1) ||
((attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK ||
attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK ||
attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_BOTTOM_RIGHT_MASK ||
attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK) &&
(window_size_right == -1 || window_size_right == 0)))) ||
// 9.9: any head_dim + Blackwell + fprop + non_paged + sq > 1
(!is_training && sm_arch_ >= 100 && cudnn_runtime_version >= 90900 && max_seqlen_q > 1 &&
layout_group != NVTE_QKV_Layout_Group::NVTE_Paged_KV_HD_HD_HD) ||
Expand Down
Loading