Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions vllm/attention/backends/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -375,9 +375,9 @@ def __init__(
def forward(
self,
layer: AttentionLayer,
hidden_states_or_cq: torch.Tensor,
kv_c_normed: torch.Tensor,
k_pe: torch.Tensor,
query: torch.Tensor | tuple[torch.Tensor, torch.Tensor],
key: torch.Tensor,
value: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: T,
output: torch.Tensor | None = None,
Expand Down
96 changes: 78 additions & 18 deletions vllm/attention/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -747,12 +747,14 @@ def __init__(

def forward(
self,
q: torch.Tensor,
q_nope: torch.Tensor,
q_pe: torch.Tensor,
kv_c_normed: torch.Tensor,
k_pe: torch.Tensor,
output_shape: torch.Size | None = None,
) -> torch.Tensor:
if self.calculate_kv_scales:
q = torch.cat((q_nope, q_pe), dim=-1)
torch.ops.vllm.maybe_calc_kv_scales(q, kv_c_normed, k_pe, self.layer_name)

if self.use_direct_call:
Expand All @@ -763,10 +765,13 @@ def forward(
self_kv_cache = self.kv_cache[forward_context.virtual_engine]

if self.attn_backend.accept_output_buffer:
output = torch.empty(output_shape, dtype=q.dtype, device=q.device)
assert output_shape is not None, "output_shape must be provided."
output = torch.empty(
output_shape, dtype=q_nope.dtype, device=q_nope.device
)
self.impl.forward(
self,
q,
(q_nope, q_pe),
kv_c_normed,
k_pe,
self_kv_cache,
Expand All @@ -776,22 +781,40 @@ def forward(
return output
else:
return self.impl.forward(
self, q, kv_c_normed, k_pe, self_kv_cache, attn_metadata
self,
(q_nope, q_pe),
kv_c_normed,
k_pe,
self_kv_cache,
attn_metadata,
)
else:
if self.attn_backend.accept_output_buffer:
output = torch.empty(output_shape, dtype=q.dtype, device=q.device)
assert output_shape is not None, "output_shape must be provided."
output = torch.empty(
output_shape, dtype=q_nope.dtype, device=q_nope.device
)
torch.ops.vllm.unified_mla_attention_with_output(
q,
q_nope,
q_pe,
kv_c_normed,
k_pe,
output,
self.layer_name,
)
return output
else:
if self.calculate_kv_scales:
forward_context = get_forward_context()
attn_metadata = forward_context.attn_metadata
if isinstance(attn_metadata, dict):
attn_metadata = attn_metadata[self.layer_name]
if getattr(attn_metadata, "enable_kv_scales_calculation", False):
self.calc_kv_scales((q_nope, q_pe), kv_c_normed, k_pe)

return torch.ops.vllm.unified_mla_attention(
q,
q_nope,
q_pe,
kv_c_normed,
k_pe,
self.layer_name,
Expand All @@ -802,7 +825,10 @@ def process_weights_after_loading(self, act_dtype: torch.dtype):
self.impl.process_weights_after_loading(act_dtype)

def calc_kv_scales(
self, q: torch.Tensor, kv_c_normed: torch.Tensor, k_pe: torch.Tensor
self,
q: torch.Tensor | tuple[torch.Tensor, torch.Tensor],
kv_c_normed: torch.Tensor,
k_pe: torch.Tensor,
) -> None:
"""Optional scale calculation for MLA inputs.

Expand All @@ -813,9 +839,26 @@ def calc_kv_scales(
k_range = getattr(self, "k_range", torch.tensor(1.0))
v_range = getattr(self, "v_range", torch.tensor(1.0))

self._q_scale.copy_(torch.abs(q).max() / q_range)
if isinstance(q, (tuple, list)):
q_nope, q_pe = q
else:
q_nope, q_pe = q.split(
[self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1
)

device = q_nope.device
zero = torch.tensor(0.0, dtype=torch.float32, device=device)
q_abs_max = zero
if q_nope.numel() > 0:
q_abs_max = torch.max(q_abs_max, torch.abs(q_nope).max().to(torch.float32))
if q_pe.numel() > 0:
q_abs_max = torch.max(q_abs_max, torch.abs(q_pe).max().to(torch.float32))
kv_abs_max = zero
if kv_c_normed.numel() > 0:
kv_abs_max = torch.abs(kv_c_normed).max().to(torch.float32)

self._q_scale.copy_(q_abs_max / q_range)
# kv_c_normed is the compressed KV representation; use it for k/v
kv_abs_max = torch.abs(kv_c_normed).max()
self._k_scale.copy_(kv_abs_max / k_range)
self._v_scale.copy_(kv_abs_max / v_range)
self._q_scale_float = self._q_scale.item()
Expand Down Expand Up @@ -978,24 +1021,39 @@ def unified_attention_with_output_fake(

@maybe_transfer_kv_layer
def unified_mla_attention(
q: torch.Tensor,
q_nope: torch.Tensor,
q_pe: torch.Tensor,
kv_c_normed: torch.Tensor,
k_pe: torch.Tensor,
layer_name: str,
) -> torch.Tensor:
attn_metadata, self, kv_cache = get_attention_context(layer_name)
output = self.impl.forward(self, q, kv_c_normed, k_pe, kv_cache, attn_metadata)
forward_context: ForwardContext = get_forward_context()
attn_metadata = forward_context.attn_metadata
if isinstance(attn_metadata, dict):
attn_metadata = attn_metadata[layer_name]
self: MLAAttention = forward_context.no_compile_layers[layer_name]
kv_cache = self.kv_cache[forward_context.virtual_engine]
output = self.impl.forward(
self, (q_nope, q_pe), kv_c_normed, k_pe, kv_cache, attn_metadata
)

return output


def unified_mla_attention_fake(
q: torch.Tensor,
q_nope: torch.Tensor,
q_pe: torch.Tensor,
kv_c_normed: torch.Tensor,
k_pe: torch.Tensor,
layer_name: str,
) -> torch.Tensor:
return torch.empty_like(q).contiguous()
head_dim = q_nope.shape[-1] + q_pe.shape[-1]
fake_shape = (*q_nope.shape[:-1], head_dim)
return torch.empty(
fake_shape,
dtype=q_nope.dtype,
device=q_nope.device,
).contiguous()


direct_register_custom_op(
Expand All @@ -1009,7 +1067,8 @@ def unified_mla_attention_fake(

@maybe_transfer_kv_layer
def unified_mla_attention_with_output(
q: torch.Tensor,
q_nope: torch.Tensor,
q_pe: torch.Tensor,
kv_c_normed: torch.Tensor,
k_pe: torch.Tensor,
output: torch.Tensor,
Expand All @@ -1020,7 +1079,7 @@ def unified_mla_attention_with_output(
attn_metadata, self, kv_cache = get_attention_context(layer_name)
self.impl.forward(
self,
q,
(q_nope, q_pe),
kv_c_normed,
k_pe,
kv_cache,
Expand All @@ -1032,7 +1091,8 @@ def unified_mla_attention_with_output(


def unified_mla_attention_with_output_fake(
q: torch.Tensor,
q_nope: torch.Tensor,
q_pe: torch.Tensor,
kv_c_normed: torch.Tensor,
k_pe: torch.Tensor,
output: torch.Tensor,
Expand Down
9 changes: 5 additions & 4 deletions vllm/model_executor/layers/mla.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,17 +146,18 @@ def forward_native(
q = q.view(-1, self.num_heads, self.qk_head_dim)
# Add head dim of 1 to k_pe
k_pe = k_pe.unsqueeze(1)
q_nope = q[..., : self.qk_nope_head_dim]
q_pe = q[..., self.qk_nope_head_dim :]

if self.rotary_emb is not None:
q[..., self.qk_nope_head_dim :], k_pe = self.rotary_emb(
positions, q[..., self.qk_nope_head_dim :], k_pe
)
q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe)

if self.indexer and self.is_sparse:
_topk_indices = self.indexer(hidden_states, q_c, positions, self.rotary_emb)

attn_out = self.mla_attn(
q,
q_nope,
q_pe,
kv_c_normed,
k_pe,
output_shape=(hidden_states.shape[0], self.num_heads * self.v_head_dim),
Expand Down
28 changes: 19 additions & 9 deletions vllm/v1/attention/backends/mla/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -1804,7 +1804,10 @@ def _forward_prefill(
)
else:
context_output, context_lse = self._compute_prefill_context(
q, kv_c_and_k_pe_cache, attn_metadata, k_scale
q,
kv_c_and_k_pe_cache,
attn_metadata,
k_scale,
)

output = torch.empty_like(suffix_output)
Expand Down Expand Up @@ -1835,7 +1838,7 @@ def _forward_decode(
def forward(
self,
layer: AttentionLayer,
q: torch.Tensor,
q: torch.Tensor | tuple[torch.Tensor, torch.Tensor],
k_c_normed: torch.Tensor, # key in unified attn
k_pe: torch.Tensor, # value in unified attn
kv_cache: torch.Tensor,
Expand Down Expand Up @@ -1880,7 +1883,15 @@ def forward(
# Inputs and outputs may be padded for CUDA graphs
output_padded = output
output = output[:num_actual_toks, ...]
q = q[:num_actual_toks, ...]
if isinstance(q, tuple):
q_nope, q_pe = q
else:
q_nope, q_pe = q.split(
[self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1
)

q_nope = q_nope[:num_actual_toks, ...]
q_pe = q_pe[:num_actual_toks, ...]
k_c_normed = k_c_normed[:num_actual_toks, ...]
k_pe = k_pe[:num_actual_toks, ...]

Expand All @@ -1894,9 +1905,12 @@ def forward(
has_prefill = attn_metadata.num_prefills > 0
num_decode_tokens = attn_metadata.num_decode_tokens

decode_q = q[:num_decode_tokens]
decode_q_nope = q_nope[:num_decode_tokens]
decode_q_pe = q_pe[:num_decode_tokens]

prefill_q = q[num_decode_tokens:]
prefill_q_nope = q_nope[num_decode_tokens:]
prefill_q_pe = q_pe[num_decode_tokens:]
prefill_q = torch.cat((prefill_q_nope, prefill_q_pe), dim=-1)
prefill_k_pe = k_pe[num_decode_tokens:]
prefill_k_c_normed = k_c_normed[num_decode_tokens:]

Expand Down Expand Up @@ -1927,10 +1941,6 @@ def forward(
if has_decode:
assert attn_metadata.decode is not None

decode_q_nope, decode_q_pe = decode_q.split(
[self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1
)

# Convert from (B, N, P) to (N, B, P)
decode_q_nope = decode_q_nope.transpose(0, 1)

Expand Down
19 changes: 13 additions & 6 deletions vllm/v1/attention/backends/mla/flashmla_sparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -470,7 +470,7 @@ def _forward_fp8_kv(
def forward(
self,
layer: AttentionLayer,
q: torch.Tensor,
q: torch.Tensor | tuple[torch.Tensor, torch.Tensor],
k_c_normed: torch.Tensor, # key in unified attn
k_pe: torch.Tensor, # value in unified attn
kv_cache: torch.Tensor,
Expand Down Expand Up @@ -499,11 +499,18 @@ def forward(

# Inputs and outputs may be padded for CUDA graphs

q = q[:num_actual_toks, ...]
if isinstance(q, tuple):
q_nope, q_pe = q
else:
q_nope, q_pe = q.split(
[self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1
)

q_nope = q_nope[:num_actual_toks, ...]
q_pe = q_pe[:num_actual_toks, ...]
k_c_normed = k_c_normed[:num_actual_toks, ...]
k_pe = k_pe[:num_actual_toks, ...]

q_nope, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
# Convert from (B, N, P) to (N, B, P)
q_nope = q_nope.transpose(0, 1)
# Multiply (N, B, P) x (N, P, L) -> (N, B, L)
Expand All @@ -522,7 +529,7 @@ def forward(
NUM_TOPK_TOKENS=attn_metadata.topk_tokens,
)

q = torch.cat([ql_nope, q_pe], dim=-1)
q_combined = torch.cat([ql_nope, q_pe], dim=-1)

# write the latent and rope to kv cache
if kv_cache.numel() > 0:
Expand All @@ -537,11 +544,11 @@ def forward(

if self.kv_cache_dtype != "fp8_ds_mla":
attn_out = self._forward_bf16_kv(
q, kv_cache, topk_indices_global, attn_metadata
q_combined, kv_cache, topk_indices_global, attn_metadata
)
else:
attn_out = self._forward_fp8_kv(
q, kv_cache, topk_indices_global, attn_metadata
q_combined, kv_cache, topk_indices_global, attn_metadata
)

self._v_up_proj(attn_out, out=output[:num_actual_toks])
Expand Down