From ce6e86542c7567fabca3fbe4b27506efadc596f3 Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Wed, 20 May 2026 15:23:36 +0000 Subject: [PATCH 1/3] speed up nvte_multi_padding / nvte_multi_unpadding --- transformer_engine/common/util/padding.cu | 100 +++++++++------------- 1 file changed, 40 insertions(+), 60 deletions(-) diff --git a/transformer_engine/common/util/padding.cu b/transformer_engine/common/util/padding.cu index 835923828..37793aee5 100644 --- a/transformer_engine/common/util/padding.cu +++ b/transformer_engine/common/util/padding.cu @@ -42,6 +42,22 @@ struct MultiPaddingArgs { int num_tensors; }; +// Binary search to find which tensor owns a given block id. +// block_range is a sorted prefix sum array with (num_tensors + 1) entries. +__device__ __forceinline__ int find_tensor_for_block(const int* block_range, int num_tensors, + int bid) { + int lo = 0, hi = num_tensors - 1; + while (lo < hi) { + int mid = (lo + hi) / 2; + if (block_range[mid + 1] <= bid) { + lo = mid + 1; + } else { + hi = mid; + } + } + return lo; +} + template __global__ void __launch_bounds__(threads_per_block) multi_padding_kernel(MultiPaddingArgs args) { using Vec = Vec; @@ -65,15 +81,13 @@ __global__ void __launch_bounds__(threads_per_block) multi_padding_kernel(MultiP constexpr int n_iterations = THREADS_PER_WARP / n_warps_per_tile; // Find tensor corresponding to block - int tensor_id = 0; - while (args.block_range[tensor_id + 1] <= bid) { - ++tensor_id; - } + const int tensor_id = find_tensor_for_block(args.block_range, args.num_tensors, bid); const Type* input = reinterpret_cast(args.input_list[tensor_id]); Type* output = reinterpret_cast(args.output_list[tensor_id]); const int num_rows = args.num_rows_list[tensor_id]; const int padded_num_rows = args.padded_num_rows_list[tensor_id]; const int row_length = args.row_length_list[tensor_id]; + const bool inplace = (input == output); // Find position of tile within tensor const int num_tiles_n = (row_length + tile_dim_n - 1) / tile_dim_n; @@ -83,10 +97,7 @@ __global__ void __launch_bounds__(threads_per_block) multi_padding_kernel(MultiP const int tile_row = tile_id_m * tile_dim_m; const int tile_col = tile_id_n * tile_dim_n; - // Load input and store to registers - // Note: Each thread loads n_iterations subtiles, casts to output - // type, and transposes in registers. - Type local_zero = static_cast(0.f); + // Process subtiles with vectorized loads/stores #pragma unroll for (int iter = 0; iter < n_iterations; ++iter) { const int i1 = tidy + iter * bdimy; @@ -95,33 +106,21 @@ __global__ void __launch_bounds__(threads_per_block) multi_padding_kernel(MultiP for (int i2 = 0; i2 < nvec; ++i2) { const int row = tile_row + i1 * nvec + i2; const int col = tile_col + j1 * nvec; - Vec local_input; - Vec local_output; - local_input.clear(); - if (row < num_rows) { - for (int j2 = 0; j2 < nvec; ++j2) { - if (col + j2 < row_length) { - local_input.data.elt[j2] = input[static_cast(row) * row_length + col + j2]; - } - } - } -#pragma unroll - for (int j2 = 0; j2 < nvec; ++j2) { - local_output.data.elt[j2] = local_input.data.elt[j2]; - } + const int remaining = row_length - col; if (row < num_rows) { - for (int j2 = 0; j2 < nvec; ++j2) { - if (col + j2 < row_length) { - output[static_cast(row) * row_length + col + j2] = local_output.data.elt[j2]; - } + // Valid data row: skip copy when in-place + if (!inplace) { + const size_t offset = static_cast(row) * row_length + col; + Vec v; + v.load_from_elts(input, offset, remaining > 0 ? min(remaining, nvec) : 0); + v.store_to_elts(output, offset, remaining > 0 ? min(remaining, nvec) : 0); } } else if (row < padded_num_rows) { - // padding - for (int j2 = 0; j2 < nvec; ++j2) { - if (col + j2 < row_length) { - output[static_cast(row) * row_length + col + j2] = local_zero; - } - } + // Padding row: fill with zeros + const size_t offset = static_cast(row) * row_length + col; + Vec v; + v.clear(); + v.store_to_elts(output, offset, remaining > 0 ? min(remaining, nvec) : 0); } } } @@ -150,14 +149,12 @@ __global__ void __launch_bounds__(threads_per_block) multi_unpadding_kernel(Mult constexpr int n_iterations = THREADS_PER_WARP / n_warps_per_tile; // Find tensor corresponding to block - int tensor_id = 0; - while (args.block_range[tensor_id + 1] <= bid) { - ++tensor_id; - } + const int tensor_id = find_tensor_for_block(args.block_range, args.num_tensors, bid); const Type* input = reinterpret_cast(args.input_list[tensor_id]); Type* output = reinterpret_cast(args.output_list[tensor_id]); const int num_rows = args.num_rows_list[tensor_id]; const int row_length = args.row_length_list[tensor_id]; + const bool inplace = (input == output); // Find position of tile within tensor const int num_tiles_n = (row_length + tile_dim_n - 1) / tile_dim_n; @@ -167,10 +164,7 @@ __global__ void __launch_bounds__(threads_per_block) multi_unpadding_kernel(Mult const int tile_row = tile_id_m * tile_dim_m; const int tile_col = tile_id_n * tile_dim_n; - // Load input and store to registers - // Note: Each thread loads n_iterations subtiles, casts to output - // type, and transposes in registers. - Type local_zero = static_cast(0.f); + // Process subtiles with vectorized loads/stores #pragma unroll for (int iter = 0; iter < n_iterations; ++iter) { const int i1 = tidy + iter * bdimy; @@ -179,26 +173,12 @@ __global__ void __launch_bounds__(threads_per_block) multi_unpadding_kernel(Mult for (int i2 = 0; i2 < nvec; ++i2) { const int row = tile_row + i1 * nvec + i2; const int col = tile_col + j1 * nvec; - Vec local_input; - Vec local_output; - local_input.clear(); - if (row < num_rows) { - for (int j2 = 0; j2 < nvec; ++j2) { - if (col + j2 < row_length) { - local_input.data.elt[j2] = input[static_cast(row) * row_length + col + j2]; - } - } - } -#pragma unroll - for (int j2 = 0; j2 < nvec; ++j2) { - local_output.data.elt[j2] = local_input.data.elt[j2]; - } - if (row < num_rows) { - for (int j2 = 0; j2 < nvec; ++j2) { - if (col + j2 < row_length) { - output[static_cast(row) * row_length + col + j2] = local_output.data.elt[j2]; - } - } + if (row < num_rows && !inplace) { + const int remaining = row_length - col; + const size_t offset = static_cast(row) * row_length + col; + Vec v; + v.load_from_elts(input, offset, remaining > 0 ? min(remaining, nvec) : 0); + v.store_to_elts(output, offset, remaining > 0 ? min(remaining, nvec) : 0); } } } From a470ecbf8127f8cb8c6d964bef1ebd4fdcc5ee46 Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Wed, 20 May 2026 18:37:56 +0000 Subject: [PATCH 2/3] factor out binary search --- transformer_engine/common/util/padding.cu | 21 +++---------------- .../common/util/rocm_device_utils.cuh | 17 +++++++++++++++ 2 files changed, 20 insertions(+), 18 deletions(-) diff --git a/transformer_engine/common/util/padding.cu b/transformer_engine/common/util/padding.cu index 37793aee5..4e725b25d 100644 --- a/transformer_engine/common/util/padding.cu +++ b/transformer_engine/common/util/padding.cu @@ -13,6 +13,7 @@ #include "../common.h" #include "../utils.cuh" +#include "rocm_device_utils.cuh" namespace transformer_engine { @@ -42,22 +43,6 @@ struct MultiPaddingArgs { int num_tensors; }; -// Binary search to find which tensor owns a given block id. -// block_range is a sorted prefix sum array with (num_tensors + 1) entries. -__device__ __forceinline__ int find_tensor_for_block(const int* block_range, int num_tensors, - int bid) { - int lo = 0, hi = num_tensors - 1; - while (lo < hi) { - int mid = (lo + hi) / 2; - if (block_range[mid + 1] <= bid) { - lo = mid + 1; - } else { - hi = mid; - } - } - return lo; -} - template __global__ void __launch_bounds__(threads_per_block) multi_padding_kernel(MultiPaddingArgs args) { using Vec = Vec; @@ -81,7 +66,7 @@ __global__ void __launch_bounds__(threads_per_block) multi_padding_kernel(MultiP constexpr int n_iterations = THREADS_PER_WARP / n_warps_per_tile; // Find tensor corresponding to block - const int tensor_id = find_tensor_for_block(args.block_range, args.num_tensors, bid); + const int tensor_id = rocm_upper_bound(args.block_range, args.num_tensors, bid); const Type* input = reinterpret_cast(args.input_list[tensor_id]); Type* output = reinterpret_cast(args.output_list[tensor_id]); const int num_rows = args.num_rows_list[tensor_id]; @@ -149,7 +134,7 @@ __global__ void __launch_bounds__(threads_per_block) multi_unpadding_kernel(Mult constexpr int n_iterations = THREADS_PER_WARP / n_warps_per_tile; // Find tensor corresponding to block - const int tensor_id = find_tensor_for_block(args.block_range, args.num_tensors, bid); + const int tensor_id = rocm_upper_bound(args.block_range, args.num_tensors, bid); const Type* input = reinterpret_cast(args.input_list[tensor_id]); Type* output = reinterpret_cast(args.output_list[tensor_id]); const int num_rows = args.num_rows_list[tensor_id]; diff --git a/transformer_engine/common/util/rocm_device_utils.cuh b/transformer_engine/common/util/rocm_device_utils.cuh index 0d2b4c658..89c49b533 100644 --- a/transformer_engine/common/util/rocm_device_utils.cuh +++ b/transformer_engine/common/util/rocm_device_utils.cuh @@ -118,6 +118,23 @@ __device__ __forceinline__ void rocm_atomicMaxFloat(float *addr, float val) { atomicMax(reinterpret_cast(addr), __float_as_int(val)); } +// Binary search on a sorted array. +// Returns the largest index i in [0, n) such that arr[i] <= val. +// Precondition: arr is sorted in non-decreasing order and arr[0] <= val. +template +__device__ __forceinline__ int rocm_upper_bound(const T* arr, int n, T val) { + int lo = 0, hi = n - 1; + while (lo < hi) { + int mid = (lo + hi + 1) / 2; + if (arr[mid] <= val) { + lo = mid; + } else { + hi = mid - 1; + } + } + return lo; +} + template __device__ __forceinline__ float rocm_block_reduce_max(float val, int warp_id) { __shared__ float staging[WARPS]; From 5f011ae91432e6fb070416b0e55b1041c25bd18b Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Wed, 20 May 2026 20:01:03 +0000 Subject: [PATCH 3/3] guard --- transformer_engine/common/util/padding.cu | 107 +++++++++++++++++++++- 1 file changed, 106 insertions(+), 1 deletion(-) diff --git a/transformer_engine/common/util/padding.cu b/transformer_engine/common/util/padding.cu index 4e725b25d..45b5ee2f0 100644 --- a/transformer_engine/common/util/padding.cu +++ b/transformer_engine/common/util/padding.cu @@ -1,4 +1,6 @@ /************************************************************************* +* This file was modified for portability to AMDGPU + * Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved. * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. @@ -13,7 +15,9 @@ #include "../common.h" #include "../utils.cuh" -#include "rocm_device_utils.cuh" +#ifdef __HIP_PLATFORM_AMD__ +#include "rocm_device_utils.cuh" // for rocm_upper_bound() +#endif namespace transformer_engine { @@ -66,13 +70,22 @@ __global__ void __launch_bounds__(threads_per_block) multi_padding_kernel(MultiP constexpr int n_iterations = THREADS_PER_WARP / n_warps_per_tile; // Find tensor corresponding to block +#ifdef __HIP_PLATFORM_AMD__ const int tensor_id = rocm_upper_bound(args.block_range, args.num_tensors, bid); +#else + int tensor_id = 0; + while (args.block_range[tensor_id + 1] <= bid) { + ++tensor_id; + } +#endif const Type* input = reinterpret_cast(args.input_list[tensor_id]); Type* output = reinterpret_cast(args.output_list[tensor_id]); const int num_rows = args.num_rows_list[tensor_id]; const int padded_num_rows = args.padded_num_rows_list[tensor_id]; const int row_length = args.row_length_list[tensor_id]; +#ifdef __HIP_PLATFORM_AMD__ const bool inplace = (input == output); +#endif // Find position of tile within tensor const int num_tiles_n = (row_length + tile_dim_n - 1) / tile_dim_n; @@ -82,6 +95,7 @@ __global__ void __launch_bounds__(threads_per_block) multi_padding_kernel(MultiP const int tile_row = tile_id_m * tile_dim_m; const int tile_col = tile_id_n * tile_dim_n; +#ifdef __HIP_PLATFORM_AMD__ // Process subtiles with vectorized loads/stores #pragma unroll for (int iter = 0; iter < n_iterations; ++iter) { @@ -109,6 +123,50 @@ __global__ void __launch_bounds__(threads_per_block) multi_padding_kernel(MultiP } } } +#else // !__HIP_PLATFORM_AMD__ + // Load input and store to registers + // Note: Each thread loads n_iterations subtiles, casts to output + // type, and transposes in registers. + Type local_zero = static_cast(0.f); +#pragma unroll + for (int iter = 0; iter < n_iterations; ++iter) { + const int i1 = tidy + iter * bdimy; + const int j1 = tidx; +#pragma unroll + for (int i2 = 0; i2 < nvec; ++i2) { + const int row = tile_row + i1 * nvec + i2; + const int col = tile_col + j1 * nvec; + Vec local_input; + Vec local_output; + local_input.clear(); + if (row < num_rows) { + for (int j2 = 0; j2 < nvec; ++j2) { + if (col + j2 < row_length) { + local_input.data.elt[j2] = input[static_cast(row) * row_length + col + j2]; + } + } + } +#pragma unroll + for (int j2 = 0; j2 < nvec; ++j2) { + local_output.data.elt[j2] = local_input.data.elt[j2]; + } + if (row < num_rows) { + for (int j2 = 0; j2 < nvec; ++j2) { + if (col + j2 < row_length) { + output[static_cast(row) * row_length + col + j2] = local_output.data.elt[j2]; + } + } + } else if (row < padded_num_rows) { + // padding + for (int j2 = 0; j2 < nvec; ++j2) { + if (col + j2 < row_length) { + output[static_cast(row) * row_length + col + j2] = local_zero; + } + } + } + } + } +#endif // __HIP_PLATFORM_AMD__ } template @@ -134,12 +192,21 @@ __global__ void __launch_bounds__(threads_per_block) multi_unpadding_kernel(Mult constexpr int n_iterations = THREADS_PER_WARP / n_warps_per_tile; // Find tensor corresponding to block +#ifdef __HIP_PLATFORM_AMD__ const int tensor_id = rocm_upper_bound(args.block_range, args.num_tensors, bid); +#else + int tensor_id = 0; + while (args.block_range[tensor_id + 1] <= bid) { + ++tensor_id; + } +#endif const Type* input = reinterpret_cast(args.input_list[tensor_id]); Type* output = reinterpret_cast(args.output_list[tensor_id]); const int num_rows = args.num_rows_list[tensor_id]; const int row_length = args.row_length_list[tensor_id]; +#ifdef __HIP_PLATFORM_AMD__ const bool inplace = (input == output); +#endif // Find position of tile within tensor const int num_tiles_n = (row_length + tile_dim_n - 1) / tile_dim_n; @@ -149,6 +216,7 @@ __global__ void __launch_bounds__(threads_per_block) multi_unpadding_kernel(Mult const int tile_row = tile_id_m * tile_dim_m; const int tile_col = tile_id_n * tile_dim_n; +#ifdef __HIP_PLATFORM_AMD__ // Process subtiles with vectorized loads/stores #pragma unroll for (int iter = 0; iter < n_iterations; ++iter) { @@ -167,6 +235,43 @@ __global__ void __launch_bounds__(threads_per_block) multi_unpadding_kernel(Mult } } } +#else // !__HIP_PLATFORM_AMD__ + // Load input and store to registers + // Note: Each thread loads n_iterations subtiles, casts to output + // type, and transposes in registers. + Type local_zero = static_cast(0.f); +#pragma unroll + for (int iter = 0; iter < n_iterations; ++iter) { + const int i1 = tidy + iter * bdimy; + const int j1 = tidx; +#pragma unroll + for (int i2 = 0; i2 < nvec; ++i2) { + const int row = tile_row + i1 * nvec + i2; + const int col = tile_col + j1 * nvec; + Vec local_input; + Vec local_output; + local_input.clear(); + if (row < num_rows) { + for (int j2 = 0; j2 < nvec; ++j2) { + if (col + j2 < row_length) { + local_input.data.elt[j2] = input[static_cast(row) * row_length + col + j2]; + } + } + } +#pragma unroll + for (int j2 = 0; j2 < nvec; ++j2) { + local_output.data.elt[j2] = local_input.data.elt[j2]; + } + if (row < num_rows) { + for (int j2 = 0; j2 < nvec; ++j2) { + if (col + j2 < row_length) { + output[static_cast(row) * row_length + col + j2] = local_output.data.elt[j2]; + } + } + } + } + } +#endif // __HIP_PLATFORM_AMD__ } } // namespace