diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle_v3.hpp index dfdfd53725f..6b635b6a23a 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle_v3.hpp @@ -35,8 +35,8 @@ namespace device { namespace { template ()]; - auto epilogue_args = typename GridwiseGemm::EpilogueCShuffle{}; + + using EpilogueType = + typename std::conditional::type; + + constexpr index_t LDS_size = + GridwiseGemm::template GetSharedMemoryNumberOfByte(); + __shared__ char p_shared[LDS_size]; + + auto epilogue_args = EpilogueType{}; const index_t block_args_id = __builtin_amdgcn_readfirstlane(blockIdx.x); index_t left = 0; index_t right = gemms_count; index_t group_id = index_t((left + right) / 2); + while((!(block_args_id >= gemm_kernel_args[group_id].BlockStart_ && block_args_id < gemm_kernel_args[group_id].BlockEnd_)) && left <= right) @@ -89,14 +99,19 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) group_id = index_t((left + right) / 2); } - const auto num_k_per_block = - gemm_kernel_args[group_id].a_grid_desc_ak0_m_ak1_.GetLength(Number<0>{}) / KBatch; + const auto num_k_per_block = GridwiseGemm::CalculateAK0Padded( + gemm_kernel_args[group_id].a_grid_desc_m_k_.GetLength(Number<1>{}), KBatch); + + const auto a_grid_desc_ak0_m_ak1 = GridwiseGemm::MakeAGridDescriptor_AK0_M_AK1( + gemm_kernel_args[group_id].a_grid_desc_m_k_); + const auto b_grid_desc_bk0_n_bk1 = GridwiseGemm::MakeBGridDescriptor_BK0_N_BK1( + gemm_kernel_args[group_id].b_grid_desc_n_k_); if constexpr(HasMainKBlockLoopInAllGemm || NoMainKBlockLoopInAllGemm) { - GridwiseGemm::template Run( p_shared, - gemm_kernel_args[group_id].a_grid_desc_ak0_m_ak1_, - gemm_kernel_args[group_id].b_grid_desc_bk0_n_bk1_, + a_grid_desc_ak0_m_ak1, + b_grid_desc_bk0_n_bk1, gemm_kernel_args[group_id].ds_grid_desc_mblock_mperblock_nblock_nperblock_, gemm_kernel_args[group_id].e_grid_desc_mblock_mperblock_nblock_nperblock_, gemm_kernel_args[group_id].block_2_ctile_map_, @@ -122,8 +137,8 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) { if(gemm_kernel_args[group_id].HasMainKBlockLoop_) { - GridwiseGemm::template Run( p_shared, - gemm_kernel_args[group_id].a_grid_desc_ak0_m_ak1_, - gemm_kernel_args[group_id].b_grid_desc_bk0_n_bk1_, + a_grid_desc_ak0_m_ak1, + b_grid_desc_bk0_n_bk1, gemm_kernel_args[group_id].ds_grid_desc_mblock_mperblock_nblock_nperblock_, gemm_kernel_args[group_id].e_grid_desc_mblock_mperblock_nblock_nperblock_, gemm_kernel_args[group_id].block_2_ctile_map_, @@ -147,8 +162,8 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) } else { - GridwiseGemm::template Run( p_shared, - gemm_kernel_args[group_id].a_grid_desc_ak0_m_ak1_, - gemm_kernel_args[group_id].b_grid_desc_bk0_n_bk1_, + a_grid_desc_ak0_m_ak1, + b_grid_desc_bk0_n_bk1, gemm_kernel_args[group_id].ds_grid_desc_mblock_mperblock_nblock_nperblock_, gemm_kernel_args[group_id].e_grid_desc_mblock_mperblock_nblock_nperblock_, gemm_kernel_args[group_id].block_2_ctile_map_, @@ -242,6 +257,7 @@ template ; // ForceThreadTileTransfer + false, + false, + false, + UseThreadTileTransfer>; #define GridwiseGemmCTransposeTemplateParameters \ ALayout, BLayout, DsLayout, ELayout, Tuple, Tuple, AccDataType, \ @@ -517,8 +541,8 @@ struct DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffleV3 struct GemmArgs { GemmArgs() = default; - GemmArgs(AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1, - BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1, + GemmArgs(AGridDesc_M_K a_grid_desc_m_k, + BGridDesc_N_K b_grid_desc_n_k, DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock ds_grid_desc_mblock_mperblock_nblock_nperblock, EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock @@ -527,8 +551,8 @@ struct DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffleV3 index_t BlockStart, index_t BlockEnd, bool HasMainKBlockLoop) - : a_grid_desc_ak0_m_ak1_(a_grid_desc_ak0_m_ak1), - b_grid_desc_bk0_n_bk1_(b_grid_desc_bk0_n_bk1), + : a_grid_desc_m_k_(a_grid_desc_m_k), + b_grid_desc_n_k_(b_grid_desc_n_k), ds_grid_desc_mblock_mperblock_nblock_nperblock_( ds_grid_desc_mblock_mperblock_nblock_nperblock), @@ -543,8 +567,8 @@ struct DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffleV3 { } // tensor descriptors for block/thread-wise copy - AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_; - BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_; + AGridDesc_M_K a_grid_desc_m_k_; + BGridDesc_N_K b_grid_desc_n_k_; DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock ds_grid_desc_mblock_mperblock_nblock_nperblock_; EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock e_grid_desc_mblock_mperblock_nblock_nperblock_; @@ -926,8 +950,8 @@ struct DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffleV3 gemm_kernel_args_[gemms_count_ / MaxGroupedGemmGroupsNum][gemms_count_ % MaxGroupedGemmGroupsNum] = - GemmArgs{a_grid_desc_ak0_m_ak1, - b_grid_desc_bk0_n_bk1, + GemmArgs{a_grid_desc_m_k, + b_grid_desc_n_k, GridwiseGemmCTranspose:: MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( ds_grid_desc_m_n, MBlock, NBlock), @@ -1055,10 +1079,10 @@ struct DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffleV3 { for(std::size_t i = 0; i < a_grid_desc_m_k_container_.size(); i++) { - std::cout << "a_grid_desc_m_ak_container_" << a_grid_desc_m_k_container_[i] + std::cout << "a_grid_desc_m_k_container_" << a_grid_desc_m_k_container_[i] << std::endl; - std::cout << "b_grid_desc_n_bk_container_" << b_grid_desc_n_k_container_[i] + std::cout << "b_grid_desc_n_k_container_" << b_grid_desc_n_k_container_[i] << std::endl; static_for<0, NumDTensor, 1>{}([&](auto j) { @@ -1086,8 +1110,7 @@ struct DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffleV3 std::vector e_grid_desc_m_n_container_; // tensor descriptor for block-wise copy - std::vector a_grid_desc_ak0_m_ak1_container_; - std::vector b_grid_desc_bk0_n_bk1_container_; + std::vector ds_grid_desc_mblock_mperblock_nblock_nperblock_container_; std::vector @@ -1233,8 +1256,8 @@ struct DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffleV3 constexpr bool no_main_loop = no_main_k_block_loop.value; const auto kernel = kernel_grouped_conv_bwd_data_wmma_cshuffle_v3< GridwiseGemmCTranspose, - DeviceOp::AGridDesc_AK0_M_AK1, - DeviceOp::BGridDesc_BK0_N_BK1, + DeviceOp::AGridDesc_M_K, + DeviceOp::BGridDesc_N_K, DeviceOp::DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, DeviceOp::EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, MaxGroupedGemmGroupsNum, @@ -1785,12 +1808,12 @@ struct DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffleV3 p_ds_grid_dummy[i] = nullptr; StrideDs_dummy[i] = I0; }); - for(std::size_t i = 0; i < arg.a_grid_desc_ak0_m_ak1_container_.size(); i++) + for(std::size_t i = 0; i < arg.a_grid_desc_m_k_container_.size(); i++) { - const index_t GemmM = arg.a_grid_desc_ak0_m_ak1_container_[i].GetLength(I1); - const index_t GemmN = arg.b_grid_desc_bk0_n_bk1_container_[i].GetLength(I1); - const index_t GemmK = arg.a_grid_desc_ak0_m_ak1_container_[i].GetLength(I0) * - arg.a_grid_desc_ak0_m_ak1_container_[i].GetLength(I2); + const index_t GemmM = arg.a_grid_desc_m_k_container_[i].GetLength(I0); + const index_t GemmN = arg.b_grid_desc_n_k_container_[i].GetLength(I0); + const index_t GemmK = arg.a_grid_desc_m_k_container_[i].GetLength(I1); + // Create gemm arguments with dummy values to check for validity typename GridwiseGemmCTranspose::Argument gemm_arg{ std::array{nullptr}, // p_as_grid diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_wmma_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_wmma_cshuffle_v3.hpp index b2ae092c274..181f67fabf3 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_wmma_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_wmma_cshuffle_v3.hpp @@ -41,8 +41,8 @@ namespace tensor_operation { namespace device { template (); + using EpilogueType = + typename std::conditional::type; + + constexpr index_t LDS_size = + GridwiseGemm::template GetSharedMemoryNumberOfByte(); __shared__ char p_shared[LDS_size]; - auto epilogue_args = typename GridwiseGemm::EpilogueCShuffle{}; + auto epilogue_args = EpilogueType{}; - GridwiseGemm::template Run() || is_NGCDHW_NGKDHW()) || is_same_v); @@ -293,6 +313,33 @@ struct DeviceGroupedConvBwdWeight_Wmma_CShuffleV3 batch); } + template + static auto transform_k0_m_k1_to_m_k(const Desc_K0_M_K1& desc_k0_m_k1) + { + const auto grid_desc_m_k = transform_tensor_descriptor( + desc_k0_m_k1, + make_tuple(make_pass_through_transform(desc_k0_m_k1.GetLength(I1)), + make_merge_transform( + make_tuple(desc_k0_m_k1.GetLength(I0), desc_k0_m_k1.GetLength(I2)))), + make_tuple(Sequence<1>{}, Sequence<0, 2>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + return grid_desc_m_k; + } + template + static auto transform_k0_m_k1_to_n_k(const Desc_K0_N_K1& desc_k0_n_k1) + { + const auto grid_desc_n_k = transform_tensor_descriptor( + desc_k0_n_k1, + make_tuple(make_pass_through_transform(desc_k0_n_k1.GetLength(I1)), + make_merge_transform( + make_tuple(desc_k0_n_k1.GetLength(I0), desc_k0_n_k1.GetLength(I2)))), + make_tuple(Sequence<1>{}, Sequence<0, 2>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + return grid_desc_n_k; + } + using NGCHWTransposeDescType = remove_cvref_t({}, {}))>; @@ -308,9 +355,12 @@ struct DeviceGroupedConvBwdWeight_Wmma_CShuffleV3 using ABCGridDescs = decltype(GetABCGridDesc()); - using AGridDesc_K0_M_K1 = remove_cvref_t; - using BGridDesc_K0_N_K1 = remove_cvref_t; - using CGridDesc_M_N = remove_cvref_t; + using AGridDesc_M_K_ = remove_cvref_t; + using BGridDesc_N_K_ = remove_cvref_t; + using CGridDesc_M_N = remove_cvref_t; + + using AGridDesc_M_K = decltype(transform_k0_m_k1_to_m_k(AGridDesc_M_K_{})); + using BGridDesc_N_K = decltype(transform_k0_m_k1_to_n_k(BGridDesc_N_K_{})); using Block2TileMapTranspose = BlockToCTileMap_M00_N0_M01Adapt; @@ -401,10 +451,10 @@ struct DeviceGroupedConvBwdWeight_Wmma_CShuffleV3 BlkGemmPipelineVer, ComputeTypeA, ComputeTypeB, - false, // PermuteA - false, // permuteB - false, // IsBPreshuffle - true>; // ForceThreadTileTransfer + false, // PermuteA + false, // permuteB + false, // IsBPreshuffle + UseThreadTileTransfer>; // ForceThreadTileTransfer // Argument using CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = @@ -434,8 +484,8 @@ struct DeviceGroupedConvBwdWeight_Wmma_CShuffleV3 &max_occupancy, kernel_grouped_conv_bwd_weight_wmma_cshuffle_v3< GridwiseGemm, - remove_reference_t, - remove_reference_t, + remove_reference_t, + remove_reference_t, remove_reference_t, ComputePtrOffsetOfStridedBatch, true, @@ -473,8 +523,8 @@ struct DeviceGroupedConvBwdWeight_Wmma_CShuffleV3 : p_a_grid_{p_out_grid}, p_b_grid_{p_in_grid}, p_c_grid_{p_wei_grid}, - a_grid_desc_kbatch_k0_m_k1_{}, - b_grid_desc_kbatch_k0_n_k1_{}, + a_grid_desc_kbatch_m_k_{}, + b_grid_desc_kbatch_n_k_{}, c_grid_desc_m_n_{}, c_grid_desc_mblock_mperblock_nblock_nperblock_{}, compute_ptr_offset_of_batch_{}, @@ -572,16 +622,16 @@ struct DeviceGroupedConvBwdWeight_Wmma_CShuffleV3 input_right_pads, k_batch_); - a_grid_desc_kbatch_k0_m_k1_ = descs[I0]; - b_grid_desc_kbatch_k0_n_k1_ = descs[I1]; - c_grid_desc_m_n_ = descs[I2]; + a_grid_desc_kbatch_m_k_ = transform_k0_m_k1_to_m_k(descs[I0]); + b_grid_desc_kbatch_n_k_ = transform_k0_m_k1_to_n_k(descs[I1]); + c_grid_desc_m_n_ = descs[I2]; // A/B/C Batch Stride compute_ptr_offset_of_batch_.BatchStrideA_ = a_g_n_k_wos_strides_transposed[0]; compute_ptr_offset_of_batch_.BatchStrideB_ = b_g_n_c_wis_strides_transposed[0]; compute_ptr_offset_of_batch_.BatchStrideC_ = e_g_k_c_xs_strides_transposed[0]; - const index_t GemmM = a_grid_desc_kbatch_k0_m_k1_.GetLength(I1); - const index_t GemmN = b_grid_desc_kbatch_k0_n_k1_.GetLength(I1); + const index_t GemmM = a_grid_desc_kbatch_m_k_.GetLength(I0); + const index_t GemmN = b_grid_desc_kbatch_n_k_.GetLength(I0); c_grid_desc_mblock_mperblock_nblock_nperblock_ = GridwiseGemm::MakeDEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( @@ -678,8 +728,8 @@ struct DeviceGroupedConvBwdWeight_Wmma_CShuffleV3 const ADataType* p_a_grid_; const BDataType* p_b_grid_; CDataType* p_c_grid_; - AGridDesc_K0_M_K1 a_grid_desc_kbatch_k0_m_k1_; - BGridDesc_K0_N_K1 b_grid_desc_kbatch_k0_n_k1_; + AGridDesc_M_K a_grid_desc_kbatch_m_k_; + BGridDesc_N_K b_grid_desc_kbatch_n_k_; CGridDesc_M_N c_grid_desc_m_n_; CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock c_grid_desc_mblock_mperblock_nblock_nperblock_; @@ -724,17 +774,15 @@ struct DeviceGroupedConvBwdWeight_Wmma_CShuffleV3 void ShowInfo(const Argument& arg) { - std::cout << "arg.a_grid_desc_kbatch_k0_m_k1_{" - << arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I0) << ", " - << arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I1) << ", " - << arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I2) << ", " - << arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I3) << "}" << std::endl; - - std::cout << "arg.b_grid_desc_kbatch_k0_n_k1_{" - << arg.b_grid_desc_kbatch_k0_n_k1_.GetLength(I0) << ", " - << arg.b_grid_desc_kbatch_k0_n_k1_.GetLength(I1) << ", " - << arg.b_grid_desc_kbatch_k0_n_k1_.GetLength(I2) << ", " - << arg.b_grid_desc_kbatch_k0_n_k1_.GetLength(I3) << "}" << std::endl; + std::cout << "arg.a_grid_desc_kbatch_m_k_{" << arg.a_grid_desc_kbatch_m_k_.GetLength(I0) + << ", " << arg.a_grid_desc_kbatch_m_k_.GetLength(I1) << ", " + << arg.a_grid_desc_kbatch_m_k_.GetLength(I2) << ", " + << arg.a_grid_desc_kbatch_m_k_.GetLength(I3) << "}" << std::endl; + + std::cout << "arg.b_grid_desc_kbatch_n_k_{" << arg.b_grid_desc_kbatch_n_k_.GetLength(I0) + << ", " << arg.b_grid_desc_kbatch_n_k_.GetLength(I1) << ", " + << arg.b_grid_desc_kbatch_n_k_.GetLength(I2) << ", " + << arg.b_grid_desc_kbatch_n_k_.GetLength(I3) << "}" << std::endl; std::cout << "arg.c_grid_desc_m_n_{" << arg.c_grid_desc_m_n_.GetLength(I0) << ", " << arg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl; @@ -744,10 +792,9 @@ struct DeviceGroupedConvBwdWeight_Wmma_CShuffleV3 { float ave_time = 0; - const index_t GemmM = arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I1); - const index_t GemmN = arg.b_grid_desc_kbatch_k0_n_k1_.GetLength(I1); - const index_t GemmK = arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I0) * - arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I2); + const index_t GemmM = arg.a_grid_desc_kbatch_m_k_.GetLength(I0); + const index_t GemmN = arg.b_grid_desc_kbatch_n_k_.GetLength(I0); + const index_t GemmK = arg.a_grid_desc_kbatch_m_k_.GetLength(I1); const ADataType* p_a_grid = arg.p_a_grid_; const BDataType* p_b_grid = arg.p_b_grid_; @@ -839,10 +886,14 @@ struct DeviceGroupedConvBwdWeight_Wmma_CShuffleV3 index_t K_split = (gemm_arg.K + k_grain - 1) / k_grain * KPerBlock; const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K_split); - const auto num_k_per_block = - arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(Number<0>{}) / gemm_arg.KBatch; + std::cout << "K0 value is:" + << (GridwiseGemm::CalculateAK0Padded( + arg.a_grid_desc_kbatch_m_k_.GetLength(Number<1>{}), arg.k_batch_)) + << std::endl; - const auto clear_workspace = [&]() { + const index_t num_k_per_block = (GridwiseGemm::CalculateAK0Padded( + arg.a_grid_desc_kbatch_m_k_.GetLength(Number<1>{}), arg.k_batch_)); + const auto clear_workspace = [&]() { hip_check_error( hipMemsetAsync(p_e_grid, 0, arg.c_space_size_bytes, stream_config.stream_id_)); }; @@ -855,11 +906,11 @@ struct DeviceGroupedConvBwdWeight_Wmma_CShuffleV3 typename GridwiseGemm::Argument gemm_arg_ = gemm_arg; std::array size_as_buffers; - size_as_buffers[0] = arg.a_grid_desc_kbatch_k0_m_k1_.GetElementSpaceSize() * + size_as_buffers[0] = arg.a_grid_desc_kbatch_m_k_.GetElementSpaceSize() * sizeof(ADataType) / GridwiseGemm::APackedSize; std::array size_bs_buffers; - size_bs_buffers[0] = arg.b_grid_desc_kbatch_k0_n_k1_.GetElementSpaceSize() * + size_bs_buffers[0] = arg.b_grid_desc_kbatch_n_k_.GetElementSpaceSize() * sizeof(BDataType) / GridwiseGemm::BPackedSize; std::array size_ds_buffers; @@ -889,8 +940,8 @@ struct DeviceGroupedConvBwdWeight_Wmma_CShuffleV3 dim3(BlockSize), 0, gemm_arg_, - arg.a_grid_desc_kbatch_k0_m_k1_, - arg.b_grid_desc_kbatch_k0_n_k1_, + arg.a_grid_desc_kbatch_m_k_, + arg.b_grid_desc_kbatch_n_k_, arg.c_grid_desc_mblock_mperblock_nblock_nperblock_, arg.compute_ptr_offset_of_batch_, num_k_per_block); @@ -905,8 +956,8 @@ struct DeviceGroupedConvBwdWeight_Wmma_CShuffleV3 dim3(BlockSize), 0, gemm_arg, - arg.a_grid_desc_kbatch_k0_m_k1_, - arg.b_grid_desc_kbatch_k0_n_k1_, + arg.a_grid_desc_kbatch_m_k_, + arg.b_grid_desc_kbatch_n_k_, arg.c_grid_desc_mblock_mperblock_nblock_nperblock_, arg.compute_ptr_offset_of_batch_, num_k_per_block); @@ -926,8 +977,8 @@ struct DeviceGroupedConvBwdWeight_Wmma_CShuffleV3 { const auto kernel = kernel_grouped_conv_bwd_weight_wmma_cshuffle_v3< GridwiseGemm, - remove_reference_t, - remove_reference_t, + remove_reference_t, + remove_reference_t, remove_reference_t< DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>, ComputePtrOffsetOfStridedBatch, @@ -940,8 +991,8 @@ struct DeviceGroupedConvBwdWeight_Wmma_CShuffleV3 { const auto kernel = kernel_grouped_conv_bwd_weight_wmma_cshuffle_v3< GridwiseGemm, - remove_reference_t, - remove_reference_t, + remove_reference_t, + remove_reference_t, remove_reference_t< DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>, ComputePtrOffsetOfStridedBatch, @@ -965,8 +1016,8 @@ struct DeviceGroupedConvBwdWeight_Wmma_CShuffleV3 { const auto kernel = kernel_grouped_conv_bwd_weight_wmma_cshuffle_v3< GridwiseGemm, - remove_reference_t, - remove_reference_t, + remove_reference_t, + remove_reference_t, remove_reference_t< DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>, ComputePtrOffsetOfStridedBatch, @@ -979,8 +1030,8 @@ struct DeviceGroupedConvBwdWeight_Wmma_CShuffleV3 { const auto kernel = kernel_grouped_conv_bwd_weight_wmma_cshuffle_v3< GridwiseGemm, - remove_reference_t, - remove_reference_t, + remove_reference_t, + remove_reference_t, remove_reference_t< DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>, ComputePtrOffsetOfStridedBatch, @@ -1042,10 +1093,9 @@ struct DeviceGroupedConvBwdWeight_Wmma_CShuffleV3 static bool IsSupportedArgument(const Argument& arg) { - const index_t GemmM = arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I1); - const index_t GemmN = arg.b_grid_desc_kbatch_k0_n_k1_.GetLength(I1); - const index_t GemmK = arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I0) * - arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I2); + const index_t GemmM = arg.a_grid_desc_kbatch_m_k_.GetLength(I0); + const index_t GemmN = arg.b_grid_desc_kbatch_n_k_.GetLength(I0); + const index_t GemmK = arg.a_grid_desc_kbatch_m_k_.GetLength(I1); typename GridwiseGemm::Argument gemm_arg{std::array{nullptr}, // p_as_grid std::array{nullptr}, // p_bs_grid diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_wmma_v3_wave_transfer_instances.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_wmma_v3_wave_transfer_instances.hpp new file mode 100644 index 00000000000..6bc0ff8b4f0 --- /dev/null +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_wmma_v3_wave_transfer_instances.hpp @@ -0,0 +1,82 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle_v3.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +#define USE_WAVE_TRANSFER_BWD_DATA + +using BF16 = ck::bhalf_t; +using F16 = ck::half_t; +using F32 = float; +using BF8 = ck::bf8_t; +using F8 = ck::f8_t; + +using Empty_Tuple = ck::Tuple<>; + +template +using S = ck::Sequence; + +using namespace ck::tensor_layout::convolution; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto ConvBwdDataDefault = ConvolutionBackwardDataSpecialization::Default; + +static constexpr auto ConvBwdDataFilter1x1Stride1Pad0 = + ConvolutionBackwardDataSpecialization::Filter1x1Stride1Pad0; +template +using device_grouped_conv_bwd_data_wmma_cshufflev3_bf16_wave_transfer_instances = std::tuple< + // clang-format off + // ##############################################| NDim| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| AElementwise| BElementwise| CDEElementwise| ConvolutionBackward| DoPad| DoPad| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle | CShuffle | CDEBlockTransfer| CDEBlockTransfer| + // ##############################################| Spatial| | | | | Type| Type| Type| DataType| Type| Type| Operation| Operation| Operation| DataSpecialization| GemmM| GemmN| Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MRepeat| NRepeat | _MBlock_MPerBlock| ScalarPerVector| + // ##############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| | + // ##############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + // generic instance + DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffleV3, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, S<1,1,1>,BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, false>, + + DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffleV3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, S<8,8,8>,BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, false>, + DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffleV3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, S<8,8,8>,BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, false> + // clang-format on + >; + +template +using device_grouped_conv_bwd_data_wmma_cshufflev3_f16_wave_transfer_instances = std::tuple< + // clang-format off + // ##############################################| NDim| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| AElementwise| BElementwise| CDEElementwise| ConvolutionBackward| DoPad| DoPad| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle | CShuffle | CDEBlockTransfer| CDEBlockTransfer| + // ##############################################| Spatial| | | | | Type| Type| Type| DataType| Type| Type| Operation| Operation| Operation| DataSpecialization| GemmM| GemmN| Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MRepeat| NRepeat | _MBlock_MPerBlock| ScalarPerVector| + // ##############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| | + // ##############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + // generic instance + DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffleV3, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, S<1,1,1>,BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, false>, + + DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffleV3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 16, 1, 4>, S<8,8,8>,BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, false>, + DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffleV3, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 16, 1, 4>, S<8,8,8>,BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, false> + + // clang-format on + >; + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_v3_wmma_wave_transfer_instance.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_v3_wmma_wave_transfer_instance.hpp new file mode 100644 index 00000000000..f74bbbef87e --- /dev/null +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_v3_wmma_wave_transfer_instance.hpp @@ -0,0 +1,89 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT +#define USE_WAVE_TRANSFER_BWD_WEI +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_wmma_cshuffle_v3.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using namespace ck::tensor_layout::convolution; + +using BF16 = ck::bhalf_t; +using F16 = ck::half_t; +using F32 = float; + +#ifdef CK_ENABLE_FP8 +using F8 = ck::f8_t; +#endif + +#ifdef CK_ENABLE_BF8 +using BF8 = ck::bf8_t; +#endif + + + +using Empty_Tuple = ck::Tuple<>; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto ConvBwdWeightDefault = + ck::tensor_operation::device::ConvolutionBackwardWeightSpecialization::Default; + +static constexpr auto ConvBwdWeightFilter1x1Stride1Pad0 = + ck::tensor_operation::device::ConvolutionBackwardWeightSpecialization::Filter1x1Stride1Pad0; + +template +using device_grouped_conv_bwd_weight_v3_wmma_c_shuffle_f16_wave_transfer_instances = std::tuple< + // clang-format off + //#########################################| Num| InLayout| WeiLayout| OutLayout| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Block| MPer| NPer| KPer| ABK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CShuffleBlockTransfer| BlockGemm| BlockGemm| + //#########################################| Dim| | | | Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Weight| Size| Block| Block| Block| | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| NRepeat| ClusterLengths| ScalarPerVector| Pipeline| Pipeline| + //#########################################| Spatial| | | | | | | | Operation| Operation| Operation| Specialization| | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| MBlock_MPerBlock| _NPerBlock| Scheduler| Version| + //#########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | NBlock_NPerBlock| | | | + DeviceGroupedConvBwdWeight_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 32, 32, 32, 8, 16, 16, 2, 1, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 2, 2, 0, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 2, 2, 0, 1, 1, S<1, 8, 1, 8>, 2, Scheduler, PipelineVersion, false>, + DeviceGroupedConvBwdWeight_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 128, 128, 32, 8, 16, 16, 8, 2, S<4, 32, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 8, 0, S<4, 32, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 8, 0, 1, 1, S<1, 16, 1, 8>, 8, Scheduler, PipelineVersion, false> + + // clang-format on + >; + +template +using device_grouped_conv_bwd_weight_v3_wmma_c_shuffle_bf16_wave_transfer_instances = std::tuple< + // clang-format off + //#########################################| Num| InLayout| WeiLayout| OutLayout| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Block| MPer| NPer| KPer| ABK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CShuffleBlockTransfer| BlockGemm| BlockGemm| + //#########################################| Dim| | | | Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Weight| Size| Block| Block| Block| | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| NRepeat| ClusterLengths| ScalarPerVector| Pipeline| Pipeline| + //#########################################| Spatial| | | | | | | | Operation| Operation| Operation| Specialization| | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| MBlock_MPerBlock| _NPerBlock| Scheduler| Version| + //#########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | NBlock_NPerBlock| | | | + DeviceGroupedConvBwdWeight_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 32, 32, 32, 8, 16, 16, 2, 1, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 2, 2, 0, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 2, 2, 0, 1, 1, S<1, 8, 1, 8>, 2, Scheduler, PipelineVersion, false>, + DeviceGroupedConvBwdWeight_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 128, 128, 32, 8, 16, 16, 8, 2, S<4, 32, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 8, 0, S<4, 32, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 8, 0, 1, 1, S<1, 16, 1, 8>, 8, Scheduler, PipelineVersion, false> + + //clang-format on + >; + + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck + diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_data.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_data.hpp index f784b6ea510..2d1a8eb4495 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_data.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_data.hpp @@ -450,6 +450,8 @@ struct DeviceOperationInstanceFactory< op_ptrs); add_device_grouped_conv2d_bwd_data_wmma_nhwgk_gkyxc_nhwgc_f16_1x1s1p0_instances( op_ptrs); + add_device_grouped_conv2d_bwd_data_wmma_v3_wave_transfer_nhwgk_gkyxc_nhwgc_f16_instances( + op_ptrs); } #endif #ifdef CK_ENABLE_BF16 @@ -461,6 +463,9 @@ struct DeviceOperationInstanceFactory< op_ptrs); add_device_grouped_conv2d_bwd_data_wmma_v3_nhwgk_gkyxc_nhwgc_bf16_16_16_instances( op_ptrs); + + add_device_grouped_conv2d_bwd_data_wmma_v3_wave_transfer_nhwgk_gkyxc_nhwgc_bf16_instances( + op_ptrs); } #endif #ifdef CK_ENABLE_INT8 @@ -520,6 +525,8 @@ struct DeviceOperationInstanceFactory< op_ptrs); add_device_grouped_conv3d_bwd_data_wmma_v3_ndhwgk_gkzyxc_ndhwgc_f16_16_16_instances( op_ptrs); + add_device_grouped_conv3d_bwd_data_wmma_v3_wave_transfer_ndhwgk_gkzyxc_ndhwgc_f16_instances( + op_ptrs); } #endif #ifdef CK_ENABLE_BF16 @@ -531,6 +538,8 @@ struct DeviceOperationInstanceFactory< op_ptrs); add_device_grouped_conv3d_bwd_data_wmma_v3_ndhwgk_gkzyxc_ndhwgc_bf16_16_16_instances( op_ptrs); + add_device_grouped_conv3d_bwd_data_wmma_v3_wave_transfer_ndhwgk_gkzyxc_ndhwgc_bf16_instances( + op_ptrs); } #endif #ifdef CK_ENABLE_INT8 diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_data_wmma.inc b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_data_wmma.inc index 40b659a87f5..3ee3b54ac8c 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_data_wmma.inc +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_data_wmma.inc @@ -80,6 +80,20 @@ void add_device_grouped_conv2d_bwd_data_wmma_nhwgk_gkyxc_nhwgc_f16_instances( PassThrough, PassThrough>>>& instances); +void add_device_grouped_conv2d_bwd_data_wmma_v3_wave_transfer_nhwgk_gkyxc_nhwgc_f16_instances( + std::vector>>& instances); + void add_device_grouped_conv2d_bwd_data_wmma_nhwgk_gkyxc_nhwgc_f16_1x1s1p0_instances( std::vector>>& instances); +void add_device_grouped_conv2d_bwd_data_wmma_v3_wave_transfer_nhwgk_gkyxc_nhwgc_bf16_instances( + std::vector>>& instances); + #endif // conv3dbwdData @@ -326,6 +354,20 @@ void add_device_grouped_conv3d_bwd_data_wmma_v3_ndhwgk_gkzyxc_ndhwgc_bf16_16_16_ PassThrough, PassThrough, PassThrough>>>& instances); + +void add_device_grouped_conv3d_bwd_data_wmma_v3_wave_transfer_ndhwgk_gkzyxc_ndhwgc_bf16_instances( + std::vector>>& instances); #endif #ifdef CK_ENABLE_FP16 void add_device_grouped_conv3d_bwd_data_wmma_v3_ndhwgk_gkzyxc_ndhwgc_f16_instances( @@ -355,6 +397,20 @@ void add_device_grouped_conv3d_bwd_data_wmma_v3_ndhwgk_gkzyxc_ndhwgc_f16_16_16_i PassThrough, PassThrough, PassThrough>>>& instances); + +void add_device_grouped_conv3d_bwd_data_wmma_v3_wave_transfer_ndhwgk_gkzyxc_ndhwgc_f16_instances( + std::vector>>& instances); #endif } // namespace instance diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight.hpp index c07dc71ac56..9ac16834f0f 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight.hpp @@ -393,9 +393,6 @@ struct DeviceOperationInstanceFactory>>& instances); +void add_device_grouped_conv2d_bwd_weight_wmma_nhwgc_gkyxc_nhwgk_f16_wave_transfer_instances( + std::vector>>& instances); + void add_device_grouped_conv2d_bwd_weight_two_stage_wmma_nhwgc_gkyxc_nhwgk_f16_pipev1_instances( std::vector>>& instances); +void add_device_grouped_conv2d_bwd_weight_wmma_nhwgc_gkyxc_nhwgk_bf16_wave_transfer_instances( + std::vector>>& instances); + void add_device_grouped_conv2d_bwd_weight_two_stage_wmma_nhwgc_gkyxc_nhwgk_bf16_pipev1_instances( std::vector>>& instances); +void add_device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_f16_wave_transfer_instances( + std::vector>>& instances); + + void add_device_grouped_conv3d_bwd_weight_two_stage_wmma_ndhwgc_gkzyxc_ndhwgk_f16_pipev1_instances( std::vector>>& instances); +void add_device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_bf16_wave_transfer_instances( + std::vector>>& instances); + void add_device_grouped_conv3d_bwd_weight_two_stage_wmma_ndhwgc_gkzyxc_ndhwgk_bf16_pipev1_instances( std::vector>>& instances) +{ + + // 2. Filter1x1Stride1Pad0 + add_device_operation_instances( + instances, + device_grouped_conv_bwd_data_wmma_cshufflev3_bf16_wave_transfer_instances< + 2, + NHWGK, + GKYXC, + Empty_Tuple, + NHWGC, + ConvBwdDataFilter1x1Stride1Pad0>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/wmma/device_grouped_conv2d_bwd_data_wmma_v3_wave_transfer_nhwgc_gkyxc_nhwgk_f16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/wmma/device_grouped_conv2d_bwd_data_wmma_v3_wave_transfer_nhwgc_gkyxc_nhwgk_f16_instance.cpp new file mode 100644 index 00000000000..a47837768d5 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/wmma/device_grouped_conv2d_bwd_data_wmma_v3_wave_transfer_nhwgc_gkyxc_nhwgk_f16_instance.cpp @@ -0,0 +1,42 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_wmma_v3_wave_transfer_instances.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_grouped_conv2d_bwd_data_wmma_v3_wave_transfer_nhwgk_gkyxc_nhwgc_f16_instances( + std::vector>>& instances) +{ + + // 2. Filter1x1Stride1Pad0 + add_device_operation_instances( + instances, + device_grouped_conv_bwd_data_wmma_cshufflev3_f16_wave_transfer_instances< + 2, + NHWGK, + GKYXC, + Empty_Tuple, + NHWGC, + ConvBwdDataFilter1x1Stride1Pad0>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/CMakeLists.txt index 268835d5bfd..59a30d4ea9c 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/CMakeLists.txt @@ -79,6 +79,8 @@ list(APPEND GROUPED_CONV2D_BWD_WEIGHT wmma/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_two_stage_wmma_nhwgc_gkyxc_nhwgk_bf16_pipev1_instance.cpp wmma/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_wmma_nhwgc_gkyxc_nhwgk_f16_instance.cpp wmma/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_two_stage_wmma_nhwgc_gkyxc_nhwgk_f16_pipev1_instance.cpp + wmma/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_wmma_nhwgc_gkyxc_nhwgk_bf16_wave_transfer_instance.cpp + wmma/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_wmma_nhwgc_gkyxc_nhwgk_f16_wave_transfer_instance.cpp ) add_instance_library(device_grouped_conv2d_bwd_weight_instance ${GROUPED_CONV2D_BWD_WEIGHT}) diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/wmma/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_wmma_nhwgc_gkyxc_nhwgk_bf16_wave_transfer_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/wmma/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_wmma_nhwgc_gkyxc_nhwgk_bf16_wave_transfer_instance.cpp new file mode 100644 index 00000000000..126d7444376 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/wmma/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_wmma_nhwgc_gkyxc_nhwgk_bf16_wave_transfer_instance.cpp @@ -0,0 +1,39 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_v3_wmma_wave_transfer_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] +void add_device_grouped_conv2d_bwd_weight_wmma_nhwgc_gkyxc_nhwgk_bf16_wave_transfer_instances( + std::vector>>& instances) +{ + // 1. Default + add_device_operation_instances( + instances, + device_grouped_conv_bwd_weight_v3_wmma_c_shuffle_bf16_wave_transfer_instances< + 2, + NHWGC, + GKYXC, + NHWGK, + ConvBwdWeightFilter1x1Stride1Pad0>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/wmma/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_wmma_nhwgc_gkyxc_nhwgk_f16_wave_transfer_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/wmma/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_wmma_nhwgc_gkyxc_nhwgk_f16_wave_transfer_instance.cpp new file mode 100644 index 00000000000..0c41d911055 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/wmma/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_wmma_nhwgc_gkyxc_nhwgk_f16_wave_transfer_instance.cpp @@ -0,0 +1,39 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_v3_wmma_wave_transfer_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] +void add_device_grouped_conv2d_bwd_weight_wmma_nhwgc_gkyxc_nhwgk_f16_wave_transfer_instances( + std::vector>>& instances) +{ + // 1. Default + add_device_operation_instances( + instances, + device_grouped_conv_bwd_weight_v3_wmma_c_shuffle_f16_wave_transfer_instances< + 2, + NHWGC, + GKYXC, + NHWGK, + ConvBwdWeightFilter1x1Stride1Pad0>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/CMakeLists.txt index 01ff4095d74..46327941390 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/CMakeLists.txt @@ -45,6 +45,10 @@ set(GROUPED_CONV3D_BWD_DATA wmma/device_grouped_conv3d_bwd_data_wmma_v3_ndhwgc_gkzyxc_ndhwgk_f16_16_16_instance.cpp wmma/device_grouped_conv3d_bwd_data_wmma_v3_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp + + + wmma/device_grouped_conv3d_bwd_data_wmma_v3_wave_transfer_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp + wmma/device_grouped_conv3d_bwd_data_wmma_v3_wave_transfer_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp ) diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/wmma/device_grouped_conv3d_bwd_data_wmma_v3_wave_transfer_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/wmma/device_grouped_conv3d_bwd_data_wmma_v3_wave_transfer_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp new file mode 100644 index 00000000000..9f45f9527ea --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/wmma/device_grouped_conv3d_bwd_data_wmma_v3_wave_transfer_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp @@ -0,0 +1,42 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_wmma_v3_wave_transfer_instances.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_grouped_conv3d_bwd_data_wmma_v3_wave_transfer_ndhwgk_gkzyxc_ndhwgc_bf16_instances( + std::vector>>& instances) +{ + + // 2. Filter1x1Stride1Pad0 + add_device_operation_instances( + instances, + device_grouped_conv_bwd_data_wmma_cshufflev3_bf16_wave_transfer_instances< + 3, + NDHWGK, + GKZYXC, + Empty_Tuple, + NDHWGC, + ConvBwdDataFilter1x1Stride1Pad0>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/wmma/device_grouped_conv3d_bwd_data_wmma_v3_wave_transfer_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/wmma/device_grouped_conv3d_bwd_data_wmma_v3_wave_transfer_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp new file mode 100644 index 00000000000..a237027bcf5 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/wmma/device_grouped_conv3d_bwd_data_wmma_v3_wave_transfer_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp @@ -0,0 +1,42 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_wmma_v3_wave_transfer_instances.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_grouped_conv3d_bwd_data_wmma_v3_wave_transfer_ndhwgk_gkzyxc_ndhwgc_f16_instances( + std::vector>>& instances) +{ + + // 2. Filter1x1Stride1Pad0 + add_device_operation_instances( + instances, + device_grouped_conv_bwd_data_wmma_cshufflev3_f16_wave_transfer_instances< + 3, + NDHWGK, + GKZYXC, + Empty_Tuple, + NDHWGC, + ConvBwdDataFilter1x1Stride1Pad0>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/CMakeLists.txt index b246b87178e..e7d8403812a 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/CMakeLists.txt @@ -73,6 +73,8 @@ list(APPEND GROUPED_CONV3D_BWD_WEIGHT wmma/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp wmma/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_two_stage_wmma_ndhwgc_gkzyxc_ndhwgk_f16_pipev1_instance.cpp wmma/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_two_stage_wmma_ndhwgc_gkzyxc_ndhwgk_bf16_pipev1_instance.cpp + wmma/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_f16_wave_transfer_instance.cpp + wmma/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_bf16_wave_transfer_instance.cpp ) if((DTYPES MATCHES "fp8" AND DTYPES MATCHES "bf8" AND DTYPES MATCHES "fp16") OR NOT DEFINED DTYPES) diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/wmma/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_bf16_wave_transfer_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/wmma/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_bf16_wave_transfer_instance.cpp new file mode 100644 index 00000000000..5d6be45bfb5 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/wmma/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_bf16_wave_transfer_instance.cpp @@ -0,0 +1,39 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_v3_wmma_wave_transfer_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] +void add_device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_bf16_wave_transfer_instances( + std::vector>>& instances) +{ + // 1. Default + add_device_operation_instances( + instances, + device_grouped_conv_bwd_weight_v3_wmma_c_shuffle_bf16_wave_transfer_instances< + 3, + NDHWGC, + GKZYXC, + NDHWGK, + ConvBwdWeightFilter1x1Stride1Pad0>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/wmma/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_f16_wave_transfer_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/wmma/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_f16_wave_transfer_instance.cpp new file mode 100644 index 00000000000..200d150bebe --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/wmma/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_f16_wave_transfer_instance.cpp @@ -0,0 +1,39 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_v3_wmma_wave_transfer_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] +void add_device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_f16_wave_transfer_instances( + std::vector>>& instances) +{ + // 1. Default + add_device_operation_instances( + instances, + device_grouped_conv_bwd_weight_v3_wmma_c_shuffle_f16_wave_transfer_instances< + 3, + NDHWGC, + GKZYXC, + NDHWGK, + ConvBwdWeightFilter1x1Stride1Pad0>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck