diff --git a/custom_ops/gpu_ops/append_attention.cu b/custom_ops/gpu_ops/append_attention.cu index 7982f61428c..23eea4ac127 100644 --- a/custom_ops/gpu_ops/append_attention.cu +++ b/custom_ops/gpu_ops/append_attention.cu @@ -42,6 +42,9 @@ void AppendAttentionKernel( const paddle::Tensor& qkv, const paddle::Tensor& key_cache, const paddle::Tensor& value_cache, + paddle::Tensor& tmp_workspace, + paddle::Tensor& tmp_m, + paddle::Tensor& tmp_d, const paddle::Tensor& seq_lens_encoder, const paddle::Tensor& seq_lens_decoder, const paddle::Tensor& seq_lens_this_time, @@ -137,6 +140,9 @@ void AppendAttentionKernel( qkv_out, key_cache, value_cache, + tmp_workspace, + tmp_m, + tmp_d, attn_mask, cache_quant_type_str == "block_wise_fp8" ? cache_k_quant_scales : cache_k_dequant_scales, @@ -446,6 +452,9 @@ std::vector AppendAttention( const paddle::Tensor& qkv, const paddle::Tensor& key_cache, const paddle::Tensor& value_cache, + paddle::Tensor& tmp_workspace, + paddle::Tensor& tmp_m, + paddle::Tensor& tmp_d, const paddle::Tensor& seq_lens_encoder, const paddle::Tensor& seq_lens_decoder, const paddle::Tensor& seq_lens_this_time, @@ -579,6 +588,9 @@ std::vector AppendAttention( qkv, key_cache, value_cache, + tmp_workspace, + tmp_m, + tmp_d, seq_lens_encoder, seq_lens_decoder, seq_lens_this_time, @@ -655,6 +667,9 @@ std::vector AppendAttentionWithOutput( const paddle::Tensor& qkv, const paddle::Tensor& key_cache, const paddle::Tensor& value_cache, + paddle::Tensor& tmp_workspace, + paddle::Tensor& tmp_m, + paddle::Tensor& tmp_d, const paddle::Tensor& seq_lens_encoder, const paddle::Tensor& seq_lens_decoder, const paddle::Tensor& seq_lens_this_time, @@ -735,6 +750,9 @@ std::vector AppendAttentionWithOutput( qkv, key_cache, value_cache, + tmp_workspace, + tmp_m, + tmp_d, seq_lens_encoder, seq_lens_decoder, seq_lens_this_time, @@ -825,6 +843,9 @@ std::vector> AppendAttentionInferShape( const std::vector& qkv_shape, const std::vector& key_cache_shape, const std::vector& value_cache_shape, + const std::vector& tmp_workspace_shape, + const std::vector& tmp_m_shape, + const std::vector& tmp_d_shape, const std::vector& seq_lens_encoder_shape, const std::vector& seq_lens_decoder_shape, const std::vector& seq_lens_this_time_shape, @@ -890,6 +911,9 @@ std::vector AppendAttentionInferDtype( const paddle::DataType& qkv_dtype, const paddle::DataType& key_cache_dtype, const paddle::DataType& value_cache_dtype, + const paddle::DataType& tmp_workspace_dtype, + const paddle::DataType& tmp_m_dtype, + const paddle::DataType& tmp_d_dtype, const paddle::DataType& seq_lens_encoder_dtype, const paddle::DataType& seq_lens_decoder_dtype, const paddle::DataType& seq_lens_this_time_dtype, @@ -975,6 +999,9 @@ std::vector> AppendAttentionWithOutputInferShape( const std::vector& qkv_shape, const std::vector& key_cache_shape, const std::vector& value_cache_shape, + const std::vector& tmp_workspace_shape, + const std::vector& tmp_m_shape, + const std::vector& tmp_d_shape, const std::vector& seq_lens_encoder_shape, const std::vector& seq_lens_decoder_shape, const std::vector& seq_lens_this_time_shape, @@ -1033,6 +1060,9 @@ std::vector AppendAttentionWithOutputInferDtype( const paddle::DataType& qkv_dtype, const paddle::DataType& key_cache_dtype, const paddle::DataType& value_cache_dtype, + const paddle::DataType& tmp_workspace_dtype, + const paddle::DataType& tmp_m_dtype, + const paddle::DataType& tmp_d_dtype, const paddle::DataType& seq_lens_encoder_dtype, const paddle::DataType& seq_lens_decoder_dtype, const paddle::DataType& seq_lens_this_time_dtype, @@ -1091,6 +1121,9 @@ PD_BUILD_STATIC_OP(append_attention) .Inputs({"qkv", "key_cache", "value_cache", + "tmp_workspace", + "tmp_m", + "tmp_d", "seq_lens_encoder", "seq_lens_decoder", "seq_lens_this_time", @@ -1152,6 +1185,9 @@ PD_BUILD_STATIC_OP(append_attention_with_output) .Inputs({"qkv", "key_cache", "value_cache", + "tmp_workspace", + "tmp_m", + "tmp_d", "seq_lens_encoder", "seq_lens_decoder", "seq_lens_this_time", diff --git a/custom_ops/gpu_ops/append_attn/append_attention_c16_impl.cuh b/custom_ops/gpu_ops/append_attn/append_attention_c16_impl.cuh index 95778e35a01..c5d90de0ba7 100644 --- a/custom_ops/gpu_ops/append_attn/append_attention_c16_impl.cuh +++ b/custom_ops/gpu_ops/append_attn/append_attention_c16_impl.cuh @@ -18,11 +18,15 @@ template void CascadeAppendAttentionC16Kernel( const AppendAttnMetaData& meta_data, - const paddle::Tensor& qkv, // [token_num, (num_heads + 2* kv_num_head) * head_dim] + const paddle::Tensor& + qkv, // [token_num, (num_heads + 2* kv_num_head) * head_dim] const paddle::Tensor& cache_k, // [max_block_num, num_heads, block_size, head_dim] const paddle::Tensor& cache_v, // [max_block_num, num_heads, head_dim, block_size] + paddle::Tensor& tmp_workspace, + paddle::Tensor& tmp_m, + paddle::Tensor& tmp_d, const paddle::optional& attn_mask, const paddle::optional& cache_k_scale, // [num_kv_heads, head_dim] @@ -35,9 +39,8 @@ void CascadeAppendAttentionC16Kernel( const paddle::optional& shift_bias, // [num_kv_heads, head_dim] const paddle::optional& - smooth_weight, // [num_kv_heads, head_dim] - const paddle::optional& - sinks, // [num_heads] + smooth_weight, // [num_kv_heads, head_dim] + const paddle::optional& sinks, // [num_heads] const paddle::Tensor& seq_lens_q, const paddle::Tensor& seq_lens_kv, const paddle::Tensor& seq_lens_encoder, @@ -99,6 +102,9 @@ void CascadeAppendAttentionC16Kernel( qkv, cache_k, cache_v, + tmp_workspace, + tmp_m, + tmp_d, attn_mask, shift_bias, smooth_weight, @@ -127,13 +133,17 @@ void CascadeAppendAttentionC16Kernel( })})})})})}) } -template void CascadeAppendAttentionC16Kernel( +template void +CascadeAppendAttentionC16Kernel( const AppendAttnMetaData& meta_data, const paddle::Tensor& qkv, // [token_num, num_heads, head_dim] const paddle::Tensor& cache_k, // [max_block_num, num_heads, block_size, head_dim] const paddle::Tensor& cache_v, // [max_block_num, num_heads, head_dim, block_size] + paddle::Tensor& tmp_workspace, + paddle::Tensor& tmp_m, + paddle::Tensor& tmp_d, const paddle::optional& attn_mask, const paddle::optional& cache_k_scale, // [num_kv_heads, head_dim] @@ -146,9 +156,8 @@ template void CascadeAppendAttentionC16Kernel& shift_bias, // [num_kv_heads, head_dim] const paddle::optional& - smooth_weight, // [num_kv_heads, head_dim] - const paddle::optional& - sinks, // [num_heads] + smooth_weight, // [num_kv_heads, head_dim] + const paddle::optional& sinks, // [num_heads] const paddle::Tensor& seq_lens_q, const paddle::Tensor& seq_lens_kv, const paddle::Tensor& seq_lens_encoder, @@ -174,13 +183,17 @@ template void CascadeAppendAttentionC16Kernel( +template void +CascadeAppendAttentionC16Kernel( const AppendAttnMetaData& meta_data, const paddle::Tensor& qkv, // [token_num, num_heads, head_dim] const paddle::Tensor& cache_k, // [max_block_num, num_heads, block_size, head_dim] const paddle::Tensor& cache_v, // [max_block_num, num_heads, head_dim, block_size] + paddle::Tensor& tmp_workspace, + paddle::Tensor& tmp_m, + paddle::Tensor& tmp_d, const paddle::optional& attn_mask, const paddle::optional& cache_k_scale, // [num_kv_heads, head_dim] @@ -193,9 +206,8 @@ template void CascadeAppendAttentionC16Kernel& shift_bias, // [num_kv_heads, head_dim] const paddle::optional& - smooth_weight, // [num_kv_heads, head_dim] - const paddle::optional& - sinks, // [num_heads] + smooth_weight, // [num_kv_heads, head_dim] + const paddle::optional& sinks, // [num_heads] const paddle::Tensor& seq_lens_q, const paddle::Tensor& seq_lens_kv, const paddle::Tensor& seq_lens_encoder, @@ -228,6 +240,9 @@ template void CascadeAppendAttentionC16Kernel( cache_k, // [max_block_num, num_heads, block_size, head_dim] const paddle::Tensor& cache_v, // [max_block_num, num_heads, head_dim, block_size] + paddle::Tensor& tmp_workspace, + paddle::Tensor& tmp_m, + paddle::Tensor& tmp_d, const paddle::optional& attn_mask, const paddle::optional& cache_k_scale, // [num_kv_heads, head_dim] @@ -240,9 +255,8 @@ template void CascadeAppendAttentionC16Kernel( const paddle::optional& shift_bias, // [num_kv_heads, head_dim] const paddle::optional& - smooth_weight, // [num_kv_heads, head_dim] - const paddle::optional& - sinks, // [num_heads] + smooth_weight, // [num_kv_heads, head_dim] + const paddle::optional& sinks, // [num_heads] const paddle::Tensor& seq_lens_q, const paddle::Tensor& seq_lens_kv, const paddle::Tensor& seq_lens_encoder, @@ -272,24 +286,26 @@ template void CascadeAppendAttentionC16Kernel( const AppendAttnMetaData& meta_data, const paddle::Tensor& qkv, // [token_num, num_heads, head_dim] const paddle::Tensor& - cache_k, // [max_block_num, num_heads, block_size, head_dim] + cache_k, // [max_block_num, num_heads, block_size, head_dim] const paddle::Tensor& - cache_v, // [max_block_num, num_heads, head_dim, block_size] + cache_v, // [max_block_num, num_heads, head_dim, block_size] + paddle::Tensor& tmp_workspace, + paddle::Tensor& tmp_m, + paddle::Tensor& tmp_d, const paddle::optional& attn_mask, const paddle::optional& - cache_k_scale, // [num_kv_heads, head_dim] - const paddle::optional& - cache_v_scale, // [num_kv_heads, head_dim] + cache_k_scale, // [num_kv_heads, head_dim] const paddle::optional& - cache_k_zp, // [num_kv_heads, head_dim] + cache_v_scale, // [num_kv_heads, head_dim] const paddle::optional& - cache_v_zp, // [num_kv_heads, head_dim] + cache_k_zp, // [num_kv_heads, head_dim] const paddle::optional& - shift_bias, // [num_kv_heads, head_dim] + cache_v_zp, // [num_kv_heads, head_dim] const paddle::optional& - smooth_weight, // [num_kv_heads, head_dim] + shift_bias, // [num_kv_heads, head_dim] const paddle::optional& - sinks, // [num_heads] + smooth_weight, // [num_kv_heads, head_dim] + const paddle::optional& sinks, // [num_heads] const paddle::Tensor& seq_lens_q, const paddle::Tensor& seq_lens_kv, const paddle::Tensor& seq_lens_encoder, @@ -315,13 +331,17 @@ template void CascadeAppendAttentionC16Kernel( paddle::Tensor* out, const int sliding_window); -template void CascadeAppendAttentionC16Kernel( +template void +CascadeAppendAttentionC16Kernel( const AppendAttnMetaData& meta_data, const paddle::Tensor& qkv, // [token_num, num_heads, head_dim] const paddle::Tensor& cache_k, // [max_block_num, num_heads, block_size, head_dim] const paddle::Tensor& cache_v, // [max_block_num, num_heads, head_dim, block_size] + paddle::Tensor& tmp_workspace, + paddle::Tensor& tmp_m, + paddle::Tensor& tmp_d, const paddle::optional& attn_mask, const paddle::optional& cache_k_scale, // [num_kv_heads, head_dim] @@ -334,9 +354,8 @@ template void CascadeAppendAttentionC16Kernel& shift_bias, // [num_kv_heads, head_dim] const paddle::optional& - smooth_weight, // [num_kv_heads, head_dim] - const paddle::optional& - sinks, // [num_heads] + smooth_weight, // [num_kv_heads, head_dim] + const paddle::optional& sinks, // [num_heads] const paddle::Tensor& seq_lens_q, const paddle::Tensor& seq_lens_kv, const paddle::Tensor& seq_lens_encoder, @@ -369,6 +388,9 @@ template void CascadeAppendAttentionC16Kernel( cache_k, // [max_block_num, num_heads, block_size, head_dim] const paddle::Tensor& cache_v, // [max_block_num, num_heads, head_dim, block_size] + paddle::Tensor& tmp_workspace, + paddle::Tensor& tmp_m, + paddle::Tensor& tmp_d, const paddle::optional& attn_mask, const paddle::optional& cache_k_scale, // [num_kv_heads, head_dim] @@ -381,9 +403,8 @@ template void CascadeAppendAttentionC16Kernel( const paddle::optional& shift_bias, // [num_kv_heads, head_dim] const paddle::optional& - smooth_weight, // [num_kv_heads, head_dim] - const paddle::optional& - sinks, // [num_heads] + smooth_weight, // [num_kv_heads, head_dim] + const paddle::optional& sinks, // [num_heads] const paddle::Tensor& seq_lens_q, const paddle::Tensor& seq_lens_kv, const paddle::Tensor& seq_lens_encoder, diff --git a/custom_ops/gpu_ops/append_attn/append_attention_c4_impl.cuh b/custom_ops/gpu_ops/append_attn/append_attention_c4_impl.cuh index 8c63ad2ef3b..d924d81fb7c 100644 --- a/custom_ops/gpu_ops/append_attn/append_attention_c4_impl.cuh +++ b/custom_ops/gpu_ops/append_attn/append_attention_c4_impl.cuh @@ -18,11 +18,15 @@ template void CascadeAppendAttentionC4Kernel( const AppendAttnMetaData& meta_data, - const paddle::Tensor& qkv, // [token_num, (num_heads + 2* kv_num_head) * head_dim] + const paddle::Tensor& + qkv, // [token_num, (num_heads + 2* kv_num_head) * head_dim] const paddle::Tensor& cache_k, // [max_block_num, num_heads, block_size, head_dim] const paddle::Tensor& cache_v, // [max_block_num, num_heads, head_dim, block_size] + paddle::Tensor& tmp_workspace, + paddle::Tensor& tmp_m, + paddle::Tensor& tmp_d, const paddle::optional& attn_mask, const paddle::optional& cache_k_scale, // [num_kv_heads, head_dim] @@ -35,9 +39,8 @@ void CascadeAppendAttentionC4Kernel( const paddle::optional& shift_bias, // [num_kv_heads, head_dim] const paddle::optional& - smooth_weight, // [num_kv_heads, head_dim] - const paddle::optional& - sinks, // [num_heads] + smooth_weight, // [num_kv_heads, head_dim] + const paddle::optional& sinks, // [num_heads] const paddle::Tensor& seq_lens_q, const paddle::Tensor& seq_lens_kv, const paddle::Tensor& seq_lens_encoder, @@ -99,6 +102,9 @@ void CascadeAppendAttentionC4Kernel( qkv, cache_k, cache_v, + tmp_workspace, + tmp_m, + tmp_d, attn_mask, cache_k_scale.get(), cache_v_scale.get(), @@ -131,13 +137,17 @@ void CascadeAppendAttentionC4Kernel( })})})})})}) } -template void CascadeAppendAttentionC4Kernel( +template void +CascadeAppendAttentionC4Kernel( const AppendAttnMetaData& meta_data, const paddle::Tensor& qkv, // [token_num, num_heads, head_dim] const paddle::Tensor& cache_k, // [max_block_num, num_heads, block_size, head_dim] const paddle::Tensor& cache_v, // [max_block_num, num_heads, head_dim, block_size] + paddle::Tensor& tmp_workspace, + paddle::Tensor& tmp_m, + paddle::Tensor& tmp_d, const paddle::optional& attn_mask, const paddle::optional& cache_k_scale, // [num_kv_heads, head_dim] @@ -150,9 +160,8 @@ template void CascadeAppendAttentionC4Kernel const paddle::optional& shift_bias, // [num_kv_heads, head_dim] const paddle::optional& - smooth_weight, // [num_kv_heads, head_dim] - const paddle::optional& - sinks, // [num_heads] + smooth_weight, // [num_kv_heads, head_dim] + const paddle::optional& sinks, // [num_heads] const paddle::Tensor& seq_lens_q, const paddle::Tensor& seq_lens_kv, const paddle::Tensor& seq_lens_encoder, @@ -178,13 +187,17 @@ template void CascadeAppendAttentionC4Kernel paddle::Tensor* out, const int sliding_window); -template void CascadeAppendAttentionC4Kernel( +template void +CascadeAppendAttentionC4Kernel( const AppendAttnMetaData& meta_data, const paddle::Tensor& qkv, // [token_num, num_heads, head_dim] const paddle::Tensor& cache_k, // [max_block_num, num_heads, block_size, head_dim] const paddle::Tensor& cache_v, // [max_block_num, num_heads, head_dim, block_size] + paddle::Tensor& tmp_workspace, + paddle::Tensor& tmp_m, + paddle::Tensor& tmp_d, const paddle::optional& attn_mask, const paddle::optional& cache_k_scale, // [num_kv_heads, head_dim] @@ -197,9 +210,8 @@ template void CascadeAppendAttentionC4Kernel& shift_bias, // [num_kv_heads, head_dim] const paddle::optional& - smooth_weight, // [num_kv_heads, head_dim] - const paddle::optional& - sinks, // [num_heads] + smooth_weight, // [num_kv_heads, head_dim] + const paddle::optional& sinks, // [num_heads] const paddle::Tensor& seq_lens_q, const paddle::Tensor& seq_lens_kv, const paddle::Tensor& seq_lens_encoder, @@ -232,6 +244,9 @@ template void CascadeAppendAttentionC4Kernel( cache_k, // [max_block_num, num_heads, block_size, head_dim] const paddle::Tensor& cache_v, // [max_block_num, num_heads, head_dim, block_size] + paddle::Tensor& tmp_workspace, + paddle::Tensor& tmp_m, + paddle::Tensor& tmp_d, const paddle::optional& attn_mask, const paddle::optional& cache_k_scale, // [num_kv_heads, head_dim] @@ -244,9 +259,8 @@ template void CascadeAppendAttentionC4Kernel( const paddle::optional& shift_bias, // [num_kv_heads, head_dim] const paddle::optional& - smooth_weight, // [num_kv_heads, head_dim] - const paddle::optional& - sinks, // [num_heads] + smooth_weight, // [num_kv_heads, head_dim] + const paddle::optional& sinks, // [num_heads] const paddle::Tensor& seq_lens_q, const paddle::Tensor& seq_lens_kv, const paddle::Tensor& seq_lens_encoder, @@ -279,6 +293,9 @@ template void CascadeAppendAttentionC4Kernel( cache_k, // [max_block_num, num_heads, block_size, head_dim] const paddle::Tensor& cache_v, // [max_block_num, num_heads, head_dim, block_size] + paddle::Tensor& tmp_workspace, + paddle::Tensor& tmp_m, + paddle::Tensor& tmp_d, const paddle::optional& attn_mask, const paddle::optional& cache_k_scale, // [num_kv_heads, head_dim] @@ -291,9 +308,8 @@ template void CascadeAppendAttentionC4Kernel( const paddle::optional& shift_bias, // [num_kv_heads, head_dim] const paddle::optional& - smooth_weight, // [num_kv_heads, head_dim] - const paddle::optional& - sinks, // [num_heads] + smooth_weight, // [num_kv_heads, head_dim] + const paddle::optional& sinks, // [num_heads] const paddle::Tensor& seq_lens_q, const paddle::Tensor& seq_lens_kv, const paddle::Tensor& seq_lens_encoder, @@ -319,13 +335,17 @@ template void CascadeAppendAttentionC4Kernel( paddle::Tensor* out, const int sliding_window); -template void CascadeAppendAttentionC4Kernel( +template void +CascadeAppendAttentionC4Kernel( const AppendAttnMetaData& meta_data, const paddle::Tensor& qkv, // [token_num, num_heads, head_dim] const paddle::Tensor& cache_k, // [max_block_num, num_heads, block_size, head_dim] const paddle::Tensor& cache_v, // [max_block_num, num_heads, head_dim, block_size] + paddle::Tensor& tmp_workspace, + paddle::Tensor& tmp_m, + paddle::Tensor& tmp_d, const paddle::optional& attn_mask, const paddle::optional& cache_k_scale, // [num_kv_heads, head_dim] @@ -338,9 +358,8 @@ template void CascadeAppendAttentionC4Kernel& shift_bias, // [num_kv_heads, head_dim] const paddle::optional& - smooth_weight, // [num_kv_heads, head_dim] - const paddle::optional& - sinks, // [num_heads] + smooth_weight, // [num_kv_heads, head_dim] + const paddle::optional& sinks, // [num_heads] const paddle::Tensor& seq_lens_q, const paddle::Tensor& seq_lens_kv, const paddle::Tensor& seq_lens_encoder, @@ -373,6 +392,9 @@ template void CascadeAppendAttentionC4Kernel( cache_k, // [max_block_num, num_heads, block_size, head_dim] const paddle::Tensor& cache_v, // [max_block_num, num_heads, head_dim, block_size] + paddle::Tensor& tmp_workspace, + paddle::Tensor& tmp_m, + paddle::Tensor& tmp_d, const paddle::optional& attn_mask, const paddle::optional& cache_k_scale, // [num_kv_heads, head_dim] @@ -385,9 +407,8 @@ template void CascadeAppendAttentionC4Kernel( const paddle::optional& shift_bias, // [num_kv_heads, head_dim] const paddle::optional& - smooth_weight, // [num_kv_heads, head_dim] - const paddle::optional& - sinks, // [num_heads] + smooth_weight, // [num_kv_heads, head_dim] + const paddle::optional& sinks, // [num_heads] const paddle::Tensor& seq_lens_q, const paddle::Tensor& seq_lens_kv, const paddle::Tensor& seq_lens_encoder, diff --git a/custom_ops/gpu_ops/append_attn/append_attention_c8_impl.cuh b/custom_ops/gpu_ops/append_attn/append_attention_c8_impl.cuh index e748afb96a4..160ddd0e992 100644 --- a/custom_ops/gpu_ops/append_attn/append_attention_c8_impl.cuh +++ b/custom_ops/gpu_ops/append_attn/append_attention_c8_impl.cuh @@ -18,11 +18,15 @@ template void CascadeAppendAttentionC8Kernel( const AppendAttnMetaData& meta_data, - const paddle::Tensor& qkv, // [token_num, (num_heads + 2* kv_num_head) * head_dim] + const paddle::Tensor& + qkv, // [token_num, (num_heads + 2* kv_num_head) * head_dim] const paddle::Tensor& cache_k, // [max_block_num, num_heads, block_size, head_dim] const paddle::Tensor& cache_v, // [max_block_num, num_heads, head_dim, block_size] + paddle::Tensor& tmp_workspace, + paddle::Tensor& tmp_m, + paddle::Tensor& tmp_d, const paddle::optional& attn_mask, const paddle::optional& cache_k_scale, // [num_kv_heads, head_dim] @@ -35,9 +39,8 @@ void CascadeAppendAttentionC8Kernel( const paddle::optional& shift_bias, // [num_kv_heads, head_dim] const paddle::optional& - smooth_weight, // [num_kv_heads, head_dim] - const paddle::optional& - sinks, // [num_heads] + smooth_weight, // [num_kv_heads, head_dim] + const paddle::optional& sinks, // [num_heads] const paddle::Tensor& seq_lens_q, const paddle::Tensor& seq_lens_kv, const paddle::Tensor& seq_lens_encoder, @@ -87,58 +90,67 @@ void CascadeAppendAttentionC8Kernel( block_size, BLOCK_SIZE, {DISPATCH_BLOCKSHAPE_Q( - block_shape_q, BLOCK_SHAPE_Q, NUM_WARP_Q, { - DISPATCH_DyCfp8(is_dynamic_cfp8, IsDynamicC8, { - MultiQueryAppendC8Attention( - meta_data, - qkv, - cache_k, - cache_v, - attn_mask, - cache_k_scale.get(), - cache_v_scale.get(), - shift_bias, - smooth_weight, - sinks, + block_shape_q, + BLOCK_SHAPE_Q, + NUM_WARP_Q, + {DISPATCH_DyCfp8(is_dynamic_cfp8, IsDynamicC8, { + MultiQueryAppendC8Attention( + meta_data, + qkv, + cache_k, + cache_v, + tmp_workspace, + tmp_m, + tmp_d, + attn_mask, + cache_k_scale.get(), + cache_v_scale.get(), + shift_bias, + smooth_weight, + sinks, seq_lens_q, - seq_lens_kv, - seq_lens_encoder, - batch_id_per_token, - cu_seqlens_q, - block_table, - batch_ids, - tile_ids_per_batch, - num_blocks, - max_seq_len, - max_dec_len, - quant_max_bound, - quant_min_bound, - in_scale, - max_partition_size, - encoder_max_partition_size, - speculate_max_draft_token_num, - is_decoder, - stream, - out, + seq_lens_kv, + seq_lens_encoder, + batch_id_per_token, + cu_seqlens_q, + block_table, + batch_ids, + tile_ids_per_batch, + num_blocks, + max_seq_len, + max_dec_len, + quant_max_bound, + quant_min_bound, + in_scale, + max_partition_size, + encoder_max_partition_size, + speculate_max_draft_token_num, + is_decoder, + stream, + out, sliding_window); - })})})})})})}) + })})})})})})}) } -template void CascadeAppendAttentionC8Kernel( +template void +CascadeAppendAttentionC8Kernel( const AppendAttnMetaData& meta_data, const paddle::Tensor& qkv, const paddle::Tensor& cache_k, const paddle::Tensor& cache_v, + paddle::Tensor& tmp_workspace, + paddle::Tensor& tmp_m, + paddle::Tensor& tmp_d, const paddle::optional& attn_mask, const paddle::optional& cache_k_scale, const paddle::optional& cache_v_scale, @@ -173,11 +185,15 @@ template void CascadeAppendAttentionC8Kernel( +template void +CascadeAppendAttentionC8Kernel( const AppendAttnMetaData& meta_data, const paddle::Tensor& qkv, const paddle::Tensor& cache_k, const paddle::Tensor& cache_v, + paddle::Tensor& tmp_workspace, + paddle::Tensor& tmp_m, + paddle::Tensor& tmp_d, const paddle::optional& attn_mask, const paddle::optional& cache_k_scale, const paddle::optional& cache_v_scale, @@ -212,11 +228,15 @@ template void CascadeAppendAttentionC8Kernel( +template void +CascadeAppendAttentionC8Kernel( const AppendAttnMetaData& meta_data, const paddle::Tensor& qkv, const paddle::Tensor& cache_k, const paddle::Tensor& cache_v, + paddle::Tensor& tmp_workspace, + paddle::Tensor& tmp_m, + paddle::Tensor& tmp_d, const paddle::optional& attn_mask, const paddle::optional& cache_k_scale, const paddle::optional& cache_v_scale, @@ -251,11 +271,15 @@ template void CascadeAppendAttentionC8Kernel( +template void +CascadeAppendAttentionC8Kernel( const AppendAttnMetaData& meta_data, const paddle::Tensor& qkv, const paddle::Tensor& cache_k, const paddle::Tensor& cache_v, + paddle::Tensor& tmp_workspace, + paddle::Tensor& tmp_m, + paddle::Tensor& tmp_d, const paddle::optional& attn_mask, const paddle::optional& cache_k_scale, const paddle::optional& cache_v_scale, @@ -295,6 +319,9 @@ template void CascadeAppendAttentionC8Kernel( const paddle::Tensor& qkv, const paddle::Tensor& cache_k, const paddle::Tensor& cache_v, + paddle::Tensor& tmp_workspace, + paddle::Tensor& tmp_m, + paddle::Tensor& tmp_d, const paddle::optional& attn_mask, const paddle::optional& cache_k_scale, const paddle::optional& cache_v_scale, @@ -334,6 +361,9 @@ template void CascadeAppendAttentionC8Kernel( const paddle::Tensor& qkv, const paddle::Tensor& cache_k, const paddle::Tensor& cache_v, + paddle::Tensor& tmp_workspace, + paddle::Tensor& tmp_m, + paddle::Tensor& tmp_d, const paddle::optional& attn_mask, const paddle::optional& cache_k_scale, const paddle::optional& cache_v_scale, @@ -368,11 +398,15 @@ template void CascadeAppendAttentionC8Kernel( paddle::Tensor* out, const int sliding_window); -template void CascadeAppendAttentionC8Kernel( +template void +CascadeAppendAttentionC8Kernel( const AppendAttnMetaData& meta_data, const paddle::Tensor& qkv, const paddle::Tensor& cache_k, const paddle::Tensor& cache_v, + paddle::Tensor& tmp_workspace, + paddle::Tensor& tmp_m, + paddle::Tensor& tmp_d, const paddle::optional& attn_mask, const paddle::optional& cache_k_scale, const paddle::optional& cache_v_scale, @@ -407,11 +441,15 @@ template void CascadeAppendAttentionC8Kernel( +template void +CascadeAppendAttentionC8Kernel( const AppendAttnMetaData& meta_data, const paddle::Tensor& qkv, const paddle::Tensor& cache_k, const paddle::Tensor& cache_v, + paddle::Tensor& tmp_workspace, + paddle::Tensor& tmp_m, + paddle::Tensor& tmp_d, const paddle::optional& attn_mask, const paddle::optional& cache_k_scale, const paddle::optional& cache_v_scale, @@ -446,11 +484,15 @@ template void CascadeAppendAttentionC8Kernel( +template void +CascadeAppendAttentionC8Kernel( const AppendAttnMetaData& meta_data, const paddle::Tensor& qkv, const paddle::Tensor& cache_k, const paddle::Tensor& cache_v, + paddle::Tensor& tmp_workspace, + paddle::Tensor& tmp_m, + paddle::Tensor& tmp_d, const paddle::optional& attn_mask, const paddle::optional& cache_k_scale, const paddle::optional& cache_v_scale, @@ -485,11 +527,15 @@ template void CascadeAppendAttentionC8Kernel( +template void +CascadeAppendAttentionC8Kernel( const AppendAttnMetaData& meta_data, const paddle::Tensor& qkv, const paddle::Tensor& cache_k, const paddle::Tensor& cache_v, + paddle::Tensor& tmp_workspace, + paddle::Tensor& tmp_m, + paddle::Tensor& tmp_d, const paddle::optional& attn_mask, const paddle::optional& cache_k_scale, const paddle::optional& cache_v_scale, @@ -529,6 +575,9 @@ template void CascadeAppendAttentionC8Kernel( const paddle::Tensor& qkv, const paddle::Tensor& cache_k, const paddle::Tensor& cache_v, + paddle::Tensor& tmp_workspace, + paddle::Tensor& tmp_m, + paddle::Tensor& tmp_d, const paddle::optional& attn_mask, const paddle::optional& cache_k_scale, const paddle::optional& cache_v_scale, @@ -568,6 +617,9 @@ template void CascadeAppendAttentionC8Kernel( const paddle::Tensor& qkv, const paddle::Tensor& cache_k, const paddle::Tensor& cache_v, + paddle::Tensor& tmp_workspace, + paddle::Tensor& tmp_m, + paddle::Tensor& tmp_d, const paddle::optional& attn_mask, const paddle::optional& cache_k_scale, const paddle::optional& cache_v_scale, diff --git a/custom_ops/gpu_ops/append_attn/append_attention_kernel.h b/custom_ops/gpu_ops/append_attn/append_attention_kernel.h index d43a89c70d3..d33df94dabf 100644 --- a/custom_ops/gpu_ops/append_attn/append_attention_kernel.h +++ b/custom_ops/gpu_ops/append_attn/append_attention_kernel.h @@ -13,11 +13,11 @@ // limitations under the License. #pragma once -#include "helper.h" -#include "utils.cuh" #include "append_attention_c16_impl.cuh" -#include "append_attention_c8_impl.cuh" #include "append_attention_c4_impl.cuh" +#include "append_attention_c8_impl.cuh" +#include "helper.h" +#include "utils.cuh" template void CascadeAppendAttentionKernel( @@ -27,6 +27,9 @@ void CascadeAppendAttentionKernel( cache_k, // [max_block_num, num_heads, block_size, head_dim] const paddle::Tensor& cache_v, // [max_block_num, num_heads, head_dim, block_size] + paddle::Tensor& tmp_workspace, + paddle::Tensor& tmp_m, + paddle::Tensor& tmp_d, const paddle::optional& attn_mask, const paddle::optional& cache_k_scale, // [num_kv_heads, head_dim] @@ -39,9 +42,8 @@ void CascadeAppendAttentionKernel( const paddle::optional& shift_bias, // [num_kv_heads, head_dim] const paddle::optional& - smooth_weight, // [num_kv_heads, head_dim] - const paddle::optional& - sinks, // [num_heads] + smooth_weight, // [num_kv_heads, head_dim] + const paddle::optional& sinks, // [num_heads] const paddle::Tensor& seq_lens_q, const paddle::Tensor& seq_lens_kv, const paddle::Tensor& seq_lens_encoder, @@ -67,159 +69,173 @@ void CascadeAppendAttentionKernel( cudaStream_t& stream, paddle::Tensor* out, const int sliding_window) { - if (cache_quant_type_str == "none") { - CascadeAppendAttentionC16Kernel(meta_data, - qkv, - cache_k, - cache_v, - attn_mask, - cache_k_scale, - cache_v_scale, - cache_k_zp, - cache_v_zp, - shift_bias, - smooth_weight, - sinks, - seq_lens_q, - seq_lens_kv, - seq_lens_encoder, - batch_id_per_token, - cu_seqlens_q, - block_table, - batch_ids, - tile_ids_per_batch, - num_blocks, - block_shape_q, - max_seq_len, - max_dec_len, - quant_max_bound, - quant_min_bound, - in_scale, - max_partition_size, - encoder_max_partition_size, - speculate_max_draft_token_num, - causal, - is_decoder, - enable_prefill, - stream, - out, - sliding_window); - } else if (cache_quant_type_str == "cache_int8") { - CascadeAppendAttentionC8Kernel(meta_data, - qkv, - cache_k, - cache_v, - attn_mask, - cache_k_scale, - cache_v_scale, - cache_k_zp, - cache_v_zp, - shift_bias, - smooth_weight, - sinks, - seq_lens_q, - seq_lens_kv, - seq_lens_encoder, - batch_id_per_token, - cu_seqlens_q, - block_table, - batch_ids, - tile_ids_per_batch, - num_blocks, - block_shape_q, - max_seq_len, - max_dec_len, - quant_max_bound, - quant_min_bound, - in_scale, - max_partition_size, - encoder_max_partition_size, - speculate_max_draft_token_num, - causal, - is_decoder, - enable_prefill, - cache_quant_type_str, - stream, - out, - sliding_window); - } else if (cache_quant_type_str == "cache_fp8" or cache_quant_type_str == "block_wise_fp8") { - CascadeAppendAttentionC8Kernel(meta_data, - qkv, - cache_k, - cache_v, - attn_mask, - cache_k_scale, - cache_v_scale, - cache_k_zp, - cache_v_zp, - shift_bias, - smooth_weight, - sinks, - seq_lens_q, - seq_lens_kv, - seq_lens_encoder, - batch_id_per_token, - cu_seqlens_q, - block_table, - batch_ids, - tile_ids_per_batch, - num_blocks, - block_shape_q, - max_seq_len, - max_dec_len, - quant_max_bound, - quant_min_bound, - in_scale, - max_partition_size, - encoder_max_partition_size, - speculate_max_draft_token_num, - causal, - is_decoder, - enable_prefill, - cache_quant_type_str, - stream, - out, - sliding_window); - } else if (cache_quant_type_str == "cache_int4_zp") { - CascadeAppendAttentionC4Kernel(meta_data, - qkv, - cache_k, - cache_v, - attn_mask, - cache_k_scale, - cache_v_scale, - cache_k_zp, - cache_v_zp, - shift_bias, - smooth_weight, - sinks, - seq_lens_q, - seq_lens_kv, - seq_lens_encoder, - batch_id_per_token, - cu_seqlens_q, - block_table, - batch_ids, - tile_ids_per_batch, - num_blocks, - block_shape_q, - max_seq_len, - max_dec_len, - quant_max_bound, - quant_min_bound, - in_scale, - max_partition_size, - encoder_max_partition_size, - speculate_max_draft_token_num, - causal, - is_decoder, - enable_prefill, - stream, - out, - sliding_window); - } else { - PD_THROW( - "cache_quant_type_str should be one of [none, cache_int8, " - "cache_int4_zp]"); - } + if (cache_quant_type_str == "none") { + CascadeAppendAttentionC16Kernel(meta_data, + qkv, + cache_k, + cache_v, + tmp_workspace, + tmp_m, + tmp_d, + attn_mask, + cache_k_scale, + cache_v_scale, + cache_k_zp, + cache_v_zp, + shift_bias, + smooth_weight, + sinks, + seq_lens_q, + seq_lens_kv, + seq_lens_encoder, + batch_id_per_token, + cu_seqlens_q, + block_table, + batch_ids, + tile_ids_per_batch, + num_blocks, + block_shape_q, + max_seq_len, + max_dec_len, + quant_max_bound, + quant_min_bound, + in_scale, + max_partition_size, + encoder_max_partition_size, + speculate_max_draft_token_num, + causal, + is_decoder, + enable_prefill, + stream, + out, + sliding_window); + } else if (cache_quant_type_str == "cache_int8") { + CascadeAppendAttentionC8Kernel( + meta_data, + qkv, + cache_k, + cache_v, + tmp_workspace, + tmp_m, + tmp_d, + attn_mask, + cache_k_scale, + cache_v_scale, + cache_k_zp, + cache_v_zp, + shift_bias, + smooth_weight, + sinks, + seq_lens_q, + seq_lens_kv, + seq_lens_encoder, + batch_id_per_token, + cu_seqlens_q, + block_table, + batch_ids, + tile_ids_per_batch, + num_blocks, + block_shape_q, + max_seq_len, + max_dec_len, + quant_max_bound, + quant_min_bound, + in_scale, + max_partition_size, + encoder_max_partition_size, + speculate_max_draft_token_num, + causal, + is_decoder, + enable_prefill, + cache_quant_type_str, + stream, + out, + sliding_window); + } else if (cache_quant_type_str == "cache_fp8" or + cache_quant_type_str == "block_wise_fp8") { + CascadeAppendAttentionC8Kernel(meta_data, + qkv, + cache_k, + cache_v, + tmp_workspace, + tmp_m, + tmp_d, + attn_mask, + cache_k_scale, + cache_v_scale, + cache_k_zp, + cache_v_zp, + shift_bias, + smooth_weight, + sinks, + seq_lens_q, + seq_lens_kv, + seq_lens_encoder, + batch_id_per_token, + cu_seqlens_q, + block_table, + batch_ids, + tile_ids_per_batch, + num_blocks, + block_shape_q, + max_seq_len, + max_dec_len, + quant_max_bound, + quant_min_bound, + in_scale, + max_partition_size, + encoder_max_partition_size, + speculate_max_draft_token_num, + causal, + is_decoder, + enable_prefill, + cache_quant_type_str, + stream, + out, + sliding_window); + } else if (cache_quant_type_str == "cache_int4_zp") { + CascadeAppendAttentionC4Kernel(meta_data, + qkv, + cache_k, + cache_v, + tmp_workspace, + tmp_m, + tmp_d, + attn_mask, + cache_k_scale, + cache_v_scale, + cache_k_zp, + cache_v_zp, + shift_bias, + smooth_weight, + sinks, + seq_lens_q, + seq_lens_kv, + seq_lens_encoder, + batch_id_per_token, + cu_seqlens_q, + block_table, + batch_ids, + tile_ids_per_batch, + num_blocks, + block_shape_q, + max_seq_len, + max_dec_len, + quant_max_bound, + quant_min_bound, + in_scale, + max_partition_size, + encoder_max_partition_size, + speculate_max_draft_token_num, + causal, + is_decoder, + enable_prefill, + stream, + out, + sliding_window); + } else { + PD_THROW( + "cache_quant_type_str should be one of [none, cache_int8, " + "cache_int4_zp]"); + } } diff --git a/custom_ops/gpu_ops/append_attn/multiquery_attention_c16_impl.cuh b/custom_ops/gpu_ops/append_attn/multiquery_attention_c16_impl.cuh index e3f03b98e83..9731fae50c0 100644 --- a/custom_ops/gpu_ops/append_attn/multiquery_attention_c16_impl.cuh +++ b/custom_ops/gpu_ops/append_attn/multiquery_attention_c16_impl.cuh @@ -852,6 +852,9 @@ void MultiQueryAppendAttention( const paddle::Tensor &qkv, const paddle::Tensor &cache_k, const paddle::Tensor &cache_v, + paddle::Tensor &tmp_workspace, + paddle::Tensor &tmp_m, + paddle::Tensor &tmp_d, const paddle::optional &attn_mask, const paddle::optional &shift_bias, const paddle::optional &smooth_weight, @@ -892,7 +895,7 @@ void MultiQueryAppendAttention( constexpr uint32_t num_frags_y = HEAD_DIM / 16; constexpr uint32_t num_qrow_per_block = NUM_WARP_Q * num_frags_x * 16; - auto *allocator = paddle::GetAllocator(qkv.place()); + // auto *allocator = paddle::GetAllocator(qkv.place()); const float scale = 1.f / sqrt(HEAD_DIM); @@ -988,32 +991,6 @@ void MultiQueryAppendAttention( sliding_window); } else { - phi::Allocator::AllocationPtr tmp_workspace, tmp_m, tmp_d; - if (ENABLE_PREFILL) { - tmp_workspace = allocator->Allocate( - phi::SizeOf(qkv.dtype()) * - static_cast(token_num * num_chunks * num_heads * HEAD_DIM)); - tmp_m = allocator->Allocate( - phi::SizeOf(paddle::DataType::FLOAT32) * - static_cast(token_num * num_chunks * num_heads)); - tmp_d = allocator->Allocate( - phi::SizeOf(paddle::DataType::FLOAT32) * - static_cast(token_num * num_chunks * num_heads)); - } else { - tmp_workspace = allocator->Allocate( - phi::SizeOf(qkv.dtype()) * - static_cast(speculate_max_draft_token_num * bsz * - num_chunks * num_heads * HEAD_DIM)); - tmp_m = allocator->Allocate( - phi::SizeOf(paddle::DataType::FLOAT32) * - static_cast(speculate_max_draft_token_num * bsz * - num_chunks * num_heads)); - tmp_d = allocator->Allocate( - phi::SizeOf(paddle::DataType::FLOAT32) * - static_cast(speculate_max_draft_token_num * bsz * - num_chunks * num_heads)); - } - launchWithPdlWhenEnabled( split_kv_kernel, grids, @@ -1047,9 +1024,9 @@ void MultiQueryAppendAttention( in_scale, chunk_size, num_blocks_x_cpu, - reinterpret_cast(tmp_workspace->ptr()), - static_cast(tmp_m->ptr()), - static_cast(tmp_d->ptr()), + reinterpret_cast(const_cast(tmp_workspace.data())), + static_cast(tmp_m.data()), + static_cast(tmp_d.data()), reinterpret_cast(out->data()), speculate_max_draft_token_num, sliding_window); @@ -1073,9 +1050,9 @@ void MultiQueryAppendAttention( blocks_merge, 0, stream, - reinterpret_cast(tmp_workspace->ptr()), - static_cast(tmp_m->ptr()), - static_cast(tmp_d->ptr()), + reinterpret_cast(const_cast(tmp_workspace.data())), + static_cast(tmp_m.data()), + static_cast(tmp_d.data()), seq_lens_q.data(), seq_lens_kv.data(), seq_lens_encoder.data(), @@ -1208,44 +1185,6 @@ void MultiQueryAppendAttention( attn_mask_len, sliding_window); } else { - phi::Allocator::AllocationPtr tmp_workspace, tmp_m, tmp_d; - if (is_decoder) { - tmp_workspace = allocator->Allocate( - phi::SizeOf(qkv.dtype()) * - static_cast(bsz * num_chunks * num_heads * HEAD_DIM)); - tmp_m = allocator->Allocate( - phi::SizeOf(paddle::DataType::FLOAT32) * - static_cast(bsz * num_chunks * num_heads)); - tmp_d = allocator->Allocate( - phi::SizeOf(paddle::DataType::FLOAT32) * - static_cast(bsz * num_chunks * num_heads)); - } else { - if (ENABLE_PREFILL) { - tmp_workspace = - allocator->Allocate(phi::SizeOf(qkv.dtype()) * - static_cast(token_num * num_chunks * - num_heads * HEAD_DIM)); - tmp_m = allocator->Allocate( - phi::SizeOf(paddle::DataType::FLOAT32) * - static_cast(token_num * num_chunks * num_heads)); - tmp_d = allocator->Allocate( - phi::SizeOf(paddle::DataType::FLOAT32) * - static_cast(token_num * num_chunks * num_heads)); - } else { - tmp_workspace = allocator->Allocate( - phi::SizeOf(qkv.dtype()) * - static_cast(speculate_max_draft_token_num * bsz * - num_chunks * num_heads * HEAD_DIM)); - tmp_m = allocator->Allocate( - phi::SizeOf(paddle::DataType::FLOAT32) * - static_cast(speculate_max_draft_token_num * bsz * - num_chunks * num_heads)); - tmp_d = allocator->Allocate( - phi::SizeOf(paddle::DataType::FLOAT32) * - static_cast(speculate_max_draft_token_num * bsz * - num_chunks * num_heads)); - } - } launchWithPdlWhenEnabled( split_kv_kernel, grids, @@ -1282,9 +1221,9 @@ void MultiQueryAppendAttention( in_scale, chunk_size, num_blocks_x_cpu, - reinterpret_cast(tmp_workspace->ptr()), - static_cast(tmp_m->ptr()), - static_cast(tmp_d->ptr()), + reinterpret_cast(const_cast(tmp_workspace.data())), + static_cast(tmp_m.data()), + static_cast(tmp_d.data()), reinterpret_cast(out->data()), speculate_max_draft_token_num, attn_mask_len, @@ -1309,9 +1248,10 @@ void MultiQueryAppendAttention( blocks_merge, 0, stream, - reinterpret_cast(tmp_workspace->ptr()), - static_cast(tmp_m->ptr()), - static_cast(tmp_d->ptr()), + reinterpret_cast( + const_cast(tmp_workspace.data())), + static_cast(tmp_m.data()), + static_cast(tmp_d.data()), seq_lens_q.data(), seq_lens_kv.data(), seq_lens_encoder.data(), @@ -1352,9 +1292,10 @@ void MultiQueryAppendAttention( blocks_merge, 0, stream, - reinterpret_cast(tmp_workspace->ptr()), - static_cast(tmp_m->ptr()), - static_cast(tmp_d->ptr()), + reinterpret_cast( + const_cast(tmp_workspace.data())), + static_cast(tmp_m.data()), + static_cast(tmp_d.data()), seq_lens_q.data(), seq_lens_kv.data(), seq_lens_encoder.data(), diff --git a/custom_ops/gpu_ops/append_attn/multiquery_attention_c16_kernel.h b/custom_ops/gpu_ops/append_attn/multiquery_attention_c16_kernel.h index e11bdb50568..b74bb59da41 100644 --- a/custom_ops/gpu_ops/append_attn/multiquery_attention_c16_kernel.h +++ b/custom_ops/gpu_ops/append_attn/multiquery_attention_c16_kernel.h @@ -29,6 +29,9 @@ void MultiQueryAppendAttention( const paddle::Tensor &qkv, const paddle::Tensor &cache_k, const paddle::Tensor &cache_v, + paddle::Tensor &tmp_workspace, + paddle::Tensor &tmp_m, + paddle::Tensor &tmp_d, const paddle::optional &attn_mask, const paddle::optional &shift_bias, const paddle::optional &smooth_weight, diff --git a/custom_ops/gpu_ops/append_attn/multiquery_attention_c4_impl.cuh b/custom_ops/gpu_ops/append_attn/multiquery_attention_c4_impl.cuh index 9748010b452..897a85b955e 100644 --- a/custom_ops/gpu_ops/append_attn/multiquery_attention_c4_impl.cuh +++ b/custom_ops/gpu_ops/append_attn/multiquery_attention_c4_impl.cuh @@ -1027,6 +1027,9 @@ void MultiQueryAppendC4Attention( const paddle::Tensor &qkv, const paddle::Tensor &cache_k, const paddle::Tensor &cache_v, + paddle::Tensor &tmp_workspace, + paddle::Tensor &tmp_m, + paddle::Tensor &tmp_d, const paddle::optional &attn_mask, const paddle::Tensor &cache_k_scale, const paddle::Tensor &cache_v_scale, @@ -1071,7 +1074,7 @@ void MultiQueryAppendC4Attention( constexpr uint32_t num_frags_y = HEAD_DIM / 16; constexpr uint32_t num_qrow_per_block = NUM_WARP_Q * num_frags_x * 16; - auto *allocator = paddle::GetAllocator(qkv.place()); + // auto *allocator = paddle::GetAllocator(qkv.place()); const float scale = 1.f / sqrt(HEAD_DIM); @@ -1189,31 +1192,6 @@ void MultiQueryAppendC4Attention( speculate_max_draft_token_num, sliding_window); } else { - phi::Allocator::AllocationPtr tmp_workspace, tmp_m, tmp_d; - if (ENABLE_PREFILL) { - tmp_workspace = allocator->Allocate( - phi::SizeOf(qkv.dtype()) * - static_cast(token_num * num_chunks * num_heads * HEAD_DIM)); - tmp_m = allocator->Allocate( - phi::SizeOf(paddle::DataType::FLOAT32) * - static_cast(token_num * num_chunks * num_heads)); - tmp_d = allocator->Allocate( - phi::SizeOf(paddle::DataType::FLOAT32) * - static_cast(token_num * num_chunks * num_heads)); - } else { - tmp_workspace = allocator->Allocate( - phi::SizeOf(qkv.dtype()) * - static_cast(speculate_max_draft_token_num * bsz * - num_chunks * num_heads * HEAD_DIM)); - tmp_m = allocator->Allocate( - phi::SizeOf(paddle::DataType::FLOAT32) * - static_cast(speculate_max_draft_token_num * bsz * - num_chunks * num_heads)); - tmp_d = allocator->Allocate( - phi::SizeOf(paddle::DataType::FLOAT32) * - static_cast(speculate_max_draft_token_num * bsz * - num_chunks * num_heads)); - } launchWithPdlWhenEnabled( split_kv_kernel, grids, @@ -1256,9 +1234,9 @@ void MultiQueryAppendC4Attention( in_scale, chunk_size, num_blocks_x_cpu, - reinterpret_cast(tmp_workspace->ptr()), - static_cast(tmp_m->ptr()), - static_cast(tmp_d->ptr()), + reinterpret_cast(const_cast(tmp_workspace.data())), + static_cast(tmp_m.data()), + static_cast(tmp_d.data()), reinterpret_cast(out->data()), speculate_max_draft_token_num, sliding_window); @@ -1281,9 +1259,9 @@ void MultiQueryAppendC4Attention( blocks_merge, 0, stream, - reinterpret_cast(tmp_workspace->ptr()), - static_cast(tmp_m->ptr()), - static_cast(tmp_d->ptr()), + reinterpret_cast(const_cast(tmp_workspace.data())), + static_cast(tmp_m.data()), + static_cast(tmp_d.data()), seq_lens_q.data(), seq_lens_kv.data(), seq_lens_encoder.data(), @@ -1437,44 +1415,6 @@ void MultiQueryAppendC4Attention( attn_mask_len, sliding_window); } else { - phi::Allocator::AllocationPtr tmp_workspace, tmp_m, tmp_d; - if (is_decoder) { - tmp_workspace = allocator->Allocate( - phi::SizeOf(qkv.dtype()) * - static_cast(bsz * num_chunks * num_heads * HEAD_DIM)); - tmp_m = allocator->Allocate( - phi::SizeOf(paddle::DataType::FLOAT32) * - static_cast(bsz * num_chunks * num_heads)); - tmp_d = allocator->Allocate( - phi::SizeOf(paddle::DataType::FLOAT32) * - static_cast(bsz * num_chunks * num_heads)); - } else { - if (ENABLE_PREFILL) { - tmp_workspace = - allocator->Allocate(phi::SizeOf(qkv.dtype()) * - static_cast(token_num * num_chunks * - num_heads * HEAD_DIM)); - tmp_m = allocator->Allocate( - phi::SizeOf(paddle::DataType::FLOAT32) * - static_cast(token_num * num_chunks * num_heads)); - tmp_d = allocator->Allocate( - phi::SizeOf(paddle::DataType::FLOAT32) * - static_cast(token_num * num_chunks * num_heads)); - } else { - tmp_workspace = allocator->Allocate( - phi::SizeOf(qkv.dtype()) * - static_cast(speculate_max_draft_token_num * bsz * - num_chunks * num_heads * HEAD_DIM)); - tmp_m = allocator->Allocate( - phi::SizeOf(paddle::DataType::FLOAT32) * - static_cast(speculate_max_draft_token_num * bsz * - num_chunks * num_heads)); - tmp_d = allocator->Allocate( - phi::SizeOf(paddle::DataType::FLOAT32) * - static_cast(speculate_max_draft_token_num * bsz * - num_chunks * num_heads)); - } - } launchWithPdlWhenEnabled( split_kv_kernel, grids, @@ -1520,9 +1460,9 @@ void MultiQueryAppendC4Attention( in_scale, chunk_size, num_blocks_x_cpu, - reinterpret_cast(tmp_workspace->ptr()), - static_cast(tmp_m->ptr()), - static_cast(tmp_d->ptr()), + reinterpret_cast(const_cast(tmp_workspace.data())), + static_cast(tmp_m.data()), + static_cast(tmp_d.data()), reinterpret_cast(out->data()), speculate_max_draft_token_num, attn_mask_len, @@ -1545,9 +1485,10 @@ void MultiQueryAppendC4Attention( blocks_merge, 0, stream, - reinterpret_cast(tmp_workspace->ptr()), - static_cast(tmp_m->ptr()), - static_cast(tmp_d->ptr()), + reinterpret_cast( + const_cast(tmp_workspace.data())), + static_cast(tmp_m.data()), + static_cast(tmp_d.data()), seq_lens_q.data(), seq_lens_kv.data(), seq_lens_encoder.data(), @@ -1587,9 +1528,10 @@ void MultiQueryAppendC4Attention( blocks_merge, 0, stream, - reinterpret_cast(tmp_workspace->ptr()), - static_cast(tmp_m->ptr()), - static_cast(tmp_d->ptr()), + reinterpret_cast( + const_cast(tmp_workspace.data())), + static_cast(tmp_m.data()), + static_cast(tmp_d.data()), seq_lens_q.data(), seq_lens_kv.data(), seq_lens_encoder.data(), diff --git a/custom_ops/gpu_ops/append_attn/multiquery_attention_c4_kernel.h b/custom_ops/gpu_ops/append_attn/multiquery_attention_c4_kernel.h index 1084e007434..ac2a8e1b661 100644 --- a/custom_ops/gpu_ops/append_attn/multiquery_attention_c4_kernel.h +++ b/custom_ops/gpu_ops/append_attn/multiquery_attention_c4_kernel.h @@ -29,6 +29,9 @@ void MultiQueryAppendC4Attention( const paddle::Tensor &qkv, const paddle::Tensor &cache_k, const paddle::Tensor &cache_v, + paddle::Tensor &tmp_workspace, + paddle::Tensor &tmp_m, + paddle::Tensor &tmp_d, const paddle::optional &attn_mask, const paddle::Tensor &cache_k_scale, const paddle::Tensor &cache_v_scale, diff --git a/custom_ops/gpu_ops/append_attn/multiquery_attention_c8_impl.cuh b/custom_ops/gpu_ops/append_attn/multiquery_attention_c8_impl.cuh index 59a838373e7..a0d80474ce4 100644 --- a/custom_ops/gpu_ops/append_attn/multiquery_attention_c8_impl.cuh +++ b/custom_ops/gpu_ops/append_attn/multiquery_attention_c8_impl.cuh @@ -1120,6 +1120,9 @@ void MultiQueryAppendC8Attention( const paddle::Tensor &qkv, const paddle::Tensor &cache_k, const paddle::Tensor &cache_v, + paddle::Tensor &tmp_workspace, + paddle::Tensor &tmp_m, + paddle::Tensor &tmp_d, const paddle::optional &attn_mask, const paddle::Tensor &cache_k_scale, const paddle::Tensor &cache_v_scale, @@ -1162,7 +1165,7 @@ void MultiQueryAppendC8Attention( constexpr uint32_t num_frags_y = HEAD_DIM / 16; constexpr uint32_t num_qrow_per_block = NUM_WARP_Q * num_frags_x * 16; - auto *allocator = paddle::GetAllocator(qkv.place()); + // auto *allocator = paddle::GetAllocator(qkv.place()); const float scale = 1.f / sqrt(HEAD_DIM); bool is_scale_channel_wise = false; @@ -1316,31 +1319,6 @@ void MultiQueryAppendC8Attention( speculate_max_draft_token_num, sliding_window); } else { - phi::Allocator::AllocationPtr tmp_workspace, tmp_m, tmp_d; - if (ENABLE_PREFILL) { - tmp_workspace = allocator->Allocate( - phi::SizeOf(qkv.dtype()) * - static_cast(token_num * num_chunks * num_heads * HEAD_DIM)); - tmp_m = allocator->Allocate( - phi::SizeOf(paddle::DataType::FLOAT32) * - static_cast(token_num * num_chunks * num_heads)); - tmp_d = allocator->Allocate( - phi::SizeOf(paddle::DataType::FLOAT32) * - static_cast(token_num * num_chunks * num_heads)); - } else { - tmp_workspace = allocator->Allocate( - phi::SizeOf(qkv.dtype()) * - static_cast(speculate_max_draft_token_num * bsz * - num_chunks * num_heads * HEAD_DIM)); - tmp_m = allocator->Allocate( - phi::SizeOf(paddle::DataType::FLOAT32) * - static_cast(speculate_max_draft_token_num * bsz * - num_chunks * num_heads)); - tmp_d = allocator->Allocate( - phi::SizeOf(paddle::DataType::FLOAT32) * - static_cast(speculate_max_draft_token_num * bsz * - num_chunks * num_heads)); - } launchWithPdlWhenEnabled( split_kv_kernel, grids, @@ -1377,9 +1355,9 @@ void MultiQueryAppendC8Attention( in_scale, chunk_size, num_blocks_x_cpu, - reinterpret_cast(tmp_workspace->ptr()), - static_cast(tmp_m->ptr()), - static_cast(tmp_d->ptr()), + reinterpret_cast(const_cast(tmp_workspace.data())), + static_cast(tmp_m.data()), + static_cast(tmp_d.data()), reinterpret_cast(out->data()), speculate_max_draft_token_num, sliding_window); @@ -1401,9 +1379,9 @@ void MultiQueryAppendC8Attention( blocks_merge, 0, stream, - reinterpret_cast(tmp_workspace->ptr()), - static_cast(tmp_m->ptr()), - static_cast(tmp_d->ptr()), + reinterpret_cast(const_cast(tmp_workspace.data())), + static_cast(tmp_m.data()), + static_cast(tmp_d.data()), seq_lens_q.data(), seq_lens_kv.data(), seq_lens_encoder.data(), @@ -1589,44 +1567,6 @@ void MultiQueryAppendC8Attention( attn_mask_len, sliding_window); } else { - phi::Allocator::AllocationPtr tmp_workspace, tmp_m, tmp_d; - if (is_decoder) { - tmp_workspace = allocator->Allocate( - phi::SizeOf(qkv.dtype()) * - static_cast(bsz * num_chunks * num_heads * HEAD_DIM)); - tmp_m = allocator->Allocate( - phi::SizeOf(paddle::DataType::FLOAT32) * - static_cast(bsz * num_chunks * num_heads)); - tmp_d = allocator->Allocate( - phi::SizeOf(paddle::DataType::FLOAT32) * - static_cast(bsz * num_chunks * num_heads)); - } else { - if (ENABLE_PREFILL) { - tmp_workspace = - allocator->Allocate(phi::SizeOf(qkv.dtype()) * - static_cast(token_num * num_chunks * - num_heads * HEAD_DIM)); - tmp_m = allocator->Allocate( - phi::SizeOf(paddle::DataType::FLOAT32) * - static_cast(token_num * num_chunks * num_heads)); - tmp_d = allocator->Allocate( - phi::SizeOf(paddle::DataType::FLOAT32) * - static_cast(token_num * num_chunks * num_heads)); - } else { - tmp_workspace = allocator->Allocate( - phi::SizeOf(qkv.dtype()) * - static_cast(speculate_max_draft_token_num * bsz * - num_chunks * num_heads * HEAD_DIM)); - tmp_m = allocator->Allocate( - phi::SizeOf(paddle::DataType::FLOAT32) * - static_cast(speculate_max_draft_token_num * bsz * - num_chunks * num_heads)); - tmp_d = allocator->Allocate( - phi::SizeOf(paddle::DataType::FLOAT32) * - static_cast(speculate_max_draft_token_num * bsz * - num_chunks * num_heads)); - } - } launchWithPdlWhenEnabled( split_kv_kernel, grids, @@ -1666,9 +1606,9 @@ void MultiQueryAppendC8Attention( in_scale, chunk_size, num_blocks_x_cpu, - reinterpret_cast(tmp_workspace->ptr()), - static_cast(tmp_m->ptr()), - static_cast(tmp_d->ptr()), + reinterpret_cast(const_cast(tmp_workspace.data())), + static_cast(tmp_m.data()), + static_cast(tmp_d.data()), reinterpret_cast(out->data()), speculate_max_draft_token_num, attn_mask_len, @@ -1692,9 +1632,10 @@ void MultiQueryAppendC8Attention( blocks_merge, 0, stream, - reinterpret_cast(tmp_workspace->ptr()), - static_cast(tmp_m->ptr()), - static_cast(tmp_d->ptr()), + reinterpret_cast( + const_cast(tmp_workspace.data())), + static_cast(tmp_m.data()), + static_cast(tmp_d.data()), seq_lens_q.data(), seq_lens_kv.data(), seq_lens_encoder.data(), @@ -1734,9 +1675,10 @@ void MultiQueryAppendC8Attention( blocks_merge, 0, stream, - reinterpret_cast(tmp_workspace->ptr()), - static_cast(tmp_m->ptr()), - static_cast(tmp_d->ptr()), + reinterpret_cast( + const_cast(tmp_workspace.data())), + static_cast(tmp_m.data()), + static_cast(tmp_d.data()), seq_lens_q.data(), seq_lens_kv.data(), seq_lens_encoder.data(), diff --git a/custom_ops/gpu_ops/append_attn/multiquery_attention_c8_kernel.h b/custom_ops/gpu_ops/append_attn/multiquery_attention_c8_kernel.h index 56d84e5b5bf..36ea08801ff 100644 --- a/custom_ops/gpu_ops/append_attn/multiquery_attention_c8_kernel.h +++ b/custom_ops/gpu_ops/append_attn/multiquery_attention_c8_kernel.h @@ -31,6 +31,9 @@ void MultiQueryAppendC8Attention( const paddle::Tensor &qkv, const paddle::Tensor &cache_k, const paddle::Tensor &cache_v, + paddle::Tensor &tmp_workspace, + paddle::Tensor &tmp_m, + paddle::Tensor &tmp_d, const paddle::optional &attn_mask, const paddle::Tensor &cache_k_scale, const paddle::Tensor &cache_v_scale, diff --git a/custom_ops/gpu_ops/append_attn/template_config.json b/custom_ops/gpu_ops/append_attn/template_config.json index c462afe07ac..820d961214c 100644 --- a/custom_ops/gpu_ops/append_attn/template_config.json +++ b/custom_ops/gpu_ops/append_attn/template_config.json @@ -36,7 +36,7 @@ ], "max_instances_per_file": 80, "file_prefix": "multiquery_attention_c8_", - "function_signature": "template void {function_name}{template_args}(\n const AppendAttnMetaData &meta_data,\n const paddle::Tensor &qkv,\n const paddle::Tensor &cache_k,\n const paddle::Tensor &cache_v,\n const paddle::optional &attn_mask,\n const paddle::Tensor &cache_k_scale,\n const paddle::Tensor &cache_v_scale,\n const paddle::optional &shift_bias,\n const paddle::optional &smooth_weight,\n const paddle::optional &sinks,\n const paddle::Tensor &seq_lens_q,\n const paddle::Tensor &seq_lens_kv,\n const paddle::Tensor &seq_lens_encoder,\n const paddle::Tensor &batch_id_per_token,\n const paddle::Tensor &cu_seqlens_q,\n const paddle::Tensor &block_table,\n const paddle::Tensor &batch_ids,\n const paddle::Tensor &tile_ids_per_batch,\n const int num_blocks_x_cpu,\n const int max_seq_len,\n const int max_dec_len,\n const float quant_max_bound,\n const float quant_min_bound,\n const float in_scale,\n const int max_partition_size,\n const int encoder_max_partition_size,\n const int speculate_max_draft_token_num,\n const bool is_decoder,\n cudaStream_t &stream,\n paddle::Tensor *out,\n const int sliding_window);\n\n" + "function_signature": "template void {function_name}{template_args}(\n const AppendAttnMetaData &meta_data,\n const paddle::Tensor &qkv,\n const paddle::Tensor &cache_k,\n const paddle::Tensor &cache_v,\n paddle::Tensor &tmp_workspace,\n paddle::Tensor &tmp_m,\n paddle::Tensor &tmp_d,\n const paddle::optional &attn_mask,\n const paddle::Tensor &cache_k_scale,\n const paddle::Tensor &cache_v_scale,\n const paddle::optional &shift_bias,\n const paddle::optional &smooth_weight,\n const paddle::optional &sinks,\n const paddle::Tensor &seq_lens_q,\n const paddle::Tensor &seq_lens_kv,\n const paddle::Tensor &seq_lens_encoder,\n const paddle::Tensor &batch_id_per_token,\n const paddle::Tensor &cu_seqlens_q,\n const paddle::Tensor &block_table,\n const paddle::Tensor &batch_ids,\n const paddle::Tensor &tile_ids_per_batch,\n const int num_blocks_x_cpu,\n const int max_seq_len,\n const int max_dec_len,\n const float quant_max_bound,\n const float quant_min_bound,\n const float in_scale,\n const int max_partition_size,\n const int encoder_max_partition_size,\n const int speculate_max_draft_token_num,\n const bool is_decoder,\n cudaStream_t &stream,\n paddle::Tensor *out,\n const int sliding_window);\n\n" }, "multiquery_attention_c4": { "name": "multiquery_attention_c4", @@ -71,7 +71,7 @@ ], "max_instances_per_file": 160, "file_prefix": "multiquery_attention_c4_", - "function_signature": "template void {function_name}{template_args}(\n const AppendAttnMetaData &meta_data,\n const paddle::Tensor &qkv,\n const paddle::Tensor &cache_k,\n const paddle::Tensor &cache_v,\n const paddle::optional &attn_mask,\n const paddle::Tensor &cache_k_scale,\n const paddle::Tensor &cache_v_scale,\n const paddle::optional &cache_k_zp,\n const paddle::optional &cache_v_zp,\n const paddle::optional &shift_bias,\n const paddle::optional &smooth_weight,\n const paddle::optional &sinks,\n const paddle::Tensor &seq_lens_q,\n const paddle::Tensor &seq_lens_kv,\n const paddle::Tensor &seq_lens_encoder,\n const paddle::Tensor &batch_id_per_token,\n const paddle::Tensor &cu_seqlens_q,\n const paddle::Tensor &block_table,\n const paddle::Tensor &batch_ids,\n const paddle::Tensor &tile_ids_per_batch,\n const int num_blocks_x_cpu,\n const int max_seq_len,\n const int max_dec_len,\n const float quant_max_bound,\n const float quant_min_bound,\n const float in_scale,\n const int max_partition_size,\n const int encoder_max_partition_size,\n const int speculate_max_draft_token_num,\n const bool is_decoder,\n cudaStream_t &stream,\n paddle::Tensor *out,\n const int sliding_window);\n\n" + "function_signature": "template void {function_name}{template_args}(\n const AppendAttnMetaData &meta_data,\n const paddle::Tensor &qkv,\n const paddle::Tensor &cache_k,\n const paddle::Tensor &cache_v,\n paddle::Tensor &tmp_workspace,\n paddle::Tensor &tmp_m,\n paddle::Tensor &tmp_d,\n const paddle::optional &attn_mask,\n const paddle::Tensor &cache_k_scale,\n const paddle::Tensor &cache_v_scale,\n const paddle::optional &cache_k_zp,\n const paddle::optional &cache_v_zp,\n const paddle::optional &shift_bias,\n const paddle::optional &smooth_weight,\n const paddle::optional &sinks,\n const paddle::Tensor &seq_lens_q,\n const paddle::Tensor &seq_lens_kv,\n const paddle::Tensor &seq_lens_encoder,\n const paddle::Tensor &batch_id_per_token,\n const paddle::Tensor &cu_seqlens_q,\n const paddle::Tensor &block_table,\n const paddle::Tensor &batch_ids,\n const paddle::Tensor &tile_ids_per_batch,\n const int num_blocks_x_cpu,\n const int max_seq_len,\n const int max_dec_len,\n const float quant_max_bound,\n const float quant_min_bound,\n const float in_scale,\n const int max_partition_size,\n const int encoder_max_partition_size,\n const int speculate_max_draft_token_num,\n const bool is_decoder,\n cudaStream_t &stream,\n paddle::Tensor *out,\n const int sliding_window);\n\n" }, "multiquery_attention_c16": { "name": "multiquery_attention_c16", @@ -106,7 +106,7 @@ ], "max_instances_per_file": 160, "file_prefix": "multiquery_attention_c16_", - "function_signature": "template void {function_name}{template_args}(\n const AppendAttnMetaData &meta_data,\n const paddle::Tensor &qkv,\n const paddle::Tensor &cache_k,\n const paddle::Tensor &cache_v,\n const paddle::optional &attn_mask,\n const paddle::optional &shift_bias,\n const paddle::optional &smooth_weight,\n const paddle::optional &sinks,\n const paddle::Tensor &seq_lens_q,\n const paddle::Tensor &seq_lens_kv,\n const paddle::Tensor &seq_lens_encoder,\n const paddle::Tensor &batch_id_per_token,\n const paddle::Tensor &cu_seqlens_q,\n const paddle::Tensor &block_table,\n const paddle::Tensor &batch_ids,\n const paddle::Tensor &tile_ids_per_batch,\n const int num_blocks_x_cpu,\n const int max_seq_len,\n const int max_dec_len,\n const float quant_max_bound,\n const float quant_min_bound,\n const float in_scale,\n const int max_partition_size,\n const int encoder_max_partition_size,\n const int speculate_max_draft_token_num,\n const bool is_decoder,\n cudaStream_t &stream,\n paddle::Tensor *out,\n const int sliding_window);\n\n" + "function_signature": "template void {function_name}{template_args}(\n const AppendAttnMetaData &meta_data,\n const paddle::Tensor &qkv,\n const paddle::Tensor &cache_k,\n const paddle::Tensor &cache_v,\n paddle::Tensor &tmp_workspace,\n paddle::Tensor &tmp_m,\n paddle::Tensor &tmp_d,\n const paddle::optional &attn_mask,\n const paddle::optional &shift_bias,\n const paddle::optional &smooth_weight,\n const paddle::optional &sinks,\n const paddle::Tensor &seq_lens_q,\n const paddle::Tensor &seq_lens_kv,\n const paddle::Tensor &seq_lens_encoder,\n const paddle::Tensor &batch_id_per_token,\n const paddle::Tensor &cu_seqlens_q,\n const paddle::Tensor &block_table,\n const paddle::Tensor &batch_ids,\n const paddle::Tensor &tile_ids_per_batch,\n const int num_blocks_x_cpu,\n const int max_seq_len,\n const int max_dec_len,\n const float quant_max_bound,\n const float quant_min_bound,\n const float in_scale,\n const int max_partition_size,\n const int encoder_max_partition_size,\n const int speculate_max_draft_token_num,\n const bool is_decoder,\n cudaStream_t &stream,\n paddle::Tensor *out,\n const int sliding_window);\n\n" }, "multiquery_decoder_attention": { "name": "multiquery_decoder_attention", diff --git a/fastdeploy/cache_manager/prefix_cache_manager.py b/fastdeploy/cache_manager/prefix_cache_manager.py index 99880fa80b6..6fa8f945542 100644 --- a/fastdeploy/cache_manager/prefix_cache_manager.py +++ b/fastdeploy/cache_manager/prefix_cache_manager.py @@ -467,6 +467,24 @@ def recycle_gpu_blocks(self, gpu_block_ids): """ recycle gpu blocks. """ + if ( + hasattr(self, "prefix_tree_status_signal") + and self.prefix_tree_status_signal.value[0] != PrefixTreeStatus.NORMAL + ): + # Prefix Tree Clearing, skip recycle gpu blocks + logger.warning("Prefix tree is not normal, skip recycle gpu blocks") + return + if not isinstance(gpu_block_ids, list): + gpu_block_ids = [gpu_block_ids] + if len(self.gpu_free_block_list) + len(gpu_block_ids) > self.num_gpu_blocks: + # The block allocation and recycling are abnormal, and the test results are not convincing + logger.error( + f"The number of free gpu blocks {len(self.gpu_free_block_list)} plus the number of recycled " + f"gpu blocks {len(gpu_block_ids)} exceeds the total number of gpu blocks {self.num_gpu_blocks} \n" + f"this indicates a block allocation and deallocation error, recycled blocks will be discarded {gpu_block_ids}" + ) + return + logger.info( f"recycle_gpu_blocks: {gpu_block_ids}, len(self.gpu_free_block_list) {len(self.gpu_free_block_list)}" ) @@ -1777,32 +1795,33 @@ def recv_data_transfer_result(self): logger.error(f"recv_data_transfer_result: {str(traceback.format_exc())}") raise e - def reset(self): + def reset(self, wait_for_tasks_done=False): """ Reset the RadixTree. """ - logger.info(f"wait for cache_task_inflight_signal to reset {self.cache_task_inflight_signal.value}") - while np.sum(self.cache_task_inflight_signal.value) != 0: - time.sleep(0.1) - logger.info("wait for recv_data_transfer_result done") - while not self.cache_task_queue.result_queue_empty(): - time.sleep(0.1) + if wait_for_tasks_done: + logger.info(f"wait for cache_task_inflight_signal to reset: {self.cache_task_inflight_signal.value}") + while np.sum(self.cache_task_inflight_signal.value) != 0: + time.sleep(0.1) + + logger.info("wait for recv_data_transfer_result done") + while not self.cache_task_queue.result_queue_empty(): + time.sleep(0.1) + + logger.info("wait for cpu_free_future to finish") + if self.cpu_free_future is not None: + self.cpu_free_future.result() + + logger.info("wait for gpu_free_task_future to finish") + if self.gpu_free_task_future is not None: + self.gpu_free_task_future.result() logger.info(f"Resetting the RadixTree! node_map len {len(self.node_map)}") - logger.info("waiting for cpu_free_future to finish") - if self.cpu_free_future is not None: - self.cpu_free_future.result() + # clear future & events self.cpu_free_future = None - logger.info("reset cpu_free_future") - - logger.info("waiting for gpu_free_task_future to finish") - if self.gpu_free_task_future is not None: - self.gpu_free_task_future.result() self.gpu_free_task_future = None - logger.info("reset gpu_free_task_future") - self.task_swapping_event.clear() # clear node map @@ -1847,10 +1866,11 @@ def clear_prefix_cache(self): prefix_tree_status_signal = self.prefix_tree_status_signal while True: if prefix_tree_status_signal.value[0] == PrefixTreeStatus.CLEARING: - self.reset() + self.reset(wait_for_tasks_done=True) prefix_tree_status_signal.value[0] = PrefixTreeStatus.CLEARED logger.info("Prefix cache tree is cleared.") if prefix_tree_status_signal.value[0] == PrefixTreeStatus.UPDATING: + self.reset(wait_for_tasks_done=False) prefix_tree_status_signal.value[0] = PrefixTreeStatus.NORMAL logger.info("Prefix cache tree is updated.") time.sleep(0.01) diff --git a/fastdeploy/entrypoints/openai/api_server.py b/fastdeploy/entrypoints/openai/api_server.py index f4e5fb39202..bab0f9a1177 100644 --- a/fastdeploy/entrypoints/openai/api_server.py +++ b/fastdeploy/entrypoints/openai/api_server.py @@ -535,6 +535,8 @@ def launch_api_server() -> None: api_server_logger.info(f"args: {args.__dict__}") fd_start_span("FD_START") + # Set control_socket_disable=True to prevent gunicorn.ctl file conflicts when multiple + # instances (e.g., Prefill and Decode services) run in the same directory (gunicorn 25.1.0+) options = { "bind": f"{args.host}:{args.port}", "workers": args.workers, @@ -542,6 +544,7 @@ def launch_api_server() -> None: "loglevel": "info", "graceful_timeout": args.timeout_graceful_shutdown, "timeout": args.timeout, + "control_socket_disable": True, } try: diff --git a/fastdeploy/model_executor/layers/attention/append_attn_backend.py b/fastdeploy/model_executor/layers/attention/append_attn_backend.py index eb6147561fb..36bccbf16ad 100644 --- a/fastdeploy/model_executor/layers/attention/append_attn_backend.py +++ b/fastdeploy/model_executor/layers/attention/append_attn_backend.py @@ -70,6 +70,7 @@ def allocate_launch_related_buffer( num_heads, kv_num_heads, block_size, + head_dim, ): # Initialize AttentionBackend buffers assert num_heads % kv_num_heads == 0 @@ -101,6 +102,16 @@ def allocate_launch_related_buffer( res["kv_batch_ids"] = paddle.full([kv_max_tile_size], 0, dtype="int32") res["kv_tile_ids_per_batch"] = paddle.full([kv_max_tile_size], 0, dtype="int32") res["kv_num_blocks_x_cpu"] = paddle.full([1], 0, dtype="int32").cpu() + + max_partition_size = int(os.getenv("FLAGS_max_partition_size", 1024)) + max_num_chunk = (max_model_len + max_partition_size - 1) // max_partition_size + res["tmp_workspace"] = paddle.full( + [max_batch_size * decoder_step_token_num, max_num_chunk, num_heads * head_dim], + 0, + dtype=paddle.get_default_dtype(), + ) + res["tmp_m"] = paddle.full([max_batch_size * decoder_step_token_num, max_num_chunk, num_heads], 0, dtype="float32") + res["tmp_d"] = paddle.full([max_batch_size * decoder_step_token_num, max_num_chunk, num_heads], 0, dtype="float32") return res @@ -313,6 +324,9 @@ def forward_mixed( qkv, cache_k, cache_v, + forward_meta.tmp_workspace, + forward_meta.tmp_m, + forward_meta.tmp_d, forward_meta.seq_lens_encoder, forward_meta.seq_lens_decoder, forward_meta.seq_lens_this_time, @@ -370,6 +384,9 @@ def forward_mixed( qkv, cache_k, cache_v, + forward_meta.tmp_workspace, + forward_meta.tmp_m, + forward_meta.tmp_d, forward_meta.seq_lens_encoder, forward_meta.seq_lens_decoder, forward_meta.seq_lens_this_time, diff --git a/fastdeploy/model_executor/layers/attention/flash_attn_backend.py b/fastdeploy/model_executor/layers/attention/flash_attn_backend.py index 37516129554..0f8303718c2 100644 --- a/fastdeploy/model_executor/layers/attention/flash_attn_backend.py +++ b/fastdeploy/model_executor/layers/attention/flash_attn_backend.py @@ -307,6 +307,9 @@ def forward_mixed( qkv, forward_meta.caches[2 * layer.layer_id], forward_meta.caches[2 * layer.layer_id + 1], + forward_meta.tmp_workspace, + forward_meta.tmp_m, + forward_meta.tmp_d, forward_meta.seq_lens_encoder, forward_meta.seq_lens_decoder, forward_meta.seq_lens_this_time, diff --git a/fastdeploy/model_executor/layers/attention/flash_mask_attn_backend.py b/fastdeploy/model_executor/layers/attention/flash_mask_attn_backend.py index bee89138f7a..cf8ff1d83e1 100644 --- a/fastdeploy/model_executor/layers/attention/flash_mask_attn_backend.py +++ b/fastdeploy/model_executor/layers/attention/flash_mask_attn_backend.py @@ -294,6 +294,9 @@ def forward_mixed( qkv, forward_meta.caches[2 * layer.layer_id], forward_meta.caches[2 * layer.layer_id + 1], + forward_meta.tmp_workspace, + forward_meta.tmp_m, + forward_meta.tmp_d, forward_meta.seq_lens_encoder, forward_meta.seq_lens_decoder, forward_meta.seq_lens_this_time, diff --git a/fastdeploy/model_executor/layers/attention/ops/append_attention.py b/fastdeploy/model_executor/layers/attention/ops/append_attention.py index 05d36d0e4ef..547cfc25500 100644 --- a/fastdeploy/model_executor/layers/attention/ops/append_attention.py +++ b/fastdeploy/model_executor/layers/attention/ops/append_attention.py @@ -33,6 +33,9 @@ def append_attention( qkv: paddle.Tensor, key_cache: paddle.Tensor, value_cache: paddle.Tensor, + tmp_workspace: paddle.Tensor, + tmp_m: paddle.Tensor, + tmp_d: paddle.Tensor, seq_lens_encoder: paddle.Tensor, seq_lens_decoder: paddle.Tensor, seq_lens_this_time: paddle.Tensor, @@ -92,6 +95,9 @@ def append_attention( qkv, key_cache, value_cache, + tmp_workspace, + tmp_m, + tmp_d, seq_lens_encoder, seq_lens_decoder, seq_lens_this_time, @@ -155,6 +161,9 @@ def append_attention_with_output( qkv: paddle.Tensor, key_cache: paddle.Tensor, value_cache: paddle.Tensor, + tmp_workspace: paddle.Tensor, + tmp_m: paddle.Tensor, + tmp_d: paddle.Tensor, seq_lens_encoder: paddle.Tensor, seq_lens_decoder: paddle.Tensor, seq_lens_this_time: paddle.Tensor, @@ -215,6 +224,9 @@ def append_attention_with_output( qkv, key_cache, value_cache, + tmp_workspace, + tmp_m, + tmp_d, seq_lens_encoder, seq_lens_decoder, seq_lens_this_time, diff --git a/fastdeploy/model_executor/layers/moe/routing_indices_cache.py b/fastdeploy/model_executor/layers/moe/routing_indices_cache.py index 6c44b4b263f..17d920e884f 100644 --- a/fastdeploy/model_executor/layers/moe/routing_indices_cache.py +++ b/fastdeploy/model_executor/layers/moe/routing_indices_cache.py @@ -389,6 +389,11 @@ def split_request_id(self, request_id: str): rollout_id = reversed_tmp_str[-1][::-1] return rollout_id + def clear_all_request(self): + """Clear all requests""" + self.routing_replay_table.fill_(-1) + self.routing_batch_to_request = {} + class StoreWrapper(object): def __init__(self, fd_config: False) -> None: diff --git a/fastdeploy/rl/rollout_model.py b/fastdeploy/rl/rollout_model.py index d343876a3fa..1171cceffca 100644 --- a/fastdeploy/rl/rollout_model.py +++ b/fastdeploy/rl/rollout_model.py @@ -593,10 +593,10 @@ def __init__(self, fd_config: FDConfig): fd_config.parallel_config.ep_group = dist.get_group( fd_config.parallel_config.data_parallel_size + envs.FD_TP_GROUP_GID_OFFSET ) - self.fd_config.parallel_config.tp_group = dist.get_group( + self.mtp_fd_config.parallel_config.tp_group = dist.get_group( fd_config.parallel_config.data_parallel_rank + envs.FD_TP_GROUP_GID_OFFSET ) - self.fd_config.parallel_config.ep_group = dist.get_group( + self.mtp_fd_config.parallel_config.ep_group = dist.get_group( fd_config.parallel_config.data_parallel_size + envs.FD_TP_GROUP_GID_OFFSET ) self.update_mtp_config(self.mtp_fd_config) diff --git a/fastdeploy/spec_decode/mtp.py b/fastdeploy/spec_decode/mtp.py index c28242318db..f335d51afe8 100644 --- a/fastdeploy/spec_decode/mtp.py +++ b/fastdeploy/spec_decode/mtp.py @@ -1210,7 +1210,7 @@ def _propose_cuda(self, step_use_cudagraph: bool = False, is_dummy_run: bool = F if substep != self.num_model_steps - 1: self._get_self_hidden_states(hidden_states) else: - if hasattr(self.model, "empty_input_forward"): + if hasattr(self.model, "empty_input_forward") and not is_dummy_run: self.model.empty_input_forward(forward_meta=self.forward_meta) def _propose_xpu(self, step_use_cudagraph: bool = False, is_dummy_run: bool = False): @@ -1300,7 +1300,7 @@ def _propose_xpu(self, step_use_cudagraph: bool = False, is_dummy_run: bool = Fa if substep != self.num_model_steps - 1: self._get_self_hidden_states(hidden_states) else: - if hasattr(self.model, "empty_input_forward"): + if hasattr(self.model, "empty_input_forward") and not is_dummy_run: self.model.empty_input_forward(self.forward_meta) def _get_self_hidden_states(self, hidden_states): diff --git a/fastdeploy/worker/gpu_model_runner.py b/fastdeploy/worker/gpu_model_runner.py index b5dc565b254..021c6f850fd 100644 --- a/fastdeploy/worker/gpu_model_runner.py +++ b/fastdeploy/worker/gpu_model_runner.py @@ -1871,6 +1871,7 @@ def _initialize_attn_backend(self) -> None: num_heads=num_heads, kv_num_heads=self.model_config.kv_num_heads, block_size=self.fd_config.cache_config.block_size, + head_dim=head_dim, ) self.share_inputs.update(res_buffer) @@ -2861,6 +2862,10 @@ def clear_requests(self): self.in_progress_prompt_logprobs.clear() self.forward_batch_reqs_list = [None for _ in range(self.scheduler_config.max_num_seqs)] + # Routing Replay + if self.routing_replay_manager: + self.routing_replay_manager.clear_all_request() + def update_parameters(self, pid): """Dynamic model loader use to update parameters use for RL""" # Update parameters diff --git a/tests/layers/test_append_attention.py b/tests/layers/test_append_attention.py index 01ad4bb932b..35110917c3c 100644 --- a/tests/layers/test_append_attention.py +++ b/tests/layers/test_append_attention.py @@ -510,6 +510,20 @@ def init_tensor(self): self.decoder_num_blocks_device = paddle.full([1], 0, dtype="int32") self.decoder_chunk_size_device = paddle.full([1], 64, dtype="int32") self.max_len_tensor_cpu = paddle.full([8], 0, dtype="int32").cpu() + decoder_max_partition_size = 32768 + decoder_step_token_num = self.seq_len + max_num_chunk = (self.max_seq_len + decoder_max_partition_size - 1) // decoder_max_partition_size + self.tmp_workspace = paddle.full( + [self.batch_size * decoder_step_token_num, max_num_chunk, self.q_num_head * self.dim_head], + 0, + dtype=paddle.get_default_dtype(), + ) + self.tmp_m = paddle.full( + [self.batch_size * decoder_step_token_num, max_num_chunk, self.q_num_head], 0, dtype="float32" + ) + self.tmp_d = paddle.full( + [self.batch_size * decoder_step_token_num, max_num_chunk, self.q_num_head], 0, dtype="float32" + ) self.encoder_batch_ids = paddle.full([self.batch_size], 0, dtype="int32") self.encoder_tile_ids_per_batch = paddle.full([self.batch_size], 0, dtype="int32") @@ -638,6 +652,9 @@ def cmp_append_attention(self, naive_cache_k=None, naive_cache_v=None, attn_mask qkv_copy = copy.deepcopy(qkv) append_attention( qkv_copy, + self.tmp_workspace, + self.tmp_m, + self.tmp_d, self.cache_k_T, self.cache_v_T, self.seq_lens_encoder, @@ -701,6 +718,9 @@ def cmp_append_attention(self, naive_cache_k=None, naive_cache_v=None, attn_mask start_time = time.time() out = append_attention( qkv, + self.tmp_workspace, + self.tmp_m, + self.tmp_d, self.cache_k, self.cache_v, self.seq_lens_encoder, diff --git a/tests/layers/test_append_attention_with_output.py b/tests/layers/test_append_attention_with_output.py index 6c15de17ccc..7807bc5e4ae 100644 --- a/tests/layers/test_append_attention_with_output.py +++ b/tests/layers/test_append_attention_with_output.py @@ -391,6 +391,20 @@ def init_tensor(self): self.kv_batch_ids = paddle.full([self.batch_size], 0, dtype="int32") self.kv_tile_ids_per_batch = paddle.full([self.batch_size], 0, dtype="int32") self.kv_num_blocks_x_cpu = paddle.full([1], 0, dtype="int32").cpu() + decoder_max_partition_size = 32768 + decoder_step_token_num = self.seq_len + max_num_chunk = (self.max_seq_len + decoder_max_partition_size - 1) // decoder_max_partition_size + self.tmp_workspace = paddle.full( + [self.batch_size * decoder_step_token_num, max_num_chunk, self.q_num_head * self.dim_head], + 0, + dtype=paddle.get_default_dtype(), + ) + self.tmp_m = paddle.full( + [self.batch_size * decoder_step_token_num, max_num_chunk, self.q_num_head], 0, dtype="float32" + ) + self.tmp_d = paddle.full( + [self.batch_size * decoder_step_token_num, max_num_chunk, self.q_num_head], 0, dtype="float32" + ) self.cache_shape = ( self.max_block_num, @@ -493,6 +507,9 @@ def cmp_append_attention(self, naive_cache_k=None, naive_cache_v=None, attn_mask qkv, self.cache_k, self.cache_v, + self.tmp_workspace, + self.tmp_m, + self.tmp_d, self.seq_lens_encoder, self.seq_lens_decoder, self.seq_lens_this_time, diff --git a/tests/layers/test_attention_layer.py b/tests/layers/test_attention_layer.py index 5f10631cd82..2856e996316 100644 --- a/tests/layers/test_attention_layer.py +++ b/tests/layers/test_attention_layer.py @@ -232,6 +232,7 @@ def create_forward_meta( num_heads=fd_config.model_config.num_attention_heads, kv_num_heads=fd_config.model_config.num_key_value_heads, block_size=fd_config.cache_config.block_size, + head_dim=fd_config.model_config.head_dim, ) block_size = fd_config.cache_config.block_size diff --git a/tests/operators/test_tree_mask.py b/tests/operators/test_tree_mask.py index 57a62044814..d1ab8b8b0e4 100644 --- a/tests/operators/test_tree_mask.py +++ b/tests/operators/test_tree_mask.py @@ -229,6 +229,15 @@ def run_append_c16_attention( kv_batch_ids = paddle.full([int(kv_max_tile_size)], 0, dtype="int32") kv_tile_ids_per_batch = paddle.full([int(kv_max_tile_size)], 0, dtype="int32") kv_num_blocks_x_cpu = paddle.full([1], 0, dtype="int32").cpu() + + max_num_chunk = (self.max_seq_len + self.max_partition_size - 1) // self.max_partition_size + tmp_workspace = paddle.full( + [self.bsz * decoder_step_token_num, max_num_chunk, self.num_q_head * self.head_dim], + 0, + dtype=paddle.get_default_dtype(), + ) + tmp_m = paddle.full([self.bsz * decoder_step_token_num, max_num_chunk, self.num_q_head], 0, dtype="float32") + tmp_d = paddle.full([self.bsz * decoder_step_token_num, max_num_chunk, self.num_q_head], 0, dtype="float32") q_norm_weight = np.ones([self.head_dim]) k_norm_weight = np.ones([self.head_dim]) self.q_norm_weight_tensor = paddle.to_tensor(q_norm_weight, dtype="float32") @@ -263,6 +272,9 @@ def run_append_c16_attention( qkv, self.cache_k, self.cache_v, + tmp_workspace, + tmp_m, + tmp_d, seq_lens_encoder, seq_lens_decoder, seq_lens_this_time,