diff --git a/fastdeploy/engine/common_engine.py b/fastdeploy/engine/common_engine.py index 806125bd027..bf88cb4447b 100644 --- a/fastdeploy/engine/common_engine.py +++ b/fastdeploy/engine/common_engine.py @@ -40,9 +40,11 @@ import fastdeploy.metrics.trace as tracing from fastdeploy.engine.register_manager import RegisterManager from fastdeploy.engine.request import ( + CompletionOutput, ControlRequest, ControlResponse, Request, + RequestMetrics, RequestOutput, RequestStatus, RequestType, @@ -1410,6 +1412,139 @@ def _control_update_weights(self, control_request: ControlRequest) -> Optional[d return responses + def _control_abort_requests(self, control_req: ControlRequest): + if not envs.ENABLE_V1_KVCACHE_SCHEDULER: + raise Exception("abort_requests only supported in ENABLE_V1_KVCACHE_SCHEDULER") + args = control_req.get_args() + abort_all = args.get("abort_all", False) + req_ids = args.get("req_ids", []) + matched_input_ids = set() + now_reqs = list(set(self.resource_manager.requests.keys()) | set(self.scheduler.requests.keys())) + + # Step 1: Determine target request list + if abort_all: + # all requests in running + waiting + target_req_ids = now_reqs + else: + # filter out requests that actually exist + target_req_ids = [] + for rid in req_ids: + if rid in now_reqs: + target_req_ids.append(rid) + matched_input_ids.add(rid) + elif f"{rid}_0" in now_reqs: + target_req_ids.append(f"{rid}_0") + matched_input_ids.add(rid) + + if not target_req_ids: + return {"aborted": [], "not_found": req_ids if not abort_all else []} + + # Step 2: Collect partial results + aborted_info = [] + results = [] + for req_id in target_req_ids: + request = self.resource_manager.requests.get(req_id) + if request is None: + scheduled_req = self.scheduler.requests.get(req_id) + if scheduled_req is None: + continue + request = scheduled_req.raw + + partial_token_ids = list(request.output_token_ids) + + # Construct finished response with partial results + now = time.time() + abort_metrics = RequestMetrics( + arrival_time=request.metrics.arrival_time if request.metrics else now, + inference_start_time=request.metrics.inference_start_time if request.metrics else now, + engine_recv_latest_token_time=now, + engine_recv_first_token_time=request.metrics.engine_recv_first_token_time if request.metrics else now, + request_start_time=request.metrics.arrival_time if request.metrics else now, + ) + result = RequestOutput( + request_id=req_id, + finished=True, + outputs=CompletionOutput( + index=0, + send_idx=len(partial_token_ids), + token_ids=[self.data_processor.eos_token_ids[0]], + ), + metrics=abort_metrics, + error_code=200, + error_msg="Aborted", + ) + results.append(result) + aborted_info.append( + { + "request_id": req_id, + "output_token_count": len(partial_token_ids), + } + ) + + # Step 3: Execute abort — add all requests to waiting_abort_req_id_set + if envs.ENABLE_V1_KVCACHE_SCHEDULER: + for req_id in target_req_ids: + self.resource_manager.add_abort_req_ids(req_id) + time.sleep(0.0001) + if self.cfg.scheduler_config.splitwise_role != "prefill": + self._wait_abort_complete(target_req_ids) + + # Add results to scheduler, engine will have a thread calling get_results, + # then cleanup and call send_response to send to client. + # When client disconnects, send_response will automatically ignore + if self.cfg.scheduler_config.splitwise_role != "prefill": + try: + # self.send_response_server.send_response(req_id, [result]) + self.scheduler.put_results(results) + except Exception: + pass # client may have disconnected + + not_found = [rid for rid in req_ids if rid not in matched_input_ids] if not abort_all else [] + + return {"aborted": aborted_info, "not_found": not_found} + + def _wait_abort_complete(self, target_req_ids, stall_timeout=1): + """ + Wait for all abort requests to complete. + - Keep monitoring as long as remaining is not empty, which means cleanup is not done yet + - If no progress within stall_timeout seconds, force cleanup requests stuck in to_be_aborted_req_id_set, + reset progress state if any, then continue monitoring + """ + target_set = set(target_req_ids) + prev_remaining_count = len(target_set) + last_progress_time = time.time() + remaining = target_set & self.resource_manager.get_reqs_in_aborting() + while remaining: + remaining = target_set & self.resource_manager.get_reqs_in_aborting() + if not remaining: + self.llm_logger.info(f"all {len(target_set)} abort reqs cleaned") + return + + current_count = len(remaining) + if current_count < prev_remaining_count: + # progress made: recycle_abort_task was called + self.llm_logger.info(f"abort progress: {prev_remaining_count} -> {current_count}") + last_progress_time = time.time() + prev_remaining_count = current_count + + if time.time() - last_progress_time > stall_timeout: + # no progress timeout: only cleanup requests stuck in to_be_aborted (worker hasn't returned -9) + stuck = remaining & self.resource_manager.to_be_aborted_req_id_set + if stuck: + self.llm_logger.warning( + f"no abort progress for {stall_timeout}s, " + f"force cleanup {len(stuck)} stuck requests (in to_be_aborted)" + ) + for req_id in list(stuck): + self.llm_logger.warning(f"force cleanup stuck req_id:{req_id}") + self.resource_manager.recycle_abort_task(req_id) + # reset progress state + last_progress_time = time.time() + prev_remaining_count = current_count - len(stuck) + # else: remaining are all in waiting_abort_req_id_set, waiting for natural flow + + time.sleep(0.005) + async def _wait_all_control_responses(self, request_id: str, timeout: int): """Wait for control responses from all workers with a global timeout. diff --git a/fastdeploy/engine/sched/resource_manager_v1.py b/fastdeploy/engine/sched/resource_manager_v1.py index b0425d779d1..ced828db4ca 100644 --- a/fastdeploy/engine/sched/resource_manager_v1.py +++ b/fastdeploy/engine/sched/resource_manager_v1.py @@ -282,6 +282,7 @@ def recycle_abort_task(self, request_id): del self.requests[request_id] del self.req_dict[request_id] self.to_be_aborted_req_id_set.remove(request_id) + self.update_metrics() def _trigger_abort(self, request_id, scheduled_reqs): if request_id in self.requests: @@ -1207,6 +1208,9 @@ def download_bos_features(bos_client, features_urls): return None inputs["audio_features"] = result + def get_reqs_in_aborting(self): + return self.waiting_abort_req_id_set | self.to_be_aborted_req_id_set + def get_available_position(self) -> int: position = 0 while position < self.max_num_seqs: diff --git a/fastdeploy/entrypoints/openai/api_server.py b/fastdeploy/entrypoints/openai/api_server.py index 492f308268a..bf3c400042e 100644 --- a/fastdeploy/entrypoints/openai/api_server.py +++ b/fastdeploy/entrypoints/openai/api_server.py @@ -441,6 +441,25 @@ async def update_weights(request: Request) -> Response: return control_response.to_api_json_response() +@app.post("/v1/abort_requests") +async def abort_requests(request: Request): + body = await request.json() + abort_all = body.get("abort_all", False) + req_ids = body.get("req_ids", None) + + # 参数校验 + if not abort_all and not req_ids: + return JSONResponse(status_code=400, content={"error": "must provide abort_all=true or req_ids"}) + + control_request = ControlRequest( + request_id=f"control-{uuid.uuid4()}", + method="abort_requests", + args={"abort_all": abort_all, "req_ids": req_ids or []}, + ) + control_response = await app.state.engine_client.run_control_method(control_request) + return control_response.to_api_json_response() + + def wrap_streaming_generator(original_generator: AsyncGenerator): """ Wrap an async generator to release the connection semaphore when the generator is finished. diff --git a/fastdeploy/entrypoints/openai/serving_chat.py b/fastdeploy/entrypoints/openai/serving_chat.py index ce6d7e97925..1064a80495c 100644 --- a/fastdeploy/entrypoints/openai/serving_chat.py +++ b/fastdeploy/entrypoints/openai/serving_chat.py @@ -462,6 +462,9 @@ async def chat_completion_stream_generator( if res.get("error_msg") is not None and "Recover" in res["error_msg"]: choice.finish_reason = "recover_stop" + if res.get("error_msg") is not None and "Aborted" in res["error_msg"]: + choice.finish_reason = "abort" + inference_start_time[idx] = 0 if request.collect_metrics: @@ -795,6 +798,8 @@ async def _create_chat_completion_choice( if data.get("error_msg", None) is not None and "Recover" in data["error_msg"]: finish_reason = "recover_stop" + if data.get("error_msg", None) is not None and "Aborted" in data["error_msg"]: + finish_reason = "abort" return ChatCompletionResponseChoice( index=idx, message=message, diff --git a/fastdeploy/entrypoints/openai/serving_completion.py b/fastdeploy/entrypoints/openai/serving_completion.py index 3dafb270905..2690390a80d 100644 --- a/fastdeploy/entrypoints/openai/serving_completion.py +++ b/fastdeploy/entrypoints/openai/serving_completion.py @@ -581,6 +581,8 @@ async def completion_stream_generator( output, tool_called[idx], ) + if res.get("error_msg") is not None and "Aborted" in res["error_msg"]: + choices[-1].finish_reason = "abort" inference_start_time[idx] = 0 send_idx = output.get("send_idx") @@ -722,6 +724,8 @@ def request_output_to_completion_response( output, False, ) + if final_res.get("error_msg", None) is not None and "Aborted" in final_res["error_msg"]: + finish_reason = "abort" choice_data = CompletionResponseChoice( token_ids=token_ids, diff --git a/fastdeploy/router/router.py b/fastdeploy/router/router.py index d64542b6ccc..960a64e7f58 100644 --- a/fastdeploy/router/router.py +++ b/fastdeploy/router/router.py @@ -17,8 +17,8 @@ import aiohttp import uvicorn -from fastapi import FastAPI, HTTPException -from fastapi.responses import ORJSONResponse, Response, StreamingResponse +from fastapi import FastAPI, HTTPException, Request +from fastapi.responses import JSONResponse, ORJSONResponse, Response, StreamingResponse from fastdeploy.router.utils import ( InstanceInfo, @@ -503,6 +503,48 @@ async def health_generate(): return Response(status_code=200) +@app.post("/v1/abort_requests") +async def abort_requests(request: Request): + body = await request.json() + prefill_servers = app.state.router.prefill_servers + decode_servers = app.state.router.decode_servers + all_servers = prefill_servers + decode_servers + + async with aiohttp.ClientSession() as session: + tasks = [session.post(f"{server.url()}/v1/abort_requests", json=body) for server in all_servers] + responses = await asyncio.gather(*tasks, return_exceptions=True) + + # Aggregate results from Node D only + all_aborted = [] + all_not_found = [] + errors = [] + decode_start = len(prefill_servers) + for i, (server, resp) in enumerate(zip(all_servers, responses)): + if i < decode_start: + continue + if isinstance(resp, Exception): + errors.append({"server": server.url(), "error": str(resp)}) + elif resp.status == 200: + data = await resp.json() + result = data.get("result") or {} + all_aborted.extend(result.get("aborted", [])) + all_not_found.extend(result.get("not_found", [])) + else: + errors.append({"server": server.url(), "status": resp.status}) + + return JSONResponse( + content={ + "request_id": f"router-{uuid4()}", + "status": "success" if not errors else "error", + "error_message": None if not errors else str(errors), + "result": { + "aborted": all_aborted, + "not_found": list(set(all_not_found)), + }, + } + ) + + def launch_router(router_args: RouterArgs): app.state.router_args = router_args print(f"Starting router with args: {router_args}")