Skip to content

Commit a7d051e

Browse files
authored
[Fix] Dump/load all tensors when use_layerwise=False (#351)
* [Fix] Dump/load all tensors * [Fix] Wait for dump tasks when all tasks send * [Fix] Simplify storage_block_ids calculate method
1 parent 018c5ef commit a7d051e

File tree

2 files changed

+58
-77
lines changed

2 files changed

+58
-77
lines changed

test/test_uc_connector.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
import random
2626
import secrets
2727
import unittest
28+
from collections import defaultdict
2829
from typing import List, Union
2930
from unittest.mock import MagicMock, Mock, patch
3031

@@ -106,12 +107,14 @@ def init_uc(
106107
ucconnector.dump_tasks: dict[str, dict[str, List[Task]]] = {}
107108
ucconnector.total_tp_size = self.total_tp_size
108109
ucconnector._connector_metadata = metadata
109-
ucconnector.layerwise_load_tasks: dict[
110-
str, dict[str, tuple[Task, Task]]
111-
] = {}
110+
ucconnector.layerwise_load_tasks: dict[str, dict[str, Task]] = defaultdict(
111+
dict
112+
)
112113
ucconnector._need_load_reqs: dict[str, Union[list[int], list[Task]]] = {}
113114
ucconnector._load_failed_reqs: set[str] = set()
114115
ucconnector._load_req_to_blocks: dict[str, set[int]] = {}
116+
ucconnector.num_layers = 48
117+
ucconnector.is_mla = False
115118
return ucconnector
116119

117120
def test_get_num_new_matched_tokens_hit_all_on_storage(self):
@@ -508,6 +511,7 @@ def test_wait_for_save_not_layerwise_invalid_para(self):
508511
ucconnector.block_size = self.block_size
509512
ucconnector.use_layerwise = False
510513
ucconnector._connector_metadata = Mock()
514+
ucconnector.is_mla = False
511515

512516
with self.assertRaises(AssertionError):
513517
ucconnector.wait_for_save()
@@ -542,6 +546,7 @@ def mock_wait(task: Task) -> int:
542546
)
543547
forward_context = Mock()
544548
ucconnector.start_load_kv(forward_context)
549+
assert mock_connector.load.call_count == 1
545550

546551
def test_start_load_kv_invalid_para(self):
547552
with patch.object(UnifiedCacheConnectorV1, "__init__", return_value=None):
@@ -559,6 +564,7 @@ def test_start_load_kv_layerwise_success(self):
559564
req_meta1.load_blocks = [
560565
(secrets.token_hex(8), i) for i in range(self.block_number)
561566
]
567+
req_meta1.load_async = False
562568

563569
metadata = UCConnectorV1Metadata()
564570
metadata.requests = [req_meta1]
@@ -575,7 +581,7 @@ def mock_load(
575581
ucconnector = self.init_uc(mock_connector, metadata=metadata)
576582
forward_context = Mock()
577583
ucconnector.start_load_kv(forward_context)
578-
assert mock_connector.load.call_count == 2 * self.num_layers
584+
assert mock_connector.load.call_count == self.num_layers
579585

580586

581587
if __name__ == "__main__":

ucm/integration/vllm/uc_connector.py

Lines changed: 48 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -25,9 +25,10 @@
2525
#
2626
import hashlib
2727
import pickle
28+
from collections import defaultdict
2829
from dataclasses import dataclass, field
2930
from enum import Enum
30-
from typing import TYPE_CHECKING, Any, Generator, List, Optional, Union
31+
from typing import TYPE_CHECKING, Any, List, Optional, Union
3132

3233
import torch
3334
from vllm.config import VllmConfig
@@ -98,7 +99,7 @@ def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole):
9899
self.request_block_infos: dict[str, RequestBlockInfo] = {}
99100
# dump tasks record request -> block -> list[task]
100101
self.dump_tasks: dict[str, dict[str, List[Task]]] = {}
101-
self.layerwise_load_tasks: dict[str, dict[str, tuple[Task, Task]]] = {}
102+
self.layerwise_load_tasks: dict[str, dict[str, Task]] = defaultdict(dict)
102103
self.is_mla = self._vllm_config.model_config.is_deepseek_mla
103104
self.num_layers = vllm_config.model_config.get_num_layers(
104105
vllm_config.parallel_config
@@ -261,62 +262,43 @@ def start_load_kv(self, forward_context: "ForwardContext", **kwargs) -> None:
261262

262263
self.layerwise_load_tasks.clear()
263264
self.current_layer = 0
265+
need_wait_tasks = []
264266
for request in metadata.requests:
265267
if not request.load_blocks:
266268
continue
267269

268270
storage_block_ids = [block[0] for block in request.load_blocks]
269271
vllm_block_ids = [block[1] for block in request.load_blocks]
270-
blocks_len = len(storage_block_ids)
271272
self._load_req_to_blocks.setdefault(request.request_id, set()).update(
272273
vllm_block_ids
273274
)
275+
is_load_async = request.load_async
276+
total_offsets = []
277+
total_tensors = []
278+
storage_block_ids = storage_block_ids * (1 if self.is_mla else 2)
274279
for layer_name, kv_layer in self.kv_caches.items():
275280
tensors, offsets = self.get_tensor_and_offset_layerwise(
276281
vllm_block_ids, kv_layer, layer_name
277282
)
278-
k_task_id = self.connector.load(
279-
storage_block_ids, offsets[:blocks_len], tensors[:blocks_len]
280-
)
281-
v_task_id = None
282-
if not self.is_mla:
283-
v_task_id = self.connector.load(
284-
storage_block_ids,
285-
offsets[blocks_len:],
286-
tensors[blocks_len:],
287-
)
288-
if request.request_id not in self.layerwise_load_tasks:
289-
self.layerwise_load_tasks[request.request_id] = {}
290-
self.layerwise_load_tasks[request.request_id][layer_name] = (
291-
k_task_id,
292-
v_task_id,
283+
if self.use_layerwise and not is_load_async:
284+
task_id = self.connector.load(storage_block_ids, offsets, tensors)
285+
self.layerwise_load_tasks[request.request_id][layer_name] = task_id
286+
continue
287+
else:
288+
total_offsets.extend(offsets)
289+
total_tensors.extend(tensors)
290+
if total_offsets and total_tensors:
291+
storage_block_ids = storage_block_ids * self.num_layers
292+
task_id = self.connector.load(
293+
storage_block_ids, total_offsets, total_tensors
293294
)
294-
295-
if request.load_async and request.request_id in self.layerwise_load_tasks:
296-
for _, (k_task, v_task) in self.layerwise_load_tasks[
297-
request.request_id
298-
].items():
299-
if request.request_id not in self._need_load_reqs:
300-
self._need_load_reqs[request.request_id] = []
301-
self._need_load_reqs[request.request_id].append(k_task)
302-
if not self.is_mla:
303-
self._need_load_reqs[request.request_id].append(v_task)
304-
self.layerwise_load_tasks.pop(request.request_id)
305-
continue
306-
307-
if (
308-
not self.use_layerwise
309-
and request.request_id in self.layerwise_load_tasks
310-
):
311-
for _, (k_task, v_task) in self.layerwise_load_tasks[
312-
request.request_id
313-
].items():
314-
if self.connector.wait(k_task) != 0:
315-
self._load_failed_reqs.add(request.request_id)
316-
break
317-
if v_task and self.connector.wait(v_task) != 0:
318-
self._load_failed_reqs.add(request.request_id)
319-
break
295+
if is_load_async:
296+
self._need_load_reqs[request.request_id] = task_id
297+
else:
298+
need_wait_tasks.append(task_id)
299+
for task_id in need_wait_tasks:
300+
if self.connector.wait(task_id) != 0:
301+
self._load_failed_reqs.add(request.request_id)
320302

321303
def wait_for_layer_load(self, layer_name: str) -> None:
322304
"""
@@ -340,20 +322,13 @@ def wait_for_layer_load(self, layer_name: str) -> None:
340322
for request_id, layer_to_task in self.layerwise_load_tasks.items():
341323
if request_id in self._load_failed_reqs:
342324
continue
343-
k_task, v_task = layer_to_task[layer_name]
344-
if self.connector.wait(k_task) != 0:
325+
task = layer_to_task[layer_name]
326+
if self.connector.wait(task) != 0:
345327
self._load_failed_reqs.add(request_id)
346328
logger.error(
347329
f"Failed to load block for request {request_id} on layer {layer_name}"
348330
)
349331
continue
350-
if not self.is_mla:
351-
if self.connector.wait(v_task) != 0:
352-
self._load_failed_reqs.add(request_id)
353-
logger.error(
354-
f"Failed to load block for request {request_id} on layer {layer_name}"
355-
)
356-
continue
357332
logger.debug(f"Load tasks for {request_id} on layer {layer_name} finished.")
358333

359334
def save_kv_layer(
@@ -437,6 +412,8 @@ def wait_for_save(self) -> Optional[dict[str, list[str]]]:
437412
"""
438413
if hasattr(self, "kv_role") and self.kv_role == "kv_consumer":
439414
return
415+
if self.is_mla and self.rank != 0:
416+
return
440417
# request id -> succeed dumped blocks
441418
success_dumped_blocks: dict[str, list[str]] = {}
442419

@@ -455,36 +432,34 @@ def wait_for_tasks():
455432
self.dump_tasks.clear()
456433
return success_dumped_blocks if success_dumped_blocks else None
457434

435+
req_to_dump_blocks: dict[str, list[str]] = {}
436+
need_dump_tasks: dict[str, Task] = {}
458437
for request in metadata.requests:
459438
if not request.dump_blocks:
460439
continue
461440

462441
storage_block_ids = [block[0] for block in request.dump_blocks]
463442
vllm_block_ids = [block[1] for block in request.dump_blocks]
464-
blocks_len = len(storage_block_ids)
443+
req_to_dump_blocks[request.request_id] = storage_block_ids
444+
total_offsets = []
445+
total_tensors = []
446+
total_block_ids = (
447+
storage_block_ids * (1 if self.is_mla else 2) * self.num_layers
448+
)
465449
for layer_name, kv_layer in self.kv_caches.items():
466450
tensors, offsets = self.get_tensor_and_offset_layerwise(
467451
vllm_block_ids, kv_layer, layer_name
468452
)
469-
for block_id, offset, tensor in zip(
470-
storage_block_ids, offsets[:blocks_len], tensors[:blocks_len]
471-
):
472-
task = self.connector.dump([block_id], [offset], [tensor])
473-
self.dump_tasks.setdefault(request.request_id, {}).setdefault(
474-
block_id, []
475-
).append(task)
476-
if not self.is_mla:
477-
for block_id, offset, tensor in zip(
478-
storage_block_ids,
479-
offsets[blocks_len:],
480-
tensors[blocks_len:],
481-
):
482-
task = self.connector.dump([block_id], [offset], [tensor])
483-
self.dump_tasks.setdefault(request.request_id, {}).setdefault(
484-
block_id, []
485-
).append(task)
486-
wait_for_tasks()
487-
self.dump_tasks.clear()
453+
total_offsets.extend(offsets)
454+
total_tensors.extend(tensors)
455+
task_id = self.connector.dump(total_block_ids, total_offsets, total_tensors)
456+
need_dump_tasks[request.request_id] = task_id
457+
458+
for req_id, task_id in need_dump_tasks.items():
459+
if self.connector.wait(task_id) != 0:
460+
logger.error(f"Failed to dump blocks for req {request.request_id}")
461+
else:
462+
success_dumped_blocks[req_id] = req_to_dump_blocks[req_id]
488463
return success_dumped_blocks if success_dumped_blocks else None
489464

490465
def get_finished(self, finished_req_ids: set[str]) -> tuple[set[str], set[str]]:

0 commit comments

Comments
 (0)