2525#
2626import hashlib
2727import pickle
28+ from collections import defaultdict
2829from dataclasses import dataclass , field
2930from enum import Enum
30- from typing import TYPE_CHECKING , Any , Generator , List , Optional , Union
31+ from typing import TYPE_CHECKING , Any , List , Optional , Union
3132
3233import torch
3334from vllm .config import VllmConfig
@@ -98,7 +99,7 @@ def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole):
9899 self .request_block_infos : dict [str , RequestBlockInfo ] = {}
99100 # dump tasks record request -> block -> list[task]
100101 self .dump_tasks : dict [str , dict [str , List [Task ]]] = {}
101- self .layerwise_load_tasks : dict [str , dict [str , tuple [ Task , Task ]]] = {}
102+ self .layerwise_load_tasks : dict [str , dict [str , Task ]] = defaultdict ( dict )
102103 self .is_mla = self ._vllm_config .model_config .is_deepseek_mla
103104 self .num_layers = vllm_config .model_config .get_num_layers (
104105 vllm_config .parallel_config
@@ -261,62 +262,43 @@ def start_load_kv(self, forward_context: "ForwardContext", **kwargs) -> None:
261262
262263 self .layerwise_load_tasks .clear ()
263264 self .current_layer = 0
265+ need_wait_tasks = []
264266 for request in metadata .requests :
265267 if not request .load_blocks :
266268 continue
267269
268270 storage_block_ids = [block [0 ] for block in request .load_blocks ]
269271 vllm_block_ids = [block [1 ] for block in request .load_blocks ]
270- blocks_len = len (storage_block_ids )
271272 self ._load_req_to_blocks .setdefault (request .request_id , set ()).update (
272273 vllm_block_ids
273274 )
275+ is_load_async = request .load_async
276+ total_offsets = []
277+ total_tensors = []
278+ storage_block_ids = storage_block_ids * (1 if self .is_mla else 2 )
274279 for layer_name , kv_layer in self .kv_caches .items ():
275280 tensors , offsets = self .get_tensor_and_offset_layerwise (
276281 vllm_block_ids , kv_layer , layer_name
277282 )
278- k_task_id = self .connector .load (
279- storage_block_ids , offsets [:blocks_len ], tensors [:blocks_len ]
280- )
281- v_task_id = None
282- if not self .is_mla :
283- v_task_id = self .connector .load (
284- storage_block_ids ,
285- offsets [blocks_len :],
286- tensors [blocks_len :],
287- )
288- if request .request_id not in self .layerwise_load_tasks :
289- self .layerwise_load_tasks [request .request_id ] = {}
290- self .layerwise_load_tasks [request .request_id ][layer_name ] = (
291- k_task_id ,
292- v_task_id ,
283+ if self .use_layerwise and not is_load_async :
284+ task_id = self .connector .load (storage_block_ids , offsets , tensors )
285+ self .layerwise_load_tasks [request .request_id ][layer_name ] = task_id
286+ continue
287+ else :
288+ total_offsets .extend (offsets )
289+ total_tensors .extend (tensors )
290+ if total_offsets and total_tensors :
291+ storage_block_ids = storage_block_ids * self .num_layers
292+ task_id = self .connector .load (
293+ storage_block_ids , total_offsets , total_tensors
293294 )
294-
295- if request .load_async and request .request_id in self .layerwise_load_tasks :
296- for _ , (k_task , v_task ) in self .layerwise_load_tasks [
297- request .request_id
298- ].items ():
299- if request .request_id not in self ._need_load_reqs :
300- self ._need_load_reqs [request .request_id ] = []
301- self ._need_load_reqs [request .request_id ].append (k_task )
302- if not self .is_mla :
303- self ._need_load_reqs [request .request_id ].append (v_task )
304- self .layerwise_load_tasks .pop (request .request_id )
305- continue
306-
307- if (
308- not self .use_layerwise
309- and request .request_id in self .layerwise_load_tasks
310- ):
311- for _ , (k_task , v_task ) in self .layerwise_load_tasks [
312- request .request_id
313- ].items ():
314- if self .connector .wait (k_task ) != 0 :
315- self ._load_failed_reqs .add (request .request_id )
316- break
317- if v_task and self .connector .wait (v_task ) != 0 :
318- self ._load_failed_reqs .add (request .request_id )
319- break
295+ if is_load_async :
296+ self ._need_load_reqs [request .request_id ] = task_id
297+ else :
298+ need_wait_tasks .append (task_id )
299+ for task_id in need_wait_tasks :
300+ if self .connector .wait (task_id ) != 0 :
301+ self ._load_failed_reqs .add (request .request_id )
320302
321303 def wait_for_layer_load (self , layer_name : str ) -> None :
322304 """
@@ -340,20 +322,13 @@ def wait_for_layer_load(self, layer_name: str) -> None:
340322 for request_id , layer_to_task in self .layerwise_load_tasks .items ():
341323 if request_id in self ._load_failed_reqs :
342324 continue
343- k_task , v_task = layer_to_task [layer_name ]
344- if self .connector .wait (k_task ) != 0 :
325+ task = layer_to_task [layer_name ]
326+ if self .connector .wait (task ) != 0 :
345327 self ._load_failed_reqs .add (request_id )
346328 logger .error (
347329 f"Failed to load block for request { request_id } on layer { layer_name } "
348330 )
349331 continue
350- if not self .is_mla :
351- if self .connector .wait (v_task ) != 0 :
352- self ._load_failed_reqs .add (request_id )
353- logger .error (
354- f"Failed to load block for request { request_id } on layer { layer_name } "
355- )
356- continue
357332 logger .debug (f"Load tasks for { request_id } on layer { layer_name } finished." )
358333
359334 def save_kv_layer (
@@ -437,6 +412,8 @@ def wait_for_save(self) -> Optional[dict[str, list[str]]]:
437412 """
438413 if hasattr (self , "kv_role" ) and self .kv_role == "kv_consumer" :
439414 return
415+ if self .is_mla and self .rank != 0 :
416+ return
440417 # request id -> succeed dumped blocks
441418 success_dumped_blocks : dict [str , list [str ]] = {}
442419
@@ -455,36 +432,34 @@ def wait_for_tasks():
455432 self .dump_tasks .clear ()
456433 return success_dumped_blocks if success_dumped_blocks else None
457434
435+ req_to_dump_blocks : dict [str , list [str ]] = {}
436+ need_dump_tasks : dict [str , Task ] = {}
458437 for request in metadata .requests :
459438 if not request .dump_blocks :
460439 continue
461440
462441 storage_block_ids = [block [0 ] for block in request .dump_blocks ]
463442 vllm_block_ids = [block [1 ] for block in request .dump_blocks ]
464- blocks_len = len (storage_block_ids )
443+ req_to_dump_blocks [request .request_id ] = storage_block_ids
444+ total_offsets = []
445+ total_tensors = []
446+ total_block_ids = (
447+ storage_block_ids * (1 if self .is_mla else 2 ) * self .num_layers
448+ )
465449 for layer_name , kv_layer in self .kv_caches .items ():
466450 tensors , offsets = self .get_tensor_and_offset_layerwise (
467451 vllm_block_ids , kv_layer , layer_name
468452 )
469- for block_id , offset , tensor in zip (
470- storage_block_ids , offsets [:blocks_len ], tensors [:blocks_len ]
471- ):
472- task = self .connector .dump ([block_id ], [offset ], [tensor ])
473- self .dump_tasks .setdefault (request .request_id , {}).setdefault (
474- block_id , []
475- ).append (task )
476- if not self .is_mla :
477- for block_id , offset , tensor in zip (
478- storage_block_ids ,
479- offsets [blocks_len :],
480- tensors [blocks_len :],
481- ):
482- task = self .connector .dump ([block_id ], [offset ], [tensor ])
483- self .dump_tasks .setdefault (request .request_id , {}).setdefault (
484- block_id , []
485- ).append (task )
486- wait_for_tasks ()
487- self .dump_tasks .clear ()
453+ total_offsets .extend (offsets )
454+ total_tensors .extend (tensors )
455+ task_id = self .connector .dump (total_block_ids , total_offsets , total_tensors )
456+ need_dump_tasks [request .request_id ] = task_id
457+
458+ for req_id , task_id in need_dump_tasks .items ():
459+ if self .connector .wait (task_id ) != 0 :
460+ logger .error (f"Failed to dump blocks for req { request .request_id } " )
461+ else :
462+ success_dumped_blocks [req_id ] = req_to_dump_blocks [req_id ]
488463 return success_dumped_blocks if success_dumped_blocks else None
489464
490465 def get_finished (self , finished_req_ids : set [str ]) -> tuple [set [str ], set [str ]]:
0 commit comments