Skip to content

Conversation

@vnadathur
Copy link
Contributor

@vnadathur vnadathur commented Nov 9, 2025

I track in this issue: #26567

open up the opportunity for MLA to its own custom op instead of unified_attention allowing for to potentially explore passing q_nope and q_rope independently instead of concatenated

Reference: #24620 and #25103

Some things of note are that the prefill path still require concatted q (this is because of kernel expectations), so I focused on just deconcating the decode path.

cc @ProExpertProg @MatthewBonanni @LucasWilkinson

Signed-off-by: vnadathur <glvikramn@gmail.com>
Co-Authored-By: Srreyansh Sethi <107075589+WorldExplored@users.noreply.github.com>
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request refactors the Multi-Head Latent Attention (MLA) layers to pass the q_nope and q_pe components of the query tensor as separate arguments, instead of a single concatenated tensor. This is a good preparatory step for introducing custom attention ops that can leverage this separation for optimizations. The changes are applied consistently across the affected attention backends and model layers.

However, I've identified a potential performance regression in the prefill path for the MLACommonImpl backend. The new API forces a split -> slice -> cat sequence of operations, which is less efficient than the previous slice-only approach due to an extra memory allocation and copy. I've left a detailed comment with a suggestion to address this. Other than that, the refactoring looks solid.

.
Signed-off-by: vnadathur <glvikramn@gmail.com>
1
Signed-off-by: vnadathur <glvikramn@gmail.com>
@mergify
Copy link

mergify bot commented Nov 11, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @vnadathur.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Nov 11, 2025
Signed-off-by: vnadathur <glvikramn@gmail.com>
@mergify mergify bot removed the needs-rebase label Nov 11, 2025
WorldExplored and others added 3 commits November 12, 2025 02:16
Signed-off-by: WorldExplored <srreyansh.sethi@gmail.com>
Co-Authored-By: vnadathur <236933696+vnadathur@users.noreply.github.com>
Signed-off-by: Srreyansh Sethi <107075589+WorldExplored@users.noreply.github.com>
Signed-off-by: WorldExplored <srreyansh.sethi@gmail.com>
@MatthewBonanni
Copy link
Contributor

Thanks for this work! Do you have any benchmarking results for what kind of speedup we can expect from this?

@vnadathur
Copy link
Contributor Author

@MatthewBonanni Here are the benchmark results for with split and without. There is some other work going on for refactoring mla backends like this pr for example: #27501 , my pr is just to enable more of the same stuff coming from that original issue & adjacent prs in the desc.

Ran deepseek-V2.5 on a 8xh100

BENCHMARK UTILIZING SPLIT q_nope and q_rope:

============ Serving Benchmark Result ============
Successful requests:                     32        
Failed requests:                         0         
Maximum request concurrency:             8         
Benchmark duration (s):                  10.14     
Total input tokens:                      32736     
Total generated tokens:                  4096      
Request throughput (req/s):              3.16      
Output token throughput (tok/s):         404.01    
Peak output token throughput (tok/s):    488.00    
Peak concurrent requests:                16.00     
Total Token throughput (tok/s):          3632.93   
---------------Time to First Token----------------
Mean TTFT (ms):                          395.99    
Median TTFT (ms):                        346.76    
P99 TTFT (ms):                           711.16    
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          16.83     
Median TPOT (ms):                        16.58     
P99 TPOT (ms):                           20.67     
---------------Inter-token Latency----------------
Mean ITL (ms):                           16.83     
Median ITL (ms):                         16.47     
P99 ITL (ms):                            19.31     
==================================================

BENCHMARK NORMAL MLA BACKEND:

============ Serving Benchmark Result ============
Successful requests:                     32        
Failed requests:                         0         
Maximum request concurrency:             8         
Benchmark duration (s):                  10.12     
Total input tokens:                      32736     
Total generated tokens:                  4096      
Request throughput (req/s):              3.16      
Output token throughput (tok/s):         404.60    
Peak output token throughput (tok/s):    488.00    
Peak concurrent requests:                16.00     
Total Token throughput (tok/s):          3638.20   
---------------Time to First Token----------------
Mean TTFT (ms):                          387.51    
Median TTFT (ms):                        336.47    
P99 TTFT (ms):                           707.50    
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          16.87     
Median TPOT (ms):                        16.61     
P99 TPOT (ms):                           20.69     
---------------Inter-token Latency----------------
Mean ITL (ms):                           16.90     
Median ITL (ms):                         16.53     
P99 ITL (ms):                            18.03     
==================================================
Config

vllm serve deepseek-ai/DeepSeek-V2.5 \
  --dtype bfloat16 \
  --tensor-parallel-size 8 \
  --trust-remote-code \
  --max-model-len 8192 \
  --gpu-memory-utilization 0.9 \
  > serve.log 2>&1 &

SERVER_PID=$!
echo "Server PID: $SERVER_PID"

vllm bench serve \
  --backend vllm \
  --model deepseek-ai/DeepSeek-V2.5 \
  --endpoint /v1/completions \
  --dataset-name random \
  --num-prompts 32 \
  --max-concurrency 8 \
  --request-rate inf \
  --trust-remote-code

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants