Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
51 commits
Select commit Hold shift + click to select a range
7c7bd61
one pass
WANDY666 Jun 3, 2026
d790ad2
Optimization
WANDY666 Jun 5, 2026
a161244
add prompt cache
WANDY666 Jun 5, 2026
61eed87
support cudagraph
WANDY666 Jun 5, 2026
19866d0
refact tokenizer
WANDY666 Jun 8, 2026
29c6082
add statement
WANDY666 Jun 8, 2026
ffafdbf
format
WANDY666 Jun 8, 2026
e8009cb
pass gsm8k but need review
WANDY666 Jun 11, 2026
b3b8123
fix
WANDY666 Jun 11, 2026
6002866
fix rope
WANDY666 Jun 11, 2026
6bc34ad
dsv4: enable decode cudagraph; fix warmup-baked FlashMLASchedMeta
WANDY666 Jun 11, 2026
e78e0d4
dsv4: enable prefill cudagraph; zero pad-row attention output
Jun 11, 2026
c09dc6a
fix profile
WANDY666 Jun 11, 2026
c07e38c
support fp8
WANDY666 Jun 12, 2026
ff71706
optimize
WANDY666 Jun 12, 2026
d7dd6e0
fix
WANDY666 Jun 12, 2026
3a5dcdc
compress infer
WANDY666 Jun 14, 2026
d76450f
add c128 to mem_manager
WANDY666 Jun 14, 2026
07d2308
refact
WANDY666 Jun 15, 2026
d4dcd8a
opt
WANDY666 Jun 15, 2026
62c16d5
opt
WANDY666 Jun 15, 2026
69824d0
delete launch.sh
WANDY666 Jun 15, 2026
df70ecb
fix
WANDY666 Jun 15, 2026
1ad981d
restore
WANDY666 Jun 16, 2026
7b17bb5
support parser
WANDY666 Jun 16, 2026
6837abd
fix
WANDY666 Jun 16, 2026
e1376fe
Merge branch 'main' of https://github.com/ModelTC/LightLLM into suppo…
WANDY666 Jun 16, 2026
02a24ce
add c4 paged indexes
WANDY666 Jun 18, 2026
52a1528
fix chunk_size and page_size
WANDY666 Jun 18, 2026
0dbc90b
add sglang third_party
WANDY666 Jun 18, 2026
e8c49d1
fix tpsp
WANDY666 Jun 18, 2026
88309b5
fix profile
WANDY666 Jun 21, 2026
cf433fb
fix swa insufficient
WANDY666 Jun 22, 2026
40f5810
fix
WANDY666 Jun 22, 2026
f527ca2
rename
WANDY666 Jun 22, 2026
255e90d
tune config
WANDY666 Jun 22, 2026
d88dc71
prepare opt
WANDY666 Jun 22, 2026
a56c79b
delete
WANDY666 Jun 22, 2026
e286943
item1: wire fused_q_indexer_rope_hadamard_quant (rope+hadamard+fp8qua…
WANDY666 Jun 23, 2026
58b145b
item3: lazy-cache layer-independent c4 paged metadata (page_table/ctx…
WANDY666 Jun 23, 2026
a0379bb
gate-bf16 (flag) + drop redundant attn_sink fp32 copy + lazy gen_nsa_…
WANDY666 Jun 23, 2026
b796d48
cache prefill FlashMLA sched-meta per compress-ratio (was rebuilt eve…
WANDY666 Jun 23, 2026
e07b85e
2-stream
WANDY666 Jun 23, 2026
51f5c84
fix parser
WANDY666 Jun 24, 2026
2f12a07
fix multi-invoke
WANDY666 Jun 24, 2026
da3fec2
speed up prepare
WANDY666 Jun 24, 2026
82cb6d6
fix arguments
WANDY666 Jun 24, 2026
bc22591
tune H100
WANDY666 Jun 24, 2026
77baa05
add encoding_dsv4
WANDY666 Jun 25, 2026
2225e1a
fix c4 error
WANDY666 Jun 26, 2026
112247b
fuse wq_a+wkv & indexer wkv+wgate GEMMs; fp8 wo_a at tp8 (1 group/rank)
WANDY666 Jun 26, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
199 changes: 197 additions & 2 deletions lightllm/common/basemodel/attention/nsa/fp8_flashmla_sparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,60 @@
from lightllm.common.basemodel.infer_struct import InferStateInfo


# this flash_mla extra-cache fork only instantiates h_q in {64, 128}; pad TP-split q heads up
# to the nearest supported count (zero heads are discarded from the output slice).
FLASHMLA_SUPPORTED_HEADS = (64, 128)


def _pad_q_heads(q_4d: torch.Tensor, attn_sink: torch.Tensor):
h_q = q_4d.shape[2]
if h_q in FLASHMLA_SUPPORTED_HEADS:
return q_4d, attn_sink, h_q
target = next((h for h in FLASHMLA_SUPPORTED_HEADS if h >= h_q), None)
assert target is not None, f"num q heads {h_q} exceeds flash_mla support {FLASHMLA_SUPPORTED_HEADS}"
q_pad = torch.nn.functional.pad(q_4d, (0, 0, 0, target - h_q))
sink_pad = torch.nn.functional.pad(attn_sink, (0, target - h_q))
return q_pad, sink_pad, h_q


def _view_dsv4_flashmla_cache(layer_buffer: torch.Tensor, page_size: int) -> torch.Tensor:
from lightllm.common.kv_cache_mem_manager.deepseek4_mem_manager import DSV4_MLA_BYTES_PER_TOKEN

usable = page_size * DSV4_MLA_BYTES_PER_TOKEN
return layer_buffer[:, :usable].view(layer_buffer.shape[0], page_size, 1, DSV4_MLA_BYTES_PER_TOKEN)


@dataclasses.dataclass
class _Dsv4Metadata:
swa_indices: torch.Tensor
swa_lengths: torch.Tensor
extra_cache: torch.Tensor = None
extra_indices: torch.Tensor = None
extra_lengths: torch.Tensor = None


def _metadata_from_dict(infer_state, nsa_dict: dict) -> "_Dsv4Metadata":
"""Bundle the model-built FINAL index tensors (carried in nsa_dict by DeepseekV4IndexInfer) with
the layer-keyed fp8 extra-cache byte view. The cache view is data-independent (a fixed per-layer
buffer slice), so it is built here -- a genuine flash_mla ABI concern -- rather than on the model
side; only the index/length tensors cross the att_control boundary."""
from lightllm.common.kv_cache_mem_manager.deepseek4_mem_manager import DSV4_C128_PAGE_SIZE, DSV4_C4_PAGE_SIZE

ratio = nsa_dict["compress_ratio"]
extra_cache = None
if ratio:
page = DSV4_C4_PAGE_SIZE if ratio == 4 else DSV4_C128_PAGE_SIZE
extra_buffer = infer_state.mem_manager.get_compressed_kv_buffer(nsa_dict["layer_index"])
extra_cache = _view_dsv4_flashmla_cache(extra_buffer, page)
return _Dsv4Metadata(
swa_indices=nsa_dict["swa_indices"],
swa_lengths=nsa_dict["swa_lengths"],
extra_cache=extra_cache,
extra_indices=nsa_dict.get("extra_indices"),
extra_lengths=nsa_dict.get("extra_lengths"),
)


class NsaFlashMlaFp8SparseAttBackend(BaseAttBackend):
def __init__(self, model):
super().__init__(model=model)
Expand All @@ -31,9 +85,20 @@ class NsaFlashMlaFp8SparsePrefillAttState(BasePrefillAttState):
ke: torch.Tensor = None
lengths: torch.Tensor = None
ragged_mem_index: torch.Tensor = None
flashmla_sched_meta: object = None

def init_state(self):
self.backend: NsaFlashMlaFp8SparseAttBackend = self.backend
self.flashmla_sched_meta = {}
return

def ensure_nsa_ks_ke(self):
"""Build the ragged ks/ke/lengths (+ ragged_mem_index) the DeepSeek-3.2 indexer consumes. The
indexer calls this explicitly before reading them; DeepSeek-V4 uses its own indexer and never
calls it, so V4 prefill skips the alloc + gen_nsa_ks_ke kernel. Idempotent + layer-independent:
the first call in a forward computes, the other layers reuse."""
if self.ks is not None:
return
self.ragged_mem_index = torch.empty(
self.infer_state.total_token_num,
dtype=torch.int32,
Expand All @@ -52,6 +117,15 @@ def init_state(self):
)
return

def _get_flashmla_sched_meta(self, compress_ratio: int):
sched_meta = self.flashmla_sched_meta.get(compress_ratio)
if sched_meta is None:
import flash_mla

sched_meta = flash_mla.get_mla_metadata()[0]
self.flashmla_sched_meta[compress_ratio] = sched_meta
return sched_meta

def prefill_att(
self,
q: torch.Tensor,
Expand All @@ -62,6 +136,12 @@ def prefill_att(
) -> torch.Tensor:
assert att_control.nsa_prefill, "nsa_prefill must be True for NSA prefill attention"
assert att_control.nsa_prefill_dict is not None, "nsa_prefill_dict is required"
if att_control.nsa_prefill_dict.get("flashmla_kvcache"):
return self._flashmla_kvcache_prefill_att(
q=q,
packed_kv=k,
nsa_dict=att_control.nsa_prefill_dict,
)
return self._nsa_prefill_att(q=q, packed_kv=k, att_control=att_control)

def _nsa_prefill_att(
Expand All @@ -78,6 +158,8 @@ def _nsa_prefill_att(
kv_lora_rank = nsa_dict["kv_lora_rank"]
topk_mem_indices = nsa_dict["topk_mem_indices"]
prefill_cache_kv = nsa_dict["prefill_cache_kv"]
attn_sink = nsa_dict.get("attn_sink")
topk_length = nsa_dict.get("topk_length")

if self.infer_state.prefix_total_token_num > 0:
# 当前推理生成的token kv部分从 prefill_cache_kv 中获取,历史
Expand All @@ -101,9 +183,51 @@ def _nsa_prefill_att(
indices=topk_indices,
sm_scale=softmax_scale,
d_v=kv_lora_rank,
attn_sink=attn_sink,
topk_length=topk_length,
)
return mla_out

def _flashmla_kvcache_prefill_att(self, q: torch.Tensor, packed_kv: torch.Tensor, nsa_dict: dict) -> torch.Tensor:
attn_sink = nsa_dict["attn_sink"]
metadata = _metadata_from_dict(self.infer_state, nsa_dict)
return self._flashmla_kvcache_att(q, packed_kv, metadata, attn_sink, nsa_dict)

def _flashmla_kvcache_att(
self,
q: torch.Tensor,
packed_kv: torch.Tensor,
metadata: _Dsv4Metadata,
attn_sink: torch.Tensor,
nsa_dict: dict,
) -> torch.Tensor:
import flash_mla
from lightllm.common.kv_cache_mem_manager.deepseek4_mem_manager import DSV4_SWA_PAGE_SIZE

q_4d = q.unsqueeze(1).contiguous()
q_4d, attn_sink, num_real_heads = _pad_q_heads(q_4d, attn_sink)
k_cache = _view_dsv4_flashmla_cache(packed_kv, DSV4_SWA_PAGE_SIZE)
sched_meta = self._get_flashmla_sched_meta(nsa_dict["compress_ratio"])
out, _ = flash_mla.flash_mla_with_kvcache(
q=q_4d,
k_cache=k_cache,
block_table=None,
cache_seqlens=None,
head_dim_v=nsa_dict["head_dim_v"],
tile_scheduler_metadata=sched_meta,
num_splits=None,
softmax_scale=nsa_dict["softmax_scale"],
causal=False,
is_fp8_kvcache=True,
indices=metadata.swa_indices,
attn_sink=attn_sink,
topk_length=metadata.swa_lengths,
extra_k_cache=metadata.extra_cache,
extra_indices_in_kvcache=metadata.extra_indices,
extra_topk_length=metadata.extra_lengths,
)
return out[:, 0, :num_real_heads].contiguous()


@dataclasses.dataclass
class NsaFlashMlaFp8SparseDecodeAttState(BaseDecodeAttState):
Expand Down Expand Up @@ -143,7 +267,23 @@ def init_state(self):
)
import flash_mla

self.flashmla_sched_meta, _ = flash_mla.get_mla_metadata()
# one sched_meta per layer type: the lazy config locks extra-cache geometry (page size,
# presence) on first invocation, so swa-only/c4/c128 layers must not share one object.
self.flashmla_sched_meta = {ratio: flash_mla.get_mla_metadata()[0] for ratio in (0, 4, 128)}
return

def ensure_nsa_ks_ke(self):
# decode builds ks/ke eagerly in init_state (outside the cuda graph, for capture safety), so
# they are already available -- this satisfies the shared DeepSeek-3.2 indexer ensure contract.
return

def reset_sched_meta_for_capture(self):
# cuda-graph capture hook: the warmup pass already locked/stored sched meta on this
# (shared) state object; reset so the capture pass re-plans INSIDE the graph and every
# replay re-plans from the live tensors instead of binding warmup leftovers.
import flash_mla

self.flashmla_sched_meta = {ratio: flash_mla.get_mla_metadata()[0] for ratio in (0, 4, 128)}
return

def decode_att(
Expand All @@ -156,6 +296,12 @@ def decode_att(
) -> torch.Tensor:
assert att_control.nsa_decode, "nsa_decode must be True for NSA decode attention"
assert att_control.nsa_decode_dict is not None, "nsa_decode_dict is required"
if att_control.nsa_decode_dict.get("flashmla_kvcache"):
return self._flashmla_kvcache_decode_att(
q=q,
packed_kv=k,
nsa_dict=att_control.nsa_decode_dict,
)
return self._nsa_decode_att(q=q, packed_kv=k, att_control=att_control)

def _nsa_decode_att(
Expand All @@ -170,6 +316,11 @@ def _nsa_decode_att(
topk_mem_indices = nsa_dict["topk_mem_indices"]
softmax_scale = nsa_dict["softmax_scale"]
kv_lora_rank = nsa_dict["kv_lora_rank"]
attn_sink = nsa_dict.get("attn_sink")
topk_length = nsa_dict.get("topk_length")
extra_k_cache = nsa_dict.get("extra_k_cache")
extra_indices = nsa_dict.get("extra_indices_in_kvcache")
extra_topk_length = nsa_dict.get("extra_topk_length")

if topk_mem_indices.ndim == 2:
topk_mem_indices = topk_mem_indices.unsqueeze(1)
Expand All @@ -189,10 +340,54 @@ def _nsa_decode_att(
block_table=None,
cache_seqlens=None,
head_dim_v=kv_lora_rank,
tile_scheduler_metadata=self.flashmla_sched_meta,
tile_scheduler_metadata=self.flashmla_sched_meta[0],
softmax_scale=softmax_scale,
causal=False,
is_fp8_kvcache=True,
indices=topk_mem_indices,
attn_sink=attn_sink,
topk_length=topk_length,
extra_k_cache=extra_k_cache,
extra_indices_in_kvcache=extra_indices,
extra_topk_length=extra_topk_length,
)
return o_tensor[:, 0, :, :] # [b, 1, h, d] -> [b, h, d]

def _flashmla_kvcache_decode_att(self, q: torch.Tensor, packed_kv: torch.Tensor, nsa_dict: dict) -> torch.Tensor:
attn_sink = nsa_dict["attn_sink"]
metadata = _metadata_from_dict(self.infer_state, nsa_dict)
return self._flashmla_kvcache_att(q, packed_kv, metadata, attn_sink, nsa_dict)

def _flashmla_kvcache_att(
self,
q: torch.Tensor,
packed_kv: torch.Tensor,
metadata: _Dsv4Metadata,
attn_sink: torch.Tensor,
nsa_dict: dict,
) -> torch.Tensor:
import flash_mla
from lightllm.common.kv_cache_mem_manager.deepseek4_mem_manager import DSV4_SWA_PAGE_SIZE

q_4d = q.unsqueeze(1).contiguous()
q_4d, attn_sink, num_real_heads = _pad_q_heads(q_4d, attn_sink)
k_cache = _view_dsv4_flashmla_cache(packed_kv, DSV4_SWA_PAGE_SIZE)
out, _ = flash_mla.flash_mla_with_kvcache(
q=q_4d,
k_cache=k_cache,
block_table=None,
cache_seqlens=None,
head_dim_v=nsa_dict["head_dim_v"],
tile_scheduler_metadata=self.flashmla_sched_meta[nsa_dict["compress_ratio"]],
num_splits=None,
softmax_scale=nsa_dict["softmax_scale"],
causal=False,
is_fp8_kvcache=True,
indices=metadata.swa_indices,
attn_sink=attn_sink,
topk_length=metadata.swa_lengths,
extra_k_cache=metadata.extra_cache,
extra_indices_in_kvcache=metadata.extra_indices,
extra_topk_length=metadata.extra_lengths,
)
return out[:, 0, :num_real_heads].contiguous()
62 changes: 62 additions & 0 deletions lightllm/common/basemodel/basemodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,6 +291,15 @@ def _init_custom(self):

@torch.no_grad()
def forward(self, model_input: ModelInput):
# decode 槽位 prep: 放在 to_cuda 前, 优先使用 b_req_idx/b_seq_len 的 CPU mirror,
# 且此刻已在 forward 的 CUDA stream 上 -> 与后续 attention 同流, 无跨流竞态、无 D2H。
# mem_indexes_cpu is None 时跳过: cudagraph warmup 的输入全在 CUDA 且 b_req_idx 全为 HOLD, prep 本就是 no-op。
if not model_input.is_prefill and model_input.mem_indexes_cpu is not None:
self.req_manager.prepare_decode(
model_input.b_req_idx_cpu,
model_input.b_seq_len_cpu,
model_input.mem_indexes_cpu,
)
model_input.to_cuda()
assert model_input.mem_indexes.is_cuda

Expand Down Expand Up @@ -371,6 +380,15 @@ def _create_padded_decode_model_input(self, model_input: ModelInput, new_batch_s
new_model_input.b_mtp_index, (0, padded_batch_size), mode="constant", value=0
)
new_model_input.b_seq_len = F.pad(new_model_input.b_seq_len, (0, padded_batch_size), mode="constant", value=2)
new_model_input.b_req_idx_cpu = F.pad(
new_model_input.b_req_idx_cpu,
(0, padded_batch_size),
mode="constant",
value=self.req_manager.HOLD_REQUEST_ID,
)
new_model_input.b_seq_len_cpu = F.pad(
new_model_input.b_seq_len_cpu, (0, padded_batch_size), mode="constant", value=2
)
new_model_input.mem_indexes = F.pad(
new_model_input.mem_indexes,
(0, padded_batch_size),
Expand Down Expand Up @@ -428,6 +446,15 @@ def _create_padded_prefill_model_input(self, model_input: ModelInput, new_handle
new_model_input.b_mtp_index = F.pad(new_model_input.b_mtp_index, (0, 1), mode="constant", value=0)
new_model_input.b_seq_len = F.pad(new_model_input.b_seq_len, (0, 1), mode="constant", value=padded_token_num)
new_model_input.b_ready_cache_len = F.pad(new_model_input.b_ready_cache_len, (0, 1), mode="constant", value=0)
new_model_input.b_req_idx_cpu = F.pad(
new_model_input.b_req_idx_cpu, (0, 1), mode="constant", value=self.req_manager.HOLD_REQUEST_ID
)
new_model_input.b_seq_len_cpu = F.pad(
new_model_input.b_seq_len_cpu, (0, 1), mode="constant", value=padded_token_num
)
new_model_input.b_ready_cache_len_cpu = F.pad(
new_model_input.b_ready_cache_len_cpu, (0, 1), mode="constant", value=0
)
b_q_seq_len = new_model_input.b_seq_len - new_model_input.b_ready_cache_len
new_model_input.b_prefill_start_loc = b_q_seq_len.cumsum(dim=0, dtype=torch.int32) - b_q_seq_len
# 构建新的list, 使用 append 可能会让外面使用的数组引用发生变化,导致错误。
Expand Down Expand Up @@ -521,6 +548,15 @@ def _prefill(
alloc_mem_index=infer_state.mem_index,
max_q_seq_len=infer_state.max_q_seq_len,
)
if model_input.b_req_idx_cpu is not None:
self.req_manager.prepare_prefill(
b_req_idx=infer_state.b_req_idx,
b_ready_cache_len=infer_state.b_ready_cache_len,
b_seq_len=infer_state.b_seq_len,
b_req_idx_cpu=model_input.b_req_idx_cpu,
b_ready_cache_len_cpu=model_input.b_ready_cache_len_cpu,
b_seq_len_cpu=model_input.b_seq_len_cpu,
)
prefill_mem_indexes_ready_event = torch.cuda.Event()
prefill_mem_indexes_ready_event.record()

Expand Down Expand Up @@ -741,6 +777,15 @@ def microbatch_overlap_prefill(self, model_input0: ModelInput, model_input1: Mod
alloc_mem_index=infer_state0.mem_index,
max_q_seq_len=infer_state0.max_q_seq_len,
)
if model_input0.b_req_idx_cpu is not None:
self.req_manager.prepare_prefill(
b_req_idx=infer_state0.b_req_idx,
b_ready_cache_len=infer_state0.b_ready_cache_len,
b_seq_len=infer_state0.b_seq_len,
b_req_idx_cpu=model_input0.b_req_idx_cpu,
b_ready_cache_len_cpu=model_input0.b_ready_cache_len_cpu,
b_seq_len_cpu=model_input0.b_seq_len_cpu,
)
infer_state0.init_some_extra_state(self)
infer_state0.init_att_state()

Expand All @@ -754,6 +799,15 @@ def microbatch_overlap_prefill(self, model_input0: ModelInput, model_input1: Mod
alloc_mem_index=infer_state1.mem_index,
max_q_seq_len=infer_state1.max_q_seq_len,
)
if model_input1.b_req_idx_cpu is not None:
self.req_manager.prepare_prefill(
b_req_idx=infer_state1.b_req_idx,
b_ready_cache_len=infer_state1.b_ready_cache_len,
b_seq_len=infer_state1.b_seq_len,
b_req_idx_cpu=model_input1.b_req_idx_cpu,
b_ready_cache_len_cpu=model_input1.b_ready_cache_len_cpu,
b_seq_len_cpu=model_input1.b_seq_len_cpu,
)
infer_state1.init_some_extra_state(self)
infer_state1.init_att_state()

Expand Down Expand Up @@ -781,6 +835,14 @@ def microbatch_overlap_prefill(self, model_input0: ModelInput, model_input1: Mod

@torch.no_grad()
def microbatch_overlap_decode(self, model_input0: ModelInput, model_input1: ModelInput):
# decode 槽位 prep: 在 to_cuda 前使用 CPU mirror, 且已在 forward 的 CUDA stream 上 (见 forward 注释)。
for mi in (model_input0, model_input1):
if mi.mem_indexes_cpu is not None:
self.req_manager.prepare_decode(
mi.b_req_idx_cpu,
mi.b_seq_len_cpu,
mi.mem_indexes_cpu,
)
model_input0.to_cuda()
model_input1.to_cuda()
assert self.args.enable_tpsp_mix_mode
Expand Down
Loading