diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp index 5266e523b9d40..d77e463497400 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp @@ -198,11 +198,21 @@ void main() { uint pos_b_ib = (batch_idx * p.batch_stride_b + ic * BN * p.stride_b + start_k) / BK; #endif +#if defined(DATA_A_Q2_K) + ACC_TYPE_VEC2 sums[WMITER * TM * WNITER * TN/2]; + + [[unroll]] for (uint i = 0; i < WMITER*TM*WNITER*TN/2; i++) { + sums[i] = ACC_TYPE_VEC2(0.0f); + } + +#else ACC_TYPE sums[WMITER * TM * WNITER * TN]; [[unroll]] for (uint i = 0; i < WMITER*TM*WNITER*TN; i++) { sums[i] = ACC_TYPE(0.0f); } +#endif + for (uint block = start_k; block < end_k; block += BK * BK_STEP) { [[unroll]] for (uint l = 0; loadc_a + l < BM; l += loadstride_a) { @@ -254,10 +264,15 @@ void main() { block_b_to_registers(ib); [[unroll]] for (uint wsir = 0; wsir < WMITER; wsir++) { + + #if defined(DATA_A_Q2_K) + [[unroll]] for (uint cr = 0; cr < TM / 2; cr++) { + const uint sums_idx = (wsic * TN + cc) * (WMITER * TM / 2) + wsir * TM / 2 + cr; + #else [[unroll]] for (uint cr = 0; cr < TM; cr++) { - const uint cache_a_idx = wsir * TM + cr; const uint sums_idx = (wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr; - + #endif + const uint cache_a_idx = wsir * TM + cr; sums[sums_idx] += mmq_dot_product(cache_a_idx); } } @@ -287,6 +302,26 @@ void main() { const u16vec2 row_idx = row_ids[row_i - ic * BN]; #endif // MUL_MAT_ID + #if defined(DATA_A_Q2_K) + [[unroll]] for (uint cr = 0; cr < TM / 2; cr++) { + const uint sums_idx = (wsic * TN + cc) * WMITER * (TM / 2) + wsir * (TM / 2) + cr; +#ifdef MUL_MAT_ID + if (dr_warp + 2 * cr < p.M) { + data_d[row_idx.y * p.batch_stride_d + row_idx.x * p.stride_d + dr_warp + 2 * cr] = D_TYPE(sums[sums_idx].x); + } + if (dr_warp + 2 * cr + 1 < p.M) { + data_d[row_idx.y * p.batch_stride_d + row_idx.x * p.stride_d + dr_warp + 2 * cr + 1] = D_TYPE(sums[sums_idx].y); + } +#else + if (dr_warp + 2 * cr < p.M && dc_warp + cc < p.N) { + data_d[offsets + (dc_warp + cc) * p.stride_d + dr_warp + 2 * cr] = D_TYPE(sums[sums_idx].x); + } + if (dr_warp + 2 * cr + 1 < p.M && dc_warp + cc < p.N) { + data_d[offsets + (dc_warp + cc) * p.stride_d + dr_warp + 2 * cr + 1] = D_TYPE(sums[sums_idx].y); + } +#endif // MUL_MAT_ID + } + #else [[unroll]] for (uint cr = 0; cr < TM; cr++) { const uint sums_idx = (wsic * TN + cc) * WMITER * TM + wsir * TM + cr; #ifdef MUL_MAT_ID @@ -299,6 +334,7 @@ void main() { } #endif // MUL_MAT_ID } + #endif } } } diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.glsl b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.glsl index 51b5bb11e7b47..fb93c05c35cbb 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.glsl +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.glsl @@ -313,20 +313,27 @@ void block_a_to_registers(const uint reg_ib, const uint buf_ib) { } } -ACC_TYPE mmq_dot_product(const uint ib_a) { - int32_t sum_d = 0; - int32_t sum_m = 0; +ACC_TYPE_VEC2 mmq_dot_product(const uint ib_a) { + i32vec2 sum_d = i32vec2(0); + i32vec2 sum_m = i32vec2(0); [[unroll]] for (uint iqs = 0; iqs < 8; iqs++) { - const uint8_t scale = cache_a[ib_a].scales[iqs / 4]; - const int32_t scale_m = int32_t(scale >> 4) * 0x01010101; // Duplicate 8-bit value across 32-bits. - const int32_t qs_a = int32_t((cache_a[ib_a].qs[iqs / 4] >> ((iqs % 4) * 2)) & 0x03030303); - - sum_d += dotPacked4x8EXT(qs_a, cache_b.qs[iqs]) * (scale & 0xF); - sum_m += dotPacked4x8EXT(scale_m, cache_b.qs[iqs]); + const u8vec2 scale = u8vec2(cache_a[ib_a ].scales[iqs / 4], + cache_a[ib_a + 1].scales[iqs / 4]); + const i32vec2 scale_m = i32vec2(int32_t(scale.x >> 4) * 0x01010101, + int32_t(scale.y >> 4) * 0x01010101); // Duplicate 8-bit value across 32-bits. + const i32vec2 qs_a = i32vec2((cache_a[ib_a ].qs[iqs / 4] >> ((iqs % 4) * 2)) & 0x03030303, + (cache_a[ib_a + 1].qs[iqs / 4] >> ((iqs % 4) * 2)) & 0x03030303); + + sum_d.x += dotPacked4x8EXT(qs_a.x, cache_b.qs[iqs]) * (scale.x & 0xF); + sum_d.y += dotPacked4x8EXT(qs_a.y, cache_b.qs[iqs]) * (scale.y & 0xF); + + sum_m.x += dotPacked4x8EXT(scale_m.x, cache_b.qs[iqs]); + sum_m.y += dotPacked4x8EXT(scale_m.y, cache_b.qs[iqs]); } - return mul_q8_1(sum_d, sum_m, cache_a[ib_a].dm, cache_b.ds, 1); + return ACC_TYPE_VEC2(mul_q8_1(sum_d.x, sum_m.x, cache_a[ib_a ].dm, cache_b.ds, 1), + mul_q8_1(sum_d.y, sum_m.y, cache_a[ib_a + 1].dm, cache_b.ds, 1)); } #endif // MMQ_SHMEM #endif