diff --git a/CHANGELOG.md b/CHANGELOG.md index c99fc1d0657..3f248f6d650 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,8 @@ Documentation for Composable Kernel available at [https://rocm.docs.amd.com/proj ## (Unreleased) Composable Kernel 1.3.0 ### Added +* Added overload of load_tile_transpose that takes reference to output tensor as output parameter +* Use data type from LDS tensor view when determining tile distribution for transpose in the GEMM pipeline * Added preshuffleB support for abquant mode in blockscale GEMM. * Added support for explicit GEMM in CK_TILE grouped convolution forward and backward weight. * Added TF32 convolution support on gfx942 and gfx950 in CK. It could be enabled/disabled via `DTYPES` of "tf32". diff --git a/include/ck_tile/core/tensor/load_tile.hpp b/include/ck_tile/core/tensor/load_tile.hpp index af0f81e832c..d1c06d43780 100644 --- a/include/ck_tile/core/tensor/load_tile.hpp +++ b/include/ck_tile/core/tensor/load_tile.hpp @@ -48,19 +48,19 @@ CK_TILE_DEVICE auto load_tile(const TileWindow_& tile_window, * and an elementwise function. For each A = A0, A1… AN, the elementwise function * is additionally applied during a single read. */ -template -CK_TILE_DEVICE auto load_tile_with_elementwise(const TileWindow_& tile_window, +CK_TILE_DEVICE auto load_tile_with_elementwise(const ck_tile::tuple& tile_windows, ElementWise_ elementwise, number = {}, bool_constant = {}) { - // TODO: Tile windows should works with unknow number of params - // Load element_wise API works only when the input typle is a tuple-tyupe - return tile_window[number<0>{}].load( - tile_window, elementwise, number{}, bool_constant{}); + // TODO: Tile windows should work with unknown number of params + // Load element_wise API works only when the input type is a tuple-type + return tile_windows[number<0>{}].load( + tile_windows, elementwise, number{}, bool_constant{}); } // Per-lane read-offset tweaks allow swizzling patterns not representable by tile_distribution. @@ -85,12 +85,12 @@ template -CK_TILE_DEVICE auto load_tile(DistributedTensor_& dst_tile, +CK_TILE_DEVICE void load_tile(DistributedTensor_& dst_tile, const TileWindow_& tile_window, number = {}, bool_constant = {}) { - return tile_window.load(dst_tile, number{}, bool_constant{}); + tile_window.load(dst_tile, number{}, bool_constant{}); } /** @@ -131,7 +131,7 @@ template -CK_TILE_DEVICE auto load_tile_raw(T& tile, +CK_TILE_DEVICE void load_tile_raw(T& tile, const tile_window_linear::distr_encoding_valid, Policy>> -CK_TILE_DEVICE auto load_tile_transpose_with_offset( +CK_TILE_DEVICE void load_tile_transpose_with_offset( + DistributedTensor_& out_tensor, const tile_window_with_static_distribution& __restrict__ tile_window, index_t offset) { - using OutTileDstrEncode = typename OutputTileDistributionTraits< - typename TileDistribution_::DstrEncode, - typename BottomTensorView_::DataType>::TransposedDstrEncode; - auto out_tensor = make_static_distributed_tensor( - make_static_tile_distribution(OutTileDstrEncode{})); auto trans_tensor = tile_window.template load_transpose_with_offset(offset); constexpr auto input_distr = TileDistribution_{}; - constexpr auto output_distr = make_static_tile_distribution(OutTileDstrEncode{}); + constexpr auto output_distr = typename DistributedTensor_::StaticTileDistribution{}; constexpr auto y_in_desc = input_distr.get_ys_to_d_descriptor(); constexpr auto y_out_desc = output_distr.get_ys_to_d_descriptor(); @@ -442,8 +440,6 @@ CK_TILE_DEVICE auto load_tile_transpose_with_offset( number{}, trans_tensor.get_thread_buffer().template get_as(number{})); }); - - return out_tensor; } /** @@ -455,6 +451,7 @@ CK_TILE_DEVICE auto load_tile_transpose_with_offset( * element space size and vector length remain consistent between the input and output * distributions. * + * @tparam DistributedTensor_ The type of the tensor containing the transposed tile data. * @tparam BottomTensorView_ The type of the bottom tensor view. * @tparam WindowLengths_ The type representing the window lengths. * @tparam TileDistribution_ The type representing the tile distribution. @@ -462,16 +459,37 @@ CK_TILE_DEVICE auto load_tile_transpose_with_offset( * @tparam Policy The transpose policy to use (defaults to DefaultTranspose). * the last is SFINAE to ensure the tile distribution encoding is valid. * + * @param out_tensor A statically distributed tensor containing the transposed tile + * data. * @param tile_window The tile window with static distribution to load and transpose. * indexing. * - * @return A statically distributed tensor containing the transposed tile data. - * * @note * - The function uses compile-time checks to ensure the input and output tile distributions * are compatible in terms of element space size and vector length. * - The transpose operation is performed according to the specified Policy. */ +template < + typename DistributedTensor_, + typename BottomTensorView_, + typename WindowLengths_, + typename TileDistribution_, + index_t NumCoord, + typename Policy = DefaultTranspose, + typename = std::enable_if_t::distr_encoding_valid, + Policy>> +CK_TILE_DEVICE void +load_tile_transpose(DistributedTensor_& out_tensor, + const tile_window_with_static_distribution& __restrict__ tile_window) +{ + load_tile_transpose_with_offset(out_tensor, tile_window, 0); +} + template < typename BottomTensorView_, typename WindowLengths_, @@ -488,7 +506,15 @@ load_tile_transpose(const tile_window_with_static_distribution& __restrict__ tile_window) { - return load_tile_transpose_with_offset(tile_window, 0); + using OutTileDstrEncode = typename OutputTileDistributionTraits< + typename TileDistribution_::DstrEncode, + typename BottomTensorView_::DataType>::TransposedDstrEncode; + auto out_tensor = make_static_distributed_tensor( + make_static_tile_distribution(OutTileDstrEncode{})); + + load_tile_transpose_with_offset(out_tensor, tile_window, 0); + + return out_tensor; } } // namespace ck_tile diff --git a/include/ck_tile/core/tensor/tile_window.hpp b/include/ck_tile/core/tensor/tile_window.hpp index d39da82a627..da90675fdd4 100644 --- a/include/ck_tile/core/tensor/tile_window.hpp +++ b/include/ck_tile/core/tensor/tile_window.hpp @@ -182,11 +182,11 @@ struct tile_window_with_static_distribution * The same thread, during vectorized reading, accesses the same set of * data from A0, A1, A2, … AN. */ - template - CK_TILE_DEVICE auto load(const TileWindow_& tile_window, + CK_TILE_DEVICE auto load(const ck_tile::tuple& tile_windows, ElementWise_ elementwise, number = {}, bool_constant = {}) const @@ -194,7 +194,7 @@ struct tile_window_with_static_distribution constexpr auto tile_dstr = typename Base::TileDstr{}; auto dst_tensor = make_static_distributed_tensor(tile_dstr); load(dst_tensor, - tile_window, + tile_windows, elementwise, number{}, bool_constant{}); @@ -202,12 +202,12 @@ struct tile_window_with_static_distribution } template CK_TILE_DEVICE void load(DistributedTensor& dst_tensor, - const TileWindow_& tile_window, + const ck_tile::tuple& tile_windows, ElementWise_ elementwise, number = {}, bool_constant = {}) const @@ -218,14 +218,14 @@ struct tile_window_with_static_distribution using SFC_Ys = typename Traits::SFC_Ys; constexpr auto tile_dstr = typename Base::TileDstr{}; - constexpr auto sizeOfTuple = TileWindow_::size(); + constexpr auto sizeOfTuple = remove_cvref_t::size(); // loop over thread tensor space [y0, y1, ...] static_for<0, NumCoord, 1>{}([&](auto iCoord) { /// TODO: use structure binding (to be captured later) if compiled in C++20 auto window_adaptor_thread_coord = - tile_window[number<0>{}].pre_computed_coords_[iCoord][I0]; + tile_windows[number<0>{}].pre_computed_coords_[iCoord][I0]; auto bottom_tensor_thread_coord = - tile_window[number<0>{}].pre_computed_coords_[iCoord][I1]; + tile_windows[number<0>{}].pre_computed_coords_[iCoord][I1]; static_for<0, NumAccessPerCoord, 1>{}([&](auto iCoordAccess) { constexpr auto iAccess = number{}; @@ -236,7 +236,7 @@ struct tile_window_with_static_distribution // read from bottom tensor const auto idx_vec_value = generate_tuple( [&](auto jj) { - return tile_window[number{}] + return tile_windows[number{}] .get_bottom_tensor_view() .template get_vectorized_elements( bottom_tensor_thread_coord, diff --git a/include/ck_tile/ops/add_rmsnorm2d_rdquant.hpp b/include/ck_tile/ops/add_rmsnorm2d_rdquant.hpp index 00234b20cf9..aa0f632c216 100644 --- a/include/ck_tile/ops/add_rmsnorm2d_rdquant.hpp +++ b/include/ck_tile/ops/add_rmsnorm2d_rdquant.hpp @@ -8,7 +8,7 @@ #include "ck_tile/ops/add_rmsnorm2d_rdquant/pipeline/add_rmsnorm2d_rdquant_fwd_pipeline_problem.hpp" #include "ck_tile/ops/add_rmsnorm2d_rdquant/pipeline/add_rmsnorm2d_rdquant_fwd_pipeline_three_pass.hpp" #include "ck_tile/ops/common/generic_2d_block_shape.hpp" -#include "ck_tile/ops/common/load_interleaved_pk_type.hpp" +#include "ck_tile/ops/common/load_and_convert_tile.hpp" #include "ck_tile/ops/common/streamk_common.hpp" #include "ck_tile/ops/common/tensor_layout.hpp" #include "ck_tile/ops/common/utils.hpp" diff --git a/include/ck_tile/ops/batched_contraction.hpp b/include/ck_tile/ops/batched_contraction.hpp index 45fa52e5051..9c90db67edd 100644 --- a/include/ck_tile/ops/batched_contraction.hpp +++ b/include/ck_tile/ops/batched_contraction.hpp @@ -6,7 +6,7 @@ #include "ck_tile/ops/batched_contraction/pipeline/batched_contraction_problem.hpp" #include "ck_tile/ops/batched_contraction/utils/tensor_descriptor_utils.hpp" #include "ck_tile/ops/common/generic_2d_block_shape.hpp" -#include "ck_tile/ops/common/load_interleaved_pk_type.hpp" +#include "ck_tile/ops/common/load_and_convert_tile.hpp" #include "ck_tile/ops/common/streamk_common.hpp" #include "ck_tile/ops/common/tensor_layout.hpp" #include "ck_tile/ops/common/utils.hpp" diff --git a/include/ck_tile/ops/batched_transpose.hpp b/include/ck_tile/ops/batched_transpose.hpp index b23e45c2331..9cac035c445 100644 --- a/include/ck_tile/ops/batched_transpose.hpp +++ b/include/ck_tile/ops/batched_transpose.hpp @@ -11,7 +11,7 @@ #include "ck_tile/ops/batched_transpose/pipeline/batched_transpose_policy.hpp" #include "ck_tile/ops/batched_transpose/pipeline/batched_transpose_problem.hpp" #include "ck_tile/ops/common/generic_2d_block_shape.hpp" -#include "ck_tile/ops/common/load_interleaved_pk_type.hpp" +#include "ck_tile/ops/common/load_and_convert_tile.hpp" #include "ck_tile/ops/common/streamk_common.hpp" #include "ck_tile/ops/common/tensor_layout.hpp" #include "ck_tile/ops/common/utils.hpp" diff --git a/include/ck_tile/ops/common.hpp b/include/ck_tile/ops/common.hpp index 94243e674f5..ad7da5c1833 100644 --- a/include/ck_tile/ops/common.hpp +++ b/include/ck_tile/ops/common.hpp @@ -3,7 +3,7 @@ #pragma once #include "ck_tile/ops/common/generic_2d_block_shape.hpp" -#include "ck_tile/ops/common/load_interleaved_pk_type.hpp" +#include "ck_tile/ops/common/load_and_convert_tile.hpp" #include "ck_tile/ops/common/streamk_common.hpp" #include "ck_tile/ops/common/tensor_layout.hpp" #include "ck_tile/ops/common/utils.hpp" diff --git a/include/ck_tile/ops/common/load_and_convert_tile.hpp b/include/ck_tile/ops/common/load_and_convert_tile.hpp new file mode 100644 index 00000000000..ada3d93b745 --- /dev/null +++ b/include/ck_tile/ops/common/load_and_convert_tile.hpp @@ -0,0 +1,56 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include "ck_tile/core/config.hpp" +#include "ck_tile/core/utility/type_traits.hpp" +#include "ck_tile/ops/elementwise/unary_element_wise_operation.hpp" + +namespace ck_tile { + +template +struct ConverterLoader +{ + template + CK_TILE_DEVICE static void load_interleaved_pk_type(WarpTile& dst, const WarpWindow& src) + { + static_assert(WarpTile::get_thread_buffer_size() % UnaryOpSize == 0); + constexpr index_t thread_buffer_size = WarpTile::get_thread_buffer_size() / UnaryOpSize; + const auto tmp = load_tile(src); + + // NOTE: we rely on types packing neatly here + using RawSrcType = typename WarpWindow::Base::DataType::type; + constexpr auto PackedSize = numeric_traits::PackedSize; + + using SrcVectorType = ext_vector_t; + using DstVectorType = ext_vector_t; + static_for<0, thread_buffer_size, 1>{}([&](auto i) { + const element_wise::PassThroughPack8 elementwise_op{}; + + elementwise_op(dst.get_thread_buffer().template get_as()(i), + tmp.get_thread_buffer().template get_as()[i]); + }); + } +}; + +template +CK_TILE_DEVICE void load_and_convert_tile(WarpTile& dst, const WarpWindow& src) +{ + if constexpr(is_packed_type_v) + { + static_assert(!LoadTranspose, "LoadTranspose not supported with pk_int4_t or pk_fp4_t"); + ConverterLoader::load_interleaved_pk_type(dst, + src); + } + else if constexpr(LoadTranspose) + { + load_tile_transpose(dst, src); + } + else + { + load_tile(dst, src); + } +} + +} // namespace ck_tile diff --git a/include/ck_tile/ops/common/load_interleaved_pk_type.hpp b/include/ck_tile/ops/common/load_interleaved_pk_type.hpp deleted file mode 100644 index 3f1a3b8f1cd..00000000000 --- a/include/ck_tile/ops/common/load_interleaved_pk_type.hpp +++ /dev/null @@ -1,62 +0,0 @@ -// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. -// SPDX-License-Identifier: MIT - -#pragma once - -#include "ck_tile/core/config.hpp" -#include "ck_tile/core/utility/type_traits.hpp" -#include "ck_tile/ops/elementwise/unary_element_wise_operation.hpp" - -namespace ck_tile { - -template -struct InterleavedPKTypeLoader -{ - template - CK_TILE_DEVICE static void load_interleaved_pk_type(WarpTile& warp_tile, - const WarpWindow& warp_window) - { - const element_wise::PassThroughPack8 elementwise_op{}; - - static_assert(WarpTile::get_thread_buffer_size() % UnaryOpSize == 0); - constexpr index_t thread_buffer_size = WarpTile::get_thread_buffer_size() / UnaryOpSize; - const auto in_dstr_tensors = load_tile(warp_window); - - // NOTE: we rely on types packing neatly here - using RawSrcType = typename SrcDataType::type; - constexpr auto PackedSize = numeric_traits::PackedSize; - - using SrcVectorType = ext_vector_t; - using DstVectorType = ext_vector_t; - static_for<0, thread_buffer_size, 1>{}([&](auto i) { - elementwise_op(warp_tile.get_thread_buffer().template get_as()(i), - in_dstr_tensors.get_thread_buffer().template get_as()[i]); - }); - } -}; - -template -CK_TILE_DEVICE void load_int4_tile(WarpTile& dst, const WarpWindow& src) -{ - if constexpr(is_packed_type_v) - { - static_assert(!LoadTranspose, "LoadTranspose not supported with pk_int4_t or pk_fp4_t"); - InterleavedPKTypeLoader::load_interleaved_pk_type( - dst, src); - } - else if constexpr(LoadTranspose) - { - dst = load_tile_transpose(src); - } - else - { - load_tile(dst, src); - } -} - -} // namespace ck_tile diff --git a/include/ck_tile/ops/elementwise.hpp b/include/ck_tile/ops/elementwise.hpp index 5752703ab60..bc72f3b0ba1 100644 --- a/include/ck_tile/ops/elementwise.hpp +++ b/include/ck_tile/ops/elementwise.hpp @@ -9,7 +9,7 @@ #include "ck_tile/ops/elementwise/pipeline/elementwise_shape.hpp" #include "ck_tile/ops/elementwise/unary_element_wise_operation.hpp" #include "ck_tile/ops/common/generic_2d_block_shape.hpp" -#include "ck_tile/ops/common/load_interleaved_pk_type.hpp" +#include "ck_tile/ops/common/load_and_convert_tile.hpp" #include "ck_tile/ops/common/streamk_common.hpp" #include "ck_tile/ops/common/tensor_layout.hpp" #include "ck_tile/ops/common/utils.hpp" diff --git a/include/ck_tile/ops/epilogue.hpp b/include/ck_tile/ops/epilogue.hpp index 433462b22e2..d1b38a8bca6 100644 --- a/include/ck_tile/ops/epilogue.hpp +++ b/include/ck_tile/ops/epilogue.hpp @@ -11,7 +11,7 @@ #include "ck_tile/ops/epilogue/default_2d_epilogue.hpp" #include "ck_tile/ops/epilogue/dynamic_quant_epilogue.hpp" #include "ck_tile/ops/common/generic_2d_block_shape.hpp" -#include "ck_tile/ops/common/load_interleaved_pk_type.hpp" +#include "ck_tile/ops/common/load_and_convert_tile.hpp" #include "ck_tile/ops/common/streamk_common.hpp" #include "ck_tile/ops/common/tensor_layout.hpp" #include "ck_tile/ops/common/utils.hpp" diff --git a/include/ck_tile/ops/flatmm.hpp b/include/ck_tile/ops/flatmm.hpp index 2d3a819e804..e08fac48c7e 100644 --- a/include/ck_tile/ops/flatmm.hpp +++ b/include/ck_tile/ops/flatmm.hpp @@ -22,7 +22,7 @@ #include "ck_tile/ops/flatmm/pipeline/mx_flatmm_pipeline_agmem_bgmem_creg_v1_policy.hpp" #include "ck_tile/ops/flatmm/pipeline/tile_flatmm_shape.hpp" #include "ck_tile/ops/common/generic_2d_block_shape.hpp" -#include "ck_tile/ops/common/load_interleaved_pk_type.hpp" +#include "ck_tile/ops/common/load_and_convert_tile.hpp" #include "ck_tile/ops/common/streamk_common.hpp" #include "ck_tile/ops/common/tensor_layout.hpp" #include "ck_tile/ops/common/utils.hpp" diff --git a/include/ck_tile/ops/fmha.hpp b/include/ck_tile/ops/fmha.hpp index eb4aa16d054..0639fa1b36e 100644 --- a/include/ck_tile/ops/fmha.hpp +++ b/include/ck_tile/ops/fmha.hpp @@ -61,7 +61,7 @@ #include "ck_tile/ops/fmha/pipeline/tile_fmha_shape.hpp" #include "ck_tile/ops/fmha/pipeline/tile_fmha_traits.hpp" #include "ck_tile/ops/common/generic_2d_block_shape.hpp" -#include "ck_tile/ops/common/load_interleaved_pk_type.hpp" +#include "ck_tile/ops/common/load_and_convert_tile.hpp" #include "ck_tile/ops/common/streamk_common.hpp" #include "ck_tile/ops/common/tensor_layout.hpp" #include "ck_tile/ops/common/utils.hpp" diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_trload_kr_ktr_vr.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_trload_kr_ktr_vr.hpp index 9aeabaa8c22..16212c0d130 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_trload_kr_ktr_vr.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_trload_kr_ktr_vr.hpp @@ -530,7 +530,7 @@ struct BlockFmhaBwdDQDKDVPipelineTrLoadKRKTRVR s_acc = gemm_0(q_reg_tensor, k_reg_tensor); dot_lds_read_window.set_bottom_tensor_view_data_ptr(do_lds_ptr_curr); - dot_reg_tensor = load_tile_transpose(dot_lds_read_window); + load_tile_transpose(dot_reg_tensor, dot_lds_read_window); } if constexpr(is_epilogue) { @@ -634,7 +634,7 @@ struct BlockFmhaBwdDQDKDVPipelineTrLoadKRKTRVR dp_acc = gemm_2(do_reg_tensor, v_reg_tensor); qt_lds_read_window.set_bottom_tensor_view_data_ptr(q_lds_ptr_curr); - qt_reg_tensor = load_tile_transpose(qt_lds_read_window); + load_tile_transpose(qt_reg_tensor, qt_lds_read_window); // STAGE 3, P^T@OGrad^T Gemm1 auto pt_reg_tensor = make_static_distributed_tensor( @@ -715,7 +715,7 @@ struct BlockFmhaBwdDQDKDVPipelineTrLoadKRKTRVR } if constexpr(is_epilogue) { - ds_reg_tensor = load_tile_transpose(ds_lds_read_window); + load_tile_transpose(ds_reg_tensor, ds_lds_read_window); move_tile_window(ds_lds_read_window, {kK4, 0}); } if constexpr(is_main_body) @@ -728,7 +728,7 @@ struct BlockFmhaBwdDQDKDVPipelineTrLoadKRKTRVR static_for<0, k4_loops, 1>{}([&](auto i_k4) { if constexpr(i_k4 < k4_loops - 1) { - ds_reg_tensor_next = load_tile_transpose(ds_lds_read_window); + load_tile_transpose(ds_reg_tensor_next, ds_lds_read_window); move_tile_window(ds_lds_read_window, {kK4, 0}); } auto kt_reg_tensor_slice = get_slice_tile( // diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_trload_qr_qtr_dor.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_trload_qr_qtr_dor.hpp index 3d21928cedf..37b4ae41a3f 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_trload_qr_qtr_dor.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_trload_qr_qtr_dor.hpp @@ -455,10 +455,10 @@ struct BlockFmhaBwdDQDKDVPipelineTrLoadQRQTRDOR async_load_tile(q_lds_write_window, q_dram_window); async_load_tile(do_lds_write_window, do_dram_window); __builtin_amdgcn_s_waitcnt(0); - qt_reg_tensor = load_tile_transpose(qt_lds_read_window); - q_reg_tensor = load_tile(q_lds_read_window); - dot_reg_tensor = load_tile_transpose(dot_lds_read_window); - do_reg_tensor = load_tile(do_lds_read_window); + load_tile_transpose(qt_reg_tensor, qt_lds_read_window); + q_reg_tensor = load_tile(q_lds_read_window); + load_tile_transpose(dot_reg_tensor, dot_lds_read_window); + do_reg_tensor = load_tile(do_lds_read_window); lse_block_tile = load_tile(lse_dram_window); d_block_tile = load_tile(d_dram_window); @@ -490,9 +490,9 @@ struct BlockFmhaBwdDQDKDVPipelineTrLoadQRQTRDOR async_load_tile(v_lds_write_window, v_dram_window); move_tile_window(v_dram_window, {kN0, 0}); s_waitcnt(); - k_reg_tensor = load_tile(k_lds_read_window); - v_reg_tensor = load_tile(v_lds_read_window); - kt_reg_tensor = load_tile_transpose(kt_lds_read_window); + k_reg_tensor = load_tile(k_lds_read_window); + v_reg_tensor = load_tile(v_lds_read_window); + load_tile_transpose(kt_reg_tensor, kt_lds_read_window); } if constexpr(is_epilogue) { @@ -668,7 +668,7 @@ struct BlockFmhaBwdDQDKDVPipelineTrLoadQRQTRDOR block_sync_lds(); if constexpr(is_epilogue) { - ds_reg_tensor = load_tile_transpose(ds_lds_read_window); + load_tile_transpose(ds_reg_tensor, ds_lds_read_window); move_tile_window(ds_lds_read_window, {kK4, 0}); } if constexpr(is_main_body) @@ -680,7 +680,7 @@ struct BlockFmhaBwdDQDKDVPipelineTrLoadQRQTRDOR static_for<0, k4_loops, 1>{}([&](auto i_k4) { if constexpr(i_k4 < k4_loops - 1) { - ds_reg_tensor_next = load_tile_transpose(ds_lds_read_window); + load_tile_transpose(ds_reg_tensor_next, ds_lds_read_window); move_tile_window(ds_lds_read_window, {kK4, 0}); } auto kt_reg_tensor_slice = get_slice_tile( // diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_v3_pipeline.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_v3_pipeline.hpp index c25f57632fa..4cca604ff15 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_v3_pipeline.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_v3_pipeline.hpp @@ -718,7 +718,7 @@ struct BlockFmhaFwdV3Pipeline }; auto V_lds_load = [&](auto v_lds_read_idx) { - kv_tile.v_tile = load_tile_transpose(v_lds_window_load(v_lds_read_idx)); + load_tile_transpose(kv_tile.v_tile, v_lds_window_load(v_lds_read_idx)); }; decltype(m) m_old; diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async_trload.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async_trload.hpp index aab79c52ae9..6bf6d2b5033 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async_trload.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async_trload.hpp @@ -591,7 +591,7 @@ struct BlockFmhaPipelineQRKSVSAsyncTrload // loop over along the [V]alue Sequence length move_tile_window(v_lds_read_window, {kK1, 0}); - v_tile = load_tile_transpose(v_lds_read_window); + load_tile_transpose(v_tile, v_lds_read_window); }); // move back to the origin move_tile_window(v_lds_read_window, {-kK1 * (k1_loops - 1), 0}); diff --git a/include/ck_tile/ops/fused_moe.hpp b/include/ck_tile/ops/fused_moe.hpp index e6802e82dce..60f5bd1c4e3 100644 --- a/include/ck_tile/ops/fused_moe.hpp +++ b/include/ck_tile/ops/fused_moe.hpp @@ -15,7 +15,7 @@ #include "ck_tile/ops/fused_moe/pipeline/moe_sorting_pipeline.hpp" #include "ck_tile/ops/fused_moe/pipeline/moe_sorting_policy.hpp" #include "ck_tile/ops/common/generic_2d_block_shape.hpp" -#include "ck_tile/ops/common/load_interleaved_pk_type.hpp" +#include "ck_tile/ops/common/load_and_convert_tile.hpp" #include "ck_tile/ops/common/streamk_common.hpp" #include "ck_tile/ops/common/tensor_layout.hpp" #include "ck_tile/ops/common/utils.hpp" diff --git a/include/ck_tile/ops/gemm.hpp b/include/ck_tile/ops/gemm.hpp index 2c3a1611216..8dbf111048e 100644 --- a/include/ck_tile/ops/gemm.hpp +++ b/include/ck_tile/ops/gemm.hpp @@ -77,7 +77,7 @@ #include "ck_tile/ops/gemm/warp/warp_gemm_smfmac_impl.hpp" #include "ck_tile/ops/gemm/warp/warp_wmma_gemm.hpp" #include "ck_tile/ops/common/generic_2d_block_shape.hpp" -#include "ck_tile/ops/common/load_interleaved_pk_type.hpp" +#include "ck_tile/ops/common/load_and_convert_tile.hpp" #include "ck_tile/ops/common/streamk_common.hpp" #include "ck_tile/ops/common/tensor_layout.hpp" #include "ck_tile/ops/common/utils.hpp" diff --git a/include/ck_tile/ops/gemm/block/block_universal_gemm_as_bs_cr.hpp b/include/ck_tile/ops/gemm/block/block_universal_gemm_as_bs_cr.hpp index 79030fcd513..6fb5cf433b1 100644 --- a/include/ck_tile/ops/gemm/block/block_universal_gemm_as_bs_cr.hpp +++ b/include/ck_tile/ops/gemm/block/block_universal_gemm_as_bs_cr.hpp @@ -4,7 +4,7 @@ #pragma once #include "ck_tile/core.hpp" -#include "ck_tile/ops/common/load_interleaved_pk_type.hpp" +#include "ck_tile/ops/common/load_and_convert_tile.hpp" #include "ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1_default_policy.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp" #include "ck_tile/ops/elementwise.hpp" @@ -217,10 +217,8 @@ struct BlockUniversalGemmAsBsCr bool_constant = {}, bool_constant = {}) { - load_int4_tile(a_warp_tile_, - a_block_window); - load_int4_tile(b_warp_tile_, - b_block_window); + load_and_convert_tile(a_warp_tile_, a_block_window); + load_and_convert_tile(b_warp_tile_, b_block_window); } // C += A * B @@ -289,9 +287,9 @@ struct BlockUniversalGemmAsBsCr static constexpr index_t KInnerLoopIter = KPerInnerLoop / WarpGemm::kKPerThread; static constexpr auto ALdsTileDistr = - make_static_tile_distribution(MakeABlockDistributionEncode()); + decltype(make_static_tile_distribution(MakeABlockDistributionEncode())){}; static constexpr auto BLdsTileDistr = - make_static_tile_distribution(MakeBBlockDistributionEncode()); + decltype(make_static_tile_distribution(MakeBBlockDistributionEncode())){}; using ALdsTile = decltype(make_static_distributed_tensor(ALdsTileDistr)); using BLdsTile = decltype(make_static_distributed_tensor(BLdsTileDistr)); @@ -348,10 +346,8 @@ struct BlockUniversalGemmAsBsCr auto b_lds_gemm_window = make_tile_window( b_block_window.get_bottom_tensor_view(), b_lds_shape, b_offset, b_lds_load_distr); - load_int4_tile(a_warp_tile_, - a_lds_gemm_window); - load_int4_tile(b_warp_tile_, - b_lds_gemm_window); + load_and_convert_tile(a_warp_tile_, a_lds_gemm_window); + load_and_convert_tile(b_warp_tile_, b_lds_gemm_window); } // C += A * B diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp index 4973d9c9410..358101d1db1 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp @@ -64,9 +64,7 @@ struct GemmPipelineAgBgCrImplBase CK_TILE_HOST_DEVICE static constexpr auto TransposeC() { return Problem::TransposeC; } - template @@ -74,7 +72,7 @@ struct GemmPipelineAgBgCrImplBase SrcTileWindow& dram_tile_window, const DramTileWindowStep& dram_tile_window_step) const { - load_int4_tile(dst_block_tile, dram_tile_window); + load_and_convert_tile(dst_block_tile, dram_tile_window); move_tile_window(dram_tile_window, dram_tile_window_step); } @@ -109,7 +107,7 @@ struct GemmPipelineAgBgCrImplBase bool_constant = {}) const { if constexpr(LoadTranspose) - dst_block_tile = load_tile_transpose(lds_tile_window); + load_tile_transpose(dst_block_tile, lds_tile_window); else load_tile(dst_block_tile, lds_tile_window); } @@ -237,12 +235,16 @@ struct GemmPipelineAgBgCrImplBase auto a_lds_load_tile_distr = []() { if constexpr(is_a_load_tr) + { return make_static_tile_distribution( typename InputTileDistributionTraits< typename ALdsLoadTileDistr::DstrEncode, - typename Problem::ADataType>::TransposedDstrEncode{}); + typename ALdsTensorView::DataType>::TransposedDstrEncode{}); + } else + { return ALdsLoadTileDistr{}; + } }(); auto a_lds_gemm_window = @@ -313,19 +315,18 @@ struct GemmPipelineAgBgCrImplBase auto b_copy_lds_window = make_tile_window(b_lds_block_view, b_lds_shape, {0, 0}); - using BLdsDataType = - std::conditional_t, - typename Problem::ADataType, - typename Problem::BDataType>; - auto b_lds_load_tile_distr = []() { if constexpr(is_b_load_tr) + { return make_static_tile_distribution( - typename InputTileDistributionTraits::TransposedDstrEncode{}); - + typename InputTileDistributionTraits< + typename BLdsLoadTileDistr::DstrEncode, + typename BLdsTensorView::DataType>::TransposedDstrEncode{}); + } else + { return BLdsLoadTileDistr{}; + } }(); auto b_lds_gemm_window = diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp index 8074994fdd3..fae37010492 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp @@ -112,7 +112,6 @@ struct UniversalGemmBasePolicy using ADataType = OverrideADataType; constexpr index_t MPerBlock = Problem::BlockGemmShape::kM; constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; - constexpr index_t KPack = Derived::template GetSmemPackA(); if constexpr(is_a_load_tr) { @@ -246,6 +245,7 @@ struct UniversalGemmBasePolicy } else // A is in RowMajor { + constexpr index_t KPack = Derived::template GetSmemPackA(); constexpr auto DataTypeSize = sizeof(ADataType); constexpr uint64_t MinLdsLayer = 1ULL; constexpr auto MLdsLayer = diff --git a/include/ck_tile/ops/gemm/pipeline/wp_pipeline_agmem_bgmem_creg_v2.hpp b/include/ck_tile/ops/gemm/pipeline/wp_pipeline_agmem_bgmem_creg_v2.hpp index c9499106de7..93999757b07 100644 --- a/include/ck_tile/ops/gemm/pipeline/wp_pipeline_agmem_bgmem_creg_v2.hpp +++ b/include/ck_tile/ops/gemm/pipeline/wp_pipeline_agmem_bgmem_creg_v2.hpp @@ -4,7 +4,7 @@ #pragma once #include "ck_tile/core.hpp" -#include "ck_tile/ops/common/load_interleaved_pk_type.hpp" +#include "ck_tile/ops/common/load_and_convert_tile.hpp" #include "ck_tile/host/concat.hpp" #include "ck_tile/ops/gemm/pipeline/wp_pipeline_agmem_bgmem_creg_base_policy.hpp" @@ -627,8 +627,7 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV2 // // Prefetch A0 Base::GlobalPrefetch(a_global_tile, a_copy_dram_window, a_dram_tile_window_step); - Base::template GlobalPrefetch( - b_global_tile[0], b_flat_dram_window, b_dram_tile_window_step); + Base::GlobalPrefetch(b_global_tile[0], b_flat_dram_window, b_dram_tile_window_step); // Prefill A0 Base::LocalPrefill(a_copy_lds_windows[I0], a_global_tile); @@ -652,7 +651,7 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV2 do { { - Base::template GlobalPrefetch( + Base::GlobalPrefetch( b_global_tile[1], b_flat_dram_window, b_dram_tile_window_step); Base::LocalPrefill(a_copy_lds_windows[I1], a_global_tile); Base::GlobalPrefetch( @@ -666,7 +665,7 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV2 HotLoopScheduler(); } { - Base::template GlobalPrefetch( + Base::GlobalPrefetch( b_global_tile[0], b_flat_dram_window, b_dram_tile_window_step); Base::LocalPrefill(a_copy_lds_windows[I0], a_global_tile); Base::GlobalPrefetch( @@ -687,7 +686,7 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV2 if constexpr(TailNum == TailNumber::Even) { { - Base::template GlobalPrefetch( + Base::GlobalPrefetch( b_global_tile[1], b_flat_dram_window, b_dram_tile_window_step); Base::LocalPrefill(a_copy_lds_windows[I1], a_global_tile); block_weight_preshuffle( diff --git a/include/ck_tile/ops/gemm_quant.hpp b/include/ck_tile/ops/gemm_quant.hpp index 696de378aaf..6aee73cda1d 100644 --- a/include/ck_tile/ops/gemm_quant.hpp +++ b/include/ck_tile/ops/gemm_quant.hpp @@ -31,7 +31,7 @@ #include "ck_tile/ops/gemm_quant/pipeline/gemm_wp_bquant_pipeline_ag_bg_cr_v2.hpp" #include "ck_tile/ops/gemm_quant/pipeline/tile_gemm_quant_traits.hpp" #include "ck_tile/ops/common/generic_2d_block_shape.hpp" -#include "ck_tile/ops/common/load_interleaved_pk_type.hpp" +#include "ck_tile/ops/common/load_and_convert_tile.hpp" #include "ck_tile/ops/common/streamk_common.hpp" #include "ck_tile/ops/common/tensor_layout.hpp" #include "ck_tile/ops/common/utils.hpp" diff --git a/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_ar_aquant_flatbr_bquant_cr.hpp b/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_ar_aquant_flatbr_bquant_cr.hpp index 1d9512b7f78..cd5a98de35c 100644 --- a/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_ar_aquant_flatbr_bquant_cr.hpp +++ b/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_ar_aquant_flatbr_bquant_cr.hpp @@ -253,7 +253,7 @@ struct BlockGemmWeightPreshuffleABQuantARegBRegCReg constexpr auto AmIter = (mIter + m_preload) % MIterPerWarp; constexpr auto AkIter = (kIter + (mIter + m_preload) / MIterPerWarp); - load_int4_tile( + load_and_convert_tile( a_warp_tensor(number{}), a_warp_windows(number{})(number{})); } diff --git a/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_aquant_bs_bquant_cr.hpp b/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_aquant_bs_bquant_cr.hpp index d79bd31489a..9cefee5d7db 100644 --- a/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_aquant_bs_bquant_cr.hpp +++ b/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_aquant_bs_bquant_cr.hpp @@ -268,10 +268,8 @@ struct ABQuantBlockUniversalGemmAsBsCr : public BlockGemmQuantBase bool_constant = {}) { // If A/B datatype were pkint4/pkfp4 it would be converted prior to storing in LDS - load_int4_tile( - a_warp_tile_, a_block_window); - load_int4_tile( - b_warp_tile_, b_block_window); + load_and_convert_tile(a_warp_tile_, a_block_window); + load_and_convert_tile(b_warp_tile_, b_block_window); } // C += A * B diff --git a/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_aquant_bs_cr.hpp b/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_aquant_bs_cr.hpp index 22563da498e..8b09530af15 100644 --- a/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_aquant_bs_cr.hpp +++ b/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_aquant_bs_cr.hpp @@ -248,10 +248,8 @@ struct AQuantBlockUniversalGemmAsBsCr // while ADatatype might not be the same as BDataType at the time of problem // initialization, we can safely use BDataType here because when A would be int4 we will // ensure A is converted to BDataType prior to loading - load_int4_tile( - a_warp_tile_, a_block_window); - load_int4_tile( - b_warp_tile_, b_block_window); + load_and_convert_tile(a_warp_tile_, a_block_window); + load_and_convert_tile(b_warp_tile_, b_block_window); } // C += A * B @@ -395,10 +393,8 @@ struct AQuantBlockUniversalGemmAsBsCr auto b_lds_gemm_window = make_tile_window( b_block_window.get_bottom_tensor_view(), b_lds_shape, b_offset, b_lds_load_distr); - load_int4_tile( - a_warp_tile_, a_lds_gemm_window); - load_int4_tile( - b_warp_tile_, b_lds_gemm_window); + load_and_convert_tile(a_warp_tile_, a_lds_gemm_window); + load_and_convert_tile(b_warp_tile_, b_lds_gemm_window); } // C += A * B with quantization support diff --git a/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_bs_bquant_cr.hpp b/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_bs_bquant_cr.hpp index 9d711c48623..1f02e24f4ee 100644 --- a/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_bs_bquant_cr.hpp +++ b/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_bs_bquant_cr.hpp @@ -258,11 +258,9 @@ struct BQuantBlockUniversalGemmAsBsCr bool_constant = {}, bool_constant = {}) { - load_int4_tile( - a_warp_tile_, a_block_window); + load_and_convert_tile(a_warp_tile_, a_block_window); // If B datatype were pkint4 it would be converted prior to storing in LDS - load_int4_tile( - b_warp_tile_, b_block_window); + load_and_convert_tile(b_warp_tile_, b_block_window); } // C += A * B diff --git a/include/ck_tile/ops/gemm_quant/pipeline/gemm_abquant_pipeline_ag_bg_cr_v3.hpp b/include/ck_tile/ops/gemm_quant/pipeline/gemm_abquant_pipeline_ag_bg_cr_v3.hpp index cfd12313e8b..e7d02536734 100644 --- a/include/ck_tile/ops/gemm_quant/pipeline/gemm_abquant_pipeline_ag_bg_cr_v3.hpp +++ b/include/ck_tile/ops/gemm_quant/pipeline/gemm_abquant_pipeline_ag_bg_cr_v3.hpp @@ -202,20 +202,16 @@ struct ABQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3(a_block_tile, a_dram_window); + load_and_convert_tile(a_block_tile, a_dram_window); } template CK_TILE_DEVICE static void LoadAndConvertBTile(BBlockTile_& b_block_tile, const BDramWindow& b_dram_window) { - using DestDataType = typename BBlockTile_::DataType; - using SrcDataType = typename BDramWindow::Base::TileWindowBase::DataType; constexpr index_t UnaryOpSize = 8; - load_int4_tile(b_block_tile, b_dram_window); + load_and_convert_tile(b_block_tile, b_dram_window); } template ADramWindow& a_dram_window, const DramTileWindowStep& dram_tile_window_step) { - using DestDataType = typename ABlockTile_::DataType; - using SrcDataType = typename ADramWindow::Base::TileWindowBase::DataType; constexpr index_t UnaryOpSize = 8; - load_int4_tile(a_block_tile, a_dram_window); + load_and_convert_tile(a_block_tile, a_dram_window); move_tile_window(a_dram_window, dram_tile_window_step); } diff --git a/include/ck_tile/ops/gemm_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_v3.hpp b/include/ck_tile/ops/gemm_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_v3.hpp index 76d8985fb15..cedc91d5641 100644 --- a/include/ck_tile/ops/gemm_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_v3.hpp +++ b/include/ck_tile/ops/gemm_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_v3.hpp @@ -174,10 +174,8 @@ struct AQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3(a_block_tile, a_dram_window); + load_and_convert_tile(a_block_tile, a_dram_window); move_tile_window(a_dram_window, dram_tile_window_step); } diff --git a/include/ck_tile/ops/gemm_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_v3.hpp b/include/ck_tile/ops/gemm_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_v3.hpp index df94eb72731..6eb2540a3b9 100644 --- a/include/ck_tile/ops/gemm_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_v3.hpp +++ b/include/ck_tile/ops/gemm_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_v3.hpp @@ -40,10 +40,7 @@ struct BQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3 && - std::is_same_v, - ADataType, - BDataType>; + std::conditional_t, ADataType, BDataType>; static_assert(BQuantGroupSize::kM == 1, "only N/K blocks for BQuant kernel!"); using I0 = number<0>; @@ -185,10 +182,8 @@ struct BQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3(b_block_tile, b_dram_window); + load_and_convert_tile(b_block_tile, b_dram_window); } template diff --git a/include/ck_tile/ops/gemm_quant/pipeline/gemm_wp_abquant_pipeline_ag_bg_cr_v2.hpp b/include/ck_tile/ops/gemm_quant/pipeline/gemm_wp_abquant_pipeline_ag_bg_cr_v2.hpp index 051b71e2c33..1a77b6a0b4e 100644 --- a/include/ck_tile/ops/gemm_quant/pipeline/gemm_wp_abquant_pipeline_ag_bg_cr_v2.hpp +++ b/include/ck_tile/ops/gemm_quant/pipeline/gemm_wp_abquant_pipeline_ag_bg_cr_v2.hpp @@ -7,7 +7,7 @@ #include #include "ck_tile/core.hpp" -#include "ck_tile/ops/common/load_interleaved_pk_type.hpp" +#include "ck_tile/ops/common/load_and_convert_tile.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp" #include "ck_tile/ops/gemm/pipeline/wp_pipeline_agmem_bgmem_creg_v2.hpp" @@ -361,8 +361,8 @@ struct WPABQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRe move_tile_window(b_flat_dram_windows(nIter)(kIter), {nIter * flatNPerWarp, kIter * flatKPerWarp}); - load_int4_tile( - b_warp_tensor_ping(nIter)(kIter), b_flat_dram_windows(nIter)(kIter)); + load_and_convert_tile(b_warp_tensor_ping(nIter)(kIter), + b_flat_dram_windows(nIter)(kIter)); }); }); // move B window to next flat K @@ -400,8 +400,6 @@ struct WPABQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRe block_sync_lds(); // preload A00,A10 from lds - using ATypeToUse = - mixed_prec_compute_type_from_input_t; using ATileType = decltype(make_static_distributed_tensor(a_warp_tile_distribution)); statically_indexed_array a_warp_tensor; @@ -409,7 +407,7 @@ struct WPABQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRe static_for<0, m_preload, 1>{}([&](auto loadIter) { constexpr auto mIter = loadIter % MIterPerWarp; constexpr auto kIter = loadIter / MIterPerWarp; - load_int4_tile( + load_and_convert_tile( a_warp_tensor(loadIter), a_warp_windows_ping(number{})(number{})); }); __builtin_amdgcn_sched_barrier(0); @@ -443,8 +441,8 @@ struct WPABQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRe move_tile_window(b_flat_dram_windows(nIter)(kIter), {nIter * flatNPerWarp, kIter * flatKPerWarp}); - load_int4_tile( - b_warp_tensor_pong(nIter)(kIter), b_flat_dram_windows(nIter)(kIter)); + load_and_convert_tile(b_warp_tensor_pong(nIter)(kIter), + b_flat_dram_windows(nIter)(kIter)); }); }); move_tile_window(b_flat_dram_window, {0, BlockGemmShape::flatKPerBlock}); @@ -459,7 +457,7 @@ struct WPABQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRe static_for<0, m_preload, 1>{}([&](auto loadIter) { constexpr auto mIter = loadIter % MIterPerWarp; constexpr auto kIter = loadIter / MIterPerWarp; - load_int4_tile( + load_and_convert_tile( a_warp_tensor(loadIter), a_warp_windows_pong(number{})(number{})); }); @@ -472,8 +470,8 @@ struct WPABQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRe move_tile_window(b_flat_dram_windows(nIter)(kIter), {nIter * flatNPerWarp, kIter * flatKPerWarp}); - load_int4_tile( - b_warp_tensor_ping(nIter)(kIter), b_flat_dram_windows(nIter)(kIter)); + load_and_convert_tile(b_warp_tensor_ping(nIter)(kIter), + b_flat_dram_windows(nIter)(kIter)); }); }); move_tile_window(b_flat_dram_window, {0, BlockGemmShape::flatKPerBlock}); @@ -504,7 +502,7 @@ struct WPABQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRe static_for<0, m_preload, 1>{}([&](auto loadIter) { constexpr auto mIter = loadIter % MIterPerWarp; constexpr auto kIter = loadIter / MIterPerWarp; - load_int4_tile( + load_and_convert_tile( a_warp_tensor(loadIter), a_warp_windows_ping(number{})(number{})); }); iCounter--; @@ -522,8 +520,8 @@ struct WPABQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRe move_tile_window(b_flat_dram_windows(nIter)(kIter), {nIter * flatNPerWarp, kIter * flatKPerWarp}); - load_int4_tile( - b_warp_tensor_pong(nIter)(kIter), b_flat_dram_windows(nIter)(kIter)); + load_and_convert_tile(b_warp_tensor_pong(nIter)(kIter), + b_flat_dram_windows(nIter)(kIter)); }); }); aq_block_tile_2 = load_tile(aq_copy_dram_window); @@ -544,7 +542,7 @@ struct WPABQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRe static_for<0, m_preload, 1>{}([&](auto loadIter) { constexpr auto mIter = loadIter % MIterPerWarp; constexpr auto kIter = loadIter / MIterPerWarp; - load_int4_tile( + load_and_convert_tile( a_warp_tensor(loadIter), a_warp_windows_pong(number{})(number{})); }); diff --git a/include/ck_tile/ops/gemm_quant/pipeline/gemm_wp_bquant_pipeline_ag_bg_cr_v2.hpp b/include/ck_tile/ops/gemm_quant/pipeline/gemm_wp_bquant_pipeline_ag_bg_cr_v2.hpp index a49279585e8..025ef53dbb0 100644 --- a/include/ck_tile/ops/gemm_quant/pipeline/gemm_wp_bquant_pipeline_ag_bg_cr_v2.hpp +++ b/include/ck_tile/ops/gemm_quant/pipeline/gemm_wp_bquant_pipeline_ag_bg_cr_v2.hpp @@ -344,8 +344,8 @@ struct WPQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRegV move_tile_window(b_flat_dram_windows(nIter)(kIter), {nIter * flatNPerWarp, kIter * flatKPerWarp}); - load_int4_tile( - b_warp_tensor_ping(nIter)(kIter), b_flat_dram_windows(nIter)(kIter)); + load_and_convert_tile(b_warp_tensor_ping(nIter)(kIter), + b_flat_dram_windows(nIter)(kIter)); }); }); // move B window to next flat K @@ -430,8 +430,8 @@ struct WPQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRegV move_tile_window(b_flat_dram_windows(nIter)(kIter), {nIter * flatNPerWarp, kIter * flatKPerWarp}); - load_int4_tile( - b_warp_tensor_pong(nIter)(kIter), b_flat_dram_windows(nIter)(kIter)); + load_and_convert_tile(b_warp_tensor_pong(nIter)(kIter), + b_flat_dram_windows(nIter)(kIter)); }); }); move_tile_window(b_flat_dram_window, {0, BlockGemmShape::flatKPerBlock}); @@ -467,8 +467,8 @@ struct WPQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRegV move_tile_window(b_flat_dram_windows(nIter)(kIter), {nIter * flatNPerWarp, kIter * flatKPerWarp}); - load_int4_tile( - b_warp_tensor_ping(nIter)(kIter), b_flat_dram_windows(nIter)(kIter)); + load_and_convert_tile(b_warp_tensor_ping(nIter)(kIter), + b_flat_dram_windows(nIter)(kIter)); }); }); move_tile_window(b_flat_dram_window, {0, BlockGemmShape::flatKPerBlock}); @@ -525,8 +525,8 @@ struct WPQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRegV move_tile_window(b_flat_dram_windows(nIter)(kIter), {nIter * flatNPerWarp, kIter * flatKPerWarp}); - load_int4_tile( - b_warp_tensor_pong(nIter)(kIter), b_flat_dram_windows(nIter)(kIter)); + load_and_convert_tile(b_warp_tensor_pong(nIter)(kIter), + b_flat_dram_windows(nIter)(kIter)); }); }); bq_block_tile_2 = load_tile(bq_copy_dram_window); diff --git a/include/ck_tile/ops/grouped_convolution.hpp b/include/ck_tile/ops/grouped_convolution.hpp index 6743e466131..eeb9b1d8a81 100644 --- a/include/ck_tile/ops/grouped_convolution.hpp +++ b/include/ck_tile/ops/grouped_convolution.hpp @@ -12,7 +12,7 @@ #include "ck_tile/ops/grouped_convolution/utils/transform_conv_bwd_weight_to_gemm.hpp" #include "ck_tile/ops/grouped_convolution/utils/transform_conv_fwd_to_gemm.hpp" #include "ck_tile/ops/common/generic_2d_block_shape.hpp" -#include "ck_tile/ops/common/load_interleaved_pk_type.hpp" +#include "ck_tile/ops/common/load_and_convert_tile.hpp" #include "ck_tile/ops/common/streamk_common.hpp" #include "ck_tile/ops/common/tensor_layout.hpp" #include "ck_tile/ops/common/utils.hpp" diff --git a/include/ck_tile/ops/image_to_column.hpp b/include/ck_tile/ops/image_to_column.hpp index 1d33ebf39d8..07d99890869 100644 --- a/include/ck_tile/ops/image_to_column.hpp +++ b/include/ck_tile/ops/image_to_column.hpp @@ -6,7 +6,7 @@ #include "ck_tile/ops/image_to_column/pipeline/block_image_to_column_problem.hpp" #include "ck_tile/ops/image_to_column/pipeline/tile_image_to_column_shape.hpp" #include "ck_tile/ops/common/generic_2d_block_shape.hpp" -#include "ck_tile/ops/common/load_interleaved_pk_type.hpp" +#include "ck_tile/ops/common/load_and_convert_tile.hpp" #include "ck_tile/ops/common/streamk_common.hpp" #include "ck_tile/ops/common/tensor_layout.hpp" #include "ck_tile/ops/common/utils.hpp" diff --git a/include/ck_tile/ops/layernorm2d.hpp b/include/ck_tile/ops/layernorm2d.hpp index ebb20aebf47..8f9ab205ac4 100644 --- a/include/ck_tile/ops/layernorm2d.hpp +++ b/include/ck_tile/ops/layernorm2d.hpp @@ -9,7 +9,7 @@ #include "ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_pipeline_two_pass.hpp" #include "ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_traits.hpp" #include "ck_tile/ops/common/generic_2d_block_shape.hpp" -#include "ck_tile/ops/common/load_interleaved_pk_type.hpp" +#include "ck_tile/ops/common/load_and_convert_tile.hpp" #include "ck_tile/ops/common/streamk_common.hpp" #include "ck_tile/ops/common/tensor_layout.hpp" #include "ck_tile/ops/common/utils.hpp" diff --git a/include/ck_tile/ops/norm_reduce.hpp b/include/ck_tile/ops/norm_reduce.hpp index 469a98c256e..eae0ea14a33 100644 --- a/include/ck_tile/ops/norm_reduce.hpp +++ b/include/ck_tile/ops/norm_reduce.hpp @@ -6,7 +6,7 @@ #include "ck_tile/ops/norm_reduce/block/block_norm_reduce_problem.hpp" #include "ck_tile/ops/norm_reduce/thread/thread_welford.hpp" #include "ck_tile/ops/common/generic_2d_block_shape.hpp" -#include "ck_tile/ops/common/load_interleaved_pk_type.hpp" +#include "ck_tile/ops/common/load_and_convert_tile.hpp" #include "ck_tile/ops/common/streamk_common.hpp" #include "ck_tile/ops/common/tensor_layout.hpp" #include "ck_tile/ops/common/utils.hpp" diff --git a/include/ck_tile/ops/permute.hpp b/include/ck_tile/ops/permute.hpp index 88a3d8a137e..4d37f4fbc12 100644 --- a/include/ck_tile/ops/permute.hpp +++ b/include/ck_tile/ops/permute.hpp @@ -5,7 +5,7 @@ #include "ck_tile/ops/permute/kernel/generic_permute_kernel.hpp" #include "ck_tile/ops/permute/pipeline/generic_petmute_problem.hpp" #include "ck_tile/ops/common/generic_2d_block_shape.hpp" -#include "ck_tile/ops/common/load_interleaved_pk_type.hpp" +#include "ck_tile/ops/common/load_and_convert_tile.hpp" #include "ck_tile/ops/common/streamk_common.hpp" #include "ck_tile/ops/common/tensor_layout.hpp" #include "ck_tile/ops/common/utils.hpp" diff --git a/include/ck_tile/ops/pooling.hpp b/include/ck_tile/ops/pooling.hpp index 3e44122afab..faa77d53273 100644 --- a/include/ck_tile/ops/pooling.hpp +++ b/include/ck_tile/ops/pooling.hpp @@ -7,7 +7,7 @@ #include "ck_tile/ops/pooling/pipeline/pool_problem.hpp" #include "ck_tile/ops/pooling/pipeline/pool_shape.hpp" #include "ck_tile/ops/common/generic_2d_block_shape.hpp" -#include "ck_tile/ops/common/load_interleaved_pk_type.hpp" +#include "ck_tile/ops/common/load_and_convert_tile.hpp" #include "ck_tile/ops/common/streamk_common.hpp" #include "ck_tile/ops/common/tensor_layout.hpp" #include "ck_tile/ops/common/utils.hpp" diff --git a/include/ck_tile/ops/reduce.hpp b/include/ck_tile/ops/reduce.hpp index 9e31b7bbe26..b5e53283e48 100644 --- a/include/ck_tile/ops/reduce.hpp +++ b/include/ck_tile/ops/reduce.hpp @@ -14,7 +14,7 @@ #include "ck_tile/ops/reduce/pipeline/reduce2d_problem.hpp" #include "ck_tile/ops/reduce/pipeline/reduce2d_shape.hpp" #include "ck_tile/ops/common/generic_2d_block_shape.hpp" -#include "ck_tile/ops/common/load_interleaved_pk_type.hpp" +#include "ck_tile/ops/common/load_and_convert_tile.hpp" #include "ck_tile/ops/common/streamk_common.hpp" #include "ck_tile/ops/common/tensor_layout.hpp" #include "ck_tile/ops/common/utils.hpp" diff --git a/include/ck_tile/ops/rmsnorm2d.hpp b/include/ck_tile/ops/rmsnorm2d.hpp index ad23a708b79..f271be50068 100644 --- a/include/ck_tile/ops/rmsnorm2d.hpp +++ b/include/ck_tile/ops/rmsnorm2d.hpp @@ -10,7 +10,7 @@ #include "ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_two_pass.hpp" #include "ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_traits.hpp" #include "ck_tile/ops/common/generic_2d_block_shape.hpp" -#include "ck_tile/ops/common/load_interleaved_pk_type.hpp" +#include "ck_tile/ops/common/load_and_convert_tile.hpp" #include "ck_tile/ops/common/streamk_common.hpp" #include "ck_tile/ops/common/tensor_layout.hpp" #include "ck_tile/ops/common/utils.hpp" diff --git a/include/ck_tile/ops/smoothquant.hpp b/include/ck_tile/ops/smoothquant.hpp index 13372f32899..4c2fe9bee43 100644 --- a/include/ck_tile/ops/smoothquant.hpp +++ b/include/ck_tile/ops/smoothquant.hpp @@ -9,7 +9,7 @@ #include "ck_tile/ops/smoothquant/pipeline/smoothquant_pipeline_problem.hpp" #include "ck_tile/ops/smoothquant/pipeline/smoothquant_pipeline_two_pass.hpp" #include "ck_tile/ops/common/generic_2d_block_shape.hpp" -#include "ck_tile/ops/common/load_interleaved_pk_type.hpp" +#include "ck_tile/ops/common/load_and_convert_tile.hpp" #include "ck_tile/ops/common/streamk_common.hpp" #include "ck_tile/ops/common/tensor_layout.hpp" #include "ck_tile/ops/common/utils.hpp" diff --git a/include/ck_tile/ops/softmax.hpp b/include/ck_tile/ops/softmax.hpp index 9cf3e08319f..c79ba06abfe 100644 --- a/include/ck_tile/ops/softmax.hpp +++ b/include/ck_tile/ops/softmax.hpp @@ -5,7 +5,7 @@ #include "ck_tile/ops/softmax/block/block_softmax_2d.hpp" #include "ck_tile/ops/softmax/block/block_softmax_2d_problem.hpp" #include "ck_tile/ops/common/generic_2d_block_shape.hpp" -#include "ck_tile/ops/common/load_interleaved_pk_type.hpp" +#include "ck_tile/ops/common/load_and_convert_tile.hpp" #include "ck_tile/ops/common/streamk_common.hpp" #include "ck_tile/ops/common/tensor_layout.hpp" #include "ck_tile/ops/common/utils.hpp" diff --git a/include/ck_tile/ops/sparse_attn.hpp b/include/ck_tile/ops/sparse_attn.hpp index 3ee643d7299..c7c4171874a 100644 --- a/include/ck_tile/ops/sparse_attn.hpp +++ b/include/ck_tile/ops/sparse_attn.hpp @@ -7,7 +7,7 @@ #include "ck_tile/ops/sparse_attn/pipeline/block_fmha_pipeline_qr_ks_vs_async_jenga.hpp" #include "ck_tile/ops/sparse_attn/pipeline/block_fmha_pipeline_qr_ks_vs_async_vsa.hpp" #include "ck_tile/ops/common/generic_2d_block_shape.hpp" -#include "ck_tile/ops/common/load_interleaved_pk_type.hpp" +#include "ck_tile/ops/common/load_and_convert_tile.hpp" #include "ck_tile/ops/common/streamk_common.hpp" #include "ck_tile/ops/common/tensor_layout.hpp" #include "ck_tile/ops/common/utils.hpp" diff --git a/include/ck_tile/ops/topk.hpp b/include/ck_tile/ops/topk.hpp index 090ad0919f5..474ba932270 100644 --- a/include/ck_tile/ops/topk.hpp +++ b/include/ck_tile/ops/topk.hpp @@ -5,7 +5,7 @@ #include "ck_tile/ops/topk/block/block_topk_stream_2d.hpp" #include "ck_tile/ops/topk/block/block_topk_stream_2d_problem.hpp" #include "ck_tile/ops/common/generic_2d_block_shape.hpp" -#include "ck_tile/ops/common/load_interleaved_pk_type.hpp" +#include "ck_tile/ops/common/load_and_convert_tile.hpp" #include "ck_tile/ops/common/streamk_common.hpp" #include "ck_tile/ops/common/tensor_layout.hpp" #include "ck_tile/ops/common/utils.hpp" diff --git a/include/ck_tile/ops/topk_softmax.hpp b/include/ck_tile/ops/topk_softmax.hpp index 7afce1708b4..066fbf5feea 100644 --- a/include/ck_tile/ops/topk_softmax.hpp +++ b/include/ck_tile/ops/topk_softmax.hpp @@ -7,7 +7,7 @@ #include "ck_tile/ops/topk_softmax/pipeline/topk_softmax_warp_per_row_policy.hpp" #include "ck_tile/ops/topk_softmax/pipeline/topk_softmax_warp_per_row_problem.hpp" #include "ck_tile/ops/common/generic_2d_block_shape.hpp" -#include "ck_tile/ops/common/load_interleaved_pk_type.hpp" +#include "ck_tile/ops/common/load_and_convert_tile.hpp" #include "ck_tile/ops/common/streamk_common.hpp" #include "ck_tile/ops/common/tensor_layout.hpp" #include "ck_tile/ops/common/utils.hpp"