Skip to content

Commit d112353

Browse files
committed
format flash_mask_attn
1 parent cd2c4df commit d112353

File tree

2 files changed

+6
-5
lines changed

2 files changed

+6
-5
lines changed

custom_ops/gpu_ops/flash_mask_attn/softmax.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -188,7 +188,7 @@ struct Softmax {
188188
using TensorT = decltype(make_tensor<float>(Shape<Int<kNRows>>{}));
189189
TensorT row_max, row_sum;
190190

191-
CUTLASS_DEVICE Softmax() {};
191+
CUTLASS_DEVICE Softmax(){};
192192

193193
template <bool Is_first, bool Check_inf = false, typename Tensor0>
194194
__forceinline__ __device__ TensorT max(Tensor0 &acc_s,

tests/operators/test_flash_mask_attn.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ def naive_attn(self, q_input, k_input, v_input, mask):
5656
out[bsz, hi] = (np.matmul(qk, v_cur[bsz, hi // gqa_group_size]) * exp_sum_inv).astype(q_input.dtype)
5757
return out
5858

59-
def paddle_flash_attn_mask(self, q_input, k_input, v_input, mask):
59+
def paddle_flash_attn_mask(self, q_input, k_input, v_input, attn_out, mask):
6060
bsz = q_input.shape[0]
6161
cu_seq_q = paddle.arange(bsz + 1) * q_input.shape[1]
6262
cu_seq_k = paddle.arange(bsz + 1) * k_input.shape[1]
@@ -71,13 +71,14 @@ def paddle_flash_attn_mask(self, q_input, k_input, v_input, mask):
7171
v_input_pad[0 : v_input.shape[0]] = v_input
7272
mask = paddle.to_tensor(mask).astype("int32")
7373

74-
out = flash_attention_mask(
74+
flash_attention_mask(
7575
q_input,
7676
k_input,
7777
v_input_pad,
7878
cu_seq_q,
7979
cu_seq_k,
8080
seq_len_encoder,
81+
attn_out,
8182
mask,
8283
int(q_input.shape[1]),
8384
int(k_input.shape[1]),
@@ -86,7 +87,6 @@ def paddle_flash_attn_mask(self, q_input, k_input, v_input, mask):
8687
int(q_input.shape[0]),
8788
int(k_input.shape[0]),
8889
)
89-
return out
9090

9191
def test_flash_attention_mask(self):
9292
q_input = np.random.normal(0, 0.5, size=(self.bsz, self.q_seq_len, self.num_head, self.head_dim))
@@ -105,7 +105,8 @@ def test_flash_attention_mask(self):
105105
mask[text_len : text_len + image_len] = text_len + image_len + self.k_seq_len
106106

107107
naive_attn_out = self.naive_attn(q_input, k_input, v_input, mask)
108-
paddle_attn_out = self.paddle_flash_attn_mask(q_input, k_input, v_input, mask)
108+
paddle_attn_out = paddle.zeros(naive_attn_out.shape, dtype="bfloat16")
109+
self.paddle_flash_attn_mask(q_input, k_input, v_input, paddle_attn_out, mask)
109110

110111
max_diff = float((paddle_attn_out.reshape([-1]) - paddle.to_tensor(naive_attn_out).reshape([-1])).max())
111112
self.assertLessEqual(max_diff, 0.05)

0 commit comments

Comments
 (0)