Skip to content

Make moe align fused deterministic#1369

Open
hiworldwzj wants to merge 5 commits into
mainfrom
wzj_moe
Open

Make moe align fused deterministic#1369
hiworldwzj wants to merge 5 commits into
mainfrom
wzj_moe

Conversation

@hiworldwzj

Copy link
Copy Markdown
Collaborator

No description provided.

@gemini-code-assist gemini-code-assist Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request replaces the atomic-based fused MoE alignment kernel with a deterministic version that processes tokens block-by-block per expert block, and adds a corresponding unit test. The review feedback highlights a critical bug in the kernel where using modulo % expert_num on expert_offsets causes duplicate processing of wrapped-around experts when expert_num is not a multiple of EXPERT_BLOCK. To address this, standard out-of-bounds masking is recommended. Additionally, the reviewer suggests updating the unit test configuration to use EXPERT_BLOCK = 4 to ensure this edge case is properly covered.

Important

The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.

Comment on lines +229 to +257
expert_block_id = tl.program_id(0)
expert_offsets = (expert_block_id * EXPERT_BLOCK + tl.arange(0, EXPERT_BLOCK)) % expert_num
block_offs = tl.arange(0, BLOCK_SIZE)
token_count = tl.full((EXPERT_BLOCK,), 0, tl.int32)

for start in tl.range(0, token_num_mul_topk, BLOCK_SIZE, num_stages=NUM_STAGE):
raw_offs = start + block_offs
valid = raw_offs < token_num_mul_topk
load_offs = tl.where(valid, raw_offs, 0)
expert_ids = tl.load(topk_ids_ptr + load_offs, mask=valid, other=-1)
weights = tl.load(topk_weights_ptr + load_offs, mask=valid, other=0.0)

expert_mask = (expert_ids[None, :] == expert_offsets[:, None]) & valid[None, :]
expert_hits = tl.where(expert_mask, 1, 0)
local_pos = tl.cumsum(expert_hits, axis=1) - 1
write_pos = token_count[:, None] + local_pos
tl.store(
expert_to_token_index_ptr + expert_offsets[:, None] * token_num_mul_topk + write_pos,
raw_offs[None, :],
mask=expert_mask,
)
tl.store(
expert_to_weight_ptr + expert_offsets[:, None] * token_num_mul_topk + write_pos,
weights[None, :],
mask=expert_mask,
)
token_count += tl.sum(expert_hits, axis=1)

# 按 token 顺序写 index 和 weight
tl.store(
expert_to_token_index_ptr + expert_ids * (token_num * topk_num) + write_pos,
offs,
mask=mask,
)
tl.store(
expert_to_weight_ptr + expert_ids * (token_num * topk_num) + write_pos,
weights,
mask=mask,
)
tl.store(expert_token_num_ptr + expert_offsets, token_count)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

Using modulo % expert_num on expert_offsets causes duplicate processing of wrapped-around experts when expert_num is not a multiple of EXPERT_BLOCK. For example, if expert_num = 5 and EXPERT_BLOCK = 4, Block 1 will process experts [4, 0, 1, 2], overwriting the correct results computed by Block 0 for experts 0, 1, 2.

To fix this, we should use standard out-of-bounds masking with expert_offsets < expert_num and a safe offset array to avoid out-of-bounds pointer arithmetic.

    expert_block_id = tl.program_id(0)
    expert_offsets = expert_block_id * EXPERT_BLOCK + tl.arange(0, EXPERT_BLOCK)
    expert_offsets_mask = expert_offsets < expert_num
    safe_expert_offsets = tl.where(expert_offsets_mask, expert_offsets, 0)
    block_offs = tl.arange(0, BLOCK_SIZE)
    token_count = tl.full((EXPERT_BLOCK,), 0, tl.int32)

    for start in tl.range(0, token_num_mul_topk, BLOCK_SIZE, num_stages=NUM_STAGE):
        raw_offs = start + block_offs
        valid = raw_offs < token_num_mul_topk
        load_offs = tl.where(valid, raw_offs, 0)
        expert_ids = tl.load(topk_ids_ptr + load_offs, mask=valid, other=-1)
        weights = tl.load(topk_weights_ptr + load_offs, mask=valid, other=0.0)

        expert_mask = (expert_ids[None, :] == safe_expert_offsets[:, None]) & valid[None, :] & expert_offsets_mask[:, None]
        expert_hits = tl.where(expert_mask, 1, 0)
        local_pos = tl.cumsum(expert_hits, axis=1) - 1
        write_pos = token_count[:, None] + local_pos
        tl.store(
            expert_to_token_index_ptr + safe_expert_offsets[:, None] * token_num_mul_topk + write_pos,
            raw_offs[None, :],
            mask=expert_mask,
        )
        tl.store(
            expert_to_weight_ptr + safe_expert_offsets[:, None] * token_num_mul_topk + write_pos,
            weights[None, :],
            mask=expert_mask,
        )
        token_count += tl.sum(expert_hits, axis=1)

    tl.store(expert_token_num_ptr + safe_expert_offsets, token_count, mask=expert_offsets_mask)

expert_token_num,
topk_ids,
topk_weights,
run_config={"BLOCK_SIZE": 1024, "EXPERT_BLOCK": 1, "num_warps": 8, "NUM_STAGE": 1},

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Change EXPERT_BLOCK to 4 (or any value greater than 1) to ensure the test covers cases where expert_num is not a multiple of EXPERT_BLOCK (since expert_num = 5 here). This will help catch wrap-around or out-of-bounds masking bugs.

Suggested change
run_config={"BLOCK_SIZE": 1024, "EXPERT_BLOCK": 1, "num_warps": 8, "NUM_STAGE": 1},
run_config={"BLOCK_SIZE": 1024, "EXPERT_BLOCK": 4, "num_warps": 8, "NUM_STAGE": 1},

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant