Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,8 @@ namespace device {
namespace {

template <typename GridwiseGemm,
typename AGridDesc_AK0_M_AK1,
typename BGridDesc_BK0_N_BK1,
typename AGridDesc_M_K,
typename BGridDesc_N_K,
typename DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
typename EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
index_t MaxGroupedGemmGroupsNum,
Expand Down Expand Up @@ -66,14 +66,24 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
if constexpr(EGlobalMemoryDataOperation != InMemoryDataOperationEnum::AtomicAdd)
{
#endif
__shared__ char p_shared[GridwiseGemm::template GetSharedMemoryNumberOfByte<
typename GridwiseGemm::EpilogueCShuffle>()];
auto epilogue_args = typename GridwiseGemm::EpilogueCShuffle{};

using EpilogueType =
typename std::conditional<GridwiseGemm::IsBWaveTransferApplicable &&
GridwiseGemm::UseDirectStore,
typename GridwiseGemm::EpilogueDirectStore,
typename GridwiseGemm::EpilogueCShuffle>::type;

constexpr index_t LDS_size =
GridwiseGemm::template GetSharedMemoryNumberOfByte<EpilogueType>();
__shared__ char p_shared[LDS_size];

auto epilogue_args = EpilogueType{};

const index_t block_args_id = __builtin_amdgcn_readfirstlane(blockIdx.x);
index_t left = 0;
index_t right = gemms_count;
index_t group_id = index_t((left + right) / 2);

while((!(block_args_id >= gemm_kernel_args[group_id].BlockStart_ &&
block_args_id < gemm_kernel_args[group_id].BlockEnd_)) &&
left <= right)
Expand All @@ -89,14 +99,19 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
group_id = index_t((left + right) / 2);
}

const auto num_k_per_block =
gemm_kernel_args[group_id].a_grid_desc_ak0_m_ak1_.GetLength(Number<0>{}) / KBatch;
const auto num_k_per_block = GridwiseGemm::CalculateAK0Padded(
gemm_kernel_args[group_id].a_grid_desc_m_k_.GetLength(Number<1>{}), KBatch);

const auto a_grid_desc_ak0_m_ak1 = GridwiseGemm::MakeAGridDescriptor_AK0_M_AK1(
gemm_kernel_args[group_id].a_grid_desc_m_k_);
const auto b_grid_desc_bk0_n_bk1 = GridwiseGemm::MakeBGridDescriptor_BK0_N_BK1(
gemm_kernel_args[group_id].b_grid_desc_n_k_);

if constexpr(HasMainKBlockLoopInAllGemm || NoMainKBlockLoopInAllGemm)
{

GridwiseGemm::template Run<AGridDesc_AK0_M_AK1,
BGridDesc_BK0_N_BK1,
GridwiseGemm::template Run<decltype(a_grid_desc_ak0_m_ak1),
decltype(b_grid_desc_bk0_n_bk1),
DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
decltype(gemm_kernel_args[group_id].block_2_ctile_map_),
Expand All @@ -107,8 +122,8 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
CTranspose,
TailNum>(
p_shared,
gemm_kernel_args[group_id].a_grid_desc_ak0_m_ak1_,
gemm_kernel_args[group_id].b_grid_desc_bk0_n_bk1_,
a_grid_desc_ak0_m_ak1,
b_grid_desc_bk0_n_bk1,
gemm_kernel_args[group_id].ds_grid_desc_mblock_mperblock_nblock_nperblock_,
gemm_kernel_args[group_id].e_grid_desc_mblock_mperblock_nblock_nperblock_,
gemm_kernel_args[group_id].block_2_ctile_map_,
Expand All @@ -122,8 +137,8 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
{
if(gemm_kernel_args[group_id].HasMainKBlockLoop_)
{
GridwiseGemm::template Run<AGridDesc_AK0_M_AK1,
BGridDesc_BK0_N_BK1,
GridwiseGemm::template Run<decltype(a_grid_desc_ak0_m_ak1),
decltype(b_grid_desc_bk0_n_bk1),
DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
decltype(gemm_kernel_args[group_id].block_2_ctile_map_),
Expand All @@ -134,8 +149,8 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
CTranspose,
TailNum>(
p_shared,
gemm_kernel_args[group_id].a_grid_desc_ak0_m_ak1_,
gemm_kernel_args[group_id].b_grid_desc_bk0_n_bk1_,
a_grid_desc_ak0_m_ak1,
b_grid_desc_bk0_n_bk1,
gemm_kernel_args[group_id].ds_grid_desc_mblock_mperblock_nblock_nperblock_,
gemm_kernel_args[group_id].e_grid_desc_mblock_mperblock_nblock_nperblock_,
gemm_kernel_args[group_id].block_2_ctile_map_,
Expand All @@ -147,8 +162,8 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
}
else
{
GridwiseGemm::template Run<AGridDesc_AK0_M_AK1,
BGridDesc_BK0_N_BK1,
GridwiseGemm::template Run<decltype(a_grid_desc_ak0_m_ak1),
decltype(b_grid_desc_bk0_n_bk1),
DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
decltype(gemm_kernel_args[group_id].block_2_ctile_map_),
Expand All @@ -159,8 +174,8 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
CTranspose,
TailNum>(
p_shared,
gemm_kernel_args[group_id].a_grid_desc_ak0_m_ak1_,
gemm_kernel_args[group_id].b_grid_desc_bk0_n_bk1_,
a_grid_desc_ak0_m_ak1,
b_grid_desc_bk0_n_bk1,
gemm_kernel_args[group_id].ds_grid_desc_mblock_mperblock_nblock_nperblock_,
gemm_kernel_args[group_id].e_grid_desc_mblock_mperblock_nblock_nperblock_,
gemm_kernel_args[group_id].block_2_ctile_map_,
Expand Down Expand Up @@ -242,6 +257,7 @@ template <index_t NDimSpatial,
typename CShuffleBlockTransferScalarPerVector,
BlockGemmPipelineScheduler BlkGemmPipeSched = BlockGemmPipelineScheduler::Intrawave,
BlockGemmPipelineVersion BlkGemmPipelineVer = BlockGemmPipelineVersion::v1,
bool UseThreadTileTransfer = true,
typename AComputeType = ADataType,
typename BComputeType = AComputeType,
index_t MaxTransposeTransferInScalarPerVector = 1,
Expand All @@ -266,6 +282,14 @@ struct DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffleV3
static_assert(NDimSpatial == 2 || NDimSpatial == 3,
"wrong! only implemented for 2D and 3D now");

#ifdef USE_WAVE_TRANSFER_BWD_DATA

static_assert(UseThreadTileTransfer == false &&
(ConvBackwardDataSpecialization ==
ConvolutionBackwardDataSpecialization::Filter1x1Stride1Pad0),
"Only Filter1x1Stride1Pad0is supported for wavetile transfer");
#endif

// MaxGroupedGemmGroupsNum is used to specify number of gemm args in compile time. With this
// implementation we can avoid copy data to workspace before kernel launch since number of
// groups is runtime parameter. If number of groups is larger than MaxGroupedGemmGroupsNum then
Expand Down Expand Up @@ -450,10 +474,10 @@ struct DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffleV3
BlkGemmPipelineVer,
AComputeType,
BComputeType,
false, // PermuteA
false, // PermuteB
false, // IsBPreShuffled
true>; // ForceThreadTileTransfer
false,
false,
false,
UseThreadTileTransfer>;

#define GridwiseGemmCTransposeTemplateParameters \
ALayout, BLayout, DsLayout, ELayout, Tuple<ADataType>, Tuple<BDataType>, AccDataType, \
Expand Down Expand Up @@ -517,8 +541,8 @@ struct DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffleV3
struct GemmArgs
{
GemmArgs() = default;
GemmArgs(AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1,
BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1,
GemmArgs(AGridDesc_M_K a_grid_desc_m_k,
BGridDesc_N_K b_grid_desc_n_k,
DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
ds_grid_desc_mblock_mperblock_nblock_nperblock,
EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
Expand All @@ -527,8 +551,8 @@ struct DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffleV3
index_t BlockStart,
index_t BlockEnd,
bool HasMainKBlockLoop)
: a_grid_desc_ak0_m_ak1_(a_grid_desc_ak0_m_ak1),
b_grid_desc_bk0_n_bk1_(b_grid_desc_bk0_n_bk1),
: a_grid_desc_m_k_(a_grid_desc_m_k),
b_grid_desc_n_k_(b_grid_desc_n_k),

ds_grid_desc_mblock_mperblock_nblock_nperblock_(
ds_grid_desc_mblock_mperblock_nblock_nperblock),
Expand All @@ -543,8 +567,8 @@ struct DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffleV3
{
}
// tensor descriptors for block/thread-wise copy
AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_;
BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_;
AGridDesc_M_K a_grid_desc_m_k_;
BGridDesc_N_K b_grid_desc_n_k_;
DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
ds_grid_desc_mblock_mperblock_nblock_nperblock_;
EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock e_grid_desc_mblock_mperblock_nblock_nperblock_;
Expand Down Expand Up @@ -926,8 +950,8 @@ struct DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffleV3
gemm_kernel_args_[gemms_count_ /
MaxGroupedGemmGroupsNum][gemms_count_ %
MaxGroupedGemmGroupsNum] =
GemmArgs{a_grid_desc_ak0_m_ak1,
b_grid_desc_bk0_n_bk1,
GemmArgs{a_grid_desc_m_k,
b_grid_desc_n_k,
GridwiseGemmCTranspose::
MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
ds_grid_desc_m_n, MBlock, NBlock),
Expand Down Expand Up @@ -1055,10 +1079,10 @@ struct DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffleV3
{
for(std::size_t i = 0; i < a_grid_desc_m_k_container_.size(); i++)
{
std::cout << "a_grid_desc_m_ak_container_" << a_grid_desc_m_k_container_[i]
std::cout << "a_grid_desc_m_k_container_" << a_grid_desc_m_k_container_[i]
<< std::endl;

std::cout << "b_grid_desc_n_bk_container_" << b_grid_desc_n_k_container_[i]
std::cout << "b_grid_desc_n_k_container_" << b_grid_desc_n_k_container_[i]
<< std::endl;

static_for<0, NumDTensor, 1>{}([&](auto j) {
Expand Down Expand Up @@ -1086,8 +1110,7 @@ struct DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffleV3
std::vector<EGridDesc_M_N> e_grid_desc_m_n_container_;

// tensor descriptor for block-wise copy
std::vector<AGridDesc_AK0_M_AK1> a_grid_desc_ak0_m_ak1_container_;
std::vector<BGridDesc_BK0_N_BK1> b_grid_desc_bk0_n_bk1_container_;

std::vector<DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>
ds_grid_desc_mblock_mperblock_nblock_nperblock_container_;
std::vector<EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>
Expand Down Expand Up @@ -1233,8 +1256,8 @@ struct DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffleV3
constexpr bool no_main_loop = no_main_k_block_loop.value;
const auto kernel = kernel_grouped_conv_bwd_data_wmma_cshuffle_v3<
GridwiseGemmCTranspose,
DeviceOp::AGridDesc_AK0_M_AK1,
DeviceOp::BGridDesc_BK0_N_BK1,
DeviceOp::AGridDesc_M_K,
DeviceOp::BGridDesc_N_K,
DeviceOp::DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
DeviceOp::EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
MaxGroupedGemmGroupsNum,
Expand Down Expand Up @@ -1785,12 +1808,12 @@ struct DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffleV3
p_ds_grid_dummy[i] = nullptr;
StrideDs_dummy[i] = I0;
});
for(std::size_t i = 0; i < arg.a_grid_desc_ak0_m_ak1_container_.size(); i++)
for(std::size_t i = 0; i < arg.a_grid_desc_m_k_container_.size(); i++)
{
const index_t GemmM = arg.a_grid_desc_ak0_m_ak1_container_[i].GetLength(I1);
const index_t GemmN = arg.b_grid_desc_bk0_n_bk1_container_[i].GetLength(I1);
const index_t GemmK = arg.a_grid_desc_ak0_m_ak1_container_[i].GetLength(I0) *
arg.a_grid_desc_ak0_m_ak1_container_[i].GetLength(I2);
const index_t GemmM = arg.a_grid_desc_m_k_container_[i].GetLength(I0);
const index_t GemmN = arg.b_grid_desc_n_k_container_[i].GetLength(I0);
const index_t GemmK = arg.a_grid_desc_m_k_container_[i].GetLength(I1);

// Create gemm arguments with dummy values to check for validity
typename GridwiseGemmCTranspose::Argument gemm_arg{
std::array<const void*, 1>{nullptr}, // p_as_grid
Expand Down
Loading
Loading