Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions .gitmodules
Original file line number Diff line number Diff line change
Expand Up @@ -26,3 +26,7 @@
[submodule "3rdparty/QoLA"]
path = 3rdparty/QoLA
url = https://github.com/Micky774/QoLA.git
[submodule "rocm_libraries"]
path = 3rdparty/rocm_libraries
url = https://github.com/ROCm/rocm-libraries.git
branch = users/jia/ck/fix_grouped_gemm_quant_mxtype
1 change: 1 addition & 0 deletions 3rdparty/rocm_libraries
Submodule rocm_libraries added at 66b1d1
2 changes: 1 addition & 1 deletion transformer_engine/common/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -354,7 +354,7 @@ set_property(
PROPERTY
COMPILE_OPTIONS "-g0;-dopt=on")
else()
set(CK_ROOT ${CMAKE_CURRENT_SOURCE_DIR}/../../3rdparty/aiter/3rdparty/composable_kernel)
set(CK_ROOT ${CMAKE_CURRENT_SOURCE_DIR}/../../3rdparty/rocm_libraries/projects/composablekernel)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Will the whole rocm_libraries too big? Do we have a way to have sparse check out for this ck subdir?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point. The full rocm_libraries checkout is fairly large (~8G locally), while projects/composablekernel alone is much smaller (~167M). Yeah, sparse checkout probably makes sense here, but I am wondering if it would be better handled in a separate PR.

target_include_directories(transformer_engine PRIVATE ${CK_ROOT}/include)
endif() #USE_CUDA

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,19 +57,19 @@ bool ck_tile_grouped_gemm(const NVTETensor* A,
// FP8 special handling.
//
// A_use/B_use and transA_use/transB_use have already gone through the
// upstream-style grouped GEMM normalization above. This block only rewrites
// that normalized presentation into the CK FP8 preferred NT presentation by selecting
// `columnwise_data` when needed.
// upstream-style grouped GEMM normalization above. CK FP8 grouped GEMM is
// compiled only for the preferred NT presentation:
//
// CK FP8 target presentation:
// A_use: N
// B_use: T
// transA_use = false
// transB_use = true
//
// The outer condition checks whether this NT presentation is possible:
// - A_use is already N, or can be made N using columnwise_data
// - B_use is already T, or can be made T using columnwise_data
// This block rewrites the normalized presentation into that NT form by
// selecting columnwise_data when needed. If the required columnwise_data view
// is unavailable, this CK FP8 backend cannot represent the GEMM in its
// supported layout form, so we fall back instead of compiling/running an
// unsupported layout variant.
//
// Then each operand is rewritten independently only if needed:
// Rewrite cases:
// NN -> rewrite B only
// TN -> rewrite A and B
// NT -> already in target form
Expand All @@ -81,16 +81,23 @@ bool ck_tile_grouped_gemm(const NVTETensor* A,
const bool has_a_col = A0_te->has_columnwise_data();
const bool has_b_col = B0_te->has_columnwise_data();

if ((!transA_use || has_a_col) && (transB_use || has_b_col)) {
if (transA_use) {
use_a_colwise_data = true;
transA_use = false;
}
const bool can_make_a_nt = !transA_use || has_a_col;
const bool can_make_b_nt = transB_use || has_b_col;

if (!transB_use) {
use_b_colwise_data = true;
transB_use = true;
}
if (!can_make_a_nt || !can_make_b_nt) {
NVTE_WARN("ck_tile_grouped_gemm: FP8 grouped GEMM requires NT presentation. "
"Missing required columnwise_data for layout rewrite; falling back.");
return false;
}

if (transA_use) {
use_a_colwise_data = true;
transA_use = false;
}

if (!transB_use) {
use_b_colwise_data = true;
transB_use = true;
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
#pragma once

#include <hip/hip_runtime.h>
#include "common/util/cuda_runtime.h"

#include <array>
#include <type_traits>
Expand Down Expand Up @@ -70,6 +71,28 @@ static inline const transformer_engine::SimpleTensor& scale_inv_view(const trans
return t.scale_inv;
}

enum class GPUArch {
GFX942,
GFX950,
GFX1250,
UNKNOWN
};

static inline GPUArch detect_gpu_arch() {
int arch = cuda::sm_arch(0);

if (arch == 94) {
return GPUArch::GFX942;
}
if (arch == 95) {
return GPUArch::GFX950;
}
if (arch == 1250) {
return GPUArch::GFX1250;
}
return GPUArch::UNKNOWN;
}

struct GroupedGemmRunContext {
const NVTETensor* A = nullptr;
const NVTETensor* B = nullptr;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ namespace grouped_gemm {
// Tile configs: FP16/BF16
// -------------------------

struct TileCfg_256x256x64 {
struct TileCfg_256x256x64_MFMA {
static constexpr ck_tile::index_t M_Tile = 256;
static constexpr ck_tile::index_t N_Tile = 256;
static constexpr ck_tile::index_t K_Tile = 64;
Expand All @@ -37,14 +37,37 @@ struct TileCfg_256x256x64 {
static constexpr ck_tile::index_t TilePartitionerM01 = 4;
};

struct TileCfg_256x128x64 : TileCfg_256x256x64 {
struct TileCfg_256x128x64_MFMA : TileCfg_256x256x64_MFMA {
static constexpr ck_tile::index_t N_Tile = 128;
};

struct TileCfg_256x128x64_padding : TileCfg_256x128x64 {
struct TileCfg_256x128x64_MFMA_padding : TileCfg_256x128x64_MFMA {
static constexpr bool kPadN = true;
};

struct TileCfg_256x256x64_WMMA {
static constexpr ck_tile::index_t M_Tile = 256;
static constexpr ck_tile::index_t N_Tile = 256;
static constexpr ck_tile::index_t K_Tile = 64;

static constexpr ck_tile::index_t M_Warp = 2;
static constexpr ck_tile::index_t N_Warp = 2;
static constexpr ck_tile::index_t K_Warp = 1;

static constexpr ck_tile::index_t M_Warp_Tile = 16;
static constexpr ck_tile::index_t N_Warp_Tile = 16;
static constexpr ck_tile::index_t K_Warp_Tile = 32;

static constexpr bool kPadM = true;
static constexpr bool kPadN = true;
static constexpr bool kPadK = true;
Comment on lines +57 to +63
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so the difference btw TileCfg_256x256x64_MFMA and TileCfg_256x256x64_WMMA is inside M, N, K warp tile and kPads?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The difference is not just the warp tile shape or kPads. MFMA and WMMA are different warp-level MMA instruction paths, so they lower through different warp dispatch/pipeline configurations with different tile and padding requirements.


static constexpr bool DoubleSmemBuffer = false;

static constexpr ck_tile::index_t TilePartitionerGroupNum = 8;
static constexpr ck_tile::index_t TilePartitionerM01 = 4;
};

template <typename AType,
typename BType,
typename CType,
Expand Down Expand Up @@ -209,10 +232,11 @@ class GroupedGemmRunner : public RunnerInterface {
runner = std::make_unique<Runner>(); \
})

bool ck_tile_grouped_gemm_fp16_dispatch(DType a_dtype,
DType b_dtype,
DType d_dtype,
const GroupedGemmRunContext& ctx) {
template <GPUArch Arch>
bool ck_tile_grouped_gemm_fp16_dispatch_arch(DType a_dtype,
DType b_dtype,
DType d_dtype,
const GroupedGemmRunContext& ctx) {
const ck_tile::stream_config s{ctx.stream};
std::unique_ptr<RunnerInterface> runner = nullptr;

Expand All @@ -229,13 +253,17 @@ bool ck_tile_grouped_gemm_fp16_dispatch(DType a_dtype,

TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(d_dtype, d_te_type, {
using CType = typename TETypeToCKType<d_te_type>::type;

if (ctx.N % 256 == 0) {
MAKE_RUNNER(TileCfg_256x256x64);
} else if (ctx.N % 128 == 0) {
MAKE_RUNNER(TileCfg_256x128x64);

if constexpr (Arch == GPUArch::GFX1250) {
MAKE_RUNNER(TileCfg_256x256x64_WMMA);
} else {
MAKE_RUNNER(TileCfg_256x128x64_padding);
if (ctx.N % 256 == 0) {
MAKE_RUNNER(TileCfg_256x256x64_MFMA);
} else if (ctx.N % 128 == 0) {
MAKE_RUNNER(TileCfg_256x128x64_MFMA);
} else {
MAKE_RUNNER(TileCfg_256x128x64_MFMA_padding);
}
}
});
});
Expand All @@ -249,6 +277,30 @@ bool ck_tile_grouped_gemm_fp16_dispatch(DType a_dtype,
return runner->run(s, ctx);
}

bool ck_tile_grouped_gemm_fp16_dispatch(DType a_dtype,
DType b_dtype,
DType d_dtype,
const GroupedGemmRunContext& ctx) {
switch (detect_gpu_arch()) {
#if defined(__gfx942__)
case GPUArch::GFX942:
return ck_tile_grouped_gemm_fp16_dispatch_arch<GPUArch::GFX942>(a_dtype, b_dtype, d_dtype, ctx);
#endif
#if defined(__gfx950__)
case GPUArch::GFX950:
return ck_tile_grouped_gemm_fp16_dispatch_arch<GPUArch::GFX950>(a_dtype, b_dtype, d_dtype, ctx);
#endif
#if defined(__gfx1250__)
case GPUArch::GFX1250:
return ck_tile_grouped_gemm_fp16_dispatch_arch<GPUArch::GFX1250>(a_dtype, b_dtype, d_dtype, ctx);
#endif

default:
NVTE_ERROR("ck_tile_grouped_gemm: available architectures = {gfx942, gfx950, gfx1250}");
return false;
}
}

#undef MAKE_RUNNER

} // namespace grouped_gemm
Expand Down
Loading
Loading