fix(pytorch): offload guided decoding CPU ops to thread pool to prevent event loop blocking#4590
fix(pytorch): offload guided decoding CPU ops to thread pool to prevent event loop blocking#4590windreamer wants to merge 1 commit into
Conversation
There was a problem hiding this comment.
Pull request overview
Offloads CPU-bound xgrammar guided-decoding operations (fill_bitmap and accept_token) onto a thread pool via asyncio.to_thread, so they no longer block the PyTorch engine's asyncio event loop. Also adds a benchmark script and documentation for measuring guided decoding overhead.
Changes:
- Extract
_fill_guided_bitmask/_accept_guided_tokensas sync helpers inFusedLogitsProcessorand call them from the event loop viaasyncio.to_thread. - Move the
accept_tokenstep out ofFusedLogitsProcessor.sampling()intoBaseModelAgent.async_sampling_logits, withnext_token_ids.cpu()materialized on the caller side. - Add
benchmark/benchmark_guided.pyand README section for guided-decoding TPOT/overhead benchmarking.
Reviewed changes
Copilot reviewed 4 out of 4 changed files in this pull request and generated 6 comments.
| File | Description |
|---|---|
| lmdeploy/pytorch/engine/logits_process.py | Splits guided-decoding CPU work into sync helpers; removes accept_token from sampling(). |
| lmdeploy/pytorch/engine/model_agent/agent.py | Invokes _accept_guided_tokens via asyncio.to_thread after sampling. |
| benchmark/benchmark_guided.py | New benchmark comparing guided vs. baseline runs, with TPOT/ITL focus. |
| benchmark/README.md | Documents the new benchmark and its metrics. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
e78002b to
bfdf8db
Compare
There was a problem hiding this comment.
Pull request overview
Copilot reviewed 4 out of 4 changed files in this pull request and generated 2 comments.
Comments suppressed due to low confidence (4)
benchmark/benchmark_guided.py:288
extract_metricsaccepts anum_requestsparameter that is never used inside the function body. Either remove the parameter (and the correspondinglen(requests)argument at the call sites inmain) or use it (e.g., to report a success rate). Carrying unused arguments around makes the API misleading.
def extract_metrics(profiler: Profiler, actual_output_lens: list[int], num_requests: int) -> dict:
return {
'elapsed': profiler.elapsed_time,
'total_input': profiler.total_input,
'total_output': profiler.total_output,
'success': profiler.success,
'avg_output_len': float(np.mean(actual_output_lens)) if actual_output_lens else 0.0,
'median_output_len': float(np.median(actual_output_lens)) if actual_output_lens else 0.0,
'output_throughput': profiler.output_throughput,
'input_throughput': profiler.input_throughput,
'rps': profiler.rps,
'e2e_mean': profiler.e2e_mean,
'e2e_p50': profiler.e2e_stat[0],
'e2e_p99': profiler.e2e_stat[-1],
'tpot_mean': profiler.tpot_mean,
'tpot_p99': profiler.tpot_stat[-1],
'ttft_mean': getattr(profiler, 'ttft_mean', float('inf')),
'ttft_p99': getattr(profiler, 'ttft_stat', (float('inf'),))[-1],
'itl_mean': getattr(profiler, 'itls_mean', float('inf')),
'itl_p99': getattr(profiler, 'itls_stat', (float('inf'),))[-1],
}
benchmark/benchmark_guided.py:262
- Failed/cancelled requests are appended with an output length of
0, which then pollutes theavg_output_len/median_output_lenreported in the comparison table (and skews any downstream interpretation of guided vs. baseline output length). It would be more honest to either skip non-successful sessions when computing averages, or use the actual generated token count (s.ns[-1]) regardless of success and report the number of failures separately.
actual_output_lens = []
for i, s in enumerate(sess):
actual_output_lens.append(s.ns[-1] if s.status == Session.SUCCESS else 0)
if s.status != Session.SUCCESS:
logger.warning(f'Request {i}: {s.ns[-1]}/{s.req_output_len} tokens, finish != length/stop')
return actual_output_lens
benchmark/benchmark_guided.py:187
sample_random_requestsiteratesfor i in range(num_prompts)and indexesdataset[i], butdataset(filtered ShareGPT conversations) may contain fewer thannum_promptsentries, in which case this will raiseIndexError. Consider guarding withmin(num_prompts, len(dataset))or sampling with replacement.
requests: list[tuple[str, int, int]] = []
for i in range(num_prompts):
prompt = dataset[i][0]
prompt_token_ids = tokenizer.encode(prompt)
prompt_len = len(prompt_token_ids)
if prompt_len > input_lens[i]:
input_ids = prompt_token_ids[:input_lens[i]]
else:
ratio = (input_lens[i] + prompt_len - 1) // prompt_len
input_ids = (prompt_token_ids * ratio)[:input_lens[i]]
prompt = tokenizer.decode(input_ids)
requests.append((prompt, int(input_lens[i]), int(output_lens[i])))
benchmark/benchmark_guided.py:636
import csv as _csvis executed inside the per-row loop. Move the import to module top-level (alongsidejson,os, etc.) — re-importing each iteration is wasteful and inconsistent with the rest of the file's import style.
with open(args.csv, 'a') as f:
import csv as _csv
writer = _csv.DictWriter(f, fieldnames=row.keys())
bfdf8db to
73ba658
Compare
73ba658 to
d16c434
Compare
d16c434 to
77adbb1
Compare
… loop blocking - Extract _fill_guided_bitmask and _accept_guided_tokens as sync methods - Wrap both with asyncio.to_thread to prevent CPU-bound xgrammar ops (fill_bitmap, accept_token) from blocking the asyncio event loop - Move result.cpu() to caller side (agent.py) instead of storing as member - Keep _wait_stream_once intact (confirmed not the root cause)
77adbb1 to
82fd238
Compare
Motivation
When using guided decoding (JSON schema, regex, etc.) with the PyTorch engine, the asyncio event loop gets blocked by CPU-bound xgrammar operations (fill_bitmap and accept_token). Since these are synchronous CPU-intensive calls executed directly inside async methods, they prevent the event loop from processing other concurrent requests, causing stalls under multi-request workloads.
Modification
__call__andsamplingpaths respectively.__call__) and_accept_guided_tokens(called in agent.py after sampling) are now executed via asyncio.to_thread, offloading CPU-bound xgrammar ops to a thread pool so the event loop remains responsive.Checklist