diff --git a/tests/pytorch/attention/run_attention_with_cp.py b/tests/pytorch/attention/run_attention_with_cp.py index 8dfea644a5..9f6b4944e6 100644 --- a/tests/pytorch/attention/run_attention_with_cp.py +++ b/tests/pytorch/attention/run_attention_with_cp.py @@ -2,6 +2,7 @@ # # See LICENSE for license information. +import copy import os import sys import logging @@ -29,6 +30,15 @@ ) from utils import ModelConfig, compare_and_assert +# Pool mode (NVTE_CP_POOL_PG=1) only: shared CP collective groups, created once +# per pool by run_attention_with_cp_pool.main() and reused across every case in +# that pool. world_size and the rank set don't change per case, so re-creating +# these per call would be wasted NCCL setup (~50-100 ms each). Single-shot +# subprocess mode leaves these None / [] and run_dpa_with_cp creates/destroys +# its own groups inline. +_pool_cp_comm_group = None +_pool_cp_comm_sub_groups: list = [] + dtypes = {"fp16": torch.float16, "bf16": torch.bfloat16, "fp8": torch.bfloat16} @@ -209,10 +219,13 @@ def run_dpa_with_cp( os.environ["NVTE_FUSED_ATTN"] = "0" if kernel_backend == "FlashAttention": os.environ["NVTE_FLASH_ATTN"] = "1" - config = model_configs_flash_attn[model] + # Deep-copy: the module-level dict is shared across pool cases; the + # THD branch below rewrites attn_mask_type in place, which would + # otherwise leak into subsequent cases reusing the same model key. + config = copy.deepcopy(model_configs_flash_attn[model]) if kernel_backend == "FusedAttention": os.environ["NVTE_FUSED_ATTN"] = "1" - config = model_configs_fused_attn[model] + config = copy.deepcopy(model_configs_fused_attn[model]) assert config.attn_mask_type in [ "causal", "no_mask", @@ -226,6 +239,9 @@ def run_dpa_with_cp( # set up distributed group rank = int(os.getenv("RANK", "0")) world_size = int(os.getenv("WORLD_SIZE", "1")) + # When NVTE_CP_POOL_PG=1, the pool runner owns the lifecycle of the main + # process group across many cases; here we only reuse it. + _pool_managed_pg = os.getenv("NVTE_CP_POOL_PG", "0") == "1" if dist.is_initialized(): world_size = dist.get_world_size() rank = dist.get_rank() @@ -234,25 +250,35 @@ def run_dpa_with_cp( device = rank % device_count torch.cuda.set_device(device) logging.info(f"[Rank {rank}] Setup: world_size {world_size}") - dist.init_process_group(backend="nccl", world_size=world_size, rank=rank) - - # set up communication group for CP + if not _pool_managed_pg: + dist.init_process_group(backend="nccl", world_size=world_size, rank=rank) + + # Set up communication group for CP. In pool mode, the pool worker has + # already pre-created world-scoped and a2a+p2p sub-groups once and stashed + # them in module-level pointers; we reuse those and the pool destroys them + # at shutdown. In single-shot mode we create them per call and destroy in + # the finally below. cp_comm_ranks = range(world_size) assert rank in cp_comm_ranks - cp_comm_group = dist.new_group(cp_comm_ranks, backend="nccl") - if cp_comm_type == "a2a+p2p": - assert world_size % 2 == 0, ( - "{cp_comm_type=} requires world_size % 2 = 0 as it assumes the a2a level has cp_size" - " = 2." - ) - cp_comm_sub_ranks = [range(i * 2, (i + 1) * 2) for i in range(world_size // 2)] - cp_comm_sub_ranks += [range(i, world_size, 2) for i in range(2)] - cp_comm_sub_groups = [] - for sub_ranks in cp_comm_sub_ranks: - sub_group = dist.new_group(sub_ranks, backend="nccl") - if rank in sub_ranks: - cp_comm_sub_groups.append(sub_group) - + _reusing_pool_groups = _pool_managed_pg and _pool_cp_comm_group is not None + cp_comm_group = None + cp_comm_sub_groups: list = [] + if _reusing_pool_groups: + cp_comm_group = _pool_cp_comm_group + cp_comm_sub_groups = _pool_cp_comm_sub_groups if cp_comm_type == "a2a+p2p" else [] + else: + cp_comm_group = dist.new_group(cp_comm_ranks, backend="nccl") + if cp_comm_type == "a2a+p2p": + assert world_size % 2 == 0, ( + "{cp_comm_type=} requires world_size % 2 = 0 as it assumes the a2a level has" + " cp_size = 2." + ) + cp_comm_sub_ranks = [range(i * 2, (i + 1) * 2) for i in range(world_size // 2)] + cp_comm_sub_ranks += [range(i, world_size, 2) for i in range(2)] + for sub_ranks in cp_comm_sub_ranks: + sub_group = dist.new_group(sub_ranks, backend="nccl") + if rank in sub_ranks: + cp_comm_sub_groups.append(sub_group) if dtype == "fp8": if scaling_mode == "delayed": fp8_recipe = DelayedScaling(fp8_dpa=fp8_dpa, fp8_mha=fp8_mha) @@ -564,7 +590,10 @@ def run_dpa_with_cp( seq_kv_size = dbias.shape[-1] # Reshape to split seq_q dimension dbias = dbias.view( - *shape_before_seq, 2 * world_size, seq_q_size // (2 * world_size), seq_kv_size + *shape_before_seq, + 2 * world_size, + seq_q_size // (2 * world_size), + seq_kv_size, ) # Index select on the newly created dimension (now at position seq_q_dim) dbias = dbias.index_select(seq_q_dim, seq_idx) @@ -754,7 +783,14 @@ def run_dpa_with_cp( ) elif qkv_format == "thd": compare_and_assert( - t, tensors_cp[i], names_no_cp[i], names_cp[i], atol, rtol, rmse_tol, is_fp8 + t, + tensors_cp[i], + names_no_cp[i], + names_cp[i], + atol, + rtol, + rmse_tol, + is_fp8, ) else: compare_and_assert( @@ -762,8 +798,28 @@ def run_dpa_with_cp( ) logging.info(f"[Rank {rank}] CP vs no-CP: {names[i]} matches") - # destroy distribution group - dist.destroy_process_group() + # Teardown on the success path. Pool mode: cp_comm_group / cp_comm_sub_groups + # point at pool-shared groups owned by the pool runner (which destroys them + # at pool shutdown), and the main PG is also pool-owned — both branches + # below are no-ops. Single-shot mode: destroy what we created here. If the + # body above raises, we skip this — the subprocess dies at function return + # and NCCL releases the communicators with the process. + if not _reusing_pool_groups: + if cp_comm_group is not None: + try: + dist.destroy_process_group(cp_comm_group) + except Exception: + pass + for g in cp_comm_sub_groups: + try: + dist.destroy_process_group(g) + except Exception: + pass + if not _pool_managed_pg: + try: + dist.destroy_process_group() + except Exception: + pass def main(**kwargs): diff --git a/tests/pytorch/attention/run_attention_with_cp_pool.py b/tests/pytorch/attention/run_attention_with_cp_pool.py new file mode 100644 index 0000000000..3e5f64a429 --- /dev/null +++ b/tests/pytorch/attention/run_attention_with_cp_pool.py @@ -0,0 +1,221 @@ +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +""" +Persistent worker for batched CP attention tests. + +Launched ONCE per (pytest session, world_size) by torchrun. All ranks init +NCCL, then enter a dispatch loop: + + rank 0: + read one JSON request line from stdin + broadcast it to all ranks + all ranks: + call run_dpa_with_cp(**kwargs) — the same work function the + per-case subprocess design uses, with NVTE_CP_POOL_PG=1 so the + function reuses our PG instead of re-initing it + torch.cuda.empty_cache() per case + all ranks gather (ok, error_msg) to rank 0 + rank 0: + write one JSON response line to stdout + +Protocol (line-delimited JSON over rank-0 stdio): + request : {"op": "run", "kwargs": {...}} + {"op": "shutdown"} + response: {"ok": true} + {"ok": false, "error": "first failing rank's traceback"} +""" +import json +import os +import sys +import time +import traceback + +import torch +import torch.distributed as dist + +# Make sibling modules importable when launched directly. +sys.path.append(os.path.dirname(os.path.abspath(__file__))) + +from run_attention_with_cp import run_dpa_with_cp +from transformer_engine.pytorch.quantization import FP8GlobalStateManager + + +def _recv_request(rank: int) -> dict: + box = [None] + if rank == 0: + line = sys.stdin.readline() + box[0] = {"op": "shutdown"} if not line else json.loads(line) + dist.broadcast_object_list(box, src=0) + return box[0] + + +def _send_response(rank: int, payload: dict) -> None: + if rank == 0: + sys.stdout.write(json.dumps(payload) + "\n") + sys.stdout.flush() + + +def _silence_non_rank0_stdout(rank: int) -> None: + """Redirect non-rank-0 stdout to /dev/null at fd level. + + All ranks share rank 0's stdout fd (torchrun inherits it from the launcher), + so Python/library writes on rank>0 would interleave with rank 0's JSON + protocol on the parent's pipe. Closing fd 1 at the OS level on rank>0 + catches both Python (``print``) and C-level (NCCL, etc.) writes. + """ + if rank == 0: + return + devnull = os.open(os.devnull, os.O_WRONLY) + os.dup2(devnull, 1) + os.close(devnull) + sys.stdout = open(1, "w", closefd=False) + + +def _reset_between_cases() -> None: + """Drop state that would otherwise cascade across cases. + + Matches the per-case startup of the single-shot worker + (``_run_single_config`` on the per-case-subprocess branch): identical RNG + seed at the start of every case, FP8 state cleared, allocator clean. + ``run_dpa_with_cp`` re-sets ``NVTE_FUSED_ATTN``/``NVTE_FLASH_ATTN`` + unconditionally and pops the other transient env vars itself, so no + explicit pop is needed here. + """ + torch.manual_seed(1234) + torch.cuda.manual_seed(1234) + FP8GlobalStateManager.reset() + torch.cuda.empty_cache() + # Invalidate DPA's module-level backend cache so the per-case + # NVTE_FLASH_ATTN/NVTE_FUSED_ATTN env-var toggle actually takes effect + # instead of reusing the previous case's resolved backend. + try: + from transformer_engine.pytorch.attention.dot_product_attention import dot_product_attention + + dot_product_attention._attention_backends["backend_selection_requires_update"] = True + except (ImportError, AttributeError, KeyError): + pass + + +_case_counter = 0 + + +def _run_one(req: dict, rank: int) -> tuple[bool, str]: + global _case_counter + op = req["op"] + if op != "run": + return False, f"unknown op: {op}" + # Reset BEFORE the case so the first case also starts from a known RNG seed + # and clean FP8 state — same as the single-shot worker's per-process startup. + _reset_between_cases() + t0 = time.monotonic() + ok = True + err = "" + try: + run_dpa_with_cp(**req.get("kwargs", {})) + except Exception: + ok = False + err = f"[Rank {rank}] {traceback.format_exc()}" + wall = time.monotonic() - t0 + # Per-case wall time on rank 0, opt-in via NVTE_CP_POOL_TIMING=1. + # Used to tune POOL_SUBMIT_TIMEOUT_SEC against the observed distribution. + if rank == 0 and int(os.environ.get("NVTE_CP_POOL_TIMING", "0")): + _case_counter += 1 + sys.stderr.write( + f"[POOL-TIMING] case_idx={_case_counter} " + f"world_size={int(os.environ.get('WORLD_SIZE', 0))} " + f"wall_s={wall:.3f} ok={ok}\n" + ) + sys.stderr.flush() + return ok, err + + +def _create_cp_comm_groups(rank: int, world_size: int) -> tuple: + """Pre-create the CP collective groups for this pool. + + world_size and the rank set are constant for the lifetime of one pool, so + the world group and the a2a+p2p sub-groups are deterministic. Creating + them once here and reusing them across every case eliminates ~50-100 ms + of NCCL setup per case (cyanguwa's review feedback on PR #2993). + + Returns ``(world_group, a2a_p2p_sub_groups)``. ``a2a_p2p_sub_groups`` is + empty when world_size is too small to support a2a+p2p (needs an even + world_size ≥ 4); cases with cp_comm_type='a2a+p2p' wouldn't be routed to + such a pool anyway. + """ + world_group = dist.new_group(range(world_size), backend="nccl") + sub_groups: list = [] + if world_size >= 4 and world_size % 2 == 0: + # Mirror the layout in run_attention_with_cp.py: cp_size/2 pairs along + # axis 0, plus 2 stride-2 groups along axis 1. + cp_comm_sub_ranks = [range(i * 2, (i + 1) * 2) for i in range(world_size // 2)] + cp_comm_sub_ranks += [range(i, world_size, 2) for i in range(2)] + for sub_ranks in cp_comm_sub_ranks: + sub_group = dist.new_group(sub_ranks, backend="nccl") + if rank in sub_ranks: + sub_groups.append(sub_group) + return world_group, sub_groups + + +def main() -> None: + rank = int(os.environ["RANK"]) + world_size = int(os.environ["WORLD_SIZE"]) + _silence_non_rank0_stdout(rank) + torch.cuda.set_device(rank % torch.cuda.device_count()) + dist.init_process_group(backend="nccl", rank=rank, world_size=world_size) + os.environ["NVTE_CP_POOL_PG"] = "1" + + # Stash pool-shared CP groups on the run_attention_with_cp module so + # run_dpa_with_cp can read them per case. Imported here (after the env var + # is set) to keep import-time side effects minimal. + import run_attention_with_cp as _rac + + _rac._pool_cp_comm_group, _rac._pool_cp_comm_sub_groups = _create_cp_comm_groups( + rank, world_size + ) + + try: + while True: + req = _recv_request(rank) + if req.get("op") == "shutdown": + break + + ok, msg = _run_one(req, rank) + + gathered: list[tuple[bool, str]] = [None] * world_size # type: ignore[list-item] + # gather_object is itself a collective synchronization point — if + # every rank reached it, none is ahead. No extra barrier needed. + dist.gather_object((ok, msg), gathered if rank == 0 else None, dst=0) + + if rank == 0: + all_ok = all(o for o, _ in gathered) + if all_ok: + _send_response(rank, {"ok": True}) + else: + first_err = next(m for o, m in gathered if not o) + _send_response(rank, {"ok": False, "error": first_err}) + # Release the allocator cache so this pool doesn't squat on + # GPUs that an overlapping different-world-size pool needs. + torch.cuda.empty_cache() + finally: + # Tear down pool-shared CP groups before the main PG (NCCL requires + # sub-groups to be destroyed first). Each destroy is independently + # guarded so a wedged communicator on one group doesn't leak the rest. + if _rac._pool_cp_comm_group is not None: + try: + dist.destroy_process_group(_rac._pool_cp_comm_group) + except Exception: + pass + for g in _rac._pool_cp_comm_sub_groups: + try: + dist.destroy_process_group(g) + except Exception: + pass + _rac._pool_cp_comm_group = None + _rac._pool_cp_comm_sub_groups = [] + dist.destroy_process_group() + + +if __name__ == "__main__": + main() diff --git a/tests/pytorch/attention/test_attention_with_cp.py b/tests/pytorch/attention/test_attention_with_cp.py index 23d1bfdd85..f0d2c27c12 100644 --- a/tests/pytorch/attention/test_attention_with_cp.py +++ b/tests/pytorch/attention/test_attention_with_cp.py @@ -2,12 +2,18 @@ # # See LICENSE for license information. +import json import os +import select +import signal import subprocess import sys +import threading +import time import pathlib import logging import copy +from collections import deque import pytest import torch from transformer_engine.pytorch import ( @@ -24,7 +30,7 @@ _current_file = pathlib.Path(__file__).resolve() sys.path.append(str(_current_file.parent.parent)) -from utils import ModelConfig, get_available_attention_backends, run_distributed +from utils import ModelConfig, get_available_attention_backends pytest_logging_level = logging.getLevelName(logging.root.level) @@ -60,19 +66,228 @@ } -def get_bash_arguments(num_gpus_per_node, **kwargs): - args = [ - "python3", - "-m", - "torch.distributed.launch", - "--nproc-per-node=" + str(num_gpus_per_node), - ] - te_path = os.getenv("TE_PATH", "/opt/transformerengine") - script_path = os.path.join(te_path, "tests/pytorch/attention/run_attention_with_cp.py") - args.append(script_path) - for k, v in kwargs.items(): - args.append(f"{k}={v}") - return args +# --- Persistent pool runner ----------------------------------------------- +# +# Each (world_size) is served by one long-lived torchrun running +# run_attention_with_cp_pool.py. We submit one work item per pytest case over +# rank-0 stdin and read one JSON response from rank-0 stdout. Replaces +# the per-case torchrun launch path; init/destroy NCCL once per pool, not +# once per case. +# +# Why two pool sizes: cp_comm_type="a2a+p2p" needs world_size=4; everything +# else uses world_size=2. We can't resize an active PG, so we keep one pool +# per world_size and route each case to the right one. Pools are spawned +# lazily on first use so a session that only exercises 2-GPU cases never +# pays the 4-GPU init cost. + +# Per-case wall is ~5 s p50 / ~15 s max on H100 (test_essential=True). +# 90 s gives ~6× headroom over the slowest observed case while still detecting +# a genuine hang within ~1.5 min instead of ~10 min. Override with the env var +# if a slower machine or expanded test matrix needs more room. +POOL_SUBMIT_TIMEOUT_SEC = float(os.getenv("NVTE_CP_POOL_TIMEOUT_SEC", "90")) + + +class PoolWorker: + # Crash-path AssertionErrors include the tail of the worker's stderr so CI + # JUnit XML shows the actual failure cause (NCCL/CUDA messages, Python + # traceback) inline with the failing test, not just "pool worker died". + # Equivalent in spirit to PR #2965's run_distributed() stderr capture. + _STDERR_BUFFER_LINES = 200 # ring cap (~40 KB ceiling) + _STDERR_TAIL_CHARS = 4000 # how much to attach to the AssertionError + + def __init__(self, world_size: int): + self.world_size = world_size + self.proc: subprocess.Popen | None = None + self._stderr_buf: deque[str] = deque(maxlen=self._STDERR_BUFFER_LINES) + + def _spawn(self) -> None: + te_path = os.getenv("TE_PATH", "/opt/transformerengine") + worker = os.path.join(te_path, "tests/pytorch/attention/run_attention_with_cp_pool.py") + cmd = [ + sys.executable, + "-m", + "torch.distributed.run", + f"--nproc-per-node={self.world_size}", + "--standalone", # picks a free rendezvous port + worker, + ] + # stderr=PIPE so we can capture the tail for crash-path AssertionErrors; + # a daemon drainer thread also echoes each line to sys.stderr so pytest's + # per-test stderr capture still works. The thread is daemon, so it + # self-terminates when the pipe closes — no tracking needed. + self.proc = subprocess.Popen( + cmd, + stdin=subprocess.PIPE, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + text=True, + bufsize=1, + env={**os.environ, "PYTHONUNBUFFERED": "1"}, + # Own process group so _kill can killpg all ranks in one shot; + # without this, terminating the launcher PID leaves rank workers + # as orphans holding CUDA/NCCL state. + start_new_session=True, + ) + self._stderr_buf.clear() + threading.Thread(target=self._drain_stderr, daemon=True).start() + + def _drain_stderr(self) -> None: + proc = self.proc + if proc is None or proc.stderr is None: + return + for line in iter(proc.stderr.readline, ""): + self._stderr_buf.append(line) + sys.stderr.write(line) + sys.stderr.flush() + + def _diag(self, msg: str) -> str: + tail = "".join(self._stderr_buf)[-self._STDERR_TAIL_CHARS :] + if not tail.strip(): + return msg + return f"{msg}\n\n--- pool worker stderr (tail) ---\n{tail}" + + def _ensure_alive(self) -> None: + if self.proc is None or self.proc.poll() is not None: + self._spawn() + + def _killpg(self, sig: int) -> None: + try: + os.killpg(self.proc.pid, sig) + except ProcessLookupError: + pass + + def _kill(self) -> None: + # Kill the whole process group so rank workers don't survive as orphans. + if self.proc and self.proc.poll() is None: + self._killpg(signal.SIGTERM) + try: + self.proc.wait(timeout=5) + except subprocess.TimeoutExpired: + self._killpg(signal.SIGKILL) + self.proc.wait() + self.proc = None + + # One retry on pool-infrastructure failures (worker died / timed out / broken + # pipe). Test-assertion failures from the worker carry the full per-rank + # traceback in resp["error"] and propagate without retry. Every retry leaves + # a [POOL-RETRY] line in stderr so pytest's capture surfaces + # flake patterns in JUnit XML for offline analysis. + _MAX_RETRIES = 1 + + def submit(self, kwargs: dict, timeout: float = POOL_SUBMIT_TIMEOUT_SEC) -> None: + first_err = None + for attempt in range(self._MAX_RETRIES + 1): + try: + return self._submit_once(kwargs, timeout) + except AssertionError as e: + msg_head = str(e).splitlines()[0] + infrastructure_flake = ( + "pool worker died" in msg_head + or "timed out" in msg_head + or "before request could be sent" in msg_head + ) + if not infrastructure_flake or attempt == self._MAX_RETRIES: + if first_err is not None: + sys.stderr.write( + f"[POOL-RETRY-FAIL] world_size={self.world_size}: " + "both attempts died; first error was: " + f"{str(first_err).splitlines()[0]!r}\n" + ) + sys.stderr.flush() + raise + first_err = e + sys.stderr.write( + f"[POOL-RETRY] world_size={self.world_size} attempt {attempt + 1} " + f"died: {msg_head!r}; respawning pool and retrying\n" + ) + sys.stderr.flush() + raise first_err # unreachable; loop either returns or raises + + def _submit_once(self, kwargs: dict, timeout: float) -> None: + self._ensure_alive() + req = json.dumps({"op": "run", "kwargs": kwargs}) + "\n" + try: + self.proc.stdin.write(req) + self.proc.stdin.flush() + except BrokenPipeError: + msg = self._diag("pool worker died before request could be sent") + self._kill() + raise AssertionError(msg) + + # Worker redirects non-rank-0 stdout to /dev/null at fd level, so + # rank 0's JSON line is the only thing that arrives on this pipe. + # select() on a pipe fd is Linux/macOS only — on Windows the select + # module only accepts sockets. CP attention tests run on Linux GPU + # hosts so this is fine; flag if portability is ever needed. + ready, _, _ = select.select([self.proc.stdout], [], [], timeout) + if not ready: + msg = self._diag( + f"pool worker (world_size={self.world_size}) timed out after " + f"{timeout}s; pool killed and will be respawned for the next case" + ) + self._kill() + raise AssertionError(msg) + + line = self.proc.stdout.readline() + if not line: + msg = self._diag("pool worker died mid-request") + self._kill() + raise AssertionError(msg) + + # A stray non-JSON line from rank 0 would desynchronize the protocol; + # turn it into a clear test failure rather than a raw JSONDecodeError. + try: + resp = json.loads(line) + except json.JSONDecodeError as e: + self._kill() + raise AssertionError( + self._diag(f"pool worker JSON protocol broke: {e!r}; line={line!r}") + ) + + if not resp["ok"]: + # Discard the pool so half-aborted CUDA/NCCL/FP8 state from the + # failed case doesn't leak into the next. resp["error"] already + # carries the per-rank traceback via gather_object. + self._kill() + raise AssertionError(resp["error"]) + + def shutdown(self) -> None: + if self.proc and self.proc.poll() is None: + try: + self.proc.stdin.write(json.dumps({"op": "shutdown"}) + "\n") + self.proc.stdin.flush() + self.proc.stdin.close() + except BrokenPipeError: + pass + try: + self.proc.wait(timeout=30) + except subprocess.TimeoutExpired: + self._kill() + self.proc = None + + +@pytest.fixture(scope="session") +def cp_pool(): + """Returns a callable: cp_pool(world_size) -> PoolWorker.""" + pools: dict[int, PoolWorker] = {} + + def _get(world_size: int) -> PoolWorker: + if world_size > torch.cuda.device_count(): + pytest.skip(f"Test requires {world_size} GPUs, but found {torch.cuda.device_count()}") + if world_size not in pools: + pools[world_size] = PoolWorker(world_size) + return pools[world_size] + + yield _get + for p in pools.values(): + p.shutdown() + + +def _submit(pool: PoolWorker, **kwargs) -> None: + # run_dpa_with_cp expects all kwargs as strings (it does e.g. + # `fp8_bwd == "True"`), matching the old argv-based path. Serialize + # everything as strings so we don't accidentally change semantics. + pool.submit({k: str(v) for k, v in kwargs.items()}) dtypes = ["bf16", "fp16"] @@ -91,10 +306,9 @@ def get_bash_arguments(num_gpus_per_node, **kwargs): @pytest.mark.parametrize("model", model_configs_flash_attn.keys()) @pytest.mark.parametrize("qkv_format", qkv_formats) @pytest.mark.parametrize("cp_comm_type", cp_comm_types) -def test_cp_with_flash_attention(dtype, model, qkv_format, cp_comm_type): +def test_cp_with_flash_attention(cp_pool, dtype, model, qkv_format, cp_comm_type): num_gpus = 4 if cp_comm_type == "a2a+p2p" else 2 - if num_gpus > torch.cuda.device_count(): - pytest.skip(f"Test requires {num_gpus} GPUs, but found {torch.cuda.device_count()}") + pool = cp_pool(num_gpus) config = model_configs_flash_attn[model] config.context_parallel = True @@ -140,16 +354,14 @@ def test_cp_with_flash_attention(dtype, model, qkv_format, cp_comm_type): if not flash_attn_supported: pytest.skip("No attention backend available.") - run_distributed( - get_bash_arguments( - num_gpus_per_node=num_gpus, - dtype=dtype, - model=model, - qkv_format=qkv_format, - kernel_backend="FlashAttention", - cp_comm_type=cp_comm_type, - log_level=pytest_logging_level, - ), + _submit( + pool, + dtype=dtype, + model=model, + qkv_format=qkv_format, + kernel_backend="FlashAttention", + cp_comm_type=cp_comm_type, + log_level=pytest_logging_level, ) @@ -274,15 +486,23 @@ def test_cp_with_flash_attention(dtype, model, qkv_format, cp_comm_type): @pytest.mark.parametrize("scaling_mode", [None, "delayed", "current", "mxfp8"]) @pytest.mark.parametrize("f16_O", [True, False]) def test_cp_with_fused_attention( - dtype, model, qkv_format, cp_comm_type, fp8_bwd, fp8_mha, fp8_dpa, scaling_mode, f16_O + cp_pool, + dtype, + model, + qkv_format, + cp_comm_type, + fp8_bwd, + fp8_mha, + fp8_dpa, + scaling_mode, + f16_O, ): config = model_configs_fused_attn[model] config.context_parallel = True config.cp_comm_type = cp_comm_type num_gpus = 4 if cp_comm_type == "a2a+p2p" else 2 - if num_gpus > torch.cuda.device_count(): - pytest.skip(f"Test requires {num_gpus} GPUs, but found {torch.cuda.device_count()} GPUs.") + pool = cp_pool(num_gpus) if get_device_compute_capability() < (9, 0) and qkv_format == "thd": pytest.skip("Only sm90+ architectures support THD format!") @@ -404,21 +624,24 @@ def test_cp_with_fused_attention( if not fused_attn_supported: pytest.skip("No attention backend available.") - run_distributed( - get_bash_arguments( - num_gpus_per_node=num_gpus, - dtype=dtype, - model=model, - qkv_format=qkv_format, - kernel_backend="FusedAttention", - cp_comm_type=cp_comm_type, - fp8_bwd=fp8_bwd, - fp8_dpa=fp8_dpa, - fp8_mha=fp8_mha, - scaling_mode=scaling_mode, - f16_O=f16_O, - is_training=is_training, - deterministic=_deterministic, - log_level=pytest_logging_level, - ), + if _deterministic and config.softmax_type != "vanilla": + pytest.skip("Deterministic mode does not support non-vanilla softmax with FusedAttention") + if _deterministic and config.attn_bias_type == "post_scale_bias" and is_training: + pytest.skip("Deterministic mode does not support post_scale_bias with requires_grad") + + _submit( + pool, + dtype=dtype, + model=model, + qkv_format=qkv_format, + kernel_backend="FusedAttention", + cp_comm_type=cp_comm_type, + fp8_bwd=fp8_bwd, + fp8_dpa=fp8_dpa, + fp8_mha=fp8_mha, + scaling_mode=scaling_mode, + f16_O=f16_O, + is_training=is_training, + deterministic=_deterministic, + log_level=pytest_logging_level, ) diff --git a/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py b/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py index 3db0417bdb..35684625a5 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py @@ -3277,6 +3277,12 @@ def forward( elif o_format == "sbhd": out_f16[i - 1].copy_(out_per_step[i - 1]) if return_max_logit: + # max_logit_per_step[i-1] was written on flash_attn_streams[i-1] + # (cp_stream for i-1=1). The torch.maximum below runs on the + # default stream, so without this wait the read can race with + # the write. The post-loop wait_stream(cp_stream) is too late. + # No-op when flash_attn_streams[i-1] is current_stream(). + torch.cuda.current_stream().wait_stream(flash_attn_streams[i - 1]) max_logit = torch.maximum(max_logit, max_logit_per_step[i - 1]) torch.cuda.current_stream().wait_stream(cp_stream)