Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
224 changes: 224 additions & 0 deletions fastdeploy/entrypoints/openai/api_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import uuid
from collections.abc import AsyncGenerator
from contextlib import asynccontextmanager
import requests

import uvicorn
import zmq
Expand Down Expand Up @@ -120,6 +121,51 @@
MAX_CONCURRENT_CONNECTIONS = (args.max_concurrency + args.workers - 1) // args.workers
connection_semaphore = StatefulSemaphore(MAX_CONCURRENT_CONNECTIONS)

# { ip: register_info_dict, ... }
_decode_nodes_lock = threading.Lock()
_decode_nodes_register_info: dict = {}


def _fetch_decode_node_register_info(ip: str, port: int):
"""get /register_info"""
url = f"http://{ip}:{port}/register_info"
try:
resp = requests.get(url, timeout=3)
if resp.status_code == 200:
return resp.json()
except Exception as e:
api_server_logger.error(f"_fetch_decode_node_register_info {url} error: {e}")
return None


def _poll_decode_nodes():
"""get decode node info and save"""
while True:
try:
ip_list_env = os.environ.get("D_IP_LIST", "")
port_list_env = os.environ.get("DECODE_PORTS", "")
ip_list = [ip.strip() for ip in ip_list_env.split(",") if ip.strip()]
port_list = [item.strip() for item in port_list_env.split(",") if item.strip()]
if len(ip_list) > 0 and len(port_list) > 0:
new_info: dict = {}
for ip in ip_list:
register_info = _fetch_decode_node_register_info(ip, port_list[0])
if register_info is not None:
new_info[ip] = register_info
with _decode_nodes_lock:
_decode_nodes_register_info.clear()
_decode_nodes_register_info.update(new_info)
api_server_logger.debug(f"decode node info update {len(ip_list)}")
except Exception as e:
api_server_logger.error(f"decode node info update error: {e}, {traceback.format_exc()}")
time.sleep(5)


def launch_decode_node_poller():
t = threading.Thread(target=_poll_decode_nodes, daemon=True, name="decode-node-poller")
t.start()
api_server_logger.info("launch_decode_node_poller start")


class StandaloneApplication(BaseApplication):
def __init__(self, app, options=None):
Expand Down Expand Up @@ -314,6 +360,7 @@ async def lifespan(app: FastAPI):

if llm_engine is not None and not isinstance(llm_engine, AsyncLLM):
llm_engine.engine.data_processor = engine_client.data_processor
launch_decode_node_poller()
yield
# close zmq
try:
Expand Down Expand Up @@ -770,6 +817,183 @@ async def check_redundant(request: Request):
return JSONResponse(content, status_code=status_code)


@app.get("/register_info")
def register_info() -> Response:
"""
Get the current register_info.
"""
global llm_engine
if llm_engine is None:
return Response("Engine not loaded", status_code=500)
cfg = llm_engine.cfg

enable_ep_dp = int(os.getenv("ENABLE_EP_DP_IN_FD", "1"))
# splitwise_role = os.getenv("SPLITWISE_ROLE", "mixed")
# dp_rank = int(os.getenv("DP_RANK", "0"))
splitwise_role = cfg.scheduler_config.splitwise_role
dp_rank = str(cfg.parallel_config.local_data_parallel_id)
if enable_ep_dp:
pod_name = (
os.getenv("POD_NAMESPACE", "None")
+ "_"
+ os.getenv("FD_POD_NAME", "None")
+ "_"
+ os.getenv("HOST_IP", "None")
+ "_"
+ splitwise_role
+ "_"
+ dp_rank
)
else:
pod_name = (
os.getenv("POD_NAMESPACE", "None")
+ "_"
+ os.getenv("FD_POD_NAME", "None")
+ "_"
+ os.getenv("HOST_IP", "None")
)

reg = cfg.register_info
result_dict = {
"splitwise_role": reg["role"],
"ip": reg["host_ip"],
"grpc_port": str(reg["connector_port"]),
"http_port": str(reg["port"]),
"model_server_id": pod_name,
}

result_content = json.dumps(result_dict, ensure_ascii=False)
return Response(result_content, media_type="application/json")


@app.get("/v2/health/ready")
def im_check_health(request: Request):
"""
IM check health
"""
resp = health(request)
error_info = {}
if resp.status_code != 200:
error_info["error_code"] = 1
error_info["error_msg"] = "APIServer is down"
return JSONResponse(status_code=500, content=error_info)
return Response()


@app.get("/fastdeploy/server/info")
def im_report() -> Response:
"""
IM Get PD disaggregation info of the API server.
"""
global llm_engine
if llm_engine is None:
return Response("Engine not loaded", status_code=500)
cfg = llm_engine.cfg

def process_object(obj):
if hasattr(obj, "__dict__"):
return obj.__dict__
if isinstance(obj, (set, frozenset)):
return list(obj)
return str(obj)

cfg_dict = {k: v for k, v in cfg.__dict__.items()}

enable_ep_dp = int(os.getenv("ENABLE_EP_DP_IN_FD", "1"))
# splitwise_role = os.getenv("SPLITWISE_ROLE", "mixed")
# dp_rank = int(os.getenv("DP_RANK", "0"))
splitwise_role = cfg.scheduler_config.splitwise_role
dp_rank = str(cfg.parallel_config.local_data_parallel_id)
if enable_ep_dp:
pod_name = (
os.getenv("POD_NAMESPACE", "None")
+ "_"
+ os.getenv("FD_POD_NAME", "None")
+ "_"
+ os.getenv("HOST_IP", "None")
+ "_"
+ splitwise_role
+ "_"
+ dp_rank
)
else:
pod_name = (
os.getenv("POD_NAMESPACE", "None")
+ "_"
+ os.getenv("FD_POD_NAME", "None")
+ "_"
+ os.getenv("HOST_IP", "None")
)

block_bs = float(os.getenv("BLOCK_BS", 50))
block_size = int(os.getenv("BLOCK_SIZE", 64))
max_dec_len = int(os.getenv("MAX_DEC_LEN", default="1024"))
max_seq_len = int(os.getenv("MAX_SEQ_LEN", 8192))
enc_dec_block_num = int(os.getenv("ENC_DEC_BLOCK_NUM", 2))
block_ratio = float(os.getenv("BLOCK_RATIO", 0.75))
use_filter = os.getenv("USE_FILTER", "False") == "True"
min_input_token_num = int(os.getenv("PULLER_MIN_SEQ_LEN", "-1"))
max_input_token_num = int(os.getenv("PULLER_MAX_SEQ_LEN", "-1"))
min_output_token_num = int(os.getenv("PULLER_MIN_DEC_LEN", "-1"))
max_output_token_num = int(os.getenv("PULLER_MAX_DEC_LEN", "-1"))
topp_min = float(os.getenv("TOPP_START_FILTER", "-1.0"))
topp_max = float(os.getenv("TOPP_END_FILTER", "-1.0"))
use_seed_filter = int(os.getenv("NEED_INFER_SEED", "0")) == 1
fed_member_file = os.getenv("FED_MEMBER_FILE", None)

max_query_block_num = (max(max_dec_len, max_seq_len) + block_size - 1) // block_size
total_block_num = int(block_bs * max_query_block_num)
max_block_num = int(total_block_num * block_ratio)
dec_token_num = enc_dec_block_num * block_size
input_token_num_range = [min_input_token_num, max_input_token_num]
output_token_num_range = [min_output_token_num, max_output_token_num]
topp_range = [topp_min, topp_max]
seed_range = [-1, -1]
if use_seed_filter:
seed_range = [-1, 2 ** 63 - 1]
is_master = 0
if fed_member_file:
try:
with open(fed_member_file, 'r') as f:
fed_member_list = f.read().strip().split(',')
if fed_member_list.index(os.getenv("HOST_IP", "None")) == 0 and \
dp_rank == 0:

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🔴 Bug dp_rank 在上面已经被转成字符串,这里再和整数 0 比较,条件永远为 False。

配置了 FED_MEMBER_FILE 且当前 HOST_IP 是成员列表第一个、DP rank 为 0 时,is_master 仍会保持 0,IM 侧无法识别 master 节点。

建议修复方式:保留一个整数 rank 用于逻辑判断,只在拼接 pod_name 或写入响应时再转字符串。

dp_rank = cfg.parallel_config.local_data_parallel_id
# pod_name 拼接处使用 str(dp_rank)
if fed_member_list.index(os.getenv("HOST_IP", "None")) == 0 and dp_rank == 0:
    is_master = 1
cfg_dict["dp_rank"] = str(dp_rank)

is_master = 1
except Exception:
pass

cfg_dict["server_perf_id"] = str(uuid.uuid4())
cfg_dict["pod_name"] = pod_name
cfg_dict["model_id"] = os.getenv("MODEL_ID", "eb")
cfg_dict["model_server_type"] = splitwise_role
cfg_dict["model_version"] = str(os.getenv("MODEL_VERSION", default="None"))
cfg_dict["model_server_id"] = pod_name
cfg_dict["dp_rank"] = dp_rank
cfg_dict["enable_ep_dp"] = enable_ep_dp
cfg_dict["state"] = "running"
cfg_dict["blocks"] = max_block_num
cfg_dict["dec_token_num"] = dec_token_num
cfg_dict["block_size"] = block_size
cfg_dict["use_filter"] = use_filter
cfg_dict["input_token_num_range"] = input_token_num_range
cfg_dict["output_token_num_range"] = output_token_num_range
cfg_dict["topp_range"] = topp_range
cfg_dict["seed_range"] = seed_range
cfg_dict["is_stopping"] = "running"
cfg_dict["is_master"] = is_master
cfg_dict["container_host_ip"] = os.getenv("HOST_IP", "None")
cfg_dict["free_block_num"] = llm_engine.engine.resource_manager.available_block_num()

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🔴 Bug 这里直接访问 llm_engine.engine.resource_manager,在 FD_ENABLE_ASYNC_LLM=1 时会让新增接口返回 500。

load_engine() 在 async 模式下把全局 llm_engine 设置为 AsyncLLMAsyncLLM 继承的 EngineServiceClient 只在子进程里创建 EngineService,主进程对象没有 .engine 属性。文件里已有生命周期代码也用 not isinstance(llm_engine, AsyncLLM) 区分了同步引擎路径。

建议修复方式:对 AsyncLLM 单独走跨进程状态查询/control API 获取 free_block_num,或在 async 模式下返回明确的不可用值;不要在 API server 主进程直接读取 llm_engine.engine.resource_manager


with _decode_nodes_lock:
connected_decode_list = [
info for info in _decode_nodes_register_info.values() if info is not None
]
cfg_dict["connected_decode_list"] = connected_decode_list

result_content = json.dumps(cfg_dict, default=process_object, ensure_ascii=False)
return Response(result_content, media_type="application/json")


def launch_api_server() -> None:
"""
启动http服务
Expand Down
Loading