Skip to content

Commit 34553b9

Browse files
authored
[Performance] Support FP8 flashinfer TRTLLM MOE on Qwen3 and Qwen-3next (#27492)
Signed-off-by: jiahanc <173873397+jiahanc@users.noreply.github.com>
1 parent b039bfd commit 34553b9

File tree

7 files changed

+78
-30
lines changed

7 files changed

+78
-30
lines changed

vllm/model_executor/layers/fused_moe/config.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# SPDX-License-Identifier: Apache-2.0
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
33
from dataclasses import dataclass
4+
from enum import IntEnum
45
from typing import Optional, Union
56

67
import torch
@@ -91,6 +92,26 @@ def _quant_flags_to_group_shape(
9192
return a_shape, w_shape
9293

9394

95+
# The type of method in top-K routing
96+
# Please keep this in sync with the counterpart defined in https://github.com/flashinfer-ai/flashinfer/blob/main/include/flashinfer/trtllm/fused_moe/runner.h
97+
class RoutingMethodType(IntEnum):
98+
# Default: Softmax -> TopK
99+
Default = (0,)
100+
# Renormalize: TopK -> Softmax
101+
Renormalize = (1,)
102+
# DeepSeekV3: Sigmoid -> RoutingBiasAdd -> Top2 in group -> Top4 groups
103+
# -> Top8 experts from the Top4 groups
104+
DeepSeekV3 = (2,)
105+
# Llama4: Top1 -> Sigmoid
106+
Llama4 = (3,)
107+
# RenormalizeNaive: Softmax -> TopK -> Renormalize
108+
RenormalizeNaive = (4,)
109+
# TopK: TopK (no softmax)
110+
TopK = (5,)
111+
# Unspecified
112+
Unspecified = 6.0
113+
114+
94115
@dataclass
95116
class FusedMoEQuantDesc:
96117
"""

vllm/model_executor/layers/fused_moe/flashinfer_trtllm_moe.py

Lines changed: 12 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
import torch
55

6+
from vllm.model_executor.layers.fused_moe.config import RoutingMethodType
67
from vllm.model_executor.layers.fused_moe.utils import moe_kernel_quantize_input
78
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
89
calculate_tile_tokens_dim,
@@ -23,26 +24,24 @@ def flashinfer_fused_moe_blockscale_fp8(
2324
w2_weight_scale_inv: torch.Tensor,
2425
global_num_experts: int,
2526
top_k: int,
26-
num_expert_group: int,
27-
topk_group: int,
27+
num_expert_group: int | None,
28+
topk_group: int | None,
2829
intermediate_size: int,
2930
expert_offset: int,
3031
local_num_experts: int,
3132
block_shape: list[int],
32-
routed_scaling: float = 1.0,
33+
routing_method_type: int = RoutingMethodType.DeepSeekV3,
34+
routed_scaling: float | None = 1.0,
3335
) -> torch.Tensor:
3436
from vllm.utils.flashinfer import flashinfer_trtllm_fp8_block_scale_moe
3537

38+
topk_group = topk_group if topk_group is not None else 0
3639
assert top_k <= global_num_experts
37-
assert top_k <= 8
38-
assert topk_group <= 4
39-
assert global_num_experts > num_expert_group
40-
assert global_num_experts % num_expert_group == 0
40+
assert top_k <= 10
4141
assert global_num_experts % 4 == 0
42-
assert top_k < (topk_group * global_num_experts / num_expert_group)
4342
assert block_shape == [128, 128]
44-
# Routing kernel expects #experts <= #threads 256
45-
assert global_num_experts <= 256
43+
# Routing kernel expects #experts <= #threads 512
44+
assert global_num_experts <= 512
4645

4746
a_q, a_sf = per_token_group_quant_fp8(x, block_shape[1])
4847
# NOTE: scales of hidden states have to be transposed!
@@ -64,10 +63,8 @@ def flashinfer_fused_moe_blockscale_fp8(
6463
local_expert_offset=expert_offset,
6564
local_num_experts=local_num_experts,
6665
routed_scaling_factor=routed_scaling,
67-
tile_tokens_dim=calculate_tile_tokens_dim(
68-
x.shape[0], top_k, global_num_experts
69-
),
70-
routing_method_type=2, # DeepSeek-styled routing method
66+
tile_tokens_dim=None,
67+
routing_method_type=routing_method_type,
7168
use_shuffled_weight=False,
7269
)
7370

@@ -88,6 +85,7 @@ def flashinfer_fused_moe_blockscale_fp8_fake(
8885
expert_offset: int,
8986
local_num_experts: int,
9087
block_shape: list[int],
88+
routing_method_type: int,
9189
routed_scaling: float = 1.0,
9290
) -> torch.Tensor:
9391
return torch.empty_like(x)

vllm/model_executor/layers/fused_moe/layer.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
FusedMoEConfig,
3232
FusedMoEParallelConfig,
3333
FusedMoEQuantConfig,
34+
RoutingMethodType,
3435
biased_moe_quant_config,
3536
)
3637
from vllm.model_executor.layers.fused_moe.fused_moe import zero_experts_compute_triton
@@ -1213,6 +1214,7 @@ def __init__(
12131214
zero_expert_type: str | None = None,
12141215
expert_mapping: list[tuple[str, str, int, str]] | None = None,
12151216
n_shared_experts: int | None = None,
1217+
routing_method_type: int | None = None,
12161218
):
12171219
super().__init__()
12181220

@@ -1397,6 +1399,24 @@ def __init__(
13971399
"Only softmax scoring function is supported for non-grouped topk."
13981400
)
13991401

1402+
# ToDo: Better logic to determine the routing method type
1403+
if routing_method_type is not None:
1404+
self.routing_method_type = routing_method_type
1405+
else:
1406+
if scoring_func == "sigmoid":
1407+
if self.use_grouped_topk:
1408+
self.routing_method_type = RoutingMethodType.DeepSeekV3
1409+
elif self.top_k == 1:
1410+
self.routing_method_type = RoutingMethodType.Llama4
1411+
elif self.scoring_func == "softmax":
1412+
self.routing_method_type = (
1413+
RoutingMethodType.Renormalize
1414+
if not self.renormalize
1415+
else RoutingMethodType.RenormalizeNaive
1416+
)
1417+
else:
1418+
self.routing_method_type = RoutingMethodType.TopK
1419+
14001420
self.moe_config: FusedMoEConfig = FusedMoEConfig(
14011421
num_experts=self.global_num_experts,
14021422
experts_per_token=top_k,

vllm/model_executor/layers/quantization/fp8.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
)
2929
from vllm.model_executor.layers.fused_moe.config import (
3030
FusedMoEQuantConfig,
31+
RoutingMethodType,
3132
fp8_w8a8_moe_quant_config,
3233
)
3334
from vllm.model_executor.layers.fused_moe.fused_marlin_moe import fused_marlin_moe
@@ -1222,22 +1223,20 @@ def apply(
12221223
assert activation == "silu", (
12231224
f"Expected 'silu' activation but got {activation}"
12241225
)
1225-
assert scoring_func == "sigmoid", (
1226-
f"Expected 'sigmoid' scoring func but got {scoring_func}"
1227-
)
1226+
12281227
if self.block_quant:
12291228
import vllm.model_executor.layers.fused_moe.flashinfer_trtllm_moe # noqa: E501, F401
12301229

1231-
assert (
1232-
renormalize and use_grouped_topk and custom_routing_function is None
1233-
)
12341230
e_score_correction_bias = (
12351231
e_score_correction_bias.to(x.dtype)
12361232
if e_score_correction_bias is not None
12371233
else None
12381234
)
1235+
routing_method_type = layer.routing_method_type
12391236
return torch.ops.vllm.flashinfer_fused_moe_blockscale_fp8(
1240-
routing_logits=router_logits.to(torch.float32),
1237+
routing_logits=router_logits.to(torch.float32)
1238+
if routing_method_type == RoutingMethodType.DeepSeekV3
1239+
else router_logits,
12411240
routing_bias=e_score_correction_bias,
12421241
x=x,
12431242
w13_weight=layer.w13_weight,
@@ -1252,6 +1251,7 @@ def apply(
12521251
expert_offset=layer.ep_rank * layer.local_num_experts,
12531252
local_num_experts=layer.local_num_experts,
12541253
block_shape=self.weight_block_size,
1254+
routing_method_type=routing_method_type,
12551255
routed_scaling=routed_scaling_factor,
12561256
)
12571257
else:

vllm/model_executor/layers/quantization/utils/flashinfer_utils.py

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -27,20 +27,25 @@ class FlashinferMoeBackend(Enum):
2727

2828

2929
def calculate_tile_tokens_dim(num_tokens, top_k, num_experts):
30+
from flashinfer import next_positive_power_of_2
31+
3032
# FlashInfer 0.2.10 has issues with larger tile sizes. Set to 8 for now.
3133
# TODO: Revert this to dynamic calculation once a new version of FlashInfer
3234
# with the necessary kernels is released.
3335
tile_tokens_dim = 8
3436

35-
# from flashinfer import next_positive_power_of_2
36-
37-
# # Guess tokens per expert assuming perfect expert distribution first.
38-
# num_tokens_per_expert = (num_tokens * top_k) // num_experts
39-
# # And pad the number to the next power of 2.
40-
# tile_tokens_dim = next_positive_power_of_2(num_tokens_per_expert)
41-
# # Cap to 8-64 tokens per CTA tile as it's the range supported by the
42-
# # kernel.
43-
# tile_tokens_dim = min(max(tile_tokens_dim, 8), 64)
37+
# A factor considering tokens are not perfectly balanced among experts.
38+
imbalance_factor = 1.3
39+
# Calculate the number of tokens per expert
40+
# assuming perfect distribution.
41+
num_tokens_per_expert = (num_tokens * top_k) // num_experts
42+
# Apply the imbalance factor.
43+
num_tokens_per_expert = int(num_tokens_per_expert * imbalance_factor)
44+
# And pad the number to the next power of 2.
45+
tile_tokens_dim = next_positive_power_of_2(num_tokens_per_expert)
46+
# Cap to 8-max_tile_tokens_dim tokens per CTA tile
47+
# as it's the range supported by the kernel.
48+
tile_tokens_dim = min(max(tile_tokens_dim, 8), 64)
4449

4550
return tile_tokens_dim
4651

vllm/model_executor/models/qwen3_moe.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@
4343
from vllm.logger import init_logger
4444
from vllm.model_executor.layers.activation import SiluAndMul
4545
from vllm.model_executor.layers.fused_moe import FusedMoE
46+
from vllm.model_executor.layers.fused_moe.config import RoutingMethodType
4647
from vllm.model_executor.layers.layernorm import RMSNorm
4748
from vllm.model_executor.layers.linear import (
4849
MergedColumnParallelLinear,
@@ -171,6 +172,7 @@ def __init__(
171172
enable_eplb=self.enable_eplb,
172173
num_redundant_experts=self.n_redundant_experts,
173174
is_sequence_parallel=self.is_sequence_parallel,
175+
routing_method_type=RoutingMethodType.Renormalize,
174176
)
175177

176178
self.gate = ReplicatedLinear(

vllm/model_executor/models/qwen3_next.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
fused_recurrent_gated_delta_rule,
3535
)
3636
from vllm.model_executor.layers.fused_moe import SharedFusedMoE
37+
from vllm.model_executor.layers.fused_moe.config import RoutingMethodType
3738
from vllm.model_executor.layers.layernorm import (
3839
GemmaRMSNorm as Qwen3NextRMSNorm,
3940
)
@@ -173,6 +174,7 @@ def __init__(self, vllm_config: VllmConfig, prefix: str = ""):
173174
enable_eplb=self.enable_eplb,
174175
num_redundant_experts=self.n_redundant_experts,
175176
is_sequence_parallel=self.is_sequence_parallel,
177+
routing_method_type=RoutingMethodType.Renormalize,
176178
)
177179

178180
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:

0 commit comments

Comments
 (0)