Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 70 additions & 0 deletions transformer_engine/common/util/padding.cu
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
/*************************************************************************
* This file was modified for portability to AMDGPU
* Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved.
* Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
Expand All @@ -13,6 +15,9 @@

#include "../common.h"
#include "../utils.cuh"
#ifdef __HIP_PLATFORM_AMD__
#include "rocm_device_utils.cuh" // for rocm_upper_bound()
#endif

namespace transformer_engine {

Expand Down Expand Up @@ -65,15 +70,22 @@ __global__ void __launch_bounds__(threads_per_block) multi_padding_kernel(MultiP
constexpr int n_iterations = THREADS_PER_WARP / n_warps_per_tile;

// Find tensor corresponding to block
#ifdef __HIP_PLATFORM_AMD__
const int tensor_id = rocm_upper_bound(args.block_range, args.num_tensors, bid);
#else
int tensor_id = 0;
while (args.block_range[tensor_id + 1] <= bid) {
++tensor_id;
}
#endif
const Type* input = reinterpret_cast<const Type*>(args.input_list[tensor_id]);
Type* output = reinterpret_cast<Type*>(args.output_list[tensor_id]);
const int num_rows = args.num_rows_list[tensor_id];
const int padded_num_rows = args.padded_num_rows_list[tensor_id];
const int row_length = args.row_length_list[tensor_id];
#ifdef __HIP_PLATFORM_AMD__
const bool inplace = (input == output);
#endif

// Find position of tile within tensor
const int num_tiles_n = (row_length + tile_dim_n - 1) / tile_dim_n;
Expand All @@ -83,6 +95,35 @@ __global__ void __launch_bounds__(threads_per_block) multi_padding_kernel(MultiP
const int tile_row = tile_id_m * tile_dim_m;
const int tile_col = tile_id_n * tile_dim_n;

#ifdef __HIP_PLATFORM_AMD__
// Process subtiles with vectorized loads/stores
#pragma unroll
for (int iter = 0; iter < n_iterations; ++iter) {
const int i1 = tidy + iter * bdimy;
const int j1 = tidx;
#pragma unroll
for (int i2 = 0; i2 < nvec; ++i2) {
const int row = tile_row + i1 * nvec + i2;
const int col = tile_col + j1 * nvec;
const int remaining = row_length - col;
if (row < num_rows) {
// Valid data row: skip copy when in-place
if (!inplace) {
const size_t offset = static_cast<size_t>(row) * row_length + col;
Vec v;
v.load_from_elts(input, offset, remaining > 0 ? min(remaining, nvec) : 0);
v.store_to_elts(output, offset, remaining > 0 ? min(remaining, nvec) : 0);
}
} else if (row < padded_num_rows) {
// Padding row: fill with zeros
const size_t offset = static_cast<size_t>(row) * row_length + col;
Vec v;
v.clear();
v.store_to_elts(output, offset, remaining > 0 ? min(remaining, nvec) : 0);
}
}
}
#else // !__HIP_PLATFORM_AMD__
// Load input and store to registers
// Note: Each thread loads n_iterations subtiles, casts to output
// type, and transposes in registers.
Expand Down Expand Up @@ -125,6 +166,7 @@ __global__ void __launch_bounds__(threads_per_block) multi_padding_kernel(MultiP
}
}
}
#endif // __HIP_PLATFORM_AMD__
}

template <int nvec, typename Type>
Expand All @@ -150,14 +192,21 @@ __global__ void __launch_bounds__(threads_per_block) multi_unpadding_kernel(Mult
constexpr int n_iterations = THREADS_PER_WARP / n_warps_per_tile;

// Find tensor corresponding to block
#ifdef __HIP_PLATFORM_AMD__
const int tensor_id = rocm_upper_bound(args.block_range, args.num_tensors, bid);
#else
int tensor_id = 0;
while (args.block_range[tensor_id + 1] <= bid) {
++tensor_id;
}
#endif
const Type* input = reinterpret_cast<const Type*>(args.input_list[tensor_id]);
Type* output = reinterpret_cast<Type*>(args.output_list[tensor_id]);
const int num_rows = args.num_rows_list[tensor_id];
const int row_length = args.row_length_list[tensor_id];
#ifdef __HIP_PLATFORM_AMD__
const bool inplace = (input == output);
#endif

// Find position of tile within tensor
const int num_tiles_n = (row_length + tile_dim_n - 1) / tile_dim_n;
Expand All @@ -167,6 +216,26 @@ __global__ void __launch_bounds__(threads_per_block) multi_unpadding_kernel(Mult
const int tile_row = tile_id_m * tile_dim_m;
const int tile_col = tile_id_n * tile_dim_n;

#ifdef __HIP_PLATFORM_AMD__
// Process subtiles with vectorized loads/stores
#pragma unroll
for (int iter = 0; iter < n_iterations; ++iter) {
const int i1 = tidy + iter * bdimy;
const int j1 = tidx;
#pragma unroll
for (int i2 = 0; i2 < nvec; ++i2) {
const int row = tile_row + i1 * nvec + i2;
const int col = tile_col + j1 * nvec;
if (row < num_rows && !inplace) {
const int remaining = row_length - col;
const size_t offset = static_cast<size_t>(row) * row_length + col;
Vec v;
v.load_from_elts(input, offset, remaining > 0 ? min(remaining, nvec) : 0);
v.store_to_elts(output, offset, remaining > 0 ? min(remaining, nvec) : 0);
}
}
}
#else // !__HIP_PLATFORM_AMD__
// Load input and store to registers
// Note: Each thread loads n_iterations subtiles, casts to output
// type, and transposes in registers.
Expand Down Expand Up @@ -202,6 +271,7 @@ __global__ void __launch_bounds__(threads_per_block) multi_unpadding_kernel(Mult
}
}
}
#endif // __HIP_PLATFORM_AMD__
}

} // namespace
Expand Down
17 changes: 17 additions & 0 deletions transformer_engine/common/util/rocm_device_utils.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,23 @@ __device__ __forceinline__ void rocm_atomicMaxFloat(float *addr, float val) {
atomicMax(reinterpret_cast<int*>(addr), __float_as_int(val));
}

// Binary search on a sorted array.
// Returns the largest index i in [0, n) such that arr[i] <= val.
// Precondition: arr is sorted in non-decreasing order and arr[0] <= val.
template <typename T>
__device__ __forceinline__ int rocm_upper_bound(const T* arr, int n, T val) {
int lo = 0, hi = n - 1;
while (lo < hi) {
int mid = (lo + hi + 1) / 2;
if (arr[mid] <= val) {
lo = mid;
} else {
hi = mid - 1;
}
}
return lo;
}

template <int WARPS>
__device__ __forceinline__ float rocm_block_reduce_max(float val, int warp_id) {
__shared__ float staging[WARPS];
Expand Down
Loading