Skip to content

Commit b78b7a9

Browse files
author
Levi-JQ
committed
optimization of kimi-k2
Signed-off-by: Levi-JQ <yujinqi2@huawei.com>
1 parent 52abd47 commit b78b7a9

File tree

3 files changed

+42
-93
lines changed

3 files changed

+42
-93
lines changed

vllm_ascend/torchair/ops/torchair_fused_moe.py

Lines changed: 14 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -857,38 +857,20 @@ def apply(
857857
shared_experts: Optional[Any] = None,
858858
**kwargs,
859859
) -> torch.Tensor:
860-
global_redundant_expert_num = get_ascend_config(
861-
).init_redundancy_expert
862-
is_deepseek_v3_r1 = global_num_experts - global_redundant_expert_num == 256
863-
# NOTE: now npu_moe_gating_top_k can only support `group_count=256` pattern
864-
if is_deepseek_v3_r1:
865-
topk_weights, topk_ids, _ = torch_npu.npu_moe_gating_top_k(
866-
router_logits,
867-
k=top_k, # topk currently is 8
868-
bias=e_score_correction_bias,
869-
k_group=topk_group, # fix: 4
870-
group_count=num_expert_group, # fix 8
871-
group_select_mode=
872-
1, # 0: the maximum in the group; 1: topk2.sum(fix)
873-
renorm=0, # 0: softmax->topk(fix); 1: topk->softmax
874-
norm_type=1, # 0: softmax; 1: sigmoid(fix)
875-
# out_flag=False, # todo new api; should the third output be output
876-
# y2_flag=False, # old api; should the third output be output
877-
routed_scaling_factor=1,
878-
eps=float(1e-20))
879-
else:
880-
topk_weights, topk_ids = torchair_select_experts(
881-
hidden_states=x,
882-
router_logits=router_logits,
883-
top_k=top_k,
884-
use_grouped_topk=use_grouped_topk,
885-
renormalize=renormalize,
886-
topk_group=topk_group,
887-
num_expert_group=num_expert_group,
888-
custom_routing_function=custom_routing_function,
889-
scoring_func=scoring_func,
890-
e_score_correction_bias=e_score_correction_bias,
891-
)
860+
topk_weights, topk_ids, _ = torch_npu.npu_moe_gating_top_k(
861+
router_logits,
862+
k=top_k, # topk currently is 8
863+
bias=e_score_correction_bias,
864+
k_group=topk_group, # fix: 4
865+
group_count=num_expert_group, # fix 8
866+
group_select_mode=
867+
1, # 0: the maximum in the group; 1: topk2.sum(fix)
868+
renorm=0, # 0: softmax->topk(fix); 1: topk->softmax
869+
norm_type=1, # 0: softmax; 1: sigmoid(fix)
870+
# out_flag=False, # todo new api; should the third output be output
871+
# y2_flag=False, # old api; should the third output be output
872+
routed_scaling_factor=1,
873+
eps=float(1e-20))
892874

893875
topk_weights = topk_weights.to(x.dtype)
894876
# this is a naive implementation for experts load balance so as

vllm_ascend/torchair/quantization/torchair_w4a8_dynamic.py

Lines changed: 14 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,6 @@
2727
from vllm_ascend.ascend_config import get_ascend_config
2828
from vllm_ascend.ascend_forward_context import FusedMoEState
2929
from vllm_ascend.distributed.parallel_state import get_mc2_group
30-
from vllm_ascend.torchair.ops.torchair_fused_moe import torchair_select_experts
3130
from vllm_ascend.torchair.quantization.torchair_w8a8_dynamic import (
3231
torchair_fused_experts_with_all2all, torchair_fused_experts_with_mc2)
3332
from vllm_ascend.torchair.utils import npu_stream_switch, npu_wait_tensor
@@ -322,34 +321,20 @@ def apply(
322321
assert router_logits.shape[
323322
1] == global_num_experts - global_redundant_expert_num, "Number of global experts mismatch (excluding redundancy)"
324323

325-
if global_num_experts == 256:
326-
topk_weights, topk_ids, _ = torch_npu.npu_moe_gating_top_k(
327-
router_logits,
328-
k=top_k, # topk currently is 8
329-
bias=e_score_correction_bias,
330-
k_group=topk_group, # fix: 4
331-
group_count=num_expert_group, # fix 8
332-
group_select_mode=
333-
1, # 0: the maximum in the group; 1: topk2.sum(fix)
334-
renorm=0, # 0: softmax->topk(fix); 1: topk->softmax
335-
norm_type=1, # 0: softmax; 1: sigmoid(fix)
336-
# out_flag=False, # todo new api; should the third output be output
337-
# y2_flag=False, # old api; should the third output be output
338-
routed_scaling_factor=1,
339-
eps=float(1e-20))
340-
else:
341-
topk_weights, topk_ids = torchair_select_experts(
342-
hidden_states=x,
343-
router_logits=router_logits,
344-
top_k=top_k,
345-
use_grouped_topk=use_grouped_topk,
346-
renormalize=renormalize,
347-
topk_group=topk_group,
348-
num_expert_group=num_expert_group,
349-
custom_routing_function=custom_routing_function,
350-
scoring_func=scoring_func,
351-
e_score_correction_bias=e_score_correction_bias,
352-
)
324+
topk_weights, topk_ids, _ = torch_npu.npu_moe_gating_top_k(
325+
router_logits,
326+
k=top_k, # topk currently is 8
327+
bias=e_score_correction_bias,
328+
k_group=topk_group, # fix: 4
329+
group_count=num_expert_group, # fix 8
330+
group_select_mode=
331+
1, # 0: the maximum in the group; 1: topk2.sum(fix)
332+
renorm=0, # 0: softmax->topk(fix); 1: topk->softmax
333+
norm_type=1, # 0: softmax; 1: sigmoid(fix)
334+
# out_flag=False, # todo new api; should the third output be output
335+
# y2_flag=False, # old api; should the third output be output
336+
routed_scaling_factor=1,
337+
eps=float(1e-20))
353338

354339
fused_moe_state = get_forward_context().fused_moe_state
355340
shared_gate_up, shared_dequant_scale = None, None

vllm_ascend/torchair/quantization/torchair_w8a8_dynamic.py

Lines changed: 14 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@
2525
from vllm_ascend.ascend_config import get_ascend_config
2626
from vllm_ascend.ascend_forward_context import FusedMoEState
2727
from vllm_ascend.distributed.parallel_state import get_mc2_group
28-
from vllm_ascend.torchair.ops.torchair_fused_moe import torchair_select_experts
2928
from vllm_ascend.torchair.utils import (npu_stream_switch, npu_wait_tensor,
3029
super_kernel)
3130
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_NZ, AscendSocVersion,
@@ -938,8 +937,6 @@ def apply(
938937
assert router_logits.shape[
939938
1] == global_num_experts - global_redundant_expert_num, "Number of global experts mismatch (excluding redundancy)"
940939

941-
is_deepseek_v3_r1 = global_num_experts - global_redundant_expert_num == 256
942-
943940
fused_moe_state = get_forward_context().fused_moe_state
944941
if self.enable_shared_expert_dp and fused_moe_state == FusedMoEState.MC2:
945942
fused_moe_state = FusedMoEState.All2All
@@ -948,35 +945,20 @@ def apply(
948945
with super_kernel(prefix,
949946
"stream-fusion=1",
950947
enabled=running_in_super_kernel):
951-
# NOTE: now npu_moe_gating_top_k can only support `group_count=256` pattern
952-
if is_deepseek_v3_r1:
953-
topk_weights, topk_ids, _ = torch_npu.npu_moe_gating_top_k(
954-
router_logits,
955-
k=top_k, # topk currently is 8
956-
bias=e_score_correction_bias,
957-
k_group=topk_group, # fix: 4
958-
group_count=num_expert_group, # fix 8
959-
group_select_mode=
960-
1, # 0: the maximum in the group; 1: topk2.sum(fix)
961-
renorm=0, # 0: softmax->topk(fix); 1: topk->softmax
962-
norm_type=1, # 0: softmax; 1: sigmoid(fix)
963-
# out_flag=False, # todo new api; should the third output be output
964-
# y2_flag=False, # old api; should the third output be output
965-
routed_scaling_factor=1,
966-
eps=float(1e-20))
967-
else:
968-
topk_weights, topk_ids = torchair_select_experts(
969-
hidden_states=x,
970-
router_logits=router_logits,
971-
top_k=top_k,
972-
use_grouped_topk=use_grouped_topk,
973-
renormalize=renormalize,
974-
topk_group=topk_group,
975-
num_expert_group=num_expert_group,
976-
custom_routing_function=custom_routing_function,
977-
scoring_func=scoring_func,
978-
e_score_correction_bias=e_score_correction_bias,
979-
)
948+
topk_weights, topk_ids, _ = torch_npu.npu_moe_gating_top_k(
949+
router_logits,
950+
k=top_k, # topk currently is 8
951+
bias=e_score_correction_bias,
952+
k_group=topk_group, # fix: 4
953+
group_count=num_expert_group, # fix 8
954+
group_select_mode=
955+
1, # 0: the maximum in the group; 1: topk2.sum(fix)
956+
renorm=0, # 0: softmax->topk(fix); 1: topk->softmax
957+
norm_type=1, # 0: softmax; 1: sigmoid(fix)
958+
# out_flag=False, # todo new api; should the third output be output
959+
# y2_flag=False, # old api; should the third output be output
960+
routed_scaling_factor=1,
961+
eps=float(1e-20))
980962

981963
if shared_experts is not None and fused_moe_state == FusedMoEState.MC2:
982964
with npu_stream_switch("moe_secondary", 0):

0 commit comments

Comments
 (0)