diff --git a/vllm/attention/backends/abstract.py b/vllm/attention/backends/abstract.py index 9275d70fd86a..be3e0176bff1 100644 --- a/vllm/attention/backends/abstract.py +++ b/vllm/attention/backends/abstract.py @@ -375,9 +375,9 @@ def __init__( def forward( self, layer: AttentionLayer, - hidden_states_or_cq: torch.Tensor, - kv_c_normed: torch.Tensor, - k_pe: torch.Tensor, + query: torch.Tensor | tuple[torch.Tensor, torch.Tensor], + key: torch.Tensor, + value: torch.Tensor, kv_cache: torch.Tensor, attn_metadata: T, output: torch.Tensor | None = None, diff --git a/vllm/attention/layer.py b/vllm/attention/layer.py index 37f9a4b383ce..2a7b06cf0162 100644 --- a/vllm/attention/layer.py +++ b/vllm/attention/layer.py @@ -747,12 +747,14 @@ def __init__( def forward( self, - q: torch.Tensor, + q_nope: torch.Tensor, + q_pe: torch.Tensor, kv_c_normed: torch.Tensor, k_pe: torch.Tensor, output_shape: torch.Size | None = None, ) -> torch.Tensor: if self.calculate_kv_scales: + q = torch.cat((q_nope, q_pe), dim=-1) torch.ops.vllm.maybe_calc_kv_scales(q, kv_c_normed, k_pe, self.layer_name) if self.use_direct_call: @@ -763,10 +765,13 @@ def forward( self_kv_cache = self.kv_cache[forward_context.virtual_engine] if self.attn_backend.accept_output_buffer: - output = torch.empty(output_shape, dtype=q.dtype, device=q.device) + assert output_shape is not None, "output_shape must be provided." + output = torch.empty( + output_shape, dtype=q_nope.dtype, device=q_nope.device + ) self.impl.forward( self, - q, + (q_nope, q_pe), kv_c_normed, k_pe, self_kv_cache, @@ -776,13 +781,22 @@ def forward( return output else: return self.impl.forward( - self, q, kv_c_normed, k_pe, self_kv_cache, attn_metadata + self, + (q_nope, q_pe), + kv_c_normed, + k_pe, + self_kv_cache, + attn_metadata, ) else: if self.attn_backend.accept_output_buffer: - output = torch.empty(output_shape, dtype=q.dtype, device=q.device) + assert output_shape is not None, "output_shape must be provided." + output = torch.empty( + output_shape, dtype=q_nope.dtype, device=q_nope.device + ) torch.ops.vllm.unified_mla_attention_with_output( - q, + q_nope, + q_pe, kv_c_normed, k_pe, output, @@ -790,8 +804,17 @@ def forward( ) return output else: + if self.calculate_kv_scales: + forward_context = get_forward_context() + attn_metadata = forward_context.attn_metadata + if isinstance(attn_metadata, dict): + attn_metadata = attn_metadata[self.layer_name] + if getattr(attn_metadata, "enable_kv_scales_calculation", False): + self.calc_kv_scales((q_nope, q_pe), kv_c_normed, k_pe) + return torch.ops.vllm.unified_mla_attention( - q, + q_nope, + q_pe, kv_c_normed, k_pe, self.layer_name, @@ -802,7 +825,10 @@ def process_weights_after_loading(self, act_dtype: torch.dtype): self.impl.process_weights_after_loading(act_dtype) def calc_kv_scales( - self, q: torch.Tensor, kv_c_normed: torch.Tensor, k_pe: torch.Tensor + self, + q: torch.Tensor | tuple[torch.Tensor, torch.Tensor], + kv_c_normed: torch.Tensor, + k_pe: torch.Tensor, ) -> None: """Optional scale calculation for MLA inputs. @@ -813,9 +839,26 @@ def calc_kv_scales( k_range = getattr(self, "k_range", torch.tensor(1.0)) v_range = getattr(self, "v_range", torch.tensor(1.0)) - self._q_scale.copy_(torch.abs(q).max() / q_range) + if isinstance(q, (tuple, list)): + q_nope, q_pe = q + else: + q_nope, q_pe = q.split( + [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1 + ) + + device = q_nope.device + zero = torch.tensor(0.0, dtype=torch.float32, device=device) + q_abs_max = zero + if q_nope.numel() > 0: + q_abs_max = torch.max(q_abs_max, torch.abs(q_nope).max().to(torch.float32)) + if q_pe.numel() > 0: + q_abs_max = torch.max(q_abs_max, torch.abs(q_pe).max().to(torch.float32)) + kv_abs_max = zero + if kv_c_normed.numel() > 0: + kv_abs_max = torch.abs(kv_c_normed).max().to(torch.float32) + + self._q_scale.copy_(q_abs_max / q_range) # kv_c_normed is the compressed KV representation; use it for k/v - kv_abs_max = torch.abs(kv_c_normed).max() self._k_scale.copy_(kv_abs_max / k_range) self._v_scale.copy_(kv_abs_max / v_range) self._q_scale_float = self._q_scale.item() @@ -978,24 +1021,39 @@ def unified_attention_with_output_fake( @maybe_transfer_kv_layer def unified_mla_attention( - q: torch.Tensor, + q_nope: torch.Tensor, + q_pe: torch.Tensor, kv_c_normed: torch.Tensor, k_pe: torch.Tensor, layer_name: str, ) -> torch.Tensor: - attn_metadata, self, kv_cache = get_attention_context(layer_name) - output = self.impl.forward(self, q, kv_c_normed, k_pe, kv_cache, attn_metadata) + forward_context: ForwardContext = get_forward_context() + attn_metadata = forward_context.attn_metadata + if isinstance(attn_metadata, dict): + attn_metadata = attn_metadata[layer_name] + self: MLAAttention = forward_context.no_compile_layers[layer_name] + kv_cache = self.kv_cache[forward_context.virtual_engine] + output = self.impl.forward( + self, (q_nope, q_pe), kv_c_normed, k_pe, kv_cache, attn_metadata + ) return output def unified_mla_attention_fake( - q: torch.Tensor, + q_nope: torch.Tensor, + q_pe: torch.Tensor, kv_c_normed: torch.Tensor, k_pe: torch.Tensor, layer_name: str, ) -> torch.Tensor: - return torch.empty_like(q).contiguous() + head_dim = q_nope.shape[-1] + q_pe.shape[-1] + fake_shape = (*q_nope.shape[:-1], head_dim) + return torch.empty( + fake_shape, + dtype=q_nope.dtype, + device=q_nope.device, + ).contiguous() direct_register_custom_op( @@ -1009,7 +1067,8 @@ def unified_mla_attention_fake( @maybe_transfer_kv_layer def unified_mla_attention_with_output( - q: torch.Tensor, + q_nope: torch.Tensor, + q_pe: torch.Tensor, kv_c_normed: torch.Tensor, k_pe: torch.Tensor, output: torch.Tensor, @@ -1020,7 +1079,7 @@ def unified_mla_attention_with_output( attn_metadata, self, kv_cache = get_attention_context(layer_name) self.impl.forward( self, - q, + (q_nope, q_pe), kv_c_normed, k_pe, kv_cache, @@ -1032,7 +1091,8 @@ def unified_mla_attention_with_output( def unified_mla_attention_with_output_fake( - q: torch.Tensor, + q_nope: torch.Tensor, + q_pe: torch.Tensor, kv_c_normed: torch.Tensor, k_pe: torch.Tensor, output: torch.Tensor, diff --git a/vllm/model_executor/layers/mla.py b/vllm/model_executor/layers/mla.py index c4c44b83ae6b..b8f8d3649fa8 100644 --- a/vllm/model_executor/layers/mla.py +++ b/vllm/model_executor/layers/mla.py @@ -146,17 +146,18 @@ def forward_native( q = q.view(-1, self.num_heads, self.qk_head_dim) # Add head dim of 1 to k_pe k_pe = k_pe.unsqueeze(1) + q_nope = q[..., : self.qk_nope_head_dim] + q_pe = q[..., self.qk_nope_head_dim :] if self.rotary_emb is not None: - q[..., self.qk_nope_head_dim :], k_pe = self.rotary_emb( - positions, q[..., self.qk_nope_head_dim :], k_pe - ) + q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe) if self.indexer and self.is_sparse: _topk_indices = self.indexer(hidden_states, q_c, positions, self.rotary_emb) attn_out = self.mla_attn( - q, + q_nope, + q_pe, kv_c_normed, k_pe, output_shape=(hidden_states.shape[0], self.num_heads * self.v_head_dim), diff --git a/vllm/v1/attention/backends/mla/common.py b/vllm/v1/attention/backends/mla/common.py index 467c01cd9d06..f892566fc0f6 100755 --- a/vllm/v1/attention/backends/mla/common.py +++ b/vllm/v1/attention/backends/mla/common.py @@ -1804,7 +1804,10 @@ def _forward_prefill( ) else: context_output, context_lse = self._compute_prefill_context( - q, kv_c_and_k_pe_cache, attn_metadata, k_scale + q, + kv_c_and_k_pe_cache, + attn_metadata, + k_scale, ) output = torch.empty_like(suffix_output) @@ -1835,7 +1838,7 @@ def _forward_decode( def forward( self, layer: AttentionLayer, - q: torch.Tensor, + q: torch.Tensor | tuple[torch.Tensor, torch.Tensor], k_c_normed: torch.Tensor, # key in unified attn k_pe: torch.Tensor, # value in unified attn kv_cache: torch.Tensor, @@ -1880,7 +1883,15 @@ def forward( # Inputs and outputs may be padded for CUDA graphs output_padded = output output = output[:num_actual_toks, ...] - q = q[:num_actual_toks, ...] + if isinstance(q, tuple): + q_nope, q_pe = q + else: + q_nope, q_pe = q.split( + [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1 + ) + + q_nope = q_nope[:num_actual_toks, ...] + q_pe = q_pe[:num_actual_toks, ...] k_c_normed = k_c_normed[:num_actual_toks, ...] k_pe = k_pe[:num_actual_toks, ...] @@ -1894,9 +1905,12 @@ def forward( has_prefill = attn_metadata.num_prefills > 0 num_decode_tokens = attn_metadata.num_decode_tokens - decode_q = q[:num_decode_tokens] + decode_q_nope = q_nope[:num_decode_tokens] + decode_q_pe = q_pe[:num_decode_tokens] - prefill_q = q[num_decode_tokens:] + prefill_q_nope = q_nope[num_decode_tokens:] + prefill_q_pe = q_pe[num_decode_tokens:] + prefill_q = torch.cat((prefill_q_nope, prefill_q_pe), dim=-1) prefill_k_pe = k_pe[num_decode_tokens:] prefill_k_c_normed = k_c_normed[num_decode_tokens:] @@ -1927,10 +1941,6 @@ def forward( if has_decode: assert attn_metadata.decode is not None - decode_q_nope, decode_q_pe = decode_q.split( - [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1 - ) - # Convert from (B, N, P) to (N, B, P) decode_q_nope = decode_q_nope.transpose(0, 1) diff --git a/vllm/v1/attention/backends/mla/flashmla_sparse.py b/vllm/v1/attention/backends/mla/flashmla_sparse.py index bb8d914d1571..6774de21c28f 100644 --- a/vllm/v1/attention/backends/mla/flashmla_sparse.py +++ b/vllm/v1/attention/backends/mla/flashmla_sparse.py @@ -470,7 +470,7 @@ def _forward_fp8_kv( def forward( self, layer: AttentionLayer, - q: torch.Tensor, + q: torch.Tensor | tuple[torch.Tensor, torch.Tensor], k_c_normed: torch.Tensor, # key in unified attn k_pe: torch.Tensor, # value in unified attn kv_cache: torch.Tensor, @@ -499,11 +499,18 @@ def forward( # Inputs and outputs may be padded for CUDA graphs - q = q[:num_actual_toks, ...] + if isinstance(q, tuple): + q_nope, q_pe = q + else: + q_nope, q_pe = q.split( + [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1 + ) + + q_nope = q_nope[:num_actual_toks, ...] + q_pe = q_pe[:num_actual_toks, ...] k_c_normed = k_c_normed[:num_actual_toks, ...] k_pe = k_pe[:num_actual_toks, ...] - q_nope, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1) # Convert from (B, N, P) to (N, B, P) q_nope = q_nope.transpose(0, 1) # Multiply (N, B, P) x (N, P, L) -> (N, B, L) @@ -522,7 +529,7 @@ def forward( NUM_TOPK_TOKENS=attn_metadata.topk_tokens, ) - q = torch.cat([ql_nope, q_pe], dim=-1) + q_combined = torch.cat([ql_nope, q_pe], dim=-1) # write the latent and rope to kv cache if kv_cache.numel() > 0: @@ -537,11 +544,11 @@ def forward( if self.kv_cache_dtype != "fp8_ds_mla": attn_out = self._forward_bf16_kv( - q, kv_cache, topk_indices_global, attn_metadata + q_combined, kv_cache, topk_indices_global, attn_metadata ) else: attn_out = self._forward_fp8_kv( - q, kv_cache, topk_indices_global, attn_metadata + q_combined, kv_cache, topk_indices_global, attn_metadata ) self._v_up_proj(attn_out, out=output[:num_actual_toks])