Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 52 additions & 11 deletions fastdeploy/model_executor/layers/attention/dsa_attention_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,8 +335,28 @@ def forward_mixed(
"""
Mixed模式的前向传播
"""
res = DSAAttentionBackend.forward_static(
q, v, compressed_kv, k_pe, forward_meta.caches[2 * layer.layer_id], forward_meta, self.attn_softmax_scale
)
return res

@staticmethod
def forward_static(
q: paddle.Tensor,
indexer_topk: paddle.Tensor,
compressed_kv: paddle.Tensor,
k_pe: paddle.Tensor,
latent_cache: paddle.Tensor,
forward_meta: ForwardMeta,
attn_softmax_scale: float,
) -> paddle.Tensor:

latent_cache = forward_meta.caches[2 * layer.layer_id] if hasattr(forward_meta, "caches") else None
assert len(q.shape) == 3
assert len(compressed_kv.shape) == 2
assert len(k_pe.shape) == 3
assert k_pe.shape[1] == 1
assert compressed_kv.shape[0] == k_pe.shape[0]
assert len(latent_cache.shape) == 4

if current_platform.is_cuda():
import flash_mla
Expand All @@ -352,43 +372,64 @@ def forward_mixed(
"fp8_ds_mla",
)

q_num_heads = q.shape[1]
ceil64_num_heads = (q_num_heads + 63) // 64 * 64

fmha_out_prefill = None
if forward_meta.max_len_tensor_cpu[1]: # max_enc_len_this_time
if ceil64_num_heads != q_num_heads:
new_q = paddle.empty([q.shape[0], ceil64_num_heads, q.shape[2]], dtype=q.dtype)
new_q[:, :q_num_heads, :] = q
else:
new_q = q

# concat for involing flash_mla_sparse_fwd!
kv = paddle.concat([compressed_kv.unsqueeze(1), k_pe], axis=-1)
fmha_out_prefill, _, __ = flash_mla.flash_mla_sparse_fwd(
q, # q_input.contiguous(),
k, # kv.unsqueeze(1),
v, # indexer_top_k.unsqueeze(1),
sm_scale=self.attn_softmax_scale,
new_q,
kv,
indexer_topk,
sm_scale=attn_softmax_scale,
)

assert len(fmha_out_prefill.shape) == 3
fmha_out_prefill = fmha_out_prefill[:, :q_num_heads, :].contiguous()

# Decode
# if k is None:
if forward_meta.max_len_tensor_cpu[2]: # max_enc_len_this_time
if forward_meta.max_len_tensor_cpu[2]:

tile_scheduler_metadata, _ = flash_mla.get_mla_metadata()
new_cache_shape = latent_cache.shape
assert new_cache_shape[1] == 1
new_cache_shape[1], new_cache_shape[2] = new_cache_shape[2], new_cache_shape[1]

if ceil64_num_heads != q_num_heads:
new_q = paddle.empty([q.shape[0], ceil64_num_heads, q.shape[2]], dtype=q.dtype)
new_q[:, :q_num_heads, :] = q
else:
new_q = q

fmha_out_decode, _ = flash_mla.flash_mla_with_kvcache(
q.unsqueeze(1).contiguous(),
new_q.unsqueeze(1).contiguous(),
latent_cache.view(new_cache_shape),
None, # forward_meta.block_tables,
None, # cache_seqlens
512, # self.qk_nope_head_dim,
tile_scheduler_metadata,
None, # num_splits,
self.attn_softmax_scale,
attn_softmax_scale,
False, # casual
True, # is_fp8_kvcache
v, # indices,
indexer_topk, # indices,
None, # t.attn_sink,
None, # extra_k_cache,
None, # extra_indices_in_kvcache: Optional[torch.Tensor] = None,
None, # topk_length: Optional[torch.Tensor] = None,
None, # extra_topk_length: Optional[torch.Tensor] = None
)

fmha_out_decode = fmha_out_decode[:, :, :q_num_heads, :].contiguous()

if fmha_out_prefill is not None:

from fastdeploy.model_executor.ops.gpu import (
Expand All @@ -402,7 +443,7 @@ def forward_mixed(
forward_meta.seq_lens_decoder,
forward_meta.seq_lens_this_time,
forward_meta.cu_seqlens_q,
self.num_heads * 4,
q_num_heads * 4,
128,
1,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -875,6 +875,7 @@ def forward_mixed(
forward_meta.cu_seqlens_q,
forward_meta.cu_seqlens_k,
causal=self.causal,
window_size=-1,

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🔴 Bug 新增的 mha_baseline(..., window_size, ...) 在唯一调用处固定传 -1,SWA mask 永远不会生效。

Attention 已经按 layer_types 设置了 layer.sliding_window,但这里没有读取它;因此 Blackwell prefill 仍走全量 causal attention,mha_baseline 里新增的 window_size > 0 分支不可达,SWA MHA 输出会和预期滑窗语义不一致。

建议修复方式:在 forward_mixed 里按当前层传入窗口,例如:

window_size = getattr(layer, "sliding_window", 0) or -1
fmha_out = MLAAttentionBackend.mha_baseline(..., window_size=window_size, **self.flash_attn_kwargs)

并补充滑窗层的 prefill 对齐用例。

**self.flash_attn_kwargs,
)
return fmha_out
Expand Down Expand Up @@ -1155,7 +1156,7 @@ def flashmla_baseline(decoder_q, latent_cache, block_table, cache_seqlens, attn_
return res_baseline

@staticmethod
def mha_baseline(q, k, v, cu_seqlens_q, cu_seqlens_k, causal, softmax_scale):
def mha_baseline(q, k, v, cu_seqlens_q, cu_seqlens_k, causal, window_size, softmax_scale):

assert causal, "Only support causal attention for now"
bsz = cu_seqlens_q.shape[0] - 1
Expand Down Expand Up @@ -1191,7 +1192,12 @@ def mha_baseline(q, k, v, cu_seqlens_q, cu_seqlens_k, causal, softmax_scale):

tmp_zeros = np.zeros((q_len, kv_len)) - 1
for i in range(q_len):
tmp_zeros[i][: i + 1] = 0
if kv_len - q_len + i + 1 > window_size and window_size > 0:
ss = kv_len - q_len + i + 1 - window_size
tmp_zeros[i][ss : kv_len - q_len + i + 1] = 0
else:
# attention all before this `i` th q.
tmp_zeros[i][: kv_len - q_len + i + 1] = 0
mask = tmp_zeros * 1000
mask = paddle.to_tensor(mask, dtype=q.dtype)
p = p + mask[None, :]
Expand Down
121 changes: 118 additions & 3 deletions fastdeploy/model_executor/models/deepseek_v3.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,78 @@
)


import triton
import triton.language as tl


@enable_compat_on_triton_kernel
@triton.jit
def get_swa_indexer_top_k_kernel(
indexer_top_k,
block_tables,
cu_seqlens_q,
seq_lens_encoder,
seq_lens_decoder,
batch_id_per_token,
max_page_per_seq: tl.constexpr,
window_size: tl.constexpr,
page_size: tl.constexpr,
):
token_id = tl.program_id(0)

indexer_top_k += token_id * window_size

batch_id = tl.load(batch_id_per_token + token_id)
if batch_id < 0:
return

block_tables += batch_id * max_page_per_seq

kv_len = tl.load(seq_lens_decoder + batch_id)
encoder_len = tl.load(seq_lens_encoder + batch_id)
cu_q_len = tl.load(cu_seqlens_q + batch_id)
token_id_in_this_batch = token_id - cu_q_len + kv_len

valid_window_size = min(token_id_in_this_batch + 1, window_size)

for idx in range(token_id_in_this_batch, token_id_in_this_batch - valid_window_size, -1):
if encoder_len > 0:
# encoder case.
tmp = cu_q_len + idx
tl.store(indexer_top_k + token_id_in_this_batch - idx, tmp)
else:
tmp = tl.load(block_tables + idx // page_size)
tmp = tmp * page_size + idx % page_size
tl.store(indexer_top_k + token_id_in_this_batch - idx, tmp)


def get_swa_indexer_top_k(
indexer_top_k,
block_tables,
cu_seqlens_q,
seq_lens_encoder,
seq_lens_decoder,
batch_id_per_token,
):
assert indexer_top_k.ndim == 3
assert indexer_top_k.shape[1] == 1

token_num = indexer_top_k.shape[0]
grid = (token_num,)

get_swa_indexer_top_k_kernel[grid](
indexer_top_k,
block_tables,
cu_seqlens_q,
seq_lens_encoder,
seq_lens_decoder,
batch_id_per_token,
max_page_per_seq=block_tables.shape[1],
window_size=indexer_top_k.shape[2],
page_size=64,
)


class DeepSeekV3MLP(nn.Layer):
"""
DeepSeekV3MLP, for Dense FFN and Shared Experts Layer.
Expand Down Expand Up @@ -534,6 +606,52 @@ def forward(
)
else:
attn_out = fmqa_out

if False:

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟡 建议 这段新加的 SWA sparse attention 分支被 if False 固定屏蔽,运行时永远不会执行。

当前 PR 新增的 get_swa_indexer_top_kDSAAttentionBackend.forward_static 调用和 512 窗口逻辑都只在这个死分支里使用;实际 DeepseekV3MLAAttention.forward 仍按上面的普通 MLA 路径返回,无法覆盖或验证这里的 SWA 行为。

建议修复方式:要么删除这段未启用代码;要么用明确的模型配置/attention backend 开关接入,并保证 prefill/decode 混合路径和精度测试一起提交。

q_nope_out = self.kv_b_proj_bmm(query_nope.transpose([1, 0, 2]), proj_type="k").transpose([1, 0, 2])

q_input = paddle.concat([q_nope_out, query_pe], axis=-1)
q_input.reshape_(
[
-1,
self.num_attention_heads_tp,
self.kv_lora_rank + self.qk_rope_head_dim,
]
)

self.index_topk = 512
indexer_top_k = paddle.full([q_input.shape[0], 1, self.index_topk], -1, dtype="int32")

get_swa_indexer_top_k(
indexer_top_k,
forward_meta.block_tables,
forward_meta.cu_seqlens_q,
forward_meta.seq_lens_encoder,
forward_meta.seq_lens_decoder,
forward_meta.batch_id_per_token,
)

from fastdeploy.model_executor.layers.attention import DSAAttentionBackend

fmqa_out = DSAAttentionBackend.forward_static(
q=q_input.contiguous(),
indexer_topk=indexer_top_k,
compressed_kv=compressed_kv,
k_pe=key_pe,
latent_cache=forward_meta.caches[self.layer_id],
forward_meta=forward_meta,
attn_softmax_scale=self.attn_softmax_scale,
)

fmqa_out = fmqa_out.reshape_([-1, self.num_attention_heads_tp, self.kv_lora_rank]).transpose([1, 0, 2])

fmqa_out = (
self.kv_b_proj_bmm(fmqa_out, proj_type="v")
.transpose([1, 0, 2])
.reshape_([-1, self.num_attention_heads_tp * self.v_head_dim])
)
attn_out = fmqa_out

if self.use_gated_attn:
gated_attn_act = getattr(self.fd_config.model_config, "gated_attn_act", "sigmoid")
if gated_attn_act == "sigmoid":
Expand All @@ -547,7 +665,6 @@ def forward(


import triton
import triton.language as tl


@enable_compat_on_triton_kernel
Expand Down Expand Up @@ -894,12 +1011,10 @@ def forward(
q_input = paddle.concat([q_nope_out.transpose([1, 0, 2]).contiguous(), query_pe], axis=-1)

compressed_kv = self.kv_a_layernorm(compressed_kv)[0]
kv = paddle.concat([compressed_kv, key_pe.squeeze(1)], axis=-1)

# dsa attention
fmha_out = self.dsa_attn(
q=q_input.contiguous(),
k=kv.unsqueeze(1).contiguous(),
v=indexer_top_k.unsqueeze(1).contiguous(),
qkv=None,
compressed_kv=compressed_kv,
Expand Down
Loading
Loading