Skip to content

[Common] Optimize fused router forward/backward kernels#3012

Open
harryzhou2000 wants to merge 15 commits into
NVIDIA:mainfrom
harryzhou2000:hhanyu/router_fix_p3R
Open

[Common] Optimize fused router forward/backward kernels#3012
harryzhou2000 wants to merge 15 commits into
NVIDIA:mainfrom
harryzhou2000:hhanyu/router_fix_p3R

Conversation

@harryzhou2000
Copy link
Copy Markdown
Member

@harryzhou2000 harryzhou2000 commented May 19, 2026

Summary

Optimizes the fused router CUDA kernels introduced in #2821 (fused_topk_with_score_function and fused_score_for_moe_aux_loss). Achieves significant bandwidth improvements for large expert counts and topk values while preserving identical performance for smaller configurations (e.g., E=256, topk=4).

Key results (B300, float32, 8192 tokens):

  • Forward (E=2304, K=36, softmax): 673 → 964 GB/s (+43%)
  • Backward (E=2304, K=36, softmax): 543 → 2766 GB/s (+410%)
  • Forward (E=512, K=4): no regression (±0.3%)

Changes

Forward kernels

  • Persistent grid with async double-buffered prefetch: RawAsyncLoader<T> uses cp.async (sm_80+) for non-blocking global→shmem loads. Occupancy-aware grid sizing (compute_persistent_grid) keeps all SMs saturated across multiple rounds.
  • Packed 8-bit radix histogram: Reduces radix topk register usage from 32 to 4 registers by packing 16 bucket counts into 4×u32 with 8-bit fields. Eliminates local memory spill at large E.
  • Compile-time score function dispatch: ScoreFunc template parameter with if constexpr removes runtime branches from the hot loop.
  • Simple kernel path for small topk: When topk < NVTE_RADIX_TOPK_THRESHOLD (default 8), dispatches to a lightweight kernel matching the original structure — no async loader, no persistent grid — avoiding scheduling overhead that dominates at small K.

Backward kernels

  • Two-pass fused design: Pass 1 accumulates warp-level sums via register reduction + warp_allreduce_sum. Pass 2 computes per-element gradients using scalar helpers. Eliminates the comp_buf shared memory buffer (saves E × warps × 4 bytes per block).
  • Double-buffered async loading: All backward inputs (grad, activation, mask) loaded through RawAsyncLoader with always-on double buffering.

Infrastructure

  • async_loader.h: RawAsyncLoader<T>, compute_persistent_grid(), choose_num_buffers(), vectorized global store/fill helpers.
  • NVTE_RADIX_TOPK_THRESHOLD env var (default 8): configurable naive↔radix crossover.
  • Templated warp_reduce_on_shmem<T, ReduceFuncType> eliminates function-pointer overhead.

Hardening

  • Host-side: num_tokens * num_experts <= INT_MAX, topk ∈ [1, E], topk % group_topk == 0
  • Device-side: assert(data_size <= kMaxExpertsRadixTopk) in radix path
  • Correct cudaDevAttrMaxSharedMemoryPerMultiprocessor for buffer-count decision
  • Fix: single-buffer prefetch clobber when shmem is too tight for double buffering

Compatibility

  • No regression for small configs: The simple forward kernel path is an exact replica of the original kernel structure, ensuring E=256/topk=4 (common in standard MoE) performs identically.
  • All existing tests pass: 891/891 test_fused_router.py tests pass, 117 skipped (fp8/multi-node).
  • No API changes: Same Python/C++ interface, same output semantics.
  • Tunable: Set NVTE_RADIX_TOPK_THRESHOLD=0 to force radix everywhere, or =16 to use naive for topk<16.

Performance (B300 SXM6, sm_103, float32, 8192 tokens)

Effective bandwidth (GB/s) is computed as the minimum bytes that must be transferred to/from global memory for one kernel invocation, divided by the measured wall time. For example, the topk forward kernel reads logits (T×E×dtype) and writes probs (T×E×dtype), routing_map (T×E×1), and intermediate_output (T×E×4). This metric captures how well the kernel utilizes memory bandwidth — higher is better, with the device peak around 8 TB/s on B300. Config format is num_experts/topk.

Full benchmark table (softmax)
kernel pass config before after
topk fprop 512/4 1779 1784 (+0.3%)
topk fprop 512/8 798 904 (+13%)
topk fprop 512/22 514 924 (+80%)
topk fprop 512/36 499 908 (+82%)
topk fprop 2304/4 1803 1802 (0%)
topk fprop 2304/8 660 993 (+51%)
topk fprop 2304/22 602 972 (+61%)
topk fprop 2304/36 673 964 (+43%)
topk bprop 512/22 3391 5362 (+58%)
topk bprop 2304/36 543 2766 (+410%)
aux_loss fprop 512/22 519 896 (+73%)
aux_loss fprop 2304/36 645 891 (+38%)
aux_loss bprop 512/22 5289 6155 (+16%)
aux_loss bprop 2304/36 2272 4201 (+85%)
Full benchmark table (sigmoid)
kernel pass config before after
topk fprop 512/4 1728 1736 (+0.5%)
topk fprop 512/22 470 891 (+90%)
topk fprop 2304/36 639 798 (+25%)
topk bprop 512/22 3169 4398 (+39%)
topk bprop 2304/36 533 2274 (+327%)
aux_loss fprop 512/22 475 912 (+92%)
aux_loss fprop 2304/36 598 867 (+45%)
aux_loss bprop 2304/36 1965 2757 (+40%)

@harryzhou2000 harryzhou2000 force-pushed the hhanyu/router_fix_p3R branch 2 times, most recently from 14a302c to a805f38 Compare May 19, 2026 10:22
@harryzhou2000 harryzhou2000 marked this pull request as ready for review May 20, 2026 08:29
@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps Bot commented May 20, 2026

Greptile Summary

This PR optimizes the fused router CUDA kernels by introducing persistent grids, double-buffered cp.async loads, packed 8-bit radix histograms, compile-time score-function dispatch, and a two-pass backward design that eliminates the comp_buf shared-memory buffer. Previous review concerns (premature shmem check, hardcoded backward buffer count) are addressed.

  • Forward: A new simple kernel handles topk below the threshold; the optimized path uses RawAsyncLoader with double buffering and a 4-register packed radix histogram.
  • Backward: comp_buf shmem is eliminated; all inputs loaded via RawAsyncLoader; two-pass design avoids shmem for intermediate sums.
  • Infrastructure: async_loader.h adds RawAsyncLoader, choose_num_buffers, compute_persistent_grid, and vectorized global store/fill helpers.

Confidence Score: 5/5

Safe to merge; all 891 tests pass and the kernel logic is correct throughout.

The implementation is carefully crafted: premature shmem-check and hardcoded backward buffer count from the prior review are both fixed. The async-loader pipeline, packed radix histogram, and two-pass backward gradient math are all correct. The behavioral change in the topk pre-softmax backward (non-selected experts now receive the mathematically correct softmax-coupling gradient rather than zero) does not affect test outcomes.

fused_topk_with_score_function.cu - the pre-softmax backward behavioral change and the choose_num_buffers argument pattern are worth a second look before landing.

Important Files Changed

Filename Overview
transformer_engine/common/fused_router/async_loader.h New header: RawAsyncLoader with double-buffered cp.async, compute_persistent_grid, choose_num_buffers, vectorized store/fill helpers. Unaligned fallback correctly omits cp_async_commit. Wait uses wait_prior(0), correct since at most one async group is in flight per loader.
transformer_engine/common/fused_router/fused_topk_with_score_function.cu Major refactor: simple forward kernel (small topk), optimized kernel (async loader + persistent grid + radix topk), restructured backward (two-pass, eliminates comp_buf). Premature shmem-check fixed. Undocumented behavioral change: pre-softmax backward now writes softmax-coupling gradients to non-selected expert logits instead of zeroing them.
transformer_engine/common/fused_router/fused_score_for_moe_aux_loss.cu Same structural improvements as topk file. Premature shmem check fixed. choose_num_buffers semantic mismatch present in backward launcher. No pre-softmax coupling issue since aux_loss always applies softmax to all experts.
transformer_engine/common/fused_router/utils.h warp_reduce_on_shmem templated on ReduceFuncType; scalar helpers added; kMaxExpertsRadixTopk=8160 constraint and assert added; radix histogram packs 16 bucket counts into 4 u32 registers using 8-bit fields, correctly guarded against overflow.
transformer_engine/common/fused_router/fused_moe_aux_loss.cu Minimal change: updates warp_reduce_on_shmem call-site to the new compile-time-dispatch template syntax. Correct and consistent.

Reviews (2): Last reviewed commit: "[pre-commit.ci] auto fixes from pre-comm..." | Re-trigger Greptile

Comment thread transformer_engine/common/fused_router/async_loader.h
@tdophung tdophung self-assigned this May 20, 2026
Replace multi-loop preprocess (separate clear/load/score/save/bias loops)
with single fused loops per score function in all 4 kernel paths (topk
forward, topk backward, aux_loss forward, aux_loss backward).

Replace multi-pass backward (array-based helpers + comp_buf shmem) with
a two-pass approach using scalar helpers:
  Pass 1: reduction — warp-level sums via warp_allreduce_sum()
  Pass 2: element-wise — scalar gradient computation → write to global

Add scalar helpers to utils.h: sigmoid_scalar, sqrtsoftplus_scalar,
sigmoid_bwd_scalar, sqrtsoftplus_bwd_scalar, normalize_bwd_scalar,
softmax_bwd_scalar.

Remove dead array helpers from utils.h: apply_sigmoid_on_float,
apply_sigmoid_bwd_on_float, apply_sqrtsoftplus_on_float,
apply_sqrtsoftplus_bwd_on_float, apply_softmax_bwd_on_float,
masked_warp_reduce_on_shmem.

Backward shmem reduced by E×W×sizeof(float) per kernel (comp_buf
eliminated).  Net -226 lines across 3 files.

Signed-off-by: Harry Zhou <hhanyu@nvidia.com>
Add async_loader.h with:
  - RawAsyncLoader<T>: cp.async on sm_80+, int4 fallback on sm_70,
    stores data in original type (no conversion during copy)
  - compute_persistent_grid(): occupancy-based grid sizing
  - choose_num_buffers(): shmem-aware 1-vs-2 buffer decision
  - vec_fill_global(), vec_store_global(): vectorized output helpers

Forward kernels (topk + aux_loss):
  - Logits loaded via RawAsyncLoader with double-buffered prefetch
  - Persistent grid replaces 1-shot grid launch
  - DataType→CompType conversion during compute, not during load
  - vec_fill_global for clearing probs/routing_map

Backward kernels (topk + aux_loss):
  - All inputs loaded via RawAsyncLoader (topk: 3 loaders for
    grad/act/mask; aux_loss: 2 loaders for grad/act)
  - Always double-buffered (kBwdNumBuffers=2, kAuxBwdNumBuffers=2)
  - Persistent grid with occupancy-based sizing

Signed-off-by: Harry Zhou <hhanyu@nvidia.com>
Replace counts[16] + total_counts[16] (32 registers) with 4 packed u32
registers using 8-bit fields (4 counters per register).  Eliminates
massive register spill to local memory on large kernels (81% of L1
traffic on E=2304, K=36).

Add kMaxExpertsRadixTopk constant (8160 = 255 * 32) and runtime checks
in both forward launchers to guard against 8-bit overflow.  All current
MoE configurations (max E=2304) are well within this limit.

Signed-off-by: Harry Zhou <hhanyu@nvidia.com>
…dispatch

Replace runtime score_function parameter in all 4 kernel __global__
functions with template int ScoreFunc (0=sigmoid, 1=softmax,
2=sqrtsoftplus).  All score_function branches now use if constexpr,
eliminating dead-code register pressure and branch overhead.

Forward launchers dispatch on TopkFunc × ScoreFunc = 6 instantiations
per DataType.  Backward launchers dispatch on ScoreFunc = 3
instantiations per DataType.

Signed-off-by: Harry Zhou <hhanyu@nvidia.com>
Fix broken topk < 0 threshold (radix was always selected, naive
unreachable).  Replace with configurable NVTE_RADIX_TOPK_THRESHOLD
env var (default 0, i.e. always use radix).  Set to 16 to restore
the old naive-for-small-K behavior.

Uses the standard TE pattern: static local + getenv (read once,
cached for process lifetime).

Signed-off-by: Harry Zhou <hhanyu@nvidia.com>
When choose_num_buffers() returns 1 (shmem too tight for double
buffering, e.g. E=1024 with group_topk scratch), buf_[0] and buf_[1]
alias the same memory.  The prefetch via start_load(next_buf()) then
overwrites the current buffer while compute is still reading it.

Fix: guard the prefetch on num_buffers > 1.  When single-buffered,
load the current round's data at the top of each iteration instead.
The first round's load_current is still issued before the loop.

Backward kernels are unaffected (always kBwdNumBuffers=2).

Signed-off-by: Harry Zhou <hhanyu@nvidia.com>
Signed-off-by: Harry Zhou <hhanyu@nvidia.com>
Code review fixes:

- C1: choose_num_buffers() now queries cudaDevAttrMaxSharedMemoryPerMultiprocessor
  (per-SM budget) instead of cudaDevAttrMaxSharedMemoryPerBlockOptin (per-block
  max).  These coincide on Hopper/Blackwell but differ on Ampere.

- H3: Remove dead fallback branch in choose_num_buffers() — since
  total_double >= total_single always, blocks_single >= blocks_double,
  so the old ternary always returned 1 anyway.

- H4/M8: Add host-side NVTE_CHECK in all 4 launchers:
  - num_experts > 0
  - topk in [1, num_experts]
  - (int64_t)num_tokens * num_experts <= INT_MAX (kernel uses int offsets)

- M9: Assert topk % group_topk == 0 when group_topk > 0.

- H6: Add device-side assert(data_size <= kMaxExpertsRadixTopk) in
  radix_topk_and_mask() — zero cost in release (NDEBUG), catches
  8-bit histogram overflow in debug builds.

- L1: Fix stale comments claiming default threshold is 16 (it is 0).
- L4: Fix typo 'hanlded' -> 'handled'.
- L8: Remove unused topk parameter from aux loss backward kernel.

Signed-off-by: Harry Zhou <hhanyu@nvidia.com>
Move the duplicated static function from both .cu files into utils.h
as an inline function.  Each TU gets its own static local (read-once
per TU), which is safe since environment variables are immutable
during process lifetime.  Documented this in a NOTE comment.

Signed-off-by: Harry Zhou <hhanyu@nvidia.com>
Replace runtime function-pointer dispatch with compile-time if constexpr.
Eliminates indirect call overhead in the reduction loop and warp shuffle
butterfly, allowing the compiler to emit straight-line arithmetic.

Removes the now-unused max<T>() and sum<T>() helper functions.

Signed-off-by: Harry Zhou <hhanyu@nvidia.com>
When topk < NVTE_RADIX_TOPK_THRESHOLD (default 8), use a lightweight
forward kernel that avoids the async loader and persistent grid overhead.
The simple kernel loads logits directly from global memory to shmem and
uses Naive iterative-argmax topk — matching the baseline structure that
was faster for small K due to lower launch/scheduling overhead.

The optimized path (async loader + persistent grid + radix topk) remains
the default for topk >= 8 where the compute savings dominate.

Both topk and aux_loss forward kernels get the simple variant.
Backward kernels are unchanged (always use the optimized path).

Signed-off-by: Harry Zhou <hhanyu@nvidia.com>
Signed-off-by: Harry Zhou <hhanyu@nvidia.com>
Use 0.0f instead of 0 to avoid ambiguity between __nv_bfloat16(float)
and __nv_bfloat16(double) constructors on older CUDA toolkits.

Signed-off-by: Harry Zhou <hhanyu@nvidia.com>
Signed-off-by: Harry Zhou <hhanyu@nvidia.com>
@harryzhou2000 harryzhou2000 force-pushed the hhanyu/router_fix_p3R branch from 9a7cb7e to 3bab7cb Compare May 21, 2026 03:03
logits, num_tokens, num_experts, topk, use_pre_softmax, num_groups, group_topk,
scaling_factor, score_function, expert_bias, probs, routing_map, intermediate_output);
// Optimized path: async loader + persistent grid + radix topk.
NVTE_CHECK(num_experts <= kMaxExpertsRadixTopk,
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unless I'm misreading the code, this is going to hard-fail for num_experts > 8160.

@ptrendx Is this a restriction we can accept?

If not, I would recommend augmenting the algorithm selection (naive topk vs. radix-sort), and just simply fall back on the naive option whenever num_experts > kMaxExpertsRadixTopk. It may not be performant but it's probably better than a hard failure.

logits, num_tokens, num_experts, topk, score_function, scores, routing_map,
intermediate_output);
// Optimized path: async loader + persistent grid + radix topk.
NVTE_CHECK(num_experts <= kMaxExpertsRadixTopk,
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same issue as fused_topk_with_score_function. We should probably fall back on the naive path here instead of hard-failing when num_experts > 8160.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants