FP8 kv cache quantization#4563
Conversation
Adds FP8 KV cache quantization (QuantPolicy.FP8 = 16) using torch.float8_e4m3fn with per-token symmetric scale (no zero point). Key design: - Reuses existing fill_kv_cache_blocked_fp8() with group_size=head_dim for per-token scale semantics in the fill path - Dequant in flatten_kv_cache and paged_attention via x.to(f32)*scale - Scale tensor shape [..., 1]: symmetric, no zero point - No bit packing (head_dim unchanged, unlike INT4/TURBO_QUANT) Also fixes pre-existing TestFillKVCacheBlockedFP8 test failures caused by calling .max() on float8_e4m3fn tensors (cast to float32 first). Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Avoid constructing a temporary cu_seqlen_q tensor in the FP8 cache-fill path by letting fill_kv_cache_blocked_fp8 consume the existing q_start_loc and q_seq_length metadata directly. The kernel keeps the old cumulative-seqlen mode for direct callers via a USE_CU_SEQLEN constexpr. Move default paged-decode FP8 dequant scaling across the attention dot products. K scales are applied after QK, and V scales are applied to probabilities before PV, which preserves the per-token/head scale algebra while avoiding full K/V tile dequantization in the hot decode loop. Add a focused FP8 paged-attention test that compares against a dequantized-FP8 reference, including a split-head-dim case, so the fused scale placement is covered without conflating it with expected quantization error.
Split normal FP8 KV cache from the dynamic per-token/head FP8 path. Normal fp8/fp8_e4m3 and fp8_e5m2 now use scalar K/V scales with FP8 cache tensors and no k_scales_zeros/v_scales_zeros metadata allocation, while fp8_per_token_head variants keep the existing per-token/head scale-cache behavior. Thread scalar k_scale/v_scale through PyTorch attention dispatch, cache fill, flatten, and paged decode kernels so normal FP8 can quantize on cache write and apply scalar dequant in decode/prefill without materialized metadata tensors. Add optional one-shot calculate_kv_scales support and guard CUDA graph capture while scale calculation is pending, mirroring vLLM's eager first-pass behavior. Add focused CLI/config/cache descriptor tests and scalar/per-token FP8 kernel reference coverage. Validation: py_compile on changed runtime/kernel/test files; pytest -q tests/test_lmdeploy/test_fp8_kv_cache_policy.py; git diff --check. CUDA kernel tests were not run because nvidia-smi cannot communicate with the driver in this environment.
Remove the deprecated-style dynamic KV scale calculation path and keep normal FP8 on the vLLM-aligned scalar-scale behavior with default scales. Drop the experimental per-token/head FP8 policy and tests so the public surface only exposes fp8, fp8_e4m3, and fp8_e5m2. Sadly we have to remove some potentially useful features to keep this PR concise and solid.
There was a problem hiding this comment.
Pull request overview
This PR adds PyTorch-backend FP8 KV-cache quantization for paged attention, using per-tensor scalar K/V scales and storing the KV cache in torch.float8_e4m3fn (fp8) or torch.float8_e5m2 (fp8_e5m2). It wires the scales through attention/backends/kernels and introduces targeted kernel + CLI/config tests.
Changes:
- Add new quant policies
FP8/FP8_E5M2with CLI aliases and basic policy/config validation. - Implement FP8 per-tensor-scale paths across KV cache fill, paged decode attention, and KV flatten/recovery kernels.
- Add kernel tests and end-to-end quant-policy tests for FP8 KV cache behavior.
Reviewed changes
Copilot reviewed 17 out of 17 changed files in this pull request and generated 5 comments.
Show a summary per file
| File | Description |
|---|---|
| tests/test_lmdeploy/test_quant_policy.py | Adds FP8 quant-policy pipeline/accuracy tests and adjusts fixture scopes. |
| tests/test_lmdeploy/test_fp8_kv_cache_policy.py | New tests for CLI parsing + engine/config acceptance/rejection + cache-engine helpers. |
| tests/pytorch/kernel/test_paged_attention.py | Adds FP8-scalar quantized paged-attention kernel tests (E4M3/E5M2). |
| tests/pytorch/kernel/test_flatten_kv_cache.py | Adds FP8-scalar KV flatten kernel tests (E4M3/E5M2) + reference flatten. |
| tests/pytorch/kernel/test_fill_kv_cache.py | Adds FP8-scalar KV fill kernel tests (E4M3/E5M2). |
| lmdeploy/pytorch/nn/attention.py | Plumbs scalar k_scale/v_scale buffers through attention forward for FP8 KV. |
| lmdeploy/pytorch/kernels/cuda/pagedattention.py | Adds FP8 quant-policy handling in the paged-attention Triton kernel wrapper and kernel logic. |
| lmdeploy/pytorch/kernels/cuda/flatten_kv_cache.py | Adds FP8-scalar Triton flatten kernel and routes FP8 policies to it. |
| lmdeploy/pytorch/kernels/cuda/fill_kv_cache.py | Adds FP8-scalar Triton fill kernel and routes FP8 policies to it. |
| lmdeploy/pytorch/engine/cache_engine.py | Makes FP8 quant policies allocate FP8-typed KV cache tensors and logs policy description. |
| lmdeploy/pytorch/config.py | Adds an import-order lint suppression (noqa: I001). |
| lmdeploy/pytorch/backends/dlinfer/attention.py | Extends attention backend API surface to accept scalar k_scale/v_scale. |
| lmdeploy/pytorch/backends/cuda/attention/fa3.py | Plumbs scalar k_scale/v_scale through FA3 backend calls. |
| lmdeploy/pytorch/backends/cuda/attention/default.py | Plumbs scalar k_scale/v_scale through default CUDA attention backend calls. |
| lmdeploy/pytorch/backends/attention.py | Extends base attention backend interface signature with scalar scales. |
| lmdeploy/messages.py | Adds FP8 quant policies and extends engine-config validation/docs. |
| lmdeploy/cli/utils.py | Adds quant-policy string aliases and custom parsing for CLI. |
Comments suppressed due to low confidence (1)
lmdeploy/messages.py:483
- PytorchEngineConfig validation allows FP8 quant policies regardless of device_type (it only restricts quantization to CUDA/ASCEND), but CacheEngine later asserts FP8 quantization is CUDA-only. Consider adding an explicit check here to reject QuantPolicy.FP8/FP8_E5M2 when device_type != 'cuda', so users get a clear configuration-time error instead of a runtime assertion deeper in the engine.
assert self.quant_policy in (
QuantPolicy.NONE,
QuantPolicy.INT4,
QuantPolicy.INT8,
QuantPolicy.FP8,
QuantPolicy.FP8_E5M2,
QuantPolicy.TURBO_QUANT,
), 'invalid quant_policy'
assert self.device_type in ['cuda', 'ascend', 'maca', 'camb'], (f'invalid device_type: {self.device_type}')
assert self.kernel_block_size >= 16 and \
(self.kernel_block_size & (self.kernel_block_size - 1)) == 0, \
f'kernel_block_size must be >= 16 and a power of 2, but got {self.kernel_block_size}'
assert self.block_size >= self.kernel_block_size and \
self.block_size % self.kernel_block_size == 0, \
(f'block_size must be >= kernel_block_size and an integer multiple '
f'of kernel_block_size, but got block_size {self.block_size} '
f'and kernel_block_size {self.kernel_block_size}')
if self.quant_policy > 0 and self.device_type not in ['cuda', 'ascend']:
assert False, \
'kv cache quantization only works for CUDA and ASCEND.'
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| assert self.num_gpu_blocks >= 0, 'invalid num_gpu_blocks' | ||
| assert self.quant_policy in (QuantPolicy.NONE, QuantPolicy.INT4, QuantPolicy.INT8, QuantPolicy.TURBO_QUANT), \ | ||
| 'invalid quant_policy' | ||
| assert self.quant_policy in ( |
There was a problem hiding this comment.
try:
self.quant_policy = QuantPolicy(self.quant_policy)
except ValueError as e:
raise ValueError(f'invalid quant_policy: {self.quant_policy}') from e
might be better for the check
| kv_start_loc: Start location of each KV sequence [batch_size]. | ||
| kv_seqlens: Length of each KV sequence [batch_size]. | ||
| quant_policy: Quantization policy (0=none, 4=int4, 8=int8/fp8). | ||
| quant_policy: Quantization policy (0=none, 4=int4, 8=int8, 16/17=per-tensor fp8). |
There was a problem hiding this comment.
Since we would send quant_policy here, can we reuse k_scales_zeros instead of create new variable?
| ) | ||
| elif quant_policy in (QuantPolicy.FP8, QuantPolicy.FP8_E5M2): | ||
| if k_scale is None: | ||
| k_scale = torch.ones((), device=k_caches.device, dtype=torch.float32) |
There was a problem hiding this comment.
Use lru_cache to avoid repeat ones kernel, or fuse it into the kernel.
Summary
Add PyTorch-backend FP8 KV-cache quantization for paged attention.
The default
fp8policy uses normal FP8 KV cache with scalar per-tensor K/V scales, matching the common vLLM behavior and avoiding per-token/per-head scale metadata on the hot path.What Changed
fp8/fp8_e4m3: E4M3 FP8 KV cachefp8_e5m2: E5M2 FP8 KV cachetorch.float8_e4m3fnortorch.float8_e5m2.k_scale/v_scalestate on attention layers.Usage
End-to-End Benchmark
Model: Qwen3.5-35B-A3B
Backend: LMDeploy PyTorch, TP=2
Dataset: ShareGPT
Baseline: BF16 KV,
--quant-policy 0Candidates: FP8 E4M3 KV,
--quant-policy fp8; FP8 E5M2 KV,--quant-policy fp8_e5m2Positive TTFT delta means lower/better latency. Small TTFT deltas should be treated as noise.
Accuracy Check
Model: Qwen3.5-397B-A17B-FP8
Backend: LMDeploy PyTorch, TP=8
Candidate: FP8 KV,
--quant-policy fp8Validation
tests/pytorch/kernel/test_fill_kv_cache.pytests/pytorch/kernel/test_flatten_kv_cache.pytests/pytorch/kernel/test_paged_attention.pytests/test_lmdeploy/test_fp8_kv_cache_policy.pytests/test_lmdeploy/test_quant_policy.pyAssistance
Assisted with Codex + GPT-5.5 High