Skip to content

Commit 2a32a61

Browse files
committed
feat: add page-aligned tensor creator for host KV cache.
1 parent 52a10df commit 2a32a61

File tree

19 files changed

+204
-126
lines changed

19 files changed

+204
-126
lines changed

third_party/dependencies.sh

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -59,8 +59,8 @@ if [ -d "yalantinglibs" ]; then
5959
fi
6060

6161
# Clone yalantinglibs
62-
echo "Cloning yalantinglibs from https://github.com/alibaba/yalantinglibs.git"
63-
git clone https://github.com/alibaba/yalantinglibs.git
62+
echo "Cloning yalantinglibs from https://gitcode.com/gh_mirrors/ya/yalantinglibs.git"
63+
git clone https://gitcode.com/gh_mirrors/ya/yalantinglibs.git
6464
check_success "Failed to clone yalantinglibs"
6565

6666
# Build and install yalantinglibs

xllm/core/distributed_runtime/comm_channel.cpp

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -312,8 +312,9 @@ void CommChannel::transfer_kv_blocks(
312312
const std::vector<BlockTransferInfo>& block_transfer_info,
313313
folly::Promise<uint32_t>& promise) {
314314
proto::BlockTransferInfos pb_block_transfer_info;
315-
if (!block_transfer_info_to_proto(
316-
0x0, block_transfer_info, &pb_block_transfer_info)) {
315+
if (!block_transfer_info_to_proto(block_transfer_info,
316+
&pb_block_transfer_info)) {
317+
LOG(ERROR) << "transfer_kv_blocks fail: create proto fail!";
317318
promise.setValue(0);
318319
return;
319320
}
@@ -330,6 +331,8 @@ void CommChannel::transfer_kv_blocks(
330331
proto::BlockTransferInfos pb_block_transfer_info;
331332
if (!block_transfer_info_to_proto(
332333
batch_id, block_transfer_info, &pb_block_transfer_info)) {
334+
LOG(ERROR) << "transfer_kv_blocks with batch id " << batch_id
335+
<< " fail: create proto fail!";
333336
return;
334337
}
335338
brpc::Controller cntl;
@@ -351,11 +354,7 @@ class ClientStreamReceiver : public brpc::StreamInputHandler {
351354

352355
~ClientStreamReceiver() {
353356
if (!promise_set_.exchange(true)) {
354-
try {
355-
close_promise_.set_value();
356-
} catch (const std::exception& e) {
357-
LOG(WARNING) << "Exception in destructor: " << e.what();
358-
}
357+
close_promise_.set_value();
359358
}
360359
}
361360

@@ -400,8 +399,9 @@ void CommChannel::prefetch_from_storage(
400399
const std::vector<BlockTransferInfo>& block_transfer_info,
401400
std::shared_ptr<std::atomic<uint32_t>>& success_cnt) {
402401
proto::BlockTransferInfos pb_block_transfer_info;
403-
if (!block_transfer_info_to_proto(
404-
0x0, block_transfer_info, &pb_block_transfer_info)) {
402+
if (!block_transfer_info_to_proto(block_transfer_info,
403+
&pb_block_transfer_info)) {
404+
LOG(ERROR) << "prefetch_from_storage fail: create proto fail!";
405405
return;
406406
}
407407
ClientStreamReceiver receiver(flag, success_cnt);
@@ -420,6 +420,7 @@ void CommChannel::prefetch_from_storage(
420420

421421
if (cntl.Failed()) {
422422
LOG(ERROR) << "Fail to connect stream, " << cntl.ErrorText();
423+
return;
423424
}
424425

425426
receiver.get_close_future().wait();

xllm/core/distributed_runtime/worker_service.cpp

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -425,7 +425,7 @@ void WorkerService::TransferBlocks(
425425
std::vector<BlockTransferInfo> block_transfer_info;
426426
uint64_t batch_id = proto_to_block_transfer_info(*req, block_transfer_info);
427427

428-
if (batch_id == 0x0) {
428+
if (batch_id == UNINITIALIZED_BATCH_ID) {
429429
resp->set_success_cnt(worker_->transfer_kv_blocks(block_transfer_info));
430430
} else {
431431
worker_->transfer_kv_blocks(batch_id, std::move(block_transfer_info));
@@ -441,11 +441,7 @@ class ServerStreamHandler : public brpc::StreamInputHandler {
441441
public:
442442
~ServerStreamHandler() {
443443
if (!promise_set_.exchange(true)) {
444-
try {
445-
close_promise_.set_value();
446-
} catch (const std::exception& e) {
447-
LOG(WARNING) << "Exception in destructor: " << e.what();
448-
}
444+
close_promise_.set_value();
449445
}
450446
}
451447

xllm/core/framework/batch/batch.h

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,9 @@ namespace xllm {
3434

3535
struct ModelArgs;
3636

37+
static uint64_t batch_counter_ = 1;
38+
constexpr uint64_t UNINITIALIZED_BATCH_ID = 0x0;
39+
3740
class Batch {
3841
public:
3942
Batch() = default;
@@ -56,8 +59,12 @@ class Batch {
5659
}
5760

5861
void set_batch_id() {
59-
if (batch_id_ == 0x0) {
60-
batch_id_ = absl::ToUnixMicros(absl::Now());
62+
if (batch_id_ == UNINITIALIZED_BATCH_ID) {
63+
batch_id_ = batch_counter_;
64+
batch_counter_++;
65+
if (batch_counter_ == UINT64_MAX) {
66+
batch_counter_ = 1;
67+
}
6168
}
6269
}
6370

@@ -138,7 +145,7 @@ class Batch {
138145
// all sequences in this batch are in prefill stage
139146
bool all_seqs_in_prefill_ = false;
140147

141-
uint64_t batch_id_ = 0x0;
148+
uint64_t batch_id_ = UNINITIALIZED_BATCH_ID;
142149
};
143150

144151
} // namespace xllm

xllm/core/framework/batch/batch_input_builder.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -159,7 +159,7 @@ class BatchInputBuilder {
159159

160160
// thread pool for multithreaded processing, not owned
161161
ThreadPool* thread_pool_ = nullptr;
162-
uint64_t batch_id_ = 0x0;
162+
uint64_t batch_id_;
163163
};
164164

165165
} // namespace xllm

xllm/core/framework/block/block_manager_impl.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -70,9 +70,9 @@ void BlockManagerImpl::deallocate(const Slice<Block>& blocks) {
7070
for (const auto& block : blocks) {
7171
// the block is not shared by other sequence
7272
if (block.is_valid() && block.ref_count() <= 2) {
73-
if (num_used_blocks_ > 0) {
74-
num_used_blocks_.fetch_sub(1, std::memory_order_relaxed);
75-
} else {
73+
auto origin_num_used_blocks =
74+
num_used_blocks_.fetch_sub(1, std::memory_order_relaxed);
75+
if (origin_num_used_blocks < 0) {
7676
LOG(ERROR) << "num_used_blocks_==0 cannot fetch_sub for id:"
7777
<< block.id()
7878
<< ", total block size: " << num_total_blocks();
@@ -84,7 +84,7 @@ void BlockManagerImpl::deallocate(const Slice<Block>& blocks) {
8484
error_msg.append(std::to_string(id)).append(" ");
8585
}
8686
}
87-
LOG(ERROR) << error_msg;
87+
LOG(FATAL) << error_msg;
8888
}
8989
}
9090
}

xllm/core/framework/block/block_manager_pool.cpp

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -156,13 +156,9 @@ void BlockManagerPool::set_offload_callback(
156156
device_block_mgr_ptr = block_managers_[i].get()](
157157
std::vector<folly::Try<uint32_t>>&& results) {
158158
for (auto&& result : results) {
159-
try {
160-
if (result.value() != host_blocks.size()) {
161-
LOG(FATAL) << "Offload copy fail, expected "
162-
<< host_blocks.size() << ", got " << result.value();
163-
}
164-
} catch (const std::exception& e) {
165-
LOG(FATAL) << "Offload copy fail! Exception caught: " << e.what();
159+
if (result.value() != host_blocks.size()) {
160+
LOG(FATAL) << "Offload copy fail, expected " << host_blocks.size()
161+
<< ", got " << result.value();
166162
}
167163
}
168164
host_block_mgr_ptr->cache(host_blocks);

xllm/core/framework/kv_cache/kv_cache_store.cpp

Lines changed: 10 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -55,30 +55,18 @@ bool KVCacheStore::init(const StoreConfig& config,
5555
LOG(INFO) << "v_cache_size_per_block: " << v_cache_size_per_block_;
5656

5757
if (config_.protocol == "rdma") {
58-
for (int block = 0; block < host_kv_caches_->size(); block++) {
59-
void* key_cache = static_cast<char*>(
60-
host_kv_caches_->at(block).get_k_cache().data_ptr());
61-
62-
auto register_k_result = client_ptr_->RegisterLocalMemory(
63-
key_cache, k_cache_size_per_block_, "cpu:0", false, false);
64-
65-
if (!register_k_result.has_value()) {
66-
LOG(ERROR) << "Failed to register local memory for key cache: "
67-
<< toString(register_k_result.error());
68-
return false;
69-
}
70-
71-
void* value_cache = static_cast<char*>(
72-
host_kv_caches_->at(block).get_v_cache().data_ptr());
73-
74-
auto register_v_result = client_ptr_->RegisterLocalMemory(
75-
value_cache, v_cache_size_per_block_, "cpu:0", false, false);
76-
77-
if (!register_v_result.has_value()) {
78-
LOG(ERROR) << "Failed to register local memory for value cache: "
79-
<< toString(register_v_result.error());
58+
if (config_.total_size > 0 && config_.tensor_data != nullptr) {
59+
auto result = client_ptr_->RegisterLocalMemory(
60+
config_.tensor_data, config_.total_size, "cpu:0", false, false);
61+
if (!result.has_value()) {
62+
LOG(ERROR) << "Failed to register local memory: "
63+
<< toString(result.error());
8064
return false;
8165
}
66+
} else {
67+
LOG(FATAL) << "rdma must RegisterLocalMemory, but got register size: "
68+
<< config_.total_size
69+
<< ", and data ptr: " << uint64_t(config_.tensor_data);
8270
}
8371
}
8472
is_initialized_ = true;

xllm/core/framework/kv_cache/kv_cache_store.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@ struct StoreConfig {
1919
std::string master_server_address = "";
2020
int replica_num = 1;
2121
uint32_t tp_rank = 0;
22+
size_t total_size = 0;
23+
void* tensor_data = nullptr;
2224
};
2325

2426
class KVCacheStore {

xllm/core/framework/model/model_input_params.h

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,11 @@ limitations under the License.
2727

2828
namespace xllm {
2929

30-
enum class TransferType : uint8_t { G2H = 0, H2D = 1, D2G = 2 };
30+
enum class TransferType : uint8_t {
31+
G2H = 0, // global memory(KVCache store) to host memory(DRAM)
32+
H2D = 1, // host memory(DRAM) to device memory(HBM)
33+
D2G = 2 // host memory(DRAM) to global memory(KVCache store)
34+
};
3135

3236
struct BlockTransferInfo {
3337
int32_t src_block_id = -1;

0 commit comments

Comments
 (0)