Make moe align fused deterministic#1369
Conversation
There was a problem hiding this comment.
Code Review
This pull request replaces the atomic-based fused MoE alignment kernel with a deterministic version that processes tokens block-by-block per expert block, and adds a corresponding unit test. The review feedback highlights a critical bug in the kernel where using modulo % expert_num on expert_offsets causes duplicate processing of wrapped-around experts when expert_num is not a multiple of EXPERT_BLOCK. To address this, standard out-of-bounds masking is recommended. Additionally, the reviewer suggests updating the unit test configuration to use EXPERT_BLOCK = 4 to ensure this edge case is properly covered.
Important
The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.
| expert_block_id = tl.program_id(0) | ||
| expert_offsets = (expert_block_id * EXPERT_BLOCK + tl.arange(0, EXPERT_BLOCK)) % expert_num | ||
| block_offs = tl.arange(0, BLOCK_SIZE) | ||
| token_count = tl.full((EXPERT_BLOCK,), 0, tl.int32) | ||
|
|
||
| for start in tl.range(0, token_num_mul_topk, BLOCK_SIZE, num_stages=NUM_STAGE): | ||
| raw_offs = start + block_offs | ||
| valid = raw_offs < token_num_mul_topk | ||
| load_offs = tl.where(valid, raw_offs, 0) | ||
| expert_ids = tl.load(topk_ids_ptr + load_offs, mask=valid, other=-1) | ||
| weights = tl.load(topk_weights_ptr + load_offs, mask=valid, other=0.0) | ||
|
|
||
| expert_mask = (expert_ids[None, :] == expert_offsets[:, None]) & valid[None, :] | ||
| expert_hits = tl.where(expert_mask, 1, 0) | ||
| local_pos = tl.cumsum(expert_hits, axis=1) - 1 | ||
| write_pos = token_count[:, None] + local_pos | ||
| tl.store( | ||
| expert_to_token_index_ptr + expert_offsets[:, None] * token_num_mul_topk + write_pos, | ||
| raw_offs[None, :], | ||
| mask=expert_mask, | ||
| ) | ||
| tl.store( | ||
| expert_to_weight_ptr + expert_offsets[:, None] * token_num_mul_topk + write_pos, | ||
| weights[None, :], | ||
| mask=expert_mask, | ||
| ) | ||
| token_count += tl.sum(expert_hits, axis=1) | ||
|
|
||
| # 按 token 顺序写 index 和 weight | ||
| tl.store( | ||
| expert_to_token_index_ptr + expert_ids * (token_num * topk_num) + write_pos, | ||
| offs, | ||
| mask=mask, | ||
| ) | ||
| tl.store( | ||
| expert_to_weight_ptr + expert_ids * (token_num * topk_num) + write_pos, | ||
| weights, | ||
| mask=mask, | ||
| ) | ||
| tl.store(expert_token_num_ptr + expert_offsets, token_count) |
There was a problem hiding this comment.
Using modulo % expert_num on expert_offsets causes duplicate processing of wrapped-around experts when expert_num is not a multiple of EXPERT_BLOCK. For example, if expert_num = 5 and EXPERT_BLOCK = 4, Block 1 will process experts [4, 0, 1, 2], overwriting the correct results computed by Block 0 for experts 0, 1, 2.
To fix this, we should use standard out-of-bounds masking with expert_offsets < expert_num and a safe offset array to avoid out-of-bounds pointer arithmetic.
expert_block_id = tl.program_id(0)
expert_offsets = expert_block_id * EXPERT_BLOCK + tl.arange(0, EXPERT_BLOCK)
expert_offsets_mask = expert_offsets < expert_num
safe_expert_offsets = tl.where(expert_offsets_mask, expert_offsets, 0)
block_offs = tl.arange(0, BLOCK_SIZE)
token_count = tl.full((EXPERT_BLOCK,), 0, tl.int32)
for start in tl.range(0, token_num_mul_topk, BLOCK_SIZE, num_stages=NUM_STAGE):
raw_offs = start + block_offs
valid = raw_offs < token_num_mul_topk
load_offs = tl.where(valid, raw_offs, 0)
expert_ids = tl.load(topk_ids_ptr + load_offs, mask=valid, other=-1)
weights = tl.load(topk_weights_ptr + load_offs, mask=valid, other=0.0)
expert_mask = (expert_ids[None, :] == safe_expert_offsets[:, None]) & valid[None, :] & expert_offsets_mask[:, None]
expert_hits = tl.where(expert_mask, 1, 0)
local_pos = tl.cumsum(expert_hits, axis=1) - 1
write_pos = token_count[:, None] + local_pos
tl.store(
expert_to_token_index_ptr + safe_expert_offsets[:, None] * token_num_mul_topk + write_pos,
raw_offs[None, :],
mask=expert_mask,
)
tl.store(
expert_to_weight_ptr + safe_expert_offsets[:, None] * token_num_mul_topk + write_pos,
weights[None, :],
mask=expert_mask,
)
token_count += tl.sum(expert_hits, axis=1)
tl.store(expert_token_num_ptr + safe_expert_offsets, token_count, mask=expert_offsets_mask)| expert_token_num, | ||
| topk_ids, | ||
| topk_weights, | ||
| run_config={"BLOCK_SIZE": 1024, "EXPERT_BLOCK": 1, "num_warps": 8, "NUM_STAGE": 1}, |
There was a problem hiding this comment.
Change EXPERT_BLOCK to 4 (or any value greater than 1) to ensure the test covers cases where expert_num is not a multiple of EXPERT_BLOCK (since expert_num = 5 here). This will help catch wrap-around or out-of-bounds masking bugs.
| run_config={"BLOCK_SIZE": 1024, "EXPERT_BLOCK": 1, "num_warps": 8, "NUM_STAGE": 1}, | |
| run_config={"BLOCK_SIZE": 1024, "EXPERT_BLOCK": 4, "num_warps": 8, "NUM_STAGE": 1}, |
No description provided.