Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py
Original file line number Diff line number Diff line change
Expand Up @@ -1208,6 +1208,7 @@ def fit(problem_ctx: ProblemContext, kernel_ctx: KernelContext) -> bool:
cond &= kernel_ctx.pipeline.F_bias in ["no", "alibi"]
cond &= kernel_ctx.pipeline.F_qscale == "no"
cond &= kernel_ctx.pipeline.F_skip == "f"
cond &= kernel_ctx.pipeline.F_sink == "f"
return cond

return Product(name="Flash attention integration", rule=fit)
Expand Down
9 changes: 4 additions & 5 deletions example/ck_tile/01_fmha/fmha_fwd.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -232,7 +232,7 @@ struct fmha_fwd_args
// array [batch + 1]. (Used with padding)
const void* block_scale_seqstart_q_ptr;
const void* block_scale_seqstart_k_ptr;
const void* sink_ptr;
const void* sink_ptr = nullptr;

ck_tile::index_t seqlen_q;
ck_tile::index_t seqlen_k;
Expand Down Expand Up @@ -329,7 +329,7 @@ struct fmha_fwd_pagedkv_args
const void* seqstart_q_ptr;
const void* seqstart_k_ptr;
const void* seqlen_k_ptr;
const void* sink_ptr;
const void* sink_ptr = nullptr;

ck_tile::index_t seqlen_q;
ck_tile::index_t seqlen_k;
Expand Down Expand Up @@ -413,7 +413,7 @@ struct fmha_fwd_splitkv_args
const void* seqstart_q_ptr;
const void* seqstart_k_ptr;
const void* seqlen_k_ptr;
const void* sink_ptr;
const void* sink_ptr = nullptr;

ck_tile::index_t seqlen_q;
ck_tile::index_t seqlen_k;
Expand Down Expand Up @@ -490,7 +490,6 @@ struct fmha_fwd_appendkv_args
ck_tile::index_t page_block_size; // only used if 'block_table_ptr' is not nullptr

const void* cache_batch_idx; // only used if block_table_ptr is nullptr -> batch mode (kvcache)
const void* sink_ptr;

ck_tile::index_t stride_q;
ck_tile::index_t stride_k;
Expand Down Expand Up @@ -534,7 +533,7 @@ struct fmha_batch_prefill_args
// 1) +
// kargs.kv_last_page_lens[b]
const void* seqstart_q_ptr;
const void* sink_ptr;
const void* sink_ptr = nullptr;

ck_tile::index_t seqlen_q;
ck_tile::index_t seqlen_k;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -426,18 +426,22 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
auto m = MLBlockTileType{};
auto l = MLBlockTileType{};

clear_tile(o_acc);
if(__builtin_isinf_sign(sink_v) >= 0)
{
const auto init_m_val = [&]() {
#if CK_TILE_FMHA_FWD_FAST_EXP2
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS ||
BiasEnum == BlockAttentionBiasEnum::ALIBI)
set_tile(m, sink_v * LOG2E * scale_s);
return sink_v * LOG2E * scale_s;
else
set_tile(m, sink_v * LOG2E);
return sink_v * LOG2E;
#else
set_tile(m, sink_v);
return sink_v;
#endif
}();

clear_tile(o_acc);
if(__builtin_isinf_sign(sink_v) >= 0)
{
set_tile(m, init_m_val);
set_tile(l, SMPLComputeDataType{1.0f});
}
else
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -227,18 +227,22 @@ struct BlockFmhaFwdPagedKVPipelineQRKSVS
auto m = MLBlockTileType{};
auto l = MLBlockTileType{};

clear_tile(o_acc);
if(__builtin_isinf_sign(sink_v) >= 0)
{
const auto init_m_val = [&]() {
#if CK_TILE_FMHA_FWD_FAST_EXP2
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS ||
BiasEnum == BlockAttentionBiasEnum::ALIBI)
set_tile(m, sink_v * C_LOG2E * scale_s);
return sink_v * C_LOG2E * scale_s;
else
set_tile(m, sink_v * C_LOG2E);
return sink_v * C_LOG2E;
#else
set_tile(m, sink_v);
return sink_v;
#endif
}();

clear_tile(o_acc);
if(__builtin_isinf_sign(sink_v) >= 0)
{
set_tile(m, init_m_val);
set_tile(l, SMPLComputeDataType{1.0f});
}
else
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -254,10 +254,22 @@ struct BlockFmhaFwdSplitKVPipelineNWarpSShuffleQRKSVS
auto m = MLBlockTileType{};
auto l = MLBlockTileType{};

const auto init_m_val = [&]() {
#if CK_TILE_FMHA_FWD_FAST_EXP2
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS ||
BiasEnum == BlockAttentionBiasEnum::ALIBI)
return sink_v * C_LOG2E * scale_s;
else
return sink_v * C_LOG2E;
#else
return sink_v;
#endif
}();

clear_tile(o_acc);
if((__builtin_isinf_sign(sink_v) >= 0) && i_split == 0)
{
set_tile(m, SMPLComputeDataType{sink_v * C_LOG2E});
set_tile(m, init_m_val);
set_tile(l, SMPLComputeDataType{1.0f});
}
else
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -227,18 +227,22 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
auto m = MLBlockTileType{};
auto l = MLBlockTileType{};

clear_tile(o_acc);
if((__builtin_isinf_sign(sink_v) >= 0) && i_split == 0)
{
const auto init_m_val = [&]() {
#if CK_TILE_FMHA_FWD_FAST_EXP2
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS ||
BiasEnum == BlockAttentionBiasEnum::ALIBI)
set_tile(m, sink_v * C_LOG2E * scale_s);
return sink_v * C_LOG2E * scale_s;
else
set_tile(m, sink_v * C_LOG2E);
return sink_v * C_LOG2E;
#else
set_tile(m, sink_v);
return sink_v;
#endif
}();

clear_tile(o_acc);
if((__builtin_isinf_sign(sink_v) >= 0) && i_split == 0)
{
set_tile(m, init_m_val);
set_tile(l, SMPLComputeDataType{1.0f});
}
else
Expand Down Expand Up @@ -302,15 +306,7 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
q_origin.at(number<0>{}), number<kM0>{}, number<kN0>{}, num_splits, i_split - 1);
if((__builtin_isinf_sign(sink_v) >= 0) && start >= end)
{
#if CK_TILE_FMHA_FWD_FAST_EXP2
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS ||
BiasEnum == BlockAttentionBiasEnum::ALIBI)
set_tile(m, sink_v * C_LOG2E * scale_s);
else
set_tile(m, sink_v * C_LOG2E);
#else
set_tile(m, sink_v);
#endif
set_tile(m, init_m_val);
set_tile(l, SMPLComputeDataType{1.0f});
}
else
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -238,18 +238,22 @@ struct BlockFmhaPipelineQRKSVS
auto m = MLBlockTileType{};
auto l = MLBlockTileType{};

clear_tile(o_acc);
if(__builtin_isinf_sign(sink_v) >= 0)
{
const auto init_m_val = [&]() {
#if CK_TILE_FMHA_FWD_FAST_EXP2
if constexpr(BiasEnum == BlockAttentionBiasEnum::ALIBI ||
BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
set_tile(m, sink_v * scale_s * C_LOG2E);
return sink_v * scale_s * C_LOG2E;
else
set_tile(m, sink_v * C_LOG2E);
return sink_v * C_LOG2E;
#else
set_tile(m, sink_v);
return sink_v;
#endif
}();

clear_tile(o_acc);
if(__builtin_isinf_sign(sink_v) >= 0)
{
set_tile(m, init_m_val);
set_tile(l, SMPLComputeDataType{1.0f});
}
else
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -279,22 +279,25 @@ struct BlockFmhaPipelineQRKSVSAsync
using OaccBlockTileType = decltype(gemm_1.MakeCBlockTile());

// init Oacc, M, L
auto o_acc = OaccBlockTileType{};
auto m = MLBlockTileType{};
auto l = MLBlockTileType{};

clear_tile(o_acc);
if(__builtin_isinf_sign(sink_v) >= 0)
{
auto o_acc = OaccBlockTileType{};
auto m = MLBlockTileType{};
auto l = MLBlockTileType{};
const auto init_m_val = [&]() {
#if CK_TILE_FMHA_FWD_FAST_EXP2
if constexpr(BiasEnum == BlockAttentionBiasEnum::ALIBI ||
BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
set_tile(m, sink_v * scale_s * LOG2E);
return sink_v * scale_s * LOG2E;
else
set_tile(m, sink_v * LOG2E);
return sink_v * LOG2E;
#else
set_tile(m, sink_v);
return sink_v;
#endif
}();

clear_tile(o_acc);
if(__builtin_isinf_sign(sink_v) >= 0)
{
set_tile(m, init_m_val);
set_tile(l, SMPLComputeDataType{1.0f});
}
else
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -193,18 +193,22 @@ struct BlockFmhaPipelineQRKSVSAsyncTrload
auto m = MLBlockTileType{};
auto l = MLBlockTileType{};

clear_tile(o_acc);
if(__builtin_isinf_sign(sink_v) >= 0)
{
const auto init_m_val = [&]() {
#if CK_TILE_FMHA_FWD_FAST_EXP2
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS ||
BiasEnum == BlockAttentionBiasEnum::ALIBI)
set_tile(m, sink_v * C_LOG2E * scale_s);
return sink_v * C_LOG2E * scale_s;
else
set_tile(m, sink_v * C_LOG2E);
return sink_v * C_LOG2E;
#else
set_tile(m, sink_v);
return sink_v;
#endif
}();

clear_tile(o_acc);
if(__builtin_isinf_sign(sink_v) >= 0)
{
set_tile(m, init_m_val);
set_tile(l, SMPLComputeDataType{1.0f});
}
else
Expand Down Expand Up @@ -722,18 +726,22 @@ struct BlockFmhaPipelineQRKSVSAsyncTrload
auto m = MLBlockTileType{};
auto l = MLBlockTileType{};

clear_tile(o_acc);
if(__builtin_isinf_sign(sink_v) >= 0)
{
const auto init_m_val = [&]() {
#if CK_TILE_FMHA_FWD_FAST_EXP2
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS ||
BiasEnum == BlockAttentionBiasEnum::ALIBI)
set_tile(m, sink_v * C_LOG2E * scale_s);
return sink_v * C_LOG2E * scale_s;
else
set_tile(m, sink_v * C_LOG2E);
return sink_v * C_LOG2E;
#else
set_tile(m, sink_v);
return sink_v;
#endif
}();

clear_tile(o_acc);
if(__builtin_isinf_sign(sink_v) >= 0)
{
set_tile(m, init_m_val);
set_tile(l, SMPLComputeDataType{1.0f});
}
else
Expand Down