From 7f426b26cbc91cb57896b48d5df3cb0657acd435 Mon Sep 17 00:00:00 2001 From: zhushuang <974198603@qq.com> Date: Mon, 9 Mar 2026 14:35:18 +0800 Subject: [PATCH] issue/1061 - feat: use template to replace int64_t in paged_attention_prefill kernel for moore gpu --- .../moore/paged_attention_prefill_kernel.h | 16 ++++---- .../moore/paged_attention_prefill_moore.mu | 39 ++++++++++++------- 2 files changed, 34 insertions(+), 21 deletions(-) diff --git a/src/infiniop/ops/paged_attention_prefill/moore/paged_attention_prefill_kernel.h b/src/infiniop/ops/paged_attention_prefill/moore/paged_attention_prefill_kernel.h index b4431b953..0569fd142 100644 --- a/src/infiniop/ops/paged_attention_prefill/moore/paged_attention_prefill_kernel.h +++ b/src/infiniop/ops/paged_attention_prefill/moore/paged_attention_prefill_kernel.h @@ -1,7 +1,9 @@ #ifndef __PAGED_ATTENTION_PREFILL_KERNEL_CUH__ #define __PAGED_ATTENTION_PREFILL_KERNEL_CUH__ namespace op::paged_attention_prefill::cuda { -__device__ __forceinline__ size_t find_seq_id(size_t token_idx, const int64_t *cum_seq_lens_q, size_t num_seqs) { + +template +__device__ __forceinline__ size_t find_seq_id(size_t token_idx, const Tindex *cum_seq_lens_q, size_t num_seqs) { size_t low = 0, high = num_seqs - 1; while (low <= high) { size_t mid = (low + high) >> 1; @@ -48,12 +50,12 @@ __device__ __forceinline__ float blockReduceSum(float val) { return shared[0]; } -template +template __global__ void pagedAttentionPrefillKernel( Tdata *out_, const Tdata *q_, const Tdata *k_cache_, const Tdata *v_cache_, - const int64_t *block_tables_, - const int64_t *total_kv_lens_, - const int64_t *cum_seq_lens_q_, + const Tindex *block_tables_, + const Tindex *total_kv_lens_, + const Tindex *cum_seq_lens_q_, const float *alibi_slopes_, const size_t num_heads, const size_t num_kv_heads, const float scale, const size_t max_num_blocks_per_seq, const size_t block_size, @@ -75,7 +77,7 @@ __global__ void pagedAttentionPrefillKernel( __shared__ float sh_w; __shared__ float sh_inv_l; if (dim_idx == 0) { - sh_seq_idx = find_seq_id(global_token_idx, cum_seq_lens_q_, num_seqs); + sh_seq_idx = find_seq_id(global_token_idx, cum_seq_lens_q_, num_seqs); const size_t q_token_idx = global_token_idx - static_cast(cum_seq_lens_q_[sh_seq_idx]); const size_t total_kv_len = static_cast(total_kv_lens_[sh_seq_idx]); const size_t q_len = static_cast(cum_seq_lens_q_[sh_seq_idx + 1] - cum_seq_lens_q_[sh_seq_idx]); @@ -90,7 +92,7 @@ __global__ void pagedAttentionPrefillKernel( const size_t kv_head_idx = sh_kv_head_idx; const Tdata *q_vec = q_ + global_token_idx * q_stride + head_idx * q_head_stride; Tdata *out_ptr = out_ + global_token_idx * num_heads * head_size + head_idx * head_size; - const int64_t *block_table = block_tables_ + seq_idx * max_num_blocks_per_seq; + const Tindex *block_table = block_tables_ + seq_idx * max_num_blocks_per_seq; const float alibi_slope = (alibi_slopes_ == nullptr) ? 0.0f : alibi_slopes_[head_idx]; const float qv = static_cast(q_vec[dim_idx]); Tcompute acc = 0.0f; diff --git a/src/infiniop/ops/paged_attention_prefill/moore/paged_attention_prefill_moore.mu b/src/infiniop/ops/paged_attention_prefill/moore/paged_attention_prefill_moore.mu index 8927b4d73..0ad2e1a51 100644 --- a/src/infiniop/ops/paged_attention_prefill/moore/paged_attention_prefill_moore.mu +++ b/src/infiniop/ops/paged_attention_prefill/moore/paged_attention_prefill_moore.mu @@ -8,12 +8,12 @@ #include "paged_attention_prefill_kernel.h" #include "paged_attention_prefill_moore.h" -template +template infiniStatus_t launchPagedAttentionPrefill( Tdata *out, const Tdata *q, const Tdata *k_cache, const Tdata *v_cache, - const int64_t *block_tables, - const int64_t *seq_lens, - const int64_t *cum_seq_lens_q, + const Tindex *block_tables, + const Tindex *seq_lens, + const Tindex *cum_seq_lens_q, const float *alibi_slopes, const size_t num_heads, const size_t num_seqs, @@ -36,7 +36,7 @@ infiniStatus_t launchPagedAttentionPrefill( dim3 grid(total_q_tokens, num_heads); dim3 block(head_size); - op::paged_attention_prefill::cuda::pagedAttentionPrefillKernel + op::paged_attention_prefill::cuda::pagedAttentionPrefillKernel <<>>( out, q, k_cache, v_cache, block_tables, seq_lens, cum_seq_lens_q, alibi_slopes, @@ -99,10 +99,10 @@ infiniStatus_t Descriptor::calculate( musaStream_t stream = (musaStream_t)stream_; -#define LAUNCH_KERNEL(Tdata, Tcompute) \ - launchPagedAttentionPrefill( \ +#define DISPATCH_KERNEL(Tindex, Tdata, Tcompute) \ + return launchPagedAttentionPrefill( \ (Tdata *)out, (const Tdata *)q, (const Tdata *)k_cache, (const Tdata *)v_cache, \ - (const int64_t *)block_tables, (const int64_t *)seq_lens, (const int64_t *)cum_seq_lens_q, \ + static_cast(block_tables), static_cast(seq_lens), static_cast(cum_seq_lens_q), \ (const float *)alibi_slopes, \ _info.num_heads, _info.num_seqs, _info.num_kv_heads, \ _info.scale, _info.max_num_blocks_per_seq, \ @@ -112,12 +112,23 @@ infiniStatus_t Descriptor::calculate( _info.q_stride, _info.q_head_stride, \ stream) - if (_info.dtype == INFINI_DTYPE_F16) { - return LAUNCH_KERNEL(half, float); - } else if (_info.dtype == INFINI_DTYPE_BF16) { - return LAUNCH_KERNEL(__mt_bfloat16, float); - } else if (_info.dtype == INFINI_DTYPE_F32) { - return LAUNCH_KERNEL(float, float); +#define DISPATCH_INDEX(Tindex) \ + do { \ + if (_info.dtype == INFINI_DTYPE_F16) { \ + DISPATCH_KERNEL(Tindex, half, float); \ + } \ + if (_info.dtype == INFINI_DTYPE_BF16) { \ + DISPATCH_KERNEL(Tindex, __nv_bfloat16, float); \ + } \ + return INFINI_STATUS_BAD_TENSOR_DTYPE; \ + } while (false) + + if (_info.index_dtype == INFINI_DTYPE_I64){ + DISPATCH_INDEX(int64_t); + } else if (_info.index_dtype == INFINI_DTYPE_I32){ + DISPATCH_INDEX(int32_t); + } else if (_info.index_dtype == INFINI_DTYPE_U32){ + DISPATCH_INDEX(uint32_t); } return INFINI_STATUS_BAD_TENSOR_DTYPE;