33#include < future>
44#include < random>
55#include < cstdlib>
6+ #include < exception>
7+ #include < iostream>
68
79#include " gpu.hpp" // createContext, createTensor, createKernel, dispatchKernel,
810 // wait, resetCommandBuffer, toCPU
@@ -615,64 +617,79 @@ inline KernelCode createMatmulWithTranspose(const char *shaderTemplate, const si
615617
616618inline KernelCode createMatmul12 (const char *shaderTemplate, const size_t M,
617619 const size_t K, const size_t N,
620+ const size_t TM, const size_t TN,
621+ const Shape &workgroupSize = {256 , 1 , 1 },
618622 NumType precision = kf32) {
619623 std::string codeString (shaderTemplate);
620624 replaceAll (codeString, {{" {{precision}}" , toString (precision)},
621625 {" {{M}}" , toString (M)},
622626 {" {{K}}" , toString (K)},
623- {" {{N}}" , toString (N)}});
624- return {codeString, {256 , 1 , 1 }, precision};
627+ {" {{N}}" , toString (N)},
628+ {" {{TM}}" , toString (TM)},
629+ {" {{TN}}" , toString (TN)}
630+ });
631+ return {codeString, workgroupSize, precision};
625632}
626633
627-
628-
629634// ─────────────────────────────────────────────────────────────────────────────
630635// Optimised WGSL matrix‑multiply kernel using subgroupMatrixLoad/Store
631636// and subgroupMatrixMultiplyAccumulate
632637// ─────────────────────────────────────────────────────────────────────────────
633638const char * kShaderSubgroupMatrixMultiply = R"(
639+ enable subgroups;
634640enable chromium_experimental_subgroup_matrix;
635641
636- @group(0) @binding(0) var<storage, read> A: array<{{precision}}>;
637- @group(0) @binding(1) var<storage, read> B: array<{{precision}}>;
638- @group(0) @binding(2) var<storage, read_write> C: array<{{precision}}>;
642+ @group(0) @binding(0) var<storage, read_write> A: array<{{precision}}>;
643+ @group(0) @binding(1) var<storage, read_write> B: array<{{precision}}>;
644+ @group(0) @binding(2) var<storage, read_write> C: array<{{precision}}>;
645+
646+ @compute @workgroup_size({{workgroupSize}})
647+ fn main(@builtin(workgroup_id) wg: vec3<u32>,
648+ @builtin(local_invocation_id) lid: vec3<u32>,
649+ @builtin(subgroup_id) sgid: u32,
650+ @builtin(subgroup_size) ssize: u32) {
639651
640- // Each workgroup computes one 16x16 tile of C.
641- @compute @workgroup_size(256, 1, 1)
642- fn main(@builtin(workgroup_id) groupID: vec3<u32>) {
652+ let rowStart: u32 = wg.x * 8u * {{TM}};
653+ let colStart: u32 = wg.y * 8u * {{TN}};
643654
644- let tileRow = groupID.y;
645- let tileCol = groupID.x;
655+ if (rowStart >= u32({{M}}) || colStart >= u32({{N}})) { return; }
646656
647- let outRowStart = tileRow * 16u;
648- let outColStart = tileCol * 16u;
657+ let baseA: u32 = rowStart * {{K}};
658+ let baseB: u32 = colStart;
659+ let cBase: u32 = rowStart * {{N}} + colStart;
649660
650- if (outRowStart >= {{M}} || outColStart >= {{N}}) {
651- return;
652- }
661+ var Ax: array<subgroup_matrix_left<{{precision}}, 8, 8>, {{TM}}>;
662+ var Bx: array<subgroup_matrix_right<{{precision}}, 8, 8>, {{TN}}>;
653663
654- var acc: subgroup_matrix_result<{{precision}}, 16, 16>;
664+ // 4x4 accumulators (8x8 each)
665+ var accxx: array<subgroup_matrix_result<{{precision}}, 8, 8>, {{TM}} * {{TN}}>;
655666
656- let kTiles = ({{K}} + 15u) / 16u;
667+ for (var k: u32 = 0u; k < {{K}}; k = k + 8u) {
668+ workgroupBarrier();
669+ for (var i: u32 = 0; i < {{TM}}; i++) {
670+ Ax[i] = subgroupMatrixLoad<subgroup_matrix_left<{{precision}},8,8>>(&A, baseA + i * 8u*{{K}} + k, false, {{K}});
671+ }
657672
658- // Load the first tile and multiply to initialize accumulator
659- let a_tile_0 = subgroupMatrixLoad<subgroup_matrix_left<{{precision}}, 16, 16>>(A, outRowStart * {{K}}, true, {{K}});
660- let b_tile_0 = subgroupMatrixLoad<subgroup_matrix_right<{{precision}}, 16, 16>>(B, outColStart, true, {{N}});
661- acc = subgroupMatrixMultiply<{{precision}}>(a_tile_0, b_tile_0);
673+ for (var i: u32 = 0; i < {{TN}}; i++) {
674+ Bx[i] = subgroupMatrixLoad<subgroup_matrix_right<{{precision}},8,8>>(&B, baseB + k*{{N}} + 8u * i, false, {{N}});
675+ }
662676
663- // Loop over the rest of the K-dimension
664- for (var kTile: u32 = 1u; kTile < kTiles; kTile = kTile + 1u) {
665- let k = kTile * 16u;
666- let a_tile = subgroupMatrixLoad<subgroup_matrix_left<{{precision}}, 16, 16>>(A, outRowStart * {{K}} + k, true, {{K}});
667- let b_tile = subgroupMatrixLoad<subgroup_matrix_right<{{precision}}, 16, 16>>(B, k * {{N}} + outColStart, true, {{N}});
668- acc = subgroupMatrixMultiplyAccumulate(a_tile, b_tile, acc);
677+ for (var i: u32 = 0; i < {{TM}}; i++) {
678+ for (var j: u32 = 0; j < {{TN}}; j++) {
679+ accxx[i+j*{{TM}}] = subgroupMatrixMultiplyAccumulate(Ax[i], Bx[j], accxx[i+j*{{TM}}]);
680+ }
669681 }
682+ }
670683
671- subgroupMatrixStore(C, outRowStart * {{N}} + outColStart, acc, true, {{N}});
684+ workgroupBarrier();
685+ for (var i: u32 = 0; i < {{TM}}; i++) {
686+ for (var j: u32 = 0; j < {{TN}}; j++) {
687+ subgroupMatrixStore(&C, cBase + i * 8u * {{N}} + 8u * j, accxx[i+j*{{TM}}], false, {{N}});
688+ }
689+ }
672690}
673691)" ;
674692
675-
676693/* *
677694 * @brief No-Op shader with matmul bindings for performance testing
678695 */
@@ -743,26 +760,30 @@ Kernel selectMatmul(Context &ctx, int version,
743760 const Bindings</* input, weights, output */ 3 > &bindings,
744761 size_t M, size_t K, size_t N, NumType numtype) {
745762 Kernel kernel;
763+ CompilationInfo info;
746764 if (version == 1 ) {
747765 Shape wgSize = {256 , 1 , 1 };
748766 Shape nWorkgroups = cdiv ({M, N, 1 }, {16 , 16 , 1 });
749767 KernelCode matmul = createNoOp (kShaderNoOp , /* wgsize*/ wgSize);
750768 kernel = createKernel (ctx, matmul, bindings,
751- /* nWorkgroups*/ nWorkgroups);
769+ /* nWorkgroups*/ nWorkgroups,
770+ NoParam{}, &info);
752771 } else if (version == 2 ) {
753772 Shape wgSize = {16 , 16 , 1 };
754773 LOG (kDefLog , kInfo , " wgSize: %s" , toString (wgSize).c_str ());
755774 KernelCode matmul =
756775 createMatmul1 (kShaderMatmul1 , M, K, N, /* wgsize*/ wgSize, numtype);
757776 kernel = createKernel (ctx, matmul, bindings,
758- /* nWorkgroups*/ cdiv ({M, N, 1 }, wgSize));
777+ /* nWorkgroups*/ cdiv ({M, N, 1 }, wgSize),
778+ NoParam{}, &info);
759779 } else if (version == 3 ) {
760780 static constexpr size_t tileSize = 16 ;
761781 KernelCode matmul = createMatmul2 (kShaderMatmul2 , M, K, N,
762782 /* wgSize*/ {tileSize * tileSize, 1 , 1 }, numtype);
763783 kernel =
764784 createKernel (ctx, matmul, bindings,
765- /* nWorkgroups*/ cdiv ({M, N, 1 }, {tileSize, tileSize, 1 }));
785+ /* nWorkgroups*/ cdiv ({M, N, 1 }, {tileSize, tileSize, 1 }),
786+ NoParam{}, &info);
766787 } else if (version == 4 || version == 6 ) {
767788 static constexpr size_t BM = 64 ;
768789 static constexpr size_t BK = 4 ;
@@ -781,7 +802,8 @@ Kernel selectMatmul(Context &ctx, int version,
781802 numtype,
782803 /* Loop unrolling*/ version == 6 ? true : false );
783804 kernel = createKernel (ctx, matmul, bindings,
784- /* nWorkgroups*/ nWorkgroups);
805+ /* nWorkgroups*/ nWorkgroups,
806+ NoParam{}, &info);
785807 } else if (version == 5 || version == 7 ) {
786808 static constexpr size_t BM = 64 ;
787809 static constexpr size_t BK = 8 ;
@@ -799,7 +821,8 @@ Kernel selectMatmul(Context &ctx, int version,
799821 numtype,
800822 /* Loop unrolling*/ version == 7 ? true : false );
801823 kernel = createKernel (ctx, matmul, bindings,
802- /* nWorkgroups*/ nWorkgroups);
824+ /* nWorkgroups*/ nWorkgroups,
825+ NoParam{}, &info);
803826 } else if (version == 8 || version == 10 ) {
804827 static constexpr size_t BM = 64 ;
805828 static constexpr size_t BK = 8 ;
@@ -817,7 +840,8 @@ Kernel selectMatmul(Context &ctx, int version,
817840 numtype,
818841 /* Loop unrolling*/ true );
819842 kernel = createKernel (ctx, matmul, bindings,
820- /* nWorkgroups*/ nWorkgroups);
843+ /* nWorkgroups*/ nWorkgroups,
844+ NoParam{}, &info);
821845 } else if (version == 9 || version == 11 ) {
822846 static constexpr size_t BM = 64 ;
823847 static constexpr size_t BK = 8 ;
@@ -834,18 +858,37 @@ Kernel selectMatmul(Context &ctx, int version,
834858 /* wgSize*/ wgSize,
835859 numtype);
836860 kernel = createKernel (ctx, matmul, bindings,
837- /* nWorkgroups*/ nWorkgroups);
861+ /* nWorkgroups*/ nWorkgroups,
862+ NoParam{}, &info);
838863 } else if (version == 12 ) {
839864 // f32: Subgroup matrix multiply
840- Shape wgSize = {256 , 1 , 1 }; // One subgroup per workgroup
841- Shape nWorkgroups = {cdiv (N, 16 ), cdiv (M, 16 ), 1 };
865+ static constexpr size_t TM = 2 ;
866+ static constexpr size_t TN = 4 ;
867+ Shape wgSize = {64 , 1 , 1 }; // One subgroup per workgroup
868+ Shape nWorkgroups = {cdiv (M, 8 * TM), cdiv (N, 8 * TN), 1 };
842869 LOG (kDefLog , kInfo , " M: %zu, K: %zu, N: %zu" , M, K, N);
843870 LOG (kDefLog , kInfo , " wgSize: ( %s )" , toString (wgSize).c_str ());
844871 LOG (kDefLog , kInfo , " nWorkgroups: ( %s )" , toString (nWorkgroups).c_str ());
845- KernelCode matmul =
846- createMatmul12 (kShaderSubgroupMatrixMultiply , M, K, N, numtype);
847- kernel = createKernel (ctx, matmul, bindings, nWorkgroups);
872+ KernelCode matmul = createMatmul12 (kShaderSubgroupMatrixMultiply , M, K, N, TM, TN, wgSize, numtype);
873+ kernel = createKernel (ctx, matmul, bindings, nWorkgroups,
874+ NoParam{}, &info);
875+ }
876+
877+ if (info.status != WGPUCompilationInfoRequestStatus_Success) {
878+ LOG (kDefLog , kError , " Failed to compile shader" );
879+ for (size_t i = 0 ; i < info.messages .size (); i++) {
880+ LOG (kDefLog , kError , " Line %llu, Pos %llu: %s" , info.lineNums [i],
881+ info.linePos [i], info.messages [i].c_str ());
882+ }
883+ exit (1 );
884+ } else {
885+ LOG (kDefLog , kInfo , " Shader compiled successfully" );
886+ for (size_t i = 0 ; i < info.messages .size (); i++) {
887+ LOG (kDefLog , kInfo , " Line %llu, Pos %llu: %s" , info.lineNums [i],
888+ info.linePos [i], info.messages [i].c_str ());
889+ }
848890 }
891+
849892 return kernel;
850893}
851894
@@ -866,36 +909,49 @@ void runTest(int version, size_t M, size_t K, size_t N,
866909 devDescriptor.requiredFeatureCount = 1 ;
867910 devDescriptor.requiredFeatures = std::array{WGPUFeatureName_ShaderF16}.data ();
868911
869- Context ctx;
870- if (numtype == kf16) {
871- ctx = createContext (
872- {}, {},
873- /* device descriptor, enabling f16 in WGSL*/
874- {
875- .requiredFeatureCount = 1 ,
876- .requiredFeatures = std::array{WGPUFeatureName_ShaderF16}.data ()
877- });
878- if (ctx.adapterStatus != WGPURequestAdapterStatus_Success) {
879- LOG (kDefLog , kError , " Failed to create adapter with f16 support, try running an f32 test instead (`export MATMUL_VERSION=9)." );
880- exit (1 );
912+ WGPUDawnTogglesDescriptor toggles = {};
913+ toggles.chain .sType = WGPUSType_DawnTogglesDescriptor;
914+ const char * enableList[] = {" allow_unsafe_apis" };
915+ toggles.enabledToggles = enableList;
916+ toggles.enabledToggleCount = 1 ;
917+
918+ WGPUDeviceDescriptor devDesc = {};
919+ devDesc.nextInChain = &toggles.chain ;
920+ devDesc.requiredFeatureCount = 3 ,
921+ devDesc.requiredFeatures = std::array{
922+ WGPUFeatureName_ShaderF16,
923+ WGPUFeatureName_Subgroups,
924+ WGPUFeatureName_ChromiumExperimentalSubgroupMatrix
925+ }.data ();
926+ devDesc.uncapturedErrorCallbackInfo = WGPUUncapturedErrorCallbackInfo {
927+ .callback = [](WGPUDevice const * device, WGPUErrorType type, WGPUStringView msg, void *, void *) {
928+ LOG (kDefLog , kError , " [Uncaptured %d] %.*s\n " , (int )type, (int )msg.length , msg.data );
881929 }
882- if (ctx.deviceStatus != WGPURequestDeviceStatus_Success) {
883- LOG (kDefLog , kError , " Failed to create device with f16 support, try running an f32 test instead. (`export MATMUL_VERSION=9)" );
884- exit (1 );
930+ };
931+ devDesc.deviceLostCallbackInfo = WGPUDeviceLostCallbackInfo {
932+ .mode = WGPUCallbackMode_AllowSpontaneous,
933+ .callback = [](WGPUDevice const * device, WGPUDeviceLostReason reason, WGPUStringView msg, void *, void *) {
934+ LOG (kDefLog , kError , " [DeviceLost %d] %.*s\n " , (int )reason, (int )msg.length , msg.data );
885935 }
886- }
887-
888- if (numtype == kf32) {
889- ctx = createContext ({}, {}, {});
890- if (ctx.adapterStatus != WGPURequestAdapterStatus_Success ||
891- ctx.deviceStatus != WGPURequestDeviceStatus_Success) {
892- LOG (kDefLog , kError , " Failed to create adapter or device" );
893- // stop execution
894- exit (1 );
895- } else {
896- LOG (kDefLog , kInfo , " Successfully created adapter and device" );
936+ };
937+
938+ Context ctx = createContext ({}, {}, devDesc);
939+
940+ WGPULoggingCallbackInfo logCb{
941+ .callback = [](WGPULoggingType type, WGPUStringView msg, void *, void *) {
942+ LOG (kDefLog , kError , " [WGPU %d] %.*s\n " , (int )type, (int )msg.length , msg.data );
897943 }
898- }
944+ };
945+ wgpuDeviceSetLoggingCallback (ctx.device , logCb);
946+
947+ if (ctx.adapterStatus != WGPURequestAdapterStatus_Success ||
948+ ctx.deviceStatus != WGPURequestDeviceStatus_Success) {
949+ LOG (kDefLog , kError , " Failed to create adapter or device" );
950+ // stop execution
951+ exit (1 );
952+ } else {
953+ LOG (kDefLog , kInfo , " Successfully created adapter and device" );
954+ }
899955
900956 Tensor input = createTensor (ctx, Shape{M, K}, numtype, inputPtr.get ());
901957 Tensor weights = createTensor (ctx, Shape{N, K}, numtype, weightsPtr.get ()); // column-major
@@ -983,14 +1039,15 @@ const std::string versionToStr(int version){
9831039 case 9 : return " f32: 2D blocktiling with loop unrolling, vectorization and transpose" ;
9841040 case 10 : return " f16: 2D blocktiling with loop unrolling and vectorization (default)" ;
9851041 case 11 : return " f16: 2D blocktiling with loop unrolling, vectorization and transpose" ;
986- case 12 : return " f32 : Subgroup matrix multiply" ;
1042+ case 12 : return " f16 : Subgroup matrix multiply with transpose " ;
9871043 default : return " Not specified" ;
9881044 }
9891045}
9901046
9911047int main () {
1048+ std::cout << " Starting matmul test..." << std::endl;
9921049 char * version_str = getenv (" MATMUL_VERSION" );
993- int version = version_str == NULL ? 12 : atoi (version_str);
1050+ int version = version_str == NULL ? 11 : atoi (version_str);
9941051 // 1 == f32: No-Op
9951052 // 2 == f32: naive matmul
9961053 // 3 == f32: tiling
@@ -1002,8 +1059,8 @@ int main() {
10021059 // 9 == f32: 2D blocktiling with loop unrolling, vectorization and transpose
10031060 // 10 == f16: 2D blocktiling with loop unrolling and vectorization (default)
10041061 // 11 == f16: 2D blocktiling with loop unrolling, vectorization and transpose
1005- // 12 == f32 : Subgroup matrix multiply
1006- bool enableF16 = version == 10 || version ==11 ;
1062+ // 12 == f16 : Subgroup matrix multiply with transpose
1063+ bool enableF16 = version == 10 || version ==11 || version == 12 ;
10071064 bool transposedInput = version == 9 || version == 11 || version == 12 ;
10081065 NumType numtype = enableF16 ? kf16 : kf32;
10091066
0 commit comments