Skip to content

Commit 23386ef

Browse files
authored
feat: support single-node with multi devices in offline inference. (#267)
Signed-off-by: pengtao.156 <pengtao.156@jd.com>
1 parent 6c0f3e5 commit 23386ef

File tree

21 files changed

+368
-25
lines changed

21 files changed

+368
-25
lines changed

examples/generate.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
1-
# python examples/generate.py --model='/path/models/Qwen2-7B-Instruct' --devices='npu:0'
1+
# python examples/generate.py --model='/path/models/Qwen2-7B-Instruct' --devices='npu:0'
2+
# python generate.py --model='/path/models/Qwen2-7B-Instruct' --devices='npu:0,npu:1'
23

34
from xllm import ArgumentParser, LLM, RequestParams
45

examples/generate_vlm.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
1-
# python examples/generate_vlm.py --model='/path/models/Qwen2.5-VL-7B' --devices='npu:0' --master_node_addr=127.0.0.1:8888
1+
# python examples/generate_vlm.py --model='/path/models/Qwen2.5-VL-7B' --devices='npu:0'
2+
# python generate_vlm.py --model='/path/models/Qwen2.5-VL-7B' --devices='npu:0,npu:1'
23

34
import os
45
import signal

setup.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -610,7 +610,8 @@ def apply_patch():
610610
},
611611
zip_safe=False,
612612
py_modules=["xllm/launch_xllm", "xllm/__init__",
613-
"xllm/pybind/llm", "xllm/pybind/vlm", "xllm/pybind/args"],
613+
"xllm/pybind/llm", "xllm/pybind/vlm",
614+
"xllm/pybind/util", "xllm/pybind/args"],
614615
entry_points={
615616
'console_scripts': [
616617
'xllm = xllm.launch_xllm:launch_xllm'

xllm/core/common/options.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -170,6 +170,11 @@ class Options {
170170
PROPERTY(int, max_requests_per_batch) = 0;
171171

172172
PROPERTY(bool, enable_continuous_kvcache) = false;
173+
174+
// for offline inference: start with offline inference, default is false
175+
PROPERTY(bool, enable_offline_inference) = false;
176+
// for offline inference: the path to spawn worker binary
177+
PROPERTY(std::string, spawn_worker_path) = "";
173178
};
174179

175180
} // namespace xllm

xllm/core/distributed_runtime/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@ if(USE_NPU)
44
include_directories(
55
${CMAKE_SOURCE_DIR}/third_party/spdlog/include
66
)
7+
8+
add_subdirectory(spawn_worker_server)
79
endif()
810

911
cc_library(

xllm/core/distributed_runtime/dist_manager.cpp

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,9 @@ void DistManager::setup_multi_node_workers(
141141
// Node2: 0+4, 1+4, 2+4, 3+4
142142
const int32_t rank = static_cast<int32_t>(i) + base_rank;
143143

144+
// we use spawn process worker to launch a xllm instance
145+
// when start a offline inference task with multi-gpu/npu/mpu/...
146+
bool use_spawn_worker = options.enable_offline_inference() && i > 0;
144147
ParallelArgs parallel_args(rank, world_size, dp_size, nullptr, ep_size);
145148
servers_.emplace_back(std::make_unique<WorkerServer>(i,
146149
master_node_addr,
@@ -149,7 +152,8 @@ void DistManager::setup_multi_node_workers(
149152
parallel_args,
150153
devices[i],
151154
worker_server_options,
152-
worker_type));
155+
worker_type,
156+
use_spawn_worker));
153157
}
154158

155159
// Master node need to wait all workers done
Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
include(cc_binary)
2+
3+
cc_binary(
4+
NAME
5+
spawn_worker
6+
HDRS
7+
spawn_worker_server.h
8+
SRCS
9+
spawn_worker_server.cpp
10+
spawn_worker_server_process.cpp
11+
DEPS
12+
:models
13+
:model
14+
:distributed_runtime
15+
absl::strings
16+
xllm_kernels
17+
ascendcl
18+
nnopbase
19+
atb
20+
c_sec
21+
spdlog::spdlog
22+
)
23+
24+
add_dependencies(export_module spawn_worker)
Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
/* Copyright 2025 The xLLM Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
https://github.com/jd-opensource/xllm/blob/main/LICENSE
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License.
14+
==============================================================================*/
15+
16+
#include "spawn_worker_server.h"
17+
18+
#include <absl/strings/str_split.h>
19+
#if defined(USE_NPU)
20+
#include <acl/acl.h>
21+
#endif
22+
#include <signal.h>
23+
#include <sys/prctl.h>
24+
25+
#include <cstdlib>
26+
27+
#include "core/distributed_runtime/worker_server.h"
28+
#include "core/platform/device.h"
29+
#include "core/runtime/options.h"
30+
31+
namespace xllm {
32+
33+
bool xllm::SpawnWorkerServer::g_running_ = true;
34+
35+
SpawnWorkerServer::SpawnWorkerServer(const std::string& master_node_addr,
36+
int local_rank,
37+
int global_rank,
38+
int world_size,
39+
int device_idx,
40+
int num_decoding_tokens,
41+
int block_size) {
42+
// TODO: pass whole xllm::runtime::Options here from main process.
43+
xllm::runtime::Options runner_options;
44+
runner_options.block_size(block_size)
45+
.num_decoding_tokens(num_decoding_tokens)
46+
.enable_schedule_overlap(false)
47+
.enable_offline_inference(true)
48+
.master_node_addr(master_node_addr);
49+
FLAGS_enable_schedule_overlap = false;
50+
FLAGS_master_node_addr = master_node_addr;
51+
FLAGS_block_size = block_size;
52+
53+
std::atomic<bool> done(false);
54+
#if defined(USE_NPU)
55+
xllm::Device device("npu:" + std::to_string(device_idx));
56+
device.set_device();
57+
device.init_device_context();
58+
FLAGS_enable_atb_comm_multiprocess = true;
59+
#endif
60+
61+
ParallelArgs parallel_args(global_rank, world_size, 1, nullptr, 1);
62+
WorkerServer worker_server(local_rank,
63+
master_node_addr,
64+
done,
65+
parallel_args,
66+
device,
67+
runner_options,
68+
WorkerType::LLM,
69+
false);
70+
}
71+
72+
void SpawnWorkerServer::handle_signal(int signum) { g_running_ = false; }
73+
74+
void SpawnWorkerServer::run() {
75+
signal(SIGINT, SpawnWorkerServer::handle_signal);
76+
signal(SIGTERM, SpawnWorkerServer::handle_signal);
77+
78+
// main thread waiting here
79+
while (SpawnWorkerServer::g_running_) {
80+
sleep(5);
81+
}
82+
}
83+
84+
} // namespace xllm
Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
/* Copyright 2025 The xLLM Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
https://github.com/jd-opensource/xllm/blob/main/LICENSE
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License.
14+
==============================================================================*/
15+
16+
#pragma once
17+
18+
#include <string>
19+
20+
namespace xllm {
21+
22+
class SpawnWorkerServer final {
23+
public:
24+
explicit SpawnWorkerServer(const std::string& master_node_addr,
25+
int local_rank,
26+
int global_rank,
27+
int world_size,
28+
int device_idx,
29+
int num_decoding_tokens,
30+
int block_size);
31+
32+
~SpawnWorkerServer() = default;
33+
34+
void run();
35+
36+
static void handle_signal(int signum);
37+
38+
static bool g_running_;
39+
};
40+
41+
} // namespace xllm
Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
/* Copyright 2025 The xLLM Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
https://github.com/jd-opensource/xllm/blob/main/LICENSE
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License.
14+
==============================================================================*/
15+
16+
#include <gflags/gflags.h>
17+
#include <glog/logging.h>
18+
#include <signal.h>
19+
#include <sys/prctl.h>
20+
21+
#include "spawn_worker_server.h"
22+
23+
// Worker argv from engine process:
24+
// @master_node_addr
25+
// @local_rank
26+
// @global_rank
27+
// @world_size
28+
// @device_idx
29+
// @num_decoding_tokens
30+
// @block_size
31+
int main(int argc, char* argv[]) {
32+
if (argc < 7) {
33+
LOG(ERROR)
34+
<< "Spwan worker process receive wrong args. Need 7 args, receive "
35+
<< argc;
36+
return 1;
37+
}
38+
39+
// set PR_SET_PDEATHSIG flag that child should exit
40+
// when parent process exit
41+
if (prctl(PR_SET_PDEATHSIG, SIGHUP) == -1) {
42+
perror("prctl");
43+
return EXIT_FAILURE;
44+
}
45+
46+
std::string master_node_addr = std::string(argv[1]);
47+
int local_rank = atoi(argv[2]);
48+
int global_rank = atoi(argv[3]);
49+
int world_size = atoi(argv[4]);
50+
int device_idx = atoi(argv[5]);
51+
int num_decoding_tokens = atoi(argv[6]);
52+
int block_size = atoi(argv[7]);
53+
54+
LOG(INFO) << "Spwan worker: "
55+
<< "master_node_addr = " << master_node_addr
56+
<< ", local_rank = " << local_rank
57+
<< ", world_size = " << world_size
58+
<< ", device_idx = " << device_idx
59+
<< ", num_decoding_tokens = " << num_decoding_tokens
60+
<< ", block_size = " << block_size << "\n";
61+
62+
xllm::SpawnWorkerServer worker(master_node_addr,
63+
local_rank,
64+
global_rank,
65+
world_size,
66+
device_idx,
67+
num_decoding_tokens,
68+
block_size);
69+
70+
worker.run();
71+
72+
return 0;
73+
}

0 commit comments

Comments
 (0)