Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
187 changes: 181 additions & 6 deletions tests/pytorch/attention/test_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -895,13 +895,188 @@ def test_dpa_qkv_layout_thd(dtype, model_configs, model, qkv_layout):
test_dot_product_attention(
dtype, model_configs, model, False, True, qkv_layout, False, pad_between_seqs
)
if get_cudnn_version() >= (9, 3, 0):
logging.info("[test_dpa_qkv_layout_thd]: pad_between_seqs = False")
# cuDNN 9.3.0+ is required to run pad_between_seqs = False/True in the same run
pad_between_seqs = False
test_dot_product_attention(
dtype, model_configs, model, False, True, qkv_layout, False, pad_between_seqs
# if get_cudnn_version() >= (9, 3, 0):
# logging.info("[test_dpa_qkv_layout_thd]: pad_between_seqs = False")
# # cuDNN 9.3.0+ is required to run pad_between_seqs = False/True in the same run
# pad_between_seqs = False
# test_dot_product_attention(
# dtype, model_configs, model, False, True, qkv_layout, False, pad_between_seqs
# )


def test_fused_attn_split_q():
"""Verify front-padding by splitting Q into two halves and running attention in two steps.

Uses causal attention. Three runs on the same Q/K/V data (tightly packed THD):
1) Full: all Q tokens valid, padding_causal_bottom_right → reference output
2) Step 1: first half of each seq's Q valid (tail-padding), padding_causal
3) Step 2: second half valid (front-padding), padding_causal_bottom_right

Step 1 uses padding_causal (top-left aligned) so Q[i] attends to K[0..i],
matching the full run's first half.
Step 2 uses padding_causal_bottom_right so Q[i] attends to K[0..i+s/2],
matching the full run's second half.

Forward: step1 first-halves + step2 second-halves == full output.
Backward: dQ from step1/step2 at valid positions == dQ from full run.
"""
from transformer_engine.pytorch.constants import TE_DType

dtype = torch.bfloat16
batch_size = 2
num_heads = 16
head_dim = 64
max_seqlen = 2048

# Even sequence lengths so they split cleanly
seqlens = torch.tensor([1504, 1826], dtype=torch.int32, device="cuda")
half_seqlens = seqlens // 2 # [752, 913]
cu_seqlens = torch.zeros(batch_size + 1, dtype=torch.int32, device="cuda")
cu_seqlens[1:] = torch.cumsum(seqlens, dim=0)
total_tokens = cu_seqlens[-1].item() # 3330

# Create tightly packed Q, K, V (no padding between sequences)
torch.manual_seed(42)
q_data = 0.1 * torch.randn(total_tokens, num_heads, head_dim, dtype=dtype, device="cuda")
k_data = 0.1 * torch.randn(total_tokens, num_heads, head_dim, dtype=dtype, device="cuda")
v_data = 0.1 * torch.randn(total_tokens, num_heads, head_dim, dtype=dtype, device="cuda")

# Create output gradient for backward
d_out = 0.001 * torch.randint(
0, 200, (total_tokens, num_heads * head_dim), dtype=dtype, device="cuda"
)

backend = FusedAttnBackend["F16_arbitrary_seqlen"]

common_kwargs = dict(
max_seqlen_q=max_seqlen,
max_seqlen_kv=max_seqlen,
k=k_data,
v=v_data,
fake_dtype=dtype,
fused_attention_backend=backend,
cu_seqlens_kv=cu_seqlens,
cu_seqlens_kv_padded=cu_seqlens,
qkv_layout="thd_thd_thd",
attn_bias_type="no_bias",
)

bwd_common_kwargs = dict(
max_seqlen_q=max_seqlen,
max_seqlen_kv=max_seqlen,
k=k_data,
v=v_data,
fake_dtype=dtype,
dqkv_dtype=TE_DType[dtype],
fused_attention_backend=backend,
cu_seqlens_kv=cu_seqlens,
cu_seqlens_kv_padded=cu_seqlens,
qkv_layout="thd_thd_thd",
attn_bias_type="no_bias",
)

# --- Full run (reference): all Q tokens valid ---
out_full, aux_full, *_ = fused_attn_fwd(
is_training=True,
cu_seqlens_q=cu_seqlens,
q=q_data,
cu_seqlens_q_padded=cu_seqlens,
attn_mask_type="padding_causal_bottom_right",
**common_kwargs,
)
dq_full, dk_full, dv_full, *_ = fused_attn_bwd(
cu_seqlens_q=cu_seqlens,
q=q_data,
o=out_full,
d_o=d_out,
aux_ctx_tensors=aux_full,
cu_seqlens_q_padded=cu_seqlens,
attn_mask_type="padding_causal_bottom_right",
**bwd_common_kwargs,
)

# cu_seqlens_q for half-length queries: [0, 752, 1665]
cu_seqlens_q_half = torch.zeros(batch_size + 1, dtype=torch.int32, device="cuda")
cu_seqlens_q_half[1:] = torch.cumsum(half_seqlens, dim=0)

# --- Step 1: first half of each seq's Q valid, second half is tail-padding ---
# padding_causal (top-left aligned): Q[i] attends to K[0..i]
out_step1, aux_step1, *_ = fused_attn_fwd(
is_training=True,
cu_seqlens_q=cu_seqlens_q_half,
q=q_data,
cu_seqlens_q_padded=cu_seqlens,
attn_mask_type="padding_causal",
**common_kwargs,
)
dq_step1, dk_step1, dv_step1, *_ = fused_attn_bwd(
cu_seqlens_q=cu_seqlens_q_half,
q=q_data,
o=out_step1,
d_o=d_out,
aux_ctx_tensors=aux_step1,
cu_seqlens_q_padded=cu_seqlens,
attn_mask_type="padding_causal",
**bwd_common_kwargs,
)

# --- Step 2: second half of each seq's Q valid, first half is front-padding ---
# padding_causal_bottom_right: Q[i] attends to K[0..i + seqlen_kv - seqlen_q]
cu_seqlens_q_padded_step2 = torch.zeros(batch_size + 1, dtype=torch.int32, device="cuda")
for i in range(batch_size):
cu_seqlens_q_padded_step2[i] = cu_seqlens[i] + half_seqlens[i]
cu_seqlens_q_padded_step2[batch_size] = total_tokens

out_step2, aux_step2, *_ = fused_attn_fwd(
is_training=True,
cu_seqlens_q=cu_seqlens_q_half,
q=q_data,
cu_seqlens_q_padded=cu_seqlens_q_padded_step2,
attn_mask_type="padding_causal_bottom_right",
**common_kwargs,
)
dq_step2, dk_step2, dv_step2, *_ = fused_attn_bwd(
cu_seqlens_q=cu_seqlens_q_half,
q=q_data,
o=out_step2,
d_o=d_out,
aux_ctx_tensors=aux_step2,
cu_seqlens_q_padded=cu_seqlens_q_padded_step2,
attn_mask_type="padding_causal_bottom_right",
**bwd_common_kwargs,
)

# --- Compare forward: stitch step1 + step2 vs full ---
fwd_tols = dict(atol=0, rtol=0)
bwd_tols = dict(atol=1e-6, rtol=1e-4) # bf16 backward has minor numerical variance
for i in range(batch_size):
s = cu_seqlens[i].item()
h = half_seqlens[i].item()
e = cu_seqlens[i + 1].item()

diff1 = (out_full[s : s + h] - out_step1[s : s + h]).abs().max().item()
diff2 = (out_full[s + h : e] - out_step2[s + h : e]).abs().max().item()
logging.info(
f"[test_fused_attn_split_q]: fwd seq {i}: "
f"first_half max_diff={diff1}, second_half max_diff={diff2}"
)
torch.testing.assert_close(out_full[s : s + h], out_step1[s : s + h], **fwd_tols)
torch.testing.assert_close(out_full[s + h : e], out_step2[s + h : e], **fwd_tols)

# --- Compare backward: dQ at valid positions ---
for i in range(batch_size):
s = cu_seqlens[i].item()
h = half_seqlens[i].item()
e = cu_seqlens[i + 1].item()

dq_diff1 = (dq_full[s : s + h] - dq_step1[s : s + h]).abs().max().item()
dq_diff2 = (dq_full[s + h : e] - dq_step2[s + h : e]).abs().max().item()
logging.info(
f"[test_fused_attn_split_q]: bwd dQ seq {i}: "
f"first_half max_diff={dq_diff1}, second_half max_diff={dq_diff2}"
)
torch.testing.assert_close(dq_full[s : s + h], dq_step1[s : s + h], **bwd_tols)
torch.testing.assert_close(dq_full[s + h : e], dq_step2[s + h : e], **bwd_tols)


def _run_dot_product_attention(
Expand Down
154 changes: 154 additions & 0 deletions tests/pytorch/attention/test_thd_ag_cp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,154 @@
"""
Standalone test for THD AllGather-based Context Parallelism (forward + backward).

Run with:
torchrun --nproc_per_node=2 tests/pytorch/attention/test_thd_ag_cp.py
"""

import os
import sys
import logging
import torch
import torch.distributed as dist

# Force fused attention backend
os.environ["NVTE_FLASH_ATTN"] = "0"
os.environ["NVTE_FUSED_ATTN"] = "1"

import transformer_engine.pytorch as te
import transformer_engine_torch as tex

logging.basicConfig(level=logging.INFO)
log = logging.getLogger(__name__)


def run_test():
rank = int(os.environ.get("RANK", "0"))
world_size = int(os.environ.get("WORLD_SIZE", "1"))
assert world_size == 2, "This test requires exactly 2 GPUs"

device_count = torch.cuda.device_count()
torch.cuda.set_device(rank % device_count)
dist.init_process_group(backend="nccl", world_size=world_size, rank=rank)

cp_group = dist.new_group(range(world_size), backend="nccl")
cp_stream = torch.cuda.Stream()

# Config
batch_size = 3
num_heads = 16
head_dim = 64
dtype = torch.bfloat16
atol, rtol = 2.5e-2, 2.5e-2

# Sequence lengths must be divisible by 2*world_size=4
seqlens = torch.tensor([256, 512, 1024], dtype=torch.int32)
assert all(s % (2 * world_size) == 0 for s in seqlens)

# Build cu_seqlens (no padding between seqs)
cu_seqlens = torch.zeros(batch_size + 1, dtype=torch.int32)
cu_seqlens[1:] = seqlens.cumsum(0)
cu_seqlens = cu_seqlens.cuda()
cu_seqlens_padded = cu_seqlens.clone()
total_tokens = cu_seqlens[-1].item()

# Create global Q/K/V data (same on all ranks via same seed)
torch.manual_seed(42)
q_global = 0.02 * torch.randn(total_tokens, num_heads, head_dim, dtype=dtype, device="cuda")
k_global = 0.02 * torch.randn(total_tokens, num_heads, head_dim, dtype=dtype, device="cuda")
v_global = 0.02 * torch.randn(total_tokens, num_heads, head_dim, dtype=dtype, device="cuda")
dout_global = 0.02 * torch.randn(total_tokens, num_heads * head_dim, dtype=dtype, device="cuda")

# ============ Run without CP (single-GPU reference) ============
log.info(f"[Rank {rank}] Running without CP (reference)")
core_attn = te.DotProductAttention(
num_heads,
head_dim,
num_gqa_groups=num_heads,
attention_dropout=0.0,
qkv_format="thd",
attn_mask_type="padding_causal",
).cuda()

q_ref = q_global.clone().requires_grad_()
k_ref = k_global.clone().requires_grad_()
v_ref = v_global.clone().requires_grad_()

out_ref = core_attn(
q_ref,
k_ref,
v_ref,
cu_seqlens_q=cu_seqlens,
cu_seqlens_kv=cu_seqlens,
cu_seqlens_q_padded=cu_seqlens_padded,
cu_seqlens_kv_padded=cu_seqlens_padded,
)
out_ref.backward(dout_global.clone())
dq_ref, dk_ref, dv_ref = q_ref.grad, k_ref.grad, v_ref.grad
log.info(f"[Rank {rank}] Reference out shape: {out_ref.shape}")

# ============ Run with CP (AllGather) ============
log.info(f"[Rank {rank}] Running with CP (all_gather)")

# Partition Q/K/V for this CP rank
seq_idx = tex.thd_get_partitioned_indices(cu_seqlens_padded, total_tokens, world_size, rank)
seq_idx_kv = seq_idx # same since self-attention

q_cp = q_global.index_select(0, seq_idx).contiguous().requires_grad_()
k_cp = k_global.index_select(0, seq_idx_kv).contiguous().requires_grad_()
v_cp = v_global.index_select(0, seq_idx_kv).contiguous().requires_grad_()
dout_cp = dout_global.index_select(0, seq_idx).contiguous()

# Set up CP group
core_attn.set_context_parallel_group(
cp_group,
list(range(world_size)),
cp_stream,
"all_gather",
)

out_cp = core_attn(
q_cp,
k_cp,
v_cp,
cu_seqlens_q=cu_seqlens,
cu_seqlens_kv=cu_seqlens,
cu_seqlens_q_padded=cu_seqlens_padded,
cu_seqlens_kv_padded=cu_seqlens_padded,
)
out_cp.backward(dout_cp)
dq_cp, dk_cp, dv_cp = q_cp.grad, k_cp.grad, v_cp.grad
log.info(f"[Rank {rank}] CP out shape: {out_cp.shape}")

# ============ Compare ============
# Extract reference outputs for this rank's partition
out_ref_part = out_ref.detach().index_select(0, seq_idx).contiguous()
dq_ref_part = dq_ref.index_select(0, seq_idx).contiguous()
dk_ref_part = dk_ref.index_select(0, seq_idx_kv).contiguous()
dv_ref_part = dv_ref.index_select(0, seq_idx_kv).contiguous()

passed = True
for name, ref, cp in [
("out", out_ref_part, out_cp.detach()),
("dq", dq_ref_part, dq_cp),
("dk", dk_ref_part, dk_cp),
("dv", dv_ref_part, dv_cp),
]:
max_diff = (ref - cp).abs().max().item()
log.info(f"[Rank {rank}] {name}: max_diff = {max_diff}")
try:
torch.testing.assert_close(ref, cp, atol=atol, rtol=rtol)
log.info(f"[Rank {rank}] {name}: PASSED")
except AssertionError as e:
log.error(f"[Rank {rank}] {name}: FAILED - {e}")
passed = False

dist.destroy_process_group()
if not passed:
log.error(f"[Rank {rank}] TEST FAILED")
sys.exit(1)
log.info(f"[Rank {rank}] ALL TESTS PASSED")


if __name__ == "__main__":
run_test()
Loading