From 6c2818a4d627df5ea2705ea6fd2b361bdeeb921c Mon Sep 17 00:00:00 2001 From: Cael Ling Date: Sun, 29 Mar 2026 01:56:39 -0700 Subject: [PATCH 01/10] streamline group Hadamard ComputeKernel loads Signed-off-by: Cael Ling --- .../group_hadamard_transform.cu | 42 ++++++++----------- 1 file changed, 18 insertions(+), 24 deletions(-) diff --git a/transformer_engine/common/hadamard_transform/group_hadamard_transform.cu b/transformer_engine/common/hadamard_transform/group_hadamard_transform.cu index 07813be059..cef9ef154e 100644 --- a/transformer_engine/common/hadamard_transform/group_hadamard_transform.cu +++ b/transformer_engine/common/hadamard_transform/group_hadamard_transform.cu @@ -57,30 +57,9 @@ __device__ __forceinline__ void ComputeKernel(uint32_t b_frag_i[4], uint32_t b_f uint32_t temp_amax_reg; uint32_t temp_amax_t_reg; - if (kReturnIdentityAmax) { - ldmatrix_x4_m8n8_shared_b16(a_frag[0], a_frag[1], a_frag[2], a_frag[3], - reinterpret_cast(in_sh_ptr) + swizzle_idx); - - mma_m16_n16_k16_b16_b16_b16_noacc( - a_frag[0], a_frag[1], a_frag[2], a_frag[3], b_frag_i[0], b_frag_i[1], b_frag_i[2], - b_frag_i[3], c_frag[0], c_frag[1], c_frag[2], c_frag[3], temp_amax_reg); - asm volatile("max.xorsign.abs.bf16x2 %0, %1, %2;\n\t" - : "=r"(local_amax_reg) - : "r"(local_amax_reg), "r"(temp_amax_reg)); - } - if (kReturnTransposedAmax) { - // TODO(Frank): This is not efficient, since we could directly load the - // matrix in transposed layout. - if (!kReturnIdentityAmax) { - ldmatrix_x4_m8n8_shared_b16(a_frag[0], a_frag[1], a_frag[2], a_frag[3], - reinterpret_cast(in_sh_ptr) + swizzle_idx); - } - - matrix_transpose_m8_n8_b16_inplace(a_frag[0]); - matrix_transpose_m8_n8_b16_inplace(a_frag[1]); - matrix_transpose_m8_n8_b16_inplace(a_frag[2]); - matrix_transpose_m8_n8_b16_inplace(a_frag[3]); + ldmatrix_x4_m8n8_shared_b16(a_frag[0], a_frag[1], a_frag[2], a_frag[3], + reinterpret_cast(in_sh_ptr) + swizzle_idx); mma_m16_n16_k16_b16_b16_b16_noacc( a_frag[0], a_frag[2], a_frag[1], a_frag[3], b_frag_t[0], b_frag_t[1], b_frag_t[2], @@ -91,7 +70,7 @@ __device__ __forceinline__ void ComputeKernel(uint32_t b_frag_i[4], uint32_t b_f } if (kReturnPreRhtAmax) { - if (!kReturnIdentityAmax && !kReturnTransposedAmax) { + if (!kReturnTransposedAmax) { ldmatrix_x4_m8n8_shared_b16(a_frag[0], a_frag[1], a_frag[2], a_frag[3], reinterpret_cast(in_sh_ptr) + swizzle_idx); } @@ -109,6 +88,21 @@ __device__ __forceinline__ void ComputeKernel(uint32_t b_frag_i[4], uint32_t b_f : "=r"(local_pre_rht_amax_reg) : "r"(a_frag[0]), "r"(local_pre_rht_amax_reg)); } + + if (kReturnIdentityAmax) { + if (kReturnTransposedAmax || + (!kReturnTransposedAmax && !kReturnPreRhtAmax)) { + ldmatrix_x4_m8n8_shared_b16(a_frag[0], a_frag[1], a_frag[2], a_frag[3], + reinterpret_cast(in_sh_ptr) + swizzle_idx); + } + + mma_m16_n16_k16_b16_b16_b16_noacc( + a_frag[0], a_frag[1], a_frag[2], a_frag[3], b_frag_i[0], b_frag_i[1], b_frag_i[2], + b_frag_i[3], c_frag[0], c_frag[1], c_frag[2], c_frag[3], temp_amax_reg); + asm volatile("max.xorsign.abs.bf16x2 %0, %1, %2;\n\t" + : "=r"(local_amax_reg) + : "r"(local_amax_reg), "r"(temp_amax_reg)); + } } template From d4b668a6d2e632dd9250742f550cc72e7debf79e Mon Sep 17 00:00:00 2001 From: Cael Ling Date: Sun, 29 Mar 2026 01:07:37 -0700 Subject: [PATCH 02/10] streamline group Hadamard ComputeKernel loads Signed-off-by: Cael Ling --- .../group_hadamard_transform.cu | 42 ++++++++----------- 1 file changed, 18 insertions(+), 24 deletions(-) diff --git a/transformer_engine/common/hadamard_transform/group_hadamard_transform.cu b/transformer_engine/common/hadamard_transform/group_hadamard_transform.cu index 07813be059..cef9ef154e 100644 --- a/transformer_engine/common/hadamard_transform/group_hadamard_transform.cu +++ b/transformer_engine/common/hadamard_transform/group_hadamard_transform.cu @@ -57,30 +57,9 @@ __device__ __forceinline__ void ComputeKernel(uint32_t b_frag_i[4], uint32_t b_f uint32_t temp_amax_reg; uint32_t temp_amax_t_reg; - if (kReturnIdentityAmax) { - ldmatrix_x4_m8n8_shared_b16(a_frag[0], a_frag[1], a_frag[2], a_frag[3], - reinterpret_cast(in_sh_ptr) + swizzle_idx); - - mma_m16_n16_k16_b16_b16_b16_noacc( - a_frag[0], a_frag[1], a_frag[2], a_frag[3], b_frag_i[0], b_frag_i[1], b_frag_i[2], - b_frag_i[3], c_frag[0], c_frag[1], c_frag[2], c_frag[3], temp_amax_reg); - asm volatile("max.xorsign.abs.bf16x2 %0, %1, %2;\n\t" - : "=r"(local_amax_reg) - : "r"(local_amax_reg), "r"(temp_amax_reg)); - } - if (kReturnTransposedAmax) { - // TODO(Frank): This is not efficient, since we could directly load the - // matrix in transposed layout. - if (!kReturnIdentityAmax) { - ldmatrix_x4_m8n8_shared_b16(a_frag[0], a_frag[1], a_frag[2], a_frag[3], - reinterpret_cast(in_sh_ptr) + swizzle_idx); - } - - matrix_transpose_m8_n8_b16_inplace(a_frag[0]); - matrix_transpose_m8_n8_b16_inplace(a_frag[1]); - matrix_transpose_m8_n8_b16_inplace(a_frag[2]); - matrix_transpose_m8_n8_b16_inplace(a_frag[3]); + ldmatrix_x4_m8n8_shared_b16(a_frag[0], a_frag[1], a_frag[2], a_frag[3], + reinterpret_cast(in_sh_ptr) + swizzle_idx); mma_m16_n16_k16_b16_b16_b16_noacc( a_frag[0], a_frag[2], a_frag[1], a_frag[3], b_frag_t[0], b_frag_t[1], b_frag_t[2], @@ -91,7 +70,7 @@ __device__ __forceinline__ void ComputeKernel(uint32_t b_frag_i[4], uint32_t b_f } if (kReturnPreRhtAmax) { - if (!kReturnIdentityAmax && !kReturnTransposedAmax) { + if (!kReturnTransposedAmax) { ldmatrix_x4_m8n8_shared_b16(a_frag[0], a_frag[1], a_frag[2], a_frag[3], reinterpret_cast(in_sh_ptr) + swizzle_idx); } @@ -109,6 +88,21 @@ __device__ __forceinline__ void ComputeKernel(uint32_t b_frag_i[4], uint32_t b_f : "=r"(local_pre_rht_amax_reg) : "r"(a_frag[0]), "r"(local_pre_rht_amax_reg)); } + + if (kReturnIdentityAmax) { + if (kReturnTransposedAmax || + (!kReturnTransposedAmax && !kReturnPreRhtAmax)) { + ldmatrix_x4_m8n8_shared_b16(a_frag[0], a_frag[1], a_frag[2], a_frag[3], + reinterpret_cast(in_sh_ptr) + swizzle_idx); + } + + mma_m16_n16_k16_b16_b16_b16_noacc( + a_frag[0], a_frag[1], a_frag[2], a_frag[3], b_frag_i[0], b_frag_i[1], b_frag_i[2], + b_frag_i[3], c_frag[0], c_frag[1], c_frag[2], c_frag[3], temp_amax_reg); + asm volatile("max.xorsign.abs.bf16x2 %0, %1, %2;\n\t" + : "=r"(local_amax_reg) + : "r"(local_amax_reg), "r"(temp_amax_reg)); + } } template From 28fb3208ab04b34af70567b088deaf7b6c501034 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sun, 29 Mar 2026 08:15:45 +0000 Subject: [PATCH 03/10] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../common/hadamard_transform/group_hadamard_transform.cu | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/transformer_engine/common/hadamard_transform/group_hadamard_transform.cu b/transformer_engine/common/hadamard_transform/group_hadamard_transform.cu index cef9ef154e..749db9f5a7 100644 --- a/transformer_engine/common/hadamard_transform/group_hadamard_transform.cu +++ b/transformer_engine/common/hadamard_transform/group_hadamard_transform.cu @@ -90,8 +90,7 @@ __device__ __forceinline__ void ComputeKernel(uint32_t b_frag_i[4], uint32_t b_f } if (kReturnIdentityAmax) { - if (kReturnTransposedAmax || - (!kReturnTransposedAmax && !kReturnPreRhtAmax)) { + if (kReturnTransposedAmax || (!kReturnTransposedAmax && !kReturnPreRhtAmax)) { ldmatrix_x4_m8n8_shared_b16(a_frag[0], a_frag[1], a_frag[2], a_frag[3], reinterpret_cast(in_sh_ptr) + swizzle_idx); } From f101b02d38e28e16652fb28a27f05c56978c0529 Mon Sep 17 00:00:00 2001 From: Cael Ling Date: Fri, 27 Mar 2026 00:42:24 -0700 Subject: [PATCH 04/10] Compute swizzle_idx once per thread and pass into ComputeKernel. Signed-off-by: Cael Ling --- .../group_hadamard_transform.cu | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/transformer_engine/common/hadamard_transform/group_hadamard_transform.cu b/transformer_engine/common/hadamard_transform/group_hadamard_transform.cu index 07813be059..8b7f079072 100644 --- a/transformer_engine/common/hadamard_transform/group_hadamard_transform.cu +++ b/transformer_engine/common/hadamard_transform/group_hadamard_transform.cu @@ -41,19 +41,13 @@ constexpr int kThreadsPerWarp = 32; template __device__ __forceinline__ void ComputeKernel(uint32_t b_frag_i[4], uint32_t b_frag_t[4], - IType* in_sh_ptr, uint32_t& local_pre_rht_amax_reg, + IType* in_sh_ptr, int swizzle_idx, + uint32_t& local_pre_rht_amax_reg, uint32_t& local_amax_reg, uint32_t& local_amax_t_reg) { uint32_t a_frag[4]; // A matrix fragment uint32_t c_frag[4]; // Result fragment - int warp_id = threadIdx.x / kThreadsPerWarp; - int local_rank = (threadIdx.x % kThreadsPerWarp); - - int ld_row_idx = local_rank % kHadamardDimension; - int ld_col_idx = local_rank / kHadamardDimension + warp_id * 2; - int swizzle_idx = swizzle_128B_atom_32B(ld_row_idx, ld_col_idx); - uint32_t temp_amax_reg; uint32_t temp_amax_t_reg; @@ -305,6 +299,12 @@ __global__ void GroupHadamardAmaxTmaKernel(const __grid_constant__ CUtensorMap t uint32_t local_amax_reg = *reinterpret_cast(&local_amax); uint32_t local_amax_t_reg = *reinterpret_cast(&local_amax_t); + const int warp_id = threadIdx.x / kThreadsPerWarp; + const int local_rank = threadIdx.x % kThreadsPerWarp; + const int ld_row_idx = local_rank % kHadamardDimension; + const int ld_col_idx = local_rank / kHadamardDimension + warp_id * 2; + const int swizzle_idx = swizzle_128B_atom_32B(ld_row_idx, ld_col_idx); + for (int stage_y = 0; stage_y < STAGES_Y; ++stage_y) { for (int stage_x = 0; stage_x < STAGES_X; ++stage_x) { int stage = STAGES_X * stage_y + stage_x; @@ -347,7 +347,7 @@ __global__ void GroupHadamardAmaxTmaKernel(const __grid_constant__ CUtensorMap t had_frag_i, had_frag_t, in_sh_ptr + in_row_offset + (compute_stage_x * kHadamardDimension * (THREADS_PER_CHUNK / kThreadsPerWarp)), - local_pre_rht_amax_reg, local_amax_reg, local_amax_t_reg); + swizzle_idx, local_pre_rht_amax_reg, local_amax_reg, local_amax_t_reg); } // Ensure all threads have finished their computation before new data over-writes the shared From dbed5ee021a420791229bcffc7489a9b723ac60d Mon Sep 17 00:00:00 2001 From: Cael Ling Date: Mon, 30 Mar 2026 16:25:59 -0700 Subject: [PATCH 05/10] Fix kReturnIdentityAmax path Signed-off-by: Cael Ling --- .../common/hadamard_transform/group_hadamard_transform.cu | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/transformer_engine/common/hadamard_transform/group_hadamard_transform.cu b/transformer_engine/common/hadamard_transform/group_hadamard_transform.cu index 749db9f5a7..a48a83a07b 100644 --- a/transformer_engine/common/hadamard_transform/group_hadamard_transform.cu +++ b/transformer_engine/common/hadamard_transform/group_hadamard_transform.cu @@ -90,10 +90,8 @@ __device__ __forceinline__ void ComputeKernel(uint32_t b_frag_i[4], uint32_t b_f } if (kReturnIdentityAmax) { - if (kReturnTransposedAmax || (!kReturnTransposedAmax && !kReturnPreRhtAmax)) { - ldmatrix_x4_m8n8_shared_b16(a_frag[0], a_frag[1], a_frag[2], a_frag[3], + ldmatrix_x4_m8n8_shared_b16(a_frag[0], a_frag[1], a_frag[2], a_frag[3], reinterpret_cast(in_sh_ptr) + swizzle_idx); - } mma_m16_n16_k16_b16_b16_b16_noacc( a_frag[0], a_frag[1], a_frag[2], a_frag[3], b_frag_i[0], b_frag_i[1], b_frag_i[2], From 1ef69e8a0e0cedda83ed7f1f5ffe315acb76fcb9 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 30 Mar 2026 23:47:22 +0000 Subject: [PATCH 06/10] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../common/hadamard_transform/group_hadamard_transform.cu | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/transformer_engine/common/hadamard_transform/group_hadamard_transform.cu b/transformer_engine/common/hadamard_transform/group_hadamard_transform.cu index a48a83a07b..88636e2ef8 100644 --- a/transformer_engine/common/hadamard_transform/group_hadamard_transform.cu +++ b/transformer_engine/common/hadamard_transform/group_hadamard_transform.cu @@ -91,7 +91,7 @@ __device__ __forceinline__ void ComputeKernel(uint32_t b_frag_i[4], uint32_t b_f if (kReturnIdentityAmax) { ldmatrix_x4_m8n8_shared_b16(a_frag[0], a_frag[1], a_frag[2], a_frag[3], - reinterpret_cast(in_sh_ptr) + swizzle_idx); + reinterpret_cast(in_sh_ptr) + swizzle_idx); mma_m16_n16_k16_b16_b16_b16_noacc( a_frag[0], a_frag[1], a_frag[2], a_frag[3], b_frag_i[0], b_frag_i[1], b_frag_i[2], From e395bdbde6e419a5f412985812483dde2bc6550b Mon Sep 17 00:00:00 2001 From: Cael Ling Date: Mon, 30 Mar 2026 19:10:54 -0700 Subject: [PATCH 07/10] Refactor the change to other variants Signed-off-by: Cael Ling --- .../graph_safe_group_hadamard_transform.cu | 18 +++++++++--------- .../hadamard_transform/hadamard_transform.cu | 18 +++++++++--------- 2 files changed, 18 insertions(+), 18 deletions(-) diff --git a/transformer_engine/common/hadamard_transform/graph_safe_group_hadamard_transform.cu b/transformer_engine/common/hadamard_transform/graph_safe_group_hadamard_transform.cu index 04e965a9da..8f9a30ac60 100644 --- a/transformer_engine/common/hadamard_transform/graph_safe_group_hadamard_transform.cu +++ b/transformer_engine/common/hadamard_transform/graph_safe_group_hadamard_transform.cu @@ -65,19 +65,13 @@ __device__ __forceinline__ size_t get_current_tensor_id( template __device__ __forceinline__ void ComputeKernel(uint32_t b_frag_i[4], uint32_t b_frag_t[4], - IType* in_sh_ptr, uint32_t& local_pre_rht_amax_reg, + IType* in_sh_ptr, int swizzle_idx, + uint32_t& local_pre_rht_amax_reg, uint32_t& local_amax_reg, uint32_t& local_amax_t_reg) { uint32_t a_frag[4]; // A matrix fragment uint32_t c_frag[4]; // Result fragment - int warp_id = threadIdx.x / kThreadsPerWarp; - int local_rank = (threadIdx.x % kThreadsPerWarp); - - int ld_row_idx = local_rank % kHadamardDimension; - int ld_col_idx = local_rank / kHadamardDimension + warp_id * 2; - int swizzle_idx = swizzle_128B_atom_32B(ld_row_idx, ld_col_idx); - uint32_t temp_amax_reg; uint32_t temp_amax_t_reg; @@ -322,6 +316,12 @@ __global__ void GraphSafeGroupHadamardAmaxTmaKernel( uint32_t local_amax_reg = *reinterpret_cast(&local_amax); uint32_t local_amax_t_reg = *reinterpret_cast(&local_amax_t); + const int warp_id = threadIdx.x / kThreadsPerWarp; + const int local_rank = threadIdx.x % kThreadsPerWarp; + const int ld_row_idx = local_rank % kHadamardDimension; + const int ld_col_idx = local_rank / kHadamardDimension + warp_id * 2; + const int swizzle_idx = swizzle_128B_atom_32B(ld_row_idx, ld_col_idx); + for (int stage_y = 0; stage_y < STAGES_Y; ++stage_y) { for (int stage_x = 0; stage_x < STAGES_X; ++stage_x) { int stage = STAGES_X * stage_y + stage_x; @@ -364,7 +364,7 @@ __global__ void GraphSafeGroupHadamardAmaxTmaKernel( had_frag_i, had_frag_t, in_sh_ptr + in_row_offset + (compute_stage_x * kHadamardDimension * (THREADS_PER_CHUNK / kThreadsPerWarp)), - local_pre_rht_amax_reg, local_amax_reg, local_amax_t_reg); + swizzle_idx, local_pre_rht_amax_reg, local_amax_reg, local_amax_t_reg); } // Ensure all threads have finished their computation before new data over-writes the shared diff --git a/transformer_engine/common/hadamard_transform/hadamard_transform.cu b/transformer_engine/common/hadamard_transform/hadamard_transform.cu index 4adc836886..4e3c528fd4 100644 --- a/transformer_engine/common/hadamard_transform/hadamard_transform.cu +++ b/transformer_engine/common/hadamard_transform/hadamard_transform.cu @@ -26,19 +26,13 @@ constexpr int kThreadsPerWarp = 32; template __device__ __forceinline__ void ComputeKernel(uint32_t b_frag_i[4], uint32_t b_frag_t[4], - IType* in_sh_ptr, uint32_t& local_pre_rht_amax_reg, + IType* in_sh_ptr, int swizzle_idx, + uint32_t& local_pre_rht_amax_reg, uint32_t& local_amax_reg, uint32_t& local_amax_t_reg) { uint32_t a_frag[4]; // A matrix fragment uint32_t c_frag[4]; // Result fragment - int warp_id = threadIdx.x / kThreadsPerWarp; - int local_rank = (threadIdx.x % kThreadsPerWarp); - - int ld_row_idx = local_rank % kHadamardDimension; - int ld_col_idx = local_rank / kHadamardDimension + warp_id * 2; - int swizzle_idx = swizzle_128B_atom_32B(ld_row_idx, ld_col_idx); - uint32_t temp_amax_reg; uint32_t temp_amax_t_reg; @@ -248,6 +242,12 @@ __global__ void HadamardAmaxTmaKernel(const __grid_constant__ CUtensorMap tensor uint32_t local_amax_reg = *reinterpret_cast(&local_amax); uint32_t local_amax_t_reg = *reinterpret_cast(&local_amax_t); + const int warp_id = threadIdx.x / kThreadsPerWarp; + const int local_rank = threadIdx.x % kThreadsPerWarp; + const int ld_row_idx = local_rank % kHadamardDimension; + const int ld_col_idx = local_rank / kHadamardDimension + warp_id * 2; + const int swizzle_idx = swizzle_128B_atom_32B(ld_row_idx, ld_col_idx); + for (int stage_y = 0; stage_y < STAGES_Y; ++stage_y) { for (int stage_x = 0; stage_x < STAGES_X; ++stage_x) { int stage = STAGES_X * stage_y + stage_x; @@ -290,7 +290,7 @@ __global__ void HadamardAmaxTmaKernel(const __grid_constant__ CUtensorMap tensor had_frag_i, had_frag_t, in_sh_ptr + in_row_offset + (compute_stage_x * kHadamardDimension * (THREADS_PER_CHUNK / kThreadsPerWarp)), - local_pre_rht_amax_reg, local_amax_reg, local_amax_t_reg); + swizzle_idx, local_pre_rht_amax_reg, local_amax_reg, local_amax_t_reg); } // Ensure all threads have finished their computation before new data over-writes the shared From ee46c42bbe7fc3ab138b15ae733aa9851f3d8ddf Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 31 Mar 2026 02:14:22 +0000 Subject: [PATCH 08/10] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../hadamard_transform/graph_safe_group_hadamard_transform.cu | 2 +- .../common/hadamard_transform/hadamard_transform.cu | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/transformer_engine/common/hadamard_transform/graph_safe_group_hadamard_transform.cu b/transformer_engine/common/hadamard_transform/graph_safe_group_hadamard_transform.cu index 8f9a30ac60..231d522f3a 100644 --- a/transformer_engine/common/hadamard_transform/graph_safe_group_hadamard_transform.cu +++ b/transformer_engine/common/hadamard_transform/graph_safe_group_hadamard_transform.cu @@ -364,7 +364,7 @@ __global__ void GraphSafeGroupHadamardAmaxTmaKernel( had_frag_i, had_frag_t, in_sh_ptr + in_row_offset + (compute_stage_x * kHadamardDimension * (THREADS_PER_CHUNK / kThreadsPerWarp)), - swizzle_idx, local_pre_rht_amax_reg, local_amax_reg, local_amax_t_reg); + swizzle_idx, local_pre_rht_amax_reg, local_amax_reg, local_amax_t_reg); } // Ensure all threads have finished their computation before new data over-writes the shared diff --git a/transformer_engine/common/hadamard_transform/hadamard_transform.cu b/transformer_engine/common/hadamard_transform/hadamard_transform.cu index 4e3c528fd4..216ed1930a 100644 --- a/transformer_engine/common/hadamard_transform/hadamard_transform.cu +++ b/transformer_engine/common/hadamard_transform/hadamard_transform.cu @@ -290,7 +290,7 @@ __global__ void HadamardAmaxTmaKernel(const __grid_constant__ CUtensorMap tensor had_frag_i, had_frag_t, in_sh_ptr + in_row_offset + (compute_stage_x * kHadamardDimension * (THREADS_PER_CHUNK / kThreadsPerWarp)), - swizzle_idx, local_pre_rht_amax_reg, local_amax_reg, local_amax_t_reg); + swizzle_idx, local_pre_rht_amax_reg, local_amax_reg, local_amax_t_reg); } // Ensure all threads have finished their computation before new data over-writes the shared From 4097137e70389bd94336f0d8b87ccffbad55491b Mon Sep 17 00:00:00 2001 From: Cael Ling Date: Mon, 30 Mar 2026 19:23:57 -0700 Subject: [PATCH 09/10] Refactor the change to other variants Signed-off-by: Cael Ling --- .../graph_safe_group_hadamard_transform.cu | 39 +++++++------------ .../hadamard_transform/hadamard_transform.cu | 39 +++++++------------ 2 files changed, 30 insertions(+), 48 deletions(-) diff --git a/transformer_engine/common/hadamard_transform/graph_safe_group_hadamard_transform.cu b/transformer_engine/common/hadamard_transform/graph_safe_group_hadamard_transform.cu index 04e965a9da..77880a9be7 100644 --- a/transformer_engine/common/hadamard_transform/graph_safe_group_hadamard_transform.cu +++ b/transformer_engine/common/hadamard_transform/graph_safe_group_hadamard_transform.cu @@ -81,30 +81,9 @@ __device__ __forceinline__ void ComputeKernel(uint32_t b_frag_i[4], uint32_t b_f uint32_t temp_amax_reg; uint32_t temp_amax_t_reg; - if (kReturnIdentityAmax) { - ldmatrix_x4_m8n8_shared_b16(a_frag[0], a_frag[1], a_frag[2], a_frag[3], - reinterpret_cast(in_sh_ptr) + swizzle_idx); - - mma_m16_n16_k16_b16_b16_b16_noacc( - a_frag[0], a_frag[1], a_frag[2], a_frag[3], b_frag_i[0], b_frag_i[1], b_frag_i[2], - b_frag_i[3], c_frag[0], c_frag[1], c_frag[2], c_frag[3], temp_amax_reg); - asm volatile("max.xorsign.abs.bf16x2 %0, %1, %2;\n\t" - : "=r"(local_amax_reg) - : "r"(local_amax_reg), "r"(temp_amax_reg)); - } - if (kReturnTransposedAmax) { - // TODO(Frank): This is not efficient, since we could directly load the - // matrix in transposed layout. - if (!kReturnIdentityAmax) { - ldmatrix_x4_m8n8_shared_b16(a_frag[0], a_frag[1], a_frag[2], a_frag[3], - reinterpret_cast(in_sh_ptr) + swizzle_idx); - } - - matrix_transpose_m8_n8_b16_inplace(a_frag[0]); - matrix_transpose_m8_n8_b16_inplace(a_frag[1]); - matrix_transpose_m8_n8_b16_inplace(a_frag[2]); - matrix_transpose_m8_n8_b16_inplace(a_frag[3]); + ldmatrix_x4_m8n8_shared_b16(a_frag[0], a_frag[1], a_frag[2], a_frag[3], + reinterpret_cast(in_sh_ptr) + swizzle_idx); mma_m16_n16_k16_b16_b16_b16_noacc( a_frag[0], a_frag[2], a_frag[1], a_frag[3], b_frag_t[0], b_frag_t[1], b_frag_t[2], @@ -115,7 +94,7 @@ __device__ __forceinline__ void ComputeKernel(uint32_t b_frag_i[4], uint32_t b_f } if (kReturnPreRhtAmax) { - if (!kReturnIdentityAmax && !kReturnTransposedAmax) { + if (!kReturnTransposedAmax) { ldmatrix_x4_m8n8_shared_b16(a_frag[0], a_frag[1], a_frag[2], a_frag[3], reinterpret_cast(in_sh_ptr) + swizzle_idx); } @@ -133,6 +112,18 @@ __device__ __forceinline__ void ComputeKernel(uint32_t b_frag_i[4], uint32_t b_f : "=r"(local_pre_rht_amax_reg) : "r"(a_frag[0]), "r"(local_pre_rht_amax_reg)); } + + if (kReturnIdentityAmax) { + ldmatrix_x4_m8n8_shared_b16(a_frag[0], a_frag[1], a_frag[2], a_frag[3], + reinterpret_cast(in_sh_ptr) + swizzle_idx); + + mma_m16_n16_k16_b16_b16_b16_noacc( + a_frag[0], a_frag[1], a_frag[2], a_frag[3], b_frag_i[0], b_frag_i[1], b_frag_i[2], + b_frag_i[3], c_frag[0], c_frag[1], c_frag[2], c_frag[3], temp_amax_reg); + asm volatile("max.xorsign.abs.bf16x2 %0, %1, %2;\n\t" + : "=r"(local_amax_reg) + : "r"(local_amax_reg), "r"(temp_amax_reg)); + } } template diff --git a/transformer_engine/common/hadamard_transform/hadamard_transform.cu b/transformer_engine/common/hadamard_transform/hadamard_transform.cu index 4adc836886..a8afae2db0 100644 --- a/transformer_engine/common/hadamard_transform/hadamard_transform.cu +++ b/transformer_engine/common/hadamard_transform/hadamard_transform.cu @@ -42,30 +42,9 @@ __device__ __forceinline__ void ComputeKernel(uint32_t b_frag_i[4], uint32_t b_f uint32_t temp_amax_reg; uint32_t temp_amax_t_reg; - if (kReturnIdentityAmax) { - ldmatrix_x4_m8n8_shared_b16(a_frag[0], a_frag[1], a_frag[2], a_frag[3], - reinterpret_cast(in_sh_ptr) + swizzle_idx); - - mma_m16_n16_k16_b16_b16_b16_noacc( - a_frag[0], a_frag[1], a_frag[2], a_frag[3], b_frag_i[0], b_frag_i[1], b_frag_i[2], - b_frag_i[3], c_frag[0], c_frag[1], c_frag[2], c_frag[3], temp_amax_reg); - asm volatile("max.xorsign.abs.bf16x2 %0, %1, %2;\n\t" - : "=r"(local_amax_reg) - : "r"(local_amax_reg), "r"(temp_amax_reg)); - } - if (kReturnTransposedAmax) { - // TODO(Frank): This is not efficient, since we could directly load the - // matrix in transposed layout. - if (!kReturnIdentityAmax) { - ldmatrix_x4_m8n8_shared_b16(a_frag[0], a_frag[1], a_frag[2], a_frag[3], - reinterpret_cast(in_sh_ptr) + swizzle_idx); - } - - matrix_transpose_m8_n8_b16_inplace(a_frag[0]); - matrix_transpose_m8_n8_b16_inplace(a_frag[1]); - matrix_transpose_m8_n8_b16_inplace(a_frag[2]); - matrix_transpose_m8_n8_b16_inplace(a_frag[3]); + ldmatrix_x4_m8n8_shared_b16(a_frag[0], a_frag[1], a_frag[2], a_frag[3], + reinterpret_cast(in_sh_ptr) + swizzle_idx); mma_m16_n16_k16_b16_b16_b16_noacc( a_frag[0], a_frag[2], a_frag[1], a_frag[3], b_frag_t[0], b_frag_t[1], b_frag_t[2], @@ -76,7 +55,7 @@ __device__ __forceinline__ void ComputeKernel(uint32_t b_frag_i[4], uint32_t b_f } if (kReturnPreRhtAmax) { - if (!kReturnIdentityAmax && !kReturnTransposedAmax) { + if (!kReturnTransposedAmax) { ldmatrix_x4_m8n8_shared_b16(a_frag[0], a_frag[1], a_frag[2], a_frag[3], reinterpret_cast(in_sh_ptr) + swizzle_idx); } @@ -94,6 +73,18 @@ __device__ __forceinline__ void ComputeKernel(uint32_t b_frag_i[4], uint32_t b_f : "=r"(local_pre_rht_amax_reg) : "r"(a_frag[0]), "r"(local_pre_rht_amax_reg)); } + + if (kReturnIdentityAmax) { + ldmatrix_x4_m8n8_shared_b16(a_frag[0], a_frag[1], a_frag[2], a_frag[3], + reinterpret_cast(in_sh_ptr) + swizzle_idx); + + mma_m16_n16_k16_b16_b16_b16_noacc( + a_frag[0], a_frag[1], a_frag[2], a_frag[3], b_frag_i[0], b_frag_i[1], b_frag_i[2], + b_frag_i[3], c_frag[0], c_frag[1], c_frag[2], c_frag[3], temp_amax_reg); + asm volatile("max.xorsign.abs.bf16x2 %0, %1, %2;\n\t" + : "=r"(local_amax_reg) + : "r"(local_amax_reg), "r"(temp_amax_reg)); + } } template From 1b0ec040636dad3ff66aff47a3786bb55584967e Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 31 Mar 2026 02:45:28 +0000 Subject: [PATCH 10/10] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../hadamard_transform/graph_safe_group_hadamard_transform.cu | 2 +- .../common/hadamard_transform/hadamard_transform.cu | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/transformer_engine/common/hadamard_transform/graph_safe_group_hadamard_transform.cu b/transformer_engine/common/hadamard_transform/graph_safe_group_hadamard_transform.cu index 8b2fba8fe2..3ef14c6b11 100644 --- a/transformer_engine/common/hadamard_transform/graph_safe_group_hadamard_transform.cu +++ b/transformer_engine/common/hadamard_transform/graph_safe_group_hadamard_transform.cu @@ -109,7 +109,7 @@ __device__ __forceinline__ void ComputeKernel(uint32_t b_frag_i[4], uint32_t b_f if (kReturnIdentityAmax) { ldmatrix_x4_m8n8_shared_b16(a_frag[0], a_frag[1], a_frag[2], a_frag[3], - reinterpret_cast(in_sh_ptr) + swizzle_idx); + reinterpret_cast(in_sh_ptr) + swizzle_idx); mma_m16_n16_k16_b16_b16_b16_noacc( a_frag[0], a_frag[1], a_frag[2], a_frag[3], b_frag_i[0], b_frag_i[1], b_frag_i[2], diff --git a/transformer_engine/common/hadamard_transform/hadamard_transform.cu b/transformer_engine/common/hadamard_transform/hadamard_transform.cu index 8c72b9c2d3..7a8db9d85c 100644 --- a/transformer_engine/common/hadamard_transform/hadamard_transform.cu +++ b/transformer_engine/common/hadamard_transform/hadamard_transform.cu @@ -70,7 +70,7 @@ __device__ __forceinline__ void ComputeKernel(uint32_t b_frag_i[4], uint32_t b_f if (kReturnIdentityAmax) { ldmatrix_x4_m8n8_shared_b16(a_frag[0], a_frag[1], a_frag[2], a_frag[3], - reinterpret_cast(in_sh_ptr) + swizzle_idx); + reinterpret_cast(in_sh_ptr) + swizzle_idx); mma_m16_n16_k16_b16_b16_b16_noacc( a_frag[0], a_frag[1], a_frag[2], a_frag[3], b_frag_i[0], b_frag_i[1], b_frag_i[2],