diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index c6503f0326031..d74f96352216a 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -520,6 +520,7 @@ struct vk_device_struct { bool subgroup_shuffle; bool subgroup_ballot; bool subgroup_clustered; + bool subgroup_vote; bool multi_add; bool shader_int64; bool buffer_device_address; @@ -4186,6 +4187,9 @@ static vk_device ggml_vk_get_device(size_t idx) { device->subgroup_ballot = (vk11_props.subgroupSupportedStages & vk::ShaderStageFlagBits::eCompute) && (vk11_props.subgroupSupportedOperations & vk::SubgroupFeatureFlagBits::eBallot); + device->subgroup_vote = (vk11_props.subgroupSupportedStages & vk::ShaderStageFlagBits::eCompute) && + (vk11_props.subgroupSupportedOperations & vk::SubgroupFeatureFlagBits::eVote); + const bool force_disable_f16 = getenv("GGML_VK_DISABLE_F16") != nullptr; device->fp16 = !force_disable_f16 && fp16_storage && fp16_compute; @@ -13521,8 +13525,8 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm default: return false; } - if (!coopmat2 && !device->subgroup_shuffle) { - // scalar FA uses subgroupShuffle + if (!coopmat2 && !(device->subgroup_shuffle && device->subgroup_vote)) { + // scalar/coopmat1 FA uses subgroupShuffle/subgroupAll return false; } return true; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp index 2255f9c168e6e..4bef48b006ce7 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp @@ -7,6 +7,7 @@ #extension GL_EXT_shader_explicit_arithmetic_types_int32 : require #extension GL_KHR_shader_subgroup_shuffle : enable +#extension GL_KHR_shader_subgroup_vote : enable #include "types.glsl" #include "flash_attn_base.glsl" @@ -108,6 +109,38 @@ void main() { [[dont_unroll]] for (uint32_t j = start_j; j < end_j; ++j) { + if ((p.mask_n_head_log2 & MASK_ENABLE_BIT) != 0) { + bool nem1_bounds_check = !(p.gqa_ratio > 1) && (p.nem1 % Br) != 0; + + float max_mask = NEG_FLT_MAX_OVER_2; + [[unroll]] for (uint32_t idx = 0; idx < Bc * Br; idx += gl_WorkGroupSize.x) { + uint32_t c = (idx + tid) % Bc; + uint32_t r = (idx + tid) / Bc; + if (idx + tid < Bc * Br) { + if ((!KV_bounds_check || j * Bc + c < KV) && (!nem1_bounds_check || i * Br + r < p.nem1)) { + float m = float(data_m[m_offset + (i * Br + r) * m_stride + (j * Bc + c)]); + masksh[c][r] = m; + max_mask = max(max_mask, m); + } else { + masksh[c][r] = float(0); + } + } + } + // skip the block if the mask is entirely -inf + bool all_less = subgroupAll(max_mask <= NEG_FLT_MAX_OVER_2); + barrier(); + if (gl_SubgroupInvocationID == 0) { + tmpsh[gl_SubgroupID] = all_less ? NEG_FLT_MAX_OVER_2 : 0.0f; + } + barrier(); + [[unroll]] for (uint s = 0; s < gl_NumSubgroups; ++s) { + max_mask = max(max_mask, tmpsh[s]); + } + if (max_mask <= NEG_FLT_MAX_OVER_2) { + continue; + } + } + float Sf[Br][cols_per_thread]; [[unroll]] for (uint32_t r = 0; r < Br; ++r) { [[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) { @@ -153,21 +186,6 @@ void main() { } if ((p.mask_n_head_log2 & MASK_ENABLE_BIT) != 0) { - bool nem1_bounds_check = !(p.gqa_ratio > 1) && (p.nem1 % Br) != 0; - - [[unroll]] for (uint32_t idx = 0; idx < Bc * Br; idx += gl_WorkGroupSize.x) { - uint32_t c = (idx + tid) % Bc; - uint32_t r = (idx + tid) / Bc; - if (idx + tid < Bc * Br) { - if ((!KV_bounds_check || j * Bc + c < KV) && (!nem1_bounds_check || i * Br + r < p.nem1)) { - masksh[c][r] = float(data_m[m_offset + (i * Br + r) * m_stride + (j * Bc + c)]); - } else { - masksh[c][r] = float(0); - } - } - } - barrier(); - [[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) { [[unroll]] for (uint32_t r = 0; r < Br; ++r) { float mvf = masksh[c * cols_per_iter + col_tid][r]; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp index 8699fa6c9cbb7..cd82e4abfabc4 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp @@ -7,6 +7,7 @@ #extension GL_EXT_shader_explicit_arithmetic_types_int32 : require #extension GL_KHR_shader_subgroup_basic : enable +#extension GL_KHR_shader_subgroup_vote : enable #extension GL_KHR_memory_scope_semantics : enable #extension GL_KHR_cooperative_matrix : enable @@ -148,6 +149,37 @@ void main() { [[dont_unroll]] for (uint32_t j = start_j; j < end_j; ++j) { + float mask_cache[Bc * Br / WorkGroupSize]; + if ((p.mask_n_head_log2 & MASK_ENABLE_BIT) != 0) { + bool nem1_bounds_check = !(p.gqa_ratio > 1) && (p.nem1 % Br) != 0; + + float max_mask = NEG_FLT_MAX_OVER_2; + [[unroll]] for (uint32_t idx = 0; idx < Bc * Br; idx += gl_WorkGroupSize.x) { + uint32_t c = (idx + tid) % Bc; + uint32_t r = (idx + tid) / Bc; + if (idx + tid < Bc * Br || idx + gl_WorkGroupSize.x <= Bc * Br) { + if ((!KV_bounds_check || j * Bc + c < KV) && (!nem1_bounds_check || i * Br + r < p.nem1)) { + float m = float(data_m[m_offset + (i * Br + r) * m_stride + (j * Bc + c)]); + mask_cache[idx / WorkGroupSize] = m; + max_mask = max(max_mask, m); + } + } + } + // skip the block if the mask is entirely -inf + bool all_less = subgroupAll(max_mask <= NEG_FLT_MAX_OVER_2); + barrier(); + if (gl_SubgroupInvocationID == 0) { + tmpsh[gl_SubgroupID] = all_less ? NEG_FLT_MAX_OVER_2 : 0.0f; + } + barrier(); + [[unroll]] for (uint s = 0; s < gl_NumSubgroups; ++s) { + max_mask = max(max_mask, tmpsh[s]); + } + if (max_mask <= NEG_FLT_MAX_OVER_2) { + continue; + } + } + [[unroll]] for (uint32_t idx = 0; idx < Bc * HSK / 4; idx += gl_WorkGroupSize.x) { uint32_t d = (idx + tid) % (HSK / 4); uint32_t c = (idx + tid) / (HSK / 4); @@ -208,7 +240,8 @@ void main() { uint32_t r = (idx + tid) / Bc; if (idx + tid < Bc * Br || idx + gl_WorkGroupSize.x <= Bc * Br) { if ((!KV_bounds_check || j * Bc + c < KV) && (!nem1_bounds_check || i * Br + r < p.nem1)) { - sfsh[c * sfshstride + r] += ACC_TYPE(slope[r] * float(data_m[m_offset + (i * Br + r) * m_stride + (j * Bc + c)])); + float f = mask_cache[idx / WorkGroupSize]; + sfsh[c * sfshstride + r] += ACC_TYPE(slope[r] * f); } } } diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp index fcfc60a878544..617d85108698a 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp @@ -29,6 +29,10 @@ ACC_TYPE maxReduce(const in ACC_TYPE x, const in ACC_TYPE y) { return max(x, y); } +float16_t maxReduceFp16(const in float16_t x, const in float16_t y) { + return max(x, y); +} + ACC_TYPE smearReduce(const in ACC_TYPE x, const in ACC_TYPE y) { return x; } @@ -142,21 +146,7 @@ void main() { [[dont_unroll]] for (uint32_t j = start_j; j < end_j; ++j) { - coopmat S = coopmat(0); - - coopmat K_T; - - uint32_t k_offset = ik2*p.nb12 + ik3*p.nb13; - coopMatLoadTensorNV(K_T, data_k, k_offset, sliceTensorLayoutNV(tensorLayoutK, j * Bc, Bc, 0, HSK_pad), tensorViewTranspose DECODEFUNC); - S = coopMatMulAdd(Qf16, K_T, S); - - if (p.logit_softcap != 0.0f) { - [[unroll]] - for (int k = 0; k < S.length(); ++k) { - S[k] = ACC_TYPE(p.logit_softcap)*tanh(S[k]); - } - } - + coopmat mv; if ((p.mask_n_head_log2 & MASK_ENABLE_BIT) != 0) { bool nem1_bounds_check = !(p.gqa_ratio > 1) && (p.nem1 % Br) != 0; @@ -164,12 +154,17 @@ void main() { tensorLayoutNV<2, gl_CooperativeMatrixClampModeConstantNV> tensorLayoutM = createTensorLayoutNV(2, gl_CooperativeMatrixClampModeConstantNV); tensorLayoutM = setTensorLayoutDimensionNV(tensorLayoutM, p.nem1, KV); tensorLayoutM = setTensorLayoutStrideNV(tensorLayoutM, m_stride, 1); + tensorLayoutM = setTensorLayoutClampValueNV(tensorLayoutM, 0xfc00); // -inf in float16_t - coopmat mv; + coopmat mv, mvmax; coopMatLoadTensorNV(mv, data_m, m_offset, sliceTensorLayoutNV(tensorLayoutM, i * Br, Br, j * Bc, Bc)); - S += slopeMat*coopmat(mv); + // skip the block if the mask is entirely -inf + coopMatReduceNV(mvmax, mv, gl_CooperativeMatrixReduceRowAndColumnNV, maxReduceFp16); + if (mvmax[0] <= NEG_FLT_MAX_OVER_2) { + continue; + } } else { tensorLayoutNV<2, Clamp> tensorLayoutM = createTensorLayoutNV(2, Clamp); // Don't clamp against nem1 when GQA is enabled @@ -177,14 +172,37 @@ void main() { tensorLayoutM = setTensorLayoutDimensionNV(tensorLayoutM, m_height, KV); tensorLayoutM = setTensorLayoutStrideNV(tensorLayoutM, m_stride, 1); - coopmat mv; + coopmat mvmax; coopMatLoadTensorNV(mv, data_m, m_offset, sliceTensorLayoutNV(tensorLayoutM, i * Br, Br, j * Bc, Bc)); - S += slopeMat*coopmat(mv); + // skip the block if the mask is entirely -inf + coopMatReduceNV(mvmax, mv, gl_CooperativeMatrixReduceRowAndColumnNV, maxReduceFp16); + if (mvmax[0] <= NEG_FLT_MAX_OVER_2) { + continue; + } } } + coopmat S = coopmat(0); + + coopmat K_T; + + uint32_t k_offset = ik2*p.nb12 + ik3*p.nb13; + coopMatLoadTensorNV(K_T, data_k, k_offset, sliceTensorLayoutNV(tensorLayoutK, j * Bc, Bc, 0, HSK_pad), tensorViewTranspose DECODEFUNC); + S = coopMatMulAdd(Qf16, K_T, S); + + if (p.logit_softcap != 0.0f) { + [[unroll]] + for (int k = 0; k < S.length(); ++k) { + S[k] = ACC_TYPE(p.logit_softcap)*tanh(S[k]); + } + } + + if ((p.mask_n_head_log2 & MASK_ENABLE_BIT) != 0) { + S += slopeMat*coopmat(mv); + } + // Clear padding elements to -inf, so they don't contribute to rowmax if (Clamp != 0 && ((j + 1) * Bc > KV ||