Skip to content

[KVCache] DSA for v1 cache manager#7787

Open
Moonchild1227 wants to merge 12 commits into
PaddlePaddle:developfrom
Moonchild1227:feat/dsa-for-v1
Open

[KVCache] DSA for v1 cache manager#7787
Moonchild1227 wants to merge 12 commits into
PaddlePaddle:developfrom
Moonchild1227:feat/dsa-for-v1

Conversation

@Moonchild1227
Copy link
Copy Markdown
Contributor

@Moonchild1227 Moonchild1227 commented May 12, 2026

Motivation

将 per-layer KV cache 分配逻辑从 CacheController 下沉到 AttentionBackend,使 CacheController 变为 variant-agnostic。新增 DSA(DeepSeek V3.2-Exp-BF16)cache layout 支持(key uint8 + indexer uint8),并为后续新增 attention 变体提供可扩展基础(无需修改 CacheController)。

Modifications

  • base_attention_backend.py:新增 create_kv_cache() 默认实现(GQA/MHA key + value,含 block_wise_fp8 scale 支持);新增 create_host_kv_cache()free_host_kv_cache() 默认实现
  • dsa_attention_backend.py:override create_kv_cache() 返回 {"key": uint8, "indexer": uint8};override create_host_kv_cache() 抛出 NotImplementedError(暂不支持 host cache 下沉)
  • mla_attention_backend.py:override create_kv_cache() 返回 {"key": tensor};override create_host_kv_cache() 仅分配 key buffer
  • cache_controller.py:重写 initialize_kv_cache / initialize_mtp_kv_cache,统一通过 attn_backend.create_kv_cache() 分配;新增 _format_cache_name();重写 initialize_host_cache_free_host_cache,委托给 backend;删除 MLACacheControllerDSACacheControllercreate_cache_controller()

Usage or Command

N/A

Accuracy Tests

DSA(DeepSeek V3.2-Exp-BF16)端到端 /v1/chat/completions 请求验证通过。

# v0
python3 gsm8k.py
🎯 Evaluation Complete: Accuracy = 95.22% (657/690)
time: 28:44
# v1
python3 gsm8k.py
🎯 Evaluation Complete: Accuracy = 94.35% (651/690)
time: 23:38

# v1 有6个问题回答正确但是输出为mardown格式的代码,导致脚本未识别。

MLA / GQA 模型验证待补充。

Checklist

  • Add at least a tag in the PR title.
    • Tag list: [[FDConfig],[APIServer],[Engine], [Scheduler], [PD Disaggregation], [Executor], [Graph Optimization], [Speculative Decoding], [RL], [Models], [Quantization], [Loader], [OP], [KVCache], [DataProcessor], [BugFix], [Docs], [CI], [Optimization], [Feature], [Benchmark], [Others], [XPU], [HPU], [GCU], [DCU], [Iluvatar], [Metax]]
    • You can add new tags based on the PR content, but the semantics must be clear.
  • Format your code, run pre-commit before commit.
  • Add unit tests. Please write the reason in this PR if no unit tests.
  • Provide accuracy results.
  • If the current PR is submitting to the release branch, make sure the PR has been submitted to the develop branch, then cherry-pick it to the release branch with the [Cherry-Pick] PR tag.

Copilot AI review requested due to automatic review settings May 12, 2026 06:31
@paddle-bot
Copy link
Copy Markdown

paddle-bot Bot commented May 12, 2026

Thanks for your contribution!

@paddle-bot paddle-bot Bot added the contributor External developers label May 12, 2026
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

该 PR 将 per-layer KV cache 的分配逻辑下沉到 AttentionBackend(通过新增 create_kv_cache 接口),使 cache_manager/v1CacheController 仅负责 role→存储名映射、注册与可选的 set_data_ipc pin,从而减少 controller 对不同 attention variant(GQA/MLA/DSA)的耦合。

Changes:

  • AttentionBackend 新增 pin_kv_cache_for_cudagraph 与默认 create_kv_cache(...)(GQA/MHA:key/value,fp8 额外 scale)。
  • MLA/DSA backend 覆写 create_kv_cache:MLA 仅 key;DSA 返回 key+indexer(uint8)。
  • CacheController.initialize_kv_cache / initialize_mtp_kv_cache 改为逐层调用 attn_backend.create_kv_cache,并新增 "indexer" role 的存储名映射及 cudagraph pin 逻辑。

Reviewed changes

Copilot reviewed 4 out of 4 changed files in this pull request and generated 2 comments.

File Description
fastdeploy/model_executor/layers/attention/base_attention_backend.py 为 attention backend 增加通用 KV cache 分配入口与 cudagraph pin 标志位。
fastdeploy/model_executor/layers/attention/mla_attention_backend.py MLA backend 覆写 KV cache 分配:仅分配压缩 latent key cache,并要求 pin。
fastdeploy/model_executor/layers/attention/dsa_attention_backend.py DSA backend 覆写 KV cache 分配:分配 uint8 key + uint8 indexer,并要求 pin。
fastdeploy/cache_manager/v1/cache_controller.py controller 重构为 role 注册/命名映射 + 可选 pin;主模型与 MTP 走同一分配路径。

Comment thread fastdeploy/cache_manager/v1/cache_controller.py Outdated
Comment thread fastdeploy/cache_manager/v1/cache_controller.py
Comment thread fastdeploy/cache_manager/v1/cache_controller.py Outdated
Comment thread fastdeploy/cache_manager/v1/cache_controller.py Outdated
Comment thread fastdeploy/cache_manager/v1/cache_controller.py Outdated
Comment thread fastdeploy/cache_manager/v1/cache_controller.py Outdated
Comment thread fastdeploy/cache_manager/v1/cache_controller.py
PaddlePaddle-bot

This comment was marked as outdated.

@PaddlePaddle-bot
Copy link
Copy Markdown

PaddlePaddle-bot commented May 12, 2026

🤖 Paddle-CI-Agent | ci_status_monitor | 2026-05-13 20:46:47

CI报告基于以下代码生成(30分钟更新一次):


1 任务总览

⚠️ 存在 1 个 Required 任务失败(需优先处理),3 个 Required 任务仍在运行中。

总执行(rerun次数) 总任务 ✅ 通过 ❌ 失败 ⏳ 运行中 ⏸️ 等待中 跳过
42(0) 42 34 2 5 1 0

2 任务状态汇总

2.1 Required任务 : 6/10 通过

必选任务阻塞合并,失败需优先处理。

状态 任务 耗时 根因 修复建议 日志 重跑
Approval 9s PR问题:新增logger.info()调用触发日志行为审批 请联系xyxinyang或zyyzghb审批PR Job -
run_tests_with_coverage - 运行中 - Job -
run_ce_cases - 运行中 - Job -
run_xpu_4cards_cases - 运行中 - Job -
其余 6 个必选任务通过 - - - - -

2.2 可选任务 — 28/32 通过

可选任务不阻塞合并,失败仅供参考。

状态 任务 耗时 日志 重跑
Cleanup artifacts 6s Job -
run_iluvatar_cases - - -
Trigger Jenkins for PR - - -
⏸️ CI_HPU - - -
其余 28 个可选任务通过 - - -

3 失败详情(仅 required)

Approval — 代码规范/审批门控(置信度: 高)

Approval

  • 状态: ❌ 失败
  • 错误类型: 代码规范(审批门控)
  • 置信度: 高
  • 根因摘要: PR新增logger.info()调用,触发日志行为审批检查,等待指定RD审批
  • 分析器: 通用分析(fallback)

根因详情:
PR 新增了多处 logger.info() 调用(KVCache DSA 相关代码),触发了 CI 中 check_approval.sh 脚本的审批检查。该检查要求修改日志行为(.info/.debug/.error/log_request)的 PR,必须获得至少一位指定 FastDeploy RD 的审批后方可通过。

关键日志:

Detected log modification in diff:
+        logger.info(
+        logger.info(
+        logger.info(f"[free_host_kv_cache][...] freeing host cache buffers.")
...
0. You must have one FastDeploy RD (xyxinyang(zhouchong), zyyzghb(zhangyongyue)) approval
   for modifying logging behavior (.info/.debug/.error/log_request).
There are 1 approved errors.
##[error]Process completed with exit code 6.

修复建议:

  1. 请联系 xyxinyang(zhouchong) 或 zyyzghb(zhangyongyue) 审批本 PR,完成后 Approval CI 将自动重跑并通过。

修复建议摘要: 请联系xyxinyang或zyyzghb完成PR审批

关联变更: PR 新增 logger.info() 调用(free_host_kv_cache 等 KVCache DSA 模块)

链接: 查看日志

@codecov-commenter
Copy link
Copy Markdown

codecov-commenter commented May 12, 2026

Codecov Report

❌ Patch coverage is 26.40000% with 92 lines in your changes missing coverage. Please review.
⚠️ Please upload report for BASE (develop@c2df4c6). Learn more about missing BASE report.

Files with missing lines Patch % Lines
...xecutor/layers/attention/base_attention_backend.py 12.06% 51 Missing ⚠️
fastdeploy/cache_manager/v1/cache_controller.py 55.26% 15 Missing and 2 partials ⚠️
...executor/layers/attention/mla_attention_backend.py 10.52% 17 Missing ⚠️
...executor/layers/attention/dsa_attention_backend.py 30.00% 7 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             develop    #7787   +/-   ##
==========================================
  Coverage           ?   72.10%           
==========================================
  Files              ?      398           
  Lines              ?    55976           
  Branches           ?     8749           
==========================================
  Hits               ?    40364           
  Misses             ?    12844           
  Partials           ?     2768           
Flag Coverage Δ
GPU 72.10% <26.40%> (?)

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

PaddlePaddle-bot

This comment was marked as outdated.

Copilot AI review requested due to automatic review settings May 12, 2026 08:11
PaddlePaddle-bot

This comment was marked as outdated.

Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 4 out of 4 changed files in this pull request and generated 5 comments.

Comment thread fastdeploy/cache_manager/v1/cache_controller.py Outdated
Comment thread fastdeploy/cache_manager/v1/cache_controller.py
Comment thread fastdeploy/cache_manager/v1/cache_controller.py
PaddlePaddle-bot

This comment was marked as outdated.

Copilot AI review requested due to automatic review settings May 13, 2026 02:44
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 4 out of 4 changed files in this pull request and generated 2 comments.

Comments suppressed due to low confidence (2)

fastdeploy/model_executor/layers/attention/base_attention_backend.py:137

  • create_host_kv_cache 的 docstring 说明“host alloc 不可用时返回空 dict”,但实现里在 cuda_host_alloc is None 时直接 raise RuntimeError。这会让调用方(如 CacheController)难以按文档处理降级逻辑。建议要么按文档返回 {} 并由上层跳过 swap space 初始化,要么修正文档并让上层显式捕获该异常。
        Returns:
            Dict keyed by ``(role, layer_idx)``. Empty dict if host alloc is
            unavailable on the current platform.
        """
        if cuda_host_alloc is None:

fastdeploy/cache_manager/v1/cache_controller.py:544

  • initialize_host_cache 目前只捕获 NotImplementedError。但默认实现 AttentionBackend.create_host_kv_cache()cuda_host_alloc is None 时会抛 RuntimeError(以及部分 backend 可能同样抛 RuntimeError),这会让启用 swap space 的场景直接初始化失败。建议在这里同时捕获 RuntimeError(必要时也可捕获 TypeError/AttributeError)并以 warning 方式跳过 host cache 初始化,保证在不支持 pinned host alloc 的平台上可降级运行。
        try:
            host_caches = attn_backend.create_host_kv_cache(
                num_layers=num_layers,
                num_blocks=num_host_blocks,
                cache_item_bytes=cache_item_bytes,

Comment on lines 43 to 49
class AttentionBackend(ABC):
"""The base class of attention backends"""

@abstractmethod
def init_attention_metadata(self, forward_meta: ForwardMeta):
"""Initialize the forward metadata."""
raise NotImplementedError
Comment on lines +295 to +299
caches = attn_backend.create_kv_cache(
num_layers=self._num_layers,
num_blocks=num_gpu_blocks,
cache_dtype=cache_dtype,
kv_cache_quant_type=kv_cache_quant_type,
PaddlePaddle-bot

This comment was marked as outdated.

)
return caches

def create_host_kv_cache(
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

也支持一个 free_host_kv_cache 方法吧。把controller里边的实现下移到这里

PaddlePaddle-bot

This comment was marked as outdated.

Copilot AI review requested due to automatic review settings May 13, 2026 07:02
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 4 out of 4 changed files in this pull request and generated 4 comments.

Comment thread fastdeploy/cache_manager/v1/cache_controller.py
Comment on lines +134 to +141
Returns:
Dict keyed by ``(role, layer_idx)``. Empty dict if host alloc is
unavailable on the current platform.
"""
if cuda_host_alloc is None:
raise RuntimeError(
f"[create_host_kv_cache][{type(self).__name__}] cuda_host_alloc " "is not available on this platform"
)
Comment on lines +299 to +312
DSA cache: uint8 key cache + uint8 indexer cache (no separate value, no scales).

`cache_dtype` is ignored; DSA always stores packed fp8+scales as uint8.
`kv_cache_quant_type` is coerced to "uint8" internally.
"""
key_shape, _, indexer_shape = self.get_kv_cache_shape(max_num_blocks=num_blocks, kv_cache_quant_type="uint8")
logger.info(
f"[create_kv_cache][DSA] num_layers={num_layers} layer_offset={layer_offset} "
f"key_shape={key_shape} indexer_shape={indexer_shape} dtype=uint8"
)
caches = {}
for layer_idx in range(layer_offset, layer_offset + num_layers):
caches[("key", layer_idx)] = paddle.full(shape=key_shape, fill_value=0, dtype="uint8")
caches[("indexer", layer_idx)] = paddle.full(shape=indexer_shape, fill_value=0, dtype="uint8")
Comment on lines +156 to +158
# fp8 scales use float32 (4 bytes), shape [num_blocks, k1, k2].
scale_elems = key_shape[1] * key_shape[2] if is_fp8 else 0
scale_bytes = num_blocks * 4 * scale_elems if is_fp8 else 0
PaddlePaddle-bot

This comment was marked as outdated.

@Moonchild1227 Moonchild1227 marked this pull request as ready for review May 13, 2026 08:57
chang-wenbin
chang-wenbin previously approved these changes May 13, 2026
Copilot AI review requested due to automatic review settings May 13, 2026 11:31
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 5 out of 5 changed files in this pull request and generated 2 comments.

Comment on lines +139 to +141
raise RuntimeError(
f"[create_host_kv_cache][{type(self).__name__}] cuda_host_alloc " "is not available on this platform"
)
Comment on lines +756 to +760
if kv_cache_quant_type == "block_wise_fp8":
caches[("key_scale", layer_idx)] = paddle.zeros([1], dtype="float32")
if resolved_val_shape is not None:
caches[("value_scale", layer_idx)] = paddle.zeros([1], dtype="float32")
return caches
PaddlePaddle-bot

This comment was marked as outdated.

PaddlePaddle-bot

This comment was marked as outdated.

Copilot AI review requested due to automatic review settings May 13, 2026 12:15
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 5 out of 5 changed files in this pull request and generated 4 comments.

Comment on lines +135 to +136
Dict keyed by ``(role, layer_idx)``. Empty dict if host alloc is
unavailable on the current platform.
Comment on lines +156 to +164
# fp8 scales use float32 (4 bytes), shape [num_blocks, k1, k2].
scale_elems = key_shape[1] * key_shape[2] if is_fp8 else 0
scale_bytes = num_blocks * 4 * scale_elems if is_fp8 else 0

logger.info(
f"[create_host_kv_cache][{type(self).__name__}] num_layers={num_layers} "
f"layer_offset={layer_offset} num_blocks={num_blocks} "
f"key_bytes_per_layer={key_bytes} value_bytes_per_layer={value_bytes} "
f"scale_bytes_per_layer={scale_bytes} kv_cache_quant_type={kv_cache_quant_type}"
Comment on lines +1033 to +1034
name->ptr bookkeeping; the backend reference is captured at
``initialize_host_cache`` time.
Comment on lines +298 to +302
caches = self.attn_backend.create_kv_cache(
num_layers=self._num_layers,
num_blocks=num_gpu_blocks,
cache_dtype=cache_dtype,
kv_cache_quant_type=kv_cache_quant_type,
Copy link
Copy Markdown

@PaddlePaddle-bot PaddlePaddle-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🤖 Paddle-CI-Agent | pr_review | 2026-05-13 20:31:34

📋 Review 摘要

PR 概述:将 per-layer KV cache 分配逻辑从 CacheController 下沉到各 AttentionBackend,新增 DSA layout 支持(key + indexer uint8),并删除原有的 MLACacheController/DSACacheController 子类。

变更范围fastdeploy/cache_manager/v1/fastdeploy/model_executor/layers/attention/

影响面 Tag[KVCache] [OP]


📝 PR 规范检查

标题格式与描述模板结构均合规;## Accuracy Tests 中标注"MLA / GQA 模型验证待补充",建议在合入前补充对应精度数据或说明原因。Add unit tests Checklist 项未勾选但测试文件已更新,建议补充勾选。


问题

级别 文件 概述
🔴 Bug base_attention_backend.py:139 create_host_kv_cacheRuntimeError,但调用方只捕获 NotImplementedError,非 GPU 平台初始化会崩溃
🟡 建议 mla_attention_backend.py:658 直接 import 后检查 if cuda_host_alloc is None 是死代码,无法提供平台保护
🟡 建议 cache_controller.py:554 NotImplementedError 提前 return 后,_host_key_cache_shape 等属性从未赋值,下游访问会 AttributeError
🟡 建议 test_cache_controller.py:763 make_mock_attn_backend 仅 mock create_kv_cache,未 mock create_host_kv_cache,host cache 测试路径可能被掩盖

总体评价

重构思路清晰,将 variant-specific 分配逻辑下沉到 backend 可以有效降低 CacheController 复杂度。但存在一个 P0 异常类型不匹配问题需要修复:非 GPU 平台上 create_host_kv_cache 会抛出 RuntimeError 而非 NotImplementedError,调用方的异常捕获会完全失效。另有两处 P1 风险(死代码保护、属性未初始化)和测试覆盖缺口建议一并处理。

unavailable on the current platform.
"""
if cuda_host_alloc is None:
raise RuntimeError(
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🔴 Bug RuntimeError 与调用方期望的 NotImplementedError 不匹配

cache_controller.initialize_host_cacheexcept NotImplementedError,当 cuda_host_alloc is None(非 GPU 平台)时此处抛出 RuntimeError 会向上穿透,导致初始化崩溃。

建议将 RuntimeError 改为 NotImplementedError,或调用方改为 except (NotImplementedError, RuntimeError)

from fastdeploy.cache_manager.ops import cuda_host_alloc

if cuda_host_alloc is None:
raise RuntimeError("[create_host_kv_cache][MLA] cuda_host_alloc is not available")
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟡 建议 None 检查永远不会触发 — 死代码

此处通过 from ... import cuda_host_alloc 直接导入,直接导入绝不会返回 None(若包不存在会抛 ImportError)。if cuda_host_alloc is None 的判断永远为 False,起不到保护作用。

建议改为模块级 try/except 包裹(与 base_attention_backend.py 一致):

try:
    from fastdeploy.cache_manager.ops import cuda_host_alloc
except Exception:
    cuda_host_alloc = None

# 然后在函数体内检查
if cuda_host_alloc is None:
    raise NotImplementedError(...)

logger.warning(
f"[CacheController] Host kv cache offload not supported by "
f"{type(attn_backend).__name__}: {e}. Skipping swap space setup."
)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟡 建议 提前 return_host_key_cache_shape 等属性从未赋值

DSA 调用 create_host_kv_cacheNotImplementedError → 此处 return,但 _host_key_cache_shape_host_value_cache_shape_host_cache_scale_shape_num_host_blocks 均在后续才被赋值。若下游代码(如 transfer_manager)访问这些属性,会抛 AttributeError

建议在 __init__ 中将这些属性初始化为 None,确保对象始终具备完整的属性集合。

return caches

backend.create_kv_cache.side_effect = fake_create_kv_cache
return backend
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟡 建议 make_mock_attn_backend 未 mock create_host_kv_cache

当前仅 mock 了 create_kv_cachecreate_host_kv_cacheMagicMock 自动生成,返回值是 MagicMock 对象。cache_controller.initialize_host_cache 会对其调用 .items(),得到不正确的迭代结果,可能掩盖真实的 host cache 初始化问题。

建议补充:

backend.create_host_kv_cache.return_value = {}  # 或按需返回 {(role, layer_idx): ptr}

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

contributor External developers

Projects

None yet

Development

Successfully merging this pull request may close these issues.

7 participants