Skip to content

Commit feb498b

Browse files
[BugFix] esa & update patch (#350)
1 parent 9e1401b commit feb498b

File tree

4 files changed

+297
-84
lines changed

4 files changed

+297
-84
lines changed

ucm/integration/vllm/patch/0.9.2/vllm-adapt-sparse.patch

Lines changed: 194 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -1,70 +1,84 @@
1-
From 8c02671e05ed23d7a0c9dc112f8474b26d579f99 Mon Sep 17 00:00:00 2001
2-
From: harrisonyhq <harrisonyhq@gmail.com>
3-
Date: Wed, 5 Nov 2025 00:22:36 -0800
4-
Subject: [PATCH 3/3] [Patch2] UCM patch for sparsed attention
1+
From 0431022b90649f7115b89b61aaf2a0f83e896d5a Mon Sep 17 00:00:00 2001
2+
From: wenxinwang <wangwenxin21@huawei.com>
3+
Date: Mon, 10 Nov 2025 20:35:47 +0800
4+
Subject: [PATCH] adapt to deepseek patch
55

66
---
7-
vllm/attention/layer.py | 43 ++++++++++++++++++
8-
vllm/v1/core/kv_cache_manager.py | 7 ++-
9-
vllm/v1/core/sched/output.py | 3 ++
10-
vllm/v1/core/sched/scheduler.py | 26 ++++++++++-
11-
vllm/v1/worker/block_table.py | 13 ++++++
12-
vllm/v1/worker/gpu_model_runner.py | 70 +++++++++++++++++++++++++-----
13-
vllm/v1/worker/gpu_worker.py | 2 +
14-
7 files changed, 151 insertions(+), 13 deletions(-)
7+
vllm/attention/layer.py | 49 ++++++++++++-
8+
.../kv_transfer/kv_connector/utils.py | 5 ++
9+
.../v1/shared_storage_connector.py | 7 +-
10+
vllm/v1/attention/backends/mla/common.py | 10 ++-
11+
vllm/v1/core/kv_cache_manager.py | 7 +-
12+
vllm/v1/core/sched/output.py | 3 +
13+
vllm/v1/core/sched/scheduler.py | 37 +++++++---
14+
vllm/v1/worker/block_table.py | 13 ++++
15+
vllm/v1/worker/gpu_model_runner.py | 71 +++++++++++++++----
16+
vllm/v1/worker/gpu_worker.py | 2 +
17+
10 files changed, 171 insertions(+), 33 deletions(-)
1518

1619
diff --git a/vllm/attention/layer.py b/vllm/attention/layer.py
17-
index f0ad68b16..d55f3d689 100644
20+
index f0ad68b16..728ab99fd 100644
1821
--- a/vllm/attention/layer.py
1922
+++ b/vllm/attention/layer.py
20-
@@ -22,6 +22,7 @@ from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
23+
@@ -2,7 +2,6 @@
24+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
25+
"""Attention layer."""
26+
from typing import Any, Dict, List, Optional
27+
-
28+
import torch
29+
import torch.nn as nn
30+
import torch.nn.functional as F
31+
@@ -22,6 +21,7 @@ from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
2132
from vllm.platforms import _Backend, current_platform
2233
from vllm.utils import direct_register_custom_op
2334
from vllm.v1.attention.backends.utils import validate_kv_sharing_target
2435
+from ucm.sparse.state import get_ucm_sparse, has_ucm_sparse
2536

2637

2738
class Attention(nn.Module):
28-
@@ -409,9 +410,11 @@ def unified_attention(
39+
@@ -409,9 +409,10 @@ def unified_attention(
2940
attn_metadata = attn_metadata[layer_name]
3041
self = forward_context.no_compile_layers[layer_name]
3142
kv_cache = self.kv_cache[forward_context.virtual_engine]
3243
+ maybe_execute_sparse_attention_begin(query, key, value, layer_name, forward_context)
3344
output = self.impl.forward(self, query, key, value, kv_cache,
3445
attn_metadata)
35-
46+
-
3647
+ maybe_execute_sparse_attention_finished(query, key, value, output, layer_name, forward_context)
3748
maybe_save_kv_layer_to_connector(layer_name, kv_cache)
3849
return output
3950

40-
@@ -449,6 +452,7 @@ def unified_attention_with_output(
51+
@@ -449,6 +450,8 @@ def unified_attention_with_output(
4152
attn_metadata = attn_metadata[layer_name]
4253
self = forward_context.no_compile_layers[layer_name]
4354
kv_cache = self.kv_cache[forward_context.virtual_engine]
44-
+ maybe_execute_sparse_attention_begin(query, key, value, layer_name, forward_context)
55+
+ if not self.use_mla:
56+
+ maybe_execute_sparse_attention_begin(query, key, value, layer_name, forward_context)
4557
self.impl.forward(self,
4658
query,
4759
key,
48-
@@ -458,6 +462,7 @@ def unified_attention_with_output(
60+
@@ -457,7 +460,8 @@ def unified_attention_with_output(
61+
attn_metadata,
4962
output=output,
5063
output_scale=output_scale)
51-
52-
+ maybe_execute_sparse_attention_finished(query, key, value, output, layer_name, forward_context)
64+
-
65+
+ if not self.use_mla:
66+
+ maybe_execute_sparse_attention_finished(query, key, value, output, layer_name, forward_context)
5367
maybe_save_kv_layer_to_connector(layer_name, kv_cache)
5468

5569

56-
@@ -479,3 +484,41 @@ direct_register_custom_op(
70+
@@ -479,3 +483,42 @@ direct_register_custom_op(
5771
fake_impl=unified_attention_with_output_fake,
5872
dispatch_key=current_platform.dispatch_key,
5973
)
6074
+
61-
+
6275
+def maybe_execute_sparse_attention_begin(
6376
+ query: torch.Tensor,
6477
+ key: torch.Tensor,
6578
+ value: torch.Tensor,
6679
+ layer_name: str,
6780
+ forward_context: ForwardContext,
81+
+ phase: Optional[str] = None,
6882
+):
6983
+ if not has_ucm_sparse():
7084
+ return
@@ -75,7 +89,7 @@ index f0ad68b16..d55f3d689 100644
7589
+ if attn_metadata is None:
7690
+ return
7791
+
78-
+ ucm_sparse.attention_begin(query, key, value, layer_name, forward_context)
92+
+ ucm_sparse.attention_begin(query, key, value, layer_name, forward_context, phase)
7993
+
8094
+def maybe_execute_sparse_attention_finished(
8195
+ query: torch.Tensor,
@@ -84,6 +98,7 @@ index f0ad68b16..d55f3d689 100644
8498
+ attn_output: torch.Tensor,
8599
+ layer_name: str,
86100
+ forward_context: ForwardContext,
101+
+ phase: Optional[str] = None,
87102
+):
88103
+ if not has_ucm_sparse():
89104
+ return
@@ -94,8 +109,101 @@ index f0ad68b16..d55f3d689 100644
94109
+ if attn_metadata is None:
95110
+ return
96111
+
97-
+ ucm_sparse.attention_finished(query, key, value, attn_output, layer_name, forward_context)
98-
\ No newline at end of file
112+
+ ucm_sparse.attention_finished(query, key, value, attn_output, layer_name, forward_context, phase)
113+
diff --git a/vllm/distributed/kv_transfer/kv_connector/utils.py b/vllm/distributed/kv_transfer/kv_connector/utils.py
114+
index b63bf5965..155597c51 100644
115+
--- a/vllm/distributed/kv_transfer/kv_connector/utils.py
116+
+++ b/vllm/distributed/kv_transfer/kv_connector/utils.py
117+
@@ -3,6 +3,11 @@
118+
"""
119+
KV cache helper for store.
120+
"""
121+
+from collections import defaultdict
122+
+from collections.abc import Sequence
123+
+from concurrent.futures import CancelledError, Future
124+
+from typing import Optional, cast
125+
+
126+
import torch
127+
128+
from collections import defaultdict
129+
diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/shared_storage_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/shared_storage_connector.py
130+
index 3c574d065..223106def 100644
131+
--- a/vllm/distributed/kv_transfer/kv_connector/v1/shared_storage_connector.py
132+
+++ b/vllm/distributed/kv_transfer/kv_connector/v1/shared_storage_connector.py
133+
@@ -2,7 +2,7 @@
134+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
135+
import hashlib
136+
import os
137+
-from dataclasses import dataclass
138+
+from dataclasses import dataclass, field
139+
from typing import TYPE_CHECKING
140+
141+
import safetensors
142+
@@ -53,10 +53,7 @@ class ReqMeta:
143+
144+
@dataclass
145+
class SharedStorageConnectorMetadata(KVConnectorMetadata):
146+
- requests: list[ReqMeta]
147+
-
148+
- def __init__(self):
149+
- self.requests = []
150+
+ requests: list[ReqMeta] = field(default_factory=list)
151+
152+
def add_request(
153+
self,
154+
diff --git a/vllm/v1/attention/backends/mla/common.py b/vllm/v1/attention/backends/mla/common.py
155+
index f2aaf59a4..b56f62b39 100644
156+
--- a/vllm/v1/attention/backends/mla/common.py
157+
+++ b/vllm/v1/attention/backends/mla/common.py
158+
@@ -200,6 +200,7 @@ from vllm.attention.backends.abstract import (AttentionBackend, AttentionLayer,
159+
MLAAttentionImpl)
160+
from vllm.attention.backends.utils import get_mla_dims
161+
from vllm.attention.ops.merge_attn_states import merge_attn_states
162+
+from vllm.forward_context import ForwardContext, get_forward_context
163+
from vllm.attention.utils.fa_utils import get_flash_attn_version
164+
from vllm.logger import init_logger
165+
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
166+
@@ -211,6 +212,7 @@ from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder,
167+
CommonAttentionMetadata)
168+
from vllm.v1.kv_cache_interface import AttentionSpec
169+
from vllm.v1.worker.block_table import BlockTable
170+
+from vllm.attention.layer import (maybe_execute_sparse_attention_begin, maybe_execute_sparse_attention_finished)
171+
172+
try:
173+
from vllm.vllm_flash_attn import flash_attn_varlen_func
174+
@@ -908,7 +910,7 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
175+
output: Optional[torch.Tensor] = None,
176+
output_scale: Optional[torch.Tensor] = None,
177+
) -> torch.Tensor:
178+
-
179+
+ forward_context: ForwardContext = get_forward_context()
180+
assert output is not None, "Output tensor must be provided."
181+
182+
if output_scale is not None:
183+
@@ -957,10 +959,11 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
184+
)
185+
186+
if has_prefill:
187+
+ maybe_execute_sparse_attention_begin(prefill_q, prefill_k_c_normed, prefill_k_pe, layer.layer_name, forward_context, "prefill")
188+
output[num_decode_tokens:] = self._forward_prefill(
189+
prefill_q, prefill_k_c_normed, prefill_k_pe, kv_cache,
190+
attn_metadata)
191+
-
192+
+ maybe_execute_sparse_attention_finished(prefill_q, prefill_k_c_normed, prefill_k_pe, output[num_decode_tokens:], layer.layer_name, forward_context, "prefill")
193+
if has_decode:
194+
assert attn_metadata.decode is not None
195+
decode_q_nope, decode_q_pe = decode_q.split(
196+
@@ -971,8 +974,9 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
197+
decode_ql_nope = torch.bmm(decode_q_nope, self.W_UK_T)
198+
# Convert from (N, B, L) to (B, N, L)
199+
decode_ql_nope = decode_ql_nope.transpose(0, 1)
200+
-
201+
+ maybe_execute_sparse_attention_begin(torch.cat([decode_ql_nope, decode_q_pe],dim=-1), decode_ql_nope, decode_q_pe, layer.layer_name, forward_context, "decode")
202+
output[:num_decode_tokens] = self._forward_decode(
203+
decode_ql_nope, decode_q_pe, kv_cache, attn_metadata)
204+
+ maybe_execute_sparse_attention_finished(torch.cat([decode_ql_nope, decode_q_pe],dim=-1), decode_ql_nope, decode_q_pe, output[:num_decode_tokens], layer.layer_name, forward_context, "decode")
205+
206+
return output_padded
99207
diff --git a/vllm/v1/core/kv_cache_manager.py b/vllm/v1/core/kv_cache_manager.py
100208
index 6937455e7..bf9aec864 100644
101209
--- a/vllm/v1/core/kv_cache_manager.py
@@ -136,7 +244,7 @@ index 6937455e7..bf9aec864 100644
136244
if new_computed_blocks is not None:
137245
new_computed_block_list = new_computed_blocks.blocks
138246
diff --git a/vllm/v1/core/sched/output.py b/vllm/v1/core/sched/output.py
139-
index c94e421c0..f6f170e10 100644
247+
index c94e421c0..fff0eeb42 100644
140248
--- a/vllm/v1/core/sched/output.py
141249
+++ b/vllm/v1/core/sched/output.py
142250
@@ -157,3 +157,6 @@ class SchedulerOutput:
@@ -146,9 +254,8 @@ index c94e421c0..f6f170e10 100644
146254
+
147255
+ # modified slots by sparse algorithm
148256
+ req_sparsed_slots: dict[str, int] = None
149-
\ No newline at end of file
150257
diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py
151-
index 2d4fd4d59..8268c1409 100644
258+
index 2d4fd4d59..e99a51788 100644
152259
--- a/vllm/v1/core/sched/scheduler.py
153260
+++ b/vllm/v1/core/sched/scheduler.py
154261
@@ -35,6 +35,8 @@ from vllm.v1.request import Request, RequestStatus
@@ -230,7 +337,42 @@ index 2d4fd4d59..8268c1409 100644
230337
# finished_req_ids is an existing state in the scheduler,
231338
# instead of being newly scheduled in this step.
232339
# It contains the request IDs that are finished in between
233-
@@ -955,6 +975,8 @@ class Scheduler(SchedulerInterface):
340+
@@ -809,16 +829,12 @@ class Scheduler(SchedulerInterface):
341+
new_logprobs = None
342+
new_token_ids = generated_token_ids
343+
kv_transfer_params = None
344+
+
345+
if model_runner_output.finished_dumping is not None:
346+
request.succeed_dumped_blocks.extend(model_runner_output.finished_dumping.get(req_id, []))
347+
is_prefill = request.num_output_tokens == 0
348+
if is_prefill:
349+
- if isinstance(self.connector, MultiConnector):
350+
- for c in self.connector._connectors:
351+
- if hasattr(c, 'connector') and hasattr(c.connector, 'commit'):
352+
- c.connector.commit(model_runner_output.finished_dumping.get(req_id, []), True)
353+
- else:
354+
- self.connector.connector.commit(model_runner_output.finished_dumping.get(req_id, []), True)
355+
+ self.connector.connector.commit(model_runner_output.finished_dumping.get(req_id, []), True)
356+
357+
# Append generated tokens and check for stop. Note that if
358+
# a request is still being prefilled, we expect the model runner
359+
@@ -870,7 +886,6 @@ class Scheduler(SchedulerInterface):
360+
spec_token_ids[req_index])
361+
else:
362+
request.spec_token_ids = spec_token_ids[req_index]
363+
-
364+
# Get prompt logprobs for this request.
365+
prompt_logprobs_tensors = prompt_logprobs_dict.get(req_id)
366+
if new_token_ids or pooler_output is not None \
367+
@@ -897,6 +912,7 @@ class Scheduler(SchedulerInterface):
368+
369+
if not stopped:
370+
new_running.append(request)
371+
+
372+
self.running = new_running
373+
374+
# KV Connector: update state for finished KV Transfers.
375+
@@ -955,6 +971,8 @@ class Scheduler(SchedulerInterface):
234376
def add_request(self, request: Request) -> None:
235377
self.waiting.add_request(request)
236378
self.requests[request.request_id] = request
@@ -239,7 +381,7 @@ index 2d4fd4d59..8268c1409 100644
239381
if self.log_stats:
240382
request.record_event(EngineCoreEventType.QUEUED)
241383

242-
@@ -1004,6 +1026,8 @@ class Scheduler(SchedulerInterface):
384+
@@ -1004,6 +1022,8 @@ class Scheduler(SchedulerInterface):
243385

244386
def _free_request(self, request: Request) -> Optional[dict[str, Any]]:
245387
assert request.is_finished()
@@ -248,6 +390,14 @@ index 2d4fd4d59..8268c1409 100644
248390

249391
delay_free_blocks, kv_xfer_params = self._connector_finished(request)
250392
self.encoder_cache_manager.free(request)
393+
@@ -1155,7 +1175,6 @@ class Scheduler(SchedulerInterface):
394+
logger.debug("Finished sending KV transfer for request %s", req_id)
395+
self._free_blocks(self.requests[req_id])
396+
397+
-
398+
def _update_requests_with_invalid_blocks(
399+
self, requests: Iterable[Request],
400+
invalid_block_ids: set[int]) -> tuple[set[str], int]:
251401
diff --git a/vllm/v1/worker/block_table.py b/vllm/v1/worker/block_table.py
252402
index 8f4e8d64c..f45e39f5c 100644
253403
--- a/vllm/v1/worker/block_table.py
@@ -280,7 +430,7 @@ index 8f4e8d64c..f45e39f5c 100644
280430
for i, block_table in enumerate(self.block_tables):
281431
block_table.add_row(block_ids[i], row_idx)
282432
diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py
283-
index c3df1d5d2..6341efc70 100644
433+
index c3df1d5d2..dbf1ea7d7 100644
284434
--- a/vllm/v1/worker/gpu_model_runner.py
285435
+++ b/vllm/v1/worker/gpu_model_runner.py
286436
@@ -72,6 +72,9 @@ from ..sample.logits_processor import LogitsProcessorManager
@@ -347,7 +497,7 @@ index c3df1d5d2..6341efc70 100644
347497
self.input_batch.block_table.append_row(new_block_ids, req_index)
348498

349499
# For the last rank, we don't need to update the token_ids_cpu
350-
@@ -639,6 +647,20 @@ class GPUModelRunner(LoRAModelRunnerMixin):
500+
@@ -639,6 +647,19 @@ class GPUModelRunner(LoRAModelRunnerMixin):
351501
if self.uses_mrope:
352502
self._calc_mrope_positions(scheduler_output)
353503

@@ -364,11 +514,10 @@ index c3df1d5d2..6341efc70 100644
364514
+ offset = 0 if req_index == 0 else cu_num_tokens[req_index - 1] # TODO: support MTP
365515
+ if is_sparsed_request:
366516
+ sparsed_positions[offset] = req_sparsed_slots[req_id] - 1
367-
+
368517
# Get token indices.
369518
# E.g., [0, 1, 0, 1, 2, 3, 4, 0, 1, 2]
370519
# -> [0, 1, M, M + 1, M + 2, M + 3, M + 4, 2 * M, 2 * M + 1, 2 * M + 2]
371-
@@ -668,11 +690,11 @@ class GPUModelRunner(LoRAModelRunnerMixin):
520+
@@ -668,11 +689,11 @@ class GPUModelRunner(LoRAModelRunnerMixin):
372521
# block_size.
373522
block_table_indices = (
374523
req_indices * block_table.max_num_blocks_per_req +
@@ -382,7 +531,7 @@ index c3df1d5d2..6341efc70 100644
382531
np.add(
383532
block_numbers * block_size,
384533
block_offsets,
385-
@@ -682,9 +704,11 @@ class GPUModelRunner(LoRAModelRunnerMixin):
534+
@@ -682,9 +703,11 @@ class GPUModelRunner(LoRAModelRunnerMixin):
386535
self.query_start_loc_np[0] = 0
387536
self.query_start_loc_np[1:num_reqs + 1] = cu_num_tokens
388537

@@ -397,7 +546,7 @@ index c3df1d5d2..6341efc70 100644
397546

398547
# Copy the tensors to the GPU.
399548
self.input_ids[:total_num_scheduled_tokens].copy_(
400-
@@ -696,6 +720,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
549+
@@ -696,6 +719,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
401550
non_blocking=True)
402551
else:
403552
# Common case (1D positions)
@@ -406,15 +555,15 @@ index c3df1d5d2..6341efc70 100644
406555
self.positions[:total_num_scheduled_tokens].copy_(
407556
self.positions_cpu[:total_num_scheduled_tokens],
408557
non_blocking=True)
409-
@@ -1386,6 +1412,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
558+
@@ -1386,6 +1411,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
410559
skip_cuda_graphs=skip_cuda_graphs,
411560
):
412561
self.maybe_setup_kv_connector(scheduler_output)
413562
+ self.maybe_execute_ucm_sparse_begin(scheduler_output, attn_metadata)
414563

415564
model_output = self.model(
416565
input_ids=input_ids,
417-
@@ -1395,6 +1422,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
566+
@@ -1395,6 +1421,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
418567
)
419568

420569
finished_dumping = self.maybe_wait_for_kv_save()
@@ -423,7 +572,12 @@ index c3df1d5d2..6341efc70 100644
423572
finished_sending, finished_recving = (
424573
self.get_finished_kv_transfers(scheduler_output))
425574
invalid_block_ids = self.get_block_ids_with_load_errors()
426-
@@ -1745,6 +1774,25 @@ class GPUModelRunner(LoRAModelRunnerMixin):
575+
@@ -1741,10 +1769,29 @@ class GPUModelRunner(LoRAModelRunnerMixin):
576+
kv_connector.start_load_kv(get_forward_context())
577+
578+
@staticmethod
579+
- def maybe_wait_for_kv_save():
580+
+ def maybe_wait_for_kv_save() -> Optional[dict[str, list[str]]]:
427581
if has_kv_transfer_group():
428582
return get_kv_transfer_group().wait_for_save()
429583

0 commit comments

Comments
 (0)