FP8 kv cache quantization#4563
Open
CUHKSZzxy wants to merge 25 commits into
Open
Conversation
Adds FP8 KV cache quantization (QuantPolicy.FP8 = 16) using torch.float8_e4m3fn with per-token symmetric scale (no zero point). Key design: - Reuses existing fill_kv_cache_blocked_fp8() with group_size=head_dim for per-token scale semantics in the fill path - Dequant in flatten_kv_cache and paged_attention via x.to(f32)*scale - Scale tensor shape [..., 1]: symmetric, no zero point - No bit packing (head_dim unchanged, unlike INT4/TURBO_QUANT) Also fixes pre-existing TestFillKVCacheBlockedFP8 test failures caused by calling .max() on float8_e4m3fn tensors (cast to float32 first). Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Avoid constructing a temporary cu_seqlen_q tensor in the FP8 cache-fill path by letting fill_kv_cache_blocked_fp8 consume the existing q_start_loc and q_seq_length metadata directly. The kernel keeps the old cumulative-seqlen mode for direct callers via a USE_CU_SEQLEN constexpr. Move default paged-decode FP8 dequant scaling across the attention dot products. K scales are applied after QK, and V scales are applied to probabilities before PV, which preserves the per-token/head scale algebra while avoiding full K/V tile dequantization in the hot decode loop. Add a focused FP8 paged-attention test that compares against a dequantized-FP8 reference, including a split-head-dim case, so the fused scale placement is covered without conflating it with expected quantization error.
Split normal FP8 KV cache from the dynamic per-token/head FP8 path. Normal fp8/fp8_e4m3 and fp8_e5m2 now use scalar K/V scales with FP8 cache tensors and no k_scales_zeros/v_scales_zeros metadata allocation, while fp8_per_token_head variants keep the existing per-token/head scale-cache behavior. Thread scalar k_scale/v_scale through PyTorch attention dispatch, cache fill, flatten, and paged decode kernels so normal FP8 can quantize on cache write and apply scalar dequant in decode/prefill without materialized metadata tensors. Add optional one-shot calculate_kv_scales support and guard CUDA graph capture while scale calculation is pending, mirroring vLLM's eager first-pass behavior. Add focused CLI/config/cache descriptor tests and scalar/per-token FP8 kernel reference coverage. Validation: py_compile on changed runtime/kernel/test files; pytest -q tests/test_lmdeploy/test_fp8_kv_cache_policy.py; git diff --check. CUDA kernel tests were not run because nvidia-smi cannot communicate with the driver in this environment.
Remove the deprecated-style dynamic KV scale calculation path and keep normal FP8 on the vLLM-aligned scalar-scale behavior with default scales. Drop the experimental per-token/head FP8 policy and tests so the public surface only exposes fp8, fp8_e4m3, and fp8_e5m2. Sadly we have to remove some potentially useful features to keep this PR concise and solid.
Contributor
There was a problem hiding this comment.
Pull request overview
This PR adds PyTorch-backend FP8 KV-cache quantization for paged attention, using per-tensor scalar K/V scales and storing the KV cache in torch.float8_e4m3fn (fp8) or torch.float8_e5m2 (fp8_e5m2). It wires the scales through attention/backends/kernels and introduces targeted kernel + CLI/config tests.
Changes:
- Add new quant policies
FP8/FP8_E5M2with CLI aliases and basic policy/config validation. - Implement FP8 per-tensor-scale paths across KV cache fill, paged decode attention, and KV flatten/recovery kernels.
- Add kernel tests and end-to-end quant-policy tests for FP8 KV cache behavior.
Reviewed changes
Copilot reviewed 17 out of 17 changed files in this pull request and generated 5 comments.
Show a summary per file
| File | Description |
|---|---|
| tests/test_lmdeploy/test_quant_policy.py | Adds FP8 quant-policy pipeline/accuracy tests and adjusts fixture scopes. |
| tests/test_lmdeploy/test_fp8_kv_cache_policy.py | New tests for CLI parsing + engine/config acceptance/rejection + cache-engine helpers. |
| tests/pytorch/kernel/test_paged_attention.py | Adds FP8-scalar quantized paged-attention kernel tests (E4M3/E5M2). |
| tests/pytorch/kernel/test_flatten_kv_cache.py | Adds FP8-scalar KV flatten kernel tests (E4M3/E5M2) + reference flatten. |
| tests/pytorch/kernel/test_fill_kv_cache.py | Adds FP8-scalar KV fill kernel tests (E4M3/E5M2). |
| lmdeploy/pytorch/nn/attention.py | Plumbs scalar k_scale/v_scale buffers through attention forward for FP8 KV. |
| lmdeploy/pytorch/kernels/cuda/pagedattention.py | Adds FP8 quant-policy handling in the paged-attention Triton kernel wrapper and kernel logic. |
| lmdeploy/pytorch/kernels/cuda/flatten_kv_cache.py | Adds FP8-scalar Triton flatten kernel and routes FP8 policies to it. |
| lmdeploy/pytorch/kernels/cuda/fill_kv_cache.py | Adds FP8-scalar Triton fill kernel and routes FP8 policies to it. |
| lmdeploy/pytorch/engine/cache_engine.py | Makes FP8 quant policies allocate FP8-typed KV cache tensors and logs policy description. |
| lmdeploy/pytorch/config.py | Adds an import-order lint suppression (noqa: I001). |
| lmdeploy/pytorch/backends/dlinfer/attention.py | Extends attention backend API surface to accept scalar k_scale/v_scale. |
| lmdeploy/pytorch/backends/cuda/attention/fa3.py | Plumbs scalar k_scale/v_scale through FA3 backend calls. |
| lmdeploy/pytorch/backends/cuda/attention/default.py | Plumbs scalar k_scale/v_scale through default CUDA attention backend calls. |
| lmdeploy/pytorch/backends/attention.py | Extends base attention backend interface signature with scalar scales. |
| lmdeploy/messages.py | Adds FP8 quant policies and extends engine-config validation/docs. |
| lmdeploy/cli/utils.py | Adds quant-policy string aliases and custom parsing for CLI. |
Comments suppressed due to low confidence (1)
lmdeploy/messages.py:483
- PytorchEngineConfig validation allows FP8 quant policies regardless of device_type (it only restricts quantization to CUDA/ASCEND), but CacheEngine later asserts FP8 quantization is CUDA-only. Consider adding an explicit check here to reject QuantPolicy.FP8/FP8_E5M2 when device_type != 'cuda', so users get a clear configuration-time error instead of a runtime assertion deeper in the engine.
assert self.quant_policy in (
QuantPolicy.NONE,
QuantPolicy.INT4,
QuantPolicy.INT8,
QuantPolicy.FP8,
QuantPolicy.FP8_E5M2,
QuantPolicy.TURBO_QUANT,
), 'invalid quant_policy'
assert self.device_type in ['cuda', 'ascend', 'maca', 'camb'], (f'invalid device_type: {self.device_type}')
assert self.kernel_block_size >= 16 and \
(self.kernel_block_size & (self.kernel_block_size - 1)) == 0, \
f'kernel_block_size must be >= 16 and a power of 2, but got {self.kernel_block_size}'
assert self.block_size >= self.kernel_block_size and \
self.block_size % self.kernel_block_size == 0, \
(f'block_size must be >= kernel_block_size and an integer multiple '
f'of kernel_block_size, but got block_size {self.block_size} '
f'and kernel_block_size {self.kernel_block_size}')
if self.quant_policy > 0 and self.device_type not in ['cuda', 'ascend']:
assert False, \
'kv cache quantization only works for CUDA and ASCEND.'
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Add PyTorch-backend FP8 KV-cache quantization for paged attention.
The default
fp8policy uses normal FP8 KV cache with scalar per-tensor K/V scales, matching the common vLLM behavior and avoiding per-token/per-head scale metadata on the hot path.What Changed
fp8/fp8_e4m3: E4M3 FP8 KV cachefp8_e5m2: E5M2 FP8 KV cachetorch.float8_e4m3fnortorch.float8_e5m2.k_scale/v_scalestate on attention layers.Usage
End-to-End Benchmark
Model: Qwen3.5-35B-A3B
Backend: LMDeploy PyTorch, TP=2
Dataset: ShareGPT
Baseline: BF16 KV,
--quant-policy 0Candidates: FP8 E4M3 KV,
--quant-policy fp8; FP8 E5M2 KV,--quant-policy fp8_e5m2Positive TTFT delta means lower/better latency. Small TTFT deltas should be treated as noise.
Accuracy Check
Model: Qwen3.5-397B-A17B-FP8
Backend: LMDeploy PyTorch, TP=8
Candidate: FP8 KV,
--quant-policy fp8Validation
tests/pytorch/kernel/test_fill_kv_cache.pytests/pytorch/kernel/test_flatten_kv_cache.pytests/pytorch/kernel/test_paged_attention.pytests/test_lmdeploy/test_fp8_kv_cache_policy.pytests/test_lmdeploy/test_quant_policy.py