33#include < future>
44#include < random>
55#include < cstdlib>
6+ #include < exception>
7+ #include < iostream>
68
79#include " gpu.hpp" // createContext, createTensor, createKernel, dispatchKernel,
810 // wait, resetCommandBuffer, toCPU
@@ -615,64 +617,76 @@ inline KernelCode createMatmulWithTranspose(const char *shaderTemplate, const si
615617
616618inline KernelCode createMatmul12 (const char *shaderTemplate, const size_t M,
617619 const size_t K, const size_t N,
620+ const size_t TM, const size_t TN,
621+ const Shape &workgroupSize = {256 , 1 , 1 },
618622 NumType precision = kf32) {
619623 std::string codeString (shaderTemplate);
620624 replaceAll (codeString, {{" {{precision}}" , toString (precision)},
621625 {" {{M}}" , toString (M)},
622626 {" {{K}}" , toString (K)},
623- {" {{N}}" , toString (N)}});
624- return {codeString, {256 , 1 , 1 }, precision};
627+ {" {{N}}" , toString (N)},
628+ {" {{TM}}" , toString (TM)},
629+ {" {{TN}}" , toString (TN)}
630+ });
631+ return {codeString, workgroupSize, precision};
625632}
626633
627-
628-
629634// ─────────────────────────────────────────────────────────────────────────────
630635// Optimised WGSL matrix‑multiply kernel using subgroupMatrixLoad/Store
631636// and subgroupMatrixMultiplyAccumulate
632637// ─────────────────────────────────────────────────────────────────────────────
633638const char * kShaderSubgroupMatrixMultiply = R"(
639+ enable subgroups;
634640enable chromium_experimental_subgroup_matrix;
635641
636- @group(0) @binding(0) var<storage, read> A: array<{{precision}}>;
637- @group(0) @binding(1) var<storage, read> B: array<{{precision}}>;
638- @group(0) @binding(2) var<storage, read_write> C: array<{{precision}}>;
642+ @group(0) @binding(0) var<storage, read_write> A: array<{{precision}}>;
643+ @group(0) @binding(1) var<storage, read_write> B: array<{{precision}}>;
644+ @group(0) @binding(2) var<storage, read_write> C: array<{{precision}}>;
645+
646+ @compute @workgroup_size({{workgroupSize}})
647+ fn main(@builtin(workgroup_id) wg: vec3<u32>) {
639648
640- // Each workgroup computes one 16x16 tile of C.
641- @compute @workgroup_size(256, 1, 1)
642- fn main(@builtin(workgroup_id) groupID: vec3<u32>) {
649+ let rowStart: u32 = wg.x * 8u * {{TM}};
650+ let colStart: u32 = wg.y * 8u * {{TN}};
643651
644- let tileRow = groupID.y;
645- let tileCol = groupID.x;
652+ if (rowStart >= u32({{M}}) || colStart >= u32({{N}})) { return; }
646653
647- let outRowStart = tileRow * 16u;
648- let outColStart = tileCol * 16u;
654+ let baseA: u32 = rowStart * {{K}};
655+ let baseB: u32 = colStart;
656+ let cBase: u32 = rowStart * {{N}} + colStart;
649657
650- if (outRowStart >= {{M}} || outColStart >= {{N}}) {
651- return;
652- }
658+ var Ax: array<subgroup_matrix_left<{{precision}}, 8, 8>, {{TM}}>;
659+ var Bx: array<subgroup_matrix_right<{{precision}}, 8, 8>, {{TN}}>;
653660
654- var acc: subgroup_matrix_result<{{precision}}, 16, 16>;
661+ // 4x4 accumulators (8x8 each)
662+ var accxx: array<subgroup_matrix_result<{{precision}}, 8, 8>, {{TM}} * {{TN}}>;
655663
656- let kTiles = ({{K}} + 15u) / 16u;
664+ for (var k: u32 = 0u; k < {{K}}; k = k + 8u) {
665+ workgroupBarrier();
666+ for (var i: u32 = 0; i < {{TM}}; i++) {
667+ Ax[i] = subgroupMatrixLoad<subgroup_matrix_left<{{precision}},8,8>>(&A, baseA + i * 8u*{{K}} + k, false, {{K}});
668+ }
657669
658- // Load the first tile and multiply to initialize accumulator
659- let a_tile_0 = subgroupMatrixLoad<subgroup_matrix_left<{{precision}}, 16, 16>>(A, outRowStart * {{K}}, true, {{K}});
660- let b_tile_0 = subgroupMatrixLoad<subgroup_matrix_right<{{precision}}, 16, 16>>(B, outColStart, true, {{N}});
661- acc = subgroupMatrixMultiply<{{precision}}>(a_tile_0, b_tile_0);
670+ for (var i: u32 = 0; i < {{TN}}; i++) {
671+ Bx[i] = subgroupMatrixLoad<subgroup_matrix_right<{{precision}},8,8>>(&B, baseB + k*{{N}} + 8u * i, false, {{N}});
672+ }
662673
663- // Loop over the rest of the K-dimension
664- for (var kTile: u32 = 1u; kTile < kTiles; kTile = kTile + 1u) {
665- let k = kTile * 16u;
666- let a_tile = subgroupMatrixLoad<subgroup_matrix_left<{{precision}}, 16, 16>>(A, outRowStart * {{K}} + k, true, {{K}});
667- let b_tile = subgroupMatrixLoad<subgroup_matrix_right<{{precision}}, 16, 16>>(B, k * {{N}} + outColStart, true, {{N}});
668- acc = subgroupMatrixMultiplyAccumulate(a_tile, b_tile, acc);
674+ for (var i: u32 = 0; i < {{TM}}; i++) {
675+ for (var j: u32 = 0; j < {{TN}}; j++) {
676+ accxx[i+j*{{TM}}] = subgroupMatrixMultiplyAccumulate(Ax[i], Bx[j], accxx[i+j*{{TM}}]);
677+ }
669678 }
679+ }
670680
671- subgroupMatrixStore(C, outRowStart * {{N}} + outColStart, acc, true, {{N}});
681+ workgroupBarrier();
682+ for (var i: u32 = 0; i < {{TM}}; i++) {
683+ for (var j: u32 = 0; j < {{TN}}; j++) {
684+ subgroupMatrixStore(&C, cBase + i * 8u * {{N}} + 8u * j, accxx[i+j*{{TM}}], false, {{N}});
685+ }
686+ }
672687}
673688)" ;
674689
675-
676690/* *
677691 * @brief No-Op shader with matmul bindings for performance testing
678692 */
@@ -743,26 +757,30 @@ Kernel selectMatmul(Context &ctx, int version,
743757 const Bindings</* input, weights, output */ 3 > &bindings,
744758 size_t M, size_t K, size_t N, NumType numtype) {
745759 Kernel kernel;
760+ CompilationInfo info;
746761 if (version == 1 ) {
747762 Shape wgSize = {256 , 1 , 1 };
748763 Shape nWorkgroups = cdiv ({M, N, 1 }, {16 , 16 , 1 });
749764 KernelCode matmul = createNoOp (kShaderNoOp , /* wgsize*/ wgSize);
750765 kernel = createKernel (ctx, matmul, bindings,
751- /* nWorkgroups*/ nWorkgroups);
766+ /* nWorkgroups*/ nWorkgroups,
767+ NoParam{}, &info);
752768 } else if (version == 2 ) {
753769 Shape wgSize = {16 , 16 , 1 };
754770 LOG (kDefLog , kInfo , " wgSize: %s" , toString (wgSize).c_str ());
755771 KernelCode matmul =
756772 createMatmul1 (kShaderMatmul1 , M, K, N, /* wgsize*/ wgSize, numtype);
757773 kernel = createKernel (ctx, matmul, bindings,
758- /* nWorkgroups*/ cdiv ({M, N, 1 }, wgSize));
774+ /* nWorkgroups*/ cdiv ({M, N, 1 }, wgSize),
775+ NoParam{}, &info);
759776 } else if (version == 3 ) {
760777 static constexpr size_t tileSize = 16 ;
761778 KernelCode matmul = createMatmul2 (kShaderMatmul2 , M, K, N,
762779 /* wgSize*/ {tileSize * tileSize, 1 , 1 }, numtype);
763780 kernel =
764781 createKernel (ctx, matmul, bindings,
765- /* nWorkgroups*/ cdiv ({M, N, 1 }, {tileSize, tileSize, 1 }));
782+ /* nWorkgroups*/ cdiv ({M, N, 1 }, {tileSize, tileSize, 1 }),
783+ NoParam{}, &info);
766784 } else if (version == 4 || version == 6 ) {
767785 static constexpr size_t BM = 64 ;
768786 static constexpr size_t BK = 4 ;
@@ -781,7 +799,8 @@ Kernel selectMatmul(Context &ctx, int version,
781799 numtype,
782800 /* Loop unrolling*/ version == 6 ? true : false );
783801 kernel = createKernel (ctx, matmul, bindings,
784- /* nWorkgroups*/ nWorkgroups);
802+ /* nWorkgroups*/ nWorkgroups,
803+ NoParam{}, &info);
785804 } else if (version == 5 || version == 7 ) {
786805 static constexpr size_t BM = 64 ;
787806 static constexpr size_t BK = 8 ;
@@ -799,7 +818,8 @@ Kernel selectMatmul(Context &ctx, int version,
799818 numtype,
800819 /* Loop unrolling*/ version == 7 ? true : false );
801820 kernel = createKernel (ctx, matmul, bindings,
802- /* nWorkgroups*/ nWorkgroups);
821+ /* nWorkgroups*/ nWorkgroups,
822+ NoParam{}, &info);
803823 } else if (version == 8 || version == 10 ) {
804824 static constexpr size_t BM = 64 ;
805825 static constexpr size_t BK = 8 ;
@@ -817,7 +837,8 @@ Kernel selectMatmul(Context &ctx, int version,
817837 numtype,
818838 /* Loop unrolling*/ true );
819839 kernel = createKernel (ctx, matmul, bindings,
820- /* nWorkgroups*/ nWorkgroups);
840+ /* nWorkgroups*/ nWorkgroups,
841+ NoParam{}, &info);
821842 } else if (version == 9 || version == 11 ) {
822843 static constexpr size_t BM = 64 ;
823844 static constexpr size_t BK = 8 ;
@@ -834,18 +855,37 @@ Kernel selectMatmul(Context &ctx, int version,
834855 /* wgSize*/ wgSize,
835856 numtype);
836857 kernel = createKernel (ctx, matmul, bindings,
837- /* nWorkgroups*/ nWorkgroups);
858+ /* nWorkgroups*/ nWorkgroups,
859+ NoParam{}, &info);
838860 } else if (version == 12 ) {
839861 // f32: Subgroup matrix multiply
840- Shape wgSize = {256 , 1 , 1 }; // One subgroup per workgroup
841- Shape nWorkgroups = {cdiv (N, 16 ), cdiv (M, 16 ), 1 };
862+ static constexpr size_t TM = 2 ;
863+ static constexpr size_t TN = 4 ;
864+ Shape wgSize = {64 , 1 , 1 }; // One subgroup per workgroup
865+ Shape nWorkgroups = {cdiv (M, 8 * TM), cdiv (N, 8 * TN), 1 };
842866 LOG (kDefLog , kInfo , " M: %zu, K: %zu, N: %zu" , M, K, N);
843867 LOG (kDefLog , kInfo , " wgSize: ( %s )" , toString (wgSize).c_str ());
844868 LOG (kDefLog , kInfo , " nWorkgroups: ( %s )" , toString (nWorkgroups).c_str ());
845- KernelCode matmul =
846- createMatmul12 (kShaderSubgroupMatrixMultiply , M, K, N, numtype);
847- kernel = createKernel (ctx, matmul, bindings, nWorkgroups);
869+ KernelCode matmul = createMatmul12 (kShaderSubgroupMatrixMultiply , M, K, N, TM, TN, wgSize, numtype);
870+ kernel = createKernel (ctx, matmul, bindings, nWorkgroups,
871+ NoParam{}, &info);
872+ }
873+
874+ if (info.status != WGPUCompilationInfoRequestStatus_Success) {
875+ LOG (kDefLog , kError , " Failed to compile shader" );
876+ for (size_t i = 0 ; i < info.messages .size (); i++) {
877+ LOG (kDefLog , kError , " Line %llu, Pos %llu: %s" , info.lineNums [i],
878+ info.linePos [i], info.messages [i].c_str ());
879+ }
880+ exit (1 );
881+ } else {
882+ LOG (kDefLog , kInfo , " Shader compiled successfully" );
883+ for (size_t i = 0 ; i < info.messages .size (); i++) {
884+ LOG (kDefLog , kInfo , " Line %llu, Pos %llu: %s" , info.lineNums [i],
885+ info.linePos [i], info.messages [i].c_str ());
886+ }
848887 }
888+
849889 return kernel;
850890}
851891
@@ -866,36 +906,49 @@ void runTest(int version, size_t M, size_t K, size_t N,
866906 devDescriptor.requiredFeatureCount = 1 ;
867907 devDescriptor.requiredFeatures = std::array{WGPUFeatureName_ShaderF16}.data ();
868908
869- Context ctx;
870- if (numtype == kf16) {
871- ctx = createContext (
872- {}, {},
873- /* device descriptor, enabling f16 in WGSL*/
874- {
875- .requiredFeatureCount = 1 ,
876- .requiredFeatures = std::array{WGPUFeatureName_ShaderF16}.data ()
877- });
878- if (ctx.adapterStatus != WGPURequestAdapterStatus_Success) {
879- LOG (kDefLog , kError , " Failed to create adapter with f16 support, try running an f32 test instead (`export MATMUL_VERSION=9)." );
880- exit (1 );
909+ WGPUDawnTogglesDescriptor toggles = {};
910+ toggles.chain .sType = WGPUSType_DawnTogglesDescriptor;
911+ const char * enableList[] = {" allow_unsafe_apis" };
912+ toggles.enabledToggles = enableList;
913+ toggles.enabledToggleCount = 1 ;
914+
915+ WGPUDeviceDescriptor devDesc = {};
916+ devDesc.nextInChain = &toggles.chain ;
917+ devDesc.requiredFeatureCount = 3 ,
918+ devDesc.requiredFeatures = std::array{
919+ WGPUFeatureName_ShaderF16,
920+ WGPUFeatureName_Subgroups,
921+ WGPUFeatureName_ChromiumExperimentalSubgroupMatrix
922+ }.data ();
923+ devDesc.uncapturedErrorCallbackInfo = WGPUUncapturedErrorCallbackInfo {
924+ .callback = [](WGPUDevice const * device, WGPUErrorType type, WGPUStringView msg, void *, void *) {
925+ LOG (kDefLog , kError , " [Uncaptured %d] %.*s\n " , (int )type, (int )msg.length , msg.data );
881926 }
882- if (ctx.deviceStatus != WGPURequestDeviceStatus_Success) {
883- LOG (kDefLog , kError , " Failed to create device with f16 support, try running an f32 test instead. (`export MATMUL_VERSION=9)" );
884- exit (1 );
927+ };
928+ devDesc.deviceLostCallbackInfo = WGPUDeviceLostCallbackInfo {
929+ .mode = WGPUCallbackMode_AllowSpontaneous,
930+ .callback = [](WGPUDevice const * device, WGPUDeviceLostReason reason, WGPUStringView msg, void *, void *) {
931+ LOG (kDefLog , kError , " [DeviceLost %d] %.*s\n " , (int )reason, (int )msg.length , msg.data );
885932 }
886- }
887-
888- if (numtype == kf32) {
889- ctx = createContext ({}, {}, {});
890- if (ctx.adapterStatus != WGPURequestAdapterStatus_Success ||
891- ctx.deviceStatus != WGPURequestDeviceStatus_Success) {
892- LOG (kDefLog , kError , " Failed to create adapter or device" );
893- // stop execution
894- exit (1 );
895- } else {
896- LOG (kDefLog , kInfo , " Successfully created adapter and device" );
933+ };
934+
935+ Context ctx = createContext ({}, {}, devDesc);
936+
937+ WGPULoggingCallbackInfo logCb{
938+ .callback = [](WGPULoggingType type, WGPUStringView msg, void *, void *) {
939+ LOG (kDefLog , kError , " [WGPU %d] %.*s\n " , (int )type, (int )msg.length , msg.data );
897940 }
898- }
941+ };
942+ wgpuDeviceSetLoggingCallback (ctx.device , logCb);
943+
944+ if (ctx.adapterStatus != WGPURequestAdapterStatus_Success ||
945+ ctx.deviceStatus != WGPURequestDeviceStatus_Success) {
946+ LOG (kDefLog , kError , " Failed to create adapter or device" );
947+ // stop execution
948+ exit (1 );
949+ } else {
950+ LOG (kDefLog , kInfo , " Successfully created adapter and device" );
951+ }
899952
900953 Tensor input = createTensor (ctx, Shape{M, K}, numtype, inputPtr.get ());
901954 Tensor weights = createTensor (ctx, Shape{N, K}, numtype, weightsPtr.get ()); // column-major
@@ -983,14 +1036,15 @@ const std::string versionToStr(int version){
9831036 case 9 : return " f32: 2D blocktiling with loop unrolling, vectorization and transpose" ;
9841037 case 10 : return " f16: 2D blocktiling with loop unrolling and vectorization (default)" ;
9851038 case 11 : return " f16: 2D blocktiling with loop unrolling, vectorization and transpose" ;
986- case 12 : return " f32 : Subgroup matrix multiply" ;
1039+ case 12 : return " f16 : Subgroup matrix multiply with transpose " ;
9871040 default : return " Not specified" ;
9881041 }
9891042}
9901043
9911044int main () {
1045+ std::cout << " Starting matmul test..." << std::endl;
9921046 char * version_str = getenv (" MATMUL_VERSION" );
993- int version = version_str == NULL ? 12 : atoi (version_str);
1047+ int version = version_str == NULL ? 11 : atoi (version_str);
9941048 // 1 == f32: No-Op
9951049 // 2 == f32: naive matmul
9961050 // 3 == f32: tiling
@@ -1002,8 +1056,8 @@ int main() {
10021056 // 9 == f32: 2D blocktiling with loop unrolling, vectorization and transpose
10031057 // 10 == f16: 2D blocktiling with loop unrolling and vectorization (default)
10041058 // 11 == f16: 2D blocktiling with loop unrolling, vectorization and transpose
1005- // 12 == f32 : Subgroup matrix multiply
1006- bool enableF16 = version == 10 || version ==11 ;
1059+ // 12 == f16 : Subgroup matrix multiply with transpose
1060+ bool enableF16 = version == 10 || version ==11 || version == 12 ;
10071061 bool transposedInput = version == 9 || version == 11 || version == 12 ;
10081062 NumType numtype = enableF16 ? kf16 : kf32;
10091063
0 commit comments