Skip to content

[Speculative Decoding]【Hackathon 10th Spring No.54】hybrid_mtp_ngram 端到端验证#7849

Open
NKNaN wants to merge 6 commits into
PaddlePaddle:developfrom
NKNaN:spec-mtp-ngram
Open

[Speculative Decoding]【Hackathon 10th Spring No.54】hybrid_mtp_ngram 端到端验证#7849
NKNaN wants to merge 6 commits into
PaddlePaddle:developfrom
NKNaN:spec-mtp-ngram

Conversation

@NKNaN
Copy link
Copy Markdown
Contributor

@NKNaN NKNaN commented May 19, 2026

Motivation

PaddlePaddle/community#1372

Modifications

  1. 算子接口(ngram_match_mixed.cu、cpp_extensions.cc):input_ids/input_ids_len 改为 token_ids_all/prompt_lens,pre_ids 暂保留,预计下一个pr去除。

  2. Python 调用消除拷贝(mtp_cuda.py):_extend_draft_token_with_ngram_match 中两次 .cuda() 替换为已在 GPU 的张量。

  3. 代码清理(mtp.py):删除 insert_tasks_v1 中的 .cpu() D→H 拷贝、input_ids_cpu/input_ids_len 写入。

  4. ProposerInputBatch 修改(input_batch.py):token_ids_all 从 clone 改为引用 target 张量;删除冗余字段 input_ids_cpu/input_ids_len 及其 swap/reset 中的维护。

  5. 新增 hybrid E2E 测试(test_ernie_21b_mtp_ngram.py):覆盖 overlap + cudagraph + logprob 。

Usage or Command

N/A

Accuracy Tests

tests/e2e/test_ernie_21b_mtp_ngram.py

Checklist

  • Add at least a tag in the PR title.
    • Tag list: [[FDConfig],[APIServer],[Engine], [Scheduler], [PD Disaggregation], [Executor], [Graph Optimization], [Speculative Decoding], [RL], [Models], [Quantization], [Loader], [OP], [KVCache], [DataProcessor], [BugFix], [Docs], [CI], [Optimization], [Feature], [Benchmark], [Others], [XPU], [HPU], [GCU], [DCU], [Iluvatar], [Metax]]
    • You can add new tags based on the PR content, but the semantics must be clear.
  • Format your code, run pre-commit before commit.
  • Add unit tests. Please write the reason in this PR if no unit tests.
  • Provide accuracy results.
  • If the current PR is submitting to the release branch, make sure the PR has been submitted to the develop branch, then cherry-pick it to the release branch with the [Cherry-Pick] PR tag.

@paddle-bot
Copy link
Copy Markdown

paddle-bot Bot commented May 19, 2026

Thanks for your contribution!

@paddle-bot paddle-bot Bot added the contributor External developers label May 19, 2026
@codecov-commenter
Copy link
Copy Markdown

codecov-commenter commented May 19, 2026

Codecov Report

❌ Patch coverage is 50.00000% with 1 line in your changes missing coverage. Please review.
⚠️ Please upload report for BASE (develop@b2fc2c6). Learn more about missing BASE report.

Files with missing lines Patch % Lines
fastdeploy/worker/input_batch.py 50.00% 1 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             develop    #7849   +/-   ##
==========================================
  Coverage           ?   63.62%           
==========================================
  Files              ?      462           
  Lines              ?    64385           
  Branches           ?     9873           
==========================================
  Hits               ?    40962           
  Misses             ?    20635           
  Partials           ?     2788           
Flag Coverage Δ
GPU 72.77% <50.00%> (?)
XPU 7.12% <0.00%> (?)

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

该 PR 围绕 hybrid_mtp_ngram(Hybrid MTP + Ngram)链路做端到端验证与代码清理:统一算子接口参数语义(从 input_ids/input_ids_len 迁移到 token_ids_all/prompt_lens),并消除 MTP hybrid 路径中不必要的 D2H/H2D 拷贝,最后补充 E2E 覆盖 overlap + cudagraph + logprob 场景。

Changes:

  • 更新 hybrid_mtp_ngram CUDA 算子接口与内部实现:prompt 搜索源改为 token_ids_all + prompt_lens
  • MTP hybrid 路径消除 input_ids_cpu/input_ids_len 相关 CPU 缓冲与 .cpu()/.cuda() 拷贝,并同步调整 ProposerInputBatch 初始化/重置逻辑。
  • 更新相关单测并新增 ERNIE 21B 的 hybrid MTP-Ngram E2E 测试用例。

Reviewed changes

Copilot reviewed 8 out of 8 changed files in this pull request and generated 1 comment.

Show a summary per file
File Description
tests/spec_decode/test_ngram_gpu_kernel.py 更新 CPU 参考实现与数据构造,适配 token_ids_all/prompt_lens 接口
tests/operators/test_hybrid_mtp_ngram.py 更新算子单测输入字段与注释,匹配新接口
tests/e2e/test_ernie_21b_mtp_ngram.py 新增 hybrid MTP-Ngram 的 E2E 覆盖(stream/non-stream、speculate_metrics、logprobs)
fastdeploy/worker/input_batch.py ProposerInputBatch 移除 input_ids_cpu/input_ids_len 维护,token_ids_all 改为直接引用目标 batch
fastdeploy/spec_decode/mtp.py 删除 insert/prepare 阶段对 input_ids_leninput_ids_cpu 的写入与 D2H 拷贝
fastdeploy/spec_decode/mtp_cuda.py hybrid ngram 扩展调用改为直接使用 GPU 上的 token_ids_all/prompt_lens
custom_ops/gpu_ops/speculate_decoding/draft_model/ngram_match_mixed.cu CUDA/CPU 路径统一改用 token_ids_all/prompt_lens,更新内核参数含义
custom_ops/gpu_ops/cpp_extensions.cc 同步更新 HybridMtpNgram C++ 声明签名
Comments suppressed due to low confidence (1)

tests/e2e/test_ernie_21b_mtp_ngram.py:259

  • 这里对 speculate_metrics 做了严格的 dict 全等比较(==),如果服务端返回的浮点值存在舍入差异、或字段顺序/附加字段有微调,就会导致用例不稳定。若目的是回归关键行为,建议改为:对整数统计做精确比较;对 accept_ratio/average_accept_length 等浮点字段做容差比较;或通过 BaselineManager 管理可更新的基线数据。
    # Baseline comparison — exact match against the values captured in the reference environment.
    if BASELINE_SPECULATE_METRICS is not None:
        assert speculate_metrics == BASELINE_SPECULATE_METRICS, (
            f"speculate_metrics mismatch\n"
            f"got:      {json.dumps(speculate_metrics, indent=2)}\n"
            f"baseline: {json.dumps(BASELINE_SPECULATE_METRICS, indent=2)}"
        )

Comment thread tests/e2e/test_ernie_21b_mtp_ngram.py Outdated
PaddlePaddle-bot

This comment was marked as outdated.

@PaddlePaddle-bot
Copy link
Copy Markdown

PaddlePaddle-bot commented May 19, 2026

🤖 Paddle-CI-Agent | ci_status_monitor | 2026-05-20 19:42:12

CI报告基于以下代码生成(30分钟更新一次):


1 任务总览

Required 任务有 2 个失败、0 个等待/运行中,当前仍阻塞合并;其中主单测任务失败来自 Diff 覆盖率未达阈值,另 1 个为 Approval 待人工审批。

总执行(rerun次数) 总任务 ✅ 通过 ❌ 失败 ⏳ 运行中 ⏸️ 等待中 跳过
42(0) 42 37 4 0 1 0

2 任务状态汇总

日志列说明:失败任务直接链接到对应 Job 日志;运行/等待中的 optional workflow 仅供参考。

2.1 Required任务 : 8/10 通过

必选任务阻塞合并,失败需优先处理。

状态 任务 耗时 根因 修复建议 日志 重跑
Run FastDeploy Unit Tests and Coverage / run_tests_with_coverage 1h23m PR问题:Diff覆盖率50%,input_batch.py:973未覆盖 为reset_model_inputs分支补测试 Job -
Approval 11s 需要 Approval 请通过人工审批 Job -
其余 8 个必选任务通过 - - - - -

2.2 可选任务 — 29/32 通过

可选任务不阻塞合并,失败仅供参考。

状态 任务 耗时 日志 重跑
Run iluvatar Tests / run_iluvatar_cases 16m24s Job -
Trigger Jenkins for PR 7m33s Job -
⏸️ CI_HPU - - -
其余 29 个可选任务通过 - - -

3 失败详情(仅 required)

Run FastDeploy Unit Tests and Coverage / run_tests_with_coverage — 覆盖率失败(置信度: 高)

Run FastDeploy Unit Tests and Coverage / run_tests_with_coverage

  • 状态: ❌ 失败
  • 错误类型: 覆盖率失败
  • 置信度: 高
  • 根因摘要: Diff覆盖率50%,input_batch.py:973未覆盖
  • 分析器: ci_analyze_unittest_fastdeploy

失败用例: 无(单测通过,失败发生在覆盖率阈值校验阶段)

根因详情:
单测阶段已通过,但 Verify Code Coverage Threshold (80%) 步骤返回 COVERAGE_EXIT_CODE=9。覆盖率产物 diff_coverage.json 显示本 PR 的 Diff 覆盖率只有 50%,低于 80% 阈值;唯一未覆盖行是 fastdeploy/worker/input_batch.py:973,对应 ProposerInputBatch.reset_model_inputs() 中 CUDA 且存在 token_ids_all 的分支。

关键日志:

All tests passed
Coverage generation failed (exit code 9)
GPU Patch Coverage Details:
{"src_stats":{"fastdeploy/worker/input_batch.py":{"percent_covered":50.0,"violation_lines":[973],"covered_lines":[773],"violations":[[973,null]]}},"total_percent_covered":50,"num_changed_lines":461}
##[error]Process completed with exit code 9.

修复建议:

  1. tests/worker/test_gpu_model_runner.py 或新增 tests/worker/test_input_batch.py 中补充 ProposerInputBatch.reset_model_inputs() 的单测,构造含 token_ids_all / prompt_lenstarget_model_input_batch,覆盖 fastdeploy/worker/input_batch.py:973
  2. 若该行无法在单测环境稳定覆盖,可将 reset_model_inputs()token_ids_all 初始化逻辑抽成可测试的小函数,并对 CUDA/token_ids_all 分支做直接单测。

修复建议摘要: 为reset_model_inputs分支补测试

关联变更: fastdeploy/worker/input_batch.py:973token_ids_all 从 clone 改为引用 target 张量)
链接: 查看日志

Approval — 需要人工审批(置信度: 高)

该 Job 需要人工 Approval,完成审批后 CI 才会继续执行。

@freeliuzc
Copy link
Copy Markdown
Collaborator

代码整体实现没问题,缺少一份置信的性能以及接受率报告来佐证功能正确。
仿照 FastDeploy/benchmarks/README.md ,使用 filtered_sharedgpt_2000_input_1136_output_200_fd 数据集,对 non-spec/ngram/mtp(1步和3步)/mtp(3步)+hybrid 出一份性能报告,以及 speculate.log 里的接受率统计

NKNaN and others added 5 commits May 20, 2026 13:56
PaddlePaddle-bot

This comment was marked as outdated.

Copy link
Copy Markdown

@PaddlePaddle-bot PaddlePaddle-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🤖 Paddle-CI-Agent | pr_review | 2026-05-20 17:28:33

📋 Review 摘要

PR 概述:将 HybridMtpNgram 算子接口从 input_ids/input_ids_len 重构为 token_ids_all/prompt_lens,消除 CPU↔GPU 拷贝开销,并补充 hybrid E2E 测试。
变更范围custom_ops/gpu_ops/speculate_decoding/fastdeploy/spec_decode/fastdeploy/worker/tests/
影响面 Tag[Speculative Decoding] [OP]

问题

级别 文件 概述
❓ 疑问 fastdeploy/worker/input_batch.py:773 token_ids_all 从 clone 改为共享引用,需确认只读安全性
🟡 建议 tests/e2e/test_ernie_21b_mtp_ngram.py:65 accept_ratio_per_head 5 个元素与 accepted_tokens_per_head 6 个元素维度不匹配

📝 PR 规范检查

标题含官方 Tag [Speculative Decoding],描述结构完整(Motivation / Modifications / Usage / Accuracy Tests / Checklist 均已填写且 Checklist 全部勾选),PR 规范合规。✓

总体评价

本次重构思路清晰,去除了 CPU 侧冗余拷贝和 D→H 数据搬移,性能方向正确;新增的 E2E 测试覆盖了 overlap + cudagraph + logprob 场景。主要需确认 token_ids_all 共享引用的只读安全性,以及 accept_ratio_per_head 维度不一致的设计意图。

self.batch_id_per_token_output = paddle.clone(self.target_model_input_batch["batch_id_per_token_output"])
if "token_ids_all" in self.target_model_input_batch:
self.token_ids_all = paddle.clone(self.target_model_input_batch["token_ids_all"])
self.token_ids_all = self.target_model_input_batch["token_ids_all"]
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

❓ 疑问 token_ids_allclone 改为直接引用,需确认共享写安全

# 改前(安全)
self.token_ids_all = paddle.clone(self.target_model_input_batch["token_ids_all"])
# 改后(共享引用)
self.token_ids_all = self.target_model_input_batch["token_ids_all"]

改为共享引用后,ProposerInputBatch.token_ids_all 与 target model inputs 指向同一张量。如果 spec_decode 路径中对 token_ids_all 有 in-place 写操作(如 fill / index_put),会直接修改 target model 的输入。

请确认:

  1. _extend_draft_token_with_ngram_matchhybrid_mtp_ngram 仅读取 token_ids_all,不会写入
  2. reset_model_inputs 中同样改为引用(diff 第 970 行),是否所有 reset 场景下 target 的 token_ids_all 都已更新完毕再被引用

如确认只读,建议在 init_share_inputs 处添加注释说明共享语义:# token_ids_all is shared (read-only) with target model inputs

BASELINE_SPECULATE_METRICS = _build_speculate_metrics_baseline(
accepted_tokens=100,
rejected_tokens=206,
accept_ratio=0.49,
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟡 建议 accept_ratio_per_head 长度与 accepted_tokens_per_head 不一致

accepted_tokens_per_head 有 6 个元素(对应 num_speculative_tokens=5 + 1 个 head-0),而 accept_ratio_per_head 仅有 5 个元素,两者维度不匹配。测试中只检查了 accepted_tokens_per_head 的长度(assert len(...) == 6),但未验证 accept_ratio_per_head 的长度。

建议确认 accept_ratio_per_head 的语义(是否 ratio 少一维是设计意图),并在 _build_speculate_metrics_baseline 注释中说明;如非有意,应补充第 6 个元素或在测试中增加长度断言。

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

contributor External developers

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants