11"""
2- Monkey patch for vLLM weight loading to skip lm_head weights on non-last pipeline stages.
2+ Monkey patch for vLLM weight loading to skip non-existent weights on different pipeline stages.
33This is similar to the approach used in sglang monkey patches.
44"""
55
99logger = logging .getLogger (__name__ )
1010
1111_vllm_patch_applied = False
12+ _is_first_stage = False # Default to False
1213_is_last_stage = True # Default to True for safety
1314
1415
15- def set_vllm_pipeline_stage (is_last_stage : bool ):
16- """Set whether this is the last pipeline stage."""
17- global _is_last_stage
16+ def set_vllm_pipeline_stage (is_first_stage : bool , is_last_stage : bool ):
17+ """Set whether this is the first and/or last pipeline stage."""
18+ global _is_first_stage , _is_last_stage
19+ _is_first_stage = is_first_stage
1820 _is_last_stage = is_last_stage
19- logger .debug (f"Set vLLM pipeline stage: is_last_stage= { is_last_stage } " )
21+ logger .debug (f"Set vLLM pipeline stage: is_first_stage= { _is_first_stage } , is_last_stage= { _is_last_stage } " )
2022
2123
2224def apply_vllm_weight_loader_patch ():
2325 """
24- Apply monkey patch to vLLM's default loader to skip lm_head initialization check
25- when not on the last pipeline stage .
26+ Apply monkey patch to vLLM's default loader to skip initialization checks
27+ for weights that are not expected on certain pipeline stages .
2628
27- This patch intercepts ValueError exceptions during weight loading and checks if they
28- are related to lm_head.weight not being initialized. If this occurs on a non-last
29- pipeline stage, the error is suppressed as expected behavior. Otherwise, the error
30- is re-raised.
29+ - Skips `embed_tokens` check on non-first stages.
30+ - Skips `lm_head` check on non-last stages.
3131 """
3232 global _vllm_patch_applied
3333
@@ -41,28 +41,37 @@ def apply_vllm_weight_loader_patch():
4141 original_load_weights = default_loader .DefaultModelLoader .load_weights
4242
4343 def patched_load_weights (self , model : Any , model_config : Any ):
44- """Patched load_weights that handles lm_head for pipeline parallelism."""
45- global _is_last_stage
44+ """Patched load_weights that handles embed_tokens and lm_head for pipeline parallelism."""
45+ global _is_first_stage , _is_last_stage
4646
4747 try :
4848 # Call original load_weights
4949 original_load_weights (self , model , model_config )
5050 except ValueError as e :
5151 error_msg = str (e )
52- # Check if this is the lm_head initialization error
53- if "lm_head.weight" in error_msg and "not initialized from checkpoint" in error_msg :
52+ uninitialized_weights = "not initialized from checkpoint" in error_msg
53+
54+ # Case 1: embed_tokens.weight not found
55+ if "model.embed_tokens.weight" in error_msg and uninitialized_weights :
56+ if not _is_first_stage :
57+ # Expected behavior for non-first pipeline stages
58+ logger .info ("Skipping embed_tokens.weight initialization check on non-first pipeline stage" )
59+ else :
60+ # This is the first stage, embed_tokens should be initialized
61+ logger .error ("embed_tokens.weight not initialized on first pipeline stage, this is an error" )
62+ raise
63+
64+ # Case 2: lm_head.weight not found
65+ elif "lm_head.weight" in error_msg and uninitialized_weights :
5466 if not _is_last_stage :
5567 # Expected behavior for non-last pipeline stages
56- logger .info (
57- "Skipping lm_head.weight initialization check on non-last pipeline stage"
58- )
59- return
68+ logger .info ("Skipping lm_head.weight initialization check on non-last pipeline stage" )
6069 else :
6170 # This is the last stage, lm_head should be initialized
62- logger .error (
63- "lm_head.weight not initialized on last pipeline stage, this is an error"
64- )
71+ logger .error ("lm_head.weight not initialized on last pipeline stage, this is an error" )
6572 raise
73+
74+ # Case 3: Other errors
6675 else :
6776 # Different error, re-raise
6877 raise
0 commit comments