@@ -142,19 +142,20 @@ bool WorkerImpl::allocate_host_kv_cache(
142142
143143 CHECK (model_ != nullptr ) << " Model is not initialized." ;
144144 CHECK (host_kv_caches_.empty ()) << " KV caches are already initialized." ;
145+ CHECK (device_kv_cache_shape[0 ][0 ] == device_kv_cache_shape[1 ][0 ]);
145146
146147 std::vector<std::vector<int64_t >> host_kv_cache_shape = device_kv_cache_shape;
147- for (auto & shape : host_kv_cache_shape) {
148- if (!shape.empty ()) {
149- shape[0 ] *= options_.host_blocks_factor ();
150- }
151- }
152-
153- // create a KVCache for each layer
154148 const int64_t num_layers = context_.get_model_args ().n_layers ();
155- host_kv_caches_.reserve (num_layers);
156- for (int64_t i = 0 ; i < num_layers; ++i) {
157- torch::Tensor key_cache, value_cache, index_cache;
149+ int64_t host_bolck_size =
150+ device_kv_cache_shape[0 ][0 ] * options_.host_blocks_factor ();
151+ host_kv_cache_shape[0 ][0 ] = num_layers;
152+ host_kv_cache_shape[1 ][0 ] = num_layers;
153+
154+ // create a KVCache shape: block_size * [layers, token, head, dim]
155+ host_kv_caches_.reserve (host_bolck_size);
156+
157+ for (int64_t i = 0 ; i < host_bolck_size; ++i) {
158+ torch::Tensor key_cache, value_cache;
158159 key_cache = torch::empty (host_kv_cache_shape[0 ],
159160 torch::dtype (dtype_).device (torch::kCPU ))
160161 .pin_memory ();
@@ -163,8 +164,7 @@ bool WorkerImpl::allocate_host_kv_cache(
163164 .pin_memory ();
164165 host_kv_caches_.emplace_back (key_cache, value_cache);
165166 }
166- LOG (INFO) << " Initializing host k cache size: " << host_kv_cache_shape[0 ][0 ];
167- LOG (INFO) << " Initializing host v cache size: " << host_kv_cache_shape[1 ][0 ];
167+ LOG (INFO) << " Initializing host kv block size: " << host_bolck_size;
168168
169169 int32_t device_id = device_.index ();
170170 h2d_attrs_.dstLoc .id = device_id;
@@ -701,22 +701,8 @@ uint32_t WorkerImpl::transfer_kv_blocks(
701701
702702 switch (block_transfer_info[0 ].transfer_type ) {
703703 case TransferType::G2H: {
704- folly::Promise<uint32_t > promise;
705- auto future = promise.getSemiFuture ();
706-
707- batchget_threadpool_.schedule (
708- [this , &block_transfer_info, promise = std::move (promise)]() mutable {
709- promise.setValue (
710- KVCacheStore::get_instance ().batch_get (block_transfer_info));
711- });
712-
713- try {
714- auto timeout = std::chrono::seconds (KVSTORE_TIMEOUT);
715- return std::move (future).wait (timeout);
716- } catch (const folly::FutureTimeout& e) {
717- LOG (WARNING) << " BatchGet operation timed out" ;
718- return 0 ;
719- }
704+ Slice<BlockTransferInfo> info_slice{block_transfer_info};
705+ return load_from_store (info_slice);
720706 }
721707 case TransferType::D2G:
722708 return offload_kv_blocks (block_transfer_info);
@@ -806,23 +792,7 @@ uint32_t WorkerImpl::offload_kv_blocks(
806792 promise = std::move (promise),
807793 slice = std::move (slice)]() mutable {
808794 bool ret = d2h_batch_copy (slice);
809- uint32_t success_cnt = 0 ;
810-
811- folly::Promise<uint32_t > store_promise;
812- auto future = store_promise.getSemiFuture ();
813-
814- batchput_threadpool_.schedule (
815- [this , &slice, promise = std::move (store_promise)]() mutable {
816- promise.setValue (KVCacheStore::get_instance ().batch_put (slice));
817- });
818-
819- try {
820- auto timeout = std::chrono::seconds (KVSTORE_TIMEOUT);
821- success_cnt = std::move (future).wait (timeout);
822- } catch (const folly::FutureTimeout& e) {
823- LOG (WARNING) << " BatchPut operation timed out" ;
824- }
825-
795+ auto success_cnt = offload_to_store (slice);
826796 if (success_cnt != slice.size ()) {
827797 LOG (WARNING) << " KVCacheStore not all put success: " << success_cnt
828798 << " /" << slice.size ();
@@ -908,6 +878,7 @@ bool WorkerImpl::d2h_batch_copy(Slice<BlockTransferInfo>& block_transfer_info) {
908878#if defined(USE_NPU)
909879 CHECK (copy_stream_.count (std::this_thread::get_id ()) != 0 )
910880 << " WorkerImpl::d2h_batch_copy can only be called in copy_threadpool_." ;
881+
911882 const int64_t num_layers = context_.get_model_args ().n_layers ();
912883 uint32_t num_batches = block_transfer_info.size () * num_layers * 2 ;
913884 void ** srcs = new void *[num_batches];
@@ -917,26 +888,25 @@ bool WorkerImpl::d2h_batch_copy(Slice<BlockTransferInfo>& block_transfer_info) {
917888 size_t attrs_indexes[1 ] = {0 };
918889 size_t fail_index;
919890 uint32_t curr_index = 0 ;
920- for (int layer_id = 0 ; layer_id < num_layers; layer_id++) {
921- auto src_k_cache = kv_caches_.at (layer_id).get_k_cache ();
922- auto dst_k_cache = host_kv_caches_.at (layer_id).get_k_cache ();
923- auto src_v_cache = kv_caches_.at (layer_id).get_v_cache ();
924- auto dst_v_cache = host_kv_caches_.at (layer_id).get_v_cache ();
925-
926- for (int idx = 0 ; idx < block_transfer_info.size (); idx++) {
927- srcs[curr_index] =
928- src_k_cache[block_transfer_info[idx].src_block_id ].data_ptr ();
929- dsts[curr_index] =
930- dst_k_cache[block_transfer_info[idx].dst_block_id ].data_ptr ();
931891
892+ for (const auto & info : block_transfer_info) {
893+ auto dst_k_cache = host_kv_caches_.at (info.dst_block_id ).get_k_cache ();
894+ auto dst_v_cache = host_kv_caches_.at (info.dst_block_id ).get_v_cache ();
895+
896+ for (int layer_id = 0 ; layer_id < num_layers; layer_id++) {
897+ auto src_k_cache = kv_caches_.at (layer_id).get_k_cache ();
898+ auto src_v_cache = kv_caches_.at (layer_id).get_v_cache ();
899+
900+ srcs[curr_index] = src_k_cache[info.src_block_id ].data_ptr ();
901+ dsts[curr_index] = dst_k_cache[layer_id].data_ptr ();
932902 copy_size[curr_index] = key_cache_size_per_layer_;
903+
933904 curr_index++;
934905
935- srcs[curr_index] =
936- src_v_cache[block_transfer_info[idx].src_block_id ].data_ptr ();
937- dsts[curr_index] =
938- dst_v_cache[block_transfer_info[idx].dst_block_id ].data_ptr ();
906+ srcs[curr_index] = src_v_cache[info.src_block_id ].data_ptr ();
907+ dsts[curr_index] = dst_v_cache[layer_id].data_ptr ();
939908 copy_size[curr_index] = value_cache_size_per_layer_;
909+
940910 curr_index++;
941911 }
942912 }
@@ -974,6 +944,7 @@ bool WorkerImpl::h2d_batch_copy(Slice<BlockTransferInfo>& block_transfer_info) {
974944#if defined(USE_NPU)
975945 CHECK (copy_stream_.count (std::this_thread::get_id ()) != 0 )
976946 << " WorkerImpl::h2d_batch_copy can only be called in copy_threadpool_." ;
947+
977948 const int64_t num_layers = context_.get_model_args ().n_layers ();
978949 uint32_t num_batches = block_transfer_info.size () * num_layers * 2 ;
979950 void ** srcs = new void *[num_batches];
@@ -984,24 +955,21 @@ bool WorkerImpl::h2d_batch_copy(Slice<BlockTransferInfo>& block_transfer_info) {
984955 size_t fail_index;
985956 uint32_t curr_index = 0 ;
986957
987- for (int layer_id = 0 ; layer_id < num_layers; layer_id++) {
988- auto src_k_cache = host_kv_caches_.at (layer_id).get_k_cache ();
989- auto dst_k_cache = kv_caches_.at (layer_id).get_k_cache ();
990- auto src_v_cache = host_kv_caches_.at (layer_id).get_v_cache ();
991- auto dst_v_cache = kv_caches_.at (layer_id).get_v_cache ();
992-
993- for (int idx = 0 ; idx < block_transfer_info.size (); idx++) {
994- srcs[curr_index] =
995- src_k_cache[block_transfer_info[idx].src_block_id ].data_ptr ();
996- dsts[curr_index] =
997- dst_k_cache[block_transfer_info[idx].dst_block_id ].data_ptr ();
958+ for (const auto & info : block_transfer_info) {
959+ auto src_k_cache = host_kv_caches_.at (info.src_block_id ).get_k_cache ();
960+ auto src_v_cache = host_kv_caches_.at (info.src_block_id ).get_v_cache ();
961+
962+ for (int layer_id = 0 ; layer_id < num_layers; layer_id++) {
963+ auto dst_k_cache = kv_caches_.at (layer_id).get_k_cache ();
964+ auto dst_v_cache = kv_caches_.at (layer_id).get_v_cache ();
965+
966+ srcs[curr_index] = src_k_cache[layer_id].data_ptr ();
967+ dsts[curr_index] = dst_k_cache[info.dst_block_id ].data_ptr ();
998968 copy_size[curr_index] = key_cache_size_per_layer_;
999969 curr_index++;
1000970
1001- srcs[curr_index] =
1002- src_v_cache[block_transfer_info[idx].src_block_id ].data_ptr ();
1003- dsts[curr_index] =
1004- dst_v_cache[block_transfer_info[idx].dst_block_id ].data_ptr ();
971+ srcs[curr_index] = src_v_cache[layer_id].data_ptr ();
972+ dsts[curr_index] = dst_v_cache[info.dst_block_id ].data_ptr ();
1005973 copy_size[curr_index] = value_cache_size_per_layer_;
1006974 curr_index++;
1007975 }
@@ -1035,4 +1003,64 @@ bool WorkerImpl::h2d_batch_copy(Slice<BlockTransferInfo>& block_transfer_info) {
10351003 return false ;
10361004}
10371005
1006+ uint32_t WorkerImpl::offload_to_store (
1007+ Slice<BlockTransferInfo>& block_transfer_info) {
1008+ if (!options_.enable_kvcache_store ()) {
1009+ return block_transfer_info.size ();
1010+ }
1011+
1012+ folly::Promise<uint32_t > promise;
1013+ auto future = promise.getSemiFuture ();
1014+
1015+ batchput_threadpool_.schedule (
1016+ [this , &block_transfer_info, promise = std::move (promise)]() mutable {
1017+ promise.setValue (
1018+ KVCacheStore::get_instance ().batch_put (block_transfer_info));
1019+ });
1020+
1021+ auto timeout = std::chrono::seconds (KVSTORE_TIMEOUT);
1022+ return std::move (future)
1023+ .via (folly::getGlobalCPUExecutor ())
1024+ .within (timeout)
1025+ .thenTry ([](folly::Try<uint32_t >&& t) -> uint32_t {
1026+ if (t.hasValue ()) {
1027+ return t.value ();
1028+ } else {
1029+ LOG (WARNING) << " BatchPut operation timed out" ;
1030+ return 0u ;
1031+ }
1032+ })
1033+ .get ();
1034+ }
1035+
1036+ uint32_t WorkerImpl::load_from_store (
1037+ Slice<BlockTransferInfo>& block_transfer_info) {
1038+ if (!options_.enable_kvcache_store ()) {
1039+ return 0 ;
1040+ }
1041+
1042+ folly::Promise<uint32_t > promise;
1043+ auto future = promise.getSemiFuture ();
1044+
1045+ batchget_threadpool_.schedule (
1046+ [this , &block_transfer_info, promise = std::move (promise)]() mutable {
1047+ promise.setValue (
1048+ KVCacheStore::get_instance ().batch_get (block_transfer_info));
1049+ });
1050+
1051+ auto timeout = std::chrono::seconds (KVSTORE_TIMEOUT);
1052+ return std::move (future)
1053+ .via (folly::getGlobalCPUExecutor ())
1054+ .within (timeout)
1055+ .thenTry ([](folly::Try<uint32_t >&& t) -> uint32_t {
1056+ if (t.hasValue ()) {
1057+ return t.value ();
1058+ } else {
1059+ LOG (WARNING) << " BatchGet operation timed out" ;
1060+ return 0u ;
1061+ }
1062+ })
1063+ .get ();
1064+ }
1065+
10381066} // namespace xllm
0 commit comments