1- From 8c02671e05ed23d7a0c9dc112f8474b26d579f99 Mon Sep 17 00:00:00 2001
2- From: harrisonyhq <harrisonyhq@gmail .com>
3- Date: Wed, 5 Nov 2025 00:22:36 - 0800
4- Subject: [PATCH 3/3] [Patch2] UCM patch for sparsed attention
1+ From 0431022b90649f7115b89b61aaf2a0f83e896d5a Mon Sep 17 00:00:00 2001
2+ From: wenxinwang <wangwenxin21@huawei .com>
3+ Date: Mon, 10 Nov 2025 20:35:47 + 0800
4+ Subject: [PATCH] adapt to deepseek patch
55
66---
7- vllm/attention/layer.py | 43 ++++++++++++++++++
8- vllm/v1/core/kv_cache_manager.py | 7 ++-
9- vllm/v1/core/sched/output.py | 3 ++
10- vllm/v1/core/sched/scheduler.py | 26 ++++++++++-
11- vllm/v1/worker/block_table.py | 13 ++++++
12- vllm/v1/worker/gpu_model_runner.py | 70 +++++++++++++++++++++++++-----
13- vllm/v1/worker/gpu_worker.py | 2 +
14- 7 files changed, 151 insertions(+), 13 deletions(-)
7+ vllm/attention/layer.py | 49 ++++++++++++-
8+ .../kv_transfer/kv_connector/utils.py | 5 ++
9+ .../v1/shared_storage_connector.py | 7 +-
10+ vllm/v1/attention/backends/mla/common.py | 10 ++-
11+ vllm/v1/core/kv_cache_manager.py | 7 +-
12+ vllm/v1/core/sched/output.py | 3 +
13+ vllm/v1/core/sched/scheduler.py | 37 +++++++---
14+ vllm/v1/worker/block_table.py | 13 ++++
15+ vllm/v1/worker/gpu_model_runner.py | 71 +++++++++++++++----
16+ vllm/v1/worker/gpu_worker.py | 2 +
17+ 10 files changed, 171 insertions(+), 33 deletions(-)
1518
1619diff --git a/vllm/attention/layer.py b/vllm/attention/layer.py
17- index f0ad68b16..d55f3d689 100644
20+ index f0ad68b16..728ab99fd 100644
1821--- a/vllm/attention/layer.py
1922+++ b/vllm/attention/layer.py
20- @@ -22,6 +22,7 @@ from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
23+ @@ -2,7 +2,6 @@
24+ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
25+ """Attention layer."""
26+ from typing import Any, Dict, List, Optional
27+ -
28+ import torch
29+ import torch.nn as nn
30+ import torch.nn.functional as F
31+ @@ -22,6 +21,7 @@ from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
2132 from vllm.platforms import _Backend, current_platform
2233 from vllm.utils import direct_register_custom_op
2334 from vllm.v1.attention.backends.utils import validate_kv_sharing_target
2435+ from ucm.sparse.state import get_ucm_sparse, has_ucm_sparse
2536
2637
2738 class Attention(nn.Module):
28- @@ -409,9 +410,11 @@ def unified_attention(
39+ @@ -409,9 +409,10 @@ def unified_attention(
2940 attn_metadata = attn_metadata[layer_name]
3041 self = forward_context.no_compile_layers[layer_name]
3142 kv_cache = self.kv_cache[forward_context.virtual_engine]
3243+ maybe_execute_sparse_attention_begin(query, key, value, layer_name, forward_context)
3344 output = self.impl.forward(self, query, key, value, kv_cache,
3445 attn_metadata)
35-
46+ -
3647+ maybe_execute_sparse_attention_finished(query, key, value, output, layer_name, forward_context)
3748 maybe_save_kv_layer_to_connector(layer_name, kv_cache)
3849 return output
3950
40- @@ -449,6 +452,7 @@ def unified_attention_with_output(
51+ @@ -449,6 +450,8 @@ def unified_attention_with_output(
4152 attn_metadata = attn_metadata[layer_name]
4253 self = forward_context.no_compile_layers[layer_name]
4354 kv_cache = self.kv_cache[forward_context.virtual_engine]
44- + maybe_execute_sparse_attention_begin(query, key, value, layer_name, forward_context)
55+ + if not self.use_mla:
56+ + maybe_execute_sparse_attention_begin(query, key, value, layer_name, forward_context)
4557 self.impl.forward(self,
4658 query,
4759 key,
48- @@ -458,6 +462,7 @@ def unified_attention_with_output(
60+ @@ -457,7 +460,8 @@ def unified_attention_with_output(
61+ attn_metadata,
4962 output=output,
5063 output_scale=output_scale)
51-
52- + maybe_execute_sparse_attention_finished(query, key, value, output, layer_name, forward_context)
64+ -
65+ + if not self.use_mla:
66+ + maybe_execute_sparse_attention_finished(query, key, value, output, layer_name, forward_context)
5367 maybe_save_kv_layer_to_connector(layer_name, kv_cache)
5468
5569
56- @@ -479,3 +484,41 @@ direct_register_custom_op(
70+ @@ -479,3 +483,42 @@ direct_register_custom_op(
5771 fake_impl=unified_attention_with_output_fake,
5872 dispatch_key=current_platform.dispatch_key,
5973 )
6074+
61- +
6275+ def maybe_execute_sparse_attention_begin(
6376+ query: torch.Tensor,
6477+ key: torch.Tensor,
6578+ value: torch.Tensor,
6679+ layer_name: str,
6780+ forward_context: ForwardContext,
81+ + phase: Optional[str] = None,
6882+ ):
6983+ if not has_ucm_sparse():
7084+ return
@@ -75,7 +89,7 @@ index f0ad68b16..d55f3d689 100644
7589+ if attn_metadata is None:
7690+ return
7791+
78- + ucm_sparse.attention_begin(query, key, value, layer_name, forward_context)
92+ + ucm_sparse.attention_begin(query, key, value, layer_name, forward_context, phase )
7993+
8094+ def maybe_execute_sparse_attention_finished(
8195+ query: torch.Tensor,
@@ -84,6 +98,7 @@ index f0ad68b16..d55f3d689 100644
8498+ attn_output: torch.Tensor,
8599+ layer_name: str,
86100+ forward_context: ForwardContext,
101+ + phase: Optional[str] = None,
87102+ ):
88103+ if not has_ucm_sparse():
89104+ return
@@ -94,8 +109,101 @@ index f0ad68b16..d55f3d689 100644
94109+ if attn_metadata is None:
95110+ return
96111+
97- + ucm_sparse.attention_finished(query, key, value, attn_output, layer_name, forward_context)
98- \ No newline at end of file
112+ + ucm_sparse.attention_finished(query, key, value, attn_output, layer_name, forward_context, phase)
113+ diff --git a/vllm/distributed/kv_transfer/kv_connector/utils.py b/vllm/distributed/kv_transfer/kv_connector/utils.py
114+ index b63bf5965..155597c51 100644
115+ --- a/vllm/distributed/kv_transfer/kv_connector/utils.py
116+ +++ b/vllm/distributed/kv_transfer/kv_connector/utils.py
117+ @@ -3,6 +3,11 @@
118+ """
119+ KV cache helper for store.
120+ """
121+ + from collections import defaultdict
122+ + from collections.abc import Sequence
123+ + from concurrent.futures import CancelledError, Future
124+ + from typing import Optional, cast
125+ +
126+ import torch
127+
128+ from collections import defaultdict
129+ diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/shared_storage_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/shared_storage_connector.py
130+ index 3c574d065..223106def 100644
131+ --- a/vllm/distributed/kv_transfer/kv_connector/v1/shared_storage_connector.py
132+ +++ b/vllm/distributed/kv_transfer/kv_connector/v1/shared_storage_connector.py
133+ @@ -2,7 +2,7 @@
134+ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
135+ import hashlib
136+ import os
137+ - from dataclasses import dataclass
138+ + from dataclasses import dataclass, field
139+ from typing import TYPE_CHECKING
140+
141+ import safetensors
142+ @@ -53,10 +53,7 @@ class ReqMeta:
143+
144+ @dataclass
145+ class SharedStorageConnectorMetadata(KVConnectorMetadata):
146+ - requests: list[ReqMeta]
147+ -
148+ - def __init__(self):
149+ - self.requests = []
150+ + requests: list[ReqMeta] = field(default_factory=list)
151+
152+ def add_request(
153+ self,
154+ diff --git a/vllm/v1/attention/backends/mla/common.py b/vllm/v1/attention/backends/mla/common.py
155+ index f2aaf59a4..b56f62b39 100644
156+ --- a/vllm/v1/attention/backends/mla/common.py
157+ +++ b/vllm/v1/attention/backends/mla/common.py
158+ @@ -200,6 +200,7 @@ from vllm.attention.backends.abstract import (AttentionBackend, AttentionLayer,
159+ MLAAttentionImpl)
160+ from vllm.attention.backends.utils import get_mla_dims
161+ from vllm.attention.ops.merge_attn_states import merge_attn_states
162+ + from vllm.forward_context import ForwardContext, get_forward_context
163+ from vllm.attention.utils.fa_utils import get_flash_attn_version
164+ from vllm.logger import init_logger
165+ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
166+ @@ -211,6 +212,7 @@ from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder,
167+ CommonAttentionMetadata)
168+ from vllm.v1.kv_cache_interface import AttentionSpec
169+ from vllm.v1.worker.block_table import BlockTable
170+ + from vllm.attention.layer import (maybe_execute_sparse_attention_begin, maybe_execute_sparse_attention_finished)
171+
172+ try:
173+ from vllm.vllm_flash_attn import flash_attn_varlen_func
174+ @@ -908,7 +910,7 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
175+ output: Optional[torch.Tensor] = None,
176+ output_scale: Optional[torch.Tensor] = None,
177+ ) -> torch.Tensor:
178+ -
179+ + forward_context: ForwardContext = get_forward_context()
180+ assert output is not None, "Output tensor must be provided."
181+
182+ if output_scale is not None:
183+ @@ -957,10 +959,11 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
184+ )
185+
186+ if has_prefill:
187+ + maybe_execute_sparse_attention_begin(prefill_q, prefill_k_c_normed, prefill_k_pe, layer.layer_name, forward_context, "prefill")
188+ output[num_decode_tokens:] = self._forward_prefill(
189+ prefill_q, prefill_k_c_normed, prefill_k_pe, kv_cache,
190+ attn_metadata)
191+ -
192+ + maybe_execute_sparse_attention_finished(prefill_q, prefill_k_c_normed, prefill_k_pe, output[num_decode_tokens:], layer.layer_name, forward_context, "prefill")
193+ if has_decode:
194+ assert attn_metadata.decode is not None
195+ decode_q_nope, decode_q_pe = decode_q.split(
196+ @@ -971,8 +974,9 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
197+ decode_ql_nope = torch.bmm(decode_q_nope, self.W_UK_T)
198+ # Convert from (N, B, L) to (B, N, L)
199+ decode_ql_nope = decode_ql_nope.transpose(0, 1)
200+ -
201+ + maybe_execute_sparse_attention_begin(torch.cat([decode_ql_nope, decode_q_pe],dim=-1), decode_ql_nope, decode_q_pe, layer.layer_name, forward_context, "decode")
202+ output[:num_decode_tokens] = self._forward_decode(
203+ decode_ql_nope, decode_q_pe, kv_cache, attn_metadata)
204+ + maybe_execute_sparse_attention_finished(torch.cat([decode_ql_nope, decode_q_pe],dim=-1), decode_ql_nope, decode_q_pe, output[:num_decode_tokens], layer.layer_name, forward_context, "decode")
205+
206+ return output_padded
99207diff --git a/vllm/v1/core/kv_cache_manager.py b/vllm/v1/core/kv_cache_manager.py
100208index 6937455e7..bf9aec864 100644
101209--- a/vllm/v1/core/kv_cache_manager.py
@@ -136,7 +244,7 @@ index 6937455e7..bf9aec864 100644
136244 if new_computed_blocks is not None:
137245 new_computed_block_list = new_computed_blocks.blocks
138246diff --git a/vllm/v1/core/sched/output.py b/vllm/v1/core/sched/output.py
139- index c94e421c0..f6f170e10 100644
247+ index c94e421c0..fff0eeb42 100644
140248--- a/vllm/v1/core/sched/output.py
141249+++ b/vllm/v1/core/sched/output.py
142250@@ -157,3 +157,6 @@ class SchedulerOutput:
@@ -146,9 +254,8 @@ index c94e421c0..f6f170e10 100644
146254+
147255+ # modified slots by sparse algorithm
148256+ req_sparsed_slots: dict[str, int] = None
149- \ No newline at end of file
150257diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py
151- index 2d4fd4d59..8268c1409 100644
258+ index 2d4fd4d59..e99a51788 100644
152259--- a/vllm/v1/core/sched/scheduler.py
153260+++ b/vllm/v1/core/sched/scheduler.py
154261@@ -35,6 +35,8 @@ from vllm.v1.request import Request, RequestStatus
@@ -230,7 +337,42 @@ index 2d4fd4d59..8268c1409 100644
230337 # finished_req_ids is an existing state in the scheduler,
231338 # instead of being newly scheduled in this step.
232339 # It contains the request IDs that are finished in between
233- @@ -955,6 +975,8 @@ class Scheduler(SchedulerInterface):
340+ @@ -809,16 +829,12 @@ class Scheduler(SchedulerInterface):
341+ new_logprobs = None
342+ new_token_ids = generated_token_ids
343+ kv_transfer_params = None
344+ +
345+ if model_runner_output.finished_dumping is not None:
346+ request.succeed_dumped_blocks.extend(model_runner_output.finished_dumping.get(req_id, []))
347+ is_prefill = request.num_output_tokens == 0
348+ if is_prefill:
349+ - if isinstance(self.connector, MultiConnector):
350+ - for c in self.connector._connectors:
351+ - if hasattr(c, 'connector') and hasattr(c.connector, 'commit'):
352+ - c.connector.commit(model_runner_output.finished_dumping.get(req_id, []), True)
353+ - else:
354+ - self.connector.connector.commit(model_runner_output.finished_dumping.get(req_id, []), True)
355+ + self.connector.connector.commit(model_runner_output.finished_dumping.get(req_id, []), True)
356+
357+ # Append generated tokens and check for stop. Note that if
358+ # a request is still being prefilled, we expect the model runner
359+ @@ -870,7 +886,6 @@ class Scheduler(SchedulerInterface):
360+ spec_token_ids[req_index])
361+ else:
362+ request.spec_token_ids = spec_token_ids[req_index]
363+ -
364+ # Get prompt logprobs for this request.
365+ prompt_logprobs_tensors = prompt_logprobs_dict.get(req_id)
366+ if new_token_ids or pooler_output is not None \
367+ @@ -897,6 +912,7 @@ class Scheduler(SchedulerInterface):
368+
369+ if not stopped:
370+ new_running.append(request)
371+ +
372+ self.running = new_running
373+
374+ # KV Connector: update state for finished KV Transfers.
375+ @@ -955,6 +971,8 @@ class Scheduler(SchedulerInterface):
234376 def add_request(self, request: Request) -> None:
235377 self.waiting.add_request(request)
236378 self.requests[request.request_id] = request
@@ -239,7 +381,7 @@ index 2d4fd4d59..8268c1409 100644
239381 if self.log_stats:
240382 request.record_event(EngineCoreEventType.QUEUED)
241383
242- @@ -1004,6 +1026 ,8 @@ class Scheduler(SchedulerInterface):
384+ @@ -1004,6 +1022 ,8 @@ class Scheduler(SchedulerInterface):
243385
244386 def _free_request(self, request: Request) -> Optional[dict[str, Any]]:
245387 assert request.is_finished()
@@ -248,6 +390,14 @@ index 2d4fd4d59..8268c1409 100644
248390
249391 delay_free_blocks, kv_xfer_params = self._connector_finished(request)
250392 self.encoder_cache_manager.free(request)
393+ @@ -1155,7 +1175,6 @@ class Scheduler(SchedulerInterface):
394+ logger.debug("Finished sending KV transfer for request %s", req_id)
395+ self._free_blocks(self.requests[req_id])
396+
397+ -
398+ def _update_requests_with_invalid_blocks(
399+ self, requests: Iterable[Request],
400+ invalid_block_ids: set[int]) -> tuple[set[str], int]:
251401diff --git a/vllm/v1/worker/block_table.py b/vllm/v1/worker/block_table.py
252402index 8f4e8d64c..f45e39f5c 100644
253403--- a/vllm/v1/worker/block_table.py
@@ -280,7 +430,7 @@ index 8f4e8d64c..f45e39f5c 100644
280430 for i, block_table in enumerate(self.block_tables):
281431 block_table.add_row(block_ids[i], row_idx)
282432diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py
283- index c3df1d5d2..6341efc70 100644
433+ index c3df1d5d2..dbf1ea7d7 100644
284434--- a/vllm/v1/worker/gpu_model_runner.py
285435+++ b/vllm/v1/worker/gpu_model_runner.py
286436@@ -72,6 +72,9 @@ from ..sample.logits_processor import LogitsProcessorManager
@@ -347,7 +497,7 @@ index c3df1d5d2..6341efc70 100644
347497 self.input_batch.block_table.append_row(new_block_ids, req_index)
348498
349499 # For the last rank, we don't need to update the token_ids_cpu
350- @@ -639,6 +647,20 @@ class GPUModelRunner(LoRAModelRunnerMixin):
500+ @@ -639,6 +647,19 @@ class GPUModelRunner(LoRAModelRunnerMixin):
351501 if self.uses_mrope:
352502 self._calc_mrope_positions(scheduler_output)
353503
@@ -364,11 +514,10 @@ index c3df1d5d2..6341efc70 100644
364514+ offset = 0 if req_index == 0 else cu_num_tokens[req_index - 1] # TODO: support MTP
365515+ if is_sparsed_request:
366516+ sparsed_positions[offset] = req_sparsed_slots[req_id] - 1
367- +
368517 # Get token indices.
369518 # E.g., [0, 1, 0, 1, 2, 3, 4, 0, 1, 2]
370519 # -> [0, 1, M, M + 1, M + 2, M + 3, M + 4, 2 * M, 2 * M + 1, 2 * M + 2]
371- @@ -668,11 +690 ,11 @@ class GPUModelRunner(LoRAModelRunnerMixin):
520+ @@ -668,11 +689 ,11 @@ class GPUModelRunner(LoRAModelRunnerMixin):
372521 # block_size.
373522 block_table_indices = (
374523 req_indices * block_table.max_num_blocks_per_req +
@@ -382,7 +531,7 @@ index c3df1d5d2..6341efc70 100644
382531 np.add(
383532 block_numbers * block_size,
384533 block_offsets,
385- @@ -682,9 +704 ,11 @@ class GPUModelRunner(LoRAModelRunnerMixin):
534+ @@ -682,9 +703 ,11 @@ class GPUModelRunner(LoRAModelRunnerMixin):
386535 self.query_start_loc_np[0] = 0
387536 self.query_start_loc_np[1:num_reqs + 1] = cu_num_tokens
388537
@@ -397,7 +546,7 @@ index c3df1d5d2..6341efc70 100644
397546
398547 # Copy the tensors to the GPU.
399548 self.input_ids[:total_num_scheduled_tokens].copy_(
400- @@ -696,6 +720 ,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
549+ @@ -696,6 +719 ,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
401550 non_blocking=True)
402551 else:
403552 # Common case (1D positions)
@@ -406,15 +555,15 @@ index c3df1d5d2..6341efc70 100644
406555 self.positions[:total_num_scheduled_tokens].copy_(
407556 self.positions_cpu[:total_num_scheduled_tokens],
408557 non_blocking=True)
409- @@ -1386,6 +1412 ,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
558+ @@ -1386,6 +1411 ,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
410559 skip_cuda_graphs=skip_cuda_graphs,
411560 ):
412561 self.maybe_setup_kv_connector(scheduler_output)
413562+ self.maybe_execute_ucm_sparse_begin(scheduler_output, attn_metadata)
414563
415564 model_output = self.model(
416565 input_ids=input_ids,
417- @@ -1395,6 +1422 ,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
566+ @@ -1395,6 +1421 ,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
418567 )
419568
420569 finished_dumping = self.maybe_wait_for_kv_save()
@@ -423,7 +572,12 @@ index c3df1d5d2..6341efc70 100644
423572 finished_sending, finished_recving = (
424573 self.get_finished_kv_transfers(scheduler_output))
425574 invalid_block_ids = self.get_block_ids_with_load_errors()
426- @@ -1745,6 +1774,25 @@ class GPUModelRunner(LoRAModelRunnerMixin):
575+ @@ -1741,10 +1769,29 @@ class GPUModelRunner(LoRAModelRunnerMixin):
576+ kv_connector.start_load_kv(get_forward_context())
577+
578+ @staticmethod
579+ - def maybe_wait_for_kv_save():
580+ + def maybe_wait_for_kv_save() -> Optional[dict[str, list[str]]]:
427581 if has_kv_transfer_group():
428582 return get_kv_transfer_group().wait_for_save()
429583
0 commit comments