Skip to content

fix(pytorch): offload guided decoding CPU ops to thread pool to prevent event loop blocking#4590

Open
windreamer wants to merge 1 commit into
InternLM:mainfrom
windreamer:fix-guided-stall
Open

fix(pytorch): offload guided decoding CPU ops to thread pool to prevent event loop blocking#4590
windreamer wants to merge 1 commit into
InternLM:mainfrom
windreamer:fix-guided-stall

Conversation

@windreamer
Copy link
Copy Markdown
Collaborator

@windreamer windreamer commented May 18, 2026

Motivation

When using guided decoding (JSON schema, regex, etc.) with the PyTorch engine, the asyncio event loop gets blocked by CPU-bound xgrammar operations (fill_bitmap and accept_token). Since these are synchronous CPU-intensive calls executed directly inside async methods, they prevent the event loop from processing other concurrent requests, causing stalls under multi-request workloads.

Modification

  1. Extract sync helpers in logits_process.py: Pulled _fill_guided_bitmask and _accept_guided_tokens out as standalone synchronous methods from the __call__ and sampling paths respectively.
  2. Wrap with asyncio.to_thread: Both _fill_guided_bitmask (called during __call__ ) and _accept_guided_tokens (called in agent.py after sampling) are now executed via asyncio.to_thread, offloading CPU-bound xgrammar ops to a thread pool so the event loop remains responsive.
  3. Move accept_token to agent.py: The accept_token call was previously embedded in FusedLogitsProcessor.sampling(), which is a sync method with no access to the event loop. Moved the call to BaseModelAgent where it can be properly awaited via asyncio.to_thread. Also moved result.cpu() to the caller side so the tensor is explicitly transferred to CPU before entering the thread.
  4. Add guided decoding benchmark: Added benchmark/benchmark_guided.py and updated benchmark/README.md with documentation for measuring per-token latency overhead of guided decoding.

Checklist

  • Pre-commit or other linting tools are used to fix the potential lint issues.
  • The modification is covered by complete unit tests. If not, please add more unit tests to ensure the correctness.
  • If the modification has a dependency on downstream projects of a newer version, this PR should be tested with all supported versions of downstream projects.
  • The documentation has been modified accordingly, like docstring or example tutorials.

Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Offloads CPU-bound xgrammar guided-decoding operations (fill_bitmap and accept_token) onto a thread pool via asyncio.to_thread, so they no longer block the PyTorch engine's asyncio event loop. Also adds a benchmark script and documentation for measuring guided decoding overhead.

Changes:

  • Extract _fill_guided_bitmask / _accept_guided_tokens as sync helpers in FusedLogitsProcessor and call them from the event loop via asyncio.to_thread.
  • Move the accept_token step out of FusedLogitsProcessor.sampling() into BaseModelAgent.async_sampling_logits, with next_token_ids.cpu() materialized on the caller side.
  • Add benchmark/benchmark_guided.py and README section for guided-decoding TPOT/overhead benchmarking.

Reviewed changes

Copilot reviewed 4 out of 4 changed files in this pull request and generated 6 comments.

File Description
lmdeploy/pytorch/engine/logits_process.py Splits guided-decoding CPU work into sync helpers; removes accept_token from sampling().
lmdeploy/pytorch/engine/model_agent/agent.py Invokes _accept_guided_tokens via asyncio.to_thread after sampling.
benchmark/benchmark_guided.py New benchmark comparing guided vs. baseline runs, with TPOT/ITL focus.
benchmark/README.md Documents the new benchmark and its metrics.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment thread lmdeploy/pytorch/engine/logits_process.py
Comment thread lmdeploy/pytorch/engine/model_agent/agent.py Outdated
Comment thread lmdeploy/pytorch/engine/logits_process.py Outdated
Comment thread benchmark/benchmark_guided.py
Comment thread benchmark/benchmark_guided.py
Comment thread benchmark/benchmark_guided.py
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 4 out of 4 changed files in this pull request and generated 2 comments.

Comments suppressed due to low confidence (4)

benchmark/benchmark_guided.py:288

  • extract_metrics accepts a num_requests parameter that is never used inside the function body. Either remove the parameter (and the corresponding len(requests) argument at the call sites in main) or use it (e.g., to report a success rate). Carrying unused arguments around makes the API misleading.
def extract_metrics(profiler: Profiler, actual_output_lens: list[int], num_requests: int) -> dict:
    return {
        'elapsed': profiler.elapsed_time,
        'total_input': profiler.total_input,
        'total_output': profiler.total_output,
        'success': profiler.success,
        'avg_output_len': float(np.mean(actual_output_lens)) if actual_output_lens else 0.0,
        'median_output_len': float(np.median(actual_output_lens)) if actual_output_lens else 0.0,
        'output_throughput': profiler.output_throughput,
        'input_throughput': profiler.input_throughput,
        'rps': profiler.rps,
        'e2e_mean': profiler.e2e_mean,
        'e2e_p50': profiler.e2e_stat[0],
        'e2e_p99': profiler.e2e_stat[-1],
        'tpot_mean': profiler.tpot_mean,
        'tpot_p99': profiler.tpot_stat[-1],
        'ttft_mean': getattr(profiler, 'ttft_mean', float('inf')),
        'ttft_p99': getattr(profiler, 'ttft_stat', (float('inf'),))[-1],
        'itl_mean': getattr(profiler, 'itls_mean', float('inf')),
        'itl_p99': getattr(profiler, 'itls_stat', (float('inf'),))[-1],
    }

benchmark/benchmark_guided.py:262

  • Failed/cancelled requests are appended with an output length of 0, which then pollutes the avg_output_len / median_output_len reported in the comparison table (and skews any downstream interpretation of guided vs. baseline output length). It would be more honest to either skip non-successful sessions when computing averages, or use the actual generated token count (s.ns[-1]) regardless of success and report the number of failures separately.
        actual_output_lens = []
        for i, s in enumerate(sess):
            actual_output_lens.append(s.ns[-1] if s.status == Session.SUCCESS else 0)
            if s.status != Session.SUCCESS:
                logger.warning(f'Request {i}: {s.ns[-1]}/{s.req_output_len} tokens, finish != length/stop')
        return actual_output_lens

benchmark/benchmark_guided.py:187

  • sample_random_requests iterates for i in range(num_prompts) and indexes dataset[i], but dataset (filtered ShareGPT conversations) may contain fewer than num_prompts entries, in which case this will raise IndexError. Consider guarding with min(num_prompts, len(dataset)) or sampling with replacement.
    requests: list[tuple[str, int, int]] = []
    for i in range(num_prompts):
        prompt = dataset[i][0]
        prompt_token_ids = tokenizer.encode(prompt)
        prompt_len = len(prompt_token_ids)
        if prompt_len > input_lens[i]:
            input_ids = prompt_token_ids[:input_lens[i]]
        else:
            ratio = (input_lens[i] + prompt_len - 1) // prompt_len
            input_ids = (prompt_token_ids * ratio)[:input_lens[i]]
        prompt = tokenizer.decode(input_ids)
        requests.append((prompt, int(input_lens[i]), int(output_lens[i])))

benchmark/benchmark_guided.py:636

  • import csv as _csv is executed inside the per-row loop. Move the import to module top-level (alongside json, os, etc.) — re-importing each iteration is wasteful and inconsistent with the rest of the file's import style.
            with open(args.csv, 'a') as f:
                import csv as _csv
                writer = _csv.DictWriter(f, fieldnames=row.keys())

Comment thread lmdeploy/pytorch/engine/logits_process.py Outdated
Comment thread benchmark/benchmark_guided.py
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 4 out of 4 changed files in this pull request and generated 3 comments.

Comment thread lmdeploy/pytorch/engine/logits_process.py
Comment thread benchmark/benchmark_guided.py
Comment thread benchmark/benchmark_guided.py
Copy link
Copy Markdown
Collaborator

@grimoire grimoire left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

… loop blocking

- Extract _fill_guided_bitmask and _accept_guided_tokens as sync methods
- Wrap both with asyncio.to_thread to prevent CPU-bound xgrammar ops
  (fill_bitmap, accept_token) from blocking the asyncio event loop
- Move result.cpu() to caller side (agent.py) instead of storing as member
- Keep _wait_stream_once intact (confirmed not the root cause)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants