Skip to content

Commit 4d5d20b

Browse files
Add the subgroup matrix multiplication
1 parent 8c58124 commit 4d5d20b

File tree

1 file changed

+77
-5
lines changed

1 file changed

+77
-5
lines changed

examples/matmul/run.cpp

Lines changed: 77 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -613,6 +613,66 @@ inline KernelCode createMatmulWithTranspose(const char *shaderTemplate, const si
613613
return {unrolledCode, workgroupSize, precision};
614614
}
615615

616+
inline KernelCode createMatmul12(const char *shaderTemplate, const size_t M,
617+
const size_t K, const size_t N,
618+
NumType precision = kf32) {
619+
std::string codeString(shaderTemplate);
620+
replaceAll(codeString, {{"{{precision}}", toString(precision)},
621+
{"{{M}}", toString(M)},
622+
{"{{K}}", toString(K)},
623+
{"{{N}}", toString(N)}});
624+
return {codeString, {256, 1, 1}, precision};
625+
}
626+
627+
628+
629+
// ─────────────────────────────────────────────────────────────────────────────
630+
// Optimised WGSL matrix‑multiply kernel using subgroupMatrixLoad/Store
631+
// and subgroupMatrixMultiplyAccumulate
632+
// ─────────────────────────────────────────────────────────────────────────────
633+
const char* kShaderSubgroupMatrixMultiply = R"(
634+
enable chromium_experimental_subgroup_matrix;
635+
636+
@group(0) @binding(0) var<storage, read> A: array<{{precision}}>;
637+
@group(0) @binding(1) var<storage, read> B: array<{{precision}}>;
638+
@group(0) @binding(2) var<storage, read_write> C: array<{{precision}}>;
639+
640+
// Each workgroup computes one 16x16 tile of C.
641+
@compute @workgroup_size(256, 1, 1)
642+
fn main(@builtin(workgroup_id) groupID: vec3<u32>) {
643+
644+
let tileRow = groupID.y;
645+
let tileCol = groupID.x;
646+
647+
let outRowStart = tileRow * 16u;
648+
let outColStart = tileCol * 16u;
649+
650+
if (outRowStart >= {{M}} || outColStart >= {{N}}) {
651+
return;
652+
}
653+
654+
var acc: subgroup_matrix_result<{{precision}}, 16, 16>;
655+
656+
let kTiles = ({{K}} + 15u) / 16u;
657+
658+
// Load the first tile and multiply to initialize accumulator
659+
let a_tile_0 = subgroupMatrixLoad<subgroup_matrix_left<{{precision}}, 16, 16>>(A, outRowStart * {{K}}, true, {{K}});
660+
let b_tile_0 = subgroupMatrixLoad<subgroup_matrix_right<{{precision}}, 16, 16>>(B, outColStart, true, {{N}});
661+
acc = subgroupMatrixMultiply<{{precision}}>(a_tile_0, b_tile_0);
662+
663+
// Loop over the rest of the K-dimension
664+
for (var kTile: u32 = 1u; kTile < kTiles; kTile = kTile + 1u) {
665+
let k = kTile * 16u;
666+
let a_tile = subgroupMatrixLoad<subgroup_matrix_left<{{precision}}, 16, 16>>(A, outRowStart * {{K}} + k, true, {{K}});
667+
let b_tile = subgroupMatrixLoad<subgroup_matrix_right<{{precision}}, 16, 16>>(B, k * {{N}} + outColStart, true, {{N}});
668+
acc = subgroupMatrixMultiplyAccumulate(a_tile, b_tile, acc);
669+
}
670+
671+
subgroupMatrixStore(C, outRowStart * {{N}} + outColStart, acc, true, {{N}});
672+
}
673+
)";
674+
675+
616676
/**
617677
* @brief No-Op shader with matmul bindings for performance testing
618678
*/
@@ -775,6 +835,16 @@ Kernel selectMatmul(Context &ctx, int version,
775835
numtype);
776836
kernel = createKernel(ctx, matmul, bindings,
777837
/*nWorkgroups*/ nWorkgroups);
838+
} else if (version == 12) {
839+
// f32: Subgroup matrix multiply
840+
Shape wgSize = {256, 1, 1}; // One subgroup per workgroup
841+
Shape nWorkgroups = {cdiv(N, 16), cdiv(M, 16), 1};
842+
LOG(kDefLog, kInfo, "M: %zu, K: %zu, N: %zu", M, K, N);
843+
LOG(kDefLog, kInfo, "wgSize: ( %s )", toString(wgSize).c_str());
844+
LOG(kDefLog, kInfo, "nWorkgroups: ( %s )", toString(nWorkgroups).c_str());
845+
KernelCode matmul =
846+
createMatmul12(kShaderSubgroupMatrixMultiply, M, K, N, numtype);
847+
kernel = createKernel(ctx, matmul, bindings, nWorkgroups);
778848
}
779849
return kernel;
780850
}
@@ -865,7 +935,7 @@ void runTest(int version, size_t M, size_t K, size_t N,
865935
// Use microsecond for more accurate time measurement
866936
auto duration =
867937
std::chrono::duration_cast<std::chrono::microseconds>(end - start);
868-
float gflops = 2 * M * N *
938+
float gflops = 2.0f * M * N *
869939
K / // factor of 2 for multiplication & accumulation
870940
(static_cast<double>(duration.count()) / 1000000.0) /
871941
1000000000.0 * static_cast<float>(nIter);
@@ -876,7 +946,7 @@ void runTest(int version, size_t M, size_t K, size_t N,
876946
show<precision>(outputPtr.get(), M, N, "Output[0]").c_str());
877947

878948
LOG(kDefLog, kInfo, "\n\n===================================================================="
879-
"============\nExecution Time: (M = %d, K = %d, N = %d) x %d iterations "
949+
"============\nExecution Time: (M = %zu, K = %zu, N = %zu) x %zu iterations "
880950
":\n%.1f "
881951
"milliseconds / dispatch ~ %.2f "
882952
"GFLOPS\n================================================================"
@@ -917,15 +987,16 @@ const std::string versionToStr(int version){
917987
case 7: return "f32: 2D blocktiling with loop unrolling";
918988
case 8: return "f32: 2D blocktiling with loop unrolling and vectorization";
919989
case 9: return "f32: 2D blocktiling with loop unrolling, vectorization and transpose";
920-
case 10: return "f16: 2D blocktiling with loop unrolling and vectorization";
990+
case 10: return "f16: 2D blocktiling with loop unrolling and vectorization (default)";
921991
case 11: return "f16: 2D blocktiling with loop unrolling, vectorization and transpose";
992+
case 12: return "f32: Subgroup matrix multiply";
922993
default: return "Not specified";
923994
}
924995
}
925996

926997
int main() {
927998
char* version_str = getenv("MATMUL_VERSION");
928-
int version = version_str == NULL ? 10 : atoi(version_str);
999+
int version = version_str == NULL ? 12 : atoi(version_str);
9291000
// 1 == f32: No-Op
9301001
// 2 == f32: naive matmul
9311002
// 3 == f32: tiling
@@ -937,8 +1008,9 @@ int main() {
9371008
// 9 == f32: 2D blocktiling with loop unrolling, vectorization and transpose
9381009
// 10 == f16: 2D blocktiling with loop unrolling and vectorization (default)
9391010
// 11 == f16: 2D blocktiling with loop unrolling, vectorization and transpose
1011+
// 12 == f32: Subgroup matrix multiply
9401012
bool enableF16 = version == 10 || version ==11;
941-
bool transposedInput = version == 9 || version == 11;
1013+
bool transposedInput = version == 9 || version == 11 || version == 12;
9421014
NumType numtype = enableF16 ? kf16 : kf32;
9431015

9441016
size_t M, K, N; // Matrix dimensions

0 commit comments

Comments
 (0)