From d0a3a09dd0115dc5b9184462d813ba38c0d176ff Mon Sep 17 00:00:00 2001 From: suss <1152623206@qq.com> Date: Sun, 8 Mar 2026 16:37:54 +0000 Subject: [PATCH 1/2] add mha_kvcache --- csrc/models/llama/llama_attention.cpp | 42 +++++++++++++++++++++------ 1 file changed, 33 insertions(+), 9 deletions(-) diff --git a/csrc/models/llama/llama_attention.cpp b/csrc/models/llama/llama_attention.cpp index fe76479b..ec40143d 100644 --- a/csrc/models/llama/llama_attention.cpp +++ b/csrc/models/llama/llama_attention.cpp @@ -4,6 +4,7 @@ #include "infinicore/nn/linear.hpp" #include "infinicore/nn/rope.hpp" #include "infinicore/ops.hpp" +#include "infinicore/ops/mha_kvcache.hpp" #include "infinicore/ops/mha_varlen.hpp" #include "infinicore/ops/mul.hpp" @@ -331,16 +332,39 @@ infinicore::Tensor LlamaAttention::forward_paged_(const infinicore::Tensor &hidd scaling_); } } else { - infinicore::op::paged_attention_( - attn_output, - q_reshaped, - k_total, - v_total, - block_tables.value(), - total_sequence_lengths.value(), - std::nullopt, - scaling_); + if (attention_backend_ == backends::AttentionBackend::FlashAttn) { + // FA2 decode path: flash::mha_fwd_kvcache + // In paged-attn mode, seq_len = actual batch_size (one query token per sequence). + // q_reshaped: [seq_len, num_heads, head_dim] → [seq_len, 1, num_heads, head_dim] + // k/v cache: [num_blocks, num_kv_heads, block_size, head_dim] + // → permute {0,2,1,3} → [num_blocks, block_size, num_kv_heads, head_dim] + auto q_for_fa = q_reshaped->view({seq_len, 1, num_attention_heads_, head_dim_}); + auto attn_out_4d = infinicore::Tensor::empty( + {seq_len, 1, num_attention_heads_, head_dim_}, + q_reshaped->dtype(), q_reshaped->device()); + infinicore::op::mha_kvcache_( + attn_out_4d, + q_for_fa, + k_total->permute({0, 2, 1, 3}), // [num_blocks, block_size, num_kv_heads, head_dim] + v_total->permute({0, 2, 1, 3}), + total_sequence_lengths.value(), // [seq_len] int32 (one entry per sequence) + block_tables.value(), // [seq_len, max_num_blocks_per_seq] int32 + std::nullopt, + scaling_); + attn_output = attn_out_4d->view({seq_len, num_attention_heads_, head_dim_}); + } else { + infinicore::op::paged_attention_( + attn_output, + q_reshaped, + k_total, + v_total, + block_tables.value(), + total_sequence_lengths.value(), + std::nullopt, + scaling_); + } } + // 7. Project output attn_output From b98aab17e1abe121df9ad245ab3cc23b8f339e95 Mon Sep 17 00:00:00 2001 From: suss <1152623206@qq.com> Date: Mon, 9 Mar 2026 11:58:44 +0000 Subject: [PATCH 2/2] repair gqa-api bug --- csrc/models/llama/llama_attention.cpp | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/csrc/models/llama/llama_attention.cpp b/csrc/models/llama/llama_attention.cpp index ec40143d..3bad0c9a 100644 --- a/csrc/models/llama/llama_attention.cpp +++ b/csrc/models/llama/llama_attention.cpp @@ -339,11 +339,7 @@ infinicore::Tensor LlamaAttention::forward_paged_(const infinicore::Tensor &hidd // k/v cache: [num_blocks, num_kv_heads, block_size, head_dim] // → permute {0,2,1,3} → [num_blocks, block_size, num_kv_heads, head_dim] auto q_for_fa = q_reshaped->view({seq_len, 1, num_attention_heads_, head_dim_}); - auto attn_out_4d = infinicore::Tensor::empty( - {seq_len, 1, num_attention_heads_, head_dim_}, - q_reshaped->dtype(), q_reshaped->device()); - infinicore::op::mha_kvcache_( - attn_out_4d, + auto attn_out_4d = infinicore::op::mha_kvcache( q_for_fa, k_total->permute({0, 2, 1, 3}), // [num_blocks, block_size, num_kv_heads, head_dim] v_total->permute({0, 2, 1, 3}),