diff --git a/fastdeploy/model_executor/layers/backends/xpu/moe/ep.py b/fastdeploy/model_executor/layers/backends/xpu/moe/ep.py index 71c2dd600ff..b49c4240e9f 100644 --- a/fastdeploy/model_executor/layers/backends/xpu/moe/ep.py +++ b/fastdeploy/model_executor/layers/backends/xpu/moe/ep.py @@ -19,17 +19,15 @@ import deep_ep import paddle from paddle import nn -from paddleformers.utils.log import logger import fastdeploy from fastdeploy.config import MoEPhase from fastdeploy.utils import singleton -@singleton -class DeepEPEngine: +class DeepEPEngineBase: """ - A wrapper class for DeepEP engine. + Base class for DeepEP engine implementations. """ def __init__( @@ -45,7 +43,7 @@ def __init__( group=None, ): """ - Initialize the DeepEP engine. + Initialize the DeepEP engine base. Args: group: The MPI group object. ep_size: The number of ranks. @@ -68,27 +66,48 @@ def __init__( self.group = group self.num_local_experts = num_experts // ep_size self.deepep_engine = None - self.init_deepep_engine() - - def init_deepep_engine(self): - if self.splitwise_role == "mixed" or self.moe_phase.phase == "prefill": - self.deepep_engine = deep_ep.Buffer( - self.group, - int(1e9), - 0, - num_experts=self.num_experts, - low_latency_mode=False, - num_qps_per_rank=1, - ) - elif self.moe_phase.phase == "decode": - logger.info("Initializing Low Latency Buffer") - self.get_low_latency_buffer() + + def barrier_all(self): + """ + barrier_all + """ + if self.deepep_engine is not None: + self.deepep_engine.barrier_all() else: - raise ValueError(f"Unknown generation phase {self.moe_phase}") + raise RuntimeError("The deepep engine has not been initialized yet.") + + +@singleton +class DeepEPEngineHighThroughput(DeepEPEngineBase): + """ + High throughput version of DeepEP engine for prefill phase. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.deepep_engine = deep_ep.Buffer( + self.group, + int(1e9), + 0, + num_experts=self.num_experts, + low_latency_mode=False, + num_qps_per_rank=1, + ) + + +@singleton +class DeepEPEngineLowLatency(DeepEPEngineBase): + """ + Low latency version of DeepEP engine for decode phase. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.get_low_latency_buffer() def get_low_latency_buffer(self): """ - Get the DeepEP buffer. + Initialize low latency buffer for decode phase. Args: group: The MPI group object. num_max_dispatch_tokens_per_rank: The maximum number of tokens per rank to dispatch. @@ -103,23 +122,16 @@ def get_low_latency_buffer(self): self.ep_size, self.num_experts, ) - # Allocate a buffer if not existed or not enough buffer size - if ( - self.deepep_engine is None - or self.deepep_engine.group != self.group - or not self.deepep_engine.low_latency_mode - or self.deepep_engine.num_rdma_bytes < num_rdma_bytes - ): - # NOTES: for best performance, the QP number **must** be equal to the number of the local experts - assert self.num_experts % self.ep_size == 0 - self.deepep_engine = deep_ep.Buffer( - self.group, - 0, - num_rdma_bytes, - self.num_experts, - low_latency_mode=True, - num_qps_per_rank=self.num_experts // self.num_ranks, - ) + # NOTES: for best performance, the QP number **must** be equal to the number of the local experts + assert self.num_experts % self.ep_size == 0 + self.deepep_engine = deep_ep.Buffer( + self.group, + 0, + num_rdma_bytes, + self.num_experts, + low_latency_mode=True, + num_qps_per_rank=self.num_experts // self.ep_size, + ) def low_latency_dispatch( self, @@ -172,7 +184,6 @@ def low_latency_combine( handle, ): """ - Return: combined_hidden_states: [num_tokens, hidden_size] """ @@ -192,12 +203,6 @@ def clean_low_latency_buffer(self): """ pass - def barrier_all(self): - """ - barrier_all - """ - self.deepep_engine.barrier_all() - class XPUEPRunner: """ @@ -227,10 +232,15 @@ def __init__( self.ep_rank = ep_rank self.redundant_experts_num = redundant_experts_num self.ep_group = ep_group + self.ep_engine = None self.init_ep_engine() def init_ep_engine(self): - self.ep_engine = DeepEPEngine( + """Initialize the EP engine with default implementation""" + self._init_ep_engine(self._get_engine_class()) + + def _init_ep_engine(self, engine_class): + self.ep_engine = engine_class( num_max_dispatch_tokens_per_rank=self.num_max_dispatch_tokens_per_rank, hidden_size=self.hidden_size, num_experts=self.num_experts + self.redundant_experts_num, @@ -241,6 +251,11 @@ def init_ep_engine(self): group=self.ep_group, ) + @abstractmethod + def _get_engine_class(self): + """Get the engine class to be initialized""" + raise NotImplementedError("Subclasses must implement this method") + def moe_select(self, layer: nn.Layer, gate_out: paddle.Tensor): """ moe_select @@ -325,6 +340,9 @@ def __init__( ep_group=ep_group, ) + def _get_engine_class(self): + return DeepEPEngineHighThroughput + def dispatch( self, x: paddle.Tensor, @@ -389,6 +407,9 @@ def __init__( ep_group=ep_group, ) + def _get_engine_class(self): + return DeepEPEngineLowLatency + def dispatch( self, x: paddle.Tensor, diff --git a/fastdeploy/worker/xpu_model_runner.py b/fastdeploy/worker/xpu_model_runner.py index b60eb8cdfc9..f371ead5963 100644 --- a/fastdeploy/worker/xpu_model_runner.py +++ b/fastdeploy/worker/xpu_model_runner.py @@ -408,6 +408,19 @@ def __init__( # Forward meta store the global meta information of the forward self.forward_meta: ForwardMeta = None + # Initialize shared memory and barrier for only_decode optimization + if self.fd_config.parallel_config.use_ep: + import multiprocessing + from threading import Barrier + + # Create shared memory list for all processes + group_size = self.parallel_config.expert_parallel_size + self.shared_only_decode_list = [multiprocessing.Value("i", 0) for _ in range(group_size)] + self.shared_not_need_stop_list = [multiprocessing.Value("i", 0) for _ in range(group_size)] + + # Create barrier for synchronization with timeout + self.decode_barrier = Barrier(group_size, timeout=10.0) + self.pd_disaggregation_mode: str = self.fd_config.parallel_config.pd_disaggregation_mode def exist_prefill(self): @@ -419,6 +432,55 @@ def exist_prefill(self): else: return 0 + def only_decode(self): + """ + check whether decode only using shared memory and barrier for all devices + """ + # Use shared memory to avoid d2h copy + if hasattr(self, "shared_only_decode_list") and self.fd_config.parallel_config.use_ep: + try: + world_size = self.parallel_config.expert_parallel_size + rank = self.rank % world_size + + # Combined check in one Barrier round + no_need_stop = self.not_need_stop() + self.shared_not_need_stop_list[rank].value = 1 if not no_need_stop else 0 + self.shared_only_decode_list[rank].value = self.forward_meta.len_info_cpu[0] <= 0 + + # Single Barrier for both checks + self.decode_barrier.wait() + + if_all_device_empty = all(p.value == 1 for p in self.shared_not_need_stop_list) + if_only_decode = all(p.value for p in self.shared_only_decode_list) + + # Single Barrier for reset + self.decode_barrier.wait() + + return False if if_all_device_empty else if_only_decode + except Exception as e: + logger.warning(f"Shared memory only_decode failed: {e}, fallback to original implementation") + + # Fallback to original implementation + if_only_decode = True + prefill_exists = None + if self.fd_config.parallel_config.use_ep and self.fd_config.scheduler_config.splitwise_role == "mixed": + no_need_stop_list = [] + no_need_stop = self.not_need_stop() + paddle.distributed.all_gather_object(no_need_stop_list, not no_need_stop) + if_all_device_empty = all(no_need_stop_list) + if if_all_device_empty: + if_only_decode = False + else: + only_decode_batch_list = [] + prefill_exists = self.exist_prefill() + paddle.distributed.all_gather_object(only_decode_batch_list, not prefill_exists) + if_only_decode = all(only_decode_batch_list) + + if_only_decode = if_only_decode and not ( + prefill_exists if prefill_exists is not None else self.exist_prefill() + ) + return if_only_decode + def insert_tasks_v1(self, req_dicts: List[Request]): """ Process scheduler output tasks, used when ENABLE_V1_KVCACHE_SCHEDULER=1 @@ -936,8 +998,16 @@ def _prepare_inputs(self, is_dummy_run=False) -> None: self.forward_meta.pos_emb_type = self.share_inputs["pos_emb_type"] self.forward_meta.attn_backend = self.attn_backends[0] self.initialize_attention_backend() + if self.pd_disaggregation_mode == "per_chunk" or self.pd_disaggregation_mode == "per_query": self.forward_meta.kv_signal_sender = self.kv_signal_sender + + if_only_decode = self.only_decode() + if ( + self.fd_config.scheduler_config.splitwise_role == "mixed" + ): # Centralized scenario: the phase is initialized as "prefill" by default. During inference runtime, different types of batches can achieve phase switching at this point. + self.fd_config.model_config.moe_phase.phase = "decode" if if_only_decode else "prefill" + # Get sampling metadata # TODU(lilujia): sync with GPU self.sampling_metadata = SamplingMetadata( diff --git a/scripts/run_ci_xpu.sh b/scripts/run_ci_xpu.sh index f4c217b3543..1741bd871ee 100644 --- a/scripts/run_ci_xpu.sh +++ b/scripts/run_ci_xpu.sh @@ -317,6 +317,7 @@ export BKCL_PCIE_RING=1 export XSHMEM_MODE=1 export XSHMEM_QP_NUM_PER_RANK=32 export BKCL_RDMA_VERBS=1 +export MOE_FFN_USE_DENSE_INPUT=1 wget -q https://paddle-qa.bj.bcebos.com/xpu_third_party/xDeepEP.tar.gz tar -xzf xDeepEP.tar.gz @@ -383,6 +384,7 @@ unset BKCL_PCIE_RING unset XSHMEM_MODE unset XSHMEM_QP_NUM_PER_RANK unset BKCL_RDMA_VERBS +unset MOE_FFN_USE_DENSE_INPUT stop_processes >kill.log 2>&1 if [ ${ep_online_exit_code} -ne 0 ]; then @@ -412,6 +414,7 @@ export BKCL_PCIE_RING=1 export XSHMEM_MODE=1 export XSHMEM_QP_NUM_PER_RANK=32 export BKCL_RDMA_VERBS=1 +export MOE_FFN_USE_DENSE_INPUT=1 export port_num=$((8188 + XPU_ID * 100)) # 启动服务 @@ -469,6 +472,7 @@ unset BKCL_PCIE_RING unset XSHMEM_MODE unset XSHMEM_QP_NUM_PER_RANK unset BKCL_RDMA_VERBS +unset MOE_FFN_USE_DENSE_INPUT stop_processes >kill.log 2>&1 if [ ${ep_online_exit_code} -ne 0 ]; then @@ -499,6 +503,7 @@ export BKCL_PCIE_RING=1 export XSHMEM_MODE=1 export XSHMEM_QP_NUM_PER_RANK=32 export BKCL_RDMA_VERBS=1 +export MOE_FFN_USE_DENSE_INPUT=1 export port_num=$((8188 + XPU_ID * 100)) # 启动服务 @@ -558,6 +563,7 @@ unset BKCL_PCIE_RING unset XSHMEM_MODE unset XSHMEM_QP_NUM_PER_RANK unset BKCL_RDMA_VERBS +unset MOE_FFN_USE_DENSE_INPUT stop_processes >kill.log 2>&1 if [ ${ep_online_exit_code} -ne 0 ]; then