Skip to content

Commit 21b2f6f

Browse files
Fix the wgsl code of subgroup-matix-multiplication
1 parent 4ef6361 commit 21b2f6f

File tree

3 files changed

+138
-80
lines changed

3 files changed

+138
-80
lines changed

examples/matmul/Makefile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ run: ./build/$(TARGET)
1616
$(LIBSPEC) && ./build/$(TARGET)
1717

1818
debug: run.cpp
19-
mkdir -p build && $(CXX) $(FLAGS) -g -fsanitize=address -fno-omit-frame-pointer -Wall -o ./build/$(TARGET)
19+
mkdir -p build && $(CXX) $(FLAGS) -g -fsanitize=address -fno-omit-frame-pointer -fasynchronous-unwind-tables -Wall -o ./build/$(TARGET)
2020

2121
run_with_metal_profiler: ./build/$(TARGET)_with_metal_profiler
2222
$(LIBSPEC) && export METAL_CAPTURE_ENABLED=1 && ./build/$(TARGET)_with_metal_profiler

examples/matmul/run.cpp

Lines changed: 131 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
#include <future>
44
#include <random>
55
#include <cstdlib>
6+
#include <exception>
7+
#include <iostream>
68

79
#include "gpu.hpp" // createContext, createTensor, createKernel, dispatchKernel,
810
// wait, resetCommandBuffer, toCPU
@@ -615,64 +617,79 @@ inline KernelCode createMatmulWithTranspose(const char *shaderTemplate, const si
615617

616618
inline KernelCode createMatmul12(const char *shaderTemplate, const size_t M,
617619
const size_t K, const size_t N,
620+
const size_t TM, const size_t TN,
621+
const Shape &workgroupSize = {256, 1, 1},
618622
NumType precision = kf32) {
619623
std::string codeString(shaderTemplate);
620624
replaceAll(codeString, {{"{{precision}}", toString(precision)},
621625
{"{{M}}", toString(M)},
622626
{"{{K}}", toString(K)},
623-
{"{{N}}", toString(N)}});
624-
return {codeString, {256, 1, 1}, precision};
627+
{"{{N}}", toString(N)},
628+
{"{{TM}}", toString(TM)},
629+
{"{{TN}}", toString(TN)}
630+
});
631+
return {codeString, workgroupSize, precision};
625632
}
626633

627-
628-
629634
// ─────────────────────────────────────────────────────────────────────────────
630635
// Optimised WGSL matrix‑multiply kernel using subgroupMatrixLoad/Store
631636
// and subgroupMatrixMultiplyAccumulate
632637
// ─────────────────────────────────────────────────────────────────────────────
633638
const char* kShaderSubgroupMatrixMultiply = R"(
639+
enable subgroups;
634640
enable chromium_experimental_subgroup_matrix;
635641
636-
@group(0) @binding(0) var<storage, read> A: array<{{precision}}>;
637-
@group(0) @binding(1) var<storage, read> B: array<{{precision}}>;
638-
@group(0) @binding(2) var<storage, read_write> C: array<{{precision}}>;
642+
@group(0) @binding(0) var<storage, read_write> A: array<{{precision}}>;
643+
@group(0) @binding(1) var<storage, read_write> B: array<{{precision}}>;
644+
@group(0) @binding(2) var<storage, read_write> C: array<{{precision}}>;
645+
646+
@compute @workgroup_size({{workgroupSize}})
647+
fn main(@builtin(workgroup_id) wg: vec3<u32>,
648+
@builtin(local_invocation_id) lid: vec3<u32>,
649+
@builtin(subgroup_id) sgid: u32,
650+
@builtin(subgroup_size) ssize: u32) {
639651
640-
// Each workgroup computes one 16x16 tile of C.
641-
@compute @workgroup_size(256, 1, 1)
642-
fn main(@builtin(workgroup_id) groupID: vec3<u32>) {
652+
let rowStart: u32 = wg.x * 8u * {{TM}};
653+
let colStart: u32 = wg.y * 8u * {{TN}};
643654
644-
let tileRow = groupID.y;
645-
let tileCol = groupID.x;
655+
if (rowStart >= u32({{M}}) || colStart >= u32({{N}})) { return; }
646656
647-
let outRowStart = tileRow * 16u;
648-
let outColStart = tileCol * 16u;
657+
let baseA: u32 = rowStart * {{K}};
658+
let baseB: u32 = colStart;
659+
let cBase: u32 = rowStart * {{N}} + colStart;
649660
650-
if (outRowStart >= {{M}} || outColStart >= {{N}}) {
651-
return;
652-
}
661+
var Ax: array<subgroup_matrix_left<{{precision}}, 8, 8>, {{TM}}>;
662+
var Bx: array<subgroup_matrix_right<{{precision}}, 8, 8>, {{TN}}>;
653663
654-
var acc: subgroup_matrix_result<{{precision}}, 16, 16>;
664+
// 4x4 accumulators (8x8 each)
665+
var accxx: array<subgroup_matrix_result<{{precision}}, 8, 8>, {{TM}} * {{TN}}>;
655666
656-
let kTiles = ({{K}} + 15u) / 16u;
667+
for (var k: u32 = 0u; k < {{K}}; k = k + 8u) {
668+
workgroupBarrier();
669+
for (var i: u32 = 0; i < {{TM}}; i++) {
670+
Ax[i] = subgroupMatrixLoad<subgroup_matrix_left<{{precision}},8,8>>(&A, baseA + i * 8u*{{K}} + k, false, {{K}});
671+
}
657672
658-
// Load the first tile and multiply to initialize accumulator
659-
let a_tile_0 = subgroupMatrixLoad<subgroup_matrix_left<{{precision}}, 16, 16>>(A, outRowStart * {{K}}, true, {{K}});
660-
let b_tile_0 = subgroupMatrixLoad<subgroup_matrix_right<{{precision}}, 16, 16>>(B, outColStart, true, {{N}});
661-
acc = subgroupMatrixMultiply<{{precision}}>(a_tile_0, b_tile_0);
673+
for (var i: u32 = 0; i < {{TN}}; i++) {
674+
Bx[i] = subgroupMatrixLoad<subgroup_matrix_right<{{precision}},8,8>>(&B, baseB + k*{{N}} + 8u * i, false, {{N}});
675+
}
662676
663-
// Loop over the rest of the K-dimension
664-
for (var kTile: u32 = 1u; kTile < kTiles; kTile = kTile + 1u) {
665-
let k = kTile * 16u;
666-
let a_tile = subgroupMatrixLoad<subgroup_matrix_left<{{precision}}, 16, 16>>(A, outRowStart * {{K}} + k, true, {{K}});
667-
let b_tile = subgroupMatrixLoad<subgroup_matrix_right<{{precision}}, 16, 16>>(B, k * {{N}} + outColStart, true, {{N}});
668-
acc = subgroupMatrixMultiplyAccumulate(a_tile, b_tile, acc);
677+
for (var i: u32 = 0; i < {{TM}}; i++) {
678+
for (var j: u32 = 0; j < {{TN}}; j++) {
679+
accxx[i+j*{{TM}}] = subgroupMatrixMultiplyAccumulate(Ax[i], Bx[j], accxx[i+j*{{TM}}]);
680+
}
669681
}
682+
}
670683
671-
subgroupMatrixStore(C, outRowStart * {{N}} + outColStart, acc, true, {{N}});
684+
workgroupBarrier();
685+
for (var i: u32 = 0; i < {{TM}}; i++) {
686+
for (var j: u32 = 0; j < {{TN}}; j++) {
687+
subgroupMatrixStore(&C, cBase + i * 8u * {{N}} + 8u * j, accxx[i+j*{{TM}}], false, {{N}});
688+
}
689+
}
672690
}
673691
)";
674692

675-
676693
/**
677694
* @brief No-Op shader with matmul bindings for performance testing
678695
*/
@@ -743,26 +760,30 @@ Kernel selectMatmul(Context &ctx, int version,
743760
const Bindings</* input, weights, output */ 3> &bindings,
744761
size_t M, size_t K, size_t N, NumType numtype) {
745762
Kernel kernel;
763+
CompilationInfo info;
746764
if (version == 1) {
747765
Shape wgSize = {256, 1, 1};
748766
Shape nWorkgroups = cdiv({M, N, 1}, {16, 16, 1});
749767
KernelCode matmul = createNoOp(kShaderNoOp, /*wgsize*/ wgSize);
750768
kernel = createKernel(ctx, matmul, bindings,
751-
/*nWorkgroups*/ nWorkgroups);
769+
/*nWorkgroups*/ nWorkgroups,
770+
NoParam{}, &info);
752771
} else if (version == 2) {
753772
Shape wgSize = {16, 16, 1};
754773
LOG(kDefLog, kInfo, "wgSize: %s", toString(wgSize).c_str());
755774
KernelCode matmul =
756775
createMatmul1(kShaderMatmul1, M, K, N, /*wgsize*/ wgSize, numtype);
757776
kernel = createKernel(ctx, matmul, bindings,
758-
/*nWorkgroups*/ cdiv({M, N, 1}, wgSize));
777+
/*nWorkgroups*/ cdiv({M, N, 1}, wgSize),
778+
NoParam{}, &info);
759779
} else if (version == 3) {
760780
static constexpr size_t tileSize = 16;
761781
KernelCode matmul = createMatmul2(kShaderMatmul2, M, K, N,
762782
/*wgSize*/ {tileSize * tileSize, 1, 1}, numtype);
763783
kernel =
764784
createKernel(ctx, matmul, bindings,
765-
/* nWorkgroups*/ cdiv({M, N, 1}, {tileSize, tileSize, 1}));
785+
/* nWorkgroups*/ cdiv({M, N, 1}, {tileSize, tileSize, 1}),
786+
NoParam{}, &info);
766787
} else if (version == 4 || version == 6) {
767788
static constexpr size_t BM = 64;
768789
static constexpr size_t BK = 4;
@@ -781,7 +802,8 @@ Kernel selectMatmul(Context &ctx, int version,
781802
numtype,
782803
/*Loop unrolling*/ version == 6 ? true: false);
783804
kernel = createKernel(ctx, matmul, bindings,
784-
/*nWorkgroups*/ nWorkgroups);
805+
/*nWorkgroups*/ nWorkgroups,
806+
NoParam{}, &info);
785807
} else if (version == 5 || version == 7) {
786808
static constexpr size_t BM = 64;
787809
static constexpr size_t BK = 8;
@@ -799,7 +821,8 @@ Kernel selectMatmul(Context &ctx, int version,
799821
numtype,
800822
/*Loop unrolling*/ version == 7 ? true: false);
801823
kernel = createKernel(ctx, matmul, bindings,
802-
/*nWorkgroups*/ nWorkgroups);
824+
/*nWorkgroups*/ nWorkgroups,
825+
NoParam{}, &info);
803826
} else if (version == 8 || version == 10) {
804827
static constexpr size_t BM = 64;
805828
static constexpr size_t BK = 8;
@@ -817,7 +840,8 @@ Kernel selectMatmul(Context &ctx, int version,
817840
numtype,
818841
/*Loop unrolling*/ true);
819842
kernel = createKernel(ctx, matmul, bindings,
820-
/*nWorkgroups*/ nWorkgroups);
843+
/*nWorkgroups*/ nWorkgroups,
844+
NoParam{}, &info);
821845
} else if (version == 9 || version == 11) {
822846
static constexpr size_t BM = 64;
823847
static constexpr size_t BK = 8;
@@ -834,18 +858,37 @@ Kernel selectMatmul(Context &ctx, int version,
834858
/*wgSize*/ wgSize,
835859
numtype);
836860
kernel = createKernel(ctx, matmul, bindings,
837-
/*nWorkgroups*/ nWorkgroups);
861+
/*nWorkgroups*/ nWorkgroups,
862+
NoParam{}, &info);
838863
} else if (version == 12) {
839864
// f32: Subgroup matrix multiply
840-
Shape wgSize = {256, 1, 1}; // One subgroup per workgroup
841-
Shape nWorkgroups = {cdiv(N, 16), cdiv(M, 16), 1};
865+
static constexpr size_t TM = 2;
866+
static constexpr size_t TN = 4;
867+
Shape wgSize = {64, 1, 1}; // One subgroup per workgroup
868+
Shape nWorkgroups = {cdiv(M, 8 * TM), cdiv(N, 8 * TN), 1};
842869
LOG(kDefLog, kInfo, "M: %zu, K: %zu, N: %zu", M, K, N);
843870
LOG(kDefLog, kInfo, "wgSize: ( %s )", toString(wgSize).c_str());
844871
LOG(kDefLog, kInfo, "nWorkgroups: ( %s )", toString(nWorkgroups).c_str());
845-
KernelCode matmul =
846-
createMatmul12(kShaderSubgroupMatrixMultiply, M, K, N, numtype);
847-
kernel = createKernel(ctx, matmul, bindings, nWorkgroups);
872+
KernelCode matmul = createMatmul12(kShaderSubgroupMatrixMultiply, M, K, N, TM, TN, wgSize, numtype);
873+
kernel = createKernel(ctx, matmul, bindings, nWorkgroups,
874+
NoParam{}, &info);
875+
}
876+
877+
if (info.status != WGPUCompilationInfoRequestStatus_Success) {
878+
LOG(kDefLog, kError, "Failed to compile shader");
879+
for (size_t i = 0; i < info.messages.size(); i++) {
880+
LOG(kDefLog, kError, "Line %llu, Pos %llu: %s", info.lineNums[i],
881+
info.linePos[i], info.messages[i].c_str());
882+
}
883+
exit(1);
884+
} else {
885+
LOG(kDefLog, kInfo, "Shader compiled successfully");
886+
for (size_t i = 0; i < info.messages.size(); i++) {
887+
LOG(kDefLog, kInfo, "Line %llu, Pos %llu: %s", info.lineNums[i],
888+
info.linePos[i], info.messages[i].c_str());
889+
}
848890
}
891+
849892
return kernel;
850893
}
851894

@@ -866,36 +909,49 @@ void runTest(int version, size_t M, size_t K, size_t N,
866909
devDescriptor.requiredFeatureCount = 1;
867910
devDescriptor.requiredFeatures = std::array{WGPUFeatureName_ShaderF16}.data();
868911

869-
Context ctx;
870-
if (numtype == kf16) {
871-
ctx = createContext(
872-
{}, {},
873-
/*device descriptor, enabling f16 in WGSL*/
874-
{
875-
.requiredFeatureCount = 1,
876-
.requiredFeatures = std::array{WGPUFeatureName_ShaderF16}.data()
877-
});
878-
if (ctx.adapterStatus != WGPURequestAdapterStatus_Success) {
879-
LOG(kDefLog, kError, "Failed to create adapter with f16 support, try running an f32 test instead (`export MATMUL_VERSION=9).");
880-
exit(1);
912+
WGPUDawnTogglesDescriptor toggles = {};
913+
toggles.chain.sType = WGPUSType_DawnTogglesDescriptor;
914+
const char* enableList[] = {"allow_unsafe_apis"};
915+
toggles.enabledToggles = enableList;
916+
toggles.enabledToggleCount = 1;
917+
918+
WGPUDeviceDescriptor devDesc = {};
919+
devDesc.nextInChain = &toggles.chain;
920+
devDesc.requiredFeatureCount = 3,
921+
devDesc.requiredFeatures = std::array{
922+
WGPUFeatureName_ShaderF16,
923+
WGPUFeatureName_Subgroups,
924+
WGPUFeatureName_ChromiumExperimentalSubgroupMatrix
925+
}.data();
926+
devDesc.uncapturedErrorCallbackInfo = WGPUUncapturedErrorCallbackInfo {
927+
.callback = [](WGPUDevice const * device, WGPUErrorType type, WGPUStringView msg, void*, void*) {
928+
LOG(kDefLog, kError, "[Uncaptured %d] %.*s\n", (int)type, (int)msg.length, msg.data);
881929
}
882-
if (ctx.deviceStatus != WGPURequestDeviceStatus_Success) {
883-
LOG(kDefLog, kError, "Failed to create device with f16 support, try running an f32 test instead. (`export MATMUL_VERSION=9)");
884-
exit(1);
930+
};
931+
devDesc.deviceLostCallbackInfo = WGPUDeviceLostCallbackInfo {
932+
.mode = WGPUCallbackMode_AllowSpontaneous,
933+
.callback = [](WGPUDevice const * device, WGPUDeviceLostReason reason, WGPUStringView msg, void*, void*) {
934+
LOG(kDefLog, kError, "[DeviceLost %d] %.*s\n", (int)reason, (int)msg.length, msg.data);
885935
}
886-
}
887-
888-
if (numtype == kf32) {
889-
ctx = createContext({}, {}, {});
890-
if (ctx.adapterStatus != WGPURequestAdapterStatus_Success ||
891-
ctx.deviceStatus != WGPURequestDeviceStatus_Success) {
892-
LOG(kDefLog, kError, "Failed to create adapter or device");
893-
// stop execution
894-
exit(1);
895-
} else {
896-
LOG(kDefLog, kInfo, "Successfully created adapter and device");
936+
};
937+
938+
Context ctx = createContext({}, {}, devDesc);
939+
940+
WGPULoggingCallbackInfo logCb{
941+
.callback = [](WGPULoggingType type, WGPUStringView msg, void*, void*) {
942+
LOG(kDefLog, kError, "[WGPU %d] %.*s\n", (int)type, (int)msg.length, msg.data);
897943
}
898-
}
944+
};
945+
wgpuDeviceSetLoggingCallback(ctx.device, logCb);
946+
947+
if (ctx.adapterStatus != WGPURequestAdapterStatus_Success ||
948+
ctx.deviceStatus != WGPURequestDeviceStatus_Success) {
949+
LOG(kDefLog, kError, "Failed to create adapter or device");
950+
// stop execution
951+
exit(1);
952+
} else {
953+
LOG(kDefLog, kInfo, "Successfully created adapter and device");
954+
}
899955

900956
Tensor input = createTensor(ctx, Shape{M, K}, numtype, inputPtr.get());
901957
Tensor weights = createTensor(ctx, Shape{N, K}, numtype, weightsPtr.get()); // column-major
@@ -983,14 +1039,15 @@ const std::string versionToStr(int version){
9831039
case 9: return "f32: 2D blocktiling with loop unrolling, vectorization and transpose";
9841040
case 10: return "f16: 2D blocktiling with loop unrolling and vectorization (default)";
9851041
case 11: return "f16: 2D blocktiling with loop unrolling, vectorization and transpose";
986-
case 12: return "f32: Subgroup matrix multiply";
1042+
case 12: return "f16: Subgroup matrix multiply with transpose";
9871043
default: return "Not specified";
9881044
}
9891045
}
9901046

9911047
int main() {
1048+
std::cout << "Starting matmul test..." << std::endl;
9921049
char* version_str = getenv("MATMUL_VERSION");
993-
int version = version_str == NULL ? 12 : atoi(version_str);
1050+
int version = version_str == NULL ? 11 : atoi(version_str);
9941051
// 1 == f32: No-Op
9951052
// 2 == f32: naive matmul
9961053
// 3 == f32: tiling
@@ -1002,8 +1059,8 @@ int main() {
10021059
// 9 == f32: 2D blocktiling with loop unrolling, vectorization and transpose
10031060
// 10 == f16: 2D blocktiling with loop unrolling and vectorization (default)
10041061
// 11 == f16: 2D blocktiling with loop unrolling, vectorization and transpose
1005-
// 12 == f32: Subgroup matrix multiply
1006-
bool enableF16 = version == 10 || version ==11;
1062+
// 12 == f16: Subgroup matrix multiply with transpose
1063+
bool enableF16 = version == 10 || version ==11 || version == 12;
10071064
bool transposedInput = version == 9 || version == 11 || version == 12;
10081065
NumType numtype = enableF16 ? kf16 : kf32;
10091066

gpu.hpp

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -412,7 +412,7 @@ struct KernelCode {
412412
}
413413
replaceAll(data, "{{workgroupSize}}", toString(workgroupSize));
414414
replaceAll(data, "{{precision}}", toString(precision));
415-
LOG(kDefLog, kInfo, "Shader code:\n%s", data.c_str());
415+
LOG(kDefLog, kTrace, "Shader code:\n%s", data.c_str());
416416
}
417417

418418
/**
@@ -438,7 +438,7 @@ struct KernelCode {
438438
replaceAll(data, "{{workgroupSize}}", toString(workgroupSize));
439439
replaceAll(data, "{{precision}}", toString(precision));
440440
replaceAll(data, "{{totalWorkgroups}}", toString(totalWorkgroups));
441-
LOG(kDefLog, kInfo, "Shader code:\n%s", data.c_str());
441+
LOG(kDefLog, kTrace, "Shader code:\n%s", data.c_str());
442442
}
443443

444444
/**
@@ -464,7 +464,7 @@ struct KernelCode {
464464
replaceAll(data, "{{workgroupSize}}", toString({workgroupSize, 1, 1}));
465465
replaceAll(data, "{{precision}}", toString(precision));
466466
replaceAll(data, "{{totalWorkgroups}}", toString(totalWorkgroups));
467-
LOG(kDefLog, kInfo, "Shader code:\n%s", data.c_str());
467+
LOG(kDefLog, kTrace, "Shader code:\n%s", data.c_str());
468468
}
469469

470470
std::string data;
@@ -1309,6 +1309,7 @@ createContextAsync(const WGPUInstanceDescriptor &desc = {},
13091309
ctx.device = wait(ctx, deviceFuture);
13101310
ctx.deviceStatus = WGPURequestDeviceStatus_Success;
13111311
} catch (const std::exception &ex) {
1312+
LOG(kDefLog, kTrace, "requestDeviceAsync: %s", ex.what());
13121313
promise->set_exception(std::make_exception_ptr(ex));
13131314
return promise->get_future();
13141315
}
@@ -1594,7 +1595,7 @@ inline void bufferMapCallback(WGPUMapAsyncStatus status, WGPUStringView message,
15941595
* and a promise to signal completion.
15951596
* @param userdata2 Unused.
15961597
*/
1597-
inline void queueWorkDoneCallback(WGPUQueueWorkDoneStatus status,
1598+
inline void queueWorkDoneCallback(WGPUQueueWorkDoneStatus status, WGPUStringView message,
15981599
void *userdata1, void * /*userdata2*/) {
15991600
const CallbackData *cbData = static_cast<CallbackData *>(userdata1);
16001601
// Ensure the queue work finished successfully.
@@ -2837,7 +2838,7 @@ Kernel createKernel(Context &ctx, const KernelCode &code,
28372838
* when the work is done.
28382839
* @param userdata2 Unused.
28392840
*/
2840-
inline void dispatchKernelCallback(WGPUQueueWorkDoneStatus status,
2841+
inline void dispatchKernelCallback(WGPUQueueWorkDoneStatus status, WGPUStringView message,
28412842
void *userdata1, void * /*userdata2*/) {
28422843
// Cast the userdata pointer back to our heap‑allocated promise.
28432844
auto *p = reinterpret_cast<std::promise<void> *>(userdata1);

0 commit comments

Comments
 (0)