From 4b39845b04504a8c840ff7a39318087902cc62cb Mon Sep 17 00:00:00 2001 From: Meekail Zain Date: Wed, 20 May 2026 19:19:28 +0000 Subject: [PATCH 1/5] Initial optimizations --- .../pytorch/triton_kernels/norms_common.py | 71 +++++++++++--- .../pytorch/triton_kernels/rmsnorm.py | 93 +++++++++++++------ 2 files changed, 127 insertions(+), 37 deletions(-) diff --git a/transformer_engine/pytorch/triton_kernels/norms_common.py b/transformer_engine/pytorch/triton_kernels/norms_common.py index ed4002f2c..d9d4db330 100644 --- a/transformer_engine/pytorch/triton_kernels/norms_common.py +++ b/transformer_engine/pytorch/triton_kernels/norms_common.py @@ -20,7 +20,9 @@ _rmsnorm_fwd_triton, _rmsnorm_fwd_triton_impl, _rmsnorm_bwd_triton, + _rmsnorm_bwd_triton_impl, _rmsnorm_bwd_dg_reduce_triton, + _rmsnorm_bwd_dg_reduce_triton_impl, ) from .layernorm import ( _layernorm_fwd_triton, @@ -41,6 +43,16 @@ False: _layernorm_fwd_triton_impl, } } + +_rmsnorm_bwd_kernels = { + True: _rmsnorm_bwd_triton, + False: _rmsnorm_bwd_triton_impl, +} + +_rmsnorm_bwd_dg_reduce_kernels = { + True: _rmsnorm_bwd_dg_reduce_triton, + False: _rmsnorm_bwd_dg_reduce_triton_impl, +} # triton drop-in replacement for transformer_engine::pytorch::rmsnorm_fwd def te_rmsnorm_fwd_triton( input: torch.Tensor, @@ -234,7 +246,7 @@ def _te_norm_fwd_triton( # triton drop-in replacement for transformer_engine::pytorch::rmsnorm_bwd -def te_rmsnorm_bwd_triton(dz, x, rsigma, gamma, sm_margin, zero_centered_gamma): +def te_rmsnorm_bwd_triton(dz, x, rsigma, gamma, sm_margin, zero_centered_gamma, autotune: bool = True): # may take non-contiguous inputs dz_ = dz.contiguous() x_ = x.contiguous() @@ -248,25 +260,62 @@ def te_rmsnorm_bwd_triton(dz, x, rsigma, gamma, sm_margin, zero_centered_gamma): blk_size = block_size(x_) USE_BLOCKED = use_blocked(x_) NUM_PRGMS = num_programs(x_, sm_margin) - need_reduction = N > 1 - dg_tmp_rows = x_.shape[0] if use_blocked(x_) else num_programs(x_, sm_margin) - dg_tmp = torch.empty(dg_tmp_rows, N, device=x.device, dtype=torch.float32, requires_grad=False) if need_reduction else None + # Both blocked and non-blocked paths now accumulate per-program (NUM_PRGMS rows) + # rather than per-input-row (M rows). For typical workloads (NUM_PRGMS ~= 144 on + # MI300X vs M up to 32k), this shrinks the partial buffer by ~100x and keeps it + # L2-resident, turning the bwd dg RMW into a near-free op vs going to HBM. + need_reduction = NUM_PRGMS > 1 + # Blocked path uses HBM RMW so the buffer must be zero-initialized. + # Non-blocked path writes once per program; empty is fine. + if need_reduction: + if USE_BLOCKED: + dg_tmp = torch.zeros(NUM_PRGMS, N, device=x.device, dtype=torch.float32, requires_grad=False) + else: + dg_tmp = torch.empty(NUM_PRGMS, N, device=x.device, dtype=torch.float32, requires_grad=False) + else: + dg_tmp = None input_aligned_16 = (x_.data_ptr() % 16 == 0) and (x_.stride(0) * x_.dtype.itemsize % 16 == 0) grad_output_aligned_16 = (dz_.data_ptr() % 16 == 0) and (dz_.stride(0) * dz_.dtype.itemsize % 16 == 0) dx_aligned_16 = (dx.data_ptr() % 16 == 0) and (dx.stride(0) * dx.dtype.itemsize % 16 == 0) dg_target = dg_tmp if need_reduction else dgamma dg_aligned_16 = (dg_target.data_ptr() % 16 == 0) and (dg_target.stride(0) * dg_target.dtype.itemsize % 16 == 0) + grid_bwd = lambda meta: (NUM_PRGMS, ) - _rmsnorm_bwd_triton[grid_bwd](dz_, x_, gamma_, rsigma_, dx, dg_tmp if need_reduction else dgamma, - x_.stride(0), dz_.stride(0), M, N, zero_centered_gamma, blk_size, - USE_BLOCKED, NUM_PRGMS, input_aligned_16, grad_output_aligned_16, - dx_aligned_16, dg_aligned_16, num_warps=8) + bwd_kernel = _rmsnorm_bwd_kernels[autotune] + bwd_kwargs = dict( + n_rows=M, n_cols=N, + ZERO_CENTERED_GAMMA=zero_centered_gamma, + BLOCK_SIZE=blk_size, + USE_BLOCKED=USE_BLOCKED, NUM_PRGMS=NUM_PRGMS, + INPUT_ALIGNED_16=input_aligned_16, + GRAD_OUTPUT_ALIGNED_16=grad_output_aligned_16, + DX_ALIGNED_16=dx_aligned_16, + DG_ALIGNED_16=dg_aligned_16, + ) + if not autotune: + bwd_kwargs["num_warps"] = 8 + bwd_kernel[grid_bwd]( + dz_, x_, gamma_, rsigma_, dx, dg_tmp if need_reduction else dgamma, + x_.stride(0), dz_.stride(0), + **bwd_kwargs, + ) if need_reduction: - grid_reduce = lambda meta: [triton.cdiv(N, meta['BLOCK_SIZE_N'])] - _rmsnorm_bwd_dg_reduce_triton[grid_reduce](dg_tmp, dgamma, dg_tmp.stride(0), dg_tmp.shape[0], dg_tmp.shape[1], - BLOCK_SIZE_M=128, BLOCK_SIZE_N=64) + reduce_kernel = _rmsnorm_bwd_dg_reduce_kernels[autotune] + if autotune: + grid_reduce = lambda meta: [triton.cdiv(N, meta['BLOCK_SIZE_N'])] + reduce_kernel[grid_reduce]( + dg_tmp, dgamma, dg_tmp.stride(0), dg_tmp.shape[0], dg_tmp.shape[1], + ) + else: + # Match the previously-hardcoded tile when autotune is disabled. + BLOCK_SIZE_M, BLOCK_SIZE_N = 128, 64 + grid_reduce = (triton.cdiv(N, BLOCK_SIZE_N),) + reduce_kernel[grid_reduce]( + dg_tmp, dgamma, dg_tmp.stride(0), dg_tmp.shape[0], dg_tmp.shape[1], + BLOCK_SIZE_M=BLOCK_SIZE_M, BLOCK_SIZE_N=BLOCK_SIZE_N, + ) return dx, dgamma diff --git a/transformer_engine/pytorch/triton_kernels/rmsnorm.py b/transformer_engine/pytorch/triton_kernels/rmsnorm.py index 5ecb48eb7..fcebeea3e 100644 --- a/transformer_engine/pytorch/triton_kernels/rmsnorm.py +++ b/transformer_engine/pytorch/triton_kernels/rmsnorm.py @@ -140,22 +140,24 @@ def _rmsnorm_fwd_triton_impl( else: mask = col_offsets < n_cols + # gamma is invariant across rows -- load + ZERO_CENTERED adjustment once per program. + g = tl.load(g_ptr + col_offsets, mask=mask, other=0.0).to(tl.float32) + if (ZERO_CENTERED_GAMMA): + g += 1 + inv_n_cols = 1.0 / n_cols for row_idx in tl.range(row_start, n_rows, NUM_PRGMS, num_stages=2): input_ptrs = input_ptr + row_idx * input_row_stride + col_offsets if INPUT_ALIGNED_16: input_ptrs = tl.multiple_of(input_ptrs, (16, )) row = tl.load(input_ptrs, mask=mask, other=0.0, cache_modifier=".cg").to(tl.float32) - g = tl.load(g_ptr + col_offsets, mask=mask, other=0.0).to(tl.float32) row_norm = row * row row_norm = tl.sum(row_norm, axis=-1) - norm_factor = tl.math.rsqrt((row_norm / n_cols) + epsilon) + norm_factor = tl.math.rsqrt(row_norm * inv_n_cols + epsilon) # Store rsigma (norm_factor) rsigma_output_ptr = rsigma_ptr + row_idx tl.store(rsigma_output_ptr, norm_factor) - if (ZERO_CENTERED_GAMMA): - g += 1 rms_norm = row * norm_factor * g output_ptrs = output_ptr + row_idx * output_row_stride + col_offsets @@ -181,29 +183,33 @@ def _rmsnorm_fwd_triton_impl( _rmsnorm_fwd_triton = autotune_dec(_rmsnorm_fwd_triton_impl) @triton.jit -def _rmsnorm_bwd_triton(grad_output_ptr, input_ptr, g_ptr, rsigma_ptr, dx_ptr, dg_ptr, input_row_stride, output_row_stride, +def _rmsnorm_bwd_triton_impl(grad_output_ptr, input_ptr, g_ptr, rsigma_ptr, dx_ptr, dg_ptr, input_row_stride, output_row_stride, n_rows, n_cols, ZERO_CENTERED_GAMMA: tl.constexpr, BLOCK_SIZE: tl.constexpr, USE_BLOCKED: tl.constexpr, NUM_PRGMS: tl.constexpr, INPUT_ALIGNED_16: tl.constexpr, GRAD_OUTPUT_ALIGNED_16: tl.constexpr, DX_ALIGNED_16: tl.constexpr, DG_ALIGNED_16: tl.constexpr): + # `dg_ptr` points to a (NUM_PRGMS, n_cols) fp32 partial buffer (pre-zeroed + # by the launcher). Each program accumulates its assigned rows into its + # own slot and a separate reduce kernel finalizes dgamma. row_start = tl.program_id(0) col_offsets = tl.arange(0, BLOCK_SIZE) + # Precomputed per-row invariant: dx = nf * (dz*g - c*x) where + # c = nf*nf * grad_sum / n_cols + inv_n_cols = 1.0 / n_cols # tl.assume(input_row_stride >= 0) # tl.assume(output_row_stride >= 0) # tl.assume(row_start >= 0) if USE_BLOCKED: - for row_idx in tl.range(row_start, n_rows, NUM_PRGMS, num_stages=1): + # Per-program partial dg slot in the (NUM_PRGMS, n_cols) scratch. + prgm_dg_ptr = dg_ptr + row_start * n_cols + for row_idx in tl.range(row_start, n_rows, NUM_PRGMS, num_stages=2): row_input_ptr = input_ptr + row_idx * input_row_stride row_grad_output_ptr = grad_output_ptr + row_idx * output_row_stride row_dx_ptr = dx_ptr + row_idx * input_row_stride - row_dg_ptr = dg_ptr + row_idx * input_row_stride # Compute gradients sum of all colums for each row n_cols_blks = tl.cdiv(n_cols, BLOCK_SIZE) - 1 - # older version of triton doesn't accept below init - # comment out for now to make it compatible with triton 3.1 - # grad_sum: tl.float32 = 0.0 grad_sum = 0.0 for blk_idx in tl.range(0, n_cols_blks, num_stages=2): cols = blk_idx * BLOCK_SIZE + col_offsets @@ -236,8 +242,9 @@ def _rmsnorm_bwd_triton(grad_output_ptr, input_ptr, g_ptr, rsigma_ptr, dx_ptr, d g += 1. grad_sum += tl.sum(grad_output * x * g, axis=0) - # Load r_sigma + # Load r_sigma; hoist per-row invariants used in dx. norm_factor = tl.load(rsigma_ptr + row_idx).to(tl.float32) + c_scalar = norm_factor * norm_factor * grad_sum * inv_n_cols for blk_idx in tl.range(0, n_cols_blks, num_stages=2): cols = blk_idx * BLOCK_SIZE + col_offsets @@ -256,19 +263,20 @@ def _rmsnorm_bwd_triton(grad_output_ptr, input_ptr, g_ptr, rsigma_ptr, dx_ptr, d g = tl.load(g_ptrs).to(tl.float32) if (ZERO_CENTERED_GAMMA): g += 1. - grad_input = grad_output * norm_factor * g - (norm_factor * norm_factor * norm_factor) * x * (grad_sum / - n_cols) + grad_input = norm_factor * (grad_output * g - c_scalar * x) dx_ptrs = row_dx_ptr + cols if DX_ALIGNED_16: dx_ptrs = tl.multiple_of(dx_ptrs, (16, )) tl.store(dx_ptrs, grad_input.to(dx_ptr.type.element_ty)) + # Accumulate this row's dg contribution into per-program slot. dg = grad_output * x * norm_factor - dg_ptrs = row_dg_ptr + cols + dg_ptrs = prgm_dg_ptr + cols if DG_ALIGNED_16: dg_ptrs = tl.multiple_of(dg_ptrs, (16, )) - tl.store(dg_ptrs, dg.to(tl.float32)) + partial_dg = tl.load(dg_ptrs) + tl.store(dg_ptrs, partial_dg + dg) # Handle remainder cols = n_cols_blks * BLOCK_SIZE + col_offsets @@ -282,8 +290,7 @@ def _rmsnorm_bwd_triton(grad_output_ptr, input_ptr, g_ptr, rsigma_ptr, dx_ptr, d g = tl.load(g_ptrs, mask=mask, other=0.0).to(tl.float32) if (ZERO_CENTERED_GAMMA): g += 1. - grad_input = grad_output * norm_factor * g - (norm_factor * norm_factor * norm_factor) * x * (grad_sum / - n_cols) + grad_input = norm_factor * (grad_output * g - c_scalar * x) dx_ptrs = row_dx_ptr + cols if DX_ALIGNED_16: @@ -291,15 +298,22 @@ def _rmsnorm_bwd_triton(grad_output_ptr, input_ptr, g_ptr, rsigma_ptr, dx_ptr, d tl.store(dx_ptrs, grad_input.to(dx_ptr.type.element_ty), mask=mask) dg = grad_output * x * norm_factor - dg_ptrs = row_dg_ptr + cols + dg_ptrs = prgm_dg_ptr + cols if DG_ALIGNED_16: dg_ptrs = tl.multiple_of(dg_ptrs, (16, )) - tl.store(dg_ptrs, dg.to(tl.float32), mask=mask) + partial_dg = tl.load(dg_ptrs, mask=mask, other=0.0) + tl.store(dg_ptrs, partial_dg + dg, mask=mask) else: mask = col_offsets < n_cols dg_col_redux = tl.zeros((BLOCK_SIZE, ), dtype=tl.float32) + # Hoist gamma load + ZERO_CENTERED adjustment outside the row loop + # since gamma is invariant across rows. + g = tl.load(g_ptr + col_offsets, mask=mask, other=0.0).to(tl.float32) + if (ZERO_CENTERED_GAMMA): + g += 1. + for row_idx in tl.range(row_start, n_rows, NUM_PRGMS, num_stages=2): input_ptrs = input_ptr + row_idx * input_row_stride + col_offsets grad_output_ptrs = grad_output_ptr + row_idx * output_row_stride + col_offsets @@ -314,25 +328,31 @@ def _rmsnorm_bwd_triton(grad_output_ptr, input_ptr, g_ptr, rsigma_ptr, dx_ptr, d x = tl.load(input_ptrs, mask=mask, other=0.0).to(tl.float32) grad_output = tl.load(grad_output_ptrs, mask=mask, other=0.0).to(tl.float32) - g = tl.load(g_ptr + col_offsets, mask=mask, other=0.0).to(tl.float32) - if (ZERO_CENTERED_GAMMA): - g += 1. norm_factor = tl.load(rsigma_ptr + row_idx).to(tl.float32) grad_sum = tl.sum(grad_output * x * g, axis=0) + c_scalar = norm_factor * norm_factor * grad_sum * inv_n_cols - grad_input = grad_output * norm_factor * g - (norm_factor * norm_factor * norm_factor) * x * (grad_sum / - n_cols) + grad_input = norm_factor * (grad_output * g - c_scalar * x) tl.store(dx_ptrs, grad_input.to(dx_ptr.type.element_ty), mask=mask) dg = grad_output * x * norm_factor dg_col_redux += dg.to(tl.float32) - tl.store(dg_ptr + tl.program_id(0) * input_row_stride + col_offsets, dg_col_redux, mask=mask) + tl.store(dg_ptr + row_start * n_cols + col_offsets, dg_col_redux, mask=mask) + + +# Autotune wrapper. Mirrors the fwd autotune layout so callers can toggle +# autotune via the same flag. +_rmsnorm_bwd_triton = triton.autotune( + configs=get_autotune_config(), + key=['n_rows', 'n_cols'], + use_cuda_graph=True, +)(_rmsnorm_bwd_triton_impl) @triton.jit -def _rmsnorm_bwd_dg_reduce_triton(dg_in_ptr, dg_out_ptr, dg_in_stride, n_rows, n_cols, BLOCK_SIZE_M: tl.constexpr, +def _rmsnorm_bwd_dg_reduce_triton_impl(dg_in_ptr, dg_out_ptr, dg_in_stride, n_rows, n_cols, BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr): # we want parallelism in N direction # if N is small, we will just use one CU, @@ -349,3 +369,24 @@ def _rmsnorm_bwd_dg_reduce_triton(dg_in_ptr, dg_out_ptr, dg_in_stride, n_rows, n sum_dg = tl.sum(acc, axis=0) tl.store(dg_out_ptr + cols, sum_dg.to(dg_out_ptr.type.element_ty), mask=cols < n_cols) + +def _get_dg_reduce_configs(): + # n_rows is NUM_PRGMS (<=~144 on MI300X) so the M dimension is small. + # The reduce kernel is <1% of bwd cost, so a tight 6-config sweep is plenty; + # bigger sweeps just pay first-call compile tax for marginal gain. + return [ + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64}, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64}, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128}, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64}, num_warps=8), + triton.Config({'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 64}, num_warps=8), + triton.Config({'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 128}, num_warps=8), + ] + + +_rmsnorm_bwd_dg_reduce_triton = triton.autotune( + configs=_get_dg_reduce_configs(), + key=['n_rows', 'n_cols'], + use_cuda_graph=True, +)(_rmsnorm_bwd_dg_reduce_triton_impl) + From baaaec9c66c480bced378c2aa1d1f01c7512ac0b Mon Sep 17 00:00:00 2001 From: Meekail Zain Date: Wed, 20 May 2026 19:39:59 +0000 Subject: [PATCH 2/5] Updated rmsnorm kernel w/ RMW accumulation pattern and autotuning --- .../pytorch/triton_kernels/norms_common.py | 24 ++++++--- .../pytorch/triton_kernels/rmsnorm.py | 52 ++++++++++++++----- 2 files changed, 56 insertions(+), 20 deletions(-) diff --git a/transformer_engine/pytorch/triton_kernels/norms_common.py b/transformer_engine/pytorch/triton_kernels/norms_common.py index d9d4db330..fe5134ecf 100644 --- a/transformer_engine/pytorch/triton_kernels/norms_common.py +++ b/transformer_engine/pytorch/triton_kernels/norms_common.py @@ -260,17 +260,26 @@ def te_rmsnorm_bwd_triton(dz, x, rsigma, gamma, sm_margin, zero_centered_gamma, blk_size = block_size(x_) USE_BLOCKED = use_blocked(x_) NUM_PRGMS = num_programs(x_, sm_margin) - # Both blocked and non-blocked paths now accumulate per-program (NUM_PRGMS rows) - # rather than per-input-row (M rows). For typical workloads (NUM_PRGMS ~= 144 on - # MI300X vs M up to 32k), this shrinks the partial buffer by ~100x and keeps it - # L2-resident, turning the bwd dg RMW into a near-free op vs going to HBM. + # dg accumulation strategy: + # * Large M (rows_per_program > 1): per-program partial buffer of shape + # (NUM_PRGMS, N) accumulated via HBM RMW. Buffer is small, L2-resident, + # RMW near-free; reduce kernel then sums NUM_PRGMS rows. + # * Small M (rows_per_program == 1, i.e. NUM_PRGMS == M): RMW would just + # be load+add+store of a slot only written once. Fall back to pure + # per-row writes into (M, N) and skip the zero-init. + # * Non-blocked path always writes via in-register accumulator (no RMW). + rows_per_program_gt_1 = NUM_PRGMS < M + DG_RMW = USE_BLOCKED and rows_per_program_gt_1 need_reduction = NUM_PRGMS > 1 - # Blocked path uses HBM RMW so the buffer must be zero-initialized. - # Non-blocked path writes once per program; empty is fine. if need_reduction: - if USE_BLOCKED: + if DG_RMW: + # RMW requires zero-init. dg_tmp = torch.zeros(NUM_PRGMS, N, device=x.device, dtype=torch.float32, requires_grad=False) + elif USE_BLOCKED: + # Pure per-row writes; rows are M. + dg_tmp = torch.empty(M, N, device=x.device, dtype=torch.float32, requires_grad=False) else: + # Non-blocked: each program writes its slot unconditionally. dg_tmp = torch.empty(NUM_PRGMS, N, device=x.device, dtype=torch.float32, requires_grad=False) else: dg_tmp = None @@ -292,6 +301,7 @@ def te_rmsnorm_bwd_triton(dz, x, rsigma, gamma, sm_margin, zero_centered_gamma, GRAD_OUTPUT_ALIGNED_16=grad_output_aligned_16, DX_ALIGNED_16=dx_aligned_16, DG_ALIGNED_16=dg_aligned_16, + DG_RMW=DG_RMW, ) if not autotune: bwd_kwargs["num_warps"] = 8 diff --git a/transformer_engine/pytorch/triton_kernels/rmsnorm.py b/transformer_engine/pytorch/triton_kernels/rmsnorm.py index fcebeea3e..da53c57e5 100644 --- a/transformer_engine/pytorch/triton_kernels/rmsnorm.py +++ b/transformer_engine/pytorch/triton_kernels/rmsnorm.py @@ -187,10 +187,20 @@ def _rmsnorm_bwd_triton_impl(grad_output_ptr, input_ptr, g_ptr, rsigma_ptr, dx_p n_rows, n_cols, ZERO_CENTERED_GAMMA: tl.constexpr, BLOCK_SIZE: tl.constexpr, USE_BLOCKED: tl.constexpr, NUM_PRGMS: tl.constexpr, INPUT_ALIGNED_16: tl.constexpr, GRAD_OUTPUT_ALIGNED_16: tl.constexpr, - DX_ALIGNED_16: tl.constexpr, DG_ALIGNED_16: tl.constexpr): - # `dg_ptr` points to a (NUM_PRGMS, n_cols) fp32 partial buffer (pre-zeroed - # by the launcher). Each program accumulates its assigned rows into its - # own slot and a separate reduce kernel finalizes dgamma. + DX_ALIGNED_16: tl.constexpr, DG_ALIGNED_16: tl.constexpr, + DG_RMW: tl.constexpr = True): + # `dg_ptr` points to a fp32 partial buffer that will be summed by the + # reduce kernel afterwards. Two storage modes (selected by the launcher): + # + # DG_RMW=True -> (NUM_PRGMS, n_cols), pre-zeroed. Each program owns one + # slot and accumulates its assigned rows via HBM RMW. + # L2-resident partial buffer makes the RMW near-free + # when rows_per_program >> 1. + # + # DG_RMW=False -> (n_rows, n_cols), uninitialized. Each program writes + # one row per `row_idx` it processes (no RMW). Used + # when n_rows <= NUM_PRGMS so RMW would just be wasted + # load+add+store on a slot that's only written once. row_start = tl.program_id(0) col_offsets = tl.arange(0, BLOCK_SIZE) # Precomputed per-row invariant: dx = nf * (dz*g - c*x) where @@ -201,12 +211,16 @@ def _rmsnorm_bwd_triton_impl(grad_output_ptr, input_ptr, g_ptr, rsigma_ptr, dx_p # tl.assume(row_start >= 0) if USE_BLOCKED: - # Per-program partial dg slot in the (NUM_PRGMS, n_cols) scratch. - prgm_dg_ptr = dg_ptr + row_start * n_cols + # Per-program partial dg slot when accumulating with RMW. + if DG_RMW: + prgm_dg_ptr = dg_ptr + row_start * n_cols for row_idx in tl.range(row_start, n_rows, NUM_PRGMS, num_stages=2): row_input_ptr = input_ptr + row_idx * input_row_stride row_grad_output_ptr = grad_output_ptr + row_idx * output_row_stride row_dx_ptr = dx_ptr + row_idx * input_row_stride + # Per-row dg slot for pure-write mode. + if not DG_RMW: + row_dg_ptr = dg_ptr + row_idx * n_cols # Compute gradients sum of all colums for each row n_cols_blks = tl.cdiv(n_cols, BLOCK_SIZE) - 1 @@ -270,13 +284,19 @@ def _rmsnorm_bwd_triton_impl(grad_output_ptr, input_ptr, g_ptr, rsigma_ptr, dx_p dx_ptrs = tl.multiple_of(dx_ptrs, (16, )) tl.store(dx_ptrs, grad_input.to(dx_ptr.type.element_ty)) - # Accumulate this row's dg contribution into per-program slot. + # Accumulate (RMW) or write (pure) this row's dg contribution. dg = grad_output * x * norm_factor - dg_ptrs = prgm_dg_ptr + cols + if DG_RMW: + dg_ptrs = prgm_dg_ptr + cols + else: + dg_ptrs = row_dg_ptr + cols if DG_ALIGNED_16: dg_ptrs = tl.multiple_of(dg_ptrs, (16, )) - partial_dg = tl.load(dg_ptrs) - tl.store(dg_ptrs, partial_dg + dg) + if DG_RMW: + partial_dg = tl.load(dg_ptrs) + tl.store(dg_ptrs, partial_dg + dg) + else: + tl.store(dg_ptrs, dg) # Handle remainder cols = n_cols_blks * BLOCK_SIZE + col_offsets @@ -298,11 +318,17 @@ def _rmsnorm_bwd_triton_impl(grad_output_ptr, input_ptr, g_ptr, rsigma_ptr, dx_p tl.store(dx_ptrs, grad_input.to(dx_ptr.type.element_ty), mask=mask) dg = grad_output * x * norm_factor - dg_ptrs = prgm_dg_ptr + cols + if DG_RMW: + dg_ptrs = prgm_dg_ptr + cols + else: + dg_ptrs = row_dg_ptr + cols if DG_ALIGNED_16: dg_ptrs = tl.multiple_of(dg_ptrs, (16, )) - partial_dg = tl.load(dg_ptrs, mask=mask, other=0.0) - tl.store(dg_ptrs, partial_dg + dg, mask=mask) + if DG_RMW: + partial_dg = tl.load(dg_ptrs, mask=mask, other=0.0) + tl.store(dg_ptrs, partial_dg + dg, mask=mask) + else: + tl.store(dg_ptrs, dg, mask=mask) else: mask = col_offsets < n_cols From de2f7d836f3679d18150c74b74667e9d542cbb05 Mon Sep 17 00:00:00 2001 From: Meekail Zain Date: Wed, 20 May 2026 19:53:30 +0000 Subject: [PATCH 3/5] Added external transpose kernel for LDS optimized transpose --- .../pytorch/triton_kernels/norms_common.py | 51 ++++++++++++++- .../pytorch/triton_kernels/rmsnorm.py | 64 +++++++++++++++++++ 2 files changed, 112 insertions(+), 3 deletions(-) diff --git a/transformer_engine/pytorch/triton_kernels/norms_common.py b/transformer_engine/pytorch/triton_kernels/norms_common.py index fe5134ecf..30f65a097 100644 --- a/transformer_engine/pytorch/triton_kernels/norms_common.py +++ b/transformer_engine/pytorch/triton_kernels/norms_common.py @@ -1,6 +1,7 @@ # Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved. # License for AMD contributions = MIT. See LICENSE for more information +import os import torch import triton import warnings @@ -23,7 +24,19 @@ _rmsnorm_bwd_triton_impl, _rmsnorm_bwd_dg_reduce_triton, _rmsnorm_bwd_dg_reduce_triton_impl, + _fp8_transpose_2d_triton, + _fp8_transpose_2d_impl, ) + +# Use the external LDS-tiled byte transpose instead of the in-kernel strided +# stores. Default on -- the in-kernel path is uncoalesced and bottlenecks +# every fp8_t shape. Set NVTE_RMS_EXTERNAL_TRANSPOSE=0 to fall back. +_USE_EXTERNAL_TRANSPOSE = os.environ.get("NVTE_RMS_EXTERNAL_TRANSPOSE", "1") == "1" + +_fp8_transpose_kernels = { + True: _fp8_transpose_2d_triton, + False: _fp8_transpose_2d_impl, +} from .layernorm import ( _layernorm_fwd_triton, _layernorm_fwd_triton_impl, @@ -164,6 +177,10 @@ def _te_norm_fwd_triton( out_transpose_ptr = None out_transpose_stride = None FP8_MAX = None + # When True, skip in-kernel strided transpose stores and dispatch a + # separate LDS-tiled transpose kernel after the main fwd. Only applies + # to the rms path for now. + use_external_transpose = False if IS_FP8: MAKE_TRANSPOSE = quantizer.columnwise_usage amax = ( @@ -182,8 +199,11 @@ def _te_norm_fwd_triton( dtype=out._data.dtype, device=device ) out._transpose_invalid = False - out_transpose_ptr = triton.reinterpret(out._transpose, tl_dtype) - out_transpose_stride = out._transpose.stride(0) + use_external_transpose = _USE_EXTERNAL_TRANSPOSE and kernel == 'rms' + if not use_external_transpose: + # In-kernel strided transpose path; main kernel does the writes. + out_transpose_ptr = triton.reinterpret(out._transpose, tl_dtype) + out_transpose_stride = out._transpose.stride(0) grid_fwd = lambda meta: (NUM_PRGMS,) kernel_func = _norm_kernels[kernel][autotune] @@ -207,7 +227,9 @@ def _te_norm_fwd_triton( BLOCK_SIZE=BLOCK_SIZE, IS_FP8=IS_FP8, FP8_MAX=FP8_MAX, - MAKE_TRANSPOSE=MAKE_TRANSPOSE, + # Gate the in-kernel strided transpose stores off when we'll do the + # transpose externally via the LDS-tiled kernel. + MAKE_TRANSPOSE=(MAKE_TRANSPOSE and not use_external_transpose), ) if kernel == 'layer': kwargs["APPLY_ATOMIC"]=APPLY_ATOMIC @@ -228,6 +250,29 @@ def _te_norm_fwd_triton( kernel_func[grid_fwd](**kwargs) + if use_external_transpose: + # out._data: (N rows, H cols) row-major uint8; out._transpose: (H, N). + transpose_kernel = _fp8_transpose_kernels[autotune] + if autotune: + grid_t = lambda meta: ( + triton.cdiv(N, meta['BLOCK_M']), + triton.cdiv(H, meta['BLOCK_N']), + ) + transpose_kernel[grid_t]( + out._data, out._transpose, + N, H, + out._data.stride(0), out._transpose.stride(0), + ) + else: + BLOCK_M, BLOCK_N = 64, 64 + grid_t = (triton.cdiv(N, BLOCK_M), triton.cdiv(H, BLOCK_N)) + transpose_kernel[grid_t]( + out._data, out._transpose, + N, H, + out._data.stride(0), out._transpose.stride(0), + BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, + ) + # Reduce and find amax if "not APPLY_ATOMIC" is True for layernorm. if IS_FP8 and not APPLY_ATOMIC: _layernorm_fwd_reduce_triton[(triton.cdiv(N, ATOMIC_REDUCTION_BLOCK_SIZE),)]( diff --git a/transformer_engine/pytorch/triton_kernels/rmsnorm.py b/transformer_engine/pytorch/triton_kernels/rmsnorm.py index da53c57e5..f4f9bea7c 100644 --- a/transformer_engine/pytorch/triton_kernels/rmsnorm.py +++ b/transformer_engine/pytorch/triton_kernels/rmsnorm.py @@ -416,3 +416,67 @@ def _get_dg_reduce_configs(): use_cuda_graph=True, )(_rmsnorm_bwd_dg_reduce_triton_impl) + +# --------------------------------------------------------------------------- # +# External LDS-tiled byte transpose +# +# Replaces the in-kernel `out_transpose_ptr + cols * stride + row_idx` strided +# stores that the main fwd kernel emits when MAKE_TRANSPOSE=True. Those writes +# are uncoalesced (1 byte/thread to a different cache line each), which is why +# every fp8_t shape in the bench sat at ~1.00x. +# +# This kernel reads a (BLOCK_M, BLOCK_N) tile coalesced from the row-major +# fp8 output, transposes it through LDS via `tl.trans`, and writes the +# (BLOCK_N, BLOCK_M) tile coalesced to the column-major transpose buffer. +# +# Operates on the raw uint8 storage so the fp8 dtype is irrelevant to +# correctness. +# --------------------------------------------------------------------------- # +@triton.jit +def _fp8_transpose_2d_impl( + src_ptr, # uint8 ptr, (n_rows, n_cols) row-major + dst_ptr, # uint8 ptr, (n_cols, n_rows) row-major + n_rows, n_cols, + src_stride, # element stride of src row dim (== n_cols when contig) + dst_stride, # element stride of dst row dim (== n_rows when contig) + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, +): + pid_m = tl.program_id(0) + pid_n = tl.program_id(1) + + rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + cn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + + # Coalesced read of (BLOCK_M, BLOCK_N) tile (innermost dim = cols). + src_offs = rm[:, None] * src_stride + cn[None, :] + src_mask = (rm[:, None] < n_rows) & (cn[None, :] < n_cols) + tile = tl.load(src_ptr + src_offs, mask=src_mask, other=0) + + # LDS-staged transpose -> (BLOCK_N, BLOCK_M). + tile_t = tl.trans(tile) + + # Coalesced write of (BLOCK_N, BLOCK_M) tile (innermost dim = rows). + dst_offs = cn[:, None] * dst_stride + rm[None, :] + dst_mask = (cn[:, None] < n_cols) & (rm[None, :] < n_rows) + tl.store(dst_ptr + dst_offs, tile_t, mask=dst_mask) + + +def _get_fp8_transpose_configs(): + # 1 B/elem on AMD CDNA3. Keep tile <= ~16 KB so the LDS staging buffer + # fits with room for double-buffering. tl.trans handles bank conflicts. + return [ + triton.Config({'BLOCK_M': 32, 'BLOCK_N': 64}, num_warps=4), + triton.Config({'BLOCK_M': 64, 'BLOCK_N': 64}, num_warps=4), + triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128}, num_warps=4), + triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64}, num_warps=4), + triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128}, num_warps=8), + triton.Config({'BLOCK_M': 64, 'BLOCK_N': 64}, num_warps=8), + ] + + +_fp8_transpose_2d_triton = triton.autotune( + configs=_get_fp8_transpose_configs(), + key=['n_rows', 'n_cols'], + use_cuda_graph=True, +)(_fp8_transpose_2d_impl) From 12c680ae2855155b20b57e929d4945927e8ad213 Mon Sep 17 00:00:00 2001 From: Meekail Zain Date: Wed, 20 May 2026 20:04:24 +0000 Subject: [PATCH 4/5] Update test to account for new autotuning --- tests/pytorch/triton_kernels/test_norms.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/tests/pytorch/triton_kernels/test_norms.py b/tests/pytorch/triton_kernels/test_norms.py index a4f11ba36..487d40524 100644 --- a/tests/pytorch/triton_kernels/test_norms.py +++ b/tests/pytorch/triton_kernels/test_norms.py @@ -298,7 +298,13 @@ def test_norm_triton( zero_centered_gamma=zero_centered_gamma, ) - triton_bwd_outs = triton_bwd_func(*args["triton"]) + # te_rmsnorm_bwd_triton accepts an `autotune` kwarg; te_layernorm_bwd_triton does not. + # Honor the same NVTE_TEST_TRITON_AUTOTUNE env toggle as the fwd path so + # default test runs avoid the autotune compile/sweep cost. + if norm == "rms": + triton_bwd_outs = triton_bwd_func(*args["triton"], autotune=autotune) + else: + triton_bwd_outs = triton_bwd_func(*args["triton"]) if norm == "layer": dx_triton, dgamma_triton, dbeta_triton = triton_bwd_outs From 5f2a9932fa0dba8c91c8d52d0f4ee4cb1030d92f Mon Sep 17 00:00:00 2001 From: Meekail Zain Date: Thu, 21 May 2026 14:03:21 +0000 Subject: [PATCH 5/5] Trim comments --- .../pytorch/triton_kernels/rmsnorm.py | 15 +++++++-------- 1 file changed, 7 insertions(+), 8 deletions(-) diff --git a/transformer_engine/pytorch/triton_kernels/rmsnorm.py b/transformer_engine/pytorch/triton_kernels/rmsnorm.py index f4f9bea7c..93b6cd49e 100644 --- a/transformer_engine/pytorch/triton_kernels/rmsnorm.py +++ b/transformer_engine/pytorch/triton_kernels/rmsnorm.py @@ -203,8 +203,6 @@ def _rmsnorm_bwd_triton_impl(grad_output_ptr, input_ptr, g_ptr, rsigma_ptr, dx_p # load+add+store on a slot that's only written once. row_start = tl.program_id(0) col_offsets = tl.arange(0, BLOCK_SIZE) - # Precomputed per-row invariant: dx = nf * (dz*g - c*x) where - # c = nf*nf * grad_sum / n_cols inv_n_cols = 1.0 / n_cols # tl.assume(input_row_stride >= 0) # tl.assume(output_row_stride >= 0) @@ -256,8 +254,10 @@ def _rmsnorm_bwd_triton_impl(grad_output_ptr, input_ptr, g_ptr, rsigma_ptr, dx_p g += 1. grad_sum += tl.sum(grad_output * x * g, axis=0) - # Load r_sigma; hoist per-row invariants used in dx. + # Load r_sigma norm_factor = tl.load(rsigma_ptr + row_idx).to(tl.float32) + # Precomputed per-row invariant: dx = nf * (dz*g - c*x) where + # c = nf*nf * grad_sum / n_cols c_scalar = norm_factor * norm_factor * grad_sum * inv_n_cols for blk_idx in tl.range(0, n_cols_blks, num_stages=2): @@ -334,8 +334,6 @@ def _rmsnorm_bwd_triton_impl(grad_output_ptr, input_ptr, g_ptr, rsigma_ptr, dx_p mask = col_offsets < n_cols dg_col_redux = tl.zeros((BLOCK_SIZE, ), dtype=tl.float32) - # Hoist gamma load + ZERO_CENTERED adjustment outside the row loop - # since gamma is invariant across rows. g = tl.load(g_ptr + col_offsets, mask=mask, other=0.0).to(tl.float32) if (ZERO_CENTERED_GAMMA): g += 1. @@ -357,6 +355,8 @@ def _rmsnorm_bwd_triton_impl(grad_output_ptr, input_ptr, g_ptr, rsigma_ptr, dx_p norm_factor = tl.load(rsigma_ptr + row_idx).to(tl.float32) grad_sum = tl.sum(grad_output * x * g, axis=0) + # Precomputed per-row invariant: dx = nf * (dz*g - c*x) where + # c = nf*nf * grad_sum / n_cols c_scalar = norm_factor * norm_factor * grad_sum * inv_n_cols grad_input = norm_factor * (grad_output * g - c_scalar * x) @@ -397,7 +397,7 @@ def _rmsnorm_bwd_dg_reduce_triton_impl(dg_in_ptr, dg_out_ptr, dg_in_stride, n_ro def _get_dg_reduce_configs(): - # n_rows is NUM_PRGMS (<=~144 on MI300X) so the M dimension is small. + # n_rows is NUM_PRGMS so the M dimension is small. # The reduce kernel is <1% of bwd cost, so a tight 6-config sweep is plenty; # bigger sweeps just pay first-call compile tax for marginal gain. return [ @@ -422,8 +422,7 @@ def _get_dg_reduce_configs(): # # Replaces the in-kernel `out_transpose_ptr + cols * stride + row_idx` strided # stores that the main fwd kernel emits when MAKE_TRANSPOSE=True. Those writes -# are uncoalesced (1 byte/thread to a different cache line each), which is why -# every fp8_t shape in the bench sat at ~1.00x. +# are uncoalesced (1 byte/thread to a different cache line each). # # This kernel reads a (BLOCK_M, BLOCK_N) tile coalesced from the row-major # fp8 output, transposes it through LDS via `tl.trans`, and writes the