[None][fix] visual_gen UlyssesAttention: pass post-A2A seq_len to inner backend#13486
[None][fix] visual_gen UlyssesAttention: pass post-A2A seq_len to inner backend#13486karljang wants to merge 1 commit intoNVIDIA:mainfrom
Conversation
…er backend The caller (modules/attention.py:_attn_impl) computes seq_len from the local sharded tensor before calling self.attn.forward(...). When self.attn is UlyssesAttention, the call goes through an all_to_all that gathers the full sequence dimension. The kwargs forwarded to the inner backend (TRTLLM, including its Sage path) still carry the sharded seq_len, but the post-A2A tensor has full seq_len rows. Inside the inner backend, seq_len is used in 2D reshapes: - trtllm.py:276-278 (Sage): q.reshape(batch_size * seq_len, -1) - trtllm.py:292 : qkv.reshape(batch_size * seq_len, -1) With Ulysses world_size > 1, this reshape uses the sharded value but operates on a tensor with full-seq rows, leading to either a ValueError or a q_hidden_size assertion deeper in the FMHA kernel runner. Fix: in both UlyssesAttention._forward_fused and _forward_unfused, override kwargs["seq_len"] (and seq_len_kv where applicable) with the post-A2A tensor shape before delegating to the inner backend. The override is captured before any HND transpose so the seq dim is taken from the correct axis. Single-rank case is unchanged: the A2A is skipped and q.shape[1] already equals the sharded seq_len passed in. Signed-off-by: Kanghwan Jang <861393+karljang@users.noreply.github.com>
|
No actionable comments were generated in the recent review. 🎉 ℹ️ Recent review info⚙️ Run configurationConfiguration used: Path: .coderabbit.yaml Review profile: CHILL Plan: Enterprise Run ID: 📒 Files selected for processing (1)
📝 WalkthroughWalkthroughModified a wrapper in the attention backend to override sequence length parameters ( Changes
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~20 minutes 🚥 Pre-merge checks | ✅ 4 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (4 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Comment |
60174a8 to
e885a55
Compare
|
/bot run --disable-fail-fast |
|
PR_Github #45648 [ run ] triggered by Bot. Commit: |
|
Hi @NVShreyas, |
|
PR_Github #45648 [ run ] completed with state
|
|
/bot run --disable-fail-fast |
|
PR_Github #45747 [ run ] triggered by Bot. Commit: |
Summary
UlyssesAttentionin visual_gen forwards a shardedseq_lenkwarg intothe inner attention backend even after its all-to-all has gathered the
sequence dim back to full length. The inner backend uses
seq_lenfor2D reshapes (Sage path and the regular fused-QKV path in
trtllm.py),so the row count mismatches the actual tensor — causing either a
ValueErroror aq_hidden_sizeassertion inside the FMHA kernelrunner whenever Ulysses sequence parallelism is on (world_size > 1).
This PR overrides
kwargs["seq_len"](andseq_len_kvin the unfusedpath) with the post-A2A tensor shape inside
UlyssesAttention.forward,so the inner backend sees a length that matches what it actually
received.
Root cause trace
tensorrt_llm/_torch/visual_gen/modules/attention.py:_attn_impl)reads
seq_len = q.shape[1]from the pre-A2A (sharded) tensor —e.g.
4096 / 8 = 512under 8-way Ulysses — and passes it toself.attn.forward(..., seq_len=512, seq_len_kv=512).UlyssesAttention.forwardrunsall_to_all_4d/5d: tensors come backwith full seq_len (4096).
kwargs["seq_len"]is still 512.tensorrt_llm/_torch/visual_gen/attention_backend/trtllm.py):q.reshape(batch_size * seq_len, -1)qkv.reshape(batch_size * seq_len, -1)Both use the stale 512 against a 4096-row tensor → kernel runner
asserts on
q_hidden_size.Fix
Two-line addition in each of
_forward_fusedand_forward_unfused,captured before any HND transpose so the seq axis is taken from the
correct dimension.
Correctness
world_size == 1: A2A is skipped;q.shape[1]already equals thecaller's
seq_len. Override is a no-op.seq_len_kv = seq_len(Q and K sharelength).
kv_seq_len_full = k.shape[1]setseparately, so different Q/KV lengths are preserved.
line 132–138.
Test plan
locally: without this PR, fails at
q_hidden_sizeassertion inthe FMHA kernel runner; with the PR applied, runs cleanly.
--ulysses_size) —bit-exact regression vs main expected (override is a no-op when
world_size == 1).🤖 Generated with Claude Code
Summary by CodeRabbit
Bug Fixes