Skip to content

Commit fdc82a9

Browse files
authored
feat(scheduler): warmup before global rebalance (#204)
Co-authored-by: christ-tt
1 parent adc4fc1 commit fdc82a9

File tree

4 files changed

+188
-26
lines changed

4 files changed

+188
-26
lines changed

src/scheduling/layer_allocation.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -207,6 +207,11 @@ def deallocate(self, node: Node) -> None:
207207
node.is_active = False
208208
self._update_layer_loads_heap()
209209

210+
def reallocate(self, node: Node, start_layer: int, end_layer: int) -> None:
211+
"""Reallocate a node to a specific layer range."""
212+
self.deallocate(node)
213+
self.allocate(node, start_layer, end_layer)
214+
210215
def declare(self, node: Node) -> None:
211216
"""Declare a node to the allocator."""
212217
if node.node_id not in self.node_id_to_node:

src/scheduling/request_routing.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,11 +52,17 @@ class DynamicProgrammingRouting(RequestRoutingStrategy):
5252
minimum-latency node sequence and total latency.
5353
"""
5454

55-
def find_turning_points(self, nodes: List[Node], num_layers: int) -> List[Tuple[str, int, str]]:
55+
@staticmethod
56+
def find_turning_points(nodes: List[Node], num_layers: int) -> List[Tuple[str, int, str]]:
5657
"""Find shard truncation points via layer-level DP.
5758
5859
DP state is (layer l, node i that hosts l). Node cost uses the node's
5960
per-layer latency proxy; edge cost uses RTT between nodes.
61+
62+
This is a static method that can be called directly without creating an instance:
63+
DynamicProgrammingRouting.find_turning_points(nodes, num_layers)
64+
65+
It can also be called via an instance, which will work due to Python's method resolution.
6066
"""
6167
if num_layers <= 0 or not nodes:
6268
return []

src/scheduling/scheduler.py

Lines changed: 70 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -116,72 +116,116 @@ def __init__(
116116
pass
117117

118118
# Orchestration helpers
119-
def bootstrap(self) -> bool:
120-
"""Bootstrapping: first-time layer allocation and optional warm-up.
119+
def bootstrap(self, *, clear_existing: bool = False, skip_warmup: bool = False) -> bool:
120+
"""Bootstrapping:
121+
This method can be used for both initial bootstrapping and global rebalancing.
122+
When clear_existing=True, it first deallocates all existing allocations before
123+
performing global allocation (rebalancing behavior). When clear_existing=False,
124+
it performs allocation on top of existing state (initial bootstrapping behavior).
121125
122-
Returns True if a full pipeline was established; False otherwise.
126+
Args:
127+
clear_existing: If True, deallocate all existing allocations before reallocating.
128+
This is used for global rebalancing. Default is False.
129+
skip_warmup: If True, skip the warm-up and truncate step. Default is False.
130+
131+
Returns:
132+
True if a full pipeline was established; False otherwise.
123133
"""
124-
if len(self.nodes) < self.min_nodes_bootstrapping:
134+
# Check node count only for initial bootstrapping (not rebalancing)
135+
if not clear_existing and len(self.nodes) < self.min_nodes_bootstrapping:
125136
logger.debug(
126137
f"Bootstrapping deferred: have {len(self.nodes)} nodes; need >= {self.min_nodes_bootstrapping}"
127138
)
128139
return False
129-
logger.debug("Bootstrapping layer allocator")
140+
141+
# Clear existing allocations if this is a rebalance
142+
if clear_existing:
143+
logger.debug("Performing global rebalance (clearing existing allocations)")
144+
self._bootstrapped = False
145+
self._bootstrapped_event.clear()
146+
for n in self.nodes:
147+
if n.start_layer is not None and n.end_layer is not None:
148+
self.layer_allocator.deallocate(n)
149+
else:
150+
logger.debug("Bootstrapping layer allocator")
151+
152+
# Perform global allocation
130153
success = self.layer_allocator.global_allocation()
131154
if not success:
132-
logger.warning("Bootstrapping failed to produce a full pipeline")
155+
logger.warning("Global allocation failed to produce a full pipeline")
133156
return False
157+
134158
assignments = self.list_node_allocations()
135159
logger.debug(f"Layer allocator assignments: {assignments}")
160+
136161
# Optional warm-up to find turning points and truncate node ranges
137-
if self.request_warm_up_for_reshard > 0:
162+
# Skip warmup for rebalancing scenarios (can be overridden with skip_warmup=False)
163+
if not skip_warmup and self.request_warm_up_for_reshard > 0:
138164
self._run_warmup_and_truncate()
139165
assignments = self.list_node_allocations()
140166
logger.debug(f"Layer allocator assignments after turn-point warm-up: {assignments}")
141167

142168
if not self.layer_allocator.has_full_pipeline():
143169
logger.warning("Bootstrapping failed to produce a full pipeline")
144170
return False
171+
145172
self._bootstrapped = True
146173
self._bootstrapped_event.set()
147-
logger.debug("Bootstrapping completed successfully; full pipeline established")
174+
action = "rebalance" if clear_existing else "bootstrapping"
175+
logger.debug(f"{action.capitalize()} completed successfully; full pipeline established")
148176
return True
149177

150178
def list_node_allocations(self) -> List[Tuple[str, int, int]]:
151179
"""List the allocations of all nodes."""
152180
return self.layer_allocator.list_node_allocations()
153181

154182
# Warm-up and re-shard
155-
def _run_warmup_and_truncate(self) -> None:
183+
def _run_warmup_and_truncate(self, override_warmup_count: int = 0) -> None:
156184
"""Run a brief warm-up to detect truncation points and shrink shards.
157185
158186
Uses layer-level DP turning points (node_id, layer_idx, kind):
159187
- kind == "tail": drop [layer_idx, end) on that node
160188
- kind == "head": drop [start, layer_idx) on that node
189+
190+
Note: Always uses DynamicProgrammingRouting for finding turning points,
191+
regardless of the current request_router type, since turning points
192+
detection requires layer-level DP analysis.
193+
194+
Args:
195+
override_warmup_count: If > 0, use this value instead of request_warm_up_for_reshard.
196+
Default is 0, which means use request_warm_up_for_reshard.
161197
"""
162198
nodes_list = list(self.nodes)
163199
if not nodes_list:
164200
return
165201
num_layers = self.model_info.num_layers
202+
166203
# The number of warm-up requests can be used to repeat detection, but a
167204
# single pass is sufficient with our DP model; we repeat to smooth noise.
205+
warmup_count = (
206+
override_warmup_count if override_warmup_count > 0 else self.request_warm_up_for_reshard
207+
)
208+
168209
agg_turns: Dict[Tuple[str, int, str], int] = {}
169-
for _ in range(self.request_warm_up_for_reshard):
170-
turns = self.request_router.find_turning_points(nodes_list, num_layers)
210+
for _ in range(warmup_count):
211+
turns = DynamicProgrammingRouting.find_turning_points(nodes_list, num_layers)
171212
for t in turns:
172213
agg_turns[t] = agg_turns.get(t, 0) + 1
214+
173215
# Apply truncation for consistently observed turning points
216+
# Note: Must use layer_allocator.allocate/deallocate to properly update
217+
# internal state (node_allocation dict and layer_to_load)
174218
for node_id, layer_idx, kind in agg_turns:
175219
node = next((n for n in self.nodes if n.node_id == node_id), None)
176220
if node is None or node.start_layer is None or node.end_layer is None:
177221
continue
178222
start, end = node.start_layer, node.end_layer
179223
if kind == "tail":
180224
if layer_idx < end:
181-
node.set_layer_allocation(start, layer_idx)
225+
self.layer_allocator.reallocate(node, start, layer_idx)
182226
elif kind == "head":
183227
if layer_idx > start:
184-
node.set_layer_allocation(layer_idx, end)
228+
self.layer_allocator.reallocate(node, layer_idx, end)
185229

186230
def update_node_info(
187231
self,
@@ -316,19 +360,20 @@ def leave(self, node_id: str) -> None:
316360
f"Mixed assignment detected ({manual_count} manual, {total_count - manual_count} automatic); skipping rebalance"
317361
)
318362
else:
319-
# All nodes are automatic, proceed with rebalance
320-
self._bootstrapped = False
321-
self._bootstrapped_event.clear()
322-
for n in self.nodes:
323-
if n.start_layer is not None and n.end_layer is not None:
324-
self.layer_allocator.deallocate(n)
325-
success = self.layer_allocator.global_allocation()
326-
if not success:
327-
logger.warning("Global rebalance failed to produce a full pipeline")
363+
# All nodes are automatic, try adjustment first, then rebalance if needed
364+
if not self.layer_allocator.has_full_pipeline():
365+
logger.debug(
366+
"No full pipeline after node leave, attempting warmup and truncate"
367+
)
368+
self._run_warmup_and_truncate(override_warmup_count=1)
369+
if not self.layer_allocator.has_full_pipeline():
370+
self.bootstrap(clear_existing=True, skip_warmup=True)
371+
else:
372+
logger.debug(
373+
"Pipeline recovered through warmup and truncate, skipping global rebalance"
374+
)
328375
else:
329-
logger.debug("Global rebalance completed successfully")
330-
self._bootstrapped = True
331-
self._bootstrapped_event.set()
376+
self.bootstrap(clear_existing=True, skip_warmup=True)
332377

333378
with self._node_count_cv:
334379
self._node_count_cv.notify_all()

tests/scheduler_tests/test_scheduler.py

Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -133,3 +133,109 @@ def test_scheduler_single_node_leave_then_rejoin_reassigns_layers():
133133
assert (
134134
n1_rejoin.start_layer is not None and n1_rejoin.end_layer is not None
135135
), "After re-join, single node should be assigned a full layer range"
136+
137+
138+
def test_scheduler_three_nodes_sequential_join_leave_rejoin():
139+
"""Test scheduler with 28-layer model, 3 nodes each capable of 22 layers.
140+
141+
Scenario:
142+
- 28-layer model
143+
- n1, n2, n3 all can host 22 layers
144+
- min_nodes_bootstrapping=2
145+
- n1, n2, n3 join sequentially
146+
- n1 leaves and rejoins
147+
- n2 leaves and rejoins
148+
- n3 leaves and rejoins
149+
"""
150+
model = build_model_info(28)
151+
152+
# Create nodes that can each host 22 layers
153+
# Calculation: 100GB can host 16 layers, so 22 layers need ~137.5GB
154+
# Using 150GB to ensure capacity for 22 layers with some margin
155+
n1 = build_node("n1", model, tflops=312.0, mem_gb=138.0, x=0, y=0)
156+
n2 = build_node("n2", model, tflops=312.0, mem_gb=138.0, x=1, y=0)
157+
n3 = build_node("n3", model, tflops=312.0, mem_gb=138.0, x=2, y=0)
158+
159+
# Verify nodes can host 22 layers
160+
assert n1.get_decoder_layer_capacity() >= 22, "n1 should be able to host 22 layers"
161+
assert n2.get_decoder_layer_capacity() >= 22, "n2 should be able to host 22 layers"
162+
assert n3.get_decoder_layer_capacity() >= 22, "n3 should be able to host 22 layers"
163+
164+
# Initialize scheduler with min_nodes_bootstrapping=2, no nodes initially
165+
sched = Scheduler(model, [], strategy="dp", min_nodes_bootstrapping=2)
166+
167+
# Step 1: n1 joins (not enough nodes yet)
168+
sched.enqueue_join(n1)
169+
sched._process_joins() # type: ignore[attr-defined]
170+
assert len(sched.nodes) == 1
171+
assert not sched.layer_allocator.has_full_pipeline()
172+
173+
# Step 2: n2 joins (now we have 2 nodes, should bootstrap)
174+
sched.enqueue_join(n2)
175+
sched._process_joins() # type: ignore[attr-defined]
176+
set_rtt_from_coords(sched.nodes)
177+
ok = sched.bootstrap()
178+
assert ok, "Bootstrap should succeed with 2 nodes"
179+
assert sched.layer_allocator.has_full_pipeline()
180+
181+
# Step 3: n3 joins (dynamic join after bootstrap)
182+
sched.enqueue_join(n3)
183+
sched._process_joins() # type: ignore[attr-defined]
184+
set_rtt_from_coords(sched.nodes)
185+
assert n3.start_layer is not None and n3.end_layer is not None
186+
assert len(sched.nodes) == 3
187+
188+
# Step 4: n1 leaves and rejoins
189+
n1_id = n1.node_id
190+
sched.leave(n1_id)
191+
assert n1 not in sched.nodes
192+
assert len(sched.nodes) == 2
193+
assert sched.layer_allocator.has_full_pipeline()
194+
195+
# Rejoin n1
196+
n1_rejoin = build_node("n1", model, tflops=312.0, mem_gb=138.0, x=0, y=0)
197+
sched.enqueue_join(n1_rejoin)
198+
sched._process_joins() # type: ignore[attr-defined]
199+
set_rtt_from_coords(sched.nodes)
200+
assert n1_rejoin.start_layer is not None and n1_rejoin.end_layer is not None
201+
assert len(sched.nodes) == 3
202+
assert sched.layer_allocator.has_full_pipeline()
203+
204+
# Step 5: n2 leaves and rejoins
205+
n2_id = n2.node_id
206+
sched.leave(n2_id)
207+
assert n2 not in sched.nodes
208+
assert len(sched.nodes) == 2
209+
assert sched.layer_allocator.has_full_pipeline()
210+
211+
# Rejoin n2
212+
n2_rejoin = build_node("n2", model, tflops=312.0, mem_gb=138.0, x=1, y=0)
213+
sched.enqueue_join(n2_rejoin)
214+
sched._process_joins() # type: ignore[attr-defined]
215+
set_rtt_from_coords(sched.nodes)
216+
assert n2_rejoin.start_layer is not None and n2_rejoin.end_layer is not None
217+
assert len(sched.nodes) == 3
218+
assert sched.layer_allocator.has_full_pipeline()
219+
220+
# Step 6: n3 leaves and rejoins
221+
n3_id = n3.node_id
222+
sched.leave(n3_id)
223+
assert n3 not in sched.nodes
224+
assert len(sched.nodes) == 2
225+
assert sched.layer_allocator.has_full_pipeline()
226+
227+
# Rejoin n3
228+
n3_rejoin = build_node("n3", model, tflops=312.0, mem_gb=138.0, x=2, y=0)
229+
sched.enqueue_join(n3_rejoin)
230+
sched._process_joins() # type: ignore[attr-defined]
231+
set_rtt_from_coords(sched.nodes)
232+
assert n3_rejoin.start_layer is not None and n3_rejoin.end_layer is not None
233+
assert len(sched.nodes) == 3
234+
assert sched.layer_allocator.has_full_pipeline()
235+
236+
# Final verification: all nodes should have layer assignments
237+
allocations = sched.list_node_allocations()
238+
assert len(allocations) == 3, "All 3 nodes should have layer assignments"
239+
# Verify full pipeline coverage
240+
total_covered = sum(e - s for _, s, e in allocations)
241+
assert total_covered >= model.num_layers, "All layers should be covered"

0 commit comments

Comments
 (0)