diff --git a/fbgemm_gpu/experimental/gen_ai/src/moe/index_shuffling.cu b/fbgemm_gpu/experimental/gen_ai/src/moe/index_shuffling.cu index 036729eea1..c790c235ee 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/moe/index_shuffling.cu +++ b/fbgemm_gpu/experimental/gen_ai/src/moe/index_shuffling.cu @@ -351,7 +351,7 @@ __global__ void index_shuffling_kernel(Params params) { rhs_smem_index, num_reduced_threads == 1); } else { - merge_topk( + merge_topk( smem.routing_scores, smem.expert_indices, lhs_smem_index, @@ -502,7 +502,18 @@ std::tuple index_shuffling_torch( num_experts == 16 || num_experts == 32 || num_experts == 128 || num_experts == 320); - TORCH_CHECK(top_k == 1 || top_k == 2 || top_k == 4); + // ROCm currently only supports top_k=1. See L562 +#ifdef USE_ROCM + TORCH_CHECK( + top_k == 1, + "ROCm currently only supports top_k=1. Requested top_k=", + top_k); +#else + TORCH_CHECK( + top_k == 1 || top_k == 2 || top_k == 4 || top_k == 8, + "top_k must be 1, 2, 4, or 8. Got top_k=", + top_k); +#endif auto allocate_index_tensor = [&](int size) { return at::empty( @@ -549,21 +560,63 @@ std::tuple index_shuffling_torch( // Reducing tile size as problem size increases to avoid // cudaErrorCooperativeLaunchTooLarge. // TopK > 1 is not supported on AMD yet. +// Expert-specific DISPATCH macros prevent compile-time errors from +// static_assert at L322: NumTokensPerTile % kNumParallelReductionGroups == 0 +// +// Each expert count has different divisibility constraints: +// - E=16: kNumParallelReductionGroups=16 → tile ≥16 (never reduces) +// - E=32: kNumParallelReductionGroups=8 → tile ≥8 (max reduction: B/2) +// - E=128: kNumParallelReductionGroups=2 → tile ≥2 (max reduction: B/8) +// - E=320: kNumParallelReductionGroups=1 → tile ≥1 (max reduction: B/16) #ifndef USE_ROCM -#define DISPATCH(E, B, K, S) \ - if (S <= 128) { \ - DISPATCH_K(E, B, K); \ - } else if (storage_factor <= 256) { \ - DISPATCH_K(E, B / 2, K); \ - } else if (storage_factor <= 512) { \ - DISPATCH_K(E, B / 4, K); \ - } else { \ - DISPATCH_K(E, B / 8, K); \ - } +#define DISPATCH_E_16(B, K, S) \ + DISPATCH_K(16, B, K); // E=16: Never reduces (always tile=16) + +#define DISPATCH_E_32(B, K, S) \ + if (S <= 128) { \ + DISPATCH_K(32, B, K); \ + } else { \ + DISPATCH_K(32, B / 2, K); \ + } // E=32: Min tile=8 (B/2) + +#define DISPATCH_E_128(B, K, S) \ + if (S <= 128) { \ + DISPATCH_K(128, B, K); \ + } else if (S <= 256) { \ + DISPATCH_K(128, B / 2, K); \ + } else if (S <= 512) { \ + DISPATCH_K(128, B / 4, K); \ + } else { \ + DISPATCH_K(128, B / 8, K); \ + } // E=128: Min tile=2 (B/8) + +#define DISPATCH_E_320(B, K, S) \ + if (S <= 128) { \ + DISPATCH_K(320, B, K); \ + } else if (S <= 256) { \ + DISPATCH_K(320, B / 2, K); \ + } else if (S <= 512) { \ + DISPATCH_K(320, B / 4, K); \ + } else { \ + DISPATCH_K(320, B / 8, K); \ + } // E=320: Min tile=2 (B/8) #else -#define DISPATCH(E, B, K, S) \ - TORCH_CHECK(K == 1); \ - DISPATCH_EB(E, 8, 1) +// ROCm: Only K=1 supported, fixed tile sizes per expert count +#define DISPATCH_E_16(B, K, S) \ + TORCH_CHECK(K == 1); \ + DISPATCH_K(16, B, K) // E=16: B=32 + +#define DISPATCH_E_32(B, K, S) \ + TORCH_CHECK(K == 1); \ + DISPATCH_K(32, B, K) // E=32: B=32 + +#define DISPATCH_E_128(B, K, S) \ + TORCH_CHECK(K == 1); \ + DISPATCH_EB(128, 8, 1) // E=128: B set to 8 + +#define DISPATCH_E_320(B, K, S) \ + TORCH_CHECK(K == 1); \ + DISPATCH_EB(320, 8, 1) // E=320: B set to 8 #endif #define DISPATCH_K(E, B, K) \ @@ -571,9 +624,11 @@ std::tuple index_shuffling_torch( DISPATCH_EB(E, B, 1) \ } else if (K == 2) { \ DISPATCH_EB(E, B, 2) \ - } else { \ - TORCH_CHECK(K == 4); \ + } else if (K == 4) { \ DISPATCH_EB(E, B, 4) \ + } else { \ + TORCH_CHECK(K == 8); \ + DISPATCH_EB(E, B, 8) \ } #define DISPATCH_EB(E, B, K) \ kernel = (void*)index_shuffling_kernel; \ @@ -582,14 +637,14 @@ std::tuple index_shuffling_torch( int storage_factor = top_k * num_experts; if (num_experts == 16) { - DISPATCH_K(16, kNumTokensPerTileFewExperts, top_k) + DISPATCH_E_16(kNumTokensPerTileFewExperts, top_k, storage_factor) } else if (num_experts == 32) { - DISPATCH_K(32, kNumTokensPerTileFewExperts, top_k) + DISPATCH_E_32(kNumTokensPerTileFewExperts, top_k, storage_factor) } else if (num_experts == 128) { - DISPATCH(128, kNumTokensPerTileFewExperts, top_k, storage_factor) + DISPATCH_E_128(kNumTokensPerTileFewExperts, top_k, storage_factor) } else { TORCH_CHECK(num_experts == 320); - DISPATCH(320, kNumTokensPerTileFewExperts, top_k, storage_factor) + DISPATCH_E_320(kNumTokensPerTileFewExperts, top_k, storage_factor) } // This is to avoid build errors (divisibility asserts and local memory // overflow) on AMD. diff --git a/fbgemm_gpu/experimental/gen_ai/test/moe/shuffling_test.py b/fbgemm_gpu/experimental/gen_ai/test/moe/shuffling_test.py index 0d6c040a5a..1b2c31e884 100644 --- a/fbgemm_gpu/experimental/gen_ai/test/moe/shuffling_test.py +++ b/fbgemm_gpu/experimental/gen_ai/test/moe/shuffling_test.py @@ -24,6 +24,7 @@ ) from hypothesis import given, settings, strategies as st, Verbosity +from parameterized import parameterized from pyre_extensions import none_throws try: @@ -40,6 +41,45 @@ _MAX_SAMPLES: int = 100 +def _generate_tile_boundary_test_cases() -> list[tuple[str, int, int]]: + """ + Generate test cases at tile size threshold boundaries. + + For each expert E, for each threshold T: + - Calculate K = T / E + - Round to smallest valid K >= target + - Use (E, K) as test case + + Returns list of (test_name, num_experts, top_k). + """ + + # Tile size boundary test configuration (sync with index_shuffling.cu) + _SUPPORTED_EXPERT_COUNTS: list[int] = [16, 32, 128, 320] + _SUPPORTED_TOP_K: list[int] = [1, 2, 4, 8] + _STORAGE_FACTOR_THRESHOLDS: list[int] = [128, 256, 512] + + def _find_nearest_valid_k(target_k: float) -> int | None: + """Find smallest K >= target_k, or None if no such K exists.""" + candidates = [k for k in _SUPPORTED_TOP_K if k >= target_k] + return min(candidates) if candidates else None + + seen = set() + cases = [] + + for num_experts in _SUPPORTED_EXPERT_COUNTS: + for threshold in _STORAGE_FACTOR_THRESHOLDS: + target_k = threshold / num_experts + nearest_k = _find_nearest_valid_k(target_k) + + if nearest_k and (num_experts, nearest_k) not in seen: + seen.add((num_experts, nearest_k)) + storage_factor = num_experts * nearest_k + test_name = f"e{num_experts}_k{nearest_k}_s{storage_factor}" + cases.append((test_name, num_experts, nearest_k)) + + return sorted(cases) + + @unittest.skipIf(open_source, "Tests currently fail in open source") @unittest.skipIf( not torch.cuda.is_available(), @@ -48,20 +88,7 @@ class ShufflingTests(unittest.TestCase): """Test shuffling kernels.""" - @given( - num_tokens=st.sampled_from( - [1, 3, 123, 128, 1234, 2048, 4567, 4096, 8192, 16384] - ), - num_experts=st.sampled_from([16, 32, 128, 320]), - num_local_experts=st.sampled_from([None, 8]), - top_k=st.sampled_from([1, 2, 4] if torch.version.cuda else [1]), - padded=st.sampled_from([True, False]), - rowmajor=st.sampled_from([True, False]), - compiled=st.sampled_from([True, False]), - routing_score_dtype=st.sampled_from([torch.float, torch.bfloat16]), - ) - @settings(verbosity=Verbosity.verbose, max_examples=_MAX_SAMPLES, deadline=None) - def test_topk_index_shuffling( + def _run_topk_index_shuffling_test( self, num_tokens: int, num_experts: int, @@ -211,6 +238,70 @@ def _assert_indices_equal( start_index = end_index ref_start_index = ref_end_index + @parameterized.expand(_generate_tile_boundary_test_cases()) + def test_topk_index_shuffling_tile_size_boundaries( + self, + name: str, + num_experts: int, + top_k: int, + ) -> None: + """ + Test index shuffling at tile size threshold boundaries. + + The index shuffling kernel switches the number of tokens per tile (tile size) + depending on the storage factor (shared memory pressure on CUDA). These test cases + test for correctness at the tile size threshold boundaries. + """ + # Skip K>1 on ROCm (not supported) + if top_k > 1 and torch.version.hip: + self.skipTest("ROCm only supports top_k=1") + + self._run_topk_index_shuffling_test( + num_tokens=2049, + num_experts=num_experts, + num_local_experts=None, + top_k=top_k, + padded=False, + rowmajor=True, + compiled=False, + routing_score_dtype=torch.float32, + ) + + @given( + num_tokens=st.sampled_from( + [1, 3, 123, 128, 1234, 2048, 4567, 4096, 8192, 16384] + ), + num_experts=st.sampled_from([16, 32, 128, 320]), + num_local_experts=st.sampled_from([None, 8]), + top_k=st.sampled_from([1, 2, 4, 8] if torch.version.cuda else [1]), + padded=st.sampled_from([True, False]), + rowmajor=st.sampled_from([True, False]), + compiled=st.sampled_from([True, False]), + routing_score_dtype=st.sampled_from([torch.float, torch.bfloat16]), + ) + @settings(verbosity=Verbosity.verbose, max_examples=_MAX_SAMPLES, deadline=None) + def test_topk_index_shuffling( + self, + num_tokens: int, + num_experts: int, + num_local_experts: Optional[int], + top_k: int, + padded: bool, + rowmajor: bool, + compiled: bool, + routing_score_dtype: torch.dtype, + ) -> None: + self._run_topk_index_shuffling_test( + num_tokens=num_tokens, + num_experts=num_experts, + num_local_experts=num_local_experts, + top_k=top_k, + padded=padded, + rowmajor=rowmajor, + compiled=compiled, + routing_score_dtype=routing_score_dtype, + ) + @given( batch_size=st.sampled_from( [1, 8, 123, 128, 1234, 2048, 4096, 4567, 8192, 16384]