From 6ff3d8b36c82fc7503e1eb438cdd7c2e80208681 Mon Sep 17 00:00:00 2001 From: apoorva Date: Wed, 21 Jan 2026 15:22:37 +0000 Subject: [PATCH 01/12] Added bwd_data cwave tile transfer support --- ...v_bwd_data_multiple_d_wmma_cshuffle_v3.hpp | 119 ++++++++++-------- ...wd_data_wmma_v3_wave_tranfer_instances.hpp | 83 ++++++++++++ .../gpu/grouped_convolution_backward_data.hpp | 9 ++ ...grouped_convolution_backward_data_wmma.inc | 56 +++++++++ .../grouped_conv2d_bwd_data/CMakeLists.txt | 5 +- ...ansfer_nhwgc_gkyxc_nhwgk_bf16_instance.cpp | 42 +++++++ ...ransfer_nhwgc_gkyxc_nhwgk_f16_instance.cpp | 42 +++++++ .../grouped_conv3d_bwd_data/CMakeLists.txt | 4 + ...fer_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp | 42 +++++++ ...sfer_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp | 42 +++++++ 10 files changed, 392 insertions(+), 52 deletions(-) create mode 100644 library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_wmma_v3_wave_tranfer_instances.hpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/wmma/device_grouped_conv2d_bwd_data_wmma_v3_wave_transfer_nhwgc_gkyxc_nhwgk_bf16_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/wmma/device_grouped_conv2d_bwd_data_wmma_v3_wave_transfer_nhwgc_gkyxc_nhwgk_f16_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/wmma/device_grouped_conv3d_bwd_data_wmma_v3_wave_transfer_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/wmma/device_grouped_conv3d_bwd_data_wmma_v3_wave_transfer_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle_v3.hpp index dfdfd53725f..8e1a07f185b 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle_v3.hpp @@ -35,8 +35,8 @@ namespace device { namespace { template ()]; - auto epilogue_args = typename GridwiseGemm::EpilogueCShuffle{}; + + using EpilogueType = + typename std::conditional::type; + + constexpr index_t LDS_size = + GridwiseGemm::template GetSharedMemoryNumberOfByte(); + __shared__ char p_shared[LDS_size]; + + auto epilogue_args = EpilogueType{}; const index_t block_args_id = __builtin_amdgcn_readfirstlane(blockIdx.x); index_t left = 0; index_t right = gemms_count; index_t group_id = index_t((left + right) / 2); + while((!(block_args_id >= gemm_kernel_args[group_id].BlockStart_ && block_args_id < gemm_kernel_args[group_id].BlockEnd_)) && left <= right) @@ -90,13 +100,13 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) } const auto num_k_per_block = - gemm_kernel_args[group_id].a_grid_desc_ak0_m_ak1_.GetLength(Number<0>{}) / KBatch; + gemm_kernel_args[group_id].a_grid_desc_m_k_.GetLength(Number<0>{}) / KBatch; if constexpr(HasMainKBlockLoopInAllGemm || NoMainKBlockLoopInAllGemm) { - GridwiseGemm::template Run( p_shared, - gemm_kernel_args[group_id].a_grid_desc_ak0_m_ak1_, - gemm_kernel_args[group_id].b_grid_desc_bk0_n_bk1_, + gemm_kernel_args[group_id].a_grid_desc_m_k_, + gemm_kernel_args[group_id].b_grid_desc_n_k_, gemm_kernel_args[group_id].ds_grid_desc_mblock_mperblock_nblock_nperblock_, gemm_kernel_args[group_id].e_grid_desc_mblock_mperblock_nblock_nperblock_, gemm_kernel_args[group_id].block_2_ctile_map_, @@ -122,8 +132,8 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) { if(gemm_kernel_args[group_id].HasMainKBlockLoop_) { - GridwiseGemm::template Run( p_shared, - gemm_kernel_args[group_id].a_grid_desc_ak0_m_ak1_, - gemm_kernel_args[group_id].b_grid_desc_bk0_n_bk1_, + gemm_kernel_args[group_id].a_grid_desc_m_k_, + gemm_kernel_args[group_id].b_grid_desc_n_k_, gemm_kernel_args[group_id].ds_grid_desc_mblock_mperblock_nblock_nperblock_, gemm_kernel_args[group_id].e_grid_desc_mblock_mperblock_nblock_nperblock_, gemm_kernel_args[group_id].block_2_ctile_map_, @@ -147,8 +157,8 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) } else { - GridwiseGemm::template Run( p_shared, - gemm_kernel_args[group_id].a_grid_desc_ak0_m_ak1_, - gemm_kernel_args[group_id].b_grid_desc_bk0_n_bk1_, + gemm_kernel_args[group_id].a_grid_desc_m_k_, + gemm_kernel_args[group_id].b_grid_desc_n_k_, gemm_kernel_args[group_id].ds_grid_desc_mblock_mperblock_nblock_nperblock_, gemm_kernel_args[group_id].e_grid_desc_mblock_mperblock_nblock_nperblock_, gemm_kernel_args[group_id].block_2_ctile_map_, @@ -242,6 +252,7 @@ template ; // ForceThreadTileTransfer + false, + false, + false, + UseThreadTileTransfer>; #define GridwiseGemmCTransposeTemplateParameters \ ALayout, BLayout, DsLayout, ELayout, Tuple, Tuple, AccDataType, \ @@ -494,13 +511,13 @@ struct DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffleV3 constexpr static ConvToGemmBwdDataTransform dummy_conv_to_gemm_transform; using ABDsEGridDesc = decltype(GetDummyABDsEGridDescriptor(dummy_conv_to_gemm_transform)); - using AGridDesc_AK0_M_AK1 = remove_cvref_t>; - using BGridDesc_BK0_N_BK1 = remove_cvref_t>; - using DsGridDesc_M_N = remove_cvref_t>; - using EGridDesc_M_N = remove_cvref_t>; + using AGridDesc_M_K_ = remove_cvref_t>; + using BGridDesc_N_K_ = remove_cvref_t>; + using DsGridDesc_M_N = remove_cvref_t>; + using EGridDesc_M_N = remove_cvref_t>; - using AGridDesc_M_K = decltype(transform_k0_m_k1_to_m_k(AGridDesc_AK0_M_AK1{})); - using BGridDesc_N_K = decltype(transform_k0_m_k1_to_m_k(BGridDesc_BK0_N_BK1{})); + using AGridDesc_M_K = decltype(transform_k0_m_k1_to_m_k(AGridDesc_M_K_{})); + using BGridDesc_N_K = decltype(transform_k0_m_k1_to_m_k(BGridDesc_N_K_{})); // Note: here we can call gridwise functions with dummy arguments, // just to create the alias @@ -517,8 +534,8 @@ struct DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffleV3 struct GemmArgs { GemmArgs() = default; - GemmArgs(AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1, - BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1, + GemmArgs(AGridDesc_M_K_ a_grid_desc_m_k, + BGridDesc_N_K_ b_grid_desc_n_k, DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock ds_grid_desc_mblock_mperblock_nblock_nperblock, EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock @@ -527,8 +544,8 @@ struct DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffleV3 index_t BlockStart, index_t BlockEnd, bool HasMainKBlockLoop) - : a_grid_desc_ak0_m_ak1_(a_grid_desc_ak0_m_ak1), - b_grid_desc_bk0_n_bk1_(b_grid_desc_bk0_n_bk1), + : a_grid_desc_m_k_(a_grid_desc_m_k), + b_grid_desc_n_k_(b_grid_desc_n_k), ds_grid_desc_mblock_mperblock_nblock_nperblock_( ds_grid_desc_mblock_mperblock_nblock_nperblock), @@ -543,8 +560,8 @@ struct DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffleV3 { } // tensor descriptors for block/thread-wise copy - AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_; - BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_; + AGridDesc_M_K_ a_grid_desc_m_k_; + BGridDesc_N_K_ b_grid_desc_n_k_; DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock ds_grid_desc_mblock_mperblock_nblock_nperblock_; EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock e_grid_desc_mblock_mperblock_nblock_nperblock_; @@ -1055,10 +1072,10 @@ struct DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffleV3 { for(std::size_t i = 0; i < a_grid_desc_m_k_container_.size(); i++) { - std::cout << "a_grid_desc_m_ak_container_" << a_grid_desc_m_k_container_[i] + std::cout << "a_grid_desc_m_k_container_" << a_grid_desc_m_k_container_[i] << std::endl; - std::cout << "b_grid_desc_n_bk_container_" << b_grid_desc_n_k_container_[i] + std::cout << "b_grid_desc_n_k_container_" << b_grid_desc_n_k_container_[i] << std::endl; static_for<0, NumDTensor, 1>{}([&](auto j) { @@ -1086,8 +1103,8 @@ struct DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffleV3 std::vector e_grid_desc_m_n_container_; // tensor descriptor for block-wise copy - std::vector a_grid_desc_ak0_m_ak1_container_; - std::vector b_grid_desc_bk0_n_bk1_container_; + // std::vector a_grid_desc_m_k_container_; + // std::vector b_grid_desc_n_k_container_; std::vector ds_grid_desc_mblock_mperblock_nblock_nperblock_container_; std::vector @@ -1233,8 +1250,8 @@ struct DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffleV3 constexpr bool no_main_loop = no_main_k_block_loop.value; const auto kernel = kernel_grouped_conv_bwd_data_wmma_cshuffle_v3< GridwiseGemmCTranspose, - DeviceOp::AGridDesc_AK0_M_AK1, - DeviceOp::BGridDesc_BK0_N_BK1, + DeviceOp::AGridDesc_M_K, + DeviceOp::BGridDesc_N_K, DeviceOp::DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, DeviceOp::EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, MaxGroupedGemmGroupsNum, @@ -1785,12 +1802,12 @@ struct DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffleV3 p_ds_grid_dummy[i] = nullptr; StrideDs_dummy[i] = I0; }); - for(std::size_t i = 0; i < arg.a_grid_desc_ak0_m_ak1_container_.size(); i++) + for(std::size_t i = 0; i < arg.a_grid_desc_m_k_container_.size(); i++) { - const index_t GemmM = arg.a_grid_desc_ak0_m_ak1_container_[i].GetLength(I1); - const index_t GemmN = arg.b_grid_desc_bk0_n_bk1_container_[i].GetLength(I1); - const index_t GemmK = arg.a_grid_desc_ak0_m_ak1_container_[i].GetLength(I0) * - arg.a_grid_desc_ak0_m_ak1_container_[i].GetLength(I2); + const index_t GemmM = arg.a_grid_desc_m_k_container_[i].GetLength(I0); + const index_t GemmN = arg.b_grid_desc_n_k_container_[i].GetLength(I0); + const index_t GemmK = arg.a_grid_desc_m_k_container_[i].GetLength(I1); + // Create gemm arguments with dummy values to check for validity typename GridwiseGemmCTranspose::Argument gemm_arg{ std::array{nullptr}, // p_as_grid diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_wmma_v3_wave_tranfer_instances.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_wmma_v3_wave_tranfer_instances.hpp new file mode 100644 index 00000000000..fa613bd8361 --- /dev/null +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_wmma_v3_wave_tranfer_instances.hpp @@ -0,0 +1,83 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle_v3.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +#define USE_WAVE_TRANSFER + +using BF16 = ck::bhalf_t; +using F16 = ck::half_t; +using F32 = float; +using BF8 = ck::bf8_t; +using F8 = ck::f8_t; + +using Empty_Tuple = ck::Tuple<>; + +template +using S = ck::Sequence; + +using namespace ck::tensor_layout::convolution; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto ConvBwdDataDefault = ConvolutionBackwardDataSpecialization::Default; + +static constexpr auto ConvBwdDataFilter1x1Stride1Pad0 = + ConvolutionBackwardDataSpecialization::Filter1x1Stride1Pad0; +template +using device_grouped_conv_bwd_data_wmma_cshufflev3_bf16_wave_transfer_instances = std::tuple< + // clang-format off + //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MWmma| NWmma| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Pipeline scheduler | Pipeline version | + //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Size| Block| Block| Block| | | WMMA| WMMA| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MWmmaPerWave| NWmmaPerWave| _MBlock_MWaveMPerWmma| ScalarPerVector| | | + //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerWmma| _NWaveNPerWmma| | | + //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + // generic instance + DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffleV3, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, S<1,1,1>,BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, false>, + + DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffleV3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, S<8,8,8>,BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, false>, + DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffleV3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, S<8,8,8>,BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, false> + // clang-format on + >; + +template +using device_grouped_conv_bwd_data_wmma_cshufflev3_f16_wave_transfer_instances = std::tuple< + // clang-format off + //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MWmma| NWmma| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Pipeline scheduler | Pipeline version | + //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Size| Block| Block| Block| | | WMMA| WMMA| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MWmmaPerWave| NWmmaPerWave| _MBlock_MWaveMPerWmma| ScalarPerVector| | | + //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerWmma| _NWaveNPerWmma| | | + //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + + // generic instance + DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffleV3, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, S<1,1,1>,BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, false>, + + DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffleV3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 16, 1, 4>, S<8,8,8>,BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, false>, + DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffleV3, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 16, 1, 4>, S<8,8,8>,BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, false> + + // clang-format on + >; + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_data.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_data.hpp index f784b6ea510..2d1a8eb4495 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_data.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_data.hpp @@ -450,6 +450,8 @@ struct DeviceOperationInstanceFactory< op_ptrs); add_device_grouped_conv2d_bwd_data_wmma_nhwgk_gkyxc_nhwgc_f16_1x1s1p0_instances( op_ptrs); + add_device_grouped_conv2d_bwd_data_wmma_v3_wave_transfer_nhwgk_gkyxc_nhwgc_f16_instances( + op_ptrs); } #endif #ifdef CK_ENABLE_BF16 @@ -461,6 +463,9 @@ struct DeviceOperationInstanceFactory< op_ptrs); add_device_grouped_conv2d_bwd_data_wmma_v3_nhwgk_gkyxc_nhwgc_bf16_16_16_instances( op_ptrs); + + add_device_grouped_conv2d_bwd_data_wmma_v3_wave_transfer_nhwgk_gkyxc_nhwgc_bf16_instances( + op_ptrs); } #endif #ifdef CK_ENABLE_INT8 @@ -520,6 +525,8 @@ struct DeviceOperationInstanceFactory< op_ptrs); add_device_grouped_conv3d_bwd_data_wmma_v3_ndhwgk_gkzyxc_ndhwgc_f16_16_16_instances( op_ptrs); + add_device_grouped_conv3d_bwd_data_wmma_v3_wave_transfer_ndhwgk_gkzyxc_ndhwgc_f16_instances( + op_ptrs); } #endif #ifdef CK_ENABLE_BF16 @@ -531,6 +538,8 @@ struct DeviceOperationInstanceFactory< op_ptrs); add_device_grouped_conv3d_bwd_data_wmma_v3_ndhwgk_gkzyxc_ndhwgc_bf16_16_16_instances( op_ptrs); + add_device_grouped_conv3d_bwd_data_wmma_v3_wave_transfer_ndhwgk_gkzyxc_ndhwgc_bf16_instances( + op_ptrs); } #endif #ifdef CK_ENABLE_INT8 diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_data_wmma.inc b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_data_wmma.inc index 40b659a87f5..3ee3b54ac8c 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_data_wmma.inc +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_data_wmma.inc @@ -80,6 +80,20 @@ void add_device_grouped_conv2d_bwd_data_wmma_nhwgk_gkyxc_nhwgc_f16_instances( PassThrough, PassThrough>>>& instances); +void add_device_grouped_conv2d_bwd_data_wmma_v3_wave_transfer_nhwgk_gkyxc_nhwgc_f16_instances( + std::vector>>& instances); + void add_device_grouped_conv2d_bwd_data_wmma_nhwgk_gkyxc_nhwgc_f16_1x1s1p0_instances( std::vector>>& instances); +void add_device_grouped_conv2d_bwd_data_wmma_v3_wave_transfer_nhwgk_gkyxc_nhwgc_bf16_instances( + std::vector>>& instances); + #endif // conv3dbwdData @@ -326,6 +354,20 @@ void add_device_grouped_conv3d_bwd_data_wmma_v3_ndhwgk_gkzyxc_ndhwgc_bf16_16_16_ PassThrough, PassThrough, PassThrough>>>& instances); + +void add_device_grouped_conv3d_bwd_data_wmma_v3_wave_transfer_ndhwgk_gkzyxc_ndhwgc_bf16_instances( + std::vector>>& instances); #endif #ifdef CK_ENABLE_FP16 void add_device_grouped_conv3d_bwd_data_wmma_v3_ndhwgk_gkzyxc_ndhwgc_f16_instances( @@ -355,6 +397,20 @@ void add_device_grouped_conv3d_bwd_data_wmma_v3_ndhwgk_gkzyxc_ndhwgc_f16_16_16_i PassThrough, PassThrough, PassThrough>>>& instances); + +void add_device_grouped_conv3d_bwd_data_wmma_v3_wave_transfer_ndhwgk_gkzyxc_ndhwgc_f16_instances( + std::vector>>& instances); #endif } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/CMakeLists.txt index 19e27cf173b..55becc73ab9 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/CMakeLists.txt @@ -49,5 +49,8 @@ add_instance_library( wmma/device_grouped_conv2d_bwd_data_wmma_v3_nhwgc_gkyxc_nhwgk_bf16_instance.cpp wmma/device_grouped_conv2d_bwd_data_wmma_v3_nhwgc_gkyxc_nhwgk_bf16_16_16_instance.cpp - + + + wmma/device_grouped_conv2d_bwd_data_wmma_v3_wave_transfer_nhwgc_gkyxc_nhwgk_bf16_instance.cpp + wmma/device_grouped_conv2d_bwd_data_wmma_v3_wave_transfer_nhwgc_gkyxc_nhwgk_f16_instance.cpp ) diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/wmma/device_grouped_conv2d_bwd_data_wmma_v3_wave_transfer_nhwgc_gkyxc_nhwgk_bf16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/wmma/device_grouped_conv2d_bwd_data_wmma_v3_wave_transfer_nhwgc_gkyxc_nhwgk_bf16_instance.cpp new file mode 100644 index 00000000000..a5027772e98 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/wmma/device_grouped_conv2d_bwd_data_wmma_v3_wave_transfer_nhwgc_gkyxc_nhwgk_bf16_instance.cpp @@ -0,0 +1,42 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_wmma_v3_wave_tranfer_instances.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_grouped_conv2d_bwd_data_wmma_v3_wave_transfer_nhwgk_gkyxc_nhwgc_bf16_instances( + std::vector>>& instances) +{ + + // 2. Filter1x1Stride1Pad0 + add_device_operation_instances( + instances, + device_grouped_conv_bwd_data_wmma_cshufflev3_bf16_wave_transfer_instances< + 2, + NHWGK, + GKYXC, + Empty_Tuple, + NHWGC, + ConvBwdDataFilter1x1Stride1Pad0>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/wmma/device_grouped_conv2d_bwd_data_wmma_v3_wave_transfer_nhwgc_gkyxc_nhwgk_f16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/wmma/device_grouped_conv2d_bwd_data_wmma_v3_wave_transfer_nhwgc_gkyxc_nhwgk_f16_instance.cpp new file mode 100644 index 00000000000..91516e81df4 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/wmma/device_grouped_conv2d_bwd_data_wmma_v3_wave_transfer_nhwgc_gkyxc_nhwgk_f16_instance.cpp @@ -0,0 +1,42 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_wmma_v3_wave_tranfer_instances.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_grouped_conv2d_bwd_data_wmma_v3_wave_transfer_nhwgk_gkyxc_nhwgc_f16_instances( + std::vector>>& instances) +{ + + // 2. Filter1x1Stride1Pad0 + add_device_operation_instances( + instances, + device_grouped_conv_bwd_data_wmma_cshufflev3_f16_wave_transfer_instances< + 2, + NHWGK, + GKYXC, + Empty_Tuple, + NHWGC, + ConvBwdDataFilter1x1Stride1Pad0>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/CMakeLists.txt index 01ff4095d74..46327941390 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/CMakeLists.txt @@ -45,6 +45,10 @@ set(GROUPED_CONV3D_BWD_DATA wmma/device_grouped_conv3d_bwd_data_wmma_v3_ndhwgc_gkzyxc_ndhwgk_f16_16_16_instance.cpp wmma/device_grouped_conv3d_bwd_data_wmma_v3_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp + + + wmma/device_grouped_conv3d_bwd_data_wmma_v3_wave_transfer_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp + wmma/device_grouped_conv3d_bwd_data_wmma_v3_wave_transfer_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp ) diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/wmma/device_grouped_conv3d_bwd_data_wmma_v3_wave_transfer_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/wmma/device_grouped_conv3d_bwd_data_wmma_v3_wave_transfer_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp new file mode 100644 index 00000000000..9a57d84a434 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/wmma/device_grouped_conv3d_bwd_data_wmma_v3_wave_transfer_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp @@ -0,0 +1,42 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_wmma_v3_wave_tranfer_instances.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_grouped_conv3d_bwd_data_wmma_v3_wave_transfer_ndhwgk_gkzyxc_ndhwgc_bf16_instances( + std::vector>>& instances) +{ + + // 2. Filter1x1Stride1Pad0 + add_device_operation_instances( + instances, + device_grouped_conv_bwd_data_wmma_cshufflev3_bf16_wave_transfer_instances< + 3, + NDHWGK, + GKZYXC, + Empty_Tuple, + NDHWGC, + ConvBwdDataFilter1x1Stride1Pad0>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/wmma/device_grouped_conv3d_bwd_data_wmma_v3_wave_transfer_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/wmma/device_grouped_conv3d_bwd_data_wmma_v3_wave_transfer_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp new file mode 100644 index 00000000000..80152c2bce5 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/wmma/device_grouped_conv3d_bwd_data_wmma_v3_wave_transfer_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp @@ -0,0 +1,42 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_wmma_v3_wave_tranfer_instances.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_grouped_conv3d_bwd_data_wmma_v3_wave_transfer_ndhwgk_gkzyxc_ndhwgc_f16_instances( + std::vector>>& instances) +{ + + // 2. Filter1x1Stride1Pad0 + add_device_operation_instances( + instances, + device_grouped_conv_bwd_data_wmma_cshufflev3_f16_wave_transfer_instances< + 3, + NDHWGK, + GKZYXC, + Empty_Tuple, + NDHWGC, + ConvBwdDataFilter1x1Stride1Pad0>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck From 2cbc4ce2a38d15eaa24eb1a07538287f915364a2 Mon Sep 17 00:00:00 2001 From: apoorva Date: Tue, 27 Jan 2026 10:16:22 +0000 Subject: [PATCH 02/12] Temp changes of bwd_wei(not working) --- ...ouped_conv_bwd_weight_wmma_cshuffle_v3.hpp | 140 ++++++++++-------- 1 file changed, 78 insertions(+), 62 deletions(-) diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_wmma_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_wmma_cshuffle_v3.hpp index b2ae092c274..6d34a44f629 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_wmma_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_wmma_cshuffle_v3.hpp @@ -41,8 +41,8 @@ namespace tensor_operation { namespace device { template (); + using EpilogueType = + typename std::conditional::type; + + constexpr index_t LDS_size = + GridwiseGemm::template GetSharedMemoryNumberOfByte(); __shared__ char p_shared[LDS_size]; - auto epilogue_args = typename GridwiseGemm::EpilogueCShuffle{}; + auto epilogue_args = EpilogueType{}; - GridwiseGemm::template Run()); - using AGridDesc_K0_M_K1 = remove_cvref_t; - using BGridDesc_K0_N_K1 = remove_cvref_t; - using CGridDesc_M_N = remove_cvref_t; + using AGridDesc_M_K = remove_cvref_t; + using BGridDesc_N_K = remove_cvref_t; + using CGridDesc_M_N = remove_cvref_t; using Block2TileMapTranspose = BlockToCTileMap_M00_N0_M01Adapt; @@ -401,10 +415,10 @@ struct DeviceGroupedConvBwdWeight_Wmma_CShuffleV3 BlkGemmPipelineVer, ComputeTypeA, ComputeTypeB, - false, // PermuteA - false, // permuteB - false, // IsBPreshuffle - true>; // ForceThreadTileTransfer + false, // PermuteA + false, // permuteB + false, // IsBPreshuffle + UseThreadTileTransfer>; // ForceThreadTileTransfer // Argument using CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = @@ -434,8 +448,8 @@ struct DeviceGroupedConvBwdWeight_Wmma_CShuffleV3 &max_occupancy, kernel_grouped_conv_bwd_weight_wmma_cshuffle_v3< GridwiseGemm, - remove_reference_t, - remove_reference_t, + remove_reference_t, + remove_reference_t, remove_reference_t, ComputePtrOffsetOfStridedBatch, true, @@ -473,8 +487,8 @@ struct DeviceGroupedConvBwdWeight_Wmma_CShuffleV3 : p_a_grid_{p_out_grid}, p_b_grid_{p_in_grid}, p_c_grid_{p_wei_grid}, - a_grid_desc_kbatch_k0_m_k1_{}, - b_grid_desc_kbatch_k0_n_k1_{}, + a_grid_desc_kbatch_m_k_{}, + b_grid_desc_kbatch_n_k_{}, c_grid_desc_m_n_{}, c_grid_desc_mblock_mperblock_nblock_nperblock_{}, compute_ptr_offset_of_batch_{}, @@ -572,16 +586,16 @@ struct DeviceGroupedConvBwdWeight_Wmma_CShuffleV3 input_right_pads, k_batch_); - a_grid_desc_kbatch_k0_m_k1_ = descs[I0]; - b_grid_desc_kbatch_k0_n_k1_ = descs[I1]; - c_grid_desc_m_n_ = descs[I2]; + a_grid_desc_kbatch_m_k_ = descs[I0]; + b_grid_desc_kbatch_n_k_ = descs[I1]; + c_grid_desc_m_n_ = descs[I2]; // A/B/C Batch Stride compute_ptr_offset_of_batch_.BatchStrideA_ = a_g_n_k_wos_strides_transposed[0]; compute_ptr_offset_of_batch_.BatchStrideB_ = b_g_n_c_wis_strides_transposed[0]; compute_ptr_offset_of_batch_.BatchStrideC_ = e_g_k_c_xs_strides_transposed[0]; - const index_t GemmM = a_grid_desc_kbatch_k0_m_k1_.GetLength(I1); - const index_t GemmN = b_grid_desc_kbatch_k0_n_k1_.GetLength(I1); + const index_t GemmM = a_grid_desc_kbatch_m_k_.GetLength(I0); + const index_t GemmN = b_grid_desc_kbatch_n_k_.GetLength(I0); c_grid_desc_mblock_mperblock_nblock_nperblock_ = GridwiseGemm::MakeDEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( @@ -678,8 +692,8 @@ struct DeviceGroupedConvBwdWeight_Wmma_CShuffleV3 const ADataType* p_a_grid_; const BDataType* p_b_grid_; CDataType* p_c_grid_; - AGridDesc_K0_M_K1 a_grid_desc_kbatch_k0_m_k1_; - BGridDesc_K0_N_K1 b_grid_desc_kbatch_k0_n_k1_; + AGridDesc_M_K a_grid_desc_kbatch_m_k_; + BGridDesc_N_K b_grid_desc_kbatch_n_k_; CGridDesc_M_N c_grid_desc_m_n_; CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock c_grid_desc_mblock_mperblock_nblock_nperblock_; @@ -724,17 +738,15 @@ struct DeviceGroupedConvBwdWeight_Wmma_CShuffleV3 void ShowInfo(const Argument& arg) { - std::cout << "arg.a_grid_desc_kbatch_k0_m_k1_{" - << arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I0) << ", " - << arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I1) << ", " - << arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I2) << ", " - << arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I3) << "}" << std::endl; - - std::cout << "arg.b_grid_desc_kbatch_k0_n_k1_{" - << arg.b_grid_desc_kbatch_k0_n_k1_.GetLength(I0) << ", " - << arg.b_grid_desc_kbatch_k0_n_k1_.GetLength(I1) << ", " - << arg.b_grid_desc_kbatch_k0_n_k1_.GetLength(I2) << ", " - << arg.b_grid_desc_kbatch_k0_n_k1_.GetLength(I3) << "}" << std::endl; + std::cout << "arg.a_grid_desc_kbatch_m_k_{" << arg.a_grid_desc_kbatch_m_k_.GetLength(I0) + << ", " << arg.a_grid_desc_kbatch_m_k_.GetLength(I1) << ", " + << arg.a_grid_desc_kbatch_m_k_.GetLength(I2) << ", " + << arg.a_grid_desc_kbatch_m_k_.GetLength(I3) << "}" << std::endl; + + std::cout << "arg.b_grid_desc_kbatch_n_k_{" << arg.b_grid_desc_kbatch_n_k_.GetLength(I0) + << ", " << arg.b_grid_desc_kbatch_n_k_.GetLength(I1) << ", " + << arg.b_grid_desc_kbatch_n_k_.GetLength(I2) << ", " + << arg.b_grid_desc_kbatch_n_k_.GetLength(I3) << "}" << std::endl; std::cout << "arg.c_grid_desc_m_n_{" << arg.c_grid_desc_m_n_.GetLength(I0) << ", " << arg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl; @@ -744,10 +756,9 @@ struct DeviceGroupedConvBwdWeight_Wmma_CShuffleV3 { float ave_time = 0; - const index_t GemmM = arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I1); - const index_t GemmN = arg.b_grid_desc_kbatch_k0_n_k1_.GetLength(I1); - const index_t GemmK = arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I0) * - arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I2); + const index_t GemmM = arg.a_grid_desc_kbatch_m_k_.GetLength(I0); + const index_t GemmN = arg.b_grid_desc_kbatch_n_k_.GetLength(I0); + const index_t GemmK = arg.a_grid_desc_kbatch_m_k_.GetLength(I1); const ADataType* p_a_grid = arg.p_a_grid_; const BDataType* p_b_grid = arg.p_b_grid_; @@ -839,9 +850,9 @@ struct DeviceGroupedConvBwdWeight_Wmma_CShuffleV3 index_t K_split = (gemm_arg.K + k_grain - 1) / k_grain * KPerBlock; const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K_split); - const auto num_k_per_block = - arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(Number<0>{}) / gemm_arg.KBatch; + std::cout << "K0 value is:" << (GridwiseGemm::CalculateAK0Padded( arg.a_grid_desc_kbatch_m_k_.GetLength(Number<1>{}),arg.k_batch_)) << std::endl; + const index_t num_k_per_block = (GridwiseGemm::CalculateAK0Padded( arg.a_grid_desc_kbatch_m_k_.GetLength(Number<1>{}),arg.k_batch_))/gemm_arg.KBatch; const auto clear_workspace = [&]() { hip_check_error( hipMemsetAsync(p_e_grid, 0, arg.c_space_size_bytes, stream_config.stream_id_)); @@ -855,11 +866,11 @@ struct DeviceGroupedConvBwdWeight_Wmma_CShuffleV3 typename GridwiseGemm::Argument gemm_arg_ = gemm_arg; std::array size_as_buffers; - size_as_buffers[0] = arg.a_grid_desc_kbatch_k0_m_k1_.GetElementSpaceSize() * + size_as_buffers[0] = arg.a_grid_desc_kbatch_m_k_.GetElementSpaceSize() * sizeof(ADataType) / GridwiseGemm::APackedSize; std::array size_bs_buffers; - size_bs_buffers[0] = arg.b_grid_desc_kbatch_k0_n_k1_.GetElementSpaceSize() * + size_bs_buffers[0] = arg.b_grid_desc_kbatch_n_k_.GetElementSpaceSize() * sizeof(BDataType) / GridwiseGemm::BPackedSize; std::array size_ds_buffers; @@ -889,8 +900,8 @@ struct DeviceGroupedConvBwdWeight_Wmma_CShuffleV3 dim3(BlockSize), 0, gemm_arg_, - arg.a_grid_desc_kbatch_k0_m_k1_, - arg.b_grid_desc_kbatch_k0_n_k1_, + arg.a_grid_desc_kbatch_m_k_, + arg.b_grid_desc_kbatch_n_k_, arg.c_grid_desc_mblock_mperblock_nblock_nperblock_, arg.compute_ptr_offset_of_batch_, num_k_per_block); @@ -905,8 +916,8 @@ struct DeviceGroupedConvBwdWeight_Wmma_CShuffleV3 dim3(BlockSize), 0, gemm_arg, - arg.a_grid_desc_kbatch_k0_m_k1_, - arg.b_grid_desc_kbatch_k0_n_k1_, + arg.a_grid_desc_kbatch_m_k_, + arg.b_grid_desc_kbatch_n_k_, arg.c_grid_desc_mblock_mperblock_nblock_nperblock_, arg.compute_ptr_offset_of_batch_, num_k_per_block); @@ -926,8 +937,8 @@ struct DeviceGroupedConvBwdWeight_Wmma_CShuffleV3 { const auto kernel = kernel_grouped_conv_bwd_weight_wmma_cshuffle_v3< GridwiseGemm, - remove_reference_t, - remove_reference_t, + remove_reference_t, + remove_reference_t, remove_reference_t< DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>, ComputePtrOffsetOfStridedBatch, @@ -940,8 +951,8 @@ struct DeviceGroupedConvBwdWeight_Wmma_CShuffleV3 { const auto kernel = kernel_grouped_conv_bwd_weight_wmma_cshuffle_v3< GridwiseGemm, - remove_reference_t, - remove_reference_t, + remove_reference_t, + remove_reference_t, remove_reference_t< DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>, ComputePtrOffsetOfStridedBatch, @@ -965,8 +976,8 @@ struct DeviceGroupedConvBwdWeight_Wmma_CShuffleV3 { const auto kernel = kernel_grouped_conv_bwd_weight_wmma_cshuffle_v3< GridwiseGemm, - remove_reference_t, - remove_reference_t, + remove_reference_t, + remove_reference_t, remove_reference_t< DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>, ComputePtrOffsetOfStridedBatch, @@ -979,8 +990,8 @@ struct DeviceGroupedConvBwdWeight_Wmma_CShuffleV3 { const auto kernel = kernel_grouped_conv_bwd_weight_wmma_cshuffle_v3< GridwiseGemm, - remove_reference_t, - remove_reference_t, + remove_reference_t, + remove_reference_t, remove_reference_t< DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>, ComputePtrOffsetOfStridedBatch, @@ -1042,10 +1053,15 @@ struct DeviceGroupedConvBwdWeight_Wmma_CShuffleV3 static bool IsSupportedArgument(const Argument& arg) { - const index_t GemmM = arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I1); - const index_t GemmN = arg.b_grid_desc_kbatch_k0_n_k1_.GetLength(I1); - const index_t GemmK = arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I0) * - arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I2); +#if DISABLE_SPLIT_K_AUTODEDUCE_FOR_ONE_STAGE_KERNELS + if(arg.k_batch_ < 0) + { + return false; + } +#endif + const index_t GemmM = arg.a_grid_desc_kbatch_m_k_.GetLength(I0); + const index_t GemmN = arg.b_grid_desc_kbatch_n_k_.GetLength(I0); + const index_t GemmK = arg.a_grid_desc_kbatch_m_k_.GetLength(I1); typename GridwiseGemm::Argument gemm_arg{std::array{nullptr}, // p_as_grid std::array{nullptr}, // p_bs_grid From cc395ff4fc44d20c81f881673886ab312f78a0c5 Mon Sep 17 00:00:00 2001 From: apoorva Date: Tue, 27 Jan 2026 20:14:53 +0000 Subject: [PATCH 03/12] wave tile support for bwd_data and bwd_wei --- ...v_bwd_data_multiple_d_wmma_cshuffle_v3.hpp | 77 ++++++++++++------- ...ouped_conv_bwd_weight_wmma_cshuffle_v3.hpp | 42 ++++++++-- 2 files changed, 85 insertions(+), 34 deletions(-) diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle_v3.hpp index 8e1a07f185b..0cb9ff53ee4 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle_v3.hpp @@ -98,15 +98,19 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) } group_id = index_t((left + right) / 2); } + const auto num_k_per_block = - gemm_kernel_args[group_id].a_grid_desc_m_k_.GetLength(Number<0>{}) / KBatch; + GridwiseGemm::CalculateAK0Padded(gemm_kernel_args[group_id].a_grid_desc_m_k_.GetLength(Number<1>{}),KBatch); + + const auto a_grid_desc_ak0_m_ak1 = GridwiseGemm::MakeAGridDescriptor_AK0_M_AK1(gemm_kernel_args[group_id].a_grid_desc_m_k_); + const auto b_grid_desc_bk0_n_bk1 = GridwiseGemm::MakeBGridDescriptor_BK0_N_BK1(gemm_kernel_args[group_id].b_grid_desc_n_k_); if constexpr(HasMainKBlockLoopInAllGemm || NoMainKBlockLoopInAllGemm) { - GridwiseGemm::template Run( p_shared, - gemm_kernel_args[group_id].a_grid_desc_m_k_, - gemm_kernel_args[group_id].b_grid_desc_n_k_, + a_grid_desc_ak0_m_ak1, + b_grid_desc_bk0_n_bk1, gemm_kernel_args[group_id].ds_grid_desc_mblock_mperblock_nblock_nperblock_, gemm_kernel_args[group_id].e_grid_desc_mblock_mperblock_nblock_nperblock_, gemm_kernel_args[group_id].block_2_ctile_map_, @@ -132,8 +136,8 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) { if(gemm_kernel_args[group_id].HasMainKBlockLoop_) { - GridwiseGemm::template Run( p_shared, - gemm_kernel_args[group_id].a_grid_desc_m_k_, - gemm_kernel_args[group_id].b_grid_desc_n_k_, + a_grid_desc_ak0_m_ak1, + b_grid_desc_bk0_n_bk1, gemm_kernel_args[group_id].ds_grid_desc_mblock_mperblock_nblock_nperblock_, gemm_kernel_args[group_id].e_grid_desc_mblock_mperblock_nblock_nperblock_, gemm_kernel_args[group_id].block_2_ctile_map_, @@ -157,8 +161,8 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) } else { - GridwiseGemm::template Run( p_shared, - gemm_kernel_args[group_id].a_grid_desc_m_k_, - gemm_kernel_args[group_id].b_grid_desc_n_k_, + a_grid_desc_ak0_m_ak1, + b_grid_desc_bk0_n_bk1, gemm_kernel_args[group_id].ds_grid_desc_mblock_mperblock_nblock_nperblock_, gemm_kernel_args[group_id].e_grid_desc_mblock_mperblock_nblock_nperblock_, gemm_kernel_args[group_id].block_2_ctile_map_, @@ -374,9 +378,9 @@ struct DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffleV3 static auto GetDummyABDsEGridDescriptor(const ConvToGemmBwdDataTransform& conv_to_gemm_transform) { - const auto a_grid_desc_m_k = conv_to_gemm_transform.MakeADescriptor_AK0_M_AK1(); + const auto a_grid_desc_ak0_m_ak1 = conv_to_gemm_transform.MakeADescriptor_AK0_M_AK1(); - const auto b_grid_desc_n_k = conv_to_gemm_transform.MakeBDescriptor_BK0_N_BK1(); + const auto b_grid_desc_bk0_n_bk1 = conv_to_gemm_transform.MakeBDescriptor_BK0_N_BK1(); const auto ds_grid_desc_m_n = generate_tuple( [&](auto i) { @@ -409,11 +413,13 @@ struct DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffleV3 if constexpr(CTranspose) { - return make_tuple(b_grid_desc_n_k, a_grid_desc_m_k, ds_grid_desc_m_n, e_grid_desc_m_n); + return make_tuple( + b_grid_desc_bk0_n_bk1, a_grid_desc_ak0_m_ak1, ds_grid_desc_m_n, e_grid_desc_m_n); } else { - return make_tuple(a_grid_desc_m_k, b_grid_desc_n_k, ds_grid_desc_m_n, e_grid_desc_m_n); + return make_tuple( + a_grid_desc_ak0_m_ak1, b_grid_desc_bk0_n_bk1, ds_grid_desc_m_n, e_grid_desc_m_n); } } @@ -507,17 +513,31 @@ struct DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffleV3 return grid_desc_m_k; } + template + static auto transform_k0_n_k1_to_n_k(const Desc_K0_N_K1& desc_k0_n_k1) + { + const auto grid_desc_n_k = transform_tensor_descriptor( + desc_k0_n_k1, + make_tuple(make_pass_through_transform(desc_k0_n_k1.GetLength(I1)), + make_merge_transform( + make_tuple(desc_k0_n_k1.GetLength(I0), desc_k0_n_k1.GetLength(I2)))), + make_tuple(Sequence<1>{}, Sequence<0, 2>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + return grid_desc_n_k; + } + // Note: the dummy function is used just to create the alias constexpr static ConvToGemmBwdDataTransform dummy_conv_to_gemm_transform; using ABDsEGridDesc = decltype(GetDummyABDsEGridDescriptor(dummy_conv_to_gemm_transform)); - using AGridDesc_M_K_ = remove_cvref_t>; - using BGridDesc_N_K_ = remove_cvref_t>; + using AGridDesc_AK0_M_AK1 = remove_cvref_t>; + using BGridDesc_BK0_N_BK1 = remove_cvref_t>; using DsGridDesc_M_N = remove_cvref_t>; using EGridDesc_M_N = remove_cvref_t>; - using AGridDesc_M_K = decltype(transform_k0_m_k1_to_m_k(AGridDesc_M_K_{})); - using BGridDesc_N_K = decltype(transform_k0_m_k1_to_m_k(BGridDesc_N_K_{})); + using AGridDesc_M_K = decltype(transform_k0_m_k1_to_m_k(AGridDesc_AK0_M_AK1{})); + using BGridDesc_N_K = decltype(transform_k0_n_k1_to_n_k(BGridDesc_BK0_N_BK1{})); // Note: here we can call gridwise functions with dummy arguments, // just to create the alias @@ -534,8 +554,8 @@ struct DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffleV3 struct GemmArgs { GemmArgs() = default; - GemmArgs(AGridDesc_M_K_ a_grid_desc_m_k, - BGridDesc_N_K_ b_grid_desc_n_k, + GemmArgs(AGridDesc_M_K a_grid_desc_m_k, + BGridDesc_N_K b_grid_desc_n_k, DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock ds_grid_desc_mblock_mperblock_nblock_nperblock, EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock @@ -560,8 +580,8 @@ struct DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffleV3 { } // tensor descriptors for block/thread-wise copy - AGridDesc_M_K_ a_grid_desc_m_k_; - BGridDesc_N_K_ b_grid_desc_n_k_; + AGridDesc_M_K a_grid_desc_m_k_; + BGridDesc_N_K b_grid_desc_n_k_; DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock ds_grid_desc_mblock_mperblock_nblock_nperblock_; EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock e_grid_desc_mblock_mperblock_nblock_nperblock_; @@ -943,8 +963,8 @@ struct DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffleV3 gemm_kernel_args_[gemms_count_ / MaxGroupedGemmGroupsNum][gemms_count_ % MaxGroupedGemmGroupsNum] = - GemmArgs{a_grid_desc_ak0_m_ak1, - b_grid_desc_bk0_n_bk1, + GemmArgs{a_grid_desc_m_k, + b_grid_desc_n_k, GridwiseGemmCTranspose:: MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( ds_grid_desc_m_n, MBlock, NBlock), @@ -1103,8 +1123,7 @@ struct DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffleV3 std::vector e_grid_desc_m_n_container_; // tensor descriptor for block-wise copy - // std::vector a_grid_desc_m_k_container_; - // std::vector b_grid_desc_n_k_container_; + std::vector ds_grid_desc_mblock_mperblock_nblock_nperblock_container_; std::vector diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_wmma_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_wmma_cshuffle_v3.hpp index 6d34a44f629..eb516378b49 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_wmma_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_wmma_cshuffle_v3.hpp @@ -307,6 +307,33 @@ struct DeviceGroupedConvBwdWeight_Wmma_CShuffleV3 batch); } + template + static auto transform_k0_m_k1_to_m_k(const Desc_K0_M_K1& desc_k0_m_k1) + { + const auto grid_desc_m_k = transform_tensor_descriptor( + desc_k0_m_k1, + make_tuple(make_pass_through_transform(desc_k0_m_k1.GetLength(I1)), + make_merge_transform( + make_tuple(desc_k0_m_k1.GetLength(I0), desc_k0_m_k1.GetLength(I2)))), + make_tuple(Sequence<1>{}, Sequence<0, 2>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + return grid_desc_m_k; + } + template + static auto transform_k0_m_k1_to_n_k(const Desc_K0_N_K1& desc_k0_n_k1) + { + const auto grid_desc_n_k = transform_tensor_descriptor( + desc_k0_n_k1, + make_tuple(make_pass_through_transform(desc_k0_n_k1.GetLength(I1)), + make_merge_transform( + make_tuple(desc_k0_n_k1.GetLength(I0), desc_k0_n_k1.GetLength(I2)))), + make_tuple(Sequence<1>{}, Sequence<0, 2>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + return grid_desc_n_k; + } + using NGCHWTransposeDescType = remove_cvref_t({}, {}))>; @@ -322,10 +349,15 @@ struct DeviceGroupedConvBwdWeight_Wmma_CShuffleV3 using ABCGridDescs = decltype(GetABCGridDesc()); - using AGridDesc_M_K = remove_cvref_t; - using BGridDesc_N_K = remove_cvref_t; + using AGridDesc_M_K_ = remove_cvref_t; + using BGridDesc_N_K_ = remove_cvref_t; using CGridDesc_M_N = remove_cvref_t; + + using AGridDesc_M_K = decltype(transform_k0_m_k1_to_m_k(AGridDesc_M_K_{})); + using BGridDesc_N_K = decltype(transform_k0_m_k1_to_n_k(BGridDesc_N_K_{})); + + using Block2TileMapTranspose = BlockToCTileMap_M00_N0_M01Adapt; using GridwiseInOutTranspose = @@ -586,8 +618,8 @@ struct DeviceGroupedConvBwdWeight_Wmma_CShuffleV3 input_right_pads, k_batch_); - a_grid_desc_kbatch_m_k_ = descs[I0]; - b_grid_desc_kbatch_n_k_ = descs[I1]; + a_grid_desc_kbatch_m_k_ = transform_k0_m_k1_to_m_k(descs[I0]); + b_grid_desc_kbatch_n_k_ =transform_k0_m_k1_to_n_k(descs[I1]); c_grid_desc_m_n_ = descs[I2]; // A/B/C Batch Stride @@ -852,7 +884,7 @@ struct DeviceGroupedConvBwdWeight_Wmma_CShuffleV3 std::cout << "K0 value is:" << (GridwiseGemm::CalculateAK0Padded( arg.a_grid_desc_kbatch_m_k_.GetLength(Number<1>{}),arg.k_batch_)) << std::endl; - const index_t num_k_per_block = (GridwiseGemm::CalculateAK0Padded( arg.a_grid_desc_kbatch_m_k_.GetLength(Number<1>{}),arg.k_batch_))/gemm_arg.KBatch; + const index_t num_k_per_block = (GridwiseGemm::CalculateAK0Padded( arg.a_grid_desc_kbatch_m_k_.GetLength(Number<1>{}),arg.k_batch_)); const auto clear_workspace = [&]() { hip_check_error( hipMemsetAsync(p_e_grid, 0, arg.c_space_size_bytes, stream_config.stream_id_)); From 21e9dc2ef2cbd305f7624e228f79a9f36bdbc746 Mon Sep 17 00:00:00 2001 From: apoorva Date: Mon, 2 Feb 2026 11:38:37 +0000 Subject: [PATCH 04/12] Refactored and fixed formatting of bwd_data instances --- ..._conv_bwd_data_multiple_d_wmma_cshuffle_v3.hpp | 15 +-------------- ..._bwd_data_wmma_v3_wave_transfer_instances.hpp} | 0 2 files changed, 1 insertion(+), 14 deletions(-) rename library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/{device_grouped_conv_bwd_data_wmma_v3_wave_tranfer_instances.hpp => device_grouped_conv_bwd_data_wmma_v3_wave_transfer_instances.hpp} (100%) diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle_v3.hpp index 0cb9ff53ee4..f9af4097aae 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle_v3.hpp @@ -513,19 +513,6 @@ struct DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffleV3 return grid_desc_m_k; } - template - static auto transform_k0_n_k1_to_n_k(const Desc_K0_N_K1& desc_k0_n_k1) - { - const auto grid_desc_n_k = transform_tensor_descriptor( - desc_k0_n_k1, - make_tuple(make_pass_through_transform(desc_k0_n_k1.GetLength(I1)), - make_merge_transform( - make_tuple(desc_k0_n_k1.GetLength(I0), desc_k0_n_k1.GetLength(I2)))), - make_tuple(Sequence<1>{}, Sequence<0, 2>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); - - return grid_desc_n_k; - } // Note: the dummy function is used just to create the alias constexpr static ConvToGemmBwdDataTransform dummy_conv_to_gemm_transform; @@ -537,7 +524,7 @@ struct DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffleV3 using EGridDesc_M_N = remove_cvref_t>; using AGridDesc_M_K = decltype(transform_k0_m_k1_to_m_k(AGridDesc_AK0_M_AK1{})); - using BGridDesc_N_K = decltype(transform_k0_n_k1_to_n_k(BGridDesc_BK0_N_BK1{})); + using BGridDesc_N_K = decltype(transform_k0_m_k1_to_m_k(BGridDesc_BK0_N_BK1{})); // Note: here we can call gridwise functions with dummy arguments, // just to create the alias diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_wmma_v3_wave_tranfer_instances.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_wmma_v3_wave_transfer_instances.hpp similarity index 100% rename from library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_wmma_v3_wave_tranfer_instances.hpp rename to library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_wmma_v3_wave_transfer_instances.hpp From 55204c3ce0412e7c9a04a2ffb41fc27c3f8a53c4 Mon Sep 17 00:00:00 2001 From: apoorva Date: Mon, 2 Feb 2026 11:38:56 +0000 Subject: [PATCH 05/12] Added instances and fixed test failures in bwd_wei --- ...bwd_weight_multiple_d_wmma_cshuffle_v3.hpp | 139 ++++++++++++------ ...ouped_conv_bwd_weight_wmma_cshuffle_v3.hpp | 11 +- ...d_data_wmma_v3_wave_transfer_instances.hpp | 21 ++- ..._weight_v3_wmma_wave_transfer_instance.hpp | 93 ++++++++++++ .../grouped_convolution_backward_weight.hpp | 71 +++++++++ ...ouped_convolution_backward_weight_wmma.inc | 66 +++++++++ .../grouped_conv2d_bwd_weight/CMakeLists.txt | 2 + ...kyxc_nhwgk_bf16_wave_transfer_instance.cpp | 40 +++++ ...gkyxc_nhwgk_f16_wave_transfer_instance.cpp | 40 +++++ .../grouped_conv3d_bwd_weight/CMakeLists.txt | 2 + ...yxc_ndhwgk_bf16_wave_transfer_instance.cpp | 40 +++++ ...zyxc_ndhwgk_f16_wave_transfer_instance.cpp | 40 +++++ 12 files changed, 507 insertions(+), 58 deletions(-) create mode 100644 library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_v3_wmma_wave_transfer_instance.hpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/wmma/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_wmma_nhwgc_gkyxc_nhwgk_bf16_wave_transfer_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/wmma/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_wmma_nhwgc_gkyxc_nhwgk_f16_wave_transfer_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/wmma/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_bf16_wave_transfer_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/wmma/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_f16_wave_transfer_instance.cpp diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_multiple_d_wmma_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_multiple_d_wmma_cshuffle_v3.hpp index f662ff834f4..5a73d86ab6b 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_multiple_d_wmma_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_multiple_d_wmma_cshuffle_v3.hpp @@ -29,6 +29,11 @@ #include "ck/host_utility/kernel_launch.hpp" #include "ck/host_utility/flush_cache.hpp" + + +#include "ck/tensor_operation/gpu/device/impl/split_k_utils.hpp" +#include "ck/tensor_operation/gpu/device/matrix_padder.hpp" + #ifdef CK_EXPERIMENTAL_BUILDER #include "ck_tile/builder/reflect/description.hpp" #include "ck_tile/builder/reflect/instance_traits_device_grouped_conv_bwd_weight_multiple_d_wmma_cshuffle_v3.hpp" @@ -39,8 +44,8 @@ namespace tensor_operation { namespace device { template (); + using EpilogueType = + typename std::conditional::type; + + constexpr index_t LDS_size = + GridwiseGemm::template GetSharedMemoryNumberOfByte(); __shared__ char p_shared[LDS_size]; - auto epilogue_args = typename GridwiseGemm::EpilogueCShuffle{}; + auto epilogue_args = EpilogueType{}; - GridwiseGemm::template Run struct DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3 @@ -164,6 +181,15 @@ struct DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3 ComputeTypeA, ComputeTypeB> { + +#if defined USE_WAVE + + static_assert(UseThreadTileTransfer==false && + (ConvBackwardWeightSpecialization==ConvolutionBackwardWeightSpecialization::Filter1x1Stride1Pad0 + ),"Only Filter1x1Stride1Pad0is supported for wavetile transfer" + ); +#endif + using DeviceOp = DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3; using ADataType = OutDataType; @@ -275,6 +301,20 @@ struct DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3 params, batch); } + + template + static auto transform_k0_m_k1_to_m_k(const Desc_K0_M_K1& desc_k0_m_k1) + { + const auto grid_desc_m_k = transform_tensor_descriptor( + desc_k0_m_k1, + make_tuple(make_pass_through_transform(desc_k0_m_k1.GetLength(I1)), + make_merge_transform( + make_tuple(desc_k0_m_k1.GetLength(I0), desc_k0_m_k1.GetLength(I2)))), + make_tuple(Sequence<1>{}, Sequence<0, 2>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + return grid_desc_m_k; + } using ABCGridDescs = decltype(GetABCGridDesc()); @@ -282,6 +322,9 @@ struct DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3 using BGridDesc_K0_N_K1 = remove_cvref_t; using CGridDesc_M_N = remove_cvref_t; + using AGridDesc_M_K = decltype(transform_k0_m_k1_to_m_k(AGridDesc_K0_M_K1{})); + using BGridDesc_N_K = decltype(transform_k0_m_k1_to_m_k(BGridDesc_K0_N_K1{})); + using GridwiseGemm = GridwiseGemm_wmma_cshuffle_v3< tensor_layout::gemm::ColumnMajor, tensor_layout::gemm::RowMajor, @@ -334,7 +377,7 @@ struct DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3 false, // permuteA false, // permuteB false, // IsBPreShuffled - true>; // ForceThreadTileTransfer + UseThreadTileTransfer>; // ForceThreadTileTransfer static constexpr auto MakeElementwiseInputSequence() { @@ -592,8 +635,8 @@ struct DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3 p_b_grid_{p_in_grid}, p_ds_grid_{}, p_e_grid_{p_wei_grid}, - a_grid_desc_kbatch_k0_m_k1_{}, - b_grid_desc_kbatch_k0_n_k1_{}, + a_grid_desc_kbatch_m_k_{}, + b_grid_desc_kbatch_n_k_{}, ce_grid_desc_m_n_{}, c_grid_desc_mblock_mperblock_nblock_nperblock_{}, compute_ptr_offset_of_batch_{}, @@ -687,8 +730,8 @@ struct DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3 compute_ptr_offset_of_batch_.BatchStrideDs_(i) = ds_g_k_c_xs_strides[i][0]; }); - a_grid_desc_kbatch_k0_m_k1_ = descs[I0]; - b_grid_desc_kbatch_k0_n_k1_ = descs[I1]; + a_grid_desc_kbatch_m_k_ = transform_k0_m_k1_to_m_k(descs[I0]); + b_grid_desc_kbatch_n_k_ =transform_k0_m_k1_to_m_k(descs[I1]); ce_grid_desc_m_n_ = descs[I2]; ds_grid_descs_tuple_ = @@ -707,8 +750,8 @@ struct DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3 index_t{1}, std::multiplies<>{}); - const index_t GemmM = a_grid_desc_kbatch_k0_m_k1_.GetLength(I1); - const index_t GemmN = b_grid_desc_kbatch_k0_n_k1_.GetLength(I1); + const index_t GemmM = a_grid_desc_kbatch_m_k_.GetLength(I0); + const index_t GemmN = b_grid_desc_kbatch_n_k_.GetLength(I0); c_grid_desc_mblock_mperblock_nblock_nperblock_ = GridwiseGemm::MakeDEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( @@ -726,8 +769,8 @@ struct DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3 const BDataType* p_b_grid_; DsGridPointerTuple p_ds_grid_; EDataType* p_e_grid_; - AGridDesc_K0_M_K1 a_grid_desc_kbatch_k0_m_k1_; - BGridDesc_K0_N_K1 b_grid_desc_kbatch_k0_n_k1_; + AGridDesc_M_K a_grid_desc_kbatch_m_k_; + BGridDesc_N_K b_grid_desc_kbatch_n_k_; CGridDesc_M_N ce_grid_desc_m_n_; CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock c_grid_desc_mblock_mperblock_nblock_nperblock_; DsGridDesc_M_N ds_grid_descs_tuple_; @@ -784,10 +827,10 @@ struct DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3 { float ave_time = 0; - const index_t GemmM = arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I1); - const index_t GemmN = arg.b_grid_desc_kbatch_k0_n_k1_.GetLength(I1); - const index_t GemmK = arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I0) * - arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I2); + const index_t GemmM = arg.a_grid_desc_kbatch_m_k_.GetLength(I0); + const index_t GemmN = arg.b_grid_desc_kbatch_n_k_.GetLength(I0); + const index_t GemmK = arg.a_grid_desc_kbatch_m_k_.GetLength(I1); + AccDataType* p_e_grid = type_convert(arg.p_workspace_); @@ -817,8 +860,7 @@ struct DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3 index_t K_split = (gemm_arg.K + k_grain - 1) / k_grain * KPerBlock; const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K_split); - const auto num_k_per_block = - arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(Number<0>{}) / gemm_arg.KBatch; + const index_t num_k_per_block = (GridwiseGemm::CalculateAK0Padded( arg.a_grid_desc_kbatch_m_k_.GetLength(Number<1>{}),arg.k_batch_)); const auto clear_workspace = [&]() { hip_check_error(hipMemsetAsync( @@ -831,11 +873,11 @@ struct DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3 typename GridwiseGemm::Argument gemm_arg_ = gemm_arg; std::array size_as_buffers; - size_as_buffers[0] = arg.a_grid_desc_kbatch_k0_m_k1_.GetElementSpaceSize() * + size_as_buffers[0] = arg.a_grid_desc_kbatch_m_k_.GetElementSpaceSize() * sizeof(ADataType) / GridwiseGemm::APackedSize; std::array size_bs_buffers; - size_bs_buffers[0] = arg.b_grid_desc_kbatch_k0_n_k1_.GetElementSpaceSize() * + size_bs_buffers[0] = arg.b_grid_desc_kbatch_n_k_.GetElementSpaceSize() * sizeof(BDataType) / GridwiseGemm::BPackedSize; std::array size_ds_buffers; @@ -865,8 +907,8 @@ struct DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3 dim3(BlockSize), 0, gemm_arg_, - arg.a_grid_desc_kbatch_k0_m_k1_, - arg.b_grid_desc_kbatch_k0_n_k1_, + arg.a_grid_desc_kbatch_m_k_, + arg.b_grid_desc_kbatch_n_k_, arg.c_grid_desc_mblock_mperblock_nblock_nperblock_, arg.compute_ptr_offset_of_batch_, num_k_per_block); @@ -881,8 +923,8 @@ struct DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3 dim3(BlockSize), 0, gemm_arg, - arg.a_grid_desc_kbatch_k0_m_k1_, - arg.b_grid_desc_kbatch_k0_n_k1_, + arg.a_grid_desc_kbatch_m_k_, + arg.b_grid_desc_kbatch_n_k_, arg.c_grid_desc_mblock_mperblock_nblock_nperblock_, arg.compute_ptr_offset_of_batch_, num_k_per_block); @@ -903,8 +945,8 @@ struct DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3 const auto kernel = kernel_grouped_conv_bwd_weight_wmma_cshuffle_v3_multiple_d< GridwiseGemm, - remove_reference_t, - remove_reference_t, + remove_reference_t, + remove_reference_t, remove_reference_t< DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>, ComputePtrOffsetOfStridedBatch, @@ -918,8 +960,8 @@ struct DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3 const auto kernel = kernel_grouped_conv_bwd_weight_wmma_cshuffle_v3_multiple_d< GridwiseGemm, - remove_reference_t, - remove_reference_t, + remove_reference_t, + remove_reference_t, remove_reference_t< DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>, ComputePtrOffsetOfStridedBatch, @@ -944,8 +986,8 @@ struct DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3 const auto kernel = kernel_grouped_conv_bwd_weight_wmma_cshuffle_v3_multiple_d< GridwiseGemm, - remove_reference_t, - remove_reference_t, + remove_reference_t, + remove_reference_t, remove_reference_t< DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>, ComputePtrOffsetOfStridedBatch, @@ -959,8 +1001,8 @@ struct DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3 const auto kernel = kernel_grouped_conv_bwd_weight_wmma_cshuffle_v3_multiple_d< GridwiseGemm, - remove_reference_t, - remove_reference_t, + remove_reference_t, + remove_reference_t, remove_reference_t< DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>, ComputePtrOffsetOfStridedBatch, @@ -1030,10 +1072,17 @@ struct DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3 static bool IsSupportedArgument(const Argument& arg) { - const index_t GemmM = arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I1); - const index_t GemmN = arg.b_grid_desc_kbatch_k0_n_k1_.GetLength(I1); - const index_t GemmK = arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I0) * - arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I2); +#if DISABLE_SPLIT_K_AUTODEDUCE_FOR_ONE_STAGE_KERNELS + if(arg.k_batch_ < 0) + { + return false; + } +#endif + + const index_t GemmM = arg.a_grid_desc_kbatch_m_k_.GetLength(I0); + const index_t GemmN = arg.b_grid_desc_kbatch_n_k_.GetLength(I0); + const index_t GemmK = arg.a_grid_desc_kbatch_m_k_.GetLength(I1); + typename GridwiseGemm::Argument gemm_arg{std::array{nullptr}, // p_as_grid std::array{nullptr}, // p_bs_grid diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_wmma_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_wmma_cshuffle_v3.hpp index eb516378b49..ac23bcd3254 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_wmma_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_wmma_cshuffle_v3.hpp @@ -181,11 +181,18 @@ struct DeviceGroupedConvBwdWeight_Wmma_CShuffleV3 static_assert(is_same_v); using DeviceOp = DeviceGroupedConvBwdWeight_Wmma_CShuffleV3; - + using ADataType = OutDataType; using BDataType = InDataType; using CDataType = WeiDataType; - +// // static const auto F1S1 = ConvolutionBackwardWeightSpecialization::Filter1x1Stride1Pad0; +// #if defined USE_WAVE + +// static_assert(UseThreadTileTransfer==false && +// (ConvBackwardWeightSpecialization==ConvolutionBackwardWeightSpecialization::Filter1x1Stride1Pad0 +// ),"Only Filter1x1Stride1Pad0is supported for wavetile transfer" +// ); +// #endif // If NGCHW then ADataType must be equal to BDataType static_assert(!(is_NGCHW_NGKHW() || is_NGCDHW_NGKDHW()) || diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_wmma_v3_wave_transfer_instances.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_wmma_v3_wave_transfer_instances.hpp index fa613bd8361..4c0f9e52765 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_wmma_v3_wave_transfer_instances.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_wmma_v3_wave_transfer_instances.hpp @@ -43,11 +43,11 @@ template using device_grouped_conv_bwd_data_wmma_cshufflev3_bf16_wave_transfer_instances = std::tuple< // clang-format off - //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MWmma| NWmma| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Pipeline scheduler | Pipeline version | - //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Size| Block| Block| Block| | | WMMA| WMMA| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MWmmaPerWave| NWmmaPerWave| _MBlock_MWaveMPerWmma| ScalarPerVector| | | - //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerWmma| _NWaveNPerWmma| | | - //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - // generic instance + // ##############################################| NDim| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| AElementwise| BElementwise| CDEElementwise| ConvolutionBackward| DoPad| DoPad| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle | CShuffle | CDEBlockTransfer| CDEBlockTransfer| + // ##############################################| Spatial| | | | | Type| Type| Type| DataType| Type| Type| Operation| Operation| Operation| DataSpecialization| GemmM| GemmN| Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MRepeat| NRepeat | _MBlock_MPerBlock| ScalarPerVector| + // ##############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| | + // ##############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + // generic instance DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffleV3, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, S<1,1,1>,BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, false>, DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffleV3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, S<8,8,8>,BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, false>, @@ -63,12 +63,11 @@ template using device_grouped_conv_bwd_data_wmma_cshufflev3_f16_wave_transfer_instances = std::tuple< // clang-format off - //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MWmma| NWmma| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Pipeline scheduler | Pipeline version | - //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Size| Block| Block| Block| | | WMMA| WMMA| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MWmmaPerWave| NWmmaPerWave| _MBlock_MWaveMPerWmma| ScalarPerVector| | | - //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerWmma| _NWaveNPerWmma| | | - //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - - // generic instance + // ##############################################| NDim| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| AElementwise| BElementwise| CDEElementwise| ConvolutionBackward| DoPad| DoPad| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle | CShuffle | CDEBlockTransfer| CDEBlockTransfer| + // ##############################################| Spatial| | | | | Type| Type| Type| DataType| Type| Type| Operation| Operation| Operation| DataSpecialization| GemmM| GemmN| Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MRepeat| NRepeat | _MBlock_MPerBlock| ScalarPerVector| + // ##############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| | + // ##############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + // generic instance DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffleV3, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, S<1,1,1>,BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, false>, DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffleV3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 16, 1, 4>, S<8,8,8>,BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, false>, diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_v3_wmma_wave_transfer_instance.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_v3_wmma_wave_transfer_instance.hpp new file mode 100644 index 00000000000..cd43e558a08 --- /dev/null +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_v3_wmma_wave_transfer_instance.hpp @@ -0,0 +1,93 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT +#define USE_WAVE +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_multiple_d_wmma_cshuffle_v3.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using namespace ck::tensor_layout::convolution; + +using BF16 = ck::bhalf_t; +using F16 = ck::half_t; +using F32 = float; + +#ifdef CK_ENABLE_FP8 +using F8 = ck::f8_t; +#endif + +#ifdef CK_ENABLE_BF8 +using BF8 = ck::bf8_t; +#endif + + + +using Empty_Tuple = ck::Tuple<>; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto ConvBwdWeightDefault = + ck::tensor_operation::device::ConvolutionBackwardWeightSpecialization::Default; + +static constexpr auto ConvBwdWeightFilter1x1Stride1Pad0 = + ck::tensor_operation::device::ConvolutionBackwardWeightSpecialization::Filter1x1Stride1Pad0; + +template +using device_grouped_conv_bwd_weight_v3_wmma_c_shuffle_f16_wave_transfer_instances = std::tuple< + // clang-format off + //#################################################| Num| InLayout| WeiLayout| OutLayout| DsLayout| InData| WeiData| OutData| AccData| DsData| In| Wei| Out| ConvBackward| Block| MPer| NPer| KPer| ABK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CBlockTransfer| BlockGemm| BlockGemm| + //#################################################| Dim| | | | | Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Weight| Size| Block| Block| Block| | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| NRepeat| ClusterLengths| ScalarPerVector| Pipeline| Pipeline | + //#################################################| Spatial| | | | | | | | | | Operation| Operation| Operation| Specialization| | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| MBlock_MPerBlock| _NPerBlock| Scheduler| Version | + //#################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | NBlock_NPerBlock| | | | + // generic instance + DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, Tuple<>, F16, F16, F16, F32, Tuple<>, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 64, 64, 32, 8, 16, 16, 4, 2, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 1, 4, 1, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 1, 4, 1, 1, 1, S<1, 16, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, false>, + + DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, Tuple<>, F16, F16, F16, F32, Tuple<>, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 128, 128, 32, 8, 16, 16, 8, 2, S<4, 32, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 8, 0, S<4, 32, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 8, 0, 1, 1, S<1, 16, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, false> + + // clang-format on + >; + +template +using device_grouped_conv_bwd_weight_v3_wmma_c_shuffle_bf16_wave_transfer_instances = std::tuple< + // clang-format off + //#################################################| Num| InLayout| WeiLayout| OutLayout| DsLayout| InData| WeiData| OutData| AccData| DsData| In| Wei| Out| ConvBackward| Block| MPer| NPer| KPer| ABK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CBlockTransfer| BlockGemm| BlockGemm| + //#################################################| Dim| | | | | Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Weight| Size| Block| Block| Block| | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| NRepeat| ClusterLengths| ScalarPerVector| Pipeline| Pipeline | + //#################################################| Spatial| | | | | | | | | | Operation| Operation| Operation| Specialization| | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| MBlock_MPerBlock| _NPerBlock| Scheduler| Version | + //#################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | NBlock_NPerBlock| | | | + // generic instance + DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, Tuple<>, BF16, BF16, BF16, F32, Tuple<>, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 32, 32, 32, 8, 16, 16, 2, 1, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 2, 2, 0, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 2, 2, 0, 1, 1, S<1, 8, 1, 8>, 2, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, false>, + + DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, Tuple<>, BF16, BF16, BF16, F32, Tuple<>, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 128, 128, 32, 8, 16, 16, 8, 2, S<4, 32, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 8, 0, S<4, 32, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 8, 0, 1, 1, S<1, 16, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, false> + + //clang-format on + >; + + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck + diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight.hpp index c07dc71ac56..7ecb6262128 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight.hpp @@ -12,6 +12,11 @@ #include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp" +#if defined USE_WAVE +#include "ck/tensor_operation/gpu/device/device_grouped_conv_bwd_weight_multiple_d.hpp" + +#endif + #ifdef DL_KERNELS #include "grouped_convolution_backward_weight_dl.inc" #endif @@ -957,7 +962,73 @@ struct DeviceOperationInstanceFactory +struct DeviceOperationInstanceFactory< + ck::tensor_operation::device::DeviceGroupedConvBwdWeightMultipleD< + NumDimSpatial, + InLayout, + WeiLayout, + OutLayout, + DsLayout, + InDataType, + WeiDataType, + OutDataType, + DsDataType, + ck::tensor_operation::element_wise::PassThrough, + ck::tensor_operation::element_wise::PassThrough, + ck::tensor_operation::element_wise::PassThrough, + ComputeTypeA, + ComputeTypeB>> +{ + using DeviceOp = + DeviceGroupedConvBwdWeightMultipleD; + static auto GetInstances() + { + std::vector> op_ptrs; + +#ifdef CK_ENABLE_BF16 + add_device_grouped_conv2d_bwd_weight_wmma_nhwgc_gkyxc_nhwgk_bf16_wave_transfer_instances( + op_ptrs); + add_device_grouped_conv2d_bwd_weight_wmma_nhwgc_gkyxc_nhwgk_f16_wave_transfer_instances( + op_ptrs); +#endif +#ifdef CK_ENABLE_FP16 + add_device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_bf16_wave_transfer_instances( + op_ptrs); + add_device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_f16_wave_transfer_instances( + op_ptrs); +#endif + return op_ptrs; + + } + + }; +#endif } // namespace instance } // namespace device } // namespace tensor_operation diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight_wmma.inc b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight_wmma.inc index 06247019f13..09977754a3d 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight_wmma.inc +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight_wmma.inc @@ -114,6 +114,72 @@ void add_device_grouped_conv3d_bwd_weight_two_stage_wmma_ndhwgc_gkzyxc_ndhwgk_bf PassThrough>>>& instances); #endif +#if defined USE_WAVE +#ifdef CK_ENABLE_BF16 + +void add_device_grouped_conv2d_bwd_weight_wmma_nhwgc_gkyxc_nhwgk_bf16_wave_transfer_instances( + std::vector, + BF16, + BF16, + BF16, + Tuple<>, + PassThrough, + PassThrough, + PassThrough>>>& instances); + +void add_device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_bf16_wave_transfer_instances( + std::vector, + BF16, + BF16, + BF16, + Tuple<>, + PassThrough, + PassThrough, + PassThrough>>>& instances); + +#endif + +#ifdef CK_ENABLE_FP16 + +void add_device_grouped_conv2d_bwd_weight_wmma_nhwgc_gkyxc_nhwgk_f16_wave_transfer_instances( + std::vector, + F16, + F16, + F16, + Tuple<>, + PassThrough, + PassThrough, + PassThrough>>>& instances); + +void add_device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_bf16_wave_transfer_instances( + std::vector, + F16, + F16, + F16, + Tuple<>, + PassThrough, + PassThrough, + PassThrough>>>& instances); + +#endif +#endif + } // namespace instance } // namespace device } // namespace tensor_operation diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/CMakeLists.txt index 268835d5bfd..59a30d4ea9c 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/CMakeLists.txt @@ -79,6 +79,8 @@ list(APPEND GROUPED_CONV2D_BWD_WEIGHT wmma/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_two_stage_wmma_nhwgc_gkyxc_nhwgk_bf16_pipev1_instance.cpp wmma/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_wmma_nhwgc_gkyxc_nhwgk_f16_instance.cpp wmma/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_two_stage_wmma_nhwgc_gkyxc_nhwgk_f16_pipev1_instance.cpp + wmma/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_wmma_nhwgc_gkyxc_nhwgk_bf16_wave_transfer_instance.cpp + wmma/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_wmma_nhwgc_gkyxc_nhwgk_f16_wave_transfer_instance.cpp ) add_instance_library(device_grouped_conv2d_bwd_weight_instance ${GROUPED_CONV2D_BWD_WEIGHT}) diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/wmma/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_wmma_nhwgc_gkyxc_nhwgk_bf16_wave_transfer_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/wmma/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_wmma_nhwgc_gkyxc_nhwgk_bf16_wave_transfer_instance.cpp new file mode 100644 index 00000000000..9b99f910bd4 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/wmma/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_wmma_nhwgc_gkyxc_nhwgk_bf16_wave_transfer_instance.cpp @@ -0,0 +1,40 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_v3_wmma_wave_transfer_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] +void add_device_grouped_conv2d_bwd_weight_wmma_nhwgc_gkyxc_nhwgk_bf16_wave_transfer_instances( + std::vector, + BF16, + BF16, + BF16, + Tuple<>, + PassThrough, + PassThrough, + PassThrough>>>& instances) +{ + // 1. Default + add_device_operation_instances( + instances, + device_grouped_conv_bwd_weight_v3_wmma_c_shuffle_bf16_wave_transfer_instances<2, + NHWGC, + GKYXC, + NHWGK, + ConvBwdWeightFilter1x1Stride1Pad0>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/wmma/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_wmma_nhwgc_gkyxc_nhwgk_f16_wave_transfer_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/wmma/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_wmma_nhwgc_gkyxc_nhwgk_f16_wave_transfer_instance.cpp new file mode 100644 index 00000000000..4f84fce872b --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/wmma/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_wmma_nhwgc_gkyxc_nhwgk_f16_wave_transfer_instance.cpp @@ -0,0 +1,40 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_v3_wmma_wave_transfer_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] +void add_device_grouped_conv2d_bwd_weight_wmma_nhwgc_gkyxc_nhwgk_f16_wave_transfer_instances( + std::vector, + F16, + F16, + F16, + Tuple<>, + PassThrough, + PassThrough, + PassThrough>>>& instances) +{ + // 1. Default + add_device_operation_instances( + instances, + device_grouped_conv_bwd_weight_v3_wmma_c_shuffle_f16_wave_transfer_instances<2, + NHWGC, + GKYXC, + NHWGK, + ConvBwdWeightFilter1x1Stride1Pad0>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/CMakeLists.txt index b246b87178e..e7d8403812a 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/CMakeLists.txt @@ -73,6 +73,8 @@ list(APPEND GROUPED_CONV3D_BWD_WEIGHT wmma/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp wmma/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_two_stage_wmma_ndhwgc_gkzyxc_ndhwgk_f16_pipev1_instance.cpp wmma/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_two_stage_wmma_ndhwgc_gkzyxc_ndhwgk_bf16_pipev1_instance.cpp + wmma/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_f16_wave_transfer_instance.cpp + wmma/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_bf16_wave_transfer_instance.cpp ) if((DTYPES MATCHES "fp8" AND DTYPES MATCHES "bf8" AND DTYPES MATCHES "fp16") OR NOT DEFINED DTYPES) diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/wmma/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_bf16_wave_transfer_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/wmma/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_bf16_wave_transfer_instance.cpp new file mode 100644 index 00000000000..0fb351d88d5 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/wmma/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_bf16_wave_transfer_instance.cpp @@ -0,0 +1,40 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_v3_wmma_wave_transfer_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] +void add_device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_bf16_wave_transfer_instances( + std::vector, + BF16, + BF16, + BF16, + Tuple<>, + PassThrough, + PassThrough, + PassThrough>>>& instances) +{ + // 1. Default + add_device_operation_instances( + instances, + device_grouped_conv_bwd_weight_v3_wmma_c_shuffle_bf16_wave_transfer_instances<3, + NDHWGC, + GKZYXC, + NDHWGK, + ConvBwdWeightFilter1x1Stride1Pad0>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/wmma/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_f16_wave_transfer_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/wmma/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_f16_wave_transfer_instance.cpp new file mode 100644 index 00000000000..6f4db11b4d3 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/wmma/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_f16_wave_transfer_instance.cpp @@ -0,0 +1,40 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_v3_wmma_wave_transfer_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] +void add_device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_f16_wave_transfer_instances( + std::vector, + F16, + F16, + F16, + Tuple<>, + PassThrough, + PassThrough, + PassThrough>>>& instances) +{ + // 1. Default + add_device_operation_instances( + instances, + device_grouped_conv_bwd_weight_v3_wmma_c_shuffle_f16_wave_transfer_instances<3, + NDHWGC, + GKZYXC, + NDHWGK, + ConvBwdWeightFilter1x1Stride1Pad0>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck From 1ea372365516bb4f9f265baba66674d69b0a9f5f Mon Sep 17 00:00:00 2001 From: apoorva Date: Mon, 2 Feb 2026 11:40:37 +0000 Subject: [PATCH 06/12] Fixed clang format --- ...v_bwd_data_multiple_d_wmma_cshuffle_v3.hpp | 18 ++++---- ...bwd_weight_multiple_d_wmma_cshuffle_v3.hpp | 40 ++++++++-------- ...ouped_conv_bwd_weight_wmma_cshuffle_v3.hpp | 46 ++++++++++--------- ..._weight_v3_wmma_wave_transfer_instance.hpp | 5 +- .../grouped_convolution_backward_weight.hpp | 14 +++--- ...kyxc_nhwgk_bf16_wave_transfer_instance.cpp | 11 +++-- ...gkyxc_nhwgk_f16_wave_transfer_instance.cpp | 33 ++++++------- ...yxc_ndhwgk_bf16_wave_transfer_instance.cpp | 33 ++++++------- ...zyxc_ndhwgk_f16_wave_transfer_instance.cpp | 33 ++++++------- 9 files changed, 115 insertions(+), 118 deletions(-) diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle_v3.hpp index f9af4097aae..f4f64f6e771 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle_v3.hpp @@ -98,13 +98,14 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) } group_id = index_t((left + right) / 2); } - - const auto num_k_per_block = - GridwiseGemm::CalculateAK0Padded(gemm_kernel_args[group_id].a_grid_desc_m_k_.GetLength(Number<1>{}),KBatch); + const auto num_k_per_block = GridwiseGemm::CalculateAK0Padded( + gemm_kernel_args[group_id].a_grid_desc_m_k_.GetLength(Number<1>{}), KBatch); - const auto a_grid_desc_ak0_m_ak1 = GridwiseGemm::MakeAGridDescriptor_AK0_M_AK1(gemm_kernel_args[group_id].a_grid_desc_m_k_); - const auto b_grid_desc_bk0_n_bk1 = GridwiseGemm::MakeBGridDescriptor_BK0_N_BK1(gemm_kernel_args[group_id].b_grid_desc_n_k_); + const auto a_grid_desc_ak0_m_ak1 = GridwiseGemm::MakeAGridDescriptor_AK0_M_AK1( + gemm_kernel_args[group_id].a_grid_desc_m_k_); + const auto b_grid_desc_bk0_n_bk1 = GridwiseGemm::MakeBGridDescriptor_BK0_N_BK1( + gemm_kernel_args[group_id].b_grid_desc_n_k_); if constexpr(HasMainKBlockLoopInAllGemm || NoMainKBlockLoopInAllGemm) { @@ -173,7 +174,7 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) CTranspose, TailNum>( p_shared, - a_grid_desc_ak0_m_ak1, + a_grid_desc_ak0_m_ak1, b_grid_desc_bk0_n_bk1, gemm_kernel_args[group_id].ds_grid_desc_mblock_mperblock_nblock_nperblock_, gemm_kernel_args[group_id].e_grid_desc_mblock_mperblock_nblock_nperblock_, @@ -513,15 +514,14 @@ struct DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffleV3 return grid_desc_m_k; } - // Note: the dummy function is used just to create the alias constexpr static ConvToGemmBwdDataTransform dummy_conv_to_gemm_transform; using ABDsEGridDesc = decltype(GetDummyABDsEGridDescriptor(dummy_conv_to_gemm_transform)); using AGridDesc_AK0_M_AK1 = remove_cvref_t>; using BGridDesc_BK0_N_BK1 = remove_cvref_t>; - using DsGridDesc_M_N = remove_cvref_t>; - using EGridDesc_M_N = remove_cvref_t>; + using DsGridDesc_M_N = remove_cvref_t>; + using EGridDesc_M_N = remove_cvref_t>; using AGridDesc_M_K = decltype(transform_k0_m_k1_to_m_k(AGridDesc_AK0_M_AK1{})); using BGridDesc_N_K = decltype(transform_k0_m_k1_to_m_k(BGridDesc_BK0_N_BK1{})); diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_multiple_d_wmma_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_multiple_d_wmma_cshuffle_v3.hpp index 5a73d86ab6b..e943361268d 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_multiple_d_wmma_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_multiple_d_wmma_cshuffle_v3.hpp @@ -29,8 +29,6 @@ #include "ck/host_utility/kernel_launch.hpp" #include "ck/host_utility/flush_cache.hpp" - - #include "ck/tensor_operation/gpu/device/impl/split_k_utils.hpp" #include "ck/tensor_operation/gpu/device/matrix_padder.hpp" @@ -56,7 +54,7 @@ __global__ void #if CK_USE_LAUNCH_BOUNDS __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) #endif - kernel_grouped_conv_bwd_weight_wmma_cshuffle_v3_multiple_d( + kernel_grouped_conv_bwd_weight_wmma_cshuffle_v3_multiple_d( typename GridwiseGemm::Argument karg, const AGridDesc_M_K a_grid_desc_m_k, const BGridDesc_N_K b_grid_desc_n_k, @@ -87,9 +85,8 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) const auto b_grid_desc_bk0_n_bk1 = GridwiseGemm::MakeBGridDescriptor_BK0_N_BK1(b_grid_desc_n_k); - - - GridwiseGemm::template Run struct DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3 @@ -184,10 +181,10 @@ struct DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3 #if defined USE_WAVE - static_assert(UseThreadTileTransfer==false && - (ConvBackwardWeightSpecialization==ConvolutionBackwardWeightSpecialization::Filter1x1Stride1Pad0 - ),"Only Filter1x1Stride1Pad0is supported for wavetile transfer" - ); + static_assert(UseThreadTileTransfer == false && + (ConvBackwardWeightSpecialization == + ConvolutionBackwardWeightSpecialization::Filter1x1Stride1Pad0), + "Only Filter1x1Stride1Pad0is supported for wavetile transfer"); #endif using DeviceOp = DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3; @@ -301,8 +298,8 @@ struct DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3 params, batch); } - - template + + template static auto transform_k0_m_k1_to_m_k(const Desc_K0_M_K1& desc_k0_m_k1) { const auto grid_desc_m_k = transform_tensor_descriptor( @@ -323,7 +320,7 @@ struct DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3 using CGridDesc_M_N = remove_cvref_t; using AGridDesc_M_K = decltype(transform_k0_m_k1_to_m_k(AGridDesc_K0_M_K1{})); - using BGridDesc_N_K = decltype(transform_k0_m_k1_to_m_k(BGridDesc_K0_N_K1{})); + using BGridDesc_N_K = decltype(transform_k0_m_k1_to_m_k(BGridDesc_K0_N_K1{})); using GridwiseGemm = GridwiseGemm_wmma_cshuffle_v3< tensor_layout::gemm::ColumnMajor, @@ -374,9 +371,9 @@ struct DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3 BlkGemmPipelineVer, ComputeTypeA, ComputeTypeB, - false, // permuteA - false, // permuteB - false, // IsBPreShuffled + false, // permuteA + false, // permuteB + false, // IsBPreShuffled UseThreadTileTransfer>; // ForceThreadTileTransfer static constexpr auto MakeElementwiseInputSequence() @@ -731,8 +728,8 @@ struct DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3 }); a_grid_desc_kbatch_m_k_ = transform_k0_m_k1_to_m_k(descs[I0]); - b_grid_desc_kbatch_n_k_ =transform_k0_m_k1_to_m_k(descs[I1]); - ce_grid_desc_m_n_ = descs[I2]; + b_grid_desc_kbatch_n_k_ = transform_k0_m_k1_to_m_k(descs[I1]); + ce_grid_desc_m_n_ = descs[I2]; ds_grid_descs_tuple_ = MakeDsGridDescriptor_M_N(ds_g_k_c_xs_lengths, ds_g_k_c_xs_strides); @@ -831,7 +828,6 @@ struct DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3 const index_t GemmN = arg.b_grid_desc_kbatch_n_k_.GetLength(I0); const index_t GemmK = arg.a_grid_desc_kbatch_m_k_.GetLength(I1); - AccDataType* p_e_grid = type_convert(arg.p_workspace_); // Convolution kernel dispatch @@ -860,7 +856,8 @@ struct DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3 index_t K_split = (gemm_arg.K + k_grain - 1) / k_grain * KPerBlock; const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K_split); - const index_t num_k_per_block = (GridwiseGemm::CalculateAK0Padded( arg.a_grid_desc_kbatch_m_k_.GetLength(Number<1>{}),arg.k_batch_)); + const index_t num_k_per_block = (GridwiseGemm::CalculateAK0Padded( + arg.a_grid_desc_kbatch_m_k_.GetLength(Number<1>{}), arg.k_batch_)); const auto clear_workspace = [&]() { hip_check_error(hipMemsetAsync( @@ -1083,7 +1080,6 @@ struct DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3 const index_t GemmN = arg.b_grid_desc_kbatch_n_k_.GetLength(I0); const index_t GemmK = arg.a_grid_desc_kbatch_m_k_.GetLength(I1); - typename GridwiseGemm::Argument gemm_arg{std::array{nullptr}, // p_as_grid std::array{nullptr}, // p_bs_grid std::array{}, // p_ds_grid diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_wmma_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_wmma_cshuffle_v3.hpp index ac23bcd3254..e60f95d9cb7 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_wmma_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_wmma_cshuffle_v3.hpp @@ -84,9 +84,8 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) const auto b_grid_desc_bk0_n_bk1 = GridwiseGemm::MakeBGridDescriptor_BK0_N_BK1(b_grid_desc_n_k); - - - GridwiseGemm::template Run); using DeviceOp = DeviceGroupedConvBwdWeight_Wmma_CShuffleV3; - + using ADataType = OutDataType; using BDataType = InDataType; using CDataType = WeiDataType; -// // static const auto F1S1 = ConvolutionBackwardWeightSpecialization::Filter1x1Stride1Pad0; -// #if defined USE_WAVE - -// static_assert(UseThreadTileTransfer==false && -// (ConvBackwardWeightSpecialization==ConvolutionBackwardWeightSpecialization::Filter1x1Stride1Pad0 -// ),"Only Filter1x1Stride1Pad0is supported for wavetile transfer" -// ); -// #endif + // // static const auto F1S1 = + // ConvolutionBackwardWeightSpecialization::Filter1x1Stride1Pad0; + // #if defined USE_WAVE + + // static_assert(UseThreadTileTransfer==false && + // (ConvBackwardWeightSpecialization==ConvolutionBackwardWeightSpecialization::Filter1x1Stride1Pad0 + // ),"Only Filter1x1Stride1Pad0is supported for wavetile transfer" + // ); + // #endif // If NGCHW then ADataType must be equal to BDataType static_assert(!(is_NGCHW_NGKHW() || is_NGCDHW_NGKDHW()) || @@ -314,7 +314,7 @@ struct DeviceGroupedConvBwdWeight_Wmma_CShuffleV3 batch); } - template + template static auto transform_k0_m_k1_to_m_k(const Desc_K0_M_K1& desc_k0_m_k1) { const auto grid_desc_m_k = transform_tensor_descriptor( @@ -327,7 +327,7 @@ struct DeviceGroupedConvBwdWeight_Wmma_CShuffleV3 return grid_desc_m_k; } - template + template static auto transform_k0_m_k1_to_n_k(const Desc_K0_N_K1& desc_k0_n_k1) { const auto grid_desc_n_k = transform_tensor_descriptor( @@ -358,12 +358,10 @@ struct DeviceGroupedConvBwdWeight_Wmma_CShuffleV3 using AGridDesc_M_K_ = remove_cvref_t; using BGridDesc_N_K_ = remove_cvref_t; - using CGridDesc_M_N = remove_cvref_t; + using CGridDesc_M_N = remove_cvref_t; - using AGridDesc_M_K = decltype(transform_k0_m_k1_to_m_k(AGridDesc_M_K_{})); - using BGridDesc_N_K = decltype(transform_k0_m_k1_to_n_k(BGridDesc_N_K_{})); - + using BGridDesc_N_K = decltype(transform_k0_m_k1_to_n_k(BGridDesc_N_K_{})); using Block2TileMapTranspose = BlockToCTileMap_M00_N0_M01Adapt; @@ -626,7 +624,7 @@ struct DeviceGroupedConvBwdWeight_Wmma_CShuffleV3 k_batch_); a_grid_desc_kbatch_m_k_ = transform_k0_m_k1_to_m_k(descs[I0]); - b_grid_desc_kbatch_n_k_ =transform_k0_m_k1_to_n_k(descs[I1]); + b_grid_desc_kbatch_n_k_ = transform_k0_m_k1_to_n_k(descs[I1]); c_grid_desc_m_n_ = descs[I2]; // A/B/C Batch Stride @@ -889,10 +887,14 @@ struct DeviceGroupedConvBwdWeight_Wmma_CShuffleV3 index_t K_split = (gemm_arg.K + k_grain - 1) / k_grain * KPerBlock; const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K_split); - std::cout << "K0 value is:" << (GridwiseGemm::CalculateAK0Padded( arg.a_grid_desc_kbatch_m_k_.GetLength(Number<1>{}),arg.k_batch_)) << std::endl; + std::cout << "K0 value is:" + << (GridwiseGemm::CalculateAK0Padded( + arg.a_grid_desc_kbatch_m_k_.GetLength(Number<1>{}), arg.k_batch_)) + << std::endl; - const index_t num_k_per_block = (GridwiseGemm::CalculateAK0Padded( arg.a_grid_desc_kbatch_m_k_.GetLength(Number<1>{}),arg.k_batch_)); - const auto clear_workspace = [&]() { + const index_t num_k_per_block = (GridwiseGemm::CalculateAK0Padded( + arg.a_grid_desc_kbatch_m_k_.GetLength(Number<1>{}), arg.k_batch_)); + const auto clear_workspace = [&]() { hip_check_error( hipMemsetAsync(p_e_grid, 0, arg.c_space_size_bytes, stream_config.stream_id_)); }; diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_v3_wmma_wave_transfer_instance.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_v3_wmma_wave_transfer_instance.hpp index cd43e558a08..5ca4f4d83d8 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_v3_wmma_wave_transfer_instance.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_v3_wmma_wave_transfer_instance.hpp @@ -8,7 +8,6 @@ #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" - namespace ck { namespace tensor_operation { namespace device { @@ -28,8 +27,6 @@ using F8 = ck::f8_t; using BF8 = ck::bf8_t; #endif - - using Empty_Tuple = ck::Tuple<>; template @@ -60,7 +57,7 @@ using device_grouped_conv_bwd_weight_v3_wmma_c_shuffle_f16_wave_transfer_instanc DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, Tuple<>, F16, F16, F16, F32, Tuple<>, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 64, 64, 32, 8, 16, 16, 4, 2, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 1, 4, 1, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 1, 4, 1, 1, 1, S<1, 16, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, false>, DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, Tuple<>, F16, F16, F16, F32, Tuple<>, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 128, 128, 32, 8, 16, 16, 8, 2, S<4, 32, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 8, 0, S<4, 32, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 8, 0, 1, 1, S<1, 16, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, false> - + // clang-format on >; diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight.hpp index 7ecb6262128..c31279d0025 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight.hpp @@ -1013,21 +1013,19 @@ struct DeviceOperationInstanceFactory< #ifdef CK_ENABLE_BF16 add_device_grouped_conv2d_bwd_weight_wmma_nhwgc_gkyxc_nhwgk_bf16_wave_transfer_instances( - op_ptrs); + op_ptrs); add_device_grouped_conv2d_bwd_weight_wmma_nhwgc_gkyxc_nhwgk_f16_wave_transfer_instances( - op_ptrs); + op_ptrs); #endif #ifdef CK_ENABLE_FP16 add_device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_bf16_wave_transfer_instances( - op_ptrs); + op_ptrs); add_device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_f16_wave_transfer_instances( - op_ptrs); + op_ptrs); #endif return op_ptrs; - - } - - }; + } +}; #endif } // namespace instance } // namespace device diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/wmma/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_wmma_nhwgc_gkyxc_nhwgk_bf16_wave_transfer_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/wmma/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_wmma_nhwgc_gkyxc_nhwgk_bf16_wave_transfer_instance.cpp index 9b99f910bd4..36312b814e0 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/wmma/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_wmma_nhwgc_gkyxc_nhwgk_bf16_wave_transfer_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/wmma/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_wmma_nhwgc_gkyxc_nhwgk_bf16_wave_transfer_instance.cpp @@ -27,11 +27,12 @@ void add_device_grouped_conv2d_bwd_weight_wmma_nhwgc_gkyxc_nhwgk_bf16_wave_trans // 1. Default add_device_operation_instances( instances, - device_grouped_conv_bwd_weight_v3_wmma_c_shuffle_bf16_wave_transfer_instances<2, - NHWGC, - GKYXC, - NHWGK, - ConvBwdWeightFilter1x1Stride1Pad0>{}); + device_grouped_conv_bwd_weight_v3_wmma_c_shuffle_bf16_wave_transfer_instances< + 2, + NHWGC, + GKYXC, + NHWGK, + ConvBwdWeightFilter1x1Stride1Pad0>{}); } } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/wmma/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_wmma_nhwgc_gkyxc_nhwgk_f16_wave_transfer_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/wmma/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_wmma_nhwgc_gkyxc_nhwgk_f16_wave_transfer_instance.cpp index 4f84fce872b..60287b459c2 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/wmma/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_wmma_nhwgc_gkyxc_nhwgk_f16_wave_transfer_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/wmma/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_wmma_nhwgc_gkyxc_nhwgk_f16_wave_transfer_instance.cpp @@ -12,26 +12,27 @@ namespace instance { // Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] void add_device_grouped_conv2d_bwd_weight_wmma_nhwgc_gkyxc_nhwgk_f16_wave_transfer_instances( std::vector, - F16, - F16, - F16, - Tuple<>, - PassThrough, - PassThrough, - PassThrough>>>& instances) + NHWGC, + GKYXC, + NHWGK, + Tuple<>, + F16, + F16, + F16, + Tuple<>, + PassThrough, + PassThrough, + PassThrough>>>& instances) { // 1. Default add_device_operation_instances( instances, - device_grouped_conv_bwd_weight_v3_wmma_c_shuffle_f16_wave_transfer_instances<2, - NHWGC, - GKYXC, - NHWGK, - ConvBwdWeightFilter1x1Stride1Pad0>{}); + device_grouped_conv_bwd_weight_v3_wmma_c_shuffle_f16_wave_transfer_instances< + 2, + NHWGC, + GKYXC, + NHWGK, + ConvBwdWeightFilter1x1Stride1Pad0>{}); } } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/wmma/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_bf16_wave_transfer_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/wmma/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_bf16_wave_transfer_instance.cpp index 0fb351d88d5..eee14e32364 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/wmma/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_bf16_wave_transfer_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/wmma/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_bf16_wave_transfer_instance.cpp @@ -12,26 +12,27 @@ namespace instance { // Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] void add_device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_bf16_wave_transfer_instances( std::vector, - BF16, - BF16, - BF16, - Tuple<>, - PassThrough, - PassThrough, - PassThrough>>>& instances) + NDHWGC, + GKZYXC, + NDHWGK, + Tuple<>, + BF16, + BF16, + BF16, + Tuple<>, + PassThrough, + PassThrough, + PassThrough>>>& instances) { // 1. Default add_device_operation_instances( instances, - device_grouped_conv_bwd_weight_v3_wmma_c_shuffle_bf16_wave_transfer_instances<3, - NDHWGC, - GKZYXC, - NDHWGK, - ConvBwdWeightFilter1x1Stride1Pad0>{}); + device_grouped_conv_bwd_weight_v3_wmma_c_shuffle_bf16_wave_transfer_instances< + 3, + NDHWGC, + GKZYXC, + NDHWGK, + ConvBwdWeightFilter1x1Stride1Pad0>{}); } } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/wmma/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_f16_wave_transfer_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/wmma/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_f16_wave_transfer_instance.cpp index 6f4db11b4d3..0eeef111a28 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/wmma/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_f16_wave_transfer_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/wmma/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_f16_wave_transfer_instance.cpp @@ -12,26 +12,27 @@ namespace instance { // Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] void add_device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_f16_wave_transfer_instances( std::vector, - F16, - F16, - F16, - Tuple<>, - PassThrough, - PassThrough, - PassThrough>>>& instances) + NDHWGC, + GKZYXC, + NDHWGK, + Tuple<>, + F16, + F16, + F16, + Tuple<>, + PassThrough, + PassThrough, + PassThrough>>>& instances) { // 1. Default add_device_operation_instances( instances, - device_grouped_conv_bwd_weight_v3_wmma_c_shuffle_f16_wave_transfer_instances<3, - NDHWGC, - GKZYXC, - NDHWGK, - ConvBwdWeightFilter1x1Stride1Pad0>{}); + device_grouped_conv_bwd_weight_v3_wmma_c_shuffle_f16_wave_transfer_instances< + 3, + NDHWGC, + GKZYXC, + NDHWGK, + ConvBwdWeightFilter1x1Stride1Pad0>{}); } } // namespace instance From 845e14d7305c7acb8dda5b54ea0b6d11fafbafa1 Mon Sep 17 00:00:00 2001 From: apoorva Date: Mon, 2 Feb 2026 12:58:49 +0000 Subject: [PATCH 07/12] Reverted unused device impl and updated macros --- ...v_bwd_data_multiple_d_wmma_cshuffle_v3.hpp | 2 +- ...bwd_weight_multiple_d_wmma_cshuffle_v3.hpp | 2 +- ...ouped_conv_bwd_weight_wmma_cshuffle_v3.hpp | 185 ++++++------------ ...d_data_wmma_v3_wave_transfer_instances.hpp | 2 +- ..._weight_v3_wmma_wave_transfer_instance.hpp | 2 +- .../grouped_convolution_backward_weight.hpp | 4 +- ...ouped_convolution_backward_weight_wmma.inc | 2 +- 7 files changed, 71 insertions(+), 128 deletions(-) diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle_v3.hpp index f4f64f6e771..6b635b6a23a 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle_v3.hpp @@ -282,7 +282,7 @@ struct DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffleV3 static_assert(NDimSpatial == 2 || NDimSpatial == 3, "wrong! only implemented for 2D and 3D now"); -#ifdef USE_WAVE_TRANSFER +#ifdef USE_WAVE_TRANSFER_BWD_DATA static_assert(UseThreadTileTransfer == false && (ConvBackwardDataSpecialization == diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_multiple_d_wmma_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_multiple_d_wmma_cshuffle_v3.hpp index e943361268d..37283dfdb36 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_multiple_d_wmma_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_multiple_d_wmma_cshuffle_v3.hpp @@ -179,7 +179,7 @@ struct DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3 ComputeTypeB> { -#if defined USE_WAVE +#if defined USE_WAVE_TRANSFER_BWD_WEI static_assert(UseThreadTileTransfer == false && (ConvBackwardWeightSpecialization == diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_wmma_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_wmma_cshuffle_v3.hpp index e60f95d9cb7..b2ae092c274 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_wmma_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_wmma_cshuffle_v3.hpp @@ -41,8 +41,8 @@ namespace tensor_operation { namespace device { template ::type; - - constexpr index_t LDS_size = - GridwiseGemm::template GetSharedMemoryNumberOfByte(); + constexpr index_t LDS_size = GridwiseGemm::template GetSharedMemoryNumberOfByte< + typename GridwiseGemm::EpilogueCShuffle>(); __shared__ char p_shared[LDS_size]; - auto epilogue_args = EpilogueType{}; + auto epilogue_args = typename GridwiseGemm::EpilogueCShuffle{}; - const auto a_grid_desc_ak0_m_ak1 = - GridwiseGemm::MakeAGridDescriptor_AK0_M_AK1(a_grid_desc_m_k); - - const auto b_grid_desc_bk0_n_bk1 = - GridwiseGemm::MakeBGridDescriptor_BK0_N_BK1(b_grid_desc_n_k); - - GridwiseGemm::template Run() || is_NGCDHW_NGKDHW()) || @@ -314,33 +293,6 @@ struct DeviceGroupedConvBwdWeight_Wmma_CShuffleV3 batch); } - template - static auto transform_k0_m_k1_to_m_k(const Desc_K0_M_K1& desc_k0_m_k1) - { - const auto grid_desc_m_k = transform_tensor_descriptor( - desc_k0_m_k1, - make_tuple(make_pass_through_transform(desc_k0_m_k1.GetLength(I1)), - make_merge_transform( - make_tuple(desc_k0_m_k1.GetLength(I0), desc_k0_m_k1.GetLength(I2)))), - make_tuple(Sequence<1>{}, Sequence<0, 2>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); - - return grid_desc_m_k; - } - template - static auto transform_k0_m_k1_to_n_k(const Desc_K0_N_K1& desc_k0_n_k1) - { - const auto grid_desc_n_k = transform_tensor_descriptor( - desc_k0_n_k1, - make_tuple(make_pass_through_transform(desc_k0_n_k1.GetLength(I1)), - make_merge_transform( - make_tuple(desc_k0_n_k1.GetLength(I0), desc_k0_n_k1.GetLength(I2)))), - make_tuple(Sequence<1>{}, Sequence<0, 2>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); - - return grid_desc_n_k; - } - using NGCHWTransposeDescType = remove_cvref_t({}, {}))>; @@ -356,12 +308,9 @@ struct DeviceGroupedConvBwdWeight_Wmma_CShuffleV3 using ABCGridDescs = decltype(GetABCGridDesc()); - using AGridDesc_M_K_ = remove_cvref_t; - using BGridDesc_N_K_ = remove_cvref_t; - using CGridDesc_M_N = remove_cvref_t; - - using AGridDesc_M_K = decltype(transform_k0_m_k1_to_m_k(AGridDesc_M_K_{})); - using BGridDesc_N_K = decltype(transform_k0_m_k1_to_n_k(BGridDesc_N_K_{})); + using AGridDesc_K0_M_K1 = remove_cvref_t; + using BGridDesc_K0_N_K1 = remove_cvref_t; + using CGridDesc_M_N = remove_cvref_t; using Block2TileMapTranspose = BlockToCTileMap_M00_N0_M01Adapt; @@ -452,10 +401,10 @@ struct DeviceGroupedConvBwdWeight_Wmma_CShuffleV3 BlkGemmPipelineVer, ComputeTypeA, ComputeTypeB, - false, // PermuteA - false, // permuteB - false, // IsBPreshuffle - UseThreadTileTransfer>; // ForceThreadTileTransfer + false, // PermuteA + false, // permuteB + false, // IsBPreshuffle + true>; // ForceThreadTileTransfer // Argument using CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = @@ -485,8 +434,8 @@ struct DeviceGroupedConvBwdWeight_Wmma_CShuffleV3 &max_occupancy, kernel_grouped_conv_bwd_weight_wmma_cshuffle_v3< GridwiseGemm, - remove_reference_t, - remove_reference_t, + remove_reference_t, + remove_reference_t, remove_reference_t, ComputePtrOffsetOfStridedBatch, true, @@ -524,8 +473,8 @@ struct DeviceGroupedConvBwdWeight_Wmma_CShuffleV3 : p_a_grid_{p_out_grid}, p_b_grid_{p_in_grid}, p_c_grid_{p_wei_grid}, - a_grid_desc_kbatch_m_k_{}, - b_grid_desc_kbatch_n_k_{}, + a_grid_desc_kbatch_k0_m_k1_{}, + b_grid_desc_kbatch_k0_n_k1_{}, c_grid_desc_m_n_{}, c_grid_desc_mblock_mperblock_nblock_nperblock_{}, compute_ptr_offset_of_batch_{}, @@ -623,16 +572,16 @@ struct DeviceGroupedConvBwdWeight_Wmma_CShuffleV3 input_right_pads, k_batch_); - a_grid_desc_kbatch_m_k_ = transform_k0_m_k1_to_m_k(descs[I0]); - b_grid_desc_kbatch_n_k_ = transform_k0_m_k1_to_n_k(descs[I1]); - c_grid_desc_m_n_ = descs[I2]; + a_grid_desc_kbatch_k0_m_k1_ = descs[I0]; + b_grid_desc_kbatch_k0_n_k1_ = descs[I1]; + c_grid_desc_m_n_ = descs[I2]; // A/B/C Batch Stride compute_ptr_offset_of_batch_.BatchStrideA_ = a_g_n_k_wos_strides_transposed[0]; compute_ptr_offset_of_batch_.BatchStrideB_ = b_g_n_c_wis_strides_transposed[0]; compute_ptr_offset_of_batch_.BatchStrideC_ = e_g_k_c_xs_strides_transposed[0]; - const index_t GemmM = a_grid_desc_kbatch_m_k_.GetLength(I0); - const index_t GemmN = b_grid_desc_kbatch_n_k_.GetLength(I0); + const index_t GemmM = a_grid_desc_kbatch_k0_m_k1_.GetLength(I1); + const index_t GemmN = b_grid_desc_kbatch_k0_n_k1_.GetLength(I1); c_grid_desc_mblock_mperblock_nblock_nperblock_ = GridwiseGemm::MakeDEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( @@ -729,8 +678,8 @@ struct DeviceGroupedConvBwdWeight_Wmma_CShuffleV3 const ADataType* p_a_grid_; const BDataType* p_b_grid_; CDataType* p_c_grid_; - AGridDesc_M_K a_grid_desc_kbatch_m_k_; - BGridDesc_N_K b_grid_desc_kbatch_n_k_; + AGridDesc_K0_M_K1 a_grid_desc_kbatch_k0_m_k1_; + BGridDesc_K0_N_K1 b_grid_desc_kbatch_k0_n_k1_; CGridDesc_M_N c_grid_desc_m_n_; CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock c_grid_desc_mblock_mperblock_nblock_nperblock_; @@ -775,15 +724,17 @@ struct DeviceGroupedConvBwdWeight_Wmma_CShuffleV3 void ShowInfo(const Argument& arg) { - std::cout << "arg.a_grid_desc_kbatch_m_k_{" << arg.a_grid_desc_kbatch_m_k_.GetLength(I0) - << ", " << arg.a_grid_desc_kbatch_m_k_.GetLength(I1) << ", " - << arg.a_grid_desc_kbatch_m_k_.GetLength(I2) << ", " - << arg.a_grid_desc_kbatch_m_k_.GetLength(I3) << "}" << std::endl; - - std::cout << "arg.b_grid_desc_kbatch_n_k_{" << arg.b_grid_desc_kbatch_n_k_.GetLength(I0) - << ", " << arg.b_grid_desc_kbatch_n_k_.GetLength(I1) << ", " - << arg.b_grid_desc_kbatch_n_k_.GetLength(I2) << ", " - << arg.b_grid_desc_kbatch_n_k_.GetLength(I3) << "}" << std::endl; + std::cout << "arg.a_grid_desc_kbatch_k0_m_k1_{" + << arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I0) << ", " + << arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I1) << ", " + << arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I2) << ", " + << arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I3) << "}" << std::endl; + + std::cout << "arg.b_grid_desc_kbatch_k0_n_k1_{" + << arg.b_grid_desc_kbatch_k0_n_k1_.GetLength(I0) << ", " + << arg.b_grid_desc_kbatch_k0_n_k1_.GetLength(I1) << ", " + << arg.b_grid_desc_kbatch_k0_n_k1_.GetLength(I2) << ", " + << arg.b_grid_desc_kbatch_k0_n_k1_.GetLength(I3) << "}" << std::endl; std::cout << "arg.c_grid_desc_m_n_{" << arg.c_grid_desc_m_n_.GetLength(I0) << ", " << arg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl; @@ -793,9 +744,10 @@ struct DeviceGroupedConvBwdWeight_Wmma_CShuffleV3 { float ave_time = 0; - const index_t GemmM = arg.a_grid_desc_kbatch_m_k_.GetLength(I0); - const index_t GemmN = arg.b_grid_desc_kbatch_n_k_.GetLength(I0); - const index_t GemmK = arg.a_grid_desc_kbatch_m_k_.GetLength(I1); + const index_t GemmM = arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I1); + const index_t GemmN = arg.b_grid_desc_kbatch_k0_n_k1_.GetLength(I1); + const index_t GemmK = arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I0) * + arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I2); const ADataType* p_a_grid = arg.p_a_grid_; const BDataType* p_b_grid = arg.p_b_grid_; @@ -887,14 +839,10 @@ struct DeviceGroupedConvBwdWeight_Wmma_CShuffleV3 index_t K_split = (gemm_arg.K + k_grain - 1) / k_grain * KPerBlock; const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K_split); - std::cout << "K0 value is:" - << (GridwiseGemm::CalculateAK0Padded( - arg.a_grid_desc_kbatch_m_k_.GetLength(Number<1>{}), arg.k_batch_)) - << std::endl; + const auto num_k_per_block = + arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(Number<0>{}) / gemm_arg.KBatch; - const index_t num_k_per_block = (GridwiseGemm::CalculateAK0Padded( - arg.a_grid_desc_kbatch_m_k_.GetLength(Number<1>{}), arg.k_batch_)); - const auto clear_workspace = [&]() { + const auto clear_workspace = [&]() { hip_check_error( hipMemsetAsync(p_e_grid, 0, arg.c_space_size_bytes, stream_config.stream_id_)); }; @@ -907,11 +855,11 @@ struct DeviceGroupedConvBwdWeight_Wmma_CShuffleV3 typename GridwiseGemm::Argument gemm_arg_ = gemm_arg; std::array size_as_buffers; - size_as_buffers[0] = arg.a_grid_desc_kbatch_m_k_.GetElementSpaceSize() * + size_as_buffers[0] = arg.a_grid_desc_kbatch_k0_m_k1_.GetElementSpaceSize() * sizeof(ADataType) / GridwiseGemm::APackedSize; std::array size_bs_buffers; - size_bs_buffers[0] = arg.b_grid_desc_kbatch_n_k_.GetElementSpaceSize() * + size_bs_buffers[0] = arg.b_grid_desc_kbatch_k0_n_k1_.GetElementSpaceSize() * sizeof(BDataType) / GridwiseGemm::BPackedSize; std::array size_ds_buffers; @@ -941,8 +889,8 @@ struct DeviceGroupedConvBwdWeight_Wmma_CShuffleV3 dim3(BlockSize), 0, gemm_arg_, - arg.a_grid_desc_kbatch_m_k_, - arg.b_grid_desc_kbatch_n_k_, + arg.a_grid_desc_kbatch_k0_m_k1_, + arg.b_grid_desc_kbatch_k0_n_k1_, arg.c_grid_desc_mblock_mperblock_nblock_nperblock_, arg.compute_ptr_offset_of_batch_, num_k_per_block); @@ -957,8 +905,8 @@ struct DeviceGroupedConvBwdWeight_Wmma_CShuffleV3 dim3(BlockSize), 0, gemm_arg, - arg.a_grid_desc_kbatch_m_k_, - arg.b_grid_desc_kbatch_n_k_, + arg.a_grid_desc_kbatch_k0_m_k1_, + arg.b_grid_desc_kbatch_k0_n_k1_, arg.c_grid_desc_mblock_mperblock_nblock_nperblock_, arg.compute_ptr_offset_of_batch_, num_k_per_block); @@ -978,8 +926,8 @@ struct DeviceGroupedConvBwdWeight_Wmma_CShuffleV3 { const auto kernel = kernel_grouped_conv_bwd_weight_wmma_cshuffle_v3< GridwiseGemm, - remove_reference_t, - remove_reference_t, + remove_reference_t, + remove_reference_t, remove_reference_t< DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>, ComputePtrOffsetOfStridedBatch, @@ -992,8 +940,8 @@ struct DeviceGroupedConvBwdWeight_Wmma_CShuffleV3 { const auto kernel = kernel_grouped_conv_bwd_weight_wmma_cshuffle_v3< GridwiseGemm, - remove_reference_t, - remove_reference_t, + remove_reference_t, + remove_reference_t, remove_reference_t< DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>, ComputePtrOffsetOfStridedBatch, @@ -1017,8 +965,8 @@ struct DeviceGroupedConvBwdWeight_Wmma_CShuffleV3 { const auto kernel = kernel_grouped_conv_bwd_weight_wmma_cshuffle_v3< GridwiseGemm, - remove_reference_t, - remove_reference_t, + remove_reference_t, + remove_reference_t, remove_reference_t< DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>, ComputePtrOffsetOfStridedBatch, @@ -1031,8 +979,8 @@ struct DeviceGroupedConvBwdWeight_Wmma_CShuffleV3 { const auto kernel = kernel_grouped_conv_bwd_weight_wmma_cshuffle_v3< GridwiseGemm, - remove_reference_t, - remove_reference_t, + remove_reference_t, + remove_reference_t, remove_reference_t< DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>, ComputePtrOffsetOfStridedBatch, @@ -1094,15 +1042,10 @@ struct DeviceGroupedConvBwdWeight_Wmma_CShuffleV3 static bool IsSupportedArgument(const Argument& arg) { -#if DISABLE_SPLIT_K_AUTODEDUCE_FOR_ONE_STAGE_KERNELS - if(arg.k_batch_ < 0) - { - return false; - } -#endif - const index_t GemmM = arg.a_grid_desc_kbatch_m_k_.GetLength(I0); - const index_t GemmN = arg.b_grid_desc_kbatch_n_k_.GetLength(I0); - const index_t GemmK = arg.a_grid_desc_kbatch_m_k_.GetLength(I1); + const index_t GemmM = arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I1); + const index_t GemmN = arg.b_grid_desc_kbatch_k0_n_k1_.GetLength(I1); + const index_t GemmK = arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I0) * + arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I2); typename GridwiseGemm::Argument gemm_arg{std::array{nullptr}, // p_as_grid std::array{nullptr}, // p_bs_grid diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_wmma_v3_wave_transfer_instances.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_wmma_v3_wave_transfer_instances.hpp index 4c0f9e52765..6bc0ff8b4f0 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_wmma_v3_wave_transfer_instances.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_wmma_v3_wave_transfer_instances.hpp @@ -14,7 +14,7 @@ namespace tensor_operation { namespace device { namespace instance { -#define USE_WAVE_TRANSFER +#define USE_WAVE_TRANSFER_BWD_DATA using BF16 = ck::bhalf_t; using F16 = ck::half_t; diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_v3_wmma_wave_transfer_instance.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_v3_wmma_wave_transfer_instance.hpp index 5ca4f4d83d8..4984ac1cec8 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_v3_wmma_wave_transfer_instance.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_v3_wmma_wave_transfer_instance.hpp @@ -1,6 +1,6 @@ // Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -#define USE_WAVE +#define USE_WAVE_TRANSFER_BWD_WEI #include "ck/ck.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp" #include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_multiple_d_wmma_cshuffle_v3.hpp" diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight.hpp index c31279d0025..56f511bc894 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight.hpp @@ -12,7 +12,7 @@ #include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp" -#if defined USE_WAVE +#if defined USE_WAVE_TRANSFER_BWD_WEI #include "ck/tensor_operation/gpu/device/device_grouped_conv_bwd_weight_multiple_d.hpp" #endif @@ -962,7 +962,7 @@ struct DeviceOperationInstanceFactory>>& instances); #endif -#if defined USE_WAVE +#if defined USE_WAVE_TRANSFER_BWD_WEI #ifdef CK_ENABLE_BF16 void add_device_grouped_conv2d_bwd_weight_wmma_nhwgc_gkyxc_nhwgk_bf16_wave_transfer_instances( From ae834e1a68c79aa3b21d76d6976431d77ff46b7b Mon Sep 17 00:00:00 2001 From: apoorva Date: Mon, 2 Feb 2026 14:04:58 +0000 Subject: [PATCH 08/12] editing include files as of renamed files --- ...ta_wmma_v3_wave_transfer_nhwgc_gkyxc_nhwgk_bf16_instance.cpp | 2 +- ...ata_wmma_v3_wave_transfer_nhwgc_gkyxc_nhwgk_f16_instance.cpp | 2 +- ...wmma_v3_wave_transfer_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp | 2 +- ..._wmma_v3_wave_transfer_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp | 2 +- 4 files changed, 4 insertions(+), 4 deletions(-) diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/wmma/device_grouped_conv2d_bwd_data_wmma_v3_wave_transfer_nhwgc_gkyxc_nhwgk_bf16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/wmma/device_grouped_conv2d_bwd_data_wmma_v3_wave_transfer_nhwgc_gkyxc_nhwgk_bf16_instance.cpp index a5027772e98..8ac4f078cce 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/wmma/device_grouped_conv2d_bwd_data_wmma_v3_wave_transfer_nhwgc_gkyxc_nhwgk_bf16_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/wmma/device_grouped_conv2d_bwd_data_wmma_v3_wave_transfer_nhwgc_gkyxc_nhwgk_bf16_instance.cpp @@ -2,7 +2,7 @@ // SPDX-License-Identifier: MIT #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" -#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_wmma_v3_wave_tranfer_instances.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_wmma_v3_wave_transfer_instances.hpp" namespace ck { namespace tensor_operation { diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/wmma/device_grouped_conv2d_bwd_data_wmma_v3_wave_transfer_nhwgc_gkyxc_nhwgk_f16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/wmma/device_grouped_conv2d_bwd_data_wmma_v3_wave_transfer_nhwgc_gkyxc_nhwgk_f16_instance.cpp index 91516e81df4..a47837768d5 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/wmma/device_grouped_conv2d_bwd_data_wmma_v3_wave_transfer_nhwgc_gkyxc_nhwgk_f16_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/wmma/device_grouped_conv2d_bwd_data_wmma_v3_wave_transfer_nhwgc_gkyxc_nhwgk_f16_instance.cpp @@ -2,7 +2,7 @@ // SPDX-License-Identifier: MIT #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" -#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_wmma_v3_wave_tranfer_instances.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_wmma_v3_wave_transfer_instances.hpp" namespace ck { namespace tensor_operation { diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/wmma/device_grouped_conv3d_bwd_data_wmma_v3_wave_transfer_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/wmma/device_grouped_conv3d_bwd_data_wmma_v3_wave_transfer_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp index 9a57d84a434..9f45f9527ea 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/wmma/device_grouped_conv3d_bwd_data_wmma_v3_wave_transfer_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/wmma/device_grouped_conv3d_bwd_data_wmma_v3_wave_transfer_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp @@ -2,7 +2,7 @@ // SPDX-License-Identifier: MIT #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" -#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_wmma_v3_wave_tranfer_instances.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_wmma_v3_wave_transfer_instances.hpp" namespace ck { namespace tensor_operation { diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/wmma/device_grouped_conv3d_bwd_data_wmma_v3_wave_transfer_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/wmma/device_grouped_conv3d_bwd_data_wmma_v3_wave_transfer_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp index 80152c2bce5..a237027bcf5 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/wmma/device_grouped_conv3d_bwd_data_wmma_v3_wave_transfer_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/wmma/device_grouped_conv3d_bwd_data_wmma_v3_wave_transfer_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp @@ -2,7 +2,7 @@ // SPDX-License-Identifier: MIT #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" -#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_wmma_v3_wave_tranfer_instances.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_wmma_v3_wave_transfer_instances.hpp" namespace ck { namespace tensor_operation { From 41653735b8670eade4666ff18e0cafed90ee78b6 Mon Sep 17 00:00:00 2001 From: apoorva Date: Tue, 3 Feb 2026 08:50:36 +0000 Subject: [PATCH 09/12] Fixing merge conflict --- ...grouped_conv_bwd_weight_multiple_d_wmma_cshuffle_v3.hpp | 7 ------- 1 file changed, 7 deletions(-) diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_multiple_d_wmma_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_multiple_d_wmma_cshuffle_v3.hpp index 37283dfdb36..88cdef7548b 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_multiple_d_wmma_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_multiple_d_wmma_cshuffle_v3.hpp @@ -1069,13 +1069,6 @@ struct DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3 static bool IsSupportedArgument(const Argument& arg) { -#if DISABLE_SPLIT_K_AUTODEDUCE_FOR_ONE_STAGE_KERNELS - if(arg.k_batch_ < 0) - { - return false; - } -#endif - const index_t GemmM = arg.a_grid_desc_kbatch_m_k_.GetLength(I0); const index_t GemmN = arg.b_grid_desc_kbatch_n_k_.GetLength(I0); const index_t GemmK = arg.a_grid_desc_kbatch_m_k_.GetLength(I1); From d514c8fca8b3b0e749b8c7ced06debbcb93c9235 Mon Sep 17 00:00:00 2001 From: apoorva Date: Tue, 3 Feb 2026 13:04:31 +0000 Subject: [PATCH 10/12] Revert "Reverted unused device impl and updated macros" This reverts commit 845e14d7305c7acb8dda5b54ea0b6d11fafbafa1. --- ...v_bwd_data_multiple_d_wmma_cshuffle_v3.hpp | 2 +- ...bwd_weight_multiple_d_wmma_cshuffle_v3.hpp | 2 +- ...ouped_conv_bwd_weight_wmma_cshuffle_v3.hpp | 185 ++++++++++++------ ...d_data_wmma_v3_wave_transfer_instances.hpp | 2 +- ..._weight_v3_wmma_wave_transfer_instance.hpp | 2 +- .../grouped_convolution_backward_weight.hpp | 4 +- ...ouped_convolution_backward_weight_wmma.inc | 2 +- 7 files changed, 128 insertions(+), 71 deletions(-) diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle_v3.hpp index 6b635b6a23a..f4f64f6e771 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle_v3.hpp @@ -282,7 +282,7 @@ struct DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffleV3 static_assert(NDimSpatial == 2 || NDimSpatial == 3, "wrong! only implemented for 2D and 3D now"); -#ifdef USE_WAVE_TRANSFER_BWD_DATA +#ifdef USE_WAVE_TRANSFER static_assert(UseThreadTileTransfer == false && (ConvBackwardDataSpecialization == diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_multiple_d_wmma_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_multiple_d_wmma_cshuffle_v3.hpp index 88cdef7548b..208855149eb 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_multiple_d_wmma_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_multiple_d_wmma_cshuffle_v3.hpp @@ -179,7 +179,7 @@ struct DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3 ComputeTypeB> { -#if defined USE_WAVE_TRANSFER_BWD_WEI +#if defined USE_WAVE static_assert(UseThreadTileTransfer == false && (ConvBackwardWeightSpecialization == diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_wmma_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_wmma_cshuffle_v3.hpp index b2ae092c274..e60f95d9cb7 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_wmma_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_wmma_cshuffle_v3.hpp @@ -41,8 +41,8 @@ namespace tensor_operation { namespace device { template (); + using EpilogueType = + typename std::conditional::type; + + constexpr index_t LDS_size = + GridwiseGemm::template GetSharedMemoryNumberOfByte(); __shared__ char p_shared[LDS_size]; - auto epilogue_args = typename GridwiseGemm::EpilogueCShuffle{}; + auto epilogue_args = EpilogueType{}; - GridwiseGemm::template Run() || is_NGCDHW_NGKDHW()) || @@ -293,6 +314,33 @@ struct DeviceGroupedConvBwdWeight_Wmma_CShuffleV3 batch); } + template + static auto transform_k0_m_k1_to_m_k(const Desc_K0_M_K1& desc_k0_m_k1) + { + const auto grid_desc_m_k = transform_tensor_descriptor( + desc_k0_m_k1, + make_tuple(make_pass_through_transform(desc_k0_m_k1.GetLength(I1)), + make_merge_transform( + make_tuple(desc_k0_m_k1.GetLength(I0), desc_k0_m_k1.GetLength(I2)))), + make_tuple(Sequence<1>{}, Sequence<0, 2>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + return grid_desc_m_k; + } + template + static auto transform_k0_m_k1_to_n_k(const Desc_K0_N_K1& desc_k0_n_k1) + { + const auto grid_desc_n_k = transform_tensor_descriptor( + desc_k0_n_k1, + make_tuple(make_pass_through_transform(desc_k0_n_k1.GetLength(I1)), + make_merge_transform( + make_tuple(desc_k0_n_k1.GetLength(I0), desc_k0_n_k1.GetLength(I2)))), + make_tuple(Sequence<1>{}, Sequence<0, 2>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + return grid_desc_n_k; + } + using NGCHWTransposeDescType = remove_cvref_t({}, {}))>; @@ -308,9 +356,12 @@ struct DeviceGroupedConvBwdWeight_Wmma_CShuffleV3 using ABCGridDescs = decltype(GetABCGridDesc()); - using AGridDesc_K0_M_K1 = remove_cvref_t; - using BGridDesc_K0_N_K1 = remove_cvref_t; - using CGridDesc_M_N = remove_cvref_t; + using AGridDesc_M_K_ = remove_cvref_t; + using BGridDesc_N_K_ = remove_cvref_t; + using CGridDesc_M_N = remove_cvref_t; + + using AGridDesc_M_K = decltype(transform_k0_m_k1_to_m_k(AGridDesc_M_K_{})); + using BGridDesc_N_K = decltype(transform_k0_m_k1_to_n_k(BGridDesc_N_K_{})); using Block2TileMapTranspose = BlockToCTileMap_M00_N0_M01Adapt; @@ -401,10 +452,10 @@ struct DeviceGroupedConvBwdWeight_Wmma_CShuffleV3 BlkGemmPipelineVer, ComputeTypeA, ComputeTypeB, - false, // PermuteA - false, // permuteB - false, // IsBPreshuffle - true>; // ForceThreadTileTransfer + false, // PermuteA + false, // permuteB + false, // IsBPreshuffle + UseThreadTileTransfer>; // ForceThreadTileTransfer // Argument using CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = @@ -434,8 +485,8 @@ struct DeviceGroupedConvBwdWeight_Wmma_CShuffleV3 &max_occupancy, kernel_grouped_conv_bwd_weight_wmma_cshuffle_v3< GridwiseGemm, - remove_reference_t, - remove_reference_t, + remove_reference_t, + remove_reference_t, remove_reference_t, ComputePtrOffsetOfStridedBatch, true, @@ -473,8 +524,8 @@ struct DeviceGroupedConvBwdWeight_Wmma_CShuffleV3 : p_a_grid_{p_out_grid}, p_b_grid_{p_in_grid}, p_c_grid_{p_wei_grid}, - a_grid_desc_kbatch_k0_m_k1_{}, - b_grid_desc_kbatch_k0_n_k1_{}, + a_grid_desc_kbatch_m_k_{}, + b_grid_desc_kbatch_n_k_{}, c_grid_desc_m_n_{}, c_grid_desc_mblock_mperblock_nblock_nperblock_{}, compute_ptr_offset_of_batch_{}, @@ -572,16 +623,16 @@ struct DeviceGroupedConvBwdWeight_Wmma_CShuffleV3 input_right_pads, k_batch_); - a_grid_desc_kbatch_k0_m_k1_ = descs[I0]; - b_grid_desc_kbatch_k0_n_k1_ = descs[I1]; - c_grid_desc_m_n_ = descs[I2]; + a_grid_desc_kbatch_m_k_ = transform_k0_m_k1_to_m_k(descs[I0]); + b_grid_desc_kbatch_n_k_ = transform_k0_m_k1_to_n_k(descs[I1]); + c_grid_desc_m_n_ = descs[I2]; // A/B/C Batch Stride compute_ptr_offset_of_batch_.BatchStrideA_ = a_g_n_k_wos_strides_transposed[0]; compute_ptr_offset_of_batch_.BatchStrideB_ = b_g_n_c_wis_strides_transposed[0]; compute_ptr_offset_of_batch_.BatchStrideC_ = e_g_k_c_xs_strides_transposed[0]; - const index_t GemmM = a_grid_desc_kbatch_k0_m_k1_.GetLength(I1); - const index_t GemmN = b_grid_desc_kbatch_k0_n_k1_.GetLength(I1); + const index_t GemmM = a_grid_desc_kbatch_m_k_.GetLength(I0); + const index_t GemmN = b_grid_desc_kbatch_n_k_.GetLength(I0); c_grid_desc_mblock_mperblock_nblock_nperblock_ = GridwiseGemm::MakeDEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( @@ -678,8 +729,8 @@ struct DeviceGroupedConvBwdWeight_Wmma_CShuffleV3 const ADataType* p_a_grid_; const BDataType* p_b_grid_; CDataType* p_c_grid_; - AGridDesc_K0_M_K1 a_grid_desc_kbatch_k0_m_k1_; - BGridDesc_K0_N_K1 b_grid_desc_kbatch_k0_n_k1_; + AGridDesc_M_K a_grid_desc_kbatch_m_k_; + BGridDesc_N_K b_grid_desc_kbatch_n_k_; CGridDesc_M_N c_grid_desc_m_n_; CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock c_grid_desc_mblock_mperblock_nblock_nperblock_; @@ -724,17 +775,15 @@ struct DeviceGroupedConvBwdWeight_Wmma_CShuffleV3 void ShowInfo(const Argument& arg) { - std::cout << "arg.a_grid_desc_kbatch_k0_m_k1_{" - << arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I0) << ", " - << arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I1) << ", " - << arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I2) << ", " - << arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I3) << "}" << std::endl; - - std::cout << "arg.b_grid_desc_kbatch_k0_n_k1_{" - << arg.b_grid_desc_kbatch_k0_n_k1_.GetLength(I0) << ", " - << arg.b_grid_desc_kbatch_k0_n_k1_.GetLength(I1) << ", " - << arg.b_grid_desc_kbatch_k0_n_k1_.GetLength(I2) << ", " - << arg.b_grid_desc_kbatch_k0_n_k1_.GetLength(I3) << "}" << std::endl; + std::cout << "arg.a_grid_desc_kbatch_m_k_{" << arg.a_grid_desc_kbatch_m_k_.GetLength(I0) + << ", " << arg.a_grid_desc_kbatch_m_k_.GetLength(I1) << ", " + << arg.a_grid_desc_kbatch_m_k_.GetLength(I2) << ", " + << arg.a_grid_desc_kbatch_m_k_.GetLength(I3) << "}" << std::endl; + + std::cout << "arg.b_grid_desc_kbatch_n_k_{" << arg.b_grid_desc_kbatch_n_k_.GetLength(I0) + << ", " << arg.b_grid_desc_kbatch_n_k_.GetLength(I1) << ", " + << arg.b_grid_desc_kbatch_n_k_.GetLength(I2) << ", " + << arg.b_grid_desc_kbatch_n_k_.GetLength(I3) << "}" << std::endl; std::cout << "arg.c_grid_desc_m_n_{" << arg.c_grid_desc_m_n_.GetLength(I0) << ", " << arg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl; @@ -744,10 +793,9 @@ struct DeviceGroupedConvBwdWeight_Wmma_CShuffleV3 { float ave_time = 0; - const index_t GemmM = arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I1); - const index_t GemmN = arg.b_grid_desc_kbatch_k0_n_k1_.GetLength(I1); - const index_t GemmK = arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I0) * - arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I2); + const index_t GemmM = arg.a_grid_desc_kbatch_m_k_.GetLength(I0); + const index_t GemmN = arg.b_grid_desc_kbatch_n_k_.GetLength(I0); + const index_t GemmK = arg.a_grid_desc_kbatch_m_k_.GetLength(I1); const ADataType* p_a_grid = arg.p_a_grid_; const BDataType* p_b_grid = arg.p_b_grid_; @@ -839,10 +887,14 @@ struct DeviceGroupedConvBwdWeight_Wmma_CShuffleV3 index_t K_split = (gemm_arg.K + k_grain - 1) / k_grain * KPerBlock; const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K_split); - const auto num_k_per_block = - arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(Number<0>{}) / gemm_arg.KBatch; + std::cout << "K0 value is:" + << (GridwiseGemm::CalculateAK0Padded( + arg.a_grid_desc_kbatch_m_k_.GetLength(Number<1>{}), arg.k_batch_)) + << std::endl; - const auto clear_workspace = [&]() { + const index_t num_k_per_block = (GridwiseGemm::CalculateAK0Padded( + arg.a_grid_desc_kbatch_m_k_.GetLength(Number<1>{}), arg.k_batch_)); + const auto clear_workspace = [&]() { hip_check_error( hipMemsetAsync(p_e_grid, 0, arg.c_space_size_bytes, stream_config.stream_id_)); }; @@ -855,11 +907,11 @@ struct DeviceGroupedConvBwdWeight_Wmma_CShuffleV3 typename GridwiseGemm::Argument gemm_arg_ = gemm_arg; std::array size_as_buffers; - size_as_buffers[0] = arg.a_grid_desc_kbatch_k0_m_k1_.GetElementSpaceSize() * + size_as_buffers[0] = arg.a_grid_desc_kbatch_m_k_.GetElementSpaceSize() * sizeof(ADataType) / GridwiseGemm::APackedSize; std::array size_bs_buffers; - size_bs_buffers[0] = arg.b_grid_desc_kbatch_k0_n_k1_.GetElementSpaceSize() * + size_bs_buffers[0] = arg.b_grid_desc_kbatch_n_k_.GetElementSpaceSize() * sizeof(BDataType) / GridwiseGemm::BPackedSize; std::array size_ds_buffers; @@ -889,8 +941,8 @@ struct DeviceGroupedConvBwdWeight_Wmma_CShuffleV3 dim3(BlockSize), 0, gemm_arg_, - arg.a_grid_desc_kbatch_k0_m_k1_, - arg.b_grid_desc_kbatch_k0_n_k1_, + arg.a_grid_desc_kbatch_m_k_, + arg.b_grid_desc_kbatch_n_k_, arg.c_grid_desc_mblock_mperblock_nblock_nperblock_, arg.compute_ptr_offset_of_batch_, num_k_per_block); @@ -905,8 +957,8 @@ struct DeviceGroupedConvBwdWeight_Wmma_CShuffleV3 dim3(BlockSize), 0, gemm_arg, - arg.a_grid_desc_kbatch_k0_m_k1_, - arg.b_grid_desc_kbatch_k0_n_k1_, + arg.a_grid_desc_kbatch_m_k_, + arg.b_grid_desc_kbatch_n_k_, arg.c_grid_desc_mblock_mperblock_nblock_nperblock_, arg.compute_ptr_offset_of_batch_, num_k_per_block); @@ -926,8 +978,8 @@ struct DeviceGroupedConvBwdWeight_Wmma_CShuffleV3 { const auto kernel = kernel_grouped_conv_bwd_weight_wmma_cshuffle_v3< GridwiseGemm, - remove_reference_t, - remove_reference_t, + remove_reference_t, + remove_reference_t, remove_reference_t< DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>, ComputePtrOffsetOfStridedBatch, @@ -940,8 +992,8 @@ struct DeviceGroupedConvBwdWeight_Wmma_CShuffleV3 { const auto kernel = kernel_grouped_conv_bwd_weight_wmma_cshuffle_v3< GridwiseGemm, - remove_reference_t, - remove_reference_t, + remove_reference_t, + remove_reference_t, remove_reference_t< DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>, ComputePtrOffsetOfStridedBatch, @@ -965,8 +1017,8 @@ struct DeviceGroupedConvBwdWeight_Wmma_CShuffleV3 { const auto kernel = kernel_grouped_conv_bwd_weight_wmma_cshuffle_v3< GridwiseGemm, - remove_reference_t, - remove_reference_t, + remove_reference_t, + remove_reference_t, remove_reference_t< DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>, ComputePtrOffsetOfStridedBatch, @@ -979,8 +1031,8 @@ struct DeviceGroupedConvBwdWeight_Wmma_CShuffleV3 { const auto kernel = kernel_grouped_conv_bwd_weight_wmma_cshuffle_v3< GridwiseGemm, - remove_reference_t, - remove_reference_t, + remove_reference_t, + remove_reference_t, remove_reference_t< DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>, ComputePtrOffsetOfStridedBatch, @@ -1042,10 +1094,15 @@ struct DeviceGroupedConvBwdWeight_Wmma_CShuffleV3 static bool IsSupportedArgument(const Argument& arg) { - const index_t GemmM = arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I1); - const index_t GemmN = arg.b_grid_desc_kbatch_k0_n_k1_.GetLength(I1); - const index_t GemmK = arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I0) * - arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I2); +#if DISABLE_SPLIT_K_AUTODEDUCE_FOR_ONE_STAGE_KERNELS + if(arg.k_batch_ < 0) + { + return false; + } +#endif + const index_t GemmM = arg.a_grid_desc_kbatch_m_k_.GetLength(I0); + const index_t GemmN = arg.b_grid_desc_kbatch_n_k_.GetLength(I0); + const index_t GemmK = arg.a_grid_desc_kbatch_m_k_.GetLength(I1); typename GridwiseGemm::Argument gemm_arg{std::array{nullptr}, // p_as_grid std::array{nullptr}, // p_bs_grid diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_wmma_v3_wave_transfer_instances.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_wmma_v3_wave_transfer_instances.hpp index 6bc0ff8b4f0..4c0f9e52765 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_wmma_v3_wave_transfer_instances.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_wmma_v3_wave_transfer_instances.hpp @@ -14,7 +14,7 @@ namespace tensor_operation { namespace device { namespace instance { -#define USE_WAVE_TRANSFER_BWD_DATA +#define USE_WAVE_TRANSFER using BF16 = ck::bhalf_t; using F16 = ck::half_t; diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_v3_wmma_wave_transfer_instance.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_v3_wmma_wave_transfer_instance.hpp index 4984ac1cec8..5ca4f4d83d8 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_v3_wmma_wave_transfer_instance.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_v3_wmma_wave_transfer_instance.hpp @@ -1,6 +1,6 @@ // Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -#define USE_WAVE_TRANSFER_BWD_WEI +#define USE_WAVE #include "ck/ck.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp" #include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_multiple_d_wmma_cshuffle_v3.hpp" diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight.hpp index 56f511bc894..c31279d0025 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight.hpp @@ -12,7 +12,7 @@ #include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp" -#if defined USE_WAVE_TRANSFER_BWD_WEI +#if defined USE_WAVE #include "ck/tensor_operation/gpu/device/device_grouped_conv_bwd_weight_multiple_d.hpp" #endif @@ -962,7 +962,7 @@ struct DeviceOperationInstanceFactory>>& instances); #endif -#if defined USE_WAVE_TRANSFER_BWD_WEI +#if defined USE_WAVE #ifdef CK_ENABLE_BF16 void add_device_grouped_conv2d_bwd_weight_wmma_nhwgc_gkyxc_nhwgk_bf16_wave_transfer_instances( From 3ef29783783d7f694e5593220e2eedf247a65774 Mon Sep 17 00:00:00 2001 From: apoorva Date: Tue, 3 Feb 2026 13:08:09 +0000 Subject: [PATCH 11/12] Revert "Reverted unused device impl and updated macros" --- ...vice_grouped_conv_bwd_data_multiple_d_wmma_cshuffle_v3.hpp | 2 +- ...ce_grouped_conv_bwd_weight_multiple_d_wmma_cshuffle_v3.hpp | 2 +- ..._grouped_conv_bwd_data_wmma_v3_wave_transfer_instances.hpp | 2 +- ...grouped_conv_bwd_weight_v3_wmma_wave_transfer_instance.hpp | 2 +- .../gpu/grouped_convolution_backward_weight.hpp | 4 ++-- .../gpu/grouped_convolution_backward_weight_wmma.inc | 2 +- 6 files changed, 7 insertions(+), 7 deletions(-) diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle_v3.hpp index f4f64f6e771..6b635b6a23a 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle_v3.hpp @@ -282,7 +282,7 @@ struct DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffleV3 static_assert(NDimSpatial == 2 || NDimSpatial == 3, "wrong! only implemented for 2D and 3D now"); -#ifdef USE_WAVE_TRANSFER +#ifdef USE_WAVE_TRANSFER_BWD_DATA static_assert(UseThreadTileTransfer == false && (ConvBackwardDataSpecialization == diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_multiple_d_wmma_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_multiple_d_wmma_cshuffle_v3.hpp index 208855149eb..88cdef7548b 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_multiple_d_wmma_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_multiple_d_wmma_cshuffle_v3.hpp @@ -179,7 +179,7 @@ struct DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3 ComputeTypeB> { -#if defined USE_WAVE +#if defined USE_WAVE_TRANSFER_BWD_WEI static_assert(UseThreadTileTransfer == false && (ConvBackwardWeightSpecialization == diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_wmma_v3_wave_transfer_instances.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_wmma_v3_wave_transfer_instances.hpp index 4c0f9e52765..6bc0ff8b4f0 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_wmma_v3_wave_transfer_instances.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_wmma_v3_wave_transfer_instances.hpp @@ -14,7 +14,7 @@ namespace tensor_operation { namespace device { namespace instance { -#define USE_WAVE_TRANSFER +#define USE_WAVE_TRANSFER_BWD_DATA using BF16 = ck::bhalf_t; using F16 = ck::half_t; diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_v3_wmma_wave_transfer_instance.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_v3_wmma_wave_transfer_instance.hpp index 5ca4f4d83d8..4984ac1cec8 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_v3_wmma_wave_transfer_instance.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_v3_wmma_wave_transfer_instance.hpp @@ -1,6 +1,6 @@ // Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -#define USE_WAVE +#define USE_WAVE_TRANSFER_BWD_WEI #include "ck/ck.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp" #include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_multiple_d_wmma_cshuffle_v3.hpp" diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight.hpp index c31279d0025..56f511bc894 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight.hpp @@ -12,7 +12,7 @@ #include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp" -#if defined USE_WAVE +#if defined USE_WAVE_TRANSFER_BWD_WEI #include "ck/tensor_operation/gpu/device/device_grouped_conv_bwd_weight_multiple_d.hpp" #endif @@ -962,7 +962,7 @@ struct DeviceOperationInstanceFactory>>& instances); #endif -#if defined USE_WAVE +#if defined USE_WAVE_TRANSFER_BWD_WEI #ifdef CK_ENABLE_BF16 void add_device_grouped_conv2d_bwd_weight_wmma_nhwgc_gkyxc_nhwgk_bf16_wave_transfer_instances( From 92e8a1035ea87443697acd2a36a85f4083e4c763 Mon Sep 17 00:00:00 2001 From: apoorva Date: Tue, 3 Feb 2026 15:00:11 +0000 Subject: [PATCH 12/12] Built fix remove multi D functionality --- ...bwd_weight_multiple_d_wmma_cshuffle_v3.hpp | 134 +++++++----------- ...ouped_conv_bwd_weight_wmma_cshuffle_v3.hpp | 25 ++-- ..._weight_v3_wmma_wave_transfer_instance.hpp | 35 +++-- .../grouped_convolution_backward_weight.hpp | 83 ++--------- ...ouped_convolution_backward_weight_wmma.inc | 97 ++++++------- ...kyxc_nhwgk_bf16_wave_transfer_instance.cpp | 4 +- ...gkyxc_nhwgk_f16_wave_transfer_instance.cpp | 10 +- ...yxc_ndhwgk_bf16_wave_transfer_instance.cpp | 4 +- ...zyxc_ndhwgk_f16_wave_transfer_instance.cpp | 4 +- 9 files changed, 129 insertions(+), 267 deletions(-) diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_multiple_d_wmma_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_multiple_d_wmma_cshuffle_v3.hpp index 88cdef7548b..f662ff834f4 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_multiple_d_wmma_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_multiple_d_wmma_cshuffle_v3.hpp @@ -29,9 +29,6 @@ #include "ck/host_utility/kernel_launch.hpp" #include "ck/host_utility/flush_cache.hpp" -#include "ck/tensor_operation/gpu/device/impl/split_k_utils.hpp" -#include "ck/tensor_operation/gpu/device/matrix_padder.hpp" - #ifdef CK_EXPERIMENTAL_BUILDER #include "ck_tile/builder/reflect/description.hpp" #include "ck_tile/builder/reflect/instance_traits_device_grouped_conv_bwd_weight_multiple_d_wmma_cshuffle_v3.hpp" @@ -42,8 +39,8 @@ namespace tensor_operation { namespace device { template ::type; - - constexpr index_t LDS_size = - GridwiseGemm::template GetSharedMemoryNumberOfByte(); - __shared__ char p_shared[LDS_size]; - - auto epilogue_args = EpilogueType{}; - const auto a_grid_desc_ak0_m_ak1 = - GridwiseGemm::MakeAGridDescriptor_AK0_M_AK1(a_grid_desc_m_k); + constexpr index_t LDS_size = GridwiseGemm::template GetSharedMemoryNumberOfByte< + typename GridwiseGemm::EpilogueCShuffle>(); + __shared__ char p_shared[LDS_size]; - const auto b_grid_desc_bk0_n_bk1 = - GridwiseGemm::MakeBGridDescriptor_BK0_N_BK1(b_grid_desc_n_k); + auto epilogue_args = typename GridwiseGemm::EpilogueCShuffle{}; - GridwiseGemm::template Run struct DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3 @@ -178,15 +164,6 @@ struct DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3 ComputeTypeA, ComputeTypeB> { - -#if defined USE_WAVE_TRANSFER_BWD_WEI - - static_assert(UseThreadTileTransfer == false && - (ConvBackwardWeightSpecialization == - ConvolutionBackwardWeightSpecialization::Filter1x1Stride1Pad0), - "Only Filter1x1Stride1Pad0is supported for wavetile transfer"); -#endif - using DeviceOp = DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3; using ADataType = OutDataType; @@ -299,29 +276,12 @@ struct DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3 batch); } - template - static auto transform_k0_m_k1_to_m_k(const Desc_K0_M_K1& desc_k0_m_k1) - { - const auto grid_desc_m_k = transform_tensor_descriptor( - desc_k0_m_k1, - make_tuple(make_pass_through_transform(desc_k0_m_k1.GetLength(I1)), - make_merge_transform( - make_tuple(desc_k0_m_k1.GetLength(I0), desc_k0_m_k1.GetLength(I2)))), - make_tuple(Sequence<1>{}, Sequence<0, 2>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); - - return grid_desc_m_k; - } - using ABCGridDescs = decltype(GetABCGridDesc()); using AGridDesc_K0_M_K1 = remove_cvref_t; using BGridDesc_K0_N_K1 = remove_cvref_t; using CGridDesc_M_N = remove_cvref_t; - using AGridDesc_M_K = decltype(transform_k0_m_k1_to_m_k(AGridDesc_K0_M_K1{})); - using BGridDesc_N_K = decltype(transform_k0_m_k1_to_m_k(BGridDesc_K0_N_K1{})); - using GridwiseGemm = GridwiseGemm_wmma_cshuffle_v3< tensor_layout::gemm::ColumnMajor, tensor_layout::gemm::RowMajor, @@ -371,10 +331,10 @@ struct DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3 BlkGemmPipelineVer, ComputeTypeA, ComputeTypeB, - false, // permuteA - false, // permuteB - false, // IsBPreShuffled - UseThreadTileTransfer>; // ForceThreadTileTransfer + false, // permuteA + false, // permuteB + false, // IsBPreShuffled + true>; // ForceThreadTileTransfer static constexpr auto MakeElementwiseInputSequence() { @@ -632,8 +592,8 @@ struct DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3 p_b_grid_{p_in_grid}, p_ds_grid_{}, p_e_grid_{p_wei_grid}, - a_grid_desc_kbatch_m_k_{}, - b_grid_desc_kbatch_n_k_{}, + a_grid_desc_kbatch_k0_m_k1_{}, + b_grid_desc_kbatch_k0_n_k1_{}, ce_grid_desc_m_n_{}, c_grid_desc_mblock_mperblock_nblock_nperblock_{}, compute_ptr_offset_of_batch_{}, @@ -727,9 +687,9 @@ struct DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3 compute_ptr_offset_of_batch_.BatchStrideDs_(i) = ds_g_k_c_xs_strides[i][0]; }); - a_grid_desc_kbatch_m_k_ = transform_k0_m_k1_to_m_k(descs[I0]); - b_grid_desc_kbatch_n_k_ = transform_k0_m_k1_to_m_k(descs[I1]); - ce_grid_desc_m_n_ = descs[I2]; + a_grid_desc_kbatch_k0_m_k1_ = descs[I0]; + b_grid_desc_kbatch_k0_n_k1_ = descs[I1]; + ce_grid_desc_m_n_ = descs[I2]; ds_grid_descs_tuple_ = MakeDsGridDescriptor_M_N(ds_g_k_c_xs_lengths, ds_g_k_c_xs_strides); @@ -747,8 +707,8 @@ struct DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3 index_t{1}, std::multiplies<>{}); - const index_t GemmM = a_grid_desc_kbatch_m_k_.GetLength(I0); - const index_t GemmN = b_grid_desc_kbatch_n_k_.GetLength(I0); + const index_t GemmM = a_grid_desc_kbatch_k0_m_k1_.GetLength(I1); + const index_t GemmN = b_grid_desc_kbatch_k0_n_k1_.GetLength(I1); c_grid_desc_mblock_mperblock_nblock_nperblock_ = GridwiseGemm::MakeDEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( @@ -766,8 +726,8 @@ struct DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3 const BDataType* p_b_grid_; DsGridPointerTuple p_ds_grid_; EDataType* p_e_grid_; - AGridDesc_M_K a_grid_desc_kbatch_m_k_; - BGridDesc_N_K b_grid_desc_kbatch_n_k_; + AGridDesc_K0_M_K1 a_grid_desc_kbatch_k0_m_k1_; + BGridDesc_K0_N_K1 b_grid_desc_kbatch_k0_n_k1_; CGridDesc_M_N ce_grid_desc_m_n_; CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock c_grid_desc_mblock_mperblock_nblock_nperblock_; DsGridDesc_M_N ds_grid_descs_tuple_; @@ -824,9 +784,10 @@ struct DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3 { float ave_time = 0; - const index_t GemmM = arg.a_grid_desc_kbatch_m_k_.GetLength(I0); - const index_t GemmN = arg.b_grid_desc_kbatch_n_k_.GetLength(I0); - const index_t GemmK = arg.a_grid_desc_kbatch_m_k_.GetLength(I1); + const index_t GemmM = arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I1); + const index_t GemmN = arg.b_grid_desc_kbatch_k0_n_k1_.GetLength(I1); + const index_t GemmK = arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I0) * + arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I2); AccDataType* p_e_grid = type_convert(arg.p_workspace_); @@ -856,8 +817,8 @@ struct DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3 index_t K_split = (gemm_arg.K + k_grain - 1) / k_grain * KPerBlock; const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K_split); - const index_t num_k_per_block = (GridwiseGemm::CalculateAK0Padded( - arg.a_grid_desc_kbatch_m_k_.GetLength(Number<1>{}), arg.k_batch_)); + const auto num_k_per_block = + arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(Number<0>{}) / gemm_arg.KBatch; const auto clear_workspace = [&]() { hip_check_error(hipMemsetAsync( @@ -870,11 +831,11 @@ struct DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3 typename GridwiseGemm::Argument gemm_arg_ = gemm_arg; std::array size_as_buffers; - size_as_buffers[0] = arg.a_grid_desc_kbatch_m_k_.GetElementSpaceSize() * + size_as_buffers[0] = arg.a_grid_desc_kbatch_k0_m_k1_.GetElementSpaceSize() * sizeof(ADataType) / GridwiseGemm::APackedSize; std::array size_bs_buffers; - size_bs_buffers[0] = arg.b_grid_desc_kbatch_n_k_.GetElementSpaceSize() * + size_bs_buffers[0] = arg.b_grid_desc_kbatch_k0_n_k1_.GetElementSpaceSize() * sizeof(BDataType) / GridwiseGemm::BPackedSize; std::array size_ds_buffers; @@ -904,8 +865,8 @@ struct DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3 dim3(BlockSize), 0, gemm_arg_, - arg.a_grid_desc_kbatch_m_k_, - arg.b_grid_desc_kbatch_n_k_, + arg.a_grid_desc_kbatch_k0_m_k1_, + arg.b_grid_desc_kbatch_k0_n_k1_, arg.c_grid_desc_mblock_mperblock_nblock_nperblock_, arg.compute_ptr_offset_of_batch_, num_k_per_block); @@ -920,8 +881,8 @@ struct DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3 dim3(BlockSize), 0, gemm_arg, - arg.a_grid_desc_kbatch_m_k_, - arg.b_grid_desc_kbatch_n_k_, + arg.a_grid_desc_kbatch_k0_m_k1_, + arg.b_grid_desc_kbatch_k0_n_k1_, arg.c_grid_desc_mblock_mperblock_nblock_nperblock_, arg.compute_ptr_offset_of_batch_, num_k_per_block); @@ -942,8 +903,8 @@ struct DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3 const auto kernel = kernel_grouped_conv_bwd_weight_wmma_cshuffle_v3_multiple_d< GridwiseGemm, - remove_reference_t, - remove_reference_t, + remove_reference_t, + remove_reference_t, remove_reference_t< DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>, ComputePtrOffsetOfStridedBatch, @@ -957,8 +918,8 @@ struct DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3 const auto kernel = kernel_grouped_conv_bwd_weight_wmma_cshuffle_v3_multiple_d< GridwiseGemm, - remove_reference_t, - remove_reference_t, + remove_reference_t, + remove_reference_t, remove_reference_t< DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>, ComputePtrOffsetOfStridedBatch, @@ -983,8 +944,8 @@ struct DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3 const auto kernel = kernel_grouped_conv_bwd_weight_wmma_cshuffle_v3_multiple_d< GridwiseGemm, - remove_reference_t, - remove_reference_t, + remove_reference_t, + remove_reference_t, remove_reference_t< DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>, ComputePtrOffsetOfStridedBatch, @@ -998,8 +959,8 @@ struct DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3 const auto kernel = kernel_grouped_conv_bwd_weight_wmma_cshuffle_v3_multiple_d< GridwiseGemm, - remove_reference_t, - remove_reference_t, + remove_reference_t, + remove_reference_t, remove_reference_t< DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>, ComputePtrOffsetOfStridedBatch, @@ -1069,9 +1030,10 @@ struct DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3 static bool IsSupportedArgument(const Argument& arg) { - const index_t GemmM = arg.a_grid_desc_kbatch_m_k_.GetLength(I0); - const index_t GemmN = arg.b_grid_desc_kbatch_n_k_.GetLength(I0); - const index_t GemmK = arg.a_grid_desc_kbatch_m_k_.GetLength(I1); + const index_t GemmM = arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I1); + const index_t GemmN = arg.b_grid_desc_kbatch_k0_n_k1_.GetLength(I1); + const index_t GemmK = arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I0) * + arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I2); typename GridwiseGemm::Argument gemm_arg{std::array{nullptr}, // p_as_grid std::array{nullptr}, // p_bs_grid diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_wmma_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_wmma_cshuffle_v3.hpp index e60f95d9cb7..181f67fabf3 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_wmma_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_wmma_cshuffle_v3.hpp @@ -184,16 +184,15 @@ struct DeviceGroupedConvBwdWeight_Wmma_CShuffleV3 using ADataType = OutDataType; using BDataType = InDataType; using CDataType = WeiDataType; - // // static const auto F1S1 = - // ConvolutionBackwardWeightSpecialization::Filter1x1Stride1Pad0; - // #if defined USE_WAVE - - // static_assert(UseThreadTileTransfer==false && - // (ConvBackwardWeightSpecialization==ConvolutionBackwardWeightSpecialization::Filter1x1Stride1Pad0 - // ),"Only Filter1x1Stride1Pad0is supported for wavetile transfer" - // ); - // #endif - // If NGCHW then ADataType must be equal to BDataType + + #if defined USE_WAVE_TRANSFER_BWD_WEI + + static_assert(UseThreadTileTransfer==false && + (ConvBackwardWeightSpecialization==ConvolutionBackwardWeightSpecialization::Filter1x1Stride1Pad0 + ),"Only Filter1x1Stride1Pad0is supported for wavetile transfer" + ); + #endif + // If NGCHW then ADataType must be equal to BDataType static_assert(!(is_NGCHW_NGKHW() || is_NGCDHW_NGKDHW()) || is_same_v); @@ -1094,12 +1093,6 @@ struct DeviceGroupedConvBwdWeight_Wmma_CShuffleV3 static bool IsSupportedArgument(const Argument& arg) { -#if DISABLE_SPLIT_K_AUTODEDUCE_FOR_ONE_STAGE_KERNELS - if(arg.k_batch_ < 0) - { - return false; - } -#endif const index_t GemmM = arg.a_grid_desc_kbatch_m_k_.GetLength(I0); const index_t GemmN = arg.b_grid_desc_kbatch_n_k_.GetLength(I0); const index_t GemmK = arg.a_grid_desc_kbatch_m_k_.GetLength(I1); diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_v3_wmma_wave_transfer_instance.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_v3_wmma_wave_transfer_instance.hpp index 4984ac1cec8..f74bbbef87e 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_v3_wmma_wave_transfer_instance.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_v3_wmma_wave_transfer_instance.hpp @@ -3,11 +3,12 @@ #define USE_WAVE_TRANSFER_BWD_WEI #include "ck/ck.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp" -#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_multiple_d_wmma_cshuffle_v3.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_wmma_cshuffle_v3.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + namespace ck { namespace tensor_operation { namespace device { @@ -27,6 +28,8 @@ using F8 = ck::f8_t; using BF8 = ck::bf8_t; #endif + + using Empty_Tuple = ck::Tuple<>; template @@ -49,15 +52,13 @@ template using device_grouped_conv_bwd_weight_v3_wmma_c_shuffle_f16_wave_transfer_instances = std::tuple< // clang-format off - //#################################################| Num| InLayout| WeiLayout| OutLayout| DsLayout| InData| WeiData| OutData| AccData| DsData| In| Wei| Out| ConvBackward| Block| MPer| NPer| KPer| ABK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CBlockTransfer| BlockGemm| BlockGemm| - //#################################################| Dim| | | | | Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Weight| Size| Block| Block| Block| | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| NRepeat| ClusterLengths| ScalarPerVector| Pipeline| Pipeline | - //#################################################| Spatial| | | | | | | | | | Operation| Operation| Operation| Specialization| | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| MBlock_MPerBlock| _NPerBlock| Scheduler| Version | - //#################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | NBlock_NPerBlock| | | | - // generic instance - DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, Tuple<>, F16, F16, F16, F32, Tuple<>, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 64, 64, 32, 8, 16, 16, 4, 2, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 1, 4, 1, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 1, 4, 1, 1, 1, S<1, 16, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, false>, - - DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, Tuple<>, F16, F16, F16, F32, Tuple<>, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 128, 128, 32, 8, 16, 16, 8, 2, S<4, 32, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 8, 0, S<4, 32, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 8, 0, 1, 1, S<1, 16, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, false> - + //#########################################| Num| InLayout| WeiLayout| OutLayout| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Block| MPer| NPer| KPer| ABK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CShuffleBlockTransfer| BlockGemm| BlockGemm| + //#########################################| Dim| | | | Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Weight| Size| Block| Block| Block| | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| NRepeat| ClusterLengths| ScalarPerVector| Pipeline| Pipeline| + //#########################################| Spatial| | | | | | | | Operation| Operation| Operation| Specialization| | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| MBlock_MPerBlock| _NPerBlock| Scheduler| Version| + //#########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | NBlock_NPerBlock| | | | + DeviceGroupedConvBwdWeight_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 32, 32, 32, 8, 16, 16, 2, 1, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 2, 2, 0, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 2, 2, 0, 1, 1, S<1, 8, 1, 8>, 2, Scheduler, PipelineVersion, false>, + DeviceGroupedConvBwdWeight_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 128, 128, 32, 8, 16, 16, 8, 2, S<4, 32, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 8, 0, S<4, 32, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 8, 0, 1, 1, S<1, 16, 1, 8>, 8, Scheduler, PipelineVersion, false> + // clang-format on >; @@ -70,14 +71,12 @@ template using device_grouped_conv_bwd_weight_v3_wmma_c_shuffle_bf16_wave_transfer_instances = std::tuple< // clang-format off - //#################################################| Num| InLayout| WeiLayout| OutLayout| DsLayout| InData| WeiData| OutData| AccData| DsData| In| Wei| Out| ConvBackward| Block| MPer| NPer| KPer| ABK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CBlockTransfer| BlockGemm| BlockGemm| - //#################################################| Dim| | | | | Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Weight| Size| Block| Block| Block| | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| NRepeat| ClusterLengths| ScalarPerVector| Pipeline| Pipeline | - //#################################################| Spatial| | | | | | | | | | Operation| Operation| Operation| Specialization| | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| MBlock_MPerBlock| _NPerBlock| Scheduler| Version | - //#################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | NBlock_NPerBlock| | | | - // generic instance - DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, Tuple<>, BF16, BF16, BF16, F32, Tuple<>, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 32, 32, 32, 8, 16, 16, 2, 1, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 2, 2, 0, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 2, 2, 0, 1, 1, S<1, 8, 1, 8>, 2, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, false>, - - DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, Tuple<>, BF16, BF16, BF16, F32, Tuple<>, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 128, 128, 32, 8, 16, 16, 8, 2, S<4, 32, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 8, 0, S<4, 32, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 8, 0, 1, 1, S<1, 16, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, false> + //#########################################| Num| InLayout| WeiLayout| OutLayout| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Block| MPer| NPer| KPer| ABK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CShuffleBlockTransfer| BlockGemm| BlockGemm| + //#########################################| Dim| | | | Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Weight| Size| Block| Block| Block| | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| NRepeat| ClusterLengths| ScalarPerVector| Pipeline| Pipeline| + //#########################################| Spatial| | | | | | | | Operation| Operation| Operation| Specialization| | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| MBlock_MPerBlock| _NPerBlock| Scheduler| Version| + //#########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | NBlock_NPerBlock| | | | + DeviceGroupedConvBwdWeight_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 32, 32, 32, 8, 16, 16, 2, 1, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 2, 2, 0, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 2, 2, 0, 1, 1, S<1, 8, 1, 8>, 2, Scheduler, PipelineVersion, false>, + DeviceGroupedConvBwdWeight_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 128, 128, 32, 8, 16, 16, 8, 2, S<4, 32, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 8, 0, S<4, 32, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 8, 0, 1, 1, S<1, 16, 1, 8>, 8, Scheduler, PipelineVersion, false> //clang-format on >; diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight.hpp index 56f511bc894..9ac16834f0f 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight.hpp @@ -12,11 +12,6 @@ #include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp" -#if defined USE_WAVE_TRANSFER_BWD_WEI -#include "ck/tensor_operation/gpu/device/device_grouped_conv_bwd_weight_multiple_d.hpp" - -#endif - #ifdef DL_KERNELS #include "grouped_convolution_backward_weight_dl.inc" #endif @@ -398,9 +393,6 @@ struct DeviceOperationInstanceFactory -struct DeviceOperationInstanceFactory< - ck::tensor_operation::device::DeviceGroupedConvBwdWeightMultipleD< - NumDimSpatial, - InLayout, - WeiLayout, - OutLayout, - DsLayout, - InDataType, - WeiDataType, - OutDataType, - DsDataType, - ck::tensor_operation::element_wise::PassThrough, - ck::tensor_operation::element_wise::PassThrough, - ck::tensor_operation::element_wise::PassThrough, - ComputeTypeA, - ComputeTypeB>> -{ - using DeviceOp = - DeviceGroupedConvBwdWeightMultipleD; - - static auto GetInstances() - { - std::vector> op_ptrs; -#ifdef CK_ENABLE_BF16 - add_device_grouped_conv2d_bwd_weight_wmma_nhwgc_gkyxc_nhwgk_bf16_wave_transfer_instances( - op_ptrs); - add_device_grouped_conv2d_bwd_weight_wmma_nhwgc_gkyxc_nhwgk_f16_wave_transfer_instances( - op_ptrs); -#endif -#ifdef CK_ENABLE_FP16 - add_device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_bf16_wave_transfer_instances( - op_ptrs); - add_device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_f16_wave_transfer_instances( - op_ptrs); -#endif - return op_ptrs; - } -}; -#endif } // namespace instance } // namespace device } // namespace tensor_operation diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight_wmma.inc b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight_wmma.inc index 33f4165c28c..ba44bf5b69e 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight_wmma.inc +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight_wmma.inc @@ -22,6 +22,18 @@ void add_device_grouped_conv2d_bwd_weight_wmma_nhwgc_gkyxc_nhwgk_f16_instances( PassThrough, PassThrough>>>& instances); +void add_device_grouped_conv2d_bwd_weight_wmma_nhwgc_gkyxc_nhwgk_f16_wave_transfer_instances( + std::vector>>& instances); + void add_device_grouped_conv2d_bwd_weight_two_stage_wmma_nhwgc_gkyxc_nhwgk_f16_pipev1_instances( std::vector>>& instances); +void add_device_grouped_conv2d_bwd_weight_wmma_nhwgc_gkyxc_nhwgk_bf16_wave_transfer_instances( + std::vector>>& instances); + void add_device_grouped_conv2d_bwd_weight_two_stage_wmma_nhwgc_gkyxc_nhwgk_bf16_pipev1_instances( std::vector>>& instances); -void add_device_grouped_conv3d_bwd_weight_two_stage_wmma_ndhwgc_gkzyxc_ndhwgk_f16_pipev1_instances( +void add_device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_f16_wave_transfer_instances( std::vector>>& instances); -#endif -#ifdef CK_ENABLE_BF16 -void add_device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_bf16_instances( + +void add_device_grouped_conv3d_bwd_weight_two_stage_wmma_ndhwgc_gkzyxc_ndhwgk_f16_pipev1_instances( std::vector>>& instances); +#endif -void add_device_grouped_conv3d_bwd_weight_two_stage_wmma_ndhwgc_gkzyxc_ndhwgk_bf16_pipev1_instances( +#ifdef CK_ENABLE_BF16 +void add_device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_bf16_instances( std::vector>>& instances); -#endif - -#if defined USE_WAVE_TRANSFER_BWD_WEI -#ifdef CK_ENABLE_BF16 - -void add_device_grouped_conv2d_bwd_weight_wmma_nhwgc_gkyxc_nhwgk_bf16_wave_transfer_instances( - std::vector, - BF16, - BF16, - BF16, - Tuple<>, - PassThrough, - PassThrough, - PassThrough>>>& instances); void add_device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_bf16_wave_transfer_instances( std::vector, + NDHWGC, + GKZYXC, + NDHWGK, BF16, BF16, BF16, - Tuple<>, PassThrough, PassThrough, PassThrough>>>& instances); -#endif - -#ifdef CK_ENABLE_FP16 - -void add_device_grouped_conv2d_bwd_weight_wmma_nhwgc_gkyxc_nhwgk_f16_wave_transfer_instances( - std::vector, - F16, - F16, - F16, - Tuple<>, - PassThrough, - PassThrough, - PassThrough>>>& instances); - -void add_device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_bf16_wave_transfer_instances( +void add_device_grouped_conv3d_bwd_weight_two_stage_wmma_ndhwgc_gkzyxc_ndhwgk_bf16_pipev1_instances( std::vector, - F16, - F16, - F16, - Tuple<>, + BF16, + BF16, + BF16, PassThrough, PassThrough, PassThrough>>>& instances); - -#endif #endif } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/wmma/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_wmma_nhwgc_gkyxc_nhwgk_bf16_wave_transfer_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/wmma/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_wmma_nhwgc_gkyxc_nhwgk_bf16_wave_transfer_instance.cpp index 36312b814e0..126d7444376 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/wmma/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_wmma_nhwgc_gkyxc_nhwgk_bf16_wave_transfer_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/wmma/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_wmma_nhwgc_gkyxc_nhwgk_bf16_wave_transfer_instance.cpp @@ -11,15 +11,13 @@ namespace instance { // Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] void add_device_grouped_conv2d_bwd_weight_wmma_nhwgc_gkyxc_nhwgk_bf16_wave_transfer_instances( - std::vector, BF16, BF16, BF16, - Tuple<>, PassThrough, PassThrough, PassThrough>>>& instances) diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/wmma/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_wmma_nhwgc_gkyxc_nhwgk_f16_wave_transfer_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/wmma/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_wmma_nhwgc_gkyxc_nhwgk_f16_wave_transfer_instance.cpp index 60287b459c2..0c41d911055 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/wmma/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_wmma_nhwgc_gkyxc_nhwgk_f16_wave_transfer_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/wmma/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_wmma_nhwgc_gkyxc_nhwgk_f16_wave_transfer_instance.cpp @@ -11,15 +11,13 @@ namespace instance { // Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] void add_device_grouped_conv2d_bwd_weight_wmma_nhwgc_gkyxc_nhwgk_f16_wave_transfer_instances( - std::vector, - F16, - F16, - F16, - Tuple<>, + F16, + F16, + F16, PassThrough, PassThrough, PassThrough>>>& instances) diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/wmma/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_bf16_wave_transfer_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/wmma/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_bf16_wave_transfer_instance.cpp index eee14e32364..5d6be45bfb5 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/wmma/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_bf16_wave_transfer_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/wmma/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_bf16_wave_transfer_instance.cpp @@ -11,15 +11,13 @@ namespace instance { // Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] void add_device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_bf16_wave_transfer_instances( - std::vector, BF16, BF16, BF16, - Tuple<>, PassThrough, PassThrough, PassThrough>>>& instances) diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/wmma/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_f16_wave_transfer_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/wmma/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_f16_wave_transfer_instance.cpp index 0eeef111a28..200d150bebe 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/wmma/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_f16_wave_transfer_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/wmma/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_f16_wave_transfer_instance.cpp @@ -11,15 +11,13 @@ namespace instance { // Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] void add_device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_f16_wave_transfer_instances( - std::vector, F16, F16, F16, - Tuple<>, PassThrough, PassThrough, PassThrough>>>& instances)