-
Notifications
You must be signed in to change notification settings - Fork 753
[Feature]report PD info to IM #8082
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: develop
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -23,6 +23,7 @@ | |
| import uuid | ||
| from collections.abc import AsyncGenerator | ||
| from contextlib import asynccontextmanager | ||
| import requests | ||
|
|
||
| import uvicorn | ||
| import zmq | ||
|
|
@@ -120,6 +121,51 @@ | |
| MAX_CONCURRENT_CONNECTIONS = (args.max_concurrency + args.workers - 1) // args.workers | ||
| connection_semaphore = StatefulSemaphore(MAX_CONCURRENT_CONNECTIONS) | ||
|
|
||
| # { ip: register_info_dict, ... } | ||
| _decode_nodes_lock = threading.Lock() | ||
| _decode_nodes_register_info: dict = {} | ||
|
|
||
|
|
||
| def _fetch_decode_node_register_info(ip: str, port: int): | ||
| """get /register_info""" | ||
| url = f"http://{ip}:{port}/register_info" | ||
| try: | ||
| resp = requests.get(url, timeout=3) | ||
| if resp.status_code == 200: | ||
| return resp.json() | ||
| except Exception as e: | ||
| api_server_logger.error(f"_fetch_decode_node_register_info {url} error: {e}") | ||
| return None | ||
|
|
||
|
|
||
| def _poll_decode_nodes(): | ||
| """get decode node info and save""" | ||
| while True: | ||
| try: | ||
| ip_list_env = os.environ.get("D_IP_LIST", "") | ||
| port_list_env = os.environ.get("DECODE_PORTS", "") | ||
| ip_list = [ip.strip() for ip in ip_list_env.split(",") if ip.strip()] | ||
| port_list = [item.strip() for item in port_list_env.split(",") if item.strip()] | ||
| if len(ip_list) > 0 and len(port_list) > 0: | ||
| new_info: dict = {} | ||
| for ip in ip_list: | ||
| register_info = _fetch_decode_node_register_info(ip, port_list[0]) | ||
| if register_info is not None: | ||
| new_info[ip] = register_info | ||
| with _decode_nodes_lock: | ||
| _decode_nodes_register_info.clear() | ||
| _decode_nodes_register_info.update(new_info) | ||
| api_server_logger.debug(f"decode node info update {len(ip_list)}") | ||
| except Exception as e: | ||
| api_server_logger.error(f"decode node info update error: {e}, {traceback.format_exc()}") | ||
| time.sleep(5) | ||
|
|
||
|
|
||
| def launch_decode_node_poller(): | ||
| t = threading.Thread(target=_poll_decode_nodes, daemon=True, name="decode-node-poller") | ||
| t.start() | ||
| api_server_logger.info("launch_decode_node_poller start") | ||
|
|
||
|
|
||
| class StandaloneApplication(BaseApplication): | ||
| def __init__(self, app, options=None): | ||
|
|
@@ -314,6 +360,7 @@ async def lifespan(app: FastAPI): | |
|
|
||
| if llm_engine is not None and not isinstance(llm_engine, AsyncLLM): | ||
| llm_engine.engine.data_processor = engine_client.data_processor | ||
| launch_decode_node_poller() | ||
| yield | ||
| # close zmq | ||
| try: | ||
|
|
@@ -770,6 +817,183 @@ async def check_redundant(request: Request): | |
| return JSONResponse(content, status_code=status_code) | ||
|
|
||
|
|
||
| @app.get("/register_info") | ||
| def register_info() -> Response: | ||
| """ | ||
| Get the current register_info. | ||
| """ | ||
| global llm_engine | ||
| if llm_engine is None: | ||
| return Response("Engine not loaded", status_code=500) | ||
| cfg = llm_engine.cfg | ||
|
|
||
| enable_ep_dp = int(os.getenv("ENABLE_EP_DP_IN_FD", "1")) | ||
| # splitwise_role = os.getenv("SPLITWISE_ROLE", "mixed") | ||
| # dp_rank = int(os.getenv("DP_RANK", "0")) | ||
| splitwise_role = cfg.scheduler_config.splitwise_role | ||
| dp_rank = str(cfg.parallel_config.local_data_parallel_id) | ||
| if enable_ep_dp: | ||
| pod_name = ( | ||
| os.getenv("POD_NAMESPACE", "None") | ||
| + "_" | ||
| + os.getenv("FD_POD_NAME", "None") | ||
| + "_" | ||
| + os.getenv("HOST_IP", "None") | ||
| + "_" | ||
| + splitwise_role | ||
| + "_" | ||
| + dp_rank | ||
| ) | ||
| else: | ||
| pod_name = ( | ||
| os.getenv("POD_NAMESPACE", "None") | ||
| + "_" | ||
| + os.getenv("FD_POD_NAME", "None") | ||
| + "_" | ||
| + os.getenv("HOST_IP", "None") | ||
| ) | ||
|
|
||
| reg = cfg.register_info | ||
| result_dict = { | ||
| "splitwise_role": reg["role"], | ||
| "ip": reg["host_ip"], | ||
| "grpc_port": str(reg["connector_port"]), | ||
| "http_port": str(reg["port"]), | ||
| "model_server_id": pod_name, | ||
| } | ||
|
|
||
| result_content = json.dumps(result_dict, ensure_ascii=False) | ||
| return Response(result_content, media_type="application/json") | ||
|
|
||
|
|
||
| @app.get("/v2/health/ready") | ||
| def im_check_health(request: Request): | ||
| """ | ||
| IM check health | ||
| """ | ||
| resp = health(request) | ||
| error_info = {} | ||
| if resp.status_code != 200: | ||
| error_info["error_code"] = 1 | ||
| error_info["error_msg"] = "APIServer is down" | ||
| return JSONResponse(status_code=500, content=error_info) | ||
| return Response() | ||
|
|
||
|
|
||
| @app.get("/fastdeploy/server/info") | ||
| def im_report() -> Response: | ||
| """ | ||
| IM Get PD disaggregation info of the API server. | ||
| """ | ||
| global llm_engine | ||
| if llm_engine is None: | ||
| return Response("Engine not loaded", status_code=500) | ||
| cfg = llm_engine.cfg | ||
|
|
||
| def process_object(obj): | ||
| if hasattr(obj, "__dict__"): | ||
| return obj.__dict__ | ||
| if isinstance(obj, (set, frozenset)): | ||
| return list(obj) | ||
| return str(obj) | ||
|
|
||
| cfg_dict = {k: v for k, v in cfg.__dict__.items()} | ||
|
|
||
| enable_ep_dp = int(os.getenv("ENABLE_EP_DP_IN_FD", "1")) | ||
| # splitwise_role = os.getenv("SPLITWISE_ROLE", "mixed") | ||
| # dp_rank = int(os.getenv("DP_RANK", "0")) | ||
| splitwise_role = cfg.scheduler_config.splitwise_role | ||
| dp_rank = str(cfg.parallel_config.local_data_parallel_id) | ||
| if enable_ep_dp: | ||
| pod_name = ( | ||
| os.getenv("POD_NAMESPACE", "None") | ||
| + "_" | ||
| + os.getenv("FD_POD_NAME", "None") | ||
| + "_" | ||
| + os.getenv("HOST_IP", "None") | ||
| + "_" | ||
| + splitwise_role | ||
| + "_" | ||
| + dp_rank | ||
| ) | ||
| else: | ||
| pod_name = ( | ||
| os.getenv("POD_NAMESPACE", "None") | ||
| + "_" | ||
| + os.getenv("FD_POD_NAME", "None") | ||
| + "_" | ||
| + os.getenv("HOST_IP", "None") | ||
| ) | ||
|
|
||
| block_bs = float(os.getenv("BLOCK_BS", 50)) | ||
| block_size = int(os.getenv("BLOCK_SIZE", 64)) | ||
| max_dec_len = int(os.getenv("MAX_DEC_LEN", default="1024")) | ||
| max_seq_len = int(os.getenv("MAX_SEQ_LEN", 8192)) | ||
| enc_dec_block_num = int(os.getenv("ENC_DEC_BLOCK_NUM", 2)) | ||
| block_ratio = float(os.getenv("BLOCK_RATIO", 0.75)) | ||
| use_filter = os.getenv("USE_FILTER", "False") == "True" | ||
| min_input_token_num = int(os.getenv("PULLER_MIN_SEQ_LEN", "-1")) | ||
| max_input_token_num = int(os.getenv("PULLER_MAX_SEQ_LEN", "-1")) | ||
| min_output_token_num = int(os.getenv("PULLER_MIN_DEC_LEN", "-1")) | ||
| max_output_token_num = int(os.getenv("PULLER_MAX_DEC_LEN", "-1")) | ||
| topp_min = float(os.getenv("TOPP_START_FILTER", "-1.0")) | ||
| topp_max = float(os.getenv("TOPP_END_FILTER", "-1.0")) | ||
| use_seed_filter = int(os.getenv("NEED_INFER_SEED", "0")) == 1 | ||
| fed_member_file = os.getenv("FED_MEMBER_FILE", None) | ||
|
|
||
| max_query_block_num = (max(max_dec_len, max_seq_len) + block_size - 1) // block_size | ||
| total_block_num = int(block_bs * max_query_block_num) | ||
| max_block_num = int(total_block_num * block_ratio) | ||
| dec_token_num = enc_dec_block_num * block_size | ||
| input_token_num_range = [min_input_token_num, max_input_token_num] | ||
| output_token_num_range = [min_output_token_num, max_output_token_num] | ||
| topp_range = [topp_min, topp_max] | ||
| seed_range = [-1, -1] | ||
| if use_seed_filter: | ||
| seed_range = [-1, 2 ** 63 - 1] | ||
| is_master = 0 | ||
| if fed_member_file: | ||
| try: | ||
| with open(fed_member_file, 'r') as f: | ||
| fed_member_list = f.read().strip().split(',') | ||
| if fed_member_list.index(os.getenv("HOST_IP", "None")) == 0 and \ | ||
| dp_rank == 0: | ||
| is_master = 1 | ||
| except Exception: | ||
| pass | ||
|
|
||
| cfg_dict["server_perf_id"] = str(uuid.uuid4()) | ||
| cfg_dict["pod_name"] = pod_name | ||
| cfg_dict["model_id"] = os.getenv("MODEL_ID", "eb") | ||
| cfg_dict["model_server_type"] = splitwise_role | ||
| cfg_dict["model_version"] = str(os.getenv("MODEL_VERSION", default="None")) | ||
| cfg_dict["model_server_id"] = pod_name | ||
| cfg_dict["dp_rank"] = dp_rank | ||
| cfg_dict["enable_ep_dp"] = enable_ep_dp | ||
| cfg_dict["state"] = "running" | ||
| cfg_dict["blocks"] = max_block_num | ||
| cfg_dict["dec_token_num"] = dec_token_num | ||
| cfg_dict["block_size"] = block_size | ||
| cfg_dict["use_filter"] = use_filter | ||
| cfg_dict["input_token_num_range"] = input_token_num_range | ||
| cfg_dict["output_token_num_range"] = output_token_num_range | ||
| cfg_dict["topp_range"] = topp_range | ||
| cfg_dict["seed_range"] = seed_range | ||
| cfg_dict["is_stopping"] = "running" | ||
| cfg_dict["is_master"] = is_master | ||
| cfg_dict["container_host_ip"] = os.getenv("HOST_IP", "None") | ||
| cfg_dict["free_block_num"] = llm_engine.engine.resource_manager.available_block_num() | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🔴 Bug 这里直接访问
建议修复方式:对 |
||
|
|
||
| with _decode_nodes_lock: | ||
| connected_decode_list = [ | ||
| info for info in _decode_nodes_register_info.values() if info is not None | ||
| ] | ||
| cfg_dict["connected_decode_list"] = connected_decode_list | ||
|
|
||
| result_content = json.dumps(cfg_dict, default=process_object, ensure_ascii=False) | ||
| return Response(result_content, media_type="application/json") | ||
|
|
||
|
|
||
| def launch_api_server() -> None: | ||
| """ | ||
| 启动http服务 | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🔴 Bug
dp_rank在上面已经被转成字符串,这里再和整数0比较,条件永远为 False。配置了
FED_MEMBER_FILE且当前HOST_IP是成员列表第一个、DP rank 为 0 时,is_master仍会保持 0,IM 侧无法识别 master 节点。建议修复方式:保留一个整数 rank 用于逻辑判断,只在拼接
pod_name或写入响应时再转字符串。