Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
135 changes: 135 additions & 0 deletions fastdeploy/engine/common_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,11 @@
import fastdeploy.metrics.trace as tracing
from fastdeploy.engine.register_manager import RegisterManager
from fastdeploy.engine.request import (
CompletionOutput,
ControlRequest,
ControlResponse,
Request,
RequestMetrics,
RequestOutput,
RequestStatus,
RequestType,
Expand Down Expand Up @@ -1410,6 +1412,139 @@ def _control_update_weights(self, control_request: ControlRequest) -> Optional[d

return responses

def _control_abort_requests(self, control_req: ControlRequest):
if not envs.ENABLE_V1_KVCACHE_SCHEDULER:
raise Exception("abort_requests only supported in ENABLE_V1_KVCACHE_SCHEDULER")
args = control_req.get_args()
abort_all = args.get("abort_all", False)
req_ids = args.get("req_ids", [])
matched_input_ids = set()
now_reqs = list(set(self.resource_manager.requests.keys()) | set(self.scheduler.requests.keys()))

# Step 1: Determine target request list
if abort_all:
# all requests in running + waiting
target_req_ids = now_reqs
else:
# filter out requests that actually exist
target_req_ids = []
for rid in req_ids:
if rid in now_reqs:
target_req_ids.append(rid)
matched_input_ids.add(rid)
elif f"{rid}_0" in now_reqs:
target_req_ids.append(f"{rid}_0")
matched_input_ids.add(rid)

if not target_req_ids:
return {"aborted": [], "not_found": req_ids if not abort_all else []}

# Step 2: Collect partial results
aborted_info = []
results = []
for req_id in target_req_ids:
request = self.resource_manager.requests.get(req_id)
if request is None:
scheduled_req = self.scheduler.requests.get(req_id)
if scheduled_req is None:
continue
request = scheduled_req.raw

partial_token_ids = list(request.output_token_ids)

# Construct finished response with partial results
now = time.time()
abort_metrics = RequestMetrics(
arrival_time=request.metrics.arrival_time if request.metrics else now,
inference_start_time=request.metrics.inference_start_time if request.metrics else now,
engine_recv_latest_token_time=now,
engine_recv_first_token_time=request.metrics.engine_recv_first_token_time if request.metrics else now,
request_start_time=request.metrics.arrival_time if request.metrics else now,
)
result = RequestOutput(
request_id=req_id,
finished=True,
outputs=CompletionOutput(
index=0,
send_idx=len(partial_token_ids),
token_ids=[self.data_processor.eos_token_ids[0]],
),
metrics=abort_metrics,
error_code=200,
error_msg="Aborted",
)
results.append(result)
aborted_info.append(
{
"request_id": req_id,
"output_token_count": len(partial_token_ids),
}
)

# Step 3: Execute abort — add all requests to waiting_abort_req_id_set
if envs.ENABLE_V1_KVCACHE_SCHEDULER:
for req_id in target_req_ids:
self.resource_manager.add_abort_req_ids(req_id)
time.sleep(0.0001)
if self.cfg.scheduler_config.splitwise_role != "prefill":
self._wait_abort_complete(target_req_ids)

# Add results to scheduler, engine will have a thread calling get_results,
# then cleanup and call send_response to send to client.
# When client disconnects, send_response will automatically ignore
if self.cfg.scheduler_config.splitwise_role != "prefill":
try:
# self.send_response_server.send_response(req_id, [result])
self.scheduler.put_results(results)
except Exception:
pass # client may have disconnected

not_found = [rid for rid in req_ids if rid not in matched_input_ids] if not abort_all else []

return {"aborted": aborted_info, "not_found": not_found}

def _wait_abort_complete(self, target_req_ids, stall_timeout=1):
"""
Wait for all abort requests to complete.
- Keep monitoring as long as remaining is not empty, which means cleanup is not done yet
- If no progress within stall_timeout seconds, force cleanup requests stuck in to_be_aborted_req_id_set,
reset progress state if any, then continue monitoring
"""
target_set = set(target_req_ids)
prev_remaining_count = len(target_set)
last_progress_time = time.time()
remaining = target_set & self.resource_manager.get_reqs_in_aborting()
while remaining:
remaining = target_set & self.resource_manager.get_reqs_in_aborting()
if not remaining:
self.llm_logger.info(f"all {len(target_set)} abort reqs cleaned")
return

current_count = len(remaining)
if current_count < prev_remaining_count:
# progress made: recycle_abort_task was called
self.llm_logger.info(f"abort progress: {prev_remaining_count} -> {current_count}")
last_progress_time = time.time()
prev_remaining_count = current_count

if time.time() - last_progress_time > stall_timeout:
# no progress timeout: only cleanup requests stuck in to_be_aborted (worker hasn't returned -9)
stuck = remaining & self.resource_manager.to_be_aborted_req_id_set
if stuck:
self.llm_logger.warning(
f"no abort progress for {stall_timeout}s, "
f"force cleanup {len(stuck)} stuck requests (in to_be_aborted)"
)
for req_id in list(stuck):
self.llm_logger.warning(f"force cleanup stuck req_id:{req_id}")
self.resource_manager.recycle_abort_task(req_id)
# reset progress state
last_progress_time = time.time()
prev_remaining_count = current_count - len(stuck)
# else: remaining are all in waiting_abort_req_id_set, waiting for natural flow

time.sleep(0.005)

async def _wait_all_control_responses(self, request_id: str, timeout: int):
"""Wait for control responses from all workers with a global timeout.

Expand Down
4 changes: 4 additions & 0 deletions fastdeploy/engine/sched/resource_manager_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,6 +282,7 @@ def recycle_abort_task(self, request_id):
del self.requests[request_id]
del self.req_dict[request_id]
self.to_be_aborted_req_id_set.remove(request_id)
self.update_metrics()

def _trigger_abort(self, request_id, scheduled_reqs):
if request_id in self.requests:
Expand Down Expand Up @@ -1207,6 +1208,9 @@ def download_bos_features(bos_client, features_urls):
return None
inputs["audio_features"] = result

def get_reqs_in_aborting(self):
return self.waiting_abort_req_id_set | self.to_be_aborted_req_id_set

def get_available_position(self) -> int:
position = 0
while position < self.max_num_seqs:
Expand Down
19 changes: 19 additions & 0 deletions fastdeploy/entrypoints/openai/api_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -441,6 +441,25 @@ async def update_weights(request: Request) -> Response:
return control_response.to_api_json_response()


@app.post("/v1/abort_requests")
async def abort_requests(request: Request):
body = await request.json()
abort_all = body.get("abort_all", False)
req_ids = body.get("req_ids", None)

# 参数校验
if not abort_all and not req_ids:
return JSONResponse(status_code=400, content={"error": "must provide abort_all=true or req_ids"})

control_request = ControlRequest(
request_id=f"control-{uuid.uuid4()}",
method="abort_requests",
args={"abort_all": abort_all, "req_ids": req_ids or []},
)
control_response = await app.state.engine_client.run_control_method(control_request)
return control_response.to_api_json_response()


def wrap_streaming_generator(original_generator: AsyncGenerator):
"""
Wrap an async generator to release the connection semaphore when the generator is finished.
Expand Down
5 changes: 5 additions & 0 deletions fastdeploy/entrypoints/openai/serving_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -462,6 +462,9 @@ async def chat_completion_stream_generator(
if res.get("error_msg") is not None and "Recover" in res["error_msg"]:
choice.finish_reason = "recover_stop"

if res.get("error_msg") is not None and "Aborted" in res["error_msg"]:
choice.finish_reason = "abort"

inference_start_time[idx] = 0

if request.collect_metrics:
Expand Down Expand Up @@ -795,6 +798,8 @@ async def _create_chat_completion_choice(
if data.get("error_msg", None) is not None and "Recover" in data["error_msg"]:
finish_reason = "recover_stop"

if data.get("error_msg", None) is not None and "Aborted" in data["error_msg"]:
finish_reason = "abort"
return ChatCompletionResponseChoice(
index=idx,
message=message,
Expand Down
4 changes: 4 additions & 0 deletions fastdeploy/entrypoints/openai/serving_completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -581,6 +581,8 @@ async def completion_stream_generator(
output,
tool_called[idx],
)
if res.get("error_msg") is not None and "Aborted" in res["error_msg"]:
choices[-1].finish_reason = "abort"
inference_start_time[idx] = 0

send_idx = output.get("send_idx")
Expand Down Expand Up @@ -722,6 +724,8 @@ def request_output_to_completion_response(
output,
False,
)
if final_res.get("error_msg", None) is not None and "Aborted" in final_res["error_msg"]:
finish_reason = "abort"

choice_data = CompletionResponseChoice(
token_ids=token_ids,
Expand Down
46 changes: 44 additions & 2 deletions fastdeploy/router/router.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@

import aiohttp
import uvicorn
from fastapi import FastAPI, HTTPException
from fastapi.responses import ORJSONResponse, Response, StreamingResponse
from fastapi import FastAPI, HTTPException, Request
from fastapi.responses import JSONResponse, ORJSONResponse, Response, StreamingResponse

from fastdeploy.router.utils import (
InstanceInfo,
Expand Down Expand Up @@ -503,6 +503,48 @@ async def health_generate():
return Response(status_code=200)


@app.post("/v1/abort_requests")
async def abort_requests(request: Request):
body = await request.json()
prefill_servers = app.state.router.prefill_servers
decode_servers = app.state.router.decode_servers
all_servers = prefill_servers + decode_servers

async with aiohttp.ClientSession() as session:
tasks = [session.post(f"{server.url()}/v1/abort_requests", json=body) for server in all_servers]
responses = await asyncio.gather(*tasks, return_exceptions=True)

# Aggregate results from Node D only
all_aborted = []
all_not_found = []
errors = []
decode_start = len(prefill_servers)
for i, (server, resp) in enumerate(zip(all_servers, responses)):
if i < decode_start:
continue
if isinstance(resp, Exception):
errors.append({"server": server.url(), "error": str(resp)})
elif resp.status == 200:
data = await resp.json()
result = data.get("result") or {}
all_aborted.extend(result.get("aborted", []))
all_not_found.extend(result.get("not_found", []))
else:
errors.append({"server": server.url(), "status": resp.status})

return JSONResponse(
content={
"request_id": f"router-{uuid4()}",
"status": "success" if not errors else "error",
"error_message": None if not errors else str(errors),
"result": {
"aborted": all_aborted,
"not_found": list(set(all_not_found)),
},
}
)


def launch_router(router_args: RouterArgs):
app.state.router_args = router_args
print(f"Starting router with args: {router_args}")
Expand Down
Loading