Skip to content

Commit f48ede6

Browse files
authored
feat: enable multi-process mode when running VLM model. (#330)
1 parent 014d305 commit f48ede6

27 files changed

+473
-220
lines changed

examples/generate_vlm.py

100644100755
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# python generate_vlm.py --model /path/to/Qwen2.5-VL-7B-Instruct/ --disable_prefix_cache --disable_chunked_prefill --max_seqs_per_batch 4
1+
# python generate_vlm.py --model /path/to/Qwen2.5-VL-7B-Instruct/ --disable_prefix_cache --disable_chunked_prefill --max_seqs_per_batch 4 --devices='npu:0' --enable_shm
22

33
import os
44
import signal

xllm/core/common/options.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,8 @@ std::string Options::to_string() const {
2323
<< ", devices: " << devices().value_or("null")
2424
<< ", draft_model_path: " << draft_model_path().value_or("null")
2525
<< ", draft_devices: " << draft_devices().value_or("null")
26-
<< ",limit_image_per_prompt: " << limit_image_per_prompt()
26+
<< ", backend: " << backend()
27+
<< ", limit_image_per_prompt: " << limit_image_per_prompt()
2728
<< ", block_size: " << block_size()
2829
<< ", max_cache_size: " << max_cache_size()
2930
<< ", max_memory_utilization: " << max_memory_utilization()

xllm/core/common/options.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,9 @@ class Options {
4545

4646
PROPERTY(std::optional<std::string>, draft_devices);
4747

48+
// model backend
49+
PROPERTY(std::string, backend);
50+
4851
// max image num per prompt, default 4
4952
PROPERTY(int32_t, limit_image_per_prompt) = 4;
5053

xllm/core/distributed_runtime/dist_manager.cpp

Lines changed: 14 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -31,65 +31,10 @@ namespace xllm {
3131

3232
DistManager::DistManager(const runtime::Options& options) {
3333
auto master_node_addr = options.master_node_addr().value_or("");
34-
// Single-Node Worker Mode
35-
if (master_node_addr.empty()) {
36-
setup_single_node_workers(options);
37-
} else {
38-
// Multi-node Worker Mode
34+
if (!master_node_addr.empty()) {
3935
setup_multi_node_workers(options, master_node_addr);
40-
}
41-
}
42-
43-
void DistManager::setup_single_node_workers(const runtime::Options& options) {
44-
const auto& devices = options.devices();
45-
CHECK_EQ((devices.size() % options.dp_size()), 0)
46-
<< "Device size must be divisible by dp size in single-node serving "
47-
"mode.";
48-
49-
// initialize process groups if there are multiple devices
50-
if (devices.size() > 1) {
51-
// create a process group for each device if there are multiple gpus
52-
process_groups_ = parallel_state::create_npu_process_groups(devices);
53-
}
54-
55-
const int32_t dp_local_tp_size = devices.size() / options.dp_size();
56-
if (options.dp_size() > 1 && options.dp_size() < devices.size()) {
57-
dp_local_process_groups_.reserve(options.dp_size());
58-
for (size_t dp_rank = 0; dp_rank < options.dp_size(); ++dp_rank) {
59-
auto dp_local_group_device_begin_idx = devices.begin();
60-
std::advance(dp_local_group_device_begin_idx, dp_rank * dp_local_tp_size);
61-
auto dp_local_group_device_end_idx = devices.begin();
62-
std::advance(dp_local_group_device_end_idx,
63-
(dp_rank + 1) * dp_local_tp_size);
64-
std::vector<torch::Device> dp_local_group_devices;
65-
std::copy(dp_local_group_device_begin_idx,
66-
dp_local_group_device_end_idx,
67-
std::back_inserter(dp_local_group_devices));
68-
dp_local_process_groups_.emplace_back(
69-
parallel_state::create_npu_process_groups(dp_local_group_devices));
70-
}
71-
}
72-
73-
// create a worker(as worker client also) for each device
74-
const int32_t world_size = static_cast<int32_t>(devices.size());
75-
WorkerType worker_type =
76-
(options.task_type() == "generate") ? WorkerType::LLM : WorkerType::ELM;
77-
for (size_t i = 0; i < devices.size(); ++i) {
78-
const int32_t rank = static_cast<int32_t>(i);
79-
ProcessGroup* pg = world_size > 1 ? process_groups_[i].get() : nullptr;
80-
// dp local process groups
81-
ProcessGroup* dp_local_pg =
82-
(options.dp_size() > 1 && options.dp_size() < world_size)
83-
? (dp_local_process_groups_[i / dp_local_tp_size]
84-
[i % dp_local_tp_size])
85-
.get()
86-
: nullptr;
87-
ParallelArgs parallel_args(
88-
rank, world_size, pg, dp_local_pg, options.dp_size());
89-
workers_.emplace_back(std::make_unique<Worker>(
90-
parallel_args, devices[i], options, worker_type));
91-
worker_clients_.emplace_back(
92-
std::make_unique<WorkerClient>(workers_.back().get()));
36+
} else {
37+
LOG(FATAL) << "master_node_addr is empty.";
9338
}
9439
}
9540

@@ -166,10 +111,17 @@ void DistManager::setup_multi_node_workers(
166111

167112
runtime::Options worker_server_options = options;
168113
worker_server_options.world_size(world_size);
169-
170-
WorkerType worker_type =
171-
(options.task_type() == "generate") ? WorkerType::LLM : WorkerType::ELM;
172-
114+
WorkerType worker_type("LLM");
115+
const auto& model_backend = options.backend();
116+
if (model_backend == "llm") {
117+
worker_type =
118+
(options.task_type() == "generate") ? WorkerType::LLM : WorkerType::ELM;
119+
} else if (model_backend == "vlm") {
120+
worker_type = (options.task_type() == "generate") ? WorkerType::VLM
121+
: WorkerType::EVLM;
122+
} else {
123+
LOG(ERROR) << "Unsupported " << model_backend << " in multi-node.";
124+
}
173125
// create local workers
174126
for (size_t i = 0; i < devices.size(); ++i) {
175127
// worldsize = 8

xllm/core/distributed_runtime/dist_manager.h

100644100755
Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,8 +34,6 @@ class DistManager {
3434

3535
private:
3636
DISALLOW_COPY_AND_ASSIGN(DistManager);
37-
38-
void setup_single_node_workers(const runtime::Options& options);
3937
void setup_multi_node_workers(const runtime::Options& options,
4038
const std::string& master_node_addr);
4139

xllm/core/distributed_runtime/worker_server.cpp

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ limitations under the License.
3636
#include "framework/parallel_state/collective_communicator.h"
3737
#include "framework/parallel_state/mapping_npu.h"
3838
#include "framework/state_dict/state_dict.h"
39+
#include "runtime/forward_params.h"
3940
#include "runtime/worker.h"
4041
#include "server/xllm_server_registry.h"
4142
#include "util/net.h"
@@ -65,6 +66,7 @@ void WorkerServer::create_server(
6566
int32_t dp_size,
6667
int local_rank,
6768
int32_t ep_size,
69+
WorkerType worker_type,
6870
std::unique_ptr<ForwardSharedMemoryManager> input_shm_manager,
6971
std::unique_ptr<ForwardSharedMemoryManager> output_shm_manager) {
7072
Device device(d);
@@ -106,11 +108,6 @@ void WorkerServer::create_server(
106108
comm.create_process_groups(master_node_addr, device);
107109
#endif
108110

109-
WorkerType worker_type =
110-
(options.task_type() == "generate") ? WorkerType::LLM : WorkerType::ELM;
111-
CHECK(worker_type == WorkerType::LLM || worker_type == WorkerType::ELM)
112-
<< "Multi Node only support LLM and ELM Now, but get task type = "
113-
<< options.task_type();
114111
std::unique_ptr<Worker> worker =
115112
std::make_unique<Worker>(*parallel_args, device, options, worker_type);
116113
worker_service->set_worker(std::move(worker));
@@ -216,8 +213,8 @@ WorkerServer::WorkerServer(int local_worker_idx,
216213
const runtime::Options& options,
217214
WorkerType worker_type,
218215
bool use_spawn_worker) {
219-
if (worker_type == WorkerType::LLM || worker_type == WorkerType::ELM) {
220-
// TODO: Refactor these code later.
216+
if (worker_type == WorkerType::LLM || worker_type == WorkerType::ELM ||
217+
worker_type == WorkerType::VLM || worker_type == WorkerType::EVLM) {
221218
if (use_spawn_worker) {
222219
// start worker in a spawn process(for offline inference worker.)
223220
create_spawn_server(local_worker_idx,
@@ -251,6 +248,7 @@ WorkerServer::WorkerServer(int local_worker_idx,
251248
parallel_args.dp_size(),
252249
local_worker_idx,
253250
parallel_args.ep_size(),
251+
worker_type,
254252
std::move(input_shm_manager),
255253
std::move(output_shm_manager));
256254
}

xllm/core/distributed_runtime/worker_server.h

100644100755
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@ class WorkerServer {
5858
int32_t dp_size,
5959
int local_rank,
6060
int32_t ep_size,
61+
WorkerType worker_type,
6162
std::unique_ptr<ForwardSharedMemoryManager> input_shm_manager,
6263
std::unique_ptr<ForwardSharedMemoryManager> output_shm_manager);
6364

xllm/core/distributed_runtime/worker_service.cpp

100644100755
Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
See the License for the specific language governing permissions and
1313
limitations under the License.
1414
==============================================================================*/
15-
1615
#include "worker_service.h"
1716

1817
#include <brpc/closure_guard.h>

xllm/core/framework/batch/batch.cpp

100644100755
Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -52,8 +52,8 @@ void Batch::add(Sequence* sequence, uint32_t allowed_max_token) {
5252
input_embeddings_vec_.emplace_back(input_embedding);
5353

5454
const auto& mm_data = sequence->get_mm_data();
55-
// if (sequence->is_prefill_stage() && mm_data.valid()) // TODO:Compatible
56-
// With Chunked Prefill
55+
// if (sequence->is_prefill_stage() && mm_data.valid()) // TODO:Compatible
56+
// With Chunked Prefill
5757
if ((sequence->kv_state().kv_cache_tokens_num() <
5858
sequence->num_prompt_tokens()) &&
5959
mm_data.valid())
@@ -83,6 +83,7 @@ ForwardInput Batch::prepare_forward_input(uint32_t num_decoding_tokens,
8383

8484
RawForwardInput Batch::prepare_forward_input(uint32_t start_idx,
8585
uint32_t end_idx,
86+
const ModelArgs& args,
8687
ThreadPool* thread_pool) {
8788
BatchInputBuilder builder(sequences_,
8889
allowed_max_tokens_,
@@ -91,7 +92,7 @@ RawForwardInput Batch::prepare_forward_input(uint32_t start_idx,
9192
copy_in_cache_block_infos_,
9293
copy_out_cache_block_infos_,
9394
swap_cache_block_infos_,
94-
nullptr,
95+
&args,
9596
thread_pool);
9697
return builder.build_raw_forward_input(start_idx, end_idx);
9798
}

xllm/core/framework/batch/batch.h

100644100755
Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,8 @@ class Batch {
7777
// Convert Batch to pb type, which will be pass to remote worker.
7878
RawForwardInput prepare_forward_input(uint32_t start_idx,
7979
uint32_t end_idx,
80-
ThreadPool* thread_pool = nullptr);
80+
const ModelArgs& args,
81+
ThreadPool* thread_pool);
8182

8283
// process output
8384
//

0 commit comments

Comments
 (0)