Skip to content
47 changes: 16 additions & 31 deletions fastdeploy/model_executor/layers/attention/mla_attention_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -326,6 +326,7 @@ def extract_decoder_token_from_q(
assert len(cu_seqlens_q.shape) == 1
assert len(seq_lens_encoder.shape) == 1
assert len(seq_lens_decoder.shape) == 1
assert seq_lens_encoder.shape == seq_lens_decoder.shape

max_bsz = seq_lens_decoder.shape[0]

Expand Down Expand Up @@ -398,7 +399,7 @@ def insert_decoder_result_back(
max_bsz = seq_lens_encoder.shape[0]

This comment was marked as outdated.


hidden_dim = decoder_result.shape[-2] * decoder_result.shape[-1]
out = paddle.zeros([mixed_token_num, hidden_dim], dtype=decoder_result.dtype)
out = paddle.empty([mixed_token_num, hidden_dim], dtype=decoder_result.dtype)

BLOCK_SIZE = triton.next_power_of_2(hidden_dim)

Expand Down Expand Up @@ -525,6 +526,7 @@ def __init__(
self.useless_tensor = paddle.randn([1]).cast("int32")
prop = paddle.device.cuda.get_device_properties()
cc = prop.major * 10 + prop.minor
self.prop = prop
self.is_blackwell = cc >= 100

if self.flash_attn_func is None:
Expand Down Expand Up @@ -813,7 +815,7 @@ def forward_mixed(
self.max_seq_len,
)

if self.is_blackwell:
if self.prop.major == 10:

This comment was marked as outdated.

# TODO support FA4
fmha_out = MLAAttentionBackend.mha_baseline(
q,
Expand Down Expand Up @@ -857,7 +859,7 @@ def forward_mixed(
speculate_decoder,
)

if int(os.getenv("USE_FLASH_MLA", "0")) == 0:
if int(os.getenv("USE_FLASH_MLA", "0")) == 0 and self.prop.major == 9:
Comment thread
zhoutianzi666 marked this conversation as resolved.
assert self.num_heads <= 64, "paddle mla attention support failed"
if self.heads_need_padding:
q = paddle.nn.functional.pad(
Expand Down Expand Up @@ -910,17 +912,7 @@ def forward_mixed(

This comment was marked as outdated.

return fmha_out
else:
import flash_mla

decoder_q, cache_seqlens = extract_decoder_token_from_q(
q,
forward_meta.cu_seqlens_q,
forward_meta.seq_lens_encoder,
forward_meta.seq_lens_decoder,
)

tile_scheduler_metadata, num_splits = flash_mla.get_mla_metadata()
token_num = q.shape[0]
decoder_q = q
decoder_q.reshape_([-1, 1, self.num_heads, 576])
if self.heads_need_padding:
padded_q = paddle.zeros(
Expand All @@ -933,22 +925,28 @@ def forward_mixed(
assert new_cache_shape[1] == 1
new_cache_shape[1], new_cache_shape[2] = new_cache_shape[2], new_cache_shape[1]

if self.is_blackwell:
if self.prop.major == 10:
# blackwell
decoder_res = MLAAttentionBackend.mla_blackwell(
decoder_q,
latent_cache,
metadata.block_tables,
cache_seqlens,
forward_meta.cache_seqlens,
attn_softmax_scale=self.attn_softmax_scale,
)
else:

import flash_mla

tile_scheduler_metadata, num_splits = flash_mla.get_mla_metadata()

decoder_res, _ = flash_mla.flash_mla_with_kvcache(
decoder_q,
# 外面的开源仓库的kv cache存储格式和FD的不同
# 幸好这里缓存的头是1,直接view即可,否则上上下下要改很多!
latent_cache.view(new_cache_shape),
metadata.block_tables,
cache_seqlens,
forward_meta.cache_seqlens,
512, # t.dv,
tile_scheduler_metadata,
num_splits,
Expand All @@ -958,15 +956,7 @@ def forward_mixed(
if self.heads_need_padding:
decoder_res = decoder_res[:, :, : self.num_heads, :].contiguous()

final_res = insert_decoder_result_back(
decoder_res,
forward_meta.cu_seqlens_q,
forward_meta.seq_lens_encoder,
forward_meta.seq_lens_decoder,
token_num,
)

return final_res
return decoder_res

@staticmethod
def mla_blackwell(decoder_q, latent_cache, block_table, cache_seqlens, attn_softmax_scale):
Expand Down Expand Up @@ -1016,11 +1006,6 @@ def mla_blackwell(decoder_q, latent_cache, block_table, cache_seqlens, attn_soft
softmax_scale = attn_softmax_scale
output_scale = 1.0

import sys

sys.path.insert(
0, "/root/paddlejob/workspace/env_run/output/zkk/cutlass/examples/python/CuTeDSL/blackwell/mla"
)
from mla_decode_fp16 import BlackwellMultiHeadLatentAttentionForwardFP16

mla = BlackwellMultiHeadLatentAttentionForwardFP16(
Expand Down
47 changes: 47 additions & 0 deletions fastdeploy/model_executor/models/deepseek_v3.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from __future__ import annotations

import math
import os
import re
from typing import Dict

Expand Down Expand Up @@ -344,6 +345,9 @@ def __init__(self, fd_config: FDConfig, layer_id: int, prefix: str = "") -> None

self.prefix = prefix

prop = paddle.device.cuda.get_device_properties()
self.prop = prop

@staticmethod
def yarn_get_mscale(scale=1, mscale=1):
""" """
Expand All @@ -362,6 +366,8 @@ def forward(
fused_read_cache_and_interleave,
)

q_total_token_num = hidden_states.shape[0]

attn_out = None
if self.use_gated_attn:
gate_out = self.gate(hidden_states)
Expand Down Expand Up @@ -438,6 +444,36 @@ def forward(
attn_out = fmha_out

if need_do_decode: # max_dec_len_this_time

if int(os.getenv("USE_FLASH_MLA", "0")) == 0 and self.prop.major == 9:
pass
else:
from fastdeploy.model_executor.layers.attention.mla_attention_backend import (
extract_decoder_token_from_q,
insert_decoder_result_back,
)

decoder_query_nope, cache_seqlens = extract_decoder_token_from_q(
query_nope.reshape([0, -1]),
forward_meta.cu_seqlens_q,
forward_meta.seq_lens_encoder,
forward_meta.seq_lens_decoder,
)

decoder_query_pe, cache_seqlens = extract_decoder_token_from_q(
query_pe.reshape([0, -1]),
forward_meta.cu_seqlens_q,
forward_meta.seq_lens_encoder,
forward_meta.seq_lens_decoder,
)
assert decoder_query_nope.shape[0] == forward_meta.seq_lens_encoder.shape[0]
assert decoder_query_pe.shape[0] == forward_meta.seq_lens_encoder.shape[0]

forward_meta.cache_seqlens = cache_seqlens

query_nope = decoder_query_nope.reshape([0, -1, self.qk_nope_head_dim])
query_pe = decoder_query_pe.reshape([0, -1, self.qk_rope_head_dim])

q_nope_out = self.kv_b_proj_bmm(query_nope.transpose([1, 0, 2]), proj_type="k").transpose([1, 0, 2])

q_input = paddle.concat([q_nope_out, query_pe], axis=-1)
Expand Down Expand Up @@ -466,6 +502,17 @@ def forward(
.reshape_([-1, self.num_attention_heads_tp * self.v_head_dim])
)

if int(os.getenv("USE_FLASH_MLA", "0")) == 0 and self.prop.major == 9:
pass
else:
fmqa_out = insert_decoder_result_back(
fmqa_out.reshape([0, 1, self.num_attention_heads_tp, self.v_head_dim]),
forward_meta.cu_seqlens_q,
forward_meta.seq_lens_encoder,
forward_meta.seq_lens_decoder,
q_total_token_num,
)

if need_do_prefill:
merge_prefill_decode_output(
attn_out,
Expand Down
Loading
Loading