Skip to content

FP8 kv cache quantization#4563

Open
CUHKSZzxy wants to merge 25 commits into
InternLM:mainfrom
CUHKSZzxy:feat/fp8-kv-cache-quant
Open

FP8 kv cache quantization#4563
CUHKSZzxy wants to merge 25 commits into
InternLM:mainfrom
CUHKSZzxy:feat/fp8-kv-cache-quant

Conversation

@CUHKSZzxy
Copy link
Copy Markdown
Collaborator

@CUHKSZzxy CUHKSZzxy commented Apr 29, 2026

Summary

Add PyTorch-backend FP8 KV-cache quantization for paged attention.

The default fp8 policy uses normal FP8 KV cache with scalar per-tensor K/V scales, matching the common vLLM behavior and avoiding per-token/per-head scale metadata on the hot path.

What Changed

  • Add PyTorch KV-cache support for:
    • fp8 / fp8_e4m3: E4M3 FP8 KV cache
    • fp8_e5m2: E5M2 FP8 KV cache
  • Store K/V cache as torch.float8_e4m3fn or torch.float8_e5m2.
  • Add scalar k_scale / v_scale state on attention layers.
  • Add FP8 paths for cache fill, paged decode, and prefill flatten.
  • Keep normal FP8 metadata-free: no per-token/per-head scale tensors are allocated.
  • Add CLI/config parsing and focused kernel/e2e tests.

Usage

# no KV-cache quantization
lmdeploy serve api_server <model> --backend pytorch --quant-policy 0

# FP8 E4M3 KV cache
lmdeploy serve api_server <model> --backend pytorch --quant-policy fp8

# FP8 E5M2 KV cache
lmdeploy serve api_server <model> --backend pytorch --quant-policy fp8_e5m2

End-to-End Benchmark

Model: Qwen3.5-35B-A3B
Backend: LMDeploy PyTorch, TP=2
Dataset: ShareGPT
Baseline: BF16 KV, --quant-policy 0
Candidates: FP8 E4M3 KV, --quant-policy fp8; FP8 E5M2 KV, --quant-policy fp8_e5m2

Output len Prompts BF16 KV tok/s FP8 E4M3 KV tok/s E4M3 delta FP8 E5M2 KV tok/s E5M2 delta
None 1000 5835.55 5842.50 +0.12% 5617.81 -3.73%
2048 1000 15091.14 17384.36 +15.20% 17158.46 +13.70%
4096 500 13480.78 15083.34 +11.89% 14451.87 +7.20%
8192 200 7199.99 7390.87 +2.65% 8236.96 +14.40%
Output len Prompts BF16 KV TTFT ms FP8 E4M3 KV TTFT ms E4M3 delta FP8 E5M2 KV TTFT ms E5M2 delta
None 1000 3669.65 3629.63 +1.09% 4914.25 -33.92%
2048 1000 3809.86 3678.69 +3.44% 3866.09 -1.48%
4096 500 1925.52 2018.82 -4.85% 1903.33 +1.15%
8192 200 979.81 983.78 -0.41% 965.03 +1.51%

Positive TTFT delta means lower/better latency. Small TTFT deltas should be treated as noise.

Accuracy Check

Model: Qwen3.5-397B-A17B-FP8
Backend: LMDeploy PyTorch, TP=8
Candidate: FP8 KV, --quant-policy fp8

Benchmark FP8 KV cache quant Official Qwen3.5-397B-A17B-FP8 Delta Note
AIME2025 repeat-32 95.83 91.3 +4.53 Official card reports AIME26, not AIME2025.
GPQA_diamond repeat-4 86.99 88.4 -1.41 Official card reports GPQA; benchmark naming may differ.

Validation

  • tests/pytorch/kernel/test_fill_kv_cache.py
  • tests/pytorch/kernel/test_flatten_kv_cache.py
  • tests/pytorch/kernel/test_paged_attention.py
  • tests/test_lmdeploy/test_fp8_kv_cache_policy.py
  • tests/test_lmdeploy/test_quant_policy.py
  • Qwen3.5-35B-A3B TP=2 ShareGPT e2e benchmark above

CUHKSZzxy and others added 19 commits April 23, 2026 14:57
Adds FP8 KV cache quantization (QuantPolicy.FP8 = 16) using
torch.float8_e4m3fn with per-token symmetric scale (no zero point).

Key design:
- Reuses existing fill_kv_cache_blocked_fp8() with group_size=head_dim
  for per-token scale semantics in the fill path
- Dequant in flatten_kv_cache and paged_attention via x.to(f32)*scale
- Scale tensor shape [..., 1]: symmetric, no zero point
- No bit packing (head_dim unchanged, unlike INT4/TURBO_QUANT)

Also fixes pre-existing TestFillKVCacheBlockedFP8 test failures caused
by calling .max() on float8_e4m3fn tensors (cast to float32 first).

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Avoid constructing a temporary cu_seqlen_q tensor in the FP8 cache-fill path by letting fill_kv_cache_blocked_fp8 consume the existing q_start_loc and q_seq_length metadata directly. The kernel keeps the old cumulative-seqlen mode for direct callers via a USE_CU_SEQLEN constexpr.

Move default paged-decode FP8 dequant scaling across the attention dot products. K scales are applied after QK, and V scales are applied to probabilities before PV, which preserves the per-token/head scale algebra while avoiding full K/V tile dequantization in the hot decode loop.

Add a focused FP8 paged-attention test that compares against a dequantized-FP8 reference, including a split-head-dim case, so the fused scale placement is covered without conflating it with expected quantization error.
Split normal FP8 KV cache from the dynamic per-token/head FP8 path. Normal fp8/fp8_e4m3 and fp8_e5m2 now use scalar K/V scales with FP8 cache tensors and no k_scales_zeros/v_scales_zeros metadata allocation, while fp8_per_token_head variants keep the existing per-token/head scale-cache behavior.

Thread scalar k_scale/v_scale through PyTorch attention dispatch, cache fill, flatten, and paged decode kernels so normal FP8 can quantize on cache write and apply scalar dequant in decode/prefill without materialized metadata tensors. Add optional one-shot calculate_kv_scales support and guard CUDA graph capture while scale calculation is pending, mirroring vLLM's eager first-pass behavior.

Add focused CLI/config/cache descriptor tests and scalar/per-token FP8 kernel reference coverage. Validation: py_compile on changed runtime/kernel/test files; pytest -q tests/test_lmdeploy/test_fp8_kv_cache_policy.py; git diff --check. CUDA kernel tests were not run because nvidia-smi cannot communicate with the driver in this environment.
Remove the deprecated-style dynamic KV scale calculation path and keep normal FP8 on the vLLM-aligned scalar-scale behavior with default scales.

Drop the experimental per-token/head FP8 policy and tests so the public surface only exposes fp8, fp8_e4m3, and fp8_e5m2.

Sadly we have to remove some potentially useful features to keep this PR concise and solid.
@CUHKSZzxy CUHKSZzxy marked this pull request as ready for review May 12, 2026 12:35
Copilot AI review requested due to automatic review settings May 12, 2026 12:35
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR adds PyTorch-backend FP8 KV-cache quantization for paged attention, using per-tensor scalar K/V scales and storing the KV cache in torch.float8_e4m3fn (fp8) or torch.float8_e5m2 (fp8_e5m2). It wires the scales through attention/backends/kernels and introduces targeted kernel + CLI/config tests.

Changes:

  • Add new quant policies FP8 / FP8_E5M2 with CLI aliases and basic policy/config validation.
  • Implement FP8 per-tensor-scale paths across KV cache fill, paged decode attention, and KV flatten/recovery kernels.
  • Add kernel tests and end-to-end quant-policy tests for FP8 KV cache behavior.

Reviewed changes

Copilot reviewed 17 out of 17 changed files in this pull request and generated 5 comments.

Show a summary per file
File Description
tests/test_lmdeploy/test_quant_policy.py Adds FP8 quant-policy pipeline/accuracy tests and adjusts fixture scopes.
tests/test_lmdeploy/test_fp8_kv_cache_policy.py New tests for CLI parsing + engine/config acceptance/rejection + cache-engine helpers.
tests/pytorch/kernel/test_paged_attention.py Adds FP8-scalar quantized paged-attention kernel tests (E4M3/E5M2).
tests/pytorch/kernel/test_flatten_kv_cache.py Adds FP8-scalar KV flatten kernel tests (E4M3/E5M2) + reference flatten.
tests/pytorch/kernel/test_fill_kv_cache.py Adds FP8-scalar KV fill kernel tests (E4M3/E5M2).
lmdeploy/pytorch/nn/attention.py Plumbs scalar k_scale/v_scale buffers through attention forward for FP8 KV.
lmdeploy/pytorch/kernels/cuda/pagedattention.py Adds FP8 quant-policy handling in the paged-attention Triton kernel wrapper and kernel logic.
lmdeploy/pytorch/kernels/cuda/flatten_kv_cache.py Adds FP8-scalar Triton flatten kernel and routes FP8 policies to it.
lmdeploy/pytorch/kernels/cuda/fill_kv_cache.py Adds FP8-scalar Triton fill kernel and routes FP8 policies to it.
lmdeploy/pytorch/engine/cache_engine.py Makes FP8 quant policies allocate FP8-typed KV cache tensors and logs policy description.
lmdeploy/pytorch/config.py Adds an import-order lint suppression (noqa: I001).
lmdeploy/pytorch/backends/dlinfer/attention.py Extends attention backend API surface to accept scalar k_scale/v_scale.
lmdeploy/pytorch/backends/cuda/attention/fa3.py Plumbs scalar k_scale/v_scale through FA3 backend calls.
lmdeploy/pytorch/backends/cuda/attention/default.py Plumbs scalar k_scale/v_scale through default CUDA attention backend calls.
lmdeploy/pytorch/backends/attention.py Extends base attention backend interface signature with scalar scales.
lmdeploy/messages.py Adds FP8 quant policies and extends engine-config validation/docs.
lmdeploy/cli/utils.py Adds quant-policy string aliases and custom parsing for CLI.
Comments suppressed due to low confidence (1)

lmdeploy/messages.py:483

  • PytorchEngineConfig validation allows FP8 quant policies regardless of device_type (it only restricts quantization to CUDA/ASCEND), but CacheEngine later asserts FP8 quantization is CUDA-only. Consider adding an explicit check here to reject QuantPolicy.FP8/FP8_E5M2 when device_type != 'cuda', so users get a clear configuration-time error instead of a runtime assertion deeper in the engine.
        assert self.quant_policy in (
            QuantPolicy.NONE,
            QuantPolicy.INT4,
            QuantPolicy.INT8,
            QuantPolicy.FP8,
            QuantPolicy.FP8_E5M2,
            QuantPolicy.TURBO_QUANT,
        ), 'invalid quant_policy'
        assert self.device_type in ['cuda', 'ascend', 'maca', 'camb'], (f'invalid device_type: {self.device_type}')
        assert self.kernel_block_size >= 16 and \
               (self.kernel_block_size & (self.kernel_block_size - 1)) == 0, \
               f'kernel_block_size must be >= 16 and a power of 2, but got {self.kernel_block_size}'
        assert self.block_size >= self.kernel_block_size and \
               self.block_size % self.kernel_block_size == 0, \
               (f'block_size must be >= kernel_block_size and an integer multiple '
                f'of kernel_block_size, but got block_size {self.block_size} '
                f'and kernel_block_size {self.kernel_block_size}')
        if self.quant_policy > 0 and self.device_type not in ['cuda', 'ascend']:
            assert False, \
                   'kv cache quantization only works for CUDA and ASCEND.'

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment thread lmdeploy/cli/utils.py Outdated
Comment thread lmdeploy/messages.py Outdated
Comment thread lmdeploy/messages.py Outdated
Comment thread tests/test_lmdeploy/test_fp8_kv_cache_policy.py Outdated
Comment thread tests/test_lmdeploy/test_quant_policy.py Outdated
@lvhan028 lvhan028 added the enhancement New feature or request label May 13, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

enhancement New feature or request

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants