diff --git a/tests/pytorch/triton_kernels/test_norms.py b/tests/pytorch/triton_kernels/test_norms.py index a4f11ba36..487d40524 100644 --- a/tests/pytorch/triton_kernels/test_norms.py +++ b/tests/pytorch/triton_kernels/test_norms.py @@ -298,7 +298,13 @@ def test_norm_triton( zero_centered_gamma=zero_centered_gamma, ) - triton_bwd_outs = triton_bwd_func(*args["triton"]) + # te_rmsnorm_bwd_triton accepts an `autotune` kwarg; te_layernorm_bwd_triton does not. + # Honor the same NVTE_TEST_TRITON_AUTOTUNE env toggle as the fwd path so + # default test runs avoid the autotune compile/sweep cost. + if norm == "rms": + triton_bwd_outs = triton_bwd_func(*args["triton"], autotune=autotune) + else: + triton_bwd_outs = triton_bwd_func(*args["triton"]) if norm == "layer": dx_triton, dgamma_triton, dbeta_triton = triton_bwd_outs diff --git a/transformer_engine/pytorch/triton_kernels/norms_common.py b/transformer_engine/pytorch/triton_kernels/norms_common.py index ed4002f2c..30f65a097 100644 --- a/transformer_engine/pytorch/triton_kernels/norms_common.py +++ b/transformer_engine/pytorch/triton_kernels/norms_common.py @@ -1,6 +1,7 @@ # Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved. # License for AMD contributions = MIT. See LICENSE for more information +import os import torch import triton import warnings @@ -20,8 +21,22 @@ _rmsnorm_fwd_triton, _rmsnorm_fwd_triton_impl, _rmsnorm_bwd_triton, + _rmsnorm_bwd_triton_impl, _rmsnorm_bwd_dg_reduce_triton, + _rmsnorm_bwd_dg_reduce_triton_impl, + _fp8_transpose_2d_triton, + _fp8_transpose_2d_impl, ) + +# Use the external LDS-tiled byte transpose instead of the in-kernel strided +# stores. Default on -- the in-kernel path is uncoalesced and bottlenecks +# every fp8_t shape. Set NVTE_RMS_EXTERNAL_TRANSPOSE=0 to fall back. +_USE_EXTERNAL_TRANSPOSE = os.environ.get("NVTE_RMS_EXTERNAL_TRANSPOSE", "1") == "1" + +_fp8_transpose_kernels = { + True: _fp8_transpose_2d_triton, + False: _fp8_transpose_2d_impl, +} from .layernorm import ( _layernorm_fwd_triton, _layernorm_fwd_triton_impl, @@ -41,6 +56,16 @@ False: _layernorm_fwd_triton_impl, } } + +_rmsnorm_bwd_kernels = { + True: _rmsnorm_bwd_triton, + False: _rmsnorm_bwd_triton_impl, +} + +_rmsnorm_bwd_dg_reduce_kernels = { + True: _rmsnorm_bwd_dg_reduce_triton, + False: _rmsnorm_bwd_dg_reduce_triton_impl, +} # triton drop-in replacement for transformer_engine::pytorch::rmsnorm_fwd def te_rmsnorm_fwd_triton( input: torch.Tensor, @@ -152,6 +177,10 @@ def _te_norm_fwd_triton( out_transpose_ptr = None out_transpose_stride = None FP8_MAX = None + # When True, skip in-kernel strided transpose stores and dispatch a + # separate LDS-tiled transpose kernel after the main fwd. Only applies + # to the rms path for now. + use_external_transpose = False if IS_FP8: MAKE_TRANSPOSE = quantizer.columnwise_usage amax = ( @@ -170,8 +199,11 @@ def _te_norm_fwd_triton( dtype=out._data.dtype, device=device ) out._transpose_invalid = False - out_transpose_ptr = triton.reinterpret(out._transpose, tl_dtype) - out_transpose_stride = out._transpose.stride(0) + use_external_transpose = _USE_EXTERNAL_TRANSPOSE and kernel == 'rms' + if not use_external_transpose: + # In-kernel strided transpose path; main kernel does the writes. + out_transpose_ptr = triton.reinterpret(out._transpose, tl_dtype) + out_transpose_stride = out._transpose.stride(0) grid_fwd = lambda meta: (NUM_PRGMS,) kernel_func = _norm_kernels[kernel][autotune] @@ -195,7 +227,9 @@ def _te_norm_fwd_triton( BLOCK_SIZE=BLOCK_SIZE, IS_FP8=IS_FP8, FP8_MAX=FP8_MAX, - MAKE_TRANSPOSE=MAKE_TRANSPOSE, + # Gate the in-kernel strided transpose stores off when we'll do the + # transpose externally via the LDS-tiled kernel. + MAKE_TRANSPOSE=(MAKE_TRANSPOSE and not use_external_transpose), ) if kernel == 'layer': kwargs["APPLY_ATOMIC"]=APPLY_ATOMIC @@ -216,6 +250,29 @@ def _te_norm_fwd_triton( kernel_func[grid_fwd](**kwargs) + if use_external_transpose: + # out._data: (N rows, H cols) row-major uint8; out._transpose: (H, N). + transpose_kernel = _fp8_transpose_kernels[autotune] + if autotune: + grid_t = lambda meta: ( + triton.cdiv(N, meta['BLOCK_M']), + triton.cdiv(H, meta['BLOCK_N']), + ) + transpose_kernel[grid_t]( + out._data, out._transpose, + N, H, + out._data.stride(0), out._transpose.stride(0), + ) + else: + BLOCK_M, BLOCK_N = 64, 64 + grid_t = (triton.cdiv(N, BLOCK_M), triton.cdiv(H, BLOCK_N)) + transpose_kernel[grid_t]( + out._data, out._transpose, + N, H, + out._data.stride(0), out._transpose.stride(0), + BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, + ) + # Reduce and find amax if "not APPLY_ATOMIC" is True for layernorm. if IS_FP8 and not APPLY_ATOMIC: _layernorm_fwd_reduce_triton[(triton.cdiv(N, ATOMIC_REDUCTION_BLOCK_SIZE),)]( @@ -234,7 +291,7 @@ def _te_norm_fwd_triton( # triton drop-in replacement for transformer_engine::pytorch::rmsnorm_bwd -def te_rmsnorm_bwd_triton(dz, x, rsigma, gamma, sm_margin, zero_centered_gamma): +def te_rmsnorm_bwd_triton(dz, x, rsigma, gamma, sm_margin, zero_centered_gamma, autotune: bool = True): # may take non-contiguous inputs dz_ = dz.contiguous() x_ = x.contiguous() @@ -248,25 +305,72 @@ def te_rmsnorm_bwd_triton(dz, x, rsigma, gamma, sm_margin, zero_centered_gamma): blk_size = block_size(x_) USE_BLOCKED = use_blocked(x_) NUM_PRGMS = num_programs(x_, sm_margin) - need_reduction = N > 1 - dg_tmp_rows = x_.shape[0] if use_blocked(x_) else num_programs(x_, sm_margin) - dg_tmp = torch.empty(dg_tmp_rows, N, device=x.device, dtype=torch.float32, requires_grad=False) if need_reduction else None + # dg accumulation strategy: + # * Large M (rows_per_program > 1): per-program partial buffer of shape + # (NUM_PRGMS, N) accumulated via HBM RMW. Buffer is small, L2-resident, + # RMW near-free; reduce kernel then sums NUM_PRGMS rows. + # * Small M (rows_per_program == 1, i.e. NUM_PRGMS == M): RMW would just + # be load+add+store of a slot only written once. Fall back to pure + # per-row writes into (M, N) and skip the zero-init. + # * Non-blocked path always writes via in-register accumulator (no RMW). + rows_per_program_gt_1 = NUM_PRGMS < M + DG_RMW = USE_BLOCKED and rows_per_program_gt_1 + need_reduction = NUM_PRGMS > 1 + if need_reduction: + if DG_RMW: + # RMW requires zero-init. + dg_tmp = torch.zeros(NUM_PRGMS, N, device=x.device, dtype=torch.float32, requires_grad=False) + elif USE_BLOCKED: + # Pure per-row writes; rows are M. + dg_tmp = torch.empty(M, N, device=x.device, dtype=torch.float32, requires_grad=False) + else: + # Non-blocked: each program writes its slot unconditionally. + dg_tmp = torch.empty(NUM_PRGMS, N, device=x.device, dtype=torch.float32, requires_grad=False) + else: + dg_tmp = None input_aligned_16 = (x_.data_ptr() % 16 == 0) and (x_.stride(0) * x_.dtype.itemsize % 16 == 0) grad_output_aligned_16 = (dz_.data_ptr() % 16 == 0) and (dz_.stride(0) * dz_.dtype.itemsize % 16 == 0) dx_aligned_16 = (dx.data_ptr() % 16 == 0) and (dx.stride(0) * dx.dtype.itemsize % 16 == 0) dg_target = dg_tmp if need_reduction else dgamma dg_aligned_16 = (dg_target.data_ptr() % 16 == 0) and (dg_target.stride(0) * dg_target.dtype.itemsize % 16 == 0) + grid_bwd = lambda meta: (NUM_PRGMS, ) - _rmsnorm_bwd_triton[grid_bwd](dz_, x_, gamma_, rsigma_, dx, dg_tmp if need_reduction else dgamma, - x_.stride(0), dz_.stride(0), M, N, zero_centered_gamma, blk_size, - USE_BLOCKED, NUM_PRGMS, input_aligned_16, grad_output_aligned_16, - dx_aligned_16, dg_aligned_16, num_warps=8) + bwd_kernel = _rmsnorm_bwd_kernels[autotune] + bwd_kwargs = dict( + n_rows=M, n_cols=N, + ZERO_CENTERED_GAMMA=zero_centered_gamma, + BLOCK_SIZE=blk_size, + USE_BLOCKED=USE_BLOCKED, NUM_PRGMS=NUM_PRGMS, + INPUT_ALIGNED_16=input_aligned_16, + GRAD_OUTPUT_ALIGNED_16=grad_output_aligned_16, + DX_ALIGNED_16=dx_aligned_16, + DG_ALIGNED_16=dg_aligned_16, + DG_RMW=DG_RMW, + ) + if not autotune: + bwd_kwargs["num_warps"] = 8 + bwd_kernel[grid_bwd]( + dz_, x_, gamma_, rsigma_, dx, dg_tmp if need_reduction else dgamma, + x_.stride(0), dz_.stride(0), + **bwd_kwargs, + ) if need_reduction: - grid_reduce = lambda meta: [triton.cdiv(N, meta['BLOCK_SIZE_N'])] - _rmsnorm_bwd_dg_reduce_triton[grid_reduce](dg_tmp, dgamma, dg_tmp.stride(0), dg_tmp.shape[0], dg_tmp.shape[1], - BLOCK_SIZE_M=128, BLOCK_SIZE_N=64) + reduce_kernel = _rmsnorm_bwd_dg_reduce_kernels[autotune] + if autotune: + grid_reduce = lambda meta: [triton.cdiv(N, meta['BLOCK_SIZE_N'])] + reduce_kernel[grid_reduce]( + dg_tmp, dgamma, dg_tmp.stride(0), dg_tmp.shape[0], dg_tmp.shape[1], + ) + else: + # Match the previously-hardcoded tile when autotune is disabled. + BLOCK_SIZE_M, BLOCK_SIZE_N = 128, 64 + grid_reduce = (triton.cdiv(N, BLOCK_SIZE_N),) + reduce_kernel[grid_reduce]( + dg_tmp, dgamma, dg_tmp.stride(0), dg_tmp.shape[0], dg_tmp.shape[1], + BLOCK_SIZE_M=BLOCK_SIZE_M, BLOCK_SIZE_N=BLOCK_SIZE_N, + ) return dx, dgamma diff --git a/transformer_engine/pytorch/triton_kernels/rmsnorm.py b/transformer_engine/pytorch/triton_kernels/rmsnorm.py index 5ecb48eb7..93b6cd49e 100644 --- a/transformer_engine/pytorch/triton_kernels/rmsnorm.py +++ b/transformer_engine/pytorch/triton_kernels/rmsnorm.py @@ -140,22 +140,24 @@ def _rmsnorm_fwd_triton_impl( else: mask = col_offsets < n_cols + # gamma is invariant across rows -- load + ZERO_CENTERED adjustment once per program. + g = tl.load(g_ptr + col_offsets, mask=mask, other=0.0).to(tl.float32) + if (ZERO_CENTERED_GAMMA): + g += 1 + inv_n_cols = 1.0 / n_cols for row_idx in tl.range(row_start, n_rows, NUM_PRGMS, num_stages=2): input_ptrs = input_ptr + row_idx * input_row_stride + col_offsets if INPUT_ALIGNED_16: input_ptrs = tl.multiple_of(input_ptrs, (16, )) row = tl.load(input_ptrs, mask=mask, other=0.0, cache_modifier=".cg").to(tl.float32) - g = tl.load(g_ptr + col_offsets, mask=mask, other=0.0).to(tl.float32) row_norm = row * row row_norm = tl.sum(row_norm, axis=-1) - norm_factor = tl.math.rsqrt((row_norm / n_cols) + epsilon) + norm_factor = tl.math.rsqrt(row_norm * inv_n_cols + epsilon) # Store rsigma (norm_factor) rsigma_output_ptr = rsigma_ptr + row_idx tl.store(rsigma_output_ptr, norm_factor) - if (ZERO_CENTERED_GAMMA): - g += 1 rms_norm = row * norm_factor * g output_ptrs = output_ptr + row_idx * output_row_stride + col_offsets @@ -181,29 +183,45 @@ def _rmsnorm_fwd_triton_impl( _rmsnorm_fwd_triton = autotune_dec(_rmsnorm_fwd_triton_impl) @triton.jit -def _rmsnorm_bwd_triton(grad_output_ptr, input_ptr, g_ptr, rsigma_ptr, dx_ptr, dg_ptr, input_row_stride, output_row_stride, +def _rmsnorm_bwd_triton_impl(grad_output_ptr, input_ptr, g_ptr, rsigma_ptr, dx_ptr, dg_ptr, input_row_stride, output_row_stride, n_rows, n_cols, ZERO_CENTERED_GAMMA: tl.constexpr, BLOCK_SIZE: tl.constexpr, USE_BLOCKED: tl.constexpr, NUM_PRGMS: tl.constexpr, INPUT_ALIGNED_16: tl.constexpr, GRAD_OUTPUT_ALIGNED_16: tl.constexpr, - DX_ALIGNED_16: tl.constexpr, DG_ALIGNED_16: tl.constexpr): + DX_ALIGNED_16: tl.constexpr, DG_ALIGNED_16: tl.constexpr, + DG_RMW: tl.constexpr = True): + # `dg_ptr` points to a fp32 partial buffer that will be summed by the + # reduce kernel afterwards. Two storage modes (selected by the launcher): + # + # DG_RMW=True -> (NUM_PRGMS, n_cols), pre-zeroed. Each program owns one + # slot and accumulates its assigned rows via HBM RMW. + # L2-resident partial buffer makes the RMW near-free + # when rows_per_program >> 1. + # + # DG_RMW=False -> (n_rows, n_cols), uninitialized. Each program writes + # one row per `row_idx` it processes (no RMW). Used + # when n_rows <= NUM_PRGMS so RMW would just be wasted + # load+add+store on a slot that's only written once. row_start = tl.program_id(0) col_offsets = tl.arange(0, BLOCK_SIZE) + inv_n_cols = 1.0 / n_cols # tl.assume(input_row_stride >= 0) # tl.assume(output_row_stride >= 0) # tl.assume(row_start >= 0) if USE_BLOCKED: - for row_idx in tl.range(row_start, n_rows, NUM_PRGMS, num_stages=1): + # Per-program partial dg slot when accumulating with RMW. + if DG_RMW: + prgm_dg_ptr = dg_ptr + row_start * n_cols + for row_idx in tl.range(row_start, n_rows, NUM_PRGMS, num_stages=2): row_input_ptr = input_ptr + row_idx * input_row_stride row_grad_output_ptr = grad_output_ptr + row_idx * output_row_stride row_dx_ptr = dx_ptr + row_idx * input_row_stride - row_dg_ptr = dg_ptr + row_idx * input_row_stride + # Per-row dg slot for pure-write mode. + if not DG_RMW: + row_dg_ptr = dg_ptr + row_idx * n_cols # Compute gradients sum of all colums for each row n_cols_blks = tl.cdiv(n_cols, BLOCK_SIZE) - 1 - # older version of triton doesn't accept below init - # comment out for now to make it compatible with triton 3.1 - # grad_sum: tl.float32 = 0.0 grad_sum = 0.0 for blk_idx in tl.range(0, n_cols_blks, num_stages=2): cols = blk_idx * BLOCK_SIZE + col_offsets @@ -238,6 +256,9 @@ def _rmsnorm_bwd_triton(grad_output_ptr, input_ptr, g_ptr, rsigma_ptr, dx_ptr, d # Load r_sigma norm_factor = tl.load(rsigma_ptr + row_idx).to(tl.float32) + # Precomputed per-row invariant: dx = nf * (dz*g - c*x) where + # c = nf*nf * grad_sum / n_cols + c_scalar = norm_factor * norm_factor * grad_sum * inv_n_cols for blk_idx in tl.range(0, n_cols_blks, num_stages=2): cols = blk_idx * BLOCK_SIZE + col_offsets @@ -256,19 +277,26 @@ def _rmsnorm_bwd_triton(grad_output_ptr, input_ptr, g_ptr, rsigma_ptr, dx_ptr, d g = tl.load(g_ptrs).to(tl.float32) if (ZERO_CENTERED_GAMMA): g += 1. - grad_input = grad_output * norm_factor * g - (norm_factor * norm_factor * norm_factor) * x * (grad_sum / - n_cols) + grad_input = norm_factor * (grad_output * g - c_scalar * x) dx_ptrs = row_dx_ptr + cols if DX_ALIGNED_16: dx_ptrs = tl.multiple_of(dx_ptrs, (16, )) tl.store(dx_ptrs, grad_input.to(dx_ptr.type.element_ty)) + # Accumulate (RMW) or write (pure) this row's dg contribution. dg = grad_output * x * norm_factor - dg_ptrs = row_dg_ptr + cols + if DG_RMW: + dg_ptrs = prgm_dg_ptr + cols + else: + dg_ptrs = row_dg_ptr + cols if DG_ALIGNED_16: dg_ptrs = tl.multiple_of(dg_ptrs, (16, )) - tl.store(dg_ptrs, dg.to(tl.float32)) + if DG_RMW: + partial_dg = tl.load(dg_ptrs) + tl.store(dg_ptrs, partial_dg + dg) + else: + tl.store(dg_ptrs, dg) # Handle remainder cols = n_cols_blks * BLOCK_SIZE + col_offsets @@ -282,8 +310,7 @@ def _rmsnorm_bwd_triton(grad_output_ptr, input_ptr, g_ptr, rsigma_ptr, dx_ptr, d g = tl.load(g_ptrs, mask=mask, other=0.0).to(tl.float32) if (ZERO_CENTERED_GAMMA): g += 1. - grad_input = grad_output * norm_factor * g - (norm_factor * norm_factor * norm_factor) * x * (grad_sum / - n_cols) + grad_input = norm_factor * (grad_output * g - c_scalar * x) dx_ptrs = row_dx_ptr + cols if DX_ALIGNED_16: @@ -291,15 +318,26 @@ def _rmsnorm_bwd_triton(grad_output_ptr, input_ptr, g_ptr, rsigma_ptr, dx_ptr, d tl.store(dx_ptrs, grad_input.to(dx_ptr.type.element_ty), mask=mask) dg = grad_output * x * norm_factor - dg_ptrs = row_dg_ptr + cols + if DG_RMW: + dg_ptrs = prgm_dg_ptr + cols + else: + dg_ptrs = row_dg_ptr + cols if DG_ALIGNED_16: dg_ptrs = tl.multiple_of(dg_ptrs, (16, )) - tl.store(dg_ptrs, dg.to(tl.float32), mask=mask) + if DG_RMW: + partial_dg = tl.load(dg_ptrs, mask=mask, other=0.0) + tl.store(dg_ptrs, partial_dg + dg, mask=mask) + else: + tl.store(dg_ptrs, dg, mask=mask) else: mask = col_offsets < n_cols dg_col_redux = tl.zeros((BLOCK_SIZE, ), dtype=tl.float32) + g = tl.load(g_ptr + col_offsets, mask=mask, other=0.0).to(tl.float32) + if (ZERO_CENTERED_GAMMA): + g += 1. + for row_idx in tl.range(row_start, n_rows, NUM_PRGMS, num_stages=2): input_ptrs = input_ptr + row_idx * input_row_stride + col_offsets grad_output_ptrs = grad_output_ptr + row_idx * output_row_stride + col_offsets @@ -314,25 +352,33 @@ def _rmsnorm_bwd_triton(grad_output_ptr, input_ptr, g_ptr, rsigma_ptr, dx_ptr, d x = tl.load(input_ptrs, mask=mask, other=0.0).to(tl.float32) grad_output = tl.load(grad_output_ptrs, mask=mask, other=0.0).to(tl.float32) - g = tl.load(g_ptr + col_offsets, mask=mask, other=0.0).to(tl.float32) - if (ZERO_CENTERED_GAMMA): - g += 1. norm_factor = tl.load(rsigma_ptr + row_idx).to(tl.float32) grad_sum = tl.sum(grad_output * x * g, axis=0) + # Precomputed per-row invariant: dx = nf * (dz*g - c*x) where + # c = nf*nf * grad_sum / n_cols + c_scalar = norm_factor * norm_factor * grad_sum * inv_n_cols - grad_input = grad_output * norm_factor * g - (norm_factor * norm_factor * norm_factor) * x * (grad_sum / - n_cols) + grad_input = norm_factor * (grad_output * g - c_scalar * x) tl.store(dx_ptrs, grad_input.to(dx_ptr.type.element_ty), mask=mask) dg = grad_output * x * norm_factor dg_col_redux += dg.to(tl.float32) - tl.store(dg_ptr + tl.program_id(0) * input_row_stride + col_offsets, dg_col_redux, mask=mask) + tl.store(dg_ptr + row_start * n_cols + col_offsets, dg_col_redux, mask=mask) + + +# Autotune wrapper. Mirrors the fwd autotune layout so callers can toggle +# autotune via the same flag. +_rmsnorm_bwd_triton = triton.autotune( + configs=get_autotune_config(), + key=['n_rows', 'n_cols'], + use_cuda_graph=True, +)(_rmsnorm_bwd_triton_impl) @triton.jit -def _rmsnorm_bwd_dg_reduce_triton(dg_in_ptr, dg_out_ptr, dg_in_stride, n_rows, n_cols, BLOCK_SIZE_M: tl.constexpr, +def _rmsnorm_bwd_dg_reduce_triton_impl(dg_in_ptr, dg_out_ptr, dg_in_stride, n_rows, n_cols, BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr): # we want parallelism in N direction # if N is small, we will just use one CU, @@ -349,3 +395,87 @@ def _rmsnorm_bwd_dg_reduce_triton(dg_in_ptr, dg_out_ptr, dg_in_stride, n_rows, n sum_dg = tl.sum(acc, axis=0) tl.store(dg_out_ptr + cols, sum_dg.to(dg_out_ptr.type.element_ty), mask=cols < n_cols) + +def _get_dg_reduce_configs(): + # n_rows is NUM_PRGMS so the M dimension is small. + # The reduce kernel is <1% of bwd cost, so a tight 6-config sweep is plenty; + # bigger sweeps just pay first-call compile tax for marginal gain. + return [ + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64}, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64}, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128}, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64}, num_warps=8), + triton.Config({'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 64}, num_warps=8), + triton.Config({'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 128}, num_warps=8), + ] + + +_rmsnorm_bwd_dg_reduce_triton = triton.autotune( + configs=_get_dg_reduce_configs(), + key=['n_rows', 'n_cols'], + use_cuda_graph=True, +)(_rmsnorm_bwd_dg_reduce_triton_impl) + + +# --------------------------------------------------------------------------- # +# External LDS-tiled byte transpose +# +# Replaces the in-kernel `out_transpose_ptr + cols * stride + row_idx` strided +# stores that the main fwd kernel emits when MAKE_TRANSPOSE=True. Those writes +# are uncoalesced (1 byte/thread to a different cache line each). +# +# This kernel reads a (BLOCK_M, BLOCK_N) tile coalesced from the row-major +# fp8 output, transposes it through LDS via `tl.trans`, and writes the +# (BLOCK_N, BLOCK_M) tile coalesced to the column-major transpose buffer. +# +# Operates on the raw uint8 storage so the fp8 dtype is irrelevant to +# correctness. +# --------------------------------------------------------------------------- # +@triton.jit +def _fp8_transpose_2d_impl( + src_ptr, # uint8 ptr, (n_rows, n_cols) row-major + dst_ptr, # uint8 ptr, (n_cols, n_rows) row-major + n_rows, n_cols, + src_stride, # element stride of src row dim (== n_cols when contig) + dst_stride, # element stride of dst row dim (== n_rows when contig) + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, +): + pid_m = tl.program_id(0) + pid_n = tl.program_id(1) + + rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + cn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + + # Coalesced read of (BLOCK_M, BLOCK_N) tile (innermost dim = cols). + src_offs = rm[:, None] * src_stride + cn[None, :] + src_mask = (rm[:, None] < n_rows) & (cn[None, :] < n_cols) + tile = tl.load(src_ptr + src_offs, mask=src_mask, other=0) + + # LDS-staged transpose -> (BLOCK_N, BLOCK_M). + tile_t = tl.trans(tile) + + # Coalesced write of (BLOCK_N, BLOCK_M) tile (innermost dim = rows). + dst_offs = cn[:, None] * dst_stride + rm[None, :] + dst_mask = (cn[:, None] < n_cols) & (rm[None, :] < n_rows) + tl.store(dst_ptr + dst_offs, tile_t, mask=dst_mask) + + +def _get_fp8_transpose_configs(): + # 1 B/elem on AMD CDNA3. Keep tile <= ~16 KB so the LDS staging buffer + # fits with room for double-buffering. tl.trans handles bank conflicts. + return [ + triton.Config({'BLOCK_M': 32, 'BLOCK_N': 64}, num_warps=4), + triton.Config({'BLOCK_M': 64, 'BLOCK_N': 64}, num_warps=4), + triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128}, num_warps=4), + triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64}, num_warps=4), + triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128}, num_warps=8), + triton.Config({'BLOCK_M': 64, 'BLOCK_N': 64}, num_warps=8), + ] + + +_fp8_transpose_2d_triton = triton.autotune( + configs=_get_fp8_transpose_configs(), + key=['n_rows', 'n_cols'], + use_cuda_graph=True, +)(_fp8_transpose_2d_impl)