-
Notifications
You must be signed in to change notification settings - Fork 29
CK Tile Group GEMM gfx1250 #576
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: gfx1250
Are you sure you want to change the base?
Changes from all commits
d52075d
2934c99
752e0d3
9ea316d
b55fe29
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -14,7 +14,7 @@ namespace grouped_gemm { | |
| // Tile configs: FP16/BF16 | ||
| // ------------------------- | ||
|
|
||
| struct TileCfg_256x256x64 { | ||
| struct TileCfg_256x256x64_MFMA { | ||
| static constexpr ck_tile::index_t M_Tile = 256; | ||
| static constexpr ck_tile::index_t N_Tile = 256; | ||
| static constexpr ck_tile::index_t K_Tile = 64; | ||
|
|
@@ -37,14 +37,37 @@ struct TileCfg_256x256x64 { | |
| static constexpr ck_tile::index_t TilePartitionerM01 = 4; | ||
| }; | ||
|
|
||
| struct TileCfg_256x128x64 : TileCfg_256x256x64 { | ||
| struct TileCfg_256x128x64_MFMA : TileCfg_256x256x64_MFMA { | ||
| static constexpr ck_tile::index_t N_Tile = 128; | ||
| }; | ||
|
|
||
| struct TileCfg_256x128x64_padding : TileCfg_256x128x64 { | ||
| struct TileCfg_256x128x64_MFMA_padding : TileCfg_256x128x64_MFMA { | ||
| static constexpr bool kPadN = true; | ||
| }; | ||
|
|
||
| struct TileCfg_256x256x64_WMMA { | ||
| static constexpr ck_tile::index_t M_Tile = 256; | ||
| static constexpr ck_tile::index_t N_Tile = 256; | ||
| static constexpr ck_tile::index_t K_Tile = 64; | ||
|
|
||
| static constexpr ck_tile::index_t M_Warp = 2; | ||
| static constexpr ck_tile::index_t N_Warp = 2; | ||
| static constexpr ck_tile::index_t K_Warp = 1; | ||
|
|
||
| static constexpr ck_tile::index_t M_Warp_Tile = 16; | ||
| static constexpr ck_tile::index_t N_Warp_Tile = 16; | ||
| static constexpr ck_tile::index_t K_Warp_Tile = 32; | ||
|
|
||
| static constexpr bool kPadM = true; | ||
| static constexpr bool kPadN = true; | ||
| static constexpr bool kPadK = true; | ||
|
Comment on lines
+57
to
+63
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. so the difference btw TileCfg_256x256x64_MFMA and TileCfg_256x256x64_WMMA is inside M, N, K warp tile and kPads?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The difference is not just the warp tile shape or kPads. MFMA and WMMA are different warp-level MMA instruction paths, so they lower through different warp dispatch/pipeline configurations with different tile and padding requirements. |
||
|
|
||
| static constexpr bool DoubleSmemBuffer = false; | ||
|
|
||
| static constexpr ck_tile::index_t TilePartitionerGroupNum = 8; | ||
| static constexpr ck_tile::index_t TilePartitionerM01 = 4; | ||
| }; | ||
|
|
||
| template <typename AType, | ||
| typename BType, | ||
| typename CType, | ||
|
|
@@ -209,10 +232,11 @@ class GroupedGemmRunner : public RunnerInterface { | |
| runner = std::make_unique<Runner>(); \ | ||
| }) | ||
|
|
||
| bool ck_tile_grouped_gemm_fp16_dispatch(DType a_dtype, | ||
| DType b_dtype, | ||
| DType d_dtype, | ||
| const GroupedGemmRunContext& ctx) { | ||
| template <GPUArch Arch> | ||
| bool ck_tile_grouped_gemm_fp16_dispatch_arch(DType a_dtype, | ||
| DType b_dtype, | ||
| DType d_dtype, | ||
| const GroupedGemmRunContext& ctx) { | ||
| const ck_tile::stream_config s{ctx.stream}; | ||
| std::unique_ptr<RunnerInterface> runner = nullptr; | ||
|
|
||
|
|
@@ -229,13 +253,17 @@ bool ck_tile_grouped_gemm_fp16_dispatch(DType a_dtype, | |
|
|
||
| TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(d_dtype, d_te_type, { | ||
| using CType = typename TETypeToCKType<d_te_type>::type; | ||
|
|
||
| if (ctx.N % 256 == 0) { | ||
| MAKE_RUNNER(TileCfg_256x256x64); | ||
| } else if (ctx.N % 128 == 0) { | ||
| MAKE_RUNNER(TileCfg_256x128x64); | ||
|
|
||
| if constexpr (Arch == GPUArch::GFX1250) { | ||
| MAKE_RUNNER(TileCfg_256x256x64_WMMA); | ||
| } else { | ||
| MAKE_RUNNER(TileCfg_256x128x64_padding); | ||
| if (ctx.N % 256 == 0) { | ||
| MAKE_RUNNER(TileCfg_256x256x64_MFMA); | ||
| } else if (ctx.N % 128 == 0) { | ||
| MAKE_RUNNER(TileCfg_256x128x64_MFMA); | ||
| } else { | ||
| MAKE_RUNNER(TileCfg_256x128x64_MFMA_padding); | ||
| } | ||
| } | ||
| }); | ||
| }); | ||
|
|
@@ -249,6 +277,30 @@ bool ck_tile_grouped_gemm_fp16_dispatch(DType a_dtype, | |
| return runner->run(s, ctx); | ||
| } | ||
|
|
||
| bool ck_tile_grouped_gemm_fp16_dispatch(DType a_dtype, | ||
| DType b_dtype, | ||
| DType d_dtype, | ||
| const GroupedGemmRunContext& ctx) { | ||
| switch (detect_gpu_arch()) { | ||
| #if defined(__gfx942__) | ||
| case GPUArch::GFX942: | ||
| return ck_tile_grouped_gemm_fp16_dispatch_arch<GPUArch::GFX942>(a_dtype, b_dtype, d_dtype, ctx); | ||
| #endif | ||
| #if defined(__gfx950__) | ||
| case GPUArch::GFX950: | ||
| return ck_tile_grouped_gemm_fp16_dispatch_arch<GPUArch::GFX950>(a_dtype, b_dtype, d_dtype, ctx); | ||
| #endif | ||
| #if defined(__gfx1250__) | ||
| case GPUArch::GFX1250: | ||
| return ck_tile_grouped_gemm_fp16_dispatch_arch<GPUArch::GFX1250>(a_dtype, b_dtype, d_dtype, ctx); | ||
| #endif | ||
|
|
||
| default: | ||
| NVTE_ERROR("ck_tile_grouped_gemm: available architectures = {gfx942, gfx950, gfx1250}"); | ||
| return false; | ||
| } | ||
| } | ||
|
|
||
| #undef MAKE_RUNNER | ||
|
|
||
| } // namespace grouped_gemm | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: Will the whole rocm_libraries too big? Do we have a way to have sparse check out for this ck subdir?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good point. The full rocm_libraries checkout is fairly large (~8G locally), while projects/composablekernel alone is much smaller (~167M). Yeah, sparse checkout probably makes sense here, but I am wondering if it would be better handled in a separate PR.