diff --git a/csrc/models/llama/llama_attention.cpp b/csrc/models/llama/llama_attention.cpp index fe76479b..3bad0c9a 100644 --- a/csrc/models/llama/llama_attention.cpp +++ b/csrc/models/llama/llama_attention.cpp @@ -4,6 +4,7 @@ #include "infinicore/nn/linear.hpp" #include "infinicore/nn/rope.hpp" #include "infinicore/ops.hpp" +#include "infinicore/ops/mha_kvcache.hpp" #include "infinicore/ops/mha_varlen.hpp" #include "infinicore/ops/mul.hpp" @@ -331,16 +332,35 @@ infinicore::Tensor LlamaAttention::forward_paged_(const infinicore::Tensor &hidd scaling_); } } else { - infinicore::op::paged_attention_( - attn_output, - q_reshaped, - k_total, - v_total, - block_tables.value(), - total_sequence_lengths.value(), - std::nullopt, - scaling_); + if (attention_backend_ == backends::AttentionBackend::FlashAttn) { + // FA2 decode path: flash::mha_fwd_kvcache + // In paged-attn mode, seq_len = actual batch_size (one query token per sequence). + // q_reshaped: [seq_len, num_heads, head_dim] → [seq_len, 1, num_heads, head_dim] + // k/v cache: [num_blocks, num_kv_heads, block_size, head_dim] + // → permute {0,2,1,3} → [num_blocks, block_size, num_kv_heads, head_dim] + auto q_for_fa = q_reshaped->view({seq_len, 1, num_attention_heads_, head_dim_}); + auto attn_out_4d = infinicore::op::mha_kvcache( + q_for_fa, + k_total->permute({0, 2, 1, 3}), // [num_blocks, block_size, num_kv_heads, head_dim] + v_total->permute({0, 2, 1, 3}), + total_sequence_lengths.value(), // [seq_len] int32 (one entry per sequence) + block_tables.value(), // [seq_len, max_num_blocks_per_seq] int32 + std::nullopt, + scaling_); + attn_output = attn_out_4d->view({seq_len, num_attention_heads_, head_dim_}); + } else { + infinicore::op::paged_attention_( + attn_output, + q_reshaped, + k_total, + v_total, + block_tables.value(), + total_sequence_lengths.value(), + std::nullopt, + scaling_); + } } + // 7. Project output attn_output