Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions build-xcframework.sh
100755 → 100644
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why are you changing this?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it was a mistake, it is reverted now thanks!

Original file line number Diff line number Diff line change
Expand Up @@ -454,6 +454,8 @@ cmake -B build-visionos -G Xcode \
-DCMAKE_C_FLAGS="-D_XOPEN_SOURCE=700 ${COMMON_C_FLAGS}" \
-DCMAKE_CXX_FLAGS="-D_XOPEN_SOURCE=700 ${COMMON_CXX_FLAGS}" \
-DLLAMA_CURL=OFF \
-DLLAMA_HTTPLIB=OFF \
-DLLAMA_BUILD_SERVER=OFF \
-S .
cmake --build build-visionos --config Release -- -quiet

Expand All @@ -468,6 +470,8 @@ cmake -B build-visionos-sim -G Xcode \
-DCMAKE_C_FLAGS="-D_XOPEN_SOURCE=700 ${COMMON_C_FLAGS}" \
-DCMAKE_CXX_FLAGS="-D_XOPEN_SOURCE=700 ${COMMON_CXX_FLAGS}" \
-DLLAMA_CURL=OFF \
-DLLAMA_HTTPLIB=OFF \
-DLLAMA_BUILD_SERVER=OFF \
-S .
cmake --build build-visionos-sim --config Release -- -quiet

Expand Down
1 change: 1 addition & 0 deletions ggml/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,7 @@ option(GGML_HIP_GRAPHS "ggml: use HIP graph, experimental,
option(GGML_HIP_NO_VMM "ggml: do not try to use HIP VMM" ON)
option(GGML_HIP_ROCWMMA_FATTN "ggml: enable rocWMMA for FlashAttention" OFF)
option(GGML_HIP_MMQ_MFMA "ggml: enable MFMA MMA for CDNA in MMQ" ON)
option(GGML_HIP_MMQ_WMMA "ggml: enable WMMA MMA for RDNA4 in MMQ" ON)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For now this is fine but long-term, after the kernels have been fully optimized and tested per datatype, it would be preferable to re-use the FORCE_CUBLAS and FORCE_MMQ options.

Copy link
Collaborator

@IMbackK IMbackK Nov 12, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This dose not replace those, but makes it use the dp4a mmq kernels instead. I added this for cdna/mfma because it allows testing for GCN performance regressions on CDNA. Similarly this allows testing for RDNA1/2 performance regressions on RDNA4.

I would prefer this to be kept.

EDIT: i gues testing for RDNA1/2 performance on RDNA4 is less useful than testing for GCN performance on CDNA as RDNA4 has more VGPRS and some new VALU instructions compared to RDNA1/2 unlike CDNA/GCN which have fewer differences

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added this for cdna/mfma because it allows testing for GCN performance regressions on CDNA.

My experience so far has been that the portability of performance across GPUs is so poor that something like this is of little utility. In the rare cases where emulating old hardware is needed one should just edit the code temporarily. If options like this are exposed to users they are going to use them and that increases the amount of work that needs to be put into maintenance. So long-term I still intend to remove those options. My current AMD lineup consists of RDNA2, RDNA3.5, RDNA4, GCN5.1, and CDNA1, and in the next months I intend to add RDNA3 and CDNA2. I would just test the performance using those GPUs directly.

Copy link
Collaborator

@IMbackK IMbackK Nov 12, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not everyone has a huge selection of hardware to choose from. Across GCN5.1/gfx906 and CDNA in my experience the performance portability is extremely close, this is no surprise as the changes made to CDNA that are relevant to ggml are very slight:

  1. MFMA was added, with a special register 256 wide file usable by just these instructions and loads and stores.
  2. an instruction was added to load from global memory directly into lds, but the compiler do sent generate it.

The only practical difference in the generated assembly is that under register pressure the compiler will spill to MFMAs register space instead of scratch memory, which very slightly reduces the cost of spills under register pressure.
The cus themselves are also extremely similar and cache local memory and global memory latency are essentially unchanged.

The picture changes only slightly with CDNA2 where the physical (but not logical) register space between the valu and mfma instructions is now shared, meaning the minimum occupancy for a valu kernel allocating all 256 registers is 2 and packed 32bit instructions where added, but again in my expirance the performance on cdna2 predicts extremely closely the performance on GCN.

I dont have much expirance with RDNA and its true that the changes between RDNA generations are larger.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In any case, we to my knowledge don't have anyone who would be using the GGML_HIP_MMQ_WMMA option with the intent you laid out so it should be removed. I fundamentally don't want to add extra compilation options unless there is a good reason for them because that is just one extra variable that one potentially needs to account for with bug reports.

option(GGML_HIP_EXPORT_METRICS "ggml: enable kernel perf metrics output" OFF)
option(GGML_MUSA_GRAPHS "ggml: use MUSA graph, experimental, unstable" OFF)
option(GGML_MUSA_MUDNN_COPY "ggml: enable muDNN for accelerated copy" OFF)
Expand Down
8 changes: 8 additions & 0 deletions ggml/src/ggml-cuda/common.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,9 @@ static const char * cu_get_error_str(CUresult err) {
#if !defined(GGML_USE_HIP) && __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
#define VOLTA_MMA_AVAILABLE
#endif // !defined(GGML_USE_HIP) && __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
#if defined(GGML_USE_HIP) && defined(RDNA4)
#define AMD_WMMA_AVAILABLE
#endif // defined(GGML_USE_HIP) && defined(RDNA4) && !defined(GGML_HIP_NO_MMQ_WMMA)

#if !defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_TURING
#define TURING_MMA_AVAILABLE
Expand Down Expand Up @@ -287,6 +290,11 @@ static bool volta_mma_available(const int cc) {
return GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) == GGML_CUDA_CC_VOLTA;
}

static bool amd_wmma_available(const int cc) {
return GGML_CUDA_CC_IS_RDNA4(cc);
}

// Volta technically had FP16 tensor cores but they work very differently compared to Turing and later.
static bool turing_mma_available(const int cc) {
return GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_TURING;
}
Expand Down
96 changes: 95 additions & 1 deletion ggml/src/ggml-cuda/mma.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ namespace ggml_cuda_mma {
static constexpr int I = I_;
static constexpr int J = J_;

#if defined(GGML_USE_HIP)
#if defined(AMD_MFMA_AVAILABLE)
static constexpr int ne = I * J / 64;
T x[ne] = {0};

Expand Down Expand Up @@ -149,6 +149,28 @@ namespace ggml_cuda_mma {
return -1;
}
}
#elif defined(AMD_WMMA_AVAILABLE) //adjusted the mapping for RDNA 4

static constexpr int ne = I * J / 32;
T x[ne] = {0};

static __device__ __forceinline__ int get_i(const int l) {
if constexpr (I == 16 && J == 16) {
return 8 * (threadIdx.x / 16) + l;
} else {
NO_DEVICE_CODE;
return -1;
}
}

static __device__ __forceinline__ int get_j(const int l) {
if constexpr (I == 16 && J == 16) {
return threadIdx.x % 16;
} else {
NO_DEVICE_CODE;
return -1;
}
}
#else
static constexpr int ne = I * J / 32;
T x[ne] = {0};
Expand Down Expand Up @@ -353,6 +375,20 @@ namespace ggml_cuda_mma {
const int64_t * xs = (int64_t *) ((const int *) xs0 + (threadIdx.x % t.I) * stride + 2 * (threadIdx.x / t.I));
xi[0] = xs[0];
}

#elif defined(AMD_WMMA_AVAILABLE)
if constexpr (I == 16 && J == 4) {
int64_t * xi = (int64_t *) t.x;
const int64_t * xs = (int64_t *) ((const int *) xs0 + (threadIdx.x % t.I) * stride + 2 * (threadIdx.x / t.I));
xi[0] = xs[0];
} else {
int64_t * xi = (int64_t *) t.x;
const int64_t * xs = (int64_t *) ((const int *) xs0 + (threadIdx.x % t.I) * stride + 4 * (threadIdx.x / t.I));
xi[0] = xs[0];

const int64_t * xs1 = (int64_t *) ((const int *) xs0 + (threadIdx.x % t.I) * stride + 4 * (threadIdx.x / t.I) + 2);
xi[1] = xs1[0];
}
#else
#pragma unroll
for (int l = 0; l < t.ne; ++l) {
Expand Down Expand Up @@ -665,6 +701,36 @@ namespace ggml_cuda_mma {
acc[0],
0, 0, 0);
#endif // defined(CDNA3)

#elif defined(AMD_WMMA_AVAILABLE)
using int32x2_t = __attribute__((__vector_size__(2 * sizeof(int)))) int;
int32x2_t * a_vec = (int32x2_t *) A.x;
int32x2_t * b_vec = (int32x2_t *) B.x;

using int32x8_t = __attribute__((__vector_size__(8 * sizeof(int)))) int;
int32x8_t * acc = (int32x8_t *) D.x;

#if defined(RDNA4)

acc[0] = __builtin_amdgcn_wmma_i32_16x16x16_iu8_w32_gfx12(
true,
a_vec[0],
true,
b_vec[0],
acc[0],
true
);

acc[0] = __builtin_amdgcn_wmma_i32_16x16x16_iu8_w32_gfx12(
true,
a_vec[1],
true,
b_vec[1],
acc[0],
true
);
#endif // defined(RDNA4)

Comment on lines +704 to +733
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is to my understanding currently unused, so please remove it.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi I believe this is used in vec_dot_q8_0_q8_1_mma function which are called in Q4_0, Q5_0, Q8_0, MXFP4 etc

#else
GGML_UNUSED_VARS(D, A, B);
NO_DEVICE_CODE;
Expand All @@ -691,6 +757,7 @@ namespace ggml_cuda_mma {
acc[0],
0, 0, 0);
#endif // defined(CDNA3)

#else
GGML_UNUSED_VARS(D, A, B);
NO_DEVICE_CODE;
Expand Down Expand Up @@ -735,4 +802,31 @@ namespace ggml_cuda_mma {
mma(D16[1], A16[1], B);
#endif // __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
}

static __device__ __forceinline__ void mma(
tile<16, 16, int> & D, const tile<16, 4, int> & A, const tile<16, 4, int> & B) {
#if defined(AMD_WMMA_AVAILABLE)
using int32x2_t = __attribute__((__vector_size__(2 * sizeof(int)))) int;
int32x2_t * a_vec = (int32x2_t *) A.x;
int32x2_t * b_vec = (int32x2_t *) B.x;

using int32x8_t = __attribute__((__vector_size__(8 * sizeof(int)))) int;
int32x8_t * acc = (int32x8_t *) D.x;

acc[0] = __builtin_amdgcn_wmma_i32_16x16x16_iu8_w32_gfx12(
true,
a_vec[0],
true,
b_vec[0],
acc[0],
false
);
#else
GGML_UNUSED(D);
GGML_UNUSED(A);
GGML_UNUSED(B);
NO_DEVICE_CODE;
#endif
}
}

6 changes: 3 additions & 3 deletions ggml/src/ggml-cuda/mmq.cu
Original file line number Diff line number Diff line change
Expand Up @@ -290,11 +290,11 @@ bool ggml_cuda_should_use_mmq(enum ggml_type type, int cc, int64_t ne11) {
return !fp16_mma_hardware_available(cc) || ne11 < MMQ_DP4A_MAX_BATCH_SIZE;
}

if (amd_mfma_available(cc)) {
if (amd_mfma_available(cc) || amd_wmma_available(cc)) {
// As of ROCM 7.0 rocblas/tensile performs very poorly on CDNA3 and hipblaslt (via ROCBLAS_USE_HIPBLASLT)
// performs better but is currently suffering from a crash on this architecture.
// TODO: Revisit when hipblaslt is fixed on CDNA3
if (GGML_CUDA_CC_IS_CDNA3(cc)) {
if (GGML_CUDA_CC_IS_CDNA3(cc) || GGML_CUDA_CC_IS_RDNA4(cc)) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could we have test-backend-ops perf -o MUL_MAT for this pr and the master to better see if this always enabling this is the way to go?

return true;
}
if (ne11 <= 128 || type == GGML_TYPE_Q4_0 || type == GGML_TYPE_Q4_1 || type == GGML_TYPE_Q5_0 || type == GGML_TYPE_Q5_1) {
Expand All @@ -306,5 +306,5 @@ bool ggml_cuda_should_use_mmq(enum ggml_type type, int cc, int64_t ne11) {
return false;
}

return (!GGML_CUDA_CC_IS_RDNA4(cc) && !GGML_CUDA_CC_IS_RDNA3(cc) && !GGML_CUDA_CC_IS_CDNA(cc)) || ne11 < MMQ_DP4A_MAX_BATCH_SIZE;
return (!GGML_CUDA_CC_IS_RDNA3(cc) && !GGML_CUDA_CC_IS_CDNA(cc)) || ne11 < MMQ_DP4A_MAX_BATCH_SIZE;
}
Loading