Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -277,16 +277,16 @@ endif()
# define the macro with the current value (0 or 1)
add_definitions(-DCK_TILE_USE_WMMA=${CK_TILE_USE_WMMA})

if (SUPPORTED_GPU_TARGETS MATCHES "gfx12" AND NOT FORCE_DISABLE_WMMA)
if (SUPPORTED_GPU_TARGETS MATCHES "gfx12" AND NOT FORCE_DISABLE_WMMA AND DEFINED CK_ENABLE_FP8)
message(STATUS "Enabling WMMA FP8 gemms on native architectures")
add_definitions(-DCK_USE_WMMA_FP8)
set(CK_USE_WMMA_FP8 "ON")
endif()
if (SUPPORTED_GPU_TARGETS MATCHES "gfx12" OR SUPPORTED_GPU_TARGETS MATCHES "gfx950")
if (SUPPORTED_GPU_TARGETS MATCHES "gfx12|gfx950" AND DEFINED CK_ENABLE_FP8)
add_definitions(-DCK_USE_OCP_FP8)
set(CK_USE_OCP_FP8 "ON")
endif()
if (SUPPORTED_GPU_TARGETS MATCHES "gfx90a" OR SUPPORTED_GPU_TARGETS MATCHES "gfx94")
if (SUPPORTED_GPU_TARGETS MATCHES "gfx90a|gfx94" AND DEFINED CK_ENABLE_FP8)
add_definitions(-DCK_USE_FNUZ_FP8)
set(CK_USE_FNUZ_FP8 "ON")
endif()
Expand Down
209 changes: 209 additions & 0 deletions include/ck/tensor_description/multi_index_transform.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -1745,6 +1745,215 @@ struct ConvBwdDataImplicitGemmOutTransform
}
};

/**
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a way to avoid duplication with the struct above? Afaik, these functions are only used in two places, so I would recommend to always call the MG variant and to have a defalut value (and ignore the GStep return value) to avoid duplicating the rest of the logic

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I did the unification for the generator helper function which was a low hanging fruit. With the structure I am not sure if its a good idea. It can be generalized but since the dimensions don't match it is slightly more complicated. For the time being I would leave it like this since I don't see the long term strategy with it. If this is the new method a more general infrastructure will be needed to support all convolution varaints, if that is not the case having this dirty variant seems fine to me.

@bartekxk do you have an opinion in this?

* @brief Transformation struct for convolution backward data output indices to GEMM indices.
*
* This struct is responsible for mapping the output tensor indices (N, Ho, Wo, G, K) from the
* convolution backward data operation to the corresponding indices (K0, M, K1) used in the
* implicit GEMM computation. It encapsulates the necessary parameters and transformation logic
* required to efficiently perform the index conversion.
*/
struct ConvBwdDataImplicitGemmOutTransformMG
{
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
static constexpr auto I3 = Number<3>{};
static constexpr auto I4 = Number<4>{};

using LowerIndex = MultiIndex<5>; // N, Ho, Wo, G, K
using UpperIndex = MultiIndex<3>; // K0, M, K1

index_t N_, Ho_, Wo_, K_;
index_t XDot_;
index_t HTilde_, WTilde_;
index_t WTildeSlice_NumGroupsToMerge_, TildeSlice_NumGroupsToMerge_, NumGroupsToMerge_;
index_t IHTildeSliceBegin_, IWTildeSliceBegin_;
index_t HRatio_, WRatio_;
index_t XDotSlice_K_;
index_t MPad_, KPad_;
Tuple<index_t, index_t, index_t> up_lengths_; // K0_, MPadded, K1_;

Tuple<index_t, index_t, index_t, index_t, index_t> low_lengths_magic_divisor_multiplier_;
Tuple<index_t, index_t, index_t, index_t, index_t> low_lengths_magic_divisor_shift_;

__host__ __device__ ConvBwdDataImplicitGemmOutTransformMG() = default;

__host__ __device__ constexpr ConvBwdDataImplicitGemmOutTransformMG(
index_t N,
index_t Ho,
index_t Wo,
index_t K,
index_t XDot,
index_t HTilde,
index_t WTilde,
index_t NumGroupsToMerge,
index_t WTildeSlice_NumGroupsToMerge,
index_t HWTildeSlice_NumGroupsToMerge,
index_t IHTildeSliceBegin,
index_t IWTildeSliceBegin,
index_t HRatio,
index_t WRatio,
index_t XDotSlice_K,
index_t K0,
index_t MPadded,
index_t K1,
index_t MPad,
index_t KPad)
: N_{N},
Ho_{Ho},
Wo_{Wo},
K_{K},
XDot_{XDot},
HTilde_{HTilde},
WTilde_{WTilde},
WTildeSlice_NumGroupsToMerge_{WTildeSlice_NumGroupsToMerge},
TildeSlice_NumGroupsToMerge_{HWTildeSlice_NumGroupsToMerge},
NumGroupsToMerge_(NumGroupsToMerge),
IHTildeSliceBegin_{IHTildeSliceBegin},
IWTildeSliceBegin_{IWTildeSliceBegin},
HRatio_{HRatio},
WRatio_{WRatio},
XDotSlice_K_{XDotSlice_K},
MPad_{MPad},
KPad_{KPad},
up_lengths_{make_tuple(K0, MPadded, K1)},
low_lengths_magic_divisor_multiplier_{
MagicDivision::CalculateMagicMultiplier(XDotSlice_K_),
MagicDivision::CalculateMagicMultiplier(K_),
MagicDivision::CalculateMagicMultiplier(TildeSlice_NumGroupsToMerge_),
MagicDivision::CalculateMagicMultiplier(WTildeSlice_NumGroupsToMerge_),
MagicDivision::CalculateMagicMultiplier(NumGroupsToMerge_)},
low_lengths_magic_divisor_shift_{
MagicDivision::CalculateMagicShift(XDotSlice_K_),
MagicDivision::CalculateMagicShift(K_),
MagicDivision::CalculateMagicShift(TildeSlice_NumGroupsToMerge_),
MagicDivision::CalculateMagicShift(WTildeSlice_NumGroupsToMerge_),
MagicDivision::CalculateMagicShift(NumGroupsToMerge_)}
{
}

__host__ __device__ static constexpr index_t GetNumOfLowerDimension() { return 5; }

__host__ __device__ static constexpr index_t GetNumOfUpperDimension() { return 3; }

__host__ __device__ constexpr const auto& GetUpperLengths() const { return up_lengths_; }

template <typename UpIdx>
__host__ __device__ constexpr auto CalculateLowerIndexN(const UpIdx& idx_up) const
{
index_t NStep{0}, GStep{0}, HStep{0}, WStep{0};
// Merge
// NStep = M_id / (TildeSlice * NumGroupsToMerge)
NStep = MagicDivision::DoMagicDivision(idx_up[I1],
this->low_lengths_magic_divisor_multiplier_[I2],
this->low_lengths_magic_divisor_shift_[I2]);

HStep = idx_up[I1] - NStep * TildeSlice_NumGroupsToMerge_;
// HStep = HStep / (WTildeSlice * NumGroupsToMerge)
HStep = MagicDivision::DoMagicDivision(HStep,
this->low_lengths_magic_divisor_multiplier_[I3],
this->low_lengths_magic_divisor_shift_[I3]);

WStep = idx_up[I1] - NStep * TildeSlice_NumGroupsToMerge_ -
HStep * WTildeSlice_NumGroupsToMerge_;
// WStep = WStep / NumGroupsToMerge
WStep = MagicDivision::DoMagicDivision(WStep,
this->low_lengths_magic_divisor_multiplier_[I4],
this->low_lengths_magic_divisor_shift_[I4]);
GStep = idx_up[I1] - NStep * TildeSlice_NumGroupsToMerge_ -
HStep * WTildeSlice_NumGroupsToMerge_ - WStep * NumGroupsToMerge_;

// Slice
HStep += IHTildeSliceBegin_;
WStep += IWTildeSliceBegin_;

return make_tuple(NStep, HStep, WStep, GStep, 0);
}

template <typename UpIdx>
__host__ __device__ constexpr auto CalculateLowerIndexK(const UpIdx& idx_up) const
{
// UnMerge
// K_idx <- K0_idx * K1 + K1_idx
index_t K_idx = idx_up[I0] * up_lengths_[I2] + idx_up[I2];
// Merge
// YStep = K_idx / (XDotSlice * K)
index_t YStep =
MagicDivision::DoMagicDivision(K_idx,
this->low_lengths_magic_divisor_multiplier_[I0],
this->low_lengths_magic_divisor_shift_[I0]);

index_t KStep = K_idx - YStep * XDotSlice_K_;
// Xstep = KStep / K
index_t XStep =
MagicDivision::DoMagicDivision(KStep,
this->low_lengths_magic_divisor_multiplier_[I1],
this->low_lengths_magic_divisor_shift_[I1]);
KStep -= XStep * K_;

// Embed
YStep *= HRatio_;
XStep *= WRatio_;

return make_tuple(0, YStep, XStep, 0, KStep);
}

template <typename LowIdx, typename UpIdx>
__host__ __device__ constexpr void CalculateLowerIndex(LowIdx& idx_low,
const UpIdx& idx_up) const
{
idx_low = CalculateLowerIndexN(idx_up) + CalculateLowerIndexK(idx_up);
}

template <typename LowIdxDiff,
typename UpIdxDiff,
typename LowIdx,
typename UpIdx,
index_t Hack>
__host__ __device__ void UpdateLowerIndex(LowIdxDiff& idx_diff_low,
const UpIdxDiff& /* idx_diff_up */,
LowIdx& idx_low,
const UpIdx& idx_up,
Number<Hack>) const
{
LowIdx low_old = idx_low;
idx_low = CalculateLowerIndexN(idx_up) + CalculateLowerIndexK(idx_up);
idx_diff_low = idx_low - low_old;
}

__host__ __device__ static constexpr bool IsLinearTransform() { return false; }

__host__ __device__ static constexpr bool IsValidUpperIndexAlwaysMappedToValidLowerIndex()
{
return true;
}

template <typename UpIdx>
__host__ __device__ constexpr bool
IsValidUpperIndexMappedToValidLowerIndex(const UpIdx& idx_up) const
{
// Padding
index_t K_idx = idx_up[Number<0>{}] * up_lengths_[Number<2>{}] + idx_up[Number<2>{}];
index_t& M_idx = idx_up[Number<1>{}];

bool pad_valid = M_idx < up_lengths_[Number<1>{}] - MPad_ &&
K_idx < up_lengths_[Number<0>{}] * up_lengths_[Number<2>{}] - KPad_;
return pad_valid;
}

__host__ __device__ static constexpr bool IsKnownAtCompileTime() { return false; }

__host__ __device__ void Print() const
{
printf("{");
printf("ConvBwdDataImplicitGemmOutTransformMG, ");
printf("up_lengths_");
print_multi_index(up_lengths_);
printf("}");
}
};

template <typename LowerIndex>
struct Freeze
{
Expand Down
67 changes: 47 additions & 20 deletions include/ck/tensor_description/multi_index_transform_helper.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@ __host__ __device__ constexpr auto make_unmerge_transform(
return UnMerge<UpLengths, Use24BitIntegerCalculation>{up_lengths};
}

template <index_t NumGroupsToMerge = 1>
__host__ __device__ constexpr auto make_conv_bwd_data_out_transform(index_t N,
index_t Ho,
index_t Wo,
Expand All @@ -118,33 +119,59 @@ __host__ __device__ constexpr auto make_conv_bwd_data_out_transform(index_t N,
index_t GemmKPerBlock)
{
// Calculate padding
const auto MRaw = N * HTildeSlice * WTildeSlice;
const auto MRaw = NumGroupsToMerge * N * HTildeSlice * WTildeSlice;
const auto MPadded = math::integer_divide_ceil(MRaw, MPerBlock) * MPerBlock;
const auto MPad = MPadded - MRaw;

const auto KRaw = YDotSlice * XDotSlice * K;
const auto KPadded = math::integer_divide_ceil(KRaw, GemmKPerBlock) * GemmKPerBlock;
const auto KPad = KPadded - KRaw;

return ConvBwdDataImplicitGemmOutTransform{N,
Ho,
Wo,
K,
XDot,
HTilde,
WTilde,
WTildeSlice,
HTildeSlice * WTildeSlice,
IHTildeSliceBegin,
IWTildeSliceBegin,
-ConvDilationH / GcdStrideDilationH,
-ConvDilationW / GcdStrideDilationW,
XDotSlice * K,
K0,
MPadded,
K1,
MPad,
KPad};
if constexpr(NumGroupsToMerge == 1)
{
return ConvBwdDataImplicitGemmOutTransform{N,
Ho,
Wo,
K,
XDot,
HTilde,
WTilde,
WTildeSlice,
HTildeSlice * WTildeSlice,
IHTildeSliceBegin,
IWTildeSliceBegin,
-ConvDilationH / GcdStrideDilationH,
-ConvDilationW / GcdStrideDilationW,
XDotSlice * K,
K0,
MPadded,
K1,
MPad,
KPad};
}
else
{
return ConvBwdDataImplicitGemmOutTransformMG{N,
Ho,
Wo,
K,
XDot,
HTilde,
WTilde,
NumGroupsToMerge,
WTildeSlice * NumGroupsToMerge,
HTildeSlice * WTildeSlice * NumGroupsToMerge,
IHTildeSliceBegin,
IWTildeSliceBegin,
-ConvDilationH / GcdStrideDilationH,
-ConvDilationW / GcdStrideDilationW,
XDotSlice * K,
K0,
MPadded,
K1,
MPad,
KPad};
}
}

template <typename LowerIndex>
Expand Down
Loading