@@ -116,72 +116,116 @@ def __init__(
116116 pass
117117
118118 # Orchestration helpers
119- def bootstrap (self ) -> bool :
120- """Bootstrapping: first-time layer allocation and optional warm-up.
119+ def bootstrap (self , * , clear_existing : bool = False , skip_warmup : bool = False ) -> bool :
120+ """Bootstrapping:
121+ This method can be used for both initial bootstrapping and global rebalancing.
122+ When clear_existing=True, it first deallocates all existing allocations before
123+ performing global allocation (rebalancing behavior). When clear_existing=False,
124+ it performs allocation on top of existing state (initial bootstrapping behavior).
121125
122- Returns True if a full pipeline was established; False otherwise.
126+ Args:
127+ clear_existing: If True, deallocate all existing allocations before reallocating.
128+ This is used for global rebalancing. Default is False.
129+ skip_warmup: If True, skip the warm-up and truncate step. Default is False.
130+
131+ Returns:
132+ True if a full pipeline was established; False otherwise.
123133 """
124- if len (self .nodes ) < self .min_nodes_bootstrapping :
134+ # Check node count only for initial bootstrapping (not rebalancing)
135+ if not clear_existing and len (self .nodes ) < self .min_nodes_bootstrapping :
125136 logger .debug (
126137 f"Bootstrapping deferred: have { len (self .nodes )} nodes; need >= { self .min_nodes_bootstrapping } "
127138 )
128139 return False
129- logger .debug ("Bootstrapping layer allocator" )
140+
141+ # Clear existing allocations if this is a rebalance
142+ if clear_existing :
143+ logger .debug ("Performing global rebalance (clearing existing allocations)" )
144+ self ._bootstrapped = False
145+ self ._bootstrapped_event .clear ()
146+ for n in self .nodes :
147+ if n .start_layer is not None and n .end_layer is not None :
148+ self .layer_allocator .deallocate (n )
149+ else :
150+ logger .debug ("Bootstrapping layer allocator" )
151+
152+ # Perform global allocation
130153 success = self .layer_allocator .global_allocation ()
131154 if not success :
132- logger .warning ("Bootstrapping failed to produce a full pipeline" )
155+ logger .warning ("Global allocation failed to produce a full pipeline" )
133156 return False
157+
134158 assignments = self .list_node_allocations ()
135159 logger .debug (f"Layer allocator assignments: { assignments } " )
160+
136161 # Optional warm-up to find turning points and truncate node ranges
137- if self .request_warm_up_for_reshard > 0 :
162+ # Skip warmup for rebalancing scenarios (can be overridden with skip_warmup=False)
163+ if not skip_warmup and self .request_warm_up_for_reshard > 0 :
138164 self ._run_warmup_and_truncate ()
139165 assignments = self .list_node_allocations ()
140166 logger .debug (f"Layer allocator assignments after turn-point warm-up: { assignments } " )
141167
142168 if not self .layer_allocator .has_full_pipeline ():
143169 logger .warning ("Bootstrapping failed to produce a full pipeline" )
144170 return False
171+
145172 self ._bootstrapped = True
146173 self ._bootstrapped_event .set ()
147- logger .debug ("Bootstrapping completed successfully; full pipeline established" )
174+ action = "rebalance" if clear_existing else "bootstrapping"
175+ logger .debug (f"{ action .capitalize ()} completed successfully; full pipeline established" )
148176 return True
149177
150178 def list_node_allocations (self ) -> List [Tuple [str , int , int ]]:
151179 """List the allocations of all nodes."""
152180 return self .layer_allocator .list_node_allocations ()
153181
154182 # Warm-up and re-shard
155- def _run_warmup_and_truncate (self ) -> None :
183+ def _run_warmup_and_truncate (self , override_warmup_count : int = 0 ) -> None :
156184 """Run a brief warm-up to detect truncation points and shrink shards.
157185
158186 Uses layer-level DP turning points (node_id, layer_idx, kind):
159187 - kind == "tail": drop [layer_idx, end) on that node
160188 - kind == "head": drop [start, layer_idx) on that node
189+
190+ Note: Always uses DynamicProgrammingRouting for finding turning points,
191+ regardless of the current request_router type, since turning points
192+ detection requires layer-level DP analysis.
193+
194+ Args:
195+ override_warmup_count: If > 0, use this value instead of request_warm_up_for_reshard.
196+ Default is 0, which means use request_warm_up_for_reshard.
161197 """
162198 nodes_list = list (self .nodes )
163199 if not nodes_list :
164200 return
165201 num_layers = self .model_info .num_layers
202+
166203 # The number of warm-up requests can be used to repeat detection, but a
167204 # single pass is sufficient with our DP model; we repeat to smooth noise.
205+ warmup_count = (
206+ override_warmup_count if override_warmup_count > 0 else self .request_warm_up_for_reshard
207+ )
208+
168209 agg_turns : Dict [Tuple [str , int , str ], int ] = {}
169- for _ in range (self . request_warm_up_for_reshard ):
170- turns = self . request_router .find_turning_points (nodes_list , num_layers )
210+ for _ in range (warmup_count ):
211+ turns = DynamicProgrammingRouting .find_turning_points (nodes_list , num_layers )
171212 for t in turns :
172213 agg_turns [t ] = agg_turns .get (t , 0 ) + 1
214+
173215 # Apply truncation for consistently observed turning points
216+ # Note: Must use layer_allocator.allocate/deallocate to properly update
217+ # internal state (node_allocation dict and layer_to_load)
174218 for node_id , layer_idx , kind in agg_turns :
175219 node = next ((n for n in self .nodes if n .node_id == node_id ), None )
176220 if node is None or node .start_layer is None or node .end_layer is None :
177221 continue
178222 start , end = node .start_layer , node .end_layer
179223 if kind == "tail" :
180224 if layer_idx < end :
181- node . set_layer_allocation ( start , layer_idx )
225+ self . layer_allocator . reallocate ( node , start , layer_idx )
182226 elif kind == "head" :
183227 if layer_idx > start :
184- node . set_layer_allocation ( layer_idx , end )
228+ self . layer_allocator . reallocate ( node , layer_idx , end )
185229
186230 def update_node_info (
187231 self ,
@@ -316,19 +360,20 @@ def leave(self, node_id: str) -> None:
316360 f"Mixed assignment detected ({ manual_count } manual, { total_count - manual_count } automatic); skipping rebalance"
317361 )
318362 else :
319- # All nodes are automatic, proceed with rebalance
320- self ._bootstrapped = False
321- self ._bootstrapped_event .clear ()
322- for n in self .nodes :
323- if n .start_layer is not None and n .end_layer is not None :
324- self .layer_allocator .deallocate (n )
325- success = self .layer_allocator .global_allocation ()
326- if not success :
327- logger .warning ("Global rebalance failed to produce a full pipeline" )
363+ # All nodes are automatic, try adjustment first, then rebalance if needed
364+ if not self .layer_allocator .has_full_pipeline ():
365+ logger .debug (
366+ "No full pipeline after node leave, attempting warmup and truncate"
367+ )
368+ self ._run_warmup_and_truncate (override_warmup_count = 1 )
369+ if not self .layer_allocator .has_full_pipeline ():
370+ self .bootstrap (clear_existing = True , skip_warmup = True )
371+ else :
372+ logger .debug (
373+ "Pipeline recovered through warmup and truncate, skipping global rebalance"
374+ )
328375 else :
329- logger .debug ("Global rebalance completed successfully" )
330- self ._bootstrapped = True
331- self ._bootstrapped_event .set ()
376+ self .bootstrap (clear_existing = True , skip_warmup = True )
332377
333378 with self ._node_count_cv :
334379 self ._node_count_cv .notify_all ()
0 commit comments