From fa2e2cb7aa161a1c65304ba7c35007c769659f48 Mon Sep 17 00:00:00 2001 From: Kshitij Lakhani Date: Fri, 20 Feb 2026 18:29:43 +0000 Subject: [PATCH 1/4] Enable sm120 support for fused attn if cuDNN is 9.18.1+ Signed-off-by: Kshitij Lakhani --- .../pytorch/attention/dot_product_attention/utils.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/transformer_engine/pytorch/attention/dot_product_attention/utils.py b/transformer_engine/pytorch/attention/dot_product_attention/utils.py index 567fd17c34..9646fed07e 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/utils.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/utils.py @@ -610,6 +610,7 @@ def get_attention_backend( qkv_layout, ) use_fused_attention = False + #TODO: KL check if this condition is now supported or not ? if ( device_compute_capability == (12, 0) and (head_dim_qk > 128 or head_dim_qk % 8 != 0) @@ -690,11 +691,11 @@ def _is_fa3_supported(num_heads, num_gqa_groups, head_dim_qk, head_dim_v, qkv_dt "padding between sequences, i.e. [a, a, PAD, b, b, b, PAD, c, PAD]" ) use_flash_attention = False - if device_compute_capability == (12, 0): + if device_compute_capability == (12, 0) and cudnn_version < (9, 18, 1): if use_fused_attention: logger.debug( "Disabling FusedAttention as qkv_format = thd is" - " not supported for compute capability = sm120" + " not supported for compute capability = sm120 and cuDNN version < 9.18.1" ) use_fused_attention = False From bea8bbbdf061cf3075cb645c740b98333326bd6a Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 20 Feb 2026 18:42:10 +0000 Subject: [PATCH 2/4] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../pytorch/attention/dot_product_attention/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/transformer_engine/pytorch/attention/dot_product_attention/utils.py b/transformer_engine/pytorch/attention/dot_product_attention/utils.py index 9646fed07e..82d1d1b2a6 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/utils.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/utils.py @@ -610,7 +610,7 @@ def get_attention_backend( qkv_layout, ) use_fused_attention = False - #TODO: KL check if this condition is now supported or not ? + # TODO: KL check if this condition is now supported or not ? if ( device_compute_capability == (12, 0) and (head_dim_qk > 128 or head_dim_qk % 8 != 0) From b2f5864b19ab5236aa4ffebb24fb081dbe187ab8 Mon Sep 17 00:00:00 2001 From: Kshitij Lakhani Date: Mon, 2 Mar 2026 15:25:48 -0800 Subject: [PATCH 3/4] Force intermediate tensors such as S, Sum_Exp, and Max to be BHS1 shape instead of TH1 for sm120 Signed-off-by: Kshitij Lakhani --- .../common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu | 9 ++++++--- transformer_engine/common/transformer_engine.cpp | 1 + 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu index eb2ebcff39..b5a12df803 100644 --- a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu +++ b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu @@ -1102,6 +1102,9 @@ void fused_attn_arbitrary_seqlen_fwd( devPtrSoftmaxOffset = input_SoftmaxOffset->data.dptr; } + const int device_id = cuda::current_device(); + const int sm_arch_ = cuda::sm_arch(device_id); + void *devPtrCuSeqlensQ = cu_seqlens_q->data.dptr; void *devPtrCuSeqlensKV = cu_seqlens_kv->data.dptr; void *devPtrSeqOffsetsQ = cu_seqlens_q_padded->data.dptr; @@ -1128,7 +1131,7 @@ void fused_attn_arbitrary_seqlen_fwd( if (return_max_logit) { Tensor *output_Max = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); output_Max->data.dptr = nullptr; - if (q_format == NVTE_QKV_Format::NVTE_THD && cudnn_runtime_version >= 90600) { + if ((q_format == NVTE_QKV_Format::NVTE_THD && cudnn_runtime_version >= 90600) && !(sm_arch_ >= 120)) { output_Max->data.shape = {num_tokens_q, num_attn_heads, 1}; } else { output_Max->data.shape = {batch, num_attn_heads, max_seqlen_q, 1}; @@ -1136,7 +1139,7 @@ void fused_attn_arbitrary_seqlen_fwd( output_Max->data.dtype = DType::kFloat32; Tensor *output_Sum_Exp = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); output_Sum_Exp->data.dptr = nullptr; - if (q_format == NVTE_QKV_Format::NVTE_THD && cudnn_runtime_version >= 90600) { + if ((q_format == NVTE_QKV_Format::NVTE_THD && cudnn_runtime_version >= 90600) && !(sm_arch_ >= 120)) { output_Sum_Exp->data.shape = {num_tokens_q, num_attn_heads, 1}; } else { output_Sum_Exp->data.shape = {batch, num_attn_heads, max_seqlen_q, 1}; @@ -1145,7 +1148,7 @@ void fused_attn_arbitrary_seqlen_fwd( } else { Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); output_S->data.dptr = nullptr; - if (q_format == NVTE_QKV_Format::NVTE_THD && cudnn_runtime_version >= 90600) { + if ((q_format == NVTE_QKV_Format::NVTE_THD && cudnn_runtime_version >= 90600) && !(sm_arch_ >= 120)) { output_S->data.shape = {num_tokens_q, num_attn_heads, 1}; } else { output_S->data.shape = {batch, num_attn_heads, max_seqlen_q, 1}; diff --git a/transformer_engine/common/transformer_engine.cpp b/transformer_engine/common/transformer_engine.cpp index cd02074fbd..763b3a1673 100644 --- a/transformer_engine/common/transformer_engine.cpp +++ b/transformer_engine/common/transformer_engine.cpp @@ -1130,6 +1130,7 @@ int nvte_is_non_tn_fp8_gemm_supported() { static std::vector cache(num_devices, -1); static std::vector flags(num_devices); int device_id = transformer_engine::cuda::current_device(); + // TODO: KL check if this condition is now supported or not ? std::call_once(flags[device_id], [&]() { int deviceComputeCapability = transformer_engine::cuda::sm_arch(device_id); // Note: this is temporary restriction and should be lifted in the future. From 076420d69f7f815844547cf958a318dd4e72a02c Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 2 Mar 2026 23:32:31 +0000 Subject: [PATCH 4/4] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu index b5a12df803..7c55654228 100644 --- a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu +++ b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu @@ -1131,7 +1131,8 @@ void fused_attn_arbitrary_seqlen_fwd( if (return_max_logit) { Tensor *output_Max = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); output_Max->data.dptr = nullptr; - if ((q_format == NVTE_QKV_Format::NVTE_THD && cudnn_runtime_version >= 90600) && !(sm_arch_ >= 120)) { + if ((q_format == NVTE_QKV_Format::NVTE_THD && cudnn_runtime_version >= 90600) && + !(sm_arch_ >= 120)) { output_Max->data.shape = {num_tokens_q, num_attn_heads, 1}; } else { output_Max->data.shape = {batch, num_attn_heads, max_seqlen_q, 1}; @@ -1139,7 +1140,8 @@ void fused_attn_arbitrary_seqlen_fwd( output_Max->data.dtype = DType::kFloat32; Tensor *output_Sum_Exp = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); output_Sum_Exp->data.dptr = nullptr; - if ((q_format == NVTE_QKV_Format::NVTE_THD && cudnn_runtime_version >= 90600) && !(sm_arch_ >= 120)) { + if ((q_format == NVTE_QKV_Format::NVTE_THD && cudnn_runtime_version >= 90600) && + !(sm_arch_ >= 120)) { output_Sum_Exp->data.shape = {num_tokens_q, num_attn_heads, 1}; } else { output_Sum_Exp->data.shape = {batch, num_attn_heads, max_seqlen_q, 1}; @@ -1148,7 +1150,8 @@ void fused_attn_arbitrary_seqlen_fwd( } else { Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); output_S->data.dptr = nullptr; - if ((q_format == NVTE_QKV_Format::NVTE_THD && cudnn_runtime_version >= 90600) && !(sm_arch_ >= 120)) { + if ((q_format == NVTE_QKV_Format::NVTE_THD && cudnn_runtime_version >= 90600) && + !(sm_arch_ >= 120)) { output_S->data.shape = {num_tokens_q, num_attn_heads, 1}; } else { output_S->data.shape = {batch, num_attn_heads, max_seqlen_q, 1};