diff --git a/include/infinicore/adaptor/aten_adaptor.hpp b/include/infinicore/adaptor/aten_adaptor.hpp new file mode 100644 index 000000000..70cb98e18 --- /dev/null +++ b/include/infinicore/adaptor/aten_adaptor.hpp @@ -0,0 +1,48 @@ +#ifdef ENABLE_ATEN +#pragma once +#include "../context/context.hpp" +#include "../tensor.hpp" + +#include + +#ifdef ENABLE_NVIDIA_API +#include +#include +#endif + +namespace infinicore::adaptor { +inline at::ScalarType to_at_dtype(DataType dtype) { + switch (dtype) { + case DataType::F32: + return at::kFloat; + case DataType::F16: + return at::kHalf; + case DataType::BF16: + return at::kBFloat16; + case DataType::I32: + return at::kInt; + case DataType::I64: + return at::kLong; + default: + throw std::runtime_error("Unsupported dtype for ATen"); + } +} + +inline at::Device to_at_device(const Device &device) { + if (device.getType() == Device::Type::NVIDIA) { + return at::Device(at::kCUDA, device.getIndex()); + } else if (device.getType() == Device::Type::CPU) { + return at::Device(at::kCPU); + } else { + throw std::runtime_error("Unsupported device type for ATen"); + } +} + +at::Tensor to_aten_tensor(const infinicore::Tensor &t); + +#ifdef ENABLE_NVIDIA_API +c10::cuda::CUDAStream get_cuda_stream(); +#endif +} // namespace infinicore::adaptor + +#endif // ENABLE_ATEN diff --git a/include/infinicore/adaptor/flash_attention_adaptor.hpp b/include/infinicore/adaptor/flash_attention_adaptor.hpp new file mode 100644 index 000000000..8a9e152fd --- /dev/null +++ b/include/infinicore/adaptor/flash_attention_adaptor.hpp @@ -0,0 +1,114 @@ +#ifdef ENABLE_FLASH_ATTN +#pragma once +#include "aten_adaptor.hpp" + +namespace flash { +std::vector +mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num_heads x round_multiple(head_size, 8) + const at::Tensor &k, // batch_size x seqlen_k x num_heads_k x round_multiple(head_size, 8) + const at::Tensor &v, // batch_size x seqlen_k x num_heads_k x round_multiple(head_size, 8) + std::optional &out_, // batch_size x seqlen_q x num_heads x round_multiple(head_size, 8) + std::optional &alibi_slopes_, // num_heads or batch_size x num_heads + const float p_dropout, + const float softmax_scale, + bool is_causal, + int window_size_left, + int window_size_right, + const float softcap, + const bool return_softmax, + std::optional gen_); + +std::vector +mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i + const at::Tensor &k, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table. + const at::Tensor &v, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table. + std::optional &out_, // total_q x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i + const at::Tensor &cu_seqlens_q, // b+1 + const at::Tensor &cu_seqlens_k, // b+1 + std::optional &seqused_k, // b. If given, only this many elements of each batch element's keys are used. + std::optional &leftpad_k_, // batch_size + std::optional &block_table_, // batch_size x max_num_blocks_per_seq + std::optional &alibi_slopes_, // num_heads or b x num_heads + int max_seqlen_q, + const int max_seqlen_k, + const float p_dropout, + const float softmax_scale, + const bool zero_tensors, + bool is_causal, + int window_size_left, + int window_size_right, + const float softcap, + const bool return_softmax, + std::optional gen_); + +std::vector +mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x multiple_of(head_size_og, 8) + const at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size + const at::Tensor &k, // batch_size x seqlen_k x num_heads_k x head_size + const at::Tensor &v, // batch_size x seqlen_k x num_heads_k x head_size + const at::Tensor &out, // batch_size x seqlen_q x num_heads x head_size + const at::Tensor &softmax_lse, // b x h x seqlen_q + std::optional &dq_, // batch_size x seqlen_q x num_heads x head_size + std::optional &dk_, // batch_size x seqlen_k x num_heads_k x head_size + std::optional &dv_, // batch_size x seqlen_k x num_heads_k x head_size + std::optional &alibi_slopes_, // num_heads or batch_size x num_heads + const float p_dropout, // probability to drop + const float softmax_scale, + const bool is_causal, + int window_size_left, + int window_size_right, + const float softcap, + const bool deterministic, + std::optional gen_, + std::optional &rng_state); + +std::vector +mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size + const at::Tensor &q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i + const at::Tensor &k, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i + const at::Tensor &v, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i + const at::Tensor &out, // total_q x num_heads x head_size + const at::Tensor &softmax_lse, // h x total_q, softmax logsumexp + std::optional &dq_, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i + std::optional &dk_, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i + std::optional &dv_, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i + const at::Tensor &cu_seqlens_q, // b+1 + const at::Tensor &cu_seqlens_k, // b+1 + std::optional &alibi_slopes_, // num_heads or b x num_heads + const int max_seqlen_q, + const int max_seqlen_k, // max sequence length to choose the kernel + const float p_dropout, // probability to drop + const float softmax_scale, + const bool zero_tensors, + const bool is_causal, + int window_size_left, + int window_size_right, + const float softcap, + const bool deterministic, + std::optional gen_, + std::optional &rng_state); + +std::vector +mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size + const at::Tensor &kcache, // batch_size_c x seqlen_k x num_heads_k x head_size or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table. + const at::Tensor &vcache, // batch_size_c x seqlen_k x num_heads_k x head_size or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table. + std::optional &k_, // batch_size x seqlen_knew x num_heads_k x head_size + std::optional &v_, // batch_size x seqlen_knew x num_heads_k x head_size + std::optional &seqlens_k_, // batch_size + std::optional &rotary_cos_, // seqlen_ro x (rotary_dim / 2) + std::optional &rotary_sin_, // seqlen_ro x (rotary_dim / 2) + std::optional &cache_batch_idx_, // indices to index into the KV cache + std::optional &leftpad_k_, // batch_size + std::optional &block_table_, // batch_size x max_num_blocks_per_seq + std::optional &alibi_slopes_, // num_heads or batch_size x num_heads + std::optional &out_, // batch_size x seqlen_q x num_heads x head_size + const float softmax_scale, + bool is_causal, + int window_size_left, + int window_size_right, + const float softcap, + bool is_rotary_interleaved, // if true, rotary combines indices 0 & 1, else indices 0 & rotary_dim / 2 + int num_splits); + +} // namespace flash +#endif // ENABLE_FLASH_ATTN diff --git a/include/infinicore/ops/mha_varlen.hpp b/include/infinicore/ops/mha_varlen.hpp new file mode 100644 index 000000000..fc35a11df --- /dev/null +++ b/include/infinicore/ops/mha_varlen.hpp @@ -0,0 +1,46 @@ +#pragma once + +#include "../device.hpp" +#include "common/op.hpp" +#include + +namespace infinicore::op { + +INFINICORE_GRAPH_OP_CLASS( + MultiheadAttentionVarlen, + Tensor, + const Tensor &, + const Tensor &, + const Tensor &, + const Tensor &, + const Tensor &, + const Tensor &, + int, + int, + std::optional, + float); + +Tensor mha_varlen(const Tensor &q, + const Tensor &k, + const Tensor &v, + const Tensor &cum_seqlens_q, + const Tensor &cum_seqlens_k, + const Tensor &block_table, + int max_seqlen_q, + int max_seqlen_k, + std::optional alibi_slopes, + float scale); + +void mha_varlen_(Tensor out, + const Tensor &q, + const Tensor &k, + const Tensor &v, + const Tensor &cum_seqlens_q, + const Tensor &cum_seqlens_k, + const Tensor &block_table, + int max_seqlen_q, + int max_seqlen_k, + std::optional alibi_slopes, + float scale); + +} // namespace infinicore::op diff --git a/python/infinicore/__init__.py b/python/infinicore/__init__.py index 54488f3c2..2f0ef56ea 100644 --- a/python/infinicore/__init__.py +++ b/python/infinicore/__init__.py @@ -52,6 +52,7 @@ from infinicore.ops.attention import attention from infinicore.ops.kv_caching import kv_caching from infinicore.ops.matmul import matmul +from infinicore.ops.mha_varlen import mha_varlen from infinicore.ops.mul import mul from infinicore.ops.narrow import narrow from infinicore.ops.paged_attention import paged_attention @@ -134,6 +135,7 @@ "from_list", "from_numpy", "from_torch", + "mha_varlen", "paged_caching", "paged_attention", "paged_attention_prefill", diff --git a/python/infinicore/ops/mha_varlen.py b/python/infinicore/ops/mha_varlen.py new file mode 100644 index 000000000..48dc5f9c1 --- /dev/null +++ b/python/infinicore/ops/mha_varlen.py @@ -0,0 +1,49 @@ +from infinicore.lib import _infinicore +from infinicore.tensor import Tensor + + +def mha_varlen( + q: Tensor, + k: Tensor, + v: Tensor, + cum_seqlens_q: Tensor, + cum_seqlens_k: Tensor, + block_table: Tensor, + max_seqlen_q: int, + max_seqlen_k: int, + alibi_slopes: Tensor | None = None, + scale: float = 1.0, + *, + out: Tensor | None = None, +): + if out is None: + return Tensor( + _infinicore.mha_varlen( + q._underlying, + k._underlying, + v._underlying, + cum_seqlens_q._underlying, + cum_seqlens_k._underlying, + block_table._underlying, + max_seqlen_q, + max_seqlen_k, + alibi_slopes._underlying if alibi_slopes is not None else None, + scale, + ) + ) + + _infinicore.mha_varlen_( + out._underlying, + q._underlying, + k._underlying, + v._underlying, + cum_seqlens_q._underlying, + cum_seqlens_k._underlying, + block_table._underlying, + max_seqlen_q, + max_seqlen_k, + alibi_slopes._underlying if alibi_slopes is not None else None, + scale, + ) + + return out diff --git a/src/infinicore/adaptor/aten_adaptor.cc b/src/infinicore/adaptor/aten_adaptor.cc new file mode 100644 index 000000000..2edbe3f8f --- /dev/null +++ b/src/infinicore/adaptor/aten_adaptor.cc @@ -0,0 +1,44 @@ +#ifdef ENABLE_ATEN +#include "infinicore/adaptor/aten_adaptor.hpp" + +namespace infinicore::adaptor { + +at::Tensor to_aten_tensor(const infinicore::Tensor &t) { + void *data_ptr = (void *)(t->data()); + + auto sizes = std::vector( + t->shape().begin(), + t->shape().end()); + + auto strides = t->strides(); + + auto dtype = to_at_dtype(t->dtype()); + auto device = to_at_device(t->device()); + + auto deleter_ = [](void * /*unused*/) mutable { + + }; + + at::TensorOptions options = at::TensorOptions() + .dtype(dtype) + .device(device) + .requires_grad(false); + + return at::from_blob( + data_ptr, + sizes, + strides, + deleter_, + options); +} + +#ifdef ENABLE_NVIDIA_API +c10::cuda::CUDAStream get_cuda_stream() { + return c10::cuda::getStreamFromExternal( + cudaStream_t(infinicore::context::getStream()), infinicore::context::getDevice().getIndex()); +} +#endif + +} // namespace infinicore::adaptor + +#endif // ENABLE_ATEN diff --git a/src/infinicore/ops/multi_head_attention_varlen/mha_varlen.cc b/src/infinicore/ops/multi_head_attention_varlen/mha_varlen.cc new file mode 100644 index 000000000..2b87f16ce --- /dev/null +++ b/src/infinicore/ops/multi_head_attention_varlen/mha_varlen.cc @@ -0,0 +1,70 @@ +#include "infinicore/ops/mha_varlen.hpp" +#include "../../utils.hpp" + +namespace infinicore::op { + +INFINICORE_GRAPH_OP_DISPATCHERS_IMPL(MultiheadAttentionVarlen); + +MultiheadAttentionVarlen::MultiheadAttentionVarlen(Tensor out, + const Tensor &q, + const Tensor &k, + const Tensor &v, + const Tensor &cum_seqlens_q, + const Tensor &cum_seqlens_kv, + const Tensor &block_table, + int max_seqlen_q, + int max_seqlen_k, + std::optional alibi_slopes, + float scale) { + INFINICORE_ASSERT_TENSORS_SAME_DEVICE(out, q, k, v, cum_seqlens_q, cum_seqlens_kv, block_table); + INFINICORE_GRAPH_OP_DISPATCH(out->device().getType(), + out, q, k, v, cum_seqlens_q, cum_seqlens_kv, block_table, max_seqlen_q, max_seqlen_k, alibi_slopes, scale); +} + +void MultiheadAttentionVarlen::execute(Tensor out, + const Tensor &q, + const Tensor &k, + const Tensor &v, + const Tensor &cum_seqlens_q, + const Tensor &cum_seqlens_kv, + const Tensor &block_table, + int max_seqlen_q, + int max_seqlen_k, + std::optional alibi_slopes, + float scale) { + INFINICORE_GRAPH_OP_RECORD_OR_RUN( + MultiheadAttentionVarlen, + out, q, k, v, cum_seqlens_q, cum_seqlens_kv, block_table, max_seqlen_q, max_seqlen_k, alibi_slopes, scale); +} + +Tensor mha_varlen( + const Tensor &q, + const Tensor &k, + const Tensor &v, + const Tensor &cum_seqlens_q, + const Tensor &cum_seqlens_kv, + const Tensor &block_table, + int max_seqlen_q, + int max_seqlen_k, + std::optional alibi_slopes, + float scale) { + auto out = Tensor::empty(q->shape(), q->dtype(), q->device()); + mha_varlen_(out, q, k, v, cum_seqlens_q, cum_seqlens_kv, block_table, max_seqlen_q, max_seqlen_k, alibi_slopes, scale); + return out; +} + +void mha_varlen_(Tensor out, + const Tensor &q, + const Tensor &k, + const Tensor &v, + const Tensor &cum_seqlens_q, + const Tensor &cum_seqlens_kv, + const Tensor &block_table, + int max_seqlen_q, + int max_seqlen_k, + std::optional alibi_slopes, + float scale) { + MultiheadAttentionVarlen::execute(out, q, k, v, cum_seqlens_q, cum_seqlens_kv, block_table, max_seqlen_q, max_seqlen_k, alibi_slopes, scale); +} + +} // namespace infinicore::op diff --git a/src/infinicore/ops/multi_head_attention_varlen/mha_varlen_flashattn.cc b/src/infinicore/ops/multi_head_attention_varlen/mha_varlen_flashattn.cc new file mode 100644 index 000000000..aff085898 --- /dev/null +++ b/src/infinicore/ops/multi_head_attention_varlen/mha_varlen_flashattn.cc @@ -0,0 +1,95 @@ +#include "infinicore/ops/mha_varlen.hpp" + +#include "infinicore/adaptor/flash_attention_adaptor.hpp" + +#include + +namespace infinicore::op::mha_varlen_impl::flashattn { + +struct PlannedMeta { + graph::GraphTensor out, q, k, v, cum_seqlens_q, cum_seqlens_k, block_table; + int max_seqlen_q, max_seqlen_k; + std::optional alibi_slopes; + float scale; +}; + +void *plan(Tensor out, + const Tensor &q, + const Tensor &k, + const Tensor &v, + const Tensor &cum_seqlens_q, + const Tensor &cum_seqlens_k, + const Tensor &block_table, + int max_seqlen_q, + int max_seqlen_k, + std::optional alibi_slopes, + float scale) { + + return new PlannedMeta{ + graph::GraphTensor(out), + graph::GraphTensor(q), + graph::GraphTensor(k), + graph::GraphTensor(v), + graph::GraphTensor(cum_seqlens_q), + graph::GraphTensor(cum_seqlens_k), + graph::GraphTensor(block_table), + max_seqlen_q, + max_seqlen_k, + alibi_slopes ? std::optional(graph::GraphTensor(*alibi_slopes)) : std::nullopt, + scale}; +} + +void run(void *planned_meta) { +#ifdef ENABLE_FLASH_ATTN + c10::cuda::CUDAStreamGuard guard(infinicore::adaptor::get_cuda_stream()); + auto *p = reinterpret_cast(planned_meta); + + auto q = infinicore::adaptor::to_aten_tensor(p->q); + auto k = infinicore::adaptor::to_aten_tensor(p->k); + auto v = infinicore::adaptor::to_aten_tensor(p->v); + auto out = std::optional(infinicore::adaptor::to_aten_tensor(p->out)); + auto cu_seqlens_q = infinicore::adaptor::to_aten_tensor(p->cum_seqlens_q); + auto cu_seqlens_kv = infinicore::adaptor::to_aten_tensor(p->cum_seqlens_k); + std::optional seqused_k = std::nullopt; + std::optional leftpad_k = std::nullopt; + auto block_table = std::optional(infinicore::adaptor::to_aten_tensor(p->block_table)); + auto max_seqlen_q = p->max_seqlen_q; + auto max_seqlen_k = p->max_seqlen_k; + auto alibi_slopes = p->alibi_slopes ? std::optional(infinicore::adaptor::to_aten_tensor(*p->alibi_slopes)) : std::nullopt; + auto scale = p->scale; + + flash::mha_varlen_fwd( + q, + k, + v, + out, + cu_seqlens_q, + cu_seqlens_kv, + seqused_k, + leftpad_k, + block_table, + alibi_slopes, + max_seqlen_q, + max_seqlen_k, + 0.0, + scale, + false, + true, + -1, + -1, + 0.0, + false, + std::nullopt); +#else + throw std::runtime_error("FlashAttention is not enabled in this build"); +#endif +} + +void cleanup(void **planned_meta_ptr) { + delete *reinterpret_cast(planned_meta_ptr); + *planned_meta_ptr = nullptr; +} + +INFINICORE_GRAPH_OP_REGISTER_ALLDEVICE(MultiheadAttentionVarlen, &plan, &run, &cleanup); + +} // namespace infinicore::op::mha_varlen_impl::flashattn diff --git a/src/infinicore/pybind11/ops.hpp b/src/infinicore/pybind11/ops.hpp index d9fc5b084..b781fa843 100644 --- a/src/infinicore/pybind11/ops.hpp +++ b/src/infinicore/pybind11/ops.hpp @@ -12,6 +12,7 @@ #include "ops/linear.hpp" #include "ops/linear_w8a8i8.hpp" #include "ops/matmul.hpp" +#include "ops/mha_varlen.hpp" #include "ops/mul.hpp" #include "ops/paged_attention.hpp" #include "ops/paged_attention_prefill.hpp" @@ -38,6 +39,7 @@ inline void bind(py::module &m) { bind_linear(m); bind_matmul(m); bind_mul(m); + bind_mha_varlen(m); bind_paged_attention(m); bind_paged_attention_prefill(m); bind_paged_caching(m); diff --git a/src/infinicore/pybind11/ops/mha_varlen.hpp b/src/infinicore/pybind11/ops/mha_varlen.hpp new file mode 100644 index 000000000..f9004d9a4 --- /dev/null +++ b/src/infinicore/pybind11/ops/mha_varlen.hpp @@ -0,0 +1,102 @@ +#pragma once + +#include + +#include "infinicore/ops/mha_varlen.hpp" + +namespace py = pybind11; + +namespace infinicore::ops { + +Tensor py_mha_varlen(Tensor q, + Tensor k, + Tensor v, + Tensor cum_seqlens_q, + Tensor cum_seqlens_k, + Tensor block_table, + int max_seqlen_q, + int max_seqlen_k, + pybind11::object alibi_slopes, + float scale) { + std::optional alibi_slopes_tensor = std::nullopt; + if (!alibi_slopes.is_none()) { + alibi_slopes_tensor = alibi_slopes.cast(); + } + + return op::mha_varlen( + q, + k, + v, + cum_seqlens_q, + cum_seqlens_k, + block_table, + max_seqlen_q, + max_seqlen_k, + alibi_slopes_tensor, + scale); +} + +void py_mha_varlen_(Tensor out, + Tensor q, + Tensor k, + Tensor v, + Tensor cum_seqlens_q, + Tensor cum_seqlens_k, + Tensor block_table, + int max_seqlen_q, + int max_seqlen_k, + pybind11::object alibi_slopes, + float scale) { + std::optional alibi_slopes_tensor = std::nullopt; + if (!alibi_slopes.is_none()) { + alibi_slopes_tensor = alibi_slopes.cast(); + } + + op::mha_varlen_( + out, + q, + k, + v, + cum_seqlens_q, + cum_seqlens_k, + block_table, + max_seqlen_q, + max_seqlen_k, + alibi_slopes_tensor, + scale); +} + +inline void bind_mha_varlen(py::module &m) { + m.def( + "mha_varlen", + &ops::py_mha_varlen, + py::arg("q"), + py::arg("k"), + py::arg("v"), + py::arg("cum_seqlens_q"), + py::arg("cum_seqlens_k"), + py::arg("block_table"), + py::arg("max_seqlen_q"), + py::arg("max_seqlen_k"), + py::arg("alibi_slopes"), + py::arg("scale"), + R"doc(Variable-length multi-head attention.)doc"); + + m.def( + "mha_varlen_", + &ops::py_mha_varlen_, + py::arg("out"), + py::arg("q"), + py::arg("k"), + py::arg("v"), + py::arg("cum_seqlens_q"), + py::arg("cum_seqlens_k"), + py::arg("block_table"), + py::arg("max_seqlen_q"), + py::arg("max_seqlen_k"), + py::arg("alibi_slopes"), + py::arg("scale"), + R"doc(In-place variable-length multi-head attention.)doc"); +} + +} // namespace infinicore::ops diff --git a/test/infinicore/ops/mha_varlen.py b/test/infinicore/ops/mha_varlen.py new file mode 100644 index 000000000..942595782 --- /dev/null +++ b/test/infinicore/ops/mha_varlen.py @@ -0,0 +1,270 @@ +import os +import sys + +import torch + +import infinicore + +sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..")) + +from framework import ( + BaseOperatorTest, + GenericTestRunner, + TensorInitializer, + TensorSpec, + TestCase, +) + +# Test Cases: (num_heads, num_kv_heads, head_size, block_size, [request_batch]) +_TEST_CASES_DATA = [ + (1, 1, 128, 256, [(250,), (7,)]), + (4, 4, 128, 256, [(250,), (7,)]), + (1, 1, 128, 256, [(260, 73), (1, 1)]), + (8, 2, 128, 256, [(250,), (7,)]), + (8, 2, 128, 256, [(260, 73), (1, 1)]), +] + +_MAX_SEQUENCE_LENGTH = 8192 + +_TOLERANCE_MAP = { + infinicore.float16: {"atol": 1e-2, "rtol": 1e-2}, + infinicore.bfloat16: {"atol": 2e-2, "rtol": 2e-2}, +} + +_TENSOR_DTYPES = [infinicore.float16, infinicore.bfloat16] + + +class SimpleCacheManager: + def __init__(self, num_blocks, block_size): + self.num_blocks = num_blocks + self.block_size = block_size + self.free_blocks = list(range(num_blocks)) + self.request_to_blocks = {} + self.request_to_len = {} + + def allocate_slots(self, request_id, num_new_tokens): + if request_id not in self.request_to_len: + self.request_to_len[request_id] = 0 + self.request_to_blocks[request_id] = [] + + start_pos = self.request_to_len[request_id] + new_total_len = start_pos + num_new_tokens + needed_blocks = (new_total_len + self.block_size - 1) // self.block_size + added_blocks = needed_blocks - len(self.request_to_blocks[request_id]) + + for _ in range(added_blocks): + self.request_to_blocks[request_id].append(self.free_blocks.pop(0)) + + self.request_to_len[request_id] = new_total_len + return self.request_to_blocks[request_id], new_total_len + + +def parse_test_cases(): + test_cases = [] + + for ( + num_heads, + num_kv_heads, + head_size, + block_size, + request_batches, + ) in _TEST_CASES_DATA: + scale = head_size**-0.5 + num_blocks = 512 + manager = SimpleCacheManager(num_blocks, block_size) + num_seqs = len(request_batches[0]) + kv_lens = torch.zeros(num_seqs, dtype=torch.int32) + + persistent_k = torch.zeros((num_blocks, num_kv_heads, block_size, head_size)) + persistent_v = torch.zeros((num_blocks, num_kv_heads, block_size, head_size)) + + for r, req in enumerate(request_batches): + assert len(req) == num_seqs, "All requests should have the same length" + q_lens = torch.tensor(req, dtype=torch.int32) + kv_lens = kv_lens + q_lens + total_q_tokens = q_lens.sum().item() + cum_seqlens_q = torch.zeros(num_seqs + 1, dtype=torch.int32) + cum_seqlens_q[1:] = torch.cumsum(q_lens, dim=0) + cum_seqlens_k = torch.zeros(num_seqs + 1, dtype=torch.int32) + cum_seqlens_k[1:] = torch.cumsum(kv_lens, dim=0) + + query_base = torch.randn((total_q_tokens, num_heads, head_size)) + + round_block_tables_list = [] + for i in range(num_seqs): + p_blocks, total_len = manager.allocate_slots(i, q_lens[i].item()) + round_block_tables_list.append(p_blocks) + + h_len = kv_lens[i].item() - q_lens[i].item() + + for t in range(q_lens[i].item()): + logical_pos = h_len + t + b_id = p_blocks[logical_pos // block_size] + off = logical_pos % block_size + persistent_k[b_id, :, off, :] = torch.randn(num_kv_heads, head_size) + persistent_v[b_id, :, off, :] = torch.randn(num_kv_heads, head_size) + + max_blks = max(len(t) for t in round_block_tables_list) + padded_tables = torch.tensor( + [t + [0] * (max_blks - len(t)) for t in round_block_tables_list] + ) + + for dtype in _TENSOR_DTYPES: + tolerance = _TOLERANCE_MAP.get(dtype) + + test_cases.append( + TestCase( + inputs=[ + TensorSpec.from_tensor( + query_base.shape, + init_mode=TensorInitializer.MANUAL, + set_tensor=query_base.clone(), + dtype=dtype, + ), + TensorSpec.from_tensor( + persistent_k.shape, + init_mode=TensorInitializer.MANUAL, + set_tensor=persistent_k.clone(), + dtype=dtype, + ), + TensorSpec.from_tensor( + persistent_v.shape, + init_mode=TensorInitializer.MANUAL, + set_tensor=persistent_v.clone(), + dtype=dtype, + ), + TensorSpec.from_tensor( + padded_tables.shape, + init_mode=TensorInitializer.MANUAL, + set_tensor=padded_tables.clone(), + dtype=infinicore.int32, + ), + TensorSpec.from_tensor( + cum_seqlens_q.shape, + init_mode=TensorInitializer.MANUAL, + set_tensor=cum_seqlens_q.clone(), + dtype=infinicore.int32, + ), + TensorSpec.from_tensor( + cum_seqlens_k.shape, + init_mode=TensorInitializer.MANUAL, + set_tensor=cum_seqlens_k.clone(), + dtype=infinicore.int32, + ), + ], + kwargs={ + "scale": scale, + "max_seqlen_q": _MAX_SEQUENCE_LENGTH, + "max_seqlen_k": _MAX_SEQUENCE_LENGTH, + }, + tolerance=tolerance, + description=f"MHA_Varlen_Round_{r}_{str(dtype).split('.')[-1]}", + ) + ) + + return test_cases + + +def ref_paged_attention_multi_turn( + query, k_cache, v_cache, block_tables, cum_seqlens_q, cum_seqlens_k, scale +): + output = torch.zeros_like(query) + num_seqs = len(cum_seqlens_q) - 1 + block_size = k_cache.shape[2] + + for i in range(num_seqs): + q_start, q_end = cum_seqlens_q[i].item(), cum_seqlens_q[i + 1].item() + cur_q = query[q_start:q_end] + q_len = q_end - q_start + h_len = (cum_seqlens_k[i + 1].item() - cum_seqlens_k[i].item()) - q_len + total_len = h_len + q_len + + table = block_tables[i] + keys, values = [], [] + for j in range(total_len): + b_id = table[j // block_size].item() + off = j % block_size + keys.append(k_cache[b_id, :, off, :]) + values.append(v_cache[b_id, :, off, :]) + + K = torch.stack(keys, dim=0) + V = torch.stack(values, dim=0) + + q_heads = cur_q.shape[1] + kv_heads = K.shape[1] + + assert q_heads % kv_heads == 0 + group_size = q_heads // kv_heads + if group_size > 1: + K = K.repeat_interleave(group_size, dim=1) + V = V.repeat_interleave(group_size, dim=1) + + scores = torch.einsum("qhd,khd->hqk", cur_q.float(), K.float()) * scale + mask = torch.full((q_len, total_len), float("-inf"), device=query.device) + for t in range(q_len): + mask[t, : h_len + t + 1] = 0.0 + + attn = torch.softmax(scores + mask.unsqueeze(0), dim=-1).to(query.dtype) + output[q_start:q_end] = torch.einsum("hqk,khd->qhd", attn, V) + return output + + +class OpTest(BaseOperatorTest): + def __init__(self): + super().__init__("PagedAttentionPrefill") + + def get_test_cases(self): + return parse_test_cases() + + def torch_operator( + self, + query, + k_cache, + v_cache, + block_tables, + cum_seqlens_q, + cum_seqlens_k, + scale=1.0, + max_seqlen_q=0, + max_seqlen_k=0, + ): + return ref_paged_attention_multi_turn( + query, k_cache, v_cache, block_tables, cum_seqlens_q, cum_seqlens_k, scale + ) + + def infinicore_operator( + self, + query, + k_cache, + v_cache, + block_tables, + cum_seqlens_q, + cum_seqlens_k, + scale=1.0, + max_seqlen_q=0, + max_seqlen_k=0, + ): + out = infinicore.mha_varlen( + query, + k_cache.permute([0, 2, 1, 3]), + v_cache.permute([0, 2, 1, 3]), + cum_seqlens_q, + cum_seqlens_k, + block_tables, + max_seqlen_q, + max_seqlen_k, + alibi_slopes=None, + scale=scale, + ) + infinicore.sync_stream() + return out + + +def main(): + """Main entry point""" + runner = GenericTestRunner(OpTest) + runner.run_and_exit() + + +if __name__ == "__main__": + main() diff --git a/xmake.lua b/xmake.lua index d68a7b465..ec0a6c632 100644 --- a/xmake.lua +++ b/xmake.lua @@ -226,6 +226,28 @@ if has_config("ninetoothed") then add_defines("ENABLE_NINETOOTHED") end +-- ATen +option("aten") + set_default(false) + set_showmenu(true) + set_description("Wether to link aten and torch libraries") +option_end() + +-- Flash-Attn +option("flash-attn") + set_default(nil) + set_showmenu(true) + set_description("Path to flash-attention repo. If not set, flash-attention will not used.") +option_end() + +if has_config("aten") then + add_defines("ENABLE_ATEN") + if get_config("flash-attn") ~= nil then + add_defines("ENABLE_FLASH_ATTN") + end +end + + -- cuda graph option("graph") set_default(false) @@ -314,6 +336,7 @@ target("infinirt") if not is_plat("windows") then add_cxflags("-fPIC") add_cxxflags("-fPIC") + add_ldflags("-fPIC", {force = true}) end set_installdir(os.getenv("INFINI_ROOT") or (os.getenv(is_host("windows") and "HOMEPATH" or "HOME") .. "/.infini")) add_files("src/infinirt/*.cc") @@ -439,8 +462,44 @@ target("infinicore_cpp_api") add_linkdirs(INFINI_ROOT.."/lib") add_links("infiniop", "infinirt", "infiniccl") + if get_config("flash-attn") ~= nil then + add_installfiles("(builddir)/$(plat)/$(arch)/$(mode)/flash-attn*.so", {prefixdir = "lib"}) + if has_config("nv-gpu") then + add_deps("flash-attn-nvidia") + end + end + + before_build(function (target) + if has_config("aten") then + local outdata = os.iorunv("python", {"-c", "import torch, os; print(os.path.dirname(torch.__file__))"}):trim() + local TORCH_DIR = outdata + + target:add( + "includedirs", + path.join(TORCH_DIR, "include"), + path.join(TORCH_DIR, "include/torch/csrc/api/include"), + { public = true }) + + target:add( + "linkdirs", + path.join(TORCH_DIR, "lib"), + { public = true } + ) + target:add( + "links", + "torch", + "c10", + "torch_cuda", + "c10_cuda", + { public = true } + ) + end + + end) + -- Add InfiniCore C++ source files (needed for RoPE and other nn modules) add_files("src/infinicore/*.cc") + add_files("src/infinicore/adaptor/*.cc") add_files("src/infinicore/context/*.cc") add_files("src/infinicore/context/*/*.cc") add_files("src/infinicore/tensor/*.cc") diff --git a/xmake/nvidia.lua b/xmake/nvidia.lua index 648a12723..a77187ff3 100644 --- a/xmake/nvidia.lua +++ b/xmake/nvidia.lua @@ -9,6 +9,10 @@ if CUTLASS_ROOT ~= nil then add_includedirs(CUTLASS_ROOT) end +local FLASH_ATTN_ROOT = get_config("flash-attn") + +local INFINI_ROOT = os.getenv("INFINI_ROOT") or (os.getenv(is_host("windows") and "HOMEPATH" or "HOME") .. "/.infini") + target("infiniop-nvidia") set_kind("static") add_deps("infini-utils") @@ -132,3 +136,49 @@ target("infiniccl-nvidia") set_languages("cxx17") target_end() + +target("flash-attn-nvidia") + set_kind("shared") + set_default(false) + set_policy("build.cuda.devlink", true) + set_toolchains("cuda") + add_links("cudart") + add_cugencodes("native") + + before_build(function (target) + if FLASH_ATTN_ROOT ~= nil then + local TORCH_DIR = os.iorunv("python", {"-c", "import torch, os; print(os.path.dirname(torch.__file__))"}):trim() + local PYTHON_INCLUDE = os.iorunv("python", {"-c", "import sysconfig; print(sysconfig.get_paths()['include'])"}):trim() + local PYTHON_LIB_DIR= os.iorunv("python", {"-c", "import sysconfig; print(sysconfig.get_config_var('LIBDIR'))"}):trim() + local LIB_PYTHON = os.iorunv("python", {"-c", "import sysconfig, os; print(sysconfig.get_config_var('LDLIBRARY'))"}):trim() + -- Include dirs + target:add("includedirs", FLASH_ATTN_ROOT .. "/csrc/flash_attn/src", {public = false}) + target:add("includedirs", TORCH_DIR .. "/include/torch/csrc/api/include", {public = false}) + target:add("includedirs", TORCH_DIR .. "/include", {public = false}) + target:add("includedirs", PYTHON_INCLUDE, {public = false}) + target:add("includedirs", CUTLASS_ROOT .. "/include", {public = false}) + target:add("includedirs", FLASH_ATTN_ROOT .. "/csrc/flash_attn", {public = false}) + + -- Link libraries + target:add("linkdirs", TORCH_DIR .. "/lib", PYTHON_LIB_DIR) + target:add("links", "torch", "torch_cuda", "torch_cpu", "c10", "c10_cuda", "torch_python", LIB_PYTHON) + + end + + end) + + if FLASH_ATTN_ROOT ~= nil then + add_files(FLASH_ATTN_ROOT .. "/csrc/flash_attn/flash_api.cpp") + add_files(FLASH_ATTN_ROOT .. "/csrc/flash_attn/src/*.cu") + -- Link options + add_ldflags("-Wl,--no-undefined", {force = true}) + -- Compile options + add_cxflags("-fPIC", {force = true}) + add_cuflags("-Xcompiler=-fPIC") + add_cuflags("--forward-unknown-to-host-compiler --expt-relaxed-constexpr --use_fast_math", {force = true}) + set_values("cuda.rdc", false) + end + + on_install(function (target) end) + +target_end()