From d0210b3251cc89bd14d0765af29f2349ad24178d Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Wed, 24 Jun 2026 05:13:54 +0000 Subject: [PATCH 1/8] Make moe align fused deterministic --- .../fused_moe/grouped_fused_moe.py | 101 ++++++++---------- .../fused_moe/test_grouped_fused_moe.py | 33 ++++++ 2 files changed, 80 insertions(+), 54 deletions(-) diff --git a/lightllm/common/basemodel/triton_kernel/fused_moe/grouped_fused_moe.py b/lightllm/common/basemodel/triton_kernel/fused_moe/grouped_fused_moe.py index 1f7175f72..bc0d7bf03 100644 --- a/lightllm/common/basemodel/triton_kernel/fused_moe/grouped_fused_moe.py +++ b/lightllm/common/basemodel/triton_kernel/fused_moe/grouped_fused_moe.py @@ -214,45 +214,43 @@ def moe_align1( @triton.jit -def moe_align_fused_kernel( +def moe_align_fused_deterministic_kernel( topk_ids_ptr, # [token_num, topk] topk_weights_ptr, # [token_num, topk] expert_to_token_index_ptr, # [expert_num, token_num * topk] expert_to_weight_ptr, # [expert_num, token_num * topk] expert_token_num_ptr, # [expert_num] - token_num, - expert_num: tl.constexpr, - topk_num: tl.constexpr, + token_num_mul_topk, BLOCK_SIZE: tl.constexpr, - INIT_EXPERT_TOKEN_NUM_IN_KERNEL: tl.constexpr, - BLOCK_EXPERT: tl.constexpr, + NUM_STAGE: tl.constexpr, ): - token_block = tl.program_id(0) - if INIT_EXPERT_TOKEN_NUM_IN_KERNEL: - expert_offs = tl.arange(0, BLOCK_EXPERT) - tl.store(expert_token_num_ptr + expert_offs, 0, mask=expert_offs < expert_num) - tl.debug_barrier() - - offs = token_block * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) - mask = offs < token_num * topk_num - - expert_ids = tl.load(topk_ids_ptr + offs, mask=mask, other=0) - weights = tl.load(topk_weights_ptr + offs, mask=mask, other=0.0) - - # 用 atomic_add 给 expert 分配写位置 - write_pos = tl.atomic_add(expert_token_num_ptr + expert_ids, 1, mask=mask) + expert_id = tl.program_id(0) + block_offs = tl.arange(0, BLOCK_SIZE) + token_count = tl.full((), 0, tl.int32) + + for start in tl.range(0, token_num_mul_topk, BLOCK_SIZE, num_stages=NUM_STAGE): + raw_offs = start + block_offs + valid = raw_offs < token_num_mul_topk + load_offs = tl.where(valid, raw_offs, 0) + expert_ids = tl.load(topk_ids_ptr + load_offs) + weights = tl.load(topk_weights_ptr + load_offs) + + expert_mask = (expert_ids == expert_id) & valid + local_pos = tl.cumsum(tl.where(expert_mask, 1, 0)) - 1 + write_pos = token_count + local_pos + tl.store( + expert_to_token_index_ptr + expert_id * token_num_mul_topk + write_pos, + raw_offs, + mask=expert_mask, + ) + tl.store( + expert_to_weight_ptr + expert_id * token_num_mul_topk + write_pos, + weights, + mask=expert_mask, + ) + token_count += tl.sum(tl.where(expert_mask, 1, 0), axis=0) - # 按 token 顺序写 index 和 weight - tl.store( - expert_to_token_index_ptr + expert_ids * (token_num * topk_num) + write_pos, - offs, - mask=mask, - ) - tl.store( - expert_to_weight_ptr + expert_ids * (token_num * topk_num) + write_pos, - weights, - mask=mask, - ) + tl.store(expert_token_num_ptr + expert_id, token_count) def _get_moe_align_fused_static_key( @@ -269,14 +267,16 @@ def _get_moe_align_fused_configs(): { "BLOCK_SIZE": bt, "num_warps": nw, + "NUM_STAGE": ns, } + for ns in [1, 2, 4, 6] for nw in [1, 2, 4, 8] - for bt in [128, 256, 512, 1024, 2048] + for bt in [32, 64, 128, 256, 512, 1024, 2048] ] @autotune( - kernel_name="moe_align_fused:v1", + kernel_name="moe_align_fused:v2", configs_gen_func=_get_moe_align_fused_configs, static_key_func=_get_moe_align_fused_static_key, run_key_func=lambda topk_ids: topk_ids.shape[0], @@ -286,34 +286,27 @@ def moe_align_fused( expert_to_token_index, expert_to_weight, expert_token_num, topk_ids, topk_weights, run_config: Optional[dict] = None ): token_num, topk_num = topk_ids.shape + expert_num = expert_token_num.shape[0] + token_num_mul_topk = token_num * topk_num + if run_config is None: - run_config = {} - BLOCK_SIZE = run_config.get("BLOCK_SIZE", 256) - num_warps = run_config.get("num_warps", 4) + if token_num_mul_topk <= 256: + run_config = {"BLOCK_SIZE": 256, "num_warps": 4, "NUM_STAGE": 4} + elif token_num_mul_topk <= 4096: + run_config = {"BLOCK_SIZE": 512, "num_warps": 8, "NUM_STAGE": 1} + else: + run_config = {"BLOCK_SIZE": 1024, "num_warps": 8, "NUM_STAGE": 1} - # For small inputs the align kernel has a single program, so it can initialize - # expert_token_num itself and avoid an extra zero_ kernel launch. - expert_num = expert_token_num.shape[0] - init_expert_token_num_in_kernel = token_num * topk_num <= BLOCK_SIZE - if not init_expert_token_num_in_kernel: - # Multiple align programs may update expert_token_num concurrently; initialize - # it before launch to avoid races between clear and atomic_add. - expert_token_num.zero_() - - grid = (triton.cdiv(token_num * topk_num, BLOCK_SIZE),) - moe_align_fused_kernel[grid]( + moe_align_fused_deterministic_kernel[(expert_num,)]( topk_ids, topk_weights, expert_to_token_index, expert_to_weight, expert_token_num, - token_num, - expert_num, - topk_num, - BLOCK_SIZE=BLOCK_SIZE, - INIT_EXPERT_TOKEN_NUM_IN_KERNEL=init_expert_token_num_in_kernel, - BLOCK_EXPERT=triton.next_power_of_2(expert_num), - num_warps=num_warps, + token_num_mul_topk, + BLOCK_SIZE=run_config["BLOCK_SIZE"], + NUM_STAGE=run_config["NUM_STAGE"], + num_warps=run_config["num_warps"], ) return expert_to_token_index, expert_to_weight, expert_token_num diff --git a/unit_tests/common/fused_moe/test_grouped_fused_moe.py b/unit_tests/common/fused_moe/test_grouped_fused_moe.py index 9c08cfc1a..9d3a85e75 100644 --- a/unit_tests/common/fused_moe/test_grouped_fused_moe.py +++ b/unit_tests/common/fused_moe/test_grouped_fused_moe.py @@ -4,6 +4,7 @@ import triton from lightllm.common.basemodel.triton_kernel.fused_moe.grouped_fused_moe import ( moe_align, + moe_align_fused, moe_align1, moe_align2, grouped_matmul, @@ -74,6 +75,38 @@ def test_moe_align1(): assert torch.equal(experts_info, true_experts_info) +def test_moe_align_fused(): + expert_num = 5 + topk_ids = torch.tensor([[0, 1, 2], [0, 3, 1], [3, 1, 4]], dtype=torch.int32, device="cuda") + topk_weights = torch.tensor([[0.3, 0.7, 0.1], [0.2, 0.8, 0.4], [0.5, 0.6, 0.9]], dtype=torch.float32, device="cuda") + expert_to_token_index = torch.empty((expert_num, topk_ids.numel()), dtype=torch.int32, device="cuda") + expert_to_weight = torch.empty((expert_num, topk_ids.numel()), dtype=torch.float32, device="cuda") + expert_token_num = torch.empty((expert_num,), dtype=torch.int32, device="cuda") + + moe_align_fused( + expert_to_token_index, + expert_to_weight, + expert_token_num, + topk_ids, + topk_weights, + run_config={"BLOCK_SIZE": 1024, "num_warps": 8, "NUM_STAGE": 1}, + ) + torch.cuda.synchronize() + + true_expert_token_num = torch.tensor([2, 3, 1, 2, 1], device="cuda", dtype=torch.int32) + assert torch.equal(expert_token_num, true_expert_token_num) + + flat_topk_ids = topk_ids.flatten() + flat_topk_weights = topk_weights.flatten() + for expert_id in range(expert_num): + mask = flat_topk_ids == expert_id + true_index = torch.nonzero(mask, as_tuple=False).flatten().to(torch.int32) + true_weight = flat_topk_weights[mask] + token_num = true_expert_token_num[expert_id] + assert torch.equal(expert_to_token_index[expert_id, :token_num], true_index) + assert torch.allclose(expert_to_weight[expert_id, :token_num], true_weight) + + def test_moe_align2(): experts_token_num = torch.zeros((4,), dtype=torch.int32, device="cuda") From 189ed93809ff874270cf43acb5a3186a67d2ddf4 Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Thu, 25 Jun 2026 02:59:16 +0000 Subject: [PATCH 2/8] fix --- .../fused_moe/grouped_fused_moe.py | 57 ++++++++++++------- .../fused_moe/test_grouped_fused_moe.py | 2 +- 2 files changed, 38 insertions(+), 21 deletions(-) diff --git a/lightllm/common/basemodel/triton_kernel/fused_moe/grouped_fused_moe.py b/lightllm/common/basemodel/triton_kernel/fused_moe/grouped_fused_moe.py index bc0d7bf03..4f962f18c 100644 --- a/lightllm/common/basemodel/triton_kernel/fused_moe/grouped_fused_moe.py +++ b/lightllm/common/basemodel/triton_kernel/fused_moe/grouped_fused_moe.py @@ -220,45 +220,52 @@ def moe_align_fused_deterministic_kernel( expert_to_token_index_ptr, # [expert_num, token_num * topk] expert_to_weight_ptr, # [expert_num, token_num * topk] expert_token_num_ptr, # [expert_num] + expert_num, token_num_mul_topk, BLOCK_SIZE: tl.constexpr, + EXPERT_BLOCK: tl.constexpr, NUM_STAGE: tl.constexpr, ): - expert_id = tl.program_id(0) + expert_block_id = tl.program_id(0) + expert_offsets = (expert_block_id * EXPERT_BLOCK + tl.arange(0, EXPERT_BLOCK)) % expert_num block_offs = tl.arange(0, BLOCK_SIZE) - token_count = tl.full((), 0, tl.int32) + token_count = tl.full((EXPERT_BLOCK,), 0, tl.int32) for start in tl.range(0, token_num_mul_topk, BLOCK_SIZE, num_stages=NUM_STAGE): raw_offs = start + block_offs valid = raw_offs < token_num_mul_topk load_offs = tl.where(valid, raw_offs, 0) - expert_ids = tl.load(topk_ids_ptr + load_offs) - weights = tl.load(topk_weights_ptr + load_offs) + expert_ids = tl.load(topk_ids_ptr + load_offs, mask=valid, other=-1) + weights = tl.load(topk_weights_ptr + load_offs, mask=valid, other=0.0) - expert_mask = (expert_ids == expert_id) & valid - local_pos = tl.cumsum(tl.where(expert_mask, 1, 0)) - 1 - write_pos = token_count + local_pos + expert_mask = (expert_ids[None, :] == expert_offsets[:, None]) & valid[None, :] + expert_hits = tl.where(expert_mask, 1, 0) + local_pos = tl.cumsum(expert_hits, axis=1) - 1 + write_pos = token_count[:, None] + local_pos tl.store( - expert_to_token_index_ptr + expert_id * token_num_mul_topk + write_pos, - raw_offs, + expert_to_token_index_ptr + expert_offsets[:, None] * token_num_mul_topk + write_pos, + raw_offs[None, :], mask=expert_mask, ) tl.store( - expert_to_weight_ptr + expert_id * token_num_mul_topk + write_pos, - weights, + expert_to_weight_ptr + expert_offsets[:, None] * token_num_mul_topk + write_pos, + weights[None, :], mask=expert_mask, ) - token_count += tl.sum(tl.where(expert_mask, 1, 0), axis=0) + token_count += tl.sum(expert_hits, axis=1) - tl.store(expert_token_num_ptr + expert_id, token_count) + tl.store(expert_token_num_ptr + expert_offsets, token_count) def _get_moe_align_fused_static_key( + expert_token_num: torch.Tensor, topk_weights: torch.Tensor, ) -> dict: topk_num = topk_weights.shape[1] + expert_num = expert_token_num.shape[0] return { "topk_num": topk_num, + "expert_num": expert_num, } @@ -266,17 +273,19 @@ def _get_moe_align_fused_configs(): return [ { "BLOCK_SIZE": bt, + "EXPERT_BLOCK": eb, "num_warps": nw, "NUM_STAGE": ns, } for ns in [1, 2, 4, 6] for nw in [1, 2, 4, 8] + for eb in [1, 4, 8, 16, 32] for bt in [32, 64, 128, 256, 512, 1024, 2048] ] @autotune( - kernel_name="moe_align_fused:v2", + kernel_name="moe_align_fused:v3", configs_gen_func=_get_moe_align_fused_configs, static_key_func=_get_moe_align_fused_static_key, run_key_func=lambda topk_ids: topk_ids.shape[0], @@ -290,21 +299,29 @@ def moe_align_fused( token_num_mul_topk = token_num * topk_num if run_config is None: - if token_num_mul_topk <= 256: - run_config = {"BLOCK_SIZE": 256, "num_warps": 4, "NUM_STAGE": 4} - elif token_num_mul_topk <= 4096: - run_config = {"BLOCK_SIZE": 512, "num_warps": 8, "NUM_STAGE": 1} + if token_num_mul_topk <= 128: + run_config = {"BLOCK_SIZE": 128, "EXPERT_BLOCK": 8, "num_warps": 1, "NUM_STAGE": 1} + elif token_num_mul_topk <= 256: + run_config = {"BLOCK_SIZE": 128, "EXPERT_BLOCK": 8, "num_warps": 4, "NUM_STAGE": 4} + elif token_num_mul_topk <= 512: + run_config = {"BLOCK_SIZE": 256, "EXPERT_BLOCK": 4, "num_warps": 8, "NUM_STAGE": 1} + elif token_num_mul_topk <= 1536: + run_config = {"BLOCK_SIZE": 2048, "EXPERT_BLOCK": 1, "num_warps": 8, "NUM_STAGE": 1} + elif token_num_mul_topk <= 3072: + run_config = {"BLOCK_SIZE": 1024, "EXPERT_BLOCK": 1, "num_warps": 8, "NUM_STAGE": 1} else: - run_config = {"BLOCK_SIZE": 1024, "num_warps": 8, "NUM_STAGE": 1} + run_config = {"BLOCK_SIZE": 2048, "EXPERT_BLOCK": 1, "num_warps": 8, "NUM_STAGE": 1} - moe_align_fused_deterministic_kernel[(expert_num,)]( + moe_align_fused_deterministic_kernel[(triton.cdiv(expert_num, run_config["EXPERT_BLOCK"]),)]( topk_ids, topk_weights, expert_to_token_index, expert_to_weight, expert_token_num, + expert_num, token_num_mul_topk, BLOCK_SIZE=run_config["BLOCK_SIZE"], + EXPERT_BLOCK=run_config["EXPERT_BLOCK"], NUM_STAGE=run_config["NUM_STAGE"], num_warps=run_config["num_warps"], ) diff --git a/unit_tests/common/fused_moe/test_grouped_fused_moe.py b/unit_tests/common/fused_moe/test_grouped_fused_moe.py index 9d3a85e75..1f8e8f366 100644 --- a/unit_tests/common/fused_moe/test_grouped_fused_moe.py +++ b/unit_tests/common/fused_moe/test_grouped_fused_moe.py @@ -89,7 +89,7 @@ def test_moe_align_fused(): expert_token_num, topk_ids, topk_weights, - run_config={"BLOCK_SIZE": 1024, "num_warps": 8, "NUM_STAGE": 1}, + run_config={"BLOCK_SIZE": 1024, "EXPERT_BLOCK": 1, "num_warps": 8, "NUM_STAGE": 1}, ) torch.cuda.synchronize() From ba710495b9ff3dd046d546304b4ba90caada1c94 Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Thu, 25 Jun 2026 08:53:58 +0000 Subject: [PATCH 3/8] fix --- .../fused_moe/grouped_fused_moe.py | 173 +++++++++++++----- .../fused_moe/test_grouped_fused_moe.py | 78 +++++--- 2 files changed, 185 insertions(+), 66 deletions(-) diff --git a/lightllm/common/basemodel/triton_kernel/fused_moe/grouped_fused_moe.py b/lightllm/common/basemodel/triton_kernel/fused_moe/grouped_fused_moe.py index 4f962f18c..ca20c3273 100644 --- a/lightllm/common/basemodel/triton_kernel/fused_moe/grouped_fused_moe.py +++ b/lightllm/common/basemodel/triton_kernel/fused_moe/grouped_fused_moe.py @@ -214,22 +214,19 @@ def moe_align1( @triton.jit -def moe_align_fused_deterministic_kernel( +def moe_align_fused_small_token_kernel( topk_ids_ptr, # [token_num, topk] topk_weights_ptr, # [token_num, topk] expert_to_token_index_ptr, # [expert_num, token_num * topk] expert_to_weight_ptr, # [expert_num, token_num * topk] expert_token_num_ptr, # [expert_num] - expert_num, token_num_mul_topk, BLOCK_SIZE: tl.constexpr, - EXPERT_BLOCK: tl.constexpr, NUM_STAGE: tl.constexpr, ): - expert_block_id = tl.program_id(0) - expert_offsets = (expert_block_id * EXPERT_BLOCK + tl.arange(0, EXPERT_BLOCK)) % expert_num + expert_id = tl.program_id(0) block_offs = tl.arange(0, BLOCK_SIZE) - token_count = tl.full((EXPERT_BLOCK,), 0, tl.int32) + token_count = tl.full((), 0, tl.int32) for start in tl.range(0, token_num_mul_topk, BLOCK_SIZE, num_stages=NUM_STAGE): raw_offs = start + block_offs @@ -238,23 +235,22 @@ def moe_align_fused_deterministic_kernel( expert_ids = tl.load(topk_ids_ptr + load_offs, mask=valid, other=-1) weights = tl.load(topk_weights_ptr + load_offs, mask=valid, other=0.0) - expert_mask = (expert_ids[None, :] == expert_offsets[:, None]) & valid[None, :] + expert_mask = (expert_ids == expert_id) & valid expert_hits = tl.where(expert_mask, 1, 0) - local_pos = tl.cumsum(expert_hits, axis=1) - 1 - write_pos = token_count[:, None] + local_pos + write_pos = token_count + tl.cumsum(expert_hits, axis=0) - 1 tl.store( - expert_to_token_index_ptr + expert_offsets[:, None] * token_num_mul_topk + write_pos, - raw_offs[None, :], + expert_to_token_index_ptr + expert_id * token_num_mul_topk + write_pos, + raw_offs, mask=expert_mask, ) tl.store( - expert_to_weight_ptr + expert_offsets[:, None] * token_num_mul_topk + write_pos, - weights[None, :], + expert_to_weight_ptr + expert_id * token_num_mul_topk + write_pos, + weights, mask=expert_mask, ) - token_count += tl.sum(expert_hits, axis=1) + token_count += tl.sum(expert_hits, axis=0) - tl.store(expert_token_num_ptr + expert_offsets, token_count) + tl.store(expert_token_num_ptr + expert_id, token_count) def _get_moe_align_fused_static_key( @@ -269,65 +265,152 @@ def _get_moe_align_fused_static_key( } -def _get_moe_align_fused_configs(): - return [ +@autotune( + kernel_name="moe_align_fused_small:v2", + configs_gen_func=lambda: [ { "BLOCK_SIZE": bt, - "EXPERT_BLOCK": eb, "num_warps": nw, "NUM_STAGE": ns, } for ns in [1, 2, 4, 6] for nw in [1, 2, 4, 8] - for eb in [1, 4, 8, 16, 32] - for bt in [32, 64, 128, 256, 512, 1024, 2048] - ] + for bt in [8, 16, 32, 64, 128, 256, 512, 1024, 2048] + ], + static_key_func=_get_moe_align_fused_static_key, + run_key_func=lambda topk_ids: topk_ids.shape[0], + mutates_args=["expert_to_token_index", "expert_to_weight", "expert_token_num"], +) +def _moe_align_fused_small_token( + expert_to_token_index, + expert_to_weight, + expert_token_num, + topk_ids, + topk_weights, + run_config: Optional[dict] = None, +): + if run_config is None: + token_num = topk_ids.shape[0] + if token_num <= 2: + run_config = {"BLOCK_SIZE": 16, "num_warps": 1, "NUM_STAGE": 1} + elif token_num <= 8: + run_config = {"BLOCK_SIZE": 64, "num_warps": 1, "NUM_STAGE": 1} + elif token_num <= 16: + run_config = {"BLOCK_SIZE": 128, "num_warps": 1, "NUM_STAGE": 1} + elif token_num < 32: + run_config = {"BLOCK_SIZE": 256, "num_warps": 2, "NUM_STAGE": 1} + elif token_num <= 64: + run_config = {"BLOCK_SIZE": 512, "num_warps": 4, "NUM_STAGE": 1} + else: + run_config = {"BLOCK_SIZE": 1024, "num_warps": 8, "NUM_STAGE": 1} + token_num_mul_topk = topk_ids.numel() + expert_num = expert_token_num.shape[0] + block_size = run_config["BLOCK_SIZE"] + + moe_align_fused_small_token_kernel[(expert_num,)]( + topk_ids, + topk_weights, + expert_to_token_index, + expert_to_weight, + expert_token_num, + token_num_mul_topk, + BLOCK_SIZE=block_size, + NUM_STAGE=run_config["NUM_STAGE"], + num_warps=run_config["num_warps"], + ) + return expert_to_token_index, expert_to_weight, expert_token_num + + +@triton.jit +def moe_align_fused_large_atomic_kernel( + topk_ids_ptr, # [token_num, topk] + topk_weights_ptr, # [token_num, topk] + expert_to_token_index_ptr, # [expert_num, token_num * topk] + expert_to_weight_ptr, # [expert_num, token_num * topk] + expert_token_num_ptr, # [expert_num] + token_num_mul_topk, + BLOCK_SIZE: tl.constexpr, +): + block_id = tl.program_id(0) + offs = block_id * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + valid = offs < token_num_mul_topk + expert_ids = tl.load(topk_ids_ptr + offs, mask=valid, other=-1) + weights = tl.load(topk_weights_ptr + offs, mask=valid, other=0.0) + write_pos = tl.atomic_add(expert_token_num_ptr + expert_ids, 1, mask=valid) + tl.store( + expert_to_token_index_ptr + expert_ids * token_num_mul_topk + write_pos, + offs, + mask=valid, + ) + tl.store( + expert_to_weight_ptr + expert_ids * token_num_mul_topk + write_pos, + weights, + mask=valid, + ) @autotune( - kernel_name="moe_align_fused:v3", - configs_gen_func=_get_moe_align_fused_configs, + kernel_name="moe_align_fused_large:v1", + configs_gen_func=lambda: [ + { + "BLOCK_SIZE": bt, + "num_warps": nw, + } + for nw in [1, 2, 4, 8] + for bt in [128, 256, 512, 1024, 2048] + ], static_key_func=_get_moe_align_fused_static_key, run_key_func=lambda topk_ids: topk_ids.shape[0], mutates_args=["expert_to_token_index", "expert_to_weight", "expert_token_num"], ) -def moe_align_fused( - expert_to_token_index, expert_to_weight, expert_token_num, topk_ids, topk_weights, run_config: Optional[dict] = None +def _moe_align_fused_large_token( + expert_to_token_index, + expert_to_weight, + expert_token_num, + topk_ids, + topk_weights, + run_config: Optional[dict] = None, ): - token_num, topk_num = topk_ids.shape - expert_num = expert_token_num.shape[0] - token_num_mul_topk = token_num * topk_num - if run_config is None: - if token_num_mul_topk <= 128: - run_config = {"BLOCK_SIZE": 128, "EXPERT_BLOCK": 8, "num_warps": 1, "NUM_STAGE": 1} - elif token_num_mul_topk <= 256: - run_config = {"BLOCK_SIZE": 128, "EXPERT_BLOCK": 8, "num_warps": 4, "NUM_STAGE": 4} - elif token_num_mul_topk <= 512: - run_config = {"BLOCK_SIZE": 256, "EXPERT_BLOCK": 4, "num_warps": 8, "NUM_STAGE": 1} - elif token_num_mul_topk <= 1536: - run_config = {"BLOCK_SIZE": 2048, "EXPERT_BLOCK": 1, "num_warps": 8, "NUM_STAGE": 1} - elif token_num_mul_topk <= 3072: - run_config = {"BLOCK_SIZE": 1024, "EXPERT_BLOCK": 1, "num_warps": 8, "NUM_STAGE": 1} + token_num = topk_ids.shape[0] + if token_num <= 128: + run_config = {"BLOCK_SIZE": 128, "num_warps": 4} + elif token_num <= 8192: + run_config = {"BLOCK_SIZE": 128, "num_warps": 8} else: - run_config = {"BLOCK_SIZE": 2048, "EXPERT_BLOCK": 1, "num_warps": 8, "NUM_STAGE": 1} + run_config = {"BLOCK_SIZE": 128, "num_warps": 1} - moe_align_fused_deterministic_kernel[(triton.cdiv(expert_num, run_config["EXPERT_BLOCK"]),)]( + token_num_mul_topk = topk_ids.numel() + block_size = run_config["BLOCK_SIZE"] + expert_token_num.zero_() + moe_align_fused_large_atomic_kernel[(triton.cdiv(token_num_mul_topk, block_size),)]( topk_ids, topk_weights, expert_to_token_index, expert_to_weight, expert_token_num, - expert_num, token_num_mul_topk, - BLOCK_SIZE=run_config["BLOCK_SIZE"], - EXPERT_BLOCK=run_config["EXPERT_BLOCK"], - NUM_STAGE=run_config["NUM_STAGE"], + BLOCK_SIZE=block_size, num_warps=run_config["num_warps"], ) return expert_to_token_index, expert_to_weight, expert_token_num +def moe_align_fused( + expert_to_token_index, expert_to_weight, expert_token_num, topk_ids, topk_weights, run_config: Optional[dict] = None +): + token_num = topk_ids.shape[0] + if token_num <= 128: + _moe_align_fused_small_token( + expert_to_token_index, expert_to_weight, expert_token_num, topk_ids, topk_weights, run_config + ) + else: + _moe_align_fused_large_token( + expert_to_token_index, expert_to_weight, expert_token_num, topk_ids, topk_weights, run_config + ) + return expert_to_token_index, expert_to_weight, expert_token_num + + @triton.jit def moe_align2_kernel( experts_token_num_ptr, # [expert_num,] diff --git a/unit_tests/common/fused_moe/test_grouped_fused_moe.py b/unit_tests/common/fused_moe/test_grouped_fused_moe.py index 1f8e8f366..87247ac11 100644 --- a/unit_tests/common/fused_moe/test_grouped_fused_moe.py +++ b/unit_tests/common/fused_moe/test_grouped_fused_moe.py @@ -75,10 +75,7 @@ def test_moe_align1(): assert torch.equal(experts_info, true_experts_info) -def test_moe_align_fused(): - expert_num = 5 - topk_ids = torch.tensor([[0, 1, 2], [0, 3, 1], [3, 1, 4]], dtype=torch.int32, device="cuda") - topk_weights = torch.tensor([[0.3, 0.7, 0.1], [0.2, 0.8, 0.4], [0.5, 0.6, 0.9]], dtype=torch.float32, device="cuda") +def _check_moe_align_fused(topk_ids, topk_weights, expert_num, run_config=None, ordered=True): expert_to_token_index = torch.empty((expert_num, topk_ids.numel()), dtype=torch.int32, device="cuda") expert_to_weight = torch.empty((expert_num, topk_ids.numel()), dtype=torch.float32, device="cuda") expert_token_num = torch.empty((expert_num,), dtype=torch.int32, device="cuda") @@ -89,22 +86,60 @@ def test_moe_align_fused(): expert_token_num, topk_ids, topk_weights, - run_config={"BLOCK_SIZE": 1024, "EXPERT_BLOCK": 1, "num_warps": 8, "NUM_STAGE": 1}, + run_config=run_config, ) torch.cuda.synchronize() - true_expert_token_num = torch.tensor([2, 3, 1, 2, 1], device="cuda", dtype=torch.int32) - assert torch.equal(expert_token_num, true_expert_token_num) - flat_topk_ids = topk_ids.flatten() flat_topk_weights = topk_weights.flatten() - for expert_id in range(expert_num): - mask = flat_topk_ids == expert_id - true_index = torch.nonzero(mask, as_tuple=False).flatten().to(torch.int32) - true_weight = flat_topk_weights[mask] - token_num = true_expert_token_num[expert_id] - assert torch.equal(expert_to_token_index[expert_id, :token_num], true_index) - assert torch.allclose(expert_to_weight[expert_id, :token_num], true_weight) + expected_token_num = torch.bincount(flat_topk_ids, minlength=expert_num).to(torch.int32) + assert torch.equal(expert_token_num, expected_token_num) + + for expert_id, token_num in enumerate(expected_token_num.tolist()): + expected_index = torch.nonzero(flat_topk_ids == expert_id, as_tuple=False).flatten() + expected_weight = flat_topk_weights[expected_index] + expected_index = expected_index.to(torch.int32) + token_index = expert_to_token_index[expert_id, :token_num] + token_weight = expert_to_weight[expert_id, :token_num] + + if not ordered: + order = torch.argsort(token_index) + token_index = token_index[order] + token_weight = token_weight[order] + + assert torch.equal(token_index, expected_index) + assert torch.allclose(token_weight, expected_weight) + + +def test_moe_align_fused_small_token(): + expert_num = 5 + small_topk_ids = torch.tensor([[0, 1, 2], [0, 3, 1], [3, 1, 4]], dtype=torch.int32, device="cuda") + small_topk_weights = torch.tensor( + [[0.3, 0.7, 0.1], [0.2, 0.8, 0.4], [0.5, 0.6, 0.9]], dtype=torch.float32, device="cuda" + ) + _check_moe_align_fused(small_topk_ids, small_topk_weights, expert_num) + + small_many_topk_ids = torch.arange(128 * 17, dtype=torch.int32, device="cuda").reshape(128, 17) % expert_num + small_many_topk_weights = torch.arange(small_many_topk_ids.numel(), dtype=torch.float32, device="cuda").reshape( + 128, 17 + ) + _check_moe_align_fused(small_many_topk_ids, small_many_topk_weights, expert_num) + + +def test_moe_align_fused_large_token(): + expert_num = 5 + + base_topk_ids = torch.tensor([[0, 1, 2], [0, 3, 1], [3, 1, 4], [2, 0, 4]], dtype=torch.int32, device="cuda") + large_topk_ids = base_topk_ids.repeat(33, 1)[:129].contiguous() + large_topk_weights = torch.arange(large_topk_ids.numel(), dtype=torch.float32, device="cuda").reshape(129, 3) + _check_moe_align_fused(large_topk_ids, large_topk_weights, expert_num, ordered=False) + _check_moe_align_fused( + large_topk_ids, + large_topk_weights, + expert_num, + run_config={"BLOCK_SIZE": 128, "num_warps": 4}, + ordered=False, + ) def test_moe_align2(): @@ -116,13 +151,14 @@ def test_moe_align2(): experts_token_num[3] = 16 mblocks_to_tuple_info = moe_align2(100, experts_token_num, block_m=16) + expected_expert_ids = torch.tensor([0, 2, 2, 2, 2, 3, -1, -1, -1, -1], device="cuda", dtype=torch.int32) + valid_blocks = expected_expert_ids != -1 + assert mblocks_to_tuple_info.shape[0] == triton.cdiv(100 + 4 * (16 - 1), 16) - assert torch.allclose( - mblocks_to_tuple_info[:, 0], - torch.tensor([0, 2, 2, 2, 2, 3, -1, -1, -1, -1], device="cuda", dtype=torch.int32), - ) - assert torch.allclose( - mblocks_to_tuple_info[:, 1], torch.tensor([0, 0, 1, 2, 3, 0, 0, 0, 0, 0], device="cuda", dtype=torch.int32) + assert torch.equal(mblocks_to_tuple_info[:, 0], expected_expert_ids) + assert torch.equal( + mblocks_to_tuple_info[valid_blocks, 1], + torch.tensor([0, 0, 1, 2, 3, 0], device="cuda", dtype=torch.int32), ) From f76b555d8f4771889b152a827f0d7a54a6d9db13 Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Fri, 26 Jun 2026 07:59:24 +0000 Subject: [PATCH 4/8] fix --- .../fused_moe/grouped_fused_moe.py | 163 +++++++++++++----- .../fused_moe/test_grouped_fused_moe.py | 42 ++++- 2 files changed, 149 insertions(+), 56 deletions(-) diff --git a/lightllm/common/basemodel/triton_kernel/fused_moe/grouped_fused_moe.py b/lightllm/common/basemodel/triton_kernel/fused_moe/grouped_fused_moe.py index ca20c3273..fd259e09e 100644 --- a/lightllm/common/basemodel/triton_kernel/fused_moe/grouped_fused_moe.py +++ b/lightllm/common/basemodel/triton_kernel/fused_moe/grouped_fused_moe.py @@ -301,8 +301,12 @@ def _moe_align_fused_small_token( run_config = {"BLOCK_SIZE": 256, "num_warps": 2, "NUM_STAGE": 1} elif token_num <= 64: run_config = {"BLOCK_SIZE": 512, "num_warps": 4, "NUM_STAGE": 1} - else: + elif token_num <= 128: run_config = {"BLOCK_SIZE": 1024, "num_warps": 8, "NUM_STAGE": 1} + elif token_num <= 192: + run_config = {"BLOCK_SIZE": 512, "num_warps": 4, "NUM_STAGE": 1} + else: + run_config = {"BLOCK_SIZE": 2048, "num_warps": 8, "NUM_STAGE": 1} token_num_mul_topk = topk_ids.numel() expert_num = expert_token_num.shape[0] block_size = run_config["BLOCK_SIZE"] @@ -322,48 +326,101 @@ def _moe_align_fused_small_token( @triton.jit -def moe_align_fused_large_atomic_kernel( - topk_ids_ptr, # [token_num, topk] +def moe_align_fused_record_sorted_segment_start_kernel( + sorted_expert_ids_ptr, # [token_num * topk] + segment_start_ptr, # [expert_num] + expert_token_num_ptr, # [expert_num] + token_num_mul_topk, + expert_num, + BLOCK_SIZE: tl.constexpr, +): + block_id = tl.program_id(0) + sorted_pos = block_id * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + valid = sorted_pos < token_num_mul_topk + expert_id = tl.load(sorted_expert_ids_ptr + sorted_pos, mask=valid, other=-1) + + if block_id == 0: + init_offs = tl.arange(0, BLOCK_SIZE) + for start in tl.range(0, expert_num, BLOCK_SIZE): + expert_offs = start + init_offs + tl.store(expert_token_num_ptr + expert_offs, 0, mask=expert_offs < expert_num) + + # A segment starts when the sorted expert id changes. + prev_pos = sorted_pos - 1 + prev_expert_id = tl.load(sorted_expert_ids_ptr + prev_pos, mask=valid & (sorted_pos > 0), other=-1) + is_segment_start = valid & ((sorted_pos == 0) | (expert_id != prev_expert_id)) + tl.store(segment_start_ptr + expert_id, sorted_pos, mask=is_segment_start) + + +@triton.jit +def moe_align_fused_sorted_scatter_kernel( + sorted_expert_ids_ptr, # [token_num * topk] topk_weights_ptr, # [token_num, topk] + sorted_token_index_ptr, # [token_num * topk] + segment_start_ptr, # [expert_num] + expert_token_num_ptr, # [expert_num] expert_to_token_index_ptr, # [expert_num, token_num * topk] expert_to_weight_ptr, # [expert_num, token_num * topk] - expert_token_num_ptr, # [expert_num] token_num_mul_topk, BLOCK_SIZE: tl.constexpr, ): block_id = tl.program_id(0) - offs = block_id * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) - valid = offs < token_num_mul_topk - expert_ids = tl.load(topk_ids_ptr + offs, mask=valid, other=-1) - weights = tl.load(topk_weights_ptr + offs, mask=valid, other=0.0) - write_pos = tl.atomic_add(expert_token_num_ptr + expert_ids, 1, mask=valid) + sorted_pos = block_id * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + valid = sorted_pos < token_num_mul_topk + + # sorted_token_index maps a sorted slot back to the original flattened topk slot. + token_index = tl.load(sorted_token_index_ptr + sorted_pos, mask=valid, other=0) + expert_id = tl.load(sorted_expert_ids_ptr + sorted_pos, mask=valid, other=0) + + # Each expert owns a contiguous sorted segment, so local row offset is + # the sorted position minus that expert's segment start. + expert_start = tl.load(segment_start_ptr + expert_id, mask=valid, other=0) + expert_pos = sorted_pos - expert_start + weight = tl.load(topk_weights_ptr + token_index, mask=valid, other=0.0) + + # A segment end gives the final token count for this expert. + next_pos = sorted_pos + 1 + next_expert_id = tl.load(sorted_expert_ids_ptr + next_pos, mask=valid & (next_pos < token_num_mul_topk), other=-1) + is_segment_end = valid & ((next_pos == token_num_mul_topk) | (expert_id != next_expert_id)) + tl.store(expert_token_num_ptr + expert_id, next_pos - expert_start, mask=is_segment_end) + tl.store( - expert_to_token_index_ptr + expert_ids * token_num_mul_topk + write_pos, - offs, + expert_to_token_index_ptr + expert_id * token_num_mul_topk + expert_pos, + token_index, mask=valid, ) tl.store( - expert_to_weight_ptr + expert_ids * token_num_mul_topk + write_pos, - weights, + expert_to_weight_ptr + expert_id * token_num_mul_topk + expert_pos, + weight, mask=valid, ) +def _make_moe_align_fused_sorted_configs(): + configs = [] + for scatter_block_size in [128, 256, 512]: + for scatter_num_warps in [4, 8]: + for record_block_size in [128, 256, 512, 1024]: + for record_num_warps in [4, 8]: + configs.append( + { + "SCATTER_BLOCK_SIZE": scatter_block_size, + "SCATTER_NUM_WARPS": scatter_num_warps, + "RECORD_BLOCK_SIZE": record_block_size, + "RECORD_NUM_WARPS": record_num_warps, + } + ) + return configs + + @autotune( - kernel_name="moe_align_fused_large:v1", - configs_gen_func=lambda: [ - { - "BLOCK_SIZE": bt, - "num_warps": nw, - } - for nw in [1, 2, 4, 8] - for bt in [128, 256, 512, 1024, 2048] - ], + kernel_name="moe_align_fused_sorted:v2", + configs_gen_func=_make_moe_align_fused_sorted_configs, static_key_func=_get_moe_align_fused_static_key, run_key_func=lambda topk_ids: topk_ids.shape[0], mutates_args=["expert_to_token_index", "expert_to_weight", "expert_token_num"], ) -def _moe_align_fused_large_token( +def _moe_align_fused_sorted_token( expert_to_token_index, expert_to_weight, expert_token_num, @@ -372,42 +429,54 @@ def _moe_align_fused_large_token( run_config: Optional[dict] = None, ): if run_config is None: - token_num = topk_ids.shape[0] - if token_num <= 128: - run_config = {"BLOCK_SIZE": 128, "num_warps": 4} - elif token_num <= 8192: - run_config = {"BLOCK_SIZE": 128, "num_warps": 8} - else: - run_config = {"BLOCK_SIZE": 128, "num_warps": 1} + run_config = { + "SCATTER_BLOCK_SIZE": 128, + "SCATTER_NUM_WARPS": 4, + "RECORD_BLOCK_SIZE": 256, + "RECORD_NUM_WARPS": 8, + } token_num_mul_topk = topk_ids.numel() - block_size = run_config["BLOCK_SIZE"] - expert_token_num.zero_() - moe_align_fused_large_atomic_kernel[(triton.cdiv(token_num_mul_topk, block_size),)]( - topk_ids, + expert_num = expert_token_num.shape[0] + record_grid = (triton.cdiv(token_num_mul_topk, run_config["RECORD_BLOCK_SIZE"]),) + scatter_grid = (triton.cdiv(token_num_mul_topk, run_config["SCATTER_BLOCK_SIZE"]),) + + flat_topk_ids = topk_ids.view(-1) + sorted_expert_ids, sorted_token_index = torch.sort(flat_topk_ids, stable=True) + segment_start = torch.empty_like(expert_token_num) + + # Stable sort makes every expert's tokens contiguous. Record each segment + # start and clear empty-expert counts, then scatter compact expert rows. + moe_align_fused_record_sorted_segment_start_kernel[record_grid]( + sorted_expert_ids, + segment_start, + expert_token_num, + token_num_mul_topk, + expert_num, + BLOCK_SIZE=run_config["RECORD_BLOCK_SIZE"], + num_warps=run_config["RECORD_NUM_WARPS"], + ) + moe_align_fused_sorted_scatter_kernel[scatter_grid]( + sorted_expert_ids, topk_weights, + sorted_token_index, + segment_start, + expert_token_num, expert_to_token_index, expert_to_weight, - expert_token_num, token_num_mul_topk, - BLOCK_SIZE=block_size, - num_warps=run_config["num_warps"], + BLOCK_SIZE=run_config["SCATTER_BLOCK_SIZE"], + num_warps=run_config["SCATTER_NUM_WARPS"], ) return expert_to_token_index, expert_to_weight, expert_token_num -def moe_align_fused( - expert_to_token_index, expert_to_weight, expert_token_num, topk_ids, topk_weights, run_config: Optional[dict] = None -): +def moe_align_fused(expert_to_token_index, expert_to_weight, expert_token_num, topk_ids, topk_weights): token_num = topk_ids.shape[0] - if token_num <= 128: - _moe_align_fused_small_token( - expert_to_token_index, expert_to_weight, expert_token_num, topk_ids, topk_weights, run_config - ) + if token_num <= 5120: + _moe_align_fused_small_token(expert_to_token_index, expert_to_weight, expert_token_num, topk_ids, topk_weights) else: - _moe_align_fused_large_token( - expert_to_token_index, expert_to_weight, expert_token_num, topk_ids, topk_weights, run_config - ) + _moe_align_fused_sorted_token(expert_to_token_index, expert_to_weight, expert_token_num, topk_ids, topk_weights) return expert_to_token_index, expert_to_weight, expert_token_num diff --git a/unit_tests/common/fused_moe/test_grouped_fused_moe.py b/unit_tests/common/fused_moe/test_grouped_fused_moe.py index 87247ac11..e37dd69eb 100644 --- a/unit_tests/common/fused_moe/test_grouped_fused_moe.py +++ b/unit_tests/common/fused_moe/test_grouped_fused_moe.py @@ -75,7 +75,7 @@ def test_moe_align1(): assert torch.equal(experts_info, true_experts_info) -def _check_moe_align_fused(topk_ids, topk_weights, expert_num, run_config=None, ordered=True): +def _check_moe_align_fused(topk_ids, topk_weights, expert_num, ordered=True): expert_to_token_index = torch.empty((expert_num, topk_ids.numel()), dtype=torch.int32, device="cuda") expert_to_weight = torch.empty((expert_num, topk_ids.numel()), dtype=torch.float32, device="cuda") expert_token_num = torch.empty((expert_num,), dtype=torch.int32, device="cuda") @@ -86,7 +86,6 @@ def _check_moe_align_fused(topk_ids, topk_weights, expert_num, run_config=None, expert_token_num, topk_ids, topk_weights, - run_config=run_config, ) torch.cuda.synchronize() @@ -132,14 +131,39 @@ def test_moe_align_fused_large_token(): base_topk_ids = torch.tensor([[0, 1, 2], [0, 3, 1], [3, 1, 4], [2, 0, 4]], dtype=torch.int32, device="cuda") large_topk_ids = base_topk_ids.repeat(33, 1)[:129].contiguous() large_topk_weights = torch.arange(large_topk_ids.numel(), dtype=torch.float32, device="cuda").reshape(129, 3) - _check_moe_align_fused(large_topk_ids, large_topk_weights, expert_num, ordered=False) - _check_moe_align_fused( - large_topk_ids, - large_topk_weights, - expert_num, - run_config={"BLOCK_SIZE": 128, "num_warps": 4}, - ordered=False, + _check_moe_align_fused(large_topk_ids, large_topk_weights, expert_num) + + sorted_path_topk_ids = base_topk_ids.repeat(1281, 1)[:5121].contiguous() + sorted_path_topk_weights = torch.arange(sorted_path_topk_ids.numel(), dtype=torch.float32, device="cuda").reshape( + 5121, 3 ) + _check_moe_align_fused(sorted_path_topk_ids, sorted_path_topk_weights, expert_num) + + sparse_expert_num = 257 + sparse_topk_ids = base_topk_ids.repeat(1281, 1)[:5121].contiguous() + sparse_topk_weights = torch.arange(sparse_topk_ids.numel(), dtype=torch.float32, device="cuda").reshape(5121, 3) + _check_moe_align_fused(sparse_topk_ids, sparse_topk_weights, sparse_expert_num) + + +def test_moe_align_fused_large_token_is_deterministic(): + expert_num = 257 + topk_ids = torch.arange(5121 * 8, dtype=torch.int32, device="cuda").reshape(5121, 8) % expert_num + topk_weights = torch.arange(topk_ids.numel(), dtype=torch.float32, device="cuda").reshape(5121, 8) + + expert_to_token_index_0 = torch.full((expert_num, topk_ids.numel()), -1, dtype=torch.int32, device="cuda") + expert_to_weight_0 = torch.zeros((expert_num, topk_ids.numel()), dtype=torch.float32, device="cuda") + expert_token_num_0 = torch.empty((expert_num,), dtype=torch.int32, device="cuda") + moe_align_fused(expert_to_token_index_0, expert_to_weight_0, expert_token_num_0, topk_ids, topk_weights) + + expert_to_token_index_1 = torch.full((expert_num, topk_ids.numel()), -1, dtype=torch.int32, device="cuda") + expert_to_weight_1 = torch.zeros((expert_num, topk_ids.numel()), dtype=torch.float32, device="cuda") + expert_token_num_1 = torch.empty((expert_num,), dtype=torch.int32, device="cuda") + moe_align_fused(expert_to_token_index_1, expert_to_weight_1, expert_token_num_1, topk_ids, topk_weights) + torch.cuda.synchronize() + + assert torch.equal(expert_token_num_0, expert_token_num_1) + assert torch.equal(expert_to_token_index_0, expert_to_token_index_1) + assert torch.allclose(expert_to_weight_0, expert_to_weight_1) def test_moe_align2(): From 8eb8275cca92cfe6984835ef5cc9b77f650b0cee Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Fri, 26 Jun 2026 09:24:47 +0000 Subject: [PATCH 5/8] fix --- .../fused_moe/grouped_fused_moe.py | 48 ++++++++++++- .../fused_moe/test_grouped_fused_moe.py | 69 +++++++++++++++---- 2 files changed, 100 insertions(+), 17 deletions(-) diff --git a/lightllm/common/basemodel/triton_kernel/fused_moe/grouped_fused_moe.py b/lightllm/common/basemodel/triton_kernel/fused_moe/grouped_fused_moe.py index fd259e09e..14c4d5f6b 100644 --- a/lightllm/common/basemodel/triton_kernel/fused_moe/grouped_fused_moe.py +++ b/lightllm/common/basemodel/triton_kernel/fused_moe/grouped_fused_moe.py @@ -471,12 +471,56 @@ def _moe_align_fused_sorted_token( return expert_to_token_index, expert_to_weight, expert_token_num +@triton.jit +def moe_align_fused_atomic_kernel( + topk_ids_ptr, # [token_num, topk] + topk_weights_ptr, # [token_num, topk] + expert_to_token_index_ptr, # [expert_num, token_num * topk] + expert_to_weight_ptr, # [expert_num, token_num * topk] + expert_token_num_ptr, # [expert_num] + token_num_mul_topk, + BLOCK_SIZE: tl.constexpr, +): + block_id = tl.program_id(0) + offs = block_id * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + valid = offs < token_num_mul_topk + expert_id = tl.load(topk_ids_ptr + offs, mask=valid, other=0) + weight = tl.load(topk_weights_ptr + offs, mask=valid, other=0.0) + write_pos = tl.atomic_add(expert_token_num_ptr + expert_id, 1, mask=valid) + tl.store(expert_to_token_index_ptr + expert_id * token_num_mul_topk + write_pos, offs, mask=valid) + tl.store(expert_to_weight_ptr + expert_id * token_num_mul_topk + write_pos, weight, mask=valid) + + +def _moe_align_fused_atomic_token( + expert_to_token_index, + expert_to_weight, + expert_token_num, + topk_ids, + topk_weights, +): + token_num_mul_topk = topk_ids.numel() + expert_token_num.zero_() + moe_align_fused_atomic_kernel[(triton.cdiv(token_num_mul_topk, 128),)]( + topk_ids, + topk_weights, + expert_to_token_index, + expert_to_weight, + expert_token_num, + token_num_mul_topk, + BLOCK_SIZE=128, + num_warps=4, + ) + return expert_to_token_index, expert_to_weight, expert_token_num + + def moe_align_fused(expert_to_token_index, expert_to_weight, expert_token_num, topk_ids, topk_weights): token_num = topk_ids.shape[0] - if token_num <= 5120: + if token_num <= 128: _moe_align_fused_small_token(expert_to_token_index, expert_to_weight, expert_token_num, topk_ids, topk_weights) else: - _moe_align_fused_sorted_token(expert_to_token_index, expert_to_weight, expert_token_num, topk_ids, topk_weights) + # Expert rows may be unordered, but grouped matmul reuses this same + # mapping for up/down projections and writes back to original topk slots. + _moe_align_fused_atomic_token(expert_to_token_index, expert_to_weight, expert_token_num, topk_ids, topk_weights) return expert_to_token_index, expert_to_weight, expert_token_num diff --git a/unit_tests/common/fused_moe/test_grouped_fused_moe.py b/unit_tests/common/fused_moe/test_grouped_fused_moe.py index e37dd69eb..9e6dec8f8 100644 --- a/unit_tests/common/fused_moe/test_grouped_fused_moe.py +++ b/unit_tests/common/fused_moe/test_grouped_fused_moe.py @@ -3,6 +3,8 @@ import pytest import triton from lightllm.common.basemodel.triton_kernel.fused_moe.grouped_fused_moe import ( + _moe_align_fused_atomic_token, + fused_experts_impl, moe_align, moe_align_fused, moe_align1, @@ -131,39 +133,76 @@ def test_moe_align_fused_large_token(): base_topk_ids = torch.tensor([[0, 1, 2], [0, 3, 1], [3, 1, 4], [2, 0, 4]], dtype=torch.int32, device="cuda") large_topk_ids = base_topk_ids.repeat(33, 1)[:129].contiguous() large_topk_weights = torch.arange(large_topk_ids.numel(), dtype=torch.float32, device="cuda").reshape(129, 3) - _check_moe_align_fused(large_topk_ids, large_topk_weights, expert_num) + _check_moe_align_fused(large_topk_ids, large_topk_weights, expert_num, ordered=False) sorted_path_topk_ids = base_topk_ids.repeat(1281, 1)[:5121].contiguous() sorted_path_topk_weights = torch.arange(sorted_path_topk_ids.numel(), dtype=torch.float32, device="cuda").reshape( 5121, 3 ) - _check_moe_align_fused(sorted_path_topk_ids, sorted_path_topk_weights, expert_num) + _check_moe_align_fused(sorted_path_topk_ids, sorted_path_topk_weights, expert_num, ordered=False) sparse_expert_num = 257 sparse_topk_ids = base_topk_ids.repeat(1281, 1)[:5121].contiguous() sparse_topk_weights = torch.arange(sparse_topk_ids.numel(), dtype=torch.float32, device="cuda").reshape(5121, 3) - _check_moe_align_fused(sparse_topk_ids, sparse_topk_weights, sparse_expert_num) + _check_moe_align_fused(sparse_topk_ids, sparse_topk_weights, sparse_expert_num, ordered=False) -def test_moe_align_fused_large_token_is_deterministic(): +def test_moe_align_fused_large_token_unordered(): expert_num = 257 topk_ids = torch.arange(5121 * 8, dtype=torch.int32, device="cuda").reshape(5121, 8) % expert_num topk_weights = torch.arange(topk_ids.numel(), dtype=torch.float32, device="cuda").reshape(5121, 8) + _check_moe_align_fused(topk_ids, topk_weights, expert_num, ordered=False) - expert_to_token_index_0 = torch.full((expert_num, topk_ids.numel()), -1, dtype=torch.int32, device="cuda") - expert_to_weight_0 = torch.zeros((expert_num, topk_ids.numel()), dtype=torch.float32, device="cuda") - expert_token_num_0 = torch.empty((expert_num,), dtype=torch.int32, device="cuda") - moe_align_fused(expert_to_token_index_0, expert_to_weight_0, expert_token_num_0, topk_ids, topk_weights) - expert_to_token_index_1 = torch.full((expert_num, topk_ids.numel()), -1, dtype=torch.int32, device="cuda") - expert_to_weight_1 = torch.zeros((expert_num, topk_ids.numel()), dtype=torch.float32, device="cuda") - expert_token_num_1 = torch.empty((expert_num,), dtype=torch.int32, device="cuda") - moe_align_fused(expert_to_token_index_1, expert_to_weight_1, expert_token_num_1, topk_ids, topk_weights) +def test_moe_align_fused_atomic_token_unordered(): + expert_num = 9 + topk_ids = torch.arange(257 * 4, dtype=torch.int32, device="cuda").reshape(257, 4) % expert_num + topk_weights = torch.arange(topk_ids.numel(), dtype=torch.float32, device="cuda").reshape(257, 4) + expert_to_token_index = torch.empty((expert_num, topk_ids.numel()), dtype=torch.int32, device="cuda") + expert_to_weight = torch.empty((expert_num, topk_ids.numel()), dtype=torch.float32, device="cuda") + expert_token_num = torch.empty((expert_num,), dtype=torch.int32, device="cuda") + + _moe_align_fused_atomic_token( + expert_to_token_index, + expert_to_weight, + expert_token_num, + topk_ids, + topk_weights, + ) + torch.cuda.synchronize() + + flat_topk_ids = topk_ids.flatten() + flat_topk_weights = topk_weights.flatten() + expected_token_num = torch.bincount(flat_topk_ids, minlength=expert_num).to(torch.int32) + assert torch.equal(expert_token_num, expected_token_num) + + for expert_id, token_num in enumerate(expected_token_num.tolist()): + expected_index = torch.nonzero(flat_topk_ids == expert_id, as_tuple=False).flatten().to(torch.int32) + expected_weight = flat_topk_weights[expected_index] + token_index = expert_to_token_index[expert_id, :token_num] + token_weight = expert_to_weight[expert_id, :token_num] + order = torch.argsort(token_index) + assert torch.equal(token_index[order], expected_index) + assert torch.allclose(token_weight[order], expected_weight) + + +def test_fused_experts_atomic_align_path_is_deterministic(): + token_num = 129 + expert_num = 9 + hidden_size = 64 + intermediate_size = 128 + topk = 4 + hidden_states = torch.randn((token_num, hidden_size), dtype=torch.bfloat16, device="cuda") / 10 + w1 = torch.randn((expert_num, intermediate_size, hidden_size), dtype=torch.bfloat16, device="cuda") / 10 + w2 = torch.randn((expert_num, hidden_size, intermediate_size // 2), dtype=torch.bfloat16, device="cuda") / 10 + topk_ids = torch.arange(token_num * topk, dtype=torch.int32, device="cuda").reshape(token_num, topk) % expert_num + topk_weights = torch.softmax(torch.randn((token_num, topk), dtype=torch.float32, device="cuda"), dim=-1) + + out_0 = fused_experts_impl(hidden_states, w1, w2, topk_weights, topk_ids) + out_1 = fused_experts_impl(hidden_states, w1, w2, topk_weights, topk_ids) torch.cuda.synchronize() - assert torch.equal(expert_token_num_0, expert_token_num_1) - assert torch.equal(expert_to_token_index_0, expert_to_token_index_1) - assert torch.allclose(expert_to_weight_0, expert_to_weight_1) + assert torch.equal(out_0, out_1) def test_moe_align2(): From 3155d47316521ad48a8fc24ceccd357101115655 Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Mon, 29 Jun 2026 01:48:30 +0000 Subject: [PATCH 6/8] fix --- .../fused_moe/grouped_fused_moe.py | 146 ------------------ .../fused_moe/test_grouped_fused_moe.py | 17 +- 2 files changed, 14 insertions(+), 149 deletions(-) diff --git a/lightllm/common/basemodel/triton_kernel/fused_moe/grouped_fused_moe.py b/lightllm/common/basemodel/triton_kernel/fused_moe/grouped_fused_moe.py index 14c4d5f6b..468773375 100644 --- a/lightllm/common/basemodel/triton_kernel/fused_moe/grouped_fused_moe.py +++ b/lightllm/common/basemodel/triton_kernel/fused_moe/grouped_fused_moe.py @@ -325,152 +325,6 @@ def _moe_align_fused_small_token( return expert_to_token_index, expert_to_weight, expert_token_num -@triton.jit -def moe_align_fused_record_sorted_segment_start_kernel( - sorted_expert_ids_ptr, # [token_num * topk] - segment_start_ptr, # [expert_num] - expert_token_num_ptr, # [expert_num] - token_num_mul_topk, - expert_num, - BLOCK_SIZE: tl.constexpr, -): - block_id = tl.program_id(0) - sorted_pos = block_id * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) - valid = sorted_pos < token_num_mul_topk - expert_id = tl.load(sorted_expert_ids_ptr + sorted_pos, mask=valid, other=-1) - - if block_id == 0: - init_offs = tl.arange(0, BLOCK_SIZE) - for start in tl.range(0, expert_num, BLOCK_SIZE): - expert_offs = start + init_offs - tl.store(expert_token_num_ptr + expert_offs, 0, mask=expert_offs < expert_num) - - # A segment starts when the sorted expert id changes. - prev_pos = sorted_pos - 1 - prev_expert_id = tl.load(sorted_expert_ids_ptr + prev_pos, mask=valid & (sorted_pos > 0), other=-1) - is_segment_start = valid & ((sorted_pos == 0) | (expert_id != prev_expert_id)) - tl.store(segment_start_ptr + expert_id, sorted_pos, mask=is_segment_start) - - -@triton.jit -def moe_align_fused_sorted_scatter_kernel( - sorted_expert_ids_ptr, # [token_num * topk] - topk_weights_ptr, # [token_num, topk] - sorted_token_index_ptr, # [token_num * topk] - segment_start_ptr, # [expert_num] - expert_token_num_ptr, # [expert_num] - expert_to_token_index_ptr, # [expert_num, token_num * topk] - expert_to_weight_ptr, # [expert_num, token_num * topk] - token_num_mul_topk, - BLOCK_SIZE: tl.constexpr, -): - block_id = tl.program_id(0) - sorted_pos = block_id * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) - valid = sorted_pos < token_num_mul_topk - - # sorted_token_index maps a sorted slot back to the original flattened topk slot. - token_index = tl.load(sorted_token_index_ptr + sorted_pos, mask=valid, other=0) - expert_id = tl.load(sorted_expert_ids_ptr + sorted_pos, mask=valid, other=0) - - # Each expert owns a contiguous sorted segment, so local row offset is - # the sorted position minus that expert's segment start. - expert_start = tl.load(segment_start_ptr + expert_id, mask=valid, other=0) - expert_pos = sorted_pos - expert_start - weight = tl.load(topk_weights_ptr + token_index, mask=valid, other=0.0) - - # A segment end gives the final token count for this expert. - next_pos = sorted_pos + 1 - next_expert_id = tl.load(sorted_expert_ids_ptr + next_pos, mask=valid & (next_pos < token_num_mul_topk), other=-1) - is_segment_end = valid & ((next_pos == token_num_mul_topk) | (expert_id != next_expert_id)) - tl.store(expert_token_num_ptr + expert_id, next_pos - expert_start, mask=is_segment_end) - - tl.store( - expert_to_token_index_ptr + expert_id * token_num_mul_topk + expert_pos, - token_index, - mask=valid, - ) - tl.store( - expert_to_weight_ptr + expert_id * token_num_mul_topk + expert_pos, - weight, - mask=valid, - ) - - -def _make_moe_align_fused_sorted_configs(): - configs = [] - for scatter_block_size in [128, 256, 512]: - for scatter_num_warps in [4, 8]: - for record_block_size in [128, 256, 512, 1024]: - for record_num_warps in [4, 8]: - configs.append( - { - "SCATTER_BLOCK_SIZE": scatter_block_size, - "SCATTER_NUM_WARPS": scatter_num_warps, - "RECORD_BLOCK_SIZE": record_block_size, - "RECORD_NUM_WARPS": record_num_warps, - } - ) - return configs - - -@autotune( - kernel_name="moe_align_fused_sorted:v2", - configs_gen_func=_make_moe_align_fused_sorted_configs, - static_key_func=_get_moe_align_fused_static_key, - run_key_func=lambda topk_ids: topk_ids.shape[0], - mutates_args=["expert_to_token_index", "expert_to_weight", "expert_token_num"], -) -def _moe_align_fused_sorted_token( - expert_to_token_index, - expert_to_weight, - expert_token_num, - topk_ids, - topk_weights, - run_config: Optional[dict] = None, -): - if run_config is None: - run_config = { - "SCATTER_BLOCK_SIZE": 128, - "SCATTER_NUM_WARPS": 4, - "RECORD_BLOCK_SIZE": 256, - "RECORD_NUM_WARPS": 8, - } - - token_num_mul_topk = topk_ids.numel() - expert_num = expert_token_num.shape[0] - record_grid = (triton.cdiv(token_num_mul_topk, run_config["RECORD_BLOCK_SIZE"]),) - scatter_grid = (triton.cdiv(token_num_mul_topk, run_config["SCATTER_BLOCK_SIZE"]),) - - flat_topk_ids = topk_ids.view(-1) - sorted_expert_ids, sorted_token_index = torch.sort(flat_topk_ids, stable=True) - segment_start = torch.empty_like(expert_token_num) - - # Stable sort makes every expert's tokens contiguous. Record each segment - # start and clear empty-expert counts, then scatter compact expert rows. - moe_align_fused_record_sorted_segment_start_kernel[record_grid]( - sorted_expert_ids, - segment_start, - expert_token_num, - token_num_mul_topk, - expert_num, - BLOCK_SIZE=run_config["RECORD_BLOCK_SIZE"], - num_warps=run_config["RECORD_NUM_WARPS"], - ) - moe_align_fused_sorted_scatter_kernel[scatter_grid]( - sorted_expert_ids, - topk_weights, - sorted_token_index, - segment_start, - expert_token_num, - expert_to_token_index, - expert_to_weight, - token_num_mul_topk, - BLOCK_SIZE=run_config["SCATTER_BLOCK_SIZE"], - num_warps=run_config["SCATTER_NUM_WARPS"], - ) - return expert_to_token_index, expert_to_weight, expert_token_num - - @triton.jit def moe_align_fused_atomic_kernel( topk_ids_ptr, # [token_num, topk] diff --git a/unit_tests/common/fused_moe/test_grouped_fused_moe.py b/unit_tests/common/fused_moe/test_grouped_fused_moe.py index 9e6dec8f8..0376d01ee 100644 --- a/unit_tests/common/fused_moe/test_grouped_fused_moe.py +++ b/unit_tests/common/fused_moe/test_grouped_fused_moe.py @@ -135,11 +135,22 @@ def test_moe_align_fused_large_token(): large_topk_weights = torch.arange(large_topk_ids.numel(), dtype=torch.float32, device="cuda").reshape(129, 3) _check_moe_align_fused(large_topk_ids, large_topk_weights, expert_num, ordered=False) - sorted_path_topk_ids = base_topk_ids.repeat(1281, 1)[:5121].contiguous() - sorted_path_topk_weights = torch.arange(sorted_path_topk_ids.numel(), dtype=torch.float32, device="cuda").reshape( + medium_topk_ids = base_topk_ids.repeat(1024, 1).contiguous() + medium_topk_weights = torch.arange(medium_topk_ids.numel(), dtype=torch.float32, device="cuda").reshape(4096, 3) + _check_moe_align_fused(medium_topk_ids, medium_topk_weights, expert_num, ordered=False) + + shared_expert_num = 257 + shared_routing = torch.arange(512 * 7, dtype=torch.int32, device="cuda").reshape(512, 7) % 256 + shared_last = torch.full((512, 1), 256, dtype=torch.int32, device="cuda") + shared_topk_ids = torch.cat([shared_routing, shared_last], dim=1).contiguous() + shared_topk_weights = torch.arange(shared_topk_ids.numel(), dtype=torch.float32, device="cuda").reshape(512, 8) + _check_moe_align_fused(shared_topk_ids, shared_topk_weights, shared_expert_num, ordered=False) + + large_atomic_topk_ids = base_topk_ids.repeat(1281, 1)[:5121].contiguous() + large_atomic_topk_weights = torch.arange(large_atomic_topk_ids.numel(), dtype=torch.float32, device="cuda").reshape( 5121, 3 ) - _check_moe_align_fused(sorted_path_topk_ids, sorted_path_topk_weights, expert_num, ordered=False) + _check_moe_align_fused(large_atomic_topk_ids, large_atomic_topk_weights, expert_num, ordered=False) sparse_expert_num = 257 sparse_topk_ids = base_topk_ids.repeat(1281, 1)[:5121].contiguous() From 8e2a6d186aba2fd56f3c256a72dafe99e39127d4 Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Mon, 29 Jun 2026 01:54:14 +0000 Subject: [PATCH 7/8] fix --- .../fused_moe/grouped_fused_moe.py | 24 ++++++++++++++++--- 1 file changed, 21 insertions(+), 3 deletions(-) diff --git a/lightllm/common/basemodel/triton_kernel/fused_moe/grouped_fused_moe.py b/lightllm/common/basemodel/triton_kernel/fused_moe/grouped_fused_moe.py index 468773375..c6eeb781d 100644 --- a/lightllm/common/basemodel/triton_kernel/fused_moe/grouped_fused_moe.py +++ b/lightllm/common/basemodel/triton_kernel/fused_moe/grouped_fused_moe.py @@ -345,24 +345,42 @@ def moe_align_fused_atomic_kernel( tl.store(expert_to_weight_ptr + expert_id * token_num_mul_topk + write_pos, weight, mask=valid) +@autotune( + kernel_name="moe_align_fused_atomic:v1", + configs_gen_func=lambda: [ + { + "BLOCK_SIZE": block_size, + "num_warps": num_warps, + } + for num_warps in [1, 2, 4, 8] + for block_size in [128, 256, 512, 1024, 2048] + ], + static_key_func=_get_moe_align_fused_static_key, + run_key_func=lambda topk_ids: topk_ids.shape[0], + mutates_args=["expert_to_token_index", "expert_to_weight", "expert_token_num"], +) def _moe_align_fused_atomic_token( expert_to_token_index, expert_to_weight, expert_token_num, topk_ids, topk_weights, + run_config: Optional[dict] = None, ): + if run_config is None: + run_config = {"BLOCK_SIZE": 128, "num_warps": 4} + token_num_mul_topk = topk_ids.numel() expert_token_num.zero_() - moe_align_fused_atomic_kernel[(triton.cdiv(token_num_mul_topk, 128),)]( + moe_align_fused_atomic_kernel[(triton.cdiv(token_num_mul_topk, run_config["BLOCK_SIZE"]),)]( topk_ids, topk_weights, expert_to_token_index, expert_to_weight, expert_token_num, token_num_mul_topk, - BLOCK_SIZE=128, - num_warps=4, + BLOCK_SIZE=run_config["BLOCK_SIZE"], + num_warps=run_config["num_warps"], ) return expert_to_token_index, expert_to_weight, expert_token_num From 10e191a0611da40c9a9afe6bd6391080f6cdf7ab Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Mon, 29 Jun 2026 02:29:03 +0000 Subject: [PATCH 8/8] fix --- ...ghts_dtype=torch.float32}_NVIDIA_H200.json | 50 ++++++++ .../{BT=16,H=8,K=128,V=128}_NVIDIA_H200.json | 8 ++ .../{BT=32,H=8,K=128,V=128}_NVIDIA_H200.json | 8 ++ .../{BT=64,H=8,K=128,V=128}_NVIDIA_H200.json | 8 ++ .../{BT=64,H=8,K=128,V=128}_NVIDIA_H200.json | 7 ++ ...ARLEN=true,REVERSE=false}_NVIDIA_H200.json | 38 ++++++ ...H=8,IS_VARLEN=true,K=128}_NVIDIA_H200.json | 7 ++ ...8,a_dtype=torch.bfloat16}_NVIDIA_H200.json | 50 ++++++++ ...6,x_dtype=torch.bfloat16}_NVIDIA_H200.json | 50 ++++++++ ...num=1,use_fp8_w8a8=false}_NVIDIA_H200.json | 110 ++++++++++++++++++ ...num=9,use_fp8_w8a8=false}_NVIDIA_H200.json | 110 ++++++++++++++++++ ...xpert_num=257,topk_num=9}_NVIDIA_H200.json | 22 ++++ ...xpert_num=257,topk_num=9}_NVIDIA_H200.json | 37 ++++++ ...orch.bfloat16,topk_num=9}_NVIDIA_H200.json | 74 ++++++++++++ ...M=4,dtype=torch.bfloat16}_NVIDIA_H200.json | 50 ++++++++ ...out_dtype=torch.bfloat16}_NVIDIA_H200.json | 74 ++++++++++++ 16 files changed, 703 insertions(+) create mode 100644 lightllm/common/triton_utils/autotune_kernel_configs/triton_3.6.0/NVIDIA_H200/append_fused_shared_experts:v1/{has_shared_expert_gate=true,num_fused_shared_experts=1,topk_ids_dtype=torch.int64,topk_num=8,topk_weights_dtype=torch.float32}_NVIDIA_H200.json create mode 100644 lightllm/common/triton_utils/autotune_kernel_configs/triton_3.6.0/NVIDIA_H200/chunk_fwd_o/{BT=16,H=8,K=128,V=128}_NVIDIA_H200.json create mode 100644 lightllm/common/triton_utils/autotune_kernel_configs/triton_3.6.0/NVIDIA_H200/chunk_fwd_o/{BT=32,H=8,K=128,V=128}_NVIDIA_H200.json create mode 100644 lightllm/common/triton_utils/autotune_kernel_configs/triton_3.6.0/NVIDIA_H200/chunk_fwd_o/{BT=64,H=8,K=128,V=128}_NVIDIA_H200.json create mode 100644 lightllm/common/triton_utils/autotune_kernel_configs/triton_3.6.0/NVIDIA_H200/chunk_gated_delta_rule_fwd_h/{BT=64,H=8,K=128,V=128}_NVIDIA_H200.json create mode 100644 lightllm/common/triton_utils/autotune_kernel_configs/triton_3.6.0/NVIDIA_H200/chunk_local_cumsum_scalar/{B=1,BT=64,H=8,IS_VARLEN=true,REVERSE=false}_NVIDIA_H200.json create mode 100644 lightllm/common/triton_utils/autotune_kernel_configs/triton_3.6.0/NVIDIA_H200/chunk_scaled_dot_kkt_fwd/{BT=64,H=8,IS_VARLEN=true,K=128}_NVIDIA_H200.json create mode 100644 lightllm/common/triton_utils/autotune_kernel_configs/triton_3.6.0/NVIDIA_H200/fused_gdn_gating:v1/{NUM_HEADS=8,a_dtype=torch.bfloat16}_NVIDIA_H200.json create mode 100644 lightllm/common/triton_utils/autotune_kernel_configs/triton_3.6.0/NVIDIA_H200/gated_rmsnorm_forward:v1/{N=128,has_bias=false,weight_dtype=torch.bfloat16,x_dtype=torch.bfloat16}_NVIDIA_H200.json create mode 100644 lightllm/common/triton_utils/autotune_kernel_configs/triton_3.6.0/NVIDIA_H200/grouped_matmul:v1/{K=128,N=3072,expert_num=257,mul_routed_weight=true,out_dtype=torch.bfloat16,topk_num=1,use_fp8_w8a8=false}_NVIDIA_H200.json create mode 100644 lightllm/common/triton_utils/autotune_kernel_configs/triton_3.6.0/NVIDIA_H200/grouped_matmul:v1/{K=3072,N=256,expert_num=257,mul_routed_weight=false,out_dtype=torch.bfloat16,topk_num=9,use_fp8_w8a8=false}_NVIDIA_H200.json create mode 100644 lightllm/common/triton_utils/autotune_kernel_configs/triton_3.6.0/NVIDIA_H200/moe_align_fused_atomic:v1/{expert_num=257,topk_num=9}_NVIDIA_H200.json create mode 100644 lightllm/common/triton_utils/autotune_kernel_configs/triton_3.6.0/NVIDIA_H200/moe_align_fused_small:v2/{expert_num=257,topk_num=9}_NVIDIA_H200.json create mode 100644 lightllm/common/triton_utils/autotune_kernel_configs/triton_3.6.0/NVIDIA_H200/moe_sum_reduce:v1/{hidden_dim=3072,out_dtype=torch.bfloat16,topk_num=9}_NVIDIA_H200.json create mode 100644 lightllm/common/triton_utils/autotune_kernel_configs/triton_3.6.0/NVIDIA_H200/mrope_triton_fused:v1/{HEAD_DIM=256,K_HEAD_NUM=1,Q_HEAD_NUM=4,dtype=torch.bfloat16}_NVIDIA_H200.json create mode 100644 lightllm/common/triton_utils/autotune_kernel_configs/triton_3.6.0/NVIDIA_H200/silu_and_mul_fwd:v1/{N=128,out_dtype=torch.bfloat16}_NVIDIA_H200.json diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.6.0/NVIDIA_H200/append_fused_shared_experts:v1/{has_shared_expert_gate=true,num_fused_shared_experts=1,topk_ids_dtype=torch.int64,topk_num=8,topk_weights_dtype=torch.float32}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.6.0/NVIDIA_H200/append_fused_shared_experts:v1/{has_shared_expert_gate=true,num_fused_shared_experts=1,topk_ids_dtype=torch.int64,topk_num=8,topk_weights_dtype=torch.float32}_NVIDIA_H200.json new file mode 100644 index 000000000..f2a0176db --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.6.0/NVIDIA_H200/append_fused_shared_experts:v1/{has_shared_expert_gate=true,num_fused_shared_experts=1,topk_ids_dtype=torch.int64,topk_num=8,topk_weights_dtype=torch.float32}_NVIDIA_H200.json @@ -0,0 +1,50 @@ +{ + "1": { + "BLOCK_TOKEN": 16, + "num_warps": 8 + }, + "100": { + "BLOCK_TOKEN": 16, + "num_warps": 8 + }, + "1024": { + "BLOCK_TOKEN": 16, + "num_warps": 8 + }, + "128": { + "BLOCK_TOKEN": 4, + "num_warps": 8 + }, + "16": { + "BLOCK_TOKEN": 16, + "num_warps": 8 + }, + "16384": { + "BLOCK_TOKEN": 32, + "num_warps": 2 + }, + "2048": { + "BLOCK_TOKEN": 16, + "num_warps": 8 + }, + "256": { + "BLOCK_TOKEN": 4, + "num_warps": 8 + }, + "32": { + "BLOCK_TOKEN": 4, + "num_warps": 8 + }, + "4096": { + "BLOCK_TOKEN": 16, + "num_warps": 8 + }, + "64": { + "BLOCK_TOKEN": 16, + "num_warps": 8 + }, + "8": { + "BLOCK_TOKEN": 16, + "num_warps": 8 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.6.0/NVIDIA_H200/chunk_fwd_o/{BT=16,H=8,K=128,V=128}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.6.0/NVIDIA_H200/chunk_fwd_o/{BT=16,H=8,K=128,V=128}_NVIDIA_H200.json new file mode 100644 index 000000000..5497f5e6c --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.6.0/NVIDIA_H200/chunk_fwd_o/{BT=16,H=8,K=128,V=128}_NVIDIA_H200.json @@ -0,0 +1,8 @@ +{ + "2": { + "BK": 128, + "BV": 64, + "num_stages": 4, + "num_warps": 2 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.6.0/NVIDIA_H200/chunk_fwd_o/{BT=32,H=8,K=128,V=128}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.6.0/NVIDIA_H200/chunk_fwd_o/{BT=32,H=8,K=128,V=128}_NVIDIA_H200.json new file mode 100644 index 000000000..5497f5e6c --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.6.0/NVIDIA_H200/chunk_fwd_o/{BT=32,H=8,K=128,V=128}_NVIDIA_H200.json @@ -0,0 +1,8 @@ +{ + "2": { + "BK": 128, + "BV": 64, + "num_stages": 4, + "num_warps": 2 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.6.0/NVIDIA_H200/chunk_fwd_o/{BT=64,H=8,K=128,V=128}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.6.0/NVIDIA_H200/chunk_fwd_o/{BT=64,H=8,K=128,V=128}_NVIDIA_H200.json new file mode 100644 index 000000000..e9918f6ad --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.6.0/NVIDIA_H200/chunk_fwd_o/{BT=64,H=8,K=128,V=128}_NVIDIA_H200.json @@ -0,0 +1,8 @@ +{ + "2": { + "BK": 64, + "BV": 128, + "num_stages": 3, + "num_warps": 4 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.6.0/NVIDIA_H200/chunk_gated_delta_rule_fwd_h/{BT=64,H=8,K=128,V=128}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.6.0/NVIDIA_H200/chunk_gated_delta_rule_fwd_h/{BT=64,H=8,K=128,V=128}_NVIDIA_H200.json new file mode 100644 index 000000000..d037521c3 --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.6.0/NVIDIA_H200/chunk_gated_delta_rule_fwd_h/{BT=64,H=8,K=128,V=128}_NVIDIA_H200.json @@ -0,0 +1,7 @@ +{ + "2": { + "BV": 32, + "num_stages": 3, + "num_warps": 4 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.6.0/NVIDIA_H200/chunk_local_cumsum_scalar/{B=1,BT=64,H=8,IS_VARLEN=true,REVERSE=false}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.6.0/NVIDIA_H200/chunk_local_cumsum_scalar/{B=1,BT=64,H=8,IS_VARLEN=true,REVERSE=false}_NVIDIA_H200.json new file mode 100644 index 000000000..e922d888f --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.6.0/NVIDIA_H200/chunk_local_cumsum_scalar/{B=1,BT=64,H=8,IS_VARLEN=true,REVERSE=false}_NVIDIA_H200.json @@ -0,0 +1,38 @@ +{ + "1": { + "num_warps": 4 + }, + "100": { + "num_warps": 4 + }, + "1024": { + "num_warps": 4 + }, + "128": { + "num_warps": 4 + }, + "16": { + "num_warps": 4 + }, + "16384": { + "num_warps": 4 + }, + "2048": { + "num_warps": 4 + }, + "256": { + "num_warps": 4 + }, + "32": { + "num_warps": 4 + }, + "4096": { + "num_warps": 4 + }, + "64": { + "num_warps": 4 + }, + "8": { + "num_warps": 4 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.6.0/NVIDIA_H200/chunk_scaled_dot_kkt_fwd/{BT=64,H=8,IS_VARLEN=true,K=128}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.6.0/NVIDIA_H200/chunk_scaled_dot_kkt_fwd/{BT=64,H=8,IS_VARLEN=true,K=128}_NVIDIA_H200.json new file mode 100644 index 000000000..9459f41fa --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.6.0/NVIDIA_H200/chunk_scaled_dot_kkt_fwd/{BT=64,H=8,IS_VARLEN=true,K=128}_NVIDIA_H200.json @@ -0,0 +1,7 @@ +{ + "2": { + "BK": 64, + "num_stages": 3, + "num_warps": 4 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.6.0/NVIDIA_H200/fused_gdn_gating:v1/{NUM_HEADS=8,a_dtype=torch.bfloat16}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.6.0/NVIDIA_H200/fused_gdn_gating:v1/{NUM_HEADS=8,a_dtype=torch.bfloat16}_NVIDIA_H200.json new file mode 100644 index 000000000..dff8ac4d0 --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.6.0/NVIDIA_H200/fused_gdn_gating:v1/{NUM_HEADS=8,a_dtype=torch.bfloat16}_NVIDIA_H200.json @@ -0,0 +1,50 @@ +{ + "1": { + "BLK_HEADS": 4, + "num_warps": 2 + }, + "100": { + "BLK_HEADS": 16, + "num_warps": 4 + }, + "1024": { + "BLK_HEADS": 16, + "num_warps": 4 + }, + "128": { + "BLK_HEADS": 16, + "num_warps": 4 + }, + "16": { + "BLK_HEADS": 16, + "num_warps": 4 + }, + "16384": { + "BLK_HEADS": 16, + "num_warps": 4 + }, + "2048": { + "BLK_HEADS": 16, + "num_warps": 4 + }, + "256": { + "BLK_HEADS": 16, + "num_warps": 4 + }, + "32": { + "BLK_HEADS": 16, + "num_warps": 4 + }, + "4096": { + "BLK_HEADS": 16, + "num_warps": 4 + }, + "64": { + "BLK_HEADS": 16, + "num_warps": 4 + }, + "8": { + "BLK_HEADS": 4, + "num_warps": 2 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.6.0/NVIDIA_H200/gated_rmsnorm_forward:v1/{N=128,has_bias=false,weight_dtype=torch.bfloat16,x_dtype=torch.bfloat16}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.6.0/NVIDIA_H200/gated_rmsnorm_forward:v1/{N=128,has_bias=false,weight_dtype=torch.bfloat16,x_dtype=torch.bfloat16}_NVIDIA_H200.json new file mode 100644 index 000000000..1fcfa30e9 --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.6.0/NVIDIA_H200/gated_rmsnorm_forward:v1/{N=128,has_bias=false,weight_dtype=torch.bfloat16,x_dtype=torch.bfloat16}_NVIDIA_H200.json @@ -0,0 +1,50 @@ +{ + "1024": { + "BLOCK_N": 128, + "num_warps": 2 + }, + "128": { + "BLOCK_N": 128, + "num_warps": 2 + }, + "131072": { + "BLOCK_N": 128, + "num_warps": 2 + }, + "16384": { + "BLOCK_N": 128, + "num_warps": 2 + }, + "2048": { + "BLOCK_N": 128, + "num_warps": 2 + }, + "256": { + "BLOCK_N": 128, + "num_warps": 2 + }, + "32768": { + "BLOCK_N": 128, + "num_warps": 2 + }, + "512": { + "BLOCK_N": 128, + "num_warps": 2 + }, + "64": { + "BLOCK_N": 512, + "num_warps": 4 + }, + "8": { + "BLOCK_N": 128, + "num_warps": 2 + }, + "800": { + "BLOCK_N": 128, + "num_warps": 2 + }, + "8192": { + "BLOCK_N": 128, + "num_warps": 2 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.6.0/NVIDIA_H200/grouped_matmul:v1/{K=128,N=3072,expert_num=257,mul_routed_weight=true,out_dtype=torch.bfloat16,topk_num=1,use_fp8_w8a8=false}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.6.0/NVIDIA_H200/grouped_matmul:v1/{K=128,N=3072,expert_num=257,mul_routed_weight=true,out_dtype=torch.bfloat16,topk_num=1,use_fp8_w8a8=false}_NVIDIA_H200.json new file mode 100644 index 000000000..978754ec9 --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.6.0/NVIDIA_H200/grouped_matmul:v1/{K=128,N=3072,expert_num=257,mul_routed_weight=true,out_dtype=torch.bfloat16,topk_num=1,use_fp8_w8a8=false}_NVIDIA_H200.json @@ -0,0 +1,110 @@ +{ + "1152": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "144": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 64, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "147456": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 16, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "18432": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 64, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "2304": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "288": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 64, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "36864": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "576": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 64, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "72": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 64, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "9": { + "BLOCK_SIZE_K": 32, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 16, + "NEED_TRANS": false, + "num_stages": 4, + "num_warps": 4 + }, + "900": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "9216": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.6.0/NVIDIA_H200/grouped_matmul:v1/{K=3072,N=256,expert_num=257,mul_routed_weight=false,out_dtype=torch.bfloat16,topk_num=9,use_fp8_w8a8=false}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.6.0/NVIDIA_H200/grouped_matmul:v1/{K=3072,N=256,expert_num=257,mul_routed_weight=false,out_dtype=torch.bfloat16,topk_num=9,use_fp8_w8a8=false}_NVIDIA_H200.json new file mode 100644 index 000000000..2af76b549 --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.6.0/NVIDIA_H200/grouped_matmul:v1/{K=3072,N=256,expert_num=257,mul_routed_weight=false,out_dtype=torch.bfloat16,topk_num=9,use_fp8_w8a8=false}_NVIDIA_H200.json @@ -0,0 +1,110 @@ +{ + "1": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 4, + "num_warps": 4 + }, + "100": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 64, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "1024": { + "BLOCK_SIZE_K": 32, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 16, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "128": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 64, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "16": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 4, + "num_warps": 4 + }, + "16384": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "2048": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 16, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "256": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 64, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "32": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 16, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "4096": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "64": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "8": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "NEED_TRANS": false, + "num_stages": 5, + "num_warps": 4 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.6.0/NVIDIA_H200/moe_align_fused_atomic:v1/{expert_num=257,topk_num=9}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.6.0/NVIDIA_H200/moe_align_fused_atomic:v1/{expert_num=257,topk_num=9}_NVIDIA_H200.json new file mode 100644 index 000000000..271548dd6 --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.6.0/NVIDIA_H200/moe_align_fused_atomic:v1/{expert_num=257,topk_num=9}_NVIDIA_H200.json @@ -0,0 +1,22 @@ +{ + "1024": { + "BLOCK_SIZE": 128, + "num_warps": 8 + }, + "16384": { + "BLOCK_SIZE": 128, + "num_warps": 4 + }, + "2048": { + "BLOCK_SIZE": 128, + "num_warps": 4 + }, + "256": { + "BLOCK_SIZE": 128, + "num_warps": 4 + }, + "4096": { + "BLOCK_SIZE": 128, + "num_warps": 4 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.6.0/NVIDIA_H200/moe_align_fused_small:v2/{expert_num=257,topk_num=9}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.6.0/NVIDIA_H200/moe_align_fused_small:v2/{expert_num=257,topk_num=9}_NVIDIA_H200.json new file mode 100644 index 000000000..42cc5c38a --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.6.0/NVIDIA_H200/moe_align_fused_small:v2/{expert_num=257,topk_num=9}_NVIDIA_H200.json @@ -0,0 +1,37 @@ +{ + "1": { + "BLOCK_SIZE": 32, + "NUM_STAGE": 2, + "num_warps": 2 + }, + "100": { + "BLOCK_SIZE": 1024, + "NUM_STAGE": 4, + "num_warps": 8 + }, + "128": { + "BLOCK_SIZE": 256, + "NUM_STAGE": 4, + "num_warps": 2 + }, + "16": { + "BLOCK_SIZE": 256, + "NUM_STAGE": 2, + "num_warps": 2 + }, + "32": { + "BLOCK_SIZE": 512, + "NUM_STAGE": 1, + "num_warps": 4 + }, + "64": { + "BLOCK_SIZE": 256, + "NUM_STAGE": 4, + "num_warps": 2 + }, + "8": { + "BLOCK_SIZE": 128, + "NUM_STAGE": 1, + "num_warps": 4 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.6.0/NVIDIA_H200/moe_sum_reduce:v1/{hidden_dim=3072,out_dtype=torch.bfloat16,topk_num=9}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.6.0/NVIDIA_H200/moe_sum_reduce:v1/{hidden_dim=3072,out_dtype=torch.bfloat16,topk_num=9}_NVIDIA_H200.json new file mode 100644 index 000000000..bd78d49c4 --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.6.0/NVIDIA_H200/moe_sum_reduce:v1/{hidden_dim=3072,out_dtype=torch.bfloat16,topk_num=9}_NVIDIA_H200.json @@ -0,0 +1,74 @@ +{ + "1": { + "BLOCK_DIM": 256, + "BLOCK_M": 8, + "NUM_STAGE": 1, + "num_warps": 8 + }, + "100": { + "BLOCK_DIM": 256, + "BLOCK_M": 1, + "NUM_STAGE": 1, + "num_warps": 1 + }, + "1024": { + "BLOCK_DIM": 1024, + "BLOCK_M": 2, + "NUM_STAGE": 4, + "num_warps": 4 + }, + "128": { + "BLOCK_DIM": 512, + "BLOCK_M": 1, + "NUM_STAGE": 1, + "num_warps": 4 + }, + "16": { + "BLOCK_DIM": 512, + "BLOCK_M": 1, + "NUM_STAGE": 2, + "num_warps": 16 + }, + "16384": { + "BLOCK_DIM": 1024, + "BLOCK_M": 1, + "NUM_STAGE": 4, + "num_warps": 4 + }, + "2048": { + "BLOCK_DIM": 1024, + "BLOCK_M": 4, + "NUM_STAGE": 4, + "num_warps": 4 + }, + "256": { + "BLOCK_DIM": 512, + "BLOCK_M": 1, + "NUM_STAGE": 1, + "num_warps": 4 + }, + "32": { + "BLOCK_DIM": 256, + "BLOCK_M": 1, + "NUM_STAGE": 1, + "num_warps": 1 + }, + "4096": { + "BLOCK_DIM": 512, + "BLOCK_M": 1, + "NUM_STAGE": 4, + "num_warps": 2 + }, + "64": { + "BLOCK_DIM": 256, + "BLOCK_M": 1, + "NUM_STAGE": 1, + "num_warps": 1 + }, + "8": { + "BLOCK_DIM": 128, + "BLOCK_M": 1, + "NUM_STAGE": 4, + "num_warps": 8 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.6.0/NVIDIA_H200/mrope_triton_fused:v1/{HEAD_DIM=256,K_HEAD_NUM=1,Q_HEAD_NUM=4,dtype=torch.bfloat16}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.6.0/NVIDIA_H200/mrope_triton_fused:v1/{HEAD_DIM=256,K_HEAD_NUM=1,Q_HEAD_NUM=4,dtype=torch.bfloat16}_NVIDIA_H200.json new file mode 100644 index 000000000..f1882ef5d --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.6.0/NVIDIA_H200/mrope_triton_fused:v1/{HEAD_DIM=256,K_HEAD_NUM=1,Q_HEAD_NUM=4,dtype=torch.bfloat16}_NVIDIA_H200.json @@ -0,0 +1,50 @@ +{ + "1": { + "num_stages": 4, + "num_warps": 1 + }, + "100": { + "num_stages": 4, + "num_warps": 1 + }, + "1024": { + "num_stages": 4, + "num_warps": 1 + }, + "128": { + "num_stages": 4, + "num_warps": 1 + }, + "16": { + "num_stages": 4, + "num_warps": 1 + }, + "16384": { + "num_stages": 4, + "num_warps": 2 + }, + "2048": { + "num_stages": 3, + "num_warps": 2 + }, + "256": { + "num_stages": 4, + "num_warps": 1 + }, + "32": { + "num_stages": 4, + "num_warps": 1 + }, + "4096": { + "num_stages": 5, + "num_warps": 1 + }, + "64": { + "num_stages": 4, + "num_warps": 1 + }, + "8": { + "num_stages": 4, + "num_warps": 2 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.6.0/NVIDIA_H200/silu_and_mul_fwd:v1/{N=128,out_dtype=torch.bfloat16}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.6.0/NVIDIA_H200/silu_and_mul_fwd:v1/{N=128,out_dtype=torch.bfloat16}_NVIDIA_H200.json new file mode 100644 index 000000000..c3cabb161 --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.6.0/NVIDIA_H200/silu_and_mul_fwd:v1/{N=128,out_dtype=torch.bfloat16}_NVIDIA_H200.json @@ -0,0 +1,74 @@ +{ + "1152": { + "BLOCK_M": 1, + "BLOCK_N": 256, + "NUM_STAGES": 1, + "num_warps": 4 + }, + "144": { + "BLOCK_M": 1, + "BLOCK_N": 256, + "NUM_STAGES": 1, + "num_warps": 4 + }, + "147456": { + "BLOCK_M": 64, + "BLOCK_N": 128, + "NUM_STAGES": 4, + "num_warps": 1 + }, + "18432": { + "BLOCK_M": 8, + "BLOCK_N": 128, + "NUM_STAGES": 4, + "num_warps": 1 + }, + "2304": { + "BLOCK_M": 1, + "BLOCK_N": 256, + "NUM_STAGES": 1, + "num_warps": 4 + }, + "288": { + "BLOCK_M": 1, + "BLOCK_N": 256, + "NUM_STAGES": 1, + "num_warps": 4 + }, + "36864": { + "BLOCK_M": 8, + "BLOCK_N": 128, + "NUM_STAGES": 4, + "num_warps": 1 + }, + "576": { + "BLOCK_M": 1, + "BLOCK_N": 256, + "NUM_STAGES": 1, + "num_warps": 4 + }, + "72": { + "BLOCK_M": 1, + "BLOCK_N": 64, + "NUM_STAGES": 2, + "num_warps": 1 + }, + "9": { + "BLOCK_M": 1, + "BLOCK_N": 256, + "NUM_STAGES": 1, + "num_warps": 4 + }, + "900": { + "BLOCK_M": 1, + "BLOCK_N": 256, + "NUM_STAGES": 1, + "num_warps": 4 + }, + "9216": { + "BLOCK_M": 8, + "BLOCK_N": 128, + "NUM_STAGES": 4, + "num_warps": 1 + } +} \ No newline at end of file