Skip to content

[None][fix] visual_gen UlyssesAttention: pass post-A2A seq_len to inner backend#13486

Open
karljang wants to merge 1 commit intoNVIDIA:mainfrom
karljang:fix/visual-gen-ulysses-seqlen
Open

[None][fix] visual_gen UlyssesAttention: pass post-A2A seq_len to inner backend#13486
karljang wants to merge 1 commit intoNVIDIA:mainfrom
karljang:fix/visual-gen-ulysses-seqlen

Conversation

@karljang
Copy link
Copy Markdown
Collaborator

@karljang karljang commented Apr 27, 2026

Summary

UlyssesAttention in visual_gen forwards a sharded seq_len kwarg into
the inner attention backend even after its all-to-all has gathered the
sequence dim back to full length. The inner backend uses seq_len for
2D reshapes (Sage path and the regular fused-QKV path in trtllm.py),
so the row count mismatches the actual tensor — causing either a
ValueError or a q_hidden_size assertion inside the FMHA kernel
runner whenever Ulysses sequence parallelism is on (world_size > 1).

This PR overrides kwargs["seq_len"] (and seq_len_kv in the unfused
path) with the post-A2A tensor shape inside UlyssesAttention.forward,
so the inner backend sees a length that matches what it actually
received.

Root cause trace

  1. Caller (tensorrt_llm/_torch/visual_gen/modules/attention.py:_attn_impl)
    reads seq_len = q.shape[1] from the pre-A2A (sharded) tensor —
    e.g. 4096 / 8 = 512 under 8-way Ulysses — and passes it to
    self.attn.forward(..., seq_len=512, seq_len_kv=512).
  2. UlyssesAttention.forward runs all_to_all_4d/5d: tensors come back
    with full seq_len (4096). kwargs["seq_len"] is still 512.
  3. Inner backend (tensorrt_llm/_torch/visual_gen/attention_backend/trtllm.py):
    • Sage path (lines 276–278): q.reshape(batch_size * seq_len, -1)
    • Fused path (line 292): qkv.reshape(batch_size * seq_len, -1)
      Both use the stale 512 against a 4096-row tensor → kernel runner
      asserts on q_hidden_size.

Fix

Two-line addition in each of _forward_fused and _forward_unfused,
captured before any HND transpose so the seq axis is taken from the
correct dimension.

Correctness

  • world_size == 1: A2A is skipped; q.shape[1] already equals the
    caller's seq_len. Override is a no-op.
  • Self-attention (fused QKV): seq_len_kv = seq_len (Q and K share
    length).
  • Cross-attention (unfused): kv_seq_len_full = k.shape[1] set
    separately, so different Q/KV lengths are preserved.
  • HND backends: shape values are captured before the transpose at
    line 132–138.

Test plan

  • FLUX.1-dev with multi-rank Ulysses + Sage attention — verified
    locally: without this PR, fails at q_hidden_size assertion in
    the FMHA kernel runner; with the PR applied, runs cleanly.
  • FLUX.1-dev / FLUX.2-dev single-rank baseline (no --ulysses_size) —
    bit-exact regression vs main expected (override is a no-op when
    world_size == 1).

🤖 Generated with Claude Code

Summary by CodeRabbit

Bug Fixes

  • Fixed assertion failures in parallel attention backend kernel execution by ensuring sequence lengths are correctly handled during distributed tensor operations.

@karljang karljang requested a review from a team as a code owner April 27, 2026 04:32
…er backend

The caller (modules/attention.py:_attn_impl) computes seq_len from the
local sharded tensor before calling self.attn.forward(...). When
self.attn is UlyssesAttention, the call goes through an all_to_all
that gathers the full sequence dimension. The kwargs forwarded to the
inner backend (TRTLLM, including its Sage path) still carry the
sharded seq_len, but the post-A2A tensor has full seq_len rows.

Inside the inner backend, seq_len is used in 2D reshapes:
  - trtllm.py:276-278 (Sage): q.reshape(batch_size * seq_len, -1)
  - trtllm.py:292        : qkv.reshape(batch_size * seq_len, -1)

With Ulysses world_size > 1, this reshape uses the sharded value but
operates on a tensor with full-seq rows, leading to either a
ValueError or a q_hidden_size assertion deeper in the FMHA kernel
runner.

Fix: in both UlyssesAttention._forward_fused and _forward_unfused,
override kwargs["seq_len"] (and seq_len_kv where applicable) with the
post-A2A tensor shape before delegating to the inner backend. The
override is captured before any HND transpose so the seq dim is taken
from the correct axis.

Single-rank case is unchanged: the A2A is skipped and q.shape[1]
already equals the sharded seq_len passed in.

Signed-off-by: Kanghwan Jang <861393+karljang@users.noreply.github.com>
@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai Bot commented Apr 27, 2026

No actionable comments were generated in the recent review. 🎉

ℹ️ Recent review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Enterprise

Run ID: 16c62ecd-9dfb-4539-abac-fc5838288734

📥 Commits

Reviewing files that changed from the base of the PR and between 2f745de and 60174a8.

📒 Files selected for processing (1)
  • tensorrt_llm/_torch/visual_gen/attention_backend/parallel.py

📝 Walkthrough

Walkthrough

Modified a wrapper in the attention backend to override sequence length parameters (seq_len and seq_len_kv) to use post-all-to-all values in both fused and unfused forward paths, ensuring tensor reshapes align with gathered dimensions.

Changes

Cohort / File(s) Summary
Attention Backend Parameter Override
tensorrt_llm/_torch/visual_gen/attention_backend/parallel.py
Added logic to force sequence length parameters to use post-all-to-all gathered values by overwriting kwargs["seq_len"] and kwargs["seq_len_kv"] in both fused and unfused forward execution paths.

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~20 minutes

🚥 Pre-merge checks | ✅ 4 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 0.00% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (4 passed)
Check name Status Explanation
Title check ✅ Passed The title clearly and specifically describes the fix: overriding the post-A2A sequence length in UlyssesAttention before passing to inner backend, which matches the core change in the PR.
Description check ✅ Passed The description comprehensively covers the summary, root cause trace, fix explanation, correctness analysis, and test plan, but lacks explicit Test Coverage section and PR Checklist items as specified in the template.
Linked Issues check ✅ Passed Check skipped because no linked issues were found for this pull request.
Out of Scope Changes check ✅ Passed Check skipped because no linked issues were found for this pull request.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Comment @coderabbitai help to get the list of available commands and usage tips.

@karljang karljang force-pushed the fix/visual-gen-ulysses-seqlen branch from 60174a8 to e885a55 Compare April 27, 2026 04:36
@karljang karljang marked this pull request as draft April 27, 2026 04:47
@karljang
Copy link
Copy Markdown
Collaborator Author

/bot run --disable-fail-fast

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #45648 [ run ] triggered by Bot. Commit: e885a55 Link to invocation

@karljang karljang marked this pull request as ready for review April 27, 2026 05:01
@karljang karljang requested a review from NVShreyas April 27, 2026 05:27
@karljang
Copy link
Copy Markdown
Collaborator Author

Hi @NVShreyas,
Could you please review this? I had to make these changes to resolve an issue I encountered while sweeping the performance of a DiT model.

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #45648 [ run ] completed with state SUCCESS. Commit: e885a55
/LLM/main/L0_MergeRequest_PR pipeline #35861 completed with status: 'FAILURE'

CI Report

⚠️ Action Required:

  • Please check the failed tests and fix your PR
  • If you cannot view the failures, ask the CI triggerer to share details
  • Once fixed, request an NVIDIA team member to trigger CI again

Link to invocation

@karljang
Copy link
Copy Markdown
Collaborator Author

/bot run --disable-fail-fast

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #45747 [ run ] triggered by Bot. Commit: e885a55 Link to invocation

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants