From 930c8a0e5d1b85633970321e7cc11373bd1c3a56 Mon Sep 17 00:00:00 2001 From: PanZezhong Date: Thu, 5 Mar 2026 06:25:14 +0000 Subject: [PATCH 1/5] issue/1033 support flash_attn lib with aten adaptor --- include/infinicore/adaptor/aten_adaptor.hpp | 35 +++ .../adaptor/flash_attention_adaptor.hpp | 112 ++++++++ include/infinicore/ops/mha_varlen.hpp | 46 +++ python/infinicore/__init__.py | 2 + python/infinicore/ops/mha_varlen.py | 49 ++++ src/infinicore/adaptor/aten_adaptor.cc | 34 +++ .../multi_head_attention_varlen/mha_varlen.cc | 70 +++++ .../mha_varlen_flashattn.cc | 88 ++++++ src/infinicore/pybind11/ops.hpp | 2 + src/infinicore/pybind11/ops/mha_varlen.hpp | 102 +++++++ test/infinicore/ops/mha_varlen.py | 262 ++++++++++++++++++ xmake.lua | 27 ++ 12 files changed, 829 insertions(+) create mode 100644 include/infinicore/adaptor/aten_adaptor.hpp create mode 100644 include/infinicore/adaptor/flash_attention_adaptor.hpp create mode 100644 include/infinicore/ops/mha_varlen.hpp create mode 100644 python/infinicore/ops/mha_varlen.py create mode 100644 src/infinicore/adaptor/aten_adaptor.cc create mode 100644 src/infinicore/ops/multi_head_attention_varlen/mha_varlen.cc create mode 100644 src/infinicore/ops/multi_head_attention_varlen/mha_varlen_flashattn.cc create mode 100644 src/infinicore/pybind11/ops/mha_varlen.hpp create mode 100644 test/infinicore/ops/mha_varlen.py diff --git a/include/infinicore/adaptor/aten_adaptor.hpp b/include/infinicore/adaptor/aten_adaptor.hpp new file mode 100644 index 000000000..17e128d94 --- /dev/null +++ b/include/infinicore/adaptor/aten_adaptor.hpp @@ -0,0 +1,35 @@ +#pragma once +#include "../tensor.hpp" + +#include + +namespace infinicore::adaptor { +inline at::ScalarType to_at_dtype(DataType dtype) { + switch (dtype) { + case DataType::F32: + return at::kFloat; + case DataType::F16: + return at::kHalf; + case DataType::BF16: + return at::kBFloat16; + case DataType::I32: + return at::kInt; + case DataType::I64: + return at::kLong; + default: + throw std::runtime_error("Unsupported dtype for ATen"); + } +} + +inline at::Device to_at_device(const Device &device) { + if (device.getType() == Device::Type::NVIDIA) { + return at::Device(at::kCUDA, device.getIndex()); + } else if (device.getType() == Device::Type::CPU) { + return at::Device(at::kCPU); + } else { + throw std::runtime_error("Unsupported device type for ATen"); + } +} + +at::Tensor to_aten_tensor(const infinicore::Tensor &t); +} // namespace infinicore::adaptor \ No newline at end of file diff --git a/include/infinicore/adaptor/flash_attention_adaptor.hpp b/include/infinicore/adaptor/flash_attention_adaptor.hpp new file mode 100644 index 000000000..13a2bff17 --- /dev/null +++ b/include/infinicore/adaptor/flash_attention_adaptor.hpp @@ -0,0 +1,112 @@ +#pragma once +#include "aten_adaptor.hpp" + +namespace flash { +std::vector +mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num_heads x round_multiple(head_size, 8) + const at::Tensor &k, // batch_size x seqlen_k x num_heads_k x round_multiple(head_size, 8) + const at::Tensor &v, // batch_size x seqlen_k x num_heads_k x round_multiple(head_size, 8) + std::optional &out_, // batch_size x seqlen_q x num_heads x round_multiple(head_size, 8) + std::optional &alibi_slopes_, // num_heads or batch_size x num_heads + const float p_dropout, + const float softmax_scale, + bool is_causal, + int window_size_left, + int window_size_right, + const float softcap, + const bool return_softmax, + std::optional gen_); + +std::vector +mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i + const at::Tensor &k, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table. + const at::Tensor &v, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table. + std::optional &out_, // total_q x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i + const at::Tensor &cu_seqlens_q, // b+1 + const at::Tensor &cu_seqlens_k, // b+1 + std::optional &seqused_k, // b. If given, only this many elements of each batch element's keys are used. + std::optional &leftpad_k_, // batch_size + std::optional &block_table_, // batch_size x max_num_blocks_per_seq + std::optional &alibi_slopes_, // num_heads or b x num_heads + int max_seqlen_q, + const int max_seqlen_k, + const float p_dropout, + const float softmax_scale, + const bool zero_tensors, + bool is_causal, + int window_size_left, + int window_size_right, + const float softcap, + const bool return_softmax, + std::optional gen_); + +std::vector +mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x multiple_of(head_size_og, 8) + const at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size + const at::Tensor &k, // batch_size x seqlen_k x num_heads_k x head_size + const at::Tensor &v, // batch_size x seqlen_k x num_heads_k x head_size + const at::Tensor &out, // batch_size x seqlen_q x num_heads x head_size + const at::Tensor &softmax_lse, // b x h x seqlen_q + std::optional &dq_, // batch_size x seqlen_q x num_heads x head_size + std::optional &dk_, // batch_size x seqlen_k x num_heads_k x head_size + std::optional &dv_, // batch_size x seqlen_k x num_heads_k x head_size + std::optional &alibi_slopes_, // num_heads or batch_size x num_heads + const float p_dropout, // probability to drop + const float softmax_scale, + const bool is_causal, + int window_size_left, + int window_size_right, + const float softcap, + const bool deterministic, + std::optional gen_, + std::optional &rng_state); + +std::vector +mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size + const at::Tensor &q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i + const at::Tensor &k, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i + const at::Tensor &v, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i + const at::Tensor &out, // total_q x num_heads x head_size + const at::Tensor &softmax_lse, // h x total_q, softmax logsumexp + std::optional &dq_, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i + std::optional &dk_, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i + std::optional &dv_, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i + const at::Tensor &cu_seqlens_q, // b+1 + const at::Tensor &cu_seqlens_k, // b+1 + std::optional &alibi_slopes_, // num_heads or b x num_heads + const int max_seqlen_q, + const int max_seqlen_k, // max sequence length to choose the kernel + const float p_dropout, // probability to drop + const float softmax_scale, + const bool zero_tensors, + const bool is_causal, + int window_size_left, + int window_size_right, + const float softcap, + const bool deterministic, + std::optional gen_, + std::optional &rng_state); + +std::vector +mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size + const at::Tensor &kcache, // batch_size_c x seqlen_k x num_heads_k x head_size or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table. + const at::Tensor &vcache, // batch_size_c x seqlen_k x num_heads_k x head_size or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table. + std::optional &k_, // batch_size x seqlen_knew x num_heads_k x head_size + std::optional &v_, // batch_size x seqlen_knew x num_heads_k x head_size + std::optional &seqlens_k_, // batch_size + std::optional &rotary_cos_, // seqlen_ro x (rotary_dim / 2) + std::optional &rotary_sin_, // seqlen_ro x (rotary_dim / 2) + std::optional &cache_batch_idx_, // indices to index into the KV cache + std::optional &leftpad_k_, // batch_size + std::optional &block_table_, // batch_size x max_num_blocks_per_seq + std::optional &alibi_slopes_, // num_heads or batch_size x num_heads + std::optional &out_, // batch_size x seqlen_q x num_heads x head_size + const float softmax_scale, + bool is_causal, + int window_size_left, + int window_size_right, + const float softcap, + bool is_rotary_interleaved, // if true, rotary combines indices 0 & 1, else indices 0 & rotary_dim / 2 + int num_splits); + +} // namespace flash \ No newline at end of file diff --git a/include/infinicore/ops/mha_varlen.hpp b/include/infinicore/ops/mha_varlen.hpp new file mode 100644 index 000000000..fc35a11df --- /dev/null +++ b/include/infinicore/ops/mha_varlen.hpp @@ -0,0 +1,46 @@ +#pragma once + +#include "../device.hpp" +#include "common/op.hpp" +#include + +namespace infinicore::op { + +INFINICORE_GRAPH_OP_CLASS( + MultiheadAttentionVarlen, + Tensor, + const Tensor &, + const Tensor &, + const Tensor &, + const Tensor &, + const Tensor &, + const Tensor &, + int, + int, + std::optional, + float); + +Tensor mha_varlen(const Tensor &q, + const Tensor &k, + const Tensor &v, + const Tensor &cum_seqlens_q, + const Tensor &cum_seqlens_k, + const Tensor &block_table, + int max_seqlen_q, + int max_seqlen_k, + std::optional alibi_slopes, + float scale); + +void mha_varlen_(Tensor out, + const Tensor &q, + const Tensor &k, + const Tensor &v, + const Tensor &cum_seqlens_q, + const Tensor &cum_seqlens_k, + const Tensor &block_table, + int max_seqlen_q, + int max_seqlen_k, + std::optional alibi_slopes, + float scale); + +} // namespace infinicore::op diff --git a/python/infinicore/__init__.py b/python/infinicore/__init__.py index 54488f3c2..2f0ef56ea 100644 --- a/python/infinicore/__init__.py +++ b/python/infinicore/__init__.py @@ -52,6 +52,7 @@ from infinicore.ops.attention import attention from infinicore.ops.kv_caching import kv_caching from infinicore.ops.matmul import matmul +from infinicore.ops.mha_varlen import mha_varlen from infinicore.ops.mul import mul from infinicore.ops.narrow import narrow from infinicore.ops.paged_attention import paged_attention @@ -134,6 +135,7 @@ "from_list", "from_numpy", "from_torch", + "mha_varlen", "paged_caching", "paged_attention", "paged_attention_prefill", diff --git a/python/infinicore/ops/mha_varlen.py b/python/infinicore/ops/mha_varlen.py new file mode 100644 index 000000000..48dc5f9c1 --- /dev/null +++ b/python/infinicore/ops/mha_varlen.py @@ -0,0 +1,49 @@ +from infinicore.lib import _infinicore +from infinicore.tensor import Tensor + + +def mha_varlen( + q: Tensor, + k: Tensor, + v: Tensor, + cum_seqlens_q: Tensor, + cum_seqlens_k: Tensor, + block_table: Tensor, + max_seqlen_q: int, + max_seqlen_k: int, + alibi_slopes: Tensor | None = None, + scale: float = 1.0, + *, + out: Tensor | None = None, +): + if out is None: + return Tensor( + _infinicore.mha_varlen( + q._underlying, + k._underlying, + v._underlying, + cum_seqlens_q._underlying, + cum_seqlens_k._underlying, + block_table._underlying, + max_seqlen_q, + max_seqlen_k, + alibi_slopes._underlying if alibi_slopes is not None else None, + scale, + ) + ) + + _infinicore.mha_varlen_( + out._underlying, + q._underlying, + k._underlying, + v._underlying, + cum_seqlens_q._underlying, + cum_seqlens_k._underlying, + block_table._underlying, + max_seqlen_q, + max_seqlen_k, + alibi_slopes._underlying if alibi_slopes is not None else None, + scale, + ) + + return out diff --git a/src/infinicore/adaptor/aten_adaptor.cc b/src/infinicore/adaptor/aten_adaptor.cc new file mode 100644 index 000000000..e2f226bd4 --- /dev/null +++ b/src/infinicore/adaptor/aten_adaptor.cc @@ -0,0 +1,34 @@ +#include "infinicore/adaptor/aten_adaptor.hpp" + +namespace infinicore::adaptor { + + +at::Tensor to_aten_tensor(const infinicore::Tensor &t) { + void *data_ptr = (void *)(t->data()); + + auto sizes = std::vector( + t->shape().begin(), + t->shape().end()); + + auto strides = t->strides(); + + auto dtype = to_at_dtype(t->dtype()); + auto device = to_at_device(t->device()); + + auto deleter_ = [](void * /*unused*/) mutable { + + }; + + at::TensorOptions options = at::TensorOptions() + .dtype(dtype) + .device(device) + .requires_grad(false); + + return at::from_blob( + data_ptr, + sizes, + strides, + deleter_, + options); +} +} // namespace infinicore::adaptor \ No newline at end of file diff --git a/src/infinicore/ops/multi_head_attention_varlen/mha_varlen.cc b/src/infinicore/ops/multi_head_attention_varlen/mha_varlen.cc new file mode 100644 index 000000000..2b87f16ce --- /dev/null +++ b/src/infinicore/ops/multi_head_attention_varlen/mha_varlen.cc @@ -0,0 +1,70 @@ +#include "infinicore/ops/mha_varlen.hpp" +#include "../../utils.hpp" + +namespace infinicore::op { + +INFINICORE_GRAPH_OP_DISPATCHERS_IMPL(MultiheadAttentionVarlen); + +MultiheadAttentionVarlen::MultiheadAttentionVarlen(Tensor out, + const Tensor &q, + const Tensor &k, + const Tensor &v, + const Tensor &cum_seqlens_q, + const Tensor &cum_seqlens_kv, + const Tensor &block_table, + int max_seqlen_q, + int max_seqlen_k, + std::optional alibi_slopes, + float scale) { + INFINICORE_ASSERT_TENSORS_SAME_DEVICE(out, q, k, v, cum_seqlens_q, cum_seqlens_kv, block_table); + INFINICORE_GRAPH_OP_DISPATCH(out->device().getType(), + out, q, k, v, cum_seqlens_q, cum_seqlens_kv, block_table, max_seqlen_q, max_seqlen_k, alibi_slopes, scale); +} + +void MultiheadAttentionVarlen::execute(Tensor out, + const Tensor &q, + const Tensor &k, + const Tensor &v, + const Tensor &cum_seqlens_q, + const Tensor &cum_seqlens_kv, + const Tensor &block_table, + int max_seqlen_q, + int max_seqlen_k, + std::optional alibi_slopes, + float scale) { + INFINICORE_GRAPH_OP_RECORD_OR_RUN( + MultiheadAttentionVarlen, + out, q, k, v, cum_seqlens_q, cum_seqlens_kv, block_table, max_seqlen_q, max_seqlen_k, alibi_slopes, scale); +} + +Tensor mha_varlen( + const Tensor &q, + const Tensor &k, + const Tensor &v, + const Tensor &cum_seqlens_q, + const Tensor &cum_seqlens_kv, + const Tensor &block_table, + int max_seqlen_q, + int max_seqlen_k, + std::optional alibi_slopes, + float scale) { + auto out = Tensor::empty(q->shape(), q->dtype(), q->device()); + mha_varlen_(out, q, k, v, cum_seqlens_q, cum_seqlens_kv, block_table, max_seqlen_q, max_seqlen_k, alibi_slopes, scale); + return out; +} + +void mha_varlen_(Tensor out, + const Tensor &q, + const Tensor &k, + const Tensor &v, + const Tensor &cum_seqlens_q, + const Tensor &cum_seqlens_kv, + const Tensor &block_table, + int max_seqlen_q, + int max_seqlen_k, + std::optional alibi_slopes, + float scale) { + MultiheadAttentionVarlen::execute(out, q, k, v, cum_seqlens_q, cum_seqlens_kv, block_table, max_seqlen_q, max_seqlen_k, alibi_slopes, scale); +} + +} // namespace infinicore::op diff --git a/src/infinicore/ops/multi_head_attention_varlen/mha_varlen_flashattn.cc b/src/infinicore/ops/multi_head_attention_varlen/mha_varlen_flashattn.cc new file mode 100644 index 000000000..3c8b627ad --- /dev/null +++ b/src/infinicore/ops/multi_head_attention_varlen/mha_varlen_flashattn.cc @@ -0,0 +1,88 @@ +#include "infinicore/ops/mha_varlen.hpp" + +#include "infinicore/adaptor/flash_attention_adaptor.hpp" + +namespace infinicore::op::mha_varlen_impl::flashattn { + +struct PlannedMeta { + graph::GraphTensor out, q, k, v, cum_seqlens_q, cum_seqlens_k, block_table; + int max_seqlen_q, max_seqlen_k; + std::optional alibi_slopes; + float scale; +}; + +void *plan(Tensor out, + const Tensor &q, + const Tensor &k, + const Tensor &v, + const Tensor &cum_seqlens_q, + const Tensor &cum_seqlens_k, + const Tensor &block_table, + int max_seqlen_q, + int max_seqlen_k, + std::optional alibi_slopes, + float scale) { + + return new PlannedMeta{ + graph::GraphTensor(out), + graph::GraphTensor(q), + graph::GraphTensor(k), + graph::GraphTensor(v), + graph::GraphTensor(cum_seqlens_q), + graph::GraphTensor(cum_seqlens_k), + graph::GraphTensor(block_table), + max_seqlen_q, + max_seqlen_k, + alibi_slopes ? std::optional(graph::GraphTensor(*alibi_slopes)) : std::nullopt, + scale}; +} + +void run(void *planned_meta) { + auto *p = reinterpret_cast(planned_meta); + + auto q = infinicore::adaptor::to_aten_tensor(p->q); + auto k = infinicore::adaptor::to_aten_tensor(p->k); + auto v = infinicore::adaptor::to_aten_tensor(p->v); + auto out = std::optional(infinicore::adaptor::to_aten_tensor(p->out)); + auto cu_seqlens_q = infinicore::adaptor::to_aten_tensor(p->cum_seqlens_q); + auto cu_seqlens_kv = infinicore::adaptor::to_aten_tensor(p->cum_seqlens_k); + std::optional seqused_k = std::nullopt; + std::optional leftpad_k = std::nullopt; + auto block_table = std::optional(infinicore::adaptor::to_aten_tensor(p->block_table)); + auto max_seqlen_q = p->max_seqlen_q; + auto max_seqlen_k = p->max_seqlen_k; + auto alibi_slopes = p->alibi_slopes ? std::optional(infinicore::adaptor::to_aten_tensor(*p->alibi_slopes)) : std::nullopt; + auto scale = p->scale; + + flash::mha_varlen_fwd( + q, + k, + v, + out, + cu_seqlens_q, + cu_seqlens_kv, + seqused_k, + leftpad_k, + block_table, + alibi_slopes, + max_seqlen_q, + max_seqlen_k, + 0.0, + scale, + false, + true, + -1, + -1, + 0.0, + false, + std::nullopt); +} + +void cleanup(void **planned_meta_ptr) { + delete *reinterpret_cast(planned_meta_ptr); + *planned_meta_ptr = nullptr; +} + +INFINICORE_GRAPH_OP_REGISTER_ALLDEVICE(MultiheadAttentionVarlen, &plan, &run, &cleanup); + +} // namespace infinicore::op::mha_varlen_impl::flashattn diff --git a/src/infinicore/pybind11/ops.hpp b/src/infinicore/pybind11/ops.hpp index d9fc5b084..7b0a9bc62 100644 --- a/src/infinicore/pybind11/ops.hpp +++ b/src/infinicore/pybind11/ops.hpp @@ -13,6 +13,7 @@ #include "ops/linear_w8a8i8.hpp" #include "ops/matmul.hpp" #include "ops/mul.hpp" +#include "ops/mha_varlen.hpp" #include "ops/paged_attention.hpp" #include "ops/paged_attention_prefill.hpp" #include "ops/paged_caching.hpp" @@ -38,6 +39,7 @@ inline void bind(py::module &m) { bind_linear(m); bind_matmul(m); bind_mul(m); + bind_mha_varlen(m); bind_paged_attention(m); bind_paged_attention_prefill(m); bind_paged_caching(m); diff --git a/src/infinicore/pybind11/ops/mha_varlen.hpp b/src/infinicore/pybind11/ops/mha_varlen.hpp new file mode 100644 index 000000000..f9004d9a4 --- /dev/null +++ b/src/infinicore/pybind11/ops/mha_varlen.hpp @@ -0,0 +1,102 @@ +#pragma once + +#include + +#include "infinicore/ops/mha_varlen.hpp" + +namespace py = pybind11; + +namespace infinicore::ops { + +Tensor py_mha_varlen(Tensor q, + Tensor k, + Tensor v, + Tensor cum_seqlens_q, + Tensor cum_seqlens_k, + Tensor block_table, + int max_seqlen_q, + int max_seqlen_k, + pybind11::object alibi_slopes, + float scale) { + std::optional alibi_slopes_tensor = std::nullopt; + if (!alibi_slopes.is_none()) { + alibi_slopes_tensor = alibi_slopes.cast(); + } + + return op::mha_varlen( + q, + k, + v, + cum_seqlens_q, + cum_seqlens_k, + block_table, + max_seqlen_q, + max_seqlen_k, + alibi_slopes_tensor, + scale); +} + +void py_mha_varlen_(Tensor out, + Tensor q, + Tensor k, + Tensor v, + Tensor cum_seqlens_q, + Tensor cum_seqlens_k, + Tensor block_table, + int max_seqlen_q, + int max_seqlen_k, + pybind11::object alibi_slopes, + float scale) { + std::optional alibi_slopes_tensor = std::nullopt; + if (!alibi_slopes.is_none()) { + alibi_slopes_tensor = alibi_slopes.cast(); + } + + op::mha_varlen_( + out, + q, + k, + v, + cum_seqlens_q, + cum_seqlens_k, + block_table, + max_seqlen_q, + max_seqlen_k, + alibi_slopes_tensor, + scale); +} + +inline void bind_mha_varlen(py::module &m) { + m.def( + "mha_varlen", + &ops::py_mha_varlen, + py::arg("q"), + py::arg("k"), + py::arg("v"), + py::arg("cum_seqlens_q"), + py::arg("cum_seqlens_k"), + py::arg("block_table"), + py::arg("max_seqlen_q"), + py::arg("max_seqlen_k"), + py::arg("alibi_slopes"), + py::arg("scale"), + R"doc(Variable-length multi-head attention.)doc"); + + m.def( + "mha_varlen_", + &ops::py_mha_varlen_, + py::arg("out"), + py::arg("q"), + py::arg("k"), + py::arg("v"), + py::arg("cum_seqlens_q"), + py::arg("cum_seqlens_k"), + py::arg("block_table"), + py::arg("max_seqlen_q"), + py::arg("max_seqlen_k"), + py::arg("alibi_slopes"), + py::arg("scale"), + R"doc(In-place variable-length multi-head attention.)doc"); +} + +} // namespace infinicore::ops diff --git a/test/infinicore/ops/mha_varlen.py b/test/infinicore/ops/mha_varlen.py new file mode 100644 index 000000000..aba85e705 --- /dev/null +++ b/test/infinicore/ops/mha_varlen.py @@ -0,0 +1,262 @@ +import os +import sys + +import infinicore +import torch + +sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..")) + +from framework import ( + BaseOperatorTest, + GenericTestRunner, + TensorInitializer, + TensorSpec, + TestCase, +) + +# Test Cases: (num_seqs, num_heads, num_kv_heads, head_size, block_size, max_step_len, num_rounds) +_TEST_CASES_DATA = [ + (1, 1, 1, 128, 256, 16, 1), + (1, 4, 4, 128, 256, 16, 4), + (2, 8, 8, 128, 256, 16, 2), +] + +_TOLERANCE_MAP = { + infinicore.float16: {"atol": 1e-2, "rtol": 1e-2}, + infinicore.bfloat16: {"atol": 2e-2, "rtol": 2e-2}, +} + +_TENSOR_DTYPES = [infinicore.float16, infinicore.bfloat16] + + +class SimpleCacheManager: + def __init__(self, num_blocks, block_size): + self.num_blocks = num_blocks + self.block_size = block_size + self.free_blocks = list(range(num_blocks)) + self.request_to_blocks = {} + self.request_to_len = {} + + def allocate_slots(self, request_id, num_new_tokens): + if request_id not in self.request_to_len: + self.request_to_len[request_id] = 0 + self.request_to_blocks[request_id] = [] + + start_pos = self.request_to_len[request_id] + new_total_len = start_pos + num_new_tokens + needed_blocks = (new_total_len + self.block_size - 1) // self.block_size + added_blocks = needed_blocks - len(self.request_to_blocks[request_id]) + + for _ in range(added_blocks): + self.request_to_blocks[request_id].append(self.free_blocks.pop(0)) + + self.request_to_len[request_id] = new_total_len + return self.request_to_blocks[request_id], new_total_len + + +def parse_test_cases(): + test_cases = [] + + for ( + num_seqs, + num_heads, + num_kv_heads, + head_size, + block_size, + max_step_len, + num_rounds, + ) in _TEST_CASES_DATA: + scale = head_size**-0.5 + num_blocks = 512 + manager = SimpleCacheManager(num_blocks, block_size) + kv_lens = torch.zeros(num_seqs, dtype=torch.int32) + + persistent_k = torch.zeros((num_blocks, num_kv_heads, block_size, head_size)) + persistent_v = torch.zeros((num_blocks, num_kv_heads, block_size, head_size)) + + for r in range(num_rounds): + q_lens = torch.randint(1, max_step_len + 1, (num_seqs,), dtype=torch.int32) + kv_lens = kv_lens + q_lens + total_q_tokens = q_lens.sum().item() + cum_seqlens_q = torch.zeros(num_seqs + 1, dtype=torch.int32) + cum_seqlens_q[1:] = torch.cumsum(q_lens, dim=0) + cum_seqlens_k = torch.zeros(num_seqs + 1, dtype=torch.int32) + cum_seqlens_k[1:] = torch.cumsum(kv_lens, dim=0) + + query_base = torch.randn((total_q_tokens, num_heads, head_size)) + + round_block_tables_list = [] + for i in range(num_seqs): + p_blocks, total_len = manager.allocate_slots(i, q_lens[i].item()) + round_block_tables_list.append(p_blocks) + + h_len = kv_lens[i].item() - q_lens[i].item() + + for t in range(q_lens[i].item()): + logical_pos = h_len + t + b_id = p_blocks[logical_pos // block_size] + off = logical_pos % block_size + persistent_k[b_id, :, off, :] = torch.randn(num_kv_heads, head_size) + persistent_v[b_id, :, off, :] = torch.randn(num_kv_heads, head_size) + + max_blks = max(len(t) for t in round_block_tables_list) + padded_tables = torch.tensor( + [t + [0] * (max_blks - len(t)) for t in round_block_tables_list] + ) + + for dtype in _TENSOR_DTYPES: + tolerance = _TOLERANCE_MAP.get(dtype) + + test_cases.append( + TestCase( + inputs=[ + TensorSpec.from_tensor( + query_base.shape, + init_mode=TensorInitializer.MANUAL, + set_tensor=query_base.clone(), + dtype=dtype, + ), + TensorSpec.from_tensor( + persistent_k.shape, + init_mode=TensorInitializer.MANUAL, + set_tensor=persistent_k.clone(), + dtype=dtype, + ), + TensorSpec.from_tensor( + persistent_v.shape, + init_mode=TensorInitializer.MANUAL, + set_tensor=persistent_v.clone(), + dtype=dtype, + ), + TensorSpec.from_tensor( + padded_tables.shape, + init_mode=TensorInitializer.MANUAL, + set_tensor=padded_tables.clone(), + dtype=infinicore.int32, + ), + # TensorSpec.from_tensor( + # kv_lens.shape, + # init_mode=TensorInitializer.MANUAL, + # set_tensor=kv_lens.clone(), + # dtype=infinicore.int64, + # ), + TensorSpec.from_tensor( + cum_seqlens_q.shape, + init_mode=TensorInitializer.MANUAL, + set_tensor=cum_seqlens_q.clone(), + dtype=infinicore.int32, + ), + TensorSpec.from_tensor( + cum_seqlens_k.shape, + init_mode=TensorInitializer.MANUAL, + set_tensor=cum_seqlens_k.clone(), + dtype=infinicore.int32, + ), + ], + kwargs={ + "scale": scale, + "max_seqlen_q": max_step_len + num_rounds, + "max_seqlen_k": max_step_len + num_rounds, + }, + tolerance=tolerance, + description=f"MHA_Varlen_Round_{r}_{str(dtype).split('.')[-1]}", + ) + ) + + return test_cases + + +def ref_paged_attention_multi_turn( + query, k_cache, v_cache, block_tables, cum_seqlens_q, cum_seqlens_k, scale +): + output = torch.zeros_like(query) + num_seqs = len(cum_seqlens_q) - 1 + block_size = k_cache.shape[2] + + for i in range(num_seqs): + q_start, q_end = cum_seqlens_q[i].item(), cum_seqlens_q[i + 1].item() + cur_q = query[q_start:q_end] + q_len = q_end - q_start + h_len = (cum_seqlens_k[i + 1].item() - cum_seqlens_k[i].item()) - q_len + total_len = h_len + q_len + + table = block_tables[i] + keys, values = [], [] + for j in range(total_len): + b_id = table[j // block_size].item() + off = j % block_size + keys.append(k_cache[b_id, :, off, :]) + values.append(v_cache[b_id, :, off, :]) + + K = torch.stack(keys, dim=0) + V = torch.stack(values, dim=0) + + scores = torch.einsum("qhd,khd->hqk", cur_q.float(), K.float()) * scale + mask = torch.full((q_len, total_len), float("-inf"), device=query.device) + for t in range(q_len): + mask[t, : h_len + t + 1] = 0.0 + + attn = torch.softmax(scores + mask.unsqueeze(0), dim=-1).to(query.dtype) + output[q_start:q_end] = torch.einsum("hqk,khd->qhd", attn, V) + return output + + +class OpTest(BaseOperatorTest): + def __init__(self): + super().__init__("PagedAttentionPrefill") + + def get_test_cases(self): + return parse_test_cases() + + def torch_operator( + self, + query, + k_cache, + v_cache, + block_tables, + cum_seqlens_q, + cum_seqlens_k, + scale=1.0, + max_seqlen_q=0, + max_seqlen_k=0, + ): + return ref_paged_attention_multi_turn( + query, k_cache, v_cache, block_tables, cum_seqlens_q, cum_seqlens_k, scale + ) + + def infinicore_operator( + self, + query, + k_cache, + v_cache, + block_tables, + cum_seqlens_q, + cum_seqlens_k, + scale=1.0, + max_seqlen_q=0, + max_seqlen_k=0, + ): + out = infinicore.mha_varlen( + query, + k_cache.permute([0, 2, 1, 3]), + v_cache.permute([0, 2, 1, 3]), + cum_seqlens_q, + cum_seqlens_k, + block_tables, + max_seqlen_q, + max_seqlen_k, + alibi_slopes=None, + scale=scale, + ) + infinicore.sync_stream() + return out + + +def main(): + """Main entry point""" + runner = GenericTestRunner(OpTest) + runner.run_and_exit() + + +if __name__ == "__main__": + main() diff --git a/xmake.lua b/xmake.lua index d68a7b465..1b8a57c12 100644 --- a/xmake.lua +++ b/xmake.lua @@ -439,8 +439,35 @@ target("infinicore_cpp_api") add_linkdirs(INFINI_ROOT.."/lib") add_links("infiniop", "infinirt", "infiniccl") + -- ============================== + -- LibTorch integration + -- ============================== + local LIBTORCH_ROOT = ("/home/panzezhong/.conda/envs/myenv/lib/python3.13/site-packages/torch") + + -- headers + add_includedirs( + path.join(LIBTORCH_ROOT, "include"), + path.join(LIBTORCH_ROOT, "include/torch/csrc/api/include"), + { public = true } + ) + + -- libraries + add_linkdirs(path.join(LIBTORCH_ROOT, "lib")) + + -- core ATen / Torch libs + add_links( + "torch", + "c10", + "torch_cuda", + "c10_cuda" + ) + -- Flash attention lib + add_linkdirs("/home/panzezhong/Projects/InfiniCore/third_party/flash-attention/csrc/build") + add_links("flash_attn") + -- Add InfiniCore C++ source files (needed for RoPE and other nn modules) add_files("src/infinicore/*.cc") + add_files("src/infinicore/adaptor/*.cc") add_files("src/infinicore/context/*.cc") add_files("src/infinicore/context/*/*.cc") add_files("src/infinicore/tensor/*.cc") From d2aa36ded1fa689476cee64e8cf6d5e202f1d6c4 Mon Sep 17 00:00:00 2001 From: PanZezhong Date: Thu, 5 Mar 2026 06:25:14 +0000 Subject: [PATCH 2/5] issue/1033 fix mha_varlen test --- test/infinicore/ops/mha_varlen.py | 44 ++++++++++++++++++------------- 1 file changed, 26 insertions(+), 18 deletions(-) diff --git a/test/infinicore/ops/mha_varlen.py b/test/infinicore/ops/mha_varlen.py index aba85e705..942595782 100644 --- a/test/infinicore/ops/mha_varlen.py +++ b/test/infinicore/ops/mha_varlen.py @@ -1,9 +1,10 @@ import os import sys -import infinicore import torch +import infinicore + sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..")) from framework import ( @@ -14,13 +15,17 @@ TestCase, ) -# Test Cases: (num_seqs, num_heads, num_kv_heads, head_size, block_size, max_step_len, num_rounds) +# Test Cases: (num_heads, num_kv_heads, head_size, block_size, [request_batch]) _TEST_CASES_DATA = [ - (1, 1, 1, 128, 256, 16, 1), - (1, 4, 4, 128, 256, 16, 4), - (2, 8, 8, 128, 256, 16, 2), + (1, 1, 128, 256, [(250,), (7,)]), + (4, 4, 128, 256, [(250,), (7,)]), + (1, 1, 128, 256, [(260, 73), (1, 1)]), + (8, 2, 128, 256, [(250,), (7,)]), + (8, 2, 128, 256, [(260, 73), (1, 1)]), ] +_MAX_SEQUENCE_LENGTH = 8192 + _TOLERANCE_MAP = { infinicore.float16: {"atol": 1e-2, "rtol": 1e-2}, infinicore.bfloat16: {"atol": 2e-2, "rtol": 2e-2}, @@ -58,24 +63,24 @@ def parse_test_cases(): test_cases = [] for ( - num_seqs, num_heads, num_kv_heads, head_size, block_size, - max_step_len, - num_rounds, + request_batches, ) in _TEST_CASES_DATA: scale = head_size**-0.5 num_blocks = 512 manager = SimpleCacheManager(num_blocks, block_size) + num_seqs = len(request_batches[0]) kv_lens = torch.zeros(num_seqs, dtype=torch.int32) persistent_k = torch.zeros((num_blocks, num_kv_heads, block_size, head_size)) persistent_v = torch.zeros((num_blocks, num_kv_heads, block_size, head_size)) - for r in range(num_rounds): - q_lens = torch.randint(1, max_step_len + 1, (num_seqs,), dtype=torch.int32) + for r, req in enumerate(request_batches): + assert len(req) == num_seqs, "All requests should have the same length" + q_lens = torch.tensor(req, dtype=torch.int32) kv_lens = kv_lens + q_lens total_q_tokens = q_lens.sum().item() cum_seqlens_q = torch.zeros(num_seqs + 1, dtype=torch.int32) @@ -134,12 +139,6 @@ def parse_test_cases(): set_tensor=padded_tables.clone(), dtype=infinicore.int32, ), - # TensorSpec.from_tensor( - # kv_lens.shape, - # init_mode=TensorInitializer.MANUAL, - # set_tensor=kv_lens.clone(), - # dtype=infinicore.int64, - # ), TensorSpec.from_tensor( cum_seqlens_q.shape, init_mode=TensorInitializer.MANUAL, @@ -155,8 +154,8 @@ def parse_test_cases(): ], kwargs={ "scale": scale, - "max_seqlen_q": max_step_len + num_rounds, - "max_seqlen_k": max_step_len + num_rounds, + "max_seqlen_q": _MAX_SEQUENCE_LENGTH, + "max_seqlen_k": _MAX_SEQUENCE_LENGTH, }, tolerance=tolerance, description=f"MHA_Varlen_Round_{r}_{str(dtype).split('.')[-1]}", @@ -191,6 +190,15 @@ def ref_paged_attention_multi_turn( K = torch.stack(keys, dim=0) V = torch.stack(values, dim=0) + q_heads = cur_q.shape[1] + kv_heads = K.shape[1] + + assert q_heads % kv_heads == 0 + group_size = q_heads // kv_heads + if group_size > 1: + K = K.repeat_interleave(group_size, dim=1) + V = V.repeat_interleave(group_size, dim=1) + scores = torch.einsum("qhd,khd->hqk", cur_q.float(), K.float()) * scale mask = torch.full((q_len, total_len), float("-inf"), device=query.device) for t in range(q_len): From 61cc09d7ed1f5fbb3c315e3ade68b3d6645d666d Mon Sep 17 00:00:00 2001 From: PanZezhong Date: Thu, 5 Mar 2026 06:25:14 +0000 Subject: [PATCH 3/5] issue/1033 support stream guard --- include/infinicore/adaptor/aten_adaptor.hpp | 6 ++++++ src/infinicore/adaptor/aten_adaptor.cc | 6 +++++- .../ops/multi_head_attention_varlen/mha_varlen_flashattn.cc | 1 + 3 files changed, 12 insertions(+), 1 deletion(-) diff --git a/include/infinicore/adaptor/aten_adaptor.hpp b/include/infinicore/adaptor/aten_adaptor.hpp index 17e128d94..de9dcac46 100644 --- a/include/infinicore/adaptor/aten_adaptor.hpp +++ b/include/infinicore/adaptor/aten_adaptor.hpp @@ -1,8 +1,12 @@ #pragma once +#include "../context/context.hpp" #include "../tensor.hpp" #include +#include +#include + namespace infinicore::adaptor { inline at::ScalarType to_at_dtype(DataType dtype) { switch (dtype) { @@ -32,4 +36,6 @@ inline at::Device to_at_device(const Device &device) { } at::Tensor to_aten_tensor(const infinicore::Tensor &t); + +c10::cuda::CUDAStream get_cuda_stream(); } // namespace infinicore::adaptor \ No newline at end of file diff --git a/src/infinicore/adaptor/aten_adaptor.cc b/src/infinicore/adaptor/aten_adaptor.cc index e2f226bd4..553f007cd 100644 --- a/src/infinicore/adaptor/aten_adaptor.cc +++ b/src/infinicore/adaptor/aten_adaptor.cc @@ -2,7 +2,6 @@ namespace infinicore::adaptor { - at::Tensor to_aten_tensor(const infinicore::Tensor &t) { void *data_ptr = (void *)(t->data()); @@ -31,4 +30,9 @@ at::Tensor to_aten_tensor(const infinicore::Tensor &t) { deleter_, options); } + +c10::cuda::CUDAStream get_cuda_stream() { + return c10::cuda::getStreamFromExternal( + cudaStream_t(infinicore::context::getStream()), infinicore::context::getDevice().getIndex()); +} } // namespace infinicore::adaptor \ No newline at end of file diff --git a/src/infinicore/ops/multi_head_attention_varlen/mha_varlen_flashattn.cc b/src/infinicore/ops/multi_head_attention_varlen/mha_varlen_flashattn.cc index 3c8b627ad..91194e42b 100644 --- a/src/infinicore/ops/multi_head_attention_varlen/mha_varlen_flashattn.cc +++ b/src/infinicore/ops/multi_head_attention_varlen/mha_varlen_flashattn.cc @@ -38,6 +38,7 @@ void *plan(Tensor out, } void run(void *planned_meta) { + c10::cuda::CUDAStreamGuard guard(infinicore::adaptor::get_cuda_stream()); auto *p = reinterpret_cast(planned_meta); auto q = infinicore::adaptor::to_aten_tensor(p->q); From 06c3df5d056114c49861db26f49761ce488c71c9 Mon Sep 17 00:00:00 2001 From: PanZezhong Date: Thu, 5 Mar 2026 06:37:34 +0000 Subject: [PATCH 4/5] issue/1033 add flash-attn compile target --- include/infinicore/adaptor/aten_adaptor.hpp | 9 ++- .../adaptor/flash_attention_adaptor.hpp | 4 +- src/infinicore/adaptor/aten_adaptor.cc | 8 +- .../mha_varlen_flashattn.cc | 6 ++ src/infinicore/pybind11/ops.hpp | 2 +- xmake.lua | 81 +++++++++++++------ xmake/nvidia.lua | 54 +++++++++++++ 7 files changed, 135 insertions(+), 29 deletions(-) diff --git a/include/infinicore/adaptor/aten_adaptor.hpp b/include/infinicore/adaptor/aten_adaptor.hpp index de9dcac46..70cb98e18 100644 --- a/include/infinicore/adaptor/aten_adaptor.hpp +++ b/include/infinicore/adaptor/aten_adaptor.hpp @@ -1,11 +1,14 @@ +#ifdef ENABLE_ATEN #pragma once #include "../context/context.hpp" #include "../tensor.hpp" #include +#ifdef ENABLE_NVIDIA_API #include #include +#endif namespace infinicore::adaptor { inline at::ScalarType to_at_dtype(DataType dtype) { @@ -37,5 +40,9 @@ inline at::Device to_at_device(const Device &device) { at::Tensor to_aten_tensor(const infinicore::Tensor &t); +#ifdef ENABLE_NVIDIA_API c10::cuda::CUDAStream get_cuda_stream(); -} // namespace infinicore::adaptor \ No newline at end of file +#endif +} // namespace infinicore::adaptor + +#endif // ENABLE_ATEN diff --git a/include/infinicore/adaptor/flash_attention_adaptor.hpp b/include/infinicore/adaptor/flash_attention_adaptor.hpp index 13a2bff17..8a9e152fd 100644 --- a/include/infinicore/adaptor/flash_attention_adaptor.hpp +++ b/include/infinicore/adaptor/flash_attention_adaptor.hpp @@ -1,3 +1,4 @@ +#ifdef ENABLE_FLASH_ATTN #pragma once #include "aten_adaptor.hpp" @@ -109,4 +110,5 @@ mha_fwd_kvcache(at::Tensor &q, // batch_size bool is_rotary_interleaved, // if true, rotary combines indices 0 & 1, else indices 0 & rotary_dim / 2 int num_splits); -} // namespace flash \ No newline at end of file +} // namespace flash +#endif // ENABLE_FLASH_ATTN diff --git a/src/infinicore/adaptor/aten_adaptor.cc b/src/infinicore/adaptor/aten_adaptor.cc index 553f007cd..2edbe3f8f 100644 --- a/src/infinicore/adaptor/aten_adaptor.cc +++ b/src/infinicore/adaptor/aten_adaptor.cc @@ -1,3 +1,4 @@ +#ifdef ENABLE_ATEN #include "infinicore/adaptor/aten_adaptor.hpp" namespace infinicore::adaptor { @@ -31,8 +32,13 @@ at::Tensor to_aten_tensor(const infinicore::Tensor &t) { options); } +#ifdef ENABLE_NVIDIA_API c10::cuda::CUDAStream get_cuda_stream() { return c10::cuda::getStreamFromExternal( cudaStream_t(infinicore::context::getStream()), infinicore::context::getDevice().getIndex()); } -} // namespace infinicore::adaptor \ No newline at end of file +#endif + +} // namespace infinicore::adaptor + +#endif // ENABLE_ATEN diff --git a/src/infinicore/ops/multi_head_attention_varlen/mha_varlen_flashattn.cc b/src/infinicore/ops/multi_head_attention_varlen/mha_varlen_flashattn.cc index 91194e42b..aff085898 100644 --- a/src/infinicore/ops/multi_head_attention_varlen/mha_varlen_flashattn.cc +++ b/src/infinicore/ops/multi_head_attention_varlen/mha_varlen_flashattn.cc @@ -2,6 +2,8 @@ #include "infinicore/adaptor/flash_attention_adaptor.hpp" +#include + namespace infinicore::op::mha_varlen_impl::flashattn { struct PlannedMeta { @@ -38,6 +40,7 @@ void *plan(Tensor out, } void run(void *planned_meta) { +#ifdef ENABLE_FLASH_ATTN c10::cuda::CUDAStreamGuard guard(infinicore::adaptor::get_cuda_stream()); auto *p = reinterpret_cast(planned_meta); @@ -77,6 +80,9 @@ void run(void *planned_meta) { 0.0, false, std::nullopt); +#else + throw std::runtime_error("FlashAttention is not enabled in this build"); +#endif } void cleanup(void **planned_meta_ptr) { diff --git a/src/infinicore/pybind11/ops.hpp b/src/infinicore/pybind11/ops.hpp index 7b0a9bc62..b781fa843 100644 --- a/src/infinicore/pybind11/ops.hpp +++ b/src/infinicore/pybind11/ops.hpp @@ -12,8 +12,8 @@ #include "ops/linear.hpp" #include "ops/linear_w8a8i8.hpp" #include "ops/matmul.hpp" -#include "ops/mul.hpp" #include "ops/mha_varlen.hpp" +#include "ops/mul.hpp" #include "ops/paged_attention.hpp" #include "ops/paged_attention_prefill.hpp" #include "ops/paged_caching.hpp" diff --git a/xmake.lua b/xmake.lua index 1b8a57c12..4b3c9f6e0 100644 --- a/xmake.lua +++ b/xmake.lua @@ -226,6 +226,28 @@ if has_config("ninetoothed") then add_defines("ENABLE_NINETOOTHED") end +-- ATen +option("aten") + set_default(false) + set_showmenu(true) + set_description("Wether to link aten and torch libraries") +option_end() + +-- Flash-Attn +option("flash-attn") + set_default(nil) + set_showmenu(true) + set_description("Path to flash-attention repo. If not set, flash-attention will not used.") +option_end() + +if has_config("aten") then + add_defines("ENABLE_ATEN") + if get_config("flash-attn") ~= nil then + add_defines("ENABLE_FLASH_ATTN") + end +end + + -- cuda graph option("graph") set_default(false) @@ -439,31 +461,40 @@ target("infinicore_cpp_api") add_linkdirs(INFINI_ROOT.."/lib") add_links("infiniop", "infinirt", "infiniccl") - -- ============================== - -- LibTorch integration - -- ============================== - local LIBTORCH_ROOT = ("/home/panzezhong/.conda/envs/myenv/lib/python3.13/site-packages/torch") - - -- headers - add_includedirs( - path.join(LIBTORCH_ROOT, "include"), - path.join(LIBTORCH_ROOT, "include/torch/csrc/api/include"), - { public = true } - ) - - -- libraries - add_linkdirs(path.join(LIBTORCH_ROOT, "lib")) - - -- core ATen / Torch libs - add_links( - "torch", - "c10", - "torch_cuda", - "c10_cuda" - ) - -- Flash attention lib - add_linkdirs("/home/panzezhong/Projects/InfiniCore/third_party/flash-attention/csrc/build") - add_links("flash_attn") + if get_config("flash-attn") ~= nil then + add_installfiles("(builddir)/$(plat)/$(arch)/$(mode)/flash-attn*.so", {prefixdir = "lib"}) + if has_config("nv-gpu") then + add_deps("flash-attn-nvidia") + end + end + + before_build(function (target) + if has_config("aten") then + local outdata = os.iorunv("python", {"-c", "import torch, os; print(os.path.dirname(torch.__file__))"}):trim() + local TORCH_DIR = outdata + + target:add( + "includedirs", + path.join(TORCH_DIR, "include"), + path.join(TORCH_DIR, "include/torch/csrc/api/include"), + { public = true }) + + target:add( + "linkdirs", + path.join(TORCH_DIR, "lib"), + { public = true } + ) + target:add( + "links", + "torch", + "c10", + "torch_cuda", + "c10_cuda", + { public = true } + ) + end + + end) -- Add InfiniCore C++ source files (needed for RoPE and other nn modules) add_files("src/infinicore/*.cc") diff --git a/xmake/nvidia.lua b/xmake/nvidia.lua index 648a12723..899fbdf3b 100644 --- a/xmake/nvidia.lua +++ b/xmake/nvidia.lua @@ -9,6 +9,10 @@ if CUTLASS_ROOT ~= nil then add_includedirs(CUTLASS_ROOT) end +local FLASH_ATTN_ROOT = get_config("flash-attn") + +local INFINI_ROOT = os.getenv("INFINI_ROOT") or (os.getenv(is_host("windows") and "HOMEPATH" or "HOME") .. "/.infini") + target("infiniop-nvidia") set_kind("static") add_deps("infini-utils") @@ -132,3 +136,53 @@ target("infiniccl-nvidia") set_languages("cxx17") target_end() + +target("flash-attn-nvidia") + set_kind("shared") + set_default(false) + set_policy("build.cuda.devlink", true) + set_toolchains("cuda") + add_links("cudart") + add_cugencodes("native") + + before_build(function (target) + if FLASH_ATTN_ROOT ~= nil then + local outdata = os.iorunv("python", {"-c", "import torch, os; print(os.path.dirname(torch.__file__))"}):trim() + local TORCH_DIR = outdata + + local outdata = os.iorunv("python", {"-c", "import sysconfig; print(sysconfig.get_paths()['include'])"}):trim() + local PYTHON_INCLUDE = outdata + + local outdata = os.iorunv("python", {"-c", "import sysconfig; print(sysconfig.get_config_var('LIBDIR'))"}):trim() + local PYTHON_LIB_DIR = outdata + -- Include dirs + target:add("includedirs", FLASH_ATTN_ROOT .. "/csrc/flash_attn/src", {public = false}) + target:add("includedirs", TORCH_DIR .. "/include/torch/csrc/api/include", {public = false}) + target:add("includedirs", TORCH_DIR .. "/include", {public = false}) + target:add("includedirs", PYTHON_INCLUDE, {public = false}) + target:add("includedirs", CUTLASS_ROOT .. "/include", {public = false}) + target:add("includedirs", FLASH_ATTN_ROOT .. "/csrc/flash_attn", {public = false}) + + -- Link libraries + target:add("linkdirs", TORCH_DIR .. "/lib", PYTHON_LIB_DIR) + target:add("links", "torch", "torch_cuda", "torch_cpu", "c10", "c10_cuda", "torch_python", "python3") + + end + + end) + + if FLASH_ATTN_ROOT ~= nil then + add_files(FLASH_ATTN_ROOT .. "/csrc/flash_attn/flash_api.cpp") + add_files(FLASH_ATTN_ROOT .. "/csrc/flash_attn/src/*.cu") + -- Link options + add_ldflags("-Wl,--no-undefined", {force = true}) + -- Compile options + add_cxflags("-fPIC", {force = true}) + add_cuflags("-Xcompiler=-fPIC") + add_cuflags("--forward-unknown-to-host-compiler --expt-relaxed-constexpr --use_fast_math", {force = true}) + set_values("cuda.rdc", false) + end + + on_install(function (target) end) + +target_end() From d575e7a09c00455ff2469e2c90f5607fe472554d Mon Sep 17 00:00:00 2001 From: PanZezhong Date: Thu, 5 Mar 2026 10:48:25 +0000 Subject: [PATCH 5/5] issue/1033 add fpic ldflag for infinirt and fix python link --- xmake.lua | 1 + xmake/nvidia.lua | 14 +++++--------- 2 files changed, 6 insertions(+), 9 deletions(-) diff --git a/xmake.lua b/xmake.lua index 4b3c9f6e0..ec0a6c632 100644 --- a/xmake.lua +++ b/xmake.lua @@ -336,6 +336,7 @@ target("infinirt") if not is_plat("windows") then add_cxflags("-fPIC") add_cxxflags("-fPIC") + add_ldflags("-fPIC", {force = true}) end set_installdir(os.getenv("INFINI_ROOT") or (os.getenv(is_host("windows") and "HOMEPATH" or "HOME") .. "/.infini")) add_files("src/infinirt/*.cc") diff --git a/xmake/nvidia.lua b/xmake/nvidia.lua index 899fbdf3b..a77187ff3 100644 --- a/xmake/nvidia.lua +++ b/xmake/nvidia.lua @@ -147,14 +147,10 @@ target("flash-attn-nvidia") before_build(function (target) if FLASH_ATTN_ROOT ~= nil then - local outdata = os.iorunv("python", {"-c", "import torch, os; print(os.path.dirname(torch.__file__))"}):trim() - local TORCH_DIR = outdata - - local outdata = os.iorunv("python", {"-c", "import sysconfig; print(sysconfig.get_paths()['include'])"}):trim() - local PYTHON_INCLUDE = outdata - - local outdata = os.iorunv("python", {"-c", "import sysconfig; print(sysconfig.get_config_var('LIBDIR'))"}):trim() - local PYTHON_LIB_DIR = outdata + local TORCH_DIR = os.iorunv("python", {"-c", "import torch, os; print(os.path.dirname(torch.__file__))"}):trim() + local PYTHON_INCLUDE = os.iorunv("python", {"-c", "import sysconfig; print(sysconfig.get_paths()['include'])"}):trim() + local PYTHON_LIB_DIR= os.iorunv("python", {"-c", "import sysconfig; print(sysconfig.get_config_var('LIBDIR'))"}):trim() + local LIB_PYTHON = os.iorunv("python", {"-c", "import sysconfig, os; print(sysconfig.get_config_var('LDLIBRARY'))"}):trim() -- Include dirs target:add("includedirs", FLASH_ATTN_ROOT .. "/csrc/flash_attn/src", {public = false}) target:add("includedirs", TORCH_DIR .. "/include/torch/csrc/api/include", {public = false}) @@ -165,7 +161,7 @@ target("flash-attn-nvidia") -- Link libraries target:add("linkdirs", TORCH_DIR .. "/lib", PYTHON_LIB_DIR) - target:add("links", "torch", "torch_cuda", "torch_cpu", "c10", "c10_cuda", "torch_python", "python3") + target:add("links", "torch", "torch_cuda", "torch_cpu", "c10", "c10_cuda", "torch_python", LIB_PYTHON) end