Skip to content

Commit b76a93a

Browse files
feat(scheduler): Launch TP>0 as subprocesses & fit scheduler (#222)
1 parent d57d7a9 commit b76a93a

File tree

13 files changed

+94
-29
lines changed

13 files changed

+94
-29
lines changed

docker/Dockerfile.blackwell

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
1-
FROM lmsysorg/sglang:v0.5.3rc1
1+
FROM lmsysorg/sglang:v0.5.4.post1
22

3-
ENV SGL_ENABLE_JIT_DEEPGEMM=0
3+
ENV SGLANG_ENABLE_JIT_DEEPGEMM=0
44

55
WORKDIR /parallax
66

@@ -9,4 +9,3 @@ COPY src ./src
99
COPY pyproject.toml ./pyproject.toml
1010

1111
RUN pip install -e '.[gpu]'
12-
RUN pip install https://github.com/sgl-project/whl/releases/download/v0.3.7/sgl_kernel-0.3.7+cu128-cp310-abi3-manylinux2014_x86_64.whl --force-reinstall

docker/Dockerfile.hopper

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
1-
FROM lmsysorg/sglang:v0.5.3rc1
1+
FROM lmsysorg/sglang:v0.5.4.post1
22

3-
ENV SGL_ENABLE_JIT_DEEPGEMM=0
3+
ENV SGLANG_ENABLE_JIT_DEEPGEMM=0
44

55
WORKDIR /parallax
66

src/backend/server/rpc_connection_handler.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -153,6 +153,7 @@ def get_layer_allocation(self, current_node_id):
153153
),
154154
"start_layer": start_layer,
155155
"end_layer": end_layer,
156+
"tp_size": node.hardware.num_gpus,
156157
}
157158
return {}
158159

@@ -182,13 +183,15 @@ def build_node(self, node_json: dict):
182183

183184
def build_hardware(self, hardware_json):
184185
node_id = hardware_json.get("node_id")
186+
num_gpus = hardware_json.get("num_gpus")
185187
tflops_fp16 = hardware_json.get("tflops_fp16")
186188
gpu_name = hardware_json.get("gpu_name")
187189
memory_gb = hardware_json.get("memory_gb")
188190
memory_bandwidth_gbps = hardware_json.get("memory_bandwidth_gbps")
189191
device = hardware_json.get("device")
190192
return NodeHardwareInfo(
191193
node_id=node_id,
194+
num_gpus=num_gpus,
192195
tflops_fp16=tflops_fp16,
193196
gpu_name=gpu_name,
194197
memory_gb=memory_gb,

src/backend/server/scheduler_manage.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,7 @@ def build_node_info(self, node):
121121
return {
122122
"node_id": node.node_id,
123123
"status": NODE_STATUS_AVAILABLE if node.is_active else NODE_STATUS_WAITING,
124+
"gpu_num": node.hardware.num_gpus,
124125
"gpu_name": node.hardware.gpu_name,
125126
"gpu_memory": node.hardware.memory_gb,
126127
}

src/parallax/cli.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -226,7 +226,7 @@ def join_command(args, passthrough_args: list[str] | None = None):
226226

227227
# Set environment variable for the subprocess
228228
env = os.environ.copy()
229-
env["SGL_ENABLE_JIT_DEEPGEMM"] = "0"
229+
env["SGLANG_ENABLE_JIT_DEEPGEMM"] = "0"
230230

231231
# Build the command to run the launch.py script
232232
passthrough_args = passthrough_args or []

src/parallax/launch.py

Lines changed: 43 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,11 @@
22
Launch the Parallax server.
33
44
This script is used to launch the Parallax server.
5-
It will start the P2P server and the executor.
5+
It will start the following services:
6+
1.Executor with tp_rank=0 in the main process.
7+
2.Executor with tp_rank>0, each tp_rank as a subprocess.
8+
3.HTTP server as a subprocess.
9+
4.P2P server as a thread in the main process.
610
711
Example command:
812
python src/parallax/launch.py \
@@ -41,7 +45,7 @@
4145
gradient_server = None
4246
http_server_process = None
4347
executor = None
44-
executor_procs = []
48+
executor_subprocs = []
4549
try:
4650
args = parse_args()
4751
set_log_level(args.log_level)
@@ -75,6 +79,7 @@
7579
pp_start_layer=args.start_layer,
7680
pp_end_layer=args.end_layer,
7781
hidden_layers=config.get("num_hidden_layers"),
82+
tp_size=args.tp_size,
7883
tcp_port=args.tcp_port,
7984
udp_port=args.udp_port,
8085
dht_prefix=args.dht_prefix,
@@ -91,18 +96,21 @@
9196
)
9297
if gradient_server is not None:
9398
gradient_server.status = ServerState.READY
94-
tp_rank_range = range(args.tp_size)
95-
for tp_rank in tp_rank_range:
99+
100+
# For each tp_rank > 0, create a subprocess and run executor
101+
for tp_rank in range(1, args.tp_size):
96102
args_copy = argparse.Namespace(**vars(args))
97103
args_copy.tp_rank = tp_rank
98104
proc = multiprocessing.Process(
99105
target=run_executor_process,
100106
args=(args_copy,),
101107
)
102108
proc.start()
103-
executor_procs.append(proc)
104-
for executor_process in executor_procs:
105-
executor_process.join()
109+
executor_subprocs.append(proc)
110+
# Launch executor with tp_rank=0 in the main process
111+
args.tp_rank = 0
112+
executor = Executor.create_from_args(args)
113+
executor.run_loop()
106114
else:
107115
gradient_server = launch_p2p_server(
108116
initial_peers=args.initial_peers,
@@ -111,6 +119,7 @@
111119
pp_start_layer=args.start_layer,
112120
pp_end_layer=args.end_layer,
113121
hidden_layers=None,
122+
tp_size=args.tp_size,
114123
tcp_port=args.tcp_port,
115124
udp_port=args.udp_port,
116125
dht_prefix=args.dht_prefix,
@@ -128,9 +137,7 @@
128137
args.start_layer = gradient_server.block_start_index
129138
args.end_layer = gradient_server.block_end_index
130139
args.model_path = gradient_server.model_name
131-
# TODO: Implement inter-process communication to enable TP.
132-
# For scheduler mode, currently only support tp_rank=0
133-
args.tp_rank = 0
140+
args.tp_size = gradient_server.tp_size
134141

135142
logger.debug(
136143
f"Start Executor with start_layer: {args.start_layer}, end_layer: {args.end_layer}"
@@ -148,6 +155,18 @@
148155
# Main execution loop with layer reallocation support
149156
while True:
150157
try:
158+
# For each tp_rank > 0, create a subprocess and run executor
159+
for tp_rank in range(1, args.tp_size):
160+
args_copy = argparse.Namespace(**vars(args))
161+
args_copy.tp_rank = tp_rank
162+
proc = multiprocessing.Process(
163+
target=run_executor_process,
164+
args=(args_copy,),
165+
)
166+
proc.start()
167+
executor_subprocs.append(proc)
168+
# Launch executor with tp_rank=0 in the main process
169+
args.tp_rank = 0
151170
executor = Executor.create_from_args(args, gradient_server=gradient_server)
152171
if gradient_server is not None:
153172
gradient_server.status = ServerState.READY
@@ -159,7 +178,18 @@
159178
logger.warning(
160179
"Layer allocation changed! Reloading executor with new layers..."
161180
)
181+
182+
# shutdown all executor processes
183+
thread_pool = []
184+
for executor_process in executor_subprocs:
185+
t = threading.Thread(
186+
target=stop_executor_process, args=(executor_process,)
187+
)
188+
t.start()
189+
thread_pool.append(t)
162190
executor.shutdown()
191+
for t in thread_pool:
192+
t.join()
163193

164194
if args.start_layer == 0:
165195
http_server_process = stop_http_server(http_server_process)
@@ -210,13 +240,13 @@
210240
if gradient_server is not None:
211241
gradient_server.shutdown()
212242

213-
# Shutdown executor subprocess for scheduler mode
214-
for executor_process in executor_procs:
243+
# Shutdown executor subprocesses
244+
for executor_process in executor_subprocs:
215245
t = threading.Thread(target=stop_executor_process, args=(executor_process,))
216246
t.start()
217247
thread_pool.append(t)
218248

219-
# Shutdown executor main process for non-scheduler mode
249+
# Shutdown executor main process
220250
if executor is not None:
221251
executor.shutdown()
222252

src/parallax/p2p/server.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -201,6 +201,7 @@ def __init__(
201201
block_start_index: int = 0,
202202
block_end_index: int = 1,
203203
hidden_layers: int = 128,
204+
tp_size: int = 1,
204205
dht_prefix: str = "gradient",
205206
host_maddrs: List[str] = [],
206207
http_port: Optional[int] = None,
@@ -220,6 +221,7 @@ def __init__(
220221
self.block_start_index = block_start_index
221222
self.block_end_index = block_end_index
222223
self.hidden_layers = hidden_layers
224+
self.tp_size = tp_size
223225
self.dht_prefix = dht_prefix
224226
self.host_maddrs = host_maddrs
225227
self.announce_maddrs = announce_maddrs
@@ -346,6 +348,7 @@ def run(self):
346348
self.block_start_index = response.get("start_layer")
347349
self.block_end_index = response.get("end_layer")
348350
self.model_name = response.get("model_name")
351+
self.tp_size = response.get("tp_size")
349352

350353
# Publish executor metrics to backend on each update
351354
def _publish_metrics(_snapshot):
@@ -738,6 +741,7 @@ def launch_p2p_server(
738741
pp_start_layer: int,
739742
pp_end_layer: int,
740743
hidden_layers: int,
744+
tp_size: int,
741745
tcp_port: int,
742746
udp_port: int,
743747
dht_prefix: str,
@@ -761,6 +765,7 @@ def launch_p2p_server(
761765
block_start_index=pp_start_layer,
762766
block_end_index=pp_end_layer,
763767
hidden_layers=hidden_layers,
768+
tp_size=tp_size,
764769
dht_prefix=dht_prefix,
765770
host_maddrs=[f"/ip4/0.0.0.0/tcp/{tcp_port}", f"/ip4/0.0.0.0/udp/{udp_port}/quic-v1"],
766771
announce_maddrs=announce_maddrs,

src/parallax/server/server_info.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ class HardwareInfo:
3232
total_ram_gb: float
3333
chip: str
3434
tflops_fp16: float
35+
num_gpus: int
3536

3637
def dumps(self) -> Dict[str, Any]:
3738
"""Serializes the HardwareInfo object to a dictionary."""
@@ -99,7 +100,7 @@ def detect(cls) -> "AppleSiliconHardwareInfo":
99100
"Please add it to the _APPLE_PEAK_FP16 dictionary."
100101
) from e
101102

102-
return cls(total_ram_gb=round(total_gb, 1), chip=chip, tflops_fp16=flops)
103+
return cls(num_gpus=1, total_ram_gb=round(total_gb, 1), chip=chip, tflops_fp16=flops)
103104

104105

105106
@dataclass
@@ -143,6 +144,7 @@ def detect(cls) -> "NvidiaHardwareInfo":
143144
if torch is None or not torch.cuda.is_available():
144145
raise RuntimeError("CUDA not available; cannot detect NVIDIA hardware")
145146

147+
device_count = torch.cuda.device_count()
146148
device_index = torch.cuda.current_device()
147149
props = torch.cuda.get_device_properties(device_index)
148150
name = getattr(props, "name", f"cuda:{device_index}")
@@ -156,6 +158,7 @@ def detect(cls) -> "NvidiaHardwareInfo":
156158

157159
spec = cls._match_gpu_specs(name, total_vram_gb)
158160
return cls(
161+
num_gpus=device_count,
159162
total_ram_gb=round(total_gb, 1),
160163
chip=name,
161164
tflops_fp16=float(spec["tflops_fp16"]),
@@ -179,6 +182,7 @@ def detect_node_hardware(node_id: Optional[str]) -> Dict[str, Any]:
179182
# Fallback to a conservative default
180183
return {
181184
"node_id": node_id,
185+
"num_gpus": 1,
182186
"tflops_fp16": 50.0,
183187
"gpu_name": "Unknown",
184188
"memory_gb": 16.0,
@@ -189,6 +193,7 @@ def detect_node_hardware(node_id: Optional[str]) -> Dict[str, Any]:
189193
if isinstance(hw, NvidiaHardwareInfo):
190194
return {
191195
"node_id": node_id,
196+
"num_gpus": hw.num_gpus,
192197
"tflops_fp16": hw.tflops_fp16,
193198
"gpu_name": hw.chip,
194199
"memory_gb": hw.vram_gb,
@@ -200,6 +205,7 @@ def detect_node_hardware(node_id: Optional[str]) -> Dict[str, Any]:
200205
est_bandwidth = 100.0
201206
return {
202207
"node_id": node_id,
208+
"num_gpus": hw.num_gpus,
203209
"tflops_fp16": hw.tflops_fp16,
204210
"gpu_name": hw.chip,
205211
"memory_gb": hw.total_ram_gb,
@@ -209,6 +215,7 @@ def detect_node_hardware(node_id: Optional[str]) -> Dict[str, Any]:
209215
# Generic fallback
210216
return {
211217
"node_id": node_id,
218+
"num_gpus": hw.num_gpus,
212219
"tflops_fp16": hw.tflops_fp16,
213220
"gpu_name": "Unknown",
214221
"memory_gb": 16.0,

src/scheduling/README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -117,13 +117,13 @@ model = ModelInfo( # instantiate with your model's parameters
117117

118118
n0 = Node(
119119
node_id="node-0",
120-
hardware=NodeHardwareInfo(node_id="node-0", tflops_fp16=180.0, gpu_name="", memory_gb=80.0, memory_bandwidth_gbps=2039.0),
120+
hardware=NodeHardwareInfo(node_id="node-0", tflops_fp16=180.0, num_gpus=1, gpu_name="", memory_gb=80.0, memory_bandwidth_gbps=2039.0),
121121
model_info=model,
122122
)
123123

124124
n1 = Node(
125125
node_id="node-1",
126-
hardware=NodeHardwareInfo(node_id="node-1", tflops_fp16=180.0, gpu_name="", memory_gb=80.0, memory_bandwidth_gbps=2039.0),
126+
hardware=NodeHardwareInfo(node_id="node-1", tflops_fp16=180.0, num_gpus=1, gpu_name="", memory_gb=80.0, memory_bandwidth_gbps=2039.0),
127127
model_info=model,
128128
)
129129

src/scheduling/layer_allocation.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -287,7 +287,9 @@ def should_global_rebalance(self) -> bool:
287287
if len(layer_heap) < 2:
288288
return False
289289

290-
total_cluster_memory = sum(node.hardware.memory_gb for node in self.nodes)
290+
total_cluster_memory = sum(
291+
(node.hardware.num_gpus * node.hardware.memory_gb) for node in self.nodes
292+
)
291293

292294
if total_cluster_memory == 0:
293295
raise ValueError("Total cluster memory is zero")

0 commit comments

Comments
 (0)