Skip to content

Triton RMSNorm Optimizations#593

Open
Micky774 wants to merge 4 commits into
devfrom
zain/rms-opt
Open

Triton RMSNorm Optimizations#593
Micky774 wants to merge 4 commits into
devfrom
zain/rms-opt

Conversation

@Micky774
Copy link
Copy Markdown
Contributor

Description

Optimizes the Triton RMSNorm forward and backward kernels and adds an LDS-tiled FP8 transpose path. Measured 10%-50% improvements across a representative shape sweep for bf16 w/ no quantization or FP8 quant, and improvements of 3x-8x on FP8 Transpose outputs.

Benchmarks generated by this script.

Results

Note that below A refers to the baseline implementation on dev while B is the PR's implementation.

kind dtype    quant       M      N       A ms       B ms    A GB/s    B GB/s   speedup
--------------------------------------------------------------------------------------
bwd  bfloat16 none       29  17389     0.0459     0.0486      67.5      63.6      0.94x  -
bwd  bfloat16 none      256  65536     0.0556     0.0568    1813.9    1778.1      0.98x
bwd  bfloat16 none     1024  32768     0.1209     0.1184    1666.6    1701.0      1.02x
bwd  bfloat16 none     2048  16384     0.1134     0.1066    1775.4    1889.3      1.06x  +
bwd  bfloat16 none     4096   1024     0.0265     0.0266     951.1     948.3      1.00x
bwd  bfloat16 none     4096  12288     0.1702     0.1431    1775.1    2110.5      1.19x  +
bwd  bfloat16 none     8192   2048     0.0568     0.0418    1774.2    2406.9      1.36x  +
bwd  bfloat16 none     8192   4096     0.0678     0.0615    2971.8    3275.5      1.10x  +
bwd  bfloat16 none     8192   8192     0.1132     0.1092    3557.5    3689.2      1.04x
bwd  bfloat16 none    16384   4096     0.1191     0.1020    3382.0    3949.9      1.17x  +
bwd  float16  none       29  17389     0.0500     0.0495      61.9      62.5      1.01x
bwd  float16  none      256  65536     0.0554     0.0569    1821.7    1773.1      0.97x
bwd  float16  none     1024  32768     0.1206     0.1191    1671.0    1691.3      1.01x
bwd  float16  none     2048  16384     0.1134     0.1082    1776.0    1862.1      1.05x
bwd  float16  none     4096   1024     0.0271     0.0260     928.7     970.2      1.04x
bwd  float16  none     4096  12288     0.1688     0.1450    1789.0    2083.7      1.16x  +
bwd  float16  none     8192   2048     0.0567     0.0410    1775.4    2456.2      1.38x  +
bwd  float16  none     8192   4096     0.0676     0.0613    2978.9    3284.0      1.10x  +
bwd  float16  none     8192   8192     0.1128     0.1124    3570.1    3584.2      1.00x
bwd  float16  none    16384   4096     0.1184     0.1020    3400.3    3949.9      1.16x  +
fwd  bfloat16 fp8        29  17389     0.2543     0.2548       6.1       6.1      1.00x
fwd  bfloat16 fp8       256  65536     0.0294     0.0294    1714.1    1716.5      1.00x
fwd  bfloat16 fp8      1024  32768     0.0645     0.0644    1561.2    1563.2      1.00x
fwd  bfloat16 fp8      2048  16384     0.0602     0.0602    1671.7    1673.9      1.00x
fwd  bfloat16 fp8      4096   1024     0.0211     0.0195     596.7     646.9      1.08x  +
fwd  bfloat16 fp8      4096  12288     0.1053     0.1051    1434.6    1436.8      1.00x
fwd  bfloat16 fp8      8192   2048     0.0397     0.0362    1269.3    1392.9      1.10x  +
fwd  bfloat16 fp8      8192   4096     0.0465     0.0417    2166.6    2416.1      1.12x  +
fwd  bfloat16 fp8      8192   8192     0.0620     0.0541    3247.9    3723.7      1.15x  +
fwd  bfloat16 fp8     16384   4096     0.0839     0.0737    2399.9    2733.4      1.14x  +
fwd  bfloat16 fp8_t      29  17389     0.2968     0.2580       6.9       8.0      1.15x  +
fwd  bfloat16 fp8_t     256  65536     0.1723     0.0371     390.3    1811.4      4.64x  +
fwd  bfloat16 fp8_t    1024  32768     0.4436     0.0760     302.7    1766.9      5.84x  +
fwd  bfloat16 fp8_t    2048  16384     0.2115     0.0731     634.8    1836.1      2.89x  +
fwd  bfloat16 fp8_t    4096   1024     0.0397     0.0408     422.8     411.3      0.97x
fwd  bfloat16 fp8_t    4096  12288     0.4059     0.1209     496.1    1665.3      3.36x  +
fwd  bfloat16 fp8_t    8192   2048     0.1072     0.0435     626.6    1542.9      2.46x  +
fwd  bfloat16 fp8_t    8192   4096     0.2243     0.0545     598.7    2464.4      4.12x  +
fwd  bfloat16 fp8_t    8192   8192     0.5651     0.0755     475.1    3557.0      7.49x  +
fwd  bfloat16 fp8_t   16384   4096     0.6377     0.0945     421.0    2840.8      6.75x  +
fwd  bfloat16 none       29  17389     0.2291     0.2279       9.0       9.0      1.01x
fwd  bfloat16 none      256  65536     0.0262     0.0255    2566.4    2639.0      1.03x
fwd  bfloat16 none     1024  32768     0.0504     0.0505    2666.5    2660.2      1.00x
fwd  bfloat16 none     2048  16384     0.0550     0.0540    2442.8    2488.1      1.02x
fwd  bfloat16 none     4096   1024     0.0165     0.0138    1016.7    1213.5      1.19x  +
fwd  bfloat16 none     4096  12288     0.0834     0.0834    2414.5    2414.5      1.00x
fwd  bfloat16 none     8192   2048     0.0354     0.0272    1898.9    2468.6      1.30x  +
fwd  bfloat16 none     8192   4096     0.0499     0.0323    2691.6    4159.2      1.55x  +
fwd  bfloat16 none     8192   8192     0.0682     0.0515    3934.3    5215.3      1.33x  +
fwd  bfloat16 none    16384   4096     0.0953     0.0624    2818.0    4300.3      1.53x  +
fwd  float16  fp8        29  17389     0.2571     0.2517       6.0       6.1      1.02x
fwd  float16  fp8       256  65536     0.0294     0.0293    1716.4    1723.5      1.00x
fwd  float16  fp8      1024  32768     0.0643     0.0644    1566.1    1563.2      1.00x
fwd  float16  fp8      2048  16384     0.0602     0.0602    1673.9    1673.9      1.00x
fwd  float16  fp8      4096   1024     0.0213     0.0199     591.1     633.9      1.07x  +
fwd  float16  fp8      4096  12288     0.1052     0.1051    1435.1    1436.8      1.00x
fwd  float16  fp8      8192   2048     0.0406     0.0364    1240.6    1382.2      1.11x  +
fwd  float16  fp8      8192   4096     0.0484     0.0430    2078.9    2344.1      1.13x  +
fwd  float16  fp8      8192   8192     0.0640     0.0559    3146.4    3601.1      1.14x  +
fwd  float16  fp8     16384   4096     0.0879     0.0766    2291.7    2629.2      1.15x  +
fwd  float16  fp8_t      29  17389     0.2939     0.2548       7.0       8.1      1.15x  +
fwd  float16  fp8_t     256  65536     0.1760     0.0373     382.0    1801.7      4.72x  +
fwd  float16  fp8_t    1024  32768     0.4045     0.0760     332.0    1767.9      5.33x  +
fwd  float16  fp8_t    2048  16384     0.2118     0.0731     633.8    1837.1      2.90x  +
fwd  float16  fp8_t    4096   1024     0.0397     0.0404     423.3     415.3      0.98x
fwd  float16  fp8_t    4096  12288     0.4032     0.1209     499.4    1665.3      3.33x  +
fwd  float16  fp8_t    8192   2048     0.1072     0.0438     626.3    1531.6      2.45x  +
fwd  float16  fp8_t    8192   4096     0.2248     0.0559     597.1    2402.6      4.02x  +
fwd  float16  fp8_t    8192   8192     0.5629     0.0774     476.9    3467.0      7.27x  +
fwd  float16  fp8_t   16384   4096     0.6362     0.0975     422.1    2754.5      6.53x  +
fwd  float16  none       29  17389     0.2735     0.2610       7.5       7.9      1.05x
fwd  float16  none      256  65536     0.0260     0.0254    2586.2    2651.5      1.03x
fwd  float16  none     1024  32768     0.0504     0.0504    2666.5    2664.4      1.00x
fwd  float16  none     2048  16384     0.0542     0.0546    2475.2    2459.0      0.99x
fwd  float16  none     4096   1024     0.0164     0.0139    1024.1    1210.1      1.18x  +
fwd  float16  none     4096  12288     0.0839     0.0838    2400.6    2401.8      1.00x
fwd  float16  none     8192   2048     0.0353     0.0276    1903.2    2432.8      1.28x  +
fwd  float16  none     8192   4096     0.0497     0.0331    2702.4    4053.7      1.50x  +
fwd  float16  none     8192   8192     0.0683     0.0517    3932.0    5191.1      1.32x  +
fwd  float16  none    16384   4096     0.0951     0.0628    2824.0    4275.6      1.51x  +

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • RMSNorm bwd dg accumulation refactor. Shrunk the dg_tmp partial buffer from (M, N) to (NUM_PRGMS, N) fp32 with per-program HBM RMW accumulation. The smaller buffer is L2-resident on typical workloads, turning the RMW into a near-free op.
  • Loop-invariant hoisting.
    • Fwd non-blocked path: gamma load + ZERO_CENTERED_GAMMA adjustment + 1/n_cols hoisted outside the persistent row loop.
    • Bwd non-blocked path: same gamma hoist; inv_n_cols hoisted.
    • Bwd both paths: per-row c_scalar = nf*nf*grad_sum*inv_n_cols computed once before the dx/dg loop; dx expression refactored to nf * (dz*g - c*x) (saves one multiply per element).
  • Autotune wiring for bwd kernels. _rmsnorm_bwd_triton and _rmsnorm_bwd_dg_reduce_triton now follow the impl + autotune-wrapper dispatch pattern already used by the fwd kernel. te_rmsnorm_bwd_triton takes an autotune: bool = True kwarg; when off it uses the previously-hardcoded num_warps=8 + fixed BLOCK_SIZE_M/N=128/64 reduce tile.
  • External LDS-tiled FP8 transpose kernel. New _fp8_transpose_2d_impl (+ autotune wrapper) replaces the in-kernel out_transpose_ptr + cols * stride + row_idx strided byte stores that were uncoalesced (one byte per thread to a different cache line). The new kernel does a coalesced (BLOCK_M, BLOCK_N) read, tl.trans() for LDS-staged transpose, then coalesced strided write. Gated by env var NVTE_RMS_EXTERNAL_TRANSPOSE (default on); set to 0 to fall back to the in-kernel path.

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant