From 665f383b49e4ab79901acb091e6bb5396964142b Mon Sep 17 00:00:00 2001 From: suss <1152623206@qq.com> Date: Sun, 8 Mar 2026 16:36:05 +0000 Subject: [PATCH 1/3] issue/1065 - add mha_kvcache --- include/infinicore/ops/mha_kvcache.hpp | 51 +++++++++++ src/infinicore/ops/mha_kvcache/mha_kvcache.cc | 58 +++++++++++++ .../ops/mha_kvcache/mha_kvcache_flashattn.cc | 85 +++++++++++++++++++ 3 files changed, 194 insertions(+) create mode 100644 include/infinicore/ops/mha_kvcache.hpp create mode 100644 src/infinicore/ops/mha_kvcache/mha_kvcache.cc create mode 100644 src/infinicore/ops/mha_kvcache/mha_kvcache_flashattn.cc diff --git a/include/infinicore/ops/mha_kvcache.hpp b/include/infinicore/ops/mha_kvcache.hpp new file mode 100644 index 000000000..2769e4e39 --- /dev/null +++ b/include/infinicore/ops/mha_kvcache.hpp @@ -0,0 +1,51 @@ +#pragma once + +#include "../device.hpp" +#include "../graph/graph.hpp" +#include "common/op.hpp" +#include + +namespace infinicore::op { + +// Flash Attention KV-cache decode op. +// +// Wraps flash::mha_fwd_kvcache for single-step (decode) attention over a +// paged KV cache. +// +// Tensor shapes: +// out : [batch_size, seqlen_q, num_heads, head_size] +// q : [batch_size, seqlen_q, num_heads, head_size] +// k_cache : [num_blocks, block_size, num_heads_k, head_size] (paged layout) +// v_cache : [num_blocks, block_size, num_heads_k, head_size] (paged layout) +// seqlens_k : [batch_size] int32 — total KV length per request +// block_table : [batch_size, max_num_blocks_per_seq] int32 + +INFINICORE_GRAPH_OP_CLASS( + MhaKVCache, + Tensor, // out + const Tensor &, // q + const Tensor &, // k_cache + const Tensor &, // v_cache + const Tensor &, // seqlens_k + const Tensor &, // block_table + std::optional, // alibi_slopes + float); // scale + +Tensor mha_kvcache(const Tensor &q, + const Tensor &k_cache, + const Tensor &v_cache, + const Tensor &seqlens_k, + const Tensor &block_table, + std::optional alibi_slopes, + float scale); + +void mha_kvcache_(Tensor out, + const Tensor &q, + const Tensor &k_cache, + const Tensor &v_cache, + const Tensor &seqlens_k, + const Tensor &block_table, + std::optional alibi_slopes, + float scale); + +} // namespace infinicore::op diff --git a/src/infinicore/ops/mha_kvcache/mha_kvcache.cc b/src/infinicore/ops/mha_kvcache/mha_kvcache.cc new file mode 100644 index 000000000..0c5b3ae8c --- /dev/null +++ b/src/infinicore/ops/mha_kvcache/mha_kvcache.cc @@ -0,0 +1,58 @@ +#include "infinicore/ops/mha_kvcache.hpp" +#include "../../utils.hpp" + +namespace infinicore::op { + +INFINICORE_GRAPH_OP_DISPATCHERS_IMPL(MhaKVCache); + +MhaKVCache::MhaKVCache(Tensor out, + const Tensor &q, + const Tensor &k_cache, + const Tensor &v_cache, + const Tensor &seqlens_k, + const Tensor &block_table, + std::optional alibi_slopes, + float scale) { + INFINICORE_ASSERT_TENSORS_SAME_DEVICE(out, q, k_cache, v_cache, seqlens_k, block_table); + INFINICORE_GRAPH_OP_DISPATCH(out->device().getType(), + out, q, k_cache, v_cache, seqlens_k, block_table, alibi_slopes, scale); +} + +void MhaKVCache::execute(Tensor out, + const Tensor &q, + const Tensor &k_cache, + const Tensor &v_cache, + const Tensor &seqlens_k, + const Tensor &block_table, + std::optional alibi_slopes, + float scale) { + INFINICORE_GRAPH_OP_RECORD_OR_RUN( + MhaKVCache, + out, q, k_cache, v_cache, seqlens_k, block_table, alibi_slopes, scale); +} + +void mha_kvcache_(Tensor out, + const Tensor &q, + const Tensor &k_cache, + const Tensor &v_cache, + const Tensor &seqlens_k, + const Tensor &block_table, + std::optional alibi_slopes, + float scale) { + MhaKVCache::execute(out, q, k_cache, v_cache, seqlens_k, block_table, alibi_slopes, scale); +} + +Tensor mha_kvcache(const Tensor &q, + const Tensor &k_cache, + const Tensor &v_cache, + const Tensor &seqlens_k, + const Tensor &block_table, + std::optional alibi_slopes, + float scale) { + // Output shape matches q: [batch_size, seqlen_q, num_heads, head_size] + auto out = Tensor::empty(q->shape(), q->dtype(), q->device()); + mha_kvcache_(out, q, k_cache, v_cache, seqlens_k, block_table, alibi_slopes, scale); + return out; +} + +} // namespace infinicore::op diff --git a/src/infinicore/ops/mha_kvcache/mha_kvcache_flashattn.cc b/src/infinicore/ops/mha_kvcache/mha_kvcache_flashattn.cc new file mode 100644 index 000000000..d74fdbb00 --- /dev/null +++ b/src/infinicore/ops/mha_kvcache/mha_kvcache_flashattn.cc @@ -0,0 +1,85 @@ +#include "infinicore/ops/mha_kvcache.hpp" + +#include "infinicore/adaptor/flash_attention_adaptor.hpp" + +namespace infinicore::op::mha_kvcache_impl::flashattn { + +struct PlannedMeta { + graph::GraphTensor out, q, k_cache, v_cache, seqlens_k, block_table; + std::optional alibi_slopes; + float scale; +}; + +void *plan(Tensor out, + const Tensor &q, + const Tensor &k_cache, + const Tensor &v_cache, + const Tensor &seqlens_k, + const Tensor &block_table, + std::optional alibi_slopes, + float scale) { + return new PlannedMeta{ + graph::GraphTensor(out), + graph::GraphTensor(q), + graph::GraphTensor(k_cache), + graph::GraphTensor(v_cache), + graph::GraphTensor(seqlens_k), + graph::GraphTensor(block_table), + alibi_slopes ? std::optional(graph::GraphTensor(*alibi_slopes)) : std::nullopt, + scale}; +} + +void run(void *planned_meta) { + c10::cuda::CUDAStreamGuard guard(infinicore::adaptor::get_cuda_stream()); + auto *p = reinterpret_cast(planned_meta); + + auto out = std::optional(infinicore::adaptor::to_aten_tensor(p->out)); + auto q = infinicore::adaptor::to_aten_tensor(p->q); + auto k_cache = infinicore::adaptor::to_aten_tensor(p->k_cache); + auto v_cache = infinicore::adaptor::to_aten_tensor(p->v_cache); + auto seqlens_k = std::optional(infinicore::adaptor::to_aten_tensor(p->seqlens_k)); + auto block_table = std::optional(infinicore::adaptor::to_aten_tensor(p->block_table)); + auto alibi_slopes = p->alibi_slopes + ? std::optional(infinicore::adaptor::to_aten_tensor(*p->alibi_slopes)) + : std::nullopt; + + // No new KV tokens to append (pure decode, KV already written to cache). + std::optional k_new = std::nullopt; + std::optional v_new = std::nullopt; + std::optional rotary_cos = std::nullopt; + std::optional rotary_sin = std::nullopt; + std::optional cache_batch_idx = std::nullopt; + std::optional leftpad_k = std::nullopt; + + flash::mha_fwd_kvcache( + q, + k_cache, + v_cache, + k_new, + v_new, + seqlens_k, + rotary_cos, + rotary_sin, + cache_batch_idx, + leftpad_k, + block_table, + alibi_slopes, + out, + p->scale, + true, // is_causal + -1, // window_size_left (-1 = no sliding window) + -1, // window_size_right + 0.0f, // softcap + false, // is_rotary_interleaved + 0 // num_splits (0 = auto) + ); +} + +void cleanup(void **planned_meta_ptr) { + delete *reinterpret_cast(planned_meta_ptr); + *planned_meta_ptr = nullptr; +} + +INFINICORE_GRAPH_OP_REGISTER_ALLDEVICE(MhaKVCache, &plan, &run, &cleanup); + +} // namespace infinicore::op::mha_kvcache_impl::flashattn From 456ee3e162b05a07b52f420457d7ead7743d4640 Mon Sep 17 00:00:00 2001 From: wooway777 Date: Mon, 9 Mar 2026 17:52:11 +0800 Subject: [PATCH 2/3] issue/1065 - add infinicore packaging for mha kvcache --- python/infinicore/__init__.py | 14 +- python/infinicore/ops/mha_kvcache.py | 67 +++++++++ .../ops/mha_kvcache/mha_kvcache_flashattn.cc | 40 +++--- src/infinicore/pybind11/ops.hpp | 2 + src/infinicore/pybind11/ops/mha_kvcache.hpp | 127 ++++++++++++++++++ 5 files changed, 227 insertions(+), 23 deletions(-) create mode 100644 python/infinicore/ops/mha_kvcache.py create mode 100644 src/infinicore/pybind11/ops/mha_kvcache.hpp diff --git a/python/infinicore/__init__.py b/python/infinicore/__init__.py index 7d061bd18..b28375e95 100644 --- a/python/infinicore/__init__.py +++ b/python/infinicore/__init__.py @@ -61,6 +61,7 @@ from infinicore.ops.equal import equal from infinicore.ops.kv_caching import kv_caching from infinicore.ops.matmul import matmul +from infinicore.ops.mha_kvcache import mha_kvcache from infinicore.ops.mha_varlen import mha_varlen from infinicore.ops.mul import mul from infinicore.ops.narrow import narrow @@ -131,16 +132,15 @@ "long", "short", "uint8", - # Operations. - "addcmul", - "atanh", - "binary_cross_entropy_with_logits", - "cdist", - "reciprocal", + # Operators. "add", + "addcmul", "add_rms_norm", "add_rms_norm_", + "atanh", "attention", + "binary_cross_entropy_with_logits", + "cdist", "kv_caching", "matmul", "equal", @@ -156,11 +156,13 @@ "from_list", "from_numpy", "from_torch", + "mha_kvcache", "mha_varlen", "paged_caching", "paged_attention", "paged_attention_prefill", "ones", + "reciprocal", "strided_empty", "strided_from_blob", "zeros", diff --git a/python/infinicore/ops/mha_kvcache.py b/python/infinicore/ops/mha_kvcache.py new file mode 100644 index 000000000..8ced8df16 --- /dev/null +++ b/python/infinicore/ops/mha_kvcache.py @@ -0,0 +1,67 @@ +from typing import Optional + +from infinicore.lib import _infinicore +from infinicore.tensor import Tensor + + +def mha_kvcache( + q: Tensor, + k_cache: Tensor, + v_cache: Tensor, + seqlens_k: Tensor, + block_table: Tensor, + alibi_slopes: Optional[Tensor] = None, + scale: float = 1.0, + *, + out: Optional[Tensor] = None, +) -> Tensor: + """ + Flash attention KV-cache decode for single-step attention over a paged KV cache. + + This function performs attention decoding using a paged KV cache layout, + which is efficient for inference with large sequence lengths. + + Args: + q: Query tensor of shape [batch_size, seqlen_q, num_heads, head_size] + k_cache: Key cache tensor of shape [num_blocks, block_size, num_heads_k, head_size] (paged layout) + v_cache: Value cache tensor of shape [num_blocks, block_size, num_heads_k, head_size] (paged layout) + seqlens_k: Total KV length per request of shape [batch_size] (int32) + block_table: Block mapping table of shape [batch_size, max_num_blocks_per_seq] (int32) + alibi_slopes: Optional ALiBi slopes tensor, if None then ALiBi is disabled + scale: Scaling factor for attention scores (typically 1.0/sqrt(head_size)) + out: Optional output tensor. If provided, the operation will be performed in-place. + + Returns: + Output tensor of shape [batch_size, seqlen_q, num_heads, head_size] + + Note: + The KV cache uses a paged layout where: + - k_cache and v_cache are organized into fixed-size blocks + - block_table maps logical positions to physical blocks for each sequence + - seqlens_k indicates the current total length of each sequence in the cache + """ + if out is None: + return Tensor( + _infinicore.mha_kvcache( + q._underlying, + k_cache._underlying, + v_cache._underlying, + seqlens_k._underlying, + block_table._underlying, + alibi_slopes._underlying if alibi_slopes is not None else None, + scale, + ) + ) + + _infinicore.mha_kvcache_( + out._underlying, + q._underlying, + k_cache._underlying, + v_cache._underlying, + seqlens_k._underlying, + block_table._underlying, + alibi_slopes._underlying if alibi_slopes is not None else None, + scale, + ) + + return out diff --git a/src/infinicore/ops/mha_kvcache/mha_kvcache_flashattn.cc b/src/infinicore/ops/mha_kvcache/mha_kvcache_flashattn.cc index d74fdbb00..77e6f37a3 100644 --- a/src/infinicore/ops/mha_kvcache/mha_kvcache_flashattn.cc +++ b/src/infinicore/ops/mha_kvcache/mha_kvcache_flashattn.cc @@ -2,6 +2,8 @@ #include "infinicore/adaptor/flash_attention_adaptor.hpp" +#include + namespace infinicore::op::mha_kvcache_impl::flashattn { struct PlannedMeta { @@ -30,26 +32,27 @@ void *plan(Tensor out, } void run(void *planned_meta) { +#ifdef ENABLE_FLASH_ATTN c10::cuda::CUDAStreamGuard guard(infinicore::adaptor::get_cuda_stream()); auto *p = reinterpret_cast(planned_meta); - auto out = std::optional(infinicore::adaptor::to_aten_tensor(p->out)); - auto q = infinicore::adaptor::to_aten_tensor(p->q); - auto k_cache = infinicore::adaptor::to_aten_tensor(p->k_cache); - auto v_cache = infinicore::adaptor::to_aten_tensor(p->v_cache); + auto out = std::optional(infinicore::adaptor::to_aten_tensor(p->out)); + auto q = infinicore::adaptor::to_aten_tensor(p->q); + auto k_cache = infinicore::adaptor::to_aten_tensor(p->k_cache); + auto v_cache = infinicore::adaptor::to_aten_tensor(p->v_cache); auto seqlens_k = std::optional(infinicore::adaptor::to_aten_tensor(p->seqlens_k)); auto block_table = std::optional(infinicore::adaptor::to_aten_tensor(p->block_table)); auto alibi_slopes = p->alibi_slopes - ? std::optional(infinicore::adaptor::to_aten_tensor(*p->alibi_slopes)) - : std::nullopt; + ? std::optional(infinicore::adaptor::to_aten_tensor(*p->alibi_slopes)) + : std::nullopt; // No new KV tokens to append (pure decode, KV already written to cache). - std::optional k_new = std::nullopt; - std::optional v_new = std::nullopt; - std::optional rotary_cos = std::nullopt; - std::optional rotary_sin = std::nullopt; + std::optional k_new = std::nullopt; + std::optional v_new = std::nullopt; + std::optional rotary_cos = std::nullopt; + std::optional rotary_sin = std::nullopt; std::optional cache_batch_idx = std::nullopt; - std::optional leftpad_k = std::nullopt; + std::optional leftpad_k = std::nullopt; flash::mha_fwd_kvcache( q, @@ -66,13 +69,16 @@ void run(void *planned_meta) { alibi_slopes, out, p->scale, - true, // is_causal - -1, // window_size_left (-1 = no sliding window) - -1, // window_size_right - 0.0f, // softcap - false, // is_rotary_interleaved - 0 // num_splits (0 = auto) + true, // is_causal + -1, // window_size_left (-1 = no sliding window) + -1, // window_size_right + 0.0f, // softcap + false, // is_rotary_interleaved + 0 // num_splits (0 = auto) ); +#else + throw std::runtime_error("FlashAttention is not enabled in this build"); +#endif } void cleanup(void **planned_meta_ptr) { diff --git a/src/infinicore/pybind11/ops.hpp b/src/infinicore/pybind11/ops.hpp index e49c6da92..3ca57881e 100644 --- a/src/infinicore/pybind11/ops.hpp +++ b/src/infinicore/pybind11/ops.hpp @@ -22,6 +22,7 @@ #include "ops/linear.hpp" #include "ops/linear_w8a8i8.hpp" #include "ops/matmul.hpp" +#include "ops/mha_kvcache.hpp" #include "ops/mha_varlen.hpp" #include "ops/mul.hpp" #include "ops/paged_attention.hpp" @@ -54,6 +55,7 @@ inline void bind(py::module &m) { bind_linear(m); bind_matmul(m); bind_mul(m); + bind_mha_kvcache(m); bind_mha_varlen(m); bind_hardswish(m); bind_hardtanh(m); diff --git a/src/infinicore/pybind11/ops/mha_kvcache.hpp b/src/infinicore/pybind11/ops/mha_kvcache.hpp new file mode 100644 index 000000000..38934233e --- /dev/null +++ b/src/infinicore/pybind11/ops/mha_kvcache.hpp @@ -0,0 +1,127 @@ +#pragma once + +#include + +#include "infinicore/ops/mha_kvcache.hpp" + +namespace py = pybind11; + +namespace infinicore::ops { + +Tensor py_mha_kvcache(Tensor q, + Tensor k_cache, + Tensor v_cache, + Tensor seqlens_k, + Tensor block_table, + pybind11::object alibi_slopes, + float scale) { + std::optional alibi_slopes_tensor = std::nullopt; + if (!alibi_slopes.is_none()) { + alibi_slopes_tensor = alibi_slopes.cast(); + } + + return op::mha_kvcache( + q, + k_cache, + v_cache, + seqlens_k, + block_table, + alibi_slopes_tensor, + scale); +} + +void py_mha_kvcache_(Tensor out, + Tensor q, + Tensor k_cache, + Tensor v_cache, + Tensor seqlens_k, + Tensor block_table, + pybind11::object alibi_slopes, + float scale) { + std::optional alibi_slopes_tensor = std::nullopt; + if (!alibi_slopes.is_none()) { + alibi_slopes_tensor = alibi_slopes.cast(); + } + + op::mha_kvcache_( + out, + q, + k_cache, + v_cache, + seqlens_k, + block_table, + alibi_slopes_tensor, + scale); +} + +inline void bind_mha_kvcache(py::module &m) { + m.def( + "mha_kvcache", + &ops::py_mha_kvcache, + py::arg("q"), + py::arg("k_cache"), + py::arg("v_cache"), + py::arg("seqlens_k"), + py::arg("block_table"), + py::arg("alibi_slopes"), + py::arg("scale"), + R"doc(Flash attention KV-cache decode for single-step attention over a paged KV cache. + +Parameters +---------- +q : Tensor + Query tensor of shape [batch_size, seqlen_q, num_heads, head_size] +k_cache : Tensor + Key cache tensor of shape [num_blocks, block_size, num_heads_k, head_size] (paged layout) +v_cache : Tensor + Value cache tensor of shape [num_blocks, block_size, num_heads_k, head_size] (paged layout) +seqlens_k : Tensor + Total KV length per request of shape [batch_size] (int32) +block_table : Tensor + Block mapping table of shape [batch_size, max_num_blocks_per_seq] (int32) +alibi_slopes : Optional[Tensor] + ALiBi slopes tensor, if None then ALiBi is disabled +scale : float + Scaling factor for attention scores (typically 1.0/sqrt(head_size)) + +Returns +------- +Tensor + Output tensor of shape [batch_size, seqlen_q, num_heads, head_size] +)doc"); + + m.def( + "mha_kvcache_", + &ops::py_mha_kvcache_, + py::arg("out"), + py::arg("q"), + py::arg("k_cache"), + py::arg("v_cache"), + py::arg("seqlens_k"), + py::arg("block_table"), + py::arg("alibi_slopes"), + py::arg("scale"), + R"doc(In-place flash attention KV-cache decode. + +Parameters +---------- +out : Tensor + Output tensor of shape [batch_size, seqlen_q, num_heads, head_size] +q : Tensor + Query tensor of shape [batch_size, seqlen_q, num_heads, head_size] +k_cache : Tensor + Key cache tensor of shape [num_blocks, block_size, num_heads_k, head_size] (paged layout) +v_cache : Tensor + Value cache tensor of shape [num_blocks, block_size, num_heads_k, head_size] (paged layout) +seqlens_k : Tensor + Total KV length per request of shape [batch_size] (int32) +block_table : Tensor + Block mapping table of shape [batch_size, max_num_blocks_per_seq] (int32) +alibi_slopes : Optional[Tensor] + ALiBi slopes tensor, if None then ALiBi is disabled +scale : float + Scaling factor for attention scores (typically 1.0/sqrt(head_size)) +)doc"); +} + +} // namespace infinicore::ops From 0f90515c310b522d456374590435488d62d5a0a0 Mon Sep 17 00:00:00 2001 From: wooway777 Date: Mon, 9 Mar 2026 19:56:44 +0800 Subject: [PATCH 3/3] issue/1065 - fix mha kv cache interface --- .../ops/mha_kvcache/mha_kvcache_flashattn.cc | 29 ++++++++++++------- 1 file changed, 19 insertions(+), 10 deletions(-) diff --git a/src/infinicore/ops/mha_kvcache/mha_kvcache_flashattn.cc b/src/infinicore/ops/mha_kvcache/mha_kvcache_flashattn.cc index 77e6f37a3..24fcf7aea 100644 --- a/src/infinicore/ops/mha_kvcache/mha_kvcache_flashattn.cc +++ b/src/infinicore/ops/mha_kvcache/mha_kvcache_flashattn.cc @@ -36,7 +36,7 @@ void run(void *planned_meta) { c10::cuda::CUDAStreamGuard guard(infinicore::adaptor::get_cuda_stream()); auto *p = reinterpret_cast(planned_meta); - auto out = std::optional(infinicore::adaptor::to_aten_tensor(p->out)); + auto out_tensor = infinicore::adaptor::to_aten_tensor(p->out); auto q = infinicore::adaptor::to_aten_tensor(p->q); auto k_cache = infinicore::adaptor::to_aten_tensor(p->k_cache); auto v_cache = infinicore::adaptor::to_aten_tensor(p->v_cache); @@ -46,7 +46,6 @@ void run(void *planned_meta) { ? std::optional(infinicore::adaptor::to_aten_tensor(*p->alibi_slopes)) : std::nullopt; - // No new KV tokens to append (pure decode, KV already written to cache). std::optional k_new = std::nullopt; std::optional v_new = std::nullopt; std::optional rotary_cos = std::nullopt; @@ -54,7 +53,14 @@ void run(void *planned_meta) { std::optional cache_batch_idx = std::nullopt; std::optional leftpad_k = std::nullopt; - flash::mha_fwd_kvcache( + const bool use_dynamic_out = q.dim() == 4 && k_cache.dim() == 4 + && q.size(1) == 1 && q.size(2) > k_cache.size(2) + && q.size(3) % 8 == 0 && !alibi_slopes.has_value(); + + auto out = use_dynamic_out ? std::optional(std::nullopt) + : std::optional(out_tensor); + + auto result = flash::mha_fwd_kvcache( q, k_cache, v_cache, @@ -69,13 +75,16 @@ void run(void *planned_meta) { alibi_slopes, out, p->scale, - true, // is_causal - -1, // window_size_left (-1 = no sliding window) - -1, // window_size_right - 0.0f, // softcap - false, // is_rotary_interleaved - 0 // num_splits (0 = auto) - ); + true, + -1, + -1, + 0.0f, + false, + 0); + + if (use_dynamic_out) { + out_tensor.copy_(result[0]); + } #else throw std::runtime_error("FlashAttention is not enabled in this build"); #endif