Skip to content

Commit fbf18b0

Browse files
committed
refactor DeepEPEngine
1 parent e8fdeb5 commit fbf18b0

File tree

1 file changed

+78
-80
lines changed
  • fastdeploy/model_executor/layers/backends/xpu/moe

1 file changed

+78
-80
lines changed

fastdeploy/model_executor/layers/backends/xpu/moe/ep.py

Lines changed: 78 additions & 80 deletions
Original file line numberDiff line numberDiff line change
@@ -19,17 +19,15 @@
1919
import deep_ep
2020
import paddle
2121
from paddle import nn
22-
from paddleformers.utils.log import logger
2322

2423
import fastdeploy
2524
from fastdeploy.config import MoEPhase
2625
from fastdeploy.utils import singleton
2726

2827

29-
@singleton
30-
class DeepEPEngine:
28+
class DeepEPEngineBase:
3129
"""
32-
A wrapper class for DeepEP engine.
30+
Base class for DeepEP engine implementations.
3331
"""
3432

3533
def __init__(
@@ -45,7 +43,7 @@ def __init__(
4543
group=None,
4644
):
4745
"""
48-
Initialize the DeepEP engine.
46+
Initialize the DeepEP engine base.
4947
Args:
5048
group: The MPI group object.
5149
ep_size: The number of ranks.
@@ -67,42 +65,47 @@ def __init__(
6765
group = paddle.distributed.new_group(range(ep_size))
6866
self.group = group
6967
self.num_local_experts = num_experts // ep_size
70-
self.deepep_engine = None # deepep_engine只调用dispatch, combine
71-
self.deepep_engine_low_latency = (
72-
None # deepep_engine_low_latency只调用low_latency_dispatch,low_latency_combine
68+
self.deepep_engine = None
69+
70+
def barrier_all(self):
71+
"""
72+
barrier_all
73+
"""
74+
if self.deepep_engine is not None:
75+
self.deepep_engine.barrier_all()
76+
77+
78+
@singleton
79+
class DeepEPEngineHighThroughput(DeepEPEngineBase):
80+
"""
81+
High throughput version of DeepEP engine for prefill phase.
82+
"""
83+
84+
def __init__(self, *args, **kwargs):
85+
super().__init__(*args, **kwargs)
86+
self.deepep_engine = deep_ep.Buffer(
87+
self.group,
88+
int(1e9),
89+
0,
90+
num_experts=self.num_experts,
91+
low_latency_mode=False,
92+
num_qps_per_rank=1,
7393
)
74-
self.init_deepep_engine()
75-
76-
def init_deepep_engine(self):
77-
if self.splitwise_role == "mixed": # 集中式场景需要初始化两种buffer,按需取用
78-
self.deepep_engine = deep_ep.Buffer(
79-
self.group,
80-
int(1e9),
81-
0,
82-
num_experts=self.num_experts,
83-
low_latency_mode=False,
84-
num_qps_per_rank=1,
85-
)
86-
logger.info("Initializing Low Latency Buffer")
87-
self.get_low_latency_buffer()
88-
elif self.moe_phase.phase == "prefill": # 分离式的P节点
89-
self.deepep_engine = deep_ep.Buffer(
90-
self.group,
91-
int(1e9),
92-
0,
93-
num_experts=self.num_experts,
94-
low_latency_mode=False,
95-
num_qps_per_rank=1,
96-
)
97-
elif self.moe_phase.phase == "decode": # 分离式的D节点
98-
logger.info("Initializing Low Latency Buffer")
99-
self.get_low_latency_buffer()
100-
else:
101-
raise ValueError(f"Unknown generation phase {self.moe_phase}")
94+
95+
96+
@singleton
97+
class DeepEPEngineLowLatency(DeepEPEngineBase):
98+
"""
99+
Low latency version of DeepEP engine for decode phase.
100+
"""
101+
102+
def __init__(self, *args, **kwargs):
103+
super().__init__(*args, **kwargs)
104+
self.get_low_latency_buffer()
102105

103106
def get_low_latency_buffer(self):
104107
"""
105-
Get the DeepEP buffer.
108+
Initialize low latency buffer for decode phase.
106109
Args:
107110
group: The MPI group object.
108111
num_max_dispatch_tokens_per_rank: The maximum number of tokens per rank to dispatch.
@@ -117,23 +120,16 @@ def get_low_latency_buffer(self):
117120
self.ep_size,
118121
self.num_experts,
119122
)
120-
# Allocate a buffer if not existed or not enough buffer size
121-
if (
122-
self.deepep_engine_low_latency is None
123-
or self.deepep_engine_low_latency.group != self.group
124-
or not self.deepep_engine_low_latency.low_latency_mode
125-
or self.deepep_engine_low_latency.num_rdma_bytes < num_rdma_bytes
126-
):
127-
# NOTES: for best performance, the QP number **must** be equal to the number of the local experts
128-
assert self.num_experts % self.ep_size == 0
129-
self.deepep_engine_low_latency = deep_ep.Buffer(
130-
self.group,
131-
0,
132-
num_rdma_bytes,
133-
self.num_experts,
134-
low_latency_mode=True,
135-
num_qps_per_rank=self.num_experts // self.ep_size,
136-
)
123+
# NOTES: for best performance, the QP number **must** be equal to the number of the local experts
124+
assert self.num_experts % self.ep_size == 0
125+
self.deepep_engine = deep_ep.Buffer(
126+
self.group,
127+
0,
128+
num_rdma_bytes,
129+
self.num_experts,
130+
low_latency_mode=True,
131+
num_qps_per_rank=self.num_experts // self.ep_size,
132+
)
137133

138134
def low_latency_dispatch(
139135
self,
@@ -165,7 +161,7 @@ def low_latency_dispatch(
165161
handle,
166162
dispatch_hook,
167163
valid_token_num,
168-
) = self.deepep_engine_low_latency.low_latency_dispatch(
164+
) = self.deepep_engine.low_latency_dispatch(
169165
hidden_states,
170166
moe_in_w4a8_scale,
171167
topk_idx,
@@ -186,11 +182,10 @@ def low_latency_combine(
186182
handle,
187183
):
188184
"""
189-
190185
Return:
191186
combined_hidden_states: [num_tokens, hidden_size]
192187
"""
193-
combined_hidden_states, combine_hook = self.deepep_engine_low_latency.low_latency_combine(
188+
combined_hidden_states, combine_hook = self.deepep_engine.low_latency_combine(
194189
hidden_states,
195190
topk_idx,
196191
topk_weights,
@@ -206,25 +201,24 @@ def clean_low_latency_buffer(self):
206201
"""
207202
pass
208203

209-
def barrier_all(self):
210-
"""
211-
barrier_all
212-
"""
213-
if self.deepep_engine is None and self.deepep_engine_low_latency is None:
214-
raise ValueError("The DeepEP engine has not been initialized yet.")
215-
216-
if self.deepep_engine is not None:
217-
self.deepep_engine.barrier_all()
218-
if self.deepep_engine_low_latency is not None:
219-
self.deepep_engine_low_latency.barrier_all()
220-
# self.deepep_engine.barrier_all()
221-
222204

223205
class XPUEPRunner:
224206
"""
225207
EPRunnerBase
226208
"""
227209

210+
def _init_ep_engine(self, engine_class):
211+
self.ep_engine = engine_class(
212+
num_max_dispatch_tokens_per_rank=self.num_max_dispatch_tokens_per_rank,
213+
hidden_size=self.hidden_size,
214+
num_experts=self.num_experts + self.redundant_experts_num,
215+
ep_size=self.ep_size,
216+
ep_rank=self.ep_rank,
217+
splitwise_role=self.splitwise_role,
218+
moe_phase=self.moe_phase,
219+
group=self.ep_group,
220+
)
221+
228222
def __init__(
229223
self,
230224
top_k: int,
@@ -248,19 +242,17 @@ def __init__(
248242
self.ep_rank = ep_rank
249243
self.redundant_experts_num = redundant_experts_num
250244
self.ep_group = ep_group
245+
self.ep_engine = None
251246
self.init_ep_engine()
252247

253248
def init_ep_engine(self):
254-
self.ep_engine = DeepEPEngine(
255-
num_max_dispatch_tokens_per_rank=self.num_max_dispatch_tokens_per_rank,
256-
hidden_size=self.hidden_size,
257-
num_experts=self.num_experts + self.redundant_experts_num,
258-
ep_size=self.ep_size,
259-
ep_rank=self.ep_rank,
260-
splitwise_role=self.splitwise_role,
261-
moe_phase=self.moe_phase,
262-
group=self.ep_group,
263-
)
249+
"""Initialize the EP engine with default implementation"""
250+
self._init_ep_engine(self._get_engine_class())
251+
252+
@abstractmethod
253+
def _get_engine_class(self):
254+
"""Get the engine class to be initialized"""
255+
raise NotImplementedError("Subclasses must implement this method")
264256

265257
def moe_select(self, layer: nn.Layer, gate_out: paddle.Tensor):
266258
"""
@@ -346,6 +338,9 @@ def __init__(
346338
ep_group=ep_group,
347339
)
348340

341+
def _get_engine_class(self):
342+
return DeepEPEngineHighThroughput
343+
349344
def dispatch(
350345
self,
351346
x: paddle.Tensor,
@@ -410,6 +405,9 @@ def __init__(
410405
ep_group=ep_group,
411406
)
412407

408+
def _get_engine_class(self):
409+
return DeepEPEngineLowLatency
410+
413411
def dispatch(
414412
self,
415413
x: paddle.Tensor,

0 commit comments

Comments
 (0)