From 210887f4f75adff909b21ed876c5769e43aa0b23 Mon Sep 17 00:00:00 2001 From: "PoYen, Chen" Date: Mon, 5 Jan 2026 03:04:03 -0600 Subject: [PATCH 01/15] Allow passing descales to fmha v3 kernel --- example/ck_tile/01_fmha/fmha_fwd.hpp | 6 + .../ops/fmha/kernel/fmha_fwd_v3_kernel.hpp | 111 +++++++++++++++--- 2 files changed, 103 insertions(+), 14 deletions(-) diff --git a/example/ck_tile/01_fmha/fmha_fwd.hpp b/example/ck_tile/01_fmha/fmha_fwd.hpp index 3ff4acfc156..cfa760f6c1e 100644 --- a/example/ck_tile/01_fmha/fmha_fwd.hpp +++ b/example/ck_tile/01_fmha/fmha_fwd.hpp @@ -731,6 +731,9 @@ auto fmha_fwd_v3_create_kargs_and_grids(fmha_fwd_args args) return FmhaKernel::MakeKargs(args.q_ptr, args.k_ptr, args.v_ptr, + args.q_descale_ptr, + args.k_descale_ptr, + args.v_descale_ptr, nullptr, // lse_ptr args.o_ptr, args.seqstart_q_ptr, @@ -764,6 +767,9 @@ auto fmha_fwd_v3_create_kargs_and_grids(fmha_fwd_args args) return FmhaKernel::MakeKargs(args.q_ptr, args.k_ptr, args.v_ptr, + args.q_descale_ptr, + args.k_descale_ptr, + args.v_descale_ptr, nullptr, // lse_ptr args.o_ptr, args.seqlen_q, diff --git a/include/ck_tile/ops/fmha/kernel/fmha_fwd_v3_kernel.hpp b/include/ck_tile/ops/fmha/kernel/fmha_fwd_v3_kernel.hpp index 6fe1de634d9..f82c5bc6526 100644 --- a/include/ck_tile/ops/fmha/kernel/fmha_fwd_v3_kernel.hpp +++ b/include/ck_tile/ops/fmha/kernel/fmha_fwd_v3_kernel.hpp @@ -27,6 +27,7 @@ struct FmhaFwdV3Kernel using QDataType = ck_tile::remove_cvref_t; using KDataType = ck_tile::remove_cvref_t; using VDataType = ck_tile::remove_cvref_t; + using PDataType = ck_tile::remove_cvref_t; using LSEDataType = ck_tile::remove_cvref_t; using ODataType = ck_tile::remove_cvref_t; using SaccDataType = ck_tile::remove_cvref_t; @@ -38,6 +39,7 @@ struct FmhaFwdV3Kernel static constexpr bool kPadHeadDimV = FmhaPipeline::kPadHeadDimV; static constexpr bool kHasLogitsSoftCap = FmhaPipeline::kHasLogitsSoftCap; static constexpr bool kStoreLSE = FmhaPipeline::kStoreLSE; + static constexpr auto QScaleEnum = FmhaPipeline::Problem::QScaleEnum; using AttentionVariant = ck_tile::remove_cvref_t; using FmhaMask = ck_tile::remove_cvref_t; @@ -118,11 +120,21 @@ struct FmhaFwdV3Kernel float logits_soft_cap_rcp; }; + struct FmhaFwdCommonQScaleKargs + { + const void* q_descale_ptr = nullptr; + const void* k_descale_ptr = nullptr; + const void* v_descale_ptr = nullptr; + }; + struct FmhaFwdBatchModeKargs : FmhaFwdCommonKargs, std::conditional_t>, std::conditional_t>, - std::conditional_t> + std::conditional_t>, + std::conditional_t> { ck_tile::index_t batch_stride_q; ck_tile::index_t batch_stride_k; @@ -139,7 +151,10 @@ struct FmhaFwdV3Kernel : FmhaFwdCommonKargs, std::conditional_t>, std::conditional_t>, - std::conditional_t> + std::conditional_t>, + std::conditional_t> { const int32_t* seqstart_q_ptr; const int32_t* seqstart_k_ptr; @@ -166,6 +181,9 @@ struct FmhaFwdV3Kernel MakeKargs(const void* q_ptr, const void* k_ptr, const void* v_ptr, + const void* q_descale_ptr, + const void* k_descale_ptr, + const void* v_descale_ptr, void* lse_ptr, void* o_ptr, ck_tile::index_t seqlen_q, @@ -218,6 +236,7 @@ struct FmhaFwdV3Kernel nhead_stride_o}, // args for common karg {}, // placeholder for mask {}, // placeholder for lse + {}, // placeholder for qscale {}, // placeholder for logits_soft_cap batch_stride_q, batch_stride_k, @@ -237,6 +256,12 @@ struct FmhaFwdV3Kernel kargs.nhead_stride_lse = nhead_stride_lse; kargs.batch_stride_lse = batch_stride_lse; } + if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::PERTENSOR) + { + kargs.q_descale_ptr = q_descale_ptr; + kargs.k_descale_ptr = k_descale_ptr; + kargs.v_descale_ptr = v_descale_ptr; + } if constexpr(kHasLogitsSoftCap) { kargs.init_logits_soft_cap(logits_soft_cap); @@ -252,6 +277,9 @@ struct FmhaFwdV3Kernel MakeKargs(const void* q_ptr, const void* k_ptr, const void* v_ptr, + const void* q_descale_ptr, + const void* k_descale_ptr, + const void* v_descale_ptr, void* lse_ptr, void* o_ptr, const void* seqstart_q_ptr, @@ -301,6 +329,7 @@ struct FmhaFwdV3Kernel nhead_stride_o}, // args for common karg {}, // placeholder for mask {}, // placeholder for lse + {}, // placeholder for qscale {}, // placeholder for logits_soft_cap reinterpret_cast(seqstart_q_ptr), reinterpret_cast(seqstart_k_ptr), @@ -319,6 +348,12 @@ struct FmhaFwdV3Kernel kargs.lse_ptr = lse_ptr; kargs.nhead_stride_lse = nhead_stride_lse; } + if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::PERTENSOR) + { + kargs.q_descale_ptr = q_descale_ptr; + kargs.k_descale_ptr = k_descale_ptr; + kargs.v_descale_ptr = v_descale_ptr; + } if constexpr(kHasLogitsSoftCap) { kargs.init_logits_soft_cap(logits_soft_cap); @@ -640,32 +675,80 @@ struct FmhaFwdV3Kernel return FmhaMask{kargs.seqlen_q, kargs.seqlen_k}; }(); + const float scale_s = [&] { + if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::PERTENSOR) + { + float q_descale = *(reinterpret_cast(kargs.q_descale_ptr)); + float k_descale = *(reinterpret_cast(kargs.k_descale_ptr)); + return kargs.scale_s * q_descale * k_descale; + } + else + { + return kargs.scale_s; + } + }(); + AttentionVariant variant; const auto variant_params = [&] { if constexpr(kHasLogitsSoftCap) { return ck_tile::LogitsSoftCapParams{ - mask, kargs.scale_s, kargs.logits_soft_cap, kargs.logits_soft_cap_rcp}; + mask, scale_s, kargs.logits_soft_cap, kargs.logits_soft_cap_rcp}; } else { - return ck_tile::StandardAttentionParams{mask, kargs.scale_s}; + return ck_tile::StandardAttentionParams{mask, scale_s}; } }(); BlockIndices block_indices{i_batch, i_nhead, i_nhead / kargs.nhead_ratio_qk}; auto o_acc_tile = [&]() { - return FmhaPipeline{}(q_dram_window, - k_dram_window, - v_dram_window, - lse_dram_window, - mask, - kargs.scale_s, - variant, - variant_params, - block_indices, - smem_ptr); + if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::PERTENSOR) + { + float v_descale = *(reinterpret_cast(kargs.v_descale_ptr)); + float scale_p = ck_tile::type_convert(ck_tile::numeric::max()); + float scale_o = v_descale / scale_p; + + auto o_acc_element_func = [&]() { + if constexpr(std::is_same_v) + return ck_tile::composes(ck_tile::saturates{}, + ck_tile::scales{scale_o}); + else + return ck_tile::scales{scale_o}; + }(); + + return FmhaPipeline{}(q_dram_window, + identity{}, // q_element_func + k_dram_window, + identity{}, // k_element_func + v_dram_window, + identity{}, // v_element_func + lse_dram_window, + identity{}, // lse_element_func + identity{}, // s_acc_element_func + scales{scale_p}, // p_compute_element_func + o_acc_element_func, + mask, + scale_s, + variant, + variant_params, + block_indices, + smem_ptr); + } + else + { + return FmhaPipeline{}(q_dram_window, + k_dram_window, + v_dram_window, + lse_dram_window, + mask, + scale_s, + variant, + variant_params, + block_indices, + smem_ptr); + } }(); // O DRAM and O DRAM window From 8f973efed9b24c1e9377212f5babf5a3c6c4095e Mon Sep 17 00:00:00 2001 From: "PoYen, Chen" Date: Thu, 8 Jan 2026 00:48:51 -0600 Subject: [PATCH 02/15] Allow enabling quantization scale feature for FMHA v3 --- include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_v3_pipeline.hpp | 1 - 1 file changed, 1 deletion(-) diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_v3_pipeline.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_v3_pipeline.hpp index c25f57632fa..14dd9c8db27 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_v3_pipeline.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_v3_pipeline.hpp @@ -300,7 +300,6 @@ struct BlockFmhaFwdV3Pipeline static constexpr auto QScaleEnum = Problem::QScaleEnum; static constexpr bool kSkipMinSeqlenQ = Problem::kSkipMinSeqlenQ; static_assert((BiasEnum == BlockAttentionBiasEnum::NO_BIAS && !kStoreLSE && !kHasDropout && - (QScaleEnum == ck_tile::BlockAttentionQuantScaleEnum::NO_SCALE) && !kSkipMinSeqlenQ), "enable unsupported features"); From 5876cd86bbeaa7567f6d93f72434d12400807a42 Mon Sep 17 00:00:00 2001 From: "PoYen, Chen" Date: Thu, 8 Jan 2026 00:50:35 -0600 Subject: [PATCH 03/15] Add fp8 32x32x32 warp gemm (C-transposed) --- include/ck_tile/ops/gemm/warp/warp_gemm.hpp | 7 +++++++ include/ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp | 2 ++ 2 files changed, 9 insertions(+) diff --git a/include/ck_tile/ops/gemm/warp/warp_gemm.hpp b/include/ck_tile/ops/gemm/warp/warp_gemm.hpp index 7bcc9107da9..e86e5d9f86a 100644 --- a/include/ck_tile/ops/gemm/warp/warp_gemm.hpp +++ b/include/ck_tile/ops/gemm/warp/warp_gemm.hpp @@ -276,6 +276,13 @@ using WarpGemmMfma_f32_32x32x32_fp8_fp8 = WarpGemmImpl, 2>>; +template +using WarpGemmMfma_f32_32x32x32_fp8_fp8_CTransposed = + WarpGemmImpl, + 2, + AttrNumAccess>>; + using WarpGemmMfma_f32_32x32x32_bf8_bf8 = WarpGemmImpl, 2>>; diff --git a/include/ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp b/include/ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp index d6c21e88b56..940447cc22f 100644 --- a/include/ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp +++ b/include/ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp @@ -98,6 +98,8 @@ template<> struct Dispatcher { u // ADataType, BDataType, AccDataType, MPerWave, NPerWave, KPerWave, TransposeC, SwizzleA, UseStructuredSparsity template<> struct Dispatcher { using Type = WarpGemmMfma_f32_32x32x16_fp8_fp8; }; template<> struct Dispatcher { using Type = WarpGemmMfma_f32_32x32x32_fp8_fp8; }; +template<> struct Dispatcher { using Type = WarpGemmMfma_f32_32x32x32_fp8_fp8_CTransposed<>; }; +template<> struct Dispatcher { using Type = WarpGemmMfma_f32_32x32x32_fp8_fp8_CTransposed; }; template<> struct Dispatcher { using Type = WarpGemmMfma_f32_16x16x32_fp8_fp8; }; template<> struct Dispatcher { using Type = WarpGemmMfma_f32_16x16x64_fp8_fp8; }; template<> struct Dispatcher { using Type = WarpGemmMfma_f32_16x16x64_fp8_fp8_CTransposed; }; From 941c7e67bbbf5f6bec7b681b5179a194fa7d396f Mon Sep 17 00:00:00 2001 From: "PoYen, Chen" Date: Thu, 8 Jan 2026 00:52:30 -0600 Subject: [PATCH 04/15] Add fp8 QK block gemm config --- .../block_fmha_fwd_v3_pipeline_default_policy.hpp | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_v3_pipeline_default_policy.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_v3_pipeline_default_policy.hpp index ce097b6741b..957e404b35b 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_v3_pipeline_default_policy.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_v3_pipeline_default_policy.hpp @@ -239,10 +239,18 @@ struct BlockFmhaV3PipelineDefaultPolicy typename Problem::BlockFmhaShape::Gemm0BlockWarps, typename Problem::BlockFmhaShape::Gemm0WarpTile>>; - constexpr auto warp_gemm = []() { - if constexpr(std::is_same_v && - std::is_same_v && + constexpr auto warp_gemm = [] { + if constexpr(std::is_same_v && + std::is_same_v && std::is_same_v) + { + constexpr index_t swizzle_factor = 4; + return WarpGemmMfmaFp8Fp8F32M32N32K32SwizzleBTransposedCDistribution< + swizzle_factor>{}; + } + else if constexpr(std::is_same_v && + std::is_same_v && + std::is_same_v) { /// NOTICE: in order to use load_tile_transpose() later for V tile, we cannot use /// WarpGemmMfmaF16F16F32M32N32K16SwizzleBTransposedCDistribution here From fa404ca9703c0930242f8e6f8fff72a4036bc072 Mon Sep 17 00:00:00 2001 From: "PoYen, Chen" Date: Thu, 8 Jan 2026 00:54:57 -0600 Subject: [PATCH 05/15] Add fp8 FMHA v3 instances --- example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py | 13 +++++++++++++ .../fmha/pipeline/block_fmha_fwd_v3_pipeline.hpp | 2 +- 2 files changed, 14 insertions(+), 1 deletion(-) diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py index dd65c0298b3..a1919e954fb 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py @@ -1048,6 +1048,10 @@ def get_hdim_tile_size_dict(cls, dtype: str) -> Optional[dict]: if (128, 128) in result.keys(): result[(128, 128)].append( FmhaFwdTileSize(256, 32, 128, 128, 32, 128, 8, 1, 1, 8, 1, 1, 32, 32, 16, 32, 32, 16, -1)) # fmt: skip + elif dtype in cls._DT_FP8BF16: + if (128, 128) in result.keys(): + result[(128, 128)].append( + FmhaFwdTileSize(256, 64, 128, 128, 64, 128, 8, 1, 1, 8, 1, 1, 32, 32, 32, 32, 32, 32, -1)) # fmt: skip return result @classmethod @@ -1085,6 +1089,15 @@ def get_pipelines( pipelines.append(FmhaFwdPipeline("qr_async_trload_v3", "row", "t", "t", "f", "f", F_logits=logits, F_bias="no", F_lse="f", F_dropout="f", F_qscale=qscale, F_mask=mask, F_skip="f", F_trload="t", F_sink="f")) # fmt: skip + elif dtype in cls._DT_FP8BF16: + # no need lse/dropout kernels + # qr_async_trload_v3 only supports (generic) causal mask + for logits, qscale, mask in itertools.product( + ["t", "f"], + ["no", "pertensor"], + ["no", "causal"], + ): + pipelines.append(FmhaFwdPipeline("qr_async_trload_v3", "row", "t", "t", "f", "f", F_logits=logits, F_bias="no", F_lse="f", F_dropout="f", F_qscale=qscale, F_mask=mask, F_skip="f", F_trload="t", F_sink="f")) # fmt: skip return pipelines diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_v3_pipeline.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_v3_pipeline.hpp index 14dd9c8db27..29c52e2fd6b 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_v3_pipeline.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_v3_pipeline.hpp @@ -436,7 +436,7 @@ struct BlockFmhaFwdV3Pipeline kN1 == VDramBlockWindowTmp{}.get_window_lengths()[number<1>{}], "wrong!"); - static_assert(sizeof(SaccDataType) * kM0 * kN0 <= GetSmemSize()); + // static_assert(sizeof(SaccDataType) * kM0 * kN0 <= GetSmemSize()); auto s_lds = make_tensor_view( reinterpret_cast(static_cast(smem_ptr)), MakeSimpleLdsDesc()); From c5e0c500401f9ab4601a038148e5f55c0558593a Mon Sep 17 00:00:00 2001 From: "PoYen, Chen" Date: Wed, 14 Jan 2026 11:34:39 -0600 Subject: [PATCH 06/15] Fix fmha_fwd_v3() dispatch logic --- example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py index a1919e954fb..4257b84fbfc 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py @@ -210,10 +210,12 @@ ((0 < args.window_size_left) or (0 < args.window_size_right)); const bool can_dispatch_v3 = (device_name.compare(0, 6, "gfx950") == 0) and - (traits.data_type.compare("fp16") == 0 or traits.data_type.compare("bf16") == 0) and + (((traits.data_type.compare("fp16") == 0 or traits.data_type.compare("bf16") == 0) and + (traits.qscale_type == quant_scale_enum::no_scale)) or + ((traits.data_type.compare("fp8bf16") == 0) and + (traits.qscale_type == quant_scale_enum::pertensor))) and traits.is_v_rowmajor and (traits.bias_type == bias_enum::no_bias) and - (not traits.has_lse) and (not traits.has_dropout) and - (traits.qscale_type == quant_scale_enum::no_scale) and (not is_swa) and + (not traits.has_lse) and (not traits.has_dropout) and (not is_swa) and (args.nhead_q % args.nhead_k == 0) and (args.hdim_q == 128) and (args.hdim_v == 128); if ({F_is_v3_enabled} and can_dispatch_v3) {{ return fmha_fwd_v3(traits, args, config); From 1be789fc69c6c991030b205b32cf15c9e13c671a Mon Sep 17 00:00:00 2001 From: "PoYen, Chen" Date: Tue, 20 Jan 2026 23:58:37 -0600 Subject: [PATCH 07/15] Add missing P tile fp32 -> fp8 converison logic --- .../ops/fmha/pipeline/block_fmha_fwd_v3_pipeline.hpp | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_v3_pipeline.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_v3_pipeline.hpp index 29c52e2fd6b..09ca59be420 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_v3_pipeline.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_v3_pipeline.hpp @@ -853,12 +853,21 @@ struct BlockFmhaFwdV3Pipeline sp(sp_reg_idx).p.thread_buf_[idx] = casted.x; sp(sp_reg_idx).p.thread_buf_[idx + 1] = casted.y; } - else + else if constexpr(std::is_same_v) { auto casted = detail::cvt_pk_bf16_f32(x, y); sp(sp_reg_idx).p.thread_buf_[idx] = casted.x; sp(sp_reg_idx).p.thread_buf_[idx + 1] = casted.y; } + else if constexpr(std::is_same_v) + { + sp(sp_reg_idx).p.thread_buf_[idx] = type_convert(x); + sp(sp_reg_idx).p.thread_buf_[idx + 1] = type_convert(y); + } + else + { + static_assert(false, "unsupported data type for P"); + } }); /// Note: Place fmha_alu1() at the end of the phase. The surrounding inline assembly From 35af63aeed70d7347cb8ec8970cc9bcb3c9f2d4f Mon Sep 17 00:00:00 2001 From: "PoYen, Chen" Date: Wed, 21 Jan 2026 00:21:20 -0600 Subject: [PATCH 08/15] Update functor creation logics --- .../ops/fmha/kernel/fmha_fwd_v3_kernel.hpp | 42 ++++++++++--------- 1 file changed, 22 insertions(+), 20 deletions(-) diff --git a/include/ck_tile/ops/fmha/kernel/fmha_fwd_v3_kernel.hpp b/include/ck_tile/ops/fmha/kernel/fmha_fwd_v3_kernel.hpp index f82c5bc6526..c2e0fe0d4cc 100644 --- a/include/ck_tile/ops/fmha/kernel/fmha_fwd_v3_kernel.hpp +++ b/include/ck_tile/ops/fmha/kernel/fmha_fwd_v3_kernel.hpp @@ -712,29 +712,31 @@ struct FmhaFwdV3Kernel auto o_acc_element_func = [&]() { if constexpr(std::is_same_v) - return ck_tile::composes(ck_tile::saturates{}, - ck_tile::scales{scale_o}); + return make_composes( + ck_tile::saturates{}, + ck_tile::scales>{scale_o}); else - return ck_tile::scales{scale_o}; + return ck_tile::scales>{scale_o}; }(); - return FmhaPipeline{}(q_dram_window, - identity{}, // q_element_func - k_dram_window, - identity{}, // k_element_func - v_dram_window, - identity{}, // v_element_func - lse_dram_window, - identity{}, // lse_element_func - identity{}, // s_acc_element_func - scales{scale_p}, // p_compute_element_func - o_acc_element_func, - mask, - scale_s, - variant, - variant_params, - block_indices, - smem_ptr); + return FmhaPipeline{}( + q_dram_window, + identity{}, // q_element_func + k_dram_window, + identity{}, // k_element_func + v_dram_window, + identity{}, // v_element_func + lse_dram_window, + identity{}, // lse_element_func + identity{}, // s_acc_element_func + scales>{scale_p}, // p_compute_element_func + o_acc_element_func, + mask, + scale_s, + variant, + variant_params, + block_indices, + smem_ptr); } else { From c7a82048f586af11922e4624eadf3631dbf9774c Mon Sep 17 00:00:00 2001 From: "PoYen, Chen" Date: Mon, 26 Jan 2026 04:00:49 -0600 Subject: [PATCH 09/15] Add CLAUDE.md and GEMINI.md --- CLAUDE.md | 112 ++++++++++++++++++++++++++++++++++++++++++++++++++++++ GEMINI.md | 0 2 files changed, 112 insertions(+) create mode 100644 CLAUDE.md create mode 100644 GEMINI.md diff --git a/CLAUDE.md b/CLAUDE.md new file mode 100644 index 00000000000..697b12bbde8 --- /dev/null +++ b/CLAUDE.md @@ -0,0 +1,112 @@ +# CLAUDE.md + +This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository. + +## Build Commands + +```bash +# Development build (from build directory) +mkdir build && cd build +../script/cmake-ck-dev.sh .. # Uses gfx908;gfx90a;gfx942 by default +../script/cmake-ck-dev.sh .. gfx90a # Specific GPU target +make -j32 # Limit threads; ~2GB RAM per thread + +# Manual cmake configuration +cmake -D CMAKE_PREFIX_PATH=/opt/rocm \ + -D CMAKE_CXX_COMPILER=/opt/rocm/bin/hipcc \ + -D CMAKE_BUILD_TYPE=Release \ + -D GPU_TARGETS="gfx90a" \ + .. + +# Common build flags +-D DTYPES="fp16;fp32" # Build only specific data types (speeds up build) +-D CK_USE_FP8_ON_UNSUPPORTED_ARCH=ON # Enable FP8 on MI100/MI200 +-D BUILD_DEV=ON # Development mode with verbose errors +``` + +## Testing + +```bash +make -j check # Build and run all tests +make -j smoke # Quick tests only (<30s each) +make -j regression # Long-running tests (>=30s each) + +# CTest direct commands +ctest --output-on-failure -L "SMOKE_TEST" +ctest --output-on-failure -L "REGRESSION_TEST" + +# Run single test executable +./test/gemm/test_gemm_fp16 +``` + +## Architecture Overview + +### Two Programming Models + +1. **CK (Legacy)** - `/include/ck/`: Traditional template-based approach +2. **CK Tile (Modern)** - `/include/ck_tile/`: Newer unified model with simpler component structure + +CK Tile is independently maintained and preferred for new development. Include single headers like `#include "ck_tile/core.hpp"` or `#include "ck_tile/ops/fmha.hpp"`. + +### Four-Layer Architecture + +1. **Templated Tile Operators** - High-level tile abstractions +2. **Templated Kernel and Invoker** - Generic kernel templates +3. **Instantiated Kernel and Invoker** - Concrete implementations per data type/hardware +4. **Client API** - User-facing interfaces + +### Core Concepts + +- **Tile-Based Programming**: Operations work on tiles (sub-regions) of tensors +- **Tensor Coordinate Transformation**: Maps ND tensor indices to 1D memory offsets through transform primitives (merge/unmerge/embed) +- **Distributed Tensor**: Describes how threads collaboratively process a tensor tile + +### CK Tile Components + +- `core/` - Basic structures: array, tuple, sequence, numeric types, coordinate transforms +- `host/` - Kernel launch utilities, device buffers, reference implementations +- `ops/` - Operation implementations (gemm, fmha, reduce) +- `ref/` - CPU/GPU reference implementations for validation + +### Instance Organization + +`/library/src/tensor_operation_instance/gpu/` contains 100+ operation variants organized by: +- Operation type (gemm, batched_gemm, conv, etc.) +- Data type (fp16, fp32, fp8, bf16, int8) +- Instruction set (xdl, wmma, dl, dpp) + +## Key Directories + +- `/example/ck_tile/01_fmha/` - Flash Multi-Head Attention (main FMHA implementation) +- `/example/01_gemm/` - Foundational GEMM example +- `/profiler/` - Performance benchmarking tool (`make -j ckProfiler`) +- `/test/` - 68+ test directories with smoke/regression classification + +## Supported Hardware + +- **MI Series**: gfx908 (MI100), gfx90a (MI200), gfx942/gfx950 (MI300) +- **NAVI Series**: gfx1030-1032 (NAVI2x), gfx1100-1102 (NAVI3x), gfx1200-1201 (RDNA4) + +## Code Style + +Pre-commit hooks enforce formatting. Install with: +```bash +sudo script/install_precommit.sh +``` + +Bypass temporarily with `git commit --no-verify`. + +## Profiling + +```bash +make -j ckProfiler +./profiler/ckProfiler gemm_xdl -M 4096 -N 4096 -K 4096 -A fp16 -B fp16 -C fp16 +``` + +## sccache (Faster Rebuilds) + +```bash +sccache --start-server +cmake ... -DCMAKE_HIP_COMPILER_LAUNCHER=sccache \ + -DCMAKE_CXX_COMPILER_LAUNCHER=sccache +``` diff --git a/GEMINI.md b/GEMINI.md new file mode 100644 index 00000000000..e69de29bb2d From 298091abbf8cc18dfdbf7839abdcc130dd75d99b Mon Sep 17 00:00:00 2001 From: "PoYen, Chen" Date: Wed, 28 Jan 2026 00:03:05 -0600 Subject: [PATCH 10/15] Ignore AI agent config dir --- .gitignore | 3 +++ 1 file changed, 3 insertions(+) diff --git a/.gitignore b/.gitignore index 740d5464fb9..cee6233a887 100644 --- a/.gitignore +++ b/.gitignore @@ -96,3 +96,6 @@ experimental/grouped_convolution_tile_instances/instances/* !experimental/grouped_convolution_tile_instances/instances/*.in !experimental/grouped_convolution_tile_instances/instances/*.inc experimental/grouped_convolution_tile_instances/*.inc + +# AI agent config +.claude From e6015c344d603f8c0d2be35d3e4630ca0623b712 Mon Sep 17 00:00:00 2001 From: "PoYen, Chen" Date: Wed, 28 Jan 2026 00:23:15 -0600 Subject: [PATCH 11/15] Update CLAUDE.md --- CLAUDE.md | 273 ++++++++++++++++++++++++++++++++++++++++-------------- 1 file changed, 202 insertions(+), 71 deletions(-) diff --git a/CLAUDE.md b/CLAUDE.md index 697b12bbde8..672ce56b8dc 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -2,111 +2,242 @@ This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository. +## Project Overview + +Composable Kernel (CK) is AMD's high-performance GPU kernel library for machine learning workloads. It uses HIP C++ with a tile-based programming model and tensor coordinate transformation techniques. + +**Two implementations exist:** +- **CK Tile** (`include/ck_tile/`) - Modern tile-programming API, preferred for new development +- **Legacy CK** (`include/ck/`) - Older implementation, still supported + ## Build Commands ```bash # Development build (from build directory) mkdir build && cd build -../script/cmake-ck-dev.sh .. # Uses gfx908;gfx90a;gfx942 by default -../script/cmake-ck-dev.sh .. gfx90a # Specific GPU target -make -j32 # Limit threads; ~2GB RAM per thread - -# Manual cmake configuration -cmake -D CMAKE_PREFIX_PATH=/opt/rocm \ - -D CMAKE_CXX_COMPILER=/opt/rocm/bin/hipcc \ - -D CMAKE_BUILD_TYPE=Release \ - -D GPU_TARGETS="gfx90a" \ - .. - -# Common build flags --D DTYPES="fp16;fp32" # Build only specific data types (speeds up build) --D CK_USE_FP8_ON_UNSUPPORTED_ARCH=ON # Enable FP8 on MI100/MI200 --D BUILD_DEV=ON # Development mode with verbose errors -``` +../script/cmake-ck-dev.sh .. "gfx908;gfx90a;gfx942" +make -j32 # Use ~2GB RAM per thread -## Testing +# Build specific targets +make tile_example_fmha_fwd # FMHA forward example +make tile_example_fmha_bwd # FMHA backward example +make ckProfiler # Performance profiler -```bash -make -j check # Build and run all tests -make -j smoke # Quick tests only (<30s each) -make -j regression # Long-running tests (>=30s each) +# Tests +make smoke # Quick tests (<30s each) +make regression # Long tests (>=30s each) +make check # All tests + +# Single test +ctest -R "fmha" -V +``` + +**CMake options:** +- `GPU_TARGETS="gfx908;gfx90a;gfx942"` - Target GPU architectures +- `DTYPES="fp32;fp16;fp8;bf16;int8"` - Data types to build +- `BUILD_DEV=ON` - Development mode -# CTest direct commands -ctest --output-on-failure -L "SMOKE_TEST" -ctest --output-on-failure -L "REGRESSION_TEST" +## Code Formatting -# Run single test executable -./test/gemm/test_gemm_fp16 +CK uses clang-format. Install pre-commit hooks: +```bash +sudo script/install_precommit.sh ``` -## Architecture Overview +Disable temporarily with `git commit --no-verify`. -### Two Programming Models +## Architecture -1. **CK (Legacy)** - `/include/ck/`: Traditional template-based approach -2. **CK Tile (Modern)** - `/include/ck_tile/`: Newer unified model with simpler component structure +### Four-Layer Structure +1. **Templated Tile Operators** - Low-level tile operations +2. **Templated Kernel/Invoker** - Kernel templates with tile operators +3. **Instantiated Kernel/Invoker** - Concrete kernel instances +4. **Client API** - User-facing API -CK Tile is independently maintained and preferred for new development. Include single headers like `#include "ck_tile/core.hpp"` or `#include "ck_tile/ops/fmha.hpp"`. +### Key Directories +- `include/ck_tile/core/` - Core utilities (containers, data types, coordinate transforms) +- `include/ck_tile/ops/` - Operator implementations (fmha, gemm, softmax, etc.) +- `include/ck_tile/ops/fmha/pipeline/` - FMHA pipeline implementations (performance-critical) +- `example/ck_tile/` - Working examples with build recipes +- `codegen/` - Python-based kernel code generation +- `profiler/` - Performance profiling tools -### Four-Layer Architecture +### FMHA (Flash Attention) Architecture -1. **Templated Tile Operators** - High-level tile abstractions -2. **Templated Kernel and Invoker** - Generic kernel templates -3. **Instantiated Kernel and Invoker** - Concrete implementations per data type/hardware -4. **Client API** - User-facing interfaces +#### Directory Structure +``` +include/ck_tile/ops/fmha/ +├── kernel/ # Kernel entry points (fmha_fwd_kernel.hpp, fmha_fwd_v3_kernel.hpp) +├── pipeline/ # Pipeline implementations (performance-critical) +├── block/ # Block-level components (masking, dropout, position encoding) +└── api/ # High-level API wrappers +``` -### Core Concepts +#### Kernel Template Structure -- **Tile-Based Programming**: Operations work on tiles (sub-regions) of tensors -- **Tensor Coordinate Transformation**: Maps ND tensor indices to 1D memory offsets through transform primitives (merge/unmerge/embed) -- **Distributed Tensor**: Describes how threads collaboratively process a tensor tile +The kernel (`FmhaFwdKernel`, `FmhaFwdV3Kernel`) has two key template parameters: +- `FmhaPipeline` - Block tile pipeline handling Q*K and P*V computations +- `EpiloguePipeline` - Post-processing and output storage -### CK Tile Components +**Key data types extracted from pipeline:** +- `QDataType`, `KDataType`, `VDataType` - Input types (fp8, fp16, bf16) +- `PDataType` - Attention probability type after softmax +- `SaccDataType` - Scratch accumulator (typically float) +- `ODataType` - Output type -- `core/` - Basic structures: array, tuple, sequence, numeric types, coordinate transforms -- `host/` - Kernel launch utilities, device buffers, reference implementations -- `ops/` - Operation implementations (gemm, fmha, reduce) -- `ref/` - CPU/GPU reference implementations for validation +**Configuration flags:** +- `kIsGroupMode` - Variable-length sequences via seqstart pointers +- `kPadSeqLenQ/K`, `kPadHeadDimQ/V` - Padding control +- `kHasLogitsSoftCap` - Gemma-style logits softcap +- `kStoreLSE` - Store log-sum-exp for backward pass +- `QScaleEnum` - FP8 quantization (PERTENSOR, NONE) -### Instance Organization +#### Pipeline Implementations -`/library/src/tensor_operation_instance/gpu/` contains 100+ operation variants organized by: -- Operation type (gemm, batched_gemm, conv, etc.) -- Data type (fp16, fp32, fp8, bf16, int8) -- Instruction set (xdl, wmma, dl, dpp) +| Pipeline | Name | Description | +|----------|------|-------------| +| `BlockFmhaPipelineQRKSVS` | "qr" | LDS-based, all QKV in LDS. For medium sequences. | +| `BlockFmhaPipelineQRKSVSAsync` | "qr_async" | Q in registers, async K/V loading. For longer sequences. | +| `BlockFmhaFwdV3Pipeline` | "v3" | Next-gen with warp group coordination and instruction scheduling. | +| `BlockFmhaPipelineSplitKV` | - | Multi-pass with reduction for very long sequences. | +| `BlockFmhaPipelinePagedKV` | - | KV-cache paging for inference. | -## Key Directories +#### Attention Computation Flow (Online Softmax) -- `/example/ck_tile/01_fmha/` - Flash Multi-Head Attention (main FMHA implementation) -- `/example/01_gemm/` - Foundational GEMM example -- `/profiler/` - Performance benchmarking tool (`make -j ckProfiler`) -- `/test/` - 68+ test directories with smoke/regression classification +``` +Phase 1: GEMM0 (Q × K^T → S) +├── Load Q tile (M0 × K0) into registers +├── Loop over K tiles (N0 × K0): +│ ├── Async load K tile to LDS +│ ├── Sync barrier +│ └── Block GEMM with MFMA → S accumulator +└── Apply scale: S *= 1/sqrt(hdim) + +Phase 2: Online Softmax +├── Row-wise max: m_j = max(S_j) +├── Optional: logits softcap (tanh transform) +├── Exponential: P = exp(S - m_j) +├── Row-wise sum: l_j = sum(P_j) +└── Rescale accumulator: O *= exp(m_old - m_new) + +Phase 3: GEMM1 (P × V → O) +├── Convert P to compute type +├── Load V tiles (K1 × N1) +├── Block GEMM with MFMA → O accumulator +└── Finalize: O /= l_j + +Phase 4: Epilogue +├── Convert O to output type +├── Optional: store LSE = m/log(2) + log(l) +└── Write O tile to DRAM +``` -## Supported Hardware +#### Memory Management -- **MI Series**: gfx908 (MI100), gfx90a (MI200), gfx942/gfx950 (MI300) -- **NAVI Series**: gfx1030-1032 (NAVI2x), gfx1100-1102 (NAVI3x), gfx1200-1201 (RDNA4) +**LDS Layout:** +- K tiles: N0 × K0, double-buffered for async prefetch +- V tiles: K1 × N1, bank-conflict-aware padding +- Size computed via `Policy::GetSmemSize()` -## Code Style +**Async Copy Pattern:** +```cpp +async_load_tile_raw(k_lds_window, k_dram_window); // Non-blocking +move_tile_window(k_dram_window, {kN0, 0}); +// ... GEMM computation overlaps with load ... +s_waitcnt_vmcnt<0>(); // Wait before use +``` -Pre-commit hooks enforce formatting. Install with: -```bash -sudo script/install_precommit.sh +**Prefetching Strategy:** Load K[i+1] while computing with K[i] + +#### Block-Level Components + +**Masking (`block_masking.hpp`):** +- `MASK_FROM_TOP_LEFT` - Causal (lower triangular) +- `MASK_FROM_BOTTOM_RIGHT` - Future tokens +- Local attention via `window_size_left/right` +- `GenericAttentionMask::GetTileRangeAlongX()` - Skip fully masked tiles + +**Quantization (`block_attention_quant_scale_enum.hpp`):** +- `NONE` - Standard float operations +- `PERTENSOR` - Single scale per Q/K/V tensor (FP8) +- Flow: `Q_fp8 * scale → float → compute → saturate → O_fp8` + +#### Policy/Trait Configuration + +**TileFmhaTraits** - Core configuration: +```cpp +template < + bool kPadSeqLenQ, kPadSeqLenK, + bool kPadHeadDimQ, kPadHeadDimV, + bool kHasLogitsSoftCap, + BlockAttentionBiasEnum BiasEnum, + bool kStoreLSE, + bool kHasDropout, + BlockAttentionQuantScaleEnum QScaleEnum +> +struct TileFmhaTraits; ``` -Bypass temporarily with `git commit --no-verify`. +**Default Policy** provides: +- Alignment hints for DRAM loads +- GEMM configurations (MFMA instruction selection) +- LDS store/load descriptors +- Register tile distributions -## Profiling +#### Grid/Block Organization -```bash -make -j ckProfiler -./profiler/ckProfiler gemm_xdl -M 4096 -N 4096 -K 4096 -A fp16 -B fp16 -C fp16 +```cpp +dim3 GridSize(batch_size, nhead, ceil(max_seqlen_q / kM0) * ceil(hdim_v / kN1)); +dim3 BlockSize(kBlockSize); // Typically 256-512 threads ``` -## sccache (Faster Rebuilds) +#### V3 Pipeline Optimizations + +- **Warp Group Specialization** - 2 warp groups (4 waves each) with different roles +- **Phase Scheduling** - Explicit barriers for MFMA/VALU/TRANS timing +- **Packed FP32** - `v_pk_mul_f32` for two operations per instruction +- **Fast Exp2** - Bit manipulation approximation + +## Key Concepts + +- **Tile** - Fixed-size data chunk processed by a thread block +- **Block Tile** - Tile owned by entire thread block +- **Wave Tile** - Tile owned by a wavefront (64 threads on AMD) +- **LDS** - Local Data Share (AMD's shared memory) +- **MFMA** - Matrix Fused Multiply-Add (AMD's matrix core instruction) +- **XDL** - Crosslane Data Layout instructions + +See `TERMINOLOGY.md` and `ACRONYMS.md` for complete references. + +## Common Variable Naming + +| Symbol | Meaning | +|--------|---------| +| M, N, K | GEMM dimensions: A[M,K] × B[K,N] = C[M,N] | +| Q, K, V | Query, Key, Value (attention) | +| S | Sequence length | +| D | Head dimension | +| B | Batch size | +| H | Number of attention heads | + +## Running FMHA Examples ```bash -sccache --start-server -cmake ... -DCMAKE_HIP_COMPILER_LAUNCHER=sccache \ - -DCMAKE_CXX_COMPILER_LAUNCHER=sccache +# Basic FMHA forward +./bin/tile_example_fmha_fwd -b=1 -h=16 -s=16384 -d=128 + +# With FP8 +./bin/tile_example_fmha_fwd -b=1 -h=8 -s=4096 -d=128 -prec=fp8 + +# Group mode (variable length) +./bin/tile_example_fmha_fwd -mode=1 -b=2 -h=8 -s=1024,2048 -d=128 + +# With causal mask +./bin/tile_example_fmha_fwd -b=1 -h=8 -s=4096 -d=128 -mask=t ``` + +Use `-?` flag to see all options. + +## Codegen System + +Kernels are instantiated into separate files via Python scripts in `codegen/` to enable parallel compilation. Example: `example/ck_tile/01_fmha/codegen/generate.py`. From 078ee116ae1949499d1e0c39da7c168bf9cadf46 Mon Sep 17 00:00:00 2001 From: "PoYen, Chen" Date: Fri, 30 Jan 2026 09:49:28 -0600 Subject: [PATCH 12/15] [DEBUG] Use simpler V LDS descriptor #1 (without padding) --- ...ck_fmha_fwd_v3_pipeline_default_policy.hpp | 21 ++++++++++--------- 1 file changed, 11 insertions(+), 10 deletions(-) diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_v3_pipeline_default_policy.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_v3_pipeline_default_policy.hpp index 957e404b35b..e8df0fda604 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_v3_pipeline_default_policy.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_v3_pipeline_default_policy.hpp @@ -140,9 +140,10 @@ struct BlockFmhaV3PipelineDefaultPolicy constexpr index_t NumIssues = kNPerBlock / (LaneGroups * NumWarps); static_assert(NumIssues == kNPerBlock * kKPerBlock / (kBlockSize * KVector)); + // Swap NumWarps and LaneGroups to store V in non-swizzled layout in LDS constexpr index_t N0 = NumIssues; - constexpr index_t N1 = LaneGroups; - constexpr index_t N2 = NumWarps; + constexpr index_t N1 = NumWarps; // was LaneGroups + constexpr index_t N2 = LaneGroups; // was NumWarps constexpr index_t K0 = LanesPerK; constexpr index_t K1 = KVector; @@ -150,7 +151,7 @@ struct BlockFmhaV3PipelineDefaultPolicy tile_distribution_encoding, tuple, sequence>, tuple, sequence<1, 2>>, - tuple, sequence<1, 0>>, + tuple, sequence<2, 0>>, sequence<1, 2>, sequence<0, 1>>{}); } @@ -315,8 +316,8 @@ struct BlockFmhaV3PipelineDefaultPolicy return BlockGemmARegBRegCRegV2{}; } - static constexpr ck_tile::index_t kKLdsPadInBytes = 4 * 4; // 4 dwords - static constexpr ck_tile::index_t kVLdsPadInBytes = 4 * 16; // 16 dwords + static constexpr ck_tile::index_t kKLdsPadInBytes = 0; // 4 dwords + static constexpr ck_tile::index_t kVLdsPadInBytes = 0; // 16 dwords template CK_TILE_DEVICE static constexpr auto @@ -497,13 +498,13 @@ struct BlockFmhaV3PipelineDefaultPolicy constexpr auto v_lds_block_desc_0 = make_naive_tensor_descriptor_with_offset( make_tuple(number{}, // n0 - number{}, // n1 number{}, // n2 + number{}, // n1 number{}, // k0 number{}), // k1 make_tuple(number{}, - number{}, number{}, + number{}, number{}, number<1>{}), number<(IBuf + 2) * GetSingleSmemElementSpaceSize()>{}, @@ -518,7 +519,7 @@ struct BlockFmhaV3PipelineDefaultPolicy make_pass_through_transform(number{}), make_merge_transform(make_tuple( number{}, number{}, number{}))), - make_tuple(sequence<0>{}, sequence<2>{}, sequence<1, 3, 4>{}), + make_tuple(sequence<0>{}, sequence<1>{}, sequence<2, 3, 4>{}), make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{})); return v_lds_block_desc_issues_warps_lanes; @@ -566,9 +567,9 @@ struct BlockFmhaV3PipelineDefaultPolicy v_lds_block_desc_0, make_tuple( make_merge_transform( - make_tuple(number{}, number{}, number{})), + make_tuple(number{}, number{}, number{})), make_merge_transform(make_tuple(number{}, number{}))), - make_tuple(sequence<0, 2, 1>{}, sequence<3, 4>{}), + make_tuple(sequence<0, 1, 2>{}, sequence<3, 4>{}), make_tuple(sequence<0>{}, sequence<1>{})); return v_lds_block_desc; From 7e2c9f4e4d66368ea369c4f8c2c9d9b5d5f6402b Mon Sep 17 00:00:00 2001 From: "PoYen, Chen" Date: Sat, 31 Jan 2026 13:45:33 -0600 Subject: [PATCH 13/15] Fix FP8 K tile half-stride bug in v3 pipeline Use WarpGemmMfma_f32_32x32x32_fp8_fp8_CTransposed instead of WarpGemmMfmaFp8Fp8F32M32N32K32SwizzleBTransposedCDistribution for GetQKBlockGemm() to avoid incorrect row stride during K tile loading. The SwizzleB encoding caused lane N to receive K row N/2 instead of row N. Co-Authored-By: Claude (claude-opus-4.5) --- .../pipeline/block_fmha_fwd_v3_pipeline_default_policy.hpp | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_v3_pipeline_default_policy.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_v3_pipeline_default_policy.hpp index e8df0fda604..2c626dfee12 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_v3_pipeline_default_policy.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_v3_pipeline_default_policy.hpp @@ -245,9 +245,10 @@ struct BlockFmhaV3PipelineDefaultPolicy std::is_same_v && std::is_same_v) { - constexpr index_t swizzle_factor = 4; - return WarpGemmMfmaFp8Fp8F32M32N32K32SwizzleBTransposedCDistribution< - swizzle_factor>{}; + /// NOTICE: in order to use load_tile() for K tile with correct row stride, + /// we cannot use WarpGemmMfmaFp8Fp8F32M32N32K32SwizzleBTransposedCDistribution here + /// because SwizzleB encoding has half-stride issue for K tile loading + return WarpGemmMfma_f32_32x32x32_fp8_fp8_CTransposed<>{}; } else if constexpr(std::is_same_v && std::is_same_v && From dbf5efecd026b26346776e9819a25d082984d481 Mon Sep 17 00:00:00 2001 From: "PoYen, Chen" Date: Mon, 2 Feb 2026 11:59:34 -0600 Subject: [PATCH 14/15] Disable qr_async_trload pipeline and dispatch to qr_async_trload only for fp8bf16 --- .../ck_tile/01_fmha/codegen/ops/fmha_fwd.py | 29 +++++++------------ 1 file changed, 11 insertions(+), 18 deletions(-) diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py index 4257b84fbfc..3885f23a66c 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py @@ -210,16 +210,21 @@ ((0 < args.window_size_left) or (0 < args.window_size_right)); const bool can_dispatch_v3 = (device_name.compare(0, 6, "gfx950") == 0) and - (((traits.data_type.compare("fp16") == 0 or traits.data_type.compare("bf16") == 0) and - (traits.qscale_type == quant_scale_enum::no_scale)) or + ((false) or ((traits.data_type.compare("fp8bf16") == 0) and (traits.qscale_type == quant_scale_enum::pertensor))) and traits.is_v_rowmajor and (traits.bias_type == bias_enum::no_bias) and (not traits.has_lse) and (not traits.has_dropout) and (not is_swa) and (args.nhead_q % args.nhead_k == 0) and (args.hdim_q == 128) and (args.hdim_v == 128); + if(traits.data_type.compare("fp8bf16") == 0) + printf("[POYENC] can_dispatch_v3(=%d) -> ", can_dispatch_v3); if ({F_is_v3_enabled} and can_dispatch_v3) {{ + if(traits.data_type.compare("fp8bf16") == 0) + printf("dispatch to fmha_fwd_v3()\\n"); return fmha_fwd_v3(traits, args, config); }} else {{ + if(traits.data_type.compare("fp8bf16") == 0) + printf("dispatch to fmha_fwd_v2()\\n"); return fmha_fwd_v2(traits, args, config); }} }} @@ -936,27 +941,14 @@ def get_hdim_tile_size_dict(cls, dtype: str) -> Optional[dict]: } # fmt: skip elif dtype in cls._DT_FP16_BF16: return { - ( 32, 32) : [FmhaFwdTileSize(128, 64, 16, 32, 32, 32, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)], - ( 64, 64) : [FmhaFwdTileSize( 16, 32, 64, 64, 32, 64, 1, 1, 1, 1, 1, 1, 16, 16, 32, 16, 16, 32, -1), - FmhaFwdTileSize( 32, 32, 64, 64, 32, 64, 1, 1, 1, 1, 1, 1, 32, 32, 16, 32, 32, 16, -1), - FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)], - ( 80, 96) : [FmhaFwdTileSize(128, 128, 16, 96, 32, 80, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)], - ( 96, 128) : [FmhaFwdTileSize(128, 128, 32, 128, 32, 96, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)], (128, 128) : [FmhaFwdTileSize( 16, 32, 64, 128, 32, 128, 1, 1, 1, 1, 1, 1, 16, 16, 32, 16, 16, 32, -1), FmhaFwdTileSize( 32, 32, 128, 128, 32, 128, 1, 1, 1, 1, 1, 1, 32, 32, 16, 32, 32, 16, -1), FmhaFwdTileSize(128, 64, 32, 128, 16, 128, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1), FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)], - # (160, 160) : [FmhaFwdTileSize(128, 128 , 32, 160, 32, 160, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, 1)], - (192, 128) : [FmhaFwdTileSize(128, 128, 32, 128, 32, 192, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)], - (192, 192) : [FmhaFwdTileSize(128, 128, 32, 192, 32, 192, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, 1)], - (256, 256) : [FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)], } # fmt: skip elif dtype in cls._DT_FP8 or dtype in cls._DT_FP8BF16: return { - ( 64, 64) : [FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 2, 1, 1, 2, 1, 1, 32, 32, 32, 32, 32, 32, -1)], (128, 128) : [FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 32, 32, 32, 32, -1)], - (192, 128) : [FmhaFwdTileSize(128, 128, 32, 128, 32, 192, 4, 1, 1, 4, 1, 1, 32, 32, 32, 32, 32, 32, -1)], - (256, 256) : [FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 32, 32, 32, 32, 32, 32, -1)], } # fmt: skip elif dtype in cls._DT_FP8FP32: return { @@ -1065,6 +1057,7 @@ def get_pipelines( ) if dtype in cls._DT_FP16_BF16: qscale = "no" + """ for logits, mask, bias, lse, dropout, skip, sink in itertools.product( ["t", "f"], get_mask_map(mask_impl).keys(), @@ -1090,7 +1083,7 @@ def get_pipelines( for logits, mask in itertools.product(["t", "f"], ["no", "causal"]): pipelines.append(FmhaFwdPipeline("qr_async_trload_v3", "row", "t", "t", "f", "f", F_logits=logits, F_bias="no", F_lse="f", F_dropout="f", F_qscale=qscale, F_mask=mask, F_skip="f", F_trload="t", F_sink="f")) # fmt: skip - + """ elif dtype in cls._DT_FP8BF16: # no need lse/dropout kernels # qr_async_trload_v3 only supports (generic) causal mask @@ -1380,8 +1373,8 @@ def accept_only_v2(trait: FmhaFwdApiTrait) -> bool: FMHA_FWD_API_FOOTER_TEMPLATE.format( F_is_v3_enabled=BOOL_MAP[ # NOTE: enable v3 pipelines when ready - # 0 < api_pool.get_num_traits(filter_fn=accept_only_v3) - False + 0 < api_pool.get_num_traits(filter_fn=accept_only_v3) + # False ] ), ] From d2dcb0b24c23e6c723d5650719b8702740ff2362 Mon Sep 17 00:00:00 2001 From: "PoYen, Chen" Date: Mon, 2 Feb 2026 12:00:28 -0600 Subject: [PATCH 15/15] Fix wrong mask kargs --- example/ck_tile/01_fmha/mask.hpp | 3 +++ 1 file changed, 3 insertions(+) diff --git a/example/ck_tile/01_fmha/mask.hpp b/example/ck_tile/01_fmha/mask.hpp index f85b811116b..010224cb0ea 100644 --- a/example/ck_tile/01_fmha/mask.hpp +++ b/example/ck_tile/01_fmha/mask.hpp @@ -148,6 +148,9 @@ struct mask_info else if(str == "0") { tmp.type = mask_enum::no_mask; + tmp.left = -1; + tmp.right = -1; + tmp.sink = 0; } else if(str == "1" || str == "t") {