Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 48 additions & 0 deletions include/infinicore/adaptor/aten_adaptor.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
#ifdef ENABLE_ATEN
#pragma once
#include "../context/context.hpp"
#include "../tensor.hpp"

#include <ATen/ATen.h>

#ifdef ENABLE_NVIDIA_API
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#endif

namespace infinicore::adaptor {
inline at::ScalarType to_at_dtype(DataType dtype) {
switch (dtype) {
case DataType::F32:
return at::kFloat;
case DataType::F16:
return at::kHalf;
case DataType::BF16:
return at::kBFloat16;
case DataType::I32:
return at::kInt;
case DataType::I64:
return at::kLong;
default:
throw std::runtime_error("Unsupported dtype for ATen");
}
}

inline at::Device to_at_device(const Device &device) {
if (device.getType() == Device::Type::NVIDIA) {
return at::Device(at::kCUDA, device.getIndex());
} else if (device.getType() == Device::Type::CPU) {
return at::Device(at::kCPU);
} else {
throw std::runtime_error("Unsupported device type for ATen");
}
}

at::Tensor to_aten_tensor(const infinicore::Tensor &t);

#ifdef ENABLE_NVIDIA_API
c10::cuda::CUDAStream get_cuda_stream();
#endif
} // namespace infinicore::adaptor

#endif // ENABLE_ATEN
114 changes: 114 additions & 0 deletions include/infinicore/adaptor/flash_attention_adaptor.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
#ifdef ENABLE_FLASH_ATTN
#pragma once
#include "aten_adaptor.hpp"

namespace flash {
std::vector<at::Tensor>
mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num_heads x round_multiple(head_size, 8)
const at::Tensor &k, // batch_size x seqlen_k x num_heads_k x round_multiple(head_size, 8)
const at::Tensor &v, // batch_size x seqlen_k x num_heads_k x round_multiple(head_size, 8)
std::optional<at::Tensor> &out_, // batch_size x seqlen_q x num_heads x round_multiple(head_size, 8)
std::optional<at::Tensor> &alibi_slopes_, // num_heads or batch_size x num_heads
const float p_dropout,
const float softmax_scale,
bool is_causal,
int window_size_left,
int window_size_right,
const float softcap,
const bool return_softmax,
std::optional<at::Generator> gen_);

std::vector<at::Tensor>
mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i
const at::Tensor &k, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table.
const at::Tensor &v, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table.
std::optional<at::Tensor> &out_, // total_q x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i
const at::Tensor &cu_seqlens_q, // b+1
const at::Tensor &cu_seqlens_k, // b+1
std::optional<at::Tensor> &seqused_k, // b. If given, only this many elements of each batch element's keys are used.
std::optional<const at::Tensor> &leftpad_k_, // batch_size
std::optional<at::Tensor> &block_table_, // batch_size x max_num_blocks_per_seq
std::optional<at::Tensor> &alibi_slopes_, // num_heads or b x num_heads
int max_seqlen_q,
const int max_seqlen_k,
const float p_dropout,
const float softmax_scale,
const bool zero_tensors,
bool is_causal,
int window_size_left,
int window_size_right,
const float softcap,
const bool return_softmax,
std::optional<at::Generator> gen_);

std::vector<at::Tensor>
mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x multiple_of(head_size_og, 8)
const at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size
const at::Tensor &k, // batch_size x seqlen_k x num_heads_k x head_size
const at::Tensor &v, // batch_size x seqlen_k x num_heads_k x head_size
const at::Tensor &out, // batch_size x seqlen_q x num_heads x head_size
const at::Tensor &softmax_lse, // b x h x seqlen_q
std::optional<at::Tensor> &dq_, // batch_size x seqlen_q x num_heads x head_size
std::optional<at::Tensor> &dk_, // batch_size x seqlen_k x num_heads_k x head_size
std::optional<at::Tensor> &dv_, // batch_size x seqlen_k x num_heads_k x head_size
std::optional<at::Tensor> &alibi_slopes_, // num_heads or batch_size x num_heads
const float p_dropout, // probability to drop
const float softmax_scale,
const bool is_causal,
int window_size_left,
int window_size_right,
const float softcap,
const bool deterministic,
std::optional<at::Generator> gen_,
std::optional<at::Tensor> &rng_state);

std::vector<at::Tensor>
mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size
const at::Tensor &q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i
const at::Tensor &k, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i
const at::Tensor &v, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i
const at::Tensor &out, // total_q x num_heads x head_size
const at::Tensor &softmax_lse, // h x total_q, softmax logsumexp
std::optional<at::Tensor> &dq_, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i
std::optional<at::Tensor> &dk_, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i
std::optional<at::Tensor> &dv_, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i
const at::Tensor &cu_seqlens_q, // b+1
const at::Tensor &cu_seqlens_k, // b+1
std::optional<at::Tensor> &alibi_slopes_, // num_heads or b x num_heads
const int max_seqlen_q,
const int max_seqlen_k, // max sequence length to choose the kernel
const float p_dropout, // probability to drop
const float softmax_scale,
const bool zero_tensors,
const bool is_causal,
int window_size_left,
int window_size_right,
const float softcap,
const bool deterministic,
std::optional<at::Generator> gen_,
std::optional<at::Tensor> &rng_state);

std::vector<at::Tensor>
mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size
const at::Tensor &kcache, // batch_size_c x seqlen_k x num_heads_k x head_size or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table.
const at::Tensor &vcache, // batch_size_c x seqlen_k x num_heads_k x head_size or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table.
std::optional<const at::Tensor> &k_, // batch_size x seqlen_knew x num_heads_k x head_size
std::optional<const at::Tensor> &v_, // batch_size x seqlen_knew x num_heads_k x head_size
std::optional<const at::Tensor> &seqlens_k_, // batch_size
std::optional<const at::Tensor> &rotary_cos_, // seqlen_ro x (rotary_dim / 2)
std::optional<const at::Tensor> &rotary_sin_, // seqlen_ro x (rotary_dim / 2)
std::optional<const at::Tensor> &cache_batch_idx_, // indices to index into the KV cache
std::optional<const at::Tensor> &leftpad_k_, // batch_size
std::optional<at::Tensor> &block_table_, // batch_size x max_num_blocks_per_seq
std::optional<at::Tensor> &alibi_slopes_, // num_heads or batch_size x num_heads
std::optional<at::Tensor> &out_, // batch_size x seqlen_q x num_heads x head_size
const float softmax_scale,
bool is_causal,
int window_size_left,
int window_size_right,
const float softcap,
bool is_rotary_interleaved, // if true, rotary combines indices 0 & 1, else indices 0 & rotary_dim / 2
int num_splits);

} // namespace flash
#endif // ENABLE_FLASH_ATTN
46 changes: 46 additions & 0 deletions include/infinicore/ops/mha_varlen.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
#pragma once

#include "../device.hpp"
#include "common/op.hpp"
#include <optional>

namespace infinicore::op {

INFINICORE_GRAPH_OP_CLASS(
MultiheadAttentionVarlen,
Tensor,
const Tensor &,
const Tensor &,
const Tensor &,
const Tensor &,
const Tensor &,
const Tensor &,
int,
int,
std::optional<Tensor>,
float);

Tensor mha_varlen(const Tensor &q,
const Tensor &k,
const Tensor &v,
const Tensor &cum_seqlens_q,
const Tensor &cum_seqlens_k,
const Tensor &block_table,
int max_seqlen_q,
int max_seqlen_k,
std::optional<Tensor> alibi_slopes,
float scale);

void mha_varlen_(Tensor out,
const Tensor &q,
const Tensor &k,
const Tensor &v,
const Tensor &cum_seqlens_q,
const Tensor &cum_seqlens_k,
const Tensor &block_table,
int max_seqlen_q,
int max_seqlen_k,
std::optional<Tensor> alibi_slopes,
float scale);

} // namespace infinicore::op
2 changes: 2 additions & 0 deletions python/infinicore/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@
from infinicore.ops.attention import attention
from infinicore.ops.kv_caching import kv_caching
from infinicore.ops.matmul import matmul
from infinicore.ops.mha_varlen import mha_varlen
from infinicore.ops.mul import mul
from infinicore.ops.narrow import narrow
from infinicore.ops.paged_attention import paged_attention
Expand Down Expand Up @@ -134,6 +135,7 @@
"from_list",
"from_numpy",
"from_torch",
"mha_varlen",
"paged_caching",
"paged_attention",
"paged_attention_prefill",
Expand Down
49 changes: 49 additions & 0 deletions python/infinicore/ops/mha_varlen.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
from infinicore.lib import _infinicore
from infinicore.tensor import Tensor


def mha_varlen(
q: Tensor,
k: Tensor,
v: Tensor,
cum_seqlens_q: Tensor,
cum_seqlens_k: Tensor,
block_table: Tensor,
max_seqlen_q: int,
max_seqlen_k: int,
alibi_slopes: Tensor | None = None,
scale: float = 1.0,
*,
out: Tensor | None = None,
):
if out is None:
return Tensor(
_infinicore.mha_varlen(
q._underlying,
k._underlying,
v._underlying,
cum_seqlens_q._underlying,
cum_seqlens_k._underlying,
block_table._underlying,
max_seqlen_q,
max_seqlen_k,
alibi_slopes._underlying if alibi_slopes is not None else None,
scale,
)
)

_infinicore.mha_varlen_(
out._underlying,
q._underlying,
k._underlying,
v._underlying,
cum_seqlens_q._underlying,
cum_seqlens_k._underlying,
block_table._underlying,
max_seqlen_q,
max_seqlen_k,
alibi_slopes._underlying if alibi_slopes is not None else None,
scale,
)

return out
44 changes: 44 additions & 0 deletions src/infinicore/adaptor/aten_adaptor.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
#ifdef ENABLE_ATEN
#include "infinicore/adaptor/aten_adaptor.hpp"

namespace infinicore::adaptor {

at::Tensor to_aten_tensor(const infinicore::Tensor &t) {
void *data_ptr = (void *)(t->data());

auto sizes = std::vector<int64_t>(
t->shape().begin(),
t->shape().end());

auto strides = t->strides();

auto dtype = to_at_dtype(t->dtype());
auto device = to_at_device(t->device());

auto deleter_ = [](void * /*unused*/) mutable {

};

at::TensorOptions options = at::TensorOptions()
.dtype(dtype)
.device(device)
.requires_grad(false);

return at::from_blob(
data_ptr,
sizes,
strides,
deleter_,
options);
}

#ifdef ENABLE_NVIDIA_API
c10::cuda::CUDAStream get_cuda_stream() {
return c10::cuda::getStreamFromExternal(
cudaStream_t(infinicore::context::getStream()), infinicore::context::getDevice().getIndex());
}
#endif

} // namespace infinicore::adaptor

#endif // ENABLE_ATEN
70 changes: 70 additions & 0 deletions src/infinicore/ops/multi_head_attention_varlen/mha_varlen.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
#include "infinicore/ops/mha_varlen.hpp"
#include "../../utils.hpp"

namespace infinicore::op {

INFINICORE_GRAPH_OP_DISPATCHERS_IMPL(MultiheadAttentionVarlen);

MultiheadAttentionVarlen::MultiheadAttentionVarlen(Tensor out,
const Tensor &q,
const Tensor &k,
const Tensor &v,
const Tensor &cum_seqlens_q,
const Tensor &cum_seqlens_kv,
const Tensor &block_table,
int max_seqlen_q,
int max_seqlen_k,
std::optional<Tensor> alibi_slopes,
float scale) {
INFINICORE_ASSERT_TENSORS_SAME_DEVICE(out, q, k, v, cum_seqlens_q, cum_seqlens_kv, block_table);
INFINICORE_GRAPH_OP_DISPATCH(out->device().getType(),
out, q, k, v, cum_seqlens_q, cum_seqlens_kv, block_table, max_seqlen_q, max_seqlen_k, alibi_slopes, scale);
}

void MultiheadAttentionVarlen::execute(Tensor out,
const Tensor &q,
const Tensor &k,
const Tensor &v,
const Tensor &cum_seqlens_q,
const Tensor &cum_seqlens_kv,
const Tensor &block_table,
int max_seqlen_q,
int max_seqlen_k,
std::optional<Tensor> alibi_slopes,
float scale) {
INFINICORE_GRAPH_OP_RECORD_OR_RUN(
MultiheadAttentionVarlen,
out, q, k, v, cum_seqlens_q, cum_seqlens_kv, block_table, max_seqlen_q, max_seqlen_k, alibi_slopes, scale);
}

Tensor mha_varlen(
const Tensor &q,
const Tensor &k,
const Tensor &v,
const Tensor &cum_seqlens_q,
const Tensor &cum_seqlens_kv,
const Tensor &block_table,
int max_seqlen_q,
int max_seqlen_k,
std::optional<Tensor> alibi_slopes,
float scale) {
auto out = Tensor::empty(q->shape(), q->dtype(), q->device());
mha_varlen_(out, q, k, v, cum_seqlens_q, cum_seqlens_kv, block_table, max_seqlen_q, max_seqlen_k, alibi_slopes, scale);
return out;
}

void mha_varlen_(Tensor out,
const Tensor &q,
const Tensor &k,
const Tensor &v,
const Tensor &cum_seqlens_q,
const Tensor &cum_seqlens_kv,
const Tensor &block_table,
int max_seqlen_q,
int max_seqlen_k,
std::optional<Tensor> alibi_slopes,
float scale) {
MultiheadAttentionVarlen::execute(out, q, k, v, cum_seqlens_q, cum_seqlens_kv, block_table, max_seqlen_q, max_seqlen_k, alibi_slopes, scale);
}

} // namespace infinicore::op
Loading