Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 51 additions & 0 deletions include/infinicore/ops/mha_kvcache.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
#pragma once

#include "../device.hpp"
#include "../graph/graph.hpp"
#include "common/op.hpp"
#include <optional>

namespace infinicore::op {

// Flash Attention KV-cache decode op.
//
// Wraps flash::mha_fwd_kvcache for single-step (decode) attention over a
// paged KV cache.
//
// Tensor shapes:
// out : [batch_size, seqlen_q, num_heads, head_size]
// q : [batch_size, seqlen_q, num_heads, head_size]
// k_cache : [num_blocks, block_size, num_heads_k, head_size] (paged layout)
// v_cache : [num_blocks, block_size, num_heads_k, head_size] (paged layout)
// seqlens_k : [batch_size] int32 — total KV length per request
// block_table : [batch_size, max_num_blocks_per_seq] int32

INFINICORE_GRAPH_OP_CLASS(
MhaKVCache,
Tensor, // out
const Tensor &, // q
const Tensor &, // k_cache
const Tensor &, // v_cache
const Tensor &, // seqlens_k
const Tensor &, // block_table
std::optional<Tensor>, // alibi_slopes
float); // scale

Tensor mha_kvcache(const Tensor &q,
const Tensor &k_cache,
const Tensor &v_cache,
const Tensor &seqlens_k,
const Tensor &block_table,
std::optional<Tensor> alibi_slopes,
float scale);

void mha_kvcache_(Tensor out,
const Tensor &q,
const Tensor &k_cache,
const Tensor &v_cache,
const Tensor &seqlens_k,
const Tensor &block_table,
std::optional<Tensor> alibi_slopes,
float scale);

} // namespace infinicore::op
14 changes: 8 additions & 6 deletions python/infinicore/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@
from infinicore.ops.equal import equal
from infinicore.ops.kv_caching import kv_caching
from infinicore.ops.matmul import matmul
from infinicore.ops.mha_kvcache import mha_kvcache
from infinicore.ops.mha_varlen import mha_varlen
from infinicore.ops.mul import mul
from infinicore.ops.narrow import narrow
Expand Down Expand Up @@ -131,16 +132,15 @@
"long",
"short",
"uint8",
# Operations.
"addcmul",
"atanh",
"binary_cross_entropy_with_logits",
"cdist",
"reciprocal",
# Operators.
"add",
"addcmul",
"add_rms_norm",
"add_rms_norm_",
"atanh",
"attention",
"binary_cross_entropy_with_logits",
"cdist",
"kv_caching",
"matmul",
"equal",
Expand All @@ -156,11 +156,13 @@
"from_list",
"from_numpy",
"from_torch",
"mha_kvcache",
"mha_varlen",
"paged_caching",
"paged_attention",
"paged_attention_prefill",
"ones",
"reciprocal",
"strided_empty",
"strided_from_blob",
"zeros",
Expand Down
67 changes: 67 additions & 0 deletions python/infinicore/ops/mha_kvcache.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
from typing import Optional

from infinicore.lib import _infinicore
from infinicore.tensor import Tensor


def mha_kvcache(
q: Tensor,
k_cache: Tensor,
v_cache: Tensor,
seqlens_k: Tensor,
block_table: Tensor,
alibi_slopes: Optional[Tensor] = None,
scale: float = 1.0,
*,
out: Optional[Tensor] = None,
) -> Tensor:
"""
Flash attention KV-cache decode for single-step attention over a paged KV cache.

This function performs attention decoding using a paged KV cache layout,
which is efficient for inference with large sequence lengths.

Args:
q: Query tensor of shape [batch_size, seqlen_q, num_heads, head_size]
k_cache: Key cache tensor of shape [num_blocks, block_size, num_heads_k, head_size] (paged layout)
v_cache: Value cache tensor of shape [num_blocks, block_size, num_heads_k, head_size] (paged layout)
seqlens_k: Total KV length per request of shape [batch_size] (int32)
block_table: Block mapping table of shape [batch_size, max_num_blocks_per_seq] (int32)
alibi_slopes: Optional ALiBi slopes tensor, if None then ALiBi is disabled
scale: Scaling factor for attention scores (typically 1.0/sqrt(head_size))
out: Optional output tensor. If provided, the operation will be performed in-place.

Returns:
Output tensor of shape [batch_size, seqlen_q, num_heads, head_size]

Note:
The KV cache uses a paged layout where:
- k_cache and v_cache are organized into fixed-size blocks
- block_table maps logical positions to physical blocks for each sequence
- seqlens_k indicates the current total length of each sequence in the cache
"""
if out is None:
return Tensor(
_infinicore.mha_kvcache(
q._underlying,
k_cache._underlying,
v_cache._underlying,
seqlens_k._underlying,
block_table._underlying,
alibi_slopes._underlying if alibi_slopes is not None else None,
scale,
)
)

_infinicore.mha_kvcache_(
out._underlying,
q._underlying,
k_cache._underlying,
v_cache._underlying,
seqlens_k._underlying,
block_table._underlying,
alibi_slopes._underlying if alibi_slopes is not None else None,
scale,
)

return out
58 changes: 58 additions & 0 deletions src/infinicore/ops/mha_kvcache/mha_kvcache.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
#include "infinicore/ops/mha_kvcache.hpp"
#include "../../utils.hpp"

namespace infinicore::op {

INFINICORE_GRAPH_OP_DISPATCHERS_IMPL(MhaKVCache);

MhaKVCache::MhaKVCache(Tensor out,
const Tensor &q,
const Tensor &k_cache,
const Tensor &v_cache,
const Tensor &seqlens_k,
const Tensor &block_table,
std::optional<Tensor> alibi_slopes,
float scale) {
INFINICORE_ASSERT_TENSORS_SAME_DEVICE(out, q, k_cache, v_cache, seqlens_k, block_table);
INFINICORE_GRAPH_OP_DISPATCH(out->device().getType(),
out, q, k_cache, v_cache, seqlens_k, block_table, alibi_slopes, scale);
}

void MhaKVCache::execute(Tensor out,
const Tensor &q,
const Tensor &k_cache,
const Tensor &v_cache,
const Tensor &seqlens_k,
const Tensor &block_table,
std::optional<Tensor> alibi_slopes,
float scale) {
INFINICORE_GRAPH_OP_RECORD_OR_RUN(
MhaKVCache,
out, q, k_cache, v_cache, seqlens_k, block_table, alibi_slopes, scale);
}

void mha_kvcache_(Tensor out,
const Tensor &q,
const Tensor &k_cache,
const Tensor &v_cache,
const Tensor &seqlens_k,
const Tensor &block_table,
std::optional<Tensor> alibi_slopes,
float scale) {
MhaKVCache::execute(out, q, k_cache, v_cache, seqlens_k, block_table, alibi_slopes, scale);
}

Tensor mha_kvcache(const Tensor &q,
const Tensor &k_cache,
const Tensor &v_cache,
const Tensor &seqlens_k,
const Tensor &block_table,
std::optional<Tensor> alibi_slopes,
float scale) {
// Output shape matches q: [batch_size, seqlen_q, num_heads, head_size]
auto out = Tensor::empty(q->shape(), q->dtype(), q->device());
mha_kvcache_(out, q, k_cache, v_cache, seqlens_k, block_table, alibi_slopes, scale);
return out;
}

} // namespace infinicore::op
100 changes: 100 additions & 0 deletions src/infinicore/ops/mha_kvcache/mha_kvcache_flashattn.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
#include "infinicore/ops/mha_kvcache.hpp"

#include "infinicore/adaptor/flash_attention_adaptor.hpp"

#include <stdexcept>

namespace infinicore::op::mha_kvcache_impl::flashattn {

struct PlannedMeta {
graph::GraphTensor out, q, k_cache, v_cache, seqlens_k, block_table;
std::optional<graph::GraphTensor> alibi_slopes;
float scale;
};

void *plan(Tensor out,
const Tensor &q,
const Tensor &k_cache,
const Tensor &v_cache,
const Tensor &seqlens_k,
const Tensor &block_table,
std::optional<Tensor> alibi_slopes,
float scale) {
return new PlannedMeta{
graph::GraphTensor(out),
graph::GraphTensor(q),
graph::GraphTensor(k_cache),
graph::GraphTensor(v_cache),
graph::GraphTensor(seqlens_k),
graph::GraphTensor(block_table),
alibi_slopes ? std::optional<graph::GraphTensor>(graph::GraphTensor(*alibi_slopes)) : std::nullopt,
scale};
}

void run(void *planned_meta) {
#ifdef ENABLE_FLASH_ATTN
c10::cuda::CUDAStreamGuard guard(infinicore::adaptor::get_cuda_stream());
auto *p = reinterpret_cast<PlannedMeta *>(planned_meta);

auto out_tensor = infinicore::adaptor::to_aten_tensor(p->out);
auto q = infinicore::adaptor::to_aten_tensor(p->q);
auto k_cache = infinicore::adaptor::to_aten_tensor(p->k_cache);
auto v_cache = infinicore::adaptor::to_aten_tensor(p->v_cache);
auto seqlens_k = std::optional<const at::Tensor>(infinicore::adaptor::to_aten_tensor(p->seqlens_k));
auto block_table = std::optional<at::Tensor>(infinicore::adaptor::to_aten_tensor(p->block_table));
auto alibi_slopes = p->alibi_slopes
? std::optional<at::Tensor>(infinicore::adaptor::to_aten_tensor(*p->alibi_slopes))
: std::nullopt;

std::optional<const at::Tensor> k_new = std::nullopt;
std::optional<const at::Tensor> v_new = std::nullopt;
std::optional<const at::Tensor> rotary_cos = std::nullopt;
std::optional<const at::Tensor> rotary_sin = std::nullopt;
std::optional<const at::Tensor> cache_batch_idx = std::nullopt;
std::optional<const at::Tensor> leftpad_k = std::nullopt;

const bool use_dynamic_out = q.dim() == 4 && k_cache.dim() == 4
&& q.size(1) == 1 && q.size(2) > k_cache.size(2)
&& q.size(3) % 8 == 0 && !alibi_slopes.has_value();

auto out = use_dynamic_out ? std::optional<at::Tensor>(std::nullopt)
: std::optional<at::Tensor>(out_tensor);

auto result = flash::mha_fwd_kvcache(
q,
k_cache,
v_cache,
k_new,
v_new,
seqlens_k,
rotary_cos,
rotary_sin,
cache_batch_idx,
leftpad_k,
block_table,
alibi_slopes,
out,
p->scale,
true,
-1,
-1,
0.0f,
false,
0);

if (use_dynamic_out) {
out_tensor.copy_(result[0]);
}
#else
throw std::runtime_error("FlashAttention is not enabled in this build");
#endif
}

void cleanup(void **planned_meta_ptr) {
delete *reinterpret_cast<PlannedMeta **>(planned_meta_ptr);
*planned_meta_ptr = nullptr;
}

INFINICORE_GRAPH_OP_REGISTER_ALLDEVICE(MhaKVCache, &plan, &run, &cleanup);

} // namespace infinicore::op::mha_kvcache_impl::flashattn
2 changes: 2 additions & 0 deletions src/infinicore/pybind11/ops.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
#include "ops/linear.hpp"
#include "ops/linear_w8a8i8.hpp"
#include "ops/matmul.hpp"
#include "ops/mha_kvcache.hpp"
#include "ops/mha_varlen.hpp"
#include "ops/mul.hpp"
#include "ops/paged_attention.hpp"
Expand Down Expand Up @@ -54,6 +55,7 @@ inline void bind(py::module &m) {
bind_linear(m);
bind_matmul(m);
bind_mul(m);
bind_mha_kvcache(m);
bind_mha_varlen(m);
bind_hardswish(m);
bind_hardtanh(m);
Expand Down
Loading