Skip to content

Conversation

@matthiasdiener
Copy link
Contributor

@matthiasdiener matthiasdiener commented Jan 28, 2026

Description

See https://github.com/ROCm/frameworks-internal/issues/13792 for context.

TODOs:

  • Enable tests in test_numerics.py
  • Make kernels selectable & tunable
  • Handle gelu/bias (or make sure these are not passed in)
  • Performance analysis and improvements
  • More tests

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • Change A
  • Change B

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

@matthiasdiener matthiasdiener self-assigned this Jan 28, 2026
@matthiasdiener matthiasdiener changed the title [WIP] proof-of-concept: grouped GEMM with ck_tile [WIP] Grouped GEMM with ck_tile Jan 29, 2026
delay_wgrad_compute,
):
os.environ["NVTE_USE_CUTLASS_GROUPED_GEMM"] = "1"
if IS_HIP_EXTENSION:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is our CK grouped gemm a drop-in replacement with NV upstream CUTLASS grouped gemm? If so, we can share the same env. It's like cublaslt vs hipblaslt...

Comment on lines +21 to +81
struct TileCfg_basic {
static constexpr ck_tile::index_t M_Tile = 256;
static constexpr ck_tile::index_t N_Tile = 128;
static constexpr ck_tile::index_t K_Tile = 64;

static constexpr ck_tile::index_t M_Warp = 2;
static constexpr ck_tile::index_t N_Warp = 2;
static constexpr ck_tile::index_t K_Warp = 1;

static constexpr ck_tile::index_t M_Warp_Tile = 32;
static constexpr ck_tile::index_t N_Warp_Tile = 32;
static constexpr ck_tile::index_t K_Warp_Tile = 16;

static constexpr bool kPadM = true;
static constexpr bool kPadN = true;
static constexpr bool kPadK = true;

static constexpr bool DoubleSmemBuffer = false;

static constexpr ck_tile::index_t TilePartitionerGroupNum = 8;
static constexpr ck_tile::index_t TilePartitionerM01 = 1;
};

template <typename AType, typename BType, typename CType,
typename ALayout, typename BLayout, typename CLayout,
typename TileCfg, ck_tile::memory_operation_enum MemOp,
typename AccType = float>
class Runner{
public:
using GemmShape = ck_tile::TileGemmShape<
ck_tile::sequence<TileCfg::M_Tile, TileCfg::N_Tile, TileCfg::K_Tile>,
ck_tile::sequence<TileCfg::M_Warp, TileCfg::N_Warp, TileCfg::K_Warp>,
ck_tile::sequence<TileCfg::M_Warp_Tile, TileCfg::N_Warp_Tile, TileCfg::K_Warp_Tile>>;

using Partitioner = ck_tile::GemmSpatiallyLocalTilePartitioner<
GemmShape, TileCfg::TilePartitionerGroupNum, TileCfg::TilePartitionerM01>;

using UniversalTraits = ck_tile::PersistentTileGemmUniversalTraits<
TileCfg::kPadM, TileCfg::kPadN, TileCfg::kPadK,
TileCfg::DoubleSmemBuffer, ALayout, BLayout, CLayout>;

static constexpr ck_tile::GemmPipelineScheduler Scheduler =
ck_tile::GemmPipelineScheduler::Intrawave;

using Problem = ck_tile::UniversalGemmPipelineProblem<
AType, BType, AccType, GemmShape, UniversalTraits, Scheduler>;

using Pipeline = ck_tile::GemmPipelineAgBgCrCompV3<Problem>;

using Epilogue = ck_tile::CShuffleEpilogue<
ck_tile::CShuffleEpilogueProblem<
AType, BType, ck_tile::tuple<>, AccType,
CType, ck_tile::tuple<>, CLayout,
ck_tile::element_wise::PassThrough,
Partitioner::MPerBlock, Partitioner::NPerBlock,
TileCfg::M_Warp, TileCfg::N_Warp,
TileCfg::M_Warp_Tile, TileCfg::N_Warp_Tile, TileCfg::K_Warp_Tile,
Problem::TransposeC, MemOp>>;

using Kernel = ck_tile::GroupedGemmKernel<Partitioner, Pipeline, Epilogue>;
};
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are these codes from CK repo? If so, can you add a comment to point to the reference?

Comment on lines +120 to +121
std::vector<ck_tile::GroupedGemmHostArgs<0>> descs;
descs.reserve(group_num);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not put group_num inside the desc vector definition?

Comment on lines +111 to +112
using R = Runner<T, T, T, ALayout, BLayout, CLayout, TileCfg_basic, MemOp>;
using Kernel = typename R::Kernel;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This R is not used anywhere else

Comment on lines +128 to +131
if (a.shape.size() != 2 || b.shape.size() != 2 || d.shape.size() != 2) {
NVTE_ERROR("grouped_gemm_ck_tile: expected all groups to be 2D.");
return false;
}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does grouped gemm support generalized matrices from high-dimensional tensors? Regular gemm supports that. And TE treat the last dim as col with other dimensions as row:

size_t flat_first_dim() const {
const auto &full_shape = shape();
size_t ret = 1;
if (!full_shape.empty()) {
for (size_t i = 0; i < full_shape.size() - 1; i++) {
ret *= full_shape[i];
}
}
return ret;
}
/*! Matrix width after tensor is flattened to 2D
*
* If a tensor has dimensions (D1, D2, ..., Dn), it is reinterpreted
* as a (D1*D2*...*D(n-1), Dn) matrix.
*/
size_t flat_last_dim() const {
const auto &full_shape = shape();
if (full_shape.empty()) {
return 1;
} else {
return full_shape.back();
}
}
};

}
}

bool grouped_gemm_ck_tile(const NVTETensor* A,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we overload this function? In cublaslt_gemm.cu, it's only called by this signature. Perhaps we can rename the grouped_gemm_ck_tile in line 255

transformer_engine::getenv<bool>("NVTE_CK_GROUPED_GEMM_WARN_FALLBACK", false);

auto is_supported_dtype = [&]() -> bool {
auto *inputA = transformer_engine::convertNVTETensorCheck(A[0]);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it possible that num_group=0 so A[0] access not valid?

set(CK_ROOT ${CMAKE_SOURCE_DIR}/../../3rdparty/aiter/3rdparty/composable_kernel)

target_include_directories(transformer_engine
BEFORE PRIVATE
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why using keyword BEFORE in this target_include_directories? Is it because cmake will not be able to find the correct header files without prioritizing the ck include dirs?

target_include_directories(transformer_engine PUBLIC
"${CMAKE_CURRENT_SOURCE_DIR}/include")

set(CK_ROOT ${CMAKE_SOURCE_DIR}/../../3rdparty/aiter/3rdparty/composable_kernel)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

CMAKE_SOURCE_DIR --> CMAKE_CURRENT_SOURCE_DIR? Not sure whether other upstream libs will depend on us but let's make it future proof

#include "common/util/cuda_runtime.h"
#include "common/util/system.h"
#ifndef __HIP_PLATFORM_AMD__
#include "cutlass_grouped_gemm.cuh"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

NV upstream made another .cu file for their cutlass_grouped_gemm and compiled it separately. Maybe we can follow their structure for better isolation (avoid CK defining some macros contaminating our cublaslt_gemm.cu)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants