[Common] Fix fused router for large top-K and expert counts#2821
[Common] Fix fused router for large top-K and expert counts#2821harryzhou2000 wants to merge 8 commits intoNVIDIA:mainfrom
Conversation
Greptile SummaryThis PR extends the fused MoE router to handle large expert counts (up to 2304) and any top-K value by introducing a warp-level radix-selection algorithm ( Confidence Score: 5/5Safe to merge — all previously identified P1 issues are resolved and no new correctness bugs were found. All prior P1 findings (unchecked CUDA attribute calls, untested Radix code path, dead dtype parameter) have been addressed. The radix selection algorithm is mathematically sound (phase 1 iterative narrowing + phase 2 ballot-prefix gather), the NaN-guarded warp butterfly in apply_softmax_on_float is correct, and the shmem capacity check fires before any kernel launch. The only remaining dead-code artifact (ordered_uint_to_float) is a P2 cleanup item that does not affect correctness. transformer_engine/common/utils.cuh — ordered_uint_to_float is defined but never called; may produce an unused-function warning depending on compiler flags. Important Files Changed
Flowchart%%{init: {'theme': 'neutral'}}%%
flowchart TD
A[Kernel Launcher\nforward / backward] --> B{topk < 16?}
B -- Yes --> C[TopkFuncType::Naive\nO K² × E ]
B -- No --> D[TopkFuncType::Radix\nO E, independent of K]
A --> E[check_shared_memory_capacity]
A --> F[cudaFuncSetAttribute\nExpand dynamic shmem]
F --> G[Kernel Launch]
D --> H[Phase 1: Radix Selection\n8 passes × 4-bit nibble]
H --> H1[Per-pass: count per bucket\nwarp_allreduce_sum × 16]
H1 --> H2[Scan buckets DESC\nFind bucket containing K-th element]
H2 --> H3{More passes?}
H3 -- Yes --> H1
H3 -- No --> I[desired = exact uint32 of K-th value]
I --> J[Phase 2: Gather]
J --> J1[Pass A: ballot elements > desired\nExclusive prefix via __popc]
J1 --> J2[Pass B: ballot elements == desired\nTie-break by ascending index]
J2 --> K[topk_indices / topk_scores filled]
C --> L[naive_topk_and_mask\nExisting O K²×E algorithm]
Reviews (3): Last reviewed commit: "[pre-commit.ci] auto fixes from pre-comm..." | Re-trigger Greptile |
transformer_engine/common/fused_router/fused_topk_with_score_function.cu
Outdated
Show resolved
Hide resolved
transformer_engine/common/fused_router/fused_score_for_moe_aux_loss.cu
Outdated
Show resolved
Hide resolved
transformer_engine/common/fused_router/fused_topk_with_score_function.cu
Show resolved
Hide resolved
dc4d438 to
ee33ea2
Compare
…r of experts - expanding shared memory when needed - switch to radix topk selection when topk is large - test_fused_router.py updated with large num experts and tolerances refined for different cases Signed-off-by: Harry Zhou <hhanyu@nvidia.com>
for more information, see https://pre-commit.ci Signed-off-by: Harry Zhou <hhanyu@nvidia.com>
added return value check of cudaFuncSetAttribute in transformer_engine/common/fused_router/fused_topk_with_score_function.cu added dtype dependent eps in tests/pytorch/test_fused_router.py removed unneeded code in transformer_engine/common/fused_router/utils.h pr bot suggestions Signed-off-by: Harry Zhou <hhanyu@nvidia.com>
for more information, see https://pre-commit.ci Signed-off-by: Harry Zhou <hhanyu@nvidia.com>
Signed-off-by: Harry Zhou <hhanyu@nvidia.com>
for more information, see https://pre-commit.ci Signed-off-by: Harry Zhou <hhanyu@nvidia.com>
cleaned up raw warp operations added comments added shared_memory check added return code check Signed-off-by: Harry Zhou <hhanyu@nvidia.com>
for more information, see https://pre-commit.ci Signed-off-by: Harry Zhou <hhanyu@nvidia.com>
ee33ea2 to
fab73d1
Compare
Description
Fixed fused router support for large topk and num_expert. Now num_expert <=2304 and any topk is supported with reasonable performance.
Current benchmark shows fused topk forward kernel is faster than pytorch at topk=32, which would be around 8x faster than before optimization.
Type of change
Changes
topk >= 16boundary.cudaFuncSetAttributein both forward and backwardkernel launchers to avoid silent failures when expert count exceeds the default 48 KB limit.
apply_softmax_on_floatto use a numerically stable online max+sum accumulation(two-pass → single-pass) with NaN-safe warp reduction, eliminating shared-memory round-trips.
Details
Radix top-K selection (
utils.h):Implements a 4-bit radix selection algorithm (8 passes over float32) that finds the K-th largest
value in O(E/32) per warp, independent of K. Phase 1 narrows the bit pattern of the K-th value
via histogram counting; Phase 2 gathers elements into output arrays with deterministic tie-breaking
(value DESC, index ASC) matching
torch.topkbehavior.Dispatch logic (
fused_topk_with_score_function.cu,fused_score_for_moe_aux_loss.cu):Template parameter
TopkFuncType(Naive/Radix) is selected at launch time based ontopk < 16. Both forward kernels and backward kernels now callcudaFuncSetAttributetorequest the required dynamic shared memory size before launch.
Tests (
test_fused_router.py):num_experts=1024to all parametrized test cases._get_tolerances()helper that scalesatol/rtolwith expert count to account forO(N * eps) accumulation divergence between fused and reference implementations.
Checklist: