diff --git a/fastdeploy/entrypoints/openai/api_server.py b/fastdeploy/entrypoints/openai/api_server.py index 2c108dda0af..14b837e7156 100644 --- a/fastdeploy/entrypoints/openai/api_server.py +++ b/fastdeploy/entrypoints/openai/api_server.py @@ -23,6 +23,7 @@ import uuid from collections.abc import AsyncGenerator from contextlib import asynccontextmanager +import requests import uvicorn import zmq @@ -120,6 +121,51 @@ MAX_CONCURRENT_CONNECTIONS = (args.max_concurrency + args.workers - 1) // args.workers connection_semaphore = StatefulSemaphore(MAX_CONCURRENT_CONNECTIONS) +# { ip: register_info_dict, ... } +_decode_nodes_lock = threading.Lock() +_decode_nodes_register_info: dict = {} + + +def _fetch_decode_node_register_info(ip: str, port: int): + """get /register_info""" + url = f"http://{ip}:{port}/register_info" + try: + resp = requests.get(url, timeout=3) + if resp.status_code == 200: + return resp.json() + except Exception as e: + api_server_logger.error(f"_fetch_decode_node_register_info {url} error: {e}") + return None + + +def _poll_decode_nodes(): + """get decode node info and save""" + while True: + try: + ip_list_env = os.environ.get("D_IP_LIST", "") + port_list_env = os.environ.get("DECODE_PORTS", "") + ip_list = [ip.strip() for ip in ip_list_env.split(",") if ip.strip()] + port_list = [item.strip() for item in port_list_env.split(",") if item.strip()] + if len(ip_list) > 0 and len(port_list) > 0: + new_info: dict = {} + for ip in ip_list: + register_info = _fetch_decode_node_register_info(ip, port_list[0]) + if register_info is not None: + new_info[ip] = register_info + with _decode_nodes_lock: + _decode_nodes_register_info.clear() + _decode_nodes_register_info.update(new_info) + api_server_logger.debug(f"decode node info update {len(ip_list)}") + except Exception as e: + api_server_logger.error(f"decode node info update error: {e}, {traceback.format_exc()}") + time.sleep(5) + + +def launch_decode_node_poller(): + t = threading.Thread(target=_poll_decode_nodes, daemon=True, name="decode-node-poller") + t.start() + api_server_logger.info("launch_decode_node_poller start") + class StandaloneApplication(BaseApplication): def __init__(self, app, options=None): @@ -314,6 +360,7 @@ async def lifespan(app: FastAPI): if llm_engine is not None and not isinstance(llm_engine, AsyncLLM): llm_engine.engine.data_processor = engine_client.data_processor + launch_decode_node_poller() yield # close zmq try: @@ -770,6 +817,183 @@ async def check_redundant(request: Request): return JSONResponse(content, status_code=status_code) +@app.get("/register_info") +def register_info() -> Response: + """ + Get the current register_info. + """ + global llm_engine + if llm_engine is None: + return Response("Engine not loaded", status_code=500) + cfg = llm_engine.cfg + + enable_ep_dp = int(os.getenv("ENABLE_EP_DP_IN_FD", "1")) + # splitwise_role = os.getenv("SPLITWISE_ROLE", "mixed") + # dp_rank = int(os.getenv("DP_RANK", "0")) + splitwise_role = cfg.scheduler_config.splitwise_role + dp_rank = str(cfg.parallel_config.local_data_parallel_id) + if enable_ep_dp: + pod_name = ( + os.getenv("POD_NAMESPACE", "None") + + "_" + + os.getenv("FD_POD_NAME", "None") + + "_" + + os.getenv("HOST_IP", "None") + + "_" + + splitwise_role + + "_" + + dp_rank + ) + else: + pod_name = ( + os.getenv("POD_NAMESPACE", "None") + + "_" + + os.getenv("FD_POD_NAME", "None") + + "_" + + os.getenv("HOST_IP", "None") + ) + + reg = cfg.register_info + result_dict = { + "splitwise_role": reg["role"], + "ip": reg["host_ip"], + "grpc_port": str(reg["connector_port"]), + "http_port": str(reg["port"]), + "model_server_id": pod_name, + } + + result_content = json.dumps(result_dict, ensure_ascii=False) + return Response(result_content, media_type="application/json") + + +@app.get("/v2/health/ready") +def im_check_health(request: Request): + """ + IM check health + """ + resp = health(request) + error_info = {} + if resp.status_code != 200: + error_info["error_code"] = 1 + error_info["error_msg"] = "APIServer is down" + return JSONResponse(status_code=500, content=error_info) + return Response() + + +@app.get("/fastdeploy/server/info") +def im_report() -> Response: + """ + IM Get PD disaggregation info of the API server. + """ + global llm_engine + if llm_engine is None: + return Response("Engine not loaded", status_code=500) + cfg = llm_engine.cfg + + def process_object(obj): + if hasattr(obj, "__dict__"): + return obj.__dict__ + if isinstance(obj, (set, frozenset)): + return list(obj) + return str(obj) + + cfg_dict = {k: v for k, v in cfg.__dict__.items()} + + enable_ep_dp = int(os.getenv("ENABLE_EP_DP_IN_FD", "1")) + # splitwise_role = os.getenv("SPLITWISE_ROLE", "mixed") + # dp_rank = int(os.getenv("DP_RANK", "0")) + splitwise_role = cfg.scheduler_config.splitwise_role + dp_rank = str(cfg.parallel_config.local_data_parallel_id) + if enable_ep_dp: + pod_name = ( + os.getenv("POD_NAMESPACE", "None") + + "_" + + os.getenv("FD_POD_NAME", "None") + + "_" + + os.getenv("HOST_IP", "None") + + "_" + + splitwise_role + + "_" + + dp_rank + ) + else: + pod_name = ( + os.getenv("POD_NAMESPACE", "None") + + "_" + + os.getenv("FD_POD_NAME", "None") + + "_" + + os.getenv("HOST_IP", "None") + ) + + block_bs = float(os.getenv("BLOCK_BS", 50)) + block_size = int(os.getenv("BLOCK_SIZE", 64)) + max_dec_len = int(os.getenv("MAX_DEC_LEN", default="1024")) + max_seq_len = int(os.getenv("MAX_SEQ_LEN", 8192)) + enc_dec_block_num = int(os.getenv("ENC_DEC_BLOCK_NUM", 2)) + block_ratio = float(os.getenv("BLOCK_RATIO", 0.75)) + use_filter = os.getenv("USE_FILTER", "False") == "True" + min_input_token_num = int(os.getenv("PULLER_MIN_SEQ_LEN", "-1")) + max_input_token_num = int(os.getenv("PULLER_MAX_SEQ_LEN", "-1")) + min_output_token_num = int(os.getenv("PULLER_MIN_DEC_LEN", "-1")) + max_output_token_num = int(os.getenv("PULLER_MAX_DEC_LEN", "-1")) + topp_min = float(os.getenv("TOPP_START_FILTER", "-1.0")) + topp_max = float(os.getenv("TOPP_END_FILTER", "-1.0")) + use_seed_filter = int(os.getenv("NEED_INFER_SEED", "0")) == 1 + fed_member_file = os.getenv("FED_MEMBER_FILE", None) + + max_query_block_num = (max(max_dec_len, max_seq_len) + block_size - 1) // block_size + total_block_num = int(block_bs * max_query_block_num) + max_block_num = int(total_block_num * block_ratio) + dec_token_num = enc_dec_block_num * block_size + input_token_num_range = [min_input_token_num, max_input_token_num] + output_token_num_range = [min_output_token_num, max_output_token_num] + topp_range = [topp_min, topp_max] + seed_range = [-1, -1] + if use_seed_filter: + seed_range = [-1, 2 ** 63 - 1] + is_master = 0 + if fed_member_file: + try: + with open(fed_member_file, 'r') as f: + fed_member_list = f.read().strip().split(',') + if fed_member_list.index(os.getenv("HOST_IP", "None")) == 0 and \ + dp_rank == 0: + is_master = 1 + except Exception: + pass + + cfg_dict["server_perf_id"] = str(uuid.uuid4()) + cfg_dict["pod_name"] = pod_name + cfg_dict["model_id"] = os.getenv("MODEL_ID", "eb") + cfg_dict["model_server_type"] = splitwise_role + cfg_dict["model_version"] = str(os.getenv("MODEL_VERSION", default="None")) + cfg_dict["model_server_id"] = pod_name + cfg_dict["dp_rank"] = dp_rank + cfg_dict["enable_ep_dp"] = enable_ep_dp + cfg_dict["state"] = "running" + cfg_dict["blocks"] = max_block_num + cfg_dict["dec_token_num"] = dec_token_num + cfg_dict["block_size"] = block_size + cfg_dict["use_filter"] = use_filter + cfg_dict["input_token_num_range"] = input_token_num_range + cfg_dict["output_token_num_range"] = output_token_num_range + cfg_dict["topp_range"] = topp_range + cfg_dict["seed_range"] = seed_range + cfg_dict["is_stopping"] = "running" + cfg_dict["is_master"] = is_master + cfg_dict["container_host_ip"] = os.getenv("HOST_IP", "None") + cfg_dict["free_block_num"] = llm_engine.engine.resource_manager.available_block_num() + + with _decode_nodes_lock: + connected_decode_list = [ + info for info in _decode_nodes_register_info.values() if info is not None + ] + cfg_dict["connected_decode_list"] = connected_decode_list + + result_content = json.dumps(cfg_dict, default=process_object, ensure_ascii=False) + return Response(result_content, media_type="application/json") + + def launch_api_server() -> None: """ 启动http服务