Bitmap topk#3009
Conversation
Signed-off-by: tdophung <tdophung@nvidia.com>
Without this XLA_FFI_REGISTER_ENUM_ATTR_DECODING the FFI handler templates cannot instantiate AttrDecoding<JAXX_Routing_Map_Format>, breaking the JAX build in router.cpp. Signed-off-by: tdophung <tdophung@nvidia.com>
for more information, see https://pre-commit.ci
…g the routing map type enum Signed-off-by: tdophung <tdophung@nvidia.com>
Greptile SummaryThis PR adds BITMAP_U8 as a new output format for the routing map in the fused topk router kernels, alongside the existing BYTEMAP format. A new
Confidence Score: 5/5The change is additive and backward-compatible: the V1 API is preserved as a BYTEMAP-delegating wrapper, and BYTEMAP remains the default everywhere. Both CUDA kernel paths use correct memory stride arithmetic, proper __syncwarp synchronization before reading the shmem bitmap accumulator, and the little-endian uint32 to uint8 reinterpret_cast is safe on all CUDA devices. The Python layer correctly saves the flat routing_map for backward and reshapes it for callers. No files require special attention. Important Files Changed
Flowchart%%{init: {'theme': 'neutral'}}%%
flowchart TD
A["fused_topk_with_score_function(logits, ..., routing_map_format)"] --> B{routing_map_format?}
B -->|BYTEMAP| C["allocate bool[T, E]"]
B -->|BITMAP_U8| D["allocate uint8[T, ceil(E/8)]"]
C --> E["V2 forward (BYTEMAP)"]
D --> F["V2 forward (BITMAP_U8)"]
E --> G["CUDA kernel: write 0/1 bytes to global routing_map"]
F --> H["CUDA kernel: atomicOr into shmem uint32, byte-copy to global uint8"]
G --> I["Return probs[T,E], routing_map bool[T,E]"]
H --> J["Return probs[T,E], routing_map uint8[T,ceil(E/8)]"]
I --> K["Backward: read routing_map[pos+i] != 0"]
J --> L["Backward: read (bitmap_row[i/8] >> i%8) & 1"]
Reviews (3): Last reviewed commit: "Merge branch 'main' into bitmap_topk" | Re-trigger Greptile |
Signed-off-by: tdophung <tdophung@nvidia.com>
for more information, see https://pre-commit.ci
jberchtold-nvidia
left a comment
There was a problem hiding this comment.
LGTM, thanks! I reviewed core and JAX changes but not PyTorch
|
/te_ci |
Description
Add a new path to our topk kernel to output the routing map in bitmap format instead of bytemap alone. The default still stay at bytemap so no regression for existing consumers downstream of this op. However, since the op now requires an additional arg to specify the routing map type (bytemap or bitmap), we introduce a V2 of the API to accomplish this, while keeping the original API the same not to break customers.
This helps NCCL EP not have to do the token_indices (sparse format) conversion to bitmap format for comms later.
Fixes #2999
Type of change
Changes
Checklist: