Skip to content

Commit ca6dba2

Browse files
committed
refactor: change host KV cache memory layout from layer-wise to block-wise.
1 parent 9833d63 commit ca6dba2

File tree

5 files changed

+142
-132
lines changed

5 files changed

+142
-132
lines changed

xllm/core/framework/kv_cache/kv_cache_store.cpp

Lines changed: 31 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -43,29 +43,24 @@ bool KVCacheStore::init(const StoreConfig& config,
4343
}
4444
client_ptr_ = client_opt.value();
4545

46-
auto key_tensor_one_layer = host_kv_caches_->at(0).get_k_cache();
47-
auto value_tensor_one_layer = host_kv_caches_->at(0).get_v_cache();
46+
auto k_tensor_one_block = host_kv_caches_->at(0).get_k_cache();
47+
auto v_tensor_one_block = host_kv_caches_->at(0).get_v_cache();
4848

49-
key_cache_size_per_layer_ =
50-
key_tensor_one_layer[0].numel() * key_tensor_one_layer[0].element_size();
51-
value_cache_size_per_layer_ = value_tensor_one_layer[0].numel() *
52-
value_tensor_one_layer[0].element_size();
49+
k_cache_size_per_block_ =
50+
k_tensor_one_block.numel() * k_tensor_one_block.element_size();
51+
v_cache_size_per_block_ =
52+
v_tensor_one_block.numel() * v_tensor_one_block.element_size();
5353

54-
auto key_cache_host_size =
55-
key_tensor_one_layer.numel() * key_tensor_one_layer.element_size();
56-
auto value_cache_host_size =
57-
value_tensor_one_layer.numel() * value_tensor_one_layer.element_size();
58-
59-
LOG(INFO) << "key_cache_size_per_layer: " << key_cache_size_per_layer_;
60-
LOG(INFO) << "value_cache_size_per_layer: " << value_cache_size_per_layer_;
54+
LOG(INFO) << "k_cache_size_per_block: " << k_cache_size_per_block_;
55+
LOG(INFO) << "v_cache_size_per_block: " << v_cache_size_per_block_;
6156

6257
if (config_.protocol == "rdma") {
63-
for (int layer = 0; layer < host_kv_caches_->size(); layer++) {
58+
for (int block = 0; block < host_kv_caches_->size(); block++) {
6459
void* key_cache = static_cast<char*>(
65-
host_kv_caches_->at(layer).get_k_cache().data_ptr());
60+
host_kv_caches_->at(block).get_k_cache().data_ptr());
6661

6762
auto register_k_result = client_ptr_->RegisterLocalMemory(
68-
key_cache, key_cache_host_size, "cpu:0", false, false);
63+
key_cache, k_cache_size_per_block_, "cpu:0", false, false);
6964

7065
if (!register_k_result.has_value()) {
7166
LOG(ERROR) << "Failed to register local memory for key cache: "
@@ -74,10 +69,10 @@ bool KVCacheStore::init(const StoreConfig& config,
7469
}
7570

7671
void* value_cache = static_cast<char*>(
77-
host_kv_caches_->at(layer).get_v_cache().data_ptr());
72+
host_kv_caches_->at(block).get_v_cache().data_ptr());
7873

7974
auto register_v_result = client_ptr_->RegisterLocalMemory(
80-
value_cache, value_cache_host_size, "cpu:0", false, false);
75+
value_cache, v_cache_size_per_block_, "cpu:0", false, false);
8176

8277
if (!register_v_result.has_value()) {
8378
LOG(ERROR) << "Failed to register local memory for value cache: "
@@ -119,23 +114,14 @@ uint32_t KVCacheStore::batch_put(
119114

120115
str_keys.emplace_back(str_key);
121116

122-
std::vector<mooncake::Slice> slice;
123-
slice.reserve(host_kv_caches_->size() * 2);
124-
for (int layer = 0; layer < host_kv_caches_->size(); layer++) {
125-
void* key_cache =
126-
static_cast<char*>(
127-
host_kv_caches_->at(layer).get_k_cache().data_ptr()) +
128-
block_info.dst_block_id * key_cache_size_per_layer_;
129-
slice.emplace_back(mooncake::Slice{key_cache, key_cache_size_per_layer_});
130-
131-
void* value_cache =
132-
static_cast<char*>(
133-
host_kv_caches_->at(layer).get_v_cache().data_ptr()) +
134-
block_info.dst_block_id * value_cache_size_per_layer_;
135-
slice.emplace_back(
136-
mooncake::Slice{value_cache, value_cache_size_per_layer_});
137-
}
138-
slices.emplace_back(std::move(slice));
117+
void* k_cache =
118+
host_kv_caches_->at(block_info.dst_block_id).get_k_cache().data_ptr();
119+
void* v_cache =
120+
host_kv_caches_->at(block_info.dst_block_id).get_k_cache().data_ptr();
121+
122+
slices.emplace_back(std::vector<mooncake::Slice>{
123+
mooncake::Slice{k_cache, k_cache_size_per_block_},
124+
mooncake::Slice{v_cache, v_cache_size_per_block_}});
139125
}
140126

141127
if (str_keys.size() == 0) {
@@ -177,24 +163,16 @@ uint32_t KVCacheStore::batch_get(
177163

178164
str_keys.emplace_back(str_key);
179165

180-
slices.insert(std::make_pair(str_key, std::vector<mooncake::Slice>()));
181-
182-
slices[str_key].reserve(host_kv_caches_->size() * 2);
183-
for (int layer = 0; layer < host_kv_caches_->size(); layer++) {
184-
void* key_cache =
185-
static_cast<char*>(
186-
host_kv_caches_->at(layer).get_k_cache().data_ptr()) +
187-
block_info.dst_block_id * key_cache_size_per_layer_;
188-
slices[str_key].emplace_back(
189-
mooncake::Slice{key_cache, key_cache_size_per_layer_});
190-
191-
void* value_cache =
192-
static_cast<char*>(
193-
host_kv_caches_->at(layer).get_v_cache().data_ptr()) +
194-
block_info.dst_block_id * value_cache_size_per_layer_;
195-
slices[str_key].emplace_back(
196-
mooncake::Slice{value_cache, value_cache_size_per_layer_});
197-
}
166+
void* k_cache =
167+
host_kv_caches_->at(block_info.dst_block_id).get_k_cache().data_ptr();
168+
void* v_cache =
169+
host_kv_caches_->at(block_info.dst_block_id).get_k_cache().data_ptr();
170+
171+
slices.insert(
172+
std::make_pair(str_key,
173+
std::vector<mooncake::Slice>{
174+
mooncake::Slice{k_cache, k_cache_size_per_block_},
175+
mooncake::Slice{v_cache, v_cache_size_per_block_}}));
198176
}
199177

200178
if (str_keys.size() == 0) {

xllm/core/framework/kv_cache/kv_cache_store.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -69,8 +69,8 @@ class KVCacheStore {
6969

7070
std::vector<xllm::KVCache>* host_kv_caches_;
7171

72-
uint64_t key_cache_size_per_layer_;
73-
uint64_t value_cache_size_per_layer_;
72+
uint64_t k_cache_size_per_block_;
73+
uint64_t v_cache_size_per_block_;
7474

7575
std::shared_ptr<mooncake::Client> client_ptr_;
7676
};

xllm/core/framework/request/sequence.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -239,10 +239,10 @@ class Sequence final {
239239

240240
void sync_result() {
241241
if (futures_.has_value()) {
242-
auto success_cnt = host_kv_state_.num_kv_blocks();
242+
uint32_t success_cnt = host_kv_state_.num_kv_blocks();
243243
for (auto& future : futures_.value()) {
244244
if (future.isReady()) {
245-
success_cnt = std::min(success_cnt, size_t(future.value()));
245+
success_cnt = std::min(success_cnt, future.value());
246246
} else {
247247
return;
248248
}

xllm/core/runtime/worker_impl.cpp

Lines changed: 103 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -142,19 +142,20 @@ bool WorkerImpl::allocate_host_kv_cache(
142142

143143
CHECK(model_ != nullptr) << "Model is not initialized.";
144144
CHECK(host_kv_caches_.empty()) << "KV caches are already initialized.";
145+
CHECK(device_kv_cache_shape[0][0] == device_kv_cache_shape[1][0]);
145146

146147
std::vector<std::vector<int64_t>> host_kv_cache_shape = device_kv_cache_shape;
147-
for (auto& shape : host_kv_cache_shape) {
148-
if (!shape.empty()) {
149-
shape[0] *= options_.host_blocks_factor();
150-
}
151-
}
152-
153-
// create a KVCache for each layer
154148
const int64_t num_layers = context_.get_model_args().n_layers();
155-
host_kv_caches_.reserve(num_layers);
156-
for (int64_t i = 0; i < num_layers; ++i) {
157-
torch::Tensor key_cache, value_cache, index_cache;
149+
int64_t host_bolck_size =
150+
device_kv_cache_shape[0][0] * options_.host_blocks_factor();
151+
host_kv_cache_shape[0][0] = num_layers;
152+
host_kv_cache_shape[1][0] = num_layers;
153+
154+
// create a KVCache shape: block_size * [layers, token, head, dim]
155+
host_kv_caches_.reserve(host_bolck_size);
156+
157+
for (int64_t i = 0; i < host_bolck_size; ++i) {
158+
torch::Tensor key_cache, value_cache;
158159
key_cache = torch::empty(host_kv_cache_shape[0],
159160
torch::dtype(dtype_).device(torch::kCPU))
160161
.pin_memory();
@@ -163,8 +164,7 @@ bool WorkerImpl::allocate_host_kv_cache(
163164
.pin_memory();
164165
host_kv_caches_.emplace_back(key_cache, value_cache);
165166
}
166-
LOG(INFO) << "Initializing host k cache size: " << host_kv_cache_shape[0][0];
167-
LOG(INFO) << "Initializing host v cache size: " << host_kv_cache_shape[1][0];
167+
LOG(INFO) << "Initializing host kv block size: " << host_bolck_size;
168168

169169
int32_t device_id = device_.index();
170170
h2d_attrs_.dstLoc.id = device_id;
@@ -701,22 +701,8 @@ uint32_t WorkerImpl::transfer_kv_blocks(
701701

702702
switch (block_transfer_info[0].transfer_type) {
703703
case TransferType::G2H: {
704-
folly::Promise<uint32_t> promise;
705-
auto future = promise.getSemiFuture();
706-
707-
batchget_threadpool_.schedule(
708-
[this, &block_transfer_info, promise = std::move(promise)]() mutable {
709-
promise.setValue(
710-
KVCacheStore::get_instance().batch_get(block_transfer_info));
711-
});
712-
713-
try {
714-
auto timeout = std::chrono::seconds(KVSTORE_TIMEOUT);
715-
return std::move(future).wait(timeout);
716-
} catch (const folly::FutureTimeout& e) {
717-
LOG(WARNING) << "BatchGet operation timed out";
718-
return 0;
719-
}
704+
Slice<BlockTransferInfo> info_slice{block_transfer_info};
705+
return load_from_store(info_slice);
720706
}
721707
case TransferType::D2G:
722708
return offload_kv_blocks(block_transfer_info);
@@ -806,23 +792,7 @@ uint32_t WorkerImpl::offload_kv_blocks(
806792
promise = std::move(promise),
807793
slice = std::move(slice)]() mutable {
808794
bool ret = d2h_batch_copy(slice);
809-
uint32_t success_cnt = 0;
810-
811-
folly::Promise<uint32_t> store_promise;
812-
auto future = store_promise.getSemiFuture();
813-
814-
batchput_threadpool_.schedule(
815-
[this, &slice, promise = std::move(store_promise)]() mutable {
816-
promise.setValue(KVCacheStore::get_instance().batch_put(slice));
817-
});
818-
819-
try {
820-
auto timeout = std::chrono::seconds(KVSTORE_TIMEOUT);
821-
success_cnt = std::move(future).wait(timeout);
822-
} catch (const folly::FutureTimeout& e) {
823-
LOG(WARNING) << "BatchPut operation timed out";
824-
}
825-
795+
auto success_cnt = offload_to_store(slice);
826796
if (success_cnt != slice.size()) {
827797
LOG(WARNING) << "KVCacheStore not all put success: " << success_cnt
828798
<< "/" << slice.size();
@@ -908,6 +878,7 @@ bool WorkerImpl::d2h_batch_copy(Slice<BlockTransferInfo>& block_transfer_info) {
908878
#if defined(USE_NPU)
909879
CHECK(copy_stream_.count(std::this_thread::get_id()) != 0)
910880
<< "WorkerImpl::d2h_batch_copy can only be called in copy_threadpool_.";
881+
911882
const int64_t num_layers = context_.get_model_args().n_layers();
912883
uint32_t num_batches = block_transfer_info.size() * num_layers * 2;
913884
void** srcs = new void*[num_batches];
@@ -917,26 +888,25 @@ bool WorkerImpl::d2h_batch_copy(Slice<BlockTransferInfo>& block_transfer_info) {
917888
size_t attrs_indexes[1] = {0};
918889
size_t fail_index;
919890
uint32_t curr_index = 0;
920-
for (int layer_id = 0; layer_id < num_layers; layer_id++) {
921-
auto src_k_cache = kv_caches_.at(layer_id).get_k_cache();
922-
auto dst_k_cache = host_kv_caches_.at(layer_id).get_k_cache();
923-
auto src_v_cache = kv_caches_.at(layer_id).get_v_cache();
924-
auto dst_v_cache = host_kv_caches_.at(layer_id).get_v_cache();
925-
926-
for (int idx = 0; idx < block_transfer_info.size(); idx++) {
927-
srcs[curr_index] =
928-
src_k_cache[block_transfer_info[idx].src_block_id].data_ptr();
929-
dsts[curr_index] =
930-
dst_k_cache[block_transfer_info[idx].dst_block_id].data_ptr();
931891

892+
for (const auto& info : block_transfer_info) {
893+
auto dst_k_cache = host_kv_caches_.at(info.dst_block_id).get_k_cache();
894+
auto dst_v_cache = host_kv_caches_.at(info.dst_block_id).get_v_cache();
895+
896+
for (int layer_id = 0; layer_id < num_layers; layer_id++) {
897+
auto src_k_cache = kv_caches_.at(layer_id).get_k_cache();
898+
auto src_v_cache = kv_caches_.at(layer_id).get_v_cache();
899+
900+
srcs[curr_index] = src_k_cache[info.src_block_id].data_ptr();
901+
dsts[curr_index] = dst_k_cache[layer_id].data_ptr();
932902
copy_size[curr_index] = key_cache_size_per_layer_;
903+
933904
curr_index++;
934905

935-
srcs[curr_index] =
936-
src_v_cache[block_transfer_info[idx].src_block_id].data_ptr();
937-
dsts[curr_index] =
938-
dst_v_cache[block_transfer_info[idx].dst_block_id].data_ptr();
906+
srcs[curr_index] = src_v_cache[info.src_block_id].data_ptr();
907+
dsts[curr_index] = dst_v_cache[layer_id].data_ptr();
939908
copy_size[curr_index] = value_cache_size_per_layer_;
909+
940910
curr_index++;
941911
}
942912
}
@@ -974,6 +944,7 @@ bool WorkerImpl::h2d_batch_copy(Slice<BlockTransferInfo>& block_transfer_info) {
974944
#if defined(USE_NPU)
975945
CHECK(copy_stream_.count(std::this_thread::get_id()) != 0)
976946
<< "WorkerImpl::h2d_batch_copy can only be called in copy_threadpool_.";
947+
977948
const int64_t num_layers = context_.get_model_args().n_layers();
978949
uint32_t num_batches = block_transfer_info.size() * num_layers * 2;
979950
void** srcs = new void*[num_batches];
@@ -984,24 +955,21 @@ bool WorkerImpl::h2d_batch_copy(Slice<BlockTransferInfo>& block_transfer_info) {
984955
size_t fail_index;
985956
uint32_t curr_index = 0;
986957

987-
for (int layer_id = 0; layer_id < num_layers; layer_id++) {
988-
auto src_k_cache = host_kv_caches_.at(layer_id).get_k_cache();
989-
auto dst_k_cache = kv_caches_.at(layer_id).get_k_cache();
990-
auto src_v_cache = host_kv_caches_.at(layer_id).get_v_cache();
991-
auto dst_v_cache = kv_caches_.at(layer_id).get_v_cache();
992-
993-
for (int idx = 0; idx < block_transfer_info.size(); idx++) {
994-
srcs[curr_index] =
995-
src_k_cache[block_transfer_info[idx].src_block_id].data_ptr();
996-
dsts[curr_index] =
997-
dst_k_cache[block_transfer_info[idx].dst_block_id].data_ptr();
958+
for (const auto& info : block_transfer_info) {
959+
auto src_k_cache = host_kv_caches_.at(info.src_block_id).get_k_cache();
960+
auto src_v_cache = host_kv_caches_.at(info.src_block_id).get_v_cache();
961+
962+
for (int layer_id = 0; layer_id < num_layers; layer_id++) {
963+
auto dst_k_cache = kv_caches_.at(layer_id).get_k_cache();
964+
auto dst_v_cache = kv_caches_.at(layer_id).get_v_cache();
965+
966+
srcs[curr_index] = src_k_cache[layer_id].data_ptr();
967+
dsts[curr_index] = dst_k_cache[info.dst_block_id].data_ptr();
998968
copy_size[curr_index] = key_cache_size_per_layer_;
999969
curr_index++;
1000970

1001-
srcs[curr_index] =
1002-
src_v_cache[block_transfer_info[idx].src_block_id].data_ptr();
1003-
dsts[curr_index] =
1004-
dst_v_cache[block_transfer_info[idx].dst_block_id].data_ptr();
971+
srcs[curr_index] = src_v_cache[layer_id].data_ptr();
972+
dsts[curr_index] = dst_v_cache[info.dst_block_id].data_ptr();
1005973
copy_size[curr_index] = value_cache_size_per_layer_;
1006974
curr_index++;
1007975
}
@@ -1035,4 +1003,64 @@ bool WorkerImpl::h2d_batch_copy(Slice<BlockTransferInfo>& block_transfer_info) {
10351003
return false;
10361004
}
10371005

1006+
uint32_t WorkerImpl::offload_to_store(
1007+
Slice<BlockTransferInfo>& block_transfer_info) {
1008+
if (!options_.enable_kvcache_store()) {
1009+
return block_transfer_info.size();
1010+
}
1011+
1012+
folly::Promise<uint32_t> promise;
1013+
auto future = promise.getSemiFuture();
1014+
1015+
batchput_threadpool_.schedule(
1016+
[this, &block_transfer_info, promise = std::move(promise)]() mutable {
1017+
promise.setValue(
1018+
KVCacheStore::get_instance().batch_put(block_transfer_info));
1019+
});
1020+
1021+
auto timeout = std::chrono::seconds(KVSTORE_TIMEOUT);
1022+
return std::move(future)
1023+
.via(folly::getGlobalCPUExecutor())
1024+
.within(timeout)
1025+
.thenTry([](folly::Try<uint32_t>&& t) -> uint32_t {
1026+
if (t.hasValue()) {
1027+
return t.value();
1028+
} else {
1029+
LOG(WARNING) << "BatchPut operation timed out";
1030+
return 0u;
1031+
}
1032+
})
1033+
.get();
1034+
}
1035+
1036+
uint32_t WorkerImpl::load_from_store(
1037+
Slice<BlockTransferInfo>& block_transfer_info) {
1038+
if (!options_.enable_kvcache_store()) {
1039+
return 0;
1040+
}
1041+
1042+
folly::Promise<uint32_t> promise;
1043+
auto future = promise.getSemiFuture();
1044+
1045+
batchget_threadpool_.schedule(
1046+
[this, &block_transfer_info, promise = std::move(promise)]() mutable {
1047+
promise.setValue(
1048+
KVCacheStore::get_instance().batch_get(block_transfer_info));
1049+
});
1050+
1051+
auto timeout = std::chrono::seconds(KVSTORE_TIMEOUT);
1052+
return std::move(future)
1053+
.via(folly::getGlobalCPUExecutor())
1054+
.within(timeout)
1055+
.thenTry([](folly::Try<uint32_t>&& t) -> uint32_t {
1056+
if (t.hasValue()) {
1057+
return t.value();
1058+
} else {
1059+
LOG(WARNING) << "BatchGet operation timed out";
1060+
return 0u;
1061+
}
1062+
})
1063+
.get();
1064+
}
1065+
10381066
} // namespace xllm

0 commit comments

Comments
 (0)