-
Notifications
You must be signed in to change notification settings - Fork 13.7k
HIP: WMMA-MMQ kernels for RDNA 4 #17156
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from all commits
6ee3f8b
59d0c47
f91615c
21db114
d9249de
c62d844
d545fe9
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -210,6 +210,7 @@ option(GGML_HIP_GRAPHS "ggml: use HIP graph, experimental, | |
| option(GGML_HIP_NO_VMM "ggml: do not try to use HIP VMM" ON) | ||
| option(GGML_HIP_ROCWMMA_FATTN "ggml: enable rocWMMA for FlashAttention" OFF) | ||
| option(GGML_HIP_MMQ_MFMA "ggml: enable MFMA MMA for CDNA in MMQ" ON) | ||
| option(GGML_HIP_MMQ_WMMA "ggml: enable WMMA MMA for RDNA4 in MMQ" ON) | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. For now this is fine but long-term, after the kernels have been fully optimized and tested per datatype, it would be preferable to re-use the
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This dose not replace those, but makes it use the dp4a mmq kernels instead. I added this for cdna/mfma because it allows testing for GCN performance regressions on CDNA. Similarly this allows testing for RDNA1/2 performance regressions on RDNA4. I would prefer this to be kept. EDIT: i gues testing for RDNA1/2 performance on RDNA4 is less useful than testing for GCN performance on CDNA as RDNA4 has more VGPRS and some new VALU instructions compared to RDNA1/2 unlike CDNA/GCN which have fewer differences
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
My experience so far has been that the portability of performance across GPUs is so poor that something like this is of little utility. In the rare cases where emulating old hardware is needed one should just edit the code temporarily. If options like this are exposed to users they are going to use them and that increases the amount of work that needs to be put into maintenance. So long-term I still intend to remove those options. My current AMD lineup consists of RDNA2, RDNA3.5, RDNA4, GCN5.1, and CDNA1, and in the next months I intend to add RDNA3 and CDNA2. I would just test the performance using those GPUs directly.
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Not everyone has a huge selection of hardware to choose from. Across GCN5.1/gfx906 and CDNA in my experience the performance portability is extremely close, this is no surprise as the changes made to CDNA that are relevant to ggml are very slight:
The only practical difference in the generated assembly is that under register pressure the compiler will spill to MFMAs register space instead of scratch memory, which very slightly reduces the cost of spills under register pressure. The picture changes only slightly with CDNA2 where the physical (but not logical) register space between the valu and mfma instructions is now shared, meaning the minimum occupancy for a valu kernel allocating all 256 registers is 2 and packed 32bit instructions where added, but again in my expirance the performance on cdna2 predicts extremely closely the performance on GCN. I dont have much expirance with RDNA and its true that the changes between RDNA generations are larger.
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In any case, we to my knowledge don't have anyone who would be using the |
||
| option(GGML_HIP_EXPORT_METRICS "ggml: enable kernel perf metrics output" OFF) | ||
| option(GGML_MUSA_GRAPHS "ggml: use MUSA graph, experimental, unstable" OFF) | ||
| option(GGML_MUSA_MUDNN_COPY "ggml: enable muDNN for accelerated copy" OFF) | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -73,7 +73,7 @@ namespace ggml_cuda_mma { | |
| static constexpr int I = I_; | ||
| static constexpr int J = J_; | ||
|
|
||
| #if defined(GGML_USE_HIP) | ||
| #if defined(AMD_MFMA_AVAILABLE) | ||
| static constexpr int ne = I * J / 64; | ||
| T x[ne] = {0}; | ||
|
|
||
|
|
@@ -149,6 +149,28 @@ namespace ggml_cuda_mma { | |
| return -1; | ||
| } | ||
| } | ||
| #elif defined(AMD_WMMA_AVAILABLE) //adjusted the mapping for RDNA 4 | ||
|
|
||
| static constexpr int ne = I * J / 32; | ||
| T x[ne] = {0}; | ||
|
|
||
| static __device__ __forceinline__ int get_i(const int l) { | ||
| if constexpr (I == 16 && J == 16) { | ||
| return 8 * (threadIdx.x / 16) + l; | ||
| } else { | ||
| NO_DEVICE_CODE; | ||
| return -1; | ||
| } | ||
| } | ||
|
|
||
| static __device__ __forceinline__ int get_j(const int l) { | ||
| if constexpr (I == 16 && J == 16) { | ||
| return threadIdx.x % 16; | ||
| } else { | ||
| NO_DEVICE_CODE; | ||
| return -1; | ||
| } | ||
| } | ||
| #else | ||
| static constexpr int ne = I * J / 32; | ||
| T x[ne] = {0}; | ||
|
|
@@ -353,6 +375,20 @@ namespace ggml_cuda_mma { | |
| const int64_t * xs = (int64_t *) ((const int *) xs0 + (threadIdx.x % t.I) * stride + 2 * (threadIdx.x / t.I)); | ||
| xi[0] = xs[0]; | ||
| } | ||
|
|
||
| #elif defined(AMD_WMMA_AVAILABLE) | ||
| if constexpr (I == 16 && J == 4) { | ||
| int64_t * xi = (int64_t *) t.x; | ||
| const int64_t * xs = (int64_t *) ((const int *) xs0 + (threadIdx.x % t.I) * stride + 2 * (threadIdx.x / t.I)); | ||
| xi[0] = xs[0]; | ||
| } else { | ||
| int64_t * xi = (int64_t *) t.x; | ||
| const int64_t * xs = (int64_t *) ((const int *) xs0 + (threadIdx.x % t.I) * stride + 4 * (threadIdx.x / t.I)); | ||
| xi[0] = xs[0]; | ||
|
|
||
| const int64_t * xs1 = (int64_t *) ((const int *) xs0 + (threadIdx.x % t.I) * stride + 4 * (threadIdx.x / t.I) + 2); | ||
| xi[1] = xs1[0]; | ||
| } | ||
| #else | ||
| #pragma unroll | ||
| for (int l = 0; l < t.ne; ++l) { | ||
|
|
@@ -665,6 +701,36 @@ namespace ggml_cuda_mma { | |
| acc[0], | ||
| 0, 0, 0); | ||
| #endif // defined(CDNA3) | ||
|
|
||
| #elif defined(AMD_WMMA_AVAILABLE) | ||
| using int32x2_t = __attribute__((__vector_size__(2 * sizeof(int)))) int; | ||
| int32x2_t * a_vec = (int32x2_t *) A.x; | ||
| int32x2_t * b_vec = (int32x2_t *) B.x; | ||
|
|
||
| using int32x8_t = __attribute__((__vector_size__(8 * sizeof(int)))) int; | ||
| int32x8_t * acc = (int32x8_t *) D.x; | ||
|
|
||
| #if defined(RDNA4) | ||
|
|
||
| acc[0] = __builtin_amdgcn_wmma_i32_16x16x16_iu8_w32_gfx12( | ||
| true, | ||
| a_vec[0], | ||
| true, | ||
| b_vec[0], | ||
| acc[0], | ||
| true | ||
| ); | ||
|
|
||
| acc[0] = __builtin_amdgcn_wmma_i32_16x16x16_iu8_w32_gfx12( | ||
| true, | ||
| a_vec[1], | ||
| true, | ||
| b_vec[1], | ||
| acc[0], | ||
| true | ||
| ); | ||
| #endif // defined(RDNA4) | ||
|
|
||
|
Comment on lines
+704
to
+733
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is to my understanding currently unused, so please remove it.
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hi I believe this is used in vec_dot_q8_0_q8_1_mma function which are called in Q4_0, Q5_0, Q8_0, MXFP4 etc |
||
| #else | ||
| GGML_UNUSED_VARS(D, A, B); | ||
| NO_DEVICE_CODE; | ||
|
|
@@ -691,6 +757,7 @@ namespace ggml_cuda_mma { | |
| acc[0], | ||
| 0, 0, 0); | ||
| #endif // defined(CDNA3) | ||
|
|
||
| #else | ||
| GGML_UNUSED_VARS(D, A, B); | ||
| NO_DEVICE_CODE; | ||
|
|
@@ -735,4 +802,31 @@ namespace ggml_cuda_mma { | |
| mma(D16[1], A16[1], B); | ||
| #endif // __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE | ||
| } | ||
|
|
||
| static __device__ __forceinline__ void mma( | ||
| tile<16, 16, int> & D, const tile<16, 4, int> & A, const tile<16, 4, int> & B) { | ||
| #if defined(AMD_WMMA_AVAILABLE) | ||
| using int32x2_t = __attribute__((__vector_size__(2 * sizeof(int)))) int; | ||
| int32x2_t * a_vec = (int32x2_t *) A.x; | ||
| int32x2_t * b_vec = (int32x2_t *) B.x; | ||
|
|
||
| using int32x8_t = __attribute__((__vector_size__(8 * sizeof(int)))) int; | ||
| int32x8_t * acc = (int32x8_t *) D.x; | ||
|
|
||
| acc[0] = __builtin_amdgcn_wmma_i32_16x16x16_iu8_w32_gfx12( | ||
| true, | ||
| a_vec[0], | ||
| true, | ||
| b_vec[0], | ||
| acc[0], | ||
| false | ||
| ); | ||
| #else | ||
| GGML_UNUSED(D); | ||
| GGML_UNUSED(A); | ||
| GGML_UNUSED(B); | ||
| NO_DEVICE_CODE; | ||
| #endif | ||
| } | ||
| } | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -290,11 +290,11 @@ bool ggml_cuda_should_use_mmq(enum ggml_type type, int cc, int64_t ne11) { | |
| return !fp16_mma_hardware_available(cc) || ne11 < MMQ_DP4A_MAX_BATCH_SIZE; | ||
| } | ||
|
|
||
| if (amd_mfma_available(cc)) { | ||
| if (amd_mfma_available(cc) || amd_wmma_available(cc)) { | ||
| // As of ROCM 7.0 rocblas/tensile performs very poorly on CDNA3 and hipblaslt (via ROCBLAS_USE_HIPBLASLT) | ||
| // performs better but is currently suffering from a crash on this architecture. | ||
| // TODO: Revisit when hipblaslt is fixed on CDNA3 | ||
| if (GGML_CUDA_CC_IS_CDNA3(cc)) { | ||
| if (GGML_CUDA_CC_IS_CDNA3(cc) || GGML_CUDA_CC_IS_RDNA4(cc)) { | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. could we have |
||
| return true; | ||
| } | ||
| if (ne11 <= 128 || type == GGML_TYPE_Q4_0 || type == GGML_TYPE_Q4_1 || type == GGML_TYPE_Q5_0 || type == GGML_TYPE_Q5_1) { | ||
|
|
@@ -306,5 +306,5 @@ bool ggml_cuda_should_use_mmq(enum ggml_type type, int cc, int64_t ne11) { | |
| return false; | ||
| } | ||
|
|
||
| return (!GGML_CUDA_CC_IS_RDNA4(cc) && !GGML_CUDA_CC_IS_RDNA3(cc) && !GGML_CUDA_CC_IS_CDNA(cc)) || ne11 < MMQ_DP4A_MAX_BATCH_SIZE; | ||
| return (!GGML_CUDA_CC_IS_RDNA3(cc) && !GGML_CUDA_CC_IS_CDNA(cc)) || ne11 < MMQ_DP4A_MAX_BATCH_SIZE; | ||
| } | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why are you changing this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it was a mistake, it is reverted now thanks!