Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 14 additions & 7 deletions ggml/src/ggml-metal/ggml-metal-ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1970,7 +1970,9 @@ size_t ggml_metal_op_flash_attn_ext_extra_pad(const ggml_tensor * op) {
const bool has_mask = op->src[3] != nullptr;

if (ggml_metal_op_flash_attn_ext_use_vec(op)) {
const bool has_kvpad = ne11 % OP_FLASH_ATTN_EXT_VEC_NCPSG != 0;
// note: always reserve the padding space to avoid graph reallocations
//const bool has_kvpad = ne11 % OP_FLASH_ATTN_EXT_VEC_NCPSG != 0;
const bool has_kvpad = true;

if (has_kvpad) {
res += OP_FLASH_ATTN_EXT_VEC_NCPSG*(
Expand All @@ -1979,7 +1981,8 @@ size_t ggml_metal_op_flash_attn_ext_extra_pad(const ggml_tensor * op) {
(has_mask ? ggml_type_size(GGML_TYPE_F16)*ne31*ne32*ne33 : 0));
}
} else {
const bool has_kvpad = ne11 % OP_FLASH_ATTN_EXT_NCPSG != 0;
//const bool has_kvpad = ne11 % OP_FLASH_ATTN_EXT_NCPSG != 0;
const bool has_kvpad = true;

if (has_kvpad) {
res += OP_FLASH_ATTN_EXT_NCPSG*(
Expand Down Expand Up @@ -2015,9 +2018,10 @@ size_t ggml_metal_op_flash_attn_ext_extra_blk(const ggml_tensor * op) {
const bool is_vec = ggml_metal_op_flash_attn_ext_use_vec(op);

// this optimization is not useful for the vector kernels
if (is_vec) {
return res;
}
// note: always reserve the blk buffer to avoid graph reallocations
//if (is_vec) {
// return res;
//}

const int nqptg = is_vec ? OP_FLASH_ATTN_EXT_VEC_NQPTG : OP_FLASH_ATTN_EXT_NQPTG;
const int ncpsg = is_vec ? OP_FLASH_ATTN_EXT_VEC_NCPSG : OP_FLASH_ATTN_EXT_NCPSG;
Expand All @@ -2044,13 +2048,16 @@ size_t ggml_metal_op_flash_attn_ext_extra_tmp(const ggml_tensor * op) {

size_t res = 0;

if (ggml_metal_op_flash_attn_ext_use_vec(op)) {
// note: always reserve the temp buffer to avoid graph reallocations
//if (ggml_metal_op_flash_attn_ext_use_vec(op)) {
if (true) {
const int64_t nwg = 32;
const int64_t ne01_max = std::min(ne01, 32);

// temp buffer for writing the results from each workgroup
// - ne20: the size of the Value head
// - + 2: the S and M values for each intermediate result
res += ggml_type_size(GGML_TYPE_F32)*(ne01*ne02*ne03*nwg*(ne20 + 2));
res += ggml_type_size(GGML_TYPE_F32)*(ne01_max*ne02*ne03*nwg*(ne20 + 2));
}

return res;
Expand Down
Loading