Skip to content

Commit efc5a0d

Browse files
committed
pre-commit
1 parent 3c67c0f commit efc5a0d

File tree

3 files changed

+18
-19
lines changed

3 files changed

+18
-19
lines changed

src/parallax/vllm/model_runner.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -26,13 +26,13 @@
2626
from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheGroupSpec, KVCacheTensor
2727
from vllm.v1.worker.gpu_model_runner import GPUModelRunner
2828

29-
from parallax.utils.tokenizer_utils import load_tokenizer
30-
from parallax_utils.logging_config import get_logger
3129
from parallax.sglang.monkey_patch_utils.weight_loader_filter import (
3230
apply_weight_loader_filter_patch,
3331
set_layer_range_for_filtering,
3432
)
33+
from parallax.utils.tokenizer_utils import load_tokenizer
3534
from parallax.vllm.monkey_patch import apply_parallax_vllm_monkey_patch
35+
from parallax_utils.logging_config import get_logger
3636

3737
logger = get_logger(__name__)
3838

@@ -200,7 +200,9 @@ def _create_kv_cache_config(self, kv_cache_memory_fraction: float = None) -> KVC
200200
model_dtype = self.vllm_config.model_config.dtype
201201
if isinstance(model_dtype, str):
202202
try:
203-
from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE # type: ignore
203+
from vllm.utils.torch_utils import (
204+
STR_DTYPE_TO_TORCH_DTYPE, # type: ignore
205+
)
204206
except Exception:
205207
# Older/newer vLLM versions may not expose torch_utils.
206208
# Fall back silently and default to float16.
@@ -349,16 +351,14 @@ def initialize_vllm_model_runner(
349351
num_hidden_layers = getattr(config, "num_hidden_layers", 28)
350352
is_first_peer = start_layer == 0
351353
is_last_peer = end_layer == num_hidden_layers
352-
354+
353355
# Apply Parallax vLLM monkey patches for pipeline parallelism
354356
try:
355357
apply_parallax_vllm_monkey_patch(is_last_stage=is_last_peer)
356-
logger.debug(
357-
f"Applied Parallax vLLM monkey patches: is_last_stage={is_last_peer}"
358-
)
358+
logger.debug(f"Applied Parallax vLLM monkey patches: is_last_stage={is_last_peer}")
359359
except Exception as e:
360360
logger.warning("Failed to apply Parallax vLLM monkey patches: %s", e)
361-
361+
362362
# Apply layer-range-based weight file filtering before any model load.
363363
# Reuse the generic monkey patch used by sglang implementation to reduce
364364
# local weight file reads when loading a partial layer shard.

src/parallax/vllm/monkey_patch.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,10 @@
1717
def apply_parallax_vllm_monkey_patch(is_last_stage: bool = True):
1818
"""
1919
Apply all Parallax monkey patches for vLLM.
20-
20+
2121
Args:
2222
is_last_stage: Whether this is the last pipeline stage. This affects
2323
whether lm_head weights are expected to be loaded.
2424
"""
2525
set_vllm_pipeline_stage(is_last_stage)
2626
apply_vllm_weight_loader_patch()
27-

src/parallax/vllm/monkey_patch_utils/weight_loader.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
Monkey patch for vLLM weight loading to skip lm_head weights on non-last pipeline stages.
33
This is similar to the approach used in sglang monkey patches.
44
"""
5+
56
import logging
67
from typing import Any
78

@@ -22,27 +23,27 @@ def apply_vllm_weight_loader_patch():
2223
"""
2324
Apply monkey patch to vLLM's default loader to skip lm_head initialization check
2425
when not on the last pipeline stage.
25-
26+
2627
This patch intercepts ValueError exceptions during weight loading and checks if they
2728
are related to lm_head.weight not being initialized. If this occurs on a non-last
2829
pipeline stage, the error is suppressed as expected behavior. Otherwise, the error
2930
is re-raised.
3031
"""
3132
global _vllm_patch_applied
32-
33+
3334
if _vllm_patch_applied:
3435
logger.debug("vLLM weight loader patch already applied, skipping")
3536
return
36-
37+
3738
try:
3839
from vllm.model_executor.model_loader import default_loader
39-
40+
4041
original_load_weights = default_loader.DefaultModelLoader.load_weights
41-
42+
4243
def patched_load_weights(self, model: Any, model_config: Any):
4344
"""Patched load_weights that handles lm_head for pipeline parallelism."""
4445
global _is_last_stage
45-
46+
4647
try:
4748
# Call original load_weights
4849
original_load_weights(self, model, model_config)
@@ -65,15 +66,14 @@ def patched_load_weights(self, model: Any, model_config: Any):
6566
else:
6667
# Different error, re-raise
6768
raise
68-
69+
6970
# Apply the patch
7071
default_loader.DefaultModelLoader.load_weights = patched_load_weights
7172
_vllm_patch_applied = True
7273
logger.info("Successfully applied vLLM weight loader patch for pipeline parallelism")
73-
74+
7475
except ImportError as e:
7576
logger.warning(f"Could not apply vLLM weight loader patch: {e}")
7677
except Exception as e:
7778
logger.error(f"Error applying vLLM weight loader patch: {e}")
7879
raise
79-

0 commit comments

Comments
 (0)