diff --git a/lightllm/common/basemodel/attention/nsa/fp8_flashmla_sparse.py b/lightllm/common/basemodel/attention/nsa/fp8_flashmla_sparse.py
index 539ade769..add32918c 100644
--- a/lightllm/common/basemodel/attention/nsa/fp8_flashmla_sparse.py
+++ b/lightllm/common/basemodel/attention/nsa/fp8_flashmla_sparse.py
@@ -1,13 +1,86 @@
import dataclasses
+import inspect
import torch
from typing import TYPE_CHECKING, Tuple
from ..base_att import AttControl, BaseAttBackend, BaseDecodeAttState, BasePrefillAttState
from lightllm.utils.dist_utils import get_current_device_id
+from lightllm.utils.log_utils import init_logger
if TYPE_CHECKING:
from lightllm.common.basemodel.infer_struct import InferStateInfo
+logger = init_logger(__name__)
+
+# this flash_mla extra-cache fork only instantiates h_q in {64, 128}; pad TP-split q heads up
+# to the nearest supported count (zero heads are discarded from the output slice).
+FLASHMLA_SUPPORTED_HEADS = (64, 128)
+
+
+def _target_q_heads(h_q: int) -> int:
+ target = next((h for h in FLASHMLA_SUPPORTED_HEADS if h >= h_q), None)
+ assert target is not None, f"num q heads {h_q} exceeds flash_mla support {FLASHMLA_SUPPORTED_HEADS}"
+ return target
+
+
+def _pad_q_heads(
+ q_4d: torch.Tensor,
+ attn_sink: torch.Tensor,
+ q_out: torch.Tensor = None,
+ sink_out: torch.Tensor = None,
+):
+ h_q = q_4d.shape[2]
+ if h_q in FLASHMLA_SUPPORTED_HEADS:
+ return q_4d, attn_sink, h_q
+ target = _target_q_heads(h_q)
+ if q_out is not None:
+ q_out[:, :, :h_q, :].copy_(q_4d)
+ q_out[:, :, h_q:target, :].zero_()
+ sink_out[:h_q].copy_(attn_sink)
+ sink_out[h_q:target].zero_()
+ return q_out, sink_out[:target], h_q
+ q_pad = torch.nn.functional.pad(q_4d, (0, 0, 0, target - h_q))
+ sink_pad = torch.nn.functional.pad(attn_sink, (0, target - h_q))
+ return q_pad, sink_pad, h_q
+
+
+def _view_dsv4_flashmla_cache(layer_buffer: torch.Tensor, page_size: int) -> torch.Tensor:
+ from lightllm.common.kv_cache_mem_manager.deepseek4_mem_manager import DSV4_MLA_BYTES_PER_TOKEN
+
+ usable = page_size * DSV4_MLA_BYTES_PER_TOKEN
+ return layer_buffer[:, :usable].view(layer_buffer.shape[0], page_size, 1, DSV4_MLA_BYTES_PER_TOKEN)
+
+
+@dataclasses.dataclass
+class _Dsv4Metadata:
+ swa_indices: torch.Tensor
+ swa_lengths: torch.Tensor
+ extra_cache: torch.Tensor = None
+ extra_indices: torch.Tensor = None
+ extra_lengths: torch.Tensor = None
+
+
+def _metadata_from_dict(infer_state, nsa_dict: dict) -> "_Dsv4Metadata":
+ """Bundle the model-built FINAL index tensors (carried in nsa_dict by DeepseekV4IndexInfer) with
+ the layer-keyed fp8 extra-cache byte view. The cache view is data-independent (a fixed per-layer
+ buffer slice), so it is built here -- a genuine flash_mla ABI concern -- rather than on the model
+ side; only the index/length tensors cross the att_control boundary."""
+ from lightllm.common.kv_cache_mem_manager.deepseek4_mem_manager import DSV4_C128_PAGE_SIZE, DSV4_C4_PAGE_SIZE
+
+ ratio = nsa_dict["compress_ratio"]
+ extra_cache = None
+ if ratio:
+ page = DSV4_C4_PAGE_SIZE if ratio == 4 else DSV4_C128_PAGE_SIZE
+ extra_buffer = infer_state.mem_manager.get_compressed_kv_buffer(nsa_dict["layer_index"])
+ extra_cache = _view_dsv4_flashmla_cache(extra_buffer, page)
+ return _Dsv4Metadata(
+ swa_indices=nsa_dict["swa_indices"],
+ swa_lengths=nsa_dict["swa_lengths"],
+ extra_cache=extra_cache,
+ extra_indices=nsa_dict.get("extra_indices"),
+ extra_lengths=nsa_dict.get("extra_lengths"),
+ )
+
class NsaFlashMlaFp8SparseAttBackend(BaseAttBackend):
def __init__(self, model):
@@ -17,6 +90,61 @@ def __init__(self, model):
torch.empty(model.graph_max_batch_size * model.max_seq_length, dtype=torch.int32, device=device)
for _ in range(2)
]
+ self.prefill_flash_mla, self.prefill_flash_mla_supports_out = self._load_prefill_flash_mla()
+ self.prefill_q_workspace = None
+ self.prefill_out_workspace = None
+ self.prefill_real_out_workspace = None
+ self.prefill_sink_workspace = None
+ self.prefill_workspace_shape = None
+ self.prefill_real_out_shape = None
+ self.prefill_workspace_token_capacity = int(model.batch_max_tokens or 0)
+ if self.prefill_flash_mla_supports_out:
+ logger.info("DSV4 FlashMLA prefill uses vLLM out= workspace path")
+ else:
+ logger.warning("DSV4 FlashMLA prefill out= path unavailable; falling back to allocating FlashMLA output")
+
+ def _load_prefill_flash_mla(self):
+ try:
+ from vllm.v1.attention.ops import flashmla as flash_mla
+
+ sig = inspect.signature(flash_mla.flash_mla_with_kvcache)
+ if "out" in sig.parameters:
+ return flash_mla, True
+ except Exception:
+ pass
+
+ import flash_mla
+
+ return flash_mla, "out" in inspect.signature(flash_mla.flash_mla_with_kvcache).parameters
+
+ def _ensure_prefill_workspace(self, token_num: int, head_num: int, target_heads: int, head_dim: int, dtype, device):
+ capacity = max(token_num, self.prefill_workspace_token_capacity)
+ workspace_shape = (capacity, 1, target_heads, head_dim)
+ if (
+ self.prefill_workspace_shape != (target_heads, head_dim, dtype, device)
+ or self.prefill_q_workspace is None
+ or self.prefill_q_workspace.shape[0] < capacity
+ ):
+ self.prefill_q_workspace = torch.empty(workspace_shape, dtype=dtype, device=device)
+ self.prefill_out_workspace = torch.empty(workspace_shape, dtype=dtype, device=device)
+ self.prefill_sink_workspace = torch.empty((target_heads,), dtype=torch.float32, device=device)
+ self.prefill_workspace_shape = (target_heads, head_dim, dtype, device)
+
+ real_out_shape = (capacity, head_num, head_dim)
+ if (
+ self.prefill_real_out_shape != (head_num, head_dim, dtype, device)
+ or self.prefill_real_out_workspace is None
+ or self.prefill_real_out_workspace.shape[0] < capacity
+ ):
+ self.prefill_real_out_workspace = torch.empty(real_out_shape, dtype=dtype, device=device)
+ self.prefill_real_out_shape = (head_num, head_dim, dtype, device)
+
+ return (
+ self.prefill_q_workspace[:token_num],
+ self.prefill_out_workspace[:token_num],
+ self.prefill_real_out_workspace[:token_num],
+ self.prefill_sink_workspace,
+ )
def create_att_prefill_state(self, infer_state: "InferStateInfo") -> "NsaFlashMlaFp8SparsePrefillAttState":
return NsaFlashMlaFp8SparsePrefillAttState(backend=self, infer_state=infer_state)
@@ -31,9 +159,20 @@ class NsaFlashMlaFp8SparsePrefillAttState(BasePrefillAttState):
ke: torch.Tensor = None
lengths: torch.Tensor = None
ragged_mem_index: torch.Tensor = None
+ flashmla_sched_meta: object = None
def init_state(self):
self.backend: NsaFlashMlaFp8SparseAttBackend = self.backend
+ self.flashmla_sched_meta = {}
+ return
+
+ def ensure_nsa_ks_ke(self):
+ """Build the ragged ks/ke/lengths (+ ragged_mem_index) the DeepSeek-3.2 indexer consumes. The
+ indexer calls this explicitly before reading them; DeepSeek-V4 uses its own indexer and never
+ calls it, so V4 prefill skips the alloc + gen_nsa_ks_ke kernel. Idempotent + layer-independent:
+ the first call in a forward computes, the other layers reuse."""
+ if self.ks is not None:
+ return
self.ragged_mem_index = torch.empty(
self.infer_state.total_token_num,
dtype=torch.int32,
@@ -52,6 +191,13 @@ def init_state(self):
)
return
+ def _get_flashmla_sched_meta(self, compress_ratio: int):
+ sched_meta = self.flashmla_sched_meta.get(compress_ratio)
+ if sched_meta is None:
+ sched_meta = self.backend.prefill_flash_mla.get_mla_metadata()[0]
+ self.flashmla_sched_meta[compress_ratio] = sched_meta
+ return sched_meta
+
def prefill_att(
self,
q: torch.Tensor,
@@ -59,9 +205,17 @@ def prefill_att(
v: torch.Tensor,
att_control: AttControl = AttControl(),
alloc_func=torch.empty,
+ out: torch.Tensor = None,
) -> torch.Tensor:
assert att_control.nsa_prefill, "nsa_prefill must be True for NSA prefill attention"
assert att_control.nsa_prefill_dict is not None, "nsa_prefill_dict is required"
+ if att_control.nsa_prefill_dict.get("flashmla_kvcache"):
+ return self._flashmla_kvcache_prefill_att(
+ q=q,
+ packed_kv=k,
+ nsa_dict=att_control.nsa_prefill_dict,
+ out=out,
+ )
return self._nsa_prefill_att(q=q, packed_kv=k, att_control=att_control)
def _nsa_prefill_att(
@@ -78,6 +232,8 @@ def _nsa_prefill_att(
kv_lora_rank = nsa_dict["kv_lora_rank"]
topk_mem_indices = nsa_dict["topk_mem_indices"]
prefill_cache_kv = nsa_dict["prefill_cache_kv"]
+ attn_sink = nsa_dict.get("attn_sink")
+ topk_length = nsa_dict.get("topk_length")
if self.infer_state.prefix_total_token_num > 0:
# 当前推理生成的token kv部分从 prefill_cache_kv 中获取,历史
@@ -101,9 +257,69 @@ def _nsa_prefill_att(
indices=topk_indices,
sm_scale=softmax_scale,
d_v=kv_lora_rank,
+ attn_sink=attn_sink,
+ topk_length=topk_length,
)
return mla_out
+ def _flashmla_kvcache_prefill_att(
+ self, q: torch.Tensor, packed_kv: torch.Tensor, nsa_dict: dict, out: torch.Tensor = None
+ ) -> torch.Tensor:
+ attn_sink = nsa_dict["attn_sink"]
+ metadata = _metadata_from_dict(self.infer_state, nsa_dict)
+ return self._flashmla_kvcache_att(q, packed_kv, metadata, attn_sink, nsa_dict, out=out)
+
+ def _flashmla_kvcache_att(
+ self,
+ q: torch.Tensor,
+ packed_kv: torch.Tensor,
+ metadata: _Dsv4Metadata,
+ attn_sink: torch.Tensor,
+ nsa_dict: dict,
+ out: torch.Tensor = None,
+ ) -> torch.Tensor:
+ from lightllm.common.kv_cache_mem_manager.deepseek4_mem_manager import DSV4_SWA_PAGE_SIZE
+
+ q_4d = q.unsqueeze(1).contiguous()
+ num_real_heads = q_4d.shape[2]
+ target_heads = _target_q_heads(num_real_heads)
+ q_workspace, full_out_workspace, real_out_workspace, sink_workspace = self.backend._ensure_prefill_workspace(
+ q_4d.shape[0],
+ num_real_heads,
+ target_heads,
+ q_4d.shape[-1],
+ q_4d.dtype,
+ q_4d.device,
+ )
+ q_for_flash, sink_for_flash, num_real_heads = _pad_q_heads(q_4d, attn_sink, q_workspace, sink_workspace)
+ k_cache = _view_dsv4_flashmla_cache(packed_kv, DSV4_SWA_PAGE_SIZE)
+ sched_meta = self._get_flashmla_sched_meta(nsa_dict["compress_ratio"])
+ flash_mla = self.backend.prefill_flash_mla
+ kwargs = dict(
+ q=q_for_flash,
+ k_cache=k_cache,
+ block_table=None,
+ cache_seqlens=None,
+ head_dim_v=nsa_dict["head_dim_v"],
+ tile_scheduler_metadata=sched_meta,
+ num_splits=None,
+ softmax_scale=nsa_dict["softmax_scale"],
+ causal=False,
+ is_fp8_kvcache=True,
+ indices=metadata.swa_indices,
+ attn_sink=sink_for_flash,
+ topk_length=metadata.swa_lengths,
+ extra_k_cache=metadata.extra_cache,
+ extra_indices_in_kvcache=metadata.extra_indices,
+ extra_topk_length=metadata.extra_lengths,
+ )
+ if self.backend.prefill_flash_mla_supports_out:
+ kwargs["out"] = full_out_workspace
+ full_out, _ = flash_mla.flash_mla_with_kvcache(**kwargs)
+ real_out = out if out is not None else real_out_workspace
+ real_out.copy_(full_out[:, 0, :num_real_heads, :])
+ return real_out
+
@dataclasses.dataclass
class NsaFlashMlaFp8SparseDecodeAttState(BaseDecodeAttState):
@@ -143,7 +359,23 @@ def init_state(self):
)
import flash_mla
- self.flashmla_sched_meta, _ = flash_mla.get_mla_metadata()
+ # one sched_meta per layer type: the lazy config locks extra-cache geometry (page size,
+ # presence) on first invocation, so swa-only/c4/c128 layers must not share one object.
+ self.flashmla_sched_meta = {ratio: flash_mla.get_mla_metadata()[0] for ratio in (0, 4, 128)}
+ return
+
+ def ensure_nsa_ks_ke(self):
+ # decode builds ks/ke eagerly in init_state (outside the cuda graph, for capture safety), so
+ # they are already available -- this satisfies the shared DeepSeek-3.2 indexer ensure contract.
+ return
+
+ def reset_sched_meta_for_capture(self):
+ # cuda-graph capture hook: the warmup pass already locked/stored sched meta on this
+ # (shared) state object; reset so the capture pass re-plans INSIDE the graph and every
+ # replay re-plans from the live tensors instead of binding warmup leftovers.
+ import flash_mla
+
+ self.flashmla_sched_meta = {ratio: flash_mla.get_mla_metadata()[0] for ratio in (0, 4, 128)}
return
def decode_att(
@@ -156,6 +388,12 @@ def decode_att(
) -> torch.Tensor:
assert att_control.nsa_decode, "nsa_decode must be True for NSA decode attention"
assert att_control.nsa_decode_dict is not None, "nsa_decode_dict is required"
+ if att_control.nsa_decode_dict.get("flashmla_kvcache"):
+ return self._flashmla_kvcache_decode_att(
+ q=q,
+ packed_kv=k,
+ nsa_dict=att_control.nsa_decode_dict,
+ )
return self._nsa_decode_att(q=q, packed_kv=k, att_control=att_control)
def _nsa_decode_att(
@@ -170,6 +408,11 @@ def _nsa_decode_att(
topk_mem_indices = nsa_dict["topk_mem_indices"]
softmax_scale = nsa_dict["softmax_scale"]
kv_lora_rank = nsa_dict["kv_lora_rank"]
+ attn_sink = nsa_dict.get("attn_sink")
+ topk_length = nsa_dict.get("topk_length")
+ extra_k_cache = nsa_dict.get("extra_k_cache")
+ extra_indices = nsa_dict.get("extra_indices_in_kvcache")
+ extra_topk_length = nsa_dict.get("extra_topk_length")
if topk_mem_indices.ndim == 2:
topk_mem_indices = topk_mem_indices.unsqueeze(1)
@@ -189,10 +432,54 @@ def _nsa_decode_att(
block_table=None,
cache_seqlens=None,
head_dim_v=kv_lora_rank,
- tile_scheduler_metadata=self.flashmla_sched_meta,
+ tile_scheduler_metadata=self.flashmla_sched_meta[0],
softmax_scale=softmax_scale,
causal=False,
is_fp8_kvcache=True,
indices=topk_mem_indices,
+ attn_sink=attn_sink,
+ topk_length=topk_length,
+ extra_k_cache=extra_k_cache,
+ extra_indices_in_kvcache=extra_indices,
+ extra_topk_length=extra_topk_length,
)
return o_tensor[:, 0, :, :] # [b, 1, h, d] -> [b, h, d]
+
+ def _flashmla_kvcache_decode_att(self, q: torch.Tensor, packed_kv: torch.Tensor, nsa_dict: dict) -> torch.Tensor:
+ attn_sink = nsa_dict["attn_sink"]
+ metadata = _metadata_from_dict(self.infer_state, nsa_dict)
+ return self._flashmla_kvcache_att(q, packed_kv, metadata, attn_sink, nsa_dict)
+
+ def _flashmla_kvcache_att(
+ self,
+ q: torch.Tensor,
+ packed_kv: torch.Tensor,
+ metadata: _Dsv4Metadata,
+ attn_sink: torch.Tensor,
+ nsa_dict: dict,
+ ) -> torch.Tensor:
+ import flash_mla
+ from lightllm.common.kv_cache_mem_manager.deepseek4_mem_manager import DSV4_SWA_PAGE_SIZE
+
+ q_4d = q.unsqueeze(1).contiguous()
+ q_4d, attn_sink, num_real_heads = _pad_q_heads(q_4d, attn_sink)
+ k_cache = _view_dsv4_flashmla_cache(packed_kv, DSV4_SWA_PAGE_SIZE)
+ out, _ = flash_mla.flash_mla_with_kvcache(
+ q=q_4d,
+ k_cache=k_cache,
+ block_table=None,
+ cache_seqlens=None,
+ head_dim_v=nsa_dict["head_dim_v"],
+ tile_scheduler_metadata=self.flashmla_sched_meta[nsa_dict["compress_ratio"]],
+ num_splits=None,
+ softmax_scale=nsa_dict["softmax_scale"],
+ causal=False,
+ is_fp8_kvcache=True,
+ indices=metadata.swa_indices,
+ attn_sink=attn_sink,
+ topk_length=metadata.swa_lengths,
+ extra_k_cache=metadata.extra_cache,
+ extra_indices_in_kvcache=metadata.extra_indices,
+ extra_topk_length=metadata.extra_lengths,
+ )
+ return out[:, 0, :num_real_heads].contiguous()
diff --git a/lightllm/common/basemodel/basemodel.py b/lightllm/common/basemodel/basemodel.py
index 94f9d4c1a..f440b9821 100755
--- a/lightllm/common/basemodel/basemodel.py
+++ b/lightllm/common/basemodel/basemodel.py
@@ -291,6 +291,15 @@ def _init_custom(self):
@torch.no_grad()
def forward(self, model_input: ModelInput):
+ # decode 槽位 prep: 放在 to_cuda 前, 优先使用 b_req_idx/b_seq_len 的 CPU mirror,
+ # 且此刻已在 forward 的 CUDA stream 上 -> 与后续 attention 同流, 无跨流竞态、无 D2H。
+ # mem_indexes_cpu is None 时跳过: cudagraph warmup 的输入全在 CUDA 且 b_req_idx 全为 HOLD, prep 本就是 no-op。
+ if not model_input.is_prefill and model_input.mem_indexes_cpu is not None:
+ self.req_manager.prepare_decode(
+ model_input.b_req_idx_cpu,
+ model_input.b_seq_len_cpu,
+ model_input.mem_indexes_cpu,
+ )
model_input.to_cuda()
assert model_input.mem_indexes.is_cuda
@@ -371,6 +380,15 @@ def _create_padded_decode_model_input(self, model_input: ModelInput, new_batch_s
new_model_input.b_mtp_index, (0, padded_batch_size), mode="constant", value=0
)
new_model_input.b_seq_len = F.pad(new_model_input.b_seq_len, (0, padded_batch_size), mode="constant", value=2)
+ new_model_input.b_req_idx_cpu = F.pad(
+ new_model_input.b_req_idx_cpu,
+ (0, padded_batch_size),
+ mode="constant",
+ value=self.req_manager.HOLD_REQUEST_ID,
+ )
+ new_model_input.b_seq_len_cpu = F.pad(
+ new_model_input.b_seq_len_cpu, (0, padded_batch_size), mode="constant", value=2
+ )
new_model_input.mem_indexes = F.pad(
new_model_input.mem_indexes,
(0, padded_batch_size),
@@ -428,6 +446,15 @@ def _create_padded_prefill_model_input(self, model_input: ModelInput, new_handle
new_model_input.b_mtp_index = F.pad(new_model_input.b_mtp_index, (0, 1), mode="constant", value=0)
new_model_input.b_seq_len = F.pad(new_model_input.b_seq_len, (0, 1), mode="constant", value=padded_token_num)
new_model_input.b_ready_cache_len = F.pad(new_model_input.b_ready_cache_len, (0, 1), mode="constant", value=0)
+ new_model_input.b_req_idx_cpu = F.pad(
+ new_model_input.b_req_idx_cpu, (0, 1), mode="constant", value=self.req_manager.HOLD_REQUEST_ID
+ )
+ new_model_input.b_seq_len_cpu = F.pad(
+ new_model_input.b_seq_len_cpu, (0, 1), mode="constant", value=padded_token_num
+ )
+ new_model_input.b_ready_cache_len_cpu = F.pad(
+ new_model_input.b_ready_cache_len_cpu, (0, 1), mode="constant", value=0
+ )
b_q_seq_len = new_model_input.b_seq_len - new_model_input.b_ready_cache_len
new_model_input.b_prefill_start_loc = b_q_seq_len.cumsum(dim=0, dtype=torch.int32) - b_q_seq_len
# 构建新的list, 使用 append 可能会让外面使用的数组引用发生变化,导致错误。
@@ -521,6 +548,15 @@ def _prefill(
alloc_mem_index=infer_state.mem_index,
max_q_seq_len=infer_state.max_q_seq_len,
)
+ if model_input.b_req_idx_cpu is not None:
+ self.req_manager.prepare_prefill(
+ b_req_idx=infer_state.b_req_idx,
+ b_ready_cache_len=infer_state.b_ready_cache_len,
+ b_seq_len=infer_state.b_seq_len,
+ b_req_idx_cpu=model_input.b_req_idx_cpu,
+ b_ready_cache_len_cpu=model_input.b_ready_cache_len_cpu,
+ b_seq_len_cpu=model_input.b_seq_len_cpu,
+ )
prefill_mem_indexes_ready_event = torch.cuda.Event()
prefill_mem_indexes_ready_event.record()
@@ -741,6 +777,15 @@ def microbatch_overlap_prefill(self, model_input0: ModelInput, model_input1: Mod
alloc_mem_index=infer_state0.mem_index,
max_q_seq_len=infer_state0.max_q_seq_len,
)
+ if model_input0.b_req_idx_cpu is not None:
+ self.req_manager.prepare_prefill(
+ b_req_idx=infer_state0.b_req_idx,
+ b_ready_cache_len=infer_state0.b_ready_cache_len,
+ b_seq_len=infer_state0.b_seq_len,
+ b_req_idx_cpu=model_input0.b_req_idx_cpu,
+ b_ready_cache_len_cpu=model_input0.b_ready_cache_len_cpu,
+ b_seq_len_cpu=model_input0.b_seq_len_cpu,
+ )
infer_state0.init_some_extra_state(self)
infer_state0.init_att_state()
@@ -754,6 +799,15 @@ def microbatch_overlap_prefill(self, model_input0: ModelInput, model_input1: Mod
alloc_mem_index=infer_state1.mem_index,
max_q_seq_len=infer_state1.max_q_seq_len,
)
+ if model_input1.b_req_idx_cpu is not None:
+ self.req_manager.prepare_prefill(
+ b_req_idx=infer_state1.b_req_idx,
+ b_ready_cache_len=infer_state1.b_ready_cache_len,
+ b_seq_len=infer_state1.b_seq_len,
+ b_req_idx_cpu=model_input1.b_req_idx_cpu,
+ b_ready_cache_len_cpu=model_input1.b_ready_cache_len_cpu,
+ b_seq_len_cpu=model_input1.b_seq_len_cpu,
+ )
infer_state1.init_some_extra_state(self)
infer_state1.init_att_state()
@@ -781,6 +835,14 @@ def microbatch_overlap_prefill(self, model_input0: ModelInput, model_input1: Mod
@torch.no_grad()
def microbatch_overlap_decode(self, model_input0: ModelInput, model_input1: ModelInput):
+ # decode 槽位 prep: 在 to_cuda 前使用 CPU mirror, 且已在 forward 的 CUDA stream 上 (见 forward 注释)。
+ for mi in (model_input0, model_input1):
+ if mi.mem_indexes_cpu is not None:
+ self.req_manager.prepare_decode(
+ mi.b_req_idx_cpu,
+ mi.b_seq_len_cpu,
+ mi.mem_indexes_cpu,
+ )
model_input0.to_cuda()
model_input1.to_cuda()
assert self.args.enable_tpsp_mix_mode
diff --git a/lightllm/common/basemodel/batch_objs.py b/lightllm/common/basemodel/batch_objs.py
index 1795ff9a8..7110d4895 100644
--- a/lightllm/common/basemodel/batch_objs.py
+++ b/lightllm/common/basemodel/batch_objs.py
@@ -42,6 +42,9 @@ class ModelInput:
multimodal_params: list = None
# cpu 变量
mem_indexes_cpu: torch.Tensor = None
+ b_req_idx_cpu: torch.Tensor = None
+ b_seq_len_cpu: torch.Tensor = None
+ b_ready_cache_len_cpu: torch.Tensor = None
# prefill 阶段使用的参数,但是不是推理过程使用的参数,是推理外部进行资源管理
# 的一些变量
b_prefill_has_output_cpu: List[bool] = None # 标记进行prefill的请求是否具有输出
@@ -53,6 +56,18 @@ class ModelInput:
# 的 draft 模型的输入
mtp_draft_input_hiddens: Optional[torch.Tensor] = None
+ def _capture_cpu_mirror(self, tensor_name: str, mirror_name: str):
+ tensor = getattr(self, tensor_name)
+ if tensor is not None and not tensor.is_cuda:
+ setattr(self, mirror_name, tensor)
+ return
+
+ def capture_cpu_mirrors(self):
+ self._capture_cpu_mirror("b_req_idx", "b_req_idx_cpu")
+ self._capture_cpu_mirror("b_seq_len", "b_seq_len_cpu")
+ self._capture_cpu_mirror("b_ready_cache_len", "b_ready_cache_len_cpu")
+ return
+
def to_cuda(self):
if self.input_ids is not None:
self.input_ids = self.input_ids.cuda(non_blocking=True)
@@ -82,6 +97,7 @@ def to_cuda(self):
self.b_shared_seq_len = self.b_shared_seq_len.cuda(non_blocking=True)
def __post_init__(self):
+ self.capture_cpu_mirrors()
self.check_input()
def check_input(self):
diff --git a/lightllm/common/basemodel/cuda_graph.py b/lightllm/common/basemodel/cuda_graph.py
index 782150661..e1d96b744 100644
--- a/lightllm/common/basemodel/cuda_graph.py
+++ b/lightllm/common/basemodel/cuda_graph.py
@@ -14,6 +14,20 @@
logger = init_logger(__name__)
+def _reset_att_state_sched_meta(infer_state: InferStateInfo):
+ # capture 前调用: warmup 趟用 copy.copy 浅拷贝共享 decode_att_state,其内部惰性初始化的
+ # 调度对象(如 FlashMLASchedMeta,首次内核调用时按当时数据规划并回写)会被 warmup 的
+ # dummy 负载锁定;若不重置,捕获趟将绑定为 dummy 规划的调度张量,所有 replay 都用错误
+ # 的 tile schedule(DSV4 实测 gsm8k 0.96 -> 0.74)。重置后规划发生在捕获区内,随 replay 重算。
+ for att_state in (infer_state.decode_att_state, infer_state.decode_att_state1):
+ if att_state is None:
+ continue
+ reset_fn = getattr(att_state, "reset_sched_meta_for_capture", None)
+ if reset_fn is not None:
+ reset_fn()
+ return
+
+
class CudaGraph:
# CudaGraph forward pass for the decoding stage.
@@ -94,6 +108,8 @@ def _capture_decode(self, decode_func, infer_state: InferStateInfo):
if param_name not in pure_para_set:
delattr(infer_state, param_name)
+ _reset_att_state_sched_meta(infer_state)
+
with torch.cuda.graph(graph_obj, pool=self.mempool):
model_output = decode_func(infer_state)
self.graph[batch_size] = (graph_obj, infer_state, model_output)
@@ -128,6 +144,9 @@ def _capture_decode_overlap(
if para_name not in pure_para_set1:
delattr(infer_state1, para_name)
+ _reset_att_state_sched_meta(infer_state)
+ _reset_att_state_sched_meta(infer_state1)
+
with torch.cuda.graph(graph_obj, pool=self.mempool):
model_output, model_output1 = decode_func(infer_state, infer_state1)
self.graph[batch_size] = (
diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/fused_moe_weight.py b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/fused_moe_weight.py
index fca9b80fc..24842ed38 100644
--- a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/fused_moe_weight.py
+++ b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/fused_moe_weight.py
@@ -68,12 +68,13 @@ def __init__(
auto_update_redundancy_expert=self.auto_update_redundancy_expert,
)
self.lock = threading.Lock()
+ self._moe_weight_finalized = False
self._create_weight()
def _init_config(self, network_config: Dict[str, Any]):
self.n_group = network_config.get("n_group", 0)
self.use_grouped_topk = self.n_group > 0
- self.norm_topk_prob = network_config["norm_topk_prob"]
+ self.norm_topk_prob = network_config.get("norm_topk_prob", False)
self.topk_group = network_config.get("topk_group", 0)
self.num_experts_per_tok = network_config["num_experts_per_tok"]
self.routed_scaling_factor = network_config.get("routed_scaling_factor", 1.0)
@@ -136,6 +137,7 @@ def experts(
is_prefill: Optional[bool] = None,
) -> torch.Tensor:
"""Backward compatible method that routes to platform-specific implementation."""
+ self._finalize_moe_weight()
return self.fuse_moe_impl(
input_tensor=input_tensor,
router_logits=router_logits,
@@ -152,6 +154,25 @@ def experts(
per_expert_scale=self.per_expert_scale,
)
+ def experts_with_preselected(
+ self,
+ input_tensor: torch.Tensor,
+ topk_weights: torch.Tensor,
+ topk_ids: torch.Tensor,
+ is_prefill: Optional[bool] = None,
+ clamp_limit: Optional[float] = None,
+ ) -> torch.Tensor:
+ self._finalize_moe_weight()
+ return self.fuse_moe_impl.fused_experts_with_topk(
+ input_tensor=input_tensor,
+ w13=self.w13,
+ w2=self.w2,
+ topk_weights=topk_weights,
+ topk_ids=topk_ids,
+ is_prefill=is_prefill,
+ clamp_limit=clamp_limit,
+ )
+
def low_latency_dispatch(
self,
hidden_states: torch.Tensor,
@@ -280,7 +301,18 @@ def verify_load(self):
e_score_correction_bias_load_ok = (
True if self.e_score_correction_bias is None else getattr(self.e_score_correction_bias, "load_ok", False)
)
- return weight_load_ok and per_expert_scale_load_ok and e_score_correction_bias_load_ok
+ load_ok = weight_load_ok and per_expert_scale_load_ok and e_score_correction_bias_load_ok
+ if load_ok:
+ self._finalize_moe_weight()
+ return load_ok
+
+ def _finalize_moe_weight(self):
+ if self._moe_weight_finalized:
+ return
+ finalize = getattr(self.quant_method, "finalize_moe_weight", None)
+ if finalize is not None:
+ finalize(self)
+ self._moe_weight_finalized = True
def _create_weight(self):
intermediate_size = self.split_inter_size
diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/__init__.py b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/__init__.py
index 67bb90e4e..282c0abdc 100644
--- a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/__init__.py
+++ b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/__init__.py
@@ -2,9 +2,15 @@
from .triton_impl import FuseMoeTriton
from .marlin_impl import FuseMoeMarlin
from .deepgemm_impl import FuseMoeDeepGEMM
+from .mxfp4_impl import FuseMoeMXFP4
def select_fuse_moe_impl(quant_method: QuantizationMethod, enable_ep_moe: bool):
+ if quant_method.method_name == "marlin-mxfp4w4a16-b32":
+ if enable_ep_moe:
+ raise RuntimeError("marlin-mxfp4w4a16-b32 does not support enable_ep_moe yet")
+ return FuseMoeMXFP4
+
if enable_ep_moe:
return FuseMoeDeepGEMM
diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/deepgemm_impl.py b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/deepgemm_impl.py
index 4d4614c00..72acf2430 100644
--- a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/deepgemm_impl.py
+++ b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/deepgemm_impl.py
@@ -76,7 +76,9 @@ def _fused_experts(
topk_ids: torch.Tensor,
router_logits: Optional[torch.Tensor] = None,
is_prefill: Optional[bool] = None,
+ clamp_limit: Optional[float] = None,
):
+ assert clamp_limit is None, "EP deepgemm fused MoE does not support clamp_limit yet"
output = fused_experts(
hidden_states=input_tensor,
w13=w13,
diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/marlin_impl.py b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/marlin_impl.py
index 0094b09b1..1fdfd94d0 100644
--- a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/marlin_impl.py
+++ b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/marlin_impl.py
@@ -30,7 +30,9 @@ def _fused_experts(
topk_ids: torch.Tensor,
router_logits: Optional[torch.Tensor] = None,
is_prefill: Optional[bool] = None,
+ clamp_limit: Optional[float] = None,
):
+ assert clamp_limit is None, "awq_marlin fused MoE does not support clamp_limit yet"
w1_weight, w1_scale, w1_zero_point = w13.weight, w13.weight_scale, w13.weight_zero_point
w2_weight, w2_scale, w2_zero_point = w2.weight, w2.weight_scale, w2.weight_zero_point
diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/mxfp4_impl.py b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/mxfp4_impl.py
new file mode 100644
index 000000000..97cf23811
--- /dev/null
+++ b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/mxfp4_impl.py
@@ -0,0 +1,44 @@
+import torch
+from typing import Optional
+
+from lightllm.common.quantization.quantize_method import WeightPack
+from .triton_impl import FuseMoeTriton
+
+
+class FuseMoeMXFP4(FuseMoeTriton):
+ def create_workspace(self):
+ return None
+
+ def _fused_experts(
+ self,
+ input_tensor: torch.Tensor,
+ w13: WeightPack,
+ w2: WeightPack,
+ topk_weights: torch.Tensor,
+ topk_ids: torch.Tensor,
+ router_logits: Optional[torch.Tensor] = None,
+ is_prefill: Optional[bool] = None,
+ clamp_limit: Optional[float] = None,
+ ):
+ try:
+ from vllm.model_executor.layers.fused_moe.activation import MoEActivation
+ from vllm.model_executor.layers.fused_moe.experts.marlin_moe import fused_marlin_moe
+ from vllm.scalar_type import scalar_types
+ except Exception as e:
+ raise RuntimeError(f"MXFP4 fused MoE requires vLLM fused kernels, error={repr(e)}") from e
+
+ return fused_marlin_moe(
+ hidden_states=input_tensor.contiguous(),
+ w1=w13.weight,
+ w2=w2.weight,
+ bias1=None,
+ bias2=None,
+ w1_scale=w13.weight_scale,
+ w2_scale=w2.weight_scale,
+ topk_weights=topk_weights.to(torch.float32).contiguous(),
+ topk_ids=topk_ids.to(torch.long).contiguous(),
+ quant_type_id=scalar_types.float4_e2m1f.id,
+ global_num_experts=self.n_routed_experts,
+ activation=MoEActivation.SILU,
+ clamp_limit=clamp_limit,
+ )
diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/triton_impl.py b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/triton_impl.py
index a0d30547a..8967dda34 100644
--- a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/triton_impl.py
+++ b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/triton_impl.py
@@ -94,6 +94,7 @@ def _fused_experts(
topk_ids: torch.Tensor,
router_logits: Optional[torch.Tensor] = None,
is_prefill: bool = False,
+ clamp_limit: Optional[float] = None,
):
w13_weight, w13_scale = w13.weight, w13.weight_scale
w2_weight, w2_scale = w2.weight, w2.weight_scale
@@ -111,9 +112,30 @@ def _fused_experts(
use_fp8_w8a8=use_fp8_w8a8,
w1_scale=w13_scale,
w2_scale=w2_scale,
+ limit=clamp_limit,
)
return input_tensor
+ def fused_experts_with_topk(
+ self,
+ input_tensor: torch.Tensor,
+ w13: WeightPack,
+ w2: WeightPack,
+ topk_weights: torch.Tensor,
+ topk_ids: torch.Tensor,
+ is_prefill: Optional[bool] = None,
+ clamp_limit: Optional[float] = None,
+ ):
+ return self._fused_experts(
+ input_tensor=input_tensor,
+ w13=w13,
+ w2=w2,
+ topk_weights=topk_weights,
+ topk_ids=topk_ids,
+ is_prefill=is_prefill,
+ clamp_limit=clamp_limit,
+ )
+
def __call__(
self,
input_tensor: torch.Tensor,
diff --git a/lightllm/common/basemodel/triton_kernel/fused_moe/grouped_fused_moe_ep.py b/lightllm/common/basemodel/triton_kernel/fused_moe/grouped_fused_moe_ep.py
index cb2e370cb..28fe6e430 100644
--- a/lightllm/common/basemodel/triton_kernel/fused_moe/grouped_fused_moe_ep.py
+++ b/lightllm/common/basemodel/triton_kernel/fused_moe/grouped_fused_moe_ep.py
@@ -49,7 +49,7 @@ def check_ep_expert_dtype(quant_method: Any):
"EP MoE requires --expert_dtype to be one of ['fp8', 'fp4'], "
f"but the resolved fused_moe quant method is `{expert_dtype}`. "
"Please start with --expert_dtype fp8 or --expert_dtype fp4. "
- "Note that --expert_dtype fp4 is only supported on SM100 GPUs."
+ "Note that --expert_dtype fp4 with EP MoE is only supported on SM100 GPUs."
)
if expert_dtype == "deepgemm-fp4fp8-b32" and not is_sm100_gpu():
raise RuntimeError(
diff --git a/lightllm/common/basemodel/triton_kernel/fused_moe/moe_silu_and_mul.py b/lightllm/common/basemodel/triton_kernel/fused_moe/moe_silu_and_mul.py
index 45c7ea73c..82fc9131c 100644
--- a/lightllm/common/basemodel/triton_kernel/fused_moe/moe_silu_and_mul.py
+++ b/lightllm/common/basemodel/triton_kernel/fused_moe/moe_silu_and_mul.py
@@ -24,6 +24,7 @@ def _silu_and_mul_kernel_fast(
NEED_MASK: tl.constexpr,
layout: tl.constexpr = "blocked", # "blocked" or "interleaved"
USE_LIMIT_AND_ALPHA: tl.constexpr = False,
+ USE_LIMIT_ONLY: tl.constexpr = False,
USE_TANH_APPROXIMATE_GELU: tl.constexpr = False,
):
stride_input_m = tl.cast(stride_input_m, dtype=tl.int64)
@@ -76,6 +77,11 @@ def _silu_and_mul_kernel_fast(
mask=mask,
)
else:
+ if USE_LIMIT_ONLY:
+ # clamped swiglu (DeepSeek-V4 swiglu_limit): clamp 后接标准 silu,
+ # 无 gpt-oss 的 alpha 缩放与 (up+1)。
+ gate = tl.minimum(gate, limit)
+ up = tl.minimum(tl.maximum(up, -limit), limit)
if USE_TANH_APPROXIMATE_GELU:
# tanh-approx GELU, matching Gemma's gelu_pytorch_tanh MLP.
gate_cubed = gate * gate * gate
@@ -124,7 +130,8 @@ def silu_and_mul_fwd(
):
assert input.is_contiguous()
assert output.is_contiguous()
- assert (limit is None and alpha is None) or (limit is not None and alpha is not None)
+ # limit+alpha: gpt-oss 语义 (up+1)*silu(alpha*gate); 仅 limit: clamp 后标准 silu (DeepSeek-V4)
+ assert alpha is None or limit is not None
stride_input_m = input.stride(0)
stride_input_n = input.stride(1)
@@ -147,6 +154,7 @@ def silu_and_mul_fwd(
while triton.cdiv(size_m, BLOCK_M) > 8192:
BLOCK_M *= 2
USE_LIMIT_AND_ALPHA = limit is not None and alpha is not None
+ USE_LIMIT_ONLY = limit is not None and alpha is None
grid = (
triton.cdiv(size_n, BLOCK_N),
@@ -171,6 +179,7 @@ def silu_and_mul_fwd(
num_warps=num_warps,
layout=layout,
USE_LIMIT_AND_ALPHA=USE_LIMIT_AND_ALPHA,
+ USE_LIMIT_ONLY=USE_LIMIT_ONLY,
USE_TANH_APPROXIMATE_GELU=ffn_use_tanh_approximate_gelu(),
)
return
diff --git a/lightllm/common/kv_cache_mem_manager/__init__.py b/lightllm/common/kv_cache_mem_manager/__init__.py
index 05544e149..95f7e8ab7 100644
--- a/lightllm/common/kv_cache_mem_manager/__init__.py
+++ b/lightllm/common/kv_cache_mem_manager/__init__.py
@@ -4,6 +4,7 @@
from .ppl_int4kv_mem_manager import PPLINT4KVMemoryManager
from .deepseek2_mem_manager import Deepseek2MemoryManager
from .deepseek3_2mem_manager import Deepseek3_2MemoryManager
+from .deepseek4_mem_manager import DeepseekV4MemoryManager
from .fp8_per_token_group_quant_deepseek3_2mem_manager import FP8PerTokenGroupQuantDeepseek3_2MemoryManager
from .fp8_static_per_head_quant_mem_manager import FP8StaticPerHeadQuantMemManager
from .fp8_static_per_tensor_quant_mem_manager import FP8StaticPerTensorQuantMemManager
@@ -17,6 +18,7 @@
"PPLINT8KVMemoryManager",
"Deepseek2MemoryManager",
"Deepseek3_2MemoryManager",
+ "DeepseekV4MemoryManager",
"FP8PerTokenGroupQuantDeepseek3_2MemoryManager",
"FP8StaticPerHeadQuantMemManager",
"FP8StaticPerTensorQuantMemManager",
diff --git a/lightllm/common/kv_cache_mem_manager/allocator.py b/lightllm/common/kv_cache_mem_manager/allocator.py
index 850c15877..0179ed271 100644
--- a/lightllm/common/kv_cache_mem_manager/allocator.py
+++ b/lightllm/common/kv_cache_mem_manager/allocator.py
@@ -3,13 +3,13 @@
from lightllm.utils.dist_utils import get_current_rank_in_node
from lightllm.utils.envs_utils import get_unique_server_name
from lightllm.utils.log_utils import init_logger
-from typing import Union, List
+from typing import Union, List, Optional
logger = init_logger(__name__)
class KvCacheAllocator:
- def __init__(self, size: int) -> None:
+ def __init__(self, size: int, shared_name: Optional[str] = None) -> None:
self.size = size
self.mem_state = torch.arange(
0, self.size, dtype=torch.int32, device="cpu", requires_grad=False, pin_memory=True
@@ -26,9 +26,11 @@ def __init__(self, size: int) -> None:
rank_in_node = get_current_rank_in_node()
# 用共享内存进行共享,router 模块读取进行精确的调度估计, nccl port 作为一个单机中单实列的标记。防止冲突。
- self.shared_can_use_token_num = SharedInt(
- f"{get_unique_server_name()}_mem_manger_can_use_token_num_{rank_in_node}"
- )
+ # shared_name 为 None 时使用主 kv 池的默认名(router 调度据此估算);DeepSeek-V4 的压缩子池等
+ # 需要各自独立的计数器,传入区别于主池的唯一名,避免多个 allocator 写同一个共享计数器。
+ if shared_name is None:
+ shared_name = f"{get_unique_server_name()}_mem_manger_can_use_token_num_{rank_in_node}"
+ self.shared_can_use_token_num = SharedInt(shared_name)
self.shared_can_use_token_num.set_value(self.can_use_mem_size)
return
diff --git a/lightllm/common/kv_cache_mem_manager/deepseek4_mem_manager.py b/lightllm/common/kv_cache_mem_manager/deepseek4_mem_manager.py
new file mode 100644
index 000000000..784065d96
--- /dev/null
+++ b/lightllm/common/kv_cache_mem_manager/deepseek4_mem_manager.py
@@ -0,0 +1,792 @@
+import torch
+from typing import List, Optional, Union
+from .mem_manager import MemoryManager
+from .operator import DeepseekV4MemOperator
+from .allocator import KvCacheAllocator
+from lightllm.utils.dist_utils import get_current_rank_in_node
+from lightllm.utils.envs_utils import get_unique_server_name
+from lightllm.utils.log_utils import init_logger
+
+logger = init_logger(__name__)
+
+
+# fp8_ds_mla packed-latent byte layout (ABI shared with the flash_mla extra-cache fork and
+# sglang/vllm): 448B NoPE fp8 + 64*2B RoPE bf16 + 7B ue8m0 scale + 1B pad = 584B per token,
+# stored in page slabs whose tail carries the per-token scale bytes.
+DSV4_MLA_NOPE_DIM = 448 # 448B
+DSV4_MLA_ROPE_DIM = 64 # 64 dim
+DSV4_MLA_HEAD_DIM = DSV4_MLA_NOPE_DIM + DSV4_MLA_ROPE_DIM # 512
+DSV4_MLA_QUANT_GROUP_SIZE = 64 # 64
+DSV4_MLA_SCALE_BYTES = DSV4_MLA_NOPE_DIM // DSV4_MLA_QUANT_GROUP_SIZE + 1 # 8 (7 ue8m0 + 1 pad)
+DSV4_MLA_BYTES_PER_TOKEN = DSV4_MLA_NOPE_DIM + DSV4_MLA_ROPE_DIM * 2 + DSV4_MLA_SCALE_BYTES # 584
+DSV4_MLA_DATA_BYTES_PER_TOKEN = DSV4_MLA_NOPE_DIM + DSV4_MLA_ROPE_DIM * 2 # 576
+DSV4_MLA_PAGE_ALIGN_BYTES = DSV4_MLA_DATA_BYTES_PER_TOKEN # 576
+DSV4_INDEXER_HEAD_DIM = 128 # 128
+DSV4_INDEXER_SCALE_BYTES = 4 # 4B fp32 scale
+DSV4_INDEXER_BYTES_PER_TOKEN = DSV4_INDEXER_HEAD_DIM + DSV4_INDEXER_SCALE_BYTES # 132
+DSV4_FP8_E4M3_MAX = 448.0 # 448.0
+DSV4_FP8_SCALE_MIN = 1e-4 # 1e-4
+DSV4_SWA_PAGE_SIZE = 128 # 128 slots/page
+DSV4_C4_PAGE_SIZE = 64 # 64 slots/page
+DSV4_C128_PAGE_SIZE = 2 # 2 slots/page
+DSV4_PROMPT_CACHE_PAGE_SIZE = DSV4_C4_PAGE_SIZE * 4 # 256 (= c4 ratio)
+# compressor state ring: c4 overlap 对为每页 2 个分组槽 × ratio 4 行;c128 离线聚合为每页 1 组。
+DSV4_C4_STATE_RING = 8 # 8 rows/page
+DSV4_C128_STATE_RING = 128 # 128 rows/page
+# swa 池占 full token 空间的比例(sglang DSV4 默认 swa_full_tokens_ratio=0.1 同值)。
+# 瞬时借页/驱逐走 swa 压力阀;池子大小仅按 ratio 切分,不再叠加结构性余量。
+DSV4_SWA_FULL_TOKENS_RATIO = 0.1 # 0.1
+
+
+def _ceil_div(a: int, b: int) -> int:
+ return (a + b - 1) // b
+
+
+class PackedPagePool:
+ """fp8_ds_mla 风格的 page-slab 存储: 每页前段连续放 token 的 data 字节,页尾放 per-token scale 字节。
+
+ 寻址是纯 token 槽位 (page = slot // page_size),page 只是 scale-tail/对齐的物理打包技巧,
+ 不存在页粒度的分配。``write``/``read`` 是 torch 参考实现(单测 oracle);生产写入走
+ triton packed writer(destindex_copy_kv_flashmla_dsv4 等),kernel 直接消费 ``buffer``。
+ """
+
+ def __init__(
+ self,
+ size: int,
+ page_size: int,
+ layer_num: int,
+ data_bytes: int,
+ scale_bytes: int,
+ align_bytes: int = 1,
+ device: str = "cuda",
+ ):
+ self.size = size
+ self.page_size = page_size
+ self.layer_num = layer_num
+ self.data_bytes_per_token = data_bytes
+ self.scale_bytes_per_token = scale_bytes
+ self.bytes_per_token = data_bytes + scale_bytes
+ self.num_pages = _ceil_div(size + 1, page_size)
+ self.bytes_per_page = _ceil_div(page_size * self.bytes_per_token, align_bytes) * align_bytes
+ self.scale_offset_in_page = page_size * data_bytes
+ self.buffer = torch.zeros((layer_num, self.num_pages, self.bytes_per_page), dtype=torch.uint8, device=device)
+ self.HOLD_TOKEN_MEMINDEX = size
+
+ def get_layer_buffer(self, layer_index: int) -> torch.Tensor:
+ return self.buffer[layer_index]
+
+ def _loc_offsets(self, loc: torch.Tensor):
+ loc = loc.long()
+ page = torch.div(loc, self.page_size, rounding_mode="floor")
+ token = loc % self.page_size
+ page_base = page * self.bytes_per_page
+ data_offsets = page_base + token * self.data_bytes_per_token
+ scale_offsets = page_base + self.scale_offset_in_page + token * self.scale_bytes_per_token
+ return data_offsets, scale_offsets
+
+ def write(self, layer_index: int, loc: torch.Tensor, packed: torch.Tensor) -> None:
+ if loc.numel() == 0:
+ return
+ loc = loc.reshape(-1)
+ packed = packed.reshape(-1, self.bytes_per_token).contiguous()
+ flat = self.buffer[layer_index].view(-1)
+ data_offsets, scale_offsets = self._loc_offsets(loc)
+ data_range = torch.arange(self.data_bytes_per_token, device=loc.device)
+ scale_range = torch.arange(self.scale_bytes_per_token, device=loc.device)
+ flat[data_offsets.unsqueeze(1) + data_range.unsqueeze(0)] = packed[:, : self.data_bytes_per_token]
+ flat[scale_offsets.unsqueeze(1) + scale_range.unsqueeze(0)] = packed[:, self.data_bytes_per_token :]
+ return
+
+ def read(self, layer_index: int, loc: torch.Tensor) -> torch.Tensor:
+ loc = loc.reshape(-1)
+ if loc.numel() == 0:
+ return torch.empty((0, self.bytes_per_token), dtype=torch.uint8, device=self.buffer.device)
+ flat = self.buffer[layer_index].view(-1)
+ data_offsets, scale_offsets = self._loc_offsets(loc)
+ data_range = torch.arange(self.data_bytes_per_token, device=loc.device)
+ scale_range = torch.arange(self.scale_bytes_per_token, device=loc.device)
+ data = flat[data_offsets.unsqueeze(1) + data_range.unsqueeze(0)]
+ scale = flat[scale_offsets.unsqueeze(1) + scale_range.unsqueeze(0)]
+ return torch.cat([data, scale], dim=1).contiguous()
+
+
+class DeepseekV4MemoryManager(MemoryManager):
+ """DeepSeek-V4 KV cache: 窗口 latent(全层) + c4/c128 压缩 latent(压实层) + c4 indexer-K。
+
+ 与兄弟 manager 一致的 token-slot 设计;req 索引的表都在 DeepseekV4ReqManager。
+
+ - ``swa_pool``: 584B packed latent,所有层。池子小于 full token 空间;prep 阶段
+ ``alloc_swa_prefill/decode`` 按**页**(128 槽,位置对齐: slot(p)=page_base+p%128)分配,
+ 映射记录到 ``full_to_swa_indexs``(以 full token 槽位为键)。出窗槽位由 DeepseekV4ReqManager
+ 在 prep 阶段批量惰性回收(``evict_swa``,页存活计数减到 0 才整页归还);full 槽位释放时
+ ``free`` 级联回收对应 swa 槽,所以 radix 驱逐/请求释放/暂停无需任何额外协议。
+ 页 allocator 触底时先走 swa free hook(radix 对 ref==0 节点 free)再 assert。
+ 没有 ring buffer,prefill chunk 大小不受 sliding_window 限制。
+ - ``c4_pool``/``c128_pool``: 压缩 latent,按 qwen3next 的层号压实手法只为压缩层建层;
+ c4 另带 packed indexer-K 池。槽位映射(``full_to_c4/c128_indexs``)以组末 token 的 full
+ 槽位为键(prep 阶段分配/scatter),``free`` 级联回收,与 swa 完全同构。
+ - 写入走标准 operator 路径(``pack_mla_kv_to_cache``),内部为 triton packed writer;
+ torch codecs 保留为 ABI 的可执行规格(单测 oracle)。
+ """
+
+ operator_class = DeepseekV4MemOperator
+
+ mla_nope_dim = DSV4_MLA_NOPE_DIM # 448
+ mla_rope_dim = DSV4_MLA_ROPE_DIM # 64
+ mla_head_dim = DSV4_MLA_HEAD_DIM # 512
+ mla_quant_group_size = DSV4_MLA_QUANT_GROUP_SIZE # 64
+ mla_scale_bytes = DSV4_MLA_SCALE_BYTES # 8
+ mla_bytes_per_token = DSV4_MLA_BYTES_PER_TOKEN # 584
+ indexer_head_dim_default = DSV4_INDEXER_HEAD_DIM # 128
+ indexer_bytes_per_token = DSV4_INDEXER_BYTES_PER_TOKEN # 132
+
+ def __init__(
+ self,
+ size,
+ dtype,
+ head_num,
+ head_dim,
+ layer_num,
+ compress_rates: List[int],
+ indexer_head_dim: int = 128,
+ max_request_num: Optional[int] = None,
+ sliding_window: Optional[int] = None,
+ swa_full_tokens_ratio: float = DSV4_SWA_FULL_TOKENS_RATIO,
+ always_copy=False,
+ mem_fraction=0.9,
+ ):
+ assert head_num == 1, "DeepSeek-V4 是 MLA(MQA),dense latent 的 head_num 必须为 1"
+ assert head_dim == self.mla_head_dim, f"DeepSeek-V4 packed KV 期望 head_dim={self.mla_head_dim}"
+ assert (
+ indexer_head_dim == self.indexer_head_dim_default
+ ), f"DeepSeek-V4 packed indexer-K 期望 indexer_head_dim={self.indexer_head_dim_default}"
+ assert len(compress_rates) == layer_num, f"compress_rates 长度 {len(compress_rates)} 必须等于 layer_num {layer_num}"
+ assert all(r in (0, 4, 128) for r in compress_rates), "compress_rates 取值只能是 0/4/128"
+
+ self.compress_rates = list(compress_rates)
+ self.n_c4 = sum(1 for r in self.compress_rates if r == 4)
+ self.n_c128 = sum(1 for r in self.compress_rates if r == 128)
+ self.indexer_head_dim = indexer_head_dim
+ self.max_request_num = max_request_num
+ self.sliding_window = sliding_window
+ self.swa_full_tokens_ratio = float(swa_full_tokens_ratio)
+
+ # 全局层号 -> 各压缩池内的压实层号(同 qwen3next 的层号压实手法)
+ self.layer_to_c4_idx = {}
+ self.layer_to_c128_idx = {}
+ c4 = c128 = 0
+ for lid, r in enumerate(self.compress_rates):
+ if r == 4:
+ self.layer_to_c4_idx[lid] = c4
+ c4 += 1
+ elif r == 128:
+ self.layer_to_c128_idx[lid] = c128
+ c128 += 1
+
+ super().__init__(size, dtype, head_num, head_dim, layer_num, always_copy, mem_fraction)
+
+ # ------------------------------------------------------------------ sizing
+ def _planned_swa_size(self, full_size: int) -> int:
+ return _ceil_div(int(full_size * self.swa_full_tokens_ratio), DSV4_SWA_PAGE_SIZE) * DSV4_SWA_PAGE_SIZE
+
+ @staticmethod
+ def _paged_state_rows(num_swa_pages: int, ring: int, ratio: int) -> int:
+ rows = num_swa_pages * ring + ring + 1
+ return _ceil_div(rows, ratio) * ratio
+
+ @staticmethod
+ def _init_state_sentinel(buffer: torch.Tensor) -> None:
+ half = buffer.shape[-1] // 2
+ buffer[:, -1, :half].zero_()
+ buffer[:, -1, half:].fill_(float("-inf"))
+ return
+
+ def get_cell_size(self):
+ kv_bytes = self.mla_bytes_per_token
+ indexer_bytes = self.indexer_bytes_per_token
+ state_dtype_bytes = torch._utils._element_size(torch.float32)
+ c4_state_width = 4 * self.head_dim + 4 * self.indexer_head_dim
+ c128_state_width = 2 * self.head_dim
+ c4_state_bytes = DSV4_C4_STATE_RING / DSV4_SWA_PAGE_SIZE * c4_state_width * state_dtype_bytes * self.n_c4
+ c128_state_bytes = (
+ DSV4_C128_STATE_RING / DSV4_SWA_PAGE_SIZE * c128_state_width * state_dtype_bytes * self.n_c128
+ )
+ swa_slot = kv_bytes * self.layer_num + c4_state_bytes + c128_state_bytes
+ compressed = (kv_bytes + indexer_bytes) * self.n_c4 / 4 + kv_bytes * self.n_c128 / 128
+
+ return swa_slot * self.swa_full_tokens_ratio + compressed
+
+ # ------------------------------------------------------------------ buffers
+ def _init_buffers(self, size, dtype, head_num, head_dim, layer_num):
+ rank_in_node = get_current_rank_in_node()
+ server = get_unique_server_name()
+
+ self.swa_size = self._planned_swa_size(size)
+ self.swa_pool = PackedPagePool(
+ size=self.swa_size,
+ page_size=DSV4_SWA_PAGE_SIZE,
+ layer_num=layer_num,
+ data_bytes=DSV4_MLA_DATA_BYTES_PER_TOKEN,
+ scale_bytes=self.mla_scale_bytes,
+ align_bytes=DSV4_MLA_PAGE_ALIGN_BYTES,
+ )
+ # 注意: 该别名是 page 索引([layer, num_pages, bytes_per_page])而非 token 索引,
+ # 只允许 get_att_input_params 的消费者使用;token 索引语义的继承接口已显式 fence。
+ self.kv_buffer = self.swa_pool.buffer
+ # 页粒度分配(页 = 128 槽,位置对齐): 槽位不变式 slot(p) = page_base + p%128。
+ # swa_size 整页对齐 ⇒ HOLD 槽(swa_size)独占池子最后一个物理页,永不参与分配。
+ self.swa_num_pages = self.swa_size // DSV4_SWA_PAGE_SIZE
+ self.swa_page_allocator = KvCacheAllocator(
+ self.swa_num_pages, shared_name=f"{server}_dsv4_swa_can_use_page_num_{rank_in_node}"
+ )
+ # 页存活计数 = 指向该页的有效 full_to_swa 行数;减到 0 归还 allocator(出窗逐 token
+ # 回收下,「部分出窗页」计数 > 0 自然受保护)。下标含 HOLD 页(只读不增减)。
+ self.swa_page_live_count = torch.zeros((self.swa_pool.num_pages,), dtype=torch.int32, device="cuda")
+ # swa free hook(可选): 页 allocator 触底时回调(radix 对 ref==0 节点 free swa 页),
+ # 由 backend 在 radix cache 创建后 register;assert 仍是最后防线。
+ self._free_radix_unreferenced_swa_fn = None
+ self.full_to_swa_indexs = torch.full((size + 1,), -1, dtype=torch.int32, device="cuda")
+ self.full_to_swa_indexs[size] = self.swa_pool.HOLD_TOKEN_MEMINDEX
+
+ self.c4_size = _ceil_div(size, 4)
+ self.c128_size = _ceil_div(size, 128)
+ self.c4_pool: Optional[PackedPagePool] = None
+ self.c4_indexer_pool: Optional[PackedPagePool] = None
+ self.c4_allocator: Optional[KvCacheAllocator] = None
+ self.c4_page_allocator: Optional[KvCacheAllocator] = None
+ self.c4_page_live_count: Optional[torch.Tensor] = None
+ self.c128_pool: Optional[PackedPagePool] = None
+ self.c128_allocator: Optional[KvCacheAllocator] = None
+ self.c4_state_buffer: Optional[torch.Tensor] = None
+ self.c4_indexer_state_buffer: Optional[torch.Tensor] = None
+ self.c128_state_buffer: Optional[torch.Tensor] = None
+ # 压缩槽映射: 键 = 组末 token(位置 (g+1)%ratio==0)的 full 槽位,值 = 压缩池槽位。
+ # 与 full_to_swa_indexs 同构: radix 持有 full 槽 => 映射行存活,free 级联回收。
+ self.full_to_c4_indexs: Optional[torch.Tensor] = None
+ self.full_to_c128_indexs: Optional[torch.Tensor] = None
+ if self.n_c4 > 0:
+ self.c4_pool = PackedPagePool(
+ size=self.c4_size,
+ page_size=DSV4_C4_PAGE_SIZE,
+ layer_num=self.n_c4,
+ data_bytes=DSV4_MLA_DATA_BYTES_PER_TOKEN,
+ scale_bytes=self.mla_scale_bytes,
+ align_bytes=DSV4_MLA_PAGE_ALIGN_BYTES,
+ )
+ self.c4_indexer_pool = PackedPagePool(
+ size=self.c4_size,
+ page_size=DSV4_C4_PAGE_SIZE,
+ layer_num=self.n_c4,
+ data_bytes=self.indexer_head_dim,
+ scale_bytes=DSV4_INDEXER_SCALE_BYTES,
+ )
+ self.c4_num_pages = self.c4_size // DSV4_C4_PAGE_SIZE
+ assert self.c4_num_pages > 0, "DeepSeek-V4 c4 pool must have at least one usable full page"
+ self.c4_page_allocator = KvCacheAllocator(
+ self.c4_num_pages, shared_name=f"{server}_dsv4_c4_can_use_page_num_{rank_in_node}"
+ )
+ self.c4_page_live_count = torch.zeros((self.c4_pool.num_pages,), dtype=torch.int32, device="cuda")
+ self.full_to_c4_indexs = torch.full((size + 1,), -1, dtype=torch.int32, device="cuda")
+ self.full_to_c4_indexs[size] = self.c4_pool.HOLD_TOKEN_MEMINDEX
+ # c4 compressor 在途状态(attention + indexer): swa 页派生寻址(翻译③),随 swa 页
+ # 生灭 -> radix 命中零拷贝续算。行数 = 页数*ring + ring(HOLD 页) + 1(哨兵),
+ # 取整到 ratio;末行哨兵 kv=0/score=-inf(KVAndScore.clear 语义),其余行由内核在
+ # 组起点覆写,无需按页清零。last_dim = 2*coff*head_dim(overlap coff=2)。
+ state_rows = self._paged_state_rows(self.swa_num_pages, DSV4_C4_STATE_RING, 4)
+ self.c4_state_buffer = torch.zeros(
+ (self.n_c4, state_rows, 4 * self.head_dim), dtype=torch.float32, device="cuda"
+ )
+ self.c4_indexer_state_buffer = torch.zeros(
+ (self.n_c4, state_rows, 4 * self.indexer_head_dim), dtype=torch.float32, device="cuda"
+ )
+ for buf in (self.c4_state_buffer, self.c4_indexer_state_buffer):
+ self._init_state_sentinel(buf)
+ if self.n_c128 > 0:
+ self.c128_pool = PackedPagePool(
+ size=self.c128_size,
+ page_size=DSV4_C128_PAGE_SIZE,
+ layer_num=self.n_c128,
+ data_bytes=DSV4_MLA_DATA_BYTES_PER_TOKEN,
+ scale_bytes=self.mla_scale_bytes,
+ align_bytes=DSV4_MLA_PAGE_ALIGN_BYTES,
+ )
+ self.c128_allocator = KvCacheAllocator(
+ self.c128_size, shared_name=f"{server}_dsv4_c128_can_use_token_num_{rank_in_node}"
+ )
+ self.full_to_c128_indexs = torch.full((size + 1,), -1, dtype=torch.int32, device="cuda")
+ self.full_to_c128_indexs[size] = self.c128_pool.HOLD_TOKEN_MEMINDEX
+ # c128 compressor 在途状态: 与 c4 同样由 full->swa 推导行号,但 ring=128 且无 overlap。
+ # last_dim = 2*head_dim;末行是 swa 缺失/出窗时读取的哨兵。
+ state_rows = self._paged_state_rows(self.swa_num_pages, DSV4_C128_STATE_RING, 128)
+ self.c128_state_buffer = torch.zeros(
+ (self.n_c128, state_rows, 2 * self.head_dim), dtype=torch.float32, device="cuda"
+ )
+ self._init_state_sentinel(self.c128_state_buffer)
+
+ logger.info(
+ f"DeepseekV4MemoryManager pools: full_tokens={size} swa={self.swa_size}({self.swa_num_pages}p) "
+ f"c4={self.c4_size}(L={self.n_c4}) c128={self.c128_size}(L={self.n_c128}) "
+ f"packed_kv_bytes={self.mla_bytes_per_token} indexer_bytes={self.indexer_bytes_per_token}"
+ )
+
+ # ------------------------------------------------------------------ buffer accessors
+ def get_att_input_params(self, layer_index: int):
+ return self.swa_pool.get_layer_buffer(layer_index)
+
+ def _pool_and_local_layer(self, layer_index: int):
+ r = self.compress_rates[layer_index]
+ if r == 4:
+ return self.c4_pool, self.layer_to_c4_idx[layer_index]
+ if r == 128:
+ return self.c128_pool, self.layer_to_c128_idx[layer_index]
+ raise AssertionError(f"layer {layer_index} (rate {r}) 不是压缩层,没有压缩池")
+
+ def get_compressed_kv_buffer(self, layer_index: int) -> torch.Tensor:
+ pool, local_layer = self._pool_and_local_layer(layer_index)
+ return pool.get_layer_buffer(local_layer)
+
+ def get_indexer_k_buffer(self, layer_index: int) -> torch.Tensor:
+ assert self.compress_rates[layer_index] == 4, "只有 c4(CSA) 层有 indexer-K"
+ return self.c4_indexer_pool.get_layer_buffer(self.layer_to_c4_idx[layer_index])
+
+ def get_c4_state_buffer(self, layer_index: int) -> torch.Tensor:
+ assert self.compress_rates[layer_index] == 4, "只有 c4(CSA) 层有 paged compressor state"
+ return self.c4_state_buffer[self.layer_to_c4_idx[layer_index]]
+
+ def get_c4_indexer_state_buffer(self, layer_index: int) -> torch.Tensor:
+ assert self.compress_rates[layer_index] == 4, "只有 c4(CSA) 层有 paged indexer state"
+ return self.c4_indexer_state_buffer[self.layer_to_c4_idx[layer_index]]
+
+ def get_c128_state_buffer(self, layer_index: int) -> torch.Tensor:
+ assert self.compress_rates[layer_index] == 128, "只有 c128(HCA) 层有 paged compressor state"
+ return self.c128_state_buffer[self.layer_to_c128_idx[layer_index]]
+
+ # ------------------------------------------------------------------ swa slot lifecycle
+ def register_swa_free_hook(self, fn) -> None:
+ """fn(need_pages): 在页 allocator 不足时尝试腾页(radix 对 ref==0 节点 free swa)。"""
+ self._free_radix_unreferenced_swa_fn = fn
+ return
+
+ def _alloc_swa_pages(self, need_pages: int) -> torch.Tensor:
+ if need_pages > self.swa_page_allocator.can_use_mem_size and self._free_radix_unreferenced_swa_fn is not None:
+ self._free_radix_unreferenced_swa_fn(need_pages - self.swa_page_allocator.can_use_mem_size)
+ return self.swa_page_allocator.alloc(need_pages)
+
+ def _count_swa_pages(self, swa_slots: torch.Tensor, delta: int) -> torch.Tensor:
+ """按 slot 所在页更新存活计数,返回触达的页(去重)。"""
+ pages = torch.div(swa_slots.long(), DSV4_SWA_PAGE_SIZE, rounding_mode="floor")
+ ones = torch.full(pages.shape, delta, dtype=torch.int32, device=pages.device)
+ self.swa_page_live_count.index_add_(0, pages, ones)
+ return torch.unique(pages)
+
+ def alloc_swa_prefill(
+ self,
+ b_req_idx: torch.Tensor,
+ b_ready_cache_len: torch.Tensor,
+ b_seq_len: torch.Tensor,
+ req_to_token_indexs: torch.Tensor,
+ b_req_idx_cpu: torch.Tensor,
+ b_ready_cache_len_cpu: torch.Tensor,
+ b_seq_len_cpu: torch.Tensor,
+ ) -> None:
+ """prefill prep: 为各请求位置 [ready, seq) 的新 token 分配位置对齐的 swa 槽。
+
+ 槽位不变式: slot(p) = page_base(p 所在页) + p%128,page_base % 128 == 0。
+ 续页(start 非整页,只可能是首页)的 base 从上一 token 的映射派生
+ (full_to_swa[req_to_token[req, start-1]],该 token 必在保留窗内);其余页全新分配。
+ radix 命中(ready 必 128 对齐)的借用方从全新页开始,与节点持有页天然不相交。
+ 必须在 init_req_to_token_indexes 之后调用(scatter 目标经 req_to_token 行)。
+ """
+ page = DSV4_SWA_PAGE_SIZE
+ hold_req_id = self.max_request_num # padding 行的请求 id(req_manager.HOLD_REQUEST_ID)
+
+ req_list = b_req_idx_cpu.tolist()
+ ready_list = b_ready_cache_len_cpu.tolist()
+ seq_list = b_seq_len_cpu.tolist()
+
+ segs = [] # (req_idx, start, end, n_new_pages, has_cont_page)
+ total_new_pages = 0
+ for req_idx, start, end in zip(req_list, ready_list, seq_list):
+ req_idx, start, end = int(req_idx), int(start), int(end)
+ if req_idx == hold_req_id or end <= start:
+ continue
+ first_new_page = _ceil_div(start, page)
+ n_new = max(0, (end - 1) // page - first_new_page + 1)
+ segs.append((req_idx, start, end, n_new, start % page != 0))
+ total_new_pages += n_new
+ if not segs:
+ return
+
+ new_pages = self._alloc_swa_pages(total_new_pages).cuda(non_blocking=True).long() if total_new_pages else None
+ page_cursor = 0
+ for req_idx, start, end, n_new, has_cont in segs:
+ positions = torch.arange(start, end, dtype=torch.long, device="cuda")
+ page_local = torch.div(positions, page, rounding_mode="floor") - start // page
+ bases = torch.empty(((end - 1) // page - start // page + 1,), dtype=torch.long, device="cuda")
+ if has_cont:
+ prev_slot = int(self.full_to_swa_indexs[req_to_token_indexs[req_idx, start - 1].long()].item())
+ # 续页不变式: 上一 token 必驻留(retain >= 2)且位置对齐(未来 resume/MTP 改动的哨兵)。
+ assert prev_slot >= 0 and prev_slot % page == (start - 1) % page
+ bases[0] = prev_slot - (start - 1) % page
+ if n_new:
+ bases[1 if has_cont else 0 :] = new_pages[page_cursor : page_cursor + n_new] * page
+ page_cursor += n_new
+ slots = (bases[page_local] + positions % page).to(torch.int32)
+ self.full_to_swa_indexs[req_to_token_indexs[req_idx, start:end].long()] = slots
+ self._count_swa_pages(slots, 1)
+ return
+
+ def alloc_swa_decode(
+ self,
+ b_req_idx_cpu: torch.Tensor,
+ b_seq_len_cpu: torch.Tensor,
+ mem_indexes: torch.Tensor,
+ req_to_token_indexs: torch.Tensor,
+ ) -> None:
+ """decode prep: 本步 token(位置 seq-1)的 swa 槽。整页起点开新页,否则上一 token 槽 +1
+ (位置对齐不变式保证同页连续)。scatter 目标用 mem_indexes(此刻 req_to_token 尚未写本步)。
+
+ 注意: 续槽从上一位置的映射派生,故同一请求的多行(MTP 多 token/步)在同一批内不支持
+ (DSV4 启动参数已拒绝 MTP;支持需按步内顺序分段派生)。"""
+ page = DSV4_SWA_PAGE_SIZE
+ hold_req_id = self.max_request_num
+ req_list = b_req_idx_cpu.tolist()
+ seq_list = b_seq_len_cpu.tolist()
+ cont_rows, cont_prev_pos, new_rows = [], [], []
+ for i, (req_idx, seq_len) in enumerate(zip(req_list, seq_list)):
+ req_idx, seq_len = int(req_idx), int(seq_len)
+ if req_idx == hold_req_id or seq_len <= 0:
+ continue
+ if (seq_len - 1) % page == 0:
+ new_rows.append(i)
+ else:
+ cont_rows.append(i)
+ cont_prev_pos.append(seq_len - 2)
+ mem_indexes = mem_indexes.cuda().long().reshape(-1)
+ if cont_rows:
+ req_rows = torch.tensor([req_list[i] for i in cont_rows], dtype=torch.long, device="cuda")
+ prev_full = req_to_token_indexs[req_rows, torch.tensor(cont_prev_pos, device="cuda")].long()
+ prev_slots = self.full_to_swa_indexs[prev_full]
+ # 续槽不变式哨兵: 上一位置必驻留(retain 覆盖)。prep 阶段本就有同步,代价可忽略。
+ assert bool((prev_slots >= 0).all())
+ slots = prev_slots + 1
+ self.full_to_swa_indexs[mem_indexes[cont_rows]] = slots
+ self._count_swa_pages(slots, 1)
+ if new_rows:
+ pages = self._alloc_swa_pages(len(new_rows)).cuda(non_blocking=True).long()
+ slots = (pages * page).to(torch.int32)
+ self.full_to_swa_indexs[mem_indexes[new_rows]] = slots
+ self._count_swa_pages(slots, 1)
+ return
+
+ def evict_swa(self, full_slots: torch.Tensor) -> None:
+ """回收 full 槽位对应的 swa 槽(出窗惰性回收 / free 级联 / 压力阀共用)。
+ 未映射(-1)的槽位跳过;页计数减到 0 时整页归还 allocator。"""
+ if full_slots.numel() == 0:
+ return
+ full_slots = full_slots.cuda().long().reshape(-1)
+ full_slots = torch.unique(full_slots[full_slots != self.HOLD_TOKEN_MEMINDEX])
+ if full_slots.numel() == 0:
+ return
+ swa_slots = self.full_to_swa_indexs[full_slots]
+ valid = swa_slots >= 0
+ valid_slots = swa_slots[valid]
+ if valid_slots.numel() == 0:
+ return
+ self.full_to_swa_indexs[full_slots[valid]] = -1
+ touched = self._count_swa_pages(valid_slots, -1)
+ empty = touched[self.swa_page_live_count[touched] == 0]
+ if empty.numel() > 0:
+ self.swa_page_allocator.free(empty.to(torch.int32))
+ return
+
+ def _evict_compress(self, full_slots: torch.Tensor, mapping: torch.Tensor, allocator: KvCacheAllocator) -> None:
+ full_slots = full_slots.cuda().long().reshape(-1)
+ # 去重: 同批重复槽会 gather 出重复的压缩槽 -> allocator 双重释放(free 已去重,直呼叫方防御)。
+ full_slots = torch.unique(full_slots[full_slots != self.HOLD_TOKEN_MEMINDEX])
+ if full_slots.numel() == 0:
+ return
+ slots = mapping[full_slots]
+ valid = slots >= 0
+ valid_slots = slots[valid]
+ if valid_slots.numel() == 0:
+ return
+ allocator.free(valid_slots)
+ mapping[full_slots[valid]] = -1
+ return
+
+ def alloc_c4_pages(self, need_pages: int) -> torch.Tensor:
+ assert self.c4_page_allocator is not None, "DeepSeek-V4 c4 page allocator is not initialized"
+ return self.c4_page_allocator.alloc(need_pages)
+
+ def count_c4_slots(self, c4_slots: torch.Tensor, delta: int) -> torch.Tensor:
+ """按 c4 slot 所在页更新存活计数,返回触达的页(去重)。"""
+ assert self.c4_page_live_count is not None, "DeepSeek-V4 c4 page live count is not initialized"
+ pages = torch.div(c4_slots.long(), DSV4_C4_PAGE_SIZE, rounding_mode="floor")
+ ones = torch.full(pages.shape, delta, dtype=torch.int32, device=pages.device)
+ self.c4_page_live_count.index_add_(0, pages, ones)
+ return torch.unique(pages)
+
+ def evict_c4(self, full_slots: torch.Tensor) -> None:
+ """回收 full 槽位(组末 token)映射的 c4 槽。非组末/未映射(-1)的槽位跳过。"""
+ if self.c4_page_allocator is None or full_slots.numel() == 0:
+ return
+ full_slots = full_slots.cuda().long().reshape(-1)
+ full_slots = torch.unique(full_slots[full_slots != self.HOLD_TOKEN_MEMINDEX])
+ if full_slots.numel() == 0:
+ return
+ slots = self.full_to_c4_indexs[full_slots]
+ valid = slots >= 0
+ valid_slots = slots[valid]
+ if valid_slots.numel() == 0:
+ return
+ self.full_to_c4_indexs[full_slots[valid]] = -1
+ touched = self.count_c4_slots(valid_slots, -1)
+ empty = touched[self.c4_page_live_count[touched] == 0]
+ if empty.numel() > 0:
+ self.c4_page_allocator.free(empty.to(torch.int32))
+ return
+
+ def evict_c128(self, full_slots: torch.Tensor) -> None:
+ """回收 full 槽位(组末 token)映射的 c128 槽。非组末/未映射(-1)的槽位跳过。"""
+ if self.c128_allocator is None or full_slots.numel() == 0:
+ return
+ self._evict_compress(full_slots, self.full_to_c128_indexs, self.c128_allocator)
+ return
+
+ # ------------------------------------------------------------------ alloc/free (cascade)
+ def free(self, free_index: Union[torch.Tensor, List[int]]) -> None:
+ """释放 full token 槽位,级联回收其 swa 槽与 c4/c128 压缩槽。radix 驱逐、请求释放/暂停都走这里。
+
+ 先对 full 槽去重: 同批重复槽位会让映射 gather 出重复的压缩/swa 槽,导致 allocator 双重释放。"""
+ if isinstance(free_index, list):
+ free_index = torch.tensor(free_index, dtype=torch.int64)
+ if free_index.numel() > 0:
+ free_index = torch.unique(free_index)
+ self.evict_swa(free_index)
+ self.evict_c4(free_index)
+ self.evict_c128(free_index)
+ super().free(free_index)
+ return
+
+ def free_all(self):
+ super().free_all()
+ self.swa_page_allocator.free_all()
+ self.swa_page_live_count.zero_()
+ self.full_to_swa_indexs.fill_(-1)
+ self.full_to_swa_indexs[self.HOLD_TOKEN_MEMINDEX] = self.swa_pool.HOLD_TOKEN_MEMINDEX
+ if self.c4_page_allocator is not None:
+ self.c4_page_allocator.free_all()
+ self.c4_page_live_count.zero_()
+ self.full_to_c4_indexs.fill_(-1)
+ self.full_to_c4_indexs[self.HOLD_TOKEN_MEMINDEX] = self.c4_pool.HOLD_TOKEN_MEMINDEX
+ if self.c128_allocator is not None:
+ self.c128_allocator.free_all()
+ self.full_to_c128_indexs.fill_(-1)
+ self.full_to_c128_indexs[self.HOLD_TOKEN_MEMINDEX] = self.c128_pool.HOLD_TOKEN_MEMINDEX
+ return
+
+ def alloc_c4(self, need_size) -> torch.Tensor:
+ raise AssertionError("DeepSeek-V4 c4 uses page-safe allocation; call alloc_c4_pages instead")
+
+ def alloc_c128(self, need_size) -> torch.Tensor:
+ return self.c128_allocator.alloc(need_size)
+
+ def free_c4(self, free_index) -> None:
+ raise AssertionError("DeepSeek-V4 c4 uses page live-count release; call evict_c4 instead")
+
+ def free_c128(self, free_index) -> None:
+ self.c128_allocator.free(free_index)
+
+ # ------------------------------------------------------------------ packed codecs (torch reference)
+ # 与 sglang/vllm 的 fp8_ds_mla 字节布局逐位对齐(ue8m0 幂次 scale)。这些 torch 实现是该 ABI 的
+ # 可执行规格(单测 oracle,triton writer 与其逐字节对拍),不可删除。
+ def _pack_mla_kv(self, kv: torch.Tensor) -> torch.Tensor:
+ kv = kv.reshape(-1, self.mla_head_dim)
+ out = torch.empty((kv.shape[0], self.mla_bytes_per_token), dtype=torch.uint8, device=kv.device)
+ nope = kv[:, : self.mla_nope_dim].float().reshape(-1, self.mla_scale_bytes - 1, self.mla_quant_group_size)
+ scale = torch.clamp(nope.abs().amax(dim=-1) / DSV4_FP8_E4M3_MAX, min=DSV4_FP8_SCALE_MIN)
+ scale_exp = torch.ceil(torch.log2(scale)).to(torch.int32)
+ scale = torch.exp2(scale_exp.float())
+ nope_fp8 = torch.clamp(nope / scale.unsqueeze(-1), -DSV4_FP8_E4M3_MAX, DSV4_FP8_E4M3_MAX).to(
+ torch.float8_e4m3fn
+ )
+ out[:, : self.mla_nope_dim].copy_(nope_fp8.reshape(-1, self.mla_nope_dim).view(dtype=torch.uint8))
+ rope_start = self.mla_nope_dim
+ rope_end = rope_start + self.mla_rope_dim * 2
+ rope = kv[:, self.mla_nope_dim : self.mla_head_dim].contiguous().to(torch.bfloat16)
+ out[:, rope_start:rope_end].copy_(rope.view(dtype=torch.uint8).reshape(-1, self.mla_rope_dim * 2))
+ scale_start = rope_end
+ scale_end = scale_start + self.mla_scale_bytes - 1
+ out[:, scale_start:scale_end].copy_((scale_exp + 127).to(torch.uint8))
+ out[:, scale_end].zero_()
+ return out
+
+ def _unpack_mla_kv(self, packed: torch.Tensor) -> torch.Tensor:
+ packed = packed.reshape(-1, self.mla_bytes_per_token)
+ if packed.shape[0] == 0:
+ return torch.empty((0, self.mla_head_dim), dtype=self.dtype, device=packed.device)
+ nope_fp8 = packed[:, : self.mla_nope_dim].view(dtype=torch.float8_e4m3fn).float()
+ nope_fp8 = nope_fp8.reshape(-1, self.mla_scale_bytes - 1, self.mla_quant_group_size)
+ rope_start = self.mla_nope_dim
+ rope_end = rope_start + self.mla_rope_dim * 2
+ scale_start = rope_end
+ scale_end = scale_start + self.mla_scale_bytes - 1
+ scale_exp = packed[:, scale_start:scale_end].to(torch.int32) - 127
+ scale = torch.exp2(scale_exp.float())
+ nope = (nope_fp8 * scale.reshape(-1, self.mla_scale_bytes - 1, 1)).reshape(-1, self.mla_nope_dim)
+ rope = packed[:, rope_start:rope_end].view(dtype=torch.bfloat16)
+ return torch.cat([nope.to(self.dtype), rope.to(self.dtype)], dim=-1)
+
+ def _pack_indexer_k(self, indexer_k: torch.Tensor) -> torch.Tensor:
+ indexer_k = indexer_k.reshape(-1, self.indexer_head_dim)
+ out = torch.empty(
+ (indexer_k.shape[0], self.indexer_bytes_per_token),
+ dtype=torch.uint8,
+ device=indexer_k.device,
+ )
+ k_float = indexer_k.float()
+ scale = torch.clamp(
+ k_float.abs().amax(dim=-1, keepdim=True) / DSV4_FP8_E4M3_MAX,
+ min=DSV4_FP8_SCALE_MIN,
+ )
+ k_fp8 = torch.clamp(k_float / scale, -DSV4_FP8_E4M3_MAX, DSV4_FP8_E4M3_MAX).to(torch.float8_e4m3fn)
+ out[:, : self.indexer_head_dim].copy_(k_fp8.view(dtype=torch.uint8))
+ out[:, self.indexer_head_dim :].copy_(scale.view(dtype=torch.uint8).reshape(-1, DSV4_INDEXER_SCALE_BYTES))
+ return out
+
+ def _unpack_indexer_k(self, packed: torch.Tensor) -> torch.Tensor:
+ packed = packed.reshape(-1, self.indexer_bytes_per_token)
+ if packed.shape[0] == 0:
+ return torch.empty((0, self.indexer_head_dim), dtype=self.dtype, device=packed.device)
+ k_fp8 = packed[:, : self.indexer_head_dim].view(dtype=torch.float8_e4m3fn).float()
+ scale = packed[:, self.indexer_head_dim :].view(dtype=torch.float32)
+ return (k_fp8 * scale).to(self.dtype)
+
+ # ------------------------------------------------------------------ cache write paths
+ def pack_mla_kv_to_cache(self, layer_index: int, mem_index: torch.Tensor, kv: torch.Tensor):
+ """标准 operator 写入路径。要求本步已对 mem_index 调过 ``alloc_swa``(prep 阶段);
+ HOLD/padding 槽位映射到 swa HOLD 槽,写入无害。"""
+ if kv.shape[0] == 0:
+ return
+ from lightllm.models.deepseek_v4.triton_kernel.destindex_copy_kv_flashmla_dsv4 import (
+ destindex_copy_kv_flashmla_dsv4,
+ )
+
+ swa_slots = self.full_to_swa_indexs[mem_index.cuda().long().reshape(-1)]
+ destindex_copy_kv_flashmla_dsv4(
+ kv.reshape(-1, self.mla_head_dim),
+ swa_slots,
+ self.swa_pool.get_layer_buffer(layer_index),
+ self.swa_pool.page_size,
+ )
+ return
+
+ def pack_mla_kv_to_cache_fused_norm_rope(
+ self,
+ layer_index: int,
+ mem_index: torch.Tensor,
+ kv: torch.Tensor,
+ kv_weight: torch.Tensor,
+ eps: float,
+ freqs_cis: torch.Tensor,
+ positions: torch.Tensor,
+ ):
+ """同 pack_mla_kv_to_cache,但 rmsnorm + 尾部交错 rope 融合进写入 kernel
+ (sglang fused_k_norm_rope_flashmla,即 sglang _compute_kv_to_cache 的池侧),
+ 省掉 bf16 kv 中间量。kv 为 wkv 投影原始输出 [T, head_dim+rope_dim]。"""
+ if kv.shape[0] == 0:
+ return
+ from lightllm.third_party.sglang_jit.dsv4 import fused_k_norm_rope_flashmla
+
+ swa_slots = self.full_to_swa_indexs[mem_index.cuda().long().reshape(-1)]
+ # 未映射槽位(-1, 如 decode 图 warmup 的 HOLD 行: prep 跳过 alloc_swa)对老 triton
+ # 写入核是显式 no-op;sglang fused 核无负槽位防护(负页偏移=非法访存),mask 到
+ # swa HOLD 槽(垃圾桶语义,与 padding 行写入一致)。
+ swa_slots = torch.where(swa_slots < 0, torch.full_like(swa_slots, self.swa_pool.HOLD_TOKEN_MEMINDEX), swa_slots)
+ fused_k_norm_rope_flashmla(
+ kv=kv,
+ kv_weight=kv_weight,
+ eps=eps,
+ freqs_cis=freqs_cis,
+ positions=positions,
+ out_loc=swa_slots,
+ kvcache=self.swa_pool.get_layer_buffer(layer_index),
+ page_size=self.swa_pool.page_size,
+ )
+ return
+
+ def pack_compressed_kv_to_cache(self, layer_index: int, slots: torch.Tensor, comp: torch.Tensor):
+ if comp.shape[0] == 0:
+ return
+ from lightllm.models.deepseek_v4.triton_kernel.destindex_copy_kv_flashmla_dsv4 import (
+ destindex_copy_kv_flashmla_dsv4,
+ )
+
+ pool, local_layer = self._pool_and_local_layer(layer_index)
+ destindex_copy_kv_flashmla_dsv4(
+ comp.reshape(-1, self.mla_head_dim),
+ slots.to(comp.device),
+ pool.get_layer_buffer(local_layer),
+ pool.page_size,
+ )
+
+ def pack_indexer_k_to_cache(self, layer_index: int, slots: torch.Tensor, indexer_k: torch.Tensor):
+ if indexer_k.shape[0] == 0:
+ return
+ assert self.compress_rates[layer_index] == 4, "只有 c4(CSA) 层有 indexer-K"
+ from lightllm.models.deepseek_v4.triton_kernel.destindex_copy_indexer_k_dsv4 import (
+ destindex_copy_indexer_k_dsv4,
+ )
+
+ destindex_copy_indexer_k_dsv4(
+ indexer_k.reshape(-1, self.indexer_head_dim),
+ slots.to(indexer_k.device),
+ self.c4_indexer_pool.get_layer_buffer(self.layer_to_c4_idx[layer_index]),
+ self.c4_indexer_pool.page_size,
+ )
+
+ def gather_indexer_k(self, layer_index: int, slots: torch.Tensor) -> torch.Tensor:
+ """反量化 gather c4 indexer-K: slots [N](c4 槽位,HOLD 合法) -> [N, indexer_head_dim] bf16。
+ indexer top-k 打分用(纯张量操作,cuda-graph 安全)。"""
+ assert self.compress_rates[layer_index] == 4, "只有 c4(CSA) 层有 indexer-K"
+ pool = self.c4_indexer_pool
+ flat = pool.get_layer_buffer(self.layer_to_c4_idx[layer_index]).view(-1)
+ data_offsets, scale_offsets = pool._loc_offsets(slots.reshape(-1))
+ data_range = torch.arange(pool.data_bytes_per_token, device=flat.device)
+ scale_range = torch.arange(pool.scale_bytes_per_token, device=flat.device)
+ k_fp8 = flat[data_offsets.unsqueeze(1) + data_range.unsqueeze(0)].view(torch.float8_e4m3fn)
+ scale = flat[scale_offsets.unsqueeze(1) + scale_range.unsqueeze(0)].contiguous().view(torch.float32)
+ return (k_fp8.float() * scale).to(torch.bfloat16)
+
+ # ------------------------------------------------------------------ fenced inherited APIs
+ # kv_buffer 是 page 索引的 uint8 slab,基类按 token 索引读写的接口会静默写坏数据,显式拦截。
+ def get_index_kv_buffer(self, index):
+ raise NotImplementedError("DeepSeek-V4 packed page-slab cache does not support token-indexed kv_buffer io")
+
+ def load_index_kv_buffer(self, index, load_tensor_dict):
+ raise NotImplementedError("DeepSeek-V4 packed page-slab cache does not support token-indexed kv_buffer io")
+
+ def alloc_kv_move_buffer(self, max_req_total_len):
+ raise NotImplementedError("DeepSeek-V4 packed/composite KV transfer is not implemented")
+
+ def alloc_paged_kv_move_buffer(self, page_num, page_size) -> torch.Tensor:
+ raise NotImplementedError("DeepSeek-V4 packed/composite KV transfer is not implemented")
+
+ def write_mem_to_page_kv_move_buffer(self, *args, **kwargs):
+ raise NotImplementedError("DeepSeek-V4 packed/composite KV transfer is not implemented")
+
+ def read_page_kv_move_buffer_to_mem(self, *args, **kwargs):
+ raise NotImplementedError("DeepSeek-V4 packed/composite KV transfer is not implemented")
+
+ def send_to_decode_node(self, *args, **kwargs):
+ raise NotImplementedError("DeepSeek-V4 packed/composite KV transfer is not implemented")
+
+ def receive_from_prefill_node(self, *args, **kwargs):
+ raise NotImplementedError("DeepSeek-V4 packed/composite KV transfer is not implemented")
+
+ def send_to_decode_node_p2p(self, *args, **kwargs):
+ raise NotImplementedError("DeepSeek-V4 packed/composite KV transfer is not implemented")
+
+ def receive_from_prefill_node_p2p(self, *args, **kwargs):
+ raise NotImplementedError("DeepSeek-V4 packed/composite KV transfer is not implemented")
diff --git a/lightllm/common/kv_cache_mem_manager/operator/__init__.py b/lightllm/common/kv_cache_mem_manager/operator/__init__.py
index 85c37ad39..442c2e300 100644
--- a/lightllm/common/kv_cache_mem_manager/operator/__init__.py
+++ b/lightllm/common/kv_cache_mem_manager/operator/__init__.py
@@ -5,6 +5,7 @@
from .deepseek import (
Deepseek2MemOperator,
Deepseek3_2MemOperator,
+ DeepseekV4MemOperator,
FP8PerTokenGroupQuantDeepseek3_2MemOperator,
)
from .fp8_quant import (
diff --git a/lightllm/common/kv_cache_mem_manager/operator/deepseek.py b/lightllm/common/kv_cache_mem_manager/operator/deepseek.py
index 6e05b96e1..0725ce9b9 100644
--- a/lightllm/common/kv_cache_mem_manager/operator/deepseek.py
+++ b/lightllm/common/kv_cache_mem_manager/operator/deepseek.py
@@ -8,7 +8,9 @@
class Deepseek2MemOperator(NormalMemOperator):
def copy_kv_to_mem_manager(self, layer_index: int, mem_index: torch.Tensor, kv: torch.Tensor):
- from lightllm.common.kv_cache_mem_manager.deepseek2_mem_manager import Deepseek2MemoryManager
+ from lightllm.common.kv_cache_mem_manager.deepseek2_mem_manager import (
+ Deepseek2MemoryManager,
+ )
mem_manager: Deepseek2MemoryManager = self.mem_manager
@@ -30,7 +32,9 @@ def copy_kv_to_mem_manager(self, layer_index: int, mem_index: torch.Tensor, kv:
class Deepseek3_2MemOperator(Deepseek2MemOperator):
def copy_kv_to_mem_manager(self, layer_index: int, mem_index: torch.Tensor, kv: torch.Tensor):
- from lightllm.common.kv_cache_mem_manager.deepseek3_2mem_manager import Deepseek3_2MemoryManager
+ from lightllm.common.kv_cache_mem_manager.deepseek3_2mem_manager import (
+ Deepseek3_2MemoryManager,
+ )
mem_manager: Deepseek3_2MemoryManager = self.mem_manager
from ...basemodel.triton_kernel.kv_copy.mla_copy_kv import destindex_copy_kv
@@ -78,3 +82,14 @@ def copy_kv_to_mem_manager(self, layer_index: int, mem_index: torch.Tensor, kv:
o_rope,
)
return
+
+
+class DeepseekV4MemOperator(BaseMemManagerOperator):
+ def copy_kv_to_mem_manager(self, layer_index: int, mem_index: torch.Tensor, kv: torch.Tensor):
+ from lightllm.common.kv_cache_mem_manager.deepseek4_mem_manager import (
+ DeepseekV4MemoryManager,
+ )
+
+ mem_manager: DeepseekV4MemoryManager = self.mem_manager
+ mem_manager.pack_mla_kv_to_cache(layer_index, mem_index, kv)
+ return
diff --git a/lightllm/common/quantization/__init__.py b/lightllm/common/quantization/__init__.py
index cd534d53e..4db55e155 100644
--- a/lightllm/common/quantization/__init__.py
+++ b/lightllm/common/quantization/__init__.py
@@ -14,6 +14,7 @@
EXPERT_DTYPE_TO_QUANT_TYPE = {
"fp8": "deepgemm-fp8w8a8-b128",
"fp4": "deepgemm-fp4fp8-b32",
+ "mxfp4": "marlin-mxfp4w4a16-b32",
}
SUPPORTED_EXPERT_DTYPES = tuple(EXPERT_DTYPE_TO_QUANT_TYPE)
@@ -64,10 +65,13 @@ def _mapping_quant_method(self):
logger.info(f"select fp8w8a8-b128 quant way: {self.quant_type}")
# fp8 量化下,部分 MoE 模型(如 DeepSeek-V4),可以单独声明 expert 权重精度,
- # 按其值给 fused_moe 选用对应的 deepgemm 量化方法。
+ # 按其值给 fused_moe 选用对应的量化方法。
expert_dtype = self.expert_dtype or self.network_config_.get("expert_dtype", None)
if expert_dtype is None:
return
+ # DeepSeek-V4 的 fp4 发布版自带预打包 MXFP4 专家。
+ if expert_dtype == "fp4" and self.network_config_.get("model_type") == "deepseek_v4":
+ expert_dtype = "mxfp4"
target = self._get_expert_quant_type(expert_dtype)
for layer_num in range(self.layer_num):
if self.expert_dtype is not None:
diff --git a/lightllm/common/quantization/deepgemm.py b/lightllm/common/quantization/deepgemm.py
index ec1ee90fd..bedf22ee9 100644
--- a/lightllm/common/quantization/deepgemm.py
+++ b/lightllm/common/quantization/deepgemm.py
@@ -198,6 +198,144 @@ def _create_weight(
return mm_param, mm_param_list
+@QUANTMETHODS.register(["marlin-mxfp4w4a16-b32"], platform="cuda")
+class MXFP4MoEQuantizationMethod(QuantizationMethod):
+ def __init__(self):
+ super().__init__()
+ self.block_size = 32
+ self.weight_suffix = "weight"
+ self.weight_zero_point_suffix = None
+ self.weight_scale_suffix = "scale"
+ self.has_weight_scale = True
+ self.has_weight_zero_point = False
+
+ @property
+ def method_name(self):
+ return "marlin-mxfp4w4a16-b32"
+
+ def quantize(self, weight: torch.Tensor, output: WeightPack):
+ raise NotImplementedError("marlin-mxfp4w4a16-b32 only loads pre-packed MXFP4 expert weights")
+
+ def apply(
+ self,
+ input_tensor: torch.Tensor,
+ weight_pack: "WeightPack",
+ out: Optional[torch.Tensor] = None,
+ workspace: Optional[torch.Tensor] = None,
+ use_custom_tensor_mananger: bool = True,
+ bias: Optional[torch.Tensor] = None,
+ ) -> torch.Tensor:
+ raise NotImplementedError("marlin-mxfp4w4a16-b32 is only implemented for fused MoE expert weights")
+
+ def _probe_marlin_layout(self, size_n: int, size_k: int, dtype: torch.dtype, device_id: int):
+ """用零输入走一遍真实的 per-expert repack 路径,探出 marlin 终态布局的形状与类型。
+ 只调用 finalize 同款的 vllm 函数,不复刻其内部公式,杜绝形状漂移。结果按维度缓存
+ (各 MoE 层同维,全程只探两次: w13 一次、w2 一次)。"""
+ cache_key = (size_n, size_k, dtype)
+ cache = getattr(self, "_marlin_layout_cache", None)
+ if cache is None:
+ cache = self._marlin_layout_cache = {}
+ if cache_key in cache:
+ return cache[cache_key]
+
+ import vllm._custom_ops as ops
+ from vllm.model_executor.layers.quantization.utils.marlin_utils import (
+ get_marlin_input_dtype,
+ marlin_permute_scales,
+ )
+ from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import (
+ mxfp4_marlin_process_scales,
+ )
+
+ input_dtype = get_marlin_input_dtype()
+ is_a_8bit = input_dtype is not None and input_dtype.itemsize == 1
+ device = f"cuda:{device_id}"
+ qweight = torch.zeros((size_n, size_k // 2), dtype=torch.int8, device=device).view(torch.int32).T.contiguous()
+ marlin_qweight = ops.gptq_marlin_repack(
+ b_q_weight=qweight,
+ perm=torch.empty(0, dtype=torch.int, device=device),
+ size_k=size_k,
+ size_n=size_n,
+ num_bits=4,
+ is_a_8bit=is_a_8bit,
+ )
+ scale = torch.zeros((size_k // self.block_size, size_n), dtype=dtype, device=device)
+ marlin_scale = marlin_permute_scales(
+ s=scale, size_k=size_k, size_n=size_n, group_size=self.block_size, is_a_8bit=is_a_8bit
+ )
+ marlin_scale = mxfp4_marlin_process_scales(marlin_scale, input_dtype=input_dtype)
+ layout = (
+ (tuple(marlin_qweight.shape), marlin_qweight.dtype),
+ (tuple(marlin_scale.shape), marlin_scale.dtype),
+ )
+ cache[cache_key] = layout
+ return layout
+
+ def _create_weight(
+ self, out_dims: Union[int, List[int]], in_dim: int, dtype: torch.dtype, device_id: int, num_experts: int = 1
+ ) -> Tuple[WeightPack, List[WeightPack]]:
+ out_dim = sum(out_dims) if isinstance(out_dims, list) else out_dims
+ assert in_dim % self.block_size == 0, "MXFP4 scale dimension must be divisible by block_size"
+ expert_prefix = (num_experts,) if num_experts > 1 else ()
+ # CPU 暂存区: load_hf_weights 灌入原始预打包 MXFP4,finalize 时 repack 进 CUDA 终态。
+ weight = torch.empty(expert_prefix + (out_dim, in_dim // 2), dtype=torch.int8, device="cpu")
+ weight_scale = torch.empty(
+ expert_prefix + (out_dim, in_dim // self.block_size), dtype=torch.float8_e8m0fnu, device="cpu"
+ )
+ mm_param = WeightPack(weight=weight, weight_scale=weight_scale)
+ # CUDA 终态(marlin 布局)在构造期物化,使 mem manager 的 profile 看到真实权重占用
+ # ("构造即分配、load 只灌数"的框架契约,与其它 quant 方法一致;惰性到 finalize 才
+ # 进卡会让空卡 profile 把 kv 池撑到挤爆权重加载)。finalize 时 repack 结果拷入。
+ (w_shape, w_dtype), (s_shape, s_dtype) = self._probe_marlin_layout(out_dim, in_dim, dtype, device_id)
+ mm_param.marlin_weight = torch.empty((num_experts,) + w_shape, dtype=w_dtype, device=f"cuda:{device_id}")
+ mm_param.marlin_weight_scale = torch.empty((num_experts,) + s_shape, dtype=s_dtype, device=f"cuda:{device_id}")
+ mm_param_list = self._split_weight_pack(
+ mm_param,
+ weight_out_dims=out_dims,
+ weight_split_dim=-2,
+ weight_scale_out_dims=out_dims,
+ weight_scale_split_dim=-2,
+ )
+ return mm_param, mm_param_list
+
+ def finalize_moe_weight(self, moe_weight):
+ try:
+ from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import (
+ prepare_moe_mxfp4_layer_for_marlin,
+ )
+ except Exception as e:
+ raise RuntimeError(f"marlin-mxfp4w4a16-b32 requires vLLM MXFP4 packing utilities, error={repr(e)}") from e
+
+ class _MXFP4Layer:
+ pass
+
+ device = torch.device("cuda", moe_weight.device_id_)
+ layer = _MXFP4Layer()
+ layer.params_dtype = moe_weight.data_type_
+ w13 = moe_weight.w13.weight.view(torch.uint8).to(device=device, non_blocking=True).contiguous()
+ w2 = moe_weight.w2.weight.view(torch.uint8).to(device=device, non_blocking=True).contiguous()
+ w13_scale = moe_weight.w13.weight_scale.to(device=device, non_blocking=True).contiguous()
+ w2_scale = moe_weight.w2.weight_scale.to(device=device, non_blocking=True).contiguous()
+ (
+ w13_new,
+ w2_new,
+ w13_scale_new,
+ w2_scale_new,
+ _,
+ _,
+ ) = prepare_moe_mxfp4_layer_for_marlin(layer, w13, w2, w13_scale, w2_scale, None, None)
+ # repack 结果拷入构造期预分配的 marlin 终态 buffer(与 AWQ marlin 路径同形态),
+ # CPU 暂存与 repack 临时随引用释放;shape 失配会在 copy_ 处显式报错(探针保证一致)。
+ moe_weight.w13.marlin_weight.copy_(w13_new)
+ moe_weight.w13.marlin_weight_scale.copy_(w13_scale_new)
+ moe_weight.w2.marlin_weight.copy_(w2_new)
+ moe_weight.w2.marlin_weight_scale.copy_(w2_scale_new)
+ moe_weight.w13.weight = moe_weight.w13.marlin_weight
+ moe_weight.w13.weight_scale = moe_weight.w13.marlin_weight_scale
+ moe_weight.w2.weight = moe_weight.w2.marlin_weight
+ moe_weight.w2.weight_scale = moe_weight.w2.marlin_weight_scale
+
+
def _deepgemm_fp8_nt(a_tuple, b_tuple, out):
if HAS_DEEPGEMM:
if hasattr(deep_gemm, "gemm_fp8_fp8_bf16_nt"):
diff --git a/lightllm/common/req_manager.py b/lightllm/common/req_manager.py
index 01e9c4ad3..1af4991a3 100644
--- a/lightllm/common/req_manager.py
+++ b/lightllm/common/req_manager.py
@@ -1,17 +1,26 @@
import torch
import collections
+from dataclasses import dataclass
from lightllm.common.linear_att_cache_manager.config_objs import LinearAttCacheConfig
from lightllm.utils.log_utils import init_logger
-from .kv_cache_mem_manager import MemoryManager
+from .kv_cache_mem_manager import MemoryManager, DeepseekV4MemoryManager
from typing import List, Optional, TYPE_CHECKING
from lightllm.common.basemodel.triton_kernel.gen_sampling_params import token_id_counter
-from lightllm.common.basemodel.triton_kernel.gen_sampling_params import update_req_to_token_id_counter
+from lightllm.common.basemodel.triton_kernel.gen_sampling_params import (
+ update_req_to_token_id_counter,
+)
from lightllm.utils.envs_utils import enable_env_vars, get_env_start_args
from lightllm.utils.config_utils import get_vocab_size
from lightllm.server.router.model_infer.pin_mem_manager import g_pin_mem_manager
from lightllm.common.linear_att_cache_manager.layer_cache import LayerCache
-from lightllm.common.linear_att_cache_manager.linear_att_buffer_manager import LinearAttCacheManager
+from lightllm.common.linear_att_cache_manager.linear_att_buffer_manager import (
+ LinearAttCacheManager,
+)
+from lightllm.common.kv_cache_mem_manager.deepseek4_mem_manager import (
+ DSV4_C4_PAGE_SIZE,
+ DSV4_PROMPT_CACHE_PAGE_SIZE,
+)
if TYPE_CHECKING:
from lightllm.server.router.model_infer.infer_batch import InferReq
@@ -19,6 +28,58 @@
logger = init_logger(__name__)
+@dataclass
+class DeepseekV4PromptCachePayload:
+ """prompt cache 载荷: 只剩 swa 按页有效性 bitmap。
+
+ 槽位与 compressor 状态都不进载荷: full_to_swa/full_to_c4/full_to_c128 以 full token 槽位
+ 为键(radix 持有 full 槽 ⇒ 映射行存活,free 级联回收);c4/c128 compressor 状态以 swa
+ 页派生寻址(随 swa 页生灭,命中零拷贝续算)。prompt cache 对齐到 256 token,
+ 避免共享前缀停在 c4 物理页中间。
+
+ * ``swa_page_valid``: cpu bool [cache_len // page],插入时按当下 full_to_swa 映射写定
+ (页内 token 映射全有效才为 True)。匹配层据此把命中裁剪到"结尾页有效"的 page 边界,
+ swa 压力阀回收节点页时清零。"""
+
+ cache_len: int
+ swa_page_valid: Optional[torch.Tensor] = None
+
+
+class DeepseekV4PromptCacheValueOps:
+ def __init__(self, req_manager: "DeepseekV4ReqManager"):
+ self.req_manager = req_manager
+
+ def slice(self, payload: DeepseekV4PromptCachePayload, start: int, end: int):
+ return self.req_manager.slice_prompt_cache_payload(payload, start, end)
+
+ def concat(self, payloads: List[DeepseekV4PromptCachePayload]):
+ return self.req_manager.concat_prompt_cache_payloads(payloads)
+
+ def free(self, payload: DeepseekV4PromptCachePayload):
+ # 槽位资源全部由 mem_manager.free(full_slots) 级联回收,载荷本身没有需要释放的资源。
+ return
+
+ def invalidate_swa_pages(self, payload: DeepseekV4PromptCachePayload) -> None:
+ """swa 压力阀回收了该节点的 swa 页后清 bitmap: 后续命中按缩短语义裁剪,不会复活。"""
+ if payload is not None and payload.swa_page_valid is not None:
+ payload.swa_page_valid.fill_(False)
+ return
+
+ def valid_match_length(self, payload: Optional[DeepseekV4PromptCachePayload], natural_len: int) -> int:
+ """radix 匹配裁剪: 返回 <= natural_len 的最大 prompt-cache 边界 L',使结尾页有效。
+
+ 有效性可能非单调(owner 生前从左驱逐、后续阀从尾回收),按候选边界回查 bitmap;
+ 中段 invalid 页不挡更靠后的有效命中(注意力只回看最后一个窗口)。"""
+ page = self.req_manager.get_prompt_cache_page_size()
+ if payload is None or payload.swa_page_valid is None:
+ return 0
+ n_pages = min(natural_len // page, int(payload.swa_page_valid.numel()))
+ valid_idx = torch.nonzero(payload.swa_page_valid[:n_pages])
+ if valid_idx.numel() == 0:
+ return 0
+ return (int(valid_idx[-1]) + 1) * page
+
+
class _ReqNode:
def __init__(self, index):
self.index = index
@@ -100,6 +161,24 @@ def free_all(self):
self.req_list = _ReqLinkedList(self.max_request_num)
return
+ def prepare_prefill(
+ self,
+ b_req_idx,
+ b_ready_cache_len,
+ b_seq_len,
+ b_req_idx_cpu=None,
+ b_ready_cache_len_cpu=None,
+ b_seq_len_cpu=None,
+ ):
+ """prefill 在 init_req_to_token_indexes 之后调用的钩子。基类 no-op; 需要
+ prefill KV 槽位 prep 的模型 (DeepSeek-V4) override。"""
+ return
+
+ def prepare_decode(self, b_req_idx_cpu, b_seq_len_cpu, mem_indexes_cpu):
+ """每个 decode step 在 to_cuda 之前调用的钩子 (优先使用 CPU mirror, 且已在 forward 的
+ CUDA stream 上)。基类 no-op; 需要 per-step KV 槽位 prep 的模型 (DeepSeek-V4) override。"""
+ return
+
class ReqSamplingParamsManager:
"""
@@ -299,3 +378,602 @@ def copy_small_page_buffer_to_linear_att_state(
self.req_to_conv_state.buffer[:, dest_req_idx, ...] = conv_state
self.req_to_ssm_state.buffer[:, dest_req_idx, ...] = ssm_state
return
+
+
+class DeepseekV4ReqManager(ReqManager):
+ """DeepSeek-V4 的请求级管理。
+
+ 在基类 ReqManager 之上补 V4 专有的 per-request 结构。该对象在 mem manager profile 前创建,
+ 所以初始化只依赖 config 派生出的 compress_rates/head_dim/indexer_head_dim/sliding_window;
+ 真实 mem_manager 会在 `_init_mem_manager()` 后通过 `bind_mem_manager()` 接入。
+
+ * 压缩槽位不在本类: ``full_to_c4/c128_indexs``(mem manager)以组末 token 的 full 槽位为键。
+ 本类只负责 prep 阶段的分配与 scatter(``prepare_prefill`` /
+ ``prepare_decode_compress_slots``)——必须先于 attention metadata 构建/图捕获;
+ 条目内容由 layer-infer 的 compressor 前向写入。
+ * compressor 在途状态不在本类: c4/c128 都在 mem manager 的 swa 页派生池,
+ 随页生灭,命中零拷贝续算。
+ * SWA 槽位分配/出窗回收(``prepare_prefill_swa`` / ``prepare_decode_swa``): 每步 prep 阶段
+ 为新 token 调 mem_manager.alloc_swa,并按 per-req 水位线(``_swa_evict_marks``)惰性回收
+ 已出窗位置的 swa 槽。水位线首次置为该请求首个 chunk 的 ready_cache_len(radix 共享前缀
+ 的边界),因此共享前缀的 swa 槽永远不会被本请求回收(归 radix 经 mem_manager.free 级联释放)。
+ """
+
+ def __init__(
+ self,
+ max_request_num,
+ max_sequence_length,
+ mem_manager: Optional[DeepseekV4MemoryManager] = None,
+ compress_rates: Optional[List[int]] = None,
+ head_dim: Optional[int] = None,
+ indexer_head_dim: Optional[int] = None,
+ sliding_window: Optional[int] = None,
+ ):
+ super().__init__(max_request_num, max_sequence_length, mem_manager)
+
+ self.sliding_window = sliding_window
+ # 出窗回收水位线: -1 表示该 req 尚未见过 prefill chunk(首个 chunk 的 ready_cache_len
+ # 即共享前缀边界,作为永不下探的回收下界)。
+ self._swa_evict_marks = [-1 for _ in range(max_request_num + 1)]
+ self.compress_rates = list(compress_rates)
+ self.n_c4 = sum(1 for r in self.compress_rates if r == 4)
+ self.n_c128 = sum(1 for r in self.compress_rates if r == 128)
+ self.head_dim = head_dim
+ self.indexer_head_dim = indexer_head_dim
+ self.layer_to_c4_idx = {}
+ self.layer_to_c128_idx = {}
+ self.mem_manager = mem_manager
+ c4 = c128 = 0
+ for lid, r in enumerate(self.compress_rates):
+ if r == 4:
+ self.layer_to_c4_idx[lid] = c4
+ c4 += 1
+ elif r == 128:
+ self.layer_to_c128_idx[lid] = c128
+ c128 += 1
+
+ return
+
+ # ------------------------------------------------------------------ swa slot prep (per step)
+ def _swa_retain_len(self) -> int:
+ """出窗回收的保留长度 = window + 一个 radix 页。
+
+ 多留一页使「最近一个完成的 prompt-cache 边界」的结尾页恒驻留: 若回收只留 window,
+ 则任何非对齐时刻该边界的结尾页都已被部分回收,插入门会把所有插入裁到 0。
+ V4 prompt-cache 页取 256 token,正好覆盖一个 c4 物理页对应的 token 范围。"""
+ return int(self.sliding_window) + self.get_prompt_cache_page_size()
+
+ def prepare_prefill_swa(
+ self,
+ b_req_idx: torch.Tensor,
+ b_ready_cache_len: torch.Tensor,
+ b_seq_len: torch.Tensor,
+ b_req_idx_cpu: torch.Tensor,
+ b_ready_cache_len_cpu: torch.Tensor,
+ b_seq_len_cpu: torch.Tensor,
+ ) -> None:
+ """prefill prep: 为本 chunk 全部新 token(位置 [ready, seq))分配位置对齐的 swa 槽,
+ 并回收已出窗位置的槽。
+
+ 本 chunk 起点 L = ready_cache_len,首个新 token(位置 L)的窗口是 [L-W+1, L];回收
+ 边界再额外保留一个 radix 页(_swa_retain_len),即位置 < L-retain+1。先回收再分配。
+ 必须在 init_req_to_token_indexes 之后调用(位置对齐分配经 req_to_token 行派生/scatter)。"""
+ self.mem_manager: DeepseekV4MemoryManager
+ if self.sliding_window is not None:
+ retain = self._swa_retain_len()
+ evict_slots = []
+ req_list = b_req_idx_cpu.tolist()
+ ready_list = b_ready_cache_len_cpu.tolist()
+ for req_idx, ready_len in zip(req_list, ready_list):
+ req_idx = int(req_idx)
+ if req_idx == self.HOLD_REQUEST_ID:
+ continue
+ ready_len = int(ready_len)
+ mark = self._swa_evict_marks[req_idx]
+ if mark < 0:
+ # 首个 chunk: [0, ready_len) 是 radix 共享前缀,其 swa 槽归 radix 所有,不可回收。
+ self._swa_evict_marks[req_idx] = ready_len
+ continue
+ evict_end = ready_len - retain + 1
+ if evict_end > mark:
+ evict_slots.append(self.req_to_token_indexs[req_idx, mark:evict_end])
+ self._swa_evict_marks[req_idx] = evict_end
+ if evict_slots:
+ self.mem_manager.evict_swa(torch.cat(evict_slots))
+ self.mem_manager.alloc_swa_prefill(
+ b_req_idx,
+ b_ready_cache_len,
+ b_seq_len,
+ self.req_to_token_indexs,
+ b_req_idx_cpu=b_req_idx_cpu,
+ b_ready_cache_len_cpu=b_ready_cache_len_cpu,
+ b_seq_len_cpu=b_seq_len_cpu,
+ )
+ return
+
+ def prepare_decode(self, b_req_idx_cpu, b_seq_len_cpu, mem_indexes_cpu):
+ """decode 每步槽位 prep: 先 swa 再 compress。由 BaseModel.forward / microbatch_overlap_decode
+ 在 to_cuda 之前调用 (CPU mirror + forward 的 CUDA stream); 不再放在 _decode 里。"""
+ self.prepare_decode_swa(b_req_idx_cpu, b_seq_len_cpu, mem_indexes_cpu)
+ self.prepare_decode_compress_slots(b_req_idx_cpu, b_seq_len_cpu, mem_indexes_cpu)
+ return
+
+ def prepare_prefill(
+ self,
+ b_req_idx: torch.Tensor,
+ b_ready_cache_len: torch.Tensor,
+ b_seq_len: torch.Tensor,
+ b_req_idx_cpu: torch.Tensor,
+ b_ready_cache_len_cpu: torch.Tensor,
+ b_seq_len_cpu: torch.Tensor,
+ ) -> None:
+ """prefill 槽位 prep: 先 swa 再 compress。由 BaseModel 在
+ init_req_to_token_indexes 之后、attention metadata 构建之前调用。"""
+ self.prepare_prefill_swa(
+ b_req_idx,
+ b_ready_cache_len,
+ b_seq_len,
+ b_req_idx_cpu=b_req_idx_cpu,
+ b_ready_cache_len_cpu=b_ready_cache_len_cpu,
+ b_seq_len_cpu=b_seq_len_cpu,
+ )
+ self.prepare_prefill_compress_slots(
+ b_req_idx,
+ b_ready_cache_len,
+ b_seq_len,
+ b_req_idx_cpu=b_req_idx_cpu,
+ b_ready_cache_len_cpu=b_ready_cache_len_cpu,
+ b_seq_len_cpu=b_seq_len_cpu,
+ )
+ return
+
+ def prepare_decode_swa(
+ self,
+ b_req_idx_cpu: torch.Tensor,
+ b_seq_len_cpu: torch.Tensor,
+ mem_indexes: torch.Tensor,
+ ) -> None:
+ """decode prep: 回收出窗槽并为本步新 token 分配位置对齐的 swa 槽。当前 query 位置
+ seq_len-1 的窗口是 [seq_len-W, seq_len-1];回收边界额外保留一个 radix 页
+ (_swa_retain_len),即位置 < seq_len-retain。先回收再分配。
+ seq_len/req_idx 从 CPU 镜像读(host 算术,无 D2H);水位线 _swa_evict_marks 仍是 host 状态。"""
+ assert self.mem_manager is not None
+ if self.sliding_window is not None:
+ retain = self._swa_retain_len()
+ evict_slots = []
+ req_list = b_req_idx_cpu.tolist()
+ seq_list = b_seq_len_cpu.tolist()
+ for req_idx, seq_len in zip(req_list, seq_list):
+ if req_idx == self.HOLD_REQUEST_ID:
+ continue
+ mark = self._swa_evict_marks[req_idx]
+ if mark < 0:
+ # 未经过 prefill prep 的保守路径: 不回收旧位置,仅推进水位线。
+ self._swa_evict_marks[req_idx] = max(0, seq_len - retain)
+ continue
+ evict_end = seq_len - retain
+ if evict_end > mark:
+ evict_slots.append(self.req_to_token_indexs[req_idx, mark:evict_end])
+ self._swa_evict_marks[req_idx] = evict_end
+ if evict_slots:
+ self.mem_manager.evict_swa(torch.cat(evict_slots))
+ self.mem_manager.alloc_swa_decode(b_req_idx_cpu, b_seq_len_cpu, mem_indexes, self.req_to_token_indexs)
+ return
+
+ def init_compress_state(self, req_idx: int):
+ """新请求开始时重置 runtime 水位线(对应 mamba 的 init_linear_att_state 调用点)。
+
+ c4/c128 compressor state 都随 swa 页寻址,由内核按组覆写;请求复用时不做 per-req 清零。"""
+ self.clear_runtime_state(req_idx)
+ return
+
+ # ------------------------------------------------------------------ compress slot prep (per step)
+ def _compress_mapping_alloc(self, ratio: int):
+ assert self.mem_manager is not None, "DeepSeek-V4 mem manager is not bound yet"
+ if ratio == 4:
+ raise AssertionError("DeepSeek-V4 c4 uses page-safe allocation")
+ if ratio == 128:
+ return self.mem_manager.full_to_c128_indexs, self.mem_manager.alloc_c128
+ raise AssertionError(f"invalid DeepSeek-V4 compress ratio {ratio}")
+
+ def _c4_group_end_full_slots(self, req_rows, entries: torch.Tensor) -> torch.Tensor:
+ """组末 token 的 full 槽位 id (token 位置 = entry*4+3);req_rows 可为标量 req_idx 或行张量。"""
+ return self.req_to_token_indexs[req_rows, entries * 4 + 3].long()
+
+ def _register_c4_slots(self, full_slots: torch.Tensor, slots: torch.Tensor) -> None:
+ """写入 full->c4 槽映射并按页累加存活计数。"""
+ self.mem_manager.full_to_c4_indexs[full_slots] = slots
+ self.mem_manager.count_c4_slots(slots, 1)
+
+ def _scatter_c4_prefill_slots_slow(self, req_idx: int, first: int, last: int) -> None:
+ """Idempotence fallback for overlapped/repeated c4 prep."""
+ page = DSV4_C4_PAGE_SIZE
+ mapping = self.mem_manager.full_to_c4_indexs
+ for page_base in range((first // page) * page, last, page):
+ e0 = max(first, page_base)
+ e1 = min(last, page_base + page)
+ entries = torch.arange(e0, e1, dtype=torch.long, device="cuda")
+ full_slots = self._c4_group_end_full_slots(req_idx, entries)
+ existing = mapping[full_slots]
+ missing = existing < 0
+ if not bool(missing.any()):
+ continue
+
+ mapped = torch.nonzero(existing >= 0, as_tuple=False)
+ if mapped.numel() > 0:
+ j = int(mapped[0].item())
+ base = int(existing[j].item()) - ((e0 + j) % page)
+ elif e0 > page_base:
+ prev_full = self.req_to_token_indexs[req_idx, e0 * 4 - 1].long()
+ prev_slot = int(mapping[prev_full].item())
+ assert prev_slot >= 0 and prev_slot % page == (e0 - 1) % page
+ base = prev_slot - ((e0 - 1) % page)
+ else:
+ base = int(self.mem_manager.alloc_c4_pages(1)[0].item()) * page
+
+ slots = (base + entries % page).to(torch.int32)
+ if mapped.numel() > 0:
+ assert bool((existing[existing >= 0] == slots[existing >= 0]).all())
+ self._register_c4_slots(full_slots[missing], slots[missing])
+ return
+
+ def _scatter_c4_prefill_slots(self, req_idx: int, first: int, last: int) -> None:
+ """为 logical c4 entry [first, last) 分配 page-safe c4 槽。
+
+ 不变式: logical entry e 映射到 physical_page * 64 + e % 64,同一 logical page
+ 内 entry 共享 physical_page。这是 DeepGEMM paged MQA logits 直接消费 page table 的前提。
+ """
+ if last <= first:
+ return
+ mapping = self.mem_manager.full_to_c4_indexs
+ full_slots = self._c4_group_end_full_slots(req_idx, torch.arange(first, last, device="cuda"))
+ need = mapping[full_slots] < 0
+ if not bool(need.any()):
+ return
+ if not bool(need.all()):
+ self._scatter_c4_prefill_slots_slow(req_idx, first, last)
+ return
+ self._scatter_c4_prefill_slots_fresh(req_idx, first, last)
+ return
+
+ def _scatter_c4_prefill_slots_fresh(self, req_idx: int, first: int, last: int) -> None:
+ """Sync-free fast path for entries [first, last) the caller already knows are all fresh:
+ page-safe alloc with the continuation base read on-GPU. KvCacheAllocator.alloc is CPU-side
+ (watermark + pinned buffer, no D2H), so this has no syncs."""
+ page = DSV4_C4_PAGE_SIZE
+ mapping = self.mem_manager.full_to_c4_indexs
+ entries = torch.arange(first, last, dtype=torch.long, device="cuda")
+ full_slots = self._c4_group_end_full_slots(req_idx, entries)
+ first_page = first // page
+ n_pages = (last - 1) // page - first_page + 1
+ bases = torch.empty((n_pages,), dtype=torch.long, device="cuda")
+ base_start = 0
+ if first % page != 0: # chunk starts mid-page -> continue the prev chunk's physical page
+ prev_full = self.req_to_token_indexs[req_idx, first * 4 - 1].long()
+ bases[0] = mapping[prev_full].long() - ((first - 1) % page)
+ base_start = 1
+ if n_pages - base_start > 0:
+ bases[base_start:] = (
+ self.mem_manager.alloc_c4_pages(n_pages - base_start).cuda(non_blocking=True).long() * page
+ )
+ page_local = torch.div(entries, page, rounding_mode="floor") - first_page
+ slots = (bases[page_local] + entries % page).to(torch.int32)
+ self._register_c4_slots(full_slots, slots)
+ return
+
+ def _scatter_c4_prefill_slots_batched(self, req_list, ready_list, seq_list) -> None:
+ """Whole-batch c4 prefill scatter in O(1) GPU ops (independent of request count). The per-req
+ loop cost O(N) launches + 2-3 D2H syncs each; here every request's group-end entries are
+ flattened (ragged) and processed in one gather / idempotency-check / page-alloc / scatter /
+ count. Falls back to the per-req idempotent path on partial/re-run; preserves the page
+ invariant (logical entry e -> physical_page*64 + e%64, same logical page shares a physical
+ page). KvCacheAllocator.alloc is CPU-side so one batched alloc has no D2H."""
+ page = DSV4_C4_PAGE_SIZE
+ mapping = self.mem_manager.full_to_c4_indexs
+ device = mapping.device
+
+ # host plan (cheap int arithmetic, no GPU). dup req in one call would break vectorized
+ # continuation ordering -> fall back to the safe per-req path.
+ plan, seen, duplicate_req = [], set(), False
+ c4_page_need = 0
+ for req_idx, ready_len, seq_len in zip(req_list, ready_list, seq_list):
+ req_idx = int(req_idx)
+ if req_idx == self.HOLD_REQUEST_ID:
+ continue
+ first, last = int(ready_len) // 4, int(seq_len) // 4
+ if last <= first:
+ continue
+ c4_page_need += (last - 1) // page - first // page + 1 # 上界=区间触及页数, 复用本循环
+ duplicate_req |= req_idx in seen
+ seen.add(req_idx)
+ plan.append((req_idx, first, last))
+ if not plan:
+ return
+ # 兑现: 在所有分支(dup/fresh/batched)的 alloc_c4_pages 之前统一腾页
+ self._realize_c4_pages(c4_page_need)
+ if duplicate_req:
+ for req_idx, first, last in plan:
+ self._scatter_c4_prefill_slots(req_idx, first, last)
+ return
+
+ def to_cuda_long(key, data):
+ return g_pin_mem_manager.gen_from_list(key=key, data=data, dtype=torch.int64).to(device, non_blocking=True)
+
+ reqs, firsts, lasts = zip(*plan)
+ counts = [last - first for first, last in zip(firsts, lasts)]
+ first_pages = [first // page for first in firsts]
+ page_counts = [((last - 1) // page) - fp + 1 for last, fp in zip(lasts, first_pages)]
+ page_offsets, total_pages = [], 0
+ for n_pages in page_counts:
+ page_offsets.append(total_pages)
+ total_pages += n_pages
+ total_entries = sum(counts)
+
+ # one pinned H2D copy for all per-request metadata (5 cols), then per-entry ragged expansion
+ meta = to_cuda_long(
+ "dsv4_c4_prefill_meta",
+ [x for row in zip(reqs, firsts, first_pages, counts, page_offsets) for x in row],
+ ).view(-1, 5)
+ reqs_t, firsts_t, first_pages_t, counts_t, page_offsets_t = meta.unbind(1)
+ seg = torch.repeat_interleave(torch.arange(len(plan), device=device), counts_t, output_size=total_entries)
+ seg_starts = counts_t.cumsum(0) - counts_t
+ entries = firsts_t[seg] + torch.arange(total_entries, device=device) - seg_starts[seg]
+ full_slots = self._c4_group_end_full_slots(reqs_t[seg], entries)
+
+ if not bool((mapping[full_slots] < 0).all()): # the single batched idempotency sync
+ for req_idx, first, last in plan:
+ self._scatter_c4_prefill_slots(req_idx, first, last)
+ return
+
+ # physical base per logical page: fresh pages from one alloc; mid-page continuations read prev
+ cont = [(off, req, first) for off, req, first in zip(page_offsets, reqs, firsts) if first % page != 0]
+ if not cont:
+ page_bases = self.mem_manager.alloc_c4_pages(total_pages).to(device, non_blocking=True).long() * page
+ else:
+ page_bases = torch.empty(total_pages, dtype=torch.long, device=device)
+ new_pos = [
+ pos
+ for off, n_pages, first in zip(page_offsets, page_counts, firsts)
+ for pos in range(off + int(first % page != 0), off + n_pages)
+ ]
+ if new_pos:
+ new_pos_t = to_cuda_long("dsv4_c4_prefill_new_pos", new_pos)
+ page_bases[new_pos_t] = (
+ self.mem_manager.alloc_c4_pages(len(new_pos)).to(device, non_blocking=True).long() * page
+ )
+ cont_t = to_cuda_long("dsv4_c4_prefill_cont", [x for row in cont for x in row]).view(-1, 3)
+ prev_slot = mapping[self.req_to_token_indexs[cont_t[:, 1], cont_t[:, 2] * 4 - 1].long()].long()
+ cont_off = (cont_t[:, 2] - 1) % page
+ assert bool((prev_slot >= 0).all()) and bool(((prev_slot % page) == cont_off).all())
+ page_bases[cont_t[:, 0]] = prev_slot - cont_off
+
+ page_idx = page_offsets_t[seg] + torch.div(entries, page, rounding_mode="floor") - first_pages_t[seg]
+ slots = (page_bases[page_idx] + entries % page).to(torch.int32)
+ self._register_c4_slots(full_slots, slots)
+ return
+
+ def _scatter_c4_decode_slots(self, req_list, seq_list, mem_indexes: torch.Tensor) -> None:
+ page = DSV4_C4_PAGE_SIZE
+ mapping = self.mem_manager.full_to_c4_indexs
+ mem_indexes = mem_indexes.cuda().long().reshape(-1)
+
+ cont_rows, cont_prev_pos, cont_offsets = [], [], []
+ new_rows = []
+ for i, (req_idx, seq_len) in enumerate(zip(req_list, seq_list)):
+ req_idx, seq_len = int(req_idx), int(seq_len)
+ if req_idx == self.HOLD_REQUEST_ID or seq_len <= 0 or seq_len % 4 != 0:
+ continue
+ entry = seq_len // 4 - 1
+ offset = entry % page
+ if offset == 0:
+ new_rows.append(i)
+ else:
+ cont_rows.append(i)
+ cont_prev_pos.append(entry * 4 - 1)
+ cont_offsets.append(offset)
+
+ if cont_rows:
+ req_rows = torch.tensor([req_list[i] for i in cont_rows], dtype=torch.long, device="cuda")
+ prev_pos = torch.tensor(cont_prev_pos, dtype=torch.long, device="cuda")
+ prev_full = self.req_to_token_indexs[req_rows, prev_pos].long()
+ prev_slots = mapping[prev_full]
+ offsets = torch.tensor(cont_offsets, dtype=torch.int32, device="cuda")
+ assert bool((prev_slots >= 0).all())
+ assert bool(((prev_slots % page) == (offsets - 1)).all())
+ self._register_c4_slots(mem_indexes[cont_rows], (prev_slots + 1).to(torch.int32))
+
+ if new_rows:
+ self._realize_c4_pages(len(new_rows)) # 兑现: 精确需求, 复用已算的 new_rows
+ pages = self.mem_manager.alloc_c4_pages(len(new_rows)).cuda(non_blocking=True).long()
+ self._register_c4_slots(mem_indexes[new_rows], (pages * page).to(torch.int32))
+ return
+
+ def _scatter_compress_slots(self, ratio: int, full_slots: torch.Tensor) -> None:
+ """为组末 full 槽位分配压缩槽并写入映射。已映射(>=0)的行跳过——重复 prep 幂等。"""
+ if full_slots.numel() == 0:
+ return
+ mapping, alloc = self._compress_mapping_alloc(ratio)
+ full_slots = full_slots.cuda().long().reshape(-1)
+ # 去重: 同批重复键会让后写覆盖先写,先分配的压缩槽成为孤儿(allocator 泄漏)。
+ need = torch.unique(full_slots[mapping[full_slots] < 0])
+ if need.numel() == 0:
+ return
+ if ratio == 128: # _scatter_compress_slots 仅用于 c128; 兑现其槽(复用已算的 need.numel())
+ self._realize_c128_slots(int(need.numel()))
+ new_slots = alloc(need.numel()).cuda(non_blocking=True).to(torch.int32)
+ mapping[need] = new_slots
+ return
+
+ def _realize_c4_pages(self, need_pages: int) -> None:
+ """压缩池兑现 —— 和主池在 prep 里调 free_radix_cache_to_get_enough_token 同一套路:
+ base_backend admission 已按"空闲+可回收"放行本步请求,这里在真分配前(scatter 已算好 need)
+ 把可回收的无引用 radix 节点驱逐出来腾出 c4 页,避免 alloc_c4_pages 触底 assert。
+ 可回收仍不足时由 admission 的 wait_pause 兜底。"""
+ if self.n_c4 == 0 or need_pages <= 0:
+ return
+ # 延迟 import: infer_batch 在模块顶 import 了 req_manager,顶层 import 会循环引用
+ from lightllm.server.router.model_infer.infer_batch import g_infer_context
+
+ if g_infer_context.radix_cache is not None:
+ g_infer_context.radix_cache.free_radix_cache_to_get_enough_c4_pages(need_pages)
+ return
+
+ def _realize_c128_slots(self, need_slots: int) -> None:
+ if self.n_c128 == 0 or need_slots <= 0:
+ return
+ from lightllm.server.router.model_infer.infer_batch import g_infer_context
+
+ if g_infer_context.radix_cache is not None:
+ g_infer_context.radix_cache.free_radix_cache_to_get_enough_c128_slots(need_slots)
+ return
+
+ def prepare_prefill_compress_slots(
+ self,
+ b_req_idx: torch.Tensor,
+ b_ready_cache_len: torch.Tensor,
+ b_seq_len: torch.Tensor,
+ b_req_idx_cpu: torch.Tensor,
+ b_ready_cache_len_cpu: torch.Tensor,
+ b_seq_len_cpu: torch.Tensor,
+ ) -> None:
+ """prefill prep: 为本 chunk 内的组末 token(位置 (g+1)*ratio-1 ∈ [ready, seq))分配压缩槽,
+ scatter 进 full_to_c4/c128_indexs。必须在 init_req_to_token_indexes 之后(组末 full 槽
+ 从 req_to_token_indexs 取)、attention metadata 构建之前调用。"""
+ if self.n_c4 == 0 and self.n_c128 == 0:
+ return
+ req_list = b_req_idx_cpu.tolist()
+ ready_list = b_ready_cache_len_cpu.tolist()
+ seq_list = b_seq_len_cpu.tolist()
+ if self.n_c4 > 0:
+ self._scatter_c4_prefill_slots_batched(req_list, ready_list, seq_list)
+
+ if self.n_c128 > 0:
+ ratio = 128
+ end_slots = []
+ for req_idx, ready_len, seq_len in zip(req_list, ready_list, seq_list):
+ req_idx = int(req_idx)
+ if req_idx == self.HOLD_REQUEST_ID:
+ continue
+ first, last = int(ready_len) // ratio, int(seq_len) // ratio
+ if last > first:
+ ends = self.req_to_token_indexs[req_idx, ratio - 1 : last * ratio : ratio]
+ end_slots.append(ends[first:])
+ if end_slots:
+ self._scatter_compress_slots(ratio, torch.cat(end_slots))
+ return
+
+ def prepare_decode_compress_slots(
+ self,
+ b_req_idx_cpu: torch.Tensor,
+ b_seq_len_cpu: torch.Tensor,
+ mem_indexes: torch.Tensor,
+ ) -> None:
+ """decode prep: 本步 token 关闭一个组(seq_len % ratio == 0)时为其分配压缩槽并 scatter。
+ 组末 full 槽即本步的 mem_index(此刻 req_to_token_indexs 尚未写入本步槽位)。
+ 从 CPU 镜像读 seq_len/req_idx(host 算术,无 D2H);非关组步 rows 为空 => 不调 _scatter,零同步。"""
+ if self.n_c4 == 0 and self.n_c128 == 0:
+ return
+ req_list = b_req_idx_cpu.tolist()
+ seq_list = b_seq_len_cpu.tolist()
+ if self.n_c4 > 0:
+ self._scatter_c4_decode_slots(req_list, seq_list, mem_indexes)
+
+ if self.n_c128 > 0:
+ ratio = 128
+ rows = [
+ i
+ for i, (req_idx, seq_len) in enumerate(zip(req_list, seq_list))
+ if req_idx != self.HOLD_REQUEST_ID and seq_len > 0 and seq_len % ratio == 0
+ ]
+ if rows:
+ self._scatter_compress_slots(ratio, mem_indexes.reshape(-1)[rows])
+ return
+
+ def alloc(self):
+ req_idx = super().alloc()
+ if req_idx is not None:
+ self.init_compress_state(req_idx)
+ return req_idx
+
+ def clear_runtime_state(self, req_idx: int):
+ # swa 槽位本身由 mem_manager.free 级联回收(随 full 槽位),这里只复位出窗水位线。
+ self._swa_evict_marks[req_idx] = -1
+ return
+
+ def get_prompt_cache_value_ops(self):
+ return DeepseekV4PromptCacheValueOps(self)
+
+ def get_prompt_cache_page_size(self):
+ return DSV4_PROMPT_CACHE_PAGE_SIZE
+
+ def compute_swa_page_valid(self, full_slots: torch.Tensor) -> torch.Tensor:
+ """按当下 full_to_swa 映射给出按页有效性: full_slots [L](L 为 page 整数倍) ->
+ cpu bool [L/page],页内全部映射有效才为 True。GPU gather + 同步,测试/校验用;
+ 插入热路径用 swa_page_valid_from_watermark(纯 CPU,免同步)。"""
+ page = self.get_prompt_cache_page_size()
+ assert full_slots.numel() % page == 0
+ if full_slots.numel() == 0:
+ return torch.zeros((0,), dtype=torch.bool)
+ swa = self.mem_manager.full_to_swa_indexs[full_slots.cuda().long().reshape(-1)]
+ return (swa.view(-1, page) >= 0).all(dim=1).cpu()
+
+ def swa_page_valid_from_watermark(self, req_idx: int, cache_len: int) -> torch.Tensor:
+ """插入时的按页有效性,纯 CPU: 请求自有 token 的 swa 映射只被出窗水位线回收
+ (阀不触活跃请求,级联只在 free 时),页 p 全驻留 ⟺ 页起点 128p >= 水位线。
+
+ 与 compute_swa_page_valid 在插入时刻对自有 token 等价,但不做 GPU gather/同步——
+ router 关键路径上每次插入省一次对全部在途 kernel 的等待。bitmap 中借入前缀
+ ([0, ready) 的页)的行在 radix insert 切片时被丢弃(既有节点保留自己的 bitmap),
+ 其取值无影响。"""
+ page = self.get_prompt_cache_page_size()
+ mark = max(0, self._swa_evict_marks[req_idx])
+ n_pages = int(cache_len) // page
+ return torch.arange(n_pages, dtype=torch.long) * page >= mark
+
+ def slice_prompt_cache_payload(self, payload: DeepseekV4PromptCachePayload, start: int, end: int):
+ start = int(start)
+ end = int(end)
+ page = self.get_prompt_cache_page_size()
+ # radix page 保证分裂点页对齐,bitmap 可整页切分。
+ return DeepseekV4PromptCachePayload(
+ cache_len=end - start,
+ swa_page_valid=payload.swa_page_valid[start // page : end // page].clone()
+ if payload.swa_page_valid is not None
+ else None,
+ )
+
+ def concat_prompt_cache_payloads(self, payloads: List[DeepseekV4PromptCachePayload]):
+ if len(payloads) == 0:
+ return None
+ bitmaps = [p.swa_page_valid for p in payloads]
+ return DeepseekV4PromptCachePayload(
+ cache_len=sum(p.cache_len for p in payloads),
+ swa_page_valid=torch.cat(bitmaps, dim=0) if all(b is not None for b in bitmaps) else None,
+ )
+
+ def build_prompt_cache_payload(
+ self,
+ req_idx: int,
+ cache_len: int,
+ ) -> DeepseekV4PromptCachePayload:
+ """构造插入载荷。compressor 状态不进载荷(c4 随 swa 页生灭、c128 边界自然归零),
+ cache_len 不再受序列末端约束——任意 128 对齐前缀皆可插入。
+ swa_page_valid 不在此填: 它必须用插入时刻的映射(infer batch 在 insert 前补)。"""
+ assert self.mem_manager is not None
+ return DeepseekV4PromptCachePayload(cache_len=int(cache_len))
+
+ def free(self, free_req_indexes, free_token_index):
+ """dense/swa/压缩槽全部经 mem_manager.free(free_token_index) 级联回收。"""
+ for req_index in free_req_indexes:
+ self.clear_runtime_state(req_index)
+ super().free(free_req_indexes, free_token_index)
+ return
+
+ def free_req(self, free_req_index: int):
+ self.clear_runtime_state(free_req_index)
+ return super().free_req(free_req_index)
+
+ def free_all(self):
+ super().free_all()
+ self._swa_evict_marks = [-1 for _ in range(self.max_request_num + 1)]
+ return
diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.6.0/NVIDIA_H100_80GB_HBM3/grouped_matmul:v1/{K=256,N=4096,expert_num=256,mul_routed_weight=true,out_dtype=torch.bfloat16,topk_num=1,use_fp8_w8a8=true}_NVIDIA_H100_80GB_HBM3.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.6.0/NVIDIA_H100_80GB_HBM3/grouped_matmul:v1/{K=256,N=4096,expert_num=256,mul_routed_weight=true,out_dtype=torch.bfloat16,topk_num=1,use_fp8_w8a8=true}_NVIDIA_H100_80GB_HBM3.json
new file mode 100644
index 000000000..520409766
--- /dev/null
+++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.6.0/NVIDIA_H100_80GB_HBM3/grouped_matmul:v1/{K=256,N=4096,expert_num=256,mul_routed_weight=true,out_dtype=torch.bfloat16,topk_num=1,use_fp8_w8a8=true}_NVIDIA_H100_80GB_HBM3.json
@@ -0,0 +1,110 @@
+{
+ "12288": {
+ "BLOCK_SIZE_K": 128,
+ "BLOCK_SIZE_M": 64,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 64,
+ "NEED_TRANS": false,
+ "num_stages": 3,
+ "num_warps": 4
+ },
+ "1536": {
+ "BLOCK_SIZE_K": 64,
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 1,
+ "NEED_TRANS": true,
+ "num_stages": 3,
+ "num_warps": 4
+ },
+ "192": {
+ "BLOCK_SIZE_K": 128,
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 16,
+ "NEED_TRANS": true,
+ "num_stages": 2,
+ "num_warps": 4
+ },
+ "24576": {
+ "BLOCK_SIZE_K": 128,
+ "BLOCK_SIZE_M": 64,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 32,
+ "NEED_TRANS": false,
+ "num_stages": 3,
+ "num_warps": 4
+ },
+ "384": {
+ "BLOCK_SIZE_K": 128,
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 16,
+ "NEED_TRANS": true,
+ "num_stages": 2,
+ "num_warps": 4
+ },
+ "48": {
+ "BLOCK_SIZE_K": 64,
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 64,
+ "NEED_TRANS": true,
+ "num_stages": 3,
+ "num_warps": 4
+ },
+ "49152": {
+ "BLOCK_SIZE_K": 128,
+ "BLOCK_SIZE_M": 64,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 16,
+ "NEED_TRANS": false,
+ "num_stages": 3,
+ "num_warps": 4
+ },
+ "6": {
+ "BLOCK_SIZE_K": 64,
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 32,
+ "NEED_TRANS": true,
+ "num_stages": 2,
+ "num_warps": 4
+ },
+ "600": {
+ "BLOCK_SIZE_K": 128,
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 16,
+ "NEED_TRANS": true,
+ "num_stages": 2,
+ "num_warps": 4
+ },
+ "6144": {
+ "BLOCK_SIZE_K": 128,
+ "BLOCK_SIZE_M": 64,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 64,
+ "NEED_TRANS": false,
+ "num_stages": 3,
+ "num_warps": 4
+ },
+ "768": {
+ "BLOCK_SIZE_K": 128,
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 16,
+ "NEED_TRANS": true,
+ "num_stages": 2,
+ "num_warps": 4
+ },
+ "96": {
+ "BLOCK_SIZE_K": 128,
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 64,
+ "NEED_TRANS": true,
+ "num_stages": 2,
+ "num_warps": 4
+ }
+}
\ No newline at end of file
diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.6.0/NVIDIA_H100_80GB_HBM3/grouped_matmul:v1/{K=4096,N=512,expert_num=256,mul_routed_weight=false,out_dtype=torch.bfloat16,topk_num=6,use_fp8_w8a8=true}_NVIDIA_H100_80GB_HBM3.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.6.0/NVIDIA_H100_80GB_HBM3/grouped_matmul:v1/{K=4096,N=512,expert_num=256,mul_routed_weight=false,out_dtype=torch.bfloat16,topk_num=6,use_fp8_w8a8=true}_NVIDIA_H100_80GB_HBM3.json
new file mode 100644
index 000000000..ac4ce1ba5
--- /dev/null
+++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.6.0/NVIDIA_H100_80GB_HBM3/grouped_matmul:v1/{K=4096,N=512,expert_num=256,mul_routed_weight=false,out_dtype=torch.bfloat16,topk_num=6,use_fp8_w8a8=true}_NVIDIA_H100_80GB_HBM3.json
@@ -0,0 +1,110 @@
+{
+ "1": {
+ "BLOCK_SIZE_K": 128,
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 64,
+ "GROUP_SIZE_M": 1,
+ "NEED_TRANS": true,
+ "num_stages": 4,
+ "num_warps": 4
+ },
+ "100": {
+ "BLOCK_SIZE_K": 128,
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 64,
+ "GROUP_SIZE_M": 1,
+ "NEED_TRANS": true,
+ "num_stages": 3,
+ "num_warps": 4
+ },
+ "1024": {
+ "BLOCK_SIZE_K": 128,
+ "BLOCK_SIZE_M": 64,
+ "BLOCK_SIZE_N": 64,
+ "GROUP_SIZE_M": 1,
+ "NEED_TRANS": false,
+ "num_stages": 3,
+ "num_warps": 4
+ },
+ "128": {
+ "BLOCK_SIZE_K": 128,
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 64,
+ "GROUP_SIZE_M": 1,
+ "NEED_TRANS": true,
+ "num_stages": 5,
+ "num_warps": 4
+ },
+ "16": {
+ "BLOCK_SIZE_K": 128,
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 64,
+ "GROUP_SIZE_M": 1,
+ "NEED_TRANS": true,
+ "num_stages": 4,
+ "num_warps": 4
+ },
+ "2048": {
+ "BLOCK_SIZE_K": 128,
+ "BLOCK_SIZE_M": 64,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 1,
+ "NEED_TRANS": false,
+ "num_stages": 3,
+ "num_warps": 8
+ },
+ "256": {
+ "BLOCK_SIZE_K": 128,
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 1,
+ "NEED_TRANS": true,
+ "num_stages": 3,
+ "num_warps": 4
+ },
+ "32": {
+ "BLOCK_SIZE_K": 128,
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 64,
+ "GROUP_SIZE_M": 16,
+ "NEED_TRANS": true,
+ "num_stages": 5,
+ "num_warps": 4
+ },
+ "4096": {
+ "BLOCK_SIZE_K": 128,
+ "BLOCK_SIZE_M": 128,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 1,
+ "NEED_TRANS": false,
+ "num_stages": 5,
+ "num_warps": 8
+ },
+ "64": {
+ "BLOCK_SIZE_K": 128,
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 64,
+ "GROUP_SIZE_M": 1,
+ "NEED_TRANS": true,
+ "num_stages": 5,
+ "num_warps": 4
+ },
+ "8": {
+ "BLOCK_SIZE_K": 128,
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 64,
+ "GROUP_SIZE_M": 32,
+ "NEED_TRANS": true,
+ "num_stages": 4,
+ "num_warps": 4
+ },
+ "8192": {
+ "BLOCK_SIZE_K": 128,
+ "BLOCK_SIZE_M": 64,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 1,
+ "NEED_TRANS": false,
+ "num_stages": 3,
+ "num_warps": 4
+ }
+}
\ No newline at end of file
diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.6.0/NVIDIA_H100_80GB_HBM3/moe_align_fused:v1/{topk_num=6}_NVIDIA_H100_80GB_HBM3.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.6.0/NVIDIA_H100_80GB_HBM3/moe_align_fused:v1/{topk_num=6}_NVIDIA_H100_80GB_HBM3.json
new file mode 100644
index 000000000..6aa8d18c5
--- /dev/null
+++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.6.0/NVIDIA_H100_80GB_HBM3/moe_align_fused:v1/{topk_num=6}_NVIDIA_H100_80GB_HBM3.json
@@ -0,0 +1,50 @@
+{
+ "1": {
+ "BLOCK_SIZE": 256,
+ "num_warps": 2
+ },
+ "100": {
+ "BLOCK_SIZE": 128,
+ "num_warps": 8
+ },
+ "1024": {
+ "BLOCK_SIZE": 128,
+ "num_warps": 8
+ },
+ "128": {
+ "BLOCK_SIZE": 128,
+ "num_warps": 8
+ },
+ "16": {
+ "BLOCK_SIZE": 512,
+ "num_warps": 4
+ },
+ "2048": {
+ "BLOCK_SIZE": 128,
+ "num_warps": 4
+ },
+ "256": {
+ "BLOCK_SIZE": 128,
+ "num_warps": 8
+ },
+ "32": {
+ "BLOCK_SIZE": 128,
+ "num_warps": 8
+ },
+ "4096": {
+ "BLOCK_SIZE": 256,
+ "num_warps": 4
+ },
+ "64": {
+ "BLOCK_SIZE": 128,
+ "num_warps": 8
+ },
+ "8": {
+ "BLOCK_SIZE": 256,
+ "num_warps": 2
+ },
+ "8192": {
+ "BLOCK_SIZE": 256,
+ "num_warps": 4
+ }
+}
\ No newline at end of file
diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.6.0/NVIDIA_H100_80GB_HBM3/moe_sum_reduce:v1/{hidden_dim=4096,out_dtype=torch.bfloat16,topk_num=6}_NVIDIA_H100_80GB_HBM3.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.6.0/NVIDIA_H100_80GB_HBM3/moe_sum_reduce:v1/{hidden_dim=4096,out_dtype=torch.bfloat16,topk_num=6}_NVIDIA_H100_80GB_HBM3.json
new file mode 100644
index 000000000..e2da8bc96
--- /dev/null
+++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.6.0/NVIDIA_H100_80GB_HBM3/moe_sum_reduce:v1/{hidden_dim=4096,out_dtype=torch.bfloat16,topk_num=6}_NVIDIA_H100_80GB_HBM3.json
@@ -0,0 +1,74 @@
+{
+ "1": {
+ "BLOCK_DIM": 256,
+ "BLOCK_M": 1,
+ "NUM_STAGE": 1,
+ "num_warps": 8
+ },
+ "100": {
+ "BLOCK_DIM": 1024,
+ "BLOCK_M": 1,
+ "NUM_STAGE": 1,
+ "num_warps": 4
+ },
+ "1024": {
+ "BLOCK_DIM": 256,
+ "BLOCK_M": 1,
+ "NUM_STAGE": 4,
+ "num_warps": 1
+ },
+ "128": {
+ "BLOCK_DIM": 512,
+ "BLOCK_M": 1,
+ "NUM_STAGE": 1,
+ "num_warps": 4
+ },
+ "16": {
+ "BLOCK_DIM": 512,
+ "BLOCK_M": 1,
+ "NUM_STAGE": 1,
+ "num_warps": 4
+ },
+ "2048": {
+ "BLOCK_DIM": 1024,
+ "BLOCK_M": 1,
+ "NUM_STAGE": 1,
+ "num_warps": 8
+ },
+ "256": {
+ "BLOCK_DIM": 512,
+ "BLOCK_M": 1,
+ "NUM_STAGE": 1,
+ "num_warps": 4
+ },
+ "32": {
+ "BLOCK_DIM": 512,
+ "BLOCK_M": 1,
+ "NUM_STAGE": 1,
+ "num_warps": 4
+ },
+ "4096": {
+ "BLOCK_DIM": 512,
+ "BLOCK_M": 1,
+ "NUM_STAGE": 4,
+ "num_warps": 2
+ },
+ "64": {
+ "BLOCK_DIM": 512,
+ "BLOCK_M": 1,
+ "NUM_STAGE": 1,
+ "num_warps": 4
+ },
+ "8": {
+ "BLOCK_DIM": 64,
+ "BLOCK_M": 1,
+ "NUM_STAGE": 1,
+ "num_warps": 2
+ },
+ "8192": {
+ "BLOCK_DIM": 1024,
+ "BLOCK_M": 1,
+ "NUM_STAGE": 1,
+ "num_warps": 8
+ }
+}
\ No newline at end of file
diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.6.0/NVIDIA_H100_80GB_HBM3/rotary_emb_fwd:v1/{HEAD_DIM=64,K_HEAD_NUM=0,Q_HEAD_NUM=8,dtype=torch.bfloat16}_NVIDIA_H100_80GB_HBM3.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.6.0/NVIDIA_H100_80GB_HBM3/rotary_emb_fwd:v1/{HEAD_DIM=64,K_HEAD_NUM=0,Q_HEAD_NUM=8,dtype=torch.bfloat16}_NVIDIA_H100_80GB_HBM3.json
new file mode 100644
index 000000000..588fd4a93
--- /dev/null
+++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.6.0/NVIDIA_H100_80GB_HBM3/rotary_emb_fwd:v1/{HEAD_DIM=64,K_HEAD_NUM=0,Q_HEAD_NUM=8,dtype=torch.bfloat16}_NVIDIA_H100_80GB_HBM3.json
@@ -0,0 +1,74 @@
+{
+ "1": {
+ "BLOCK_SEQ": 1,
+ "HEAD_PARALLEL_NUM": 8,
+ "num_stages": 5,
+ "num_warps": 8
+ },
+ "100": {
+ "BLOCK_SEQ": 1,
+ "HEAD_PARALLEL_NUM": 8,
+ "num_stages": 1,
+ "num_warps": 4
+ },
+ "1024": {
+ "BLOCK_SEQ": 1,
+ "HEAD_PARALLEL_NUM": 1,
+ "num_stages": 3,
+ "num_warps": 4
+ },
+ "128": {
+ "BLOCK_SEQ": 1,
+ "HEAD_PARALLEL_NUM": 8,
+ "num_stages": 1,
+ "num_warps": 4
+ },
+ "16": {
+ "BLOCK_SEQ": 1,
+ "HEAD_PARALLEL_NUM": 8,
+ "num_stages": 4,
+ "num_warps": 8
+ },
+ "2048": {
+ "BLOCK_SEQ": 1,
+ "HEAD_PARALLEL_NUM": 1,
+ "num_stages": 4,
+ "num_warps": 2
+ },
+ "256": {
+ "BLOCK_SEQ": 2,
+ "HEAD_PARALLEL_NUM": 8,
+ "num_stages": 5,
+ "num_warps": 1
+ },
+ "32": {
+ "BLOCK_SEQ": 1,
+ "HEAD_PARALLEL_NUM": 8,
+ "num_stages": 2,
+ "num_warps": 8
+ },
+ "4096": {
+ "BLOCK_SEQ": 1,
+ "HEAD_PARALLEL_NUM": 1,
+ "num_stages": 1,
+ "num_warps": 2
+ },
+ "64": {
+ "BLOCK_SEQ": 1,
+ "HEAD_PARALLEL_NUM": 16,
+ "num_stages": 2,
+ "num_warps": 4
+ },
+ "8": {
+ "BLOCK_SEQ": 1,
+ "HEAD_PARALLEL_NUM": 8,
+ "num_stages": 5,
+ "num_warps": 8
+ },
+ "8192": {
+ "BLOCK_SEQ": 2,
+ "HEAD_PARALLEL_NUM": 1,
+ "num_stages": 3,
+ "num_warps": 1
+ }
+}
\ No newline at end of file
diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.6.0/NVIDIA_H100_80GB_HBM3/silu_and_mul_fwd:v1/{N=256,out_dtype=torch.bfloat16}_NVIDIA_H100_80GB_HBM3.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.6.0/NVIDIA_H100_80GB_HBM3/silu_and_mul_fwd:v1/{N=256,out_dtype=torch.bfloat16}_NVIDIA_H100_80GB_HBM3.json
new file mode 100644
index 000000000..4d7e8f118
--- /dev/null
+++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.6.0/NVIDIA_H100_80GB_HBM3/silu_and_mul_fwd:v1/{N=256,out_dtype=torch.bfloat16}_NVIDIA_H100_80GB_HBM3.json
@@ -0,0 +1,146 @@
+{
+ "1": {
+ "BLOCK_M": 128,
+ "BLOCK_N": 128,
+ "NUM_STAGES": 2,
+ "num_warps": 4
+ },
+ "100": {
+ "BLOCK_M": 1,
+ "BLOCK_N": 256,
+ "NUM_STAGES": 1,
+ "num_warps": 1
+ },
+ "1024": {
+ "BLOCK_M": 1,
+ "BLOCK_N": 256,
+ "NUM_STAGES": 1,
+ "num_warps": 4
+ },
+ "12288": {
+ "BLOCK_M": 8,
+ "BLOCK_N": 256,
+ "NUM_STAGES": 4,
+ "num_warps": 1
+ },
+ "128": {
+ "BLOCK_M": 1,
+ "BLOCK_N": 256,
+ "NUM_STAGES": 2,
+ "num_warps": 4
+ },
+ "1536": {
+ "BLOCK_M": 1,
+ "BLOCK_N": 256,
+ "NUM_STAGES": 2,
+ "num_warps": 1
+ },
+ "16": {
+ "BLOCK_M": 1,
+ "BLOCK_N": 128,
+ "NUM_STAGES": 4,
+ "num_warps": 4
+ },
+ "192": {
+ "BLOCK_M": 1,
+ "BLOCK_N": 256,
+ "NUM_STAGES": 1,
+ "num_warps": 4
+ },
+ "2048": {
+ "BLOCK_M": 1,
+ "BLOCK_N": 256,
+ "NUM_STAGES": 1,
+ "num_warps": 1
+ },
+ "24576": {
+ "BLOCK_M": 8,
+ "BLOCK_N": 256,
+ "NUM_STAGES": 4,
+ "num_warps": 1
+ },
+ "256": {
+ "BLOCK_M": 1,
+ "BLOCK_N": 256,
+ "NUM_STAGES": 4,
+ "num_warps": 1
+ },
+ "32": {
+ "BLOCK_M": 1,
+ "BLOCK_N": 32,
+ "NUM_STAGES": 2,
+ "num_warps": 4
+ },
+ "384": {
+ "BLOCK_M": 1,
+ "BLOCK_N": 256,
+ "NUM_STAGES": 4,
+ "num_warps": 1
+ },
+ "4096": {
+ "BLOCK_M": 8,
+ "BLOCK_N": 256,
+ "NUM_STAGES": 4,
+ "num_warps": 1
+ },
+ "48": {
+ "BLOCK_M": 1,
+ "BLOCK_N": 256,
+ "NUM_STAGES": 2,
+ "num_warps": 1
+ },
+ "49152": {
+ "BLOCK_M": 32,
+ "BLOCK_N": 128,
+ "NUM_STAGES": 4,
+ "num_warps": 1
+ },
+ "6": {
+ "BLOCK_M": 1,
+ "BLOCK_N": 128,
+ "NUM_STAGES": 4,
+ "num_warps": 8
+ },
+ "600": {
+ "BLOCK_M": 1,
+ "BLOCK_N": 256,
+ "NUM_STAGES": 4,
+ "num_warps": 1
+ },
+ "6144": {
+ "BLOCK_M": 8,
+ "BLOCK_N": 128,
+ "NUM_STAGES": 4,
+ "num_warps": 1
+ },
+ "64": {
+ "BLOCK_M": 1,
+ "BLOCK_N": 128,
+ "NUM_STAGES": 4,
+ "num_warps": 8
+ },
+ "768": {
+ "BLOCK_M": 1,
+ "BLOCK_N": 256,
+ "NUM_STAGES": 1,
+ "num_warps": 4
+ },
+ "8": {
+ "BLOCK_M": 1,
+ "BLOCK_N": 32,
+ "NUM_STAGES": 1,
+ "num_warps": 8
+ },
+ "8192": {
+ "BLOCK_M": 8,
+ "BLOCK_N": 256,
+ "NUM_STAGES": 4,
+ "num_warps": 1
+ },
+ "96": {
+ "BLOCK_M": 1,
+ "BLOCK_N": 128,
+ "NUM_STAGES": 4,
+ "num_warps": 8
+ }
+}
\ No newline at end of file
diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.6.0/NVIDIA_H200/grouped_matmul:v1/{K=256,N=4096,expert_num=256,mul_routed_weight=true,out_dtype=torch.bfloat16,topk_num=1,use_fp8_w8a8=true}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.6.0/NVIDIA_H200/grouped_matmul:v1/{K=256,N=4096,expert_num=256,mul_routed_weight=true,out_dtype=torch.bfloat16,topk_num=1,use_fp8_w8a8=true}_NVIDIA_H200.json
new file mode 100644
index 000000000..b1aae6bfb
--- /dev/null
+++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.6.0/NVIDIA_H200/grouped_matmul:v1/{K=256,N=4096,expert_num=256,mul_routed_weight=true,out_dtype=torch.bfloat16,topk_num=1,use_fp8_w8a8=true}_NVIDIA_H200.json
@@ -0,0 +1,110 @@
+{
+ "12288": {
+ "BLOCK_SIZE_K": 128,
+ "BLOCK_SIZE_M": 64,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 64,
+ "NEED_TRANS": false,
+ "num_stages": 3,
+ "num_warps": 4
+ },
+ "1536": {
+ "BLOCK_SIZE_K": 128,
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 64,
+ "NEED_TRANS": true,
+ "num_stages": 2,
+ "num_warps": 4
+ },
+ "192": {
+ "BLOCK_SIZE_K": 64,
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 32,
+ "NEED_TRANS": true,
+ "num_stages": 3,
+ "num_warps": 4
+ },
+ "24576": {
+ "BLOCK_SIZE_K": 128,
+ "BLOCK_SIZE_M": 64,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 16,
+ "NEED_TRANS": false,
+ "num_stages": 3,
+ "num_warps": 4
+ },
+ "384": {
+ "BLOCK_SIZE_K": 64,
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 32,
+ "NEED_TRANS": true,
+ "num_stages": 3,
+ "num_warps": 4
+ },
+ "48": {
+ "BLOCK_SIZE_K": 64,
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 64,
+ "NEED_TRANS": true,
+ "num_stages": 3,
+ "num_warps": 4
+ },
+ "49152": {
+ "BLOCK_SIZE_K": 128,
+ "BLOCK_SIZE_M": 64,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 16,
+ "NEED_TRANS": false,
+ "num_stages": 3,
+ "num_warps": 4
+ },
+ "6": {
+ "BLOCK_SIZE_K": 64,
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 32,
+ "NEED_TRANS": true,
+ "num_stages": 3,
+ "num_warps": 4
+ },
+ "600": {
+ "BLOCK_SIZE_K": 64,
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 64,
+ "NEED_TRANS": true,
+ "num_stages": 3,
+ "num_warps": 4
+ },
+ "6144": {
+ "BLOCK_SIZE_K": 128,
+ "BLOCK_SIZE_M": 32,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 64,
+ "NEED_TRANS": true,
+ "num_stages": 2,
+ "num_warps": 4
+ },
+ "768": {
+ "BLOCK_SIZE_K": 64,
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 32,
+ "NEED_TRANS": true,
+ "num_stages": 3,
+ "num_warps": 4
+ },
+ "96": {
+ "BLOCK_SIZE_K": 64,
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 64,
+ "NEED_TRANS": true,
+ "num_stages": 3,
+ "num_warps": 4
+ }
+}
\ No newline at end of file
diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.6.0/NVIDIA_H200/grouped_matmul:v1/{K=4096,N=512,expert_num=256,mul_routed_weight=false,out_dtype=torch.bfloat16,topk_num=6,use_fp8_w8a8=true}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.6.0/NVIDIA_H200/grouped_matmul:v1/{K=4096,N=512,expert_num=256,mul_routed_weight=false,out_dtype=torch.bfloat16,topk_num=6,use_fp8_w8a8=true}_NVIDIA_H200.json
new file mode 100644
index 000000000..9ffb0efd1
--- /dev/null
+++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.6.0/NVIDIA_H200/grouped_matmul:v1/{K=4096,N=512,expert_num=256,mul_routed_weight=false,out_dtype=torch.bfloat16,topk_num=6,use_fp8_w8a8=true}_NVIDIA_H200.json
@@ -0,0 +1,110 @@
+{
+ "1": {
+ "BLOCK_SIZE_K": 128,
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 64,
+ "GROUP_SIZE_M": 1,
+ "NEED_TRANS": true,
+ "num_stages": 3,
+ "num_warps": 4
+ },
+ "100": {
+ "BLOCK_SIZE_K": 128,
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 64,
+ "GROUP_SIZE_M": 16,
+ "NEED_TRANS": true,
+ "num_stages": 3,
+ "num_warps": 4
+ },
+ "1024": {
+ "BLOCK_SIZE_K": 128,
+ "BLOCK_SIZE_M": 32,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 16,
+ "NEED_TRANS": true,
+ "num_stages": 3,
+ "num_warps": 4
+ },
+ "128": {
+ "BLOCK_SIZE_K": 128,
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 1,
+ "NEED_TRANS": true,
+ "num_stages": 5,
+ "num_warps": 4
+ },
+ "16": {
+ "BLOCK_SIZE_K": 128,
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 64,
+ "GROUP_SIZE_M": 32,
+ "NEED_TRANS": true,
+ "num_stages": 3,
+ "num_warps": 4
+ },
+ "2048": {
+ "BLOCK_SIZE_K": 128,
+ "BLOCK_SIZE_M": 64,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 16,
+ "NEED_TRANS": false,
+ "num_stages": 3,
+ "num_warps": 4
+ },
+ "256": {
+ "BLOCK_SIZE_K": 128,
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 1,
+ "NEED_TRANS": true,
+ "num_stages": 3,
+ "num_warps": 4
+ },
+ "32": {
+ "BLOCK_SIZE_K": 128,
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 64,
+ "GROUP_SIZE_M": 1,
+ "NEED_TRANS": true,
+ "num_stages": 4,
+ "num_warps": 4
+ },
+ "4096": {
+ "BLOCK_SIZE_K": 128,
+ "BLOCK_SIZE_M": 128,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 16,
+ "NEED_TRANS": false,
+ "num_stages": 5,
+ "num_warps": 8
+ },
+ "64": {
+ "BLOCK_SIZE_K": 128,
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 64,
+ "GROUP_SIZE_M": 1,
+ "NEED_TRANS": true,
+ "num_stages": 5,
+ "num_warps": 4
+ },
+ "8": {
+ "BLOCK_SIZE_K": 128,
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 64,
+ "GROUP_SIZE_M": 16,
+ "NEED_TRANS": true,
+ "num_stages": 4,
+ "num_warps": 4
+ },
+ "8192": {
+ "BLOCK_SIZE_K": 128,
+ "BLOCK_SIZE_M": 64,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 1,
+ "NEED_TRANS": false,
+ "num_stages": 3,
+ "num_warps": 4
+ }
+}
\ No newline at end of file
diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.6.0/NVIDIA_H200/moe_align_fused:v1/{topk_num=6}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.6.0/NVIDIA_H200/moe_align_fused:v1/{topk_num=6}_NVIDIA_H200.json
new file mode 100644
index 000000000..85a20d9b1
--- /dev/null
+++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.6.0/NVIDIA_H200/moe_align_fused:v1/{topk_num=6}_NVIDIA_H200.json
@@ -0,0 +1,50 @@
+{
+ "1": {
+ "BLOCK_SIZE": 128,
+ "num_warps": 1
+ },
+ "100": {
+ "BLOCK_SIZE": 128,
+ "num_warps": 4
+ },
+ "1024": {
+ "BLOCK_SIZE": 128,
+ "num_warps": 8
+ },
+ "128": {
+ "BLOCK_SIZE": 128,
+ "num_warps": 4
+ },
+ "16": {
+ "BLOCK_SIZE": 256,
+ "num_warps": 8
+ },
+ "2048": {
+ "BLOCK_SIZE": 128,
+ "num_warps": 8
+ },
+ "256": {
+ "BLOCK_SIZE": 128,
+ "num_warps": 8
+ },
+ "32": {
+ "BLOCK_SIZE": 128,
+ "num_warps": 8
+ },
+ "4096": {
+ "BLOCK_SIZE": 128,
+ "num_warps": 8
+ },
+ "64": {
+ "BLOCK_SIZE": 128,
+ "num_warps": 4
+ },
+ "8": {
+ "BLOCK_SIZE": 512,
+ "num_warps": 4
+ },
+ "8192": {
+ "BLOCK_SIZE": 256,
+ "num_warps": 8
+ }
+}
\ No newline at end of file
diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.6.0/NVIDIA_H200/moe_sum_reduce:v1/{hidden_dim=4096,out_dtype=torch.bfloat16,topk_num=6}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.6.0/NVIDIA_H200/moe_sum_reduce:v1/{hidden_dim=4096,out_dtype=torch.bfloat16,topk_num=6}_NVIDIA_H200.json
new file mode 100644
index 000000000..de2f015a0
--- /dev/null
+++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.6.0/NVIDIA_H200/moe_sum_reduce:v1/{hidden_dim=4096,out_dtype=torch.bfloat16,topk_num=6}_NVIDIA_H200.json
@@ -0,0 +1,74 @@
+{
+ "1": {
+ "BLOCK_DIM": 128,
+ "BLOCK_M": 1,
+ "NUM_STAGE": 2,
+ "num_warps": 4
+ },
+ "100": {
+ "BLOCK_DIM": 1024,
+ "BLOCK_M": 1,
+ "NUM_STAGE": 1,
+ "num_warps": 4
+ },
+ "1024": {
+ "BLOCK_DIM": 256,
+ "BLOCK_M": 4,
+ "NUM_STAGE": 4,
+ "num_warps": 1
+ },
+ "128": {
+ "BLOCK_DIM": 1024,
+ "BLOCK_M": 1,
+ "NUM_STAGE": 1,
+ "num_warps": 8
+ },
+ "16": {
+ "BLOCK_DIM": 1024,
+ "BLOCK_M": 1,
+ "NUM_STAGE": 1,
+ "num_warps": 8
+ },
+ "2048": {
+ "BLOCK_DIM": 512,
+ "BLOCK_M": 1,
+ "NUM_STAGE": 4,
+ "num_warps": 2
+ },
+ "256": {
+ "BLOCK_DIM": 1024,
+ "BLOCK_M": 1,
+ "NUM_STAGE": 2,
+ "num_warps": 1
+ },
+ "32": {
+ "BLOCK_DIM": 256,
+ "BLOCK_M": 1,
+ "NUM_STAGE": 1,
+ "num_warps": 2
+ },
+ "4096": {
+ "BLOCK_DIM": 1024,
+ "BLOCK_M": 1,
+ "NUM_STAGE": 4,
+ "num_warps": 1
+ },
+ "64": {
+ "BLOCK_DIM": 1024,
+ "BLOCK_M": 1,
+ "NUM_STAGE": 1,
+ "num_warps": 8
+ },
+ "8": {
+ "BLOCK_DIM": 512,
+ "BLOCK_M": 1,
+ "NUM_STAGE": 1,
+ "num_warps": 4
+ },
+ "8192": {
+ "BLOCK_DIM": 1024,
+ "BLOCK_M": 1,
+ "NUM_STAGE": 4,
+ "num_warps": 2
+ }
+}
\ No newline at end of file
diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.6.0/NVIDIA_H200/rotary_emb_fwd:v1/{HEAD_DIM=64,K_HEAD_NUM=0,Q_HEAD_NUM=8,dtype=torch.bfloat16}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.6.0/NVIDIA_H200/rotary_emb_fwd:v1/{HEAD_DIM=64,K_HEAD_NUM=0,Q_HEAD_NUM=8,dtype=torch.bfloat16}_NVIDIA_H200.json
new file mode 100644
index 000000000..88742a0b1
--- /dev/null
+++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.6.0/NVIDIA_H200/rotary_emb_fwd:v1/{HEAD_DIM=64,K_HEAD_NUM=0,Q_HEAD_NUM=8,dtype=torch.bfloat16}_NVIDIA_H200.json
@@ -0,0 +1,74 @@
+{
+ "1": {
+ "BLOCK_SEQ": 1,
+ "HEAD_PARALLEL_NUM": 16,
+ "num_stages": 4,
+ "num_warps": 1
+ },
+ "100": {
+ "BLOCK_SEQ": 2,
+ "HEAD_PARALLEL_NUM": 8,
+ "num_stages": 5,
+ "num_warps": 2
+ },
+ "1024": {
+ "BLOCK_SEQ": 1,
+ "HEAD_PARALLEL_NUM": 1,
+ "num_stages": 1,
+ "num_warps": 4
+ },
+ "128": {
+ "BLOCK_SEQ": 1,
+ "HEAD_PARALLEL_NUM": 8,
+ "num_stages": 5,
+ "num_warps": 2
+ },
+ "16": {
+ "BLOCK_SEQ": 1,
+ "HEAD_PARALLEL_NUM": 8,
+ "num_stages": 5,
+ "num_warps": 4
+ },
+ "2048": {
+ "BLOCK_SEQ": 1,
+ "HEAD_PARALLEL_NUM": 1,
+ "num_stages": 1,
+ "num_warps": 2
+ },
+ "256": {
+ "BLOCK_SEQ": 2,
+ "HEAD_PARALLEL_NUM": 8,
+ "num_stages": 5,
+ "num_warps": 2
+ },
+ "32": {
+ "BLOCK_SEQ": 1,
+ "HEAD_PARALLEL_NUM": 8,
+ "num_stages": 5,
+ "num_warps": 8
+ },
+ "4096": {
+ "BLOCK_SEQ": 1,
+ "HEAD_PARALLEL_NUM": 1,
+ "num_stages": 2,
+ "num_warps": 1
+ },
+ "64": {
+ "BLOCK_SEQ": 1,
+ "HEAD_PARALLEL_NUM": 8,
+ "num_stages": 1,
+ "num_warps": 4
+ },
+ "8": {
+ "BLOCK_SEQ": 1,
+ "HEAD_PARALLEL_NUM": 8,
+ "num_stages": 2,
+ "num_warps": 8
+ },
+ "8192": {
+ "BLOCK_SEQ": 2,
+ "HEAD_PARALLEL_NUM": 1,
+ "num_stages": 2,
+ "num_warps": 1
+ }
+}
\ No newline at end of file
diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.6.0/NVIDIA_H200/silu_and_mul_fwd:v1/{N=256,out_dtype=torch.bfloat16}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.6.0/NVIDIA_H200/silu_and_mul_fwd:v1/{N=256,out_dtype=torch.bfloat16}_NVIDIA_H200.json
new file mode 100644
index 000000000..fbd364973
--- /dev/null
+++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.6.0/NVIDIA_H200/silu_and_mul_fwd:v1/{N=256,out_dtype=torch.bfloat16}_NVIDIA_H200.json
@@ -0,0 +1,146 @@
+{
+ "1": {
+ "BLOCK_M": 128,
+ "BLOCK_N": 32,
+ "NUM_STAGES": 2,
+ "num_warps": 4
+ },
+ "100": {
+ "BLOCK_M": 1,
+ "BLOCK_N": 128,
+ "NUM_STAGES": 4,
+ "num_warps": 4
+ },
+ "1024": {
+ "BLOCK_M": 1,
+ "BLOCK_N": 256,
+ "NUM_STAGES": 1,
+ "num_warps": 1
+ },
+ "12288": {
+ "BLOCK_M": 8,
+ "BLOCK_N": 256,
+ "NUM_STAGES": 1,
+ "num_warps": 1
+ },
+ "128": {
+ "BLOCK_M": 1,
+ "BLOCK_N": 128,
+ "NUM_STAGES": 1,
+ "num_warps": 1
+ },
+ "1536": {
+ "BLOCK_M": 1,
+ "BLOCK_N": 256,
+ "NUM_STAGES": 1,
+ "num_warps": 4
+ },
+ "16": {
+ "BLOCK_M": 1,
+ "BLOCK_N": 32,
+ "NUM_STAGES": 2,
+ "num_warps": 4
+ },
+ "192": {
+ "BLOCK_M": 1,
+ "BLOCK_N": 128,
+ "NUM_STAGES": 4,
+ "num_warps": 8
+ },
+ "2048": {
+ "BLOCK_M": 1,
+ "BLOCK_N": 256,
+ "NUM_STAGES": 1,
+ "num_warps": 4
+ },
+ "24576": {
+ "BLOCK_M": 1,
+ "BLOCK_N": 256,
+ "NUM_STAGES": 4,
+ "num_warps": 1
+ },
+ "256": {
+ "BLOCK_M": 1,
+ "BLOCK_N": 256,
+ "NUM_STAGES": 1,
+ "num_warps": 1
+ },
+ "32": {
+ "BLOCK_M": 1,
+ "BLOCK_N": 128,
+ "NUM_STAGES": 4,
+ "num_warps": 8
+ },
+ "384": {
+ "BLOCK_M": 1,
+ "BLOCK_N": 256,
+ "NUM_STAGES": 2,
+ "num_warps": 1
+ },
+ "4096": {
+ "BLOCK_M": 8,
+ "BLOCK_N": 256,
+ "NUM_STAGES": 4,
+ "num_warps": 4
+ },
+ "48": {
+ "BLOCK_M": 1,
+ "BLOCK_N": 32,
+ "NUM_STAGES": 2,
+ "num_warps": 4
+ },
+ "49152": {
+ "BLOCK_M": 32,
+ "BLOCK_N": 256,
+ "NUM_STAGES": 4,
+ "num_warps": 1
+ },
+ "6": {
+ "BLOCK_M": 1,
+ "BLOCK_N": 32,
+ "NUM_STAGES": 4,
+ "num_warps": 8
+ },
+ "600": {
+ "BLOCK_M": 1,
+ "BLOCK_N": 256,
+ "NUM_STAGES": 1,
+ "num_warps": 4
+ },
+ "6144": {
+ "BLOCK_M": 8,
+ "BLOCK_N": 256,
+ "NUM_STAGES": 1,
+ "num_warps": 4
+ },
+ "64": {
+ "BLOCK_M": 1,
+ "BLOCK_N": 128,
+ "NUM_STAGES": 4,
+ "num_warps": 4
+ },
+ "768": {
+ "BLOCK_M": 1,
+ "BLOCK_N": 256,
+ "NUM_STAGES": 1,
+ "num_warps": 4
+ },
+ "8": {
+ "BLOCK_M": 1,
+ "BLOCK_N": 64,
+ "NUM_STAGES": 1,
+ "num_warps": 1
+ },
+ "8192": {
+ "BLOCK_M": 8,
+ "BLOCK_N": 256,
+ "NUM_STAGES": 1,
+ "num_warps": 4
+ },
+ "96": {
+ "BLOCK_M": 1,
+ "BLOCK_N": 128,
+ "NUM_STAGES": 4,
+ "num_warps": 8
+ }
+}
\ No newline at end of file
diff --git a/lightllm/models/__init__.py b/lightllm/models/__init__.py
index f619b1d88..3d376d160 100644
--- a/lightllm/models/__init__.py
+++ b/lightllm/models/__init__.py
@@ -20,6 +20,7 @@
from lightllm.models.phi3.model import Phi3TpPartModel
from lightllm.models.deepseek2.model import Deepseek2TpPartModel
from lightllm.models.deepseek3_2.model import Deepseek3_2TpPartModel
+from lightllm.models.deepseek_v4.model import DeepseekV4TpPartModel
from lightllm.models.glm4_moe_lite.model import Glm4MoeLiteTpPartModel
from lightllm.models.internvl.model import (
InternVLLlamaTpPartModel,
diff --git a/lightllm/models/deepseek2/triton_kernel/rotary_emb.py b/lightllm/models/deepseek2/triton_kernel/rotary_emb.py
index 30e5a5924..a8f851de2 100644
--- a/lightllm/models/deepseek2/triton_kernel/rotary_emb.py
+++ b/lightllm/models/deepseek2/triton_kernel/rotary_emb.py
@@ -29,6 +29,8 @@ def _rotary_kernel(
BLOCK_SEQ: tl.constexpr,
BLOCK_DMODEL: tl.constexpr,
NUM_STAGE: tl.constexpr,
+ HAS_K: tl.constexpr,
+ INVERSE: tl.constexpr,
):
head_start_index = tl.program_id(0)
seq_block_index = tl.program_id(1)
@@ -44,6 +46,8 @@ def _rotary_kernel(
off_dimcos_sin = seq_index * stride_cosbs + cos_range * stride_cosd
cos = tl.load(Cos + off_dimcos_sin)
sin = tl.load(Sin + off_dimcos_sin)
+ if INVERSE:
+ sin = -sin
if HEAD_PARALLEL_NUM == 1:
for q_head_index in tl.static_range(0, HEAD_Q, step=1):
@@ -56,18 +60,19 @@ def _rotary_kernel(
tl.store(Q + off_q0, out_q0)
tl.store(Q + off_q1, out_q1)
- for k_head_index in tl.static_range(0, HEAD_K, step=1):
- off_k0 = seq_index * stride_kbs + k_head_index * stride_kh + dim_range0 * stride_kd
- off_k1 = seq_index * stride_kbs + k_head_index * stride_kh + dim_range1 * stride_kd
+ if HAS_K:
+ for k_head_index in tl.static_range(0, HEAD_K, step=1):
+ off_k0 = seq_index * stride_kbs + k_head_index * stride_kh + dim_range0 * stride_kd
+ off_k1 = seq_index * stride_kbs + k_head_index * stride_kh + dim_range1 * stride_kd
- k0 = tl.load(K + off_k0)
- k1 = tl.load(K + off_k1)
+ k0 = tl.load(K + off_k0)
+ k1 = tl.load(K + off_k1)
- out_k0 = k0 * cos - k1 * sin
- out_k1 = k0 * sin + k1 * cos
+ out_k0 = k0 * cos - k1 * sin
+ out_k1 = k0 * sin + k1 * cos
- tl.store(K + off_k0, out_k0)
- tl.store(K + off_k1, out_k1)
+ tl.store(K + off_k0, out_k0)
+ tl.store(K + off_k1, out_k1)
else:
for q_head_index in tl.range(head_start_index, HEAD_Q, step=HEAD_PARALLEL_NUM, num_stages=NUM_STAGE):
off_q0 = seq_index * stride_qbs + q_head_index * stride_qh + dim_range0 * stride_qd
@@ -79,18 +84,19 @@ def _rotary_kernel(
tl.store(Q + off_q0, out_q0)
tl.store(Q + off_q1, out_q1)
- for k_head_index in tl.range(head_start_index, HEAD_K, step=HEAD_PARALLEL_NUM, num_stages=NUM_STAGE):
- off_k0 = seq_index * stride_kbs + k_head_index * stride_kh + dim_range0 * stride_kd
- off_k1 = seq_index * stride_kbs + k_head_index * stride_kh + dim_range1 * stride_kd
+ if HAS_K:
+ for k_head_index in tl.range(head_start_index, HEAD_K, step=HEAD_PARALLEL_NUM, num_stages=NUM_STAGE):
+ off_k0 = seq_index * stride_kbs + k_head_index * stride_kh + dim_range0 * stride_kd
+ off_k1 = seq_index * stride_kbs + k_head_index * stride_kh + dim_range1 * stride_kd
- k0 = tl.load(K + off_k0)
- k1 = tl.load(K + off_k1)
+ k0 = tl.load(K + off_k0)
+ k1 = tl.load(K + off_k1)
- out_k0 = k0 * cos - k1 * sin
- out_k1 = k0 * sin + k1 * cos
+ out_k0 = k0 * cos - k1 * sin
+ out_k1 = k0 * sin + k1 * cos
- tl.store(K + off_k0, out_k0)
- tl.store(K + off_k1, out_k1)
+ tl.store(K + off_k0, out_k0)
+ tl.store(K + off_k1, out_k1)
return
@@ -109,7 +115,10 @@ def get_test_configs():
def get_static_key(q, k):
- head_num_q, head_num_k, head_dim = q.shape[1], k.shape[1], q.shape[2]
+ assert q is not None, "q can not be None"
+ head_num_q = q.shape[1]
+ head_num_k = k.shape[1] if k is not None else 0
+ head_dim = q.shape[2]
return {
"Q_HEAD_NUM": head_num_q,
"K_HEAD_NUM": head_num_k,
@@ -126,12 +135,17 @@ def get_static_key(q, k):
mutates_args=["q", "k"],
)
@torch.no_grad()
-def rotary_emb_fwd(q, k, cos, sin, run_config=None):
+def rotary_emb_fwd(q, k, cos, sin, inverse=False, run_config=None):
+ assert q is not None, "q can not be None"
+ has_k = k is not None and k.shape[1] != 0
total_len = q.shape[0]
- head_num_q, head_num_k = q.shape[1], k.shape[1]
+ head_num_q = q.shape[1]
+ head_num_k = k.shape[1] if k is not None else 0
head_dim = q.shape[2]
assert q.shape[0] == cos.shape[0] and q.shape[0] == sin.shape[0], f"q shape {q.shape} cos shape {cos.shape}"
- assert k.shape[0] == cos.shape[0] and k.shape[0] == sin.shape[0], f"k shape {k.shape} cos shape {cos.shape}"
+ if k is not None:
+ assert k.shape[0] == cos.shape[0] and k.shape[0] == sin.shape[0], f"k shape {k.shape} cos shape {cos.shape}"
+ assert k.shape[2] == head_dim, f"k shape {k.shape} q head_dim {head_dim}"
assert triton.next_power_of_2(head_dim) == head_dim
if not run_config:
@@ -157,9 +171,9 @@ def rotary_emb_fwd(q, k, cos, sin, run_config=None):
stride_qbs=q.stride(0),
stride_qh=q.stride(1),
stride_qd=q.stride(2),
- stride_kbs=k.stride(0),
- stride_kh=k.stride(1),
- stride_kd=k.stride(2),
+ stride_kbs=k.stride(0) if k is not None else 0,
+ stride_kh=k.stride(1) if k is not None else 0,
+ stride_kd=k.stride(2) if k is not None else 0,
stride_cosbs=cos.stride(0),
stride_cosd=cos.stride(1),
stride_sinbs=sin.stride(0),
@@ -171,6 +185,8 @@ def rotary_emb_fwd(q, k, cos, sin, run_config=None):
BLOCK_SEQ=BLOCK_SEQ,
BLOCK_DMODEL=head_dim,
NUM_STAGE=num_stages,
+ HAS_K=has_k,
+ INVERSE=inverse,
num_warps=num_warps,
num_stages=num_stages,
)
diff --git a/lightllm/models/deepseek3_2/layer_infer/transformer_layer_infer.py b/lightllm/models/deepseek3_2/layer_infer/transformer_layer_infer.py
index d6eaebe2f..8c7506b1e 100644
--- a/lightllm/models/deepseek3_2/layer_infer/transformer_layer_infer.py
+++ b/lightllm/models/deepseek3_2/layer_infer/transformer_layer_infer.py
@@ -206,6 +206,7 @@ def _get_indices(
weights = layer_weight.weights_proj_.mm(hidden_states) * self.index_n_heads_scale
weights = weights.unsqueeze(-1) * q_scale
+ att_state.ensure_nsa_ks_ke()
ks = att_state.ks
ke = att_state.ke
lengths = att_state.lengths
diff --git a/lightllm/models/deepseek_v4/__init__.py b/lightllm/models/deepseek_v4/__init__.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/lightllm/models/deepseek_v4/encoding/__init__.py b/lightllm/models/deepseek_v4/encoding/__init__.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/lightllm/models/deepseek_v4/encoding/encoding_dsv4.py b/lightllm/models/deepseek_v4/encoding/encoding_dsv4.py
new file mode 100644
index 000000000..6cbd5f9bf
--- /dev/null
+++ b/lightllm/models/deepseek_v4/encoding/encoding_dsv4.py
@@ -0,0 +1,762 @@
+"""
+DeepSeek-V4 Encoding
+
+A self-contained implementation for encoding/decoding DeepSeek-V4 chat messages
+with tool calling, thinking mode, and quick instruction task support.
+"""
+
+from typing import Any, Dict, List, Union, Optional, Tuple
+import copy
+import json
+import re
+
+# ============================================================
+# Special Tokens
+# ============================================================
+
+bos_token: str = "<|begin▁of▁sentence|>"
+eos_token: str = "<|end▁of▁sentence|>"
+thinking_start_token: str = ""
+thinking_end_token: str = ""
+dsml_token: str = "|DSML|"
+
+USER_SP_TOKEN = "<|User|>"
+ASSISTANT_SP_TOKEN = "<|Assistant|>"
+LATEST_REMINDER_SP_TOKEN = "<|latest_reminder|>"
+
+# Task special tokens for internal classification tasks
+DS_TASK_SP_TOKENS = {
+ "action": "<|action|>",
+ "query": "<|query|>",
+ "authority": "<|authority|>",
+ "domain": "<|domain|>",
+ "title": "<|title|>",
+ "read_url": "<|read_url|>",
+}
+VALID_TASKS = set(DS_TASK_SP_TOKENS.keys())
+
+# ============================================================
+# Templates
+# ============================================================
+
+system_msg_template: str = "{content}"
+user_msg_template: str = "{content}"
+latest_reminder_msg_template: str = "{content}"
+assistant_msg_template: str = "{reasoning}{content}{tool_calls}" + eos_token
+assistant_msg_wo_eos_template: str = "{reasoning}{content}{tool_calls}"
+thinking_template: str = "{reasoning_content}"
+
+response_format_template: str = (
+ "## Response Format:\n\nYou MUST strictly adhere to the following schema to reply:\n{schema}"
+)
+tool_call_template: str = '<{dsml_token}invoke name="{name}">\n{arguments}\n{dsml_token}invoke>'
+tool_calls_template = "<{dsml_token}{tc_block_name}>\n{tool_calls}\n{dsml_token}{tc_block_name}>"
+tool_calls_block_name: str = "tool_calls"
+
+tool_output_template: str = "{content}"
+
+REASONING_EFFORT_MAX = (
+ "Reasoning Effort: Absolute maximum with no shortcuts permitted.\n"
+ "You MUST be very thorough in your thinking and comprehensively decompose the problem to resolve the root cause, rigorously stress-testing your logic against all potential paths, edge cases, and adversarial scenarios.\n" # noqa: E501
+ "Explicitly write out your entire deliberation process, documenting every intermediate step, considered alternative, and rejected hypothesis to ensure absolutely no assumption is left unchecked.\n\n" # noqa: E501
+)
+
+TOOLS_TEMPLATE = """## Tools
+
+You have access to a set of tools to help answer the user's question. You can invoke tools by writing a \
+"<{dsml_token}tool_calls>" block like the following:
+
+<{dsml_token}tool_calls>
+<{dsml_token}invoke name="$TOOL_NAME">
+<{dsml_token}parameter name="$PARAMETER_NAME" string="true|false">$PARAMETER_VALUE{dsml_token}parameter>
+...
+{dsml_token}invoke>
+<{dsml_token}invoke name="$TOOL_NAME2">
+...
+{dsml_token}invoke>
+{dsml_token}tool_calls>
+
+String parameters should be specified as is and set `string="true"`. For all other types (numbers, \
+booleans, arrays, objects), pass the value in JSON format and set `string="false"`.
+
+If thinking_mode is enabled (triggered by {thinking_start_token}), you MUST output your complete \
+reasoning inside {thinking_start_token}...{thinking_end_token} BEFORE any tool calls or final response.
+
+Otherwise, output directly after {thinking_end_token} with tool calls or final response.
+
+### Available Tool Schemas
+
+{tool_schemas}
+
+You MUST strictly follow the above defined tool name and parameter schemas to invoke tool calls.
+"""
+
+# ============================================================
+# Utility Functions
+# ============================================================
+
+
+def to_json(value: Any) -> str:
+ """Serialize a value to JSON string."""
+ try:
+ return json.dumps(value, ensure_ascii=False)
+ except:
+ return json.dumps(value, ensure_ascii=True)
+
+
+def tools_from_openai_format(tools):
+ """Extract function definitions from OpenAI-format tool list."""
+ return [tool["function"] for tool in tools]
+
+
+def tool_calls_from_openai_format(tool_calls):
+ """Convert OpenAI-format tool calls to internal format."""
+ return [
+ {
+ "name": tool_call["function"]["name"],
+ "arguments": tool_call["function"]["arguments"],
+ }
+ for tool_call in tool_calls
+ ]
+
+
+def tool_calls_to_openai_format(tool_calls):
+ """Convert internal tool calls to OpenAI format."""
+ return [
+ {
+ "type": "function",
+ "function": {
+ "name": tool_call["name"],
+ "arguments": tool_call["arguments"],
+ },
+ }
+ for tool_call in tool_calls
+ ]
+
+
+def encode_arguments_to_dsml(tool_call: Dict[str, str]) -> str:
+ """
+ Encode tool call arguments into DSML parameter format.
+
+ Args:
+ tool_call: Dict with "name" and "arguments" (JSON string) keys.
+
+ Returns:
+ DSML-formatted parameter string.
+ """
+ p_dsml_template = '<{dsml_token}parameter name="{key}" string="{is_str}">{value}{dsml_token}parameter>'
+ P_dsml_strs = []
+
+ try:
+ arguments = json.loads(tool_call["arguments"])
+ except Exception:
+ arguments = {"arguments": tool_call["arguments"]}
+
+ for k, v in arguments.items():
+ p_dsml_str = p_dsml_template.format(
+ dsml_token=dsml_token,
+ key=k,
+ is_str="true" if isinstance(v, str) else "false",
+ value=v if isinstance(v, str) else to_json(v),
+ )
+ P_dsml_strs.append(p_dsml_str)
+
+ return "\n".join(P_dsml_strs)
+
+
+def decode_dsml_to_arguments(tool_name: str, tool_args: Dict[str, Tuple[str, str]]) -> Dict[str, str]:
+ """
+ Decode DSML parameters back to a tool call dict.
+
+ Args:
+ tool_name: Name of the tool.
+ tool_args: Dict mapping param_name -> (value, is_string_flag).
+
+ Returns:
+ Dict with "name" and "arguments" (JSON string) keys.
+ """
+
+ def _decode_value(key: str, value: str, string: str):
+ if string == "true":
+ value = to_json(value)
+ return f"{to_json(key)}: {value}"
+
+ tool_args_json = "{" + ", ".join([_decode_value(k, v, string=is_str) for k, (v, is_str) in tool_args.items()]) + "}"
+ return dict(name=tool_name, arguments=tool_args_json)
+
+
+def render_tools(tools: List[Dict[str, Union[str, Dict[str, Any]]]]) -> str:
+ """
+ Render tool schemas into the system prompt format.
+
+ Args:
+ tools: List of tool schema dicts (each with name, description, parameters).
+
+ Returns:
+ Formatted tools section string.
+ """
+ tools_json = [to_json(t) for t in tools]
+
+ return TOOLS_TEMPLATE.format(
+ tool_schemas="\n".join(tools_json),
+ dsml_token=dsml_token,
+ thinking_start_token=thinking_start_token,
+ thinking_end_token=thinking_end_token,
+ )
+
+
+def find_last_user_index(messages: List[Dict[str, Any]]) -> int:
+ """Find the index of the last user/developer message."""
+ last_user_index = -1
+ for idx in range(len(messages) - 1, -1, -1):
+ if messages[idx].get("role") in ["user", "developer"]:
+ last_user_index = idx
+ break
+ return last_user_index
+
+
+# ============================================================
+# Message Rendering
+# ============================================================
+
+
+def render_message(
+ index: int,
+ messages: List[Dict[str, Any]],
+ thinking_mode: str,
+ drop_thinking: bool = True,
+ reasoning_effort: Optional[str] = None,
+) -> str:
+ """
+ Render a single message at the given index into its encoded string form.
+
+ This is the core function that converts each message in the conversation
+ into the DeepSeek-V4 format.
+
+ Args:
+ index: Index of the message to render.
+ messages: Full list of messages in the conversation.
+ thinking_mode: Either "chat" or "thinking".
+ drop_thinking: Whether to drop reasoning content from earlier turns.
+ reasoning_effort: Optional reasoning effort level ("max", "high", or None).
+
+ Returns:
+ Encoded string for this message.
+ """
+ assert 0 <= index < len(messages)
+ assert thinking_mode in ["chat", "thinking"], f"Invalid thinking_mode `{thinking_mode}`"
+
+ prompt = ""
+ msg = messages[index]
+ last_user_idx = find_last_user_index(messages)
+
+ role = msg.get("role")
+ content = msg.get("content")
+ tools = msg.get("tools")
+ response_format = msg.get("response_format")
+ tool_calls = msg.get("tool_calls")
+ reasoning_content = msg.get("reasoning_content")
+ wo_eos = msg.get("wo_eos", False)
+
+ if tools:
+ tools = tools_from_openai_format(tools)
+ if tool_calls:
+ tool_calls = tool_calls_from_openai_format(tool_calls)
+
+ # Reasoning effort prefix (only at index 0 in thinking mode with max effort)
+ assert reasoning_effort in ["max", None, "high"], f"Invalid reasoning effort: {reasoning_effort}"
+ if index == 0 and thinking_mode == "thinking" and reasoning_effort == "max":
+ prompt += REASONING_EFFORT_MAX
+
+ if role == "system":
+ prompt += system_msg_template.format(content=content or "")
+ if tools:
+ prompt += "\n\n" + render_tools(tools)
+ if response_format:
+ prompt += "\n\n" + response_format_template.format(schema=to_json(response_format))
+
+ elif role == "developer":
+ assert content, f"Invalid message for role `{role}`: {msg}"
+
+ content_developer = USER_SP_TOKEN
+ content_developer += content
+
+ if tools:
+ content_developer += "\n\n" + render_tools(tools)
+ if response_format:
+ content_developer += "\n\n" + response_format_template.format(schema=to_json(response_format))
+
+ prompt += user_msg_template.format(content=content_developer)
+
+ elif role == "user":
+ prompt += USER_SP_TOKEN
+
+ # Handle content blocks (tool results mixed with text)
+ content_blocks = msg.get("content_blocks")
+ if content_blocks:
+ parts = []
+ for block in content_blocks:
+ block_type = block.get("type")
+ if block_type == "text":
+ parts.append(block.get("text", ""))
+ elif block_type == "tool_result":
+ tool_content = block.get("content", "")
+ if isinstance(tool_content, list):
+ text_parts = []
+ for b in tool_content:
+ if b.get("type") == "text":
+ text_parts.append(b.get("text", ""))
+ else:
+ text_parts.append(f"[Unsupported {b.get('type')}]")
+ tool_content = "\n\n".join(text_parts)
+ parts.append(tool_output_template.format(content=tool_content))
+ else:
+ parts.append(f"[Unsupported {block_type}]")
+ prompt += "\n\n".join(parts)
+ else:
+ prompt += content or ""
+
+ elif role == "latest_reminder":
+ prompt += LATEST_REMINDER_SP_TOKEN + latest_reminder_msg_template.format(content=content)
+
+ elif role == "tool":
+ raise NotImplementedError(
+ "deepseek_v4 merges tool messages into user; please preprocess with merge_tool_messages()"
+ )
+
+ elif role == "assistant":
+ thinking_part = ""
+ tc_content = ""
+
+ if tool_calls:
+ tc_list = [
+ tool_call_template.format(
+ dsml_token=dsml_token, name=tc.get("name"), arguments=encode_arguments_to_dsml(tc)
+ )
+ for tc in tool_calls
+ ]
+ tc_content += "\n\n" + tool_calls_template.format(
+ dsml_token=dsml_token,
+ tool_calls="\n".join(tc_list),
+ tc_block_name=tool_calls_block_name,
+ )
+
+ summary_content = content or ""
+ rc = reasoning_content or ""
+
+ # Check if previous message has a task - if so, this is a task output (no thinking)
+ prev_has_task = index - 1 >= 0 and messages[index - 1].get("task") is not None
+
+ if thinking_mode == "thinking" and not prev_has_task:
+ if not drop_thinking or index > last_user_idx:
+ thinking_part = thinking_template.format(reasoning_content=rc) + thinking_end_token
+ else:
+ thinking_part = ""
+
+ if wo_eos:
+ prompt += assistant_msg_wo_eos_template.format(
+ reasoning=thinking_part,
+ content=summary_content,
+ tool_calls=tc_content,
+ )
+ else:
+ prompt += assistant_msg_template.format(
+ reasoning=thinking_part,
+ content=summary_content,
+ tool_calls=tc_content,
+ )
+ else:
+ raise NotImplementedError(f"Unknown role: {role}")
+
+ # Append transition tokens based on what follows
+ if index + 1 < len(messages) and messages[index + 1].get("role") not in ["assistant", "latest_reminder"]:
+ return prompt
+
+ task = messages[index].get("task")
+ if task is not None:
+ # Task special token for internal classification tasks
+ assert task in VALID_TASKS, f"Invalid task: '{task}'. Valid tasks are: {list(VALID_TASKS)}"
+ task_sp_token = DS_TASK_SP_TOKENS[task]
+
+ if task != "action":
+ # Non-action tasks: append task sp token directly after the message
+ prompt += task_sp_token
+ else:
+ # Action task: append Assistant + thinking token + action sp token
+ prompt += ASSISTANT_SP_TOKEN
+ prompt += thinking_end_token if thinking_mode != "thinking" else thinking_start_token
+ prompt += task_sp_token
+
+ elif messages[index].get("role") in ["user", "developer"]:
+ # Normal generation: append Assistant + thinking token
+ prompt += ASSISTANT_SP_TOKEN
+ if not drop_thinking and thinking_mode == "thinking":
+ prompt += thinking_start_token
+ elif drop_thinking and thinking_mode == "thinking" and index >= last_user_idx:
+ prompt += thinking_start_token
+ else:
+ prompt += thinking_end_token
+
+ return prompt
+
+
+# ============================================================
+# Preprocessing
+# ============================================================
+
+
+def merge_tool_messages(messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
+ """
+ Merge tool messages into the preceding user message using content_blocks format.
+
+ DeepSeek-V4 does not have a standalone "tool" role; instead, tool results
+ are encoded as blocks within user messages.
+
+ This function converts a standard OpenAI-format conversation (with separate
+ "tool" role messages) into V4 format where tool results are merged into
+ user messages.
+
+ Args:
+ messages: List of message dicts in OpenAI format.
+
+ Returns:
+ Processed message list with tool messages merged into user messages.
+ """
+ merged: List[Dict[str, Any]] = []
+
+ for msg in messages:
+ msg = copy.deepcopy(msg)
+ role = msg.get("role")
+
+ if role == "tool":
+ # Convert tool message to a user message with tool_result block
+ tool_block = {
+ "type": "tool_result",
+ "tool_use_id": msg.get("tool_call_id", ""),
+ "content": msg.get("content", ""),
+ }
+ # Merge into previous message if it's already a user (merged tool)
+ if merged and merged[-1].get("role") == "user" and "content_blocks" in merged[-1]:
+ merged[-1]["content_blocks"].append(tool_block)
+ else:
+ merged.append(
+ {
+ "role": "user",
+ "content_blocks": [tool_block],
+ }
+ )
+ elif role == "user":
+ text_block = {"type": "text", "text": msg.get("content", "")}
+ if (
+ merged
+ and merged[-1].get("role") == "user"
+ and "content_blocks" in merged[-1]
+ and merged[-1].get("task") is None
+ ):
+ merged[-1]["content_blocks"].append(text_block)
+ else:
+ new_msg = {
+ "role": "user",
+ "content": msg.get("content", ""),
+ "content_blocks": [text_block],
+ }
+ # Preserve extra fields (task, wo_eos, mask, etc.)
+ for key in ("task", "wo_eos", "mask"):
+ if key in msg:
+ new_msg[key] = msg[key]
+ merged.append(new_msg)
+ else:
+ merged.append(msg)
+
+ return merged
+
+
+def sort_tool_results_by_call_order(messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
+ """
+ Sort tool_result blocks within user messages by the order of tool_calls
+ in the preceding assistant message.
+
+ Args:
+ messages: Preprocessed message list (after merge_tool_messages).
+
+ Returns:
+ Message list with sorted tool result blocks.
+ """
+ last_tool_call_order: Dict[str, int] = {}
+
+ for msg in messages:
+ role = msg.get("role")
+ if role == "assistant" and msg.get("tool_calls"):
+ last_tool_call_order = {}
+ for idx, tc in enumerate(msg["tool_calls"]):
+ tc_id = tc.get("id") or tc.get("function", {}).get("id", "")
+ if tc_id:
+ last_tool_call_order[tc_id] = idx
+
+ elif role == "user" and msg.get("content_blocks"):
+ tool_blocks = [b for b in msg["content_blocks"] if b.get("type") == "tool_result"]
+ if len(tool_blocks) > 1 and last_tool_call_order:
+ sorted_blocks = sorted(tool_blocks, key=lambda b: last_tool_call_order.get(b.get("tool_use_id", ""), 0))
+ sorted_idx = 0
+ new_blocks = []
+ for block in msg["content_blocks"]:
+ if block.get("type") == "tool_result":
+ new_blocks.append(sorted_blocks[sorted_idx])
+ sorted_idx += 1
+ else:
+ new_blocks.append(block)
+ msg["content_blocks"] = new_blocks
+
+ return messages
+
+
+# ============================================================
+# Main Encoding Function
+# ============================================================
+
+
+def encode_messages(
+ messages: List[Dict[str, Any]],
+ thinking_mode: str,
+ context: Optional[List[Dict[str, Any]]] = None,
+ drop_thinking: bool = True,
+ add_default_bos_token: bool = True,
+ reasoning_effort: Optional[str] = None,
+) -> str:
+ """
+ Encode a list of messages into the DeepSeek-V4 prompt format.
+
+ This is the main entry point for encoding conversations. It handles:
+ - BOS token insertion
+ - Thinking mode with optional reasoning content dropping
+ - Tool message merging into user messages
+ - Multi-turn conversation context
+
+ Args:
+ messages: List of message dicts to encode.
+ thinking_mode: Either "chat" or "thinking".
+ context: Optional preceding context messages (already encoded prefix).
+ drop_thinking: If True, drop reasoning_content from earlier assistant turns
+ (only keep reasoning for messages after the last user message).
+ add_default_bos_token: Whether to prepend BOS token at conversation start.
+ reasoning_effort: Optional reasoning effort level ("max", "high", or None).
+
+ Returns:
+ The encoded prompt string.
+ """
+ context = context if context else []
+
+ # Preprocess: merge tool messages and sort tool results
+ messages = merge_tool_messages(messages)
+ messages = sort_tool_results_by_call_order(context + messages)[len(context) :]
+ if context:
+ context = merge_tool_messages(context)
+ context = sort_tool_results_by_call_order(context)
+
+ full_messages = context + messages
+
+ prompt = bos_token if add_default_bos_token and len(context) == 0 else ""
+
+ # Resolve drop_thinking: if any message has tools defined, don't drop thinking
+ effective_drop_thinking = drop_thinking
+ if any(m.get("tools") for m in full_messages):
+ effective_drop_thinking = False
+
+ if thinking_mode == "thinking" and effective_drop_thinking:
+ full_messages = _drop_thinking_messages(full_messages)
+ # After dropping, recalculate how many messages to render
+ # (context may have shrunk too)
+ num_to_render = len(full_messages) - len(_drop_thinking_messages(context))
+ context_len = len(full_messages) - num_to_render
+ else:
+ num_to_render = len(messages)
+ context_len = len(context)
+
+ for idx in range(num_to_render):
+ prompt += render_message(
+ idx + context_len,
+ full_messages,
+ thinking_mode=thinking_mode,
+ drop_thinking=effective_drop_thinking,
+ reasoning_effort=reasoning_effort,
+ )
+
+ return prompt
+
+
+def _drop_thinking_messages(messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
+ """
+ Drop reasoning_content and non-essential messages before the last user message.
+
+ Behavior:
+ - Messages with role in ["user", "system", "tool", "latest_reminder"] are always kept.
+ - Messages at or after the last user index are always kept.
+ - Assistant messages before the last user get reasoning_content removed.
+ - Developer messages before the last user are dropped entirely.
+ """
+ last_user_idx = find_last_user_index(messages)
+ result = []
+ keep_roles = {"user", "system", "tool", "latest_reminder", "direct_search_results"}
+
+ for idx, msg in enumerate(messages):
+ role = msg.get("role")
+ if role in keep_roles or idx >= last_user_idx:
+ result.append(msg)
+ elif role == "assistant":
+ msg = copy.copy(msg)
+ msg.pop("reasoning_content", None)
+ result.append(msg)
+ # developer and other roles before last_user_idx are dropped
+
+ return result
+
+
+# ============================================================
+# Parsing (Decoding model output)
+# ============================================================
+
+
+def _read_until_stop(index: int, text: str, stop: List[str]) -> Tuple[int, str, Optional[str]]:
+ """
+ Read text from index until one of the stop strings is found.
+
+ Returns:
+ Tuple of (new_index, content_before_stop, matched_stop_string_or_None).
+ """
+ min_pos = len(text)
+ matched_stop = None
+
+ for s in stop:
+ pos = text.find(s, index)
+ if pos != -1 and pos < min_pos:
+ min_pos = pos
+ matched_stop = s
+
+ if matched_stop:
+ content = text[index:min_pos]
+ return min_pos + len(matched_stop), content, matched_stop
+ else:
+ content = text[index:]
+ return len(text), content, None
+
+
+def parse_tool_calls(index: int, text: str) -> Tuple[int, Optional[str], List[Dict[str, str]]]:
+ """
+ Parse DSML tool calls from text starting at the given index.
+
+ Args:
+ index: Starting position in text.
+ text: The full text to parse.
+
+ Returns:
+ Tuple of (new_index, last_stop_token, list_of_tool_call_dicts).
+ Each tool call dict has "name" and "arguments" keys.
+ """
+ tool_calls: List[Dict[str, Any]] = []
+ stop_token = None
+ tool_calls_end_token = f"{dsml_token}{tool_calls_block_name}>"
+
+ while index < len(text):
+ index, _, stop_token = _read_until_stop(index, text, [f"<{dsml_token}invoke", tool_calls_end_token])
+ if _ != ">\n":
+ raise ValueError(f"Tool call format error: expected '>\\n' but got '{_}'")
+
+ if stop_token == tool_calls_end_token:
+ break
+
+ if stop_token is None:
+ raise ValueError("Missing special token in tool calls")
+
+ index, tool_name_content, stop_token = _read_until_stop(
+ index, text, [f"<{dsml_token}parameter", f"{dsml_token}invoke"]
+ )
+
+ p_tool_name = re.findall(r'^\s*name="(.*?)">\n$', tool_name_content, flags=re.DOTALL)
+ if len(p_tool_name) != 1:
+ raise ValueError(f"Tool name format error: '{tool_name_content}'")
+ tool_name = p_tool_name[0]
+
+ tool_args: Dict[str, Tuple[str, str]] = {}
+ while stop_token == f"<{dsml_token}parameter":
+ index, param_content, stop_token = _read_until_stop(index, text, [f"/{dsml_token}parameter"])
+
+ param_kv = re.findall(r'^ name="(.*?)" string="(true|false)">(.*?)<$', param_content, flags=re.DOTALL)
+ if len(param_kv) != 1:
+ raise ValueError(f"Parameter format error: '{param_content}'")
+ param_name, string, param_value = param_kv[0]
+
+ if param_name in tool_args:
+ raise ValueError(f"Duplicate parameter name: '{param_name}'")
+ tool_args[param_name] = (param_value, string)
+
+ index, content, stop_token = _read_until_stop(
+ index, text, [f"<{dsml_token}parameter", f"{dsml_token}invoke"]
+ )
+ if content != ">\n":
+ raise ValueError(f"Parameter format error: expected '>\\n' but got '{content}'")
+
+ tool_call = decode_dsml_to_arguments(tool_name=tool_name, tool_args=tool_args)
+ tool_calls.append(tool_call)
+
+ return index, stop_token, tool_calls
+
+
+def parse_message_from_completion_text(text: str, thinking_mode: str) -> Dict[str, Any]:
+ """
+ Parse a model completion text into a structured assistant message.
+
+ This function takes the raw text output from the model (a single assistant turn)
+ and extracts:
+ - reasoning_content (thinking block)
+ - content (summary/response)
+ - tool_calls (if any)
+
+ NOTE: This function is designed to parse only correctly formatted strings and
+ will raise ValueError for malformed output.
+
+ Args:
+ text: The raw completion text (including EOS token).
+ thinking_mode: Either "chat" or "thinking".
+
+ Returns:
+ Dict with keys: "role", "content", "reasoning_content", "tool_calls".
+ tool_calls are in OpenAI format.
+ """
+ summary_content, reasoning_content, tool_calls = "", "", []
+ index, stop_token = 0, None
+ tool_calls_start_token = f"\n\n<{dsml_token}{tool_calls_block_name}"
+
+ is_thinking = thinking_mode == "thinking"
+ is_tool_calling = False
+
+ if is_thinking:
+ index, content_delta, stop_token = _read_until_stop(index, text, [thinking_end_token, tool_calls_start_token])
+ reasoning_content = content_delta
+ assert stop_token == thinking_end_token, "Invalid thinking format: missing "
+
+ index, content_delta, stop_token = _read_until_stop(index, text, [eos_token, tool_calls_start_token])
+ summary_content = content_delta
+ if stop_token == tool_calls_start_token:
+ is_tool_calling = True
+ else:
+ assert stop_token == eos_token, "Invalid format: missing EOS token"
+
+ if is_tool_calling:
+ index, stop_token, tool_calls = parse_tool_calls(index, text)
+
+ index, tool_ends_text, stop_token = _read_until_stop(index, text, [eos_token])
+ assert not tool_ends_text, "Unexpected content after tool calls"
+
+ assert len(text) == index and stop_token in [eos_token, None], "Unexpected content at end"
+
+ for sp_token in [bos_token, eos_token, thinking_start_token, thinking_end_token, dsml_token]:
+ assert (
+ sp_token not in summary_content and sp_token not in reasoning_content
+ ), f"Unexpected special token '{sp_token}' in content"
+
+ return {
+ "role": "assistant",
+ "content": summary_content,
+ "reasoning_content": reasoning_content,
+ "tool_calls": tool_calls_to_openai_format(tool_calls),
+ }
diff --git a/lightllm/models/deepseek_v4/infer_struct.py b/lightllm/models/deepseek_v4/infer_struct.py
new file mode 100644
index 000000000..e14e84c96
--- /dev/null
+++ b/lightllm/models/deepseek_v4/infer_struct.py
@@ -0,0 +1,98 @@
+import torch
+from lightllm.common.basemodel import InferStateInfo
+from lightllm.common.req_manager import DeepseekV4ReqManager
+from lightllm.common.kv_cache_mem_manager import DeepseekV4MemoryManager
+
+
+class DeepseekV4InferStateInfo(InferStateInfo):
+ req_manager: DeepseekV4ReqManager
+ mem_manager: DeepseekV4MemoryManager
+
+ """Per-token interleaved-rope cos/sin for the two rope variants (sliding / compressed), following
+ the gemma4 two-variant convention (_cos_cached_* -> position_cos_*). The full rope tables are
+ model constants and live on the model / layer infers, not here."""
+
+ def __init__(self):
+ super().__init__()
+ self.position_cos_sliding = None
+ self.position_sin_sliding = None
+ self.position_cos_compress = None
+ self.position_sin_compress = None
+ # layer-independent sparse-index metadata, built once per forward in init_some_extra_state
+ # (None until then so copy_for_cuda_graph's tensor-attr loop skips them).
+ self.dsv4_sparse_req_idx = None
+ self.dsv4_swa_indices = None
+ self.dsv4_swa_lengths = None
+ self.dsv4_c128_indices = None
+ self.dsv4_c128_lengths = None
+ self.dsv4_workspace = None
+ # token -> batch-position map for the compressor; built per prefill forward in init_some_extra_state.
+ self._dsv4_token_to_batch_idx = None
+ # lazily-built (first c4 layer) cache of layer-independent paged-c4 metadata; reused by the
+ # other c4 layers in the same forward. Plain tuple (not a tensor attr) so copy_for_cuda_graph
+ # ignores it -- it's a capture-time wiring of layer0->others, not a staged graph input.
+ self._c4_paged_meta = None
+
+ def _dsv4_index_max_kv_seq_len(self, model):
+ if (
+ not self.is_prefill
+ and model.graph is not None
+ and model.graph.can_run(self.batch_size, self.max_kv_seq_len)
+ ):
+ return model.graph.graph_max_len_in_batch
+ return self.max_kv_seq_len
+
+ def init_some_extra_state(self, model):
+ self._c4_paged_meta = None # reset per forward before any c4 layer runs
+ super().init_some_extra_state(model) # sets position_ids, b_q_seq_len, b_q_start_loc (prefill)
+ pos = self.position_ids
+ self.position_cos_sliding = torch.index_select(model._cos_cached_sliding, 0, pos) # [T, rope_dim//2]
+ self.position_sin_sliding = torch.index_select(model._sin_cached_sliding, 0, pos)
+ self.position_cos_compress = torch.index_select(model._cos_cached_compress, 0, pos)
+ self.position_sin_compress = torch.index_select(model._sin_cached_compress, 0, pos)
+ # Per-token request id (decode: one token per req; prefill: ragged -> repeat by q-len).
+ # Layer-independent; the swa kernel + build_metadata's c4/c128 readers all reuse it.
+ if self.is_prefill:
+ self.dsv4_sparse_req_idx = torch.repeat_interleave(self.b_req_idx, self.b_q_seq_len.long())
+ self._dsv4_token_to_batch_idx = torch.repeat_interleave(
+ torch.arange(self.b_req_idx.shape[0], device=self.b_req_idx.device),
+ self.b_q_seq_len.long(),
+ output_size=pos.numel(),
+ ).to(torch.int32)
+ else:
+ self.dsv4_sparse_req_idx = self.b_req_idx
+ self._dsv4_token_to_batch_idx = None
+ # Sliding-window indices are layer-independent, so build them once into the model workspace.
+ from lightllm.models.deepseek_v4.triton_kernel.build_swa_index_dsv4 import build_swa_index
+
+ workspace = model.dsv4_workspace
+ self.dsv4_workspace = workspace
+ self.dsv4_swa_indices, self.dsv4_swa_lengths = workspace.swa(self.microbatch_index, pos.numel())
+ self.dsv4_swa_indices, self.dsv4_swa_lengths = build_swa_index(
+ req_idx=self.dsv4_sparse_req_idx,
+ positions=self.position_ids,
+ req_to_token_indexs=self.req_manager.req_to_token_indexs,
+ full_to_swa_indexs=self.mem_manager.full_to_swa_indexs,
+ swa_index=self.dsv4_swa_indices,
+ swa_length=self.dsv4_swa_lengths,
+ )
+ from lightllm.models.deepseek_v4.triton_kernel.build_compress_index_dsv4 import build_compress_index
+
+ cap = workspace.compress_cap(self._dsv4_index_max_kv_seq_len(model), 128)
+ self.dsv4_c128_indices, self.dsv4_c128_lengths = workspace.c128(self.microbatch_index, pos.numel(), cap)
+ build_compress_index(
+ self.dsv4_sparse_req_idx,
+ self.position_ids,
+ self.req_manager.req_to_token_indexs,
+ self.mem_manager.full_to_c128_indexs,
+ 128,
+ self.dsv4_c128_indices,
+ self.dsv4_c128_lengths,
+ )
+ # prefill-cudagraph 桶填充的 HOLD 尾请求的 q 行数。其注意力读 HOLD 槽位(内容被并发写
+ # 竞争,每轮不同),输出必须清零,否则 pad 行 hidden 不确定 -> MoE 路由抖动 -> 共享 expert
+ # 批次组成变化 -> 真实行 GEMM 归约顺序变化(ulp 级),44 层放大后翻转低置信 token。
+ self._dsv4_prefill_pad_q_len = 0
+ if self.is_prefill and self.b_req_idx.numel() > 0:
+ if int(self.b_req_idx[-1].item()) == self.req_manager.HOLD_REQUEST_ID:
+ self._dsv4_prefill_pad_q_len = int((self.b_seq_len[-1] - self.b_ready_cache_len[-1]).item())
diff --git a/lightllm/models/deepseek_v4/layer_infer/__init__.py b/lightllm/models/deepseek_v4/layer_infer/__init__.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/lightllm/models/deepseek_v4/layer_infer/compressor.py b/lightllm/models/deepseek_v4/layer_infer/compressor.py
new file mode 100644
index 000000000..695ac33bc
--- /dev/null
+++ b/lightllm/models/deepseek_v4/layer_infer/compressor.py
@@ -0,0 +1,473 @@
+from dataclasses import dataclass
+from typing import Optional
+
+import torch
+import triton
+import triton.language as tl
+from triton.language.extra import libdevice
+
+from lightllm.common.kv_cache_mem_manager import DeepseekV4MemoryManager
+from lightllm.common.kv_cache_mem_manager.deepseek4_mem_manager import (
+ DSV4_C4_STATE_RING,
+ DSV4_C128_STATE_RING,
+ DSV4_SWA_PAGE_SIZE,
+)
+
+
+@dataclass
+class CoreCompressorMetadata:
+ layer_idx: int
+ compress_ratio: int
+ out_slots: torch.Tensor
+ mem_index: torch.Tensor
+ state_buffer: torch.Tensor
+ out_buffer: torch.Tensor
+ out_page_size: int
+ position_ids: torch.Tensor
+ b_req_idx: torch.Tensor
+ b_seq_len: torch.Tensor
+ b_ready_cache_len: Optional[torch.Tensor]
+ b_q_start_loc: Optional[torch.Tensor]
+ req_to_token_indexs: torch.Tensor
+ full_to_swa_indexs: torch.Tensor
+ token_to_batch_idx: Optional[torch.Tensor]
+ kv_score: Optional[torch.Tensor]
+ is_prefill: bool
+
+
+@triton.jit
+def _add_ape_to_kv_score_kernel(
+ kv_score,
+ kv_score_stride0,
+ kv_score_stride1,
+ ape,
+ ape_stride0,
+ positions,
+ STATE_WIDTH: tl.constexpr,
+ COMPRESS_RATIO: tl.constexpr,
+ BLOCK: tl.constexpr,
+):
+ token_idx = tl.program_id(0)
+ offs = tl.arange(0, BLOCK)
+ mask = offs < STATE_WIDTH
+
+ position = tl.load(positions + token_idx)
+ ape_row = position % COMPRESS_RATIO
+ score = tl.load(kv_score + token_idx * kv_score_stride0 + (STATE_WIDTH + offs) * kv_score_stride1, mask=mask)
+ ape_value = tl.load(ape + ape_row * ape_stride0 + offs, mask=mask)
+ tl.store(
+ kv_score + token_idx * kv_score_stride0 + (STATE_WIDTH + offs) * kv_score_stride1,
+ score + ape_value,
+ mask=mask,
+ )
+ return
+
+
+@triton.jit
+def _save_partial_states_kernel(
+ kv_score,
+ kv_score_stride0,
+ kv_score_stride1,
+ positions,
+ token_to_batch_idx,
+ b_req_idx,
+ b_seq_len,
+ mem_index,
+ full_to_swa,
+ state_buffer,
+ STATE_WIDTH: tl.constexpr,
+ STATE_LAST_DIM: tl.constexpr,
+ COMPRESS_RATIO: tl.constexpr,
+ IS_C4: tl.constexpr,
+ IS_PREFILL: tl.constexpr,
+ SWA_PAGE_SIZE: tl.constexpr,
+ STATE_RING: tl.constexpr,
+ BLOCK: tl.constexpr,
+):
+ token_idx = tl.program_id(0)
+ batch_idx = tl.load(token_to_batch_idx + token_idx) if IS_PREFILL else token_idx
+ position = tl.load(positions + token_idx)
+ seq_len = tl.load(b_seq_len + batch_idx)
+
+ if IS_C4:
+ same_page_next = (position % SWA_PAGE_SIZE) + STATE_RING < SWA_PAGE_SIZE
+ if same_page_next and position + STATE_RING < seq_len:
+ return
+ else:
+ if position + COMPRESS_RATIO < seq_len:
+ return
+
+ full_slot = tl.load(mem_index + token_idx).to(tl.int64)
+ swa_slot = tl.load(full_to_swa + full_slot).to(tl.int64)
+ if swa_slot < 0:
+ return
+ state_row = (swa_slot // SWA_PAGE_SIZE) * STATE_RING + (swa_slot % STATE_RING)
+
+ offs = tl.arange(0, BLOCK)
+ mask = offs < STATE_WIDTH
+ kv = tl.load(kv_score + token_idx * kv_score_stride0 + offs * kv_score_stride1, mask=mask)
+ score = tl.load(kv_score + token_idx * kv_score_stride0 + (STATE_WIDTH + offs) * kv_score_stride1, mask=mask)
+ state_base = state_buffer + state_row * STATE_LAST_DIM
+ tl.store(state_base + offs, kv, mask=mask)
+ tl.store(state_base + STATE_WIDTH + offs, score, mask=mask)
+ return
+
+
+@triton.jit
+def _fused_compress_norm_rope_insert_kernel(
+ kv_score,
+ kv_score_stride0,
+ kv_score_stride1,
+ state_buffer,
+ positions,
+ token_to_batch_idx,
+ b_req_idx,
+ b_seq_len,
+ b_ready_cache_len,
+ b_q_start_loc,
+ req_to_token,
+ req_to_token_stride0,
+ full_to_swa,
+ out_slots,
+ norm_weight,
+ rms_eps,
+ cos_table,
+ cos_stride0,
+ sin_table,
+ sin_stride0,
+ out_buffer,
+ HEAD_DIM: tl.constexpr,
+ STATE_WIDTH: tl.constexpr,
+ STATE_LAST_DIM: tl.constexpr,
+ COMPRESS_RATIO: tl.constexpr,
+ WINDOW_SIZE: tl.constexpr,
+ IS_C4: tl.constexpr,
+ IS_PREFILL: tl.constexpr,
+ SWA_PAGE_SIZE: tl.constexpr,
+ STATE_RING: tl.constexpr,
+ ROPE_HEAD_DIM: tl.constexpr,
+ FP8_MAX: tl.constexpr,
+ SCALE_MIN: tl.constexpr,
+ NOPE_DIM: tl.constexpr,
+ QUANT_BLOCK: tl.constexpr,
+ SCALE_BYTES: tl.constexpr,
+ PAGE_SIZE: tl.constexpr,
+ BYTES_PER_PAGE: tl.constexpr,
+ BLOCK: tl.constexpr,
+ OUTPUT_BF16: tl.constexpr,
+):
+ token_idx = tl.program_id(0)
+ out_slot = tl.load(out_slots + token_idx).to(tl.int64)
+ if out_slot < 0:
+ return
+
+ position = tl.load(positions + token_idx)
+ if (position + 1) % COMPRESS_RATIO != 0:
+ return
+
+ batch_idx = tl.load(token_to_batch_idx + token_idx) if IS_PREFILL else token_idx
+ req_idx = tl.load(b_req_idx + batch_idx).to(tl.int64)
+ seq_len = tl.load(b_seq_len + batch_idx)
+ if IS_PREFILL:
+ ready_len = tl.load(b_ready_cache_len + batch_idx)
+ q_start = tl.load(b_q_start_loc + batch_idx)
+ else:
+ ready_len = position
+ q_start = token_idx
+
+ token_offsets = tl.arange(0, WINDOW_SIZE)
+ start = position - WINDOW_SIZE + 1
+ gather_pos = start + token_offsets
+ valid_pos = (gather_pos >= 0) & (gather_pos < seq_len)
+ use_current = (gather_pos >= ready_len) & valid_pos if IS_PREFILL else gather_pos == position
+ current_idx = q_start + (gather_pos - ready_len) if IS_PREFILL else token_idx + token_offsets * 0
+
+ if IS_C4:
+ full_slot = tl.load(
+ req_to_token + req_idx * req_to_token_stride0 + gather_pos,
+ mask=valid_pos & (~use_current),
+ other=0,
+ ).to(tl.int64)
+ swa_slot = tl.load(full_to_swa + full_slot, mask=valid_pos & (~use_current), other=-1).to(tl.int64)
+ state_row = (swa_slot // SWA_PAGE_SIZE) * STATE_RING + (swa_slot % STATE_RING)
+ state_valid = valid_pos & (~use_current) & (swa_slot >= 0)
+ head_offset = tl.where(token_offsets >= COMPRESS_RATIO, HEAD_DIM, 0)
+ else:
+ full_slot = tl.load(
+ req_to_token + req_idx * req_to_token_stride0 + gather_pos,
+ mask=valid_pos & (~use_current),
+ other=0,
+ ).to(tl.int64)
+ swa_slot = tl.load(full_to_swa + full_slot, mask=valid_pos & (~use_current), other=-1).to(tl.int64)
+ state_row = (swa_slot // SWA_PAGE_SIZE) * STATE_RING + (swa_slot % STATE_RING)
+ state_valid = valid_pos & (~use_current) & (swa_slot >= 0)
+ head_offset = token_offsets * 0
+
+ offs = tl.arange(0, BLOCK)
+ dim_mask = offs < HEAD_DIM
+ current_mask = use_current[:, None] & dim_mask[None, :]
+ state_mask = state_valid[:, None] & dim_mask[None, :]
+
+ cur_kv = tl.load(
+ kv_score + current_idx[:, None] * kv_score_stride0 + (head_offset[:, None] + offs[None, :]) * kv_score_stride1,
+ mask=current_mask,
+ other=0.0,
+ )
+ cur_score = tl.load(
+ kv_score
+ + current_idx[:, None] * kv_score_stride0
+ + (STATE_WIDTH + head_offset[:, None] + offs[None, :]) * kv_score_stride1,
+ mask=current_mask,
+ other=float("-inf"),
+ )
+ state_kv = tl.load(
+ state_buffer + state_row[:, None] * STATE_LAST_DIM + head_offset[:, None] + offs[None, :],
+ mask=state_mask,
+ other=0.0,
+ )
+ state_score = tl.load(
+ state_buffer + state_row[:, None] * STATE_LAST_DIM + STATE_WIDTH + head_offset[:, None] + offs[None, :],
+ mask=state_mask,
+ other=float("-inf"),
+ )
+
+ kv = tl.where(current_mask, cur_kv, state_kv)
+ score = tl.where(current_mask, cur_score, state_score)
+ score = tl.softmax(score, dim=0)
+ compressed_kv = tl.sum(kv * score, axis=0)
+
+ rms_w = tl.load(norm_weight + offs, mask=dim_mask, other=0.0)
+ variance = tl.sum(compressed_kv * compressed_kv, axis=0) / HEAD_DIM
+ rrms = tl.rsqrt(variance + rms_eps)
+ normed = compressed_kv * rrms * rms_w
+
+ num_pairs: tl.constexpr = BLOCK // 2
+ nope_pairs: tl.constexpr = NOPE_DIM // 2
+ pair_2d = tl.reshape(normed, (num_pairs, 2))
+ even, odd = tl.split(pair_2d)
+ pair_idx = tl.arange(0, num_pairs)
+ rope_pair_local = pair_idx - nope_pairs
+ is_rope_pair = rope_pair_local >= 0
+ cs_idx = tl.maximum(rope_pair_local, 0)
+ compressed_pos = (position // COMPRESS_RATIO) * COMPRESS_RATIO
+ cos_v = tl.load(cos_table + compressed_pos * cos_stride0 + cs_idx, mask=is_rope_pair, other=1.0)
+ sin_v = tl.load(sin_table + compressed_pos * sin_stride0 + cs_idx, mask=is_rope_pair, other=0.0)
+ new_even = even * cos_v - odd * sin_v
+ new_odd = odd * cos_v + even * sin_v
+ rotated = tl.interleave(new_even, new_odd)
+
+ if OUTPUT_BF16:
+ # indexer-K path: emit the post-rope full HEAD_DIM vector as dense bf16 (token-indexed),
+ # leaving the fp8 single-amax pack to destindex_copy_indexer_k_dsv4 (the c4_indexer_pool
+ # ABI differs from the latent slab: whole-vector fp8 + one fp32 scale, no bf16 rope tail).
+ tl.store(out_buffer + token_idx * HEAD_DIM + offs, rotated.to(tl.bfloat16), mask=dim_mask)
+ return
+
+ page = out_slot // PAGE_SIZE
+ token_in_page = out_slot % PAGE_SIZE
+ data_base = page * BYTES_PER_PAGE + token_in_page * (NOPE_DIM + ROPE_HEAD_DIM * 2)
+ scale_base = page * BYTES_PER_PAGE + PAGE_SIZE * (NOPE_DIM + ROPE_HEAD_DIM * 2) + token_in_page * SCALE_BYTES
+
+ n_quant_blocks: tl.constexpr = BLOCK // QUANT_BLOCK
+ n_nope_blocks: tl.constexpr = NOPE_DIM // QUANT_BLOCK
+ quant_input = normed.to(tl.bfloat16).to(tl.float32)
+ quant_2d = tl.reshape(quant_input, (n_quant_blocks, QUANT_BLOCK))
+ abs_2d = tl.abs(quant_2d)
+ block_absmax = tl.max(abs_2d, axis=1)
+ scale_exp = tl.ceil(libdevice.log2(tl.maximum(block_absmax / FP8_MAX, SCALE_MIN))).to(tl.int32)
+ scale = ((scale_exp + 127) << 23).to(tl.float32, bitcast=True)
+ kv_fp8 = tl.clamp(quant_2d / scale[:, None], -FP8_MAX, FP8_MAX).to(tl.float8e4nv)
+ kv_u8 = tl.reshape(kv_fp8.to(tl.uint8, bitcast=True), (BLOCK,))
+ tl.store(out_buffer + data_base + offs, kv_u8, mask=offs < NOPE_DIM)
+
+ scale_idx = tl.arange(0, SCALE_BYTES)
+ scale_bytes = tl.where(scale_idx < n_nope_blocks, scale_exp + 127, 0).to(tl.uint8)
+ tl.store(out_buffer + scale_base + scale_idx, scale_bytes)
+
+ rope_local = offs - NOPE_DIM
+ rope_mask = (offs >= NOPE_DIM) & dim_mask
+ rope_ptr = (out_buffer + data_base + NOPE_DIM).to(tl.pointer_type(tl.bfloat16))
+ tl.store(rope_ptr + rope_local, rotated.to(tl.bfloat16), mask=rope_mask)
+ return
+
+
+def prepare_compress_states(*, infer_state, layer_idx: int, compress_ratio: int, is_in_indexer: bool = False):
+ if compress_ratio == 0:
+ return None
+
+ mem_manager: DeepseekV4MemoryManager = infer_state.mem_manager
+ if is_in_indexer:
+ # c4 Lightning-Indexer key compression: same window/state machinery as the c4 latent
+ # compressor but with index_head_dim, a separate state pool, and a DENSE bf16 scratch
+ # out_buffer (the kernel's OUTPUT_BF16 path); the fp8 pack into c4_indexer_pool is done
+ # afterwards by pack_indexer_k_to_cache.
+ assert compress_ratio == 4, "只有 c4(CSA) 层有 indexer-K"
+ out_slots = mem_manager.full_to_c4_indexs[infer_state.mem_index.long().reshape(-1)]
+ state_buffer = mem_manager.get_c4_indexer_state_buffer(layer_idx)
+ out_buffer = torch.empty(
+ (infer_state.mem_index.numel(), mem_manager.indexer_head_dim),
+ dtype=torch.bfloat16,
+ device=infer_state.mem_index.device,
+ )
+ out_page_size = 1 # unused under OUTPUT_BF16 (token-indexed dense scratch, not paged)
+ else:
+ if compress_ratio == 4:
+ out_slots = mem_manager.full_to_c4_indexs[infer_state.mem_index.long().reshape(-1)]
+ state_buffer = mem_manager.get_c4_state_buffer(layer_idx)
+ out_pool = mem_manager.c4_pool
+ elif compress_ratio == 128:
+ out_slots = mem_manager.full_to_c128_indexs[infer_state.mem_index.long().reshape(-1)]
+ state_buffer = mem_manager.get_c128_state_buffer(layer_idx)
+ out_pool = mem_manager.c128_pool
+ else:
+ raise AssertionError(f"invalid DeepSeek-V4 compress ratio {compress_ratio}")
+ out_buffer = mem_manager.get_compressed_kv_buffer(layer_idx)
+ out_page_size = out_pool.page_size
+
+ token_to_batch_idx = infer_state.b_req_idx
+ if infer_state.is_prefill:
+ token_to_batch_idx = getattr(infer_state, "_dsv4_token_to_batch_idx", None)
+ if token_to_batch_idx is None or token_to_batch_idx.numel() != infer_state.position_ids.numel():
+ q_lens = (infer_state.b_seq_len - infer_state.b_ready_cache_len).to(torch.long)
+ batch_idx = torch.arange(infer_state.b_req_idx.shape[0], device=infer_state.b_req_idx.device)
+ token_to_batch_idx = torch.repeat_interleave(
+ batch_idx, q_lens, output_size=infer_state.position_ids.numel()
+ ).to(torch.int32)
+ infer_state._dsv4_token_to_batch_idx = token_to_batch_idx
+
+ return CoreCompressorMetadata(
+ layer_idx=layer_idx,
+ compress_ratio=compress_ratio,
+ out_slots=out_slots,
+ mem_index=infer_state.mem_index,
+ state_buffer=state_buffer,
+ out_buffer=out_buffer,
+ out_page_size=out_page_size,
+ position_ids=infer_state.position_ids,
+ b_req_idx=infer_state.b_req_idx,
+ b_seq_len=infer_state.b_seq_len,
+ b_ready_cache_len=infer_state.b_ready_cache_len,
+ b_q_start_loc=infer_state.b_q_start_loc,
+ req_to_token_indexs=infer_state.req_manager.req_to_token_indexs,
+ full_to_swa_indexs=mem_manager.full_to_swa_indexs,
+ token_to_batch_idx=token_to_batch_idx,
+ kv_score=None,
+ is_prefill=infer_state.is_prefill,
+ )
+
+
+def prepare_partial_states(
+ *,
+ kv_score: torch.Tensor,
+ metadata: Optional[CoreCompressorMetadata],
+ ape: torch.Tensor,
+ compress_ratio: int,
+):
+ if metadata is None or kv_score.shape[0] == 0:
+ return
+ state_width = kv_score.shape[-1] // 2
+ _add_ape_to_kv_score_kernel[(kv_score.shape[0],)](
+ kv_score,
+ kv_score.stride(0),
+ kv_score.stride(1),
+ ape,
+ ape.stride(0),
+ metadata.position_ids,
+ STATE_WIDTH=state_width,
+ COMPRESS_RATIO=compress_ratio,
+ BLOCK=triton.next_power_of_2(state_width),
+ num_warps=4,
+ )
+ return
+
+
+def fused_compress(
+ *,
+ kv_score: torch.Tensor,
+ metadata: Optional[CoreCompressorMetadata],
+ norm_weight: torch.Tensor,
+ ape: torch.Tensor,
+ eps: float,
+ head_dim: int,
+ qk_rope_head_dim: int,
+ compress_ratio: int,
+ cos_table: torch.Tensor,
+ sin_table: torch.Tensor,
+ output_bf16: bool = False,
+):
+ if metadata is None or kv_score.shape[0] == 0:
+ return
+
+ state_width = kv_score.shape[-1] // 2
+ state_last_dim = metadata.state_buffer.shape[-1]
+ is_c4 = compress_ratio == 4
+ state_ring = DSV4_C4_STATE_RING if is_c4 else DSV4_C128_STATE_RING
+ block_state = triton.next_power_of_2(state_width)
+ block_head = triton.next_power_of_2(head_dim)
+
+ _fused_compress_norm_rope_insert_kernel[(kv_score.shape[0],)](
+ kv_score,
+ kv_score.stride(0),
+ kv_score.stride(1),
+ metadata.state_buffer,
+ metadata.position_ids,
+ metadata.token_to_batch_idx,
+ metadata.b_req_idx,
+ metadata.b_seq_len,
+ metadata.b_ready_cache_len if metadata.b_ready_cache_len is not None else metadata.b_seq_len,
+ metadata.b_q_start_loc if metadata.b_q_start_loc is not None else metadata.b_seq_len,
+ metadata.req_to_token_indexs,
+ metadata.req_to_token_indexs.stride(0),
+ metadata.full_to_swa_indexs,
+ metadata.out_slots,
+ norm_weight,
+ eps,
+ cos_table,
+ cos_table.stride(0),
+ sin_table,
+ sin_table.stride(0),
+ metadata.out_buffer,
+ HEAD_DIM=head_dim,
+ STATE_WIDTH=state_width,
+ STATE_LAST_DIM=state_last_dim,
+ COMPRESS_RATIO=compress_ratio,
+ WINDOW_SIZE=compress_ratio * (2 if is_c4 else 1),
+ IS_C4=is_c4,
+ IS_PREFILL=metadata.is_prefill,
+ SWA_PAGE_SIZE=DSV4_SWA_PAGE_SIZE,
+ STATE_RING=state_ring,
+ ROPE_HEAD_DIM=qk_rope_head_dim,
+ FP8_MAX=torch.finfo(torch.float8_e4m3fn).max,
+ SCALE_MIN=1e-4,
+ NOPE_DIM=head_dim - qk_rope_head_dim,
+ QUANT_BLOCK=64,
+ SCALE_BYTES=(head_dim - qk_rope_head_dim) // 64 + 1,
+ PAGE_SIZE=metadata.out_page_size,
+ BYTES_PER_PAGE=metadata.out_buffer.shape[-1],
+ BLOCK=block_head,
+ OUTPUT_BF16=output_bf16,
+ num_warps=4,
+ )
+
+ _save_partial_states_kernel[(kv_score.shape[0],)](
+ kv_score,
+ kv_score.stride(0),
+ kv_score.stride(1),
+ metadata.position_ids,
+ metadata.token_to_batch_idx,
+ metadata.b_req_idx,
+ metadata.b_seq_len,
+ metadata.mem_index,
+ metadata.full_to_swa_indexs,
+ metadata.state_buffer,
+ STATE_WIDTH=state_width,
+ STATE_LAST_DIM=state_last_dim,
+ COMPRESS_RATIO=compress_ratio,
+ IS_C4=is_c4,
+ IS_PREFILL=metadata.is_prefill,
+ SWA_PAGE_SIZE=DSV4_SWA_PAGE_SIZE,
+ STATE_RING=state_ring,
+ BLOCK=block_state,
+ num_warps=4,
+ )
+ return
diff --git a/lightllm/models/deepseek_v4/layer_infer/hyper_connection.py b/lightllm/models/deepseek_v4/layer_infer/hyper_connection.py
new file mode 100644
index 000000000..080ebabd8
--- /dev/null
+++ b/lightllm/models/deepseek_v4/layer_infer/hyper_connection.py
@@ -0,0 +1,75 @@
+import torch
+
+try:
+ import vllm.model_executor.layers.mhc # noqa: F401
+except Exception as e:
+ raise RuntimeError("DeepSeek-V4 requires vLLM mHC custom ops; failed to import vllm MHC kernels") from e
+
+
+# vllm DeepseekV4DecoderLayer.hc_post_alpha
+HC_POST_ALPHA = 2.0
+
+
+def hc_pre(residual, hc_fn, hc_scale, hc_base, rms_eps, hc_eps, sinkhorn_iters, norm_weight, norm_eps):
+ """Standalone hc_pre for the first layer. residual:[T, hc, dim] ->
+ (x[T,dim], residual, post_mix[T,hc,1], res_mix[T,hc,hc]); the sub-layer RMSNorm is fused via norm_weight."""
+ post_mix, res_mix, x = torch.ops.vllm.mhc_pre_tilelang(
+ residual=residual,
+ fn=hc_fn,
+ hc_scale=hc_scale,
+ hc_base=hc_base,
+ rms_eps=rms_eps,
+ hc_pre_eps=hc_eps,
+ hc_sinkhorn_eps=hc_eps,
+ hc_post_mult_value=HC_POST_ALPHA,
+ sinkhorn_repeat=sinkhorn_iters,
+ norm_weight=norm_weight,
+ norm_eps=norm_eps,
+ )
+ return x, residual, post_mix, res_mix
+
+
+def hc_fused_post_pre(
+ x, residual, post_mix, res_mix, hc_fn, hc_scale, hc_base, rms_eps, hc_eps, sinkhorn_iters, norm_weight, norm_eps
+):
+ """hc_post of the previous sub-layer fused with hc_pre of the next one (norm fused too).
+ Returns (x[T,dim], residual[T,hc,dim], post_mix, res_mix)."""
+ residual, post_mix, res_mix, x = torch.ops.vllm.mhc_fused_post_pre_tilelang(
+ x=x,
+ residual=residual,
+ post_layer_mix=post_mix,
+ comb_res_mix=res_mix,
+ fn=hc_fn,
+ hc_scale=hc_scale,
+ hc_base=hc_base,
+ rms_eps=rms_eps,
+ hc_pre_eps=hc_eps,
+ hc_sinkhorn_eps=hc_eps,
+ hc_post_mult_value=HC_POST_ALPHA,
+ sinkhorn_repeat=sinkhorn_iters,
+ norm_weight=norm_weight,
+ norm_eps=norm_eps,
+ )
+ return x, residual, post_mix, res_mix
+
+
+def hc_post(x, residual, post_mix, res_mix):
+ """Complete the hc_post left pending by the last layer. -> streams [T, hc, dim]."""
+ return torch.ops.vllm.mhc_post_tilelang(x, residual, post_mix, res_mix)
+
+
+def hc_head(streams, hc_fn, hc_scale, hc_base, hc_mult, dim, rms_eps, hc_eps, alloc_func):
+ """Final stream collapse before the lm_head. streams:[N, hc*dim] -> [N, dim]."""
+ out = alloc_func((streams.shape[0], dim), dtype=streams.dtype, device=streams.device)
+ torch.ops.vllm.hc_head_fused_kernel_tilelang(
+ streams.view(-1, hc_mult, dim).contiguous(),
+ hc_fn,
+ hc_scale,
+ hc_base,
+ out,
+ dim,
+ rms_eps,
+ hc_eps,
+ hc_mult,
+ )
+ return out
diff --git a/lightllm/models/deepseek_v4/layer_infer/post_layer_infer.py b/lightllm/models/deepseek_v4/layer_infer/post_layer_infer.py
new file mode 100644
index 000000000..8eddfb3b9
--- /dev/null
+++ b/lightllm/models/deepseek_v4/layer_infer/post_layer_infer.py
@@ -0,0 +1,27 @@
+from lightllm.models.llama.layer_infer.post_layer_infer import LlamaPostLayerInfer
+from .hyper_connection import hc_head, hc_post
+from ..infer_struct import DeepseekV4InferStateInfo
+
+
+class DeepseekV4PostLayerInfer(LlamaPostLayerInfer):
+ """Collapse the hc_mult residual streams (hc_head) to [T, hidden], then final norm + lm_head."""
+
+ def token_forward(self, input_embdings, infer_state: DeepseekV4InferStateInfo, layer_weight):
+ cfg = layer_weight.network_config_
+ if isinstance(input_embdings, tuple):
+ # truncated-layer runs (autotune warmup) end before the last layer's _hc_ffn_out
+ # collapse; finish the pending hc_post here.
+ streams = hc_post(*input_embdings)
+ input_embdings = streams.reshape(streams.shape[0], -1)
+ collapsed = hc_head(
+ input_embdings,
+ layer_weight.hc_head_fn_.weight,
+ layer_weight.hc_head_scale_.weight,
+ layer_weight.hc_head_base_.weight,
+ cfg["hc_mult"],
+ cfg["hidden_size"],
+ cfg["rms_norm_eps"],
+ cfg.get("hc_eps", 1e-6),
+ self.alloc_tensor,
+ )
+ return super().token_forward(collapsed, infer_state, layer_weight)
diff --git a/lightllm/models/deepseek_v4/layer_infer/pre_layer_infer.py b/lightllm/models/deepseek_v4/layer_infer/pre_layer_infer.py
new file mode 100644
index 000000000..b95f5a14a
--- /dev/null
+++ b/lightllm/models/deepseek_v4/layer_infer/pre_layer_infer.py
@@ -0,0 +1,24 @@
+import torch
+import torch.distributed as dist
+from lightllm.models.llama.layer_infer.pre_layer_infer import LlamaPreLayerInfer
+from lightllm.distributed.communication_op import all_reduce
+from ..infer_struct import DeepseekV4InferStateInfo
+
+
+class DeepseekV4PreLayerInfer(LlamaPreLayerInfer):
+ """Token embedding, then expand to the hc_mult parallel residual streams [T, hc_mult*hidden]."""
+
+ def __init__(self, network_config):
+ super().__init__(network_config)
+ self.hc_mult = network_config["hc_mult"]
+ return
+
+ def context_forward(self, input_ids, infer_state: DeepseekV4InferStateInfo, layer_weight):
+ input_embdings = super().context_forward(input_ids, infer_state, layer_weight)
+ t, hidden = input_embdings.shape
+ return input_embdings.unsqueeze(1).expand(t, self.hc_mult, hidden).reshape(t, self.hc_mult * hidden)
+
+ def token_forward(self, input_ids, infer_state: DeepseekV4InferStateInfo, layer_weight):
+ input_embdings = super().token_forward(input_ids, infer_state, layer_weight)
+ t, hidden = input_embdings.shape
+ return input_embdings.unsqueeze(1).expand(t, self.hc_mult, hidden).reshape(t, self.hc_mult * hidden)
diff --git a/lightllm/models/deepseek_v4/layer_infer/transformer_layer_infer.py b/lightllm/models/deepseek_v4/layer_infer/transformer_layer_infer.py
new file mode 100644
index 000000000..664bc2357
--- /dev/null
+++ b/lightllm/models/deepseek_v4/layer_infer/transformer_layer_infer.py
@@ -0,0 +1,775 @@
+import torch
+import torch.distributed as dist
+from lightllm.common.basemodel import TransformerLayerInferTpl
+from lightllm.common.basemodel.attention.base_att import AttControl
+from lightllm.distributed.communication_op import all_reduce
+from lightllm.models.deepseek3_2.layer_infer.transformer_layer_infer import Deepseek3_2TransformerLayerInfer
+from lightllm.models.deepseek_v4.layer_weights.transformer_layer_weight import DeepseekV4TransformerLayerWeight
+from lightllm.utils.envs_utils import get_env_start_args
+from lightllm.utils.tensor_utils import tensor_to_no_ref_tensor
+from lightllm.utils.vllm_utils import vllm_ops
+from .hyper_connection import hc_pre, hc_fused_post_pre, hc_post
+from .compressor import fused_compress as fused_compress_op
+from .compressor import prepare_partial_states
+from .compressor import prepare_compress_states
+from lightllm.models.deepseek2.triton_kernel.rotary_emb import rotary_emb_fwd
+from ..infer_struct import DeepseekV4InferStateInfo
+import deep_gemm
+from lightllm.third_party.sglang_jit.dsv4 import topk_transform_512
+
+
+_C4_PREFILL_LOGITS_BUDGET_BYTES = 512 * 1024 * 1024
+
+
+class DeepseekV4TransformerLayerInfer(Deepseek3_2TransformerLayerInfer):
+ def __init__(self, layer_num, network_config):
+ TransformerLayerInferTpl.__init__(self, layer_num, network_config)
+ self.eps_ = network_config["rms_norm_eps"]
+ self.embed_dim_ = network_config["hidden_size"]
+ self.num_heads = network_config["num_attention_heads"]
+ self.head_dim_ = network_config["head_dim"]
+ self.qk_rope_head_dim = network_config["qk_rope_head_dim"]
+ self.qk_nope_head_dim = self.head_dim_ - self.qk_rope_head_dim
+ self.v_head_dim = self.head_dim_
+ self.o_groups = network_config["o_groups"]
+ self.hc_mult = network_config["hc_mult"]
+ self.sinkhorn_iters = network_config["hc_sinkhorn_iters"]
+ self.hc_eps = network_config["hc_eps"]
+ self.compress_ratio = network_config["compress_ratios"][layer_num]
+ self.is_hash = layer_num < network_config["num_hash_layers"]
+ self.is_last_layer = layer_num == network_config["n_layer"] - 1
+ # complex64 rope table for this layer's variant (sliding / compressed); set by
+ # DeepseekV4TpPartModel._init_to_get_rotary once the tables are built. The full compress
+ # cos/sin tables (compressor entry rope uses entry positions, not token positions) are
+ # wired there too.
+ self.freqs_cis = None
+ self.cos_compress_table = None
+ self.sin_compress_table = None
+ self.num_experts_per_tok = network_config["num_experts_per_tok"]
+ self.routed_scaling_factor = network_config["routed_scaling_factor"]
+ self.swiglu_limit = network_config["swiglu_limit"]
+ self.softmax_scale = (self.qk_nope_head_dim + self.qk_rope_head_dim) ** (-0.5)
+ self.tp_q_head_num_ = self.num_heads // self.tp_world_size_
+ self.tp_groups = self.o_groups // self.tp_world_size_
+ self.enable_ep_moe = get_env_start_args().enable_ep_moe
+ self.compressor = CompressorInfer(
+ layer_idx=self.layer_num_, network_config=self.network_config_, tp_world_size=self.tp_world_size_
+ )
+ self.index_infer = DeepseekV4IndexInfer(
+ layer_idx=self.layer_num_, network_config=self.network_config_, tp_world_size=self.tp_world_size_
+ )
+ self.dsv4_prefill_aux_stream = None
+
+ # ------------------------------------------------------------------ forward (HC-threaded)
+ def _hc_attn_in(self, input_embdings, layer_weight: DeepseekV4TransformerLayerWeight):
+ """Layer input -> attention input (attn_norm fused). First layer gets the raw streams
+ and runs a standalone hc_pre; later layers get (x, residual, post_mix, res_mix) and fuse
+ the previous layer's ffn hc_post with this layer's attn hc_pre."""
+ if torch.is_tensor(input_embdings):
+ residual = input_embdings.view(-1, self.hc_mult, self.embed_dim_)
+ return hc_pre(
+ residual,
+ layer_weight.hc_attn_fn_.weight,
+ layer_weight.hc_attn_scale_.weight,
+ layer_weight.hc_attn_base_.weight,
+ self.eps_,
+ self.hc_eps,
+ self.sinkhorn_iters,
+ layer_weight.attn_norm_.weight,
+ self.eps_,
+ )
+ x, residual, post_mix, res_mix = input_embdings
+ return hc_fused_post_pre(
+ x,
+ residual,
+ post_mix,
+ res_mix,
+ layer_weight.hc_attn_fn_.weight,
+ layer_weight.hc_attn_scale_.weight,
+ layer_weight.hc_attn_base_.weight,
+ self.eps_,
+ self.hc_eps,
+ self.sinkhorn_iters,
+ layer_weight.attn_norm_.weight,
+ self.eps_,
+ )
+
+ def _hc_ffn_in(self, x, residual, post_mix, res_mix, layer_weight: DeepseekV4TransformerLayerWeight):
+ """Attention output -> ffn input (ffn_norm fused): fused attn hc_post + ffn hc_pre."""
+ return hc_fused_post_pre(
+ x,
+ residual,
+ post_mix,
+ res_mix,
+ layer_weight.hc_ffn_fn_.weight,
+ layer_weight.hc_ffn_scale_.weight,
+ layer_weight.hc_ffn_base_.weight,
+ self.eps_,
+ self.hc_eps,
+ self.sinkhorn_iters,
+ layer_weight.ffn_norm_.weight,
+ self.eps_,
+ )
+
+ def _hc_ffn_out(self, x, residual, post_mix, res_mix):
+ """Mid layers leave the ffn hc_post pending for the next layer's fused post+pre; the last
+ layer completes it and hands the flat streams [T, hc_mult*hidden] back to the model loop."""
+ if not self.is_last_layer:
+ return x, residual, post_mix, res_mix
+ streams = hc_post(x, residual, post_mix, res_mix)
+ return streams.reshape(streams.shape[0], -1)
+
+ def context_forward(
+ self, input_embdings, infer_state: DeepseekV4InferStateInfo, layer_weight: DeepseekV4TransformerLayerWeight
+ ):
+ x, residual, post_mix, res_mix = self._hc_attn_in(input_embdings, layer_weight)
+ x = self.context_attention_forward(x, infer_state, layer_weight)
+ x, residual, post_mix, res_mix = self._hc_ffn_in(x, residual, post_mix, res_mix, layer_weight)
+ x = self._ffn(x, infer_state, layer_weight)
+ out = self._hc_ffn_out(x, residual, post_mix, res_mix)
+ return out
+
+ def token_forward(
+ self, input_embdings, infer_state: DeepseekV4InferStateInfo, layer_weight: DeepseekV4TransformerLayerWeight
+ ):
+ x, residual, post_mix, res_mix = self._hc_attn_in(input_embdings, layer_weight)
+ x = self.token_attention_forward(x, infer_state, layer_weight)
+ x, residual, post_mix, res_mix = self._hc_ffn_in(x, residual, post_mix, res_mix, layer_weight)
+ x = self._ffn(x, infer_state, layer_weight)
+ return self._hc_ffn_out(x, residual, post_mix, res_mix)
+
+ # ------------------------------------------------------------------ shared projections / cache
+ def _select_rope(self, infer_state: DeepseekV4InferStateInfo):
+ if self.compress_ratio:
+ return infer_state.position_cos_compress, infer_state.position_sin_compress
+ return infer_state.position_cos_sliding, infer_state.position_sin_sliding
+
+ def _get_qkv(
+ self,
+ input: torch.Tensor,
+ infer_state: DeepseekV4InferStateInfo,
+ layer_weight: DeepseekV4TransformerLayerWeight,
+ ):
+ from lightllm.third_party.sglang_jit.dsv4 import fused_q_norm_rope
+
+ input = self._tpsp_allgather(input=input, infer_state=infer_state)
+ T = input.shape[0]
+ # wq_a and wkv share `input` -> one fused fp8 GEMM, split [q_lora_rank | head_dim]. qa is a
+ # row-strided view (rmsnorm honors stride(0)); kv feeds a sglang jit kernel -> contiguous.
+ qkv = layer_weight.wq_a_wkv_.mm(input)
+ qa = layer_weight.q_norm_(qkv[:, : -self.head_dim_], eps=self.eps_)
+ q_in = layer_weight.wq_b_.mm(qa).view(T, self.tp_q_head_num_, self.head_dim_)
+ # per-(token, head) weightless self-RMSNorm + interleaved rope on the last rope_dim dims,
+ # fused in one sglang dsv4 jit kernel (fp32 norm/rotation, bf16 in between -- same as eager).
+ q = self.alloc_tensor(q_in.shape, dtype=q_in.dtype, device=q_in.device)
+ fused_q_norm_rope(q_in, q, self.eps_, self.freqs_cis, infer_state.position_ids)
+ # kv: rmsnorm + rope + fp8 pack + scatter 进 swa 池,一个 sglang jit kernel 完成
+ # (同 sglang _compute_kv_to_cache),替代 eager norm/rope/cat + _post_cache_kv。
+ # bf16 kv 中间量没有其他消费者: flashmla 路径注意力读 cache,压缩器/indexer 取 x。
+ infer_state.mem_manager.pack_mla_kv_to_cache_fused_norm_rope(
+ layer_index=self.layer_num_,
+ mem_index=infer_state.mem_index,
+ kv=qkv[:, -self.head_dim_ :].contiguous(),
+ kv_weight=layer_weight.kv_norm_.weight,
+ eps=self.eps_,
+ freqs_cis=self.freqs_cis,
+ positions=infer_state.position_ids,
+ )
+ return q, qa, input
+
+ def _get_o(self, o, infer_state: DeepseekV4InferStateInfo, layer_weight: DeepseekV4TransformerLayerWeight):
+ # o: [T, tp_q_head_num_, head_dim_] after inverse rope -> grouped low-rank O -> [T, embed_dim_]
+ position_cos, position_sin = self._select_rope(infer_state)
+ rotary_emb_fwd(o[..., -self.qk_rope_head_dim :], None, position_cos, position_sin, inverse=True)
+ T = o.shape[0]
+ if layer_weight.o_proj_fp8:
+ # one group per rank -> a single fp8 GEMM (deepgemm .mm quantizes o to fp8 internally)
+ o = layer_weight.wo_a_.mm(o.reshape(T, -1)) # [T, o_lora]
+ else:
+ o = o.reshape(T, self.tp_groups, -1).transpose(0, 1).contiguous() # [groups, T, per_group_in]
+ o = layer_weight.wo_a_.bmm(o).transpose(0, 1).reshape(T, -1) # [T, groups*o_lora]
+ o = layer_weight.wo_b_.mm(o)
+ return self._tpsp_reduce(input=o, infer_state=infer_state)
+
+ # ------------------------------------------------------------------ attention (prefill)
+ def context_attention_forward(
+ self, x, infer_state: DeepseekV4InferStateInfo, layer_weight: DeepseekV4TransformerLayerWeight
+ ):
+ # _get_qkv writes the chunk's packed latent into the swa pool (fused kernel) before
+ # attention reads it back via full_to_swa indices (this custom forward bypasses the
+ # tpl _post_cache_kv path).
+ q, q_lora, full_x = self._get_qkv(x, infer_state, layer_weight)
+ o = self._context_attention_wrapper_run(q, q_lora, full_x, infer_state, layer_weight)
+ return self._get_o(o, infer_state, layer_weight)
+
+ def _context_attention_wrapper_run(
+ self, q, q_lora, x, infer_state: DeepseekV4InferStateInfo, layer_weight: DeepseekV4TransformerLayerWeight
+ ):
+ if torch.cuda.is_current_stream_capturing():
+ q = q.contiguous()
+ q_lora = q_lora.contiguous()
+ x = x.contiguous()
+ _q = tensor_to_no_ref_tensor(q)
+ _q_lora = tensor_to_no_ref_tensor(q_lora)
+ _x = tensor_to_no_ref_tensor(x)
+
+ pre_capture_graph = infer_state.prefill_cuda_graph_get_current_capture_graph()
+ pre_capture_graph.__exit__(None, None, None)
+
+ infer_state.prefill_cuda_graph_create_graph_obj()
+ infer_state.prefill_cuda_graph_get_current_capture_graph().__enter__()
+ # Same graph-split output handoff as the template, but avoid its dry-run because
+ # DSV4 attention mutates compressor/cache state before returning.
+ o = self.alloc_tensor((q.shape[0], self.tp_q_head_num_, self.head_dim_), dtype=q.dtype, device=q.device)
+ _o = tensor_to_no_ref_tensor(o)
+
+ def att_func(new_infer_state: DeepseekV4InferStateInfo):
+ tmp_o = self._context_attention_kernel(_q, _q_lora, _x, new_infer_state, layer_weight, out=_o)
+ assert tmp_o.shape == _o.shape
+ if tmp_o.data_ptr() != _o.data_ptr():
+ _o.copy_(tmp_o)
+ return
+
+ infer_state.prefill_cuda_graph_add_cpu_runnning_func(func=att_func, after_graph=pre_capture_graph)
+ return o
+
+ return self._context_attention_kernel(q, q_lora, x, infer_state, layer_weight)
+
+ def _compress_and_index(self, q_lora, x, infer_state: DeepseekV4InferStateInfo, layer_weight):
+ cos_table, sin_table = self.cos_compress_table, self.sin_compress_table
+ aux_stream = self.dsv4_prefill_aux_stream
+ if self.compress_ratio == 4 and aux_stream is not None and not torch.cuda.is_current_stream_capturing():
+ # _dsv4_token_to_batch_idx is built in init_some_extra_state (default stream, before this fork),
+ # so both the aux indexer-compressor and the main compressor read a ready, race-free tensor.
+ main_stream = torch.cuda.current_stream()
+ aux_stream.wait_stream(main_stream) # fork: aux waits for x / q_lora produced on main
+ with torch.cuda.stream(aux_stream):
+ # x / q_lora are main-allocated and read here -> record so the allocator won't reuse them.
+ x.record_stream(aux_stream)
+ q_lora.record_stream(aux_stream)
+ self.index_infer.write_indexer_k(
+ x, infer_state, layer_weight, cos_table, sin_table, use_custom_tensor_manager=False
+ )
+ meta = self.index_infer.build_metadata(
+ x, q_lora, infer_state, layer_weight, use_custom_tensor_manager=False
+ )
+ self.compressor.prepare_states(x, infer_state, layer_weight)
+ self.compressor.fused_compress(infer_state, layer_weight, cos_table, sin_table)
+ main_stream.wait_stream(aux_stream) # join before prefill_att reads the indices / latent KV
+ # extra_indices / extra_lengths were allocated on aux -> record on main so they survive until consumed.
+ for _t in (meta.get("extra_indices"), meta.get("extra_lengths")):
+ if _t is not None:
+ _t.record_stream(main_stream)
+ return meta
+
+ # serial fallback -- semantics identical to the original sequence.
+ self.compressor.prepare_states(x, infer_state, layer_weight)
+ self.compressor.fused_compress(infer_state, layer_weight, cos_table, sin_table)
+ # write c4 Lightning-Indexer keys BEFORE build_metadata so the scorer reads fresh+accumulated entries.
+ self.index_infer.write_indexer_k(x, infer_state, layer_weight, cos_table, sin_table)
+ return self.index_infer.build_metadata(x, q_lora, infer_state, layer_weight)
+
+ def _context_attention_kernel(
+ self,
+ q,
+ q_lora,
+ x,
+ infer_state: DeepseekV4InferStateInfo,
+ layer_weight: DeepseekV4TransformerLayerWeight,
+ out=None,
+ ):
+ meta = self._compress_and_index(q_lora, x, infer_state, layer_weight)
+ att_control = AttControl(
+ nsa_prefill=True,
+ nsa_prefill_dict={
+ "flashmla_kvcache": True,
+ "layer_index": self.layer_num_,
+ "compress_ratio": self.compress_ratio,
+ "head_dim_v": self.v_head_dim,
+ "softmax_scale": self.softmax_scale,
+ "attn_sink": layer_weight.attn_sink_.weight,
+ **meta,
+ },
+ )
+ out = infer_state.prefill_att_state.prefill_att(
+ q=q,
+ k=infer_state.mem_manager.get_att_input_params(layer_index=self.layer_num_),
+ v=None,
+ att_control=att_control,
+ out=out,
+ )
+ pad_q_len = getattr(infer_state, "_dsv4_prefill_pad_q_len", 0)
+ if pad_q_len:
+ # pad 行读 HOLD 槽位(参见 infer_struct._dsv4_prefill_pad_q_len),清零以保持确定性
+ out[-pad_q_len:] = 0
+ return out
+
+ # ------------------------------------------------------------------ attention (decode)
+ def token_attention_forward(
+ self, x, infer_state: DeepseekV4InferStateInfo, layer_weight: DeepseekV4TransformerLayerWeight
+ ):
+ q, q_lora, full_x = self._get_qkv(x, infer_state, layer_weight)
+ o = self._token_attention_kernel(q, q_lora, full_x, infer_state, layer_weight)
+ return self._get_o(o, infer_state, layer_weight)
+
+ def _token_attention_kernel(
+ self, q, q_lora, x, infer_state: DeepseekV4InferStateInfo, layer_weight: DeepseekV4TransformerLayerWeight
+ ):
+ self.compressor.prepare_states(x, infer_state, layer_weight)
+ self.compressor.fused_compress(infer_state, layer_weight, self.cos_compress_table, self.sin_compress_table)
+ self.index_infer.write_indexer_k(x, infer_state, layer_weight, self.cos_compress_table, self.sin_compress_table)
+ meta = self.index_infer.build_metadata(x, q_lora, infer_state, layer_weight)
+ att_control = AttControl(
+ nsa_decode=True,
+ nsa_decode_dict={
+ "flashmla_kvcache": True,
+ "layer_index": self.layer_num_,
+ "compress_ratio": self.compress_ratio,
+ "head_dim_v": self.v_head_dim,
+ "softmax_scale": self.softmax_scale,
+ "attn_sink": layer_weight.attn_sink_.weight,
+ **meta,
+ },
+ )
+ return infer_state.decode_att_state.decode_att(
+ q=q,
+ k=infer_state.mem_manager.get_att_input_params(layer_index=self.layer_num_),
+ v=None,
+ att_control=att_control,
+ )
+
+ # ------------------------------------------------------------------ moe
+ def _routed_experts(self, x, weights, indices, layer_weight: DeepseekV4TransformerLayerWeight):
+ return layer_weight.experts_.experts_with_preselected(
+ input_tensor=x,
+ topk_weights=weights,
+ topk_ids=indices,
+ clamp_limit=float(self.swiglu_limit),
+ )
+
+ def _ffn(self, x, infer_state: DeepseekV4InferStateInfo, layer_weight: DeepseekV4TransformerLayerWeight):
+ x = x.view(-1, self.embed_dim_)
+ if not self.enable_ep_moe:
+ x = self._tpsp_allgather(input=x, infer_state=infer_state)
+
+ logits = layer_weight.gate_weight_.mm(x).float().contiguous()
+ weights, indices = self._select_experts(logits, infer_state, layer_weight)
+ # shared expert 必须先于 routed 计算: fp8 路径 (FuseMoeTriton) 的 fused_experts
+ # 是 inplace 的,_routed_experts 返回后 x 已被覆盖为 routed 输出。
+ # 复用 Llama 的 _ffn_tp: fused gate_up matmul + silu_and_mul triton kernel,无 swiglu clamp,
+ # 对齐参考 DeepseekV4MLP(=LlamaMLP)。swiglu_limit clamp 只属于 routed 专家 (见 _routed_experts)。
+ shared = self._ffn_tp(input=x, infer_state=infer_state, layer_weight=layer_weight)
+ routed = self._routed_experts(x, weights, indices, layer_weight)
+ if self.enable_ep_moe:
+ if self.tp_world_size_ > 1:
+ all_reduce(
+ shared,
+ op=dist.ReduceOp.SUM,
+ group=infer_state.dist_group,
+ async_op=False,
+ )
+ return routed + shared
+ out = routed + shared
+ return self._tpsp_reduce(input=out, infer_state=infer_state)
+
+ def _select_experts(
+ self, logits, infer_state: DeepseekV4InferStateInfo, layer_weight: DeepseekV4TransformerLayerWeight
+ ):
+ return self._select_experts_vllm(logits, infer_state, layer_weight)
+
+ def _select_experts_vllm(
+ self, logits, infer_state: DeepseekV4InferStateInfo, layer_weight: DeepseekV4TransformerLayerWeight
+ ):
+ M = logits.shape[0]
+ bias = None
+ input_tokens = None
+ hash_indices_table = None
+ indices_dtype = torch.int64
+ if self.is_hash:
+ hash_indices_table = layer_weight.gate_tid2eid_.weight
+ if not hash_indices_table.is_contiguous():
+ hash_indices_table = hash_indices_table.contiguous()
+ indices_dtype = hash_indices_table.dtype
+ input_tokens = infer_state.input_ids.to(dtype=indices_dtype).contiguous()
+ else:
+ bias = layer_weight.gate_bias_.weight
+
+ weights = self.alloc_tensor((M, self.num_experts_per_tok), dtype=torch.float32, device=logits.device)
+ indices = self.alloc_tensor((M, self.num_experts_per_tok), dtype=indices_dtype, device=logits.device)
+ token_expert_indices = self.alloc_tensor((M, self.num_experts_per_tok), dtype=torch.int32, device=logits.device)
+ vllm_ops.topk_hash_softplus_sqrt(
+ weights,
+ indices,
+ token_expert_indices,
+ logits,
+ True,
+ self.routed_scaling_factor,
+ bias,
+ input_tokens,
+ hash_indices_table,
+ )
+ return weights, indices.long()
+
+
+class CompressorInfer:
+ """Window-softmax compressor. is_in_indexer=False compresses the c4/c128 latent KV into the
+ paged fp8 slab (attention extra_k); is_in_indexer=True reuses the SAME machinery (mirroring
+ sglang's Compressor(is_in_indexer=...)) with the indexer weights/dims/state pool to produce the
+ per-c4-entry Lightning-Indexer keys, emitted as dense bf16 (OUTPUT_BF16) then fp8-packed into
+ c4_indexer_pool by the caller. Indexer mode is c4-only."""
+
+ def __init__(self, layer_idx: int, network_config: dict, tp_world_size: int, is_in_indexer: bool = False):
+ super().__init__()
+ self.layer_idx_ = layer_idx
+ self.network_config_ = network_config
+ self.tp_world_size_ = tp_world_size
+ self.is_in_indexer = is_in_indexer
+ self.compress_ratio = network_config["compress_ratios"][layer_idx]
+ self.head_dim = network_config["head_dim"]
+ self.index_head_dim = network_config["index_head_dim"]
+ self.qk_rope_head_dim = network_config["qk_rope_head_dim"]
+ self.eps = network_config["rms_norm_eps"]
+ self._metadata = None
+
+ def prepare_states(
+ self,
+ x: torch.Tensor,
+ infer_state: DeepseekV4InferStateInfo,
+ layer_weight: DeepseekV4TransformerLayerWeight,
+ use_custom_tensor_manager: bool = True,
+ ):
+ # use_custom_tensor_manager=False routes the .mm outputs through torch.empty (stream-aware)
+ # instead of the stream-blind global cache -- required when this runs on the prefill aux stream.
+ self._metadata = prepare_compress_states(
+ infer_state=infer_state,
+ layer_idx=self.layer_idx_,
+ compress_ratio=self.compress_ratio,
+ is_in_indexer=self.is_in_indexer,
+ )
+ if self._metadata is not None:
+ if self.is_in_indexer:
+ # fused wkv/wgate GEMM -> [T, 2*coff*idx_hd] in the [kv | score] layout directly
+ # (same as the attention compressor_wkv_gate_).
+ self._metadata.kv_score = layer_weight.idx_cmp_wkv_gate_.mm(
+ x, use_custom_tensor_mananger=use_custom_tensor_manager
+ ).float()
+ ape = layer_weight.idx_cmp_ape_.weight
+ else:
+ self._metadata.kv_score = layer_weight.compressor_wkv_gate_.mm(x).float()
+ ape = layer_weight.compressor_ape_.weight
+ prepare_partial_states(
+ kv_score=self._metadata.kv_score,
+ metadata=self._metadata,
+ ape=ape,
+ compress_ratio=self.compress_ratio,
+ )
+ return self._metadata
+
+ def fused_compress(
+ self,
+ infer_state: DeepseekV4InferStateInfo,
+ layer_weight: DeepseekV4TransformerLayerWeight,
+ cos_table: torch.Tensor,
+ sin_table: torch.Tensor,
+ ):
+ if self.compress_ratio == 0:
+ return None
+ metadata = self._metadata
+ if metadata is None:
+ raise RuntimeError("DeepSeek-V4 compressor.prepare_states must run before fused_compress")
+ if self.is_in_indexer:
+ norm_weight = layer_weight.idx_cmp_norm_.weight
+ ape = layer_weight.idx_cmp_ape_.weight
+ head_dim = self.index_head_dim
+ else:
+ norm_weight = layer_weight.compressor_norm_.weight
+ ape = layer_weight.compressor_ape_.weight
+ head_dim = self.head_dim
+ return fused_compress_op(
+ kv_score=metadata.kv_score,
+ metadata=metadata,
+ norm_weight=norm_weight,
+ ape=ape,
+ eps=self.eps,
+ head_dim=head_dim,
+ qk_rope_head_dim=self.qk_rope_head_dim,
+ compress_ratio=self.compress_ratio,
+ cos_table=cos_table,
+ sin_table=sin_table,
+ output_bf16=self.is_in_indexer,
+ )
+
+
+class DeepseekV4IndexInfer:
+ """Model-side builder for the FlashMLA sparse-index metadata. Mirrors deepseek3_2's NsaInfer
+ boundary (the model owns ALL index construction; the attention backend only forwards final
+ tensors to flash_mla.flash_mla_with_kvcache) AND its c4 implementation: hadamard'd fp8 q/K, a
+ ragged gather of the compressed c4 keys, deep_gemm.fp8_mqa_logits, then topk -- adapted for the
+ replicated indexer (no gather-q/all_reduce), the c4-compressed entry space, and topk-512 (no
+ inheritance only because of those data-shape differences). swa metadata is precomputed in
+ init_some_extra_state; this class owns the c4 entry gather (build_compress_index) AND the c4
+ Lightning-Indexer scoring (gather + deep_gemm.fp8_mqa_logits + topk). Holds only static per-layer
+ config; all per-request data flows in via args. Invoke from _context/_token_attention_kernel
+ (after compressor.fused_compress, before *_att) so the c4 scorer/topk keep the same cuda-graph
+ capture position they had when this lived in the backend. The indexer is replicated (no TP collective)."""
+
+ def __init__(self, layer_idx: int, network_config: dict, tp_world_size: int):
+ self.layer_idx_ = layer_idx
+ self.compress_ratio = network_config["compress_ratios"][layer_idx]
+ self.index_topk = network_config["index_topk"]
+ self.index_head_dim = network_config["index_head_dim"]
+ self.qk_rope_head_dim = network_config["qk_rope_head_dim"]
+ self.index_n_heads = network_config["index_n_heads"]
+ self.tp_world_size_ = tp_world_size
+ self.indexer_score_scale = self.index_head_dim ** -0.5
+ self.indexer_weight_scale = self.indexer_score_scale * self.index_n_heads ** -0.5
+ # c4 layers own a second compressor (is_in_indexer) that writes the Lightning-Indexer key
+ # pool every step; _c4_indices gathers it back + scores via deep_gemm.fp8_mqa_logits.
+ self.indexer_compressor = (
+ CompressorInfer(layer_idx, network_config, tp_world_size, is_in_indexer=True)
+ if self.compress_ratio == 4
+ else None
+ )
+
+ def write_indexer_k(
+ self,
+ x,
+ infer_state: DeepseekV4InferStateInfo,
+ layer_weight,
+ cos_table,
+ sin_table,
+ use_custom_tensor_manager=True,
+ ):
+ """c4-only: compress this step's tokens into per-c4-entry indexer keys and pack them into
+ c4_indexer_pool. MUST run before build_metadata so the scorer (gather + deep_gemm.fp8_mqa_logits)
+ reads the finished entries; runs every step (incl. in the decode graph) so keys accumulate for
+ later long-context scoring. No-op on c128 / dense layers."""
+ if self.compress_ratio != 4:
+ return
+ self.indexer_compressor.prepare_states(
+ x, infer_state, layer_weight, use_custom_tensor_manager=use_custom_tensor_manager
+ )
+ self.indexer_compressor.fused_compress(infer_state, layer_weight, cos_table, sin_table)
+ scratch = self.indexer_compressor._metadata.out_buffer # [T, index_head_dim] bf16 (group-end rows valid)
+ # Rotate K (post norm+rope) by the SAME 1/sqrt(d) Hadamard the q kernel applies, so
+ # (Hq)·(Hk)=q·k (H orthogonal) and the fp8 quant of K stays accurate.
+ from lightllm.models.deepseek3_2.triton_kernel.hadamard_transform import hadamard_transform
+
+ scratch = hadamard_transform(scratch, scale=self.index_head_dim ** -0.5)
+ mem_manager = infer_state.mem_manager
+ positions = infer_state.position_ids
+ out_slots = mem_manager.full_to_c4_indexs[infer_state.mem_index.long().reshape(-1)]
+ # only group-end tokens finish a c4 entry; mask the rest to -1 so the packer skips them
+ # (mid-group tokens share the group's c4 slot -> avoids racing a finished slot).
+ completed = ((positions + 1) % 4 == 0) & (out_slots >= 0)
+ masked_slots = torch.where(completed, out_slots, torch.full_like(out_slots, -1)).to(torch.int32)
+ mem_manager.pack_indexer_k_to_cache(self.layer_idx_, masked_slots, scratch)
+
+ def build_metadata(
+ self, x, q_lora, infer_state: DeepseekV4InferStateInfo, layer_weight, use_custom_tensor_manager=True
+ ):
+ """Return the final flash_mla index tensors for this layer's compress variant. swa indices and
+ the per-token req_idx are layer-independent and precomputed once in init_some_extra_state
+ (read here); only the c4 scorer is per-layer. The backend pairs these with the
+ (data-independent, layer-keyed) fp8 cache-byte views it owns."""
+ swa_indices = infer_state.dsv4_swa_indices.unsqueeze(1)
+ swa_lengths = infer_state.dsv4_swa_lengths
+ positions = infer_state.position_ids
+ extra_indices = extra_lengths = None
+ if self.compress_ratio == 4:
+ idx_q_fp8, weights = self._indexer_q_weight(
+ x, q_lora, infer_state, layer_weight, use_custom_tensor_manager=use_custom_tensor_manager
+ )
+ extra_indices, extra_lengths = self._c4_indices(infer_state, idx_q_fp8, weights, positions)
+ elif self.compress_ratio == 128:
+ extra_indices = infer_state.dsv4_c128_indices.unsqueeze(1)
+ extra_lengths = infer_state.dsv4_c128_lengths
+ return {
+ "swa_indices": swa_indices,
+ "swa_lengths": swa_lengths,
+ "extra_indices": extra_indices,
+ "extra_lengths": extra_lengths,
+ }
+
+ def _indexer_q_weight(
+ self, x, q_lora, infer_state: DeepseekV4InferStateInfo, layer_weight, use_custom_tensor_manager=True
+ ):
+ """fp8 indexer q (mirrors deepseek3_2 NsaInfer): wq_b -> rope(last rope dims) -> 1/sqrt(d)
+ Hadamard -> per-token fp8 quant. Returns (idx_q_fp8 [T,H,d], weights [T,H]); the per-token q
+ fp8 scale and the head_dim^-0.5 * n_heads^-0.5 score scale are folded into weights -- the
+ deep_gemm.fp8_mqa_logits contract (fp8 q carries no companion scale). Replicated -> full heads."""
+ # Fused: wq_b mm -> rope(last rope dims) -> 1/sqrt(d) Hadamard -> per-token fp8 quant, with the
+ # per-token q scale + indexer_weight_scale folded into weights, all in ONE kernel (was 4 kernels:
+ # rotary_emb_fwd + hadamard_transform + act_quant + weights mul). freqs_cis is the compress rope
+ # table (same one the main compress-layer Q path uses); positions indexed inside the kernel.
+ from lightllm.third_party.sglang_jit.dsv4.elementwise import fused_q_indexer_rope_hadamard_quant
+
+ token_num = q_lora.shape[0]
+ if x.shape[0] != token_num:
+ raise RuntimeError(
+ f"DeepSeek-V4 indexer expects full-token hidden states, got x={x.shape[0]} q_lora={token_num}"
+ )
+ idx_q = layer_weight.idx_wq_b_.mm(q_lora, use_custom_tensor_mananger=use_custom_tensor_manager).view(
+ token_num, self.index_n_heads, self.index_head_dim
+ )
+ raw_w = layer_weight.idx_weights_proj_.mm(x, use_custom_tensor_mananger=use_custom_tensor_manager).view(
+ token_num, self.index_n_heads
+ ) # [T, H] raw
+ idx_q_fp8, weights = fused_q_indexer_rope_hadamard_quant(
+ idx_q, raw_w, self.indexer_weight_scale, self.freqs_cis, infer_state.position_ids
+ ) # fp8 [T,H,d]; weights [T,H,1] with q-scale + weight_scale folded
+ return idx_q_fp8, weights.squeeze(-1).contiguous()
+
+ def _c4_indices(self, infer_state: DeepseekV4InferStateInfo, idx_q_fp8, weights, positions):
+ """c4 scorer via the page-safe deep_gemm.fp8_paged_mqa_logits over the paged c4 indexer pool,
+ then masked topk-512 -> c4 slots. Fixed shapes (c4_cap pinned per graph bucket) keep the decode
+ cuda graph capturable."""
+ mem_manager = infer_state.mem_manager
+ workspace = infer_state.dsv4_workspace
+ index_topk = self.index_topk
+ max_entries = max(1, int(infer_state.max_kv_seq_len) // 4)
+ c4_cap = ((max_entries + 63) // 64) * 64
+
+ # entry space fits the budget -> every causal entry is selected; no scoring needed. The
+ # captured decode graph (graph_max_len -> max_entries > topk) always takes the scorer branch
+ # below, so this only shortcuts tiny eager contexts.
+ if max_entries <= index_topk:
+ from ..triton_kernel.build_compress_index_dsv4 import build_compress_index
+
+ slots, lengths = workspace.c4(infer_state.microbatch_index, positions.shape[0], c4_cap)
+ slots, lengths = build_compress_index(
+ infer_state.dsv4_sparse_req_idx,
+ positions,
+ infer_state.req_manager.req_to_token_indexs,
+ mem_manager.full_to_c4_indexs,
+ 4,
+ slots,
+ lengths,
+ )
+ return slots.unsqueeze(1), lengths
+
+ c4_len = torch.div(infer_state.b_seq_len, 4, rounding_mode="floor").to(torch.int32) # entries/req
+
+ device = positions.device
+ page_size = mem_manager.c4_indexer_pool.page_size
+
+ # The page table / row_page_table / valid_len / ctx_lens / paged-logits metadata / topk_lengths
+ # are LAYER-INDEPENDENT (depend on request layout + c4_cap, not on weights/layer). Build them on
+ # the first c4 layer of the forward and reuse on the other ~20 c4 layers (was rebuilt per layer:
+ # build_c4_indexer_page_table + a [T,npages] gather + clamp/reshape + get_paged_mqa_logits_metadata
+ # each, i.e. ~20x redundant index/copy/clamp launches). Lazy (not init_some_extra_state) so it is
+ # computed inside the decode cuda graph with the capture-forced shapes -> no graph-cap mismatch.
+ cached = getattr(infer_state, "_c4_paged_meta", None)
+ if cached is None:
+ from ..triton_kernel.gather_c4_indexer_k_dsv4 import build_c4_indexer_page_table
+
+ b_req_idx = infer_state.b_req_idx
+ batch = b_req_idx.shape[0]
+ page_table = build_c4_indexer_page_table(
+ mem_manager,
+ b_req_idx,
+ c4_len,
+ c4_cap,
+ infer_state.req_manager.req_to_token_indexs,
+ infer_state.req_manager.HOLD_REQUEST_ID,
+ )
+
+ if infer_state.is_prefill:
+ token_batch_pos = torch.repeat_interleave(
+ torch.arange(batch, device=device, dtype=torch.int32),
+ infer_state.b_q_seq_len,
+ output_size=positions.numel(),
+ )
+ row_page_table = page_table[token_batch_pos.long()].contiguous()
+ else:
+ row_page_table = page_table
+
+ valid_len = ((positions + 1) // 4).to(torch.int32)
+ ctx_lens = torch.clamp(valid_len, min=1).reshape(-1, 1).contiguous()
+ metadata = deep_gemm.get_paged_mqa_logits_metadata(
+ ctx_lens,
+ page_size,
+ deep_gemm.get_num_sms(),
+ )
+ topk_lengths = torch.clamp(
+ torch.minimum(valid_len, torch.full_like(valid_len, index_topk)), min=1
+ ).contiguous()
+ cached = (row_page_table, valid_len, ctx_lens, metadata, topk_lengths)
+ infer_state._c4_paged_meta = cached
+
+ row_page_table, valid_len, ctx_lens, metadata, topk_lengths = cached
+ kv_cache = mem_manager.c4_indexer_pool.get_layer_buffer(mem_manager.layer_to_c4_idx[self.layer_idx_]).view(
+ mem_manager.c4_indexer_pool.num_pages,
+ page_size,
+ 1,
+ self.index_head_dim + 4,
+ )
+ top_slots, _ = workspace.c4(infer_state.microbatch_index, idx_q_fp8.shape[0], index_topk)
+ if infer_state.is_prefill:
+ rows_per_chunk = max(1, _C4_PREFILL_LOGITS_BUDGET_BYTES // (c4_cap * 4))
+ if idx_q_fp8.shape[0] > rows_per_chunk:
+ for start in range(0, idx_q_fp8.shape[0], rows_per_chunk):
+ end = min(start + rows_per_chunk, idx_q_fp8.shape[0])
+ chunk_ctx_lens = ctx_lens[start:end]
+ self._c4_score_topk(
+ idx_q_fp8[start:end],
+ kv_cache,
+ weights[start:end],
+ chunk_ctx_lens,
+ row_page_table[start:end],
+ deep_gemm.get_paged_mqa_logits_metadata(
+ chunk_ctx_lens,
+ page_size,
+ deep_gemm.get_num_sms(),
+ ),
+ c4_cap,
+ valid_len[start:end],
+ top_slots[start:end],
+ page_size,
+ )
+ return top_slots.unsqueeze(1), topk_lengths
+
+ self._c4_score_topk(
+ idx_q_fp8,
+ kv_cache,
+ weights,
+ ctx_lens,
+ row_page_table,
+ metadata,
+ c4_cap,
+ valid_len,
+ top_slots,
+ page_size,
+ )
+ return top_slots.unsqueeze(1), topk_lengths
+
+ @staticmethod
+ def _c4_score_topk(
+ idx_q_fp8,
+ kv_cache,
+ weights,
+ ctx_lens,
+ row_page_table,
+ metadata,
+ c4_cap,
+ valid_len,
+ top_slots,
+ page_size,
+ ):
+ logits = deep_gemm.fp8_paged_mqa_logits(
+ idx_q_fp8.unsqueeze(1),
+ kv_cache,
+ weights,
+ ctx_lens,
+ row_page_table,
+ metadata,
+ c4_cap,
+ False,
+ )
+ topk_transform_512(
+ logits,
+ valid_len,
+ row_page_table,
+ top_slots,
+ page_size,
+ )
diff --git a/lightllm/models/deepseek_v4/layer_weights/__init__.py b/lightllm/models/deepseek_v4/layer_weights/__init__.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/lightllm/models/deepseek_v4/layer_weights/pre_and_post_layer_weight.py b/lightllm/models/deepseek_v4/layer_weights/pre_and_post_layer_weight.py
new file mode 100644
index 000000000..54f29ce57
--- /dev/null
+++ b/lightllm/models/deepseek_v4/layer_weights/pre_and_post_layer_weight.py
@@ -0,0 +1,37 @@
+import torch
+from lightllm.common.basemodel import PreAndPostLayerWeight
+from lightllm.common.basemodel.layer_weights.meta_weights import (
+ EmbeddingWeight,
+ LMHeadWeight,
+ RMSNormWeight,
+ ParameterWeight,
+)
+
+
+class DeepseekV4PreAndPostLayerWeight(PreAndPostLayerWeight):
+ def __init__(self, data_type, network_config):
+ super().__init__(data_type, network_config)
+
+ hidden = network_config["hidden_size"]
+ vocab = network_config["vocab_size"]
+ hc_mult = network_config["hc_mult"]
+
+ # embeddings / lm_head / final norm (bf16, vocab tensor-parallel). V4 has no `model.` prefix
+ # and does not tie embeddings (tie_word_embeddings=false).
+ self.wte_weight_ = EmbeddingWeight(
+ dim=hidden, vocab_size=vocab, weight_name="embed.weight", data_type=self.data_type_
+ )
+ self.lm_head_weight_ = LMHeadWeight(
+ dim=hidden, vocab_size=vocab, weight_name="head.weight", data_type=self.data_type_
+ )
+ self.final_norm_weight_ = RMSNormWeight(dim=hidden, weight_name="norm.weight", data_type=self.data_type_)
+
+ # final hyper-connection head (collapses the hc_mult residual streams before the lm_head)
+ self.hc_head_fn_ = ParameterWeight(
+ weight_name="hc_head_fn", data_type=torch.float32, weight_shape=(hc_mult, hc_mult * hidden)
+ )
+ self.hc_head_base_ = ParameterWeight(
+ weight_name="hc_head_base", data_type=torch.float32, weight_shape=(hc_mult,)
+ )
+ self.hc_head_scale_ = ParameterWeight(weight_name="hc_head_scale", data_type=torch.float32, weight_shape=(1,))
+ return
diff --git a/lightllm/models/deepseek_v4/layer_weights/transformer_layer_weight.py b/lightllm/models/deepseek_v4/layer_weights/transformer_layer_weight.py
new file mode 100644
index 000000000..5ecce0a76
--- /dev/null
+++ b/lightllm/models/deepseek_v4/layer_weights/transformer_layer_weight.py
@@ -0,0 +1,334 @@
+import torch
+from lightllm.common.basemodel import TransformerLayerWeight
+from lightllm.common.basemodel.layer_weights.meta_weights import (
+ ROWMMWeight,
+ COLMMWeight,
+ ROWBMMWeight,
+ RMSNormWeight,
+ ParameterWeight,
+ TpAttSinkWeight,
+ FusedMoeWeight,
+)
+from ..triton_kernel.quant_convert import dequant_fp8_block_to_bf16
+
+
+class DeepseekV4TransformerLayerWeight(TransformerLayerWeight):
+ """Per-layer weights for DeepSeek-V4-Flash.
+
+ DS4 does not share DS2/DS3.2's ``model.layers.*.self_attn/mlp`` layout. Its attention is
+ HC + CSA, and routed experts are checkpointed as MXFP4 (fp4 release) or
+ FP8 block-128 (fp8 release, same layout as the dense fp8 weights).
+ """
+
+ def __init__(self, layer_num, data_type, network_config, quant_cfg=None):
+ super().__init__(layer_num, data_type, network_config, quant_cfg)
+ return
+
+ def _parse_config(self):
+ cfg = self.network_config_
+ self.hidden = cfg["hidden_size"]
+ self.n_heads = cfg["num_attention_heads"]
+ self.head_dim = cfg["head_dim"]
+ self.rope_dim = cfg["qk_rope_head_dim"]
+ self.q_lora_rank = cfg["q_lora_rank"]
+ self.o_lora_rank = cfg["o_lora_rank"]
+ self.o_groups = cfg["o_groups"]
+ self.index_n_heads = cfg["index_n_heads"]
+ self.index_head_dim = cfg["index_head_dim"]
+ self.n_routed_experts = cfg["n_routed_experts"]
+ self.moe_inter = cfg["moe_intermediate_size"]
+ self.num_hash_layers = cfg["num_hash_layers"]
+ self.vocab_size = cfg["vocab_size"]
+ self.hc_mult = cfg["hc_mult"]
+ self.mix_hc = (2 + self.hc_mult) * self.hc_mult
+ self.compress_ratio = cfg["compress_ratios"][self.layer_num_]
+ self.has_compressor = self.compress_ratio != 0
+ self.has_indexer = self.compress_ratio == 4
+ self.is_hash = self.layer_num_ < self.num_hash_layers
+ assert self.n_heads % self.tp_world_size_ == 0
+ assert self.o_groups % self.tp_world_size_ == 0
+ self.prefix = f"layers.{self.layer_num_}"
+
+ def _init_weight(self):
+ self._init_qkvo()
+ if self.has_compressor:
+ self._init_compressor()
+ if self.has_indexer:
+ self._init_indexer()
+ self._init_moe()
+ self._init_norm()
+ self._init_hyper_connection()
+
+ # ------------------------------------------------------------------ attention
+ def _init_qkvo(self):
+ p = f"{self.prefix}.attn"
+ # q low-rank A and kv (single replicated head) both consume the same attention input ->
+ # fuse into one fp8 GEMM; _get_qkv splits the [q_lora_rank | head_dim] output. (q_b is
+ # column-parallel over heads.)
+ self.wq_a_wkv_ = ROWMMWeight(
+ in_dim=self.hidden,
+ out_dims=[self.q_lora_rank, self.head_dim],
+ weight_names=[f"{p}.wq_a.weight", f"{p}.wkv.weight"],
+ data_type=self.data_type_,
+ quant_method=self.get_quant_method("wq_a"),
+ tp_rank=0,
+ tp_world_size=1,
+ )
+ self.wq_b_ = ROWMMWeight(
+ in_dim=self.q_lora_rank,
+ out_dims=[self.n_heads * self.head_dim],
+ weight_names=f"{p}.wq_b.weight",
+ data_type=self.data_type_,
+ quant_method=self.get_quant_method("wq_b"),
+ )
+ self.q_norm_ = RMSNormWeight(dim=self.q_lora_rank, weight_name=f"{p}.q_norm.weight", data_type=self.data_type_)
+ self.kv_norm_ = RMSNormWeight(dim=self.head_dim, weight_name=f"{p}.kv_norm.weight", data_type=self.data_type_)
+ self.attn_sink_ = TpAttSinkWeight(
+ all_q_head_num=self.n_heads, weight_name=f"{p}.attn_sink", data_type=torch.float32
+ )
+ # grouped low-rank output projection (wo_a per-group [in, o_lora], wo_b row-parallel
+ # [groups*o_lora -> hidden]).
+ per_group_in = self.n_heads * self.head_dim // self.o_groups
+ # When o_groups == tp_world_size (e.g. the daily tp8 config) each rank owns exactly ONE
+ # group, so the grouped O-proj collapses to a single GEMM -> run it in fp8 (deepgemm)
+ # instead of dequantizing wo_a to bf16. sglang does the same (fp8 wo_a is default-on there).
+ # For >1 group per rank (tp < o_groups) the per-group inputs differ (block-diagonal), so
+ # keep the bf16 grouped bmm.
+ self.o_proj_fp8 = (self.o_groups // self.tp_world_size_) == 1
+ if self.o_proj_fp8:
+ self.wo_a_ = ROWMMWeight(
+ in_dim=per_group_in,
+ out_dims=[self.o_groups * self.o_lora_rank],
+ weight_names=f"{p}.wo_a.weight",
+ data_type=self.data_type_,
+ quant_method=self.get_quant_method("wo_a"),
+ )
+ else:
+ self.wo_a_ = ROWBMMWeight(
+ dim0=self.o_groups,
+ dim1=per_group_in,
+ dim2=self.o_lora_rank,
+ weight_names=f"{p}.wo_a.weight",
+ data_type=self.data_type_,
+ quant_method=None,
+ )
+ self.wo_b_ = COLMMWeight(
+ in_dim=self.o_groups * self.o_lora_rank,
+ out_dims=[self.hidden],
+ weight_names=f"{p}.wo_b.weight",
+ data_type=self.data_type_,
+ quant_method=self.get_quant_method("wo_b"),
+ )
+
+ # ------------------------------------------------------------------ compressor / indexer
+ def _init_compressor(self):
+ prefix = f"{self.prefix}.attn.compressor"
+ head_dim = self.head_dim
+ ratio = self.compress_ratio
+
+ coff = 2 if ratio == 4 else 1
+ # wkv/wgate are bf16 (no scale) and replicated (single KV head).
+ self.compressor_wkv_gate_ = ROWMMWeight(
+ in_dim=self.hidden,
+ out_dims=[coff * head_dim, coff * head_dim],
+ weight_names=[f"{prefix}.wkv.weight", f"{prefix}.wgate.weight"],
+ data_type=self.data_type_,
+ quant_method=None,
+ tp_rank=0,
+ tp_world_size=1,
+ )
+ self.compressor_norm_ = RMSNormWeight(
+ dim=head_dim, weight_name=f"{prefix}.norm.weight", data_type=self.data_type_
+ )
+ self.compressor_ape_ = ParameterWeight(
+ weight_name=f"{prefix}.ape", data_type=torch.float32, weight_shape=(ratio, coff * head_dim)
+ )
+
+ def _init_indexer(self):
+ p = f"{self.prefix}.attn.indexer"
+ # The Lightning-Indexer is REPLICATED across TP ranks (like sglang/vllm), not head-sharded:
+ # q_lora and the attn input are already full on every rank, so each rank scores all
+ # index_n_heads locally and the c4 top-k is identical everywhere -- no gather/all_reduce.
+ # wq_b is FP8 in the checkpoint -> de-quantized to bf16 at load.
+ self.idx_wq_b_ = ROWMMWeight(
+ in_dim=self.q_lora_rank,
+ out_dims=[self.index_n_heads * self.index_head_dim],
+ weight_names=f"{p}.wq_b.weight",
+ data_type=self.data_type_,
+ quant_method=self.get_quant_method("idx_wq_b"),
+ tp_rank=0,
+ tp_world_size=1,
+ )
+ self.idx_weights_proj_ = ROWMMWeight(
+ in_dim=self.hidden,
+ out_dims=[self.index_n_heads],
+ weight_names=f"{p}.weights_proj.weight",
+ data_type=self.data_type_,
+ quant_method=None,
+ tp_rank=0,
+ tp_world_size=1,
+ )
+ coff = 2 # indexer compressor always uses ratio 4 (overlap)
+ # wkv/wgate share the same input -> one fused bf16 GEMM producing the [kv | gate] layout
+ # directly (same as the attention compressor_wkv_gate_).
+ self.idx_cmp_wkv_gate_ = ROWMMWeight(
+ in_dim=self.hidden,
+ out_dims=[coff * self.index_head_dim, coff * self.index_head_dim],
+ weight_names=[f"{p}.compressor.wkv.weight", f"{p}.compressor.wgate.weight"],
+ data_type=self.data_type_,
+ quant_method=None,
+ tp_rank=0,
+ tp_world_size=1,
+ )
+ self.idx_cmp_norm_ = RMSNormWeight(
+ dim=self.index_head_dim, weight_name=f"{p}.compressor.norm.weight", data_type=self.data_type_
+ )
+ self.idx_cmp_ape_ = ParameterWeight(
+ weight_name=f"{p}.compressor.ape", data_type=torch.float32, weight_shape=(4, coff * self.index_head_dim)
+ )
+
+ # ------------------------------------------------------------------ moe
+ def _init_moe(self):
+ p = f"{self.prefix}.ffn"
+ # Router gate in bf16 (matches the sglang/vLLM DeepSeek references, which run the gate GEMM in
+ # the model dtype); the bf16 GEMM output is cast back to fp32 in _ffn for topk_hash_softplus_sqrt.
+ self.gate_weight_ = ROWMMWeight(
+ in_dim=self.hidden,
+ out_dims=[self.n_routed_experts],
+ weight_names=f"{p}.gate.weight",
+ data_type=torch.bfloat16,
+ quant_method=None,
+ tp_rank=0,
+ tp_world_size=1,
+ )
+ if self.is_hash:
+ self.gate_tid2eid_ = ParameterWeight(
+ weight_name=f"{p}.gate.tid2eid",
+ data_type=torch.int64,
+ weight_shape=(self.vocab_size, self.network_config_["num_experts_per_tok"]),
+ )
+ else:
+ self.gate_bias_ = ParameterWeight(
+ weight_name=f"{p}.gate.bias", data_type=torch.float32, weight_shape=(self.n_routed_experts,)
+ )
+ # shared expert (dense, bf16 after de-quant): w1=gate, w3=up fused (row), w2=down (col).
+ # Named gate_up_proj/down_proj so the inherited Llama `_ffn_tp` (fused gate_up matmul +
+ # silu_and_mul triton kernel, no swiglu clamp) drives it directly. Order [w1, w3] = [gate, up]
+ # matches silu_and_mul_fwd's blocked layout (first half gate, second half up).
+ sp = f"{p}.shared_experts"
+ self.gate_up_proj = ROWMMWeight(
+ in_dim=self.hidden,
+ out_dims=[self.moe_inter, self.moe_inter],
+ weight_names=[f"{sp}.w1.weight", f"{sp}.w3.weight"],
+ data_type=self.data_type_,
+ quant_method=self.get_quant_method("shared_gate"),
+ )
+ self.down_proj = COLMMWeight(
+ in_dim=self.moe_inter,
+ out_dims=[self.hidden],
+ weight_names=f"{sp}.w2.weight",
+ data_type=self.data_type_,
+ quant_method=self.get_quant_method("shared_down"),
+ )
+ self.experts_ = FusedMoeWeight(
+ gate_proj_name="w1",
+ down_proj_name="w2",
+ up_proj_name="w3",
+ e_score_correction_bias_name="",
+ weight_prefix=f"{p}.experts",
+ n_routed_experts=self.n_routed_experts,
+ hidden_size=self.hidden,
+ moe_intermediate_size=self.moe_inter,
+ data_type=self.data_type_,
+ quant_method=self.quant_cfg.get_quant_method(self.layer_num_, "fused_moe"),
+ layer_num=self.layer_num_,
+ network_config=self.network_config_,
+ )
+
+ def _init_norm(self):
+ self.attn_norm_ = RMSNormWeight(
+ dim=self.hidden, weight_name=f"{self.prefix}.attn_norm.weight", data_type=self.data_type_
+ )
+ self.ffn_norm_ = RMSNormWeight(
+ dim=self.hidden, weight_name=f"{self.prefix}.ffn_norm.weight", data_type=self.data_type_
+ )
+
+ def _init_hyper_connection(self):
+ p = self.prefix
+ self.hc_attn_fn_ = ParameterWeight(
+ weight_name=f"{p}.hc_attn_fn",
+ data_type=torch.float32,
+ weight_shape=(self.mix_hc, self.hc_mult * self.hidden),
+ )
+ self.hc_attn_base_ = ParameterWeight(
+ weight_name=f"{p}.hc_attn_base", data_type=torch.float32, weight_shape=(self.mix_hc,)
+ )
+ self.hc_attn_scale_ = ParameterWeight(
+ weight_name=f"{p}.hc_attn_scale", data_type=torch.float32, weight_shape=(3,)
+ )
+ self.hc_ffn_fn_ = ParameterWeight(
+ weight_name=f"{p}.hc_ffn_fn",
+ data_type=torch.float32,
+ weight_shape=(self.mix_hc, self.hc_mult * self.hidden),
+ )
+ self.hc_ffn_base_ = ParameterWeight(
+ weight_name=f"{p}.hc_ffn_base", data_type=torch.float32, weight_shape=(self.mix_hc,)
+ )
+ self.hc_ffn_scale_ = ParameterWeight(
+ weight_name=f"{p}.hc_ffn_scale", data_type=torch.float32, weight_shape=(3,)
+ )
+
+ # ------------------------------------------------------------------ loading
+ def load_hf_weights(self, weights):
+ self._dequant_in_place(weights)
+ return super().load_hf_weights(weights)
+
+ def _fp8_scale_renames(self):
+ """Map weight name -> the scale name its quant method loads (e.g. `weight_scale_inv`
+ for DeepGEMM). Read from each MM weight's own `weight_scale_names`, so the rename
+ target always matches what that weight will look up; no-quant weights have None
+ entries and are skipped."""
+ renames = {}
+ for attr in self.__dict__.values():
+ weight_names = getattr(attr, "weight_names", ())
+ scale_names = getattr(attr, "weight_scale_names", ())
+ for weight_name, scale_name in zip(weight_names, scale_names):
+ if scale_name is not None:
+ renames[weight_name] = scale_name
+ return renames
+
+ def _dequant_in_place(self, weights):
+ p = self.prefix + "."
+ scale_renames = self._fp8_scale_renames()
+ # Convert every `.scale` belonging to this layer. Weights are loaded incrementally
+ # per safetensors shard, so the paired weight may live in another shard:
+ # - routed expert `.scale` follows the fused_moe quant method's weight_scale_suffix:
+ # MXFP4 consumes `.scale` as-is, FP8 DeepGEMM expects `.weight_scale_inv` (rename only);
+ # - FP8 matmul scales only need renaming for DeepGEMM, no weight required;
+ # - FP8 pairs on no-quant paths (wo_a's ROWBMMWeight) are expanded to bf16,
+ # the only case that truly requires weight and scale in the same shard.
+ expert_scale_suffix = self.experts_.quant_method.weight_scale_suffix
+ for scale_k in [k for k in list(weights.keys()) if k.startswith(p) and k.endswith(".scale")]:
+ if scale_k.startswith(f"{p}ffn.experts."):
+ if expert_scale_suffix is not None and expert_scale_suffix != "scale":
+ weights[scale_k[: -len("scale")] + expert_scale_suffix] = weights[scale_k].to(torch.float32)
+ del weights[scale_k]
+ continue
+ k = scale_k[: -len(".scale")] + ".weight"
+ target = scale_renames.get(k)
+ if target is not None: # FP8 e4m3, block-128 scale, run by DeepGEMM directly
+ weights[target] = weights[scale_k].to(torch.float32)
+ del weights[scale_k]
+ else:
+ weights[k] = dequant_fp8_block_to_bf16(weights[k], weights[scale_k]).to(self.data_type_)
+ del weights[scale_k]
+ # grouped-O (bf16 path only): reshape [groups*o_lora, in] -> [groups, in, o_lora] for the
+ # batched matmul. The fp8 path keeps wo_a as a plain [groups*o_lora, in] fp8 GEMM weight
+ # (its `.scale` is renamed to `.weight_scale_inv` by the loop above, not dequantized).
+ if not self.o_proj_fp8:
+ woa = f"{self.prefix}.attn.wo_a.weight"
+ if woa in weights and weights[woa].dim() == 2:
+ w = weights[woa]
+ per_group_in = self.n_heads * self.head_dim // self.o_groups
+ weights[woa] = w.view(self.o_groups, self.o_lora_rank, per_group_in).transpose(1, 2).contiguous()
+ return
diff --git a/lightllm/models/deepseek_v4/model.py b/lightllm/models/deepseek_v4/model.py
new file mode 100644
index 000000000..edb2ce15d
--- /dev/null
+++ b/lightllm/models/deepseek_v4/model.py
@@ -0,0 +1,296 @@
+import copy
+import importlib.util
+import json
+import os
+
+import torch
+from lightllm.models.registry import ModelRegistry
+from lightllm.models.llama.model import LlamaTpPartModel
+from lightllm.common.req_manager import DeepseekV4ReqManager
+from lightllm.common.kv_cache_mem_manager import DeepseekV4MemoryManager
+from lightllm.models.deepseek_v4.layer_weights.pre_and_post_layer_weight import (
+ DeepseekV4PreAndPostLayerWeight,
+)
+from lightllm.models.deepseek_v4.layer_weights.transformer_layer_weight import (
+ DeepseekV4TransformerLayerWeight,
+)
+from lightllm.models.deepseek_v4.layer_infer.pre_layer_infer import (
+ DeepseekV4PreLayerInfer,
+)
+from lightllm.models.deepseek_v4.layer_infer.post_layer_infer import (
+ DeepseekV4PostLayerInfer,
+)
+from lightllm.models.deepseek_v4.layer_infer.transformer_layer_infer import (
+ DeepseekV4TransformerLayerInfer,
+)
+from lightllm.common.basemodel.attention import get_nsa_prefill_att_backend_class, get_nsa_decode_att_backend_class
+from lightllm.models.deepseek_v4.infer_struct import DeepseekV4InferStateInfo
+from lightllm.models.deepseek_v4.workspace import DeepseekV4Workspace
+from lightllm.models.llama.yarn_rotary_utils import (
+ find_correction_range,
+ linear_ramp_mask,
+)
+from lightllm.utils.envs_utils import get_added_mtp_kv_layer_num, get_env_start_args
+from lightllm.utils.log_utils import init_logger
+from lightllm.distributed.communication_op import dist_group_manager
+
+logger = init_logger(__name__)
+
+
+@ModelRegistry("deepseek_v4")
+class DeepseekV4TpPartModel(LlamaTpPartModel):
+ req_manager: DeepseekV4ReqManager
+ mem_manager: DeepseekV4MemoryManager
+
+ pre_and_post_weight_class = DeepseekV4PreAndPostLayerWeight
+ transformer_weight_class = DeepseekV4TransformerLayerWeight
+
+ pre_layer_infer_class = DeepseekV4PreLayerInfer
+ post_layer_infer_class = DeepseekV4PostLayerInfer
+ transformer_layer_infer_class = DeepseekV4TransformerLayerInfer
+
+ infer_state_class = DeepseekV4InferStateInfo
+
+ def _verify_params(self):
+ assert self.load_way == "HF", "only support HF format weights"
+ assert self.config["num_attention_heads"] % self.tp_world_size_ == 0
+ assert self.config["o_groups"] % self.tp_world_size_ == 0
+ assert self.config["index_n_heads"] % self.tp_world_size_ == 0
+ return
+
+ def _init_req_manager(self):
+ create_max_seq_len = 0
+ if self.batch_max_tokens is not None:
+ create_max_seq_len = max(create_max_seq_len, self.batch_max_tokens)
+ if self.max_seq_length is not None:
+ create_max_seq_len = max(create_max_seq_len, self.max_seq_length)
+
+ self._dsv4_req_manager_seq_len = create_max_seq_len
+ layer_num = self.config["n_layer"] + get_added_mtp_kv_layer_num()
+ self._dsv4_compress_rates = self._get_compress_rates(layer_num)
+ self.req_manager = DeepseekV4ReqManager(
+ self.max_req_num,
+ create_max_seq_len,
+ compress_rates=self._dsv4_compress_rates,
+ head_dim=self.config["head_dim"],
+ indexer_head_dim=self.config["index_head_dim"],
+ sliding_window=self.config["sliding_window"],
+ )
+ return
+
+ def _get_compress_rates(self, layer_num):
+ rates = list(self.config["compress_ratios"])
+ return rates[:layer_num]
+
+ def _init_mem_manager(self):
+ layer_num = self.config["n_layer"] + get_added_mtp_kv_layer_num()
+ compress_rates = getattr(self, "_dsv4_compress_rates", self._get_compress_rates(layer_num))
+ sliding_window = int(self.config["sliding_window"])
+ self.mem_manager = DeepseekV4MemoryManager(
+ self.max_total_token_num,
+ dtype=self.data_type,
+ head_num=1,
+ head_dim=self.config["head_dim"],
+ layer_num=layer_num,
+ compress_rates=compress_rates,
+ indexer_head_dim=self.config["index_head_dim"],
+ max_request_num=self.max_req_num,
+ sliding_window=sliding_window,
+ mem_fraction=self.mem_fraction,
+ )
+ self.req_manager.mem_manager = self.mem_manager
+ return
+
+ def _init_att_backend(self):
+ args = get_env_start_args()
+ if args.llm_kv_type == "None":
+ args.llm_kv_type = "fp8kv_dsa"
+ # TODO: 支持其他 kv type
+ if args.llm_kv_type != "fp8kv_dsa":
+ raise RuntimeError("DeepSeek-V4 requires llm_kv_type=fp8kv_dsa for packed FlashMLA sparse attention")
+ self.prefill_att_backend = get_nsa_prefill_att_backend_class(index=0)(model=self)
+ self.decode_att_backend = get_nsa_decode_att_backend_class(index=0)(model=self)
+ return
+
+ def _init_custom(self):
+ self._init_to_get_rotary()
+ self.dsv4_workspace = DeepseekV4Workspace(self)
+ if os.getenv("LIGHTLLM_DSV4_PREFILL_OVERLAP", "1") == "1":
+ prefill_aux_stream = torch.cuda.Stream()
+ for layer in self.layers_infer:
+ layer.dsv4_prefill_aux_stream = prefill_aux_stream
+ dist_group_manager.new_deepep_group(
+ self.config["n_routed_experts"],
+ self.config["hidden_size"],
+ self.config.get("num_experts_per_tok", 1),
+ self.config.get("moe_intermediate_size", self.config.get("intermediate_size")),
+ )
+ return
+
+ def _init_to_get_rotary(self):
+ # Interleaved (GPT-J) rope. Build complex64 freqs_cis tables (_freqs_cis_*) following the
+ # gemma4 two-variant convention; the fused sglang q kernel consumes them directly, while
+ # _cos_cached_*/_sin_cached_* are .real/.imag views of the same storage for the kv rope,
+ # inverse rope and compressor paths (deepseek2's interleaved triton rotary_emb_fwd).
+ # Sliding-window layers use base rope_theta (no YaRN);
+ # compressed (CSA/HCA) layers use compress_rope_theta with configured rope_scaling.
+ # Kept fp32 for accuracy (the apply upcasts anyway).
+ cfg = self.config
+ rs = cfg.get("rope_scaling", {}) or {}
+ dim = cfg["qk_rope_head_dim"]
+ # The rope tables MUST span every absolute position any request can produce (the served
+ # max_req_total_len / max_position_embeddings, up to 1M). Capping them shorter makes
+ # init_some_extra_state's index_select(cos/sin, position_ids) read OOB past the table at
+ # contexts beyond the cap (device-side assert / crash). ~268MB total at 1M, fp32x32 x4 views.
+ max_seq = max(int(self.max_seq_length), int(cfg.get("max_position_embeddings", 8192)))
+ freq_exponents = torch.arange(0, dim, 2, dtype=torch.float32, device="cuda") / dim
+ positions = torch.arange(max_seq, dtype=torch.float32, device="cuda")
+
+ sliding_freqs = 1.0 / (cfg["rope_theta"] ** freq_exponents)
+ f = torch.outer(positions, sliding_freqs) # [max_seq, dim//2]
+ self._freqs_cis_sliding = torch.complex(f.cos(), f.sin())
+
+ compress_freqs = 1.0 / (cfg["compress_rope_theta"] ** freq_exponents)
+ rope_type = rs.get("rope_type", rs.get("type", "default"))
+ orig_max = rs.get("original_max_position_embeddings", 0)
+ if rope_type == "yarn" and orig_max > 0:
+ beta_fast = rs.get("beta_fast", 32)
+ beta_slow = rs.get("beta_slow", 1)
+ factor = rs.get("factor", 1)
+ if factor is None:
+ factor = cfg.get("max_position_embeddings", max_seq) / orig_max
+ low, high = find_correction_range(beta_fast, beta_slow, dim, cfg["compress_rope_theta"], orig_max)
+ smooth = 1 - linear_ramp_mask(low, high, dim // 2).cuda()
+ compress_freqs = compress_freqs / factor * (1 - smooth) + compress_freqs * smooth
+ f = torch.outer(positions, compress_freqs) # [max_seq, dim//2]
+ self._freqs_cis_compress = torch.complex(f.cos(), f.sin())
+ self._cos_cached_sliding = self._freqs_cis_sliding.real
+ self._sin_cached_sliding = self._freqs_cis_sliding.imag
+ self._cos_cached_compress = self._freqs_cis_compress.real
+ self._sin_cached_compress = self._freqs_cis_compress.imag
+ # Each layer uses exactly one rope variant; wire its table once here (layers are already
+ # built: _init_infer_layer runs before _init_custom) instead of relaying via infer_state.
+ # The compressor needs the full compress tables (entry rope positions != token positions).
+ for layer in self.layers_infer:
+ layer.freqs_cis = self._freqs_cis_compress if layer.compress_ratio else self._freqs_cis_sliding
+ layer.cos_compress_table = self._cos_cached_compress
+ layer.sin_compress_table = self._sin_cached_compress
+ # the indexer-Q fused kernel (compress rope) needs the complex compress freqs table.
+ if getattr(layer, "index_infer", None) is not None:
+ layer.index_infer.freqs_cis = self._freqs_cis_compress
+ return
+
+
+class DeepSeekV4Tokenizer:
+ """Tokenizer wrapper for DeepSeek-V4's Python prompt encoding."""
+
+ # DeepSeek-V4 has a per-request thinking mode (...) toggled via
+ # chat_template_kwargs={"thinking": true}. It has no Jinja chat_template string,
+ # so advertise thinking support explicitly for tokenizer_supports_force_thinking().
+ supports_thinking = True
+
+ def __init__(self, tokenizer, model_dir):
+ self.tokenizer = tokenizer
+ self.model_dir = model_dir
+ self._encoding_module = None
+ self._added_vocab = None
+
+ def __getattr__(self, name):
+ return getattr(self.tokenizer, name)
+
+ def get_added_vocab(self):
+ if self._added_vocab is None:
+ self._added_vocab = self.tokenizer.get_added_vocab()
+ return self._added_vocab
+
+ def _get_encoding_module(self):
+ if self._encoding_module is not None:
+ return self._encoding_module
+
+ # Prefer the encoder shipped inside the model dir (respects any model-specific
+ # customization); fall back to the copy vendored in this repo, because some
+ # DeepSeek-V4 releases (e.g. the FP8 weights) do NOT ship an encoding/ dir.
+ # vLLM/sglang likewise vendor this encoder in-tree instead of depending on the
+ # model directory.
+ encoding_path = os.path.join(self.model_dir, "encoding", "encoding_dsv4.py")
+ if not os.path.exists(encoding_path):
+ encoding_path = os.path.join(os.path.dirname(__file__), "encoding", "encoding_dsv4.py")
+ if not os.path.exists(encoding_path):
+ raise FileNotFoundError(f"DeepSeek-V4 encoding file not found: {encoding_path}")
+
+ spec = importlib.util.spec_from_file_location("lightllm_deepseek_v4_encoding_dsv4", encoding_path)
+ if spec is None or spec.loader is None:
+ raise ImportError(f"failed to load DeepSeek-V4 encoding module from {encoding_path}")
+ module = importlib.util.module_from_spec(spec)
+ spec.loader.exec_module(module)
+ self._encoding_module = module
+ return module
+
+ def apply_chat_template(
+ self,
+ conversation=None,
+ messages=None,
+ tools=None,
+ tokenize=False,
+ add_generation_prompt=True,
+ thinking=None,
+ enable_thinking=None,
+ **kwargs,
+ ):
+ msgs = conversation if conversation is not None else messages
+ if msgs is None:
+ raise ValueError("Either 'conversation' or 'messages' must be provided")
+
+ msgs = copy.deepcopy(msgs)
+
+ # The model's DSML encoder (encode_arguments_to_dsml in encoding_dsv4.py) expects
+ # function.arguments as a JSON string and parses it internally. Upstream,
+ # build_prompt._normalize_tool_call_arguments converts arguments from the OpenAI
+ # JSON string to a dict (needed by Qwen3.x-style Jinja templates). A dict hits the
+ # encoder's except-branch and gets wrapped under a single name="arguments" param,
+ # which the model then imitates and amplifies across turns until required fields go
+ # missing. Re-serialize dicts back to JSON strings so the encoder emits one
+ # per real arg.
+ for msg in msgs:
+ for tc in msg.get("tool_calls") or []:
+ fn = tc.get("function")
+ if isinstance(fn, dict) and isinstance(fn.get("arguments"), dict):
+ fn["arguments"] = json.dumps(fn["arguments"], ensure_ascii=False)
+
+ if tools:
+ wrapped_tools = []
+ for tool in tools:
+ if "function" in tool:
+ wrapped_tools.append(tool)
+ else:
+ wrapped_tools.append({"type": "function", "function": tool})
+
+ injected = False
+ for msg in msgs:
+ if msg.get("role") == "system":
+ existing = msg.get("tools") or []
+ msg["tools"] = existing + wrapped_tools
+ injected = True
+ break
+
+ if not injected:
+ msgs.insert(0, {"role": "system", "content": "", "tools": wrapped_tools})
+
+ if thinking is None:
+ thinking = bool(enable_thinking) if enable_thinking is not None else False
+ thinking_mode = "thinking" if thinking else "chat"
+ effort = kwargs.get("reasoning_effort")
+ if effort not in ("max", "high", None):
+ effort = None
+ encoding = self._get_encoding_module()
+ prompt = encoding.encode_messages(
+ msgs,
+ thinking_mode=thinking_mode,
+ drop_thinking=kwargs.get("drop_thinking", True),
+ add_default_bos_token=kwargs.get("add_default_bos_token", True),
+ reasoning_effort=effort,
+ )
+
+ if tokenize:
+ return self.tokenizer.encode(prompt, add_special_tokens=False)
+ return prompt
diff --git a/lightllm/models/deepseek_v4/triton_kernel/__init__.py b/lightllm/models/deepseek_v4/triton_kernel/__init__.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/lightllm/models/deepseek_v4/triton_kernel/build_compress_index_dsv4.py b/lightllm/models/deepseek_v4/triton_kernel/build_compress_index_dsv4.py
new file mode 100644
index 000000000..dae5121ab
--- /dev/null
+++ b/lightllm/models/deepseek_v4/triton_kernel/build_compress_index_dsv4.py
@@ -0,0 +1,79 @@
+import torch
+import triton
+import triton.language as tl
+
+
+@triton.jit
+def _build_compress_index_kernel(
+ req_idx_ptr,
+ pos_ptr,
+ req_to_token_ptr,
+ req_to_token_stride0,
+ full_to_c_ptr,
+ index_ptr,
+ index_stride0,
+ length_ptr,
+ cap,
+ RATIO: tl.constexpr,
+ BLOCK_E: tl.constexpr,
+):
+ t = tl.program_id(0)
+ eb = tl.program_id(1)
+ req = tl.load(req_idx_ptr + t).to(tl.int64)
+ pos = tl.load(pos_ptr + t).to(tl.int64)
+ raw_len = (pos + 1) // RATIO
+
+ e = eb * BLOCK_E + tl.arange(0, BLOCK_E)
+ e_mask = e < cap
+ valid = (e < raw_len) & e_mask
+ # group-end token of compressed entry e: position e*RATIO + (RATIO-1).
+ end_pos = e * RATIO + (RATIO - 1)
+ safe_pos = tl.where(valid, end_pos, 0)
+ full_slot = tl.load(req_to_token_ptr + req * req_to_token_stride0 + safe_pos, mask=valid, other=0).to(tl.int64)
+ c_slot = tl.load(full_to_c_ptr + full_slot, mask=valid, other=-1).to(tl.int32)
+ tl.store(index_ptr + t * index_stride0 + e, c_slot, mask=e_mask)
+
+ if eb == 0:
+ tl.store(length_ptr + t, tl.maximum(raw_len, 1).to(tl.int32))
+
+
+def build_compress_index(
+ req_idx: torch.Tensor,
+ positions: torch.Tensor,
+ req_to_token_indexs: torch.Tensor,
+ full_to_c_indexs: torch.Tensor,
+ ratio: int,
+ index: torch.Tensor,
+ length: torch.Tensor,
+):
+ """Fused two-level group-end gather for the c4/c128 compressed-entry index tables.
+
+ For token t (at request `req_idx[t]`, absolute `positions[t]`) and compressed entry e:
+ slot[t, e] = full_to_c[ req_to_token[req, e*ratio + (ratio-1)] ] (the group-end token's full slot)
+ with slot = -1 where e >= (pos+1)//ratio (beyond the causal compressed length) or where the
+ full->c map is unset. Writes index [T, cap] and length [T] = clamp((pos+1)//ratio, 1).
+
+ Replaces the eager _gather_compress_slots/_c128/c4-causal torch chain. The caller owns the
+ output storage, so this wrapper does not allocate on the hot path.
+ """
+ T = positions.shape[0]
+ cap = index.shape[1]
+ if T == 0:
+ return index, length
+ BLOCK_E = 256
+ grid = (T, triton.cdiv(cap, BLOCK_E))
+ _build_compress_index_kernel[grid](
+ req_idx,
+ positions,
+ req_to_token_indexs,
+ req_to_token_indexs.stride(0),
+ full_to_c_indexs,
+ index,
+ index.stride(0),
+ length,
+ cap,
+ RATIO=ratio,
+ BLOCK_E=BLOCK_E,
+ num_warps=4,
+ )
+ return index, length
diff --git a/lightllm/models/deepseek_v4/triton_kernel/build_swa_index_dsv4.py b/lightllm/models/deepseek_v4/triton_kernel/build_swa_index_dsv4.py
new file mode 100644
index 000000000..a1ef2d5be
--- /dev/null
+++ b/lightllm/models/deepseek_v4/triton_kernel/build_swa_index_dsv4.py
@@ -0,0 +1,72 @@
+import torch
+import triton
+import triton.language as tl
+
+
+@triton.jit
+def _build_swa_index_kernel(
+ req_idx_ptr,
+ pos_ptr,
+ req_to_token_ptr,
+ req_to_token_stride0,
+ full_to_swa_ptr,
+ swa_index_ptr,
+ swa_index_stride0,
+ swa_length_ptr,
+ WINDOW: tl.constexpr,
+ BLOCK_W: tl.constexpr,
+):
+ token_idx = tl.program_id(0)
+ req = tl.load(req_idx_ptr + token_idx).to(tl.int64)
+ pos = tl.load(pos_ptr + token_idx).to(tl.int64)
+
+ w = tl.arange(0, BLOCK_W)
+ w_mask = w < WINDOW
+ # most-recent-first window, identical to the eager _swa_indices (offset = position - arange).
+ offset = pos - w
+ valid = (offset >= 0) & w_mask
+ safe_offset = tl.where(valid, offset, 0)
+ full_slot = tl.load(req_to_token_ptr + req * req_to_token_stride0 + safe_offset, mask=valid, other=0).to(tl.int64)
+ swa_slot = tl.load(full_to_swa_ptr + full_slot, mask=valid, other=-1)
+ out = tl.where(valid, swa_slot, -1).to(tl.int32)
+ tl.store(swa_index_ptr + token_idx * swa_index_stride0 + w, out, mask=w_mask)
+
+ length = tl.minimum(tl.maximum(pos + 1, 1), WINDOW).to(tl.int32)
+ tl.store(swa_length_ptr + token_idx, length)
+
+
+def build_swa_index(
+ req_idx: torch.Tensor,
+ positions: torch.Tensor,
+ req_to_token_indexs: torch.Tensor,
+ full_to_swa_indexs: torch.Tensor,
+ swa_index: torch.Tensor,
+ swa_length: torch.Tensor,
+):
+ """Per-token sliding-window FlashMLA index table, built ONCE per forward (layer-independent:
+ full_to_swa is a single global map and the window is a model constant, so every layer's swa
+ indices are identical). Replaces DeepseekV4IndexInfer._swa_indices: for token t at
+ (req_idx, position) gather the last `window` tokens' full slots via req_to_token, then map
+ full -> swa; out-of-range positions store -1.
+
+ Writes (swa_index [T, window] int32, swa_length [T] int32). The caller owns the output storage;
+ the reader adds the s_q axis via unsqueeze(1).
+ """
+ T = positions.shape[0]
+ window = swa_index.shape[1]
+ if T == 0:
+ return swa_index, swa_length
+ _build_swa_index_kernel[(T,)](
+ req_idx,
+ positions,
+ req_to_token_indexs,
+ req_to_token_indexs.stride(0),
+ full_to_swa_indexs,
+ swa_index,
+ swa_index.stride(0),
+ swa_length,
+ WINDOW=window,
+ BLOCK_W=triton.next_power_of_2(window),
+ num_warps=4,
+ )
+ return swa_index, swa_length
diff --git a/lightllm/models/deepseek_v4/triton_kernel/destindex_copy_indexer_k_dsv4.py b/lightllm/models/deepseek_v4/triton_kernel/destindex_copy_indexer_k_dsv4.py
new file mode 100644
index 000000000..3510b92c3
--- /dev/null
+++ b/lightllm/models/deepseek_v4/triton_kernel/destindex_copy_indexer_k_dsv4.py
@@ -0,0 +1,92 @@
+import torch
+
+import triton
+import triton.language as tl
+
+
+@triton.jit
+def _fwd_kernel_destindex_copy_indexer_k_dsv4(
+ K,
+ Dest_loc,
+ O_fp8,
+ O_f32,
+ stride_k_bs,
+ stride_k_d,
+ FP8_MIN: tl.constexpr,
+ FP8_MAX: tl.constexpr,
+ SCALE_MIN: tl.constexpr,
+ HEAD_DIM: tl.constexpr,
+ PAGE_SIZE: tl.constexpr,
+ BYTES_PER_PAGE: tl.constexpr,
+):
+ cur_index = tl.program_id(0)
+ dest_index = tl.load(Dest_loc + cur_index).to(tl.int64)
+ # negative dest (unmapped slot) is a no-op, not an OOB write into a neighboring page.
+ if dest_index < 0:
+ return
+
+ page = dest_index // PAGE_SIZE
+ token_in_page = dest_index % PAGE_SIZE
+
+ offs_d = tl.arange(0, HEAD_DIM)
+ vals = tl.load(K + cur_index * stride_k_bs + offs_d * stride_k_d).to(tl.float32)
+ amax = tl.max(tl.abs(vals), axis=0)
+ # per-token plain fp32 scale (not ue8m0), matching DeepseekV4MemoryManager._pack_indexer_k
+ scale = tl.maximum(amax / FP8_MAX, SCALE_MIN)
+ k_fp8 = tl.clamp(vals / scale, min=FP8_MIN, max=FP8_MAX).to(tl.float8e4nv)
+
+ data_base = page * BYTES_PER_PAGE + token_in_page * HEAD_DIM
+ tl.store(O_fp8 + data_base + offs_d, k_fp8)
+ scale_idx = (page * BYTES_PER_PAGE + PAGE_SIZE * HEAD_DIM) // 4 + token_in_page
+ tl.store(O_f32 + scale_idx, scale)
+ return
+
+
+@torch.no_grad()
+def destindex_copy_indexer_k_dsv4(
+ K: torch.Tensor,
+ DestLoc: torch.Tensor,
+ O_buffer: torch.Tensor,
+ page_size: int,
+):
+ """Packed indexer-K page-slab writer (DeepSeek-V4 c4/CSA layers).
+
+ K: [T, 128] bf16 unquantized indexer keys.
+ DestLoc: [T] int — c4-pool-local token slots; must already be allocated by the caller.
+ Negative slots (unmapped) are skipped.
+ O_buffer: [num_pages, bytes_per_page] uint8 — one layer's slab from the c4 indexer
+ PackedPagePool (128B fp8 data region + 4B fp32 scale tail per token).
+
+ Bit-compatible with DeepseekV4MemoryManager._pack_indexer_k + PackedPagePool.write.
+ """
+ seq_len = DestLoc.shape[0]
+ if seq_len == 0:
+ return
+ head_dim, scale_bytes = 128, 4
+
+ K = K.reshape(-1, head_dim)
+ assert K.shape[0] == seq_len, f"Expected K shape[0]={seq_len}, got {K.shape[0]}"
+ assert K.dtype == torch.bfloat16, f"Expected bf16 indexer K, got {K.dtype}"
+ bytes_per_page = O_buffer.shape[-1]
+ assert O_buffer.dtype == torch.uint8 and O_buffer.is_contiguous()
+ assert bytes_per_page % 4 == 0
+ assert bytes_per_page >= page_size * (head_dim + scale_bytes)
+
+ flat = O_buffer.view(-1)
+ _fwd_kernel_destindex_copy_indexer_k_dsv4[(seq_len,)](
+ K,
+ DestLoc,
+ flat.view(torch.float8_e4m3fn),
+ flat.view(torch.float32),
+ K.stride(0),
+ K.stride(1),
+ FP8_MIN=torch.finfo(torch.float8_e4m3fn).min,
+ FP8_MAX=torch.finfo(torch.float8_e4m3fn).max,
+ SCALE_MIN=1e-4,
+ HEAD_DIM=head_dim,
+ PAGE_SIZE=page_size,
+ BYTES_PER_PAGE=bytes_per_page,
+ num_warps=1,
+ num_stages=1,
+ )
+ return
diff --git a/lightllm/models/deepseek_v4/triton_kernel/destindex_copy_kv_flashmla_dsv4.py b/lightllm/models/deepseek_v4/triton_kernel/destindex_copy_kv_flashmla_dsv4.py
new file mode 100644
index 000000000..a3ec6ed8c
--- /dev/null
+++ b/lightllm/models/deepseek_v4/triton_kernel/destindex_copy_kv_flashmla_dsv4.py
@@ -0,0 +1,121 @@
+import torch
+
+import triton
+import triton.language as tl
+from triton.language.extra import libdevice
+
+
+@triton.jit
+def _fwd_kernel_destindex_copy_kv_flashmla_dsv4(
+ KV,
+ Dest_loc,
+ O_fp8,
+ O_bf16,
+ O_u8,
+ stride_kv_bs,
+ stride_kv_d,
+ FP8_MIN: tl.constexpr,
+ FP8_MAX: tl.constexpr,
+ SCALE_MIN: tl.constexpr,
+ NOPE_DIM: tl.constexpr,
+ ROPE_DIM: tl.constexpr,
+ GROUP_SIZE: tl.constexpr,
+ NUM_GROUPS: tl.constexpr,
+ SCALE_BYTES: tl.constexpr,
+ PAGE_SIZE: tl.constexpr,
+ BYTES_PER_PAGE: tl.constexpr,
+):
+ cur_index = tl.program_id(0)
+ dest_index = tl.load(Dest_loc + cur_index).to(tl.int64)
+ # negative dest (unmapped slot, e.g. full_to_c* rows that never closed a group) is a no-op,
+ # not an OOB write into a neighboring page.
+ if dest_index < 0:
+ return
+
+ page = dest_index // PAGE_SIZE
+ token_in_page = dest_index % PAGE_SIZE
+ data_base = page * BYTES_PER_PAGE + token_in_page * (NOPE_DIM + ROPE_DIM * 2)
+ scale_base = page * BYTES_PER_PAGE + PAGE_SIZE * (NOPE_DIM + ROPE_DIM * 2) + token_in_page * SCALE_BYTES
+
+ # nope: per-group ue8m0 quant. SCALE_BYTES(=NUM_GROUPS+1) lanes cover the exponent bytes
+ # plus the trailing zero pad byte in one store. libdevice.log2 (not tl.log2, which is the
+ # approx instruction) and the bit-packed 2**e keep this bit-exact with the torch oracle
+ # DeepseekV4MemoryManager._pack_mla_kv.
+ offs_g = tl.arange(0, SCALE_BYTES)
+ offs_e = tl.arange(0, GROUP_SIZE)
+ group_mask = offs_g < NUM_GROUPS
+ kv_ptrs = KV + cur_index * stride_kv_bs + (offs_g[:, None] * GROUP_SIZE + offs_e[None, :]) * stride_kv_d
+ vals = tl.load(kv_ptrs, mask=group_mask[:, None], other=0.0).to(tl.float32)
+ amax = tl.max(tl.abs(vals), axis=1)
+ scale_exp = tl.ceil(libdevice.log2(tl.maximum(amax / FP8_MAX, SCALE_MIN))).to(tl.int32)
+ scale = ((scale_exp + 127) << 23).to(tl.float32, bitcast=True)
+ kv_fp8 = tl.clamp(vals / scale[:, None], min=FP8_MIN, max=FP8_MAX).to(tl.float8e4nv)
+ tl.store(O_fp8 + data_base + offs_g[:, None] * GROUP_SIZE + offs_e[None, :], kv_fp8, mask=group_mask[:, None])
+ scale_bytes = tl.where(group_mask, scale_exp + 127, 0).to(tl.uint8)
+ tl.store(O_u8 + scale_base + offs_g, scale_bytes)
+
+ # rope: bf16 passthrough into the data region right after the nope bytes
+ offs_r = tl.arange(0, ROPE_DIM)
+ rope = tl.load(KV + cur_index * stride_kv_bs + (NOPE_DIM + offs_r) * stride_kv_d)
+ tl.store(O_bf16 + (data_base + NOPE_DIM) // 2 + offs_r, rope)
+ return
+
+
+@torch.no_grad()
+def destindex_copy_kv_flashmla_dsv4(
+ KV: torch.Tensor,
+ DestLoc: torch.Tensor,
+ O_buffer: torch.Tensor,
+ page_size: int,
+):
+ """fp8_ds_mla packed page-slab writer (DeepSeek-V4 ABI, all latent pools).
+
+ KV: [T, 512] bf16 — 448 normed-latent dims + 64 rope'd dims per token.
+ DestLoc: [T] int — pool-local token slots (page = slot // page_size); the pool HOLD slot is
+ a valid in-bounds row, negative slots (unmapped) are skipped. Slots must already be
+ resolved/allocated by the caller.
+ O_buffer: [num_pages, bytes_per_page] uint8 — one layer's slab from PackedPagePool
+ (swa page=128 / c4 page=64 / c128 page=2 all share this kernel).
+
+ Per token: 448B fp8(e4m3) in 7x64 ue8m0 groups + 128B bf16 rope in the page data region;
+ 7 exponent bytes (e+127) + 1 zero pad at the page scale tail. Bit-compatible with
+ DeepseekV4MemoryManager._pack_mla_kv + PackedPagePool.write.
+ """
+ seq_len = DestLoc.shape[0]
+ if seq_len == 0:
+ return
+ nope_dim, rope_dim, group_size = 448, 64, 64
+ head_dim = nope_dim + rope_dim
+ scale_bytes = nope_dim // group_size + 1
+
+ KV = KV.reshape(-1, head_dim)
+ assert KV.shape[0] == seq_len, f"Expected KV shape[0]={seq_len}, got {KV.shape[0]}"
+ assert KV.dtype == torch.bfloat16, f"Expected bf16 KV (rope bytes are stored as-is), got {KV.dtype}"
+ bytes_per_page = O_buffer.shape[-1]
+ assert O_buffer.dtype == torch.uint8 and O_buffer.is_contiguous()
+ assert bytes_per_page % 2 == 0
+ assert bytes_per_page >= page_size * (nope_dim + rope_dim * 2 + scale_bytes)
+
+ flat = O_buffer.view(-1)
+ _fwd_kernel_destindex_copy_kv_flashmla_dsv4[(seq_len,)](
+ KV,
+ DestLoc,
+ flat.view(torch.float8_e4m3fn),
+ flat.view(torch.bfloat16),
+ flat,
+ KV.stride(0),
+ KV.stride(1),
+ FP8_MIN=torch.finfo(torch.float8_e4m3fn).min,
+ FP8_MAX=torch.finfo(torch.float8_e4m3fn).max,
+ SCALE_MIN=1e-4,
+ NOPE_DIM=nope_dim,
+ ROPE_DIM=rope_dim,
+ GROUP_SIZE=group_size,
+ NUM_GROUPS=nope_dim // group_size,
+ SCALE_BYTES=scale_bytes,
+ PAGE_SIZE=page_size,
+ BYTES_PER_PAGE=bytes_per_page,
+ num_warps=4,
+ num_stages=1,
+ )
+ return
diff --git a/lightllm/models/deepseek_v4/triton_kernel/gather_c4_indexer_k_dsv4.py b/lightllm/models/deepseek_v4/triton_kernel/gather_c4_indexer_k_dsv4.py
new file mode 100644
index 000000000..b50eab672
--- /dev/null
+++ b/lightllm/models/deepseek_v4/triton_kernel/gather_c4_indexer_k_dsv4.py
@@ -0,0 +1,71 @@
+import torch
+import triton
+import triton.language as tl
+
+
+@triton.jit
+def _build_c4_indexer_page_table_kernel(
+ req_idx_ptr, # [batch] int
+ c4_len_ptr, # [batch] int
+ req_to_token_ptr,
+ req_to_token_stride0,
+ full_to_c4_ptr,
+ page_table_ptr, # [batch, page_cap] int32
+ page_cap,
+ hold_req_id,
+ RATIO: tl.constexpr,
+ PAGE_SIZE: tl.constexpr,
+):
+ p = tl.program_id(0)
+ r = tl.program_id(1)
+ req = tl.load(req_idx_ptr + r).to(tl.int64)
+ c4_len = tl.load(c4_len_ptr + r).to(tl.int64)
+ page_start = p * PAGE_SIZE
+ active = (req != hold_req_id) & (page_start < c4_len)
+
+ full_pos0 = page_start * RATIO + (RATIO - 1)
+ full_slot0 = tl.load(
+ req_to_token_ptr + req * req_to_token_stride0 + full_pos0,
+ mask=active,
+ other=0,
+ ).to(tl.int64)
+ c4_slot0 = tl.load(full_to_c4_ptr + full_slot0, mask=active, other=0).to(tl.int64)
+ phys_page = c4_slot0 // PAGE_SIZE
+ tl.store(page_table_ptr + r * page_cap + p, tl.where(active, phys_page, 0).to(tl.int32))
+
+
+@torch.no_grad()
+def build_c4_indexer_page_table(
+ mem_manager,
+ b_req_idx: torch.Tensor,
+ c4_len: torch.Tensor,
+ c4_cap: int,
+ req_to_token_indexs: torch.Tensor,
+ hold_req_id: int,
+):
+ """Build the logical-c4-page -> physical-c4-page table expected by DeepGEMM paged logits.
+
+ Safe only when each logical c4 page maps to a physical page with matching offsets:
+ c4_slot(entry p*64 + o) == page_table[p] * 64 + o
+ which the current token-slot allocator guarantees.
+ """
+ pool = mem_manager.c4_indexer_pool
+ page_size = pool.page_size
+ assert c4_cap % page_size == 0
+ batch = b_req_idx.shape[0]
+ page_cap = c4_cap // page_size
+ page_table = torch.empty((batch, page_cap), dtype=torch.int32, device=b_req_idx.device)
+ _build_c4_indexer_page_table_kernel[(page_cap, batch)](
+ b_req_idx,
+ c4_len,
+ req_to_token_indexs,
+ req_to_token_indexs.stride(0),
+ mem_manager.full_to_c4_indexs,
+ page_table,
+ page_cap,
+ int(hold_req_id),
+ RATIO=4,
+ PAGE_SIZE=page_size,
+ num_warps=1,
+ )
+ return page_table
diff --git a/lightllm/models/deepseek_v4/triton_kernel/quant_convert.py b/lightllm/models/deepseek_v4/triton_kernel/quant_convert.py
new file mode 100644
index 000000000..47d87d493
--- /dev/null
+++ b/lightllm/models/deepseek_v4/triton_kernel/quant_convert.py
@@ -0,0 +1,16 @@
+import torch
+
+
+def e8m0_to_fp32(scale: torch.Tensor) -> torch.Tensor:
+ """float8_e8m0fnu encodes 2**(byte-127); torch decodes it correctly on .to(float32)."""
+ return scale.to(torch.float32)
+
+
+def dequant_fp8_block_to_bf16(weight_e4m3: torch.Tensor, scale_e8m0: torch.Tensor, block_size: int = 128):
+ """De-quantize an FP8 e4m3 weight [out, in] with block-[bs,bs] ue8m0 scale to bf16."""
+ from lightllm.models.deepseek2.triton_kernel.weight_dequant import weight_dequant
+
+ w = weight_e4m3.cuda().contiguous()
+ s = e8m0_to_fp32(scale_e8m0).cuda().contiguous()
+ # weight_dequant runs with torch default dtype for the output; force bf16 result.
+ return weight_dequant(w, s, block_size)
diff --git a/lightllm/models/deepseek_v4/workspace.py b/lightllm/models/deepseek_v4/workspace.py
new file mode 100644
index 000000000..562a53bbf
--- /dev/null
+++ b/lightllm/models/deepseek_v4/workspace.py
@@ -0,0 +1,51 @@
+import torch
+from lightllm.utils.envs_utils import get_env_start_args
+
+
+class DeepseekV4Workspace:
+ def __init__(self, model):
+ self.token_capacity = int(model.batch_max_tokens)
+ self.sliding_window = int(model.config["sliding_window"])
+ self.index_topk = int(model.config["index_topk"])
+ self.c128_cap = self.compress_cap(model.max_seq_length, 128)
+ args = get_env_start_args()
+ overlap = args.enable_decode_microbatch_overlap or args.enable_prefill_microbatch_overlap
+ self.microbatch_count = 1 + int(overlap)
+
+ self.swa_indices = self._alloc(self.sliding_window)
+ self.swa_lengths = torch.empty((self.microbatch_count, self.token_capacity), dtype=torch.int32, device="cuda")
+ self.c4_indices = self._alloc(self.index_topk)
+ self.c4_lengths = torch.empty((self.microbatch_count, self.token_capacity), dtype=torch.int32, device="cuda")
+ self.c128_indices = self._alloc(self.c128_cap)
+ self.c128_lengths = torch.empty((self.microbatch_count, self.token_capacity), dtype=torch.int32, device="cuda")
+
+ @staticmethod
+ def compress_cap(max_kv_seq_len: int, ratio: int) -> int:
+ entries = max(1, int(max_kv_seq_len) // ratio)
+ return ((entries + 63) // 64) * 64
+
+ def _alloc(self, width: int) -> torch.Tensor:
+ return torch.empty((self.microbatch_count, self.token_capacity * width), dtype=torch.int32, device="cuda")
+
+ @staticmethod
+ def _view(buffer: torch.Tensor, token_num: int, width: int) -> torch.Tensor:
+ return torch.as_strided(buffer, (token_num, width), (width, 1))
+
+ def swa(self, microbatch_index: int, token_num: int):
+ return (
+ self._view(self.swa_indices[microbatch_index], token_num, self.sliding_window),
+ self.swa_lengths[microbatch_index, :token_num],
+ )
+
+ def c4(self, microbatch_index: int, token_num: int, width: int):
+ assert width <= self.index_topk, f"c4 width {width} exceeds allocated {self.index_topk}"
+ return (
+ self._view(self.c4_indices[microbatch_index], token_num, width),
+ self.c4_lengths[microbatch_index, :token_num],
+ )
+
+ def c128(self, microbatch_index: int, token_num: int, width: int):
+ return (
+ self._view(self.c128_indices[microbatch_index], token_num, width),
+ self.c128_lengths[microbatch_index, :token_num],
+ )
diff --git a/lightllm/server/api_cli.py b/lightllm/server/api_cli.py
index 1bdf8f342..70a786edf 100644
--- a/lightllm/server/api_cli.py
+++ b/lightllm/server/api_cli.py
@@ -161,6 +161,7 @@ def make_argument_parser() -> argparse.ArgumentParser:
"qwen",
"deepseekv31",
"deepseekv32",
+ "deepseekv4",
"glm47",
"kimi_k2",
"qwen3_coder",
@@ -174,6 +175,7 @@ def make_argument_parser() -> argparse.ArgumentParser:
choices=[
"deepseek-r1",
"deepseek-v3",
+ "deepseek-v4",
"glm45",
"gpt-oss",
"kimi",
@@ -620,8 +622,11 @@ def make_argument_parser() -> argparse.ArgumentParser:
type=str,
default=None,
choices=["fp8", "fp4"],
- help="""Expert quantization dtype for EP MoE. Supported values are
- fp8 and fp4. Note that fp4 is only supported on SM100 GPUs.""",
+ help="""Requested dtype for MoE expert weights, fp8 or fp4. Resolves the fused_moe
+ quant method: fp8 -> deepgemm-fp8w8a8-b128; fp4 -> deepgemm-fp4fp8-b32 (online
+ quantization) on SM100 GPUs, or marlin-mxfp4w4a16-b32 (Marlin W4A16, TP only) on other GPUs.
+ Defaults to `expert_dtype` in config.json if present. Per-layer override:
+ --quant_cfg mix_bits with name `fused_moe`.""",
)
parser.add_argument(
"--vit_quant_type",
diff --git a/lightllm/server/api_models.py b/lightllm/server/api_models.py
index 1737d2774..dcf3d2a0a 100644
--- a/lightllm/server/api_models.py
+++ b/lightllm/server/api_models.py
@@ -221,7 +221,7 @@ class ChatCompletionRequest(BaseModel):
parallel_tool_calls: Optional[bool] = True
# OpenAI parameters for reasoning and others
- reasoning_effort: Optional[Literal["low", "medium", "high"]] = None
+ reasoning_effort: Optional[Literal["none", "low", "medium", "high", "max"]] = None
chat_template_kwargs: Optional[Dict] = None
separate_reasoning: Optional[bool] = True
stream_reasoning: Optional[bool] = False
diff --git a/lightllm/server/api_openai.py b/lightllm/server/api_openai.py
index 0d934c44c..1df04af3a 100644
--- a/lightllm/server/api_openai.py
+++ b/lightllm/server/api_openai.py
@@ -165,8 +165,13 @@ def _is_force_thinking_mode(request: ChatCompletionRequest) -> bool:
return False
if reasoning_parser in ["qwen3-thinking", "gpt-oss", "minimax"]:
return True
- if reasoning_parser in ["deepseek-v3"]:
- return request.chat_template_kwargs is not None and request.chat_template_kwargs.get("thinking") is True
+ if reasoning_parser in ["deepseek-v3", "deepseek-v4"]:
+ chat_template_kwargs = request.chat_template_kwargs or {}
+ if "thinking" in chat_template_kwargs:
+ return chat_template_kwargs["thinking"] is True
+ if request.reasoning_effort is not None:
+ return request.reasoning_effort != "none"
+ return False
if reasoning_parser in ["qwen3", "glm45", "nano_v3", "interns1", "gemma4"]:
# qwen3, glm45, nano_v3, interns1, and gemma4 are reasoning by default;
return not request.chat_template_kwargs or request.chat_template_kwargs.get("enable_thinking", True) is True
diff --git a/lightllm/server/api_start.py b/lightllm/server/api_start.py
index 3cf431d65..c13f562af 100644
--- a/lightllm/server/api_start.py
+++ b/lightllm/server/api_start.py
@@ -10,7 +10,11 @@
from .metrics.manager import start_metric_manager
from .embed_cache.manager import start_cache_manager
from lightllm.utils.log_utils import init_logger
-from lightllm.utils.envs_utils import set_env_start_args, set_unique_server_name, get_unique_server_name
+from lightllm.utils.envs_utils import (
+ set_env_start_args,
+ set_unique_server_name,
+ get_unique_server_name,
+)
from lightllm.utils.envs_utils import get_lightllm_gunicorn_keep_alive
from .detokenization.manager import start_detokenization_process
from .router.manager import start_router_process
@@ -23,6 +27,8 @@
has_vision_module,
is_linear_att_mixed_model,
auto_set_max_req_total_len,
+ get_model_type,
+ get_config_json,
)
from lightllm.utils.dist_check_utils import auto_configure_allreduce_flags_from_args
@@ -108,6 +114,19 @@ def normal_or_p_d_start(args):
else:
args.enable_multimodal = True
+ model_type = get_model_type(args.model_dir)
+ if model_type == "deepseek_v4":
+ if args.run_mode != "normal":
+ raise NotImplementedError("DeepSeek-V4 currently supports only run_mode=normal in LightLLM.")
+ if args.enable_cpu_cache or args.enable_disk_cache:
+ raise NotImplementedError("DeepSeek-V4 CPU/disk KV cache is not supported yet.")
+ if args.mtp_mode is not None or args.mtp_draft_model_dir is not None or args.mtp_step != 0:
+ raise NotImplementedError("DeepSeek-V4 MTP/speculative decoding is not supported yet.")
+ if args.enable_ep_moe:
+ raise NotImplementedError("DeepSeek-V4 EP MoE is not supported yet; use TP for now.")
+ if "prompt_cache_kv_buffer" in get_config_json(args.model_dir):
+ raise NotImplementedError("DeepSeek-V4 prompt_cache_kv_buffer is not supported yet.")
+
if args.enable_cpu_cache:
# 生成一个用于创建cpu kv cache的共享内存id。
args.cpu_kv_cache_shm_id = uuid.uuid1().int % 123456789
@@ -333,7 +352,14 @@ def normal_or_p_d_start(args):
from lightllm.utils.config_utils import get_dtype
args.data_type = get_dtype(args.model_dir)
- assert args.data_type in ["fp16", "float16", "bf16", "bfloat16", "fp32", "float32"]
+ assert args.data_type in [
+ "fp16",
+ "float16",
+ "bf16",
+ "bfloat16",
+ "fp32",
+ "float32",
+ ]
already_uesd_ports = [args.port]
if args.nccl_port is not None:
@@ -425,7 +451,6 @@ def normal_or_p_d_start(args):
)
if not args.disable_vision:
-
if not args.visual_use_proxy_mode:
from .visualserver.manager import start_visual_process
@@ -609,7 +634,14 @@ def visual_only_start(args):
from lightllm.utils.config_utils import get_dtype
args.data_type = get_dtype(args.model_dir)
- assert args.data_type in ["fp16", "float16", "bf16", "bfloat16", "fp32", "float32"]
+ assert args.data_type in [
+ "fp16",
+ "float16",
+ "bf16",
+ "bfloat16",
+ "fp32",
+ "float32",
+ ]
logger.info(f"alloced ports: {can_use_ports}")
diff --git a/lightllm/server/build_prompt.py b/lightllm/server/build_prompt.py
index 54d22a0d0..bda1aafea 100644
--- a/lightllm/server/build_prompt.py
+++ b/lightllm/server/build_prompt.py
@@ -53,6 +53,13 @@ def tokenizer_supports_force_thinking() -> bool:
assert tokenizer is not None
+ # Tokenizers that encode prompts in Python (e.g. DeepSeek-V4) have no Jinja
+ # chat_template string to inspect, so advertise thinking support via an
+ # explicit attribute instead.
+ if getattr(tokenizer, "supports_thinking", False):
+ logger.info("tokenizer_supports_force_thinking : True (explicit attribute)")
+ return True
+
try:
ans = "thinking" in tokenizer.chat_template or "enable_thinking" in tokenizer.chat_template
logger.debug(f"chat_template: {tokenizer.chat_template}")
@@ -144,6 +151,8 @@ async def build_prompt(request, tools) -> str:
if request.chat_template_kwargs:
kwargs.update(request.chat_template_kwargs)
+ if request.reasoning_effort is not None and "reasoning_effort" not in kwargs:
+ kwargs["reasoning_effort"] = request.reasoning_effort
# 修复一些parser类型是默认打开thinking,但是 tokenizer有时候不知道打开了thinking。导致
# 构建的reasoning parser 和 tokenizer 的行为不对齐导致的问题。
from .api_openai import _is_force_thinking_mode
diff --git a/lightllm/server/core/objs/start_args_type.py b/lightllm/server/core/objs/start_args_type.py
index 40c802815..8f2df6eba 100644
--- a/lightllm/server/core/objs/start_args_type.py
+++ b/lightllm/server/core/objs/start_args_type.py
@@ -42,6 +42,7 @@ class StartArgs:
"choices": [
"deepseek-r1",
"deepseek-v3",
+ "deepseek-v4",
"glm45",
"gpt-oss",
"kimi",
diff --git a/lightllm/server/function_call_parser.py b/lightllm/server/function_call_parser.py
index dfcb2f8d9..c08cb83fa 100644
--- a/lightllm/server/function_call_parser.py
+++ b/lightllm/server/function_call_parser.py
@@ -40,6 +40,7 @@
"[TOOL_CALLS]",
"<|tool▁calls▁begin|>",
"<|DSML|function_calls>",
+ "<|DSML|tool_calls>",
]
@@ -1480,11 +1481,14 @@ class DeepSeekV32Detector(BaseFormatDetector):
Reference: https://huggingface.co/deepseek-ai/DeepSeek-V3.2
"""
- def __init__(self):
+ def __init__(self, block_name: str = "function_calls"):
super().__init__()
self.dsml_token = "|DSML|"
- self.bot_token = f"<{self.dsml_token}function_calls>"
- self.eot_token = f"{self.dsml_token}function_calls>"
+ # DeepSeek V3.2 wraps tool calls in a `function_calls` block; V4 uses
+ # `tool_calls`. Only the outer block name differs — the invoke/parameter
+ # grammar is identical — so subclasses just override block_name.
+ self.bot_token = f"<{self.dsml_token}{block_name}>"
+ self.eot_token = f"{self.dsml_token}{block_name}>"
self.invoke_start_prefix = f"<{self.dsml_token}invoke"
self.invoke_end_token = f"{self.dsml_token}invoke>"
self.param_end_token = f"{self.dsml_token}parameter>"
@@ -1589,8 +1593,10 @@ def parse_streaming_increment(self, new_text: str, tools: List[Tool]) -> Streami
try:
# Try to find complete invoke blocks first
- complete_invoke_match = self.invoke_regex.search(current_text)
- if complete_invoke_match:
+ while True:
+ complete_invoke_match = self.invoke_regex.search(current_text)
+ if not complete_invoke_match:
+ break
func_name = complete_invoke_match.group(1)
invoke_body = complete_invoke_match.group(2)
@@ -1654,8 +1660,7 @@ def parse_streaming_increment(self, new_text: str, tools: List[Tool]) -> Streami
self.current_tool_name_sent = False
self._accumulated_params = []
self.streamed_args_for_tool.append("")
-
- return StreamingParseResult(normal_text="", calls=calls)
+ current_text = self._buffer
# Partial invoke: name is known but parameters are still streaming
partial_match = self.partial_invoke_regex.search(current_text)
@@ -1694,9 +1699,10 @@ def parse_streaming_increment(self, new_text: str, tools: List[Tool]) -> Streami
if param_matches and len(param_matches) > len(self._accumulated_params):
self._accumulated_params = param_matches
current_args_json = self._dsml_params_to_json(param_matches)
+ open_args_json = current_args_json[:-1] # drop trailing '}'
sent = len(self.streamed_args_for_tool[self.current_tool_id])
- argument_diff = current_args_json[sent:]
+ argument_diff = open_args_json[sent:]
if argument_diff:
calls.append(
@@ -1962,6 +1968,32 @@ def parse_streaming_increment(self, new_text: str, tools: List[Tool]) -> Streami
self._buffer = current_text[eot_pos + len(self.eot_token) :].lstrip()
+class DeepSeekV4Detector(DeepSeekV32Detector):
+ """
+ Detector for DeepSeek V4 model function call format using DSML.
+
+ Identical grammar to V3.2 (``<|DSML|invoke name="...">`` blocks with
+ ``<|DSML|parameter name="k" string="true|false">v|DSML|parameter>``
+ tags), except the outer block is named ``tool_calls`` instead of
+ ``function_calls`` — matching the model's own encoding (encoding_dsv4.py:
+ ``tool_calls_block_name = "tool_calls"``) and system prompt.
+
+ Format Structure:
+ ```
+ <|DSML|tool_calls>
+ <|DSML|invoke name="get_weather">
+ <|DSML|parameter name="location" string="true">Hangzhou|DSML|parameter>
+ |DSML|invoke>
+ |DSML|tool_calls>
+ ```
+
+ Reference: https://huggingface.co/deepseek-ai/DeepSeek-V4
+ """
+
+ def __init__(self):
+ super().__init__(block_name="tool_calls")
+
+
class FunctionCallParser:
"""
Parser for function/tool calls in model outputs.
@@ -1975,6 +2007,7 @@ class FunctionCallParser:
"deepseekv3": DeepSeekV3Detector,
"deepseekv31": DeepSeekV31Detector,
"deepseekv32": DeepSeekV32Detector,
+ "deepseekv4": DeepSeekV4Detector,
"glm47": Glm47Detector,
"kimi_k2": KimiK2Detector,
"llama3": Llama32Detector,
diff --git a/lightllm/server/reasoning_parser.py b/lightllm/server/reasoning_parser.py
index 8a8d07355..f351d8a6c 100644
--- a/lightllm/server/reasoning_parser.py
+++ b/lightllm/server/reasoning_parser.py
@@ -903,6 +903,7 @@ class ReasoningParser:
DetectorMap: Dict[str, Type[BaseReasoningFormatDetector]] = {
"deepseek-r1": DeepSeekR1Detector,
"deepseek-v3": Qwen3Detector,
+ "deepseek-v4": Qwen3Detector,
"glm45": Qwen3Detector,
"gpt-oss": GptOssDetector,
"kimi": KimiDetector,
diff --git a/lightllm/server/router/dynamic_prompt/radix_cache.py b/lightllm/server/router/dynamic_prompt/radix_cache.py
index 21e26c585..e45b556b5 100644
--- a/lightllm/server/router/dynamic_prompt/radix_cache.py
+++ b/lightllm/server/router/dynamic_prompt/radix_cache.py
@@ -2,9 +2,12 @@
import torch
import numpy as np
import collections
-from typing import Tuple, Dict, Set, List, Optional, Union
+from typing import Any, Tuple, Dict, Set, List, Optional, Union
from sortedcontainers import SortedSet
from .shared_arr import SharedArray
+from lightllm.utils.log_utils import init_logger
+
+logger = init_logger(__name__)
class UniqueTimeIdGenerator:
@@ -25,6 +28,7 @@ def __init__(self):
self.parent: TreeNode = None
self.token_id_key: torch.Tensor = None
self.token_mem_index_value: torch.Tensor = None # 用于记录存储的 token_index 为每个元素在 token mem 中的index位置
+ self.token_extra_value: Any = None
self.ref_counter = 0
self.time_id = time_gen.generate_time_id() # 用于标识时间周期
@@ -34,14 +38,17 @@ def __init__(self):
def get_compare_key(self):
return (0 if self.ref_counter == 0 else 1, len(self.children), self.time_id)
- def split_node(self, prefix_len):
+ def split_node(self, prefix_len, child_key_fn=None, extra_value_ops=None):
split_parent_node = TreeNode()
split_parent_node.parent = self.parent
- split_parent_node.parent.children[self.token_id_key[0].item()] = split_parent_node
+ split_parent_node.parent.children[child_key_fn(self.token_id_key)] = split_parent_node
split_parent_node.token_id_key = self.token_id_key[0:prefix_len]
split_parent_node.token_mem_index_value = self.token_mem_index_value[0:prefix_len]
+ if self.token_extra_value is not None and extra_value_ops is not None:
+ split_parent_node.token_extra_value = extra_value_ops.slice(self.token_extra_value, 0, prefix_len)
+ self.token_extra_value = extra_value_ops.slice(self.token_extra_value, prefix_len, len(self.token_id_key))
split_parent_node.children = {}
- split_parent_node.children[self.token_id_key[prefix_len].item()] = self
+ split_parent_node.children[child_key_fn(self.token_id_key[prefix_len:])] = self
split_parent_node.ref_counter = self.ref_counter
new_len = len(split_parent_node.token_mem_index_value)
@@ -56,11 +63,12 @@ def split_node(self, prefix_len):
self.node_prefix_total_len = self.parent.node_prefix_total_len + new_len
return split_parent_node
- def add_and_return_new_child(self, token_id_key, token_mem_index_value):
+ def add_and_return_new_child(self, token_id_key, token_mem_index_value, token_extra_value=None, child_key=None):
child = TreeNode()
child.token_id_key = token_id_key
child.token_mem_index_value = token_mem_index_value
- first_token_key = child.token_id_key[0].item()
+ child.token_extra_value = token_extra_value
+ first_token_key = child.token_id_key[0].item() if child_key is None else child_key
assert first_token_key not in self.children.keys()
self.children[first_token_key] = child
child.parent = self
@@ -71,9 +79,17 @@ def add_and_return_new_child(self, token_id_key, token_mem_index_value):
return child
def remove_child(self, child_node: "TreeNode"):
- del self.children[child_node.token_id_key[0].item()]
- child_node.parent = None
- return
+ child_key = child_node.token_id_key[0].item()
+ if child_key in self.children:
+ del self.children[child_key]
+ child_node.parent = None
+ return
+ for key, value in list(self.children.items()):
+ if value is child_node:
+ del self.children[key]
+ child_node.parent = None
+ return
+ raise KeyError("child node not found")
def update_time(self):
self.time_id = time_gen.generate_time_id()
@@ -103,12 +119,22 @@ class RadixCache:
unique_name 主要用于解决单机,多实列部署时的shm冲突
"""
- def __init__(self, unique_name, total_token_num, rank_in_node, mem_manager=None):
+ def __init__(
+ self,
+ unique_name,
+ total_token_num,
+ rank_in_node,
+ mem_manager=None,
+ page_size: int = 1,
+ extra_value_ops=None,
+ ):
from lightllm.common.kv_cache_mem_manager import MemoryManager
self.mem_manager: MemoryManager = mem_manager
self._key_dtype = torch.int64
self._value_dtype = torch.int64
+ self.page_size = max(1, int(page_size))
+ self.extra_value_ops = extra_value_ops
self.root_node = TreeNode()
self.root_node.token_id_key = torch.zeros((0,), device="cpu", dtype=self._key_dtype)
@@ -124,31 +150,89 @@ def __init__(self, unique_name, total_token_num, rank_in_node, mem_manager=None)
f"{unique_name}_tree_total_tokens_num_{rank_in_node}", (1,), dtype=np.int64
)
self.tree_total_tokens_num.arr[0] = 0
+ self.swa_tree_total_pages_num = 0
+ self.swa_refed_pages_num = 0
+ # 每个 prompt-cache 页折算多少 swa 页(DSV4 为 256/128=2);非 swa 场景为 0,_node_swa_pages_num 退化为常数 0。
+ self._swa_pages_per_prompt_page = self._probe_swa_pages_per_prompt_page()
+
+ def _probe_swa_pages_per_prompt_page(self) -> int:
+ """构造期探测一次 mem_manager 是否带 swa_pool,缓存折算系数,避免热路径反复 getattr。"""
+ if self.mem_manager is None or self.extra_value_ops is None:
+ return 0
+ swa_pool = getattr(self.mem_manager, "swa_pool", None)
+ swa_page_size = getattr(swa_pool, "page_size", None)
+ if swa_page_size is None:
+ return 0
+ return (self.page_size + int(swa_page_size) - 1) // int(swa_page_size)
+
+ def _node_swa_pages_num(self, node: TreeNode) -> int:
+ if self._swa_pages_per_prompt_page == 0 or node.token_extra_value is None:
+ return 0
+ valid = node.token_extra_value.swa_page_valid
+ if valid is None:
+ return 0
+ return int(valid.sum().item()) * self._swa_pages_per_prompt_page
+
+ def _align_len(self, length: int) -> int:
+ if self.page_size <= 1:
+ return int(length)
+ return int(length) // self.page_size * self.page_size
+
+ def align_len(self, length: int) -> int:
+ return self._align_len(length)
+
+ def _child_key(self, key: torch.Tensor):
+ if self.page_size <= 1:
+ return key[0].item()
+ return tuple(key[: self.page_size].tolist())
+
+ def _match_len(self, key: torch.Tensor, node_key: torch.Tensor) -> int:
+ prefix_len = match(key, node_key)
+ return self._align_len(prefix_len)
+
+ def _slice_extra(self, extra_value, start: int, end: int):
+ if extra_value is None:
+ return None
+ assert self.extra_value_ops is not None
+ return self.extra_value_ops.slice(extra_value, start, end)
- def insert(self, key, value=None) -> Tuple[int, Optional[TreeNode]]:
+ def _concat_extra(self, values: list):
+ values = [v for v in values if v is not None]
+ if len(values) == 0:
+ return None
+ assert self.extra_value_ops is not None
+ return self.extra_value_ops.concat(values)
+
+ def insert(self, key, value=None, extra_value=None) -> Tuple[int, Optional[TreeNode]]:
if value is None:
value = key
+ align_len = self._align_len(len(key))
+ key = key[:align_len]
+ value = value[:align_len]
+ if extra_value is not None:
+ extra_value = self._slice_extra(extra_value, 0, align_len)
+
assert len(key) == len(value) # and len(key) >= 1
if len(key) == 0:
return 0, None
- return self._insert_helper(self.root_node, key, value)
+ return self._insert_helper(self.root_node, key, value, extra_value)
- def _insert_helper(self, node: TreeNode, key, value) -> Tuple[int, Optional[TreeNode]]:
+ def _insert_helper(self, node: TreeNode, key, value, extra_value) -> Tuple[int, Optional[TreeNode]]:
handle_stack = collections.deque()
update_list = collections.deque()
- handle_stack.append((node, key, value))
+ handle_stack.append((node, key, value, extra_value))
ans_prefix_len = 0
ans_node = None
while len(handle_stack) != 0:
- node, key, value = handle_stack.popleft()
- ans_tuple = self._insert_helper_no_recursion(node=node, key=key, value=value)
- if len(ans_tuple) == 4:
- (_prefix_len, new_node, new_key, new_value) = ans_tuple
+ node, key, value, extra_value = handle_stack.popleft()
+ ans_tuple = self._insert_helper_no_recursion(node=node, key=key, value=value, extra_value=extra_value)
+ if len(ans_tuple) == 5:
+ (_prefix_len, new_node, new_key, new_value, new_extra_value) = ans_tuple
ans_prefix_len += _prefix_len
- handle_stack.append((new_node, new_key, new_value))
+ handle_stack.append((new_node, new_key, new_value, new_extra_value))
else:
_prefix_len, ans_node = ans_tuple
ans_prefix_len += _prefix_len
@@ -166,15 +250,15 @@ def _insert_helper(self, node: TreeNode, key, value) -> Tuple[int, Optional[Tree
return ans_prefix_len, ans_node
def _insert_helper_no_recursion(
- self, node: TreeNode, key: torch.Tensor, value: torch.Tensor
- ) -> Union[Tuple[int, Optional[TreeNode]], Tuple[int, TreeNode, torch.Tensor, torch.Tensor]]:
+ self, node: TreeNode, key: torch.Tensor, value: torch.Tensor, extra_value=None
+ ) -> Union[Tuple[int, Optional[TreeNode]], Tuple[int, TreeNode, torch.Tensor, torch.Tensor, Any]]:
if node.is_leaf():
self.evict_tree_set.discard(node)
- first_key_id = key[0].item()
+ first_key_id = self._child_key(key)
if first_key_id in node.children.keys():
child: TreeNode = node.children[first_key_id]
- prefix_len = match(key, child.token_id_key)
+ prefix_len = self._match_len(key, child.token_id_key)
if prefix_len == len(key):
if prefix_len == len(child.token_id_key):
if child.is_leaf():
@@ -184,10 +268,14 @@ def _insert_helper_no_recursion(
self.evict_tree_set.add(child)
return prefix_len, child
elif prefix_len < len(child.token_id_key):
+ if prefix_len == 0:
+ return 0, node
if child.is_leaf():
self.evict_tree_set.discard(child)
- split_parent_node = child.split_node(prefix_len)
+ split_parent_node = child.split_node(
+ prefix_len, child_key_fn=self._child_key, extra_value_ops=self.extra_value_ops
+ )
if split_parent_node.is_leaf():
self.evict_tree_set.add(split_parent_node)
@@ -199,15 +287,26 @@ def _insert_helper_no_recursion(
assert False, "can not run to here"
elif prefix_len < len(key) and prefix_len < len(child.token_id_key):
+ if prefix_len == 0:
+ return 0, node
if child.is_leaf():
self.evict_tree_set.discard(child)
+ new_extra_value = self._slice_extra(extra_value, prefix_len, len(key))
key = key[prefix_len:]
value = value[prefix_len:]
- split_parent_node = child.split_node(prefix_len)
- new_node = split_parent_node.add_and_return_new_child(key, value)
+ split_parent_node = child.split_node(
+ prefix_len, child_key_fn=self._child_key, extra_value_ops=self.extra_value_ops
+ )
+ new_node = split_parent_node.add_and_return_new_child(
+ key,
+ value,
+ token_extra_value=new_extra_value,
+ child_key=self._child_key(key),
+ )
# update total token num
self.tree_total_tokens_num.arr[0] += len(new_node.token_mem_index_value)
+ self.swa_tree_total_pages_num += self._node_swa_pages_num(new_node)
if split_parent_node.is_leaf():
self.evict_tree_set.add(split_parent_node)
@@ -218,20 +317,37 @@ def _insert_helper_no_recursion(
self.evict_tree_set.add(child)
return prefix_len, new_node
elif prefix_len < len(key) and prefix_len == len(child.token_id_key):
- return (prefix_len, child, key[prefix_len:], value[prefix_len:])
+ return (
+ prefix_len,
+ child,
+ key[prefix_len:],
+ value[prefix_len:],
+ self._slice_extra(extra_value, prefix_len, len(key)),
+ )
else:
assert False, "can not run to here"
else:
- new_node = node.add_and_return_new_child(key, value)
+ new_node = node.add_and_return_new_child(
+ key,
+ value,
+ token_extra_value=extra_value,
+ child_key=first_key_id,
+ )
# update total token num
self.tree_total_tokens_num.arr[0] += len(new_node.token_mem_index_value)
+ self.swa_tree_total_pages_num += self._node_swa_pages_num(new_node)
if new_node.is_leaf():
self.evict_tree_set.add(new_node)
return 0, new_node
def match_prefix(self, key, update_refs=False):
- assert len(key) != 0
+ key = key[: self._align_len(len(key))]
+ if len(key) == 0:
+ return None, 0, None
+ key = self._trim_key_by_extra_value_validity(key)
+ if len(key) == 0:
+ return None, 0, None
ans_value_list = []
tree_node = self._match_prefix_helper(self.root_node, key, ans_value_list, update_refs=update_refs)
if tree_node != self.root_node:
@@ -245,6 +361,30 @@ def match_prefix(self, key, update_refs=False):
self.dec_node_ref_counter(self.root_node)
return None, 0, None
+ def _trim_key_by_extra_value_validity(self, key: torch.Tensor) -> torch.Tensor:
+ """命中有效性裁剪(extra_value_ops 提供 valid_match_length 时启用,如 DeepSeek-V4 的
+ swa 按页 bitmap): 先做一次只读探测遍历得到自然命中与沿路 extra_value,按其有效边界截短
+ key,随后的正常遍历(加引用/分裂)只走截短后的前缀 —— 引用计数与最终返回值在同一次遍历
+ 内保持一致,不存在事后裁剪导致的漏减/多减。
+
+ 探测遍历可能分裂部分命中的节点(与正常遍历同语义,树不变式不受影响)。裁剪只会缩短命中,
+ 没有任何失败路径。"""
+ if self.extra_value_ops is None:
+ return key
+ valid_match_length = getattr(self.extra_value_ops, "valid_match_length", None)
+ if valid_match_length is None:
+ return key
+ probe_values = []
+ probe_node = self._match_prefix_helper(self.root_node, key, probe_values, update_refs=False)
+ if probe_node == self.root_node or len(probe_values) == 0:
+ return key
+ natural_len = sum(len(v) for v in probe_values)
+ extra_value = self.get_extra_value_by_node(probe_node)
+ valid_len = int(valid_match_length(extra_value, natural_len))
+ if valid_len < natural_len:
+ return key[:valid_len]
+ return key
+
def _match_prefix_helper(
self, node: TreeNode, key: torch.Tensor, ans_value_list: list, update_refs=False
) -> TreeNode:
@@ -286,24 +426,29 @@ def _match_prefix_helper_no_recursion(
# from 0 to 1 need update refs token num
if node.ref_counter == 1:
self.refed_tokens_num.arr[0] += len(node.token_mem_index_value)
+ self.swa_refed_pages_num += self._node_swa_pages_num(node)
if len(key) == 0:
return node
- first_key_id = key[0].item()
+ first_key_id = self._child_key(key)
if first_key_id not in node.children.keys():
return node
else:
child = node.children[first_key_id]
- prefix_len = match(key, child.token_id_key)
+ prefix_len = self._match_len(key, child.token_id_key)
if prefix_len == len(child.token_id_key):
ans_value_list.append(child.token_mem_index_value)
return (child, key[prefix_len:])
elif prefix_len < len(child.token_id_key):
+ if prefix_len == 0:
+ return node
if child.is_leaf():
self.evict_tree_set.discard(child)
- split_parent_node = child.split_node(prefix_len)
+ split_parent_node = child.split_node(
+ prefix_len, child_key_fn=self._child_key, extra_value_ops=self.extra_value_ops
+ )
ans_value_list.append(split_parent_node.token_mem_index_value)
if update_refs:
@@ -311,6 +456,7 @@ def _match_prefix_helper_no_recursion(
# from 0 to 1 need update refs token num
if split_parent_node.ref_counter == 1:
self.refed_tokens_num.arr[0] += len(split_parent_node.token_mem_index_value)
+ self.swa_refed_pages_num += self._node_swa_pages_num(split_parent_node)
if child.is_leaf():
self.evict_tree_set.add(child)
@@ -334,8 +480,11 @@ def evict(self, need_remove_tokens, evict_callback):
), "error evict tree node state"
num_evicted += len(node.token_mem_index_value)
evict_callback(node.token_mem_index_value)
+ if self.extra_value_ops is not None and node.token_extra_value is not None:
+ self.extra_value_ops.free(node.token_extra_value)
# update total token num
self.tree_total_tokens_num.arr[0] -= len(node.token_mem_index_value)
+ self.swa_tree_total_pages_num -= self._node_swa_pages_num(node)
parent_node: TreeNode = node.parent
parent_node.remove_child(node)
if parent_node.is_leaf():
@@ -369,11 +518,12 @@ def _try_merge(self, child_node: TreeNode) -> Optional[TreeNode]:
child_node.token_mem_index_value = torch.cat(
[parent_node.token_mem_index_value, child_node.token_mem_index_value]
)
+ child_node.token_extra_value = self._concat_extra([parent_node.token_extra_value, child_node.token_extra_value])
child_node.node_value_len = len(child_node.token_mem_index_value)
child_node.time_id = max(parent_node.time_id, child_node.time_id)
grandparent_node = parent_node.parent
- key_in_grandparent = parent_node.token_id_key[0].item()
+ key_in_grandparent = self._child_key(parent_node.token_id_key)
grandparent_node.children[key_in_grandparent] = child_node
child_node.parent = grandparent_node
@@ -417,6 +567,8 @@ def clear_tree_nodes(self):
self.tree_total_tokens_num.arr[0] = 0
self.refed_tokens_num.arr[0] = 0
+ self.swa_tree_total_pages_num = 0
+ self.swa_refed_pages_num = 0
return
def dec_node_ref_counter(self, node: TreeNode):
@@ -430,6 +582,7 @@ def dec_node_ref_counter(self, node: TreeNode):
while node is not None:
if node.ref_counter == 1:
self.refed_tokens_num.arr[0] -= len(node.token_mem_index_value)
+ self.swa_refed_pages_num -= self._node_swa_pages_num(node)
node.ref_counter -= 1
node = node.parent
@@ -449,6 +602,7 @@ def add_node_ref_counter(self, node: TreeNode):
while node is not None:
if node.ref_counter == 0:
self.refed_tokens_num.arr[0] += len(node.token_mem_index_value)
+ self.swa_refed_pages_num += self._node_swa_pages_num(node)
node.ref_counter += 1
node = node.parent
@@ -469,12 +623,28 @@ def get_mem_index_value_by_node(self, node: TreeNode) -> Optional[torch.Tensor]:
ans_list.reverse()
return torch.concat(ans_list, dim=0)
+ def get_extra_value_by_node(self, node: TreeNode):
+ if node is None or self.extra_value_ops is None:
+ return None
+
+ ans_list = []
+ while node is not None:
+ if node.token_extra_value is not None:
+ ans_list.append(node.token_extra_value)
+ node = node.parent
+
+ ans_list.reverse()
+ return self._concat_extra(ans_list)
+
def get_refed_tokens_num(self):
return self.refed_tokens_num.arr[0]
def get_tree_total_tokens_num(self):
return self.tree_total_tokens_num.arr[0]
+ def get_unrefed_swa_pages_num(self):
+ return self.swa_tree_total_pages_num - self.swa_refed_pages_num
+
def print_self(self, indent=0):
self._print_helper(self.root_node, indent)
@@ -489,6 +659,38 @@ def _print_helper(self, node: TreeNode, indent):
self._print_helper(child, indent=indent + 2)
return
+ def free_unreferenced_swa_pages(self, need_pages: int) -> None:
+ """DeepSeek-V4 swa free hook: 页 allocator 触底时,沿 LRU 序(evict_tree_set)只对
+ ref_count==0 的节点链 free 其 swa 页(full 槽与压缩条目保留——节点仍可服务更长前缀的
+ 中段命中),并清载荷 bitmap 位使后续命中按缩短语义裁剪。所有权判定直接复用 radix
+ 引用计数: 节点被任何活跃请求借用即 ref>0,其页不可达。不够时由 allocator 的 assert
+ 兜底(最后防线)。"""
+ if self.mem_manager is None or self.extra_value_ops is None:
+ return
+ invalidate = getattr(self.extra_value_ops, "invalidate_swa_pages", None)
+ if invalidate is None:
+ return
+ allocator = self.mem_manager.swa_page_allocator
+ target = allocator.can_use_mem_size + int(need_pages)
+ for leaf in list(self.evict_tree_set):
+ if allocator.can_use_mem_size >= target:
+ break
+ node = leaf
+ # 叶子起步沿父链回收: 引用计数向上累加(add_node_ref_counter 走父链),
+ # 因此 ref==0 的祖先必无任何活跃借用方。重复访问无害(evict_swa/-1 跳过)。
+ # 每回收一个节点就复查目标,避免多回收(无谓削减命中可用性)。
+ while node is not None and node is not self.root_node and node.ref_counter == 0:
+ if len(node.token_mem_index_value) > 0:
+ old_pages = self._node_swa_pages_num(node)
+ self.mem_manager.evict_swa(node.token_mem_index_value)
+ if node.token_extra_value is not None:
+ invalidate(node.token_extra_value)
+ self.swa_tree_total_pages_num -= old_pages - self._node_swa_pages_num(node)
+ if allocator.can_use_mem_size >= target:
+ return
+ node = node.parent
+ return
+
def free_radix_cache_to_get_enough_token(self, need_token_num):
assert self.mem_manager is not None
if need_token_num > self.mem_manager.allocator.can_use_mem_size:
@@ -504,6 +706,44 @@ def release_mem(mem_index):
self.mem_manager.free(mem_index)
return
+ def _free_radix_full_nodes_until(self, allocator, need: int) -> None:
+ """DeepSeek-V4 压缩池(c4/c128)兑现: 沿 LRU 序逐个驱逐 ref_count==0 的整个 full radix 节点,
+ 经 mem_manager.free() 级联回收其 c4 页 / c128 槽(evict_c4/evict_c128),每驱逐一个就复查
+ *真实* allocator(不靠计数,稳),直到够或已无可驱逐的无引用节点。后者(空闲+可回收仍不足)
+ 由上游 base_backend admission 的 wait_pause 兜底,allocator 的 assert 是最后防线。"""
+ if self.mem_manager is None or allocator is None:
+ return
+ while allocator.can_use_mem_size < need:
+ # 无可驱逐的无引用 token => 停(admission 应已 wait_pause)
+ if self.tree_total_tokens_num.arr[0] <= self.refed_tokens_num.arr[0]:
+ # 兜底没兜住:admission/realize 估算漂移了。打日志便于定位(否则只会撞下游隐晦的
+ # allocator "error alloc state" assert)。
+ logger.warning(
+ f"dsv4 compress-pool realize could not free enough: need={need} "
+ f"free={allocator.can_use_mem_size} tree_total={self.tree_total_tokens_num.arr[0]} "
+ f"refed={self.refed_tokens_num.arr[0]} (admission should have paused this req)"
+ )
+ return
+ release_mems = []
+ # 复用已测的 evict():弹一个 LRU、ref==0 的叶子(>=1 token),其 full 槽经 free 级联回收压缩槽
+ self.evict(1, lambda mem_index: release_mems.append(mem_index))
+ self.mem_manager.free(torch.concat(release_mems))
+ return
+
+ def free_radix_cache_to_get_enough_c4_pages(self, need_pages: int) -> None:
+ allocator = getattr(self.mem_manager, "c4_page_allocator", None) if self.mem_manager is not None else None
+ if allocator is None or need_pages <= 0:
+ return
+ self._free_radix_full_nodes_until(allocator, need_pages)
+ return
+
+ def free_radix_cache_to_get_enough_c128_slots(self, need_slots: int) -> None:
+ allocator = getattr(self.mem_manager, "c128_allocator", None) if self.mem_manager is not None else None
+ if allocator is None or need_slots <= 0:
+ return
+ self._free_radix_full_nodes_until(allocator, need_slots)
+ return
+
class _RadixCacheReadOnlyClient:
"""
diff --git a/lightllm/server/router/model_infer/infer_batch.py b/lightllm/server/router/model_infer/infer_batch.py
index 5c2d0d45f..1d5db33bf 100644
--- a/lightllm/server/router/model_infer/infer_batch.py
+++ b/lightllm/server/router/model_infer/infer_batch.py
@@ -8,7 +8,7 @@
from sortedcontainers import SortedDict
from dataclasses import dataclass, field
from typing import List, Dict, Tuple, Optional, Callable, Any, Union
-from lightllm.common.req_manager import ReqManager, ReqManagerForMamba
+from lightllm.common.req_manager import DeepseekV4ReqManager, ReqManager, ReqManagerForMamba
from lightllm.utils.infer_utils import mark_start, mark_end
from lightllm.server.core.objs import Req, SamplingParams, FinishStatus, ShmReqManager
from lightllm.server.router.dynamic_prompt.radix_cache import RadixCache, TreeNode
@@ -50,7 +50,9 @@ def register(
vocab_size: int,
):
self.args = get_env_start_args()
- from lightllm.server.router.model_infer.mode_backend.base_backend import ModeBackend
+ from lightllm.server.router.model_infer.mode_backend.base_backend import (
+ ModeBackend,
+ )
self.backend: ModeBackend = backend
self.req_manager = req_manager
@@ -122,11 +124,20 @@ def add_reqs(self, requests: List[Tuple[int, int, Any, int]], init_prefix_cache:
return req_objs
def free_a_req_mem(self, free_token_index: List, req: "InferReq"):
+ is_dsv4_req_manager = hasattr(self.req_manager, "build_prompt_cache_payload")
if self.radix_cache is None:
free_token_index.append(self.req_manager.req_to_token_indexs[req.req_idx][0 : req.cur_kv_len])
+ if is_dsv4_req_manager:
+ # 槽位随 full 槽经 mem_manager.free 级联回收。pause 路径不释放 req_idx,
+ # 必须在此复位出窗水位线 + 清 c128 在途状态(恢复命中走 extend,不会再有
+ # restore/zero 时机;c4 状态随 swa 页生灭,无需处理)。
+ self.req_manager.init_compress_state(req.req_idx)
else:
if not self.is_linear_att_mixed_model:
- self._full_att_free_req(free_token_index=free_token_index, req=req)
+ if is_dsv4_req_manager:
+ self._dsv4_full_att_free_req(free_token_index=free_token_index, req=req)
+ else:
+ self._full_att_free_req(free_token_index=free_token_index, req=req)
else:
self._linear_att_free_req(free_token_index=free_token_index, req=req)
assert len(req.linear_att_len_to_big_page_id) == 0
@@ -134,6 +145,11 @@ def free_a_req_mem(self, free_token_index: List, req: "InferReq"):
req.shm_req.shm_cur_kv_len = req.cur_kv_len
return
+ def _append_free_token_index(self, free_token_index: List, tensor: torch.Tensor):
+ if tensor.numel() > 0:
+ free_token_index.append(tensor)
+ return
+
def _full_att_free_req(self, free_token_index: List, req: "InferReq"):
input_token_ids = req.get_input_token_ids()
key = torch.tensor(input_token_ids[0 : req.cur_kv_len], dtype=torch.int64, device="cpu")
@@ -149,6 +165,68 @@ def _full_att_free_req(self, free_token_index: List, req: "InferReq"):
req.shared_kv_node = None
return
+ def _dsv4_full_att_free_req(self, free_token_index: List, req: "InferReq"):
+ if req.cur_kv_len == 0:
+ free_token_index.append(self.req_manager.req_to_token_indexs[req.req_idx][0:0])
+ return
+
+ old_prefix_len = 0 if req.shared_kv_node is None else req.shared_kv_node.node_prefix_total_len
+ inserted_len = old_prefix_len
+ duplicate_prefix_len = old_prefix_len
+
+ # 载荷只剩按页 bitmap(compressor 状态随 swa 页生灭/边界自然归零,不进载荷),
+ # 任意 128 对齐前缀皆可插入——含生成段(floor(cur_kv_len) 边界,回收保留尾页保证其驻留)。
+ cache_len = self.radix_cache.align_len(req.cur_kv_len)
+ self.req_manager: DeepseekV4ReqManager
+ if cache_len > old_prefix_len:
+ payload = self.req_manager.build_prompt_cache_payload(req.req_idx, cache_len)
+ value = self.req_manager.req_to_token_indexs[req.req_idx][:cache_len].detach().cpu()
+ # 按页有效性 bitmap 用插入时刻的映射写定(此后只会被阀清 0,不会复活)。水位线
+ # 纯 CPU 推导,避免 router 关键路径上的 GPU gather 同步(每插入一次要等全部在途
+ # decode kernel)。插入门: 截掉结尾的 invalid 页 —— 它们生来不可命中,还会永久
+ # 挡住后续更长前缀复用同一段 token(全量重插会因前缀已存在而保留旧 bitmap)。
+ page_size = self.req_manager.get_prompt_cache_page_size()
+ bitmap = self.req_manager.swa_page_valid_from_watermark(req.req_idx, cache_len)
+ n_pages = int(bitmap.numel())
+ while n_pages > 0 and not bool(bitmap[n_pages - 1]):
+ n_pages -= 1
+ gated_len = n_pages * page_size
+ if gated_len < cache_len:
+ logger.info(
+ f"DeepSeek-V4 prompt cache insert gate: trailing swa pages already evicted, "
+ f"shrink insert {cache_len} -> {gated_len}"
+ )
+ cache_len = gated_len
+ payload.cache_len = cache_len
+ payload.swa_page_valid = bitmap[:n_pages].clone()
+
+ if cache_len > old_prefix_len:
+ input_token_ids = req.get_input_token_ids()
+ key = torch.tensor(input_token_ids[0:cache_len], dtype=torch.int64, device="cpu")
+ duplicate_prefix_len, cache_node = self.radix_cache.insert(key, value[:cache_len], extra_value=payload)
+ inserted_len = 0 if cache_node is None else cache_node.node_prefix_total_len
+ if inserted_len != cache_len:
+ inserted_len = old_prefix_len
+ duplicate_prefix_len = old_prefix_len
+
+ dense_row = self.req_manager.req_to_token_indexs[req.req_idx]
+ self._append_free_token_index(free_token_index, dense_row[old_prefix_len:duplicate_prefix_len])
+ self._append_free_token_index(free_token_index, dense_row[inserted_len : req.cur_kv_len])
+ if len(free_token_index) == 0:
+ free_token_index.append(dense_row[0:0])
+ # 释放的 full 槽经 mem_manager.free 级联回收 swa/c4/c128(映射键控,无需收集槽位)。
+
+ # pause 路径不会走 req_manager.free/init: 复位出窗水位线(残留水位线会破坏下一次
+ # prefill 的共享前缀保护)并清 c128 在途状态(恢复命中走 extend 续算,若残留暂停前的
+ # 半窗聚合会算错;c128 状态在 128 对齐命中边界本应为零)。
+ self.req_manager.init_compress_state(req.req_idx)
+
+ if req.shared_kv_node is not None:
+ assert req.shared_kv_node.node_prefix_total_len <= max(inserted_len, old_prefix_len)
+ self.radix_cache.dec_node_ref_counter(req.shared_kv_node)
+ req.shared_kv_node = None
+ return
+
def _linear_att_free_req(self, free_token_index: List, req: "InferReq"):
assert g_infer_context.is_linear_att_mixed_model is True
args = get_env_start_args()
@@ -325,7 +403,12 @@ def pause_reqs(self, pause_reqs: List["InferReq"], is_master_in_dp: bool):
self.req_manager.free_token(free_token_index)
return self
- def recover_paused_reqs(self, paused_reqs: List["InferReq"], is_master_in_dp: bool, can_alloc_token_num: int):
+ def recover_paused_reqs(
+ self,
+ paused_reqs: List["InferReq"],
+ is_master_in_dp: bool,
+ can_alloc_token_num: int,
+ ):
if paused_reqs:
for req in paused_reqs:
@@ -354,6 +437,110 @@ def get_can_alloc_token_num(self):
)
return self.req_manager.mem_manager.allocator.can_use_mem_size + radix_cache_unref_token_num
+ def get_can_alloc_dsv4_swa_page_num(self):
+ mem_manager = self.req_manager.mem_manager
+ allocator = getattr(mem_manager, "swa_page_allocator", None)
+ if allocator is None:
+ return None
+
+ radix_cache_unref_page_num = 0
+ if self.radix_cache is not None:
+ radix_cache_unref_page_num = self.radix_cache.get_unrefed_swa_pages_num()
+ return int(allocator.can_use_mem_size) + radix_cache_unref_page_num
+
+ def get_dsv4_swa_prefill_need_page_num(self, req: "InferReq", is_chuncked_prefill: bool):
+ page_size = self._get_dsv4_swa_page_size()
+ if page_size is None:
+ return 0
+
+ start = int(req.cur_kv_len)
+ if is_chuncked_prefill:
+ end = int(req.get_chuncked_input_token_len())
+ else:
+ end = int(req.get_cur_total_len())
+ if end <= start:
+ return 0
+ first_new_page = (start + page_size - 1) // page_size
+ last_page = (end - 1) // page_size
+ return last_page - first_new_page + 1
+
+ def get_dsv4_swa_decode_need_page_num(self, req: "InferReq"):
+ page_size = self._get_dsv4_swa_page_size()
+ if page_size is None:
+ return 0
+
+ seq_len = int(req.get_cur_total_len())
+ if seq_len <= 0:
+ return 0
+ return 1 if (seq_len - 1) % page_size == 0 else 0
+
+ def _get_dsv4_swa_page_size(self):
+ mem_manager = self.req_manager.mem_manager
+ allocator = getattr(mem_manager, "swa_page_allocator", None)
+ if allocator is None:
+ return None
+ return mem_manager.swa_pool.page_size
+
+ # ---- DeepSeek-V4 compressed-pool (c4/c128) admission, mirror of the swa helpers above ----
+ # c4 is paged for fp8_paged_mqa_logits: 64 c4-slots/page == 256 full tokens. The prompt-cache
+ # radix is 256-aligned (DSV4_PROMPT_CACHE_PAGE_SIZE) and c4 is NOT windowed, so reclaimable c4
+ # pages derive exactly from the unref token count (`// 256`) — no separate counter needed
+ # (unlike swa, which is windowed). c128 is slot-based: 1 slot per 128 full tokens (`// 128`).
+ def get_can_alloc_dsv4_c4_page_num(self):
+ allocator = getattr(self.req_manager.mem_manager, "c4_page_allocator", None)
+ if allocator is None:
+ return None
+ radix_unref_page_num = 0
+ if self.radix_cache is not None:
+ radix_unref_page_num = (
+ self.radix_cache.get_tree_total_tokens_num() - self.radix_cache.get_refed_tokens_num()
+ ) // 256
+ return int(allocator.can_use_mem_size) + int(radix_unref_page_num)
+
+ def get_can_alloc_dsv4_c128_slot_num(self):
+ allocator = getattr(self.req_manager.mem_manager, "c128_allocator", None)
+ if allocator is None:
+ return None
+ radix_unref_slot_num = 0
+ if self.radix_cache is not None:
+ radix_unref_slot_num = (
+ self.radix_cache.get_tree_total_tokens_num() - self.radix_cache.get_refed_tokens_num()
+ ) // 128
+ return int(allocator.can_use_mem_size) + int(radix_unref_slot_num)
+
+ def get_dsv4_c4_decode_need_page_num(self, req: "InferReq"):
+ if getattr(self.req_manager.mem_manager, "c4_page_allocator", None) is None:
+ return 0
+ seq_len = int(req.get_cur_total_len())
+ # 与 _scatter_c4_decode_slots 一致: 关组(seq%4==0)且组末 c4-entry 落页首(entry%64==0) -> 开新页
+ if seq_len > 0 and seq_len % 4 == 0 and (seq_len // 4 - 1) % 64 == 0:
+ return 1
+ return 0
+
+ def get_dsv4_c128_decode_need_slot_num(self, req: "InferReq"):
+ if getattr(self.req_manager.mem_manager, "c128_allocator", None) is None:
+ return 0
+ seq_len = int(req.get_cur_total_len())
+ return 1 if (seq_len > 0 and seq_len % 128 == 0) else 0
+
+ def get_dsv4_c4_prefill_need_page_num(self, req: "InferReq", is_chuncked_prefill: bool):
+ if getattr(self.req_manager.mem_manager, "c4_page_allocator", None) is None:
+ return 0
+ start = int(req.cur_kv_len)
+ end = int(req.get_chuncked_input_token_len()) if is_chuncked_prefill else int(req.get_cur_total_len())
+ first, last = start // 4, end // 4
+ if last <= first:
+ return 0
+ # 安全上界: 覆盖 c4-entry 区间 [first,last) 触及的全部 64-页(忽略已分配的延续页 -> 偏多, 安全)
+ return (last - 1) // 64 - first // 64 + 1
+
+ def get_dsv4_c128_prefill_need_slot_num(self, req: "InferReq", is_chuncked_prefill: bool):
+ if getattr(self.req_manager.mem_manager, "c128_allocator", None) is None:
+ return 0
+ start = int(req.cur_kv_len)
+ end = int(req.get_chuncked_input_token_len()) if is_chuncked_prefill else int(req.get_cur_total_len())
+ return max(0, end // 128 - start // 128)
+
def copy_linear_att_state_to_cache_buffer(self, b_req_idx: torch.Tensor, reqs: List["InferReq"]):
"""
该函数用于在线性混合模型prefill后,如果存在大页匹配的情况下,将线性层状态复制到
@@ -382,7 +569,9 @@ def copy_linear_att_state_to_cache_buffer(self, b_req_idx: torch.Tensor, reqs: L
)
big_page_buffer_ids = big_page_buffer_ids.cuda(non_blocking=True)
- from lightllm.common.basemodel.triton_kernel.linear_att_copy import copy_linear_att_state_to_kv_buffer
+ from lightllm.common.basemodel.triton_kernel.linear_att_copy import (
+ copy_linear_att_state_to_kv_buffer,
+ )
copy_linear_att_state_to_kv_buffer(
b_req_idx=b_req_idx,
@@ -412,9 +601,10 @@ def copy_linear_att_state_to_cache_buffer(self, b_req_idx: torch.Tensor, reqs: L
gpu_ssm_state = self.req_manager.req_to_ssm_state.buffer[:, src_buffer_idx, ...]
dst_buffer_idx = req.tail_linear_att_small_page_buffer_id
- dst_conv_state, dst_ssm_state = self.radix_cache.linear_att_small_page_buffers.get_state_cache(
- buffer_idx=dst_buffer_idx
- )
+ (
+ dst_conv_state,
+ dst_ssm_state,
+ ) = self.radix_cache.linear_att_small_page_buffers.get_state_cache(buffer_idx=dst_buffer_idx)
# TODO 对于非连续对象调用 copy_ 效率并不高
dst_conv_state.copy_(gpu_conv_state, non_blocking=True)
dst_ssm_state.copy_(gpu_ssm_state, non_blocking=True)
@@ -591,6 +781,8 @@ def _init_all_state(self):
self.cur_output_len = 0
g_infer_context.req_manager.req_sampling_params_manager.init_req_sampling_params(self)
+ if hasattr(g_infer_context.req_manager, "init_compress_state"):
+ g_infer_context.req_manager.init_compress_state(req_idx=self.req_idx)
self.stop_sequences = self.sampling_param.shm_param.stop_sequences.to_list()
# token healing mode 才被使用的管理对象
@@ -626,6 +818,9 @@ def _match_radix_cache(self):
ready_cache_len = share_node.node_prefix_total_len
# 从 cpu 到 gpu 是流内阻塞操作
g_infer_context.req_manager.req_to_token_indexs[self.req_idx, 0:ready_cache_len] = value_tensor
+ # DeepSeek-V4 命中无需任何恢复: 槽位由 full_to_* 映射键控(radix 持有 full 槽即有效,
+ # 命中长度已在 match_prefix 内按 bitmap 裁剪),c4 compressor 状态随 swa 页常驻
+ # (零拷贝续算),c128 状态在 128 对齐边界自然归零(init_compress_state 已清)。
self.cur_kv_len = int(ready_cache_len) # 序列化问题, 该对象可能为numpy.int64,用 int(*)转换
self.shm_req.prompt_cache_len = self.cur_kv_len # 记录 prompt cache 的命中长度
@@ -639,7 +834,10 @@ def _linear_match_radix_cache(self):
enable_prompt_cache = (not self.sampling_param.disable_prompt_cache) and g_infer_context.radix_cache is not None
linear_hash_list = self.shm_req.linear_att_token_hash_list.get_all()
linear_att_hash_page_size = self.args.linear_att_hash_page_size
- match_tokens = min(len(linear_hash_list) * linear_att_hash_page_size, self.get_cur_total_len() - 1)
+ match_tokens = min(
+ len(linear_hash_list) * linear_att_hash_page_size,
+ self.get_cur_total_len() - 1,
+ )
match_tokens = max(0, match_tokens)
match_tokens = (match_tokens // linear_att_hash_page_size) * linear_att_hash_page_size
match_block_num = match_tokens // linear_att_hash_page_size
@@ -705,7 +903,8 @@ def _linear_match_radix_cache(self):
# 将 对应的 value_tensors 中的 kv 数据 拷贝到 tail_mems 中对应的数据去
radix_cache.mem_manager.operator.copy_mem_to_mem(
- value_tensor[cur_big_page_tokens:shared_kv_len], tail_mems
+ value_tensor[cur_big_page_tokens:shared_kv_len],
+ tail_mems,
)
self.shared_kv_node = share_node # 只是为了保证 copy_small_page_buffer_to_linear_att_state 正确调用
@@ -736,7 +935,8 @@ def _linear_match_radix_cache(self):
assert self.tail_linear_att_small_page_buffer_id is None
# 恢复linear att 状态
g_infer_context.req_manager.copy_big_page_buffer_to_linear_att_state(
- big_page_buffer_idx=share_node.big_page_buffer_idx, req=self
+ big_page_buffer_idx=share_node.big_page_buffer_idx,
+ req=self,
)
self.shm_req.shm_cur_kv_len = self.cur_kv_len
@@ -792,6 +992,7 @@ def get_input_token_ids(self):
def get_chuncked_input_token_ids(self):
chunked_start = self.cur_kv_len
chunked_end = min(self.get_cur_total_len(), chunked_start + self.args.chunked_prefill_size)
+ chunked_end = self._align_chuncked_end_for_prompt_cache(chunked_start, chunked_end)
return self.shm_req.shm_prompt_ids.arr[0:chunked_end]
def get_chuncked_input_token_ids_for_linear_att(self):
@@ -812,6 +1013,23 @@ def get_chuncked_input_token_ids_for_linear_att(self):
def get_chuncked_input_token_len(self):
chunked_start = self.cur_kv_len
chunked_end = min(self.get_cur_total_len(), chunked_start + self.args.chunked_prefill_size)
+ return self._align_chuncked_end_for_prompt_cache(chunked_start, chunked_end)
+
+ def _align_chuncked_end_for_prompt_cache(self, chunked_start: int, chunked_end: int):
+ radix_cache = g_infer_context.radix_cache
+ page_size = getattr(radix_cache, "page_size", 1) if radix_cache is not None else 1
+ if page_size <= 1 or self.sampling_param.disable_prompt_cache:
+ return chunked_end
+ prompt_end = int(self.shm_req.input_len)
+ chunked_start = int(chunked_start)
+ chunked_end = int(chunked_end)
+ if chunked_end >= prompt_end:
+ return chunked_end
+
+ assert self.args.chunked_prefill_size % page_size == 0, (
+ f"chunked_prefill_size={self.args.chunked_prefill_size} must be divisible by "
+ f"prompt-cache page_size={page_size}"
+ )
return chunked_end
def get_chuncked_input_token_len_for_linear_att(self):
diff --git a/lightllm/server/router/model_infer/mode_backend/base_backend.py b/lightllm/server/router/model_infer/mode_backend/base_backend.py
index a65dfb1bb..07eaeefa1 100644
--- a/lightllm/server/router/model_infer/mode_backend/base_backend.py
+++ b/lightllm/server/router/model_infer/mode_backend/base_backend.py
@@ -153,6 +153,8 @@ def init_model(self, kvargs):
self.model: TpPartBaseModel = self.model # for easy typing
set_random_seed(2147483647)
self.is_linear_att_mixed_model = isinstance(self.model.req_manager, ReqManagerForMamba)
+ if hasattr(self.model.req_manager, "build_prompt_cache_payload"):
+ self.support_overlap = False
if self.is_linear_att_mixed_model:
self.linear_att_cache_manager = LinearAttCacheManager(
@@ -164,6 +166,7 @@ def init_model(self, kvargs):
if not self.use_dynamic_prompt_cache:
self.radix_cache = None
+ setattr(self.args, "dynamic_prompt_cache_page_size", 1)
else:
if self.is_linear_att_mixed_model:
self.radix_cache = LinearAttPagedRadixCache(
@@ -175,13 +178,25 @@ def init_model(self, kvargs):
kv_cache_mem_manager=self.model.mem_manager,
linear_att_small_page_buffers=self.linear_att_cache_manager,
)
+ setattr(self.args, "dynamic_prompt_cache_page_size", 1)
else:
+ radix_page_size = 1
+ radix_extra_value_ops = None
+ if hasattr(self.model.req_manager, "get_prompt_cache_value_ops"):
+ radix_page_size = self.model.req_manager.get_prompt_cache_page_size()
+ radix_extra_value_ops = self.model.req_manager.get_prompt_cache_value_ops()
+ setattr(self.args, "dynamic_prompt_cache_page_size", radix_page_size)
self.radix_cache = RadixCache(
unique_name=get_unique_server_name(),
total_token_num=self.model.mem_manager.size,
rank_in_node=self.rank_in_node,
mem_manager=self.model.mem_manager,
+ page_size=radix_page_size,
+ extra_value_ops=radix_extra_value_ops,
)
+ if radix_extra_value_ops is not None and hasattr(self.model.mem_manager, "register_swa_free_hook"):
+ # swa 页 allocator 触底时让 radix 对 ref==0 节点 free swa 页(DeepSeek-V4)。
+ self.model.mem_manager.register_swa_free_hook(self.radix_cache.free_unreferenced_swa_pages)
if "prompt_cache_kv_buffer" in model_cfg:
assert self.use_dynamic_prompt_cache
@@ -582,6 +597,9 @@ def _get_classed_reqs(
prefill_tokens = 0
can_alloc_token_num = g_infer_context.get_can_alloc_token_num()
+ can_alloc_dsv4_swa_page_num = g_infer_context.get_can_alloc_dsv4_swa_page_num()
+ can_alloc_dsv4_c4_page_num = g_infer_context.get_can_alloc_dsv4_c4_page_num()
+ can_alloc_dsv4_c128_slot_num = g_infer_context.get_can_alloc_dsv4_c128_slot_num()
for req_obj in ready_reqs:
@@ -615,9 +633,23 @@ def _get_classed_reqs(
if is_decode:
token_num = req_obj.decode_need_token_num()
- if token_num <= can_alloc_token_num:
+ swa_page_num = g_infer_context.get_dsv4_swa_decode_need_page_num(req_obj)
+ c4_page_num = g_infer_context.get_dsv4_c4_decode_need_page_num(req_obj)
+ c128_slot_num = g_infer_context.get_dsv4_c128_decode_need_slot_num(req_obj)
+ if (
+ token_num <= can_alloc_token_num
+ and (can_alloc_dsv4_swa_page_num is None or swa_page_num <= can_alloc_dsv4_swa_page_num)
+ and (can_alloc_dsv4_c4_page_num is None or c4_page_num <= can_alloc_dsv4_c4_page_num)
+ and (can_alloc_dsv4_c128_slot_num is None or c128_slot_num <= can_alloc_dsv4_c128_slot_num)
+ ):
decode_reqs.append(req_obj)
can_alloc_token_num -= token_num
+ if can_alloc_dsv4_swa_page_num is not None:
+ can_alloc_dsv4_swa_page_num -= swa_page_num
+ if can_alloc_dsv4_c4_page_num is not None:
+ can_alloc_dsv4_c4_page_num -= c4_page_num
+ if can_alloc_dsv4_c128_slot_num is not None:
+ can_alloc_dsv4_c128_slot_num -= c128_slot_num
else:
if wait_pause_count < pause_max_req_num:
req_obj.wait_pause = True
@@ -632,10 +664,30 @@ def _get_classed_reqs(
token_num = req_obj.prefill_need_token_num(is_chuncked_prefill=not self.disable_chunked_prefill)
if prefill_tokens + token_num > self.batch_max_tokens:
continue
- if token_num <= can_alloc_token_num:
+ swa_page_num = g_infer_context.get_dsv4_swa_prefill_need_page_num(
+ req_obj, is_chuncked_prefill=not self.disable_chunked_prefill
+ )
+ c4_page_num = g_infer_context.get_dsv4_c4_prefill_need_page_num(
+ req_obj, is_chuncked_prefill=not self.disable_chunked_prefill
+ )
+ c128_slot_num = g_infer_context.get_dsv4_c128_prefill_need_slot_num(
+ req_obj, is_chuncked_prefill=not self.disable_chunked_prefill
+ )
+ if (
+ token_num <= can_alloc_token_num
+ and (can_alloc_dsv4_swa_page_num is None or swa_page_num <= can_alloc_dsv4_swa_page_num)
+ and (can_alloc_dsv4_c4_page_num is None or c4_page_num <= can_alloc_dsv4_c4_page_num)
+ and (can_alloc_dsv4_c128_slot_num is None or c128_slot_num <= can_alloc_dsv4_c128_slot_num)
+ ):
prefill_tokens += token_num
prefill_reqs.append(req_obj)
can_alloc_token_num -= token_num
+ if can_alloc_dsv4_swa_page_num is not None:
+ can_alloc_dsv4_swa_page_num -= swa_page_num
+ if can_alloc_dsv4_c4_page_num is not None:
+ can_alloc_dsv4_c4_page_num -= c4_page_num
+ if can_alloc_dsv4_c128_slot_num is not None:
+ can_alloc_dsv4_c128_slot_num -= c128_slot_num
else:
if wait_pause_count < pause_max_req_num:
req_obj.wait_pause = True
diff --git a/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl.py b/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl.py
index 792a10a78..40d1c175a 100644
--- a/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl.py
+++ b/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl.py
@@ -402,6 +402,7 @@ def _draft_decode_eagle(
draft_model_output: ModelOutput = self.draft_models[draft_model_idx].forward(draft_model_input)
draft_next_token_ids = self._gen_argmax_token_ids(draft_model_output)
draft_model_input.b_seq_len += 1
+ draft_model_input.b_seq_len_cpu += 1
draft_model_input.max_kv_seq_len += 1
eagle_mem_indexes_i = eagle_mem_indexes[_step * num_reqs : (_step + 1) * num_reqs]
draft_model_input.mem_indexes = torch.cat(
diff --git a/lightllm/server/router/model_infer/mode_backend/dp_backend/impl.py b/lightllm/server/router/model_infer/mode_backend/dp_backend/impl.py
index e6b9d1c18..b051a1a3a 100644
--- a/lightllm/server/router/model_infer/mode_backend/dp_backend/impl.py
+++ b/lightllm/server/router/model_infer/mode_backend/dp_backend/impl.py
@@ -603,6 +603,7 @@ def _draft_decode_eagle(
draft_model_output: ModelOutput = self.draft_models[draft_model_idx].forward(draft_model_input)
# update the meta info of the inference
draft_model_input.b_seq_len += 1
+ draft_model_input.b_seq_len_cpu += 1
draft_model_input.max_kv_seq_len += 1
eagle_mem_indexes_i = eagle_mem_indexes[_step * real_req_num : (_step + 1) * real_req_num]
eagle_mem_indexes_i = F.pad(
@@ -967,6 +968,7 @@ def _draft_decode_eagle_overlap(
)
draft_model_input0.b_seq_len += 1
+ draft_model_input0.b_seq_len_cpu += 1
draft_model_input0.max_kv_seq_len += 1
eagle_mem_indexes_i = eagle_mem_indexes0[_step * real_req_num0 : (_step + 1) * real_req_num0]
eagle_mem_indexes_i = F.pad(
@@ -981,6 +983,7 @@ def _draft_decode_eagle_overlap(
).view(-1)
draft_model_input1.b_seq_len += 1
+ draft_model_input1.b_seq_len_cpu += 1
draft_model_input1.max_kv_seq_len += 1
eagle_mem_indexes_i = eagle_mem_indexes1[_step * real_req_num1 : (_step + 1) * real_req_num1]
eagle_mem_indexes_i = F.pad(
diff --git a/lightllm/server/tokenizer.py b/lightllm/server/tokenizer.py
index e1a4e421d..6aea7cd67 100644
--- a/lightllm/server/tokenizer.py
+++ b/lightllm/server/tokenizer.py
@@ -90,6 +90,11 @@ def get_tokenizer(
)
logger.info("Using DeepSeek-V3.2 tokenizer mode with Python-based chat template encoding.")
return DeepSeekV32Tokenizer(hf_tokenizer)
+ if model_type == "deepseek_v4":
+ from ..models.deepseek_v4.model import DeepSeekV4Tokenizer
+
+ logger.info("Using DeepSeek-V4 tokenizer mode with Python-based chat template encoding.")
+ return DeepSeekV4Tokenizer(tokenizer, tokenizer_name)
if model_cfg["architectures"][0] == "TarsierForConditionalGeneration":
from ..models.qwen2_vl.vision_process import Qwen2VLImageProcessor
diff --git a/lightllm/third_party/__init__.py b/lightllm/third_party/__init__.py
new file mode 100644
index 000000000..2adb50db2
--- /dev/null
+++ b/lightllm/third_party/__init__.py
@@ -0,0 +1 @@
+"""Third-party source subsets vendored for LightLLM runtime support."""
diff --git a/lightllm/third_party/sglang_jit/LICENSE b/lightllm/third_party/sglang_jit/LICENSE
new file mode 100755
index 000000000..9c422689c
--- /dev/null
+++ b/lightllm/third_party/sglang_jit/LICENSE
@@ -0,0 +1,201 @@
+ Apache License
+ Version 2.0, January 2004
+ http://www.apache.org/licenses/
+
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
+
+ 1. Definitions.
+
+ "License" shall mean the terms and conditions for use, reproduction,
+ and distribution as defined by Sections 1 through 9 of this document.
+
+ "Licensor" shall mean the copyright owner or entity authorized by
+ the copyright owner that is granting the License.
+
+ "Legal Entity" shall mean the union of the acting entity and all
+ other entities that control, are controlled by, or are under common
+ control with that entity. For the purposes of this definition,
+ "control" means (i) the power, direct or indirect, to cause the
+ direction or management of such entity, whether by contract or
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
+ outstanding shares, or (iii) beneficial ownership of such entity.
+
+ "You" (or "Your") shall mean an individual or Legal Entity
+ exercising permissions granted by this License.
+
+ "Source" form shall mean the preferred form for making modifications,
+ including but not limited to software source code, documentation
+ source, and configuration files.
+
+ "Object" form shall mean any form resulting from mechanical
+ transformation or translation of a Source form, including but
+ not limited to compiled object code, generated documentation,
+ and conversions to other media types.
+
+ "Work" shall mean the work of authorship, whether in Source or
+ Object form, made available under the License, as indicated by a
+ copyright notice that is included in or attached to the work
+ (an example is provided in the Appendix below).
+
+ "Derivative Works" shall mean any work, whether in Source or Object
+ form, that is based on (or derived from) the Work and for which the
+ editorial revisions, annotations, elaborations, or other modifications
+ represent, as a whole, an original work of authorship. For the purposes
+ of this License, Derivative Works shall not include works that remain
+ separable from, or merely link (or bind by name) to the interfaces of,
+ the Work and Derivative Works thereof.
+
+ "Contribution" shall mean any work of authorship, including
+ the original version of the Work and any modifications or additions
+ to that Work or Derivative Works thereof, that is intentionally
+ submitted to Licensor for inclusion in the Work by the copyright owner
+ or by an individual or Legal Entity authorized to submit on behalf of
+ the copyright owner. For the purposes of this definition, "submitted"
+ means any form of electronic, verbal, or written communication sent
+ to the Licensor or its representatives, including but not limited to
+ communication on electronic mailing lists, source code control systems,
+ and issue tracking systems that are managed by, or on behalf of, the
+ Licensor for the purpose of discussing and improving the Work, but
+ excluding communication that is conspicuously marked or otherwise
+ designated in writing by the copyright owner as "Not a Contribution."
+
+ "Contributor" shall mean Licensor and any individual or Legal Entity
+ on behalf of whom a Contribution has been received by Licensor and
+ subsequently incorporated within the Work.
+
+ 2. Grant of Copyright License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ copyright license to reproduce, prepare Derivative Works of,
+ publicly display, publicly perform, sublicense, and distribute the
+ Work and such Derivative Works in Source or Object form.
+
+ 3. Grant of Patent License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ (except as stated in this section) patent license to make, have made,
+ use, offer to sell, sell, import, and otherwise transfer the Work,
+ where such license applies only to those patent claims licensable
+ by such Contributor that are necessarily infringed by their
+ Contribution(s) alone or by combination of their Contribution(s)
+ with the Work to which such Contribution(s) was submitted. If You
+ institute patent litigation against any entity (including a
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
+ or a Contribution incorporated within the Work constitutes direct
+ or contributory patent infringement, then any patent licenses
+ granted to You under this License for that Work shall terminate
+ as of the date such litigation is filed.
+
+ 4. Redistribution. You may reproduce and distribute copies of the
+ Work or Derivative Works thereof in any medium, with or without
+ modifications, and in Source or Object form, provided that You
+ meet the following conditions:
+
+ (a) You must give any other recipients of the Work or
+ Derivative Works a copy of this License; and
+
+ (b) You must cause any modified files to carry prominent notices
+ stating that You changed the files; and
+
+ (c) You must retain, in the Source form of any Derivative Works
+ that You distribute, all copyright, patent, trademark, and
+ attribution notices from the Source form of the Work,
+ excluding those notices that do not pertain to any part of
+ the Derivative Works; and
+
+ (d) If the Work includes a "NOTICE" text file as part of its
+ distribution, then any Derivative Works that You distribute must
+ include a readable copy of the attribution notices contained
+ within such NOTICE file, excluding those notices that do not
+ pertain to any part of the Derivative Works, in at least one
+ of the following places: within a NOTICE text file distributed
+ as part of the Derivative Works; within the Source form or
+ documentation, if provided along with the Derivative Works; or,
+ within a display generated by the Derivative Works, if and
+ wherever such third-party notices normally appear. The contents
+ of the NOTICE file are for informational purposes only and
+ do not modify the License. You may add Your own attribution
+ notices within Derivative Works that You distribute, alongside
+ or as an addendum to the NOTICE text from the Work, provided
+ that such additional attribution notices cannot be construed
+ as modifying the License.
+
+ You may add Your own copyright statement to Your modifications and
+ may provide additional or different license terms and conditions
+ for use, reproduction, or distribution of Your modifications, or
+ for any such Derivative Works as a whole, provided Your use,
+ reproduction, and distribution of the Work otherwise complies with
+ the conditions stated in this License.
+
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
+ any Contribution intentionally submitted for inclusion in the Work
+ by You to the Licensor shall be under the terms and conditions of
+ this License, without any additional terms or conditions.
+ Notwithstanding the above, nothing herein shall supersede or modify
+ the terms of any separate license agreement you may have executed
+ with Licensor regarding such Contributions.
+
+ 6. Trademarks. This License does not grant permission to use the trade
+ names, trademarks, service marks, or product names of the Licensor,
+ except as required for reasonable and customary use in describing the
+ origin of the Work and reproducing the content of the NOTICE file.
+
+ 7. Disclaimer of Warranty. Unless required by applicable law or
+ agreed to in writing, Licensor provides the Work (and each
+ Contributor provides its Contributions) on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
+ implied, including, without limitation, any warranties or conditions
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
+ PARTICULAR PURPOSE. You are solely responsible for determining the
+ appropriateness of using or redistributing the Work and assume any
+ risks associated with Your exercise of permissions under this License.
+
+ 8. Limitation of Liability. In no event and under no legal theory,
+ whether in tort (including negligence), contract, or otherwise,
+ unless required by applicable law (such as deliberate and grossly
+ negligent acts) or agreed to in writing, shall any Contributor be
+ liable to You for damages, including any direct, indirect, special,
+ incidental, or consequential damages of any character arising as a
+ result of this License or out of the use or inability to use the
+ Work (including but not limited to damages for loss of goodwill,
+ work stoppage, computer failure or malfunction, or any and all
+ other commercial damages or losses), even if such Contributor
+ has been advised of the possibility of such damages.
+
+ 9. Accepting Warranty or Additional Liability. While redistributing
+ the Work or Derivative Works thereof, You may choose to offer,
+ and charge a fee for, acceptance of support, warranty, indemnity,
+ or other liability obligations and/or rights consistent with this
+ License. However, in accepting such obligations, You may act only
+ on Your own behalf and on Your sole responsibility, not on behalf
+ of any other Contributor, and only if You agree to indemnify,
+ defend, and hold each Contributor harmless for any liability
+ incurred by, or claims asserted against, such Contributor by reason
+ of your accepting any such warranty or additional liability.
+
+ END OF TERMS AND CONDITIONS
+
+ APPENDIX: How to apply the Apache License to your work.
+
+ To apply the Apache License to your work, attach the following
+ boilerplate notice, with the fields enclosed by brackets "[]"
+ replaced with your own identifying information. (Don't include
+ the brackets!) The text should be enclosed in the appropriate
+ comment syntax for the file format. We also recommend that a
+ file or class name and description of purpose be included on the
+ same "printed page" as the copyright notice for easier
+ identification within third-party archives.
+
+ Copyright 2023-2024 SGLang Team
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
diff --git a/lightllm/third_party/sglang_jit/README.md b/lightllm/third_party/sglang_jit/README.md
new file mode 100644
index 000000000..4f68c9cfd
--- /dev/null
+++ b/lightllm/third_party/sglang_jit/README.md
@@ -0,0 +1,13 @@
+# Vendored SGLang JIT Subset
+
+This directory contains the minimal SGLang JIT source subset needed by the
+DeepSeek-V4 LightLLM implementation.
+
+Source: https://github.com/sgl-project/sglang
+Commit: 8cea0473ea5299bc04885f8f6ba71269415a39b5
+License: Apache License 2.0, copied in `LICENSE`.
+
+Local changes:
+- The Python imports were moved from `sglang.jit_kernel.*` to
+ `lightllm.third_party.sglang_jit.*`.
+- The package exports only the DSv4 functions used by LightLLM.
diff --git a/lightllm/third_party/sglang_jit/__init__.py b/lightllm/third_party/sglang_jit/__init__.py
new file mode 100644
index 000000000..164d545b4
--- /dev/null
+++ b/lightllm/third_party/sglang_jit/__init__.py
@@ -0,0 +1 @@
+"""Vendored SGLang JIT kernels used by DeepSeek-V4."""
diff --git a/lightllm/third_party/sglang_jit/csrc/deepseek_v4/c128.cuh b/lightllm/third_party/sglang_jit/csrc/deepseek_v4/c128.cuh
new file mode 100644
index 000000000..3a89e8114
--- /dev/null
+++ b/lightllm/third_party/sglang_jit/csrc/deepseek_v4/c128.cuh
@@ -0,0 +1,522 @@
+#include
+#include
+
+#include
+#include
+#include
+#include
+#include
+#include
+
+#include
+
+#include
+#include
+#include
+
+#include
+
+namespace {
+
+using Plan128 = device::compress::PrefillPlan;
+using IndiceT = int32_t;
+
+/// \brief Each thread will handle this many elements (split along head_dim)
+constexpr int32_t kTileElements = 2;
+/// \brief Each warp will handle this many elements (split along 128)
+constexpr int32_t kElementsPerWarp = 8;
+constexpr uint32_t kNumWarps = 128 / kElementsPerWarp;
+constexpr uint32_t kBlockSize = device::kWarpThreads * kNumWarps;
+
+/// \brief Need to reduce register usage to increase occupancy
+#define C128_KERNEL __global__ __launch_bounds__(kBlockSize, 2)
+
+struct Compress128DecodeParams {
+ /**
+ * \brief Shape: `[num_indices, 128, head_dim * 2]` \n
+ * last dimension layout:
+ * | kv current | score current |
+ */
+ void* __restrict__ kv_score_buffer;
+ /** \brief Shape: `[batch_size, head_dim * 2]` */
+ const void* __restrict__ kv_score_input;
+ /** \brief Shape: `[batch_size, head_dim]` */
+ void* __restrict__ kv_compressed_output;
+ /** \brief Shape: `[128, head_dim]` (called `ape`) */
+ const void* __restrict__ score_bias;
+ /** \brief Shape: `[batch_size, ]`*/
+ const IndiceT* __restrict__ indices;
+ /** \brief Shape: `[batch_size, ]` */
+ const IndiceT* __restrict__ seq_lens;
+ /** \NOTE: `batch_size` <= `num_indices` */
+ uint32_t batch_size;
+};
+
+struct Compress128PrefillParams {
+ /**
+ * \brief Shape: `[num_indices, 128, head_dim * 2]` \n
+ * last dimension layout:
+ * | kv current | score current |
+ */
+ void* __restrict__ kv_score_buffer;
+ /** \brief Shape: `[batch_size, head_dim * 2]` */
+ const void* __restrict__ kv_score_input;
+ /** \brief Shape: `[batch_size, head_dim]` */
+ void* __restrict__ kv_compressed_output;
+ /** \brief Shape: `[128, head_dim]` (called `ape`) */
+ const void* __restrict__ score_bias;
+ /** \brief Shape: `[batch_size, ]`*/
+ const IndiceT* __restrict__ indices;
+ /** \brief Shape: `[batch_size, ]`*/
+ const int32_t* __restrict__ load_indices;
+ /** \brief The following part is plan info. */
+ const Plan128* __restrict__ compress_plan;
+ const Plan128* __restrict__ write_plan;
+ uint32_t num_compress;
+ uint32_t num_write;
+};
+
+struct Compress128SharedBuffer {
+ using Storage = device::AlignedVector;
+ Storage data[kNumWarps][device::kWarpThreads + 1]; // padding to avoid bank conflict
+ SGL_DEVICE Storage& operator()(uint32_t warp_id, uint32_t lane_id) {
+ return data[warp_id][lane_id];
+ }
+ SGL_DEVICE float& operator()(uint32_t warp_id, uint32_t lane_id, uint32_t tile_id) {
+ return data[warp_id][lane_id][tile_id];
+ }
+};
+
+template
+SGL_DEVICE void c128_write(
+ T* kv_score_buf, //
+ const T* kv_score_src,
+ const int64_t head_dim,
+ const int32_t write_pos,
+ const uint32_t lane_id) {
+ using namespace device;
+
+ using Storage = AlignedVector;
+ const auto element_size = head_dim * 2;
+ const auto gmem = tile::Memory{lane_id, kWarpThreads};
+ kv_score_buf += write_pos * element_size;
+
+ /// NOTE: Layout | [0] = kv | [1] = score |
+ Storage kv_score[2];
+#pragma unroll
+ for (int32_t i = 0; i < 2; ++i) {
+ kv_score[i] = gmem.load(kv_score_src + head_dim * i);
+ }
+#pragma unroll
+ for (int32_t i = 0; i < 2; ++i) {
+ gmem.store(kv_score_buf + head_dim * i, kv_score[i]);
+ }
+}
+
+template
+SGL_DEVICE void c128_forward(
+ const InFloat* kv_score_buf,
+ const InFloat* kv_score_src,
+ OutFloat* kv_out,
+ const InFloat* score_bias,
+ const int64_t head_dim,
+ const int32_t window_len,
+ const uint32_t warp_id,
+ const uint32_t lane_id) {
+ using namespace device;
+
+ const auto element_size = head_dim * 2;
+ const auto score_offset = head_dim;
+
+ /// NOTE: part 1: load kv + score
+ using StorageIn = AlignedVector;
+ const auto gmem_in = tile::Memory{lane_id, kWarpThreads};
+ StorageIn kv[kElementsPerWarp];
+ StorageIn score[kElementsPerWarp];
+ StorageIn bias[kElementsPerWarp];
+ const int32_t warp_offset = warp_id * kElementsPerWarp;
+
+#pragma unroll
+ for (int32_t i = 0; i < 8; ++i) {
+ const int32_t j = i + warp_offset;
+ bias[i] = gmem_in.load(score_bias + j * head_dim);
+ }
+
+#pragma unroll
+ for (int32_t i = 0; i < kElementsPerWarp; ++i) {
+ const int32_t j = i + warp_offset;
+ const InFloat* src;
+ __builtin_assume(j < 128);
+ if (j < window_len) {
+ src = kv_score_buf + j * element_size;
+ } else {
+ /// NOTE: k in [-127, 0]. We'll load from the ragged `kv_score_src`
+ const int32_t k = j - 127;
+ src = kv_score_src + k * element_size;
+ }
+ kv[i] = gmem_in.load(src);
+ score[i] = gmem_in.load(src + score_offset);
+ }
+
+ /// NOTE: part 2: safe online softmax + weighted sum
+ using TmpStorage = typename Compress128SharedBuffer::Storage;
+ __shared__ Compress128SharedBuffer s_local_val_max;
+ __shared__ Compress128SharedBuffer s_local_exp_sum;
+ __shared__ Compress128SharedBuffer s_local_product;
+
+ TmpStorage tmp_val_max;
+ TmpStorage tmp_exp_sum;
+ TmpStorage tmp_product;
+
+#pragma unroll
+ for (int32_t i = 0; i < kTileElements; ++i) {
+ float score_fp32[kElementsPerWarp];
+
+#pragma unroll
+ for (int32_t j = 0; j < kElementsPerWarp; ++j) {
+ score_fp32[j] = cast(score[j][i]) + cast(bias[j][i]);
+ }
+
+ float max_value = score_fp32[0];
+ float sum_exp_value = 0.0f;
+
+#pragma unroll
+ for (int32_t j = 1; j < kElementsPerWarp; ++j) {
+ const auto fp32_score = score_fp32[j];
+ max_value = fmaxf(max_value, fp32_score);
+ }
+
+ float sum_product = 0.0f;
+#pragma unroll
+ for (int32_t j = 0; j < 8; ++j) {
+ const auto fp32_score = score_fp32[j];
+ const auto exp_score = expf(fp32_score - max_value);
+ sum_product += cast(kv[j][i]) * exp_score;
+ sum_exp_value += exp_score;
+ }
+
+ tmp_val_max[i] = max_value;
+ tmp_exp_sum[i] = sum_exp_value;
+ tmp_product[i] = sum_product;
+ }
+
+ // naturally aligned, so no bank conflict
+ s_local_val_max(warp_id, lane_id) = tmp_val_max;
+ s_local_exp_sum(warp_id, lane_id) = tmp_exp_sum;
+ s_local_product(warp_id, lane_id) = tmp_product;
+
+ __syncthreads();
+
+ /// NOTE: part 3: online softmax
+ /// NOTE: We have `kTileElements * kWarpThreads * kNumWarps` values to reduce
+ /// each reduce will consume `kNumWarps` threads (use partial warp reduction)
+ constexpr uint32_t kReductionCount = kTileElements * kWarpThreads * kNumWarps;
+ constexpr uint32_t kIteration = kReductionCount / kBlockSize;
+
+#pragma unroll
+ for (uint32_t i = 0; i < kIteration; ++i) {
+ /// NOTE: Range `[0, kTileElements * kWarpThreads * kNumWarps)`
+ const uint32_t j = i * kBlockSize + warp_id * kWarpThreads + lane_id;
+ /// NOTE: Range `[0, kNumWarps)`
+ const uint32_t local_warp_id = j % kNumWarps;
+ /// NOTE: Range `[0, kTileElements * kWarpThreads)`
+ const uint32_t local_elem_id = j / kNumWarps;
+ /// NOTE: Range `[0, kTileElements)`
+ const uint32_t local_tile_id = local_elem_id % kTileElements;
+ /// NOTE: Range `[0, kWarpThreads)`
+ const uint32_t local_lane_id = local_elem_id / kTileElements;
+ /// NOTE: each warp will access the whole tile (all `kTileElements`)
+ /// and for different lanes, the memory access only differ in `local_warp_id`
+ /// so there's no bank conflict in shared memory access.
+ static_assert(kTileElements * kNumWarps == kWarpThreads, "TODO: support other configs");
+ const auto local_val_max = s_local_val_max(local_warp_id, local_lane_id, local_tile_id);
+ const auto local_exp_sum = s_local_exp_sum(local_warp_id, local_lane_id, local_tile_id);
+ const auto local_product = s_local_product(local_warp_id, local_lane_id, local_tile_id);
+ const auto global_val_max = warp::reduce_max(local_val_max);
+ const auto rescale = expf(local_val_max - global_val_max);
+ const auto global_exp_sum = warp::reduce_sum(local_exp_sum * rescale);
+ const auto final_scale = rescale / global_exp_sum;
+ const auto global_product = warp::reduce_sum(local_product * final_scale);
+ kv_out[local_elem_id] = cast(global_product);
+ }
+}
+
+template
+C128_KERNEL void flash_c128_decode(const __grid_constant__ Compress128DecodeParams params) {
+ using namespace device;
+
+ constexpr int64_t kTileDim = kTileElements * kWarpThreads; // 64
+ constexpr uint32_t kNumSplit = kHeadDim / kTileDim;
+ constexpr int64_t kElementSize = kHeadDim * 2;
+ static_assert(kHeadDim % kTileDim == 0, "Head dim must be multiple of tile dim");
+
+ const auto& [
+ _kv_score_buffer, _kv_score_input, _kv_compressed_output, _score_bias, // kv score
+ indices, seq_lens, batch_size // decode info
+ ] = params;
+ const uint32_t warp_id = threadIdx.x / kWarpThreads;
+ const uint32_t lane_id = threadIdx.x % kWarpThreads;
+
+ const uint32_t global_bid = blockIdx.x / kNumSplit; // batch id
+ const uint32_t global_sid = blockIdx.x % kNumSplit; // split id
+ if (global_bid >= batch_size) return;
+
+ const int32_t index = indices[global_bid];
+ const int32_t seq_len = seq_lens[global_bid];
+ const int64_t split_offset = global_sid * kTileDim;
+
+ // kv score
+ const auto kv_score_buffer = static_cast(_kv_score_buffer);
+ const auto kv_buf = kv_score_buffer + index * (kElementSize * 128) + split_offset;
+
+ // kv input
+ const auto kv_score_input = static_cast(_kv_score_input);
+ const auto kv_src = kv_score_input + global_bid * kElementSize + split_offset;
+
+ // kv output
+ const auto kv_compressed_output = static_cast(_kv_compressed_output);
+ const auto kv_out = kv_compressed_output + global_bid * kHeadDim + split_offset;
+
+ // score bias (ape)
+ const auto score_bias = static_cast(_score_bias) + split_offset;
+
+ PDLWaitPrimary();
+
+ /// NOTE: the write must be visible to the subsequent c128_forward,
+ /// so only the last warp can write to HBM
+ /// In addition, `position` = `seq_len - 1`. To avoid underflow, we use `seq_len + 127`
+ if (warp_id == kNumWarps - 1) {
+ c128_write(kv_buf, kv_src, kHeadDim, /*write_pos=*/(seq_len + 127) % 128, lane_id);
+ }
+ if (seq_len % 128 == 0) {
+ c128_forward(kv_buf, kv_src, kv_out, score_bias, kHeadDim, /*window_len=*/128, warp_id, lane_id);
+ }
+
+ PDLTriggerSecondary();
+}
+
+// compress kernel
+template
+C128_KERNEL void flash_c128_prefill(const __grid_constant__ Compress128PrefillParams params) {
+ using namespace device;
+
+ constexpr int64_t kTileDim = kTileElements * kWarpThreads; // 64
+ constexpr uint32_t kNumSplit = kHeadDim / kTileDim;
+ constexpr int64_t kElementSize = kHeadDim * 2;
+ static_assert(kHeadDim % kTileDim == 0, "Head dim must be multiple of tile dim");
+
+ const auto& [
+ _kv_score_buffer, _kv_score_input, _kv_compressed_output, _score_bias, // kv score
+ indices, load_indices, compress_plan, write_plan, num_compress, num_write // prefill plan
+ ] = params;
+ const uint32_t warp_id = threadIdx.x / kWarpThreads;
+ const uint32_t lane_id = threadIdx.x % kWarpThreads;
+
+ uint32_t global_id;
+ if constexpr (kWrite) {
+ // for write kernel, we use global warp_id to dispatch work
+ global_id = (blockIdx.x * blockDim.x + threadIdx.x) / kWarpThreads;
+ } else {
+ // for compress kernel, we use block id to dispatch work
+ global_id = blockIdx.x; // block id
+ }
+ const uint32_t global_pid = global_id / kNumSplit; // plan id
+ const uint32_t global_sid = global_id % kNumSplit; // split id
+
+ /// NOTE: compiler can optimize this if-else at compile time
+ const auto num_plans = kWrite ? num_write : num_compress;
+ const auto plan_ptr = kWrite ? write_plan : compress_plan;
+ if (global_pid >= num_plans) return;
+
+ const auto& [ragged_id, global_bid, position, window_len] = plan_ptr[global_pid];
+ const auto indices_ptr = kWrite ? indices : load_indices;
+
+ const int64_t split_offset = global_sid * kTileDim;
+
+ // kv input
+ const auto kv_score_input = static_cast(_kv_score_input);
+ const auto kv_src = kv_score_input + ragged_id * kElementSize + split_offset;
+
+ // kv output
+ const auto kv_compressed_output = static_cast(_kv_compressed_output);
+ const auto kv_out = kv_compressed_output + ragged_id * kHeadDim + split_offset;
+
+ // score bias (ape)
+ const auto score_bias = static_cast(_score_bias) + split_offset;
+
+ if (ragged_id == 0xFFFFFFFF) [[unlikely]]
+ return;
+
+ const int32_t index = indices_ptr[global_bid];
+ // kv score
+ const auto kv_score_buffer = static_cast(_kv_score_buffer);
+ const auto kv_buf = kv_score_buffer + index * (kElementSize * 128) + split_offset;
+
+ PDLWaitPrimary();
+
+ // only responsible for the compress part
+ if constexpr (kWrite) {
+ c128_write(kv_buf, kv_src, kHeadDim, /*write_pos=*/position % 128, lane_id);
+ } else {
+ c128_forward(kv_buf, kv_src, kv_out, score_bias, kHeadDim, window_len, warp_id, lane_id);
+ }
+
+ PDLTriggerSecondary();
+}
+
+template
+struct FlashCompress128Kernel {
+ static constexpr auto decode_kernel = flash_c128_decode;
+ template
+ static constexpr auto prefill_kernel = flash_c128_prefill;
+ static constexpr auto prefill_c_kernel = prefill_kernel*kWrite=*/false>;
+ static constexpr auto prefill_w_kernel = prefill_kernel*kWrite=*/true>;
+ static constexpr int64_t kTileDim = kTileElements * device::kWarpThreads; // 64
+ static constexpr uint32_t kNumSplit = kHeadDim / kTileDim;
+ static constexpr uint32_t kWriteBlockSize = 128;
+ static constexpr uint32_t kWarpsPerWriteBlock = kWriteBlockSize / device::kWarpThreads;
+
+ static void run_decode(
+ const tvm::ffi::TensorView kv_score_buffer,
+ const tvm::ffi::TensorView kv_score_input,
+ const tvm::ffi::TensorView kv_compressed_output,
+ const tvm::ffi::TensorView ape,
+ const tvm::ffi::TensorView indices,
+ const tvm::ffi::TensorView seq_lens,
+ const tvm::ffi::Optional /* UNUSED */) {
+ using namespace host;
+
+ // this should not happen in practice
+ auto B = SymbolicSize{"batch_size"};
+ auto device = SymbolicDevice{};
+ device.set_options();
+
+ TensorMatcher({-1, 128, kHeadDim * 2}) // kv score
+ .with_dtype()
+ .with_device(device)
+ .verify(kv_score_buffer);
+ TensorMatcher({B, kHeadDim * 2}) // kv score input
+ .with_dtype()
+ .with_device(device)
+ .verify(kv_score_input);
+ TensorMatcher({B, kHeadDim}) // kv compressed output
+ .with_dtype()
+ .with_device(device)
+ .verify(kv_compressed_output);
+ TensorMatcher({128, kHeadDim}) // ape
+ .with_dtype()
+ .with_device(device)
+ .verify(ape);
+ TensorMatcher({B}) // indices
+ .with_dtype()
+ .with_device(device)
+ .verify(indices);
+ TensorMatcher({B}) // seq lens
+ .with_dtype()
+ .with_device(device)
+ .verify(seq_lens);
+
+ const auto batch_size = static_cast(B.unwrap());
+ const auto params = Compress128DecodeParams{
+ .kv_score_buffer = kv_score_buffer.data_ptr(),
+ .kv_score_input = kv_score_input.data_ptr(),
+ .kv_compressed_output = kv_compressed_output.data_ptr(),
+ .score_bias = ape.data_ptr(),
+ .indices = static_cast(indices.data_ptr()),
+ .seq_lens = static_cast(seq_lens.data_ptr()),
+ .batch_size = batch_size,
+ };
+
+ const uint32_t num_blocks = batch_size * kNumSplit;
+ LaunchKernel(num_blocks, kBlockSize, device.unwrap()) //
+ .enable_pdl(kUsePDL)(decode_kernel, params);
+ }
+
+ static void run_prefill(
+ const tvm::ffi::TensorView kv_score_buffer,
+ const tvm::ffi::TensorView kv_score_input,
+ const tvm::ffi::TensorView kv_compressed_output,
+ const tvm::ffi::TensorView ape,
+ const tvm::ffi::TensorView indices,
+ const tvm::ffi::TensorView compress_plan,
+ const tvm::ffi::TensorView write_plan,
+ const tvm::ffi::Optional extra) {
+ using namespace host;
+
+ auto B = SymbolicSize{"batch_size"};
+ auto N = SymbolicSize{"num_q_tokens"};
+ auto X = SymbolicSize{"compress_tokens"};
+ auto Y = SymbolicSize{"write_tokens"};
+ auto device_ = SymbolicDevice{};
+ device_.set_options();
+
+ TensorMatcher({-1, 128, kHeadDim * 2}) // kv score
+ .with_dtype()
+ .with_device(device_)
+ .verify(kv_score_buffer);
+ TensorMatcher({N, kHeadDim * 2}) // kv score input
+ .with_dtype()
+ .with_device(device_)
+ .verify(kv_score_input);
+ TensorMatcher({N, kHeadDim}) // kv compressed output
+ .with_dtype()
+ .with_device(device_)
+ .verify(kv_compressed_output);
+ TensorMatcher({128, kHeadDim}) // ape
+ .with_dtype()
+ .with_device(device_)
+ .verify(ape);
+ TensorMatcher({B}) // indices
+ .with_dtype()
+ .with_device(device_)
+ .verify(indices);
+ TensorMatcher({X, compress::kPrefillPlanDim}) // compress plan
+ .with_dtype()
+ .with_device(device_)
+ .verify(compress_plan);
+ TensorMatcher({Y, compress::kPrefillPlanDim}) // write plan
+ .with_dtype()
+ .with_device(device_)
+ .verify(write_plan);
+
+ // might be needed for prefill write
+ const auto load_indices = extra.value_or(indices);
+ TensorMatcher({B}) // [read_positions]
+ .with_dtype()
+ .with_device(device_)
+ .verify(load_indices);
+
+ const auto device = device_.unwrap();
+ const auto batch_size = static_cast(B.unwrap());
+ const auto num_q_tokens = static_cast(N.unwrap());
+ const auto num_c = static_cast(X.unwrap());
+ const auto num_w = static_cast(Y.unwrap());
+ const auto params = Compress128PrefillParams{
+ .kv_score_buffer = kv_score_buffer.data_ptr(),
+ .kv_score_input = kv_score_input.data_ptr(),
+ .kv_compressed_output = kv_compressed_output.data_ptr(),
+ .score_bias = ape.data_ptr(),
+ .indices = static_cast(indices.data_ptr()),
+ .load_indices = static_cast(load_indices.data_ptr()),
+ .compress_plan = static_cast(compress_plan.data_ptr()),
+ .write_plan = static_cast(write_plan.data_ptr()),
+ .num_compress = num_c,
+ .num_write = num_w,
+ };
+ RuntimeCheck(num_q_tokens >= batch_size, "num_q_tokens must be >= batch_size");
+ RuntimeCheck(num_q_tokens >= std::max(num_c, num_w), "invalid prefill plan");
+
+ constexpr auto kBlockSize_C = kBlockSize;
+ constexpr auto kBlockSize_W = kWriteBlockSize;
+ if (const auto num_c_blocks = num_c * kNumSplit) {
+ LaunchKernel(num_c_blocks, kBlockSize_C, device) //
+ .enable_pdl(kUsePDL)(prefill_c_kernel, params);
+ }
+ if (const auto num_w_blocks = div_ceil(num_w * kNumSplit, kWarpsPerWriteBlock)) {
+ LaunchKernel(num_w_blocks, kBlockSize_W, device) //
+ .enable_pdl(kUsePDL)(prefill_w_kernel, params);
+ }
+ }
+};
+
+} // namespace
diff --git a/lightllm/third_party/sglang_jit/csrc/deepseek_v4/c128_online.cuh b/lightllm/third_party/sglang_jit/csrc/deepseek_v4/c128_online.cuh
new file mode 100644
index 000000000..b49747060
--- /dev/null
+++ b/lightllm/third_party/sglang_jit/csrc/deepseek_v4/c128_online.cuh
@@ -0,0 +1,726 @@
+#include
+#include
+
+#include
+#include
+#include
+#include
+#include
+#include
+
+#include
+
+#include
+#include
+#include
+#include
+
+#include
+#include
+#include
+
+namespace device::compress {
+
+/// \brief Plan entry for online compress 128 prefill.
+/// Each entry describes a contiguous segment of tokens that lies inside a
+/// single 128-chunk. Multiple segments can map to the same batch id when the
+/// extend tokens span chunk boundaries.
+///
+/// **Layout compatibility:** the field order/types match `PrefillPlan` so that
+/// downstream kernels (e.g. `fused_norm_rope` in `CompressExtend` mode) can
+/// consume the compress_plan tensor as-if it were a `PrefillPlan` tensor --
+/// they only read `ragged_id` and `position`, both of which carry identical
+/// semantics here (the LAST token of the segment in q-ragged and global
+/// coordinates respectively).
+///
+/// Note that `window_len` here means "number of real tokens in this segment"
+/// (1..128), which differs from `PrefillPlan::window_len`. Downstream kernels
+/// that share the tensor MUST NOT read it under that name.
+struct alignas(16) OnlinePrefillPlan {
+ /// \brief Ragged-q position of the LAST token in this segment.
+ /// Equal to `segment_start_ragged + window_len - 1`.
+ uint32_t ragged_id;
+ /// \brief Index into the `indices` / `load_indices` arrays.
+ uint32_t batch_id;
+ /// \brief Global position of the LAST token in this segment.
+ /// For compress plans, `position % 128 == 127` (chunk-closing); for write
+ /// plans, `position % 128 < 127`.
+ uint32_t position;
+ /// \brief Number of real tokens in this segment (1..128).
+ /// The first segment token sits at `position - window_len + 1` (global) and
+ /// at `ragged_id - window_len + 1` (ragged).
+ uint32_t window_len;
+};
+
+static_assert(alignof(OnlinePrefillPlan) == alignof(PrefillPlan));
+static_assert(sizeof(OnlinePrefillPlan) == sizeof(PrefillPlan));
+
+} // namespace device::compress
+
+namespace host::compress {
+
+using device::compress::OnlinePrefillPlan;
+using OnlinePrefillPlanTensorDtype = uint8_t;
+inline constexpr int64_t kOnlinePrefillPlanDim = 16;
+
+static_assert(alignof(OnlinePrefillPlan) == sizeof(OnlinePrefillPlan));
+static_assert(sizeof(OnlinePrefillPlan) == kOnlinePrefillPlanDim * sizeof(OnlinePrefillPlanTensorDtype));
+
+} // namespace host::compress
+
+namespace {
+
+using OnlinePlan = device::compress::OnlinePrefillPlan;
+using IndiceT = int32_t;
+
+/// \brief Need to reduce register usage to increase occupancy
+struct Compress128OnlineDecodeParams {
+ /** \brief Shape: `[num_indices, 1, head_dim * 3 (max, sum, kv) ]` \n */
+ void* __restrict__ kv_score_buffer;
+ /** \brief Shape: `[batch_size, head_dim * 2]` */
+ const void* __restrict__ kv_score_input;
+ /** \brief Shape: `[batch_size, head_dim]` */
+ void* __restrict__ kv_compressed_output;
+ /** \brief Shape: `[128, head_dim]` (called `ape`) */
+ const void* __restrict__ score_bias;
+ /** \brief Shape: `[batch_size, ]`*/
+ const IndiceT* __restrict__ indices;
+ /** \brief Shape: `[batch_size, ]` */
+ const IndiceT* __restrict__ seq_lens;
+ /** \NOTE: `batch_size` <= `num_indices` */
+ uint32_t batch_size;
+};
+
+/// \brief Need to reduce register usage to increase occupancy
+struct Compress128OnlinePrefillParams {
+ /** \brief Shape: `[num_indices, 1, head_dim * 3 (max, sum, kv) ]` \n */
+ void* __restrict__ kv_score_buffer;
+ /** \brief Shape: `[num_q_tokens, head_dim * 2]` */
+ const void* __restrict__ kv_score_input;
+ /** \brief Shape: `[num_q_tokens, head_dim]` */
+ void* __restrict__ kv_compressed_output;
+ /** \brief Shape: `[128, head_dim]` (called `ape`) */
+ const void* __restrict__ score_bias;
+ /** \brief Shape: `[batch_size, ]`*/
+ const IndiceT* __restrict__ indices;
+ /** \brief Shape: `[batch_size, ]`*/
+ const IndiceT* __restrict__ load_indices;
+ /// \brief Plan for segments that close a chunk (write to `kv_compressed_output`).
+ /// Shape: `[num_compress, 16]` (uint8).
+ const OnlinePlan* __restrict__ compress_plan;
+ /// \brief Plan for the trailing partial segment of each batch (write back to
+ /// `kv_score_buffer`). Shape: `[num_write, 16]` (uint8).
+ const OnlinePlan* __restrict__ write_plan;
+ uint32_t num_compress;
+ uint32_t num_write;
+};
+
+// 4 elements per thread, kHeadDim / 4 threads per block
+template
+__global__ void flash_c128_online_decode(const __grid_constant__ Compress128OnlineDecodeParams params) {
+ using namespace device;
+ constexpr uint32_t kVecSize = 4;
+ constexpr uint32_t kBlockSize = kHeadDim / kVecSize;
+ using Vec = AlignedVector;
+ const auto gmem = tile::Memory::cta(kBlockSize);
+ const auto batch_id = blockIdx.x;
+ const auto index = params.indices[batch_id];
+ const auto seq_len = params.seq_lens[batch_id];
+
+ const auto kv_score_buffer = static_cast(params.kv_score_buffer);
+ const auto kv_buf = kv_score_buffer + index * (kHeadDim * 3);
+ const auto kv_score_input = static_cast(params.kv_score_input);
+ const auto kv_src = kv_score_input + batch_id * (kHeadDim * 2);
+
+ /// NOTE: kv_score_buffer layout is [max, sum, kv] (slot 0 / 1 / 2). Reads,
+ /// writes, and the prefill kernel must all agree on this order.
+ const auto max_score_vec = gmem.load(kv_buf, 0);
+ const auto sum_score_vec = gmem.load(kv_buf, 1);
+ const auto old_kv_vec = gmem.load(kv_buf, 2);
+
+ /// NOTE: kv_score_input layout is | kv | score | (head_dim each), matching
+ /// the offline c128 kernel and the online prefill kernel.
+ const auto new_kv_vec = gmem.load(kv_src, 0);
+ const auto new_score_raw_vec = gmem.load(kv_src, 1);
+
+ /// NOTE: the new token sits at global position `seq_len - 1`, so its
+ /// position inside the 128-chunk is `(seq_len - 1) % 128`. The previous
+ /// `seq_len % 128` was off by one (`bias[127]` vs `bias[0]`, etc.).
+ const auto pos_in_chunk = (seq_len - 1) % 128;
+ const auto bias_vec = gmem.load(params.score_bias, pos_in_chunk);
+
+ Vec out_kv_vec;
+ Vec out_max_vec;
+ Vec out_sum_vec;
+ if (pos_in_chunk != 0) {
+ // Mid-chunk: combine prior partial state with the new token via online softmax.
+#pragma unroll
+ for (uint32_t i = 0; i < 4; ++i) {
+ const auto old_max = max_score_vec[i];
+ const auto old_kv = old_kv_vec[i];
+ const auto new_score = new_score_raw_vec[i] + bias_vec[i];
+ const auto new_kv = new_kv_vec[i];
+ const auto new_max = fmax(old_max, new_score);
+ const auto old_sum = sum_score_vec[i] * expf(old_max - new_max);
+ const auto new_exp = expf(new_score - new_max);
+ const auto new_sum = old_sum + new_exp;
+ out_kv_vec[i] = (old_kv * old_sum + new_kv * new_exp) / new_sum;
+ out_max_vec[i] = new_max;
+ out_sum_vec[i] = new_sum;
+ }
+ } else {
+ // First token of a new 128-chunk: initialize state with this token alone.
+#pragma unroll
+ for (uint32_t i = 0; i < 4; ++i) {
+ out_kv_vec[i] = new_kv_vec[i];
+ out_max_vec[i] = new_score_raw_vec[i] + bias_vec[i];
+ out_sum_vec[i] = 1.0f; // exp(score - max) with max == score
+ }
+ }
+
+ if (pos_in_chunk == 127) {
+ // Chunk just closed: emit the compressed kv. No need to update the buffer
+ // -- the next chunk's first token will overwrite it.
+ const auto kv_out = static_cast(params.kv_compressed_output) + batch_id * kHeadDim;
+ gmem.store(kv_out, out_kv_vec);
+ } else {
+ // Otherwise persist the running [max, sum, kv] state for the next step.
+ gmem.store(kv_buf, out_max_vec, 0);
+ gmem.store(kv_buf, out_sum_vec, 1);
+ gmem.store(kv_buf, out_kv_vec, 2);
+ }
+}
+
+constexpr int32_t kTileElements = 2; // split (along head-dim)
+/// \brief Each warp will handle this many elements (split along softmax-128)
+constexpr int32_t kElementsPerWarp = 8;
+constexpr uint32_t kNumWarps = 128 / kElementsPerWarp;
+constexpr uint32_t kPrefillBlockSize = device::kWarpThreads * kNumWarps;
+using PrefillStorage = device::AlignedVector;
+
+struct Compress128SharedBuffer {
+ using Storage = device::AlignedVector;
+ Storage data[kNumWarps][device::kWarpThreads + 1]; // padding to avoid bank conflict
+ SGL_DEVICE Storage& operator()(uint32_t warp_id, uint32_t lane_id) {
+ return data[warp_id][lane_id];
+ }
+ SGL_DEVICE float& operator()(uint32_t warp_id, uint32_t lane_id, uint32_t tile_id) {
+ return data[warp_id][lane_id][tile_id];
+ }
+};
+
+template
+SGL_DEVICE void c128_prefill_forward(
+ const PrefillStorage (&kv)[kElementsPerWarp],
+ const PrefillStorage (&score)[kElementsPerWarp],
+ float* kv_out,
+ float* max_out,
+ float* sum_out,
+ const uint32_t warp_id,
+ const uint32_t lane_id) {
+ using namespace device;
+
+ /// NOTE: part 2: safe online softmax + weighted sum
+ using TmpStorage = typename Compress128SharedBuffer::Storage;
+ __shared__ Compress128SharedBuffer s_local_val_max;
+ __shared__ Compress128SharedBuffer s_local_exp_sum;
+ __shared__ Compress128SharedBuffer s_local_product;
+
+ TmpStorage tmp_val_max;
+ TmpStorage tmp_exp_sum;
+ TmpStorage tmp_product;
+
+#pragma unroll
+ for (int32_t i = 0; i < kTileElements; ++i) {
+ float score_fp32[kElementsPerWarp];
+
+#pragma unroll
+ for (int32_t j = 0; j < kElementsPerWarp; ++j) {
+ score_fp32[j] = score[j][i];
+ }
+
+ float max_value = score_fp32[0];
+ float sum_exp_value = 0.0f;
+
+#pragma unroll
+ for (int32_t j = 1; j < kElementsPerWarp; ++j) {
+ const auto fp32_score = score_fp32[j];
+ max_value = fmaxf(max_value, fp32_score);
+ }
+
+ float sum_product = 0.0f;
+#pragma unroll
+ for (int32_t j = 0; j < 8; ++j) {
+ const auto fp32_score = score_fp32[j];
+ const auto exp_score = expf(fp32_score - max_value);
+ sum_product += cast(kv[j][i]) * exp_score;
+ sum_exp_value += exp_score;
+ }
+
+ tmp_val_max[i] = max_value;
+ tmp_exp_sum[i] = sum_exp_value;
+ tmp_product[i] = sum_product;
+ }
+
+ // naturally aligned, so no bank conflict
+ s_local_val_max(warp_id, lane_id) = tmp_val_max;
+ s_local_exp_sum(warp_id, lane_id) = tmp_exp_sum;
+ s_local_product(warp_id, lane_id) = tmp_product;
+
+ __syncthreads();
+
+ /// NOTE: part 3: online softmax
+ /// NOTE: We have `kTileElements * kWarpThreads * kNumWarps` values to reduce
+ /// each reduce will consume `kNumWarps` threads (use partial warp reduction)
+ constexpr uint32_t kReductionCount = kTileElements * kWarpThreads * kNumWarps;
+ constexpr uint32_t kIteration = kReductionCount / kPrefillBlockSize;
+
+#pragma unroll
+ for (uint32_t i = 0; i < kIteration; ++i) {
+ /// NOTE: Range `[0, kTileElements * kWarpThreads * kNumWarps)`
+ const uint32_t j = i * kPrefillBlockSize + warp_id * kWarpThreads + lane_id;
+ /// NOTE: Range `[0, kNumWarps)`
+ const uint32_t local_warp_id = j % kNumWarps;
+ /// NOTE: Range `[0, kTileElements * kWarpThreads)`
+ const uint32_t local_elem_id = j / kNumWarps;
+ /// NOTE: Range `[0, kTileElements)`
+ const uint32_t local_tile_id = local_elem_id % kTileElements;
+ /// NOTE: Range `[0, kWarpThreads)`
+ const uint32_t local_lane_id = local_elem_id / kTileElements;
+ /// NOTE: each warp will access the whole tile (all `kTileElements`)
+ /// and for different lanes, the memory access only differ in `local_warp_id`
+ /// so there's no bank conflict in shared memory access.
+ static_assert(kTileElements * kNumWarps == kWarpThreads, "TODO: support other configs");
+ const auto local_val_max = s_local_val_max(local_warp_id, local_lane_id, local_tile_id);
+ const auto local_exp_sum = s_local_exp_sum(local_warp_id, local_lane_id, local_tile_id);
+ const auto local_product = s_local_product(local_warp_id, local_lane_id, local_tile_id);
+ const auto global_val_max = warp::reduce_max(local_val_max);
+ const auto rescale = expf(local_val_max - global_val_max);
+ const auto global_exp_sum = warp::reduce_sum(local_exp_sum * rescale);
+ const auto final_scale = rescale / global_exp_sum;
+ const auto global_product = warp::reduce_sum(local_product * final_scale);
+ kv_out[local_elem_id] = global_product;
+ if constexpr (kNeedData) {
+ max_out[local_elem_id] = global_val_max;
+ sum_out[local_elem_id] = global_exp_sum;
+ }
+ }
+ if constexpr (kNeedData) __syncthreads();
+}
+
+/// \brief Sentinel score for padded positions in a 128-segment.
+/// Must be finite so that `score - max` never produces NaN even when an
+/// entire warp has only padded positions.
+constexpr float kPadScore = -FLT_MAX;
+
+/// \brief Online compress 128 prefill. Two passes share this body:
+/// - `kWrite=false` (compress pass): handles segments that close a chunk.
+/// May load prior partial state from the buffer, but never writes to it,
+/// so concurrent blocks can read the same slot without racing.
+/// - `kWrite=true` (write pass): handles the trailing partial segment of each
+/// batch. Each batch contributes at most one such plan, so concurrent blocks
+/// touch disjoint buffer slots.
+///
+/// The two passes MUST run as separate kernel launches (in stream order) so
+/// that all reads in pass 1 finish before any writes in pass 2 start.
+template
+__global__ __launch_bounds__(kPrefillBlockSize, 2) //
+ void flash_c128_online_prefill(const __grid_constant__ Compress128OnlinePrefillParams params) {
+ using namespace device;
+
+ constexpr int64_t kTileDim = kTileElements * kWarpThreads; // 64
+ constexpr uint32_t kNumSplit = kHeadDim / kTileDim;
+ static_assert(kHeadDim % kTileDim == 0, "Head dim must be multiple of tile dim");
+
+ /// NOTE: the compiler folds the if-else at compile time.
+ const auto num_plans = kWrite ? params.num_write : params.num_compress;
+ const auto plan_ptr = kWrite ? params.write_plan : params.compress_plan;
+ const uint32_t global_id = blockIdx.x;
+ const uint32_t global_pid = global_id / kNumSplit; // plan id
+ const uint32_t global_sid = global_id % kNumSplit; // split id
+ if (global_pid >= num_plans) return;
+ const auto [ragged_id, batch_id, position, window_len] = plan_ptr[global_pid];
+ if (ragged_id == 0xFFFFFFFFu) [[unlikely]]
+ return;
+
+ const uint32_t warp_id = threadIdx.x / kWarpThreads;
+ const uint32_t lane_id = threadIdx.x % kWarpThreads;
+ const int32_t split_offset = global_sid * kTileDim; // int32 is enough
+
+ const auto kv_score_buffer = static_cast(params.kv_score_buffer);
+ const auto kv_score_input = static_cast(params.kv_score_input);
+ const auto kv_compressed_output = static_cast(params.kv_compressed_output);
+ const auto score_bias_base = static_cast(params.score_bias);
+
+ constexpr int64_t kElementSize = kHeadDim * 2; // | kv | score |
+ const uint32_t chunk_offset = (position % 128u) + 1u - window_len;
+ const uint32_t window_end = chunk_offset + window_len; // exclusive, in [1, 128]
+ const int32_t segment_start = ragged_id - (position % 128u); // can be negative, but safe
+ const int32_t load_index = chunk_offset != 0 ? params.load_indices[batch_id] : -1;
+ const int32_t store_index = kWrite ? params.indices[batch_id] : -1;
+
+ PDLWaitPrimary();
+
+ // 2 * 8 = 16 register per elem. in theory we should consume 48 register here
+ PrefillStorage kv[kElementsPerWarp];
+ PrefillStorage score[kElementsPerWarp];
+ PrefillStorage bias[kElementsPerWarp];
+ const auto warp_offset = warp_id * kElementsPerWarp;
+
+#pragma unroll
+ for (uint32_t i = 0; i < kElementsPerWarp; ++i) {
+ const uint32_t j = i + warp_offset;
+ if (j >= chunk_offset && j < window_end) {
+ const auto kv_src_ptr = kv_score_input + (segment_start + j) * kElementSize + split_offset;
+ const auto score_src_ptr = kv_src_ptr + kHeadDim;
+ const auto bias_src_ptr = score_bias_base + j * kHeadDim + split_offset;
+ kv[i].load(kv_src_ptr, lane_id);
+ score[i].load(score_src_ptr, lane_id);
+ bias[i].load(bias_src_ptr, lane_id);
+ }
+ }
+
+#pragma unroll
+ for (uint32_t i = 0; i < kElementsPerWarp; ++i) {
+ const uint32_t j = i + warp_offset;
+ const bool is_valid = (j >= chunk_offset && j < window_end);
+#pragma unroll
+ for (uint32_t ii = 0; ii < kTileElements; ++ii) {
+ score[i][ii] = is_valid ? score[i][ii] + bias[i][ii] : kPadScore;
+ /// NOTE: must zero out kv on padded slots -- `c128_prefill_forward`
+ /// computes `kv * exp_score` where `exp_score = expf(-FLT_MAX - max) ??? 0`,
+ /// and IEEE-754 makes `NaN * 0 = NaN` / `+-inf * 0 = NaN`. An
+ /// uninitialized register can hold a NaN/inf bit pattern, so without
+ /// this reset a single padded warp can poison the whole softmax.
+ kv[i][ii] = is_valid ? kv[i][ii] : 0.0f;
+ }
+ }
+
+ __shared__ alignas(16) float seg_kv[kTileDim];
+ __shared__ alignas(16) float seg_max[kTileDim];
+ __shared__ alignas(16) float seg_sum[kTileDim];
+
+ c128_prefill_forward(kv, score, seg_kv, seg_max, seg_sum, warp_id, lane_id);
+
+ PDLTriggerSecondary();
+
+ if (warp_id == 0) {
+ PrefillStorage out_kv_vec, out_max_vec, out_sum_vec;
+ out_kv_vec.load(seg_kv, lane_id);
+ out_max_vec.load(seg_max, lane_id);
+ out_sum_vec.load(seg_sum, lane_id);
+ if (chunk_offset != 0) {
+ /// NOTE: load (max, sum, kv) of the in-progress chunk for this index.
+ /// `load_indices` may differ from `indices` when the prior partial state
+ /// lives on a different slot than the slot we ultimately write to.
+ const auto buf_load = kv_score_buffer + load_index * (kHeadDim * 3) + split_offset;
+ PrefillStorage buf_max_vec, buf_sum_vec, buf_kv_vec;
+ buf_max_vec.load(buf_load + 0 * kHeadDim, lane_id);
+ buf_sum_vec.load(buf_load + 1 * kHeadDim, lane_id);
+ buf_kv_vec.load(buf_load + 2 * kHeadDim, lane_id);
+#pragma unroll
+ for (uint32_t ii = 0; ii < kTileElements; ++ii) {
+ const float m1 = buf_max_vec[ii];
+ const float s1 = buf_sum_vec[ii];
+ const float k1 = buf_kv_vec[ii];
+ const float m2 = out_max_vec[ii];
+ const float s2 = out_sum_vec[ii];
+ const float k2 = out_kv_vec[ii];
+ const float new_max = fmaxf(m1, m2);
+ const float new_s1 = s1 * expf(m1 - new_max);
+ const float new_s2 = s2 * expf(m2 - new_max);
+ const float new_sum = new_s1 + new_s2;
+ const float new_kv = (k1 * new_s1 + k2 * new_s2) / new_sum;
+ out_max_vec[ii] = new_max;
+ out_sum_vec[ii] = new_sum;
+ out_kv_vec[ii] = new_kv;
+ }
+ }
+
+ if constexpr (kWrite) {
+ const auto buf_store = kv_score_buffer + store_index * (kHeadDim * 3) + split_offset;
+ reinterpret_cast(buf_store + 0 * kHeadDim)[lane_id] = out_max_vec;
+ reinterpret_cast(buf_store + 1 * kHeadDim)[lane_id] = out_sum_vec;
+ reinterpret_cast(buf_store + 2 * kHeadDim)[lane_id] = out_kv_vec;
+ } else {
+ const auto out_ptr = kv_compressed_output + ragged_id * kHeadDim + split_offset;
+ reinterpret_cast(out_ptr)[lane_id] = out_kv_vec;
+ }
+ }
+}
+
+template
+struct FlashCompress128OnlineKernel {
+ static constexpr auto decode_kernel = flash_c128_online_decode;
+ template
+ static constexpr auto prefill_kernel = flash_c128_online_prefill;
+ static constexpr auto prefill_c_kernel = prefill_kernel*kWrite=*/false>;
+ static constexpr auto prefill_w_kernel = prefill_kernel*kWrite=*/true>;
+ static constexpr int64_t kTileDim = kTileElements * device::kWarpThreads; // 64
+ static constexpr uint32_t kNumSplit = kHeadDim / kTileDim;
+ static constexpr uint32_t kDecodeBlockSize = kHeadDim / 4;
+
+ static void run_decode(
+ const tvm::ffi::TensorView kv_score_buffer,
+ const tvm::ffi::TensorView kv_score_input,
+ const tvm::ffi::TensorView kv_compressed_output,
+ const tvm::ffi::TensorView ape,
+ const tvm::ffi::TensorView indices,
+ const tvm::ffi::TensorView seq_lens,
+ const tvm::ffi::Optional /* UNUSED */) {
+ using namespace host;
+
+ auto B = SymbolicSize{"batch_size"};
+ auto device = SymbolicDevice{};
+ device.set_options();
+
+ TensorMatcher({-1, 1, kHeadDim * 3}) // kv score buffer (max, sum, kv)
+ .with_dtype()
+ .with_device(device)
+ .verify(kv_score_buffer);
+ TensorMatcher({B, kHeadDim * 2}) // kv score input
+ .with_dtype()
+ .with_device(device)
+ .verify(kv_score_input);
+ TensorMatcher({B, kHeadDim}) // kv compressed output
+ .with_dtype()
+ .with_device(device)
+ .verify(kv_compressed_output);
+ TensorMatcher({128, kHeadDim}) // ape
+ .with_dtype()
+ .with_device(device)
+ .verify(ape);
+ TensorMatcher({B}).with_dtype().with_device(device).verify(indices);
+ TensorMatcher({B}).with_dtype().with_device(device).verify(seq_lens);
+
+ const auto batch_size = static_cast(B.unwrap());
+ const auto params = Compress128OnlineDecodeParams{
+ .kv_score_buffer = kv_score_buffer.data_ptr(),
+ .kv_score_input = kv_score_input.data_ptr(),
+ .kv_compressed_output = kv_compressed_output.data_ptr(),
+ .score_bias = ape.data_ptr(),
+ .indices = static_cast(indices.data_ptr()),
+ .seq_lens = static_cast(seq_lens.data_ptr()),
+ .batch_size = batch_size,
+ };
+ LaunchKernel(batch_size, kDecodeBlockSize, device.unwrap()) //
+ .enable_pdl(kUsePDL)(decode_kernel, params);
+ }
+
+ static void run_prefill(
+ const tvm::ffi::TensorView kv_score_buffer,
+ const tvm::ffi::TensorView kv_score_input,
+ const tvm::ffi::TensorView kv_compressed_output,
+ const tvm::ffi::TensorView ape,
+ const tvm::ffi::TensorView indices,
+ const tvm::ffi::TensorView compress_plan,
+ const tvm::ffi::TensorView write_plan,
+ const tvm::ffi::Optional extra) {
+ using namespace host;
+ using host::compress::kOnlinePrefillPlanDim;
+ using host::compress::OnlinePrefillPlanTensorDtype;
+
+ auto B = SymbolicSize{"batch_size"};
+ auto N = SymbolicSize{"num_q_tokens"};
+ auto X = SymbolicSize{"compress_tokens"};
+ auto Y = SymbolicSize{"write_tokens"};
+ auto device_ = SymbolicDevice{};
+ device_.set_options();
+
+ TensorMatcher({-1, 1, kHeadDim * 3}) // kv score buffer (max, sum, kv) ??? 2D
+ .with_dtype()
+ .with_device(device_)
+ .verify(kv_score_buffer);
+ TensorMatcher({N, kHeadDim * 2}) // kv score input
+ .with_dtype()
+ .with_device(device_)
+ .verify(kv_score_input);
+ TensorMatcher({N, kHeadDim}) // kv compressed output
+ .with_dtype()
+ .with_device(device_)
+ .verify(kv_compressed_output);
+ TensorMatcher({128, kHeadDim}) // ape
+ .with_dtype()
+ .with_device(device_)
+ .verify(ape);
+ TensorMatcher({B}) // indices
+ .with_dtype()
+ .with_device(device_)
+ .verify(indices);
+ TensorMatcher({X, kOnlinePrefillPlanDim}) // compress plan
+ .with_dtype()
+ .with_device(device_)
+ .verify(compress_plan);
+ TensorMatcher({Y, kOnlinePrefillPlanDim}) // write plan
+ .with_dtype()
+ .with_device(device_)
+ .verify(write_plan);
+
+ /// NOTE: `extra` is `load_indices`. When the previous partial state lives
+ /// on a slot different from the destination slot (e.g. paged buffers), the
+ /// caller must supply this; otherwise it defaults to `indices`.
+ const auto load_indices = extra.value_or(indices);
+ TensorMatcher({B}).with_dtype().with_device(device_).verify(load_indices);
+
+ const auto device = device_.unwrap();
+ const auto num_c = static_cast(X.unwrap());
+ const auto num_w = static_cast(Y.unwrap());
+ const auto params = Compress128OnlinePrefillParams{
+ .kv_score_buffer = kv_score_buffer.data_ptr(),
+ .kv_score_input = kv_score_input.data_ptr(),
+ .kv_compressed_output = kv_compressed_output.data_ptr(),
+ .score_bias = ape.data_ptr(),
+ .indices = static_cast(indices.data_ptr()),
+ .load_indices = static_cast(load_indices.data_ptr()),
+ .compress_plan = static_cast(compress_plan.data_ptr()),
+ .write_plan = static_cast(write_plan.data_ptr()),
+ .num_compress = num_c,
+ .num_write = num_w,
+ };
+
+ /// NOTE: pass 1 reads the buffer (for the first segment of each batch
+ /// that started mid-chunk) and writes only to `kv_compressed_output`.
+ /// Pass 2 then writes the trailing partial state of each batch back to
+ /// the buffer. Stream serialization between the two launches enforces
+ /// read-before-write on shared buffer slots.
+ if (const auto num_c_blocks = num_c * kNumSplit) {
+ LaunchKernel(num_c_blocks, kPrefillBlockSize, device) //
+ .enable_pdl(kUsePDL)(prefill_c_kernel, params);
+ }
+ if (const auto num_w_blocks = num_w * kNumSplit) {
+ LaunchKernel(num_w_blocks, kPrefillBlockSize, device) //
+ .enable_pdl(kUsePDL)(prefill_w_kernel, params);
+ }
+ }
+};
+
+} // namespace
+
+namespace host::compress {
+
+using OnlinePlanResult = tvm::ffi::Tuple;
+
+struct OnlinePrefillCompressParams {
+ OnlinePrefillPlan* __restrict__ compress_plan;
+ OnlinePrefillPlan* __restrict__ write_plan;
+ const int64_t* __restrict__ seq_lens;
+ const int64_t* __restrict__ extend_lens;
+ uint32_t batch_size;
+ uint32_t num_tokens;
+};
+
+/// \brief Build the compress + write plans for online compress 128 prefill.
+///
+/// Each batch's `[prefix_len, prefix_len + extend_len)` range is split at
+/// 128-aligned boundaries. Every resulting segment falls into one of:
+/// - **compress**: closes a 128-chunk (`chunk_offset + window_len == 128`).
+/// These plans only read the buffer (when starting mid-chunk) and write the
+/// compressed kv to `kv_compressed_output`.
+/// - **write**: trailing partial of the batch (`chunk_offset + window_len < 128`).
+/// May read the buffer and always writes the new partial state back to it.
+/// Each batch produces at most one such plan.
+///
+/// The two plans MUST be dispatched as separate kernel launches in stream
+/// order so that pass-1 reads of a buffer slot complete before any pass-2
+/// write of the same slot.
+inline OnlinePlanResult plan_online_prefill_host(const OnlinePrefillCompressParams& params, const bool use_cuda_graph) {
+ const auto& [compress_plan, write_plan, seq_lens, extend_lens, batch_size, num_tokens] = params;
+
+ uint32_t counter = 0;
+ uint32_t compress_count = 0;
+ uint32_t write_count = 0;
+ for (const auto i : irange(batch_size)) {
+ const uint32_t seq_len = static_cast(seq_lens[i]);
+ const uint32_t extend_len = static_cast(extend_lens[i]);
+ RuntimeCheck(0 < extend_len && extend_len <= seq_len);
+ const uint32_t prefix_len = seq_len - extend_len;
+ const uint32_t end_pos = prefix_len + extend_len;
+ /// NOTE: split the extend range into per-128-chunk segments. Each segment
+ /// stays inside one chunk, so the kernel can decide load/store from
+ /// `chunk_offset` and `window_len` alone.
+ uint32_t pos = prefix_len;
+ while (pos < end_pos) {
+ const uint32_t chunk_start = (pos / 128u) * 128u;
+ const uint32_t seg_end = std::min(end_pos, chunk_start + 128u); // exclusive
+ const uint32_t seg_len = seg_end - pos;
+ const uint32_t chunk_off = pos - chunk_start;
+ /// NOTE: store last-token coordinates so that downstream consumers
+ /// (e.g. `fused_norm_rope`) can read `ragged_id` and `position` with the
+ /// same semantics as `PrefillPlan`. The segment start is recoverable as
+ /// `ragged_id - window_len + 1` and `position - window_len + 1`.
+ const uint32_t last_pos = seg_end - 1;
+ const uint32_t last_ragged = counter + (last_pos - prefix_len);
+ const auto plan = OnlinePrefillPlan{
+ .ragged_id = last_ragged,
+ .batch_id = i,
+ .position = last_pos,
+ .window_len = seg_len,
+ };
+ if (chunk_off + seg_len == 128u) {
+ // full chunk, must be complete, maybe read the buffer, no write
+ RuntimeCheck(compress_count < num_tokens);
+ compress_plan[compress_count++] = plan;
+ } else {
+ // last chunk, must be incomplete, maybe read the buffer, must write
+ RuntimeCheck(write_count < num_tokens);
+ write_plan[write_count++] = plan;
+ }
+ pos = seg_end;
+ }
+ counter += extend_len;
+ }
+ RuntimeCheck(counter == num_tokens, "input size ", counter, " != num_q_tokens ", num_tokens);
+ if (!use_cuda_graph) return OnlinePlanResult{compress_count, write_count};
+ /// NOTE: pad both plans with sentinel entries so cuda-graph runs always see
+ /// the same number of blocks. The kernel skips plans whose `ragged_id` is -1.
+ constexpr auto kInvalid = static_cast(-1);
+ constexpr auto kInvalidPlan = OnlinePrefillPlan{kInvalid, kInvalid, kInvalid, kInvalid};
+ for (const auto i : irange(compress_count, num_tokens)) {
+ compress_plan[i] = kInvalidPlan;
+ }
+ for (const auto i : irange(write_count, num_tokens)) {
+ write_plan[i] = kInvalidPlan;
+ }
+ return OnlinePlanResult{num_tokens, num_tokens};
+}
+
+inline OnlinePlanResult plan_online_prefill(
+ const tvm::ffi::TensorView extend_lens,
+ const tvm::ffi::TensorView seq_lens,
+ const tvm::ffi::TensorView compress_plan,
+ const tvm::ffi::TensorView write_plan,
+ const bool use_cuda_graph) {
+ auto N = SymbolicSize{"batch_size"};
+ auto M = SymbolicSize{"num_tokens"};
+ auto device = SymbolicDevice{};
+ /// NOTE: only host (CPU/cuda-host) planning is implemented for now. The
+ device.set_options();
+ TensorMatcher({N}) //
+ .with_dtype()
+ .with_device(device)
+ .verify(extend_lens)
+ .verify(seq_lens);
+ TensorMatcher({M, kOnlinePrefillPlanDim}) //
+ .with_dtype()
+ .with_device(device)
+ .verify(compress_plan)
+ .verify(write_plan);
+ const auto params = OnlinePrefillCompressParams{
+ .compress_plan = static_cast(compress_plan.data_ptr()),
+ .write_plan = static_cast(write_plan.data_ptr()),
+ .seq_lens = static_cast(seq_lens.data_ptr()),
+ .extend_lens = static_cast(extend_lens.data_ptr()),
+ .batch_size = static_cast(N.unwrap()),
+ .num_tokens = static_cast(M.unwrap()),
+ };
+ return plan_online_prefill_host(params, use_cuda_graph);
+}
+
+} // namespace host::compress
+
+namespace {
+
+[[maybe_unused]]
+constexpr auto& plan_compress_online_prefill = host::compress::plan_online_prefill;
+
+} // namespace
diff --git a/lightllm/third_party/sglang_jit/csrc/deepseek_v4/c128_online_v2.cuh b/lightllm/third_party/sglang_jit/csrc/deepseek_v4/c128_online_v2.cuh
new file mode 100644
index 000000000..71e600dc3
--- /dev/null
+++ b/lightllm/third_party/sglang_jit/csrc/deepseek_v4/c128_online_v2.cuh
@@ -0,0 +1,875 @@
+#include
+#include
+#include
+
+#include
+#include
+#include
+#include
+#include
+#include
+
+#include
+
+#include
+#include
+#include
+
+#include
+#include
+#include
+#include
+#include
+
+namespace {
+
+using PlanD = device::compress::DecodePlan;
+using PlanC = device::compress::CompressPlan;
+
+// ---------------------------------------------------------------------------
+// Decode kernel: 1 token / batch. Each block handles one batch.
+// 4 elements per thread -> kBlockSize = head_dim / 4.
+// ---------------------------------------------------------------------------
+
+struct Compress128OnlineDecodeParams {
+ void* __restrict__ kv_score_buffer; // [num_slots, 1, head_dim * 3]
+ const void* __restrict__ kv_score_input; // [batch_size, head_dim * 2]
+ void* __restrict__ kv_compressed_output; // [batch_size, head_dim]
+ const void* __restrict__ score_bias; // [128, head_dim]
+ const PlanD* __restrict__ plan_d;
+ uint32_t batch_size;
+};
+
+template
+__global__ void flash_c128_online_decode_v2(const __grid_constant__ Compress128OnlineDecodeParams params) {
+ using namespace device;
+ constexpr uint32_t kVecSize = 4;
+ constexpr uint32_t kBlockSize = kHeadDim / kVecSize;
+ using Vec = AlignedVector;
+ const auto gmem = tile::Memory::cta(kBlockSize);
+ const auto batch_id = blockIdx.x;
+ if (batch_id >= params.batch_size) return;
+
+ // Wait for the plan-finalize kernel to publish `plan.read_page_0 / write_loc`
+ // before reading the plan. The plan kernel runs on the same stream and does
+ // NOT issue a PDL trigger, so launching this kernel with PDL means our
+ // pre-wait global reads can race with the plan kernel's writes.
+ PDLWaitPrimary();
+
+ const auto plan = params.plan_d[batch_id];
+ const auto pos_in_chunk = (plan.seq_len - 1) % 128;
+
+ const auto kv_score_buffer = static_cast(params.kv_score_buffer);
+ const auto kv_score_input = static_cast(params.kv_score_input);
+ const auto kv_load_buf = kv_score_buffer + plan.read_page_0 * (kHeadDim * 3);
+ const auto kv_store_buf = kv_score_buffer + plan.write_loc * (kHeadDim * 3);
+ const auto kv_src = kv_score_input + batch_id * (kHeadDim * 2);
+
+ // Buffer layout: [max | sum | kv] (slot 0 / 1 / 2 of the head_dim*3 row).
+ const auto new_kv_vec = gmem.load(kv_src, 0);
+ const auto new_score_raw_vec = gmem.load(kv_src, 1);
+ const auto bias_vec = gmem.load(params.score_bias, pos_in_chunk);
+
+ Vec out_kv_vec;
+ Vec out_max_vec;
+ Vec out_sum_vec;
+ if (pos_in_chunk != 0) {
+ // Mid-chunk: combine prior partial state with the new token.
+ const auto max_score_vec = gmem.load(kv_load_buf, 0);
+ const auto sum_score_vec = gmem.load(kv_load_buf, 1);
+ const auto old_kv_vec = gmem.load(kv_load_buf, 2);
+#pragma unroll
+ for (uint32_t i = 0; i < kVecSize; ++i) {
+ const auto old_max = max_score_vec[i];
+ const auto old_kv = old_kv_vec[i];
+ const auto new_score = new_score_raw_vec[i] + bias_vec[i];
+ const auto new_kv = new_kv_vec[i];
+ const auto new_max = fmaxf(old_max, new_score);
+ const auto old_sum = sum_score_vec[i] * expf(old_max - new_max);
+ const auto new_exp = expf(new_score - new_max);
+ const auto new_sum = old_sum + new_exp;
+ out_kv_vec[i] = (old_kv * old_sum + new_kv * new_exp) / new_sum;
+ out_max_vec[i] = new_max;
+ out_sum_vec[i] = new_sum;
+ }
+ } else {
+ // First token of a new chunk: state == this token alone.
+#pragma unroll
+ for (uint32_t i = 0; i < kVecSize; ++i) {
+ out_kv_vec[i] = new_kv_vec[i];
+ out_max_vec[i] = new_score_raw_vec[i] + bias_vec[i];
+ out_sum_vec[i] = 1.0f;
+ }
+ }
+
+ if (pos_in_chunk == 127) {
+ // Chunk just closed: emit compressed kv, no buffer update.
+ const auto kv_out = static_cast(params.kv_compressed_output) + batch_id * kHeadDim;
+ gmem.store(kv_out, out_kv_vec);
+ } else {
+ gmem.store(kv_store_buf, out_max_vec, 0);
+ gmem.store(kv_store_buf, out_sum_vec, 1);
+ gmem.store(kv_store_buf, out_kv_vec, 2);
+ }
+}
+
+// ---------------------------------------------------------------------------
+// Prefill kernel: 1 segment / block. Two passes (compress + write) share the
+// kernel template, parameterized by `kWrite`.
+// 16 warps per block; each warp handles 8 of the 128 chunk positions.
+// ---------------------------------------------------------------------------
+
+constexpr int32_t kTileElements = 2; // split along head-dim
+constexpr int32_t kElementsPerWarp = 8; // split along the 128-chunk
+constexpr uint32_t kNumWarps = 128 / kElementsPerWarp;
+constexpr uint32_t kPrefillBlockSize = device::kWarpThreads * kNumWarps;
+using PrefillStorage = device::AlignedVector;
+
+struct Compress128OnlinePrefillParams {
+ void* __restrict__ kv_score_buffer; // [num_slots, 1, head_dim * 3]
+ const void* __restrict__ kv_score_input; // [num_q_tokens, head_dim * 2]
+ void* __restrict__ kv_compressed_output; // [num_compress, head_dim]
+ const void* __restrict__ score_bias; // [128, head_dim]
+ const PlanC* __restrict__ plan_c; // close-chunk segments
+ const PlanC* __restrict__ plan_w; // trailing partial segments
+ uint32_t num_compress;
+ uint32_t num_write;
+};
+
+struct Compress128SharedBuffer {
+ using Storage = device::AlignedVector;
+ Storage data[kNumWarps][device::kWarpThreads + 1]; // +1 to avoid bank conflict
+ SGL_DEVICE Storage& operator()(uint32_t warp_id, uint32_t lane_id) {
+ return data[warp_id][lane_id];
+ }
+ SGL_DEVICE float& operator()(uint32_t warp_id, uint32_t lane_id, uint32_t tile_id) {
+ return data[warp_id][lane_id][tile_id];
+ }
+};
+
+/// \brief Sentinel score for padded positions in a 128-segment.
+constexpr float kPadScore = -FLT_MAX;
+
+[[maybe_unused]]
+SGL_DEVICE void c128_prefill_segment_softmax(
+ const PrefillStorage (&kv)[kElementsPerWarp],
+ const PrefillStorage (&score)[kElementsPerWarp],
+ float* seg_kv,
+ float* seg_max,
+ float* seg_sum,
+ const uint32_t warp_id,
+ const uint32_t lane_id) {
+ using namespace device;
+
+ // Per-warp running state (max, sum, kv) for kTileElements head-dim slots.
+ using TmpStorage = typename Compress128SharedBuffer::Storage;
+ __shared__ Compress128SharedBuffer s_local_val_max;
+ __shared__ Compress128SharedBuffer s_local_exp_sum;
+ __shared__ Compress128SharedBuffer s_local_product;
+
+ TmpStorage tmp_val_max;
+ TmpStorage tmp_exp_sum;
+ TmpStorage tmp_product;
+
+#pragma unroll
+ for (int32_t i = 0; i < kTileElements; ++i) {
+ float score_fp32[kElementsPerWarp];
+#pragma unroll
+ for (int32_t j = 0; j < kElementsPerWarp; ++j) {
+ score_fp32[j] = score[j][i];
+ }
+ float max_value = score_fp32[0];
+#pragma unroll
+ for (int32_t j = 1; j < kElementsPerWarp; ++j) {
+ max_value = fmaxf(max_value, score_fp32[j]);
+ }
+ float sum_exp_value = 0.0f;
+ float sum_product = 0.0f;
+#pragma unroll
+ for (int32_t j = 0; j < kElementsPerWarp; ++j) {
+ const auto exp_score = expf(score_fp32[j] - max_value);
+ sum_product += kv[j][i] * exp_score;
+ sum_exp_value += exp_score;
+ }
+ tmp_val_max[i] = max_value;
+ tmp_exp_sum[i] = sum_exp_value;
+ tmp_product[i] = sum_product;
+ }
+
+ // Aligned writes (no bank conflict thanks to `+1` padding).
+ s_local_val_max(warp_id, lane_id) = tmp_val_max;
+ s_local_exp_sum(warp_id, lane_id) = tmp_exp_sum;
+ s_local_product(warp_id, lane_id) = tmp_product;
+
+ __syncthreads();
+
+ // Cross-warp reduction. Same recipe as c128_online.cuh: each block-thread
+ // pair reduces a (tile_id, lane_id) slot using a kNumWarps-wide warp shuffle.
+ constexpr uint32_t kReductionCount = kTileElements * kWarpThreads * kNumWarps;
+ constexpr uint32_t kIteration = kReductionCount / kPrefillBlockSize;
+ static_assert(kTileElements * kNumWarps == kWarpThreads, "TODO: support other configs");
+
+#pragma unroll
+ for (uint32_t i = 0; i < kIteration; ++i) {
+ const uint32_t j = i * kPrefillBlockSize + warp_id * kWarpThreads + lane_id;
+ const uint32_t local_warp_id = j % kNumWarps;
+ const uint32_t local_elem_id = j / kNumWarps;
+ const uint32_t local_tile_id = local_elem_id % kTileElements;
+ const uint32_t local_lane_id = local_elem_id / kTileElements;
+ const auto local_val_max = s_local_val_max(local_warp_id, local_lane_id, local_tile_id);
+ const auto local_exp_sum = s_local_exp_sum(local_warp_id, local_lane_id, local_tile_id);
+ const auto local_product = s_local_product(local_warp_id, local_lane_id, local_tile_id);
+ const auto global_val_max = warp::reduce_max(local_val_max);
+ const auto rescale = expf(local_val_max - global_val_max);
+ const auto global_exp_sum = warp::reduce_sum(local_exp_sum * rescale);
+ const auto final_scale = rescale / global_exp_sum;
+ const auto global_product = warp::reduce_sum(local_product * final_scale);
+ seg_kv[local_elem_id] = global_product;
+ seg_max[local_elem_id] = global_val_max;
+ seg_sum[local_elem_id] = global_exp_sum;
+ }
+ __syncthreads();
+}
+
+/// \brief Online compress 128 prefill v2.
+///
+/// `kWrite=false` (compress pass): handles segments that close a 128-chunk.
+/// Reads optional prior state from `read_page_0` (-1 = none), emits compressed
+/// kv to `kv_compressed_output[plan_id]` (compact).
+/// `kWrite=true` (write pass) : handles trailing partial segments.
+/// Reads optional prior state from `read_page_0` (-1 = none), writes new
+/// running state to `read_page_1`.
+template
+__global__ __launch_bounds__(kPrefillBlockSize, 2) //
+ void flash_c128_online_prefill_v2(const __grid_constant__ Compress128OnlinePrefillParams params) {
+ using namespace device;
+
+ constexpr int64_t kTileDim = kTileElements * kWarpThreads; // 64
+ constexpr uint32_t kNumSplit = kHeadDim / kTileDim;
+ static_assert(kHeadDim % kTileDim == 0);
+
+ // Compile-time fold to the right plan list.
+ const auto num_plans = kWrite ? params.num_write : params.num_compress;
+ const auto plan_ptr = kWrite ? params.plan_w : params.plan_c;
+ const uint32_t global_id = blockIdx.x;
+ const uint32_t global_pid = global_id / kNumSplit;
+ const uint32_t global_sid = global_id % kNumSplit;
+ if (global_pid >= num_plans) return;
+
+ const uint32_t warp_id = threadIdx.x / kWarpThreads;
+ const uint32_t lane_id = threadIdx.x % kWarpThreads;
+ const int32_t split_offset = global_sid * kTileDim;
+
+ // The previous kernel (plan-finalize stage 1) does NOT issue a PDL trigger,
+ // so PDLWaitPrimary effectively waits for stage 1 to complete. Read the plan
+ // AFTER the wait so the freshly-written `read_page_0` (= state-pool slot) is
+ // visible. Reading it before the wait is a real race -- with PDL enabled the
+ // kernel can begin executing before stage 1's stores propagate, and we'd see
+ // the stage-0 batch_id placeholder in `read_page_0` instead of the slot.
+ PDLWaitPrimary();
+
+ const auto plan = plan_ptr[global_pid];
+ if (plan.is_invalid()) [[unlikely]]
+ return;
+
+ const auto kv_score_buffer = static_cast(params.kv_score_buffer);
+ const auto kv_score_input = static_cast(params.kv_score_input);
+ const auto kv_compressed_output = static_cast(params.kv_compressed_output);
+ const auto score_bias_base = static_cast(params.score_bias);
+
+ constexpr int64_t kElementSize = kHeadDim * 2; // | kv | score |
+
+ // The plan stores last-token coordinates; segment start is recoverable as
+ // ragged_id - window_len + 1.
+ const uint32_t window_len = plan.buffer_len;
+ const uint32_t position = plan.seq_len - 1;
+ const uint32_t pos_in_chunk_end = (position % 128u) + 1u; // exclusive, in [1, 128]
+ const uint32_t chunk_offset = pos_in_chunk_end - window_len; // in [0, 127]
+ const int32_t segment_start_ragged = static_cast(plan.ragged_id) - static_cast(position % 128u);
+
+ // --- Stage 1: load kv / score / bias for this warp's 8 chunk positions.
+ PrefillStorage kv[kElementsPerWarp];
+ PrefillStorage score[kElementsPerWarp];
+ PrefillStorage bias[kElementsPerWarp];
+ const uint32_t warp_offset = warp_id * kElementsPerWarp;
+
+#pragma unroll
+ for (uint32_t i = 0; i < kElementsPerWarp; ++i) {
+ const uint32_t j = i + warp_offset;
+ if (j >= chunk_offset && j < pos_in_chunk_end) {
+ const auto kv_src_ptr = kv_score_input + (segment_start_ragged + j) * kElementSize + split_offset;
+ const auto score_src_ptr = kv_src_ptr + kHeadDim;
+ const auto bias_src_ptr = score_bias_base + j * kHeadDim + split_offset;
+ kv[i].load(kv_src_ptr, lane_id);
+ score[i].load(score_src_ptr, lane_id);
+ bias[i].load(bias_src_ptr, lane_id);
+ }
+ }
+
+ // --- Stage 2: pad invalid positions. score = -FLT_MAX, kv = 0 (so that
+ // kv * exp(score-max) ??? 0 / 0 cleanly without producing NaN/inf).
+#pragma unroll
+ for (uint32_t i = 0; i < kElementsPerWarp; ++i) {
+ const uint32_t j = i + warp_offset;
+ const bool is_valid = (j >= chunk_offset && j < pos_in_chunk_end);
+#pragma unroll
+ for (uint32_t ii = 0; ii < kTileElements; ++ii) {
+ score[i][ii] = is_valid ? score[i][ii] + bias[i][ii] : kPadScore;
+ kv[i][ii] = is_valid ? kv[i][ii] : 0.0f;
+ }
+ }
+
+ // --- Stage 3: warp-tile online softmax over the 128-position chunk.
+ __shared__ alignas(16) float seg_kv[kTileDim];
+ __shared__ alignas(16) float seg_max[kTileDim];
+ __shared__ alignas(16) float seg_sum[kTileDim];
+ c128_prefill_segment_softmax(kv, score, seg_kv, seg_max, seg_sum, warp_id, lane_id);
+
+ PDLTriggerSecondary();
+
+ // --- Stage 4: warp 0 folds with prior partial state (if any) and writes.
+ if (warp_id == 0) {
+ PrefillStorage out_kv_vec, out_max_vec, out_sum_vec;
+ out_kv_vec.load(seg_kv, lane_id);
+ out_max_vec.load(seg_max, lane_id);
+ out_sum_vec.load(seg_sum, lane_id);
+
+ if (chunk_offset != 0 && plan.read_page_0 >= 0) {
+ // Combine with prior partial state for this slot.
+ const auto buf_load = kv_score_buffer + plan.read_page_0 * (kHeadDim * 3) + split_offset;
+ PrefillStorage buf_max_vec, buf_sum_vec, buf_kv_vec;
+ buf_max_vec.load(buf_load + 0 * kHeadDim, lane_id);
+ buf_sum_vec.load(buf_load + 1 * kHeadDim, lane_id);
+ buf_kv_vec.load(buf_load + 2 * kHeadDim, lane_id);
+#pragma unroll
+ for (uint32_t ii = 0; ii < kTileElements; ++ii) {
+ const float m1 = buf_max_vec[ii];
+ const float s1 = buf_sum_vec[ii];
+ const float k1 = buf_kv_vec[ii];
+ const float m2 = out_max_vec[ii];
+ const float s2 = out_sum_vec[ii];
+ const float k2 = out_kv_vec[ii];
+ const float new_max = fmaxf(m1, m2);
+ const float new_s1 = s1 * expf(m1 - new_max);
+ const float new_s2 = s2 * expf(m2 - new_max);
+ const float new_sum = new_s1 + new_s2;
+ const float new_kv = (k1 * new_s1 + k2 * new_s2) / new_sum;
+ out_max_vec[ii] = new_max;
+ out_sum_vec[ii] = new_sum;
+ out_kv_vec[ii] = new_kv;
+ }
+ }
+
+ if constexpr (kWrite) {
+ // For trailing-partial segments the load and store slots collapse to the
+ // segment's own chunk slot (the request keeps a single in-progress
+ // chunk's running state at any time), so we reuse `read_page_0`.
+ const auto buf_store = kv_score_buffer + plan.read_page_0 * (kHeadDim * 3) + split_offset;
+ reinterpret_cast(buf_store + 0 * kHeadDim)[lane_id] = out_max_vec;
+ reinterpret_cast(buf_store + 1 * kHeadDim)[lane_id] = out_sum_vec;
+ reinterpret_cast(buf_store + 2 * kHeadDim)[lane_id] = out_kv_vec;
+ } else {
+ // Compact output: one row per compress plan, indexed by `global_pid`.
+ const auto out_ptr = kv_compressed_output + global_pid * kHeadDim + split_offset;
+ reinterpret_cast(out_ptr)[lane_id] = out_kv_vec;
+ }
+ }
+}
+
+// ---------------------------------------------------------------------------
+// Host wrapper: matches the c128_v2 / c4_v2 host API style (run_decode /
+// run_prefill methods on a kernel-class template). We only expose `kHeadDim`
+// + `kUsePDL`; the dtype is fixed to fp32 for the online state pool.
+// ---------------------------------------------------------------------------
+
+template
+struct FlashCompress128OnlineKernel {
+ static constexpr auto decode_kernel = flash_c128_online_decode_v2;
+ template
+ static constexpr auto prefill_kernel = flash_c128_online_prefill_v2;
+ static constexpr int64_t kTileDim = kTileElements * device::kWarpThreads; // 64
+ static constexpr uint32_t kNumSplit = kHeadDim / kTileDim;
+ static constexpr uint32_t kDecodeBlockSize = kHeadDim / 4;
+
+ static void run_decode(
+ const tvm::ffi::TensorView kv_score_buffer,
+ const tvm::ffi::TensorView kv_score_input,
+ const tvm::ffi::TensorView kv_compressed_output,
+ const tvm::ffi::TensorView ape,
+ const tvm::ffi::TensorView plan_d_) {
+ using namespace host;
+
+ auto B = SymbolicSize{"batch_size"};
+ auto device_ = SymbolicDevice{};
+ device_.set_options();
+
+ TensorMatcher({-1, 1, kHeadDim * 3}) // kv score buffer (max, sum, kv)
+ .with_dtype()
+ .with_device(device_)
+ .verify(kv_score_buffer);
+ TensorMatcher({B, kHeadDim * 2}) // kv score input
+ .with_dtype()
+ .with_device(device_)
+ .verify(kv_score_input);
+ TensorMatcher({B, kHeadDim}) // kv compressed output (sparse by batch_id)
+ .with_dtype()
+ .with_device(device_)
+ .verify(kv_compressed_output);
+ TensorMatcher({128, kHeadDim}) // ape
+ .with_dtype()
+ .with_device(device_)
+ .verify(ape);
+
+ const auto plan_d = compress::verify_plan_d(plan_d_, B, device_);
+ const auto batch_size = static_cast(B.unwrap());
+ if (batch_size == 0) return;
+ const auto params = Compress128OnlineDecodeParams{
+ .kv_score_buffer = kv_score_buffer.data_ptr(),
+ .kv_score_input = kv_score_input.data_ptr(),
+ .kv_compressed_output = kv_compressed_output.data_ptr(),
+ .score_bias = ape.data_ptr(),
+ .plan_d = plan_d,
+ .batch_size = batch_size,
+ };
+ LaunchKernel(batch_size, kDecodeBlockSize, device_.unwrap()) //
+ .enable_pdl(kUsePDL)(decode_kernel, params);
+ }
+
+ static void run_prefill(
+ const tvm::ffi::TensorView kv_score_buffer,
+ const tvm::ffi::TensorView kv_score_input,
+ const tvm::ffi::TensorView kv_compressed_output,
+ const tvm::ffi::TensorView ape,
+ const tvm::ffi::TensorView plan_c_,
+ const tvm::ffi::TensorView plan_w_) {
+ using namespace host;
+
+ auto N = SymbolicSize{"num_q_tokens"};
+ auto C = SymbolicSize{"num_c_plans"};
+ auto W = SymbolicSize{"num_w_plans"};
+ auto device_ = SymbolicDevice{};
+ device_.set_options();
+
+ TensorMatcher({-1, 1, kHeadDim * 3}) // kv score buffer
+ .with_dtype()
+ .with_device(device_)
+ .verify(kv_score_buffer);
+ TensorMatcher({N, kHeadDim * 2}) // kv score input (ragged)
+ .with_dtype()
+ .with_device(device_)
+ .verify(kv_score_input);
+ TensorMatcher({C, kHeadDim}) // kv compressed output (compact, by plan_c index)
+ .with_dtype()
+ .with_device(device_)
+ .verify(kv_compressed_output);
+ TensorMatcher({128, kHeadDim}) // ape
+ .with_dtype()
+ .with_device(device_)
+ .verify(ape);
+
+ // Both compress and write segments use PlanC layout. plan_c uses
+ // read_page_1=-1 (unused); plan_w uses read_page_1=store_slot.
+ const auto plan_c = compress::verify_plan_c(plan_c_, C, device_);
+ const auto plan_w = compress::verify_plan_c(plan_w_, W, device_);
+ const auto device = device_.unwrap();
+ const auto num_q_tokens = static_cast(N.unwrap());
+ const auto num_c = static_cast(C.unwrap());
+ const auto num_w = static_cast(W.unwrap());
+ RuntimeCheck(num_q_tokens >= num_w, "invalid prefill plan: num_q < num_w");
+ const auto params = Compress128OnlinePrefillParams{
+ .kv_score_buffer = kv_score_buffer.data_ptr(),
+ .kv_score_input = kv_score_input.data_ptr(),
+ .kv_compressed_output = kv_compressed_output.data_ptr(),
+ .score_bias = ape.data_ptr(),
+ .plan_c = plan_c,
+ .plan_w = plan_w,
+ .num_compress = num_c,
+ .num_write = num_w,
+ };
+
+ // The two passes MUST be serialized in stream order: pass 1 reads slots
+ // that pass 2 may write to; running them in parallel would race.
+ if (const auto num_c_blocks = num_c * kNumSplit) {
+ LaunchKernel(num_c_blocks, kPrefillBlockSize, device) //
+ .enable_pdl(kUsePDL)(prefill_kernel*kWrite=*/false>, params);
+ }
+ if (const auto num_w_blocks = num_w * kNumSplit) {
+ LaunchKernel(num_w_blocks, kPrefillBlockSize, device) //
+ .enable_pdl(kUsePDL)(prefill_kernel*kWrite=*/true>, params);
+ }
+ }
+};
+
+} // namespace
+
+// ===========================================================================
+// Plan builders. Mirrors the offline v2 pattern (`c_plan.cuh`):
+// - Decode: a single GPU kernel reads seq_lens / req_to_token /
+// req_pool_indices on device and emits the final PlanD tensor in one go.
+// - Prefill: stage 0 (host, on CPU pinned memory) splits each batch's
+// extend range into per-chunk segments and emits PlanC entries with the
+// batch_id stashed in `read_page_0` as a placeholder. Stage 1 is a tiny
+// GPU kernel that finalizes `read_page_0` to `req_to_token[rid][chunk_start]`,
+// so the slot tensors never leave GPU memory. The online state pool keeps
+// a single in-progress chunk per request, so each segment's load and
+// store slot collapse to one value (the slot for the segment's own chunk),
+// and `read_page_1` is unused.
+// ===========================================================================
+
+namespace host::compress {
+
+using device::compress::CompressPlan;
+using device::compress::DecodePlan;
+
+// ---------------------------------------------------------------------------
+// Decode plan builder.
+// ---------------------------------------------------------------------------
+
+struct OnlineDecodePlanParams {
+ DecodePlan* __restrict__ plan_d;
+ const int64_t* __restrict__ seq_lens;
+ const int64_t* __restrict__ req_pool_indices;
+ const int32_t* __restrict__ req_to_token;
+ const int64_t* __restrict__ full_to_swa; // (full_cache_size,) int64
+ int64_t stride_r2t;
+ int32_t swa_page_size;
+ uint32_t batch_size;
+};
+
+__global__ void plan_c128_online_decode_kernel(const OnlineDecodePlanParams params) {
+ const uint32_t idx = blockIdx.x * blockDim.x + threadIdx.x;
+ if (idx >= params.batch_size) return;
+ const auto seq_len = static_cast(params.seq_lens[idx]);
+ const auto rid = params.req_pool_indices[idx];
+ const int32_t chunk_start = static_cast((seq_len - 1u) / 128u * 128u);
+ const int32_t full_loc = params.req_to_token[rid * params.stride_r2t + chunk_start];
+ const int32_t swa_loc = static_cast(params.full_to_swa[full_loc]);
+ const int32_t slot = swa_loc / params.swa_page_size;
+ params.plan_d[idx] = DecodePlan{
+ .seq_len = seq_len,
+ .write_loc = slot,
+ .read_page_0 = slot,
+ .read_page_1 = -1,
+ };
+}
+
+/// \brief Build the decode plan tensor. Caller (Python) pre-allocates
+/// `plan_d_dev` as a `(batch_size, 16)` device uint8 tensor; this routine
+/// only fills it. See `plan_online_prefill` for the rationale (avoid
+/// `ffi::empty` + dlpack roundtrip / PyTorch caching-allocator stream
+/// tracking issue that surfaces as IMA in unrelated downstream kernels).
+inline void plan_online_decode(
+ const tvm::ffi::TensorView seq_lens,
+ const tvm::ffi::TensorView req_pool_indices,
+ const tvm::ffi::TensorView req_to_token,
+ const tvm::ffi::TensorView full_to_swa,
+ const tvm::ffi::TensorView plan_d_dev_,
+ const int32_t swa_page_size) {
+ auto B = SymbolicSize{"batch_size"};
+ auto device_ = SymbolicDevice{};
+ device_.set_options();
+
+ auto seq_dtype = SymbolicDType{};
+ TensorMatcher({B}) //
+ .with_dtype(seq_dtype)
+ .with_device(device_)
+ .verify(seq_lens);
+ TensorMatcher({B}) //
+ .with_dtype()
+ .with_device(device_)
+ .verify(req_pool_indices);
+ TensorMatcher({-1, -1}) //
+ .with_dtype()
+ .with_device(device_)
+ .verify(req_to_token);
+ TensorMatcher({-1}) //
+ .with_dtype()
+ .with_device(device_)
+ .verify(full_to_swa);
+ TensorMatcher({B, sizeof(DecodePlan)}) //
+ .with_dtype()
+ .with_device(device_)
+ .verify(plan_d_dev_);
+ RuntimeCheck(swa_page_size > 0);
+
+ const auto batch_size = static_cast(B.unwrap());
+ if (batch_size == 0) return;
+
+ const auto device = device_.unwrap();
+ constexpr uint32_t kBlockSize = 256;
+ const uint32_t num_blocks = host::div_ceil(batch_size, kBlockSize);
+ const auto stride_r2t = req_to_token.stride(0);
+ const auto params = OnlineDecodePlanParams{
+ .plan_d = static_cast(plan_d_dev_.data_ptr()),
+ .seq_lens = static_cast(seq_lens.data_ptr()),
+ .req_pool_indices = static_cast(req_pool_indices.data_ptr()),
+ .req_to_token = static_cast(req_to_token.data_ptr()),
+ .full_to_swa = static_cast(full_to_swa.data_ptr()),
+ .stride_r2t = stride_r2t,
+ .swa_page_size = swa_page_size,
+ .batch_size = batch_size,
+ };
+ LaunchKernel(num_blocks, kBlockSize, device)(plan_c128_online_decode_kernel, params);
+}
+
+// ---------------------------------------------------------------------------
+// Prefill plan builder: host stage 0 + GPU stage 1.
+// ---------------------------------------------------------------------------
+
+struct OnlinePrefillStage0Params {
+ CompressPlan* __restrict__ plan_c;
+ CompressPlan* __restrict__ plan_w;
+ const int64_t* __restrict__ seq_lens;
+ const int64_t* __restrict__ extend_lens;
+ uint32_t batch_size;
+ uint32_t num_q_tokens;
+};
+
+inline std::tuple _plan_prefill_partial(const OnlinePrefillStage0Params& p) {
+ uint32_t counter = 0;
+ uint32_t compress_count = 0;
+ uint32_t write_count = 0;
+ for (const auto i : irange(p.batch_size)) {
+ const uint32_t seq_len = static_cast(p.seq_lens[i]);
+ const uint32_t extend_len = static_cast(p.extend_lens[i]);
+ RuntimeCheck(0 < extend_len && extend_len <= seq_len);
+ const uint32_t prefix_len = seq_len - extend_len;
+ const uint32_t end_pos = prefix_len + extend_len;
+
+ uint32_t pos = prefix_len;
+ while (pos < end_pos) {
+ const uint32_t chunk_start = (pos / 128u) * 128u;
+ const uint32_t seg_end = std::min(end_pos, chunk_start + 128u); // exclusive
+ const uint32_t seg_len = seg_end - pos;
+ const uint32_t chunk_off = pos - chunk_start;
+ const uint32_t last_pos = seg_end - 1;
+ const uint32_t last_ragged = counter + (last_pos - prefix_len);
+ RuntimeCheck(last_ragged < (1u << 16), "PlanC.ragged_id is uint16; ragged ", last_ragged, " overflows");
+ RuntimeCheck(seg_len <= 128u);
+ // Stash batch_id in `read_page_0` for stage 1 to translate. A
+ // chunk-aligned segment never loads, so we still need stage 1 to fill
+ // a slot in -- the kernel keys the load on `chunk_offset != 0`.
+ const auto plan = CompressPlan{
+ .seq_len = last_pos + 1u,
+ .ragged_id = static_cast(last_ragged),
+ .buffer_len = static_cast(seg_len),
+ .read_page_0 = static_cast(i), // batch_id placeholder
+ .read_page_1 = -1, // unused, kept so MSB layout is stable
+ };
+ if (chunk_off + seg_len == 128u) {
+ // close-chunk segment
+ RuntimeCheck(compress_count < p.num_q_tokens);
+ p.plan_c[compress_count++] = plan;
+ } else {
+ // trailing partial segment
+ RuntimeCheck(write_count < p.num_q_tokens);
+ p.plan_w[write_count++] = plan;
+ }
+ pos = seg_end;
+ }
+ counter += extend_len;
+ }
+ RuntimeCheck(counter == p.num_q_tokens, "input size ", counter, " != num_q_tokens ", p.num_q_tokens);
+ return std::tuple{compress_count, write_count};
+}
+
+struct OnlinePrefillStage1Params {
+ CompressPlan* __restrict__ plan_c;
+ CompressPlan* __restrict__ plan_w;
+ const int64_t* __restrict__ req_pool_indices; // (batch_size,)
+ const int32_t* __restrict__ req_to_token; // (num_reqs, max_tokens)
+ const int64_t* __restrict__ full_to_swa; // (full_cache_size,)
+ int64_t stride_r2t;
+ int32_t swa_page_size;
+ uint32_t num_c;
+ uint32_t num_w;
+};
+
+__global__ void plan_c128_online_prefill_kernel(const OnlinePrefillStage1Params params) {
+ const uint32_t idx = blockIdx.x * blockDim.x + threadIdx.x;
+ const uint32_t total = params.num_c + params.num_w;
+ if (idx >= total) return;
+
+ const bool is_compress = idx < params.num_c;
+ CompressPlan* const plan_ptr = is_compress ? ¶ms.plan_c[idx] : ¶ms.plan_w[idx - params.num_c];
+ auto plan = *plan_ptr;
+ const auto batch_id = plan.read_page_0;
+ const auto rid = params.req_pool_indices[batch_id];
+ const int32_t position = static_cast(plan.seq_len - 1u);
+ const int32_t chunk_start = (position / 128) * 128;
+ const int32_t full_loc = params.req_to_token[rid * params.stride_r2t + chunk_start];
+ const int32_t swa_loc = static_cast(params.full_to_swa[full_loc]);
+ plan.read_page_0 = swa_loc / params.swa_page_size;
+ *plan_ptr = plan;
+}
+
+using OnlinePrefillPlan = tvm::ffi::Tuple;
+
+inline OnlinePrefillPlan plan_online_prefill(
+ const tvm::ffi::TensorView seq_lens,
+ const tvm::ffi::TensorView extend_lens,
+ const tvm::ffi::TensorView req_pool_indices,
+ const tvm::ffi::TensorView req_to_token,
+ const tvm::ffi::TensorView full_to_swa,
+ const tvm::ffi::TensorView plan_c_pin,
+ const tvm::ffi::TensorView plan_w_pin,
+ const tvm::ffi::TensorView plan_c_dev_,
+ const tvm::ffi::TensorView plan_w_dev_,
+ const int32_t swa_page_size) {
+ auto B = SymbolicSize{"batch_size"};
+ auto N = SymbolicSize{"num_q_tokens"};
+ auto cpu = SymbolicDevice{};
+ auto device_ = SymbolicDevice{};
+ cpu.set_options();
+ device_.set_options();
+
+ TensorMatcher({B}) //
+ .with_dtype()
+ .with_device(cpu)
+ .verify(seq_lens)
+ .verify(extend_lens);
+ TensorMatcher({B}) //
+ .with_dtype()
+ .with_device(device_)
+ .verify(req_pool_indices);
+ TensorMatcher({-1, -1}) //
+ .with_dtype()
+ .with_device(device_)
+ .verify(req_to_token);
+ TensorMatcher({-1}) //
+ .with_dtype()
+ .with_device(device_)
+ .verify(full_to_swa);
+ TensorMatcher({N, sizeof(CompressPlan)}) //
+ .with_dtype()
+ .with_device(cpu)
+ .verify(plan_c_pin)
+ .verify(plan_w_pin);
+ TensorMatcher({N, sizeof(CompressPlan)}) //
+ .with_dtype