@@ -106,7 +106,7 @@ def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole):
106106 )
107107 self .element_size = vllm_config .model_config .dtype .itemsize
108108 self .kv_role = vllm_config .kv_transfer_config .kv_role
109- self ._need_load_reqs : dict [str , Union [list [int ], list [ Task ] ]] = {}
109+ self ._need_load_reqs : dict [str , Union [list [int ], Task ]] = {}
110110 self ._load_failed_reqs : set [str ] = set ()
111111 self ._load_req_to_blocks : dict [str , set [int ]] = {}
112112 self .num_head = vllm_config .model_config .get_num_kv_heads (
@@ -466,32 +466,23 @@ def wait_for_tasks():
466466 def get_finished (self , finished_req_ids : set [str ]) -> tuple [set [str ], set [str ]]:
467467 """Get the finished recving and sending requests."""
468468 done_recving : set [str ] = set ()
469- for req_id , tasks in self ._need_load_reqs .items ():
469+ for req_id , task in self ._need_load_reqs .items ():
470470 if req_id in self ._load_failed_reqs :
471471 done_recving .add (req_id )
472472 continue
473- unfinished_tasks = []
474- for task in tasks :
475- ret , finish = self .connector .check (task )
476- if ret != 0 :
477- logger .error (
478- f"Task { task } failed, check return { ret } for request { req_id } "
479- )
480- self ._load_failed_reqs .add (req_id )
481- break
482- if not finish :
483- unfinished_tasks .append (task )
484- continue
485- wret = self .connector .wait (task )
486- if wret != 0 :
487- logger .error (
488- f"Task { task } failed, wait return { wret } for request { req_id } "
489- )
490- self ._load_failed_reqs .add (req_id )
491- break
492- if unfinished_tasks :
493- self ._need_load_reqs [req_id ] = unfinished_tasks
473+ ret , finish = self .connector .check (task )
474+ if ret != 0 :
475+ logger .error (
476+ f"Task { task } failed, check return { ret } for request { req_id } "
477+ )
478+ self ._load_failed_reqs .add (req_id )
479+ elif not finish :
494480 continue
481+ elif (wret := self .connector .wait (task )) != 0 :
482+ logger .error (
483+ f"Task { task } failed, wait return { wret } for request { req_id } "
484+ )
485+ self ._load_failed_reqs .add (req_id )
495486 done_recving .add (req_id )
496487
497488 # remove the finished requests
0 commit comments