Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions ggml/src/ggml-vulkan/ggml-vulkan.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -520,6 +520,7 @@ struct vk_device_struct {
bool subgroup_shuffle;
bool subgroup_ballot;
bool subgroup_clustered;
bool subgroup_vote;
bool multi_add;
bool shader_int64;
bool buffer_device_address;
Expand Down Expand Up @@ -4186,6 +4187,9 @@ static vk_device ggml_vk_get_device(size_t idx) {
device->subgroup_ballot = (vk11_props.subgroupSupportedStages & vk::ShaderStageFlagBits::eCompute) &&
(vk11_props.subgroupSupportedOperations & vk::SubgroupFeatureFlagBits::eBallot);

device->subgroup_vote = (vk11_props.subgroupSupportedStages & vk::ShaderStageFlagBits::eCompute) &&
(vk11_props.subgroupSupportedOperations & vk::SubgroupFeatureFlagBits::eVote);

const bool force_disable_f16 = getenv("GGML_VK_DISABLE_F16") != nullptr;

device->fp16 = !force_disable_f16 && fp16_storage && fp16_compute;
Expand Down Expand Up @@ -13521,8 +13525,8 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
default:
return false;
}
if (!coopmat2 && !device->subgroup_shuffle) {
// scalar FA uses subgroupShuffle
if (!coopmat2 && !(device->subgroup_shuffle && device->subgroup_vote)) {
// scalar/coopmat1 FA uses subgroupShuffle/subgroupAll
return false;
}
return true;
Expand Down
48 changes: 33 additions & 15 deletions ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require

#extension GL_KHR_shader_subgroup_shuffle : enable
#extension GL_KHR_shader_subgroup_vote : enable

#include "types.glsl"
#include "flash_attn_base.glsl"
Expand Down Expand Up @@ -108,6 +109,38 @@ void main() {
[[dont_unroll]]
for (uint32_t j = start_j; j < end_j; ++j) {

if ((p.mask_n_head_log2 & MASK_ENABLE_BIT) != 0) {
bool nem1_bounds_check = !(p.gqa_ratio > 1) && (p.nem1 % Br) != 0;

float max_mask = NEG_FLT_MAX_OVER_2;
[[unroll]] for (uint32_t idx = 0; idx < Bc * Br; idx += gl_WorkGroupSize.x) {
uint32_t c = (idx + tid) % Bc;
uint32_t r = (idx + tid) / Bc;
if (idx + tid < Bc * Br) {
if ((!KV_bounds_check || j * Bc + c < KV) && (!nem1_bounds_check || i * Br + r < p.nem1)) {
float m = float(data_m[m_offset + (i * Br + r) * m_stride + (j * Bc + c)]);
masksh[c][r] = m;
max_mask = max(max_mask, m);
} else {
masksh[c][r] = float(0);
}
}
}
// skip the block if the mask is entirely -inf
bool all_less = subgroupAll(max_mask <= NEG_FLT_MAX_OVER_2);
barrier();
if (gl_SubgroupInvocationID == 0) {
tmpsh[gl_SubgroupID] = all_less ? NEG_FLT_MAX_OVER_2 : 0.0f;
}
barrier();
[[unroll]] for (uint s = 0; s < gl_NumSubgroups; ++s) {
max_mask = max(max_mask, tmpsh[s]);
}
if (max_mask <= NEG_FLT_MAX_OVER_2) {
continue;
}
}

float Sf[Br][cols_per_thread];
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
[[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) {
Expand Down Expand Up @@ -153,21 +186,6 @@ void main() {
}

if ((p.mask_n_head_log2 & MASK_ENABLE_BIT) != 0) {
bool nem1_bounds_check = !(p.gqa_ratio > 1) && (p.nem1 % Br) != 0;

[[unroll]] for (uint32_t idx = 0; idx < Bc * Br; idx += gl_WorkGroupSize.x) {
uint32_t c = (idx + tid) % Bc;
uint32_t r = (idx + tid) / Bc;
if (idx + tid < Bc * Br) {
if ((!KV_bounds_check || j * Bc + c < KV) && (!nem1_bounds_check || i * Br + r < p.nem1)) {
masksh[c][r] = float(data_m[m_offset + (i * Br + r) * m_stride + (j * Bc + c)]);
} else {
masksh[c][r] = float(0);
}
}
}
barrier();

[[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) {
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
float mvf = masksh[c * cols_per_iter + col_tid][r];
Expand Down
35 changes: 34 additions & 1 deletion ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require

#extension GL_KHR_shader_subgroup_basic : enable
#extension GL_KHR_shader_subgroup_vote : enable
#extension GL_KHR_memory_scope_semantics : enable
#extension GL_KHR_cooperative_matrix : enable

Expand Down Expand Up @@ -148,6 +149,37 @@ void main() {
[[dont_unroll]]
for (uint32_t j = start_j; j < end_j; ++j) {

float mask_cache[Bc * Br / WorkGroupSize];
if ((p.mask_n_head_log2 & MASK_ENABLE_BIT) != 0) {
bool nem1_bounds_check = !(p.gqa_ratio > 1) && (p.nem1 % Br) != 0;

float max_mask = NEG_FLT_MAX_OVER_2;
[[unroll]] for (uint32_t idx = 0; idx < Bc * Br; idx += gl_WorkGroupSize.x) {
uint32_t c = (idx + tid) % Bc;
uint32_t r = (idx + tid) / Bc;
if (idx + tid < Bc * Br || idx + gl_WorkGroupSize.x <= Bc * Br) {
if ((!KV_bounds_check || j * Bc + c < KV) && (!nem1_bounds_check || i * Br + r < p.nem1)) {
float m = float(data_m[m_offset + (i * Br + r) * m_stride + (j * Bc + c)]);
mask_cache[idx / WorkGroupSize] = m;
max_mask = max(max_mask, m);
}
}
}
// skip the block if the mask is entirely -inf
bool all_less = subgroupAll(max_mask <= NEG_FLT_MAX_OVER_2);
barrier();
if (gl_SubgroupInvocationID == 0) {
tmpsh[gl_SubgroupID] = all_less ? NEG_FLT_MAX_OVER_2 : 0.0f;
}
barrier();
[[unroll]] for (uint s = 0; s < gl_NumSubgroups; ++s) {
max_mask = max(max_mask, tmpsh[s]);
}
if (max_mask <= NEG_FLT_MAX_OVER_2) {
continue;
}
}

[[unroll]] for (uint32_t idx = 0; idx < Bc * HSK / 4; idx += gl_WorkGroupSize.x) {
uint32_t d = (idx + tid) % (HSK / 4);
uint32_t c = (idx + tid) / (HSK / 4);
Expand Down Expand Up @@ -208,7 +240,8 @@ void main() {
uint32_t r = (idx + tid) / Bc;
if (idx + tid < Bc * Br || idx + gl_WorkGroupSize.x <= Bc * Br) {
if ((!KV_bounds_check || j * Bc + c < KV) && (!nem1_bounds_check || i * Br + r < p.nem1)) {
sfsh[c * sfshstride + r] += ACC_TYPE(slope[r] * float(data_m[m_offset + (i * Br + r) * m_stride + (j * Bc + c)]));
float f = mask_cache[idx / WorkGroupSize];
sfsh[c * sfshstride + r] += ACC_TYPE(slope[r] * f);
}
}
}
Expand Down
56 changes: 37 additions & 19 deletions ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,10 @@ ACC_TYPE maxReduce(const in ACC_TYPE x, const in ACC_TYPE y) {
return max(x, y);
}

float16_t maxReduceFp16(const in float16_t x, const in float16_t y) {
return max(x, y);
}

ACC_TYPE smearReduce(const in ACC_TYPE x, const in ACC_TYPE y) {
return x;
}
Expand Down Expand Up @@ -142,49 +146,63 @@ void main() {
[[dont_unroll]]
for (uint32_t j = start_j; j < end_j; ++j) {

coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> S = coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(0);

coopmat<float16_t, gl_ScopeWorkgroup, HSK_pad, Bc, gl_MatrixUseB> K_T;

uint32_t k_offset = ik2*p.nb12 + ik3*p.nb13;
coopMatLoadTensorNV(K_T, data_k, k_offset, sliceTensorLayoutNV(tensorLayoutK, j * Bc, Bc, 0, HSK_pad), tensorViewTranspose DECODEFUNC);
S = coopMatMulAdd(Qf16, K_T, S);

if (p.logit_softcap != 0.0f) {
[[unroll]]
for (int k = 0; k < S.length(); ++k) {
S[k] = ACC_TYPE(p.logit_softcap)*tanh(S[k]);
}
}

coopmat<float16_t, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> mv;
if ((p.mask_n_head_log2 & MASK_ENABLE_BIT) != 0) {
bool nem1_bounds_check = !(p.gqa_ratio > 1) && (p.nem1 % Br) != 0;

if (nem1_bounds_check) {
tensorLayoutNV<2, gl_CooperativeMatrixClampModeConstantNV> tensorLayoutM = createTensorLayoutNV(2, gl_CooperativeMatrixClampModeConstantNV);
tensorLayoutM = setTensorLayoutDimensionNV(tensorLayoutM, p.nem1, KV);
tensorLayoutM = setTensorLayoutStrideNV(tensorLayoutM, m_stride, 1);
tensorLayoutM = setTensorLayoutClampValueNV(tensorLayoutM, 0xfc00); // -inf in float16_t

coopmat<float16_t, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> mv;
coopmat<float16_t, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> mv, mvmax;

coopMatLoadTensorNV(mv, data_m, m_offset, sliceTensorLayoutNV(tensorLayoutM, i * Br, Br, j * Bc, Bc));

S += slopeMat*coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(mv);
// skip the block if the mask is entirely -inf
coopMatReduceNV(mvmax, mv, gl_CooperativeMatrixReduceRowAndColumnNV, maxReduceFp16);
if (mvmax[0] <= NEG_FLT_MAX_OVER_2) {
continue;
}
} else {
tensorLayoutNV<2, Clamp> tensorLayoutM = createTensorLayoutNV(2, Clamp);
// Don't clamp against nem1 when GQA is enabled
uint32_t m_height = p.gqa_ratio > 1 ? ~0 : p.nem1;
tensorLayoutM = setTensorLayoutDimensionNV(tensorLayoutM, m_height, KV);
tensorLayoutM = setTensorLayoutStrideNV(tensorLayoutM, m_stride, 1);

coopmat<float16_t, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> mv;
coopmat<float16_t, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> mvmax;

coopMatLoadTensorNV(mv, data_m, m_offset, sliceTensorLayoutNV(tensorLayoutM, i * Br, Br, j * Bc, Bc));

S += slopeMat*coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(mv);
// skip the block if the mask is entirely -inf
coopMatReduceNV(mvmax, mv, gl_CooperativeMatrixReduceRowAndColumnNV, maxReduceFp16);
if (mvmax[0] <= NEG_FLT_MAX_OVER_2) {
continue;
}
}
}

coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> S = coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(0);

coopmat<float16_t, gl_ScopeWorkgroup, HSK_pad, Bc, gl_MatrixUseB> K_T;

uint32_t k_offset = ik2*p.nb12 + ik3*p.nb13;
coopMatLoadTensorNV(K_T, data_k, k_offset, sliceTensorLayoutNV(tensorLayoutK, j * Bc, Bc, 0, HSK_pad), tensorViewTranspose DECODEFUNC);
S = coopMatMulAdd(Qf16, K_T, S);

if (p.logit_softcap != 0.0f) {
[[unroll]]
for (int k = 0; k < S.length(); ++k) {
S[k] = ACC_TYPE(p.logit_softcap)*tanh(S[k]);
}
}

if ((p.mask_n_head_log2 & MASK_ENABLE_BIT) != 0) {
S += slopeMat*coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(mv);
}

// Clear padding elements to -inf, so they don't contribute to rowmax
if (Clamp != 0 &&
((j + 1) * Bc > KV ||
Expand Down
Loading