diff --git a/test/ck_tile/epilogue/CMakeLists.txt b/test/ck_tile/epilogue/CMakeLists.txt index 2b3ffe33cc8..b408d79509c 100644 --- a/test/ck_tile/epilogue/CMakeLists.txt +++ b/test/ck_tile/epilogue/CMakeLists.txt @@ -1,4 +1,17 @@ # Copyright (c) Advanced Micro Devices, Inc., or its affiliates. # SPDX-License-Identifier: MIT -add_gtest_executable(test_ck_tile_cshuffle_epilogue test_cshuffle_epilogue.cpp) +add_gtest_executable(test_ck_tile_cshuffle_epilogue_fp16 test_cshuffle_epilogue_fp16.cpp) +add_gtest_executable(test_ck_tile_cshuffle_epilogue_fp8 test_cshuffle_epilogue_fp8.cpp) +add_gtest_executable(test_ck_tile_cshuffle_epilogue_scale test_cshuffle_epilogue_scale.cpp) + +if(CK_USE_OCP_FP8) + target_compile_options(test_ck_tile_cshuffle_epilogue_fp8 PRIVATE -DCK_TILE_USE_OCP_FP8) +endif() + +if(GPU_TARGETS MATCHES "gfx950") + add_gtest_executable(test_ck_tile_cshuffle_epilogue_fp8_gfx950 test_cshuffle_epilogue_fp8_gfx950.cpp) + if(CK_USE_OCP_FP8) + target_compile_options(test_ck_tile_cshuffle_epilogue_fp8_gfx950 PRIVATE -DCK_TILE_USE_OCP_FP8) + endif() +endif() diff --git a/test/ck_tile/epilogue/test_cshuffle_epilogue.cpp b/test/ck_tile/epilogue/test_cshuffle_epilogue.cpp deleted file mode 100644 index 9fbe883e320..00000000000 --- a/test/ck_tile/epilogue/test_cshuffle_epilogue.cpp +++ /dev/null @@ -1,121 +0,0 @@ -// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. -// SPDX-License-Identifier: MIT - -#include "test_cshuffle_epilogue_util.hpp" -#include -#include - -using namespace ck_tile; - -class CShuffleEpilogueTest : public ::testing::Test -{ - protected: - void SetUp() override {} -}; - -TEST_F(CShuffleEpilogueTest, BasicHalfTest) -{ - // Basic test configuration with half_t data types - using ADataType = ck_tile::half_t; - using BDataType = ck_tile::half_t; - using AccDataType = float; - using ODataType = ck_tile::half_t; - - constexpr index_t kMPerBlock = 256; - constexpr index_t kNPerBlock = 256; - constexpr index_t MWave = 2; - constexpr index_t NWave = 2; - constexpr index_t MPerXdl = 32; - constexpr index_t NPerXdl = 32; - constexpr index_t KPerXdl = 8; - - using TestProblem = SimpleCShuffleEpilogueProblem; - - auto result = run_cshuffle_epilogue_test(ScaleType::None); - EXPECT_FLOAT_EQ(result[0], 2.0F) << "Basic CShuffleEpilogue test failed"; -} - -TEST_F(CShuffleEpilogueTest, BasicHalfTestWithScale) -{ - // Basic test configuration with half_t data types - using ADataType = ck_tile::half_t; - using BDataType = ck_tile::half_t; - using AccDataType = float; - using ODataType = ck_tile::half_t; - - constexpr index_t kMPerBlock = 256; - constexpr index_t kNPerBlock = 256; - constexpr index_t MWave = 2; - constexpr index_t NWave = 2; - constexpr index_t MPerXdl = 32; - constexpr index_t NPerXdl = 32; - constexpr index_t KPerXdl = 8; - - using TestProblem = SimpleCShuffleEpilogueProblem; - - auto result = - run_cshuffle_epilogue_test(ScaleType::RowCol); - EXPECT_FLOAT_EQ(result[0], 2.0F) << "RowCol CShuffleEpilogue test failed: first element not 2"; - EXPECT_FLOAT_EQ(result[1], 4.0F) - << "RowCol CShuffleEpilogue test failed: second element not 2*2"; -} - -TEST_F(CShuffleEpilogueTest, BasicHalfTestWithTensorScale) -{ - // Basic test configuration with half_t data types - using ADataType = ck_tile::half_t; - using BDataType = ck_tile::half_t; - using AccDataType = float; - using ODataType = ck_tile::half_t; - - constexpr index_t kMPerBlock = 256; - constexpr index_t kNPerBlock = 256; - constexpr index_t MWave = 2; - constexpr index_t NWave = 2; - constexpr index_t MPerXdl = 32; - constexpr index_t NPerXdl = 32; - constexpr index_t KPerXdl = 8; - - using TestProblem = SimpleCShuffleEpilogueProblem; - - auto result = - run_cshuffle_epilogue_test(ScaleType::Tensor); - EXPECT_FLOAT_EQ(result[0], 4.0F) - << "TensorScale CShuffleEpilogue test failed: first element not 2*2=4"; -} - -int main(int argc, char** argv) -{ - ::testing::InitGoogleTest(&argc, argv); - return RUN_ALL_TESTS(); -} diff --git a/test/ck_tile/epilogue/test_cshuffle_epilogue_common.hpp b/test/ck_tile/epilogue/test_cshuffle_epilogue_common.hpp new file mode 100644 index 00000000000..75bc60c7509 --- /dev/null +++ b/test/ck_tile/epilogue/test_cshuffle_epilogue_common.hpp @@ -0,0 +1,175 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include "test_cshuffle_epilogue_util.hpp" +#include +#include +#include +#include +#include + +// Test configuration template for parameterized tests +// MfmaDataType is used for MFMA instruction selection (determines valid KPerXdl values) +// ODataType is the output data type +template +struct TileConfig +{ + using DataType = ODataType_; + using MfmaDataType = MfmaDataType_; + static constexpr ck_tile::index_t kMPerBlock = MPerBlock_; + static constexpr ck_tile::index_t kNPerBlock = NPerBlock_; + static constexpr ck_tile::index_t MWave = MWave_; + static constexpr ck_tile::index_t NWave = NWave_; + static constexpr ck_tile::index_t MPerXdl = MPerXdl_; + static constexpr ck_tile::index_t NPerXdl = NPerXdl_; + static constexpr ck_tile::index_t KPerXdl = KPerXdl_; +}; + +// Helper to construct SimpleCShuffleEpilogueProblem from TileConfig +// Uses MfmaDataType for MFMA input types (A/B) and DataType for output +template +using MakeProblem = ck_tile::SimpleCShuffleEpilogueProblem; + +// Verification helper: check that output contains valid data from the epilogue shuffle +// The C-shuffle epilogue broadcasts thread-local values to multiple output locations, +// so we verify: no NaN/zeros, reasonable value range, and at least kBlockSize unique values +// (since each thread generates unique values) +template +void verify_permutation_output(const std::vector& sorted_vals) +{ + constexpr size_t expected_size = static_cast(kMPerBlock * kNPerBlock); + + // Verify output size matches expected + ASSERT_EQ(sorted_vals.size(), expected_size) << "CShuffleEpilogue output size mismatch"; + + // Verify no NaN values + for(size_t i = 0; i < sorted_vals.size(); ++i) + { + ASSERT_FALSE(std::isnan(sorted_vals[i])) + << "CShuffleEpilogue output contains NaN at index " << i; + } + + // Verify all values are positive (no zeros from unwritten memory) + EXPECT_GT(sorted_vals.front(), 0.0f) << "CShuffleEpilogue output contains zero values"; + + // Count unique values and track occurrence counts for uniformity check + std::vector occurrence_counts; + size_t current_count = 1; + for(size_t i = 1; i < sorted_vals.size(); ++i) + { + if(std::abs(sorted_vals[i] - sorted_vals[i - 1]) > ck_tile::verification::kScaleEpsilon) + { + occurrence_counts.push_back(current_count); + current_count = 1; + } + else + { + ++current_count; + } + } + occurrence_counts.push_back(current_count); // Don't forget the last value + + const size_t num_unique = occurrence_counts.size(); + + // Each thread generates unique values, so we expect at least kBlockSize unique values + // This verifies that all threads contributed to the output + EXPECT_GE(num_unique, static_cast(kBlockSize)) + << "CShuffleEpilogue output has fewer unique values (" << num_unique + << ") than threads per block (" << kBlockSize << ")"; + + // Check if distribution is uniform (all values appear same number of times) + const size_t first_count = occurrence_counts[0]; + bool is_uniform = true; + size_t min_count = first_count; + size_t max_count = first_count; + + for(size_t count : occurrence_counts) + { + if(count != first_count) + { + is_uniform = false; + } + min_count = std::min(min_count, count); + max_count = std::max(max_count, count); + } + + if(is_uniform) + { + // Uniform distribution: verify exact counts + const size_t expected_count = expected_size / num_unique; + EXPECT_EQ(first_count, expected_count) << "Uniform distribution but count " << first_count + << " != expected " << expected_count; + EXPECT_EQ(expected_size % num_unique, 0u) + << "Output size " << expected_size << " not evenly divisible by " << num_unique; + } + else + { + // Non-uniform distribution: log for investigation + std::cout << " [INFO] Non-uniform distribution detected: " << num_unique + << " unique values, counts range [" << min_count << ", " << max_count << "]" + << std::endl; + } +} + +// Type-parameterized test fixture +template +class CShuffleEpilogueTypedTest : public ::testing::Test +{ +}; + +TYPED_TEST_SUITE_P(CShuffleEpilogueTypedTest); + +TYPED_TEST_P(CShuffleEpilogueTypedTest, BasicTest) +{ + using Config = TypeParam; + using DataType = typename Config::DataType; + + constexpr ck_tile::index_t kMPerBlock = Config::kMPerBlock; + constexpr ck_tile::index_t kNPerBlock = Config::kNPerBlock; + + using TestProblem = MakeProblem; + constexpr ck_tile::index_t kBlockSize = TestProblem::kBlockSize; + + auto test_result = ck_tile::run_cshuffle_epilogue_test( + ck_tile::ScaleType::None); + + // Convert output to sorted vector and verify + auto output_vals = ck_tile::convert_and_sort_output(test_result.output); + verify_permutation_output(output_vals); +} + +REGISTER_TYPED_TEST_SUITE_P(CShuffleEpilogueTypedTest, BasicTest); + +// Allow this test suite to be included without instantiation (e.g., in scale tests) +GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(CShuffleEpilogueTypedTest); + +// Macro to instantiate typed test suites with suppressed clang warnings +// NOLINTNEXTLINE(cppcoreguidelines-macro-usage) +#define CK_INSTANTIATE_TYPED_TEST_SUITE(Prefix, Suite, Types) \ + _Pragma("clang diagnostic push") \ + _Pragma("clang diagnostic ignored \"-Wused-but-marked-unused\"") \ + INSTANTIATE_TYPED_TEST_SUITE_P(Prefix, Suite, Types); \ + _Pragma("clang diagnostic pop") diff --git a/test/ck_tile/epilogue/test_cshuffle_epilogue_fp16.cpp b/test/ck_tile/epilogue/test_cshuffle_epilogue_fp16.cpp new file mode 100644 index 00000000000..c33de3a4bd5 --- /dev/null +++ b/test/ck_tile/epilogue/test_cshuffle_epilogue_fp16.cpp @@ -0,0 +1,27 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "test_cshuffle_epilogue_common.hpp" + +using namespace ck_tile; + +// Half precision test configurations +using HalfConfig_256x256_2x2x1_32x32x8 = TileConfig; +using HalfConfig_128x128_1x4x1_16x16x16 = TileConfig; +using HalfConfig_128x128_2x2x1_16x16x16 = TileConfig; +using HalfConfig_128x128_4x1x1_16x16x16 = TileConfig; +using HalfConfig_128x128_2x2x1_32x32x16 = TileConfig; + +using HalfTestTypes = ::testing::Types; + +CK_INSTANTIATE_TYPED_TEST_SUITE(FP16, CShuffleEpilogueTypedTest, HalfTestTypes) + +int main(int argc, char** argv) +{ + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/test/ck_tile/epilogue/test_cshuffle_epilogue_fp8.cpp b/test/ck_tile/epilogue/test_cshuffle_epilogue_fp8.cpp new file mode 100644 index 00000000000..1a95bb68e42 --- /dev/null +++ b/test/ck_tile/epilogue/test_cshuffle_epilogue_fp8.cpp @@ -0,0 +1,32 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "test_cshuffle_epilogue_common.hpp" + +using namespace ck_tile; + +// FP8 MFMA tile configurations with half_t output +// Using half_t output avoids FP8 range limitations while testing FP8-specific tile sizes +using FP8Config_128x128_2x2x1_16x16x16 = TileConfig; +using FP8Config_128x128_1x4x1_16x16x16 = TileConfig; +using FP8Config_128x128_4x1x1_16x16x16 = TileConfig; +using FP8Config_128x128_2x2x1_32x32x16 = TileConfig; +using FP8Config_128x128_2x2x1_16x16x32 = TileConfig; +using FP8Config_128x128_2x2x1_32x32x32 = TileConfig; +using FP8Config_128x128_2x2x1_16x16x64 = TileConfig; + +using FP8TestTypes = ::testing::Types; + +CK_INSTANTIATE_TYPED_TEST_SUITE(FP8, CShuffleEpilogueTypedTest, FP8TestTypes) + +int main(int argc, char** argv) +{ + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/test/ck_tile/epilogue/test_cshuffle_epilogue_fp8_gfx950.cpp b/test/ck_tile/epilogue/test_cshuffle_epilogue_fp8_gfx950.cpp new file mode 100644 index 00000000000..1956e51635b --- /dev/null +++ b/test/ck_tile/epilogue/test_cshuffle_epilogue_fp8_gfx950.cpp @@ -0,0 +1,33 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "test_cshuffle_epilogue_common.hpp" + +using namespace ck_tile; + +// FP8 MFMA tile configurations for gfx950-specific tile sizes with half_t output +// Using half_t output avoids FP8 range limitations while testing FP8-specific tile sizes +// 2x2 warp layout +using FP8Config_128x128_2x2x1_16x16x128 = TileConfig; +using FP8Config_128x128_2x2x1_32x32x64 = TileConfig; +// 1x4 warp layout +using FP8Config_128x128_1x4x1_16x16x128 = TileConfig; +using FP8Config_128x128_1x4x1_32x32x64 = TileConfig; +// 4x1 warp layout +using FP8Config_128x128_4x1x1_16x16x128 = TileConfig; +using FP8Config_128x128_4x1x1_32x32x64 = TileConfig; + +using FP8Gfx950TestTypes = ::testing::Types; + +CK_INSTANTIATE_TYPED_TEST_SUITE(FP8Gfx950, CShuffleEpilogueTypedTest, FP8Gfx950TestTypes) + +int main(int argc, char** argv) +{ + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/test/ck_tile/epilogue/test_cshuffle_epilogue_scale.cpp b/test/ck_tile/epilogue/test_cshuffle_epilogue_scale.cpp new file mode 100644 index 00000000000..1018f7cd196 --- /dev/null +++ b/test/ck_tile/epilogue/test_cshuffle_epilogue_scale.cpp @@ -0,0 +1,100 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "test_cshuffle_epilogue_common.hpp" +#include +#include +#include +#include + +using namespace ck_tile; + +// Half precision test configuration for scale tests +using HalfConfig = TileConfig; +using ScaleTestProblem = MakeProblem; + +class CShuffleEpilogueScaleTest : public ::testing::Test +{ +}; + +TEST_F(CShuffleEpilogueScaleTest, HalfTestWithRowColScale) +{ + // Run both unscaled and scaled tests + auto results = run_scale_comparison_test(); + + // With RowCol scaling, column kScaledColIndex is scaled by kTestScaleFactor + // while other columns are scaled by kIdentityScale. + // Verify scaling behavior for the first MPerXdl * MWave rows. + const index_t rows_to_check = + std::min(HalfConfig::kMPerBlock, HalfConfig::MPerXdl * HalfConfig::MWave); + + constexpr index_t kUnscaledCol = 0; + constexpr index_t kScaledCol = verification::kScaledColIndex; + + size_t col0_unchanged_count = 0; + size_t col1_scaled_count = 0; + + for(index_t row = 0; row < rows_to_check; ++row) + { + const size_t col0_idx = static_cast(row * HalfConfig::kNPerBlock + kUnscaledCol); + const size_t col1_idx = static_cast(row * HalfConfig::kNPerBlock + kScaledCol); + + const auto unscaled_col0 = type_convert(results.unscaled.output.mData[col0_idx]); + const auto scaled_col0 = type_convert(results.scaled.output.mData[col0_idx]); + const auto unscaled_col1 = type_convert(results.unscaled.output.mData[col1_idx]); + const auto scaled_col1 = type_convert(results.scaled.output.mData[col1_idx]); + + // Count rows where column 0 is unchanged (scale = kIdentityScale) + if(std::abs(scaled_col0 - unscaled_col0) < verification::kScaleEpsilon) + { + col0_unchanged_count++; + } + + // Count rows where column 1 is scaled by kTestScaleFactor + const float expected_scaled = unscaled_col1 * verification::kTestScaleFactor; + if(std::abs(scaled_col1 - expected_scaled) < verification::kScaleEpsilon) + { + col1_scaled_count++; + } + } + + // All rows must have correct scaling + EXPECT_EQ(col0_unchanged_count, static_cast(rows_to_check)) + << "RowCol: not all rows have unchanged col0"; + EXPECT_EQ(col1_scaled_count, static_cast(rows_to_check)) + << "RowCol: not all rows have scaled col1"; +} + +TEST_F(CShuffleEpilogueScaleTest, HalfTestWithTensorScale) +{ + // Run both unscaled and scaled tests + auto results = run_scale_comparison_test(); + + // Convert both to sorted vectors using helper + auto unscaled_vals = convert_and_sort_output(results.unscaled.output); + auto scaled_vals = convert_and_sort_output(results.scaled.output); + + // With Tensor scaling (m_scale=kTestScaleFactor, n_scale=kIdentityScale), + // all values should be scaled by kTestScaleFactor + EXPECT_EQ(unscaled_vals.size(), scaled_vals.size()) << "Tensor scale: output sizes differ"; + + for(size_t i = 0; i < unscaled_vals.size(); ++i) + { + const float expected = unscaled_vals[i] * verification::kTestScaleFactor; + EXPECT_NEAR(scaled_vals[i], expected, verification::kScaleEpsilon) + << "Tensor scale: sorted scaled[" << i << "]=" << scaled_vals[i] << " should be " + << verification::kTestScaleFactor << "x " << unscaled_vals[i]; + } +} + +int main(int argc, char** argv) +{ + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/test/ck_tile/epilogue/test_cshuffle_epilogue_util.hpp b/test/ck_tile/epilogue/test_cshuffle_epilogue_util.hpp index 0572115201f..1fb6ff0d8a8 100644 --- a/test/ck_tile/epilogue/test_cshuffle_epilogue_util.hpp +++ b/test/ck_tile/epilogue/test_cshuffle_epilogue_util.hpp @@ -4,21 +4,36 @@ #pragma once #include "ck_tile/core.hpp" +#include "ck_tile/host/device_memory.hpp" #include "ck_tile/host/hip_check_error.hpp" +#include "ck_tile/host/host_tensor.hpp" #include "ck_tile/ops/epilogue/cshuffle_epilogue.hpp" #include "ck_tile/ops/elementwise.hpp" #include "ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp" #include "ck_tile/ops/common/tensor_layout.hpp" +#include #include #include #include #include +#include #include #include namespace ck_tile { +// Verification and test constants +namespace verification { +constexpr float kScaleEpsilon = 0.001F; + +// Scale factors used in tests - these must match between allocation and verification +constexpr float kTestScaleFactor = + 2.0F; // Scale factor applied in RowCol (to column 1) and Tensor tests +constexpr float kIdentityScale = 1.0F; // Identity scale (no change) +constexpr index_t kScaledColIndex = 1; // Column index that gets scaled in RowCol tests +} // namespace verification + enum class ScaleType { None, @@ -46,9 +61,17 @@ __global__ void test_cshuffle_epilogue_kernel(typename Problem::ODataType* __res auto acc_tile = make_static_distributed_tensor(lds_distribution_encode); - // Fill acc_tile with a simple pattern - auto& acc_buffer = acc_tile.get_thread_buffer(); - acc_buffer[0] = 2.0F; + // Fill acc_tile with unique integer values per thread and buffer index + // Values are exactly representable in the 32-bit float accumulator + auto& acc_buffer = acc_tile.get_thread_buffer(); + const index_t thread_id = threadIdx.x; + const index_t buffer_size = acc_buffer.size(); + for(index_t i = 0; i < buffer_size; i++) + { + // Generate unique value: 1-based to avoid zero + const index_t unique_val = thread_id * buffer_size + i + 1; + acc_buffer[i] = static_cast(unique_val); + } // Create output tensor view auto output_tensor_view = @@ -123,6 +146,80 @@ using SimpleCShuffleEpilogueProblem = false // isCTransposed >; +// Result struct containing output for verification +template +struct CShuffleEpilogueTestResult +{ + HostTensor output; +}; + +// Launch kernel with RowCol scaling +template +void launch_kernel_with_rowcol_scale(typename Problem::ODataType* device_output, + dim3 gridSize, + dim3 blockSize) +{ + HostTensor h_m_scale({M}); + HostTensor h_n_scale({N}); + for(index_t i = 0; i < M; ++i) + { + h_m_scale.mData[i] = verification::kIdentityScale; + } + for(index_t i = 0; i < N; ++i) + { + h_n_scale.mData[i] = verification::kIdentityScale; + } + h_n_scale.mData[verification::kScaledColIndex] = verification::kTestScaleFactor; + + DeviceMem m_scale_buf(h_m_scale.get_element_space_size_in_bytes()); + DeviceMem n_scale_buf(h_n_scale.get_element_space_size_in_bytes()); + m_scale_buf.ToDevice(h_m_scale.data()); + n_scale_buf.ToDevice(h_n_scale.data()); + + test_cshuffle_epilogue_kernel + <<>>(device_output, + static_cast(m_scale_buf.GetDeviceBuffer()), + static_cast(n_scale_buf.GetDeviceBuffer())); + HIP_CHECK_ERROR(hipGetLastError()); + HIP_CHECK_ERROR(hipDeviceSynchronize()); +} + +// Launch kernel with Tensor scaling +template +void launch_kernel_with_tensor_scale(typename Problem::ODataType* device_output, + dim3 gridSize, + dim3 blockSize) +{ + HostTensor h_m_scale({1}); + HostTensor h_n_scale({1}); + h_m_scale.mData[0] = verification::kTestScaleFactor; + h_n_scale.mData[0] = verification::kIdentityScale; + + DeviceMem m_scale_buf(h_m_scale.get_element_space_size_in_bytes()); + DeviceMem n_scale_buf(h_n_scale.get_element_space_size_in_bytes()); + m_scale_buf.ToDevice(h_m_scale.data()); + n_scale_buf.ToDevice(h_n_scale.data()); + + test_cshuffle_epilogue_kernel + <<>>(device_output, + static_cast(m_scale_buf.GetDeviceBuffer()), + static_cast(n_scale_buf.GetDeviceBuffer())); + HIP_CHECK_ERROR(hipGetLastError()); + HIP_CHECK_ERROR(hipDeviceSynchronize()); +} + +// Launch kernel without scaling +template +void launch_kernel_without_scale(typename Problem::ODataType* device_output, + dim3 gridSize, + dim3 blockSize) +{ + test_cshuffle_epilogue_kernel + <<>>(device_output, nullptr, nullptr); + HIP_CHECK_ERROR(hipGetLastError()); + HIP_CHECK_ERROR(hipDeviceSynchronize()); +} + template auto run_cshuffle_epilogue_test(ScaleType scale = ScaleType::None) { @@ -130,76 +227,77 @@ auto run_cshuffle_epilogue_test(ScaleType scale = ScaleType::None) constexpr index_t kMPerBlock = Problem::kMPerBlock; constexpr index_t kNPerBlock = Problem::kNPerBlock; - index_t kBlockSize = ck_tile::is_wave32() ? Problem::kBlockSize / 2 : Problem::kBlockSize; + const index_t kBlockSize = ck_tile::is_wave32() ? Problem::kBlockSize / 2 : Problem::kBlockSize; std::cout << "Running CShuffleEpilogue test with M=" << M << ", N=" << N << ", MPerBlock=" << kMPerBlock << ", NPerBlock=" << kNPerBlock << ", BlockSize=" << kBlockSize << std::endl; // Allocate host memory - const size_t output_size = M * N; - - std::vector host_output(output_size, static_cast(0)); + HostTensor host_output({M, N}); + host_output.SetZero(); // Allocate device memory - ODataType* device_output; + DeviceMem device_output_buf(host_output.get_element_space_size_in_bytes()); + device_output_buf.ToDevice(host_output.data()); + ODataType* device_output = static_cast(device_output_buf.GetDeviceBuffer()); - HIP_CHECK_ERROR(hipMalloc(&device_output, output_size * sizeof(ODataType))); - - HIP_CHECK_ERROR(hipMemcpy( - device_output, host_output.data(), output_size * sizeof(ODataType), hipMemcpyHostToDevice)); - - // Launch kernel + // Launch kernel with appropriate scale configuration dim3 gridSize(1, 1, 1); dim3 blockSize(kBlockSize, 1, 1); - if(scale == ScaleType::RowCol) - { - float* m_scale; - float* n_scale; - std::vector h_m_scale(M, 1.0F); - std::vector h_n_scale(N, 1.0F); - h_n_scale[1] = 2.0F; // multiply one col only with 2 - HIP_CHECK_ERROR(hipMalloc(&m_scale, M * sizeof(float))); - HIP_CHECK_ERROR(hipMalloc(&n_scale, N * sizeof(float))); - HIP_CHECK_ERROR( - hipMemcpy(m_scale, h_m_scale.data(), M * sizeof(float), hipMemcpyHostToDevice)); - HIP_CHECK_ERROR( - hipMemcpy(n_scale, h_n_scale.data(), N * sizeof(float), hipMemcpyHostToDevice)); - test_cshuffle_epilogue_kernel - <<>>(device_output, m_scale, n_scale); - } - else if(scale == ScaleType::Tensor) + switch(scale) { - float* m_scale; - float* n_scale; - std::vector h_m_scale(1, 2.0F); - std::vector h_n_scale(1, 1.0F); - HIP_CHECK_ERROR(hipMalloc(&m_scale, sizeof(float))); - HIP_CHECK_ERROR(hipMalloc(&n_scale, sizeof(float))); - HIP_CHECK_ERROR(hipMemcpy(m_scale, h_m_scale.data(), sizeof(float), hipMemcpyHostToDevice)); - HIP_CHECK_ERROR(hipMemcpy(n_scale, h_n_scale.data(), sizeof(float), hipMemcpyHostToDevice)); - test_cshuffle_epilogue_kernel - <<>>(device_output, m_scale, n_scale); + case ScaleType::RowCol: + launch_kernel_with_rowcol_scale(device_output, gridSize, blockSize); + break; + case ScaleType::Tensor: + launch_kernel_with_tensor_scale(device_output, gridSize, blockSize); + break; + case ScaleType::None: + launch_kernel_without_scale(device_output, gridSize, blockSize); + break; } - else + + // Copy results back + device_output_buf.FromDevice(host_output.data()); + + return CShuffleEpilogueTestResult{std::move(host_output)}; +} + +// Convert output values to sorted float vector for verification +// Uses float as intermediate to preserve precision for floating-point comparison +template +std::vector convert_and_sort_output(const HostTensor& output) +{ + std::vector result; + result.reserve(output.get_element_size()); + for(size_t i = 0; i < output.get_element_size(); ++i) { - test_cshuffle_epilogue_kernel - <<>>(device_output, nullptr, nullptr); + result.push_back(type_convert(output.mData[i])); } + std::sort(result.begin(), result.end()); + return result; +} - // Check for kernel launch errors - HIP_CHECK_ERROR(hipGetLastError()); - HIP_CHECK_ERROR(hipDeviceSynchronize()); +// Result pair for scale comparison tests +template +struct ScaleComparisonResult +{ + CShuffleEpilogueTestResult unscaled; + CShuffleEpilogueTestResult scaled; +}; - // Copy results back - HIP_CHECK_ERROR(hipMemcpy( - host_output.data(), device_output, output_size * sizeof(ODataType), hipMemcpyDeviceToHost)); +// Run both unscaled and scaled tests for comparison +template +auto run_scale_comparison_test() +{ + using ODataType = typename Problem::ODataType; - // Cleanup - HIP_CHECK_ERROR(hipFree(device_output)); + auto unscaled = run_cshuffle_epilogue_test(ScaleType::None); + auto scaled = run_cshuffle_epilogue_test(ScaleMode); - return host_output; + return ScaleComparisonResult{std::move(unscaled), std::move(scaled)}; } } // namespace ck_tile