Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
115 changes: 68 additions & 47 deletions fastdeploy/model_executor/layers/backends/xpu/moe/ep.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,17 +19,15 @@
import deep_ep
import paddle
from paddle import nn
from paddleformers.utils.log import logger

import fastdeploy
from fastdeploy.config import MoEPhase
from fastdeploy.utils import singleton


@singleton
class DeepEPEngine:
class DeepEPEngineBase:
"""
A wrapper class for DeepEP engine.
Base class for DeepEP engine implementations.
"""

def __init__(
Expand All @@ -45,7 +43,7 @@ def __init__(
group=None,
):
"""
Initialize the DeepEP engine.
Initialize the DeepEP engine base.
Args:
group: The MPI group object.
ep_size: The number of ranks.
Expand All @@ -68,27 +66,48 @@ def __init__(
self.group = group
self.num_local_experts = num_experts // ep_size
self.deepep_engine = None
self.init_deepep_engine()

def init_deepep_engine(self):
if self.splitwise_role == "mixed" or self.moe_phase.phase == "prefill":
self.deepep_engine = deep_ep.Buffer(
self.group,
int(1e9),
0,
num_experts=self.num_experts,
low_latency_mode=False,
num_qps_per_rank=1,
)
elif self.moe_phase.phase == "decode":
logger.info("Initializing Low Latency Buffer")
self.get_low_latency_buffer()

def barrier_all(self):
"""
barrier_all
"""
if self.deepep_engine is not None:
self.deepep_engine.barrier_all()
else:
raise ValueError(f"Unknown generation phase {self.moe_phase}")
raise RuntimeError("The deepep engine has not been initialized yet.")


@singleton
class DeepEPEngineHighThroughput(DeepEPEngineBase):
"""
High throughput version of DeepEP engine for prefill phase.
"""

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.deepep_engine = deep_ep.Buffer(
self.group,
int(1e9),
0,
num_experts=self.num_experts,
low_latency_mode=False,
num_qps_per_rank=1,
)


@singleton
class DeepEPEngineLowLatency(DeepEPEngineBase):
"""
Low latency version of DeepEP engine for decode phase.
"""

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.get_low_latency_buffer()

def get_low_latency_buffer(self):
"""
Get the DeepEP buffer.
Initialize low latency buffer for decode phase.
Args:
group: The MPI group object.
num_max_dispatch_tokens_per_rank: The maximum number of tokens per rank to dispatch.
Expand All @@ -103,23 +122,16 @@ def get_low_latency_buffer(self):
self.ep_size,
self.num_experts,
)
# Allocate a buffer if not existed or not enough buffer size
if (
self.deepep_engine is None
or self.deepep_engine.group != self.group
or not self.deepep_engine.low_latency_mode
or self.deepep_engine.num_rdma_bytes < num_rdma_bytes
):
# NOTES: for best performance, the QP number **must** be equal to the number of the local experts
assert self.num_experts % self.ep_size == 0
self.deepep_engine = deep_ep.Buffer(
self.group,
0,
num_rdma_bytes,
self.num_experts,
low_latency_mode=True,
num_qps_per_rank=self.num_experts // self.num_ranks,
)
# NOTES: for best performance, the QP number **must** be equal to the number of the local experts
assert self.num_experts % self.ep_size == 0
self.deepep_engine = deep_ep.Buffer(
self.group,
0,
num_rdma_bytes,
self.num_experts,
low_latency_mode=True,
num_qps_per_rank=self.num_experts // self.ep_size,
)

def low_latency_dispatch(
self,
Expand Down Expand Up @@ -172,7 +184,6 @@ def low_latency_combine(
handle,
):
"""

Return:
combined_hidden_states: [num_tokens, hidden_size]
"""
Expand All @@ -192,12 +203,6 @@ def clean_low_latency_buffer(self):
"""
pass

def barrier_all(self):
"""
barrier_all
"""
self.deepep_engine.barrier_all()


class XPUEPRunner:
"""
Expand Down Expand Up @@ -227,10 +232,15 @@ def __init__(
self.ep_rank = ep_rank
self.redundant_experts_num = redundant_experts_num
self.ep_group = ep_group
self.ep_engine = None
self.init_ep_engine()

def init_ep_engine(self):
self.ep_engine = DeepEPEngine(
"""Initialize the EP engine with default implementation"""
self._init_ep_engine(self._get_engine_class())

def _init_ep_engine(self, engine_class):
self.ep_engine = engine_class(
num_max_dispatch_tokens_per_rank=self.num_max_dispatch_tokens_per_rank,
hidden_size=self.hidden_size,
num_experts=self.num_experts + self.redundant_experts_num,
Expand All @@ -241,6 +251,11 @@ def init_ep_engine(self):
group=self.ep_group,
)

@abstractmethod
def _get_engine_class(self):
"""Get the engine class to be initialized"""
raise NotImplementedError("Subclasses must implement this method")

def moe_select(self, layer: nn.Layer, gate_out: paddle.Tensor):
"""
moe_select
Expand Down Expand Up @@ -325,6 +340,9 @@ def __init__(
ep_group=ep_group,
)

def _get_engine_class(self):
return DeepEPEngineHighThroughput

def dispatch(
self,
x: paddle.Tensor,
Expand Down Expand Up @@ -389,6 +407,9 @@ def __init__(
ep_group=ep_group,
)

def _get_engine_class(self):
return DeepEPEngineLowLatency

def dispatch(
self,
x: paddle.Tensor,
Expand Down
70 changes: 70 additions & 0 deletions fastdeploy/worker/xpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -408,6 +408,19 @@ def __init__(
# Forward meta store the global meta information of the forward
self.forward_meta: ForwardMeta = None

# Initialize shared memory and barrier for only_decode optimization
if self.fd_config.parallel_config.use_ep:
import multiprocessing
from threading import Barrier

# Create shared memory list for all processes
group_size = self.parallel_config.expert_parallel_size
self.shared_only_decode_list = [multiprocessing.Value("i", 0) for _ in range(group_size)]
self.shared_not_need_stop_list = [multiprocessing.Value("i", 0) for _ in range(group_size)]

# Create barrier for synchronization with timeout
self.decode_barrier = Barrier(group_size, timeout=10.0)

self.pd_disaggregation_mode: str = self.fd_config.parallel_config.pd_disaggregation_mode

def exist_prefill(self):
Expand All @@ -419,6 +432,55 @@ def exist_prefill(self):
else:
return 0

def only_decode(self):
"""
check whether decode only using shared memory and barrier for all devices
"""
# Use shared memory to avoid d2h copy
if hasattr(self, "shared_only_decode_list") and self.fd_config.parallel_config.use_ep:
try:
world_size = self.parallel_config.expert_parallel_size
rank = self.rank % world_size

# Combined check in one Barrier round
no_need_stop = self.not_need_stop()
self.shared_not_need_stop_list[rank].value = 1 if not no_need_stop else 0
self.shared_only_decode_list[rank].value = self.forward_meta.len_info_cpu[0] <= 0

# Single Barrier for both checks
self.decode_barrier.wait()

if_all_device_empty = all(p.value == 1 for p in self.shared_not_need_stop_list)
if_only_decode = all(p.value for p in self.shared_only_decode_list)

# Single Barrier for reset
self.decode_barrier.wait()

return False if if_all_device_empty else if_only_decode
except Exception as e:
logger.warning(f"Shared memory only_decode failed: {e}, fallback to original implementation")

# Fallback to original implementation
if_only_decode = True
prefill_exists = None
if self.fd_config.parallel_config.use_ep and self.fd_config.scheduler_config.splitwise_role == "mixed":
no_need_stop_list = []
no_need_stop = self.not_need_stop()
paddle.distributed.all_gather_object(no_need_stop_list, not no_need_stop)
if_all_device_empty = all(no_need_stop_list)
if if_all_device_empty:
if_only_decode = False
else:
only_decode_batch_list = []
prefill_exists = self.exist_prefill()
paddle.distributed.all_gather_object(only_decode_batch_list, not prefill_exists)
if_only_decode = all(only_decode_batch_list)

if_only_decode = if_only_decode and not (
prefill_exists if prefill_exists is not None else self.exist_prefill()
)
return if_only_decode

def insert_tasks_v1(self, req_dicts: List[Request]):
"""
Process scheduler output tasks, used when ENABLE_V1_KVCACHE_SCHEDULER=1
Expand Down Expand Up @@ -936,8 +998,16 @@ def _prepare_inputs(self, is_dummy_run=False) -> None:
self.forward_meta.pos_emb_type = self.share_inputs["pos_emb_type"]
self.forward_meta.attn_backend = self.attn_backends[0]
self.initialize_attention_backend()

if self.pd_disaggregation_mode == "per_chunk" or self.pd_disaggregation_mode == "per_query":
self.forward_meta.kv_signal_sender = self.kv_signal_sender

if_only_decode = self.only_decode()
if (
self.fd_config.scheduler_config.splitwise_role == "mixed"
): # Centralized scenario: the phase is initialized as "prefill" by default. During inference runtime, different types of batches can achieve phase switching at this point.
self.fd_config.model_config.moe_phase.phase = "decode" if if_only_decode else "prefill"

# Get sampling metadata
# TODU(lilujia): sync with GPU
self.sampling_metadata = SamplingMetadata(
Expand Down
6 changes: 6 additions & 0 deletions scripts/run_ci_xpu.sh
Original file line number Diff line number Diff line change
Expand Up @@ -317,6 +317,7 @@ export BKCL_PCIE_RING=1
export XSHMEM_MODE=1
export XSHMEM_QP_NUM_PER_RANK=32
export BKCL_RDMA_VERBS=1
export MOE_FFN_USE_DENSE_INPUT=1

wget -q https://paddle-qa.bj.bcebos.com/xpu_third_party/xDeepEP.tar.gz
tar -xzf xDeepEP.tar.gz
Expand Down Expand Up @@ -383,6 +384,7 @@ unset BKCL_PCIE_RING
unset XSHMEM_MODE
unset XSHMEM_QP_NUM_PER_RANK
unset BKCL_RDMA_VERBS
unset MOE_FFN_USE_DENSE_INPUT
stop_processes >kill.log 2>&1

if [ ${ep_online_exit_code} -ne 0 ]; then
Expand Down Expand Up @@ -412,6 +414,7 @@ export BKCL_PCIE_RING=1
export XSHMEM_MODE=1
export XSHMEM_QP_NUM_PER_RANK=32
export BKCL_RDMA_VERBS=1
export MOE_FFN_USE_DENSE_INPUT=1

export port_num=$((8188 + XPU_ID * 100))
# 启动服务
Expand Down Expand Up @@ -469,6 +472,7 @@ unset BKCL_PCIE_RING
unset XSHMEM_MODE
unset XSHMEM_QP_NUM_PER_RANK
unset BKCL_RDMA_VERBS
unset MOE_FFN_USE_DENSE_INPUT
stop_processes >kill.log 2>&1

if [ ${ep_online_exit_code} -ne 0 ]; then
Expand Down Expand Up @@ -499,6 +503,7 @@ export BKCL_PCIE_RING=1
export XSHMEM_MODE=1
export XSHMEM_QP_NUM_PER_RANK=32
export BKCL_RDMA_VERBS=1
export MOE_FFN_USE_DENSE_INPUT=1

export port_num=$((8188 + XPU_ID * 100))
# 启动服务
Expand Down Expand Up @@ -558,6 +563,7 @@ unset BKCL_PCIE_RING
unset XSHMEM_MODE
unset XSHMEM_QP_NUM_PER_RANK
unset BKCL_RDMA_VERBS
unset MOE_FFN_USE_DENSE_INPUT
stop_processes >kill.log 2>&1

if [ ${ep_online_exit_code} -ne 0 ]; then
Expand Down
Loading