From 6dad5a5ebc073dbf01c629c733cb35c2f4d1f768 Mon Sep 17 00:00:00 2001 From: RinZ27 <222222878+RinZ27@users.noreply.github.com> Date: Thu, 5 Mar 2026 09:42:54 +0700 Subject: [PATCH 1/2] fix: implement RestrictedUnpickler to mitigate CVE-2026-26220 --- lightllm/distributed/pynccl.py | 6 +- lightllm/server/api_http.py | 6 +- lightllm/server/config_server/api_http.py | 6 +- lightllm/server/core/objs/rpc_shm.py | 10 ++-- .../server/core/objs/shm_objs_io_buffer.py | 6 +- lightllm/server/httpserver/manager.py | 4 +- lightllm/server/httpserver/pd_loop.py | 8 +-- .../httpserver_for_pd_master/manager.py | 18 +++--- .../httpserver_for_pd_master/register_loop.py | 4 +- .../server/router/model_infer/infer_batch.py | 4 +- .../pd_mode/decode_node_impl/up_status.py | 4 +- .../decode_node_impl/decode_trans_process.py | 8 +-- .../pd_nixl/decode_node_impl/up_status.py | 4 +- .../pd_nixl/nixl_kv_transporter.py | 6 +- .../prefill_trans_process.py | 4 +- lightllm/utils/pickle_utils.py | 60 +++++++++++++++++++ 16 files changed, 109 insertions(+), 49 deletions(-) create mode 100644 lightllm/utils/pickle_utils.py diff --git a/lightllm/distributed/pynccl.py b/lightllm/distributed/pynccl.py index b96e0d1bab..8cc6421cab 100644 --- a/lightllm/distributed/pynccl.py +++ b/lightllm/distributed/pynccl.py @@ -21,7 +21,7 @@ import dataclasses from datetime import timedelta -import pickle +from lightllm.utils.pickle_utils import safe_pickle_loads, safe_pickle_dumps import time from typing import Optional, Union, Dict, Deque, Tuple, Any from collections import deque @@ -85,7 +85,7 @@ def send_obj(self, obj: Any): """Send an object to a destination rank.""" self.expire_data() key = f"send_to/{self.dest_id}/{self.send_dst_counter}" - self.store.set(key, pickle.dumps(obj)) + self.store.set(key, safe_pickle_dumps(obj)) self.send_dst_counter += 1 self.entries.append((key, time.time())) @@ -102,7 +102,7 @@ def expire_data(self): def recv_obj(self) -> Any: """Receive an object from a source rank.""" - obj = pickle.loads(self.store.get(f"send_to/{self.dest_id}/{self.recv_src_counter}")) + obj = safe_pickle_loads(self.store.get(f"send_to/{self.dest_id}/{self.recv_src_counter}")) self.recv_src_counter += 1 return obj diff --git a/lightllm/server/api_http.py b/lightllm/server/api_http.py index 230da5b369..3098b07f8c 100755 --- a/lightllm/server/api_http.py +++ b/lightllm/server/api_http.py @@ -24,7 +24,7 @@ import base64 import os from io import BytesIO -import pickle +from lightllm.utils.pickle_utils import safe_pickle_loads import setproctitle asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) @@ -307,7 +307,7 @@ async def register_and_keep_alive(websocket: WebSocket): while True: # 等待接收消息,设置超时为10秒 data = await websocket.receive_bytes() - obj = pickle.loads(data) + obj = safe_pickle_loads(data) await g_objs.httpserver_manager.put_to_handle_queue(obj) except (WebSocketDisconnect, Exception, RuntimeError) as e: @@ -328,7 +328,7 @@ async def kv_move_status(websocket: WebSocket): while True: # 等待接收消息,设置超时为10秒 data = await websocket.receive_bytes() - upkv_status = pickle.loads(data) + upkv_status = safe_pickle_loads(data) logger.info(f"received upkv_status {upkv_status} from {(client_ip, client_port)}") await g_objs.httpserver_manager.update_req_status(upkv_status) except (WebSocketDisconnect, Exception, RuntimeError) as e: diff --git a/lightllm/server/config_server/api_http.py b/lightllm/server/config_server/api_http.py index c5505acda4..9f69f514ba 100644 --- a/lightllm/server/config_server/api_http.py +++ b/lightllm/server/config_server/api_http.py @@ -1,7 +1,7 @@ import time import asyncio import base64 -import pickle +from lightllm.utils.pickle_utils import safe_pickle_loads, safe_pickle_dumps import setproctitle import multiprocessing as mp from fastapi import FastAPI, WebSocket, WebSocketDisconnect, Request, Query @@ -53,7 +53,7 @@ async def websocket_endpoint(websocket: WebSocket): await websocket.accept() client_ip, client_port = websocket.client logger.info(f"ws connected from IP: {client_ip}, Port: {client_port}") - registered_pd_master_obj: PD_Master_Obj = pickle.loads(await websocket.receive_bytes()) + registered_pd_master_obj: PD_Master_Obj = safe_pickle_loads(await websocket.receive_bytes()) logger.info(f"received registered_pd_master_obj {registered_pd_master_obj}") with registered_pd_master_obj_lock: registered_pd_master_objs[registered_pd_master_obj.node_id] = registered_pd_master_obj @@ -75,7 +75,7 @@ async def websocket_endpoint(websocket: WebSocket): @app.get("/registered_objects") async def get_registered_objects(): with registered_pd_master_obj_lock: - serialized_data = pickle.dumps(registered_pd_master_objs) + serialized_data = safe_pickle_dumps(registered_pd_master_objs) base64_encoded = base64.b64encode(serialized_data).decode("utf-8") return {"data": base64_encoded} diff --git a/lightllm/server/core/objs/rpc_shm.py b/lightllm/server/core/objs/rpc_shm.py index 69a3d01236..990e4b1d22 100644 --- a/lightllm/server/core/objs/rpc_shm.py +++ b/lightllm/server/core/objs/rpc_shm.py @@ -1,5 +1,5 @@ import os -import pickle +from lightllm.utils.pickle_utils import safe_pickle_loads, safe_pickle_dumps import numpy as np from multiprocessing import shared_memory from typing import List @@ -24,14 +24,14 @@ def create_or_link_shm(self): return def write_func_params(self, func_name, args): - objs_bytes = pickle.dumps((func_name, args)) + objs_bytes = safe_pickle_dumps((func_name, args)) self.shm.buf.cast("i")[0] = len(objs_bytes) self.shm.buf[4 : 4 + len(objs_bytes)] = objs_bytes return def read_func_params(self): bytes_len = self.shm.buf.cast("i")[0] - func_name, args = pickle.loads(self.shm.buf[4 : 4 + bytes_len]) + func_name, args = safe_pickle_loads(self.shm.buf[4 : 4 + bytes_len]) return func_name, args @@ -45,13 +45,13 @@ def create_or_link_shm(self): return def write_func_result(self, func_name, ret): - objs_bytes = pickle.dumps((func_name, ret)) + objs_bytes = safe_pickle_dumps((func_name, ret)) self.shm.buf.cast("i")[0] = len(objs_bytes) self.shm.buf[4 : 4 + len(objs_bytes)] = objs_bytes def read_func_result(self): bytes_len = self.shm.buf.cast("i")[0] - func_name, ret = pickle.loads(self.shm.buf[4 : 4 + bytes_len]) + func_name, ret = safe_pickle_loads(self.shm.buf[4 : 4 + bytes_len]) return func_name, ret diff --git a/lightllm/server/core/objs/shm_objs_io_buffer.py b/lightllm/server/core/objs/shm_objs_io_buffer.py index 05b6087601..979aa9f3bf 100644 --- a/lightllm/server/core/objs/shm_objs_io_buffer.py +++ b/lightllm/server/core/objs/shm_objs_io_buffer.py @@ -1,5 +1,5 @@ import os -import pickle +from lightllm.utils.pickle_utils import safe_pickle_loads, safe_pickle_dumps from lightllm.server.core.objs.atomic_lock import AtomicShmLock from lightllm.utils.envs_utils import get_env_start_args from lightllm.utils.envs_utils import get_unique_server_name @@ -41,14 +41,14 @@ def is_ready(self): return self.int_view[0] == self.node_world_size def write_obj(self, obj): - obj_bytes = pickle.dumps(obj) + obj_bytes = safe_pickle_dumps(obj) self.int_view[1] = len(obj_bytes) self.shm.buf[8 : 8 + len(obj_bytes)] = obj_bytes return def read_obj(self): bytes_len = self.int_view[1] - obj = pickle.loads(self.shm.buf[8 : 8 + bytes_len]) + obj = safe_pickle_loads(self.shm.buf[8 : 8 + bytes_len]) return obj def _create_or_link_shm(self): diff --git a/lightllm/server/httpserver/manager.py b/lightllm/server/httpserver/manager.py index d51f88cdda..5f638fc3f7 100644 --- a/lightllm/server/httpserver/manager.py +++ b/lightllm/server/httpserver/manager.py @@ -308,7 +308,7 @@ async def generate( f"nixl prefill node upload group_req_id {group_request_id} prompt ids len : {len(prompt_ids)}" ) await nixl_pd_upload_websocket.send( - pickle.dumps((ObjType.NIXL_UPLOAD_NP_PROMPT_IDS, group_request_id, prompt_ids)) + safe_pickle_dumps((ObjType.NIXL_UPLOAD_NP_PROMPT_IDS, group_request_id, prompt_ids)) ) try: await asyncio.wait_for(nixl_pd_event.wait(), timeout=80) @@ -317,7 +317,7 @@ async def generate( raise Exception(f"group_req_id {group_request_id} wait nixl_pd_event time out") decode_node_info: NIXLDecodeNodeInfo = nixl_pd_event.decode_node_info - sampling_params.nixl_params.set(pickle.dumps(decode_node_info)) + sampling_params.nixl_params.set(safe_pickle_dumps(decode_node_info)) if decode_node_info.ready_kv_len == len(prompt_ids) - 1: # 如果 decode 节点的 ready_kv_len 和 prefill encode 的 len(prompt ids) -1 相等,说明不需要进行 prefill diff --git a/lightllm/server/httpserver/pd_loop.py b/lightllm/server/httpserver/pd_loop.py index a646d4f4cc..f7d1727576 100644 --- a/lightllm/server/httpserver/pd_loop.py +++ b/lightllm/server/httpserver/pd_loop.py @@ -1,5 +1,5 @@ import asyncio -import pickle +from lightllm.utils.pickle_utils import safe_pickle_loads, safe_pickle_dumps import websockets import ujson as json import socket @@ -103,7 +103,7 @@ async def _pd_handle_task(manager: HttpServerManager, pd_master_obj: PD_Master_O # 接收 pd master 发来的请求,并推理后,将生成的token转发回pd master。 while True: recv_bytes = await websocket.recv() - obj = pickle.loads(recv_bytes) + obj = safe_pickle_loads(recv_bytes) if obj[0] == ObjType.REQ: prompt, sampling_params, multimodal_params = obj[1] group_req_id = sampling_params.group_request_id @@ -183,7 +183,7 @@ async def _get_pd_master_objs(args: StartArgs) -> Optional[Dict[int, PD_Master_O response = await client.get(uri) if response.status_code == 200: base64data = response.json()["data"] - id_to_pd_master_obj = pickle.loads(base64.b64decode(base64data)) + id_to_pd_master_obj = safe_pickle_loads(base64.b64decode(base64data)) return id_to_pd_master_obj else: logger.error(f"get pd_master_objs error {response.status_code}") @@ -231,7 +231,7 @@ async def _up_tokens_to_pd_master(forwarding_queue: AsyncQueue, websocket: Clien if handle_list: load_info: dict = _get_load_info() - await websocket.send(pickle.dumps((ObjType.TOKEN_PACKS, handle_list, load_info))) + await websocket.send(safe_pickle_dumps((ObjType.TOKEN_PACKS, handle_list, load_info))) # 获取节点负载信息 diff --git a/lightllm/server/httpserver_for_pd_master/manager.py b/lightllm/server/httpserver_for_pd_master/manager.py index fda1387f3b..10cef2cac6 100644 --- a/lightllm/server/httpserver_for_pd_master/manager.py +++ b/lightllm/server/httpserver_for_pd_master/manager.py @@ -4,7 +4,7 @@ import time import datetime import ujson as json -import pickle +from lightllm.utils.pickle_utils import safe_pickle_loads, safe_pickle_dumps asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) from typing import Union, List, Tuple, Dict, Optional @@ -173,7 +173,7 @@ async def fetch_stream( sampling_params.move_kv_to_decode_node.initialize(decode_node_dict if old_max_new_tokens != 1 else None) sampling_params.suggested_dp_index = -1 - await p_node.websocket.send_bytes(pickle.dumps((ObjType.REQ, (prompt, sampling_params, multimodal_params)))) + await p_node.websocket.send_bytes(safe_pickle_dumps((ObjType.REQ, (prompt, sampling_params, multimodal_params)))) while True: await req_status.wait_to_ready() @@ -210,7 +210,7 @@ async def fetch_stream( sampling_params.suggested_dp_index = upkv_status.dp_index await d_node.websocket.send_bytes( - pickle.dumps((ObjType.REQ, (prompt_ids, sampling_params, MultimodalParams()))) + safe_pickle_dumps((ObjType.REQ, (prompt_ids, sampling_params, MultimodalParams()))) ) while True: @@ -244,7 +244,7 @@ async def fetch_nixl_stream( old_max_new_tokens = sampling_params.max_new_tokens sampling_params.max_new_tokens = 1 - await p_node.websocket.send_bytes(pickle.dumps((ObjType.REQ, (prompt, sampling_params, multimodal_params)))) + await p_node.websocket.send_bytes(safe_pickle_dumps((ObjType.REQ, (prompt, sampling_params, multimodal_params)))) try: await asyncio.wait_for(nixl_np_up_prompt_ids_event.wait(), timeout=60) @@ -260,7 +260,7 @@ async def fetch_nixl_stream( sampling_params.max_new_tokens = old_max_new_tokens await d_node.websocket.send_bytes( - pickle.dumps((ObjType.REQ, (prompt_ids, sampling_params, MultimodalParams()))) + safe_pickle_dumps((ObjType.REQ, (prompt_ids, sampling_params, MultimodalParams()))) ) try: @@ -272,9 +272,9 @@ async def fetch_nixl_stream( # 将 decode 节点上报的当前请求使用的decode节点的信息下发给 p 节点,这样 p 节点才知道将 kv 传输给那个 d 节点。 upkv_status: NixlUpKVStatus = up_status_event.upkv_status nixl_params: bytes = upkv_status.nixl_params - decode_node_info: NIXLDecodeNodeInfo = pickle.loads(nixl_params) + decode_node_info: NIXLDecodeNodeInfo = safe_pickle_loads(nixl_params) await p_node.websocket.send_bytes( - pickle.dumps((ObjType.NIXL_REQ_DECODE_NODE_INFO, group_request_id, decode_node_info)) + safe_pickle_dumps((ObjType.NIXL_REQ_DECODE_NODE_INFO, group_request_id, decode_node_info)) ) first_token_gen = False @@ -392,12 +392,12 @@ async def abort( pass try: - await p_node.websocket.send_bytes(pickle.dumps((ObjType.ABORT, group_request_id))) + await p_node.websocket.send_bytes(safe_pickle_dumps((ObjType.ABORT, group_request_id))) except: pass try: - await d_node.websocket.send_bytes(pickle.dumps((ObjType.ABORT, group_request_id))) + await d_node.websocket.send_bytes(safe_pickle_dumps((ObjType.ABORT, group_request_id))) except: pass diff --git a/lightllm/server/httpserver_for_pd_master/register_loop.py b/lightllm/server/httpserver_for_pd_master/register_loop.py index 7e30ffc470..4bb29ef9d1 100644 --- a/lightllm/server/httpserver_for_pd_master/register_loop.py +++ b/lightllm/server/httpserver_for_pd_master/register_loop.py @@ -1,5 +1,5 @@ import asyncio -import pickle +from lightllm.utils.pickle_utils import safe_pickle_dumps import websockets import socket from lightllm.utils.net_utils import get_hostname_ip @@ -31,7 +31,7 @@ async def register_loop(manager: HttpServerManagerForPDMaster): node_id=manager.args.pd_node_id, host_ip_port=f"{manager.host_ip}:{manager.args.port}" ) - await websocket.send(pickle.dumps(pd_master_obj)) + await websocket.send(safe_pickle_dumps(pd_master_obj)) logger.info(f"Sent registration pd_master obj: {pd_master_obj}") while True: diff --git a/lightllm/server/router/model_infer/infer_batch.py b/lightllm/server/router/model_infer/infer_batch.py index 1b4a1ca5cb..409932ffcb 100644 --- a/lightllm/server/router/model_infer/infer_batch.py +++ b/lightllm/server/router/model_infer/infer_batch.py @@ -3,7 +3,7 @@ import torch.distributed as dist import numpy as np import collections -import pickle +from lightllm.utils.pickle_utils import safe_pickle_loads from dataclasses import dataclass, field from typing import List, Dict, Tuple, Optional, Callable, Any @@ -283,7 +283,7 @@ def __init__( # nixl decode node information if self.shm_param.nixl_params.data_len > 0: - self.nixl_decode_node: NIXLDecodeNodeInfo = pickle.loads(self.shm_param.nixl_params.get()) + self.nixl_decode_node: NIXLDecodeNodeInfo = safe_pickle_loads(self.shm_param.nixl_params.get()) else: self.nixl_decode_node: NIXLDecodeNodeInfo = None diff --git a/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/decode_node_impl/up_status.py b/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/decode_node_impl/up_status.py index 833ffecc89..6ef550c781 100644 --- a/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/decode_node_impl/up_status.py +++ b/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/decode_node_impl/up_status.py @@ -5,7 +5,7 @@ import websockets import inspect import setproctitle -import pickle +from lightllm.utils.pickle_utils import safe_pickle_loads, safe_pickle_dumps from typing import Dict from dataclasses import asdict @@ -91,7 +91,7 @@ async def up_kv_status_task(self, pd_master_obj: PD_Master_Obj): if pd_master_obj.node_id in self.id_to_handle_queue: task_queue = self.id_to_handle_queue[pd_master_obj.node_id] upkv_status: UpKVStatus = await task_queue.get() - await websocket.send(pickle.dumps(upkv_status)) + await websocket.send(safe_pickle_dumps(upkv_status)) logger.info(f"up status: {upkv_status}") else: await asyncio.sleep(3) diff --git a/lightllm/server/router/model_infer/mode_backend/pd_nixl/decode_node_impl/decode_trans_process.py b/lightllm/server/router/model_infer/mode_backend/pd_nixl/decode_node_impl/decode_trans_process.py index b04cbb900a..e9eb68aa22 100644 --- a/lightllm/server/router/model_infer/mode_backend/pd_nixl/decode_node_impl/decode_trans_process.py +++ b/lightllm/server/router/model_infer/mode_backend/pd_nixl/decode_node_impl/decode_trans_process.py @@ -5,7 +5,7 @@ import torch.multiprocessing as mp import collections import queue -import pickle +from lightllm.utils.pickle_utils import safe_pickle_loads, safe_pickle_dumps from typing import List, Dict, Union, Deque, Optional from lightllm.utils.log_utils import init_logger from lightllm.common.kv_cache_mem_manager import MemoryManager @@ -193,7 +193,7 @@ def dispatch_task_loop(self): up_status = NixlUpKVStatus( group_request_id=task.request_id, pd_master_node_id=task.pd_master_node_id, - nixl_params=pickle.dumps(decode_node_info), + nixl_params=safe_pickle_dumps(decode_node_info), ) self.up_status_in_queue.put(up_status) @@ -220,7 +220,7 @@ def accept_peer_task_loop( for remote_agent_name, _notify_list in notifies_dict.items(): for notify in _notify_list: try: - notify_obj = pickle.loads(notify) + notify_obj = safe_pickle_loads(notify) except: notify_obj = None @@ -236,7 +236,7 @@ def accept_peer_task_loop( try: self.transporter.send_notify_to_prefill_node( prefill_agent_name=remote_agent_name, - notify=pickle.dumps(remote_trans_task.createRetObj()), + notify=safe_pickle_dumps(remote_trans_task.createRetObj()), ) except BaseException as e: logger.error(f"send notify to prefill node failed: {str(e)}") diff --git a/lightllm/server/router/model_infer/mode_backend/pd_nixl/decode_node_impl/up_status.py b/lightllm/server/router/model_infer/mode_backend/pd_nixl/decode_node_impl/up_status.py index f79fb4ea2c..ba66db5b9a 100644 --- a/lightllm/server/router/model_infer/mode_backend/pd_nixl/decode_node_impl/up_status.py +++ b/lightllm/server/router/model_infer/mode_backend/pd_nixl/decode_node_impl/up_status.py @@ -4,7 +4,7 @@ import threading import websockets import inspect -import pickle +from lightllm.utils.pickle_utils import safe_pickle_dumps from typing import Dict, Union from dataclasses import asdict @@ -88,7 +88,7 @@ async def up_kv_status_task(self, pd_master_obj: PD_Master_Obj): if pd_master_obj.node_id in self.id_to_handle_queue: task_queue = self.id_to_handle_queue[pd_master_obj.node_id] upkv_status: Union[UpKVStatus, NixlUpKVStatus] = await task_queue.get() - await websocket.send(pickle.dumps(upkv_status)) + await websocket.send(safe_pickle_dumps(upkv_status)) logger.info(f"up kv status: {upkv_status}") else: await asyncio.sleep(3) diff --git a/lightllm/server/router/model_infer/mode_backend/pd_nixl/nixl_kv_transporter.py b/lightllm/server/router/model_infer/mode_backend/pd_nixl/nixl_kv_transporter.py index 134fbd5027..1d7630c717 100644 --- a/lightllm/server/router/model_infer/mode_backend/pd_nixl/nixl_kv_transporter.py +++ b/lightllm/server/router/model_infer/mode_backend/pd_nixl/nixl_kv_transporter.py @@ -1,4 +1,4 @@ -import pickle +from lightllm.utils.pickle_utils import safe_pickle_dumps import copy from dataclasses import dataclass from collections import defaultdict @@ -125,7 +125,7 @@ def send_readtask_to_decode_node(self, trans_task: NIXLChunckedTransTask): new_trans_task.mem_indexes = None self.nixl_agent.send_notif( remote_agent.agent_name, - pickle.dumps(new_trans_task), + safe_pickle_dumps(new_trans_task), ) return @@ -165,7 +165,7 @@ def read_blocks_paged( [trans_task.nixl_dst_page_index], src_handle, [trans_task.nixl_src_page_index], - pickle.dumps(notify_obj), + safe_pickle_dumps(notify_obj), ) if not handle: raise RuntimeError(f"make_prepped_xfer failed for task: {trans_task.to_str()}") diff --git a/lightllm/server/router/model_infer/mode_backend/pd_nixl/prefill_node_impl/prefill_trans_process.py b/lightllm/server/router/model_infer/mode_backend/pd_nixl/prefill_node_impl/prefill_trans_process.py index 063ce5c6a9..0ad4018215 100644 --- a/lightllm/server/router/model_infer/mode_backend/pd_nixl/prefill_node_impl/prefill_trans_process.py +++ b/lightllm/server/router/model_infer/mode_backend/pd_nixl/prefill_node_impl/prefill_trans_process.py @@ -5,7 +5,7 @@ import torch.multiprocessing as mp import collections import queue -import pickle +from lightllm.utils.pickle_utils import safe_pickle_loads from typing import List, Dict, Union, Deque, Optional from lightllm.utils.log_utils import init_logger from lightllm.common.kv_cache_mem_manager import MemoryManager @@ -211,7 +211,7 @@ def update_task_status_loop( for _, _notify_list in notifies_dict.items(): for notify in _notify_list: try: - notify_obj = pickle.loads(notify) + notify_obj = safe_pickle_loads(notify) except: notify_obj = None diff --git a/lightllm/utils/pickle_utils.py b/lightllm/utils/pickle_utils.py new file mode 100644 index 0000000000..cf974bb5d5 --- /dev/null +++ b/lightllm/utils/pickle_utils.py @@ -0,0 +1,60 @@ +import pickle +import io + +class RestrictedUnpickler(pickle.Unpickler): + """ + A restricted unpickler that only allows a whitelist of classes to be deserialized. + This mitigates the Remote Code Execution (RCE) risk associated with pickle.loads(). + """ + ALLOWED_MODULES = { + "lightllm.server.config_server.api_http", + "lightllm.server.router.model_infer.mode_backend.continues_batch.pd_mode.decode_node_impl.up_status", + "lightllm.server.router.model_infer.mode_backend.pd_nixl.decode_node_impl.decode_trans_process", + "lightllm.server.router.model_infer.mode_backend.pd_nixl.prefill_node_impl.prefill_trans_process", + "lightllm.server.httpserver.manager", + "lightllm.server.httpserver_for_pd_master.manager", + "lightllm.server.router.model_infer.infer_batch", + "lightllm.server.router.model_infer.mode_backend.pd_nixl.nixl_kv_transporter", + "lightllm.server.router.model_infer.mode_backend.pd_nixl.decode_node_impl.up_status", + "enum", + "builtins", + "collections", + "numpy", + "torch", + } + + ALLOWED_CLASSES = { + ("builtins", "list"), + ("builtins", "dict"), + ("builtins", "tuple"), + ("builtins", "set"), + ("builtins", "int"), + ("builtins", "float"), + ("builtins", "str"), + ("builtins", "bool"), + ("builtins", "NoneType"), + ("builtins", "getattr"), + ("collections", "deque"), + ("enum", "Enum"), + } + + def find_class(self, module, name): + # Only allow specific modules or classes + if module in self.ALLOWED_MODULES or (module, name) in self.ALLOWED_CLASSES: + return super().find_class(module, name) + + # Also allow classes starting with lightllm + if module.startswith("lightllm."): + return super().find_class(module, name) + + raise pickle.UnpicklingError(f"Global '{module}.{name}' is forbidden") + +def safe_pickle_loads(data): + """Safely loads a pickled object using a restricted unpickler.""" + if data is None: + return None + return RestrictedUnpickler(io.BytesIO(data)).load() + +def safe_pickle_dumps(obj): + """Dumps an object using pickle.HIGHEST_PROTOCOL.""" + return pickle.dumps(obj, protocol=pickle.HIGHEST_PROTOCOL) From a9cf27232e532ae5ec830850f9c1991b50acd995 Mon Sep 17 00:00:00 2001 From: Rin Date: Thu, 5 Mar 2026 09:48:36 +0700 Subject: [PATCH 2/2] Update lightllm/utils/pickle_utils.py Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> --- lightllm/utils/pickle_utils.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/lightllm/utils/pickle_utils.py b/lightllm/utils/pickle_utils.py index cf974bb5d5..77dd845895 100644 --- a/lightllm/utils/pickle_utils.py +++ b/lightllm/utils/pickle_utils.py @@ -42,11 +42,7 @@ def find_class(self, module, name): # Only allow specific modules or classes if module in self.ALLOWED_MODULES or (module, name) in self.ALLOWED_CLASSES: return super().find_class(module, name) - - # Also allow classes starting with lightllm - if module.startswith("lightllm."): - return super().find_class(module, name) - + raise pickle.UnpicklingError(f"Global '{module}.{name}' is forbidden") def safe_pickle_loads(data):