diff --git a/lightllm/common/basemodel/triton_kernel/fused_moe/grouped_fused_moe.py b/lightllm/common/basemodel/triton_kernel/fused_moe/grouped_fused_moe.py index 1f7175f72..c6eeb781d 100644 --- a/lightllm/common/basemodel/triton_kernel/fused_moe/grouped_fused_moe.py +++ b/lightllm/common/basemodel/triton_kernel/fused_moe/grouped_fused_moe.py @@ -214,110 +214,188 @@ def moe_align1( @triton.jit -def moe_align_fused_kernel( +def moe_align_fused_small_token_kernel( topk_ids_ptr, # [token_num, topk] topk_weights_ptr, # [token_num, topk] expert_to_token_index_ptr, # [expert_num, token_num * topk] expert_to_weight_ptr, # [expert_num, token_num * topk] expert_token_num_ptr, # [expert_num] - token_num, - expert_num: tl.constexpr, - topk_num: tl.constexpr, + token_num_mul_topk, BLOCK_SIZE: tl.constexpr, - INIT_EXPERT_TOKEN_NUM_IN_KERNEL: tl.constexpr, - BLOCK_EXPERT: tl.constexpr, + NUM_STAGE: tl.constexpr, ): - token_block = tl.program_id(0) - if INIT_EXPERT_TOKEN_NUM_IN_KERNEL: - expert_offs = tl.arange(0, BLOCK_EXPERT) - tl.store(expert_token_num_ptr + expert_offs, 0, mask=expert_offs < expert_num) - tl.debug_barrier() - - offs = token_block * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) - mask = offs < token_num * topk_num - - expert_ids = tl.load(topk_ids_ptr + offs, mask=mask, other=0) - weights = tl.load(topk_weights_ptr + offs, mask=mask, other=0.0) - - # 用 atomic_add 给 expert 分配写位置 - write_pos = tl.atomic_add(expert_token_num_ptr + expert_ids, 1, mask=mask) + expert_id = tl.program_id(0) + block_offs = tl.arange(0, BLOCK_SIZE) + token_count = tl.full((), 0, tl.int32) + + for start in tl.range(0, token_num_mul_topk, BLOCK_SIZE, num_stages=NUM_STAGE): + raw_offs = start + block_offs + valid = raw_offs < token_num_mul_topk + load_offs = tl.where(valid, raw_offs, 0) + expert_ids = tl.load(topk_ids_ptr + load_offs, mask=valid, other=-1) + weights = tl.load(topk_weights_ptr + load_offs, mask=valid, other=0.0) + + expert_mask = (expert_ids == expert_id) & valid + expert_hits = tl.where(expert_mask, 1, 0) + write_pos = token_count + tl.cumsum(expert_hits, axis=0) - 1 + tl.store( + expert_to_token_index_ptr + expert_id * token_num_mul_topk + write_pos, + raw_offs, + mask=expert_mask, + ) + tl.store( + expert_to_weight_ptr + expert_id * token_num_mul_topk + write_pos, + weights, + mask=expert_mask, + ) + token_count += tl.sum(expert_hits, axis=0) - # 按 token 顺序写 index 和 weight - tl.store( - expert_to_token_index_ptr + expert_ids * (token_num * topk_num) + write_pos, - offs, - mask=mask, - ) - tl.store( - expert_to_weight_ptr + expert_ids * (token_num * topk_num) + write_pos, - weights, - mask=mask, - ) + tl.store(expert_token_num_ptr + expert_id, token_count) def _get_moe_align_fused_static_key( + expert_token_num: torch.Tensor, topk_weights: torch.Tensor, ) -> dict: topk_num = topk_weights.shape[1] + expert_num = expert_token_num.shape[0] return { "topk_num": topk_num, + "expert_num": expert_num, } -def _get_moe_align_fused_configs(): - return [ +@autotune( + kernel_name="moe_align_fused_small:v2", + configs_gen_func=lambda: [ { "BLOCK_SIZE": bt, "num_warps": nw, + "NUM_STAGE": ns, } + for ns in [1, 2, 4, 6] for nw in [1, 2, 4, 8] - for bt in [128, 256, 512, 1024, 2048] - ] + for bt in [8, 16, 32, 64, 128, 256, 512, 1024, 2048] + ], + static_key_func=_get_moe_align_fused_static_key, + run_key_func=lambda topk_ids: topk_ids.shape[0], + mutates_args=["expert_to_token_index", "expert_to_weight", "expert_token_num"], +) +def _moe_align_fused_small_token( + expert_to_token_index, + expert_to_weight, + expert_token_num, + topk_ids, + topk_weights, + run_config: Optional[dict] = None, +): + if run_config is None: + token_num = topk_ids.shape[0] + if token_num <= 2: + run_config = {"BLOCK_SIZE": 16, "num_warps": 1, "NUM_STAGE": 1} + elif token_num <= 8: + run_config = {"BLOCK_SIZE": 64, "num_warps": 1, "NUM_STAGE": 1} + elif token_num <= 16: + run_config = {"BLOCK_SIZE": 128, "num_warps": 1, "NUM_STAGE": 1} + elif token_num < 32: + run_config = {"BLOCK_SIZE": 256, "num_warps": 2, "NUM_STAGE": 1} + elif token_num <= 64: + run_config = {"BLOCK_SIZE": 512, "num_warps": 4, "NUM_STAGE": 1} + elif token_num <= 128: + run_config = {"BLOCK_SIZE": 1024, "num_warps": 8, "NUM_STAGE": 1} + elif token_num <= 192: + run_config = {"BLOCK_SIZE": 512, "num_warps": 4, "NUM_STAGE": 1} + else: + run_config = {"BLOCK_SIZE": 2048, "num_warps": 8, "NUM_STAGE": 1} + token_num_mul_topk = topk_ids.numel() + expert_num = expert_token_num.shape[0] + block_size = run_config["BLOCK_SIZE"] + + moe_align_fused_small_token_kernel[(expert_num,)]( + topk_ids, + topk_weights, + expert_to_token_index, + expert_to_weight, + expert_token_num, + token_num_mul_topk, + BLOCK_SIZE=block_size, + NUM_STAGE=run_config["NUM_STAGE"], + num_warps=run_config["num_warps"], + ) + return expert_to_token_index, expert_to_weight, expert_token_num + + +@triton.jit +def moe_align_fused_atomic_kernel( + topk_ids_ptr, # [token_num, topk] + topk_weights_ptr, # [token_num, topk] + expert_to_token_index_ptr, # [expert_num, token_num * topk] + expert_to_weight_ptr, # [expert_num, token_num * topk] + expert_token_num_ptr, # [expert_num] + token_num_mul_topk, + BLOCK_SIZE: tl.constexpr, +): + block_id = tl.program_id(0) + offs = block_id * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + valid = offs < token_num_mul_topk + expert_id = tl.load(topk_ids_ptr + offs, mask=valid, other=0) + weight = tl.load(topk_weights_ptr + offs, mask=valid, other=0.0) + write_pos = tl.atomic_add(expert_token_num_ptr + expert_id, 1, mask=valid) + tl.store(expert_to_token_index_ptr + expert_id * token_num_mul_topk + write_pos, offs, mask=valid) + tl.store(expert_to_weight_ptr + expert_id * token_num_mul_topk + write_pos, weight, mask=valid) @autotune( - kernel_name="moe_align_fused:v1", - configs_gen_func=_get_moe_align_fused_configs, + kernel_name="moe_align_fused_atomic:v1", + configs_gen_func=lambda: [ + { + "BLOCK_SIZE": block_size, + "num_warps": num_warps, + } + for num_warps in [1, 2, 4, 8] + for block_size in [128, 256, 512, 1024, 2048] + ], static_key_func=_get_moe_align_fused_static_key, run_key_func=lambda topk_ids: topk_ids.shape[0], mutates_args=["expert_to_token_index", "expert_to_weight", "expert_token_num"], ) -def moe_align_fused( - expert_to_token_index, expert_to_weight, expert_token_num, topk_ids, topk_weights, run_config: Optional[dict] = None +def _moe_align_fused_atomic_token( + expert_to_token_index, + expert_to_weight, + expert_token_num, + topk_ids, + topk_weights, + run_config: Optional[dict] = None, ): - token_num, topk_num = topk_ids.shape if run_config is None: - run_config = {} - BLOCK_SIZE = run_config.get("BLOCK_SIZE", 256) - num_warps = run_config.get("num_warps", 4) + run_config = {"BLOCK_SIZE": 128, "num_warps": 4} - # For small inputs the align kernel has a single program, so it can initialize - # expert_token_num itself and avoid an extra zero_ kernel launch. - expert_num = expert_token_num.shape[0] - init_expert_token_num_in_kernel = token_num * topk_num <= BLOCK_SIZE - if not init_expert_token_num_in_kernel: - # Multiple align programs may update expert_token_num concurrently; initialize - # it before launch to avoid races between clear and atomic_add. - expert_token_num.zero_() - - grid = (triton.cdiv(token_num * topk_num, BLOCK_SIZE),) - moe_align_fused_kernel[grid]( + token_num_mul_topk = topk_ids.numel() + expert_token_num.zero_() + moe_align_fused_atomic_kernel[(triton.cdiv(token_num_mul_topk, run_config["BLOCK_SIZE"]),)]( topk_ids, topk_weights, expert_to_token_index, expert_to_weight, expert_token_num, - token_num, - expert_num, - topk_num, - BLOCK_SIZE=BLOCK_SIZE, - INIT_EXPERT_TOKEN_NUM_IN_KERNEL=init_expert_token_num_in_kernel, - BLOCK_EXPERT=triton.next_power_of_2(expert_num), - num_warps=num_warps, + token_num_mul_topk, + BLOCK_SIZE=run_config["BLOCK_SIZE"], + num_warps=run_config["num_warps"], ) return expert_to_token_index, expert_to_weight, expert_token_num +def moe_align_fused(expert_to_token_index, expert_to_weight, expert_token_num, topk_ids, topk_weights): + token_num = topk_ids.shape[0] + if token_num <= 128: + _moe_align_fused_small_token(expert_to_token_index, expert_to_weight, expert_token_num, topk_ids, topk_weights) + else: + # Expert rows may be unordered, but grouped matmul reuses this same + # mapping for up/down projections and writes back to original topk slots. + _moe_align_fused_atomic_token(expert_to_token_index, expert_to_weight, expert_token_num, topk_ids, topk_weights) + return expert_to_token_index, expert_to_weight, expert_token_num + + @triton.jit def moe_align2_kernel( experts_token_num_ptr, # [expert_num,] diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.6.0/NVIDIA_H200/append_fused_shared_experts:v1/{has_shared_expert_gate=true,num_fused_shared_experts=1,topk_ids_dtype=torch.int64,topk_num=8,topk_weights_dtype=torch.float32}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.6.0/NVIDIA_H200/append_fused_shared_experts:v1/{has_shared_expert_gate=true,num_fused_shared_experts=1,topk_ids_dtype=torch.int64,topk_num=8,topk_weights_dtype=torch.float32}_NVIDIA_H200.json new file mode 100644 index 000000000..f2a0176db --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.6.0/NVIDIA_H200/append_fused_shared_experts:v1/{has_shared_expert_gate=true,num_fused_shared_experts=1,topk_ids_dtype=torch.int64,topk_num=8,topk_weights_dtype=torch.float32}_NVIDIA_H200.json @@ -0,0 +1,50 @@ +{ + "1": { + "BLOCK_TOKEN": 16, + "num_warps": 8 + }, + "100": { + "BLOCK_TOKEN": 16, + "num_warps": 8 + }, + "1024": { + "BLOCK_TOKEN": 16, + "num_warps": 8 + }, + "128": { + "BLOCK_TOKEN": 4, + "num_warps": 8 + }, + "16": { + "BLOCK_TOKEN": 16, + "num_warps": 8 + }, + "16384": { + "BLOCK_TOKEN": 32, + "num_warps": 2 + }, + "2048": { + "BLOCK_TOKEN": 16, + "num_warps": 8 + }, + "256": { + "BLOCK_TOKEN": 4, + "num_warps": 8 + }, + "32": { + "BLOCK_TOKEN": 4, + "num_warps": 8 + }, + "4096": { + "BLOCK_TOKEN": 16, + "num_warps": 8 + }, + "64": { + "BLOCK_TOKEN": 16, + "num_warps": 8 + }, + "8": { + "BLOCK_TOKEN": 16, + "num_warps": 8 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.6.0/NVIDIA_H200/chunk_fwd_o/{BT=16,H=8,K=128,V=128}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.6.0/NVIDIA_H200/chunk_fwd_o/{BT=16,H=8,K=128,V=128}_NVIDIA_H200.json new file mode 100644 index 000000000..5497f5e6c --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.6.0/NVIDIA_H200/chunk_fwd_o/{BT=16,H=8,K=128,V=128}_NVIDIA_H200.json @@ -0,0 +1,8 @@ +{ + "2": { + "BK": 128, + "BV": 64, + "num_stages": 4, + "num_warps": 2 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.6.0/NVIDIA_H200/chunk_fwd_o/{BT=32,H=8,K=128,V=128}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.6.0/NVIDIA_H200/chunk_fwd_o/{BT=32,H=8,K=128,V=128}_NVIDIA_H200.json new file mode 100644 index 000000000..5497f5e6c --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.6.0/NVIDIA_H200/chunk_fwd_o/{BT=32,H=8,K=128,V=128}_NVIDIA_H200.json @@ -0,0 +1,8 @@ +{ + "2": { + "BK": 128, + "BV": 64, + "num_stages": 4, + "num_warps": 2 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.6.0/NVIDIA_H200/chunk_fwd_o/{BT=64,H=8,K=128,V=128}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.6.0/NVIDIA_H200/chunk_fwd_o/{BT=64,H=8,K=128,V=128}_NVIDIA_H200.json new file mode 100644 index 000000000..e9918f6ad --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.6.0/NVIDIA_H200/chunk_fwd_o/{BT=64,H=8,K=128,V=128}_NVIDIA_H200.json @@ -0,0 +1,8 @@ +{ + "2": { + "BK": 64, + "BV": 128, + "num_stages": 3, + "num_warps": 4 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.6.0/NVIDIA_H200/chunk_gated_delta_rule_fwd_h/{BT=64,H=8,K=128,V=128}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.6.0/NVIDIA_H200/chunk_gated_delta_rule_fwd_h/{BT=64,H=8,K=128,V=128}_NVIDIA_H200.json new file mode 100644 index 000000000..d037521c3 --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.6.0/NVIDIA_H200/chunk_gated_delta_rule_fwd_h/{BT=64,H=8,K=128,V=128}_NVIDIA_H200.json @@ -0,0 +1,7 @@ +{ + "2": { + "BV": 32, + "num_stages": 3, + "num_warps": 4 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.6.0/NVIDIA_H200/chunk_local_cumsum_scalar/{B=1,BT=64,H=8,IS_VARLEN=true,REVERSE=false}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.6.0/NVIDIA_H200/chunk_local_cumsum_scalar/{B=1,BT=64,H=8,IS_VARLEN=true,REVERSE=false}_NVIDIA_H200.json new file mode 100644 index 000000000..e922d888f --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.6.0/NVIDIA_H200/chunk_local_cumsum_scalar/{B=1,BT=64,H=8,IS_VARLEN=true,REVERSE=false}_NVIDIA_H200.json @@ -0,0 +1,38 @@ +{ + "1": { + "num_warps": 4 + }, + "100": { + "num_warps": 4 + }, + "1024": { + "num_warps": 4 + }, + "128": { + "num_warps": 4 + }, + "16": { + "num_warps": 4 + }, + "16384": { + "num_warps": 4 + }, + "2048": { + "num_warps": 4 + }, + "256": { + "num_warps": 4 + }, + "32": { + "num_warps": 4 + }, + "4096": { + "num_warps": 4 + }, + "64": { + "num_warps": 4 + }, + "8": { + "num_warps": 4 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.6.0/NVIDIA_H200/chunk_scaled_dot_kkt_fwd/{BT=64,H=8,IS_VARLEN=true,K=128}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.6.0/NVIDIA_H200/chunk_scaled_dot_kkt_fwd/{BT=64,H=8,IS_VARLEN=true,K=128}_NVIDIA_H200.json new file mode 100644 index 000000000..9459f41fa --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.6.0/NVIDIA_H200/chunk_scaled_dot_kkt_fwd/{BT=64,H=8,IS_VARLEN=true,K=128}_NVIDIA_H200.json @@ -0,0 +1,7 @@ +{ + "2": { + "BK": 64, + "num_stages": 3, + "num_warps": 4 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.6.0/NVIDIA_H200/fused_gdn_gating:v1/{NUM_HEADS=8,a_dtype=torch.bfloat16}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.6.0/NVIDIA_H200/fused_gdn_gating:v1/{NUM_HEADS=8,a_dtype=torch.bfloat16}_NVIDIA_H200.json new file mode 100644 index 000000000..dff8ac4d0 --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.6.0/NVIDIA_H200/fused_gdn_gating:v1/{NUM_HEADS=8,a_dtype=torch.bfloat16}_NVIDIA_H200.json @@ -0,0 +1,50 @@ +{ + "1": { + "BLK_HEADS": 4, + "num_warps": 2 + }, + "100": { + "BLK_HEADS": 16, + "num_warps": 4 + }, + "1024": { + "BLK_HEADS": 16, + "num_warps": 4 + }, + "128": { + "BLK_HEADS": 16, + "num_warps": 4 + }, + "16": { + "BLK_HEADS": 16, + "num_warps": 4 + }, + "16384": { + "BLK_HEADS": 16, + "num_warps": 4 + }, + "2048": { + "BLK_HEADS": 16, + "num_warps": 4 + }, + "256": { + "BLK_HEADS": 16, + "num_warps": 4 + }, + "32": { + "BLK_HEADS": 16, + "num_warps": 4 + }, + "4096": { + "BLK_HEADS": 16, + "num_warps": 4 + }, + "64": { + "BLK_HEADS": 16, + "num_warps": 4 + }, + "8": { + "BLK_HEADS": 4, + "num_warps": 2 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.6.0/NVIDIA_H200/gated_rmsnorm_forward:v1/{N=128,has_bias=false,weight_dtype=torch.bfloat16,x_dtype=torch.bfloat16}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.6.0/NVIDIA_H200/gated_rmsnorm_forward:v1/{N=128,has_bias=false,weight_dtype=torch.bfloat16,x_dtype=torch.bfloat16}_NVIDIA_H200.json new file mode 100644 index 000000000..1fcfa30e9 --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.6.0/NVIDIA_H200/gated_rmsnorm_forward:v1/{N=128,has_bias=false,weight_dtype=torch.bfloat16,x_dtype=torch.bfloat16}_NVIDIA_H200.json @@ -0,0 +1,50 @@ +{ + "1024": { + "BLOCK_N": 128, + "num_warps": 2 + }, + "128": { + "BLOCK_N": 128, + "num_warps": 2 + }, + "131072": { + "BLOCK_N": 128, + "num_warps": 2 + }, + "16384": { + "BLOCK_N": 128, + "num_warps": 2 + }, + "2048": { + "BLOCK_N": 128, + "num_warps": 2 + }, + "256": { + "BLOCK_N": 128, + "num_warps": 2 + }, + "32768": { + "BLOCK_N": 128, + "num_warps": 2 + }, + "512": { + "BLOCK_N": 128, + "num_warps": 2 + }, + "64": { + "BLOCK_N": 512, + "num_warps": 4 + }, + "8": { + "BLOCK_N": 128, + "num_warps": 2 + }, + "800": { + "BLOCK_N": 128, + "num_warps": 2 + }, + "8192": { + "BLOCK_N": 128, + "num_warps": 2 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.6.0/NVIDIA_H200/grouped_matmul:v1/{K=128,N=3072,expert_num=257,mul_routed_weight=true,out_dtype=torch.bfloat16,topk_num=1,use_fp8_w8a8=false}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.6.0/NVIDIA_H200/grouped_matmul:v1/{K=128,N=3072,expert_num=257,mul_routed_weight=true,out_dtype=torch.bfloat16,topk_num=1,use_fp8_w8a8=false}_NVIDIA_H200.json new file mode 100644 index 000000000..978754ec9 --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.6.0/NVIDIA_H200/grouped_matmul:v1/{K=128,N=3072,expert_num=257,mul_routed_weight=true,out_dtype=torch.bfloat16,topk_num=1,use_fp8_w8a8=false}_NVIDIA_H200.json @@ -0,0 +1,110 @@ +{ + "1152": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "144": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 64, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "147456": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 16, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "18432": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 64, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "2304": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "288": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 64, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "36864": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "576": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 64, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "72": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 64, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "9": { + "BLOCK_SIZE_K": 32, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 16, + "NEED_TRANS": false, + "num_stages": 4, + "num_warps": 4 + }, + "900": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "9216": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.6.0/NVIDIA_H200/grouped_matmul:v1/{K=3072,N=256,expert_num=257,mul_routed_weight=false,out_dtype=torch.bfloat16,topk_num=9,use_fp8_w8a8=false}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.6.0/NVIDIA_H200/grouped_matmul:v1/{K=3072,N=256,expert_num=257,mul_routed_weight=false,out_dtype=torch.bfloat16,topk_num=9,use_fp8_w8a8=false}_NVIDIA_H200.json new file mode 100644 index 000000000..2af76b549 --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.6.0/NVIDIA_H200/grouped_matmul:v1/{K=3072,N=256,expert_num=257,mul_routed_weight=false,out_dtype=torch.bfloat16,topk_num=9,use_fp8_w8a8=false}_NVIDIA_H200.json @@ -0,0 +1,110 @@ +{ + "1": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 4, + "num_warps": 4 + }, + "100": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 64, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "1024": { + "BLOCK_SIZE_K": 32, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 16, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "128": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 64, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "16": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 4, + "num_warps": 4 + }, + "16384": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "2048": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 16, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "256": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 64, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "32": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 16, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "4096": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "64": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "8": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "NEED_TRANS": false, + "num_stages": 5, + "num_warps": 4 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.6.0/NVIDIA_H200/moe_align_fused_atomic:v1/{expert_num=257,topk_num=9}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.6.0/NVIDIA_H200/moe_align_fused_atomic:v1/{expert_num=257,topk_num=9}_NVIDIA_H200.json new file mode 100644 index 000000000..271548dd6 --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.6.0/NVIDIA_H200/moe_align_fused_atomic:v1/{expert_num=257,topk_num=9}_NVIDIA_H200.json @@ -0,0 +1,22 @@ +{ + "1024": { + "BLOCK_SIZE": 128, + "num_warps": 8 + }, + "16384": { + "BLOCK_SIZE": 128, + "num_warps": 4 + }, + "2048": { + "BLOCK_SIZE": 128, + "num_warps": 4 + }, + "256": { + "BLOCK_SIZE": 128, + "num_warps": 4 + }, + "4096": { + "BLOCK_SIZE": 128, + "num_warps": 4 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.6.0/NVIDIA_H200/moe_align_fused_small:v2/{expert_num=257,topk_num=9}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.6.0/NVIDIA_H200/moe_align_fused_small:v2/{expert_num=257,topk_num=9}_NVIDIA_H200.json new file mode 100644 index 000000000..42cc5c38a --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.6.0/NVIDIA_H200/moe_align_fused_small:v2/{expert_num=257,topk_num=9}_NVIDIA_H200.json @@ -0,0 +1,37 @@ +{ + "1": { + "BLOCK_SIZE": 32, + "NUM_STAGE": 2, + "num_warps": 2 + }, + "100": { + "BLOCK_SIZE": 1024, + "NUM_STAGE": 4, + "num_warps": 8 + }, + "128": { + "BLOCK_SIZE": 256, + "NUM_STAGE": 4, + "num_warps": 2 + }, + "16": { + "BLOCK_SIZE": 256, + "NUM_STAGE": 2, + "num_warps": 2 + }, + "32": { + "BLOCK_SIZE": 512, + "NUM_STAGE": 1, + "num_warps": 4 + }, + "64": { + "BLOCK_SIZE": 256, + "NUM_STAGE": 4, + "num_warps": 2 + }, + "8": { + "BLOCK_SIZE": 128, + "NUM_STAGE": 1, + "num_warps": 4 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.6.0/NVIDIA_H200/moe_sum_reduce:v1/{hidden_dim=3072,out_dtype=torch.bfloat16,topk_num=9}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.6.0/NVIDIA_H200/moe_sum_reduce:v1/{hidden_dim=3072,out_dtype=torch.bfloat16,topk_num=9}_NVIDIA_H200.json new file mode 100644 index 000000000..bd78d49c4 --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.6.0/NVIDIA_H200/moe_sum_reduce:v1/{hidden_dim=3072,out_dtype=torch.bfloat16,topk_num=9}_NVIDIA_H200.json @@ -0,0 +1,74 @@ +{ + "1": { + "BLOCK_DIM": 256, + "BLOCK_M": 8, + "NUM_STAGE": 1, + "num_warps": 8 + }, + "100": { + "BLOCK_DIM": 256, + "BLOCK_M": 1, + "NUM_STAGE": 1, + "num_warps": 1 + }, + "1024": { + "BLOCK_DIM": 1024, + "BLOCK_M": 2, + "NUM_STAGE": 4, + "num_warps": 4 + }, + "128": { + "BLOCK_DIM": 512, + "BLOCK_M": 1, + "NUM_STAGE": 1, + "num_warps": 4 + }, + "16": { + "BLOCK_DIM": 512, + "BLOCK_M": 1, + "NUM_STAGE": 2, + "num_warps": 16 + }, + "16384": { + "BLOCK_DIM": 1024, + "BLOCK_M": 1, + "NUM_STAGE": 4, + "num_warps": 4 + }, + "2048": { + "BLOCK_DIM": 1024, + "BLOCK_M": 4, + "NUM_STAGE": 4, + "num_warps": 4 + }, + "256": { + "BLOCK_DIM": 512, + "BLOCK_M": 1, + "NUM_STAGE": 1, + "num_warps": 4 + }, + "32": { + "BLOCK_DIM": 256, + "BLOCK_M": 1, + "NUM_STAGE": 1, + "num_warps": 1 + }, + "4096": { + "BLOCK_DIM": 512, + "BLOCK_M": 1, + "NUM_STAGE": 4, + "num_warps": 2 + }, + "64": { + "BLOCK_DIM": 256, + "BLOCK_M": 1, + "NUM_STAGE": 1, + "num_warps": 1 + }, + "8": { + "BLOCK_DIM": 128, + "BLOCK_M": 1, + "NUM_STAGE": 4, + "num_warps": 8 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.6.0/NVIDIA_H200/mrope_triton_fused:v1/{HEAD_DIM=256,K_HEAD_NUM=1,Q_HEAD_NUM=4,dtype=torch.bfloat16}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.6.0/NVIDIA_H200/mrope_triton_fused:v1/{HEAD_DIM=256,K_HEAD_NUM=1,Q_HEAD_NUM=4,dtype=torch.bfloat16}_NVIDIA_H200.json new file mode 100644 index 000000000..f1882ef5d --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.6.0/NVIDIA_H200/mrope_triton_fused:v1/{HEAD_DIM=256,K_HEAD_NUM=1,Q_HEAD_NUM=4,dtype=torch.bfloat16}_NVIDIA_H200.json @@ -0,0 +1,50 @@ +{ + "1": { + "num_stages": 4, + "num_warps": 1 + }, + "100": { + "num_stages": 4, + "num_warps": 1 + }, + "1024": { + "num_stages": 4, + "num_warps": 1 + }, + "128": { + "num_stages": 4, + "num_warps": 1 + }, + "16": { + "num_stages": 4, + "num_warps": 1 + }, + "16384": { + "num_stages": 4, + "num_warps": 2 + }, + "2048": { + "num_stages": 3, + "num_warps": 2 + }, + "256": { + "num_stages": 4, + "num_warps": 1 + }, + "32": { + "num_stages": 4, + "num_warps": 1 + }, + "4096": { + "num_stages": 5, + "num_warps": 1 + }, + "64": { + "num_stages": 4, + "num_warps": 1 + }, + "8": { + "num_stages": 4, + "num_warps": 2 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.6.0/NVIDIA_H200/silu_and_mul_fwd:v1/{N=128,out_dtype=torch.bfloat16}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.6.0/NVIDIA_H200/silu_and_mul_fwd:v1/{N=128,out_dtype=torch.bfloat16}_NVIDIA_H200.json new file mode 100644 index 000000000..c3cabb161 --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.6.0/NVIDIA_H200/silu_and_mul_fwd:v1/{N=128,out_dtype=torch.bfloat16}_NVIDIA_H200.json @@ -0,0 +1,74 @@ +{ + "1152": { + "BLOCK_M": 1, + "BLOCK_N": 256, + "NUM_STAGES": 1, + "num_warps": 4 + }, + "144": { + "BLOCK_M": 1, + "BLOCK_N": 256, + "NUM_STAGES": 1, + "num_warps": 4 + }, + "147456": { + "BLOCK_M": 64, + "BLOCK_N": 128, + "NUM_STAGES": 4, + "num_warps": 1 + }, + "18432": { + "BLOCK_M": 8, + "BLOCK_N": 128, + "NUM_STAGES": 4, + "num_warps": 1 + }, + "2304": { + "BLOCK_M": 1, + "BLOCK_N": 256, + "NUM_STAGES": 1, + "num_warps": 4 + }, + "288": { + "BLOCK_M": 1, + "BLOCK_N": 256, + "NUM_STAGES": 1, + "num_warps": 4 + }, + "36864": { + "BLOCK_M": 8, + "BLOCK_N": 128, + "NUM_STAGES": 4, + "num_warps": 1 + }, + "576": { + "BLOCK_M": 1, + "BLOCK_N": 256, + "NUM_STAGES": 1, + "num_warps": 4 + }, + "72": { + "BLOCK_M": 1, + "BLOCK_N": 64, + "NUM_STAGES": 2, + "num_warps": 1 + }, + "9": { + "BLOCK_M": 1, + "BLOCK_N": 256, + "NUM_STAGES": 1, + "num_warps": 4 + }, + "900": { + "BLOCK_M": 1, + "BLOCK_N": 256, + "NUM_STAGES": 1, + "num_warps": 4 + }, + "9216": { + "BLOCK_M": 8, + "BLOCK_N": 128, + "NUM_STAGES": 4, + "num_warps": 1 + } +} \ No newline at end of file diff --git a/unit_tests/common/fused_moe/test_grouped_fused_moe.py b/unit_tests/common/fused_moe/test_grouped_fused_moe.py index 9c08cfc1a..0376d01ee 100644 --- a/unit_tests/common/fused_moe/test_grouped_fused_moe.py +++ b/unit_tests/common/fused_moe/test_grouped_fused_moe.py @@ -3,7 +3,10 @@ import pytest import triton from lightllm.common.basemodel.triton_kernel.fused_moe.grouped_fused_moe import ( + _moe_align_fused_atomic_token, + fused_experts_impl, moe_align, + moe_align_fused, moe_align1, moe_align2, grouped_matmul, @@ -74,6 +77,145 @@ def test_moe_align1(): assert torch.equal(experts_info, true_experts_info) +def _check_moe_align_fused(topk_ids, topk_weights, expert_num, ordered=True): + expert_to_token_index = torch.empty((expert_num, topk_ids.numel()), dtype=torch.int32, device="cuda") + expert_to_weight = torch.empty((expert_num, topk_ids.numel()), dtype=torch.float32, device="cuda") + expert_token_num = torch.empty((expert_num,), dtype=torch.int32, device="cuda") + + moe_align_fused( + expert_to_token_index, + expert_to_weight, + expert_token_num, + topk_ids, + topk_weights, + ) + torch.cuda.synchronize() + + flat_topk_ids = topk_ids.flatten() + flat_topk_weights = topk_weights.flatten() + expected_token_num = torch.bincount(flat_topk_ids, minlength=expert_num).to(torch.int32) + assert torch.equal(expert_token_num, expected_token_num) + + for expert_id, token_num in enumerate(expected_token_num.tolist()): + expected_index = torch.nonzero(flat_topk_ids == expert_id, as_tuple=False).flatten() + expected_weight = flat_topk_weights[expected_index] + expected_index = expected_index.to(torch.int32) + token_index = expert_to_token_index[expert_id, :token_num] + token_weight = expert_to_weight[expert_id, :token_num] + + if not ordered: + order = torch.argsort(token_index) + token_index = token_index[order] + token_weight = token_weight[order] + + assert torch.equal(token_index, expected_index) + assert torch.allclose(token_weight, expected_weight) + + +def test_moe_align_fused_small_token(): + expert_num = 5 + small_topk_ids = torch.tensor([[0, 1, 2], [0, 3, 1], [3, 1, 4]], dtype=torch.int32, device="cuda") + small_topk_weights = torch.tensor( + [[0.3, 0.7, 0.1], [0.2, 0.8, 0.4], [0.5, 0.6, 0.9]], dtype=torch.float32, device="cuda" + ) + _check_moe_align_fused(small_topk_ids, small_topk_weights, expert_num) + + small_many_topk_ids = torch.arange(128 * 17, dtype=torch.int32, device="cuda").reshape(128, 17) % expert_num + small_many_topk_weights = torch.arange(small_many_topk_ids.numel(), dtype=torch.float32, device="cuda").reshape( + 128, 17 + ) + _check_moe_align_fused(small_many_topk_ids, small_many_topk_weights, expert_num) + + +def test_moe_align_fused_large_token(): + expert_num = 5 + + base_topk_ids = torch.tensor([[0, 1, 2], [0, 3, 1], [3, 1, 4], [2, 0, 4]], dtype=torch.int32, device="cuda") + large_topk_ids = base_topk_ids.repeat(33, 1)[:129].contiguous() + large_topk_weights = torch.arange(large_topk_ids.numel(), dtype=torch.float32, device="cuda").reshape(129, 3) + _check_moe_align_fused(large_topk_ids, large_topk_weights, expert_num, ordered=False) + + medium_topk_ids = base_topk_ids.repeat(1024, 1).contiguous() + medium_topk_weights = torch.arange(medium_topk_ids.numel(), dtype=torch.float32, device="cuda").reshape(4096, 3) + _check_moe_align_fused(medium_topk_ids, medium_topk_weights, expert_num, ordered=False) + + shared_expert_num = 257 + shared_routing = torch.arange(512 * 7, dtype=torch.int32, device="cuda").reshape(512, 7) % 256 + shared_last = torch.full((512, 1), 256, dtype=torch.int32, device="cuda") + shared_topk_ids = torch.cat([shared_routing, shared_last], dim=1).contiguous() + shared_topk_weights = torch.arange(shared_topk_ids.numel(), dtype=torch.float32, device="cuda").reshape(512, 8) + _check_moe_align_fused(shared_topk_ids, shared_topk_weights, shared_expert_num, ordered=False) + + large_atomic_topk_ids = base_topk_ids.repeat(1281, 1)[:5121].contiguous() + large_atomic_topk_weights = torch.arange(large_atomic_topk_ids.numel(), dtype=torch.float32, device="cuda").reshape( + 5121, 3 + ) + _check_moe_align_fused(large_atomic_topk_ids, large_atomic_topk_weights, expert_num, ordered=False) + + sparse_expert_num = 257 + sparse_topk_ids = base_topk_ids.repeat(1281, 1)[:5121].contiguous() + sparse_topk_weights = torch.arange(sparse_topk_ids.numel(), dtype=torch.float32, device="cuda").reshape(5121, 3) + _check_moe_align_fused(sparse_topk_ids, sparse_topk_weights, sparse_expert_num, ordered=False) + + +def test_moe_align_fused_large_token_unordered(): + expert_num = 257 + topk_ids = torch.arange(5121 * 8, dtype=torch.int32, device="cuda").reshape(5121, 8) % expert_num + topk_weights = torch.arange(topk_ids.numel(), dtype=torch.float32, device="cuda").reshape(5121, 8) + _check_moe_align_fused(topk_ids, topk_weights, expert_num, ordered=False) + + +def test_moe_align_fused_atomic_token_unordered(): + expert_num = 9 + topk_ids = torch.arange(257 * 4, dtype=torch.int32, device="cuda").reshape(257, 4) % expert_num + topk_weights = torch.arange(topk_ids.numel(), dtype=torch.float32, device="cuda").reshape(257, 4) + expert_to_token_index = torch.empty((expert_num, topk_ids.numel()), dtype=torch.int32, device="cuda") + expert_to_weight = torch.empty((expert_num, topk_ids.numel()), dtype=torch.float32, device="cuda") + expert_token_num = torch.empty((expert_num,), dtype=torch.int32, device="cuda") + + _moe_align_fused_atomic_token( + expert_to_token_index, + expert_to_weight, + expert_token_num, + topk_ids, + topk_weights, + ) + torch.cuda.synchronize() + + flat_topk_ids = topk_ids.flatten() + flat_topk_weights = topk_weights.flatten() + expected_token_num = torch.bincount(flat_topk_ids, minlength=expert_num).to(torch.int32) + assert torch.equal(expert_token_num, expected_token_num) + + for expert_id, token_num in enumerate(expected_token_num.tolist()): + expected_index = torch.nonzero(flat_topk_ids == expert_id, as_tuple=False).flatten().to(torch.int32) + expected_weight = flat_topk_weights[expected_index] + token_index = expert_to_token_index[expert_id, :token_num] + token_weight = expert_to_weight[expert_id, :token_num] + order = torch.argsort(token_index) + assert torch.equal(token_index[order], expected_index) + assert torch.allclose(token_weight[order], expected_weight) + + +def test_fused_experts_atomic_align_path_is_deterministic(): + token_num = 129 + expert_num = 9 + hidden_size = 64 + intermediate_size = 128 + topk = 4 + hidden_states = torch.randn((token_num, hidden_size), dtype=torch.bfloat16, device="cuda") / 10 + w1 = torch.randn((expert_num, intermediate_size, hidden_size), dtype=torch.bfloat16, device="cuda") / 10 + w2 = torch.randn((expert_num, hidden_size, intermediate_size // 2), dtype=torch.bfloat16, device="cuda") / 10 + topk_ids = torch.arange(token_num * topk, dtype=torch.int32, device="cuda").reshape(token_num, topk) % expert_num + topk_weights = torch.softmax(torch.randn((token_num, topk), dtype=torch.float32, device="cuda"), dim=-1) + + out_0 = fused_experts_impl(hidden_states, w1, w2, topk_weights, topk_ids) + out_1 = fused_experts_impl(hidden_states, w1, w2, topk_weights, topk_ids) + torch.cuda.synchronize() + + assert torch.equal(out_0, out_1) + + def test_moe_align2(): experts_token_num = torch.zeros((4,), dtype=torch.int32, device="cuda") @@ -83,13 +225,14 @@ def test_moe_align2(): experts_token_num[3] = 16 mblocks_to_tuple_info = moe_align2(100, experts_token_num, block_m=16) + expected_expert_ids = torch.tensor([0, 2, 2, 2, 2, 3, -1, -1, -1, -1], device="cuda", dtype=torch.int32) + valid_blocks = expected_expert_ids != -1 + assert mblocks_to_tuple_info.shape[0] == triton.cdiv(100 + 4 * (16 - 1), 16) - assert torch.allclose( - mblocks_to_tuple_info[:, 0], - torch.tensor([0, 2, 2, 2, 2, 3, -1, -1, -1, -1], device="cuda", dtype=torch.int32), - ) - assert torch.allclose( - mblocks_to_tuple_info[:, 1], torch.tensor([0, 0, 1, 2, 3, 0, 0, 0, 0, 0], device="cuda", dtype=torch.int32) + assert torch.equal(mblocks_to_tuple_info[:, 0], expected_expert_ids) + assert torch.equal( + mblocks_to_tuple_info[valid_blocks, 1], + torch.tensor([0, 0, 1, 2, 3, 0], device="cuda", dtype=torch.int32), )