Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions custom_ops/gpu_ops/cpp_extensions.cc
Original file line number Diff line number Diff line change
Expand Up @@ -962,8 +962,8 @@ void NgramMatch(const paddle::Tensor& token_ids_all,
const int max_ngram_size,
const int max_draft_tokens);

void HybridMtpNgram(const paddle::Tensor& input_ids,
const paddle::Tensor& input_ids_len,
void HybridMtpNgram(const paddle::Tensor& token_ids_all,
const paddle::Tensor& prompt_lens,
const paddle::Tensor& pre_ids,
const paddle::Tensor& step_idx,
const paddle::Tensor& draft_token_num,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,16 +29,16 @@
// Also copies tentative matched tokens to scratch buffers.
// ============================================================
__global__ void ngram_match_mixed_search_kernel(
const int64_t *input_ids,
const int64_t *input_ids_len,
const int64_t *token_ids_all,
const int64_t *prompt_lens,
const int64_t *pre_ids,
const int64_t *step_idx,
const int *draft_token_num,
const int32_t *seq_lens_this_time,
const int64_t *max_dec_len,
int64_t *draft_tokens_copy,
int32_t *seq_lens_this_time_copy,
int64_t input_ids_stride,
int64_t max_model_len,
int64_t pre_ids_stride,
int64_t draft_tokens_stride,
int64_t max_batch_size,
Expand Down Expand Up @@ -69,8 +69,9 @@ __global__ void ngram_match_mixed_search_kernel(
if (draft_budget <= 0 || remaining_dec <= 0) return;
int max_draft_tokens = static_cast<int>(min(draft_budget, remaining_dec));

const int64_t *cur_input_ids = input_ids + batch_idx * input_ids_stride;
const int64_t cur_input_ids_len = input_ids_len[batch_idx];
const int64_t prompt_len = prompt_lens[batch_idx];
const int64_t *cur_input_ids = token_ids_all + batch_idx * max_model_len;
const int64_t cur_input_ids_len = prompt_len;
const int64_t *cur_pre_ids = pre_ids + batch_idx * pre_ids_stride;
const int64_t cur_step_idx = step_idx[batch_idx];

Expand Down Expand Up @@ -228,16 +229,16 @@ static int sum_mixed_cpu(const int *value, int num) {
return sum_value;
}

static void find_candidate_pred_tokens_mixed(const int64_t *input_ids,
const int64_t *input_ids_len,
static void find_candidate_pred_tokens_mixed(const int64_t *token_ids_all,
const int64_t *prompt_lens,
const int64_t *pre_ids,
const int64_t *step_idx,
const int *draft_token_num,
int64_t *draft_tokens,
int32_t *seq_lens_this_time,
int32_t *seq_lens_decoder,
int64_t *max_dec_len,
int64_t input_ids_stride,
int64_t max_model_len,
int64_t pre_ids_stride,
int64_t draft_tokens_stride,
int64_t max_batch_size,
Expand Down Expand Up @@ -268,11 +269,12 @@ static void find_candidate_pred_tokens_mixed(const int64_t *input_ids,
int max_draft_tokens_query =
static_cast<int>(std::min(draft_budget, remaining_dec));

const int64_t *cur_input_ids = input_ids + batch_idx * input_ids_stride;
const int64_t prompt_len = prompt_lens[batch_idx];
const int64_t *cur_input_ids = token_ids_all + batch_idx * max_model_len;
int64_t *cur_draft_tokens = draft_tokens + batch_idx * draft_tokens_stride;
const int64_t *cur_pre_ids = pre_ids + batch_idx * pre_ids_stride;
const int64_t cur_step_idx = step_idx[batch_idx];
const int64_t cur_input_ids_len = input_ids_len[batch_idx];
const int64_t cur_input_ids_len = prompt_len;
unprocessed_batch_size--;

auto sum_token_num = sum_mixed_cpu(seq_lens_this_time, batch_idx);
Expand Down Expand Up @@ -363,8 +365,8 @@ static void find_candidate_pred_tokens_mixed(const int64_t *input_ids,
// threshold enforcement + final token copy.
// ============================================================

void HybridMtpNgram(const paddle::Tensor &input_ids,
const paddle::Tensor &input_ids_len,
void HybridMtpNgram(const paddle::Tensor &token_ids_all,
const paddle::Tensor &prompt_lens,
const paddle::Tensor &pre_ids,
const paddle::Tensor &step_idx,
const paddle::Tensor &draft_token_num,
Expand All @@ -375,8 +377,7 @@ void HybridMtpNgram(const paddle::Tensor &input_ids,
const int max_ngram_size,
const int min_ngram_size,
const int max_draft_tokens) {
auto input_ids_shape = input_ids.shape();
const int64_t input_ids_stride = input_ids_shape[1];
const int64_t max_model_len = token_ids_all.shape()[1];

auto pre_ids_shape = pre_ids.shape();
const int64_t pre_ids_stride = pre_ids_shape[1];
Expand All @@ -392,8 +393,8 @@ void HybridMtpNgram(const paddle::Tensor &input_ids,
threshold = std::stoi(env_var);
}

if (input_ids.is_gpu()) {
auto stream = input_ids.stream();
if (token_ids_all.is_gpu()) {
auto stream = token_ids_all.stream();

// NOTE: GPU path does not pass seq_lens_decoder to kernels — the mixed
// variant uses ori_seq_len_this_time == 0 to skip inactive items. This
Expand All @@ -408,16 +409,16 @@ void HybridMtpNgram(const paddle::Tensor &input_ids,
auto draft_tokens_copy =
paddle::empty({max_batch_size, draft_tokens_stride},
paddle::DataType::INT64,
input_ids.place());
token_ids_all.place());

// Scratch copy of seq_lens_this_time (Phase 1 writes tentative counts)
auto seq_lens_this_time_copy = paddle::empty(
{max_batch_size}, paddle::DataType::INT32, input_ids.place());
{max_batch_size}, paddle::DataType::INT32, token_ids_all.place());

// Save a copy of original seq_lens_this_time for Phase 2
// (Phase 1 reads from the original, Phase 2 needs ori values)
auto seq_lens_this_time_orig = paddle::empty(
{max_batch_size}, paddle::DataType::INT32, input_ids.place());
{max_batch_size}, paddle::DataType::INT32, token_ids_all.place());
cudaMemcpyAsync(seq_lens_this_time_orig.data<int32_t>(),
seq_lens_this_time.data<int32_t>(),
max_batch_size * sizeof(int32_t),
Expand All @@ -434,16 +435,16 @@ void HybridMtpNgram(const paddle::Tensor &input_ids,
NGRAM_BLOCK_THREADS,
0,
stream>>>(
input_ids.data<int64_t>(),
input_ids_len.data<int64_t>(),
token_ids_all.data<int64_t>(),
prompt_lens.data<int64_t>(),
pre_ids.data<int64_t>(),
step_idx.data<int64_t>(),
draft_token_num.data<int>(),
seq_lens_this_time.data<int32_t>(),
max_dec_len.data<int64_t>(),
draft_tokens_copy.data<int64_t>(),
seq_lens_this_time_copy.data<int32_t>(),
input_ids_stride,
max_model_len,
pre_ids_stride,
draft_tokens_stride,
max_batch_size,
Expand All @@ -464,16 +465,16 @@ void HybridMtpNgram(const paddle::Tensor &input_ids,
threshold);
} else {
find_candidate_pred_tokens_mixed(
input_ids.data<int64_t>(),
input_ids_len.data<int64_t>(),
token_ids_all.data<int64_t>(),
prompt_lens.data<int64_t>(),
pre_ids.data<int64_t>(),
step_idx.data<int64_t>(),
draft_token_num.data<int>(),
const_cast<int64_t *>(draft_tokens.data<int64_t>()),
const_cast<int32_t *>(seq_lens_this_time.data<int32_t>()),
const_cast<int32_t *>(seq_lens_decoder.data<int32_t>()),
const_cast<int64_t *>(max_dec_len.data<int64_t>()),
input_ids_stride,
max_model_len,
pre_ids_stride,
draft_tokens_stride,
max_batch_size,
Expand All @@ -484,8 +485,8 @@ void HybridMtpNgram(const paddle::Tensor &input_ids,
}

PD_BUILD_STATIC_OP(hybrid_mtp_ngram)
.Inputs({"input_ids",
"input_ids_len",
.Inputs({"token_ids_all",
"prompt_lens",
"pre_ids",
"step_idx",
"draft_token_num",
Expand Down
13 changes: 0 additions & 13 deletions fastdeploy/spec_decode/mtp.py
Original file line number Diff line number Diff line change
Expand Up @@ -470,16 +470,10 @@ def insert_tasks_v1(

input_ids = request.prompt_token_ids + request.output_token_ids

self.model_inputs["input_ids_len"][idx] = length - 1
async_set_value(self.model_inputs["pre_ids"][idx : idx + 1], -1)
self.model_inputs["input_ids"][idx : idx + 1, : length - 1] = self.target_model_inputs["input_ids"][
idx : idx + 1, 1:length
]
# TODO: use token_all_ids replace with input_ids_cpu
if getattr(self, "hybrid_mode", False) and "input_ids_cpu" in self.model_inputs:
self.model_inputs["input_ids_cpu"][idx : idx + 1, : length - 1] = self.target_model_inputs[
"input_ids"
][idx : idx + 1, 1:length].cpu()
encoder_block_num = len(request.block_tables)
async_set_value(self.model_inputs["encoder_block_lens"][idx : idx + 1], encoder_block_num)
async_set_value(self.model_inputs["block_tables"][idx : idx + 1, :], -1)
Expand Down Expand Up @@ -567,17 +561,13 @@ def insert_prefill_inputs(self, req_dicts: List[Request], num_running_requests:
request = req_dicts[i]
idx = request.idx
length = len(request.prompt_token_ids)
self.model_inputs.input_ids_len[idx] = length - 1

if req_dicts[i].disaggregate_info is not None and req_dicts[i].disaggregate_info["role"] == "decode":
length = len(request.prompt_token_ids)
if length > 1:
self.model_inputs["input_ids"][idx : idx + 1, : length - 1] = self.target_model_inputs[
"input_ids"
][idx : idx + 1, 1:length]
self.model_inputs["input_ids_cpu"][idx : idx + 1, : length - 1] = np.array(
request.prompt_token_ids
)[1:]
self.model_inputs["pre_ids"][idx : idx + 1] = request.prompt_token_ids[-1]
prefill_token_num = self.max_draft_token_num + 1
self.model_inputs["draft_tokens"][idx : idx + 1, 0:1] = paddle.to_tensor(
Expand Down Expand Up @@ -606,9 +596,6 @@ def insert_prefill_inputs(self, req_dicts: List[Request], num_running_requests:
self.model_inputs["input_ids"][idx : idx + 1, : length - 1] = self.target_model_inputs[
"input_ids"
][idx : idx + 1, 1:length]
self.model_inputs["input_ids_cpu"][idx : idx + 1, : length - 1] = np.array(
request.prompt_token_ids
)[1:]
self.model_inputs["pre_ids"][idx : idx + 1] = -1
self.model_inputs["step_idx"][idx : idx + 1] = 0
if self.cache_config.enable_chunked_prefill:
Expand Down
5 changes: 2 additions & 3 deletions fastdeploy/spec_decode/mtp_cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -393,10 +393,9 @@ def _update_status(self):
)

def _extend_draft_token_with_ngram_match(self):
# TODO: replace with gpu tensor
hybrid_mtp_ngram(
self.model_inputs["input_ids_cpu"].cuda(),
self.model_inputs["input_ids_len"].cuda(),
self.model_inputs["token_ids_all"],
self.model_inputs["prompt_lens"],
self.model_inputs["pre_ids"],
self.model_inputs["step_idx"],
self.target_model_inputs["actual_draft_token_num"],
Expand Down
17 changes: 2 additions & 15 deletions fastdeploy/worker/input_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -758,12 +758,6 @@ def init_share_inputs(self):

self.block_tables = paddle.clone(self.target_model_input_batch["block_tables"])
self.input_ids = paddle.clone(self.target_model_input_batch["input_ids"])
self.input_ids_cpu = paddle.full(
shape=[self.scheduler_config.max_num_seqs, self.model_config.max_model_len],
fill_value=-1,
dtype="int64",
device="cpu",
)
self.seq_lens_this_time_buffer = paddle.clone(self.target_model_input_batch["seq_lens_this_time"])

self.seq_lens_encoder = paddle.clone(self.target_model_input_batch["seq_lens_encoder"])
Expand All @@ -776,7 +770,7 @@ def init_share_inputs(self):
self.cu_seqlens_q_output = paddle.clone(self.target_model_input_batch["cu_seqlens_q_output"])
self.batch_id_per_token_output = paddle.clone(self.target_model_input_batch["batch_id_per_token_output"])
if "token_ids_all" in self.target_model_input_batch:
self.token_ids_all = paddle.clone(self.target_model_input_batch["token_ids_all"])
self.token_ids_all = self.target_model_input_batch["token_ids_all"]

This comment was marked as outdated.

This comment was marked as outdated.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. set_value_by_flags_and_idx是非spec路径下更新token_ids_all的地方,spec路径下是unified_update_model_status,在postprocess 中调用,更新写入的是本轮验证后的accept tokens;2. mtp proposer 读 token_ids_all 是在每轮 insert task 时,这一步与postprocess并不重叠;3. overlap schedule重叠的是本轮的target model forward + postprocess 过程和上一轮 cached_output 的cpu端保存过程,保存过程中不涉及 token_ids_all 的写入

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

❓ 疑问 token_ids_allclone 改为直接引用,需确认共享写安全

# 改前(安全)
self.token_ids_all = paddle.clone(self.target_model_input_batch["token_ids_all"])
# 改后(共享引用)
self.token_ids_all = self.target_model_input_batch["token_ids_all"]

改为共享引用后,ProposerInputBatch.token_ids_all 与 target model inputs 指向同一张量。如果 spec_decode 路径中对 token_ids_all 有 in-place 写操作(如 fill / index_put),会直接修改 target model 的输入。

请确认:

  1. _extend_draft_token_with_ngram_matchhybrid_mtp_ngram 仅读取 token_ids_all,不会写入
  2. reset_model_inputs 中同样改为引用(diff 第 970 行),是否所有 reset 场景下 target 的 token_ids_all 都已更新完毕再被引用

如确认只读,建议在 init_share_inputs 处添加注释说明共享语义:# token_ids_all is shared (read-only) with target model inputs

# TODO: delete pre_ids in mtp
self.pre_ids = paddle.full(
[self.scheduler_config.max_num_seqs, self.model_config.max_model_len],
Expand Down Expand Up @@ -886,7 +880,6 @@ def init_share_inputs(self):
self.last_seq_lens_this_time = paddle.full_like(
self.target_model_input_batch["seq_lens_this_time"], fill_value=-1, dtype="int32"
)
self.input_ids_len = paddle.zeros(shape=[self.scheduler_config.max_num_seqs, 1], dtype="int64", device="cpu")
self.temp_scaled_logprobs = self.target_model_input_batch["temp_scaled_logprobs"]
self.top_p_normalized_logprobs = self.target_model_input_batch["top_p_normalized_logprobs"]
self.accept_num = self.target_model_input_batch["accept_num"]
Expand Down Expand Up @@ -936,14 +929,12 @@ def swap_data(tensor, idx1, idx2):
self.index_to_batch_id[i1], self.index_to_batch_id[i2] = self.index_to_batch_id[i2], self.index_to_batch_id[i1]
swap_data(self.block_tables, i1, i2)
swap_data(self.input_ids, i1, i2)
swap_data(self.input_ids_cpu, i1, i2)
swap_data(self.seq_lens_this_time_buffer, i1, i2)
swap_data(self.seq_lens_encoder, i1, i2)
swap_data(self.seq_lens_decoder, i1, i2)
swap_data(self.step_idx, i1, i2)
swap_data(self.pre_ids, i1, i2)
swap_data(self.encoder_block_lens, i1, i2)
swap_data(self.input_ids_len, i1, i2)
swap_data(self.mask_rollback, i1, i2)
swap_data(self.recompute_token_num, i1, i2)
if self.enable_mm:
Expand All @@ -966,7 +957,6 @@ def reset_model_inputs(self) -> None:
# Clone the target model inputs to restore initial values
self.block_tables = paddle.clone(self.target_model_input_batch["block_tables"])
self.input_ids = paddle.clone(self.target_model_input_batch["input_ids"])
fill_paddle_tensor(self, "input_ids_cpu", -1)
# acceptance rate decline when reset seq_lens_this_time
# self.seq_lens_this_time_buffer = paddle.clone(self.target_model_input_batch["seq_lens_this_time"])

Expand All @@ -980,7 +970,7 @@ def reset_model_inputs(self) -> None:
self.index_to_batch_id = {}
if current_platform.is_cuda():
if "token_ids_all" in self.target_model_input_batch:
self.token_ids_all = paddle.clone(self.target_model_input_batch["token_ids_all"])
self.token_ids_all = self.target_model_input_batch["token_ids_all"]
# TODO: delete pre_ids in mtp
self.pre_ids = paddle.full(
[self.scheduler_config.max_num_seqs, self.model_config.max_model_len],
Expand Down Expand Up @@ -1062,9 +1052,6 @@ def reset_model_inputs(self) -> None:
if self.num_model_steps > 1:
fill_paddle_tensor(self, "last_seq_lens_this_time", -1)

# Reset input IDs length
fill_paddle_tensor(self, "input_ids_len", 0)

# Reset various scores and flags
self.temp_scaled_logprobs = self.target_model_input_batch["temp_scaled_logprobs"]
self.top_p_normalized_logprobs = self.target_model_input_batch["top_p_normalized_logprobs"]
Expand Down
Loading
Loading