Skip to content

Conversation

@jiachengjason
Copy link

Enabled WMMA-MMQ kernels for RDNA 4 architecture on AMD GPUs

Following similar approach to #14624

Using ./build/bin/llama-bench to collect the following performance results

Performance results with ggml/llama.cpp master commit up to/includes 5b180c3

Build command for the following performance results:
HIPCXX="$(hipconfig -l)/clang" HIP_PATH="$(hipconfig -R)" cmake -S . -B build -DGGML_HIP=ON -DGGML_CUDA_FORCE_MMQ=OFF -DGGML_HIP_UMA=OFF -DGGML_HIP_ROCWMMA_FATTN=OFF -DGPU_TARGETS="gfx1201" -DGGML_HIP_GRAPHS=OFF -DLLAMA_CURL=OFF -DGGML_CUDA_FORCE_CUBLAS=OFF -DCMAKE_BUILD_TYPE=Release && cmake --build build --config Release -- -j 32

image

Build command for the following performance results:
HIPCXX="$(hipconfig -l)/clang" HIP_PATH="$(hipconfig -R)" cmake -S . -B build -DGGML_HIP=ON -DGGML_HIP_UMA=OFF -DGGML_HIP_ROCWMMA_FATTN=ON -DGPU_TARGETS=gfx1201 -DGGML_HIP_GRAPHS=OFF -DLLAMA_CURL=OFF -DGGML_CUDA_FORCE_CUBLAS=OFF -DCMAKE_BUILD_TYPE=Release && cmake --build build --config Release -- -j 32

image

@github-actions github-actions bot added Nvidia GPU Issues specific to Nvidia GPUs ggml changes relating to the ggml tensor library for machine learning labels Nov 11, 2025
@JohannesGaessler
Copy link
Collaborator

Can you give me a quick summary of what you would consider to still be missing from this PR for it to be ready for review?

@jiachengjason jiachengjason marked this pull request as ready for review November 11, 2025 18:15
@jiachengjason
Copy link
Author

Can you give me a quick summary of what you would consider to still be missing from this PR for it to be ready for review?

Hi I opened up for review now, thanks!

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think GGML_HIP_FORCE_ROCWMMA_FATTN_GFX12 is not used anymore

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

its not

Copy link
Collaborator

@IMbackK IMbackK left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jiachengjason As mentioned by @slojosic-amd, there is a accidental change in the cmake file.
You are also changing the permissions of build-xcframework.sh by accident with this pr. Please revert these changes.

Since you improved the performance mmq on RDNA4 you should also change

return (!GGML_CUDA_CC_IS_RDNA4(cc) && !GGML_CUDA_CC_IS_RDNA3(cc) && !GGML_CUDA_CC_IS_CDNA(cc)) || ne11 < MMQ_DP4A_MAX_BATCH_SIZE;
to increase the use of these kernels. As the performance diff between this and blas seams quite uneven i would recommend doing so on a per-datatype and shape basis like is done for mfma. For CDNA i just selected whatever is faster but for RDNA4 i would err on the side using mmq unless the performance difference is large as consumer RDNA4 devices have only 16 GiB of vram makeing the size of the dequant buffers more relevant.

@IMbackK
Copy link
Collaborator

IMbackK commented Nov 12, 2025

Oh and using scripts/compare-llama-bench.py with llama-bench [...] -o sql|sqlite3 llama-bench.sqlite and test-backend-ops perf provide more readable ways to compare performances changes - useful for finding the right shapes to enable.

@IMbackK IMbackK self-assigned this Nov 12, 2025
option(GGML_HIP_NO_VMM "ggml: do not try to use HIP VMM" ON)
option(GGML_HIP_ROCWMMA_FATTN "ggml: enable rocWMMA for FlashAttention" OFF)
option(GGML_HIP_MMQ_MFMA "ggml: enable MFMA MMA for CDNA in MMQ" ON)
option(GGML_HIP_MMQ_WMMA "ggml: enable WMMA MMA for RDNA4 in MMQ" ON)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For now this is fine but long-term, after the kernels have been fully optimized and tested per datatype, it would be preferable to re-use the FORCE_CUBLAS and FORCE_MMQ options.

Copy link
Collaborator

@IMbackK IMbackK Nov 12, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This dose not replace those, but makes it use the dp4a mmq kernels instead. I added this for cdna/mfma because it allows testing for GCN performance regressions on CDNA. Similarly this allows testing for RDNA1/2 performance regressions on RDNA4.

I would prefer this to be kept.

EDIT: i gues testing for RDNA1/2 performance on RDNA4 is less useful than testing for GCN performance on CDNA as RDNA4 has more VGPRS and some new VALU instructions compared to RDNA1/2 unlike CDNA/GCN which have fewer differences

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added this for cdna/mfma because it allows testing for GCN performance regressions on CDNA.

My experience so far has been that the portability of performance across GPUs is so poor that something like this is of little utility. In the rare cases where emulating old hardware is needed one should just edit the code temporarily. If options like this are exposed to users they are going to use them and that increases the amount of work that needs to be put into maintenance. So long-term I still intend to remove those options. My current AMD lineup consists of RDNA2, RDNA3.5, RDNA4, GCN5.1, and CDNA1, and in the next months I intend to add RDNA3 and CDNA2. I would just test the performance using those GPUs directly.

Copy link
Collaborator

@IMbackK IMbackK Nov 12, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not everyone has a huge selection of hardware to choose from. Across GCN5.1/gfx906 and CDNA in my experience the performance portability is extremely close, this is no surprise as the changes made to CDNA that are relevant to ggml are very slight:

  1. MFMA was added, with a special register 256 wide file usable by just these instructions and loads and stores.
  2. an instruction was added to load from global memory directly into lds, but the compiler do sent generate it.

The only practical difference in the generated assembly is that under register pressure the compiler will spill to MFMAs register space instead of scratch memory, which very slightly reduces the cost of spills under register pressure.
The cus themselves are also extremely similar and cache local memory and global memory latency are essentially unchanged.

The picture changes only slightly with CDNA2 where the physical (but not logical) register space between the valu and mfma instructions is now shared, meaning the minimum occupancy for a valu kernel allocating all 256 registers is 2 and packed 32bit instructions where added, but again in my expirance the performance on cdna2 predicts extremely closely the performance on GCN.

I dont have much expirance with RDNA and its true that the changes between RDNA generations are larger.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In any case, we to my knowledge don't have anyone who would be using the GGML_HIP_MMQ_WMMA option with the intent you laid out so it should be removed. I fundamentally don't want to add extra compilation options unless there is a good reason for them because that is just one extra variable that one potentially needs to account for with bug reports.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why are you changing this?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it was a mistake, it is reverted now thanks!

Comment on lines 231 to 233
#if defined(GGML_USE_HIP) && defined(RDNA4) && !defined(GGML_HIP_NO_MMQ_WMMA)
#define AMD_WMMA_AVAILABLE
#endif // defined(GGML_USE_HIP) && defined(RDNA4) && !defined(GGML_HIP_NO_MMQ_WMMA)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This going to be in conflict with #17077 . Instead of making the availability of AMD WMMA contingent on GGML_HIP_NO_MMQ_WMMA, check that macro in mmq.cuh to decide whether or not to use AMD WMMA.

}

if (amd_mfma_available(cc)) {
if (amd_mfma_available(cc)||amd_wmma_available(cc)) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Make a separate branch for AMD WMMA instead. Since as of right now it's an explicit opt-in via a compilation option, simply return true if and only if the cc is RDNA4 and the compilation option has been enabled.

Copy link
Collaborator

@IMbackK IMbackK Nov 12, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JohannesGaessler I think you misunderstood what the compilation option is supped to achieve.

Copy link
Collaborator

@IMbackK IMbackK Nov 12, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jiachengjason formating, should be: if (amd_mfma_available(cc) || amd_wmma_available(cc)) (spaces)

static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma(
const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
#if defined(AMD_MFMA_AVAILABLE)
#if defined(AMD_MFMA_AVAILABLE)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
#if defined(AMD_MFMA_AVAILABLE)
#if defined(AMD_MFMA_AVAILABLE)

Comment on lines 157 to 182
static __device__ __forceinline__ int get_i(const int l) {
if constexpr (I == 16 && J == 8) {
return threadIdx.x % 16;
} else if constexpr (I == 32 && J == 4) {
return threadIdx.x % 32;
} else if constexpr (I == 16 && J == 16) {
return 8 * (threadIdx.x / 16) + l;
} else if constexpr (I == 32 && J == 32) {
return 4 * (threadIdx.x / 32) + 8 * (l / 4) + (l % 4);
} else {
static_assert(I == -1 && J == -1, "template specialization not implemented");
}
}

static __device__ __forceinline__ int get_j(const int l) {
if constexpr (I == 16 && J == 8) {
return 2 * (threadIdx.x / 16) + l;
} else if constexpr (I == 32 && J == 4) {
return 2 * (threadIdx.x / 32) + l;
} else if constexpr (I == 16 && J == 16) {
return threadIdx.x % 16;
} else if constexpr (I == 32 && J == 32) {
return threadIdx.x % 32;
} else {
static_assert(I == -1 && J == -1, "template specialization not implemented");
}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please define only the actually used shapes of 16x4 and 16x16 and use NO_DEVICE_CODE instead of a static assert as was recently changed in the surrounding code.

Comment on lines +714 to +743

#elif defined(AMD_WMMA_AVAILABLE)
using int32x2_t = __attribute__((__vector_size__(2 * sizeof(int)))) int;
int32x2_t * a_vec = (int32x2_t *) A.x;
int32x2_t * b_vec = (int32x2_t *) B.x;

using int32x8_t = __attribute__((__vector_size__(8 * sizeof(int)))) int;
int32x8_t * acc = (int32x8_t *) D.x;

#if defined(RDNA4)

acc[0] = __builtin_amdgcn_wmma_i32_16x16x16_iu8_w32_gfx12(
true,
a_vec[0],
true,
b_vec[0],
acc[0],
true
);

acc[0] = __builtin_amdgcn_wmma_i32_16x16x16_iu8_w32_gfx12(
true,
a_vec[1],
true,
b_vec[1],
acc[0],
true
);
#endif // defined(RDNA4)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is to my understanding currently unused, so please remove it.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi I believe this is used in vec_dot_q8_0_q8_1_mma function which are called in Q4_0, Q5_0, Q8_0, MXFP4 etc

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ggml changes relating to the ggml tensor library for machine learning Nvidia GPU Issues specific to Nvidia GPUs

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants