Skip to content

Commit e2fdd34

Browse files
committed
feat: implement batch prefetch from store.
1 parent 1b9dd1f commit e2fdd34

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

44 files changed

+501
-158
lines changed

xllm/core/common/global_flags.cpp

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -336,14 +336,18 @@ DEFINE_string(store_protocol,
336336
"tcp",
337337
"KV cache store protocol(e.g. tcp, rdma).");
338338

339-
DEFINE_string(store_master_server_entry,
339+
DEFINE_string(store_master_server_address,
340340
"",
341341
"The address information of the store master service.");
342342

343-
DEFINE_string(store_metadata_connstring,
343+
DEFINE_string(store_metadata_server,
344344
"",
345345
"The address of the kv cache store metadata service.");
346346

347+
DEFINE_string(store_local_hostname,
348+
"",
349+
"The local host name of the kv cache store client.");
350+
347351
// --- computation communication parallel config ---
348352

349353
DEFINE_bool(

xllm/core/common/global_flags.h

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -161,9 +161,11 @@ DECLARE_bool(enable_kvcache_store);
161161

162162
DECLARE_string(store_protocol);
163163

164-
DECLARE_string(store_master_server_entry);
164+
DECLARE_string(store_master_server_address);
165165

166-
DECLARE_string(store_metadata_connstring);
166+
DECLARE_string(store_metadata_server);
167+
168+
DECLARE_string(store_local_hostname);
167169

168170
DECLARE_bool(enable_multi_stream_parallel);
169171

xllm/core/common/options.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -54,8 +54,9 @@ std::string Options::to_string() const {
5454
<< ", enable_cache_upload: " << enable_cache_upload()
5555
<< ", enable_kvcache_store: " << enable_kvcache_store()
5656
<< ", store_protocol: " << store_protocol()
57-
<< ", store_master_server_entry: " << store_master_server_entry()
58-
<< ", store_metadata_connstring: " << store_metadata_connstring()
57+
<< ", store_master_server_address: " << store_master_server_address()
58+
<< ", store_metadata_server: " << store_metadata_server()
59+
<< ", store_local_hostname: " << store_local_hostname()
5960
<< ", enable_multi_stream_parallel: " << enable_multi_stream_parallel()
6061
<< ", enable_continuous_kvcache: " << enable_continuous_kvcache()
6162
<< ", disable_ttft_profiling: " << disable_ttft_profiling()

xllm/core/common/options.h

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -146,9 +146,11 @@ class Options {
146146

147147
PROPERTY(std::string, store_protocol) = "tcp";
148148

149-
PROPERTY(std::string, store_master_server_entry) = "";
149+
PROPERTY(std::string, store_master_server_address) = "";
150150

151-
PROPERTY(std::string, store_metadata_connstring) = "";
151+
PROPERTY(std::string, store_metadata_server) = "";
152+
153+
PROPERTY(std::string, store_local_hostname) = "";
152154

153155
PROPERTY(bool, enable_multi_stream_parallel) = false;
154156

xllm/core/distributed_runtime/comm_channel.cpp

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@ limitations under the License.
1818
#include <brpc/controller.h>
1919
#include <glog/logging.h>
2020

21+
#include <future>
22+
2123
namespace xllm {
2224

2325
bool CommChannel::init_brpc(const std::string& server_address) {
@@ -335,6 +337,94 @@ void CommChannel::transfer_kv_blocks(
335337
stub_->TransferBlocks(&cntl, &pb_block_transfer_info, &response, nullptr);
336338
}
337339

340+
class ClientStreamReceiver : public brpc::StreamInputHandler {
341+
private:
342+
const std::atomic<bool>& termination_flag_;
343+
std::shared_ptr<std::atomic<uint32_t>> success_cnt_;
344+
std::promise<void> close_promise_;
345+
std::atomic<bool> promise_set_{false};
346+
347+
public:
348+
ClientStreamReceiver(const std::atomic<bool>& termination_flag,
349+
std::shared_ptr<std::atomic<uint32_t>>& success_cnt)
350+
: termination_flag_(termination_flag), success_cnt_(success_cnt) {}
351+
352+
~ClientStreamReceiver() {
353+
if (!promise_set_.exchange(true)) {
354+
try {
355+
close_promise_.set_value();
356+
} catch (const std::exception& e) {
357+
LOG(WARNING) << "Exception in destructor: " << e.what();
358+
}
359+
}
360+
}
361+
362+
std::future<void> get_close_future() { return close_promise_.get_future(); }
363+
364+
int on_received_messages(brpc::StreamId id,
365+
butil::IOBuf* const messages[],
366+
size_t size) override {
367+
for (size_t i = 0; i < size; ++i) {
368+
std::string msg_str = messages[i]->to_string();
369+
int32_t success_cnt = std::stoi(msg_str);
370+
371+
if (success_cnt > 0 &&
372+
!termination_flag_.load(std::memory_order_acquire)) {
373+
success_cnt_->fetch_add(success_cnt, std::memory_order_relaxed);
374+
} else {
375+
brpc::StreamClose(id);
376+
if (!promise_set_.exchange(true)) {
377+
close_promise_.set_value();
378+
}
379+
break;
380+
}
381+
}
382+
return 0;
383+
}
384+
385+
virtual void on_idle_timeout(brpc::StreamId id) override {
386+
if (!promise_set_.exchange(true)) {
387+
close_promise_.set_value();
388+
}
389+
}
390+
391+
virtual void on_closed(brpc::StreamId id) override {
392+
if (!promise_set_.exchange(true)) {
393+
close_promise_.set_value();
394+
}
395+
}
396+
};
397+
398+
void CommChannel::prefetch_from_storage(
399+
const std::atomic<bool>& flag,
400+
const std::vector<BlockTransferInfo>& block_transfer_info,
401+
std::shared_ptr<std::atomic<uint32_t>>& success_cnt) {
402+
proto::BlockTransferInfos pb_block_transfer_info;
403+
if (!block_transfer_info_to_proto(
404+
0x0, block_transfer_info, &pb_block_transfer_info)) {
405+
return;
406+
}
407+
ClientStreamReceiver receiver(flag, success_cnt);
408+
brpc::Controller cntl;
409+
brpc::StreamOptions stream_options;
410+
brpc::StreamId stream_id;
411+
proto::Status response;
412+
stream_options.handler = &receiver;
413+
if (brpc::StreamCreate(&stream_id, cntl, &stream_options) != 0) {
414+
LOG(ERROR) << "Failed to create stream";
415+
return;
416+
}
417+
418+
stub_->PrefetchFromStorage(
419+
&cntl, &pb_block_transfer_info, &response, nullptr);
420+
421+
if (cntl.Failed()) {
422+
LOG(ERROR) << "Fail to connect stream, " << cntl.ErrorText();
423+
}
424+
425+
receiver.get_close_future().wait();
426+
}
427+
338428
bool CommChannel::get_last_step_result_async(
339429
folly::Promise<std::optional<RawForwardOutput>>& promise) {
340430
proto::Empty req;

xllm/core/distributed_runtime/comm_channel.h

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -87,10 +87,6 @@ class CommChannel {
8787
const uint64_t kv_cache_size,
8888
const std::vector<std::vector<int64_t>>& kv_cache_shape);
8989

90-
virtual bool load_kv_blocks_from_store_async(
91-
const std::vector<CacheBlockInfo>& cache_block_info,
92-
folly::Promise<uint32_t>& promise);
93-
9490
virtual void transfer_kv_blocks(
9591
const std::vector<BlockTransferInfo>& block_transfer_info,
9692
folly::Promise<uint32_t>& promise);
@@ -99,6 +95,11 @@ class CommChannel {
9995
const uint64_t batch_id,
10096
const std::vector<BlockTransferInfo>& block_transfer_info);
10197

98+
virtual void prefetch_from_storage(
99+
const std::atomic<bool>& flag,
100+
const std::vector<BlockTransferInfo>& block_transfer_info,
101+
std::shared_ptr<std::atomic<uint32_t>>& success_cnt);
102+
102103
virtual bool get_last_step_result_async(
103104
folly::Promise<std::optional<RawForwardOutput>>& promise);
104105

xllm/core/distributed_runtime/remote_worker.cpp

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ limitations under the License.
3535
#include "util/hash_util.h"
3636

3737
namespace xllm {
38+
3839
RemoteWorker::RemoteWorker(int32_t global_rank,
3940
const std::string& server_address,
4041
const torch::Device& d,
@@ -286,7 +287,7 @@ folly::SemiFuture<uint32_t> RemoteWorker::transfer_kv_blocks(
286287
const std::vector<BlockTransferInfo>& block_transfer_info) {
287288
folly::Promise<uint32_t> promise;
288289
auto future = promise.getSemiFuture();
289-
general_threadpool_.schedule(
290+
copy_threadpool_.schedule(
290291
[this,
291292
block_transfer_info = std::move(block_transfer_info),
292293
promise = std::move(promise)]() mutable {
@@ -298,14 +299,27 @@ folly::SemiFuture<uint32_t> RemoteWorker::transfer_kv_blocks(
298299
void RemoteWorker::transfer_kv_blocks(
299300
const uint64_t batch_id,
300301
const std::vector<BlockTransferInfo>& block_transfer_info) {
301-
general_threadpool_.schedule(
302+
copy_threadpool_.schedule(
302303
[this,
303304
batch_id = batch_id,
304305
block_transfer_info = std::move(block_transfer_info)]() mutable {
305306
channel_->transfer_kv_blocks(batch_id, block_transfer_info);
306307
});
307308
}
308309

310+
void RemoteWorker::prefetch_from_storage(
311+
const std::atomic<bool>& flag,
312+
const std::vector<BlockTransferInfo>& block_transfer_info,
313+
std::shared_ptr<std::atomic<uint32_t>>& success_cnt) {
314+
copy_threadpool_.schedule(
315+
[this,
316+
flag = &flag,
317+
block_transfer_info = std::move(block_transfer_info),
318+
success_cnt = success_cnt]() mutable {
319+
channel_->prefetch_from_storage(flag, block_transfer_info, success_cnt);
320+
});
321+
}
322+
309323
const torch::Device& RemoteWorker::device() const {
310324
LOG(ERROR) << "RemoteWorker Method device is UnImplemented.";
311325
}

xllm/core/distributed_runtime/remote_worker.h

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,11 @@ class RemoteWorker : public WorkerClient {
117117
const uint64_t batch_id,
118118
const std::vector<BlockTransferInfo>& block_transfer_info) override;
119119

120+
virtual void prefetch_from_storage(
121+
const std::atomic<bool>& flag,
122+
const std::vector<BlockTransferInfo>& block_transfer_info,
123+
std::shared_ptr<std::atomic<uint32_t>>& success_cnt) override;
124+
120125
// Run the model and return the output.
121126
virtual folly::SemiFuture<std::optional<ForwardOutput>> step_async(
122127
const ForwardInput& inputs) override;
@@ -144,9 +149,8 @@ class RemoteWorker : public WorkerClient {
144149
// connection resource
145150
std::unique_ptr<CommChannel> channel_;
146151
ThreadPool threadpool_;
147-
// general working thread
148-
// do some overlap work with model execute
149-
ThreadPool general_threadpool_{4};
152+
// copy working thread
153+
ThreadPool copy_threadpool_{4};
150154
const torch::Device device_;
151155
};
152156
} // namespace xllm

xllm/core/distributed_runtime/worker_service.cpp

Lines changed: 111 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -418,13 +418,12 @@ void WorkerService::PullKVCache(::google::protobuf::RpcController* controller,
418418

419419
void WorkerService::TransferBlocks(
420420
::google::protobuf::RpcController* controller,
421-
const ::xllm::proto::BlockTransferInfos* req,
422-
::xllm::proto::TransferStatus* resp,
421+
const proto::BlockTransferInfos* req,
422+
proto::TransferStatus* resp,
423423
::google::protobuf::Closure* done) {
424424
brpc::ClosureGuard done_guard(done);
425425
std::vector<BlockTransferInfo> block_transfer_info;
426-
uint64_t batch_id;
427-
proto_to_block_transfer_info(*req, batch_id, block_transfer_info);
426+
uint64_t batch_id = proto_to_block_transfer_info(*req, block_transfer_info);
428427

429428
if (batch_id == 0x0) {
430429
resp->set_success_cnt(worker_->transfer_kv_blocks(block_transfer_info));
@@ -434,6 +433,114 @@ void WorkerService::TransferBlocks(
434433
return;
435434
}
436435

436+
class ServerStreamHandler : public brpc::StreamInputHandler {
437+
private:
438+
std::promise<void> close_promise_;
439+
std::atomic<bool> promise_set_{false};
440+
441+
public:
442+
~ServerStreamHandler() {
443+
if (!promise_set_.exchange(true)) {
444+
try {
445+
close_promise_.set_value();
446+
} catch (const std::exception& e) {
447+
LOG(WARNING) << "Exception in destructor: " << e.what();
448+
}
449+
}
450+
}
451+
452+
std::future<void> get_close_future() { return close_promise_.get_future(); }
453+
454+
int on_received_messages(brpc::StreamId id,
455+
butil::IOBuf* const messages[],
456+
size_t size) override {
457+
LOG(WARNING) << "ServerStreamHandler::on_received_messages not implement.";
458+
return 0;
459+
}
460+
461+
void on_closed(brpc::StreamId id) override {
462+
if (!promise_set_.exchange(true)) {
463+
close_promise_.set_value();
464+
}
465+
}
466+
467+
void on_idle_timeout(brpc::StreamId id) override {
468+
if (!promise_set_.exchange(true)) {
469+
LOG(WARNING) << "Stream idle timeout: " << id;
470+
close_promise_.set_value();
471+
}
472+
}
473+
};
474+
475+
void WorkerService::PrefetchFromStorage(
476+
google::protobuf::RpcController* controller,
477+
const proto::BlockTransferInfos* req,
478+
proto::Status* resp,
479+
google::protobuf::Closure* done) {
480+
brpc::ClosureGuard done_guard(done);
481+
brpc::Controller* cntl = static_cast<brpc::Controller*>(controller);
482+
483+
auto stream_handler = std::make_unique<ServerStreamHandler>();
484+
auto stream_id = std::make_unique<brpc::StreamId>();
485+
brpc::StreamOptions stream_options;
486+
stream_options.handler = stream_handler.get();
487+
if (brpc::StreamAccept(stream_id.get(), *cntl, &stream_options) != 0) {
488+
resp->set_ok(false);
489+
LOG(ERROR) << "Failed to accept stream!";
490+
return;
491+
}
492+
493+
std::vector<BlockTransferInfo> block_transfer_info;
494+
proto_to_block_transfer_info(*req, block_transfer_info);
495+
496+
copy_threadpool_.schedule(
497+
[this,
498+
block_transfer_info = std::move(block_transfer_info),
499+
stream_id = std::move(stream_id),
500+
stream_handler = std::move(stream_handler)]() mutable {
501+
Slice<BlockTransferInfo> transfer_slice{block_transfer_info};
502+
auto close_future = stream_handler->get_close_future();
503+
bool is_completed = false;
504+
505+
for (size_t i = 0; i < transfer_slice.size();
506+
i += stream_copy_batch_size_) {
507+
auto current_slice = transfer_slice.slice(
508+
i, std::min(i + stream_copy_batch_size_, transfer_slice.size()));
509+
510+
auto success_cnt = worker_->prefetch_from_storage(current_slice);
511+
512+
if (success_cnt != current_slice.size() ||
513+
i + stream_copy_batch_size_ >= transfer_slice.size()) {
514+
is_completed = true;
515+
}
516+
517+
butil::IOBuf buf;
518+
buf.append(std::to_string(success_cnt));
519+
if (brpc::StreamWrite(*stream_id.get(), buf) != 0) {
520+
brpc::StreamClose(*stream_id.get());
521+
is_completed = false;
522+
break;
523+
}
524+
525+
if (is_completed) {
526+
if (success_cnt != 0) {
527+
butil::IOBuf buf_end;
528+
buf_end.append("0");
529+
brpc::StreamWrite(*stream_id.get(), buf_end);
530+
}
531+
break;
532+
}
533+
}
534+
if (is_completed) {
535+
close_future.wait();
536+
}
537+
brpc::StreamClose(*stream_id.get());
538+
});
539+
540+
resp->set_ok(true);
541+
return;
542+
}
543+
437544
void WorkerService::GetDeviceInfo(::google::protobuf::RpcController* controller,
438545
const proto::Empty* req,
439546
proto::DeviceInfo* resp,

0 commit comments

Comments
 (0)