-
Notifications
You must be signed in to change notification settings - Fork 753
Support swa mha #8053
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: develop
Are you sure you want to change the base?
Support swa mha #8053
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -72,6 +72,78 @@ | |
| ) | ||
|
|
||
|
|
||
| import triton | ||
| import triton.language as tl | ||
|
|
||
|
|
||
| @enable_compat_on_triton_kernel | ||
| @triton.jit | ||
| def get_swa_indexer_top_k_kernel( | ||
| indexer_top_k, | ||
| block_tables, | ||
| cu_seqlens_q, | ||
| seq_lens_encoder, | ||
| seq_lens_decoder, | ||
| batch_id_per_token, | ||
| max_page_per_seq: tl.constexpr, | ||
| window_size: tl.constexpr, | ||
| page_size: tl.constexpr, | ||
| ): | ||
| token_id = tl.program_id(0) | ||
|
|
||
| indexer_top_k += token_id * window_size | ||
|
|
||
| batch_id = tl.load(batch_id_per_token + token_id) | ||
| if batch_id < 0: | ||
| return | ||
|
|
||
| block_tables += batch_id * max_page_per_seq | ||
|
|
||
| kv_len = tl.load(seq_lens_decoder + batch_id) | ||
| encoder_len = tl.load(seq_lens_encoder + batch_id) | ||
| cu_q_len = tl.load(cu_seqlens_q + batch_id) | ||
| token_id_in_this_batch = token_id - cu_q_len + kv_len | ||
|
|
||
| valid_window_size = min(token_id_in_this_batch + 1, window_size) | ||
|
|
||
| for idx in range(token_id_in_this_batch, token_id_in_this_batch - valid_window_size, -1): | ||
| if encoder_len > 0: | ||
| # encoder case. | ||
| tmp = cu_q_len + idx | ||
| tl.store(indexer_top_k + token_id_in_this_batch - idx, tmp) | ||
| else: | ||
| tmp = tl.load(block_tables + idx // page_size) | ||
| tmp = tmp * page_size + idx % page_size | ||
| tl.store(indexer_top_k + token_id_in_this_batch - idx, tmp) | ||
|
|
||
|
|
||
| def get_swa_indexer_top_k( | ||
| indexer_top_k, | ||
| block_tables, | ||
| cu_seqlens_q, | ||
| seq_lens_encoder, | ||
| seq_lens_decoder, | ||
| batch_id_per_token, | ||
| ): | ||
| assert indexer_top_k.ndim == 3 | ||
| assert indexer_top_k.shape[1] == 1 | ||
|
|
||
| token_num = indexer_top_k.shape[0] | ||
| grid = (token_num,) | ||
|
|
||
| get_swa_indexer_top_k_kernel[grid]( | ||
| indexer_top_k, | ||
| block_tables, | ||
| cu_seqlens_q, | ||
| seq_lens_encoder, | ||
| seq_lens_decoder, | ||
| batch_id_per_token, | ||
| max_page_per_seq=block_tables.shape[1], | ||
| window_size=indexer_top_k.shape[2], | ||
| page_size=64, | ||
| ) | ||
|
|
||
|
|
||
| class DeepSeekV3MLP(nn.Layer): | ||
| """ | ||
| DeepSeekV3MLP, for Dense FFN and Shared Experts Layer. | ||
|
|
@@ -534,6 +606,52 @@ def forward( | |
| ) | ||
| else: | ||
| attn_out = fmqa_out | ||
|
|
||
| if False: | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🟡 建议 这段新加的 SWA sparse attention 分支被 当前 PR 新增的 建议修复方式:要么删除这段未启用代码;要么用明确的模型配置/attention backend 开关接入,并保证 prefill/decode 混合路径和精度测试一起提交。 |
||
| q_nope_out = self.kv_b_proj_bmm(query_nope.transpose([1, 0, 2]), proj_type="k").transpose([1, 0, 2]) | ||
|
|
||
| q_input = paddle.concat([q_nope_out, query_pe], axis=-1) | ||
| q_input.reshape_( | ||
| [ | ||
| -1, | ||
| self.num_attention_heads_tp, | ||
| self.kv_lora_rank + self.qk_rope_head_dim, | ||
| ] | ||
| ) | ||
|
|
||
| self.index_topk = 512 | ||
| indexer_top_k = paddle.full([q_input.shape[0], 1, self.index_topk], -1, dtype="int32") | ||
|
|
||
| get_swa_indexer_top_k( | ||
| indexer_top_k, | ||
| forward_meta.block_tables, | ||
| forward_meta.cu_seqlens_q, | ||
| forward_meta.seq_lens_encoder, | ||
| forward_meta.seq_lens_decoder, | ||
| forward_meta.batch_id_per_token, | ||
| ) | ||
|
|
||
| from fastdeploy.model_executor.layers.attention import DSAAttentionBackend | ||
|
|
||
| fmqa_out = DSAAttentionBackend.forward_static( | ||
| q=q_input.contiguous(), | ||
| indexer_topk=indexer_top_k, | ||
| compressed_kv=compressed_kv, | ||
| k_pe=key_pe, | ||
| latent_cache=forward_meta.caches[self.layer_id], | ||
| forward_meta=forward_meta, | ||
| attn_softmax_scale=self.attn_softmax_scale, | ||
| ) | ||
|
|
||
| fmqa_out = fmqa_out.reshape_([-1, self.num_attention_heads_tp, self.kv_lora_rank]).transpose([1, 0, 2]) | ||
|
|
||
| fmqa_out = ( | ||
| self.kv_b_proj_bmm(fmqa_out, proj_type="v") | ||
| .transpose([1, 0, 2]) | ||
| .reshape_([-1, self.num_attention_heads_tp * self.v_head_dim]) | ||
| ) | ||
| attn_out = fmqa_out | ||
|
|
||
| if self.use_gated_attn: | ||
| gated_attn_act = getattr(self.fd_config.model_config, "gated_attn_act", "sigmoid") | ||
| if gated_attn_act == "sigmoid": | ||
|
|
@@ -547,7 +665,6 @@ def forward( | |
|
|
||
|
|
||
| import triton | ||
| import triton.language as tl | ||
|
|
||
|
|
||
| @enable_compat_on_triton_kernel | ||
|
|
@@ -894,12 +1011,10 @@ def forward( | |
| q_input = paddle.concat([q_nope_out.transpose([1, 0, 2]).contiguous(), query_pe], axis=-1) | ||
|
|
||
| compressed_kv = self.kv_a_layernorm(compressed_kv)[0] | ||
| kv = paddle.concat([compressed_kv, key_pe.squeeze(1)], axis=-1) | ||
|
|
||
| # dsa attention | ||
| fmha_out = self.dsa_attn( | ||
| q=q_input.contiguous(), | ||
| k=kv.unsqueeze(1).contiguous(), | ||
| v=indexer_top_k.unsqueeze(1).contiguous(), | ||
| qkv=None, | ||
| compressed_kv=compressed_kv, | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🔴 Bug 新增的
mha_baseline(..., window_size, ...)在唯一调用处固定传-1,SWA mask 永远不会生效。Attention已经按layer_types设置了layer.sliding_window,但这里没有读取它;因此 Blackwell prefill 仍走全量 causal attention,mha_baseline里新增的window_size > 0分支不可达,SWA MHA 输出会和预期滑窗语义不一致。建议修复方式:在
forward_mixed里按当前层传入窗口,例如:并补充滑窗层的 prefill 对齐用例。