Skip to content

[Common] Fix fused router for large top-K and expert counts#2821

Open
harryzhou2000 wants to merge 8 commits intoNVIDIA:mainfrom
harryzhou2000:hhanyu/router_fix_p2
Open

[Common] Fix fused router for large top-K and expert counts#2821
harryzhou2000 wants to merge 8 commits intoNVIDIA:mainfrom
harryzhou2000:hhanyu/router_fix_p2

Conversation

@harryzhou2000
Copy link
Copy Markdown
Member

Description

Fixed fused router support for large topk and num_expert. Now num_expert <=2304 and any topk is supported with reasonable performance.

Current benchmark shows fused topk forward kernel is faster than pytorch at topk=32, which would be around 8x faster than before optimization.

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

  • Fix fused MoE router kernels to support large top-K values and large numbers of experts (1024+) by adding a warp-level radix-selection top-K algorithm (O(E), independent of K) alongside the existing naive O(K^2*E) implementation, dispatched at topk >= 16 boundary.
  • Expand dynamic shared memory allocation via cudaFuncSetAttribute in both forward and backward
    kernel launchers to avoid silent failures when expert count exceeds the default 48 KB limit.
  • Rewrite apply_softmax_on_float to use a numerically stable online max+sum accumulation
    (two-pass → single-pass) with NaN-safe warp reduction, eliminating shared-memory round-trips.

Details

Radix top-K selection (utils.h):
Implements a 4-bit radix selection algorithm (8 passes over float32) that finds the K-th largest
value in O(E/32) per warp, independent of K. Phase 1 narrows the bit pattern of the K-th value
via histogram counting; Phase 2 gathers elements into output arrays with deterministic tie-breaking
(value DESC, index ASC) matching torch.topk behavior.
Dispatch logic (fused_topk_with_score_function.cu, fused_score_for_moe_aux_loss.cu):
Template parameter TopkFuncType (Naive/Radix) is selected at launch time based on
topk < 16. Both forward kernels and backward kernels now call cudaFuncSetAttribute to
request the required dynamic shared memory size before launch.
Tests (test_fused_router.py):

  • Add num_experts=1024 to all parametrized test cases.
  • Add _get_tolerances() helper that scales atol/rtol with expert count to account for
    O(N * eps) accumulation divergence between fused and reference implementations.

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

@harryzhou2000 harryzhou2000 marked this pull request as ready for review April 1, 2026 14:44
@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps bot commented Apr 1, 2026

Greptile Summary

This PR extends the fused MoE router to handle large expert counts (up to 2304) and any top-K value by introducing a warp-level radix-selection algorithm (radix_topk_and_mask) dispatched when topk >= 16, expanding dynamic shared memory allocation via cudaFuncSetAttribute, and rewriting apply_softmax_on_float to use a numerically stable single-pass online max+sum accumulation. Previous review concerns (unchecked cudaFuncSetAttribute returns, missing Radix path test coverage, and the dead dtype parameter in _get_tolerances) have all been addressed in this revision.

Confidence Score: 5/5

Safe to merge — all previously identified P1 issues are resolved and no new correctness bugs were found.

All prior P1 findings (unchecked CUDA attribute calls, untested Radix code path, dead dtype parameter) have been addressed. The radix selection algorithm is mathematically sound (phase 1 iterative narrowing + phase 2 ballot-prefix gather), the NaN-guarded warp butterfly in apply_softmax_on_float is correct, and the shmem capacity check fires before any kernel launch. The only remaining dead-code artifact (ordered_uint_to_float) is a P2 cleanup item that does not affect correctness.

transformer_engine/common/utils.cuh — ordered_uint_to_float is defined but never called; may produce an unused-function warning depending on compiler flags.

Important Files Changed

Filename Overview
transformer_engine/common/fused_router/utils.h Adds TopkFuncType enum, radix_topk_and_mask (full O(E) warp-level radix selection), topk_and_mask dispatch template, and rewrites apply_softmax_on_float to single-pass online max+sum; algorithm looks correct but ordered_uint_to_float is still dead code
transformer_engine/common/utils.cuh Adds float_to_ordered_uint, ordered_uint_to_float (unused), and warp_allreduce_sum; float_to_ordered_uint correctly preserves descending sort order for all IEEE 754 floats
transformer_engine/common/fused_router/fused_topk_with_score_function.cu Adds TopkFunc template parameter to forward kernel and dispatches Naive vs Radix at topk < 16; adds NVTE_CHECK_CUDA-wrapped cudaFuncSetAttribute for both forward and backward kernels; removes stale utils.cuh include
transformer_engine/common/fused_router/fused_score_for_moe_aux_loss.cu Same dispatch pattern as fused_topk; all cudaFuncSetAttribute calls properly checked; backward launcher also updated with shmem capacity check
tests/pytorch/test_fused_router.py Adds _get_tolerances helper, num_experts=1024, topk=16/32 to cover the Radix path, and skip guards for invalid combinations; now exercises both Naive and Radix dispatch branches

Flowchart

%%{init: {'theme': 'neutral'}}%%
flowchart TD
    A[Kernel Launcher\nforward / backward] --> B{topk < 16?}
    B -- Yes --> C[TopkFuncType::Naive\nO K² × E ]
    B -- No --> D[TopkFuncType::Radix\nO E, independent of K]
    A --> E[check_shared_memory_capacity]
    A --> F[cudaFuncSetAttribute\nExpand dynamic shmem]
    F --> G[Kernel Launch]

    D --> H[Phase 1: Radix Selection\n8 passes × 4-bit nibble]
    H --> H1[Per-pass: count per bucket\nwarp_allreduce_sum × 16]
    H1 --> H2[Scan buckets DESC\nFind bucket containing K-th element]
    H2 --> H3{More passes?}
    H3 -- Yes --> H1
    H3 -- No --> I[desired = exact uint32 of K-th value]

    I --> J[Phase 2: Gather]
    J --> J1[Pass A: ballot elements > desired\nExclusive prefix via __popc]
    J1 --> J2[Pass B: ballot elements == desired\nTie-break by ascending index]
    J2 --> K[topk_indices / topk_scores filled]

    C --> L[naive_topk_and_mask\nExisting O K²×E algorithm]
Loading

Reviews (3): Last reviewed commit: "[pre-commit.ci] auto fixes from pre-comm..." | Re-trigger Greptile

@harryzhou2000 harryzhou2000 marked this pull request as draft April 1, 2026 15:08
@harryzhou2000 harryzhou2000 marked this pull request as ready for review April 1, 2026 15:14
@harryzhou2000 harryzhou2000 changed the title Fix fused router for large top-K and expert counts [Common] Fix fused router for large top-K and expert counts Apr 2, 2026
@harryzhou2000 harryzhou2000 marked this pull request as draft April 3, 2026 02:53
@harryzhou2000 harryzhou2000 force-pushed the hhanyu/router_fix_p2 branch from dc4d438 to ee33ea2 Compare April 3, 2026 07:19
harryzhou2000 and others added 8 commits April 3, 2026 15:26
…r of experts

- expanding shared memory when needed
- switch to radix topk selection when topk is large
- test_fused_router.py updated with large num experts and tolerances refined for different cases

Signed-off-by: Harry Zhou <hhanyu@nvidia.com>
for more information, see https://pre-commit.ci

Signed-off-by: Harry Zhou <hhanyu@nvidia.com>
added return value check of cudaFuncSetAttribute in transformer_engine/common/fused_router/fused_topk_with_score_function.cu
added dtype dependent eps in tests/pytorch/test_fused_router.py
removed unneeded code in transformer_engine/common/fused_router/utils.h

pr bot suggestions

Signed-off-by: Harry Zhou <hhanyu@nvidia.com>
for more information, see https://pre-commit.ci

Signed-off-by: Harry Zhou <hhanyu@nvidia.com>
Signed-off-by: Harry Zhou <hhanyu@nvidia.com>
for more information, see https://pre-commit.ci

Signed-off-by: Harry Zhou <hhanyu@nvidia.com>
cleaned up raw warp operations
added comments
added shared_memory check
added return code check

Signed-off-by: Harry Zhou <hhanyu@nvidia.com>
for more information, see https://pre-commit.ci

Signed-off-by: Harry Zhou <hhanyu@nvidia.com>
@harryzhou2000 harryzhou2000 force-pushed the hhanyu/router_fix_p2 branch from ee33ea2 to fab73d1 Compare April 3, 2026 07:30
@harryzhou2000 harryzhou2000 marked this pull request as ready for review April 3, 2026 07:31
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants