@@ -56,7 +56,7 @@ def naive_attn(self, q_input, k_input, v_input, mask):
5656 out [bsz , hi ] = (np .matmul (qk , v_cur [bsz , hi // gqa_group_size ]) * exp_sum_inv ).astype (q_input .dtype )
5757 return out
5858
59- def paddle_flash_attn_mask (self , q_input , k_input , v_input , mask ):
59+ def paddle_flash_attn_mask (self , q_input , k_input , v_input , attn_out , mask ):
6060 bsz = q_input .shape [0 ]
6161 cu_seq_q = paddle .arange (bsz + 1 ) * q_input .shape [1 ]
6262 cu_seq_k = paddle .arange (bsz + 1 ) * k_input .shape [1 ]
@@ -71,13 +71,14 @@ def paddle_flash_attn_mask(self, q_input, k_input, v_input, mask):
7171 v_input_pad [0 : v_input .shape [0 ]] = v_input
7272 mask = paddle .to_tensor (mask ).astype ("int32" )
7373
74- out = flash_attention_mask (
74+ flash_attention_mask (
7575 q_input ,
7676 k_input ,
7777 v_input_pad ,
7878 cu_seq_q ,
7979 cu_seq_k ,
8080 seq_len_encoder ,
81+ attn_out ,
8182 mask ,
8283 int (q_input .shape [1 ]),
8384 int (k_input .shape [1 ]),
@@ -86,7 +87,6 @@ def paddle_flash_attn_mask(self, q_input, k_input, v_input, mask):
8687 int (q_input .shape [0 ]),
8788 int (k_input .shape [0 ]),
8889 )
89- return out
9090
9191 def test_flash_attention_mask (self ):
9292 q_input = np .random .normal (0 , 0.5 , size = (self .bsz , self .q_seq_len , self .num_head , self .head_dim ))
@@ -105,7 +105,8 @@ def test_flash_attention_mask(self):
105105 mask [text_len : text_len + image_len ] = text_len + image_len + self .k_seq_len
106106
107107 naive_attn_out = self .naive_attn (q_input , k_input , v_input , mask )
108- paddle_attn_out = self .paddle_flash_attn_mask (q_input , k_input , v_input , mask )
108+ paddle_attn_out = paddle .zeros (naive_attn_out .shape , dtype = "bfloat16" )
109+ self .paddle_flash_attn_mask (q_input , k_input , v_input , paddle_attn_out , mask )
109110
110111 max_diff = float ((paddle_attn_out .reshape ([- 1 ]) - paddle .to_tensor (naive_attn_out ).reshape ([- 1 ])).max ())
111112 self .assertLessEqual (max_diff , 0.05 )
0 commit comments