@@ -351,7 +351,7 @@ __global__ void index_shuffling_kernel(Params<DataType, IndexType> params) {
351351 rhs_smem_index,
352352 num_reduced_threads == 1 );
353353 } else {
354- merge_topk<DataType, IndexType, 4 >(
354+ merge_topk<DataType, IndexType, TopK >(
355355 smem.routing_scores ,
356356 smem.expert_indices ,
357357 lhs_smem_index,
@@ -502,7 +502,18 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> index_shuffling_torch(
502502 num_experts == 16 || num_experts == 32 || num_experts == 128 ||
503503 num_experts == 320 );
504504
505- TORCH_CHECK (top_k == 1 || top_k == 2 || top_k == 4 );
505+ // ROCm currently only supports top_k=1. See L562
506+ #ifdef USE_ROCM
507+ TORCH_CHECK (
508+ top_k == 1 ,
509+ " ROCm currently only supports top_k=1. Requested top_k=" ,
510+ top_k);
511+ #else
512+ TORCH_CHECK (
513+ top_k == 1 || top_k == 2 || top_k == 4 || top_k == 8 ,
514+ " top_k must be 1, 2, 4, or 8. Got top_k=" ,
515+ top_k);
516+ #endif
506517
507518 auto allocate_index_tensor = [&](int size) {
508519 return at::empty (
@@ -549,31 +560,75 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> index_shuffling_torch(
549560// Reducing tile size as problem size increases to avoid
550561// cudaErrorCooperativeLaunchTooLarge.
551562// TopK > 1 is not supported on AMD yet.
563+ // Expert-specific DISPATCH macros prevent compile-time errors from
564+ // static_assert at L322: NumTokensPerTile % kNumParallelReductionGroups == 0
565+ //
566+ // Each expert count has different divisibility constraints:
567+ // - E=16: kNumParallelReductionGroups=16 → tile ≥16 (never reduces)
568+ // - E=32: kNumParallelReductionGroups=8 → tile ≥8 (max reduction: B/2)
569+ // - E=128: kNumParallelReductionGroups=2 → tile ≥2 (max reduction: B/8)
570+ // - E=320: kNumParallelReductionGroups=1 → tile ≥1 (max reduction: B/16)
552571#ifndef USE_ROCM
553- #define DISPATCH (E, B, K, S ) \
554- if (S <= 128 ) { \
555- DISPATCH_K (E, B, K); \
556- } else if (storage_factor <= 256 ) { \
557- DISPATCH_K (E, B / 2 , K); \
558- } else if (storage_factor <= 512 ) { \
559- DISPATCH_K (E, B / 4 , K); \
560- } else { \
561- DISPATCH_K (E, B / 8 , K); \
562- }
572+ #define DISPATCH_E_16 (B, K, S ) \
573+ DISPATCH_K (16 , B, K); // E=16: Never reduces (always tile=16)
574+
575+ #define DISPATCH_E_32 (B, K, S ) \
576+ if (S <= 128 ) { \
577+ DISPATCH_K (32 , B, K); \
578+ } else { \
579+ DISPATCH_K (32 , B / 2 , K); \
580+ } // E=32: Min tile=8 (B/2)
581+
582+ #define DISPATCH_E_128 (B, K, S ) \
583+ if (S <= 128 ) { \
584+ DISPATCH_K (128 , B, K); \
585+ } else if (S <= 256 ) { \
586+ DISPATCH_K (128 , B / 2 , K); \
587+ } else if (S <= 512 ) { \
588+ DISPATCH_K (128 , B / 4 , K); \
589+ } else { \
590+ DISPATCH_K (128 , B / 8 , K); \
591+ } // E=128: Min tile=2 (B/8)
592+
593+ #define DISPATCH_E_320 (B, K, S ) \
594+ if (S <= 128 ) { \
595+ DISPATCH_K (320 , B, K); \
596+ } else if (S <= 256 ) { \
597+ DISPATCH_K (320 , B / 2 , K); \
598+ } else if (S <= 512 ) { \
599+ DISPATCH_K (320 , B / 4 , K); \
600+ } else { \
601+ DISPATCH_K (320 , B / 8 , K); \
602+ } // E=320: Min tile=2 (B/8)
563603#else
564- #define DISPATCH (E, B, K, S ) \
565- TORCH_CHECK (K == 1 ); \
566- DISPATCH_EB (E, 8 , 1 )
604+ // ROCm: Only K=1 supported, fixed tile sizes per expert count
605+ #define DISPATCH_E_16 (B, K, S ) \
606+ TORCH_CHECK (K == 1 ); \
607+ DISPATCH_K (16 , B, K) // E=16: B=32
608+
609+ #define DISPATCH_E_32 (B, K, S ) \
610+ TORCH_CHECK (K == 1 ); \
611+ DISPATCH_K (32 , B, K) // E=32: B=32
612+
613+ #define DISPATCH_E_128 (B, K, S ) \
614+ TORCH_CHECK (K == 1 ); \
615+ DISPATCH_EB (128 , 8 , 1 ) // E=128: B set to 8
616+
617+ #define DISPATCH_E_320 (B, K, S ) \
618+ TORCH_CHECK (K == 1 ); \
619+ DISPATCH_EB (320 , 8 , 1 ) // E=320: B set to 8
567620#endif
568621
569622#define DISPATCH_K (E, B, K ) \
570623 if (K == 1 ) { \
571624 DISPATCH_EB (E, B, 1 ) \
572625 } else if (K == 2 ) { \
573626 DISPATCH_EB (E, B, 2 ) \
574- } else { \
575- TORCH_CHECK (K == 4 ); \
627+ } else if (K == 4 ) { \
576628 DISPATCH_EB (E, B, 4 ) \
629+ } else { \
630+ TORCH_CHECK (K == 8 ); \
631+ DISPATCH_EB (E, B, 8 ) \
577632 }
578633#define DISPATCH_EB (E, B, K ) \
579634 kernel = (void *)index_shuffling_kernel<DataType, IndexType, E, B, K>; \
@@ -582,14 +637,14 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> index_shuffling_torch(
582637 int storage_factor = top_k * num_experts;
583638
584639 if (num_experts == 16 ) {
585- DISPATCH_K ( 16 , kNumTokensPerTileFewExperts , top_k)
640+ DISPATCH_E_16 ( kNumTokensPerTileFewExperts , top_k, storage_factor )
586641 } else if (num_experts == 32 ) {
587- DISPATCH_K ( 32 , kNumTokensPerTileFewExperts , top_k)
642+ DISPATCH_E_32 ( kNumTokensPerTileFewExperts , top_k, storage_factor )
588643 } else if (num_experts == 128 ) {
589- DISPATCH ( 128 , kNumTokensPerTileFewExperts , top_k, storage_factor)
644+ DISPATCH_E_128 ( kNumTokensPerTileFewExperts , top_k, storage_factor)
590645 } else {
591646 TORCH_CHECK (num_experts == 320 );
592- DISPATCH ( 320 , kNumTokensPerTileFewExperts , top_k, storage_factor)
647+ DISPATCH_E_320 ( kNumTokensPerTileFewExperts , top_k, storage_factor)
593648 }
594649 // This is to avoid build errors (divisibility asserts and local memory
595650 // overflow) on AMD.
0 commit comments