diff --git a/lightllm/distributed/pynccl.py b/lightllm/distributed/pynccl.py index b96e0d1bab..6998caee34 100644 --- a/lightllm/distributed/pynccl.py +++ b/lightllm/distributed/pynccl.py @@ -26,6 +26,7 @@ from typing import Optional, Union, Dict, Deque, Tuple, Any from collections import deque import logging +import ujson as json # ===================== import region ===================== import torch @@ -85,7 +86,7 @@ def send_obj(self, obj: Any): """Send an object to a destination rank.""" self.expire_data() key = f"send_to/{self.dest_id}/{self.send_dst_counter}" - self.store.set(key, pickle.dumps(obj)) + self.store.set(key, json.dumps(obj).encode()) self.send_dst_counter += 1 self.entries.append((key, time.time())) @@ -102,7 +103,7 @@ def expire_data(self): def recv_obj(self) -> Any: """Receive an object from a source rank.""" - obj = pickle.loads(self.store.get(f"send_to/{self.dest_id}/{self.recv_src_counter}")) + obj = json.loads(self.store.get(f"send_to/{self.dest_id}/{self.recv_src_counter}").decode()) self.recv_src_counter += 1 return obj diff --git a/lightllm/server/api_http.py b/lightllm/server/api_http.py index 230da5b369..ae1584d5d6 100755 --- a/lightllm/server/api_http.py +++ b/lightllm/server/api_http.py @@ -307,7 +307,7 @@ async def register_and_keep_alive(websocket: WebSocket): while True: # 等待接收消息,设置超时为10秒 data = await websocket.receive_bytes() - obj = pickle.loads(data) + obj = json.loads(data.decode()) await g_objs.httpserver_manager.put_to_handle_queue(obj) except (WebSocketDisconnect, Exception, RuntimeError) as e: @@ -328,7 +328,7 @@ async def kv_move_status(websocket: WebSocket): while True: # 等待接收消息,设置超时为10秒 data = await websocket.receive_bytes() - upkv_status = pickle.loads(data) + upkv_status = json.loads(data.decode()) logger.info(f"received upkv_status {upkv_status} from {(client_ip, client_port)}") await g_objs.httpserver_manager.update_req_status(upkv_status) except (WebSocketDisconnect, Exception, RuntimeError) as e: diff --git a/lightllm/server/config_server/api_http.py b/lightllm/server/config_server/api_http.py index c5505acda4..bb3e40a4a5 100644 --- a/lightllm/server/config_server/api_http.py +++ b/lightllm/server/config_server/api_http.py @@ -3,6 +3,7 @@ import base64 import pickle import setproctitle +import ujson as json import multiprocessing as mp from fastapi import FastAPI, WebSocket, WebSocketDisconnect, Request, Query from threading import Lock @@ -53,7 +54,7 @@ async def websocket_endpoint(websocket: WebSocket): await websocket.accept() client_ip, client_port = websocket.client logger.info(f"ws connected from IP: {client_ip}, Port: {client_port}") - registered_pd_master_obj: PD_Master_Obj = pickle.loads(await websocket.receive_bytes()) + registered_pd_master_obj: PD_Master_Obj = json.loads((await websocket.receive_bytes()).decode()) logger.info(f"received registered_pd_master_obj {registered_pd_master_obj}") with registered_pd_master_obj_lock: registered_pd_master_objs[registered_pd_master_obj.node_id] = registered_pd_master_obj @@ -75,7 +76,7 @@ async def websocket_endpoint(websocket: WebSocket): @app.get("/registered_objects") async def get_registered_objects(): with registered_pd_master_obj_lock: - serialized_data = pickle.dumps(registered_pd_master_objs) + serialized_data = json.dumps(registered_pd_master_objs).encode() base64_encoded = base64.b64encode(serialized_data).decode("utf-8") return {"data": base64_encoded} diff --git a/lightllm/server/httpserver/manager.py b/lightllm/server/httpserver/manager.py index d51f88cdda..f5f2a313cd 100644 --- a/lightllm/server/httpserver/manager.py +++ b/lightllm/server/httpserver/manager.py @@ -10,6 +10,7 @@ import hashlib import datetime import pickle +import ujson as json from frozendict import frozendict asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) @@ -308,7 +309,7 @@ async def generate( f"nixl prefill node upload group_req_id {group_request_id} prompt ids len : {len(prompt_ids)}" ) await nixl_pd_upload_websocket.send( - pickle.dumps((ObjType.NIXL_UPLOAD_NP_PROMPT_IDS, group_request_id, prompt_ids)) + json.dumps((ObjType.NIXL_UPLOAD_NP_PROMPT_IDS, group_request_id, prompt_ids)).encode() ) try: await asyncio.wait_for(nixl_pd_event.wait(), timeout=80) @@ -317,7 +318,7 @@ async def generate( raise Exception(f"group_req_id {group_request_id} wait nixl_pd_event time out") decode_node_info: NIXLDecodeNodeInfo = nixl_pd_event.decode_node_info - sampling_params.nixl_params.set(pickle.dumps(decode_node_info)) + sampling_params.nixl_params.set(json.dumps(decode_node_info).encode()) if decode_node_info.ready_kv_len == len(prompt_ids) - 1: # 如果 decode 节点的 ready_kv_len 和 prefill encode 的 len(prompt ids) -1 相等,说明不需要进行 prefill diff --git a/lightllm/server/httpserver/pd_loop.py b/lightllm/server/httpserver/pd_loop.py index a646d4f4cc..e127aac474 100644 --- a/lightllm/server/httpserver/pd_loop.py +++ b/lightllm/server/httpserver/pd_loop.py @@ -103,7 +103,7 @@ async def _pd_handle_task(manager: HttpServerManager, pd_master_obj: PD_Master_O # 接收 pd master 发来的请求,并推理后,将生成的token转发回pd master。 while True: recv_bytes = await websocket.recv() - obj = pickle.loads(recv_bytes) + obj = json.loads(recv_bytes.decode()) if obj[0] == ObjType.REQ: prompt, sampling_params, multimodal_params = obj[1] group_req_id = sampling_params.group_request_id @@ -183,7 +183,7 @@ async def _get_pd_master_objs(args: StartArgs) -> Optional[Dict[int, PD_Master_O response = await client.get(uri) if response.status_code == 200: base64data = response.json()["data"] - id_to_pd_master_obj = pickle.loads(base64.b64decode(base64data)) + id_to_pd_master_obj = json.loads(base64.b64decode(base64data).decode()) return id_to_pd_master_obj else: logger.error(f"get pd_master_objs error {response.status_code}") @@ -231,7 +231,7 @@ async def _up_tokens_to_pd_master(forwarding_queue: AsyncQueue, websocket: Clien if handle_list: load_info: dict = _get_load_info() - await websocket.send(pickle.dumps((ObjType.TOKEN_PACKS, handle_list, load_info))) + await websocket.send(json.dumps((ObjType.TOKEN_PACKS, handle_list, load_info)).encode()) # 获取节点负载信息 diff --git a/lightllm/server/httpserver_for_pd_master/manager.py b/lightllm/server/httpserver_for_pd_master/manager.py index fda1387f3b..9e186b6019 100644 --- a/lightllm/server/httpserver_for_pd_master/manager.py +++ b/lightllm/server/httpserver_for_pd_master/manager.py @@ -173,7 +173,9 @@ async def fetch_stream( sampling_params.move_kv_to_decode_node.initialize(decode_node_dict if old_max_new_tokens != 1 else None) sampling_params.suggested_dp_index = -1 - await p_node.websocket.send_bytes(pickle.dumps((ObjType.REQ, (prompt, sampling_params, multimodal_params)))) + await p_node.websocket.send_bytes( + json.dumps((ObjType.REQ, (prompt, sampling_params, multimodal_params))).encode() + ) while True: await req_status.wait_to_ready() @@ -210,7 +212,7 @@ async def fetch_stream( sampling_params.suggested_dp_index = upkv_status.dp_index await d_node.websocket.send_bytes( - pickle.dumps((ObjType.REQ, (prompt_ids, sampling_params, MultimodalParams()))) + json.dumps((ObjType.REQ, (prompt_ids, sampling_params, MultimodalParams()))).encode() ) while True: @@ -244,7 +246,9 @@ async def fetch_nixl_stream( old_max_new_tokens = sampling_params.max_new_tokens sampling_params.max_new_tokens = 1 - await p_node.websocket.send_bytes(pickle.dumps((ObjType.REQ, (prompt, sampling_params, multimodal_params)))) + await p_node.websocket.send_bytes( + json.dumps((ObjType.REQ, (prompt, sampling_params, multimodal_params))).encode() + ) try: await asyncio.wait_for(nixl_np_up_prompt_ids_event.wait(), timeout=60) @@ -260,7 +264,7 @@ async def fetch_nixl_stream( sampling_params.max_new_tokens = old_max_new_tokens await d_node.websocket.send_bytes( - pickle.dumps((ObjType.REQ, (prompt_ids, sampling_params, MultimodalParams()))) + json.dumps((ObjType.REQ, (prompt_ids, sampling_params, MultimodalParams()))).encode() ) try: @@ -272,9 +276,9 @@ async def fetch_nixl_stream( # 将 decode 节点上报的当前请求使用的decode节点的信息下发给 p 节点,这样 p 节点才知道将 kv 传输给那个 d 节点。 upkv_status: NixlUpKVStatus = up_status_event.upkv_status nixl_params: bytes = upkv_status.nixl_params - decode_node_info: NIXLDecodeNodeInfo = pickle.loads(nixl_params) + decode_node_info: NIXLDecodeNodeInfo = json.loads(nixl_params.decode()) await p_node.websocket.send_bytes( - pickle.dumps((ObjType.NIXL_REQ_DECODE_NODE_INFO, group_request_id, decode_node_info)) + json.dumps((ObjType.NIXL_REQ_DECODE_NODE_INFO, group_request_id, decode_node_info)).encode() ) first_token_gen = False @@ -392,12 +396,12 @@ async def abort( pass try: - await p_node.websocket.send_bytes(pickle.dumps((ObjType.ABORT, group_request_id))) + await p_node.websocket.send_bytes(json.dumps((ObjType.ABORT, group_request_id)).encode()) except: pass try: - await d_node.websocket.send_bytes(pickle.dumps((ObjType.ABORT, group_request_id))) + await d_node.websocket.send_bytes(json.dumps((ObjType.ABORT, group_request_id)).encode()) except: pass diff --git a/lightllm/server/httpserver_for_pd_master/register_loop.py b/lightllm/server/httpserver_for_pd_master/register_loop.py index 7e30ffc470..afc94c3850 100644 --- a/lightllm/server/httpserver_for_pd_master/register_loop.py +++ b/lightllm/server/httpserver_for_pd_master/register_loop.py @@ -1,7 +1,7 @@ import asyncio -import pickle import websockets import socket +import ujson as json from lightllm.utils.net_utils import get_hostname_ip from lightllm.utils.log_utils import init_logger from lightllm.server.httpserver_for_pd_master.manager import HttpServerManagerForPDMaster @@ -31,7 +31,7 @@ async def register_loop(manager: HttpServerManagerForPDMaster): node_id=manager.args.pd_node_id, host_ip_port=f"{manager.host_ip}:{manager.args.port}" ) - await websocket.send(pickle.dumps(pd_master_obj)) + await websocket.send(json.dumps(pd_master_obj).encode()) logger.info(f"Sent registration pd_master obj: {pd_master_obj}") while True: diff --git a/lightllm/server/router/model_infer/infer_batch.py b/lightllm/server/router/model_infer/infer_batch.py index 1b4a1ca5cb..f6e29a5cb7 100644 --- a/lightllm/server/router/model_infer/infer_batch.py +++ b/lightllm/server/router/model_infer/infer_batch.py @@ -3,7 +3,7 @@ import torch.distributed as dist import numpy as np import collections -import pickle +import ujson as json from dataclasses import dataclass, field from typing import List, Dict, Tuple, Optional, Callable, Any @@ -283,7 +283,7 @@ def __init__( # nixl decode node information if self.shm_param.nixl_params.data_len > 0: - self.nixl_decode_node: NIXLDecodeNodeInfo = pickle.loads(self.shm_param.nixl_params.get()) + self.nixl_decode_node: NIXLDecodeNodeInfo = json.loads(self.shm_param.nixl_params.get().decode()) else: self.nixl_decode_node: NIXLDecodeNodeInfo = None diff --git a/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/decode_node_impl/up_status.py b/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/decode_node_impl/up_status.py index 833ffecc89..86b93997fb 100644 --- a/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/decode_node_impl/up_status.py +++ b/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/decode_node_impl/up_status.py @@ -1,5 +1,5 @@ import time -import json +import ujson as json import asyncio import threading import websockets @@ -91,7 +91,7 @@ async def up_kv_status_task(self, pd_master_obj: PD_Master_Obj): if pd_master_obj.node_id in self.id_to_handle_queue: task_queue = self.id_to_handle_queue[pd_master_obj.node_id] upkv_status: UpKVStatus = await task_queue.get() - await websocket.send(pickle.dumps(upkv_status)) + await websocket.send(json.dumps(upkv_status).encode()) logger.info(f"up status: {upkv_status}") else: await asyncio.sleep(3) diff --git a/lightllm/server/router/model_infer/mode_backend/pd_nixl/decode_node_impl/decode_trans_process.py b/lightllm/server/router/model_infer/mode_backend/pd_nixl/decode_node_impl/decode_trans_process.py index b04cbb900a..2ee93b10af 100644 --- a/lightllm/server/router/model_infer/mode_backend/pd_nixl/decode_node_impl/decode_trans_process.py +++ b/lightllm/server/router/model_infer/mode_backend/pd_nixl/decode_node_impl/decode_trans_process.py @@ -5,7 +5,7 @@ import torch.multiprocessing as mp import collections import queue -import pickle +import ujson as json from typing import List, Dict, Union, Deque, Optional from lightllm.utils.log_utils import init_logger from lightllm.common.kv_cache_mem_manager import MemoryManager @@ -193,7 +193,7 @@ def dispatch_task_loop(self): up_status = NixlUpKVStatus( group_request_id=task.request_id, pd_master_node_id=task.pd_master_node_id, - nixl_params=pickle.dumps(decode_node_info), + nixl_params=json.dumps(decode_node_info).encode(), ) self.up_status_in_queue.put(up_status) @@ -220,7 +220,7 @@ def accept_peer_task_loop( for remote_agent_name, _notify_list in notifies_dict.items(): for notify in _notify_list: try: - notify_obj = pickle.loads(notify) + notify_obj = json.loads(notify.decode()) except: notify_obj = None @@ -236,7 +236,7 @@ def accept_peer_task_loop( try: self.transporter.send_notify_to_prefill_node( prefill_agent_name=remote_agent_name, - notify=pickle.dumps(remote_trans_task.createRetObj()), + notify=json.dumps(remote_trans_task.createRetObj()).encode(), ) except BaseException as e: logger.error(f"send notify to prefill node failed: {str(e)}") diff --git a/lightllm/server/router/model_infer/mode_backend/pd_nixl/decode_node_impl/up_status.py b/lightllm/server/router/model_infer/mode_backend/pd_nixl/decode_node_impl/up_status.py index f79fb4ea2c..85e8c597ad 100644 --- a/lightllm/server/router/model_infer/mode_backend/pd_nixl/decode_node_impl/up_status.py +++ b/lightllm/server/router/model_infer/mode_backend/pd_nixl/decode_node_impl/up_status.py @@ -88,7 +88,7 @@ async def up_kv_status_task(self, pd_master_obj: PD_Master_Obj): if pd_master_obj.node_id in self.id_to_handle_queue: task_queue = self.id_to_handle_queue[pd_master_obj.node_id] upkv_status: Union[UpKVStatus, NixlUpKVStatus] = await task_queue.get() - await websocket.send(pickle.dumps(upkv_status)) + await websocket.send(json.dumps(upkv_status).encode()) logger.info(f"up kv status: {upkv_status}") else: await asyncio.sleep(3) diff --git a/lightllm/server/router/model_infer/mode_backend/pd_nixl/nixl_kv_transporter.py b/lightllm/server/router/model_infer/mode_backend/pd_nixl/nixl_kv_transporter.py index 134fbd5027..f7cc2fe8d0 100644 --- a/lightllm/server/router/model_infer/mode_backend/pd_nixl/nixl_kv_transporter.py +++ b/lightllm/server/router/model_infer/mode_backend/pd_nixl/nixl_kv_transporter.py @@ -1,5 +1,6 @@ import pickle import copy +import ujson as json from dataclasses import dataclass from collections import defaultdict from typing import Dict, List, Any, Optional, Tuple @@ -125,7 +126,7 @@ def send_readtask_to_decode_node(self, trans_task: NIXLChunckedTransTask): new_trans_task.mem_indexes = None self.nixl_agent.send_notif( remote_agent.agent_name, - pickle.dumps(new_trans_task), + json.dumps(new_trans_task).encode(), ) return @@ -165,7 +166,7 @@ def read_blocks_paged( [trans_task.nixl_dst_page_index], src_handle, [trans_task.nixl_src_page_index], - pickle.dumps(notify_obj), + json.dumps(notify_obj).encode(), ) if not handle: raise RuntimeError(f"make_prepped_xfer failed for task: {trans_task.to_str()}") diff --git a/lightllm/server/router/model_infer/mode_backend/pd_nixl/prefill_node_impl/prefill_trans_process.py b/lightllm/server/router/model_infer/mode_backend/pd_nixl/prefill_node_impl/prefill_trans_process.py index 063ce5c6a9..f11e3d86b7 100644 --- a/lightllm/server/router/model_infer/mode_backend/pd_nixl/prefill_node_impl/prefill_trans_process.py +++ b/lightllm/server/router/model_infer/mode_backend/pd_nixl/prefill_node_impl/prefill_trans_process.py @@ -5,7 +5,7 @@ import torch.multiprocessing as mp import collections import queue -import pickle +import ujson as json from typing import List, Dict, Union, Deque, Optional from lightllm.utils.log_utils import init_logger from lightllm.common.kv_cache_mem_manager import MemoryManager @@ -211,7 +211,7 @@ def update_task_status_loop( for _, _notify_list in notifies_dict.items(): for notify in _notify_list: try: - notify_obj = pickle.loads(notify) + notify_obj = json.loads(notify.decode()) except: notify_obj = None