Skip to content

Commit 5e779f1

Browse files
author
Levi-JQ
committed
fix ut no feature
Signed-off-by: Levi-JQ <yujinqi2@huawei.com>
1 parent 3c32524 commit 5e779f1

File tree

1 file changed

+11
-7
lines changed

1 file changed

+11
-7
lines changed

tests/ut/torchair/ops/test_torchair_fused_moe.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -383,26 +383,30 @@ def test_apply_without_expert_map(self, moe_method, mock_dist_env,
383383
else:
384384
assert result.shape == x.shape
385385

386-
@patch('torch_npu.npu_moe_gating_top_k')
387386
@pytest.mark.parametrize("others_param", [16, 1, 4])
388387
def test_apply_with_expert_map(self, moe_method, mock_dist_env,
389-
mock_moe_env, others_param, mock_topk):
388+
mock_moe_env, others_param):
390389
"""
391390
1 test use_select_experts and use fused_expters_with_mc2
392391
2 test use_select_experts and fused_experts_with_all2all_buffer
393392
3 test use_select_experts and fused_experts_with_all2all
394393
4 test use_select_experts and fused_experts
395394
"""
396-
mock_topk.return_value = (torch.randn(8,
397-
2), torch.randint(0, 8,
398-
(8, 2)), None)
399395
ep_size = others_param
400396
is_prefill = False
401397
forward_context = MagicMock(
402398
fused_moe_state=_get_fused_moe_state(ep_size, is_prefill, True))
403-
with patch("vllm_ascend.torchair.ops.torchair_fused_moe.get_forward_context", return_value=forward_context), \
404-
patch("vllm_ascend.torchair.ops.torchair_fused_moe.get_ascend_soc_version", return_value=AscendSocVersion.A3):
399+
if ep_size == 1:
400+
top_k_return = (torch.randn(16, 2), torch.randint(0, 16,
401+
(16, 2)), None)
402+
expert_map = torch.tensor([0, 1, 2, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1])
403+
else:
404+
top_k_return = (torch.randn(8, 2), torch.randint(0, 8,
405+
(8, 2)), None)
405406
expert_map = torch.tensor([0, 1, 2, -1, -1, -1, -1, -1])
407+
with patch("vllm_ascend.torchair.ops.torchair_fused_moe.get_forward_context", return_value=forward_context), \
408+
patch("vllm_ascend.torchair.ops.torchair_fused_moe.get_ascend_soc_version", return_value=AscendSocVersion.A3), \
409+
patch('torch_npu.npu_moe_gating_top_k', return_value=top_k_return):
406410
moe_method.ep_size = ep_size
407411
x = torch.randn(8, 2, 2)
408412
if ep_size == 1:

0 commit comments

Comments
 (0)