Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 38 additions & 2 deletions ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp
Original file line number Diff line number Diff line change
Expand Up @@ -198,11 +198,21 @@ void main() {
uint pos_b_ib = (batch_idx * p.batch_stride_b + ic * BN * p.stride_b + start_k) / BK;
#endif

#if defined(DATA_A_Q2_K)
ACC_TYPE_VEC2 sums[WMITER * TM * WNITER * TN/2];

[[unroll]] for (uint i = 0; i < WMITER*TM*WNITER*TN/2; i++) {
sums[i] = ACC_TYPE_VEC2(0.0f);
}

#else
ACC_TYPE sums[WMITER * TM * WNITER * TN];

[[unroll]] for (uint i = 0; i < WMITER*TM*WNITER*TN; i++) {
sums[i] = ACC_TYPE(0.0f);
}
#endif


for (uint block = start_k; block < end_k; block += BK * BK_STEP) {
[[unroll]] for (uint l = 0; loadc_a + l < BM; l += loadstride_a) {
Expand Down Expand Up @@ -254,10 +264,15 @@ void main() {
block_b_to_registers(ib);

[[unroll]] for (uint wsir = 0; wsir < WMITER; wsir++) {

#if defined(DATA_A_Q2_K)
[[unroll]] for (uint cr = 0; cr < TM / 2; cr++) {
const uint sums_idx = (wsic * TN + cc) * (WMITER * TM / 2) + wsir * TM / 2 + cr;
#else
[[unroll]] for (uint cr = 0; cr < TM; cr++) {
const uint cache_a_idx = wsir * TM + cr;
const uint sums_idx = (wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr;

#endif
const uint cache_a_idx = wsir * TM + cr;
sums[sums_idx] += mmq_dot_product(cache_a_idx);
}
}
Expand Down Expand Up @@ -287,6 +302,26 @@ void main() {

const u16vec2 row_idx = row_ids[row_i - ic * BN];
#endif // MUL_MAT_ID
#if defined(DATA_A_Q2_K)
[[unroll]] for (uint cr = 0; cr < TM / 2; cr++) {
const uint sums_idx = (wsic * TN + cc) * WMITER * (TM / 2) + wsir * (TM / 2) + cr;
#ifdef MUL_MAT_ID
if (dr_warp + 2 * cr < p.M) {
data_d[row_idx.y * p.batch_stride_d + row_idx.x * p.stride_d + dr_warp + 2 * cr] = D_TYPE(sums[sums_idx].x);
}
if (dr_warp + 2 * cr + 1 < p.M) {
data_d[row_idx.y * p.batch_stride_d + row_idx.x * p.stride_d + dr_warp + 2 * cr + 1] = D_TYPE(sums[sums_idx].y);
}
#else
if (dr_warp + 2 * cr < p.M && dc_warp + cc < p.N) {
data_d[offsets + (dc_warp + cc) * p.stride_d + dr_warp + 2 * cr] = D_TYPE(sums[sums_idx].x);
}
if (dr_warp + 2 * cr + 1 < p.M && dc_warp + cc < p.N) {
data_d[offsets + (dc_warp + cc) * p.stride_d + dr_warp + 2 * cr + 1] = D_TYPE(sums[sums_idx].y);
}
#endif // MUL_MAT_ID
}
#else
[[unroll]] for (uint cr = 0; cr < TM; cr++) {
const uint sums_idx = (wsic * TN + cc) * WMITER * TM + wsir * TM + cr;
#ifdef MUL_MAT_ID
Expand All @@ -299,6 +334,7 @@ void main() {
}
#endif // MUL_MAT_ID
}
#endif
}
}
}
Expand Down
27 changes: 17 additions & 10 deletions ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.glsl
Original file line number Diff line number Diff line change
Expand Up @@ -313,20 +313,27 @@ void block_a_to_registers(const uint reg_ib, const uint buf_ib) {
}
}

ACC_TYPE mmq_dot_product(const uint ib_a) {
int32_t sum_d = 0;
int32_t sum_m = 0;
ACC_TYPE_VEC2 mmq_dot_product(const uint ib_a) {
i32vec2 sum_d = i32vec2(0);
i32vec2 sum_m = i32vec2(0);

[[unroll]] for (uint iqs = 0; iqs < 8; iqs++) {
const uint8_t scale = cache_a[ib_a].scales[iqs / 4];
const int32_t scale_m = int32_t(scale >> 4) * 0x01010101; // Duplicate 8-bit value across 32-bits.
const int32_t qs_a = int32_t((cache_a[ib_a].qs[iqs / 4] >> ((iqs % 4) * 2)) & 0x03030303);

sum_d += dotPacked4x8EXT(qs_a, cache_b.qs[iqs]) * (scale & 0xF);
sum_m += dotPacked4x8EXT(scale_m, cache_b.qs[iqs]);
const u8vec2 scale = u8vec2(cache_a[ib_a ].scales[iqs / 4],
cache_a[ib_a + 1].scales[iqs / 4]);
const i32vec2 scale_m = i32vec2(int32_t(scale.x >> 4) * 0x01010101,
int32_t(scale.y >> 4) * 0x01010101); // Duplicate 8-bit value across 32-bits.
const i32vec2 qs_a = i32vec2((cache_a[ib_a ].qs[iqs / 4] >> ((iqs % 4) * 2)) & 0x03030303,
(cache_a[ib_a + 1].qs[iqs / 4] >> ((iqs % 4) * 2)) & 0x03030303);

sum_d.x += dotPacked4x8EXT(qs_a.x, cache_b.qs[iqs]) * (scale.x & 0xF);
sum_d.y += dotPacked4x8EXT(qs_a.y, cache_b.qs[iqs]) * (scale.y & 0xF);

sum_m.x += dotPacked4x8EXT(scale_m.x, cache_b.qs[iqs]);
sum_m.y += dotPacked4x8EXT(scale_m.y, cache_b.qs[iqs]);
}

return mul_q8_1(sum_d, sum_m, cache_a[ib_a].dm, cache_b.ds, 1);
return ACC_TYPE_VEC2(mul_q8_1(sum_d.x, sum_m.x, cache_a[ib_a ].dm, cache_b.ds, 1),
mul_q8_1(sum_d.y, sum_m.y, cache_a[ib_a + 1].dm, cache_b.ds, 1));
}
#endif // MMQ_SHMEM
#endif
Expand Down
Loading