From 74189c27672cd850d4b6370efcc2e8611e18da11 Mon Sep 17 00:00:00 2001 From: Junyi Chen Date: Mon, 20 Apr 2026 20:23:52 +0800 Subject: [PATCH 01/20] add vllm test script --- .gitignore | 1 + test/benchmark/service/benchmark_sharegpt.py | 29 +- test/speculative/bench_throughput.sh | 2 +- .../run_vllm_speculative_baseline.sh | 298 ++++++++++++++++++ 4 files changed, 325 insertions(+), 5 deletions(-) create mode 100755 test/speculative/run_vllm_speculative_baseline.sh diff --git a/.gitignore b/.gitignore index d572eac42..b1717ce67 100644 --- a/.gitignore +++ b/.gitignore @@ -1,5 +1,6 @@ __pycache__/ .pyc +.codex build dist *.egg-info diff --git a/test/benchmark/service/benchmark_sharegpt.py b/test/benchmark/service/benchmark_sharegpt.py index b056e69bd..9337e6527 100644 --- a/test/benchmark/service/benchmark_sharegpt.py +++ b/test/benchmark/service/benchmark_sharegpt.py @@ -215,7 +215,7 @@ async def send_request( "top_k": 1, "top_p": 1.0, "temperature": 0, - "stream": True, + # "stream": True, "ignore_eos": True, "max_tokens": output_len, } @@ -224,20 +224,41 @@ async def send_request( async with aiohttp.ClientSession(timeout=timeout) as session: async with session.post(url, headers=headers, json=data) as response: + response.raise_for_status() chunks = [] text = "" start_time = time.time() is_first = True + sse_buffer = "" async for chunk, _ in response.content.iter_chunks(): now_time = time.time() delta_time = now_time - start_time if is_first: is_first = False ttft = delta_time - text += json.loads(chunk.decode("utf-8")[6:])["choices"][0]["delta"].get("content", "") - if delta_time < 0.005: - receive_n += 1 chunks.append(delta_time) + # OpenAI-compatible stream is SSE; one TCP chunk may contain + # partial/multiple events. Parse by complete lines safely. + sse_buffer += chunk.decode("utf-8", errors="ignore") + while "\n" in sse_buffer: + line, sse_buffer = sse_buffer.split("\n", 1) + line = line.strip() + if not line or not line.startswith("data:"): + continue + payload = line[5:].strip() + if payload == "[DONE]": + break + if not payload: + continue + try: + event = json.loads(payload) + except json.JSONDecodeError: + # In rare cases malformed/partial payload slips in; + # skip and continue to keep benchmark running. + continue + text += event.get("choices", [{}])[0].get("delta", {}).get("content", "") + if delta_time < 0.005: + receive_n += 1 start_time = now_time # print("messages", messages) # print("text", text) diff --git a/test/speculative/bench_throughput.sh b/test/speculative/bench_throughput.sh index 8e14f8189..4cfa90bcb 100644 --- a/test/speculative/bench_throughput.sh +++ b/test/speculative/bench_throughput.sh @@ -2,7 +2,7 @@ # 默认值 PORT=8088 NUM_PROMPTS=1000 -TOKENIZER="/mtc/models/qwen3-8b" +TOKENIZER="/mtc/models/qwen3-32b" DATASET="/data/nvme0/chenjunyi/project/lightllm/datasets/gsm8k.json" HISTORY_TURNS=1 CONCURRENCY=128 diff --git a/test/speculative/run_vllm_speculative_baseline.sh b/test/speculative/run_vllm_speculative_baseline.sh new file mode 100755 index 000000000..f2027f20c --- /dev/null +++ b/test/speculative/run_vllm_speculative_baseline.sh @@ -0,0 +1,298 @@ +#!/bin/bash + +# ============================================================================= +# vLLM Speculative Decoding Baseline Experiment Script +# Function: Run vLLM default draft-model speculative decoding baseline for +# different mtp steps (mapped to num_speculative_tokens), and collect +# throughput/latency metrics with the same benchmark script. +# ============================================================================= + +set -euo pipefail + +# Keep default GPU visibility aligned with existing LightLLM experiment scripts. +export CUDA_VISIBLE_DEVICES="${CUDA_VISIBLE_DEVICES:-1,2,3,4,6}" +# Reduce allocator fragmentation risk during model warmup. +export PYTORCH_CUDA_ALLOC_CONF="${PYTORCH_CUDA_ALLOC_CONF:-expandable_segments:True}" + +# ============================================================================= +# Configurable Parameters +# ============================================================================= +PROJECT_DIR="/data/nvme0/chenjunyi/project/lightllm" +BENCH_PY_SCRIPT="${PROJECT_DIR}/test/benchmark/service/benchmark_sharegpt.py" +DATASET="${PROJECT_DIR}/datasets/gsm8k.json" + +# Keep defaults close to existing LightLLM qwen3-32b setup. +MODEL_DIR="/mtc/models/qwen3-32b" +DRAFT_MODEL_DIR="/mtc/models/qwen3-32b-eagle3" +TOKENIZER="/mtc/models/qwen3-32b" + +SAMPLES=1000 +CONCURRENCY=256 +PORT=8088 +TP=4 +MAX_MODEL_LEN=16384 +MAX_NUM_BATCHED_TOKENS=200000 +MAX_NUM_SEQS=256 +GPU_MEMORY_UTILIZATION=0.6 +MAX_CUDAGRAPH_CAPTURE_SIZE=256 +ATTENTION_BACKEND="FLASH_ATTN" +DISABLE_CUSTOM_ALL_REDUCE=1 +MTP_STEPS=(5) + +RESULTS_DIR="${PROJECT_DIR}/experiment_results" +TIMESTAMP=$(date +%Y%m%d_%H%M%S) +DATASET_NAME=$(basename "${DATASET}" .json) +EXPERIMENT_SUBDIR="${RESULTS_DIR}/${DATASET_NAME}_${TIMESTAMP}_vllm_spec_default" +RESULTS_FILE="${EXPERIMENT_SUBDIR}/results.csv" + +usage() { + echo "Usage: $0 [options]" + echo "" + echo "Options:" + echo " --model-dir PATH Main model path (default: ${MODEL_DIR})" + echo " --draft-model-dir PATH Draft model path (default: ${DRAFT_MODEL_DIR})" + echo " --dataset PATH Dataset path (default: ${DATASET})" + echo " --tokenizer PATH Tokenizer path (default: ${TOKENIZER})" + echo " --samples NUM Number of prompts (default: ${SAMPLES})" + echo " --concurrency NUM Concurrency (default: ${CONCURRENCY})" + echo " --port PORT Service port (default: ${PORT})" + echo " --tp NUM Tensor parallel size (default: ${TP})" + echo " --mtp-steps LIST Comma-separated mtp steps (default: 5)" + echo " --num-speculative-tokens NUM Backward-compatible alias, equals one mtp step" + echo " --max-model-len NUM vLLM max model len (default: ${MAX_MODEL_LEN})" + echo " --max-num-batched-tokens NUM vLLM max batched tokens (default: ${MAX_NUM_BATCHED_TOKENS})" + echo " --max-num-seqs NUM vLLM max number of concurrent seqs (default: ${MAX_NUM_SEQS})" + echo " --max-cudagraph-capture-size NUM vLLM max cudagraph capture size (default: ${MAX_CUDAGRAPH_CAPTURE_SIZE})" + echo " --gpu-memory-utilization F GPU memory utilization (default: ${GPU_MEMORY_UTILIZATION})" + echo " --attention-backend NAME vLLM attention backend (default: ${ATTENTION_BACKEND})" + echo " --enable-custom-all-reduce Enable custom all-reduce (default: disabled)" + echo " --results-dir DIR Results base dir (default: ${RESULTS_DIR})" + echo " --help Show this help" + exit 1 +} + +while [[ $# -gt 0 ]]; do + case "$1" in + --model-dir) + MODEL_DIR="$2" + shift 2 + ;; + --draft-model-dir) + DRAFT_MODEL_DIR="$2" + shift 2 + ;; + --dataset) + DATASET="$2" + shift 2 + ;; + --tokenizer) + TOKENIZER="$2" + shift 2 + ;; + --samples) + SAMPLES="$2" + shift 2 + ;; + --concurrency) + CONCURRENCY="$2" + shift 2 + ;; + --port) + PORT="$2" + shift 2 + ;; + --tp) + TP="$2" + shift 2 + ;; + --mtp-steps) + IFS=',' read -ra MTP_STEPS <<< "$2" + shift 2 + ;; + --num-speculative-tokens) + MTP_STEPS=("$2") + shift 2 + ;; + --max-model-len) + MAX_MODEL_LEN="$2" + shift 2 + ;; + --max-num-batched-tokens) + MAX_NUM_BATCHED_TOKENS="$2" + shift 2 + ;; + --max-num-seqs) + MAX_NUM_SEQS="$2" + shift 2 + ;; + --max-cudagraph-capture-size) + MAX_CUDAGRAPH_CAPTURE_SIZE="$2" + shift 2 + ;; + --gpu-memory-utilization) + GPU_MEMORY_UTILIZATION="$2" + shift 2 + ;; + --attention-backend) + ATTENTION_BACKEND="$2" + shift 2 + ;; + --enable-custom-all-reduce) + DISABLE_CUSTOM_ALL_REDUCE=0 + shift 1 + ;; + --results-dir) + RESULTS_DIR="$2" + shift 2 + ;; + --help) + usage + ;; + *) + echo "Unknown argument: $1" + usage + ;; + esac +done + +# Recompute result paths in case dataset/results-dir was overridden. +DATASET_NAME=$(basename "${DATASET}" .json) +EXPERIMENT_SUBDIR="${RESULTS_DIR}/${DATASET_NAME}_${TIMESTAMP}_vllm_spec_default" +RESULTS_FILE="${EXPERIMENT_SUBDIR}/results.csv" + +mkdir -p "${EXPERIMENT_SUBDIR}" + +echo "timestamp,engine,mode,mtp_step,dataset,samples,concurrency,throughput,avg_latency,avg_ttft,avg_inter_token_latency" > "${RESULTS_FILE}" + +wait_for_server() { + local max_attempts=600 + local attempt=0 + echo "Waiting for vLLM server to start..." + while [[ ${attempt} -lt ${max_attempts} ]]; do + if curl -s "http://localhost:${PORT}/health" > /dev/null 2>&1; then + echo "vLLM server started" + return 0 + fi + sleep 2 + attempt=$((attempt + 1)) + done + echo "vLLM server startup timeout" + return 1 +} + +extract_benchmark_metrics() { + local log_file="$1" + local throughput="" + local avg_latency="" + local avg_ttft="" + local avg_inter_token_latency="" + + throughput=$(grep -oP 'Throughput: \K[\d.]+' "$log_file" | tail -1) + avg_latency=$(grep -oP 'Average latency: \K[\d.]+' "$log_file" | tail -1) + avg_ttft=$(grep -oP 'Average time to first token: \K[\d.]+' "$log_file" | tail -1) + avg_inter_token_latency=$(grep -oP 'Average inter-token latency: \K[\d.]+' "$log_file" | tail -1) + + echo "${throughput:-NA},${avg_latency:-NA},${avg_ttft:-NA},${avg_inter_token_latency:-NA}" +} + +kill_vllm() { + echo "Stopping vLLM server..." + pkill -9 -f "vllm serve" 2>/dev/null || true + pkill -9 -f "vllm.entrypoints.openai.api_server" 2>/dev/null || true + sleep 1 + echo "vLLM server stopped" +} + +trap 'kill_vllm' EXIT + +echo "==============================================" +echo "vLLM Speculative Baseline Started" +echo "==============================================" +echo "Model: ${MODEL_DIR}" +echo "Draft model: ${DRAFT_MODEL_DIR}" +echo "Tokenizer: ${TOKENIZER}" +echo "Dataset: ${DATASET}" +echo "Samples: ${SAMPLES}" +echo "Concurrency: ${CONCURRENCY}" +echo "TP: ${TP}" +echo "Port: ${PORT}" +echo "Max model len: ${MAX_MODEL_LEN}" +echo "Max batched tokens: ${MAX_NUM_BATCHED_TOKENS}" +echo "Max num seqs: ${MAX_NUM_SEQS}" +echo "Max cudagraph capture size: ${MAX_CUDAGRAPH_CAPTURE_SIZE}" +echo "GPU memory utilization: ${GPU_MEMORY_UTILIZATION}" +echo "Attention backend: ${ATTENTION_BACKEND}" +echo "Disable custom all reduce: ${DISABLE_CUSTOM_ALL_REDUCE}" +echo "MTP steps: ${MTP_STEPS[*]}" +echo "Results directory: ${EXPERIMENT_SUBDIR}" +echo "==============================================" + +for MTP_STEP in "${MTP_STEPS[@]}"; do + echo "" + echo "--- Running mtp step: ${MTP_STEP} ---" + + LOG_FILE="${EXPERIMENT_SUBDIR}/log_vllm_spec_default_step${MTP_STEP}_${TIMESTAMP}.txt" + BENCH_LOG="${EXPERIMENT_SUBDIR}/bench_vllm_spec_default_step${MTP_STEP}_${TIMESTAMP}.txt" + + SPECULATIVE_CONFIG=$(printf '{"model": "%s", "num_speculative_tokens": %s, "method": "draft_model"}' \ + "${DRAFT_MODEL_DIR}" "${MTP_STEP}") + CUSTOM_ALL_REDUCE_FLAG="" + if [[ "${DISABLE_CUSTOM_ALL_REDUCE}" == "1" ]]; then + CUSTOM_ALL_REDUCE_FLAG="--disable-custom-all-reduce" + fi + + kill_vllm + + echo "Starting vLLM server with speculative_config=${SPECULATIVE_CONFIG}" + ( + vllm serve "${MODEL_DIR}" \ + --host 0.0.0.0 \ + --port "${PORT}" \ + --served-model-name DeepSeek-R1 \ + -tp "${TP}" \ + --max_model_len "${MAX_MODEL_LEN}" \ + --max_num_batched_tokens "${MAX_NUM_BATCHED_TOKENS}" \ + --max_num_seqs "${MAX_NUM_SEQS}" \ + --max-cudagraph-capture-size "${MAX_CUDAGRAPH_CAPTURE_SIZE}" \ + --attention-backend "${ATTENTION_BACKEND}" \ + ${CUSTOM_ALL_REDUCE_FLAG} \ + --speculative_config "${SPECULATIVE_CONFIG}" + ) > "${LOG_FILE}" 2>&1 & + + SERVER_PID=$! + echo "vLLM PID: ${SERVER_PID}" + + if ! wait_for_server; then + echo "vLLM server failed to start for mtp step ${MTP_STEP}. Check log: ${LOG_FILE}" + RESULT_LINE="${TIMESTAMP},vllm,speculative_draft_model_default,${MTP_STEP},${DATASET},${SAMPLES},${CONCURRENCY},NA,NA,NA,NA" + echo "${RESULT_LINE}" >> "${RESULTS_FILE}" + continue + fi + + sleep 5 + + echo "Running benchmark with benchmark_sharegpt.py (OpenAI API mode)..." + python "${BENCH_PY_SCRIPT}" \ + --use_openai_api \ + --port "${PORT}" \ + --num-prompts "${SAMPLES}" \ + --tokenizer "${TOKENIZER}" \ + --dataset "${DATASET}" \ + --history-turns 1 \ + --concurrency "${CONCURRENCY}" 2>&1 | tee "${BENCH_LOG}" + + cat "${BENCH_LOG}" >> "${LOG_FILE}" + + BENCH_METRICS=$(extract_benchmark_metrics "${LOG_FILE}") + RESULT_LINE="${TIMESTAMP},vllm,speculative_draft_model_default,${MTP_STEP},${DATASET},${SAMPLES},${CONCURRENCY},${BENCH_METRICS}" + echo "${RESULT_LINE}" >> "${RESULTS_FILE}" + + echo "Completed mtp step ${MTP_STEP}: ${RESULT_LINE}" +done + +echo "" +echo "==============================================" +echo "All Experiments Completed" +echo "==============================================" +echo "Results file: ${RESULTS_FILE}" +cat "${RESULTS_FILE}" From df549356461870e5b93c310240818aeb5b1d166b Mon Sep 17 00:00:00 2001 From: Junyi Chen Date: Thu, 7 May 2026 21:10:43 +0800 Subject: [PATCH 02/20] remove 200000 token limit in test script --- test/speculative/qwen3-32b/dynamic_triton.sh | 6 ++++-- test/speculative/qwen3-32b/no_mtp_fa3.sh | 11 +++++++++++ test/speculative/qwen3-32b/static_fa3.sh | 6 ++++-- test/speculative/qwen3-32b/static_triton.sh | 6 ++++-- 4 files changed, 23 insertions(+), 6 deletions(-) create mode 100644 test/speculative/qwen3-32b/no_mtp_fa3.sh diff --git a/test/speculative/qwen3-32b/dynamic_triton.sh b/test/speculative/qwen3-32b/dynamic_triton.sh index 39145e5f5..ca70bf15e 100644 --- a/test/speculative/qwen3-32b/dynamic_triton.sh +++ b/test/speculative/qwen3-32b/dynamic_triton.sh @@ -16,8 +16,10 @@ done MODEL_DIR=/mtc/models/qwen3-32b DRAFT_MODEL_DIR=/mtc/models/qwen3-32b-eagle3 -LOADWORKER=18 python -m lightllm.server.api_server --port 8088 \ ---tp 4 --max_total_token_num 200000 \ +PATH=/data/nvme0/chenjunyi/miniconda3/envs/lightllm/bin:$PATH + +LOADWORKER=18 /data/nvme0/chenjunyi/miniconda3/envs/lightllm/bin/python -m lightllm.server.api_server --port 8088 \ +--tp 2 \ --model_dir ${MODEL_DIR} \ --mtp_mode eagle3 \ --disable_dynamic_prompt_cache \ diff --git a/test/speculative/qwen3-32b/no_mtp_fa3.sh b/test/speculative/qwen3-32b/no_mtp_fa3.sh new file mode 100644 index 000000000..c17562721 --- /dev/null +++ b/test/speculative/qwen3-32b/no_mtp_fa3.sh @@ -0,0 +1,11 @@ +MODEL_DIR=/mtc/models/qwen3-32b +DRAFT_MODEL_DIR=/mtc/models/qwen3-32b-eagle3 + +PATH=/data/nvme0/chenjunyi/miniconda3/envs/lightllm/bin:$PATH + +LOADWORKER=18 /data/nvme0/chenjunyi/miniconda3/envs/lightllm/bin/python -m lightllm.server.api_server --port 8088 \ +--tp 2 \ +--model_dir ${MODEL_DIR} \ +--disable_dynamic_prompt_cache \ +--graph_grow_step_size 1 \ +--llm_decode_att_backend triton \ No newline at end of file diff --git a/test/speculative/qwen3-32b/static_fa3.sh b/test/speculative/qwen3-32b/static_fa3.sh index c9712116e..44c67e03b 100644 --- a/test/speculative/qwen3-32b/static_fa3.sh +++ b/test/speculative/qwen3-32b/static_fa3.sh @@ -16,8 +16,10 @@ done MODEL_DIR=/mtc/models/qwen3-32b DRAFT_MODEL_DIR=/mtc/models/qwen3-32b-eagle3 -LOADWORKER=18 python -m lightllm.server.api_server --port 8088 \ ---tp 4 --max_total_token_num 200000 \ +PATH=/data/nvme0/chenjunyi/miniconda3/envs/lightllm/bin:$PATH + +LOADWORKER=18 /data/nvme0/chenjunyi/miniconda3/envs/lightllm/bin/python -m lightllm.server.api_server --port 8088 \ +--tp 2 \ --model_dir ${MODEL_DIR} \ --mtp_mode eagle3 \ --mtp_draft_model_dir ${DRAFT_MODEL_DIR} \ diff --git a/test/speculative/qwen3-32b/static_triton.sh b/test/speculative/qwen3-32b/static_triton.sh index 453c5678e..71964c9af 100644 --- a/test/speculative/qwen3-32b/static_triton.sh +++ b/test/speculative/qwen3-32b/static_triton.sh @@ -16,8 +16,10 @@ done MODEL_DIR=/mtc/models/qwen3-32b DRAFT_MODEL_DIR=/mtc/models/qwen3-32b-eagle3 -LOADWORKER=18 python -m lightllm.server.api_server --port 8088 \ ---tp 4 --max_total_token_num 200000 \ +PATH=/data/nvme0/chenjunyi/miniconda3/envs/lightllm/bin:$PATH + +LOADWORKER=18 /data/nvme0/chenjunyi/miniconda3/envs/lightllm/bin/python -m lightllm.server.api_server --port 8088 \ +--tp 2 \ --model_dir ${MODEL_DIR} \ --mtp_mode eagle3 \ --disable_dynamic_prompt_cache \ From db0a2ca692c2525b8c5ca8b8a02ed635b2e1470e Mon Sep 17 00:00:00 2001 From: Junyi Chen Date: Thu, 14 May 2026 20:18:05 +0800 Subject: [PATCH 03/20] first commit for ema --- .../mode_backend/chunked_prefill/impl.py | 91 ++++---------- .../mode_backend/dynamic_mtp_planner.py | 112 ++++++++++++++++++ .../mode_backend/generic_pre_process.py | 40 +++++-- 3 files changed, 166 insertions(+), 77 deletions(-) create mode 100644 lightllm/server/router/model_infer/mode_backend/dynamic_mtp_planner.py diff --git a/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl.py b/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl.py index c41dbb6d9..6b1578ae7 100644 --- a/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl.py +++ b/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl.py @@ -25,6 +25,7 @@ from lightllm.utils.log_utils import init_logger from lightllm.utils.dist_utils import get_current_device_id from lightllm.utils.envs_utils import get_env_start_args, enable_dynamic_mtp_verify +from lightllm.server.router.model_infer.mode_backend.dynamic_mtp_planner import DynamicMTPPlanner from .control_state import ControlState logger = init_logger(__name__) @@ -45,6 +46,9 @@ def __init__(self) -> None: self.num_mtp_models = 1 if self.is_mtp_eagle else get_env_start_args().mtp_step self._draft_decode_func = self._draft_decode_eagle if self.is_mtp_eagle else self._draft_decode_vanilla self.enable_dynamic_mtp = enable_dynamic_mtp_verify() + self.dynamic_mtp_planner = ( + DynamicMTPPlanner(max_mtp_step=get_env_start_args().mtp_step) if self.enable_dynamic_mtp else None + ) else: self.prefill = self.prefill_normal self.decode = self.decode_normal @@ -233,7 +237,15 @@ def decode_mtp( """ MTP解码的通用流程,整合eagle和vanilla的共同逻辑 """ - model_input, run_reqs = prepare_decode_inputs(decode_reqs) + mtp_plan = None + if self.enable_dynamic_mtp: + mtp_plan = self.dynamic_mtp_planner.build_plan(decode_reqs) + model_input, run_reqs = prepare_decode_inputs( + decode_reqs, + mtp_decode_indexes=mtp_plan.selected_mtp_indexes, + ) + else: + model_input, run_reqs = prepare_decode_inputs(decode_reqs) with torch.cuda.stream(g_infer_context.get_overlap_stream()): b_mtp_index_cpu = model_input.b_mtp_index @@ -260,32 +272,12 @@ def decode_mtp( verify_event = torch.cuda.Event() verify_event.record() - if self.enable_dynamic_mtp: - all_next_token_ids, additional_mem_indexes_cpu, draft_probs_list = self._draft_decode_func( - main_model_input=model_input, - main_model_output=model_output, - next_token_ids=next_token_ids, - b_req_mtp_start_loc=b_req_mtp_start_loc, - ) - else: - all_next_token_ids, additional_mem_indexes_cpu = self._draft_decode_func( - main_model_input=model_input, - main_model_output=model_output, - next_token_ids=next_token_ids, - b_req_mtp_start_loc=b_req_mtp_start_loc, - ) - - # dynamic_sizes_gpu 用于第二阶段更新 req 的 mtp_size - if self.enable_dynamic_mtp: - draft_probs_tensor = torch.cat(draft_probs_list, dim=-1).view(self.mtp_step, b_mtp_index_cpu.shape[0]) - dynamic_sizes_gpu = self._compute_dynamic_mtp_size_gpu_part(draft_probs_tensor=draft_probs_tensor) - # 异步拷贝回 CPU Pin Memory - dynamic_sizes_cpu = g_pin_mem_manager.async_copy_from_gpu_tensor( - key="dynamic_mtp_sizes", gpu_tensor=dynamic_sizes_gpu - ) - - dynamic_mtp_event = torch.cuda.Event() - dynamic_mtp_event.record() + all_next_token_ids, additional_mem_indexes_cpu = self._draft_decode_func( + main_model_input=model_input, + main_model_output=model_output, + next_token_ids=next_token_ids, + b_req_mtp_start_loc=b_req_mtp_start_loc, + ) mtp_scatter_next_token_ids( req_to_next_token_ids=self.model.req_manager.req_sampling_params_manager.req_to_next_token_ids, @@ -320,11 +312,6 @@ def decode_mtp( verify_event.synchronize() accepted_index_cpu_numpy = accepted_index_cpu.numpy() verify_ok_reqs = [run_reqs[i] for i in range(len(run_reqs)) if accepted_index_cpu_numpy[i] == 1] - if self.enable_dynamic_mtp: - dynamic_mtp_event.synchronize() - self._update_dynamic_mtp_size_cpu_part( - run_reqs=run_reqs, dynamic_sizes_cpu=dynamic_sizes_cpu, accepted_index_cpu=accepted_index_cpu - ) update_packs = self._pre_post_handle(verify_ok_reqs, is_chuncked_mode=False) # 第三阶段 @@ -337,6 +324,9 @@ def decode_mtp( need_free_mem_indexes = torch.cat([need_free_mem_indexes, additional_mem_indexes_cpu], dim=0) self._update_mtp_accept_ratio(decode_reqs=decode_reqs, mtp_accept_len_cpu=mtp_accept_len_cpu) + if self.enable_dynamic_mtp: + self.dynamic_mtp_planner.update(decode_reqs, mtp_accept_len_cpu) + select_mask = torch.tensor(accepted_index_cpu, dtype=torch.bool, device="cpu") self._post_handle( run_reqs=verify_ok_reqs, @@ -355,28 +345,6 @@ def decode_mtp( event_pack.notify_pre_post_handle() return - def _compute_dynamic_mtp_size_gpu_part( - self, - draft_probs_tensor: torch.Tensor, - ) -> torch.Tensor: - rand_vals = torch.rand_like(draft_probs_tensor) - accepted_mask = draft_probs_tensor > rand_vals - valid_steps = torch.cumprod(accepted_mask.to(torch.int32), dim=0) - dynamic_mtp_sizes = valid_steps.sum(dim=0) - return dynamic_mtp_sizes - - def _update_dynamic_mtp_size_cpu_part( - self, - run_reqs: List[InferReq], - dynamic_sizes_cpu: torch.Tensor, - accepted_index_cpu: torch.Tensor, - ): - assert len(run_reqs) == dynamic_sizes_cpu.shape[0] == accepted_index_cpu.shape[0] - for req, new_size, accepted in zip(run_reqs, dynamic_sizes_cpu.numpy(), accepted_index_cpu.numpy()): - if int(accepted) == 1: - req.current_mtp_step = int(new_size) - assert req.current_mtp_step <= req.mtp_step - def _draft_prefill_forward(self, model_input: ModelInput, model_output: ModelOutput, next_token_ids: torch.Tensor): # spec prefill: MTP, 这个地方只是为了填充draft model的 kv, 并不会使用生成的token_id。 draft_model_input = model_input @@ -442,9 +410,6 @@ def _draft_decode_eagle( all_next_token_ids = [] all_next_token_ids.append(next_token_ids) - # 用于收集每个 step 的 probs - draft_probs_list = [] if self.enable_dynamic_mtp else None - # process the draft model output for _step in range(self.mtp_step): draft_model_input.input_ids = draft_next_token_ids @@ -453,12 +418,7 @@ def _draft_decode_eagle( draft_model_idx = _step % self.num_mtp_models draft_model_output: ModelOutput = self.draft_models[draft_model_idx].forward(draft_model_input) - # 收集 probs(如果需要) - if self.enable_dynamic_mtp: - draft_next_token_ids, draft_probs = self._gen_argmax_token_ids_and_prob(draft_model_output) - draft_probs_list.append(draft_probs) - else: - draft_next_token_ids = self._gen_argmax_token_ids(draft_model_output) + draft_next_token_ids = self._gen_argmax_token_ids(draft_model_output) draft_model_input.b_seq_len += 1 draft_model_input.max_kv_seq_len += 1 eagle_mem_indexes_i = eagle_mem_indexes[_step * num_reqs : (_step + 1) * num_reqs] @@ -478,7 +438,4 @@ def _draft_decode_eagle( all_next_token_ids = torch.stack(all_next_token_ids, dim=1) # [batch_size, mtp_step + 1] - if self.enable_dynamic_mtp: - return all_next_token_ids, eagle_mem_indexes_cpu, draft_probs_list - else: - return all_next_token_ids, eagle_mem_indexes_cpu + return all_next_token_ids, eagle_mem_indexes_cpu diff --git a/lightllm/server/router/model_infer/mode_backend/dynamic_mtp_planner.py b/lightllm/server/router/model_infer/mode_backend/dynamic_mtp_planner.py new file mode 100644 index 000000000..eb2a51162 --- /dev/null +++ b/lightllm/server/router/model_infer/mode_backend/dynamic_mtp_planner.py @@ -0,0 +1,112 @@ +from __future__ import annotations + +import dataclasses +import math +import threading +from typing import List, Sequence, TYPE_CHECKING + +if TYPE_CHECKING: + from lightllm.server.router.model_infer.infer_batch import InferReq + + +# Development-time knobs. Keep these local while the dynamic MTP planner is being +# tuned; move the stable subset to StartArgs once the policy settles. +EMA_ALPHA = 0.2 +BUDGET_SCALE = 1.0 +MIN_STEP = 1 +MAX_STEP = None + + +@dataclasses.dataclass +class MTPPlan: + planned_steps: List[int] + selected_mtp_indexes: List[List[int]] + budget: int + estimated_step: int + b_req_mtp_start_loc: List[int] + + +class DynamicMTPPlanner: + """ + Plans a uniform dynamic MTP verification length from historical acceptance. + + The plan is intentionally based on already available history so decode + preprocessing does not have to wait for the current draft pass to finish. + """ + + def __init__( + self, + max_mtp_step: int, + ema_alpha: float = EMA_ALPHA, + budget_scale: float = BUDGET_SCALE, + min_step: int = MIN_STEP, + max_step: int = None, + ) -> None: + assert max_mtp_step >= 0 + assert 0.0 < ema_alpha <= 1.0 + assert budget_scale > 0.0 + self.max_mtp_step = max_mtp_step + self.ema_alpha = ema_alpha + self.budget_scale = budget_scale + self.min_step = max(0, min(min_step, max_mtp_step)) + if max_step is None: + max_step = max_mtp_step if MAX_STEP is None else MAX_STEP + self.max_step = max(self.min_step, min(max_step, max_mtp_step)) + self._lock = threading.Lock() + self._ema_max_accept_step = float(self.max_step) + + def build_plan(self, reqs: Sequence[InferReq]) -> MTPPlan: + req_num = len(reqs) + if req_num == 0: + return MTPPlan( + planned_steps=[], + selected_mtp_indexes=[], + budget=0, + estimated_step=0, + b_req_mtp_start_loc=[], + ) + + with self._lock: + slot_limit = int(math.ceil(self._ema_max_accept_step * self.budget_scale)) + + slot_limit = min(max(slot_limit, self.min_step), self.max_step) + planned_steps = [slot_limit for _ in reqs] + + selected_mtp_indexes = [list(range(1, step + 1)) for step in planned_steps] + + start_locs = [] + cur_loc = 0 + for selected_indexes in selected_mtp_indexes: + start_locs.append(cur_loc) + cur_loc += 1 + len(selected_indexes) + + for req, step in zip(reqs, planned_steps): + req.current_mtp_step = step + + return MTPPlan( + planned_steps=planned_steps, + selected_mtp_indexes=selected_mtp_indexes, + budget=sum(planned_steps), + estimated_step=slot_limit, + b_req_mtp_start_loc=start_locs, + ) + + def update( + self, + reqs: Sequence[InferReq], + mtp_accept_len_cpu, + ) -> None: + if not reqs: + return + + accept_len_np = mtp_accept_len_cpu.numpy() + max_accept_step = 0 + for req_index in range(len(reqs)): + accept_len = int(accept_len_np[req_index]) + accept_step = max(0, accept_len - 1) + max_accept_step = max(max_accept_step, min(accept_step, self.max_step)) + + with self._lock: + self._ema_max_accept_step = ( + self.ema_alpha * max_accept_step + (1.0 - self.ema_alpha) * self._ema_max_accept_step + ) diff --git a/lightllm/server/router/model_infer/mode_backend/generic_pre_process.py b/lightllm/server/router/model_infer/mode_backend/generic_pre_process.py index 3d9d8815e..c4ae24faf 100644 --- a/lightllm/server/router/model_infer/mode_backend/generic_pre_process.py +++ b/lightllm/server/router/model_infer/mode_backend/generic_pre_process.py @@ -1,6 +1,6 @@ import torch import numpy as np -from typing import List, Tuple +from typing import List, Optional, Sequence, Tuple from lightllm.server.router.model_infer.infer_batch import InferReq, g_infer_context from lightllm.common.basemodel.infer_lock import g_infer_state_lock from lightllm.common.basemodel.batch_objs import ModelInput @@ -94,7 +94,17 @@ def prepare_prefill_inputs(req_objs: List[InferReq], is_chuncked_mode: bool) -> return model_input, run_reqs -def prepare_decode_inputs(req_objs: List[InferReq]) -> Tuple[ModelInput, List[InferReq]]: +def prepare_decode_inputs( + req_objs: List[InferReq], + mtp_decode_steps: Optional[Sequence[int]] = None, + mtp_decode_indexes: Optional[Sequence[Sequence[int]]] = None, +) -> Tuple[ModelInput, List[InferReq]]: + if mtp_decode_steps is not None: + assert len(mtp_decode_steps) == len(req_objs) + if mtp_decode_indexes is not None: + assert mtp_decode_steps is None + assert len(mtp_decode_indexes) == len(req_objs) + run_reqs: List[InferReq] = [] total_token_num = 0 b_req_idx = [] @@ -102,7 +112,7 @@ def prepare_decode_inputs(req_objs: List[InferReq]) -> Tuple[ModelInput, List[In b_seq_len = [] b_q_seq_len = [] multimodal_params = [] - for req in req_objs: + for req_index, req in enumerate(req_objs): run_reqs.append(req) b_req_idx.append(req.req_idx) seq_len = req.get_cur_total_len() @@ -113,15 +123,25 @@ def prepare_decode_inputs(req_objs: List[InferReq]) -> Tuple[ModelInput, List[In b_mtp_index.append(0) multimodal_params.append(req.multimodal_params) # process the draft tokens. - # 动态 MTP 模式:使用动态 current_mtp_step 构建 batch - # 非动态 MTP 模式:current_mtp_step 为固定的 mtp_step - for step in range(req.current_mtp_step): + # 动态 MTP planner 可以提前给出本轮要填充进验证槽位的 draft index。 + # 当前 planner 使用连续 prefix index;后续非连续选择可在该接口后接 compact kernel。 + if mtp_decode_indexes is not None: + decode_indexes = [int(index) for index in mtp_decode_indexes[req_index]] + assert decode_indexes == list(range(1, len(decode_indexes) + 1)), ( + "Current MTP verify path requires contiguous prefix draft indexes. " + "Non-prefix indexes need a compact/remap kernel before decode." + ) + else: + decode_step = req.current_mtp_step if mtp_decode_steps is None else int(mtp_decode_steps[req_index]) + decode_indexes = range(1, decode_step + 1) + + for mtp_index in decode_indexes: run_reqs.append(req) b_req_idx.append(req.req_idx) - seq_len += 1 - b_seq_len.append(seq_len) - total_token_num += seq_len - b_mtp_index.append(step + 1) + mtp_seq_len = seq_len + int(mtp_index) + b_seq_len.append(mtp_seq_len) + total_token_num += mtp_seq_len + b_mtp_index.append(int(mtp_index)) multimodal_params.append(req.multimodal_params) b_q_seq_len.append(1) From c6e0a7270663585aa5a9ce22eac5ae0ab4fb54fe Mon Sep 17 00:00:00 2001 From: Junyi Chen Date: Sat, 16 May 2026 13:29:44 +0800 Subject: [PATCH 04/20] Revert "first commit for ema" This reverts commit db0a2ca692c2525b8c5ca8b8a02ed635b2e1470e. --- .../mode_backend/chunked_prefill/impl.py | 91 ++++++++++---- .../mode_backend/dynamic_mtp_planner.py | 112 ------------------ .../mode_backend/generic_pre_process.py | 40 ++----- 3 files changed, 77 insertions(+), 166 deletions(-) delete mode 100644 lightllm/server/router/model_infer/mode_backend/dynamic_mtp_planner.py diff --git a/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl.py b/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl.py index 6b1578ae7..c41dbb6d9 100644 --- a/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl.py +++ b/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl.py @@ -25,7 +25,6 @@ from lightllm.utils.log_utils import init_logger from lightllm.utils.dist_utils import get_current_device_id from lightllm.utils.envs_utils import get_env_start_args, enable_dynamic_mtp_verify -from lightllm.server.router.model_infer.mode_backend.dynamic_mtp_planner import DynamicMTPPlanner from .control_state import ControlState logger = init_logger(__name__) @@ -46,9 +45,6 @@ def __init__(self) -> None: self.num_mtp_models = 1 if self.is_mtp_eagle else get_env_start_args().mtp_step self._draft_decode_func = self._draft_decode_eagle if self.is_mtp_eagle else self._draft_decode_vanilla self.enable_dynamic_mtp = enable_dynamic_mtp_verify() - self.dynamic_mtp_planner = ( - DynamicMTPPlanner(max_mtp_step=get_env_start_args().mtp_step) if self.enable_dynamic_mtp else None - ) else: self.prefill = self.prefill_normal self.decode = self.decode_normal @@ -237,15 +233,7 @@ def decode_mtp( """ MTP解码的通用流程,整合eagle和vanilla的共同逻辑 """ - mtp_plan = None - if self.enable_dynamic_mtp: - mtp_plan = self.dynamic_mtp_planner.build_plan(decode_reqs) - model_input, run_reqs = prepare_decode_inputs( - decode_reqs, - mtp_decode_indexes=mtp_plan.selected_mtp_indexes, - ) - else: - model_input, run_reqs = prepare_decode_inputs(decode_reqs) + model_input, run_reqs = prepare_decode_inputs(decode_reqs) with torch.cuda.stream(g_infer_context.get_overlap_stream()): b_mtp_index_cpu = model_input.b_mtp_index @@ -272,12 +260,32 @@ def decode_mtp( verify_event = torch.cuda.Event() verify_event.record() - all_next_token_ids, additional_mem_indexes_cpu = self._draft_decode_func( - main_model_input=model_input, - main_model_output=model_output, - next_token_ids=next_token_ids, - b_req_mtp_start_loc=b_req_mtp_start_loc, - ) + if self.enable_dynamic_mtp: + all_next_token_ids, additional_mem_indexes_cpu, draft_probs_list = self._draft_decode_func( + main_model_input=model_input, + main_model_output=model_output, + next_token_ids=next_token_ids, + b_req_mtp_start_loc=b_req_mtp_start_loc, + ) + else: + all_next_token_ids, additional_mem_indexes_cpu = self._draft_decode_func( + main_model_input=model_input, + main_model_output=model_output, + next_token_ids=next_token_ids, + b_req_mtp_start_loc=b_req_mtp_start_loc, + ) + + # dynamic_sizes_gpu 用于第二阶段更新 req 的 mtp_size + if self.enable_dynamic_mtp: + draft_probs_tensor = torch.cat(draft_probs_list, dim=-1).view(self.mtp_step, b_mtp_index_cpu.shape[0]) + dynamic_sizes_gpu = self._compute_dynamic_mtp_size_gpu_part(draft_probs_tensor=draft_probs_tensor) + # 异步拷贝回 CPU Pin Memory + dynamic_sizes_cpu = g_pin_mem_manager.async_copy_from_gpu_tensor( + key="dynamic_mtp_sizes", gpu_tensor=dynamic_sizes_gpu + ) + + dynamic_mtp_event = torch.cuda.Event() + dynamic_mtp_event.record() mtp_scatter_next_token_ids( req_to_next_token_ids=self.model.req_manager.req_sampling_params_manager.req_to_next_token_ids, @@ -312,6 +320,11 @@ def decode_mtp( verify_event.synchronize() accepted_index_cpu_numpy = accepted_index_cpu.numpy() verify_ok_reqs = [run_reqs[i] for i in range(len(run_reqs)) if accepted_index_cpu_numpy[i] == 1] + if self.enable_dynamic_mtp: + dynamic_mtp_event.synchronize() + self._update_dynamic_mtp_size_cpu_part( + run_reqs=run_reqs, dynamic_sizes_cpu=dynamic_sizes_cpu, accepted_index_cpu=accepted_index_cpu + ) update_packs = self._pre_post_handle(verify_ok_reqs, is_chuncked_mode=False) # 第三阶段 @@ -324,9 +337,6 @@ def decode_mtp( need_free_mem_indexes = torch.cat([need_free_mem_indexes, additional_mem_indexes_cpu], dim=0) self._update_mtp_accept_ratio(decode_reqs=decode_reqs, mtp_accept_len_cpu=mtp_accept_len_cpu) - if self.enable_dynamic_mtp: - self.dynamic_mtp_planner.update(decode_reqs, mtp_accept_len_cpu) - select_mask = torch.tensor(accepted_index_cpu, dtype=torch.bool, device="cpu") self._post_handle( run_reqs=verify_ok_reqs, @@ -345,6 +355,28 @@ def decode_mtp( event_pack.notify_pre_post_handle() return + def _compute_dynamic_mtp_size_gpu_part( + self, + draft_probs_tensor: torch.Tensor, + ) -> torch.Tensor: + rand_vals = torch.rand_like(draft_probs_tensor) + accepted_mask = draft_probs_tensor > rand_vals + valid_steps = torch.cumprod(accepted_mask.to(torch.int32), dim=0) + dynamic_mtp_sizes = valid_steps.sum(dim=0) + return dynamic_mtp_sizes + + def _update_dynamic_mtp_size_cpu_part( + self, + run_reqs: List[InferReq], + dynamic_sizes_cpu: torch.Tensor, + accepted_index_cpu: torch.Tensor, + ): + assert len(run_reqs) == dynamic_sizes_cpu.shape[0] == accepted_index_cpu.shape[0] + for req, new_size, accepted in zip(run_reqs, dynamic_sizes_cpu.numpy(), accepted_index_cpu.numpy()): + if int(accepted) == 1: + req.current_mtp_step = int(new_size) + assert req.current_mtp_step <= req.mtp_step + def _draft_prefill_forward(self, model_input: ModelInput, model_output: ModelOutput, next_token_ids: torch.Tensor): # spec prefill: MTP, 这个地方只是为了填充draft model的 kv, 并不会使用生成的token_id。 draft_model_input = model_input @@ -410,6 +442,9 @@ def _draft_decode_eagle( all_next_token_ids = [] all_next_token_ids.append(next_token_ids) + # 用于收集每个 step 的 probs + draft_probs_list = [] if self.enable_dynamic_mtp else None + # process the draft model output for _step in range(self.mtp_step): draft_model_input.input_ids = draft_next_token_ids @@ -418,7 +453,12 @@ def _draft_decode_eagle( draft_model_idx = _step % self.num_mtp_models draft_model_output: ModelOutput = self.draft_models[draft_model_idx].forward(draft_model_input) - draft_next_token_ids = self._gen_argmax_token_ids(draft_model_output) + # 收集 probs(如果需要) + if self.enable_dynamic_mtp: + draft_next_token_ids, draft_probs = self._gen_argmax_token_ids_and_prob(draft_model_output) + draft_probs_list.append(draft_probs) + else: + draft_next_token_ids = self._gen_argmax_token_ids(draft_model_output) draft_model_input.b_seq_len += 1 draft_model_input.max_kv_seq_len += 1 eagle_mem_indexes_i = eagle_mem_indexes[_step * num_reqs : (_step + 1) * num_reqs] @@ -438,4 +478,7 @@ def _draft_decode_eagle( all_next_token_ids = torch.stack(all_next_token_ids, dim=1) # [batch_size, mtp_step + 1] - return all_next_token_ids, eagle_mem_indexes_cpu + if self.enable_dynamic_mtp: + return all_next_token_ids, eagle_mem_indexes_cpu, draft_probs_list + else: + return all_next_token_ids, eagle_mem_indexes_cpu diff --git a/lightllm/server/router/model_infer/mode_backend/dynamic_mtp_planner.py b/lightllm/server/router/model_infer/mode_backend/dynamic_mtp_planner.py deleted file mode 100644 index eb2a51162..000000000 --- a/lightllm/server/router/model_infer/mode_backend/dynamic_mtp_planner.py +++ /dev/null @@ -1,112 +0,0 @@ -from __future__ import annotations - -import dataclasses -import math -import threading -from typing import List, Sequence, TYPE_CHECKING - -if TYPE_CHECKING: - from lightllm.server.router.model_infer.infer_batch import InferReq - - -# Development-time knobs. Keep these local while the dynamic MTP planner is being -# tuned; move the stable subset to StartArgs once the policy settles. -EMA_ALPHA = 0.2 -BUDGET_SCALE = 1.0 -MIN_STEP = 1 -MAX_STEP = None - - -@dataclasses.dataclass -class MTPPlan: - planned_steps: List[int] - selected_mtp_indexes: List[List[int]] - budget: int - estimated_step: int - b_req_mtp_start_loc: List[int] - - -class DynamicMTPPlanner: - """ - Plans a uniform dynamic MTP verification length from historical acceptance. - - The plan is intentionally based on already available history so decode - preprocessing does not have to wait for the current draft pass to finish. - """ - - def __init__( - self, - max_mtp_step: int, - ema_alpha: float = EMA_ALPHA, - budget_scale: float = BUDGET_SCALE, - min_step: int = MIN_STEP, - max_step: int = None, - ) -> None: - assert max_mtp_step >= 0 - assert 0.0 < ema_alpha <= 1.0 - assert budget_scale > 0.0 - self.max_mtp_step = max_mtp_step - self.ema_alpha = ema_alpha - self.budget_scale = budget_scale - self.min_step = max(0, min(min_step, max_mtp_step)) - if max_step is None: - max_step = max_mtp_step if MAX_STEP is None else MAX_STEP - self.max_step = max(self.min_step, min(max_step, max_mtp_step)) - self._lock = threading.Lock() - self._ema_max_accept_step = float(self.max_step) - - def build_plan(self, reqs: Sequence[InferReq]) -> MTPPlan: - req_num = len(reqs) - if req_num == 0: - return MTPPlan( - planned_steps=[], - selected_mtp_indexes=[], - budget=0, - estimated_step=0, - b_req_mtp_start_loc=[], - ) - - with self._lock: - slot_limit = int(math.ceil(self._ema_max_accept_step * self.budget_scale)) - - slot_limit = min(max(slot_limit, self.min_step), self.max_step) - planned_steps = [slot_limit for _ in reqs] - - selected_mtp_indexes = [list(range(1, step + 1)) for step in planned_steps] - - start_locs = [] - cur_loc = 0 - for selected_indexes in selected_mtp_indexes: - start_locs.append(cur_loc) - cur_loc += 1 + len(selected_indexes) - - for req, step in zip(reqs, planned_steps): - req.current_mtp_step = step - - return MTPPlan( - planned_steps=planned_steps, - selected_mtp_indexes=selected_mtp_indexes, - budget=sum(planned_steps), - estimated_step=slot_limit, - b_req_mtp_start_loc=start_locs, - ) - - def update( - self, - reqs: Sequence[InferReq], - mtp_accept_len_cpu, - ) -> None: - if not reqs: - return - - accept_len_np = mtp_accept_len_cpu.numpy() - max_accept_step = 0 - for req_index in range(len(reqs)): - accept_len = int(accept_len_np[req_index]) - accept_step = max(0, accept_len - 1) - max_accept_step = max(max_accept_step, min(accept_step, self.max_step)) - - with self._lock: - self._ema_max_accept_step = ( - self.ema_alpha * max_accept_step + (1.0 - self.ema_alpha) * self._ema_max_accept_step - ) diff --git a/lightllm/server/router/model_infer/mode_backend/generic_pre_process.py b/lightllm/server/router/model_infer/mode_backend/generic_pre_process.py index c4ae24faf..3d9d8815e 100644 --- a/lightllm/server/router/model_infer/mode_backend/generic_pre_process.py +++ b/lightllm/server/router/model_infer/mode_backend/generic_pre_process.py @@ -1,6 +1,6 @@ import torch import numpy as np -from typing import List, Optional, Sequence, Tuple +from typing import List, Tuple from lightllm.server.router.model_infer.infer_batch import InferReq, g_infer_context from lightllm.common.basemodel.infer_lock import g_infer_state_lock from lightllm.common.basemodel.batch_objs import ModelInput @@ -94,17 +94,7 @@ def prepare_prefill_inputs(req_objs: List[InferReq], is_chuncked_mode: bool) -> return model_input, run_reqs -def prepare_decode_inputs( - req_objs: List[InferReq], - mtp_decode_steps: Optional[Sequence[int]] = None, - mtp_decode_indexes: Optional[Sequence[Sequence[int]]] = None, -) -> Tuple[ModelInput, List[InferReq]]: - if mtp_decode_steps is not None: - assert len(mtp_decode_steps) == len(req_objs) - if mtp_decode_indexes is not None: - assert mtp_decode_steps is None - assert len(mtp_decode_indexes) == len(req_objs) - +def prepare_decode_inputs(req_objs: List[InferReq]) -> Tuple[ModelInput, List[InferReq]]: run_reqs: List[InferReq] = [] total_token_num = 0 b_req_idx = [] @@ -112,7 +102,7 @@ def prepare_decode_inputs( b_seq_len = [] b_q_seq_len = [] multimodal_params = [] - for req_index, req in enumerate(req_objs): + for req in req_objs: run_reqs.append(req) b_req_idx.append(req.req_idx) seq_len = req.get_cur_total_len() @@ -123,25 +113,15 @@ def prepare_decode_inputs( b_mtp_index.append(0) multimodal_params.append(req.multimodal_params) # process the draft tokens. - # 动态 MTP planner 可以提前给出本轮要填充进验证槽位的 draft index。 - # 当前 planner 使用连续 prefix index;后续非连续选择可在该接口后接 compact kernel。 - if mtp_decode_indexes is not None: - decode_indexes = [int(index) for index in mtp_decode_indexes[req_index]] - assert decode_indexes == list(range(1, len(decode_indexes) + 1)), ( - "Current MTP verify path requires contiguous prefix draft indexes. " - "Non-prefix indexes need a compact/remap kernel before decode." - ) - else: - decode_step = req.current_mtp_step if mtp_decode_steps is None else int(mtp_decode_steps[req_index]) - decode_indexes = range(1, decode_step + 1) - - for mtp_index in decode_indexes: + # 动态 MTP 模式:使用动态 current_mtp_step 构建 batch + # 非动态 MTP 模式:current_mtp_step 为固定的 mtp_step + for step in range(req.current_mtp_step): run_reqs.append(req) b_req_idx.append(req.req_idx) - mtp_seq_len = seq_len + int(mtp_index) - b_seq_len.append(mtp_seq_len) - total_token_num += mtp_seq_len - b_mtp_index.append(int(mtp_index)) + seq_len += 1 + b_seq_len.append(seq_len) + total_token_num += seq_len + b_mtp_index.append(step + 1) multimodal_params.append(req.multimodal_params) b_q_seq_len.append(1) From 5317c5ae740023ebab69480e1069f657aeb45fa8 Mon Sep 17 00:00:00 2001 From: Junyi Chen Date: Sun, 17 May 2026 21:04:09 +0800 Subject: [PATCH 05/20] fix dynamic_mtp_planner --- .../model_infer/mode_backend/base_backend.py | 19 +- .../mode_backend/chunked_prefill/impl.py | 101 ++++--- .../mode_backend/dynamic_mtp_planner.py | 275 ++++++++++++++++++ 3 files changed, 340 insertions(+), 55 deletions(-) create mode 100644 lightllm/server/router/model_infer/mode_backend/dynamic_mtp_planner.py diff --git a/lightllm/server/router/model_infer/mode_backend/base_backend.py b/lightllm/server/router/model_infer/mode_backend/base_backend.py index 321055b4b..bde7fad52 100644 --- a/lightllm/server/router/model_infer/mode_backend/base_backend.py +++ b/lightllm/server/router/model_infer/mode_backend/base_backend.py @@ -775,14 +775,19 @@ def _update_mtp_accept_ratio( return - def _update_mtp_verify_token_num(self, decode_reqs: List[InferReq]): + def _update_mtp_verify_token_num( + self, + decode_reqs: List[InferReq], + verify_token_nums: Optional[List[int]] = None, + ): if self.is_master_in_dp: - for req in decode_reqs: - # 统计发送给主模型验证的 token 数量:1 个主 token + 当前 mtp_size 个 draft token - # 在静态 MTP 模式下,使用固定的 mtp_step;在动态 MTP 模式下,使用动态调整的 current_mtp_step - # current_mtp_step 在静态 MTP 模式下为 mtp_step,在动态 MTP 模式下会在推理过程中动态设置。 - assert req.current_mtp_step >= 0 - req.update_mtp_verify_token_num(verify_token_num=1 + req.current_mtp_step) + if verify_token_nums is None: + verify_token_nums = [1 + req.current_mtp_step for req in decode_reqs] + assert len(decode_reqs) == len(verify_token_nums) + for req, verify_token_num in zip(decode_reqs, verify_token_nums): + # 统计发送给主模型验证的 token 数量,动态 MTP 模式由 planner 传入实际裁剪后的行数。 + assert verify_token_num >= 1 + req.update_mtp_verify_token_num(verify_token_num=verify_token_num) return def _gen_argmax_token_ids(self, model_output: ModelOutput): diff --git a/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl.py b/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl.py index c41dbb6d9..71b22c08e 100644 --- a/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl.py +++ b/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl.py @@ -25,6 +25,7 @@ from lightllm.utils.log_utils import init_logger from lightllm.utils.dist_utils import get_current_device_id from lightllm.utils.envs_utils import get_env_start_args, enable_dynamic_mtp_verify +from lightllm.server.router.model_infer.mode_backend.dynamic_mtp_planner import DynamicMTPPlanner from .control_state import ControlState logger = init_logger(__name__) @@ -45,6 +46,9 @@ def __init__(self) -> None: self.num_mtp_models = 1 if self.is_mtp_eagle else get_env_start_args().mtp_step self._draft_decode_func = self._draft_decode_eagle if self.is_mtp_eagle else self._draft_decode_vanilla self.enable_dynamic_mtp = enable_dynamic_mtp_verify() + self.dynamic_mtp_planner = ( + DynamicMTPPlanner(mtp_step=get_env_start_args().mtp_step) if self.enable_dynamic_mtp else None + ) else: self.prefill = self.prefill_normal self.decode = self.decode_normal @@ -233,9 +237,25 @@ def decode_mtp( """ MTP解码的通用流程,整合eagle和vanilla的共同逻辑 """ + if self.enable_dynamic_mtp: + # 让通用 pre-process 始终构建最大候选池,动态策略只在 forward 前裁剪。 + for req in decode_reqs: + req.current_mtp_step = req.mtp_step model_input, run_reqs = prepare_decode_inputs(decode_reqs) + dynamic_mtp_plan = None with torch.cuda.stream(g_infer_context.get_overlap_stream()): + if self.enable_dynamic_mtp: + model_input, run_reqs, dynamic_mtp_plan = self.dynamic_mtp_planner.trim_before_forward( + model_input=model_input, + run_reqs=run_reqs, + decode_reqs=decode_reqs, + ) + dynamic_mtp_start_event = None + if self.enable_dynamic_mtp: + dynamic_mtp_start_event = torch.cuda.Event(enable_timing=True) + dynamic_mtp_start_event.record() + b_mtp_index_cpu = model_input.b_mtp_index model_output = self.model.forward(model_input) next_token_ids, next_token_logprobs = sample(model_output.logits, run_reqs, self.eos_id) @@ -257,9 +277,10 @@ def decode_mtp( gpu_tensor=accepted_index, ) - verify_event = torch.cuda.Event() + verify_event = torch.cuda.Event(enable_timing=self.enable_dynamic_mtp) verify_event.record() + per_req_probs_cpu = None if self.enable_dynamic_mtp: all_next_token_ids, additional_mem_indexes_cpu, draft_probs_list = self._draft_decode_func( main_model_input=model_input, @@ -267,6 +288,13 @@ def decode_mtp( next_token_ids=next_token_ids, b_req_mtp_start_loc=b_req_mtp_start_loc, ) + draft_probs_tensor = torch.stack(draft_probs_list, dim=1) + request_start_rows = b_req_mtp_start_loc.to(torch.long) + per_req_probs = draft_probs_tensor[request_start_rows].mean(dim=1) + per_req_probs_cpu = g_pin_mem_manager.async_copy_from_gpu_tensor( + key="dynamic_mtp_req_probs", + gpu_tensor=per_req_probs, + ) else: all_next_token_ids, additional_mem_indexes_cpu = self._draft_decode_func( main_model_input=model_input, @@ -275,18 +303,6 @@ def decode_mtp( b_req_mtp_start_loc=b_req_mtp_start_loc, ) - # dynamic_sizes_gpu 用于第二阶段更新 req 的 mtp_size - if self.enable_dynamic_mtp: - draft_probs_tensor = torch.cat(draft_probs_list, dim=-1).view(self.mtp_step, b_mtp_index_cpu.shape[0]) - dynamic_sizes_gpu = self._compute_dynamic_mtp_size_gpu_part(draft_probs_tensor=draft_probs_tensor) - # 异步拷贝回 CPU Pin Memory - dynamic_sizes_cpu = g_pin_mem_manager.async_copy_from_gpu_tensor( - key="dynamic_mtp_sizes", gpu_tensor=dynamic_sizes_gpu - ) - - dynamic_mtp_event = torch.cuda.Event() - dynamic_mtp_event.record() - mtp_scatter_next_token_ids( req_to_next_token_ids=self.model.req_manager.req_sampling_params_manager.req_to_next_token_ids, b_req_mtp_start_loc=b_req_mtp_start_loc, @@ -310,26 +326,34 @@ def decode_mtp( gpu_tensor=mtp_accept_len, ) - sync_event = torch.cuda.Event() + sync_event = torch.cuda.Event(enable_timing=self.enable_dynamic_mtp) sync_event.record() # 第二阶段 event_pack.notify_post_handle_and_wait_pre_post_handle() - self._update_mtp_verify_token_num(decode_reqs=decode_reqs) + self._update_mtp_verify_token_num( + decode_reqs=decode_reqs, + verify_token_nums=dynamic_mtp_plan.per_req_rows if dynamic_mtp_plan is not None else None, + ) verify_event.synchronize() + dynamic_mtp_elapsed_ms = None accepted_index_cpu_numpy = accepted_index_cpu.numpy() verify_ok_reqs = [run_reqs[i] for i in range(len(run_reqs)) if accepted_index_cpu_numpy[i] == 1] - if self.enable_dynamic_mtp: - dynamic_mtp_event.synchronize() - self._update_dynamic_mtp_size_cpu_part( - run_reqs=run_reqs, dynamic_sizes_cpu=dynamic_sizes_cpu, accepted_index_cpu=accepted_index_cpu - ) update_packs = self._pre_post_handle(verify_ok_reqs, is_chuncked_mode=False) # 第三阶段 event_pack.notify_forward_and_wait_post_handle() sync_event.synchronize() + if self.enable_dynamic_mtp: + dynamic_mtp_elapsed_ms = dynamic_mtp_start_event.elapsed_time(sync_event) + self.dynamic_mtp_planner.update_after_verify( + plan=dynamic_mtp_plan, + decode_reqs=decode_reqs, + mtp_accept_len_cpu=mtp_accept_len_cpu, + elapsed_ms=dynamic_mtp_elapsed_ms, + per_req_probs_cpu=per_req_probs_cpu, + ) # 处理需要释放的内存索引 need_free_mem_indexes = model_input.mem_indexes_cpu[accepted_index_cpu == 0] @@ -355,28 +379,6 @@ def decode_mtp( event_pack.notify_pre_post_handle() return - def _compute_dynamic_mtp_size_gpu_part( - self, - draft_probs_tensor: torch.Tensor, - ) -> torch.Tensor: - rand_vals = torch.rand_like(draft_probs_tensor) - accepted_mask = draft_probs_tensor > rand_vals - valid_steps = torch.cumprod(accepted_mask.to(torch.int32), dim=0) - dynamic_mtp_sizes = valid_steps.sum(dim=0) - return dynamic_mtp_sizes - - def _update_dynamic_mtp_size_cpu_part( - self, - run_reqs: List[InferReq], - dynamic_sizes_cpu: torch.Tensor, - accepted_index_cpu: torch.Tensor, - ): - assert len(run_reqs) == dynamic_sizes_cpu.shape[0] == accepted_index_cpu.shape[0] - for req, new_size, accepted in zip(run_reqs, dynamic_sizes_cpu.numpy(), accepted_index_cpu.numpy()): - if int(accepted) == 1: - req.current_mtp_step = int(new_size) - assert req.current_mtp_step <= req.mtp_step - def _draft_prefill_forward(self, model_input: ModelInput, model_output: ModelOutput, next_token_ids: torch.Tensor): # spec prefill: MTP, 这个地方只是为了填充draft model的 kv, 并不会使用生成的token_id。 draft_model_input = model_input @@ -405,17 +407,24 @@ def _draft_decode_vanilla( draft_next_token_ids = next_token_ids all_next_token_ids = [] all_next_token_ids.append(next_token_ids) + draft_probs_list = [] if self.enable_dynamic_mtp else None # process the draft model output for draft_model_idx in range(self.mtp_step): draft_model_input.input_ids = draft_next_token_ids draft_model_input.mtp_draft_input_hiddens = draft_model_output.mtp_main_output_hiddens # spec decode: MTP draft_model_output: ModelOutput = self.draft_models[draft_model_idx].forward(draft_model_input) - draft_next_token_ids = self._gen_argmax_token_ids(draft_model_output) + if self.enable_dynamic_mtp: + draft_next_token_ids, draft_probs = self._gen_argmax_token_ids_and_prob(draft_model_output) + draft_probs_list.append(draft_probs) + else: + draft_next_token_ids = self._gen_argmax_token_ids(draft_model_output) all_next_token_ids.append(draft_next_token_ids) all_next_token_ids = torch.stack(all_next_token_ids, dim=1) # [batch_size, mtp_step + 1] + if self.enable_dynamic_mtp: + return all_next_token_ids, None, draft_probs_list return all_next_token_ids, None def _draft_decode_eagle( @@ -441,8 +450,6 @@ def _draft_decode_eagle( draft_next_token_ids = next_token_ids all_next_token_ids = [] all_next_token_ids.append(next_token_ids) - - # 用于收集每个 step 的 probs draft_probs_list = [] if self.enable_dynamic_mtp else None # process the draft model output @@ -453,7 +460,6 @@ def _draft_decode_eagle( draft_model_idx = _step % self.num_mtp_models draft_model_output: ModelOutput = self.draft_models[draft_model_idx].forward(draft_model_input) - # 收集 probs(如果需要) if self.enable_dynamic_mtp: draft_next_token_ids, draft_probs = self._gen_argmax_token_ids_and_prob(draft_model_output) draft_probs_list.append(draft_probs) @@ -480,5 +486,4 @@ def _draft_decode_eagle( if self.enable_dynamic_mtp: return all_next_token_ids, eagle_mem_indexes_cpu, draft_probs_list - else: - return all_next_token_ids, eagle_mem_indexes_cpu + return all_next_token_ids, eagle_mem_indexes_cpu diff --git a/lightllm/server/router/model_infer/mode_backend/dynamic_mtp_planner.py b/lightllm/server/router/model_infer/mode_backend/dynamic_mtp_planner.py new file mode 100644 index 000000000..b09a711bd --- /dev/null +++ b/lightllm/server/router/model_infer/mode_backend/dynamic_mtp_planner.py @@ -0,0 +1,275 @@ +import copy +import math +import random +from dataclasses import dataclass +from typing import Dict, List, Optional + +import torch + +from lightllm.common.basemodel.batch_objs import ModelInput +from lightllm.common.basemodel.infer_lock import g_infer_state_lock +from lightllm.server.router.model_infer.infer_batch import InferReq, g_infer_context +from lightllm.utils.envs_utils import get_diverse_max_batch_shared_group_size + + +@dataclass +class DynamicMTPPlan: + req_num: int + original_batch_size: int + dynamic_batch_size: int + keep_indices: torch.Tensor + per_req_rows: List[int] + estimated_accept_mean: float + estimated_accept_std: float + + +class _EMAValue: + def __init__(self, decay: float, init_value: Optional[float] = None) -> None: + self.decay = decay + self.value = init_value + self.initialized = init_value is not None + + def update(self, new_value: float) -> float: + if not self.initialized: + self.value = new_value + self.initialized = True + else: + self.value = self.decay * self.value + (1.0 - self.decay) * new_value + return self.value + + def get(self, fallback: float) -> float: + return self.value if self.initialized else fallback + + +class DynamicMTPPlanner: + def __init__( + self, + mtp_step: int, + ema_decay: float = 0.9, + confidence_k: float = 1.0, + ) -> None: + self.mtp_step = mtp_step + self.max_rows_per_req = mtp_step + 1 + self.confidence_k = confidence_k + self.accept_mean = _EMAValue(ema_decay, init_value=float(self.max_rows_per_req)) + self.accept_second_moment = _EMAValue(ema_decay, init_value=float(self.max_rows_per_req**2)) + self.req_accept_mean: Dict[int, _EMAValue] = {} + self.req_prob: Dict[int, _EMAValue] = {} + self.latency_ms_by_batch_size: Dict[int, _EMAValue] = {} + self.accepted_token_speed = _EMAValue(ema_decay) + self.verify_row_speed = _EMAValue(ema_decay) + self.actual_speedup = _EMAValue(ema_decay) + self.single_token_speed_by_req_num: Dict[int, _EMAValue] = {} + self.last_plan: Optional[DynamicMTPPlan] = None + + def trim_before_forward( + self, + model_input: ModelInput, + run_reqs: List[InferReq], + decode_reqs: List[InferReq], + ): + plan = self._build_plan(model_input=model_input, decode_reqs=decode_reqs) + self.last_plan = plan + if plan.dynamic_batch_size == plan.original_batch_size: + return model_input, run_reqs, plan + + pruned_indices = self._invert_indices(plan.keep_indices, plan.original_batch_size) + if pruned_indices.numel() > 0: + pruned_mem_indexes = model_input.mem_indexes_cpu[pruned_indices] + g_infer_state_lock.acquire() + g_infer_context.req_manager.mem_manager.free(pruned_mem_indexes) + g_infer_state_lock.release() + + trimmed_input = copy.copy(model_input) + keep_indices = plan.keep_indices + keep_list = keep_indices.tolist() + + trimmed_input.batch_size = plan.dynamic_batch_size + trimmed_input.b_req_idx = model_input.b_req_idx[keep_indices].contiguous() + trimmed_input.b_mtp_index = model_input.b_mtp_index[keep_indices].contiguous() + trimmed_input.b_seq_len = model_input.b_seq_len[keep_indices].contiguous() + trimmed_input.mem_indexes_cpu = model_input.mem_indexes_cpu[keep_indices].contiguous() + trimmed_input.mem_indexes = None + trimmed_input.total_token_num = int(trimmed_input.b_seq_len.sum().item()) + trimmed_input.max_kv_seq_len = int(trimmed_input.b_seq_len.max().item()) + trimmed_input.multimodal_params = [model_input.multimodal_params[index] for index in keep_list] + trimmed_run_reqs = [run_reqs[index] for index in keep_list] + trimmed_input.b_mark_shared_group = self._build_mtp_shared_group_infos(trimmed_run_reqs) + trimmed_input.check_input() + return trimmed_input, trimmed_run_reqs, plan + + def update_after_verify( + self, + plan: DynamicMTPPlan, + decode_reqs: List[InferReq], + mtp_accept_len_cpu: torch.Tensor, + elapsed_ms: float, + per_req_probs_cpu: Optional[torch.Tensor] = None, + ) -> None: + if plan is None: + return + + accept_lens = [int(value) for value in mtp_accept_len_cpu.numpy()] + if not accept_lens: + return + + batch_mean = sum(accept_lens) / len(accept_lens) + batch_second = sum(value * value for value in accept_lens) / len(accept_lens) + self.accept_mean.update(batch_mean) + self.accept_second_moment.update(batch_second) + + for req, accept_len in zip(decode_reqs, accept_lens): + req_ema = self.req_accept_mean.get(req.req_id) + if req_ema is None: + req_ema = _EMAValue(self.accept_mean.decay, init_value=batch_mean) + self.req_accept_mean[req.req_id] = req_ema + req_ema.update(float(accept_len)) + + if per_req_probs_cpu is not None: + for req, req_prob in zip(decode_reqs, per_req_probs_cpu.numpy()): + req_prob_ema = self.req_prob.get(req.req_id) + if req_prob_ema is None: + req_prob_ema = _EMAValue(self.accept_mean.decay) + self.req_prob[req.req_id] = req_prob_ema + req_prob_ema.update(self._clip_prob(float(req_prob))) + + if elapsed_ms > 0: + current_output_speed = sum(accept_lens) / elapsed_ms + latency_ema = self.latency_ms_by_batch_size.get(plan.dynamic_batch_size) + if latency_ema is None: + latency_ema = _EMAValue(self.accept_mean.decay) + self.latency_ms_by_batch_size[plan.dynamic_batch_size] = latency_ema + latency_ema.update(elapsed_ms) + self.accepted_token_speed.update(current_output_speed) + self.verify_row_speed.update(plan.dynamic_batch_size / elapsed_ms) + if plan.dynamic_batch_size == plan.req_num: + single_token_speed = self.single_token_speed_by_req_num.get(plan.req_num) + if single_token_speed is None: + single_token_speed = _EMAValue(self.accept_mean.decay) + self.single_token_speed_by_req_num[plan.req_num] = single_token_speed + single_token_speed.update(plan.req_num / elapsed_ms) + baseline_speed = self.single_token_speed_by_req_num.get(plan.req_num) + if baseline_speed is not None and baseline_speed.get(0.0) > 0: + self.actual_speedup.update(current_output_speed / baseline_speed.get(0.0)) + + def get_stats_snapshot(self) -> Dict[str, float]: + return { + "accept_mean": self.accept_mean.get(float(self.max_rows_per_req)), + "accept_second_moment": self.accept_second_moment.get(float(self.max_rows_per_req**2)), + "accepted_token_speed": self.accepted_token_speed.get(0.0), + "verify_row_speed": self.verify_row_speed.get(0.0), + "actual_speedup": self.actual_speedup.get(0.0), + } + + def _build_plan(self, model_input: ModelInput, decode_reqs: List[InferReq]) -> DynamicMTPPlan: + req_num = len(decode_reqs) + original_batch_size = model_input.batch_size + if req_num == 0 or self.mtp_step == 0: + keep_indices = torch.arange(original_batch_size, dtype=torch.long, device="cpu") + return DynamicMTPPlan(req_num, original_batch_size, original_batch_size, keep_indices, [], 1.0, 0.0) + + mean = self.accept_mean.get(float(self.max_rows_per_req)) + second = self.accept_second_moment.get(float(self.max_rows_per_req**2)) + variance = max(0.0, second - mean * mean) + std = math.sqrt(variance) + budget = math.ceil(req_num * mean + self.confidence_k * math.sqrt(req_num) * std) + dynamic_batch_size = max(req_num, min(original_batch_size, budget)) + + per_req_rows = self._allocate_rows(decode_reqs=decode_reqs, dynamic_batch_size=dynamic_batch_size) + keep_indices, per_req_rows = self._build_keep_indices(model_input=model_input, per_req_rows=per_req_rows) + dynamic_batch_size = int(keep_indices.numel()) + + return DynamicMTPPlan( + req_num=req_num, + original_batch_size=original_batch_size, + dynamic_batch_size=dynamic_batch_size, + keep_indices=keep_indices, + per_req_rows=per_req_rows, + estimated_accept_mean=mean, + estimated_accept_std=std, + ) + + def _allocate_rows(self, decode_reqs: List[InferReq], dynamic_batch_size: int) -> List[int]: + req_num = len(decode_reqs) + per_req_rows = [1 for _ in range(req_num)] + remaining = dynamic_batch_size - req_num + if remaining <= 0: + return per_req_rows + + req_order = sorted( + range(req_num), + key=lambda index: self._req_prob(decode_reqs[index]), + reverse=True, + ) + + for req_index in req_order: + req_prob = self._req_prob(decode_reqs[req_index]) + for _ in range(self.mtp_step): + if remaining <= 0: + break + if random.random() >= req_prob: + break + per_req_rows[req_index] += 1 + remaining -= 1 + if remaining <= 0: + break + return per_req_rows + + def _req_prob(self, req: InferReq) -> float: + req_prob = self.req_prob.get(req.req_id) + if req_prob is not None: + return self._clip_prob(req_prob.get(1.0)) + fallback = self.accept_mean.get(float(self.max_rows_per_req)) / float(self.max_rows_per_req) + return self._clip_prob(fallback) + + def _clip_prob(self, value: float) -> float: + return min(1.0, max(0.0, value)) + + def _build_keep_indices(self, model_input: ModelInput, per_req_rows: List[int]): + keep_indices = [] + effective_per_req_rows = [0 for _ in per_req_rows] + req_index = -1 + cur_req_kept = 0 + cur_req_target = 0 + for index, mtp_index in enumerate(model_input.b_mtp_index.tolist()): + if mtp_index == 0: + req_index += 1 + cur_req_kept = 0 + cur_req_target = per_req_rows[req_index] + if cur_req_kept < cur_req_target: + keep_indices.append(index) + cur_req_kept += 1 + effective_per_req_rows[req_index] += 1 + return torch.tensor(keep_indices, dtype=torch.long, device="cpu"), effective_per_req_rows + + def _invert_indices(self, keep_indices: torch.Tensor, total_size: int) -> torch.Tensor: + keep_mask = torch.zeros((total_size,), dtype=torch.bool, device="cpu") + keep_mask[keep_indices] = True + return torch.nonzero(~keep_mask, as_tuple=False).view(-1) + + def _build_mtp_shared_group_infos(self, run_reqs: List[InferReq]) -> torch.Tensor: + max_batch_shared_group_size = get_diverse_max_batch_shared_group_size() + req_ids = [req.req_id for req in run_reqs] + b_mark_shared_group = [] + current_group = [] + for req_id in req_ids: + if not current_group: + current_group.append(req_id) + elif req_id == current_group[-1]: + current_group.append(req_id) + else: + b_mark_shared_group.extend([0 for _ in range(len(current_group))]) + b_mark_shared_group[-1] = len(current_group) + current_group.clear() + current_group.append(req_id) + + if len(current_group) == max_batch_shared_group_size: + b_mark_shared_group.extend([0 for _ in range(len(current_group))]) + b_mark_shared_group[-1] = len(current_group) + current_group.clear() + + if current_group: + b_mark_shared_group.extend([0 for _ in range(len(current_group))]) + b_mark_shared_group[-1] = len(current_group) + + return torch.tensor(b_mark_shared_group, dtype=torch.int32, device="cpu") From 7e67469c4fb497f504863f2a073dd83ad8da2dc6 Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Mon, 18 May 2026 07:08:02 +0000 Subject: [PATCH 06/20] reback --- .../mode_backend/chunked_prefill/impl.py | 101 +++++++++--------- 1 file changed, 48 insertions(+), 53 deletions(-) diff --git a/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl.py b/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl.py index 71b22c08e..c41dbb6d9 100644 --- a/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl.py +++ b/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl.py @@ -25,7 +25,6 @@ from lightllm.utils.log_utils import init_logger from lightllm.utils.dist_utils import get_current_device_id from lightllm.utils.envs_utils import get_env_start_args, enable_dynamic_mtp_verify -from lightllm.server.router.model_infer.mode_backend.dynamic_mtp_planner import DynamicMTPPlanner from .control_state import ControlState logger = init_logger(__name__) @@ -46,9 +45,6 @@ def __init__(self) -> None: self.num_mtp_models = 1 if self.is_mtp_eagle else get_env_start_args().mtp_step self._draft_decode_func = self._draft_decode_eagle if self.is_mtp_eagle else self._draft_decode_vanilla self.enable_dynamic_mtp = enable_dynamic_mtp_verify() - self.dynamic_mtp_planner = ( - DynamicMTPPlanner(mtp_step=get_env_start_args().mtp_step) if self.enable_dynamic_mtp else None - ) else: self.prefill = self.prefill_normal self.decode = self.decode_normal @@ -237,25 +233,9 @@ def decode_mtp( """ MTP解码的通用流程,整合eagle和vanilla的共同逻辑 """ - if self.enable_dynamic_mtp: - # 让通用 pre-process 始终构建最大候选池,动态策略只在 forward 前裁剪。 - for req in decode_reqs: - req.current_mtp_step = req.mtp_step model_input, run_reqs = prepare_decode_inputs(decode_reqs) - dynamic_mtp_plan = None with torch.cuda.stream(g_infer_context.get_overlap_stream()): - if self.enable_dynamic_mtp: - model_input, run_reqs, dynamic_mtp_plan = self.dynamic_mtp_planner.trim_before_forward( - model_input=model_input, - run_reqs=run_reqs, - decode_reqs=decode_reqs, - ) - dynamic_mtp_start_event = None - if self.enable_dynamic_mtp: - dynamic_mtp_start_event = torch.cuda.Event(enable_timing=True) - dynamic_mtp_start_event.record() - b_mtp_index_cpu = model_input.b_mtp_index model_output = self.model.forward(model_input) next_token_ids, next_token_logprobs = sample(model_output.logits, run_reqs, self.eos_id) @@ -277,10 +257,9 @@ def decode_mtp( gpu_tensor=accepted_index, ) - verify_event = torch.cuda.Event(enable_timing=self.enable_dynamic_mtp) + verify_event = torch.cuda.Event() verify_event.record() - per_req_probs_cpu = None if self.enable_dynamic_mtp: all_next_token_ids, additional_mem_indexes_cpu, draft_probs_list = self._draft_decode_func( main_model_input=model_input, @@ -288,13 +267,6 @@ def decode_mtp( next_token_ids=next_token_ids, b_req_mtp_start_loc=b_req_mtp_start_loc, ) - draft_probs_tensor = torch.stack(draft_probs_list, dim=1) - request_start_rows = b_req_mtp_start_loc.to(torch.long) - per_req_probs = draft_probs_tensor[request_start_rows].mean(dim=1) - per_req_probs_cpu = g_pin_mem_manager.async_copy_from_gpu_tensor( - key="dynamic_mtp_req_probs", - gpu_tensor=per_req_probs, - ) else: all_next_token_ids, additional_mem_indexes_cpu = self._draft_decode_func( main_model_input=model_input, @@ -303,6 +275,18 @@ def decode_mtp( b_req_mtp_start_loc=b_req_mtp_start_loc, ) + # dynamic_sizes_gpu 用于第二阶段更新 req 的 mtp_size + if self.enable_dynamic_mtp: + draft_probs_tensor = torch.cat(draft_probs_list, dim=-1).view(self.mtp_step, b_mtp_index_cpu.shape[0]) + dynamic_sizes_gpu = self._compute_dynamic_mtp_size_gpu_part(draft_probs_tensor=draft_probs_tensor) + # 异步拷贝回 CPU Pin Memory + dynamic_sizes_cpu = g_pin_mem_manager.async_copy_from_gpu_tensor( + key="dynamic_mtp_sizes", gpu_tensor=dynamic_sizes_gpu + ) + + dynamic_mtp_event = torch.cuda.Event() + dynamic_mtp_event.record() + mtp_scatter_next_token_ids( req_to_next_token_ids=self.model.req_manager.req_sampling_params_manager.req_to_next_token_ids, b_req_mtp_start_loc=b_req_mtp_start_loc, @@ -326,34 +310,26 @@ def decode_mtp( gpu_tensor=mtp_accept_len, ) - sync_event = torch.cuda.Event(enable_timing=self.enable_dynamic_mtp) + sync_event = torch.cuda.Event() sync_event.record() # 第二阶段 event_pack.notify_post_handle_and_wait_pre_post_handle() - self._update_mtp_verify_token_num( - decode_reqs=decode_reqs, - verify_token_nums=dynamic_mtp_plan.per_req_rows if dynamic_mtp_plan is not None else None, - ) + self._update_mtp_verify_token_num(decode_reqs=decode_reqs) verify_event.synchronize() - dynamic_mtp_elapsed_ms = None accepted_index_cpu_numpy = accepted_index_cpu.numpy() verify_ok_reqs = [run_reqs[i] for i in range(len(run_reqs)) if accepted_index_cpu_numpy[i] == 1] + if self.enable_dynamic_mtp: + dynamic_mtp_event.synchronize() + self._update_dynamic_mtp_size_cpu_part( + run_reqs=run_reqs, dynamic_sizes_cpu=dynamic_sizes_cpu, accepted_index_cpu=accepted_index_cpu + ) update_packs = self._pre_post_handle(verify_ok_reqs, is_chuncked_mode=False) # 第三阶段 event_pack.notify_forward_and_wait_post_handle() sync_event.synchronize() - if self.enable_dynamic_mtp: - dynamic_mtp_elapsed_ms = dynamic_mtp_start_event.elapsed_time(sync_event) - self.dynamic_mtp_planner.update_after_verify( - plan=dynamic_mtp_plan, - decode_reqs=decode_reqs, - mtp_accept_len_cpu=mtp_accept_len_cpu, - elapsed_ms=dynamic_mtp_elapsed_ms, - per_req_probs_cpu=per_req_probs_cpu, - ) # 处理需要释放的内存索引 need_free_mem_indexes = model_input.mem_indexes_cpu[accepted_index_cpu == 0] @@ -379,6 +355,28 @@ def decode_mtp( event_pack.notify_pre_post_handle() return + def _compute_dynamic_mtp_size_gpu_part( + self, + draft_probs_tensor: torch.Tensor, + ) -> torch.Tensor: + rand_vals = torch.rand_like(draft_probs_tensor) + accepted_mask = draft_probs_tensor > rand_vals + valid_steps = torch.cumprod(accepted_mask.to(torch.int32), dim=0) + dynamic_mtp_sizes = valid_steps.sum(dim=0) + return dynamic_mtp_sizes + + def _update_dynamic_mtp_size_cpu_part( + self, + run_reqs: List[InferReq], + dynamic_sizes_cpu: torch.Tensor, + accepted_index_cpu: torch.Tensor, + ): + assert len(run_reqs) == dynamic_sizes_cpu.shape[0] == accepted_index_cpu.shape[0] + for req, new_size, accepted in zip(run_reqs, dynamic_sizes_cpu.numpy(), accepted_index_cpu.numpy()): + if int(accepted) == 1: + req.current_mtp_step = int(new_size) + assert req.current_mtp_step <= req.mtp_step + def _draft_prefill_forward(self, model_input: ModelInput, model_output: ModelOutput, next_token_ids: torch.Tensor): # spec prefill: MTP, 这个地方只是为了填充draft model的 kv, 并不会使用生成的token_id。 draft_model_input = model_input @@ -407,24 +405,17 @@ def _draft_decode_vanilla( draft_next_token_ids = next_token_ids all_next_token_ids = [] all_next_token_ids.append(next_token_ids) - draft_probs_list = [] if self.enable_dynamic_mtp else None # process the draft model output for draft_model_idx in range(self.mtp_step): draft_model_input.input_ids = draft_next_token_ids draft_model_input.mtp_draft_input_hiddens = draft_model_output.mtp_main_output_hiddens # spec decode: MTP draft_model_output: ModelOutput = self.draft_models[draft_model_idx].forward(draft_model_input) - if self.enable_dynamic_mtp: - draft_next_token_ids, draft_probs = self._gen_argmax_token_ids_and_prob(draft_model_output) - draft_probs_list.append(draft_probs) - else: - draft_next_token_ids = self._gen_argmax_token_ids(draft_model_output) + draft_next_token_ids = self._gen_argmax_token_ids(draft_model_output) all_next_token_ids.append(draft_next_token_ids) all_next_token_ids = torch.stack(all_next_token_ids, dim=1) # [batch_size, mtp_step + 1] - if self.enable_dynamic_mtp: - return all_next_token_ids, None, draft_probs_list return all_next_token_ids, None def _draft_decode_eagle( @@ -450,6 +441,8 @@ def _draft_decode_eagle( draft_next_token_ids = next_token_ids all_next_token_ids = [] all_next_token_ids.append(next_token_ids) + + # 用于收集每个 step 的 probs draft_probs_list = [] if self.enable_dynamic_mtp else None # process the draft model output @@ -460,6 +453,7 @@ def _draft_decode_eagle( draft_model_idx = _step % self.num_mtp_models draft_model_output: ModelOutput = self.draft_models[draft_model_idx].forward(draft_model_input) + # 收集 probs(如果需要) if self.enable_dynamic_mtp: draft_next_token_ids, draft_probs = self._gen_argmax_token_ids_and_prob(draft_model_output) draft_probs_list.append(draft_probs) @@ -486,4 +480,5 @@ def _draft_decode_eagle( if self.enable_dynamic_mtp: return all_next_token_ids, eagle_mem_indexes_cpu, draft_probs_list - return all_next_token_ids, eagle_mem_indexes_cpu + else: + return all_next_token_ids, eagle_mem_indexes_cpu From b39600c3a07cebc76e0e545870306e89dc28c998 Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Mon, 18 May 2026 07:11:25 +0000 Subject: [PATCH 07/20] add static mtp_step --- .../router/model_infer/mode_backend/generic_pre_process.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/lightllm/server/router/model_infer/mode_backend/generic_pre_process.py b/lightllm/server/router/model_infer/mode_backend/generic_pre_process.py index 3d9d8815e..c50e0b7c3 100644 --- a/lightllm/server/router/model_infer/mode_backend/generic_pre_process.py +++ b/lightllm/server/router/model_infer/mode_backend/generic_pre_process.py @@ -113,9 +113,7 @@ def prepare_decode_inputs(req_objs: List[InferReq]) -> Tuple[ModelInput, List[In b_mtp_index.append(0) multimodal_params.append(req.multimodal_params) # process the draft tokens. - # 动态 MTP 模式:使用动态 current_mtp_step 构建 batch - # 非动态 MTP 模式:current_mtp_step 为固定的 mtp_step - for step in range(req.current_mtp_step): + for step in range(req.mtp_step): run_reqs.append(req) b_req_idx.append(req.req_idx) seq_len += 1 From 783be5a7f7b73b40b4c1c57425bb24b7f9ec844a Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Mon, 18 May 2026 07:41:22 +0000 Subject: [PATCH 08/20] fix --- .../model_infer/mode_backend/base_backend.py | 19 +++++++------------ 1 file changed, 7 insertions(+), 12 deletions(-) diff --git a/lightllm/server/router/model_infer/mode_backend/base_backend.py b/lightllm/server/router/model_infer/mode_backend/base_backend.py index bde7fad52..321055b4b 100644 --- a/lightllm/server/router/model_infer/mode_backend/base_backend.py +++ b/lightllm/server/router/model_infer/mode_backend/base_backend.py @@ -775,19 +775,14 @@ def _update_mtp_accept_ratio( return - def _update_mtp_verify_token_num( - self, - decode_reqs: List[InferReq], - verify_token_nums: Optional[List[int]] = None, - ): + def _update_mtp_verify_token_num(self, decode_reqs: List[InferReq]): if self.is_master_in_dp: - if verify_token_nums is None: - verify_token_nums = [1 + req.current_mtp_step for req in decode_reqs] - assert len(decode_reqs) == len(verify_token_nums) - for req, verify_token_num in zip(decode_reqs, verify_token_nums): - # 统计发送给主模型验证的 token 数量,动态 MTP 模式由 planner 传入实际裁剪后的行数。 - assert verify_token_num >= 1 - req.update_mtp_verify_token_num(verify_token_num=verify_token_num) + for req in decode_reqs: + # 统计发送给主模型验证的 token 数量:1 个主 token + 当前 mtp_size 个 draft token + # 在静态 MTP 模式下,使用固定的 mtp_step;在动态 MTP 模式下,使用动态调整的 current_mtp_step + # current_mtp_step 在静态 MTP 模式下为 mtp_step,在动态 MTP 模式下会在推理过程中动态设置。 + assert req.current_mtp_step >= 0 + req.update_mtp_verify_token_num(verify_token_num=1 + req.current_mtp_step) return def _gen_argmax_token_ids(self, model_output: ModelOutput): From 0054d08fb9904c24d6467951014ecae6b86ac52f Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Mon, 18 May 2026 08:18:06 +0000 Subject: [PATCH 09/20] fix --- .../basemodel/triton_kernel/mtp_utils.py | 30 +++++++++++++++++++ 1 file changed, 30 insertions(+) diff --git a/lightllm/common/basemodel/triton_kernel/mtp_utils.py b/lightllm/common/basemodel/triton_kernel/mtp_utils.py index 2d70a68c0..843a77d96 100644 --- a/lightllm/common/basemodel/triton_kernel/mtp_utils.py +++ b/lightllm/common/basemodel/triton_kernel/mtp_utils.py @@ -1,3 +1,4 @@ +from typing import Optional import triton import triton.language as tl import torch @@ -93,10 +94,15 @@ def _fwd_kernel_mtp_scatter_next_token_ids( req_to_next_token_ids_stride, all_next_token_ids, all_next_token_ids_stride, + req_to_next_token_probs, + req_to_next_token_probs_stride, + all_next_token_probs, + all_next_token_probs_stride, mtp_accept_len, b_req_mtp_start_loc, b_req_idx, mtp_step, + HAS_HAS_NEXT_TOKEN_PROBS: tl.constexpr, BLOCK_SIZE: tl.constexpr, ): @@ -106,6 +112,17 @@ def _fwd_kernel_mtp_scatter_next_token_ids( cur_req_idx = tl.load(b_req_idx + req_start_loc) offset = tl.arange(0, BLOCK_SIZE) + if HAS_HAS_NEXT_TOKEN_PROBS: + cur_next_token_probs = tl.load( + all_next_token_probs + (req_start_loc + accept_len - 1) * all_next_token_probs_stride + offset, + mask=offset < mtp_step, + other=0.0, + ) + tl.store( + req_to_next_token_probs + cur_req_idx * req_to_next_token_probs_stride + offset, + cur_next_token_probs, + mask=offset < mtp_step, + ) scatter_next_token_ids = tl.load( all_next_token_ids + (req_start_loc + accept_len - 1) * all_next_token_ids_stride + offset, mask=offset < mtp_step, @@ -125,12 +142,20 @@ def mtp_scatter_next_token_ids( all_next_token_ids: torch.Tensor, b_req_idx: torch.Tensor, mtp_accept_len: torch.Tensor, + req_to_next_token_probs: Optional[torch.Tensor] = None, + all_next_token_probs: Optional[torch.Tensor] = None, ): max_mtp_step = req_to_next_token_ids.shape[1] BLOCK_SIZE = 16 assert max_mtp_step <= BLOCK_SIZE, f"max_mtp_step must be less than {BLOCK_SIZE}" num_reqs = b_req_mtp_start_loc.shape[0] mtp_step = all_next_token_ids.shape[1] + if req_to_next_token_probs is not None: + assert all_next_token_probs is not None + assert all_next_token_probs.shape == all_next_token_ids.shape + + HAS_HAS_NEXT_TOKEN_PROBS = req_to_next_token_probs is not None + grid = (num_reqs,) num_warps = 1 _fwd_kernel_mtp_scatter_next_token_ids[grid]( @@ -138,10 +163,15 @@ def mtp_scatter_next_token_ids( req_to_next_token_ids_stride=req_to_next_token_ids.stride(0), all_next_token_ids=all_next_token_ids, all_next_token_ids_stride=all_next_token_ids.stride(0), + req_to_next_token_probs=req_to_next_token_probs, + req_to_next_token_probs_stride=req_to_next_token_probs.stride(0), + all_next_token_probs=all_next_token_probs, + all_next_token_probs_stride=all_next_token_probs.stride(0), mtp_accept_len=mtp_accept_len, b_req_mtp_start_loc=b_req_mtp_start_loc, b_req_idx=b_req_idx, mtp_step=mtp_step, + HAS_HAS_NEXT_TOKEN_PROBS=HAS_HAS_NEXT_TOKEN_PROBS, BLOCK_SIZE=BLOCK_SIZE, num_warps=num_warps, num_stages=1, From ccc0f78779e424221122fa46f812db7b0ae00824 Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Mon, 18 May 2026 08:23:12 +0000 Subject: [PATCH 10/20] fix req_manager --- lightllm/common/req_manager.py | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/lightllm/common/req_manager.py b/lightllm/common/req_manager.py index 3a4e2b631..278b3509c 100644 --- a/lightllm/common/req_manager.py +++ b/lightllm/common/req_manager.py @@ -6,7 +6,7 @@ from lightllm.common.basemodel.triton_kernel.gen_sampling_params import token_id_counter from lightllm.common.basemodel.triton_kernel.gen_sampling_params import update_req_to_token_id_counter -from lightllm.utils.envs_utils import enable_env_vars, get_env_start_args +from lightllm.utils.envs_utils import enable_env_vars, get_env_start_args, enable_dynamic_mtp_verify from lightllm.utils.config_utils import get_vocab_size from lightllm.server.router.model_infer.pin_mem_manager import g_pin_mem_manager @@ -116,6 +116,15 @@ def __init__(self, max_request_num): dtype=torch.int64, device="cuda", ) + if enable_dynamic_mtp_verify(): + self.req_to_next_token_probs = torch.zeros( + (max_request_num + 1, 16), + dtype=torch.float32, + device="cuda", + ) + else: + self.req_to_next_token_probs = None + self.req_to_exponential_decay_length_penalty = torch.zeros( max_request_num + 1, dtype=torch.float32, device="cuda" ) @@ -137,6 +146,9 @@ def init_req_sampling_params(self, req): shm_param = req.sampling_param.shm_param self.req_to_next_token_ids[req.req_idx][0:1].fill_(req.get_last_gen_token()) + if enable_dynamic_mtp_verify(): + self.req_to_next_token_probs[req.req_idx].fill_(0.0) + self.req_to_next_token_probs[req.req_idx][0:1].fill_(1.0) self.req_to_presence_penalty[req.req_idx].fill_(shm_param.presence_penalty) self.req_to_frequency_penalty[req.req_idx].fill_(shm_param.frequency_penalty) self.req_to_repetition_penalty[req.req_idx].fill_(shm_param.repetition_penalty) From 326515ce99139ea812ca7156f258b23ee666f057 Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Mon, 18 May 2026 08:55:42 +0000 Subject: [PATCH 11/20] fix --- .../model_infer/mode_backend/base_backend.py | 19 ++++-- .../mode_backend/chunked_prefill/impl.py | 64 +++++++++++++------ 2 files changed, 58 insertions(+), 25 deletions(-) diff --git a/lightllm/server/router/model_infer/mode_backend/base_backend.py b/lightllm/server/router/model_infer/mode_backend/base_backend.py index 321055b4b..4203b093f 100644 --- a/lightllm/server/router/model_infer/mode_backend/base_backend.py +++ b/lightllm/server/router/model_infer/mode_backend/base_backend.py @@ -4,6 +4,7 @@ import time import threading import torch.distributed as dist +import collections from typing import List, Tuple, Callable, Optional from transformers.configuration_utils import PretrainedConfig from lightllm.utils.infer_utils import set_random_seed @@ -775,14 +776,18 @@ def _update_mtp_accept_ratio( return - def _update_mtp_verify_token_num(self, decode_reqs: List[InferReq]): + def _update_mtp_verify_token_num( + self, decode_reqs: List[InferReq], dynamic_mtp_run_reqs: Optional[List[InferReq]] = None + ): if self.is_master_in_dp: - for req in decode_reqs: - # 统计发送给主模型验证的 token 数量:1 个主 token + 当前 mtp_size 个 draft token - # 在静态 MTP 模式下,使用固定的 mtp_step;在动态 MTP 模式下,使用动态调整的 current_mtp_step - # current_mtp_step 在静态 MTP 模式下为 mtp_step,在动态 MTP 模式下会在推理过程中动态设置。 - assert req.current_mtp_step >= 0 - req.update_mtp_verify_token_num(verify_token_num=1 + req.current_mtp_step) + if dynamic_mtp_run_reqs is None: + for req in decode_reqs: + assert req.mtp_step > 0 + req.update_mtp_verify_token_num(verify_token_num=1 + req.mtp_step) + else: + counter = collections.Counter([req.req_idx for req in dynamic_mtp_run_reqs]) + for req in decode_reqs: + req.update_mtp_verify_token_num(verify_token_num=1 + counter[req.req_idx] - 1) return def _gen_argmax_token_ids(self, model_output: ModelOutput): diff --git a/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl.py b/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl.py index c41dbb6d9..d68d34918 100644 --- a/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl.py +++ b/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl.py @@ -236,17 +236,27 @@ def decode_mtp( model_input, run_reqs = prepare_decode_inputs(decode_reqs) with torch.cuda.stream(g_infer_context.get_overlap_stream()): - b_mtp_index_cpu = model_input.b_mtp_index + + if self.enable_dynamic_mtp: + # 根据当前的 batch size 和 dynamic_batch_size 计算出需要裁剪的 batch size 的model_input + dynamic_batch_size = 10 # TODO: 需要根据实际情况计算出 dynamic_batch_size + trans_to_dynamic_model_input = None # TODO: 需要根据实际情况实现 trans_to_dynamic_model_input + model_input, selected_run_reqs = trans_to_dynamic_model_input(model_input, dynamic_batch_size) + # selected_run_reqs 是一个 gpu tensor, 类型为 int, 0, 表示没有选中, 1 表示选中。 + + selected_run_reqs_cpu = g_pin_mem_manager.async_copy_from_gpu_tensor( + key="selected_run_reqs", + gpu_tensor=selected_run_reqs, + ) + trans_dynamic_model_input_event = torch.cuda.Event() + trans_dynamic_model_input_event.record() + model_output = self.model.forward(model_input) next_token_ids, next_token_logprobs = sample(model_output.logits, run_reqs, self.eos_id) # verify the next_token_ids - b_req_mtp_start_loc = [index for index, mtp_index in enumerate(b_mtp_index_cpu) if mtp_index == 0] - b_req_mtp_start_loc = g_pin_mem_manager.gen_from_list( - key="b_req_mtp_start_loc", - data=b_req_mtp_start_loc, - dtype=torch.int32, - ).cuda(non_blocking=True) - + get_b_req_mtp_start_loc = None # TODO: 需要根据实际情况实现 get_b_req_mtp_start_loc + b_req_mtp_start_loc = get_b_req_mtp_start_loc(model_input.b_mtp_index, req_num=len(decode_reqs)) + # b_req_mtp_start_loc 是一个 gpu tensor, 类型为 int, 表示每个请求的 mtp_start_loc, shape 为 len(decode_reqs) mtp_accept_len, accepted_index = self._verify_mtp_v2( new_next_token_ids=next_token_ids, b_req_idx=model_input.b_req_idx, @@ -277,15 +287,20 @@ def decode_mtp( # dynamic_sizes_gpu 用于第二阶段更新 req 的 mtp_size if self.enable_dynamic_mtp: - draft_probs_tensor = torch.cat(draft_probs_list, dim=-1).view(self.mtp_step, b_mtp_index_cpu.shape[0]) + draft_probs_tensor = torch.cat(draft_probs_list, dim=-1).view( + self.mtp_step, model_input.b_mtp_index.shape[0] + ) dynamic_sizes_gpu = self._compute_dynamic_mtp_size_gpu_part(draft_probs_tensor=draft_probs_tensor) # 异步拷贝回 CPU Pin Memory dynamic_sizes_cpu = g_pin_mem_manager.async_copy_from_gpu_tensor( key="dynamic_mtp_sizes", gpu_tensor=dynamic_sizes_gpu ) - - dynamic_mtp_event = torch.cuda.Event() - dynamic_mtp_event.record() + dynamic_sizes_cpu # TODO, use to update statcis. + draft_probs_list = [e.view(-1, 1) for e in draft_probs_list] + draft_probs_list = [torch.ones_like(draft_probs_list[-1])] + draft_probs_list + all_next_token_probs = torch.cat(draft_probs_list, dim=-1) # [batch_size, mtp_step + 1] + else: + all_next_token_probs = None mtp_scatter_next_token_ids( req_to_next_token_ids=self.model.req_manager.req_sampling_params_manager.req_to_next_token_ids, @@ -293,6 +308,8 @@ def decode_mtp( all_next_token_ids=all_next_token_ids, b_req_idx=model_input.b_req_idx, mtp_accept_len=mtp_accept_len, + req_to_next_token_probs=self.model.req_manager.req_sampling_params_manager.req_to_next_token_probs, + all_next_token_probs=all_next_token_probs, ) next_token_ids_cpu, next_token_logprobs_cpu = self._async_copy_next_token_infos_to_pin_mem( @@ -315,22 +332,33 @@ def decode_mtp( # 第二阶段 event_pack.notify_post_handle_and_wait_pre_post_handle() - self._update_mtp_verify_token_num(decode_reqs=decode_reqs) + + if self.enable_dynamic_mtp: + trans_dynamic_model_input_event.synchronize() + selected_run_reqs_cpu_numpy = selected_run_reqs_cpu.numpy() + run_reqs = [run_reqs[i] for i in range(len(run_reqs)) if selected_run_reqs_cpu_numpy[i] == 1] + + if self.enable_dynamic_mtp: + self._update_mtp_verify_token_num(decode_reqs=decode_reqs, dynamic_mtp_run_reqs=run_reqs) + else: + self._update_mtp_verify_token_num(decode_reqs=decode_reqs) verify_event.synchronize() accepted_index_cpu_numpy = accepted_index_cpu.numpy() verify_ok_reqs = [run_reqs[i] for i in range(len(run_reqs)) if accepted_index_cpu_numpy[i] == 1] - if self.enable_dynamic_mtp: - dynamic_mtp_event.synchronize() - self._update_dynamic_mtp_size_cpu_part( - run_reqs=run_reqs, dynamic_sizes_cpu=dynamic_sizes_cpu, accepted_index_cpu=accepted_index_cpu - ) update_packs = self._pre_post_handle(verify_ok_reqs, is_chuncked_mode=False) # 第三阶段 event_pack.notify_forward_and_wait_post_handle() sync_event.synchronize() + if self.enable_dynamic_mtp: + # TODO: 更新动态 mtp step 步的相关信息到 planer中,进行相关的信息统计。便于分析。 + # self._update_dynamic_mtp_size_cpu_part( + # run_reqs=run_reqs, dynamic_sizes_cpu=dynamic_sizes_cpu, accepted_index_cpu=accepted_index_cpu + # ) + pass + # 处理需要释放的内存索引 need_free_mem_indexes = model_input.mem_indexes_cpu[accepted_index_cpu == 0] if additional_mem_indexes_cpu is not None: From 19fbb69fbc1940af71b6e2c89895137b7bd4ddfb Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Mon, 18 May 2026 09:22:51 +0000 Subject: [PATCH 12/20] fix --- .../server/router/model_infer/infer_batch.py | 5 +---- .../mode_backend/chunked_prefill/impl.py | 19 ++++++++++++------- 2 files changed, 13 insertions(+), 11 deletions(-) diff --git a/lightllm/server/router/model_infer/infer_batch.py b/lightllm/server/router/model_infer/infer_batch.py index 46608da13..7b931f153 100644 --- a/lightllm/server/router/model_infer/infer_batch.py +++ b/lightllm/server/router/model_infer/infer_batch.py @@ -435,10 +435,7 @@ def __init__( # mtp_step 用来记录一个请求 draft模型每步需要生成的token数量 # 正常模式下,这个值为0,在 mtp 模式下,这个值为 draft 模型每步需要生成的token数量 self.mtp_step: int = get_env_start_args().mtp_step - # current_mtp_step 用来记录当前的 MTP 验证长度(<= mtp_step) - # 在启用动态 MTP 验证时,每步会根据 prob 分布重新设置该值 - # 静态模式下为 mtp_step,动态模式下为动态计算的 MTP 验证长度 - self.current_mtp_step: int = self.mtp_step + if self.mtp_step > 0: self.decode_need_token_num = self._mtp_decode_need_token_num else: diff --git a/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl.py b/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl.py index d68d34918..392c826b9 100644 --- a/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl.py +++ b/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl.py @@ -353,11 +353,12 @@ def decode_mtp( sync_event.synchronize() if self.enable_dynamic_mtp: - # TODO: 更新动态 mtp step 步的相关信息到 planer中,进行相关的信息统计。便于分析。 - # self._update_dynamic_mtp_size_cpu_part( - # run_reqs=run_reqs, dynamic_sizes_cpu=dynamic_sizes_cpu, accepted_index_cpu=accepted_index_cpu - # ) - pass + self._update_dynamic_mtp_size_cpu_part( + decode_reqs=decode_reqs, + run_reqs=run_reqs, + dynamic_sizes_cpu=dynamic_sizes_cpu, + accepted_index_cpu=accepted_index_cpu, + ) # 处理需要释放的内存索引 need_free_mem_indexes = model_input.mem_indexes_cpu[accepted_index_cpu == 0] @@ -395,15 +396,19 @@ def _compute_dynamic_mtp_size_gpu_part( def _update_dynamic_mtp_size_cpu_part( self, + decode_reqs: List[InferReq], run_reqs: List[InferReq], dynamic_sizes_cpu: torch.Tensor, accepted_index_cpu: torch.Tensor, ): + id_to_current_mtp_step = {} assert len(run_reqs) == dynamic_sizes_cpu.shape[0] == accepted_index_cpu.shape[0] for req, new_size, accepted in zip(run_reqs, dynamic_sizes_cpu.numpy(), accepted_index_cpu.numpy()): if int(accepted) == 1: - req.current_mtp_step = int(new_size) - assert req.current_mtp_step <= req.mtp_step + assert int(new_size) <= req.mtp_step + id_to_current_mtp_step[req.req_idx] = int(new_size) + # TODO 将 id_to_current_mtp_step 的信息更新到 planner 中去 + pass def _draft_prefill_forward(self, model_input: ModelInput, model_output: ModelOutput, next_token_ids: torch.Tensor): # spec prefill: MTP, 这个地方只是为了填充draft model的 kv, 并不会使用生成的token_id。 From 880578b42b0814d83d95a8394e7185f2ecda9fe8 Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Tue, 19 May 2026 00:42:52 +0000 Subject: [PATCH 13/20] fix --- .../model_infer/mode_backend/chunked_prefill/impl.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl.py b/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl.py index 392c826b9..501a25c58 100644 --- a/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl.py +++ b/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl.py @@ -251,6 +251,9 @@ def decode_mtp( trans_dynamic_model_input_event = torch.cuda.Event() trans_dynamic_model_input_event.record() + start_time_event = torch.cuda.Event(enable_timing=True) + start_time_event.record() + model_output = self.model.forward(model_input) next_token_ids, next_token_logprobs = sample(model_output.logits, run_reqs, self.eos_id) # verify the next_token_ids @@ -267,7 +270,7 @@ def decode_mtp( gpu_tensor=accepted_index, ) - verify_event = torch.cuda.Event() + verify_event = torch.cuda.Event(enable_timing=True) verify_event.record() if self.enable_dynamic_mtp: @@ -353,12 +356,18 @@ def decode_mtp( sync_event.synchronize() if self.enable_dynamic_mtp: + # 更新 动态verify token 数据到 planner 中去 self._update_dynamic_mtp_size_cpu_part( decode_reqs=decode_reqs, run_reqs=run_reqs, dynamic_sizes_cpu=dynamic_sizes_cpu, accepted_index_cpu=accepted_index_cpu, ) + # 更新单token的速度信息到 planner 中去 + per_token_cost_ms = start_time_event.elapsed_time(verify_event) / (mtp_accept_len_cpu.sum().item()) + # TODO 将 per_token_cost_ms 更新到 planner 中去 + per_token_cost_ms = per_token_cost_ms + 0.0 + pass # 处理需要释放的内存索引 need_free_mem_indexes = model_input.mem_indexes_cpu[accepted_index_cpu == 0] From 6320722d11fea970213a5f7cfc0e7ae05da1ce87 Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Tue, 19 May 2026 02:01:35 +0000 Subject: [PATCH 14/20] fix --- .../mode_backend/dynamic_mtp_planner.py | 293 +++--------------- 1 file changed, 40 insertions(+), 253 deletions(-) diff --git a/lightllm/server/router/model_infer/mode_backend/dynamic_mtp_planner.py b/lightllm/server/router/model_infer/mode_backend/dynamic_mtp_planner.py index b09a711bd..7be523716 100644 --- a/lightllm/server/router/model_infer/mode_backend/dynamic_mtp_planner.py +++ b/lightllm/server/router/model_infer/mode_backend/dynamic_mtp_planner.py @@ -1,45 +1,5 @@ -import copy -import math -import random -from dataclasses import dataclass from typing import Dict, List, Optional -import torch - -from lightllm.common.basemodel.batch_objs import ModelInput -from lightllm.common.basemodel.infer_lock import g_infer_state_lock -from lightllm.server.router.model_infer.infer_batch import InferReq, g_infer_context -from lightllm.utils.envs_utils import get_diverse_max_batch_shared_group_size - - -@dataclass -class DynamicMTPPlan: - req_num: int - original_batch_size: int - dynamic_batch_size: int - keep_indices: torch.Tensor - per_req_rows: List[int] - estimated_accept_mean: float - estimated_accept_std: float - - -class _EMAValue: - def __init__(self, decay: float, init_value: Optional[float] = None) -> None: - self.decay = decay - self.value = init_value - self.initialized = init_value is not None - - def update(self, new_value: float) -> float: - if not self.initialized: - self.value = new_value - self.initialized = True - else: - self.value = self.decay * self.value + (1.0 - self.decay) * new_value - return self.value - - def get(self, fallback: float) -> float: - return self.value if self.initialized else fallback - class DynamicMTPPlanner: def __init__( @@ -49,227 +9,54 @@ def __init__( confidence_k: float = 1.0, ) -> None: self.mtp_step = mtp_step - self.max_rows_per_req = mtp_step + 1 self.confidence_k = confidence_k - self.accept_mean = _EMAValue(ema_decay, init_value=float(self.max_rows_per_req)) - self.accept_second_moment = _EMAValue(ema_decay, init_value=float(self.max_rows_per_req**2)) - self.req_accept_mean: Dict[int, _EMAValue] = {} - self.req_prob: Dict[int, _EMAValue] = {} - self.latency_ms_by_batch_size: Dict[int, _EMAValue] = {} - self.accepted_token_speed = _EMAValue(ema_decay) - self.verify_row_speed = _EMAValue(ema_decay) - self.actual_speedup = _EMAValue(ema_decay) - self.single_token_speed_by_req_num: Dict[int, _EMAValue] = {} - self.last_plan: Optional[DynamicMTPPlan] = None - - def trim_before_forward( - self, - model_input: ModelInput, - run_reqs: List[InferReq], - decode_reqs: List[InferReq], - ): - plan = self._build_plan(model_input=model_input, decode_reqs=decode_reqs) - self.last_plan = plan - if plan.dynamic_batch_size == plan.original_batch_size: - return model_input, run_reqs, plan - - pruned_indices = self._invert_indices(plan.keep_indices, plan.original_batch_size) - if pruned_indices.numel() > 0: - pruned_mem_indexes = model_input.mem_indexes_cpu[pruned_indices] - g_infer_state_lock.acquire() - g_infer_context.req_manager.mem_manager.free(pruned_mem_indexes) - g_infer_state_lock.release() - - trimmed_input = copy.copy(model_input) - keep_indices = plan.keep_indices - keep_list = keep_indices.tolist() - - trimmed_input.batch_size = plan.dynamic_batch_size - trimmed_input.b_req_idx = model_input.b_req_idx[keep_indices].contiguous() - trimmed_input.b_mtp_index = model_input.b_mtp_index[keep_indices].contiguous() - trimmed_input.b_seq_len = model_input.b_seq_len[keep_indices].contiguous() - trimmed_input.mem_indexes_cpu = model_input.mem_indexes_cpu[keep_indices].contiguous() - trimmed_input.mem_indexes = None - trimmed_input.total_token_num = int(trimmed_input.b_seq_len.sum().item()) - trimmed_input.max_kv_seq_len = int(trimmed_input.b_seq_len.max().item()) - trimmed_input.multimodal_params = [model_input.multimodal_params[index] for index in keep_list] - trimmed_run_reqs = [run_reqs[index] for index in keep_list] - trimmed_input.b_mark_shared_group = self._build_mtp_shared_group_infos(trimmed_run_reqs) - trimmed_input.check_input() - return trimmed_input, trimmed_run_reqs, plan - - def update_after_verify( - self, - plan: DynamicMTPPlan, - decode_reqs: List[InferReq], - mtp_accept_len_cpu: torch.Tensor, - elapsed_ms: float, - per_req_probs_cpu: Optional[torch.Tensor] = None, - ) -> None: - if plan is None: - return - - accept_lens = [int(value) for value in mtp_accept_len_cpu.numpy()] - if not accept_lens: - return - - batch_mean = sum(accept_lens) / len(accept_lens) - batch_second = sum(value * value for value in accept_lens) / len(accept_lens) - self.accept_mean.update(batch_mean) - self.accept_second_moment.update(batch_second) - - for req, accept_len in zip(decode_reqs, accept_lens): - req_ema = self.req_accept_mean.get(req.req_id) - if req_ema is None: - req_ema = _EMAValue(self.accept_mean.decay, init_value=batch_mean) - self.req_accept_mean[req.req_id] = req_ema - req_ema.update(float(accept_len)) - - if per_req_probs_cpu is not None: - for req, req_prob in zip(decode_reqs, per_req_probs_cpu.numpy()): - req_prob_ema = self.req_prob.get(req.req_id) - if req_prob_ema is None: - req_prob_ema = _EMAValue(self.accept_mean.decay) - self.req_prob[req.req_id] = req_prob_ema - req_prob_ema.update(self._clip_prob(float(req_prob))) - - if elapsed_ms > 0: - current_output_speed = sum(accept_lens) / elapsed_ms - latency_ema = self.latency_ms_by_batch_size.get(plan.dynamic_batch_size) - if latency_ema is None: - latency_ema = _EMAValue(self.accept_mean.decay) - self.latency_ms_by_batch_size[plan.dynamic_batch_size] = latency_ema - latency_ema.update(elapsed_ms) - self.accepted_token_speed.update(current_output_speed) - self.verify_row_speed.update(plan.dynamic_batch_size / elapsed_ms) - if plan.dynamic_batch_size == plan.req_num: - single_token_speed = self.single_token_speed_by_req_num.get(plan.req_num) - if single_token_speed is None: - single_token_speed = _EMAValue(self.accept_mean.decay) - self.single_token_speed_by_req_num[plan.req_num] = single_token_speed - single_token_speed.update(plan.req_num / elapsed_ms) - baseline_speed = self.single_token_speed_by_req_num.get(plan.req_num) - if baseline_speed is not None and baseline_speed.get(0.0) > 0: - self.actual_speedup.update(current_output_speed / baseline_speed.get(0.0)) - - def get_stats_snapshot(self) -> Dict[str, float]: - return { - "accept_mean": self.accept_mean.get(float(self.max_rows_per_req)), - "accept_second_moment": self.accept_second_moment.get(float(self.max_rows_per_req**2)), - "accepted_token_speed": self.accepted_token_speed.get(0.0), - "verify_row_speed": self.verify_row_speed.get(0.0), - "actual_speedup": self.actual_speedup.get(0.0), - } - - def _build_plan(self, model_input: ModelInput, decode_reqs: List[InferReq]) -> DynamicMTPPlan: - req_num = len(decode_reqs) - original_batch_size = model_input.batch_size - if req_num == 0 or self.mtp_step == 0: - keep_indices = torch.arange(original_batch_size, dtype=torch.long, device="cpu") - return DynamicMTPPlan(req_num, original_batch_size, original_batch_size, keep_indices, [], 1.0, 0.0) - - mean = self.accept_mean.get(float(self.max_rows_per_req)) - second = self.accept_second_moment.get(float(self.max_rows_per_req**2)) - variance = max(0.0, second - mean * mean) - std = math.sqrt(variance) - budget = math.ceil(req_num * mean + self.confidence_k * math.sqrt(req_num) * std) - dynamic_batch_size = max(req_num, min(original_batch_size, budget)) + # 记录每个请求的 accept_len 的 ema 值, 用于分布统计 + self.req_accept_len_ema = _EMAValue(ema_decay, init_value=float(self.mtp_step + 1)) + self.req_accept_len_second_moment_ema = _EMAValue(ema_decay, init_value=float((self.mtp_step + 1) ** 2)) - per_req_rows = self._allocate_rows(decode_reqs=decode_reqs, dynamic_batch_size=dynamic_batch_size) - keep_indices, per_req_rows = self._build_keep_indices(model_input=model_input, per_req_rows=per_req_rows) - dynamic_batch_size = int(keep_indices.numel()) + # 记录每个请求对应的单token速度记录信息 + self.req_num_to_speed_dict: Dict[int, List[_EMAValue]] = {} - return DynamicMTPPlan( - req_num=req_num, - original_batch_size=original_batch_size, - dynamic_batch_size=dynamic_batch_size, - keep_indices=keep_indices, - per_req_rows=per_req_rows, - estimated_accept_mean=mean, - estimated_accept_std=std, - ) + def update_req_accept_len_statics(self, accept_lens: List[int]) -> None: + for accept_len in accept_lens: + self.req_accept_len_ema.update(float(accept_len)) + self.req_accept_len_second_moment_ema.update(float(accept_len ** 2)) + return - def _allocate_rows(self, decode_reqs: List[InferReq], dynamic_batch_size: int) -> List[int]: - req_num = len(decode_reqs) - per_req_rows = [1 for _ in range(req_num)] - remaining = dynamic_batch_size - req_num - if remaining <= 0: - return per_req_rows + def update_req_num_speed_statics(self, req_num: int, dynamic_batch_size: int, per_token_cost_ms: float) -> None: + speed_ema_list = self._get_req_num_speed_ema_list(req_num) + index = dynamic_batch_size - req_num + ema_obj = speed_ema_list[index] + ema_obj.update(per_token_cost_ms) + return - req_order = sorted( - range(req_num), - key=lambda index: self._req_prob(decode_reqs[index]), - reverse=True, - ) + def get_dynamic_batch_size(self, req_num: int, original_batch_size: int) -> int: + assert req_num * (self.mtp_step + 1) == original_batch_size - for req_index in req_order: - req_prob = self._req_prob(decode_reqs[req_index]) - for _ in range(self.mtp_step): - if remaining <= 0: - break - if random.random() >= req_prob: - break - per_req_rows[req_index] += 1 - remaining -= 1 - if remaining <= 0: - break - return per_req_rows + mean = self.req_accept_len_ema.get() + return max(req_num, int(mean)) - def _req_prob(self, req: InferReq) -> float: - req_prob = self.req_prob.get(req.req_id) - if req_prob is not None: - return self._clip_prob(req_prob.get(1.0)) - fallback = self.accept_mean.get(float(self.max_rows_per_req)) / float(self.max_rows_per_req) - return self._clip_prob(fallback) + def _get_req_num_speed_ema_list(self, req_num: int) -> List["_EMAValue"]: + if req_num not in self.req_num_to_speed_dict: + self.req_num_to_speed_dict[req_num] = [ + _EMAValue(self.ema_decay, init_value=10000000.0) for _ in range(req_num * (self.mtp_step)) + ] + return self.req_num_to_speed_dict[req_num] - def _clip_prob(self, value: float) -> float: - return min(1.0, max(0.0, value)) - def _build_keep_indices(self, model_input: ModelInput, per_req_rows: List[int]): - keep_indices = [] - effective_per_req_rows = [0 for _ in per_req_rows] - req_index = -1 - cur_req_kept = 0 - cur_req_target = 0 - for index, mtp_index in enumerate(model_input.b_mtp_index.tolist()): - if mtp_index == 0: - req_index += 1 - cur_req_kept = 0 - cur_req_target = per_req_rows[req_index] - if cur_req_kept < cur_req_target: - keep_indices.append(index) - cur_req_kept += 1 - effective_per_req_rows[req_index] += 1 - return torch.tensor(keep_indices, dtype=torch.long, device="cpu"), effective_per_req_rows - - def _invert_indices(self, keep_indices: torch.Tensor, total_size: int) -> torch.Tensor: - keep_mask = torch.zeros((total_size,), dtype=torch.bool, device="cpu") - keep_mask[keep_indices] = True - return torch.nonzero(~keep_mask, as_tuple=False).view(-1) - - def _build_mtp_shared_group_infos(self, run_reqs: List[InferReq]) -> torch.Tensor: - max_batch_shared_group_size = get_diverse_max_batch_shared_group_size() - req_ids = [req.req_id for req in run_reqs] - b_mark_shared_group = [] - current_group = [] - for req_id in req_ids: - if not current_group: - current_group.append(req_id) - elif req_id == current_group[-1]: - current_group.append(req_id) - else: - b_mark_shared_group.extend([0 for _ in range(len(current_group))]) - b_mark_shared_group[-1] = len(current_group) - current_group.clear() - current_group.append(req_id) - - if len(current_group) == max_batch_shared_group_size: - b_mark_shared_group.extend([0 for _ in range(len(current_group))]) - b_mark_shared_group[-1] = len(current_group) - current_group.clear() +class _EMAValue: + def __init__(self, decay: float, init_value: float) -> None: + """ """ + assert decay > 0.0 and decay < 1.0 + self.decay = decay + self.current_decay = 0.0 + self.value = init_value - if current_group: - b_mark_shared_group.extend([0 for _ in range(len(current_group))]) - b_mark_shared_group[-1] = len(current_group) + def update(self, new_value: float) -> float: + self.value = self.current_decay * self.value + (1.0 - self.current_decay) * new_value + # 更新 current_decay 的值,使得 current_decay 逐渐逼近 decay 的值 + self.current_decay = min(self.decay, (self.decay + self.current_decay) / 2.0 + 0.001) + return self.value - return torch.tensor(b_mark_shared_group, dtype=torch.int32, device="cpu") + def get(self) -> float: + return self.value From 337c44163977ddae0396d0fd3ec1850a143998a1 Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Tue, 19 May 2026 02:14:08 +0000 Subject: [PATCH 15/20] fix --- .../mode_backend/chunked_prefill/impl.py | 28 +++++++++++++------ 1 file changed, 19 insertions(+), 9 deletions(-) diff --git a/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl.py b/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl.py index 501a25c58..3db771f91 100644 --- a/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl.py +++ b/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl.py @@ -26,6 +26,7 @@ from lightllm.utils.dist_utils import get_current_device_id from lightllm.utils.envs_utils import get_env_start_args, enable_dynamic_mtp_verify from .control_state import ControlState +from lightllm.server.router.model_infer.mode_backend.dynamic_mtp_planner import DynamicMTPPlanner logger = init_logger(__name__) @@ -49,6 +50,9 @@ def __init__(self) -> None: self.prefill = self.prefill_normal self.decode = self.decode_normal + if self.enable_dynamic_mtp: + self.dynamic_mtp_planner = DynamicMTPPlanner(mtp_step=get_env_start_args().mtp_step) + self.classed_req_strict_prefill = False return @@ -238,8 +242,10 @@ def decode_mtp( with torch.cuda.stream(g_infer_context.get_overlap_stream()): if self.enable_dynamic_mtp: - # 根据当前的 batch size 和 dynamic_batch_size 计算出需要裁剪的 batch size 的model_input - dynamic_batch_size = 10 # TODO: 需要根据实际情况计算出 dynamic_batch_size + dynamic_batch_size = self.dynamic_mtp_planner.get_dynamic_batch_size( + req_num=len(decode_reqs), + original_batch_size=model_input.batch_size, + ) trans_to_dynamic_model_input = None # TODO: 需要根据实际情况实现 trans_to_dynamic_model_input model_input, selected_run_reqs = trans_to_dynamic_model_input(model_input, dynamic_batch_size) # selected_run_reqs 是一个 gpu tensor, 类型为 int, 0, 表示没有选中, 1 表示选中。 @@ -357,7 +363,7 @@ def decode_mtp( if self.enable_dynamic_mtp: # 更新 动态verify token 数据到 planner 中去 - self._update_dynamic_mtp_size_cpu_part( + self._update_dynamic_mtp_size_statics( decode_reqs=decode_reqs, run_reqs=run_reqs, dynamic_sizes_cpu=dynamic_sizes_cpu, @@ -365,9 +371,12 @@ def decode_mtp( ) # 更新单token的速度信息到 planner 中去 per_token_cost_ms = start_time_event.elapsed_time(verify_event) / (mtp_accept_len_cpu.sum().item()) - # TODO 将 per_token_cost_ms 更新到 planner 中去 - per_token_cost_ms = per_token_cost_ms + 0.0 - pass + + self.dynamic_mtp_planner.update_req_num_speed_statics( + req_num=len(decode_reqs), + dynamic_batch_size=dynamic_batch_size, + per_token_cost_ms=per_token_cost_ms, + ) # 处理需要释放的内存索引 need_free_mem_indexes = model_input.mem_indexes_cpu[accepted_index_cpu == 0] @@ -403,7 +412,7 @@ def _compute_dynamic_mtp_size_gpu_part( dynamic_mtp_sizes = valid_steps.sum(dim=0) return dynamic_mtp_sizes - def _update_dynamic_mtp_size_cpu_part( + def _update_dynamic_mtp_size_statics( self, decode_reqs: List[InferReq], run_reqs: List[InferReq], @@ -416,8 +425,9 @@ def _update_dynamic_mtp_size_cpu_part( if int(accepted) == 1: assert int(new_size) <= req.mtp_step id_to_current_mtp_step[req.req_idx] = int(new_size) - # TODO 将 id_to_current_mtp_step 的信息更新到 planner 中去 - pass + + self.dynamic_mtp_planner.update_req_accept_len_statics(list(id_to_current_mtp_step.values())) + return def _draft_prefill_forward(self, model_input: ModelInput, model_output: ModelOutput, next_token_ids: torch.Tensor): # spec prefill: MTP, 这个地方只是为了填充draft model的 kv, 并不会使用生成的token_id。 From e1f7b7736a194ae38d84458b4b74d008124b307a Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Tue, 19 May 2026 03:00:54 +0000 Subject: [PATCH 16/20] fix --- .../mode_backend/dynamic_mtp_planner.py | 45 ++++++++++++++++--- 1 file changed, 38 insertions(+), 7 deletions(-) diff --git a/lightllm/server/router/model_infer/mode_backend/dynamic_mtp_planner.py b/lightllm/server/router/model_infer/mode_backend/dynamic_mtp_planner.py index 7be523716..8b476741b 100644 --- a/lightllm/server/router/model_infer/mode_backend/dynamic_mtp_planner.py +++ b/lightllm/server/router/model_infer/mode_backend/dynamic_mtp_planner.py @@ -1,3 +1,6 @@ +import random +import math +import numpy as np from typing import Dict, List, Optional @@ -6,16 +9,25 @@ def __init__( self, mtp_step: int, ema_decay: float = 0.9, - confidence_k: float = 1.0, + use_random_mode: bool = True, + random_mode_iter_threshold: int = 100, ) -> None: self.mtp_step = mtp_step - self.confidence_k = confidence_k + self.ema_decay = ema_decay + # 记录每个请求的 accept_len 的 ema 值, 用于分布统计 - self.req_accept_len_ema = _EMAValue(ema_decay, init_value=float(self.mtp_step + 1)) - self.req_accept_len_second_moment_ema = _EMAValue(ema_decay, init_value=float((self.mtp_step + 1) ** 2)) + self.req_accept_len_ema = _EMAValue(self.ema_decay, init_value=float(self.mtp_step + 1)) + self.req_accept_len_second_moment_ema = _EMAValue(self.ema_decay, init_value=float((self.mtp_step + 1) ** 2)) + + # 每多少个请求采用随机的方式决定 dynamic_batch_size + self._iter = 0 + self._iter_threshold = random_mode_iter_threshold + self._use_random_mode = use_random_mode + random.seed(0) # 记录每个请求对应的单token速度记录信息 self.req_num_to_speed_dict: Dict[int, List[_EMAValue]] = {} + return def update_req_accept_len_statics(self, accept_lens: List[int]) -> None: for accept_len in accept_lens: @@ -32,14 +44,33 @@ def update_req_num_speed_statics(self, req_num: int, dynamic_batch_size: int, pe def get_dynamic_batch_size(self, req_num: int, original_batch_size: int) -> int: assert req_num * (self.mtp_step + 1) == original_batch_size + self._iter += 1 + # case 1 如果采用随机的方式决定 dynamic_batch_size + if self._use_random_mode and self._iter % self._iter_threshold == 0: + sigma = math.sqrt(self.req_accept_len_second_moment_ema.get() - self.req_accept_len_ema.get() ** 2) + max_batch_size = min(req_num * (self.mtp_step + 1), int(self.req_accept_len_ema.get() + 1 * sigma)) + return random.randint(req_num, max_batch_size) - mean = self.req_accept_len_ema.get() - return max(req_num, int(mean)) + # case 2 如果采用统计的方式决定 dynamic_batch_size, 利用统计的 ema 信息来决定 + ema_batch_size = max(req_num, int(self.req_accept_len_ema.get())) + sigma = math.sqrt(self.req_accept_len_second_moment_ema.get() - self.req_accept_len_ema.get() ** 2) + max_batch_size = min(req_num * (self.mtp_step + 1), int(self.req_accept_len_ema.get() + 1 * sigma)) + start = req_num - req_num + end = max_batch_size - req_num + ema_index = ema_batch_size - req_num + speed_ema_list = self._get_req_num_speed_ema_list(req_num=req_num) + speeds = [obj.get() for obj in speed_ema_list[start : (end + 1)]] + # 对于期望均值的位置,我们稍微降低下其数据,便于在 ema 统计的数值不够多和准确的时候,更好的选中期望位置, + # 以获得平均性能。 + speeds[ema_index] -= 0.001 + min_index = np.argmin(speeds) + min_cost_batch_size = min_index + req_num + return min_cost_batch_size def _get_req_num_speed_ema_list(self, req_num: int) -> List["_EMAValue"]: if req_num not in self.req_num_to_speed_dict: self.req_num_to_speed_dict[req_num] = [ - _EMAValue(self.ema_decay, init_value=10000000.0) for _ in range(req_num * (self.mtp_step)) + _EMAValue(decay=self.ema_decay, init_value=10000000.0) for _ in range(req_num * (self.mtp_step)) ] return self.req_num_to_speed_dict[req_num] From 4b42d64f0f987b5a86a56030e7b85d00f4d6f69b Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Tue, 19 May 2026 07:02:20 +0000 Subject: [PATCH 17/20] fix --- .../mode_backend/chunked_prefill/impl.py | 9 +++--- .../mode_backend/dynamic_mtp_planner.py | 30 +++++++++++-------- 2 files changed, 22 insertions(+), 17 deletions(-) diff --git a/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl.py b/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl.py index 3db771f91..da8a8ff68 100644 --- a/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl.py +++ b/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl.py @@ -364,7 +364,6 @@ def decode_mtp( if self.enable_dynamic_mtp: # 更新 动态verify token 数据到 planner 中去 self._update_dynamic_mtp_size_statics( - decode_reqs=decode_reqs, run_reqs=run_reqs, dynamic_sizes_cpu=dynamic_sizes_cpu, accepted_index_cpu=accepted_index_cpu, @@ -414,19 +413,19 @@ def _compute_dynamic_mtp_size_gpu_part( def _update_dynamic_mtp_size_statics( self, - decode_reqs: List[InferReq], run_reqs: List[InferReq], dynamic_sizes_cpu: torch.Tensor, accepted_index_cpu: torch.Tensor, ): - id_to_current_mtp_step = {} + id_to_verify_len = {} + assert len(run_reqs) == dynamic_sizes_cpu.shape[0] == accepted_index_cpu.shape[0] for req, new_size, accepted in zip(run_reqs, dynamic_sizes_cpu.numpy(), accepted_index_cpu.numpy()): if int(accepted) == 1: assert int(new_size) <= req.mtp_step - id_to_current_mtp_step[req.req_idx] = int(new_size) + id_to_verify_len[req.req_idx] = int(new_size) + 1 - self.dynamic_mtp_planner.update_req_accept_len_statics(list(id_to_current_mtp_step.values())) + self.dynamic_mtp_planner.update_req_verify_len_statics(verify_lens=list(id_to_verify_len.values())) return def _draft_prefill_forward(self, model_input: ModelInput, model_output: ModelOutput, next_token_ids: torch.Tensor): diff --git a/lightllm/server/router/model_infer/mode_backend/dynamic_mtp_planner.py b/lightllm/server/router/model_infer/mode_backend/dynamic_mtp_planner.py index 8b476741b..73a593085 100644 --- a/lightllm/server/router/model_infer/mode_backend/dynamic_mtp_planner.py +++ b/lightllm/server/router/model_infer/mode_backend/dynamic_mtp_planner.py @@ -15,9 +15,9 @@ def __init__( self.mtp_step = mtp_step self.ema_decay = ema_decay - # 记录每个请求的 accept_len 的 ema 值, 用于分布统计 - self.req_accept_len_ema = _EMAValue(self.ema_decay, init_value=float(self.mtp_step + 1)) - self.req_accept_len_second_moment_ema = _EMAValue(self.ema_decay, init_value=float((self.mtp_step + 1) ** 2)) + # 记录每个请求的 verify_len 的 ema 值, 用于分布统计 + self.req_verify_len_ema = _EMAValue(self.ema_decay, init_value=float(self.mtp_step + 1)) + self.req_verify_len_second_moment_ema = _EMAValue(self.ema_decay, init_value=float((self.mtp_step + 1) ** 2)) # 每多少个请求采用随机的方式决定 dynamic_batch_size self._iter = 0 @@ -29,10 +29,10 @@ def __init__( self.req_num_to_speed_dict: Dict[int, List[_EMAValue]] = {} return - def update_req_accept_len_statics(self, accept_lens: List[int]) -> None: - for accept_len in accept_lens: - self.req_accept_len_ema.update(float(accept_len)) - self.req_accept_len_second_moment_ema.update(float(accept_len ** 2)) + def update_req_verify_len_statics(self, verify_lens: List[int]) -> None: + for verify_len in verify_lens: + self.req_verify_len_ema.update(float(verify_len)) + self.req_verify_len_second_moment_ema.update(float(verify_len ** 2)) return def update_req_num_speed_statics(self, req_num: int, dynamic_batch_size: int, per_token_cost_ms: float) -> None: @@ -47,14 +47,17 @@ def get_dynamic_batch_size(self, req_num: int, original_batch_size: int) -> int: self._iter += 1 # case 1 如果采用随机的方式决定 dynamic_batch_size if self._use_random_mode and self._iter % self._iter_threshold == 0: - sigma = math.sqrt(self.req_accept_len_second_moment_ema.get() - self.req_accept_len_ema.get() ** 2) - max_batch_size = min(req_num * (self.mtp_step + 1), int(self.req_accept_len_ema.get() + 1 * sigma)) + sigma = self._get_verify_len_sigma() + max_batch_size = min(req_num * (self.mtp_step + 1), int(self.req_verify_len_ema.get() + 1 * sigma)) + max_batch_size = max(req_num, max_batch_size) return random.randint(req_num, max_batch_size) # case 2 如果采用统计的方式决定 dynamic_batch_size, 利用统计的 ema 信息来决定 - ema_batch_size = max(req_num, int(self.req_accept_len_ema.get())) - sigma = math.sqrt(self.req_accept_len_second_moment_ema.get() - self.req_accept_len_ema.get() ** 2) - max_batch_size = min(req_num * (self.mtp_step + 1), int(self.req_accept_len_ema.get() + 1 * sigma)) + ema_batch_size = max(req_num, int(self.req_verify_len_ema.get())) + sigma = self._get_verify_len_sigma() + max_batch_size = min(req_num * (self.mtp_step + 1), int(self.req_verify_len_ema.get() + 1 * sigma)) + max_batch_size = max(req_num, max_batch_size) + start = req_num - req_num end = max_batch_size - req_num ema_index = ema_batch_size - req_num @@ -74,6 +77,9 @@ def _get_req_num_speed_ema_list(self, req_num: int) -> List["_EMAValue"]: ] return self.req_num_to_speed_dict[req_num] + def _get_verify_len_sigma(self) -> float: + return math.sqrt(max(0.0, self.req_verify_len_second_moment_ema.get() - self.req_verify_len_ema.get() ** 2)) + class _EMAValue: def __init__(self, decay: float, init_value: float) -> None: From 74c6919aa94f907daf25ff1a9084ed2974eba45e Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Tue, 19 May 2026 07:14:31 +0000 Subject: [PATCH 18/20] fix --- .../router/model_infer/mode_backend/chunked_prefill/impl.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl.py b/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl.py index da8a8ff68..5dbc69384 100644 --- a/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl.py +++ b/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl.py @@ -304,7 +304,7 @@ def decode_mtp( dynamic_sizes_cpu = g_pin_mem_manager.async_copy_from_gpu_tensor( key="dynamic_mtp_sizes", gpu_tensor=dynamic_sizes_gpu ) - dynamic_sizes_cpu # TODO, use to update statcis. + draft_probs_list = [e.view(-1, 1) for e in draft_probs_list] draft_probs_list = [torch.ones_like(draft_probs_list[-1])] + draft_probs_list all_next_token_probs = torch.cat(draft_probs_list, dim=-1) # [batch_size, mtp_step + 1] From 8374f197cb5bdcf6665fccd914c60b8da88f724c Mon Sep 17 00:00:00 2001 From: Junyi Chen Date: Fri, 22 May 2026 16:53:50 +0800 Subject: [PATCH 19/20] save --- .../basemodel/triton_kernel/mtp_utils.py | 332 +++++++++++++++++- .../mode_backend/chunked_prefill/impl.py | 29 +- .../mode_backend/generic_post_process.py | 32 +- 3 files changed, 384 insertions(+), 9 deletions(-) diff --git a/lightllm/common/basemodel/triton_kernel/mtp_utils.py b/lightllm/common/basemodel/triton_kernel/mtp_utils.py index 843a77d96..d3b1fe93d 100644 --- a/lightllm/common/basemodel/triton_kernel/mtp_utils.py +++ b/lightllm/common/basemodel/triton_kernel/mtp_utils.py @@ -1,8 +1,10 @@ -from typing import Optional +from typing import Optional, Tuple import triton import triton.language as tl import torch +from lightllm.common.basemodel.batch_objs import ModelInput + @triton.jit def _fwd_kernel_mtp_verify( @@ -178,6 +180,214 @@ def mtp_scatter_next_token_ids( ) +@triton.jit +def _fwd_kernel_sample_dynamic_mtp_steps( + req_to_next_token_probs, + req_to_next_token_probs_stride, + req_indices, + sampled_steps, + rand_seed, + MAX_MTP_STEP: tl.constexpr, + BLOCK_SIZE: tl.constexpr, +): + cur_index = tl.program_id(0) + cur_req_idx = tl.load(req_indices + cur_index) + + sampled_step = 1 + prefix_ok = 1 + + for step in range(1, BLOCK_SIZE): + if step < MAX_MTP_STEP: + cur_prob = tl.load(req_to_next_token_probs + cur_req_idx * req_to_next_token_probs_stride + step) + cur_rand = tl.rand(rand_seed, cur_index * BLOCK_SIZE + step) + cur_accept = (cur_prob > cur_rand) & (prefix_ok == 1) + sampled_step += tl.where(cur_accept, 1, 0) + prefix_ok = tl.where(cur_accept, 1, 0) + + tl.store(sampled_steps + cur_index, sampled_step) + return + + +def sample_dynamic_mtp_req_mask( + dynamic_batch_size: int, + b_req_idx: torch.Tensor, + b_mtp_index: torch.Tensor, + req_to_next_token_probs: torch.Tensor, +) -> Tuple[torch.Tensor, torch.Tensor]: + assert b_req_idx.shape == b_mtp_index.shape + assert req_to_next_token_probs.is_cuda + + b_req_idx_gpu = b_req_idx.cuda(non_blocking=True) if not b_req_idx.is_cuda else b_req_idx + b_mtp_index_gpu = b_mtp_index.cuda(non_blocking=True) if not b_mtp_index.is_cuda else b_mtp_index + + batch_size = b_req_idx_gpu.shape[0] + req_start_mask = b_mtp_index_gpu == 0 + num_reqs = int(req_start_mask.sum().item()) + assert num_reqs > 0 + assert dynamic_batch_size >= num_reqs + assert dynamic_batch_size <= batch_size + + req_indices_gpu = b_req_idx_gpu[req_start_mask].to(dtype=torch.int32) + + max_mtp_step = req_to_next_token_probs.shape[1] + BLOCK_SIZE = 16 + assert max_mtp_step <= BLOCK_SIZE, f"max_mtp_step must be less than or equal to {BLOCK_SIZE}" + + sampled_steps = torch.empty((num_reqs,), dtype=torch.int32, device="cuda") + rand_seed = int(torch.randint(0, 2**31 - 1, (1,), device="cuda", dtype=torch.int64).item()) + + grid = (num_reqs,) + _fwd_kernel_sample_dynamic_mtp_steps[grid]( + req_to_next_token_probs=req_to_next_token_probs, + req_to_next_token_probs_stride=req_to_next_token_probs.stride(0), + req_indices=req_indices_gpu, + sampled_steps=sampled_steps, + rand_seed=rand_seed, + MAX_MTP_STEP=max_mtp_step, + BLOCK_SIZE=BLOCK_SIZE, + num_warps=1, + num_stages=1, + ) + + req_order = torch.cumsum(req_start_mask.to(torch.int32), dim=0) - 1 + row_sampled_steps = sampled_steps[req_order.long()] + sampled_mask_gpu = (b_mtp_index_gpu < row_sampled_steps).to(torch.int32) + final_mask_gpu = select_dynamic_mtp_exec_mask( + dynamic_batch_size=dynamic_batch_size, + req_start_mask=req_start_mask, + sampled_mask_gpu=sampled_mask_gpu, + ) + + return final_mask_gpu, sampled_steps + + +def select_dynamic_mtp_exec_mask( + dynamic_batch_size: int, + req_start_mask: torch.Tensor, + sampled_mask_gpu: torch.Tensor, +) -> torch.Tensor: + assert req_start_mask.is_cuda + assert sampled_mask_gpu.is_cuda + assert req_start_mask.shape == sampled_mask_gpu.shape + + num_reqs = int(req_start_mask.sum().item()) + selected_count = int(sampled_mask_gpu.sum().item()) + if selected_count <= dynamic_batch_size: + return sampled_mask_gpu + + remaining_budget = dynamic_batch_size - num_reqs + if remaining_budget <= 0: + # 当前 token budget 分配逻辑保持最简单的顺序策略: + # 先保证每个请求至少保留 1 个 token(即 mtp_index == 0 的主请求行), + # 如果 budget 不足以支持更多 draft token,则该请求本轮不做 MTP 展开。 + return req_start_mask.to(torch.int32) + + # 当前 token budget 分配逻辑保持最简单的顺序策略: + # 在保证每个请求至少保留 1 个 token 之后, + # 对剩余的 draft token 按 batch 中从前到后的顺序依次分配 budget。 + extra_mask_gpu = sampled_mask_gpu & (~req_start_mask) + extra_rank = torch.cumsum(extra_mask_gpu.to(torch.int32), dim=0) + kept_extra_mask = extra_mask_gpu & (extra_rank <= remaining_budget) + return req_start_mask.to(torch.int32) | kept_extra_mask.to(torch.int32) + + +def _rebuild_trimmed_mtp_b_mark_shared_group_from_b_mtp_index(b_mtp_index: torch.Tensor) -> torch.Tensor: + assert b_mtp_index.is_cuda + batch_size = b_mtp_index.shape[0] + if batch_size == 0: + return torch.empty((0,), dtype=torch.int32, device=b_mtp_index.device) + + # prepare_decode_inputs 中已经保证了每个请求的 decode rows 是按 + # b_mtp_index = [0, 1, 2, ...] 的前缀顺序展开的;动态 trim 只会保留 + # 每个请求的前缀,因此 compact 之后每个请求块的末尾满足: + # 1. 是最后一个元素;或 + # 2. 下一个元素的 b_mtp_index 重新回到 0 + group_end_mask = torch.ones((batch_size,), dtype=torch.bool, device=b_mtp_index.device) + if batch_size > 1: + group_end_mask[:-1] = b_mtp_index[1:] == 0 + + b_mark_shared_group = torch.zeros((batch_size,), dtype=torch.int32, device=b_mtp_index.device) + b_mark_shared_group[group_end_mask] = (b_mtp_index[group_end_mask] + 1).to(torch.int32) + return b_mark_shared_group + + +def _trim_decode_model_input_inplace(model_input: ModelInput, selected_mask_gpu: torch.Tensor) -> ModelInput: + assert not model_input.is_prefill + assert selected_mask_gpu.is_cuda + assert model_input.b_req_idx.is_cuda + assert model_input.b_mtp_index.is_cuda + assert model_input.b_seq_len.is_cuda + + selected_index = torch.nonzero(selected_mask_gpu, as_tuple=False).flatten() + new_batch_size = selected_index.numel() + if new_batch_size == model_input.batch_size: + return model_input + selected_index_cpu = None + + if model_input.input_ids is not None: + assert model_input.input_ids.is_cuda + model_input.input_ids = torch.index_select(model_input.input_ids, dim=0, index=selected_index) + model_input.b_req_idx = torch.index_select(model_input.b_req_idx, dim=0, index=selected_index) + model_input.b_mtp_index = torch.index_select(model_input.b_mtp_index, dim=0, index=selected_index) + model_input.b_seq_len = torch.index_select(model_input.b_seq_len, dim=0, index=selected_index) + + if model_input.mem_indexes is not None: + assert model_input.mem_indexes.is_cuda + model_input.mem_indexes = torch.index_select(model_input.mem_indexes, dim=0, index=selected_index) + if model_input.mem_indexes_cpu is not None: + selected_index_cpu = selected_index.cpu() + model_input.mem_indexes_cpu = torch.index_select(model_input.mem_indexes_cpu, dim=0, index=selected_index_cpu) + if model_input.b_shared_seq_len is not None: + assert model_input.b_shared_seq_len.is_cuda + model_input.b_shared_seq_len = torch.index_select(model_input.b_shared_seq_len, dim=0, index=selected_index) + if model_input.mtp_draft_input_hiddens is not None: + assert model_input.mtp_draft_input_hiddens.is_cuda + model_input.mtp_draft_input_hiddens = torch.index_select( + model_input.mtp_draft_input_hiddens, dim=0, index=selected_index + ) + if model_input.multimodal_params is not None: + if selected_index_cpu is None: + selected_index_cpu = selected_index.cpu() + model_input.multimodal_params = [model_input.multimodal_params[i] for i in selected_index_cpu.tolist()] + + model_input.b_mark_shared_group = _rebuild_trimmed_mtp_b_mark_shared_group_from_b_mtp_index( + model_input.b_mtp_index + ) + model_input.batch_size = new_batch_size + # model_input.total_token_num = int(model_input.b_seq_len.sum().item()) + # model_input.max_kv_seq_len = int(model_input.b_seq_len.max().item()) if new_batch_size > 0 else 0 + + # TODO: if CPU mirrors become a measurable bottleneck, move mem_indexes_cpu/multimodal_params + # trimming out of the hot path or rebuild them from later-stage selected indices. + return model_input + + +def trim_dynamic_mtp_model_input( + model_input: ModelInput, + dynamic_batch_size: int, + req_to_next_token_ids: torch.Tensor, + req_to_next_token_probs: Optional[torch.Tensor] = None, +): + del req_to_next_token_ids + + if req_to_next_token_probs is None: + selected_mask = torch.ones((model_input.batch_size,), dtype=torch.int32, device="cuda") + return model_input, selected_mask + + assert not model_input.is_prefill, "trim_dynamic_mtp_model_input only supports decode inputs" + # Launch the host->device copies as early as possible, then do mask sampling and compaction on GPU. + model_input.to_cuda() + + selected_mask_gpu, _ = sample_dynamic_mtp_req_mask( + dynamic_batch_size=dynamic_batch_size, + b_req_idx=model_input.b_req_idx, + b_mtp_index=model_input.b_mtp_index, + req_to_next_token_probs=req_to_next_token_probs, + ) + model_input = _trim_decode_model_input_inplace(model_input=model_input, selected_mask_gpu=selected_mask_gpu) + return model_input, selected_mask_gpu + + @triton.jit def _fwd_kernel_gen_b_req_mtp_start_loc( b_mtp_index, @@ -238,6 +448,124 @@ def test_gen_b_req_mtp_start_loc(): print(b_req_mtp_start_loc, gt_output) +def test_sample_dynamic_mtp_req_mask(): + torch.manual_seed(1234) + + req_to_next_token_ids = torch.tensor( + [ + [100, 101, 102, 103, 0, 0], + [200, 201, 202, 203, 0, 0], + [300, 301, 302, 303, 0, 0], + ], + dtype=torch.int64, + device="cuda", + ) + req_to_next_token_probs = torch.tensor( + [ + [1.0, 0.95, 0.90, 0.10, 0.0, 0.0], + [1.0, 0.20, 0.80, 0.80, 0.0, 0.0], + [1.0, 0.99, 0.99, 0.99, 0.0, 0.0], + ], + dtype=torch.float32, + device="cuda", + ) + + # 3 个请求,每个请求展开成 4 个 decode row + b_req_idx = torch.tensor( + [0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2], + dtype=torch.int32, + device="cuda", + ) + b_mtp_index = torch.tensor( + [0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3], + dtype=torch.int32, + device="cuda", + ) + + dynamic_batch_size = 8 + selected_mask, sampled_steps = sample_dynamic_mtp_req_mask( + dynamic_batch_size=dynamic_batch_size, + b_req_idx=b_req_idx, + b_mtp_index=b_mtp_index, + req_to_next_token_probs=req_to_next_token_probs, + ) + + print("==== test_sample_dynamic_mtp_req_mask ====") + print("req_to_next_token_ids:") + print(req_to_next_token_ids.cpu()) + print("req_to_next_token_probs:") + print(req_to_next_token_probs.cpu()) + print("b_req_idx:") + print(b_req_idx.cpu()) + print("b_mtp_index:") + print(b_mtp_index.cpu()) + print("sampled_steps per req:") + print(sampled_steps.cpu()) + print("selected_mask:") + print(selected_mask.cpu()) + print("selected rows [req_idx, mtp_index]:") + selected_pos = torch.where(selected_mask == 1)[0] + print(torch.stack([b_req_idx[selected_pos], b_mtp_index[selected_pos]], dim=-1).cpu()) + + +def test_trim_dynamic_mtp_model_input(): + torch.manual_seed(1234) + + model_input = ModelInput( + batch_size=12, + total_token_num=54, + max_q_seq_len=1, + max_kv_seq_len=6, + input_ids=None, + b_req_idx=torch.tensor([0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2], dtype=torch.int32, device="cuda"), + b_mtp_index=torch.tensor([0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3], dtype=torch.int32, device="cuda"), + b_seq_len=torch.tensor([3, 4, 5, 6, 3, 4, 5, 6, 3, 4, 5, 6], dtype=torch.int32, device="cuda"), + b_shared_seq_len=None, + b_mark_shared_group=torch.tensor([0, 0, 0, 4, 0, 0, 0, 4, 0, 0, 0, 4], dtype=torch.int32, device="cuda"), + mem_indexes=torch.arange(12, dtype=torch.int32, device="cuda"), + mem_indexes_cpu=torch.arange(12, dtype=torch.int32, device="cpu"), + is_prefill=False, + multimodal_params=[{"images": [], "audios": []} for _ in range(12)], + ) + req_to_next_token_probs = torch.tensor( + [ + [1.0, 0.95, 0.90, 0.10, 0.0, 0.0], + [1.0, 0.20, 0.80, 0.80, 0.0, 0.0], + [1.0, 0.99, 0.99, 0.99, 0.0, 0.0], + ], + dtype=torch.float32, + device="cuda", + ) + + model_input, selected_mask = trim_dynamic_mtp_model_input( + model_input=model_input, + dynamic_batch_size=8, + req_to_next_token_ids=torch.empty((0,), dtype=torch.int64, device="cuda"), + req_to_next_token_probs=req_to_next_token_probs, + ) + + print("==== test_trim_dynamic_mtp_model_input ====") + print("selected_mask:") + print(selected_mask.cpu()) + print("batch_size:", model_input.batch_size) + print("total_token_num:", model_input.total_token_num) + print("max_kv_seq_len:", model_input.max_kv_seq_len) + print("b_req_idx:") + print(model_input.b_req_idx.cpu()) + print("b_mtp_index:") + print(model_input.b_mtp_index.cpu()) + print("b_seq_len:") + print(model_input.b_seq_len.cpu()) + print("b_mark_shared_group:") + print(model_input.b_mark_shared_group.cpu()) + print("mem_indexes:") + print(model_input.mem_indexes.cpu()) + print("mem_indexes_cpu:") + print(model_input.mem_indexes_cpu) + + if __name__ == "__main__": - test_mtp_verify() + # test_mtp_verify() # test_gen_b_req_mtp_start_loc() + test_sample_dynamic_mtp_req_mask() + # test_trim_dynamic_mtp_model_input() diff --git a/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl.py b/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl.py index 5dbc69384..913dfd795 100644 --- a/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl.py +++ b/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl.py @@ -20,7 +20,9 @@ from lightllm.common.basemodel.batch_objs import ModelOutput, ModelInput from lightllm.common.basemodel.triton_kernel.gather_token_id import scatter_token from lightllm.common.basemodel.triton_kernel.mtp_utils import ( + gen_b_req_mtp_start_loc, mtp_scatter_next_token_ids, + trim_dynamic_mtp_model_input, ) from lightllm.utils.log_utils import init_logger from lightllm.utils.dist_utils import get_current_device_id @@ -238,6 +240,7 @@ def decode_mtp( MTP解码的通用流程,整合eagle和vanilla的共同逻辑 """ model_input, run_reqs = prepare_decode_inputs(decode_reqs) + origin_mem_indexes_cpu = model_input.mem_indexes_cpu with torch.cuda.stream(g_infer_context.get_overlap_stream()): @@ -245,9 +248,15 @@ def decode_mtp( dynamic_batch_size = self.dynamic_mtp_planner.get_dynamic_batch_size( req_num=len(decode_reqs), original_batch_size=model_input.batch_size, + ) + # TODO: 需要根据实际情况实现 trans_to_dynamic_model_input + trans_to_dynamic_model_input = trim_dynamic_mtp_model_input + model_input, selected_run_reqs = trans_to_dynamic_model_input( + model_input=model_input, + dynamic_batch_size=dynamic_batch_size, + req_to_next_token_ids=self.model.req_manager.req_sampling_params_manager.req_to_next_token_ids, + req_to_next_token_probs=self.model.req_manager.req_sampling_params_manager.req_to_next_token_probs, ) - trans_to_dynamic_model_input = None # TODO: 需要根据实际情况实现 trans_to_dynamic_model_input - model_input, selected_run_reqs = trans_to_dynamic_model_input(model_input, dynamic_batch_size) # selected_run_reqs 是一个 gpu tensor, 类型为 int, 0, 表示没有选中, 1 表示选中。 selected_run_reqs_cpu = g_pin_mem_manager.async_copy_from_gpu_tensor( @@ -261,10 +270,15 @@ def decode_mtp( start_time_event.record() model_output = self.model.forward(model_input) - next_token_ids, next_token_logprobs = sample(model_output.logits, run_reqs, self.eos_id) + next_token_ids, next_token_logprobs = sample( + model_output.logits, + run_reqs, + self.eos_id, + selected_run_reqs=selected_run_reqs if self.enable_dynamic_mtp else None, + ) # verify the next_token_ids - get_b_req_mtp_start_loc = None # TODO: 需要根据实际情况实现 get_b_req_mtp_start_loc - b_req_mtp_start_loc = get_b_req_mtp_start_loc(model_input.b_mtp_index, req_num=len(decode_reqs)) + # TODO: 需要根据实际情况实现 get_b_req_mtp_start_loc + b_req_mtp_start_loc = gen_b_req_mtp_start_loc(model_input.b_mtp_index, num_reqs=len(decode_reqs)) # b_req_mtp_start_loc 是一个 gpu tensor, 类型为 int, 表示每个请求的 mtp_start_loc, shape 为 len(decode_reqs) mtp_accept_len, accepted_index = self._verify_mtp_v2( new_next_token_ids=next_token_ids, @@ -379,6 +393,11 @@ def decode_mtp( # 处理需要释放的内存索引 need_free_mem_indexes = model_input.mem_indexes_cpu[accepted_index_cpu == 0] + if self.enable_dynamic_mtp: + selected_run_reqs_cpu_numpy = selected_run_reqs_cpu.numpy() + trim_free_mem_indexes = origin_mem_indexes_cpu[selected_run_reqs_cpu_numpy == 0] + if len(trim_free_mem_indexes) > 0: + need_free_mem_indexes = torch.cat([need_free_mem_indexes, trim_free_mem_indexes], dim=0) if additional_mem_indexes_cpu is not None: need_free_mem_indexes = torch.cat([need_free_mem_indexes, additional_mem_indexes_cpu], dim=0) diff --git a/lightllm/server/router/model_infer/mode_backend/generic_post_process.py b/lightllm/server/router/model_infer/mode_backend/generic_post_process.py index f3ad03662..fddfa60f0 100644 --- a/lightllm/server/router/model_infer/mode_backend/generic_post_process.py +++ b/lightllm/server/router/model_infer/mode_backend/generic_post_process.py @@ -1,5 +1,5 @@ import torch -from typing import List, Tuple +from typing import List, Tuple, Optional from lightllm.common.basemodel.triton_kernel.apply_penalty import apply_penalty from lightllm.common.basemodel.triton_kernel.apply_penalty_gpu_cache import apply_penalty_gpu_cache from lightllm.server.router.model_infer.infer_batch import InferReq, g_infer_context @@ -7,7 +7,35 @@ from lightllm.utils.envs_utils import get_env_start_args -def sample(logits: torch.Tensor, reqs: List[InferReq], eos_id: List[int] = [2]): +def sample( + logits: torch.Tensor, + reqs: List[InferReq], + eos_id: List[int] = [2], + selected_run_reqs: Optional[torch.Tensor] = None, +): + if selected_run_reqs is not None: + assert selected_run_reqs.is_cuda + assert selected_run_reqs.numel() == len(reqs) + selected_index = torch.nonzero(selected_run_reqs.to(torch.bool), as_tuple=False).flatten() + assert logits.shape[0] == selected_index.numel(), f"{logits.shape[0]} vs {selected_index.numel()}" + + padded_logits = torch.zeros( + (len(reqs), logits.shape[1]), + dtype=logits.dtype, + device=logits.device, + ) + padded_logits.index_copy_(0, selected_index, logits) + + full_next_token_ids, full_next_token_logprobs = sample( + padded_logits, + reqs, + eos_id=eos_id, + selected_run_reqs=None, + ) + next_token_ids = torch.index_select(full_next_token_ids, dim=0, index=selected_index) + next_token_logprobs = torch.index_select(full_next_token_logprobs, dim=0, index=selected_index) + return next_token_ids, next_token_logprobs + ( b_req_idx, b_temperatures, From ce9941679a88852878b1fc6bb889a5269d6faa37 Mon Sep 17 00:00:00 2001 From: Junyi Chen Date: Sat, 23 May 2026 18:29:53 +0800 Subject: [PATCH 20/20] fix sync operation --- .../basemodel/triton_kernel/mtp_utils.py | 31 +++++++------------ .../mode_backend/chunked_prefill/impl.py | 16 +++++++--- .../mode_backend/dynamic_mtp_planner.py | 6 ++-- test/speculative/qwen3-32b/dynamic_triton.sh | 3 +- 4 files changed, 28 insertions(+), 28 deletions(-) diff --git a/lightllm/common/basemodel/triton_kernel/mtp_utils.py b/lightllm/common/basemodel/triton_kernel/mtp_utils.py index d3b1fe93d..eb4b70e70 100644 --- a/lightllm/common/basemodel/triton_kernel/mtp_utils.py +++ b/lightllm/common/basemodel/triton_kernel/mtp_utils.py @@ -215,19 +215,18 @@ def sample_dynamic_mtp_req_mask( req_to_next_token_probs: torch.Tensor, ) -> Tuple[torch.Tensor, torch.Tensor]: assert b_req_idx.shape == b_mtp_index.shape + assert b_req_idx.is_cuda + assert b_mtp_index.is_cuda assert req_to_next_token_probs.is_cuda - b_req_idx_gpu = b_req_idx.cuda(non_blocking=True) if not b_req_idx.is_cuda else b_req_idx - b_mtp_index_gpu = b_mtp_index.cuda(non_blocking=True) if not b_mtp_index.is_cuda else b_mtp_index - - batch_size = b_req_idx_gpu.shape[0] - req_start_mask = b_mtp_index_gpu == 0 + batch_size = b_req_idx.shape[0] + req_start_mask = b_mtp_index == 0 num_reqs = int(req_start_mask.sum().item()) assert num_reqs > 0 assert dynamic_batch_size >= num_reqs assert dynamic_batch_size <= batch_size - req_indices_gpu = b_req_idx_gpu[req_start_mask].to(dtype=torch.int32) + req_indices_gpu = b_req_idx[req_start_mask].to(dtype=torch.int32) max_mtp_step = req_to_next_token_probs.shape[1] BLOCK_SIZE = 16 @@ -251,7 +250,7 @@ def sample_dynamic_mtp_req_mask( req_order = torch.cumsum(req_start_mask.to(torch.int32), dim=0) - 1 row_sampled_steps = sampled_steps[req_order.long()] - sampled_mask_gpu = (b_mtp_index_gpu < row_sampled_steps).to(torch.int32) + sampled_mask_gpu = (b_mtp_index < row_sampled_steps).to(torch.int32) final_mask_gpu = select_dynamic_mtp_exec_mask( dynamic_batch_size=dynamic_batch_size, req_start_mask=req_start_mask, @@ -317,12 +316,12 @@ def _trim_decode_model_input_inplace(model_input: ModelInput, selected_mask_gpu: assert model_input.b_req_idx.is_cuda assert model_input.b_mtp_index.is_cuda assert model_input.b_seq_len.is_cuda + assert model_input.mem_indexes is not None and model_input.mem_indexes.is_cuda selected_index = torch.nonzero(selected_mask_gpu, as_tuple=False).flatten() new_batch_size = selected_index.numel() if new_batch_size == model_input.batch_size: return model_input - selected_index_cpu = None if model_input.input_ids is not None: assert model_input.input_ids.is_cuda @@ -334,9 +333,6 @@ def _trim_decode_model_input_inplace(model_input: ModelInput, selected_mask_gpu: if model_input.mem_indexes is not None: assert model_input.mem_indexes.is_cuda model_input.mem_indexes = torch.index_select(model_input.mem_indexes, dim=0, index=selected_index) - if model_input.mem_indexes_cpu is not None: - selected_index_cpu = selected_index.cpu() - model_input.mem_indexes_cpu = torch.index_select(model_input.mem_indexes_cpu, dim=0, index=selected_index_cpu) if model_input.b_shared_seq_len is not None: assert model_input.b_shared_seq_len.is_cuda model_input.b_shared_seq_len = torch.index_select(model_input.b_shared_seq_len, dim=0, index=selected_index) @@ -345,20 +341,15 @@ def _trim_decode_model_input_inplace(model_input: ModelInput, selected_mask_gpu: model_input.mtp_draft_input_hiddens = torch.index_select( model_input.mtp_draft_input_hiddens, dim=0, index=selected_index ) + # ! 目前用不到multimodal_params,但是又会检查它的长度是否正确,因此先简单地 trim 掉,保持和其他参数一致的 batch_size if model_input.multimodal_params is not None: - if selected_index_cpu is None: - selected_index_cpu = selected_index.cpu() - model_input.multimodal_params = [model_input.multimodal_params[i] for i in selected_index_cpu.tolist()] + model_input.multimodal_params = model_input.multimodal_params[:new_batch_size] model_input.b_mark_shared_group = _rebuild_trimmed_mtp_b_mark_shared_group_from_b_mtp_index( model_input.b_mtp_index ) model_input.batch_size = new_batch_size - # model_input.total_token_num = int(model_input.b_seq_len.sum().item()) - # model_input.max_kv_seq_len = int(model_input.b_seq_len.max().item()) if new_batch_size > 0 else 0 - # TODO: if CPU mirrors become a measurable bottleneck, move mem_indexes_cpu/multimodal_params - # trimming out of the hot path or rebuild them from later-stage selected indices. return model_input @@ -375,8 +366,10 @@ def trim_dynamic_mtp_model_input( return model_input, selected_mask assert not model_input.is_prefill, "trim_dynamic_mtp_model_input only supports decode inputs" - # Launch the host->device copies as early as possible, then do mask sampling and compaction on GPU. + # Keep the trim path simple and deterministic: finish model_input.to_cuda() first, + # then do both sampling and compaction purely on GPU tensors. model_input.to_cuda() + torch.cuda.current_stream().synchronize() selected_mask_gpu, _ = sample_dynamic_mtp_req_mask( dynamic_batch_size=dynamic_batch_size, diff --git a/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl.py b/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl.py index 913dfd795..1bfa702e3 100644 --- a/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl.py +++ b/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl.py @@ -258,7 +258,6 @@ def decode_mtp( req_to_next_token_probs=self.model.req_manager.req_sampling_params_manager.req_to_next_token_probs, ) # selected_run_reqs 是一个 gpu tensor, 类型为 int, 0, 表示没有选中, 1 表示选中。 - selected_run_reqs_cpu = g_pin_mem_manager.async_copy_from_gpu_tensor( key="selected_run_reqs", gpu_tensor=selected_run_reqs, @@ -391,13 +390,20 @@ def decode_mtp( per_token_cost_ms=per_token_cost_ms, ) - # 处理需要释放的内存索引 - need_free_mem_indexes = model_input.mem_indexes_cpu[accepted_index_cpu == 0] + # 处理需要释放的内存索引。动态 MTP trim 只裁了 GPU 侧 model_input, + # 因此这里统一基于原始的 origin_mem_indexes_cpu 计算: + # 1. 被选中但 verify 未通过的 index + # 2. 未被选中的 index if self.enable_dynamic_mtp: - selected_run_reqs_cpu_numpy = selected_run_reqs_cpu.numpy() - trim_free_mem_indexes = origin_mem_indexes_cpu[selected_run_reqs_cpu_numpy == 0] + selected_run_reqs_cpu_mask = selected_run_reqs_cpu.to(dtype=torch.bool) + selected_mem_indexes_cpu = origin_mem_indexes_cpu[selected_run_reqs_cpu_mask] + need_free_mem_indexes = selected_mem_indexes_cpu[accepted_index_cpu == 0] + + trim_free_mem_indexes = origin_mem_indexes_cpu[~selected_run_reqs_cpu_mask] if len(trim_free_mem_indexes) > 0: need_free_mem_indexes = torch.cat([need_free_mem_indexes, trim_free_mem_indexes], dim=0) + else: + need_free_mem_indexes = model_input.mem_indexes_cpu[accepted_index_cpu == 0] if additional_mem_indexes_cpu is not None: need_free_mem_indexes = torch.cat([need_free_mem_indexes, additional_mem_indexes_cpu], dim=0) diff --git a/lightllm/server/router/model_infer/mode_backend/dynamic_mtp_planner.py b/lightllm/server/router/model_infer/mode_backend/dynamic_mtp_planner.py index 73a593085..11c0c93d2 100644 --- a/lightllm/server/router/model_infer/mode_backend/dynamic_mtp_planner.py +++ b/lightllm/server/router/model_infer/mode_backend/dynamic_mtp_planner.py @@ -44,18 +44,18 @@ def update_req_num_speed_statics(self, req_num: int, dynamic_batch_size: int, pe def get_dynamic_batch_size(self, req_num: int, original_batch_size: int) -> int: assert req_num * (self.mtp_step + 1) == original_batch_size - self._iter += 1 # case 1 如果采用随机的方式决定 dynamic_batch_size if self._use_random_mode and self._iter % self._iter_threshold == 0: sigma = self._get_verify_len_sigma() - max_batch_size = min(req_num * (self.mtp_step + 1), int(self.req_verify_len_ema.get() + 1 * sigma)) + max_batch_size = min(req_num * (self.mtp_step + 1), req_num * int(self.req_verify_len_ema.get() + 1 * sigma)) max_batch_size = max(req_num, max_batch_size) return random.randint(req_num, max_batch_size) + self._iter += 1 # case 2 如果采用统计的方式决定 dynamic_batch_size, 利用统计的 ema 信息来决定 ema_batch_size = max(req_num, int(self.req_verify_len_ema.get())) sigma = self._get_verify_len_sigma() - max_batch_size = min(req_num * (self.mtp_step + 1), int(self.req_verify_len_ema.get() + 1 * sigma)) + max_batch_size = min(req_num * (self.mtp_step + 1), req_num * int(self.req_verify_len_ema.get() + 1 * sigma)) max_batch_size = max(req_num, max_batch_size) start = req_num - req_num diff --git a/test/speculative/qwen3-32b/dynamic_triton.sh b/test/speculative/qwen3-32b/dynamic_triton.sh index ca70bf15e..6af1c2ca4 100644 --- a/test/speculative/qwen3-32b/dynamic_triton.sh +++ b/test/speculative/qwen3-32b/dynamic_triton.sh @@ -27,7 +27,8 @@ LOADWORKER=18 /data/nvme0/chenjunyi/miniconda3/envs/lightllm/bin/python -m light --graph_grow_step_size 1 \ --mtp_step ${MTP_STEP} \ --llm_decode_att_backend triton \ ---mtp_dynamic_verify +--mtp_dynamic_verify \ +--disable_cudagraph # if you want to enable microbatch overlap, you can uncomment the following lines #--enable_prefill_microbatch_overlap \ #--enable_decode_microbatch_overlap \