Skip to content

Commit c94b793

Browse files
authored
[Fix] Fix iteration bug for async load task (#357)
1 parent de63b7c commit c94b793

File tree

1 file changed

+14
-23
lines changed

1 file changed

+14
-23
lines changed

ucm/integration/vllm/uc_connector.py

Lines changed: 14 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,7 @@ def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole):
106106
)
107107
self.element_size = vllm_config.model_config.dtype.itemsize
108108
self.kv_role = vllm_config.kv_transfer_config.kv_role
109-
self._need_load_reqs: dict[str, Union[list[int], list[Task]]] = {}
109+
self._need_load_reqs: dict[str, Union[list[int], Task]] = {}
110110
self._load_failed_reqs: set[str] = set()
111111
self._load_req_to_blocks: dict[str, set[int]] = {}
112112
self.num_head = vllm_config.model_config.get_num_kv_heads(
@@ -466,32 +466,23 @@ def wait_for_tasks():
466466
def get_finished(self, finished_req_ids: set[str]) -> tuple[set[str], set[str]]:
467467
"""Get the finished recving and sending requests."""
468468
done_recving: set[str] = set()
469-
for req_id, tasks in self._need_load_reqs.items():
469+
for req_id, task in self._need_load_reqs.items():
470470
if req_id in self._load_failed_reqs:
471471
done_recving.add(req_id)
472472
continue
473-
unfinished_tasks = []
474-
for task in tasks:
475-
ret, finish = self.connector.check(task)
476-
if ret != 0:
477-
logger.error(
478-
f"Task {task} failed, check return {ret} for request {req_id}"
479-
)
480-
self._load_failed_reqs.add(req_id)
481-
break
482-
if not finish:
483-
unfinished_tasks.append(task)
484-
continue
485-
wret = self.connector.wait(task)
486-
if wret != 0:
487-
logger.error(
488-
f"Task {task} failed, wait return {wret} for request {req_id}"
489-
)
490-
self._load_failed_reqs.add(req_id)
491-
break
492-
if unfinished_tasks:
493-
self._need_load_reqs[req_id] = unfinished_tasks
473+
ret, finish = self.connector.check(task)
474+
if ret != 0:
475+
logger.error(
476+
f"Task {task} failed, check return {ret} for request {req_id}"
477+
)
478+
self._load_failed_reqs.add(req_id)
479+
elif not finish:
494480
continue
481+
elif (wret := self.connector.wait(task)) != 0:
482+
logger.error(
483+
f"Task {task} failed, wait return {wret} for request {req_id}"
484+
)
485+
self._load_failed_reqs.add(req_id)
495486
done_recving.add(req_id)
496487

497488
# remove the finished requests

0 commit comments

Comments
 (0)