Skip to content

fix(CP, MLA): CP works fine with MLA in a2a cp_comm_type#2826

Open
zhujian19891203 wants to merge 1 commit intoNVIDIA:mainfrom
021ai:CP_MLA_a2a
Open

fix(CP, MLA): CP works fine with MLA in a2a cp_comm_type#2826
zhujian19891203 wants to merge 1 commit intoNVIDIA:mainfrom
021ai:CP_MLA_a2a

Conversation

@zhujian19891203
Copy link
Copy Markdown
Contributor

Description

I think CP works fine with MLA in a2a cp_comm_type, we can just open it. And I conducted some relevant experiments, and the results were consistent with expectations. If my understanding is incorrect, please point it out. Thank you.

Fixes # (issue)

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • Just context_parallel.py file

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps bot commented Apr 2, 2026

Greptile Summary

This PR relaxes a constraint in attn_forward_func_with_cp, allowing Multi-head Latent Attention (MLA) — characterized by k.shape[-1] != v.shape[-1] — to be used with cp_comm_type = "a2a" (Ulysses-style all-to-all context parallelism). Previously only "p2p" and "a2a+p2p" were permitted.

Key changes:

  • Added "a2a" to the allowlist in the assert not enable_mla or cp_comm_type in [...] guard in context_parallel.py

Analysis:

  • The change is technically plausible: flash_attn_a2a_communicate handles q, k, v as independent tensors, reshaping only along the head-count dimension (x.shape[-2]) while preserving each tensor's head-dimension (x.shape[-1]). MLA's asymmetric head dimensions (k_dim ≠ v_dim) are therefore carried through the A2A scatter/gather without conflict.
  • The A2A path (AttnFuncWithCPAndQKVOA2A) makes no assumption that k and v share the same head dimension, and there is no MLA-specific code path anywhere else in the file.
  • No tests were added to validate the new combination, and no tests for the a2a + MLA path exist in the test suite.
  • A previously impossible combination is now silently enabled: MLA + sliding-window attention + a2a, since sliding_window_attn already permits a2a and MLA now also permits a2a.

Confidence Score: 3/5

The change is likely correct in isolation but lacks test coverage and silently enables an untested edge case (MLA + sliding-window + a2a)

The one-line addition is structurally sound: flash_attn_a2a_communicate handles q, k, v independently and preserves each tensor's head dimension, so asymmetric MLA dims are compatible with the A2A path. However, no automated tests were added to validate this, and the interaction with sliding-window attention is a previously-impossible combination that is now silently enabled.

transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py — specifically the guard on lines 4067-4077 and the new MLA + sliding-window combination

Important Files Changed

Filename Overview
transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py Single-line allowlist change adds "a2a" to MLA-compatible cp_comm_types; structurally sound but lacks test coverage and silently enables untested MLA + sliding-window + a2a combination

Comments Outside Diff (1)

  1. transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py, line 4067-4077 (link)

    P1 Newly enabled MLA + sliding-window + a2a combination is untested

    With this change, the combination of MLA (k.shape[-1] != v.shape[-1]) + sliding-window attention + cp_comm_type="a2a" becomes valid for the first time, because:

    • sliding_window_attn allows "a2a" (lines 4067-4070)
    • MLA now also allows "a2a" (this PR)

    Previously neither "p2p" nor "a2a+p2p" permitted sliding-window attention, so this 3-way combination was unreachable. It is now silently enabled without any dedicated validation.

    If MLA + sliding-window + a2a is not a supported combination (e.g. the underlying fused_attn_fwd kernel or flash_attn_fwd does not support asymmetric head dims together with sliding-window), this could produce silent wrong results rather than a clear error.

    Consider either:

    1. Adding an explicit guard (if unsupported), or
    2. Confirming the combination is correct and adding a test for it.

Reviews (1): Last reviewed commit: "fix(CP, MLA): CP works fine with MLA in ..." | Re-trigger Greptile

Comment on lines 4072 to 4077
enable_mla = k.shape[-1] != v.shape[-1]
assert not enable_mla or cp_comm_type in [
"p2p",
"a2a",
"a2a+p2p",
], f"Context parallelism does not support MLA with {cp_comm_type=}!"
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 No tests added for new MLA + a2a combination

The PR author acknowledges in the checklist that tests have not been added. The new "a2a" entry in the MLA allowlist enables a code path (AttnFuncWithCPAndQKVOA2A + MLA) that has not been exercised by any automated test. While the implementation appears structurally sound (the A2A communication handles q, k, v independently so asymmetric head dims are preserved), the lack of coverage means regressions in forward pass accuracy, backward pass gradients, or FP8-quantized variants could go undetected.

Consider adding a test case (similar to existing CP x MLA tests) that covers:

  • cp_comm_type="a2a" with k.shape[-1] != v.shape[-1]
  • Both use_fused_attention=True and use_fused_attention=False variants
  • Gradient correctness (torch.autograd.gradcheck or compare against a non-CP reference)

@ptrendx ptrendx requested a review from cyanguwa April 3, 2026 01:11
@ptrendx ptrendx added the community-contribution PRs from external contributor outside the core maintainers, representing community-driven work. label Apr 3, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

community-contribution PRs from external contributor outside the core maintainers, representing community-driven work.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants