@@ -613,6 +613,66 @@ inline KernelCode createMatmulWithTranspose(const char *shaderTemplate, const si
613613 return {unrolledCode, workgroupSize, precision};
614614}
615615
616+ inline KernelCode createMatmul12 (const char *shaderTemplate, const size_t M,
617+ const size_t K, const size_t N,
618+ NumType precision = kf32) {
619+ std::string codeString (shaderTemplate);
620+ replaceAll (codeString, {{" {{precision}}" , toString (precision)},
621+ {" {{M}}" , toString (M)},
622+ {" {{K}}" , toString (K)},
623+ {" {{N}}" , toString (N)}});
624+ return {codeString, {256 , 1 , 1 }, precision};
625+ }
626+
627+
628+
629+ // ─────────────────────────────────────────────────────────────────────────────
630+ // Optimised WGSL matrix‑multiply kernel using subgroupMatrixLoad/Store
631+ // and subgroupMatrixMultiplyAccumulate
632+ // ─────────────────────────────────────────────────────────────────────────────
633+ const char * kShaderSubgroupMatrixMultiply = R"(
634+ enable chromium_experimental_subgroup_matrix;
635+
636+ @group(0) @binding(0) var<storage, read> A: array<{{precision}}>;
637+ @group(0) @binding(1) var<storage, read> B: array<{{precision}}>;
638+ @group(0) @binding(2) var<storage, read_write> C: array<{{precision}}>;
639+
640+ // Each workgroup computes one 16x16 tile of C.
641+ @compute @workgroup_size(256, 1, 1)
642+ fn main(@builtin(workgroup_id) groupID: vec3<u32>) {
643+
644+ let tileRow = groupID.y;
645+ let tileCol = groupID.x;
646+
647+ let outRowStart = tileRow * 16u;
648+ let outColStart = tileCol * 16u;
649+
650+ if (outRowStart >= {{M}} || outColStart >= {{N}}) {
651+ return;
652+ }
653+
654+ var acc: subgroup_matrix_result<{{precision}}, 16, 16>;
655+
656+ let kTiles = ({{K}} + 15u) / 16u;
657+
658+ // Load the first tile and multiply to initialize accumulator
659+ let a_tile_0 = subgroupMatrixLoad<subgroup_matrix_left<{{precision}}, 16, 16>>(A, outRowStart * {{K}}, true, {{K}});
660+ let b_tile_0 = subgroupMatrixLoad<subgroup_matrix_right<{{precision}}, 16, 16>>(B, outColStart, true, {{N}});
661+ acc = subgroupMatrixMultiply<{{precision}}>(a_tile_0, b_tile_0);
662+
663+ // Loop over the rest of the K-dimension
664+ for (var kTile: u32 = 1u; kTile < kTiles; kTile = kTile + 1u) {
665+ let k = kTile * 16u;
666+ let a_tile = subgroupMatrixLoad<subgroup_matrix_left<{{precision}}, 16, 16>>(A, outRowStart * {{K}} + k, true, {{K}});
667+ let b_tile = subgroupMatrixLoad<subgroup_matrix_right<{{precision}}, 16, 16>>(B, k * {{N}} + outColStart, true, {{N}});
668+ acc = subgroupMatrixMultiplyAccumulate(a_tile, b_tile, acc);
669+ }
670+
671+ subgroupMatrixStore(C, outRowStart * {{N}} + outColStart, acc, true, {{N}});
672+ }
673+ )" ;
674+
675+
616676/* *
617677 * @brief No-Op shader with matmul bindings for performance testing
618678 */
@@ -775,6 +835,16 @@ Kernel selectMatmul(Context &ctx, int version,
775835 numtype);
776836 kernel = createKernel (ctx, matmul, bindings,
777837 /* nWorkgroups*/ nWorkgroups);
838+ } else if (version == 12 ) {
839+ // f32: Subgroup matrix multiply
840+ Shape wgSize = {256 , 1 , 1 }; // One subgroup per workgroup
841+ Shape nWorkgroups = {cdiv (N, 16 ), cdiv (M, 16 ), 1 };
842+ LOG (kDefLog , kInfo , " M: %zu, K: %zu, N: %zu" , M, K, N);
843+ LOG (kDefLog , kInfo , " wgSize: ( %s )" , toString (wgSize).c_str ());
844+ LOG (kDefLog , kInfo , " nWorkgroups: ( %s )" , toString (nWorkgroups).c_str ());
845+ KernelCode matmul =
846+ createMatmul12 (kShaderSubgroupMatrixMultiply , M, K, N, numtype);
847+ kernel = createKernel (ctx, matmul, bindings, nWorkgroups);
778848 }
779849 return kernel;
780850}
@@ -865,7 +935,7 @@ void runTest(int version, size_t M, size_t K, size_t N,
865935 // Use microsecond for more accurate time measurement
866936 auto duration =
867937 std::chrono::duration_cast<std::chrono::microseconds>(end - start);
868- float gflops = 2 * M * N *
938+ float gflops = 2 . 0f * M * N *
869939 K / // factor of 2 for multiplication & accumulation
870940 (static_cast <double >(duration.count ()) / 1000000.0 ) /
871941 1000000000.0 * static_cast <float >(nIter);
@@ -876,7 +946,7 @@ void runTest(int version, size_t M, size_t K, size_t N,
876946 show<precision>(outputPtr.get (), M, N, " Output[0]" ).c_str ());
877947
878948 LOG (kDefLog , kInfo , " \n\n ===================================================================="
879- " ============\n Execution Time: (M = %d , K = %d , N = %d ) x %d iterations "
949+ " ============\n Execution Time: (M = %zu , K = %zu , N = %zu ) x %zu iterations "
880950 " :\n %.1f "
881951 " milliseconds / dispatch ~ %.2f "
882952 " GFLOPS\n ================================================================"
@@ -917,15 +987,16 @@ const std::string versionToStr(int version){
917987 case 7 : return " f32: 2D blocktiling with loop unrolling" ;
918988 case 8 : return " f32: 2D blocktiling with loop unrolling and vectorization" ;
919989 case 9 : return " f32: 2D blocktiling with loop unrolling, vectorization and transpose" ;
920- case 10 : return " f16: 2D blocktiling with loop unrolling and vectorization" ;
990+ case 10 : return " f16: 2D blocktiling with loop unrolling and vectorization (default) " ;
921991 case 11 : return " f16: 2D blocktiling with loop unrolling, vectorization and transpose" ;
992+ case 12 : return " f32: Subgroup matrix multiply" ;
922993 default : return " Not specified" ;
923994 }
924995}
925996
926997int main () {
927998 char * version_str = getenv (" MATMUL_VERSION" );
928- int version = version_str == NULL ? 10 : atoi (version_str);
999+ int version = version_str == NULL ? 12 : atoi (version_str);
9291000 // 1 == f32: No-Op
9301001 // 2 == f32: naive matmul
9311002 // 3 == f32: tiling
@@ -937,8 +1008,9 @@ int main() {
9371008 // 9 == f32: 2D blocktiling with loop unrolling, vectorization and transpose
9381009 // 10 == f16: 2D blocktiling with loop unrolling and vectorization (default)
9391010 // 11 == f16: 2D blocktiling with loop unrolling, vectorization and transpose
1011+ // 12 == f32: Subgroup matrix multiply
9401012 bool enableF16 = version == 10 || version ==11 ;
941- bool transposedInput = version == 9 || version == 11 ;
1013+ bool transposedInput = version == 9 || version == 11 || version == 12 ;
9421014 NumType numtype = enableF16 ? kf16 : kf32;
9431015
9441016 size_t M, K, N; // Matrix dimensions
0 commit comments