Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions example/ck_tile/01_fmha/codegen/cpp_symbol_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,12 +78,14 @@ def get_mask_cpp_check_expr(mask: str) -> str:
"no": "ck_tile::BlockAttentionQuantScaleEnum::NO_SCALE",
"pertensor": "ck_tile::BlockAttentionQuantScaleEnum::PERTENSOR",
"blockscale": "ck_tile::BlockAttentionQuantScaleEnum::BLOCKSCALE",
"kv_blockscale": "ck_tile::BlockAttentionQuantScaleEnum::KV_BLOCKSCALE",
}

QSCALE_CHECK_MAP = {
"no": "quant_scale_enum::no_scale",
"pertensor": "quant_scale_enum::pertensor",
"blockscale": "quant_scale_enum::blockscale",
"kv_blockscale": "quant_scale_enum::kv_blockscale",
}

BIAS_MAP = {
Expand Down
6 changes: 5 additions & 1 deletion example/ck_tile/01_fmha/codegen/ops/fmha_batch_prefill.py
Original file line number Diff line number Diff line change
Expand Up @@ -677,7 +677,7 @@ def get_pipelines(dtype, hdim, receipt, mask_impl) -> List[FmhaFwdPipeline]:
kv_lookup_table,
) in itertools.product(
["t", "f"],
["pertensor"],
["pertensor", "kv_blockscale"],
get_mask_map(mask_impl).keys(),
["no"],
SUPPORTED_KV_MEMORY_LAYOUT,
Expand Down Expand Up @@ -740,6 +740,10 @@ def get_fwd_blobs(
for page_size in SUPPORTED_PAGE_SIZE:
if page_size == 1 and pipeline.F_kv_memory_layout != "linear":
continue
# kv_blockscale requires page_size >= kN0 (tile.F_bn0)
# This ensures all tokens in a main loop iteration belong to the same page
if pipeline.F_qscale == "kv_blockscale" and page_size < tile.F_bn0:
continue
k = FmhaFwdKernel(
F_idx=0,
F_hdim=hdim,
Expand Down
15 changes: 13 additions & 2 deletions example/ck_tile/01_fmha/fmha_fwd.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -602,6 +602,13 @@ struct fmha_batch_prefill_args

std::variant<std::pair<uint64_t, uint64_t>, std::pair<const void*, const void*>>
drop_seed_offset;

// KV_BLOCKSCALE: per-page K/V descales (Q per-tensor, K/V per-page)
// k_descale_ptr/v_descale_ptr are reused for KV_BLOCKSCALE mode:
// k_descale_ptr: [num_block, num_kv_head] - points to k block descale
// v_descale_ptr: [num_block, num_kv_head] - points to v block descale
ck_tile::index_t nblock_stride_kv_block_descale = 0; // Stride along num_block dimension
ck_tile::index_t nhead_stride_kv_block_descale = 0; // Stride along num_kv_head dimension
};

template <typename FmhaKernel>
Expand Down Expand Up @@ -1225,7 +1232,9 @@ auto fmha_batch_prefill_create_kargs_and_grids(fmha_batch_prefill_args args)
args.p_drop,
args.s_randval,
args.drop_seed_offset,
args.sink_ptr);
args.sink_ptr,
args.nblock_stride_kv_block_descale,
args.nhead_stride_kv_block_descale);
}
else
{ // create batch mode kernel arguments
Expand Down Expand Up @@ -1278,7 +1287,9 @@ auto fmha_batch_prefill_create_kargs_and_grids(fmha_batch_prefill_args args)
args.p_drop,
args.s_randval,
args.drop_seed_offset,
args.sink_ptr);
args.sink_ptr,
args.nblock_stride_kv_block_descale,
args.nhead_stride_kv_block_descale);
}
}();

Expand Down
13 changes: 10 additions & 3 deletions example/ck_tile/01_fmha/quant.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,10 @@
// keep sync with BlockAttentionQuantScaleEnum
enum class quant_scale_enum
{
no_scale = 0,
pertensor = 1,
blockscale,
no_scale = 0,
pertensor = 1,
blockscale = 2,
kv_blockscale = 3, // Q per-tensor, K/V per-page block scale
};

struct quant_scale_info
Expand All @@ -31,6 +32,8 @@ struct quant_scale_info
os << "pt";
else if(type == quant_scale_enum::blockscale)
os << "bs";
else if(type == quant_scale_enum::kv_blockscale)
os << "kvbs";
}

static quant_scale_info decode(std::string str)
Expand All @@ -48,6 +51,10 @@ struct quant_scale_info
{
info.type = quant_scale_enum::blockscale;
}
else if(str == "kvbs" || str == "3")
{
info.type = quant_scale_enum::kv_blockscale;
}
else
{
throw std::invalid_argument("invalid quant scale value: " + str);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,10 @@ namespace ck_tile {
// This class is used for codegen pattern matching
enum class BlockAttentionQuantScaleEnum
{
NO_SCALE = 0,
PERTENSOR = 1,
BLOCKSCALE,
NO_SCALE = 0,
PERTENSOR = 1,
BLOCKSCALE = 2,
KV_BLOCKSCALE = 3, // Q per-tensor, K/V per-page block scale
};

template <BlockAttentionQuantScaleEnum>
Expand Down
110 changes: 101 additions & 9 deletions include/ck_tile/ops/fmha/kernel/fmha_batch_prefill_kernel.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#include "ck_tile/ops/fmha/block/block_attention_quant_scale_enum.hpp"
#include "ck_tile/ops/fmha/block/variants.hpp"

#include <cassert>
#include <string>
#include <type_traits>
#include <utility>
Expand Down Expand Up @@ -185,13 +186,45 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel
ck_tile::index_t batch_stride_lse = 0;
};

struct FmhaFwdCommonQScaleKargs
// PERTENSOR: Q/K/V all use per-tensor descales
struct FmhaFwdPerTensorQScaleKargs
{
const void* q_descale_ptr = nullptr;
const void* k_descale_ptr = nullptr;
const void* v_descale_ptr = nullptr;
};

// KV_BLOCKSCALE: Q per-tensor, K/V per-page descales
// K descale: [num_block, num_kv_head], V descale: [num_block, num_kv_head]
struct FmhaFwdKVBlockScaleKargs
{
const void* q_descale_ptr = nullptr; // Per-tensor Q descale
const void* k_descale_ptr = nullptr; // [num_block, num_kv_head]
const void* v_descale_ptr = nullptr; // [num_block, num_kv_head]
ck_tile::index_t nblock_stride_kv_block_descale = 0; // Stride along num_block dimension
ck_tile::index_t nhead_stride_kv_block_descale = 0; // Stride along num_kv_head dimension
};

// Helper template to select QScale Kargs type based on QScaleEnum
// EmptyType: type to use when QScaleEnum is NO_SCALE (e.g., FmhaFwdEmptyKargs<3>)
template <BlockAttentionQuantScaleEnum QScale, typename EmptyType>
struct GetQScaleKargs
{
using type = EmptyType;
};

template <typename EmptyType>
struct GetQScaleKargs<BlockAttentionQuantScaleEnum::PERTENSOR, EmptyType>
{
using type = FmhaFwdPerTensorQScaleKargs;
};

template <typename EmptyType>
struct GetQScaleKargs<BlockAttentionQuantScaleEnum::KV_BLOCKSCALE, EmptyType>
{
using type = FmhaFwdKVBlockScaleKargs;
};

struct FmhaFwdDropoutSeedOffset
{
template <typename T>
Expand Down Expand Up @@ -255,9 +288,7 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel
FmhaFwdEmptyKargs<0>>>,
std::conditional_t<kHasMask, FmhaFwdMaskKargs, FmhaFwdEmptyKargs<1>>,
std::conditional_t<kStoreLSE, FmhaFwdCommonLSEKargs, FmhaFwdEmptyKargs<2>>,
std::conditional_t<QScaleEnum == BlockAttentionQuantScaleEnum::PERTENSOR,
FmhaFwdCommonQScaleKargs,
FmhaFwdEmptyKargs<3>>,
GetQScaleKargs<QScaleEnum, FmhaFwdEmptyKargs<3>>::type,
std::conditional_t<kHasDropout, FmhaFwdBatchModeDropoutKargs, FmhaFwdEmptyKargs<4>>,
std::conditional_t<kHasLogitsSoftCap, FmhaFwdLogitsSoftCapKargs, FmhaFwdEmptyKargs<5>>
{
Expand All @@ -276,9 +307,7 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel
FmhaFwdEmptyKargs<0>>>,
std::conditional_t<kHasMask, FmhaFwdMaskKargs, FmhaFwdEmptyKargs<1>>,
std::conditional_t<kStoreLSE, FmhaFwdCommonLSEKargs, FmhaFwdEmptyKargs<2>>,
std::conditional_t<QScaleEnum == BlockAttentionQuantScaleEnum::PERTENSOR,
FmhaFwdCommonQScaleKargs,
FmhaFwdEmptyKargs<3>>,
GetQScaleKargs<QScaleEnum, FmhaFwdEmptyKargs<3>>::type,
std::conditional_t<kHasDropout, FmhaFwdCommonDropoutKargs, FmhaFwdEmptyKargs<4>>,
std::conditional_t<kHasLogitsSoftCap, FmhaFwdLogitsSoftCapKargs, FmhaFwdEmptyKargs<5>>
{
Expand Down Expand Up @@ -348,7 +377,9 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel
bool s_randval,
std::variant<std::pair<uint64_t, uint64_t>, std::pair<const void*, const void*>>
drop_seed_offset,
const void* sink_ptr = nullptr)
const void* sink_ptr = nullptr,
ck_tile::index_t nblock_stride_kv_block_descale = 0,
ck_tile::index_t nhead_stride_kv_block_descale = 0)
{
Kargs kargs{{q_ptr,
k_ptr,
Expand Down Expand Up @@ -419,6 +450,14 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel
kargs.k_descale_ptr = k_descale_ptr;
kargs.v_descale_ptr = v_descale_ptr;
}
else if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::KV_BLOCKSCALE)
{
kargs.q_descale_ptr = q_descale_ptr;
kargs.k_descale_ptr = k_descale_ptr;
kargs.v_descale_ptr = v_descale_ptr;
kargs.nblock_stride_kv_block_descale = nblock_stride_kv_block_descale;
kargs.nhead_stride_kv_block_descale = nhead_stride_kv_block_descale;
}
if constexpr(kHasDropout)
{
if(drop_seed_offset.index() == 0) // seed & offset come from host
Expand Down Expand Up @@ -495,7 +534,9 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel
bool s_randval,
std::variant<std::pair<uint64_t, uint64_t>, std::pair<const void*, const void*>>
drop_seed_offset,
const void* sink_ptr = nullptr)
const void* sink_ptr = nullptr,
ck_tile::index_t nblock_stride_kv_block_descale = 0,
ck_tile::index_t nhead_stride_kv_block_descale = 0)
{
Kargs kargs{{q_ptr,
k_ptr,
Expand Down Expand Up @@ -563,6 +604,14 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel
kargs.k_descale_ptr = k_descale_ptr;
kargs.v_descale_ptr = v_descale_ptr;
}
else if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::KV_BLOCKSCALE)
{
kargs.q_descale_ptr = q_descale_ptr;
kargs.k_descale_ptr = k_descale_ptr;
kargs.v_descale_ptr = v_descale_ptr;
kargs.nblock_stride_kv_block_descale = nblock_stride_kv_block_descale;
kargs.nhead_stride_kv_block_descale = nhead_stride_kv_block_descale;
}
if constexpr(kHasDropout)
{
if(drop_seed_offset.index() == 0) // seed & offset come from host
Expand Down Expand Up @@ -1157,11 +1206,20 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel
const float scale_s = [&] {
if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::PERTENSOR)
{
assert(kargs.q_descale_ptr != nullptr);
assert(kargs.k_descale_ptr != nullptr);
float q_descale = *(reinterpret_cast<const float*>(kargs.q_descale_ptr));
float k_descale = *(reinterpret_cast<const float*>(kargs.k_descale_ptr));

return kargs.scale_s * q_descale * k_descale;
}
else if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::KV_BLOCKSCALE)
{
// Q is per-tensor, K is per-page (handled in pipeline)
assert(kargs.q_descale_ptr != nullptr);
float q_descale = *(reinterpret_cast<const float*>(kargs.q_descale_ptr));
return kargs.scale_s * q_descale;
}
else
{
return kargs.scale_s;
Expand Down Expand Up @@ -1194,6 +1252,7 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel
if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::PERTENSOR)
{
// TODO - move global load of descale to pipeline
assert(kargs.v_descale_ptr != nullptr);
float v_descale = *(reinterpret_cast<const float*>(kargs.v_descale_ptr));

float scale_p = ck_tile::type_convert<float>(ck_tile::numeric<PDataType>::max());
Expand Down Expand Up @@ -1237,6 +1296,39 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel
dropout,
sink_value);
}
else if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::KV_BLOCKSCALE)
{
// KV_BLOCKSCALE: K/V descale is per-page, handled in pipeline
assert(kargs.k_descale_ptr != nullptr);
assert(kargs.v_descale_ptr != nullptr);
const float* k_descale_ptr = reinterpret_cast<const float*>(kargs.k_descale_ptr);
const float* v_descale_ptr = reinterpret_cast<const float*>(kargs.v_descale_ptr);

return FmhaPipeline{}(q_dram_window,
k_dram_window,
v_dram_window,
bias_dram_window,
randval_dram_window,
lse_dram_window,
mask,
position_encoding,
variant_params.sm_scale,
variant,
variant_params,
block_indices,
smem_ptr,
page_idx,
stride_k_for_pipeline,
stride_v_for_pipeline,
kargs.batch_stride_k,
kargs.batch_stride_v,
dropout,
sink_value,
k_descale_ptr,
v_descale_ptr,
kargs.nblock_stride_kv_block_descale,
kargs.nhead_stride_kv_block_descale);
}
else
{
return FmhaPipeline{}(q_dram_window,
Expand Down
Loading