From 93ef0b4fad60435f76e0351865493aad3de5a608 Mon Sep 17 00:00:00 2001 From: Linjun-AMD Date: Mon, 2 Feb 2026 01:19:16 -0600 Subject: [PATCH 1/4] optimized some code for gptoss sink Signed-off-by: Linjun-AMD --- ..._batch_prefill_pipeline_qr_ks_vs_async.hpp | 16 ++++++---- ...ock_fmha_fwd_pagedkv_pipeline_qr_ks_vs.hpp | 16 ++++++---- ...litkv_pipeline_nwarp_sshuffle_qr_ks_vs.hpp | 14 +++++++- ...ock_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp | 26 +++++++-------- .../pipeline/block_fmha_pipeline_qr_ks_vs.hpp | 16 ++++++---- .../block_fmha_pipeline_qr_ks_vs_async.hpp | 23 +++++++------ ...ck_fmha_pipeline_qr_ks_vs_async_trload.hpp | 32 ++++++++++++------- 7 files changed, 87 insertions(+), 56 deletions(-) diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp index 48e8f75ae7e..ad66d03d9e7 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp @@ -426,18 +426,22 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync auto m = MLBlockTileType{}; auto l = MLBlockTileType{}; - clear_tile(o_acc); - if(__builtin_isinf_sign(sink_v) >= 0) - { + const auto init_m_val = [&]() { #if CK_TILE_FMHA_FWD_FAST_EXP2 if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS || BiasEnum == BlockAttentionBiasEnum::ALIBI) - set_tile(m, sink_v * LOG2E * scale_s); + return sink_v * LOG2E * scale_s; else - set_tile(m, sink_v * LOG2E); + return sink_v * LOG2E; #else - set_tile(m, sink_v); + return sink_v; #endif + }(); + + clear_tile(o_acc); + if(__builtin_isinf_sign(sink_v) >= 0) + { + set_tile(m, init_m_val); set_tile(l, SMPLComputeDataType{1.0f}); } else diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_pagedkv_pipeline_qr_ks_vs.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_pagedkv_pipeline_qr_ks_vs.hpp index e516fc8eea0..b8213f2110f 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_pagedkv_pipeline_qr_ks_vs.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_pagedkv_pipeline_qr_ks_vs.hpp @@ -227,18 +227,22 @@ struct BlockFmhaFwdPagedKVPipelineQRKSVS auto m = MLBlockTileType{}; auto l = MLBlockTileType{}; - clear_tile(o_acc); - if(__builtin_isinf_sign(sink_v) >= 0) - { + const auto init_m_val = [&]() { #if CK_TILE_FMHA_FWD_FAST_EXP2 if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS || BiasEnum == BlockAttentionBiasEnum::ALIBI) - set_tile(m, sink_v * C_LOG2E * scale_s); + return sink_v * C_LOG2E * scale_s; else - set_tile(m, sink_v * C_LOG2E); + return sink_v * C_LOG2E; #else - set_tile(m, sink_v); + return sink_v; #endif + }(); + + clear_tile(o_acc); + if((__builtin_isinf_sign(sink_v) >= 0) && i_split == 0) + { + set_tile(m, init_m_val); set_tile(l, SMPLComputeDataType{1.0f}); } else diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_nwarp_sshuffle_qr_ks_vs.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_nwarp_sshuffle_qr_ks_vs.hpp index adc8ea5a90c..0c9f000808a 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_nwarp_sshuffle_qr_ks_vs.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_nwarp_sshuffle_qr_ks_vs.hpp @@ -254,10 +254,22 @@ struct BlockFmhaFwdSplitKVPipelineNWarpSShuffleQRKSVS auto m = MLBlockTileType{}; auto l = MLBlockTileType{}; + const auto init_m_val = [&]() { +#if CK_TILE_FMHA_FWD_FAST_EXP2 + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS || + BiasEnum == BlockAttentionBiasEnum::ALIBI) + return sink_v * C_LOG2E * scale_s; + else + return sink_v * C_LOG2E; +#else + return sink_v; +#endif + }(); + clear_tile(o_acc); if((__builtin_isinf_sign(sink_v) >= 0) && i_split == 0) { - set_tile(m, SMPLComputeDataType{sink_v * C_LOG2E}); + set_tile(m, init_m_val); set_tile(l, SMPLComputeDataType{1.0f}); } else diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp index c09330f8471..384cb90e27d 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp @@ -227,18 +227,22 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS auto m = MLBlockTileType{}; auto l = MLBlockTileType{}; - clear_tile(o_acc); - if((__builtin_isinf_sign(sink_v) >= 0) && i_split == 0) - { + const auto init_m_val = [&]() { #if CK_TILE_FMHA_FWD_FAST_EXP2 if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS || BiasEnum == BlockAttentionBiasEnum::ALIBI) - set_tile(m, sink_v * C_LOG2E * scale_s); + return sink_v * C_LOG2E * scale_s; else - set_tile(m, sink_v * C_LOG2E); + return sink_v * C_LOG2E; #else - set_tile(m, sink_v); + return sink_v; #endif + }(); + + clear_tile(o_acc); + if((__builtin_isinf_sign(sink_v) >= 0) && i_split == 0) + { + set_tile(m, init_m_val); set_tile(l, SMPLComputeDataType{1.0f}); } else @@ -302,15 +306,7 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS q_origin.at(number<0>{}), number{}, number{}, num_splits, i_split - 1); if((__builtin_isinf_sign(sink_v) >= 0) && start >= end) { -#if CK_TILE_FMHA_FWD_FAST_EXP2 - if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS || - BiasEnum == BlockAttentionBiasEnum::ALIBI) - set_tile(m, sink_v * C_LOG2E * scale_s); - else - set_tile(m, sink_v * C_LOG2E); -#else - set_tile(m, sink_v); -#endif + set_tile(m, init_m_val); set_tile(l, SMPLComputeDataType{1.0f}); } else diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs.hpp index 2fbc9fdb545..999afef5035 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs.hpp @@ -238,18 +238,22 @@ struct BlockFmhaPipelineQRKSVS auto m = MLBlockTileType{}; auto l = MLBlockTileType{}; - clear_tile(o_acc); - if(__builtin_isinf_sign(sink_v) >= 0) - { + const auto init_m_val = [&]() { #if CK_TILE_FMHA_FWD_FAST_EXP2 if constexpr(BiasEnum == BlockAttentionBiasEnum::ALIBI || BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) - set_tile(m, sink_v * scale_s * C_LOG2E); + return sink_v * scale_s * C_LOG2E; else - set_tile(m, sink_v * C_LOG2E); + return sink_v * C_LOG2E; #else - set_tile(m, sink_v); + return sink_v; #endif + }(); + + clear_tile(o_acc); + if(__builtin_isinf_sign(sink_v) >= 0) + { + set_tile(m, init_m_val); set_tile(l, SMPLComputeDataType{1.0f}); } else diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp index 81bd8d5ab52..08b3d9b30f2 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp @@ -279,22 +279,25 @@ struct BlockFmhaPipelineQRKSVSAsync using OaccBlockTileType = decltype(gemm_1.MakeCBlockTile()); // init Oacc, M, L - auto o_acc = OaccBlockTileType{}; - auto m = MLBlockTileType{}; - auto l = MLBlockTileType{}; - - clear_tile(o_acc); - if(__builtin_isinf_sign(sink_v) >= 0) - { + auto o_acc = OaccBlockTileType{}; + auto m = MLBlockTileType{}; + auto l = MLBlockTileType{}; + const auto init_m_val = [&]() { #if CK_TILE_FMHA_FWD_FAST_EXP2 if constexpr(BiasEnum == BlockAttentionBiasEnum::ALIBI || BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) - set_tile(m, sink_v * scale_s * LOG2E); + return sink_v * scale_s * LOG2E; else - set_tile(m, sink_v * LOG2E); + return sink_v * LOG2E; #else - set_tile(m, sink_v); + return sink_v; #endif + }(); + + clear_tile(o_acc); + if(__builtin_isinf_sign(sink_v) >= 0) + { + set_tile(m, init_m_val); set_tile(l, SMPLComputeDataType{1.0f}); } else diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async_trload.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async_trload.hpp index aab79c52ae9..2fafa3a40a2 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async_trload.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async_trload.hpp @@ -193,18 +193,22 @@ struct BlockFmhaPipelineQRKSVSAsyncTrload auto m = MLBlockTileType{}; auto l = MLBlockTileType{}; - clear_tile(o_acc); - if(__builtin_isinf_sign(sink_v) >= 0) - { + const auto init_m_val = [&]() { #if CK_TILE_FMHA_FWD_FAST_EXP2 if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS || BiasEnum == BlockAttentionBiasEnum::ALIBI) - set_tile(m, sink_v * C_LOG2E * scale_s); + return sink_v * C_LOG2E * scale_s; else - set_tile(m, sink_v * C_LOG2E); + return sink_v * C_LOG2E; #else - set_tile(m, sink_v); + return sink_v; #endif + }(); + + clear_tile(o_acc); + if(__builtin_isinf_sign(sink_v) >= 0) + { + set_tile(m, init_m_val); set_tile(l, SMPLComputeDataType{1.0f}); } else @@ -722,18 +726,22 @@ struct BlockFmhaPipelineQRKSVSAsyncTrload auto m = MLBlockTileType{}; auto l = MLBlockTileType{}; - clear_tile(o_acc); - if(__builtin_isinf_sign(sink_v) >= 0) - { + const auto init_m_val = [&]() { #if CK_TILE_FMHA_FWD_FAST_EXP2 if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS || BiasEnum == BlockAttentionBiasEnum::ALIBI) - set_tile(m, sink_v * C_LOG2E * scale_s); + return sink_v * C_LOG2E * scale_s; else - set_tile(m, sink_v * C_LOG2E); + return sink_v * C_LOG2E; #else - set_tile(m, sink_v); + return sink_v; #endif + }(); + + clear_tile(o_acc); + if(__builtin_isinf_sign(sink_v) >= 0) + { + set_tile(m, init_m_val); set_tile(l, SMPLComputeDataType{1.0f}); } else From a812b044ae4f91abf0e0028592d248b5d4d87fa3 Mon Sep 17 00:00:00 2001 From: Linjun-AMD Date: Mon, 2 Feb 2026 01:34:23 -0600 Subject: [PATCH 2/4] updated fmha_args Signed-off-by: Linjun-AMD --- example/ck_tile/01_fmha/fmha_fwd.hpp | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/example/ck_tile/01_fmha/fmha_fwd.hpp b/example/ck_tile/01_fmha/fmha_fwd.hpp index aedbb0e17c2..65c280d42dd 100644 --- a/example/ck_tile/01_fmha/fmha_fwd.hpp +++ b/example/ck_tile/01_fmha/fmha_fwd.hpp @@ -232,7 +232,7 @@ struct fmha_fwd_args // array [batch + 1]. (Used with padding) const void* block_scale_seqstart_q_ptr; const void* block_scale_seqstart_k_ptr; - const void* sink_ptr; + const void* sink_ptr = nullptr; ck_tile::index_t seqlen_q; ck_tile::index_t seqlen_k; @@ -329,7 +329,7 @@ struct fmha_fwd_pagedkv_args const void* seqstart_q_ptr; const void* seqstart_k_ptr; const void* seqlen_k_ptr; - const void* sink_ptr; + const void* sink_ptr = nullptr; ck_tile::index_t seqlen_q; ck_tile::index_t seqlen_k; @@ -413,7 +413,7 @@ struct fmha_fwd_splitkv_args const void* seqstart_q_ptr; const void* seqstart_k_ptr; const void* seqlen_k_ptr; - const void* sink_ptr; + const void* sink_ptr = nullptr; ck_tile::index_t seqlen_q; ck_tile::index_t seqlen_k; @@ -490,7 +490,6 @@ struct fmha_fwd_appendkv_args ck_tile::index_t page_block_size; // only used if 'block_table_ptr' is not nullptr const void* cache_batch_idx; // only used if block_table_ptr is nullptr -> batch mode (kvcache) - const void* sink_ptr; ck_tile::index_t stride_q; ck_tile::index_t stride_k; @@ -534,7 +533,7 @@ struct fmha_batch_prefill_args // 1) + // kargs.kv_last_page_lens[b] const void* seqstart_q_ptr; - const void* sink_ptr; + const void* sink_ptr = nullptr; ck_tile::index_t seqlen_q; ck_tile::index_t seqlen_k; From 7f0d5cdcc9a6e409f6abce9a7b0d66ae9ff8497a Mon Sep 17 00:00:00 2001 From: Linjun-AMD Date: Mon, 2 Feb 2026 01:36:34 -0600 Subject: [PATCH 3/4] update codegen for sink Signed-off-by: Linjun-AMD --- example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py | 1 + 1 file changed, 1 insertion(+) diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py index b59f442663f..594758a2b7b 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py @@ -1208,6 +1208,7 @@ def fit(problem_ctx: ProblemContext, kernel_ctx: KernelContext) -> bool: cond &= kernel_ctx.pipeline.F_bias in ["no", "alibi"] cond &= kernel_ctx.pipeline.F_qscale == "no" cond &= kernel_ctx.pipeline.F_skip == "f" + cond &= kernel_ctx.pipeline.F_sink == "f" return cond return Product(name="Flash attention integration", rule=fit) From c4d95db73db07519eea75839d080dfe57f2d1c1a Mon Sep 17 00:00:00 2001 From: Linjun-AMD Date: Mon, 2 Feb 2026 15:48:26 +0800 Subject: [PATCH 4/4] Simplify condition for setting tile values --- .../fmha/pipeline/block_fmha_fwd_pagedkv_pipeline_qr_ks_vs.hpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_pagedkv_pipeline_qr_ks_vs.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_pagedkv_pipeline_qr_ks_vs.hpp index b8213f2110f..894d22e5f17 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_pagedkv_pipeline_qr_ks_vs.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_pagedkv_pipeline_qr_ks_vs.hpp @@ -240,7 +240,7 @@ struct BlockFmhaFwdPagedKVPipelineQRKSVS }(); clear_tile(o_acc); - if((__builtin_isinf_sign(sink_v) >= 0) && i_split == 0) + if(__builtin_isinf_sign(sink_v) >= 0) { set_tile(m, init_m_val); set_tile(l, SMPLComputeDataType{1.0f});