diff --git a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu index eb2ebcff39..7c55654228 100644 --- a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu +++ b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu @@ -1102,6 +1102,9 @@ void fused_attn_arbitrary_seqlen_fwd( devPtrSoftmaxOffset = input_SoftmaxOffset->data.dptr; } + const int device_id = cuda::current_device(); + const int sm_arch_ = cuda::sm_arch(device_id); + void *devPtrCuSeqlensQ = cu_seqlens_q->data.dptr; void *devPtrCuSeqlensKV = cu_seqlens_kv->data.dptr; void *devPtrSeqOffsetsQ = cu_seqlens_q_padded->data.dptr; @@ -1128,7 +1131,8 @@ void fused_attn_arbitrary_seqlen_fwd( if (return_max_logit) { Tensor *output_Max = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); output_Max->data.dptr = nullptr; - if (q_format == NVTE_QKV_Format::NVTE_THD && cudnn_runtime_version >= 90600) { + if ((q_format == NVTE_QKV_Format::NVTE_THD && cudnn_runtime_version >= 90600) && + !(sm_arch_ >= 120)) { output_Max->data.shape = {num_tokens_q, num_attn_heads, 1}; } else { output_Max->data.shape = {batch, num_attn_heads, max_seqlen_q, 1}; @@ -1136,7 +1140,8 @@ void fused_attn_arbitrary_seqlen_fwd( output_Max->data.dtype = DType::kFloat32; Tensor *output_Sum_Exp = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); output_Sum_Exp->data.dptr = nullptr; - if (q_format == NVTE_QKV_Format::NVTE_THD && cudnn_runtime_version >= 90600) { + if ((q_format == NVTE_QKV_Format::NVTE_THD && cudnn_runtime_version >= 90600) && + !(sm_arch_ >= 120)) { output_Sum_Exp->data.shape = {num_tokens_q, num_attn_heads, 1}; } else { output_Sum_Exp->data.shape = {batch, num_attn_heads, max_seqlen_q, 1}; @@ -1145,7 +1150,8 @@ void fused_attn_arbitrary_seqlen_fwd( } else { Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); output_S->data.dptr = nullptr; - if (q_format == NVTE_QKV_Format::NVTE_THD && cudnn_runtime_version >= 90600) { + if ((q_format == NVTE_QKV_Format::NVTE_THD && cudnn_runtime_version >= 90600) && + !(sm_arch_ >= 120)) { output_S->data.shape = {num_tokens_q, num_attn_heads, 1}; } else { output_S->data.shape = {batch, num_attn_heads, max_seqlen_q, 1}; diff --git a/transformer_engine/common/transformer_engine.cpp b/transformer_engine/common/transformer_engine.cpp index cd02074fbd..763b3a1673 100644 --- a/transformer_engine/common/transformer_engine.cpp +++ b/transformer_engine/common/transformer_engine.cpp @@ -1130,6 +1130,7 @@ int nvte_is_non_tn_fp8_gemm_supported() { static std::vector cache(num_devices, -1); static std::vector flags(num_devices); int device_id = transformer_engine::cuda::current_device(); + // TODO: KL check if this condition is now supported or not ? std::call_once(flags[device_id], [&]() { int deviceComputeCapability = transformer_engine::cuda::sm_arch(device_id); // Note: this is temporary restriction and should be lifted in the future. diff --git a/transformer_engine/pytorch/attention/dot_product_attention/utils.py b/transformer_engine/pytorch/attention/dot_product_attention/utils.py index 567fd17c34..82d1d1b2a6 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/utils.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/utils.py @@ -610,6 +610,7 @@ def get_attention_backend( qkv_layout, ) use_fused_attention = False + # TODO: KL check if this condition is now supported or not ? if ( device_compute_capability == (12, 0) and (head_dim_qk > 128 or head_dim_qk % 8 != 0) @@ -690,11 +691,11 @@ def _is_fa3_supported(num_heads, num_gqa_groups, head_dim_qk, head_dim_v, qkv_dt "padding between sequences, i.e. [a, a, PAD, b, b, b, PAD, c, PAD]" ) use_flash_attention = False - if device_compute_capability == (12, 0): + if device_compute_capability == (12, 0) and cudnn_version < (9, 18, 1): if use_fused_attention: logger.debug( "Disabling FusedAttention as qkv_format = thd is" - " not supported for compute capability = sm120" + " not supported for compute capability = sm120 and cuDNN version < 9.18.1" ) use_fused_attention = False