Skip to content

Commit 43097a5

Browse files
[BugFix] [PD Disaggregation] fix v1 scheduler prefill node profile run & ipc transfer protocol (#5132)
* [fix] fix v1 scheduler profile run for append attention in prefill node * [fix] skip send_signal if kv signal not inited for gpu and xpu * [fix] extend fix to flash_attn & mla_attn * [fix] fix v1 pd run in ipc transfer protocol * [ci] add test for v1 pd profile run using ipc transfer protocol * [style] fix code style check * [style] fix code style again * [fix] fix profile run * [update] remove --num-gpu-blocks-override in example script * [chore] rename forward_meta is_profiling to is_dummy_or_profile_run
1 parent 01c30f6 commit 43097a5

File tree

12 files changed

+510
-92
lines changed

12 files changed

+510
-92
lines changed

custom_ops/gpu_ops/remote_cache_kv_ipc.h

Lines changed: 71 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -18,88 +18,94 @@
1818
#include <stdio.h>
1919
#include <stdlib.h>
2020
#include <string.h>
21+
#include <sys/ipc.h>
2122
#include <sys/mman.h>
23+
#include <sys/msg.h>
2224
#include <sys/stat.h>
2325
#include <sys/types.h>
24-
#include <sys/ipc.h>
25-
#include <sys/msg.h>
2626
#include <unistd.h>
2727

2828
#include "driver_types.h"
29+
#include "msg_utils.h"
2930
#include "paddle/extension.h"
3031
#include "paddle/phi/core/allocator.h"
3132
#include "paddle/phi/core/dense_tensor.h"
32-
#include "msg_utils.h"
3333

3434
struct RemoteCacheKvIpc {
35-
struct save_cache_kv_complete_signal_layerwise_meta_data{
36-
int32_t layer_id=-1;
37-
void * shm_ptr=nullptr;
38-
int shm_fd=-1;
39-
save_cache_kv_complete_signal_layerwise_meta_data(){}
40-
save_cache_kv_complete_signal_layerwise_meta_data(int32_t layer_id_,
41-
void* shm_ptr_,
42-
int shm_fd_)
43-
:layer_id(layer_id_), shm_ptr(shm_ptr_), shm_fd(shm_fd_){
44-
}
45-
};
35+
struct save_cache_kv_complete_signal_layerwise_meta_data {
36+
int32_t layer_id = -1;
37+
void* shm_ptr = nullptr;
38+
int shm_fd = -1;
39+
save_cache_kv_complete_signal_layerwise_meta_data() {}
40+
save_cache_kv_complete_signal_layerwise_meta_data(int32_t layer_id_,
41+
void* shm_ptr_,
42+
int shm_fd_)
43+
: layer_id(layer_id_), shm_ptr(shm_ptr_), shm_fd(shm_fd_) {}
44+
};
4645

47-
struct save_cache_kv_complete_signal_layerwise_meta_data_per_query{
48-
int layer_id_;
49-
int num_layers_;
50-
bool inited = false;
51-
struct msgdatakv msg_sed;
52-
int msgid;
46+
struct save_cache_kv_complete_signal_layerwise_meta_data_per_query {
47+
int layer_id_;
48+
int num_layers_;
49+
bool inited = false;
50+
struct msgdatakv msg_sed;
51+
int msgid;
5352

54-
save_cache_kv_complete_signal_layerwise_meta_data_per_query(){}
53+
save_cache_kv_complete_signal_layerwise_meta_data_per_query() {}
5554

56-
void init(const int *seq_lens_encoder,
57-
const int *seq_lens_decoder,
58-
const int rank,
59-
const int num_layers,
60-
const int real_bsz) {
61-
layer_id_ = 0;
62-
num_layers_ = num_layers;
63-
msg_sed.mtype = 1;
64-
int encoder_count = 0;
65-
for (int i = 0; i < real_bsz; i++) {
66-
if (seq_lens_encoder[i] > 0) {
67-
msg_sed.mtext[3 * encoder_count + 2] = i;
68-
msg_sed.mtext[3 * encoder_count + 3] = seq_lens_decoder[i];
69-
msg_sed.mtext[3 * encoder_count + 4] = seq_lens_encoder[i];
70-
encoder_count++;
71-
}
72-
}
73-
msg_sed.mtext[0] = encoder_count;
74-
75-
if (!inited) {
76-
// just init once
77-
const int msg_id = 1024 + rank;
78-
key_t key = ftok("/opt/", msg_id);
79-
msgid = msgget(key, IPC_CREAT | 0666);
80-
inited = true;
81-
}
55+
void init(const int* seq_lens_encoder,
56+
const int* seq_lens_decoder,
57+
const int rank,
58+
const int num_layers,
59+
const int real_bsz) {
60+
layer_id_ = 0;
61+
num_layers_ = num_layers;
62+
msg_sed.mtype = 1;
63+
int encoder_count = 0;
64+
for (int i = 0; i < real_bsz; i++) {
65+
if (seq_lens_encoder[i] > 0) {
66+
msg_sed.mtext[3 * encoder_count + 2] = i;
67+
msg_sed.mtext[3 * encoder_count + 3] = seq_lens_decoder[i];
68+
msg_sed.mtext[3 * encoder_count + 4] = seq_lens_encoder[i];
69+
encoder_count++;
8270
}
71+
}
72+
msg_sed.mtext[0] = encoder_count;
73+
74+
if (!inited) {
75+
// just init once
76+
const int msg_id = 1024 + rank;
77+
key_t key = ftok("/opt/", msg_id);
78+
msgid = msgget(key, IPC_CREAT | 0666);
79+
inited = true;
80+
}
81+
}
8382

84-
void CUDART_CB send_signal() {
85-
msg_sed.mtext[1] = layer_id_;
86-
if ((msgsnd(msgid, &msg_sed, (MAX_BSZ * 3 + 2) * 4, 0)) == -1) {
87-
printf("kv signal full msg buffer\n");
88-
}
89-
layer_id_ = (layer_id_ + 1);
90-
assert(layer_id_ <= num_layers_);
83+
void CUDART_CB send_signal() {
84+
if (inited) {
85+
msg_sed.mtext[1] = layer_id_;
86+
if ((msgsnd(msgid, &msg_sed, (MAX_BSZ * 3 + 2) * 4, 0)) == -1) {
87+
printf("kv signal full msg buffer\n");
9188
}
92-
};
89+
layer_id_ = (layer_id_ + 1);
90+
assert(layer_id_ <= num_layers_);
91+
}
92+
}
93+
};
9394

94-
static RemoteCacheKvIpc::save_cache_kv_complete_signal_layerwise_meta_data kv_complete_signal_meta_data;
95-
static RemoteCacheKvIpc::save_cache_kv_complete_signal_layerwise_meta_data_per_query kv_complete_signal_meta_data_per_query;
96-
static void* kv_complete_signal_identity_ptr;
97-
static bool kv_complete_signal_shmem_opened;
95+
static RemoteCacheKvIpc::save_cache_kv_complete_signal_layerwise_meta_data
96+
kv_complete_signal_meta_data;
97+
static RemoteCacheKvIpc::
98+
save_cache_kv_complete_signal_layerwise_meta_data_per_query
99+
kv_complete_signal_meta_data_per_query;
100+
static void* kv_complete_signal_identity_ptr;
101+
static bool kv_complete_signal_shmem_opened;
98102

99-
static RemoteCacheKvIpc::save_cache_kv_complete_signal_layerwise_meta_data open_shm_and_get_complete_signal_meta_data(
100-
const int rank_id,
101-
const int device_id,
102-
const bool keep_pd_step_flag);
103-
static void CUDART_CB save_cache_kv_complete_signal_layerwise(void* meta_data);
104-
static void CUDART_CB save_cache_kv_complete_signal_layerwise_per_query(void* meta_data);
103+
static RemoteCacheKvIpc::save_cache_kv_complete_signal_layerwise_meta_data
104+
open_shm_and_get_complete_signal_meta_data(const int rank_id,
105+
const int device_id,
106+
const bool keep_pd_step_flag);
107+
static void CUDART_CB
108+
save_cache_kv_complete_signal_layerwise(void* meta_data);
109+
static void CUDART_CB
110+
save_cache_kv_complete_signal_layerwise_per_query(void* meta_data);
105111
};

custom_ops/xpu_ops/src/ops/remote_cache_kv_ipc.h

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -72,12 +72,14 @@ struct RemoteCacheKvIpc {
7272
}
7373

7474
void send_signal() {
75-
msg_sed.mtext[1] = layer_id_;
76-
if ((msgsnd(msgid, &msg_sed, (MAX_BSZ * 3 + 2) * 4, 0)) == -1) {
77-
printf("kv signal full msg buffer\n");
75+
if (inited) {
76+
msg_sed.mtext[1] = layer_id_;
77+
if ((msgsnd(msgid, &msg_sed, (MAX_BSZ * 3 + 2) * 4, 0)) == -1) {
78+
printf("kv signal full msg buffer\n");
79+
}
80+
layer_id_ = (layer_id_ + 1);
81+
assert(layer_id_ <= num_layers_);
7882
}
79-
layer_id_ = (layer_id_ + 1);
80-
assert(layer_id_ <= num_layers_);
8183
}
8284
};
8385

examples/splitwise/start_v1_tp1.sh

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,6 @@ nohup python -m fastdeploy.entrypoints.openai.api_server \
6868
--cache-transfer-protocol "rdma" \
6969
--rdma-comm-ports "$((P_PORT + 4))" \
7070
--pd-comm-port "$((P_PORT + 5))" \
71-
--num-gpu-blocks-override 2000 \
7271
--router "0.0.0.0:${ROUTER_PORT}" \
7372
2>&1 >${FD_LOG_DIR}/nohup &
7473

fastdeploy/cache_manager/cache_messager.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -687,8 +687,8 @@ def prefill_layerwise_send_cache_thread(self):
687687
for engine_idx, _ in batch_engine_signals:
688688
task = self.idx_cache_task_dict[engine_idx]
689689
if task["status"] == "finished" or ("error" in task["status"]):
690-
target_id = int(task["rdma_ports"][self.rank])
691690
if task["transfer_protocol"] == "ipc":
691+
target_id = int(task["device_ids"][self.rank])
692692
self.messager["ipc"].write_block_by_sync(target_id)
693693
self.engine_worker_queue.finish_send_cache_barrier.wait()
694694
self.engine_worker_queue.put_finished_req([[task["request_id"], task["status"]]])

fastdeploy/engine/args_utils.py

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -517,18 +517,6 @@ def __post_init__(self):
517517
f"The number of rdma comm ports must be equal to number of ranks ({self.data_parallel_size=} * {self.tensor_parallel_size=} = {self.data_parallel_size * self.tensor_parallel_size}), but got {len(self.rdma_comm_ports)}."
518518
)
519519

520-
if envs.ENABLE_V1_KVCACHE_SCHEDULER == 1:
521-
if "ipc" in self.cache_transfer_protocol:
522-
# FIXME: support ipc cache transfer protocol
523-
raise NotImplementedError(
524-
"only support rdma cache transfer protocol " "when using ENABLE_V1_KVCACHE_SCHEDULER."
525-
)
526-
# FIXME: fix this bug
527-
if self.splitwise_role == "prefill" and self.num_gpu_blocks_override is None:
528-
raise NotImplementedError(
529-
"please set num_gpu_blocks_override for prefill " "instance using ENABLE_V1_KVCACHE_SCHEDULER."
530-
)
531-
532520
if not current_platform.is_cuda() and not current_platform.is_xpu():
533521
envs.ENABLE_V1_KVCACHE_SCHEDULER = 0
534522
if self.guided_decoding_backend != "off":

fastdeploy/engine/sched/resource_manager_v1.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1001,7 +1001,7 @@ def preallocate_resource_in_d(self, request: Request):
10011001
request.need_prefill_tokens + self.config.cache_config.block_size - 1
10021002
) // self.config.cache_config.block_size + self.config.cache_config.enc_dec_block_num # consider for mtp, plus enc_dec_block_num
10031003
if self.cache_manager.can_allocate_gpu_blocks(need_prealloc_prefill_blocks):
1004-
request.block_tables.extend(self.cache_manager.allocate_gpu_blocks(need_prealloc_prefill_blocks))
1004+
request.block_tables = self.cache_manager.allocate_gpu_blocks(need_prealloc_prefill_blocks)
10051005
request.num_computed_tokens = request.need_prefill_tokens
10061006
request.disaggregate_info["block_tables"] = request.block_tables
10071007
allocated_position = self.get_available_position()

fastdeploy/model_executor/forward_meta.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,8 @@ class ForwardMeta:
140140
block_tables: Optional[paddle.Tensor] = None
141141
# KV caches
142142
caches: Optional[list[paddle.Tensor]] = None
143+
# Flag of profile run
144+
is_dummy_or_profile_run: bool = False
143145

144146
def clear_caches(self):
145147
"""Safely clean up the caches"""

fastdeploy/model_executor/layers/attention/append_attn_backend.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -178,7 +178,7 @@ def init_attention_metadata(self, forward_meta: ForwardMeta):
178178
# pd_disaggregation
179179
metadata.kv_signal_data_list = [None] * self.num_layers
180180
if self.pd_disaggregation_mode == "per_chunk":
181-
if not self.keep_pd_step_flag:
181+
if not self.keep_pd_step_flag and not forward_meta.is_dummy_or_profile_run:
182182
init_kv_signal_per_query(
183183
forward_meta.seq_lens_encoder,
184184
forward_meta.seq_lens_this_time,

fastdeploy/model_executor/layers/attention/flash_attn_backend.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -231,7 +231,7 @@ def init_attention_metadata(self, forward_meta: ForwardMeta):
231231
# pd_disaggregation
232232
metadata.kv_signal_data_list = [None] * self.num_layers
233233
if self.pd_disaggregation_mode == "per_chunk":
234-
if not self.keep_pd_step_flag:
234+
if not self.keep_pd_step_flag and not forward_meta.is_dummy_or_profile_run:
235235
init_kv_signal_per_query(
236236
forward_meta.seq_lens_encoder,
237237
forward_meta.seq_lens_this_time,

fastdeploy/model_executor/layers/attention/mla_attention_backend.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -214,7 +214,7 @@ def init_attention_metadata(self, forward_meta: ForwardMeta):
214214
# pd_disaggregation
215215
metadata.kv_signal_data_list = [None] * self.num_layers
216216
if self.pd_disaggregation_mode == "per_chunk":
217-
if not self.keep_pd_step_flag:
217+
if not self.keep_pd_step_flag and not forward_meta.is_dummy_or_profile_run:
218218
init_kv_signal_per_query(
219219
forward_meta.seq_lens_encoder,
220220
forward_meta.seq_lens_this_time,

0 commit comments

Comments
 (0)