|
18 | 18 | #include <stdio.h> |
19 | 19 | #include <stdlib.h> |
20 | 20 | #include <string.h> |
| 21 | +#include <sys/ipc.h> |
21 | 22 | #include <sys/mman.h> |
| 23 | +#include <sys/msg.h> |
22 | 24 | #include <sys/stat.h> |
23 | 25 | #include <sys/types.h> |
24 | | -#include <sys/ipc.h> |
25 | | -#include <sys/msg.h> |
26 | 26 | #include <unistd.h> |
27 | 27 |
|
28 | 28 | #include "driver_types.h" |
| 29 | +#include "msg_utils.h" |
29 | 30 | #include "paddle/extension.h" |
30 | 31 | #include "paddle/phi/core/allocator.h" |
31 | 32 | #include "paddle/phi/core/dense_tensor.h" |
32 | | -#include "msg_utils.h" |
33 | 33 |
|
34 | 34 | struct RemoteCacheKvIpc { |
35 | | - struct save_cache_kv_complete_signal_layerwise_meta_data{ |
36 | | - int32_t layer_id=-1; |
37 | | - void * shm_ptr=nullptr; |
38 | | - int shm_fd=-1; |
39 | | - save_cache_kv_complete_signal_layerwise_meta_data(){} |
40 | | - save_cache_kv_complete_signal_layerwise_meta_data(int32_t layer_id_, |
41 | | - void* shm_ptr_, |
42 | | - int shm_fd_) |
43 | | - :layer_id(layer_id_), shm_ptr(shm_ptr_), shm_fd(shm_fd_){ |
44 | | - } |
45 | | - }; |
| 35 | + struct save_cache_kv_complete_signal_layerwise_meta_data { |
| 36 | + int32_t layer_id = -1; |
| 37 | + void* shm_ptr = nullptr; |
| 38 | + int shm_fd = -1; |
| 39 | + save_cache_kv_complete_signal_layerwise_meta_data() {} |
| 40 | + save_cache_kv_complete_signal_layerwise_meta_data(int32_t layer_id_, |
| 41 | + void* shm_ptr_, |
| 42 | + int shm_fd_) |
| 43 | + : layer_id(layer_id_), shm_ptr(shm_ptr_), shm_fd(shm_fd_) {} |
| 44 | + }; |
46 | 45 |
|
47 | | - struct save_cache_kv_complete_signal_layerwise_meta_data_per_query{ |
48 | | - int layer_id_; |
49 | | - int num_layers_; |
50 | | - bool inited = false; |
51 | | - struct msgdatakv msg_sed; |
52 | | - int msgid; |
| 46 | + struct save_cache_kv_complete_signal_layerwise_meta_data_per_query { |
| 47 | + int layer_id_; |
| 48 | + int num_layers_; |
| 49 | + bool inited = false; |
| 50 | + struct msgdatakv msg_sed; |
| 51 | + int msgid; |
53 | 52 |
|
54 | | - save_cache_kv_complete_signal_layerwise_meta_data_per_query(){} |
| 53 | + save_cache_kv_complete_signal_layerwise_meta_data_per_query() {} |
55 | 54 |
|
56 | | - void init(const int *seq_lens_encoder, |
57 | | - const int *seq_lens_decoder, |
58 | | - const int rank, |
59 | | - const int num_layers, |
60 | | - const int real_bsz) { |
61 | | - layer_id_ = 0; |
62 | | - num_layers_ = num_layers; |
63 | | - msg_sed.mtype = 1; |
64 | | - int encoder_count = 0; |
65 | | - for (int i = 0; i < real_bsz; i++) { |
66 | | - if (seq_lens_encoder[i] > 0) { |
67 | | - msg_sed.mtext[3 * encoder_count + 2] = i; |
68 | | - msg_sed.mtext[3 * encoder_count + 3] = seq_lens_decoder[i]; |
69 | | - msg_sed.mtext[3 * encoder_count + 4] = seq_lens_encoder[i]; |
70 | | - encoder_count++; |
71 | | - } |
72 | | - } |
73 | | - msg_sed.mtext[0] = encoder_count; |
74 | | - |
75 | | - if (!inited) { |
76 | | - // just init once |
77 | | - const int msg_id = 1024 + rank; |
78 | | - key_t key = ftok("/opt/", msg_id); |
79 | | - msgid = msgget(key, IPC_CREAT | 0666); |
80 | | - inited = true; |
81 | | - } |
| 55 | + void init(const int* seq_lens_encoder, |
| 56 | + const int* seq_lens_decoder, |
| 57 | + const int rank, |
| 58 | + const int num_layers, |
| 59 | + const int real_bsz) { |
| 60 | + layer_id_ = 0; |
| 61 | + num_layers_ = num_layers; |
| 62 | + msg_sed.mtype = 1; |
| 63 | + int encoder_count = 0; |
| 64 | + for (int i = 0; i < real_bsz; i++) { |
| 65 | + if (seq_lens_encoder[i] > 0) { |
| 66 | + msg_sed.mtext[3 * encoder_count + 2] = i; |
| 67 | + msg_sed.mtext[3 * encoder_count + 3] = seq_lens_decoder[i]; |
| 68 | + msg_sed.mtext[3 * encoder_count + 4] = seq_lens_encoder[i]; |
| 69 | + encoder_count++; |
82 | 70 | } |
| 71 | + } |
| 72 | + msg_sed.mtext[0] = encoder_count; |
| 73 | + |
| 74 | + if (!inited) { |
| 75 | + // just init once |
| 76 | + const int msg_id = 1024 + rank; |
| 77 | + key_t key = ftok("/opt/", msg_id); |
| 78 | + msgid = msgget(key, IPC_CREAT | 0666); |
| 79 | + inited = true; |
| 80 | + } |
| 81 | + } |
83 | 82 |
|
84 | | - void CUDART_CB send_signal() { |
85 | | - msg_sed.mtext[1] = layer_id_; |
86 | | - if ((msgsnd(msgid, &msg_sed, (MAX_BSZ * 3 + 2) * 4, 0)) == -1) { |
87 | | - printf("kv signal full msg buffer\n"); |
88 | | - } |
89 | | - layer_id_ = (layer_id_ + 1); |
90 | | - assert(layer_id_ <= num_layers_); |
| 83 | + void CUDART_CB send_signal() { |
| 84 | + if (inited) { |
| 85 | + msg_sed.mtext[1] = layer_id_; |
| 86 | + if ((msgsnd(msgid, &msg_sed, (MAX_BSZ * 3 + 2) * 4, 0)) == -1) { |
| 87 | + printf("kv signal full msg buffer\n"); |
91 | 88 | } |
92 | | - }; |
| 89 | + layer_id_ = (layer_id_ + 1); |
| 90 | + assert(layer_id_ <= num_layers_); |
| 91 | + } |
| 92 | + } |
| 93 | + }; |
93 | 94 |
|
94 | | - static RemoteCacheKvIpc::save_cache_kv_complete_signal_layerwise_meta_data kv_complete_signal_meta_data; |
95 | | - static RemoteCacheKvIpc::save_cache_kv_complete_signal_layerwise_meta_data_per_query kv_complete_signal_meta_data_per_query; |
96 | | - static void* kv_complete_signal_identity_ptr; |
97 | | - static bool kv_complete_signal_shmem_opened; |
| 95 | + static RemoteCacheKvIpc::save_cache_kv_complete_signal_layerwise_meta_data |
| 96 | + kv_complete_signal_meta_data; |
| 97 | + static RemoteCacheKvIpc:: |
| 98 | + save_cache_kv_complete_signal_layerwise_meta_data_per_query |
| 99 | + kv_complete_signal_meta_data_per_query; |
| 100 | + static void* kv_complete_signal_identity_ptr; |
| 101 | + static bool kv_complete_signal_shmem_opened; |
98 | 102 |
|
99 | | - static RemoteCacheKvIpc::save_cache_kv_complete_signal_layerwise_meta_data open_shm_and_get_complete_signal_meta_data( |
100 | | - const int rank_id, |
101 | | - const int device_id, |
102 | | - const bool keep_pd_step_flag); |
103 | | - static void CUDART_CB save_cache_kv_complete_signal_layerwise(void* meta_data); |
104 | | - static void CUDART_CB save_cache_kv_complete_signal_layerwise_per_query(void* meta_data); |
| 103 | + static RemoteCacheKvIpc::save_cache_kv_complete_signal_layerwise_meta_data |
| 104 | + open_shm_and_get_complete_signal_meta_data(const int rank_id, |
| 105 | + const int device_id, |
| 106 | + const bool keep_pd_step_flag); |
| 107 | + static void CUDART_CB |
| 108 | + save_cache_kv_complete_signal_layerwise(void* meta_data); |
| 109 | + static void CUDART_CB |
| 110 | + save_cache_kv_complete_signal_layerwise_per_query(void* meta_data); |
105 | 111 | }; |
0 commit comments