-
Notifications
You must be signed in to change notification settings - Fork 23
[WIP] Grouped GEMM with ck_tile #434
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: dev
Are you sure you want to change the base?
Conversation
This reverts commit 86fbbac.
| delay_wgrad_compute, | ||
| ): | ||
| os.environ["NVTE_USE_CUTLASS_GROUPED_GEMM"] = "1" | ||
| if IS_HIP_EXTENSION: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is our CK grouped gemm a drop-in replacement with NV upstream CUTLASS grouped gemm? If so, we can share the same env. It's like cublaslt vs hipblaslt...
| struct TileCfg_basic { | ||
| static constexpr ck_tile::index_t M_Tile = 256; | ||
| static constexpr ck_tile::index_t N_Tile = 128; | ||
| static constexpr ck_tile::index_t K_Tile = 64; | ||
|
|
||
| static constexpr ck_tile::index_t M_Warp = 2; | ||
| static constexpr ck_tile::index_t N_Warp = 2; | ||
| static constexpr ck_tile::index_t K_Warp = 1; | ||
|
|
||
| static constexpr ck_tile::index_t M_Warp_Tile = 32; | ||
| static constexpr ck_tile::index_t N_Warp_Tile = 32; | ||
| static constexpr ck_tile::index_t K_Warp_Tile = 16; | ||
|
|
||
| static constexpr bool kPadM = true; | ||
| static constexpr bool kPadN = true; | ||
| static constexpr bool kPadK = true; | ||
|
|
||
| static constexpr bool DoubleSmemBuffer = false; | ||
|
|
||
| static constexpr ck_tile::index_t TilePartitionerGroupNum = 8; | ||
| static constexpr ck_tile::index_t TilePartitionerM01 = 1; | ||
| }; | ||
|
|
||
| template <typename AType, typename BType, typename CType, | ||
| typename ALayout, typename BLayout, typename CLayout, | ||
| typename TileCfg, ck_tile::memory_operation_enum MemOp, | ||
| typename AccType = float> | ||
| class Runner{ | ||
| public: | ||
| using GemmShape = ck_tile::TileGemmShape< | ||
| ck_tile::sequence<TileCfg::M_Tile, TileCfg::N_Tile, TileCfg::K_Tile>, | ||
| ck_tile::sequence<TileCfg::M_Warp, TileCfg::N_Warp, TileCfg::K_Warp>, | ||
| ck_tile::sequence<TileCfg::M_Warp_Tile, TileCfg::N_Warp_Tile, TileCfg::K_Warp_Tile>>; | ||
|
|
||
| using Partitioner = ck_tile::GemmSpatiallyLocalTilePartitioner< | ||
| GemmShape, TileCfg::TilePartitionerGroupNum, TileCfg::TilePartitionerM01>; | ||
|
|
||
| using UniversalTraits = ck_tile::PersistentTileGemmUniversalTraits< | ||
| TileCfg::kPadM, TileCfg::kPadN, TileCfg::kPadK, | ||
| TileCfg::DoubleSmemBuffer, ALayout, BLayout, CLayout>; | ||
|
|
||
| static constexpr ck_tile::GemmPipelineScheduler Scheduler = | ||
| ck_tile::GemmPipelineScheduler::Intrawave; | ||
|
|
||
| using Problem = ck_tile::UniversalGemmPipelineProblem< | ||
| AType, BType, AccType, GemmShape, UniversalTraits, Scheduler>; | ||
|
|
||
| using Pipeline = ck_tile::GemmPipelineAgBgCrCompV3<Problem>; | ||
|
|
||
| using Epilogue = ck_tile::CShuffleEpilogue< | ||
| ck_tile::CShuffleEpilogueProblem< | ||
| AType, BType, ck_tile::tuple<>, AccType, | ||
| CType, ck_tile::tuple<>, CLayout, | ||
| ck_tile::element_wise::PassThrough, | ||
| Partitioner::MPerBlock, Partitioner::NPerBlock, | ||
| TileCfg::M_Warp, TileCfg::N_Warp, | ||
| TileCfg::M_Warp_Tile, TileCfg::N_Warp_Tile, TileCfg::K_Warp_Tile, | ||
| Problem::TransposeC, MemOp>>; | ||
|
|
||
| using Kernel = ck_tile::GroupedGemmKernel<Partitioner, Pipeline, Epilogue>; | ||
| }; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Are these codes from CK repo? If so, can you add a comment to point to the reference?
| std::vector<ck_tile::GroupedGemmHostArgs<0>> descs; | ||
| descs.reserve(group_num); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why not put group_num inside the desc vector definition?
| using R = Runner<T, T, T, ALayout, BLayout, CLayout, TileCfg_basic, MemOp>; | ||
| using Kernel = typename R::Kernel; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This R is not used anywhere else
| if (a.shape.size() != 2 || b.shape.size() != 2 || d.shape.size() != 2) { | ||
| NVTE_ERROR("grouped_gemm_ck_tile: expected all groups to be 2D."); | ||
| return false; | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does grouped gemm support generalized matrices from high-dimensional tensors? Regular gemm supports that. And TE treat the last dim as col with other dimensions as row:
TransformerEngine/transformer_engine/common/common.h
Lines 238 to 262 in 9d6b0e5
| size_t flat_first_dim() const { | |
| const auto &full_shape = shape(); | |
| size_t ret = 1; | |
| if (!full_shape.empty()) { | |
| for (size_t i = 0; i < full_shape.size() - 1; i++) { | |
| ret *= full_shape[i]; | |
| } | |
| } | |
| return ret; | |
| } | |
| /*! Matrix width after tensor is flattened to 2D | |
| * | |
| * If a tensor has dimensions (D1, D2, ..., Dn), it is reinterpreted | |
| * as a (D1*D2*...*D(n-1), Dn) matrix. | |
| */ | |
| size_t flat_last_dim() const { | |
| const auto &full_shape = shape(); | |
| if (full_shape.empty()) { | |
| return 1; | |
| } else { | |
| return full_shape.back(); | |
| } | |
| } | |
| }; |
| } | ||
| } | ||
|
|
||
| bool grouped_gemm_ck_tile(const NVTETensor* A, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why do we overload this function? In cublaslt_gemm.cu, it's only called by this signature. Perhaps we can rename the grouped_gemm_ck_tile in line 255
| transformer_engine::getenv<bool>("NVTE_CK_GROUPED_GEMM_WARN_FALLBACK", false); | ||
|
|
||
| auto is_supported_dtype = [&]() -> bool { | ||
| auto *inputA = transformer_engine::convertNVTETensorCheck(A[0]); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is it possible that num_group=0 so A[0] access not valid?
| set(CK_ROOT ${CMAKE_SOURCE_DIR}/../../3rdparty/aiter/3rdparty/composable_kernel) | ||
|
|
||
| target_include_directories(transformer_engine | ||
| BEFORE PRIVATE |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why using keyword BEFORE in this target_include_directories? Is it because cmake will not be able to find the correct header files without prioritizing the ck include dirs?
| target_include_directories(transformer_engine PUBLIC | ||
| "${CMAKE_CURRENT_SOURCE_DIR}/include") | ||
|
|
||
| set(CK_ROOT ${CMAKE_SOURCE_DIR}/../../3rdparty/aiter/3rdparty/composable_kernel) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
CMAKE_SOURCE_DIR --> CMAKE_CURRENT_SOURCE_DIR? Not sure whether other upstream libs will depend on us but let's make it future proof
| #include "common/util/cuda_runtime.h" | ||
| #include "common/util/system.h" | ||
| #ifndef __HIP_PLATFORM_AMD__ | ||
| #include "cutlass_grouped_gemm.cuh" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
NV upstream made another .cu file for their cutlass_grouped_gemm and compiled it separately. Maybe we can follow their structure for better isolation (avoid CK defining some macros contaminating our cublaslt_gemm.cu)
Description
See https://github.com/ROCm/frameworks-internal/issues/13792 for context.
TODOs:
Type of change
Changes
Please list the changes introduced in this PR:
Checklist: