Skip to content

[PyTorch] Add fp32_partial_output support for CP P2P ring attention#1

Open
vedaanta wants to merge 1 commit into
mainfrom
cp-fp32-partial-output
Open

[PyTorch] Add fp32_partial_output support for CP P2P ring attention#1
vedaanta wants to merge 1 commit into
mainfrom
cp-fp32-partial-output

Conversation

@vedaanta
Copy link
Copy Markdown
Owner

Summary

  • Adds a fp32_partial_output flag to AttnFuncWithCPAndKVP2P (CP ring P2P attention) that casts each per-step partial output to float32 before LSE-correction merging across CP ranks, improving numerical stability for fp16/bf16 inputs
  • Threads the flag through cp_p2p_fwd_fused_attnattn_forward_func_with_cpFusedAttentionBackend.fused_attentionDotProductAttention.forward
  • Uses a software cast (out_per_step.to(float32)) after the cuDNN kernel call; the fake_dtype=torch.float32 path is wired and ready for when the cuDNN kernel gains native fp16/bf16→fp32 output support
  • Fixes the THD format output-buffer dtype to match fp32 when fp32_partial_output=True
  • Updates AttnFuncWithCPAndKVP2P.backward return tuple for the new forward input

Test plan

  • test_cp_with_fused_attention_fp32_partial_output covers all combinations of:
    • dtype: bf16, fp16
    • model: MHA causal, MHA non-causal, GQA causal
    • qkv_format: sbhd, thd
    • cp_comm_type: p2p, a2a+p2p
  • All existing CP attention tests continue to pass (flag defaults to False, no behavior change)

🤖 Generated with Claude Code

Extends context-parallel ring P2P attention (AttnFuncWithCPAndKVP2P) with a
new `fp32_partial_output` flag that accumulates per-step partial attention
outputs in float32 before LSE-correction merging, improving numerical
stability across CP ranks for fp16/bf16 inputs.

Implementation uses a software cast (out_per_step.to(float32)) after the
cuDNN kernel call. The fake_dtype=torch.float32 path is wired and ready for
when the cuDNN kernel gains native fp16/bf16→fp32 output support.

Changes:
- context_parallel.py: add fp32_partial_output param to cp_p2p_fwd_fused_attn
  and AttnFuncWithCPAndKVP2P.forward; fix THD out-buffer dtype; update
  backward return tuple for new forward input count
- backends.py: thread fp32_partial_output through FusedAttentionBackend
- dot_product_attention.py: expose fp32_partial_output in DotProductAttention.forward
- tests: add test_cp_with_fused_attention_fp32_partial_output covering
  bf16/fp16, MHA/GQA, causal/non-causal, sbhd/thd, p2p/a2a+p2p

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
@vedaanta vedaanta force-pushed the cp-fp32-partial-output branch from 20d3e05 to 877d459 Compare April 20, 2026 20:05
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant