Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
97 changes: 76 additions & 21 deletions fbgemm_gpu/experimental/gen_ai/src/moe/index_shuffling.cu
Original file line number Diff line number Diff line change
Expand Up @@ -351,7 +351,7 @@ __global__ void index_shuffling_kernel(Params<DataType, IndexType> params) {
rhs_smem_index,
num_reduced_threads == 1);
} else {
merge_topk<DataType, IndexType, 4>(
merge_topk<DataType, IndexType, TopK>(
smem.routing_scores,
smem.expert_indices,
lhs_smem_index,
Expand Down Expand Up @@ -502,7 +502,18 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> index_shuffling_torch(
num_experts == 16 || num_experts == 32 || num_experts == 128 ||
num_experts == 320);

TORCH_CHECK(top_k == 1 || top_k == 2 || top_k == 4);
// ROCm currently only supports top_k=1. See L562
#ifdef USE_ROCM
TORCH_CHECK(
top_k == 1,
"ROCm currently only supports top_k=1. Requested top_k=",
top_k);
#else
TORCH_CHECK(
top_k == 1 || top_k == 2 || top_k == 4 || top_k == 8,
"top_k must be 1, 2, 4, or 8. Got top_k=",
top_k);
#endif

auto allocate_index_tensor = [&](int size) {
return at::empty(
Expand Down Expand Up @@ -549,31 +560,75 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> index_shuffling_torch(
// Reducing tile size as problem size increases to avoid
// cudaErrorCooperativeLaunchTooLarge.
// TopK > 1 is not supported on AMD yet.
// Expert-specific DISPATCH macros prevent compile-time errors from
// static_assert at L322: NumTokensPerTile % kNumParallelReductionGroups == 0
//
// Each expert count has different divisibility constraints:
// - E=16: kNumParallelReductionGroups=16 → tile ≥16 (never reduces)
// - E=32: kNumParallelReductionGroups=8 → tile ≥8 (max reduction: B/2)
// - E=128: kNumParallelReductionGroups=2 → tile ≥2 (max reduction: B/8)
// - E=320: kNumParallelReductionGroups=1 → tile ≥1 (max reduction: B/16)
#ifndef USE_ROCM
#define DISPATCH(E, B, K, S) \
if (S <= 128) { \
DISPATCH_K(E, B, K); \
} else if (storage_factor <= 256) { \
DISPATCH_K(E, B / 2, K); \
} else if (storage_factor <= 512) { \
DISPATCH_K(E, B / 4, K); \
} else { \
DISPATCH_K(E, B / 8, K); \
}
#define DISPATCH_E_16(B, K, S) \
DISPATCH_K(16, B, K); // E=16: Never reduces (always tile=16)

#define DISPATCH_E_32(B, K, S) \
if (S <= 128) { \
DISPATCH_K(32, B, K); \
} else { \
DISPATCH_K(32, B / 2, K); \
} // E=32: Min tile=8 (B/2)

#define DISPATCH_E_128(B, K, S) \
if (S <= 128) { \
DISPATCH_K(128, B, K); \
} else if (S <= 256) { \
DISPATCH_K(128, B / 2, K); \
} else if (S <= 512) { \
DISPATCH_K(128, B / 4, K); \
} else { \
DISPATCH_K(128, B / 8, K); \
} // E=128: Min tile=2 (B/8)

#define DISPATCH_E_320(B, K, S) \
if (S <= 128) { \
DISPATCH_K(320, B, K); \
} else if (S <= 256) { \
DISPATCH_K(320, B / 2, K); \
} else if (S <= 512) { \
DISPATCH_K(320, B / 4, K); \
} else { \
DISPATCH_K(320, B / 8, K); \
} // E=320: Min tile=2 (B/8)
#else
#define DISPATCH(E, B, K, S) \
TORCH_CHECK(K == 1); \
DISPATCH_EB(E, 8, 1)
// ROCm: Only K=1 supported, fixed tile sizes per expert count
#define DISPATCH_E_16(B, K, S) \
TORCH_CHECK(K == 1); \
DISPATCH_K(16, B, K) // E=16: B=32

#define DISPATCH_E_32(B, K, S) \
TORCH_CHECK(K == 1); \
DISPATCH_K(32, B, K) // E=32: B=32

#define DISPATCH_E_128(B, K, S) \
TORCH_CHECK(K == 1); \
DISPATCH_EB(128, 8, 1) // E=128: B set to 8

#define DISPATCH_E_320(B, K, S) \
TORCH_CHECK(K == 1); \
DISPATCH_EB(320, 8, 1) // E=320: B set to 8
#endif

#define DISPATCH_K(E, B, K) \
if (K == 1) { \
DISPATCH_EB(E, B, 1) \
} else if (K == 2) { \
DISPATCH_EB(E, B, 2) \
} else { \
TORCH_CHECK(K == 4); \
} else if (K == 4) { \
DISPATCH_EB(E, B, 4) \
} else { \
TORCH_CHECK(K == 8); \
DISPATCH_EB(E, B, 8) \
}
#define DISPATCH_EB(E, B, K) \
kernel = (void*)index_shuffling_kernel<DataType, IndexType, E, B, K>; \
Expand All @@ -582,14 +637,14 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> index_shuffling_torch(
int storage_factor = top_k * num_experts;

if (num_experts == 16) {
DISPATCH_K(16, kNumTokensPerTileFewExperts, top_k)
DISPATCH_E_16(kNumTokensPerTileFewExperts, top_k, storage_factor)
} else if (num_experts == 32) {
DISPATCH_K(32, kNumTokensPerTileFewExperts, top_k)
DISPATCH_E_32(kNumTokensPerTileFewExperts, top_k, storage_factor)
} else if (num_experts == 128) {
DISPATCH(128, kNumTokensPerTileFewExperts, top_k, storage_factor)
DISPATCH_E_128(kNumTokensPerTileFewExperts, top_k, storage_factor)
} else {
TORCH_CHECK(num_experts == 320);
DISPATCH(320, kNumTokensPerTileFewExperts, top_k, storage_factor)
DISPATCH_E_320(kNumTokensPerTileFewExperts, top_k, storage_factor)
}
// This is to avoid build errors (divisibility asserts and local memory
// overflow) on AMD.
Expand Down
119 changes: 105 additions & 14 deletions fbgemm_gpu/experimental/gen_ai/test/moe/shuffling_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
)

from hypothesis import given, settings, strategies as st, Verbosity
from parameterized import parameterized
from pyre_extensions import none_throws

try:
Expand All @@ -40,6 +41,45 @@
_MAX_SAMPLES: int = 100


def _generate_tile_boundary_test_cases() -> list[tuple[str, int, int]]:
"""
Generate test cases at tile size threshold boundaries.

For each expert E, for each threshold T:
- Calculate K = T / E
- Round to smallest valid K >= target
- Use (E, K) as test case

Returns list of (test_name, num_experts, top_k).
"""

# Tile size boundary test configuration (sync with index_shuffling.cu)
_SUPPORTED_EXPERT_COUNTS: list[int] = [16, 32, 128, 320]
_SUPPORTED_TOP_K: list[int] = [1, 2, 4, 8]
_STORAGE_FACTOR_THRESHOLDS: list[int] = [128, 256, 512]

def _find_nearest_valid_k(target_k: float) -> int | None:
"""Find smallest K >= target_k, or None if no such K exists."""
candidates = [k for k in _SUPPORTED_TOP_K if k >= target_k]
return min(candidates) if candidates else None

seen = set()
cases = []

for num_experts in _SUPPORTED_EXPERT_COUNTS:
for threshold in _STORAGE_FACTOR_THRESHOLDS:
target_k = threshold / num_experts
nearest_k = _find_nearest_valid_k(target_k)

if nearest_k and (num_experts, nearest_k) not in seen:
seen.add((num_experts, nearest_k))
storage_factor = num_experts * nearest_k
test_name = f"e{num_experts}_k{nearest_k}_s{storage_factor}"
cases.append((test_name, num_experts, nearest_k))

return sorted(cases)


@unittest.skipIf(open_source, "Tests currently fail in open source")
@unittest.skipIf(
not torch.cuda.is_available(),
Expand All @@ -48,20 +88,7 @@
class ShufflingTests(unittest.TestCase):
"""Test shuffling kernels."""

@given(
num_tokens=st.sampled_from(
[1, 3, 123, 128, 1234, 2048, 4567, 4096, 8192, 16384]
),
num_experts=st.sampled_from([16, 32, 128, 320]),
num_local_experts=st.sampled_from([None, 8]),
top_k=st.sampled_from([1, 2, 4] if torch.version.cuda else [1]),
padded=st.sampled_from([True, False]),
rowmajor=st.sampled_from([True, False]),
compiled=st.sampled_from([True, False]),
routing_score_dtype=st.sampled_from([torch.float, torch.bfloat16]),
)
@settings(verbosity=Verbosity.verbose, max_examples=_MAX_SAMPLES, deadline=None)
def test_topk_index_shuffling(
def _run_topk_index_shuffling_test(
self,
num_tokens: int,
num_experts: int,
Expand Down Expand Up @@ -211,6 +238,70 @@ def _assert_indices_equal(
start_index = end_index
ref_start_index = ref_end_index

@parameterized.expand(_generate_tile_boundary_test_cases())
def test_topk_index_shuffling_tile_size_boundaries(
self,
name: str,
num_experts: int,
top_k: int,
) -> None:
"""
Test index shuffling at tile size threshold boundaries.

The index shuffling kernel switches the number of tokens per tile (tile size)
depending on the storage factor (shared memory pressure on CUDA). These test cases
test for correctness at the tile size threshold boundaries.
"""
# Skip K>1 on ROCm (not supported)
if top_k > 1 and torch.version.hip:
self.skipTest("ROCm only supports top_k=1")

self._run_topk_index_shuffling_test(
num_tokens=2049,
num_experts=num_experts,
num_local_experts=None,
top_k=top_k,
padded=False,
rowmajor=True,
compiled=False,
routing_score_dtype=torch.float32,
)

@given(
num_tokens=st.sampled_from(
[1, 3, 123, 128, 1234, 2048, 4567, 4096, 8192, 16384]
),
num_experts=st.sampled_from([16, 32, 128, 320]),
num_local_experts=st.sampled_from([None, 8]),
top_k=st.sampled_from([1, 2, 4, 8] if torch.version.cuda else [1]),
padded=st.sampled_from([True, False]),
rowmajor=st.sampled_from([True, False]),
compiled=st.sampled_from([True, False]),
routing_score_dtype=st.sampled_from([torch.float, torch.bfloat16]),
)
@settings(verbosity=Verbosity.verbose, max_examples=_MAX_SAMPLES, deadline=None)
def test_topk_index_shuffling(
self,
num_tokens: int,
num_experts: int,
num_local_experts: Optional[int],
top_k: int,
padded: bool,
rowmajor: bool,
compiled: bool,
routing_score_dtype: torch.dtype,
) -> None:
self._run_topk_index_shuffling_test(
num_tokens=num_tokens,
num_experts=num_experts,
num_local_experts=num_local_experts,
top_k=top_k,
padded=padded,
rowmajor=rowmajor,
compiled=compiled,
routing_score_dtype=routing_score_dtype,
)

@given(
batch_size=st.sampled_from(
[1, 8, 123, 128, 1234, 2048, 4096, 4567, 8192, 16384]
Expand Down
Loading