diff --git a/.gitignore b/.gitignore index d572eac42..b1717ce67 100644 --- a/.gitignore +++ b/.gitignore @@ -1,5 +1,6 @@ __pycache__/ .pyc +.codex build dist *.egg-info diff --git a/lightllm/common/basemodel/triton_kernel/mtp_utils.py b/lightllm/common/basemodel/triton_kernel/mtp_utils.py index 2d70a68c0..eb4b70e70 100644 --- a/lightllm/common/basemodel/triton_kernel/mtp_utils.py +++ b/lightllm/common/basemodel/triton_kernel/mtp_utils.py @@ -1,7 +1,10 @@ +from typing import Optional, Tuple import triton import triton.language as tl import torch +from lightllm.common.basemodel.batch_objs import ModelInput + @triton.jit def _fwd_kernel_mtp_verify( @@ -93,10 +96,15 @@ def _fwd_kernel_mtp_scatter_next_token_ids( req_to_next_token_ids_stride, all_next_token_ids, all_next_token_ids_stride, + req_to_next_token_probs, + req_to_next_token_probs_stride, + all_next_token_probs, + all_next_token_probs_stride, mtp_accept_len, b_req_mtp_start_loc, b_req_idx, mtp_step, + HAS_HAS_NEXT_TOKEN_PROBS: tl.constexpr, BLOCK_SIZE: tl.constexpr, ): @@ -106,6 +114,17 @@ def _fwd_kernel_mtp_scatter_next_token_ids( cur_req_idx = tl.load(b_req_idx + req_start_loc) offset = tl.arange(0, BLOCK_SIZE) + if HAS_HAS_NEXT_TOKEN_PROBS: + cur_next_token_probs = tl.load( + all_next_token_probs + (req_start_loc + accept_len - 1) * all_next_token_probs_stride + offset, + mask=offset < mtp_step, + other=0.0, + ) + tl.store( + req_to_next_token_probs + cur_req_idx * req_to_next_token_probs_stride + offset, + cur_next_token_probs, + mask=offset < mtp_step, + ) scatter_next_token_ids = tl.load( all_next_token_ids + (req_start_loc + accept_len - 1) * all_next_token_ids_stride + offset, mask=offset < mtp_step, @@ -125,12 +144,20 @@ def mtp_scatter_next_token_ids( all_next_token_ids: torch.Tensor, b_req_idx: torch.Tensor, mtp_accept_len: torch.Tensor, + req_to_next_token_probs: Optional[torch.Tensor] = None, + all_next_token_probs: Optional[torch.Tensor] = None, ): max_mtp_step = req_to_next_token_ids.shape[1] BLOCK_SIZE = 16 assert max_mtp_step <= BLOCK_SIZE, f"max_mtp_step must be less than {BLOCK_SIZE}" num_reqs = b_req_mtp_start_loc.shape[0] mtp_step = all_next_token_ids.shape[1] + if req_to_next_token_probs is not None: + assert all_next_token_probs is not None + assert all_next_token_probs.shape == all_next_token_ids.shape + + HAS_HAS_NEXT_TOKEN_PROBS = req_to_next_token_probs is not None + grid = (num_reqs,) num_warps = 1 _fwd_kernel_mtp_scatter_next_token_ids[grid]( @@ -138,16 +165,222 @@ def mtp_scatter_next_token_ids( req_to_next_token_ids_stride=req_to_next_token_ids.stride(0), all_next_token_ids=all_next_token_ids, all_next_token_ids_stride=all_next_token_ids.stride(0), + req_to_next_token_probs=req_to_next_token_probs, + req_to_next_token_probs_stride=req_to_next_token_probs.stride(0), + all_next_token_probs=all_next_token_probs, + all_next_token_probs_stride=all_next_token_probs.stride(0), mtp_accept_len=mtp_accept_len, b_req_mtp_start_loc=b_req_mtp_start_loc, b_req_idx=b_req_idx, mtp_step=mtp_step, + HAS_HAS_NEXT_TOKEN_PROBS=HAS_HAS_NEXT_TOKEN_PROBS, BLOCK_SIZE=BLOCK_SIZE, num_warps=num_warps, num_stages=1, ) +@triton.jit +def _fwd_kernel_sample_dynamic_mtp_steps( + req_to_next_token_probs, + req_to_next_token_probs_stride, + req_indices, + sampled_steps, + rand_seed, + MAX_MTP_STEP: tl.constexpr, + BLOCK_SIZE: tl.constexpr, +): + cur_index = tl.program_id(0) + cur_req_idx = tl.load(req_indices + cur_index) + + sampled_step = 1 + prefix_ok = 1 + + for step in range(1, BLOCK_SIZE): + if step < MAX_MTP_STEP: + cur_prob = tl.load(req_to_next_token_probs + cur_req_idx * req_to_next_token_probs_stride + step) + cur_rand = tl.rand(rand_seed, cur_index * BLOCK_SIZE + step) + cur_accept = (cur_prob > cur_rand) & (prefix_ok == 1) + sampled_step += tl.where(cur_accept, 1, 0) + prefix_ok = tl.where(cur_accept, 1, 0) + + tl.store(sampled_steps + cur_index, sampled_step) + return + + +def sample_dynamic_mtp_req_mask( + dynamic_batch_size: int, + b_req_idx: torch.Tensor, + b_mtp_index: torch.Tensor, + req_to_next_token_probs: torch.Tensor, +) -> Tuple[torch.Tensor, torch.Tensor]: + assert b_req_idx.shape == b_mtp_index.shape + assert b_req_idx.is_cuda + assert b_mtp_index.is_cuda + assert req_to_next_token_probs.is_cuda + + batch_size = b_req_idx.shape[0] + req_start_mask = b_mtp_index == 0 + num_reqs = int(req_start_mask.sum().item()) + assert num_reqs > 0 + assert dynamic_batch_size >= num_reqs + assert dynamic_batch_size <= batch_size + + req_indices_gpu = b_req_idx[req_start_mask].to(dtype=torch.int32) + + max_mtp_step = req_to_next_token_probs.shape[1] + BLOCK_SIZE = 16 + assert max_mtp_step <= BLOCK_SIZE, f"max_mtp_step must be less than or equal to {BLOCK_SIZE}" + + sampled_steps = torch.empty((num_reqs,), dtype=torch.int32, device="cuda") + rand_seed = int(torch.randint(0, 2**31 - 1, (1,), device="cuda", dtype=torch.int64).item()) + + grid = (num_reqs,) + _fwd_kernel_sample_dynamic_mtp_steps[grid]( + req_to_next_token_probs=req_to_next_token_probs, + req_to_next_token_probs_stride=req_to_next_token_probs.stride(0), + req_indices=req_indices_gpu, + sampled_steps=sampled_steps, + rand_seed=rand_seed, + MAX_MTP_STEP=max_mtp_step, + BLOCK_SIZE=BLOCK_SIZE, + num_warps=1, + num_stages=1, + ) + + req_order = torch.cumsum(req_start_mask.to(torch.int32), dim=0) - 1 + row_sampled_steps = sampled_steps[req_order.long()] + sampled_mask_gpu = (b_mtp_index < row_sampled_steps).to(torch.int32) + final_mask_gpu = select_dynamic_mtp_exec_mask( + dynamic_batch_size=dynamic_batch_size, + req_start_mask=req_start_mask, + sampled_mask_gpu=sampled_mask_gpu, + ) + + return final_mask_gpu, sampled_steps + + +def select_dynamic_mtp_exec_mask( + dynamic_batch_size: int, + req_start_mask: torch.Tensor, + sampled_mask_gpu: torch.Tensor, +) -> torch.Tensor: + assert req_start_mask.is_cuda + assert sampled_mask_gpu.is_cuda + assert req_start_mask.shape == sampled_mask_gpu.shape + + num_reqs = int(req_start_mask.sum().item()) + selected_count = int(sampled_mask_gpu.sum().item()) + if selected_count <= dynamic_batch_size: + return sampled_mask_gpu + + remaining_budget = dynamic_batch_size - num_reqs + if remaining_budget <= 0: + # 当前 token budget 分配逻辑保持最简单的顺序策略: + # 先保证每个请求至少保留 1 个 token(即 mtp_index == 0 的主请求行), + # 如果 budget 不足以支持更多 draft token,则该请求本轮不做 MTP 展开。 + return req_start_mask.to(torch.int32) + + # 当前 token budget 分配逻辑保持最简单的顺序策略: + # 在保证每个请求至少保留 1 个 token 之后, + # 对剩余的 draft token 按 batch 中从前到后的顺序依次分配 budget。 + extra_mask_gpu = sampled_mask_gpu & (~req_start_mask) + extra_rank = torch.cumsum(extra_mask_gpu.to(torch.int32), dim=0) + kept_extra_mask = extra_mask_gpu & (extra_rank <= remaining_budget) + return req_start_mask.to(torch.int32) | kept_extra_mask.to(torch.int32) + + +def _rebuild_trimmed_mtp_b_mark_shared_group_from_b_mtp_index(b_mtp_index: torch.Tensor) -> torch.Tensor: + assert b_mtp_index.is_cuda + batch_size = b_mtp_index.shape[0] + if batch_size == 0: + return torch.empty((0,), dtype=torch.int32, device=b_mtp_index.device) + + # prepare_decode_inputs 中已经保证了每个请求的 decode rows 是按 + # b_mtp_index = [0, 1, 2, ...] 的前缀顺序展开的;动态 trim 只会保留 + # 每个请求的前缀,因此 compact 之后每个请求块的末尾满足: + # 1. 是最后一个元素;或 + # 2. 下一个元素的 b_mtp_index 重新回到 0 + group_end_mask = torch.ones((batch_size,), dtype=torch.bool, device=b_mtp_index.device) + if batch_size > 1: + group_end_mask[:-1] = b_mtp_index[1:] == 0 + + b_mark_shared_group = torch.zeros((batch_size,), dtype=torch.int32, device=b_mtp_index.device) + b_mark_shared_group[group_end_mask] = (b_mtp_index[group_end_mask] + 1).to(torch.int32) + return b_mark_shared_group + + +def _trim_decode_model_input_inplace(model_input: ModelInput, selected_mask_gpu: torch.Tensor) -> ModelInput: + assert not model_input.is_prefill + assert selected_mask_gpu.is_cuda + assert model_input.b_req_idx.is_cuda + assert model_input.b_mtp_index.is_cuda + assert model_input.b_seq_len.is_cuda + assert model_input.mem_indexes is not None and model_input.mem_indexes.is_cuda + + selected_index = torch.nonzero(selected_mask_gpu, as_tuple=False).flatten() + new_batch_size = selected_index.numel() + if new_batch_size == model_input.batch_size: + return model_input + + if model_input.input_ids is not None: + assert model_input.input_ids.is_cuda + model_input.input_ids = torch.index_select(model_input.input_ids, dim=0, index=selected_index) + model_input.b_req_idx = torch.index_select(model_input.b_req_idx, dim=0, index=selected_index) + model_input.b_mtp_index = torch.index_select(model_input.b_mtp_index, dim=0, index=selected_index) + model_input.b_seq_len = torch.index_select(model_input.b_seq_len, dim=0, index=selected_index) + + if model_input.mem_indexes is not None: + assert model_input.mem_indexes.is_cuda + model_input.mem_indexes = torch.index_select(model_input.mem_indexes, dim=0, index=selected_index) + if model_input.b_shared_seq_len is not None: + assert model_input.b_shared_seq_len.is_cuda + model_input.b_shared_seq_len = torch.index_select(model_input.b_shared_seq_len, dim=0, index=selected_index) + if model_input.mtp_draft_input_hiddens is not None: + assert model_input.mtp_draft_input_hiddens.is_cuda + model_input.mtp_draft_input_hiddens = torch.index_select( + model_input.mtp_draft_input_hiddens, dim=0, index=selected_index + ) + # ! 目前用不到multimodal_params,但是又会检查它的长度是否正确,因此先简单地 trim 掉,保持和其他参数一致的 batch_size + if model_input.multimodal_params is not None: + model_input.multimodal_params = model_input.multimodal_params[:new_batch_size] + + model_input.b_mark_shared_group = _rebuild_trimmed_mtp_b_mark_shared_group_from_b_mtp_index( + model_input.b_mtp_index + ) + model_input.batch_size = new_batch_size + + return model_input + + +def trim_dynamic_mtp_model_input( + model_input: ModelInput, + dynamic_batch_size: int, + req_to_next_token_ids: torch.Tensor, + req_to_next_token_probs: Optional[torch.Tensor] = None, +): + del req_to_next_token_ids + + if req_to_next_token_probs is None: + selected_mask = torch.ones((model_input.batch_size,), dtype=torch.int32, device="cuda") + return model_input, selected_mask + + assert not model_input.is_prefill, "trim_dynamic_mtp_model_input only supports decode inputs" + # Keep the trim path simple and deterministic: finish model_input.to_cuda() first, + # then do both sampling and compaction purely on GPU tensors. + model_input.to_cuda() + torch.cuda.current_stream().synchronize() + + selected_mask_gpu, _ = sample_dynamic_mtp_req_mask( + dynamic_batch_size=dynamic_batch_size, + b_req_idx=model_input.b_req_idx, + b_mtp_index=model_input.b_mtp_index, + req_to_next_token_probs=req_to_next_token_probs, + ) + model_input = _trim_decode_model_input_inplace(model_input=model_input, selected_mask_gpu=selected_mask_gpu) + return model_input, selected_mask_gpu + + @triton.jit def _fwd_kernel_gen_b_req_mtp_start_loc( b_mtp_index, @@ -208,6 +441,124 @@ def test_gen_b_req_mtp_start_loc(): print(b_req_mtp_start_loc, gt_output) +def test_sample_dynamic_mtp_req_mask(): + torch.manual_seed(1234) + + req_to_next_token_ids = torch.tensor( + [ + [100, 101, 102, 103, 0, 0], + [200, 201, 202, 203, 0, 0], + [300, 301, 302, 303, 0, 0], + ], + dtype=torch.int64, + device="cuda", + ) + req_to_next_token_probs = torch.tensor( + [ + [1.0, 0.95, 0.90, 0.10, 0.0, 0.0], + [1.0, 0.20, 0.80, 0.80, 0.0, 0.0], + [1.0, 0.99, 0.99, 0.99, 0.0, 0.0], + ], + dtype=torch.float32, + device="cuda", + ) + + # 3 个请求,每个请求展开成 4 个 decode row + b_req_idx = torch.tensor( + [0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2], + dtype=torch.int32, + device="cuda", + ) + b_mtp_index = torch.tensor( + [0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3], + dtype=torch.int32, + device="cuda", + ) + + dynamic_batch_size = 8 + selected_mask, sampled_steps = sample_dynamic_mtp_req_mask( + dynamic_batch_size=dynamic_batch_size, + b_req_idx=b_req_idx, + b_mtp_index=b_mtp_index, + req_to_next_token_probs=req_to_next_token_probs, + ) + + print("==== test_sample_dynamic_mtp_req_mask ====") + print("req_to_next_token_ids:") + print(req_to_next_token_ids.cpu()) + print("req_to_next_token_probs:") + print(req_to_next_token_probs.cpu()) + print("b_req_idx:") + print(b_req_idx.cpu()) + print("b_mtp_index:") + print(b_mtp_index.cpu()) + print("sampled_steps per req:") + print(sampled_steps.cpu()) + print("selected_mask:") + print(selected_mask.cpu()) + print("selected rows [req_idx, mtp_index]:") + selected_pos = torch.where(selected_mask == 1)[0] + print(torch.stack([b_req_idx[selected_pos], b_mtp_index[selected_pos]], dim=-1).cpu()) + + +def test_trim_dynamic_mtp_model_input(): + torch.manual_seed(1234) + + model_input = ModelInput( + batch_size=12, + total_token_num=54, + max_q_seq_len=1, + max_kv_seq_len=6, + input_ids=None, + b_req_idx=torch.tensor([0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2], dtype=torch.int32, device="cuda"), + b_mtp_index=torch.tensor([0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3], dtype=torch.int32, device="cuda"), + b_seq_len=torch.tensor([3, 4, 5, 6, 3, 4, 5, 6, 3, 4, 5, 6], dtype=torch.int32, device="cuda"), + b_shared_seq_len=None, + b_mark_shared_group=torch.tensor([0, 0, 0, 4, 0, 0, 0, 4, 0, 0, 0, 4], dtype=torch.int32, device="cuda"), + mem_indexes=torch.arange(12, dtype=torch.int32, device="cuda"), + mem_indexes_cpu=torch.arange(12, dtype=torch.int32, device="cpu"), + is_prefill=False, + multimodal_params=[{"images": [], "audios": []} for _ in range(12)], + ) + req_to_next_token_probs = torch.tensor( + [ + [1.0, 0.95, 0.90, 0.10, 0.0, 0.0], + [1.0, 0.20, 0.80, 0.80, 0.0, 0.0], + [1.0, 0.99, 0.99, 0.99, 0.0, 0.0], + ], + dtype=torch.float32, + device="cuda", + ) + + model_input, selected_mask = trim_dynamic_mtp_model_input( + model_input=model_input, + dynamic_batch_size=8, + req_to_next_token_ids=torch.empty((0,), dtype=torch.int64, device="cuda"), + req_to_next_token_probs=req_to_next_token_probs, + ) + + print("==== test_trim_dynamic_mtp_model_input ====") + print("selected_mask:") + print(selected_mask.cpu()) + print("batch_size:", model_input.batch_size) + print("total_token_num:", model_input.total_token_num) + print("max_kv_seq_len:", model_input.max_kv_seq_len) + print("b_req_idx:") + print(model_input.b_req_idx.cpu()) + print("b_mtp_index:") + print(model_input.b_mtp_index.cpu()) + print("b_seq_len:") + print(model_input.b_seq_len.cpu()) + print("b_mark_shared_group:") + print(model_input.b_mark_shared_group.cpu()) + print("mem_indexes:") + print(model_input.mem_indexes.cpu()) + print("mem_indexes_cpu:") + print(model_input.mem_indexes_cpu) + + if __name__ == "__main__": - test_mtp_verify() + # test_mtp_verify() # test_gen_b_req_mtp_start_loc() + test_sample_dynamic_mtp_req_mask() + # test_trim_dynamic_mtp_model_input() diff --git a/lightllm/common/req_manager.py b/lightllm/common/req_manager.py index 3a4e2b631..278b3509c 100644 --- a/lightllm/common/req_manager.py +++ b/lightllm/common/req_manager.py @@ -6,7 +6,7 @@ from lightllm.common.basemodel.triton_kernel.gen_sampling_params import token_id_counter from lightllm.common.basemodel.triton_kernel.gen_sampling_params import update_req_to_token_id_counter -from lightllm.utils.envs_utils import enable_env_vars, get_env_start_args +from lightllm.utils.envs_utils import enable_env_vars, get_env_start_args, enable_dynamic_mtp_verify from lightllm.utils.config_utils import get_vocab_size from lightllm.server.router.model_infer.pin_mem_manager import g_pin_mem_manager @@ -116,6 +116,15 @@ def __init__(self, max_request_num): dtype=torch.int64, device="cuda", ) + if enable_dynamic_mtp_verify(): + self.req_to_next_token_probs = torch.zeros( + (max_request_num + 1, 16), + dtype=torch.float32, + device="cuda", + ) + else: + self.req_to_next_token_probs = None + self.req_to_exponential_decay_length_penalty = torch.zeros( max_request_num + 1, dtype=torch.float32, device="cuda" ) @@ -137,6 +146,9 @@ def init_req_sampling_params(self, req): shm_param = req.sampling_param.shm_param self.req_to_next_token_ids[req.req_idx][0:1].fill_(req.get_last_gen_token()) + if enable_dynamic_mtp_verify(): + self.req_to_next_token_probs[req.req_idx].fill_(0.0) + self.req_to_next_token_probs[req.req_idx][0:1].fill_(1.0) self.req_to_presence_penalty[req.req_idx].fill_(shm_param.presence_penalty) self.req_to_frequency_penalty[req.req_idx].fill_(shm_param.frequency_penalty) self.req_to_repetition_penalty[req.req_idx].fill_(shm_param.repetition_penalty) diff --git a/lightllm/server/router/model_infer/infer_batch.py b/lightllm/server/router/model_infer/infer_batch.py index 46608da13..7b931f153 100644 --- a/lightllm/server/router/model_infer/infer_batch.py +++ b/lightllm/server/router/model_infer/infer_batch.py @@ -435,10 +435,7 @@ def __init__( # mtp_step 用来记录一个请求 draft模型每步需要生成的token数量 # 正常模式下,这个值为0,在 mtp 模式下,这个值为 draft 模型每步需要生成的token数量 self.mtp_step: int = get_env_start_args().mtp_step - # current_mtp_step 用来记录当前的 MTP 验证长度(<= mtp_step) - # 在启用动态 MTP 验证时,每步会根据 prob 分布重新设置该值 - # 静态模式下为 mtp_step,动态模式下为动态计算的 MTP 验证长度 - self.current_mtp_step: int = self.mtp_step + if self.mtp_step > 0: self.decode_need_token_num = self._mtp_decode_need_token_num else: diff --git a/lightllm/server/router/model_infer/mode_backend/base_backend.py b/lightllm/server/router/model_infer/mode_backend/base_backend.py index 321055b4b..4203b093f 100644 --- a/lightllm/server/router/model_infer/mode_backend/base_backend.py +++ b/lightllm/server/router/model_infer/mode_backend/base_backend.py @@ -4,6 +4,7 @@ import time import threading import torch.distributed as dist +import collections from typing import List, Tuple, Callable, Optional from transformers.configuration_utils import PretrainedConfig from lightllm.utils.infer_utils import set_random_seed @@ -775,14 +776,18 @@ def _update_mtp_accept_ratio( return - def _update_mtp_verify_token_num(self, decode_reqs: List[InferReq]): + def _update_mtp_verify_token_num( + self, decode_reqs: List[InferReq], dynamic_mtp_run_reqs: Optional[List[InferReq]] = None + ): if self.is_master_in_dp: - for req in decode_reqs: - # 统计发送给主模型验证的 token 数量:1 个主 token + 当前 mtp_size 个 draft token - # 在静态 MTP 模式下,使用固定的 mtp_step;在动态 MTP 模式下,使用动态调整的 current_mtp_step - # current_mtp_step 在静态 MTP 模式下为 mtp_step,在动态 MTP 模式下会在推理过程中动态设置。 - assert req.current_mtp_step >= 0 - req.update_mtp_verify_token_num(verify_token_num=1 + req.current_mtp_step) + if dynamic_mtp_run_reqs is None: + for req in decode_reqs: + assert req.mtp_step > 0 + req.update_mtp_verify_token_num(verify_token_num=1 + req.mtp_step) + else: + counter = collections.Counter([req.req_idx for req in dynamic_mtp_run_reqs]) + for req in decode_reqs: + req.update_mtp_verify_token_num(verify_token_num=1 + counter[req.req_idx] - 1) return def _gen_argmax_token_ids(self, model_output: ModelOutput): diff --git a/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl.py b/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl.py index c41dbb6d9..1bfa702e3 100644 --- a/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl.py +++ b/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl.py @@ -20,12 +20,15 @@ from lightllm.common.basemodel.batch_objs import ModelOutput, ModelInput from lightllm.common.basemodel.triton_kernel.gather_token_id import scatter_token from lightllm.common.basemodel.triton_kernel.mtp_utils import ( + gen_b_req_mtp_start_loc, mtp_scatter_next_token_ids, + trim_dynamic_mtp_model_input, ) from lightllm.utils.log_utils import init_logger from lightllm.utils.dist_utils import get_current_device_id from lightllm.utils.envs_utils import get_env_start_args, enable_dynamic_mtp_verify from .control_state import ControlState +from lightllm.server.router.model_infer.mode_backend.dynamic_mtp_planner import DynamicMTPPlanner logger = init_logger(__name__) @@ -49,6 +52,9 @@ def __init__(self) -> None: self.prefill = self.prefill_normal self.decode = self.decode_normal + if self.enable_dynamic_mtp: + self.dynamic_mtp_planner = DynamicMTPPlanner(mtp_step=get_env_start_args().mtp_step) + self.classed_req_strict_prefill = False return @@ -234,19 +240,45 @@ def decode_mtp( MTP解码的通用流程,整合eagle和vanilla的共同逻辑 """ model_input, run_reqs = prepare_decode_inputs(decode_reqs) + origin_mem_indexes_cpu = model_input.mem_indexes_cpu with torch.cuda.stream(g_infer_context.get_overlap_stream()): - b_mtp_index_cpu = model_input.b_mtp_index + + if self.enable_dynamic_mtp: + dynamic_batch_size = self.dynamic_mtp_planner.get_dynamic_batch_size( + req_num=len(decode_reqs), + original_batch_size=model_input.batch_size, + ) + # TODO: 需要根据实际情况实现 trans_to_dynamic_model_input + trans_to_dynamic_model_input = trim_dynamic_mtp_model_input + model_input, selected_run_reqs = trans_to_dynamic_model_input( + model_input=model_input, + dynamic_batch_size=dynamic_batch_size, + req_to_next_token_ids=self.model.req_manager.req_sampling_params_manager.req_to_next_token_ids, + req_to_next_token_probs=self.model.req_manager.req_sampling_params_manager.req_to_next_token_probs, + ) + # selected_run_reqs 是一个 gpu tensor, 类型为 int, 0, 表示没有选中, 1 表示选中。 + selected_run_reqs_cpu = g_pin_mem_manager.async_copy_from_gpu_tensor( + key="selected_run_reqs", + gpu_tensor=selected_run_reqs, + ) + trans_dynamic_model_input_event = torch.cuda.Event() + trans_dynamic_model_input_event.record() + + start_time_event = torch.cuda.Event(enable_timing=True) + start_time_event.record() + model_output = self.model.forward(model_input) - next_token_ids, next_token_logprobs = sample(model_output.logits, run_reqs, self.eos_id) + next_token_ids, next_token_logprobs = sample( + model_output.logits, + run_reqs, + self.eos_id, + selected_run_reqs=selected_run_reqs if self.enable_dynamic_mtp else None, + ) # verify the next_token_ids - b_req_mtp_start_loc = [index for index, mtp_index in enumerate(b_mtp_index_cpu) if mtp_index == 0] - b_req_mtp_start_loc = g_pin_mem_manager.gen_from_list( - key="b_req_mtp_start_loc", - data=b_req_mtp_start_loc, - dtype=torch.int32, - ).cuda(non_blocking=True) - + # TODO: 需要根据实际情况实现 get_b_req_mtp_start_loc + b_req_mtp_start_loc = gen_b_req_mtp_start_loc(model_input.b_mtp_index, num_reqs=len(decode_reqs)) + # b_req_mtp_start_loc 是一个 gpu tensor, 类型为 int, 表示每个请求的 mtp_start_loc, shape 为 len(decode_reqs) mtp_accept_len, accepted_index = self._verify_mtp_v2( new_next_token_ids=next_token_ids, b_req_idx=model_input.b_req_idx, @@ -257,7 +289,7 @@ def decode_mtp( gpu_tensor=accepted_index, ) - verify_event = torch.cuda.Event() + verify_event = torch.cuda.Event(enable_timing=True) verify_event.record() if self.enable_dynamic_mtp: @@ -277,15 +309,20 @@ def decode_mtp( # dynamic_sizes_gpu 用于第二阶段更新 req 的 mtp_size if self.enable_dynamic_mtp: - draft_probs_tensor = torch.cat(draft_probs_list, dim=-1).view(self.mtp_step, b_mtp_index_cpu.shape[0]) + draft_probs_tensor = torch.cat(draft_probs_list, dim=-1).view( + self.mtp_step, model_input.b_mtp_index.shape[0] + ) dynamic_sizes_gpu = self._compute_dynamic_mtp_size_gpu_part(draft_probs_tensor=draft_probs_tensor) # 异步拷贝回 CPU Pin Memory dynamic_sizes_cpu = g_pin_mem_manager.async_copy_from_gpu_tensor( key="dynamic_mtp_sizes", gpu_tensor=dynamic_sizes_gpu ) - dynamic_mtp_event = torch.cuda.Event() - dynamic_mtp_event.record() + draft_probs_list = [e.view(-1, 1) for e in draft_probs_list] + draft_probs_list = [torch.ones_like(draft_probs_list[-1])] + draft_probs_list + all_next_token_probs = torch.cat(draft_probs_list, dim=-1) # [batch_size, mtp_step + 1] + else: + all_next_token_probs = None mtp_scatter_next_token_ids( req_to_next_token_ids=self.model.req_manager.req_sampling_params_manager.req_to_next_token_ids, @@ -293,6 +330,8 @@ def decode_mtp( all_next_token_ids=all_next_token_ids, b_req_idx=model_input.b_req_idx, mtp_accept_len=mtp_accept_len, + req_to_next_token_probs=self.model.req_manager.req_sampling_params_manager.req_to_next_token_probs, + all_next_token_probs=all_next_token_probs, ) next_token_ids_cpu, next_token_logprobs_cpu = self._async_copy_next_token_infos_to_pin_mem( @@ -315,24 +354,56 @@ def decode_mtp( # 第二阶段 event_pack.notify_post_handle_and_wait_pre_post_handle() - self._update_mtp_verify_token_num(decode_reqs=decode_reqs) + + if self.enable_dynamic_mtp: + trans_dynamic_model_input_event.synchronize() + selected_run_reqs_cpu_numpy = selected_run_reqs_cpu.numpy() + run_reqs = [run_reqs[i] for i in range(len(run_reqs)) if selected_run_reqs_cpu_numpy[i] == 1] + + if self.enable_dynamic_mtp: + self._update_mtp_verify_token_num(decode_reqs=decode_reqs, dynamic_mtp_run_reqs=run_reqs) + else: + self._update_mtp_verify_token_num(decode_reqs=decode_reqs) verify_event.synchronize() accepted_index_cpu_numpy = accepted_index_cpu.numpy() verify_ok_reqs = [run_reqs[i] for i in range(len(run_reqs)) if accepted_index_cpu_numpy[i] == 1] - if self.enable_dynamic_mtp: - dynamic_mtp_event.synchronize() - self._update_dynamic_mtp_size_cpu_part( - run_reqs=run_reqs, dynamic_sizes_cpu=dynamic_sizes_cpu, accepted_index_cpu=accepted_index_cpu - ) update_packs = self._pre_post_handle(verify_ok_reqs, is_chuncked_mode=False) # 第三阶段 event_pack.notify_forward_and_wait_post_handle() sync_event.synchronize() - # 处理需要释放的内存索引 - need_free_mem_indexes = model_input.mem_indexes_cpu[accepted_index_cpu == 0] + if self.enable_dynamic_mtp: + # 更新 动态verify token 数据到 planner 中去 + self._update_dynamic_mtp_size_statics( + run_reqs=run_reqs, + dynamic_sizes_cpu=dynamic_sizes_cpu, + accepted_index_cpu=accepted_index_cpu, + ) + # 更新单token的速度信息到 planner 中去 + per_token_cost_ms = start_time_event.elapsed_time(verify_event) / (mtp_accept_len_cpu.sum().item()) + + self.dynamic_mtp_planner.update_req_num_speed_statics( + req_num=len(decode_reqs), + dynamic_batch_size=dynamic_batch_size, + per_token_cost_ms=per_token_cost_ms, + ) + + # 处理需要释放的内存索引。动态 MTP trim 只裁了 GPU 侧 model_input, + # 因此这里统一基于原始的 origin_mem_indexes_cpu 计算: + # 1. 被选中但 verify 未通过的 index + # 2. 未被选中的 index + if self.enable_dynamic_mtp: + selected_run_reqs_cpu_mask = selected_run_reqs_cpu.to(dtype=torch.bool) + selected_mem_indexes_cpu = origin_mem_indexes_cpu[selected_run_reqs_cpu_mask] + need_free_mem_indexes = selected_mem_indexes_cpu[accepted_index_cpu == 0] + + trim_free_mem_indexes = origin_mem_indexes_cpu[~selected_run_reqs_cpu_mask] + if len(trim_free_mem_indexes) > 0: + need_free_mem_indexes = torch.cat([need_free_mem_indexes, trim_free_mem_indexes], dim=0) + else: + need_free_mem_indexes = model_input.mem_indexes_cpu[accepted_index_cpu == 0] if additional_mem_indexes_cpu is not None: need_free_mem_indexes = torch.cat([need_free_mem_indexes, additional_mem_indexes_cpu], dim=0) @@ -365,17 +436,22 @@ def _compute_dynamic_mtp_size_gpu_part( dynamic_mtp_sizes = valid_steps.sum(dim=0) return dynamic_mtp_sizes - def _update_dynamic_mtp_size_cpu_part( + def _update_dynamic_mtp_size_statics( self, run_reqs: List[InferReq], dynamic_sizes_cpu: torch.Tensor, accepted_index_cpu: torch.Tensor, ): + id_to_verify_len = {} + assert len(run_reqs) == dynamic_sizes_cpu.shape[0] == accepted_index_cpu.shape[0] for req, new_size, accepted in zip(run_reqs, dynamic_sizes_cpu.numpy(), accepted_index_cpu.numpy()): if int(accepted) == 1: - req.current_mtp_step = int(new_size) - assert req.current_mtp_step <= req.mtp_step + assert int(new_size) <= req.mtp_step + id_to_verify_len[req.req_idx] = int(new_size) + 1 + + self.dynamic_mtp_planner.update_req_verify_len_statics(verify_lens=list(id_to_verify_len.values())) + return def _draft_prefill_forward(self, model_input: ModelInput, model_output: ModelOutput, next_token_ids: torch.Tensor): # spec prefill: MTP, 这个地方只是为了填充draft model的 kv, 并不会使用生成的token_id。 diff --git a/lightllm/server/router/model_infer/mode_backend/dynamic_mtp_planner.py b/lightllm/server/router/model_infer/mode_backend/dynamic_mtp_planner.py new file mode 100644 index 000000000..11c0c93d2 --- /dev/null +++ b/lightllm/server/router/model_infer/mode_backend/dynamic_mtp_planner.py @@ -0,0 +1,99 @@ +import random +import math +import numpy as np +from typing import Dict, List, Optional + + +class DynamicMTPPlanner: + def __init__( + self, + mtp_step: int, + ema_decay: float = 0.9, + use_random_mode: bool = True, + random_mode_iter_threshold: int = 100, + ) -> None: + self.mtp_step = mtp_step + self.ema_decay = ema_decay + + # 记录每个请求的 verify_len 的 ema 值, 用于分布统计 + self.req_verify_len_ema = _EMAValue(self.ema_decay, init_value=float(self.mtp_step + 1)) + self.req_verify_len_second_moment_ema = _EMAValue(self.ema_decay, init_value=float((self.mtp_step + 1) ** 2)) + + # 每多少个请求采用随机的方式决定 dynamic_batch_size + self._iter = 0 + self._iter_threshold = random_mode_iter_threshold + self._use_random_mode = use_random_mode + random.seed(0) + + # 记录每个请求对应的单token速度记录信息 + self.req_num_to_speed_dict: Dict[int, List[_EMAValue]] = {} + return + + def update_req_verify_len_statics(self, verify_lens: List[int]) -> None: + for verify_len in verify_lens: + self.req_verify_len_ema.update(float(verify_len)) + self.req_verify_len_second_moment_ema.update(float(verify_len ** 2)) + return + + def update_req_num_speed_statics(self, req_num: int, dynamic_batch_size: int, per_token_cost_ms: float) -> None: + speed_ema_list = self._get_req_num_speed_ema_list(req_num) + index = dynamic_batch_size - req_num + ema_obj = speed_ema_list[index] + ema_obj.update(per_token_cost_ms) + return + + def get_dynamic_batch_size(self, req_num: int, original_batch_size: int) -> int: + assert req_num * (self.mtp_step + 1) == original_batch_size + # case 1 如果采用随机的方式决定 dynamic_batch_size + if self._use_random_mode and self._iter % self._iter_threshold == 0: + sigma = self._get_verify_len_sigma() + max_batch_size = min(req_num * (self.mtp_step + 1), req_num * int(self.req_verify_len_ema.get() + 1 * sigma)) + max_batch_size = max(req_num, max_batch_size) + return random.randint(req_num, max_batch_size) + self._iter += 1 + + # case 2 如果采用统计的方式决定 dynamic_batch_size, 利用统计的 ema 信息来决定 + ema_batch_size = max(req_num, int(self.req_verify_len_ema.get())) + sigma = self._get_verify_len_sigma() + max_batch_size = min(req_num * (self.mtp_step + 1), req_num * int(self.req_verify_len_ema.get() + 1 * sigma)) + max_batch_size = max(req_num, max_batch_size) + + start = req_num - req_num + end = max_batch_size - req_num + ema_index = ema_batch_size - req_num + speed_ema_list = self._get_req_num_speed_ema_list(req_num=req_num) + speeds = [obj.get() for obj in speed_ema_list[start : (end + 1)]] + # 对于期望均值的位置,我们稍微降低下其数据,便于在 ema 统计的数值不够多和准确的时候,更好的选中期望位置, + # 以获得平均性能。 + speeds[ema_index] -= 0.001 + min_index = np.argmin(speeds) + min_cost_batch_size = min_index + req_num + return min_cost_batch_size + + def _get_req_num_speed_ema_list(self, req_num: int) -> List["_EMAValue"]: + if req_num not in self.req_num_to_speed_dict: + self.req_num_to_speed_dict[req_num] = [ + _EMAValue(decay=self.ema_decay, init_value=10000000.0) for _ in range(req_num * (self.mtp_step)) + ] + return self.req_num_to_speed_dict[req_num] + + def _get_verify_len_sigma(self) -> float: + return math.sqrt(max(0.0, self.req_verify_len_second_moment_ema.get() - self.req_verify_len_ema.get() ** 2)) + + +class _EMAValue: + def __init__(self, decay: float, init_value: float) -> None: + """ """ + assert decay > 0.0 and decay < 1.0 + self.decay = decay + self.current_decay = 0.0 + self.value = init_value + + def update(self, new_value: float) -> float: + self.value = self.current_decay * self.value + (1.0 - self.current_decay) * new_value + # 更新 current_decay 的值,使得 current_decay 逐渐逼近 decay 的值 + self.current_decay = min(self.decay, (self.decay + self.current_decay) / 2.0 + 0.001) + return self.value + + def get(self) -> float: + return self.value diff --git a/lightllm/server/router/model_infer/mode_backend/generic_post_process.py b/lightllm/server/router/model_infer/mode_backend/generic_post_process.py index f3ad03662..fddfa60f0 100644 --- a/lightllm/server/router/model_infer/mode_backend/generic_post_process.py +++ b/lightllm/server/router/model_infer/mode_backend/generic_post_process.py @@ -1,5 +1,5 @@ import torch -from typing import List, Tuple +from typing import List, Tuple, Optional from lightllm.common.basemodel.triton_kernel.apply_penalty import apply_penalty from lightllm.common.basemodel.triton_kernel.apply_penalty_gpu_cache import apply_penalty_gpu_cache from lightllm.server.router.model_infer.infer_batch import InferReq, g_infer_context @@ -7,7 +7,35 @@ from lightllm.utils.envs_utils import get_env_start_args -def sample(logits: torch.Tensor, reqs: List[InferReq], eos_id: List[int] = [2]): +def sample( + logits: torch.Tensor, + reqs: List[InferReq], + eos_id: List[int] = [2], + selected_run_reqs: Optional[torch.Tensor] = None, +): + if selected_run_reqs is not None: + assert selected_run_reqs.is_cuda + assert selected_run_reqs.numel() == len(reqs) + selected_index = torch.nonzero(selected_run_reqs.to(torch.bool), as_tuple=False).flatten() + assert logits.shape[0] == selected_index.numel(), f"{logits.shape[0]} vs {selected_index.numel()}" + + padded_logits = torch.zeros( + (len(reqs), logits.shape[1]), + dtype=logits.dtype, + device=logits.device, + ) + padded_logits.index_copy_(0, selected_index, logits) + + full_next_token_ids, full_next_token_logprobs = sample( + padded_logits, + reqs, + eos_id=eos_id, + selected_run_reqs=None, + ) + next_token_ids = torch.index_select(full_next_token_ids, dim=0, index=selected_index) + next_token_logprobs = torch.index_select(full_next_token_logprobs, dim=0, index=selected_index) + return next_token_ids, next_token_logprobs + ( b_req_idx, b_temperatures, diff --git a/lightllm/server/router/model_infer/mode_backend/generic_pre_process.py b/lightllm/server/router/model_infer/mode_backend/generic_pre_process.py index 3d9d8815e..c50e0b7c3 100644 --- a/lightllm/server/router/model_infer/mode_backend/generic_pre_process.py +++ b/lightllm/server/router/model_infer/mode_backend/generic_pre_process.py @@ -113,9 +113,7 @@ def prepare_decode_inputs(req_objs: List[InferReq]) -> Tuple[ModelInput, List[In b_mtp_index.append(0) multimodal_params.append(req.multimodal_params) # process the draft tokens. - # 动态 MTP 模式:使用动态 current_mtp_step 构建 batch - # 非动态 MTP 模式:current_mtp_step 为固定的 mtp_step - for step in range(req.current_mtp_step): + for step in range(req.mtp_step): run_reqs.append(req) b_req_idx.append(req.req_idx) seq_len += 1 diff --git a/test/benchmark/service/benchmark_sharegpt.py b/test/benchmark/service/benchmark_sharegpt.py index b056e69bd..9337e6527 100644 --- a/test/benchmark/service/benchmark_sharegpt.py +++ b/test/benchmark/service/benchmark_sharegpt.py @@ -215,7 +215,7 @@ async def send_request( "top_k": 1, "top_p": 1.0, "temperature": 0, - "stream": True, + # "stream": True, "ignore_eos": True, "max_tokens": output_len, } @@ -224,20 +224,41 @@ async def send_request( async with aiohttp.ClientSession(timeout=timeout) as session: async with session.post(url, headers=headers, json=data) as response: + response.raise_for_status() chunks = [] text = "" start_time = time.time() is_first = True + sse_buffer = "" async for chunk, _ in response.content.iter_chunks(): now_time = time.time() delta_time = now_time - start_time if is_first: is_first = False ttft = delta_time - text += json.loads(chunk.decode("utf-8")[6:])["choices"][0]["delta"].get("content", "") - if delta_time < 0.005: - receive_n += 1 chunks.append(delta_time) + # OpenAI-compatible stream is SSE; one TCP chunk may contain + # partial/multiple events. Parse by complete lines safely. + sse_buffer += chunk.decode("utf-8", errors="ignore") + while "\n" in sse_buffer: + line, sse_buffer = sse_buffer.split("\n", 1) + line = line.strip() + if not line or not line.startswith("data:"): + continue + payload = line[5:].strip() + if payload == "[DONE]": + break + if not payload: + continue + try: + event = json.loads(payload) + except json.JSONDecodeError: + # In rare cases malformed/partial payload slips in; + # skip and continue to keep benchmark running. + continue + text += event.get("choices", [{}])[0].get("delta", {}).get("content", "") + if delta_time < 0.005: + receive_n += 1 start_time = now_time # print("messages", messages) # print("text", text) diff --git a/test/speculative/bench_throughput.sh b/test/speculative/bench_throughput.sh index 8e14f8189..4cfa90bcb 100644 --- a/test/speculative/bench_throughput.sh +++ b/test/speculative/bench_throughput.sh @@ -2,7 +2,7 @@ # 默认值 PORT=8088 NUM_PROMPTS=1000 -TOKENIZER="/mtc/models/qwen3-8b" +TOKENIZER="/mtc/models/qwen3-32b" DATASET="/data/nvme0/chenjunyi/project/lightllm/datasets/gsm8k.json" HISTORY_TURNS=1 CONCURRENCY=128 diff --git a/test/speculative/qwen3-32b/dynamic_triton.sh b/test/speculative/qwen3-32b/dynamic_triton.sh index 39145e5f5..6af1c2ca4 100644 --- a/test/speculative/qwen3-32b/dynamic_triton.sh +++ b/test/speculative/qwen3-32b/dynamic_triton.sh @@ -16,8 +16,10 @@ done MODEL_DIR=/mtc/models/qwen3-32b DRAFT_MODEL_DIR=/mtc/models/qwen3-32b-eagle3 -LOADWORKER=18 python -m lightllm.server.api_server --port 8088 \ ---tp 4 --max_total_token_num 200000 \ +PATH=/data/nvme0/chenjunyi/miniconda3/envs/lightllm/bin:$PATH + +LOADWORKER=18 /data/nvme0/chenjunyi/miniconda3/envs/lightllm/bin/python -m lightllm.server.api_server --port 8088 \ +--tp 2 \ --model_dir ${MODEL_DIR} \ --mtp_mode eagle3 \ --disable_dynamic_prompt_cache \ @@ -25,7 +27,8 @@ LOADWORKER=18 python -m lightllm.server.api_server --port 8088 \ --graph_grow_step_size 1 \ --mtp_step ${MTP_STEP} \ --llm_decode_att_backend triton \ ---mtp_dynamic_verify +--mtp_dynamic_verify \ +--disable_cudagraph # if you want to enable microbatch overlap, you can uncomment the following lines #--enable_prefill_microbatch_overlap \ #--enable_decode_microbatch_overlap \ diff --git a/test/speculative/qwen3-32b/no_mtp_fa3.sh b/test/speculative/qwen3-32b/no_mtp_fa3.sh new file mode 100644 index 000000000..c17562721 --- /dev/null +++ b/test/speculative/qwen3-32b/no_mtp_fa3.sh @@ -0,0 +1,11 @@ +MODEL_DIR=/mtc/models/qwen3-32b +DRAFT_MODEL_DIR=/mtc/models/qwen3-32b-eagle3 + +PATH=/data/nvme0/chenjunyi/miniconda3/envs/lightllm/bin:$PATH + +LOADWORKER=18 /data/nvme0/chenjunyi/miniconda3/envs/lightllm/bin/python -m lightllm.server.api_server --port 8088 \ +--tp 2 \ +--model_dir ${MODEL_DIR} \ +--disable_dynamic_prompt_cache \ +--graph_grow_step_size 1 \ +--llm_decode_att_backend triton \ No newline at end of file diff --git a/test/speculative/qwen3-32b/static_fa3.sh b/test/speculative/qwen3-32b/static_fa3.sh index c9712116e..44c67e03b 100644 --- a/test/speculative/qwen3-32b/static_fa3.sh +++ b/test/speculative/qwen3-32b/static_fa3.sh @@ -16,8 +16,10 @@ done MODEL_DIR=/mtc/models/qwen3-32b DRAFT_MODEL_DIR=/mtc/models/qwen3-32b-eagle3 -LOADWORKER=18 python -m lightllm.server.api_server --port 8088 \ ---tp 4 --max_total_token_num 200000 \ +PATH=/data/nvme0/chenjunyi/miniconda3/envs/lightllm/bin:$PATH + +LOADWORKER=18 /data/nvme0/chenjunyi/miniconda3/envs/lightllm/bin/python -m lightllm.server.api_server --port 8088 \ +--tp 2 \ --model_dir ${MODEL_DIR} \ --mtp_mode eagle3 \ --mtp_draft_model_dir ${DRAFT_MODEL_DIR} \ diff --git a/test/speculative/qwen3-32b/static_triton.sh b/test/speculative/qwen3-32b/static_triton.sh index 453c5678e..71964c9af 100644 --- a/test/speculative/qwen3-32b/static_triton.sh +++ b/test/speculative/qwen3-32b/static_triton.sh @@ -16,8 +16,10 @@ done MODEL_DIR=/mtc/models/qwen3-32b DRAFT_MODEL_DIR=/mtc/models/qwen3-32b-eagle3 -LOADWORKER=18 python -m lightllm.server.api_server --port 8088 \ ---tp 4 --max_total_token_num 200000 \ +PATH=/data/nvme0/chenjunyi/miniconda3/envs/lightllm/bin:$PATH + +LOADWORKER=18 /data/nvme0/chenjunyi/miniconda3/envs/lightllm/bin/python -m lightllm.server.api_server --port 8088 \ +--tp 2 \ --model_dir ${MODEL_DIR} \ --mtp_mode eagle3 \ --disable_dynamic_prompt_cache \ diff --git a/test/speculative/run_vllm_speculative_baseline.sh b/test/speculative/run_vllm_speculative_baseline.sh new file mode 100755 index 000000000..f2027f20c --- /dev/null +++ b/test/speculative/run_vllm_speculative_baseline.sh @@ -0,0 +1,298 @@ +#!/bin/bash + +# ============================================================================= +# vLLM Speculative Decoding Baseline Experiment Script +# Function: Run vLLM default draft-model speculative decoding baseline for +# different mtp steps (mapped to num_speculative_tokens), and collect +# throughput/latency metrics with the same benchmark script. +# ============================================================================= + +set -euo pipefail + +# Keep default GPU visibility aligned with existing LightLLM experiment scripts. +export CUDA_VISIBLE_DEVICES="${CUDA_VISIBLE_DEVICES:-1,2,3,4,6}" +# Reduce allocator fragmentation risk during model warmup. +export PYTORCH_CUDA_ALLOC_CONF="${PYTORCH_CUDA_ALLOC_CONF:-expandable_segments:True}" + +# ============================================================================= +# Configurable Parameters +# ============================================================================= +PROJECT_DIR="/data/nvme0/chenjunyi/project/lightllm" +BENCH_PY_SCRIPT="${PROJECT_DIR}/test/benchmark/service/benchmark_sharegpt.py" +DATASET="${PROJECT_DIR}/datasets/gsm8k.json" + +# Keep defaults close to existing LightLLM qwen3-32b setup. +MODEL_DIR="/mtc/models/qwen3-32b" +DRAFT_MODEL_DIR="/mtc/models/qwen3-32b-eagle3" +TOKENIZER="/mtc/models/qwen3-32b" + +SAMPLES=1000 +CONCURRENCY=256 +PORT=8088 +TP=4 +MAX_MODEL_LEN=16384 +MAX_NUM_BATCHED_TOKENS=200000 +MAX_NUM_SEQS=256 +GPU_MEMORY_UTILIZATION=0.6 +MAX_CUDAGRAPH_CAPTURE_SIZE=256 +ATTENTION_BACKEND="FLASH_ATTN" +DISABLE_CUSTOM_ALL_REDUCE=1 +MTP_STEPS=(5) + +RESULTS_DIR="${PROJECT_DIR}/experiment_results" +TIMESTAMP=$(date +%Y%m%d_%H%M%S) +DATASET_NAME=$(basename "${DATASET}" .json) +EXPERIMENT_SUBDIR="${RESULTS_DIR}/${DATASET_NAME}_${TIMESTAMP}_vllm_spec_default" +RESULTS_FILE="${EXPERIMENT_SUBDIR}/results.csv" + +usage() { + echo "Usage: $0 [options]" + echo "" + echo "Options:" + echo " --model-dir PATH Main model path (default: ${MODEL_DIR})" + echo " --draft-model-dir PATH Draft model path (default: ${DRAFT_MODEL_DIR})" + echo " --dataset PATH Dataset path (default: ${DATASET})" + echo " --tokenizer PATH Tokenizer path (default: ${TOKENIZER})" + echo " --samples NUM Number of prompts (default: ${SAMPLES})" + echo " --concurrency NUM Concurrency (default: ${CONCURRENCY})" + echo " --port PORT Service port (default: ${PORT})" + echo " --tp NUM Tensor parallel size (default: ${TP})" + echo " --mtp-steps LIST Comma-separated mtp steps (default: 5)" + echo " --num-speculative-tokens NUM Backward-compatible alias, equals one mtp step" + echo " --max-model-len NUM vLLM max model len (default: ${MAX_MODEL_LEN})" + echo " --max-num-batched-tokens NUM vLLM max batched tokens (default: ${MAX_NUM_BATCHED_TOKENS})" + echo " --max-num-seqs NUM vLLM max number of concurrent seqs (default: ${MAX_NUM_SEQS})" + echo " --max-cudagraph-capture-size NUM vLLM max cudagraph capture size (default: ${MAX_CUDAGRAPH_CAPTURE_SIZE})" + echo " --gpu-memory-utilization F GPU memory utilization (default: ${GPU_MEMORY_UTILIZATION})" + echo " --attention-backend NAME vLLM attention backend (default: ${ATTENTION_BACKEND})" + echo " --enable-custom-all-reduce Enable custom all-reduce (default: disabled)" + echo " --results-dir DIR Results base dir (default: ${RESULTS_DIR})" + echo " --help Show this help" + exit 1 +} + +while [[ $# -gt 0 ]]; do + case "$1" in + --model-dir) + MODEL_DIR="$2" + shift 2 + ;; + --draft-model-dir) + DRAFT_MODEL_DIR="$2" + shift 2 + ;; + --dataset) + DATASET="$2" + shift 2 + ;; + --tokenizer) + TOKENIZER="$2" + shift 2 + ;; + --samples) + SAMPLES="$2" + shift 2 + ;; + --concurrency) + CONCURRENCY="$2" + shift 2 + ;; + --port) + PORT="$2" + shift 2 + ;; + --tp) + TP="$2" + shift 2 + ;; + --mtp-steps) + IFS=',' read -ra MTP_STEPS <<< "$2" + shift 2 + ;; + --num-speculative-tokens) + MTP_STEPS=("$2") + shift 2 + ;; + --max-model-len) + MAX_MODEL_LEN="$2" + shift 2 + ;; + --max-num-batched-tokens) + MAX_NUM_BATCHED_TOKENS="$2" + shift 2 + ;; + --max-num-seqs) + MAX_NUM_SEQS="$2" + shift 2 + ;; + --max-cudagraph-capture-size) + MAX_CUDAGRAPH_CAPTURE_SIZE="$2" + shift 2 + ;; + --gpu-memory-utilization) + GPU_MEMORY_UTILIZATION="$2" + shift 2 + ;; + --attention-backend) + ATTENTION_BACKEND="$2" + shift 2 + ;; + --enable-custom-all-reduce) + DISABLE_CUSTOM_ALL_REDUCE=0 + shift 1 + ;; + --results-dir) + RESULTS_DIR="$2" + shift 2 + ;; + --help) + usage + ;; + *) + echo "Unknown argument: $1" + usage + ;; + esac +done + +# Recompute result paths in case dataset/results-dir was overridden. +DATASET_NAME=$(basename "${DATASET}" .json) +EXPERIMENT_SUBDIR="${RESULTS_DIR}/${DATASET_NAME}_${TIMESTAMP}_vllm_spec_default" +RESULTS_FILE="${EXPERIMENT_SUBDIR}/results.csv" + +mkdir -p "${EXPERIMENT_SUBDIR}" + +echo "timestamp,engine,mode,mtp_step,dataset,samples,concurrency,throughput,avg_latency,avg_ttft,avg_inter_token_latency" > "${RESULTS_FILE}" + +wait_for_server() { + local max_attempts=600 + local attempt=0 + echo "Waiting for vLLM server to start..." + while [[ ${attempt} -lt ${max_attempts} ]]; do + if curl -s "http://localhost:${PORT}/health" > /dev/null 2>&1; then + echo "vLLM server started" + return 0 + fi + sleep 2 + attempt=$((attempt + 1)) + done + echo "vLLM server startup timeout" + return 1 +} + +extract_benchmark_metrics() { + local log_file="$1" + local throughput="" + local avg_latency="" + local avg_ttft="" + local avg_inter_token_latency="" + + throughput=$(grep -oP 'Throughput: \K[\d.]+' "$log_file" | tail -1) + avg_latency=$(grep -oP 'Average latency: \K[\d.]+' "$log_file" | tail -1) + avg_ttft=$(grep -oP 'Average time to first token: \K[\d.]+' "$log_file" | tail -1) + avg_inter_token_latency=$(grep -oP 'Average inter-token latency: \K[\d.]+' "$log_file" | tail -1) + + echo "${throughput:-NA},${avg_latency:-NA},${avg_ttft:-NA},${avg_inter_token_latency:-NA}" +} + +kill_vllm() { + echo "Stopping vLLM server..." + pkill -9 -f "vllm serve" 2>/dev/null || true + pkill -9 -f "vllm.entrypoints.openai.api_server" 2>/dev/null || true + sleep 1 + echo "vLLM server stopped" +} + +trap 'kill_vllm' EXIT + +echo "==============================================" +echo "vLLM Speculative Baseline Started" +echo "==============================================" +echo "Model: ${MODEL_DIR}" +echo "Draft model: ${DRAFT_MODEL_DIR}" +echo "Tokenizer: ${TOKENIZER}" +echo "Dataset: ${DATASET}" +echo "Samples: ${SAMPLES}" +echo "Concurrency: ${CONCURRENCY}" +echo "TP: ${TP}" +echo "Port: ${PORT}" +echo "Max model len: ${MAX_MODEL_LEN}" +echo "Max batched tokens: ${MAX_NUM_BATCHED_TOKENS}" +echo "Max num seqs: ${MAX_NUM_SEQS}" +echo "Max cudagraph capture size: ${MAX_CUDAGRAPH_CAPTURE_SIZE}" +echo "GPU memory utilization: ${GPU_MEMORY_UTILIZATION}" +echo "Attention backend: ${ATTENTION_BACKEND}" +echo "Disable custom all reduce: ${DISABLE_CUSTOM_ALL_REDUCE}" +echo "MTP steps: ${MTP_STEPS[*]}" +echo "Results directory: ${EXPERIMENT_SUBDIR}" +echo "==============================================" + +for MTP_STEP in "${MTP_STEPS[@]}"; do + echo "" + echo "--- Running mtp step: ${MTP_STEP} ---" + + LOG_FILE="${EXPERIMENT_SUBDIR}/log_vllm_spec_default_step${MTP_STEP}_${TIMESTAMP}.txt" + BENCH_LOG="${EXPERIMENT_SUBDIR}/bench_vllm_spec_default_step${MTP_STEP}_${TIMESTAMP}.txt" + + SPECULATIVE_CONFIG=$(printf '{"model": "%s", "num_speculative_tokens": %s, "method": "draft_model"}' \ + "${DRAFT_MODEL_DIR}" "${MTP_STEP}") + CUSTOM_ALL_REDUCE_FLAG="" + if [[ "${DISABLE_CUSTOM_ALL_REDUCE}" == "1" ]]; then + CUSTOM_ALL_REDUCE_FLAG="--disable-custom-all-reduce" + fi + + kill_vllm + + echo "Starting vLLM server with speculative_config=${SPECULATIVE_CONFIG}" + ( + vllm serve "${MODEL_DIR}" \ + --host 0.0.0.0 \ + --port "${PORT}" \ + --served-model-name DeepSeek-R1 \ + -tp "${TP}" \ + --max_model_len "${MAX_MODEL_LEN}" \ + --max_num_batched_tokens "${MAX_NUM_BATCHED_TOKENS}" \ + --max_num_seqs "${MAX_NUM_SEQS}" \ + --max-cudagraph-capture-size "${MAX_CUDAGRAPH_CAPTURE_SIZE}" \ + --attention-backend "${ATTENTION_BACKEND}" \ + ${CUSTOM_ALL_REDUCE_FLAG} \ + --speculative_config "${SPECULATIVE_CONFIG}" + ) > "${LOG_FILE}" 2>&1 & + + SERVER_PID=$! + echo "vLLM PID: ${SERVER_PID}" + + if ! wait_for_server; then + echo "vLLM server failed to start for mtp step ${MTP_STEP}. Check log: ${LOG_FILE}" + RESULT_LINE="${TIMESTAMP},vllm,speculative_draft_model_default,${MTP_STEP},${DATASET},${SAMPLES},${CONCURRENCY},NA,NA,NA,NA" + echo "${RESULT_LINE}" >> "${RESULTS_FILE}" + continue + fi + + sleep 5 + + echo "Running benchmark with benchmark_sharegpt.py (OpenAI API mode)..." + python "${BENCH_PY_SCRIPT}" \ + --use_openai_api \ + --port "${PORT}" \ + --num-prompts "${SAMPLES}" \ + --tokenizer "${TOKENIZER}" \ + --dataset "${DATASET}" \ + --history-turns 1 \ + --concurrency "${CONCURRENCY}" 2>&1 | tee "${BENCH_LOG}" + + cat "${BENCH_LOG}" >> "${LOG_FILE}" + + BENCH_METRICS=$(extract_benchmark_metrics "${LOG_FILE}") + RESULT_LINE="${TIMESTAMP},vllm,speculative_draft_model_default,${MTP_STEP},${DATASET},${SAMPLES},${CONCURRENCY},${BENCH_METRICS}" + echo "${RESULT_LINE}" >> "${RESULTS_FILE}" + + echo "Completed mtp step ${MTP_STEP}: ${RESULT_LINE}" +done + +echo "" +echo "==============================================" +echo "All Experiments Completed" +echo "==============================================" +echo "Results file: ${RESULTS_FILE}" +cat "${RESULTS_FILE}"