Skip to content

Bitmap topk#3009

Open
tdophung wants to merge 11 commits into
NVIDIA:mainfrom
tdophung:bitmap_topk
Open

Bitmap topk#3009
tdophung wants to merge 11 commits into
NVIDIA:mainfrom
tdophung:bitmap_topk

Conversation

@tdophung
Copy link
Copy Markdown
Collaborator

@tdophung tdophung commented May 18, 2026

Description

Add a new path to our topk kernel to output the routing map in bitmap format instead of bytemap alone. The default still stay at bytemap so no regression for existing consumers downstream of this op. However, since the op now requires an additional arg to specify the routing map type (bytemap or bitmap), we introduce a V2 of the API to accomplish this, while keeping the original API the same not to break customers.

This helps NCCL EP not have to do the token_indices (sparse format) conversion to bitmap format for comms later.

Fixes #2999

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

  • Change in the kernel for an if path that does atomicOr for all expert indices shifted by bit position to create the expert indices -> bitmap conversion.
  • Plumb the arg for routing map mode (byte map or bitmap) through pytorch and jax primitives

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

tdophung and others added 5 commits May 18, 2026 11:29
Signed-off-by: tdophung <tdophung@nvidia.com>
Without this XLA_FFI_REGISTER_ENUM_ATTR_DECODING the FFI handler
templates cannot instantiate AttrDecoding<JAXX_Routing_Map_Format>,
breaking the JAX build in router.cpp.

Signed-off-by: tdophung <tdophung@nvidia.com>
…g the routing map type enum

Signed-off-by: tdophung <tdophung@nvidia.com>
@tdophung tdophung marked this pull request as ready for review May 20, 2026 20:42
@tdophung tdophung requested a review from phu0ngng May 20, 2026 20:43
@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps Bot commented May 20, 2026

Greptile Summary

This PR adds BITMAP_U8 as a new output format for the routing map in the fused topk router kernels, alongside the existing BYTEMAP format. A new NVTERoutingMapFormat enum and V2 C API are introduced; the original V1 API is preserved as a BYTEMAP-defaulting wrapper for ABI compatibility.

  • CUDA kernels: Both forward and backward kernels are templatized on NVTERoutingMapFormat; the BITMAP_U8 path accumulates bits into a per-warp uint32 shmem buffer via atomicOr, then byte-copies to the global uint8 output row.
  • Python bindings (PyTorch and JAX): routing_map_format parameter plumbed through autograd functions; string, int, and enum inputs are all validated. The output routing_map shape is correctly reshaped to match the caller's input leading dims.
  • Tests: Both PyTorch and JAX test suites gain bitmap-vs-bytemap parity tests covering forward output, shape/dtype, and backward gradient identity.

Confidence Score: 5/5

The change is additive and backward-compatible: the V1 API is preserved as a BYTEMAP-delegating wrapper, and BYTEMAP remains the default everywhere.

Both CUDA kernel paths use correct memory stride arithmetic, proper __syncwarp synchronization before reading the shmem bitmap accumulator, and the little-endian uint32 to uint8 reinterpret_cast is safe on all CUDA devices. The Python layer correctly saves the flat routing_map for backward and reshapes it for callers.

No files require special attention.

Important Files Changed

Filename Overview
transformer_engine/common/fused_router/fused_topk_with_score_function.cu Templatized both forward and backward kernels on NVTERoutingMapFormat; BITMAP_U8 path correctly uses per-warp shmem atomicOr accumulation with __syncwarp before readback. Shape validation added.
transformer_engine/common/fused_router/fused_score_for_moe_aux_loss.cu Same BITMAP_U8 kernel changes as the topk kernel; AUX_LOSS_FORWARD_DISPATCH macro cleanly separates the two format paths.
transformer_engine/common/include/transformer_engine/fused_router.h New NVTERoutingMapFormat enum and V2 function declarations added; V1 functions correctly marked deprecated and delegate to V2 with BYTEMAP.
transformer_engine/pytorch/router.py routing_map_format correctly plumbed through FusedTopkScoreFunction and FusedComputeScoresForMoEAuxLoss; routing_map saved flat for backward and correctly reshaped for the caller.
transformer_engine/jax/cpp_extensions/router.py routing_map_format added to FwdPrimitive and BwdPrimitive static args; abstract correctly computes BITMAP_U8 routing_map shape; shardy sharding rules correctly handle both formats.
transformer_engine/pytorch/csrc/extensions/router.cpp allocate_routing_map helper correctly chooses bool[T,E] vs uint8[T,ceil(E/8)] output shape; both fwd functions upgraded to V2 API.
transformer_engine/jax/csrc/extensions/router.cpp JAX FFI handler correctly constructs routing_map_shape from format before wrapping in TensorWrapper; both forward and backward switch to V2 API.

Flowchart

%%{init: {'theme': 'neutral'}}%%
flowchart TD
    A["fused_topk_with_score_function(logits, ..., routing_map_format)"] --> B{routing_map_format?}
    B -->|BYTEMAP| C["allocate bool[T, E]"]
    B -->|BITMAP_U8| D["allocate uint8[T, ceil(E/8)]"]
    C --> E["V2 forward (BYTEMAP)"]
    D --> F["V2 forward (BITMAP_U8)"]
    E --> G["CUDA kernel: write 0/1 bytes to global routing_map"]
    F --> H["CUDA kernel: atomicOr into shmem uint32, byte-copy to global uint8"]
    G --> I["Return probs[T,E], routing_map bool[T,E]"]
    H --> J["Return probs[T,E], routing_map uint8[T,ceil(E/8)]"]
    I --> K["Backward: read routing_map[pos+i] != 0"]
    J --> L["Backward: read (bitmap_row[i/8] >> i%8) & 1"]
Loading

Reviews (3): Last reviewed commit: "Merge branch 'main' into bitmap_topk" | Re-trigger Greptile

Comment thread transformer_engine/jax/router.py
Comment thread transformer_engine/pytorch/router.py Outdated
Comment thread transformer_engine/common/fused_router/fused_score_for_moe_aux_loss.cu Outdated
Comment thread transformer_engine/common/include/transformer_engine/fused_router.h Outdated
Comment thread transformer_engine/common/include/transformer_engine/fused_router.h Outdated
Comment thread transformer_engine/common/fused_router/fused_topk_with_score_function.cu Outdated
Comment thread transformer_engine/jax/cpp_extensions/router.py Outdated
Comment thread transformer_engine/jax/csrc/extensions/misc.h
Copy link
Copy Markdown
Collaborator

@jberchtold-nvidia jberchtold-nvidia left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks! I reviewed core and JAX changes but not PyTorch

@tdophung
Copy link
Copy Markdown
Collaborator Author

/te_ci

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Output routing map from fused_topk_with_scores in bitmap format

2 participants