Skip to content

Commit 41ebfec

Browse files
author
yuhao-zh
committed
update load weights
1 parent c56b7bd commit 41ebfec

File tree

3 files changed

+39
-28
lines changed

3 files changed

+39
-28
lines changed

src/parallax/vllm/model_runner.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -315,6 +315,7 @@ def custom_get_pp_indices(num_layers: int, rank: int, world_size: int):
315315
f"Successfully loaded {self.num_shard_layers} layers "
316316
f"[{self.start_layer}:{self.end_layer}]"
317317
)
318+
318319
finally:
319320
vllm.distributed.utils.get_pp_indices = original_get_pp_indices
320321

@@ -347,15 +348,15 @@ def initialize_vllm_model_runner(
347348
config = load_config(model_path)
348349
tokenizer = load_tokenizer(model_path, eos_token_ids=config.get("eos_token_id", None))
349350
dtype = config.get("torch_dtype", "bfloat16")
350-
351-
num_hidden_layers = getattr(config, "num_hidden_layers", 28)
351+
352+
num_hidden_layers = config.get("num_hidden_layers")
352353
is_first_peer = start_layer == 0
353354
is_last_peer = end_layer == num_hidden_layers
354355

355356
# Apply Parallax vLLM monkey patches for pipeline parallelism
356357
try:
357-
apply_parallax_vllm_monkey_patch(is_last_stage=is_last_peer)
358-
logger.debug(f"Applied Parallax vLLM monkey patches: is_last_stage={is_last_peer}")
358+
apply_parallax_vllm_monkey_patch(is_first_stage=is_first_peer, is_last_stage=is_last_peer)
359+
logger.debug(f"Applied Parallax vLLM monkey patches: is_first_stage={is_first_peer}, is_last_stage={is_last_peer}")
359360
except Exception as e:
360361
logger.warning("Failed to apply Parallax vLLM monkey patches: %s", e)
361362

src/parallax/vllm/monkey_patch.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,13 +14,14 @@
1414
## Here are patch functions for vLLM
1515
## Hopefully, when vLLM supports pipeline parallelism natively in the way we need,
1616
## we can remove these patches
17-
def apply_parallax_vllm_monkey_patch(is_last_stage: bool = True):
17+
def apply_parallax_vllm_monkey_patch(is_first_stage: bool, is_last_stage: bool):
1818
"""
1919
Apply all Parallax monkey patches for vLLM.
2020
2121
Args:
22+
is_first_stage: Whether this is the first pipeline stage.
2223
is_last_stage: Whether this is the last pipeline stage. This affects
2324
whether lm_head weights are expected to be loaded.
2425
"""
25-
set_vllm_pipeline_stage(is_last_stage)
26+
set_vllm_pipeline_stage(is_first_stage, is_last_stage)
2627
apply_vllm_weight_loader_patch()

src/parallax/vllm/monkey_patch_utils/weight_loader.py

Lines changed: 31 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
"""
2-
Monkey patch for vLLM weight loading to skip lm_head weights on non-last pipeline stages.
2+
Monkey patch for vLLM weight loading to skip non-existent weights on different pipeline stages.
33
This is similar to the approach used in sglang monkey patches.
44
"""
55

@@ -9,25 +9,25 @@
99
logger = logging.getLogger(__name__)
1010

1111
_vllm_patch_applied = False
12+
_is_first_stage = False # Default to False
1213
_is_last_stage = True # Default to True for safety
1314

1415

15-
def set_vllm_pipeline_stage(is_last_stage: bool):
16-
"""Set whether this is the last pipeline stage."""
17-
global _is_last_stage
16+
def set_vllm_pipeline_stage(is_first_stage: bool, is_last_stage: bool):
17+
"""Set whether this is the first and/or last pipeline stage."""
18+
global _is_first_stage, _is_last_stage
19+
_is_first_stage = is_first_stage
1820
_is_last_stage = is_last_stage
19-
logger.debug(f"Set vLLM pipeline stage: is_last_stage={is_last_stage}")
21+
logger.debug(f"Set vLLM pipeline stage: is_first_stage={_is_first_stage}, is_last_stage={_is_last_stage}")
2022

2123

2224
def apply_vllm_weight_loader_patch():
2325
"""
24-
Apply monkey patch to vLLM's default loader to skip lm_head initialization check
25-
when not on the last pipeline stage.
26+
Apply monkey patch to vLLM's default loader to skip initialization checks
27+
for weights that are not expected on certain pipeline stages.
2628
27-
This patch intercepts ValueError exceptions during weight loading and checks if they
28-
are related to lm_head.weight not being initialized. If this occurs on a non-last
29-
pipeline stage, the error is suppressed as expected behavior. Otherwise, the error
30-
is re-raised.
29+
- Skips `embed_tokens` check on non-first stages.
30+
- Skips `lm_head` check on non-last stages.
3131
"""
3232
global _vllm_patch_applied
3333

@@ -41,28 +41,37 @@ def apply_vllm_weight_loader_patch():
4141
original_load_weights = default_loader.DefaultModelLoader.load_weights
4242

4343
def patched_load_weights(self, model: Any, model_config: Any):
44-
"""Patched load_weights that handles lm_head for pipeline parallelism."""
45-
global _is_last_stage
44+
"""Patched load_weights that handles embed_tokens and lm_head for pipeline parallelism."""
45+
global _is_first_stage, _is_last_stage
4646

4747
try:
4848
# Call original load_weights
4949
original_load_weights(self, model, model_config)
5050
except ValueError as e:
5151
error_msg = str(e)
52-
# Check if this is the lm_head initialization error
53-
if "lm_head.weight" in error_msg and "not initialized from checkpoint" in error_msg:
52+
uninitialized_weights = "not initialized from checkpoint" in error_msg
53+
54+
# Case 1: embed_tokens.weight not found
55+
if "model.embed_tokens.weight" in error_msg and uninitialized_weights:
56+
if not _is_first_stage:
57+
# Expected behavior for non-first pipeline stages
58+
logger.info("Skipping embed_tokens.weight initialization check on non-first pipeline stage")
59+
else:
60+
# This is the first stage, embed_tokens should be initialized
61+
logger.error("embed_tokens.weight not initialized on first pipeline stage, this is an error")
62+
raise
63+
64+
# Case 2: lm_head.weight not found
65+
elif "lm_head.weight" in error_msg and uninitialized_weights:
5466
if not _is_last_stage:
5567
# Expected behavior for non-last pipeline stages
56-
logger.info(
57-
"Skipping lm_head.weight initialization check on non-last pipeline stage"
58-
)
59-
return
68+
logger.info("Skipping lm_head.weight initialization check on non-last pipeline stage")
6069
else:
6170
# This is the last stage, lm_head should be initialized
62-
logger.error(
63-
"lm_head.weight not initialized on last pipeline stage, this is an error"
64-
)
71+
logger.error("lm_head.weight not initialized on last pipeline stage, this is an error")
6572
raise
73+
74+
# Case 3: Other errors
6675
else:
6776
# Different error, re-raise
6877
raise

0 commit comments

Comments
 (0)