diff --git a/tests/pytorch/attention/test_attention.py b/tests/pytorch/attention/test_attention.py index 2eb307aa48..6ff5d140a3 100644 --- a/tests/pytorch/attention/test_attention.py +++ b/tests/pytorch/attention/test_attention.py @@ -895,13 +895,188 @@ def test_dpa_qkv_layout_thd(dtype, model_configs, model, qkv_layout): test_dot_product_attention( dtype, model_configs, model, False, True, qkv_layout, False, pad_between_seqs ) - if get_cudnn_version() >= (9, 3, 0): - logging.info("[test_dpa_qkv_layout_thd]: pad_between_seqs = False") - # cuDNN 9.3.0+ is required to run pad_between_seqs = False/True in the same run - pad_between_seqs = False - test_dot_product_attention( - dtype, model_configs, model, False, True, qkv_layout, False, pad_between_seqs + # if get_cudnn_version() >= (9, 3, 0): + # logging.info("[test_dpa_qkv_layout_thd]: pad_between_seqs = False") + # # cuDNN 9.3.0+ is required to run pad_between_seqs = False/True in the same run + # pad_between_seqs = False + # test_dot_product_attention( + # dtype, model_configs, model, False, True, qkv_layout, False, pad_between_seqs + # ) + + +def test_fused_attn_split_q(): + """Verify front-padding by splitting Q into two halves and running attention in two steps. + + Uses causal attention. Three runs on the same Q/K/V data (tightly packed THD): + 1) Full: all Q tokens valid, padding_causal_bottom_right → reference output + 2) Step 1: first half of each seq's Q valid (tail-padding), padding_causal + 3) Step 2: second half valid (front-padding), padding_causal_bottom_right + + Step 1 uses padding_causal (top-left aligned) so Q[i] attends to K[0..i], + matching the full run's first half. + Step 2 uses padding_causal_bottom_right so Q[i] attends to K[0..i+s/2], + matching the full run's second half. + + Forward: step1 first-halves + step2 second-halves == full output. + Backward: dQ from step1/step2 at valid positions == dQ from full run. + """ + from transformer_engine.pytorch.constants import TE_DType + + dtype = torch.bfloat16 + batch_size = 2 + num_heads = 16 + head_dim = 64 + max_seqlen = 2048 + + # Even sequence lengths so they split cleanly + seqlens = torch.tensor([1504, 1826], dtype=torch.int32, device="cuda") + half_seqlens = seqlens // 2 # [752, 913] + cu_seqlens = torch.zeros(batch_size + 1, dtype=torch.int32, device="cuda") + cu_seqlens[1:] = torch.cumsum(seqlens, dim=0) + total_tokens = cu_seqlens[-1].item() # 3330 + + # Create tightly packed Q, K, V (no padding between sequences) + torch.manual_seed(42) + q_data = 0.1 * torch.randn(total_tokens, num_heads, head_dim, dtype=dtype, device="cuda") + k_data = 0.1 * torch.randn(total_tokens, num_heads, head_dim, dtype=dtype, device="cuda") + v_data = 0.1 * torch.randn(total_tokens, num_heads, head_dim, dtype=dtype, device="cuda") + + # Create output gradient for backward + d_out = 0.001 * torch.randint( + 0, 200, (total_tokens, num_heads * head_dim), dtype=dtype, device="cuda" + ) + + backend = FusedAttnBackend["F16_arbitrary_seqlen"] + + common_kwargs = dict( + max_seqlen_q=max_seqlen, + max_seqlen_kv=max_seqlen, + k=k_data, + v=v_data, + fake_dtype=dtype, + fused_attention_backend=backend, + cu_seqlens_kv=cu_seqlens, + cu_seqlens_kv_padded=cu_seqlens, + qkv_layout="thd_thd_thd", + attn_bias_type="no_bias", + ) + + bwd_common_kwargs = dict( + max_seqlen_q=max_seqlen, + max_seqlen_kv=max_seqlen, + k=k_data, + v=v_data, + fake_dtype=dtype, + dqkv_dtype=TE_DType[dtype], + fused_attention_backend=backend, + cu_seqlens_kv=cu_seqlens, + cu_seqlens_kv_padded=cu_seqlens, + qkv_layout="thd_thd_thd", + attn_bias_type="no_bias", + ) + + # --- Full run (reference): all Q tokens valid --- + out_full, aux_full, *_ = fused_attn_fwd( + is_training=True, + cu_seqlens_q=cu_seqlens, + q=q_data, + cu_seqlens_q_padded=cu_seqlens, + attn_mask_type="padding_causal_bottom_right", + **common_kwargs, + ) + dq_full, dk_full, dv_full, *_ = fused_attn_bwd( + cu_seqlens_q=cu_seqlens, + q=q_data, + o=out_full, + d_o=d_out, + aux_ctx_tensors=aux_full, + cu_seqlens_q_padded=cu_seqlens, + attn_mask_type="padding_causal_bottom_right", + **bwd_common_kwargs, + ) + + # cu_seqlens_q for half-length queries: [0, 752, 1665] + cu_seqlens_q_half = torch.zeros(batch_size + 1, dtype=torch.int32, device="cuda") + cu_seqlens_q_half[1:] = torch.cumsum(half_seqlens, dim=0) + + # --- Step 1: first half of each seq's Q valid, second half is tail-padding --- + # padding_causal (top-left aligned): Q[i] attends to K[0..i] + out_step1, aux_step1, *_ = fused_attn_fwd( + is_training=True, + cu_seqlens_q=cu_seqlens_q_half, + q=q_data, + cu_seqlens_q_padded=cu_seqlens, + attn_mask_type="padding_causal", + **common_kwargs, + ) + dq_step1, dk_step1, dv_step1, *_ = fused_attn_bwd( + cu_seqlens_q=cu_seqlens_q_half, + q=q_data, + o=out_step1, + d_o=d_out, + aux_ctx_tensors=aux_step1, + cu_seqlens_q_padded=cu_seqlens, + attn_mask_type="padding_causal", + **bwd_common_kwargs, + ) + + # --- Step 2: second half of each seq's Q valid, first half is front-padding --- + # padding_causal_bottom_right: Q[i] attends to K[0..i + seqlen_kv - seqlen_q] + cu_seqlens_q_padded_step2 = torch.zeros(batch_size + 1, dtype=torch.int32, device="cuda") + for i in range(batch_size): + cu_seqlens_q_padded_step2[i] = cu_seqlens[i] + half_seqlens[i] + cu_seqlens_q_padded_step2[batch_size] = total_tokens + + out_step2, aux_step2, *_ = fused_attn_fwd( + is_training=True, + cu_seqlens_q=cu_seqlens_q_half, + q=q_data, + cu_seqlens_q_padded=cu_seqlens_q_padded_step2, + attn_mask_type="padding_causal_bottom_right", + **common_kwargs, + ) + dq_step2, dk_step2, dv_step2, *_ = fused_attn_bwd( + cu_seqlens_q=cu_seqlens_q_half, + q=q_data, + o=out_step2, + d_o=d_out, + aux_ctx_tensors=aux_step2, + cu_seqlens_q_padded=cu_seqlens_q_padded_step2, + attn_mask_type="padding_causal_bottom_right", + **bwd_common_kwargs, + ) + + # --- Compare forward: stitch step1 + step2 vs full --- + fwd_tols = dict(atol=0, rtol=0) + bwd_tols = dict(atol=1e-6, rtol=1e-4) # bf16 backward has minor numerical variance + for i in range(batch_size): + s = cu_seqlens[i].item() + h = half_seqlens[i].item() + e = cu_seqlens[i + 1].item() + + diff1 = (out_full[s : s + h] - out_step1[s : s + h]).abs().max().item() + diff2 = (out_full[s + h : e] - out_step2[s + h : e]).abs().max().item() + logging.info( + f"[test_fused_attn_split_q]: fwd seq {i}: " + f"first_half max_diff={diff1}, second_half max_diff={diff2}" + ) + torch.testing.assert_close(out_full[s : s + h], out_step1[s : s + h], **fwd_tols) + torch.testing.assert_close(out_full[s + h : e], out_step2[s + h : e], **fwd_tols) + + # --- Compare backward: dQ at valid positions --- + for i in range(batch_size): + s = cu_seqlens[i].item() + h = half_seqlens[i].item() + e = cu_seqlens[i + 1].item() + + dq_diff1 = (dq_full[s : s + h] - dq_step1[s : s + h]).abs().max().item() + dq_diff2 = (dq_full[s + h : e] - dq_step2[s + h : e]).abs().max().item() + logging.info( + f"[test_fused_attn_split_q]: bwd dQ seq {i}: " + f"first_half max_diff={dq_diff1}, second_half max_diff={dq_diff2}" ) + torch.testing.assert_close(dq_full[s : s + h], dq_step1[s : s + h], **bwd_tols) + torch.testing.assert_close(dq_full[s + h : e], dq_step2[s + h : e], **bwd_tols) def _run_dot_product_attention( diff --git a/tests/pytorch/attention/test_thd_ag_cp.py b/tests/pytorch/attention/test_thd_ag_cp.py new file mode 100644 index 0000000000..4738f7b181 --- /dev/null +++ b/tests/pytorch/attention/test_thd_ag_cp.py @@ -0,0 +1,154 @@ +""" +Standalone test for THD AllGather-based Context Parallelism (forward + backward). + +Run with: + torchrun --nproc_per_node=2 tests/pytorch/attention/test_thd_ag_cp.py +""" + +import os +import sys +import logging +import torch +import torch.distributed as dist + +# Force fused attention backend +os.environ["NVTE_FLASH_ATTN"] = "0" +os.environ["NVTE_FUSED_ATTN"] = "1" + +import transformer_engine.pytorch as te +import transformer_engine_torch as tex + +logging.basicConfig(level=logging.INFO) +log = logging.getLogger(__name__) + + +def run_test(): + rank = int(os.environ.get("RANK", "0")) + world_size = int(os.environ.get("WORLD_SIZE", "1")) + assert world_size == 2, "This test requires exactly 2 GPUs" + + device_count = torch.cuda.device_count() + torch.cuda.set_device(rank % device_count) + dist.init_process_group(backend="nccl", world_size=world_size, rank=rank) + + cp_group = dist.new_group(range(world_size), backend="nccl") + cp_stream = torch.cuda.Stream() + + # Config + batch_size = 3 + num_heads = 16 + head_dim = 64 + dtype = torch.bfloat16 + atol, rtol = 2.5e-2, 2.5e-2 + + # Sequence lengths must be divisible by 2*world_size=4 + seqlens = torch.tensor([256, 512, 1024], dtype=torch.int32) + assert all(s % (2 * world_size) == 0 for s in seqlens) + + # Build cu_seqlens (no padding between seqs) + cu_seqlens = torch.zeros(batch_size + 1, dtype=torch.int32) + cu_seqlens[1:] = seqlens.cumsum(0) + cu_seqlens = cu_seqlens.cuda() + cu_seqlens_padded = cu_seqlens.clone() + total_tokens = cu_seqlens[-1].item() + + # Create global Q/K/V data (same on all ranks via same seed) + torch.manual_seed(42) + q_global = 0.02 * torch.randn(total_tokens, num_heads, head_dim, dtype=dtype, device="cuda") + k_global = 0.02 * torch.randn(total_tokens, num_heads, head_dim, dtype=dtype, device="cuda") + v_global = 0.02 * torch.randn(total_tokens, num_heads, head_dim, dtype=dtype, device="cuda") + dout_global = 0.02 * torch.randn(total_tokens, num_heads * head_dim, dtype=dtype, device="cuda") + + # ============ Run without CP (single-GPU reference) ============ + log.info(f"[Rank {rank}] Running without CP (reference)") + core_attn = te.DotProductAttention( + num_heads, + head_dim, + num_gqa_groups=num_heads, + attention_dropout=0.0, + qkv_format="thd", + attn_mask_type="padding_causal", + ).cuda() + + q_ref = q_global.clone().requires_grad_() + k_ref = k_global.clone().requires_grad_() + v_ref = v_global.clone().requires_grad_() + + out_ref = core_attn( + q_ref, + k_ref, + v_ref, + cu_seqlens_q=cu_seqlens, + cu_seqlens_kv=cu_seqlens, + cu_seqlens_q_padded=cu_seqlens_padded, + cu_seqlens_kv_padded=cu_seqlens_padded, + ) + out_ref.backward(dout_global.clone()) + dq_ref, dk_ref, dv_ref = q_ref.grad, k_ref.grad, v_ref.grad + log.info(f"[Rank {rank}] Reference out shape: {out_ref.shape}") + + # ============ Run with CP (AllGather) ============ + log.info(f"[Rank {rank}] Running with CP (all_gather)") + + # Partition Q/K/V for this CP rank + seq_idx = tex.thd_get_partitioned_indices(cu_seqlens_padded, total_tokens, world_size, rank) + seq_idx_kv = seq_idx # same since self-attention + + q_cp = q_global.index_select(0, seq_idx).contiguous().requires_grad_() + k_cp = k_global.index_select(0, seq_idx_kv).contiguous().requires_grad_() + v_cp = v_global.index_select(0, seq_idx_kv).contiguous().requires_grad_() + dout_cp = dout_global.index_select(0, seq_idx).contiguous() + + # Set up CP group + core_attn.set_context_parallel_group( + cp_group, + list(range(world_size)), + cp_stream, + "all_gather", + ) + + out_cp = core_attn( + q_cp, + k_cp, + v_cp, + cu_seqlens_q=cu_seqlens, + cu_seqlens_kv=cu_seqlens, + cu_seqlens_q_padded=cu_seqlens_padded, + cu_seqlens_kv_padded=cu_seqlens_padded, + ) + out_cp.backward(dout_cp) + dq_cp, dk_cp, dv_cp = q_cp.grad, k_cp.grad, v_cp.grad + log.info(f"[Rank {rank}] CP out shape: {out_cp.shape}") + + # ============ Compare ============ + # Extract reference outputs for this rank's partition + out_ref_part = out_ref.detach().index_select(0, seq_idx).contiguous() + dq_ref_part = dq_ref.index_select(0, seq_idx).contiguous() + dk_ref_part = dk_ref.index_select(0, seq_idx_kv).contiguous() + dv_ref_part = dv_ref.index_select(0, seq_idx_kv).contiguous() + + passed = True + for name, ref, cp in [ + ("out", out_ref_part, out_cp.detach()), + ("dq", dq_ref_part, dq_cp), + ("dk", dk_ref_part, dk_cp), + ("dv", dv_ref_part, dv_cp), + ]: + max_diff = (ref - cp).abs().max().item() + log.info(f"[Rank {rank}] {name}: max_diff = {max_diff}") + try: + torch.testing.assert_close(ref, cp, atol=atol, rtol=rtol) + log.info(f"[Rank {rank}] {name}: PASSED") + except AssertionError as e: + log.error(f"[Rank {rank}] {name}: FAILED - {e}") + passed = False + + dist.destroy_process_group() + if not passed: + log.error(f"[Rank {rank}] TEST FAILED") + sys.exit(1) + log.info(f"[Rank {rank}] ALL TESTS PASSED") + + +if __name__ == "__main__": + run_test() diff --git a/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py b/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py index 64cccaac6e..f8a3733c41 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py @@ -2834,9 +2834,14 @@ def forward( causal = "causal" in attn_mask_type padding = "padding" in attn_mask_type - assert not padding, f"{attn_mask_type} mask type is not supported!" + if qkv_format == "thd": + # THD always uses padding mask types; per-step masks set internally + assert padding, f"THD format requires padding mask type, got {attn_mask_type}!" + else: + assert not padding, f"{attn_mask_type} mask type is not supported!" if use_fused_attention and causal and "bottom_right" not in attn_mask_type: - attn_mask_type = attn_mask_type + "_bottom_right" + if qkv_format != "thd": + attn_mask_type = attn_mask_type + "_bottom_right" assert attn_bias_type == "no_bias", f"{attn_bias_type} bias type is not supported!" assert q.shape[-1] % 8 == 0, "Hidden size per attention head should be multiple of 8!" assert ( @@ -2874,14 +2879,18 @@ def forward( if fa_utils.v2_6_0_plus: fa_forward_kwargs["softcap"] = 0.0 - assert qkv_format != "thd", f"{qkv_format} format is not supported!" qkv_layout = qkv_format + "_" + qkv_format + "_" + qkv_format - seq_dim = qkv_format.index("s") - assert ( - q.shape[seq_dim] % 2 == 0 and k.shape[seq_dim] % 2 == 0 - ), "Sequence length per GPU needs to be divisible by 2!" + if qkv_format == "thd": + # Save original cu_seqlens before division (needed for KV reordering) + cu_seqlens_q_original = cu_seqlens_q.clone() + else: + seq_dim = qkv_format.index("s") + assert ( + q.shape[seq_dim] % 2 == 0 and k.shape[seq_dim] % 2 == 0 + ), "Sequence length per GPU needs to be divisible by 2!" + # Divide by 2*cp_size to get per-chunk values max_seqlen_q = max_seqlen_q // (2 * cp_size) max_seqlen_kv = max_seqlen_kv // (2 * cp_size) if use_fused_attention or qkv_format == "thd": @@ -2891,24 +2900,36 @@ def forward( else: cu_seqlens_q_padded = None - # [b, s, h, d] -> [b, 2, s//2, h, d] or [s, b, h, d] -> [2, s//2, b, h, d] - q = q.view(*q.shape[:seq_dim], 2, q.shape[seq_dim] // 2, *q.shape[(seq_dim + 1) :]) - # [b, s, h, d] or [s, b, h, d] -> [s, b, h, d] - k, v = [x.movedim(seq_dim, 0).contiguous() for x in [k, v]] + if qkv_format != "thd": + # [b, s, h, d] -> [b, 2, s//2, h, d] or [s, b, h, d] -> [2, s//2, b, h, d] + q = q.view(*q.shape[:seq_dim], 2, q.shape[seq_dim] // 2, *q.shape[(seq_dim + 1) :]) + # [b, s, h, d] or [s, b, h, d] -> [s, b, h, d] + k, v = [x.movedim(seq_dim, 0).contiguous() for x in [k, v]] - # [s, b, h, d] -> [cp, s, b, h, d] + # AllGather K/V across CP ranks k_ag, _ = gather_along_first_dim(k, cp_group) v_ag, _ = gather_along_first_dim(v, cp_group) - # [cp, s, b, h, d] -> [cp*2, s//2, b, h, d] - k_ag = k_ag.view(2 * cp_size, k.shape[0] // 2, *k.shape[1:]) - v_ag = v_ag.view(2 * cp_size, v.shape[0] // 2, *v.shape[1:]) - chunk_ids_for_kv_ag = get_seq_chunk_ids_for_reordering_before_attn(cp_size, k.device) - k_ag = torch.index_select(k_ag, dim=0, index=chunk_ids_for_kv_ag) - v_ag = torch.index_select(v_ag, dim=0, index=chunk_ids_for_kv_ag) - # [cp*2, s//2, b, h, d] -> [cp*s, b, h, d] - k_ag = k_ag.view(-1, *k.shape[1:]) - v_ag = v_ag.view(-1, *v.shape[1:]) + if qkv_format == "thd": + # [cp*t, h, d] -> reorder to contiguous per-sequence order -> [t_full, h, d] + chunk_ids_for_kv_ag = get_seq_chunk_ids_for_reordering_before_attn(cp_size, k.device) + k_ag = reorder_seq_chunks_after_a2a_before_attn_thd( + k_ag, cu_seqlens_q_original, chunk_ids_for_kv_ag, cp_size + ) + v_ag = reorder_seq_chunks_after_a2a_before_attn_thd( + v_ag, cu_seqlens_q_original, chunk_ids_for_kv_ag, cp_size + ) + else: + # [cp, s, b, h, d] -> [cp*2, s//2, b, h, d] + k_ag = k_ag.view(2 * cp_size, k.shape[0] // 2, *k.shape[1:]) + v_ag = v_ag.view(2 * cp_size, v.shape[0] // 2, *v.shape[1:]) + chunk_ids_for_kv_ag = get_seq_chunk_ids_for_reordering_before_attn(cp_size, k.device) + k_ag = torch.index_select(k_ag, dim=0, index=chunk_ids_for_kv_ag) + v_ag = torch.index_select(v_ag, dim=0, index=chunk_ids_for_kv_ag) + # [cp*2, s//2, b, h, d] -> [cp*s, b, h, d] + k_ag = k_ag.view(-1, *k.shape[1:]) + v_ag = v_ag.view(-1, *v.shape[1:]) + cp_stream.wait_stream(torch.cuda.current_stream()) # create two streams to resolve wave quantization issue of Flash Attn in each step @@ -2925,35 +2946,82 @@ def forward( max_logit_per_step = [None, None] max_logit = None + # Pre-compute THD-specific per-step cu_seqlens + if qkv_format == "thd": + cu_seqlens_kv_full = cu_seqlens_q * (2 * cp_size) + # Rank-level padded offsets (2 chunks per sequence on this rank) + cu_seqlens_q_padded_rank = cu_seqlens_q_padded * 2 + chunk_sizes_q = cu_seqlens_q[1:] - cu_seqlens_q[:-1] + # Step 0: first chunk valid, starts at allocation start + thd_cu_seqlens_q_padded_per_step = [cu_seqlens_q_padded_rank, None] + # Step 1: second chunk valid, starts at midpoint (front-padded) + thd_cu_seqlens_q_padded_per_step[1] = cu_seqlens_q_padded_rank.clone() + thd_cu_seqlens_q_padded_per_step[1][:-1] += chunk_sizes_q + if causal: + # Q is always the last chunk in the visible KV range, + # so bottom_right alignment is always correct. + # (When seqlen_q == seqlen_kv, bottom_right == top-left.) + thd_attn_mask_type_per_step = [ + "padding_causal_bottom_right", + "padding_causal_bottom_right", + ] + else: + thd_attn_mask_type_per_step = ["padding", "padding"] + for i in range(len(local_seq_chunk_ids) + 1): if i < len(local_seq_chunk_ids): with torch.cuda.stream(flash_attn_streams[i]): - # [b, 2, sq//2, h, d] -> [b, sq//2, h, d] - # or [2, sq//2, b, h, d] -> [sq//2, b, h, d] - q_ = q.select(seq_dim, i).contiguous() - kv_seq_range_per_step[i], window_size_per_step[i] = ( - get_kv_seq_info_after_all_gather( - local_seq_chunk_ids[i], - cp_size, - max_seqlen_q, - max_seqlen_kv, - window_size, - causal, + if qkv_format in ["bshd", "sbhd"]: + # [b, 2, sq//2, h, d] -> [b, sq//2, h, d] + # or [2, sq//2, b, h, d] -> [sq//2, b, h, d] + q_ = q.select(seq_dim, i).contiguous() + kv_seq_range_per_step[i], window_size_per_step[i] = ( + get_kv_seq_info_after_all_gather( + local_seq_chunk_ids[i], + cp_size, + max_seqlen_q, + max_seqlen_kv, + window_size, + causal, + ) ) - ) - seq_start_idx, seq_end_idx = ( - kv_seq_range_per_step[i][0], - kv_seq_range_per_step[i][1], - ) - max_seqlen_kv_ = seq_end_idx - seq_start_idx - if use_fused_attention or qkv_format == "thd": - cu_seqlens_kv_per_step[i] = dpa_utils.get_full_cu_seqlens( - k.shape[1], max_seqlen_kv_, k.device + seq_start_idx, seq_end_idx = ( + kv_seq_range_per_step[i][0], + kv_seq_range_per_step[i][1], ) - k_, v_ = [x[seq_start_idx:seq_end_idx] for x in [k_ag, v_ag]] - # [s_range, b, h, d] -> [b, s_range, h, d] or [s_range, b, h, d] - k_, v_ = [x.movedim(0, seq_dim).contiguous() for x in [k_, v_]] + max_seqlen_kv_ = seq_end_idx - seq_start_idx + if use_fused_attention: + cu_seqlens_kv_per_step[i] = dpa_utils.get_full_cu_seqlens( + k.shape[1], max_seqlen_kv_, k.device + ) + k_, v_ = [x[seq_start_idx:seq_end_idx] for x in [k_ag, v_ag]] + # [s_range, b, h, d] -> [b, s_range, h, d] or [s_range, b, h, d] + k_, v_ = [x.movedim(0, seq_dim).contiguous() for x in [k_, v_]] + elif qkv_format == "thd": + # THD: no tensor slicing — use cu_seqlens to select valid + # tokens per step. Q stays [t_rank, h, d], K/V stay [t_full, h, d]. + q_ = q + k_ = k_ag + v_ = v_ag + chunk_id = local_seq_chunk_ids[i] + max_seqlen_kv_ = max_seqlen_kv * (chunk_id + 1) + # KV visible range: first (chunk_id+1) chunks per sequence + cu_seqlens_kv_per_step[i] = cu_seqlens_q * (chunk_id + 1) + # Window size + if window_size is None: + window_size_per_step[i] = (-1, 0) if causal else (-1, -1) + else: + window_size_per_step[i] = window_size if use_fused_attention: + # Set per-step parameters for THD vs bshd/sbhd + if qkv_format == "thd": + attn_mask_type_ = thd_attn_mask_type_per_step[i] + cu_seqlens_q_padded_ = thd_cu_seqlens_q_padded_per_step[i] + cu_seqlens_kv_padded_ = cu_seqlens_kv_full + else: + attn_mask_type_ = attn_mask_type + cu_seqlens_q_padded_ = cu_seqlens_q_padded + cu_seqlens_kv_padded_ = cu_seqlens_kv_per_step[i] ( out_per_step[i], [softmax_lse_per_step[i], rng_states[i]], @@ -2972,11 +3040,11 @@ def forward( attn_scale=softmax_scale, dropout=dropout_p, qkv_layout=qkv_layout, - attn_mask_type=attn_mask_type, + attn_mask_type=attn_mask_type_, attn_bias_type=attn_bias_type, attn_bias=attn_bias, - cu_seqlens_q_padded=cu_seqlens_q_padded, - cu_seqlens_kv_padded=cu_seqlens_kv_per_step[i], + cu_seqlens_q_padded=cu_seqlens_q_padded_, + cu_seqlens_kv_padded=cu_seqlens_kv_padded_, window_size=window_size_per_step[i], return_max_logit=return_max_logit, cuda_graph=is_graph_capturing(), @@ -3025,6 +3093,17 @@ def forward( out[:, i - 1].copy_(out_per_step[i - 1]) elif qkv_format == "sbhd": out[i - 1].copy_(out_per_step[i - 1]) + elif qkv_format == "thd": + # Copy valid token ranges from this step's output + # Step 0 wrote at first-chunk positions, + # Step 1 wrote at second-chunk positions + step_padded = thd_cu_seqlens_q_padded_per_step[i - 1] + batch_size = cu_seqlens_q.shape[0] - 1 + for b in range(batch_size): + s = step_padded[b].item() + sz = chunk_sizes_q[b].item() + out[s : s + sz].copy_(out_per_step[i - 1][s : s + sz]) + if return_max_logit: max_logit = torch.maximum(max_logit, max_logit_per_step[i - 1]) @@ -3034,7 +3113,9 @@ def forward( max_logit, op=torch.distributed.ReduceOp.MAX, group=cp_group ) - if use_fused_attention: + if qkv_format == "thd": + pass # out is already [t_rank, h, d], no reshape needed + elif use_fused_attention: if qkv_format == "bshd": out = out.view(out.shape[0], -1, *out.shape[-2:]) elif qkv_format == "sbhd": @@ -3068,6 +3149,9 @@ def forward( ctx.deterministic = deterministic ctx.use_fused_attention = use_fused_attention ctx.use_flash_attn_3 = use_flash_attn_3 + if qkv_format == "thd": + ctx.thd_attn_mask_type_per_step = thd_attn_mask_type_per_step + ctx.max_seqlen_kv = max_seqlen_kv nvtx_range_pop("transformer_engine.AttnFuncWithCPAndKVAllGather.forward") if return_max_logit: return out, max_logit @@ -3089,7 +3173,8 @@ def backward(ctx, dout, *_args): kv_seq_range_per_step = ctx.kv_seq_range_per_step window_size_per_step = ctx.window_size_per_step - seq_dim = ctx.qkv_format.index("s") + if ctx.qkv_format != "thd": + seq_dim = ctx.qkv_format.index("s") qkv_layout = ctx.qkv_format + "_" + ctx.qkv_format + "_" + ctx.qkv_format dout = dout.view(q.shape) @@ -3105,19 +3190,36 @@ def backward(ctx, dout, *_args): # synchronize dkv update across steps dkv_update_done = torch.cuda.Event() - # [s, b, h, d] -> [cp, s, b, h, d] + # AllGather K/V across CP ranks k_ag, _ = gather_along_first_dim(k, ctx.cp_group) v_ag, _ = gather_along_first_dim(v, ctx.cp_group) - # [cp, s, b, h, d] -> [cp*2, s//2, b, h, d] - k_ag = k_ag.view(2 * cp_size, k.shape[0] // 2, *k.shape[1:]) - v_ag = v_ag.view(2 * cp_size, v.shape[0] // 2, *v.shape[1:]) - chunk_ids_for_kv_ag = get_seq_chunk_ids_for_reordering_before_attn(cp_size, k.device) - k_ag = torch.index_select(k_ag, dim=0, index=chunk_ids_for_kv_ag) - v_ag = torch.index_select(v_ag, dim=0, index=chunk_ids_for_kv_ag) - # [cp*2, s//2, b, h, d] -> [cp*s, b, h, d] - k_ag = k_ag.view(-1, *k.shape[1:]) - v_ag = v_ag.view(-1, *v.shape[1:]) + if ctx.qkv_format == "thd": + # [cp*t, h, d] -> reorder to contiguous per-sequence order + cu_seqlens_kv_full = cu_seqlens_q * (2 * cp_size) + chunk_ids_for_kv_ag = get_seq_chunk_ids_for_reordering_before_attn(cp_size, k.device) + k_ag = reorder_seq_chunks_after_a2a_before_attn_thd( + k_ag, cu_seqlens_kv_full, chunk_ids_for_kv_ag, cp_size + ) + v_ag = reorder_seq_chunks_after_a2a_before_attn_thd( + v_ag, cu_seqlens_kv_full, chunk_ids_for_kv_ag, cp_size + ) + # Pre-compute THD per-step values (same as forward) + cu_seqlens_q_padded_rank = cu_seqlens_q_padded * 2 + chunk_sizes_q = cu_seqlens_q[1:] - cu_seqlens_q[:-1] + thd_cu_seqlens_q_padded_per_step = [cu_seqlens_q_padded_rank, None] + thd_cu_seqlens_q_padded_per_step[1] = cu_seqlens_q_padded_rank.clone() + thd_cu_seqlens_q_padded_per_step[1][:-1] += chunk_sizes_q + else: + # [cp, s, b, h, d] -> [cp*2, s//2, b, h, d] + k_ag = k_ag.view(2 * cp_size, k.shape[0] // 2, *k.shape[1:]) + v_ag = v_ag.view(2 * cp_size, v.shape[0] // 2, *v.shape[1:]) + chunk_ids_for_kv_ag = get_seq_chunk_ids_for_reordering_before_attn(cp_size, k.device) + k_ag = torch.index_select(k_ag, dim=0, index=chunk_ids_for_kv_ag) + v_ag = torch.index_select(v_ag, dim=0, index=chunk_ids_for_kv_ag) + # [cp*2, s//2, b, h, d] -> [cp*s, b, h, d] + k_ag = k_ag.view(-1, *k.shape[1:]) + v_ag = v_ag.view(-1, *v.shape[1:]) ctx.cp_stream.wait_stream(torch.cuda.current_stream()) local_seq_chunk_ids = [rank, 2 * cp_size - rank - 1] @@ -3156,20 +3258,39 @@ def backward(ctx, dout, *_args): for i in range(len(local_seq_chunk_ids) + 1): if i < len(local_seq_chunk_ids): with torch.cuda.stream(flash_attn_streams[i]): - # [b, 2, sq//2, h, d] -> [b, sq//2, h, d] - # or [2, sq//2, b, h, d] -> [sq//2, b, h, d] - q_ = q.select(seq_dim, i).contiguous() - seq_start_idx, seq_end_idx = ( - kv_seq_range_per_step[i][0], - kv_seq_range_per_step[i][1], - ) - max_seqlen_kv = seq_end_idx - seq_start_idx - k_, v_ = [x[seq_start_idx:seq_end_idx] for x in [k_ag, v_ag]] - # [cp*s, b, h, d] -> [b, s_range, h, d] or [s_range, b, h, d] - k_, v_ = [x.movedim(0, seq_dim).contiguous() for x in [k_, v_]] - out_ = out_per_step[i] - dout_ = dout.select(seq_dim, i).contiguous().view(out_.shape) + if ctx.qkv_format == "thd": + # THD: use full tensors + cu_seqlens (mirrors forward) + q_ = q + k_ = k_ag + v_ = v_ag + chunk_id = local_seq_chunk_ids[i] + max_seqlen_kv = ctx.max_seqlen_kv * (chunk_id + 1) + out_ = out_per_step[i] + dout_ = dout + else: + # [b, 2, sq//2, h, d] -> [b, sq//2, h, d] + # or [2, sq//2, b, h, d] -> [sq//2, b, h, d] + q_ = q.select(seq_dim, i).contiguous() + seq_start_idx, seq_end_idx = ( + kv_seq_range_per_step[i][0], + kv_seq_range_per_step[i][1], + ) + max_seqlen_kv = seq_end_idx - seq_start_idx + k_, v_ = [x[seq_start_idx:seq_end_idx] for x in [k_ag, v_ag]] + # [cp*s, b, h, d] -> [b, s_range, h, d] or [s_range, b, h, d] + k_, v_ = [x.movedim(0, seq_dim).contiguous() for x in [k_, v_]] + out_ = out_per_step[i] + dout_ = dout.select(seq_dim, i).contiguous().view(out_.shape) if ctx.use_fused_attention: + # Set per-step parameters for THD + if ctx.qkv_format == "thd": + attn_mask_type_ = ctx.thd_attn_mask_type_per_step[i] + cu_seqlens_q_padded_ = thd_cu_seqlens_q_padded_per_step[i] + cu_seqlens_kv_padded_ = cu_seqlens_kv_full + else: + attn_mask_type_ = ctx.attn_mask_type + cu_seqlens_q_padded_ = cu_seqlens_q_padded + cu_seqlens_kv_padded_ = cu_seqlens_kv_per_step[i] aux_ctx_tensors = [softmax_lse_per_step[i], rng_states[i]] dq_per_step[i], dk_per_step[i], dv_per_step[i], *_ = fused_attn_bwd( ctx.max_seqlen_q, @@ -3185,12 +3306,12 @@ def backward(ctx, dout, *_args): TE_DType[dout.dtype], aux_ctx_tensors, tex.NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen, - cu_seqlens_q_padded=cu_seqlens_q_padded, - cu_seqlens_kv_padded=cu_seqlens_kv_per_step[i], + cu_seqlens_q_padded=cu_seqlens_q_padded_, + cu_seqlens_kv_padded=cu_seqlens_kv_padded_, attn_scale=ctx.softmax_scale, dropout=ctx.dropout_p, qkv_layout=qkv_layout, - attn_mask_type=ctx.attn_mask_type, + attn_mask_type=attn_mask_type_, attn_bias_type=ctx.attn_bias_type, window_size=window_size_per_step[i], deterministic=ctx.deterministic, @@ -3236,44 +3357,69 @@ def backward(ctx, dout, *_args): if i > 0: with torch.cuda.stream(flash_attn_streams[i - 1]): - if ctx.qkv_format == "bshd": - dq[:, i - 1].copy_(dq_per_step[i - 1]) - elif ctx.qkv_format == "sbhd": - dq[i - 1].copy_(dq_per_step[i - 1]) - # [b, s_range, h, d] or [s_range, b, h, d] -> [s_range, b, h, d] - dk_per_step[i - 1], dv_per_step[i - 1] = [ - x.movedim(seq_dim, 0).contiguous() - for x in [dk_per_step[i - 1], dv_per_step[i - 1]] - ] - # wait until dkv update of last step is done - if i > 1: - flash_attn_streams[i - 1].wait_event(dkv_update_done) - seq_start_idx, seq_end_idx = ( - kv_seq_range_per_step[i - 1][0], - kv_seq_range_per_step[i - 1][1], - ) - dk[seq_start_idx:seq_end_idx].add_(dk_per_step[i - 1]) - dv[seq_start_idx:seq_end_idx].add_(dv_per_step[i - 1]) - if i < len(local_seq_chunk_ids): - flash_attn_streams[i - 1].record_event(dkv_update_done) + if ctx.qkv_format == "thd": + # Copy valid dQ ranges (same positions as forward output) + step_padded = thd_cu_seqlens_q_padded_per_step[i - 1] + batch_size = cu_seqlens_q.shape[0] - 1 + for b in range(batch_size): + s = step_padded[b].item() + sz = chunk_sizes_q[b].item() + dq[s : s + sz].copy_(dq_per_step[i - 1][s : s + sz]) + # dK/dV: add full tensor (kernel zeros non-valid positions) + if i > 1: + flash_attn_streams[i - 1].wait_event(dkv_update_done) + dk.add_(dk_per_step[i - 1]) + dv.add_(dv_per_step[i - 1]) + if i < len(local_seq_chunk_ids): + flash_attn_streams[i - 1].record_event(dkv_update_done) + else: + if ctx.qkv_format == "bshd": + dq[:, i - 1].copy_(dq_per_step[i - 1]) + elif ctx.qkv_format == "sbhd": + dq[i - 1].copy_(dq_per_step[i - 1]) + # [b, s_range, h, d] or [s_range, b, h, d] -> [s_range, b, h, d] + dk_per_step[i - 1], dv_per_step[i - 1] = [ + x.movedim(seq_dim, 0).contiguous() + for x in [dk_per_step[i - 1], dv_per_step[i - 1]] + ] + # wait until dkv update of last step is done + if i > 1: + flash_attn_streams[i - 1].wait_event(dkv_update_done) + seq_start_idx, seq_end_idx = ( + kv_seq_range_per_step[i - 1][0], + kv_seq_range_per_step[i - 1][1], + ) + dk[seq_start_idx:seq_end_idx].add_(dk_per_step[i - 1]) + dv[seq_start_idx:seq_end_idx].add_(dv_per_step[i - 1]) + if i < len(local_seq_chunk_ids): + flash_attn_streams[i - 1].record_event(dkv_update_done) torch.cuda.current_stream().wait_stream(ctx.cp_stream) - # [cp*s, b, h, d] -> [cp*2, s//2, b, h, d] - dk = dk.view(2 * cp_size, -1, *dk.shape[-3:]) - dv = dv.view(2 * cp_size, -1, *dv.shape[-3:]) - chunk_ids_for_kv_ag = get_seq_chunk_ids_for_reordering_after_attn(cp_size, dk.device) - dk = torch.index_select(dk, dim=0, index=chunk_ids_for_kv_ag) - dv = torch.index_select(dv, dim=0, index=chunk_ids_for_kv_ag) - # [cp*2, s//2, b, h, d] -> [cp*s, b, h, d] - dk = dk.view(-1, *dk.shape[-3:]) - dv = dv.view(-1, *dv.shape[-3:]) - dk, _ = reduce_scatter_along_first_dim(dk, ctx.cp_group) - dv, _ = reduce_scatter_along_first_dim(dv, ctx.cp_group) - - dq = dq.view(*dq.shape[:seq_dim], -1, *dq.shape[(seq_dim + 2) :]) - dk = dk.movedim(0, seq_dim).contiguous() - dv = dv.movedim(0, seq_dim).contiguous() + if ctx.qkv_format == "thd": + # Reverse-reorder dK/dV from contiguous order back to dual-chunk order, + # then reduce-scatter across CP ranks + dk = reorder_seq_chunks_before_a2a_after_attn_thd(dk, cu_seqlens_kv_full, cp_size) + dv = reorder_seq_chunks_before_a2a_after_attn_thd(dv, cu_seqlens_kv_full, cp_size) + dk, _ = reduce_scatter_along_first_dim(dk, ctx.cp_group) + dv, _ = reduce_scatter_along_first_dim(dv, ctx.cp_group) + # dQ is already [t_rank, h, d], no reshape needed + else: + # [cp*s, b, h, d] -> [cp*2, s//2, b, h, d] + dk = dk.view(2 * cp_size, -1, *dk.shape[-3:]) + dv = dv.view(2 * cp_size, -1, *dv.shape[-3:]) + chunk_ids_for_kv_ag = get_seq_chunk_ids_for_reordering_after_attn(cp_size, dk.device) + dk = torch.index_select(dk, dim=0, index=chunk_ids_for_kv_ag) + dv = torch.index_select(dv, dim=0, index=chunk_ids_for_kv_ag) + # [cp*2, s//2, b, h, d] -> [cp*s, b, h, d] + dk = dk.view(-1, *dk.shape[-3:]) + dv = dv.view(-1, *dv.shape[-3:]) + dk, _ = reduce_scatter_along_first_dim(dk, ctx.cp_group) + dv, _ = reduce_scatter_along_first_dim(dv, ctx.cp_group) + + dq = dq.view(*dq.shape[:seq_dim], -1, *dq.shape[(seq_dim + 2) :]) + dk = dk.movedim(0, seq_dim).contiguous() + dv = dv.movedim(0, seq_dim).contiguous() nvtx_range_pop("transformer_engine.AttnFuncWithCPAndKVAllGather.backward") return ( diff --git a/transformer_engine/pytorch/cpp_extensions/fused_attn.py b/transformer_engine/pytorch/cpp_extensions/fused_attn.py index 7653296c78..be6cc6bc84 100644 --- a/transformer_engine/pytorch/cpp_extensions/fused_attn.py +++ b/transformer_engine/pytorch/cpp_extensions/fused_attn.py @@ -317,7 +317,6 @@ def fused_attn_fwd( raise ValueError(f"Unsupported backend {fused_attention_backend}") # execute kernel - output_tensors = tex.fused_attn_fwd( max_seqlen_q, max_seqlen_kv,