diff --git a/.gitignore b/.gitignore index 63408699f4..3fb49db8b1 100644 --- a/.gitignore +++ b/.gitignore @@ -7,3 +7,4 @@ dist .vscode tmp/ requirements-musa.txt +CLAUDE.md diff --git a/lightllm/common/basemodel/basemodel.py b/lightllm/common/basemodel/basemodel.py index 05aaaadca8..308cc8ab11 100755 --- a/lightllm/common/basemodel/basemodel.py +++ b/lightllm/common/basemodel/basemodel.py @@ -7,7 +7,7 @@ import torch import torch.nn.functional as F import triton -from typing import final, List +from typing import final, List, Optional from tqdm import tqdm from lightllm.common.basemodel.layer_weights.hf_load_utils import load_hf_weights @@ -33,6 +33,10 @@ from lightllm.utils.envs_utils import set_model_init_status, enable_diverse_mode_gqa_decode_fast_kernel from lightllm.common.triton_utils.autotuner import Autotuner from lightllm.utils.infer_utils import post_empty_cache +from lightllm.utils.torch_memory_saver_utils import ( + TorchMemorySaverWrapper, + MemoryTag, +) from .attention import get_prefill_att_backend_class, get_decode_att_backend_class from .attention import BaseAttBackend @@ -91,6 +95,7 @@ def __init__(self, kvargs): self.tp_world_size_ = get_dp_world_size() self.enable_tpsp_mix_mode = get_env_start_args().enable_tpsp_mix_mode + self.torch_memory_saver = TorchMemorySaverWrapper(self.args.enable_torch_memory_saver) self.is_mtp_mode = self.args.mtp_mode in [ "vanilla_with_att", "eagle_with_att", @@ -104,19 +109,21 @@ def __init__(self, kvargs): self._verify_params() self._init_quant() - self._init_weights() - self._init_req_manager() - self._init_mem_manager() + with self.torch_memory_saver.region(tag=MemoryTag.WEIGHT, enable_cpu_backup=self.args.enable_weight_cpu_backup): + self._init_weights() + with self.torch_memory_saver.region(tag=MemoryTag.KV_CACHE): + self._init_req_manager() + self._init_mem_manager() + self._init_kv_move_buffer() + # 因为类似 qwen3.5 的linear 架构的模型,其 req_manager 会存储运行时使用的大量 linear state # 这可能会占用大量的显存,所以,req_manger 中保存的 mem_manger 是mem manager 初始化后再赋值 self.req_manager.mem_manager = self.mem_manager - - self._init_kv_move_buffer() self._check_mem_size() self._init_infer_layer() self._init_some_value() self._init_custom() - self._load_hf_weights() + self.load_weights(self.weight_dict) # wait必须在init cudagraph 之前,避免错误捕获 self._wait_other_modules_ready() @@ -181,17 +188,15 @@ def _init_weights(self, start_layer_index=0): ] return - def _load_hf_weights(self): + def load_weights(self, weight_dict: dict): + assert weight_dict is None or isinstance(weight_dict, dict), "weight_dict must be a dict or None" load_hf_weights( self.data_type, - weight_dir=self.weight_dir_, + self.weight_dir_, pre_post_layer=self.pre_post_weight, transformer_layer_list=self.trans_layers_weight, - weight_dict=self.weight_dict, + weight_dict=weight_dict, ) - self.pre_post_weight.verify_load() - [weight.verify_load() for weight in self.trans_layers_weight] - return def _init_mem_manager(self): assert self.config["num_attention_heads"] % self.tp_world_size_ == 0 @@ -1015,6 +1020,7 @@ def _check_max_len_infer(self): ) logger.error(exception_str) raise Exception(exception_str) + torch.cuda.empty_cache() return def autotune_layers(self): @@ -1149,6 +1155,9 @@ def _init_padded_req(self): del b_seq_len del b_ready_cache_len del model_output + del b_mtp_index + del b_prefill_start_loc + del b_q_seq_len torch.cuda.empty_cache() return @@ -1169,3 +1178,72 @@ def _gen_special_model_input(self, token_num: int): special_model_input["mtp_draft_input_hiddens"] = None return special_model_input + + def release_memory_occupation(self, tags: Optional[List[MemoryTag]]): + torch.cuda.synchronize() + if tags is None: + self.release_all() + return + if MemoryTag.WEIGHT in tags: + self.release_weight() + if MemoryTag.KV_CACHE in tags: + self.release_kv_cache() + if MemoryTag.GRAPH in tags: + self.release_graph() + return + + def resume_memory_occupation(self, tags: Optional[List[MemoryTag]]): + if tags is None: + self.resume_all() + return + if MemoryTag.WEIGHT in tags: + self.resume_weight() + if MemoryTag.KV_CACHE in tags: + self.resume_kv_cache() + if MemoryTag.GRAPH in tags: + self.resume_graph() + return + + def release_weight(self): + self.torch_memory_saver.pause(tag=MemoryTag.WEIGHT) + torch.cuda.empty_cache() + gc.collect() + + def release_kv_cache(self): + self.torch_memory_saver.pause(tag=MemoryTag.KV_CACHE) + torch.cuda.empty_cache() + gc.collect() + + def release_graph(self): + self.torch_memory_saver.pause(tag=MemoryTag.GRAPH) + torch.cuda.empty_cache() + gc.collect() + + def release_all(self): + self.torch_memory_saver.pause(tag=MemoryTag.WEIGHT) + self.torch_memory_saver.pause(tag=MemoryTag.KV_CACHE) + self.torch_memory_saver.pause(tag=MemoryTag.GRAPH) + torch.cuda.empty_cache() + gc.collect() + + def resume_weight(self): + torch.cuda.empty_cache() + gc.collect() + self.torch_memory_saver.resume(tag=MemoryTag.WEIGHT) + + def resume_kv_cache(self): + torch.cuda.empty_cache() + gc.collect() + self.torch_memory_saver.resume(tag=MemoryTag.KV_CACHE) + + def resume_graph(self): + torch.cuda.empty_cache() + gc.collect() + self.torch_memory_saver.resume(tag=MemoryTag.GRAPH) + + def resume_all(self): + torch.cuda.empty_cache() + gc.collect() + self.torch_memory_saver.resume(tag=MemoryTag.WEIGHT) + self.torch_memory_saver.resume(tag=MemoryTag.KV_CACHE) + self.torch_memory_saver.resume(tag=MemoryTag.GRAPH) diff --git a/lightllm/common/basemodel/cuda_graph.py b/lightllm/common/basemodel/cuda_graph.py index 782150661e..5e8036ee81 100644 --- a/lightllm/common/basemodel/cuda_graph.py +++ b/lightllm/common/basemodel/cuda_graph.py @@ -8,6 +8,10 @@ from lightllm.utils.envs_utils import get_env_start_args from lightllm.distributed import dist_group_manager from lightllm.common.basemodel.batch_objs import ModelInput, ModelOutput +from lightllm.utils.torch_memory_saver_utils import ( + TorchMemorySaverWrapper, + MemoryTag, +) from .infer_struct import InferStateInfo @@ -26,6 +30,7 @@ def __init__(self, max_batch_size=8, max_len_in_batch=8192, tp_world_size: int = self.max_batch_size = max_batch_size self.graph_max_len_in_batch = max_len_in_batch self.enable_decode_microbatch_overlap = self.args.enable_decode_microbatch_overlap + self.torch_memory_saver = TorchMemorySaverWrapper(self.args.enable_torch_memory_saver) # gen cuda graph batch_sizes # cuda graph gen for batch size = [1, 2, 3, ..., graph_split_batch_size] @@ -94,7 +99,7 @@ def _capture_decode(self, decode_func, infer_state: InferStateInfo): if param_name not in pure_para_set: delattr(infer_state, param_name) - with torch.cuda.graph(graph_obj, pool=self.mempool): + with self.torch_memory_saver.cuda_graph(graph_obj, pool=self.mempool): model_output = decode_func(infer_state) self.graph[batch_size] = (graph_obj, infer_state, model_output) graph_obj.replay() @@ -128,7 +133,7 @@ def _capture_decode_overlap( if para_name not in pure_para_set1: delattr(infer_state1, para_name) - with torch.cuda.graph(graph_obj, pool=self.mempool): + with self.torch_memory_saver.cuda_graph(graph_obj, pool=self.mempool): model_output, model_output1 = decode_func(infer_state, infer_state1) self.graph[batch_size] = ( graph_obj, diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/__init__.py b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/fused_moe_weight.py b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/fused_moe_weight.py index 8f54e14a72..6ca48299f0 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/fused_moe_weight.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/fused_moe_weight.py @@ -33,6 +33,7 @@ def __init__( num_fused_shared_experts: int = 0, layer_num: int = 0, network_config: Dict[str, Any] = None, + moe_layer_index: int = 0, ) -> None: super().__init__(data_type=data_type) self.w1_weight_name = gate_proj_name @@ -50,6 +51,7 @@ def __init__( self.enable_ep_moe = get_env_start_args().enable_ep_moe self.n_routed_experts = n_routed_experts self.num_fused_shared_experts = num_fused_shared_experts + self.moe_layer_index = moe_layer_index self._init_config(network_config) self._init_redundancy_expert_params() self._init_parallel_params() @@ -130,6 +132,7 @@ def experts( topk_group: int, num_expert_group: int, is_prefill: Optional[bool] = None, + microbatch_index: int = 0, ) -> torch.Tensor: """Backward compatible method that routes to platform-specific implementation.""" return self.fuse_moe_impl( @@ -145,6 +148,8 @@ def experts( topk_group=topk_group, num_expert_group=num_expert_group, is_prefill=is_prefill, + moe_layer_index=self.moe_layer_index, + microbatch_index=microbatch_index, ) def low_latency_dispatch( @@ -295,6 +300,7 @@ def _create_weight(self): device_id=self.device_id_, num_experts=self.local_n_routed_experts, ) + self.w1, self.w3 = w13_param_list self.w1_list: List[WeightPack] = self._get_expert_weight_list(w13_param_list[0]) self.w3_list: List[WeightPack] = self._get_expert_weight_list(w13_param_list[1]) self.w2_list: List[WeightPack] = self._get_expert_weight_list(self.w2) @@ -307,7 +313,8 @@ def _get_expert_weight_list(self, weight_pack: WeightPack): return weight_list def _load_weight(self, expert_idx_to_local_idx: Dict[int, int], weights: Dict[str, torch.Tensor]): - + # for merged weights + self._load_merge_weight(weights) # Load each expert with TP slicing for expert_idx, local_expert_idx in expert_idx_to_local_idx.items(): with self.lock: @@ -332,6 +339,7 @@ def _load_expert( w1_weight = f"{self.weight_prefix}.{expert_idx}.{self.w1_weight_name}.{self.quant_method.weight_suffix}" w2_weight = f"{self.weight_prefix}.{expert_idx}.{self.w2_weight_name}.{self.quant_method.weight_suffix}" w3_weight = f"{self.weight_prefix}.{expert_idx}.{self.w3_weight_name}.{self.quant_method.weight_suffix}" + row_slice_func = self.row_slicer._slice_weight col_slice_func = self.col_slicer._slice_weight if w1_weight in weights: @@ -341,6 +349,19 @@ def _load_expert( if w2_weight in weights: self.quant_method.load_weight(col_slice_func(weights[w2_weight]), self.w2_list[local_expert_idx]) + def _load_merge_weight(self, weights: Dict[str, torch.Tensor]): + w1_merge_weight = f"{self.weight_prefix}.{self.w1_weight_name}" + w2_merge_weight = f"{self.weight_prefix}.{self.w2_weight_name}" + w3_merge_weight = f"{self.weight_prefix}.{self.w3_weight_name}" + row_slice_func = self.row_slicer._slice_weight + col_slice_func = self.col_slicer._slice_weight + if w1_merge_weight in weights: + self.quant_method.load_weight(row_slice_func(weights[w1_merge_weight]), self.w1) + if w2_merge_weight in weights: + self.quant_method.load_weight(col_slice_func(weights[w2_merge_weight]), self.w2) + if w3_merge_weight in weights: + self.quant_method.load_weight(row_slice_func(weights[w3_merge_weight]), self.w3) + def _load_expert_scale( self, expert_idx: int, diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/gpt_oss_fused_moe_weight_tp.py b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/gpt_oss_fused_moe_weight_tp.py index 6ed0cef0b4..4ca1605be4 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/gpt_oss_fused_moe_weight_tp.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/gpt_oss_fused_moe_weight_tp.py @@ -8,6 +8,7 @@ from lightllm.common.quantization import Quantcfg from lightllm.common.quantization.quantize_method import QuantizationMethod from lightllm.utils.log_utils import init_logger +from lightllm.common.basemodel import routing_manager as _routing_mgr logger = init_logger(__name__) @@ -46,6 +47,7 @@ def __init__( num_fused_shared_experts: int = 0, layer_num: int = 0, network_config: Dict[str, Any] = None, + moe_layer_index: int = 0, ) -> None: network_config["norm_topk_prob"] = None super().__init__( @@ -62,6 +64,7 @@ def __init__( num_fused_shared_experts=num_fused_shared_experts, layer_num=layer_num, network_config=network_config, + moe_layer_index=moe_layer_index, ) self.hidden_size = network_config["hidden_size"] @@ -144,10 +147,15 @@ def experts( topk_group: int, num_expert_group: int, is_prefill: Optional[bool] = None, + microbatch_index: int = 0, ): topk_weights, topk_ids = self._router(router_logits, top_k) + # Rollout router replay + if _routing_mgr.g_routing_capture_manager is not None: + _routing_mgr.g_routing_capture_manager.capture(self.moe_layer_index, topk_ids, microbatch_index) + w1, w1_scale = self.w1 w2, w2_scale = self.w2 use_fp8_w8a8 = self.quant_method is not None diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/base_impl.py b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/base_impl.py index 00587ac185..1c93cb13dc 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/base_impl.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/base_impl.py @@ -62,5 +62,7 @@ def __call__( topk_group: int, num_expert_group: int, is_prefill: Optional[bool] = None, + moe_layer_index: Optional[int] = None, + microbatch_index: int = 0, ) -> torch.Tensor: pass diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/triton_impl.py b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/triton_impl.py index d6e923a115..90b525d275 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/triton_impl.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/triton_impl.py @@ -3,6 +3,7 @@ from lightllm.common.quantization.no_quant import WeightPack from lightllm.common.quantization.quantize_method import QuantizationMethod from .base_impl import FuseMoeBaseImpl +from lightllm.common.basemodel import routing_manager as _routing_mgr class FuseMoeTriton(FuseMoeBaseImpl): @@ -125,6 +126,8 @@ def __call__( topk_group: int, num_expert_group: int, is_prefill: Optional[bool] = None, + moe_layer_index: Optional[int] = None, + microbatch_index: int = 0, ): topk_weights, topk_ids = self._select_experts( input_tensor=input_tensor, @@ -137,6 +140,10 @@ def __call__( num_expert_group=num_expert_group, scoring_func=scoring_func, ) + + if _routing_mgr.g_routing_capture_manager is not None and moe_layer_index is not None: + _routing_mgr.g_routing_capture_manager.capture(moe_layer_index, topk_ids, microbatch_index) + output = self._fused_experts( input_tensor=input_tensor, w13=w13, diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/mm_slicer.py b/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/mm_slicer.py index ddbf98a866..15f050c14a 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/mm_slicer.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/mm_slicer.py @@ -47,17 +47,17 @@ def _slice_weight_zero_point(self, weight_zero_point: torch.Tensor) -> torch.Ten # 默认weight 的shape是 outxin,这也是目前最通用的约定。 -# 所以row-wise是沿着dim=0进行切分,col-wise是沿着dim=1进行切分。 +# 这里约定row-wise沿着倒数第二维切分,col-wise沿着第一维切分。 class RowSliceMixin(SliceMixinTpl): def __init__(self, tp_rank: int = None, tp_world_size: int = None, repeat_times: int = 1): super().__init__(tp_rank, tp_world_size, repeat_times) def _slice_weight(self, weight: torch.Tensor) -> torch.Tensor: assert ( - weight.shape[0] * self.repeat_times_ % self.tp_world_size_ == 0 - ), f"tp slice error {weight.shape[0] * self.repeat_times_} % {self.tp_world_size_}" - start, end = self._get_slice_start_end(weight.shape[0]) - return weight[start:end, :] + weight.shape[-2] * self.repeat_times_ % self.tp_world_size_ == 0 + ), f"tp slice error {weight.shape[-2] * self.repeat_times_} % {self.tp_world_size_}" + start, end = self._get_slice_start_end(weight.shape[-2]) + return weight[..., start:end, :] def _slice_bias(self, bias: torch.Tensor) -> torch.Tensor: assert ( @@ -75,17 +75,17 @@ def __init__(self, tp_rank: int = None, tp_world_size: int = None, repeat_times: def _slice_weight_scale(self, weight_scale: torch.Tensor) -> torch.Tensor: assert ( - weight_scale.shape[0] % self.tp_world_size_ == 0 - ), f"tp slice error {weight_scale.shape[0]} % {self.tp_world_size_}" - start, end = self._get_slice_start_end(weight_scale.shape[0]) - return weight_scale[start:end] + weight_scale.shape[-2] % self.tp_world_size_ == 0 + ), f"tp slice error {weight_scale.shape[-2]} % {self.tp_world_size_}" + start, end = self._get_slice_start_end(weight_scale.shape[-2]) + return weight_scale[..., start:end, :] def _slice_weight_zero_point(self, weight_zero_point: torch.Tensor) -> torch.Tensor: assert ( - weight_zero_point.shape[0] % self.tp_world_size_ == 0 - ), f"tp slice error {weight_zero_point.shape[0]} % {self.tp_world_size_}" - start, end = self._get_slice_start_end(weight_zero_point.shape[0]) - return weight_zero_point[start:end] + weight_zero_point.shape[-2] % self.tp_world_size_ == 0 + ), f"tp slice error {weight_zero_point.shape[-2]} % {self.tp_world_size_}" + start, end = self._get_slice_start_end(weight_zero_point.shape[-2]) + return weight_zero_point[..., start:end, :] class ColSliceMixin(SliceMixinTpl): @@ -94,10 +94,10 @@ def __init__(self, tp_rank: int = None, tp_world_size: int = None, repeat_times: def _slice_weight(self, weight: torch.Tensor) -> torch.Tensor: assert ( - weight.shape[1] * self.repeat_times_ % self.tp_world_size_ == 0 - ), f"tp slice error {weight.shape[1] * self.repeat_times_ } % {self.tp_world_size_}" - start, end = self._get_slice_start_end(weight.shape[1]) - return weight[:, start:end] + weight.shape[-1] * self.repeat_times_ % self.tp_world_size_ == 0 + ), f"tp slice error {weight.shape[-1] * self.repeat_times_ } % {self.tp_world_size_}" + start, end = self._get_slice_start_end(weight.shape[-1]) + return weight[..., start:end] def _slice_bias(self, bias: torch.Tensor) -> torch.Tensor: return bias / self.tp_world_size_ * self.repeat_times_ @@ -110,16 +110,16 @@ def __init__(self, tp_rank: int = None, tp_world_size: int = None, repeat_times: def _slice_weight_scale(self, weight_scale: torch.Tensor) -> torch.Tensor: assert ( weight_scale.shape[1] * self.repeat_times_ % self.tp_world_size_ == 0 - ), f"tp slice error {weight_scale.shape[1] * self.repeat_times_ } % {self.tp_world_size_}" - start, end = self._get_slice_start_end(weight_scale.shape[1]) - return weight_scale[:, start:end] + ), f"tp slice error {weight_scale.shape[-1] * self.repeat_times_ } % {self.tp_world_size_}" + start, end = self._get_slice_start_end(weight_scale.shape[-1]) + return weight_scale[..., start:end] def _slice_weight_zero_point(self, weight_zero_point: torch.Tensor) -> torch.Tensor: assert ( - weight_zero_point.shape[1] * self.repeat_times_ % self.tp_world_size_ == 0 - ), f"tp slice error {weight_zero_point.shape[1] * self.repeat_times_ } % {self.tp_world_size_}" - start, end = self._get_slice_start_end(weight_zero_point.shape[1]) - return weight_zero_point[:, start:end] + weight_zero_point.shape[-1] * self.repeat_times_ % self.tp_world_size_ == 0 + ), f"tp slice error {weight_zero_point.shape[-1] * self.repeat_times_ } % {self.tp_world_size_}" + start, end = self._get_slice_start_end(weight_zero_point.shape[-1]) + return weight_zero_point[..., start:end] # awq 的量化权重是inxout存储格式,需要定制实现。 diff --git a/lightllm/common/basemodel/routing_manager.py b/lightllm/common/basemodel/routing_manager.py new file mode 100644 index 0000000000..77b611130f --- /dev/null +++ b/lightllm/common/basemodel/routing_manager.py @@ -0,0 +1,191 @@ +import atexit +import torch +import numpy as np +from multiprocessing import shared_memory +from typing import Optional +from lightllm.utils.log_utils import init_logger +from lightllm.utils.dist_utils import get_current_rank_in_dp +from lightllm.server.router.dynamic_prompt.shared_arr import SharedArray +from lightllm.utils.envs_utils import get_unique_server_name + +logger = init_logger(__name__) + + +def routing_dtype_id_to_np(dtype_id: int): + if dtype_id == 1: + return np.uint8 + elif dtype_id == 2: + return np.int16 + return np.int32 + + +def get_routing_config_shm() -> SharedArray: + service_name = get_unique_server_name() + return SharedArray(f"{service_name}_routing_config", shape=(4,), dtype=np.int32) + + +class RoutingCaptureManager: + def __init__( + self, + num_moe_layers: int, + topk: int, + num_experts: int, + kv_cache_size: int, + max_capture_tokens: int, + ): + self.num_moe_layers = num_moe_layers + self.topk = topk + self.num_experts = num_experts + self.kv_cache_size = kv_cache_size + + self.dtype = torch.uint8 if num_experts <= 255 else torch.int16 + dtype_bytes = 1 if self.dtype == torch.uint8 else 2 + + # Shape: (kv_cache_size, num_moe_layers, topk) — on CPU to save GPU memory. + # Written after forward() via flush_to_routing_buffer(), read on request finish. + routing_buffer_size = num_moe_layers * kv_cache_size * topk * dtype_bytes + self.routing_buffer = torch.zeros( + (kv_cache_size, num_moe_layers, topk), + dtype=self.dtype, + device="cpu", + ) + + # Capture buffers: simple contiguous tensors written to during forward(). + capture_buf_size = max_capture_tokens * num_moe_layers * topk * dtype_bytes + self._capture_buffer = [ + torch.zeros((max_capture_tokens, num_moe_layers, topk), dtype=self.dtype, device="cuda") for _ in range(2) + ] + + dtype_name = "uint8" if self.dtype == torch.uint8 else "int16" + logger.info( + f"RoutingCaptureManager initialized: {num_moe_layers} MoE layers, topk={topk}, " + f"routing_buffer(cpu)={routing_buffer_size / 1024 / 1024:.2f}MB, " + f"capture_buffer={capture_buf_size / 1024 / 1024:.2f}MB x2, dtype={dtype_name}" + ) + + @property + def np_dtype(self): + return np.uint8 if self.dtype == torch.uint8 else np.int16 + + @property + def dtype_id(self) -> int: + return 1 if self.dtype == torch.uint8 else 2 + + def capture(self, moe_layer_index: int, topk_ids: torch.Tensor, microbatch_index: int = 0) -> None: + num_tokens = topk_ids.shape[0] + self._capture_buffer[microbatch_index][:num_tokens, moe_layer_index, :] = topk_ids.to(self.dtype) + + def flush_to_routing_buffer(self, mem_indexes: torch.Tensor, num_tokens: int, microbatch_index: int = 0) -> None: + buf = self._capture_buffer[microbatch_index][:num_tokens] # (num_tokens, num_moe_layers, topk) + self.routing_buffer[mem_indexes[:num_tokens].cpu(), :, :] = buf.cpu() + + def extract_routing_data(self, mem_indexes: torch.Tensor) -> np.ndarray: + cpu_indexes = mem_indexes.cpu() if mem_indexes.is_cuda else mem_indexes + return self.routing_buffer[cpu_indexes, :, :].numpy() + + +g_routing_capture_manager: Optional[RoutingCaptureManager] = None + + +def create_routing_capture_manager( + num_moe_layers: int, + topk: int, + num_experts: int, + kv_cache_size: int, + max_capture_tokens: int, +) -> None: + global g_routing_capture_manager + assert g_routing_capture_manager is None, "RoutingCaptureManager already exists" + g_routing_capture_manager = RoutingCaptureManager( + num_moe_layers=num_moe_layers, + topk=topk, + num_experts=num_experts, + kv_cache_size=kv_cache_size, + max_capture_tokens=max_capture_tokens, + ) + + +def cleanup_routing_shm_pool() -> None: + """Unlink all pre-allocated routing SHM segments. Called at server shutdown.""" + try: + from lightllm.utils.envs_utils import get_env_start_args + + args = get_env_start_args() + except Exception: + return + + service_name = get_unique_server_name() + + for i in range(args.running_max_req_size): + name = f"{service_name}_shm_routing_{i}" + try: + shm = shared_memory.SharedMemory(name=name) + shm.close() + shm.unlink() + except Exception: + pass + + config_name = f"{service_name}_routing_config" + try: + shm = shared_memory.SharedMemory(name=config_name) + shm.close() + shm.unlink() + except Exception: + pass + + +def init_routing_capture(model, num_moe_layers: int) -> None: + dp_rank = get_current_rank_in_dp() + logger.info(f"init_routing_capture called: num_moe_layers={num_moe_layers}, dp_rank={dp_rank}") + if dp_rank != 0: + logger.info(f"Skipping routing capture initialization on dp_rank={dp_rank}") + return + + if num_moe_layers == 0: + logger.warning( + "enable_return_routed_experts is set but no MoE layers found. Routing capture will not be enabled." + ) + return + + num_experts = model.config.get("n_routed_experts", model.config.get("num_experts", 0)) + topk = model.config.get("num_experts_per_tok", 0) + assert num_experts > 0 and topk > 0 + + from lightllm.utils.envs_utils import get_env_start_args + + args = get_env_start_args() + + # Capture buffer must fit the max tokens in any single forward call. + # For prefill that's batch_max_tokens; for decode it's graph_max_batch_size. + batch_max_tokens = args.batch_max_tokens or args.max_req_total_len or 8192 + max_capture_tokens = max(batch_max_tokens, args.graph_max_batch_size) + + logger.info( + f"Initializing routing capture: num_moe_layers={num_moe_layers}, " + f"topk={topk}, num_experts={num_experts}, max_capture_tokens={max_capture_tokens}" + ) + + create_routing_capture_manager( + num_moe_layers=num_moe_layers, + topk=topk, + num_experts=num_experts, + kv_cache_size=model.mem_manager.size + 1, + max_capture_tokens=max_capture_tokens, + ) + + mgr = g_routing_capture_manager + dtype_id = mgr.dtype_id + + max_req_total_len = args.max_req_total_len + + # Write config to cross-process SHM + shm = get_routing_config_shm() + shm.arr[0] = num_moe_layers + shm.arr[1] = topk + shm.arr[2] = dtype_id + shm.arr[3] = max_req_total_len + logger.info( + f"Shared routing config set: num_moe_layers={num_moe_layers}, topk={topk}, " + f"dtype_id={dtype_id}, max_tokens={max_req_total_len}" + ) + atexit.register(cleanup_routing_shm_pool) diff --git a/lightllm/common/kv_cache_mem_manager/qwen3next_mem_manager.py b/lightllm/common/kv_cache_mem_manager/qwen3next_mem_manager.py index caca4bb621..6385786bec 100644 --- a/lightllm/common/kv_cache_mem_manager/qwen3next_mem_manager.py +++ b/lightllm/common/kv_cache_mem_manager/qwen3next_mem_manager.py @@ -33,7 +33,13 @@ def get_att_input_params(self, layer_index: int) -> Tuple[Any, Any]: return super().get_att_input_params(layer_index) def _init_buffers(self, size, dtype, head_num, head_dim, layer_num): - super()._init_buffers(size, dtype, head_num, head_dim, layer_num) + # 将原来一次性申请的大 tensor (layer_num, size+1, 2*head_num, head_dim) + # 拆分成 layer_num 个独立的小 tensor,避免单次大块连续显存的申请, + # 在 RL release/resume 显存回放等场景下减少 OOM 概率。 + # 所有按 self.kv_buffer[layer_index] 形式的访问保持完全兼容。 + self.kv_buffer = [ + torch.empty((size + 1, 2 * head_num, head_dim), dtype=dtype, device="cuda") for _ in range(layer_num) + ] # TODO 初始化线性 att 对应的部分 buffer. self._init_linear_att_buffers() return diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.4.0/NVIDIA_H200/grouped_matmul:v1/{K=2048,N=768,expert_num=128,mul_routed_weight=false,out_dtype=torch.bfloat16,topk_num=8,use_fp8_w8a8=false}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.4.0/NVIDIA_H200/grouped_matmul:v1/{K=2048,N=768,expert_num=128,mul_routed_weight=false,out_dtype=torch.bfloat16,topk_num=8,use_fp8_w8a8=false}_NVIDIA_H200.json new file mode 100644 index 0000000000..c75c871c72 --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.4.0/NVIDIA_H200/grouped_matmul:v1/{K=2048,N=768,expert_num=128,mul_routed_weight=false,out_dtype=torch.bfloat16,topk_num=8,use_fp8_w8a8=false}_NVIDIA_H200.json @@ -0,0 +1,110 @@ +{ + "1": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 16, + "NEED_TRANS": false, + "num_stages": 5, + "num_warps": 4 + }, + "100": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 8 + }, + "1024": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "128": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 64, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "16": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 32, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "2048": { + "BLOCK_SIZE_K": 32, + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 4, + "num_warps": 4 + }, + "256": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "32": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "4096": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "64": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 32, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "8": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "8448": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.4.0/NVIDIA_H200/grouped_matmul:v1/{K=384,N=2048,expert_num=128,mul_routed_weight=true,out_dtype=torch.bfloat16,topk_num=1,use_fp8_w8a8=false}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.4.0/NVIDIA_H200/grouped_matmul:v1/{K=384,N=2048,expert_num=128,mul_routed_weight=true,out_dtype=torch.bfloat16,topk_num=1,use_fp8_w8a8=false}_NVIDIA_H200.json new file mode 100644 index 0000000000..14026090e6 --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.4.0/NVIDIA_H200/grouped_matmul:v1/{K=384,N=2048,expert_num=128,mul_routed_weight=true,out_dtype=torch.bfloat16,topk_num=1,use_fp8_w8a8=false}_NVIDIA_H200.json @@ -0,0 +1,110 @@ +{ + "1024": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "128": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 4, + "num_warps": 4 + }, + "16384": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 16, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "2048": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 16, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "256": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 4, + "num_warps": 4 + }, + "32768": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 16, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "512": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 16, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "64": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 16, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "67584": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 16, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "8": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 8 + }, + "800": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "8192": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 16, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.4.0/NVIDIA_H200/moe_sum_reduce:v1/{hidden_dim=2048,out_dtype=torch.bfloat16,topk_num=8}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.4.0/NVIDIA_H200/moe_sum_reduce:v1/{hidden_dim=2048,out_dtype=torch.bfloat16,topk_num=8}_NVIDIA_H200.json new file mode 100644 index 0000000000..939c939523 --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.4.0/NVIDIA_H200/moe_sum_reduce:v1/{hidden_dim=2048,out_dtype=torch.bfloat16,topk_num=8}_NVIDIA_H200.json @@ -0,0 +1,74 @@ +{ + "1": { + "BLOCK_DIM": 128, + "BLOCK_M": 2, + "NUM_STAGE": 2, + "num_warps": 4 + }, + "100": { + "BLOCK_DIM": 512, + "BLOCK_M": 1, + "NUM_STAGE": 1, + "num_warps": 8 + }, + "1024": { + "BLOCK_DIM": 1024, + "BLOCK_M": 1, + "NUM_STAGE": 1, + "num_warps": 4 + }, + "128": { + "BLOCK_DIM": 1024, + "BLOCK_M": 1, + "NUM_STAGE": 1, + "num_warps": 16 + }, + "16": { + "BLOCK_DIM": 128, + "BLOCK_M": 1, + "NUM_STAGE": 2, + "num_warps": 4 + }, + "2048": { + "BLOCK_DIM": 256, + "BLOCK_M": 1, + "NUM_STAGE": 1, + "num_warps": 2 + }, + "256": { + "BLOCK_DIM": 1024, + "BLOCK_M": 1, + "NUM_STAGE": 1, + "num_warps": 8 + }, + "32": { + "BLOCK_DIM": 256, + "BLOCK_M": 1, + "NUM_STAGE": 1, + "num_warps": 2 + }, + "4096": { + "BLOCK_DIM": 512, + "BLOCK_M": 1, + "NUM_STAGE": 1, + "num_warps": 4 + }, + "64": { + "BLOCK_DIM": 1024, + "BLOCK_M": 1, + "NUM_STAGE": 1, + "num_warps": 8 + }, + "8": { + "BLOCK_DIM": 128, + "BLOCK_M": 1, + "NUM_STAGE": 1, + "num_warps": 8 + }, + "8448": { + "BLOCK_DIM": 1024, + "BLOCK_M": 1, + "NUM_STAGE": 1, + "num_warps": 4 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.4.0/NVIDIA_H200/silu_and_mul_fwd:v1/{N=384,out_dtype=torch.bfloat16}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.4.0/NVIDIA_H200/silu_and_mul_fwd:v1/{N=384,out_dtype=torch.bfloat16}_NVIDIA_H200.json new file mode 100644 index 0000000000..13ba4ba8e5 --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.4.0/NVIDIA_H200/silu_and_mul_fwd:v1/{N=384,out_dtype=torch.bfloat16}_NVIDIA_H200.json @@ -0,0 +1,74 @@ +{ + "1024": { + "BLOCK_M": 8, + "BLOCK_N": 128, + "NUM_STAGES": 4, + "num_warps": 1 + }, + "128": { + "BLOCK_M": 1, + "BLOCK_N": 256, + "NUM_STAGES": 1, + "num_warps": 4 + }, + "16384": { + "BLOCK_M": 8, + "BLOCK_N": 256, + "NUM_STAGES": 4, + "num_warps": 1 + }, + "2048": { + "BLOCK_M": 8, + "BLOCK_N": 256, + "NUM_STAGES": 4, + "num_warps": 1 + }, + "256": { + "BLOCK_M": 1, + "BLOCK_N": 256, + "NUM_STAGES": 2, + "num_warps": 4 + }, + "32768": { + "BLOCK_M": 32, + "BLOCK_N": 128, + "NUM_STAGES": 4, + "num_warps": 1 + }, + "512": { + "BLOCK_M": 1, + "BLOCK_N": 256, + "NUM_STAGES": 1, + "num_warps": 8 + }, + "64": { + "BLOCK_M": 1, + "BLOCK_N": 128, + "NUM_STAGES": 1, + "num_warps": 1 + }, + "67584": { + "BLOCK_M": 64, + "BLOCK_N": 128, + "NUM_STAGES": 4, + "num_warps": 1 + }, + "8": { + "BLOCK_M": 1, + "BLOCK_N": 128, + "NUM_STAGES": 4, + "num_warps": 4 + }, + "800": { + "BLOCK_M": 1, + "BLOCK_N": 256, + "NUM_STAGES": 1, + "num_warps": 4 + }, + "8192": { + "BLOCK_M": 8, + "BLOCK_N": 256, + "NUM_STAGES": 4, + "num_warps": 1 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.0/NVIDIA_H100_80GB_HBM3/grouped_matmul:v1/{K=192,N=2048,expert_num=128,mul_routed_weight=true,out_dtype=torch.bfloat16,topk_num=1,use_fp8_w8a8=false}_NVIDIA_H100_80GB_HBM3.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.0/NVIDIA_H100_80GB_HBM3/grouped_matmul:v1/{K=192,N=2048,expert_num=128,mul_routed_weight=true,out_dtype=torch.bfloat16,topk_num=1,use_fp8_w8a8=false}_NVIDIA_H100_80GB_HBM3.json new file mode 100644 index 0000000000..ee316f610b --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.0/NVIDIA_H100_80GB_HBM3/grouped_matmul:v1/{K=192,N=2048,expert_num=128,mul_routed_weight=true,out_dtype=torch.bfloat16,topk_num=1,use_fp8_w8a8=false}_NVIDIA_H100_80GB_HBM3.json @@ -0,0 +1,110 @@ +{ + "1024": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 8 + }, + "128": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 64, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "16384": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "2048": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "256": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "32768": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "512": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 4, + "num_warps": 4 + }, + "64": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 16, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "67584": { + "BLOCK_SIZE_K": 32, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 16, + "NEED_TRANS": false, + "num_stages": 4, + "num_warps": 4 + }, + "8": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "800": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 8 + }, + "8192": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 16, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.0/NVIDIA_H100_80GB_HBM3/grouped_matmul:v1/{K=192,N=2048,expert_num=128,mul_routed_weight=true,out_dtype=torch.float16,topk_num=1,use_fp8_w8a8=false}_NVIDIA_H100_80GB_HBM3.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.0/NVIDIA_H100_80GB_HBM3/grouped_matmul:v1/{K=192,N=2048,expert_num=128,mul_routed_weight=true,out_dtype=torch.float16,topk_num=1,use_fp8_w8a8=false}_NVIDIA_H100_80GB_HBM3.json new file mode 100644 index 0000000000..e027701092 --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.0/NVIDIA_H100_80GB_HBM3/grouped_matmul:v1/{K=192,N=2048,expert_num=128,mul_routed_weight=true,out_dtype=torch.float16,topk_num=1,use_fp8_w8a8=false}_NVIDIA_H100_80GB_HBM3.json @@ -0,0 +1,110 @@ +{ + "1024": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "128": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 4, + "num_warps": 4 + }, + "16384": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "2048": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "256": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "32768": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "512": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 16, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "64": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 32, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "67584": { + "BLOCK_SIZE_K": 32, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 16, + "NEED_TRANS": false, + "num_stages": 4, + "num_warps": 4 + }, + "8": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "800": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "8192": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 16, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.0/NVIDIA_H100_80GB_HBM3/grouped_matmul:v1/{K=192,N=2048,expert_num=128,mul_routed_weight=true,out_dtype=torch.float16,topk_num=1,use_fp8_w8a8=true}_NVIDIA_H100_80GB_HBM3.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.0/NVIDIA_H100_80GB_HBM3/grouped_matmul:v1/{K=192,N=2048,expert_num=128,mul_routed_weight=true,out_dtype=torch.float16,topk_num=1,use_fp8_w8a8=true}_NVIDIA_H100_80GB_HBM3.json new file mode 100644 index 0000000000..ddda23d257 --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.0/NVIDIA_H100_80GB_HBM3/grouped_matmul:v1/{K=192,N=2048,expert_num=128,mul_routed_weight=true,out_dtype=torch.float16,topk_num=1,use_fp8_w8a8=true}_NVIDIA_H100_80GB_HBM3.json @@ -0,0 +1,110 @@ +{ + "1024": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 64, + "NEED_TRANS": true, + "num_stages": 4, + "num_warps": 4 + }, + "128": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": true, + "num_stages": 2, + "num_warps": 4 + }, + "16384": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 64, + "NEED_TRANS": false, + "num_stages": 4, + "num_warps": 4 + }, + "2048": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 64, + "NEED_TRANS": true, + "num_stages": 4, + "num_warps": 4 + }, + "256": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 16, + "NEED_TRANS": true, + "num_stages": 2, + "num_warps": 4 + }, + "32768": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "NEED_TRANS": false, + "num_stages": 4, + "num_warps": 4 + }, + "512": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "NEED_TRANS": false, + "num_stages": 4, + "num_warps": 4 + }, + "64": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": true, + "num_stages": 2, + "num_warps": 4 + }, + "67584": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 16, + "NEED_TRANS": false, + "num_stages": 4, + "num_warps": 4 + }, + "8": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": true, + "num_stages": 2, + "num_warps": 4 + }, + "800": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 64, + "NEED_TRANS": true, + "num_stages": 4, + "num_warps": 4 + }, + "8192": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 64, + "NEED_TRANS": false, + "num_stages": 4, + "num_warps": 4 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.0/NVIDIA_H100_80GB_HBM3/grouped_matmul:v1/{K=2048,N=384,expert_num=128,mul_routed_weight=false,out_dtype=torch.bfloat16,topk_num=8,use_fp8_w8a8=false}_NVIDIA_H100_80GB_HBM3.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.0/NVIDIA_H100_80GB_HBM3/grouped_matmul:v1/{K=2048,N=384,expert_num=128,mul_routed_weight=false,out_dtype=torch.bfloat16,topk_num=8,use_fp8_w8a8=false}_NVIDIA_H100_80GB_HBM3.json new file mode 100644 index 0000000000..560ca6c09d --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.0/NVIDIA_H100_80GB_HBM3/grouped_matmul:v1/{K=2048,N=384,expert_num=128,mul_routed_weight=false,out_dtype=torch.bfloat16,topk_num=8,use_fp8_w8a8=false}_NVIDIA_H100_80GB_HBM3.json @@ -0,0 +1,110 @@ +{ + "1": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 5, + "num_warps": 4 + }, + "100": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 16, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "1024": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "128": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 16, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 8 + }, + "16": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "2048": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 16, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "256": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 64, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "32": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "4096": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "64": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "8": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 16, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "8448": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 8 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.0/NVIDIA_H100_80GB_HBM3/grouped_matmul:v1/{K=2048,N=384,expert_num=128,mul_routed_weight=false,out_dtype=torch.float16,topk_num=8,use_fp8_w8a8=false}_NVIDIA_H100_80GB_HBM3.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.0/NVIDIA_H100_80GB_HBM3/grouped_matmul:v1/{K=2048,N=384,expert_num=128,mul_routed_weight=false,out_dtype=torch.float16,topk_num=8,use_fp8_w8a8=false}_NVIDIA_H100_80GB_HBM3.json new file mode 100644 index 0000000000..0713de7996 --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.0/NVIDIA_H100_80GB_HBM3/grouped_matmul:v1/{K=2048,N=384,expert_num=128,mul_routed_weight=false,out_dtype=torch.float16,topk_num=8,use_fp8_w8a8=false}_NVIDIA_H100_80GB_HBM3.json @@ -0,0 +1,110 @@ +{ + "1": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "100": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "1024": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "128": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 16, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 8 + }, + "16": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "2048": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 16, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "256": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 16, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "32": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 16, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "4096": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "64": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "8": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 16, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "8448": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 8 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.0/NVIDIA_H100_80GB_HBM3/grouped_matmul:v1/{K=2048,N=384,expert_num=128,mul_routed_weight=false,out_dtype=torch.float16,topk_num=8,use_fp8_w8a8=true}_NVIDIA_H100_80GB_HBM3.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.0/NVIDIA_H100_80GB_HBM3/grouped_matmul:v1/{K=2048,N=384,expert_num=128,mul_routed_weight=false,out_dtype=torch.float16,topk_num=8,use_fp8_w8a8=true}_NVIDIA_H100_80GB_HBM3.json new file mode 100644 index 0000000000..e950ff0954 --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.0/NVIDIA_H100_80GB_HBM3/grouped_matmul:v1/{K=2048,N=384,expert_num=128,mul_routed_weight=false,out_dtype=torch.float16,topk_num=8,use_fp8_w8a8=true}_NVIDIA_H100_80GB_HBM3.json @@ -0,0 +1,110 @@ +{ + "1": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 16, + "NEED_TRANS": true, + "num_stages": 3, + "num_warps": 4 + }, + "100": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 16, + "NEED_TRANS": true, + "num_stages": 3, + "num_warps": 4 + }, + "1024": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "128": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 64, + "NEED_TRANS": true, + "num_stages": 3, + "num_warps": 4 + }, + "16": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 64, + "NEED_TRANS": true, + "num_stages": 3, + "num_warps": 4 + }, + "2048": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 32, + "NEED_TRANS": true, + "num_stages": 3, + "num_warps": 4 + }, + "256": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 64, + "NEED_TRANS": true, + "num_stages": 3, + "num_warps": 4 + }, + "32": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 64, + "NEED_TRANS": true, + "num_stages": 3, + "num_warps": 4 + }, + "4096": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 16, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 8 + }, + "64": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 1, + "NEED_TRANS": true, + "num_stages": 3, + "num_warps": 4 + }, + "8": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "NEED_TRANS": true, + "num_stages": 4, + "num_warps": 4 + }, + "8448": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 8 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.0/NVIDIA_H100_80GB_HBM3/moe_align_fused:v1/{topk_num=8}_NVIDIA_H100_80GB_HBM3.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.0/NVIDIA_H100_80GB_HBM3/moe_align_fused:v1/{topk_num=8}_NVIDIA_H100_80GB_HBM3.json new file mode 100644 index 0000000000..7f479b8382 --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.0/NVIDIA_H100_80GB_HBM3/moe_align_fused:v1/{topk_num=8}_NVIDIA_H100_80GB_HBM3.json @@ -0,0 +1,50 @@ +{ + "1": { + "BLOCK_SIZE": 256, + "num_warps": 2 + }, + "100": { + "BLOCK_SIZE": 128, + "num_warps": 8 + }, + "1024": { + "BLOCK_SIZE": 128, + "num_warps": 2 + }, + "128": { + "BLOCK_SIZE": 128, + "num_warps": 8 + }, + "16": { + "BLOCK_SIZE": 256, + "num_warps": 4 + }, + "2048": { + "BLOCK_SIZE": 256, + "num_warps": 8 + }, + "256": { + "BLOCK_SIZE": 128, + "num_warps": 8 + }, + "32": { + "BLOCK_SIZE": 128, + "num_warps": 8 + }, + "4096": { + "BLOCK_SIZE": 128, + "num_warps": 1 + }, + "64": { + "BLOCK_SIZE": 128, + "num_warps": 8 + }, + "8": { + "BLOCK_SIZE": 256, + "num_warps": 2 + }, + "8448": { + "BLOCK_SIZE": 256, + "num_warps": 4 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.0/NVIDIA_H100_80GB_HBM3/moe_sum_reduce:v1/{hidden_dim=2048,out_dtype=torch.bfloat16,topk_num=8}_NVIDIA_H100_80GB_HBM3.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.0/NVIDIA_H100_80GB_HBM3/moe_sum_reduce:v1/{hidden_dim=2048,out_dtype=torch.bfloat16,topk_num=8}_NVIDIA_H100_80GB_HBM3.json new file mode 100644 index 0000000000..b3051c6584 --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.0/NVIDIA_H100_80GB_HBM3/moe_sum_reduce:v1/{hidden_dim=2048,out_dtype=torch.bfloat16,topk_num=8}_NVIDIA_H100_80GB_HBM3.json @@ -0,0 +1,74 @@ +{ + "1": { + "BLOCK_DIM": 128, + "BLOCK_M": 1, + "NUM_STAGE": 2, + "num_warps": 8 + }, + "100": { + "BLOCK_DIM": 512, + "BLOCK_M": 1, + "NUM_STAGE": 1, + "num_warps": 8 + }, + "1024": { + "BLOCK_DIM": 1024, + "BLOCK_M": 2, + "NUM_STAGE": 4, + "num_warps": 4 + }, + "128": { + "BLOCK_DIM": 512, + "BLOCK_M": 1, + "NUM_STAGE": 1, + "num_warps": 8 + }, + "16": { + "BLOCK_DIM": 128, + "BLOCK_M": 1, + "NUM_STAGE": 2, + "num_warps": 4 + }, + "2048": { + "BLOCK_DIM": 512, + "BLOCK_M": 1, + "NUM_STAGE": 1, + "num_warps": 4 + }, + "256": { + "BLOCK_DIM": 1024, + "BLOCK_M": 1, + "NUM_STAGE": 4, + "num_warps": 2 + }, + "32": { + "BLOCK_DIM": 256, + "BLOCK_M": 1, + "NUM_STAGE": 1, + "num_warps": 2 + }, + "4096": { + "BLOCK_DIM": 256, + "BLOCK_M": 1, + "NUM_STAGE": 1, + "num_warps": 1 + }, + "64": { + "BLOCK_DIM": 512, + "BLOCK_M": 1, + "NUM_STAGE": 1, + "num_warps": 4 + }, + "8": { + "BLOCK_DIM": 256, + "BLOCK_M": 1, + "NUM_STAGE": 1, + "num_warps": 2 + }, + "8448": { + "BLOCK_DIM": 256, + "BLOCK_M": 1, + "NUM_STAGE": 1, + "num_warps": 2 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.0/NVIDIA_H100_80GB_HBM3/moe_sum_reduce:v1/{hidden_dim=2048,out_dtype=torch.float16,topk_num=8}_NVIDIA_H100_80GB_HBM3.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.0/NVIDIA_H100_80GB_HBM3/moe_sum_reduce:v1/{hidden_dim=2048,out_dtype=torch.float16,topk_num=8}_NVIDIA_H100_80GB_HBM3.json new file mode 100644 index 0000000000..fdb3212216 --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.0/NVIDIA_H100_80GB_HBM3/moe_sum_reduce:v1/{hidden_dim=2048,out_dtype=torch.float16,topk_num=8}_NVIDIA_H100_80GB_HBM3.json @@ -0,0 +1,74 @@ +{ + "1": { + "BLOCK_DIM": 256, + "BLOCK_M": 1, + "NUM_STAGE": 4, + "num_warps": 8 + }, + "100": { + "BLOCK_DIM": 1024, + "BLOCK_M": 1, + "NUM_STAGE": 1, + "num_warps": 8 + }, + "1024": { + "BLOCK_DIM": 512, + "BLOCK_M": 1, + "NUM_STAGE": 1, + "num_warps": 2 + }, + "128": { + "BLOCK_DIM": 1024, + "BLOCK_M": 1, + "NUM_STAGE": 1, + "num_warps": 8 + }, + "16": { + "BLOCK_DIM": 128, + "BLOCK_M": 1, + "NUM_STAGE": 4, + "num_warps": 4 + }, + "2048": { + "BLOCK_DIM": 1024, + "BLOCK_M": 1, + "NUM_STAGE": 1, + "num_warps": 8 + }, + "256": { + "BLOCK_DIM": 1024, + "BLOCK_M": 1, + "NUM_STAGE": 1, + "num_warps": 8 + }, + "32": { + "BLOCK_DIM": 256, + "BLOCK_M": 1, + "NUM_STAGE": 2, + "num_warps": 8 + }, + "4096": { + "BLOCK_DIM": 256, + "BLOCK_M": 1, + "NUM_STAGE": 1, + "num_warps": 2 + }, + "64": { + "BLOCK_DIM": 1024, + "BLOCK_M": 1, + "NUM_STAGE": 1, + "num_warps": 8 + }, + "8": { + "BLOCK_DIM": 128, + "BLOCK_M": 1, + "NUM_STAGE": 1, + "num_warps": 4 + }, + "8448": { + "BLOCK_DIM": 1024, + "BLOCK_M": 1, + "NUM_STAGE": 1, + "num_warps": 8 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.0/NVIDIA_H100_80GB_HBM3/silu_and_mul_fwd:v1/{N=192,out_dtype=torch.bfloat16}_NVIDIA_H100_80GB_HBM3.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.0/NVIDIA_H100_80GB_HBM3/silu_and_mul_fwd:v1/{N=192,out_dtype=torch.bfloat16}_NVIDIA_H100_80GB_HBM3.json new file mode 100644 index 0000000000..a94e669353 --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.0/NVIDIA_H100_80GB_HBM3/silu_and_mul_fwd:v1/{N=192,out_dtype=torch.bfloat16}_NVIDIA_H100_80GB_HBM3.json @@ -0,0 +1,74 @@ +{ + "1024": { + "BLOCK_M": 1, + "BLOCK_N": 256, + "NUM_STAGES": 1, + "num_warps": 4 + }, + "128": { + "BLOCK_M": 1, + "BLOCK_N": 256, + "NUM_STAGES": 4, + "num_warps": 1 + }, + "16384": { + "BLOCK_M": 8, + "BLOCK_N": 256, + "NUM_STAGES": 4, + "num_warps": 1 + }, + "2048": { + "BLOCK_M": 1, + "BLOCK_N": 256, + "NUM_STAGES": 1, + "num_warps": 4 + }, + "256": { + "BLOCK_M": 1, + "BLOCK_N": 256, + "NUM_STAGES": 1, + "num_warps": 4 + }, + "32768": { + "BLOCK_M": 1, + "BLOCK_N": 256, + "NUM_STAGES": 4, + "num_warps": 1 + }, + "512": { + "BLOCK_M": 1, + "BLOCK_N": 256, + "NUM_STAGES": 1, + "num_warps": 4 + }, + "64": { + "BLOCK_M": 1, + "BLOCK_N": 64, + "NUM_STAGES": 1, + "num_warps": 4 + }, + "67584": { + "BLOCK_M": 1, + "BLOCK_N": 256, + "NUM_STAGES": 4, + "num_warps": 1 + }, + "8": { + "BLOCK_M": 1, + "BLOCK_N": 64, + "NUM_STAGES": 1, + "num_warps": 4 + }, + "800": { + "BLOCK_M": 1, + "BLOCK_N": 256, + "NUM_STAGES": 1, + "num_warps": 4 + }, + "8192": { + "BLOCK_M": 8, + "BLOCK_N": 256, + "NUM_STAGES": 4, + "num_warps": 1 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.0/NVIDIA_H100_80GB_HBM3/silu_and_mul_fwd:v1/{N=192,out_dtype=torch.float16}_NVIDIA_H100_80GB_HBM3.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.0/NVIDIA_H100_80GB_HBM3/silu_and_mul_fwd:v1/{N=192,out_dtype=torch.float16}_NVIDIA_H100_80GB_HBM3.json new file mode 100644 index 0000000000..441421fd5d --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.0/NVIDIA_H100_80GB_HBM3/silu_and_mul_fwd:v1/{N=192,out_dtype=torch.float16}_NVIDIA_H100_80GB_HBM3.json @@ -0,0 +1,74 @@ +{ + "1024": { + "BLOCK_M": 1, + "BLOCK_N": 256, + "NUM_STAGES": 1, + "num_warps": 1 + }, + "128": { + "BLOCK_M": 1, + "BLOCK_N": 64, + "NUM_STAGES": 1, + "num_warps": 1 + }, + "16384": { + "BLOCK_M": 8, + "BLOCK_N": 256, + "NUM_STAGES": 4, + "num_warps": 1 + }, + "2048": { + "BLOCK_M": 1, + "BLOCK_N": 256, + "NUM_STAGES": 1, + "num_warps": 1 + }, + "256": { + "BLOCK_M": 1, + "BLOCK_N": 256, + "NUM_STAGES": 1, + "num_warps": 1 + }, + "32768": { + "BLOCK_M": 1, + "BLOCK_N": 256, + "NUM_STAGES": 1, + "num_warps": 1 + }, + "512": { + "BLOCK_M": 1, + "BLOCK_N": 256, + "NUM_STAGES": 1, + "num_warps": 1 + }, + "64": { + "BLOCK_M": 1, + "BLOCK_N": 64, + "NUM_STAGES": 1, + "num_warps": 1 + }, + "67584": { + "BLOCK_M": 1, + "BLOCK_N": 256, + "NUM_STAGES": 4, + "num_warps": 1 + }, + "8": { + "BLOCK_M": 1, + "BLOCK_N": 64, + "NUM_STAGES": 1, + "num_warps": 1 + }, + "800": { + "BLOCK_M": 1, + "BLOCK_N": 256, + "NUM_STAGES": 1, + "num_warps": 4 + }, + "8192": { + "BLOCK_M": 8, + "BLOCK_N": 256, + "NUM_STAGES": 4, + "num_warps": 1 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/gemma_rmsnorm_forward:v1/{N=2048,weight_dtype=torch.bfloat16,x_dtype=torch.bfloat16}_NVIDIA_H100_80GB_HBM3.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/gemma_rmsnorm_forward:v1/{N=2048,weight_dtype=torch.bfloat16,x_dtype=torch.bfloat16}_NVIDIA_H100_80GB_HBM3.json new file mode 100644 index 0000000000..864d1d3f18 --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/gemma_rmsnorm_forward:v1/{N=2048,weight_dtype=torch.bfloat16,x_dtype=torch.bfloat16}_NVIDIA_H100_80GB_HBM3.json @@ -0,0 +1,7 @@ +{ + "2048": { + "BLOCK_SIZE": 4096, + "num_stages": 1, + "num_warps": 8 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/gemma_rmsnorm_forward:v1/{N=256,weight_dtype=torch.bfloat16,x_dtype=torch.bfloat16}_NVIDIA_H100_80GB_HBM3.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/gemma_rmsnorm_forward:v1/{N=256,weight_dtype=torch.bfloat16,x_dtype=torch.bfloat16}_NVIDIA_H100_80GB_HBM3.json new file mode 100644 index 0000000000..bcf56e01f7 --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/gemma_rmsnorm_forward:v1/{N=256,weight_dtype=torch.bfloat16,x_dtype=torch.bfloat16}_NVIDIA_H100_80GB_HBM3.json @@ -0,0 +1,7 @@ +{ + "256": { + "BLOCK_SIZE": 128, + "num_stages": 1, + "num_warps": 1 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/gemma_rmsnorm_forward:v1/{N=3072,weight_dtype=torch.bfloat16,x_dtype=torch.bfloat16}_NVIDIA_H100_80GB_HBM3.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/gemma_rmsnorm_forward:v1/{N=3072,weight_dtype=torch.bfloat16,x_dtype=torch.bfloat16}_NVIDIA_H100_80GB_HBM3.json new file mode 100644 index 0000000000..ba1dc8a75d --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/gemma_rmsnorm_forward:v1/{N=3072,weight_dtype=torch.bfloat16,x_dtype=torch.bfloat16}_NVIDIA_H100_80GB_HBM3.json @@ -0,0 +1,7 @@ +{ + "3072": { + "BLOCK_SIZE": 2048, + "num_stages": 1, + "num_warps": 8 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/gemma_rmsnorm_forward:v1/{N=5120,weight_dtype=torch.bfloat16,x_dtype=torch.bfloat16}_NVIDIA_H100_80GB_HBM3.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/gemma_rmsnorm_forward:v1/{N=5120,weight_dtype=torch.bfloat16,x_dtype=torch.bfloat16}_NVIDIA_H100_80GB_HBM3.json new file mode 100644 index 0000000000..6f109e1c6e --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/gemma_rmsnorm_forward:v1/{N=5120,weight_dtype=torch.bfloat16,x_dtype=torch.bfloat16}_NVIDIA_H100_80GB_HBM3.json @@ -0,0 +1,7 @@ +{ + "5120": { + "BLOCK_SIZE": 32768, + "num_stages": 1, + "num_warps": 8 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/gemma_rmsnorm_forward:v1/{N=2048,weight_dtype=torch.bfloat16,x_dtype=torch.bfloat16}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/gemma_rmsnorm_forward:v1/{N=2048,weight_dtype=torch.bfloat16,x_dtype=torch.bfloat16}_NVIDIA_H200.json new file mode 100644 index 0000000000..198a196dfb --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/gemma_rmsnorm_forward:v1/{N=2048,weight_dtype=torch.bfloat16,x_dtype=torch.bfloat16}_NVIDIA_H200.json @@ -0,0 +1,7 @@ +{ + "2048": { + "BLOCK_SIZE": 1024, + "num_stages": 1, + "num_warps": 4 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/gemma_rmsnorm_forward:v1/{N=256,weight_dtype=torch.bfloat16,x_dtype=torch.bfloat16}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/gemma_rmsnorm_forward:v1/{N=256,weight_dtype=torch.bfloat16,x_dtype=torch.bfloat16}_NVIDIA_H200.json new file mode 100644 index 0000000000..537c7a90eb --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/gemma_rmsnorm_forward:v1/{N=256,weight_dtype=torch.bfloat16,x_dtype=torch.bfloat16}_NVIDIA_H200.json @@ -0,0 +1,7 @@ +{ + "256": { + "BLOCK_SIZE": 512, + "num_stages": 1, + "num_warps": 1 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/gemma_rmsnorm_forward:v1/{N=4096,weight_dtype=torch.bfloat16,x_dtype=torch.bfloat16}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/gemma_rmsnorm_forward:v1/{N=4096,weight_dtype=torch.bfloat16,x_dtype=torch.bfloat16}_NVIDIA_H200.json new file mode 100644 index 0000000000..9a6dcb6fbf --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/gemma_rmsnorm_forward:v1/{N=4096,weight_dtype=torch.bfloat16,x_dtype=torch.bfloat16}_NVIDIA_H200.json @@ -0,0 +1,7 @@ +{ + "4096": { + "BLOCK_SIZE": 1024, + "num_stages": 1, + "num_warps": 8 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/gemma_rmsnorm_forward:v1/{N=5120,weight_dtype=torch.bfloat16,x_dtype=torch.bfloat16}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/gemma_rmsnorm_forward:v1/{N=5120,weight_dtype=torch.bfloat16,x_dtype=torch.bfloat16}_NVIDIA_H200.json new file mode 100644 index 0000000000..df501847ec --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/gemma_rmsnorm_forward:v1/{N=5120,weight_dtype=torch.bfloat16,x_dtype=torch.bfloat16}_NVIDIA_H200.json @@ -0,0 +1,7 @@ +{ + "5120": { + "BLOCK_SIZE": 1024, + "num_stages": 1, + "num_warps": 8 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotuner.py b/lightllm/common/triton_utils/autotuner.py index c62a2572ff..60c12e4c07 100644 --- a/lightllm/common/triton_utils/autotuner.py +++ b/lightllm/common/triton_utils/autotuner.py @@ -106,14 +106,10 @@ def __init__( self.configs_gen_func = configs_gen_func self.kernel_name = kernel_name - self.cache_dir = os.path.join( - Path(__file__).parent, - "autotune_kernel_configs", - get_triton_version(), - get_current_device_name(), - self.kernel_name, - ) - os.makedirs(self.cache_dir, exist_ok=True) + # cache_dir 依赖 get_current_device_name(),后者要求 torch.cuda.is_available()。 + # 这里 lazy 化,避免 CPU-only 的进程(例如 Ray driver / verl rollout replica + # 入口)在 import 时就触发 TypeError。 + self._cache_dir: Optional[str] = None self.fn = fn self.static_key_func = static_key_func self.run_key_func = run_key_func @@ -209,6 +205,25 @@ def __call__(self, *args, **kwargs): return self.fn(*args, **kwargs) + @property + def cache_dir(self) -> str: + if self._cache_dir is None: + device_name = get_current_device_name() + if device_name is None: + raise RuntimeError( + f"Autotuner for kernel {self.kernel_name} requires a visible CUDA/MUSA device " + f"to resolve its cache directory, but torch.cuda.is_available() is False." + ) + self._cache_dir = os.path.join( + Path(__file__).parent, + "autotune_kernel_configs", + get_triton_version(), + device_name, + self.kernel_name, + ) + os.makedirs(self._cache_dir, exist_ok=True) + return self._cache_dir + def _try_load_cache(self, static_key): if static_key in self.cached_configs: return False diff --git a/lightllm/models/deepseek2/layer_infer/transformer_layer_infer.py b/lightllm/models/deepseek2/layer_infer/transformer_layer_infer.py index fa2dee444f..88c4b1e8ee 100644 --- a/lightllm/models/deepseek2/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/deepseek2/layer_infer/transformer_layer_infer.py @@ -231,6 +231,7 @@ def _moe_ffn_tp( use_grouped_topk=self.n_group, topk_group=self.topk_group, num_expert_group=self.n_group, + microbatch_index=infer_state.microbatch_index, ) if self.n_shared_experts is not None and layer_weight.num_fused_shared_experts == 0: @@ -258,6 +259,7 @@ def _moe_ffn_edp( topk_group=self.topk_group, num_expert_group=self.n_group, is_prefill=infer_state.is_prefill, + microbatch_index=infer_state.microbatch_index, ) if self.n_shared_experts is not None: diff --git a/lightllm/models/deepseek2/layer_weights/transformer_layer_weight.py b/lightllm/models/deepseek2/layer_weights/transformer_layer_weight.py index 3eb09f9176..bd72035072 100644 --- a/lightllm/models/deepseek2/layer_weights/transformer_layer_weight.py +++ b/lightllm/models/deepseek2/layer_weights/transformer_layer_weight.py @@ -242,6 +242,9 @@ def _init_moe(self): # == 0 时,说明不存在融合共享专家,共享专家单独加载和进行推理。 if self.num_fused_shared_experts == 0: self._load_mlp(f"model.layers.{self.layer_num_}.mlp.shared_experts", is_shared_experts=True) + first_moe = self.network_config_["first_k_dense_replace"] + freq = self.network_config_.get("moe_layer_freq", 1) + moe_layer_index = (self.layer_num_ - first_moe) // freq self.experts = FusedMoeWeight( gate_proj_name="gate_proj", down_proj_name="down_proj", @@ -256,6 +259,7 @@ def _init_moe(self): num_fused_shared_experts=self.num_fused_shared_experts, layer_num=self.layer_num_, network_config=self.network_config_, + moe_layer_index=moe_layer_index, ) def _init_ffn(self): diff --git a/lightllm/models/deepseek2/model.py b/lightllm/models/deepseek2/model.py index e596eed97c..79bd327068 100644 --- a/lightllm/models/deepseek2/model.py +++ b/lightllm/models/deepseek2/model.py @@ -6,6 +6,7 @@ from lightllm.models.deepseek2.infer_struct import Deepseek2InferStateInfo from lightllm.models.llama.model import LlamaTpPartModel from lightllm.common.kv_cache_mem_manager.mem_utils import select_mem_manager_class +from lightllm.common.basemodel.routing_manager import init_routing_capture from lightllm.utils.log_utils import init_logger from lightllm.utils.envs_utils import enable_env_vars, get_env_start_args, get_added_mtp_kv_layer_num from lightllm.distributed.communication_op import dist_group_manager @@ -49,6 +50,9 @@ def _init_some_value(self): def _init_custom(self): self._init_to_get_yarn_rotary() dist_group_manager.new_deepep_group(self.config["n_routed_experts"], self.config["hidden_size"]) + if self.args.enable_return_routed_experts: + num_moe_layers = sum(1 for w in self.trans_layers_weight if w.is_moe) + init_routing_capture(self, num_moe_layers) def _verify_params(self): return super()._verify_params() diff --git a/lightllm/models/gpt_oss/layer_infer/transformer_layer_infer.py b/lightllm/models/gpt_oss/layer_infer/transformer_layer_infer.py index 4c457fd993..bb9e6140bf 100644 --- a/lightllm/models/gpt_oss/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/gpt_oss/layer_infer/transformer_layer_infer.py @@ -52,6 +52,7 @@ def _ffn(self, input, infer_state, layer_weight: GptOssTransformerLayerWeight) - use_grouped_topk=False, topk_group=None, num_expert_group=None, + microbatch_index=infer_state.microbatch_index, ) hidden_states = hidden_states.view(num_tokens, hidden_dim) return self._tpsp_reduce(input=hidden_states, infer_state=infer_state) diff --git a/lightllm/models/gpt_oss/layer_weights/transformer_layer_weight.py b/lightllm/models/gpt_oss/layer_weights/transformer_layer_weight.py index 7c8c30940e..7278c62fec 100644 --- a/lightllm/models/gpt_oss/layer_weights/transformer_layer_weight.py +++ b/lightllm/models/gpt_oss/layer_weights/transformer_layer_weight.py @@ -55,6 +55,7 @@ def _init_moe(self): num_fused_shared_experts=0, layer_num=self.layer_num_, network_config=self.network_config_, + moe_layer_index=self.layer_num_, ) def _init_weight_names(self): diff --git a/lightllm/models/gpt_oss/model.py b/lightllm/models/gpt_oss/model.py index 9e9561eb24..cff748933d 100644 --- a/lightllm/models/gpt_oss/model.py +++ b/lightllm/models/gpt_oss/model.py @@ -2,6 +2,7 @@ from lightllm.models.gpt_oss.layer_weights.transformer_layer_weight import GptOssTransformerLayerWeight from lightllm.models.llama.model import LlamaTpPartModel from lightllm.models.registry import ModelRegistry +from lightllm.common.basemodel.routing_manager import init_routing_capture from lightllm.utils.envs_utils import get_env_start_args from lightllm.utils.log_utils import init_logger from lightllm.common.basemodel.attention import get_prefill_att_backend_class, get_decode_att_backend_class @@ -21,6 +22,12 @@ class GptOssTpPartModel(LlamaTpPartModel): def __init__(self, kvargs): super().__init__(kvargs) + def _init_custom(self): + super()._init_custom() + if self.args.enable_return_routed_experts: + num_moe_layers = len(self.trans_layers_weight) + init_routing_capture(self, num_moe_layers) + def _init_att_backend(self): self.prefill_att_backend: BaseAttBackend = get_prefill_att_backend_class(index=0, priority_list=["fa3"])( model=self diff --git a/lightllm/models/mixtral/layer_infer/_custom_ops.py b/lightllm/models/mixtral/layer_infer/_custom_ops.py deleted file mode 100644 index b0e27ac1de..0000000000 --- a/lightllm/models/mixtral/layer_infer/_custom_ops.py +++ /dev/null @@ -1,46 +0,0 @@ -import functools -import json -import os -from typing import Any, Dict, Optional, Tuple - -import torch -import triton -import triton.language as tl -from lightllm.utils.log_utils import init_logger - -logger = init_logger(__name__) - -# Pytorch version -# Triton version in progress -def topk_softmax( - topk_weights, - topk_ids, - token_expert_indicies, - gating_output, - topk=2, -): - scores = torch.softmax(gating_output, dim=-1) - topk_weights, topk_ids = torch.topk(scores, k=topk, dim=-1, sorted=False) - return topk_weights, topk_ids - - -def fused_topk( - hidden_states: torch.Tensor, - gating_output: torch.Tensor, - topk: int, - renormalize: bool, - alloc_tensor_func=torch.empty, -): - assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch" - - M, _ = hidden_states.shape - - topk_weights = alloc_tensor_func((M, topk), dtype=torch.float32, device=hidden_states.device) - topk_ids = alloc_tensor_func((M, topk), dtype=torch.int32, device=hidden_states.device) - token_expert_indicies = alloc_tensor_func((M, topk), dtype=torch.int32, device=hidden_states.device) - topk_weights, topk_ids = topk_softmax(topk_weights, topk_ids, token_expert_indicies, gating_output.float(), topk) - del token_expert_indicies # Not used. Will be used in the future. - - if renormalize: - topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True) - return topk_weights, topk_ids diff --git a/lightllm/models/mixtral/layer_infer/transformer_layer_infer.py b/lightllm/models/mixtral/layer_infer/transformer_layer_infer.py index 0cf651598a..d90c631547 100644 --- a/lightllm/models/mixtral/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/mixtral/layer_infer/transformer_layer_infer.py @@ -1,9 +1,6 @@ -import os import torch -import torch.nn.functional as F from lightllm.common.basemodel.infer_struct import InferStateInfo from lightllm.models.llama.layer_infer.transformer_layer_infer import LlamaTransformerLayerInfer -from lightllm.models.mixtral.layer_infer._custom_ops import fused_topk from lightllm.models.mixtral.layer_weights.transformer_layer_weight import MixtralTransformerLayerWeight @@ -21,25 +18,14 @@ def _ffn(self, input, infer_state: InferStateInfo, layer_weight: MixtralTransfor num_tokens, hidden_dim = hidden_states.shape router_logits = layer_weight.moe_gate.mm(hidden_states) - topk_weights, topk_ids = fused_topk( - hidden_states=hidden_states, - gating_output=router_logits, - topk=self.num_experts_per_tok, + layer_weight.experts.experts( + hidden_states, + router_logits=router_logits, + top_k=self.num_experts_per_tok, renormalize=self.renormalize, - alloc_tensor_func=self.alloc_tensor, + use_grouped_topk=False, + topk_group=None, + num_expert_group=None, + microbatch_index=getattr(infer_state, "microbatch_index", 0), ) - from lightllm.common.fused_moe.grouped_fused_moe import fused_experts_impl - - ffn2_out = fused_experts_impl( - hidden_states=hidden_states, - w1=layer_weight.experts.w1[0], - w2=layer_weight.experts.w2[0], - topk_weights=topk_weights, - topk_ids=topk_ids, - inplace=True, - use_fp8_w8a8=False, - w1_scale=None, - w2_scale=None, - alloc_tensor_func=self.alloc_tensor, - ) - return self._tpsp_reduce(input=ffn2_out, infer_state=infer_state) + return hidden_states.view(num_tokens, hidden_dim) diff --git a/lightllm/models/mixtral/layer_weights/transformer_layer_weight.py b/lightllm/models/mixtral/layer_weights/transformer_layer_weight.py index 51c62fd4cb..d93cb5fb58 100644 --- a/lightllm/models/mixtral/layer_weights/transformer_layer_weight.py +++ b/lightllm/models/mixtral/layer_weights/transformer_layer_weight.py @@ -57,4 +57,5 @@ def _init_moe(self): quant_method=self.quant_cfg.get_quant_method(self.layer_num_, "fused_moe"), layer_num=self.layer_num_, network_config=self.network_config_, + moe_layer_index=self.layer_num_, ) diff --git a/lightllm/models/mixtral/model.py b/lightllm/models/mixtral/model.py index 3c2d7b4e87..35bf38de58 100644 --- a/lightllm/models/mixtral/model.py +++ b/lightllm/models/mixtral/model.py @@ -2,6 +2,7 @@ import numpy as np from lightllm.models.registry import ModelRegistry from lightllm.common.basemodel.basemodel import TpPartBaseModel +from lightllm.common.basemodel.routing_manager import init_routing_capture from lightllm.common.kv_cache_mem_manager import MemoryManager from lightllm.models.llama.layer_infer.post_layer_infer import LlamaPostLayerInfer from lightllm.models.llama.layer_infer.pre_layer_infer import LlamaPreLayerInfer @@ -45,6 +46,9 @@ def _verify_params(self): def _init_custom(self): self._init_to_get_rotary() + if self.args.enable_return_routed_experts: + num_moe_layers = len(self.trans_layers_weight) + init_routing_capture(self, num_moe_layers) return def _init_mem_manager(self): diff --git a/lightllm/models/qwen2_vl/model.py b/lightllm/models/qwen2_vl/model.py index 237c4ad897..c94135573b 100644 --- a/lightllm/models/qwen2_vl/model.py +++ b/lightllm/models/qwen2_vl/model.py @@ -12,6 +12,7 @@ from .vision_process import smart_resize from lightllm.models.qwen2.model import Qwen2TpPartModel import os +from typing import Union, List # Warp of the origal tokenizer class QWen2VLTokenizer(BaseMultiModalTokenizer): @@ -52,9 +53,13 @@ def get_image_token_length(self, img: ImageItem): def get_audio_token_length(self, audio: AudioItem): raise NotImplementedError - def encode(self, prompt, multimodal_params: MultimodalParams = None, **kwargs): - - origin_ids = self.tokenizer.encode(prompt) + def encode(self, prompt: Union[str, List[int]], multimodal_params: MultimodalParams = None, **kwargs): + if isinstance(prompt, str): + origin_ids = self.tokenizer.encode(prompt) + elif isinstance(prompt, list): + origin_ids = prompt + else: + raise ValueError(f"Unsupported prompt type: {type(prompt)}") # -> origin_ids = [token for token in origin_ids if token != self.image_token_id] diff --git a/lightllm/models/qwen3_5_moe/layer_infer/__init__.py b/lightllm/models/qwen3_5_moe/layer_infer/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/lightllm/models/qwen3_5_moe/layer_weights/transformer_layer_weight.py b/lightllm/models/qwen3_5_moe/layer_weights/transformer_layer_weight.py index fe4b1883bd..44425e7e10 100644 --- a/lightllm/models/qwen3_5_moe/layer_weights/transformer_layer_weight.py +++ b/lightllm/models/qwen3_5_moe/layer_weights/transformer_layer_weight.py @@ -12,7 +12,6 @@ def load_hf_weights(self, weights): def split_fused_expert_weights(weights: dict, layer_num: int, moe_intermediate_size: int): layer_prefix = f"model.layers.{layer_num}." keys = list(weights.keys()) - num_experts = 0 for k in keys: if not k.startswith(layer_prefix): @@ -20,21 +19,8 @@ def split_fused_expert_weights(weights: dict, layer_num: int, moe_intermediate_s if "mlp.experts.gate_up_proj" in k: fused_weight = weights.pop(k) # [num_experts, 2*inter_size, hidden_size] - num_experts = fused_weight.shape[0] - prefix = k.rsplit(".gate_up_proj", 1)[0] gate_weight = fused_weight[:, :moe_intermediate_size, :] up_weight = fused_weight[:, moe_intermediate_size:, :] - - for expert_idx in range(num_experts): - weights[f"{prefix}.{expert_idx}.gate_proj.weight"] = gate_weight[expert_idx] - weights[f"{prefix}.{expert_idx}.up_proj.weight"] = up_weight[expert_idx] - - elif "mlp.experts.down_proj" in k: - down_weight = weights.pop(k) # [num_experts, hidden_size, inter_size] - num_experts = down_weight.shape[0] - - prefix = k.rsplit(".down_proj", 1)[0] - - for expert_idx in range(num_experts): - weights[f"{prefix}.{expert_idx}.down_proj.weight"] = down_weight[expert_idx] + weights[f"{prefix}.gate_proj"] = gate_weight + weights[f"{prefix}.up_proj"] = up_weight diff --git a/lightllm/models/qwen3_moe/layer_infer/transformer_layer_infer.py b/lightllm/models/qwen3_moe/layer_infer/transformer_layer_infer.py index 54e4373652..744ddc9d4f 100644 --- a/lightllm/models/qwen3_moe/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/qwen3_moe/layer_infer/transformer_layer_infer.py @@ -85,6 +85,7 @@ def _moe_ffn_tp( use_grouped_topk=False, topk_group=None, num_expert_group=None, + microbatch_index=infer_state.microbatch_index, ) return hidden_states.view(num_tokens, hidden_dim) @@ -104,6 +105,7 @@ def _moe_ffn_edp( topk_group=None, num_expert_group=None, is_prefill=infer_state.is_prefill, + microbatch_index=infer_state.microbatch_index, ) ep_output = ep_output.view(token_num, hidden_dim) diff --git a/lightllm/models/qwen3_moe/layer_weights/transformer_layer_weight.py b/lightllm/models/qwen3_moe/layer_weights/transformer_layer_weight.py index e525cb2d20..5358229949 100644 --- a/lightllm/models/qwen3_moe/layer_weights/transformer_layer_weight.py +++ b/lightllm/models/qwen3_moe/layer_weights/transformer_layer_weight.py @@ -52,6 +52,11 @@ def _init_moe(self): tp_rank=0, tp_world_size=1, ) + mlp_only = set(self.network_config_.get("mlp_only_layers", [])) + step = self.network_config_.get("decoder_sparse_step", 1) + moe_layer_index = sum( + 1 for i in range(self.layer_num_) if self.n_routed_experts > 0 and i not in mlp_only and (i + 1) % step == 0 + ) self.experts = FusedMoeWeight( gate_proj_name="gate_proj", down_proj_name="down_proj", @@ -65,6 +70,7 @@ def _init_moe(self): quant_method=self.quant_cfg.get_quant_method(self.layer_num_, "fused_moe"), layer_num=self.layer_num_, network_config=self.network_config_, + moe_layer_index=moe_layer_index, ) def _init_qkv(self): diff --git a/lightllm/models/qwen3_moe/model.py b/lightllm/models/qwen3_moe/model.py index b71d7f4878..506f69e3ff 100644 --- a/lightllm/models/qwen3_moe/model.py +++ b/lightllm/models/qwen3_moe/model.py @@ -1,9 +1,14 @@ import torch from typing import final from lightllm.models.registry import ModelRegistry -from lightllm.models.qwen3_moe.layer_infer.transformer_layer_infer import Qwen3MOETransformerLayerInfer -from lightllm.models.qwen3_moe.layer_weights.transformer_layer_weight import Qwen3MOETransformerLayerWeight +from lightllm.models.qwen3_moe.layer_infer.transformer_layer_infer import ( + Qwen3MOETransformerLayerInfer, +) +from lightllm.models.qwen3_moe.layer_weights.transformer_layer_weight import ( + Qwen3MOETransformerLayerWeight, +) from lightllm.models.qwen3.model import Qwen3TpPartModel +from lightllm.common.basemodel.routing_manager import init_routing_capture from lightllm.utils.log_utils import init_logger from lightllm.distributed.communication_op import dist_group_manager @@ -28,3 +33,6 @@ def _init_custom(self): # Only initialize DeepEP group for MoE models with num_experts if "num_experts" in self.config and self.config["num_experts"] > 0: dist_group_manager.new_deepep_group(self.config["num_experts"], self.config["hidden_size"]) + if self.args.enable_return_routed_experts: + num_moe_layers = sum(1 for w in self.trans_layers_weight if w.is_moe) + init_routing_capture(self, num_moe_layers) diff --git a/lightllm/models/qwen3_moe_mtp/model.py b/lightllm/models/qwen3_moe_mtp/model.py index 9f83832a7e..b4be10d0d0 100644 --- a/lightllm/models/qwen3_moe_mtp/model.py +++ b/lightllm/models/qwen3_moe_mtp/model.py @@ -26,6 +26,7 @@ def _pre_init(self, kvargs: dict): return def _init_custom(self): + super()._init_custom() self._cos_cached = self.main_model._cos_cached self._sin_cached = self.main_model._sin_cached return diff --git a/lightllm/server/api_cli.py b/lightllm/server/api_cli.py index f33f58b86d..eedbbf7526 100644 --- a/lightllm/server/api_cli.py +++ b/lightllm/server/api_cli.py @@ -1,8 +1,7 @@ import argparse -def make_argument_parser() -> argparse.ArgumentParser: - parser = argparse.ArgumentParser() +def add_cli_args(parser: argparse.ArgumentParser) -> argparse.ArgumentParser: parser.add_argument( "--run_mode", @@ -61,6 +60,15 @@ def make_argument_parser() -> argparse.ArgumentParser: default=None, help="p d mode, decode node used for kv move manager rpyc server port", ) + parser.add_argument( + "--control_rpyc_port", + type=int, + default=None, + help=( + "rpyc port on master router for control-plane ops " + "(flush_cache, update_weights, etc.); auto-allocated if unset" + ), + ) parser.add_argument( "--select_p_d_node_strategy", type=str, @@ -263,6 +271,12 @@ def make_argument_parser() -> argparse.ArgumentParser: parser.add_argument( "--nccl_port", type=int, default=None, help="the nccl_port to build a distributed environment for PyTorch" ) + parser.add_argument( + "--lightllm_instance_id", + type=int, + default=0, + help="Instance ID (0~7) for multi-instance port isolation. Each ID maps to a dedicated port range.", + ) parser.add_argument( "--use_config_server_to_init_nccl", action="store_true", @@ -747,6 +761,12 @@ def make_argument_parser() -> argparse.ArgumentParser: parser.add_argument( "--disk_cache_storage_size", type=float, default=10, help="""The capacity of disk cache. GB used.""" ) + parser.add_argument( + "--enable_torch_memory_saver", + action="store_true", + help="""enable torch memory saver, which is used for release_memory and resume_memory during RL training.""", + ) + parser.add_argument("--enable_weight_cpu_backup", action="store_true", help="""enable weight cpu backup.""") parser.add_argument( "--disk_cache_dir", type=str, @@ -820,4 +840,10 @@ def make_argument_parser() -> argparse.ArgumentParser: If the op is not implemented for the platform and the hardware support triton, it will use triton implementation.""", ) + parser.add_argument( + "--enable_return_routed_experts", + action="store_true", + default=False, + help="Enable returning routed expert indices for MoE models (R3 feature).", + ) return parser diff --git a/lightllm/server/api_http.py b/lightllm/server/api_http.py index c106ca1cd9..0681bdb47d 100755 --- a/lightllm/server/api_http.py +++ b/lightllm/server/api_http.py @@ -34,7 +34,7 @@ import uuid from PIL import Image import multiprocessing as mp -from typing import AsyncGenerator, Union +from typing import Any, AsyncGenerator, Union from typing import Callable from lightllm.server import TokenLoad from fastapi import BackgroundTasks, FastAPI, Request, WebSocket, WebSocketDisconnect @@ -50,6 +50,7 @@ from lightllm.utils.error_utils import ClientDisconnected, ServerBusyError from lightllm.server.metrics.manager import MetricClient from lightllm.utils.envs_utils import get_unique_server_name +from lightllm.server.io_struct import ReleaseMemoryReq, ResumeMemoryReq from dataclasses import dataclass from .api_openai import chat_completions_impl, completions_impl @@ -61,6 +62,16 @@ ModelCard, ModelListResponse, ) +from .io_struct import ( + AbortReq, + FlushCacheReq, + InitWeightsUpdateGroupReq, + DestroyWeightsUpdateGroupReq, + UpdateWeightsFromDistributedReq, + UpdateWeightsFromTensorReq, + UpdateWeightsFromIPCReq, + GeneralModelToHttpRpcRsp, +) from .build_prompt import build_prompt, init_tokenizer logger = init_logger(__name__) @@ -188,6 +199,22 @@ def get_model_name(): return {"model_name": g_objs.args.model_name} +@app.get("/get_server_info") +@app.post("/get_server_info") +def get_server_info(): + # 将 StartArgs 转换为字典格式 + from dataclasses import asdict + + server_info: dict[str, Any] = asdict(g_objs.args) + return {**server_info} + + +@app.get("/get_weight_version") +@app.post("/get_weight_version") +def get_weight_version(): + return {"weight_version": g_objs.args.weight_version} + + @app.get("/healthz", summary="Check server health") @app.get("/health", summary="Check server health") @app.head("/health", summary="Check server health") @@ -412,6 +439,88 @@ async def metrics() -> Response: return response +@app.post("/abort_request") +async def abort_request(request: AbortReq, raw_request: Request): + """Abort a request.""" + try: + await g_objs.httpserver_manager.abort_request(request) + return Response(status_code=200) + except Exception as e: + return create_error_response(HTTPStatus.EXPECTATION_FAILED, f"error: {str(e)}") + + +async def handle_request_common(request_obj, handler): + try: + ret: GeneralModelToHttpRpcRsp = await handler(request_obj) + if ret.success: + return JSONResponse({"success": ret.success, "message": ret.msg}, status_code=200) + else: + return create_error_response(HTTPStatus.BAD_REQUEST, ret.msg) + except Exception as e: + logger.error("handle_request_common (%s) error occurred: %s", str(request_obj), str(e), exc_info=True) + return create_error_response(HTTPStatus.EXPECTATION_FAILED, f"error: {str(e)}") + + +@app.post("/init_weights_update_group") +async def init_weights_update_group(request: InitWeightsUpdateGroupReq, raw_request: Request): + """Init weights update group.""" + return await handle_request_common(request, g_objs.httpserver_manager.init_weights_update_group) + + +@app.post("/destroy_weights_update_group") +async def destroy_weights_update_group(request: DestroyWeightsUpdateGroupReq, raw_request: Request): + """Destroy weights update group.""" + return await handle_request_common(request, g_objs.httpserver_manager.destroy_weights_update_group) + + +@app.post("/update_weights_from_distributed") +async def update_weights_from_distributed(request: UpdateWeightsFromDistributedReq, raw_request: Request): + """Update model parameter from distributed online.""" + return await handle_request_common(request, g_objs.httpserver_manager.update_weights_from_distributed) + + +@app.post("/update_weights_from_tensor") +async def update_weights_from_tensor(request: UpdateWeightsFromTensorReq, raw_request: Request): + """Update model parameter from distributed online.""" + return await handle_request_common(request, g_objs.httpserver_manager.update_weights_from_tensor) + + +@app.post("/update_weights_from_ipc") +async def update_weights_from_ipc(request: UpdateWeightsFromIPCReq, raw_request: Request): + return await handle_request_common(request, g_objs.httpserver_manager.update_weights_from_ipc) + + +@app.post("/flush_cache") +@app.get("/flush_cache") +async def flush_cache(): + """Flush the radix cache.""" + return await handle_request_common(FlushCacheReq(), g_objs.httpserver_manager.flush_cache) + + +@app.post("/pause_generation") +async def pause_generation(): + await g_objs.httpserver_manager.pause_generation() + return Response(content="Generation paused successfully.", status_code=200) + + +@app.post("/continue_generation") +async def continue_generation(): + await g_objs.httpserver_manager.continue_generation() + return Response(content="Generation continued successfully.", status_code=200) + + +@app.get("/release_memory_occupation") +@app.post("/release_memory_occupation") +async def release_memory_occupation(request: ReleaseMemoryReq): + return await handle_request_common(request, g_objs.httpserver_manager.release_memory_occupation) + + +@app.get("/resume_memory_occupation") +@app.post("/resume_memory_occupation") +async def resume_memory_occupation(request: ResumeMemoryReq): + return await handle_request_common(request, g_objs.httpserver_manager.resume_memory_occupation) + + @app.websocket("/pd_register") async def register_and_keep_alive(websocket: WebSocket): await websocket.accept() diff --git a/lightllm/server/api_lightllm.py b/lightllm/server/api_lightllm.py index 39a5808aab..28d57ccdc4 100644 --- a/lightllm/server/api_lightllm.py +++ b/lightllm/server/api_lightllm.py @@ -35,6 +35,9 @@ async def lightllm_generate(request: Request, httpserver_manager: HttpServerMana prompt = request_dict.pop("inputs") sample_params_dict = request_dict["parameters"] return_details = sample_params_dict.pop("return_details", False) + return_routed_experts = sample_params_dict.pop( + "return_routed_experts", httpserver_manager.args.enable_return_routed_experts + ) sampling_params = SamplingParams() sampling_params.init(tokenizer=httpserver_manager.tokenizer, **sample_params_dict) sampling_params.verify() @@ -53,6 +56,7 @@ async def lightllm_generate(request: Request, httpserver_manager: HttpServerMana prompt_token_ids = None is_first_metadata = True input_usage = None + routed_experts_data = None async for sub_req_id, request_output, metadata, finish_status in results_generator: # when set "--return_all_prompt_logprobs", the first token metadata will contains # prompt_logprobs and prompt_token_ids @@ -78,6 +82,8 @@ async def lightllm_generate(request: Request, httpserver_manager: HttpServerMana if finish_status.is_finished(): finish_reason_dict[sub_req_id] = finish_status + if "routed_experts" in metadata: + routed_experts_data = metadata["routed_experts"] n = sampling_params.n sub_ids = list(final_output_dict.keys())[:n] final_output_list = ["".join(final_output_dict[sub_id]) for sub_id in sub_ids] @@ -102,6 +108,8 @@ async def lightllm_generate(request: Request, httpserver_manager: HttpServerMana ret["prompt_logprobs"] = prompt_logprobs if input_usage is not None: ret["input_usage"] = input_usage + if return_routed_experts and routed_experts_data is not None: + ret["routed_experts"] = routed_experts_data return Response(content=json.dumps(ret, ensure_ascii=False).encode("utf-8")) @@ -112,6 +120,7 @@ async def lightllm_generate_stream(request: Request, httpserver_manager: HttpSer prompt = request_dict.pop("inputs") sample_params_dict = request_dict["parameters"] _ = sample_params_dict.pop("return_details", False) + _ = sample_params_dict.pop("return_routed_experts", None) sampling_params = SamplingParams() sampling_params.init(tokenizer=httpserver_manager.tokenizer, **sample_params_dict) sampling_params.verify() diff --git a/lightllm/server/api_server.py b/lightllm/server/api_server.py index 6e04d5d47e..5306ecb698 100755 --- a/lightllm/server/api_server.py +++ b/lightllm/server/api_server.py @@ -1,11 +1,22 @@ import torch -from .api_cli import make_argument_parser +from .api_cli import add_cli_args +from lightllm.server.core.objs.start_args_type import StartArgs +from lightllm.utils.log_utils import init_logger + +logger = init_logger(__name__) -if __name__ == "__main__": - torch.multiprocessing.set_start_method("spawn") # this code will not be ok for settings to fork to subprocess - parser = make_argument_parser() - args = parser.parse_args() - from .api_start import pd_master_start, normal_or_p_d_start, visual_only_start, config_server_start + +def launch_server(args: StartArgs): + from .api_start import pd_master_start, normal_or_p_d_start, config_server_start, visual_only_start + + try: + # this code will not be ok for settings to fork to subprocess + torch.multiprocessing.set_start_method("spawn") + except RuntimeError as e: + logger.warning(f"Failed to set start method: {e}") + except Exception as e: + logger.error(f"Failed to set start method: {e}") + raise e if args.run_mode == "pd_master": pd_master_start(args) @@ -15,3 +26,13 @@ visual_only_start(args) else: normal_or_p_d_start(args) + + +if __name__ == "__main__": + from argparse import ArgumentParser + + parser = ArgumentParser() + add_cli_args(parser) + args = parser.parse_args() + + launch_server(StartArgs(**vars(args))) diff --git a/lightllm/server/api_start.py b/lightllm/server/api_start.py index 8c6af128c8..ad3e269369 100644 --- a/lightllm/server/api_start.py +++ b/lightllm/server/api_start.py @@ -1,3 +1,4 @@ +import multiprocessing as mp import os import sys import time @@ -17,6 +18,7 @@ from lightllm.utils.multinode_utils import send_and_receive_node_ip from lightllm.utils.redis_utils import start_redis_service from lightllm.utils.shm_size_check import check_recommended_shm_size +from lightllm.server.core.objs.start_args_type import StartArgs from lightllm.utils.config_utils import ( has_audio_module, has_vision_module, @@ -59,9 +61,31 @@ def signal_handler(sig, frame): process_manager.terminate_all_processes() logger.info("All processes have been terminated gracefully.") sys.exit(0) + elif sig == signal.SIGHUP: + logger.info("Received SIGHUP (terminal closed), shutting down gracefully...") + if http_server_process and http_server_process.poll() is None: + http_server_process.send_signal(signal.SIGTERM) + + start_time = time.time() + while (time.time() - start_time) < 60: + if not is_process_active(http_server_process.pid): + logger.info("httpserver exit") + break + time.sleep(1) + + if time.time() - start_time < 60: + logger.info("HTTP server has exited gracefully") + else: + logger.warning("HTTP server did not exit in time, killing it...") + kill_recursive(http_server_process) + + process_manager.terminate_all_processes() + logger.info("All processes have been terminated gracefully due to terminal closure.") + sys.exit(0) signal.signal(signal.SIGTERM, signal_handler) signal.signal(signal.SIGINT, signal_handler) + signal.signal(signal.SIGHUP, signal_handler) logger.info(f"start process pid {os.getpid()}") if http_server_process: @@ -69,13 +93,15 @@ def signal_handler(sig, frame): return -def normal_or_p_d_start(args): - from lightllm.server.core.objs.start_args_type import StartArgs +def _set_envs_and_config(args: StartArgs): + mp.set_start_method("spawn", force=True) - args: StartArgs = args + +def _launch_subprocesses(args: StartArgs): + + _set_envs_and_config(args) auto_set_max_req_total_len(args) - set_unique_server_name(args) if args.enable_mps: from lightllm.utils.device_utils import enable_mps @@ -143,12 +169,6 @@ def normal_or_p_d_start(args): check_recommended_shm_size(args) assert args.zmq_mode in ["tcp://", "ipc:///tmp/"] - # 确保单机上多实列不冲突 - if args.zmq_mode == "ipc:///tmp/": - zmq_mode = f"{args.zmq_mode}_{get_unique_server_name()}_" - args.zmq_mode = None # args 的参数不能直接设置,只能先设置None,再设置才能成功 - args.zmq_mode = zmq_mode - logger.info(f"zmq mode head: {args.zmq_mode}") logger.info(f"use tgi api: {args.use_tgi_api}") @@ -207,12 +227,16 @@ def normal_or_p_d_start(args): # mtp params check if args.mtp_mode is not None: - assert args.mtp_draft_model_dir is not None + if args.mtp_draft_model_dir is None: + args.mtp_draft_model_dir = [args.model_dir] * args.mtp_step assert args.mtp_step > 0 else: assert args.mtp_draft_model_dir is None assert args.mtp_step == 0 + # automatically set visual_dp based on visual_tp and tp + if args.visual_tp < args.tp and args.tp % args.visual_tp == 0: + args.visual_dp = args.tp // args.visual_tp if args.afs_image_embed_dir is not None: os.makedirs(args.afs_image_embed_dir, mode=0o777, exist_ok=True) os.chmod(args.afs_image_embed_dir, 0o777) @@ -334,6 +358,8 @@ def normal_or_p_d_start(args): already_uesd_ports.append(args.nccl_port) if args.pd_decode_rpyc_port is not None: already_uesd_ports.append(args.pd_decode_rpyc_port) + if args.control_rpyc_port is not None: + already_uesd_ports.append(args.control_rpyc_port) if args.visual_nccl_ports is not None: already_uesd_ports.extend(args.visual_nccl_ports[: args.visual_dp]) if not args.disable_audio and args.audio_nccl_ports is not None: @@ -346,7 +372,8 @@ def normal_or_p_d_start(args): node_world_size = args.tp // args.nnodes can_use_ports = alloc_can_use_network_port( - num=10 + node_world_size + args.visual_dp * args.visual_tp + args.visual_dp + args.audio_dp, + num=11 + node_world_size + args.visual_dp * args.visual_tp + args.visual_dp + args.audio_dp, + instance_id=args.lightllm_instance_id, used_ports=already_uesd_ports, ) logger.info(f"alloced ports: {can_use_ports}") @@ -361,8 +388,9 @@ def normal_or_p_d_start(args): metric_port, multi_level_kv_cache_port, pd_decode_rpyc_port, - ) = can_use_ports[0:10] - can_use_ports = can_use_ports[10:] + control_rpyc_port, + ) = can_use_ports[0:11] + can_use_ports = can_use_ports[11:] if args.visual_nccl_ports is None: args.visual_nccl_ports = can_use_ports[: args.visual_dp] @@ -381,6 +409,18 @@ def normal_or_p_d_start(args): args.nccl_port = nccl_port if args.pd_decode_rpyc_port is None: args.pd_decode_rpyc_port = pd_decode_rpyc_port + if args.control_rpyc_port is None: + args.control_rpyc_port = control_rpyc_port + + set_unique_server_name(args) + + # 确保单机上多实列不冲突 + if args.zmq_mode == "ipc:///tmp/": + zmq_mode = f"{args.zmq_mode}_{get_unique_server_name()}_" + args.zmq_mode = None # args 的参数不能直接设置,只能先设置None,再设置才能成功 + args.zmq_mode = zmq_mode + logger.info(f"zmq mode head: {args.zmq_mode}") + args.router_port = router_port args.detokenization_port = detokenization_port args.http_server_port = http_server_port @@ -487,6 +527,13 @@ def normal_or_p_d_start(args): ], ) + return process_manager + + +def normal_or_p_d_start(args: StartArgs): + + process_manager = _launch_subprocesses(args) + # 启动 Hypercorn command = [ "hypercorn", @@ -522,7 +569,7 @@ def normal_or_p_d_start(args): return -def pd_master_start(args): +def pd_master_start(args: StartArgs): set_unique_server_name(args) if args.run_mode != "pd_master": return @@ -541,10 +588,7 @@ def pd_master_start(args): logger.info(f"all start args:{args}") can_use_ports = alloc_can_use_network_port( - num=1, - used_ports=[ - args.port, - ], + num=1, used_nccl_ports=[args.nccl_port, args.port], instance_id=args.lightllm_instance_id ) metric_port = can_use_ports[0] diff --git a/lightllm/server/audioserver/model_infer/__init__.py b/lightllm/server/audioserver/model_infer/__init__.py index 6068b000ce..2709977c26 100644 --- a/lightllm/server/audioserver/model_infer/__init__.py +++ b/lightllm/server/audioserver/model_infer/__init__.py @@ -8,6 +8,7 @@ from lightllm.utils.retry_utils import retry from rpyc.utils.factory import unix_connect from lightllm.utils.graceful_utils import graceful_registry +from lightllm.utils.process_check import start_parent_check_thread from .model_rpc_client import AudioModelRpcClient from .model_rpc import AudioModelRpcServer from ..objs import rpyc_config @@ -17,6 +18,7 @@ def _init_env(socket_path: str, success_event): graceful_registry(inspect.currentframe().f_code.co_name) setproctitle.setproctitle(f"lightllm::{get_unique_server_name()}::audio_model_infer") + start_parent_check_thread() import lightllm.utils.rpyc_fix_utils as _ diff --git a/lightllm/server/core/objs/out_token_circlequeue.py b/lightllm/server/core/objs/out_token_circlequeue.py index ea99dae5f6..8019c9a1a1 100644 --- a/lightllm/server/core/objs/out_token_circlequeue.py +++ b/lightllm/server/core/objs/out_token_circlequeue.py @@ -4,6 +4,9 @@ LIGHTLLM_TOKEN_MAX_BYTES = int(os.getenv("LIGHTLLM_TOKEN_MAX_BYTES", 1280)) LIGHTLLM_OUT_TOKEN_QUEUE_SIZE = int(os.getenv("LIGHTLLM_OUT_TOKEN_QUEUE_SIZE", 8)) +from lightllm.utils.log_utils import init_logger + +logger = init_logger(__name__) class QueueItem(ctypes.Structure): @@ -24,9 +27,19 @@ def __init__(self): def set(self, token_str: str, src_index: int, special: bool, count_output_tokens: int): str_bytes = token_str.encode("utf-8") - assert ( - len(str_bytes) <= LIGHTLLM_TOKEN_MAX_BYTES - ), f"Token string {len(str_bytes)} exceeds maximum length of {LIGHTLLM_TOKEN_MAX_BYTES} bytes." + max_data_len = max(LIGHTLLM_TOKEN_MAX_BYTES - 1, 0) + if len(str_bytes) > max_data_len: + logger.error( + "Token string exceeds max bytes: bytes=%d limit=%d src_index=%d count_output_tokens=%d preview=%s", + len(str_bytes), + max_data_len, + src_index, + count_output_tokens, + token_str, + ) + str_bytes = str_bytes[:max_data_len] + # Ensure truncation never leaves an incomplete UTF-8 sequence. + str_bytes = str_bytes.decode("utf-8", errors="ignore").encode("utf-8") ctypes.memmove(self.data, str_bytes, len(str_bytes)) self.data_len = len(str_bytes) self.src_index = src_index diff --git a/lightllm/server/core/objs/py_sampling_params.py b/lightllm/server/core/objs/py_sampling_params.py index 5d3a511d21..4489ccd708 100644 --- a/lightllm/server/core/objs/py_sampling_params.py +++ b/lightllm/server/core/objs/py_sampling_params.py @@ -114,13 +114,18 @@ def __init__( def load_generation_cfg(cls, weight_dir): try: generation_cfg = GenerationConfig.from_pretrained(weight_dir, trust_remote_code=True).to_dict() - cls._do_sample = generation_cfg.get("do_sample", False) - cls._presence_penalty = generation_cfg.get("presence_penalty", 0.0) - cls._frequency_penalty = generation_cfg.get("frequency_penalty", 0.0) - cls._repetition_penalty = generation_cfg.get("repetition_penalty", 1.0) - cls._temperature = generation_cfg.get("temperature", 1.0) - cls._top_p = generation_cfg.get("top_p", 1.0) - cls._top_k = generation_cfg.get("top_k", -1) + + def _cfg(key, default): + v = generation_cfg.get(key) + return v if v is not None else default + + cls._do_sample = _cfg("do_sample", False) + cls._presence_penalty = _cfg("presence_penalty", 0.0) + cls._frequency_penalty = _cfg("frequency_penalty", 0.0) + cls._repetition_penalty = _cfg("repetition_penalty", 1.0) + cls._temperature = _cfg("temperature", 1.0) + cls._top_p = _cfg("top_p", 1.0) + cls._top_k = _cfg("top_k", -1) cls._stop_sequences = generation_cfg.get("stop", None) except: pass diff --git a/lightllm/server/core/objs/req.py b/lightllm/server/core/objs/req.py index 7f2b697091..d954870393 100644 --- a/lightllm/server/core/objs/req.py +++ b/lightllm/server/core/objs/req.py @@ -1,6 +1,7 @@ import os import math import ctypes +import base64 import numpy as np import time from .sampling_params import SamplingParams @@ -14,6 +15,7 @@ from lightllm.utils.kv_cache_utils import compute_token_list_hash from typing import List, Any, Union from lightllm.utils.log_utils import init_logger +from lightllm.utils.shm_utils import create_or_link_shm logger = init_logger(__name__) @@ -25,19 +27,20 @@ class FinishStatus(ctypes.Structure): NO_FINISH = 0 FINISHED_STOP = 1 FINISHED_LENGTH = 2 + FINISHED_ABORTED = 3 def __init__(self, init_state=NO_FINISH): self.status = init_state def set_status(self, new_status): - assert 0 <= new_status <= 2 + assert 0 <= new_status <= 3 self.status = new_status def get_status(self): return self.status def is_finished(self): - return self.FINISHED_STOP <= self.status <= self.FINISHED_LENGTH + return self.FINISHED_STOP <= self.status <= self.FINISHED_ABORTED def is_stopped(self): return self.status == self.FINISHED_STOP @@ -50,6 +53,8 @@ def get_finish_reason(self): return "stop" elif self.status == self.FINISHED_LENGTH: return "length" + elif self.status == self.FINISHED_ABORTED: + return "abort" return None @@ -125,6 +130,8 @@ class Req(ctypes.Structure): ("token_hash_page_len_list", TokenPageLenList), # 用于保存查找匹配到的可以被复用的cpu cache 页面信息。 ("cpu_cache_match_page_indexes", CpuCachePageList), + # Number of tokens in routing data SHM, written by model worker, read by HTTP server. + ("shm_routing_num_tokens", ctypes.c_int), ] def get_str(self): @@ -182,6 +189,7 @@ def init( self._mtp_step = get_env_start_args().mtp_step self.stop_str_matched = False self.stop_str_matched_token_index = -1 + self.shm_routing_num_tokens = 0 self.post_init() @@ -277,6 +285,69 @@ def link_logprobs_shm_array(self): self.shm_logprobs.link_shm() return + def create_routing_data_shm_array(self, num_moe_layers: int, num_tokens: int, topk: int, np_dtype=np.int8): + """Create routing SHM at actual size (on-demand, not pre-allocated). + + Uses smart mode: links if same-sized SHM exists, otherwise creates new. + """ + service_uni_name = get_unique_server_name() + name = f"{service_uni_name}_shm_routing_{self.index_in_shm_mem}" + shape = (num_tokens, num_moe_layers, topk) + self.shm_routing_data = ShmArray(name, shape, dtype=np_dtype) + self.shm_routing_data.create_shm() + self.shm_routing_num_tokens = num_tokens + return + + def link_routing_data_shm_array(self, num_moe_layers: int, topk: int, np_dtype=np.int8): + """Link to routing SHM from the reader side (HTTP server).""" + if num_moe_layers == 0: + return + num_tokens = self.shm_routing_num_tokens + if num_tokens <= 0: + return + service_uni_name = get_unique_server_name() + name = f"{service_uni_name}_shm_routing_{self.index_in_shm_mem}" + shape = (num_tokens, num_moe_layers, topk) + self.shm_routing_data = ShmArray(name, shape, dtype=np_dtype) + self.shm_routing_data.link_shm() + return + + def get_routing_data(self): + if not hasattr(self, "shm_routing_data") or self.shm_routing_data is None: + return None + return self.shm_routing_data.arr + + def close_routing_data_shm_array(self): + """Close and unlink routing SHM (on-demand, no longer pooled).""" + if hasattr(self, "shm_routing_data") and self.shm_routing_data is not None: + self.shm_routing_data.close_shm() + self.shm_routing_data = None + self.shm_routing_num_tokens = 0 + return + + def get_routing_metadata(self, num_moe_layers: int, topk: int, dtype_id: int = 1): + if num_moe_layers == 0 or topk == 0: + return None + if self.shm_routing_num_tokens <= 0: + return None + try: + from lightllm.common.basemodel.routing_manager import routing_dtype_id_to_np + + np_dtype = routing_dtype_id_to_np(dtype_id) + if not hasattr(self, "shm_routing_data") or self.shm_routing_data is None: + self.link_routing_data_shm_array(num_moe_layers, topk, np_dtype=np_dtype) + routing_data = self.get_routing_data() + if routing_data is None: + return None + return { + "shape": list(routing_data.shape), + "dtype": str(routing_data.dtype), + "data": base64.b64encode(routing_data.tobytes()).decode("ascii"), + } + except Exception as e: + logger.warning(f"Failed to read routing data for req {self.request_id}: {e}") + return None + def get_prompt_ids(self): return self.shm_prompt_ids.arr[: self.input_len].tolist() @@ -297,9 +368,8 @@ def can_release(self): ref_count_ok = self.ref_count == 1 can_released_mark = self.can_released_mark - if self.is_aborted and can_released_mark and ref_count_ok: - return True - + # if self.is_aborted and can_released_mark and ref_count_ok: + # return True ok_finished_gen_req = self.finish_status.is_finished() or self.stop_str_matched if ok_finished_gen_req and can_released_mark and ref_count_ok and self.out_tokens_queue.is_empty(): diff --git a/lightllm/server/core/objs/sampling_params.py b/lightllm/server/core/objs/sampling_params.py index c94f3c6957..0946622928 100644 --- a/lightllm/server/core/objs/sampling_params.py +++ b/lightllm/server/core/objs/sampling_params.py @@ -313,6 +313,8 @@ class SamplingParams(ctypes.Structure): ("ignore_eos", ctypes.c_bool), # the max number of image patches to be used in the internvl model, for the test ("image_max_patch_num", ctypes.c_int), + ("min_pixels", ctypes.c_int), + ("max_pixels", ctypes.c_int), ("max_new_tokens", ctypes.c_int), ("min_new_tokens", ctypes.c_int), # Whether to count input tokens for presence_penalty, frequency_penalty and repetition_penalty @@ -356,6 +358,9 @@ class SamplingParams(ctypes.Structure): def init(self, tokenizer, **kwargs): super().__init__() + # 移除kwargs中为null的参数,避免覆盖默认值 + kwargs = {k: v for k, v in kwargs.items() if v is not None} + self.best_of = kwargs.get("best_of", 1) self.n = kwargs.get("n", self.best_of) self.do_sample = kwargs.get("do_sample", SamplingParams._do_sample) @@ -437,15 +442,18 @@ def init(self, tokenizer, **kwargs): def load_generation_cfg(cls, weight_dir): try: generation_cfg = GenerationConfig.from_pretrained(weight_dir, trust_remote_code=True).to_dict() - cls._do_sample = generation_cfg.get("do_sample", False) - cls._presence_penalty = generation_cfg.get("presence_penalty", 0.0) - cls._frequency_penalty = generation_cfg.get("frequency_penalty", 0.0) - cls._repetition_penalty = generation_cfg.get("repetition_penalty", 1.0) - if cls._repetition_penalty is None: - cls._repetition_penalty = 1.0 - cls._temperature = generation_cfg.get("temperature", 1.0) - cls._top_p = generation_cfg.get("top_p", 1.0) - cls._top_k = generation_cfg.get("top_k", -1) + + def _cfg(key, default): + v = generation_cfg.get(key) + return v if v is not None else default + + cls._do_sample = _cfg("do_sample", False) + cls._presence_penalty = _cfg("presence_penalty", 0.0) + cls._frequency_penalty = _cfg("frequency_penalty", 0.0) + cls._repetition_penalty = _cfg("repetition_penalty", 1.0) + cls._temperature = _cfg("temperature", 1.0) + cls._top_p = _cfg("top_p", 1.0) + cls._top_k = _cfg("top_k", -1) except: pass @@ -512,6 +520,8 @@ def to_dict(self): "image_max_patch_num": self.image_max_patch_num, "max_new_tokens": self.max_new_tokens, "min_new_tokens": self.min_new_tokens, + "min_pixels": self.min_pixels, + "max_pixels": self.max_pixels, "exponential_decay_length_penalty": self.exponential_decay_length_penalty.to_tuple(), "stop_sequences": self.stop_sequences.to_list(), "best_of": self.best_of, diff --git a/lightllm/server/core/objs/shm_array.py b/lightllm/server/core/objs/shm_array.py index c5ad512c6b..74d64b6c5e 100644 --- a/lightllm/server/core/objs/shm_array.py +++ b/lightllm/server/core/objs/shm_array.py @@ -26,6 +26,13 @@ def link_shm(self): self.arr = np.ndarray(self.shape, dtype=self.dtype, buffer=self.shm.buf) return + def detach_shm(self): + """Close handle without unlinking (SHM persists for reuse).""" + if self.shm is not None: + self.shm.close() + self.shm = None + self.arr = None + def close_shm(self): if self.shm is not None: self.shm.close() diff --git a/lightllm/server/core/objs/start_args_type.py b/lightllm/server/core/objs/start_args_type.py index 954daa50fe..5e13794200 100644 --- a/lightllm/server/core/objs/start_args_type.py +++ b/lightllm/server/core/objs/start_args_type.py @@ -1,7 +1,7 @@ from dataclasses import dataclass, field from typing import List, Optional, Tuple -# 只是为了更好的编程提示 +# 服务启动参数 @dataclass @@ -9,16 +9,27 @@ class StartArgs: run_mode: str = field( default="normal", metadata={ - "choices": ["normal", "prefill", "decode", "pd_master", "nixl_prefill", "nixl_decode", "visual_only"] + "choices": [ + "normal", + "prefill", + "decode", + "nixl_prefill", + "nixl_decode", + "pd_master", + "config_server", + "visual_only", + ] }, ) + performance_mode: str = field(default=None, metadata={"choices": ["personal"]}) host: str = field(default="127.0.0.1") port: int = field(default=8000) + httpserver_workers: int = field(default=1) zmq_mode: str = field( default="ipc:///tmp/", metadata={"help": "use socket mode or ipc mode, only can be set in ['tcp://', 'ipc:///tmp/']"}, ) - pd_master_ip: str = field(default="127.0.0.1") + pd_master_ip: str = field(default="0.0.0.0") pd_master_port: int = field(default=1212) config_server_host: str = field(default=None) config_server_port: int = field(default=None) @@ -26,18 +37,35 @@ class StartArgs: afs_image_embed_dir: str = field(default=None) afs_embed_capacity: int = field(default=250000) pd_decode_rpyc_port: int = field(default=None) - select_p_d_node_strategy: str = field(default=None) + control_rpyc_port: int = field(default=None) + select_p_d_node_strategy: str = field( + default="round_robin", metadata={"choices": ["random", "round_robin", "adaptive_load"]} + ) model_name: str = field(default="default_model_name") + model_owner: Optional[str] = field(default=None) model_dir: Optional[str] = field(default=None) - tokenizer_mode: str = field(default="slow") + tokenizer_mode: str = field(default="fast") load_way: str = field(default="HF") max_total_token_num: Optional[int] = field(default=None) mem_fraction: float = field(default=0.9) batch_max_tokens: Optional[int] = field(default=None) - eos_id: List[int] = field(default_factory=list) + eos_id: Optional[List[int]] = field(default=None) tool_call_parser: Optional[str] = field( default=None, - metadata={"choices": ["llama3", "qwen25", "mistral", "deepseekv3", "kimi_k2", "qwen", "qwen3_coder"]}, + metadata={ + "choices": [ + "qwen25", + "llama3", + "mistral", + "deepseekv3", + "qwen", + "deepseekv31", + "deepseekv32", + "glm47", + "kimi_k2", + "qwen3_coder", + ] + }, ) reasoning_parser: Optional[str] = field( default=None, @@ -60,7 +88,7 @@ class StartArgs: }, ) chat_template: Optional[str] = field(default=None) - running_max_req_size: int = field(default=512) + running_max_req_size: int = field(default=256) tp: int = field(default=1) dp: int = field(default=1) nnodes: int = field(default=1) @@ -69,12 +97,13 @@ class StartArgs: max_req_total_len: Optional[int] = field(default=None) nccl_host: str = field(default="127.0.0.1") nccl_port: int = field(default=None) + lightllm_instance_id: int = field(default=0) use_config_server_to_init_nccl: bool = field(default=False) trust_remote_code: bool = field(default=False) detail_log: bool = field(default=False) disable_log_stats: bool = field(default=False) log_stats_interval: int = field(default=10) - router_token_ratio: float = field(default=0.0) + router_token_ratio: float = field(default=None) router_max_wait_tokens: int = field(default=1) disable_aggressive_schedule: bool = field(default=False) disable_dynamic_prompt_cache: bool = field(default=False) @@ -82,7 +111,7 @@ class StartArgs: disable_chunked_prefill: bool = field(default=False) diverse_mode: bool = field(default=False) token_healing_mode: bool = field(default=False) - output_constraint_mode: str = field(default="none", metadata={"choices": ["none", "simple", "xgrammar"]}) + output_constraint_mode: str = field(default="none", metadata={"choices": ["outlines", "xgrammar", "none"]}) first_token_constraint_mode: bool = field(default=False) enable_multimodal: bool = field(default=False) disable_vision: Optional[bool] = field(default=None) @@ -105,12 +134,12 @@ class StartArgs: health_monitor: bool = field(default=False) metric_gateway: Optional[str] = field(default=None) job_name: str = field(default="lightllm") - grouping_key: List[str] = field(default_factory=list) + grouping_key: List[str] = field(default_factory=lambda: []) push_interval: int = field(default=10) visual_node_id: int = field(default=None) visual_infer_batch_size: int = field(default=None) visual_send_batch_size: int = field(default=1) - visual_gpu_ids: List[int] = field(default_factory=lambda: [0]) + visual_gpu_ids: List[int] = field(default=None) visual_tp: int = field(default=1) visual_dp: int = field(default=1) visual_nccl_ports: List[int] = field(default=None) @@ -128,18 +157,18 @@ class StartArgs: graph_split_batch_size: int = field(default=32) graph_grow_step_size: int = field(default=16) graph_max_len_in_batch: int = field(default=0) - quant_type: Optional[str] = field(default=None) + quant_type: Optional[str] = field(default="none") quant_cfg: Optional[str] = field(default=None) - vit_quant_type: Optional[str] = field(default=None) + vit_quant_type: Optional[str] = field(default="none") vit_quant_cfg: Optional[str] = field(default=None) llm_prefill_att_backend: List[str] = field( - default=("auto",), metadata={"choices": ["auto", "triton", "fa3", "flashinfer"]} + default_factory=lambda: ["auto"], metadata={"choices": ["auto", "triton", "fa3", "flashinfer"]} ) llm_decode_att_backend: List[str] = field( - default=("auto",), metadata={"choices": ["auto", "triton", "fa3", "flashinfer"]} + default_factory=lambda: ["auto"], metadata={"choices": ["auto", "triton", "fa3", "flashinfer"]} ) vit_att_backend: List[str] = field( - default=("auto",), metadata={"choices": ["auto", "triton", "fa3", "sdpa", "xformers"]} + default_factory=lambda: ["auto"], metadata={"choices": ["auto", "triton", "fa3", "sdpa", "xformers"]} ) llm_kv_type: str = field( default="None", metadata={"choices": ["None", "int8kv", "int4kv", "fp8kv_sph", "fp8kv_spt", "fp8kv_dsa"]} @@ -160,8 +189,6 @@ class StartArgs: "eagle_with_att", "vanilla_no_att", "eagle_no_att", - "qwen3next_vanilla", - "qwen3next_eagle", None, ] }, @@ -174,7 +201,7 @@ class StartArgs: pd_node_id: int = field(default=-1) enable_cpu_cache: bool = field(default=False) cpu_cache_storage_size: float = field(default=2) - cpu_cache_token_page_size: int = field(default=64) + cpu_cache_token_page_size: int = field(default=256) enable_disk_cache: bool = field(default=False) disk_cache_storage_size: float = field(default=10) disk_cache_dir: Optional[str] = field(default=None) @@ -189,6 +216,27 @@ class StartArgs: metric_port: int = field(default=None) multinode_httpmanager_port: int = field(default=12345) multi_level_kv_cache_port: int = field(default=None) + # multi_modal + enable_multimodal_audio: bool = field(default=False) + + disable_shm_warning: bool = field(default=False) + dp_balancer: str = field(default="bs_balancer", metadata={"choices": ["round_robin", "bs_balancer"]}) + enable_custom_allgather: bool = field(default=False) + enable_fused_shared_experts: bool = field(default=False) + enable_mps: bool = field(default=False) + multinode_router_gloo_port: int = field(default=20001) + schedule_time_interval: float = field(default=0.03) + use_dynamic_prompt_cache: bool = field(default=False) + disable_custom_allreduce: bool = field(default=False) + enable_torch_memory_saver: bool = field(default=False) + enable_weight_cpu_backup: bool = field(default=False) + hardware_platform: str = field(default="cuda", metadata={"choices": ["cuda", "musa"]}) + enable_torch_fallback: bool = field(default=False) + enable_triton_fallback: bool = field(default=False) + + enable_return_routed_experts: bool = field(default=False) + + weight_version: str = "default" # hybrid attention model (Qwen3Next) linear_att_hash_page_size: int = field(default=512) diff --git a/lightllm/server/detokenization/decode_req.py b/lightllm/server/detokenization/decode_req.py index 9aa3a8effc..c77379986c 100644 --- a/lightllm/server/detokenization/decode_req.py +++ b/lightllm/server/detokenization/decode_req.py @@ -62,11 +62,7 @@ def stop_sequences_str_match(self) -> bool: return False def need_detoken(self): - if ( - (not self.req.is_aborted) - and (not self.req.stop_str_matched) - and len(self.output_ids) < self.req.candetoken_out_len - ): + if (not self.req.stop_str_matched) and len(self.output_ids) < self.req.candetoken_out_len: return True return False @@ -83,8 +79,6 @@ def get_decode_tokens(self): return prefix_tokens, read_tokens def can_set_release_mark(self): - if self.req.is_aborted: - return True if self.req.stop_str_matched: return True if ( diff --git a/lightllm/server/detokenization/manager.py b/lightllm/server/detokenization/manager.py index 389171ba8a..ab4b61acf7 100644 --- a/lightllm/server/detokenization/manager.py +++ b/lightllm/server/detokenization/manager.py @@ -6,7 +6,6 @@ import zmq import inspect from lightllm.server.core.objs import ShmReqManager, StartArgs -from lightllm.server.core.objs.io_objs import GroupReqIndexes from lightllm.utils.graceful_utils import graceful_registry from typing import Union, Dict, List from .decode import decode_token @@ -17,6 +16,7 @@ import time from lightllm.utils.log_utils import init_logger from lightllm.utils.envs_utils import get_unique_server_name +from lightllm.server.core.objs.io_objs import GroupReqIndexes logger = init_logger(__name__) @@ -75,7 +75,6 @@ def handle_loop(self): # 一次最多从 zmq 中取 recv_max_count 个请求,防止 zmq 队列中请求数量过多导致阻塞了主循环。 for _ in range(recv_max_count): recv_obj: GroupReqIndexes = self.zmq_recv_socket.recv_pyobj(zmq.NOBLOCK) - assert isinstance(recv_obj, GroupReqIndexes) self._add_new_group_req_index(recv_obj=recv_obj) # 当队列中存在较多的请求时,将一次接受的数量上调 diff --git a/lightllm/server/httpserver/control_rpyc_client.py b/lightllm/server/httpserver/control_rpyc_client.py new file mode 100644 index 0000000000..5382e71137 --- /dev/null +++ b/lightllm/server/httpserver/control_rpyc_client.py @@ -0,0 +1,76 @@ +import asyncio +import socket +from typing import Optional + +import rpyc +from rpyc.utils.classic import obtain + +from lightllm.server.io_struct import GeneralModelToHttpRpcRsp +from lightllm.utils.log_utils import init_logger + +logger = init_logger(__name__) + + +class ControlRpycClient: + """到 master router 控制面 rpyc service 的异步客户端。 + + - 懒初始化连接, 断连时自动重连。 + - 所有调用通过 `call(method_name, *args)` 派发到 server 端的 root service。 + - 错误统一封装为 GeneralModelToHttpRpcRsp, 调用方无需处理异常。 + """ + + def __init__( + self, + host: str = "127.0.0.1", + port: Optional[int] = None, + config: Optional[dict] = None, + ping_timeout: float = 2.0, + ): + self.host = host + self.port = port + self.config = config if config is not None else {"allow_pickle": True, "sync_request_timeout": 600} + self.ping_timeout = ping_timeout + self._lock: asyncio.Lock = asyncio.Lock() + self._conn: Optional[rpyc.core.protocol.Connection] = None + + async def _get_conn(self) -> rpyc.core.protocol.Connection: + async with self._lock: + if self._conn is not None: + try: + self._conn.ping(timeout=self.ping_timeout) + return self._conn + except BaseException: + try: + self._conn.close() + except BaseException: + pass + self._conn = None + + self._conn = await asyncio.to_thread( + rpyc.connect, + self.host, + self.port, + config=self.config, + ) + self._conn._channel.stream.sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) + return self._conn + + async def call(self, method_name: str, *args) -> GeneralModelToHttpRpcRsp: + try: + conn = await self._get_conn() + ret = await asyncio.to_thread(getattr(conn.root, method_name), *args) + return obtain(ret) + except BaseException as e: + logger.exception(f"control rpyc call {method_name} failed: {e}") + return GeneralModelToHttpRpcRsp( + success=False, msg=f"control rpyc call {method_name} error: {e}", func_name=method_name + ) + + async def close(self): + async with self._lock: + if self._conn is not None: + try: + self._conn.close() + except BaseException: + pass + self._conn = None diff --git a/lightllm/server/httpserver/manager.py b/lightllm/server/httpserver/manager.py index 4c049f77c0..362e5f6629 100644 --- a/lightllm/server/httpserver/manager.py +++ b/lightllm/server/httpserver/manager.py @@ -3,14 +3,15 @@ import zmq.asyncio import asyncio import uvloop -import rpyc import socket +import rpyc import time import copy import hashlib import datetime import pickle from frozendict import frozendict +from .control_rpyc_client import ControlRpycClient asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) from typing import Union, List, Tuple, Dict, Optional, AsyncGenerator @@ -29,8 +30,21 @@ from lightllm.server.core.objs.shm_req_manager import ShmReqManager from lightllm.server.core.objs.atomic_array_lock import AtomicShmArrayLock, AsyncLock, AtomicLockItem from lightllm.server.router.dynamic_prompt.shared_arr import SharedInt +from lightllm.common.basemodel.routing_manager import get_routing_config_shm from lightllm.utils.log_utils import init_logger from lightllm.server.metrics.manager import MetricClient +from lightllm.server.io_struct import ( + AbortReq, + FlushCacheReq, + ReleaseMemoryReq, + ResumeMemoryReq, + InitWeightsUpdateGroupReq, + DestroyWeightsUpdateGroupReq, + UpdateWeightsFromDistributedReq, + UpdateWeightsFromTensorReq, + UpdateWeightsFromIPCReq, + GeneralModelToHttpRpcRsp, +) from lightllm.utils.statics_utils import MovingAverage from lightllm.utils.config_utils import get_vocab_size from lightllm.utils.envs_utils import get_unique_server_name @@ -74,7 +88,7 @@ def __init__( self.multinode_req_manager = context.socket(zmq.PULL) self.multinode_req_manager.bind(f"tcp://*:{args.multinode_httpmanager_port}") logger.info( - f"HttpServerManager listening for child node requests on *:{args.multinode_httpmanager_port}" + f"HttpServerManager listening for master node requests on *:{args.multinode_httpmanager_port}" ) self.enable_multimodal = args.enable_multimodal @@ -123,6 +137,15 @@ def __init__( self.latest_success_infer_time_mark = SharedInt(f"{get_unique_server_name()}_latest_success_infer_time_mark") self.latest_success_infer_time_mark.set_value(int(time.time())) + # Cache routing config for MoE expert routing data extraction + self._routing_shm = get_routing_config_shm() if args.enable_return_routed_experts else None + + self.is_pause = False + self.is_pause_cond = asyncio.Condition() + + # 控制面 rpyc client: 到 master router 的 rpyc service, 懒初始化 + 自动重连 + self._control_rpyc_client = ControlRpycClient(host="127.0.0.1", port=args.control_rpyc_port) + # 用于记录真实的--max_total_token_num 参数,当这个参数在启动参数中没有设置的时候,其是在推理进程中被分析出来的, # 这个时候如果 --max_req_total_len > --max_total_token_num 时,如果httpserver放过一些非法的输入进入后续的模块可能 # 会触发整个系统崩溃,所以httpserver需要知道真实的 max_total_token_num的数据,用于提前拦截非法请求等参数。 @@ -255,11 +278,12 @@ def tokens(self, prompt, multimodal_params, samping_params: SamplingParams, kwar async def loop_for_request(self): assert self.args.node_rank > 0 while True: - ( - prompt, - sampling_params, - multimodal_params, - ) = await self.multinode_req_manager.recv_pyobj() + req_obj = await self.multinode_req_manager.recv_pyobj() + if isinstance(req_obj, AbortReq): + asyncio.create_task(self.abort_request(req_obj)) + continue + # 兼容 main 的协议: master 用 tuple 转发 generate 请求 + prompt, sampling_params, multimodal_params = req_obj results_generator = self.generate(prompt, sampling_params, multimodal_params, None) async def generate_wrapper(results_generator): @@ -331,8 +355,20 @@ async def generate( "verify_and_preload_done", ) + # Debug logging for multimodal requests + if multimodal_params and multimodal_params.images: + logger.debug( + f"[MULTIMODAL_DEBUG] req_id={group_request_id}, " + f"num_images={len(multimodal_params.images)}, " + f"max_new_tokens={sampling_params.max_new_tokens}" + ) + # 记录请求到达的相关信息 await self._log_req_header(request_headers, group_request_id) + + async with self.is_pause_cond: + await self.is_pause_cond.wait_for(lambda: not self.is_pause) + # encode prompt_ids = await self._encode(prompt, multimodal_params, sampling_params) self._log_stage_timing( @@ -406,12 +442,6 @@ async def generate( "shm_req_init_done", ) - logger.debug( - f"alloc shm_req for req_id {group_request_id}, " - f"shm_req num: {sampling_params.n} details (req_id, index_in_shm_mem): " - f"{[(req_obj.request_id, req_obj.index_in_shm_mem) for req_obj in req_objs]}" - ) - req_status = ReqStatus(group_request_id, multimodal_params, req_objs, start_time) self.req_id_to_out_inf[group_request_id] = req_status @@ -537,7 +567,21 @@ async def _encode( # 这里的校验对多模态不是很充分, to do if all(isinstance(e, int) for e in prompt): - if not self.enable_multimodal and not self.pd_mode.is_D(): + if self.enable_multimodal: + assert ( + len(multimodal_params.images + multimodal_params.audios) <= self.args.cache_capacity + ), "too many multimodal items!" + if multimodal_params.audios: + assert self.args.enable_multimodal_audio, "audio multimodal not enabled" + await self._alloc_multimodal_resources(multimodal_params, sampling_params) + prompt_ids = self.tokenizer.encode( + prompt, + multimodal_params, + add_special_tokens=sampling_params.add_special_tokens, + already_tokenized=True, + ) + return prompt_ids + elif not self.enable_multimodal and not self.pd_mode.is_D(): if all(e < self.vocab_size for e in prompt): return prompt else: @@ -776,6 +820,20 @@ async def abort(self, group_req_id: int) -> bool: logger.warning(f"aborted group_request_id {group_req_objs.group_req_id}") return True + async def abort_request(self, request: AbortReq): + request_id = request.request_id + abort_all = request.abort_all + # 多节点纯 tp 运行模式下,master 需要把 abort 转发给 slave,使 slave 也清掉自己 shm 里的请求。 + if self.is_multinode_tp_master: + for sender in self.multinode_req_manager: + sender.send_pyobj(request, protocol=pickle.HIGHEST_PROTOCOL) + if request_id is not None and not abort_all: + await self.abort(request_id) + if abort_all: + for group_req_id in list(self.req_id_to_out_inf.keys()): + await self.abort(group_req_id) + return + async def recycle_resource_loop(self): pre_time_mark = time.time() @@ -798,6 +856,11 @@ async def recycle_resource_loop(self): self.req_id_to_out_inf.pop(req_status.group_req_objs.group_req_id, None) _is_aborted = False for req in req_status.group_req_objs.shm_req_objs: + if hasattr(req, "shm_routing_data") and req.shm_routing_data is not None: + try: + req.close_routing_data_shm_array() + except Exception as e: + logger.debug(f"Failed to close routing data shm for req {req.request_id}: {e}") _is_aborted = _is_aborted or req.is_aborted logger.debug(f"httpserver release req_id {req.request_id}, index {req.index_in_shm_mem}") await self.shm_req_manager.async_put_back_req_obj(req) @@ -843,60 +906,7 @@ async def handle_loop(self): pass try: - for group_req_id_ in list(self.req_id_to_out_inf.keys()): - req_status = self.req_id_to_out_inf.get(group_req_id_, None) - if req_status is None: - continue - - token_list = [] - for req in req_status.group_req_objs.shm_req_objs: - req_id = req.request_id - read_token_count = 1 - if req.out_tokens_queue.is_full(): - read_token_count = LIGHTLLM_OUT_TOKEN_QUEUE_SIZE - - for _ in range(read_token_count): - if not req.out_tokens_queue.is_empty(): - - text, src_index, special, count_output_tokens = req.out_tokens_queue.peek() - req.cumlogprob += float(req.shm_logprobs.arr[src_index]) - metadata = { - "id": int(req.shm_prompt_ids.arr[src_index]), - "logprob": float(req.shm_logprobs.arr[src_index]), - "cumlogprob": float(req.cumlogprob) / count_output_tokens, - "special": special, - "count_output_tokens": count_output_tokens, - "prompt_cache_len": req.prompt_cache_len, - "cpu_prompt_cache_len": req.cpu_prompt_cache_len, - "disk_prompt_cache_len": req.disk_prompt_cache_len, - "mtp_accepted_token_num": req.mtp_accepted_token_num, - } - if self.args.return_all_prompt_logprobs: - metadata.update(req.get_all_prompt_metadata()) - if self.args.use_reward_model: - metadata["score"] = float(req.reward_score) - - req.out_tokens_queue.pop_no_ret() - - finished_token_index = ( - req.stop_str_matched_token_index if req.stop_str_matched else req.finish_token_index - ) - - if finished_token_index != src_index: - token_list.append((req_id, text, metadata, FinishStatus())) - else: - if req.stop_str_matched: - finish_status = FinishStatus(FinishStatus.FINISHED_STOP) - else: - finish_status = FinishStatus(req.finish_status.status) - - token_list.append((req_id, text, metadata, finish_status)) - else: - break - - async with req_status.lock: - req_status.out_token_info_list.extend(token_list) - req_status.event.set() + await self._handle_token_output() except BaseException as e: logger.exception(str(e)) raise e @@ -904,6 +914,132 @@ async def handle_loop(self): self.recycle_event.set() return + async def _handle_token_output(self): + for group_req_id_ in list(self.req_id_to_out_inf.keys()): + req_status = self.req_id_to_out_inf.get(group_req_id_, None) + if req_status is None: + continue + + token_list = [] + for req in req_status.group_req_objs.shm_req_objs: + req_id = req.request_id + read_token_count = 1 + if req.out_tokens_queue.is_full(): + read_token_count = LIGHTLLM_OUT_TOKEN_QUEUE_SIZE + + for _ in range(read_token_count): + if not req.out_tokens_queue.is_empty(): + + text, src_index, special, count_output_tokens = req.out_tokens_queue.peek() + req.cumlogprob += float(req.shm_logprobs.arr[src_index]) + metadata = { + "id": int(req.shm_prompt_ids.arr[src_index]), + "logprob": float(req.shm_logprobs.arr[src_index]), + "cumlogprob": float(req.cumlogprob) / count_output_tokens, + "special": special, + "count_output_tokens": count_output_tokens, + "prompt_cache_len": req.prompt_cache_len, + "cpu_prompt_cache_len": req.cpu_prompt_cache_len, + "mtp_accepted_token_num": req.mtp_accepted_token_num, + } + if self.args.return_all_prompt_logprobs: + metadata.update(req.get_all_prompt_metadata()) + if self.args.use_reward_model: + metadata["score"] = float(req.reward_score) + + req.out_tokens_queue.pop_no_ret() + + finished_token_index = ( + req.stop_str_matched_token_index if req.stop_str_matched else req.finish_token_index + ) + + if finished_token_index != src_index: + token_list.append((req_id, text, metadata, FinishStatus())) + else: + if req.stop_str_matched: + finish_status = FinishStatus(FinishStatus.FINISHED_STOP) + else: + finish_status = FinishStatus(req.finish_status.status) + + if self._routing_shm is not None: + _num_moe = int(self._routing_shm.arr[0]) + _topk = int(self._routing_shm.arr[1]) + _dtype_id = int(self._routing_shm.arr[2]) + if _num_moe > 0: + routing_meta = req.get_routing_metadata(_num_moe, _topk, dtype_id=_dtype_id) + if routing_meta is not None: + metadata["routed_experts"] = routing_meta + + token_list.append((req_id, text, metadata, finish_status)) + else: + break + + async with req_status.lock: + req_status.out_token_info_list.extend(token_list) + req_status.event.set() + + async def pause_generation(self): + # 因为请求是从master node转发到slave node的 + # 所以只要master暂停了,slave自然暂停。 + if self.is_pause: + return + async with self.is_pause_cond: + self.is_pause = True + while True: + await self.abort_request(AbortReq(request_id=None, abort_all=True)) + running_req_num = len(list(self.req_id_to_out_inf.keys())) + if running_req_num == 0: + break + await asyncio.sleep(1.0) + + async def continue_generation(self): + async with self.is_pause_cond: + self.is_pause = False + self.is_pause_cond.notify_all() + + # -------- master router 控制面 rpyc 调用 -------- + + async def flush_cache(self, request: FlushCacheReq): + return await self._control_rpyc_client.call("flush_cache", request) + + async def release_memory_occupation(self, request: ReleaseMemoryReq): + assert len(self.req_id_to_out_inf) == 0, "there are still requests running, cannot release memory occupation" + await self.pause_generation() + return await self._control_rpyc_client.call("release_memory_occupation", request.tags) + + async def resume_memory_occupation(self, request: ResumeMemoryReq): + ret = await self._control_rpyc_client.call("resume_memory_occupation", request.tags) + if ret.success: + await self.continue_generation() + return ret + + async def init_weights_update_group(self, request: InitWeightsUpdateGroupReq): + return await self._control_rpyc_client.call("init_weights_update_group", request) + + async def destroy_weights_update_group(self, request: DestroyWeightsUpdateGroupReq): + return await self._control_rpyc_client.call("destroy_weights_update_group", request) + + async def update_weights_from_distributed(self, request: UpdateWeightsFromDistributedReq): + if request.abort_all_requests: + await self.abort_request(AbortReq(abort_all=True)) + if request.flush_cache: + await self.flush_cache(FlushCacheReq()) + return await self._control_rpyc_client.call("update_weights_from_distributed", request) + + async def update_weights_from_tensor(self, request: UpdateWeightsFromTensorReq) -> GeneralModelToHttpRpcRsp: + if request.abort_all_requests: + await self.abort_request(AbortReq(abort_all=True)) + if request.flush_cache: + await self.flush_cache(FlushCacheReq()) + return await self._control_rpyc_client.call("update_weights_from_tensor", request) + + async def update_weights_from_ipc(self, request: UpdateWeightsFromIPCReq) -> GeneralModelToHttpRpcRsp: + if request.abort_all_requests: + await self.abort_request(AbortReq(abort_all=True)) + if request.flush_cache: + await self.flush_cache(FlushCacheReq()) + return await self._control_rpyc_client.call("update_weights_from_ipc", request) + class ReqStatus: def __init__(self, group_request_id, multimodal_params, req_objs: List[Req], start_time) -> None: diff --git a/lightllm/server/io_struct.py b/lightllm/server/io_struct.py new file mode 100644 index 0000000000..12ed501c58 --- /dev/null +++ b/lightllm/server/io_struct.py @@ -0,0 +1,101 @@ +from dataclasses import dataclass +from typing import List, Optional, Any, Union +from lightllm.utils.torch_memory_saver_utils import MemoryTag + + +@dataclass +class AbortReq: + # 外部调用传入,等同内部的 group_req_id + request_id: Optional[int] = None + abort_all: bool = False + + +def _normalize_memory_tags(tags): + if tags is None: + return None + return [tag if isinstance(tag, MemoryTag) else MemoryTag(tag) for tag in tags] + + +@dataclass +class FlushCacheReq: + pass + + +@dataclass +class ReleaseMemoryReq: + tags: Optional[List[MemoryTag]] = None + + def __post_init__(self): + self.tags = _normalize_memory_tags(self.tags) + + +@dataclass +class ResumeMemoryReq: + tags: Optional[List[MemoryTag]] = None + + def __post_init__(self): + self.tags = _normalize_memory_tags(self.tags) + + +@dataclass +class GeneralHttpToModelRpcReq: + func_name: str + func_args: Optional[Any] = None + + +@dataclass +class GeneralModelToHttpRpcRsp: + success: bool + msg: Optional[str] + func_name: str + func_rsp: Optional[Any] = None + + +@dataclass +class InitWeightsUpdateGroupReq: + master_address: str + master_port: int + rank_offset: int + world_size: int + group_name: str = "weight_update_group" + backend: str = "nccl" + + +@dataclass +class DestroyWeightsUpdateGroupReq: + group_name: str = "weight_update_group" + + +@dataclass +class UpdateWeightsFromDistributedReq: + names: List[str] + dtypes: List[str] + shapes: List[List[int]] + group_name: str = "weight_update_group" + flush_cache: bool = True + abort_all_requests: bool = False + weight_version: Optional[str] = None + + +@dataclass +class UpdateWeightsFromTensorReq: + """Update model weights from tensor input. + + - Tensors are serialized for transmission + - Data is structured in JSON for easy transmission over HTTP + """ + + serialized_named_tensors: List[Union[str, bytes]] + load_format: Optional[str] = None + flush_cache: bool = True + abort_all_requests: bool = False + weight_version: Optional[str] = None + + +@dataclass +class UpdateWeightsFromIPCReq: + ipc_handle: str = None + use_shm: bool = False + flush_cache: bool = True + abort_all_requests: bool = False + weight_version: Optional[str] = None diff --git a/lightllm/server/multi_level_kv_cache/manager.py b/lightllm/server/multi_level_kv_cache/manager.py index 0a7dec0005..205ae0e537 100644 --- a/lightllm/server/multi_level_kv_cache/manager.py +++ b/lightllm/server/multi_level_kv_cache/manager.py @@ -12,7 +12,7 @@ from queue import Queue from typing import List from lightllm.server.core.objs import ShmReqManager, Req, StartArgs -from lightllm.server.core.objs.io_objs import GroupReqIndexes +from lightllm.server.core.objs.io_objs.group_req import GroupReqIndexes from lightllm.utils.graceful_utils import graceful_registry from .cpu_cache_client import CpuKvCacheClient from lightllm.utils.log_utils import init_logger @@ -219,7 +219,7 @@ def recv_loop(self): # 一次最多从 zmq 中取 recv_max_count 个请求,防止 zmq 队列中请求数量过多导致阻塞了主循环。 for _ in range(recv_max_count): recv_obj: GroupReqIndexes = self.zmq_recv_socket.recv_pyobj(zmq.NOBLOCK) - assert isinstance(recv_obj, GroupReqIndexes) + assert isinstance(recv_obj, GroupReqIndexes), f"unexpected req type: {type(recv_obj)}" recv_objs.append(recv_obj) start_time = recv_obj.time_mark diff --git a/lightllm/server/multimodal_params.py b/lightllm/server/multimodal_params.py index cb5ec52bd2..ffb3e232f9 100644 --- a/lightllm/server/multimodal_params.py +++ b/lightllm/server/multimodal_params.py @@ -75,9 +75,11 @@ def read(self): assert self._preload_data is not None ans = self._preload_data self._preload_data = None - self._data = None return ans + def free(self): + self._data = None + def to_dict(self): ret = {} ret["uuid"] = self.uuid @@ -164,9 +166,11 @@ def read(self): assert self._preload_data is not None ans = self._preload_data self._preload_data = None - self._data = None return ans + def free(self): + self._data = None + def to_dict(self): ret = {} ret["uuid"] = self.uuid @@ -220,3 +224,10 @@ def to_origin_dict(self): ret["images"] = [i.to_origin_dict() for i in self.images] ret["audios"] = [a.to_origin_dict() for a in self.audios] return ret + + def free(self): + for image in self.images: + image.free() + for audio in self.audios: + audio.free() + return diff --git a/lightllm/server/req_id_generator.py b/lightllm/server/req_id_generator.py index f7c099c292..bc81d835b6 100644 --- a/lightllm/server/req_id_generator.py +++ b/lightllm/server/req_id_generator.py @@ -30,7 +30,8 @@ def __init__(self): self.current_id.arr[0] = 0 self.current_id.arr[1] = 0 self.lock = AtomicShmLock(f"{get_unique_server_name()}_req_id_gen_lock") - self._wait_all_workers_ready() + if self.args.httpserver_workers > 1: + self._wait_all_workers_ready() logger.info("ReqIDGenerator init finished") def _wait_all_workers_ready(self): diff --git a/lightllm/server/router/control_rpyc.py b/lightllm/server/router/control_rpyc.py new file mode 100644 index 0000000000..d803abedc3 --- /dev/null +++ b/lightllm/server/router/control_rpyc.py @@ -0,0 +1,61 @@ +import threading +import rpyc +from rpyc.utils.classic import obtain +from rpyc.utils.server import ThreadedServer +from lightllm.utils.log_utils import init_logger + +logger = init_logger(__name__) + + +class RouterControlRpcService(rpyc.Service): + """挂在 master router 进程上的控制面 rpyc service。httpserver 通过 rpyc 同步调用, + rpyc 线程把请求投递到 router 的 asyncio loop(_control_op_queue),等结果。 + 多机情况下,只有 node_rank=0 启动这个 service;slave 通过 router 的 NCCL 广播协同。""" + + def __init__(self, router): + super().__init__() + self._router = router + + def exposed_flush_cache(self, request): + return self._router.submit_control_op("flush_cache", obtain(request)) + + def exposed_release_memory_occupation(self, tags): + return self._router.submit_control_op("release_memory_occupation", obtain(tags)) + + def exposed_resume_memory_occupation(self, tags): + return self._router.submit_control_op("resume_memory_occupation", obtain(tags)) + + def exposed_init_weights_update_group(self, request): + return self._router.submit_control_op("init_weights_update_group", obtain(request)) + + def exposed_destroy_weights_update_group(self, request): + return self._router.submit_control_op("destroy_weights_update_group", obtain(request)) + + def exposed_update_weights_from_distributed(self, request): + return self._router.submit_control_op("update_weights_from_distributed", obtain(request)) + + def exposed_update_weights_from_tensor(self, request): + return self._router.submit_control_op("update_weights_from_tensor", obtain(request)) + + def exposed_update_weights_from_ipc(self, request): + return self._router.submit_control_op("update_weights_from_ipc", obtain(request)) + + +def start_control_rpyc_server(router, port: int) -> None: + """在 daemon 线程里启动 rpyc ThreadedServer。绑定 127.0.0.1,httpserver 同机调用即可。""" + + def _run(): + try: + t = ThreadedServer( + RouterControlRpcService(router), + hostname="127.0.0.1", + port=port, + protocol_config={"allow_pickle": True, "sync_request_timeout": 600}, + ) + logger.info(f"control rpyc server listening on 127.0.0.1:{port}") + t.start() + except BaseException as e: + logger.exception(f"control rpyc server crashed: {e}") + + th = threading.Thread(target=_run, name="control_rpyc_server", daemon=True) + th.start() diff --git a/lightllm/server/router/dynamic_prompt/linear_att_radix_cache.py b/lightllm/server/router/dynamic_prompt/linear_att_radix_cache.py index 73c6dba54d..082f6bec08 100644 --- a/lightllm/server/router/dynamic_prompt/linear_att_radix_cache.py +++ b/lightllm/server/router/dynamic_prompt/linear_att_radix_cache.py @@ -467,6 +467,10 @@ def clear_tree_nodes(self): self.free_radix_cache_to_get_enough_token(need_token_num=self.total_token_num) return + def flush_cache(self): + self.free_radix_cache_to_get_enough_token(need_token_num=self.total_token_num) + return + def deref_to_first_big_page_node(self, node: LinearAttPagedTreeNode) -> Optional[LinearAttPagedTreeNode]: assert not node.is_big_page_node() iter_node = node diff --git a/lightllm/server/router/dynamic_prompt/radix_cache.py b/lightllm/server/router/dynamic_prompt/radix_cache.py index 88b099459b..f17186da11 100644 --- a/lightllm/server/router/dynamic_prompt/radix_cache.py +++ b/lightllm/server/router/dynamic_prompt/radix_cache.py @@ -106,6 +106,7 @@ class RadixCache: def __init__(self, unique_name, total_token_num, rank_in_node, mem_manager=None): from lightllm.common.kv_cache_mem_manager import MemoryManager + self.total_token_num = total_token_num self.mem_manager: MemoryManager = mem_manager self._key_dtype = torch.int64 self._value_dtype = torch.int64 @@ -425,6 +426,10 @@ def clear_tree_nodes(self): self.refed_tokens_num.arr[0] = 0 return + def flush_cache(self): + self.free_radix_cache_to_get_enough_token(need_token_num=self.total_token_num) + return + def dec_node_ref_counter(self, node: TreeNode): if node is None: return diff --git a/lightllm/server/router/manager.py b/lightllm/server/router/manager.py index 24f8da6e6f..6a6698ee96 100644 --- a/lightllm/server/router/manager.py +++ b/lightllm/server/router/manager.py @@ -5,6 +5,8 @@ import pickle import inspect import setproctitle +import queue +import concurrent.futures asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) import zmq @@ -12,7 +14,7 @@ import torch.multiprocessing as mp import torch.distributed as dist import multiprocessing -from typing import Dict, List, Optional, Union +from typing import Dict, List, Optional, Tuple, Union from .batch import Batch, Req from .model_infer.model_rpc import start_model_process, ModelRpcClient from .req_queue import build_req_queue @@ -29,6 +31,10 @@ from lightllm.server.router.token_load import TokenLoad from lightllm.server.metrics.manager import MetricClient from lightllm.common.basemodel.infer_lock import g_router_lock +from lightllm.server.io_struct import ( + GeneralHttpToModelRpcReq, + GeneralModelToHttpRpcRsp, +) from lightllm.common.kv_cache_mem_manager import ReadOnlyStaticsMemoryManager from lightllm.utils.graceful_utils import graceful_registry from lightllm.utils.process_check import start_parent_check_thread @@ -110,6 +116,11 @@ def __init__(self, args: StartArgs): else CpuKvCacheClient(only_create_meta_data=True, init_shm_data=False) ) self.router_statics = RouterStatics(self.args) + + # 控制面 rpyc 队列:rpyc 线程把 (req, future) 放进来,asyncio 主循环每个 step 取出处理 + self._control_op_queue: "queue.Queue[Tuple[GeneralHttpToModelRpcReq, concurrent.futures.Future]]" = ( + queue.Queue() + ) return async def wait_to_model_ready(self): @@ -363,8 +374,13 @@ def _get_aborted_reqs_from_running_batch(self) -> List[Req]: ans = [] if self.running_batch is None: return ans - for req in self.running_batch.reqs: - if req.is_aborted and req._router_aborted is False: + aborted_req_mask = torch.tensor( + [req.is_aborted for req in self.running_batch.reqs], dtype=torch.bool, device="cpu" + ) + if self.is_multinode_tp: + dist.all_reduce(aborted_req_mask, op=dist.ReduceOp.MIN, group=self.mulitnode_group) + for req, is_aborted in zip(self.running_batch.reqs, aborted_req_mask.numpy()): + if is_aborted and req._router_aborted is False: req._router_aborted = True ans.append(req) return ans @@ -461,9 +477,22 @@ def _multinode_tp_generate_new_batch(self): dist.broadcast_object_list(req_ids, src=0, group=self.mulitnode_group) req_id_select_mark = [1 for _ in range(len(req_ids))] req_id_select_mark = torch.tensor(req_id_select_mark, dtype=torch.int32, device="cpu") + # TODO: 这里可以合成一个 allreudce,req_id_select_mark + aborted_req_mask dist.all_reduce(req_id_select_mark, op=dist.ReduceOp.MIN, group=self.mulitnode_group) + aborted_req_mask = torch.tensor( + [req.is_aborted for req in new_batch.reqs], dtype=torch.bool, device="cpu" + ) + dist.all_reduce(aborted_req_mask, op=dist.ReduceOp.MIN, group=self.mulitnode_group) back_req_list = [] - for req_id, select in zip(req_ids, req_id_select_mark.numpy()): + for req_id, select, is_aborted in zip( + req_ids, req_id_select_mark.numpy(), aborted_req_mask.numpy() + ): + # 释放多节点abort 请求,如果select == 0, is_aborted 一定为False + if is_aborted and select == 1: + req = new_batch.pop_req(req_id) + self.req_queue.free_aborted_req(req) + self.shm_req_manager.put_back_req_obj(req) + continue if select == 0: req = new_batch.pop_req(req_id) back_req_list.append(req) @@ -479,23 +508,28 @@ def _multinode_tp_generate_new_batch(self): else: req_ids = [None for _ in range(req_num)] dist.broadcast_object_list(req_ids, src=0, group=self.mulitnode_group) - all_req_id_set = set([req.request_id for req in self.req_queue.waiting_req_list]) + # all_req_id_set = set([req.request_id for req in self.req_queue.waiting_req_list]) + id_to_req_obj = {req.request_id: req for req in self.req_queue.waiting_req_list} req_id_select_mark = [] + aborted_req_mask = [] for req_id in req_ids: - req_id_select_mark.append(1 if req_id in all_req_id_set else 0) + req_id_select_mark.append(1 if req_id in id_to_req_obj else 0) + aborted_req_mask.append(id_to_req_obj[req_id].is_aborted if req_id in id_to_req_obj else False) req_id_select_mark = torch.tensor(req_id_select_mark, dtype=torch.int32, device="cpu") dist.all_reduce(req_id_select_mark, op=dist.ReduceOp.MIN, group=self.mulitnode_group) - select_req_ids = [] - for req_id, select in zip(req_ids, req_id_select_mark.numpy()): - if select == 1: - select_req_ids.append(req_id) - + aborted_req_mask = torch.tensor(aborted_req_mask, dtype=torch.bool, device="cpu") + dist.all_reduce(aborted_req_mask, op=dist.ReduceOp.MIN, group=self.mulitnode_group) select_reqs = [] - for req_id in select_req_ids: - for req in self.req_queue.waiting_req_list: - if req.request_id == req_id: - select_reqs.append(req) - + for req_id, select, is_aborted in zip( + req_ids, req_id_select_mark.numpy(), aborted_req_mask.numpy() + ): + if select == 1: + req = id_to_req_obj[req_id] + if is_aborted: + self.req_queue.free_aborted_req(req) + self.shm_req_manager.put_back_req_obj(req) + continue + select_reqs.append(req) for req in select_reqs: self.req_queue.waiting_req_list.remove(req) if select_reqs: @@ -522,7 +556,7 @@ async def _recv_new_reqs_and_schedule(self): if isinstance(recv_req, GroupReqIndexes): self._add_req(recv_req) else: - assert False, f"Error Req Inf {recv_req}" + raise ValueError(f"Unknown request type: {type(recv_req)}") # 当队列中存在较多的请求时,将一次接受的数量上调 self.recv_max_count = min(int(self.recv_max_count * 1.3), 256) @@ -531,6 +565,8 @@ async def _recv_new_reqs_and_schedule(self): # 当队列已经开始清空的时候,将一次接受的数量下调 self.recv_max_count = 64 + await self._process_special_reqs() + if self.is_multinode_tp: self._multinode_tp_generate_new_batch() else: @@ -538,9 +574,88 @@ async def _recv_new_reqs_and_schedule(self): self._generate_new_batch() return + async def _process_special_reqs(self): + # master: 从 rpyc 队列里取出 (req, future) — slave 的队列恒为空(无 rpyc service) + pairs: List[Tuple[GeneralHttpToModelRpcReq, concurrent.futures.Future]] = [] + while True: + try: + pair = self._control_op_queue.get_nowait() + pairs.append(pair) + except queue.Empty: + break + + reqs: List[GeneralHttpToModelRpcReq] = [req for req, _ in pairs] + + # 多机 TP:master 通过 NCCL 广播 req 到 slave router;slave 在自己的主循环里到达此处时,会从 broadcast 收到 master 的 reqs + if self.is_multinode_tp: + reqs = self.broadcast_reqs_to_other_nodes(reqs) + + for i, req in enumerate(reqs): + assert isinstance(req, GeneralHttpToModelRpcReq), "special request must be GeneralHttpToModelRpcReq" + try: + ret = await self.forward_to_model(req) + except BaseException as e: + logger.exception(f"forward_to_model failed for {req.func_name}: {e}") + ret = GeneralModelToHttpRpcRsp( + success=False, msg=f"forward_to_model error: {e}", func_name=req.func_name + ) + # 只有 master 持有 future,slave 的 pairs 始终为空 + if i < len(pairs): + _, fut = pairs[i] + if not fut.done(): + fut.set_result(ret) + + def broadcast_reqs_to_other_nodes(self, reqs: List[GeneralHttpToModelRpcReq]): + req_num = len(reqs) + if self.node_rank == 0: + req_nums = [len(reqs)] + dist.broadcast_object_list(req_nums, src=0, group=self.mulitnode_group) + req_num = req_nums[0] + if req_num > 0: + dist.broadcast_object_list(reqs, src=0, group=self.mulitnode_group) + else: + req_nums = [None] + dist.broadcast_object_list(req_nums, src=0, group=self.mulitnode_group) + req_num = req_nums[0] + if req_num > 0: + reqs = [None for _ in range(req_num)] + dist.broadcast_object_list(reqs, src=0, group=self.mulitnode_group) + return reqs + + async def forward_to_model(self, req: GeneralHttpToModelRpcReq) -> GeneralModelToHttpRpcRsp: + forward_to_model_tasks = [] + for model_rpc_client in self.model_rpc_clients: + forward_to_model_tasks.append(model_rpc_client.forward_to_model(req)) + all_ret = await asyncio.gather(*forward_to_model_tasks) + succes = all(ret.success for ret in all_ret) + ret: GeneralModelToHttpRpcRsp = all_ret[0] + ret.success = succes + if self.is_multinode_tp: + output_list = [None for _ in range(self.nnodes)] if self.node_rank == 0 else None + dist.gather_object(ret, output_list, dst=0, group=self.mulitnode_group) + if self.node_rank == 0: + for res in output_list: + res: GeneralModelToHttpRpcRsp + if not res.success: + ret = res + break + return ret + def clean_up(self): return + def submit_control_op(self, func_name: str, func_args, timeout: float = 300.0) -> GeneralModelToHttpRpcRsp: + """从 rpyc 线程调用,把控制面操作投递到 asyncio 主循环,同步等结果返回。""" + req = GeneralHttpToModelRpcReq(func_name=func_name, func_args=func_args) + fut: concurrent.futures.Future = concurrent.futures.Future() + self._control_op_queue.put((req, fut)) + try: + return fut.result(timeout=timeout) + except concurrent.futures.TimeoutError: + return GeneralModelToHttpRpcRsp( + success=False, msg=f"control op {func_name} timeout after {timeout}s", func_name=func_name + ) + def start_router_process(args, pipe_writer): # 注册 graceful 退出的处理 @@ -572,6 +687,12 @@ def handle_exception(loop, context): router.clean_up() raise + # master node 启动控制面 rpyc service。slave 不需要,通过 NCCL 接收广播。 + if args.node_rank == 0 and args.control_rpyc_port is not None: + from .control_rpyc import start_control_rpyc_server + + start_control_rpyc_server(router, args.control_rpyc_port) + pipe_writer.send("init ok") loop.run_until_complete(router.loop_for_fwd()) return diff --git a/lightllm/server/router/model_infer/infer_batch.py b/lightllm/server/router/model_infer/infer_batch.py index 7c19b5748e..bb8ec1eada 100644 --- a/lightllm/server/router/model_infer/infer_batch.py +++ b/lightllm/server/router/model_infer/infer_batch.py @@ -24,6 +24,7 @@ from lightllm.utils.envs_utils import get_env_start_args from lightllm.server.pd_io_struct import NIXLDecodeNodeInfo from lightllm.server.embed_cache.embed_cache_client import CpuEmbedCacheClient +from lightllm.common.basemodel import routing_manager as _routing_mgr logger = init_logger(__name__) @@ -122,6 +123,16 @@ def add_reqs(self, requests: List[Tuple[int, int, Any, int]], init_prefix_cache: return req_objs + def _extract_routing_data(self, req: "InferReq"): + if req.shm_req.shm_routing_num_tokens > 0: + return + mem_indexes = self.req_manager.req_to_token_indexs[req.req_idx][0 : req.cur_kv_len] + mgr = _routing_mgr.g_routing_capture_manager + routing_data = mgr.extract_routing_data(mem_indexes) + req.shm_req.create_routing_data_shm_array(mgr.num_moe_layers, req.cur_kv_len, mgr.topk, np_dtype=mgr.np_dtype) + req.shm_req.shm_routing_data.arr[:] = routing_data + req.shm_req.shm_routing_data.detach_shm() + def free_a_req_mem(self, free_token_index: List, req: "InferReq"): if self.radix_cache is None: free_token_index.append(self.req_manager.req_to_token_indexs[req.req_idx][0 : req.cur_kv_len]) @@ -848,6 +859,8 @@ def update_finish_status(self, eos_ids, output_len: int): self.finish_status.set_status(FinishStatus.FINISHED_STOP) elif output_len >= self.sampling_param.shm_param.max_new_tokens: self.finish_status.set_status(FinishStatus.FINISHED_LENGTH) + elif self.infer_aborted: + self.finish_status.set_status(FinishStatus.FINISHED_ABORTED) return def _stop_sequences_matched(self, output_len: int): @@ -937,6 +950,8 @@ def handle( shm_req.shm_cur_output_len = self.output_len if finish_status.is_finished(): + if _routing_mgr.g_routing_capture_manager is not None: + g_infer_context._extract_routing_data(req_obj) shm_req.finish_token_index = shm_req.input_len + self.output_len - 1 shm_req.finish_status = req_obj.finish_status diff --git a/lightllm/server/router/model_infer/mode_backend/base_backend.py b/lightllm/server/router/model_infer/mode_backend/base_backend.py index ca982ec0f0..1509f0c9cf 100644 --- a/lightllm/server/router/model_infer/mode_backend/base_backend.py +++ b/lightllm/server/router/model_infer/mode_backend/base_backend.py @@ -4,7 +4,7 @@ import time import threading import torch.distributed as dist -from typing import List, Tuple, Callable, Optional +from typing import List, Tuple, Callable, Optional, Union from transformers.configuration_utils import PretrainedConfig from lightllm.utils.infer_utils import set_random_seed from lightllm.utils.log_utils import init_logger @@ -19,7 +19,7 @@ from lightllm.server.router.dynamic_prompt.radix_cache import RadixCache from lightllm.common.basemodel.batch_objs import ModelOutput, ModelInput from lightllm.common.basemodel.triton_kernel.mtp_utils import mtp_verify -from lightllm.utils.dist_utils import init_distributed_env +from lightllm.utils.dist_utils import init_distributed_env, init_custom_process_group from lightllm.utils.envs_utils import get_unique_server_name from lightllm.server.core.objs import ShmReqManager, StartArgs from lightllm.server.core.objs.io_objs import AbortedReqCmd, StopStrMatchedReqCmd @@ -34,8 +34,11 @@ enable_radix_tree_timer_merge, get_radix_tree_merge_update_delta, ) +from lightllm.utils.serializer import LocalSerializedTensor, MultiprocessingSerializer +from lightllm.utils.patch_torch import monkey_patch_torch_reductions +from lightllm.utils.tensor_bucket import FlattenedTensorBucket, FlattenedTensorMetadata +from lightllm.distributed import dist_group_manager from lightllm.distributed.communication_op import ( - dist_group_manager, all_gather_into_tensor, all_reduce, broadcast, @@ -49,7 +52,17 @@ from lightllm.server.router.model_infer.mode_backend.generic_post_process import sample from lightllm.common.basemodel.triton_kernel.gather_token_id import scatter_token from lightllm.server.pd_io_struct import NIXLChunckedTransTaskRet +from lightllm.common.basemodel import routing_manager as _routing_mgr +from lightllm.utils.torch_memory_saver_utils import MemoryTag from .multi_level_kv_cache import MultiLevelKvCacheModule +from lightllm.server.io_struct import ( + FlushCacheReq, + InitWeightsUpdateGroupReq, + DestroyWeightsUpdateGroupReq, + UpdateWeightsFromDistributedReq, + UpdateWeightsFromIPCReq, + UpdateWeightsFromTensorReq, +) class ModeBackend: @@ -122,6 +135,8 @@ def init_model(self, kvargs): ) dist_group_manager.create_groups(group_size=group_size) # set the default group + self._model_update_group = {} + self.shared_token_load = TokenLoad(f"{get_unique_server_name()}_shared_token_load", self.dp_size_in_node) # 为 p d 分离模式添加的全局锁管理,用于做一些同步操作。 一定需要在 @@ -366,6 +381,219 @@ def init_mtp_draft_model(self, main_kvargs: dict): self.logger.info(f"loaded mtp model class {self.draft_models[i].__class__}") return + def flush_cache(self, request: FlushCacheReq): + if self.radix_cache is not None: + self.radix_cache.flush_cache() + return True, "Succeeded to flush cache." + + def release_memory_occupation(self, tags: List[MemoryTag]): + try: + self.model.release_memory_occupation(tags) + self.flush_cache(request=None) + return True, "Succeeded to release memory occupation." + except Exception as e: + self.logger.error(f"release memory occupation failed: {str(e)}") + return False, f"release memory occupation failed: {str(e)}" + + def resume_memory_occupation(self, tags: List[MemoryTag]): + try: + self.model.resume_memory_occupation(tags) + return True, "Succeeded to resume memory occupation." + except Exception as e: + self.logger.error(f"resume memory occupation failed: {str(e)}") + return False, f"resume memory occupation failed: {str(e)}" + + def init_weights_update_group(self, request: InitWeightsUpdateGroupReq): + assert torch.distributed.is_initialized(), "Default torch process group must be initialized" + + assert request.group_name != "", "Group name cannot be empty" + rank_offset = request.rank_offset + rank = rank_offset + self.rank_in_dp + world_size = request.world_size + group_name = request.group_name + self.logger.info( + f"init custom process group: master_address={request.master_address}, master_port={request.master_port}, " + f"rank_offset={rank_offset}, rank={rank}, world_size={world_size}, group_name={group_name}, " + f" backend={request.backend}" + ) + + try: + if group_name in self._model_update_group: + raise ValueError(f"Process group with name {group_name} already exists.") + + self._model_update_group[group_name] = init_custom_process_group( + backend=request.backend, + init_method=f"tcp://{request.master_address}:{request.master_port}", + world_size=world_size, + rank=rank, + group_name=group_name, + ) + return True, "Succeeded to initialize custom process group." + + except Exception as e: + message = f"Failed to initialize custom process group: {e}." + self.logger.error(message) + return False, message + + def destroy_weights_update_group(self, request: DestroyWeightsUpdateGroupReq): + try: + if request.group_name in self._model_update_group: + pg = self._model_update_group.pop(request.group_name) + torch.distributed.destroy_process_group(pg) + return True, "Succeeded to destroy custom process group." + else: + return False, "The group to be destroyed does not exist." + except Exception as e: + message = f"Failed to destroy custom process group: {e}." + self.logger.error(message) + return False, message + + def update_weights_from_distributed(self, request: UpdateWeightsFromDistributedReq): + """ + Update specific parameter in the model weights online + through `_model_update_group` process group. + + Args: + name: the name of the parameter to be updated. + dtype: the data type of the parameter to be updated. + shape: the shape of the parameter to be updated. + """ + + assert request.group_name in self._model_update_group, ( + f"Group {request.group_name} not in {list(self._model_update_group.keys())}. " + "Please call `init_weights_update_group` first." + ) + + try: + weights = [] + handles = [] + for name, dtype, shape in zip(request.names, request.dtypes, request.shapes): + target_dtype = dtype if isinstance(dtype, torch.dtype) else getattr(torch, dtype) + weight = torch.empty(shape, dtype=target_dtype, device="cuda") + handles.append( + torch.distributed.broadcast( + weight, + src=0, + group=self._model_update_group[request.group_name], + async_op=True, + ) + ) + weights.append((name, weight)) + for handle in handles: + handle.wait() + + self.model.load_weights(weights) + return True, "Succeeded to update parameter online from distributed." + + except Exception as e: + error_msg = ( + f"Failed to update parameter online: {e}. " + f"The full weights of the ModelRunner are partially updated. " + f"Please discard the whole weights." + ) + self.logger.error(error_msg) + return False, error_msg + + def _update_weights_from_flattened_bucket( + self, + flattened_tensor_bucket_dict, + ): + """Handle flattened bucket format for weight updates""" + flattened_tensor = flattened_tensor_bucket_dict["flattened_tensor"] + metadata = flattened_tensor_bucket_dict["metadata"] + + # Convert metadata dict to our format + converted_metadata = [] + for meta in metadata: + converted_meta = FlattenedTensorMetadata( + name=meta.name, + shape=meta.shape, + dtype=meta.dtype, + start_idx=meta.start_idx, + end_idx=meta.end_idx, + numel=meta.numel, + ) + converted_metadata.append(converted_meta) + + # Create bucket and reconstruct tensors + bucket = FlattenedTensorBucket(flattened_tensor=flattened_tensor, metadata=converted_metadata) + reconstructed_tensors = bucket.reconstruct_tensors() + + named_tensors = {name: tensor for name, tensor in reconstructed_tensors} + + # Load the reconstructed tensors using the standard method + self.model.load_weights(named_tensors) + + return True, "Succeeded to update parameter online from flattened bucket tensor." + + def update_weights_from_tensor(self, request: UpdateWeightsFromTensorReq): + try: + monkey_patch_torch_reductions() + if request.load_format == "flattened_bucket": + # Handle flattened bucket format + serialized_named_tensors = MultiprocessingSerializer.deserialize( + request.serialized_named_tensors[self.rank_in_dp] + ) + return self._update_weights_from_flattened_bucket(flattened_tensor_bucket_dict=serialized_named_tensors) + + # We need to get device after patch otherwise the device would be wrong + self.device_module = torch.get_device_module("cuda") + infered_device = self.device_module.current_device() + + named_tensors = MultiprocessingSerializer.deserialize(request.serialized_named_tensors[self.rank_in_dp]) + + def _unwrap_tensor(tensor, tp_rank, device): + if isinstance(tensor, LocalSerializedTensor): + tensor = tensor.get(tp_rank) + clone = tensor.to(device).clone() + del tensor # free the ipc tensor + return clone + + named_tensors = { + name: _unwrap_tensor(tensor, tp_rank=self.rank_in_dp, device=infered_device) + for name, tensor in named_tensors + } + + self.model.load_weights(named_tensors) + + return True, "Succeeded to update parameter online from tensor." + + except Exception as e: + message = f"Failed to update parameter online from tensor. Reason: {e}." + self.logger.error(message) + + return False, message + + def update_weights_from_ipc(self, request: UpdateWeightsFromIPCReq): + try: + from .bucketed_weight_transfer import BucketedWeightReceiver, get_zmq_handle + + zmq_handle = get_zmq_handle() + use_shm = request.use_shm + recv_device = torch.device("cuda", self.current_device_id) + self.logger.debug( + "[LightLLM] base_backend.update_weights_from_ipc: request.ipc_handle=%r, " + "resolved zmq_handle=%r, cuda_device_id=%s", + request.ipc_handle, + zmq_handle, + self.current_device_id, + ) + + bucketed_weight_receiver = BucketedWeightReceiver( + zmq_handle=zmq_handle, device=recv_device, use_shm=use_shm + ) + bucketed_weight_receiver.receive_weights(on_bucket_received=self.model.load_weights) + return True, "Succeeded to update parameter online from ipc." + + except Exception as e: + import traceback + + traceback.print_exc() + message = f"Failed to update parameter online from tensor. Reason: {e}." + self.logger.error(message) + + return False, message + def _async_copy_next_token_infos_to_pin_mem(self, next_token_ids: torch.Tensor, next_token_logprobs: torch.Tensor): """ 这个函数会把next token id和logprobs保存到pinned memory中 @@ -616,7 +844,7 @@ def _get_classed_reqs( paused_reqs.append(req_obj) continue - if req_obj.infer_aborted or req_obj.finish_status.is_finished(): + if req_obj.finish_status.is_finished(): if support_overlap: # 延迟处理 req_obj.filter_mark = True @@ -826,6 +1054,18 @@ def _sample_and_scatter_token( ) return next_token_ids, next_token_ids_cpu, next_token_logprobs_cpu + def _flush_routing_to_kv_buffer(self, mem_indexes: torch.Tensor, microbatch_index: int = 0) -> None: + """Scatter captured routing data from capture buffer to KV-indexed GPU buffer. + + Must be called AFTER model.forward() completes. mem_indexes should be the + original (unpadded) tensor — either CPU or CUDA. + """ + if _routing_mgr.g_routing_capture_manager is not None and mem_indexes is not None: + if not mem_indexes.is_cuda: + mem_indexes = mem_indexes.cuda(non_blocking=True) + num_tokens = mem_indexes.shape[0] + _routing_mgr.g_routing_capture_manager.flush_to_routing_buffer(mem_indexes, num_tokens, microbatch_index) + def _dp_all_gather_prefill_and_decode_req_num( self, prefill_reqs: List[InferReq], decode_reqs: List[InferReq] ) -> Tuple[np.ndarray, np.ndarray]: diff --git a/lightllm/server/router/model_infer/mode_backend/bucketed_weight_transfer.py b/lightllm/server/router/model_infer/mode_backend/bucketed_weight_transfer.py new file mode 100644 index 0000000000..47ccbf5e32 --- /dev/null +++ b/lightllm/server/router/model_infer/mode_backend/bucketed_weight_transfer.py @@ -0,0 +1,339 @@ +# Copyright 2025 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Bucketed weight transfer via ZMQ + IPC (or shared memory fallback). +copy from https://github.com/verl-project/verl/blob/main/verl/workers/rollout/vllm_rollout/bucketed_weight_transfer.py +""" + +import gc +import logging +import os +from multiprocessing import shared_memory +from typing import Callable, TypedDict + +import torch +import zmq +from torch.multiprocessing.reductions import reduce_tensor +from lightllm.utils.patch_torch import _device_to_uuid as get_device_uuid + +logger = logging.getLogger(__file__) +logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "INFO")) + + +is_cuda_available = torch.cuda.is_available() + +def get_device_name() -> str: + """Function that gets the torch.device based on the current machine. + This currently only supports CPU, CUDA, NPU. + Returns: + device + """ + if is_cuda_available: + device = "cuda" + else: + device = "cpu" + return device + + +def get_torch_device() -> any: + """Return the corresponding torch attribute based on the device type string. + Returns: + module: The corresponding torch device namespace, or torch.cuda if not found. + """ + device_name = get_device_name() + try: + return getattr(torch, device_name) + except AttributeError: + logger.warning(f"Device namespace '{device_name}' not found in torch, try to load torch.cuda.") + return torch.cuda + +def get_device_id() -> int: + """Return current device id based on the device type. + Returns: + device index + """ + return get_torch_device().current_device() + + +def get_zmq_handle() -> str: + return f"ipc:///tmp/rl-colocate-zmq-{get_device_uuid(get_device_id())}.sock" + + + +class TensorMetadata(TypedDict): + name: str + shape: torch.Size + dtype: torch.dtype + offset: int + + +# copy from https://github.com/vllm-project/vllm/blob/main/examples/offline_inference/rlhf_utils.py +def rebuild_ipc(handle: tuple[Callable, tuple], device_id: int | None = None) -> torch.Tensor: + func, args = handle + list_args = list(args) + if device_id is not None: + # the key is to change device id to the current device id + # in case two processes have different CUDA_VISIBLE_DEVICES + list_args[6] = device_id + buffer = func(*list_args) + return buffer + + +def create_shared_memory(size: int, name: str): + """Create shared memory for weight transfer. If already exists, attach to it.""" + try: + shm = shared_memory.SharedMemory(name=name, create=True, size=size) + except FileExistsError: + shm = shared_memory.SharedMemory(name=name) + assert shm.size >= size, f"Stale shm segment '{name}': expected {size} bytes, got {shm.size}" + return shm + + +def rebuild_shared_memory(name: str, size: int, dtype=torch.uint8): + """Rebuild tensor from shared memory.""" + shm = shared_memory.SharedMemory(name=name) + tensor = torch.frombuffer(shm.buf[:size], dtype=dtype) + + return tensor, shm + + +class BucketedWeightSender: + """ + Send model weights via bucketed IPC transfer over ZMQ. + + Packs weight tensors into a fixed-size communication buffer and sends them + in buckets to the receiver. Supports CUDA IPC and shared memory fallback. + + Args: + zmq_handle: ZMQ IPC socket path (e.g., "ipc:///tmp/rl-colocate-zmq-.sock") + bucket_size_mb: Communication buffer size in MB + use_shm: Use shared memory instead of CUDA IPC (for NPU compatibility) + """ + + def __init__( + self, + zmq_handle: str, + bucket_size_mb: int = 512, + use_shm: bool = False, + ): + self.zmq_handle = zmq_handle + self.bucket_size_mb = bucket_size_mb + self.bucket_size = int(bucket_size_mb) << 20 + self.use_shm = use_shm + + self.zmq_context = zmq.Context.instance() + self.socket = None + self.buffer = None + self.shm = None + + async def async_send_weights(self, weights): + """ + Send weights to the receiver. Accepts a sync generator or async iterator. + + Args: + weights: Generator or async iterator yielding (name, tensor) pairs + """ + from verl.workers.rollout.utils import ensure_async_iterator + + try: + self._init_socket() + self._init_buffer() + + # send bucket weights + offset = 0 + bucket_meta: dict[str, TensorMetadata] = {} + # dtype = PrecisionType.to_dtype(self.config.dtype) + async for name, weight in ensure_async_iterator(weights): + # model parameters are in fp32 full precision + # (vermouth1992) we should not force cast weight here because some parameters + # (such as moe gate) have to keep fp32 precision. If a weight is bf16 in the rollout side, + # the rollout should automatically cast on demand. However, this would incur a higher weight + # transfer volume. + # weight = weight.to(dtype, non_blocking=True) + + # fill the tensor bucket + if offset + weight.nbytes > self.bucket_size: + torch.cuda.synchronize() + self.socket.send_pyobj({"bucket_meta": bucket_meta, "is_last": False}) + self.socket.recv() + bucket_meta = {} + offset = 0 + + # TODO: slice embedding layer weight into chunks + assert offset + weight.nbytes <= self.bucket_size, ( + f"Weight {name}({weight.shape}, {weight.dtype}) is too large to fit in the bucket." + f"Please increase rollout.update_weights_bucket_megabytes({self.bucket_size_mb} MB)." + ) + bucket_meta[name] = { + "name": name, + "shape": weight.shape, + "dtype": weight.dtype, + "offset": offset, + } + self.buffer[offset : offset + weight.nbytes].copy_(weight.view(-1).view(torch.uint8), non_blocking=True) + offset += weight.nbytes + + # send the last bucket + torch.cuda.synchronize() + self.socket.send_pyobj({"bucket_meta": bucket_meta, "is_last": True}) + self.socket.recv() + finally: + self._cleanup() + + def _init_socket(self): + """Initialize ZMQ REQ socket and bind.""" + self.socket = self.zmq_context.socket(zmq.REQ) + self.socket.bind(self.zmq_handle) + + def _init_buffer(self): + """build communication buffer""" + buffer, shm = None, None + if not self.use_shm: + buffer = torch.empty(self.bucket_size, dtype=torch.uint8, device=f"{get_device_name()}:{get_device_id()}") + handle = reduce_tensor(buffer) + self.socket.send_pyobj(handle) + else: + import uuid + + # Create unique name for shared memory + shm_name = f"verl_weights_{uuid.uuid4().hex}" + shm = create_shared_memory(self.bucket_size, shm_name) + buffer = torch.frombuffer(shm.buf, dtype=torch.uint8) + + comm_metadata = {"name": shm_name, "size": self.bucket_size} + self.socket.send_pyobj(comm_metadata) + + self.socket.recv() + self.buffer = buffer + self.shm = shm + + def _cleanup(self): + """clean up""" + if self.socket is not None: + self.socket.close() + self.socket = None + del self.buffer + self.buffer = None + if self.shm is not None: + self.shm.close() + self.shm.unlink() + del self.shm + self.shm = None + gc.collect() + torch.cuda.ipc_collect() + torch.cuda.empty_cache() + + +class BucketedWeightReceiver: + """ + Receive model weights via bucketed IPC transfer over ZMQ. + + Receives weight tensors from BucketedWeightSender and passes each + bucket to a callback for processing (e.g., loading into the model). + + Args: + zmq_handle: ZMQ IPC socket path (must match sender) + device: Target device for received tensors + use_shm: Use shared memory instead of CUDA IPC + """ + + def __init__( + self, + zmq_handle: str, + device: torch.device, + use_shm: bool = False, + ): + self.zmq_handle = zmq_handle + self.device = device + self.use_shm = use_shm + + self.zmq_context = zmq.Context.instance() + self.socket = None + self.buffer = None + self.shm = None + + def receive_weights(self, on_bucket_received: callable): + """ + Receive weights from sender and process each bucket via callback. + + Args: + on_bucket_received: Callback function(weight_dict: dict[str, torch.Tensor]) called per bucket. + """ + try: + self._init_socket() + self._init_buffer() + + # receive bucket and update weights + while True: + metadata = self.socket.recv_pyobj() + weights, tensor = [], None + for name, meta in metadata["bucket_meta"].items(): + shape, dtype, offset = meta["shape"], meta["dtype"], meta["offset"] + size = dtype.itemsize * shape.numel() + # NOTE: we need to clone the tensor to release CUDA IPC memory + # but for shared memory, it's not necessary and if we do clone, + # it will cause extra memory copy overhead and slow down the process. + tensor = self.buffer[offset : offset + size].view(dtype=dtype).view(shape) + if not self.use_shm: + tensor = tensor.clone() + else: + tensor = tensor.to(self.device) + weights.append((name, tensor)) + torch.cuda.synchronize() + self.socket.send(b"") + on_bucket_received(dict(weights)) + del weights, tensor + if metadata["is_last"]: + break + finally: + self._cleanup() + + def _init_socket(self): + """Initialize ZMQ REP socket and connect.""" + self.socket = self.zmq_context.socket(zmq.REP) + self.socket.connect(self.zmq_handle) + + def _init_buffer(self): + """Receive and rebuild communication buffer from sender.""" + comm_metadata = self.socket.recv_pyobj() + buffer, shm = None, None + if not self.use_shm: + handle = comm_metadata + buffer = rebuild_ipc(handle, self.device.index) + assert buffer.dtype == torch.uint8 + else: + shm_name = comm_metadata["name"] + shm_size = comm_metadata["size"] + buffer, shm = rebuild_shared_memory(shm_name, shm_size, dtype=torch.uint8) + self.socket.send(b"") + self.buffer = buffer + self.shm = shm + + def _cleanup(self): + """clean up""" + if self.socket is not None: + self.socket.close() + self.socket = None + # Synchronize before releasing the buffer to ensure all async ops + # referencing it (e.g. clone, .to()) have completed. + torch.cuda.synchronize() + del self.buffer + self.buffer = None + if self.shm is not None: + self.shm.close() + del self.shm + self.shm = None + gc.collect() + torch.cuda.ipc_collect() + torch.cuda.empty_cache() diff --git a/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl.py b/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl.py index 60045fab6c..e068c00c76 100644 --- a/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl.py +++ b/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl.py @@ -109,6 +109,7 @@ def prefill_normal( model_input, run_reqs = prepare_prefill_inputs(prefill_reqs, is_chuncked_mode=not self.disable_chunked_prefill) with torch.cuda.stream(g_infer_context.get_overlap_stream()): model_output = self.model.forward(model_input) + self._flush_routing_to_kv_buffer(model_input.mem_indexes) _, next_token_ids_cpu, next_token_logprobs_cpu = self._sample_and_scatter_token( logits=model_output.logits, b_req_idx=model_input.b_req_idx, @@ -140,6 +141,7 @@ def prefill_normal( extra_post_req_handle_func=self.extra_post_req_handle_func, nixl_prefill_chuncked_handle_func=self.nixl_prefill_chuncked_handle_func, ) + # 第四阶段 event_pack.notify_pre_post_handle() return @@ -152,6 +154,7 @@ def decode_normal( model_input, run_reqs = prepare_decode_inputs(decode_reqs) with torch.cuda.stream(g_infer_context.get_overlap_stream()): model_output = self.model.forward(model_input) + self._flush_routing_to_kv_buffer(model_input.mem_indexes) _, next_token_ids_cpu, next_token_logprobs_cpu = self._sample_and_scatter_token( logits=model_output.logits, b_req_idx=model_input.b_req_idx, @@ -190,6 +193,7 @@ def prefill_mtp( model_input, run_reqs = prepare_prefill_inputs(prefill_reqs, is_chuncked_mode=not self.disable_chunked_prefill) with torch.cuda.stream(g_infer_context.get_overlap_stream()): model_output = self.model.forward(model_input) + self._flush_routing_to_kv_buffer(model_input.mem_indexes) next_token_ids, next_token_ids_cpu, next_token_logprobs_cpu = self._sample_and_scatter_token( logits=model_output.logits, b_req_idx=model_input.b_req_idx, @@ -244,6 +248,7 @@ def decode_mtp( with torch.cuda.stream(g_infer_context.get_overlap_stream()): b_mtp_index_cpu = model_input.b_mtp_index model_output = self.model.forward(model_input) + self._flush_routing_to_kv_buffer(model_input.mem_indexes) next_token_ids, next_token_logprobs = sample(model_output.logits, run_reqs, self.eos_id) # verify the next_token_ids b_req_mtp_start_loc = [index for index, mtp_index in enumerate(b_mtp_index_cpu) if mtp_index == 0] @@ -266,6 +271,7 @@ def decode_mtp( key="mtp_accept_len", gpu_tensor=mtp_accept_len, ) + verify_event = torch.cuda.Event() verify_event.record() diff --git a/lightllm/server/router/model_infer/mode_backend/diverse_backend/impl.py b/lightllm/server/router/model_infer/mode_backend/diverse_backend/impl.py index 5a179cb620..ebc55b7ef4 100644 --- a/lightllm/server/router/model_infer/mode_backend/diverse_backend/impl.py +++ b/lightllm/server/router/model_infer/mode_backend/diverse_backend/impl.py @@ -40,8 +40,8 @@ def beam_prefill(self, event_pack: OverlapEventPack, prefill_reqs: List[InferReq ) with torch.cuda.stream(g_infer_context.get_overlap_stream()): - model_output = self.model.forward(model_input) + self._flush_routing_to_kv_buffer(model_input.mem_indexes) logits = model_output.logits batch_idx, run_reqs = self._diverse_copy( diff --git a/lightllm/server/router/model_infer/mode_backend/dp_backend/impl.py b/lightllm/server/router/model_infer/mode_backend/dp_backend/impl.py index c83e8cd4a5..c9484dba6f 100644 --- a/lightllm/server/router/model_infer/mode_backend/dp_backend/impl.py +++ b/lightllm/server/router/model_infer/mode_backend/dp_backend/impl.py @@ -151,6 +151,7 @@ def prefill_normal( run_reqs_num = len(run_reqs) with torch.cuda.stream(g_infer_context.get_overlap_stream()): model_output = self.model.forward(model_input) + self._flush_routing_to_kv_buffer(model_input.mem_indexes) if run_reqs_num > 0: _, next_token_ids_cpu, next_token_logprobs_cpu = self._sample_and_scatter_token( logits=model_output.logits[:run_reqs_num], @@ -198,6 +199,7 @@ def decode_normal(self, event_pack: OverlapEventPack, decode_reqs: List[InferReq run_reqs_num = len(run_reqs) with torch.cuda.stream(g_infer_context.get_overlap_stream()): model_output = self.model.forward(model_input) + self._flush_routing_to_kv_buffer(model_input.mem_indexes) if run_reqs_num > 0: _, next_token_ids_cpu, next_token_logprobs_cpu = self._sample_and_scatter_token( logits=model_output.logits[:run_reqs_num], @@ -246,6 +248,8 @@ def prefill_overlap(self, event_pack: OverlapEventPack, prefill_reqs: List[Infer with torch.cuda.stream(g_infer_context.get_overlap_stream()): model_output0, model_output1 = self.model.microbatch_overlap_prefill(model_input0, model_input1) + self._flush_routing_to_kv_buffer(model_input0.mem_indexes, microbatch_index=0) + self._flush_routing_to_kv_buffer(model_input1.mem_indexes, microbatch_index=1) logits0 = model_output0.logits logits1 = model_output1.logits @@ -319,6 +323,8 @@ def decode_overlap(self, event_pack: OverlapEventPack, decode_reqs: List[InferRe with torch.cuda.stream(g_infer_context.get_overlap_stream()): model_output0, model_output1 = self.model.microbatch_overlap_decode(model_input0, model_input1) + self._flush_routing_to_kv_buffer(model_input0.mem_indexes, microbatch_index=0) + self._flush_routing_to_kv_buffer(model_input1.mem_indexes, microbatch_index=1) logits0 = model_output0.logits logits1 = model_output1.logits @@ -373,6 +379,7 @@ def prefill_mtp(self, event_pack: OverlapEventPack, prefill_reqs: List[InferReq] req_num = len(run_reqs) with torch.cuda.stream(g_infer_context.get_overlap_stream()): model_output: ModelOutput = self.model.forward(model_input) + self._flush_routing_to_kv_buffer(model_input.mem_indexes) b_has_out_cpu = model_input.b_prefill_has_output_cpu[0:req_num] logits = model_output.logits[0:req_num, :] b_req_idx = model_input.b_req_idx[0:req_num] @@ -438,6 +445,7 @@ def decode_mtp(self, event_pack: OverlapEventPack, decode_reqs: List[InferReq]): with torch.cuda.stream(g_infer_context.get_overlap_stream()): model_output = self.model.forward(model_input) + self._flush_routing_to_kv_buffer(model_input.mem_indexes) mtp_accept_len, b_req_mtp_start_loc, next_token_ids = None, None, None if req_num > 0: logits = model_output.logits[0:req_num, :] @@ -646,6 +654,8 @@ def prefill_overlap_mtp(self, event_pack: OverlapEventPack, prefill_reqs: List[I ) = padded_overlap_prepare_prefill_inputs(prefill_reqs) with torch.cuda.stream(g_infer_context.get_overlap_stream()): model_output0, model_output1 = self.model.microbatch_overlap_prefill(model_input0, model_input1) + self._flush_routing_to_kv_buffer(model_input0.mem_indexes, microbatch_index=0) + self._flush_routing_to_kv_buffer(model_input1.mem_indexes, microbatch_index=1) logits0 = model_output0.logits logits1 = model_output1.logits req_num0, req_num1 = len(run_reqs0), len(run_reqs1) @@ -747,8 +757,9 @@ def decode_overlap_mtp(self, event_pack: OverlapEventPack, decode_reqs: List[Inf b_mtp_index_cpu0 = model_input0.b_mtp_index b_mtp_index_cpu1 = model_input1.b_mtp_index with torch.cuda.stream(g_infer_context.get_overlap_stream()): - model_output0, model_output1 = self.model.microbatch_overlap_decode(model_input0, model_input1) + self._flush_routing_to_kv_buffer(model_input0.mem_indexes, microbatch_index=0) + self._flush_routing_to_kv_buffer(model_input1.mem_indexes, microbatch_index=1) logits0 = model_output0.logits logits1 = model_output1.logits run_reqs = run_reqs0 + run_reqs1 @@ -788,6 +799,20 @@ def decode_overlap_mtp(self, event_pack: OverlapEventPack, decode_reqs: List[Inf ) all_next_token_ids.append(next_token_ids) + # Copy accepted buffer states back to buffer[0] for MTP + # Only copy when accept_len > 1 + mask = mtp_accept_len > 1 + if mask.sum() > 0: + actual_req_idxes = b_req_idx[b_req_mtp_start_loc[mask]] + src_buffer_indexes = g_infer_context.req_manager.req_to_buffer_index[ + actual_req_idxes, mtp_accept_len[mask] - 1 + ] + dst_buffer_indexes = g_infer_context.req_manager.req_to_buffer_index[actual_req_idxes, 0] + if hasattr(g_infer_context.req_manager.buffer_mem_manager, "copy_buffer_p2p"): + g_infer_context.req_manager.buffer_mem_manager.copy_buffer_p2p( + src_buffer_indexes, dst_buffer_indexes + ) + verify_event = torch.cuda.Event() verify_event.record() diff --git a/lightllm/server/router/model_infer/model_rpc.py b/lightllm/server/router/model_infer/model_rpc.py index 408b173371..2060753ae6 100644 --- a/lightllm/server/router/model_infer/model_rpc.py +++ b/lightllm/server/router/model_infer/model_rpc.py @@ -37,6 +37,8 @@ from lightllm.utils.graceful_utils import graceful_registry from lightllm.utils.process_check import start_parent_check_thread from lightllm.utils.envs_utils import get_unique_server_name +from lightllm.utils.torch_memory_saver_utils import MemoryTag +from lightllm.server.io_struct import GeneralHttpToModelRpcReq, GeneralModelToHttpRpcRsp logger = init_logger(__name__) @@ -128,6 +130,35 @@ def exposed_init_model(self, kvargs): def exposed_get_max_total_token_num(self): return self.backend.get_max_total_token_num() + def release_memory_occupation(self, tags: List[MemoryTag]): + try: + self.backend.release_memory_occupation(tags) + return True + except BaseException as e: + logger.exception(f"release memory occupation failed: {str(e)}") + return False + + def resume_memory_occupation(self, tags: List[MemoryTag]): + try: + self.backend.resume_memory_occupation(tags) + return True + except BaseException as e: + logger.exception(f"resume memory occupation failed: {str(e)}") + return False + + def exposed_forward_to_model(self, req: GeneralHttpToModelRpcReq) -> GeneralModelToHttpRpcRsp: + try: + req = obtain(req) + if self.backend is None or not hasattr(self.backend, req.func_name): + raise ValueError(f"Backend does not support function {req.func_name}") + success, ret = getattr(self.backend, req.func_name)(req.func_args) + return GeneralModelToHttpRpcRsp(success=success, msg=str(ret), func_name=req.func_name, func_rsp=ret) + except BaseException as e: + logger.exception(f"forward to model backend failed: {str(e)}") + return GeneralModelToHttpRpcRsp( + success=False, msg=f"forward to model backend failed: {str(e)}", func_name=req.func_name + ) + class ModelRpcClient: def __init__(self, conn): @@ -151,6 +182,7 @@ async def _func(*args, **kwargs): self._init_model = async_wrap(self.conn.root.init_model) self._get_max_total_token_num = async_wrap(self.conn.root.get_max_total_token_num) + self._forward_to_model = async_wrap(self.conn.root.forward_to_model) return async def init_model(self, kvargs): @@ -162,6 +194,10 @@ async def get_max_total_token_num(self): ans = self._get_max_total_token_num() return obtain(await ans) + async def forward_to_model(self, req: GeneralHttpToModelRpcReq) -> GeneralModelToHttpRpcRsp: + ans = self._forward_to_model(req) + return obtain(await ans) + def _init_env( args, @@ -222,7 +258,11 @@ async def start_model_process( success_event, ), ) - proc.start() + from lightllm.utils.torch_memory_saver_utils import TorchMemorySaverWrapper + + torch_memory_saver = TorchMemorySaverWrapper(args.enable_torch_memory_saver) + with torch_memory_saver.configure_subprocess(): + proc.start() # Use asyncio.to_thread to make the blocking wait non-blocking await asyncio.to_thread(success_event.wait, timeout=40) diff --git a/lightllm/server/router/req_queue/base_queue.py b/lightllm/server/router/req_queue/base_queue.py index 73113a59b8..e1e2479c86 100644 --- a/lightllm/server/router/req_queue/base_queue.py +++ b/lightllm/server/router/req_queue/base_queue.py @@ -33,6 +33,17 @@ def free_aborted_req_cpu_cache_pages(self, req: Req): req.cpu_cache_match_page_indexes.clear() self.router.cpu_cache_client.lock.release() + def free_aborted_req(self, req: Req): + # 为了让http server 能正常返回请求,还没有开始推理的请求,直接设置结束,返回空字符串 + input_len = req.input_len + req.link_prompt_ids_shm_array() + req.link_logprobs_shm_array() + req.candetoken_out_len = 1 + req.finish_token_index = input_len + req.shm_prompt_ids.arr[input_len] = self.args.eos_id[0] + req.shm_logprobs.arr[input_len] = 0 + req.finish_status.set_status(FinishStatus.FINISHED_ABORTED) + def extend(self, req_group: List[Req]): for req in req_group: req.sample_params.suggested_dp_index = self.dp_index diff --git a/lightllm/server/router/req_queue/chunked_prefill/impl.py b/lightllm/server/router/req_queue/chunked_prefill/impl.py index 884b5930b0..24f017d95c 100644 --- a/lightllm/server/router/req_queue/chunked_prefill/impl.py +++ b/lightllm/server/router/req_queue/chunked_prefill/impl.py @@ -84,8 +84,8 @@ def generate_new_batch(self, current_batch: Batch): waiting_queue = self.waiting_req_list for req in waiting_queue: - if req.is_aborted: - # 由于管理的复杂性,只有没有被调度运行过的请求可以因为abort直接在队列中忽略掉. + if req.is_aborted and not self.router.is_multinode_tp: + # 由于管理的复杂性,只有没有被调度运行过的单节点请求可以因为abort直接在队列中忽略掉. # 暂停的请求需要恢复后,由 router manager 部分来过滤。暂时保持这种处理方法, 否则会导致管理token的泄漏 aborted_count += 1 abort_req_list.append(req) @@ -104,6 +104,7 @@ def generate_new_batch(self, current_batch: Batch): req: Req = req logger.debug(f"router abort req id {req.request_id} shm_index: {req.index_in_shm_mem}") self.free_aborted_req_cpu_cache_pages(req) + self.free_aborted_req(req) self.router.shm_req_manager.put_back_req_obj(req) self.waiting_req_list = self.waiting_req_list[len(can_run_list) + aborted_count :] return new_batch diff --git a/lightllm/server/router/req_queue/chunked_prefill/impl_for_nixl_pd.py b/lightllm/server/router/req_queue/chunked_prefill/impl_for_nixl_pd.py index 3b831c92a6..0b17bbd1c0 100644 --- a/lightllm/server/router/req_queue/chunked_prefill/impl_for_nixl_pd.py +++ b/lightllm/server/router/req_queue/chunked_prefill/impl_for_nixl_pd.py @@ -75,7 +75,7 @@ def generate_new_batch(self, current_batch: Batch): waiting_queue = self.waiting_req_list for req in waiting_queue: - if req.is_aborted: + if req.is_aborted and not self.router.is_multinode_tp: # 由于管理的复杂性,只有没有被调度运行过的请求可以因为abort直接在队列中忽略掉. # 暂停的请求需要恢复后,由 router manager 部分来过滤。暂时保持这种处理方法, 否则会导致管理token的泄漏 aborted_count += 1 @@ -95,6 +95,7 @@ def generate_new_batch(self, current_batch: Batch): req: Req = req logger.debug(f"router abort req id {req.request_id} shm_index: {req.index_in_shm_mem}") self.free_aborted_req_cpu_cache_pages(req) + self.free_aborted_req(req) self.router.shm_req_manager.put_back_req_obj(req) self.waiting_req_list = self.waiting_req_list[len(can_run_list) + aborted_count :] return new_batch diff --git a/lightllm/server/router/req_queue/chunked_prefill/impl_for_pd_decode.py b/lightllm/server/router/req_queue/chunked_prefill/impl_for_pd_decode.py index 4c2ebf7c00..1bfb8fc59e 100644 --- a/lightllm/server/router/req_queue/chunked_prefill/impl_for_pd_decode.py +++ b/lightllm/server/router/req_queue/chunked_prefill/impl_for_pd_decode.py @@ -41,7 +41,7 @@ def generate_new_batch(self, current_batch: Batch): abort_req_list = [] aborted_count = 0 for req in self.waiting_req_list: - if req.is_aborted: + if req.is_aborted and not self.router.is_multinode_tp: # 由于管理的复杂性,只有没有被调度运行过的请求可以因为abort直接在队列中忽略掉. # 暂停的请求需要恢复后,由 router manager 部分来过滤。暂时保持这种处理方法, 否则会导致管理token和管理req对象的泄漏 aborted_count += 1 @@ -58,6 +58,7 @@ def generate_new_batch(self, current_batch: Batch): req: Req = req logger.debug(f"router abort req id {req.request_id} shm_index: {req.index_in_shm_mem}") self.free_aborted_req_cpu_cache_pages(req) + self.free_aborted_req(req) self.router.shm_req_manager.put_back_req_obj(req) self.waiting_req_list = self.waiting_req_list[len(can_run_list) + aborted_count :] return new_batch diff --git a/lightllm/server/router/req_queue/dp_base_queue.py b/lightllm/server/router/req_queue/dp_base_queue.py index a73823b8b7..e5f731df5f 100644 --- a/lightllm/server/router/req_queue/dp_base_queue.py +++ b/lightllm/server/router/req_queue/dp_base_queue.py @@ -27,6 +27,12 @@ def __init__(self, args, router, base_queue_class, dp_size_in_node) -> None: self.reqs_waiting_for_dp_index: List[List[Req]] = [] return + def free_aborted_req(self, req: Req): + dp_index = req.sample_params.suggested_dp_index + assert dp_index >= 0 and dp_index < self.dp_size_in_node + self.inner_queues[dp_index].free_aborted_req(req) + return + def get_dp_queue(self, dp_index: int): assert dp_index < self.dp_size_in_node, "dp index out of range" return self.inner_queues[dp_index] diff --git a/lightllm/server/visualserver/model_infer/__init__.py b/lightllm/server/visualserver/model_infer/__init__.py index ae3c4204db..de0e31bafd 100644 --- a/lightllm/server/visualserver/model_infer/__init__.py +++ b/lightllm/server/visualserver/model_infer/__init__.py @@ -4,12 +4,14 @@ import uuid import os import multiprocessing +import setproctitle from lightllm.utils.retry_utils import retry from rpyc.utils.factory import unix_connect from rpyc.utils.classic import obtain from rpyc.utils.server import ThreadedServer from lightllm.utils.graceful_utils import graceful_registry -from lightllm.utils.envs_utils import get_env_start_args +from lightllm.utils.process_check import start_parent_check_thread +from lightllm.utils.envs_utils import get_env_start_args, get_unique_server_name from .model_rpc_client import VisualModelRpcClient from .model_rpc import VisualModelRpcServer from ..objs import rpyc_config @@ -18,6 +20,8 @@ def _init_env(socket_path: str, success_event): # 注册graceful 退出的处理 graceful_registry(inspect.currentframe().f_code.co_name) + setproctitle.setproctitle(f"lightllm::{get_unique_server_name()}::visual_model_infer") + start_parent_check_thread() import lightllm.utils.rpyc_fix_utils as _ diff --git a/lightllm/utils/dist_utils.py b/lightllm/utils/dist_utils.py index 5b9705ed0e..87aae86e4f 100644 --- a/lightllm/utils/dist_utils.py +++ b/lightllm/utils/dist_utils.py @@ -80,12 +80,15 @@ def init_vision_distributed_env(kvargs): device_id = kvargs["device_id"] set_current_device_id(device_id) torch.cuda.set_device(device_id) + # 不要在init_process_group时,显示的传入device_id + # 这会触发torch的device-bound split优化,会默认后面想加入新进程组的rank + # 都已经存在于默认组,这样RL更新weight的init_group时,外部想加入的组,在执行 + # 通信原语时例如all_reduce,会永远等不到LightLLM默认组里的回复,从而导致错误结果。 dist.init_process_group( "nccl", init_method=f'tcp://127.0.0.1:{kvargs["visual_nccl_port"]}', rank=kvargs["tp_rank_id"], world_size=tp_world_size, - device_id=torch.device(f"cuda:{device_id}"), ) # warmup nccl communicator _a = torch.zeros([1]).to(f"cuda:{device_id}") @@ -150,7 +153,6 @@ def init_distributed_env(kvargs): init_method=f'tcp://{kvargs["nccl_host"]}:{kvargs["nccl_port"]}', rank=kvargs["rank_id"], world_size=kvargs["world_size"], - device_id=torch.device(f"cuda:{device_id}"), ) # warmup nccl communicator _a = torch.zeros([1]).to(f"cuda:{device_id}") @@ -316,3 +318,71 @@ def _init_nccl_env(): assert response.status_code == 200, f"Failed to init config server nccl tcp store: {response.status_code}" return + + +# copy from https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/utils/common.py#L1675 +def init_custom_process_group( + backend=None, + init_method=None, + timeout=None, + world_size=-1, + rank=-1, + store=None, + group_name=None, + pg_options=None, + device_id=None, +): + from torch.distributed.distributed_c10d import ( + Backend, + PrefixStore, + _new_process_group_helper, + _world, + default_pg_timeout, + rendezvous, + ) + + assert (store is None) or (init_method is None), "Cannot specify both init_method and store." + + if store is not None: + assert world_size > 0, "world_size must be positive if using store" + assert rank >= 0, "rank must be non-negative if using store" + elif init_method is None: + init_method = "env://" + + if backend: + backend = Backend(backend) + else: + backend = Backend("undefined") + + if timeout is None: + timeout = default_pg_timeout + + # backward compatible API + if store is None: + rendezvous_iterator = rendezvous(init_method, rank, world_size, timeout=timeout) + store, rank, world_size = next(rendezvous_iterator) + store.set_timeout(timeout) + + # Use a PrefixStore to avoid accidental overrides of keys used by + # different systems (e.g. RPC) in case the store is multi-tenant. + store = PrefixStore(group_name, store) + + # NOTE: The pg_options parameter was renamed into backend_options in PyTorch 2.6.0 + # https://github.com/pytorch/pytorch/commit/a0c7029a75628cd5fa8df83c0de0ea98ee7fd844 + # We need to determine the appropriate parameter name based on PyTorch version + pg_options_param_name = "backend_options" if str(torch.__version__) >= "2.6" else "pg_options" + pg, _ = _new_process_group_helper( + world_size, + rank, + [], + backend, + store, + group_name=group_name, + **{pg_options_param_name: pg_options}, + timeout=timeout, + device_id=device_id, + ) + + _world.pg_group_ranks[pg] = {i: i for i in range(world_size)} + + return pg diff --git a/lightllm/utils/net_utils.py b/lightllm/utils/net_utils.py index b87096d945..c3a466191d 100644 --- a/lightllm/utils/net_utils.py +++ b/lightllm/utils/net_utils.py @@ -1,45 +1,92 @@ import socket import subprocess import ipaddress -import random +import os from lightllm.utils.log_utils import init_logger logger = init_logger(__name__) +DEFAULT_BASE_PORT = 10000 +PORTS_PER_INSTANCE = 1000 +MAX_INSTANCE_ID = 7 -def alloc_can_use_network_port(num=3, used_ports=None, from_port_num=10000): - port_list = [] - for port in range(from_port_num, 65536): + +def _is_port_available(port: int) -> bool: + """Check if a port is available by attempting to bind it.""" + try: with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: - result = s.connect_ex(("localhost", port)) - if result != 0 and port not in used_ports: - port_list.append(port) - if len(port_list) > num * 30: - break + s.bind(("", port)) + return True + except OSError: + return False - if len(port_list) < num: - return None - random.shuffle(port_list) - return port_list[0:num] +def alloc_can_use_network_port(num=3, instance_id=0, used_ports=None): + """ + Allocate available network ports within an instance-specific range. + + Each instance gets a dedicated 1000-port range starting from BASE_PORT + (default 10000, override via LIGHTLLM_BASE_PORT env var). + Instance 0: 10000-10999, Instance 1: 11000-11999, etc. + """ + if instance_id < 0 or instance_id > MAX_INSTANCE_ID: + raise ValueError(f"instance_id must be in range [0, {MAX_INSTANCE_ID}], got {instance_id}") + + base_port = int(os.environ.get("LIGHTLLM_BASE_PORT", DEFAULT_BASE_PORT)) + range_start = base_port + instance_id * PORTS_PER_INSTANCE + range_end = range_start + PORTS_PER_INSTANCE + used_set = set(used_ports) if used_ports else set() + + port_list = [] + for port in range(range_start, range_end): + if len(port_list) >= num: + break + if port in used_set: + continue + if _is_port_available(port): + port_list.append(port) + used_set.add(port) + + if len(port_list) >= num: + logger.info( + f"Instance {instance_id}: allocated {len(port_list)} ports in [{range_start}, {range_end}): {port_list}" + ) + return port_list + + raise RuntimeError( + f"Failed to allocate {num} ports for instance {instance_id} in range [{range_start}, {range_end}). " + f"Only found {len(port_list)} available. Try a different instance_id or set LIGHTLLM_BASE_PORT." + ) def alloc_can_use_port(min_port, max_port): port_list = [] for port in range(min_port, max_port): - with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: - result = s.connect_ex(("localhost", port)) + try: + test_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + test_socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + result = test_socket.connect_ex(("localhost", port)) + test_socket.close() + if result != 0: port_list.append(port) + except Exception: + continue return port_list def find_available_port(start_port, end_port): for port in range(start_port, end_port + 1): - with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock: - result = sock.connect_ex(("localhost", port)) + try: + test_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + test_socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + result = test_socket.connect_ex(("localhost", port)) + test_socket.close() + if result != 0: return port + except Exception: + continue return None diff --git a/lightllm/utils/patch_torch.py b/lightllm/utils/patch_torch.py new file mode 100644 index 0000000000..9f51edeb64 --- /dev/null +++ b/lightllm/utils/patch_torch.py @@ -0,0 +1,63 @@ +# copied from https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/utils/patch_torch.py +from typing import Callable, Union + +import torch +from packaging import version +from torch.multiprocessing import reductions + + +def monkey_patch_torch_reductions(): + """Monkey patching before Torch https://github.com/pytorch/pytorch/pull/149248 is fixed""" + + # Currently, NPU does not support UUID. This has been temporarily commented out, + # with support expected in the fourth quarter. + # if _is_npu: + # return + + if hasattr(reductions, "_reduce_tensor_original"): + return + + reductions._reduce_tensor_original = reductions.reduce_tensor + reductions._rebuild_cuda_tensor_original = reductions.rebuild_cuda_tensor + + reductions.reduce_tensor = _reduce_tensor_modified + reductions.rebuild_cuda_tensor = _rebuild_cuda_tensor_modified + + reductions.init_reductions() + + +# The signature has not been changed for years, and we will not need this when the next version is released, +# so it looks safe to use a constant. +_REDUCE_TENSOR_ARG_DEVICE_INDEX = 6 + + +def _reduce_tensor_modified(*args, **kwargs): + output_fn, output_args = reductions._reduce_tensor_original(*args, **kwargs) + output_args = _modify_tuple(output_args, _REDUCE_TENSOR_ARG_DEVICE_INDEX, _device_to_uuid) + return output_fn, output_args + + +def _rebuild_cuda_tensor_modified(*args): + args = _modify_tuple(args, _REDUCE_TENSOR_ARG_DEVICE_INDEX, _device_from_maybe_uuid) + return reductions._rebuild_cuda_tensor_original(*args) + + +def _device_to_uuid(device: int) -> str: + return str(torch.cuda.get_device_properties(device).uuid) + + +def _device_from_maybe_uuid(device_maybe_uuid: Union[int, str]) -> int: + if isinstance(device_maybe_uuid, int): + return device_maybe_uuid + + if isinstance(device_maybe_uuid, str): + for device in range(torch.cuda.device_count()): + if str(torch.cuda.get_device_properties(device).uuid) == device_maybe_uuid: + return device + raise Exception("Invalid device_uuid=" + device_maybe_uuid) + + raise Exception(f"Unknown type: {device_maybe_uuid=}") + + +def _modify_tuple(t, index: int, modifier: Callable): + return *t[:index], modifier(t[index]), *t[index + 1 :] diff --git a/lightllm/utils/serializer.py b/lightllm/utils/serializer.py new file mode 100644 index 0000000000..d8180aeb0c --- /dev/null +++ b/lightllm/utils/serializer.py @@ -0,0 +1,131 @@ +# copied from https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/utils/common.py + +import base64 +import pickle +import io +from dataclasses import dataclass +from multiprocessing.reduction import ForkingPickler +from typing import List + + +class MultiprocessingSerializer: + @staticmethod + def serialize(obj, output_str: bool = False): + """ + Serialize a Python object using ForkingPickler. + + Args: + obj: The object to serialize. + output_str (bool): If True, return a base64-encoded string instead of raw bytes. + + Returns: + bytes or str: The serialized object. + """ + buf = io.BytesIO() + ForkingPickler(buf).dump(obj) + buf.seek(0) + output = buf.read() + + if output_str: + # Convert bytes to base64-encoded string + output = base64.b64encode(output).decode("utf-8") + + return output + + @staticmethod + def deserialize(data): + """ + Deserialize a previously serialized object. + + Args: + data (bytes or str): The serialized data, optionally base64-encoded. + + Returns: + The deserialized Python object. + """ + if isinstance(data, str): + # Decode base64 string to bytes + data = base64.b64decode(data, validate=True) + + return SafeUnpickler(io.BytesIO(data)).load() + + +class SafeUnpickler(pickle.Unpickler): + ALLOWED_MODULE_PREFIXES = { + # --- Python types --- + "builtins.", + "collections.", + "copyreg.", + "functools.", + "itertools.", + "operator.", + "types.", + "weakref.", + # --- PyTorch types --- + "torch.", + "torch._tensor.", + "torch.storage.", + "torch.nn.parameter.", + "torch.autograd.function.", + # --- torch distributed --- + "torch.distributed.", + "torch.distributed._shard.", + "torch.distributed._composable.", + "torch._C._distributed_c10d.", + "torch._C._distributed_fsdp.", + "torch.distributed.optim.", + # --- multiprocessing --- + "multiprocessing.resource_sharer.", + "multiprocessing.reduction.", + "pickletools.", + # --- PEFT / LoRA --- + "peft.", + "transformers.", + "huggingface_hub.", + # --- SGLang & Unitest --- + "sglang.srt.weight_sync.tensor_bucket.", + "sglang.srt.model_executor.model_runner.", + "sglang.srt.layers.", + "sglang.srt.utils.", + # --- LightLLM --- + "lightllm.utils.", + } + + DENY_CLASSES = { + ("builtins", "eval"), + ("builtins", "exec"), + ("builtins", "compile"), + ("os", "system"), + ("subprocess", "Popen"), + ("subprocess", "run"), + ("codecs", "decode"), + ("types", "CodeType"), + ("types", "FunctionType"), + } + + def find_class(self, module, name): + # Block deterministic attacks + if (module, name) in self.DENY_CLASSES: + raise RuntimeError( + f"Blocked unsafe class loading ({module}.{name}), " f"to prevent exploitation of CVE-2025-10164" + ) + # Allowlist of safe-to-load modules. + if any((module + ".").startswith(prefix) for prefix in self.ALLOWED_MODULE_PREFIXES): + return super().find_class(module, name) + + # Block everything else. (Potential attack surface) + raise RuntimeError( + f"Blocked unsafe class loading ({module}.{name}), " f"to prevent exploitation of CVE-2025-10164" + ) + + +@dataclass +class LocalSerializedTensor: + """torch.Tensor that gets serialized by MultiprocessingSerializer + (which only serializes a pointer and not the data). + The i-th element in the list corresponds to i-th rank's GPU.""" + + values: List[bytes] + + def get(self, rank: int): + return MultiprocessingSerializer.deserialize(self.values[rank]) diff --git a/lightllm/utils/tensor_bucket.py b/lightllm/utils/tensor_bucket.py new file mode 100644 index 0000000000..a9d7a367dd --- /dev/null +++ b/lightllm/utils/tensor_bucket.py @@ -0,0 +1,104 @@ +# copy from +# https://raw.githubusercontent.com/sgl-project/sglang/refs/heads/main/python/sglang/ +# srt/weight_sync/tensor_bucket.py +from dataclasses import dataclass +from typing import List, Tuple + +import torch + + +@dataclass +class FlattenedTensorMetadata: + """Metadata for a tensor in a flattened bucket""" + + name: str + shape: torch.Size + dtype: torch.dtype + start_idx: int + end_idx: int + numel: int + + +class FlattenedTensorBucket: + """ + A bucket that flattens multiple tensors into a single tensor for efficient processing + while preserving all metadata needed for reconstruction. + """ + + # This field is solely for users of to check whether the class supports this feature + supports_multi_dtypes = True + + def __init__( + self, + named_tensors: List[Tuple[str, torch.Tensor]] = None, + flattened_tensor: torch.Tensor = None, + metadata: List[FlattenedTensorMetadata] = None, + ): + """ + Initialize a tensor bucket from a list of named tensors OR from pre-flattened data. + Args: + named_tensors: List of (name, tensor) tuples (for creating new bucket) + flattened_tensor: Pre-flattened tensor (for reconstruction) + metadata: Pre-computed metadata (for reconstruction) + """ + if named_tensors is not None: + # Create bucket from named tensors + self.metadata: List[FlattenedTensorMetadata] = [None] * len(named_tensors) + self.flattened_tensor: torch.Tensor = None + + if not named_tensors: + raise ValueError("Cannot create empty tensor bucket") + + # Collect metadata and flatten tensors + current_idx = 0 + flattened_tensors: List[torch.Tensor] = [None] * len(named_tensors) + + for i, (name, tensor) in enumerate(named_tensors): + flattened = tensor.flatten().view(torch.uint8) + flattened_tensors[i] = flattened + + # Store metadata + + numel = flattened.numel() + metadata_obj = FlattenedTensorMetadata( + name=name, + shape=tensor.shape, + dtype=tensor.dtype, + start_idx=current_idx, + end_idx=current_idx + numel, + numel=numel, + ) + self.metadata[i] = metadata_obj + current_idx += numel + + # Concatenate all flattened tensors + self.flattened_tensor = torch.cat(flattened_tensors, dim=0) + else: + # Initialize from pre-flattened data + if flattened_tensor is None or metadata is None: + raise ValueError("Must provide either named_tensors or both flattened_tensor and metadata") + self.flattened_tensor = flattened_tensor + self.metadata = metadata + + def get_flattened_tensor(self) -> torch.Tensor: + """Get the flattened tensor containing all bucket tensors""" + return self.flattened_tensor + + def get_metadata(self) -> List[FlattenedTensorMetadata]: + """Get metadata for all tensors in the bucket""" + return self.metadata + + def reconstruct_tensors(self) -> List[Tuple[str, torch.Tensor]]: + """ + Reconstruct original tensors from flattened tensor with optimized performance. + Uses memory-efficient operations to minimize allocations and copies. + """ + # preallocate the result list + reconstructed = [None] * len(self.metadata) + + for i, meta in enumerate(self.metadata): + tensor = self.flattened_tensor[meta.start_idx : meta.end_idx].view(meta.dtype).reshape(meta.shape) + + reconstructed[i] = (meta.name, tensor) + + return reconstructed diff --git a/lightllm/utils/torch_memory_saver_utils.py b/lightllm/utils/torch_memory_saver_utils.py new file mode 100644 index 0000000000..c1184ef30c --- /dev/null +++ b/lightllm/utils/torch_memory_saver_utils.py @@ -0,0 +1,92 @@ +import torch +from contextlib import contextmanager +from enum import Enum +from lightllm.utils.log_utils import init_logger + +try: + from torch_memory_saver import ( + torch_memory_saver, + configure_subprocess, + ) + + HAS_TORCH_MEMORY_SAVER = True + +except ImportError: + HAS_TORCH_MEMORY_SAVER = False + pass + +logger = init_logger(__name__) + + +class MemoryTag(Enum): + KV_CACHE = "kv_cache" + WEIGHT = "weights" + GRAPH = "graph" + + def is_kv_cache(self): + return self == MemoryTag.KV_CACHE + + def is_weight(self): + return self == MemoryTag.WEIGHT + + def is_graph(self): + return self == MemoryTag.GRAPH + + def __str__(self): + return self.value + + +class TorchMemorySaverWrapper: + def __new__(cls, enable_torch_memory_saver: bool = False): + if enable_torch_memory_saver: + assert ( + HAS_TORCH_MEMORY_SAVER + ), "torch_memory_saver is not installed, please install it via `pip install torch_memory_saver`." + return _TorchMemorySaver() + else: + return _TorchMemorySaverFake() + + +class _TorchMemorySaver: + def configure_subprocess(self): + return configure_subprocess() + + def region(self, tag: MemoryTag, enable_cpu_backup: bool = False): + return torch_memory_saver.region(tag=tag.value, enable_cpu_backup=enable_cpu_backup) + + def cuda_graph(self, graph_obj: torch.cuda.CUDAGraph, **kwargs): + return torch_memory_saver.cuda_graph(cuda_graph=graph_obj, **kwargs, tag=MemoryTag.GRAPH.value) + + def disable(self): + return torch_memory_saver.disable() + + def pause(self, tag: MemoryTag): + return torch_memory_saver.pause(tag=tag.value) + + def resume(self, tag: MemoryTag): + return torch_memory_saver.resume(tag=tag.value) + + +class _TorchMemorySaverFake: + @contextmanager + def configure_subprocess(self): + yield + + @contextmanager + def region(self, tag: MemoryTag, enable_cpu_backup: bool = False): + yield + + def cuda_graph(self, graph_obj: torch.cuda.CUDAGraph, **kwargs): + return torch.cuda.graph(graph_obj, **kwargs) + + @contextmanager + def disable(self): + yield + + def pause(self, tag: MemoryTag): + logger.warning("torch_memory_saver is not enabled, pause is not supported.") + return + + def resume(self, tag: MemoryTag): + logger.warning("torch_memory_saver is not enabled, resume is not supported.") + return diff --git a/requirements.txt b/requirements.txt index d37ae05690..2ede1b24c1 100644 --- a/requirements.txt +++ b/requirements.txt @@ -94,6 +94,7 @@ partial_json_parser==0.2.1.1.post6 websockets==15.0.1 cupy-cuda12x==13.6.0 nixl==0.8.0 +torch_memory_saver==0.0.9 xformers==0.0.33.post2 redis==7.3.0 litellm>=1.52.0,<1.85 diff --git a/test/test_api/test_abort_request.py b/test/test_api/test_abort_request.py new file mode 100644 index 0000000000..ca99f6298c --- /dev/null +++ b/test/test_api/test_abort_request.py @@ -0,0 +1,432 @@ +""" +Test the /abort_request endpoint against a running lightllm server. + +What this test asserts (and why it does not assert "stream becomes finish_reason='abort'"): + + In normal / chunked_prefill mode, /abort_request: + - sets shm_req.is_aborted = True + - drives the router to send AbortedReqCmd, which sets InferReq.infer_aborted = True + on the worker + - causes still-waiting (not yet scheduled) reqs to be freed with FINISHED_ABORTED + - but does NOT cause already-running reqs to early-exit; they finish at max_new_tokens + / EOS / stop sequence as usual. (The shm flag is consumed by audio/visual servers + and pd_nixl mode, but the LLM inference loop never short-circuits on it.) + +So the test verifies the contract that actually exists today: + + Stage A: bogus request_id -> HTTP 200, server log "not exist" warning + Stage B: abort_all on an idle server -> HTTP 200, no errors + Stage C: abort_all on a running stream + -> HTTP 200; server log shows "aborted group_request_id N" warning + -> the stream terminates within reasonable time (whether via abort or + natural max_new_tokens completion) + Stage D: abort by SPECIFIC request_id on a running stream + -> resolve the lightllm_req_id from the server log (via X-Request-Id), + POST /abort_request with that exact id, verify the targeted log + warning lands and the stream terminates + Stage E: server remains healthy and answers a fresh /generate + +Usage: + python test/test_api/test_abort_request.py \ + --url http://127.0.0.1:8000 \ + --server_log_path /tmp/lightllm_test/server.log +""" + +import argparse +import json +import os +import re +import sys +import threading +import time +import uuid +from typing import List, Optional, Tuple + +import requests + + +GREEN = "\033[32m" +RED = "\033[31m" +YELLOW = "\033[33m" +RESET = "\033[0m" + + +def banner(msg: str): + print(f"\n{YELLOW}=== {msg} ==={RESET}", flush=True) + + +def ok(msg: str): + print(f" {GREEN}OK{RESET} {msg}", flush=True) + + +def fail(msg: str): + print(f" {RED}FAIL{RESET} {msg}", flush=True) + + +# ---------------- HTTP helpers ---------------- + + +def _get_health(url: str, timeout=5): + return requests.get(url + "/health", timeout=timeout) + + +def post_abort(url: str, request_id: Optional[int] = None, abort_all: bool = False) -> Tuple[int, str]: + payload = {"abort_all": abort_all} + if request_id is not None: + payload["request_id"] = request_id + r = requests.post(url + "/abort_request", json=payload, timeout=30) + return r.status_code, r.text + + +# ---------------- streaming helpers ---------------- + + +def _stream_run( + url: str, + prompt: str, + max_new_tokens: int, + x_request_id: str, + out: dict, + close_after_n: Optional[int] = None, +): + """ + Issue a /generate_stream and append every event to out["events"]. + If close_after_n is set, the underlying socket is forcibly closed + (TCP RST via SO_LINGER + close) after that many events arrive — kept + here for completeness even though no current stage uses it. Sets + out["error"] on transport errors. + """ + headers = {"X-Request-Id": x_request_id, "Content-Type": "application/json"} + body = { + "inputs": prompt, + "parameters": { + "max_new_tokens": max_new_tokens, + "do_sample": False, + "ignore_eos": True, + }, + } + out["events"] = [] + out["start"] = time.time() + out["error"] = None + out["closed_intentionally"] = False + try: + # urllib3 keeps the socket pooled; we need direct access to force-close. + with requests.post(url + "/generate_stream", json=body, headers=headers, stream=True, timeout=120) as r: + r.raise_for_status() + for raw in r.iter_lines(decode_unicode=True): + if not raw: + continue + if raw.startswith("data:"): + raw = raw[len("data:") :] + try: + ev = json.loads(raw) + except Exception: + continue + ev["_t"] = time.time() - out["start"] + out["events"].append(ev) + if close_after_n is not None and len(out["events"]) >= close_after_n: + out["closed_intentionally"] = True + # Reach into urllib3 to force a TCP RST so the server sees + # the disconnect immediately rather than after a graceful + # FIN that hypercorn might not propagate while the response + # is mid-stream. + try: + import socket as _socket + + sock = r.raw._fp.fp.raw._sock # type: ignore[attr-defined] + # SO_LINGER with timeout 0 -> RST on close. + l_onoff, l_linger = 1, 0 + sock.setsockopt( + _socket.SOL_SOCKET, + _socket.SO_LINGER, + int.to_bytes(l_onoff, 4, "little") + int.to_bytes(l_linger, 4, "little"), + ) + sock.close() + except Exception as e: + out["close_error"] = repr(e) + break + if ev.get("finished"): + break + except Exception as e: + out["error"] = repr(e) + out["end"] = time.time() + + +def start_stream( + url: str, prompt: str, max_new_tokens: int, close_after_n: Optional[int] = None +) -> Tuple[threading.Thread, dict, str]: + xid = uuid.uuid4().hex + out = {} + th = threading.Thread(target=_stream_run, args=(url, prompt, max_new_tokens, xid, out, close_after_n)) + th.daemon = True + th.start() + return th, out, xid + + +def wait_for_first_token(out: dict, timeout: float = 30.0) -> bool: + deadline = time.time() + timeout + while time.time() < deadline: + if out.get("events"): + return True + time.sleep(0.05) + return False + + +def get_finish_reason(out: dict) -> Optional[str]: + for ev in reversed(out.get("events") or []): + fr = ev.get("finish_reason") + if fr: + return fr + return None + + +# ---------------- log helpers ---------------- + + +def _read_log_tail(server_log_path: Optional[str], max_bytes: int = 256 * 1024) -> str: + if not server_log_path or not os.path.exists(server_log_path): + return "" + try: + size = os.path.getsize(server_log_path) + with open(server_log_path, "rb") as f: + if size > max_bytes: + f.seek(size - max_bytes) + return f.read().decode("utf-8", errors="ignore") + except FileNotFoundError: + return "" + + +def grep_log_for_pattern(server_log_path: Optional[str], pattern: re.Pattern, timeout: float = 5.0) -> Optional[str]: + """Poll the tail of the server log for a regex match.""" + deadline = time.time() + timeout + while time.time() < deadline: + tail = _read_log_tail(server_log_path) + m = pattern.search(tail) + if m: + return m.group(0) + time.sleep(0.1) + return None + + +def grep_log_after_offset( + server_log_path: Optional[str], start_offset: int, pattern: re.Pattern, timeout: float = 5.0 +) -> Optional[str]: + """Poll the server log starting at start_offset for a regex match. + Only content written after start_offset is considered, so this isolates + a stage from log produced by earlier stages.""" + if not server_log_path: + return None + deadline = time.time() + timeout + while time.time() < deadline: + try: + with open(server_log_path, "rb") as f: + f.seek(start_offset) + new = f.read().decode("utf-8", errors="ignore") + except FileNotFoundError: + new = "" + m = pattern.search(new) + if m: + return m.group(0) + time.sleep(0.1) + return None + + +def server_log_size(server_log_path: Optional[str]) -> int: + if not server_log_path or not os.path.exists(server_log_path): + return 0 + return os.path.getsize(server_log_path) + + +def lookup_lightllm_req_id_from_log(server_log_path: str, x_request_id: str, timeout: float = 5.0) -> Optional[int]: + pattern = re.compile(rf"received req X-Request-Id:{re.escape(x_request_id)}\b.*?lightllm_req_id:(\d+)") + deadline = time.time() + timeout + while time.time() < deadline: + tail = _read_log_tail(server_log_path) + m = pattern.search(tail) + if m: + return int(m.group(1)) + time.sleep(0.1) + return None + + +# ---------------- stages ---------------- + + +def stage_a_bogus_id(url: str) -> bool: + banner("Stage A: abort with a non-existent id") + bogus = 99_999_999 + code, text = post_abort(url, request_id=bogus, abort_all=False) + print(f" /abort_request request_id={bogus} -> HTTP {code} body={text!r}") + if code != 200: + fail(f"expected HTTP 200, got {code}") + return False + ok("HTTP 200") + return True + + +def stage_b_abort_all_idle(url: str) -> bool: + banner("Stage B: abort_all on an idle server") + code, text = post_abort(url, abort_all=True) + print(f" /abort_request abort_all=true -> HTTP {code} body={text!r}") + if code != 200: + fail(f"expected HTTP 200, got {code}") + return False + ok("HTTP 200") + return True + + +def stage_c_abort_running(url: str, server_log_path: Optional[str]) -> bool: + banner("Stage C: abort_all on a running stream") + log_offset = server_log_size(server_log_path) + th, out, xid = start_stream(url, "Recite the alphabet repeatedly.", max_new_tokens=200) + if not wait_for_first_token(out, timeout=30.0): + fail("did not receive any tokens before abort") + return False + first_t = out["events"][0]["_t"] + ok(f"first token at +{first_t:.2f}s") + + target_id = lookup_lightllm_req_id_from_log(server_log_path, xid, timeout=5.0) if server_log_path else None + print(f" resolved lightllm_req_id from log: {target_id}") + + code, text = post_abort(url, abort_all=True) + print(f" /abort_request abort_all=true -> HTTP {code} body={text!r}") + if code != 200: + fail(f"expected HTTP 200, got {code}") + return False + + th.join(timeout=60.0) + if th.is_alive(): + fail("stream did not terminate within 60s of abort") + return False + fr = get_finish_reason(out) + n = len(out.get("events") or []) + print(f" stream events received: {n}, finish_reason={fr!r}, error={out.get('error')!r}") + + # The api itself succeeded; whether the stream got a clean 'abort' finish reason + # depends on which mode-backend the server is running. We DO assert the abort + # warning landed in the server log though, scoped to log content produced after + # this stage started so we don't match earlier-stage residue. + if server_log_path: + if target_id is not None: + pat = re.compile(rf"aborted group_request_id {target_id}\b") + else: + pat = re.compile(r"aborted group_request_id \d+") + hit = grep_log_after_offset(server_log_path, log_offset, pat, timeout=5.0) + if not hit: + fail("could not find 'aborted group_request_id' in server log (post-stage)") + return False + ok(f"server log recorded: {hit!r}") + else: + print(" no --server_log_path; skipped log assertion") + ok("stream terminated and abort acknowledged") + return True + + +def stage_d_abort_by_id(url: str, server_log_path: Optional[str]) -> bool: + banner("Stage D: abort by specific request_id on a running stream") + if not server_log_path: + print(" --server_log_path not provided; skipping (we need the log to resolve req_id)") + return True + + log_offset = server_log_size(server_log_path) + th, out, xid = start_stream(url, "Sing a long lullaby for the moon.", max_new_tokens=300) + if not wait_for_first_token(out, timeout=30.0): + fail("did not receive any tokens before abort") + return False + ok(f"first token at +{out['events'][0]['_t']:.2f}s, X-Request-Id={xid[:8]}…") + + target_id = lookup_lightllm_req_id_from_log(server_log_path, xid, timeout=5.0) + if target_id is None: + fail("could not resolve lightllm_req_id from server log; cannot test by-id abort") + return False + print(f" resolved lightllm_req_id: {target_id}") + + code, text = post_abort(url, request_id=target_id, abort_all=False) + print(f" /abort_request request_id={target_id} -> HTTP {code} body={text!r}") + if code != 200: + fail(f"expected HTTP 200, got {code}") + return False + + th.join(timeout=60.0) + if th.is_alive(): + fail("stream did not terminate within 60s") + return False + fr = get_finish_reason(out) + n = len(out.get("events") or []) + print(f" stream events received: {n}, finish_reason={fr!r}") + + pat = re.compile(rf"aborted group_request_id {target_id}\b") + hit = grep_log_after_offset(server_log_path, log_offset, pat, timeout=5.0) + if not hit: + fail(f"could not find 'aborted group_request_id {target_id}' in server log (post-stage)") + return False + ok(f"server log recorded: {hit!r}") + return True + + +def stage_e_health_after(url: str) -> bool: + banner("Stage E: server still serves a normal /generate") + r = requests.post( + url + "/generate", + json={ + "inputs": "The capital of France is", + "parameters": {"max_new_tokens": 6, "do_sample": False}, + }, + timeout=60, + ) + print(f" /generate -> HTTP {r.status_code} {r.text[:200]}") + if r.status_code != 200: + fail(f"final /generate failed with {r.status_code}") + return False + body = r.json() + text = body.get("generated_text") + if isinstance(text, list): + text = text[0] + if not text or not text.strip(): + fail("final /generate returned empty text") + return False + ok(f"final /generate returned {text!r}") + return True + + +# ---------------- main ---------------- + + +def main(): + ap = argparse.ArgumentParser() + ap.add_argument("--url", default="http://127.0.0.1:8000") + ap.add_argument( + "--server_log_path", + default=None, + help="optional path to the server stdout/stderr log; enables log-grep assertions", + ) + args = ap.parse_args() + + try: + r = _get_health(args.url) + r.raise_for_status() + except Exception as e: + fail(f"server at {args.url} not reachable: {e}") + sys.exit(1) + ok(f"server reachable at {args.url}") + + results = [] + results.append(("A", stage_a_bogus_id(args.url))) + results.append(("B", stage_b_abort_all_idle(args.url))) + results.append(("C", stage_c_abort_running(args.url, args.server_log_path))) + results.append(("D", stage_d_abort_by_id(args.url, args.server_log_path))) + results.append(("E", stage_e_health_after(args.url))) + + print("\n" + "=" * 50) + all_ok = True + for name, passed in results: + tag = f"{GREEN}PASS{RESET}" if passed else f"{RED}FAIL{RESET}" + print(f" Stage {name}: {tag}") + all_ok = all_ok and passed + if not all_ok: + sys.exit(1) + print(f"\n{GREEN}ALL ABORT STAGES PASSED{RESET}") + + +if __name__ == "__main__": + main() diff --git a/test/test_api/test_r3.py b/test/test_api/test_r3.py new file mode 100644 index 0000000000..85c4e44ef9 --- /dev/null +++ b/test/test_api/test_r3.py @@ -0,0 +1,92 @@ +import sys +import argparse +import requests +import base64 +import numpy as np + + +def test_routing_export(url: str = "http://localhost:8000"): + print(f"Testing routing export at {url}") + print("-" * 50) + + try: + response = requests.post( + f"{url}/generate", + json={ + "inputs": "What is the capital of France? What is the capital of France?", + "parameters": { + "max_new_tokens": 50, + # "return_routed_experts": True, + # "repetition_penalty": 1.0, + }, + }, + timeout=60, + ) + except requests.exceptions.ConnectionError: + print(f"ERROR: Cannot connect to server at {url}") + print("Make sure the LightLLM server is running with --enable_return_routed_experts") + return False + except requests.exceptions.Timeout: + print("ERROR: Request timed out") + return False + + print(f"Status: {response.status_code}") + + if response.status_code != 200: + print(f"ERROR: Request failed with status {response.status_code}") + print(f"Response: {response.text}") + return False + + res = response.json() + print(f"Generated text: {res.get('generated_text', 'N/A')[:100]}...") + + if "routed_experts" not in res or not res["routed_experts"]: + print("\nWARNING: No routed_experts in response.") + print("This could mean:") + print(" - The model is not a MoE model") + print(" - The server was not started with --enable_return_routed_experts") + print(" - The routing capture manager was not initialized") + return False + + routing_info = res["routed_experts"] + shape = routing_info["shape"] + dtype_str = routing_info["dtype"] + dtype = np.dtype(dtype_str) + data = base64.b64decode(routing_info["data"]) + routing_array = np.frombuffer(data, dtype=dtype).reshape(shape) + + print(f"\n{'=' * 50}") + print("ROUTING CAPTURE SUCCESS!") + print(f"{'=' * 50}") + print(f"Shape: {shape}") + print(f"Dtype: {dtype}") + print(f"Num tokens: {shape[0]}") + print(f"Num MoE layers: {shape[1]}") + print(f"Top-K: {shape[2]}") + + # Compute payload size savings + int32_size = np.prod(shape) * 4 + actual_size = len(data) + savings = (1 - actual_size / int32_size) * 100 + print(f"Payload: {actual_size} bytes (vs {int32_size} bytes with int32, {savings:.0f}% smaller)") + + print(f"\nSample routing (first layer, first 5 tokens):") + num_tokens_to_show = shape[0] + for i in range(num_tokens_to_show): + print(f" Token {i}: experts {routing_array[i, 0, :].tolist()}") + + if np.all(routing_array == 0): + print("\nWARNING: All routing data is zeros. Capture may not be working correctly.") + return False + + print("\nTest PASSED!") + return True + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Test R3 routing export feature") + parser.add_argument("--url", default="http://localhost:8000", help="Server URL") + args = parser.parse_args() + + success = test_routing_export(args.url) + sys.exit(0 if success else 1) diff --git a/test/test_api/test_rl_endpoints.py b/test/test_api/test_rl_endpoints.py new file mode 100644 index 0000000000..dda304924b --- /dev/null +++ b/test/test_api/test_rl_endpoints.py @@ -0,0 +1,309 @@ +""" +Test release_memory_occupation / resume_memory_occupation / update_weights_from_tensor +against a running lightllm server. + +Sequence: + 1. baseline generate (sanity) + 2. release_memory_occupation -> GPU memory should drop sharply + 3. resume_memory_occupation -> GPU memory should grow back + (without --enable_weight_cpu_backup the weight + memory is allocated empty, so generation right + after resume is expected to be garbage) + 4. update_weights_from_tensor (per-batch CUDA-IPC handoff) for every parameter + found on disk -> repopulate weights + 5. final generate -> should produce a sensible answer again + +The "trainer" runs in this same process: it holds tensors on a free GPU, serialises +them via lightllm.utils.serializer.MultiprocessingSerializer (CUDA IPC handles, not +data), then asks the server to clone them into its weight buffers. No NCCL group +is required, so this is safe to interrupt without leaving the server hung. + +Usage: + python test/test_api/test_rl_endpoints.py \ + --url http://127.0.0.1:8000 \ + --model_dir /nvme/models/Qwen3.5-35B-A3B \ + --tp 2 \ + --client_device 2 + +Notes: + - This script must run on the same machine as the server (CUDA IPC). + - --client_device picks a free GPU for the in-process trainer; must differ from + the TP worker GPUs (workers occupy 0..tp-1 in the user's launch command). +""" + +import argparse +import json +import os +import subprocess +import sys +import time +from glob import glob +from typing import Dict, List, Tuple + +# Make the repo importable when this script is invoked by path rather than -m. +_REPO_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..")) +if _REPO_ROOT not in sys.path: + sys.path.insert(0, _REPO_ROOT) + +import requests +import torch +from safetensors import safe_open + +from lightllm.utils.patch_torch import monkey_patch_torch_reductions +from lightllm.utils.serializer import MultiprocessingSerializer + + +GREEN = "\033[32m" +RED = "\033[31m" +YELLOW = "\033[33m" +RESET = "\033[0m" + + +def banner(msg: str): + print(f"\n{YELLOW}=== {msg} ==={RESET}", flush=True) + + +def ok(msg: str): + print(f" {GREEN}OK{RESET} {msg}", flush=True) + + +def fail(msg: str): + print(f" {RED}FAIL{RESET} {msg}", flush=True) + + +def gpu_mem_used_mib() -> List[int]: + out = subprocess.check_output(["nvidia-smi", "--query-gpu=memory.used", "--format=csv,noheader,nounits"]).decode() + return [int(x.strip()) for x in out.strip().splitlines()] + + +def post(url: str, path: str, payload=None, timeout=600): + r = requests.post(url + path, json=payload or {}, timeout=timeout) + try: + body = r.json() + except Exception: + body = r.text + return r.status_code, body + + +def generate(url: str, prompt: str, max_new_tokens: int = 16) -> str: + r = requests.post( + url + "/generate", + json={ + "inputs": prompt, + "parameters": {"max_new_tokens": max_new_tokens, "do_sample": False}, + }, + timeout=120, + ) + r.raise_for_status() + data = r.json() + if isinstance(data.get("generated_text"), list): + return data["generated_text"][0] + return data.get("generated_text", json.dumps(data)) + + +def looks_garbage(text: str) -> bool: + """Heuristic: post-resume text is usually a single repeated character (e.g. '!!!!').""" + s = text.strip() + if not s: + return True + return len(set(s)) == 1 + + +# ---------------- weight-update helpers (update_weights_from_tensor) ---------------- + + +def _list_safetensor_shards(model_dir: str) -> List[str]: + shards = sorted(glob(os.path.join(model_dir, "*.safetensors"))) + if not shards: + raise RuntimeError(f"no .safetensors found under {model_dir}") + return shards + + +def _send_update_from_tensor( + url: str, + serialized_per_rank: List[str], + flush_cache: bool = False, +): + code, body = post( + url, + "/update_weights_from_tensor", + { + "serialized_named_tensors": serialized_per_rank, + "load_format": None, + "flush_cache": flush_cache, + "abort_all_requests": False, + }, + timeout=600, + ) + return code, body + + +def update_weights_from_disk_via_tensor_api( + url: str, + model_dir: str, + tp: int, + client_device: int, + batch_per_request: int = 8, + flush_cache_at_end: bool = True, +): + """ + Acts as an in-process "trainer": loads every safetensor shard onto + cuda:client_device, then ships each batch of (name, tensor) to the server + via /update_weights_from_tensor. The server worker on each TP rank receives + a CUDA IPC handle, copies into its weight buffer. + """ + banner("update_weights_from_tensor (CUDA IPC)") + # Server side patches its own copy; we patch ours so reductions can serialise + # CUDA tensors with UUID-based device addressing. + monkey_patch_torch_reductions() + torch.cuda.set_device(client_device) + device = f"cuda:{client_device}" + + shards = _list_safetensor_shards(model_dir) + print(f" found {len(shards)} safetensor shards, batch_per_request={batch_per_request}", flush=True) + + total_params = 0 + total_bytes = 0 + t0 = time.time() + for shard_idx, shard in enumerate(shards): + shard_t0 = time.time() + with safe_open(shard, framework="pt") as f: + keys = list(f.keys()) + for i in range(0, len(keys), batch_per_request): + batch_keys = keys[i : i + batch_per_request] + # Load batch onto the client GPU. .contiguous() guarantees a + # whole-tensor allocation (safetensors slices are already + # contiguous, but this is cheap insurance). + tensors = [f.get_tensor(k).to(device).contiguous() for k in batch_keys] + named: List[Tuple[str, torch.Tensor]] = list(zip(batch_keys, tensors)) + + # Same payload to every TP rank — the server clones full + # tensors per rank and lets model.load_weights handle the TP + # sharding internally (matching how update_weights_from_* + # paths are written). + blob = MultiprocessingSerializer.serialize(named, output_str=True) + serialized_per_rank = [blob] * tp + + # Last batch flushes the prefix cache so old KV from the + # previous weight version cannot poison subsequent gens. + is_last = (shard_idx == len(shards) - 1) and (i + batch_per_request >= len(keys)) + code, body = _send_update_from_tensor( + url, + serialized_per_rank, + flush_cache=(flush_cache_at_end and is_last), + ) + if code != 200: + fail(f"update batch failed: {code} {body}") + raise RuntimeError(f"update batch failed: {code} {body}") + total_params += len(batch_keys) + total_bytes += sum(t.numel() * t.element_size() for t in tensors) + # Free client-side memory before next batch — the worker has + # already cloned the data by the time post() returned. + for t in tensors: + del t + del tensors, named + torch.cuda.empty_cache() + + print( + f" shard {shard_idx+1}/{len(shards)} done " + f"(+{len(keys)} tensors, {time.time()-shard_t0:.1f}s, " + f"running total {total_params} params, {total_bytes/1e9:.1f} GB)", + flush=True, + ) + + dt = time.time() - t0 + ok(f"streamed {total_params} params, {total_bytes/1e9:.1f} GB in {dt:.1f}s") + + +# ---------------- main flow ---------------- + + +def main(): + ap = argparse.ArgumentParser() + ap.add_argument("--url", default="http://127.0.0.1:8000") + ap.add_argument("--model_dir", required=True) + ap.add_argument("--tp", type=int, required=True) + ap.add_argument( + "--client_device", + type=int, + default=2, + help="GPU index for the in-process trainer; must differ from TP worker GPUs", + ) + ap.add_argument("--prompt", default="The capital of France is") + ap.add_argument("--max_new_tokens", type=int, default=16) + ap.add_argument("--batch_per_request", type=int, default=8) + ap.add_argument("--skip_update", action="store_true", help="run only release/resume, skip the update_weights phase") + args = ap.parse_args() + + # ---------------- stage 1: baseline ---------------- + banner("baseline generate") + base_text = generate(args.url, args.prompt, args.max_new_tokens) + print(f" prompt : {args.prompt!r}") + print(f" generated: {base_text!r}") + ok("baseline generated") + + # ---------------- stage 2: release ---------------- + banner("release_memory_occupation") + before = gpu_mem_used_mib() + print(f" GPU mem before: {before[: max(args.tp, 4)]}") + code, body = post(args.url, "/release_memory_occupation", {}) + print(f" resp: {code} {body}") + if code != 200: + fail("release failed") + sys.exit(1) + time.sleep(2) + after = gpu_mem_used_mib() + print(f" GPU mem after : {after[: max(args.tp, 4)]}") + drop = sum(before[: args.tp]) - sum(after[: args.tp]) + if drop < 10_000: + fail(f"release did not free much memory (delta={drop} MiB)") + sys.exit(1) + ok(f"release freed ~{drop} MiB on TP GPUs") + + # ---------------- stage 3: resume ---------------- + banner("resume_memory_occupation") + code, body = post(args.url, "/resume_memory_occupation", {}) + print(f" resp: {code} {body}") + if code != 200: + fail("resume failed") + sys.exit(1) + time.sleep(2) + print(f" GPU mem after : {gpu_mem_used_mib()[: max(args.tp, 4)]}") + ok("resume returned success") + + banner("post-resume generate (likely garbage without weight cpu backup)") + text_after_resume = generate(args.url, args.prompt, args.max_new_tokens) + print(f" generated: {text_after_resume!r} garbage_heuristic={looks_garbage(text_after_resume)}") + + if args.skip_update: + ok("done (skipped update_weights stage)") + return + + # ---------------- stage 4: update_weights_from_tensor ---------------- + update_weights_from_disk_via_tensor_api( + url=args.url, + model_dir=args.model_dir, + tp=args.tp, + client_device=args.client_device, + batch_per_request=args.batch_per_request, + flush_cache_at_end=True, + ) + + # ---------------- stage 5: final generate ---------------- + banner("final generate (after weight reload)") + final_text = generate(args.url, args.prompt, args.max_new_tokens) + print(f" prompt : {args.prompt!r}") + print(f" generated: {final_text!r}") + if looks_garbage(final_text): + fail("final generation still looks like garbage; weight update did not stick") + sys.exit(1) + if final_text.strip() == base_text.strip(): + ok("final output matches baseline exactly") + else: + ok("final output is sensible (differs from baseline but not garbage)") + + print(f"\n{GREEN}ALL STAGES PASSED{RESET}") + + +if __name__ == "__main__": + main() diff --git a/unit_tests/__init__.py b/unit_tests/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/unit_tests/common/__init__.py b/unit_tests/common/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/unit_tests/common/basemodel/__init__.py b/unit_tests/common/basemodel/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/unit_tests/common/basemodel/test_routing_capture_manager.py b/unit_tests/common/basemodel/test_routing_capture_manager.py new file mode 100644 index 0000000000..dcc010b372 --- /dev/null +++ b/unit_tests/common/basemodel/test_routing_capture_manager.py @@ -0,0 +1,219 @@ +import torch +import numpy as np + + +class TestRoutingCaptureManager: + def test_capture_and_extract_basic(self): + """Test the core pipeline: capture → flush_to_kv_buffer → extract_from_gpu.""" + from lightllm.common.basemodel.routing_manager import RoutingCaptureManager + + manager = RoutingCaptureManager( + num_moe_layers=4, + topk=8, + num_experts=64, + kv_cache_size=1024, + max_capture_tokens=64, + ) + + # Simulate a batch of 10 tokens at KV-cache positions [100..109] + mem_indexes = torch.arange(100, 110, device="cuda") + + # Capture routing for each MoE layer (writes to capture buffer) + for layer_idx in range(4): + topk_ids = torch.randint(0, 64, (10, 8), device="cuda") + manager.capture(moe_layer_index=layer_idx, topk_ids=topk_ids, microbatch_index=0) + + # Flush from capture buffer to KV-indexed gpu_kv_buffer + manager.flush_to_kv_buffer(mem_indexes, num_tokens=10, microbatch_index=0) + + # Extract for those same KV-cache positions + result = manager.extract_from_gpu(mem_indexes) + assert result.shape == (4, 10, 8) + assert result.dtype == np.int8 + + def test_capture_writes_to_correct_kv_positions(self): + """Verify that captured data lands in the right KV-cache positions after flush.""" + from lightllm.common.basemodel.routing_manager import RoutingCaptureManager + + manager = RoutingCaptureManager( + num_moe_layers=2, + topk=4, + num_experts=32, + kv_cache_size=256, + max_capture_tokens=16, + ) + + # Use non-contiguous mem_indexes to simulate real KV-cache + mem_indexes = torch.tensor([10, 50, 200], device="cuda") + + topk_ids = torch.tensor([[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12]], device="cuda") + manager.capture(moe_layer_index=0, topk_ids=topk_ids, microbatch_index=0) + + topk_ids_layer1 = topk_ids + 20 + manager.capture(moe_layer_index=1, topk_ids=topk_ids_layer1, microbatch_index=0) + + # Flush to KV positions + manager.flush_to_kv_buffer(mem_indexes, num_tokens=3, microbatch_index=0) + + # Extract and verify + result = manager.extract_from_gpu(mem_indexes) + assert result.shape == (2, 3, 4) + np.testing.assert_array_equal(result[0], topk_ids.cpu().numpy()) + np.testing.assert_array_equal(result[1], topk_ids_layer1.cpu().numpy()) + + def test_microbatch_isolation(self): + """Two microbatches writing to different KV positions don't interfere.""" + from lightllm.common.basemodel.routing_manager import RoutingCaptureManager + + manager = RoutingCaptureManager( + num_moe_layers=1, + topk=4, + num_experts=32, + kv_cache_size=256, + max_capture_tokens=16, + ) + + # Microbatch 0: positions [10, 11] + mem0 = torch.tensor([10, 11], device="cuda") + ids_0 = torch.ones((2, 4), dtype=torch.int64, device="cuda") + manager.capture(moe_layer_index=0, topk_ids=ids_0, microbatch_index=0) + + # Microbatch 1: positions [20, 21] + mem1 = torch.tensor([20, 21], device="cuda") + ids_1 = torch.ones((2, 4), dtype=torch.int64, device="cuda") * 2 + manager.capture(moe_layer_index=0, topk_ids=ids_1, microbatch_index=1) + + # Flush each microbatch to different KV positions + manager.flush_to_kv_buffer(mem0, num_tokens=2, microbatch_index=0) + manager.flush_to_kv_buffer(mem1, num_tokens=2, microbatch_index=1) + + # Extract microbatch 0 + result0 = manager.extract_from_gpu(mem0) + assert result0.shape == (1, 2, 4) + assert result0[0, 0, 0] == 1 + + # Extract microbatch 1 + result1 = manager.extract_from_gpu(mem1) + assert result1[0, 0, 0] == 2 + + def test_dtype_selection_int8(self): + """Models with ≤127 experts use int8.""" + from lightllm.common.basemodel.routing_manager import RoutingCaptureManager + + manager = RoutingCaptureManager( + num_moe_layers=1, + topk=2, + num_experts=64, + kv_cache_size=128, + max_capture_tokens=16, + ) + assert manager.dtype == torch.int8 + assert manager.np_dtype == np.int8 + assert manager.dtype_id == 1 + + def test_dtype_selection_int16(self): + """Models with >127 experts use int16.""" + from lightllm.common.basemodel.routing_manager import RoutingCaptureManager + + manager = RoutingCaptureManager( + num_moe_layers=1, + topk=2, + num_experts=256, + kv_cache_size=128, + max_capture_tokens=16, + ) + assert manager.dtype == torch.int16 + assert manager.np_dtype == np.int16 + assert manager.dtype_id == 2 + + def test_extract_preserves_values(self): + """Extracted values exactly match what was captured, no dtype truncation.""" + from lightllm.common.basemodel.routing_manager import RoutingCaptureManager + + manager = RoutingCaptureManager( + num_moe_layers=1, + topk=4, + num_experts=64, + kv_cache_size=64, + max_capture_tokens=16, + ) + + mem_indexes = torch.tensor([0, 1, 2], device="cuda") + + topk_ids = torch.tensor([[10, 20, 30, 40], [50, 60, 63, 1], [0, 5, 127, 3]], device="cuda") + manager.capture(moe_layer_index=0, topk_ids=topk_ids, microbatch_index=0) + + # Flush then extract + manager.flush_to_kv_buffer(mem_indexes, num_tokens=3, microbatch_index=0) + result = manager.extract_from_gpu(mem_indexes) + expected = topk_ids.cpu().numpy().astype(np.int8) + np.testing.assert_array_equal(result[0], expected) + + def test_gpu_kv_buffer_shape(self): + """Buffer shape is (num_moe_layers, kv_cache_size, topk).""" + from lightllm.common.basemodel.routing_manager import RoutingCaptureManager + + # 127 experts fits in int8 (max value 127) + manager = RoutingCaptureManager( + num_moe_layers=48, + topk=8, + num_experts=127, + kv_cache_size=2048, + max_capture_tokens=256, + ) + assert manager.gpu_kv_buffer.shape == (48, 2048, 8) + assert manager.gpu_kv_buffer.dtype == torch.int8 + assert manager.gpu_kv_buffer.device.type == "cuda" + + # 128 experts requires int16 + manager2 = RoutingCaptureManager( + num_moe_layers=48, + topk=8, + num_experts=128, + kv_cache_size=2048, + max_capture_tokens=256, + ) + assert manager2.gpu_kv_buffer.dtype == torch.int16 + + def test_partial_token_capture(self): + """capture() only writes num_tokens rows to the buffer.""" + from lightllm.common.basemodel.routing_manager import RoutingCaptureManager + + manager = RoutingCaptureManager( + num_moe_layers=1, + topk=2, + num_experts=32, + kv_cache_size=128, + max_capture_tokens=16, + ) + + # Capture only 3 tokens, flush to 5 KV positions (first 3 get data) + mem_indexes = torch.tensor([10, 11, 12, 13, 14], device="cuda") + + topk_ids = torch.tensor([[1, 2], [3, 4], [5, 6]], device="cuda") # only 3 tokens + manager.capture(moe_layer_index=0, topk_ids=topk_ids, microbatch_index=0) + + # Flush only the 3 captured tokens + manager.flush_to_kv_buffer(mem_indexes[:3], num_tokens=3, microbatch_index=0) + + # Positions 10-12 should have data, 13-14 should be zeros (from init) + result_written = manager.extract_from_gpu(mem_indexes[:3]) + np.testing.assert_array_equal(result_written[0], topk_ids.cpu().numpy().astype(np.int8)) + + result_unwritten = manager.extract_from_gpu(mem_indexes[3:]) + np.testing.assert_array_equal(result_unwritten[0], np.zeros((2, 2), dtype=np.int8)) + + def test_capture_buffer_shape(self): + """Capture buffer has correct shape (max_tokens, num_moe_layers, topk).""" + from lightllm.common.basemodel.routing_manager import RoutingCaptureManager + + manager = RoutingCaptureManager( + num_moe_layers=4, + topk=8, + num_experts=64, + kv_cache_size=1024, + max_capture_tokens=256, + ) + assert manager._capture_buffer[0].shape == (256, 4, 8) + assert manager._capture_buffer[1].shape == (256, 4, 8) + assert manager._capture_buffer[0].dtype == torch.int8 diff --git a/unit_tests/server/router/dynamic_prompt/test_radix_cache.py b/unit_tests/server/router/dynamic_prompt/test_radix_cache.py index 605433e9d8..dfeda0b6f7 100644 --- a/unit_tests/server/router/dynamic_prompt/test_radix_cache.py +++ b/unit_tests/server/router/dynamic_prompt/test_radix_cache.py @@ -230,5 +230,32 @@ def test_case9(): assert torch.equal(unmerged_node_d.token_id_key, torch.tensor([6], dtype=torch.int64)) +def test_case10(): + """ + 测试场景:测试 flush_cache 函数 + """ + print("\nTest Case 10: Testing flush_cache function\n") + tree = RadixCache("unique_name", 100, 0) + tree.insert(torch.tensor([1, 2, 3], dtype=torch.int64)) + tree.insert(torch.tensor([1, 2, 3, 4, 5], dtype=torch.int64)) + tree_node, size, values = tree.match_prefix( + torch.tensor([1, 2, 3], dtype=torch.int64, device="cpu"), update_refs=True + ) + assert tree_node is not None + assert size == 3 + tree.flush_cache() + tree_node, size, values = tree.match_prefix( + torch.tensor([1, 2, 3], dtype=torch.int64, device="cpu"), update_refs=True + ) + assert tree_node is None + assert size == 0 + assert tree.get_tree_total_tokens_num() == 0 + assert tree.get_refed_tokens_num() == 0 + assert len(tree.root_node.children) == 0 + assert tree.root_node.token_id_key.numel() == 0 + assert tree.root_node.token_mem_index_value.numel() == 0 + assert tree.root_node.ref_counter == 1 + + if __name__ == "__main__": pytest.main()