Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
326 changes: 266 additions & 60 deletions lightllm/common/basemodel/triton_kernel/fused_moe/grouped_fused_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,110 +214,316 @@ def moe_align1(


@triton.jit
def moe_align_fused_kernel(
def moe_align_fused_small_token_kernel(
topk_ids_ptr, # [token_num, topk]
topk_weights_ptr, # [token_num, topk]
expert_to_token_index_ptr, # [expert_num, token_num * topk]
expert_to_weight_ptr, # [expert_num, token_num * topk]
expert_token_num_ptr, # [expert_num]
token_num,
expert_num: tl.constexpr,
topk_num: tl.constexpr,
token_num_mul_topk,
BLOCK_SIZE: tl.constexpr,
INIT_EXPERT_TOKEN_NUM_IN_KERNEL: tl.constexpr,
BLOCK_EXPERT: tl.constexpr,
NUM_STAGE: tl.constexpr,
):
token_block = tl.program_id(0)
if INIT_EXPERT_TOKEN_NUM_IN_KERNEL:
expert_offs = tl.arange(0, BLOCK_EXPERT)
tl.store(expert_token_num_ptr + expert_offs, 0, mask=expert_offs < expert_num)
tl.debug_barrier()

offs = token_block * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
mask = offs < token_num * topk_num

expert_ids = tl.load(topk_ids_ptr + offs, mask=mask, other=0)
weights = tl.load(topk_weights_ptr + offs, mask=mask, other=0.0)

# 用 atomic_add 给 expert 分配写位置
write_pos = tl.atomic_add(expert_token_num_ptr + expert_ids, 1, mask=mask)
expert_id = tl.program_id(0)
block_offs = tl.arange(0, BLOCK_SIZE)
token_count = tl.full((), 0, tl.int32)

for start in tl.range(0, token_num_mul_topk, BLOCK_SIZE, num_stages=NUM_STAGE):
raw_offs = start + block_offs
valid = raw_offs < token_num_mul_topk
load_offs = tl.where(valid, raw_offs, 0)
expert_ids = tl.load(topk_ids_ptr + load_offs, mask=valid, other=-1)
weights = tl.load(topk_weights_ptr + load_offs, mask=valid, other=0.0)

expert_mask = (expert_ids == expert_id) & valid
expert_hits = tl.where(expert_mask, 1, 0)
write_pos = token_count + tl.cumsum(expert_hits, axis=0) - 1
tl.store(
expert_to_token_index_ptr + expert_id * token_num_mul_topk + write_pos,
raw_offs,
mask=expert_mask,
)
tl.store(
expert_to_weight_ptr + expert_id * token_num_mul_topk + write_pos,
weights,
mask=expert_mask,
)
token_count += tl.sum(expert_hits, axis=0)

# 按 token 顺序写 index 和 weight
tl.store(
expert_to_token_index_ptr + expert_ids * (token_num * topk_num) + write_pos,
offs,
mask=mask,
)
tl.store(
expert_to_weight_ptr + expert_ids * (token_num * topk_num) + write_pos,
weights,
mask=mask,
)
tl.store(expert_token_num_ptr + expert_id, token_count)


def _get_moe_align_fused_static_key(
expert_token_num: torch.Tensor,
topk_weights: torch.Tensor,
) -> dict:
topk_num = topk_weights.shape[1]
expert_num = expert_token_num.shape[0]
return {
"topk_num": topk_num,
"expert_num": expert_num,
}


def _get_moe_align_fused_configs():
return [
@autotune(
kernel_name="moe_align_fused_small:v2",
configs_gen_func=lambda: [
{
"BLOCK_SIZE": bt,
"num_warps": nw,
"NUM_STAGE": ns,
}
for ns in [1, 2, 4, 6]
for nw in [1, 2, 4, 8]
for bt in [128, 256, 512, 1024, 2048]
]
for bt in [8, 16, 32, 64, 128, 256, 512, 1024, 2048]
],
static_key_func=_get_moe_align_fused_static_key,
run_key_func=lambda topk_ids: topk_ids.shape[0],
mutates_args=["expert_to_token_index", "expert_to_weight", "expert_token_num"],
)
def _moe_align_fused_small_token(
expert_to_token_index,
expert_to_weight,
expert_token_num,
topk_ids,
topk_weights,
run_config: Optional[dict] = None,
):
if run_config is None:
token_num = topk_ids.shape[0]
if token_num <= 2:
run_config = {"BLOCK_SIZE": 16, "num_warps": 1, "NUM_STAGE": 1}
elif token_num <= 8:
run_config = {"BLOCK_SIZE": 64, "num_warps": 1, "NUM_STAGE": 1}
elif token_num <= 16:
run_config = {"BLOCK_SIZE": 128, "num_warps": 1, "NUM_STAGE": 1}
elif token_num < 32:
run_config = {"BLOCK_SIZE": 256, "num_warps": 2, "NUM_STAGE": 1}
elif token_num <= 64:
run_config = {"BLOCK_SIZE": 512, "num_warps": 4, "NUM_STAGE": 1}
elif token_num <= 128:
run_config = {"BLOCK_SIZE": 1024, "num_warps": 8, "NUM_STAGE": 1}
elif token_num <= 192:
run_config = {"BLOCK_SIZE": 512, "num_warps": 4, "NUM_STAGE": 1}
else:
run_config = {"BLOCK_SIZE": 2048, "num_warps": 8, "NUM_STAGE": 1}
token_num_mul_topk = topk_ids.numel()
expert_num = expert_token_num.shape[0]
block_size = run_config["BLOCK_SIZE"]

moe_align_fused_small_token_kernel[(expert_num,)](
topk_ids,
topk_weights,
expert_to_token_index,
expert_to_weight,
expert_token_num,
token_num_mul_topk,
BLOCK_SIZE=block_size,
NUM_STAGE=run_config["NUM_STAGE"],
num_warps=run_config["num_warps"],
)
return expert_to_token_index, expert_to_weight, expert_token_num


@triton.jit
def moe_align_fused_record_sorted_segment_start_kernel(
sorted_expert_ids_ptr, # [token_num * topk]
segment_start_ptr, # [expert_num]
expert_token_num_ptr, # [expert_num]
token_num_mul_topk,
expert_num,
BLOCK_SIZE: tl.constexpr,
):
block_id = tl.program_id(0)
sorted_pos = block_id * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
valid = sorted_pos < token_num_mul_topk
expert_id = tl.load(sorted_expert_ids_ptr + sorted_pos, mask=valid, other=-1)

if block_id == 0:
init_offs = tl.arange(0, BLOCK_SIZE)
for start in tl.range(0, expert_num, BLOCK_SIZE):
expert_offs = start + init_offs
tl.store(expert_token_num_ptr + expert_offs, 0, mask=expert_offs < expert_num)

# A segment starts when the sorted expert id changes.
prev_pos = sorted_pos - 1
prev_expert_id = tl.load(sorted_expert_ids_ptr + prev_pos, mask=valid & (sorted_pos > 0), other=-1)
is_segment_start = valid & ((sorted_pos == 0) | (expert_id != prev_expert_id))
tl.store(segment_start_ptr + expert_id, sorted_pos, mask=is_segment_start)


@triton.jit
def moe_align_fused_sorted_scatter_kernel(
sorted_expert_ids_ptr, # [token_num * topk]
topk_weights_ptr, # [token_num, topk]
sorted_token_index_ptr, # [token_num * topk]
segment_start_ptr, # [expert_num]
expert_token_num_ptr, # [expert_num]
expert_to_token_index_ptr, # [expert_num, token_num * topk]
expert_to_weight_ptr, # [expert_num, token_num * topk]
token_num_mul_topk,
BLOCK_SIZE: tl.constexpr,
):
block_id = tl.program_id(0)
sorted_pos = block_id * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
valid = sorted_pos < token_num_mul_topk

# sorted_token_index maps a sorted slot back to the original flattened topk slot.
token_index = tl.load(sorted_token_index_ptr + sorted_pos, mask=valid, other=0)
expert_id = tl.load(sorted_expert_ids_ptr + sorted_pos, mask=valid, other=0)

# Each expert owns a contiguous sorted segment, so local row offset is
# the sorted position minus that expert's segment start.
expert_start = tl.load(segment_start_ptr + expert_id, mask=valid, other=0)
expert_pos = sorted_pos - expert_start
weight = tl.load(topk_weights_ptr + token_index, mask=valid, other=0.0)

# A segment end gives the final token count for this expert.
next_pos = sorted_pos + 1
next_expert_id = tl.load(sorted_expert_ids_ptr + next_pos, mask=valid & (next_pos < token_num_mul_topk), other=-1)
is_segment_end = valid & ((next_pos == token_num_mul_topk) | (expert_id != next_expert_id))
tl.store(expert_token_num_ptr + expert_id, next_pos - expert_start, mask=is_segment_end)

tl.store(
expert_to_token_index_ptr + expert_id * token_num_mul_topk + expert_pos,
token_index,
mask=valid,
)
tl.store(
expert_to_weight_ptr + expert_id * token_num_mul_topk + expert_pos,
weight,
mask=valid,
)


def _make_moe_align_fused_sorted_configs():
configs = []
for scatter_block_size in [128, 256, 512]:
for scatter_num_warps in [4, 8]:
for record_block_size in [128, 256, 512, 1024]:
for record_num_warps in [4, 8]:
configs.append(
{
"SCATTER_BLOCK_SIZE": scatter_block_size,
"SCATTER_NUM_WARPS": scatter_num_warps,
"RECORD_BLOCK_SIZE": record_block_size,
"RECORD_NUM_WARPS": record_num_warps,
}
)
return configs


@autotune(
kernel_name="moe_align_fused:v1",
configs_gen_func=_get_moe_align_fused_configs,
kernel_name="moe_align_fused_sorted:v2",
configs_gen_func=_make_moe_align_fused_sorted_configs,
static_key_func=_get_moe_align_fused_static_key,
run_key_func=lambda topk_ids: topk_ids.shape[0],
mutates_args=["expert_to_token_index", "expert_to_weight", "expert_token_num"],
)
def moe_align_fused(
expert_to_token_index, expert_to_weight, expert_token_num, topk_ids, topk_weights, run_config: Optional[dict] = None
def _moe_align_fused_sorted_token(
expert_to_token_index,
expert_to_weight,
expert_token_num,
topk_ids,
topk_weights,
run_config: Optional[dict] = None,
):
token_num, topk_num = topk_ids.shape
if run_config is None:
run_config = {}
BLOCK_SIZE = run_config.get("BLOCK_SIZE", 256)
num_warps = run_config.get("num_warps", 4)
run_config = {
"SCATTER_BLOCK_SIZE": 128,
"SCATTER_NUM_WARPS": 4,
"RECORD_BLOCK_SIZE": 256,
"RECORD_NUM_WARPS": 8,
}

# For small inputs the align kernel has a single program, so it can initialize
# expert_token_num itself and avoid an extra zero_ kernel launch.
token_num_mul_topk = topk_ids.numel()
expert_num = expert_token_num.shape[0]
init_expert_token_num_in_kernel = token_num * topk_num <= BLOCK_SIZE
if not init_expert_token_num_in_kernel:
# Multiple align programs may update expert_token_num concurrently; initialize
# it before launch to avoid races between clear and atomic_add.
expert_token_num.zero_()

grid = (triton.cdiv(token_num * topk_num, BLOCK_SIZE),)
moe_align_fused_kernel[grid](
record_grid = (triton.cdiv(token_num_mul_topk, run_config["RECORD_BLOCK_SIZE"]),)
scatter_grid = (triton.cdiv(token_num_mul_topk, run_config["SCATTER_BLOCK_SIZE"]),)

flat_topk_ids = topk_ids.view(-1)
sorted_expert_ids, sorted_token_index = torch.sort(flat_topk_ids, stable=True)
segment_start = torch.empty_like(expert_token_num)

# Stable sort makes every expert's tokens contiguous. Record each segment
# start and clear empty-expert counts, then scatter compact expert rows.
moe_align_fused_record_sorted_segment_start_kernel[record_grid](
sorted_expert_ids,
segment_start,
expert_token_num,
token_num_mul_topk,
expert_num,
BLOCK_SIZE=run_config["RECORD_BLOCK_SIZE"],
num_warps=run_config["RECORD_NUM_WARPS"],
)
moe_align_fused_sorted_scatter_kernel[scatter_grid](
sorted_expert_ids,
topk_weights,
sorted_token_index,
segment_start,
expert_token_num,
expert_to_token_index,
expert_to_weight,
token_num_mul_topk,
BLOCK_SIZE=run_config["SCATTER_BLOCK_SIZE"],
num_warps=run_config["SCATTER_NUM_WARPS"],
)
return expert_to_token_index, expert_to_weight, expert_token_num


@triton.jit
def moe_align_fused_atomic_kernel(
topk_ids_ptr, # [token_num, topk]
topk_weights_ptr, # [token_num, topk]
expert_to_token_index_ptr, # [expert_num, token_num * topk]
expert_to_weight_ptr, # [expert_num, token_num * topk]
expert_token_num_ptr, # [expert_num]
token_num_mul_topk,
BLOCK_SIZE: tl.constexpr,
):
block_id = tl.program_id(0)
offs = block_id * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
valid = offs < token_num_mul_topk
expert_id = tl.load(topk_ids_ptr + offs, mask=valid, other=0)
weight = tl.load(topk_weights_ptr + offs, mask=valid, other=0.0)
write_pos = tl.atomic_add(expert_token_num_ptr + expert_id, 1, mask=valid)
tl.store(expert_to_token_index_ptr + expert_id * token_num_mul_topk + write_pos, offs, mask=valid)
tl.store(expert_to_weight_ptr + expert_id * token_num_mul_topk + write_pos, weight, mask=valid)


def _moe_align_fused_atomic_token(
expert_to_token_index,
expert_to_weight,
expert_token_num,
topk_ids,
topk_weights,
):
token_num_mul_topk = topk_ids.numel()
expert_token_num.zero_()
moe_align_fused_atomic_kernel[(triton.cdiv(token_num_mul_topk, 128),)](
topk_ids,
topk_weights,
expert_to_token_index,
expert_to_weight,
expert_token_num,
token_num,
expert_num,
topk_num,
BLOCK_SIZE=BLOCK_SIZE,
INIT_EXPERT_TOKEN_NUM_IN_KERNEL=init_expert_token_num_in_kernel,
BLOCK_EXPERT=triton.next_power_of_2(expert_num),
num_warps=num_warps,
token_num_mul_topk,
BLOCK_SIZE=128,
num_warps=4,
)
return expert_to_token_index, expert_to_weight, expert_token_num


def moe_align_fused(expert_to_token_index, expert_to_weight, expert_token_num, topk_ids, topk_weights):
token_num = topk_ids.shape[0]
if token_num <= 128:
_moe_align_fused_small_token(expert_to_token_index, expert_to_weight, expert_token_num, topk_ids, topk_weights)
else:
# Expert rows may be unordered, but grouped matmul reuses this same
# mapping for up/down projections and writes back to original topk slots.
_moe_align_fused_atomic_token(expert_to_token_index, expert_to_weight, expert_token_num, topk_ids, topk_weights)
return expert_to_token_index, expert_to_weight, expert_token_num


@triton.jit
def moe_align2_kernel(
experts_token_num_ptr, # [expert_num,]
Expand Down
Loading
Loading