From 0aa1142bb589671a475a1c68f2615b1d48899df7 Mon Sep 17 00:00:00 2001 From: Jeff Huang Date: Thu, 29 Jan 2026 09:36:55 +0800 Subject: [PATCH 1/6] [CK] Add FP8 KV_BLOCKSCALE support for batch prefill Implement per-page K/V quantization for paged attention: - Add KV_BLOCKSCALE enum to BlockAttentionQuantScaleEnum - Use exp2 shift trick to eliminate explicit P scaling overhead - Prefetch physical pages offset for KV cache, overlaps with computations --- .../ck_tile/01_fmha/codegen/cpp_symbol_map.py | 2 + .../01_fmha/codegen/ops/fmha_batch_prefill.py | 5 +- example/ck_tile/01_fmha/fmha_fwd.hpp | 20 +- example/ck_tile/01_fmha/quant.hpp | 13 +- .../block_attention_quant_scale_enum.hpp | 12 +- .../fmha/kernel/fmha_batch_prefill_kernel.hpp | 106 +++- ..._batch_prefill_pipeline_qr_ks_vs_async.hpp | 504 +++++++++++++++--- 7 files changed, 558 insertions(+), 104 deletions(-) diff --git a/example/ck_tile/01_fmha/codegen/cpp_symbol_map.py b/example/ck_tile/01_fmha/codegen/cpp_symbol_map.py index cac6671ca5f..995fc8c9659 100644 --- a/example/ck_tile/01_fmha/codegen/cpp_symbol_map.py +++ b/example/ck_tile/01_fmha/codegen/cpp_symbol_map.py @@ -78,12 +78,14 @@ def get_mask_cpp_check_expr(mask: str) -> str: "no": "ck_tile::BlockAttentionQuantScaleEnum::NO_SCALE", "pertensor": "ck_tile::BlockAttentionQuantScaleEnum::PERTENSOR", "blockscale": "ck_tile::BlockAttentionQuantScaleEnum::BLOCKSCALE", + "kv_blockscale": "ck_tile::BlockAttentionQuantScaleEnum::KV_BLOCKSCALE", } QSCALE_CHECK_MAP = { "no": "quant_scale_enum::no_scale", "pertensor": "quant_scale_enum::pertensor", "blockscale": "quant_scale_enum::blockscale", + "kv_blockscale": "quant_scale_enum::kv_blockscale", } BIAS_MAP = { diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_batch_prefill.py b/example/ck_tile/01_fmha/codegen/ops/fmha_batch_prefill.py index 42f686e0c00..b575adc7d05 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_batch_prefill.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_batch_prefill.py @@ -677,7 +677,7 @@ def get_pipelines(dtype, hdim, receipt, mask_impl) -> List[FmhaFwdPipeline]: kv_lookup_table, ) in itertools.product( ["t", "f"], - ["pertensor"], + ["pertensor", "kv_blockscale"], get_mask_map(mask_impl).keys(), ["no"], SUPPORTED_KV_MEMORY_LAYOUT, @@ -740,6 +740,9 @@ def get_fwd_blobs( for page_size in SUPPORTED_PAGE_SIZE: if page_size == 1 and pipeline.F_kv_memory_layout != "linear": continue + # kv_blockscale only supports page_size=1024 + if pipeline.F_qscale == "kv_blockscale" and page_size != 1024: + continue k = FmhaFwdKernel( F_idx=0, F_hdim=hdim, diff --git a/example/ck_tile/01_fmha/fmha_fwd.hpp b/example/ck_tile/01_fmha/fmha_fwd.hpp index aedbb0e17c2..1fe14982a15 100644 --- a/example/ck_tile/01_fmha/fmha_fwd.hpp +++ b/example/ck_tile/01_fmha/fmha_fwd.hpp @@ -602,6 +602,14 @@ struct fmha_batch_prefill_args std::variant, std::pair> drop_seed_offset; + + // KV_BLOCKSCALE: per-page K/V descales (Q per-tensor, K/V per-page) + // Layout: [num_block, num_kv_head, 2] where 2 = (k_descale, v_descale) + // Mutually exclusive with per-tensor k_descale_ptr/v_descale_ptr + const void* kv_block_descale_ptr = nullptr; + ck_tile::index_t kv_block_descale_stride_block = 0; // Stride along num_block dimension + ck_tile::index_t kv_block_descale_stride_head = 0; // Stride along num_kv_head dimension + ck_tile::index_t kv_block_descale_stride_kv = 1; // Stride for K/V index (last dim) }; template @@ -1225,7 +1233,11 @@ auto fmha_batch_prefill_create_kargs_and_grids(fmha_batch_prefill_args args) args.p_drop, args.s_randval, args.drop_seed_offset, - args.sink_ptr); + args.sink_ptr, + args.kv_block_descale_ptr, + args.kv_block_descale_stride_block, + args.kv_block_descale_stride_head, + args.kv_block_descale_stride_kv); } else { // create batch mode kernel arguments @@ -1278,7 +1290,11 @@ auto fmha_batch_prefill_create_kargs_and_grids(fmha_batch_prefill_args args) args.p_drop, args.s_randval, args.drop_seed_offset, - args.sink_ptr); + args.sink_ptr, + args.kv_block_descale_ptr, + args.kv_block_descale_stride_block, + args.kv_block_descale_stride_head, + args.kv_block_descale_stride_kv); } }(); diff --git a/example/ck_tile/01_fmha/quant.hpp b/example/ck_tile/01_fmha/quant.hpp index da588910b23..833a025f798 100644 --- a/example/ck_tile/01_fmha/quant.hpp +++ b/example/ck_tile/01_fmha/quant.hpp @@ -14,9 +14,10 @@ // keep sync with BlockAttentionQuantScaleEnum enum class quant_scale_enum { - no_scale = 0, - pertensor = 1, - blockscale, + no_scale = 0, + pertensor = 1, + blockscale = 2, + kv_blockscale = 3, // Q per-tensor, K/V per-page block scale }; struct quant_scale_info @@ -31,6 +32,8 @@ struct quant_scale_info os << "pt"; else if(type == quant_scale_enum::blockscale) os << "bs"; + else if(type == quant_scale_enum::kv_blockscale) + os << "kvbs"; } static quant_scale_info decode(std::string str) @@ -48,6 +51,10 @@ struct quant_scale_info { info.type = quant_scale_enum::blockscale; } + else if(str == "kvbs" || str == "3") + { + info.type = quant_scale_enum::kv_blockscale; + } else { throw std::invalid_argument("invalid quant scale value: " + str); diff --git a/include/ck_tile/ops/fmha/block/block_attention_quant_scale_enum.hpp b/include/ck_tile/ops/fmha/block/block_attention_quant_scale_enum.hpp index 7e0f704bef8..84a2321708d 100644 --- a/include/ck_tile/ops/fmha/block/block_attention_quant_scale_enum.hpp +++ b/include/ck_tile/ops/fmha/block/block_attention_quant_scale_enum.hpp @@ -10,9 +10,10 @@ namespace ck_tile { // This class is used for codegen pattern matching enum class BlockAttentionQuantScaleEnum { - NO_SCALE = 0, - PERTENSOR = 1, - BLOCKSCALE, + NO_SCALE = 0, + PERTENSOR = 1, + BLOCKSCALE = 2, + KV_BLOCKSCALE = 3, // Q per-tensor, K/V per-page block scale }; template @@ -33,5 +34,10 @@ struct BlockAttentionQuantScaleEnumToStr +struct BlockAttentionQuantScaleEnumToStr +{ + static constexpr const char* name = "kv_blockscale"; +}; } // namespace ck_tile diff --git a/include/ck_tile/ops/fmha/kernel/fmha_batch_prefill_kernel.hpp b/include/ck_tile/ops/fmha/kernel/fmha_batch_prefill_kernel.hpp index 86e1de3e9fd..03303a0683d 100644 --- a/include/ck_tile/ops/fmha/kernel/fmha_batch_prefill_kernel.hpp +++ b/include/ck_tile/ops/fmha/kernel/fmha_batch_prefill_kernel.hpp @@ -185,13 +185,44 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel ck_tile::index_t batch_stride_lse = 0; }; - struct FmhaFwdCommonQScaleKargs + // PERTENSOR: Q/K/V all use per-tensor descales + struct FmhaFwdPerTensorQScaleKargs { const void* q_descale_ptr = nullptr; const void* k_descale_ptr = nullptr; const void* v_descale_ptr = nullptr; }; + // KV_BLOCKSCALE: Q per-tensor, K/V per-page descales + struct FmhaFwdKVBlockScaleKargs + { + const void* q_descale_ptr = nullptr; // Per-tensor Q descale + const void* kv_block_descale_ptr = nullptr; // [num_block, num_kv_head, 2] + ck_tile::index_t kv_block_descale_stride_block = 0; // Stride along num_block dimension + ck_tile::index_t kv_block_descale_stride_head = 0; // Stride along num_kv_head dimension + ck_tile::index_t kv_block_descale_stride_kv = 1; // Stride for K/V index + }; + + // Helper template to select QScale Kargs type based on QScaleEnum + // EmptyType: type to use when QScaleEnum is NO_SCALE (e.g., FmhaFwdEmptyKargs<3>) + template + struct QScaleKargsSelector + { + using type = EmptyType; + }; + + template + struct QScaleKargsSelector + { + using type = FmhaFwdPerTensorQScaleKargs; + }; + + template + struct QScaleKargsSelector + { + using type = FmhaFwdKVBlockScaleKargs; + }; + struct FmhaFwdDropoutSeedOffset { template @@ -255,9 +286,7 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel FmhaFwdEmptyKargs<0>>>, std::conditional_t>, std::conditional_t>, - std::conditional_t>, + QScaleKargsSelector>::type, std::conditional_t>, std::conditional_t> { @@ -276,9 +305,7 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel FmhaFwdEmptyKargs<0>>>, std::conditional_t>, std::conditional_t>, - std::conditional_t>, + QScaleKargsSelector>::type, std::conditional_t>, std::conditional_t> { @@ -348,7 +375,11 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel bool s_randval, std::variant, std::pair> drop_seed_offset, - const void* sink_ptr = nullptr) + const void* sink_ptr = nullptr, + const void* kv_block_descale_ptr = nullptr, + ck_tile::index_t kv_block_descale_stride_block = 0, + ck_tile::index_t kv_block_descale_stride_head = 0, + ck_tile::index_t kv_block_descale_stride_kv = 1) { Kargs kargs{{q_ptr, k_ptr, @@ -419,6 +450,14 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel kargs.k_descale_ptr = k_descale_ptr; kargs.v_descale_ptr = v_descale_ptr; } + else if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::KV_BLOCKSCALE) + { + kargs.q_descale_ptr = q_descale_ptr; + kargs.kv_block_descale_ptr = kv_block_descale_ptr; + kargs.kv_block_descale_stride_block = kv_block_descale_stride_block; + kargs.kv_block_descale_stride_head = kv_block_descale_stride_head; + kargs.kv_block_descale_stride_kv = kv_block_descale_stride_kv; + } if constexpr(kHasDropout) { if(drop_seed_offset.index() == 0) // seed & offset come from host @@ -495,7 +534,11 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel bool s_randval, std::variant, std::pair> drop_seed_offset, - const void* sink_ptr = nullptr) + const void* sink_ptr = nullptr, + const void* kv_block_descale_ptr = nullptr, + ck_tile::index_t kv_block_descale_stride_block = 0, + ck_tile::index_t kv_block_descale_stride_head = 0, + ck_tile::index_t kv_block_descale_stride_kv = 1) { Kargs kargs{{q_ptr, k_ptr, @@ -563,6 +606,14 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel kargs.k_descale_ptr = k_descale_ptr; kargs.v_descale_ptr = v_descale_ptr; } + else if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::KV_BLOCKSCALE) + { + kargs.q_descale_ptr = q_descale_ptr; + kargs.kv_block_descale_ptr = kv_block_descale_ptr; + kargs.kv_block_descale_stride_block = kv_block_descale_stride_block; + kargs.kv_block_descale_stride_head = kv_block_descale_stride_head; + kargs.kv_block_descale_stride_kv = kv_block_descale_stride_kv; + } if constexpr(kHasDropout) { if(drop_seed_offset.index() == 0) // seed & offset come from host @@ -1162,6 +1213,12 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel return kargs.scale_s * q_descale * k_descale; } + else if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::KV_BLOCKSCALE) + { + // Q is per-tensor, K is per-page (handled in pipeline) + float q_descale = *(reinterpret_cast(kargs.q_descale_ptr)); + return kargs.scale_s * q_descale; + } else { return kargs.scale_s; @@ -1237,6 +1294,37 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel dropout, sink_value); } + else if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::KV_BLOCKSCALE) + { + // KV_BLOCKSCALE: K/V descale is per-page, handled in pipeline + const float* kv_block_descale_ptr = + reinterpret_cast(kargs.kv_block_descale_ptr); + + return FmhaPipeline{}(q_dram_window, + k_dram_window, + v_dram_window, + bias_dram_window, + randval_dram_window, + lse_dram_window, + mask, + position_encoding, + variant_params.sm_scale, + variant, + variant_params, + block_indices, + smem_ptr, + page_idx, + stride_k_for_pipeline, + stride_v_for_pipeline, + kargs.batch_stride_k, + kargs.batch_stride_v, + dropout, + sink_value, + kv_block_descale_ptr, + kargs.kv_block_descale_stride_block, + kargs.kv_block_descale_stride_head, + kargs.kv_block_descale_stride_kv); + } else { return FmhaPipeline{}(q_dram_window, diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp index 48e8f75ae7e..7622778c89e 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp @@ -7,14 +7,21 @@ #include "ck_tile/ops/common/tensor_layout.hpp" #include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp" #include "ck_tile/ops/fmha/block/block_attention_kvcache_layout_enum.hpp" +#include "ck_tile/ops/fmha/block/block_attention_quant_scale_enum.hpp" #include "ck_tile/ops/fmha/block/block_dropout.hpp" #include "ck_tile/ops/fmha/block/variants.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_batch_prefill_pipeline_qr_ks_vs_async_default_policy.hpp" #include "ck_tile/ops/reduce/block/block_reduce.hpp" namespace ck_tile { -template -CK_TILE_HOST_DEVICE void kv_offset_array_transform(const index_t* page_idx, - const index_t& stride_token, - const index_t& stride_page_block, - const CoordVecType& coord_vec, - OffsetVecType& kv_offset_vec, - index_t global_seq_offset = 0) + index_t kN0> +CK_TILE_DEVICE void load_physical_pages(const index_t* page_idx, + const CoordVecType& coord_vec, + index_t global_seq_offset, + index_t (&physical_pages)[kLoopCount]) { static constexpr index_t kLog2PageSize = [] { index_t shift = 0; @@ -42,18 +46,16 @@ CK_TILE_HOST_DEVICE void kv_offset_array_transform(const index_t* page_idx, return shift; }(); - const index_t& thread_coord_start = coord_vec[kCoordAxis]; - constexpr index_t kInPageOffsetMask = (1 << kLog2PageSize) - 1; + const index_t& thread_coord_start = coord_vec[kCoordAxis]; + if constexpr(kIsKcache) { - // for k offsets + // K cache: per-token lookup (all tokens may be on different pages) static_for<0, kLoopCount, 1>{}([&](auto k0) { const index_t global_token_idx = global_seq_offset + thread_coord_start + kLoopStart + kLoopStride * k0.value; - const index_t page_id = global_token_idx >> kLog2PageSize; - const index_t token_idx_in_page = global_token_idx & kInPageOffsetMask; - kv_offset_vec[k0] = static_cast(page_idx[page_id]) * stride_page_block + - static_cast(token_idx_in_page) * stride_token; + const index_t page_id = global_token_idx >> kLog2PageSize; + physical_pages[k0.value] = page_idx[page_id]; }); } else @@ -71,11 +73,7 @@ CK_TILE_HOST_DEVICE void kv_offset_array_transform(const index_t* page_idx, static_for<0, kLoopCount, 1>{}([&](auto k0) { const index_t global_token_idx = global_seq_offset + thread_coord_start + kLoopStart + kLoopStride * k0.value; - - const long_index_t page_base_offset = - static_cast(page_idx[global_token_idx]) * stride_page_block; - - kv_offset_vec[k0] = page_base_offset; + physical_pages[k0.value] = page_idx[global_token_idx]; }); } else if constexpr(kVTileCrossesPages) @@ -85,70 +83,131 @@ CK_TILE_HOST_DEVICE void kv_offset_array_transform(const index_t* page_idx, static_for<0, kLoopCount, 1>{}([&](auto k0) { const index_t global_token_idx = global_seq_offset + thread_coord_start + kLoopStart + kLoopStride * k0.value; - const index_t page_id = global_token_idx >> kLog2PageSize; - const index_t token_idx_in_page = global_token_idx & kInPageOffsetMask; - - const long_index_t page_base_offset = - static_cast(page_idx[page_id]) * stride_page_block; - - if constexpr(kKVMemoryLayout == - BlockAttentionKVCacheMemoryLayoutEnum::VECTORIZED_LAYOUT) - { - // Vectorized layout uses a packed [token/kVectorSize, head_dim, kVectorSize] - // address pattern. - const long_index_t token_offset = - static_cast((token_idx_in_page / kVectorSize) * - (stride_token * kVectorSize)) + - (token_idx_in_page % kVectorSize); - - kv_offset_vec[k0] = page_base_offset + token_offset; - } - else // BlockAttentionKVCacheMemoryLayoutEnum::LINEAR_LAYOUT - { - kv_offset_vec[k0] = page_base_offset + - static_cast(token_idx_in_page) * stride_token; - } + const index_t page_id = global_token_idx >> kLog2PageSize; + physical_pages[k0.value] = page_idx[page_id]; }); } - else // !kVTileCrossesPages + else { - // V tile is fully contained in one page, so page_id is shared. - // Use lane0 to compute page_id once and broadcast page_base_offset. + // V tile fully contained in one page: lane0 lookup, broadcast to all const index_t lane0_start = __builtin_amdgcn_readfirstlane(thread_coord_start); const index_t lane0_page_id = (global_seq_offset + lane0_start + kLoopStart) >> kLog2PageSize; + const index_t shared_physical_page = page_idx[lane0_page_id]; - const long_index_t page_base_offset = - static_cast(page_idx[lane0_page_id]) * stride_page_block; + static_for<0, kLoopCount, 1>{}( + [&](auto k0) { physical_pages[k0.value] = shared_physical_page; }); + } + } +} - static_for<0, kLoopCount, 1>{}([&](auto k0) { - // kLoopStride allows non-unit token spacing in the tile distribution. - const index_t token_idx_in_page = - (global_seq_offset + thread_coord_start + kLoopStart + kLoopStride * k0.value) & - kInPageOffsetMask; +// kv_offset_array_transform: Converts logical token indices to physical memory offsets +// for paged KV cache access. +// +// This version uses pre-loaded physical_pages array from load_physical_pages(). +// Benefits: +// - page_idx is read only once (by load_physical_pages) +// - physical_pages can be prefetched before GEMM to hide memory latency +// - physical_pages can be reused for descale lookup (KV_BLOCKSCALE) +// +// Template parameters: +// - kCoordAxis: Which axis of coord_vec contains the thread's token coordinate +// - kPageBlockSize: Number of tokens per page (must be power of 2) +// - kLoopStart/kLoopCount/kLoopStride: Loop iteration parameters for static_for +// - kKVMemoryLayout: VECTORIZED_LAYOUT or LINEAR_LAYOUT +// - kIsKcache: true for K cache, false for V cache +// - kN0: Tile size in N dimension (used for page crossing detection) +// - kVectorSize: Vector size for vectorized layout (e.g., 8 for fp8) +// +// Memory layout for V cache: +// LINEAR_LAYOUT: [page, token_in_page, head_dim] +// VECTORIZED_LAYOUT: [page, token_in_page/kVectorSize, head_dim, kVectorSize] +// +template +CK_TILE_HOST_DEVICE void kv_offset_array_transform(const index_t (&physical_pages)[kLoopCount], + const index_t& stride_token, + const index_t& stride_page_block, + const CoordVecType& coord_vec, + OffsetVecType& kv_offset_vec, + index_t global_seq_offset = 0) +{ + static constexpr index_t kLog2PageSize = [] { + index_t shift = 0; + index_t val = kPageBlockSize; + while(val > 1) + { + val >>= 1; + shift++; + } + return shift; + }(); - if constexpr(kKVMemoryLayout == - BlockAttentionKVCacheMemoryLayoutEnum::VECTORIZED_LAYOUT) - { - // Vectorized layout offset - // Layout: [BlockSize/kVectorSize, HeadDim, kVectorSize] - // Offset = (token_idx_in_page / kVectorSize) * (HeadDim * kVectorSize) + - // (token_idx_in_page % kVectorSize) + const index_t& thread_coord_start = coord_vec[kCoordAxis]; + constexpr index_t kInPageOffsetMask = (1 << kLog2PageSize) - 1; - const long_index_t token_offset = - static_cast((token_idx_in_page / kVectorSize) * - (stride_token * kVectorSize)) + - (token_idx_in_page % kVectorSize); + if constexpr(kIsKcache) + { + // K cache: per-token lookup + // Each token may be on a different page, so we use physical_pages[k0] for each. + // Offset = physical_page * stride_page_block + token_idx_in_page * stride_token + static_for<0, kLoopCount, 1>{}([&](auto k0) { + const index_t global_token_idx = + global_seq_offset + thread_coord_start + kLoopStart + kLoopStride * k0.value; + const index_t token_idx_in_page = global_token_idx & kInPageOffsetMask; + const index_t physical_page = physical_pages[k0.value]; - kv_offset_vec[k0] = page_base_offset + token_offset; - } - else // BlockAttentionKVCacheMemoryLayoutEnum::LINEAR_LAYOUT - { - kv_offset_vec[k0] = page_base_offset + - static_cast(token_idx_in_page) * stride_token; - } - }); - } + kv_offset_vec[k0] = static_cast(physical_page) * stride_page_block + + static_cast(token_idx_in_page) * stride_token; + }); + } + else // !kVTileCrossesPages + { + // V cache: use physical_pages[k0] for each token + // physical_pages was already populated correctly by load_physical_pages(), handling: + // - page_size=1: page_idx maps token_idx -> physical_page directly + // - V tile crosses pages: per-token page lookup + // - V tile in single page: lane0 lookup with broadcast to all lanes + static_for<0, kLoopCount, 1>{}([&](auto k0) { + const index_t global_token_idx = + global_seq_offset + thread_coord_start + kLoopStart + kLoopStride * k0.value; + const index_t token_idx_in_page = global_token_idx & kInPageOffsetMask; + const index_t physical_page = physical_pages[k0.value]; + + const long_index_t page_base_offset = + static_cast(physical_page) * stride_page_block; + + if constexpr(kKVMemoryLayout == + BlockAttentionKVCacheMemoryLayoutEnum::VECTORIZED_LAYOUT) + { + // Vectorized layout offset calculation: + // Layout: [page, token_in_page/kVectorSize, head_dim, kVectorSize] + // Offset = page_base + (token/kVectorSize) * (head_dim * kVectorSize) + + // (token % kVectorSize) + const long_index_t token_offset = + static_cast((token_idx_in_page / kVectorSize) * + (stride_token * kVectorSize)) + + (token_idx_in_page % kVectorSize); + + kv_offset_vec[k0] = page_base_offset + token_offset; + } + else // LINEAR_LAYOUT + { + // Linear layout: [page, token_in_page, head_dim] + // Offset = page_base + token_idx_in_page * stride_token + kv_offset_vec[k0] = + page_base_offset + static_cast(token_idx_in_page) * stride_token; + } + }); } } @@ -209,6 +268,12 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync static constexpr bool kStoreLSE = Problem::kStoreLSE; static constexpr bool kHasDropout = Problem::kHasDropout; static constexpr auto kKVMemoryLayout = Problem::kKVMemoryLayout; + static constexpr auto QScaleEnum = Problem::QScaleEnum; + + // For KV_BLOCKSCALE: shift value for exp2(x + shift) to scale P to [0, 2^shift] + // This avoids explicit P *= scale_p and v_descale /= scale_p operations + static constexpr float OCP_FP8_SHIFT = 8.0f; + static constexpr float FNUZ_FP8_SHIFT = 7.0f; static_assert((CK_TILE_FMHA_FWD_FAST_EXP2 && (kHasLogitsSoftCap && Problem::BiasEnum == BlockAttentionBiasEnum::NO_BIAS || @@ -341,8 +406,20 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync const index_t page_stride_k, const index_t page_stride_v, DropoutType& dropout, - const float sink_v) const + const float sink_v, + // KV_BLOCKSCALE parameters (only used when QScaleEnum == KV_BLOCKSCALE) + const float* kv_block_descale_ptr = nullptr, + index_t kv_block_descale_stride_block = 0, + index_t kv_block_descale_stride_head = 0, + index_t kv_block_descale_stride_kv = 1) const { + // KV_BLOCKSCALE requires page_block_size >= kN0 to ensure + // all tokens in a main loop iteration belong to the same page + if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::KV_BLOCKSCALE) + { + static_assert(kPageBlockSize >= kN0, "KV_BLOCKSCALE requires kPageBlockSize >= kN0"); + } + static_assert( std::is_same_v> && std::is_same_v> && @@ -494,6 +571,20 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync constexpr index_t NRepeat = KDstrEncode::hs_lengthss_[I0][I0]; statically_indexed_array k_offsets; index_t current_seq_k = seqlen_k_start; + + // Load physical pages first, then compute offsets. + // k_physical_pages can be reused for descale lookup later. + index_t k_physical_pages[NRepeat] = {}; + load_physical_pages(page_idx, k_coord, current_seq_k, k_physical_pages); + kv_offset_array_transform, decltype(k_coord), 0, @@ -505,7 +596,7 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync true, kN0, kVectorSize>( - page_idx, stride_k, page_stride_k, k_coord, k_offsets, current_seq_k); + k_physical_pages, stride_k, page_stride_k, k_coord, k_offsets, current_seq_k); auto k_dram_window = make_tile_scatter_gather(k_dram_block_window.get_bottom_tensor_view(), k_dram_block_window.get_window_lengths(), @@ -644,6 +735,50 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync "V page-index Y dim must be valid"); statically_indexed_array v_offsets; + // V physical pages array for use with kv_offset_array_transform + // For V_KIterOuter > 1, we need V_PageIdxRepeat elements; otherwise V_KIterInner + index_t v_physical_pages[V_PageIdxRepeat] = {}; + + // Prefetch V physical pages - can be called early to hide buffer load latency + auto prefetch_v_physical_pages = [&](auto k_loop_start) { + constexpr index_t kLoopStart = decltype(k_loop_start)::value; + if constexpr(V_KIterOuter > 1) + { + static_for<0, V_KIterOuter, 1>{}([&](auto k2) { + // Load physical pages for this k2 slice into the appropriate portion of array + index_t v_physical_pages_k2[V_KIterInner] = {}; + load_physical_pages(page_idx, v_coord, current_seq_k, v_physical_pages_k2); + + // Copy to merged array + static_for<0, V_KIterInner, 1>{}([&](auto k1) { + constexpr auto idx = k1.value + k2.value * V_KIterInner; + v_physical_pages[idx] = v_physical_pages_k2[k1.value]; + }); + }); + } + else + { + load_physical_pages(page_idx, v_coord, current_seq_k, v_physical_pages); + } + }; + + // Update V offsets using pre-loaded physical pages auto update_v_offsets = [&](auto k_loop_start) { constexpr index_t kLoopStart = decltype(k_loop_start)::value; // For 3D K decomposition (K2, K0, K1), compute offsets for each K2 slice @@ -653,6 +788,13 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync { static_for<0, V_KIterOuter, 1>{}([&](auto k2) { statically_indexed_array v_offsets_k2; + // Extract physical pages for this k2 slice + index_t v_physical_pages_k2[V_KIterInner]; + static_for<0, V_KIterInner, 1>{}([&](auto k1) { + constexpr auto idx = k1.value + k2.value * V_KIterInner; + v_physical_pages_k2[k1.value] = v_physical_pages[idx]; + }); + kv_offset_array_transform, decltype(v_coord), I1, @@ -663,8 +805,13 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync kKVMemoryLayout, false, kN0, - kVectorSize>( - page_idx, stride_v, page_stride_v, v_coord, v_offsets_k2, current_seq_k); + kVectorSize>(v_physical_pages_k2, + stride_v, + page_stride_v, + v_coord, + v_offsets_k2, + current_seq_k); + static_for<0, V_KIterInner, 1>{}([&](auto k1) { constexpr auto idx = number{}; v_offsets[idx] = v_offsets_k2[k1]; @@ -684,9 +831,12 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync false, kN0, kVectorSize>( - page_idx, stride_v, page_stride_v, v_coord, v_offsets, current_seq_k); + v_physical_pages, stride_v, page_stride_v, v_coord, v_offsets, current_seq_k); } }; + + // Prefetch V physical pages early to hide buffer load latency + prefetch_v_physical_pages(number<0>{}); update_v_offsets(number<0>{}); auto v_dram_window = make_tile_scatter_gather(v_dram_block_window_tmp.get_bottom_tensor_view(), @@ -717,6 +867,41 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync // main loop do { + // KV_BLOCKSCALE: load per-page K/V descale factors + // Uses k_physical_pages[0] from load_physical_pages to avoid redundant page_idx reads. + // Assumes kPageBlockSize >= kN0, so all tokens in one main loop iteration belong to + // the same page (single scale pair). + // + // TODO: Cross-page KV_BLOCKSCALE support + // Currently only supports kPageBlockSize >= kN0 (all tokens in tile on same page). + // To support smaller page sizes (cross-page tiles), need: + // + // 1. K descale: Load per-token k_descale_vec[NRepeat] based on k_physical_pages[k0] + // - After GEMM0 (S = Q × K^T), apply column-wise scaling: S[:,j] *= k_descale[j] + // - Requires modifying s_acc_element_func to accept column index + // + // 2. V descale: Load per-token v_descale_vec[V_PageIdxRepeat] based on + // v_physical_pages[k0] + // - Before GEMM1 (O = P × V), apply row-wise scaling to P: P[i,j] *= v_descale[j] + // - Or pre-scale V in LDS (more complex) + // + // 3. K and V may be on different pages for the same token index, so need separate + // lookups + // + [[maybe_unused]] float k_descale = 1.0f; + [[maybe_unused]] float v_descale = 1.0f; + if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::KV_BLOCKSCALE) + { + const index_t scale_offset = + k_physical_pages[0] * kv_block_descale_stride_block + + block_indices.kv_head_idx * kv_block_descale_stride_head; + k_descale = kv_block_descale_ptr[scale_offset + 0 * kv_block_descale_stride_kv]; + v_descale = kv_block_descale_ptr[scale_offset + 1 * kv_block_descale_stride_kv]; + } + + // Prefetch V physical pages early - overlaps with GEMM0 computation + prefetch_v_physical_pages(number{}); + // STAGE 1, QK gemm clear_tile(s_acc); // initialize C if constexpr(k0_loops > 1) @@ -763,9 +948,16 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync __builtin_amdgcn_sched_barrier(1); auto v_buf = load_tile(v_dram_window, number<-1>{}, bool_constant{}); + // V physical pages already prefetched before GEMM0 update_v_offsets(number{}); v_dram_window.update_page_idx(v_offsets); + // KV_BLOCKSCALE: apply k_descale to s_acc (dequantize QK result) + if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::KV_BLOCKSCALE) + { + tile_elementwise_inout([&k_descale](auto& x) { x *= k_descale; }, s_acc); + } + const auto p = [&]() { const auto bias_tile = load_tile(bias_dram_window); // load bias tile @@ -875,6 +1067,13 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync } const auto s = cast_tile(s_acc); // S{j} + + // Prefetch V physical pages early - overlaps with softmax computation + if constexpr(k1_loops > 1) + { + prefetch_v_physical_pages(number<2 * kK1>{}); + } + auto m_local = block_tile_reduce( s, sequence<1>{}, @@ -953,7 +1152,21 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync sweep_tile_span(p_spans[number<0>{}], [&](auto idx0) { constexpr auto i_idx = make_tuple(idx0); #if CK_TILE_FMHA_FWD_FAST_EXP2 - auto row_max = scale_s * get_validated_m(m[i_idx]); + // For KV_BLOCKSCALE: precompute (m - shift) once per row + // exp2(s - (m - shift)) = exp2(s - m + shift) = exp2(s - m) * 2^shift + // This scales P by 2^shift (≈448 for fp8_e4m3) without explicit multiply + auto validated_m = get_validated_m(m[i_idx]); + auto row_max = scale_s * validated_m; + if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::KV_BLOCKSCALE) + { +#if CK_TILE_USE_OCP_FP8 + validated_m -= OCP_FP8_SHIFT; // for Bias/Alibi/SoftCap + row_max -= OCP_FP8_SHIFT; // for else branch +#else + validated_m -= FNUZ_FP8_SHIFT; + row_max -= FNUZ_FP8_SHIFT; +#endif + } #endif sweep_tile_span(p_spans[number<1>{}], [&](auto idx1) { constexpr auto i_j_idx = make_tuple(idx0, idx1); @@ -961,13 +1174,13 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS || BiasEnum == BlockAttentionBiasEnum::ALIBI) { - p_compute(i_j_idx) = exp2(s[i_j_idx] - get_validated_m(m[i_idx])); + p_compute(i_j_idx) = exp2(s[i_j_idx] - validated_m); } else { if constexpr(kHasLogitsSoftCap) { - p_compute(i_j_idx) = exp2(s[i_j_idx] - get_validated_m(m[i_idx])); + p_compute(i_j_idx) = exp2(s[i_j_idx] - validated_m); } else { @@ -1049,6 +1262,22 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync }(); // STAGE 3, KV gemm + // KV_BLOCKSCALE: accumulate P*V into temporary tile before applying v_descale + auto o_acc_unscaled = decltype(o_acc){}; + if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::KV_BLOCKSCALE) + { + clear_tile(o_acc_unscaled); + } + + // Select GEMM1 target: o_acc_unscaled for KV_BLOCKSCALE (needs v_descale), o_acc + // otherwise + auto& gemm1_acc = [&]() -> auto& { + if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::KV_BLOCKSCALE) + return o_acc_unscaled; + else + return o_acc; + }(); + if constexpr(k1_loops > 1) { static_for<0, k1_loops - 1, 1>{}([&](auto i_k1) { @@ -1056,11 +1285,19 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync { v_buf = load_tile( v_dram_window, number<-1>{}, bool_constant{}); // load next v_buf + // Update V offsets using previously prefetched physical pages update_v_offsets(number<(2 + i_k1.value) * kK1>{}); v_dram_window.update_page_idx(v_offsets); } + + // Prefetch V physical pages for NEXT iteration - overlaps with GEMM1 + if constexpr(i_k1 + 1 < k1_loops - 1) + { + prefetch_v_physical_pages(number<(2 + i_k1.value + 1) * kK1>{}); + } + block_sync_lds(); - gemm_1(o_acc, + gemm_1(gemm1_acc, get_slice_tile( p, sequence<0, i_k1 * kK1>{}, sequence{}), get_slice_tile( @@ -1104,6 +1341,17 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync move_tile_window(k_dram_block_window, {kN0, 0}); k_dram_window.set_window_origin(k_dram_block_window.get_window_origin()); + // KV_BLOCKSCALE: reload physical pages for the new tile + load_physical_pages(page_idx, k_coord, current_seq_k, k_physical_pages); + kv_offset_array_transform, decltype(k_coord), 0, @@ -1115,7 +1363,7 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync true, kN0, kVectorSize>( - page_idx, stride_k, page_stride_k, k_coord, k_offsets, current_seq_k); + k_physical_pages, stride_k, page_stride_k, k_coord, k_offsets, current_seq_k); k_dram_window.update_page_idx(k_offsets); if constexpr(k1_loops >= 2 && LdsSeq.at(number<0>{}) == LdsSeq.at(number{})) @@ -1131,13 +1379,26 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync { block_sync_lds(); gemm_1( - o_acc, + gemm1_acc, get_slice_tile(p, sequence<0, (k1_loops - 1) * kK1>{}, sequence{}), get_slice_tile( v_lds_window, sequence<(LdsSeq.at(number{})) * kN1, 0>{}, sequence<(LdsSeq.at(number{}) + 1) * kN1, kK1>{})); } + + // KV_BLOCKSCALE: apply v_descale and accumulate o_acc_unscaled into o_acc + // Note: No division by scale_p needed because: + // 1. P was scaled by 2^shift through exp2 shift trick + // 2. rowsum l was also scaled by 2^shift + // 3. Final O = sum(P*V) / l, so the 2^shift cancels out + if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::KV_BLOCKSCALE) + { + tile_elementwise_inout( + [&v_descale](auto& o, auto& o_unscaled) { o += o_unscaled * v_descale; }, + o_acc, + o_acc_unscaled); + } } while(i_total_loops < num_total_loop); // store lse @@ -1257,6 +1518,77 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync dropout, sink_v); } + + // Overload for KV_BLOCKSCALE: K/V descale is per-page + // This is a convenience overload that forwards to the main operator() with kv_scale parameters + template + CK_TILE_HOST_DEVICE auto + operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile + const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile + const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile + const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile + RandValDramBlockWindowTmp& randval_dram_block_window_tmp, // M0*N0 tile + LSEDramBlockWindowTmp& lse_dram_block_window_tmp, // M0*1 tile + FmhaMask mask, + PositionEncoding position_encoding, + float scale_s, + const AttentionVariant& variant, + const AttentionVariantParams& variant_params, + const BlockIndices& block_indices, + void* smem_ptr, + const index_t* page_idx, + const index_t stride_k, + const index_t stride_v, + const index_t page_stride_k, + const index_t page_stride_v, + DropoutType& dropout, + float sink_v, + const float* kv_block_descale_ptr, + index_t kv_block_descale_stride_block, + index_t kv_block_descale_stride_head, + index_t kv_block_descale_stride_kv) const + { + return operator()(q_dram_block_window_tmp, + identity{}, + k_dram_block_window_tmp, + identity{}, + v_dram_block_window_tmp, + identity{}, + bias_dram_block_window_tmp, + identity{}, + randval_dram_block_window_tmp, + lse_dram_block_window_tmp, + identity{}, + identity{}, + identity{}, + identity{}, + mask, + position_encoding, + scale_s, + variant, + variant_params, + block_indices, + smem_ptr, + page_idx, + stride_k, + stride_v, + page_stride_k, + page_stride_v, + dropout, + sink_v, + kv_block_descale_ptr, + kv_block_descale_stride_block, + kv_block_descale_stride_head, + kv_block_descale_stride_kv); + } }; } // namespace ck_tile From dc74e66e7be47fefe4383a025068fdeb6dcbd22d Mon Sep 17 00:00:00 2001 From: Jeff Huang Date: Mon, 2 Feb 2026 22:57:29 +0800 Subject: [PATCH 2/6] Add runtime check nullptr for prevent quantization parameters. --- .../ck_tile/ops/fmha/kernel/fmha_batch_prefill_kernel.hpp | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/include/ck_tile/ops/fmha/kernel/fmha_batch_prefill_kernel.hpp b/include/ck_tile/ops/fmha/kernel/fmha_batch_prefill_kernel.hpp index 03303a0683d..873e65eac54 100644 --- a/include/ck_tile/ops/fmha/kernel/fmha_batch_prefill_kernel.hpp +++ b/include/ck_tile/ops/fmha/kernel/fmha_batch_prefill_kernel.hpp @@ -10,6 +10,7 @@ #include "ck_tile/ops/fmha/block/block_attention_quant_scale_enum.hpp" #include "ck_tile/ops/fmha/block/variants.hpp" +#include #include #include #include @@ -1208,6 +1209,8 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel const float scale_s = [&] { if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::PERTENSOR) { + assert(kargs.q_descale_ptr != nullptr); + assert(kargs.k_descale_ptr != nullptr); float q_descale = *(reinterpret_cast(kargs.q_descale_ptr)); float k_descale = *(reinterpret_cast(kargs.k_descale_ptr)); @@ -1216,6 +1219,7 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel else if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::KV_BLOCKSCALE) { // Q is per-tensor, K is per-page (handled in pipeline) + assert(kargs.q_descale_ptr != nullptr); float q_descale = *(reinterpret_cast(kargs.q_descale_ptr)); return kargs.scale_s * q_descale; } @@ -1251,6 +1255,7 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::PERTENSOR) { // TODO - move global load of descale to pipeline + assert(kargs.v_descale_ptr != nullptr); float v_descale = *(reinterpret_cast(kargs.v_descale_ptr)); float scale_p = ck_tile::type_convert(ck_tile::numeric::max()); @@ -1297,6 +1302,7 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel else if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::KV_BLOCKSCALE) { // KV_BLOCKSCALE: K/V descale is per-page, handled in pipeline + assert(kargs.kv_block_descale_ptr != nullptr); const float* kv_block_descale_ptr = reinterpret_cast(kargs.kv_block_descale_ptr); From e1af9b7afbbb130575502ad096dfa26489a85bca Mon Sep 17 00:00:00 2001 From: Jeff Huang Date: Mon, 2 Feb 2026 23:27:10 +0800 Subject: [PATCH 3/6] 1. Relax kv_blockscale page_size restriction from == 1024 to >= kN0 2. Rename QScaleKargsSelector -> GetQScaleKargs for naming consistency 3. Remove unused BlockAttentionQuantScaleEnumToStr --- .../ck_tile/01_fmha/codegen/ops/fmha_batch_prefill.py | 5 +++-- .../fmha/block/block_attention_quant_scale_enum.hpp | 5 ----- .../ops/fmha/kernel/fmha_batch_prefill_kernel.hpp | 10 +++++----- 3 files changed, 8 insertions(+), 12 deletions(-) diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_batch_prefill.py b/example/ck_tile/01_fmha/codegen/ops/fmha_batch_prefill.py index b575adc7d05..8459f20d358 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_batch_prefill.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_batch_prefill.py @@ -740,8 +740,9 @@ def get_fwd_blobs( for page_size in SUPPORTED_PAGE_SIZE: if page_size == 1 and pipeline.F_kv_memory_layout != "linear": continue - # kv_blockscale only supports page_size=1024 - if pipeline.F_qscale == "kv_blockscale" and page_size != 1024: + # kv_blockscale requires page_size >= kN0 (tile.F_bn0) + # This ensures all tokens in a main loop iteration belong to the same page + if pipeline.F_qscale == "kv_blockscale" and page_size < tile.F_bn0: continue k = FmhaFwdKernel( F_idx=0, diff --git a/include/ck_tile/ops/fmha/block/block_attention_quant_scale_enum.hpp b/include/ck_tile/ops/fmha/block/block_attention_quant_scale_enum.hpp index 84a2321708d..0c6075e0630 100644 --- a/include/ck_tile/ops/fmha/block/block_attention_quant_scale_enum.hpp +++ b/include/ck_tile/ops/fmha/block/block_attention_quant_scale_enum.hpp @@ -34,10 +34,5 @@ struct BlockAttentionQuantScaleEnumToStr -struct BlockAttentionQuantScaleEnumToStr -{ - static constexpr const char* name = "kv_blockscale"; -}; } // namespace ck_tile diff --git a/include/ck_tile/ops/fmha/kernel/fmha_batch_prefill_kernel.hpp b/include/ck_tile/ops/fmha/kernel/fmha_batch_prefill_kernel.hpp index 873e65eac54..ecc1ef94852 100644 --- a/include/ck_tile/ops/fmha/kernel/fmha_batch_prefill_kernel.hpp +++ b/include/ck_tile/ops/fmha/kernel/fmha_batch_prefill_kernel.hpp @@ -207,19 +207,19 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel // Helper template to select QScale Kargs type based on QScaleEnum // EmptyType: type to use when QScaleEnum is NO_SCALE (e.g., FmhaFwdEmptyKargs<3>) template - struct QScaleKargsSelector + struct GetQScaleKargs { using type = EmptyType; }; template - struct QScaleKargsSelector + struct GetQScaleKargs { using type = FmhaFwdPerTensorQScaleKargs; }; template - struct QScaleKargsSelector + struct GetQScaleKargs { using type = FmhaFwdKVBlockScaleKargs; }; @@ -287,7 +287,7 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel FmhaFwdEmptyKargs<0>>>, std::conditional_t>, std::conditional_t>, - QScaleKargsSelector>::type, + GetQScaleKargs>::type, std::conditional_t>, std::conditional_t> { @@ -306,7 +306,7 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel FmhaFwdEmptyKargs<0>>>, std::conditional_t>, std::conditional_t>, - QScaleKargsSelector>::type, + GetQScaleKargs>::type, std::conditional_t>, std::conditional_t> { From 4933100b0fc7754db35a692fa7a796ddfff6de21 Mon Sep 17 00:00:00 2001 From: Jeff Huang Date: Tue, 3 Feb 2026 09:00:42 +0800 Subject: [PATCH 4/6] use statically_indexed_array instead of c-style array. --- ..._batch_prefill_pipeline_qr_ks_vs_async.hpp | 57 ++++++++++--------- 1 file changed, 31 insertions(+), 26 deletions(-) diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp index 7622778c89e..62e67a1fe11 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp @@ -21,7 +21,8 @@ namespace ck_tile { // - Crosses pages: per-token lookup // - Single page: lane0 lookup once, broadcast to all // Output: physical_pages array with kLoopCount elements -template {}([&](auto k0) { const index_t global_token_idx = global_seq_offset + thread_coord_start + kLoopStart + kLoopStride * k0.value; - const index_t page_id = global_token_idx >> kLog2PageSize; - physical_pages[k0.value] = page_idx[page_id]; + const index_t page_id = global_token_idx >> kLog2PageSize; + physical_pages[k0] = page_idx[page_id]; }); } else @@ -73,7 +74,7 @@ CK_TILE_DEVICE void load_physical_pages(const index_t* page_idx, static_for<0, kLoopCount, 1>{}([&](auto k0) { const index_t global_token_idx = global_seq_offset + thread_coord_start + kLoopStart + kLoopStride * k0.value; - physical_pages[k0.value] = page_idx[global_token_idx]; + physical_pages[k0] = page_idx[global_token_idx]; }); } else if constexpr(kVTileCrossesPages) @@ -83,8 +84,8 @@ CK_TILE_DEVICE void load_physical_pages(const index_t* page_idx, static_for<0, kLoopCount, 1>{}([&](auto k0) { const index_t global_token_idx = global_seq_offset + thread_coord_start + kLoopStart + kLoopStride * k0.value; - const index_t page_id = global_token_idx >> kLog2PageSize; - physical_pages[k0.value] = page_idx[page_id]; + const index_t page_id = global_token_idx >> kLog2PageSize; + physical_pages[k0] = page_idx[page_id]; }); } else @@ -96,7 +97,7 @@ CK_TILE_DEVICE void load_physical_pages(const index_t* page_idx, const index_t shared_physical_page = page_idx[lane0_page_id]; static_for<0, kLoopCount, 1>{}( - [&](auto k0) { physical_pages[k0.value] = shared_physical_page; }); + [&](auto k0) { physical_pages[k0] = shared_physical_page; }); } } } @@ -123,7 +124,7 @@ CK_TILE_DEVICE void load_physical_pages(const index_t* page_idx, // LINEAR_LAYOUT: [page, token_in_page, head_dim] // VECTORIZED_LAYOUT: [page, token_in_page/kVectorSize, head_dim, kVectorSize] // -template -CK_TILE_HOST_DEVICE void kv_offset_array_transform(const index_t (&physical_pages)[kLoopCount], +CK_TILE_HOST_DEVICE void kv_offset_array_transform(const IndexArrayType& physical_pages, const index_t& stride_token, const index_t& stride_page_block, const CoordVecType& coord_vec, - OffsetVecType& kv_offset_vec, + IndexArrayType& kv_offset_vec, index_t global_seq_offset = 0) { static constexpr index_t kLog2PageSize = [] { @@ -164,7 +165,7 @@ CK_TILE_HOST_DEVICE void kv_offset_array_transform(const index_t (&physical_page const index_t global_token_idx = global_seq_offset + thread_coord_start + kLoopStart + kLoopStride * k0.value; const index_t token_idx_in_page = global_token_idx & kInPageOffsetMask; - const index_t physical_page = physical_pages[k0.value]; + const index_t physical_page = physical_pages[k0]; kv_offset_vec[k0] = static_cast(physical_page) * stride_page_block + static_cast(token_idx_in_page) * stride_token; @@ -181,7 +182,7 @@ CK_TILE_HOST_DEVICE void kv_offset_array_transform(const index_t (&physical_page const index_t global_token_idx = global_seq_offset + thread_coord_start + kLoopStart + kLoopStride * k0.value; const index_t token_idx_in_page = global_token_idx & kInPageOffsetMask; - const index_t physical_page = physical_pages[k0.value]; + const index_t physical_page = physical_pages[k0]; const long_index_t page_base_offset = static_cast(physical_page) * stride_page_block; @@ -574,8 +575,9 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync // Load physical pages first, then compute offsets. // k_physical_pages can be reused for descale lookup later. - index_t k_physical_pages[NRepeat] = {}; - load_physical_pages k_physical_pages{}; + load_physical_pages, + decltype(k_coord), 0, kPageBlockSize, 0, @@ -737,7 +739,7 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync statically_indexed_array v_offsets; // V physical pages array for use with kv_offset_array_transform // For V_KIterOuter > 1, we need V_PageIdxRepeat elements; otherwise V_KIterInner - index_t v_physical_pages[V_PageIdxRepeat] = {}; + statically_indexed_array v_physical_pages{}; // Prefetch V physical pages - can be called early to hide buffer load latency auto prefetch_v_physical_pages = [&](auto k_loop_start) { @@ -746,8 +748,9 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync { static_for<0, V_KIterOuter, 1>{}([&](auto k2) { // Load physical pages for this k2 slice into the appropriate portion of array - index_t v_physical_pages_k2[V_KIterInner] = {}; - load_physical_pages v_physical_pages_k2{}; + load_physical_pages, + decltype(v_coord), I1, kPageBlockSize, kLoopStart + k2.value * V_KLanes * V_KIterInner, @@ -759,14 +762,15 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync // Copy to merged array static_for<0, V_KIterInner, 1>{}([&](auto k1) { - constexpr auto idx = k1.value + k2.value * V_KIterInner; - v_physical_pages[idx] = v_physical_pages_k2[k1.value]; + constexpr auto idx = number{}; + v_physical_pages[idx] = v_physical_pages_k2[k1]; }); }); } else { - load_physical_pages, + decltype(v_coord), I1, kPageBlockSize, kLoopStart, @@ -789,10 +793,10 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync static_for<0, V_KIterOuter, 1>{}([&](auto k2) { statically_indexed_array v_offsets_k2; // Extract physical pages for this k2 slice - index_t v_physical_pages_k2[V_KIterInner]; + statically_indexed_array v_physical_pages_k2; static_for<0, V_KIterInner, 1>{}([&](auto k1) { - constexpr auto idx = k1.value + k2.value * V_KIterInner; - v_physical_pages_k2[k1.value] = v_physical_pages[idx]; + constexpr auto idx = number{}; + v_physical_pages_k2[k1] = v_physical_pages[idx]; }); kv_offset_array_transform, @@ -893,7 +897,7 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::KV_BLOCKSCALE) { const index_t scale_offset = - k_physical_pages[0] * kv_block_descale_stride_block + + k_physical_pages[number<0>{}] * kv_block_descale_stride_block + block_indices.kv_head_idx * kv_block_descale_stride_head; k_descale = kv_block_descale_ptr[scale_offset + 0 * kv_block_descale_stride_kv]; v_descale = kv_block_descale_ptr[scale_offset + 1 * kv_block_descale_stride_kv]; @@ -1342,7 +1346,8 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync k_dram_window.set_window_origin(k_dram_block_window.get_window_origin()); // KV_BLOCKSCALE: reload physical pages for the new tile - load_physical_pages, + decltype(k_coord), 0, kPageBlockSize, 0, From 8b07580161b266447097a1d2806e07692f612ca1 Mon Sep 17 00:00:00 2001 From: Jeff Huang Date: Tue, 3 Feb 2026 10:34:58 +0800 Subject: [PATCH 5/6] Rename stride fields in FmhaFwdKVBlockScaleKargs --- example/ck_tile/01_fmha/fmha_fwd.hpp | 20 +++---- .../fmha/kernel/fmha_batch_prefill_kernel.hpp | 56 +++++++++---------- ..._batch_prefill_pipeline_qr_ks_vs_async.hpp | 28 +++++----- 3 files changed, 52 insertions(+), 52 deletions(-) diff --git a/example/ck_tile/01_fmha/fmha_fwd.hpp b/example/ck_tile/01_fmha/fmha_fwd.hpp index 1fe14982a15..6a8a2cb9b9c 100644 --- a/example/ck_tile/01_fmha/fmha_fwd.hpp +++ b/example/ck_tile/01_fmha/fmha_fwd.hpp @@ -606,10 +606,10 @@ struct fmha_batch_prefill_args // KV_BLOCKSCALE: per-page K/V descales (Q per-tensor, K/V per-page) // Layout: [num_block, num_kv_head, 2] where 2 = (k_descale, v_descale) // Mutually exclusive with per-tensor k_descale_ptr/v_descale_ptr - const void* kv_block_descale_ptr = nullptr; - ck_tile::index_t kv_block_descale_stride_block = 0; // Stride along num_block dimension - ck_tile::index_t kv_block_descale_stride_head = 0; // Stride along num_kv_head dimension - ck_tile::index_t kv_block_descale_stride_kv = 1; // Stride for K/V index (last dim) + const void* kv_block_descale_ptr = nullptr; + ck_tile::index_t nblock_stride_kv_block_descale = 0; // Stride along num_block dimension + ck_tile::index_t nhead_stride_kv_block_descale = 0; // Stride along num_kv_head dimension + ck_tile::index_t kv_stride_kv_block_descale = 1; // Stride for K/V index (last dim) }; template @@ -1235,9 +1235,9 @@ auto fmha_batch_prefill_create_kargs_and_grids(fmha_batch_prefill_args args) args.drop_seed_offset, args.sink_ptr, args.kv_block_descale_ptr, - args.kv_block_descale_stride_block, - args.kv_block_descale_stride_head, - args.kv_block_descale_stride_kv); + args.nblock_stride_kv_block_descale, + args.nhead_stride_kv_block_descale, + args.kv_stride_kv_block_descale); } else { // create batch mode kernel arguments @@ -1292,9 +1292,9 @@ auto fmha_batch_prefill_create_kargs_and_grids(fmha_batch_prefill_args args) args.drop_seed_offset, args.sink_ptr, args.kv_block_descale_ptr, - args.kv_block_descale_stride_block, - args.kv_block_descale_stride_head, - args.kv_block_descale_stride_kv); + args.nblock_stride_kv_block_descale, + args.nhead_stride_kv_block_descale, + args.kv_stride_kv_block_descale); } }(); diff --git a/include/ck_tile/ops/fmha/kernel/fmha_batch_prefill_kernel.hpp b/include/ck_tile/ops/fmha/kernel/fmha_batch_prefill_kernel.hpp index ecc1ef94852..d78c26b9189 100644 --- a/include/ck_tile/ops/fmha/kernel/fmha_batch_prefill_kernel.hpp +++ b/include/ck_tile/ops/fmha/kernel/fmha_batch_prefill_kernel.hpp @@ -197,11 +197,11 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel // KV_BLOCKSCALE: Q per-tensor, K/V per-page descales struct FmhaFwdKVBlockScaleKargs { - const void* q_descale_ptr = nullptr; // Per-tensor Q descale - const void* kv_block_descale_ptr = nullptr; // [num_block, num_kv_head, 2] - ck_tile::index_t kv_block_descale_stride_block = 0; // Stride along num_block dimension - ck_tile::index_t kv_block_descale_stride_head = 0; // Stride along num_kv_head dimension - ck_tile::index_t kv_block_descale_stride_kv = 1; // Stride for K/V index + const void* q_descale_ptr = nullptr; // Per-tensor Q descale + const void* kv_block_descale_ptr = nullptr; // [num_block, num_kv_head, 2] + ck_tile::index_t nblock_stride_kv_block_descale = 0; // Stride along num_block dimension + ck_tile::index_t nhead_stride_kv_block_descale = 0; // Stride along num_kv_head dimension + ck_tile::index_t kv_stride_kv_block_descale = 1; // Stride for K/V index }; // Helper template to select QScale Kargs type based on QScaleEnum @@ -376,11 +376,11 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel bool s_randval, std::variant, std::pair> drop_seed_offset, - const void* sink_ptr = nullptr, - const void* kv_block_descale_ptr = nullptr, - ck_tile::index_t kv_block_descale_stride_block = 0, - ck_tile::index_t kv_block_descale_stride_head = 0, - ck_tile::index_t kv_block_descale_stride_kv = 1) + const void* sink_ptr = nullptr, + const void* kv_block_descale_ptr = nullptr, + ck_tile::index_t nblock_stride_kv_block_descale = 0, + ck_tile::index_t nhead_stride_kv_block_descale = 0, + ck_tile::index_t kv_stride_kv_block_descale = 1) { Kargs kargs{{q_ptr, k_ptr, @@ -453,11 +453,11 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel } else if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::KV_BLOCKSCALE) { - kargs.q_descale_ptr = q_descale_ptr; - kargs.kv_block_descale_ptr = kv_block_descale_ptr; - kargs.kv_block_descale_stride_block = kv_block_descale_stride_block; - kargs.kv_block_descale_stride_head = kv_block_descale_stride_head; - kargs.kv_block_descale_stride_kv = kv_block_descale_stride_kv; + kargs.q_descale_ptr = q_descale_ptr; + kargs.kv_block_descale_ptr = kv_block_descale_ptr; + kargs.nblock_stride_kv_block_descale = nblock_stride_kv_block_descale; + kargs.nhead_stride_kv_block_descale = nhead_stride_kv_block_descale; + kargs.kv_stride_kv_block_descale = kv_stride_kv_block_descale; } if constexpr(kHasDropout) { @@ -535,11 +535,11 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel bool s_randval, std::variant, std::pair> drop_seed_offset, - const void* sink_ptr = nullptr, - const void* kv_block_descale_ptr = nullptr, - ck_tile::index_t kv_block_descale_stride_block = 0, - ck_tile::index_t kv_block_descale_stride_head = 0, - ck_tile::index_t kv_block_descale_stride_kv = 1) + const void* sink_ptr = nullptr, + const void* kv_block_descale_ptr = nullptr, + ck_tile::index_t nblock_stride_kv_block_descale = 0, + ck_tile::index_t nhead_stride_kv_block_descale = 0, + ck_tile::index_t kv_stride_kv_block_descale = 1) { Kargs kargs{{q_ptr, k_ptr, @@ -609,11 +609,11 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel } else if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::KV_BLOCKSCALE) { - kargs.q_descale_ptr = q_descale_ptr; - kargs.kv_block_descale_ptr = kv_block_descale_ptr; - kargs.kv_block_descale_stride_block = kv_block_descale_stride_block; - kargs.kv_block_descale_stride_head = kv_block_descale_stride_head; - kargs.kv_block_descale_stride_kv = kv_block_descale_stride_kv; + kargs.q_descale_ptr = q_descale_ptr; + kargs.kv_block_descale_ptr = kv_block_descale_ptr; + kargs.nblock_stride_kv_block_descale = nblock_stride_kv_block_descale; + kargs.nhead_stride_kv_block_descale = nhead_stride_kv_block_descale; + kargs.kv_stride_kv_block_descale = kv_stride_kv_block_descale; } if constexpr(kHasDropout) { @@ -1327,9 +1327,9 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel dropout, sink_value, kv_block_descale_ptr, - kargs.kv_block_descale_stride_block, - kargs.kv_block_descale_stride_head, - kargs.kv_block_descale_stride_kv); + kargs.nblock_stride_kv_block_descale, + kargs.nhead_stride_kv_block_descale, + kargs.kv_stride_kv_block_descale); } else { diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp index 62e67a1fe11..592c9322e8f 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp @@ -409,10 +409,10 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync DropoutType& dropout, const float sink_v, // KV_BLOCKSCALE parameters (only used when QScaleEnum == KV_BLOCKSCALE) - const float* kv_block_descale_ptr = nullptr, - index_t kv_block_descale_stride_block = 0, - index_t kv_block_descale_stride_head = 0, - index_t kv_block_descale_stride_kv = 1) const + const float* kv_block_descale_ptr = nullptr, + index_t nblock_stride_kv_block_descale = 0, + index_t nhead_stride_kv_block_descale = 0, + index_t kv_stride_kv_block_descale = 1) const { // KV_BLOCKSCALE requires page_block_size >= kN0 to ensure // all tokens in a main loop iteration belong to the same page @@ -897,10 +897,10 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::KV_BLOCKSCALE) { const index_t scale_offset = - k_physical_pages[number<0>{}] * kv_block_descale_stride_block + - block_indices.kv_head_idx * kv_block_descale_stride_head; - k_descale = kv_block_descale_ptr[scale_offset + 0 * kv_block_descale_stride_kv]; - v_descale = kv_block_descale_ptr[scale_offset + 1 * kv_block_descale_stride_kv]; + k_physical_pages[number<0>{}] * nblock_stride_kv_block_descale + + block_indices.kv_head_idx * nhead_stride_kv_block_descale; + k_descale = kv_block_descale_ptr[scale_offset + 0 * kv_stride_kv_block_descale]; + v_descale = kv_block_descale_ptr[scale_offset + 1 * kv_stride_kv_block_descale]; } // Prefetch V physical pages early - overlaps with GEMM0 computation @@ -1557,9 +1557,9 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync DropoutType& dropout, float sink_v, const float* kv_block_descale_ptr, - index_t kv_block_descale_stride_block, - index_t kv_block_descale_stride_head, - index_t kv_block_descale_stride_kv) const + index_t nblock_stride_kv_block_descale, + index_t nhead_stride_kv_block_descale, + index_t kv_stride_kv_block_descale) const { return operator()(q_dram_block_window_tmp, identity{}, @@ -1590,9 +1590,9 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync dropout, sink_v, kv_block_descale_ptr, - kv_block_descale_stride_block, - kv_block_descale_stride_head, - kv_block_descale_stride_kv); + nblock_stride_kv_block_descale, + nhead_stride_kv_block_descale, + kv_stride_kv_block_descale); } }; From c093935e0c6c2ad769da62bbd082a7aaf0e0febc Mon Sep 17 00:00:00 2001 From: Jeff Huang Date: Tue, 3 Feb 2026 12:36:00 +0800 Subject: [PATCH 6/6] Split kv_block_descale_ptr into k_descale_ptr and v_descale_ptr to maintain flexibility. --- example/ck_tile/01_fmha/fmha_fwd.hpp | 15 +++----- .../fmha/kernel/fmha_batch_prefill_kernel.hpp | 34 +++++++++---------- ..._batch_prefill_pipeline_qr_ks_vs_async.hpp | 22 ++++++------ 3 files changed, 32 insertions(+), 39 deletions(-) diff --git a/example/ck_tile/01_fmha/fmha_fwd.hpp b/example/ck_tile/01_fmha/fmha_fwd.hpp index 6a8a2cb9b9c..ee404010eff 100644 --- a/example/ck_tile/01_fmha/fmha_fwd.hpp +++ b/example/ck_tile/01_fmha/fmha_fwd.hpp @@ -604,12 +604,11 @@ struct fmha_batch_prefill_args drop_seed_offset; // KV_BLOCKSCALE: per-page K/V descales (Q per-tensor, K/V per-page) - // Layout: [num_block, num_kv_head, 2] where 2 = (k_descale, v_descale) - // Mutually exclusive with per-tensor k_descale_ptr/v_descale_ptr - const void* kv_block_descale_ptr = nullptr; + // k_descale_ptr/v_descale_ptr are reused for KV_BLOCKSCALE mode: + // k_descale_ptr: [num_block, num_kv_head] - points to k block descale + // v_descale_ptr: [num_block, num_kv_head] - points to v block descale ck_tile::index_t nblock_stride_kv_block_descale = 0; // Stride along num_block dimension ck_tile::index_t nhead_stride_kv_block_descale = 0; // Stride along num_kv_head dimension - ck_tile::index_t kv_stride_kv_block_descale = 1; // Stride for K/V index (last dim) }; template @@ -1234,10 +1233,8 @@ auto fmha_batch_prefill_create_kargs_and_grids(fmha_batch_prefill_args args) args.s_randval, args.drop_seed_offset, args.sink_ptr, - args.kv_block_descale_ptr, args.nblock_stride_kv_block_descale, - args.nhead_stride_kv_block_descale, - args.kv_stride_kv_block_descale); + args.nhead_stride_kv_block_descale); } else { // create batch mode kernel arguments @@ -1291,10 +1288,8 @@ auto fmha_batch_prefill_create_kargs_and_grids(fmha_batch_prefill_args args) args.s_randval, args.drop_seed_offset, args.sink_ptr, - args.kv_block_descale_ptr, args.nblock_stride_kv_block_descale, - args.nhead_stride_kv_block_descale, - args.kv_stride_kv_block_descale); + args.nhead_stride_kv_block_descale); } }(); diff --git a/include/ck_tile/ops/fmha/kernel/fmha_batch_prefill_kernel.hpp b/include/ck_tile/ops/fmha/kernel/fmha_batch_prefill_kernel.hpp index d78c26b9189..c6628f66be8 100644 --- a/include/ck_tile/ops/fmha/kernel/fmha_batch_prefill_kernel.hpp +++ b/include/ck_tile/ops/fmha/kernel/fmha_batch_prefill_kernel.hpp @@ -195,13 +195,14 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel }; // KV_BLOCKSCALE: Q per-tensor, K/V per-page descales + // K descale: [num_block, num_kv_head], V descale: [num_block, num_kv_head] struct FmhaFwdKVBlockScaleKargs { const void* q_descale_ptr = nullptr; // Per-tensor Q descale - const void* kv_block_descale_ptr = nullptr; // [num_block, num_kv_head, 2] + const void* k_descale_ptr = nullptr; // [num_block, num_kv_head] + const void* v_descale_ptr = nullptr; // [num_block, num_kv_head] ck_tile::index_t nblock_stride_kv_block_descale = 0; // Stride along num_block dimension ck_tile::index_t nhead_stride_kv_block_descale = 0; // Stride along num_kv_head dimension - ck_tile::index_t kv_stride_kv_block_descale = 1; // Stride for K/V index }; // Helper template to select QScale Kargs type based on QScaleEnum @@ -377,10 +378,8 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel std::variant, std::pair> drop_seed_offset, const void* sink_ptr = nullptr, - const void* kv_block_descale_ptr = nullptr, ck_tile::index_t nblock_stride_kv_block_descale = 0, - ck_tile::index_t nhead_stride_kv_block_descale = 0, - ck_tile::index_t kv_stride_kv_block_descale = 1) + ck_tile::index_t nhead_stride_kv_block_descale = 0) { Kargs kargs{{q_ptr, k_ptr, @@ -454,10 +453,10 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel else if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::KV_BLOCKSCALE) { kargs.q_descale_ptr = q_descale_ptr; - kargs.kv_block_descale_ptr = kv_block_descale_ptr; + kargs.k_descale_ptr = k_descale_ptr; + kargs.v_descale_ptr = v_descale_ptr; kargs.nblock_stride_kv_block_descale = nblock_stride_kv_block_descale; kargs.nhead_stride_kv_block_descale = nhead_stride_kv_block_descale; - kargs.kv_stride_kv_block_descale = kv_stride_kv_block_descale; } if constexpr(kHasDropout) { @@ -536,10 +535,8 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel std::variant, std::pair> drop_seed_offset, const void* sink_ptr = nullptr, - const void* kv_block_descale_ptr = nullptr, ck_tile::index_t nblock_stride_kv_block_descale = 0, - ck_tile::index_t nhead_stride_kv_block_descale = 0, - ck_tile::index_t kv_stride_kv_block_descale = 1) + ck_tile::index_t nhead_stride_kv_block_descale = 0) { Kargs kargs{{q_ptr, k_ptr, @@ -610,10 +607,10 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel else if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::KV_BLOCKSCALE) { kargs.q_descale_ptr = q_descale_ptr; - kargs.kv_block_descale_ptr = kv_block_descale_ptr; + kargs.k_descale_ptr = k_descale_ptr; + kargs.v_descale_ptr = v_descale_ptr; kargs.nblock_stride_kv_block_descale = nblock_stride_kv_block_descale; kargs.nhead_stride_kv_block_descale = nhead_stride_kv_block_descale; - kargs.kv_stride_kv_block_descale = kv_stride_kv_block_descale; } if constexpr(kHasDropout) { @@ -1302,9 +1299,10 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel else if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::KV_BLOCKSCALE) { // KV_BLOCKSCALE: K/V descale is per-page, handled in pipeline - assert(kargs.kv_block_descale_ptr != nullptr); - const float* kv_block_descale_ptr = - reinterpret_cast(kargs.kv_block_descale_ptr); + assert(kargs.k_descale_ptr != nullptr); + assert(kargs.v_descale_ptr != nullptr); + const float* k_descale_ptr = reinterpret_cast(kargs.k_descale_ptr); + const float* v_descale_ptr = reinterpret_cast(kargs.v_descale_ptr); return FmhaPipeline{}(q_dram_window, k_dram_window, @@ -1326,10 +1324,10 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel kargs.batch_stride_v, dropout, sink_value, - kv_block_descale_ptr, + k_descale_ptr, + v_descale_ptr, kargs.nblock_stride_kv_block_descale, - kargs.nhead_stride_kv_block_descale, - kargs.kv_stride_kv_block_descale); + kargs.nhead_stride_kv_block_descale); } else { diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp index 592c9322e8f..6672940576c 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp @@ -409,10 +409,10 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync DropoutType& dropout, const float sink_v, // KV_BLOCKSCALE parameters (only used when QScaleEnum == KV_BLOCKSCALE) - const float* kv_block_descale_ptr = nullptr, + const float* k_descale_ptr = nullptr, + const float* v_descale_ptr = nullptr, index_t nblock_stride_kv_block_descale = 0, - index_t nhead_stride_kv_block_descale = 0, - index_t kv_stride_kv_block_descale = 1) const + index_t nhead_stride_kv_block_descale = 0) const { // KV_BLOCKSCALE requires page_block_size >= kN0 to ensure // all tokens in a main loop iteration belong to the same page @@ -899,8 +899,8 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync const index_t scale_offset = k_physical_pages[number<0>{}] * nblock_stride_kv_block_descale + block_indices.kv_head_idx * nhead_stride_kv_block_descale; - k_descale = kv_block_descale_ptr[scale_offset + 0 * kv_stride_kv_block_descale]; - v_descale = kv_block_descale_ptr[scale_offset + 1 * kv_stride_kv_block_descale]; + k_descale = k_descale_ptr[scale_offset]; + v_descale = v_descale_ptr[scale_offset]; } // Prefetch V physical pages early - overlaps with GEMM0 computation @@ -1556,10 +1556,10 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync const index_t page_stride_v, DropoutType& dropout, float sink_v, - const float* kv_block_descale_ptr, + const float* k_descale_ptr, + const float* v_descale_ptr, index_t nblock_stride_kv_block_descale, - index_t nhead_stride_kv_block_descale, - index_t kv_stride_kv_block_descale) const + index_t nhead_stride_kv_block_descale) const { return operator()(q_dram_block_window_tmp, identity{}, @@ -1589,10 +1589,10 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync page_stride_v, dropout, sink_v, - kv_block_descale_ptr, + k_descale_ptr, + v_descale_ptr, nblock_stride_kv_block_descale, - nhead_stride_kv_block_descale, - kv_stride_kv_block_descale); + nhead_stride_kv_block_descale); } };