1919import deep_ep
2020import paddle
2121from paddle import nn
22- from paddleformers .utils .log import logger
2322
2423import fastdeploy
2524from fastdeploy .config import MoEPhase
2625from fastdeploy .utils import singleton
2726
2827
29- @singleton
30- class DeepEPEngine :
28+ class DeepEPEngineBase :
3129 """
32- A wrapper class for DeepEP engine.
30+ Base class for DeepEP engine implementations .
3331 """
3432
3533 def __init__ (
@@ -45,7 +43,7 @@ def __init__(
4543 group = None ,
4644 ):
4745 """
48- Initialize the DeepEP engine.
46+ Initialize the DeepEP engine base .
4947 Args:
5048 group: The MPI group object.
5149 ep_size: The number of ranks.
@@ -67,42 +65,47 @@ def __init__(
6765 group = paddle .distributed .new_group (range (ep_size ))
6866 self .group = group
6967 self .num_local_experts = num_experts // ep_size
70- self .deepep_engine = None # deepep_engine只调用dispatch, combine
71- self .deepep_engine_low_latency = (
72- None # deepep_engine_low_latency只调用low_latency_dispatch,low_latency_combine
68+ self .deepep_engine = None
69+
70+ def barrier_all (self ):
71+ """
72+ barrier_all
73+ """
74+ if self .deepep_engine is not None :
75+ self .deepep_engine .barrier_all ()
76+
77+
78+ @singleton
79+ class DeepEPEngineHighThroughput (DeepEPEngineBase ):
80+ """
81+ High throughput version of DeepEP engine for prefill phase.
82+ """
83+
84+ def __init__ (self , * args , ** kwargs ):
85+ super ().__init__ (* args , ** kwargs )
86+ self .deepep_engine = deep_ep .Buffer (
87+ self .group ,
88+ int (1e9 ),
89+ 0 ,
90+ num_experts = self .num_experts ,
91+ low_latency_mode = False ,
92+ num_qps_per_rank = 1 ,
7393 )
74- self .init_deepep_engine ()
75-
76- def init_deepep_engine (self ):
77- if self .splitwise_role == "mixed" : # 集中式场景需要初始化两种buffer,按需取用
78- self .deepep_engine = deep_ep .Buffer (
79- self .group ,
80- int (1e9 ),
81- 0 ,
82- num_experts = self .num_experts ,
83- low_latency_mode = False ,
84- num_qps_per_rank = 1 ,
85- )
86- logger .info ("Initializing Low Latency Buffer" )
87- self .get_low_latency_buffer ()
88- elif self .moe_phase .phase == "prefill" : # 分离式的P节点
89- self .deepep_engine = deep_ep .Buffer (
90- self .group ,
91- int (1e9 ),
92- 0 ,
93- num_experts = self .num_experts ,
94- low_latency_mode = False ,
95- num_qps_per_rank = 1 ,
96- )
97- elif self .moe_phase .phase == "decode" : # 分离式的D节点
98- logger .info ("Initializing Low Latency Buffer" )
99- self .get_low_latency_buffer ()
100- else :
101- raise ValueError (f"Unknown generation phase { self .moe_phase } " )
94+
95+
96+ @singleton
97+ class DeepEPEngineLowLatency (DeepEPEngineBase ):
98+ """
99+ Low latency version of DeepEP engine for decode phase.
100+ """
101+
102+ def __init__ (self , * args , ** kwargs ):
103+ super ().__init__ (* args , ** kwargs )
104+ self .get_low_latency_buffer ()
102105
103106 def get_low_latency_buffer (self ):
104107 """
105- Get the DeepEP buffer.
108+ Initialize low latency buffer for decode phase .
106109 Args:
107110 group: The MPI group object.
108111 num_max_dispatch_tokens_per_rank: The maximum number of tokens per rank to dispatch.
@@ -117,23 +120,16 @@ def get_low_latency_buffer(self):
117120 self .ep_size ,
118121 self .num_experts ,
119122 )
120- # Allocate a buffer if not existed or not enough buffer size
121- if (
122- self .deepep_engine_low_latency is None
123- or self .deepep_engine_low_latency .group != self .group
124- or not self .deepep_engine_low_latency .low_latency_mode
125- or self .deepep_engine_low_latency .num_rdma_bytes < num_rdma_bytes
126- ):
127- # NOTES: for best performance, the QP number **must** be equal to the number of the local experts
128- assert self .num_experts % self .ep_size == 0
129- self .deepep_engine_low_latency = deep_ep .Buffer (
130- self .group ,
131- 0 ,
132- num_rdma_bytes ,
133- self .num_experts ,
134- low_latency_mode = True ,
135- num_qps_per_rank = self .num_experts // self .ep_size ,
136- )
123+ # NOTES: for best performance, the QP number **must** be equal to the number of the local experts
124+ assert self .num_experts % self .ep_size == 0
125+ self .deepep_engine = deep_ep .Buffer (
126+ self .group ,
127+ 0 ,
128+ num_rdma_bytes ,
129+ self .num_experts ,
130+ low_latency_mode = True ,
131+ num_qps_per_rank = self .num_experts // self .ep_size ,
132+ )
137133
138134 def low_latency_dispatch (
139135 self ,
@@ -165,7 +161,7 @@ def low_latency_dispatch(
165161 handle ,
166162 dispatch_hook ,
167163 valid_token_num ,
168- ) = self .deepep_engine_low_latency .low_latency_dispatch (
164+ ) = self .deepep_engine .low_latency_dispatch (
169165 hidden_states ,
170166 moe_in_w4a8_scale ,
171167 topk_idx ,
@@ -186,11 +182,10 @@ def low_latency_combine(
186182 handle ,
187183 ):
188184 """
189-
190185 Return:
191186 combined_hidden_states: [num_tokens, hidden_size]
192187 """
193- combined_hidden_states , combine_hook = self .deepep_engine_low_latency .low_latency_combine (
188+ combined_hidden_states , combine_hook = self .deepep_engine .low_latency_combine (
194189 hidden_states ,
195190 topk_idx ,
196191 topk_weights ,
@@ -206,25 +201,24 @@ def clean_low_latency_buffer(self):
206201 """
207202 pass
208203
209- def barrier_all (self ):
210- """
211- barrier_all
212- """
213- if self .deepep_engine is None and self .deepep_engine_low_latency is None :
214- raise ValueError ("The DeepEP engine has not been initialized yet." )
215-
216- if self .deepep_engine is not None :
217- self .deepep_engine .barrier_all ()
218- if self .deepep_engine_low_latency is not None :
219- self .deepep_engine_low_latency .barrier_all ()
220- # self.deepep_engine.barrier_all()
221-
222204
223205class XPUEPRunner :
224206 """
225207 EPRunnerBase
226208 """
227209
210+ def _init_ep_engine (self , engine_class ):
211+ self .ep_engine = engine_class (
212+ num_max_dispatch_tokens_per_rank = self .num_max_dispatch_tokens_per_rank ,
213+ hidden_size = self .hidden_size ,
214+ num_experts = self .num_experts + self .redundant_experts_num ,
215+ ep_size = self .ep_size ,
216+ ep_rank = self .ep_rank ,
217+ splitwise_role = self .splitwise_role ,
218+ moe_phase = self .moe_phase ,
219+ group = self .ep_group ,
220+ )
221+
228222 def __init__ (
229223 self ,
230224 top_k : int ,
@@ -248,19 +242,17 @@ def __init__(
248242 self .ep_rank = ep_rank
249243 self .redundant_experts_num = redundant_experts_num
250244 self .ep_group = ep_group
245+ self .ep_engine = None
251246 self .init_ep_engine ()
252247
253248 def init_ep_engine (self ):
254- self .ep_engine = DeepEPEngine (
255- num_max_dispatch_tokens_per_rank = self .num_max_dispatch_tokens_per_rank ,
256- hidden_size = self .hidden_size ,
257- num_experts = self .num_experts + self .redundant_experts_num ,
258- ep_size = self .ep_size ,
259- ep_rank = self .ep_rank ,
260- splitwise_role = self .splitwise_role ,
261- moe_phase = self .moe_phase ,
262- group = self .ep_group ,
263- )
249+ """Initialize the EP engine with default implementation"""
250+ self ._init_ep_engine (self ._get_engine_class ())
251+
252+ @abstractmethod
253+ def _get_engine_class (self ):
254+ """Get the engine class to be initialized"""
255+ raise NotImplementedError ("Subclasses must implement this method" )
264256
265257 def moe_select (self , layer : nn .Layer , gate_out : paddle .Tensor ):
266258 """
@@ -346,6 +338,9 @@ def __init__(
346338 ep_group = ep_group ,
347339 )
348340
341+ def _get_engine_class (self ):
342+ return DeepEPEngineHighThroughput
343+
349344 def dispatch (
350345 self ,
351346 x : paddle .Tensor ,
@@ -410,6 +405,9 @@ def __init__(
410405 ep_group = ep_group ,
411406 )
412407
408+ def _get_engine_class (self ):
409+ return DeepEPEngineLowLatency
410+
413411 def dispatch (
414412 self ,
415413 x : paddle .Tensor ,
0 commit comments