From 32ed415cb80ab7fd098addfea6dee832d79b9a16 Mon Sep 17 00:00:00 2001 From: yuhao_zhang Date: Mon, 20 Oct 2025 14:34:54 +0800 Subject: [PATCH 01/36] init vllm --- src/backend/server/server_args.py | 8 + src/parallax/server/executor.py | 71 ++++--- src/parallax/server/server_args.py | 8 + src/parallax/vllm/__init__.py | 17 ++ src/parallax/vllm/model_runner.py | 295 +++++++++++++++++++++++++++++ 5 files changed, 376 insertions(+), 23 deletions(-) create mode 100644 src/parallax/vllm/__init__.py create mode 100644 src/parallax/vllm/model_runner.py diff --git a/src/backend/server/server_args.py b/src/backend/server/server_args.py index 9624eb11..44e1dcdf 100644 --- a/src/backend/server/server_args.py +++ b/src/backend/server/server_args.py @@ -45,6 +45,14 @@ def parse_args() -> argparse.Namespace: "--is-local-network", type=bool, default=True, help="Whether to use local network" ) + parser.add_argument( + "--gpu_backend", + type=str, + default="sglang", + choices=["sglang", "vllm"], + help="GPU backend to use", + ) + args = parser.parse_args() return args diff --git a/src/parallax/server/executor.py b/src/parallax/server/executor.py index dd0b819e..ba45aefe 100644 --- a/src/parallax/server/executor.py +++ b/src/parallax/server/executor.py @@ -93,39 +93,63 @@ def __init__( # IPC Communication Configs executor_input_ipc_addr: Optional[str] = None, executor_output_ipc_addr: Optional[str] = None, - # GPU/SGLang Specialized Configs + # GPU Backend Configs + gpu_backend: Optional[str] = "sglang", # "sglang" or "vllm" attention_backend: Optional[str] = "torch_native", moe_runner_backend: Optional[str] = "auto", ): # Backend self.device = get_current_device() - logger.debug(f"Executor initializing on device: {self.device}") + self.gpu_backend = gpu_backend + logger.debug(f"Executor initializing on device: {self.device}, gpu_backend: {gpu_backend}") # Sharded Model if self.device == "cuda": - from sglang.srt.managers.schedule_batch import ScheduleBatch + if gpu_backend == "vllm": + from parallax.vllm.model_runner import initialize_vllm_model_runner - from parallax.sglang.model_runner import initialize_sgl_model_runner + logger.debug( + f"Initializing vLLM model runner for repo={model_repo}, layers=[{start_layer}, {end_layer})" + ) + self.model_runner, self.config, self.tokenizer = initialize_vllm_model_runner( + model_repo, + start_layer, + end_layer, + kv_cache_memory_fraction, + kv_block_size, + max_num_seqs=max_batch_size, + max_model_len=max_sequence_length, + ) + logger.debug( + f"vLLM model runner initialized. num_layers={self.config.get('num_hidden_layers')}" + ) + # vLLM manages its own KV cache and batching + self.running_batch = None + self.cur_batch = None + else: # sglang backend + from sglang.srt.managers.schedule_batch import ScheduleBatch - logger.debug( - f"Initializing CUDA model runner for repo={model_repo}, layers=[{start_layer}, {end_layer})" - ) - self.model_runner, self.config, self.tokenizer = initialize_sgl_model_runner( - model_repo, - start_layer, - end_layer, - kv_cache_memory_fraction, - attention_backend, - kv_block_size, - moe_runner_backend, - ) - logger.debug( - f"CUDA model runner initialized. num_layers={self.config.get('num_hidden_layers')}" - ) - # SGL KV Cache Manager is already initialized in ScheduleBatch - # TODO: Replace ScheduleBatch to Parallax inflight batch - self.running_batch = ScheduleBatch(reqs=[], batch_is_full=False) - self.cur_batch = None + from parallax.sglang.model_runner import initialize_sgl_model_runner + + logger.debug( + f"Initializing SGLang model runner for repo={model_repo}, layers=[{start_layer}, {end_layer})" + ) + self.model_runner, self.config, self.tokenizer = initialize_sgl_model_runner( + model_repo, + start_layer, + end_layer, + kv_cache_memory_fraction, + attention_backend, + kv_block_size, + moe_runner_backend, + ) + logger.debug( + f"SGLang model runner initialized. num_layers={self.config.get('num_hidden_layers')}" + ) + # SGL KV Cache Manager is already initialized in ScheduleBatch + # TODO: Replace ScheduleBatch to Parallax inflight batch + self.running_batch = ScheduleBatch(reqs=[], batch_is_full=False) + self.cur_batch = None else: logger.debug( f"Initializing MLX sharded model loader for repo={model_repo}, layers=[{start_layer}, {end_layer})" @@ -1230,6 +1254,7 @@ def create_executor_config(args: argparse.Namespace): "recv_from_peer_addr": args.recv_from_peer_addr if "recv_from_peer_addr" in args else None, "executor_input_ipc_addr": args.executor_input_ipc, "executor_output_ipc_addr": args.executor_output_ipc, + "gpu_backend": args.gpu_backend if hasattr(args, "gpu_backend") else "sglang", "attention_backend": args.attention_backend, "moe_runner_backend": args.moe_runner_backend, } diff --git a/src/parallax/server/server_args.py b/src/parallax/server/server_args.py index 7b8eead8..c488e821 100644 --- a/src/parallax/server/server_args.py +++ b/src/parallax/server/server_args.py @@ -165,6 +165,14 @@ def parse_args() -> argparse.Namespace: parser.add_argument("--verbose", action="store_true", help="Enable verbose logging") + parser.add_argument( + "--gpu_backend", + type=str, + default="sglang", + choices=["sglang", "vllm"], + help="GPU backend to use", + ) + args = parser.parse_args() # Validate arguments diff --git a/src/parallax/vllm/__init__.py b/src/parallax/vllm/__init__.py new file mode 100644 index 00000000..ed34a077 --- /dev/null +++ b/src/parallax/vllm/__init__.py @@ -0,0 +1,17 @@ +""" +vLLM backend integration for Parallax distributed inference. + +This module provides vLLM model runner with pipeline parallelism support. +""" + +from parallax.vllm.model_runner import ( + ParallaxVLLMEngine, + form_vllm_engine_args, + initialize_vllm_model_runner, +) + +__all__ = [ + "ParallaxVLLMEngine", + "form_vllm_engine_args", + "initialize_vllm_model_runner", +] diff --git a/src/parallax/vllm/model_runner.py b/src/parallax/vllm/model_runner.py new file mode 100644 index 00000000..c34cd3cd --- /dev/null +++ b/src/parallax/vllm/model_runner.py @@ -0,0 +1,295 @@ +""" +Imports vLLM ModelRunner related modules and wrap them into create functions. +We use monkey patch to modify vLLM originated methods. The main purpose is to pass +arguments needed by decentralized inference with pipeline parallelism. +""" + +import logging +import os +import random +from typing import Any, Dict, List, Optional, Tuple + +import torch +from mlx_lm.utils import get_model_path, load_config +from vllm import EngineArgs, LLMEngine +from vllm.config import ( + CacheConfig, + DeviceConfig, + LoadConfig, + LoRAConfig, + ModelConfig, + ParallelConfig, + SchedulerConfig, +) +from vllm.executor.ray_gpu_executor import RayGPUExecutor +from vllm.model_executor.layers.sampler import Sampler +from vllm.sequence import SamplerOutput, SequenceGroupMetadata +from vllm.utils import get_distributed_init_method, get_ip, get_open_port + +from parallax.utils.tokenizer_utils import load_tokenizer + +logger = logging.getLogger(__name__) + + +class ParallaxVLLMEngine: + """ + Wrapper around vLLM Engine that supports pipeline parallelism for decentralized inference. + This class handles the sharding of layers across different nodes. + """ + + def __init__( + self, + model_config: ModelConfig, + cache_config: CacheConfig, + parallel_config: ParallelConfig, + scheduler_config: SchedulerConfig, + device_config: DeviceConfig, + load_config: LoadConfig, + lora_config: Optional[LoRAConfig], + pp_start_layer: int, + pp_end_layer: int, + **kwargs, + ): + """ + Initialize ParallaxVLLMEngine with pipeline parallelism support. + + Args: + model_config: vLLM model configuration + cache_config: vLLM cache configuration + parallel_config: vLLM parallel configuration + scheduler_config: vLLM scheduler configuration + device_config: vLLM device configuration + load_config: vLLM load configuration + lora_config: Optional LoRA configuration + pp_start_layer: Starting layer index for this shard (inclusive) + pp_end_layer: Ending layer index for this shard (exclusive) + """ + self.pp_start_layer = pp_start_layer + self.pp_end_layer = pp_end_layer + self.model_config = model_config + self.cache_config = cache_config + self.parallel_config = parallel_config + self.scheduler_config = scheduler_config + self.device_config = device_config + self.load_config = load_config + self.lora_config = lora_config + + # Modify model config to only load specified layers + self.model_config.hf_config.start_layer = pp_start_layer + self.model_config.hf_config.end_layer = pp_end_layer + + # Initialize the vLLM engine + # Note: vLLM doesn't natively support arbitrary layer sharding, + # so we need to monkey patch the model loading + from vllm.worker.model_runner import ModelRunner + + self.model_runner = None + self.is_first_peer = pp_start_layer == 0 + self.is_last_peer = pp_end_layer == model_config.hf_config.num_hidden_layers + + logger.info( + f"Initialized ParallaxVLLMEngine: layers [{pp_start_layer}, {pp_end_layer}), " + f"is_first={self.is_first_peer}, is_last={self.is_last_peer}" + ) + + def initialize_model(self): + """Initialize the model with the specified layer range.""" + # Import here to avoid circular dependency + from vllm.worker.worker import Worker + + # Create worker with modified configuration + worker = Worker( + model_config=self.model_config, + parallel_config=self.parallel_config, + scheduler_config=self.scheduler_config, + device_config=self.device_config, + cache_config=self.cache_config, + load_config=self.load_config, + local_rank=0, + rank=0, + distributed_init_method=get_distributed_init_method(get_ip(), get_open_port()), + lora_config=self.lora_config, + kv_cache_dtype=self.cache_config.cache_dtype, + ) + + # Initialize worker + worker.init_device() + worker.load_model() + + self.model_runner = worker.model_runner + logger.info("vLLM model loaded successfully") + + def execute_model( + self, + seq_group_metadata_list: List[SequenceGroupMetadata], + kv_caches: List[torch.Tensor], + ) -> SamplerOutput: + """ + Execute the model on the given sequences. + + Args: + seq_group_metadata_list: List of sequence group metadata + kv_caches: List of KV cache tensors + + Returns: + SamplerOutput containing logits or sampled tokens + """ + if self.model_runner is None: + raise RuntimeError("Model not initialized. Call initialize_model() first.") + + return self.model_runner.execute_model( + seq_group_metadata_list=seq_group_metadata_list, kv_caches=kv_caches + ) + + +def form_vllm_engine_args( + model_path: str, + dtype: str = "bfloat16", + kv_block_size: int = 16, + gpu_memory_utilization: float = 0.85, + max_num_seqs: int = 256, + max_model_len: Optional[int] = None, + enforce_eager: bool = False, + **kwargs, +) -> EngineArgs: + """ + Creates vLLM EngineArgs object with Parallax-specific configurations. + + Args: + model_path: Path or name of the model + dtype: Data type for model weights (e.g., "bfloat16", "float16") + kv_block_size: Block size for paged attention KV cache + gpu_memory_utilization: Fraction of GPU memory to use + max_num_seqs: Maximum number of sequences to process + max_model_len: Maximum model context length + enforce_eager: Whether to enforce eager execution (disable CUDA graphs) + + Returns: + EngineArgs: vLLM engine arguments + """ + engine_args = EngineArgs( + model=model_path, + dtype=dtype, + tokenizer=model_path, + trust_remote_code=True, + gpu_memory_utilization=gpu_memory_utilization, + max_num_seqs=max_num_seqs, + max_model_len=max_model_len, + block_size=kv_block_size, + enforce_eager=enforce_eager, + # Disable tensor parallelism for now (will be handled by Parallax) + tensor_parallel_size=1, + pipeline_parallel_size=1, + **kwargs, + ) + return engine_args + + +def initialize_vllm_model_runner( + original_model_path: str, + start_layer: int, + end_layer: int, + kv_cache_memory_fraction: float, + kv_block_size: int, + max_num_seqs: int = 256, + max_model_len: Optional[int] = None, + enforce_eager: bool = False, +) -> Tuple[ParallaxVLLMEngine, Dict[str, Any], Any]: + """ + Creates a Parallax vLLM Engine object for decentralized inference. + + Args: + original_model_path: Original model path or name + start_layer: Starting layer index (inclusive) + end_layer: Ending layer index (exclusive) + kv_cache_memory_fraction: Fraction of memory for KV cache + kv_block_size: Block size for paged attention + max_num_seqs: Maximum number of sequences + max_model_len: Maximum model context length + enforce_eager: Whether to disable CUDA graphs + + Returns: + Tuple of (vllm_engine, config_dict, tokenizer) + """ + # Load model configuration + model_path = get_model_path(original_model_path)[0] + config = load_config(model_path) + tokenizer = load_tokenizer(model_path, eos_token_ids=config.get("eos_token_id", None)) + + # Get dtype from config + dtype = str(config.get("torch_dtype", "bfloat16")).replace("torch.", "") + + # Create engine args + engine_args = form_vllm_engine_args( + model_path=original_model_path, + dtype=dtype, + kv_block_size=kv_block_size, + gpu_memory_utilization=kv_cache_memory_fraction, + max_num_seqs=max_num_seqs, + max_model_len=max_model_len, + enforce_eager=enforce_eager, + ) + + # Create model, cache, parallel, scheduler, and device configs + model_config = ModelConfig( + model=original_model_path, + tokenizer=original_model_path, + tokenizer_mode="auto", + trust_remote_code=True, + dtype=dtype, + seed=0, + max_model_len=max_model_len, + ) + + cache_config = CacheConfig( + block_size=kv_block_size, + gpu_memory_utilization=kv_cache_memory_fraction, + swap_space=4, # GB + cache_dtype=dtype, + ) + + parallel_config = ParallelConfig( + pipeline_parallel_size=1, + tensor_parallel_size=1, + worker_use_ray=False, + max_parallel_loading_workers=None, + ) + + scheduler_config = SchedulerConfig( + max_num_batched_tokens=None, + max_num_seqs=max_num_seqs, + max_model_len=model_config.max_model_len, + ) + + device_config = DeviceConfig(device="cuda") + + load_config = LoadConfig( + load_format="auto", + download_dir=None, + model_loader_extra_config=None, + ) + + # Create Parallax vLLM Engine + logger.info( + f"Creating ParallaxVLLMEngine: model={original_model_path}, " + f"layers=[{start_layer}, {end_layer}), dtype={dtype}" + ) + + vllm_engine = ParallaxVLLMEngine( + model_config=model_config, + cache_config=cache_config, + parallel_config=parallel_config, + scheduler_config=scheduler_config, + device_config=device_config, + load_config=load_config, + lora_config=None, + pp_start_layer=start_layer, + pp_end_layer=end_layer, + ) + + # Initialize the model + vllm_engine.initialize_model() + + logger.info(f"vLLM model runner initialized for layers [{start_layer}, {end_layer})") + + return vllm_engine, config, tokenizer From 15be623f33382d61cd4ebfef7bc4e278a0146b70 Mon Sep 17 00:00:00 2001 From: Alien mac air <2214632589@qq.com> Date: Mon, 20 Oct 2025 21:40:51 +0800 Subject: [PATCH 02/36] update --- src/parallax/server/executor.py | 111 +++++++++++++++++++++++++++--- src/parallax/vllm/batch_info.py | 80 +++++++++++++++++++++ src/parallax/vllm/model_runner.py | 68 +++++++++++++++++- 3 files changed, 250 insertions(+), 9 deletions(-) create mode 100644 src/parallax/vllm/batch_info.py diff --git a/src/parallax/server/executor.py b/src/parallax/server/executor.py index ba45aefe..7c23c623 100644 --- a/src/parallax/server/executor.py +++ b/src/parallax/server/executor.py @@ -373,13 +373,50 @@ def _prepare_cuda_prefill_batch(self, batched_requests: List[Request]) -> Dict[s Prepares inputs for SGLang model runner from a batch of prefill requests. Returns: SGLang ScheduleBatch """ - from sglang.srt.model_executor.forward_batch_info import PPProxyTensors - - from parallax.sglang.batch_info import form_sgl_batch_prefill - batch_size = len(batched_requests) if batch_size == 0: return None + + if self.gpu_backend == "vllm": + from parallax.vllm.batch_info import form_vllm_batch_prefill + + pp_proxy_tensors = None + if not self.is_first_peer: + hidden_states = torch.cat( + [ + ( + req.hidden_states + if req.hidden_states.ndim == 2 + else req.hidden_states.unsqueeze(0) + ) + for req in batched_requests + ], + dim=0, + ) + residual = torch.zeros( + hidden_states.shape, dtype=hidden_states.dtype, device=hidden_states.device + ) + pp_proxy_tensors = { + "hidden_states": hidden_states, + "residual": residual, + } + logger.debug(f"PP Proxy: hidden_states shape: {hidden_states.shape}") + + fb = None + if self.is_first_peer: + fb = form_vllm_batch_prefill(batched_requests, self.pad_token_id) + lengths = [req.total_length for req in batched_requests] + return { + "input_ids": fb["input_ids"] if fb else None, + "pp_proxy_tensors": pp_proxy_tensors, + "lengths": torch.tensor(lengths, device=self.device), + "requests": batched_requests, + } + + # sglang 路径 + from sglang.srt.model_executor.forward_batch_info import PPProxyTensors + from parallax.sglang.batch_info import form_sgl_batch_prefill + schedule_batch, forward_batch = form_sgl_batch_prefill(batched_requests, self.model_runner) self.cur_batch = schedule_batch @@ -423,14 +460,48 @@ def _prepare_cuda_decode_batch(self, batched_requests: List[Request]) -> Dict[st Prepares inputs for SGLang model runner from a batch of decode requests. Returns: SGLang ScheduleBatch """ - from sglang.srt.model_executor.forward_batch_info import PPProxyTensors - - from parallax.sglang.batch_info import form_sgl_batch_decode - batch_size = len(batched_requests) if batch_size == 0: return None + if self.gpu_backend == "vllm": + from parallax.vllm.batch_info import form_vllm_batch_decode + + pp_proxy_tensors = None + if not self.is_first_peer: + hidden_states = torch.cat( + [ + ( + req.hidden_states + if req.hidden_states.ndim == 2 + else req.hidden_states.unsqueeze(0) + ) + for req in batched_requests + ], + dim=0, + ) + residual = torch.zeros( + hidden_states.shape, dtype=hidden_states.dtype, device=hidden_states.device + ) + pp_proxy_tensors = { + "hidden_states": hidden_states, + "residual": residual, + } + logger.debug(f"PP Proxy: hidden_states shape: {hidden_states.shape}") + + fb = form_vllm_batch_decode(batched_requests, self.is_first_peer) + lengths = [req.total_length for req in batched_requests] + return { + "input_ids": fb["input_ids"] if fb else None, + "pp_proxy_tensors": pp_proxy_tensors, + "lengths": torch.tensor(lengths, device=self.device), + "requests": batched_requests, + } + + # sglang 路径 + from sglang.srt.model_executor.forward_batch_info import PPProxyTensors + from parallax.sglang.batch_info import form_sgl_batch_decode + lengths = [] for req in batched_requests: lengths.append(req.total_length) @@ -1003,6 +1074,30 @@ def _process_batch_cuda( """ Process a batch of requests in CUDA. """ + if self.gpu_backend == "vllm": + # vLLM 自定义前向 + input_ids = prepared_inputs.get("input_ids") + pp_proxy_tensors = prepared_inputs.get("pp_proxy_tensors") + ret = self.model_runner.forward( + input_ids=input_ids, + lengths=prepared_inputs.get("lengths"), + pp_proxy_tensors=pp_proxy_tensors, + return_logits=(self.is_last_peer and return_decoded_tokens), + ) + + if self.is_last_peer and return_decoded_tokens: + logits = ret.get("logits") + assert logits is not None, "vLLM last peer must return logits" + next_token_ids = self.model_runner.sample_argmax(logits) + return next_token_ids + + # 其它 peer:返回 hidden_states + residual + hidden_states = ret.get("hidden_states") + residual = ret.get("residual") + assert hidden_states is not None and residual is not None + return hidden_states + residual + + # sglang 路径 assert "forward_batch" in prepared_inputs, "forward_batch should be in cuda prepared inputs" assert ( "pp_proxy_tensors" in prepared_inputs diff --git a/src/parallax/vllm/batch_info.py b/src/parallax/vllm/batch_info.py new file mode 100644 index 00000000..8a07e844 --- /dev/null +++ b/src/parallax/vllm/batch_info.py @@ -0,0 +1,80 @@ +""" +为 vLLM 后端组装批次输入,接口尽量与 sglang 的 form_* 类似,便于在 executor 中切换。 +注意:这里不做 KV 管理,由各 peer 内的模型自行处理(性能可能不及 vLLM 内部调度)。 +""" + +from __future__ import annotations + +from typing import Any, Dict, List, Tuple + +import torch + +from parallax.server.request import Request + + +def _pad_2d(seqs: List[torch.Tensor], pad_id: int) -> Tuple[torch.Tensor, torch.Tensor]: + """将一组 1D token 张量 padding 为 2D (B, L_max),返回 padded 和 长度 tensor。""" + if not seqs: + return torch.empty(0, 0, dtype=torch.long), torch.empty(0, dtype=torch.long) + max_len = max(x.numel() for x in seqs) + bsz = len(seqs) + padded = torch.full((bsz, max_len), pad_id, dtype=seqs[0].dtype, device=seqs[0].device) + lengths = torch.empty(bsz, dtype=torch.long, device=seqs[0].device) + for i, x in enumerate(seqs): + L = x.numel() + padded[i, :L] = x + lengths[i] = L + return padded, lengths + + +def form_vllm_batch_prefill( + batched_requests: List[Request], pad_token_id: int +) -> Dict[str, Any]: + """首个 peer: 使用 input_ids;中间/最后 peer: 由 executor 传入 intermediate_tensors。 + 这里仅组装 input_ids/lengths/requests(供首个 peer 使用)。 + """ + if len(batched_requests) == 0: + return None + # 收集 tokens(first peer 情况) + token_lists: List[torch.Tensor] = [] + for req in batched_requests: + assert hasattr(req, "input_ids") and req.input_ids is not None + # 将 list[int] 转为 torch tensor + token_lists.append(torch.tensor(req.input_ids, dtype=torch.long, device="cuda")) + input_ids, lengths = _pad_2d(token_lists, pad_token_id) + return { + "input_ids": input_ids, + "lengths": lengths, + "requests": batched_requests, + } + + +def form_vllm_batch_decode( + batched_requests: List[Request], is_first_peer: bool +) -> Dict[str, Any]: + """解码批次: + - 首个 peer: 仅传最后一个 token。 + - 中间/最后 peer: 由 executor 提供 intermediate_tensors。 + 这里只组装首个 peer 所需的输入。 + """ + if len(batched_requests) == 0: + return None + if not is_first_peer: + # 非首个 peer 不需要 tokens 输入 + return { + "input_ids": None, + "lengths": torch.tensor([1 for _ in batched_requests], device="cuda"), + "requests": batched_requests, + } + last_tokens: List[torch.Tensor] = [] + for req in batched_requests: + assert req.output_ids is not None and len(req.output_ids) > 0 + last_tokens.append(torch.tensor([req.output_ids[-1]], dtype=torch.long, device="cuda")) + input_ids, lengths = _pad_2d(last_tokens, pad_id=0) + return { + "input_ids": input_ids, + "lengths": lengths, + "requests": batched_requests, + } + + diff --git a/src/parallax/vllm/model_runner.py b/src/parallax/vllm/model_runner.py index c34cd3cd..392ca18b 100644 --- a/src/parallax/vllm/model_runner.py +++ b/src/parallax/vllm/model_runner.py @@ -23,7 +23,7 @@ ) from vllm.executor.ray_gpu_executor import RayGPUExecutor from vllm.model_executor.layers.sampler import Sampler -from vllm.sequence import SamplerOutput, SequenceGroupMetadata +from vllm.sequence import SamplerOutput, IntermediateTensors from vllm.utils import get_distributed_init_method, get_ip, get_open_port from parallax.utils.tokenizer_utils import load_tokenizer @@ -119,6 +119,72 @@ def initialize_model(self): self.model_runner = worker.model_runner logger.info("vLLM model loaded successfully") + def forward( + self, + *, + input_ids: Optional[torch.Tensor], + lengths: Optional[torch.Tensor], + pp_proxy_tensors: Optional[Dict[str, torch.Tensor]], + return_logits: bool, + ) -> Dict[str, torch.Tensor]: + """ + 进行一次前向: + - 首个 peer 传 input_ids/lengths; + - 其它 peer 传 pp_proxy_tensors={hidden_states,residual}; + - 最后一个 peer 设置 return_logits=True 以便采样。 + """ + assert self.model_runner is not None, "Model not initialized" + + # positions(简单从长度构建,形状 [B, L] -> 递增序列);非首个 peer 不需要 + positions = None + if input_ids is not None and lengths is not None: + batch_size, max_len = input_ids.shape + positions = torch.arange(max_len, device=input_ids.device).unsqueeze(0).expand(batch_size, -1) + + inter_tensors = None + if pp_proxy_tensors is not None: + inter_tensors = IntermediateTensors( + tensors={ + "hidden_states": pp_proxy_tensors["hidden_states"], + "residual": pp_proxy_tensors["residual"], + } + ) + + # 直接调用底层模型 + outputs = self.model_runner.model( + input_ids=input_ids, + positions=positions, + intermediate_tensors=inter_tensors, + inputs_embeds=None, + ) + + # 约定:模型返回包含 hidden_states/residual 或 logits + ret: Dict[str, torch.Tensor] = {} + if isinstance(outputs, dict): + if "hidden_states" in outputs: + ret["hidden_states"] = outputs["hidden_states"] + if "residual" in outputs: + ret["residual"] = outputs["residual"] + if return_logits and "logits" in outputs: + ret["logits"] = outputs["logits"] + elif isinstance(outputs, tuple): + # 兜底:假设 (hidden_states, residual) 或 (logits,) + if return_logits and len(outputs) == 1: + ret["logits"] = outputs[0] + elif len(outputs) >= 2: + ret["hidden_states"] = outputs[0] + ret["residual"] = outputs[1] + else: + # 可能直接返回 logits + if return_logits: + ret["logits"] = outputs + return ret + + @staticmethod + def sample_argmax(logits: torch.Tensor) -> torch.Tensor: + """简单贪心采样。logits: [B, V] -> token_ids: [B]""" + return torch.argmax(logits, dim=-1) + def execute_model( self, seq_group_metadata_list: List[SequenceGroupMetadata], From 544a7a230c77085f1009c7c3f76c31a370f8b76b Mon Sep 17 00:00:00 2001 From: yuhao_zhang Date: Tue, 21 Oct 2025 10:09:20 +0800 Subject: [PATCH 03/36] rebase --- src/parallax/server/executor.py | 199 ++++------------ src/parallax/vllm/batch_info.py | 170 ++++++++------ src/parallax/vllm/model_runner.py | 362 +----------------------------- 3 files changed, 152 insertions(+), 579 deletions(-) diff --git a/src/parallax/server/executor.py b/src/parallax/server/executor.py index 7c23c623..1da3df2e 100644 --- a/src/parallax/server/executor.py +++ b/src/parallax/server/executor.py @@ -93,63 +93,39 @@ def __init__( # IPC Communication Configs executor_input_ipc_addr: Optional[str] = None, executor_output_ipc_addr: Optional[str] = None, - # GPU Backend Configs - gpu_backend: Optional[str] = "sglang", # "sglang" or "vllm" + # GPU/SGLang Specialized Configs attention_backend: Optional[str] = "torch_native", moe_runner_backend: Optional[str] = "auto", ): # Backend self.device = get_current_device() - self.gpu_backend = gpu_backend - logger.debug(f"Executor initializing on device: {self.device}, gpu_backend: {gpu_backend}") + logger.debug(f"Executor initializing on device: {self.device}") # Sharded Model if self.device == "cuda": - if gpu_backend == "vllm": - from parallax.vllm.model_runner import initialize_vllm_model_runner + from sglang.srt.managers.schedule_batch import ScheduleBatch - logger.debug( - f"Initializing vLLM model runner for repo={model_repo}, layers=[{start_layer}, {end_layer})" - ) - self.model_runner, self.config, self.tokenizer = initialize_vllm_model_runner( - model_repo, - start_layer, - end_layer, - kv_cache_memory_fraction, - kv_block_size, - max_num_seqs=max_batch_size, - max_model_len=max_sequence_length, - ) - logger.debug( - f"vLLM model runner initialized. num_layers={self.config.get('num_hidden_layers')}" - ) - # vLLM manages its own KV cache and batching - self.running_batch = None - self.cur_batch = None - else: # sglang backend - from sglang.srt.managers.schedule_batch import ScheduleBatch - - from parallax.sglang.model_runner import initialize_sgl_model_runner + from parallax.sglang.model_runner import initialize_sgl_model_runner - logger.debug( - f"Initializing SGLang model runner for repo={model_repo}, layers=[{start_layer}, {end_layer})" - ) - self.model_runner, self.config, self.tokenizer = initialize_sgl_model_runner( - model_repo, - start_layer, - end_layer, - kv_cache_memory_fraction, - attention_backend, - kv_block_size, - moe_runner_backend, - ) - logger.debug( - f"SGLang model runner initialized. num_layers={self.config.get('num_hidden_layers')}" - ) - # SGL KV Cache Manager is already initialized in ScheduleBatch - # TODO: Replace ScheduleBatch to Parallax inflight batch - self.running_batch = ScheduleBatch(reqs=[], batch_is_full=False) - self.cur_batch = None + logger.debug( + f"Initializing CUDA model runner for repo={model_repo}, layers=[{start_layer}, {end_layer})" + ) + self.model_runner, self.config, self.tokenizer = initialize_sgl_model_runner( + model_repo, + start_layer, + end_layer, + kv_cache_memory_fraction, + attention_backend, + kv_block_size, + moe_runner_backend, + ) + logger.debug( + f"CUDA model runner initialized. num_layers={self.config.get('num_hidden_layers')}" + ) + # SGL KV Cache Manager is already initialized in ScheduleBatch + # TODO: Replace ScheduleBatch to Parallax inflight batch + self.running_batch = ScheduleBatch(reqs=[], batch_is_full=False) + self.cur_batch = None else: logger.debug( f"Initializing MLX sharded model loader for repo={model_repo}, layers=[{start_layer}, {end_layer})" @@ -167,6 +143,7 @@ def __init__( self.finished_batch = [] self.start_layer = start_layer self.end_layer = end_layer + self.is_first_peer = start_layer == 0 self.is_last_peer = end_layer == self.config.get("num_hidden_layers") self.num_shard_layers = end_layer - start_layer @@ -340,6 +317,22 @@ def recv_requests_from_peer(self) -> List[Request]: forward_request = forward_pb2.ForwardRequest() forward_request.ParseFromString(recv_req[1]) recv_req = proto_to_request(forward_request, self.device) + + # Convert hidden_states dtype if necessary + if recv_req is not None and len(recv_req) > 0: + for req in recv_req: + if req.hidden_states is not None: + if req.hidden_states.dtype != self.dtype: + logger.debug( + f"Converting hidden_states dtype from {req.hidden_states.dtype} to {self.dtype} for request {req.request_id}" + ) + if self.device == "cuda": + req.hidden_states = req.hidden_states.to(self.dtype) + elif self.device == "mlx": + req.hidden_states = req.hidden_states.astype(self.dtype) + else: + raise ValueError(f"Unsupported device type: {self.device}") + # Move current position for first peer if self.is_first_peer: for req in recv_req: @@ -373,50 +366,13 @@ def _prepare_cuda_prefill_batch(self, batched_requests: List[Request]) -> Dict[s Prepares inputs for SGLang model runner from a batch of prefill requests. Returns: SGLang ScheduleBatch """ - batch_size = len(batched_requests) - if batch_size == 0: - return None - - if self.gpu_backend == "vllm": - from parallax.vllm.batch_info import form_vllm_batch_prefill - - pp_proxy_tensors = None - if not self.is_first_peer: - hidden_states = torch.cat( - [ - ( - req.hidden_states - if req.hidden_states.ndim == 2 - else req.hidden_states.unsqueeze(0) - ) - for req in batched_requests - ], - dim=0, - ) - residual = torch.zeros( - hidden_states.shape, dtype=hidden_states.dtype, device=hidden_states.device - ) - pp_proxy_tensors = { - "hidden_states": hidden_states, - "residual": residual, - } - logger.debug(f"PP Proxy: hidden_states shape: {hidden_states.shape}") - - fb = None - if self.is_first_peer: - fb = form_vllm_batch_prefill(batched_requests, self.pad_token_id) - lengths = [req.total_length for req in batched_requests] - return { - "input_ids": fb["input_ids"] if fb else None, - "pp_proxy_tensors": pp_proxy_tensors, - "lengths": torch.tensor(lengths, device=self.device), - "requests": batched_requests, - } - - # sglang 路径 from sglang.srt.model_executor.forward_batch_info import PPProxyTensors + from parallax.sglang.batch_info import form_sgl_batch_prefill + batch_size = len(batched_requests) + if batch_size == 0: + return None schedule_batch, forward_batch = form_sgl_batch_prefill(batched_requests, self.model_runner) self.cur_batch = schedule_batch @@ -460,48 +416,14 @@ def _prepare_cuda_decode_batch(self, batched_requests: List[Request]) -> Dict[st Prepares inputs for SGLang model runner from a batch of decode requests. Returns: SGLang ScheduleBatch """ + from sglang.srt.model_executor.forward_batch_info import PPProxyTensors + + from parallax.sglang.batch_info import form_sgl_batch_decode + batch_size = len(batched_requests) if batch_size == 0: return None - if self.gpu_backend == "vllm": - from parallax.vllm.batch_info import form_vllm_batch_decode - - pp_proxy_tensors = None - if not self.is_first_peer: - hidden_states = torch.cat( - [ - ( - req.hidden_states - if req.hidden_states.ndim == 2 - else req.hidden_states.unsqueeze(0) - ) - for req in batched_requests - ], - dim=0, - ) - residual = torch.zeros( - hidden_states.shape, dtype=hidden_states.dtype, device=hidden_states.device - ) - pp_proxy_tensors = { - "hidden_states": hidden_states, - "residual": residual, - } - logger.debug(f"PP Proxy: hidden_states shape: {hidden_states.shape}") - - fb = form_vllm_batch_decode(batched_requests, self.is_first_peer) - lengths = [req.total_length for req in batched_requests] - return { - "input_ids": fb["input_ids"] if fb else None, - "pp_proxy_tensors": pp_proxy_tensors, - "lengths": torch.tensor(lengths, device=self.device), - "requests": batched_requests, - } - - # sglang 路径 - from sglang.srt.model_executor.forward_batch_info import PPProxyTensors - from parallax.sglang.batch_info import form_sgl_batch_decode - lengths = [] for req in batched_requests: lengths.append(req.total_length) @@ -1074,30 +996,6 @@ def _process_batch_cuda( """ Process a batch of requests in CUDA. """ - if self.gpu_backend == "vllm": - # vLLM 自定义前向 - input_ids = prepared_inputs.get("input_ids") - pp_proxy_tensors = prepared_inputs.get("pp_proxy_tensors") - ret = self.model_runner.forward( - input_ids=input_ids, - lengths=prepared_inputs.get("lengths"), - pp_proxy_tensors=pp_proxy_tensors, - return_logits=(self.is_last_peer and return_decoded_tokens), - ) - - if self.is_last_peer and return_decoded_tokens: - logits = ret.get("logits") - assert logits is not None, "vLLM last peer must return logits" - next_token_ids = self.model_runner.sample_argmax(logits) - return next_token_ids - - # 其它 peer:返回 hidden_states + residual - hidden_states = ret.get("hidden_states") - residual = ret.get("residual") - assert hidden_states is not None and residual is not None - return hidden_states + residual - - # sglang 路径 assert "forward_batch" in prepared_inputs, "forward_batch should be in cuda prepared inputs" assert ( "pp_proxy_tensors" in prepared_inputs @@ -1349,7 +1247,6 @@ def create_executor_config(args: argparse.Namespace): "recv_from_peer_addr": args.recv_from_peer_addr if "recv_from_peer_addr" in args else None, "executor_input_ipc_addr": args.executor_input_ipc, "executor_output_ipc_addr": args.executor_output_ipc, - "gpu_backend": args.gpu_backend if hasattr(args, "gpu_backend") else "sglang", "attention_backend": args.attention_backend, "moe_runner_backend": args.moe_runner_backend, } diff --git a/src/parallax/vllm/batch_info.py b/src/parallax/vllm/batch_info.py index 8a07e844..257b8ec5 100644 --- a/src/parallax/vllm/batch_info.py +++ b/src/parallax/vllm/batch_info.py @@ -1,80 +1,116 @@ -""" -为 vLLM 后端组装批次输入,接口尽量与 sglang 的 form_* 类似,便于在 executor 中切换。 -注意:这里不做 KV 管理,由各 peer 内的模型自行处理(性能可能不及 vLLM 内部调度)。 -""" - from __future__ import annotations -from typing import Any, Dict, List, Tuple - -import torch +from typing import Any, Dict, List from parallax.server.request import Request +from parallax.server.sampling.sampling_params import ( + SamplingParams as ParallaxSamplingParams, +) +from vllm.v1.request import Request as VLLMRequest +from vllm.sampling_params import ( + SamplingParams as VLLMSamplingParams, + StructuredOutputsParams, +) +from parallax_utils.logging_config import get_logger + +logger = get_logger(__name__) + + +def transform_sampling_params_to_vllm(old_params: ParallaxSamplingParams) -> VLLMSamplingParams: + """Transforms Parallax SamplingParams to vLLM SamplingParams format.""" + # Map Parallax json_schema -> vLLM structured_outputs + structured = ( + StructuredOutputsParams(json=old_params.json_schema) + if getattr(old_params, "json_schema", None) is not None + else None + ) + + # vLLM uses max_tokens/min_tokens naming + params = VLLMSamplingParams( + max_tokens=old_params.max_new_tokens, + min_tokens=old_params.min_new_tokens, + temperature=old_params.temperature, + top_p=old_params.top_p, + min_p=old_params.min_p, + top_k=old_params.top_k, + stop_token_ids=( + list(old_params.stop_token_ids) + if getattr(old_params, "stop_token_ids", None) is not None + else None + ), + ignore_eos=old_params.ignore_eos, + stop=old_params.stop_strs, + repetition_penalty=old_params.repetition_penalty, + presence_penalty=old_params.presence_penalty, + frequency_penalty=old_params.frequency_penalty, + structured_outputs=structured, + ) + return params + + +def transform_requests_to_vllm(batched_requests: List[Request]) -> List[VLLMRequest]: + """Transforms Parallax Request to vLLM Request format. + Note: Only used if we later choose to feed vLLM Engine directly. + """ + vllm_reqs = [] + for old_req in batched_requests: + sampling_params = transform_sampling_params_to_vllm(old_req.sampling_params) + vllm_req = VLLMRequest( + request_id=old_req.request_id, + prompt_token_ids=old_req.input_ids, + sampling_params=sampling_params, + eos_token_id=getattr(old_req, "eos_token_id", None), + client_index=getattr(old_req, "client_index", 0), + ) + vllm_reqs.append(vllm_req) + return vllm_reqs + +def form_vllm_batch_prefill(batched_requests: List[Request], pad_token_id: int) -> Dict[str, Any] | None: + """Builds the vLLM prefill batch inputs for the first peer. -def _pad_2d(seqs: List[torch.Tensor], pad_id: int) -> Tuple[torch.Tensor, torch.Tensor]: - """将一组 1D token 张量 padding 为 2D (B, L_max),返回 padded 和 长度 tensor。""" - if not seqs: - return torch.empty(0, 0, dtype=torch.long), torch.empty(0, dtype=torch.long) - max_len = max(x.numel() for x in seqs) - bsz = len(seqs) - padded = torch.full((bsz, max_len), pad_id, dtype=seqs[0].dtype, device=seqs[0].device) - lengths = torch.empty(bsz, dtype=torch.long, device=seqs[0].device) - for i, x in enumerate(seqs): - L = x.numel() - padded[i, :L] = x - lengths[i] = L - return padded, lengths - - -def form_vllm_batch_prefill( - batched_requests: List[Request], pad_token_id: int -) -> Dict[str, Any]: - """首个 peer: 使用 input_ids;中间/最后 peer: 由 executor 传入 intermediate_tensors。 - 这里仅组装 input_ids/lengths/requests(供首个 peer 使用)。 + Returns a dict with: + - input_ids: List[List[int]] padded to max prompt length with pad_token_id. """ - if len(batched_requests) == 0: + batch_size = len(batched_requests) + if batch_size == 0: return None - # 收集 tokens(first peer 情况) - token_lists: List[torch.Tensor] = [] + + # Collect prompts and compute max length + seqs: List[List[int]] = [] + max_len = 0 for req in batched_requests: - assert hasattr(req, "input_ids") and req.input_ids is not None - # 将 list[int] 转为 torch tensor - token_lists.append(torch.tensor(req.input_ids, dtype=torch.long, device="cuda")) - input_ids, lengths = _pad_2d(token_lists, pad_token_id) - return { - "input_ids": input_ids, - "lengths": lengths, - "requests": batched_requests, - } - - -def form_vllm_batch_decode( - batched_requests: List[Request], is_first_peer: bool -) -> Dict[str, Any]: - """解码批次: - - 首个 peer: 仅传最后一个 token。 - - 中间/最后 peer: 由 executor 提供 intermediate_tensors。 - 这里只组装首个 peer 所需的输入。 + assert req.is_prefill, f"Request {req.request_id} is not a prefill request." + assert req.input_ids is not None and len(req.input_ids) > 0, ( + f"Request {req.request_id} has empty input_ids for prefill" + ) + seqs.append(req.input_ids) + if len(req.input_ids) > max_len: + max_len = len(req.input_ids) + + # Right-pad to max_len with pad_token_id + padded: List[List[int]] = [seq + [pad_token_id] * (max_len - len(seq)) for seq in seqs] + + return {"input_ids": padded} + + +def form_vllm_batch_decode(batched_requests: List[Request], is_first_peer: bool) -> Dict[str, Any] | None: + """Builds the vLLM decode batch inputs for the first peer. + + For decode, the first peer feeds the last generated token per request. + Other peers return None (they use pp_proxy_tensors path). """ - if len(batched_requests) == 0: - return None if not is_first_peer: - # 非首个 peer 不需要 tokens 输入 - return { - "input_ids": None, - "lengths": torch.tensor([1 for _ in batched_requests], device="cuda"), - "requests": batched_requests, - } - last_tokens: List[torch.Tensor] = [] - for req in batched_requests: - assert req.output_ids is not None and len(req.output_ids) > 0 - last_tokens.append(torch.tensor([req.output_ids[-1]], dtype=torch.long, device="cuda")) - input_ids, lengths = _pad_2d(last_tokens, pad_id=0) - return { - "input_ids": input_ids, - "lengths": lengths, - "requests": batched_requests, - } + return None + # For first peer, gather the next-step input token ids (last output token) + tokens: List[int] = [] + for req in batched_requests: + assert req.is_decoding, f"Request {req.request_id} is not a decode request." + assert req.output_ids is not None and len(req.output_ids) > 0, ( + f"Decode step requires at least one output token for {req.request_id}" + ) + tokens.append(req.output_ids[-1]) + # Use shape [batch, 1] for consistency + return {"input_ids": [[tok] for tok in tokens]} diff --git a/src/parallax/vllm/model_runner.py b/src/parallax/vllm/model_runner.py index 392ca18b..2ae28399 100644 --- a/src/parallax/vllm/model_runner.py +++ b/src/parallax/vllm/model_runner.py @@ -1,361 +1 @@ -""" -Imports vLLM ModelRunner related modules and wrap them into create functions. -We use monkey patch to modify vLLM originated methods. The main purpose is to pass -arguments needed by decentralized inference with pipeline parallelism. -""" - -import logging -import os -import random -from typing import Any, Dict, List, Optional, Tuple - -import torch -from mlx_lm.utils import get_model_path, load_config -from vllm import EngineArgs, LLMEngine -from vllm.config import ( - CacheConfig, - DeviceConfig, - LoadConfig, - LoRAConfig, - ModelConfig, - ParallelConfig, - SchedulerConfig, -) -from vllm.executor.ray_gpu_executor import RayGPUExecutor -from vllm.model_executor.layers.sampler import Sampler -from vllm.sequence import SamplerOutput, IntermediateTensors -from vllm.utils import get_distributed_init_method, get_ip, get_open_port - -from parallax.utils.tokenizer_utils import load_tokenizer - -logger = logging.getLogger(__name__) - - -class ParallaxVLLMEngine: - """ - Wrapper around vLLM Engine that supports pipeline parallelism for decentralized inference. - This class handles the sharding of layers across different nodes. - """ - - def __init__( - self, - model_config: ModelConfig, - cache_config: CacheConfig, - parallel_config: ParallelConfig, - scheduler_config: SchedulerConfig, - device_config: DeviceConfig, - load_config: LoadConfig, - lora_config: Optional[LoRAConfig], - pp_start_layer: int, - pp_end_layer: int, - **kwargs, - ): - """ - Initialize ParallaxVLLMEngine with pipeline parallelism support. - - Args: - model_config: vLLM model configuration - cache_config: vLLM cache configuration - parallel_config: vLLM parallel configuration - scheduler_config: vLLM scheduler configuration - device_config: vLLM device configuration - load_config: vLLM load configuration - lora_config: Optional LoRA configuration - pp_start_layer: Starting layer index for this shard (inclusive) - pp_end_layer: Ending layer index for this shard (exclusive) - """ - self.pp_start_layer = pp_start_layer - self.pp_end_layer = pp_end_layer - self.model_config = model_config - self.cache_config = cache_config - self.parallel_config = parallel_config - self.scheduler_config = scheduler_config - self.device_config = device_config - self.load_config = load_config - self.lora_config = lora_config - - # Modify model config to only load specified layers - self.model_config.hf_config.start_layer = pp_start_layer - self.model_config.hf_config.end_layer = pp_end_layer - - # Initialize the vLLM engine - # Note: vLLM doesn't natively support arbitrary layer sharding, - # so we need to monkey patch the model loading - from vllm.worker.model_runner import ModelRunner - - self.model_runner = None - self.is_first_peer = pp_start_layer == 0 - self.is_last_peer = pp_end_layer == model_config.hf_config.num_hidden_layers - - logger.info( - f"Initialized ParallaxVLLMEngine: layers [{pp_start_layer}, {pp_end_layer}), " - f"is_first={self.is_first_peer}, is_last={self.is_last_peer}" - ) - - def initialize_model(self): - """Initialize the model with the specified layer range.""" - # Import here to avoid circular dependency - from vllm.worker.worker import Worker - - # Create worker with modified configuration - worker = Worker( - model_config=self.model_config, - parallel_config=self.parallel_config, - scheduler_config=self.scheduler_config, - device_config=self.device_config, - cache_config=self.cache_config, - load_config=self.load_config, - local_rank=0, - rank=0, - distributed_init_method=get_distributed_init_method(get_ip(), get_open_port()), - lora_config=self.lora_config, - kv_cache_dtype=self.cache_config.cache_dtype, - ) - - # Initialize worker - worker.init_device() - worker.load_model() - - self.model_runner = worker.model_runner - logger.info("vLLM model loaded successfully") - - def forward( - self, - *, - input_ids: Optional[torch.Tensor], - lengths: Optional[torch.Tensor], - pp_proxy_tensors: Optional[Dict[str, torch.Tensor]], - return_logits: bool, - ) -> Dict[str, torch.Tensor]: - """ - 进行一次前向: - - 首个 peer 传 input_ids/lengths; - - 其它 peer 传 pp_proxy_tensors={hidden_states,residual}; - - 最后一个 peer 设置 return_logits=True 以便采样。 - """ - assert self.model_runner is not None, "Model not initialized" - - # positions(简单从长度构建,形状 [B, L] -> 递增序列);非首个 peer 不需要 - positions = None - if input_ids is not None and lengths is not None: - batch_size, max_len = input_ids.shape - positions = torch.arange(max_len, device=input_ids.device).unsqueeze(0).expand(batch_size, -1) - - inter_tensors = None - if pp_proxy_tensors is not None: - inter_tensors = IntermediateTensors( - tensors={ - "hidden_states": pp_proxy_tensors["hidden_states"], - "residual": pp_proxy_tensors["residual"], - } - ) - - # 直接调用底层模型 - outputs = self.model_runner.model( - input_ids=input_ids, - positions=positions, - intermediate_tensors=inter_tensors, - inputs_embeds=None, - ) - - # 约定:模型返回包含 hidden_states/residual 或 logits - ret: Dict[str, torch.Tensor] = {} - if isinstance(outputs, dict): - if "hidden_states" in outputs: - ret["hidden_states"] = outputs["hidden_states"] - if "residual" in outputs: - ret["residual"] = outputs["residual"] - if return_logits and "logits" in outputs: - ret["logits"] = outputs["logits"] - elif isinstance(outputs, tuple): - # 兜底:假设 (hidden_states, residual) 或 (logits,) - if return_logits and len(outputs) == 1: - ret["logits"] = outputs[0] - elif len(outputs) >= 2: - ret["hidden_states"] = outputs[0] - ret["residual"] = outputs[1] - else: - # 可能直接返回 logits - if return_logits: - ret["logits"] = outputs - return ret - - @staticmethod - def sample_argmax(logits: torch.Tensor) -> torch.Tensor: - """简单贪心采样。logits: [B, V] -> token_ids: [B]""" - return torch.argmax(logits, dim=-1) - - def execute_model( - self, - seq_group_metadata_list: List[SequenceGroupMetadata], - kv_caches: List[torch.Tensor], - ) -> SamplerOutput: - """ - Execute the model on the given sequences. - - Args: - seq_group_metadata_list: List of sequence group metadata - kv_caches: List of KV cache tensors - - Returns: - SamplerOutput containing logits or sampled tokens - """ - if self.model_runner is None: - raise RuntimeError("Model not initialized. Call initialize_model() first.") - - return self.model_runner.execute_model( - seq_group_metadata_list=seq_group_metadata_list, kv_caches=kv_caches - ) - - -def form_vllm_engine_args( - model_path: str, - dtype: str = "bfloat16", - kv_block_size: int = 16, - gpu_memory_utilization: float = 0.85, - max_num_seqs: int = 256, - max_model_len: Optional[int] = None, - enforce_eager: bool = False, - **kwargs, -) -> EngineArgs: - """ - Creates vLLM EngineArgs object with Parallax-specific configurations. - - Args: - model_path: Path or name of the model - dtype: Data type for model weights (e.g., "bfloat16", "float16") - kv_block_size: Block size for paged attention KV cache - gpu_memory_utilization: Fraction of GPU memory to use - max_num_seqs: Maximum number of sequences to process - max_model_len: Maximum model context length - enforce_eager: Whether to enforce eager execution (disable CUDA graphs) - - Returns: - EngineArgs: vLLM engine arguments - """ - engine_args = EngineArgs( - model=model_path, - dtype=dtype, - tokenizer=model_path, - trust_remote_code=True, - gpu_memory_utilization=gpu_memory_utilization, - max_num_seqs=max_num_seqs, - max_model_len=max_model_len, - block_size=kv_block_size, - enforce_eager=enforce_eager, - # Disable tensor parallelism for now (will be handled by Parallax) - tensor_parallel_size=1, - pipeline_parallel_size=1, - **kwargs, - ) - return engine_args - - -def initialize_vllm_model_runner( - original_model_path: str, - start_layer: int, - end_layer: int, - kv_cache_memory_fraction: float, - kv_block_size: int, - max_num_seqs: int = 256, - max_model_len: Optional[int] = None, - enforce_eager: bool = False, -) -> Tuple[ParallaxVLLMEngine, Dict[str, Any], Any]: - """ - Creates a Parallax vLLM Engine object for decentralized inference. - - Args: - original_model_path: Original model path or name - start_layer: Starting layer index (inclusive) - end_layer: Ending layer index (exclusive) - kv_cache_memory_fraction: Fraction of memory for KV cache - kv_block_size: Block size for paged attention - max_num_seqs: Maximum number of sequences - max_model_len: Maximum model context length - enforce_eager: Whether to disable CUDA graphs - - Returns: - Tuple of (vllm_engine, config_dict, tokenizer) - """ - # Load model configuration - model_path = get_model_path(original_model_path)[0] - config = load_config(model_path) - tokenizer = load_tokenizer(model_path, eos_token_ids=config.get("eos_token_id", None)) - - # Get dtype from config - dtype = str(config.get("torch_dtype", "bfloat16")).replace("torch.", "") - - # Create engine args - engine_args = form_vllm_engine_args( - model_path=original_model_path, - dtype=dtype, - kv_block_size=kv_block_size, - gpu_memory_utilization=kv_cache_memory_fraction, - max_num_seqs=max_num_seqs, - max_model_len=max_model_len, - enforce_eager=enforce_eager, - ) - - # Create model, cache, parallel, scheduler, and device configs - model_config = ModelConfig( - model=original_model_path, - tokenizer=original_model_path, - tokenizer_mode="auto", - trust_remote_code=True, - dtype=dtype, - seed=0, - max_model_len=max_model_len, - ) - - cache_config = CacheConfig( - block_size=kv_block_size, - gpu_memory_utilization=kv_cache_memory_fraction, - swap_space=4, # GB - cache_dtype=dtype, - ) - - parallel_config = ParallelConfig( - pipeline_parallel_size=1, - tensor_parallel_size=1, - worker_use_ray=False, - max_parallel_loading_workers=None, - ) - - scheduler_config = SchedulerConfig( - max_num_batched_tokens=None, - max_num_seqs=max_num_seqs, - max_model_len=model_config.max_model_len, - ) - - device_config = DeviceConfig(device="cuda") - - load_config = LoadConfig( - load_format="auto", - download_dir=None, - model_loader_extra_config=None, - ) - - # Create Parallax vLLM Engine - logger.info( - f"Creating ParallaxVLLMEngine: model={original_model_path}, " - f"layers=[{start_layer}, {end_layer}), dtype={dtype}" - ) - - vllm_engine = ParallaxVLLMEngine( - model_config=model_config, - cache_config=cache_config, - parallel_config=parallel_config, - scheduler_config=scheduler_config, - device_config=device_config, - load_config=load_config, - lora_config=None, - pp_start_layer=start_layer, - pp_end_layer=end_layer, - ) - - # Initialize the model - vllm_engine.initialize_model() - - logger.info(f"vLLM model runner initialized for layers [{start_layer}, {end_layer})") - - return vllm_engine, config, tokenizer +pass From 5696936de86183cb7b4fff239882fafc649abea4 Mon Sep 17 00:00:00 2001 From: Alien mac air <2214632589@qq.com> Date: Tue, 21 Oct 2025 15:06:44 +0800 Subject: [PATCH 04/36] add support without model shard (PP) --- src/parallax/server/executor.py | 293 +++++++++++++++++++++--------- src/parallax/sglang/batch_info.py | 2 +- src/parallax/vllm/batch_info.py | 207 +++++++++++++++++---- src/parallax/vllm/model_runner.py | 229 ++++++++++++++++++++++- 4 files changed, 614 insertions(+), 117 deletions(-) diff --git a/src/parallax/server/executor.py b/src/parallax/server/executor.py index 1da3df2e..cf9e31c8 100644 --- a/src/parallax/server/executor.py +++ b/src/parallax/server/executor.py @@ -71,6 +71,8 @@ def __init__( start_layer: int, end_layer: int, dtype: str = "float16", + # Backend selection + gpu_backend: str = "sglang", # Scheduler Configs max_batch_size: Optional[int] = 8, max_sequence_length: Optional[int] = None, @@ -100,32 +102,43 @@ def __init__( # Backend self.device = get_current_device() logger.debug(f"Executor initializing on device: {self.device}") + self.backend_type = gpu_backend # Sharded Model if self.device == "cuda": - from sglang.srt.managers.schedule_batch import ScheduleBatch + if self.backend_type == "vllm": + from parallax.vllm.model_runner import ( + initialize_vllm_model_runner as initialize_cuda_model_runner, + ) - from parallax.sglang.model_runner import initialize_sgl_model_runner + logger.debug( + f"Initializing vLLM model runner for repo={model_repo}, layers=[{start_layer}, {end_layer})" + ) + elif self.backend_type == "sglang": + from sglang.srt.managers.schedule_batch import ScheduleBatch as CudaScheduleBatch + from parallax.sglang.model_runner import ( + initialize_sgl_model_runner as initialize_cuda_model_runner, + ) - logger.debug( - f"Initializing CUDA model runner for repo={model_repo}, layers=[{start_layer}, {end_layer})" - ) - self.model_runner, self.config, self.tokenizer = initialize_sgl_model_runner( + logger.debug( + f"Initializing SGLang model runner for repo={model_repo}, layers=[{start_layer}, {end_layer})" + ) + else: + raise ValueError(f"Unsupported GPU backend type: {self.backend_type}") + + self.model_runner, self.config, self.tokenizer = initialize_cuda_model_runner( model_repo, start_layer, end_layer, kv_cache_memory_fraction, attention_backend, kv_block_size, - moe_runner_backend, ) - logger.debug( - f"CUDA model runner initialized. num_layers={self.config.get('num_hidden_layers')}" - ) - # SGL KV Cache Manager is already initialized in ScheduleBatch - # TODO: Replace ScheduleBatch to Parallax inflight batch - self.running_batch = ScheduleBatch(reqs=[], batch_is_full=False) + self.running_batch = None self.cur_batch = None + if self == "sglang": + self.running_batch = CudaScheduleBatch(reqs=[], batch_is_full=False) + else: logger.debug( f"Initializing MLX sharded model loader for repo={model_repo}, layers=[{start_layer}, {end_layer})" @@ -363,19 +376,16 @@ def recv_requests_from_peer(self) -> List[Request]: def _prepare_cuda_prefill_batch(self, batched_requests: List[Request]) -> Dict[str, Any]: """ - Prepares inputs for SGLang model runner from a batch of prefill requests. - Returns: SGLang ScheduleBatch + Prepares inputs for CUDA backends from a batch of prefill requests. + Routes to SGLang or vLLM depending on backend_type. """ from sglang.srt.model_executor.forward_batch_info import PPProxyTensors - from parallax.sglang.batch_info import form_sgl_batch_prefill - batch_size = len(batched_requests) if batch_size == 0: return None - schedule_batch, forward_batch = form_sgl_batch_prefill(batched_requests, self.model_runner) - self.cur_batch = schedule_batch + # Prepare PP proxy tensors (common for both backends when not first peer) pp_proxy_tensors = None if not self.is_first_peer: hidden_states = torch.cat( @@ -399,40 +409,55 @@ def _prepare_cuda_prefill_batch(self, batched_requests: List[Request]) -> Dict[s } ) logger.debug(f"PP Proxy: hidden_states shape: {hidden_states.shape}") + + # Prepare lengths (common for both backends) lengths = [] for req in batched_requests: lengths.append(req.total_length) - ret = { - "forward_batch": forward_batch, - "pp_proxy_tensors": pp_proxy_tensors, - "lengths": torch.tensor(lengths, device=self.device), - "requests": batched_requests, - } - logger.debug(f"Prepared CUDA prefill batch (size={batch_size})") - return ret + lengths_tensor = torch.tensor(lengths, device=self.device) + + if self.backend_type == "vllm": + from parallax.vllm.batch_info import form_vllm_batch_prefill + + schedule_outputs_prefill = form_vllm_batch_prefill(batched_requests, self.model_runner) + + ret = { + "scheduler_output": schedule_outputs_prefill, + "pp_proxy_tensors": pp_proxy_tensors, + "lengths": lengths_tensor, + "requests": batched_requests, + } + logger.debug(f"Prepared CUDA prefill batch (vllm, size={batch_size})") + return ret + else: + from parallax.sglang.batch_info import form_sgl_batch_prefill + + schedule_batch, forward_batch = form_sgl_batch_prefill( + batched_requests, self.model_runner + ) + self.cur_batch = schedule_batch + + ret = { + "forward_batch": forward_batch, + "pp_proxy_tensors": pp_proxy_tensors, + "lengths": lengths_tensor, + "requests": batched_requests, + } + logger.debug(f"Prepared CUDA prefill batch (sglang, size={batch_size})") + return ret def _prepare_cuda_decode_batch(self, batched_requests: List[Request]) -> Dict[str, Any]: """ - Prepares inputs for SGLang model runner from a batch of decode requests. - Returns: SGLang ScheduleBatch + Prepares inputs for CUDA backends from a batch of decode requests. + Routes to SGLang or vLLM depending on backend_type. """ from sglang.srt.model_executor.forward_batch_info import PPProxyTensors - from parallax.sglang.batch_info import form_sgl_batch_decode - batch_size = len(batched_requests) if batch_size == 0: return None - lengths = [] - for req in batched_requests: - lengths.append(req.total_length) - forward_batch = form_sgl_batch_decode( - batched_requests, - self.model_runner, - self.running_batch, - self.is_first_peer, - ) + # Prepare PP proxy tensors (common for both backends when not first peer) pp_proxy_tensors = None if not self.is_first_peer: hidden_states = torch.cat( @@ -456,14 +481,43 @@ def _prepare_cuda_decode_batch(self, batched_requests: List[Request]) -> Dict[st } ) logger.debug(f"PP Proxy: hidden_states shape: {hidden_states.shape}") - ret = { - "forward_batch": forward_batch, - "pp_proxy_tensors": pp_proxy_tensors, - "lengths": torch.tensor(lengths, device=self.device), - "requests": batched_requests, - } - logger.debug(f"Prepared CUDA decode batch (size={batch_size})") - return ret + + # Prepare lengths (common for both backends) + lengths = [] + for req in batched_requests: + lengths.append(req.total_length) + lengths_tensor = torch.tensor(lengths, device=self.device) + + if self.backend_type == "vllm": + from parallax.vllm.batch_info import form_vllm_batch_decode + + scheduler_outputs_decode = form_vllm_batch_decode(batched_requests, self.model_runner) + ret = { + "scheduler_output": scheduler_outputs_decode, + "pp_proxy_tensors": pp_proxy_tensors, + "lengths": lengths_tensor, + "requests": batched_requests, + } + logger.debug(f"Prepared CUDA decode batch (vllm, size={batch_size})") + return ret + else: + from parallax.sglang.batch_info import form_sgl_batch_decode + + forward_batch = form_sgl_batch_decode( + batched_requests, + self.model_runner, + self.running_batch, + self.is_first_peer, + ) + + ret = { + "forward_batch": forward_batch, + "pp_proxy_tensors": pp_proxy_tensors, + "lengths": lengths_tensor, + "requests": batched_requests, + } + logger.debug(f"Prepared CUDA decode batch (sglang, size={batch_size})") + return ret def _prepare_mlx_prefill_batch(self, batched_requests: List[Request]) -> Dict[str, Any]: """Prepares inputs for ShardedModel from a batch of prefill requests.""" @@ -728,7 +782,7 @@ def _handle_cuda_input_requests(self, requests: List[Request]): Cuda specialized handle function. The main difference is to remove all the kv cache operations. """ - from parallax.sglang.batch_info import release_cuda_request + from parallax.sglang.batch_info import release_sglang_request if self.is_first_peer: # First peer can receive InitialRequests from the client RPC, @@ -754,7 +808,12 @@ def _handle_cuda_input_requests(self, requests: List[Request]): # Check for termination. if self.scheduler.check_and_update_request_status(original_req): logger.debug(f"Releasing resources for finished request {req.request_id}") - release_cuda_request(self.running_batch, req.request_id) + if self.backend_type == "sglang": + release_sglang_request(self.running_batch, req.request_id) + elif self.backend_type == "vllm": + from parallax.vllm.batch_info import release_vllm_request + + release_vllm_request(self.model_runner, req.request_id) if not self.is_last_peer: self.finished_batch.append(req) else: @@ -781,7 +840,12 @@ def _handle_cuda_input_requests(self, requests: List[Request]): ), "Non-first peers must receive IntermediateRequests." if req.is_finished or req.hidden_states is None: self.scheduler.evict_request(req.request_id, req.status) - release_cuda_request(self.running_batch, req.request_id) + if self.backend_type == "sglang": + release_sglang_request(self.running_batch, req.request_id) + elif self.backend_type == "vllm": + from parallax.vllm.batch_info import release_vllm_request + + release_vllm_request(self.model_runner, req.request_id) if not self.is_last_peer: self.finished_batch.append(req) else: @@ -995,37 +1059,98 @@ def _process_batch_cuda( ): """ Process a batch of requests in CUDA. + + Supports both vLLM and SGLang backends with Pipeline Parallelism. + + Args: + prepared_inputs: Dict containing batch data and metadata + return_decoded_tokens: If True, return token IDs (last peer); + If False, return hidden states (intermediate peer) + + Returns: + token_ids (Tensor): If return_decoded_tokens=True + hidden_states (Tensor): If return_decoded_tokens=False """ - assert "forward_batch" in prepared_inputs, "forward_batch should be in cuda prepared inputs" - assert ( - "pp_proxy_tensors" in prepared_inputs - ), "pp_proxy_tensors should be in cuda prepared inputs" - forward_batch = prepared_inputs["forward_batch"] - pp_proxy_tensors = prepared_inputs["pp_proxy_tensors"] - logits_output, _ = self.model_runner.forward( - forward_batch=forward_batch, - pp_proxy_tensors=pp_proxy_tensors, - ) + if self.backend_type == "vllm": + # ========== vLLM Backend ========== + assert ( + "scheduler_output" in prepared_inputs + ), "scheduler_output should be provided for vLLM backend" + assert ( + "pp_proxy_tensors" in prepared_inputs + ), "pp_proxy_tensors should be in cuda prepared inputs" + scheduler_output = prepared_inputs["scheduler_output"] + pp_proxy_tensors = prepared_inputs["pp_proxy_tensors"] + intermediate_tensors = None + if pp_proxy_tensors is not None: + # Convert SGLang's PPProxyTensors to vLLM's IntermediateTensors + from vllm.sequence import IntermediateTensors + + intermediate_tensors = IntermediateTensors(pp_proxy_tensors.tensors) + logger.debug(f"vLLM: Using intermediate_tensors for PP (non-first peer)") + + # Execute model with vLLM + output = self.model_runner.execute_model( + scheduler_output=scheduler_output, + intermediate_tensors=intermediate_tensors, + ) - if self.cur_batch: - if self.cur_batch.forward_mode.is_extend(): - # Merge the new batch into the running batch - if not self.cur_batch.is_empty(): - if self.running_batch.is_empty(): - self.running_batch = self.cur_batch - else: - # Merge running_batch with prefill batch - self.running_batch.merge_batch(self.cur_batch) - self.cur_batch = None + # Return appropriate output based on peer position + if return_decoded_tokens: + # Last peer: return sampled token IDs + return output.sampled_token_ids + else: + # Intermediate peer: return hidden states for next peer + if hasattr(output, "hidden_states") and output.hidden_states is not None: + return output.hidden_states + else: + raise RuntimeError( + "vLLM backend: expected hidden_states in output for PP, but got None. " + "This typically means the model runner is not configured for pipeline parallelism." + ) - if return_decoded_tokens: - next_token_ids = self.model_runner.sample(logits_output, forward_batch) - return next_token_ids - # Currently hack the result of (hidden_state + residual) here for GPU - final_hidden_states = ( - logits_output.tensors["hidden_states"] + logits_output.tensors["residual"] - ) - return final_hidden_states + else: # self.backend_type == "sglang" + # ========== SGLang Backend ========== + assert ( + "forward_batch" in prepared_inputs + ), "forward_batch should be in cuda prepared inputs" + assert ( + "pp_proxy_tensors" in prepared_inputs + ), "pp_proxy_tensors should be in cuda prepared inputs" + + forward_batch = prepared_inputs["forward_batch"] + pp_proxy_tensors = prepared_inputs["pp_proxy_tensors"] + + # Execute model with SGLang + logits_output, _ = self.model_runner.forward( + forward_batch=forward_batch, + pp_proxy_tensors=pp_proxy_tensors, + ) + + # SGLang-specific batch management: merge prefill batch into running batch + if self.cur_batch: + if self.cur_batch.forward_mode.is_extend(): + # Merge the new batch into the running batch + if not self.cur_batch.is_empty(): + if self.running_batch.is_empty(): + self.running_batch = self.cur_batch + else: + # Merge running_batch with prefill batch + self.running_batch.merge_batch(self.cur_batch) + self.cur_batch = None + + # Return appropriate output based on peer position + if return_decoded_tokens: + # Last peer: sample and return token IDs + next_token_ids = self.model_runner.sample(logits_output, forward_batch) + return next_token_ids + else: + # Intermediate peer: return hidden states for next peer + # Note: SGLang stores hidden_states + residual separately + final_hidden_states = ( + logits_output.tensors["hidden_states"] + logits_output.tensors["residual"] + ) + return final_hidden_states def _process_batch_mlx( self, prepared_inputs: Dict[str, Any], return_decoded_tokens: bool = True @@ -1206,9 +1331,14 @@ def run_loop(self): for req in batch_to_process: self.scheduler.evict_request(req.request_id, req.status) if self.device == "cuda": - from parallax.sglang.batch_info import release_cuda_request + if self.backend_type == "vllm": + from parallax.vllm.batch_info import release_vllm_request + + release_vllm_request(self.model_runner, req.request_id) + elif self.backend_type == "sglang": + from parallax.sglang.batch_info import release_sglang_request - release_cuda_request(self.running_batch, req.request_id) + release_sglang_request(self.running_batch, req.request_id) else: self.kv_cache_manager.release_request(req.request_id) @@ -1234,6 +1364,7 @@ def create_executor_config(args: argparse.Namespace): "start_layer": args.start_layer, "end_layer": args.end_layer, "dtype": args.dtype, + "gpu_backend": args.gpu_backend if hasattr(args, "gpu_backend") else "sglang", "max_sequence_length": args.max_sequence_length if "max_sequence_length" in args else None, "max_batch_size": args.max_batch_size if "max_batch_size" in args else None, "kv_block_size": args.kv_block_size, diff --git a/src/parallax/sglang/batch_info.py b/src/parallax/sglang/batch_info.py index 4ce2a89f..0b3e3c77 100755 --- a/src/parallax/sglang/batch_info.py +++ b/src/parallax/sglang/batch_info.py @@ -191,7 +191,7 @@ def form_sgl_batch_decode( return forward_batch -def release_cuda_request(running_batch: ScheduleBatch, request_id: str): +def release_sglang_request(running_batch: ScheduleBatch, request_id: str): """Release KV Cache and other resources for finished/aborted requests.""" seq_lens_cpu = running_batch.seq_lens.cpu().numpy() idx = find_index(running_batch, request_id) diff --git a/src/parallax/vllm/batch_info.py b/src/parallax/vllm/batch_info.py index 257b8ec5..eb7b7f02 100644 --- a/src/parallax/vllm/batch_info.py +++ b/src/parallax/vllm/batch_info.py @@ -1,3 +1,16 @@ +""" +Store information about a vLLM batch. + +This module provides batch formation utilities for vLLM v1 backend integration. +It transforms Parallax requests into vLLM-compatible structures for both prefill +and decode stages. + +Key differences from SGLang: +- vLLM uses SchedulerOutput (flat) vs SGLang's ScheduleBatch (hierarchical) +- KV Cache is managed independently via KVCache object +- Sampling is integrated in execute_model() call +""" + from __future__ import annotations from typing import Any, Dict, List @@ -17,7 +30,14 @@ def transform_sampling_params_to_vllm(old_params: ParallaxSamplingParams) -> VLLMSamplingParams: - """Transforms Parallax SamplingParams to vLLM SamplingParams format.""" + """Transforms Parallax SamplingParams to vLLM SamplingParams format. + + Args: + old_params: Parallax sampling parameters + + Returns: + vLLM SamplingParams object + """ # Map Parallax json_schema -> vLLM structured_outputs structured = ( StructuredOutputsParams(json=old_params.json_schema) @@ -50,7 +70,15 @@ def transform_sampling_params_to_vllm(old_params: ParallaxSamplingParams) -> VLL def transform_requests_to_vllm(batched_requests: List[Request]) -> List[VLLMRequest]: """Transforms Parallax Request to vLLM Request format. + Note: Only used if we later choose to feed vLLM Engine directly. + Currently we bypass the Engine and use GPUModelRunner directly. + + Args: + batched_requests: List of Parallax requests + + Returns: + List of vLLM Request objects """ vllm_reqs = [] for old_req in batched_requests: @@ -63,54 +91,165 @@ def transform_requests_to_vllm(batched_requests: List[Request]) -> List[VLLMRequ client_index=getattr(old_req, "client_index", 0), ) vllm_reqs.append(vllm_req) + return vllm_reqs -def form_vllm_batch_prefill(batched_requests: List[Request], pad_token_id: int) -> Dict[str, Any] | None: - """Builds the vLLM prefill batch inputs for the first peer. +def form_vllm_batch_prefill( + batched_requests: List[Request], + model_runner: Any = None, +) -> Dict[str, Any]: + """Prepare a vLLM prefill batch. + + Constructs a SchedulerOutput for vLLM v1 GPUModelRunner that contains: + - NewRequestData for each request (new prefill requests) + - KV cache block allocations + - Token scheduling information - Returns a dict with: - - input_ids: List[List[int]] padded to max prompt length with pad_token_id. + Args: + batched_requests: List of Parallax requests to prefill + model_runner: vLLM GPUModelRunner instance + + Returns: + Dict containing: + - scheduler_output: SchedulerOutput for vLLM + - requests: Original Parallax requests + Returns None if batched_requests is empty """ - batch_size = len(batched_requests) - if batch_size == 0: + from vllm.v1.core.sched.output import CachedRequestData, NewRequestData, SchedulerOutput + + if not batched_requests: return None - # Collect prompts and compute max length - seqs: List[List[int]] = [] - max_len = 0 + # Initialize KV cache manager if not already done + # This is a lightweight wrapper around vLLM's KV cache + if not hasattr(model_runner, "_parallax_kv_cache"): + from parallax.vllm.model_runner import VLLMKVCacheManager + + model_runner._parallax_kv_cache = VLLMKVCacheManager( + model_runner, model_runner.kv_cache_config.block_size + ) + + kv_cache = model_runner._parallax_kv_cache + + # Build NewRequestData for each request + new_request_data_list = [] for req in batched_requests: - assert req.is_prefill, f"Request {req.request_id} is not a prefill request." - assert req.input_ids is not None and len(req.input_ids) > 0, ( - f"Request {req.request_id} has empty input_ids for prefill" + sampling_params = transform_sampling_params_to_vllm(req.sampling_params) + + # Allocate KV cache blocks for this request + block_ids = kv_cache.allocate(req.request_id, len(req.input_ids)) + + new_req_data = NewRequestData( + req_id=req.request_id, + prompt_token_ids=req.input_ids, + mm_features=[], # Multimodal features (empty for text-only) + sampling_params=sampling_params, + pooling_params=None, # For embedding models + block_ids=block_ids, + num_computed_tokens=0, # Prefill starts from scratch + lora_request=None, # LoRA adapter + prompt_embeds=None, # Soft prompts ) - seqs.append(req.input_ids) - if len(req.input_ids) > max_len: - max_len = len(req.input_ids) + new_request_data_list.append(new_req_data) + + # Build SchedulerOutput + # This is the main data structure that vLLM's model runner expects + scheduler_output = SchedulerOutput( + scheduled_new_reqs=new_request_data_list, + scheduled_cached_reqs=CachedRequestData.make_empty(), # No cached reqs in prefill + num_scheduled_tokens={req.request_id: len(req.input_ids) for req in batched_requests}, + total_num_scheduled_tokens=sum(len(req.input_ids) for req in batched_requests), + scheduled_spec_decode_tokens={}, # Speculative decoding tokens + scheduled_encoder_inputs={}, # For encoder-decoder models + num_common_prefix_blocks=[], # Prefix caching + finished_req_ids=set(), # No finished requests in prefill + free_encoder_mm_hashes=[], # Encoder multimodal hash cleanup + structured_output_request_ids=[], # Requests using structured output + grammar_bitmask=None, # Grammar constraints + kv_connector_metadata=None, # KV connector for disaggregation + ) + + return scheduler_output, batched_requests - # Right-pad to max_len with pad_token_id - padded: List[List[int]] = [seq + [pad_token_id] * (max_len - len(seq)) for seq in seqs] - return {"input_ids": padded} +def form_vllm_batch_decode( + batched_requests: List[Request], + model_runner: Any = None, +) -> Dict[str, Any]: + """Prepare a vLLM decode batch. + Constructs a SchedulerOutput for vLLM v1 GPUModelRunner for decode stage. + Key differences from prefill: + - Uses CachedRequestData (not NewRequestData) + - Each request processes exactly 1 token + - KV cache blocks are already allocated -def form_vllm_batch_decode(batched_requests: List[Request], is_first_peer: bool) -> Dict[str, Any] | None: - """Builds the vLLM decode batch inputs for the first peer. + Args: + batched_requests: List of Parallax requests in decode phase + model_runner: vLLM GPUModelRunner instance - For decode, the first peer feeds the last generated token per request. - Other peers return None (they use pp_proxy_tensors path). + Returns: + Dict containing: + - scheduler_output: SchedulerOutput for vLLM + - requests: Original Parallax requests + Returns None if batched_requests is empty """ - if not is_first_peer: + from vllm.v1.core.sched.output import CachedRequestData, SchedulerOutput + + if not batched_requests: return None - # For first peer, gather the next-step input token ids (last output token) - tokens: List[int] = [] - for req in batched_requests: - assert req.is_decoding, f"Request {req.request_id} is not a decode request." - assert req.output_ids is not None and len(req.output_ids) > 0, ( - f"Decode step requires at least one output token for {req.request_id}" - ) - tokens.append(req.output_ids[-1]) + # Get KV cache manager (should already be initialized in prefill) + kv_cache = model_runner._parallax_kv_cache + + req_ids = [req.request_id for req in batched_requests] + + # Build CachedRequestData for decode + # These are requests that already have KV cache allocated + cached_req_data = CachedRequestData( + req_ids=req_ids, + resumed_from_preemption=[False] * len(req_ids), # Not resuming from preemption + new_token_ids=[[] for _ in req_ids], # Empty for non-pipeline-parallel + resumed_req_token_ids=[None for _ in req_ids], # Not resumed + new_block_ids=[None for _ in req_ids], # No new blocks needed for decode + num_computed_tokens=[req.current_position for req in batched_requests], + num_output_tokens=[ + len(req.output_ids) if hasattr(req, "output_ids") else 0 for req in batched_requests + ], + ) + + # Build SchedulerOutput for decode + scheduler_output = SchedulerOutput( + scheduled_new_reqs=[], # No new requests in decode + scheduled_cached_reqs=cached_req_data, + num_scheduled_tokens={req_id: 1 for req_id in req_ids}, # 1 token per request in decode + total_num_scheduled_tokens=len(req_ids), + scheduled_spec_decode_tokens={}, + scheduled_encoder_inputs={}, + num_common_prefix_blocks=[], + finished_req_ids=set(), + free_encoder_mm_hashes=[], + structured_output_request_ids=[], + grammar_bitmask=None, + kv_connector_metadata=None, + ) + + return scheduler_output + + +def release_vllm_request(model_runner: Any, request_id: str): + """Release KV Cache and other resources for finished/aborted requests. + + Similar to SGLang's release_cuda_request but for vLLM backend. + + Args: + model_runner: vLLM GPUModelRunner instance + request_id: ID of the request to release + """ + if not hasattr(model_runner, "_parallax_kv_cache"): + logger.warning(f"KV cache manager not found when releasing request {request_id}") + return - # Use shape [batch, 1] for consistency - return {"input_ids": [[tok] for tok in tokens]} + kv_cache = model_runner._parallax_kv_cache + kv_cache.free(request_id) diff --git a/src/parallax/vllm/model_runner.py b/src/parallax/vllm/model_runner.py index 2ae28399..41330517 100644 --- a/src/parallax/vllm/model_runner.py +++ b/src/parallax/vllm/model_runner.py @@ -1 +1,228 @@ -pass +""" +vLLM Model Runner wrapper for Parallax. + +Integrates vLLM v1 GPUModelRunner for CUDA backend. +""" + +from __future__ import annotations + +from typing import Any, Dict, List, Optional, Tuple + +import torch +from transformers import AutoConfig, AutoTokenizer + +from vllm.config import ( + CacheConfig, + DecodingConfig, + DeviceConfig, + LoadConfig, + ModelConfig, + ParallelConfig, + SchedulerConfig, +) +from vllm.v1.core.sched.output import CachedRequestData, NewRequestData, SchedulerOutput +from vllm.v1.kv_cache_interface import KVCacheConfig +from vllm.v1.outputs import ModelRunnerOutput +from vllm.v1.worker.gpu_model_runner import GPUModelRunner +from vllm.config import VllmConfig + +from parallax.server.request import Request +from parallax_utils.logging_config import get_logger + +logger = get_logger(__name__) + + +def initialize_vllm_model_runner( + model_repo: str, + start_layer: int, + end_layer: int, + kv_cache_memory_fraction: float, + attention_backend: str, + kv_block_size: int, + dtype: str = "float16", +) -> Tuple[GPUModelRunner, Dict, Any]: + """Initialize vLLM GPUModelRunner. + + Args: + model_repo: HuggingFace model repo path + start_layer: Start layer index (for PP) + end_layer: End layer index (for PP) + kv_cache_memory_fraction: Fraction of GPU memory for KV cache + attention_backend: Attention backend (e.g., "flash_attn") + kv_block_size: KV cache block size + dtype: Model dtype + + Returns: + (model_runner, config_dict, tokenizer) + """ + logger.info(f"Initializing vLLM model runner for {model_repo}") + + # Load HuggingFace config and tokenizer + hf_config = AutoConfig.from_pretrained(model_repo, trust_remote_code=True) + tokenizer = AutoTokenizer.from_pretrained(model_repo, trust_remote_code=True) + + num_hidden_layers = hf_config.num_hidden_layers + + # Build vLLM configs + model_config = ModelConfig( + model=model_repo, + tokenizer=model_repo, + tokenizer_mode="auto", + trust_remote_code=True, + dtype=dtype, + seed=0, + max_model_len=getattr(hf_config, "max_position_embeddings", 4096), + ) + + cache_config = CacheConfig( + block_size=kv_block_size, + gpu_memory_utilization=kv_cache_memory_fraction, + swap_space=0, + cache_dtype="auto", + ) + + # For single-node in Parallax, we don't use vLLM's internal PP + parallel_config = ParallelConfig( + pipeline_parallel_size=1, + tensor_parallel_size=1, + distributed_executor_backend=None, + ) + + device_config = DeviceConfig(device="cuda") + load_config = LoadConfig(load_format="auto") + + # Minimal scheduler config (we bypass vLLM scheduler) + scheduler_config = SchedulerConfig( + max_num_batched_tokens=8192, + max_num_seqs=256, + max_model_len=model_config.max_model_len, + ) + + decoding_config = DecodingConfig() + + vllm_config = VllmConfig( + model_config=model_config, + cache_config=cache_config, + parallel_config=parallel_config, + scheduler_config=scheduler_config, + device_config=device_config, + load_config=load_config, + lora_config=None, + speculative_config=None, + decoding_config=decoding_config, + observability_config=None, + prompt_adapter_config=None, + quant_config=None, + compilation_config=None, + ) + + # Determine KV cache blocks + kv_cache_config = KVCacheConfig( + block_size=kv_block_size, + num_gpu_blocks=None, # Will be calculated by model runner + ) + + # Initialize GPUModelRunner + model_runner = GPUModelRunner( + vllm_config=vllm_config, + kv_cache_config=kv_cache_config, + device="cuda", + ) + + # Load model + logger.info("Loading vLLM model...") + model_runner.load_model() + logger.info("vLLM model loaded successfully") + + # Return config as dict for compatibility with Parallax executor + config_dict = { + "num_hidden_layers": num_hidden_layers, + "hidden_size": hf_config.hidden_size, + "num_attention_heads": hf_config.num_attention_heads, + "num_key_value_heads": getattr( + hf_config, "num_key_value_heads", hf_config.num_attention_heads + ), + } + + return model_runner, config_dict, tokenizer + + +class VLLMKVCacheManager: + """Simple KV cache block manager for vLLM integration.""" + + def __init__(self, model_runner: GPUModelRunner, block_size: int): + self.model_runner = model_runner + self.block_size = block_size + self.request_blocks: Dict[str, List[int]] = {} + self.next_block_id = 0 + + # Get available blocks from model runner + self.total_blocks = model_runner.kv_cache_config.num_gpu_blocks + self.free_blocks = list(range(self.total_blocks)) + + def allocate(self, request_id: str, num_tokens: int) -> Tuple[List[int], ...]: + """Allocate KV cache blocks for a request. + + Returns: + block_ids: Tuple of lists of block IDs (one per KV cache layer) + """ + num_blocks_needed = (num_tokens + self.block_size - 1) // self.block_size + + if len(self.free_blocks) < num_blocks_needed: + raise RuntimeError( + f"Not enough KV cache blocks. Needed: {num_blocks_needed}, Available: {len(self.free_blocks)}" + ) + + # Allocate blocks + allocated = [] + for _ in range(num_blocks_needed): + block_id = self.free_blocks.pop(0) + allocated.append(block_id) + + self.request_blocks[request_id] = allocated + + # vLLM expects tuple of lists (one per layer group) + # For simplicity, we use the same blocks for all layers + return (allocated,) + + def free(self, request_id: str): + """Free KV cache blocks for a request.""" + if request_id in self.request_blocks: + blocks = self.request_blocks.pop(request_id) + self.free_blocks.extend(blocks) + + def get_blocks(self, request_id: str) -> Tuple[List[int], ...]: + """Get allocated blocks for a request.""" + return (self.request_blocks.get(request_id, []),) + + +from __future__ import annotations + +from typing import Any, Tuple + +from parallax_utils.logging_config import get_logger + +logger = get_logger(__name__) + + +def initialize_vllm_model_runner( + model_repo: str, + start_layer: int, + end_layer: int, + kv_cache_memory_fraction: float, + attention_backend: str, + kv_block_size: int, + dtype: str = "float16", +) -> Tuple[Any, dict, Any]: + """Initialize vLLM GPUModelRunner (scaffold). + + This function is a placeholder and documents the expected return values: + - model_runner: An object exposing execute_model() compatible with vLLM v1. + - config: A dict-like model config with at least num_hidden_layers. + - tokenizer: Tokenizer instance used by executor. + """ + raise NotImplementedError( + "vLLM backend scaffolding is present, but the actual model runner " + "initialization is not implemented yet. Please implement " + "parallax.vllm.model_runner.initialize_vllm_model_runner() per the plan." + ) From 9878bb9745b85c1bc5ad0c569168bd01baef4fd4 Mon Sep 17 00:00:00 2001 From: Alien mac air <2214632589@qq.com> Date: Tue, 21 Oct 2025 19:31:01 +0800 Subject: [PATCH 05/36] up date kvcache --- src/parallax/vllm/batch_info.py | 133 ++++++++++--- src/parallax/vllm/model_runner.py | 303 ++++++++++++++++++++---------- 2 files changed, 310 insertions(+), 126 deletions(-) diff --git a/src/parallax/vllm/batch_info.py b/src/parallax/vllm/batch_info.py index eb7b7f02..dece171e 100644 --- a/src/parallax/vllm/batch_info.py +++ b/src/parallax/vllm/batch_info.py @@ -103,12 +103,12 @@ def form_vllm_batch_prefill( Constructs a SchedulerOutput for vLLM v1 GPUModelRunner that contains: - NewRequestData for each request (new prefill requests) - - KV cache block allocations + - KV cache block allocations via vLLM's native KVCacheManager - Token scheduling information Args: batched_requests: List of Parallax requests to prefill - model_runner: vLLM GPUModelRunner instance + model_runner: ParallaxVLLMModelRunner instance with initialized kv_cache_manager Returns: Dict containing: @@ -117,28 +117,65 @@ def form_vllm_batch_prefill( Returns None if batched_requests is empty """ from vllm.v1.core.sched.output import CachedRequestData, NewRequestData, SchedulerOutput + from vllm.v1.request import Request as VLLMRequest if not batched_requests: return None - # Initialize KV cache manager if not already done - # This is a lightweight wrapper around vLLM's KV cache - if not hasattr(model_runner, "_parallax_kv_cache"): - from parallax.vllm.model_runner import VLLMKVCacheManager - - model_runner._parallax_kv_cache = VLLMKVCacheManager( - model_runner, model_runner.kv_cache_config.block_size + # Get vLLM's KVCacheManager from model_runner + if not hasattr(model_runner, "kv_cache_manager"): + raise RuntimeError( + "model_runner must have kv_cache_manager initialized. " + "Call model_runner.initialize_kv_cache_manager() first." ) - kv_cache = model_runner._parallax_kv_cache + kv_cache_manager = model_runner.kv_cache_manager # Build NewRequestData for each request new_request_data_list = [] + vllm_requests = [] + for req in batched_requests: sampling_params = transform_sampling_params_to_vllm(req.sampling_params) - # Allocate KV cache blocks for this request - block_ids = kv_cache.allocate(req.request_id, len(req.input_ids)) + # Create vLLM Request object for KV cache management + vllm_req = VLLMRequest( + request_id=req.request_id, + prompt_token_ids=req.input_ids, + sampling_params=sampling_params, + eos_token_id=getattr(req, "eos_token_id", None), + arrival_time=getattr(req, "arrival_time", 0.0), + ) + vllm_requests.append(vllm_req) + + # Check for prefix cache hits + computed_blocks, num_computed_tokens = kv_cache_manager.get_computed_blocks(vllm_req) + + # Allocate KV cache blocks for the remaining tokens + num_new_tokens = len(req.input_ids) - num_computed_tokens + if num_new_tokens > 0: + new_blocks = kv_cache_manager.allocate_slots( + request=vllm_req, + num_new_tokens=num_new_tokens, + num_new_computed_tokens=num_computed_tokens, + new_computed_blocks=computed_blocks if num_computed_tokens > 0 else None, + ) + + if new_blocks is None: + # Cannot allocate blocks (OOM) + logger.warning(f"Cannot allocate KV cache for request {req.request_id}") + # Free any allocated blocks for previous requests in this batch + for prev_req in vllm_requests[:-1]: + kv_cache_manager.free(prev_req) + return None + + # Combine computed blocks and new blocks + all_blocks = computed_blocks + new_blocks if num_computed_tokens > 0 else new_blocks + else: + all_blocks = computed_blocks + + # Get block IDs for the request + block_ids = all_blocks.get_block_ids() new_req_data = NewRequestData( req_id=req.request_id, @@ -147,7 +184,7 @@ def form_vllm_batch_prefill( sampling_params=sampling_params, pooling_params=None, # For embedding models block_ids=block_ids, - num_computed_tokens=0, # Prefill starts from scratch + num_computed_tokens=num_computed_tokens, lora_request=None, # LoRA adapter prompt_embeds=None, # Soft prompts ) @@ -183,11 +220,11 @@ def form_vllm_batch_decode( Key differences from prefill: - Uses CachedRequestData (not NewRequestData) - Each request processes exactly 1 token - - KV cache blocks are already allocated + - KV cache blocks already allocated, may need to extend Args: batched_requests: List of Parallax requests in decode phase - model_runner: vLLM GPUModelRunner instance + model_runner: ParallaxVLLMModelRunner instance with initialized kv_cache_manager Returns: Dict containing: @@ -196,14 +233,51 @@ def form_vllm_batch_decode( Returns None if batched_requests is empty """ from vllm.v1.core.sched.output import CachedRequestData, SchedulerOutput + from vllm.v1.request import Request as VLLMRequest if not batched_requests: return None - # Get KV cache manager (should already be initialized in prefill) - kv_cache = model_runner._parallax_kv_cache + # Get vLLM's KVCacheManager + if not hasattr(model_runner, "kv_cache_manager"): + raise RuntimeError( + "model_runner must have kv_cache_manager initialized. " + "Call model_runner.initialize_kv_cache_manager() first." + ) + + kv_cache_manager = model_runner.kv_cache_manager - req_ids = [req.request_id for req in batched_requests] + req_ids = [] + vllm_requests = [] + + for req in batched_requests: + req_ids.append(req.request_id) + + # Create or retrieve vLLM Request object + # In decode phase, request should already exist + sampling_params = transform_sampling_params_to_vllm(req.sampling_params) + vllm_req = VLLMRequest( + request_id=req.request_id, + prompt_token_ids=req.input_ids, + sampling_params=sampling_params, + eos_token_id=getattr(req, "eos_token_id", None), + arrival_time=getattr(req, "arrival_time", 0.0), + ) + vllm_req.num_computed_tokens = req.current_position - 1 # Update computed tokens + vllm_requests.append(vllm_req) + + # Allocate slot for 1 new decode token + # This may require allocating a new block if current block is full + new_blocks = kv_cache_manager.allocate_slots( + request=vllm_req, + num_new_tokens=1, # Decode generates 1 token at a time + num_new_computed_tokens=0, + ) + + if new_blocks is None: + # Cannot allocate (OOM), need to preempt or wait + logger.warning(f"Cannot allocate KV cache for decode request {req.request_id}") + return None # Build CachedRequestData for decode # These are requests that already have KV cache allocated @@ -241,15 +315,28 @@ def form_vllm_batch_decode( def release_vllm_request(model_runner: Any, request_id: str): """Release KV Cache and other resources for finished/aborted requests. - Similar to SGLang's release_cuda_request but for vLLM backend. + Uses vLLM's native KVCacheManager to properly free allocated blocks + and update prefix cache if enabled. Args: - model_runner: vLLM GPUModelRunner instance + model_runner: ParallaxVLLMModelRunner instance with kv_cache_manager request_id: ID of the request to release """ - if not hasattr(model_runner, "_parallax_kv_cache"): + from vllm.v1.request import Request as VLLMRequest + + if not hasattr(model_runner, "kv_cache_manager"): logger.warning(f"KV cache manager not found when releasing request {request_id}") return - kv_cache = model_runner._parallax_kv_cache - kv_cache.free(request_id) + kv_cache_manager = model_runner.kv_cache_manager + + # Create a minimal vLLM Request object for the free operation + # Note: We need the request object, not just the ID + # In a real scenario, we should maintain a mapping of request_id -> vLLMRequest + # For now, we'll use the KVCacheManager's coordinator directly + try: + # The coordinator can free by request_id directly + kv_cache_manager.coordinator.free(request_id) + logger.debug(f"Released KV cache for request {request_id}") + except Exception as e: + logger.warning(f"Error releasing KV cache for request {request_id}: {e}") diff --git a/src/parallax/vllm/model_runner.py b/src/parallax/vllm/model_runner.py index 41330517..fb8fa96a 100644 --- a/src/parallax/vllm/model_runner.py +++ b/src/parallax/vllm/model_runner.py @@ -1,7 +1,8 @@ """ -vLLM Model Runner wrapper for Parallax. +vLLM Model Runner wrapper for Parallax with Pipeline Parallelism support. Integrates vLLM v1 GPUModelRunner for CUDA backend. +Uses vLLM's native Pipeline Parallelism mechanism to load only required layers. """ from __future__ import annotations @@ -25,6 +26,11 @@ from vllm.v1.outputs import ModelRunnerOutput from vllm.v1.worker.gpu_model_runner import GPUModelRunner from vllm.config import VllmConfig +from vllm.distributed import ( + initialize_model_parallel, + get_pp_group, +) +from vllm.v1.core.kv_cache_manager import KVCacheManager from parallax.server.request import Request from parallax_utils.logging_config import get_logger @@ -32,6 +38,133 @@ logger = get_logger(__name__) +class ParallaxVLLMModelRunner(GPUModelRunner): + """ + Extended vLLM GPUModelRunner that leverages vLLM's native Pipeline Parallelism. + + This class uses vLLM's PPMissingLayer mechanism to load only the required layers + during model initialization, avoiding the need to load and then prune the full model. + """ + + def __init__( + self, + vllm_config: VllmConfig, + kv_cache_config: KVCacheConfig, + device: str, + start_layer: int, + end_layer: int, + num_hidden_layers: int, + ): + """ + Args: + vllm_config: vLLM configuration object + kv_cache_config: KV cache configuration + device: Device to run on (e.g., "cuda") + start_layer: First layer index to load (inclusive) + end_layer: Last layer index to load (exclusive) + num_hidden_layers: Total number of layers in the full model + """ + # Store layer information before calling super().__init__ + self.start_layer = start_layer + self.end_layer = end_layer + self.num_hidden_layers = num_hidden_layers + self.num_shard_layers = end_layer - start_layer + + self.is_first_peer = start_layer == 0 + self.is_last_peer = end_layer == num_hidden_layers + + # Calculate PP rank and size for vLLM + # We simulate a PP setup where each Parallax peer is a PP rank + self.pp_rank = 0 # Will be updated based on layer range + self.pp_size = 1 # Single node, but with layer slicing + + # Call parent init + super().__init__(vllm_config=vllm_config, device=torch.device(device)) + self.kv_cache_config = kv_cache_config + + logger.info( + f"ParallaxVLLMModelRunner initialized: layers [{start_layer}, {end_layer}), " + f"is_first={self.is_first_peer}, is_last={self.is_last_peer}" + ) + + def initialize_kv_cache_manager(self, max_model_len: int) -> KVCacheManager: + """ + Initialize vLLM's native KVCacheManager. + + This should be called after the model is loaded to properly set up + the KV cache management system. + + Args: + max_model_len: Maximum sequence length the model can handle + + Returns: + Initialized KVCacheManager instance + """ + logger.info("Initializing vLLM KVCacheManager...") + + kv_cache_manager = KVCacheManager( + kv_cache_config=self.kv_cache_config, + max_model_len=max_model_len, + enable_caching=True, # Enable prefix caching + use_eagle=False, # Not using EAGLE speculative decoding + log_stats=True, # Enable stats logging + enable_kv_cache_events=False, # Disable KV cache events for now + dcp_world_size=1, # Decode Context Parallelism world size + ) + + self.kv_cache_manager = kv_cache_manager + logger.info( + f"KVCacheManager initialized: block_size={kv_cache_manager.block_size}, " + f"usage={kv_cache_manager.usage:.2%}" + ) + + return kv_cache_manager + + def load_model(self) -> None: + """ + Load model using vLLM's native layer loading mechanism. + + This method uses vLLM's make_layers function which creates PPMissingLayer + placeholders for layers outside [start_layer, end_layer), ensuring only + the required layers are actually loaded from checkpoint. + """ + logger.info(f"Loading vLLM model with layers [{self.start_layer}, {self.end_layer})...") + + # Temporarily override vLLM's PP configuration for this peer + # This allows us to use vLLM's layer skipping mechanism + import vllm.distributed.parallel_state as parallel_state + from vllm.distributed.utils import get_pp_indices + + # Monkey-patch get_pp_indices to return our custom layer range + original_get_pp_indices = parallel_state.get_pp_indices + + def custom_get_pp_indices(num_layers: int, rank: int, world_size: int): + """Return our custom layer range instead of vLLM's calculated range.""" + logger.debug( + f"custom_get_pp_indices called: num_layers={num_layers}, " + f"returning [{self.start_layer}, {self.end_layer})" + ) + return self.start_layer, self.end_layer + + # Temporarily replace the function + import vllm.distributed.utils + + vllm.distributed.utils.get_pp_indices = custom_get_pp_indices + + try: + # Now call the parent load_model, which will use our custom layer range + super().load_model() + logger.info( + f"Successfully loaded {self.num_shard_layers} layers " + f"[{self.start_layer}:{self.end_layer}]" + ) + finally: + # Restore original function + vllm.distributed.utils.get_pp_indices = original_get_pp_indices + + logger.info("Model loaded successfully with partial layers") + + def initialize_vllm_model_runner( model_repo: str, start_layer: int, @@ -40,22 +173,40 @@ def initialize_vllm_model_runner( attention_backend: str, kv_block_size: int, dtype: str = "float16", -) -> Tuple[GPUModelRunner, Dict, Any]: - """Initialize vLLM GPUModelRunner. - - Args: - model_repo: HuggingFace model repo path - start_layer: Start layer index (for PP) - end_layer: End layer index (for PP) - kv_cache_memory_fraction: Fraction of GPU memory for KV cache - attention_backend: Attention backend (e.g., "flash_attn") - kv_block_size: KV cache block size - dtype: Model dtype - - Returns: - (model_runner, config_dict, tokenizer) +) -> Tuple[ParallaxVLLMModelRunner, Dict, Any]: + """Initialize vLLM GPUModelRunner with true partial layer loading. + + This function leverages vLLM's native Pipeline Parallelism mechanism to load + only the required layers, avoiding the memory overhead of loading the full model. + + The key insight is to monkey-patch vLLM's get_pp_indices function during model + loading, which allows us to control exactly which layers are loaded. Layers + outside the [start_layer, end_layer) range are replaced with PPMissingLayer + placeholders that consume minimal memory. + + Args: + model_repo: HuggingFace model repo path + start_layer: Start layer index (inclusive) + end_layer: End layer index (exclusive) + kv_cache_memory_fraction: Fraction of GPU memory for KV cache + attention_backend: Attention backend (e.g., "flash_attn") + kv_block_size: KV cache block size + dtype: Model dtype + + Returns: + (model_runner, config_dict, tokenizer) + + Example: + >>> # Load only layers 8-16 of a 32-layer model + >>> runner, config, tok = initialize_vllm_model_runner( + ... "meta-llama/Llama-2-7b-hf", 8, 16, 0.8, "flash_attn", 64 + ... ) + >>> # Only 8 layers are actually loaded into memory + ``` """ - logger.info(f"Initializing vLLM model runner for {model_repo}") + logger.info( + f"Initializing vLLM model runner for {model_repo}, " f"layers=[{start_layer}, {end_layer})" + ) # Load HuggingFace config and tokenizer hf_config = AutoConfig.from_pretrained(model_repo, trust_remote_code=True) @@ -63,6 +214,12 @@ def initialize_vllm_model_runner( num_hidden_layers = hf_config.num_hidden_layers + if end_layer > num_hidden_layers: + raise ValueError( + f"end_layer ({end_layer}) cannot be greater than " + f"num_hidden_layers ({num_hidden_layers})" + ) + # Build vLLM configs model_config = ModelConfig( model=model_repo, @@ -81,9 +238,19 @@ def initialize_vllm_model_runner( cache_dtype="auto", ) - # For single-node in Parallax, we don't use vLLM's internal PP + # Configure PP for layer slicing + # We set pp_size > 1 to enable vLLM's layer skipping mechanism + # but use our custom get_pp_indices to control which layers to load + is_first_peer = start_layer == 0 + is_last_peer = end_layer == num_hidden_layers + + # Calculate a virtual PP size that makes sense + # For example, if we have 32 layers and loading [8, 16), we're in the "middle" + # Set pp_size=2 to enable PP mode, and we'll override the layer calculation + virtual_pp_size = 2 if not (is_first_peer and is_last_peer) else 1 + parallel_config = ParallelConfig( - pipeline_parallel_size=1, + pipeline_parallel_size=virtual_pp_size, tensor_parallel_size=1, distributed_executor_backend=None, ) @@ -122,18 +289,26 @@ def initialize_vllm_model_runner( num_gpu_blocks=None, # Will be calculated by model runner ) - # Initialize GPUModelRunner - model_runner = GPUModelRunner( + # Initialize our custom ParallaxVLLMModelRunner + model_runner = ParallaxVLLMModelRunner( vllm_config=vllm_config, kv_cache_config=kv_cache_config, device="cuda", + start_layer=start_layer, + end_layer=end_layer, + num_hidden_layers=num_hidden_layers, ) - # Load model - logger.info("Loading vLLM model...") + # Load model with partial layers + logger.info("Loading vLLM model (partial layers)...") model_runner.load_model() logger.info("vLLM model loaded successfully") + # Initialize KV Cache Manager after model is loaded + logger.info("Initializing KV Cache Manager...") + model_runner.initialize_kv_cache_manager(max_model_len=model_config.max_model_len) + logger.info("KV Cache Manager initialized successfully") + # Return config as dict for compatibility with Parallax executor config_dict = { "num_hidden_layers": num_hidden_layers, @@ -142,87 +317,9 @@ def initialize_vllm_model_runner( "num_key_value_heads": getattr( hf_config, "num_key_value_heads", hf_config.num_attention_heads ), + "head_dim": getattr( + hf_config, "head_dim", hf_config.hidden_size // hf_config.num_attention_heads + ), } return model_runner, config_dict, tokenizer - - -class VLLMKVCacheManager: - """Simple KV cache block manager for vLLM integration.""" - - def __init__(self, model_runner: GPUModelRunner, block_size: int): - self.model_runner = model_runner - self.block_size = block_size - self.request_blocks: Dict[str, List[int]] = {} - self.next_block_id = 0 - - # Get available blocks from model runner - self.total_blocks = model_runner.kv_cache_config.num_gpu_blocks - self.free_blocks = list(range(self.total_blocks)) - - def allocate(self, request_id: str, num_tokens: int) -> Tuple[List[int], ...]: - """Allocate KV cache blocks for a request. - - Returns: - block_ids: Tuple of lists of block IDs (one per KV cache layer) - """ - num_blocks_needed = (num_tokens + self.block_size - 1) // self.block_size - - if len(self.free_blocks) < num_blocks_needed: - raise RuntimeError( - f"Not enough KV cache blocks. Needed: {num_blocks_needed}, Available: {len(self.free_blocks)}" - ) - - # Allocate blocks - allocated = [] - for _ in range(num_blocks_needed): - block_id = self.free_blocks.pop(0) - allocated.append(block_id) - - self.request_blocks[request_id] = allocated - - # vLLM expects tuple of lists (one per layer group) - # For simplicity, we use the same blocks for all layers - return (allocated,) - - def free(self, request_id: str): - """Free KV cache blocks for a request.""" - if request_id in self.request_blocks: - blocks = self.request_blocks.pop(request_id) - self.free_blocks.extend(blocks) - - def get_blocks(self, request_id: str) -> Tuple[List[int], ...]: - """Get allocated blocks for a request.""" - return (self.request_blocks.get(request_id, []),) - - -from __future__ import annotations - -from typing import Any, Tuple - -from parallax_utils.logging_config import get_logger - -logger = get_logger(__name__) - - -def initialize_vllm_model_runner( - model_repo: str, - start_layer: int, - end_layer: int, - kv_cache_memory_fraction: float, - attention_backend: str, - kv_block_size: int, - dtype: str = "float16", -) -> Tuple[Any, dict, Any]: - """Initialize vLLM GPUModelRunner (scaffold). - - This function is a placeholder and documents the expected return values: - - model_runner: An object exposing execute_model() compatible with vLLM v1. - - config: A dict-like model config with at least num_hidden_layers. - - tokenizer: Tokenizer instance used by executor. - """ - raise NotImplementedError( - "vLLM backend scaffolding is present, but the actual model runner " - "initialization is not implemented yet. Please implement " - "parallax.vllm.model_runner.initialize_vllm_model_runner() per the plan." - ) From 7cafb418e5e3f8c961e3f41ddd93ebd8541a5f3f Mon Sep 17 00:00:00 2001 From: Alien mac air <2214632589@qq.com> Date: Tue, 21 Oct 2025 20:11:28 +0800 Subject: [PATCH 06/36] link kv_cache_manager & model_runner --- src/parallax/vllm/batch_info.py | 157 +++++++++++++++++------------- src/parallax/vllm/model_runner.py | 40 +++++++- 2 files changed, 126 insertions(+), 71 deletions(-) diff --git a/src/parallax/vllm/batch_info.py b/src/parallax/vllm/batch_info.py index dece171e..1448dfd3 100644 --- a/src/parallax/vllm/batch_info.py +++ b/src/parallax/vllm/batch_info.py @@ -13,12 +13,13 @@ from __future__ import annotations -from typing import Any, Dict, List +from typing import Any, Dict, List, Optional from parallax.server.request import Request from parallax.server.sampling.sampling_params import ( SamplingParams as ParallaxSamplingParams, ) +from vllm.v1.core.sched.output import CachedRequestData, NewRequestData, SchedulerOutput from vllm.v1.request import Request as VLLMRequest from vllm.sampling_params import ( SamplingParams as VLLMSamplingParams, @@ -68,7 +69,10 @@ def transform_sampling_params_to_vllm(old_params: ParallaxSamplingParams) -> VLL return params -def transform_requests_to_vllm(batched_requests: List[Request]) -> List[VLLMRequest]: +def transform_requests_to_vllm( + batched_requests: List[Request], + model_runner: Any | None = None, +) -> List[VLLMRequest]: """Transforms Parallax Request to vLLM Request format. Note: Only used if we later choose to feed vLLM Engine directly. @@ -83,22 +87,52 @@ def transform_requests_to_vllm(batched_requests: List[Request]) -> List[VLLMRequ vllm_reqs = [] for old_req in batched_requests: sampling_params = transform_sampling_params_to_vllm(old_req.sampling_params) + block_hasher = getattr(model_runner, "request_block_hasher", None) if model_runner else None vllm_req = VLLMRequest( request_id=old_req.request_id, prompt_token_ids=old_req.input_ids, sampling_params=sampling_params, + pooling_params=None, eos_token_id=getattr(old_req, "eos_token_id", None), client_index=getattr(old_req, "client_index", 0), + block_hasher=block_hasher, ) + output_ids = getattr(old_req, "output_ids", None) or [] + if output_ids: + vllm_req.append_output_token_ids(output_ids) vllm_reqs.append(vllm_req) return vllm_reqs +def _build_vllm_request( + req: Request, + sampling_params: VLLMSamplingParams, + model_runner: Any, + *, + include_outputs: bool, +) -> VLLMRequest: + block_hasher = getattr(model_runner, "request_block_hasher", None) + vllm_req = VLLMRequest( + request_id=req.request_id, + prompt_token_ids=getattr(req, "input_ids", None), + sampling_params=sampling_params, + pooling_params=None, + eos_token_id=getattr(req, "eos_token_id", None), + arrival_time=getattr(req, "arrival_time", 0.0), + block_hasher=block_hasher, + ) + if include_outputs: + output_ids = getattr(req, "output_ids", None) or [] + if output_ids: + vllm_req.append_output_token_ids(output_ids) + return vllm_req + + def form_vllm_batch_prefill( batched_requests: List[Request], model_runner: Any = None, -) -> Dict[str, Any]: +) -> Optional[SchedulerOutput]: """Prepare a vLLM prefill batch. Constructs a SchedulerOutput for vLLM v1 GPUModelRunner that contains: @@ -111,14 +145,8 @@ def form_vllm_batch_prefill( model_runner: ParallaxVLLMModelRunner instance with initialized kv_cache_manager Returns: - Dict containing: - - scheduler_output: SchedulerOutput for vLLM - - requests: Original Parallax requests - Returns None if batched_requests is empty + SchedulerOutput compatible with vLLM GPUModelRunner, or None if the batch is empty. """ - from vllm.v1.core.sched.output import CachedRequestData, NewRequestData, SchedulerOutput - from vllm.v1.request import Request as VLLMRequest - if not batched_requests: return None @@ -131,28 +159,27 @@ def form_vllm_batch_prefill( kv_cache_manager = model_runner.kv_cache_manager + num_common_prefix_blocks = [0] * getattr(kv_cache_manager, "num_kv_cache_groups", 1) + + created_vllm_requests: List[VLLMRequest] = [] + # Build NewRequestData for each request new_request_data_list = [] - vllm_requests = [] + num_scheduled_tokens: Dict[str, int] = {} + total_tokens = 0 for req in batched_requests: sampling_params = transform_sampling_params_to_vllm(req.sampling_params) - # Create vLLM Request object for KV cache management - vllm_req = VLLMRequest( - request_id=req.request_id, - prompt_token_ids=req.input_ids, - sampling_params=sampling_params, - eos_token_id=getattr(req, "eos_token_id", None), - arrival_time=getattr(req, "arrival_time", 0.0), - ) - vllm_requests.append(vllm_req) + vllm_req = _build_vllm_request(req, sampling_params, model_runner, include_outputs=False) + created_vllm_requests.append(vllm_req) # Check for prefix cache hits computed_blocks, num_computed_tokens = kv_cache_manager.get_computed_blocks(vllm_req) # Allocate KV cache blocks for the remaining tokens - num_new_tokens = len(req.input_ids) - num_computed_tokens + prompt_token_ids = getattr(req, "input_ids", None) or [] + num_new_tokens = max(len(prompt_token_ids) - num_computed_tokens, 0) if num_new_tokens > 0: new_blocks = kv_cache_manager.allocate_slots( request=vllm_req, @@ -165,7 +192,7 @@ def form_vllm_batch_prefill( # Cannot allocate blocks (OOM) logger.warning(f"Cannot allocate KV cache for request {req.request_id}") # Free any allocated blocks for previous requests in this batch - for prev_req in vllm_requests[:-1]: + for prev_req in created_vllm_requests[:-1]: kv_cache_manager.free(prev_req) return None @@ -190,16 +217,20 @@ def form_vllm_batch_prefill( ) new_request_data_list.append(new_req_data) + scheduled_tokens = len(prompt_token_ids) + num_scheduled_tokens[req.request_id] = scheduled_tokens + total_tokens += scheduled_tokens + # Build SchedulerOutput # This is the main data structure that vLLM's model runner expects scheduler_output = SchedulerOutput( scheduled_new_reqs=new_request_data_list, scheduled_cached_reqs=CachedRequestData.make_empty(), # No cached reqs in prefill - num_scheduled_tokens={req.request_id: len(req.input_ids) for req in batched_requests}, - total_num_scheduled_tokens=sum(len(req.input_ids) for req in batched_requests), + num_scheduled_tokens=num_scheduled_tokens, + total_num_scheduled_tokens=total_tokens, scheduled_spec_decode_tokens={}, # Speculative decoding tokens scheduled_encoder_inputs={}, # For encoder-decoder models - num_common_prefix_blocks=[], # Prefix caching + num_common_prefix_blocks=num_common_prefix_blocks, # Prefix caching baseline finished_req_ids=set(), # No finished requests in prefill free_encoder_mm_hashes=[], # Encoder multimodal hash cleanup structured_output_request_ids=[], # Requests using structured output @@ -207,13 +238,13 @@ def form_vllm_batch_prefill( kv_connector_metadata=None, # KV connector for disaggregation ) - return scheduler_output, batched_requests + return scheduler_output def form_vllm_batch_decode( batched_requests: List[Request], model_runner: Any = None, -) -> Dict[str, Any]: +) -> Optional[SchedulerOutput]: """Prepare a vLLM decode batch. Constructs a SchedulerOutput for vLLM v1 GPUModelRunner for decode stage. @@ -227,14 +258,8 @@ def form_vllm_batch_decode( model_runner: ParallaxVLLMModelRunner instance with initialized kv_cache_manager Returns: - Dict containing: - - scheduler_output: SchedulerOutput for vLLM - - requests: Original Parallax requests - Returns None if batched_requests is empty + SchedulerOutput describing the decode work, or None if the batch is empty. """ - from vllm.v1.core.sched.output import CachedRequestData, SchedulerOutput - from vllm.v1.request import Request as VLLMRequest - if not batched_requests: return None @@ -247,61 +272,63 @@ def form_vllm_batch_decode( kv_cache_manager = model_runner.kv_cache_manager - req_ids = [] - vllm_requests = [] + req_ids: List[str] = [] + resumed_from_preemption: List[bool] = [] + new_token_ids: List[List[int]] = [] + resumed_req_token_ids: List[List[int] | None] = [] + new_block_ids: List[tuple[List[int], ...] | None] = [] + num_computed_tokens: List[int] = [] + num_output_tokens: List[int] = [] + num_scheduled_tokens: Dict[str, int] = {} for req in batched_requests: req_ids.append(req.request_id) + resumed_from_preemption.append(False) + new_token_ids.append([]) + resumed_req_token_ids.append(None) - # Create or retrieve vLLM Request object - # In decode phase, request should already exist sampling_params = transform_sampling_params_to_vllm(req.sampling_params) - vllm_req = VLLMRequest( - request_id=req.request_id, - prompt_token_ids=req.input_ids, - sampling_params=sampling_params, - eos_token_id=getattr(req, "eos_token_id", None), - arrival_time=getattr(req, "arrival_time", 0.0), - ) - vllm_req.num_computed_tokens = req.current_position - 1 # Update computed tokens - vllm_requests.append(vllm_req) + vllm_req = _build_vllm_request(req, sampling_params, model_runner, include_outputs=True) + + prompt_ids = getattr(req, "input_ids", None) or [] + output_ids = getattr(req, "output_ids", None) or [] + computed_token_count = len(prompt_ids) + len(output_ids) + vllm_req.num_computed_tokens = computed_token_count - # Allocate slot for 1 new decode token - # This may require allocating a new block if current block is full new_blocks = kv_cache_manager.allocate_slots( request=vllm_req, - num_new_tokens=1, # Decode generates 1 token at a time + num_new_tokens=1, num_new_computed_tokens=0, ) if new_blocks is None: - # Cannot allocate (OOM), need to preempt or wait logger.warning(f"Cannot allocate KV cache for decode request {req.request_id}") return None - # Build CachedRequestData for decode - # These are requests that already have KV cache allocated + new_block_ids.append(new_blocks.get_block_ids(allow_none=True)) + num_computed_tokens.append(computed_token_count) + num_output_tokens.append(len(output_ids)) + num_scheduled_tokens[req.request_id] = 1 + cached_req_data = CachedRequestData( req_ids=req_ids, - resumed_from_preemption=[False] * len(req_ids), # Not resuming from preemption - new_token_ids=[[] for _ in req_ids], # Empty for non-pipeline-parallel - resumed_req_token_ids=[None for _ in req_ids], # Not resumed - new_block_ids=[None for _ in req_ids], # No new blocks needed for decode - num_computed_tokens=[req.current_position for req in batched_requests], - num_output_tokens=[ - len(req.output_ids) if hasattr(req, "output_ids") else 0 for req in batched_requests - ], + resumed_from_preemption=resumed_from_preemption, + new_token_ids=new_token_ids, + resumed_req_token_ids=resumed_req_token_ids, + new_block_ids=new_block_ids, + num_computed_tokens=num_computed_tokens, + num_output_tokens=num_output_tokens, ) # Build SchedulerOutput for decode scheduler_output = SchedulerOutput( scheduled_new_reqs=[], # No new requests in decode scheduled_cached_reqs=cached_req_data, - num_scheduled_tokens={req_id: 1 for req_id in req_ids}, # 1 token per request in decode - total_num_scheduled_tokens=len(req_ids), + num_scheduled_tokens=num_scheduled_tokens, + total_num_scheduled_tokens=sum(num_scheduled_tokens.values()), scheduled_spec_decode_tokens={}, scheduled_encoder_inputs={}, - num_common_prefix_blocks=[], + num_common_prefix_blocks=[0] * getattr(kv_cache_manager, "num_kv_cache_groups", 1), finished_req_ids=set(), free_encoder_mm_hashes=[], structured_output_request_ids=[], @@ -322,8 +349,6 @@ def release_vllm_request(model_runner: Any, request_id: str): model_runner: ParallaxVLLMModelRunner instance with kv_cache_manager request_id: ID of the request to release """ - from vllm.v1.request import Request as VLLMRequest - if not hasattr(model_runner, "kv_cache_manager"): logger.warning(f"KV cache manager not found when releasing request {request_id}") return diff --git a/src/parallax/vllm/model_runner.py b/src/parallax/vllm/model_runner.py index fb8fa96a..2177cfaa 100644 --- a/src/parallax/vllm/model_runner.py +++ b/src/parallax/vllm/model_runner.py @@ -7,7 +7,8 @@ from __future__ import annotations -from typing import Any, Dict, List, Optional, Tuple +import importlib +from typing import Any, Callable, Dict, List, Optional, Tuple import torch from transformers import AutoConfig, AutoTokenizer @@ -21,16 +22,15 @@ ParallelConfig, SchedulerConfig, ) -from vllm.v1.core.sched.output import CachedRequestData, NewRequestData, SchedulerOutput -from vllm.v1.kv_cache_interface import KVCacheConfig -from vllm.v1.outputs import ModelRunnerOutput -from vllm.v1.worker.gpu_model_runner import GPUModelRunner from vllm.config import VllmConfig from vllm.distributed import ( initialize_model_parallel, get_pp_group, ) from vllm.v1.core.kv_cache_manager import KVCacheManager +from vllm.v1.core.kv_cache_utils import get_request_block_hasher, init_none_hash +from vllm.v1.kv_cache_interface import KVCacheConfig +from vllm.v1.worker.gpu_model_runner import GPUModelRunner from parallax.server.request import Request from parallax_utils.logging_config import get_logger @@ -78,6 +78,9 @@ def __init__( self.pp_rank = 0 # Will be updated based on layer range self.pp_size = 1 # Single node, but with layer slicing + self.request_block_hasher: Optional[Callable[[Any], List[Any]]] = None + self.enable_prefix_caching: bool = True + # Call parent init super().__init__(vllm_config=vllm_config, device=torch.device(device)) self.kv_cache_config = kv_cache_config @@ -113,6 +116,33 @@ def initialize_kv_cache_manager(self, max_model_len: int) -> KVCacheManager: ) self.kv_cache_manager = kv_cache_manager + cache_config = self.vllm_config.cache_config + enable_prefix = cache_config.enable_prefix_caching + if enable_prefix is None: + enable_prefix = True + self.enable_prefix_caching = enable_prefix + + self.request_block_hasher = None + if enable_prefix and kv_cache_manager.block_size is not None: + try: + hashing_mod = importlib.import_module("vllm.utils.hashing") + get_hash_fn_by_name: Callable[[str], Callable[[Any], bytes]] = getattr( + hashing_mod, "get_hash_fn_by_name" + ) + hash_fn = get_hash_fn_by_name(cache_config.prefix_caching_hash_algo) + except (ModuleNotFoundError, AttributeError) as exc: + logger.warning("Unable to initialize prefix cache hashing: %s", exc) + else: + init_none_hash(hash_fn) + block_size = kv_cache_manager.block_size + if block_size is None and self.kv_cache_config.kv_cache_groups: + block_size = self.kv_cache_config.kv_cache_groups[0].kv_cache_spec.block_size + if block_size is not None: + self.request_block_hasher = get_request_block_hasher(block_size, hash_fn) + logger.info( + "Initialized prefix cache block hasher with block_size=%d", block_size + ) + logger.info( f"KVCacheManager initialized: block_size={kv_cache_manager.block_size}, " f"usage={kv_cache_manager.usage:.2%}" From 7873e58ef43aa5e941c23e570a2fb8f411d4f240 Mon Sep 17 00:00:00 2001 From: Alien mac air <2214632589@qq.com> Date: Tue, 21 Oct 2025 21:48:09 +0800 Subject: [PATCH 07/36] add kvcache config --- src/parallax/vllm/batch_info.py | 36 -------------- src/parallax/vllm/model_runner.py | 81 +++++++++++++++++++++++++++---- 2 files changed, 71 insertions(+), 46 deletions(-) diff --git a/src/parallax/vllm/batch_info.py b/src/parallax/vllm/batch_info.py index 1448dfd3..9c413d10 100644 --- a/src/parallax/vllm/batch_info.py +++ b/src/parallax/vllm/batch_info.py @@ -69,42 +69,6 @@ def transform_sampling_params_to_vllm(old_params: ParallaxSamplingParams) -> VLL return params -def transform_requests_to_vllm( - batched_requests: List[Request], - model_runner: Any | None = None, -) -> List[VLLMRequest]: - """Transforms Parallax Request to vLLM Request format. - - Note: Only used if we later choose to feed vLLM Engine directly. - Currently we bypass the Engine and use GPUModelRunner directly. - - Args: - batched_requests: List of Parallax requests - - Returns: - List of vLLM Request objects - """ - vllm_reqs = [] - for old_req in batched_requests: - sampling_params = transform_sampling_params_to_vllm(old_req.sampling_params) - block_hasher = getattr(model_runner, "request_block_hasher", None) if model_runner else None - vllm_req = VLLMRequest( - request_id=old_req.request_id, - prompt_token_ids=old_req.input_ids, - sampling_params=sampling_params, - pooling_params=None, - eos_token_id=getattr(old_req, "eos_token_id", None), - client_index=getattr(old_req, "client_index", 0), - block_hasher=block_hasher, - ) - output_ids = getattr(old_req, "output_ids", None) or [] - if output_ids: - vllm_req.append_output_token_ids(output_ids) - vllm_reqs.append(vllm_req) - - return vllm_reqs - - def _build_vllm_request( req: Request, sampling_params: VLLMSamplingParams, diff --git a/src/parallax/vllm/model_runner.py b/src/parallax/vllm/model_runner.py index 2177cfaa..a2559156 100644 --- a/src/parallax/vllm/model_runner.py +++ b/src/parallax/vllm/model_runner.py @@ -28,8 +28,13 @@ get_pp_group, ) from vllm.v1.core.kv_cache_manager import KVCacheManager -from vllm.v1.core.kv_cache_utils import get_request_block_hasher, init_none_hash -from vllm.v1.kv_cache_interface import KVCacheConfig +from vllm.v1.core.kv_cache_utils import ( + get_request_block_hasher, + init_none_hash, + get_kv_cache_configs, + generate_scheduler_kv_cache_config, +) +from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec from vllm.v1.worker.gpu_model_runner import GPUModelRunner from parallax.server.request import Request @@ -49,7 +54,7 @@ class ParallaxVLLMModelRunner(GPUModelRunner): def __init__( self, vllm_config: VllmConfig, - kv_cache_config: KVCacheConfig, + kv_cache_config: Optional[KVCacheConfig], device: str, start_layer: int, end_layer: int, @@ -58,7 +63,7 @@ def __init__( """ Args: vllm_config: vLLM configuration object - kv_cache_config: KV cache configuration + kv_cache_config: KV cache configuration (can be None, will be created by KVCacheManager) device: Device to run on (e.g., "cuda") start_layer: First layer index to load (inclusive) end_layer: Last layer index to load (exclusive) @@ -83,6 +88,7 @@ def __init__( # Call parent init super().__init__(vllm_config=vllm_config, device=torch.device(device)) + # KV cache config will be created by KVCacheManager during initialization self.kv_cache_config = kv_cache_config logger.info( @@ -90,6 +96,59 @@ def __init__( f"is_first={self.is_first_peer}, is_last={self.is_last_peer}" ) + def _create_kv_cache_config(self) -> KVCacheConfig: + """ + Create KV cache configuration from the loaded model. + + This method leverages vLLM's native KV cache configuration generation + by extracting KV cache specs from the model's attention layers and + using vLLM's utilities to generate the proper configuration. + + Returns: + KVCacheConfig: Properly configured KV cache configuration + """ + logger.info("Generating KV cache configuration from model...") + + # Get KV cache specs from model's attention layers + # This method is provided by vLLM's GPUModelRunner + kv_cache_specs = self.model.get_kv_cache_spec() + + # Get available GPU memory for KV cache + # Use vLLM's memory profiling utilities + from vllm.utils import get_gpu_memory + + free_memory, _ = get_gpu_memory(self.device.index or 0) + + # Calculate available memory for KV cache based on cache_config + gpu_memory_utilization = self.cache_config.gpu_memory_utilization + available_memory = int(free_memory * gpu_memory_utilization) + + logger.info( + f"Available GPU memory for KV cache: " + f"{available_memory / (1024**3):.2f} GB " + f"({gpu_memory_utilization:.1%} of {free_memory / (1024**3):.2f} GB)" + ) + + # Use vLLM's utility to generate KV cache config + # This handles all the complexity of different attention types, + # hybrid models, sliding windows, etc. + kv_cache_configs = get_kv_cache_configs( + vllm_config=self.vllm_config, + kv_cache_specs=[kv_cache_specs], # Single worker + available_memory=[available_memory], + ) + + # For scheduler (single worker case), we can use the first config + kv_cache_config = generate_scheduler_kv_cache_config(kv_cache_configs) + + logger.info( + f"KV cache config generated: " + f"num_blocks={kv_cache_config.num_blocks}, " + f"num_groups={len(kv_cache_config.kv_cache_groups)}" + ) + + return kv_cache_config + def initialize_kv_cache_manager(self, max_model_len: int) -> KVCacheManager: """ Initialize vLLM's native KVCacheManager. @@ -105,6 +164,10 @@ def initialize_kv_cache_manager(self, max_model_len: int) -> KVCacheManager: """ logger.info("Initializing vLLM KVCacheManager...") + # Generate KV cache config from model if not already provided + if self.kv_cache_config is None: + self.kv_cache_config = self._create_kv_cache_config() + kv_cache_manager = KVCacheManager( kv_cache_config=self.kv_cache_config, max_model_len=max_model_len, @@ -313,16 +376,14 @@ def initialize_vllm_model_runner( compilation_config=None, ) - # Determine KV cache blocks - kv_cache_config = KVCacheConfig( - block_size=kv_block_size, - num_gpu_blocks=None, # Will be calculated by model runner - ) + # Note: KVCacheConfig will be created by vLLM's KVCacheManager during initialization + # We don't need to manually create it here as it requires complex layer-specific information + # The KVCacheManager will handle this based on the model's architecture # Initialize our custom ParallaxVLLMModelRunner model_runner = ParallaxVLLMModelRunner( vllm_config=vllm_config, - kv_cache_config=kv_cache_config, + kv_cache_config=None, # Will be created by KVCacheManager device="cuda", start_layer=start_layer, end_layer=end_layer, From 4dfd95184cc731bd78bc86a283bc0a0e877a72bf Mon Sep 17 00:00:00 2001 From: yuhao_zhang Date: Wed, 22 Oct 2025 09:40:36 +0800 Subject: [PATCH 08/36] pre commit --- src/parallax/vllm/batch_info.py | 11 +++++------ src/parallax/vllm/model_runner.py | 15 ++++----------- 2 files changed, 9 insertions(+), 17 deletions(-) diff --git a/src/parallax/vllm/batch_info.py b/src/parallax/vllm/batch_info.py index 9c413d10..6330554a 100644 --- a/src/parallax/vllm/batch_info.py +++ b/src/parallax/vllm/batch_info.py @@ -15,16 +15,15 @@ from typing import Any, Dict, List, Optional +from vllm.sampling_params import SamplingParams as VLLMSamplingParams +from vllm.sampling_params import StructuredOutputsParams +from vllm.v1.core.sched.output import CachedRequestData, NewRequestData, SchedulerOutput +from vllm.v1.request import Request as VLLMRequest + from parallax.server.request import Request from parallax.server.sampling.sampling_params import ( SamplingParams as ParallaxSamplingParams, ) -from vllm.v1.core.sched.output import CachedRequestData, NewRequestData, SchedulerOutput -from vllm.v1.request import Request as VLLMRequest -from vllm.sampling_params import ( - SamplingParams as VLLMSamplingParams, - StructuredOutputsParams, -) from parallax_utils.logging_config import get_logger logger = get_logger(__name__) diff --git a/src/parallax/vllm/model_runner.py b/src/parallax/vllm/model_runner.py index a2559156..ae5fed43 100644 --- a/src/parallax/vllm/model_runner.py +++ b/src/parallax/vllm/model_runner.py @@ -12,7 +12,6 @@ import torch from transformers import AutoConfig, AutoTokenizer - from vllm.config import ( CacheConfig, DecodingConfig, @@ -21,23 +20,18 @@ ModelConfig, ParallelConfig, SchedulerConfig, -) -from vllm.config import VllmConfig -from vllm.distributed import ( - initialize_model_parallel, - get_pp_group, + VllmConfig, ) from vllm.v1.core.kv_cache_manager import KVCacheManager from vllm.v1.core.kv_cache_utils import ( + generate_scheduler_kv_cache_config, + get_kv_cache_configs, get_request_block_hasher, init_none_hash, - get_kv_cache_configs, - generate_scheduler_kv_cache_config, ) -from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec +from vllm.v1.kv_cache_interface import KVCacheConfig from vllm.v1.worker.gpu_model_runner import GPUModelRunner -from parallax.server.request import Request from parallax_utils.logging_config import get_logger logger = get_logger(__name__) @@ -226,7 +220,6 @@ def load_model(self) -> None: # Temporarily override vLLM's PP configuration for this peer # This allows us to use vLLM's layer skipping mechanism import vllm.distributed.parallel_state as parallel_state - from vllm.distributed.utils import get_pp_indices # Monkey-patch get_pp_indices to return our custom layer range original_get_pp_indices = parallel_state.get_pp_indices From 27950cee0956e842c8c8f3d9935c007083716b65 Mon Sep 17 00:00:00 2001 From: yuhao_zhang Date: Wed, 22 Oct 2025 10:11:06 +0800 Subject: [PATCH 09/36] update --- src/parallax/server/executor.py | 9 +++++++-- src/parallax/server/server_args.py | 2 +- 2 files changed, 8 insertions(+), 3 deletions(-) diff --git a/src/parallax/server/executor.py b/src/parallax/server/executor.py index cf9e31c8..3a9903dd 100644 --- a/src/parallax/server/executor.py +++ b/src/parallax/server/executor.py @@ -115,7 +115,10 @@ def __init__( f"Initializing vLLM model runner for repo={model_repo}, layers=[{start_layer}, {end_layer})" ) elif self.backend_type == "sglang": - from sglang.srt.managers.schedule_batch import ScheduleBatch as CudaScheduleBatch + from sglang.srt.managers.schedule_batch import ( + ScheduleBatch as CudaScheduleBatch, + ) + from parallax.sglang.model_runner import ( initialize_sgl_model_runner as initialize_cuda_model_runner, ) @@ -1336,7 +1339,9 @@ def run_loop(self): release_vllm_request(self.model_runner, req.request_id) elif self.backend_type == "sglang": - from parallax.sglang.batch_info import release_sglang_request + from parallax.sglang.batch_info import ( + release_sglang_request, + ) release_sglang_request(self.running_batch, req.request_id) else: diff --git a/src/parallax/server/server_args.py b/src/parallax/server/server_args.py index c488e821..3f6a8458 100644 --- a/src/parallax/server/server_args.py +++ b/src/parallax/server/server_args.py @@ -166,7 +166,7 @@ def parse_args() -> argparse.Namespace: parser.add_argument("--verbose", action="store_true", help="Enable verbose logging") parser.add_argument( - "--gpu_backend", + "--gpu-backend", type=str, default="sglang", choices=["sglang", "vllm"], From 672da47a1575bbe31f5c4e4f3a20d841d5007e96 Mon Sep 17 00:00:00 2001 From: yuhao_zhang Date: Wed, 22 Oct 2025 04:27:39 +0000 Subject: [PATCH 10/36] update --- src/parallax/sglang/model_runner.py | 5 ++--- src/parallax/vllm/__init__.py | 17 ----------------- src/parallax/vllm/model_runner.py | 6 +----- 3 files changed, 3 insertions(+), 25 deletions(-) delete mode 100644 src/parallax/vllm/__init__.py diff --git a/src/parallax/sglang/model_runner.py b/src/parallax/sglang/model_runner.py index 1e0b4d9f..a005a03f 100755 --- a/src/parallax/sglang/model_runner.py +++ b/src/parallax/sglang/model_runner.py @@ -574,9 +574,8 @@ def initialize_sgl_model_runner( model_config.hf_config.tie_word_embeddings = False model_config.hf_config.start_layer = start_layer model_config.hf_config.end_layer = end_layer - print("Model config:", model_config) - print("model_start_layer:", model_config.hf_config.start_layer) - print("model_end_layer:", model_config.hf_config.end_layer) + logger.debug(f"model_start_layer: {model_config.hf_config.start_layer}") + logger.debug(f"model_end_layer: {model_config.hf_config.end_layer}") model_runner = ParallaxModelRunner( model_config=model_config, mem_fraction_static=kv_cache_memory_fraction, diff --git a/src/parallax/vllm/__init__.py b/src/parallax/vllm/__init__.py deleted file mode 100644 index ed34a077..00000000 --- a/src/parallax/vllm/__init__.py +++ /dev/null @@ -1,17 +0,0 @@ -""" -vLLM backend integration for Parallax distributed inference. - -This module provides vLLM model runner with pipeline parallelism support. -""" - -from parallax.vllm.model_runner import ( - ParallaxVLLMEngine, - form_vllm_engine_args, - initialize_vllm_model_runner, -) - -__all__ = [ - "ParallaxVLLMEngine", - "form_vllm_engine_args", - "initialize_vllm_model_runner", -] diff --git a/src/parallax/vllm/model_runner.py b/src/parallax/vllm/model_runner.py index ae5fed43..148b70b4 100644 --- a/src/parallax/vllm/model_runner.py +++ b/src/parallax/vllm/model_runner.py @@ -9,12 +9,11 @@ import importlib from typing import Any, Callable, Dict, List, Optional, Tuple - +import vllm import torch from transformers import AutoConfig, AutoTokenizer from vllm.config import ( CacheConfig, - DecodingConfig, DeviceConfig, LoadConfig, ModelConfig, @@ -351,8 +350,6 @@ def initialize_vllm_model_runner( max_model_len=model_config.max_model_len, ) - decoding_config = DecodingConfig() - vllm_config = VllmConfig( model_config=model_config, cache_config=cache_config, @@ -362,7 +359,6 @@ def initialize_vllm_model_runner( load_config=load_config, lora_config=None, speculative_config=None, - decoding_config=decoding_config, observability_config=None, prompt_adapter_config=None, quant_config=None, From 687891ca60574490da093f2ecd99c75b212780db Mon Sep 17 00:00:00 2001 From: yuhao_zhang Date: Wed, 22 Oct 2025 04:27:53 +0000 Subject: [PATCH 11/36] update --- pyproject.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/pyproject.toml b/pyproject.toml index a53887e3..e5d5cc02 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -45,6 +45,7 @@ mac = [ ] gpu = [ + "vllm==0.11.0", "mlx-lm==0.28.0", "mlx[cpu]==0.29.1", "sglang[all]==0.5.2", From f4b3bde020b4f6bdc7e74adb5c9c0a1571845510 Mon Sep 17 00:00:00 2001 From: yuhao_zhang Date: Fri, 24 Oct 2025 07:28:54 +0000 Subject: [PATCH 12/36] run success but response error --- src/parallax/server/executor.py | 39 ++- src/parallax/sglang/model_runner.py | 2 + src/parallax/vllm/batch_info.py | 7 +- src/parallax/vllm/model_runner.py | 369 ++++++++++++++++++++++++---- 4 files changed, 359 insertions(+), 58 deletions(-) diff --git a/src/parallax/server/executor.py b/src/parallax/server/executor.py index 3a9903dd..13e4ea2a 100644 --- a/src/parallax/server/executor.py +++ b/src/parallax/server/executor.py @@ -129,13 +129,20 @@ def __init__( else: raise ValueError(f"Unsupported GPU backend type: {self.backend_type}") + # Prepare all parameters for model runner initialization + model_runner_params = { + "model_repo": model_repo, + "start_layer": start_layer, + "end_layer": end_layer, + "kv_cache_memory_fraction": kv_cache_memory_fraction, + "attention_backend": attention_backend, + "kv_block_size": kv_block_size, + "max_num_tokens_per_batch": max_num_tokens_per_batch, + "dtype": dtype, + } + self.model_runner, self.config, self.tokenizer = initialize_cuda_model_runner( - model_repo, - start_layer, - end_layer, - kv_cache_memory_fraction, - attention_backend, - kv_block_size, + **model_runner_params ) self.running_batch = None self.cur_batch = None @@ -1100,12 +1107,28 @@ def _process_batch_cuda( # Return appropriate output based on peer position if return_decoded_tokens: - # Last peer: return sampled token IDs - return output.sampled_token_ids + # Last peer: return sampled token IDs as tensor + # Convert list[list[int]] to tensor + import torch + + sampled_token_ids = output.sampled_token_ids + if isinstance(sampled_token_ids, list) and len(sampled_token_ids) > 0: + # Convert to tensor: pad sequences to same length + max_len = max(len(seq) for seq in sampled_token_ids) + padded_tokens = [] + for seq in sampled_token_ids: + padded_seq = seq + [-1] * (max_len - len(seq)) # Pad with -1 + padded_tokens.append(padded_seq) + return torch.tensor(padded_tokens, dtype=torch.int64) + else: + return torch.tensor(sampled_token_ids, dtype=torch.int64) else: # Intermediate peer: return hidden states for next peer if hasattr(output, "hidden_states") and output.hidden_states is not None: return output.hidden_states + elif hasattr(output, "tensors") and "hidden_states" in output.tensors: + # Handle IntermediateTensors case + return output.tensors["hidden_states"] else: raise RuntimeError( "vLLM backend: expected hidden_states in output for PP, but got None. " diff --git a/src/parallax/sglang/model_runner.py b/src/parallax/sglang/model_runner.py index a005a03f..f0d9b723 100755 --- a/src/parallax/sglang/model_runner.py +++ b/src/parallax/sglang/model_runner.py @@ -529,6 +529,8 @@ def initialize_sgl_model_runner( attention_backend: str, kv_block_size: int, moe_runner_backend: str, + max_num_tokens_per_batch: int = 1024, + **kwargs, ): """ Creates a SGL ModelRunner object. diff --git a/src/parallax/vllm/batch_info.py b/src/parallax/vllm/batch_info.py index 6330554a..64ce36ce 100644 --- a/src/parallax/vllm/batch_info.py +++ b/src/parallax/vllm/batch_info.py @@ -122,7 +122,7 @@ def form_vllm_batch_prefill( kv_cache_manager = model_runner.kv_cache_manager - num_common_prefix_blocks = [0] * getattr(kv_cache_manager, "num_kv_cache_groups", 1) + num_common_prefix_blocks = [0] * len(model_runner.kv_cache_config.kv_cache_groups) created_vllm_requests: List[VLLMRequest] = [] @@ -248,7 +248,8 @@ def form_vllm_batch_decode( req_ids.append(req.request_id) resumed_from_preemption.append(False) new_token_ids.append([]) - resumed_req_token_ids.append(None) + # For decode requests, we don't have resumed token IDs + resumed_req_token_ids.append([]) sampling_params = transform_sampling_params_to_vllm(req.sampling_params) vllm_req = _build_vllm_request(req, sampling_params, model_runner, include_outputs=True) @@ -277,10 +278,8 @@ def form_vllm_batch_decode( req_ids=req_ids, resumed_from_preemption=resumed_from_preemption, new_token_ids=new_token_ids, - resumed_req_token_ids=resumed_req_token_ids, new_block_ids=new_block_ids, num_computed_tokens=num_computed_tokens, - num_output_tokens=num_output_tokens, ) # Build SchedulerOutput for decode diff --git a/src/parallax/vllm/model_runner.py b/src/parallax/vllm/model_runner.py index 148b70b4..16c525e4 100644 --- a/src/parallax/vllm/model_runner.py +++ b/src/parallax/vllm/model_runner.py @@ -14,6 +14,7 @@ from transformers import AutoConfig, AutoTokenizer from vllm.config import ( CacheConfig, + CompilationConfig, DeviceConfig, LoadConfig, ModelConfig, @@ -28,7 +29,7 @@ get_request_block_hasher, init_none_hash, ) -from vllm.v1.kv_cache_interface import KVCacheConfig +from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheGroupSpec, KVCacheTensor from vllm.v1.worker.gpu_model_runner import GPUModelRunner from parallax_utils.logging_config import get_logger @@ -36,6 +37,69 @@ logger = get_logger(__name__) +def _create_kv_cache_config_from_specs( + kv_cache_group: KVCacheGroupSpec, + attn_layers: List[str], + kv_cache_memory_fraction: float, +) -> KVCacheConfig: + """ + Create KV cache configuration from KV cache group specs and attention layers. + + This is a standalone function that can be used by both the model runner's + _create_kv_cache_config method and the initialize_vllm_model_runner function. + + Args: + kv_cache_group: KV cache group specification + attn_layers: List of attention layer names + kv_cache_memory_fraction: Fraction of GPU memory to use for KV cache + + Returns: + KVCacheConfig: Properly configured KV cache configuration + """ + import torch + + # Calculate available GPU memory for KV cache + free_memory, total_memory = torch.cuda.mem_get_info(0) + available_memory = int(free_memory * kv_cache_memory_fraction) + + logger.info( + f"Available GPU memory for KV cache: " + f"{available_memory / (1024**3):.2f} GB " + f"({kv_cache_memory_fraction:.1%} of {free_memory / (1024**3):.2f} GB)" + ) + + # Calculate page_size_bytes for proper tensor sizing + page_size_bytes = kv_cache_group.kv_cache_spec.page_size_bytes + + # Calculate reasonable number of blocks based on available memory + # Each block needs page_size_bytes, so we can fit this many blocks + max_blocks_by_memory = available_memory // page_size_bytes + + # Use a conservative estimate (80% of max possible blocks) + # But ensure we don't exceed available memory + num_blocks = max(100, min(1000, int(max_blocks_by_memory * 0.8))) + + logger.info(f"Calculated KV cache blocks: {num_blocks} (max possible: {max_blocks_by_memory})") + + # Ensure tensor size is divisible by page_size_bytes + tensor_size_bytes = page_size_bytes * num_blocks + + # Ensure KVCacheTensor.shared_by covers all attention layers; otherwise + # vLLM will assert that some layers are not initialized. + kv_cache_config = KVCacheConfig( + num_blocks=num_blocks, + kv_cache_tensors=[ + KVCacheTensor( + size=tensor_size_bytes, + shared_by=attn_layers, + ) + ], + kv_cache_groups=[kv_cache_group], + ) + + return kv_cache_config + + class ParallaxVLLMModelRunner(GPUModelRunner): """ Extended vLLM GPUModelRunner that leverages vLLM's native Pipeline Parallelism. @@ -89,7 +153,7 @@ def __init__( f"is_first={self.is_first_peer}, is_last={self.is_last_peer}" ) - def _create_kv_cache_config(self) -> KVCacheConfig: + def _create_kv_cache_config(self, kv_cache_memory_fraction: float = None) -> KVCacheConfig: """ Create KV cache configuration from the loaded model. @@ -103,36 +167,110 @@ def _create_kv_cache_config(self) -> KVCacheConfig: logger.info("Generating KV cache configuration from model...") # Get KV cache specs from model's attention layers - # This method is provided by vLLM's GPUModelRunner - kv_cache_specs = self.model.get_kv_cache_spec() + # Try to access the method directly, bypassing cudagraph wrapper if needed + try: + kv_cache_specs = self.model.get_kv_cache_spec() + except AttributeError: + # If cudagraph wrapper is blocking access, try to get specs from the underlying model + logger.warning( + "Cannot access get_kv_cache_spec due to cudagraph wrapper, using fallback method" + ) + # Use a simplified approach - let KVCacheManager handle the details + kv_cache_specs = None # Get available GPU memory for KV cache - # Use vLLM's memory profiling utilities - from vllm.utils import get_gpu_memory + # Use PyTorch's native memory info function + import torch - free_memory, _ = get_gpu_memory(self.device.index or 0) + free_memory, total_memory = torch.cuda.mem_get_info(self.device.index or 0) - # Calculate available memory for KV cache based on cache_config - gpu_memory_utilization = self.cache_config.gpu_memory_utilization - available_memory = int(free_memory * gpu_memory_utilization) + # Calculate available memory for KV cache + # Use provided fraction or fall back to cache_config + memory_fraction = ( + kv_cache_memory_fraction + if kv_cache_memory_fraction is not None + else self.cache_config.gpu_memory_utilization + ) + available_memory = int(free_memory * memory_fraction) logger.info( f"Available GPU memory for KV cache: " f"{available_memory / (1024**3):.2f} GB " - f"({gpu_memory_utilization:.1%} of {free_memory / (1024**3):.2f} GB)" + f"({memory_fraction:.1%} of {free_memory / (1024**3):.2f} GB)" ) # Use vLLM's utility to generate KV cache config # This handles all the complexity of different attention types, # hybrid models, sliding windows, etc. - kv_cache_configs = get_kv_cache_configs( - vllm_config=self.vllm_config, - kv_cache_specs=[kv_cache_specs], # Single worker - available_memory=[available_memory], - ) + if kv_cache_specs is not None: + kv_cache_configs = get_kv_cache_configs( + vllm_config=self.vllm_config, + kv_cache_specs=[kv_cache_specs], # Single worker + available_memory=[available_memory], + ) + # For scheduler (single worker case), we can use the first config + kv_cache_config = generate_scheduler_kv_cache_config(kv_cache_configs) + else: + # Fallback: create a basic KV cache config + logger.info("Using fallback KV cache configuration") - # For scheduler (single worker case), we can use the first config - kv_cache_config = generate_scheduler_kv_cache_config(kv_cache_configs) + # Try to get model info from the loaded model to create a more accurate config + try: + # Get model architecture info from the loaded model + model = self.model + if hasattr(model, "model") and hasattr(model.model, "config"): + hf_config = model.model.config + num_hidden_layers = getattr(hf_config, "num_hidden_layers", 28) + num_attention_heads = getattr(hf_config, "num_attention_heads", 8) + hidden_size = getattr(hf_config, "hidden_size", 1024) + head_size = hidden_size // num_attention_heads + else: + # Fallback to default values + num_hidden_layers = 28 + num_attention_heads = 8 + head_size = 128 + + logger.info( + f"Using model info: layers={num_hidden_layers}, heads={num_attention_heads}, head_size={head_size}" + ) + + except Exception as e: + logger.warning(f"Could not get model info: {e}, using defaults") + num_hidden_layers = 28 + num_attention_heads = 8 + head_size = 128 + + # Create a basic KV cache group with the block size from cache config + from vllm.v1.kv_cache_interface import KVCacheGroupSpec, FullAttentionSpec + + # Get the correct dtype from the model config to match query/key dtypes + model_dtype = self.vllm_config.model_config.dtype + if isinstance(model_dtype, str): + from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE + + model_dtype = STR_DTYPE_TO_TORCH_DTYPE.get(model_dtype, torch.float16) + + kv_cache_group = KVCacheGroupSpec( + layer_names=[ + f"model.layers.{i}" for i in range(self.start_layer, self.end_layer) + ], # Only loaded layers + kv_cache_spec=FullAttentionSpec( + block_size=self.cache_config.block_size, + num_kv_heads=num_attention_heads, # Use actual model info + head_size=head_size, # Use actual model info + dtype=model_dtype, # Use model dtype instead of hardcoded float16 + ), + ) + + # Use the extracted function to create KV cache config + # Get layer names for the loaded layers + layer_names = [f"model.layers.{i}" for i in range(self.start_layer, self.end_layer)] + + kv_cache_config = _create_kv_cache_config_from_specs( + kv_cache_group=kv_cache_group, + attn_layers=layer_names, + kv_cache_memory_fraction=memory_fraction, + ) logger.info( f"KV cache config generated: " @@ -186,24 +324,86 @@ def initialize_kv_cache_manager(self, max_model_len: int) -> KVCacheManager: hashing_mod, "get_hash_fn_by_name" ) hash_fn = get_hash_fn_by_name(cache_config.prefix_caching_hash_algo) + init_none_hash(hash_fn) except (ModuleNotFoundError, AttributeError) as exc: logger.warning("Unable to initialize prefix cache hashing: %s", exc) - else: - init_none_hash(hash_fn) - block_size = kv_cache_manager.block_size - if block_size is None and self.kv_cache_config.kv_cache_groups: - block_size = self.kv_cache_config.kv_cache_groups[0].kv_cache_spec.block_size - if block_size is not None: - self.request_block_hasher = get_request_block_hasher(block_size, hash_fn) - logger.info( - "Initialized prefix cache block hasher with block_size=%d", block_size - ) + # Use a simple fallback hash function + def simple_hash_fn(obj: Any) -> bytes: + return str(hash(str(obj))).encode("utf-8") + + hash_fn = simple_hash_fn + logger.info("Using simple fallback hash function for prefix caching") + + # Initialize block hasher regardless of whether we got the hash function from vLLM or fallback + block_size = kv_cache_manager.block_size + if block_size is None and self.kv_cache_config.kv_cache_groups: + block_size = self.kv_cache_config.kv_cache_groups[0].kv_cache_spec.block_size + if block_size is not None: + self.request_block_hasher = get_request_block_hasher(block_size, hash_fn) + logger.info("Initialized prefix cache block hasher with block_size=%d", block_size) + + # Add detailed debugging information logger.info( f"KVCacheManager initialized: block_size={kv_cache_manager.block_size}, " f"usage={kv_cache_manager.usage:.2%}" ) + # Debug block pool information + if hasattr(kv_cache_manager.block_pool, "get_num_free_blocks"): + free_blocks = kv_cache_manager.block_pool.get_num_free_blocks() + total_blocks = getattr(kv_cache_manager.block_pool, "num_blocks", "unknown") + logger.info(f"Block pool: {free_blocks} free blocks out of {total_blocks} total blocks") + + # Debug coordinator information + if hasattr(kv_cache_manager.coordinator, "block_pool"): + coordinator_pool = kv_cache_manager.coordinator.block_pool + if hasattr(coordinator_pool, "get_num_free_blocks"): + coordinator_free = coordinator_pool.get_num_free_blocks() + logger.info(f"Coordinator block pool: {coordinator_free} free blocks") + + # Test KV cache allocation with a dummy request + try: + + from vllm.v1.request import Request + from vllm.sampling_params import SamplingParams + + # Create a test request to verify KV cache allocation + test_request = Request( + request_id="test_kv_cache", + prompt_token_ids=[1, 2, 3, 4, 5], # Dummy token IDs + sampling_params=SamplingParams( + temperature=0.0, + max_tokens=10, + ), + pooling_params=None, + eos_token_id=2, + lora_request=None, + ) + + # Try to allocate some blocks for the test request + allocated_blocks = kv_cache_manager.allocate_slots( + request=test_request, + num_new_tokens=5, + ) + + if allocated_blocks is not None: + logger.info( + f"Test KV cache allocation successful: {len(allocated_blocks.blocks[0])} blocks allocated" + ) + logger.info(f"KV cache usage after test: {kv_cache_manager.usage:.2%}") + + # Free the test blocks + kv_cache_manager.free(test_request) + logger.info( + f"KV cache usage after freeing test blocks: {kv_cache_manager.usage:.2%}" + ) + else: + logger.warning("Test KV cache allocation failed - no blocks available") + + except Exception as e: + logger.warning(f"KV cache test failed: {e}") + return kv_cache_manager def load_model(self) -> None: @@ -218,10 +418,10 @@ def load_model(self) -> None: # Temporarily override vLLM's PP configuration for this peer # This allows us to use vLLM's layer skipping mechanism - import vllm.distributed.parallel_state as parallel_state + from vllm.distributed.utils import get_pp_indices # Monkey-patch get_pp_indices to return our custom layer range - original_get_pp_indices = parallel_state.get_pp_indices + original_get_pp_indices = get_pp_indices def custom_get_pp_indices(num_layers: int, rank: int, world_size: int): """Return our custom layer range instead of vLLM's calculated range.""" @@ -257,7 +457,9 @@ def initialize_vllm_model_runner( kv_cache_memory_fraction: float, attention_backend: str, kv_block_size: int, + max_num_tokens_per_batch: int = 1024, dtype: str = "float16", + **kwargs, ) -> Tuple[ParallaxVLLMModelRunner, Dict, Any]: """Initialize vLLM GPUModelRunner with true partial layer loading. @@ -293,10 +495,47 @@ def initialize_vllm_model_runner( f"Initializing vLLM model runner for {model_repo}, " f"layers=[{start_layer}, {end_layer})" ) - # Load HuggingFace config and tokenizer + # Load HuggingFace config and tokenizer first hf_config = AutoConfig.from_pretrained(model_repo, trust_remote_code=True) tokenizer = AutoTokenizer.from_pretrained(model_repo, trust_remote_code=True) + # Calculate virtual PP size (needed for both configs) + is_first_peer = start_layer == 0 + is_last_peer = end_layer == hf_config.num_hidden_layers + virtual_pp_size = 2 if not (is_first_peer and is_last_peer) else 1 + + # Initialize vLLM distributed environment for pipeline parallelism + # This is required for vLLM's pipeline parallel mechanism to work + import vllm.distributed.parallel_state as parallel_state + import os + + # Initialize distributed environment if not already initialized + if not parallel_state.model_parallel_is_initialized(): + logger.info("Initializing vLLM distributed environment...") + + # Set required environment variables for single GPU scenario + if "RANK" not in os.environ: + os.environ["RANK"] = "0" + if "WORLD_SIZE" not in os.environ: + os.environ["WORLD_SIZE"] = "1" + if "LOCAL_RANK" not in os.environ: + os.environ["LOCAL_RANK"] = "0" + if "MASTER_ADDR" not in os.environ: + os.environ["MASTER_ADDR"] = "localhost" + if "MASTER_PORT" not in os.environ: + os.environ["MASTER_PORT"] = "12355" + + try: + parallel_state.init_distributed_environment() + parallel_state.initialize_model_parallel( + tensor_model_parallel_size=1, # Single GPU + pipeline_model_parallel_size=virtual_pp_size, # Match ParallelConfig + ) + logger.info(f"vLLM distributed environment initialized with pp_size={virtual_pp_size}") + except Exception as e: + logger.warning(f"Failed to initialize distributed environment: {e}") + logger.info("Continuing without distributed initialization...") + num_hidden_layers = hf_config.num_hidden_layers if end_layer > num_hidden_layers: @@ -326,13 +565,7 @@ def initialize_vllm_model_runner( # Configure PP for layer slicing # We set pp_size > 1 to enable vLLM's layer skipping mechanism # but use our custom get_pp_indices to control which layers to load - is_first_peer = start_layer == 0 - is_last_peer = end_layer == num_hidden_layers - - # Calculate a virtual PP size that makes sense - # For example, if we have 32 layers and loading [8, 16), we're in the "middle" - # Set pp_size=2 to enable PP mode, and we'll override the layer calculation - virtual_pp_size = 2 if not (is_first_peer and is_last_peer) else 1 + # virtual_pp_size is already calculated above parallel_config = ParallelConfig( pipeline_parallel_size=virtual_pp_size, @@ -344,8 +577,11 @@ def initialize_vllm_model_runner( load_config = LoadConfig(load_format="auto") # Minimal scheduler config (we bypass vLLM scheduler) + # Ensure max_num_batched_tokens is at least as large as max_model_len + # Use the provided max_num_tokens_per_batch parameter + max_batched_tokens = max(max_num_tokens_per_batch, model_config.max_model_len) scheduler_config = SchedulerConfig( - max_num_batched_tokens=8192, + max_num_batched_tokens=max_batched_tokens, max_num_seqs=256, max_model_len=model_config.max_model_len, ) @@ -362,17 +598,17 @@ def initialize_vllm_model_runner( observability_config=None, prompt_adapter_config=None, quant_config=None, - compilation_config=None, + compilation_config=CompilationConfig(), + kv_transfer_config=None, + kv_events_config=None, + additional_config={}, + instance_id="", ) - # Note: KVCacheConfig will be created by vLLM's KVCacheManager during initialization - # We don't need to manually create it here as it requires complex layer-specific information - # The KVCacheManager will handle this based on the model's architecture - - # Initialize our custom ParallaxVLLMModelRunner + # Initialize runner first; we'll build KV cache config after model load model_runner = ParallaxVLLMModelRunner( vllm_config=vllm_config, - kv_cache_config=None, # Will be created by KVCacheManager + kv_cache_config=None, device="cuda", start_layer=start_layer, end_layer=end_layer, @@ -384,6 +620,47 @@ def initialize_vllm_model_runner( model_runner.load_model() logger.info("vLLM model loaded successfully") + # Let vLLM automatically generate KV cache configuration + # This ensures proper shape and format compatibility + logger.info("Letting vLLM automatically generate KV cache configuration...") + + # Get KV cache specs from the loaded model + kv_cache_specs = model_runner.get_kv_cache_spec() + + if not kv_cache_specs: + raise RuntimeError("No KV cache specs found in the loaded model") + + # Calculate available memory for KV cache + import torch + + free_memory, total_memory = torch.cuda.mem_get_info(0) + available_memory = int(free_memory * kv_cache_memory_fraction) + + logger.info( + f"Available GPU memory for KV cache: " + f"{available_memory / (1024**3):.2f} GB " + f"({kv_cache_memory_fraction:.1%} of {free_memory / (1024**3):.2f} GB)" + ) + + # Use vLLM's utility to generate KV cache config + from vllm.v1.core.kv_cache_utils import get_kv_cache_configs, generate_scheduler_kv_cache_config + + kv_cache_configs = get_kv_cache_configs( + vllm_config=model_runner.vllm_config, + kv_cache_specs=[kv_cache_specs], # Single worker + available_memory=[available_memory], + ) + + # For single worker case, use the first config + kv_cache_config = generate_scheduler_kv_cache_config(kv_cache_configs) + + model_runner.kv_cache_config = kv_cache_config + + # Initialize GPU-side KV cache (creates attn_groups, block tables, etc.) + logger.info("Initializing GPUModelRunner KV cache...") + model_runner.initialize_kv_cache(kv_cache_config) + logger.info("GPUModelRunner KV cache initialized successfully") + # Initialize KV Cache Manager after model is loaded logger.info("Initializing KV Cache Manager...") model_runner.initialize_kv_cache_manager(max_model_len=model_config.max_model_len) From 5b6bc799d5655c141137d47e16f02552413e7db8 Mon Sep 17 00:00:00 2001 From: yuhao-zh Date: Mon, 27 Oct 2025 09:03:23 +0000 Subject: [PATCH 13/36] update --- src/parallax/vllm/model_runner.py | 124 ++++++------------------------ 1 file changed, 25 insertions(+), 99 deletions(-) diff --git a/src/parallax/vllm/model_runner.py b/src/parallax/vllm/model_runner.py index 16c525e4..66426e45 100644 --- a/src/parallax/vllm/model_runner.py +++ b/src/parallax/vllm/model_runner.py @@ -31,8 +31,10 @@ ) from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheGroupSpec, KVCacheTensor from vllm.v1.worker.gpu_model_runner import GPUModelRunner +from parallax.utils.tokenizer_utils import load_tokenizer from parallax_utils.logging_config import get_logger +from mlx_lm.utils import get_model_path, load_config logger = get_logger(__name__) @@ -215,30 +217,14 @@ def _create_kv_cache_config(self, kv_cache_memory_fraction: float = None) -> KVC logger.info("Using fallback KV cache configuration") # Try to get model info from the loaded model to create a more accurate config - try: + # Get model architecture info from the loaded model - model = self.model - if hasattr(model, "model") and hasattr(model.model, "config"): - hf_config = model.model.config - num_hidden_layers = getattr(hf_config, "num_hidden_layers", 28) - num_attention_heads = getattr(hf_config, "num_attention_heads", 8) - hidden_size = getattr(hf_config, "hidden_size", 1024) - head_size = hidden_size // num_attention_heads - else: - # Fallback to default values - num_hidden_layers = 28 - num_attention_heads = 8 - head_size = 128 - - logger.info( - f"Using model info: layers={num_hidden_layers}, heads={num_attention_heads}, head_size={head_size}" - ) - - except Exception as e: - logger.warning(f"Could not get model info: {e}, using defaults") - num_hidden_layers = 28 - num_attention_heads = 8 - head_size = 128 + model = self.model + hf_config = model.model.config + num_attention_heads = getattr(hf_config, "num_attention_heads", 8) + hidden_size = getattr(hf_config, "hidden_size", 1024) + head_size = hidden_size // num_attention_heads + # Create a basic KV cache group with the block size from cache config from vllm.v1.kv_cache_interface import KVCacheGroupSpec, FullAttentionSpec @@ -314,7 +300,8 @@ def initialize_kv_cache_manager(self, max_model_len: int) -> KVCacheManager: enable_prefix = cache_config.enable_prefix_caching if enable_prefix is None: enable_prefix = True - self.enable_prefix_caching = enable_prefix + + self.enable_prefix_caching = False self.request_block_hasher = None if enable_prefix and kv_cache_manager.block_size is not None: @@ -349,60 +336,7 @@ def simple_hash_fn(obj: Any) -> bytes: f"usage={kv_cache_manager.usage:.2%}" ) - # Debug block pool information - if hasattr(kv_cache_manager.block_pool, "get_num_free_blocks"): - free_blocks = kv_cache_manager.block_pool.get_num_free_blocks() - total_blocks = getattr(kv_cache_manager.block_pool, "num_blocks", "unknown") - logger.info(f"Block pool: {free_blocks} free blocks out of {total_blocks} total blocks") - - # Debug coordinator information - if hasattr(kv_cache_manager.coordinator, "block_pool"): - coordinator_pool = kv_cache_manager.coordinator.block_pool - if hasattr(coordinator_pool, "get_num_free_blocks"): - coordinator_free = coordinator_pool.get_num_free_blocks() - logger.info(f"Coordinator block pool: {coordinator_free} free blocks") - - # Test KV cache allocation with a dummy request - try: - - from vllm.v1.request import Request - from vllm.sampling_params import SamplingParams - - # Create a test request to verify KV cache allocation - test_request = Request( - request_id="test_kv_cache", - prompt_token_ids=[1, 2, 3, 4, 5], # Dummy token IDs - sampling_params=SamplingParams( - temperature=0.0, - max_tokens=10, - ), - pooling_params=None, - eos_token_id=2, - lora_request=None, - ) - - # Try to allocate some blocks for the test request - allocated_blocks = kv_cache_manager.allocate_slots( - request=test_request, - num_new_tokens=5, - ) - - if allocated_blocks is not None: - logger.info( - f"Test KV cache allocation successful: {len(allocated_blocks.blocks[0])} blocks allocated" - ) - logger.info(f"KV cache usage after test: {kv_cache_manager.usage:.2%}") - - # Free the test blocks - kv_cache_manager.free(test_request) - logger.info( - f"KV cache usage after freeing test blocks: {kv_cache_manager.usage:.2%}" - ) - else: - logger.warning("Test KV cache allocation failed - no blocks available") - - except Exception as e: - logger.warning(f"KV cache test failed: {e}") + return kv_cache_manager @@ -495,13 +429,15 @@ def initialize_vllm_model_runner( f"Initializing vLLM model runner for {model_repo}, " f"layers=[{start_layer}, {end_layer})" ) - # Load HuggingFace config and tokenizer first - hf_config = AutoConfig.from_pretrained(model_repo, trust_remote_code=True) - tokenizer = AutoTokenizer.from_pretrained(model_repo, trust_remote_code=True) + model_path = get_model_path(model_repo)[0] + config = load_config(model_path) + tokenizer = load_tokenizer(model_path, eos_token_ids=config.get("eos_token_id", None)) + dtype = config.get("torch_dtype", "bfloat16") # Calculate virtual PP size (needed for both configs) + num_hidden_layers = getattr(config, "num_hidden_layers", 28) is_first_peer = start_layer == 0 - is_last_peer = end_layer == hf_config.num_hidden_layers + is_last_peer = end_layer == num_hidden_layers virtual_pp_size = 2 if not (is_first_peer and is_last_peer) else 1 # Initialize vLLM distributed environment for pipeline parallelism @@ -536,7 +472,7 @@ def initialize_vllm_model_runner( logger.warning(f"Failed to initialize distributed environment: {e}") logger.info("Continuing without distributed initialization...") - num_hidden_layers = hf_config.num_hidden_layers + if end_layer > num_hidden_layers: raise ValueError( @@ -552,7 +488,7 @@ def initialize_vllm_model_runner( trust_remote_code=True, dtype=dtype, seed=0, - max_model_len=getattr(hf_config, "max_position_embeddings", 4096), + max_model_len=getattr(config, "max_position_embeddings", 4096), ) cache_config = CacheConfig( @@ -574,7 +510,7 @@ def initialize_vllm_model_runner( ) device_config = DeviceConfig(device="cuda") - load_config = LoadConfig(load_format="auto") + load_config_for_config = LoadConfig(load_format="auto") # Minimal scheduler config (we bypass vLLM scheduler) # Ensure max_num_batched_tokens is at least as large as max_model_len @@ -592,7 +528,7 @@ def initialize_vllm_model_runner( parallel_config=parallel_config, scheduler_config=scheduler_config, device_config=device_config, - load_config=load_config, + load_config=load_config_for_config, lora_config=None, speculative_config=None, observability_config=None, @@ -667,16 +603,6 @@ def initialize_vllm_model_runner( logger.info("KV Cache Manager initialized successfully") # Return config as dict for compatibility with Parallax executor - config_dict = { - "num_hidden_layers": num_hidden_layers, - "hidden_size": hf_config.hidden_size, - "num_attention_heads": hf_config.num_attention_heads, - "num_key_value_heads": getattr( - hf_config, "num_key_value_heads", hf_config.num_attention_heads - ), - "head_dim": getattr( - hf_config, "head_dim", hf_config.hidden_size // hf_config.num_attention_heads - ), - } - - return model_runner, config_dict, tokenizer + + + return model_runner, config, tokenizer From 57eb1c7bffd3d3815087eb5a8267166709a2c285 Mon Sep 17 00:00:00 2001 From: yuhao-zh Date: Thu, 30 Oct 2025 08:55:44 +0000 Subject: [PATCH 14/36] update --- src/parallax/vllm/batch_info.py | 107 +++------------- src/parallax/vllm/model_runner.py | 201 +++--------------------------- 2 files changed, 29 insertions(+), 279 deletions(-) diff --git a/src/parallax/vllm/batch_info.py b/src/parallax/vllm/batch_info.py index 64ce36ce..2008f104 100644 --- a/src/parallax/vllm/batch_info.py +++ b/src/parallax/vllm/batch_info.py @@ -1,15 +1,3 @@ -""" -Store information about a vLLM batch. - -This module provides batch formation utilities for vLLM v1 backend integration. -It transforms Parallax requests into vLLM-compatible structures for both prefill -and decode stages. - -Key differences from SGLang: -- vLLM uses SchedulerOutput (flat) vs SGLang's ScheduleBatch (hierarchical) -- KV Cache is managed independently via KVCache object -- Sampling is integrated in execute_model() call -""" from __future__ import annotations @@ -30,22 +18,11 @@ def transform_sampling_params_to_vllm(old_params: ParallaxSamplingParams) -> VLLMSamplingParams: - """Transforms Parallax SamplingParams to vLLM SamplingParams format. - - Args: - old_params: Parallax sampling parameters - - Returns: - vLLM SamplingParams object - """ - # Map Parallax json_schema -> vLLM structured_outputs structured = ( StructuredOutputsParams(json=old_params.json_schema) if getattr(old_params, "json_schema", None) is not None else None ) - - # vLLM uses max_tokens/min_tokens naming params = VLLMSamplingParams( max_tokens=old_params.max_new_tokens, min_tokens=old_params.min_new_tokens, @@ -96,24 +73,9 @@ def form_vllm_batch_prefill( batched_requests: List[Request], model_runner: Any = None, ) -> Optional[SchedulerOutput]: - """Prepare a vLLM prefill batch. - - Constructs a SchedulerOutput for vLLM v1 GPUModelRunner that contains: - - NewRequestData for each request (new prefill requests) - - KV cache block allocations via vLLM's native KVCacheManager - - Token scheduling information - - Args: - batched_requests: List of Parallax requests to prefill - model_runner: ParallaxVLLMModelRunner instance with initialized kv_cache_manager - - Returns: - SchedulerOutput compatible with vLLM GPUModelRunner, or None if the batch is empty. - """ if not batched_requests: return None - # Get vLLM's KVCacheManager from model_runner if not hasattr(model_runner, "kv_cache_manager"): raise RuntimeError( "model_runner must have kv_cache_manager initialized. " @@ -126,7 +88,6 @@ def form_vllm_batch_prefill( created_vllm_requests: List[VLLMRequest] = [] - # Build NewRequestData for each request new_request_data_list = [] num_scheduled_tokens: Dict[str, int] = {} total_tokens = 0 @@ -137,10 +98,8 @@ def form_vllm_batch_prefill( vllm_req = _build_vllm_request(req, sampling_params, model_runner, include_outputs=False) created_vllm_requests.append(vllm_req) - # Check for prefix cache hits computed_blocks, num_computed_tokens = kv_cache_manager.get_computed_blocks(vllm_req) - # Allocate KV cache blocks for the remaining tokens prompt_token_ids = getattr(req, "input_ids", None) or [] num_new_tokens = max(len(prompt_token_ids) - num_computed_tokens, 0) if num_new_tokens > 0: @@ -152,31 +111,27 @@ def form_vllm_batch_prefill( ) if new_blocks is None: - # Cannot allocate blocks (OOM) logger.warning(f"Cannot allocate KV cache for request {req.request_id}") - # Free any allocated blocks for previous requests in this batch for prev_req in created_vllm_requests[:-1]: kv_cache_manager.free(prev_req) return None - # Combine computed blocks and new blocks all_blocks = computed_blocks + new_blocks if num_computed_tokens > 0 else new_blocks else: all_blocks = computed_blocks - # Get block IDs for the request block_ids = all_blocks.get_block_ids() new_req_data = NewRequestData( req_id=req.request_id, prompt_token_ids=req.input_ids, - mm_features=[], # Multimodal features (empty for text-only) + mm_features=[], sampling_params=sampling_params, - pooling_params=None, # For embedding models + pooling_params=None, block_ids=block_ids, num_computed_tokens=num_computed_tokens, - lora_request=None, # LoRA adapter - prompt_embeds=None, # Soft prompts + lora_request=None, + prompt_embeds=None, ) new_request_data_list.append(new_req_data) @@ -184,21 +139,19 @@ def form_vllm_batch_prefill( num_scheduled_tokens[req.request_id] = scheduled_tokens total_tokens += scheduled_tokens - # Build SchedulerOutput - # This is the main data structure that vLLM's model runner expects scheduler_output = SchedulerOutput( scheduled_new_reqs=new_request_data_list, - scheduled_cached_reqs=CachedRequestData.make_empty(), # No cached reqs in prefill + scheduled_cached_reqs=CachedRequestData.make_empty(), num_scheduled_tokens=num_scheduled_tokens, total_num_scheduled_tokens=total_tokens, - scheduled_spec_decode_tokens={}, # Speculative decoding tokens - scheduled_encoder_inputs={}, # For encoder-decoder models - num_common_prefix_blocks=num_common_prefix_blocks, # Prefix caching baseline - finished_req_ids=set(), # No finished requests in prefill - free_encoder_mm_hashes=[], # Encoder multimodal hash cleanup - structured_output_request_ids=[], # Requests using structured output - grammar_bitmask=None, # Grammar constraints - kv_connector_metadata=None, # KV connector for disaggregation + scheduled_spec_decode_tokens={}, + scheduled_encoder_inputs={}, + num_common_prefix_blocks=num_common_prefix_blocks, + finished_req_ids=set(), + free_encoder_mm_hashes=[], + structured_output_request_ids=[], + grammar_bitmask=None, + kv_connector_metadata=None, ) return scheduler_output @@ -208,25 +161,9 @@ def form_vllm_batch_decode( batched_requests: List[Request], model_runner: Any = None, ) -> Optional[SchedulerOutput]: - """Prepare a vLLM decode batch. - - Constructs a SchedulerOutput for vLLM v1 GPUModelRunner for decode stage. - Key differences from prefill: - - Uses CachedRequestData (not NewRequestData) - - Each request processes exactly 1 token - - KV cache blocks already allocated, may need to extend - - Args: - batched_requests: List of Parallax requests in decode phase - model_runner: ParallaxVLLMModelRunner instance with initialized kv_cache_manager - - Returns: - SchedulerOutput describing the decode work, or None if the batch is empty. - """ if not batched_requests: return None - # Get vLLM's KVCacheManager if not hasattr(model_runner, "kv_cache_manager"): raise RuntimeError( "model_runner must have kv_cache_manager initialized. " @@ -248,7 +185,6 @@ def form_vllm_batch_decode( req_ids.append(req.request_id) resumed_from_preemption.append(False) new_token_ids.append([]) - # For decode requests, we don't have resumed token IDs resumed_req_token_ids.append([]) sampling_params = transform_sampling_params_to_vllm(req.sampling_params) @@ -282,9 +218,8 @@ def form_vllm_batch_decode( num_computed_tokens=num_computed_tokens, ) - # Build SchedulerOutput for decode scheduler_output = SchedulerOutput( - scheduled_new_reqs=[], # No new requests in decode + scheduled_new_reqs=[], scheduled_cached_reqs=cached_req_data, num_scheduled_tokens=num_scheduled_tokens, total_num_scheduled_tokens=sum(num_scheduled_tokens.values()), @@ -302,27 +237,13 @@ def form_vllm_batch_decode( def release_vllm_request(model_runner: Any, request_id: str): - """Release KV Cache and other resources for finished/aborted requests. - - Uses vLLM's native KVCacheManager to properly free allocated blocks - and update prefix cache if enabled. - - Args: - model_runner: ParallaxVLLMModelRunner instance with kv_cache_manager - request_id: ID of the request to release - """ if not hasattr(model_runner, "kv_cache_manager"): logger.warning(f"KV cache manager not found when releasing request {request_id}") return kv_cache_manager = model_runner.kv_cache_manager - # Create a minimal vLLM Request object for the free operation - # Note: We need the request object, not just the ID - # In a real scenario, we should maintain a mapping of request_id -> vLLMRequest - # For now, we'll use the KVCacheManager's coordinator directly try: - # The coordinator can free by request_id directly kv_cache_manager.coordinator.free(request_id) logger.debug(f"Released KV cache for request {request_id}") except Exception as e: diff --git a/src/parallax/vllm/model_runner.py b/src/parallax/vllm/model_runner.py index 66426e45..b3eac185 100644 --- a/src/parallax/vllm/model_runner.py +++ b/src/parallax/vllm/model_runner.py @@ -1,9 +1,3 @@ -""" -vLLM Model Runner wrapper for Parallax with Pipeline Parallelism support. - -Integrates vLLM v1 GPUModelRunner for CUDA backend. -Uses vLLM's native Pipeline Parallelism mechanism to load only required layers. -""" from __future__ import annotations @@ -44,23 +38,8 @@ def _create_kv_cache_config_from_specs( attn_layers: List[str], kv_cache_memory_fraction: float, ) -> KVCacheConfig: - """ - Create KV cache configuration from KV cache group specs and attention layers. - - This is a standalone function that can be used by both the model runner's - _create_kv_cache_config method and the initialize_vllm_model_runner function. - - Args: - kv_cache_group: KV cache group specification - attn_layers: List of attention layer names - kv_cache_memory_fraction: Fraction of GPU memory to use for KV cache - - Returns: - KVCacheConfig: Properly configured KV cache configuration - """ import torch - # Calculate available GPU memory for KV cache free_memory, total_memory = torch.cuda.mem_get_info(0) available_memory = int(free_memory * kv_cache_memory_fraction) @@ -70,24 +49,16 @@ def _create_kv_cache_config_from_specs( f"({kv_cache_memory_fraction:.1%} of {free_memory / (1024**3):.2f} GB)" ) - # Calculate page_size_bytes for proper tensor sizing page_size_bytes = kv_cache_group.kv_cache_spec.page_size_bytes - # Calculate reasonable number of blocks based on available memory - # Each block needs page_size_bytes, so we can fit this many blocks max_blocks_by_memory = available_memory // page_size_bytes - # Use a conservative estimate (80% of max possible blocks) - # But ensure we don't exceed available memory num_blocks = max(100, min(1000, int(max_blocks_by_memory * 0.8))) logger.info(f"Calculated KV cache blocks: {num_blocks} (max possible: {max_blocks_by_memory})") - # Ensure tensor size is divisible by page_size_bytes tensor_size_bytes = page_size_bytes * num_blocks - # Ensure KVCacheTensor.shared_by covers all attention layers; otherwise - # vLLM will assert that some layers are not initialized. kv_cache_config = KVCacheConfig( num_blocks=num_blocks, kv_cache_tensors=[ @@ -103,12 +74,6 @@ def _create_kv_cache_config_from_specs( class ParallaxVLLMModelRunner(GPUModelRunner): - """ - Extended vLLM GPUModelRunner that leverages vLLM's native Pipeline Parallelism. - - This class uses vLLM's PPMissingLayer mechanism to load only the required layers - during model initialization, avoiding the need to load and then prune the full model. - """ def __init__( self, @@ -119,16 +84,6 @@ def __init__( end_layer: int, num_hidden_layers: int, ): - """ - Args: - vllm_config: vLLM configuration object - kv_cache_config: KV cache configuration (can be None, will be created by KVCacheManager) - device: Device to run on (e.g., "cuda") - start_layer: First layer index to load (inclusive) - end_layer: Last layer index to load (exclusive) - num_hidden_layers: Total number of layers in the full model - """ - # Store layer information before calling super().__init__ self.start_layer = start_layer self.end_layer = end_layer self.num_hidden_layers = num_hidden_layers @@ -137,17 +92,13 @@ def __init__( self.is_first_peer = start_layer == 0 self.is_last_peer = end_layer == num_hidden_layers - # Calculate PP rank and size for vLLM - # We simulate a PP setup where each Parallax peer is a PP rank - self.pp_rank = 0 # Will be updated based on layer range - self.pp_size = 1 # Single node, but with layer slicing + self.pp_rank = 0 + self.pp_size = 1 self.request_block_hasher: Optional[Callable[[Any], List[Any]]] = None self.enable_prefix_caching: bool = True - # Call parent init super().__init__(vllm_config=vllm_config, device=torch.device(device)) - # KV cache config will be created by KVCacheManager during initialization self.kv_cache_config = kv_cache_config logger.info( @@ -156,38 +107,20 @@ def __init__( ) def _create_kv_cache_config(self, kv_cache_memory_fraction: float = None) -> KVCacheConfig: - """ - Create KV cache configuration from the loaded model. - - This method leverages vLLM's native KV cache configuration generation - by extracting KV cache specs from the model's attention layers and - using vLLM's utilities to generate the proper configuration. - - Returns: - KVCacheConfig: Properly configured KV cache configuration - """ logger.info("Generating KV cache configuration from model...") - # Get KV cache specs from model's attention layers - # Try to access the method directly, bypassing cudagraph wrapper if needed try: kv_cache_specs = self.model.get_kv_cache_spec() except AttributeError: - # If cudagraph wrapper is blocking access, try to get specs from the underlying model logger.warning( "Cannot access get_kv_cache_spec due to cudagraph wrapper, using fallback method" ) - # Use a simplified approach - let KVCacheManager handle the details kv_cache_specs = None - # Get available GPU memory for KV cache - # Use PyTorch's native memory info function import torch free_memory, total_memory = torch.cuda.mem_get_info(self.device.index or 0) - # Calculate available memory for KV cache - # Use provided fraction or fall back to cache_config memory_fraction = ( kv_cache_memory_fraction if kv_cache_memory_fraction is not None @@ -201,35 +134,24 @@ def _create_kv_cache_config(self, kv_cache_memory_fraction: float = None) -> KVC f"({memory_fraction:.1%} of {free_memory / (1024**3):.2f} GB)" ) - # Use vLLM's utility to generate KV cache config - # This handles all the complexity of different attention types, - # hybrid models, sliding windows, etc. if kv_cache_specs is not None: kv_cache_configs = get_kv_cache_configs( vllm_config=self.vllm_config, - kv_cache_specs=[kv_cache_specs], # Single worker + kv_cache_specs=[kv_cache_specs], available_memory=[available_memory], ) - # For scheduler (single worker case), we can use the first config kv_cache_config = generate_scheduler_kv_cache_config(kv_cache_configs) else: - # Fallback: create a basic KV cache config logger.info("Using fallback KV cache configuration") - # Try to get model info from the loaded model to create a more accurate config - - # Get model architecture info from the loaded model model = self.model hf_config = model.model.config num_attention_heads = getattr(hf_config, "num_attention_heads", 8) hidden_size = getattr(hf_config, "hidden_size", 1024) head_size = hidden_size // num_attention_heads - - # Create a basic KV cache group with the block size from cache config from vllm.v1.kv_cache_interface import KVCacheGroupSpec, FullAttentionSpec - # Get the correct dtype from the model config to match query/key dtypes model_dtype = self.vllm_config.model_config.dtype if isinstance(model_dtype, str): from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE @@ -239,17 +161,15 @@ def _create_kv_cache_config(self, kv_cache_memory_fraction: float = None) -> KVC kv_cache_group = KVCacheGroupSpec( layer_names=[ f"model.layers.{i}" for i in range(self.start_layer, self.end_layer) - ], # Only loaded layers + ], kv_cache_spec=FullAttentionSpec( block_size=self.cache_config.block_size, - num_kv_heads=num_attention_heads, # Use actual model info - head_size=head_size, # Use actual model info - dtype=model_dtype, # Use model dtype instead of hardcoded float16 + num_kv_heads=num_attention_heads, + head_size=head_size, + dtype=model_dtype, ), ) - # Use the extracted function to create KV cache config - # Get layer names for the loaded layers layer_names = [f"model.layers.{i}" for i in range(self.start_layer, self.end_layer)] kv_cache_config = _create_kv_cache_config_from_specs( @@ -267,32 +187,19 @@ def _create_kv_cache_config(self, kv_cache_memory_fraction: float = None) -> KVC return kv_cache_config def initialize_kv_cache_manager(self, max_model_len: int) -> KVCacheManager: - """ - Initialize vLLM's native KVCacheManager. - - This should be called after the model is loaded to properly set up - the KV cache management system. - - Args: - max_model_len: Maximum sequence length the model can handle - - Returns: - Initialized KVCacheManager instance - """ logger.info("Initializing vLLM KVCacheManager...") - # Generate KV cache config from model if not already provided if self.kv_cache_config is None: self.kv_cache_config = self._create_kv_cache_config() kv_cache_manager = KVCacheManager( kv_cache_config=self.kv_cache_config, max_model_len=max_model_len, - enable_caching=True, # Enable prefix caching - use_eagle=False, # Not using EAGLE speculative decoding - log_stats=True, # Enable stats logging - enable_kv_cache_events=False, # Disable KV cache events for now - dcp_world_size=1, # Decode Context Parallelism world size + enable_caching=True, + use_eagle=False, + log_stats=True, + enable_kv_cache_events=False, + dcp_world_size=1, ) self.kv_cache_manager = kv_cache_manager @@ -315,14 +222,12 @@ def initialize_kv_cache_manager(self, max_model_len: int) -> KVCacheManager: except (ModuleNotFoundError, AttributeError) as exc: logger.warning("Unable to initialize prefix cache hashing: %s", exc) - # Use a simple fallback hash function def simple_hash_fn(obj: Any) -> bytes: return str(hash(str(obj))).encode("utf-8") hash_fn = simple_hash_fn logger.info("Using simple fallback hash function for prefix caching") - # Initialize block hasher regardless of whether we got the hash function from vLLM or fallback block_size = kv_cache_manager.block_size if block_size is None and self.kv_cache_config.kv_cache_groups: block_size = self.kv_cache_config.kv_cache_groups[0].kv_cache_spec.block_size @@ -330,55 +235,38 @@ def simple_hash_fn(obj: Any) -> bytes: self.request_block_hasher = get_request_block_hasher(block_size, hash_fn) logger.info("Initialized prefix cache block hasher with block_size=%d", block_size) - # Add detailed debugging information logger.info( f"KVCacheManager initialized: block_size={kv_cache_manager.block_size}, " f"usage={kv_cache_manager.usage:.2%}" ) - - return kv_cache_manager def load_model(self) -> None: - """ - Load model using vLLM's native layer loading mechanism. - - This method uses vLLM's make_layers function which creates PPMissingLayer - placeholders for layers outside [start_layer, end_layer), ensuring only - the required layers are actually loaded from checkpoint. - """ logger.info(f"Loading vLLM model with layers [{self.start_layer}, {self.end_layer})...") - # Temporarily override vLLM's PP configuration for this peer - # This allows us to use vLLM's layer skipping mechanism from vllm.distributed.utils import get_pp_indices - # Monkey-patch get_pp_indices to return our custom layer range original_get_pp_indices = get_pp_indices def custom_get_pp_indices(num_layers: int, rank: int, world_size: int): - """Return our custom layer range instead of vLLM's calculated range.""" logger.debug( f"custom_get_pp_indices called: num_layers={num_layers}, " f"returning [{self.start_layer}, {self.end_layer})" ) return self.start_layer, self.end_layer - # Temporarily replace the function import vllm.distributed.utils vllm.distributed.utils.get_pp_indices = custom_get_pp_indices try: - # Now call the parent load_model, which will use our custom layer range super().load_model() logger.info( f"Successfully loaded {self.num_shard_layers} layers " f"[{self.start_layer}:{self.end_layer}]" ) finally: - # Restore original function vllm.distributed.utils.get_pp_indices = original_get_pp_indices logger.info("Model loaded successfully with partial layers") @@ -395,36 +283,6 @@ def initialize_vllm_model_runner( dtype: str = "float16", **kwargs, ) -> Tuple[ParallaxVLLMModelRunner, Dict, Any]: - """Initialize vLLM GPUModelRunner with true partial layer loading. - - This function leverages vLLM's native Pipeline Parallelism mechanism to load - only the required layers, avoiding the memory overhead of loading the full model. - - The key insight is to monkey-patch vLLM's get_pp_indices function during model - loading, which allows us to control exactly which layers are loaded. Layers - outside the [start_layer, end_layer) range are replaced with PPMissingLayer - placeholders that consume minimal memory. - - Args: - model_repo: HuggingFace model repo path - start_layer: Start layer index (inclusive) - end_layer: End layer index (exclusive) - kv_cache_memory_fraction: Fraction of GPU memory for KV cache - attention_backend: Attention backend (e.g., "flash_attn") - kv_block_size: KV cache block size - dtype: Model dtype - - Returns: - (model_runner, config_dict, tokenizer) - - Example: - >>> # Load only layers 8-16 of a 32-layer model - >>> runner, config, tok = initialize_vllm_model_runner( - ... "meta-llama/Llama-2-7b-hf", 8, 16, 0.8, "flash_attn", 64 - ... ) - >>> # Only 8 layers are actually loaded into memory - ``` - """ logger.info( f"Initializing vLLM model runner for {model_repo}, " f"layers=[{start_layer}, {end_layer})" ) @@ -434,22 +292,17 @@ def initialize_vllm_model_runner( tokenizer = load_tokenizer(model_path, eos_token_ids=config.get("eos_token_id", None)) dtype = config.get("torch_dtype", "bfloat16") - # Calculate virtual PP size (needed for both configs) num_hidden_layers = getattr(config, "num_hidden_layers", 28) is_first_peer = start_layer == 0 is_last_peer = end_layer == num_hidden_layers virtual_pp_size = 2 if not (is_first_peer and is_last_peer) else 1 - # Initialize vLLM distributed environment for pipeline parallelism - # This is required for vLLM's pipeline parallel mechanism to work import vllm.distributed.parallel_state as parallel_state import os - # Initialize distributed environment if not already initialized if not parallel_state.model_parallel_is_initialized(): logger.info("Initializing vLLM distributed environment...") - # Set required environment variables for single GPU scenario if "RANK" not in os.environ: os.environ["RANK"] = "0" if "WORLD_SIZE" not in os.environ: @@ -464,23 +317,20 @@ def initialize_vllm_model_runner( try: parallel_state.init_distributed_environment() parallel_state.initialize_model_parallel( - tensor_model_parallel_size=1, # Single GPU - pipeline_model_parallel_size=virtual_pp_size, # Match ParallelConfig + tensor_model_parallel_size=1, + pipeline_model_parallel_size=virtual_pp_size, ) logger.info(f"vLLM distributed environment initialized with pp_size={virtual_pp_size}") except Exception as e: logger.warning(f"Failed to initialize distributed environment: {e}") logger.info("Continuing without distributed initialization...") - - if end_layer > num_hidden_layers: raise ValueError( f"end_layer ({end_layer}) cannot be greater than " f"num_hidden_layers ({num_hidden_layers})" ) - # Build vLLM configs model_config = ModelConfig( model=model_repo, tokenizer=model_repo, @@ -498,11 +348,6 @@ def initialize_vllm_model_runner( cache_dtype="auto", ) - # Configure PP for layer slicing - # We set pp_size > 1 to enable vLLM's layer skipping mechanism - # but use our custom get_pp_indices to control which layers to load - # virtual_pp_size is already calculated above - parallel_config = ParallelConfig( pipeline_parallel_size=virtual_pp_size, tensor_parallel_size=1, @@ -512,9 +357,6 @@ def initialize_vllm_model_runner( device_config = DeviceConfig(device="cuda") load_config_for_config = LoadConfig(load_format="auto") - # Minimal scheduler config (we bypass vLLM scheduler) - # Ensure max_num_batched_tokens is at least as large as max_model_len - # Use the provided max_num_tokens_per_batch parameter max_batched_tokens = max(max_num_tokens_per_batch, model_config.max_model_len) scheduler_config = SchedulerConfig( max_num_batched_tokens=max_batched_tokens, @@ -541,7 +383,6 @@ def initialize_vllm_model_runner( instance_id="", ) - # Initialize runner first; we'll build KV cache config after model load model_runner = ParallaxVLLMModelRunner( vllm_config=vllm_config, kv_cache_config=None, @@ -551,22 +392,17 @@ def initialize_vllm_model_runner( num_hidden_layers=num_hidden_layers, ) - # Load model with partial layers logger.info("Loading vLLM model (partial layers)...") model_runner.load_model() logger.info("vLLM model loaded successfully") - # Let vLLM automatically generate KV cache configuration - # This ensures proper shape and format compatibility logger.info("Letting vLLM automatically generate KV cache configuration...") - # Get KV cache specs from the loaded model kv_cache_specs = model_runner.get_kv_cache_spec() if not kv_cache_specs: raise RuntimeError("No KV cache specs found in the loaded model") - # Calculate available memory for KV cache import torch free_memory, total_memory = torch.cuda.mem_get_info(0) @@ -578,31 +414,24 @@ def initialize_vllm_model_runner( f"({kv_cache_memory_fraction:.1%} of {free_memory / (1024**3):.2f} GB)" ) - # Use vLLM's utility to generate KV cache config from vllm.v1.core.kv_cache_utils import get_kv_cache_configs, generate_scheduler_kv_cache_config kv_cache_configs = get_kv_cache_configs( vllm_config=model_runner.vllm_config, - kv_cache_specs=[kv_cache_specs], # Single worker + kv_cache_specs=[kv_cache_specs], available_memory=[available_memory], ) - # For single worker case, use the first config kv_cache_config = generate_scheduler_kv_cache_config(kv_cache_configs) model_runner.kv_cache_config = kv_cache_config - # Initialize GPU-side KV cache (creates attn_groups, block tables, etc.) logger.info("Initializing GPUModelRunner KV cache...") model_runner.initialize_kv_cache(kv_cache_config) logger.info("GPUModelRunner KV cache initialized successfully") - # Initialize KV Cache Manager after model is loaded logger.info("Initializing KV Cache Manager...") model_runner.initialize_kv_cache_manager(max_model_len=model_config.max_model_len) logger.info("KV Cache Manager initialized successfully") - # Return config as dict for compatibility with Parallax executor - - return model_runner, config, tokenizer From 35e5baf8f2046705a38de599d68f357d8fa19d2f Mon Sep 17 00:00:00 2001 From: yuhao-zh Date: Fri, 31 Oct 2025 02:40:17 +0000 Subject: [PATCH 15/36] update --- src/parallax/server/executor.py | 3 ++- src/parallax/sglang/model_runner.py | 8 ++++---- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/src/parallax/server/executor.py b/src/parallax/server/executor.py index e6a56187..7ff8aeb6 100644 --- a/src/parallax/server/executor.py +++ b/src/parallax/server/executor.py @@ -139,6 +139,7 @@ def __init__( "kv_block_size": kv_block_size, "max_num_tokens_per_batch": max_num_tokens_per_batch, "dtype": dtype, + "moe_runner_backend": moe_runner_backend, } self.model_runner, self.config, self.tokenizer = initialize_cuda_model_runner( @@ -146,7 +147,7 @@ def __init__( ) self.running_batch = None self.cur_batch = None - if self == "sglang": + if self.backend_type == "sglang": self.running_batch = CudaScheduleBatch(reqs=[], batch_is_full=False) else: diff --git a/src/parallax/sglang/model_runner.py b/src/parallax/sglang/model_runner.py index 85ae5228..98335f76 100755 --- a/src/parallax/sglang/model_runner.py +++ b/src/parallax/sglang/model_runner.py @@ -539,7 +539,7 @@ def apply_parallax_monkey_patch(): def initialize_sgl_model_runner( - original_model_path: str, + model_repo: str, start_layer: int, end_layer: int, kv_cache_memory_fraction: float, @@ -557,7 +557,7 @@ def initialize_sgl_model_runner( - tokenizer: tokenizer driven by mlx-lm """ apply_parallax_monkey_patch() - model_path = get_model_path(original_model_path)[0] + model_path = get_model_path(model_repo)[0] config = load_config(model_path) tokenizer = load_tokenizer(model_path, eos_token_ids=config.get("eos_token_id", None)) dtype = config.get("torch_dtype", "bfloat16") @@ -578,7 +578,7 @@ def initialize_sgl_model_runner( kv_block_size = 1 server_args = form_sgl_server_args( - original_model_path, + model_repo, dtype, attention_backend, kv_block_size, @@ -589,7 +589,7 @@ def initialize_sgl_model_runner( if (quantization_config := config.get("quantization_config", None)) is not None: quant_method = quantization_config.get("quant_method") model_config = ModelConfig( - model_path=original_model_path, + model_path=model_repo, model_override_args="{}", dtype=dtype, quantization=quant_method, From 508cea4c660f6f46b6a9fde38ddd4f8e7d7a86f0 Mon Sep 17 00:00:00 2001 From: yuhao-zh Date: Thu, 6 Nov 2025 13:34:02 +0800 Subject: [PATCH 16/36] update args --- src/backend/server/server_args.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/backend/server/server_args.py b/src/backend/server/server_args.py index 3d62be4a..feaf60c9 100644 --- a/src/backend/server/server_args.py +++ b/src/backend/server/server_args.py @@ -37,7 +37,7 @@ def parse_args() -> argparse.Namespace: ) parser.add_argument( - "--gpu_backend", + "--gpu-backend", type=str, default="sglang", choices=["sglang", "vllm"], From 9db035c72ccf20119584bff7e11b247b84937135 Mon Sep 17 00:00:00 2001 From: yuhao-zh Date: Thu, 6 Nov 2025 15:43:52 +0800 Subject: [PATCH 17/36] success run --- src/parallax/vllm/batch_info.py | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/src/parallax/vllm/batch_info.py b/src/parallax/vllm/batch_info.py index 2008f104..db0a9bd9 100644 --- a/src/parallax/vllm/batch_info.py +++ b/src/parallax/vllm/batch_info.py @@ -184,7 +184,13 @@ def form_vllm_batch_decode( for req in batched_requests: req_ids.append(req.request_id) resumed_from_preemption.append(False) - new_token_ids.append([]) + output_ids = getattr(req, "output_ids", None) or [] + if output_ids: + last_token = output_ids[-1] + new_token_ids.append([last_token]) + else: + new_token_ids.append([]) + resumed_req_token_ids.append([]) sampling_params = transform_sampling_params_to_vllm(req.sampling_params) @@ -192,7 +198,10 @@ def form_vllm_batch_decode( prompt_ids = getattr(req, "input_ids", None) or [] output_ids = getattr(req, "output_ids", None) or [] - computed_token_count = len(prompt_ids) + len(output_ids) + if output_ids: + computed_token_count = len(prompt_ids) + len(output_ids) - 1 + else: + computed_token_count = len(prompt_ids) vllm_req.num_computed_tokens = computed_token_count new_blocks = kv_cache_manager.allocate_slots( From fb728d87d4f259944b14563af448ed80e9ea5e75 Mon Sep 17 00:00:00 2001 From: yuhao-zh Date: Thu, 6 Nov 2025 18:52:57 +0800 Subject: [PATCH 18/36] update model path --- src/parallax/sglang/model_runner.py | 626 ++++++++++++++-------------- src/parallax/vllm/model_runner.py | 12 +- 2 files changed, 322 insertions(+), 316 deletions(-) diff --git a/src/parallax/sglang/model_runner.py b/src/parallax/sglang/model_runner.py index aa2c1029..de997f9e 100755 --- a/src/parallax/sglang/model_runner.py +++ b/src/parallax/sglang/model_runner.py @@ -1,313 +1,313 @@ -""" -Imports sglang ModelRunner related modules and wrap them into create functions. -We use monkey patch to modify sglang originated methods. The main purpose is to pass -arguments needed by decentralized inference. -""" - -import logging -import os -import random - -import sglang -import sglang.srt.distributed.parallel_state -import torch -from mlx_lm.utils import load_config -from sglang.srt.configs.model_config import ModelConfig -from sglang.srt.distributed import ( - get_tp_group, - get_world_group, - init_distributed_environment, - set_custom_all_reduce, - set_mscclpp_all_reduce, -) -from sglang.srt.layers.dp_attention import ( - get_attention_tp_group, - initialize_dp_attention, -) -from sglang.srt.layers.moe import initialize_moe_config -from sglang.srt.model_executor.model_runner import ModelRunner as SGLModelRunner -from sglang.srt.server_args import ServerArgs -from sglang.srt.utils import ( - cpu_has_amx_support, - get_available_gpu_memory, - get_bool_env_var, - monkey_patch_p2p_access_check, -) - -from parallax.sglang.monkey_patch import apply_parallax_sglang_monkey_patch -from parallax.sglang.monkey_patch_utils.weight_loader_filter import ( - set_layer_range_for_filtering, -) -from parallax.utils.tokenizer_utils import load_tokenizer - -logger = logging.getLogger(__name__) - -_is_cpu_amx_available = cpu_has_amx_support() - - -class ParallaxModelRunner(SGLModelRunner): - """ - Parallax ModelRunner module. - pp_start_layer and pp_end_layer are passed to initialize states of distribution. - """ - - def __init__( - self, - model_config: ModelConfig, - mem_fraction_static: float, - gpu_id: int, - tp_rank: int, - tp_size: int, - moe_ep_rank: int, - moe_ep_size: int, - pp_rank: int, - pp_size: int, - nccl_port: int, - server_args: ServerArgs, - pp_start_layer: int, - pp_end_layer: int, - ): - """Add pp_start_layer and pp_end_layer for decentralized model""" - self.pp_start_layer = pp_start_layer - self.pp_end_layer = pp_end_layer - num_hidden_layers = model_config.hf_config.num_hidden_layers - set_layer_range_for_filtering(pp_start_layer, pp_end_layer, num_hidden_layers) - - super().__init__( - model_config=model_config, - mem_fraction_static=mem_fraction_static, - gpu_id=gpu_id, - tp_rank=tp_rank, - tp_size=tp_size, - pp_rank=pp_rank, - pp_size=pp_size, - moe_ep_rank=moe_ep_rank, - moe_ep_size=moe_ep_size, - nccl_port=nccl_port, - server_args=server_args, - ) - - def init_torch_distributed(self): - """ - Modifies init_torch_distributed in sglang. - The only difference is to replace initialize_model_parallel. - """ - logger.info("Init torch distributed begin.") - - try: - torch.get_device_module(self.device).set_device(self.gpu_id) - except Exception: - logger.warning( - f"Context: {self.device=} {self.gpu_id=} {os.environ.get('CUDA_VISIBLE_DEVICES')=} \ - {self.tp_rank=} {self.tp_size=}" - ) - raise - - if self.device == "cuda": - backend = "nccl" - elif self.device == "xpu": - backend = "xccl" - elif self.device == "hpu": - backend = "hccl" - elif self.device == "cpu": - backend = "gloo" - elif self.device == "npu": - backend = "hccl" - - before_avail_memory = get_available_gpu_memory(self.device, self.gpu_id) - if not self.server_args.enable_p2p_check: - monkey_patch_p2p_access_check() - - if self.server_args.dist_init_addr: - dist_init_method = f"tcp://{self.server_args.dist_init_addr}" - else: - dist_init_method = f"tcp://127.0.0.1:{self.dist_port}" - set_custom_all_reduce(not self.server_args.disable_custom_all_reduce) - set_mscclpp_all_reduce(self.server_args.enable_mscclpp) - - if not self.is_draft_worker: - if self.device == "cpu": - if _is_cpu_amx_available: - # Bind OpenMP threads to CPU cores - torch.ops.sgl_kernel.init_cpu_threads_env(self.local_omp_cpuid) - - # Set local size to hint SGLang to use shared memory based AllReduce - os.environ["LOCAL_SIZE"] = str(self.tp_size) - torch.ops.sgl_kernel.initialize(self.tp_size, self.tp_rank) - else: - logger.warning( - "init_cpu_threads_env and shared memory based AllReduce is disabled \ - since intel amx backend is not available" - ) - - # Only initialize the distributed environment on the target model worker. - init_distributed_environment( - backend=backend, - world_size=self.tp_size * self.pp_size, - rank=self.tp_size * self.pp_rank + self.tp_rank, - local_rank=self.gpu_id, - distributed_init_method=dist_init_method, - timeout=self.server_args.dist_timeout, - ) - - # Use monkey patch modified function - sglang.srt.distributed.parallel_state.initialize_model_parallel( - tensor_model_parallel_size=self.tp_size, - pipeline_model_parallel_size=self.pp_size, - expert_model_parallel_size=self.moe_ep_size, - duplicate_tp_group=self.server_args.enable_pdmux, - pp_start_layer=self.pp_start_layer, - pp_end_layer=self.pp_end_layer, - hidden_layers=self.model_config.num_hidden_layers, - ) - - initialize_dp_attention( - self.server_args, - self.model_config, - ) - - min_per_gpu_memory = get_available_gpu_memory( - self.device, - self.gpu_id, - distributed=get_world_group().world_size > 1, - cpu_group=get_world_group().cpu_group, - ) - self.tp_group = get_tp_group() - self.attention_tp_group = get_attention_tp_group() - - # Check memory for tensor parallelism - local_gpu_memory = get_available_gpu_memory(self.device, self.gpu_id) - if self.tp_size > 1 and not self.is_draft_worker: - if min_per_gpu_memory < local_gpu_memory * 0.9: - if get_bool_env_var("SGL_DISABLE_TP_MEMORY_INBALANCE_CHECK"): - logger.warning( - "The memory capacity is unbalanced. Some GPUs may be occupied by other processes. " - f"{min_per_gpu_memory=}, {local_gpu_memory=}, {local_gpu_memory * 0.9=}" - ) - else: - raise ValueError( - "The memory capacity is unbalanced. Some GPUs may be occupied by other processes. " - f"{min_per_gpu_memory=}, {local_gpu_memory=}, {local_gpu_memory * 0.9=}" - ) - - logger.info( - f"Init torch distributed ends. mem usage={(before_avail_memory - local_gpu_memory):.2f} GB" - ) - - # This is a hack for initializing CudaGraphRunner - self.server_args.pp_size = 2 - - return min_per_gpu_memory - - -def form_sgl_server_args( - model_path: str, - dtype: str = "bfloat16", - attention_backend: str = "flashinfer", - kv_block_size: int = 64, - moe_runner_backend="auto", -): - """Creates a SGL ServerArgs object""" - sgl_server_args = ServerArgs( - model_path=model_path, - dtype=dtype, - attention_backend=attention_backend, - page_size=kv_block_size, - mem_fraction_static=0.85, - moe_runner_backend=moe_runner_backend, - ) - return sgl_server_args - - -def initialize_sgl_model_runner( - model_repo: str, - start_layer: int, - end_layer: int, - kv_cache_memory_fraction: float, - attention_backend: str, - kv_block_size: int, - moe_runner_backend: str, - max_num_tokens_per_batch: int = 1024, - **kwargs, -): - """ - Creates a SGL ModelRunner object. - Returns: - - model_runner: SGL model runner - - config: model config driven by mlx-lm - - tokenizer: tokenizer driven by mlx-lm - """ - apply_parallax_sglang_monkey_patch() - - # Use selective download for GPU models to save bandwidth and disk space - from parallax.utils.selective_download import get_model_path_with_selective_download - - logger.info( - f"Downloading model with selective weight files for layers [{start_layer}, {end_layer})" - ) - model_path = get_model_path_with_selective_download( - original_model_path, - start_layer=start_layer, - end_layer=end_layer, - ) - - config = load_config(model_path) - tokenizer = load_tokenizer(model_path, eos_token_ids=config.get("eos_token_id", None)) - dtype = config.get("torch_dtype", "bfloat16") - nccl_port = random.randint(4000, 5000) - - # Handling mxfp4 arguments - quant_method = config.get("quant_method", None) - quantization_config = config.get("quantization_config", None) - if quant_method is None and quantization_config is not None: - quant_method = quantization_config.get("quant_method", None) - if quant_method == "mxfp4": - attention_backend = "triton" - moe_runner_backend = "triton_kernel" - - architectures = config.get("architectures", []) - if architectures and any("Qwen3Next" in arch for arch in architectures): - logger.debug(f"Qwen3-Next model detected, setting kv_block_size to 1") - kv_block_size = 1 - - server_args = form_sgl_server_args( - str(model_path), - dtype, - attention_backend, - kv_block_size, - moe_runner_backend, - ) - initialize_moe_config(server_args) - quant_method = None - if (quantization_config := config.get("quantization_config", None)) is not None: - quant_method = quantization_config.get("quant_method") - model_config = ModelConfig( - model_path=str(model_path), - model_override_args="{}", - dtype=dtype, - quantization=quant_method, - ) - # TODO: Fix me - model_config.hf_config.tie_word_embeddings = False - model_config.hf_config.start_layer = start_layer - model_config.hf_config.end_layer = end_layer - - logger.debug(f"model_start_layer: {model_config.hf_config.start_layer}") - logger.debug(f"model_end_layer: {model_config.hf_config.end_layer}") - - model_runner = ParallaxModelRunner( - model_config=model_config, - mem_fraction_static=kv_cache_memory_fraction, - gpu_id=0, - tp_rank=0, - tp_size=1, - pp_rank=0, - pp_size=1, - moe_ep_rank=0, - moe_ep_size=1, - nccl_port=nccl_port, - server_args=server_args, - pp_start_layer=start_layer, - pp_end_layer=end_layer, - ) - return model_runner, config, tokenizer +""" +Imports sglang ModelRunner related modules and wrap them into create functions. +We use monkey patch to modify sglang originated methods. The main purpose is to pass +arguments needed by decentralized inference. +""" + +import logging +import os +import random + +import sglang +import sglang.srt.distributed.parallel_state +import torch +from mlx_lm.utils import load_config +from sglang.srt.configs.model_config import ModelConfig +from sglang.srt.distributed import ( + get_tp_group, + get_world_group, + init_distributed_environment, + set_custom_all_reduce, + set_mscclpp_all_reduce, +) +from sglang.srt.layers.dp_attention import ( + get_attention_tp_group, + initialize_dp_attention, +) +from sglang.srt.layers.moe import initialize_moe_config +from sglang.srt.model_executor.model_runner import ModelRunner as SGLModelRunner +from sglang.srt.server_args import ServerArgs +from sglang.srt.utils import ( + cpu_has_amx_support, + get_available_gpu_memory, + get_bool_env_var, + monkey_patch_p2p_access_check, +) + +from parallax.sglang.monkey_patch import apply_parallax_sglang_monkey_patch +from parallax.sglang.monkey_patch_utils.weight_loader_filter import ( + set_layer_range_for_filtering, +) +from parallax.utils.tokenizer_utils import load_tokenizer + +logger = logging.getLogger(__name__) + +_is_cpu_amx_available = cpu_has_amx_support() + + +class ParallaxModelRunner(SGLModelRunner): + """ + Parallax ModelRunner module. + pp_start_layer and pp_end_layer are passed to initialize states of distribution. + """ + + def __init__( + self, + model_config: ModelConfig, + mem_fraction_static: float, + gpu_id: int, + tp_rank: int, + tp_size: int, + moe_ep_rank: int, + moe_ep_size: int, + pp_rank: int, + pp_size: int, + nccl_port: int, + server_args: ServerArgs, + pp_start_layer: int, + pp_end_layer: int, + ): + """Add pp_start_layer and pp_end_layer for decentralized model""" + self.pp_start_layer = pp_start_layer + self.pp_end_layer = pp_end_layer + num_hidden_layers = model_config.hf_config.num_hidden_layers + set_layer_range_for_filtering(pp_start_layer, pp_end_layer, num_hidden_layers) + + super().__init__( + model_config=model_config, + mem_fraction_static=mem_fraction_static, + gpu_id=gpu_id, + tp_rank=tp_rank, + tp_size=tp_size, + pp_rank=pp_rank, + pp_size=pp_size, + moe_ep_rank=moe_ep_rank, + moe_ep_size=moe_ep_size, + nccl_port=nccl_port, + server_args=server_args, + ) + + def init_torch_distributed(self): + """ + Modifies init_torch_distributed in sglang. + The only difference is to replace initialize_model_parallel. + """ + logger.info("Init torch distributed begin.") + + try: + torch.get_device_module(self.device).set_device(self.gpu_id) + except Exception: + logger.warning( + f"Context: {self.device=} {self.gpu_id=} {os.environ.get('CUDA_VISIBLE_DEVICES')=} \ + {self.tp_rank=} {self.tp_size=}" + ) + raise + + if self.device == "cuda": + backend = "nccl" + elif self.device == "xpu": + backend = "xccl" + elif self.device == "hpu": + backend = "hccl" + elif self.device == "cpu": + backend = "gloo" + elif self.device == "npu": + backend = "hccl" + + before_avail_memory = get_available_gpu_memory(self.device, self.gpu_id) + if not self.server_args.enable_p2p_check: + monkey_patch_p2p_access_check() + + if self.server_args.dist_init_addr: + dist_init_method = f"tcp://{self.server_args.dist_init_addr}" + else: + dist_init_method = f"tcp://127.0.0.1:{self.dist_port}" + set_custom_all_reduce(not self.server_args.disable_custom_all_reduce) + set_mscclpp_all_reduce(self.server_args.enable_mscclpp) + + if not self.is_draft_worker: + if self.device == "cpu": + if _is_cpu_amx_available: + # Bind OpenMP threads to CPU cores + torch.ops.sgl_kernel.init_cpu_threads_env(self.local_omp_cpuid) + + # Set local size to hint SGLang to use shared memory based AllReduce + os.environ["LOCAL_SIZE"] = str(self.tp_size) + torch.ops.sgl_kernel.initialize(self.tp_size, self.tp_rank) + else: + logger.warning( + "init_cpu_threads_env and shared memory based AllReduce is disabled \ + since intel amx backend is not available" + ) + + # Only initialize the distributed environment on the target model worker. + init_distributed_environment( + backend=backend, + world_size=self.tp_size * self.pp_size, + rank=self.tp_size * self.pp_rank + self.tp_rank, + local_rank=self.gpu_id, + distributed_init_method=dist_init_method, + timeout=self.server_args.dist_timeout, + ) + + # Use monkey patch modified function + sglang.srt.distributed.parallel_state.initialize_model_parallel( + tensor_model_parallel_size=self.tp_size, + pipeline_model_parallel_size=self.pp_size, + expert_model_parallel_size=self.moe_ep_size, + duplicate_tp_group=self.server_args.enable_pdmux, + pp_start_layer=self.pp_start_layer, + pp_end_layer=self.pp_end_layer, + hidden_layers=self.model_config.num_hidden_layers, + ) + + initialize_dp_attention( + self.server_args, + self.model_config, + ) + + min_per_gpu_memory = get_available_gpu_memory( + self.device, + self.gpu_id, + distributed=get_world_group().world_size > 1, + cpu_group=get_world_group().cpu_group, + ) + self.tp_group = get_tp_group() + self.attention_tp_group = get_attention_tp_group() + + # Check memory for tensor parallelism + local_gpu_memory = get_available_gpu_memory(self.device, self.gpu_id) + if self.tp_size > 1 and not self.is_draft_worker: + if min_per_gpu_memory < local_gpu_memory * 0.9: + if get_bool_env_var("SGL_DISABLE_TP_MEMORY_INBALANCE_CHECK"): + logger.warning( + "The memory capacity is unbalanced. Some GPUs may be occupied by other processes. " + f"{min_per_gpu_memory=}, {local_gpu_memory=}, {local_gpu_memory * 0.9=}" + ) + else: + raise ValueError( + "The memory capacity is unbalanced. Some GPUs may be occupied by other processes. " + f"{min_per_gpu_memory=}, {local_gpu_memory=}, {local_gpu_memory * 0.9=}" + ) + + logger.info( + f"Init torch distributed ends. mem usage={(before_avail_memory - local_gpu_memory):.2f} GB" + ) + + # This is a hack for initializing CudaGraphRunner + self.server_args.pp_size = 2 + + return min_per_gpu_memory + + +def form_sgl_server_args( + model_path: str, + dtype: str = "bfloat16", + attention_backend: str = "flashinfer", + kv_block_size: int = 64, + moe_runner_backend="auto", +): + """Creates a SGL ServerArgs object""" + sgl_server_args = ServerArgs( + model_path=model_path, + dtype=dtype, + attention_backend=attention_backend, + page_size=kv_block_size, + mem_fraction_static=0.85, + moe_runner_backend=moe_runner_backend, + ) + return sgl_server_args + + +def initialize_sgl_model_runner( + model_repo: str, + start_layer: int, + end_layer: int, + kv_cache_memory_fraction: float, + attention_backend: str, + kv_block_size: int, + moe_runner_backend: str, + max_num_tokens_per_batch: int = 1024, + **kwargs, +): + """ + Creates a SGL ModelRunner object. + Returns: + - model_runner: SGL model runner + - config: model config driven by mlx-lm + - tokenizer: tokenizer driven by mlx-lm + """ + apply_parallax_sglang_monkey_patch() + + # Use selective download for GPU models to save bandwidth and disk space + from parallax.utils.selective_download import get_model_path_with_selective_download + + logger.info( + f"Downloading model with selective weight files for layers [{start_layer}, {end_layer})" + ) + model_path = get_model_path_with_selective_download( + model_repo, + start_layer=start_layer, + end_layer=end_layer, + ) + + config = load_config(model_path) + tokenizer = load_tokenizer(model_path, eos_token_ids=config.get("eos_token_id", None)) + dtype = config.get("torch_dtype", "bfloat16") + nccl_port = random.randint(4000, 5000) + + # Handling mxfp4 arguments + quant_method = config.get("quant_method", None) + quantization_config = config.get("quantization_config", None) + if quant_method is None and quantization_config is not None: + quant_method = quantization_config.get("quant_method", None) + if quant_method == "mxfp4": + attention_backend = "triton" + moe_runner_backend = "triton_kernel" + + architectures = config.get("architectures", []) + if architectures and any("Qwen3Next" in arch for arch in architectures): + logger.debug(f"Qwen3-Next model detected, setting kv_block_size to 1") + kv_block_size = 1 + + server_args = form_sgl_server_args( + str(model_path), + dtype, + attention_backend, + kv_block_size, + moe_runner_backend, + ) + initialize_moe_config(server_args) + quant_method = None + if (quantization_config := config.get("quantization_config", None)) is not None: + quant_method = quantization_config.get("quant_method") + model_config = ModelConfig( + model_path=str(model_path), + model_override_args="{}", + dtype=dtype, + quantization=quant_method, + ) + # TODO: Fix me + model_config.hf_config.tie_word_embeddings = False + model_config.hf_config.start_layer = start_layer + model_config.hf_config.end_layer = end_layer + + logger.debug(f"model_start_layer: {model_config.hf_config.start_layer}") + logger.debug(f"model_end_layer: {model_config.hf_config.end_layer}") + + model_runner = ParallaxModelRunner( + model_config=model_config, + mem_fraction_static=kv_cache_memory_fraction, + gpu_id=0, + tp_rank=0, + tp_size=1, + pp_rank=0, + pp_size=1, + moe_ep_rank=0, + moe_ep_size=1, + nccl_port=nccl_port, + server_args=server_args, + pp_start_layer=start_layer, + pp_end_layer=end_layer, + ) + return model_runner, config, tokenizer diff --git a/src/parallax/vllm/model_runner.py b/src/parallax/vllm/model_runner.py index b3eac185..00594d4c 100644 --- a/src/parallax/vllm/model_runner.py +++ b/src/parallax/vllm/model_runner.py @@ -283,11 +283,17 @@ def initialize_vllm_model_runner( dtype: str = "float16", **kwargs, ) -> Tuple[ParallaxVLLMModelRunner, Dict, Any]: + from parallax.utils.selective_download import get_model_path_with_selective_download logger.info( f"Initializing vLLM model runner for {model_repo}, " f"layers=[{start_layer}, {end_layer})" ) - model_path = get_model_path(model_repo)[0] + model_path = get_model_path_with_selective_download( + model_repo, + start_layer=start_layer, + end_layer=end_layer, + ) + config = load_config(model_path) tokenizer = load_tokenizer(model_path, eos_token_ids=config.get("eos_token_id", None)) dtype = config.get("torch_dtype", "bfloat16") @@ -332,8 +338,8 @@ def initialize_vllm_model_runner( ) model_config = ModelConfig( - model=model_repo, - tokenizer=model_repo, + model=model_path, + tokenizer=model_path, tokenizer_mode="auto", trust_remote_code=True, dtype=dtype, From 717f69ff631044536ec8551397b6362ff8032e82 Mon Sep 17 00:00:00 2001 From: yuhao-zh Date: Thu, 6 Nov 2025 19:56:27 +0800 Subject: [PATCH 19/36] test PP with mac --- src/parallax/server/executor.py | 122 ++++++++++++------ .../weight_loader_filter.py | 4 +- src/parallax/vllm/model_runner.py | 95 ++++++++++++-- 3 files changed, 173 insertions(+), 48 deletions(-) diff --git a/src/parallax/server/executor.py b/src/parallax/server/executor.py index d7e00f31..9f707d81 100644 --- a/src/parallax/server/executor.py +++ b/src/parallax/server/executor.py @@ -407,26 +407,45 @@ def _prepare_cuda_prefill_batch(self, batched_requests: List[Request]) -> Dict[s # Prepare PP proxy tensors (common for both backends when not first peer) pp_proxy_tensors = None if not self.is_first_peer: - hidden_states = torch.cat( - [ - ( - req.hidden_states - if req.hidden_states.ndim == 2 - else req.hidden_states.unsqueeze(0) - ) - for req in batched_requests - ], - dim=0, - ) + # Concatenate hidden states from all requests + # For vLLM, we need to flatten to (total_tokens, hidden_size) + # For SGLang, we keep the batch dimension + hidden_states_list = [] + for req in batched_requests: + hs = req.hidden_states + if hs.ndim == 2: + # Already (seq_len, hidden_size) or (1, hidden_size) + hidden_states_list.append(hs) + elif hs.ndim == 3: + # (1, seq_len, hidden_size) -> (seq_len, hidden_size) + hidden_states_list.append(hs.squeeze(0)) + else: + # (hidden_size,) -> (1, hidden_size) + hidden_states_list.append(hs.unsqueeze(0)) + + # Concatenate along sequence dimension to get (total_tokens, hidden_size) + hidden_states = torch.cat(hidden_states_list, dim=0) + + # Create residual tensor with same shape residual = torch.zeros( hidden_states.shape, dtype=hidden_states.dtype, device=hidden_states.device ) - pp_proxy_tensors = PPProxyTensors( - { + + if self.backend_type == "vllm": + # For vLLM, pass directly as IntermediateTensors + from vllm.sequence import IntermediateTensors + pp_proxy_tensors = IntermediateTensors({ "hidden_states": hidden_states, "residual": residual, - } - ) + }) + else: + # For SGLang, use PPProxyTensors + pp_proxy_tensors = PPProxyTensors( + { + "hidden_states": hidden_states, + "residual": residual, + } + ) logger.debug(f"PP Proxy: hidden_states shape: {hidden_states.shape}") # Prepare lengths (common for both backends) @@ -479,26 +498,45 @@ def _prepare_cuda_decode_batch(self, batched_requests: List[Request]) -> Dict[st # Prepare PP proxy tensors (common for both backends when not first peer) pp_proxy_tensors = None if not self.is_first_peer: - hidden_states = torch.cat( - [ - ( - req.hidden_states - if req.hidden_states.ndim == 2 - else req.hidden_states.unsqueeze(0) - ) - for req in batched_requests - ], - dim=0, - ) + # Concatenate hidden states from all requests + # For vLLM, we need to flatten to (total_tokens, hidden_size) + # For SGLang, we keep the batch dimension + hidden_states_list = [] + for req in batched_requests: + hs = req.hidden_states + if hs.ndim == 2: + # Already (seq_len, hidden_size) or (1, hidden_size) + hidden_states_list.append(hs) + elif hs.ndim == 3: + # (1, seq_len, hidden_size) -> (seq_len, hidden_size) + hidden_states_list.append(hs.squeeze(0)) + else: + # (hidden_size,) -> (1, hidden_size) + hidden_states_list.append(hs.unsqueeze(0)) + + # Concatenate along sequence dimension to get (total_tokens, hidden_size) + hidden_states = torch.cat(hidden_states_list, dim=0) + + # Create residual tensor with same shape residual = torch.zeros( hidden_states.shape, dtype=hidden_states.dtype, device=hidden_states.device ) - pp_proxy_tensors = PPProxyTensors( - { + + if self.backend_type == "vllm": + # For vLLM, pass directly as IntermediateTensors + from vllm.sequence import IntermediateTensors + pp_proxy_tensors = IntermediateTensors({ "hidden_states": hidden_states, "residual": residual, - } - ) + }) + else: + # For SGLang, use PPProxyTensors + pp_proxy_tensors = PPProxyTensors( + { + "hidden_states": hidden_states, + "residual": residual, + } + ) logger.debug(f"PP Proxy: hidden_states shape: {hidden_states.shape}") # Prepare lengths (common for both backends) @@ -1092,13 +1130,13 @@ def _process_batch_cuda( ), "pp_proxy_tensors should be in cuda prepared inputs" scheduler_output = prepared_inputs["scheduler_output"] pp_proxy_tensors = prepared_inputs["pp_proxy_tensors"] - intermediate_tensors = None - if pp_proxy_tensors is not None: - # Convert SGLang's PPProxyTensors to vLLM's IntermediateTensors - from vllm.sequence import IntermediateTensors - - intermediate_tensors = IntermediateTensors(pp_proxy_tensors.tensors) + # For vLLM, pp_proxy_tensors is already an IntermediateTensors object + intermediate_tensors = pp_proxy_tensors if pp_proxy_tensors is not None else None + if intermediate_tensors is not None: logger.debug(f"vLLM: Using intermediate_tensors for PP (non-first peer)") + + # Import IntermediateTensors for type checking + from vllm.sequence import IntermediateTensors # Execute model with vLLM output = self.model_runner.execute_model( @@ -1125,14 +1163,22 @@ def _process_batch_cuda( return torch.tensor(sampled_token_ids, dtype=torch.int64) else: # Intermediate peer: return hidden states for next peer - if hasattr(output, "hidden_states") and output.hidden_states is not None: + # vLLM with Parallax PP should return IntermediateTensors + if isinstance(output, IntermediateTensors): + # Got IntermediateTensors from monkey-patched forward + if "hidden_states" in output.tensors: + return output.tensors["hidden_states"] + else: + # Return the full IntermediateTensors (might be just hidden_states tensor) + return output + elif hasattr(output, "hidden_states") and output.hidden_states is not None: return output.hidden_states elif hasattr(output, "tensors") and "hidden_states" in output.tensors: - # Handle IntermediateTensors case return output.tensors["hidden_states"] else: raise RuntimeError( "vLLM backend: expected hidden_states in output for PP, but got None. " + f"Output type: {type(output)}, is_last_peer={self.is_last_peer}. " "This typically means the model runner is not configured for pipeline parallelism." ) diff --git a/src/parallax/sglang/monkey_patch_utils/weight_loader_filter.py b/src/parallax/sglang/monkey_patch_utils/weight_loader_filter.py index 029db284..7bd48082 100644 --- a/src/parallax/sglang/monkey_patch_utils/weight_loader_filter.py +++ b/src/parallax/sglang/monkey_patch_utils/weight_loader_filter.py @@ -39,8 +39,8 @@ def _filter_weight_files_by_cache(hf_weights_files: List[str]) -> List[str]: filtered_files = filter_weight_files_by_layer_range_for_load( model_path=model_path, weight_files=hf_weights_files, - pp_start_layer=pp_start_layer, - pp_end_layer=pp_end_layer, + start_layer=pp_start_layer, + end_layer=pp_end_layer, is_first_shard=is_first_shard, is_last_shard=is_last_shard, ) diff --git a/src/parallax/vllm/model_runner.py b/src/parallax/vllm/model_runner.py index 00594d4c..e260cca5 100644 --- a/src/parallax/vllm/model_runner.py +++ b/src/parallax/vllm/model_runner.py @@ -2,7 +2,7 @@ from __future__ import annotations import importlib -from typing import Any, Callable, Dict, List, Optional, Tuple +from typing import Any, Callable, Dict, List, Optional, Tuple, Union import vllm import torch from transformers import AutoConfig, AutoTokenizer @@ -16,6 +16,7 @@ SchedulerConfig, VllmConfig, ) +from vllm.distributed.parallel_state import GroupCoordinator as VLLMGroupCoordinator from vllm.v1.core.kv_cache_manager import KVCacheManager from vllm.v1.core.kv_cache_utils import ( generate_scheduler_kv_cache_config, @@ -25,6 +26,7 @@ ) from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheGroupSpec, KVCacheTensor from vllm.v1.worker.gpu_model_runner import GPUModelRunner +from vllm.sequence import IntermediateTensors from parallax.utils.tokenizer_utils import load_tokenizer from parallax_utils.logging_config import get_logger @@ -33,6 +35,47 @@ logger = get_logger(__name__) +class ParallaxVLLMGroupCoordinator(VLLMGroupCoordinator): + """ + Parallax version of vLLM's GroupCoordinator. + Override is_first_rank and is_last_rank to use layer ranges instead of process ranks. + """ + + def __init__( + self, + group_ranks: List[List[int]], + local_rank: int, + torch_distributed_backend: Union[str, torch.distributed.Backend], + use_device_communicator: bool, + use_message_queue_broadcaster: bool = False, + group_name: Optional[str] = None, + pp_start_layer: int = 0, + pp_end_layer: int = 0, + num_hidden_layers: int = 0, + ): + super().__init__( + group_ranks=group_ranks, + local_rank=local_rank, + torch_distributed_backend=torch_distributed_backend, + use_device_communicator=use_device_communicator, + use_message_queue_broadcaster=use_message_queue_broadcaster, + group_name=group_name, + ) + self.pp_start_layer = pp_start_layer + self.pp_end_layer = pp_end_layer + self.num_hidden_layers = num_hidden_layers + + @property + def is_first_rank(self) -> bool: + """Return whether this is the first pipeline stage based on layer range.""" + return self.pp_start_layer == 0 + + @property + def is_last_rank(self) -> bool: + """Return whether this is the last pipeline stage based on layer range.""" + return self.pp_end_layer >= self.num_hidden_layers + + def _create_kv_cache_config_from_specs( kv_cache_group: KVCacheGroupSpec, attn_layers: List[str], @@ -301,14 +344,17 @@ def initialize_vllm_model_runner( num_hidden_layers = getattr(config, "num_hidden_layers", 28) is_first_peer = start_layer == 0 is_last_peer = end_layer == num_hidden_layers - virtual_pp_size = 2 if not (is_first_peer and is_last_peer) else 1 + + # For single process, always use pp_size=1 + virtual_pp_size = 1 import vllm.distributed.parallel_state as parallel_state import os if not parallel_state.model_parallel_is_initialized(): - logger.info("Initializing vLLM distributed environment...") + logger.info(f"Initializing vLLM distributed environment...") + # Set environment variables for distributed initialization if "RANK" not in os.environ: os.environ["RANK"] = "0" if "WORLD_SIZE" not in os.environ: @@ -322,14 +368,47 @@ def initialize_vllm_model_runner( try: parallel_state.init_distributed_environment() + + # Initialize with pp_size=1 for single process parallel_state.initialize_model_parallel( tensor_model_parallel_size=1, - pipeline_model_parallel_size=virtual_pp_size, + pipeline_model_parallel_size=1, ) - logger.info(f"vLLM distributed environment initialized with pp_size={virtual_pp_size}") + + # Monkey patch the PP group with our custom Parallax coordinator + # that uses layer ranges to determine is_first_rank/is_last_rank + original_pp_group = parallel_state._PP + if original_pp_group is not None: + # Get backend from device_group (torch is already imported at module level) + import torch.distributed + backend = torch.distributed.get_backend(original_pp_group.device_group) + + # Create a Parallax PP group coordinator + # Need to wrap ranks in a list of lists for group_ranks parameter + parallax_pp_group = ParallaxVLLMGroupCoordinator( + group_ranks=[original_pp_group.ranks], + local_rank=original_pp_group.local_rank, + torch_distributed_backend=backend, + use_device_communicator=original_pp_group.use_device_communicator, + use_message_queue_broadcaster=(original_pp_group.mq_broadcaster is not None), + group_name="pp", + pp_start_layer=start_layer, + pp_end_layer=end_layer, + num_hidden_layers=num_hidden_layers, + ) + # Replace the PP group + parallel_state._PP = parallax_pp_group + logger.info( + f"Replaced vLLM PP group with Parallax coordinator: " + f"is_first_rank={parallax_pp_group.is_first_rank}, " + f"is_last_rank={parallax_pp_group.is_last_rank}" + ) + + logger.info(f"vLLM distributed environment initialized") except Exception as e: logger.warning(f"Failed to initialize distributed environment: {e}") - logger.info("Continuing without distributed initialization...") + logger.error(f"vLLM distributed initialization failed. Error: {e}") + raise if end_layer > num_hidden_layers: raise ValueError( @@ -338,8 +417,8 @@ def initialize_vllm_model_runner( ) model_config = ModelConfig( - model=model_path, - tokenizer=model_path, + model=str(model_path), + tokenizer=str(model_path), tokenizer_mode="auto", trust_remote_code=True, dtype=dtype, From 9bfa201651d640defeecf7d4ffe0a1392e74fd49 Mon Sep 17 00:00:00 2001 From: yuhao_zhang Date: Tue, 11 Nov 2025 20:38:18 +0800 Subject: [PATCH 20/36] test pass --- src/parallax/server/executor.py | 80 ++++++++++++++++++++++----------- 1 file changed, 53 insertions(+), 27 deletions(-) diff --git a/src/parallax/server/executor.py b/src/parallax/server/executor.py index 9f707d81..6f100c23 100644 --- a/src/parallax/server/executor.py +++ b/src/parallax/server/executor.py @@ -422,22 +422,25 @@ def _prepare_cuda_prefill_batch(self, batched_requests: List[Request]) -> Dict[s else: # (hidden_size,) -> (1, hidden_size) hidden_states_list.append(hs.unsqueeze(0)) - + # Concatenate along sequence dimension to get (total_tokens, hidden_size) hidden_states = torch.cat(hidden_states_list, dim=0) - + # Create residual tensor with same shape residual = torch.zeros( hidden_states.shape, dtype=hidden_states.dtype, device=hidden_states.device ) - + if self.backend_type == "vllm": # For vLLM, pass directly as IntermediateTensors from vllm.sequence import IntermediateTensors - pp_proxy_tensors = IntermediateTensors({ - "hidden_states": hidden_states, - "residual": residual, - }) + + pp_proxy_tensors = IntermediateTensors( + { + "hidden_states": hidden_states, + "residual": residual, + } + ) else: # For SGLang, use PPProxyTensors pp_proxy_tensors = PPProxyTensors( @@ -513,22 +516,25 @@ def _prepare_cuda_decode_batch(self, batched_requests: List[Request]) -> Dict[st else: # (hidden_size,) -> (1, hidden_size) hidden_states_list.append(hs.unsqueeze(0)) - + # Concatenate along sequence dimension to get (total_tokens, hidden_size) hidden_states = torch.cat(hidden_states_list, dim=0) - + # Create residual tensor with same shape residual = torch.zeros( hidden_states.shape, dtype=hidden_states.dtype, device=hidden_states.device ) - + if self.backend_type == "vllm": # For vLLM, pass directly as IntermediateTensors from vllm.sequence import IntermediateTensors - pp_proxy_tensors = IntermediateTensors({ - "hidden_states": hidden_states, - "residual": residual, - }) + + pp_proxy_tensors = IntermediateTensors( + { + "hidden_states": hidden_states, + "residual": residual, + } + ) else: # For SGLang, use PPProxyTensors pp_proxy_tensors = PPProxyTensors( @@ -1134,7 +1140,7 @@ def _process_batch_cuda( intermediate_tensors = pp_proxy_tensors if pp_proxy_tensors is not None else None if intermediate_tensors is not None: logger.debug(f"vLLM: Using intermediate_tensors for PP (non-first peer)") - + # Import IntermediateTensors for type checking from vllm.sequence import IntermediateTensors @@ -1164,23 +1170,42 @@ def _process_batch_cuda( else: # Intermediate peer: return hidden states for next peer # vLLM with Parallax PP should return IntermediateTensors + def _merge_hidden_and_residual(hidden_tensor, residual_tensor): + if hidden_tensor is None: + return None + if residual_tensor is not None: + # vLLM separates residual connections; downstream peers expect the merged tensor. + hidden_tensor = hidden_tensor + residual_tensor + return hidden_tensor + if isinstance(output, IntermediateTensors): - # Got IntermediateTensors from monkey-patched forward - if "hidden_states" in output.tensors: - return output.tensors["hidden_states"] - else: - # Return the full IntermediateTensors (might be just hidden_states tensor) + tensors = output.tensors + merged = _merge_hidden_and_residual( + tensors.get("hidden_states"), tensors.get("residual") + ) + if merged is not None: + return merged + # Return full object if hidden states are packed under a different key + if tensors: return output elif hasattr(output, "hidden_states") and output.hidden_states is not None: - return output.hidden_states + residual = getattr(output, "residual", None) + merged = _merge_hidden_and_residual(output.hidden_states, residual) + if merged is not None: + return merged elif hasattr(output, "tensors") and "hidden_states" in output.tensors: - return output.tensors["hidden_states"] - else: - raise RuntimeError( - "vLLM backend: expected hidden_states in output for PP, but got None. " - f"Output type: {type(output)}, is_last_peer={self.is_last_peer}. " - "This typically means the model runner is not configured for pipeline parallelism." + tensors = output.tensors + merged = _merge_hidden_and_residual( + tensors.get("hidden_states"), tensors.get("residual") ) + if merged is not None: + return merged + + raise RuntimeError( + "vLLM backend: expected hidden_states in output for PP, but got None. " + f"Output type: {type(output)}, is_last_peer={self.is_last_peer}. " + "This typically means the model runner is not configured for pipeline parallelism." + ) else: # self.backend_type == "sglang" # ========== SGLang Backend ========== @@ -1462,6 +1487,7 @@ def shutdown(self): logger.debug("Executor shutting down...") self._should_stop = True import time + time.sleep(0.1) # Give run_loop a moment to exit gracefully try: From dfbc0c1a6f29caf4728196e196d9af119f5f3025 Mon Sep 17 00:00:00 2001 From: yuhao_zhang Date: Tue, 11 Nov 2025 20:49:20 +0800 Subject: [PATCH 21/36] pre-commit --- src/parallax/vllm/batch_info.py | 3 +- src/parallax/vllm/model_runner.py | 75 ++++++++++++++++--------------- 2 files changed, 39 insertions(+), 39 deletions(-) diff --git a/src/parallax/vllm/batch_info.py b/src/parallax/vllm/batch_info.py index db0a9bd9..b92bb4f5 100644 --- a/src/parallax/vllm/batch_info.py +++ b/src/parallax/vllm/batch_info.py @@ -1,4 +1,3 @@ - from __future__ import annotations from typing import Any, Dict, List, Optional @@ -190,7 +189,7 @@ def form_vllm_batch_decode( new_token_ids.append([last_token]) else: new_token_ids.append([]) - + resumed_req_token_ids.append([]) sampling_params = transform_sampling_params_to_vllm(req.sampling_params) diff --git a/src/parallax/vllm/model_runner.py b/src/parallax/vllm/model_runner.py index e260cca5..00d50908 100644 --- a/src/parallax/vllm/model_runner.py +++ b/src/parallax/vllm/model_runner.py @@ -1,11 +1,10 @@ - from __future__ import annotations import importlib from typing import Any, Callable, Dict, List, Optional, Tuple, Union -import vllm + import torch -from transformers import AutoConfig, AutoTokenizer +from mlx_lm.utils import load_config from vllm.config import ( CacheConfig, CompilationConfig, @@ -26,11 +25,9 @@ ) from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheGroupSpec, KVCacheTensor from vllm.v1.worker.gpu_model_runner import GPUModelRunner -from vllm.sequence import IntermediateTensors -from parallax.utils.tokenizer_utils import load_tokenizer +from parallax.utils.tokenizer_utils import load_tokenizer from parallax_utils.logging_config import get_logger -from mlx_lm.utils import get_model_path, load_config logger = get_logger(__name__) @@ -40,7 +37,7 @@ class ParallaxVLLMGroupCoordinator(VLLMGroupCoordinator): Parallax version of vLLM's GroupCoordinator. Override is_first_rank and is_last_rank to use layer ranges instead of process ranks. """ - + def __init__( self, group_ranks: List[List[int]], @@ -64,12 +61,12 @@ def __init__( self.pp_start_layer = pp_start_layer self.pp_end_layer = pp_end_layer self.num_hidden_layers = num_hidden_layers - + @property def is_first_rank(self) -> bool: """Return whether this is the first pipeline stage based on layer range.""" return self.pp_start_layer == 0 - + @property def is_last_rank(self) -> bool: """Return whether this is the last pipeline stage based on layer range.""" @@ -98,7 +95,7 @@ def _create_kv_cache_config_from_specs( num_blocks = max(100, min(1000, int(max_blocks_by_memory * 0.8))) - logger.info(f"Calculated KV cache blocks: {num_blocks} (max possible: {max_blocks_by_memory})") + logger.debug(f"Calculated KV cache blocks: {num_blocks} (max possible: {max_blocks_by_memory})") tensor_size_bytes = page_size_bytes * num_blocks @@ -171,7 +168,7 @@ def _create_kv_cache_config(self, kv_cache_memory_fraction: float = None) -> KVC ) available_memory = int(free_memory * memory_fraction) - logger.info( + logger.debug( f"Available GPU memory for KV cache: " f"{available_memory / (1024**3):.2f} GB " f"({memory_fraction:.1%} of {free_memory / (1024**3):.2f} GB)" @@ -185,7 +182,7 @@ def _create_kv_cache_config(self, kv_cache_memory_fraction: float = None) -> KVC ) kv_cache_config = generate_scheduler_kv_cache_config(kv_cache_configs) else: - logger.info("Using fallback KV cache configuration") + logger.debug("Using fallback KV cache configuration") model = self.model hf_config = model.model.config @@ -193,7 +190,7 @@ def _create_kv_cache_config(self, kv_cache_memory_fraction: float = None) -> KVC hidden_size = getattr(hf_config, "hidden_size", 1024) head_size = hidden_size // num_attention_heads - from vllm.v1.kv_cache_interface import KVCacheGroupSpec, FullAttentionSpec + from vllm.v1.kv_cache_interface import FullAttentionSpec, KVCacheGroupSpec model_dtype = self.vllm_config.model_config.dtype if isinstance(model_dtype, str): @@ -202,9 +199,7 @@ def _create_kv_cache_config(self, kv_cache_memory_fraction: float = None) -> KVC model_dtype = STR_DTYPE_TO_TORCH_DTYPE.get(model_dtype, torch.float16) kv_cache_group = KVCacheGroupSpec( - layer_names=[ - f"model.layers.{i}" for i in range(self.start_layer, self.end_layer) - ], + layer_names=[f"model.layers.{i}" for i in range(self.start_layer, self.end_layer)], kv_cache_spec=FullAttentionSpec( block_size=self.cache_config.block_size, num_kv_heads=num_attention_heads, @@ -221,7 +216,7 @@ def _create_kv_cache_config(self, kv_cache_memory_fraction: float = None) -> KVC kv_cache_memory_fraction=memory_fraction, ) - logger.info( + logger.debug( f"KV cache config generated: " f"num_blocks={kv_cache_config.num_blocks}, " f"num_groups={len(kv_cache_config.kv_cache_groups)}" @@ -230,7 +225,7 @@ def _create_kv_cache_config(self, kv_cache_memory_fraction: float = None) -> KVC return kv_cache_config def initialize_kv_cache_manager(self, max_model_len: int) -> KVCacheManager: - logger.info("Initializing vLLM KVCacheManager...") + logger.debug("Initializing vLLM KVCacheManager...") if self.kv_cache_config is None: self.kv_cache_config = self._create_kv_cache_config() @@ -278,7 +273,7 @@ def simple_hash_fn(obj: Any) -> bytes: self.request_block_hasher = get_request_block_hasher(block_size, hash_fn) logger.info("Initialized prefix cache block hasher with block_size=%d", block_size) - logger.info( + logger.debug( f"KVCacheManager initialized: block_size={kv_cache_manager.block_size}, " f"usage={kv_cache_manager.usage:.2%}" ) @@ -286,7 +281,7 @@ def simple_hash_fn(obj: Any) -> bytes: return kv_cache_manager def load_model(self) -> None: - logger.info(f"Loading vLLM model with layers [{self.start_layer}, {self.end_layer})...") + logger.debug(f"Loading vLLM model with layers [{self.start_layer}, {self.end_layer})...") from vllm.distributed.utils import get_pp_indices @@ -305,14 +300,14 @@ def custom_get_pp_indices(num_layers: int, rank: int, world_size: int): try: super().load_model() - logger.info( + logger.debug( f"Successfully loaded {self.num_shard_layers} layers " f"[{self.start_layer}:{self.end_layer}]" ) finally: vllm.distributed.utils.get_pp_indices = original_get_pp_indices - logger.info("Model loaded successfully with partial layers") + logger.debug("Model loaded successfully with partial layers") def initialize_vllm_model_runner( @@ -327,6 +322,7 @@ def initialize_vllm_model_runner( **kwargs, ) -> Tuple[ParallaxVLLMModelRunner, Dict, Any]: from parallax.utils.selective_download import get_model_path_with_selective_download + logger.info( f"Initializing vLLM model runner for {model_repo}, " f"layers=[{start_layer}, {end_layer})" ) @@ -344,15 +340,16 @@ def initialize_vllm_model_runner( num_hidden_layers = getattr(config, "num_hidden_layers", 28) is_first_peer = start_layer == 0 is_last_peer = end_layer == num_hidden_layers - + # For single process, always use pp_size=1 virtual_pp_size = 1 - import vllm.distributed.parallel_state as parallel_state import os + import vllm.distributed.parallel_state as parallel_state + if not parallel_state.model_parallel_is_initialized(): - logger.info(f"Initializing vLLM distributed environment...") + logger.debug(f"Initializing vLLM distributed environment...") # Set environment variables for distributed initialization if "RANK" not in os.environ: @@ -368,21 +365,22 @@ def initialize_vllm_model_runner( try: parallel_state.init_distributed_environment() - + # Initialize with pp_size=1 for single process parallel_state.initialize_model_parallel( tensor_model_parallel_size=1, pipeline_model_parallel_size=1, ) - + # Monkey patch the PP group with our custom Parallax coordinator # that uses layer ranges to determine is_first_rank/is_last_rank original_pp_group = parallel_state._PP if original_pp_group is not None: # Get backend from device_group (torch is already imported at module level) import torch.distributed + backend = torch.distributed.get_backend(original_pp_group.device_group) - + # Create a Parallax PP group coordinator # Need to wrap ranks in a list of lists for group_ranks parameter parallax_pp_group = ParallaxVLLMGroupCoordinator( @@ -398,13 +396,13 @@ def initialize_vllm_model_runner( ) # Replace the PP group parallel_state._PP = parallax_pp_group - logger.info( + logger.debug( f"Replaced vLLM PP group with Parallax coordinator: " f"is_first_rank={parallax_pp_group.is_first_rank}, " f"is_last_rank={parallax_pp_group.is_last_rank}" ) - - logger.info(f"vLLM distributed environment initialized") + + logger.debug(f"vLLM distributed environment initialized") except Exception as e: logger.warning(f"Failed to initialize distributed environment: {e}") logger.error(f"vLLM distributed initialization failed. Error: {e}") @@ -481,7 +479,7 @@ def initialize_vllm_model_runner( model_runner.load_model() logger.info("vLLM model loaded successfully") - logger.info("Letting vLLM automatically generate KV cache configuration...") + logger.debug("Letting vLLM automatically generate KV cache configuration...") kv_cache_specs = model_runner.get_kv_cache_spec() @@ -499,7 +497,10 @@ def initialize_vllm_model_runner( f"({kv_cache_memory_fraction:.1%} of {free_memory / (1024**3):.2f} GB)" ) - from vllm.v1.core.kv_cache_utils import get_kv_cache_configs, generate_scheduler_kv_cache_config + from vllm.v1.core.kv_cache_utils import ( + generate_scheduler_kv_cache_config, + get_kv_cache_configs, + ) kv_cache_configs = get_kv_cache_configs( vllm_config=model_runner.vllm_config, @@ -511,12 +512,12 @@ def initialize_vllm_model_runner( model_runner.kv_cache_config = kv_cache_config - logger.info("Initializing GPUModelRunner KV cache...") + logger.debug("Initializing GPUModelRunner KV cache...") model_runner.initialize_kv_cache(kv_cache_config) - logger.info("GPUModelRunner KV cache initialized successfully") + logger.debug("GPUModelRunner KV cache initialized successfully") - logger.info("Initializing KV Cache Manager...") + logger.debug("Initializing KV Cache Manager...") model_runner.initialize_kv_cache_manager(max_model_len=model_config.max_model_len) - logger.info("KV Cache Manager initialized successfully") + logger.debug("KV Cache Manager initialized successfully") return model_runner, config, tokenizer From d233d34173d91ad979c7a87778bd0de84feda137 Mon Sep 17 00:00:00 2001 From: yuhao_zhang Date: Tue, 11 Nov 2025 21:01:50 +0800 Subject: [PATCH 22/36] update --- src/parallax/server/executor.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/parallax/server/executor.py b/src/parallax/server/executor.py index c7b6eddf..f189f959 100644 --- a/src/parallax/server/executor.py +++ b/src/parallax/server/executor.py @@ -147,6 +147,9 @@ def __init__( "max_num_tokens_per_batch": max_num_tokens_per_batch, "dtype": dtype, "moe_runner_backend": moe_runner_backend, + "tp_rank": tp_rank, + "tp_size": tp_size, + "nccl_port": nccl_port, } self.model_runner, self.config, self.tokenizer = initialize_cuda_model_runner( From 191f218fc27067c2a098084a63c3c3028a6c146e Mon Sep 17 00:00:00 2001 From: yuhao_zhang Date: Wed, 12 Nov 2025 11:10:02 +0800 Subject: [PATCH 23/36] update log and pyproject --- pyproject.toml | 7 ++++++- src/parallax/vllm/model_runner.py | 2 +- 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 9cd69865..b2a58f3b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -46,10 +46,15 @@ mac = [ ] gpu = [ + "sglang[all]==0.5.4.post1", + "mlx-lm==0.28.0", + "mlx[cpu]==0.29.1", +] + +vllm = [ "vllm==0.11.0", "mlx-lm==0.28.0", "mlx[cpu]==0.29.1", - "sglang[all]==0.5.4.post1", ] benchmark = [ diff --git a/src/parallax/vllm/model_runner.py b/src/parallax/vllm/model_runner.py index 00d50908..0cb3114a 100644 --- a/src/parallax/vllm/model_runner.py +++ b/src/parallax/vllm/model_runner.py @@ -147,7 +147,7 @@ def __init__( ) def _create_kv_cache_config(self, kv_cache_memory_fraction: float = None) -> KVCacheConfig: - logger.info("Generating KV cache configuration from model...") + logger.debug("Generating KV cache configuration from model...") try: kv_cache_specs = self.model.get_kv_cache_spec() From 3c67c0f693ef9df5d5f8f9514357f6c2afa5a7ba Mon Sep 17 00:00:00 2001 From: yuhao-zh Date: Wed, 12 Nov 2025 13:02:12 +0800 Subject: [PATCH 24/36] add weight load fiter --- src/parallax/server/executor.py | 4 +- src/parallax/vllm/model_runner.py | 34 +++++++- src/parallax/vllm/monkey_patch.py | 27 +++++++ .../vllm/monkey_patch_utils/weight_loader.py | 79 +++++++++++++++++++ 4 files changed, 141 insertions(+), 3 deletions(-) create mode 100644 src/parallax/vllm/monkey_patch.py create mode 100644 src/parallax/vllm/monkey_patch_utils/weight_loader.py diff --git a/src/parallax/server/executor.py b/src/parallax/server/executor.py index 7ff32500..87b908cf 100644 --- a/src/parallax/server/executor.py +++ b/src/parallax/server/executor.py @@ -1578,6 +1578,7 @@ def shutdown(self): def run_executor_process(args, gradient_server=None): """Run executor as a subprocess""" + executor = None try: executor = Executor.create_from_args(args, gradient_server) executor.run_loop() @@ -1586,7 +1587,8 @@ def run_executor_process(args, gradient_server=None): except Exception as e: logger.exception(e) finally: - executor.shutdown() + if executor is not None: + executor.shutdown() def stop_executor_process(executor_process): diff --git a/src/parallax/vllm/model_runner.py b/src/parallax/vllm/model_runner.py index 0cb3114a..ddfe9b4e 100644 --- a/src/parallax/vllm/model_runner.py +++ b/src/parallax/vllm/model_runner.py @@ -28,6 +28,11 @@ from parallax.utils.tokenizer_utils import load_tokenizer from parallax_utils.logging_config import get_logger +from parallax.sglang.monkey_patch_utils.weight_loader_filter import ( + apply_weight_loader_filter_patch, + set_layer_range_for_filtering, +) +from parallax.vllm.monkey_patch import apply_parallax_vllm_monkey_patch logger = get_logger(__name__) @@ -194,8 +199,12 @@ def _create_kv_cache_config(self, kv_cache_memory_fraction: float = None) -> KVC model_dtype = self.vllm_config.model_config.dtype if isinstance(model_dtype, str): - from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE - + try: + from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE # type: ignore + except Exception: + # Older/newer vLLM versions may not expose torch_utils. + # Fall back silently and default to float16. + STR_DTYPE_TO_TORCH_DTYPE = {} model_dtype = STR_DTYPE_TO_TORCH_DTYPE.get(model_dtype, torch.float16) kv_cache_group = KVCacheGroupSpec( @@ -340,6 +349,27 @@ def initialize_vllm_model_runner( num_hidden_layers = getattr(config, "num_hidden_layers", 28) is_first_peer = start_layer == 0 is_last_peer = end_layer == num_hidden_layers + + # Apply Parallax vLLM monkey patches for pipeline parallelism + try: + apply_parallax_vllm_monkey_patch(is_last_stage=is_last_peer) + logger.debug( + f"Applied Parallax vLLM monkey patches: is_last_stage={is_last_peer}" + ) + except Exception as e: + logger.warning("Failed to apply Parallax vLLM monkey patches: %s", e) + + # Apply layer-range-based weight file filtering before any model load. + # Reuse the generic monkey patch used by sglang implementation to reduce + # local weight file reads when loading a partial layer shard. + try: + set_layer_range_for_filtering(start_layer, end_layer, num_hidden_layers) + apply_weight_loader_filter_patch() + logger.debug( + f"Applied weight loader filter monkey patch for layers [{start_layer}, {end_layer})" + ) + except Exception as e: + logger.warning("Failed to apply weight loader filter patch for vLLM loading: %s", e) # For single process, always use pp_size=1 virtual_pp_size = 1 diff --git a/src/parallax/vllm/monkey_patch.py b/src/parallax/vllm/monkey_patch.py new file mode 100644 index 00000000..a24c07c7 --- /dev/null +++ b/src/parallax/vllm/monkey_patch.py @@ -0,0 +1,27 @@ +""" +Monkey patches for vLLM to support Parallax pipeline parallelism. + +This module provides a unified entry point for applying all vLLM-related monkey patches +required for Parallax's distributed inference with pipeline parallelism. +""" + +from parallax.vllm.monkey_patch_utils.weight_loader import ( + apply_vllm_weight_loader_patch, + set_vllm_pipeline_stage, +) + + +## Here are patch functions for vLLM +## Hopefully, when vLLM supports pipeline parallelism natively in the way we need, +## we can remove these patches +def apply_parallax_vllm_monkey_patch(is_last_stage: bool = True): + """ + Apply all Parallax monkey patches for vLLM. + + Args: + is_last_stage: Whether this is the last pipeline stage. This affects + whether lm_head weights are expected to be loaded. + """ + set_vllm_pipeline_stage(is_last_stage) + apply_vllm_weight_loader_patch() + diff --git a/src/parallax/vllm/monkey_patch_utils/weight_loader.py b/src/parallax/vllm/monkey_patch_utils/weight_loader.py new file mode 100644 index 00000000..5841b1bb --- /dev/null +++ b/src/parallax/vllm/monkey_patch_utils/weight_loader.py @@ -0,0 +1,79 @@ +""" +Monkey patch for vLLM weight loading to skip lm_head weights on non-last pipeline stages. +This is similar to the approach used in sglang monkey patches. +""" +import logging +from typing import Any + +logger = logging.getLogger(__name__) + +_vllm_patch_applied = False +_is_last_stage = True # Default to True for safety + + +def set_vllm_pipeline_stage(is_last_stage: bool): + """Set whether this is the last pipeline stage.""" + global _is_last_stage + _is_last_stage = is_last_stage + logger.debug(f"Set vLLM pipeline stage: is_last_stage={is_last_stage}") + + +def apply_vllm_weight_loader_patch(): + """ + Apply monkey patch to vLLM's default loader to skip lm_head initialization check + when not on the last pipeline stage. + + This patch intercepts ValueError exceptions during weight loading and checks if they + are related to lm_head.weight not being initialized. If this occurs on a non-last + pipeline stage, the error is suppressed as expected behavior. Otherwise, the error + is re-raised. + """ + global _vllm_patch_applied + + if _vllm_patch_applied: + logger.debug("vLLM weight loader patch already applied, skipping") + return + + try: + from vllm.model_executor.model_loader import default_loader + + original_load_weights = default_loader.DefaultModelLoader.load_weights + + def patched_load_weights(self, model: Any, model_config: Any): + """Patched load_weights that handles lm_head for pipeline parallelism.""" + global _is_last_stage + + try: + # Call original load_weights + original_load_weights(self, model, model_config) + except ValueError as e: + error_msg = str(e) + # Check if this is the lm_head initialization error + if "lm_head.weight" in error_msg and "not initialized from checkpoint" in error_msg: + if not _is_last_stage: + # Expected behavior for non-last pipeline stages + logger.info( + "Skipping lm_head.weight initialization check on non-last pipeline stage" + ) + return + else: + # This is the last stage, lm_head should be initialized + logger.error( + "lm_head.weight not initialized on last pipeline stage, this is an error" + ) + raise + else: + # Different error, re-raise + raise + + # Apply the patch + default_loader.DefaultModelLoader.load_weights = patched_load_weights + _vllm_patch_applied = True + logger.info("Successfully applied vLLM weight loader patch for pipeline parallelism") + + except ImportError as e: + logger.warning(f"Could not apply vLLM weight loader patch: {e}") + except Exception as e: + logger.error(f"Error applying vLLM weight loader patch: {e}") + raise + From efc5a0d885c7230c4e3db8f99c1e6129f2e29834 Mon Sep 17 00:00:00 2001 From: yuhao_zhang Date: Wed, 12 Nov 2025 13:04:50 +0800 Subject: [PATCH 25/36] pre-commit --- src/parallax/vllm/model_runner.py | 16 ++++++++-------- src/parallax/vllm/monkey_patch.py | 3 +-- .../vllm/monkey_patch_utils/weight_loader.py | 18 +++++++++--------- 3 files changed, 18 insertions(+), 19 deletions(-) diff --git a/src/parallax/vllm/model_runner.py b/src/parallax/vllm/model_runner.py index ddfe9b4e..d8c8cec2 100644 --- a/src/parallax/vllm/model_runner.py +++ b/src/parallax/vllm/model_runner.py @@ -26,13 +26,13 @@ from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheGroupSpec, KVCacheTensor from vllm.v1.worker.gpu_model_runner import GPUModelRunner -from parallax.utils.tokenizer_utils import load_tokenizer -from parallax_utils.logging_config import get_logger from parallax.sglang.monkey_patch_utils.weight_loader_filter import ( apply_weight_loader_filter_patch, set_layer_range_for_filtering, ) +from parallax.utils.tokenizer_utils import load_tokenizer from parallax.vllm.monkey_patch import apply_parallax_vllm_monkey_patch +from parallax_utils.logging_config import get_logger logger = get_logger(__name__) @@ -200,7 +200,9 @@ def _create_kv_cache_config(self, kv_cache_memory_fraction: float = None) -> KVC model_dtype = self.vllm_config.model_config.dtype if isinstance(model_dtype, str): try: - from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE # type: ignore + from vllm.utils.torch_utils import ( + STR_DTYPE_TO_TORCH_DTYPE, # type: ignore + ) except Exception: # Older/newer vLLM versions may not expose torch_utils. # Fall back silently and default to float16. @@ -349,16 +351,14 @@ def initialize_vllm_model_runner( num_hidden_layers = getattr(config, "num_hidden_layers", 28) is_first_peer = start_layer == 0 is_last_peer = end_layer == num_hidden_layers - + # Apply Parallax vLLM monkey patches for pipeline parallelism try: apply_parallax_vllm_monkey_patch(is_last_stage=is_last_peer) - logger.debug( - f"Applied Parallax vLLM monkey patches: is_last_stage={is_last_peer}" - ) + logger.debug(f"Applied Parallax vLLM monkey patches: is_last_stage={is_last_peer}") except Exception as e: logger.warning("Failed to apply Parallax vLLM monkey patches: %s", e) - + # Apply layer-range-based weight file filtering before any model load. # Reuse the generic monkey patch used by sglang implementation to reduce # local weight file reads when loading a partial layer shard. diff --git a/src/parallax/vllm/monkey_patch.py b/src/parallax/vllm/monkey_patch.py index a24c07c7..a22511bf 100644 --- a/src/parallax/vllm/monkey_patch.py +++ b/src/parallax/vllm/monkey_patch.py @@ -17,11 +17,10 @@ def apply_parallax_vllm_monkey_patch(is_last_stage: bool = True): """ Apply all Parallax monkey patches for vLLM. - + Args: is_last_stage: Whether this is the last pipeline stage. This affects whether lm_head weights are expected to be loaded. """ set_vllm_pipeline_stage(is_last_stage) apply_vllm_weight_loader_patch() - diff --git a/src/parallax/vllm/monkey_patch_utils/weight_loader.py b/src/parallax/vllm/monkey_patch_utils/weight_loader.py index 5841b1bb..849fe930 100644 --- a/src/parallax/vllm/monkey_patch_utils/weight_loader.py +++ b/src/parallax/vllm/monkey_patch_utils/weight_loader.py @@ -2,6 +2,7 @@ Monkey patch for vLLM weight loading to skip lm_head weights on non-last pipeline stages. This is similar to the approach used in sglang monkey patches. """ + import logging from typing import Any @@ -22,27 +23,27 @@ def apply_vllm_weight_loader_patch(): """ Apply monkey patch to vLLM's default loader to skip lm_head initialization check when not on the last pipeline stage. - + This patch intercepts ValueError exceptions during weight loading and checks if they are related to lm_head.weight not being initialized. If this occurs on a non-last pipeline stage, the error is suppressed as expected behavior. Otherwise, the error is re-raised. """ global _vllm_patch_applied - + if _vllm_patch_applied: logger.debug("vLLM weight loader patch already applied, skipping") return - + try: from vllm.model_executor.model_loader import default_loader - + original_load_weights = default_loader.DefaultModelLoader.load_weights - + def patched_load_weights(self, model: Any, model_config: Any): """Patched load_weights that handles lm_head for pipeline parallelism.""" global _is_last_stage - + try: # Call original load_weights original_load_weights(self, model, model_config) @@ -65,15 +66,14 @@ def patched_load_weights(self, model: Any, model_config: Any): else: # Different error, re-raise raise - + # Apply the patch default_loader.DefaultModelLoader.load_weights = patched_load_weights _vllm_patch_applied = True logger.info("Successfully applied vLLM weight loader patch for pipeline parallelism") - + except ImportError as e: logger.warning(f"Could not apply vLLM weight loader patch: {e}") except Exception as e: logger.error(f"Error applying vLLM weight loader patch: {e}") raise - From c56b7bda08b8c9fc88d16f2be3a4c6b7a73e3a69 Mon Sep 17 00:00:00 2001 From: yuhao_zhang Date: Wed, 12 Nov 2025 13:13:29 +0800 Subject: [PATCH 26/36] update args name --- src/parallax/sglang/model_runner.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/parallax/sglang/model_runner.py b/src/parallax/sglang/model_runner.py index ca626e45..d1624a13 100755 --- a/src/parallax/sglang/model_runner.py +++ b/src/parallax/sglang/model_runner.py @@ -253,7 +253,7 @@ def initialize_sgl_model_runner( f"Downloading model with selective weight files for layers [{start_layer}, {end_layer})" ) model_path = get_model_path_with_selective_download( - model_repo, start_layer=start_layer, end_layer=end_layer, use_hfcache=use_hfcache + model_repo, start_layer=start_layer, end_layer=end_layer, local_files_only=use_hfcache ) config = load_config(model_path) From 41ebfecd3e26c87f58ba7ba85481dd5240eecd5f Mon Sep 17 00:00:00 2001 From: yuhao-zh Date: Wed, 12 Nov 2025 08:46:55 +0000 Subject: [PATCH 27/36] update load weights --- src/parallax/vllm/model_runner.py | 9 ++-- src/parallax/vllm/monkey_patch.py | 5 +- .../vllm/monkey_patch_utils/weight_loader.py | 53 +++++++++++-------- 3 files changed, 39 insertions(+), 28 deletions(-) diff --git a/src/parallax/vllm/model_runner.py b/src/parallax/vllm/model_runner.py index d8c8cec2..3d35c38f 100644 --- a/src/parallax/vllm/model_runner.py +++ b/src/parallax/vllm/model_runner.py @@ -315,6 +315,7 @@ def custom_get_pp_indices(num_layers: int, rank: int, world_size: int): f"Successfully loaded {self.num_shard_layers} layers " f"[{self.start_layer}:{self.end_layer}]" ) + finally: vllm.distributed.utils.get_pp_indices = original_get_pp_indices @@ -347,15 +348,15 @@ def initialize_vllm_model_runner( config = load_config(model_path) tokenizer = load_tokenizer(model_path, eos_token_ids=config.get("eos_token_id", None)) dtype = config.get("torch_dtype", "bfloat16") - - num_hidden_layers = getattr(config, "num_hidden_layers", 28) + + num_hidden_layers = config.get("num_hidden_layers") is_first_peer = start_layer == 0 is_last_peer = end_layer == num_hidden_layers # Apply Parallax vLLM monkey patches for pipeline parallelism try: - apply_parallax_vllm_monkey_patch(is_last_stage=is_last_peer) - logger.debug(f"Applied Parallax vLLM monkey patches: is_last_stage={is_last_peer}") + apply_parallax_vllm_monkey_patch(is_first_stage=is_first_peer, is_last_stage=is_last_peer) + logger.debug(f"Applied Parallax vLLM monkey patches: is_first_stage={is_first_peer}, is_last_stage={is_last_peer}") except Exception as e: logger.warning("Failed to apply Parallax vLLM monkey patches: %s", e) diff --git a/src/parallax/vllm/monkey_patch.py b/src/parallax/vllm/monkey_patch.py index a22511bf..5c098730 100644 --- a/src/parallax/vllm/monkey_patch.py +++ b/src/parallax/vllm/monkey_patch.py @@ -14,13 +14,14 @@ ## Here are patch functions for vLLM ## Hopefully, when vLLM supports pipeline parallelism natively in the way we need, ## we can remove these patches -def apply_parallax_vllm_monkey_patch(is_last_stage: bool = True): +def apply_parallax_vllm_monkey_patch(is_first_stage: bool, is_last_stage: bool): """ Apply all Parallax monkey patches for vLLM. Args: + is_first_stage: Whether this is the first pipeline stage. is_last_stage: Whether this is the last pipeline stage. This affects whether lm_head weights are expected to be loaded. """ - set_vllm_pipeline_stage(is_last_stage) + set_vllm_pipeline_stage(is_first_stage, is_last_stage) apply_vllm_weight_loader_patch() diff --git a/src/parallax/vllm/monkey_patch_utils/weight_loader.py b/src/parallax/vllm/monkey_patch_utils/weight_loader.py index 849fe930..4e6eb943 100644 --- a/src/parallax/vllm/monkey_patch_utils/weight_loader.py +++ b/src/parallax/vllm/monkey_patch_utils/weight_loader.py @@ -1,5 +1,5 @@ """ -Monkey patch for vLLM weight loading to skip lm_head weights on non-last pipeline stages. +Monkey patch for vLLM weight loading to skip non-existent weights on different pipeline stages. This is similar to the approach used in sglang monkey patches. """ @@ -9,25 +9,25 @@ logger = logging.getLogger(__name__) _vllm_patch_applied = False +_is_first_stage = False # Default to False _is_last_stage = True # Default to True for safety -def set_vllm_pipeline_stage(is_last_stage: bool): - """Set whether this is the last pipeline stage.""" - global _is_last_stage +def set_vllm_pipeline_stage(is_first_stage: bool, is_last_stage: bool): + """Set whether this is the first and/or last pipeline stage.""" + global _is_first_stage, _is_last_stage + _is_first_stage = is_first_stage _is_last_stage = is_last_stage - logger.debug(f"Set vLLM pipeline stage: is_last_stage={is_last_stage}") + logger.debug(f"Set vLLM pipeline stage: is_first_stage={_is_first_stage}, is_last_stage={_is_last_stage}") def apply_vllm_weight_loader_patch(): """ - Apply monkey patch to vLLM's default loader to skip lm_head initialization check - when not on the last pipeline stage. + Apply monkey patch to vLLM's default loader to skip initialization checks + for weights that are not expected on certain pipeline stages. - This patch intercepts ValueError exceptions during weight loading and checks if they - are related to lm_head.weight not being initialized. If this occurs on a non-last - pipeline stage, the error is suppressed as expected behavior. Otherwise, the error - is re-raised. + - Skips `embed_tokens` check on non-first stages. + - Skips `lm_head` check on non-last stages. """ global _vllm_patch_applied @@ -41,28 +41,37 @@ def apply_vllm_weight_loader_patch(): original_load_weights = default_loader.DefaultModelLoader.load_weights def patched_load_weights(self, model: Any, model_config: Any): - """Patched load_weights that handles lm_head for pipeline parallelism.""" - global _is_last_stage + """Patched load_weights that handles embed_tokens and lm_head for pipeline parallelism.""" + global _is_first_stage, _is_last_stage try: # Call original load_weights original_load_weights(self, model, model_config) except ValueError as e: error_msg = str(e) - # Check if this is the lm_head initialization error - if "lm_head.weight" in error_msg and "not initialized from checkpoint" in error_msg: + uninitialized_weights = "not initialized from checkpoint" in error_msg + + # Case 1: embed_tokens.weight not found + if "model.embed_tokens.weight" in error_msg and uninitialized_weights: + if not _is_first_stage: + # Expected behavior for non-first pipeline stages + logger.info("Skipping embed_tokens.weight initialization check on non-first pipeline stage") + else: + # This is the first stage, embed_tokens should be initialized + logger.error("embed_tokens.weight not initialized on first pipeline stage, this is an error") + raise + + # Case 2: lm_head.weight not found + elif "lm_head.weight" in error_msg and uninitialized_weights: if not _is_last_stage: # Expected behavior for non-last pipeline stages - logger.info( - "Skipping lm_head.weight initialization check on non-last pipeline stage" - ) - return + logger.info("Skipping lm_head.weight initialization check on non-last pipeline stage") else: # This is the last stage, lm_head should be initialized - logger.error( - "lm_head.weight not initialized on last pipeline stage, this is an error" - ) + logger.error("lm_head.weight not initialized on last pipeline stage, this is an error") raise + + # Case 3: Other errors else: # Different error, re-raise raise From ffeb18f9a523ecaf736075928137f5b24efc81c9 Mon Sep 17 00:00:00 2001 From: yuhao_zhang Date: Fri, 14 Nov 2025 15:29:28 +0800 Subject: [PATCH 28/36] fix bug done --- src/parallax/server/executor.py | 86 +++++++++++++++++++++++++++++++ src/parallax/vllm/batch_info.py | 40 +++++++++++++- src/parallax/vllm/model_runner.py | 22 ++++++-- 3 files changed, 143 insertions(+), 5 deletions(-) diff --git a/src/parallax/server/executor.py b/src/parallax/server/executor.py index 5470d7a7..01d2cdd0 100644 --- a/src/parallax/server/executor.py +++ b/src/parallax/server/executor.py @@ -446,6 +446,57 @@ def recv_requests_from_peer(self) -> List[Request]: return recv_reqs + def _compute_expected_intermediate_tokens(self, scheduler_output: Any) -> Optional[int]: + """Estimate the padded token count expected by vLLM for this batch.""" + if scheduler_output is None: + return None + + total_tokens = getattr(scheduler_output, "total_num_scheduled_tokens", None) + if total_tokens is None: + return None + + try: + total_tokens = int(total_tokens) + except (TypeError, ValueError): + return None + + model_runner = getattr(self, "model_runner", None) + if model_runner is None: + return None + + get_num_input_tokens = getattr(model_runner, "_get_num_input_tokens", None) + get_dp_padding = getattr(model_runner, "get_dp_padding", None) + if get_num_input_tokens is None or get_dp_padding is None: + return None + + num_input_tokens = get_num_input_tokens(total_tokens) + num_pad, _ = get_dp_padding(num_input_tokens) + return num_input_tokens + num_pad + + @staticmethod + def _pad_or_trim_tensor(tensor: torch.Tensor, target_len: int) -> torch.Tensor: + if target_len < 0: + return tensor + current_len = tensor.shape[0] + if current_len == target_len: + return tensor + if current_len > target_len: + return tensor[:target_len] + pad_shape = (target_len - current_len,) + tensor.shape[1:] + pad = tensor.new_zeros(pad_shape) + return torch.cat((tensor, pad), dim=0) + + def _resize_intermediate_tensors(self, intermediate_tensors, target_len: Optional[int]): + if intermediate_tensors is None or target_len is None: + return intermediate_tensors + if target_len < 0: + return intermediate_tensors + + # Create a list to avoid "dictionary changed size during iteration". + for key, tensor in list(intermediate_tensors.items()): + intermediate_tensors[key] = self._pad_or_trim_tensor(tensor, target_len) + return intermediate_tensors + def _prepare_cuda_prefill_batch(self, batched_requests: List[Request]) -> Dict[str, Any]: """ Prepares inputs for CUDA backends from a batch of prefill requests. @@ -459,6 +510,7 @@ def _prepare_cuda_prefill_batch(self, batched_requests: List[Request]) -> Dict[s # Prepare PP proxy tensors (common for both backends when not first peer) pp_proxy_tensors = None + pp_proxy_initial_tokens = None if not self.is_first_peer: # Concatenate hidden states from all requests # For vLLM, we need to flatten to (total_tokens, hidden_size) @@ -478,6 +530,7 @@ def _prepare_cuda_prefill_batch(self, batched_requests: List[Request]) -> Dict[s # Concatenate along sequence dimension to get (total_tokens, hidden_size) hidden_states = torch.cat(hidden_states_list, dim=0) + pp_proxy_initial_tokens = hidden_states.shape[0] # Create residual tensor with same shape residual = torch.zeros( @@ -515,6 +568,29 @@ def _prepare_cuda_prefill_batch(self, batched_requests: List[Request]) -> Dict[s schedule_outputs_prefill = form_vllm_batch_prefill(batched_requests, self.model_runner) + if not self.is_first_peer and pp_proxy_tensors is not None: + target_tokens = self._compute_expected_intermediate_tokens(schedule_outputs_prefill) + if target_tokens is not None: + before = pp_proxy_tensors["hidden_states"].shape[0] + pp_proxy_tensors = self._resize_intermediate_tensors( + pp_proxy_tensors, target_tokens + ) + after = pp_proxy_tensors["hidden_states"].shape[0] + if after != before: + logger.debug( + "PP Proxy: resized hidden_states from %d to %d tokens (requested=%s, initial=%s)", + before, + after, + target_tokens, + pp_proxy_initial_tokens, + ) + + if not self.is_first_peer and pp_proxy_tensors is not None: + logger.debug( + "PP Proxy: hidden_states shape after adjustment: %s", + tuple(pp_proxy_tensors["hidden_states"].shape), + ) + ret = { "scheduler_output": schedule_outputs_prefill, "pp_proxy_tensors": pp_proxy_tensors, @@ -572,6 +648,7 @@ def _prepare_cuda_decode_batch(self, batched_requests: List[Request]) -> Dict[st # Concatenate along sequence dimension to get (total_tokens, hidden_size) hidden_states = torch.cat(hidden_states_list, dim=0) + pp_proxy_initial_tokens = hidden_states.shape[0] # Create residual tensor with same shape residual = torch.zeros( @@ -918,6 +995,10 @@ def _handle_cuda_input_requests(self, requests: List[Request]): assert req.next_token_id is not None original_req.commit_new_token(req.next_token_id) + logger.debug( + f"[FirstPeer-CUDA] Committed token {req.next_token_id} for {req.request_id}, " + f"output_ids now has {len(original_req.output_ids)} tokens" + ) if len(req.routing_table) > 0: original_req.routing_table = req.routing_table @@ -1102,6 +1183,8 @@ def _prepare_next_single_request(self, request: Request, hidden_states: Any) -> assert isinstance( request, IntermediateRequest ), "Last peer must receive an IntermediateRequest." + logger.info(f"hidden_states shape: {hidden_states.shape}") + logger.info(f"hidden_states: {hidden_states}") if self.device == "cuda": assert hidden_states.dtype in ( torch.int64, @@ -1143,6 +1226,7 @@ def _prepare_next_batch_requests( for i, src_request in enumerate(requests): if self.is_last_peer: # Last peer gets a 1D array of token IDs + logger.info(f"hidden_states: {hidden_states}") hidden_state_for_req = hidden_states[i : i + 1] else: # Other peers get a 3D array of hidden states @@ -1217,6 +1301,7 @@ def _process_batch_cuda( import torch sampled_token_ids = output.sampled_token_ids + logger.info(f"sampled_token_ids: {sampled_token_ids}") if isinstance(sampled_token_ids, list) and len(sampled_token_ids) > 0: # Convert to tensor: pad sequences to same length max_len = max(len(seq) for seq in sampled_token_ids) @@ -1498,6 +1583,7 @@ def run_loop(self): output = self.process_batch( prepared_inputs, return_decoded_tokens=self.is_last_peer ) + logger.info(f"output: {output}") # Update metrics with per-layer latency sample (throttled by decode steps) if batch_type == "decode_batch": try: diff --git a/src/parallax/vllm/batch_info.py b/src/parallax/vllm/batch_info.py index b92bb4f5..d2371391 100644 --- a/src/parallax/vllm/batch_info.py +++ b/src/parallax/vllm/batch_info.py @@ -159,6 +159,7 @@ def form_vllm_batch_prefill( def form_vllm_batch_decode( batched_requests: List[Request], model_runner: Any = None, + scheduler: Any = None, ) -> Optional[SchedulerOutput]: if not batched_requests: return None @@ -183,7 +184,32 @@ def form_vllm_batch_decode( for req in batched_requests: req_ids.append(req.request_id) resumed_from_preemption.append(False) + + # For GPU workers (non-first peer), IntermediateRequest doesn't have output_ids + # We need to get it from vLLM's CachedRequestState in model_runner output_ids = getattr(req, "output_ids", None) or [] + + # If this request doesn't have output_ids (IntermediateRequest case), + # try to get it from model_runner's cached request state (vLLM internal state) + if not output_ids and hasattr(model_runner, "requests"): + cached_req_state = model_runner.requests.get(req.request_id) + if cached_req_state is not None: + output_ids = getattr(cached_req_state, "output_token_ids", []) + logger.debug( + f"[Decode] Retrieved output_token_ids from vLLM CachedRequestState for " + f"{req.request_id}: len={len(output_ids)}" + ) + + # Fallback: try scheduler if available + if not output_ids and scheduler is not None: + running_req = scheduler.get_running_request(req.request_id) + if running_req is not None: + output_ids = getattr(running_req, "output_ids", None) or [] + logger.debug( + f"[Decode] Retrieved output_ids from scheduler for {req.request_id}: " + f"len={len(output_ids)}" + ) + if output_ids: last_token = output_ids[-1] new_token_ids.append([last_token]) @@ -196,13 +222,23 @@ def form_vllm_batch_decode( vllm_req = _build_vllm_request(req, sampling_params, model_runner, include_outputs=True) prompt_ids = getattr(req, "input_ids", None) or [] - output_ids = getattr(req, "output_ids", None) or [] + # For decode stage, computed_token_count should be the total number of tokens + # that have been processed (including all output tokens). + # In pipeline parallelism, this must match what GPU worker expects. if output_ids: - computed_token_count = len(prompt_ids) + len(output_ids) - 1 + # All tokens (prompt + all generated outputs) have been computed + computed_token_count = len(prompt_ids) + len(output_ids) else: + # First decode step: only prompt has been computed computed_token_count = len(prompt_ids) vllm_req.num_computed_tokens = computed_token_count + # Debug logging to track state synchronization + logger.debug( + f"[Decode] req_id={req.request_id}, prompt_len={len(prompt_ids)}, " + f"output_len={len(output_ids)}, computed_tokens={computed_token_count}" + ) + new_blocks = kv_cache_manager.allocate_slots( request=vllm_req, num_new_tokens=1, diff --git a/src/parallax/vllm/model_runner.py b/src/parallax/vllm/model_runner.py index 3d35c38f..600d99ae 100644 --- a/src/parallax/vllm/model_runner.py +++ b/src/parallax/vllm/model_runner.py @@ -319,7 +319,21 @@ def custom_get_pp_indices(num_layers: int, rank: int, world_size: int): finally: vllm.distributed.utils.get_pp_indices = original_get_pp_indices - logger.debug("Model loaded successfully with partial layers") + def execute_model(self, scheduler_output, intermediate_tensors=None): + """ + Execute the model with the given scheduler output and intermediate tensors. + If this is not the first peer, and the intermediate_tensors buffer is not initialized, + initialize it. + """ + if not self.is_first_peer and self.intermediate_tensors is None: + self.intermediate_tensors = self.model.make_empty_intermediate_tensors( + batch_size=self.max_num_tokens, + dtype=self.model_config.dtype, + device=self.device, + ) + logger.debug("Successfully initialized intermediate_tensors buffer") + + return super().execute_model(scheduler_output, intermediate_tensors) def initialize_vllm_model_runner( @@ -348,7 +362,7 @@ def initialize_vllm_model_runner( config = load_config(model_path) tokenizer = load_tokenizer(model_path, eos_token_ids=config.get("eos_token_id", None)) dtype = config.get("torch_dtype", "bfloat16") - + num_hidden_layers = config.get("num_hidden_layers") is_first_peer = start_layer == 0 is_last_peer = end_layer == num_hidden_layers @@ -356,7 +370,9 @@ def initialize_vllm_model_runner( # Apply Parallax vLLM monkey patches for pipeline parallelism try: apply_parallax_vllm_monkey_patch(is_first_stage=is_first_peer, is_last_stage=is_last_peer) - logger.debug(f"Applied Parallax vLLM monkey patches: is_first_stage={is_first_peer}, is_last_stage={is_last_peer}") + logger.debug( + f"Applied Parallax vLLM monkey patches: is_first_stage={is_first_peer}, is_last_stage={is_last_peer}" + ) except Exception as e: logger.warning("Failed to apply Parallax vLLM monkey patches: %s", e) From 05e58462825067e4c2caa52f3e41d19f23d6fe02 Mon Sep 17 00:00:00 2001 From: yuhao_zhang Date: Fri, 14 Nov 2025 15:35:58 +0800 Subject: [PATCH 29/36] pre-commit --- src/parallax/server/executor.py | 1 - .../vllm/monkey_patch_utils/weight_loader.py | 22 ++++++++++++++----- 2 files changed, 16 insertions(+), 7 deletions(-) diff --git a/src/parallax/server/executor.py b/src/parallax/server/executor.py index 01d2cdd0..e8300ab9 100644 --- a/src/parallax/server/executor.py +++ b/src/parallax/server/executor.py @@ -648,7 +648,6 @@ def _prepare_cuda_decode_batch(self, batched_requests: List[Request]) -> Dict[st # Concatenate along sequence dimension to get (total_tokens, hidden_size) hidden_states = torch.cat(hidden_states_list, dim=0) - pp_proxy_initial_tokens = hidden_states.shape[0] # Create residual tensor with same shape residual = torch.zeros( diff --git a/src/parallax/vllm/monkey_patch_utils/weight_loader.py b/src/parallax/vllm/monkey_patch_utils/weight_loader.py index 4e6eb943..45e6fc87 100644 --- a/src/parallax/vllm/monkey_patch_utils/weight_loader.py +++ b/src/parallax/vllm/monkey_patch_utils/weight_loader.py @@ -18,7 +18,9 @@ def set_vllm_pipeline_stage(is_first_stage: bool, is_last_stage: bool): global _is_first_stage, _is_last_stage _is_first_stage = is_first_stage _is_last_stage = is_last_stage - logger.debug(f"Set vLLM pipeline stage: is_first_stage={_is_first_stage}, is_last_stage={_is_last_stage}") + logger.debug( + f"Set vLLM pipeline stage: is_first_stage={_is_first_stage}, is_last_stage={_is_last_stage}" + ) def apply_vllm_weight_loader_patch(): @@ -55,22 +57,30 @@ def patched_load_weights(self, model: Any, model_config: Any): if "model.embed_tokens.weight" in error_msg and uninitialized_weights: if not _is_first_stage: # Expected behavior for non-first pipeline stages - logger.info("Skipping embed_tokens.weight initialization check on non-first pipeline stage") + logger.info( + "Skipping embed_tokens.weight initialization check on non-first pipeline stage" + ) else: # This is the first stage, embed_tokens should be initialized - logger.error("embed_tokens.weight not initialized on first pipeline stage, this is an error") + logger.error( + "embed_tokens.weight not initialized on first pipeline stage, this is an error" + ) raise # Case 2: lm_head.weight not found elif "lm_head.weight" in error_msg and uninitialized_weights: if not _is_last_stage: # Expected behavior for non-last pipeline stages - logger.info("Skipping lm_head.weight initialization check on non-last pipeline stage") + logger.info( + "Skipping lm_head.weight initialization check on non-last pipeline stage" + ) else: # This is the last stage, lm_head should be initialized - logger.error("lm_head.weight not initialized on last pipeline stage, this is an error") + logger.error( + "lm_head.weight not initialized on last pipeline stage, this is an error" + ) raise - + # Case 3: Other errors else: # Different error, re-raise From f4eb7e2255f3f5738e316a665c62e67dce0aceca Mon Sep 17 00:00:00 2001 From: yuhao_zhang Date: Sun, 16 Nov 2025 13:20:49 +0800 Subject: [PATCH 30/36] refactor code --- src/parallax/server/executor.py | 141 ++++++------------------------ src/parallax/vllm/batch_info.py | 92 +++++++++++++++++++ src/parallax/vllm/model_runner.py | 8 +- 3 files changed, 121 insertions(+), 120 deletions(-) diff --git a/src/parallax/server/executor.py b/src/parallax/server/executor.py index e8300ab9..61d6ddcd 100644 --- a/src/parallax/server/executor.py +++ b/src/parallax/server/executor.py @@ -125,7 +125,7 @@ def __init__( ) elif self.backend_type == "sglang": from sglang.srt.managers.schedule_batch import ( - ScheduleBatch as CudaScheduleBatch, + ScheduleBatch as SGLangScheduleBatch, ) from parallax.sglang.model_runner import ( @@ -162,17 +162,12 @@ def __init__( f"CUDA model runner initialized. num_layers={self.config.get('num_hidden_layers')}" ) self.cur_batch = None + self.running_batch = None + if self.backend_type == "sglang": - self.running_batch = CudaScheduleBatch(reqs=[], batch_is_full=False) - # Set tp_group for tensor parallelism support - if ( - hasattr(self.model_runner, "tp_group") - and self.model_runner.tp_group is not None - ): - self.tp_group = self.model_runner.tp_group - self.tp_cpu_group = self.tp_group.cpu_group - else: - self.running_batch = None + self.running_batch = SGLangScheduleBatch(reqs=[], batch_is_full=False) + self.tp_group = self.model_runner.tp_group + self.tp_cpu_group = self.tp_group.cpu_group else: logger.debug( @@ -446,57 +441,6 @@ def recv_requests_from_peer(self) -> List[Request]: return recv_reqs - def _compute_expected_intermediate_tokens(self, scheduler_output: Any) -> Optional[int]: - """Estimate the padded token count expected by vLLM for this batch.""" - if scheduler_output is None: - return None - - total_tokens = getattr(scheduler_output, "total_num_scheduled_tokens", None) - if total_tokens is None: - return None - - try: - total_tokens = int(total_tokens) - except (TypeError, ValueError): - return None - - model_runner = getattr(self, "model_runner", None) - if model_runner is None: - return None - - get_num_input_tokens = getattr(model_runner, "_get_num_input_tokens", None) - get_dp_padding = getattr(model_runner, "get_dp_padding", None) - if get_num_input_tokens is None or get_dp_padding is None: - return None - - num_input_tokens = get_num_input_tokens(total_tokens) - num_pad, _ = get_dp_padding(num_input_tokens) - return num_input_tokens + num_pad - - @staticmethod - def _pad_or_trim_tensor(tensor: torch.Tensor, target_len: int) -> torch.Tensor: - if target_len < 0: - return tensor - current_len = tensor.shape[0] - if current_len == target_len: - return tensor - if current_len > target_len: - return tensor[:target_len] - pad_shape = (target_len - current_len,) + tensor.shape[1:] - pad = tensor.new_zeros(pad_shape) - return torch.cat((tensor, pad), dim=0) - - def _resize_intermediate_tensors(self, intermediate_tensors, target_len: Optional[int]): - if intermediate_tensors is None or target_len is None: - return intermediate_tensors - if target_len < 0: - return intermediate_tensors - - # Create a list to avoid "dictionary changed size during iteration". - for key, tensor in list(intermediate_tensors.items()): - intermediate_tensors[key] = self._pad_or_trim_tensor(tensor, target_len) - return intermediate_tensors - def _prepare_cuda_prefill_batch(self, batched_requests: List[Request]) -> Dict[str, Any]: """ Prepares inputs for CUDA backends from a batch of prefill requests. @@ -510,7 +454,6 @@ def _prepare_cuda_prefill_batch(self, batched_requests: List[Request]) -> Dict[s # Prepare PP proxy tensors (common for both backends when not first peer) pp_proxy_tensors = None - pp_proxy_initial_tokens = None if not self.is_first_peer: # Concatenate hidden states from all requests # For vLLM, we need to flatten to (total_tokens, hidden_size) @@ -530,7 +473,7 @@ def _prepare_cuda_prefill_batch(self, batched_requests: List[Request]) -> Dict[s # Concatenate along sequence dimension to get (total_tokens, hidden_size) hidden_states = torch.cat(hidden_states_list, dim=0) - pp_proxy_initial_tokens = hidden_states.shape[0] + hidden_states.shape[0] # Create residual tensor with same shape residual = torch.zeros( @@ -564,32 +507,19 @@ def _prepare_cuda_prefill_batch(self, batched_requests: List[Request]) -> Dict[s lengths_tensor = torch.tensor(lengths, device=self.device) if self.backend_type == "vllm": - from parallax.vllm.batch_info import form_vllm_batch_prefill + from parallax.vllm.batch_info import ( + compute_expected_intermediate_tokens, + form_vllm_batch_prefill, + resize_intermediate_tensors, + ) schedule_outputs_prefill = form_vllm_batch_prefill(batched_requests, self.model_runner) if not self.is_first_peer and pp_proxy_tensors is not None: - target_tokens = self._compute_expected_intermediate_tokens(schedule_outputs_prefill) - if target_tokens is not None: - before = pp_proxy_tensors["hidden_states"].shape[0] - pp_proxy_tensors = self._resize_intermediate_tensors( - pp_proxy_tensors, target_tokens - ) - after = pp_proxy_tensors["hidden_states"].shape[0] - if after != before: - logger.debug( - "PP Proxy: resized hidden_states from %d to %d tokens (requested=%s, initial=%s)", - before, - after, - target_tokens, - pp_proxy_initial_tokens, - ) - - if not self.is_first_peer and pp_proxy_tensors is not None: - logger.debug( - "PP Proxy: hidden_states shape after adjustment: %s", - tuple(pp_proxy_tensors["hidden_states"].shape), + target_tokens = compute_expected_intermediate_tokens( + schedule_outputs_prefill, self.model_runner ) + pp_proxy_tensors = resize_intermediate_tensors(pp_proxy_tensors, target_tokens) ret = { "scheduler_output": schedule_outputs_prefill, @@ -621,7 +551,6 @@ def _prepare_cuda_decode_batch(self, batched_requests: List[Request]) -> Dict[st Prepares inputs for CUDA backends from a batch of decode requests. Routes to SGLang or vLLM depending on backend_type. """ - from sglang.srt.model_executor.forward_batch_info import PPProxyTensors batch_size = len(batched_requests) if batch_size == 0: @@ -655,23 +584,17 @@ def _prepare_cuda_decode_batch(self, batched_requests: List[Request]) -> Dict[st ) if self.backend_type == "vllm": - # For vLLM, pass directly as IntermediateTensors - from vllm.sequence import IntermediateTensors - - pp_proxy_tensors = IntermediateTensors( - { - "hidden_states": hidden_states, - "residual": residual, - } - ) + from vllm.sequence import IntermediateTensors as CudaPPProxyTensors else: - # For SGLang, use PPProxyTensors - pp_proxy_tensors = PPProxyTensors( - { - "hidden_states": hidden_states, - "residual": residual, - } + from sglang.srt.model_executor.forward_batch_info import ( + PPProxyTensors as CudaPPProxyTensors, ) + pp_proxy_tensors = CudaPPProxyTensors( + { + "hidden_states": hidden_states, + "residual": residual, + } + ) logger.debug(f"PP Proxy: hidden_states shape: {hidden_states.shape}") # Prepare lengths (common for both backends) @@ -1182,8 +1105,6 @@ def _prepare_next_single_request(self, request: Request, hidden_states: Any) -> assert isinstance( request, IntermediateRequest ), "Last peer must receive an IntermediateRequest." - logger.info(f"hidden_states shape: {hidden_states.shape}") - logger.info(f"hidden_states: {hidden_states}") if self.device == "cuda": assert hidden_states.dtype in ( torch.int64, @@ -1256,18 +1177,7 @@ def _process_batch_cuda( self, prepared_inputs: Dict[str, Any], return_decoded_tokens: bool = True ): """ - Process a batch of requests in CUDA. - - Supports both vLLM and SGLang backends with Pipeline Parallelism. - - Args: - prepared_inputs: Dict containing batch data and metadata - return_decoded_tokens: If True, return token IDs (last peer); - If False, return hidden states (intermediate peer) - - Returns: - token_ids (Tensor): If return_decoded_tokens=True - hidden_states (Tensor): If return_decoded_tokens=False + Process a batch of requests in CUDA, supports both vLLM and SGLang backends. """ if self.backend_type == "vllm": # ========== vLLM Backend ========== @@ -1582,7 +1492,6 @@ def run_loop(self): output = self.process_batch( prepared_inputs, return_decoded_tokens=self.is_last_peer ) - logger.info(f"output: {output}") # Update metrics with per-layer latency sample (throttled by decode steps) if batch_type == "decode_batch": try: diff --git a/src/parallax/vllm/batch_info.py b/src/parallax/vllm/batch_info.py index d2371391..f9f05965 100644 --- a/src/parallax/vllm/batch_info.py +++ b/src/parallax/vllm/batch_info.py @@ -2,8 +2,10 @@ from typing import Any, Dict, List, Optional +import torch from vllm.sampling_params import SamplingParams as VLLMSamplingParams from vllm.sampling_params import StructuredOutputsParams +from vllm.sequence import IntermediateTensors from vllm.v1.core.sched.output import CachedRequestData, NewRequestData, SchedulerOutput from vllm.v1.request import Request as VLLMRequest @@ -16,6 +18,95 @@ logger = get_logger(__name__) +def compute_expected_intermediate_tokens(scheduler_output: Any, model_runner: Any) -> Optional[int]: + """ + Estimate the padded token count expected by vLLM for this batch. + + This function computes the total number of tokens including padding that vLLM + expects for data parallel processing. + + Args: + scheduler_output: SchedulerOutput from vLLM scheduler + model_runner: The vLLM model runner instance + + Returns: + Expected total token count including padding, or None if unable to compute + """ + if scheduler_output is None: + return None + + total_tokens = getattr(scheduler_output, "total_num_scheduled_tokens", None) + if total_tokens is None: + return None + + try: + total_tokens = int(total_tokens) + except (TypeError, ValueError): + return None + + if model_runner is None: + return None + + get_num_input_tokens = getattr(model_runner, "_get_num_input_tokens", None) + get_dp_padding = getattr(model_runner, "get_dp_padding", None) + if get_num_input_tokens is None or get_dp_padding is None: + return None + + num_input_tokens = get_num_input_tokens(total_tokens) + num_pad, _ = get_dp_padding(num_input_tokens) + return num_input_tokens + num_pad + + +def pad_or_trim_tensor(tensor: torch.Tensor, target_len: int) -> torch.Tensor: + """ + Pad or trim a tensor to the target length along dimension 0. + + Args: + tensor: Input tensor to pad/trim + target_len: Target length for dimension 0. If negative, returns unchanged. + + Returns: + Tensor with dimension 0 adjusted to target_len + """ + if target_len < 0: + return tensor + current_len = tensor.shape[0] + if current_len == target_len: + return tensor + if current_len > target_len: + return tensor[:target_len] + pad_shape = (target_len - current_len,) + tensor.shape[1:] + pad = tensor.new_zeros(pad_shape) + return torch.cat((tensor, pad), dim=0) + + +def resize_intermediate_tensors( + intermediate_tensors: IntermediateTensors, target_len: Optional[int] +) -> IntermediateTensors: + """ + Resize all tensors in IntermediateTensors to match the target length. + + This is needed for vLLM pipeline parallelism when the actual token count + doesn't match the expected padded count for data parallel processing. + + Args: + intermediate_tensors: vLLM IntermediateTensors containing hidden states + target_len: Target token count. If None or negative, returns unchanged. + + Returns: + IntermediateTensors with all tensors resized to target_len + """ + if intermediate_tensors is None or target_len is None: + return intermediate_tensors + if target_len < 0: + return intermediate_tensors + + # Create a list to avoid "dictionary changed size during iteration". + for key, tensor in list(intermediate_tensors.items()): + intermediate_tensors[key] = pad_or_trim_tensor(tensor, target_len) + return intermediate_tensors + + def transform_sampling_params_to_vllm(old_params: ParallaxSamplingParams) -> VLLMSamplingParams: structured = ( StructuredOutputsParams(json=old_params.json_schema) @@ -160,6 +251,7 @@ def form_vllm_batch_decode( batched_requests: List[Request], model_runner: Any = None, scheduler: Any = None, + **kwargs, ) -> Optional[SchedulerOutput]: if not batched_requests: return None diff --git a/src/parallax/vllm/model_runner.py b/src/parallax/vllm/model_runner.py index 600d99ae..cf25efe7 100644 --- a/src/parallax/vllm/model_runner.py +++ b/src/parallax/vllm/model_runner.py @@ -559,12 +559,12 @@ def initialize_vllm_model_runner( model_runner.kv_cache_config = kv_cache_config - logger.debug("Initializing GPUModelRunner KV cache...") + logger.info("Initializing GPUModelRunner KV cache...") model_runner.initialize_kv_cache(kv_cache_config) - logger.debug("GPUModelRunner KV cache initialized successfully") + logger.info("GPUModelRunner KV cache initialized successfully") - logger.debug("Initializing KV Cache Manager...") + logger.info("Initializing KV Cache Manager...") model_runner.initialize_kv_cache_manager(max_model_len=model_config.max_model_len) - logger.debug("KV Cache Manager initialized successfully") + logger.info("KV Cache Manager initialized successfully") return model_runner, config, tokenizer From 4f3de54901746b16443562502a85441ec035ef15 Mon Sep 17 00:00:00 2001 From: yuhao_zhang Date: Sun, 16 Nov 2025 13:22:15 +0800 Subject: [PATCH 31/36] rm code --- src/parallax/server/executor.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/parallax/server/executor.py b/src/parallax/server/executor.py index 61d6ddcd..3ed5a310 100644 --- a/src/parallax/server/executor.py +++ b/src/parallax/server/executor.py @@ -473,7 +473,6 @@ def _prepare_cuda_prefill_batch(self, batched_requests: List[Request]) -> Dict[s # Concatenate along sequence dimension to get (total_tokens, hidden_size) hidden_states = torch.cat(hidden_states_list, dim=0) - hidden_states.shape[0] # Create residual tensor with same shape residual = torch.zeros( From 07ecac6bdbafbfef0ae3eee3c1a306f6c7a06e22 Mon Sep 17 00:00:00 2001 From: yuhao_zhang Date: Mon, 17 Nov 2025 09:49:27 +0800 Subject: [PATCH 32/36] refactor executor --- src/parallax/server/executor.py | 43 ++------------------------------- 1 file changed, 2 insertions(+), 41 deletions(-) diff --git a/src/parallax/server/executor.py b/src/parallax/server/executor.py index 3ed5a310..e0b3f889 100644 --- a/src/parallax/server/executor.py +++ b/src/parallax/server/executor.py @@ -1145,7 +1145,6 @@ def _prepare_next_batch_requests( for i, src_request in enumerate(requests): if self.is_last_peer: # Last peer gets a 1D array of token IDs - logger.info(f"hidden_states: {hidden_states}") hidden_state_for_req = hidden_states[i : i + 1] else: # Other peers get a 3D array of hidden states @@ -1204,12 +1203,9 @@ def _process_batch_cuda( # Return appropriate output based on peer position if return_decoded_tokens: - # Last peer: return sampled token IDs as tensor - # Convert list[list[int]] to tensor import torch sampled_token_ids = output.sampled_token_ids - logger.info(f"sampled_token_ids: {sampled_token_ids}") if isinstance(sampled_token_ids, list) and len(sampled_token_ids) > 0: # Convert to tensor: pad sequences to same length max_len = max(len(seq) for seq in sampled_token_ids) @@ -1222,43 +1218,8 @@ def _process_batch_cuda( return torch.tensor(sampled_token_ids, dtype=torch.int64) else: # Intermediate peer: return hidden states for next peer - # vLLM with Parallax PP should return IntermediateTensors - def _merge_hidden_and_residual(hidden_tensor, residual_tensor): - if hidden_tensor is None: - return None - if residual_tensor is not None: - # vLLM separates residual connections; downstream peers expect the merged tensor. - hidden_tensor = hidden_tensor + residual_tensor - return hidden_tensor - - if isinstance(output, IntermediateTensors): - tensors = output.tensors - merged = _merge_hidden_and_residual( - tensors.get("hidden_states"), tensors.get("residual") - ) - if merged is not None: - return merged - # Return full object if hidden states are packed under a different key - if tensors: - return output - elif hasattr(output, "hidden_states") and output.hidden_states is not None: - residual = getattr(output, "residual", None) - merged = _merge_hidden_and_residual(output.hidden_states, residual) - if merged is not None: - return merged - elif hasattr(output, "tensors") and "hidden_states" in output.tensors: - tensors = output.tensors - merged = _merge_hidden_and_residual( - tensors.get("hidden_states"), tensors.get("residual") - ) - if merged is not None: - return merged - - raise RuntimeError( - "vLLM backend: expected hidden_states in output for PP, but got None. " - f"Output type: {type(output)}, is_last_peer={self.is_last_peer}. " - "This typically means the model runner is not configured for pipeline parallelism." - ) + final_hidden_states = output.tensors["hidden_states"] + output.tensors["residual"] + return final_hidden_states else: # self.backend_type == "sglang" # ========== SGLang Backend ========== From 16d9e571283fce1a95464dac32eb4a91464379c2 Mon Sep 17 00:00:00 2001 From: yuhao_zhang Date: Mon, 17 Nov 2025 09:49:50 +0800 Subject: [PATCH 33/36] pre-commit --- src/parallax/server/executor.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/parallax/server/executor.py b/src/parallax/server/executor.py index e0b3f889..3963fdf9 100644 --- a/src/parallax/server/executor.py +++ b/src/parallax/server/executor.py @@ -1193,7 +1193,6 @@ def _process_batch_cuda( logger.debug(f"vLLM: Using intermediate_tensors for PP (non-first peer)") # Import IntermediateTensors for type checking - from vllm.sequence import IntermediateTensors # Execute model with vLLM output = self.model_runner.execute_model( From 64db1c1c17de21674d41c2dc02381b9f03f7c865 Mon Sep 17 00:00:00 2001 From: yuhao-zh Date: Mon, 17 Nov 2025 10:58:21 +0800 Subject: [PATCH 34/36] fix single gpu bug --- src/parallax/vllm/batch_info.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/parallax/vllm/batch_info.py b/src/parallax/vllm/batch_info.py index f9f05965..95f2c442 100644 --- a/src/parallax/vllm/batch_info.py +++ b/src/parallax/vllm/batch_info.py @@ -319,7 +319,7 @@ def form_vllm_batch_decode( # In pipeline parallelism, this must match what GPU worker expects. if output_ids: # All tokens (prompt + all generated outputs) have been computed - computed_token_count = len(prompt_ids) + len(output_ids) + computed_token_count = len(prompt_ids) + len(output_ids) - 1 else: # First decode step: only prompt has been computed computed_token_count = len(prompt_ids) From 1c68d24fd18e299e035e2eb305e5f8756c02597d Mon Sep 17 00:00:00 2001 From: yuhao_zhang Date: Mon, 17 Nov 2025 20:12:38 +0800 Subject: [PATCH 35/36] pre-commit --- .pre-commit-config.yaml | 1 + src/parallax/server/radix_cache.py | 846 +++++++++--------- src/parallax/sglang/batch_info.py | 446 ++++----- .../monkey_patch_utils/gpt_oss_model.py | 384 ++++---- .../monkey_patch_utils/triton_backend.py | 222 ++--- src/parallax/utils/tokenizer_utils.py | 250 +++--- src/parallax_utils/ascii_anime.py | 454 +++++----- tests/test_prefix_cache.py | 64 +- 8 files changed, 1334 insertions(+), 1333 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 01d7ce5f..6a985232 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -7,6 +7,7 @@ repos: exclude: '^src/frontend/dist/|\.svg$|\.png$|\.jpg$|\.jpeg$|\.gif$|\.webp$' - id: trailing-whitespace - id: mixed-line-ending + args: ['--fix=lf'] - repo: https://github.com/PyCQA/autoflake rev: v2.3.1 diff --git a/src/parallax/server/radix_cache.py b/src/parallax/server/radix_cache.py index 91841708..0bd07d27 100755 --- a/src/parallax/server/radix_cache.py +++ b/src/parallax/server/radix_cache.py @@ -1,423 +1,423 @@ -""" -Prefix Cache class for KV Cache reuse. -This module is implemented using radix tree, which retains the -same as SGLang. -""" - -import heapq -import time -from collections import defaultdict -from functools import partial -from typing import Dict, List, Optional, Tuple - -import mlx.core as mx - -from parallax.server.kv_cache import KVCache -from parallax.server.request import Request - - -class TreeNode: - """ - Radix tree node data structure. - Key: token id list. It should be an empty list for the root node. - Value: kv cache positions. - """ - - counter = 0 - - def __init__(self, node_id: Optional[int] = None): - self.children = defaultdict(TreeNode) - self.parent: TreeNode = None - self.key: List[int] = None - self.value: Optional[List[int]] = None - self.kv_cache = None - self.lock_ref = 0 - self.last_access_time = time.monotonic() - - self.hit_count = 0 - - self.node_id = TreeNode.counter if node_id is None else node_id - TreeNode.counter += 1 - - @property - def evicted(self): - """Check if this node has been evicted""" - return self.value is None - - def __lt__(self, other: "TreeNode"): - return self.last_access_time < other.last_access_time - - -def _key_match_page_size1(key0: List, key1: List): - """Key match function especially for page_size=1""" - i = 0 - for k0, k1 in zip(key0, key1): - if k0 != k1: - break - i += 1 - return i - - -def _key_match_paged(key0: List, key1: List, page_size: int): - """Key match function for page_size>1""" - min_len = min(len(key0), len(key1)) - - i = 0 - while i < min_len: - if key0[i : i + page_size] != key1[i : i + page_size]: - break - i += page_size - - return i - - -class RadixCache: - """ - Manages Radix Cache for the running executor. - Note: Currently only support page_size=1. - """ - - def __init__( - self, - num_kv_heads: int, - head_dim: int, - num_layers: int, - dtype: mx.Dtype, - page_size: int = 1, - max_num_tokens: int = None, - ): - self.num_kv_heads = num_kv_heads - self.head_dim = head_dim - self.num_layers = num_layers - self.dtype = dtype - self.page_size = page_size - self.req_to_token: Dict[str, List[int]] = {} - if max_num_tokens is None: - self.max_num_tokens = 10000 - else: - self.max_num_tokens = max_num_tokens - - if self.page_size == 1: - self.key_match_fn = _key_match_page_size1 - self.get_child_key_fn = lambda key: key[0] - else: - self.key_match_fn = partial(_key_match_paged, page_size=page_size) - self.get_child_key_fn = lambda key: tuple(key[:page_size]) - self.reset() - - def reset(self): - """Reset function for the whole tree""" - self.root_node = TreeNode() - self.root_node.key = [] - self.root_node.value = [] - self.root_node.lock_ref = 1 - self.evictable_size_ = 0 - self.protected_size_ = 0 - self.req_to_token = {} - - def update_req_to_token(self, req_id: str, token_ids: List[int]): - """Update the req->tokens dict""" - value = self.req_to_token.get(req_id) - if value: - self.req_to_token[req_id] = self.req_to_token[req_id] + token_ids - else: - self.req_to_token[req_id] = token_ids - - def evict_request(self, req_id: str): - """Remove a single request. Used when request if finished or cached""" - del self.req_to_token[req_id] - - def match_prefix( - self, - key: List[int], - ) -> Tuple[mx.array, mx.array, int]: - """Find the matching prefix from the radix tree. - Args: - key: A list of token IDs to find a matching prefix. - Returns: - A tuple of (value, matched last node) - Note that this API can modify the internal state of the Radix tree. - The last node creates a new child if the prefix is shorter than - the last node's value. - """ - if len(key) == 0: - return ( - [], - self.root_node, - ) - - if self.page_size != 1: - page_aligned_len = len(key) // self.page_size * self.page_size - key = key[:page_aligned_len] - - value, last_node = self._match_prefix_helper(self.root_node, key) - return value, last_node - - def fetch_kv_cache(self, node: TreeNode): - """ - Get and concat kv cache from a tree node to the root. - """ - assert node != self.root_node, "should not fetch from the root node." - k_cache, v_cache = node.kv_cache.fetch() - node = node.parent - while node != self.root_node: - cur_k_cache, cur_v_cache = node.kv_cache.fetch() - k_cache = mx.concatenate([cur_k_cache, k_cache], axis=2) - v_cache = mx.concatenate([cur_v_cache, v_cache], axis=2) - node = node.parent - return k_cache, v_cache - - def insert(self, key: List, value, k_cache: mx.array, v_cache: mx.array): - """Insert a tree node.""" - if value is None: - value = list(key) - return self._insert_helper(self.root_node, key, value, k_cache, v_cache) - - def evict(self, num_tokens: int): - """Remove cached tokens until the total tokens stored is reduced by num_tokens""" - leaves = self._collect_leaves() - heapq.heapify(leaves) - - num_evicted = 0 - while num_evicted < num_tokens and len(leaves) > 0: - x = heapq.heappop(leaves) - - if x == self.root_node: - break - if x.lock_ref > 0: - continue - - # self.token_to_kv_pool_allocator.free(x.value) TODO - num_evicted += len(x.value) - self._delete_leaf(x) - - if len(x.parent.children) == 0: - heapq.heappush(leaves, x.parent) - - def pretty_print(self): - """Print the whole tree.""" - self._print_helper(self.root_node, 0) - print(f"#tokens: {self.total_size()}") - - def total_size(self): - """Get the total number of tokens stored in the tree.""" - return self._total_size_helper() - - def increase_lock_ref(self, node: TreeNode): - """Increase the lock reference by 1 from a node to the root.""" - delta = 0 - while node != self.root_node: - if node.lock_ref == 0: - self.evictable_size_ -= len(node.value) - self.protected_size_ += len(node.value) - delta -= len(node.value) - node.lock_ref += 1 - node = node.parent - return delta - - def decrease_lock_ref(self, node: TreeNode): - """decrease the lock reference by 1 from a node to the root.""" - delta = 0 - while node != self.root_node: - if node.lock_ref == 1: - self.evictable_size_ += len(node.value) - self.protected_size_ -= len(node.value) - delta += len(node.value) - if node.lock_ref > 0: - node.lock_ref -= 1 - node = node.parent - return delta - - def cache_finished_request(self, req: Request, k_cache: mx.array, v_cache: mx.array): - """Cache request when it finishes.""" - token_ids = self.req_to_token[req.request_id] - _, node = self.insert(key=token_ids, value=None, k_cache=k_cache, v_cache=v_cache) - self.decrease_lock_ref(node) - - if self.protected_size_ > self.max_num_tokens: - self.evict(self.protected_size_) - elif self.protected_size_ + self.evictable_size_ > self.max_num_tokens: - self.evict(self.protected_size_ + self.evictable_size_ - self.max_num_tokens) - - def cache_unfinished_request(self, req: Request, k_cache: mx.array, v_cache: mx.array): - """Cache request when it is unfinished.""" - token_ids = self.req_to_token[req.request_id] - _, node = self.insert(key=token_ids, value=None, k_cache=k_cache, v_cache=v_cache) - self.increase_lock_ref(node) - if self.protected_size_ > self.max_num_tokens: - self.evict(self.protected_size_) - elif self.protected_size_ + self.evictable_size_ > self.max_num_tokens: - self.evict(self.protected_size_ + self.evictable_size_ - self.max_num_tokens) - - def _match_prefix_helper(self, node: TreeNode, key: List): - """Match prefix helper function""" - node.last_access_time = time.monotonic() - - child_key = self.get_child_key_fn(key) - - value = [] - while len(key) > 0 and child_key in node.children.keys(): - child = node.children[child_key] - child.last_access_time = time.monotonic() - prefix_len = self.key_match_fn(child.key, key) - if prefix_len >= len(child.key): - value += child.value - node = child - key = key[prefix_len:] - if len(key): - child_key = self.get_child_key_fn(key) - else: - new_node = self._split_node(child.key, child, prefix_len) - value += new_node.value - node = new_node - break - - return value, node - - def _split_node(self, key, child: TreeNode, split_len: int): - """Split a node for insertion. Note that this node be any nodes in the tree.""" - # new_node -> child - new_node = TreeNode() - new_node.children = {self.get_child_key_fn(key[split_len:]): child} - new_node.parent = child.parent - new_node.lock_ref = child.lock_ref - new_node.key = child.key[:split_len] - new_node.value = child.value[:split_len] - child.parent = new_node - child.key = child.key[split_len:] - child.value = child.value[split_len:] - new_node.parent.children[self.get_child_key_fn(key)] = new_node - - child_k_cache, child_v_cache = child.kv_cache.fetch() - # create kv cache for new_node - new_k_cache = child_k_cache[..., :split_len, :] - new_v_cache = child_v_cache[..., :split_len, :] - new_node.kv_cache = KVCache( - num_kv_heads=self.num_kv_heads, - head_dim=self.head_dim, - num_layers=self.num_layers, - dtype=self.dtype, - block_size=self.page_size, - num_initial_tokens=self.page_size, - ) - new_node.kv_cache.update(new_k_cache, new_v_cache) - # update kv cache for child - child_k_cache = child_k_cache[..., split_len:, :] - child_v_cache = child_v_cache[..., split_len:, :] - child.kv_cache = KVCache( - num_kv_heads=self.num_kv_heads, - head_dim=self.head_dim, - num_layers=self.num_layers, - dtype=self.dtype, - block_size=self.page_size, - num_initial_tokens=self.page_size, - ) - child.kv_cache.update(child_k_cache, child_v_cache) - - return new_node - - def _collect_leaves(self): - """Returns all the leaf nodes from the root""" - ret_list = [] - stack = [self.root_node] - - while stack: - cur_node = stack.pop() - if len(cur_node.children) == 0: - ret_list.append(cur_node) - else: - stack.extend(cur_node.children.values()) - - return ret_list - - def _insert_helper( - self, node: TreeNode, key: List, value: List, k_cache: mx.array, v_cache: mx.array - ): - """Insert key-value helper function""" - node.last_access_time = time.monotonic() - if len(key) == 0: - return 0 - - child_key = self.get_child_key_fn(key) - - total_prefix_length = 0 - while len(key) > 0 and child_key in node.children.keys(): - node = node.children[child_key] - node.last_access_time = time.monotonic() - prefix_len = self.key_match_fn(node.key, key) - total_prefix_length += prefix_len - key = key[prefix_len:] - value = value[prefix_len:] - - if prefix_len < len(node.key): - new_node = self._split_node(node.key, node, prefix_len) - node = new_node - - if len(key): - child_key = self.get_child_key_fn(key) - - if len(key): - new_node = TreeNode() - new_node.parent = node - new_node.key = key - new_node.value = value - node.children[child_key] = new_node - self.evictable_size_ += len(value) - - # create kvcache for new_node - kv_cache = KVCache( - num_kv_heads=self.num_kv_heads, - head_dim=self.head_dim, - num_layers=self.num_layers, - dtype=self.dtype, - block_size=self.page_size, - num_initial_tokens=self.page_size, - ) - k_cache = k_cache[..., total_prefix_length:, :] - v_cache = v_cache[..., total_prefix_length:, :] - kv_cache.update(k_cache, v_cache) - new_node.kv_cache = kv_cache - - node = new_node - - return total_prefix_length, node - - def _delete_leaf(self, node): - """Deletes a leaf node.""" - for k, v in node.parent.children.items(): - if v == node: - del node.parent.children[k] - break - self.evictable_size_ -= len(node.key) - - def _print_helper(self, node: TreeNode, indent: int): - """Prints the radix tree in a human-readable format.""" - stack = [(node, indent)] - while stack: - current_node, current_indent = stack.pop() - print( - " " * current_indent, - len(current_node.key), - current_node.key[:10], - f"r={current_node.lock_ref}", - current_node.kv_cache, - ) - for key, child in current_node.children.items(): - stack.append((child, current_indent + 2)) - - assert key == self.get_child_key_fn( - child.key - ), f"{key=}, {self.get_child_key_fn(child.key)=}" - - def _total_size_helper(self): - """Get total number of tokens stored helper function""" - total_size = 0 - stack = [self.root_node] - while stack: - current_node = stack.pop() - total_size += len(current_node.value) - for child in current_node.children.values(): - if child.evicted: - continue - stack.append(child) - return total_size +""" +Prefix Cache class for KV Cache reuse. +This module is implemented using radix tree, which retains the +same as SGLang. +""" + +import heapq +import time +from collections import defaultdict +from functools import partial +from typing import Dict, List, Optional, Tuple + +import mlx.core as mx + +from parallax.server.kv_cache import KVCache +from parallax.server.request import Request + + +class TreeNode: + """ + Radix tree node data structure. + Key: token id list. It should be an empty list for the root node. + Value: kv cache positions. + """ + + counter = 0 + + def __init__(self, node_id: Optional[int] = None): + self.children = defaultdict(TreeNode) + self.parent: TreeNode = None + self.key: List[int] = None + self.value: Optional[List[int]] = None + self.kv_cache = None + self.lock_ref = 0 + self.last_access_time = time.monotonic() + + self.hit_count = 0 + + self.node_id = TreeNode.counter if node_id is None else node_id + TreeNode.counter += 1 + + @property + def evicted(self): + """Check if this node has been evicted""" + return self.value is None + + def __lt__(self, other: "TreeNode"): + return self.last_access_time < other.last_access_time + + +def _key_match_page_size1(key0: List, key1: List): + """Key match function especially for page_size=1""" + i = 0 + for k0, k1 in zip(key0, key1): + if k0 != k1: + break + i += 1 + return i + + +def _key_match_paged(key0: List, key1: List, page_size: int): + """Key match function for page_size>1""" + min_len = min(len(key0), len(key1)) + + i = 0 + while i < min_len: + if key0[i : i + page_size] != key1[i : i + page_size]: + break + i += page_size + + return i + + +class RadixCache: + """ + Manages Radix Cache for the running executor. + Note: Currently only support page_size=1. + """ + + def __init__( + self, + num_kv_heads: int, + head_dim: int, + num_layers: int, + dtype: mx.Dtype, + page_size: int = 1, + max_num_tokens: int = None, + ): + self.num_kv_heads = num_kv_heads + self.head_dim = head_dim + self.num_layers = num_layers + self.dtype = dtype + self.page_size = page_size + self.req_to_token: Dict[str, List[int]] = {} + if max_num_tokens is None: + self.max_num_tokens = 10000 + else: + self.max_num_tokens = max_num_tokens + + if self.page_size == 1: + self.key_match_fn = _key_match_page_size1 + self.get_child_key_fn = lambda key: key[0] + else: + self.key_match_fn = partial(_key_match_paged, page_size=page_size) + self.get_child_key_fn = lambda key: tuple(key[:page_size]) + self.reset() + + def reset(self): + """Reset function for the whole tree""" + self.root_node = TreeNode() + self.root_node.key = [] + self.root_node.value = [] + self.root_node.lock_ref = 1 + self.evictable_size_ = 0 + self.protected_size_ = 0 + self.req_to_token = {} + + def update_req_to_token(self, req_id: str, token_ids: List[int]): + """Update the req->tokens dict""" + value = self.req_to_token.get(req_id) + if value: + self.req_to_token[req_id] = self.req_to_token[req_id] + token_ids + else: + self.req_to_token[req_id] = token_ids + + def evict_request(self, req_id: str): + """Remove a single request. Used when request if finished or cached""" + del self.req_to_token[req_id] + + def match_prefix( + self, + key: List[int], + ) -> Tuple[mx.array, mx.array, int]: + """Find the matching prefix from the radix tree. + Args: + key: A list of token IDs to find a matching prefix. + Returns: + A tuple of (value, matched last node) + Note that this API can modify the internal state of the Radix tree. + The last node creates a new child if the prefix is shorter than + the last node's value. + """ + if len(key) == 0: + return ( + [], + self.root_node, + ) + + if self.page_size != 1: + page_aligned_len = len(key) // self.page_size * self.page_size + key = key[:page_aligned_len] + + value, last_node = self._match_prefix_helper(self.root_node, key) + return value, last_node + + def fetch_kv_cache(self, node: TreeNode): + """ + Get and concat kv cache from a tree node to the root. + """ + assert node != self.root_node, "should not fetch from the root node." + k_cache, v_cache = node.kv_cache.fetch() + node = node.parent + while node != self.root_node: + cur_k_cache, cur_v_cache = node.kv_cache.fetch() + k_cache = mx.concatenate([cur_k_cache, k_cache], axis=2) + v_cache = mx.concatenate([cur_v_cache, v_cache], axis=2) + node = node.parent + return k_cache, v_cache + + def insert(self, key: List, value, k_cache: mx.array, v_cache: mx.array): + """Insert a tree node.""" + if value is None: + value = list(key) + return self._insert_helper(self.root_node, key, value, k_cache, v_cache) + + def evict(self, num_tokens: int): + """Remove cached tokens until the total tokens stored is reduced by num_tokens""" + leaves = self._collect_leaves() + heapq.heapify(leaves) + + num_evicted = 0 + while num_evicted < num_tokens and len(leaves) > 0: + x = heapq.heappop(leaves) + + if x == self.root_node: + break + if x.lock_ref > 0: + continue + + # self.token_to_kv_pool_allocator.free(x.value) TODO + num_evicted += len(x.value) + self._delete_leaf(x) + + if len(x.parent.children) == 0: + heapq.heappush(leaves, x.parent) + + def pretty_print(self): + """Print the whole tree.""" + self._print_helper(self.root_node, 0) + print(f"#tokens: {self.total_size()}") + + def total_size(self): + """Get the total number of tokens stored in the tree.""" + return self._total_size_helper() + + def increase_lock_ref(self, node: TreeNode): + """Increase the lock reference by 1 from a node to the root.""" + delta = 0 + while node != self.root_node: + if node.lock_ref == 0: + self.evictable_size_ -= len(node.value) + self.protected_size_ += len(node.value) + delta -= len(node.value) + node.lock_ref += 1 + node = node.parent + return delta + + def decrease_lock_ref(self, node: TreeNode): + """decrease the lock reference by 1 from a node to the root.""" + delta = 0 + while node != self.root_node: + if node.lock_ref == 1: + self.evictable_size_ += len(node.value) + self.protected_size_ -= len(node.value) + delta += len(node.value) + if node.lock_ref > 0: + node.lock_ref -= 1 + node = node.parent + return delta + + def cache_finished_request(self, req: Request, k_cache: mx.array, v_cache: mx.array): + """Cache request when it finishes.""" + token_ids = self.req_to_token[req.request_id] + _, node = self.insert(key=token_ids, value=None, k_cache=k_cache, v_cache=v_cache) + self.decrease_lock_ref(node) + + if self.protected_size_ > self.max_num_tokens: + self.evict(self.protected_size_) + elif self.protected_size_ + self.evictable_size_ > self.max_num_tokens: + self.evict(self.protected_size_ + self.evictable_size_ - self.max_num_tokens) + + def cache_unfinished_request(self, req: Request, k_cache: mx.array, v_cache: mx.array): + """Cache request when it is unfinished.""" + token_ids = self.req_to_token[req.request_id] + _, node = self.insert(key=token_ids, value=None, k_cache=k_cache, v_cache=v_cache) + self.increase_lock_ref(node) + if self.protected_size_ > self.max_num_tokens: + self.evict(self.protected_size_) + elif self.protected_size_ + self.evictable_size_ > self.max_num_tokens: + self.evict(self.protected_size_ + self.evictable_size_ - self.max_num_tokens) + + def _match_prefix_helper(self, node: TreeNode, key: List): + """Match prefix helper function""" + node.last_access_time = time.monotonic() + + child_key = self.get_child_key_fn(key) + + value = [] + while len(key) > 0 and child_key in node.children.keys(): + child = node.children[child_key] + child.last_access_time = time.monotonic() + prefix_len = self.key_match_fn(child.key, key) + if prefix_len >= len(child.key): + value += child.value + node = child + key = key[prefix_len:] + if len(key): + child_key = self.get_child_key_fn(key) + else: + new_node = self._split_node(child.key, child, prefix_len) + value += new_node.value + node = new_node + break + + return value, node + + def _split_node(self, key, child: TreeNode, split_len: int): + """Split a node for insertion. Note that this node be any nodes in the tree.""" + # new_node -> child + new_node = TreeNode() + new_node.children = {self.get_child_key_fn(key[split_len:]): child} + new_node.parent = child.parent + new_node.lock_ref = child.lock_ref + new_node.key = child.key[:split_len] + new_node.value = child.value[:split_len] + child.parent = new_node + child.key = child.key[split_len:] + child.value = child.value[split_len:] + new_node.parent.children[self.get_child_key_fn(key)] = new_node + + child_k_cache, child_v_cache = child.kv_cache.fetch() + # create kv cache for new_node + new_k_cache = child_k_cache[..., :split_len, :] + new_v_cache = child_v_cache[..., :split_len, :] + new_node.kv_cache = KVCache( + num_kv_heads=self.num_kv_heads, + head_dim=self.head_dim, + num_layers=self.num_layers, + dtype=self.dtype, + block_size=self.page_size, + num_initial_tokens=self.page_size, + ) + new_node.kv_cache.update(new_k_cache, new_v_cache) + # update kv cache for child + child_k_cache = child_k_cache[..., split_len:, :] + child_v_cache = child_v_cache[..., split_len:, :] + child.kv_cache = KVCache( + num_kv_heads=self.num_kv_heads, + head_dim=self.head_dim, + num_layers=self.num_layers, + dtype=self.dtype, + block_size=self.page_size, + num_initial_tokens=self.page_size, + ) + child.kv_cache.update(child_k_cache, child_v_cache) + + return new_node + + def _collect_leaves(self): + """Returns all the leaf nodes from the root""" + ret_list = [] + stack = [self.root_node] + + while stack: + cur_node = stack.pop() + if len(cur_node.children) == 0: + ret_list.append(cur_node) + else: + stack.extend(cur_node.children.values()) + + return ret_list + + def _insert_helper( + self, node: TreeNode, key: List, value: List, k_cache: mx.array, v_cache: mx.array + ): + """Insert key-value helper function""" + node.last_access_time = time.monotonic() + if len(key) == 0: + return 0 + + child_key = self.get_child_key_fn(key) + + total_prefix_length = 0 + while len(key) > 0 and child_key in node.children.keys(): + node = node.children[child_key] + node.last_access_time = time.monotonic() + prefix_len = self.key_match_fn(node.key, key) + total_prefix_length += prefix_len + key = key[prefix_len:] + value = value[prefix_len:] + + if prefix_len < len(node.key): + new_node = self._split_node(node.key, node, prefix_len) + node = new_node + + if len(key): + child_key = self.get_child_key_fn(key) + + if len(key): + new_node = TreeNode() + new_node.parent = node + new_node.key = key + new_node.value = value + node.children[child_key] = new_node + self.evictable_size_ += len(value) + + # create kvcache for new_node + kv_cache = KVCache( + num_kv_heads=self.num_kv_heads, + head_dim=self.head_dim, + num_layers=self.num_layers, + dtype=self.dtype, + block_size=self.page_size, + num_initial_tokens=self.page_size, + ) + k_cache = k_cache[..., total_prefix_length:, :] + v_cache = v_cache[..., total_prefix_length:, :] + kv_cache.update(k_cache, v_cache) + new_node.kv_cache = kv_cache + + node = new_node + + return total_prefix_length, node + + def _delete_leaf(self, node): + """Deletes a leaf node.""" + for k, v in node.parent.children.items(): + if v == node: + del node.parent.children[k] + break + self.evictable_size_ -= len(node.key) + + def _print_helper(self, node: TreeNode, indent: int): + """Prints the radix tree in a human-readable format.""" + stack = [(node, indent)] + while stack: + current_node, current_indent = stack.pop() + print( + " " * current_indent, + len(current_node.key), + current_node.key[:10], + f"r={current_node.lock_ref}", + current_node.kv_cache, + ) + for key, child in current_node.children.items(): + stack.append((child, current_indent + 2)) + + assert key == self.get_child_key_fn( + child.key + ), f"{key=}, {self.get_child_key_fn(child.key)=}" + + def _total_size_helper(self): + """Get total number of tokens stored helper function""" + total_size = 0 + stack = [self.root_node] + while stack: + current_node = stack.pop() + total_size += len(current_node.value) + for child in current_node.children.values(): + if child.evicted: + continue + stack.append(child) + return total_size diff --git a/src/parallax/sglang/batch_info.py b/src/parallax/sglang/batch_info.py index 364300c2..cd8f3767 100755 --- a/src/parallax/sglang/batch_info.py +++ b/src/parallax/sglang/batch_info.py @@ -1,223 +1,223 @@ -""" -Store information about a SGLang batch. -The following is the flow of data structures for a batch in SGLang: - -ScheduleBatch -> ModelWorkerBatch -> ForwardBatch -""" - -from types import SimpleNamespace -from typing import List - -import torch -from sglang.srt.managers.schedule_batch import Req, ScheduleBatch -from sglang.srt.model_executor.forward_batch_info import ForwardBatch -from sglang.srt.model_executor.model_runner import ModelRunner -from sglang.srt.sampling.sampling_batch_info import ( - SamplingBatchInfo as SGLSamplingBatchInfo, -) -from sglang.srt.sampling.sampling_params import SamplingParams as SGLSamplingParams -from sglang.srt.speculative.spec_info import SpeculativeAlgorithm - -from parallax.server.request import Request -from parallax.server.sampling.sampling_params import ( - SamplingParams as ParallaxSamplingParams, -) -from parallax_utils.logging_config import get_logger - -logger = get_logger(__name__) - - -def transform_sampling_params_to_sglang(old_params: ParallaxSamplingParams) -> SGLSamplingParams: - """Transforms Parallax SamplingParams to SGLang.SamplingParams format""" - params = SGLSamplingParams( - max_new_tokens=old_params.max_new_tokens, - min_new_tokens=old_params.min_new_tokens, - temperature=old_params.temperature, - top_p=old_params.top_p, - min_p=old_params.min_p, - top_k=old_params.top_k, - stop_token_ids=old_params.stop_token_ids, - ignore_eos=old_params.ignore_eos, - stop=old_params.stop_strs, - repetition_penalty=old_params.repetition_penalty, - presence_penalty=old_params.presence_penalty, - json_schema=old_params.json_schema, - ) - return params - - -def transform_requests_to_sglang(old_requests: List[Request]) -> List[Req]: - """Transforms Parallax Request to SGLang.Req format""" - reqs = [] - for old_req in old_requests: - sampling_params = transform_sampling_params_to_sglang(old_req.sampling_params) - req = Req( - rid=old_req.request_id, - origin_input_text="", - origin_input_ids=old_req.input_ids, - sampling_params=sampling_params, - ) - req.init_next_round_input() - reqs.append(req) - return reqs - - -def form_sgl_batch_prefill( - requests: List[Request], - model_runner: ModelRunner, -) -> ForwardBatch: - """Initialize a prefill ScheduleBatch -> ModelWorkerBatch -> ForwardBatch workflow""" - sgl_reqs = transform_requests_to_sglang(requests) - - def dummy_evict(*args): - pass - - dummy_tree_cache = SimpleNamespace( - page_size=model_runner.server_args.page_size, - device=model_runner.device, - token_to_kv_pool_allocator=model_runner.token_to_kv_pool_allocator, - evictable_size=0, - ) - dummy_tree_cache.evict = dummy_evict - schedule_batch = ScheduleBatch.init_new( - reqs=sgl_reqs, - req_to_token_pool=model_runner.req_to_token_pool, - token_to_kv_pool_allocator=model_runner.token_to_kv_pool_allocator, - tree_cache=dummy_tree_cache, - model_config=model_runner.model_config, - enable_overlap=False, - spec_algorithm=SpeculativeAlgorithm.NONE, - ) - schedule_batch.prepare_for_extend() - model_worker_batch = schedule_batch.get_model_worker_batch() - forward_batch = ForwardBatch.init_new(model_worker_batch, model_runner) - return schedule_batch, forward_batch - - -def select_batch( - origin_batch: ScheduleBatch, - keep_indices: List[int], -) -> ScheduleBatch: - """ - Copy a subset of requests to form a new ScheduleBatch from the running ScheduleBatch. - Since the requests are not necessary selected in the loop, we need to copy by indicies to select - the real requests to run. - """ - ret = origin_batch.copy() - if keep_indices is None or len(keep_indices) == 0: - return None - - keep_indices_device = torch.tensor(keep_indices, dtype=torch.int64).to( - origin_batch.device, non_blocking=True - ) - - ret.token_to_kv_pool_allocator = origin_batch.token_to_kv_pool_allocator - ret.req_to_token_pool = origin_batch.req_to_token_pool - ret.tree_cache = origin_batch.tree_cache - - if origin_batch.model_config.is_encoder_decoder: - ret.encoder_lens = origin_batch.encoder_lens[keep_indices_device] - ret.encoder_lens_cpu = [origin_batch.encoder_lens_cpu[i] for i in keep_indices] - - ret.reqs = [origin_batch.reqs[i] for i in keep_indices] - if origin_batch.multimodal_inputs is not None: - ret.multimodal_inputs = [origin_batch.multimodal_inputs[i] for i in keep_indices] - ret.seq_lens_cpu = origin_batch.seq_lens_cpu[keep_indices] - ret.req_pool_indices = origin_batch.req_pool_indices[keep_indices_device] - ret.seq_lens = origin_batch.seq_lens[keep_indices_device] - ret.orig_seq_lens = origin_batch.orig_seq_lens[keep_indices_device] - - if origin_batch.out_cache_loc is not None: - ret.out_cache_loc = origin_batch.out_cache_loc[keep_indices_device] - ret.seq_lens_sum = ret.seq_lens.sum().item() - - if origin_batch.output_ids is not None: - ret.output_ids = origin_batch.output_ids[keep_indices_device] - - ret.return_logprob = any(req.return_logprob for req in origin_batch.reqs) - if ret.return_logprob: - ret.top_logprobs_nums = [origin_batch.top_logprobs_nums[i] for i in keep_indices] - ret.token_ids_logprobs = [origin_batch.token_ids_logprobs[i] for i in keep_indices] - else: - ret.top_logprobs_nums = None - ret.token_ids_logprobs = None - - ret.has_stream = any(req.stream for req in origin_batch.reqs) - ret.has_grammar = any(req.grammar for req in origin_batch.reqs) - - ret.sampling_info = SGLSamplingBatchInfo.from_schedule_batch( - ret, origin_batch.model_config.vocab_size - ) - - return ret - - -def find_index(running_batch: ScheduleBatch, request_id: str): - """Helper function for finding the requests in the running batch by request_id""" - for index, req in enumerate(running_batch.reqs): - if req.rid == request_id: - return index - logger.exception( - f"Request {request_id} not found in running batch, size: {len(running_batch.reqs)}, \ - reqs: {[request.rid for request in running_batch.reqs]}" - ) - return -1 - - -def form_sgl_batch_decode( - requests: List[Request], - model_runner: ModelRunner, - running_batch: ScheduleBatch, - is_first_rank: bool, -) -> ForwardBatch: - """ - Forms the decoding batch in this round. - The returned ScheduleBatch is a copy of subset of the running batch. - ModelWorkerBatch -> ForwardBatch are generated from the selected ScheduleBatch. - """ - ready_indices = list( - filter(lambda x: x != -1, [find_index(running_batch, req.request_id) for req in requests]) - ) - ret = select_batch(running_batch, ready_indices) - if is_first_rank: - output_ids = [] - for request in requests: - output_ids.append(request.output_ids[-1]) - ret.output_ids = torch.tensor(output_ids, dtype=torch.int64).to( - ret.device, non_blocking=True - ) - else: - # Set an empty output_ids tensor - batch_size = len(ready_indices) - ret.output_ids = torch.empty(batch_size, dtype=torch.int64).to( - ret.device, non_blocking=True - ) - ret.prepare_for_decode() - # TODO: this is a hack to make the seq_lens correct due to select_batch is not refference running batch's seq_lens - # need to fix this - running_batch.seq_lens[ready_indices] += 1 - running_batch.seq_lens_cpu[ready_indices] += 1 - running_batch.orig_seq_lens[ready_indices] += 1 - - model_worker_batch = ret.get_model_worker_batch() - forward_batch = ForwardBatch.init_new(model_worker_batch, model_runner) - - return forward_batch - - -def release_sglang_request(running_batch: ScheduleBatch, request_id: str): - """Release KV Cache and other resources for finished/aborted requests.""" - if running_batch is None or running_batch.is_empty(): - return - seq_lens_cpu = running_batch.seq_lens.cpu().numpy() - idx = find_index(running_batch, request_id) - req = running_batch.reqs.pop(idx) - - # Free kv cache - page_size = running_batch.token_to_kv_pool_allocator.page_size - last_uncached_pos = (len(req.prefix_indices) // page_size) * page_size - token_indices = running_batch.req_to_token_pool.req_to_token[ - req.req_pool_idx, last_uncached_pos : seq_lens_cpu[idx] - ] - running_batch.token_to_kv_pool_allocator.free(token_indices) - running_batch.req_to_token_pool.free(req.req_pool_idx) +""" +Store information about a SGLang batch. +The following is the flow of data structures for a batch in SGLang: + +ScheduleBatch -> ModelWorkerBatch -> ForwardBatch +""" + +from types import SimpleNamespace +from typing import List + +import torch +from sglang.srt.managers.schedule_batch import Req, ScheduleBatch +from sglang.srt.model_executor.forward_batch_info import ForwardBatch +from sglang.srt.model_executor.model_runner import ModelRunner +from sglang.srt.sampling.sampling_batch_info import ( + SamplingBatchInfo as SGLSamplingBatchInfo, +) +from sglang.srt.sampling.sampling_params import SamplingParams as SGLSamplingParams +from sglang.srt.speculative.spec_info import SpeculativeAlgorithm + +from parallax.server.request import Request +from parallax.server.sampling.sampling_params import ( + SamplingParams as ParallaxSamplingParams, +) +from parallax_utils.logging_config import get_logger + +logger = get_logger(__name__) + + +def transform_sampling_params_to_sglang(old_params: ParallaxSamplingParams) -> SGLSamplingParams: + """Transforms Parallax SamplingParams to SGLang.SamplingParams format""" + params = SGLSamplingParams( + max_new_tokens=old_params.max_new_tokens, + min_new_tokens=old_params.min_new_tokens, + temperature=old_params.temperature, + top_p=old_params.top_p, + min_p=old_params.min_p, + top_k=old_params.top_k, + stop_token_ids=old_params.stop_token_ids, + ignore_eos=old_params.ignore_eos, + stop=old_params.stop_strs, + repetition_penalty=old_params.repetition_penalty, + presence_penalty=old_params.presence_penalty, + json_schema=old_params.json_schema, + ) + return params + + +def transform_requests_to_sglang(old_requests: List[Request]) -> List[Req]: + """Transforms Parallax Request to SGLang.Req format""" + reqs = [] + for old_req in old_requests: + sampling_params = transform_sampling_params_to_sglang(old_req.sampling_params) + req = Req( + rid=old_req.request_id, + origin_input_text="", + origin_input_ids=old_req.input_ids, + sampling_params=sampling_params, + ) + req.init_next_round_input() + reqs.append(req) + return reqs + + +def form_sgl_batch_prefill( + requests: List[Request], + model_runner: ModelRunner, +) -> ForwardBatch: + """Initialize a prefill ScheduleBatch -> ModelWorkerBatch -> ForwardBatch workflow""" + sgl_reqs = transform_requests_to_sglang(requests) + + def dummy_evict(*args): + pass + + dummy_tree_cache = SimpleNamespace( + page_size=model_runner.server_args.page_size, + device=model_runner.device, + token_to_kv_pool_allocator=model_runner.token_to_kv_pool_allocator, + evictable_size=0, + ) + dummy_tree_cache.evict = dummy_evict + schedule_batch = ScheduleBatch.init_new( + reqs=sgl_reqs, + req_to_token_pool=model_runner.req_to_token_pool, + token_to_kv_pool_allocator=model_runner.token_to_kv_pool_allocator, + tree_cache=dummy_tree_cache, + model_config=model_runner.model_config, + enable_overlap=False, + spec_algorithm=SpeculativeAlgorithm.NONE, + ) + schedule_batch.prepare_for_extend() + model_worker_batch = schedule_batch.get_model_worker_batch() + forward_batch = ForwardBatch.init_new(model_worker_batch, model_runner) + return schedule_batch, forward_batch + + +def select_batch( + origin_batch: ScheduleBatch, + keep_indices: List[int], +) -> ScheduleBatch: + """ + Copy a subset of requests to form a new ScheduleBatch from the running ScheduleBatch. + Since the requests are not necessary selected in the loop, we need to copy by indicies to select + the real requests to run. + """ + ret = origin_batch.copy() + if keep_indices is None or len(keep_indices) == 0: + return None + + keep_indices_device = torch.tensor(keep_indices, dtype=torch.int64).to( + origin_batch.device, non_blocking=True + ) + + ret.token_to_kv_pool_allocator = origin_batch.token_to_kv_pool_allocator + ret.req_to_token_pool = origin_batch.req_to_token_pool + ret.tree_cache = origin_batch.tree_cache + + if origin_batch.model_config.is_encoder_decoder: + ret.encoder_lens = origin_batch.encoder_lens[keep_indices_device] + ret.encoder_lens_cpu = [origin_batch.encoder_lens_cpu[i] for i in keep_indices] + + ret.reqs = [origin_batch.reqs[i] for i in keep_indices] + if origin_batch.multimodal_inputs is not None: + ret.multimodal_inputs = [origin_batch.multimodal_inputs[i] for i in keep_indices] + ret.seq_lens_cpu = origin_batch.seq_lens_cpu[keep_indices] + ret.req_pool_indices = origin_batch.req_pool_indices[keep_indices_device] + ret.seq_lens = origin_batch.seq_lens[keep_indices_device] + ret.orig_seq_lens = origin_batch.orig_seq_lens[keep_indices_device] + + if origin_batch.out_cache_loc is not None: + ret.out_cache_loc = origin_batch.out_cache_loc[keep_indices_device] + ret.seq_lens_sum = ret.seq_lens.sum().item() + + if origin_batch.output_ids is not None: + ret.output_ids = origin_batch.output_ids[keep_indices_device] + + ret.return_logprob = any(req.return_logprob for req in origin_batch.reqs) + if ret.return_logprob: + ret.top_logprobs_nums = [origin_batch.top_logprobs_nums[i] for i in keep_indices] + ret.token_ids_logprobs = [origin_batch.token_ids_logprobs[i] for i in keep_indices] + else: + ret.top_logprobs_nums = None + ret.token_ids_logprobs = None + + ret.has_stream = any(req.stream for req in origin_batch.reqs) + ret.has_grammar = any(req.grammar for req in origin_batch.reqs) + + ret.sampling_info = SGLSamplingBatchInfo.from_schedule_batch( + ret, origin_batch.model_config.vocab_size + ) + + return ret + + +def find_index(running_batch: ScheduleBatch, request_id: str): + """Helper function for finding the requests in the running batch by request_id""" + for index, req in enumerate(running_batch.reqs): + if req.rid == request_id: + return index + logger.exception( + f"Request {request_id} not found in running batch, size: {len(running_batch.reqs)}, \ + reqs: {[request.rid for request in running_batch.reqs]}" + ) + return -1 + + +def form_sgl_batch_decode( + requests: List[Request], + model_runner: ModelRunner, + running_batch: ScheduleBatch, + is_first_rank: bool, +) -> ForwardBatch: + """ + Forms the decoding batch in this round. + The returned ScheduleBatch is a copy of subset of the running batch. + ModelWorkerBatch -> ForwardBatch are generated from the selected ScheduleBatch. + """ + ready_indices = list( + filter(lambda x: x != -1, [find_index(running_batch, req.request_id) for req in requests]) + ) + ret = select_batch(running_batch, ready_indices) + if is_first_rank: + output_ids = [] + for request in requests: + output_ids.append(request.output_ids[-1]) + ret.output_ids = torch.tensor(output_ids, dtype=torch.int64).to( + ret.device, non_blocking=True + ) + else: + # Set an empty output_ids tensor + batch_size = len(ready_indices) + ret.output_ids = torch.empty(batch_size, dtype=torch.int64).to( + ret.device, non_blocking=True + ) + ret.prepare_for_decode() + # TODO: this is a hack to make the seq_lens correct due to select_batch is not refference running batch's seq_lens + # need to fix this + running_batch.seq_lens[ready_indices] += 1 + running_batch.seq_lens_cpu[ready_indices] += 1 + running_batch.orig_seq_lens[ready_indices] += 1 + + model_worker_batch = ret.get_model_worker_batch() + forward_batch = ForwardBatch.init_new(model_worker_batch, model_runner) + + return forward_batch + + +def release_sglang_request(running_batch: ScheduleBatch, request_id: str): + """Release KV Cache and other resources for finished/aborted requests.""" + if running_batch is None or running_batch.is_empty(): + return + seq_lens_cpu = running_batch.seq_lens.cpu().numpy() + idx = find_index(running_batch, request_id) + req = running_batch.reqs.pop(idx) + + # Free kv cache + page_size = running_batch.token_to_kv_pool_allocator.page_size + last_uncached_pos = (len(req.prefix_indices) // page_size) * page_size + token_indices = running_batch.req_to_token_pool.req_to_token[ + req.req_pool_idx, last_uncached_pos : seq_lens_cpu[idx] + ] + running_batch.token_to_kv_pool_allocator.free(token_indices) + running_batch.req_to_token_pool.free(req.req_pool_idx) diff --git a/src/parallax/sglang/monkey_patch_utils/gpt_oss_model.py b/src/parallax/sglang/monkey_patch_utils/gpt_oss_model.py index acc26458..447484e6 100644 --- a/src/parallax/sglang/monkey_patch_utils/gpt_oss_model.py +++ b/src/parallax/sglang/monkey_patch_utils/gpt_oss_model.py @@ -1,192 +1,192 @@ -## This is a patch file for sglang GPT-OSS model to support loading mxFP4 MoE experts weights - -import math - -import torch -from sglang.srt.distributed import ( - get_moe_expert_parallel_rank, - get_moe_expert_parallel_world_size, - get_moe_tensor_parallel_rank, - get_moe_tensor_parallel_world_size, -) -from sglang.srt.layers.utils import get_layer_id -from sglang.srt.model_loader.weight_utils import default_weight_loader -from sglang.srt.models.gpt_oss import GptOssForCausalLM - - -def _parallax_load_mxfp4_experts_weights(self, weights): - params_dict = dict(self.named_parameters()) - loaded_params: set[str] = set() - mxfp4_block = 32 - - moe_tp_rank = get_moe_tensor_parallel_rank() - moe_tp_size = get_moe_tensor_parallel_world_size() - moe_ep_rank = get_moe_expert_parallel_rank() - moe_ep_size = get_moe_expert_parallel_world_size() - - intermediate_size = self.config.intermediate_size - assert ( - intermediate_size % mxfp4_block == 0 - ), f"{intermediate_size=} must be divisible by {mxfp4_block=}" - intermediate_size_block = intermediate_size // mxfp4_block - - per_rank_intermediate_size_block = math.ceil(intermediate_size_block / moe_tp_size) - - per_rank_intermediate_size = per_rank_intermediate_size_block * mxfp4_block - - # Calculate common slicing bounds for current rank - assert self.config.num_local_experts % moe_ep_size == 0 - moe_num_global_experts = self.config.num_local_experts - moe_num_local_experts = self.config.num_local_experts // moe_ep_size - - moe_tp_rank_start = moe_tp_rank * per_rank_intermediate_size - moe_tp_rank_end = min((moe_tp_rank + 1) * per_rank_intermediate_size, intermediate_size) - - moe_ep_rank_start = moe_ep_rank * moe_num_local_experts - moe_ep_rank_end = (moe_ep_rank + 1) * moe_num_local_experts - - for name, weight in weights: - ############################################################################ - ## TODO: remove when sglang code support pipeline parallelism - ## This is a patch code for sgalng - layer_id = get_layer_id(name) - if ( - layer_id is not None - and hasattr(self.model, "start_layer") - and (layer_id < self.model.start_layer or layer_id >= self.model.end_layer) - ): - continue - ## End of patch - ############################################################################ - weight = weight.cuda() - - if "gate_up_proj_blocks" in name: - # Handle MLP gate and up projection weights - new_name = name.replace("gate_up_proj_blocks", "w13_weight") - - # flat weight from (E, 2 * N, block_size, entry_per_block) - # to (E, 2 * N, -1), shouldn't trigger copy for contiguous - weight = weight.view(moe_num_global_experts, 2 * intermediate_size, -1).contiguous() - - narrow_weight = weight[ - moe_ep_rank_start:moe_ep_rank_end, - 2 * moe_tp_rank_start : 2 * moe_tp_rank_end, - ..., - ] - - param = params_dict[new_name] - weight_loader = getattr(param, "weight_loader", default_weight_loader) - weight_loader( - param, - narrow_weight, - weight_name=new_name, - shard_id=None, - expert_id=None, - ) - loaded_params.add(new_name) - - elif "down_proj_blocks" in name: - # Handle MLP down projection weights - new_name = name.replace("down_proj_blocks", "w2_weight") - # same flatten here, but since 2 mx4 value are packed in 1 - # uint8, divide by 2 - weight = weight.view(moe_num_global_experts, -1, intermediate_size // 2).contiguous() - narrow_weight = weight[ - moe_ep_rank_start:moe_ep_rank_end, - ..., - moe_tp_rank_start // 2 : moe_tp_rank_end // 2, - ] - - param = params_dict[new_name] - weight_loader = getattr(param, "weight_loader", default_weight_loader) - weight_loader( - param, - narrow_weight, - weight_name=new_name, - shard_id=None, - expert_id=None, - ) - loaded_params.add(new_name) - - elif "gate_up_proj_scales" in name: - # Handle MLP gate and up projection weights scale - new_name = name.replace("gate_up_proj_scales", "w13_weight_scale") - narrow_weight = weight[ - moe_ep_rank_start:moe_ep_rank_end, - 2 * moe_tp_rank_start : 2 * moe_tp_rank_end, - ..., - ] - - param = params_dict[new_name] - weight_loader = getattr(param, "weight_loader", default_weight_loader) - weight_loader( - param, - narrow_weight, - weight_name=new_name, - shard_id=None, - expert_id=None, - ) - loaded_params.add(new_name) - - elif "down_proj_scales" in name: - # Handle MLP down projection weights - new_name = name.replace("down_proj_scales", "w2_weight_scale") - narrow_weight = weight[ - moe_ep_rank_start:moe_ep_rank_end, - ..., - moe_tp_rank_start // mxfp4_block : moe_tp_rank_end // mxfp4_block, - ] - - param = params_dict[new_name] - weight_loader = getattr(param, "weight_loader", default_weight_loader) - weight_loader( - param, - narrow_weight, - weight_name=new_name, - shard_id=None, - expert_id=None, - ) - loaded_params.add(new_name) - elif "gate_up_proj_bias" in name: - # Handle MLP gate and up projection biases - new_name = name.replace("gate_up_proj_bias", "w13_weight_bias") - - narrow_weight = weight[ - moe_ep_rank_start:moe_ep_rank_end, - 2 * moe_tp_rank_start : 2 * moe_tp_rank_end, - ] - - param = params_dict[new_name] - weight_loader = getattr(param, "weight_loader", default_weight_loader) - weight_loader( - param, - narrow_weight, - weight_name=new_name, - shard_id=None, - expert_id=None, - ) - loaded_params.add(new_name) - - elif "down_proj_bias" in name: - narrow_weight = weight[moe_ep_rank_start:moe_ep_rank_end, ...] - if moe_tp_rank != 0: - narrow_weight = torch.zeros_like(narrow_weight) - - # Handle MLP down projection bias - new_name = name.replace("down_proj_bias", "w2_weight_bias") - param = params_dict[new_name] - weight_loader = getattr(param, "weight_loader", default_weight_loader) - weight_loader( - param, - narrow_weight, - weight_name=new_name, - shard_id=None, - expert_id=None, - ) - loaded_params.add(new_name) - - return loaded_params - - -def apply_gpt_oss_monkey_patch(): - GptOssForCausalLM._load_mxfp4_experts_weights = _parallax_load_mxfp4_experts_weights +## This is a patch file for sglang GPT-OSS model to support loading mxFP4 MoE experts weights + +import math + +import torch +from sglang.srt.distributed import ( + get_moe_expert_parallel_rank, + get_moe_expert_parallel_world_size, + get_moe_tensor_parallel_rank, + get_moe_tensor_parallel_world_size, +) +from sglang.srt.layers.utils import get_layer_id +from sglang.srt.model_loader.weight_utils import default_weight_loader +from sglang.srt.models.gpt_oss import GptOssForCausalLM + + +def _parallax_load_mxfp4_experts_weights(self, weights): + params_dict = dict(self.named_parameters()) + loaded_params: set[str] = set() + mxfp4_block = 32 + + moe_tp_rank = get_moe_tensor_parallel_rank() + moe_tp_size = get_moe_tensor_parallel_world_size() + moe_ep_rank = get_moe_expert_parallel_rank() + moe_ep_size = get_moe_expert_parallel_world_size() + + intermediate_size = self.config.intermediate_size + assert ( + intermediate_size % mxfp4_block == 0 + ), f"{intermediate_size=} must be divisible by {mxfp4_block=}" + intermediate_size_block = intermediate_size // mxfp4_block + + per_rank_intermediate_size_block = math.ceil(intermediate_size_block / moe_tp_size) + + per_rank_intermediate_size = per_rank_intermediate_size_block * mxfp4_block + + # Calculate common slicing bounds for current rank + assert self.config.num_local_experts % moe_ep_size == 0 + moe_num_global_experts = self.config.num_local_experts + moe_num_local_experts = self.config.num_local_experts // moe_ep_size + + moe_tp_rank_start = moe_tp_rank * per_rank_intermediate_size + moe_tp_rank_end = min((moe_tp_rank + 1) * per_rank_intermediate_size, intermediate_size) + + moe_ep_rank_start = moe_ep_rank * moe_num_local_experts + moe_ep_rank_end = (moe_ep_rank + 1) * moe_num_local_experts + + for name, weight in weights: + ############################################################################ + ## TODO: remove when sglang code support pipeline parallelism + ## This is a patch code for sgalng + layer_id = get_layer_id(name) + if ( + layer_id is not None + and hasattr(self.model, "start_layer") + and (layer_id < self.model.start_layer or layer_id >= self.model.end_layer) + ): + continue + ## End of patch + ############################################################################ + weight = weight.cuda() + + if "gate_up_proj_blocks" in name: + # Handle MLP gate and up projection weights + new_name = name.replace("gate_up_proj_blocks", "w13_weight") + + # flat weight from (E, 2 * N, block_size, entry_per_block) + # to (E, 2 * N, -1), shouldn't trigger copy for contiguous + weight = weight.view(moe_num_global_experts, 2 * intermediate_size, -1).contiguous() + + narrow_weight = weight[ + moe_ep_rank_start:moe_ep_rank_end, + 2 * moe_tp_rank_start : 2 * moe_tp_rank_end, + ..., + ] + + param = params_dict[new_name] + weight_loader = getattr(param, "weight_loader", default_weight_loader) + weight_loader( + param, + narrow_weight, + weight_name=new_name, + shard_id=None, + expert_id=None, + ) + loaded_params.add(new_name) + + elif "down_proj_blocks" in name: + # Handle MLP down projection weights + new_name = name.replace("down_proj_blocks", "w2_weight") + # same flatten here, but since 2 mx4 value are packed in 1 + # uint8, divide by 2 + weight = weight.view(moe_num_global_experts, -1, intermediate_size // 2).contiguous() + narrow_weight = weight[ + moe_ep_rank_start:moe_ep_rank_end, + ..., + moe_tp_rank_start // 2 : moe_tp_rank_end // 2, + ] + + param = params_dict[new_name] + weight_loader = getattr(param, "weight_loader", default_weight_loader) + weight_loader( + param, + narrow_weight, + weight_name=new_name, + shard_id=None, + expert_id=None, + ) + loaded_params.add(new_name) + + elif "gate_up_proj_scales" in name: + # Handle MLP gate and up projection weights scale + new_name = name.replace("gate_up_proj_scales", "w13_weight_scale") + narrow_weight = weight[ + moe_ep_rank_start:moe_ep_rank_end, + 2 * moe_tp_rank_start : 2 * moe_tp_rank_end, + ..., + ] + + param = params_dict[new_name] + weight_loader = getattr(param, "weight_loader", default_weight_loader) + weight_loader( + param, + narrow_weight, + weight_name=new_name, + shard_id=None, + expert_id=None, + ) + loaded_params.add(new_name) + + elif "down_proj_scales" in name: + # Handle MLP down projection weights + new_name = name.replace("down_proj_scales", "w2_weight_scale") + narrow_weight = weight[ + moe_ep_rank_start:moe_ep_rank_end, + ..., + moe_tp_rank_start // mxfp4_block : moe_tp_rank_end // mxfp4_block, + ] + + param = params_dict[new_name] + weight_loader = getattr(param, "weight_loader", default_weight_loader) + weight_loader( + param, + narrow_weight, + weight_name=new_name, + shard_id=None, + expert_id=None, + ) + loaded_params.add(new_name) + elif "gate_up_proj_bias" in name: + # Handle MLP gate and up projection biases + new_name = name.replace("gate_up_proj_bias", "w13_weight_bias") + + narrow_weight = weight[ + moe_ep_rank_start:moe_ep_rank_end, + 2 * moe_tp_rank_start : 2 * moe_tp_rank_end, + ] + + param = params_dict[new_name] + weight_loader = getattr(param, "weight_loader", default_weight_loader) + weight_loader( + param, + narrow_weight, + weight_name=new_name, + shard_id=None, + expert_id=None, + ) + loaded_params.add(new_name) + + elif "down_proj_bias" in name: + narrow_weight = weight[moe_ep_rank_start:moe_ep_rank_end, ...] + if moe_tp_rank != 0: + narrow_weight = torch.zeros_like(narrow_weight) + + # Handle MLP down projection bias + new_name = name.replace("down_proj_bias", "w2_weight_bias") + param = params_dict[new_name] + weight_loader = getattr(param, "weight_loader", default_weight_loader) + weight_loader( + param, + narrow_weight, + weight_name=new_name, + shard_id=None, + expert_id=None, + ) + loaded_params.add(new_name) + + return loaded_params + + +def apply_gpt_oss_monkey_patch(): + GptOssForCausalLM._load_mxfp4_experts_weights = _parallax_load_mxfp4_experts_weights diff --git a/src/parallax/sglang/monkey_patch_utils/triton_backend.py b/src/parallax/sglang/monkey_patch_utils/triton_backend.py index cfd08074..da7e2169 100644 --- a/src/parallax/sglang/monkey_patch_utils/triton_backend.py +++ b/src/parallax/sglang/monkey_patch_utils/triton_backend.py @@ -1,111 +1,111 @@ -from typing import Optional - -import torch -from sglang.srt.layers.attention.triton_backend import TritonAttnBackend -from sglang.srt.layers.dp_attention import get_attention_tp_size -from sglang.srt.model_executor.model_runner import ModelRunner -from sglang.srt.utils import get_bool_env_var, get_device_core_count, get_int_env_var - - -def parallax_triton_backend_init( - self, - model_runner: ModelRunner, - skip_prefill: bool = False, - kv_indptr_buf: Optional[torch.Tensor] = None, -): - # Lazy import to avoid the initialization of cuda context - from sglang.srt.layers.attention.triton_ops.decode_attention import ( - decode_attention_fwd, - ) - from sglang.srt.layers.attention.triton_ops.extend_attention import ( - extend_attention_fwd, - ) - - self.decode_attention_fwd = torch.compiler.disable(decode_attention_fwd) - self.extend_attention_fwd = torch.compiler.disable(extend_attention_fwd) - - # Parse args - self.skip_prefill = skip_prefill - max_bs = model_runner.req_to_token_pool.size - self.sliding_window_size = model_runner.sliding_window_size - self.req_to_token = model_runner.req_to_token_pool.req_to_token - self.token_to_kv_pool_allocator = model_runner.token_to_kv_pool_allocator - self.num_draft_tokens = model_runner.server_args.speculative_num_draft_tokens - self.speculative_num_steps = model_runner.server_args.speculative_num_steps - self.num_head = model_runner.model_config.num_attention_heads // get_attention_tp_size() - self.num_kv_head = model_runner.model_config.get_num_kv_heads(get_attention_tp_size()) - # Modifies layer id to support pipeline parallel - if model_runner.hybrid_gdn_config is not None: - # For hybrid linear models, layer_id = 0 may not be full attention - self.v_head_dim = model_runner.token_to_kv_pool.get_v_head_dim() - else: - - ################################################################################ - ## Patch for PP: get pp_start_layer - self.v_head_dim = model_runner.token_to_kv_pool.get_value_buffer( - model_runner.pp_start_layer - ).shape[-1] - ## End of patch - ################################################################################ - self.max_context_len = model_runner.model_config.context_len - self.device = model_runner.device - self.device_core_count = get_device_core_count(model_runner.gpu_id) - self.static_kv_splits = get_bool_env_var("SGLANG_TRITON_DECODE_ATTN_STATIC_KV_SPLITS", "false") - self.max_kv_splits = model_runner.server_args.triton_attention_num_kv_splits - - # Decide whether enable deterministic inference with batch-invariant operations - self.enable_deterministic = model_runner.server_args.enable_deterministic_inference - - # Configure deterministic inference settings - if self.enable_deterministic: - # Use fixed split tile size for batch invariance - self.split_tile_size = get_int_env_var("SGLANG_TRITON_DECODE_SPLIT_TILE_SIZE", 256) - # Set static_kv_splits to False to use deterministic logic instead - self.static_kv_splits = False - else: - self.split_tile_size = model_runner.server_args.triton_attention_split_tile_size - - if self.split_tile_size is not None: - self.max_kv_splits = ( - self.max_context_len + self.split_tile_size - 1 - ) // self.split_tile_size - # Check arguments - assert not ( - model_runner.sliding_window_size is not None - and model_runner.model_config.is_encoder_decoder - ), "Sliding window and cross attention are not supported together" - - # Initialize buffers - # TODO(Jianan Ji): Make sure it behaves as expected when kv_indptr_buf is provided and sliding window is enabled - if kv_indptr_buf is None: - self.kv_indptr = torch.zeros((max_bs + 1,), dtype=torch.int32, device=model_runner.device) - else: - self.kv_indptr = kv_indptr_buf - - # If sliding window is enabled, we might need two sets of buffers - # because of interleaved attention types (e.g. for Gemma3) - self.window_kv_indptr = None - if self.sliding_window_size is not None and self.sliding_window_size > 0: - if kv_indptr_buf is None: - self.window_kv_indptr = torch.zeros( - (max_bs + 1,), dtype=torch.int32, device=model_runner.device - ) - else: - # When provided a buffer, create a clone for the second buffer - self.window_kv_indptr = torch.zeros_like(kv_indptr_buf) - - if not self.skip_prefill: - self.qo_indptr = torch.zeros((max_bs + 1,), dtype=torch.int32, device=model_runner.device) - - self.mask_indptr = torch.zeros((max_bs + 1,), dtype=torch.int64, device=model_runner.device) - - # Initialize forward metadata - from sglang.srt.layers.attention.triton_backend import ForwardMetadata - - self.forward_metadata: ForwardMetadata = None - - self.cuda_graph_custom_mask = None - - -def apply_triton_backend_init_monkey_patch(): - TritonAttnBackend.__init__ = parallax_triton_backend_init +from typing import Optional + +import torch +from sglang.srt.layers.attention.triton_backend import TritonAttnBackend +from sglang.srt.layers.dp_attention import get_attention_tp_size +from sglang.srt.model_executor.model_runner import ModelRunner +from sglang.srt.utils import get_bool_env_var, get_device_core_count, get_int_env_var + + +def parallax_triton_backend_init( + self, + model_runner: ModelRunner, + skip_prefill: bool = False, + kv_indptr_buf: Optional[torch.Tensor] = None, +): + # Lazy import to avoid the initialization of cuda context + from sglang.srt.layers.attention.triton_ops.decode_attention import ( + decode_attention_fwd, + ) + from sglang.srt.layers.attention.triton_ops.extend_attention import ( + extend_attention_fwd, + ) + + self.decode_attention_fwd = torch.compiler.disable(decode_attention_fwd) + self.extend_attention_fwd = torch.compiler.disable(extend_attention_fwd) + + # Parse args + self.skip_prefill = skip_prefill + max_bs = model_runner.req_to_token_pool.size + self.sliding_window_size = model_runner.sliding_window_size + self.req_to_token = model_runner.req_to_token_pool.req_to_token + self.token_to_kv_pool_allocator = model_runner.token_to_kv_pool_allocator + self.num_draft_tokens = model_runner.server_args.speculative_num_draft_tokens + self.speculative_num_steps = model_runner.server_args.speculative_num_steps + self.num_head = model_runner.model_config.num_attention_heads // get_attention_tp_size() + self.num_kv_head = model_runner.model_config.get_num_kv_heads(get_attention_tp_size()) + # Modifies layer id to support pipeline parallel + if model_runner.hybrid_gdn_config is not None: + # For hybrid linear models, layer_id = 0 may not be full attention + self.v_head_dim = model_runner.token_to_kv_pool.get_v_head_dim() + else: + + ################################################################################ + ## Patch for PP: get pp_start_layer + self.v_head_dim = model_runner.token_to_kv_pool.get_value_buffer( + model_runner.pp_start_layer + ).shape[-1] + ## End of patch + ################################################################################ + self.max_context_len = model_runner.model_config.context_len + self.device = model_runner.device + self.device_core_count = get_device_core_count(model_runner.gpu_id) + self.static_kv_splits = get_bool_env_var("SGLANG_TRITON_DECODE_ATTN_STATIC_KV_SPLITS", "false") + self.max_kv_splits = model_runner.server_args.triton_attention_num_kv_splits + + # Decide whether enable deterministic inference with batch-invariant operations + self.enable_deterministic = model_runner.server_args.enable_deterministic_inference + + # Configure deterministic inference settings + if self.enable_deterministic: + # Use fixed split tile size for batch invariance + self.split_tile_size = get_int_env_var("SGLANG_TRITON_DECODE_SPLIT_TILE_SIZE", 256) + # Set static_kv_splits to False to use deterministic logic instead + self.static_kv_splits = False + else: + self.split_tile_size = model_runner.server_args.triton_attention_split_tile_size + + if self.split_tile_size is not None: + self.max_kv_splits = ( + self.max_context_len + self.split_tile_size - 1 + ) // self.split_tile_size + # Check arguments + assert not ( + model_runner.sliding_window_size is not None + and model_runner.model_config.is_encoder_decoder + ), "Sliding window and cross attention are not supported together" + + # Initialize buffers + # TODO(Jianan Ji): Make sure it behaves as expected when kv_indptr_buf is provided and sliding window is enabled + if kv_indptr_buf is None: + self.kv_indptr = torch.zeros((max_bs + 1,), dtype=torch.int32, device=model_runner.device) + else: + self.kv_indptr = kv_indptr_buf + + # If sliding window is enabled, we might need two sets of buffers + # because of interleaved attention types (e.g. for Gemma3) + self.window_kv_indptr = None + if self.sliding_window_size is not None and self.sliding_window_size > 0: + if kv_indptr_buf is None: + self.window_kv_indptr = torch.zeros( + (max_bs + 1,), dtype=torch.int32, device=model_runner.device + ) + else: + # When provided a buffer, create a clone for the second buffer + self.window_kv_indptr = torch.zeros_like(kv_indptr_buf) + + if not self.skip_prefill: + self.qo_indptr = torch.zeros((max_bs + 1,), dtype=torch.int32, device=model_runner.device) + + self.mask_indptr = torch.zeros((max_bs + 1,), dtype=torch.int64, device=model_runner.device) + + # Initialize forward metadata + from sglang.srt.layers.attention.triton_backend import ForwardMetadata + + self.forward_metadata: ForwardMetadata = None + + self.cuda_graph_custom_mask = None + + +def apply_triton_backend_init_monkey_patch(): + TritonAttnBackend.__init__ = parallax_triton_backend_init diff --git a/src/parallax/utils/tokenizer_utils.py b/src/parallax/utils/tokenizer_utils.py index faefef22..63bee300 100755 --- a/src/parallax/utils/tokenizer_utils.py +++ b/src/parallax/utils/tokenizer_utils.py @@ -1,125 +1,125 @@ -""" -Implements parallax detokenizers for performance. -""" - -import json -from functools import partial -from json import JSONDecodeError - -from mlx_lm.tokenizer_utils import ( - BPEStreamingDetokenizer, - NaiveStreamingDetokenizer, - SPMStreamingDetokenizer, - _is_bpe_decoder, - _is_spm_decoder, - _is_spm_decoder_no_space, -) -from mlx_lm.tokenizer_utils import load_tokenizer as _mlx_load_tokenizer - - -class ParallaxNaiveStreamingDetokenizer(NaiveStreamingDetokenizer): - """A custom BPE streaming detokenizer that add an argument 'tokenizer'""" - - def __init__(self, tokenizer, tokenmap): - self._tokenizer = tokenizer - self._tokenizer.decode([0]) - self.reset() - - -class ParallaxBPEStreamingDetokenizer(BPEStreamingDetokenizer): - """A custom BPE streaming detokenizer that skips initializing tokenmap""" - - def __init__(self, tokenizer, tokenmap): - self.clean_spaces = tokenizer.clean_up_tokenization_spaces - self.tokenmap = tokenmap - self.reset() - self.make_byte_decoder() - - -class ParallaxSPMStreamingDetokenizer(SPMStreamingDetokenizer): - """A custom SPM streaming detokenizer that skips initializing tokenmap""" - - def __init__(self, tokenizer, tokenmap, trim_space=True): - self.trim_space = trim_space - self._sep = "\u2581".encode() - self.tokenmap = tokenmap - self.reset() - - -def _get_spm_tokenmap(tokenizer): - """Initialize spm tokenmap for reuse""" - # Extract the tokens in a list from id to text - tokenmap = [""] * (max(tokenizer.vocab.values()) + 1) - for value, tokenid in tokenizer.vocab.items(): - if value.startswith("<0x"): - # Replace bytes with their value - tokenmap[tokenid] = bytes([int(value[3:5], 16)]) - else: - tokenmap[tokenid] = value.encode() - return tokenmap - - -def _get_bpe_tokenmap(tokenizer): - """Initialize bpe tokenmap for reuse""" - # Extract the tokens in a list from id to text - tokenmap = [None] * len(tokenizer.vocab) - for value, tokenid in tokenizer.vocab.items(): - tokenmap[tokenid] = value - return tokenmap - - -def load_detokenizer(model_path, tokenizer): - """Load a huggingface tokenizer and try to infer the type of streaming - detokenizer to use. - - Note, to use a fast streaming tokenizer, pass a local file path rather than - a Hugging Face repo ID. - """ - detokenizer_class = NaiveStreamingDetokenizer - tokenmap = None - - tokenizer_file = model_path / "tokenizer.json" - if tokenizer_file.exists(): - with open(tokenizer_file, "r", encoding="utf-8") as fid: - try: - tokenizer_content = json.load(fid) - except JSONDecodeError as e: - raise JSONDecodeError("Failed to parse tokenizer.json", e.doc, e.pos) - - if "decoder" in tokenizer_content: - if _is_spm_decoder(tokenizer_content["decoder"]): - detokenizer_class = ParallaxSPMStreamingDetokenizer - tokenmap = _get_spm_tokenmap(tokenizer) - elif _is_spm_decoder_no_space(tokenizer_content["decoder"]): - detokenizer_class = partial(ParallaxSPMStreamingDetokenizer, trim_space=False) - tokenmap = _get_spm_tokenmap(tokenizer) - elif _is_bpe_decoder(tokenizer_content["decoder"]): - detokenizer_class = ParallaxBPEStreamingDetokenizer - tokenmap = _get_bpe_tokenmap(tokenizer) - - return detokenizer_class, tokenmap - - -def load_tokenizer(model_path, trust_remote_code=True, tokenizer_config_extra=None, **kwargs): - """ - Wrapper function for MLX load_tokenizer that defaults trust_remote_code to True. - This is needed for models like Kimi-K2 that contain custom code. - - Args: - model_path: Path to the model - trust_remote_code: Whether to trust remote code (defaults to True) - tokenizer_config_extra: Extra config to pass to AutoTokenizer.from_pretrained - **kwargs: Additional arguments to pass to the original load_tokenizer - - Returns: - The loaded tokenizer - """ - if tokenizer_config_extra is None: - tokenizer_config_extra = {} - - # Add trust_remote_code to the tokenizer config - if trust_remote_code: - tokenizer_config_extra = tokenizer_config_extra.copy() - tokenizer_config_extra["trust_remote_code"] = True - - return _mlx_load_tokenizer(model_path, tokenizer_config_extra=tokenizer_config_extra, **kwargs) +""" +Implements parallax detokenizers for performance. +""" + +import json +from functools import partial +from json import JSONDecodeError + +from mlx_lm.tokenizer_utils import ( + BPEStreamingDetokenizer, + NaiveStreamingDetokenizer, + SPMStreamingDetokenizer, + _is_bpe_decoder, + _is_spm_decoder, + _is_spm_decoder_no_space, +) +from mlx_lm.tokenizer_utils import load_tokenizer as _mlx_load_tokenizer + + +class ParallaxNaiveStreamingDetokenizer(NaiveStreamingDetokenizer): + """A custom BPE streaming detokenizer that add an argument 'tokenizer'""" + + def __init__(self, tokenizer, tokenmap): + self._tokenizer = tokenizer + self._tokenizer.decode([0]) + self.reset() + + +class ParallaxBPEStreamingDetokenizer(BPEStreamingDetokenizer): + """A custom BPE streaming detokenizer that skips initializing tokenmap""" + + def __init__(self, tokenizer, tokenmap): + self.clean_spaces = tokenizer.clean_up_tokenization_spaces + self.tokenmap = tokenmap + self.reset() + self.make_byte_decoder() + + +class ParallaxSPMStreamingDetokenizer(SPMStreamingDetokenizer): + """A custom SPM streaming detokenizer that skips initializing tokenmap""" + + def __init__(self, tokenizer, tokenmap, trim_space=True): + self.trim_space = trim_space + self._sep = "\u2581".encode() + self.tokenmap = tokenmap + self.reset() + + +def _get_spm_tokenmap(tokenizer): + """Initialize spm tokenmap for reuse""" + # Extract the tokens in a list from id to text + tokenmap = [""] * (max(tokenizer.vocab.values()) + 1) + for value, tokenid in tokenizer.vocab.items(): + if value.startswith("<0x"): + # Replace bytes with their value + tokenmap[tokenid] = bytes([int(value[3:5], 16)]) + else: + tokenmap[tokenid] = value.encode() + return tokenmap + + +def _get_bpe_tokenmap(tokenizer): + """Initialize bpe tokenmap for reuse""" + # Extract the tokens in a list from id to text + tokenmap = [None] * len(tokenizer.vocab) + for value, tokenid in tokenizer.vocab.items(): + tokenmap[tokenid] = value + return tokenmap + + +def load_detokenizer(model_path, tokenizer): + """Load a huggingface tokenizer and try to infer the type of streaming + detokenizer to use. + + Note, to use a fast streaming tokenizer, pass a local file path rather than + a Hugging Face repo ID. + """ + detokenizer_class = NaiveStreamingDetokenizer + tokenmap = None + + tokenizer_file = model_path / "tokenizer.json" + if tokenizer_file.exists(): + with open(tokenizer_file, "r", encoding="utf-8") as fid: + try: + tokenizer_content = json.load(fid) + except JSONDecodeError as e: + raise JSONDecodeError("Failed to parse tokenizer.json", e.doc, e.pos) + + if "decoder" in tokenizer_content: + if _is_spm_decoder(tokenizer_content["decoder"]): + detokenizer_class = ParallaxSPMStreamingDetokenizer + tokenmap = _get_spm_tokenmap(tokenizer) + elif _is_spm_decoder_no_space(tokenizer_content["decoder"]): + detokenizer_class = partial(ParallaxSPMStreamingDetokenizer, trim_space=False) + tokenmap = _get_spm_tokenmap(tokenizer) + elif _is_bpe_decoder(tokenizer_content["decoder"]): + detokenizer_class = ParallaxBPEStreamingDetokenizer + tokenmap = _get_bpe_tokenmap(tokenizer) + + return detokenizer_class, tokenmap + + +def load_tokenizer(model_path, trust_remote_code=True, tokenizer_config_extra=None, **kwargs): + """ + Wrapper function for MLX load_tokenizer that defaults trust_remote_code to True. + This is needed for models like Kimi-K2 that contain custom code. + + Args: + model_path: Path to the model + trust_remote_code: Whether to trust remote code (defaults to True) + tokenizer_config_extra: Extra config to pass to AutoTokenizer.from_pretrained + **kwargs: Additional arguments to pass to the original load_tokenizer + + Returns: + The loaded tokenizer + """ + if tokenizer_config_extra is None: + tokenizer_config_extra = {} + + # Add trust_remote_code to the tokenizer config + if trust_remote_code: + tokenizer_config_extra = tokenizer_config_extra.copy() + tokenizer_config_extra["trust_remote_code"] = True + + return _mlx_load_tokenizer(model_path, tokenizer_config_extra=tokenizer_config_extra, **kwargs) diff --git a/src/parallax_utils/ascii_anime.py b/src/parallax_utils/ascii_anime.py index 680ea222..5a0c53f0 100755 --- a/src/parallax_utils/ascii_anime.py +++ b/src/parallax_utils/ascii_anime.py @@ -1,227 +1,227 @@ -import json -import math -import os - -from parallax_utils.file_util import get_project_root - - -class HexColorPrinter: - COLOR_MAP = { - "#000000": ("\033[30m", (0, 0, 0)), - "#800000": ("\033[31m", (128, 0, 0)), - "#008000": ("\033[32m", (0, 128, 0)), - "#808000": ("\033[33m", (128, 128, 0)), - "#000080": ("\033[34m", (0, 0, 128)), - "#800080": ("\033[35m", (128, 0, 128)), - "#008080": ("\033[36m", (0, 128, 128)), - "#c0c0c0": ("\033[37m", (192, 192, 192)), - "#808080": ("\033[90m", (128, 128, 128)), - "#ff0000": ("\033[91m", (255, 0, 0)), - "#00ff00": ("\033[92m", (0, 255, 0)), - "#ffff00": ("\033[93m", (255, 255, 0)), - "#0000ff": ("\033[94m", (0, 0, 255)), - "#ff00ff": ("\033[95m", (255, 0, 255)), - "#00ffff": ("\033[96m", (0, 255, 255)), - "#ffffff": ("\033[97m", (255, 255, 255)), - } - - RESET = "\033[0m" - SHOW = "\033[97m" - WHITE = "\033[97m" - - @classmethod - def hex_to_rgb(cls, hex_color): - hex_color = hex_color.lstrip("#") - return tuple(int(hex_color[i : i + 2], 16) for i in (0, 2, 4)) - - @classmethod - def color_distance(cls, rgb1, rgb2): - return math.sqrt(sum((c1 - c2) ** 2 for c1, c2 in zip(rgb1, rgb2))) - - @classmethod - def find_closest_color(cls, target_hex): - target_rgb = cls.hex_to_rgb(target_hex) - min_distance = float("inf") - closest_color = "\033[97m" - - for _, (ansi_code, rgb) in cls.COLOR_MAP.items(): - distance = cls.color_distance(target_rgb, rgb) - if distance < min_distance: - min_distance = distance - closest_color = ansi_code - if closest_color == "\033[37m": - closest_color = "\033[35m" - if closest_color == "\033[90m": - closest_color = "\033[95m" - - return closest_color - - -def clear_screen(): - # Clear screen command for different operating systems - os.system("cls" if os.name == "nt" else "clear") - - -def handle_colors_data(raw_data): - color_dict = {} - if raw_data is not None: - config = json.loads(raw_data) - for key, value in config.items(): - if isinstance(value, str) and value.startswith("#"): - color_dict[key] = value - return color_dict - - -def process_context_color_run(content, colors): - res = [] - for row, row_str in enumerate(content): - processed_row = "" - for column, text in enumerate(row_str): - position_str = str(column) + "," + str(row) - hex_color = colors.get(position_str, None) - if text in ("▝", "#", ".") and hex_color == "#000000": - text = " " - elif row == 11 and text not in ("▝", "#", " "): - color = HexColorPrinter.WHITE - processed_row += color - else: - if hex_color: - color = HexColorPrinter.find_closest_color(hex_color) - processed_row += color - processed_row += text - processed_row += HexColorPrinter.RESET - res.append(processed_row) - return res - - -def process_context_color_join(content, colors, model_name): - res = [] - if len(model_name) > 30: - model_name = model_name[:30] - name_len = len(model_name) - for row, row_str in enumerate(content): - processed_row = "" - for column, text in enumerate(row_str): - position_str = str(column) + "," + str(row) - hex_color = colors.get(position_str, None) - if text in ("▝", "#", ".") and hex_color == "#000000": - if hex_color == "#000000": - text = " " - elif row == 7 and 9 <= column <= 38: - pos = column - 9 - if pos < name_len: - text = model_name[pos] - processed_row += HexColorPrinter.RESET - else: - text = " " - if hex_color: - color = HexColorPrinter.find_closest_color(hex_color) - processed_row += color - else: - if hex_color: - color = HexColorPrinter.find_closest_color(hex_color) - processed_row += color - processed_row += text - processed_row += HexColorPrinter.RESET - res.append(processed_row) - return res - - -def display_ascii_animation_run(animation_data): - frames = animation_data.get("frames", []) - # loop = animation_data.get('loop', False) - - if not frames: - print("No animation frames found in the JSON data.") - return - - if len(frames) > 0: - last_frame = frames[-1] - content = last_frame.get("content", None) - colors_data = last_frame.get("colors", None) - foreground = colors_data.get("foreground", None) - colors = handle_colors_data(foreground) - - if content: - res = process_context_color_run(content, colors) - res = "\n".join(res) - clear_screen() - print(res) - - # for frame_data in frames: - # content = frame_data.get("content", None) - # delay = frame_data.get("duration", 30) / 1000.0 - # colors_data = frame_data.get("colors", None) - # foreground = colors_data.get("foreground", None) - # colors = handle_colors_data(foreground) - - # if content: - # res = process_context_color_run(content, colors) - # res = "\n".join(res) - # clear_screen() - # print(res) - # time.sleep(delay) - - -def display_ascii_animation_join(animation_data, model_name): - frames = animation_data.get("frames", []) - # loop = animation_data.get('loop', False) - - if not frames: - print("No animation frames found in the JSON data.") - return - - if len(frames) > 0: - last_frame = frames[-1] - content = last_frame.get("content", None) - colors_data = last_frame.get("colors", None) - foreground = colors_data.get("foreground", None) - colors = handle_colors_data(foreground) - - if content: - res = process_context_color_join(content, colors, model_name) - res = "\n".join(res) - clear_screen() - print(res) - - # for frame_data in frames: - # content = frame_data.get("content", None) - # delay = frame_data.get("duration", 30) / 1000.0 - # colors_data = frame_data.get("colors", None) - # foreground = colors_data.get("foreground", None) - # colors = handle_colors_data(foreground) - - # if content: - # res = process_context_color_join(content, colors, model_name) - # res = "\n".join(res) - # clear_screen() - # print(res) - # time.sleep(delay) - - -def display_parallax_run(): - file_path = str(get_project_root()) + "/src/parallax_utils/anime/parallax_run.json" - try: - with open(file_path, "r") as f: - animation_data = json.load(f) - except FileNotFoundError: - print(f"Error: The file '{file_path}' was not found.") - return - except json.JSONDecodeError: - print(f"Error: The file '{file_path}' contains invalid JSON.") - return - display_ascii_animation_run(animation_data) - - -def display_parallax_join(model_name): - file_path = str(get_project_root()) + "/src/parallax_utils/anime/parallax_join.json" - try: - with open(file_path, "r") as f: - animation_data = json.load(f) - except FileNotFoundError: - print(f"Error: The file '{file_path}' was not found.") - return - except json.JSONDecodeError: - print(f"Error: The file '{file_path}' contains invalid JSON.") - return - display_ascii_animation_join(animation_data, model_name) +import json +import math +import os + +from parallax_utils.file_util import get_project_root + + +class HexColorPrinter: + COLOR_MAP = { + "#000000": ("\033[30m", (0, 0, 0)), + "#800000": ("\033[31m", (128, 0, 0)), + "#008000": ("\033[32m", (0, 128, 0)), + "#808000": ("\033[33m", (128, 128, 0)), + "#000080": ("\033[34m", (0, 0, 128)), + "#800080": ("\033[35m", (128, 0, 128)), + "#008080": ("\033[36m", (0, 128, 128)), + "#c0c0c0": ("\033[37m", (192, 192, 192)), + "#808080": ("\033[90m", (128, 128, 128)), + "#ff0000": ("\033[91m", (255, 0, 0)), + "#00ff00": ("\033[92m", (0, 255, 0)), + "#ffff00": ("\033[93m", (255, 255, 0)), + "#0000ff": ("\033[94m", (0, 0, 255)), + "#ff00ff": ("\033[95m", (255, 0, 255)), + "#00ffff": ("\033[96m", (0, 255, 255)), + "#ffffff": ("\033[97m", (255, 255, 255)), + } + + RESET = "\033[0m" + SHOW = "\033[97m" + WHITE = "\033[97m" + + @classmethod + def hex_to_rgb(cls, hex_color): + hex_color = hex_color.lstrip("#") + return tuple(int(hex_color[i : i + 2], 16) for i in (0, 2, 4)) + + @classmethod + def color_distance(cls, rgb1, rgb2): + return math.sqrt(sum((c1 - c2) ** 2 for c1, c2 in zip(rgb1, rgb2))) + + @classmethod + def find_closest_color(cls, target_hex): + target_rgb = cls.hex_to_rgb(target_hex) + min_distance = float("inf") + closest_color = "\033[97m" + + for _, (ansi_code, rgb) in cls.COLOR_MAP.items(): + distance = cls.color_distance(target_rgb, rgb) + if distance < min_distance: + min_distance = distance + closest_color = ansi_code + if closest_color == "\033[37m": + closest_color = "\033[35m" + if closest_color == "\033[90m": + closest_color = "\033[95m" + + return closest_color + + +def clear_screen(): + # Clear screen command for different operating systems + os.system("cls" if os.name == "nt" else "clear") + + +def handle_colors_data(raw_data): + color_dict = {} + if raw_data is not None: + config = json.loads(raw_data) + for key, value in config.items(): + if isinstance(value, str) and value.startswith("#"): + color_dict[key] = value + return color_dict + + +def process_context_color_run(content, colors): + res = [] + for row, row_str in enumerate(content): + processed_row = "" + for column, text in enumerate(row_str): + position_str = str(column) + "," + str(row) + hex_color = colors.get(position_str, None) + if text in ("▝", "#", ".") and hex_color == "#000000": + text = " " + elif row == 11 and text not in ("▝", "#", " "): + color = HexColorPrinter.WHITE + processed_row += color + else: + if hex_color: + color = HexColorPrinter.find_closest_color(hex_color) + processed_row += color + processed_row += text + processed_row += HexColorPrinter.RESET + res.append(processed_row) + return res + + +def process_context_color_join(content, colors, model_name): + res = [] + if len(model_name) > 30: + model_name = model_name[:30] + name_len = len(model_name) + for row, row_str in enumerate(content): + processed_row = "" + for column, text in enumerate(row_str): + position_str = str(column) + "," + str(row) + hex_color = colors.get(position_str, None) + if text in ("▝", "#", ".") and hex_color == "#000000": + if hex_color == "#000000": + text = " " + elif row == 7 and 9 <= column <= 38: + pos = column - 9 + if pos < name_len: + text = model_name[pos] + processed_row += HexColorPrinter.RESET + else: + text = " " + if hex_color: + color = HexColorPrinter.find_closest_color(hex_color) + processed_row += color + else: + if hex_color: + color = HexColorPrinter.find_closest_color(hex_color) + processed_row += color + processed_row += text + processed_row += HexColorPrinter.RESET + res.append(processed_row) + return res + + +def display_ascii_animation_run(animation_data): + frames = animation_data.get("frames", []) + # loop = animation_data.get('loop', False) + + if not frames: + print("No animation frames found in the JSON data.") + return + + if len(frames) > 0: + last_frame = frames[-1] + content = last_frame.get("content", None) + colors_data = last_frame.get("colors", None) + foreground = colors_data.get("foreground", None) + colors = handle_colors_data(foreground) + + if content: + res = process_context_color_run(content, colors) + res = "\n".join(res) + clear_screen() + print(res) + + # for frame_data in frames: + # content = frame_data.get("content", None) + # delay = frame_data.get("duration", 30) / 1000.0 + # colors_data = frame_data.get("colors", None) + # foreground = colors_data.get("foreground", None) + # colors = handle_colors_data(foreground) + + # if content: + # res = process_context_color_run(content, colors) + # res = "\n".join(res) + # clear_screen() + # print(res) + # time.sleep(delay) + + +def display_ascii_animation_join(animation_data, model_name): + frames = animation_data.get("frames", []) + # loop = animation_data.get('loop', False) + + if not frames: + print("No animation frames found in the JSON data.") + return + + if len(frames) > 0: + last_frame = frames[-1] + content = last_frame.get("content", None) + colors_data = last_frame.get("colors", None) + foreground = colors_data.get("foreground", None) + colors = handle_colors_data(foreground) + + if content: + res = process_context_color_join(content, colors, model_name) + res = "\n".join(res) + clear_screen() + print(res) + + # for frame_data in frames: + # content = frame_data.get("content", None) + # delay = frame_data.get("duration", 30) / 1000.0 + # colors_data = frame_data.get("colors", None) + # foreground = colors_data.get("foreground", None) + # colors = handle_colors_data(foreground) + + # if content: + # res = process_context_color_join(content, colors, model_name) + # res = "\n".join(res) + # clear_screen() + # print(res) + # time.sleep(delay) + + +def display_parallax_run(): + file_path = str(get_project_root()) + "/src/parallax_utils/anime/parallax_run.json" + try: + with open(file_path, "r") as f: + animation_data = json.load(f) + except FileNotFoundError: + print(f"Error: The file '{file_path}' was not found.") + return + except json.JSONDecodeError: + print(f"Error: The file '{file_path}' contains invalid JSON.") + return + display_ascii_animation_run(animation_data) + + +def display_parallax_join(model_name): + file_path = str(get_project_root()) + "/src/parallax_utils/anime/parallax_join.json" + try: + with open(file_path, "r") as f: + animation_data = json.load(f) + except FileNotFoundError: + print(f"Error: The file '{file_path}' was not found.") + return + except json.JSONDecodeError: + print(f"Error: The file '{file_path}' contains invalid JSON.") + return + display_ascii_animation_join(animation_data, model_name) diff --git a/tests/test_prefix_cache.py b/tests/test_prefix_cache.py index 5b27bc03..ab96a773 100755 --- a/tests/test_prefix_cache.py +++ b/tests/test_prefix_cache.py @@ -1,32 +1,32 @@ -""" -Tests for the radix tree. -""" - -import mlx.core as mx - -from parallax.server.radix_cache import RadixCache - -if __name__ == "__main__": - DATA_TYPE = mx.bfloat16 - tree = RadixCache( - num_kv_heads=1, - head_dim=4, - num_layers=10, - dtype=DATA_TYPE, - page_size=1, - max_num_tokens=10000, - ) - arr_for_test = mx.zeros([tree.num_layers, tree.num_kv_heads, 1, tree.head_dim], dtype=DATA_TYPE) - - tree.insert("Hello", None, arr_for_test, arr_for_test) - tree.insert("Hello", None, arr_for_test, arr_for_test) - tree.insert("Hello_L.A.!", None, arr_for_test, arr_for_test) - tree.insert("Hello_world! Happy", None, arr_for_test, arr_for_test) - tree.insert("I love you!", None, arr_for_test, arr_for_test) - tree.pretty_print() - - print(tree.match_prefix("I love you! aha")) - - tree.evict(5) - tree.evict(10) - tree.pretty_print() +""" +Tests for the radix tree. +""" + +import mlx.core as mx + +from parallax.server.radix_cache import RadixCache + +if __name__ == "__main__": + DATA_TYPE = mx.bfloat16 + tree = RadixCache( + num_kv_heads=1, + head_dim=4, + num_layers=10, + dtype=DATA_TYPE, + page_size=1, + max_num_tokens=10000, + ) + arr_for_test = mx.zeros([tree.num_layers, tree.num_kv_heads, 1, tree.head_dim], dtype=DATA_TYPE) + + tree.insert("Hello", None, arr_for_test, arr_for_test) + tree.insert("Hello", None, arr_for_test, arr_for_test) + tree.insert("Hello_L.A.!", None, arr_for_test, arr_for_test) + tree.insert("Hello_world! Happy", None, arr_for_test, arr_for_test) + tree.insert("I love you!", None, arr_for_test, arr_for_test) + tree.pretty_print() + + print(tree.match_prefix("I love you! aha")) + + tree.evict(5) + tree.evict(10) + tree.pretty_print() From 7058c4cb773d838e485c4a11df7d21200d50e93c Mon Sep 17 00:00:00 2001 From: yuhao_zhang Date: Mon, 17 Nov 2025 20:16:09 +0800 Subject: [PATCH 36/36] rm useless args --- src/backend/server/server_args.py | 8 -------- 1 file changed, 8 deletions(-) diff --git a/src/backend/server/server_args.py b/src/backend/server/server_args.py index e7c21bf9..61d876b1 100644 --- a/src/backend/server/server_args.py +++ b/src/backend/server/server_args.py @@ -42,14 +42,6 @@ def parse_args() -> argparse.Namespace: help="Use local Hugging Face cache only (no network download)", ) - parser.add_argument( - "--gpu-backend", - type=str, - default="sglang", - choices=["sglang", "vllm"], - help="GPU backend to use", - ) - args = parser.parse_args() return args