diff --git a/CMakeLists.txt b/CMakeLists.txt index cd7121b39db..7a0835f35e5 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -277,16 +277,16 @@ endif() # define the macro with the current value (0 or 1) add_definitions(-DCK_TILE_USE_WMMA=${CK_TILE_USE_WMMA}) -if (SUPPORTED_GPU_TARGETS MATCHES "gfx12" AND NOT FORCE_DISABLE_WMMA) +if (SUPPORTED_GPU_TARGETS MATCHES "gfx12" AND NOT FORCE_DISABLE_WMMA AND DEFINED CK_ENABLE_FP8) message(STATUS "Enabling WMMA FP8 gemms on native architectures") add_definitions(-DCK_USE_WMMA_FP8) set(CK_USE_WMMA_FP8 "ON") endif() -if (SUPPORTED_GPU_TARGETS MATCHES "gfx12" OR SUPPORTED_GPU_TARGETS MATCHES "gfx950") +if (SUPPORTED_GPU_TARGETS MATCHES "gfx12|gfx950" AND DEFINED CK_ENABLE_FP8) add_definitions(-DCK_USE_OCP_FP8) set(CK_USE_OCP_FP8 "ON") endif() -if (SUPPORTED_GPU_TARGETS MATCHES "gfx90a" OR SUPPORTED_GPU_TARGETS MATCHES "gfx94") +if (SUPPORTED_GPU_TARGETS MATCHES "gfx90a|gfx94" AND DEFINED CK_ENABLE_FP8) add_definitions(-DCK_USE_FNUZ_FP8) set(CK_USE_FNUZ_FP8 "ON") endif() diff --git a/include/ck/tensor_description/multi_index_transform.hpp b/include/ck/tensor_description/multi_index_transform.hpp index 19a47487328..a922a216ba8 100644 --- a/include/ck/tensor_description/multi_index_transform.hpp +++ b/include/ck/tensor_description/multi_index_transform.hpp @@ -1745,6 +1745,215 @@ struct ConvBwdDataImplicitGemmOutTransform } }; +/** + * @brief Transformation struct for convolution backward data output indices to GEMM indices. + * + * This struct is responsible for mapping the output tensor indices (N, Ho, Wo, G, K) from the + * convolution backward data operation to the corresponding indices (K0, M, K1) used in the + * implicit GEMM computation. It encapsulates the necessary parameters and transformation logic + * required to efficiently perform the index conversion. + */ +struct ConvBwdDataImplicitGemmOutTransformMG +{ + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + static constexpr auto I4 = Number<4>{}; + + using LowerIndex = MultiIndex<5>; // N, Ho, Wo, G, K + using UpperIndex = MultiIndex<3>; // K0, M, K1 + + index_t N_, Ho_, Wo_, K_; + index_t XDot_; + index_t HTilde_, WTilde_; + index_t WTildeSlice_NumGroupsToMerge_, TildeSlice_NumGroupsToMerge_, NumGroupsToMerge_; + index_t IHTildeSliceBegin_, IWTildeSliceBegin_; + index_t HRatio_, WRatio_; + index_t XDotSlice_K_; + index_t MPad_, KPad_; + Tuple up_lengths_; // K0_, MPadded, K1_; + + Tuple low_lengths_magic_divisor_multiplier_; + Tuple low_lengths_magic_divisor_shift_; + + __host__ __device__ ConvBwdDataImplicitGemmOutTransformMG() = default; + + __host__ __device__ constexpr ConvBwdDataImplicitGemmOutTransformMG( + index_t N, + index_t Ho, + index_t Wo, + index_t K, + index_t XDot, + index_t HTilde, + index_t WTilde, + index_t NumGroupsToMerge, + index_t WTildeSlice_NumGroupsToMerge, + index_t HWTildeSlice_NumGroupsToMerge, + index_t IHTildeSliceBegin, + index_t IWTildeSliceBegin, + index_t HRatio, + index_t WRatio, + index_t XDotSlice_K, + index_t K0, + index_t MPadded, + index_t K1, + index_t MPad, + index_t KPad) + : N_{N}, + Ho_{Ho}, + Wo_{Wo}, + K_{K}, + XDot_{XDot}, + HTilde_{HTilde}, + WTilde_{WTilde}, + WTildeSlice_NumGroupsToMerge_{WTildeSlice_NumGroupsToMerge}, + TildeSlice_NumGroupsToMerge_{HWTildeSlice_NumGroupsToMerge}, + NumGroupsToMerge_(NumGroupsToMerge), + IHTildeSliceBegin_{IHTildeSliceBegin}, + IWTildeSliceBegin_{IWTildeSliceBegin}, + HRatio_{HRatio}, + WRatio_{WRatio}, + XDotSlice_K_{XDotSlice_K}, + MPad_{MPad}, + KPad_{KPad}, + up_lengths_{make_tuple(K0, MPadded, K1)}, + low_lengths_magic_divisor_multiplier_{ + MagicDivision::CalculateMagicMultiplier(XDotSlice_K_), + MagicDivision::CalculateMagicMultiplier(K_), + MagicDivision::CalculateMagicMultiplier(TildeSlice_NumGroupsToMerge_), + MagicDivision::CalculateMagicMultiplier(WTildeSlice_NumGroupsToMerge_), + MagicDivision::CalculateMagicMultiplier(NumGroupsToMerge_)}, + low_lengths_magic_divisor_shift_{ + MagicDivision::CalculateMagicShift(XDotSlice_K_), + MagicDivision::CalculateMagicShift(K_), + MagicDivision::CalculateMagicShift(TildeSlice_NumGroupsToMerge_), + MagicDivision::CalculateMagicShift(WTildeSlice_NumGroupsToMerge_), + MagicDivision::CalculateMagicShift(NumGroupsToMerge_)} + { + } + + __host__ __device__ static constexpr index_t GetNumOfLowerDimension() { return 5; } + + __host__ __device__ static constexpr index_t GetNumOfUpperDimension() { return 3; } + + __host__ __device__ constexpr const auto& GetUpperLengths() const { return up_lengths_; } + + template + __host__ __device__ constexpr auto CalculateLowerIndexN(const UpIdx& idx_up) const + { + index_t NStep{0}, GStep{0}, HStep{0}, WStep{0}; + // Merge + // NStep = M_id / (TildeSlice * NumGroupsToMerge) + NStep = MagicDivision::DoMagicDivision(idx_up[I1], + this->low_lengths_magic_divisor_multiplier_[I2], + this->low_lengths_magic_divisor_shift_[I2]); + + HStep = idx_up[I1] - NStep * TildeSlice_NumGroupsToMerge_; + // HStep = HStep / (WTildeSlice * NumGroupsToMerge) + HStep = MagicDivision::DoMagicDivision(HStep, + this->low_lengths_magic_divisor_multiplier_[I3], + this->low_lengths_magic_divisor_shift_[I3]); + + WStep = idx_up[I1] - NStep * TildeSlice_NumGroupsToMerge_ - + HStep * WTildeSlice_NumGroupsToMerge_; + // WStep = WStep / NumGroupsToMerge + WStep = MagicDivision::DoMagicDivision(WStep, + this->low_lengths_magic_divisor_multiplier_[I4], + this->low_lengths_magic_divisor_shift_[I4]); + GStep = idx_up[I1] - NStep * TildeSlice_NumGroupsToMerge_ - + HStep * WTildeSlice_NumGroupsToMerge_ - WStep * NumGroupsToMerge_; + + // Slice + HStep += IHTildeSliceBegin_; + WStep += IWTildeSliceBegin_; + + return make_tuple(NStep, HStep, WStep, GStep, 0); + } + + template + __host__ __device__ constexpr auto CalculateLowerIndexK(const UpIdx& idx_up) const + { + // UnMerge + // K_idx <- K0_idx * K1 + K1_idx + index_t K_idx = idx_up[I0] * up_lengths_[I2] + idx_up[I2]; + // Merge + // YStep = K_idx / (XDotSlice * K) + index_t YStep = + MagicDivision::DoMagicDivision(K_idx, + this->low_lengths_magic_divisor_multiplier_[I0], + this->low_lengths_magic_divisor_shift_[I0]); + + index_t KStep = K_idx - YStep * XDotSlice_K_; + // Xstep = KStep / K + index_t XStep = + MagicDivision::DoMagicDivision(KStep, + this->low_lengths_magic_divisor_multiplier_[I1], + this->low_lengths_magic_divisor_shift_[I1]); + KStep -= XStep * K_; + + // Embed + YStep *= HRatio_; + XStep *= WRatio_; + + return make_tuple(0, YStep, XStep, 0, KStep); + } + + template + __host__ __device__ constexpr void CalculateLowerIndex(LowIdx& idx_low, + const UpIdx& idx_up) const + { + idx_low = CalculateLowerIndexN(idx_up) + CalculateLowerIndexK(idx_up); + } + + template + __host__ __device__ void UpdateLowerIndex(LowIdxDiff& idx_diff_low, + const UpIdxDiff& /* idx_diff_up */, + LowIdx& idx_low, + const UpIdx& idx_up, + Number) const + { + LowIdx low_old = idx_low; + idx_low = CalculateLowerIndexN(idx_up) + CalculateLowerIndexK(idx_up); + idx_diff_low = idx_low - low_old; + } + + __host__ __device__ static constexpr bool IsLinearTransform() { return false; } + + __host__ __device__ static constexpr bool IsValidUpperIndexAlwaysMappedToValidLowerIndex() + { + return true; + } + + template + __host__ __device__ constexpr bool + IsValidUpperIndexMappedToValidLowerIndex(const UpIdx& idx_up) const + { + // Padding + index_t K_idx = idx_up[Number<0>{}] * up_lengths_[Number<2>{}] + idx_up[Number<2>{}]; + index_t& M_idx = idx_up[Number<1>{}]; + + bool pad_valid = M_idx < up_lengths_[Number<1>{}] - MPad_ && + K_idx < up_lengths_[Number<0>{}] * up_lengths_[Number<2>{}] - KPad_; + return pad_valid; + } + + __host__ __device__ static constexpr bool IsKnownAtCompileTime() { return false; } + + __host__ __device__ void Print() const + { + printf("{"); + printf("ConvBwdDataImplicitGemmOutTransformMG, "); + printf("up_lengths_"); + print_multi_index(up_lengths_); + printf("}"); + } +}; + template struct Freeze { diff --git a/include/ck/tensor_description/multi_index_transform_helper.hpp b/include/ck/tensor_description/multi_index_transform_helper.hpp index 129df612757..9db64c52af5 100644 --- a/include/ck/tensor_description/multi_index_transform_helper.hpp +++ b/include/ck/tensor_description/multi_index_transform_helper.hpp @@ -94,6 +94,7 @@ __host__ __device__ constexpr auto make_unmerge_transform( return UnMerge{up_lengths}; } +template __host__ __device__ constexpr auto make_conv_bwd_data_out_transform(index_t N, index_t Ho, index_t Wo, @@ -118,7 +119,7 @@ __host__ __device__ constexpr auto make_conv_bwd_data_out_transform(index_t N, index_t GemmKPerBlock) { // Calculate padding - const auto MRaw = N * HTildeSlice * WTildeSlice; + const auto MRaw = NumGroupsToMerge * N * HTildeSlice * WTildeSlice; const auto MPadded = math::integer_divide_ceil(MRaw, MPerBlock) * MPerBlock; const auto MPad = MPadded - MRaw; @@ -126,25 +127,51 @@ __host__ __device__ constexpr auto make_conv_bwd_data_out_transform(index_t N, const auto KPadded = math::integer_divide_ceil(KRaw, GemmKPerBlock) * GemmKPerBlock; const auto KPad = KPadded - KRaw; - return ConvBwdDataImplicitGemmOutTransform{N, - Ho, - Wo, - K, - XDot, - HTilde, - WTilde, - WTildeSlice, - HTildeSlice * WTildeSlice, - IHTildeSliceBegin, - IWTildeSliceBegin, - -ConvDilationH / GcdStrideDilationH, - -ConvDilationW / GcdStrideDilationW, - XDotSlice * K, - K0, - MPadded, - K1, - MPad, - KPad}; + if constexpr(NumGroupsToMerge == 1) + { + return ConvBwdDataImplicitGemmOutTransform{N, + Ho, + Wo, + K, + XDot, + HTilde, + WTilde, + WTildeSlice, + HTildeSlice * WTildeSlice, + IHTildeSliceBegin, + IWTildeSliceBegin, + -ConvDilationH / GcdStrideDilationH, + -ConvDilationW / GcdStrideDilationW, + XDotSlice * K, + K0, + MPadded, + K1, + MPad, + KPad}; + } + else + { + return ConvBwdDataImplicitGemmOutTransformMG{N, + Ho, + Wo, + K, + XDot, + HTilde, + WTilde, + NumGroupsToMerge, + WTildeSlice * NumGroupsToMerge, + HTildeSlice * WTildeSlice * NumGroupsToMerge, + IHTildeSliceBegin, + IWTildeSliceBegin, + -ConvDilationH / GcdStrideDilationH, + -ConvDilationW / GcdStrideDilationW, + XDotSlice * K, + K0, + MPadded, + K1, + MPad, + KPad}; + } } template diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp index b324845c3eb..fa410fd2f95 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp @@ -77,7 +77,8 @@ template + bool CTranspose, + index_t NumGroupsToMerge> __global__ void #if CK_USE_LAUNCH_BOUNDS __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) @@ -101,7 +102,7 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) { // offset base pointer for each work-group const index_t block_args_id = __builtin_amdgcn_readfirstlane(blockIdx.x); - const index_t g_idx = __builtin_amdgcn_readfirstlane(blockIdx.y); + const index_t g_idx = __builtin_amdgcn_readfirstlane(blockIdx.y * NumGroupsToMerge); const index_t n_idx = __builtin_amdgcn_readfirstlane(blockIdx.z / KBatch); const index_t k_idx = __builtin_amdgcn_readfirstlane(blockIdx.z - n_idx * KBatch); @@ -287,7 +288,8 @@ template + index_t MaxTransposeTransferOutScalarPerVector = 1, + index_t NumGroupsToMerge = 1> struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 : public DeviceGroupedConvBwdDataMultipleD= 1); + // MaxGroupedGemmGroupsNum is used to specify number of gemm args in compile time. With this // implementation we can avoid copy data to workspace before kernel launch since number of // groups is runtime parameter. If number of groups is larger than MaxGroupedGemmGroupsNum then @@ -387,7 +391,7 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 true, /*SplitConvN*/ ABDataType, EDataType, - 1, + NumGroupsToMerge, index_t, CTranspose>; @@ -418,7 +422,8 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 true, /*SplitConvN*/ ABDataType, DDataType, - 1, /*index_t NumGroupsToMerge = 1,*/ + NumGroupsToMerge, /*index_t NumGroupsToMerge = + 1,*/ index_t, /* typename IndexType = */ CTranspose>; return ConvToGemmBwdDataTransformD{}.MakeCDescriptor_M_N(); @@ -887,7 +892,7 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 true, /*SplitConvN*/ ABDataType, DDataType, - 1, + NumGroupsToMerge, index_t, CTranspose>; ConvToGemmBwdDataTransformD conv_to_gemm_transform_d{ @@ -1147,7 +1152,7 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 { float ave_time = 0; - const index_t gdy = arg.num_group_; + const index_t gdy = arg.num_group_ / NumGroupsToMerge; const index_t gdz = arg.num_workgroups_per_Conv_N_ * arg.k_batch_; const ADataType* p_a_grid = arg.p_a_grid_; @@ -1219,7 +1224,8 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 ElementOp, has_main_loop, no_main_loop, - CTranspose>; + CTranspose, + NumGroupsToMerge>; return launch_and_time_kernel_with_preprocess( stream_config, @@ -1258,7 +1264,8 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 ElementOp, has_main_loop, no_main_loop, - CTranspose>; + CTranspose, + NumGroupsToMerge>; return launch_and_time_kernel_with_preprocess( stream_config, @@ -1557,6 +1564,18 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 } } + if constexpr(NumGroupsToMerge > 1) + { + if(!(ConvC == 1 && ConvK == 1)) + { + return false; + } + if(ConvG % NumGroupsToMerge != 0) + { + return false; + } + } + // vector load for A matrix from global memory to LDS if constexpr(is_same_v || is_same_v || diff --git a/include/ck/tensor_operation/operator_transform/transform_conv_bwd_data_to_gemm_v1.hpp b/include/ck/tensor_operation/operator_transform/transform_conv_bwd_data_to_gemm_v1.hpp index 03dc0efeb5a..a26810154e8 100644 --- a/include/ck/tensor_operation/operator_transform/transform_conv_bwd_data_to_gemm_v1.hpp +++ b/include/ck/tensor_operation/operator_transform/transform_conv_bwd_data_to_gemm_v1.hpp @@ -164,6 +164,12 @@ struct TransformConvBwdDataToGemm_v1 static_cast(transform_conv_bwd_data_to_gemm_base.NStrideTensorA_)}, NStrideTensorC_{ static_cast(transform_conv_bwd_data_to_gemm_base.NStrideTensorC_)}, + GStrideTensorA_{ + static_cast(transform_conv_bwd_data_to_gemm_base.GStrideTensorA_)}, + GStrideTensorB_{ + static_cast(transform_conv_bwd_data_to_gemm_base.GStrideTensorB_)}, + GStrideTensorC_{ + static_cast(transform_conv_bwd_data_to_gemm_base.GStrideTensorC_)}, ConvStrideD_{static_cast(transform_conv_bwd_data_to_gemm_base.ConvStrideD_)}, ConvStrideH_{static_cast(transform_conv_bwd_data_to_gemm_base.ConvStrideH_)}, ConvStrideW_{static_cast(transform_conv_bwd_data_to_gemm_base.ConvStrideW_)}, @@ -233,6 +239,9 @@ struct TransformConvBwdDataToGemm_v1 KStrideTensorB_{b_g_k_c_xs_strides[I1]}, NStrideTensorA_{a_g_n_k_wos_strides[I1]}, NStrideTensorC_{c_g_n_c_wis_strides[I1]}, + GStrideTensorA_{a_g_n_k_wos_strides[I0]}, + GStrideTensorB_{b_g_k_c_xs_strides[I0]}, + GStrideTensorC_{c_g_n_c_wis_strides[I0]}, ConvStrideH_{conv_filter_strides[HIdx - NonSpatialDimsNum]}, ConvStrideW_{conv_filter_strides[WIdx - NonSpatialDimsNum]}, ConvDilationH_{conv_filter_dilations[HIdx - NonSpatialDimsNum]}, @@ -508,15 +517,36 @@ struct TransformConvBwdDataToGemm_v1 ck::tensor_operation::device::ConvolutionBackwardDataSpecialization:: Filter1x1Stride1Pad0) { - - return make_naive_tensor_descriptor(make_tuple(N_ * Ho_ * Wo_, K_), - make_tuple(WoStride_, KStrideTensorA_)); + if constexpr(NumGroupsToMerge == 1) + { + return make_naive_tensor_descriptor(make_tuple(N_ * Ho_ * Wo_, K_), + make_tuple(WoStride_, KStrideTensorA_)); + } + else + { + return make_naive_tensor_descriptor( + make_tuple(N_ * Ho_ * Wo_, NumGroupsToMerge, K_), + make_tuple(WoStride_, GStrideTensorA_, KStrideTensorA_)); + } } else { - return make_naive_tensor_descriptor( - make_tuple(N_, Ho_, Wo_, K_), - make_tuple(NStrideTensorA_, HoStride_, WoStride_, KStrideTensorA_)); + if constexpr(NumGroupsToMerge == 1) + { + return make_naive_tensor_descriptor( + make_tuple(N_, Ho_, Wo_, K_), + make_tuple(NStrideTensorA_, HoStride_, WoStride_, KStrideTensorA_)); + } + else + { + return make_naive_tensor_descriptor( + make_tuple(N_, Ho_, Wo_, NumGroupsToMerge, K_), + make_tuple(NStrideTensorA_, + HoStride_, + WoStride_, + GStrideTensorA_, + KStrideTensorA_)); + } } } else if constexpr(is_same_v) @@ -525,19 +555,45 @@ struct TransformConvBwdDataToGemm_v1 ck::tensor_operation::device::ConvolutionBackwardDataSpecialization:: Filter1x1Stride1Pad0) { - - return make_naive_tensor_descriptor(make_tuple(N_ * Do_ * Ho_ * Wo_, K_), - make_tuple(WoStride_, KStrideTensorA_)); + if constexpr(NumGroupsToMerge == 1) + { + return make_naive_tensor_descriptor(make_tuple(N_ * Do_ * Ho_ * Wo_, K_), + make_tuple(WoStride_, KStrideTensorA_)); + } + else + { + return make_naive_tensor_descriptor( + make_tuple(N_ * Do_ * Ho_ * Wo_, NumGroupsToMerge, K_), + make_tuple(WoStride_, GStrideTensorA_, KStrideTensorA_)); + } } else { - return make_naive_tensor_descriptor( - make_tuple(N_, Do_, Ho_, Wo_, K_), - make_tuple(NStrideTensorA_, DoStride_, HoStride_, WoStride_, KStrideTensorA_)); + if constexpr(NumGroupsToMerge == 1) + { + return make_naive_tensor_descriptor( + make_tuple(N_, Do_, Ho_, Wo_, K_), + make_tuple( + NStrideTensorA_, DoStride_, HoStride_, WoStride_, KStrideTensorA_)); + } + else + { + return make_naive_tensor_descriptor( + make_tuple(N_, Do_, Ho_, Wo_, NumGroupsToMerge, K_), + make_tuple(NStrideTensorA_, + DoStride_, + HoStride_, + WoStride_, + GStrideTensorA_, + KStrideTensorA_)); + } } } else if constexpr(is_same_v) { + // implement on demand + static_assert(NumGroupsToMerge == 1, "Merge group doesn't support GNHWK layout."); + // assume packed if constexpr(ConvBwdDataSpecialization == ck::tensor_operation::device::ConvolutionBackwardDataSpecialization:: @@ -552,6 +608,9 @@ struct TransformConvBwdDataToGemm_v1 } else if constexpr(is_same_v) { + // implement on demand + static_assert(NumGroupsToMerge == 1, "Merge group doesn't support GNDHWK layout."); + // assume packed if constexpr(ConvBwdDataSpecialization == ck::tensor_operation::device::ConvolutionBackwardDataSpecialization:: @@ -566,6 +625,9 @@ struct TransformConvBwdDataToGemm_v1 } else if constexpr(is_same_v) { + // implement on demand + static_assert(NumGroupsToMerge == 1, "Merge group doesn't support NGKHW layout."); + // assume packed static_assert(ConvBwdDataSpecialization == ck::tensor_operation::device::ConvolutionBackwardDataSpecialization:: @@ -583,6 +645,9 @@ struct TransformConvBwdDataToGemm_v1 } else if constexpr(is_same_v) { + // implement on demand + static_assert(NumGroupsToMerge == 1, "Merge group doesn't support NGKDHW layout."); + // assume packed static_assert(ConvBwdDataSpecialization == ck::tensor_operation::device::ConvolutionBackwardDataSpecialization:: @@ -607,14 +672,29 @@ struct TransformConvBwdDataToGemm_v1 __host__ __device__ auto MakeWeiGridDesc() const { - if constexpr(is_same_v) { - return make_naive_tensor_descriptor_packed(make_tuple(K_, Y_, X_, C_)); + if constexpr(NumGroupsToMerge == 1) + { + return make_naive_tensor_descriptor_packed(make_tuple(K_, Y_, X_, C_)); + } + else + { + return make_naive_tensor_descriptor_packed( + make_tuple(NumGroupsToMerge, K_, Y_, X_, C_)); + } } else if constexpr(is_same_v) { - return make_naive_tensor_descriptor_packed(make_tuple(K_, Z_, Y_, X_, C_)); + if constexpr(NumGroupsToMerge == 1) + { + return make_naive_tensor_descriptor_packed(make_tuple(K_, Z_, Y_, X_, C_)); + } + else + { + return make_naive_tensor_descriptor_packed( + make_tuple(NumGroupsToMerge, K_, Z_, Y_, X_, C_)); + } } else { @@ -624,21 +704,49 @@ struct TransformConvBwdDataToGemm_v1 __host__ __device__ auto MakeInGridDesc() const { - if constexpr(is_same_v || is_same_v || is_same_v) { - return make_naive_tensor_descriptor( - make_tuple(N_, Hi_, Wi_, C_), - make_tuple(NStrideTensorC_, HiStride_, WiStride_, CStrideTensorC_)); + if constexpr(NumGroupsToMerge == 1) + { + return make_naive_tensor_descriptor( + make_tuple(N_, Hi_, Wi_, C_), + make_tuple(NStrideTensorC_, HiStride_, WiStride_, CStrideTensorC_)); + } + else + { + return make_naive_tensor_descriptor( + make_tuple(N_, Hi_, Wi_, NumGroupsToMerge, C_, 1), + make_tuple(NStrideTensorC_, + HiStride_, + WiStride_, + GStrideTensorC_, + CStrideTensorC_, + GStrideTensorC_)); + } } else if constexpr(is_same_v || is_same_v) { - return make_naive_tensor_descriptor( - make_tuple(N_, Di_, Hi_, Wi_, C_), - make_tuple(NStrideTensorC_, DiStride_, HiStride_, WiStride_, CStrideTensorC_)); + if constexpr(NumGroupsToMerge == 1) + { + return make_naive_tensor_descriptor( + make_tuple(N_, Di_, Hi_, Wi_, C_), + make_tuple(NStrideTensorC_, DiStride_, HiStride_, WiStride_, CStrideTensorC_)); + } + else + { + return make_naive_tensor_descriptor( + make_tuple(N_, Di_, Hi_, Wi_, NumGroupsToMerge, C_, 1), + make_tuple(NStrideTensorC_, + DiStride_, + HiStride_, + WiStride_, + GStrideTensorC_, + CStrideTensorC_, + GStrideTensorC_)); + } } else { @@ -665,25 +773,60 @@ struct TransformConvBwdDataToGemm_v1 ck::tensor_operation::device::ConvolutionBackwardDataSpecialization:: Filter1x1Stride1Pad0) { - const index_t K0PerBlock = GemmKPerBlock / AK1; - const index_t AK0 = - math::integer_divide_ceil(K_, AK1 * K0PerBlock * batch_k_) * K0PerBlock; - - // A: output tensor - const auto out_gemmak0_gemmmraw_gemmak1_grid_desc = transform_tensor_descriptor( - out_grid_desc, - make_tuple(make_pass_through_transform(N_ * Do_ * Ho_ * Wo_), - make_unmerge_transform(make_tuple(AK0 * batch_k_, AK1))), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<1>{}, Sequence<0, 2>{})); + if constexpr(NumGroupsToMerge == 1) + { + const index_t K0PerBlock = GemmKPerBlock / AK1; + const index_t AK0 = + math::integer_divide_ceil(K_, AK1 * K0PerBlock * batch_k_) * K0PerBlock; - const auto out_gemmak0_gemmm_gemmak1_grid_desc = - ck::tensor_operation::device::PadTensorDescriptor( - out_gemmak0_gemmmraw_gemmak1_grid_desc, - make_tuple(AK0 * batch_k_, GemmMPerBlock, AK1), - Sequence{}); + // A: output tensor + const auto out_gemmak0_gemmmraw_gemmak1_grid_desc = transform_tensor_descriptor( + out_grid_desc, + make_tuple(make_pass_through_transform(N_ * Do_ * Ho_ * Wo_), + make_unmerge_transform(make_tuple(AK0 * batch_k_, AK1))), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<1>{}, Sequence<0, 2>{})); + + const auto out_gemmak0_gemmm_gemmak1_grid_desc = + ck::tensor_operation::device::PadTensorDescriptor( + out_gemmak0_gemmmraw_gemmak1_grid_desc, + make_tuple(AK0 * batch_k_, GemmMPerBlock, AK1), + Sequence{}); + + return out_gemmak0_gemmm_gemmak1_grid_desc; + } + else + { + const auto out_gemmk_gemmm_grid_desc = transform_tensor_descriptor( + out_grid_desc, + make_tuple( + make_merge_transform(make_tuple(N_ * Do_ * Ho_ * Wo_, NumGroupsToMerge)), + make_pass_through_transform(K_)), + make_tuple(Sequence<0, 1>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto out_gemmk_gemmm_padded_grid_desc = + ck::tensor_operation::device::PadTensorDescriptor( + out_gemmk_gemmm_grid_desc, + make_tuple(GemmMPerBlock, GemmKPerBlock), + Sequence{}); + + const index_t K0PerBlock = GemmKPerBlock / AK1; + const index_t AK0 = + math::integer_divide_ceil(out_gemmk_gemmm_padded_grid_desc.GetLength(I1), + AK1 * K0PerBlock * batch_k_) * + K0PerBlock; + + const auto out_gemmak0_gemmmraw_gemmak1_grid_desc = transform_tensor_descriptor( + out_gemmk_gemmm_padded_grid_desc, + make_tuple( + make_pass_through_transform(out_gemmk_gemmm_padded_grid_desc.GetLength(I0)), + make_unmerge_transform(make_tuple(AK0 * batch_k_, AK1))), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<1>{}, Sequence<0, 2>{})); - return out_gemmak0_gemmm_gemmak1_grid_desc; + return out_gemmak0_gemmmraw_gemmak1_grid_desc; + } } else { @@ -713,145 +856,137 @@ struct TransformConvBwdDataToGemm_v1 if constexpr(NDimSpatial == 2) { - const index_t K0PerBlock = GemmKPerBlock / AK1; - const index_t AK0 = math::integer_divide_ceil(YDotSlice * XDotSlice * K_, - AK1 * K0PerBlock * batch_k_) * - K0PerBlock; + if constexpr(NumGroupsToMerge == 1) + { + const index_t K0PerBlock = GemmKPerBlock / AK1; + const index_t AK0 = math::integer_divide_ceil(YDotSlice * XDotSlice * K_, + AK1 * K0PerBlock * batch_k_) * + K0PerBlock; #if CK_USE_CUSTOM_TENSOR_TRANSFORM_FOR_BWD_DATA_OUT == 0 - // A: output tensor - const auto out_n_hop_wop_k_grid_desc = transform_tensor_descriptor( - out_grid_desc, - make_tuple(make_pass_through_transform(N_), - make_pad_transform(Ho_, I0, I0), - make_pad_transform(Wo_, I0, I0), - make_pass_through_transform(K_)), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); - - const auto out_n_ydot_htilde_xdot_wtilde_k_grid_desc = transform_tensor_descriptor( - out_n_hop_wop_k_grid_desc, - make_tuple( - make_pass_through_transform(N_), - make_embed_transform(make_tuple(YDot_, HTilde_), - make_tuple(-ConvDilationH_ / GcdStrideDilationH_, I1)), - make_embed_transform(make_tuple(XDot_, WTilde_), - make_tuple(-ConvDilationW_ / GcdStrideDilationW_, I1)), - make_pass_through_transform(K_)), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), - make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{})); - - const auto out_n_ydotslice_htildeslice_xdotslice_wtildeslice_k_grid_desc = - transform_tensor_descriptor( - out_n_ydot_htilde_xdot_wtilde_k_grid_desc, + // A: output tensor + const auto out_n_hop_wop_k_grid_desc = transform_tensor_descriptor( + out_grid_desc, make_tuple(make_pass_through_transform(N_), - make_slice_transform(YDot_, I0, YDotSlice), - make_slice_transform(HTilde_, IHTildeSliceBegin, HTildeSlice), - make_slice_transform(XDot_, I0, XDotSlice), - make_slice_transform(WTilde_, IWTildeSliceBegin, WTildeSlice), + make_pad_transform(Ho_, I0, I0), + make_pad_transform(Wo_, I0, I0), make_pass_through_transform(K_)), - make_tuple(Sequence<0>{}, - Sequence<1>{}, - Sequence<2>{}, - Sequence<3>{}, - Sequence<4>{}, - Sequence<5>{}), - make_tuple(Sequence<0>{}, - Sequence<1>{}, - Sequence<2>{}, - Sequence<3>{}, - Sequence<4>{}, - Sequence<5>{})); - - const auto out_gemmk_gemmmraw_grid_desc = transform_tensor_descriptor( - out_n_ydotslice_htildeslice_xdotslice_wtildeslice_k_grid_desc, - make_tuple(make_merge_transform(make_tuple(YDotSlice, XDotSlice, K_)), - make_merge_transform(make_tuple(N_, HTildeSlice, WTildeSlice))), - make_tuple(Sequence<1, 3, 5>{}, Sequence<0, 2, 4>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); - const auto out_gemmk_gemmm_padded_grid_desc = - ck::tensor_operation::device::PadTensorDescriptor( - out_gemmk_gemmmraw_grid_desc, - make_tuple(GemmKPerBlock, GemmMPerBlock), - Sequence{}); + const auto out_n_ydot_htilde_xdot_wtilde_k_grid_desc = + transform_tensor_descriptor( + out_n_hop_wop_k_grid_desc, + make_tuple(make_pass_through_transform(N_), + make_embed_transform( + make_tuple(YDot_, HTilde_), + make_tuple(-ConvDilationH_ / GcdStrideDilationH_, I1)), + make_embed_transform( + make_tuple(XDot_, WTilde_), + make_tuple(-ConvDilationW_ / GcdStrideDilationW_, I1)), + make_pass_through_transform(K_)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple( + Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{})); - const auto out_gemmak0_gemmm_gemmak1_grid_desc = transform_tensor_descriptor( - out_gemmk_gemmm_padded_grid_desc, - make_tuple(make_unmerge_transform(make_tuple(AK0 * batch_k_, AK1)), - make_pass_through_transform( - out_gemmk_gemmm_padded_grid_desc.GetLength(I1))), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0, 2>{}, Sequence<1>{})); - return out_gemmak0_gemmm_gemmak1_grid_desc; + const auto out_n_ydotslice_htildeslice_xdotslice_wtildeslice_k_grid_desc = + transform_tensor_descriptor( + out_n_ydot_htilde_xdot_wtilde_k_grid_desc, + make_tuple( + make_pass_through_transform(N_), + make_slice_transform(YDot_, I0, YDotSlice), + make_slice_transform(HTilde_, IHTildeSliceBegin, HTildeSlice), + make_slice_transform(XDot_, I0, XDotSlice), + make_slice_transform(WTilde_, IWTildeSliceBegin, WTildeSlice), + make_pass_through_transform(K_)), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4>{}, + Sequence<5>{}), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4>{}, + Sequence<5>{})); + + const auto out_gemmk_gemmmraw_grid_desc = transform_tensor_descriptor( + out_n_ydotslice_htildeslice_xdotslice_wtildeslice_k_grid_desc, + make_tuple(make_merge_transform(make_tuple(YDotSlice, XDotSlice, K_)), + make_merge_transform(make_tuple(N_, HTildeSlice, WTildeSlice))), + make_tuple(Sequence<1, 3, 5>{}, Sequence<0, 2, 4>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto out_gemmk_gemmm_padded_grid_desc = + ck::tensor_operation::device::PadTensorDescriptor( + out_gemmk_gemmmraw_grid_desc, + make_tuple(GemmKPerBlock, GemmMPerBlock), + Sequence{}); + + const auto out_gemmak0_gemmm_gemmak1_grid_desc = transform_tensor_descriptor( + out_gemmk_gemmm_padded_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(AK0 * batch_k_, AK1)), + make_pass_through_transform( + out_gemmk_gemmm_padded_grid_desc.GetLength(I1))), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + return out_gemmak0_gemmm_gemmak1_grid_desc; #else - const auto out_n_hop_wop_k_grid_desc = transform_tensor_descriptor( - out_grid_desc, - make_tuple(make_pass_through_transform(N_), - make_pad_transform(Ho_, I0, I0), - make_pad_transform(Wo_, I0, I0), - make_pass_through_transform(K_)), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); - - const auto out_n_hop_wop_k_grid_desc_final = transform_tensor_descriptor( - out_n_hop_wop_k_grid_desc, - make_tuple(make_conv_bwd_data_out_transform(N_, - Ho_, - Wo_, - K_, - YDot_, - XDot_, - HTilde_, - WTilde_, - ConvDilationH_, - ConvDilationW_, - HTildeSlice, - WTildeSlice, - YDotSlice, - XDotSlice, - IHTildeSliceBegin, - IWTildeSliceBegin, - GcdStrideDilationH_, - GcdStrideDilationW_, - AK0 * batch_k_, - AK1, - GemmMPerBlock, - GemmKPerBlock)), - make_tuple(Sequence<0, 1, 2, 3>{}), - make_tuple(Sequence<0, 1, 2>{})); - - return out_n_hop_wop_k_grid_desc_final; -#endif - } - else if constexpr(NDimSpatial == 3) - { - // A: output tensor - const auto out_n_hop_wop_k_grid_desc = transform_tensor_descriptor( - out_grid_desc, - make_tuple(make_pass_through_transform(N_), - make_pad_transform(Do_, I0, I0), - make_pad_transform(Ho_, I0, I0), - make_pad_transform(Wo_, I0, I0), - make_pass_through_transform(K_)), - make_tuple( - Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}), - make_tuple( - Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{})); + const auto out_n_hop_wop_k_grid_desc = transform_tensor_descriptor( + out_grid_desc, + make_tuple(make_pass_through_transform(N_), + make_pad_transform(Ho_, I0, I0), + make_pad_transform(Wo_, I0, I0), + make_pass_through_transform(K_)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); - const auto out_n_zdot_dtilde_ydot_htilde_xdot_wtilde_k_grid_desc = - transform_tensor_descriptor( + const auto out_n_hop_wop_k_grid_desc_final = transform_tensor_descriptor( out_n_hop_wop_k_grid_desc, + make_tuple(make_conv_bwd_data_out_transform(N_, + Ho_, + Wo_, + K_, + YDot_, + XDot_, + HTilde_, + WTilde_, + ConvDilationH_, + ConvDilationW_, + HTildeSlice, + WTildeSlice, + YDotSlice, + XDotSlice, + IHTildeSliceBegin, + IWTildeSliceBegin, + GcdStrideDilationH_, + GcdStrideDilationW_, + AK0 * batch_k_, + AK1, + GemmMPerBlock, + GemmKPerBlock)), + make_tuple(Sequence<0, 1, 2, 3>{}), + make_tuple(Sequence<0, 1, 2>{})); + + return out_n_hop_wop_k_grid_desc_final; +#endif + } + else + { + const index_t K0PerBlock = GemmKPerBlock / AK1; + const index_t AK0 = math::integer_divide_ceil(YDotSlice * XDotSlice * K_, + AK1 * K0PerBlock * batch_k_) * + K0PerBlock; + + // A: output tensor + const auto out_n_hop_wop_k_grid_desc = transform_tensor_descriptor( + out_grid_desc, make_tuple(make_pass_through_transform(N_), - make_embed_transform( - make_tuple(ZDot_, DTilde_), - make_tuple(-ConvDilationD_ / GcdStrideDilationD_, I1)), - make_embed_transform( - make_tuple(YDot_, HTilde_), - make_tuple(-ConvDilationH_ / GcdStrideDilationH_, I1)), - make_embed_transform( - make_tuple(XDot_, WTilde_), - make_tuple(-ConvDilationW_ / GcdStrideDilationW_, I1)), + make_pad_transform(Ho_, I0, I0), + make_pad_transform(Wo_, I0, I0), + make_pass_through_transform(NumGroupsToMerge), make_pass_through_transform(K_)), make_tuple(Sequence<0>{}, Sequence<1>{}, @@ -859,23 +994,45 @@ struct TransformConvBwdDataToGemm_v1 Sequence<3>{}, Sequence<4>{}), make_tuple(Sequence<0>{}, - Sequence<1, 2>{}, - Sequence<3, 4>{}, - Sequence<5, 6>{}, - Sequence<7>{})); + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4>{})); + +#if CK_USE_CUSTOM_TENSOR_TRANSFORM_FOR_BWD_DATA_OUT == 0 + const auto out_n_ydot_htilde_xdot_wtilde_k_grid_desc = + transform_tensor_descriptor( + out_n_hop_wop_k_grid_desc, + make_tuple(make_pass_through_transform(N_), + make_embed_transform( + make_tuple(YDot_, HTilde_), + make_tuple(-ConvDilationH_ / GcdStrideDilationH_, I1)), + make_embed_transform( + make_tuple(XDot_, WTilde_), + make_tuple(-ConvDilationW_ / GcdStrideDilationW_, I1)), + make_pass_through_transform(NumGroupsToMerge), + make_pass_through_transform(K_)), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4>{}), + make_tuple(Sequence<0>{}, + Sequence<1, 2>{}, + Sequence<3, 4>{}, + Sequence<5>{}, + Sequence<6>{})); - const auto - out_n_zdotslice_dtildeslice_ydotslice_htildeslice_xdotslice_wtildeslice_k_grid_desc = + const auto out_n_ydotslice_htildeslice_xdotslice_wtildeslice_k_grid_desc = transform_tensor_descriptor( - out_n_zdot_dtilde_ydot_htilde_xdot_wtilde_k_grid_desc, + out_n_ydot_htilde_xdot_wtilde_k_grid_desc, make_tuple( make_pass_through_transform(N_), - make_slice_transform(ZDot_, I0, ZDotSlice), - make_slice_transform(DTilde_, IDTildeSliceBegin, DTildeSlice), make_slice_transform(YDot_, I0, YDotSlice), make_slice_transform(HTilde_, IHTildeSliceBegin, HTildeSlice), make_slice_transform(XDot_, I0, XDotSlice), make_slice_transform(WTilde_, IWTildeSliceBegin, WTildeSlice), + make_pass_through_transform(NumGroupsToMerge), make_pass_through_transform(K_)), make_tuple(Sequence<0>{}, Sequence<1>{}, @@ -883,47 +1040,293 @@ struct TransformConvBwdDataToGemm_v1 Sequence<3>{}, Sequence<4>{}, Sequence<5>{}, - Sequence<6>{}, - Sequence<7>{}), + Sequence<6>{}), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}, Sequence<5>{}, - Sequence<6>{}, + Sequence<6>{})); + + const auto out_gemmk_gemmmraw_grid_desc = transform_tensor_descriptor( + out_n_ydotslice_htildeslice_xdotslice_wtildeslice_k_grid_desc, + make_tuple(make_merge_transform(make_tuple(YDotSlice, XDotSlice, K_)), + make_merge_transform( + make_tuple(N_, HTildeSlice, WTildeSlice, NumGroupsToMerge))), + make_tuple(Sequence<1, 3, 6>{}, Sequence<0, 2, 4, 5>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto out_gemmk_gemmm_padded_grid_desc = + ck::tensor_operation::device::PadTensorDescriptor( + out_gemmk_gemmmraw_grid_desc, + make_tuple(GemmKPerBlock, GemmMPerBlock), + Sequence{}); + + const auto out_n_hop_wop_k_grid_desc_final = transform_tensor_descriptor( + out_gemmk_gemmm_padded_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(AK0 * batch_k_, AK1)), + make_pass_through_transform( + out_gemmk_gemmm_padded_grid_desc.GetLength(I1))), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); +#else + const auto out_n_hop_wop_k_grid_desc_final = transform_tensor_descriptor( + out_n_hop_wop_k_grid_desc, + make_tuple( + make_conv_bwd_data_out_transform(N_, + Ho_, + Wo_, + K_, + YDot_, + XDot_, + HTilde_, + WTilde_, + ConvDilationH_, + ConvDilationW_, + HTildeSlice, + WTildeSlice, + YDotSlice, + XDotSlice, + IHTildeSliceBegin, + IWTildeSliceBegin, + GcdStrideDilationH_, + GcdStrideDilationW_, + AK0 * batch_k_, + AK1, + GemmMPerBlock, + GemmKPerBlock)), + make_tuple(Sequence<0, 1, 2, 3, 4>{}), + make_tuple(Sequence<0, 1, 2>{})); +#endif + return out_n_hop_wop_k_grid_desc_final; + } + } + else if constexpr(NDimSpatial == 3) + { + if constexpr(NumGroupsToMerge == 1) + { + // A: output tensor + const auto out_n_hop_wop_k_grid_desc = + transform_tensor_descriptor(out_grid_desc, + make_tuple(make_pass_through_transform(N_), + make_pad_transform(Do_, I0, I0), + make_pad_transform(Ho_, I0, I0), + make_pad_transform(Wo_, I0, I0), + make_pass_through_transform(K_)), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4>{}), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4>{})); + + const auto out_n_zdot_dtilde_ydot_htilde_xdot_wtilde_k_grid_desc = + transform_tensor_descriptor( + out_n_hop_wop_k_grid_desc, + make_tuple(make_pass_through_transform(N_), + make_embed_transform( + make_tuple(ZDot_, DTilde_), + make_tuple(-ConvDilationD_ / GcdStrideDilationD_, I1)), + make_embed_transform( + make_tuple(YDot_, HTilde_), + make_tuple(-ConvDilationH_ / GcdStrideDilationH_, I1)), + make_embed_transform( + make_tuple(XDot_, WTilde_), + make_tuple(-ConvDilationW_ / GcdStrideDilationW_, I1)), + make_pass_through_transform(K_)), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4>{}), + make_tuple(Sequence<0>{}, + Sequence<1, 2>{}, + Sequence<3, 4>{}, + Sequence<5, 6>{}, Sequence<7>{})); - const auto out_gemmk_gemmmraw_grid_desc = transform_tensor_descriptor( - out_n_zdotslice_dtildeslice_ydotslice_htildeslice_xdotslice_wtildeslice_k_grid_desc, - make_tuple( - make_merge_transform(make_tuple(ZDotSlice, YDotSlice, XDotSlice, K_)), - make_merge_transform( - make_tuple(N_, DTildeSlice, HTildeSlice, WTildeSlice))), - make_tuple(Sequence<1, 3, 5, 7>{}, Sequence<0, 2, 4, 6>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); - - const auto out_gemmk_gemmm_padded_grid_desc = - ck::tensor_operation::device::PadTensorDescriptor( - out_gemmk_gemmmraw_grid_desc, - make_tuple(GemmKPerBlock, GemmMPerBlock), - Sequence{}); - - const index_t K0PerBlock = GemmKPerBlock / AK1; - const index_t AK0 = - math::integer_divide_ceil(out_gemmk_gemmm_padded_grid_desc.GetLength(I0), - AK1 * K0PerBlock * batch_k_) * - K0PerBlock; - - const auto out_gemmak0_gemmm_gemmak1_grid_desc = transform_tensor_descriptor( - out_gemmk_gemmm_padded_grid_desc, - make_tuple(make_unmerge_transform(make_tuple(AK0 * batch_k_, AK1)), - make_pass_through_transform( - out_gemmk_gemmm_padded_grid_desc.GetLength(I1))), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + const auto + out_n_zdotslice_dtildeslice_ydotslice_htildeslice_xdotslice_wtildeslice_k_grid_desc = + transform_tensor_descriptor( + out_n_zdot_dtilde_ydot_htilde_xdot_wtilde_k_grid_desc, + make_tuple( + make_pass_through_transform(N_), + make_slice_transform(ZDot_, I0, ZDotSlice), + make_slice_transform(DTilde_, IDTildeSliceBegin, DTildeSlice), + make_slice_transform(YDot_, I0, YDotSlice), + make_slice_transform(HTilde_, IHTildeSliceBegin, HTildeSlice), + make_slice_transform(XDot_, I0, XDotSlice), + make_slice_transform(WTilde_, IWTildeSliceBegin, WTildeSlice), + make_pass_through_transform(K_)), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4>{}, + Sequence<5>{}, + Sequence<6>{}, + Sequence<7>{}), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4>{}, + Sequence<5>{}, + Sequence<6>{}, + Sequence<7>{})); + + const auto out_gemmk_gemmmraw_grid_desc = transform_tensor_descriptor( + out_n_zdotslice_dtildeslice_ydotslice_htildeslice_xdotslice_wtildeslice_k_grid_desc, + make_tuple( + make_merge_transform(make_tuple(ZDotSlice, YDotSlice, XDotSlice, K_)), + make_merge_transform( + make_tuple(N_, DTildeSlice, HTildeSlice, WTildeSlice))), + make_tuple(Sequence<1, 3, 5, 7>{}, Sequence<0, 2, 4, 6>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto out_gemmk_gemmm_padded_grid_desc = + ck::tensor_operation::device::PadTensorDescriptor( + out_gemmk_gemmmraw_grid_desc, + make_tuple(GemmKPerBlock, GemmMPerBlock), + Sequence{}); + + const index_t K0PerBlock = GemmKPerBlock / AK1; + const index_t AK0 = + math::integer_divide_ceil(out_gemmk_gemmm_padded_grid_desc.GetLength(I0), + AK1 * K0PerBlock * batch_k_) * + K0PerBlock; + + const auto out_gemmak0_gemmm_gemmak1_grid_desc = transform_tensor_descriptor( + out_gemmk_gemmm_padded_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(AK0 * batch_k_, AK1)), + make_pass_through_transform( + out_gemmk_gemmm_padded_grid_desc.GetLength(I1))), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return out_gemmak0_gemmm_gemmak1_grid_desc; + } + else + { + // A: output tensor + const auto out_n_hop_wop_k_grid_desc = transform_tensor_descriptor( + out_grid_desc, + make_tuple(make_pass_through_transform(N_), + make_pad_transform(Do_, I0, I0), + make_pad_transform(Ho_, I0, I0), + make_pad_transform(Wo_, I0, I0), + make_pass_through_transform(NumGroupsToMerge), + make_pass_through_transform(K_)), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4>{}, + Sequence<5>{}), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4>{}, + Sequence<5>{})); - return out_gemmak0_gemmm_gemmak1_grid_desc; + const auto out_n_zdot_dtilde_ydot_htilde_xdot_wtilde_k_grid_desc = + transform_tensor_descriptor( + out_n_hop_wop_k_grid_desc, + make_tuple(make_pass_through_transform(N_), + make_embed_transform( + make_tuple(ZDot_, DTilde_), + make_tuple(-ConvDilationD_ / GcdStrideDilationD_, I1)), + make_embed_transform( + make_tuple(YDot_, HTilde_), + make_tuple(-ConvDilationH_ / GcdStrideDilationH_, I1)), + make_embed_transform( + make_tuple(XDot_, WTilde_), + make_tuple(-ConvDilationW_ / GcdStrideDilationW_, I1)), + make_pass_through_transform(NumGroupsToMerge), + make_pass_through_transform(K_)), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4>{}, + Sequence<5>{}), + make_tuple(Sequence<0>{}, + Sequence<1, 2>{}, + Sequence<3, 4>{}, + Sequence<5, 6>{}, + Sequence<7>{}, + Sequence<8>{})); + + const auto + out_n_zdotslice_dtildeslice_ydotslice_htildeslice_xdotslice_wtildeslice_k_grid_desc = + transform_tensor_descriptor( + out_n_zdot_dtilde_ydot_htilde_xdot_wtilde_k_grid_desc, + make_tuple( + make_pass_through_transform(N_), + make_slice_transform(ZDot_, I0, ZDotSlice), + make_slice_transform(DTilde_, IDTildeSliceBegin, DTildeSlice), + make_slice_transform(YDot_, I0, YDotSlice), + make_slice_transform(HTilde_, IHTildeSliceBegin, HTildeSlice), + make_slice_transform(XDot_, I0, XDotSlice), + make_slice_transform(WTilde_, IWTildeSliceBegin, WTildeSlice), + make_pass_through_transform(NumGroupsToMerge), + make_pass_through_transform(K_)), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4>{}, + Sequence<5>{}, + Sequence<6>{}, + Sequence<7>{}, + Sequence<8>{}), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4>{}, + Sequence<5>{}, + Sequence<6>{}, + Sequence<7>{}, + Sequence<8>{})); + + const auto out_gemmk_gemmmraw_grid_desc = transform_tensor_descriptor( + out_n_zdotslice_dtildeslice_ydotslice_htildeslice_xdotslice_wtildeslice_k_grid_desc, + make_tuple( + make_merge_transform(make_tuple(ZDotSlice, YDotSlice, XDotSlice, K_)), + make_merge_transform(make_tuple( + N_, DTildeSlice, HTildeSlice, WTildeSlice, NumGroupsToMerge))), + make_tuple(Sequence<1, 3, 5, 8>{}, Sequence<0, 2, 4, 6, 7>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto out_gemmk_gemmm_padded_grid_desc = + ck::tensor_operation::device::PadTensorDescriptor( + out_gemmk_gemmmraw_grid_desc, + make_tuple(GemmKPerBlock, GemmMPerBlock), + Sequence{}); + + const index_t K0PerBlock = GemmKPerBlock / AK1; + const index_t AK0 = + math::integer_divide_ceil(out_gemmk_gemmm_padded_grid_desc.GetLength(I0), + AK1 * K0PerBlock * batch_k_) * + K0PerBlock; + + const auto out_gemmak0_gemmm_gemmak1_grid_desc = transform_tensor_descriptor( + out_gemmk_gemmm_padded_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(AK0 * batch_k_, AK1)), + make_pass_through_transform( + out_gemmk_gemmm_padded_grid_desc.GetLength(I1))), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return out_gemmak0_gemmm_gemmak1_grid_desc; + } } else { @@ -947,24 +1350,33 @@ struct TransformConvBwdDataToGemm_v1 ck::tensor_operation::device::ConvolutionBackwardDataSpecialization:: Filter1x1Stride1Pad0) { + const auto wei_gemmnraw_gemmkraw_grid_desc = transform_tensor_descriptor( + make_naive_tensor_descriptor_packed(make_tuple(NumGroupsToMerge, K_, C_)), + make_tuple(make_pass_through_transform(K_), + make_merge_transform(make_tuple(NumGroupsToMerge, C_))), + make_tuple(Sequence<1>{}, Sequence<0, 2>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto wei_gemmnraw_gemmkraw_padded_grid_desc = + ck::tensor_operation::device::PadTensorDescriptor( + wei_gemmnraw_gemmkraw_grid_desc, + make_tuple(GemmKPerBlock, GemmNPerBlock), + Sequence{}); + const index_t K0PerBlock = GemmKPerBlock / BK1; const index_t BK0 = - math::integer_divide_ceil(K_, BK1 * K0PerBlock * batch_k_) * K0PerBlock; + math::integer_divide_ceil(wei_gemmnraw_gemmkraw_padded_grid_desc.GetLength(I0), + BK1 * K0PerBlock * batch_k_) * + K0PerBlock; // B: weight tensor - const auto wei_gemmbk0_gemmnraw_gemmbk1_grid_desc = transform_tensor_descriptor( - make_naive_tensor_descriptor_packed(make_tuple(K_, C_)), + const auto wei_gemmbk0_gemmn_gemmbk1_grid_desc = transform_tensor_descriptor( + wei_gemmnraw_gemmkraw_padded_grid_desc, make_tuple(make_unmerge_transform(make_tuple(BK0 * batch_k_, BK1)), - make_pass_through_transform(C_)), + make_pass_through_transform( + wei_gemmnraw_gemmkraw_padded_grid_desc.GetLength(I1))), make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0, 2>{}, Sequence<1>{})); - make_naive_tensor_descriptor(make_tuple(N_ * Do_ * Ho_ * Wo_, C_), make_tuple(I0, I1)); - - const auto wei_gemmbk0_gemmn_gemmbk1_grid_desc = - ck::tensor_operation::device::PadTensorDescriptor( - wei_gemmbk0_gemmnraw_gemmbk1_grid_desc, - make_tuple(BK0 * batch_k_, GemmNPerBlock, BK1), - Sequence{}); return wei_gemmbk0_gemmn_gemmbk1_grid_desc; } @@ -984,152 +1396,329 @@ struct TransformConvBwdDataToGemm_v1 // B weight tensor if constexpr(NDimSpatial == 2) { - const auto wei_k_ydot_ytilde_xdot_xtilde_c_grid_desc = transform_tensor_descriptor( - wei_grid_desc, - make_tuple( - make_pass_through_transform(K_), - make_embed_transform(make_tuple(YDot_, YTilde_), - make_tuple(ConvStrideH_ / GcdStrideDilationH_, I1)), - make_embed_transform(make_tuple(XDot_, XTilde_), - make_tuple(ConvStrideW_ / GcdStrideDilationW_, I1)), - make_pass_through_transform(C_)), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), - make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{})); - - const auto wei_k_ydotslice_xdotslice_c_grid_desc = transform_tensor_descriptor( - wei_k_ydot_ytilde_xdot_xtilde_c_grid_desc, - make_tuple(make_pass_through_transform(K_), - make_slice_transform(YDot_, I0, YDotSlice), - make_slice_transform(XDot_, I0, XDotSlice), - make_freeze_transform(IdxYTilde_), - make_freeze_transform(IdxXTilde_), - make_pass_through_transform(C_)), - make_tuple(Sequence<0>{}, - Sequence<1>{}, - Sequence<3>{}, - Sequence<2>{}, - Sequence<4>{}, - Sequence<5>{}), - make_tuple(Sequence<0>{}, - Sequence<1>{}, - Sequence<2>{}, - Sequence<>{}, - Sequence<>{}, - Sequence<3>{})); - - const auto wei_gemmk_gemmnraw_grid_desc = transform_tensor_descriptor( - wei_k_ydotslice_xdotslice_c_grid_desc, - make_tuple(make_merge_transform(make_tuple(YDotSlice, XDotSlice, K_)), - make_pass_through_transform(C_)), - make_tuple(Sequence<1, 2, 0>{}, Sequence<3>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); - - const auto wei_gemmk_gemmn_padded_grid_desc = - ck::tensor_operation::device::PadTensorDescriptor( - wei_gemmk_gemmnraw_grid_desc, - make_tuple(GemmKPerBlock, GemmNPerBlock), - Sequence{}); - - const index_t K0PerBlock = GemmKPerBlock / BK1; - const index_t BK0 = - math::integer_divide_ceil(wei_gemmk_gemmn_padded_grid_desc.GetLength(I0), - BK1 * K0PerBlock * batch_k_) * - K0PerBlock; - - const auto wei_gemmbk0_gemmn_gemmbk1_grid_desc = transform_tensor_descriptor( - wei_gemmk_gemmn_padded_grid_desc, - make_tuple(make_unmerge_transform(make_tuple(BK0 * batch_k_, BK1)), - make_pass_through_transform( - wei_gemmk_gemmn_padded_grid_desc.GetLength(I1))), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + if constexpr(NumGroupsToMerge == 1) + { + const auto wei_k_ydot_ytilde_xdot_xtilde_c_grid_desc = + transform_tensor_descriptor( + wei_grid_desc, + make_tuple(make_pass_through_transform(K_), + make_embed_transform( + make_tuple(YDot_, YTilde_), + make_tuple(ConvStrideH_ / GcdStrideDilationH_, I1)), + make_embed_transform( + make_tuple(XDot_, XTilde_), + make_tuple(ConvStrideW_ / GcdStrideDilationW_, I1)), + make_pass_through_transform(C_)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple( + Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{})); - return wei_gemmbk0_gemmn_gemmbk1_grid_desc; - } - else if constexpr(NDimSpatial == 3) - { - const auto wei_k_zdot_ztilde_ydot_ytilde_xdot_xtilde_c_grid_desc = - transform_tensor_descriptor( - wei_grid_desc, + const auto wei_k_ydotslice_xdotslice_c_grid_desc = transform_tensor_descriptor( + wei_k_ydot_ytilde_xdot_xtilde_c_grid_desc, make_tuple(make_pass_through_transform(K_), - make_embed_transform( - make_tuple(ZDot_, ZTilde_), - make_tuple(ConvStrideD_ / GcdStrideDilationD_, I1)), - make_embed_transform( - make_tuple(YDot_, YTilde_), - make_tuple(ConvStrideH_ / GcdStrideDilationH_, I1)), - make_embed_transform( - make_tuple(XDot_, XTilde_), - make_tuple(ConvStrideW_ / GcdStrideDilationW_, I1)), + make_slice_transform(YDot_, I0, YDotSlice), + make_slice_transform(XDot_, I0, XDotSlice), + make_freeze_transform(IdxYTilde_), + make_freeze_transform(IdxXTilde_), make_pass_through_transform(C_)), make_tuple(Sequence<0>{}, Sequence<1>{}, - Sequence<2>{}, Sequence<3>{}, - Sequence<4>{}), + Sequence<2>{}, + Sequence<4>{}, + Sequence<5>{}), make_tuple(Sequence<0>{}, - Sequence<1, 2>{}, - Sequence<3, 4>{}, - Sequence<5, 6>{}, - Sequence<7>{})); + Sequence<1>{}, + Sequence<2>{}, + Sequence<>{}, + Sequence<>{}, + Sequence<3>{})); - const auto wei_gemmk_zdotslice_ydotslice_xdotslice_c_grid_desc = - transform_tensor_descriptor( - wei_k_zdot_ztilde_ydot_ytilde_xdot_xtilde_c_grid_desc, - make_tuple(make_pass_through_transform(K_), - make_slice_transform(ZDot_, I0, ZDotSlice), + const auto wei_gemmk_gemmnraw_grid_desc = transform_tensor_descriptor( + wei_k_ydotslice_xdotslice_c_grid_desc, + make_tuple(make_merge_transform(make_tuple(YDotSlice, XDotSlice, K_)), + make_pass_through_transform(C_)), + make_tuple(Sequence<1, 2, 0>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto wei_gemmk_gemmn_padded_grid_desc = + ck::tensor_operation::device::PadTensorDescriptor( + wei_gemmk_gemmnraw_grid_desc, + make_tuple(GemmKPerBlock, GemmNPerBlock), + Sequence{}); + + const index_t K0PerBlock = GemmKPerBlock / BK1; + const index_t BK0 = + math::integer_divide_ceil(wei_gemmk_gemmn_padded_grid_desc.GetLength(I0), + BK1 * K0PerBlock * batch_k_) * + K0PerBlock; + + const auto wei_gemmbk0_gemmn_gemmbk1_grid_desc = transform_tensor_descriptor( + wei_gemmk_gemmn_padded_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(BK0 * batch_k_, BK1)), + make_pass_through_transform( + wei_gemmk_gemmn_padded_grid_desc.GetLength(I1))), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return wei_gemmbk0_gemmn_gemmbk1_grid_desc; + } + else + { + const auto wei_k_ydot_ytilde_xdot_xtilde_c_grid_desc = + transform_tensor_descriptor( + wei_grid_desc, + make_tuple(make_pass_through_transform(NumGroupsToMerge), + make_pass_through_transform(K_), + make_embed_transform( + make_tuple(YDot_, YTilde_), + make_tuple(ConvStrideH_ / GcdStrideDilationH_, I1)), + make_embed_transform( + make_tuple(XDot_, XTilde_), + make_tuple(ConvStrideW_ / GcdStrideDilationW_, I1)), + make_pass_through_transform(C_)), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4>{}), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2, 3>{}, + Sequence<4, 5>{}, + Sequence<6>{})); + + const auto wei_k_ydotslice_xdotslice_c_grid_desc = transform_tensor_descriptor( + wei_k_ydot_ytilde_xdot_xtilde_c_grid_desc, + make_tuple(make_pass_through_transform(NumGroupsToMerge), + make_pass_through_transform(K_), make_slice_transform(YDot_, I0, YDotSlice), make_slice_transform(XDot_, I0, XDotSlice), - make_freeze_transform(IdxZTilde_), make_freeze_transform(IdxYTilde_), make_freeze_transform(IdxXTilde_), make_pass_through_transform(C_)), make_tuple(Sequence<0>{}, Sequence<1>{}, - Sequence<3>{}, - Sequence<5>{}, Sequence<2>{}, Sequence<4>{}, - Sequence<6>{}, - Sequence<7>{}), + Sequence<3>{}, + Sequence<5>{}, + Sequence<6>{}), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<>{}, Sequence<>{}, - Sequence<>{}, Sequence<4>{})); - const auto wei_gemmk_gemmnraw_grid_desc = transform_tensor_descriptor( - wei_gemmk_zdotslice_ydotslice_xdotslice_c_grid_desc, - make_tuple( - make_merge_transform(make_tuple(ZDotSlice, YDotSlice, XDotSlice, K_)), - make_pass_through_transform(C_)), - make_tuple(Sequence<1, 2, 3, 0>{}, Sequence<4>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); + const auto wei_gemmk_gemmnraw_grid_desc = transform_tensor_descriptor( + wei_k_ydotslice_xdotslice_c_grid_desc, + make_tuple(make_merge_transform(make_tuple(K_, YDotSlice, XDotSlice)), + make_merge_transform(make_tuple(NumGroupsToMerge, C_))), + make_tuple(Sequence<1, 2, 3>{}, Sequence<0, 4>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto wei_gemmk_gemmn_padded_grid_desc = + ck::tensor_operation::device::PadTensorDescriptor( + wei_gemmk_gemmnraw_grid_desc, + make_tuple(GemmKPerBlock, GemmNPerBlock), + Sequence{}); + + const index_t K0PerBlock = GemmKPerBlock / BK1; + const index_t BK0 = + math::integer_divide_ceil(wei_gemmk_gemmn_padded_grid_desc.GetLength(I0), + BK1 * K0PerBlock * batch_k_) * + K0PerBlock; + + const auto wei_gemmbk0_gemmn_gemmbk1_grid_desc = transform_tensor_descriptor( + wei_gemmk_gemmn_padded_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(BK0 * batch_k_, BK1)), + make_pass_through_transform( + wei_gemmk_gemmn_padded_grid_desc.GetLength(I1))), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return wei_gemmbk0_gemmn_gemmbk1_grid_desc; + } + } + else if constexpr(NDimSpatial == 3) + { + if constexpr(NumGroupsToMerge == 1) + { + const auto wei_k_zdot_ztilde_ydot_ytilde_xdot_xtilde_c_grid_desc = + transform_tensor_descriptor( + wei_grid_desc, + make_tuple(make_pass_through_transform(K_), + make_embed_transform( + make_tuple(ZDot_, ZTilde_), + make_tuple(ConvStrideD_ / GcdStrideDilationD_, I1)), + make_embed_transform( + make_tuple(YDot_, YTilde_), + make_tuple(ConvStrideH_ / GcdStrideDilationH_, I1)), + make_embed_transform( + make_tuple(XDot_, XTilde_), + make_tuple(ConvStrideW_ / GcdStrideDilationW_, I1)), + make_pass_through_transform(C_)), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4>{}), + make_tuple(Sequence<0>{}, + Sequence<1, 2>{}, + Sequence<3, 4>{}, + Sequence<5, 6>{}, + Sequence<7>{})); - const auto wei_gemmk_gemmn_padded_grid_desc = - ck::tensor_operation::device::PadTensorDescriptor( - wei_gemmk_gemmnraw_grid_desc, - make_tuple(GemmKPerBlock, GemmNPerBlock), - Sequence{}); - - const index_t K0PerBlock = GemmKPerBlock / BK1; - const index_t BK0 = - math::integer_divide_ceil(wei_gemmk_gemmn_padded_grid_desc.GetLength(I0), - BK1 * K0PerBlock * batch_k_) * - K0PerBlock; + const auto wei_gemmk_zdotslice_ydotslice_xdotslice_c_grid_desc = + transform_tensor_descriptor( + wei_k_zdot_ztilde_ydot_ytilde_xdot_xtilde_c_grid_desc, + make_tuple(make_pass_through_transform(K_), + make_slice_transform(ZDot_, I0, ZDotSlice), + make_slice_transform(YDot_, I0, YDotSlice), + make_slice_transform(XDot_, I0, XDotSlice), + make_freeze_transform(IdxZTilde_), + make_freeze_transform(IdxYTilde_), + make_freeze_transform(IdxXTilde_), + make_pass_through_transform(C_)), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<3>{}, + Sequence<5>{}, + Sequence<2>{}, + Sequence<4>{}, + Sequence<6>{}, + Sequence<7>{}), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<>{}, + Sequence<>{}, + Sequence<>{}, + Sequence<4>{})); - const auto wei_gemmbk0_gemm_gemmbk1_grid_desc = transform_tensor_descriptor( - wei_gemmk_gemmn_padded_grid_desc, - make_tuple(make_unmerge_transform(make_tuple(BK0 * batch_k_, BK1)), - make_pass_through_transform( - wei_gemmk_gemmn_padded_grid_desc.GetLength(I1))), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + const auto wei_gemmk_gemmnraw_grid_desc = transform_tensor_descriptor( + wei_gemmk_zdotslice_ydotslice_xdotslice_c_grid_desc, + make_tuple( + make_merge_transform(make_tuple(ZDotSlice, YDotSlice, XDotSlice, K_)), + make_pass_through_transform(C_)), + make_tuple(Sequence<1, 2, 3, 0>{}, Sequence<4>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto wei_gemmk_gemmn_padded_grid_desc = + ck::tensor_operation::device::PadTensorDescriptor( + wei_gemmk_gemmnraw_grid_desc, + make_tuple(GemmKPerBlock, GemmNPerBlock), + Sequence{}); + + const index_t K0PerBlock = GemmKPerBlock / BK1; + const index_t BK0 = + math::integer_divide_ceil(wei_gemmk_gemmn_padded_grid_desc.GetLength(I0), + BK1 * K0PerBlock * batch_k_) * + K0PerBlock; + + const auto wei_gemmbk0_gemm_gemmbk1_grid_desc = transform_tensor_descriptor( + wei_gemmk_gemmn_padded_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(BK0 * batch_k_, BK1)), + make_pass_through_transform( + wei_gemmk_gemmn_padded_grid_desc.GetLength(I1))), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return wei_gemmbk0_gemm_gemmbk1_grid_desc; + } + else + { + const auto wei_k_ydot_ytilde_xdot_xtilde_c_grid_desc = + transform_tensor_descriptor( + wei_grid_desc, + make_tuple(make_pass_through_transform(NumGroupsToMerge), + make_pass_through_transform(K_), + make_embed_transform( + make_tuple(ZDot_, ZTilde_), + make_tuple(ConvStrideD_ / GcdStrideDilationD_, I1)), + make_embed_transform( + make_tuple(YDot_, YTilde_), + make_tuple(ConvStrideH_ / GcdStrideDilationH_, I1)), + make_embed_transform( + make_tuple(XDot_, XTilde_), + make_tuple(ConvStrideW_ / GcdStrideDilationW_, I1)), + make_pass_through_transform(C_)), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4>{}, + Sequence<5>{}), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2, 3>{}, + Sequence<4, 5>{}, + Sequence<6, 7>{}, + Sequence<8>{})); + + const auto wei_k_ydotslice_xdotslice_c_grid_desc = transform_tensor_descriptor( + wei_k_ydot_ytilde_xdot_xtilde_c_grid_desc, + make_tuple(make_pass_through_transform(NumGroupsToMerge), + make_pass_through_transform(K_), + make_slice_transform(ZDot_, I0, ZDotSlice), + make_slice_transform(YDot_, I0, YDotSlice), + make_slice_transform(XDot_, I0, XDotSlice), + make_freeze_transform(IdxZTilde_), + make_freeze_transform(IdxYTilde_), + make_freeze_transform(IdxXTilde_), + make_pass_through_transform(C_)), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<4>{}, + Sequence<6>{}, + Sequence<3>{}, + Sequence<5>{}, + Sequence<7>{}, + Sequence<8>{}), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4>{}, + Sequence<>{}, + Sequence<>{}, + Sequence<>{}, + Sequence<5>{})); - return wei_gemmbk0_gemm_gemmbk1_grid_desc; + const auto wei_gemmk_gemmnraw_grid_desc = transform_tensor_descriptor( + wei_k_ydotslice_xdotslice_c_grid_desc, + make_tuple( + make_merge_transform(make_tuple(K_, ZDotSlice, YDotSlice, XDotSlice)), + make_merge_transform(make_tuple(NumGroupsToMerge, C_))), + make_tuple(Sequence<1, 2, 3, 4>{}, Sequence<0, 5>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto wei_gemmk_gemmn_padded_grid_desc = + ck::tensor_operation::device::PadTensorDescriptor( + wei_gemmk_gemmnraw_grid_desc, + make_tuple(GemmKPerBlock, GemmNPerBlock), + Sequence{}); + + const index_t K0PerBlock = GemmKPerBlock / BK1; + const index_t BK0 = + math::integer_divide_ceil(wei_gemmk_gemmn_padded_grid_desc.GetLength(I0), + BK1 * K0PerBlock * batch_k_) * + K0PerBlock; + + const auto wei_gemmbk0_gemmn_gemmbk1_grid_desc = transform_tensor_descriptor( + wei_gemmk_gemmn_padded_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(BK0 * batch_k_, BK1)), + make_pass_through_transform( + wei_gemmk_gemmn_padded_grid_desc.GetLength(I1))), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return wei_gemmbk0_gemmn_gemmbk1_grid_desc; + } } else { @@ -1161,75 +1750,239 @@ struct TransformConvBwdDataToGemm_v1 // C: input tensor if constexpr(NDimSpatial == 2) { - const auto in_n_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor( - in_grid_desc, - make_tuple( - make_pass_through_transform(N_), - make_embed_transform(make_tuple(I1, Ho_), make_tuple(I1, ConvStrideH_)), - make_embed_transform(make_tuple(I1, Wo_), make_tuple(I1, ConvStrideW_)), - make_pass_through_transform(C_)), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), - make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{})); - - const auto in_gemmmraw_gemmnraw_grid_desc = transform_tensor_descriptor( - in_n_y_ho_x_wo_c_grid_desc, - make_tuple(make_freeze_transform(I0), - make_freeze_transform(I0), - make_merge_transform(make_tuple(N_, Ho_, Wo_)), - make_pass_through_transform(C_)), - make_tuple(Sequence<1>{}, Sequence<3>{}, Sequence<0, 2, 4>{}, Sequence<5>{}), - make_tuple(Sequence<>{}, Sequence<>{}, Sequence<0>{}, Sequence<1>{})); - - const auto in_gemmm_gemmn_grid_desc = - ck::tensor_operation::device::PadTensorDescriptor( - in_gemmmraw_gemmnraw_grid_desc, - make_tuple(GemmMPerBlock, GemmNPerBlock), - Sequence{}); + if constexpr(NumGroupsToMerge == 1) + { + const auto in_n_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor( + in_grid_desc, + make_tuple( + make_pass_through_transform(N_), + make_embed_transform(make_tuple(I1, Ho_), make_tuple(I1, ConvStrideH_)), + make_embed_transform(make_tuple(I1, Wo_), make_tuple(I1, ConvStrideW_)), + make_pass_through_transform(C_)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple( + Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{})); - return in_gemmm_gemmn_grid_desc; + const auto in_gemmmraw_gemmnraw_grid_desc = transform_tensor_descriptor( + in_n_y_ho_x_wo_c_grid_desc, + make_tuple(make_freeze_transform(I0), + make_freeze_transform(I0), + make_merge_transform(make_tuple(N_, Ho_, Wo_)), + make_pass_through_transform(C_)), + make_tuple( + Sequence<1>{}, Sequence<3>{}, Sequence<0, 2, 4>{}, Sequence<5>{}), + make_tuple(Sequence<>{}, Sequence<>{}, Sequence<0>{}, Sequence<1>{})); + + const auto in_gemmm_gemmn_grid_desc = + ck::tensor_operation::device::PadTensorDescriptor( + in_gemmmraw_gemmnraw_grid_desc, + make_tuple(GemmMPerBlock, GemmNPerBlock), + Sequence{}); + + return in_gemmm_gemmn_grid_desc; + } + else + { + const auto in_n_y_x_mg_c_mgPad_desc = transform_tensor_descriptor( + in_grid_desc, + make_tuple(make_pass_through_transform(N_), + make_pass_through_transform(Ho_), + make_pass_through_transform(Wo_), + make_pass_through_transform(NumGroupsToMerge), + make_pass_through_transform(C_), + make_pad_transform(1, 0, NumGroupsToMerge - 1)), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4>{}, + Sequence<5>{}), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4>{}, + Sequence<5>{})); + + const auto in_n_y_ho_x_wo_mgxor_c_mgxor_grid_desc = transform_tensor_descriptor( + in_n_y_x_mg_c_mgPad_desc, + make_tuple( + make_pass_through_transform(N_), + make_embed_transform(make_tuple(I1, Ho_), make_tuple(I1, ConvStrideH_)), + make_embed_transform(make_tuple(I1, Wo_), make_tuple(I1, ConvStrideW_)), + make_xor_transform(make_tuple(NumGroupsToMerge, NumGroupsToMerge)), + make_pass_through_transform(C_)), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3, 5>{}, + Sequence<4>{}), + make_tuple(Sequence<0>{}, + Sequence<1, 2>{}, + Sequence<3, 4>{}, + Sequence<5, 7>{}, + Sequence<6>{})); + // We need only matrices from diagonal. X_or returns 0 for the same + // values. So if matrices is not on diagonal then it will be stored in padding. + // To avoid use of modulo after xor we assume that NumBatch to merge is power + // of 2. + static_assert(NumGroupsToMerge == 1 || NumGroupsToMerge == 2 || + NumGroupsToMerge == 4 || NumGroupsToMerge == 8 || + NumGroupsToMerge == 16 || NumGroupsToMerge == 32 || + NumGroupsToMerge == 64); + + const auto in_gemmmraw_gemmnraw_grid_desc = transform_tensor_descriptor( + in_n_y_ho_x_wo_mgxor_c_mgxor_grid_desc, + make_tuple(make_freeze_transform(I0), + make_freeze_transform(I0), + make_merge_transform(make_tuple(N_, Ho_, Wo_, NumGroupsToMerge)), + make_merge_transform(make_tuple(C_, NumGroupsToMerge))), + make_tuple( + Sequence<1>{}, Sequence<3>{}, Sequence<0, 2, 4, 5>{}, Sequence<6, 7>{}), + make_tuple(Sequence<>{}, Sequence<>{}, Sequence<0>{}, Sequence<1>{})); + + const auto in_gemmm_gemmn_grid_desc = + ck::tensor_operation::device::PadTensorDescriptor( + in_gemmmraw_gemmnraw_grid_desc, + make_tuple(GemmMPerBlock, GemmNPerBlock), + Sequence{}); + + return in_gemmm_gemmn_grid_desc; + } } else if constexpr(NDimSpatial == 3) { + if constexpr(NumGroupsToMerge == 1) + { + // C: input tensor + const auto in_n_x_do_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor( + in_grid_desc, + make_tuple( + make_pass_through_transform(N_), + make_embed_transform(make_tuple(I1, Do_), make_tuple(I1, ConvStrideD_)), + make_embed_transform(make_tuple(I1, Ho_), make_tuple(I1, ConvStrideH_)), + make_embed_transform(make_tuple(I1, Wo_), make_tuple(I1, ConvStrideW_)), + make_pass_through_transform(C_)), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4>{}), + make_tuple(Sequence<0>{}, + Sequence<1, 2>{}, + Sequence<3, 4>{}, + Sequence<5, 6>{}, + Sequence<7>{})); - // C: input tensor - const auto in_n_x_do_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor( - in_grid_desc, - make_tuple( - make_pass_through_transform(N_), - make_embed_transform(make_tuple(I1, Do_), make_tuple(I1, ConvStrideD_)), - make_embed_transform(make_tuple(I1, Ho_), make_tuple(I1, ConvStrideH_)), - make_embed_transform(make_tuple(I1, Wo_), make_tuple(I1, ConvStrideW_)), - make_pass_through_transform(C_)), - make_tuple( - Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}), - make_tuple(Sequence<0>{}, - Sequence<1, 2>{}, - Sequence<3, 4>{}, - Sequence<5, 6>{}, - Sequence<7>{})); - - const auto in_gemmmraw_gemmnraw_grid_desc = transform_tensor_descriptor( - in_n_x_do_y_ho_x_wo_c_grid_desc, - make_tuple(make_freeze_transform(I0), - make_freeze_transform(I0), - make_freeze_transform(I0), - make_merge_transform(make_tuple(N_, Do_, Ho_, Wo_)), - make_pass_through_transform(C_)), - make_tuple(Sequence<1>{}, - Sequence<3>{}, - Sequence<5>{}, - Sequence<0, 2, 4, 6>{}, - Sequence<7>{}), - make_tuple( - Sequence<>{}, Sequence<>{}, Sequence<>{}, Sequence<0>{}, Sequence<1>{})); + const auto in_gemmmraw_gemmnraw_grid_desc = transform_tensor_descriptor( + in_n_x_do_y_ho_x_wo_c_grid_desc, + make_tuple(make_freeze_transform(I0), + make_freeze_transform(I0), + make_freeze_transform(I0), + make_merge_transform(make_tuple(N_, Do_, Ho_, Wo_)), + make_pass_through_transform(C_)), + make_tuple(Sequence<1>{}, + Sequence<3>{}, + Sequence<5>{}, + Sequence<0, 2, 4, 6>{}, + Sequence<7>{}), + make_tuple(Sequence<>{}, + Sequence<>{}, + Sequence<>{}, + Sequence<0>{}, + Sequence<1>{})); - const auto in_gemmm_gemmn_grid_desc = - ck::tensor_operation::device::PadTensorDescriptor( - in_gemmmraw_gemmnraw_grid_desc, - make_tuple(GemmMPerBlock, GemmNPerBlock), - Sequence{}); + const auto in_gemmm_gemmn_grid_desc = + ck::tensor_operation::device::PadTensorDescriptor( + in_gemmmraw_gemmnraw_grid_desc, + make_tuple(GemmMPerBlock, GemmNPerBlock), + Sequence{}); - return in_gemmm_gemmn_grid_desc; + return in_gemmm_gemmn_grid_desc; + } + else + { + const auto in_n_y_x_mg_c_mgPad_desc = transform_tensor_descriptor( + in_grid_desc, + make_tuple(make_pass_through_transform(N_), + make_pass_through_transform(Do_), + make_pass_through_transform(Ho_), + make_pass_through_transform(Wo_), + make_pass_through_transform(NumGroupsToMerge), + make_pass_through_transform(C_), + make_pad_transform(1, 0, NumGroupsToMerge - 1)), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4>{}, + Sequence<5>{}, + Sequence<6>{}), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4>{}, + Sequence<5>{}, + Sequence<6>{})); + + const auto in_n_y_ho_x_wo_mgxor_c_mgxor_grid_desc = transform_tensor_descriptor( + in_n_y_x_mg_c_mgPad_desc, + make_tuple( + make_pass_through_transform(N_), + make_embed_transform(make_tuple(I1, Do_), make_tuple(I1, ConvStrideD_)), + make_embed_transform(make_tuple(I1, Ho_), make_tuple(I1, ConvStrideH_)), + make_embed_transform(make_tuple(I1, Wo_), make_tuple(I1, ConvStrideW_)), + make_xor_transform(make_tuple(NumGroupsToMerge, NumGroupsToMerge)), + make_pass_through_transform(C_)), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4, 6>{}, + Sequence<5>{}), + make_tuple(Sequence<0>{}, + Sequence<1, 2>{}, + Sequence<3, 4>{}, + Sequence<5, 6>{}, + Sequence<7, 9>{}, + Sequence<8>{})); + // We need only matrices from diagonal. X_or returns 0 for the same + // values. So if matrices is not on diagonal then it will be stored in padding. + // To avoid use of modulo after xor we assume that NumBatch to merge is power + // of 2. + static_assert(NumGroupsToMerge == 1 || NumGroupsToMerge == 2 || + NumGroupsToMerge == 4 || NumGroupsToMerge == 8 || + NumGroupsToMerge == 16 || NumGroupsToMerge == 32 || + NumGroupsToMerge == 64); + + const auto in_gemmmraw_gemmnraw_grid_desc = transform_tensor_descriptor( + in_n_y_ho_x_wo_mgxor_c_mgxor_grid_desc, + make_tuple( + make_freeze_transform(I0), + make_freeze_transform(I0), + make_freeze_transform(I0), + make_merge_transform(make_tuple(N_, Do_, Ho_, Wo_, NumGroupsToMerge)), + make_merge_transform(make_tuple(C_, NumGroupsToMerge))), + make_tuple(Sequence<1>{}, + Sequence<3>{}, + Sequence<5>{}, + Sequence<0, 2, 4, 6, 7>{}, + Sequence<8, 9>{}), + make_tuple(Sequence<>{}, + Sequence<>{}, + Sequence<>{}, + Sequence<0>{}, + Sequence<1>{})); + + const auto in_gemmm_gemmn_grid_desc = + ck::tensor_operation::device::PadTensorDescriptor( + in_gemmmraw_gemmnraw_grid_desc, + make_tuple(GemmMPerBlock, GemmNPerBlock), + Sequence{}); + + return in_gemmm_gemmn_grid_desc; + } } else { @@ -1261,88 +2014,177 @@ struct TransformConvBwdDataToGemm_v1 // C: input tensor if constexpr(NDimSpatial == 2) { - const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor( - in_grid_desc, - make_tuple(make_pass_through_transform(N_), - make_pad_transform(Hi_, InLeftPadH_, InRightPadH_), - make_pad_transform(Wi_, InLeftPadW_, InRightPadW_), - make_pass_through_transform(C_)), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); - - const auto in_n_ytilde_htilde_xtilde_wtilde_c_grid_desc = - transform_tensor_descriptor( - in_n_hip_wip_c_grid_desc, + if constexpr(NumGroupsToMerge == 1) + { + const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor( + in_grid_desc, make_tuple(make_pass_through_transform(N_), - make_embed_transform(make_tuple(YTilde_, HTilde_), - make_tuple(ConvDilationH_, ConvStrideH_)), - make_embed_transform(make_tuple(XTilde_, WTilde_), - make_tuple(ConvDilationW_, ConvStrideW_)), + make_pad_transform(Hi_, InLeftPadH_, InRightPadH_), + make_pad_transform(Wi_, InLeftPadW_, InRightPadW_), make_pass_through_transform(C_)), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), - make_tuple( - Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{})); + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); - const auto in_n_htildeslice_wtildeslice_c_grid_desc = transform_tensor_descriptor( - in_n_ytilde_htilde_xtilde_wtilde_c_grid_desc, - make_tuple(make_pass_through_transform(N_), - make_freeze_transform(IdxYTilde_), - make_slice_transform(HTilde_, IHTildeSliceBegin, HTildeSlice), - make_freeze_transform(IdxXTilde_), - make_slice_transform(WTilde_, IWTildeSliceBegin, WTildeSlice), - make_pass_through_transform(C_)), - make_tuple(Sequence<0>{}, - Sequence<1>{}, - Sequence<2>{}, - Sequence<3>{}, - Sequence<4>{}, - Sequence<5>{}), - make_tuple(Sequence<0>{}, - Sequence<>{}, - Sequence<1>{}, - Sequence<>{}, - Sequence<2>{}, - Sequence<3>{})); - - const auto in_gemmmraw_gemmnraw_grid_desc = transform_tensor_descriptor( - in_n_htildeslice_wtildeslice_c_grid_desc, - make_tuple(make_merge_transform(make_tuple(N_, HTildeSlice, WTildeSlice)), - make_pass_through_transform(C_)), - make_tuple(Sequence<0, 1, 2>{}, Sequence<3>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); + const auto in_n_ytilde_htilde_xtilde_wtilde_c_grid_desc = + transform_tensor_descriptor( + in_n_hip_wip_c_grid_desc, + make_tuple( + make_pass_through_transform(N_), + make_embed_transform(make_tuple(YTilde_, HTilde_), + make_tuple(ConvDilationH_, ConvStrideH_)), + make_embed_transform(make_tuple(XTilde_, WTilde_), + make_tuple(ConvDilationW_, ConvStrideW_)), + make_pass_through_transform(C_)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple( + Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{})); - const auto in_gemmm_gemmn_grid_desc = - ck::tensor_operation::device::PadTensorDescriptor( - in_gemmmraw_gemmnraw_grid_desc, - make_tuple(GemmMPerBlock, GemmNPerBlock), - Sequence{}); + const auto in_n_htildeslice_wtildeslice_c_grid_desc = + transform_tensor_descriptor( + in_n_ytilde_htilde_xtilde_wtilde_c_grid_desc, + make_tuple( + make_pass_through_transform(N_), + make_freeze_transform(IdxYTilde_), + make_slice_transform(HTilde_, IHTildeSliceBegin, HTildeSlice), + make_freeze_transform(IdxXTilde_), + make_slice_transform(WTilde_, IWTildeSliceBegin, WTildeSlice), + make_pass_through_transform(C_)), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4>{}, + Sequence<5>{}), + make_tuple(Sequence<0>{}, + Sequence<>{}, + Sequence<1>{}, + Sequence<>{}, + Sequence<2>{}, + Sequence<3>{})); - return in_gemmm_gemmn_grid_desc; + const auto in_gemmmraw_gemmnraw_grid_desc = transform_tensor_descriptor( + in_n_htildeslice_wtildeslice_c_grid_desc, + make_tuple(make_merge_transform(make_tuple(N_, HTildeSlice, WTildeSlice)), + make_pass_through_transform(C_)), + make_tuple(Sequence<0, 1, 2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto in_gemmm_gemmn_grid_desc = + ck::tensor_operation::device::PadTensorDescriptor( + in_gemmmraw_gemmnraw_grid_desc, + make_tuple(GemmMPerBlock, GemmNPerBlock), + Sequence{}); + + return in_gemmm_gemmn_grid_desc; + } + else + { + const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor( + in_grid_desc, + make_tuple(make_pass_through_transform(N_), + make_pad_transform(Hi_, InLeftPadH_, InRightPadH_), + make_pad_transform(Wi_, InLeftPadW_, InRightPadW_), + make_pass_through_transform(NumGroupsToMerge), + make_pass_through_transform(C_), + make_pad_transform(1, 0, NumGroupsToMerge - 1)), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4>{}, + Sequence<5>{}), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4>{}, + Sequence<5>{})); + // We need only matrices from diagonal. X_or returns 0 for the same + // values. So if matrices is not on diagonal then it will be stored in padding. + // To avoid use of modulo after xor we assume that NumBatch to merge is power + // of 2. + static_assert(NumGroupsToMerge == 1 || NumGroupsToMerge == 2 || + NumGroupsToMerge == 4 || NumGroupsToMerge == 8 || + NumGroupsToMerge == 16 || NumGroupsToMerge == 32 || + NumGroupsToMerge == 64); + const auto in_n_ytilde_htilde_xtilde_wtilde_c_grid_desc = + transform_tensor_descriptor( + in_n_hip_wip_c_grid_desc, + make_tuple( + make_pass_through_transform(N_), + make_embed_transform(make_tuple(YTilde_, HTilde_), + make_tuple(ConvDilationH_, ConvStrideH_)), + make_embed_transform(make_tuple(XTilde_, WTilde_), + make_tuple(ConvDilationW_, ConvStrideW_)), + make_xor_transform(make_tuple(NumGroupsToMerge, NumGroupsToMerge)), + make_pass_through_transform(C_)), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3, 5>{}, + Sequence<4>{}), + make_tuple(Sequence<0>{}, + Sequence<1, 2>{}, + Sequence<3, 4>{}, + Sequence<5, 7>{}, + Sequence<6>{})); + + const auto in_n_htildeslice_wtildeslice_c_grid_desc = + transform_tensor_descriptor( + in_n_ytilde_htilde_xtilde_wtilde_c_grid_desc, + make_tuple( + make_pass_through_transform(N_), + make_freeze_transform(IdxYTilde_), + make_slice_transform(HTilde_, IHTildeSliceBegin, HTildeSlice), + make_freeze_transform(IdxXTilde_), + make_slice_transform(WTilde_, IWTildeSliceBegin, WTildeSlice), + make_pass_through_transform(NumGroupsToMerge), + make_pass_through_transform(C_), + make_pass_through_transform(NumGroupsToMerge)), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4>{}, + Sequence<5>{}, + Sequence<6>{}, + Sequence<7>{}), + make_tuple(Sequence<0>{}, + Sequence<>{}, + Sequence<1>{}, + Sequence<>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4>{}, + Sequence<5>{})); + + const auto in_gemmmraw_gemmnraw_grid_desc = transform_tensor_descriptor( + in_n_htildeslice_wtildeslice_c_grid_desc, + make_tuple(make_merge_transform( + make_tuple(N_, HTildeSlice, WTildeSlice, NumGroupsToMerge)), + make_merge_transform(make_tuple(C_, NumGroupsToMerge))), + make_tuple(Sequence<0, 1, 2, 3>{}, Sequence<4, 5>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto in_gemmm_gemmn_grid_desc = + ck::tensor_operation::device::PadTensorDescriptor( + in_gemmmraw_gemmnraw_grid_desc, + make_tuple(GemmMPerBlock, GemmNPerBlock), + Sequence{}); + + return in_gemmm_gemmn_grid_desc; + } } else if(NDimSpatial == 3) { - const auto in_n_dip_hip_wip_c_grid_desc = transform_tensor_descriptor( - in_grid_desc, - make_tuple(make_pass_through_transform(N_), - make_pad_transform(Di_, InLeftPadD_, InRightPadD_), - make_pad_transform(Hi_, InLeftPadH_, InRightPadH_), - make_pad_transform(Wi_, InLeftPadW_, InRightPadW_), - make_pass_through_transform(C_)), - make_tuple( - Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}), - make_tuple( - Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{})); - - const auto in_n_ztilde_dtilde_ytilde_htilde_xtilde_wtilde_c_grid_desc = - transform_tensor_descriptor( - in_n_dip_hip_wip_c_grid_desc, + if constexpr(NumGroupsToMerge == 1) + { + const auto in_n_dip_hip_wip_c_grid_desc = transform_tensor_descriptor( + in_grid_desc, make_tuple(make_pass_through_transform(N_), - make_embed_transform(make_tuple(ZTilde_, DTilde_), - make_tuple(ConvDilationD_, ConvStrideD_)), - make_embed_transform(make_tuple(YTilde_, HTilde_), - make_tuple(ConvDilationH_, ConvStrideH_)), - make_embed_transform(make_tuple(XTilde_, WTilde_), - make_tuple(ConvDilationW_, ConvStrideW_)), + make_pad_transform(Di_, InLeftPadD_, InRightPadD_), + make_pad_transform(Hi_, InLeftPadH_, InRightPadH_), + make_pad_transform(Wi_, InLeftPadW_, InRightPadW_), make_pass_through_transform(C_)), make_tuple(Sequence<0>{}, Sequence<1>{}, @@ -1350,53 +2192,189 @@ struct TransformConvBwdDataToGemm_v1 Sequence<3>{}, Sequence<4>{}), make_tuple(Sequence<0>{}, - Sequence<1, 2>{}, - Sequence<3, 4>{}, - Sequence<5, 6>{}, - Sequence<7>{})); + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4>{})); - const auto in_n_dtildeslice_htildeslice_wtildeslice_c_grid_desc = - transform_tensor_descriptor( - in_n_ztilde_dtilde_ytilde_htilde_xtilde_wtilde_c_grid_desc, - make_tuple(make_pass_through_transform(N_), - make_freeze_transform(IdxZTilde_), - make_slice_transform(DTilde_, IDTildeSliceBegin, DTildeSlice), - make_freeze_transform(IdxYTilde_), - make_slice_transform(HTilde_, IHTildeSliceBegin, HTildeSlice), - make_freeze_transform(IdxXTilde_), - make_slice_transform(WTilde_, IWTildeSliceBegin, WTildeSlice), + const auto in_n_ztilde_dtilde_ytilde_htilde_xtilde_wtilde_c_grid_desc = + transform_tensor_descriptor( + in_n_dip_hip_wip_c_grid_desc, + make_tuple( + make_pass_through_transform(N_), + make_embed_transform(make_tuple(ZTilde_, DTilde_), + make_tuple(ConvDilationD_, ConvStrideD_)), + make_embed_transform(make_tuple(YTilde_, HTilde_), + make_tuple(ConvDilationH_, ConvStrideH_)), + make_embed_transform(make_tuple(XTilde_, WTilde_), + make_tuple(ConvDilationW_, ConvStrideW_)), + make_pass_through_transform(C_)), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4>{}), + make_tuple(Sequence<0>{}, + Sequence<1, 2>{}, + Sequence<3, 4>{}, + Sequence<5, 6>{}, + Sequence<7>{})); + + const auto in_n_dtildeslice_htildeslice_wtildeslice_c_grid_desc = + transform_tensor_descriptor( + in_n_ztilde_dtilde_ytilde_htilde_xtilde_wtilde_c_grid_desc, + make_tuple( + make_pass_through_transform(N_), + make_freeze_transform(IdxZTilde_), + make_slice_transform(DTilde_, IDTildeSliceBegin, DTildeSlice), + make_freeze_transform(IdxYTilde_), + make_slice_transform(HTilde_, IHTildeSliceBegin, HTildeSlice), + make_freeze_transform(IdxXTilde_), + make_slice_transform(WTilde_, IWTildeSliceBegin, WTildeSlice), + make_pass_through_transform(C_)), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4>{}, + Sequence<5>{}, + Sequence<6>{}, + Sequence<7>{}), + make_tuple(Sequence<0>{}, + Sequence<>{}, + Sequence<1>{}, + Sequence<>{}, + Sequence<2>{}, + Sequence<>{}, + Sequence<3>{}, + Sequence<4>{})); + + const auto in_gemmmraw_gemmnraw_grid_desc = transform_tensor_descriptor( + in_n_dtildeslice_htildeslice_wtildeslice_c_grid_desc, + make_tuple(make_merge_transform( + make_tuple(N_, DTildeSlice, HTildeSlice, WTildeSlice)), make_pass_through_transform(C_)), + make_tuple(Sequence<0, 1, 2, 3>{}, Sequence<4>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto in_gemmm_gemmn_grid_desc = + ck::tensor_operation::device::PadTensorDescriptor( + in_gemmmraw_gemmnraw_grid_desc, + make_tuple(GemmMPerBlock, GemmNPerBlock), + Sequence{}); + return in_gemmm_gemmn_grid_desc; + } + else + { + const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor( + in_grid_desc, + make_tuple(make_pass_through_transform(N_), + make_pad_transform(Di_, InLeftPadD_, InRightPadD_), + make_pad_transform(Hi_, InLeftPadH_, InRightPadH_), + make_pad_transform(Wi_, InLeftPadW_, InRightPadW_), + make_pass_through_transform(NumGroupsToMerge), + make_pass_through_transform(C_), + make_pad_transform(1, 0, NumGroupsToMerge - 1)), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}, Sequence<5>{}, - Sequence<6>{}, - Sequence<7>{}), + Sequence<6>{}), make_tuple(Sequence<0>{}, - Sequence<>{}, Sequence<1>{}, - Sequence<>{}, Sequence<2>{}, - Sequence<>{}, Sequence<3>{}, - Sequence<4>{})); + Sequence<4>{}, + Sequence<5>{}, + Sequence<6>{})); + // We need only matrices from diagonal. X_or returns 0 for the same + // values. So if matrices is not on diagonal then it will be stored in padding. + // To avoid use of modulo after xor we assume that NumBatch to merge is power + // of 2. + static_assert(NumGroupsToMerge == 1 || NumGroupsToMerge == 2 || + NumGroupsToMerge == 4 || NumGroupsToMerge == 8 || + NumGroupsToMerge == 16 || NumGroupsToMerge == 32 || + NumGroupsToMerge == 64); + const auto in_n_ytilde_htilde_xtilde_wtilde_c_grid_desc = + transform_tensor_descriptor( + in_n_hip_wip_c_grid_desc, + make_tuple( + make_pass_through_transform(N_), + make_embed_transform(make_tuple(ZTilde_, DTilde_), + make_tuple(ConvDilationD_, ConvStrideD_)), + make_embed_transform(make_tuple(YTilde_, HTilde_), + make_tuple(ConvDilationH_, ConvStrideH_)), + make_embed_transform(make_tuple(XTilde_, WTilde_), + make_tuple(ConvDilationW_, ConvStrideW_)), + make_xor_transform(make_tuple(NumGroupsToMerge, NumGroupsToMerge)), + make_pass_through_transform(C_)), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4, 6>{}, + Sequence<5>{}), + make_tuple(Sequence<0>{}, + Sequence<1, 2>{}, + Sequence<3, 4>{}, + Sequence<5, 6>{}, + Sequence<7, 9>{}, + Sequence<8>{})); - const auto in_gemmmraw_gemmnraw_grid_desc = transform_tensor_descriptor( - in_n_dtildeslice_htildeslice_wtildeslice_c_grid_desc, - make_tuple( - make_merge_transform(make_tuple(N_, DTildeSlice, HTildeSlice, WTildeSlice)), - make_pass_through_transform(C_)), - make_tuple(Sequence<0, 1, 2, 3>{}, Sequence<4>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); + const auto in_n_htildeslice_wtildeslice_c_grid_desc = + transform_tensor_descriptor( + in_n_ytilde_htilde_xtilde_wtilde_c_grid_desc, + make_tuple( + make_pass_through_transform(N_), + make_freeze_transform(IdxZTilde_), + make_slice_transform(DTilde_, IDTildeSliceBegin, DTildeSlice), + make_freeze_transform(IdxYTilde_), + make_slice_transform(HTilde_, IHTildeSliceBegin, HTildeSlice), + make_freeze_transform(IdxXTilde_), + make_slice_transform(WTilde_, IWTildeSliceBegin, WTildeSlice), + make_pass_through_transform(NumGroupsToMerge), + make_pass_through_transform(C_), + make_pass_through_transform(NumGroupsToMerge)), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4>{}, + Sequence<5>{}, + Sequence<6>{}, + Sequence<7>{}, + Sequence<8>{}, + Sequence<9>{}), + make_tuple(Sequence<0>{}, + Sequence<>{}, + Sequence<1>{}, + Sequence<>{}, + Sequence<2>{}, + Sequence<>{}, + Sequence<3>{}, + Sequence<4>{}, + Sequence<5>{}, + Sequence<6>{})); - const auto in_gemmm_gemmn_grid_desc = - ck::tensor_operation::device::PadTensorDescriptor( - in_gemmmraw_gemmnraw_grid_desc, - make_tuple(GemmMPerBlock, GemmNPerBlock), - Sequence{}); - return in_gemmm_gemmn_grid_desc; + const auto in_gemmmraw_gemmnraw_grid_desc = transform_tensor_descriptor( + in_n_htildeslice_wtildeslice_c_grid_desc, + make_tuple( + make_merge_transform(make_tuple( + N_, DTildeSlice, HTildeSlice, WTildeSlice, NumGroupsToMerge)), + make_merge_transform(make_tuple(C_, NumGroupsToMerge))), + make_tuple(Sequence<0, 1, 2, 3, 4>{}, Sequence<5, 6>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto in_gemmm_gemmn_grid_desc = + ck::tensor_operation::device::PadTensorDescriptor( + in_gemmmraw_gemmnraw_grid_desc, + make_tuple(GemmMPerBlock, GemmNPerBlock), + Sequence{}); + + return in_gemmm_gemmn_grid_desc; + } } else { @@ -1514,6 +2492,7 @@ struct TransformConvBwdDataToGemm_v1 IndexType DoStride_, HoStride_, WoStride_; IndexType CStrideTensorB_, CStrideTensorC_, KStrideTensorA_, KStrideTensorB_; IndexType NStrideTensorA_, NStrideTensorC_; + IndexType GStrideTensorA_, GStrideTensorB_, GStrideTensorC_; IndexType ConvStrideD_, ConvStrideH_, ConvStrideW_; IndexType ConvDilationD_, ConvDilationH_, ConvDilationW_; IndexType InLeftPadD_, InLeftPadH_, InLeftPadW_; diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_xdl_instance.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_xdl_instance.hpp index 745f8cbd321..a6ee9e19849 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_xdl_instance.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_xdl_instance.hpp @@ -566,6 +566,89 @@ using device_grouped_conv_bwd_data_xdl_input_fp16_comp_bf8f8_instances = // clang-format on >; +template +using device_grouped_conv_bwd_data_xdl_f32_merged_groups_instances = std::tuple< + // clang-format off + // ##############################################| NDim| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| AElementwise| BElementwise| CDEElementwise| ConvolutionBackward| DoPad| DoPad| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffleMXdl| CShuffleNXdl| CDEBlockTransfer| CDEBlockTransfer| LoopScheduler| ACompute| BCompute| MaxTranspose| MaxTranspose| NumGroups| + // ##############################################| Spatial| | | | | Type| Type| Type| DataType| Type| Type| Operation| Operation| Operation| DataSpecialization| GemmM| GemmN| PrefetchStage| Size| Block| Block| Block| | | XDL| XDL| PerWave| PerWave| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| PerWave| PerWave| _MBlock_MPerBlock| ScalarPerVector| | Type| Type| TransferInScalar| TransferOutScalar| ToMerge| + // ##############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| | | | PerVector| PerVector| | + // ##############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 64, 32, 32, 32, 8, 8, 32, 32, 1, 1, S<4, 16, 1>, S<2, 0, 1>, S<2, 0, 1>, 2, 1, 2, 1, S<4, 16, 1>, S<2, 0, 1>, S<2, 0, 1>, 1, 1, 2, 1, 1, 1, S<1, 8, 1, 8>, 1, LoopScheduler::Default, F32, F32, 2, 4, 16>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 64, 32, 32, 32, 8, 8, 32, 32, 1, 1, S<4, 16, 1>, S<2, 0, 1>, S<2, 0, 1>, 2, 1, 2, 1, S<4, 16, 1>, S<2, 0, 1>, S<2, 0, 1>, 1, 1, 2, 1, 1, 1, S<1, 8, 1, 8>, 1, LoopScheduler::Default, F32, F32, 4, 2, 16>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 64, 32, 32, 32, 8, 8, 32, 32, 1, 1, S<4, 16, 1>, S<2, 0, 1>, S<2, 0, 1>, 2, 1, 2, 1, S<4, 16, 1>, S<2, 0, 1>, S<2, 0, 1>, 1, 1, 2, 1, 1, 1, S<1, 8, 1, 8>, 1, LoopScheduler::Default, F32, F32, 2, 2, 16>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 64, 32, 32, 32, 8, 8, 32, 32, 1, 1, S<4, 16, 1>, S<2, 0, 1>, S<2, 0, 1>, 2, 1, 2, 1, S<4, 16, 1>, S<2, 0, 1>, S<2, 0, 1>, 1, 1, 2, 1, 1, 1, S<1, 8, 1, 8>, 1, LoopScheduler::Default, F32, F32, 1, 2, 16>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 64, 32, 32, 32, 8, 8, 32, 32, 1, 1, S<4, 16, 1>, S<2, 0, 1>, S<2, 0, 1>, 2, 1, 2, 1, S<4, 16, 1>, S<2, 0, 1>, S<2, 0, 1>, 1, 1, 2, 1, 1, 1, S<1, 8, 1, 8>, 1, LoopScheduler::Default, F32, F32, 2, 1, 16>, + + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 128, 64, 32, 32, 8, 8, 32, 32, 1, 1, S<4, 32, 1>, S<2, 0, 1>, S<2, 0, 1>, 2, 1, 2, 1, S<4, 32, 1>, S<2, 0, 1>, S<2, 0, 1>, 1, 1, 2, 1, 1, 1, S<1, 8, 1, 8>, 1, LoopScheduler::Default, F32, F32, 2, 1, 8>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 128, 64, 32, 32, 8, 8, 32, 32, 1, 1, S<4, 32, 1>, S<2, 0, 1>, S<2, 0, 1>, 2, 1, 2, 1, S<4, 32, 1>, S<2, 0, 1>, S<2, 0, 1>, 1, 1, 2, 1, 1, 1, S<1, 8, 1, 8>, 1, LoopScheduler::Default, F32, F32, 1, 2, 8> + // clang-format on + >; + +template +using device_grouped_conv_bwd_data_xdl_f16_merged_groups_instances = std::tuple< + // clang-format off + // ##############################################| NDim| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| AElementwise| BElementwise| CDEElementwise| ConvolutionBackward| DoPad| DoPad| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffleMXdl| CShuffleNXdl| CDEBlockTransfer| CDEBlockTransfer| LoopScheduler| ACompute| BCompute| MaxTranspose| MaxTranspose| NumGroups| + // ##############################################| Spatial| | | | | Type| Type| Type| DataType| Type| Type| Operation| Operation| Operation| DataSpecialization| GemmM| GemmN| PrefetchStage| Size| Block| Block| Block| | | XDL| XDL| PerWave| PerWave| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| PerWave| PerWave| _MBlock_MPerBlock| ScalarPerVector| | Type| Type| TransferInScalar| TransferOutScalar| ToMerge| + // ##############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| | | | PerVector| PerVector| | + // ##############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + // generic instance + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 64, 32, 32, 32, 8, 8, 32, 32, 1, 1, S<4, 16, 1>, S<2, 0, 1>, S<2, 0, 1>, 2, 1, 2, 1, S<4, 16, 1>, S<2, 0, 1>, S<2, 0, 1>, 1, 1, 2, 1, 1, 1, S<1, 8, 1, 8>, 1, make_default_loop_scheduler(), F16, F16, 2, 2, 2>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2, S<4, 16, 1>, S<2, 0, 1>, S<2, 0, 1>, 2, 1, 4, 1, S<4, 16, 1>, S<2, 0, 1>, S<2, 0, 1>, 1, 1, 4, 1, 1, 1, S<1, 8, 1, 8>, 1, make_default_loop_scheduler(), F16, F16, 4, 4, 4>, + + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 64, 32, 32, 32, 8, 8, 32, 32, 1, 1, S<4, 16, 1>, S<2, 0, 1>, S<2, 0, 1>, 2, 1, 2, 1, S<4, 16, 1>, S<2, 0, 1>, S<2, 0, 1>, 1, 1, 2, 1, 1, 1, S<1, 8, 1, 8>, 1, make_default_loop_scheduler(), F16, F16, 1, 2, 2>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2, S<4, 16, 1>, S<2, 0, 1>, S<2, 0, 1>, 2, 1, 4, 1, S<4, 16, 1>, S<2, 0, 1>, S<2, 0, 1>, 1, 1, 4, 1, 1, 1, S<1, 8, 1, 8>, 1, make_default_loop_scheduler(), F16, F16, 1, 4, 4>, + + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 64, 32, 32, 32, 8, 8, 32, 32, 1, 1, S<4, 16, 1>, S<2, 0, 1>, S<2, 0, 1>, 2, 1, 2, 1, S<4, 16, 1>, S<2, 0, 1>, S<2, 0, 1>, 1, 1, 2, 1, 1, 1, S<1, 8, 1, 8>, 1, make_default_loop_scheduler(), F16, F16, 2, 1, 2>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2, S<4, 16, 1>, S<2, 0, 1>, S<2, 0, 1>, 2, 1, 4, 1, S<4, 16, 1>, S<2, 0, 1>, S<2, 0, 1>, 1, 1, 4, 1, 1, 1, S<1, 8, 1, 8>, 1, make_default_loop_scheduler(), F16, F16, 4, 1, 4>, + + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 64, 16, 64, 32, 8, 8, 16, 16, 1, 4, S<4, 16, 1>, S<2, 0, 1>, S<2, 0, 1>, 2, 1, 2, 1, S<4, 16, 1>, S<2, 0, 1>, S<2, 0, 1>, 1, 1, 2, 1, 1, 1, S<1, 8, 1, 8>, 1, make_default_loop_scheduler(), F16, F16, 2, 1, 2>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 64, 16, 128, 32, 8, 8, 16, 16, 1, 8, S<4, 16, 1>, S<2, 0, 1>, S<2, 0, 1>, 2, 1, 4, 1, S<4, 16, 1>, S<2, 0, 1>, S<2, 0, 1>, 1, 1, 4, 1, 1, 1, S<1, 8, 1, 8>, 1, make_default_loop_scheduler(), F16, F16, 4, 1, 4>, + + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 128, 32, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 32, 1>, S<2, 0, 1>, S<2, 0, 1>, 2, 1, 2, 1, S<4, 32, 1>, S<2, 0, 1>, S<2, 0, 1>, 1, 1, 2, 1, 1, 1, S<1, 4, 1, 8>, 1, make_default_loop_scheduler(), F16, F16, 1, 1, 2>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 256, 32, 128, 32, 8, 8, 32, 32, 1, 1, S<4, 32, 1>, S<2, 0, 1>, S<2, 0, 1>, 2, 1, 2, 1, S<4, 32, 1>, S<2, 0, 1>, S<2, 0, 1>, 1, 1, 2, 1, 1, 1, S<1, 4, 1, 8>, 1, make_default_loop_scheduler(), F16, F16, 1, 1, 2> + // clang-format on + >; + +template +using device_grouped_conv_bwd_data_xdl_bf16_merged_groups_instances = std::tuple< + // clang-format off + // ##############################################| NDim| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| AElementwise| BElementwise| CDEElementwise| ConvolutionBackward| DoPad| DoPad| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffleMXdl| CShuffleNXdl| CDEBlockTransfer| CDEBlockTransfer| LoopScheduler| ACompute| BCompute| MaxTranspose| MaxTranspose| NumGroups| + // ##############################################| Spatial| | | | | Type| Type| Type| DataType| Type| Type| Operation| Operation| Operation| DataSpecialization| GemmM| GemmN| PrefetchStage| Size| Block| Block| Block| | | XDL| XDL| PerWave| PerWave| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| PerWave| PerWave| _MBlock_MPerBlock| ScalarPerVector| | Type| Type| TransferInScalar| TransferOutScalar| ToMerge| + // ##############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| | | | PerVector| PerVector| | + // ##############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + // generic instance + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, BF16, BF16, F32, F32, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 64, 32, 32, 32, 8, 8, 32, 32, 1, 1, S<4, 16, 1>, S<2, 0, 1>, S<2, 0, 1>, 2, 1, 2, 1, S<4, 16, 1>, S<2, 0, 1>, S<2, 0, 1>, 1, 1, 2, 1, 1, 1, S<1, 8, 1, 8>, 1, make_default_loop_scheduler(), BF16, BF16, 2, 2, 2>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, BF16, BF16, F32, F32, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2, S<4, 16, 1>, S<2, 0, 1>, S<2, 0, 1>, 2, 1, 4, 1, S<4, 16, 1>, S<2, 0, 1>, S<2, 0, 1>, 1, 1, 4, 1, 1, 1, S<1, 8, 1, 8>, 1, make_default_loop_scheduler(), BF16, BF16, 4, 4, 4>, + + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, BF16, BF16, F32, F32, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 64, 32, 32, 32, 8, 8, 32, 32, 1, 1, S<4, 16, 1>, S<2, 0, 1>, S<2, 0, 1>, 2, 1, 2, 1, S<4, 16, 1>, S<2, 0, 1>, S<2, 0, 1>, 1, 1, 2, 1, 1, 1, S<1, 8, 1, 8>, 1, make_default_loop_scheduler(), BF16, BF16, 1, 2, 2>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, BF16, BF16, F32, F32, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2, S<4, 16, 1>, S<2, 0, 1>, S<2, 0, 1>, 2, 1, 4, 1, S<4, 16, 1>, S<2, 0, 1>, S<2, 0, 1>, 1, 1, 4, 1, 1, 1, S<1, 8, 1, 8>, 1, make_default_loop_scheduler(), BF16, BF16, 1, 4, 4>, + + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, BF16, BF16, F32, F32, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 64, 32, 32, 32, 8, 8, 32, 32, 1, 1, S<4, 16, 1>, S<2, 0, 1>, S<2, 0, 1>, 2, 1, 2, 1, S<4, 16, 1>, S<2, 0, 1>, S<2, 0, 1>, 1, 1, 2, 1, 1, 1, S<1, 8, 1, 8>, 1, make_default_loop_scheduler(), BF16, BF16, 2, 1, 2>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, BF16, BF16, F32, F32, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2, S<4, 16, 1>, S<2, 0, 1>, S<2, 0, 1>, 2, 1, 4, 1, S<4, 16, 1>, S<2, 0, 1>, S<2, 0, 1>, 1, 1, 4, 1, 1, 1, S<1, 8, 1, 8>, 1, make_default_loop_scheduler(), BF16, BF16, 4, 1, 4>, + + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, BF16, BF16, F32, F32, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 64, 16, 64, 32, 8, 8, 16, 16, 1, 4, S<4, 16, 1>, S<2, 0, 1>, S<2, 0, 1>, 2, 1, 2, 1, S<4, 16, 1>, S<2, 0, 1>, S<2, 0, 1>, 1, 1, 2, 1, 1, 1, S<1, 8, 1, 8>, 1, make_default_loop_scheduler(), BF16, BF16, 2, 1, 2>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, BF16, BF16, F32, F32, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 64, 16, 128, 32, 8, 8, 16, 16, 1, 8, S<4, 16, 1>, S<2, 0, 1>, S<2, 0, 1>, 2, 1, 4, 1, S<4, 16, 1>, S<2, 0, 1>, S<2, 0, 1>, 1, 1, 4, 1, 1, 1, S<1, 8, 1, 8>, 1, make_default_loop_scheduler(), BF16, BF16, 4, 1, 4>, + + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, BF16, BF16, F32, F32, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 128, 32, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 32, 1>, S<2, 0, 1>, S<2, 0, 1>, 2, 1, 2, 1, S<4, 32, 1>, S<2, 0, 1>, S<2, 0, 1>, 1, 1, 2, 1, 1, 1, S<1, 4, 1, 8>, 1, make_default_loop_scheduler(), BF16, BF16, 1, 1, 2>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, BF16, BF16, F32, F32, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 256, 32, 128, 32, 8, 8, 32, 32, 1, 1, S<4, 32, 1>, S<2, 0, 1>, S<2, 0, 1>, 2, 1, 2, 1, S<4, 32, 1>, S<2, 0, 1>, S<2, 0, 1>, 1, 1, 2, 1, 1, 1, S<1, 4, 1, 8>, 1, make_default_loop_scheduler(), BF16, BF16, 1, 1, 2> + // clang-format on + >; + } // namespace instance } // namespace device } // namespace tensor_operation diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_two_stage_xdl_instance.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_two_stage_xdl_instance.hpp index 3633015a785..25b2ae5e38d 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_two_stage_xdl_instance.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_two_stage_xdl_instance.hpp @@ -384,6 +384,58 @@ using device_grouped_conv_bwd_weight_two_stage_ngchw_xdl_c_shuffle_bf16_part2_in // clang-format on >; +template +using device_grouped_conv_bwd_weight_two_stage_ngchw_xdl_c_shuffle_f32_merge_instances = std::tuple< + // clang-format off + //#############################################| Num| InLayout| WeiLayout| OutLayout| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransfer| CBlockTransfer| BlockGemm| BlockGemm| NumGroups| Compute| Compute| Transpose| Transpose| + //#############################################| Dim| | | | Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Weight| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| ClusterLengths| ScalarPerVector| Pipeline| Pipeline| ToMerge| TypeA| TypeB| TransferSrc| TransferDst| + //#############################################| Spatial| | | | | | | | Operation| Operation| Operation| Specialization| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| MBlock_MPerBlock| NWaveNPerXdl| Scheduler| Version| | | | ScalarPerVector| ScalarPerVector| + //#############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | NBlock_NPerBlock| | | | | | | | | + DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 256, 256, 32, 8, 32, 32, 4, 4, S<4, 32, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 2, 2, 0, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 2, 2, 0, 1, 1, S<1, 8, 1, 8>, 1, Scheduler, PipelineVersion, 2, F32, F32, 1, 2>, + DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 256, 256, 32, 8, 32, 32, 4, 4, S<4, 32, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 2, 2, 0, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 2, 2, 0, 1, 1, S<1, 8, 1, 8>, 1, Scheduler, PipelineVersion, 2, F32, F32, 1, 4>, + DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 256, 256, 32, 8, 32, 32, 4, 4, S<4, 32, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 2, 2, 0, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 2, 2, 0, 1, 1, S<1, 8, 1, 8>, 1, Scheduler, PipelineVersion, 2, F32, F32, 4, 2>, + DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 256, 256, 32, 8, 32, 32, 4, 4, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 2, 2, 0, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 2, 2, 0, 1, 1, S<1, 8, 1, 8>, 1, Scheduler, PipelineVersion, 2, F32, F32, 2, 2>, + + DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 64, 64, 32, 8, 32, 32, 2, 1, S<4, 32, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 2, 2, 0, S<4, 32, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 2, 2, 0, 1, 1, S<1, 8, 1, 8>, 1, Scheduler, PipelineVersion, 2, F32, F32, 1, 2>, + DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 64, 64, 32, 8, 32, 32, 2, 1, S<4, 32, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 2, 2, 0, S<4, 32, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 2, 2, 0, 1, 1, S<1, 8, 1, 8>, 1, Scheduler, PipelineVersion, 4, F32, F32, 1, 2>, + DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 64, 64, 32, 8, 32, 32, 2, 1, S<4, 32, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 2, 2, 0, S<4, 32, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 2, 2, 0, 1, 1, S<1, 8, 1, 8>, 1, Scheduler, PipelineVersion, 4, F32, F32, 1, 4>, + DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 64, 64, 32, 8, 32, 32, 2, 1, S<4, 32, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 2, 2, 0, S<4, 32, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 2, 2, 0, 1, 1, S<1, 8, 1, 8>, 1, Scheduler, PipelineVersion, 4, F32, F32, 4, 2>, + DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 64, 64, 32, 8, 32, 32, 2, 1, S<4, 32, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 2, 2, 0, S<4, 32, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 2, 2, 0, 1, 1, S<1, 8, 1, 8>, 1, Scheduler, PipelineVersion, 4, F32, F32, 4, 4> + // clang-format on + >; + +template +using device_grouped_conv_bwd_weight_two_stage_nhwgc_xdl_c_shuffle_f32_merge_instances = std::tuple< + // clang-format off + //#############################################| Num| InLayout| WeiLayout| OutLayout| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransfer| CBlockTransfer| BlockGemm| BlockGemm| NumGroups| Compute| Compute| Transpose| Transpose| + //#############################################| Dim| | | | Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Weight| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| ClusterLengths| ScalarPerVector| Pipeline| Pipeline| ToMerge| TypeA| TypeB| TransferSrc| TransferDst| + //#############################################| Spatial| | | | | | | | Operation| Operation| Operation| Specialization| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| MBlock_MPerBlock| NWaveNPerXdl| Scheduler| Version| | | | ScalarPerVector| ScalarPerVector| + //#############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | NBlock_NPerBlock| | | | | | | | | + DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 32, 64, 32, 8, 16, 16, 2, 2, S<4, 16, 1>, S<2, 0, 1>, S<2, 0, 1>, 1, 2, 2, 0, S<4, 32, 1>, S<2, 0, 1>, S<2, 0, 1>, 1, 2, 2, 0, 1, 1, S<1, 8, 1, 8>, 1, Scheduler, PipelineVersion, 2, F32, F32, 1, 4>, + DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 32, 64, 32, 8, 16, 16, 2, 2, S<4, 16, 1>, S<2, 0, 1>, S<2, 0, 1>, 1, 2, 2, 0, S<4, 32, 1>, S<2, 0, 1>, S<2, 0, 1>, 1, 2, 2, 0, 1, 1, S<1, 8, 1, 8>, 1, Scheduler, PipelineVersion, 4, F32, F32, 1, 4>, + DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 32, 64, 32, 8, 16, 16, 2, 2, S<4, 16, 1>, S<2, 0, 1>, S<2, 0, 1>, 1, 2, 2, 0, S<4, 32, 1>, S<2, 0, 1>, S<2, 0, 1>, 1, 2, 2, 0, 1, 1, S<1, 8, 1, 8>, 1, Scheduler, PipelineVersion, 4, F32, F32, 2, 2>, + DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 32, 64, 32, 8, 16, 16, 2, 2, S<4, 16, 1>, S<2, 0, 1>, S<2, 0, 1>, 1, 2, 2, 0, S<4, 32, 1>, S<2, 0, 1>, S<2, 0, 1>, 1, 2, 2, 0, 1, 1, S<1, 8, 1, 8>, 1, Scheduler, PipelineVersion, 4, F32, F32, 2, 4>, + DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 32, 64, 32, 8, 16, 16, 2, 2, S<4, 16, 1>, S<2, 0, 1>, S<2, 0, 1>, 1, 2, 2, 0, S<4, 32, 1>, S<2, 0, 1>, S<2, 0, 1>, 1, 2, 2, 0, 1, 1, S<1, 8, 1, 8>, 1, Scheduler, PipelineVersion, 4, F32, F32, 4, 1>, + DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 32, 64, 32, 8, 16, 16, 2, 2, S<4, 16, 1>, S<2, 0, 1>, S<2, 0, 1>, 1, 2, 2, 0, S<4, 32, 1>, S<2, 0, 1>, S<2, 0, 1>, 1, 2, 2, 0, 1, 1, S<1, 8, 1, 8>, 1, Scheduler, PipelineVersion, 4, F32, F32, 4, 2>, + DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 32, 64, 32, 8, 16, 16, 2, 2, S<4, 16, 1>, S<2, 0, 1>, S<2, 0, 1>, 1, 2, 2, 0, S<4, 32, 1>, S<2, 0, 1>, S<2, 0, 1>, 1, 2, 2, 0, 1, 1, S<1, 8, 1, 8>, 1, Scheduler, PipelineVersion, 4, F32, F32, 4, 4>, + + DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 16, 64, 32, 8, 16, 16, 1, 4, S<4, 4, 1>, S<2, 0, 1>, S<2, 0, 1>, 1, 2, 2, 0, S<4, 16, 1>, S<2, 0, 1>, S<2, 0, 1>, 1, 2, 2, 0, 1, 1, S<1, 8, 1, 8>, 1, Scheduler, PipelineVersion, 2, F32, F32, 2, 1>, + DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 16, 128, 32, 8, 16, 16, 1, 8, S<4, 4, 1>, S<2, 0, 1>, S<2, 0, 1>, 1, 4, 4, 0, S<4, 16, 1>, S<2, 0, 1>, S<2, 0, 1>, 1, 4, 4, 0, 1, 1, S<1, 8, 1, 8>, 1, Scheduler, PipelineVersion, 4, F32, F32, 4, 1> + // clang-format on + >; + } // namespace instance } // namespace device } // namespace tensor_operation diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_merged_groups_instance.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_merged_groups_instance.hpp index 944e68f1927..0519b288936 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_merged_groups_instance.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_merged_groups_instance.hpp @@ -137,6 +137,8 @@ using device_grouped_conv_fwd_xdl_merged_groups_f32_instances = std::tuple< //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | // Instances with NumGroupsPerBatch > 1 + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, S< 4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, 1, 1, S<1, 16, 1, 4>, 1, F32, F32, LoopScheduler::Default, 2>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, S< 4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, 1, 1, S<1, 16, 1, 4>, 1, F32, F32, LoopScheduler::Default, 4>, DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, S< 4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, 1, 1, S<1, 16, 1, 4>, 1, F32, F32, LoopScheduler::Default, 8>, DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, S< 4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, 1, 1, S<1, 16, 1, 4>, 1, F32, F32, LoopScheduler::Default, 16>, DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, S< 4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, 1, 1, S<1, 16, 1, 4>, 1, F32, F32, LoopScheduler::Default, 32> diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight.hpp index 6dd8758eb7e..3c0544ee281 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight.hpp @@ -367,6 +367,9 @@ struct DeviceOperationInstanceFactory>>& instances); +void add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_instances( + std::vector>>& instances); + void add_device_grouped_conv2d_bwd_weight_xdl_ngchw_gkyxc_ngkhw_f32_instances( std::vector>>& instances); +void add_device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_f32_instances( + std::vector>>& instances); + #endif #ifdef CK_ENABLE_TF32 void add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_instances( + std::vector>>& instances); + +void add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f32_pad0_pipev2_instances( std::vector>>& instances); + +void add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_pad0_pipev5_instances( + std::vector>>& instances); + #endif // conv3d backward weight #ifdef CK_ENABLE_BF16 diff --git a/library/src/tensor_operation_instance/gpu/gemm_streamk/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/gemm_streamk/CMakeLists.txt index b7d7ed46288..979343acdfe 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_streamk/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/gemm_streamk/CMakeLists.txt @@ -3,12 +3,5 @@ # ONLY XDL_KERNELS add_instance_library(device_gemm_streamk_instance - # device_gemm_xdl_streamk_f32_f32_f32_mk_kn_mn_instance.cpp - # device_gemm_xdl_streamk_f32_f32_f32_mk_nk_mn_instance.cpp - # device_gemm_xdl_streamk_f32_f32_f32_km_kn_mn_instance.cpp - # device_gemm_xdl_streamk_f32_f32_f32_km_nk_mn_instance.cpp device_gemm_xdl_streamk_f16_f16_f16_mk_kn_mn_instance.cpp - # device_gemm_xdl_streamk_f16_f16_f16_mk_nk_mn_instance.cpp - # device_gemm_xdl_streamk_f16_f16_f16_km_kn_mn_instance.cpp - # device_gemm_xdl_streamk_f16_f16_f16_km_nk_mn_instance.cpp ) diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/xdl/device_grouped_conv2d_bwd_data_xdl_nhwgc_gkyxc_nhwgk_bf16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/xdl/device_grouped_conv2d_bwd_data_xdl_nhwgc_gkyxc_nhwgk_bf16_instance.cpp index 09d7c9b66f7..4ba022e1de2 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/xdl/device_grouped_conv2d_bwd_data_xdl_nhwgc_gkyxc_nhwgk_bf16_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/xdl/device_grouped_conv2d_bwd_data_xdl_nhwgc_gkyxc_nhwgk_bf16_instance.cpp @@ -41,6 +41,25 @@ void add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_bf16_instances( Empty_Tuple, NHWGC, ConvBwdDataFilter1x1Stride1Pad0>{}); + + // 1. Default + add_device_operation_instances( + instances, + device_grouped_conv_bwd_data_xdl_bf16_merged_groups_instances<2, + NHWGK, + GKYXC, + Empty_Tuple, + NHWGC, + ConvBwdDataDefault>{}); + // 2. Filter1x1Stride1Pad0 + add_device_operation_instances(instances, + device_grouped_conv_bwd_data_xdl_bf16_merged_groups_instances< + 2, + NHWGK, + GKYXC, + Empty_Tuple, + NHWGC, + ConvBwdDataFilter1x1Stride1Pad0>{}); } } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/xdl/device_grouped_conv2d_bwd_data_xdl_nhwgc_gkyxc_nhwgk_f16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/xdl/device_grouped_conv2d_bwd_data_xdl_nhwgc_gkyxc_nhwgk_f16_instance.cpp index 085b6aaaf52..7de0e8c29a1 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/xdl/device_grouped_conv2d_bwd_data_xdl_nhwgc_gkyxc_nhwgk_f16_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/xdl/device_grouped_conv2d_bwd_data_xdl_nhwgc_gkyxc_nhwgk_f16_instance.cpp @@ -41,6 +41,25 @@ void add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_f16_instances( Empty_Tuple, NHWGC, ConvBwdDataFilter1x1Stride1Pad0>{}); + + // 1. Default + add_device_operation_instances( + instances, + device_grouped_conv_bwd_data_xdl_f16_merged_groups_instances<2, + NHWGK, + GKYXC, + Empty_Tuple, + NHWGC, + ConvBwdDataDefault>{}); + // 2. Filter1x1Stride1Pad0 + add_device_operation_instances(instances, + device_grouped_conv_bwd_data_xdl_f16_merged_groups_instances< + 2, + NHWGK, + GKYXC, + Empty_Tuple, + NHWGC, + ConvBwdDataFilter1x1Stride1Pad0>{}); } } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/xdl/device_grouped_conv2d_bwd_data_xdl_nhwgc_gkyxc_nhwgk_f32_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/xdl/device_grouped_conv2d_bwd_data_xdl_nhwgc_gkyxc_nhwgk_f32_instance.cpp index 02a0eeb517f..6130651ae49 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/xdl/device_grouped_conv2d_bwd_data_xdl_nhwgc_gkyxc_nhwgk_f32_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/xdl/device_grouped_conv2d_bwd_data_xdl_nhwgc_gkyxc_nhwgk_f32_instance.cpp @@ -41,6 +41,16 @@ void add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_f32_instances( Empty_Tuple, NHWGC, ConvBwdDataFilter1x1Stride1Pad0>{}); + + // 3. Merged groups + add_device_operation_instances( + instances, + device_grouped_conv_bwd_data_xdl_f32_merged_groups_instances<2, + NHWGK, + GKYXC, + Empty_Tuple, + NHWGC, + ConvBwdDataDefault>{}); } } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/CMakeLists.txt index ec9e7da3911..312f3ebb570 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/CMakeLists.txt @@ -44,6 +44,7 @@ set(GROUPED_CONV2D_BWD_WEIGHT xdl/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_f16_pipev5_irregular_instance.cpp xdl/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_bf16_pipev2_irregular_instance.cpp xdl/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_bf16_pipev5_irregular_instance.cpp + xdl/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_f32_instance.cpp xdl/ngchw_gkcyx_ngkhw/device_grouped_conv2d_bwd_weight_xdl_ngchw_gkcyx_ngkhw_f16_instance.cpp xdl/ngchw_gkcyx_ngkhw/device_grouped_conv2d_bwd_weight_xdl_ngchw_gkcyx_ngkhw_f32_instance.cpp diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_f32_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_f32_instance.cpp new file mode 100644 index 00000000000..3adc8bd22c5 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_f32_instance.cpp @@ -0,0 +1,51 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_two_stage_xdl_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] +void add_device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_f32_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_conv_bwd_weight_two_stage_nhwgc_xdl_c_shuffle_f32_merge_instances< + 2, + NHWGC, + GKYXC, + NHWGK, + ConvBwdWeightDefault, + BlockGemmPipelineScheduler::Intrawave, + BlockGemmPipelineVersion::v1>{}); + + add_device_operation_instances( + instances, + device_grouped_conv_bwd_weight_two_stage_nhwgc_xdl_c_shuffle_f32_merge_instances< + 2, + NHWGC, + GKYXC, + NHWGK, + ConvBwdWeightDefault, + BlockGemmPipelineScheduler::Intrawave, + BlockGemmPipelineVersion::v2>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/xdl/device_grouped_conv3d_bwd_data_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/xdl/device_grouped_conv3d_bwd_data_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp index 60530647942..b41cf56680b 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/xdl/device_grouped_conv3d_bwd_data_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/xdl/device_grouped_conv3d_bwd_data_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp @@ -41,6 +41,25 @@ void add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_bf16_instances( Empty_Tuple, NDHWGC, ConvBwdDataFilter1x1Stride1Pad0>{}); + + // 1. Default + add_device_operation_instances( + instances, + device_grouped_conv_bwd_data_xdl_bf16_merged_groups_instances<3, + NDHWGK, + GKZYXC, + Empty_Tuple, + NDHWGC, + ConvBwdDataDefault>{}); + // 2. Filter1x1Stride1Pad0 + add_device_operation_instances(instances, + device_grouped_conv_bwd_data_xdl_bf16_merged_groups_instances< + 3, + NDHWGK, + GKZYXC, + Empty_Tuple, + NDHWGC, + ConvBwdDataFilter1x1Stride1Pad0>{}); } } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/xdl/device_grouped_conv3d_bwd_data_xdl_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/xdl/device_grouped_conv3d_bwd_data_xdl_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp index 33a12f3bc5e..ad1ad09059c 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/xdl/device_grouped_conv3d_bwd_data_xdl_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/xdl/device_grouped_conv3d_bwd_data_xdl_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp @@ -41,6 +41,25 @@ void add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_f16_instances( Empty_Tuple, NDHWGC, ConvBwdDataFilter1x1Stride1Pad0>{}); + + // 1. Default + add_device_operation_instances( + instances, + device_grouped_conv_bwd_data_xdl_f16_merged_groups_instances<3, + NDHWGK, + GKZYXC, + Empty_Tuple, + NDHWGC, + ConvBwdDataDefault>{}); + // 2. Filter1x1Stride1Pad0 + add_device_operation_instances(instances, + device_grouped_conv_bwd_data_xdl_f16_merged_groups_instances< + 3, + NDHWGK, + GKZYXC, + Empty_Tuple, + NDHWGC, + ConvBwdDataFilter1x1Stride1Pad0>{}); } } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/xdl/device_grouped_conv3d_bwd_data_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/xdl/device_grouped_conv3d_bwd_data_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instance.cpp index 23d41f69625..59cbadde7c0 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/xdl/device_grouped_conv3d_bwd_data_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/xdl/device_grouped_conv3d_bwd_data_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instance.cpp @@ -41,6 +41,16 @@ void add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_f32_instances( Empty_Tuple, NDHWGC, ConvBwdDataFilter1x1Stride1Pad0>{}); + + // 3. Merged groups + add_device_operation_instances( + instances, + device_grouped_conv_bwd_data_xdl_f32_merged_groups_instances<3, + NDHWGK, + GKZYXC, + Empty_Tuple, + NDHWGC, + ConvBwdDataDefault>{}); } } // namespace instance diff --git a/library/src/utility/convolution_parameter.cpp b/library/src/utility/convolution_parameter.cpp index 3a8c1f0155d..158c55bce9c 100644 --- a/library/src/utility/convolution_parameter.cpp +++ b/library/src/utility/convolution_parameter.cpp @@ -63,6 +63,13 @@ ConvParam::ConvParam(ck::index_t n_dim, (input_spatial_lengths_[i] + input_left_pads_[i] + input_right_pads_[i] - x_eff) / conv_filter_strides_[i] + 1; + + if(output_spatial_lengths_[i] <= 0) + { + throw std::runtime_error( + "ConvParam::ConvParam: " + "the given input would result in output dimension less than 1!"); + } } } @@ -113,6 +120,13 @@ ConvParam::ConvParam(ck::long_index_t n_dim, (input_spatial_lengths_[i] + input_left_pads_[i] + input_right_pads_[i] - x_eff) / conv_filter_strides_[i] + 1; + + if(output_spatial_lengths_[i] <= 0) + { + throw std::runtime_error( + "ConvParam::ConvParam: " + "the given input would result in output dimension less than 1!"); + } } } @@ -220,6 +234,7 @@ std::ostream& operator<<(std::ostream& os, const ck::utils::conv::ConvParam& p) << "\nN: " << p.N_ << "\nK: " << p.K_ << "\nC: " << p.C_ << "\nfilter_spatial_lengths: " << p.filter_spatial_lengths_ << "\ninput_spatial_lengths: " << p.input_spatial_lengths_ + << "\noutput_spatial_lengths: " << p.output_spatial_lengths_ << "\nconv_filter_strides: " << p.conv_filter_strides_ << "\nconv_filter_dilations: " << p.conv_filter_dilations_ << "\ninput_left_pads: " << p.input_left_pads_