fix(CP, MLA): CP works fine with MLA in a2a cp_comm_type#2826
fix(CP, MLA): CP works fine with MLA in a2a cp_comm_type#2826zhujian19891203 wants to merge 1 commit intoNVIDIA:mainfrom
Conversation
Greptile SummaryThis PR relaxes a constraint in Key changes:
Analysis:
Confidence Score: 3/5The change is likely correct in isolation but lacks test coverage and silently enables an untested edge case (MLA + sliding-window + a2a) The one-line addition is structurally sound: flash_attn_a2a_communicate handles q, k, v independently and preserves each tensor's head dimension, so asymmetric MLA dims are compatible with the A2A path. However, no automated tests were added to validate this, and the interaction with sliding-window attention is a previously-impossible combination that is now silently enabled. transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py — specifically the guard on lines 4067-4077 and the new MLA + sliding-window combination Important Files Changed
|
| enable_mla = k.shape[-1] != v.shape[-1] | ||
| assert not enable_mla or cp_comm_type in [ | ||
| "p2p", | ||
| "a2a", | ||
| "a2a+p2p", | ||
| ], f"Context parallelism does not support MLA with {cp_comm_type=}!" |
There was a problem hiding this comment.
No tests added for new MLA +
a2a combination
The PR author acknowledges in the checklist that tests have not been added. The new "a2a" entry in the MLA allowlist enables a code path (AttnFuncWithCPAndQKVOA2A + MLA) that has not been exercised by any automated test. While the implementation appears structurally sound (the A2A communication handles q, k, v independently so asymmetric head dims are preserved), the lack of coverage means regressions in forward pass accuracy, backward pass gradients, or FP8-quantized variants could go undetected.
Consider adding a test case (similar to existing CP x MLA tests) that covers:
cp_comm_type="a2a"withk.shape[-1] != v.shape[-1]- Both
use_fused_attention=Trueanduse_fused_attention=Falsevariants - Gradient correctness (
torch.autograd.gradcheckor compare against a non-CP reference)
Description
I think CP works fine with MLA in a2a cp_comm_type, we can just open it. And I conducted some relevant experiments, and the results were consistent with expectations. If my understanding is incorrect, please point it out. Thank you.
Fixes # (issue)
Type of change
Changes
Please list the changes introduced in this PR:
Checklist: