Skip to content

Commit 65413cc

Browse files
metastableBfacebook-github-bot
authored andcommitted
Adding support for top_k=8 to index_shuffling moe kernel
Summary: - Existing `index_shuffling` kernel in fbgemm supports `top_k` values of `{1, 2, 4}`. This diff adds support for top_k=8. - Refactors tiling/dispatch logic to enable the `top_k=8` case - Adds unit test for k=8 and adds another additional test to explicitly test the tile size boundaries. Reviewed By: jasonjk-park Differential Revision: D86728049
1 parent 44f943c commit 65413cc

File tree

2 files changed

+181
-35
lines changed

2 files changed

+181
-35
lines changed

fbgemm_gpu/experimental/gen_ai/src/moe/index_shuffling.cu

Lines changed: 76 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -351,7 +351,7 @@ __global__ void index_shuffling_kernel(Params<DataType, IndexType> params) {
351351
rhs_smem_index,
352352
num_reduced_threads == 1);
353353
} else {
354-
merge_topk<DataType, IndexType, 4>(
354+
merge_topk<DataType, IndexType, TopK>(
355355
smem.routing_scores,
356356
smem.expert_indices,
357357
lhs_smem_index,
@@ -502,7 +502,18 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> index_shuffling_torch(
502502
num_experts == 16 || num_experts == 32 || num_experts == 128 ||
503503
num_experts == 320);
504504

505-
TORCH_CHECK(top_k == 1 || top_k == 2 || top_k == 4);
505+
// ROCm currently only supports top_k=1. See L562
506+
#ifdef USE_ROCM
507+
TORCH_CHECK(
508+
top_k == 1,
509+
"ROCm currently only supports top_k=1. Requested top_k=",
510+
top_k);
511+
#else
512+
TORCH_CHECK(
513+
top_k == 1 || top_k == 2 || top_k == 4 || top_k == 8,
514+
"top_k must be 1, 2, 4, or 8. Got top_k=",
515+
top_k);
516+
#endif
506517

507518
auto allocate_index_tensor = [&](int size) {
508519
return at::empty(
@@ -549,31 +560,75 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> index_shuffling_torch(
549560
// Reducing tile size as problem size increases to avoid
550561
// cudaErrorCooperativeLaunchTooLarge.
551562
// TopK > 1 is not supported on AMD yet.
563+
// Expert-specific DISPATCH macros prevent compile-time errors from
564+
// static_assert at L322: NumTokensPerTile % kNumParallelReductionGroups == 0
565+
//
566+
// Each expert count has different divisibility constraints:
567+
// - E=16: kNumParallelReductionGroups=16 → tile ≥16 (never reduces)
568+
// - E=32: kNumParallelReductionGroups=8 → tile ≥8 (max reduction: B/2)
569+
// - E=128: kNumParallelReductionGroups=2 → tile ≥2 (max reduction: B/8)
570+
// - E=320: kNumParallelReductionGroups=1 → tile ≥1 (max reduction: B/16)
552571
#ifndef USE_ROCM
553-
#define DISPATCH(E, B, K, S) \
554-
if (S <= 128) { \
555-
DISPATCH_K(E, B, K); \
556-
} else if (storage_factor <= 256) { \
557-
DISPATCH_K(E, B / 2, K); \
558-
} else if (storage_factor <= 512) { \
559-
DISPATCH_K(E, B / 4, K); \
560-
} else { \
561-
DISPATCH_K(E, B / 8, K); \
562-
}
572+
#define DISPATCH_E_16(B, K, S) \
573+
DISPATCH_K(16, B, K); // E=16: Never reduces (always tile=16)
574+
575+
#define DISPATCH_E_32(B, K, S) \
576+
if (S <= 128) { \
577+
DISPATCH_K(32, B, K); \
578+
} else { \
579+
DISPATCH_K(32, B / 2, K); \
580+
} // E=32: Min tile=8 (B/2)
581+
582+
#define DISPATCH_E_128(B, K, S) \
583+
if (S <= 128) { \
584+
DISPATCH_K(128, B, K); \
585+
} else if (S <= 256) { \
586+
DISPATCH_K(128, B / 2, K); \
587+
} else if (S <= 512) { \
588+
DISPATCH_K(128, B / 4, K); \
589+
} else { \
590+
DISPATCH_K(128, B / 8, K); \
591+
} // E=128: Min tile=2 (B/8)
592+
593+
#define DISPATCH_E_320(B, K, S) \
594+
if (S <= 128) { \
595+
DISPATCH_K(320, B, K); \
596+
} else if (S <= 256) { \
597+
DISPATCH_K(320, B / 2, K); \
598+
} else if (S <= 512) { \
599+
DISPATCH_K(320, B / 4, K); \
600+
} else { \
601+
DISPATCH_K(320, B / 8, K); \
602+
} // E=320: Min tile=2 (B/8)
563603
#else
564-
#define DISPATCH(E, B, K, S) \
565-
TORCH_CHECK(K == 1); \
566-
DISPATCH_EB(E, 8, 1)
604+
// ROCm: Only K=1 supported, fixed tile sizes per expert count
605+
#define DISPATCH_E_16(B, K, S) \
606+
TORCH_CHECK(K == 1); \
607+
DISPATCH_K(16, B, K) // E=16: B=32
608+
609+
#define DISPATCH_E_32(B, K, S) \
610+
TORCH_CHECK(K == 1); \
611+
DISPATCH_K(32, B, K) // E=32: B=32
612+
613+
#define DISPATCH_E_128(B, K, S) \
614+
TORCH_CHECK(K == 1); \
615+
DISPATCH_EB(128, 8, 1) // E=128: B set to 8
616+
617+
#define DISPATCH_E_320(B, K, S) \
618+
TORCH_CHECK(K == 1); \
619+
DISPATCH_EB(320, 8, 1) // E=320: B set to 8
567620
#endif
568621

569622
#define DISPATCH_K(E, B, K) \
570623
if (K == 1) { \
571624
DISPATCH_EB(E, B, 1) \
572625
} else if (K == 2) { \
573626
DISPATCH_EB(E, B, 2) \
574-
} else { \
575-
TORCH_CHECK(K == 4); \
627+
} else if (K == 4) { \
576628
DISPATCH_EB(E, B, 4) \
629+
} else { \
630+
TORCH_CHECK(K == 8); \
631+
DISPATCH_EB(E, B, 8) \
577632
}
578633
#define DISPATCH_EB(E, B, K) \
579634
kernel = (void*)index_shuffling_kernel<DataType, IndexType, E, B, K>; \
@@ -582,14 +637,14 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> index_shuffling_torch(
582637
int storage_factor = top_k * num_experts;
583638

584639
if (num_experts == 16) {
585-
DISPATCH_K(16, kNumTokensPerTileFewExperts, top_k)
640+
DISPATCH_E_16(kNumTokensPerTileFewExperts, top_k, storage_factor)
586641
} else if (num_experts == 32) {
587-
DISPATCH_K(32, kNumTokensPerTileFewExperts, top_k)
642+
DISPATCH_E_32(kNumTokensPerTileFewExperts, top_k, storage_factor)
588643
} else if (num_experts == 128) {
589-
DISPATCH(128, kNumTokensPerTileFewExperts, top_k, storage_factor)
644+
DISPATCH_E_128(kNumTokensPerTileFewExperts, top_k, storage_factor)
590645
} else {
591646
TORCH_CHECK(num_experts == 320);
592-
DISPATCH(320, kNumTokensPerTileFewExperts, top_k, storage_factor)
647+
DISPATCH_E_320(kNumTokensPerTileFewExperts, top_k, storage_factor)
593648
}
594649
// This is to avoid build errors (divisibility asserts and local memory
595650
// overflow) on AMD.

fbgemm_gpu/experimental/gen_ai/test/moe/shuffling_test.py

Lines changed: 105 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
)
2525

2626
from hypothesis import given, settings, strategies as st, Verbosity
27+
from parameterized import parameterized
2728
from pyre_extensions import none_throws
2829

2930
try:
@@ -40,6 +41,45 @@
4041
_MAX_SAMPLES: int = 100
4142

4243

44+
def _generate_tile_boundary_test_cases() -> list[tuple[str, int, int]]:
45+
"""
46+
Generate test cases at tile size threshold boundaries.
47+
48+
For each expert E, for each threshold T:
49+
- Calculate K = T / E
50+
- Round to smallest valid K >= target
51+
- Use (E, K) as test case
52+
53+
Returns list of (test_name, num_experts, top_k).
54+
"""
55+
56+
# Tile size boundary test configuration (sync with index_shuffling.cu)
57+
_SUPPORTED_EXPERT_COUNTS: list[int] = [16, 32, 128, 320]
58+
_SUPPORTED_TOP_K: list[int] = [1, 2, 4, 8]
59+
_STORAGE_FACTOR_THRESHOLDS: list[int] = [128, 256, 512]
60+
61+
def _find_nearest_valid_k(target_k: float) -> int | None:
62+
"""Find smallest K >= target_k, or None if no such K exists."""
63+
candidates = [k for k in _SUPPORTED_TOP_K if k >= target_k]
64+
return min(candidates) if candidates else None
65+
66+
seen = set()
67+
cases = []
68+
69+
for num_experts in _SUPPORTED_EXPERT_COUNTS:
70+
for threshold in _STORAGE_FACTOR_THRESHOLDS:
71+
target_k = threshold / num_experts
72+
nearest_k = _find_nearest_valid_k(target_k)
73+
74+
if nearest_k and (num_experts, nearest_k) not in seen:
75+
seen.add((num_experts, nearest_k))
76+
storage_factor = num_experts * nearest_k
77+
test_name = f"e{num_experts}_k{nearest_k}_s{storage_factor}"
78+
cases.append((test_name, num_experts, nearest_k))
79+
80+
return sorted(cases)
81+
82+
4383
@unittest.skipIf(open_source, "Tests currently fail in open source")
4484
@unittest.skipIf(
4585
not torch.cuda.is_available(),
@@ -48,20 +88,7 @@
4888
class ShufflingTests(unittest.TestCase):
4989
"""Test shuffling kernels."""
5090

51-
@given(
52-
num_tokens=st.sampled_from(
53-
[1, 3, 123, 128, 1234, 2048, 4567, 4096, 8192, 16384]
54-
),
55-
num_experts=st.sampled_from([16, 32, 128, 320]),
56-
num_local_experts=st.sampled_from([None, 8]),
57-
top_k=st.sampled_from([1, 2, 4] if torch.version.cuda else [1]),
58-
padded=st.sampled_from([True, False]),
59-
rowmajor=st.sampled_from([True, False]),
60-
compiled=st.sampled_from([True, False]),
61-
routing_score_dtype=st.sampled_from([torch.float, torch.bfloat16]),
62-
)
63-
@settings(verbosity=Verbosity.verbose, max_examples=_MAX_SAMPLES, deadline=None)
64-
def test_topk_index_shuffling(
91+
def _run_topk_index_shuffling_test(
6592
self,
6693
num_tokens: int,
6794
num_experts: int,
@@ -211,6 +238,70 @@ def _assert_indices_equal(
211238
start_index = end_index
212239
ref_start_index = ref_end_index
213240

241+
@parameterized.expand(_generate_tile_boundary_test_cases())
242+
def test_topk_index_shuffling_tile_size_boundaries(
243+
self,
244+
name: str,
245+
num_experts: int,
246+
top_k: int,
247+
) -> None:
248+
"""
249+
Test index shuffling at tile size threshold boundaries.
250+
251+
The index shuffling kernel switches the number of tokens per tile (tile size)
252+
depending on the storage factor (shared memory pressure on CUDA). These test cases
253+
test for correctness at the tile size threshold boundaries.
254+
"""
255+
# Skip K>1 on ROCm (not supported)
256+
if top_k > 1 and torch.version.hip:
257+
self.skipTest("ROCm only supports top_k=1")
258+
259+
self._run_topk_index_shuffling_test(
260+
num_tokens=2049,
261+
num_experts=num_experts,
262+
num_local_experts=None,
263+
top_k=top_k,
264+
padded=False,
265+
rowmajor=True,
266+
compiled=False,
267+
routing_score_dtype=torch.float32,
268+
)
269+
270+
@given(
271+
num_tokens=st.sampled_from(
272+
[1, 3, 123, 128, 1234, 2048, 4567, 4096, 8192, 16384]
273+
),
274+
num_experts=st.sampled_from([16, 32, 128, 320]),
275+
num_local_experts=st.sampled_from([None, 8]),
276+
top_k=st.sampled_from([1, 2, 4, 8] if torch.version.cuda else [1]),
277+
padded=st.sampled_from([True, False]),
278+
rowmajor=st.sampled_from([True, False]),
279+
compiled=st.sampled_from([True, False]),
280+
routing_score_dtype=st.sampled_from([torch.float, torch.bfloat16]),
281+
)
282+
@settings(verbosity=Verbosity.verbose, max_examples=_MAX_SAMPLES, deadline=None)
283+
def test_topk_index_shuffling(
284+
self,
285+
num_tokens: int,
286+
num_experts: int,
287+
num_local_experts: Optional[int],
288+
top_k: int,
289+
padded: bool,
290+
rowmajor: bool,
291+
compiled: bool,
292+
routing_score_dtype: torch.dtype,
293+
) -> None:
294+
self._run_topk_index_shuffling_test(
295+
num_tokens=num_tokens,
296+
num_experts=num_experts,
297+
num_local_experts=num_local_experts,
298+
top_k=top_k,
299+
padded=padded,
300+
rowmajor=rowmajor,
301+
compiled=compiled,
302+
routing_score_dtype=routing_score_dtype,
303+
)
304+
214305
@given(
215306
batch_size=st.sampled_from(
216307
[1, 8, 123, 128, 1234, 2048, 4096, 4567, 8192, 16384]

0 commit comments

Comments
 (0)