Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -1102,6 +1102,9 @@ void fused_attn_arbitrary_seqlen_fwd(
devPtrSoftmaxOffset = input_SoftmaxOffset->data.dptr;
}

const int device_id = cuda::current_device();
const int sm_arch_ = cuda::sm_arch(device_id);

void *devPtrCuSeqlensQ = cu_seqlens_q->data.dptr;
void *devPtrCuSeqlensKV = cu_seqlens_kv->data.dptr;
void *devPtrSeqOffsetsQ = cu_seqlens_q_padded->data.dptr;
Expand All @@ -1128,15 +1131,17 @@ void fused_attn_arbitrary_seqlen_fwd(
if (return_max_logit) {
Tensor *output_Max = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]);
output_Max->data.dptr = nullptr;
if (q_format == NVTE_QKV_Format::NVTE_THD && cudnn_runtime_version >= 90600) {
if ((q_format == NVTE_QKV_Format::NVTE_THD && cudnn_runtime_version >= 90600) &&
!(sm_arch_ >= 120)) {
output_Max->data.shape = {num_tokens_q, num_attn_heads, 1};
} else {
output_Max->data.shape = {batch, num_attn_heads, max_seqlen_q, 1};
}
output_Max->data.dtype = DType::kFloat32;
Tensor *output_Sum_Exp = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]);
output_Sum_Exp->data.dptr = nullptr;
if (q_format == NVTE_QKV_Format::NVTE_THD && cudnn_runtime_version >= 90600) {
if ((q_format == NVTE_QKV_Format::NVTE_THD && cudnn_runtime_version >= 90600) &&
!(sm_arch_ >= 120)) {
output_Sum_Exp->data.shape = {num_tokens_q, num_attn_heads, 1};
} else {
output_Sum_Exp->data.shape = {batch, num_attn_heads, max_seqlen_q, 1};
Expand All @@ -1145,7 +1150,8 @@ void fused_attn_arbitrary_seqlen_fwd(
} else {
Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]);
output_S->data.dptr = nullptr;
if (q_format == NVTE_QKV_Format::NVTE_THD && cudnn_runtime_version >= 90600) {
if ((q_format == NVTE_QKV_Format::NVTE_THD && cudnn_runtime_version >= 90600) &&
!(sm_arch_ >= 120)) {
output_S->data.shape = {num_tokens_q, num_attn_heads, 1};
} else {
output_S->data.shape = {batch, num_attn_heads, max_seqlen_q, 1};
Expand Down
1 change: 1 addition & 0 deletions transformer_engine/common/transformer_engine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1130,6 +1130,7 @@ int nvte_is_non_tn_fp8_gemm_supported() {
static std::vector<int> cache(num_devices, -1);
static std::vector<std::once_flag> flags(num_devices);
int device_id = transformer_engine::cuda::current_device();
// TODO: KL check if this condition is now supported or not ?
std::call_once(flags[device_id], [&]() {
int deviceComputeCapability = transformer_engine::cuda::sm_arch(device_id);
// Note: this is temporary restriction and should be lifted in the future.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -610,6 +610,7 @@ def get_attention_backend(
qkv_layout,
)
use_fused_attention = False
# TODO: KL check if this condition is now supported or not ?
if (
device_compute_capability == (12, 0)
and (head_dim_qk > 128 or head_dim_qk % 8 != 0)
Expand Down Expand Up @@ -690,11 +691,11 @@ def _is_fa3_supported(num_heads, num_gqa_groups, head_dim_qk, head_dim_v, qkv_dt
"padding between sequences, i.e. [a, a, PAD, b, b, b, PAD, c, PAD]"
)
use_flash_attention = False
if device_compute_capability == (12, 0):
if device_compute_capability == (12, 0) and cudnn_version < (9, 18, 1):
if use_fused_attention:
logger.debug(
"Disabling FusedAttention as qkv_format = thd is"
" not supported for compute capability = sm120"
" not supported for compute capability = sm120 and cuDNN version < 9.18.1"
)
use_fused_attention = False

Expand Down