diff --git a/.gitmodules b/.gitmodules index 4b188d6bb1..33f6a4e78d 100644 --- a/.gitmodules +++ b/.gitmodules @@ -6,4 +6,8 @@ url = https://github.com/NVIDIA/cudnn-frontend.git [submodule "3rdparty/cutlass"] path = 3rdparty/cutlass - url = https://github.com/NVIDIA/cutlass.git + url = https://github.com/NVIDIA/cutlass.git +[submodule "3rdparty/nccl"] + path = 3rdparty/nccl + url = https://github.com/NVIDIA/nccl.git + branch = v2.30u1 diff --git a/3rdparty/nccl b/3rdparty/nccl new file mode 160000 index 0000000000..6a9bc953ac --- /dev/null +++ b/3rdparty/nccl @@ -0,0 +1 @@ +Subproject commit 6a9bc953ac1c4eef92d5adbe3092d4c2cb0a4c98 diff --git a/build_tools/pytorch.py b/build_tools/pytorch.py index 533addaf53..cf3a91c881 100644 --- a/build_tools/pytorch.py +++ b/build_tools/pytorch.py @@ -76,6 +76,12 @@ def setup_pytorch_extension( setup_mpi_flags(include_dirs, cxx_flags) + # Mirror the NCCL EP gate from setup.py / common CMake. When disabled, the + # ep.cpp source no-ops at the #ifdef boundary; without the define it would + # produce undefined references to nvte_ep_*. + if bool(int(os.getenv("NVTE_BUILD_WITH_NCCL_EP", "1"))): + cxx_flags.append("-DNVTE_WITH_NCCL_EP") + library_dirs = [] libraries = [] if bool(int(os.getenv("NVTE_ENABLE_NVSHMEM", 0))): diff --git a/examples/pytorch/ep/bench/ep_bench.py b/examples/pytorch/ep/bench/ep_bench.py new file mode 100644 index 0000000000..221f851b92 --- /dev/null +++ b/examples/pytorch/ep/bench/ep_bench.py @@ -0,0 +1,396 @@ +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. +"""PyTorch EP perf bench — raw + autograd dispatch / combine on a single EP group. + +One process per GPU; launched via run_ep_bench.sh (torchrun). + +Stages (each timed in its own loop): + * dispatch_raw — ``_ep_dispatch_raw`` (no autograd, no prepare) + * ep_dispatch_fwd — ``ep_dispatch`` (autograd wrapper, forward only) + * ep_dispatch_fwd_bwd — ``ep_dispatch`` + ``.backward()`` on ``0.5*||recv||²`` + * combine_raw — ``_ep_combine_raw`` (no autograd, no weighting) + * ep_combine_fwd — ``ep_combine`` (autograd wrapper, forward only) + * ep_combine_fwd_bwd — ``ep_combine`` + ``.backward()`` + +``ep_prepare`` runs once outside the timed loops. Wall-clock per iter measured +with ``perf_counter_ns`` and NVTX ranges (if available) frame each stage so +``nsys`` can attribute kernels. ``--kineto DIR`` dumps a Chrome trace plus a +per-kernel summary table on rank 0. +""" + +import argparse +import os +import sys +import time +from contextlib import nullcontext + +import numpy as np +import torch +import torch.distributed as dist + +from transformer_engine.pytorch.ep import ( + EpBuffer, + EpHandle, + ep_bootstrap, + ep_combine, + ep_dispatch, + ep_prepare, + _ep_combine_raw, + _ep_dispatch_raw, +) + + +def _parse_args(): + p = argparse.ArgumentParser(description="TE-PyTorch EP perf bench") + p.add_argument("--tokens-per-rank", type=int, default=8192) + p.add_argument("--hidden", type=int, default=7168) + p.add_argument("--top-k", type=int, default=8) + p.add_argument("--num-experts", type=int, default=256) + p.add_argument("--warmup", type=int, default=2) + p.add_argument("--iters", type=int, default=10) + p.add_argument( + "--max-num-sms", + type=int, + default=0, + help="Max SMs for dispatch/combine/preprocess kernels (0 = auto).", + ) + p.add_argument( + "--zero-copy", + action="store_true", + default=False, + help="Use symm-mem-backed EpBuffer (zero-copy path).", + ) + p.add_argument( + "--kineto", + default=None, + help="If set, dump a Kineto Chrome trace + per-kernel summary into this dir (rank 0).", + ) + p.add_argument( + "--cuda-graph", + action="store_true", + default=False, + help=( + "Capture each stage into a CUDA graph and time replay() instead of the eager call. " + "Raw + fwd-only stages use torch.cuda.graph; fwd+bwd stages use " + "torch.cuda.make_graphed_callables to capture forward and backward together." + ), + ) + p.add_argument( + "--mode-label", + default=None, + help="Optional suffix for NVTX range names (e.g. 'fused' / 'unfused').", + ) + return p.parse_args() + + +def _nvtx_funcs(): + """Return push/pop helpers using torch.cuda.nvtx if available.""" + try: + push = torch.cuda.nvtx.range_push + pop = torch.cuda.nvtx.range_pop + return push, pop + except AttributeError: + return lambda _name: None, lambda: None + + +def _device_sm() -> int: + major, minor = torch.cuda.get_device_capability() + return major * 10 + minor + + +def _make_inputs(rank, world_size, T, H, K, E, device): + """Round-robin identity routing + uniform top-k weights.""" + topk_idx = np.empty((T, K), dtype=np.int64) + for t in range(T): + for k in range(K): + topk_idx[t, k] = ((rank * T + t) * K + k) % E + rng = np.random.default_rng(seed=42 + rank) + tokens_np = (rng.standard_normal((T, H), dtype=np.float32) * 0.5).astype(np.float32) + return ( + torch.from_numpy(topk_idx).to(device), + torch.from_numpy(tokens_np).to(device=device, dtype=torch.bfloat16), + torch.full((T, K), 1.0 / K, dtype=torch.float32, device=device), + ) + + +def _time_stage_us(name, fn, iters, nvtx_suffix, push, pop): + """Time ``fn`` for ``iters`` iterations after one untimed warmup; returns mean µs.""" + # Run iters+1 times; drop the first (autotune outlier) and frame NVTX from iter 1. + total_ns = 0 + counted = 0 + for i in range(iters + 1): + if i == 1: + push(f"{name}{nvtx_suffix}") + torch.cuda.synchronize() + t0 = time.perf_counter_ns() + fn() + torch.cuda.synchronize() + dt = time.perf_counter_ns() - t0 + if i == 0: + continue + total_ns += dt + counted += 1 + pop() + return total_ns / 1e3 / counted + + +def main(): + args = _parse_args() + dist.init_process_group(backend="nccl") + rank = dist.get_rank() + world_size = dist.get_world_size() + torch.cuda.set_device(int(os.environ.get("LOCAL_RANK", rank))) + device = torch.device("cuda", torch.cuda.current_device()) + + if _device_sm() < 90: + if rank == 0: + print(f"[ep_bench] SKIPPED: EP requires SM>=90 (got SM{_device_sm()})") + dist.destroy_process_group() + return + if world_size < 4: + if rank == 0: + print(f"[ep_bench] SKIPPED: EP requires >=4 ranks (got {world_size})") + dist.destroy_process_group() + return + + ep_size = world_size + E = args.num_experts + assert E % ep_size == 0, f"num_experts ({E}) must be divisible by ep_size ({ep_size})" + num_local_experts = E // ep_size + T = args.tokens_per_rank + H = args.hidden + K = args.top_k + # Conservative cap: every token could land on every local expert. + recv_pr = world_size * T * K // 2 + if rank == 0: + print( + f"[ep_bench] world={world_size} ep={ep_size} T={T} H={H} K={K} " + f"E={E} (local={num_local_experts}) recv_pr={recv_pr} zero_copy={args.zero_copy}" + + (f" mode={args.mode_label}" if args.mode_label else ""), + flush=True, + ) + + ep_group = dist.new_group(ranks=list(range(world_size)), backend="nccl") + ep_bootstrap( + ep_group, + num_experts=E, + max_tokens_per_rank=T, + recv_capacity_per_rank=recv_pr, + hidden_dim=H, + max_num_sms=args.max_num_sms, + ) + + topk_idx, tokens, topk_w = _make_inputs(rank, world_size, T, H, K, E, device) + + handle = EpHandle( + top_k=K, + max_tokens_per_rank=T, + recv_capacity_per_rank=recv_pr, + hidden_dim=H, + num_local_experts=num_local_experts, + ) + buffer = EpBuffer( + handle, + ep_group=ep_group if args.zero_copy else None, + use_symm_mem=args.zero_copy, + ) + + # ── Prepare once outside the timed loops ────────────────────────────── + ep_prepare(handle, topk_idx) + torch.cuda.synchronize() + + # Raw-path scratch buffers (fixed-size, mutated in place). + recv_tokens = torch.empty(recv_pr, H, dtype=torch.bfloat16, device=device) + recv_w = torch.empty(recv_pr, dtype=torch.float32, device=device) + + # Pre-dispatch a steady recv_tokens / recv_w so combine stages have valid input. + _ep_dispatch_raw(handle, topk_idx, tokens, topk_w, recv_tokens, recv_w) + torch.cuda.synchronize() + expert_out = recv_tokens.clone() # fp-equivalent stand-in for an MLP output. + + nvtx_suffix = f"[{args.mode_label}]" if args.mode_label else "" + push, pop = _nvtx_funcs() + + # ── Stage closures ──────────────────────────────────────────────────── + # + # Persistent inputs for the fwd+bwd stages so the same memory is reused + # across iters (required by make_graphed_callables and matches the eager + # path). + tokens_p = tokens.detach().clone().requires_grad_(True) + eo_p = recv_tokens.detach().clone().requires_grad_(True) + + # Stand-in callables; the cuda-graph branch below swaps in graphed versions. + fwd_bwd_dispatch_fn = lambda x: ep_dispatch( # noqa: E731 + handle, buffer, x, topk_idx, topk_w, zero_copy=args.zero_copy + )[0] + fwd_bwd_combine_fn = lambda eo: ep_combine( # noqa: E731 + handle, buffer, eo, recv_w, zero_copy=args.zero_copy + ) + + def _dispatch_raw(): + _ep_dispatch_raw(handle, topk_idx, tokens, topk_w, recv_tokens, recv_w) + + def _combine_raw(): + out_buf = torch.empty(T, H, dtype=torch.bfloat16, device=device) + _ep_combine_raw(handle, expert_out, out_buf) + + def _ep_dispatch_fwd(): + ep_dispatch(handle, buffer, tokens.detach(), topk_idx, topk_w, zero_copy=args.zero_copy) + + def _ep_dispatch_fwd_bwd(): + tokens_p.grad = None + r = fwd_bwd_dispatch_fn(tokens_p) + (0.5 * (r.float() ** 2).sum()).backward() + + def _ep_combine_fwd(): + ep_combine(handle, buffer, recv_tokens, recv_w, zero_copy=args.zero_copy) + + def _ep_combine_fwd_bwd(): + eo_p.grad = None + out = fwd_bwd_combine_fn(eo_p) + (0.5 * (out.float() ** 2).sum()).backward() + + stages = [ + ("dispatch_raw", _dispatch_raw, True), + ("ep_dispatch_fwd", _ep_dispatch_fwd, True), + ("ep_dispatch_fwd_bwd", _ep_dispatch_fwd_bwd, False), + ("combine_raw", _combine_raw, True), + ("ep_combine_fwd", _ep_combine_fwd, True), + ("ep_combine_fwd_bwd", _ep_combine_fwd_bwd, False), + ] + # Third tuple element: True = direct torch.cuda.graph capture; False = use + # make_graphed_callables (autograd-aware) instead. + + # ── Warmup ─────────────────────────────────────────────────────────── + for _ in range(args.warmup): + for _name, fn, _capt in stages: + fn() + torch.cuda.synchronize() + + # ── Optional CUDA-graph capture ────────────────────────────────────── + # Capture each capturable stage on a side stream and time .replay() + # instead of the eager call. Outputs allocated inside the + # autograd.Function's forward go through the per-capture private pool + # so addresses stay stable across replays. + captured_runners = {} + if args.cuda_graph: + # Graph fwd+bwd of the autograd-wrapped ops via make_graphed_callables. + class _DispatchMod(torch.nn.Module): + def forward(self, x): + return ep_dispatch(handle, buffer, x, topk_idx, topk_w, zero_copy=args.zero_copy)[0] + + class _CombineMod(torch.nn.Module): + def forward(self, eo): + return ep_combine(handle, buffer, eo, recv_w, zero_copy=args.zero_copy) + + disp_mod = _DispatchMod().cuda() + comb_mod = _CombineMod().cuda() + g_disp, g_comb = torch.cuda.make_graphed_callables( + (disp_mod, comb_mod), + ((tokens_p,), (eo_p,)), + ) + fwd_bwd_dispatch_fn = g_disp + fwd_bwd_combine_fn = g_comb + + # Direct torch.cuda.graph capture for raw + fwd-only stages. + side = torch.cuda.Stream() + side.wait_stream(torch.cuda.current_stream()) + with torch.cuda.stream(side): + for name, fn, direct_capturable in stages: + if not direct_capturable: + continue + fn() # prime the allocator for stable replay addresses + torch.cuda.synchronize() + g = torch.cuda.CUDAGraph() + with torch.cuda.graph(g): + fn() + captured_runners[name] = g + torch.cuda.current_stream().wait_stream(side) + torch.cuda.synchronize() + + # ── Optional Kineto profiling ──────────────────────────────────────── + kineto_ctx = nullcontext() + if args.kineto and rank == 0: + os.makedirs(args.kineto, exist_ok=True) + kineto_ctx = torch.profiler.profile( + activities=[ + torch.profiler.ProfilerActivity.CPU, + torch.profiler.ProfilerActivity.CUDA, + ], + record_shapes=False, + with_stack=False, + ) + + # ── Timed loops ────────────────────────────────────────────────────── + results = {} + with kineto_ctx as prof: + for name, fn, _ in stages: + runner = fn + if name in captured_runners: + # Time replay() instead of the eager call. + graph = captured_runners[name] + runner = graph.replay + results[name] = _time_stage_us(name, runner, args.iters, nvtx_suffix, push, pop) + + if rank == 0: + label = f" [{args.mode_label}]" if args.mode_label else "" + print("", flush=True) + print(f"| stage | mean wall (us){label} |", flush=True) + print("|----------------------|---------------:|", flush=True) + for name in ( + "dispatch_raw", + "ep_dispatch_fwd", + "ep_dispatch_fwd_bwd", + "combine_raw", + "ep_combine_fwd", + "ep_combine_fwd_bwd", + ): + print(f"| {name:20s} | {results[name]:14.1f} |", flush=True) + print( + "| (dispatch fwd-raw) |" + f" {results['ep_dispatch_fwd'] - results['dispatch_raw']:14.1f} |", + flush=True, + ) + print( + "| (dispatch bwd-fwd) |" + f" {results['ep_dispatch_fwd_bwd'] - results['ep_dispatch_fwd']:14.1f} |", + flush=True, + ) + print( + "| (combine fwd-raw) |" + f" {results['ep_combine_fwd'] - results['combine_raw']:14.1f} |", + flush=True, + ) + print( + "| (combine bwd-fwd) |" + f" {results['ep_combine_fwd_bwd'] - results['ep_combine_fwd']:14.1f} |", + flush=True, + ) + print("", flush=True) + + if args.kineto and rank == 0 and prof is not None: + trace_path = os.path.join(args.kineto, "ep_bench_trace.json") + prof.export_chrome_trace(trace_path) + print(f"[ep_bench] kineto trace: {trace_path}", flush=True) + print( + prof.key_averages().table(sort_by="cuda_time_total", row_limit=30), + flush=True, + ) + kern_csv = os.path.join(args.kineto, "ep_bench_kernels.csv") + with open(kern_csv, "w") as f: + f.write("name,cuda_time_us,cpu_time_us,count\n") + for evt in prof.key_averages(): + if evt.device_time_total == 0 and evt.cpu_time_total == 0: + continue + f.write(f"{evt.key},{evt.device_time_total},{evt.cpu_time_total},{evt.count}\n") + print(f"[ep_bench] per-kernel CSV: {kern_csv}", flush=True) + + dist.barrier() + dist.destroy_process_group() + sys.stdout.flush() + sys.stderr.flush() + + +if __name__ == "__main__": + main() diff --git a/examples/pytorch/ep/bench/run_ep_bench.sh b/examples/pytorch/ep/bench/run_ep_bench.sh new file mode 100755 index 0000000000..a8e1fdc173 --- /dev/null +++ b/examples/pytorch/ep/bench/run_ep_bench.sh @@ -0,0 +1,77 @@ +#!/usr/bin/env bash +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. +# +# Launcher for examples/pytorch/ep/bench/ep_bench.py. +# Examples: +# bash run_ep_bench.sh # plain run, stdout only +# bash run_ep_bench.sh --zero-copy # symm-mem-backed EpBuffer +# bash run_ep_bench.sh --cuda-graph # capture + replay each stage as a CUDA graph +# bash run_ep_bench.sh --kineto # Chrome trace + per-kernel CSV (rank 0) +# bash run_ep_bench.sh --nsys # nsys profile on rank 0 -> results/pyt_nsys.nsys-rep + +set -uo pipefail + +NSYS=0; KINETO=0; ZERO_COPY=0; CGRAPH=0 +for a in "$@"; do + case "$a" in + --nsys) NSYS=1 ;; + --kineto) KINETO=1 ;; + --zero-copy) ZERO_COPY=1 ;; + --cuda-graph) CGRAPH=1 ;; + *) echo "unknown arg: $a" >&2; exit 2 ;; + esac +done +if [ "${NSYS}" -eq 1 ] && [ "${KINETO}" -eq 1 ]; then + echo "--nsys and --kineto both attach CUPTI; pick one." >&2; exit 2 +fi + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +TE_REPO_ROOT="$(cd "${SCRIPT_DIR}/../../../.." && pwd)" +RESULTS="${SCRIPT_DIR}/results" +mkdir -p "${RESULTS}" +export PYTHONPATH="${TE_REPO_ROOT}${PYTHONPATH:+:${PYTHONPATH}}" + +DETECTED_GPUS=$(nvidia-smi -L 2>/dev/null | wc -l) +NUM_GPUS="${NUM_GPUS:-${DETECTED_GPUS}}" +if [ "${NUM_GPUS}" -lt 4 ]; then + echo "EP bench requires >=4 GPUs (found ${NUM_GPUS}); SKIPPING."; exit 0 +fi +if [ "${NUM_GPUS}" -gt 8 ]; then NUM_GPUS=8; fi + +: "${TIMEOUT_S:=1800}" +: "${NCCL_EP_JIT_CACHE_DIR:=${TMPDIR:-/tmp}/nccl_ep_jit_cache_$(id -u)}" +export NCCL_EP_JIT_CACHE_DIR +mkdir -p "${NCCL_EP_JIT_CACHE_DIR}" + +# Silence per-call symm-mem fallback warning when running the HBM path. +export NVTE_EP_SILENCE_NONSYMM_WARN="${NVTE_EP_SILENCE_NONSYMM_WARN:-1}" + +EXTRA_ARGS=() +TAG="pyt" +[ "${ZERO_COPY}" -eq 1 ] && EXTRA_ARGS+=(--zero-copy) && TAG="${TAG}_zc" +[ "${CGRAPH}" -eq 1 ] && EXTRA_ARGS+=(--cuda-graph) && TAG="${TAG}_cg" +if [ "${KINETO}" -eq 1 ]; then + EXTRA_ARGS+=(--kineto "${RESULTS}/kineto_${TAG}") +fi + +LAUNCH=(torchrun --standalone --nnodes=1 --nproc-per-node="${NUM_GPUS}" + "${SCRIPT_DIR}/ep_bench.py" "${EXTRA_ARGS[@]}") + +if [ "${NSYS}" -eq 1 ]; then + NSYS_CMD=(nsys profile + --output "${RESULTS}/pyt_${TAG}_nsys" + --force-overwrite=true + --trace=cuda,nvtx + --gpu-metrics-devices=none + --cuda-um-cpu-page-faults=false + --cuda-um-gpu-page-faults=false) + echo "[run_ep_bench] launching with nsys (results/${TAG}_nsys.nsys-rep)" + timeout --foreground --signal=TERM "${TIMEOUT_S}" "${NSYS_CMD[@]}" "${LAUNCH[@]}" + RC=$? +else + timeout --foreground --signal=TERM "${TIMEOUT_S}" "${LAUNCH[@]}" + RC=$? +fi +exit $RC diff --git a/examples/pytorch/ep/bench/run_nccl_ep_bench.sh b/examples/pytorch/ep/bench/run_nccl_ep_bench.sh new file mode 100755 index 0000000000..8f6da04a00 --- /dev/null +++ b/examples/pytorch/ep/bench/run_nccl_ep_bench.sh @@ -0,0 +1,62 @@ +#!/usr/bin/env bash +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. +# +# Launcher for the native NCCL EP ``ep_bench`` (baseline for PyTorch comparison). +# Usage: +# bash run_nccl_ep_bench.sh # plain run, stdout only +# bash run_nccl_ep_bench.sh --nsys # nsys → results/nccl_ep_nsys.nsys-rep + +set -uo pipefail + +NSYS=0 +for a in "$@"; do + case "$a" in + --nsys) NSYS=1 ;; + *) echo "unknown arg: $a" >&2; exit 2 ;; + esac +done + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +TE_REPO_ROOT="$(cd "${SCRIPT_DIR}/../../../.." && pwd)" +RESULTS="${SCRIPT_DIR}/results" +mkdir -p "${RESULTS}" + +BIN="${TE_REPO_ROOT}/3rdparty/nccl/build/test/nccl_ep/ep_bench" +LIB="${TE_REPO_ROOT}/3rdparty/nccl/build/lib" +[ -x "${BIN}" ] || { echo "ep_bench not built at ${BIN}" >&2; exit 2; } + +NUM_GPUS=$(nvidia-smi -L 2>/dev/null | wc -l) +if [ "${NUM_GPUS}" -lt 4 ]; then + echo "NCCL EP bench requires >=4 GPUs (found ${NUM_GPUS}); SKIPPING."; exit 0 +fi +if [ "${NUM_GPUS}" -gt 8 ]; then NUM_GPUS=8; fi + +if [ "${NSYS}" -eq 1 ]; then + ITERS=10 +else + ITERS=50 +fi +ARGS=(--algorithm ht --layout em --tokens 2048 --hidden 7168 --top-k 8 + --experts 256 --warmup 5 --iters "${ITERS}") +[ "${NSYS}" -eq 1 ] && ARGS+=(--profile) # enables NVTX ranges + cudaProfilerStart/Stop + +CMD=(/usr/local/mpi/bin/mpirun --allow-run-as-root --oversubscribe -np "${NUM_GPUS}" + -x LD_LIBRARY_PATH="${LIB}:${LD_LIBRARY_PATH:-}" + "${BIN}" "${ARGS[@]}") + +if [ "${NSYS}" -eq 1 ]; then + CMD=(nsys profile + --output "${RESULTS}/nccl_ep_nsys" + --force-overwrite=true + --capture-range=cudaProfilerApi + --capture-range-end=stop + --trace=cuda,nvtx,osrt + "${CMD[@]}") +fi + +[ "${NSYS}" -eq 1 ] && SUFFIX="_nsys" || SUFFIX="" +LOG="${RESULTS}/stdout_nccl_ep${SUFFIX}.txt" +"${CMD[@]}" 2>&1 | tee "${LOG}" +echo "Done. Log: ${LOG}" diff --git a/examples/pytorch/ep/ep_moe.py b/examples/pytorch/ep/ep_moe.py new file mode 100644 index 0000000000..ef87e4b239 --- /dev/null +++ b/examples/pytorch/ep/ep_moe.py @@ -0,0 +1,239 @@ +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. +"""End-to-end MoE example: dispatch → batched expert linear → combine, fwd + bwd. + +One process per GPU; launched via run_test_ep.sh (torchrun). +""" + +import argparse +import os +import sys + +import numpy as np +import torch +import torch.distributed as dist + +from transformer_engine.pytorch.ep import ( + EpHandle, + EpBuffer, + ep_bootstrap, + ep_dispatch, + ep_combine, + symm_mem_alloc, +) + + +def _parse_args(): + p = argparse.ArgumentParser(description="TE-PyTorch EP MoE example (fwd + bwd)") + p.add_argument("--num-tokens", type=int, default=8, help="Per-rank token count.") + p.add_argument("--top-k", type=int, default=2) + p.add_argument("--hidden", type=int, default=32) + p.add_argument("--hidden-out", type=int, default=32) + p.add_argument("--num-experts", type=int, default=None) + p.add_argument("--check", action="store_true", default=True) + p.add_argument( + "--benchmark", + action="store_true", + help="Time fwd over HBM and symm-mem buffers.", + ) + p.add_argument("--benchmark-iters", type=int, default=20) + p.add_argument("--benchmark-warmup", type=int, default=5) + return p.parse_args() + + +def _make_routing(rank, T, K, E, num_local_experts): + """Deterministic routing: topk_idx[t, k] = (rank*NLE + t*K + k) % E.""" + topk_idx = np.empty((T, K), dtype=np.int64) + for t in range(T): + for k in range(K): + topk_idx[t, k] = (rank * num_local_experts + t * K + k) % E + return topk_idx + + +def _batched_expert_linear(recv_tokens, kernels, num_local_experts): + """Per-expert linear via bmm; ``recv_pr // num_local_experts`` slots per expert.""" + recv_pr, _H = recv_tokens.shape + H_out = kernels.shape[-1] + slots_per_expert = recv_pr // num_local_experts + grouped = recv_tokens.view(num_local_experts, slots_per_expert, recv_tokens.shape[-1]) + out = torch.bmm(grouped, kernels.to(grouped.dtype)) + return out.view(recv_pr, H_out) + + +def _reference_moe(tokens, topk_idx, topk_w, kernels): + T, K = topk_idx.shape + H_out = kernels.shape[-1] + out = np.zeros((T, H_out), dtype=np.float32) + for t in range(T): + tok = tokens[t].astype(np.float32) + for k in range(K): + e = int(topk_idx[t, k]) + out[t] += float(topk_w[t, k]) * (tok @ kernels[e].astype(np.float32)) + return out + + +def _reference_grad(tokens, topk_idx, topk_w, kernels): + T, K = topk_idx.shape + H = tokens.shape[-1] + ref_out = _reference_moe(tokens, topk_idx, topk_w, kernels) + grad = np.zeros((T, H), dtype=np.float32) + for t in range(T): + mixed = np.zeros_like(kernels[0]) + for k in range(K): + mixed = mixed + float(topk_w[t, k]) * kernels[int(topk_idx[t, k])] + grad[t] = ref_out[t] @ mixed.T + return ref_out, grad + + +def main(): + args = _parse_args() + + dist.init_process_group(backend="nccl") + rank = dist.get_rank() + world_size = dist.get_world_size() + torch.cuda.set_device(int(os.environ.get("LOCAL_RANK", rank))) + device = torch.device("cuda", torch.cuda.current_device()) + + major, minor = torch.cuda.get_device_capability() + if major * 10 + minor < 90: + if rank == 0: + print(f"[ep_moe] SKIPPED: EP requires SM>=90 (got SM{major}{minor})") + dist.destroy_process_group() + return + + if world_size < 4: + if rank == 0: + print(f"[ep_moe] SKIPPED: EP requires >= 4 ranks (got {world_size})") + dist.destroy_process_group() + return + + ep_size = world_size + num_experts = args.num_experts if args.num_experts is not None else world_size + assert num_experts % ep_size == 0 + num_local_experts = num_experts // ep_size + T = args.num_tokens + recv_pr = ep_size * T * args.top_k + + ep_group = dist.new_group(ranks=list(range(world_size)), backend="nccl") + ep_bootstrap( + ep_group, + num_experts=num_experts, + max_tokens_per_rank=T, + recv_capacity_per_rank=recv_pr, + hidden_dim=args.hidden, + ) + + rng = np.random.default_rng(seed=42 + rank) + tokens_np = (rng.standard_normal((T, args.hidden), dtype=np.float32) * 0.5).astype(np.float32) + topk_idx_np = _make_routing(rank, T, args.top_k, num_experts, num_local_experts) + w_np = np.full((T, args.top_k), 1.0 / args.top_k, dtype=np.float32) + # Same seed across ranks → identical kernel array everywhere. + kr = np.random.default_rng(seed=42) + kernels_np = ( + kr.standard_normal((num_experts, args.hidden, args.hidden_out), dtype=np.float32) + * (1.0 / np.sqrt(args.hidden)) + ).astype(np.float32) + + tokens = ( + torch.from_numpy(tokens_np).to(device=device, dtype=torch.bfloat16).requires_grad_(True) + ) + topk_idx = torch.from_numpy(topk_idx_np).to(device) + topk_w = torch.from_numpy(w_np).to(device) + kernels_local = torch.from_numpy( + kernels_np[rank * num_local_experts : (rank + 1) * num_local_experts] + ).to(device=device, dtype=torch.bfloat16) + + handle = EpHandle( + top_k=args.top_k, + max_tokens_per_rank=T, + recv_capacity_per_rank=recv_pr, + hidden_dim=args.hidden, + num_local_experts=num_local_experts, + ) + buffer = EpBuffer(handle, ep_group) + + recv_t, recv_w_out, _tc = ep_dispatch(handle, buffer, tokens, topk_idx, topk_w) + expert_out = _batched_expert_linear(recv_t, kernels_local, num_local_experts) + out = ep_combine(handle, buffer, expert_out, recv_w_out) + + loss = 0.5 * (out.float() ** 2).sum() + loss.backward() + torch.cuda.synchronize() + + if rank == 0: + print( + f"[ep_moe] loss={float(loss):.4f} grad_tokens.shape={tuple(tokens.grad.shape)} " + f"ep={ep_size} num_experts={num_experts} recv_pr={recv_pr}" + ) + + if args.benchmark: + # Compare dispatch + expert + combine over HBM vs symm-mem payload buffers. + import time + + def _time(label, tokens_buf, buffer_obj): + torch.cuda.synchronize() + dist.barrier() + for _ in range(args.benchmark_warmup): + rt, rw, _tc = ep_dispatch(handle, buffer_obj, tokens_buf, topk_idx, topk_w) + eo = _batched_expert_linear(rt, kernels_local, num_local_experts) + ep_combine(handle, buffer_obj, eo, rw) + torch.cuda.synchronize() + dist.barrier() + t0 = time.perf_counter() + for _ in range(args.benchmark_iters): + rt, rw, _tc = ep_dispatch(handle, buffer_obj, tokens_buf, topk_idx, topk_w) + eo = _batched_expert_linear(rt, kernels_local, num_local_experts) + ep_combine(handle, buffer_obj, eo, rw) + torch.cuda.synchronize() + dt_ms = (time.perf_counter() - t0) * 1000.0 / args.benchmark_iters + if rank == 0: + print( + f"[ep_moe --benchmark] {label}: {dt_ms:.3f} ms/iter " + f"(iters={args.benchmark_iters})" + ) + return dt_ms + + buffer_hbm = EpBuffer(handle, ep_group=None, use_symm_mem=False) + hbm_ms = _time("regular HBM", tokens.detach(), buffer_hbm) + + # Place the dispatch input in symm-mem too for the fast-path comparison. + try: + tokens_sm = symm_mem_alloc((T, args.hidden), torch.bfloat16, ep_group, device=device) + tokens_sm.copy_(tokens.detach()) + symm_ms = _time("symm-mem", tokens_sm, buffer) + if rank == 0: + print(f"[ep_moe --benchmark] speedup: {hbm_ms / symm_ms:.2f}x") + except RuntimeError as e: + if rank == 0: + print(f"[ep_moe --benchmark] symm-mem path skipped: {e}") + + if args.check: + # All-gather inputs/outputs/grads for a global reference comparison. + global_tokens = [torch.empty_like(tokens) for _ in range(world_size)] + global_topk_idx = [torch.empty_like(topk_idx) for _ in range(world_size)] + global_topk_w = [torch.empty_like(topk_w) for _ in range(world_size)] + global_out = [torch.empty_like(out) for _ in range(world_size)] + global_grad = [torch.empty_like(tokens.grad) for _ in range(world_size)] + dist.all_gather(global_tokens, tokens.detach()) + dist.all_gather(global_topk_idx, topk_idx) + dist.all_gather(global_topk_w, topk_w) + dist.all_gather(global_out, out.detach()) + dist.all_gather(global_grad, tokens.grad) + if rank == 0: + all_tokens = torch.cat(global_tokens).float().cpu().numpy() + all_idx = torch.cat(global_topk_idx).cpu().numpy() + all_w = torch.cat(global_topk_w).cpu().numpy() + all_out = torch.cat(global_out).float().cpu().numpy() + all_grad = torch.cat(global_grad).float().cpu().numpy() + ref_out, ref_grad = _reference_grad(all_tokens, all_idx, all_w, kernels_np) + np.testing.assert_allclose(all_out, ref_out, rtol=5e-2, atol=5e-2) + np.testing.assert_allclose(all_grad, ref_grad, rtol=5e-2, atol=5e-2) + print(f"[ep_moe] --check PASSED (ref_out.sum()={float(ref_out.sum()):.4f})") + + dist.destroy_process_group() + + +if __name__ == "__main__": + main() + sys.exit(0) diff --git a/examples/pytorch/ep/run_test_ep.sh b/examples/pytorch/ep/run_test_ep.sh new file mode 100755 index 0000000000..9727590cea --- /dev/null +++ b/examples/pytorch/ep/run_test_ep.sh @@ -0,0 +1,37 @@ +#!/bin/bash +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +set -uo pipefail + +DETECTED_GPUS=$(nvidia-smi -L 2>/dev/null | wc -l) +NUM_GPUS="${NUM_GPUS:-${DETECTED_GPUS}}" +if [ "${NUM_GPUS}" -lt 4 ]; then + echo "EP requires >= 4 GPUs (found ${NUM_GPUS}); SKIPPING." + exit 0 +fi +if [ "${NUM_GPUS}" -gt 8 ]; then NUM_GPUS=8; fi + +: ${TE_PATH:=/opt/transformerengine} +: ${TEST_TIMEOUT_S:=120} + +SCRIPT="${TE_PATH}/examples/pytorch/ep/ep_moe.py" +export PYTHONPATH="${TE_PATH}${PYTHONPATH:+:${PYTHONPATH}}" + +# Stage JIT cubins on tmpfs for fast iteration. +: ${NCCL_EP_JIT_CACHE_DIR:="${TMPDIR:-/tmp}/nccl_ep_jit_cache_$(id -u)"} +export NCCL_EP_JIT_CACHE_DIR +mkdir -p "$NCCL_EP_JIT_CACHE_DIR" + +echo "*** Executing ep_moe.py across ${NUM_GPUS} GPUs (timeout=${TEST_TIMEOUT_S}s) ***" +timeout --foreground --signal=KILL "${TEST_TIMEOUT_S}" \ + torchrun --standalone --nnodes=1 --nproc-per-node="${NUM_GPUS}" \ + "${SCRIPT}" --check 2>&1 | tee stdout_ep_moe.txt +RC=${PIPESTATUS[0]} + +RET=0 +if [ "${RC}" -ne 0 ]; then RET=1; fi +if grep -qE "FAILED|Traceback" stdout_ep_moe.txt; then RET=1; fi +rm -f stdout_ep_moe.txt +exit $RET diff --git a/qa/L1_cpp_distributed/test.sh b/qa/L1_cpp_distributed/test.sh index 8d767a4efb..7e5ce2cf0d 100755 --- a/qa/L1_cpp_distributed/test.sh +++ b/qa/L1_cpp_distributed/test.sh @@ -14,4 +14,7 @@ if [[ $(nvidia-smi --list-gpus | wc -l) -ge 4 ]]; then cmake -GNinja -S. -Bbuild cmake --build build mpirun --allow-run-as-root --np 4 --oversubscribe ./build/test_comm_gemm + + # EP suites; runner self-skips on pre-Hopper GPUs. + bash ./run_test_ep.sh 4 ./build fi diff --git a/setup.py b/setup.py index ec277b6349..db360c8a29 100644 --- a/setup.py +++ b/setup.py @@ -83,6 +83,34 @@ def setup_common_extension() -> CMakeExtension: cusolvermp_dir = os.getenv("CUSOLVERMP_HOME", "/usr") cmake_flags.append(f"-DCUSOLVERMP_DIR={cusolvermp_dir}") + # NCCL EP: on by default; auto-disabled if no arch >= 90. + # Set NVTE_BUILD_WITH_NCCL_EP=0/1 to force off/on. + nccl_ep_env = os.getenv("NVTE_BUILD_WITH_NCCL_EP") + explicit_nccl_ep = nccl_ep_env is not None + build_with_nccl_ep = bool(int(nccl_ep_env)) if explicit_nccl_ep else True + + if build_with_nccl_ep: + arch_tokens = [a.strip() for a in str(archs or "").split(";") if a.strip()] + has_hopper_or_newer = any(t.lower() == "native" for t in arch_tokens) or any( + int(t.rstrip("af")) >= 90 for t in arch_tokens if t.rstrip("af").isdigit() + ) + if not has_hopper_or_newer: + if explicit_nccl_ep: + raise RuntimeError( + "NVTE_BUILD_WITH_NCCL_EP=1 requires at least one CUDA arch >= 90 in " + f"NVTE_CUDA_ARCHS (got '{archs}'). Add '90' or unset NVTE_BUILD_WITH_NCCL_EP." + ) + print( + "[NCCL EP] No CUDA arch >= 90 in NVTE_CUDA_ARCHS" + f" ('{archs}'); auto-disabling NCCL EP (nvte_ep_* will throw at runtime)." + ) + build_with_nccl_ep = False + + if build_with_nccl_ep: + build_nccl_ep_submodule() + else: + cmake_flags.append("-DNVTE_WITH_NCCL_EP=OFF") + # Add custom CMake arguments from environment variable nvte_cmake_extra_args = os.getenv("NVTE_CMAKE_EXTRA_ARGS") if nvte_cmake_extra_args: @@ -128,6 +156,105 @@ def setup_requirements() -> Tuple[List[str], List[str]]: return [remove_dups(reqs) for reqs in [install_reqs, test_reqs]] +def _discover_nccl_home() -> str: + """Resolve NCCL_HOME: honor env var, else probe well-known prefixes, else ldconfig.""" + env_home = os.environ.get("NCCL_HOME") + if env_home: + if (Path(env_home) / "include" / "nccl.h").exists(): + return env_home + print( + f"[NCCL EP] WARNING: NCCL_HOME='{env_home}' is set but " + f"'{env_home}/include/nccl.h' was not found; falling back to system probes." + ) + + for cand in ("/opt/nvidia/nccl", "/usr/local/nccl", "/usr"): + p = Path(cand) + if (p / "include" / "nccl.h").exists() and any( + (p / "lib" / name).exists() or (p / "lib64" / name).exists() + for name in ("libnccl.so", "libnccl.so.2") + ): + return str(p) + + try: + out = subprocess.check_output(["ldconfig", "-p"], stderr=subprocess.DEVNULL).decode() + for line in out.splitlines(): + if "libnccl.so" in line and "=>" in line: + lib_path = Path(line.split("=>")[-1].strip()) + root = lib_path.parent.parent + if (root / "include" / "nccl.h").exists(): + return str(root) + except (subprocess.CalledProcessError, FileNotFoundError): + pass + + raise RuntimeError( + "Could not locate NCCL core (nccl.h + libnccl.so). Set NCCL_HOME to the install prefix." + ) + + +def build_nccl_ep_submodule() -> str: + """Build libnccl_ep.so from the 3rdparty/nccl submodule. + + NCCL EP is on by default; the system NCCL core (libnccl.so) supplies the + headers and runtime symbols. Returns the submodule build directory. + """ + nccl_root = current_file_path / "3rdparty" / "nccl" + if not (nccl_root / "Makefile").exists(): + raise RuntimeError( + f"NCCL submodule not found at {nccl_root}. " + "Run `git submodule update --init --recursive`." + ) + + build_dir = nccl_root / "build" + nccl_ep_lib = build_dir / "lib" / "libnccl_ep.so" + + archs = cuda_archs() or "90" + arch_list = [] + for a in str(archs).split(";"): + a = a.strip().rstrip("af") + if a and a.isdigit() and int(a) >= 90: + arch_list.append(a) + if not arch_list: + arch_list = ["90"] + gencode = " ".join(f"-gencode=arch=compute_{a},code=sm_{a}" for a in arch_list) + + nproc = os.cpu_count() or 8 + env = os.environ.copy() + env["NVCC_GENCODE"] = gencode + # NCCL EP needs the core NCCL headers + libnccl.so; write NCCL EP build + # outputs to the submodule's local build/ tree. + nccl_home = _discover_nccl_home() + env["NCCL_HOME"] = nccl_home + env["NCCL_EP_BUILDDIR"] = str(build_dir) + + if not nccl_ep_lib.exists(): + print(f"[NCCL EP] Building libnccl_ep.so (gencode='{gencode}')") + subprocess.check_call( + ["make", "-j", str(nproc), "-C", "contrib/nccl_ep", "lib"], + cwd=str(nccl_root), + env=env, + ) + + # TE's CMake expects nccl.h under 3rdparty/nccl/build/include/ for its + # version check. Mirror the top-level host headers from the system NCCL + # install — DON'T mirror nccl_device/ because the submodule ships its own + # newer copy at src/include/nccl_device/ with device-side templates that + # conflict with older system versions, and the JIT include path picks the + # submodule's. + nccl_include = build_dir / "include" + nccl_include.mkdir(parents=True, exist_ok=True) + for cand in (Path(nccl_home) / "include", Path("/usr/include")): + p = Path(cand) + if (p / "nccl.h").exists(): + for name in ("nccl.h", "nccl_net.h", "nccl_tuner.h"): + src = p / name + dst = nccl_include / name + if src.exists() and not dst.exists(): + dst.symlink_to(src) + break + + return str(build_dir) + + def git_check_submodules() -> None: """ Attempt to checkout git submodules automatically during setup. diff --git a/tests/cpp_distributed/CMakeLists.txt b/tests/cpp_distributed/CMakeLists.txt index 0d7258a81d..3870f57911 100644 --- a/tests/cpp_distributed/CMakeLists.txt +++ b/tests/cpp_distributed/CMakeLists.txt @@ -30,7 +30,7 @@ if(NOT DEFINED TE_LIB_PATH) get_filename_component(TE_LIB_PATH ${TE_LIB_FILE} DIRECTORY) endif() -find_library(TE_LIB NAMES transformer_engine PATHS "${TE_LIB_PATH}/.." ${TE_LIB_PATH} ENV TE_LIB_PATH REQUIRED) +find_library(TE_LIB NAMES transformer_engine PATHS "${TE_LIB_PATH}/.." ${TE_LIB_PATH} ENV TE_LIB_PATH REQUIRED NO_CMAKE_SYSTEM_PATH) message(STATUS "Found transformer_engine library: ${TE_LIB}") include_directories(../../transformer_engine/common/include) @@ -46,12 +46,99 @@ add_executable(test_comm_gemm find_package(OpenMP REQUIRED) find_package(MPI REQUIRED) + +# ── NCCL library ────────────────────────────────────────────────────────────── +# Search order: NCCL_HOME env → 3rdparty/nccl submodule build → system paths. +set(NCCL_SUBMODULE_BUILD "${CMAKE_CURRENT_SOURCE_DIR}/../../3rdparty/nccl/build") find_library(NCCL_LIB NAMES nccl libnccl - PATH_SUFFIXES lib + HINTS $ENV{NCCL_HOME}/lib ${NCCL_SUBMODULE_BUILD}/lib + PATH_SUFFIXES lib lib64 REQUIRED) + +# NCCL headers: prefer submodule build output (has the handle_init API), +# then submodule src, then system (CUDA toolkit). +set(NCCL_SUBMODULE_INCLUDE "${CMAKE_CURRENT_SOURCE_DIR}/../../3rdparty/nccl/build/include") +set(NCCL_SUBMODULE_SRC_INCLUDE "${CMAKE_CURRENT_SOURCE_DIR}/../../3rdparty/nccl/src/include") +if(EXISTS "${NCCL_SUBMODULE_INCLUDE}/nccl.h") + set(NCCL_INCLUDE_DIR "${NCCL_SUBMODULE_INCLUDE}") +elseif(EXISTS "${NCCL_SUBMODULE_SRC_INCLUDE}/nccl.h") + set(NCCL_INCLUDE_DIR "${NCCL_SUBMODULE_SRC_INCLUDE}") +elseif(DEFINED ENV{NCCL_HOME}) + set(NCCL_INCLUDE_DIR "$ENV{NCCL_HOME}/include") +endif() target_include_directories(test_comm_gemm PRIVATE ${MPI_CXX_INCLUDE_PATH} $ENV{CUBLASMP_HOME}/include) target_link_libraries(test_comm_gemm PUBLIC CUDA::cuda_driver CUDA::cudart GTest::gtest ${TE_LIB} CUDA::nvrtc MPI::MPI_CXX ${NCCL_LIB} OpenMP::OpenMP_CXX) include(GoogleTest) gtest_discover_tests(test_comm_gemm DISCOVERY_TIMEOUT 600) + +# ── EP distributed tests (HT mode) ───────────────────────────────────────── +# No MPI dependency — processes are spawned by run_test_ep.sh with +# --rank / --nranks flags. ncclUniqueId exchange uses a +# shared temp file (see test_ep_common.h for details). +# Headers + libs come from the in-tree 3rdparty/nccl submodule build. +set(NCCL_EP_SUBMODULE_ROOT + "${CMAKE_CURRENT_SOURCE_DIR}/../../3rdparty/nccl") +find_library(NCCL_EP_LIB + NAMES nccl_ep libnccl_ep + HINTS ${NCCL_EP_SUBMODULE_ROOT}/build/lib + NO_DEFAULT_PATH + REQUIRED) + +set(NCCL_EP_INCLUDE_DIR "${NCCL_EP_SUBMODULE_ROOT}/contrib/nccl_ep/include") +if(NOT EXISTS "${NCCL_EP_INCLUDE_DIR}/nccl_ep.h") + message(FATAL_ERROR + "NCCL EP header not found at ${NCCL_EP_INCLUDE_DIR}/nccl_ep.h. " + "Run `git submodule update --init --recursive` to checkout 3rdparty/nccl.") +endif() +message(STATUS "EP test: NCCL EP headers: ${NCCL_EP_INCLUDE_DIR}") + +# Collect NCCL include dirs shared by all EP test targets (nccl_ep.h + nccl.h). +set(EP_TEST_NCCL_INCLUDES ${NCCL_EP_INCLUDE_DIR}) +if(DEFINED NCCL_INCLUDE_DIR) + list(APPEND EP_TEST_NCCL_INCLUDES ${NCCL_INCLUDE_DIR}) + message(STATUS "EP test: NCCL headers: ${NCCL_INCLUDE_DIR}") +endif() + +set(EP_TEST_COMMON_INCLUDES + ${EP_TEST_NCCL_INCLUDES} + ../../transformer_engine/common/include + ../../transformer_engine/common + ${CMAKE_CURRENT_SOURCE_DIR}) + +set(EP_TEST_COMMON_LIBS + CUDA::cuda_driver + CUDA::cudart + CUDA::nvrtc + GTest::gtest + ${TE_LIB} + ${NCCL_LIB} + ${NCCL_EP_LIB}) + +# nvrtc symbols are referenced from libtransformer_engine.so but not in its +# DT_NEEDED list (loaded via dlopen in Python). For cpp tests we link nvrtc +# explicitly with --no-as-needed so the linker keeps the dependency. +set(EP_TEST_LINK_OPTS "LINKER:--no-as-needed") + +# ── EP init tests (InitPath, HandleMemSizeQuery) ───────────────────────────── +add_executable(test_ep_init test_ep_init.cu) +target_include_directories(test_ep_init PRIVATE ${EP_TEST_COMMON_INCLUDES}) +target_link_libraries(test_ep_init PUBLIC ${EP_TEST_COMMON_LIBS}) +target_link_options(test_ep_init PUBLIC ${EP_TEST_LINK_OPTS}) + +# ── EP pipeline tests (dispatch, combine, bwd, integrated) ─────────────────── +add_executable(test_ep_pipeline test_ep_pipeline.cu) +target_include_directories(test_ep_pipeline PRIVATE ${EP_TEST_COMMON_INCLUDES}) +target_link_libraries(test_ep_pipeline PUBLIC ${EP_TEST_COMMON_LIBS}) +target_link_options(test_ep_pipeline PUBLIC ${EP_TEST_LINK_OPTS}) + +# ── EP coverage tests (multi-handle, top_k=1, empty experts, negatives, threading) ── +add_executable(test_ep_coverage test_ep_coverage.cu) +target_include_directories(test_ep_coverage PRIVATE ${EP_TEST_COMMON_INCLUDES}) +target_link_libraries(test_ep_coverage PUBLIC ${EP_TEST_COMMON_LIBS}) +target_link_options(test_ep_coverage PUBLIC ${EP_TEST_LINK_OPTS}) + +# Do NOT use gtest_discover_tests — these binaries require multi-process +# launch via run_test_ep.sh, not direct single-process execution. +message(STATUS "EP distributed tests enabled: ${NCCL_EP_LIB}") diff --git a/tests/cpp_distributed/run_test_ep.sh b/tests/cpp_distributed/run_test_ep.sh new file mode 100755 index 0000000000..017d3f807b --- /dev/null +++ b/tests/cpp_distributed/run_test_ep.sh @@ -0,0 +1,137 @@ +#!/usr/bin/env bash +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. +# +# Run TE EP distributed unit tests across multiple GPUs. +# +# Spawns one background bash process per GPU (no MPI dependency), matching the +# JAX multi-process launcher style. ncclUniqueId is exchanged via a shared +# temp file (see test_ep_common.h). Each rank builds its own ncclComm_t and +# passes it to nvte_ep_initialize. +# +# Usage: +# bash run_test_ep.sh [num_gpus] [build_dir] +# +# Defaults: +# num_gpus = number of GPUs visible to nvidia-smi +# build_dir = /build +# +# Environment variables: +# GTEST_FILTER — forwarded to all processes (e.g., "EPDispatchTest.*") +# TEST_TIMEOUT_S — per-process timeout in seconds (default: 180) + +set -euo pipefail + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +BUILD_DIR="${2:-${SCRIPT_DIR}/build}" +NUM_GPUS="${1:-$(nvidia-smi -L 2>/dev/null | wc -l)}" +TEST_TIMEOUT_S="${TEST_TIMEOUT_S:-180}" + +# Skip cleanly on pre-Hopper: NCCL EP requires SM>=90. +MIN_SM=$(nvidia-smi --query-gpu=compute_cap --format=csv,noheader 2>/dev/null \ + | awk -F. 'NR==1 || ($1*10+$2) 0 && MIN_SM < 90 )); then + echo "NCCL EP requires SM>=90 (lowest visible GPU is SM${MIN_SM}); SKIPPING." + exit 0 +fi + +GTEST_ARGS="${GTEST_FILTER:+--gtest_filter=${GTEST_FILTER}}" +OVERALL_FAIL=0 + +# --------------------------------------------------------------------------- +# run_suite BINARY SUITE_NAME MIN_GPUS +# --------------------------------------------------------------------------- +run_suite() { + local BINARY="$1" + local SUITE_NAME="$2" + local MIN_GPUS="${3:-2}" + + local TEST_BIN="${BUILD_DIR}/${BINARY}" + + if [[ ! -x "${TEST_BIN}" ]]; then + echo "ERROR: binary not found: ${TEST_BIN}" + echo "Build: cd ${SCRIPT_DIR} && mkdir -p build && cd build && cmake .. && make" + OVERALL_FAIL=1 + return + fi + + if (( NUM_GPUS < MIN_GPUS )); then + echo "${SUITE_NAME}: requires ${MIN_GPUS} GPUs, found ${NUM_GPUS}. Skipping." + return + fi + + local TMPDIR_L="${TMPDIR:-/tmp}" + local UID_FILE="${TMPDIR_L}/te_ep_uid_${BINARY}_$$" + rm -f "${UID_FILE}" + + local LOG_DIR + LOG_DIR=$(mktemp -d) + local FAIL=0 + + echo "=== ${SUITE_NAME} ===" + echo " GPUs: ${NUM_GPUS} Binary: ${TEST_BIN}" + echo + + # Spawn one background process per GPU. ncclUniqueId is exchanged via the + # shared UID_FILE. Each process is wrapped in `timeout` to detect hangs early. + local PIDS=() + for i in $(seq 0 $((NUM_GPUS - 1))); do + timeout --foreground --signal=KILL "${TEST_TIMEOUT_S}" \ + "${TEST_BIN}" \ + --rank="${i}" \ + --nranks="${NUM_GPUS}" \ + --uid-file="${UID_FILE}" \ + ${GTEST_ARGS} \ + > "${LOG_DIR}/rank_${i}.log" 2>&1 & + PIDS+=($!) + done + for i in $(seq 0 $((NUM_GPUS - 1))); do + if ! wait "${PIDS[$i]}"; then + local rc=$? + FAIL=1 + if [[ $rc -eq 137 || $rc -eq 124 ]]; then + echo " rank ${i}: TIMEOUT after ${TEST_TIMEOUT_S}s (rc=${rc})" + fi + fi + done + + echo "--- Rank 0 output ---" + cat "${LOG_DIR}/rank_0.log" + + if (( FAIL )); then + for i in $(seq 1 $((NUM_GPUS - 1))); do + echo "--- Rank ${i} output ---" + cat "${LOG_DIR}/rank_${i}.log" + done + echo "=== ${SUITE_NAME}: FAILED ===" + OVERALL_FAIL=1 + else + echo "=== ${SUITE_NAME}: ALL PASSED ===" + fi + + rm -rf "${LOG_DIR}" + rm -f "${UID_FILE}" +} + +# --------------------------------------------------------------------------- +# Cleanup on abort +# --------------------------------------------------------------------------- +cleanup() { rm -f "${TMPDIR:-/tmp}"/te_ep_uid_*_"$$" 2>/dev/null || true; } +trap cleanup EXIT INT TERM + +# --------------------------------------------------------------------------- +# Run all suites +# --------------------------------------------------------------------------- +run_suite "test_ep_init" "EP Init Tests" 2 +run_suite "test_ep_pipeline" "EP Pipeline Tests" 2 +run_suite "test_ep_coverage" "EP Coverage Tests" 2 + +echo +if (( OVERALL_FAIL )); then + echo "=== SOME SUITES FAILED ===" +else + echo "=== ALL SUITES PASSED ===" +fi + +exit "${OVERALL_FAIL}" diff --git a/tests/cpp_distributed/test_ep_common.h b/tests/cpp_distributed/test_ep_common.h new file mode 100644 index 0000000000..77baa92b0c --- /dev/null +++ b/tests/cpp_distributed/test_ep_common.h @@ -0,0 +1,308 @@ +/************************************************************************* + * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +/* + * Shared TE EP test infrastructure. Include once per TU; ep_bootstrap() in + * each test binary's main() populates process-level globals. + * Defaults: 4 experts/rank, hidden_dim=256, max_tokens_per_rank=64. + */ +#pragma once + +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include + +// ── Error-checking macros ───────────────────────────────────────────────────── + +#define CHECK_NCCL(expr) \ + do { \ + ncclResult_t _err = (expr); \ + if (_err != ncclSuccess) \ + FAIL() << "NCCL error " << _err << ": " << ncclGetErrorString(_err); \ + } while (false) + +#define CHECK_CUDA(expr) \ + do { \ + cudaError_t _err = (expr); \ + if (_err != cudaSuccess) \ + FAIL() << "CUDA error " << _err << ": " << cudaGetErrorString(_err); \ + } while (false) + +#define ASSERT_CUDA_OK(expr) \ + do { \ + cudaError_t _err = (expr); \ + if (_err != cudaSuccess) { \ + fprintf(stderr, "CUDA error %d: %s\n", _err, cudaGetErrorString(_err)); \ + exit(EXIT_FAILURE); \ + } \ + } while (false) + +#define ASSERT_NCCL_OK(expr) \ + do { \ + ncclResult_t _err = (expr); \ + if (_err != ncclSuccess) { \ + fprintf(stderr, "NCCL error %d: %s\n", _err, ncclGetErrorString(_err)); \ + exit(EXIT_FAILURE); \ + } \ + } while (false) + +// ── Process-level state ─────────────────────────────────────────────────────── + +static int g_process_id = -1; +static int g_num_processes = -1; +static std::string g_uid_file; + +static int g_sm_major = -1; // set by ep_bootstrap; -1 until then +static int g_ep_size = -1; +static int g_num_experts = -1; +static int g_hidden_dim = 256; +static int g_max_tokens_per_rank = 64; +static bool g_ep_initialized = false; +static ncclComm_t g_ep_comm = nullptr; // owned by harness, destroyed in ep_teardown + +// ── TensorHandle RAII wrapper ───────────────────────────────────────────────── + +// View over a caller-owned device buffer; owns NVTETensor metadata only. Move-only. +struct TensorHandle { + NVTETensor tensor = nullptr; + void* dev_ptr = nullptr; + + ~TensorHandle() { + if (tensor) nvte_destroy_tensor(tensor); + } + + TensorHandle() = default; + TensorHandle(const TensorHandle&) = delete; + TensorHandle& operator=(const TensorHandle&) = delete; + + TensorHandle(TensorHandle&& o) noexcept : tensor(o.tensor), dev_ptr(o.dev_ptr) { + o.tensor = nullptr; o.dev_ptr = nullptr; + } + TensorHandle& operator=(TensorHandle&& o) noexcept { + if (this != &o) { + if (tensor) nvte_destroy_tensor(tensor); + tensor = o.tensor; dev_ptr = o.dev_ptr; + o.tensor = nullptr; o.dev_ptr = nullptr; + } + return *this; + } +}; + +static TensorHandle make_nvte_tensor(void* dev_ptr, + const std::vector& shape, + NVTEDType dtype) { + TensorHandle h; + h.dev_ptr = dev_ptr; + h.tensor = nvte_create_tensor(NVTE_DELAYED_TENSOR_SCALING); + + NVTEShape s; + s.ndim = shape.size(); + for (size_t i = 0; i < shape.size(); ++i) s.data[i] = shape[i]; + + NVTEBasicTensor bt; + bt.data_ptr = dev_ptr; + bt.dtype = dtype; + bt.shape = s; + nvte_set_tensor_param_v2(h.tensor, kNVTERowwiseData, &bt, sizeof(bt)); + + return h; +} + +// RAII owner for a cudaMalloc'd device buffer; frees on destruction. +template +struct DevBuf { + T* ptr = nullptr; + size_t count = 0; + + DevBuf() = default; + explicit DevBuf(size_t n) { alloc(n); } + ~DevBuf() { reset(); } + + DevBuf(const DevBuf&) = delete; + DevBuf& operator=(const DevBuf&) = delete; + DevBuf(DevBuf&& o) noexcept : ptr(o.ptr), count(o.count) { o.ptr = nullptr; o.count = 0; } + DevBuf& operator=(DevBuf&& o) noexcept { + if (this != &o) { reset(); ptr = o.ptr; count = o.count; o.ptr = nullptr; o.count = 0; } + return *this; + } + + void alloc(size_t n) { + reset(); + count = n; + if (n > 0) { + cudaError_t e = cudaMalloc(&ptr, n * sizeof(T)); + if (e != cudaSuccess) { + fprintf(stderr, "DevBuf cudaMalloc(%zu) failed: %s\n", n * sizeof(T), + cudaGetErrorString(e)); + ptr = nullptr; + count = 0; + } + } + } + + void reset() { + if (ptr) { cudaFree(ptr); ptr = nullptr; } + count = 0; + } + + T* get() const { return ptr; } + size_t bytes() const { return count * sizeof(T); } +}; + +// ── Shared routing helper ───────────────────────────────────────────────────── + +// Balanced round-robin routing: token t on rank r maps top_k experts to +// (r * num_local_experts + t * top_k + k) % num_experts +static inline std::vector routing_balanced( + int rank, int num_tokens, int top_k, int num_experts, int num_local_experts) { + std::vector idx(num_tokens * top_k); + for (int t = 0; t < num_tokens; ++t) + for (int k = 0; k < top_k; ++k) + idx[t * top_k + k] = (rank * num_local_experts + t * top_k + k) % num_experts; + return idx; +} + +// ── File-based ncclUniqueId exchange ───────────────────────────────────────── + +static void exchange_unique_id(ncclUniqueId* uid) { + const size_t sz = sizeof(ncclUniqueId); + + if (g_process_id == 0) { + ASSERT_NCCL_OK(ncclGetUniqueId(uid)); + FILE* f = fopen(g_uid_file.c_str(), "wb"); + if (!f) { fprintf(stderr, "Cannot open uid file: %s\n", g_uid_file.c_str()); exit(EXIT_FAILURE); } + fwrite(uid, 1, sz, f); + fclose(f); + } else { + auto deadline = std::chrono::steady_clock::now() + std::chrono::seconds(60); + while (true) { + FILE* f = fopen(g_uid_file.c_str(), "rb"); + if (f) { + fseek(f, 0, SEEK_END); + if (static_cast(ftell(f)) >= sz) { + fseek(f, 0, SEEK_SET); + size_t n = fread(uid, 1, sz, f); + fclose(f); + if (n == sz) break; + } else { + fclose(f); + } + } + if (std::chrono::steady_clock::now() > deadline) { + fprintf(stderr, "Process %d: timed out waiting for uid file\n", g_process_id); + exit(EXIT_FAILURE); + } + std::this_thread::sleep_for(std::chrono::milliseconds(50)); + } + } +} + +// ── CLI parsing ─────────────────────────────────────────────────────────────── + +static void ep_parse_args(int argc, char* argv[]) { + for (int i = 1; i < argc; ++i) { + std::string a(argv[i]); + if (a.rfind("--process-id=", 0) == 0) g_process_id = std::stoi(a.substr(13)); + else if (a.rfind("--rank=", 0) == 0) g_process_id = std::stoi(a.substr(7)); + else if (a.rfind("--num-processes=",0)==0) g_num_processes = std::stoi(a.substr(16)); + else if (a.rfind("--nranks=", 0) == 0) g_num_processes = std::stoi(a.substr(9)); + else if (a.rfind("--uid-file=", 0) == 0) g_uid_file = a.substr(11); + } + + if (g_process_id < 0 || g_num_processes <= 0) { + fprintf(stderr, + "Usage: %s --rank=N --nranks=N [--uid-file=path] [gtest flags]\n" + " Aliases: --process-id=N, --num-processes=N\n", + argc > 0 ? argv[0] : "test_ep"); + exit(EXIT_FAILURE); + } + + if (g_uid_file.empty()) { + const char* t = getenv("TMPDIR"); if (!t) t = "/tmp"; + g_uid_file = std::string(t) + "/te_ep_uid_" + std::to_string(g_process_id); + } +} + +// ── Bootstrap / teardown ────────────────────────────────────────────────────── + +// Returns false if the binary should exit without running tests (wrong SM, etc.). +static bool ep_bootstrap(int argc, char* argv[]) { + ep_parse_args(argc, argv); + ::testing::InitGoogleTest(&argc, argv); + + int device_count; + cudaGetDeviceCount(&device_count); + cudaSetDevice(g_process_id % device_count); + + int device, major; + cudaGetDevice(&device); + cudaDeviceGetAttribute(&major, cudaDevAttrComputeCapabilityMajor, device); + g_sm_major = major; + if (major < 9) { + if (g_process_id == 0) + printf("SKIP: EP requires SM_90+ (device is SM_%d0)\n", major); + return false; + } + if (g_num_processes < 2) { + if (g_process_id == 0) + printf("SKIP: at least 2 processes required\n"); + return false; + } + + g_ep_size = g_num_processes; + g_num_experts = g_ep_size * 4; // 4 experts per rank + + ncclUniqueId uid{}; + exchange_unique_id(&uid); + + NVTEEpGroupConfig group_config{}; + group_config.ep_size = g_ep_size; + group_config.num_experts = g_num_experts; + group_config.max_tokens_per_rank = g_max_tokens_per_rank; + // Worst-case for top_k fan-out: ep_size * max_tokens_per_rank * 2. + group_config.max_recv_tokens_per_rank = g_ep_size * g_max_tokens_per_rank * 2; + group_config.hidden_dim = g_hidden_dim; + + ASSERT_NCCL_OK(ncclCommInitRank(&g_ep_comm, g_num_processes, uid, g_process_id)); + nvte_ep_initialize(static_cast(g_ep_comm), group_config); + + if (g_process_id == 0) { + printf("EP initialized: ep_size=%d num_experts=%d " + "hidden_dim=%d max_tokens_per_rank=%d\n", + g_ep_size, g_num_experts, g_hidden_dim, g_max_tokens_per_rank); + } + + g_ep_initialized = true; + return true; +} + +// Tear down in dependency order: backend's ep_group reads from ep_comm, +// so destroy the group first, then the comm. +static void ep_teardown() { + if (g_ep_initialized) { + nvte_ep_shutdown(); + if (g_ep_comm != nullptr) { + ncclCommDestroy(g_ep_comm); + g_ep_comm = nullptr; + } + g_ep_initialized = false; + } + if (g_process_id == 0) remove(g_uid_file.c_str()); +} diff --git a/tests/cpp_distributed/test_ep_coverage.cu b/tests/cpp_distributed/test_ep_coverage.cu new file mode 100644 index 0000000000..ef7941905d --- /dev/null +++ b/tests/cpp_distributed/test_ep_coverage.cu @@ -0,0 +1,379 @@ +/************************************************************************* + * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +/* + * EP C-API coverage tests (paths not exercised by the pipeline suite). + * + * MultiHandleAllocTest — distinct handle ids; each works end-to-end. + * TopK1Test — top_k=1 dispatch/combine/bwd round-trip. + * EmptyExpertsTest — alignment ∈ {0, 2, 8, 16} with experts receiving 0 tokens. + * NegativeTests — alignment mismatch and null handle_mem must throw. + */ + +#include "test_ep_common.h" + +#include +#include + +// top1 -> expert 0, top2 -> expert 2; leaves local-expert 1 empty between two +// full experts. Requires top_k >= 2 and num_experts >= 3. +static std::vector routing_skip_middle(int num_tokens, int top_k) { + std::vector idx(num_tokens * top_k); + for (int t = 0; t < num_tokens; ++t) { + idx[t * top_k + 0] = 0; + if (top_k >= 2) idx[t * top_k + 1] = 2; + for (int k = 2; k < top_k; ++k) idx[t * top_k + k] = 2 + k; // distinct stragglers + } + return idx; +} + +static std::vector tokens_constant(int num_tokens, int hidden_dim, float val) { + std::vector v(num_tokens * hidden_dim); + nv_bfloat16 b = __float2bfloat16(val); + std::fill(v.begin(), v.end(), b); + return v; +} + +namespace { + +class EpCoverageBase : public ::testing::Test { + protected: + int ep_size_, num_experts_, num_local_experts_, hidden_dim_; + int max_tokens_per_rank_; + + void SetUp() override { + if (g_sm_major < 9) + GTEST_SKIP() << "EP requires SM_90+ (device is SM_" << g_sm_major << "0)"; + ASSERT_GE(g_num_processes, 2); + ASSERT_TRUE(g_ep_initialized); + ep_size_ = g_ep_size; + num_experts_ = g_num_experts; + num_local_experts_ = num_experts_ / ep_size_; + hidden_dim_ = g_hidden_dim; + max_tokens_per_rank_ = g_max_tokens_per_rank; + } + + // Helper: allocate buffers + tensor views for a single dispatch+combine. + struct Bundle { + DevBuf topk_idx; + DevBuf topk_weights; + DevBuf tokens; + DevBuf token_counts; + DevBuf handle_mem; + DevBuf recv_tokens; + DevBuf recv_topk_weights; + DevBuf result; + uint64_t handle_id = 0; + size_t handle_mem_size = 0; + size_t recv_capacity = 0; + }; + + Bundle make_bundle(int num_tokens, int top_k, int num_local_experts, + size_t alignment) { + Bundle b; + b.recv_capacity = static_cast(ep_size_) * max_tokens_per_rank_ * 2; + b.topk_idx.alloc(num_tokens * top_k); + b.topk_weights.alloc(num_tokens * top_k); + b.tokens.alloc(num_tokens * hidden_dim_); + b.token_counts.alloc(num_local_experts); + b.recv_tokens.alloc(b.recv_capacity * hidden_dim_); + b.recv_topk_weights.alloc(b.recv_capacity); + b.result.alloc(num_tokens * hidden_dim_); + NVTEEpLayerConfig cfg{num_local_experts, top_k, alignment}; + b.handle_id = nvte_ep_register_layer(cfg, &b.handle_mem_size); + b.handle_mem.alloc(b.handle_mem_size); + return b; + } +}; + +} // namespace + +// ============================================================================= +// MultiHandleAllocTest: ids are distinct and each is independently usable. +// ============================================================================= + +class MultiHandleAllocTest : public EpCoverageBase {}; + +TEST_F(MultiHandleAllocTest, IdsAreDistinct) { + NVTEEpLayerConfig cfg{num_local_experts_, /*top_k=*/2, /*alignment=*/0}; + const int kN = 8; + std::vector ids(kN); + for (int i = 0; i < kN; ++i) { + size_t sz = 0; + ids[i] = nvte_ep_register_layer(cfg, &sz); + } + for (int i = 0; i < kN; ++i) { + EXPECT_NE(ids[i], 0u) << "handle_id 0 is reserved as \"no id\""; + for (int j = i + 1; j < kN; ++j) + EXPECT_NE(ids[i], ids[j]) << "duplicate id " << ids[i] << " at indices " << i << ", " << j; + } +} + +TEST_F(MultiHandleAllocTest, TwoHandlesCoexist) { + const int num_tokens = 16, top_k = 2; + Bundle a = make_bundle(num_tokens, top_k, num_local_experts_, /*alignment=*/0); + Bundle b = make_bundle(num_tokens, top_k, num_local_experts_, /*alignment=*/0); + + auto h_idx = routing_balanced(g_process_id, num_tokens, top_k, + num_experts_, num_local_experts_); + std::vector h_w(num_tokens * top_k, 1.0f / top_k); + auto h_tok = tokens_constant(num_tokens, hidden_dim_, 0.5f); + for (Bundle* x : {&a, &b}) { + CHECK_CUDA(cudaMemcpy(x->topk_idx.get(), h_idx.data(), + h_idx.size() * sizeof(int64_t), cudaMemcpyHostToDevice)); + CHECK_CUDA(cudaMemcpy(x->topk_weights.get(), h_w.data(), + h_w.size() * sizeof(float), cudaMemcpyHostToDevice)); + CHECK_CUDA(cudaMemcpy(x->tokens.get(), h_tok.data(), + h_tok.size() * sizeof(nv_bfloat16), cudaMemcpyHostToDevice)); + } + + cudaStream_t stream; + CHECK_CUDA(cudaStreamCreate(&stream)); + + ASSERT_NE(a.handle_id, b.handle_id); + + auto run_one = [&](Bundle& x) { + auto topk_idx = make_nvte_tensor(x.topk_idx.get(), {(size_t)num_tokens, (size_t)top_k}, kNVTEInt64); + auto topk_weights = make_nvte_tensor(x.topk_weights.get(), {(size_t)num_tokens, (size_t)top_k}, kNVTEFloat32); + auto token_counts = make_nvte_tensor(x.token_counts.get(), {(size_t)num_local_experts_}, kNVTEInt32); + auto handle_mem = make_nvte_tensor(x.handle_mem.get(), {x.handle_mem_size}, kNVTEByte); + auto tokens = make_nvte_tensor(x.tokens.get(), {(size_t)num_tokens, (size_t)hidden_dim_}, kNVTEBFloat16); + auto recv_tokens = make_nvte_tensor(x.recv_tokens.get(), {x.recv_capacity, (size_t)hidden_dim_}, kNVTEBFloat16); + auto recv_w = make_nvte_tensor(x.recv_topk_weights.get(), {x.recv_capacity}, kNVTEFloat32); + auto result = make_nvte_tensor(x.result.get(), {(size_t)num_tokens, (size_t)hidden_dim_}, kNVTEBFloat16); + NVTEEpHandle h{x.handle_id, handle_mem.tensor}; + ASSERT_NO_THROW(nvte_ep_prepare(h, topk_idx.tensor, token_counts.tensor, + /*alignment=*/0, stream)); + ASSERT_NO_THROW(nvte_ep_dispatch(h, topk_idx.tensor, tokens.tensor, + NVTECommWindow{}, topk_weights.tensor, NVTECommWindow{}, + recv_tokens.tensor, NVTECommWindow{}, recv_w.tensor, + NVTECommWindow{}, stream)); + ASSERT_NO_THROW(nvte_ep_combine(h, recv_tokens.tensor, NVTECommWindow{}, + result.tensor, stream)); + }; + run_one(a); + run_one(b); + CHECK_CUDA(cudaStreamSynchronize(stream)); + + // Both round-trips must produce result == top_k * 0.5 = 1.0. + for (Bundle* x : {&a, &b}) { + std::vector h_res(num_tokens * hidden_dim_); + CHECK_CUDA(cudaMemcpy(h_res.data(), x->result.get(), + h_res.size() * sizeof(nv_bfloat16), cudaMemcpyDeviceToHost)); + const int probes[3] = {0, hidden_dim_ / 2, hidden_dim_ - 1}; + for (int t = 0; t < num_tokens; ++t) + for (int p : probes) + EXPECT_NEAR(__bfloat162float(h_res[t * hidden_dim_ + p]), + static_cast(top_k) * 0.5f, 1e-2f); + } + CHECK_CUDA(cudaStreamDestroy(stream)); +} + +// ============================================================================= +// TopK1Test: top_k=1 dispatch/combine round-trip, including dispatch_bwd. +// ============================================================================= + +class TopK1Test : public EpCoverageBase {}; + +TEST_F(TopK1Test, RoundTrip) { + const int num_tokens = 16, top_k = 1; + Bundle b = make_bundle(num_tokens, top_k, num_local_experts_, /*alignment=*/0); + + auto h_idx = routing_balanced(g_process_id, num_tokens, top_k, + num_experts_, num_local_experts_); + std::vector h_w(num_tokens * top_k, 1.0f); // top_k=1: weight is unity + auto h_tok = tokens_constant(num_tokens, hidden_dim_, 0.25f); + CHECK_CUDA(cudaMemcpy(b.topk_idx.get(), h_idx.data(), + h_idx.size() * sizeof(int64_t), cudaMemcpyHostToDevice)); + CHECK_CUDA(cudaMemcpy(b.topk_weights.get(), h_w.data(), + h_w.size() * sizeof(float), cudaMemcpyHostToDevice)); + CHECK_CUDA(cudaMemcpy(b.tokens.get(), h_tok.data(), + h_tok.size() * sizeof(nv_bfloat16), cudaMemcpyHostToDevice)); + + auto topk_idx_t = make_nvte_tensor(b.topk_idx.get(), + {(size_t)num_tokens, (size_t)top_k}, kNVTEInt64); + auto topk_weights_t = make_nvte_tensor(b.topk_weights.get(), + {(size_t)num_tokens, (size_t)top_k}, kNVTEFloat32); + auto token_counts_t = make_nvte_tensor(b.token_counts.get(), + {(size_t)num_local_experts_}, kNVTEInt32); + auto handle_mem_t = make_nvte_tensor(b.handle_mem.get(), + {b.handle_mem_size}, kNVTEByte); + auto tokens_t = make_nvte_tensor(b.tokens.get(), + {(size_t)num_tokens, (size_t)hidden_dim_}, kNVTEBFloat16); + auto recv_tokens_t = make_nvte_tensor(b.recv_tokens.get(), + {b.recv_capacity, (size_t)hidden_dim_}, kNVTEBFloat16); + auto recv_w_t = make_nvte_tensor(b.recv_topk_weights.get(), + {b.recv_capacity}, kNVTEFloat32); + auto result_t = make_nvte_tensor(b.result.get(), + {(size_t)num_tokens, (size_t)hidden_dim_}, kNVTEBFloat16); + + cudaStream_t stream; + CHECK_CUDA(cudaStreamCreate(&stream)); + + NVTEEpHandle h{b.handle_id, handle_mem_t.tensor}; + ASSERT_NO_THROW(nvte_ep_prepare(h, topk_idx_t.tensor, token_counts_t.tensor, + /*alignment=*/0, stream)); + ASSERT_NO_THROW(nvte_ep_dispatch(h, topk_idx_t.tensor, + tokens_t.tensor, NVTECommWindow{}, topk_weights_t.tensor, + NVTECommWindow{}, recv_tokens_t.tensor, NVTECommWindow{}, + recv_w_t.tensor, NVTECommWindow{}, stream)); + ASSERT_NO_THROW(nvte_ep_combine(h, recv_tokens_t.tensor, + NVTECommWindow{}, result_t.tensor, stream)); + CHECK_CUDA(cudaStreamSynchronize(stream)); + + // top_k=1: combine is unweighted gather, so result[t] == tokens[t]. + std::vector h_res(num_tokens * hidden_dim_); + CHECK_CUDA(cudaMemcpy(h_res.data(), b.result.get(), + h_res.size() * sizeof(nv_bfloat16), cudaMemcpyDeviceToHost)); + const int probes[3] = {0, hidden_dim_ / 2, hidden_dim_ - 1}; + for (int t = 0; t < num_tokens; ++t) + for (int p : probes) + EXPECT_NEAR(__bfloat162float(h_res[t * hidden_dim_ + p]), 0.25f, 1e-2f) + << "tok " << t << " hidden " << p; + + CHECK_CUDA(cudaStreamDestroy(stream)); +} + +// ============================================================================= +// EmptyExpertsTest: alignment ∈ {0, 2, 8, 16}, only local-expert 0 receives +// tokens. Round-trip must produce result == top_k * tokens regardless of the +// per-expert padding choice. +// ============================================================================= + +class EmptyExpertsTest : public EpCoverageBase, + public ::testing::WithParamInterface {}; + +TEST_P(EmptyExpertsTest, RoundTripCorrect) { + // routing_skip_middle needs experts {0, 2, ...}; smallest viable num_experts is 3. + ASSERT_GE(num_experts_, 3); + const size_t alignment = GetParam(); + const int num_tokens = 16, top_k = 2; + Bundle b = make_bundle(num_tokens, top_k, num_local_experts_, alignment); + + // top1 -> expert 0, top2 -> expert 2; rank 0's local-expert 1 receives 0 + // tokens between two non-empty experts. + std::vector h_idx = routing_skip_middle(num_tokens, top_k); + std::vector h_w(num_tokens * top_k, 1.0f / top_k); + auto h_tok = tokens_constant(num_tokens, hidden_dim_, 0.3f); + + CHECK_CUDA(cudaMemcpy(b.topk_idx.get(), h_idx.data(), + h_idx.size() * sizeof(int64_t), cudaMemcpyHostToDevice)); + CHECK_CUDA(cudaMemcpy(b.topk_weights.get(), h_w.data(), + h_w.size() * sizeof(float), cudaMemcpyHostToDevice)); + CHECK_CUDA(cudaMemcpy(b.tokens.get(), h_tok.data(), + h_tok.size() * sizeof(nv_bfloat16), cudaMemcpyHostToDevice)); + + auto topk_idx_t = make_nvte_tensor(b.topk_idx.get(), + {(size_t)num_tokens, (size_t)top_k}, kNVTEInt64); + auto topk_weights_t = make_nvte_tensor(b.topk_weights.get(), + {(size_t)num_tokens, (size_t)top_k}, kNVTEFloat32); + auto token_counts_t = make_nvte_tensor(b.token_counts.get(), + {(size_t)num_local_experts_}, kNVTEInt32); + auto handle_mem_t = make_nvte_tensor(b.handle_mem.get(), + {b.handle_mem_size}, kNVTEByte); + auto tokens_t = make_nvte_tensor(b.tokens.get(), + {(size_t)num_tokens, (size_t)hidden_dim_}, kNVTEBFloat16); + auto recv_tokens_t = make_nvte_tensor(b.recv_tokens.get(), + {b.recv_capacity, (size_t)hidden_dim_}, kNVTEBFloat16); + auto recv_w_t = make_nvte_tensor(b.recv_topk_weights.get(), + {b.recv_capacity}, kNVTEFloat32); + auto result_t = make_nvte_tensor(b.result.get(), + {(size_t)num_tokens, (size_t)hidden_dim_}, kNVTEBFloat16); + + cudaStream_t stream; + CHECK_CUDA(cudaStreamCreate(&stream)); + + NVTEEpHandle h{b.handle_id, handle_mem_t.tensor}; + ASSERT_NO_THROW(nvte_ep_prepare(h, topk_idx_t.tensor, token_counts_t.tensor, + alignment, stream)); + ASSERT_NO_THROW(nvte_ep_dispatch(h, topk_idx_t.tensor, + tokens_t.tensor, NVTECommWindow{}, topk_weights_t.tensor, + NVTECommWindow{}, recv_tokens_t.tensor, NVTECommWindow{}, + recv_w_t.tensor, NVTECommWindow{}, stream)); + ASSERT_NO_THROW(nvte_ep_combine(h, recv_tokens_t.tensor, + NVTECommWindow{}, result_t.tensor, stream)); + CHECK_CUDA(cudaStreamSynchronize(stream)); + + // Identity expert + uniform weights: result[t] == top_k * tokens[t]. + std::vector h_res(num_tokens * hidden_dim_); + CHECK_CUDA(cudaMemcpy(h_res.data(), b.result.get(), + h_res.size() * sizeof(nv_bfloat16), cudaMemcpyDeviceToHost)); + const float expected = static_cast(top_k) * 0.3f; + const int probes[3] = {0, hidden_dim_ / 2, hidden_dim_ - 1}; + for (int t = 0; t < num_tokens; ++t) + for (int p : probes) + EXPECT_NEAR(__bfloat162float(h_res[t * hidden_dim_ + p]), expected, 1e-2f) + << "alignment=" << alignment << " tok=" << t << " hidden=" << p; + + CHECK_CUDA(cudaStreamDestroy(stream)); +} + +INSTANTIATE_TEST_SUITE_P(Alignments, EmptyExpertsTest, + ::testing::Values(0, 2, 8, 16)); + +// ============================================================================= +// NegativeTests: prepare/dispatch must surface bad inputs as exceptions. +// ============================================================================= + +class NegativeTests : public EpCoverageBase {}; + +TEST_F(NegativeTests, AlignmentMismatchThrows) { + const int num_tokens = 8, top_k = 2; + // Allocate handle for alignment=0, then call prepare with alignment=16. + Bundle b = make_bundle(num_tokens, top_k, num_local_experts_, /*alignment=*/0); + auto h_idx = routing_balanced(g_process_id, num_tokens, top_k, + num_experts_, num_local_experts_); + CHECK_CUDA(cudaMemcpy(b.topk_idx.get(), h_idx.data(), + h_idx.size() * sizeof(int64_t), cudaMemcpyHostToDevice)); + + auto topk_idx_t = make_nvte_tensor(b.topk_idx.get(), + {(size_t)num_tokens, (size_t)top_k}, kNVTEInt64); + auto token_counts_t = make_nvte_tensor(b.token_counts.get(), + {(size_t)num_local_experts_}, kNVTEInt32); + auto handle_mem_t = make_nvte_tensor(b.handle_mem.get(), + {b.handle_mem_size}, kNVTEByte); + + cudaStream_t stream; + CHECK_CUDA(cudaStreamCreate(&stream)); + NVTEEpHandle h{b.handle_id, handle_mem_t.tensor}; + EXPECT_THROW(nvte_ep_prepare(h, topk_idx_t.tensor, token_counts_t.tensor, + /*alignment=*/16, stream), + std::exception); + CHECK_CUDA(cudaStreamDestroy(stream)); +} + +TEST_F(NegativeTests, NullHandleMemThrows) { + const int num_tokens = 8, top_k = 2; + Bundle b = make_bundle(num_tokens, top_k, num_local_experts_, /*alignment=*/0); + auto h_idx = routing_balanced(g_process_id, num_tokens, top_k, + num_experts_, num_local_experts_); + CHECK_CUDA(cudaMemcpy(b.topk_idx.get(), h_idx.data(), + h_idx.size() * sizeof(int64_t), cudaMemcpyHostToDevice)); + + auto topk_idx_t = make_nvte_tensor(b.topk_idx.get(), + {(size_t)num_tokens, (size_t)top_k}, kNVTEInt64); + auto token_counts_t = make_nvte_tensor(b.token_counts.get(), + {(size_t)num_local_experts_}, kNVTEInt32); + // Construct a tensor view backed by a null device pointer. + auto null_hm_t = make_nvte_tensor(nullptr, {b.handle_mem_size}, kNVTEByte); + + cudaStream_t stream; + CHECK_CUDA(cudaStreamCreate(&stream)); + NVTEEpHandle h{b.handle_id, null_hm_t.tensor}; + EXPECT_THROW(nvte_ep_prepare(h, topk_idx_t.tensor, token_counts_t.tensor, + /*alignment=*/0, stream), + std::exception); + CHECK_CUDA(cudaStreamDestroy(stream)); +} + +// ── main ────────────────────────────────────────────────────────────────────── + +int main(int argc, char* argv[]) { + if (!ep_bootstrap(argc, argv)) return 0; + int ret = RUN_ALL_TESTS(); + ep_teardown(); + return ret; +} diff --git a/tests/cpp_distributed/test_ep_init.cu b/tests/cpp_distributed/test_ep_init.cu new file mode 100644 index 0000000000..08744dfee5 --- /dev/null +++ b/tests/cpp_distributed/test_ep_init.cu @@ -0,0 +1,64 @@ +/************************************************************************* + * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +/* + * Unit tests for EP initialization paths. + * + * Tests: + * EPInitTest/InitPath — backend is live after init, handle_mem_size > 0 + * EPInitTest/NumLocalExperts — handle_mem_size is consistent across num_local_experts values + * + * Run via run_test_ep.sh (both uid and comm init paths are tested by the script). + */ + +#include "test_ep_common.h" + +// ── Fixture ─────────────────────────────────────────────────────────────────── + +class EPInitTest : public ::testing::Test { + protected: + void SetUp() override { + if (g_sm_major < 9) + GTEST_SKIP() << "EP requires SM_90+ (device is SM_" << g_sm_major << "0)"; + ASSERT_GE(g_num_processes, 2) << "EP tests require at least 2 processes"; + ASSERT_TRUE(g_ep_initialized) << "EP not initialized"; + } +}; + +// ── Tests ───────────────────────────────────────────────────────────────────── + +TEST_F(EPInitTest, InitPath) { + int nle = g_num_experts / g_ep_size; + NVTEEpLayerConfig cfg{nle, /*top_k=*/2}; + size_t sz = 0; + (void)nvte_ep_register_layer(cfg, &sz); + ASSERT_GT(sz, 0u) << "handle_mem_size must be > 0 after init"; + + if (g_process_id == 0) { + printf(" handle_mem : %zu bytes\n", sz); + } +} + +TEST_F(EPInitTest, NumLocalExperts) { + // handle_mem_size should be > 0 for any valid num_local_experts value. + for (int nle : {1, g_num_experts / g_ep_size}) { + NVTEEpLayerConfig cfg{nle, /*top_k=*/2}; + size_t sz = 0; + (void)nvte_ep_register_layer(cfg, &sz); + ASSERT_GT(sz, 0u) << "num_local_experts=" << nle; + if (g_process_id == 0) + printf(" nle=%-3d handle_mem_size=%zu bytes\n", nle, sz); + } +} + +// ── main ────────────────────────────────────────────────────────────────────── + +int main(int argc, char* argv[]) { + if (!ep_bootstrap(argc, argv)) return 0; + int ret = RUN_ALL_TESTS(); + ep_teardown(); + return ret; +} diff --git a/tests/cpp_distributed/test_ep_pipeline.cu b/tests/cpp_distributed/test_ep_pipeline.cu new file mode 100644 index 0000000000..41f83a6d11 --- /dev/null +++ b/tests/cpp_distributed/test_ep_pipeline.cu @@ -0,0 +1,890 @@ +/************************************************************************* + * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +/* + * EP pipeline tests: smallest-scope first. + * + * EPDispatchTest/PrepareAndDispatch — exact recv values + per-expert counts + * EPCombineTest/Combine — round-trip: out == top_k * tokens + * EPCombineBwdTest/CombineBwdCheck — exact grad_expert values + * EPDispatchBwdTest/DispatchBwdCheck — exact grad_tokens + * EPDispatchBwdGradWeightsTest/RoundTrip — exact per-(t, k) grad_topk_weights + * EPPipelineTest/FullForwardBackward — fwd + bwd NaN/Inf check + * + * Routing: token t on rank r → expert (r * num_local_experts + t * top_k + k) % num_experts + * Token values: rank r, token t → all hidden dims = (r+1)*0.01 + t*0.001 + * + * Closed-form expected values: + * dispatch recv: multiset of source-token values routed to this rank's experts + * combine: result[t] == top_k * tokens[t] + * combine_bwd: grad_expert[slot] == d_result[t] (no weighting) + * dispatch_bwd: grad_tokens[t] == top_k * d_result[t] + */ + +#include "test_ep_common.h" + +#include +#include +#include +#include + +// ── Deterministic routing helpers ───────────────────────────────────────────── + +// Token value for (rank, t): (rank * num_tokens + t + 1) / 256. Step 1/256 is +// bf16-exact and unique across (rank, t) when rank * num_tokens + t < 256. +static inline float token_value(int rank, int t, int num_tokens) { + return static_cast(rank * num_tokens + t + 1) * (1.0f / 256.0f); +} + +static std::vector generate_tokens(int rank, int num_tokens, int hidden_dim) { + std::vector v(num_tokens * hidden_dim); + for (int t = 0; t < num_tokens; ++t) { + nv_bfloat16 val = __float2bfloat16(token_value(rank, t, num_tokens)); + for (int h = 0; h < hidden_dim; ++h) + v[t * hidden_dim + h] = val; + } + return v; +} + +static std::vector expected_token_counts( + int recv_rank, int num_processes, int num_tokens, int top_k, + int num_experts, int num_local_experts) { + int base = recv_rank * num_local_experts; + std::vector cnt(num_local_experts, 0); + for (int src = 0; src < num_processes; ++src) { + auto idx = routing_balanced(src, num_tokens, top_k, num_experts, num_local_experts); + for (int t = 0; t < num_tokens; ++t) + for (int k = 0; k < top_k; ++k) { + int64_t e = idx[t * top_k + k]; + if (e >= base && e < base + num_local_experts) ++cnt[e - base]; + } + } + return cnt; +} + +static std::vector expected_recv_values_sorted( + int recv_rank, int num_processes, int num_tokens, int top_k, + int num_experts, int num_local_experts) { + int base = recv_rank * num_local_experts; + std::vector vals; + for (int src = 0; src < num_processes; ++src) { + auto idx = routing_balanced(src, num_tokens, top_k, num_experts, num_local_experts); + for (int t = 0; t < num_tokens; ++t) + for (int k = 0; k < top_k; ++k) { + int64_t e = idx[t * top_k + k]; + if (e >= base && e < base + num_local_experts) { + float raw = token_value(src, t, num_tokens); + vals.push_back(__bfloat162float(__float2bfloat16(raw))); + } + } + } + std::sort(vals.begin(), vals.end()); + return vals; +} + +// BF16 has 7 mantissa bits; relative ULP ≈ 2^-7. Use 4× headroom for +// accumulation noise inside dispatch/combine. +static float bf16_tol(float magnitude) { + return 4.f * std::ldexp(std::fabs(magnitude) + 1e-3f, -7); +} + +static bool check_no_nan_inf(const nv_bfloat16* dev, int count, const char* name) { + std::vector h(count); + cudaMemcpy(h.data(), dev, count * sizeof(nv_bfloat16), cudaMemcpyDeviceToHost); + for (int i = 0; i < count; ++i) { + float v = __bfloat162float(h[i]); + if (std::isnan(v) || std::isinf(v)) { + fprintf(stderr, "Rank %d: %s in %s[%d]\n", + g_process_id, std::isnan(v) ? "NaN" : "Inf", name, i); + return false; + } + } + return true; +} + +// ── Forward buffer set with RAII ────────────────────────────────────────────── + +struct EPBuffers { + // Forward + DevBuf topk_idx; + DevBuf topk_weights; + DevBuf tokens; + DevBuf token_counts; + DevBuf handle_mem; + DevBuf recv_tokens; + DevBuf recv_topk_weights; + DevBuf result; + // Backward + DevBuf grad_result; + DevBuf grad_expert; + DevBuf grad_tokens; + DevBuf g_recv_topk_weights; + DevBuf grad_topk_weights; + + uint64_t handle_id = 0; + size_t handle_mem_size = 0; + size_t recv_capacity = 0; + int top_k_ = 0; + + void alloc(int num_tokens, int top_k, int hidden_dim, int num_local_experts, + int ep_size, int max_tokens_per_rank, size_t alignment = 0) { + top_k_ = top_k; + recv_capacity = static_cast(ep_size) * max_tokens_per_rank * 2; + + topk_idx.alloc(num_tokens * top_k); + topk_weights.alloc(num_tokens * top_k); + tokens.alloc(num_tokens * hidden_dim); + token_counts.alloc(num_local_experts); + recv_tokens.alloc(recv_capacity * hidden_dim); + recv_topk_weights.alloc(recv_capacity); + result.alloc(num_tokens * hidden_dim); + + NVTEEpLayerConfig cfg{num_local_experts, top_k, alignment}; + handle_id = nvte_ep_register_layer(cfg, &handle_mem_size); + handle_mem.alloc(handle_mem_size); + + grad_result.alloc(num_tokens * hidden_dim); + grad_expert.alloc(recv_capacity * hidden_dim); + grad_tokens.alloc(num_tokens * hidden_dim); + g_recv_topk_weights.alloc(recv_capacity); + grad_topk_weights.alloc(num_tokens * top_k); + } +}; + +// Bundled NVTETensor views over an EPBuffers — one place to update the shape +// conventions when the C-API evolves. +struct EPTensors { + TensorHandle topk_idx, topk_weights, token_counts, handle_mem, tokens; + TensorHandle recv_tokens, recv_topk_weights, result; + TensorHandle grad_result, grad_expert, grad_tokens; + TensorHandle g_recv_topk_weights, grad_topk_weights; + + EPTensors(EPBuffers& b, int num_tokens, int top_k, int hidden_dim, + int num_local_experts) { + topk_idx = make_nvte_tensor(b.topk_idx.get(), + {(size_t)num_tokens, (size_t)top_k}, kNVTEInt64); + topk_weights = make_nvte_tensor(b.topk_weights.get(), + {(size_t)num_tokens, (size_t)top_k}, kNVTEFloat32); + token_counts = make_nvte_tensor(b.token_counts.get(), + {(size_t)num_local_experts}, kNVTEInt32); + handle_mem = make_nvte_tensor(b.handle_mem.get(), + {b.handle_mem_size}, kNVTEByte); + tokens = make_nvte_tensor(b.tokens.get(), + {(size_t)num_tokens, (size_t)hidden_dim}, kNVTEBFloat16); + recv_tokens = make_nvte_tensor(b.recv_tokens.get(), + {b.recv_capacity, (size_t)hidden_dim}, kNVTEBFloat16); + recv_topk_weights = make_nvte_tensor(b.recv_topk_weights.get(), + {b.recv_capacity}, kNVTEFloat32); + result = make_nvte_tensor(b.result.get(), + {(size_t)num_tokens, (size_t)hidden_dim}, kNVTEBFloat16); + grad_result = make_nvte_tensor(b.grad_result.get(), + {(size_t)num_tokens, (size_t)hidden_dim}, kNVTEBFloat16); + grad_expert = make_nvte_tensor(b.grad_expert.get(), + {b.recv_capacity, (size_t)hidden_dim}, kNVTEBFloat16); + grad_tokens = make_nvte_tensor(b.grad_tokens.get(), + {(size_t)num_tokens, (size_t)hidden_dim}, kNVTEBFloat16); + g_recv_topk_weights = make_nvte_tensor(b.g_recv_topk_weights.get(), + {b.recv_capacity}, kNVTEFloat32); + grad_topk_weights = make_nvte_tensor(b.grad_topk_weights.get(), + {(size_t)num_tokens, (size_t)top_k}, kNVTEFloat32); + } +}; + +// ── Shared fixture base ─────────────────────────────────────────────────────── + +class EpOpTestBase : public ::testing::Test { + protected: + int ep_size_, num_experts_, num_local_experts_, hidden_dim_; + int max_tokens_per_rank_, top_k_, num_tokens_; + + void SetUp() override { + if (g_sm_major < 9) + GTEST_SKIP() << "EP requires SM_90+ (device is SM_" << g_sm_major << "0)"; + ASSERT_GE(g_num_processes, 2); + ASSERT_TRUE(g_ep_initialized); + + ep_size_ = g_ep_size; + num_experts_ = g_num_experts; + num_local_experts_ = num_experts_ / ep_size_; + hidden_dim_ = g_hidden_dim; + max_tokens_per_rank_ = g_max_tokens_per_rank; + top_k_ = 2; + num_tokens_ = 32; + } + + void upload_inputs(EPBuffers& buf, int rank = -1) { + if (rank < 0) rank = g_process_id; + auto h_idx = routing_balanced(rank, num_tokens_, top_k_, + num_experts_, num_local_experts_); + std::vector h_w(num_tokens_ * top_k_, 1.0f / top_k_); + auto h_tok = generate_tokens(rank, num_tokens_, hidden_dim_); + + CHECK_CUDA(cudaMemcpy(buf.topk_idx.get(), h_idx.data(), + h_idx.size() * sizeof(int64_t), cudaMemcpyHostToDevice)); + CHECK_CUDA(cudaMemcpy(buf.topk_weights.get(), h_w.data(), + h_w.size() * sizeof(float), cudaMemcpyHostToDevice)); + CHECK_CUDA(cudaMemcpy(buf.tokens.get(), h_tok.data(), + h_tok.size() * sizeof(nv_bfloat16), cudaMemcpyHostToDevice)); + } + + NVTEEpLayerConfig layer_config(size_t alignment = 0) const { + return NVTEEpLayerConfig{num_local_experts_, top_k_, alignment}; + } + + // ASSERT_CUDA_OK (fprintf+exit) so this non-void helper stays legal. + int read_total_recv(const EPBuffers& buf) const { + std::vector cnt(num_local_experts_); + ASSERT_CUDA_OK(cudaMemcpy(cnt.data(), buf.token_counts.get(), + num_local_experts_ * sizeof(int32_t), cudaMemcpyDeviceToHost)); + int total = 0; + for (int c : cnt) total += c; + return total; + } +}; + +// ============================================================================= +// EPDispatchTest: exact recv values and per-expert counts. +// ============================================================================= + +class EPDispatchTest : public EpOpTestBase {}; + +TEST_F(EPDispatchTest, PrepareAndDispatch) { + EPBuffers buf; + buf.alloc(num_tokens_, top_k_, hidden_dim_, num_local_experts_, + ep_size_, max_tokens_per_rank_); + upload_inputs(buf); + EPTensors t(buf, num_tokens_, top_k_, hidden_dim_, num_local_experts_); + + CHECK_CUDA(cudaMemset(buf.recv_tokens.get(), 0, buf.recv_tokens.bytes())); + + cudaStream_t stream; + CHECK_CUDA(cudaStreamCreate(&stream)); + + uint64_t handle_id = buf.handle_id; + ASSERT_NO_THROW(nvte_ep_prepare(NVTEEpHandle{handle_id, t.handle_mem.tensor}, t.topk_idx.tensor, t.token_counts.tensor, /*alignment=*/0, stream)); + ASSERT_NO_THROW(nvte_ep_dispatch(NVTEEpHandle{handle_id, t.handle_mem.tensor}, t.topk_idx.tensor, + t.tokens.tensor, NVTECommWindow{}, t.topk_weights.tensor, + NVTECommWindow{}, t.recv_tokens.tensor, NVTECommWindow{}, + t.recv_topk_weights.tensor, NVTECommWindow{}, stream)); + CHECK_CUDA(cudaStreamSynchronize(stream)); + + // 1. Per-expert counts. + std::vector got_counts(num_local_experts_); + CHECK_CUDA(cudaMemcpy(got_counts.data(), buf.token_counts.get(), + num_local_experts_ * sizeof(int32_t), cudaMemcpyDeviceToHost)); + auto exp_counts = expected_token_counts(g_process_id, g_num_processes, num_tokens_, top_k_, + num_experts_, num_local_experts_); + int total_recv = 0; + for (int i = 0; i < num_local_experts_; ++i) { + EXPECT_EQ(got_counts[i], exp_counts[i]) << "local expert " << i; + total_recv += exp_counts[i]; + } + ASSERT_LE(total_recv, static_cast(buf.recv_capacity)) + << "total_recv exceeded recv_capacity — overflow would corrupt downstream memory"; + + // 2. Recv values: read only the filled prefix per local-expert zone, not the + // whole recv buffer — avoids false positives from legitimate-zero token values. + std::vector h_recv(buf.recv_capacity * hidden_dim_); + CHECK_CUDA(cudaMemcpy(h_recv.data(), buf.recv_tokens.get(), + h_recv.size() * sizeof(nv_bfloat16), cudaMemcpyDeviceToHost)); + + std::vector got_vals; + got_vals.reserve(total_recv); + size_t slot = 0; + for (int e = 0; e < num_local_experts_; ++e) { + for (int i = 0; i < got_counts[e]; ++i) { + got_vals.push_back(__bfloat162float(h_recv[slot * hidden_dim_])); + ++slot; + } + } + std::sort(got_vals.begin(), got_vals.end()); + + auto exp_vals = expected_recv_values_sorted(g_process_id, g_num_processes, num_tokens_, + top_k_, num_experts_, num_local_experts_); + + ASSERT_EQ(got_vals.size(), exp_vals.size()); + for (size_t i = 0; i < exp_vals.size(); ++i) + EXPECT_NEAR(got_vals[i], exp_vals[i], bf16_tol(exp_vals[i])) + << "recv value mismatch at sorted index " << i; + + // 3. recv_topk_weights: every filled slot must equal the per-token weight (1/top_k). + std::vector h_w(buf.recv_capacity); + CHECK_CUDA(cudaMemcpy(h_w.data(), buf.recv_topk_weights.get(), + h_w.size() * sizeof(float), cudaMemcpyDeviceToHost)); + const float exp_w = 1.0f / static_cast(top_k_); + for (int i = 0; i < total_recv; ++i) + EXPECT_NEAR(h_w[i], exp_w, 1e-6f) << "recv_topk_weights[" << i << "]"; + + if (g_process_id == 0) + printf(" PrepareAndDispatch: passed (recv=%d, values + weights exact)\n", total_recv); + + CHECK_CUDA(cudaStreamDestroy(stream)); +} + +// ============================================================================= +// EPCombineTest: round-trip identity expert → result == top_k * tokens. +// ============================================================================= + +class EPCombineTest : public EpOpTestBase {}; + +TEST_F(EPCombineTest, Combine) { + EPBuffers buf; + buf.alloc(num_tokens_, top_k_, hidden_dim_, num_local_experts_, + ep_size_, max_tokens_per_rank_); + upload_inputs(buf); + EPTensors t(buf, num_tokens_, top_k_, hidden_dim_, num_local_experts_); + + cudaStream_t stream; + CHECK_CUDA(cudaStreamCreate(&stream)); + + uint64_t handle_id = buf.handle_id; + ASSERT_NO_THROW(nvte_ep_prepare(NVTEEpHandle{handle_id, t.handle_mem.tensor}, t.topk_idx.tensor, t.token_counts.tensor, /*alignment=*/0, stream)); + ASSERT_NO_THROW(nvte_ep_dispatch(NVTEEpHandle{handle_id, t.handle_mem.tensor}, t.topk_idx.tensor, + t.tokens.tensor, NVTECommWindow{}, t.topk_weights.tensor, + NVTECommWindow{}, t.recv_tokens.tensor, NVTECommWindow{}, + t.recv_topk_weights.tensor, NVTECommWindow{}, stream)); + ASSERT_NO_THROW(nvte_ep_combine(NVTEEpHandle{handle_id, t.handle_mem.tensor}, t.recv_tokens.tensor, NVTECommWindow{}, + t.result.tensor, stream)); + CHECK_CUDA(cudaStreamSynchronize(stream)); + + std::vector h_result(num_tokens_ * hidden_dim_); + CHECK_CUDA(cudaMemcpy(h_result.data(), buf.result.get(), + h_result.size() * sizeof(nv_bfloat16), cudaMemcpyDeviceToHost)); + auto h_tok = generate_tokens(g_process_id, num_tokens_, hidden_dim_); + // Spot-check 3 hidden-dim positions per token to catch partial-row writes. + const int probes[3] = {0, hidden_dim_ / 2, hidden_dim_ - 1}; + for (int tok = 0; tok < num_tokens_; ++tok) { + float exp = __bfloat162float(h_tok[tok * hidden_dim_]) * static_cast(top_k_); + for (int p : probes) { + float got = __bfloat162float(h_result[tok * hidden_dim_ + p]); + EXPECT_NEAR(got, exp, bf16_tol(exp)) + << "token " << tok << " rank " << g_process_id << " hidden " << p; + } + } + + if (g_process_id == 0) + printf(" Combine: passed (result == top_k * tokens)\n"); + + CHECK_CUDA(cudaStreamDestroy(stream)); +} + +// ============================================================================= +// EPCombineBwdTest: filled slots in grad_expert == d_result (unweighted). +// ============================================================================= + +class EPCombineBwdTest : public EpOpTestBase {}; + +TEST_F(EPCombineBwdTest, CombineBwdCheck) { + EPBuffers buf; + buf.alloc(num_tokens_, top_k_, hidden_dim_, num_local_experts_, + ep_size_, max_tokens_per_rank_); + upload_inputs(buf); + EPTensors t(buf, num_tokens_, top_k_, hidden_dim_, num_local_experts_); + + cudaStream_t stream; + CHECK_CUDA(cudaStreamCreate(&stream)); + + uint64_t handle_id = buf.handle_id; + ASSERT_NO_THROW(nvte_ep_prepare(NVTEEpHandle{handle_id, t.handle_mem.tensor}, t.topk_idx.tensor, t.token_counts.tensor, /*alignment=*/0, stream)); + ASSERT_NO_THROW(nvte_ep_dispatch(NVTEEpHandle{handle_id, t.handle_mem.tensor}, t.topk_idx.tensor, + t.tokens.tensor, NVTECommWindow{}, t.topk_weights.tensor, + NVTECommWindow{}, t.recv_tokens.tensor, NVTECommWindow{}, + t.recv_topk_weights.tensor, NVTECommWindow{}, stream)); + ASSERT_NO_THROW(nvte_ep_combine(NVTEEpHandle{handle_id, t.handle_mem.tensor}, t.recv_tokens.tensor, NVTECommWindow{}, + t.result.tensor, stream)); + + std::vector h_grad_r(num_tokens_ * hidden_dim_, __float2bfloat16(0.1f)); + CHECK_CUDA(cudaMemcpyAsync(buf.grad_result.get(), h_grad_r.data(), + h_grad_r.size() * sizeof(nv_bfloat16), + cudaMemcpyHostToDevice, stream)); + CHECK_CUDA(cudaMemsetAsync(buf.grad_expert.get(), 0, buf.grad_expert.bytes(), stream)); + + ASSERT_NO_THROW(nvte_ep_combine_bwd(NVTEEpHandle{handle_id, t.handle_mem.tensor}, t.grad_result.tensor, NVTECommWindow{}, + t.grad_expert.tensor, NVTECommWindow{}, stream)); + CHECK_CUDA(cudaStreamSynchronize(stream)); + + int total_recv = read_total_recv(buf); + + std::vector cnt(num_local_experts_); + CHECK_CUDA(cudaMemcpy(cnt.data(), buf.token_counts.get(), + num_local_experts_ * sizeof(int32_t), cudaMemcpyDeviceToHost)); + std::vector h_ge(buf.recv_capacity * hidden_dim_); + CHECK_CUDA(cudaMemcpy(h_ge.data(), buf.grad_expert.get(), + h_ge.size() * sizeof(nv_bfloat16), cudaMemcpyDeviceToHost)); + + // Walk filled slots by per-expert zone (no v != 0 heuristic). + const float kExpGrad = 0.1f; + size_t slot = 0; + int filled = 0; + for (int e = 0; e < num_local_experts_; ++e) { + for (int i = 0; i < cnt[e]; ++i) { + float v = __bfloat162float(h_ge[slot * hidden_dim_]); + EXPECT_NEAR(v, kExpGrad, bf16_tol(kExpGrad)) + << "grad_expert expert " << e << " slot " << i << " (linear " << slot << ")"; + ++filled; ++slot; + } + } + EXPECT_EQ(filled, total_recv); + + if (g_process_id == 0) + printf(" CombineBwdCheck: passed (filled=%d)\n", filled); + + CHECK_CUDA(cudaStreamDestroy(stream)); +} + +// ============================================================================= +// EPDispatchBwdTest: grad_tokens == top_k * d_result. +// ============================================================================= + +class EPDispatchBwdTest : public EpOpTestBase {}; + +TEST_F(EPDispatchBwdTest, DispatchBwdCheck) { + EPBuffers buf; + buf.alloc(num_tokens_, top_k_, hidden_dim_, num_local_experts_, + ep_size_, max_tokens_per_rank_); + upload_inputs(buf); + EPTensors t(buf, num_tokens_, top_k_, hidden_dim_, num_local_experts_); + + cudaStream_t stream; + CHECK_CUDA(cudaStreamCreate(&stream)); + + uint64_t handle_id = buf.handle_id; + ASSERT_NO_THROW(nvte_ep_prepare(NVTEEpHandle{handle_id, t.handle_mem.tensor}, t.topk_idx.tensor, t.token_counts.tensor, /*alignment=*/0, stream)); + ASSERT_NO_THROW(nvte_ep_dispatch(NVTEEpHandle{handle_id, t.handle_mem.tensor}, t.topk_idx.tensor, + t.tokens.tensor, NVTECommWindow{}, t.topk_weights.tensor, + NVTECommWindow{}, t.recv_tokens.tensor, NVTECommWindow{}, + t.recv_topk_weights.tensor, NVTECommWindow{}, stream)); + ASSERT_NO_THROW(nvte_ep_combine(NVTEEpHandle{handle_id, t.handle_mem.tensor}, t.recv_tokens.tensor, NVTECommWindow{}, + t.result.tensor, stream)); + + std::vector h_grad(num_tokens_ * hidden_dim_, __float2bfloat16(0.1f)); + CHECK_CUDA(cudaMemcpyAsync(buf.grad_result.get(), h_grad.data(), + h_grad.size() * sizeof(nv_bfloat16), + cudaMemcpyHostToDevice, stream)); + CHECK_CUDA(cudaMemsetAsync(buf.grad_expert.get(), 0, buf.grad_expert.bytes(), stream)); + CHECK_CUDA(cudaMemsetAsync(buf.g_recv_topk_weights.get(), 0, buf.g_recv_topk_weights.bytes(), stream)); + CHECK_CUDA(cudaMemsetAsync(buf.grad_topk_weights.get(), 0, buf.grad_topk_weights.bytes(), stream)); + + ASSERT_NO_THROW(nvte_ep_combine_bwd(NVTEEpHandle{handle_id, t.handle_mem.tensor}, t.grad_result.tensor, NVTECommWindow{}, + t.grad_expert.tensor, NVTECommWindow{}, stream)); + ASSERT_NO_THROW(nvte_ep_dispatch_bwd(NVTEEpHandle{handle_id, t.handle_mem.tensor}, t.grad_expert.tensor, NVTECommWindow{}, + t.g_recv_topk_weights.tensor, NVTECommWindow{}, + t.grad_tokens.tensor, t.grad_topk_weights.tensor, stream)); + CHECK_CUDA(cudaStreamSynchronize(stream)); + + std::vector h_gt(num_tokens_ * hidden_dim_); + CHECK_CUDA(cudaMemcpy(h_gt.data(), buf.grad_tokens.get(), + h_gt.size() * sizeof(nv_bfloat16), cudaMemcpyDeviceToHost)); + const float kExpGrad = static_cast(top_k_) * 0.1f; + for (int tok = 0; tok < num_tokens_; ++tok) + EXPECT_NEAR(__bfloat162float(h_gt[tok * hidden_dim_]), kExpGrad, bf16_tol(kExpGrad)) + << "grad_tokens token " << tok; + + if (g_process_id == 0) + printf(" DispatchBwdCheck: passed (grad_tokens == %.2f)\n", kExpGrad); + + CHECK_CUDA(cudaStreamDestroy(stream)); +} + +// ============================================================================= +// EPDispatchBwdGradWeightsTest: round-trip per-(t, k) weights. +// ============================================================================= + +class EPDispatchBwdGradWeightsTest : public EpOpTestBase {}; + +TEST_F(EPDispatchBwdGradWeightsTest, RoundTrip) { + EPBuffers buf; + buf.alloc(num_tokens_, top_k_, hidden_dim_, num_local_experts_, + ep_size_, max_tokens_per_rank_); + upload_inputs(buf); + EPTensors t(buf, num_tokens_, top_k_, hidden_dim_, num_local_experts_); + + // Distinct per-(rank, t, k) weights so each slot carries a unique value. + std::vector h_w(num_tokens_ * top_k_); + for (int tok = 0; tok < num_tokens_; ++tok) + for (int k = 0; k < top_k_; ++k) + h_w[tok * top_k_ + k] = 0.1f + 0.01f * tok + 0.001f * k + + 0.0001f * (g_process_id + 1); + CHECK_CUDA(cudaMemcpy(buf.topk_weights.get(), h_w.data(), + h_w.size() * sizeof(float), cudaMemcpyHostToDevice)); + + cudaStream_t stream; + CHECK_CUDA(cudaStreamCreate(&stream)); + + uint64_t handle_id = buf.handle_id; + ASSERT_NO_THROW(nvte_ep_prepare(NVTEEpHandle{handle_id, t.handle_mem.tensor}, t.topk_idx.tensor, t.token_counts.tensor, /*alignment=*/0, stream)); + CHECK_CUDA(cudaMemsetAsync(buf.recv_topk_weights.get(), 0, + buf.recv_topk_weights.bytes(), stream)); + ASSERT_NO_THROW(nvte_ep_dispatch(NVTEEpHandle{handle_id, t.handle_mem.tensor}, t.topk_idx.tensor, + t.tokens.tensor, NVTECommWindow{}, t.topk_weights.tensor, + NVTECommWindow{}, t.recv_tokens.tensor, NVTECommWindow{}, + t.recv_topk_weights.tensor, NVTECommWindow{}, stream)); + + // Sentinel: NaN so any (t, k) the bwd kernel fails to write is immediately visible. + std::vector h_nan(num_tokens_ * top_k_, + std::numeric_limits::quiet_NaN()); + CHECK_CUDA(cudaMemcpyAsync(buf.grad_topk_weights.get(), h_nan.data(), + h_nan.size() * sizeof(float), + cudaMemcpyHostToDevice, stream)); + CHECK_CUDA(cudaMemsetAsync(buf.grad_expert.get(), 0, buf.grad_expert.bytes(), stream)); + + // g_recv_topk_weights := recv_topk_weights (the round-trip input). + auto g_recv_t = make_nvte_tensor(buf.recv_topk_weights.get(), + {buf.recv_capacity}, kNVTEFloat32); + ASSERT_NO_THROW(nvte_ep_dispatch_bwd(NVTEEpHandle{handle_id, t.handle_mem.tensor}, t.grad_expert.tensor, + NVTECommWindow{}, g_recv_t.tensor, NVTECommWindow{}, + t.grad_tokens.tensor, t.grad_topk_weights.tensor, stream)); + CHECK_CUDA(cudaStreamSynchronize(stream)); + + std::vector h_grad_w(num_tokens_ * top_k_); + CHECK_CUDA(cudaMemcpy(h_grad_w.data(), buf.grad_topk_weights.get(), + h_grad_w.size() * sizeof(float), cudaMemcpyDeviceToHost)); + + const float kTol = 1e-5f; + int errs = 0, k0_eq_k1 = 0; + for (int tok = 0; tok < num_tokens_; ++tok) { + for (int k = 0; k < top_k_; ++k) { + float got = h_grad_w[tok * top_k_ + k]; + float exp = h_w[tok * top_k_ + k]; + if (std::isnan(got) || std::fabs(got - exp) > kTol) { + if (errs < 8) + fprintf(stderr, "Rank %d: grad_topk_weights[%d, %d]: got %.6f, expected %.6f\n", + g_process_id, tok, k, got, exp); + ++errs; + } + } + if (top_k_ >= 2 && + std::fabs(h_grad_w[tok * top_k_ + 0] - h_grad_w[tok * top_k_ + 1]) < 1e-7f) + ++k0_eq_k1; + } + EXPECT_EQ(errs, 0); + EXPECT_EQ(k0_eq_k1, 0) << "per-token-average regression: grad[t, 0] == grad[t, 1]"; + + if (g_process_id == 0 && errs == 0 && k0_eq_k1 == 0) + printf(" RoundTrip: passed (%d (t, k) gradients)\n", num_tokens_ * top_k_); + + CHECK_CUDA(cudaStreamDestroy(stream)); +} + +// ============================================================================= +// Integrated FwdBwd: NaN/Inf check end-to-end. +// ============================================================================= + +class EPPipelineTest : public EpOpTestBase {}; + +TEST_F(EPPipelineTest, FullForwardBackward) { + EPBuffers buf; + buf.alloc(num_tokens_, top_k_, hidden_dim_, num_local_experts_, + ep_size_, max_tokens_per_rank_); + upload_inputs(buf); + EPTensors t(buf, num_tokens_, top_k_, hidden_dim_, num_local_experts_); + + cudaStream_t stream; + CHECK_CUDA(cudaStreamCreate(&stream)); + + uint64_t handle_id = buf.handle_id; + ASSERT_NO_THROW(nvte_ep_prepare(NVTEEpHandle{handle_id, t.handle_mem.tensor}, t.topk_idx.tensor, t.token_counts.tensor, /*alignment=*/0, stream)); + ASSERT_NO_THROW(nvte_ep_dispatch(NVTEEpHandle{handle_id, t.handle_mem.tensor}, t.topk_idx.tensor, + t.tokens.tensor, NVTECommWindow{}, t.topk_weights.tensor, + NVTECommWindow{}, t.recv_tokens.tensor, NVTECommWindow{}, + t.recv_topk_weights.tensor, NVTECommWindow{}, stream)); + ASSERT_NO_THROW(nvte_ep_combine(NVTEEpHandle{handle_id, t.handle_mem.tensor}, t.recv_tokens.tensor, NVTECommWindow{}, + t.result.tensor, stream)); + + std::vector h_grad(num_tokens_ * hidden_dim_, __float2bfloat16(0.1f)); + CHECK_CUDA(cudaMemcpyAsync(buf.grad_result.get(), h_grad.data(), + h_grad.size() * sizeof(nv_bfloat16), + cudaMemcpyHostToDevice, stream)); + CHECK_CUDA(cudaMemsetAsync(buf.grad_expert.get(), 0, buf.grad_expert.bytes(), stream)); + CHECK_CUDA(cudaMemsetAsync(buf.g_recv_topk_weights.get(), 0, buf.g_recv_topk_weights.bytes(), stream)); + CHECK_CUDA(cudaMemsetAsync(buf.grad_topk_weights.get(), 0, buf.grad_topk_weights.bytes(), stream)); + + ASSERT_NO_THROW(nvte_ep_combine_bwd(NVTEEpHandle{handle_id, t.handle_mem.tensor}, t.grad_result.tensor, NVTECommWindow{}, + t.grad_expert.tensor, NVTECommWindow{}, stream)); + ASSERT_NO_THROW(nvte_ep_dispatch_bwd(NVTEEpHandle{handle_id, t.handle_mem.tensor}, t.grad_expert.tensor, NVTECommWindow{}, + t.g_recv_topk_weights.tensor, NVTECommWindow{}, + t.grad_tokens.tensor, t.grad_topk_weights.tensor, stream)); + CHECK_CUDA(cudaStreamSynchronize(stream)); + + ASSERT_TRUE(check_no_nan_inf(buf.result.get(), num_tokens_ * hidden_dim_, "result")); + ASSERT_TRUE(check_no_nan_inf(buf.grad_tokens.get(), num_tokens_ * hidden_dim_, "grad_tokens")); + + if (g_process_id == 0) printf(" FullForwardBackward: passed\n"); + + CHECK_CUDA(cudaStreamDestroy(stream)); +} + +// ============================================================================= +// EPZeroCopyTest: dispatch/combine with NCCL symmetric-memory windows attached +// to payload tensors (zero-copy fast path via ncclEpTensorCreateFromWindow). +// Symm-mem requirements per spec: input&output of Dispatch, input of Combine, +// input&output of Combine bwd, input of Dispatch bwd. +// ============================================================================= + +namespace { + +// Caller-owned ncclMemAlloc'd buffer with a registered symmetric window. +// Frees in destructor (deregister + ncclMemFree). Non-copyable, move-only. +struct SymmBuf { + void* ptr = nullptr; + size_t bytes = 0; + ncclWindow_t win = nullptr; + + SymmBuf() = default; + SymmBuf(const SymmBuf&) = delete; + SymmBuf& operator=(const SymmBuf&) = delete; + SymmBuf(SymmBuf&& o) noexcept : ptr(o.ptr), bytes(o.bytes), win(o.win) { + o.ptr = nullptr; o.win = nullptr; o.bytes = 0; + } + ~SymmBuf() { + if (win) ncclCommWindowDeregister(g_ep_comm, win); + if (ptr) ncclMemFree(ptr); + } + + void alloc(size_t n_bytes) { + bytes = n_bytes; + ASSERT_NCCL_OK(ncclMemAlloc(&ptr, bytes)); + CHECK_CUDA(cudaMemset(ptr, 0, bytes)); + ASSERT_NCCL_OK(ncclCommWindowRegister(g_ep_comm, ptr, bytes, &win, + NCCL_WIN_COLL_SYMMETRIC)); + } +}; + +// Build an NVTECommWindow descriptor pointing at a SymmBuf's window (offset 0). +static inline NVTECommWindow symm_window(const SymmBuf& b) { + return NVTECommWindow{b.win, /*offset=*/0}; +} + +} // namespace + +class EPZeroCopyTest : public EpOpTestBase {}; + +// Identity round-trip with symm-mem on dispatch i/o + combine input. Bit-exact +// vs HBM reference (same routing, same input). +TEST_F(EPZeroCopyTest, IdentityAllSymm) { + // HBM reference run. + EPBuffers ref_buf; + ref_buf.alloc(num_tokens_, top_k_, hidden_dim_, num_local_experts_, + ep_size_, max_tokens_per_rank_); + upload_inputs(ref_buf); + EPTensors ref_t(ref_buf, num_tokens_, top_k_, hidden_dim_, num_local_experts_); + + cudaStream_t stream; + CHECK_CUDA(cudaStreamCreate(&stream)); + + uint64_t ref_hid = ref_buf.handle_id; + ASSERT_NO_THROW(nvte_ep_prepare(NVTEEpHandle{ref_hid, ref_t.handle_mem.tensor}, ref_t.topk_idx.tensor, ref_t.token_counts.tensor, /*alignment=*/0, stream)); + ASSERT_NO_THROW(nvte_ep_dispatch(NVTEEpHandle{ref_hid, ref_t.handle_mem.tensor}, ref_t.topk_idx.tensor, + ref_t.tokens.tensor, NVTECommWindow{}, ref_t.topk_weights.tensor, + NVTECommWindow{}, ref_t.recv_tokens.tensor, NVTECommWindow{}, + ref_t.recv_topk_weights.tensor, NVTECommWindow{}, stream)); + ASSERT_NO_THROW(nvte_ep_combine(NVTEEpHandle{ref_hid, ref_t.handle_mem.tensor}, ref_t.recv_tokens.tensor, NVTECommWindow{}, + ref_t.result.tensor, stream)); + CHECK_CUDA(cudaStreamSynchronize(stream)); + + std::vector ref_recv(ref_buf.recv_capacity * hidden_dim_); + std::vector ref_result(num_tokens_ * hidden_dim_); + CHECK_CUDA(cudaMemcpy(ref_recv.data(), ref_buf.recv_tokens.get(), + ref_recv.size() * sizeof(nv_bfloat16), cudaMemcpyDeviceToHost)); + CHECK_CUDA(cudaMemcpy(ref_result.data(), ref_buf.result.get(), + ref_result.size() * sizeof(nv_bfloat16), cudaMemcpyDeviceToHost)); + + // Symm-mem run: tokens, recv_tokens, combine_input (== recv_tokens) all symm. + EPBuffers sym_buf; // alloc all buffers except the symm ones. + sym_buf.alloc(num_tokens_, top_k_, hidden_dim_, num_local_experts_, + ep_size_, max_tokens_per_rank_); + upload_inputs(sym_buf); + + SymmBuf sym_tokens, sym_recv; + sym_tokens.alloc(num_tokens_ * hidden_dim_ * sizeof(nv_bfloat16)); + sym_recv .alloc(sym_buf.recv_capacity * hidden_dim_ * sizeof(nv_bfloat16)); + + // Stage same tokens into the symm-mem input. + auto h_tok = generate_tokens(g_process_id, num_tokens_, hidden_dim_); + CHECK_CUDA(cudaMemcpy(sym_tokens.ptr, h_tok.data(), + h_tok.size() * sizeof(nv_bfloat16), cudaMemcpyHostToDevice)); + + EPTensors sym_t(sym_buf, num_tokens_, top_k_, hidden_dim_, num_local_experts_); + // Replace the tokens/recv_tokens views with ones pointing at the symm buffers. + sym_t.tokens = make_nvte_tensor(sym_tokens.ptr, + {(size_t)num_tokens_, (size_t)hidden_dim_}, kNVTEBFloat16); + sym_t.recv_tokens = make_nvte_tensor(sym_recv.ptr, + {sym_buf.recv_capacity, (size_t)hidden_dim_}, kNVTEBFloat16); + + uint64_t sym_hid = sym_buf.handle_id; + ASSERT_NO_THROW(nvte_ep_prepare(NVTEEpHandle{sym_hid, sym_t.handle_mem.tensor}, sym_t.topk_idx.tensor, sym_t.token_counts.tensor, /*alignment=*/0, stream)); + ASSERT_NO_THROW(nvte_ep_dispatch(NVTEEpHandle{sym_hid, sym_t.handle_mem.tensor}, sym_t.topk_idx.tensor, + sym_t.tokens.tensor, symm_window(sym_tokens), + sym_t.topk_weights.tensor, NVTECommWindow{}, + sym_t.recv_tokens.tensor, symm_window(sym_recv), + sym_t.recv_topk_weights.tensor, NVTECommWindow{}, stream)); + ASSERT_NO_THROW(nvte_ep_combine(NVTEEpHandle{sym_hid, sym_t.handle_mem.tensor}, sym_t.recv_tokens.tensor, + symm_window(sym_recv), sym_t.result.tensor, stream)); + CHECK_CUDA(cudaStreamSynchronize(stream)); + + std::vector sym_recv_host(sym_buf.recv_capacity * hidden_dim_); + std::vector sym_result(num_tokens_ * hidden_dim_); + CHECK_CUDA(cudaMemcpy(sym_recv_host.data(), sym_recv.ptr, + sym_recv_host.size() * sizeof(nv_bfloat16), cudaMemcpyDeviceToHost)); + CHECK_CUDA(cudaMemcpy(sym_result.data(), sym_buf.result.get(), + sym_result.size() * sizeof(nv_bfloat16), cudaMemcpyDeviceToHost)); + + // Compare per filled recv slot (HBM ref vs symm) and full result. + int total_recv = read_total_recv(sym_buf); + for (int i = 0; i < total_recv * hidden_dim_; ++i) + ASSERT_EQ(__bfloat162float(sym_recv_host[i]), __bfloat162float(ref_recv[i])) + << "recv mismatch at " << i; + for (size_t i = 0; i < sym_result.size(); ++i) + ASSERT_EQ(__bfloat162float(sym_result[i]), __bfloat162float(ref_result[i])) + << "result mismatch at " << i; + + if (g_process_id == 0) + printf(" IdentityAllSymm: passed (recv_slots=%d, bit-exact vs HBM)\n", total_recv); + + CHECK_CUDA(cudaStreamDestroy(stream)); +} + +// Same buffers, 2 iterations — catches window-lifecycle regressions where the +// symm-mem registration goes stale between calls. +TEST_F(EPZeroCopyTest, IdentityAllSymmRepeated) { + EPBuffers buf; + buf.alloc(num_tokens_, top_k_, hidden_dim_, num_local_experts_, + ep_size_, max_tokens_per_rank_); + upload_inputs(buf); + + SymmBuf sym_tokens, sym_recv; + sym_tokens.alloc(num_tokens_ * hidden_dim_ * sizeof(nv_bfloat16)); + sym_recv .alloc(buf.recv_capacity * hidden_dim_ * sizeof(nv_bfloat16)); + auto h_tok = generate_tokens(g_process_id, num_tokens_, hidden_dim_); + CHECK_CUDA(cudaMemcpy(sym_tokens.ptr, h_tok.data(), + h_tok.size() * sizeof(nv_bfloat16), cudaMemcpyHostToDevice)); + + EPTensors t(buf, num_tokens_, top_k_, hidden_dim_, num_local_experts_); + t.tokens = make_nvte_tensor(sym_tokens.ptr, + {(size_t)num_tokens_, (size_t)hidden_dim_}, kNVTEBFloat16); + t.recv_tokens = make_nvte_tensor(sym_recv.ptr, + {buf.recv_capacity, (size_t)hidden_dim_}, kNVTEBFloat16); + + cudaStream_t stream; + CHECK_CUDA(cudaStreamCreate(&stream)); + + uint64_t handle_id = buf.handle_id; + for (int iter = 0; iter < 2; ++iter) { + ASSERT_NO_THROW(nvte_ep_prepare(NVTEEpHandle{handle_id, t.handle_mem.tensor}, t.topk_idx.tensor, t.token_counts.tensor, /*alignment=*/0, stream)); + ASSERT_NO_THROW(nvte_ep_dispatch(NVTEEpHandle{handle_id, t.handle_mem.tensor}, t.topk_idx.tensor, + t.tokens.tensor, symm_window(sym_tokens), + t.topk_weights.tensor, NVTECommWindow{}, + t.recv_tokens.tensor, symm_window(sym_recv), + t.recv_topk_weights.tensor, NVTECommWindow{}, stream)); + ASSERT_NO_THROW(nvte_ep_combine(NVTEEpHandle{handle_id, t.handle_mem.tensor}, t.recv_tokens.tensor, + symm_window(sym_recv), t.result.tensor, stream)); + CHECK_CUDA(cudaStreamSynchronize(stream)); + + std::vector h_res(num_tokens_ * hidden_dim_); + CHECK_CUDA(cudaMemcpy(h_res.data(), buf.result.get(), + h_res.size() * sizeof(nv_bfloat16), cudaMemcpyDeviceToHost)); + for (int tok = 0; tok < num_tokens_; ++tok) { + float exp = __bfloat162float(h_tok[tok * hidden_dim_]) * static_cast(top_k_); + float got = __bfloat162float(h_res[tok * hidden_dim_]); + ASSERT_NEAR(got, exp, bf16_tol(exp)) << "iter " << iter << " tok " << tok; + } + } + + if (g_process_id == 0) + printf(" IdentityAllSymmRepeated: passed (2 iters)\n"); + + CHECK_CUDA(cudaStreamDestroy(stream)); +} + +// Full forward+backward with symm-mem on every spec-mandated buffer: +// dispatch i/o, combine input, combine_bwd i/o, dispatch_bwd input. +// TODO: flaky on rank 0 (grad_tokens partial-zero) when run after the prior +// EPZeroCopyTest cases in the same binary; passes in isolation. Re-enable once +// the root cause (likely NCCL EP NVLS write→read coherence on grad_expert) is +// understood. Tracked separately. +TEST_F(EPZeroCopyTest, DISABLED_FullPipelineSymm) { + EPBuffers buf; + buf.alloc(num_tokens_, top_k_, hidden_dim_, num_local_experts_, + ep_size_, max_tokens_per_rank_); + upload_inputs(buf); + + // Symm-mem: tokens (dispatch input), recv_tokens (dispatch output AND + // combine input), grad_result (combine_bwd input), grad_expert + // (combine_bwd output AND dispatch_bwd input). + SymmBuf sym_tokens, sym_recv, sym_grad_result, sym_grad_expert; + sym_tokens .alloc(num_tokens_ * hidden_dim_ * sizeof(nv_bfloat16)); + sym_recv .alloc(buf.recv_capacity * hidden_dim_ * sizeof(nv_bfloat16)); + sym_grad_result.alloc(num_tokens_ * hidden_dim_ * sizeof(nv_bfloat16)); + sym_grad_expert.alloc(buf.recv_capacity * hidden_dim_ * sizeof(nv_bfloat16)); + + auto h_tok = generate_tokens(g_process_id, num_tokens_, hidden_dim_); + CHECK_CUDA(cudaMemcpy(sym_tokens.ptr, h_tok.data(), + h_tok.size() * sizeof(nv_bfloat16), cudaMemcpyHostToDevice)); + + EPTensors t(buf, num_tokens_, top_k_, hidden_dim_, num_local_experts_); + t.tokens = make_nvte_tensor(sym_tokens.ptr, + {(size_t)num_tokens_, (size_t)hidden_dim_}, kNVTEBFloat16); + t.recv_tokens = make_nvte_tensor(sym_recv.ptr, + {buf.recv_capacity, (size_t)hidden_dim_}, kNVTEBFloat16); + t.grad_result = make_nvte_tensor(sym_grad_result.ptr, + {(size_t)num_tokens_, (size_t)hidden_dim_}, kNVTEBFloat16); + t.grad_expert = make_nvte_tensor(sym_grad_expert.ptr, + {buf.recv_capacity, (size_t)hidden_dim_}, kNVTEBFloat16); + + cudaStream_t stream; + CHECK_CUDA(cudaStreamCreate(&stream)); + + uint64_t handle_id = buf.handle_id; + ASSERT_NO_THROW(nvte_ep_prepare(NVTEEpHandle{handle_id, t.handle_mem.tensor}, t.topk_idx.tensor, t.token_counts.tensor, /*alignment=*/0, stream)); + ASSERT_NO_THROW(nvte_ep_dispatch(NVTEEpHandle{handle_id, t.handle_mem.tensor}, t.topk_idx.tensor, + t.tokens.tensor, symm_window(sym_tokens), + t.topk_weights.tensor, NVTECommWindow{}, + t.recv_tokens.tensor, symm_window(sym_recv), + t.recv_topk_weights.tensor, NVTECommWindow{}, stream)); + ASSERT_NO_THROW(nvte_ep_combine(NVTEEpHandle{handle_id, t.handle_mem.tensor}, t.recv_tokens.tensor, + symm_window(sym_recv), t.result.tensor, stream)); + + std::vector h_grad(num_tokens_ * hidden_dim_, __float2bfloat16(0.1f)); + CHECK_CUDA(cudaMemcpyAsync(sym_grad_result.ptr, h_grad.data(), + h_grad.size() * sizeof(nv_bfloat16), + cudaMemcpyHostToDevice, stream)); + CHECK_CUDA(cudaMemsetAsync(sym_grad_expert.ptr, 0, sym_grad_expert.bytes, stream)); + CHECK_CUDA(cudaMemsetAsync(buf.g_recv_topk_weights.get(), 0, buf.g_recv_topk_weights.bytes(), stream)); + CHECK_CUDA(cudaMemsetAsync(buf.grad_topk_weights.get(), 0, buf.grad_topk_weights.bytes(), stream)); + + ASSERT_NO_THROW(nvte_ep_combine_bwd(NVTEEpHandle{handle_id, t.handle_mem.tensor}, t.grad_result.tensor, + symm_window(sym_grad_result), t.grad_expert.tensor, + symm_window(sym_grad_expert), stream)); + ASSERT_NO_THROW(nvte_ep_dispatch_bwd(NVTEEpHandle{handle_id, t.handle_mem.tensor}, t.grad_expert.tensor, + symm_window(sym_grad_expert), + t.g_recv_topk_weights.tensor, NVTECommWindow{}, + t.grad_tokens.tensor, t.grad_topk_weights.tensor, stream)); + CHECK_CUDA(cudaStreamSynchronize(stream)); + + ASSERT_TRUE(check_no_nan_inf(buf.result.get(), num_tokens_ * hidden_dim_, "result")); + ASSERT_TRUE(check_no_nan_inf(buf.grad_tokens.get(), num_tokens_ * hidden_dim_, "grad_tokens")); + + std::vector h_gt(num_tokens_ * hidden_dim_); + CHECK_CUDA(cudaMemcpy(h_gt.data(), buf.grad_tokens.get(), + h_gt.size() * sizeof(nv_bfloat16), cudaMemcpyDeviceToHost)); + const float kExpGrad = static_cast(top_k_) * 0.1f; + for (int tok = 0; tok < num_tokens_; ++tok) + EXPECT_NEAR(__bfloat162float(h_gt[tok * hidden_dim_]), kExpGrad, bf16_tol(kExpGrad)) + << "grad_tokens token " << tok; + + if (g_process_id == 0) printf(" FullPipelineSymm: passed\n"); + + CHECK_CUDA(cudaStreamDestroy(stream)); +} + +// ── main ────────────────────────────────────────────────────────────────────── + +int main(int argc, char* argv[]) { + if (!ep_bootstrap(argc, argv)) return 0; + int ret = RUN_ALL_TESTS(); + ep_teardown(); + return ret; +} diff --git a/tests/pytorch/distributed/run_ep.py b/tests/pytorch/distributed/run_ep.py new file mode 100644 index 0000000000..d3abfebb82 --- /dev/null +++ b/tests/pytorch/distributed/run_ep.py @@ -0,0 +1,689 @@ +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. +"""Multi-process PyTorch EP tests, launched via torchrun (one process per GPU).""" + +import os +import sys +import unittest + +import numpy as np +import torch +import torch.distributed as dist + +from transformer_engine.pytorch.ep import ( + EpHandle, + EpBuffer, + ep_bootstrap, + ep_prepare, + ep_dispatch, + ep_combine, + symm_mem_alloc, + _ep_combine_raw, + _ep_dispatch_raw, + _zero_copy_scope, +) + +# Must come after the transformer_engine import so libtransformer_engine.so is loaded. +import transformer_engine_torch as tex + + +NUM_LOCAL_EXPERTS = 2 +HIDDEN_DIM = 32 +TOP_K = 2 +TOKENS_PER_RANK = 4 + + +def _device_sm() -> int: + major, minor = torch.cuda.get_device_capability() + return major * 10 + minor + + +def _build_ep_group(): + """EP group spanning all ranks of the default PG.""" + world_pg = dist.distributed_c10d._get_default_group() + ranks = list(range(world_pg.size())) + return dist.new_group(ranks=ranks, backend="nccl") + + +def _make_identity_inputs(rank, ep_size, nonuniform=False, device="cuda"): + """Per-rank identity routing + uniform weights so combine ≈ tokens.""" + T = TOKENS_PER_RANK + E = ep_size * NUM_LOCAL_EXPERTS + topk_idx = np.empty((T, TOP_K), dtype=np.int64) + if nonuniform: + assert TOP_K == 2 + for t in range(T): + topk_idx[t, 0] = 0 + topk_idx[t, 1] = 1 + (t % (E - 1)) + else: + base = rank * T + for t in range(T): + for k in range(TOP_K): + topk_idx[t, k] = ((base + t) * TOP_K + k) % E + tokens_np = np.linspace( + 0.1 + rank * 0.01, 0.9 + rank * 0.01, T * HIDDEN_DIM, dtype=np.float32 + ).reshape(T, HIDDEN_DIM) + topk_weights = np.full((T, TOP_K), 1.0 / TOP_K, dtype=np.float32) + return ( + torch.from_numpy(topk_idx).to(device), + torch.from_numpy(tokens_np).to(device=device, dtype=torch.bfloat16), + torch.from_numpy(topk_weights).to(device), + ) + + +class _Cfg: + rank: int + world_size: int + ep_size: int + num_experts: int + recv_capacity_per_rank: int + device: torch.device + + +def _make_cfg() -> _Cfg: + cfg = _Cfg() + cfg.rank = dist.get_rank() + cfg.world_size = dist.get_world_size() + cfg.ep_size = cfg.world_size + cfg.num_experts = NUM_LOCAL_EXPERTS * cfg.ep_size + T = TOKENS_PER_RANK + active = min(cfg.num_experts, T * cfg.ep_size * TOP_K) + overconc = cfg.num_experts // active + cfg.recv_capacity_per_rank = NUM_LOCAL_EXPERTS * max(T * cfg.ep_size * TOP_K, 16) * overconc * 2 + cfg.device = torch.device("cuda", torch.cuda.current_device()) + return cfg + + +# ── Test class ─────────────────────────────────────────────────────────────── + + +class TestEP(unittest.TestCase): + cfg: _Cfg + ep_group: dist.ProcessGroup + + @classmethod + def setUpClass(cls): + if _device_sm() < 90: + raise unittest.SkipTest(f"NCCL EP requires SM>=90 (got SM{_device_sm()})") + cls.cfg = _make_cfg() + cls.ep_group = _build_ep_group() + ep_bootstrap( + cls.ep_group, + num_experts=cls.cfg.num_experts, + max_tokens_per_rank=TOKENS_PER_RANK, + recv_capacity_per_rank=cls.cfg.recv_capacity_per_rank, + hidden_dim=HIDDEN_DIM, + allow_handle_mem_reloc=False, + ) + + def _make_handle(self, alignment=0, top_k=TOP_K): + return EpHandle( + top_k=top_k, + max_tokens_per_rank=TOKENS_PER_RANK, + recv_capacity_per_rank=self.cfg.recv_capacity_per_rank, + hidden_dim=HIDDEN_DIM, + num_local_experts=NUM_LOCAL_EXPERTS, + alignment=alignment, + ) + + def _make_buffers(self, dtype=torch.bfloat16): + """Allocate raw recv buffers + token_counts for the primitive (non-autograd) tests.""" + rc = self.cfg.recv_capacity_per_rank + return ( + torch.empty(rc, HIDDEN_DIM, dtype=dtype, device=self.cfg.device), + torch.empty(rc, dtype=torch.float32, device=self.cfg.device), + torch.empty(NUM_LOCAL_EXPERTS, dtype=torch.int32, device=self.cfg.device), + ) + + def _make_ep_buffer(self, handle, use_symm_mem=False): + return EpBuffer( + handle, + ep_group=self.ep_group if use_symm_mem else None, + use_symm_mem=use_symm_mem, + ) + + @staticmethod + def _weighted(recv_tokens, recv_w): + """fp32 per-slot weighting + cast back, matching the combine forward path.""" + mask = (recv_w != 0).to(torch.float32).unsqueeze(-1) + return (recv_tokens.float() * recv_w.unsqueeze(-1).float() * mask).to(recv_tokens.dtype) + + # ── prepare ────────────────────────────────────────────────────────── + + def test_primitive_prepare(self): + handle = self._make_handle() + topk_idx, _toks, _w = _make_identity_inputs(self.cfg.rank, self.cfg.ep_size) + token_counts = ep_prepare(handle, topk_idx) + torch.cuda.synchronize() + self.assertEqual(token_counts.shape, (NUM_LOCAL_EXPERTS,)) + # Global recv total == global send * TOP_K. + local = int(token_counts.sum().item()) + total = torch.tensor([local], dtype=torch.int64, device=self.cfg.device) + dist.all_reduce(total, op=dist.ReduceOp.SUM, group=self.ep_group) + self.assertEqual(int(total.item()), self.cfg.world_size * TOKENS_PER_RANK * TOP_K) + + # ── identity round-trip via primitives ─────────────────────────────── + + def _run_identity_round_trip(self, nonuniform): + handle = self._make_handle() + topk_idx, tokens, w = _make_identity_inputs( + self.cfg.rank, self.cfg.ep_size, nonuniform=nonuniform + ) + recv_tokens, recv_w, _ = self._make_buffers() + ep_prepare(handle, topk_idx) + _ep_dispatch_raw(handle, topk_idx, tokens, w, recv_tokens, recv_w) + result = torch.empty_like(tokens) + _ep_combine_raw(handle, self._weighted(recv_tokens, recv_w), result) + torch.cuda.synchronize() + torch.testing.assert_close(result.float(), tokens.float(), atol=5e-2, rtol=5e-2) + + def test_primitive_dispatch_combine_identity_uniform(self): + self._run_identity_round_trip(nonuniform=False) + + def test_primitive_dispatch_combine_identity_nonuniform(self): + self._run_identity_round_trip(nonuniform=True) + + def test_3d_input_round_trip(self): + """3D (B, S, H) inputs round-trip identically to 2D — leading dims are flattened to T.""" + handle = self._make_handle() + topk_idx, tokens_2d, w = _make_identity_inputs(self.cfg.rank, self.cfg.ep_size) + B, S = 2, TOKENS_PER_RANK // 2 + assert B * S == TOKENS_PER_RANK + tokens_3d = tokens_2d.view(B, S, HIDDEN_DIM) + topk_idx_3d = topk_idx.view(B, S, TOP_K) + w_3d = w.view(B, S, TOP_K) + recv_tokens, recv_w, _ = self._make_buffers() + ep_prepare(handle, topk_idx_3d) + _ep_dispatch_raw(handle, topk_idx_3d, tokens_3d, w_3d, recv_tokens, recv_w) + result = torch.empty_like(tokens_3d) + _ep_combine_raw(handle, self._weighted(recv_tokens, recv_w), result) + torch.cuda.synchronize() + self.assertEqual(result.shape, (B, S, HIDDEN_DIM)) + torch.testing.assert_close(result.float(), tokens_3d.float(), atol=5e-2, rtol=5e-2) + + # ── autograd ───────────────────────────────────────────────────────── + + def test_dispatch_fwd_bwd(self): + """0.5*||recv_tokens||² ⇒ grad_tokens ≈ TOP_K * tokens.""" + handle = self._make_handle() + buffer = self._make_ep_buffer(handle) + topk_idx, tokens, w = _make_identity_inputs(self.cfg.rank, self.cfg.ep_size) + tokens_p = tokens.detach().clone().requires_grad_(True) + recv_t, _recv_w, _tc = ep_dispatch(handle, buffer, tokens_p, topk_idx, w) + loss = 0.5 * (recv_t.float() ** 2).sum() + loss.backward() + torch.cuda.synchronize() + torch.testing.assert_close( + tokens_p.grad.float(), tokens.float() * float(TOP_K), atol=5e-2, rtol=5e-2 + ) + + def test_combine_fwd_bwd(self): + """const eo=c, uniform w ⇒ max|grad_eo| ≈ c / TOP_K.""" + handle = self._make_handle() + buffer = self._make_ep_buffer(handle) + topk_idx, tokens, w = _make_identity_inputs(self.cfg.rank, self.cfg.ep_size) + _recv_t, recv_w_out, _tc = ep_dispatch(handle, buffer, tokens, topk_idx, w) + eo_const = 0.5 + eo = torch.full( + (self.cfg.recv_capacity_per_rank, HIDDEN_DIM), + eo_const, + dtype=torch.bfloat16, + device=self.cfg.device, + requires_grad=True, + ) + out = ep_combine(handle, buffer, eo, recv_w_out) + loss = 0.5 * (out.float() ** 2).sum() + loss.backward() + torch.cuda.synchronize() + arr = eo.grad.float().cpu().numpy() + self.assertTrue(np.all(np.isfinite(arr))) + self.assertGreater(arr.max(), 0.0) + np.testing.assert_allclose(arr.max(), eo_const / float(TOP_K), atol=5e-2, rtol=5e-2) + + # ── coverage: top_k=1 + alignment ──────────────────────────────────── + + def test_dispatch_combine_top_k_1_all_to_expert_0(self): + handle = self._make_handle(top_k=1) + T = TOKENS_PER_RANK + topk_idx = torch.zeros(T, 1, dtype=torch.int64, device=self.cfg.device) + w = torch.ones(T, 1, dtype=torch.float32, device=self.cfg.device) + tokens = torch.from_numpy( + np.linspace(0.1, 0.9, T * HIDDEN_DIM, dtype=np.float32).reshape(T, HIDDEN_DIM) + ).to(device=self.cfg.device, dtype=torch.bfloat16) + recv_tokens, recv_w, _ = self._make_buffers() + token_counts = ep_prepare(handle, topk_idx) + _ep_dispatch_raw(handle, topk_idx, tokens, w, recv_tokens, recv_w) + result = torch.empty_like(tokens) + _ep_combine_raw(handle, self._weighted(recv_tokens, recv_w), result) + torch.cuda.synchronize() + torch.testing.assert_close(result.float(), tokens.float(), atol=5e-2, rtol=5e-2) + # Rank 0 owns expert 0 and receives world*T tokens; other ranks see 0. + tc = token_counts.cpu().numpy() + if self.cfg.rank == 0: + self.assertEqual(int(tc[0]), self.cfg.world_size * T) + else: + self.assertEqual(int(tc[0]), 0) + if NUM_LOCAL_EXPERTS > 1: + self.assertEqual(int(tc[1:].sum()), 0) + + def test_dispatch_combine_alignment(self): + alignment = 8 + handle = self._make_handle(alignment=alignment) + topk_idx, tokens, w = _make_identity_inputs(self.cfg.rank, self.cfg.ep_size) + recv_tokens, recv_w, _ = self._make_buffers() + ep_prepare(handle, topk_idx) + _ep_dispatch_raw(handle, topk_idx, tokens, w, recv_tokens, recv_w) + result = torch.empty_like(tokens) + _ep_combine_raw(handle, self._weighted(recv_tokens, recv_w), result) + torch.cuda.synchronize() + torch.testing.assert_close(result.float(), tokens.float(), atol=5e-2, rtol=5e-2) + + # ── Integration: CUDA graph, autocast, torch.compile ───────────────── + + def _moe_step(self, handle, buffer, topk_idx, tokens, w): + recv_t, recv_w_out, _tc = ep_dispatch(handle, buffer, tokens, topk_idx, w) + return ep_combine(handle, buffer, recv_t, recv_w_out) + + def test_cuda_graph_capture(self): + """Capture dispatch+combine via the raw ops; replay must be bit-stable.""" + handle = self._make_handle() + topk_idx, tokens, w = _make_identity_inputs(self.cfg.rank, self.cfg.ep_size) + recv_tokens, recv_w, _ = self._make_buffers() + result = torch.empty_like(tokens) + + def step(): + ep_prepare(handle, topk_idx) + _ep_dispatch_raw(handle, topk_idx, tokens, w, recv_tokens, recv_w) + _ep_combine_raw(handle, self._weighted(recv_tokens, recv_w), result) + + for _ in range(3): + step() + torch.cuda.synchronize() + + # Routing is fixed per layer, so prepare runs once before capture and only + # dispatch+combine go into the graph. + ep_prepare(handle, topk_idx) + torch.cuda.synchronize() + + graph = torch.cuda.CUDAGraph() + s = torch.cuda.Stream() + s.wait_stream(torch.cuda.current_stream()) + with torch.cuda.stream(s): + with torch.cuda.graph(graph): + _ep_dispatch_raw(handle, topk_idx, tokens, w, recv_tokens, recv_w) + _ep_combine_raw(handle, self._weighted(recv_tokens, recv_w), result) + torch.cuda.current_stream().wait_stream(s) + torch.cuda.synchronize() + + ref = result.clone() + for _ in range(5): + graph.replay() + torch.cuda.synchronize() + torch.testing.assert_close(result.float(), ref.float(), atol=0, rtol=0) + + def test_autocast_bf16(self): + """EP under autocast must preserve dtype and identity round-trip.""" + handle = self._make_handle() + topk_idx, tokens, w = _make_identity_inputs(self.cfg.rank, self.cfg.ep_size) + recv_tokens, recv_w, _ = self._make_buffers() + result = torch.empty_like(tokens) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + ep_prepare(handle, topk_idx) + _ep_dispatch_raw(handle, topk_idx, tokens, w, recv_tokens, recv_w) + _ep_combine_raw(handle, self._weighted(recv_tokens, recv_w), result) + torch.cuda.synchronize() + self.assertEqual(recv_tokens.dtype, torch.bfloat16) + self.assertEqual(result.dtype, torch.bfloat16) + torch.testing.assert_close(result.float(), tokens.float(), atol=5e-2, rtol=5e-2) + + def test_torch_compile_fullgraph(self): + """Raw EP pipeline under torch.compile(fullgraph=True) must not graph-break.""" + handle = self._make_handle() + topk_idx, tokens, w = _make_identity_inputs(self.cfg.rank, self.cfg.ep_size) + recv_tokens, recv_w, token_counts = self._make_buffers() + result = torch.empty_like(tokens) + alignment = handle.alignment + handle_id = handle.handle_id + handle_mem = handle.handle_mem + + def step(handle_mem, topk_idx, tokens, w, recv_tokens, recv_w, token_counts, result): + torch.ops.transformer_engine_ep.prepare( + handle_mem, handle_id, topk_idx, token_counts, alignment + ) + torch.ops.transformer_engine_ep.dispatch( + handle_mem, handle_id, topk_idx, tokens, w, recv_tokens, recv_w + ) + mask = (recv_w != 0).to(torch.float32).unsqueeze(-1) + weighted = (recv_tokens.float() * recv_w.unsqueeze(-1).float() * mask).to( + recv_tokens.dtype + ) + torch.ops.transformer_engine_ep.combine(handle_mem, handle_id, weighted, result) + return result + + ref = torch.empty_like(tokens) + step(handle_mem, topk_idx, tokens, w, recv_tokens, recv_w, token_counts, ref) + torch.cuda.synchronize() + ref_clone = ref.clone() + + recv_tokens.zero_() + recv_w.zero_() + token_counts.zero_() + result.zero_() + compiled = torch.compile(step, fullgraph=True, dynamic=False) + out = compiled(handle_mem, topk_idx, tokens, w, recv_tokens, recv_w, token_counts, result) + torch.cuda.synchronize() + torch.testing.assert_close(out.float(), ref_clone.float(), atol=5e-2, rtol=5e-2) + + # ── Zero-copy via NCCL symmetric memory ────────────────────────────── + + def _try_symm_alloc(self, shape, dtype): + """Allocate a symm-mem tensor or skip if the backend is unavailable.""" + try: + return symm_mem_alloc(shape, dtype, self.ep_group, device=self.cfg.device) + except Exception as e: + self.skipTest(f"NCCL symmetric memory unavailable: {e}") + + def test_zero_copy_dispatch_combine_identity(self): + """Symm-mem payload buffers must match the HBM path bit-for-bit.""" + handle = self._make_handle() + topk_idx, tokens_hbm, w = _make_identity_inputs(self.cfg.rank, self.cfg.ep_size) + rc = self.cfg.recv_capacity_per_rank + + tokens_sm = self._try_symm_alloc((TOKENS_PER_RANK, HIDDEN_DIM), torch.bfloat16) + recv_tokens_sm = self._try_symm_alloc((rc, HIDDEN_DIM), torch.bfloat16) + expert_out_sm = self._try_symm_alloc((rc, HIDDEN_DIM), torch.bfloat16) + # Guard against the test silently degrading to the HBM path. + from torch.distributed._symmetric_memory import is_symm_mem_tensor + + self.assertTrue(is_symm_mem_tensor(tokens_sm), "tokens_sm not symm-mem backed") + self.assertTrue(is_symm_mem_tensor(recv_tokens_sm), "recv_tokens_sm not symm-mem backed") + self.assertTrue(is_symm_mem_tensor(expert_out_sm), "expert_out_sm not symm-mem backed") + tokens_sm.copy_(tokens_hbm) + recv_w = torch.empty(rc, dtype=torch.float32, device=self.cfg.device) + + # Symm-mem path. + ep_prepare(handle, topk_idx) + _ep_dispatch_raw(handle, topk_idx, tokens_sm, w, recv_tokens_sm, recv_w) + expert_out_sm.copy_(self._weighted(recv_tokens_sm, recv_w)) + result_sm = torch.empty_like(tokens_hbm) + _ep_combine_raw(handle, expert_out_sm, result_sm) + torch.cuda.synchronize() + + # HBM reference. + handle_ref = self._make_handle() + recv_tokens_hbm, recv_w_hbm, _ = self._make_buffers() + result_hbm = torch.empty_like(tokens_hbm) + ep_prepare(handle_ref, topk_idx) + _ep_dispatch_raw(handle_ref, topk_idx, tokens_hbm, w, recv_tokens_hbm, recv_w_hbm) + _ep_combine_raw(handle_ref, self._weighted(recv_tokens_hbm, recv_w_hbm), result_hbm) + torch.cuda.synchronize() + + torch.testing.assert_close(result_sm.float(), tokens_hbm.float(), atol=5e-2, rtol=5e-2) + torch.testing.assert_close(result_sm, result_hbm, atol=0, rtol=0) + + def test_zero_copy_cuda_graph_capture(self): + """Capture dispatch+combine over symm-mem payload buffers; replay must be bit-stable.""" + handle = self._make_handle() + topk_idx, tokens_hbm, w = _make_identity_inputs(self.cfg.rank, self.cfg.ep_size) + rc = self.cfg.recv_capacity_per_rank + tokens = self._try_symm_alloc((TOKENS_PER_RANK, HIDDEN_DIM), torch.bfloat16) + recv_tokens = self._try_symm_alloc((rc, HIDDEN_DIM), torch.bfloat16) + expert_out = self._try_symm_alloc((rc, HIDDEN_DIM), torch.bfloat16) + tokens.copy_(tokens_hbm) + recv_w = torch.empty(rc, dtype=torch.float32, device=self.cfg.device) + result = torch.empty_like(tokens_hbm) + + def step(): + ep_prepare(handle, topk_idx) + _ep_dispatch_raw(handle, topk_idx, tokens, w, recv_tokens, recv_w) + expert_out.copy_(self._weighted(recv_tokens, recv_w)) + _ep_combine_raw(handle, expert_out, result) + + for _ in range(3): + step() + torch.cuda.synchronize() + + ep_prepare(handle, topk_idx) + torch.cuda.synchronize() + + graph = torch.cuda.CUDAGraph() + s = torch.cuda.Stream() + s.wait_stream(torch.cuda.current_stream()) + with torch.cuda.stream(s): + with torch.cuda.graph(graph): + _ep_dispatch_raw(handle, topk_idx, tokens, w, recv_tokens, recv_w) + expert_out.copy_(self._weighted(recv_tokens, recv_w)) + _ep_combine_raw(handle, expert_out, result) + torch.cuda.current_stream().wait_stream(s) + torch.cuda.synchronize() + + ref = result.clone() + for _ in range(5): + graph.replay() + torch.cuda.synchronize() + torch.testing.assert_close(result.float(), ref.float(), atol=0, rtol=0) + + def test_zero_copy_autograd_combine(self): + """ep_combine autograd path must keep EpBuffer's symm-mem annotation on combine_in.""" + handle = self._make_handle() + buffer = self._make_ep_buffer(handle, use_symm_mem=True) + from torch.distributed._symmetric_memory import is_symm_mem_tensor + + self.assertTrue(is_symm_mem_tensor(buffer.combine_in)) + self.assertTrue(is_symm_mem_tensor(buffer.recv_tokens)) + topk_idx, tokens, w = _make_identity_inputs(self.cfg.rank, self.cfg.ep_size) + tokens_p = tokens.detach().clone().requires_grad_(True) + out = self._moe_step(handle, buffer, topk_idx, tokens_p, w) + loss = 0.5 * (out.float() ** 2).sum() + loss.backward() + torch.cuda.synchronize() + torch.testing.assert_close(out.float(), tokens.float(), atol=5e-2, rtol=5e-2) + torch.testing.assert_close(tokens_p.grad.float(), tokens.float(), atol=5e-2, rtol=5e-2) + + def test_zero_copy_falls_back_when_not_registered(self): + """Plain torch.empty tensors take the staged-copy fallback correctly.""" + try: + from torch.distributed._symmetric_memory import is_symm_mem_tensor + except ImportError: + is_symm_mem_tensor = None + + handle = self._make_handle() + topk_idx, tokens, w = _make_identity_inputs(self.cfg.rank, self.cfg.ep_size) + recv_tokens, recv_w, _ = self._make_buffers() + if is_symm_mem_tensor is not None: + self.assertFalse(is_symm_mem_tensor(tokens)) + self.assertFalse(is_symm_mem_tensor(recv_tokens)) + result = torch.empty_like(tokens) + ep_prepare(handle, topk_idx) + _ep_dispatch_raw(handle, topk_idx, tokens, w, recv_tokens, recv_w) + _ep_combine_raw(handle, self._weighted(recv_tokens, recv_w), result) + torch.cuda.synchronize() + torch.testing.assert_close(result.float(), tokens.float(), atol=5e-2, rtol=5e-2) + + def test_gradient_checkpointing(self): + from torch.utils.checkpoint import checkpoint + + handle = self._make_handle() + buffer = self._make_ep_buffer(handle) + topk_idx, tokens, w = _make_identity_inputs(self.cfg.rank, self.cfg.ep_size) + tokens_p = tokens.detach().clone().requires_grad_(True) + + def step(t): + return self._moe_step(handle, buffer, topk_idx, t, w) + + out = checkpoint(step, tokens_p, use_reentrant=False) + loss = 0.5 * (out.float() ** 2).sum() + loss.backward() + torch.cuda.synchronize() + torch.testing.assert_close(tokens_p.grad.float(), tokens.float(), atol=5e-2, rtol=5e-2) + + def test_autocast_bf16_autograd(self): + """Autocast must not change result/grad dtype through the autograd path.""" + handle = self._make_handle() + buffer = self._make_ep_buffer(handle) + topk_idx, tokens, w = _make_identity_inputs(self.cfg.rank, self.cfg.ep_size) + tokens_p = tokens.detach().clone().requires_grad_(True) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + out = self._moe_step(handle, buffer, topk_idx, tokens_p, w) + self.assertEqual(out.dtype, torch.bfloat16) + self.assertEqual(buffer.recv_tokens.dtype, torch.bfloat16) + loss = 0.5 * (out.float() ** 2).sum() + loss.backward() + torch.cuda.synchronize() + self.assertEqual(tokens_p.grad.dtype, torch.bfloat16) + torch.testing.assert_close(tokens_p.grad.float(), tokens.float(), atol=5e-2, rtol=5e-2) + + # ── Snapshot / scope / multi-iter ──────────────────────────────────── + + def test_zero_copy_scope_nested(self): + """Nested _zero_copy_scope must save/restore the host-side toggle.""" + initial = tex.ep_get_zero_copy() + try: + with _zero_copy_scope(True): + self.assertTrue(tex.ep_get_zero_copy()) + with _zero_copy_scope(False): + self.assertFalse(tex.ep_get_zero_copy()) + with _zero_copy_scope(True): + self.assertTrue(tex.ep_get_zero_copy()) + self.assertFalse(tex.ep_get_zero_copy()) + self.assertTrue(tex.ep_get_zero_copy()) + self.assertEqual(tex.ep_get_zero_copy(), initial) + finally: + tex.ep_set_zero_copy(initial) + + def test_topk_int32_raises_clear_error(self): + """int32 topk_idx must error with a message pointing to .long().""" + handle = self._make_handle() + topk_idx_int32 = torch.zeros( + TOKENS_PER_RANK, TOP_K, dtype=torch.int32, device=self.cfg.device + ) + with self.assertRaises(RuntimeError) as cm: + ep_prepare(handle, topk_idx_int32) + msg = str(cm.exception) + self.assertIn("topk_idx", msg) + self.assertIn(".long()", msg) + + def test_dispatch_fwd_bwd_multiple_iterations(self): + """5 fwd+bwd iters on the same EpHandle + EpBuffer must be bit-stable.""" + handle = self._make_handle() + buffer = self._make_ep_buffer(handle) + topk_idx, tokens, w = _make_identity_inputs(self.cfg.rank, self.cfg.ep_size) + + def one_step(): + tokens_p = tokens.detach().clone().requires_grad_(True) + out = self._moe_step(handle, buffer, topk_idx, tokens_p, w) + loss = 0.5 * (out.float() ** 2).sum() + loss.backward() + return out.detach().clone(), tokens_p.grad.detach().clone() + + out_ref, grad_ref = one_step() + torch.cuda.synchronize() + for _ in range(4): + out_i, grad_i = one_step() + torch.cuda.synchronize() + torch.testing.assert_close(out_i, out_ref, atol=0, rtol=0) + torch.testing.assert_close(grad_i, grad_ref, atol=0, rtol=0) + + def test_compile_fullgraph_with_new_api(self): + """torch.compile(fullgraph=True) on public ep_dispatch+ep_combine; forward only.""" + import torch._dynamo + + torch._dynamo.reset() + handle = self._make_handle() + buffer = self._make_ep_buffer(handle) + topk_idx, tokens, w = _make_identity_inputs(self.cfg.rank, self.cfg.ep_size) + + def step(tokens, topk_idx, w): + recv_t, recv_w_out, _ = ep_dispatch(handle, buffer, tokens, topk_idx, w) + return ep_combine(handle, buffer, recv_t, recv_w_out) + + with torch.no_grad(): + ref = step(tokens, topk_idx, w).detach().clone() + torch.cuda.synchronize() + + compiled = torch.compile(step, fullgraph=True, dynamic=False) + with torch.no_grad(): + out = compiled(tokens, topk_idx, w) + torch.cuda.synchronize() + torch.testing.assert_close(out.float(), ref.float(), atol=5e-2, rtol=5e-2) + + def test_pp_1f1b_handle_mem_snapshot(self): + """fwd₀ → fwd₁ → bwd₀ → bwd₁ with different routings; each bwd uses its own snapshot.""" + handle = self._make_handle() + buffer0 = self._make_ep_buffer(handle) + buffer1 = self._make_ep_buffer(handle) + T, H = TOKENS_PER_RANK, HIDDEN_DIM + # Routing 0 → experts 0, 1; routing 1 → experts 2, 3. + idx0 = torch.zeros(T, TOP_K, dtype=torch.int64, device=self.cfg.device) + idx0[:, 1] = 1 + idx1 = torch.full((T, TOP_K), 2, dtype=torch.int64, device=self.cfg.device) + idx1[:, 1] = 3 + w = torch.full((T, TOP_K), 1.0 / TOP_K, dtype=torch.float32, device=self.cfg.device) + tokens0 = torch.full( + (T, H), 0.1 + self.cfg.rank * 0.01, dtype=torch.bfloat16, device=self.cfg.device + ) + tokens1 = torch.full( + (T, H), 0.5 + self.cfg.rank * 0.01, dtype=torch.bfloat16, device=self.cfg.device + ) + t0 = tokens0.detach().clone().requires_grad_(True) + t1 = tokens1.detach().clone().requires_grad_(True) + r0, _, _ = ep_dispatch(handle, buffer0, t0, idx0, w) + r1, _, _ = ep_dispatch(handle, buffer1, t1, idx1, w) + loss0 = 0.5 * (r0.float() ** 2).sum() + loss1 = 0.5 * (r1.float() ** 2).sum() + loss0.backward() + loss1.backward() + torch.cuda.synchronize() + torch.testing.assert_close( + t0.grad.float(), tokens0.float() * float(TOP_K), atol=5e-2, rtol=5e-2 + ) + torch.testing.assert_close( + t1.grad.float(), tokens1.float() * float(TOP_K), atol=5e-2, rtol=5e-2 + ) + + def test_record_stream(self): + """EpBuffer.record_stream(s) records on both owned tensors.""" + handle = self._make_handle() + buffer = self._make_ep_buffer(handle, use_symm_mem=True) + s = torch.cuda.Stream() + buffer.record_stream(s) + with torch.cuda.stream(s): + buffer.recv_tokens.add_(0) + buffer.combine_in.add_(0) + torch.cuda.synchronize() + + def test_from_external_round_trip(self): + """EpBuffer.from_external with caller-allocated symm-mem tensors must round-trip.""" + handle = self._make_handle() + rc = self.cfg.recv_capacity_per_rank + recv_tokens = self._try_symm_alloc((rc, HIDDEN_DIM), torch.bfloat16) + combine_in = self._try_symm_alloc((rc, HIDDEN_DIM), torch.bfloat16) + buffer = EpBuffer.from_external(handle, recv_tokens=recv_tokens, combine_in=combine_in) + topk_idx, tokens, w = _make_identity_inputs(self.cfg.rank, self.cfg.ep_size) + tokens_p = tokens.detach().clone().requires_grad_(True) + out = self._moe_step(handle, buffer, topk_idx, tokens_p, w) + loss = 0.5 * (out.float() ** 2).sum() + loss.backward() + torch.cuda.synchronize() + torch.testing.assert_close(out.float(), tokens.float(), atol=5e-2, rtol=5e-2) + torch.testing.assert_close(tokens_p.grad.float(), tokens.float(), atol=5e-2, rtol=5e-2) + + +# ── Entry point ────────────────────────────────────────────────────────────── + + +def _init_distributed(): + dist.init_process_group(backend="nccl") + torch.cuda.set_device(int(os.environ["LOCAL_RANK"])) + + +if __name__ == "__main__": + _init_distributed() + loader = unittest.TestLoader() + suite = loader.loadTestsFromTestCase(TestEP) + runner = unittest.TextTestRunner(stream=sys.stdout, verbosity=2) + result = runner.run(suite) + dist.barrier() + dist.destroy_process_group() + sys.exit(0 if result.wasSuccessful() else 1) diff --git a/tests/pytorch/distributed/run_test_ep.sh b/tests/pytorch/distributed/run_test_ep.sh new file mode 100755 index 0000000000..cad1fc0a40 --- /dev/null +++ b/tests/pytorch/distributed/run_test_ep.sh @@ -0,0 +1,55 @@ +#!/bin/bash +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. +# +# Launcher for tests/pytorch/distributed/run_ep.py. Auto-detects GPU count. +# Short timeout by default to surface hangs early. + +set -uo pipefail + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +TE_REPO_ROOT="$(cd "${SCRIPT_DIR}/../../.." && pwd)" +export PYTHONPATH="${TE_REPO_ROOT}${PYTHONPATH:+:${PYTHONPATH}}" + +DETECTED_GPUS=$(nvidia-smi -L 2>/dev/null | wc -l) +if [ "${DETECTED_GPUS}" -lt 4 ]; then + echo "EP requires >= 4 GPUs (found ${DETECTED_GPUS}); SKIPPING." + exit 0 +fi +NUM_RANKS="${NVTE_TEST_EP_NUM_RANKS:-${DETECTED_GPUS}}" +if [ "${NUM_RANKS}" -gt 8 ]; then NUM_RANKS=8; fi + +# Short timeout to detect hangs early. +TEST_TIMEOUT_S="${TEST_TIMEOUT_S:-120}" + +# Stage NCCL EP JIT cubins on tmpfs to keep iteration fast. +: ${NCCL_EP_JIT_CACHE_DIR:="${TMPDIR:-/tmp}/nccl_ep_jit_cache_$(id -u)"} +export NCCL_EP_JIT_CACHE_DIR +mkdir -p "$NCCL_EP_JIT_CACHE_DIR" + +SCRIPT="${SCRIPT_DIR}/run_ep.py" +echo "=== Running ${SCRIPT} on ${NUM_RANKS} GPUs (timeout=${TEST_TIMEOUT_S}s) ===" + +# setsid + kill-after so SIGKILL takes down the whole process group, not just torchrun. +setsid timeout --foreground --kill-after=10 --signal=TERM "${TEST_TIMEOUT_S}" \ + torchrun --standalone --nnodes=1 --nproc-per-node="${NUM_RANKS}" \ + "${SCRIPT}" 2>&1 | tee stdout_ep.txt +RC=${PIPESTATUS[0]} +pkill -9 -f "tests/pytorch/distributed/run_ep.py" 2>/dev/null || true + +RET=0 +if [ "${RC}" -ne 0 ]; then + echo "torchrun exited with ${RC}" + RET=1 +fi +# Match unittest failure markers and unhandled Python tracebacks at column 0 +# (don't fail on stack frames in informational logs). +if grep -qE "^FAILED|^Traceback" stdout_ep.txt; then RET=1; fi +if ! grep -qE "Ran [0-9]+ test|^OK$" stdout_ep.txt; then + echo "ERROR: no test summary — likely hang or early crash" + RET=1 +fi + +if [ -z "${KEEP_EP_LOGS:-}" ]; then rm -f stdout_ep.txt; fi +exit $RET diff --git a/tests/pytorch/distributed/test_ep.py b/tests/pytorch/distributed/test_ep.py new file mode 100644 index 0000000000..81eef9a3c1 --- /dev/null +++ b/tests/pytorch/distributed/test_ep.py @@ -0,0 +1,31 @@ +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. +"""Pytest driver — spawns run_ep.py under torchrun and asserts the suite passed.""" + +import os +import subprocess +from pathlib import Path + +import pytest +import torch + +TEST_ROOT = Path(__file__).parent.resolve() +WORKER = TEST_ROOT / "run_ep.py" +LAUNCHER = TEST_ROOT / "run_test_ep.sh" + + +@pytest.mark.skipif(torch.cuda.device_count() < 4, reason="EP requires >= 4 GPUs") +def test_multi_process_ep(): + """Launch the EP unit-test suite across all visible GPUs. + + Short timeout so a hang on any rank surfaces fast rather than burning CI time. + """ + timeout_s = int(os.environ.get("NVTE_TEST_EP_TIMEOUT_S", "180")) + proc = subprocess.run( + ["bash", str(LAUNCHER)], + env={**os.environ, "KEEP_EP_LOGS": "1", "TEST_TIMEOUT_S": str(timeout_s)}, + timeout=timeout_s + 30, + check=False, + ) + assert proc.returncode == 0, f"EP test suite failed (rc={proc.returncode})" diff --git a/transformer_engine/common/CMakeLists.txt b/transformer_engine/common/CMakeLists.txt index 030023d949..c5f8dfb1ab 100644 --- a/transformer_engine/common/CMakeLists.txt +++ b/transformer_engine/common/CMakeLists.txt @@ -379,6 +379,96 @@ if (NVTE_WITH_CUSOLVERMP) message(STATUS "Using cuSolverMp at: ${CUSOLVERMP_DIR}") endif() +# ── NCCL EP (on by default, HT mode only) ───────────────────────────────── +# Set -DNVTE_WITH_NCCL_EP=OFF (or NVTE_BUILD_WITH_NCCL_EP=0 in setup.py) to +# skip NCCL EP entirely — useful on older images whose system NCCL is below +# the 2.30.4 EP minimum. +option(NVTE_WITH_NCCL_EP "Build NCCL EP into libtransformer_engine.so" ON) +if(NVTE_WITH_NCCL_EP) +# SM>=90 and NCCL>=2.30.4 are gated at runtime in EPBackend::initialize. +# ── NCCL EP headers ──────────────────────────────────────────────────────── +# Headers + libs are produced by the in-tree 3rdparty/nccl submodule build +# (auto-built by setup.py via build_nccl_ep_submodule). +set(NCCL_EP_SUBMODULE_ROOT + "${CMAKE_CURRENT_SOURCE_DIR}/../../3rdparty/nccl") +set(NCCL_EP_INCLUDE_DIR "${NCCL_EP_SUBMODULE_ROOT}/contrib/nccl_ep/include") +if(NOT EXISTS "${NCCL_EP_INCLUDE_DIR}/nccl_ep.h") + message(FATAL_ERROR + "NCCL EP header not found at ${NCCL_EP_INCLUDE_DIR}/nccl_ep.h. " + "Run `git submodule update --init --recursive` to checkout 3rdparty/nccl.") +endif() +message(STATUS "NCCL EP headers: ${NCCL_EP_INCLUDE_DIR}") + +# ── libnccl_ep.so ────────────────────────────────────────────────────────── +set(NCCL_EP_LIB_DIR "${NCCL_EP_SUBMODULE_ROOT}/build/lib") +find_library(NCCL_EP_LIB + NAMES nccl_ep libnccl_ep + HINTS ${NCCL_EP_LIB_DIR} + NO_DEFAULT_PATH + REQUIRED) + +# ── NCCL + GIN headers ───────────────────────────────────────────────────── +# libnccl.so and all GIN headers (ncclGin.h, ncclWindow_t, ncclDevComm_t) +# ship with the base CUDA Toolkit OR the 3rdparty/nccl submodule build +# (preferred when present; auto-built by setup.py via build_nccl_ep_submodule). +if(NOT NCCL_LIB) + find_library(NCCL_LIB + NAMES nccl libnccl + HINTS ${NCCL_EP_LIB_DIR} ${CUDAToolkit_LIBRARY_DIR} + PATH_SUFFIXES lib lib64 + REQUIRED) +endif() + +set(NCCL_SUBMODULE_INCLUDE + "${CMAKE_CURRENT_SOURCE_DIR}/../../3rdparty/nccl/build/include") +if(EXISTS "${NCCL_SUBMODULE_INCLUDE}/nccl.h") + set(NCCL_INCLUDE_DIRS_FOR_TE ${NCCL_SUBMODULE_INCLUDE}) +else() + set(NCCL_INCLUDE_DIRS_FOR_TE ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES}) +endif() + +# Diagnostic: log detected NCCL header version (minimum enforced at runtime). +find_file(_nvte_nccl_header_path nccl.h + PATHS ${NCCL_INCLUDE_DIRS_FOR_TE} ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES} + NO_DEFAULT_PATH) +if(_nvte_nccl_header_path) + file(READ "${_nvte_nccl_header_path}" _nvte_nccl_h) + string(REGEX MATCH "#define[ \t]+NCCL_MAJOR[ \t]+([0-9]+)" _ "${_nvte_nccl_h}") + set(_nvte_nccl_major "${CMAKE_MATCH_1}") + string(REGEX MATCH "#define[ \t]+NCCL_MINOR[ \t]+([0-9]+)" _ "${_nvte_nccl_h}") + set(_nvte_nccl_minor "${CMAKE_MATCH_1}") + string(REGEX MATCH "#define[ \t]+NCCL_PATCH[ \t]+([0-9]+)" _ "${_nvte_nccl_h}") + set(_nvte_nccl_patch "${CMAKE_MATCH_1}") + if(_nvte_nccl_major AND _nvte_nccl_minor AND _nvte_nccl_patch) + message(STATUS "NCCL header: ${_nvte_nccl_header_path} (version ${_nvte_nccl_major}.${_nvte_nccl_minor}.${_nvte_nccl_patch})") + endif() +endif() + +target_include_directories(transformer_engine PRIVATE + ${NCCL_EP_INCLUDE_DIR} + ${NCCL_INCLUDE_DIRS_FOR_TE}) # covers nccl.h + nccl_device/ + +target_link_libraries(transformer_engine PUBLIC + ${NCCL_EP_LIB} + ${NCCL_LIB}) + +# Embed rpath so the installed wheel finds libnccl_ep.so at runtime. +# libnccl.so is already on the system via the Toolkit — no rpath needed for it. +set_target_properties(transformer_engine PROPERTIES + INSTALL_RPATH "$ORIGIN;${NCCL_EP_LIB_DIR}") + +target_sources(transformer_engine PRIVATE + ep/ep_backend.cpp + ep/ep_api.cpp) + +message(STATUS "NCCL EP enabled: ${NCCL_EP_LIB}") +message(STATUS "NCCL EP include: ${NCCL_EP_INCLUDE_DIR}") +else() + # NCCL EP off: export throwing nvte_ep_* stubs so framework bindings link. + target_sources(transformer_engine PRIVATE ep/ep_api_stub.cpp) + message(STATUS "NCCL EP disabled (NVTE_WITH_NCCL_EP=OFF) — using nvte_ep_* stubs") +endif() + # Number of philox4x32 rounds for stochastic rounding (build-time constant). set(NVTE_BUILD_NUM_PHILOX_ROUNDS_STR $ENV{NVTE_BUILD_NUM_PHILOX_ROUNDS}) if (NOT NVTE_BUILD_NUM_PHILOX_ROUNDS_STR) diff --git a/transformer_engine/common/ep/ep_api.cpp b/transformer_engine/common/ep/ep_api.cpp new file mode 100644 index 0000000000..89d8b38607 --- /dev/null +++ b/transformer_engine/common/ep/ep_api.cpp @@ -0,0 +1,76 @@ +/************************************************************************* + * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +/*! \file ep_api.cpp + * \brief nvte_ep_* C API: thin delegations to the EPBackend singleton. + */ + +#include +#include + +#include "../common.h" +#include "../util/logging.h" +#include "ep_backend.h" + +using transformer_engine::ep::EPBackend; + +void nvte_ep_initialize(void* ep_comm, NVTEEpGroupConfig group_config) { + NVTE_CHECK(ep_comm != nullptr, "ep_comm must not be null"); + EPBackend::initialize(static_cast(ep_comm), group_config); +} + +void nvte_ep_shutdown(void) { EPBackend::shutdown(); } + +uint64_t nvte_ep_register_layer(NVTEEpLayerConfig layer_config, size_t* handle_mem_size) { + NVTE_CHECK(handle_mem_size != nullptr, "handle_mem_size must not be null"); + return EPBackend::get().register_layer(layer_config, handle_mem_size); +} + +void nvte_ep_prepare(NVTEEpHandle handle, NVTETensor topk_idx, NVTETensor token_counts, + size_t dispatch_output_per_expert_alignment, cudaStream_t stream) { + void* mem_ptr = nvte_tensor_data(handle.mem); + NVTE_CHECK(mem_ptr != nullptr, "handle_mem tensor data must not be null"); + EPBackend::get().prepare(handle.id, topk_idx, token_counts, mem_ptr, + dispatch_output_per_expert_alignment, stream); +} + +void nvte_ep_dispatch(NVTEEpHandle handle, NVTETensor topk_idx, NVTETensor tokens, + NVTECommWindow tokens_win, NVTETensor topk_weights, + NVTECommWindow topk_weights_win, NVTETensor recv_tokens, + NVTECommWindow recv_tokens_win, NVTETensor recv_topk_weights, + NVTECommWindow recv_topk_weights_win, cudaStream_t stream) { + void* mem_ptr = nvte_tensor_data(handle.mem); + NVTE_CHECK(mem_ptr != nullptr, "handle_mem tensor data must not be null"); + EPBackend::get().dispatch(handle.id, mem_ptr, topk_idx, tokens, tokens_win, topk_weights, + topk_weights_win, recv_tokens, recv_tokens_win, recv_topk_weights, + recv_topk_weights_win, stream); +} + +void nvte_ep_combine(NVTEEpHandle handle, NVTETensor expert_out, NVTECommWindow expert_out_win, + NVTETensor result, cudaStream_t stream) { + void* mem_ptr = nvte_tensor_data(handle.mem); + NVTE_CHECK(mem_ptr != nullptr, "handle_mem tensor data must not be null"); + EPBackend::get().combine(handle.id, mem_ptr, expert_out, expert_out_win, result, stream); +} + +void nvte_ep_dispatch_bwd(NVTEEpHandle handle, NVTETensor grad, NVTECommWindow grad_win, + NVTETensor g_recv_topk_weights, NVTECommWindow g_recv_topk_weights_win, + NVTETensor grad_tokens, NVTETensor grad_topk_weights, + cudaStream_t stream) { + void* mem_ptr = nvte_tensor_data(handle.mem); + NVTE_CHECK(mem_ptr != nullptr, "handle_mem tensor data must not be null"); + EPBackend::get().dispatch_bwd(handle.id, mem_ptr, grad, grad_win, g_recv_topk_weights, + g_recv_topk_weights_win, grad_tokens, grad_topk_weights, stream); +} + +void nvte_ep_combine_bwd(NVTEEpHandle handle, NVTETensor grad, NVTECommWindow grad_win, + NVTETensor grad_expert_out, NVTECommWindow grad_expert_out_win, + cudaStream_t stream) { + void* mem_ptr = nvte_tensor_data(handle.mem); + NVTE_CHECK(mem_ptr != nullptr, "handle_mem tensor data must not be null"); + EPBackend::get().combine_bwd(handle.id, mem_ptr, grad, grad_win, grad_expert_out, + grad_expert_out_win, stream); +} diff --git a/transformer_engine/common/ep/ep_api_stub.cpp b/transformer_engine/common/ep/ep_api_stub.cpp new file mode 100644 index 0000000000..fe4127d87d --- /dev/null +++ b/transformer_engine/common/ep/ep_api_stub.cpp @@ -0,0 +1,61 @@ +/************************************************************************* + * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +/*! \file ep_api_stub.cpp + * \brief Throwing nvte_ep_* stubs compiled when NVTE_WITH_NCCL_EP=OFF. + */ + +#include + +#include "../util/logging.h" + +namespace { +[[noreturn]] void ep_not_built() { + NVTE_ERROR( + "NCCL EP is not built into this TransformerEngine. Rebuild TE with " + "NVTE_BUILD_WITH_NCCL_EP=1 and CUDA arch >= 90 (e.g. NVTE_CUDA_ARCHS=\"90\")."); +} +} // namespace + +void nvte_ep_initialize(void* /*ep_comm*/, NVTEEpGroupConfig /*group_config*/) { ep_not_built(); } + +void nvte_ep_shutdown(void) {} + +uint64_t nvte_ep_register_layer(NVTEEpLayerConfig /*layer_config*/, size_t* /*handle_mem_size*/) { + ep_not_built(); +} + +void nvte_ep_prepare(NVTEEpHandle /*handle*/, NVTETensor /*topk_idx*/, NVTETensor /*token_counts*/, + size_t /*dispatch_output_per_expert_alignment*/, cudaStream_t /*stream*/) { + ep_not_built(); +} + +void nvte_ep_dispatch(NVTEEpHandle /*handle*/, NVTETensor /*topk_idx*/, NVTETensor /*tokens*/, + NVTECommWindow /*tokens_win*/, NVTETensor /*topk_weights*/, + NVTECommWindow /*topk_weights_win*/, NVTETensor /*recv_tokens*/, + NVTECommWindow /*recv_tokens_win*/, NVTETensor /*recv_topk_weights*/, + NVTECommWindow /*recv_topk_weights_win*/, cudaStream_t /*stream*/) { + ep_not_built(); +} + +void nvte_ep_combine(NVTEEpHandle /*handle*/, NVTETensor /*expert_out*/, + NVTECommWindow /*expert_out_win*/, NVTETensor /*result*/, + cudaStream_t /*stream*/) { + ep_not_built(); +} + +void nvte_ep_dispatch_bwd(NVTEEpHandle /*handle*/, NVTETensor /*grad*/, NVTECommWindow /*grad_win*/, + NVTETensor /*g_recv_topk_weights*/, + NVTECommWindow /*g_recv_topk_weights_win*/, NVTETensor /*grad_tokens*/, + NVTETensor /*grad_topk_weights*/, cudaStream_t /*stream*/) { + ep_not_built(); +} + +void nvte_ep_combine_bwd(NVTEEpHandle /*handle*/, NVTETensor /*grad*/, NVTECommWindow /*grad_win*/, + NVTETensor /*grad_expert_out*/, NVTECommWindow /*grad_expert_out_win*/, + cudaStream_t /*stream*/) { + ep_not_built(); +} diff --git a/transformer_engine/common/ep/ep_backend.cpp b/transformer_engine/common/ep/ep_backend.cpp new file mode 100644 index 0000000000..ae0f3ab888 --- /dev/null +++ b/transformer_engine/common/ep/ep_backend.cpp @@ -0,0 +1,514 @@ +/************************************************************************* + * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +/*! \file ep_backend.cpp + * \brief EPBackend implementation. See ep_backend.h for the op flow. + */ + +#include "ep_backend.h" + +#include +#include +#include +#include +#include +#include + +#include "../common.h" +#include "../util/cuda_runtime.h" +#include "../util/logging.h" + +namespace transformer_engine { +namespace ep { + +namespace { + +// Build a by-value ncclEpTensor_t descriptor. `sizes` is caller-owned and must +// outlive any NCCL EP call that consumes the descriptor. +inline ncclEpTensor_t make_tensor(void* data, unsigned int ndim, ncclDataType_t datatype, + size_t* sizes) { + ncclEpTensor_t t = NCCL_EP_TENSOR_INIT; + t.ndim = ndim; + t.datatype = datatype; + t.data = data; + t.sizes = sizes; + return t; +} + +// Payload descriptor: prefer the symmem window when set, else fall back to the +// NVTETensor's raw device pointer. +inline ncclEpTensor_t make_payload_tensor(const NVTETensor t, const NVTECommWindow& win, + unsigned int ndim, ncclDataType_t datatype, + size_t* sizes) { + ncclEpTensor_t desc = NCCL_EP_TENSOR_INIT; + desc.ndim = ndim; + desc.datatype = datatype; + desc.sizes = sizes; + if (win.window != nullptr) { + desc.win_hdl = win.window; + desc.win_offset = win.offset; + } else { + desc.data = nvte_tensor_data(t); + NVTE_CHECK(desc.data != nullptr, "payload tensor data must not be null"); + } + return desc; +} + +// RAII guard for ncclEpHandle_t — destroys on scope exit, leak-free on throw. +class ScopedEpHandle { + public: + ScopedEpHandle() = default; + explicit ScopedEpHandle(ncclEpHandle_t h) : h_(h) {} + ~ScopedEpHandle() { + if (h_ != nullptr) ncclEpHandleDestroy(h_); + } + ScopedEpHandle(const ScopedEpHandle&) = delete; + ScopedEpHandle& operator=(const ScopedEpHandle&) = delete; + ScopedEpHandle(ScopedEpHandle&& other) noexcept : h_(other.h_) { other.h_ = nullptr; } + ScopedEpHandle& operator=(ScopedEpHandle&& other) noexcept { + if (this != &other) { + if (h_ != nullptr) ncclEpHandleDestroy(h_); + h_ = other.h_; + other.h_ = nullptr; + } + return *this; + } + operator ncclEpHandle_t() const { return h_; } + ncclEpHandle_t get() const { return h_; } + + private: + ncclEpHandle_t h_ = nullptr; +}; + +} // namespace + +// --------------------------------------------------------------------------- +// Singleton + bootstrap +// --------------------------------------------------------------------------- + +EPBackend& EPBackend::instance() { + static EPBackend inst; + return inst; +} + +EPBackend& EPBackend::get() { + EPBackend& inst = instance(); + NVTE_CHECK(inst.initialized_, "EPBackend not initialized. Call nvte_ep_initialize() first."); + return inst; +} + +void EPBackend::validate_config(const NVTEEpGroupConfig& config) { + NVTE_CHECK(config.ep_size > 0, "ep_size must be positive, got ", config.ep_size); + NVTE_CHECK(config.num_experts > 0, "num_experts must be positive, got ", config.num_experts); + NVTE_CHECK(config.max_tokens_per_rank > 0, "max_tokens_per_rank must be positive, got ", + config.max_tokens_per_rank); + NVTE_CHECK(config.max_recv_tokens_per_rank > 0, "max_recv_tokens_per_rank must be positive, got ", + config.max_recv_tokens_per_rank); + NVTE_CHECK(config.hidden_dim > 0, "hidden_dim must be positive, got ", config.hidden_dim); + NVTE_CHECK(config.hidden_dim * sizeof(nv_bfloat16) >= 16, + "hidden_dim * 2 must be >= 16 (NCCL EP 16B row alignment); got hidden_dim=", + config.hidden_dim); + NVTE_CHECK(config.num_experts % config.ep_size == 0, "num_experts (", config.num_experts, + ") must be divisible by ep_size (", config.ep_size, ")"); + NVTE_CHECK(config.max_num_sms >= 0, "max_num_sms must be >= 0 (0 = auto), got ", + config.max_num_sms); + + int device, major; + NVTE_CHECK_CUDA(cudaGetDevice(&device)); + NVTE_CHECK_CUDA(cudaDeviceGetAttribute(&major, cudaDevAttrComputeCapabilityMajor, device)); + NVTE_CHECK(major >= 9, + "NCCL EP requires SM_90+ (Hopper or later), " + "but current device has compute capability ", + major, ".x"); + + // NCCL EP needs CUDA multicast (NVLS); init hangs without it. + NVTE_CHECK(cuda::supports_multicast(device), + "NCCL EP requires CUDA multicast (NVLS) support on device ", device, + " but CU_DEVICE_ATTRIBUTE_MULTICAST_SUPPORTED reports 0."); +} + +void EPBackend::initialize(ncclComm_t ep_comm, NVTEEpGroupConfig config) { + EPBackend& inst = instance(); + std::lock_guard lock(inst.mutex_); + NVTE_CHECK(!inst.initialized_, "EP already initialized. Call initialize only once per process."); + NVTE_CHECK(ep_comm != nullptr, "ep_comm must not be null"); + + // Runtime gate: NCCL >= 2.30.4 (matches the submodule pin). + constexpr int kMinNcclVersion = 23004; + int nccl_version = 0; + NVTE_CHECK_NCCL(ncclGetVersion(&nccl_version)); + NVTE_CHECK(nccl_version >= kMinNcclVersion, "NCCL EP requires NCCL >= 2.30.4, found ", + nccl_version / 10000, ".", (nccl_version / 100) % 100, ".", nccl_version % 100, + " at runtime."); + + validate_config(config); + + int comm_size = 0; + NVTE_CHECK_NCCL(ncclCommCount(ep_comm, &comm_size)); + NVTE_CHECK(comm_size == config.ep_size, "ep_comm size (", comm_size, ") must equal ep_size (", + config.ep_size, "). Pass the EP sub-communicator, not the world comm."); + + inst.init(ep_comm, config); +} + +void EPBackend::shutdown() { + EPBackend& inst = instance(); + std::lock_guard lock(inst.mutex_); + if (!inst.initialized_) return; + inst.handles_.clear(); + // ncclEpGroupDestroy reads from ep_comm_; destroy group while comm is still alive. + if (inst.ep_group_ != nullptr) { + ncclEpGroupDestroy(inst.ep_group_); + inst.ep_group_ = nullptr; + } + inst.ep_comm_ = nullptr; // borrowed — caller destroys + inst.initialized_ = false; +} + +// --------------------------------------------------------------------------- +// Helpers +// --------------------------------------------------------------------------- + +ncclDataType_t EPBackend::nvte_dtype_to_nccl(NVTEDType dtype) { + switch (dtype) { + case kNVTEFloat32: + return ncclFloat32; + case kNVTEFloat16: + return ncclFloat16; + case kNVTEBFloat16: + return ncclBfloat16; + case kNVTEInt32: + return ncclInt32; + case kNVTEInt64: + return ncclInt64; + case kNVTEByte: + return ncclUint8; + case kNVTEFloat8E4M3: + return ncclFloat8e4m3; + case kNVTEFloat8E5M2: + return ncclFloat8e5m2; + default: + NVTE_ERROR("Unsupported NVTEDType for NCCL EP conversion: ", static_cast(dtype)); + } + return ncclFloat32; // unreachable +} + +// Open a transient ncclEpHandle over handle_mem. Caller owns the result. +ncclEpHandle_t EPBackend::open_handle(void* handle_mem, size_t handle_mem_size, int num_topk, + size_t dispatch_output_per_expert_alignment) { + size_t hm_sizes[1] = {handle_mem_size}; + ncclEpTensor_t routing_desc = make_tensor(handle_mem, 1, ncclUint8, hm_sizes); + ncclEpHandleConfig_t hcfg = NCCL_EP_HANDLE_CONFIG_INIT; + hcfg.dispatch_output_per_expert_alignment = dispatch_output_per_expert_alignment; + ncclEpHandle_t handle; + NVTE_CHECK_NCCL(ncclEpInitHandle(&handle, ep_group_, NCCL_EP_LAYOUT_EXPERT_MAJOR, &hcfg, num_topk, + &routing_desc)); + return handle; +} + +// --------------------------------------------------------------------------- +// Lifecycle +// --------------------------------------------------------------------------- + +// Static-dtor teardown: skip NCCL calls (CUDA context / borrowed ep_comm_ may +// already be gone) and release in-memory state only. +EPBackend::~EPBackend() { + std::lock_guard lock(mutex_); + if (!initialized_) return; + handles_.clear(); + ep_group_ = nullptr; + ep_comm_ = nullptr; + initialized_ = false; +} + +void EPBackend::init(ncclComm_t ep_comm, NVTEEpGroupConfig group_config) { + NVTE_CHECK(!initialized_, "EPBackend already initialized"); + + group_config_ = group_config; + + ncclEpGroupConfig_t cfg = NCCL_EP_GROUP_CONFIG_INIT; + cfg.algorithm = NCCL_EP_ALGO_HIGH_THROUGHPUT; + cfg.num_experts = static_cast(group_config.num_experts); + cfg.max_dispatch_tokens_per_rank = static_cast(group_config.max_tokens_per_rank); + cfg.max_token_bytes = static_cast(group_config.hidden_dim * sizeof(nv_bfloat16)); + cfg.rdma_buffer_size = NCCL_EP_AUTO; + cfg.num_qp_per_rank = NCCL_EP_AUTO; + cfg.num_channels = NCCL_EP_AUTO; + cfg.max_num_sms = group_config.max_num_sms > 0 + ? static_cast(group_config.max_num_sms) + : NCCL_EP_AUTO; + // Must be > 0; NCCL EP errors out on 0. + cfg.max_recv_tokens_per_rank = static_cast(group_config.max_recv_tokens_per_rank); + + NVTE_CHECK_NCCL(ncclEpCreateGroup(&ep_group_, ep_comm, &cfg)); + + ep_comm_ = ep_comm; + + initialized_ = true; +} + +// --------------------------------------------------------------------------- +// Per-handle_id config cache +// --------------------------------------------------------------------------- + +uint64_t EPBackend::insert_new_entry(size_t handle_mem_size, int top_k, size_t alignment) { + if (handle_cache_cap_ == 0) { + const char* cap_env = std::getenv("NVTE_EP_HANDLE_CACHE_SIZE"); + handle_cache_cap_ = (cap_env != nullptr) ? std::max(1, std::atoi(cap_env)) : 8192; + } + NVTE_CHECK(handles_.size() < handle_cache_cap_, "EP handle cache full (", handle_cache_cap_, + " entries). Raise via NVTE_EP_HANDLE_CACHE_SIZE."); + uint64_t id = next_handle_id_.fetch_add(1, std::memory_order_relaxed); + handles_.emplace(id, HandleEntry{handle_mem_size, alignment, top_k}); + return id; +} + +EPBackend::HandleEntry& EPBackend::lookup_config(uint64_t handle_id) { + auto it = handles_.find(handle_id); + NVTE_CHECK(it != handles_.end(), "ep op on handle_id=", handle_id, + " with no cached config — call ep_prepare first."); + return it->second; +} + +// --------------------------------------------------------------------------- +// Per-step operations +// --------------------------------------------------------------------------- + +uint64_t EPBackend::register_layer(NVTEEpLayerConfig layer_config, size_t* handle_mem_size) { + NVTE_CHECK(initialized_, "EPBackend not initialized"); + NVTE_CHECK(layer_config.top_k > 0, "NVTEEpLayerConfig.top_k must be > 0"); + NVTE_CHECK(handle_mem_size != nullptr, "handle_mem_size must not be null"); + ncclEpHandleConfig_t hcfg = NCCL_EP_HANDLE_CONFIG_INIT; + hcfg.dispatch_output_per_expert_alignment = layer_config.dispatch_output_per_expert_alignment; + size_t hm_size = 0; + NVTE_CHECK_NCCL(ncclEpHandleMemSize(ep_group_, NCCL_EP_LAYOUT_EXPERT_MAJOR, &hcfg, &hm_size, + layer_config.top_k)); + *handle_mem_size = hm_size; + std::lock_guard lock(mutex_); + return insert_new_entry(hm_size, layer_config.top_k, + layer_config.dispatch_output_per_expert_alignment); +} + +void EPBackend::prepare(uint64_t handle_id, const NVTETensor topk_idx, NVTETensor token_counts, + void* handle_mem, size_t dispatch_output_per_expert_alignment, + cudaStream_t stream) { + NVTE_CHECK(initialized_, "EPBackend not initialized"); + NVTE_CHECK(handle_mem != nullptr, "handle_mem must not be null"); + + NVTEShape idx_shape = nvte_tensor_shape(topk_idx); + void* idx_data = nvte_tensor_data(topk_idx); + NVTE_CHECK(idx_data != nullptr, "topk_idx data must not be null"); + + const size_t num_tokens = idx_shape.data[0]; + const size_t top_k = idx_shape.ndim > 1 ? idx_shape.data[1] : 1; + const size_t num_local_experts = + static_cast(group_config_.num_experts / group_config_.ep_size); + + size_t idx_sizes[2] = {num_tokens, top_k}; + ncclEpTensor_t nccl_topk_idx = make_tensor(idx_data, 2, ncclInt64, idx_sizes); + + // ncclEpUpdateHandle writes per-expert counts via expert_counters. + size_t cnt_sizes[1] = {num_local_experts}; + ncclEpTensor_t token_counts_desc; + void* token_counts_data = (token_counts != nullptr) ? nvte_tensor_data(token_counts) : nullptr; + if (token_counts_data != nullptr) { + token_counts_desc = make_tensor(token_counts_data, 1, ncclInt32, cnt_sizes); + } + ncclEpLayoutInfo_t layout_info = NCCL_EP_LAYOUT_INFO_INIT; + layout_info.expert_counters = (token_counts_data != nullptr) ? &token_counts_desc : nullptr; + + ScopedEpHandle transient; + { + std::lock_guard lock(mutex_); + HandleEntry& cfg = lookup_config(handle_id); + NVTE_CHECK(cfg.alignment == dispatch_output_per_expert_alignment, + "ep_prepare: alignment mismatch for handle_id=", handle_id, + " (cached=", cfg.alignment, ", got=", dispatch_output_per_expert_alignment, ")"); + transient = + ScopedEpHandle(open_handle(handle_mem, cfg.handle_mem_size, cfg.top_k, cfg.alignment)); + } + NVTE_CHECK_NCCL(ncclEpUpdateHandle(transient, &nccl_topk_idx, &layout_info, stream)); +} + +void EPBackend::dispatch(uint64_t handle_id, void* handle_mem, const NVTETensor topk_idx, + const NVTETensor tokens, const NVTECommWindow& tokens_win, + const NVTETensor topk_weights, const NVTECommWindow& topk_weights_win, + NVTETensor recv_tokens, const NVTECommWindow& recv_tokens_win, + NVTETensor recv_topk_weights, const NVTECommWindow& recv_topk_weights_win, + cudaStream_t stream) { + NVTE_CHECK(initialized_, "EPBackend not initialized"); + NVTE_CHECK(handle_mem != nullptr, "handle_mem must not be null"); + + NVTEShape tok_shape = nvte_tensor_shape(tokens); + NVTEDType tok_dtype = nvte_tensor_type(tokens); + + const size_t num_tokens = tok_shape.data[0]; + const size_t hidden_dim = tok_shape.data[1]; + + size_t tok_sizes[2] = {num_tokens, hidden_dim}; + ncclEpTensor_t nccl_tokens_in = + make_payload_tensor(tokens, tokens_win, 2, nvte_dtype_to_nccl(tok_dtype), tok_sizes); + + const bool is_forward = (topk_weights != nullptr); + + // Routing is cached in handle_mem by ep_prepare; dispatch only needs + // topk_weights to reconstruct the sparse-to-dense prob map. + size_t weights_in_sizes[2] = {0, 0}; + ncclEpTensor_t nccl_topk_weights_in; + if (is_forward) { + NVTE_CHECK(topk_idx != nullptr, "topk_idx required in forward dispatch"); + NVTEShape idx_shape = nvte_tensor_shape(topk_idx); + const size_t top_k = idx_shape.ndim > 1 ? idx_shape.data[1] : 1; + weights_in_sizes[0] = num_tokens; + weights_in_sizes[1] = top_k; + nccl_topk_weights_in = + make_payload_tensor(topk_weights, topk_weights_win, 2, ncclFloat32, weights_in_sizes); + } + + NVTEShape recv_shape = nvte_tensor_shape(recv_tokens); + NVTEDType recv_dtype = nvte_tensor_type(recv_tokens); + + size_t recv_sizes[2] = {recv_shape.data[0], recv_shape.data[1]}; + ncclEpTensor_t nccl_tokens_out = make_payload_tensor(recv_tokens, recv_tokens_win, 2, + nvte_dtype_to_nccl(recv_dtype), recv_sizes); + + size_t weights_out_sizes[1] = {recv_shape.data[0]}; + ncclEpTensor_t nccl_topk_weights_out; + if (is_forward) { + NVTE_CHECK(recv_topk_weights != nullptr, + "recv_topk_weights must not be null in forward dispatch"); + NVTEShape recv_w_shape = nvte_tensor_shape(recv_topk_weights); + NVTE_CHECK(recv_w_shape.ndim == 1, "recv_topk_weights must be 1D [recv_capacity]"); + nccl_topk_weights_out = make_payload_tensor(recv_topk_weights, recv_topk_weights_win, 1, + ncclFloat32, weights_out_sizes); + } + + ncclEpDispatchInputs_t in_struct = NCCL_EP_DISPATCH_INPUTS_INIT; + in_struct.tokens = &nccl_tokens_in; + in_struct.topk_weights = is_forward ? &nccl_topk_weights_in : nullptr; + + ncclEpDispatchOutputs_t out_struct = NCCL_EP_DISPATCH_OUTPUTS_INIT; + out_struct.tokens = &nccl_tokens_out; + out_struct.topk_weights = is_forward ? &nccl_topk_weights_out : nullptr; + + ncclEpDispatchConfig_t dispatch_cfg = NCCL_EP_DISPATCH_CONFIG_INIT; + dispatch_cfg.pass_direction = is_forward ? NCCL_EP_FWD_PASS : NCCL_EP_BWD_PASS; + + ScopedEpHandle transient; + { + std::lock_guard lock(mutex_); + HandleEntry& cfg = lookup_config(handle_id); + transient = + ScopedEpHandle(open_handle(handle_mem, cfg.handle_mem_size, cfg.top_k, cfg.alignment)); + } + NVTE_CHECK_NCCL(ncclEpDispatch(transient, &in_struct, &out_struct, + /*layout_info=*/nullptr, &dispatch_cfg, stream)); +} + +void EPBackend::combine(uint64_t handle_id, void* handle_mem, const NVTETensor expert_out, + const NVTECommWindow& expert_out_win, NVTETensor result, + cudaStream_t stream) { + NVTE_CHECK(initialized_, "EPBackend not initialized"); + NVTE_CHECK(handle_mem != nullptr, "handle_mem must not be null"); + + NVTEShape exp_shape = nvte_tensor_shape(expert_out); + NVTEDType exp_dtype = nvte_tensor_type(expert_out); + + size_t exp_sizes[2] = {exp_shape.data[0], exp_shape.data[1]}; + ncclEpTensor_t nccl_expert_in = + make_payload_tensor(expert_out, expert_out_win, 2, nvte_dtype_to_nccl(exp_dtype), exp_sizes); + + NVTEShape res_shape = nvte_tensor_shape(result); + void* res_data = nvte_tensor_data(result); + NVTEDType res_dtype = nvte_tensor_type(result); + NVTE_CHECK(res_data != nullptr, "result data must not be null"); + + size_t res_sizes[2] = {res_shape.data[0], res_shape.data[1]}; + ncclEpTensor_t nccl_result_out = + make_tensor(res_data, 2, nvte_dtype_to_nccl(res_dtype), res_sizes); + + ncclEpCombineInputs_t in_struct = NCCL_EP_COMBINE_INPUTS_INIT; + in_struct.tokens = &nccl_expert_in; + + ncclEpCombineOutputs_t out_struct = NCCL_EP_COMBINE_OUTPUTS_INIT; + out_struct.tokens = &nccl_result_out; + + ScopedEpHandle transient; + { + std::lock_guard lock(mutex_); + HandleEntry& cfg = lookup_config(handle_id); + transient = + ScopedEpHandle(open_handle(handle_mem, cfg.handle_mem_size, cfg.top_k, cfg.alignment)); + } + NVTE_CHECK_NCCL(ncclEpCombine(transient, &in_struct, &out_struct, /*config=*/nullptr, stream)); +} + +void EPBackend::dispatch_bwd(uint64_t handle_id, void* handle_mem, const NVTETensor grad, + const NVTECommWindow& grad_win, const NVTETensor g_recv_topk_weights, + const NVTECommWindow& g_recv_topk_weights_win, NVTETensor grad_tokens, + NVTETensor grad_topk_weights, cudaStream_t stream) { + NVTE_CHECK(initialized_, "EPBackend not initialized"); + NVTE_CHECK(handle_mem != nullptr, "handle_mem must not be null"); + + NVTEShape g_shape = nvte_tensor_shape(grad); + NVTEDType g_dtype = nvte_tensor_type(grad); + size_t g_sizes[2] = {g_shape.data[0], g_shape.data[1]}; + ncclEpTensor_t nccl_tok_in = + make_payload_tensor(grad, grad_win, 2, nvte_dtype_to_nccl(g_dtype), g_sizes); + + // g_recv_topk_weights must be 1D [recv_capacity] — caller flattens. + NVTEShape gw_shape = nvte_tensor_shape(g_recv_topk_weights); + NVTE_CHECK(gw_shape.ndim == 1, + "g_recv_topk_weights must be 1D [recv_capacity]; caller must flatten leading dims"); + size_t gw_sizes[1] = {gw_shape.data[0]}; + ncclEpTensor_t nccl_w_in = + make_payload_tensor(g_recv_topk_weights, g_recv_topk_weights_win, 1, ncclFloat32, gw_sizes); + + NVTEShape gt_shape = nvte_tensor_shape(grad_tokens); + void* gt_data = nvte_tensor_data(grad_tokens); + NVTE_CHECK(gt_data != nullptr, "grad_tokens data must not be null"); + size_t gt_sizes[2] = {gt_shape.data[0], gt_shape.data[1]}; + ncclEpTensor_t nccl_tok_out = make_tensor(gt_data, 2, nvte_dtype_to_nccl(g_dtype), gt_sizes); + + NVTEShape gtw_shape = nvte_tensor_shape(grad_topk_weights); + void* gtw_data = nvte_tensor_data(grad_topk_weights); + NVTE_CHECK(gtw_data != nullptr, "grad_topk_weights data must not be null"); + NVTE_CHECK(gtw_shape.ndim == 2, "grad_topk_weights must be 2D [T, top_k]"); + size_t gtw_sizes[2] = {gtw_shape.data[0], gtw_shape.data[1]}; + ncclEpTensor_t nccl_w_out = make_tensor(gtw_data, 2, ncclFloat32, gtw_sizes); + + ncclEpCombineInputs_t in_struct = NCCL_EP_COMBINE_INPUTS_INIT; + in_struct.tokens = &nccl_tok_in; + in_struct.topk_weights = &nccl_w_in; + + ncclEpCombineOutputs_t out_struct = NCCL_EP_COMBINE_OUTPUTS_INIT; + out_struct.tokens = &nccl_tok_out; + out_struct.topk_weights = &nccl_w_out; + + ncclEpCombineConfig_t cfg = NCCL_EP_COMBINE_CONFIG_INIT; + cfg.pass_direction = NCCL_EP_BWD_PASS; + + ScopedEpHandle transient; + { + std::lock_guard lock(mutex_); + HandleEntry& entry = lookup_config(handle_id); + transient = ScopedEpHandle( + open_handle(handle_mem, entry.handle_mem_size, entry.top_k, entry.alignment)); + } + NVTE_CHECK_NCCL(ncclEpCombine(transient, &in_struct, &out_struct, &cfg, stream)); +} + +void EPBackend::combine_bwd(uint64_t handle_id, void* handle_mem, const NVTETensor grad, + const NVTECommWindow& grad_win, NVTETensor grad_expert_out, + const NVTECommWindow& grad_expert_out_win, cudaStream_t stream) { + // Backward of combine = reverse-direction dispatch. + dispatch(handle_id, handle_mem, /*topk_idx=*/nullptr, grad, grad_win, /*topk_weights=*/nullptr, + /*topk_weights_win=*/NVTECommWindow{}, grad_expert_out, grad_expert_out_win, + /*recv_topk_weights=*/nullptr, /*recv_topk_weights_win=*/NVTECommWindow{}, stream); +} + +} // namespace ep +} // namespace transformer_engine diff --git a/transformer_engine/common/ep/ep_backend.h b/transformer_engine/common/ep/ep_backend.h new file mode 100644 index 0000000000..18307ebb4f --- /dev/null +++ b/transformer_engine/common/ep/ep_backend.h @@ -0,0 +1,114 @@ +/************************************************************************* + * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +/*! \file ep_backend.h + * \brief Internal NCCL EP singleton; not part of the public API. + * + * Per handle_id the cache stores config only (no device pointers), so + * handle_mem may be relocated between ops. Cap: NVTE_EP_HANDLE_CACHE_SIZE + * (default 8192); overflow throws. + */ + +#ifndef TRANSFORMER_ENGINE_COMMON_EP_EP_BACKEND_H_ +#define TRANSFORMER_ENGINE_COMMON_EP_EP_BACKEND_H_ + +#include +#include +#include +#include + +#include +#include +#include +#include +#include + +namespace transformer_engine { +namespace ep { + +/*! \brief EP backend singleton — owns the NCCL EP group; borrows the comm. */ +class EPBackend { + public: + /*! \brief Access the singleton. Aborts if not initialized. */ + static EPBackend& get(); + + /*! \brief Bootstrap from an existing EP sub-communicator. + * ep_comm is borrowed; the caller keeps it alive until shutdown() returns + * and must span exactly config.ep_size ranks. + */ + static void initialize(ncclComm_t ep_comm, NVTEEpGroupConfig config); + + /*! \brief Tear down the backend. Idempotent. Does not destroy ep_comm_. */ + static void shutdown(); + + // Host-only: reserve a fresh handle_id, cache the layer config, and report + // the handle_mem buffer size the caller must allocate. + uint64_t register_layer(NVTEEpLayerConfig layer_config, size_t* handle_mem_size); + + void prepare(uint64_t handle_id, const NVTETensor topk_idx, NVTETensor token_counts, + void* handle_mem, size_t dispatch_output_per_expert_alignment, cudaStream_t stream); + + void dispatch(uint64_t handle_id, void* handle_mem, const NVTETensor topk_idx, + const NVTETensor tokens, const NVTECommWindow& tokens_win, + const NVTETensor topk_weights, const NVTECommWindow& topk_weights_win, + NVTETensor recv_tokens, const NVTECommWindow& recv_tokens_win, + NVTETensor recv_topk_weights, const NVTECommWindow& recv_topk_weights_win, + cudaStream_t stream); + + void combine(uint64_t handle_id, void* handle_mem, const NVTETensor expert_out, + const NVTECommWindow& expert_out_win, NVTETensor result, cudaStream_t stream); + + // g_recv_topk_weights: 1D [recv_capacity] f32; grad_topk_weights: 2D [T, top_k] f32. + void dispatch_bwd(uint64_t handle_id, void* handle_mem, const NVTETensor grad, + const NVTECommWindow& grad_win, const NVTETensor g_recv_topk_weights, + const NVTECommWindow& g_recv_topk_weights_win, NVTETensor grad_tokens, + NVTETensor grad_topk_weights, cudaStream_t stream); + + void combine_bwd(uint64_t handle_id, void* handle_mem, const NVTETensor grad, + const NVTECommWindow& grad_win, NVTETensor grad_expert_out, + const NVTECommWindow& grad_expert_out_win, cudaStream_t stream); + + private: + EPBackend() = default; + ~EPBackend(); + EPBackend(const EPBackend&) = delete; + EPBackend& operator=(const EPBackend&) = delete; + + // ep_comm is borrowed — caller retains ownership across the backend lifetime. + void init(ncclComm_t ep_comm, NVTEEpGroupConfig config); + + static EPBackend& instance(); // Meyers singleton accessor + static void validate_config(const NVTEEpGroupConfig& config); + + static ncclDataType_t nvte_dtype_to_nccl(NVTEDType dtype); + // Open a transient ncclEpHandle over handle_mem. num_topk=-1 for paths + // that don't carry per-token weights. + ncclEpHandle_t open_handle(void* handle_mem, size_t handle_mem_size, int num_topk, + size_t dispatch_output_per_expert_alignment); + + ncclEpGroup_t ep_group_{nullptr}; + ncclComm_t ep_comm_{nullptr}; + NVTEEpGroupConfig group_config_{}; + bool initialized_{false}; + std::mutex mutex_; + struct HandleEntry { + size_t handle_mem_size; + size_t alignment; + int top_k; + }; + std::unordered_map handles_; + std::atomic next_handle_id_{1}; // 0 reserved as "no id" + size_t handle_cache_cap_{0}; // set lazily from NVTE_EP_HANDLE_CACHE_SIZE + + // Caller must hold mutex_. Throws on cap overflow. + uint64_t insert_new_entry(size_t handle_mem_size, int top_k, size_t alignment); + HandleEntry& lookup_config(uint64_t handle_id); +}; + +} // namespace ep +} // namespace transformer_engine + +#endif // TRANSFORMER_ENGINE_COMMON_EP_EP_BACKEND_H_ diff --git a/transformer_engine/common/include/transformer_engine/comm_window.h b/transformer_engine/common/include/transformer_engine/comm_window.h new file mode 100644 index 0000000000..088ea7f0c3 --- /dev/null +++ b/transformer_engine/common/include/transformer_engine/comm_window.h @@ -0,0 +1,32 @@ +/************************************************************************* + * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +/*! \file comm_window.h + * \brief Borrowed symmetric-memory window + offset for zero-copy one-sided ops. + * Pass ``{NULL, 0}`` to use the raw-pointer path. + */ + +#ifndef TRANSFORMER_ENGINE_COMM_WINDOW_H_ +#define TRANSFORMER_ENGINE_COMM_WINDOW_H_ + +#include +#include + +#ifdef __cplusplus +extern "C" { +#endif + +/*! \brief NCCL window + byte offset for a zero-copy payload tensor. */ +typedef struct { + ncclWindow_t window; /*!< NCCL window, or NULL to use the raw data pointer. */ + uint64_t offset; /*!< Byte offset of the payload within ``window``. */ +} NVTECommWindow; + +#ifdef __cplusplus +} +#endif + +#endif // TRANSFORMER_ENGINE_COMM_WINDOW_H_ diff --git a/transformer_engine/common/include/transformer_engine/ep.h b/transformer_engine/common/include/transformer_engine/ep.h new file mode 100644 index 0000000000..8c3a06b5f0 --- /dev/null +++ b/transformer_engine/common/include/transformer_engine/ep.h @@ -0,0 +1,161 @@ +/************************************************************************* + * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +/*! \file ep.h + * \brief Public C API for Expert Parallelism. Per-step ops are allocation-free + * and CUDA graph-capturable. + */ + +#ifndef TRANSFORMER_ENGINE_EP_H_ +#define TRANSFORMER_ENGINE_EP_H_ + +#include +#include +#include +#include +#include + +#ifdef __cplusplus +extern "C" { +#endif + +/* ── Config structs ─────────────────────────────────────────────────────── */ + +/*! \brief Group-level EP configuration (fixed for the EP group lifetime). */ +typedef struct { + int ep_size; /*!< EP world size. */ + int num_experts; /*!< Total experts across all ranks. */ + int max_tokens_per_rank; /*!< Upper bound on tokens this rank sends per dispatch. */ + /*! Upper bound on tokens received per dispatch (worst-case top_k fan-out; must be > 0). */ + int max_recv_tokens_per_rank; + int hidden_dim; /*!< Token hidden dimension. */ + int max_num_sms; /*!< Max SMs for EP kernels. 0 = auto. */ + /*! 0 (default): throw on relocated handle_mem for a cached handle_id. 1: silently rebuild. */ + int allow_handle_mem_reloc; +} NVTEEpGroupConfig; + +/*! \brief Per-layer EP configuration. */ +typedef struct { + int num_local_experts; /*!< Reserved for ABI stability (derived from group config). */ + int top_k; /*!< Per-token expert fan-out. Required. */ + size_t dispatch_output_per_expert_alignment; + /*!< Per-expert zone alignment in tokens (pow2; 0/1 = no padding). Must match + * between nvte_ep_register_layer and nvte_ep_prepare. */ +} NVTEEpLayerConfig; + +/* ── Bootstrap ──────────────────────────────────────────────────────────── */ + +/*! \brief Bootstrap from an existing NCCL EP sub-communicator. Requires SM>=90. + * + * ep_comm is borrowed and must span exactly group_config.ep_size ranks. + * Re-init after shutdown is allowed; double-init throws. + * + * \param[in] ep_comm Opaque ncclComm_t for the EP sub-group. + * \param[in] group_config Group-level EP configuration. + */ +void nvte_ep_initialize(void* ep_comm, NVTEEpGroupConfig group_config); + +/*! \brief Tear down the EP backend. Idempotent. Does not destroy ep_comm. */ +void nvte_ep_shutdown(void); + +/* ── Layer registration (host-only, eager) ───────────────────────────────── */ + +/*! \brief Reserve a handle_id for a layer config and report the handle_mem buffer + * size the caller must allocate. Host-only. + * + * \param[in] layer_config Per-layer EP configuration. + * \param[out] handle_mem_size Bytes the caller must allocate for handle_mem. + * \return uint64_t handle_id (non-zero). + */ +uint64_t nvte_ep_register_layer(NVTEEpLayerConfig layer_config, size_t* handle_mem_size); + +/*! \brief Per-step handle: the registered handle_id paired with its handle_mem buffer. */ +typedef struct { + uint64_t id; /*!< Handle id from nvte_ep_register_layer. */ + NVTETensor mem; /*!< Caller-allocated handle_mem buffer (size from nvte_ep_register_layer). */ +} NVTEEpHandle; + +/* ── Per-step ops (all allocation-free, CUDA graph-capturable) ──────────── */ + +/*! \brief AllGather the routing map; write per-expert counts and cache routing + * metadata in handle.mem for the subsequent dispatch/combine. + * + * \param[in] handle EP handle (id + mem buffer). + * \param[in] topk_idx [T, top_k] int64 routing indices. + * \param[out] token_counts [num_local_experts] int32 counts. + * \param[in] dispatch_output_per_expert_alignment Must match the handle_mem sizing. + * \param[in] stream CUDA stream. + */ +void nvte_ep_prepare(NVTEEpHandle handle, NVTETensor topk_idx, NVTETensor token_counts, + size_t dispatch_output_per_expert_alignment, cudaStream_t stream); + +/*! \brief Dispatch tokens (and routing weights) to expert ranks. + * + * \param[in] handle EP handle (id + mem buffer). + * \param[in] topk_idx [T, top_k] int64 sparse routing indices. + * \param[in] tokens [T, hidden_dim] input tokens. + * \param[in] tokens_win Optional symmem window for ``tokens``. + * \param[in] topk_weights [T, top_k] float32 weights, or null in backward. + * \param[in] topk_weights_win Optional symmem window for ``topk_weights``. + * \param[out] recv_tokens [recv_T, hidden_dim] received tokens. + * \param[in] recv_tokens_win Optional symmem window for ``recv_tokens``. + * \param[out] recv_topk_weights [recv_T] float32 per-slot weights, or null in backward. + * \param[in] recv_topk_weights_win Optional symmem window for ``recv_topk_weights``. + * \param[in] stream CUDA stream. + */ +void nvte_ep_dispatch(NVTEEpHandle handle, NVTETensor topk_idx, NVTETensor tokens, + NVTECommWindow tokens_win, NVTETensor topk_weights, + NVTECommWindow topk_weights_win, NVTETensor recv_tokens, + NVTECommWindow recv_tokens_win, NVTETensor recv_topk_weights, + NVTECommWindow recv_topk_weights_win, cudaStream_t stream); + +/*! \brief Scatter-sum expert outputs back to originating ranks. Unweighted — + * caller must pre-multiply expert_out by recv_topk_weights (and the + * valid-slot mask) before calling. + * + * \param[in] handle EP handle (id + mem buffer). + * \param[in] expert_out [recv_T, hidden_dim] pre-weighted expert outputs. + * \param[in] expert_out_win Optional symmem window for ``expert_out``. + * \param[out] result [T, hidden_dim] combined output. + * \param[in] stream CUDA stream. + */ +void nvte_ep_combine(NVTEEpHandle handle, NVTETensor expert_out, NVTECommWindow expert_out_win, + NVTETensor result, cudaStream_t stream); + +/*! \brief Backward of dispatch — routes token and weight grads back to source. + * + * \param[in] handle EP handle (id + mem buffer). + * \param[in] grad [recv_capacity, hidden_dim] grad w.r.t. recv_tokens. + * \param[in] grad_win Optional symmem window for ``grad``. + * \param[in] g_recv_topk_weights [recv_capacity] f32 grad w.r.t. recv_topk_weights. + * \param[in] g_recv_topk_weights_win Optional symmem window for ``g_recv_topk_weights``. + * \param[out] grad_tokens [T, hidden_dim] grad w.r.t. tokens. + * \param[out] grad_topk_weights [T, top_k] f32 grad w.r.t. topk_weights. + * \param[in] stream CUDA stream. + */ +void nvte_ep_dispatch_bwd(NVTEEpHandle handle, NVTETensor grad, NVTECommWindow grad_win, + NVTETensor g_recv_topk_weights, NVTECommWindow g_recv_topk_weights_win, + NVTETensor grad_tokens, NVTETensor grad_topk_weights, + cudaStream_t stream); + +/*! \brief Backward of combine. Padded slots in grad_expert_out are zeroed. + * + * \param[in] handle EP handle (id + mem buffer). + * \param[in] grad [T, hidden_dim] grad w.r.t. result. + * \param[in] grad_win Optional symmem window for ``grad``. + * \param[out] grad_expert_out [recv_capacity, hidden_dim] grad w.r.t. expert_out. + * \param[in] grad_expert_out_win Optional symmem window for ``grad_expert_out``. + * \param[in] stream CUDA stream. + */ +void nvte_ep_combine_bwd(NVTEEpHandle handle, NVTETensor grad, NVTECommWindow grad_win, + NVTETensor grad_expert_out, NVTECommWindow grad_expert_out_win, + cudaStream_t stream); + +#ifdef __cplusplus +} +#endif + +#endif // TRANSFORMER_ENGINE_EP_H_ diff --git a/transformer_engine/pytorch/csrc/extensions.h b/transformer_engine/pytorch/csrc/extensions.h index 8082ff07ed..f9d6f3acf3 100644 --- a/transformer_engine/pytorch/csrc/extensions.h +++ b/transformer_engine/pytorch/csrc/extensions.h @@ -606,6 +606,49 @@ void inplace_multi_tensor_swizzle_scales_for_gemm_unchecked(std::vector ep_register_layer(int64_t top_k, + int64_t dispatch_output_per_expert_alignment); + +void ep_prepare(at::Tensor handle_mem, int64_t handle_id, at::Tensor topk_idx, + at::Tensor token_counts, int64_t dispatch_output_per_expert_alignment); + +void ep_dispatch(at::Tensor handle_mem, int64_t handle_id, at::Tensor topk_idx, at::Tensor tokens, + at::Tensor topk_weights, at::Tensor recv_tokens, at::Tensor recv_topk_weights); + +void ep_combine(at::Tensor handle_mem, int64_t handle_id, at::Tensor expert_out, at::Tensor result); + +void ep_dispatch_bwd(at::Tensor handle_mem, int64_t handle_id, at::Tensor grad, + at::Tensor g_recv_topk_weights, at::Tensor grad_tokens, + at::Tensor grad_topk_weights); + +void ep_combine_bwd(at::Tensor handle_mem, int64_t handle_id, at::Tensor grad, + at::Tensor grad_expert_out); + +// Registers the EP pybind functions on `m`. Defined under NVTE_WITH_NCCL_EP. +void register_ep_bindings(pybind11::module_ &m); + /*************************************************************************************************** * NVSHMEM APIs **************************************************************************************************/ diff --git a/transformer_engine/pytorch/csrc/extensions/ep.cpp b/transformer_engine/pytorch/csrc/extensions/ep.cpp new file mode 100644 index 0000000000..66d1283a2d --- /dev/null +++ b/transformer_engine/pytorch/csrc/extensions/ep.cpp @@ -0,0 +1,407 @@ +/************************************************************************* + * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#ifdef NVTE_WITH_NCCL_EP + +#include "transformer_engine/ep.h" + +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "transformer_engine/comm_window.h" + +#ifdef NCCL_HAS_SYMMEM_SUPPORT +#include +#endif + +#include "../common.h" +#include "../extensions.h" +#include "transformer_engine/gemm.h" + +namespace transformer_engine::pytorch { + +namespace { + +// EP process group name, captured at ep_bootstrap. Used by per-step ops to +// look up SymmetricMemory for payload tensors. Empty until bootstrap. +std::string g_ep_group_name; // NOLINT(runtime/string) + +// NCCL comm for the EP sub-group, owned by this process. Destroyed in +// ep_finalize AFTER nvte_ep_shutdown releases the backend's borrowed reference. +void* g_ep_nccl_comm = nullptr; + +// When false, per-step ops skip symm-mem window annotation and the backend +// takes the staged-copy path. Atomic so the Python-side `_zero_copy_scope` +// toggle is safe against concurrent ep_dispatch/combine (which release the GIL). +std::atomic g_zero_copy_enabled{true}; + +// Warn-once per role when an EP payload tensor isn't symm-mem-backed (the +// staged-copy fallback is correct but slower). Set NVTE_EP_SILENCE_NONSYMM_WARN=1 +// to silence (CI or known-non-symm paths). +void warn_if_not_symm(const at::Tensor& t, const char* role) { +#ifdef NCCL_HAS_SYMMEM_SUPPORT + if (g_ep_group_name.empty()) return; + if (c10d::symmetric_memory::is_symm_mem_tensor(t)) return; + static const bool silenced = []() { + const char* e = std::getenv("NVTE_EP_SILENCE_NONSYMM_WARN"); + return e != nullptr && e[0] != '\0' && e[0] != '0'; + }(); + if (silenced) return; + static std::atomic warned_mask{0}; + uint32_t h = 5381; + for (const char* p = role; *p; ++p) h = h * 33u + static_cast(*p); + const uint32_t bit = 1u << (h & 31u); + if ((warned_mask.fetch_or(bit) & bit) != 0) return; + std::fprintf(stderr, + "[NVTE EP] WARNING: %s tensor is not backed by an NCCL symmetric memory " + "window; falling back to staged copy. Allocate this buffer via the " + "framework's symm-mem API and rendezvous it for the zero-copy fast path. " + "Set NVTE_EP_SILENCE_NONSYMM_WARN=1 to suppress.\n", + role); +#else + (void)t; + (void)role; +#endif +} + +// Build the NVTECommWindow descriptor for ``t`` so the backend can take the +// zero-copy one-sided path. Returns ``{nullptr, 0}`` when symm-mem is disabled, +// not yet bootstrapped, or ``t`` isn't symm-mem-backed — in which case the +// backend falls back to the raw-pointer staged path. +NVTECommWindow maybe_make_window(const at::Tensor& t) { +#ifdef NCCL_HAS_SYMMEM_SUPPORT + if (!g_zero_copy_enabled.load(std::memory_order_relaxed)) return NVTECommWindow{nullptr, 0}; + if (g_ep_group_name.empty()) return NVTECommWindow{nullptr, 0}; + if (!c10d::symmetric_memory::is_symm_mem_tensor(t)) return NVTECommWindow{nullptr, 0}; + auto sm = c10d::symmetric_memory::get_symmetric_memory(t, g_ep_group_name); + NVTE_CHECK(sm != nullptr, + "EP payload tensor is symm-mem-backed but not rendezvoused on the EP group; " + "call symm_mem_alloc (or rendezvous explicitly at allocation time) before EP ops."); + auto* nccl_sm = dynamic_cast(sm.get()); + NVTE_CHECK(nccl_sm != nullptr, + "Symm-mem backend mismatch: expected NCCLSymmetricMemory. Set the backend to " + "\"NCCL\" before allocating EP payload buffers."); + return NVTECommWindow{static_cast(nccl_sm->get_window()), + static_cast(nccl_sm->get_offset())}; +#else + (void)t; + return NVTECommWindow{nullptr, 0}; +#endif +} + +// The backend only accepts int64 topk_idx. The PyTorch wrapper enforces this +// at the boundary so the per-step ops don't need an upcast workspace. +void check_topk_idx_int64(at::Tensor topk_idx) { + NVTE_CHECK(topk_idx.is_contiguous(), "topk_idx must be contiguous"); + NVTE_CHECK(topk_idx.scalar_type() == at::kLong, + "topk_idx must be int64; got dtype=", c10::toString(topk_idx.scalar_type()), + ". Cast with topk_idx.long() before calling."); +} + +using Shape = std::vector; + +} // namespace + +void ep_set_zero_copy(bool enabled) { + g_zero_copy_enabled.store(enabled, std::memory_order_relaxed); +} + +bool ep_get_zero_copy() { return g_zero_copy_enabled.load(std::memory_order_relaxed); } + +// ── Bootstrap ──────────────────────────────────────────────────────────────── +// +// POD-only across the pybind boundary; no c10d types (their ABI churns across +// PyTorch releases). Python orchestrates the uniqueId broadcast over dist. + +constexpr int kEpUniqueIdSize = 128; + +py::bytes ep_get_unique_id() { + static_assert(sizeof(ncclUniqueId) == kEpUniqueIdSize, + "ncclUniqueId size mismatch with kEpUniqueIdSize"); + static_assert(std::is_trivially_copyable::value, + "ncclUniqueId must be trivially copyable for byte-level memcpy"); + ncclUniqueId uid{}; + ncclResult_t ret = ncclGetUniqueId(&uid); + NVTE_CHECK(ret == ncclSuccess, "ncclGetUniqueId returned ", static_cast(ret)); + return py::bytes(uid.internal, kEpUniqueIdSize); +} + +void ep_initialize(const std::string& unique_id_bytes, int64_t rank, int64_t ep_size, + const std::string& group_name, int64_t num_experts, int64_t max_tokens_per_rank, + int64_t max_recv_tokens_per_rank, int64_t hidden_dim, int64_t max_num_sms, + bool allow_handle_mem_reloc) { + NVTE_CHECK(static_cast(unique_id_bytes.size()) == kEpUniqueIdSize, + "unique_id_bytes must be ", kEpUniqueIdSize, " bytes, got ", unique_id_bytes.size()); + NVTE_CHECK(!group_name.empty(), "group_name must be non-empty (used for symm-mem lookup)"); + NVTE_CHECK(g_ep_nccl_comm == nullptr, "ep_initialize called twice without ep_finalize"); + NVTEEpGroupConfig cfg{ + /*ep_size=*/static_cast(ep_size), + /*num_experts=*/static_cast(num_experts), + /*max_tokens_per_rank=*/static_cast(max_tokens_per_rank), + /*max_recv_tokens_per_rank=*/static_cast(max_recv_tokens_per_rank), + /*hidden_dim=*/static_cast(hidden_dim), + /*max_num_sms=*/static_cast(max_num_sms), + /*allow_handle_mem_reloc=*/allow_handle_mem_reloc ? 1 : 0, + }; + + // Copy bytes into a typed ncclUniqueId so the ABI is unambiguous when + // passing it by value to ncclCommInitRank. + ncclUniqueId uid{}; + std::memcpy(uid.internal, unique_id_bytes.data(), kEpUniqueIdSize); + ncclComm_t ep_comm = nullptr; + ncclResult_t rc = + ncclCommInitRank(&ep_comm, static_cast(ep_size), uid, static_cast(rank)); + NVTE_CHECK(rc == ncclSuccess, "ncclCommInitRank returned ", static_cast(rc)); + NVTE_CHECK(ep_comm != nullptr, "ncclCommInitRank produced a null comm"); + + // Destroy the comm if nvte_ep_initialize throws — otherwise the rank-collective + // ncclCommInitRank succeeds but no later finalize is registered to free it. + try { + nvte_ep_initialize(static_cast(ep_comm), cfg); + } catch (...) { + (void)ncclCommDestroy(ep_comm); + throw; + } + g_ep_nccl_comm = static_cast(ep_comm); + // Cache for symm-mem lookup in per-step ops (see maybe_make_window). + g_ep_group_name = group_name; +} + +void ep_finalize() { + nvte_ep_shutdown(); + // Destroy the comm AFTER the backend has released its borrowed reference. + if (g_ep_nccl_comm != nullptr) { + (void)ncclCommDestroy(static_cast(g_ep_nccl_comm)); + g_ep_nccl_comm = nullptr; + } + g_ep_group_name.clear(); +} + +std::tuple ep_register_layer(int64_t top_k, + int64_t dispatch_output_per_expert_alignment) { + NVTEEpLayerConfig layer_cfg{0, static_cast(top_k), + static_cast(dispatch_output_per_expert_alignment)}; + size_t handle_mem_size = 0; + uint64_t handle_id = nvte_ep_register_layer(layer_cfg, &handle_mem_size); + return std::make_tuple(static_cast(handle_id), static_cast(handle_mem_size)); +} + +// ── Per-step ops ───────────────────────────────────────────────────────────── + +void ep_prepare(at::Tensor handle_mem, int64_t handle_id, at::Tensor topk_idx, + at::Tensor token_counts, int64_t dispatch_output_per_expert_alignment) { + auto stream = at::cuda::getCurrentCUDAStream().stream(); + NVTE_CHECK(topk_idx.dim() >= 2, "topk_idx must be at least 2D [..., top_k]"); + check_topk_idx_int64(topk_idx); + const size_t T_flat = topk_idx.numel() / topk_idx.size(-1); + const size_t topk_n = static_cast(topk_idx.size(-1)); + + auto topk_idx_te = + makeTransformerEngineTensor(topk_idx.data_ptr(), Shape{T_flat, topk_n}, DType::kInt64); + auto token_counts_te = makeTransformerEngineTensor( + token_counts.data_ptr(), Shape{static_cast(token_counts.numel())}, DType::kInt32); + auto handle_mem_te = makeTransformerEngineTensor( + handle_mem.data_ptr(), Shape{static_cast(handle_mem.numel())}, DType::kByte); + + NVTEEpHandle handle{static_cast(handle_id), handle_mem_te.data()}; + nvte_ep_prepare(handle, topk_idx_te.data(), token_counts_te.data(), + static_cast(dispatch_output_per_expert_alignment), stream); +} + +void ep_dispatch(at::Tensor handle_mem, int64_t handle_id, at::Tensor topk_idx, at::Tensor tokens, + at::Tensor topk_weights, at::Tensor recv_tokens, at::Tensor recv_topk_weights) { + auto stream = at::cuda::getCurrentCUDAStream().stream(); + NVTE_CHECK(tokens.dim() >= 2, "tokens must be at least 2D [..., H]"); + NVTE_CHECK(topk_idx.dim() >= 2, "topk_idx must be at least 2D [..., top_k]"); + NVTE_CHECK(topk_weights.dim() >= 2, "topk_weights must be at least 2D [..., top_k]"); + NVTE_CHECK(recv_tokens.dim() >= 2, "recv_tokens must be at least 2D [..., recv_pr, H]"); + check_topk_idx_int64(topk_idx); + + const size_t H = static_cast(tokens.size(-1)); + const size_t T_flat = tokens.numel() / H; + const size_t topk_n = static_cast(topk_idx.size(-1)); + const size_t recv_pr = recv_tokens.numel() / H; + + NVTE_CHECK(static_cast(topk_weights.size(-1)) == topk_n, + "topk_weights last dim must equal topk_idx last dim"); + NVTE_CHECK(static_cast(recv_topk_weights.numel()) == recv_pr, + "recv_topk_weights total size must equal recv_tokens recv_pr"); + NVTE_CHECK(recv_tokens.scalar_type() == tokens.scalar_type(), "recv_tokens dtype (", + c10::toString(recv_tokens.scalar_type()), ") must match tokens dtype (", + c10::toString(tokens.scalar_type()), ")"); + + auto tok_dtype = GetTransformerEngineDType(tokens.scalar_type()); + auto handle_mem_te = makeTransformerEngineTensor( + handle_mem.data_ptr(), Shape{static_cast(handle_mem.numel())}, DType::kByte); + auto topk_idx_te = + makeTransformerEngineTensor(topk_idx.data_ptr(), Shape{T_flat, topk_n}, DType::kInt64); + auto tokens_te = makeTransformerEngineTensor(tokens.data_ptr(), Shape{T_flat, H}, tok_dtype); + warn_if_not_symm(tokens, "dispatch input (tokens)"); + NVTECommWindow tokens_win = maybe_make_window(tokens); + auto topk_w_te = + makeTransformerEngineTensor(topk_weights.data_ptr(), Shape{T_flat, topk_n}, DType::kFloat32); + // topk_weights symm-mem backing is nice-to-have, not required — silently + // fall back to the staged path if the caller didn't allocate it via symm-mem. + NVTECommWindow topk_weights_win = maybe_make_window(topk_weights); + auto recv_tokens_te = + makeTransformerEngineTensor(recv_tokens.data_ptr(), Shape{recv_pr, H}, tok_dtype); + warn_if_not_symm(recv_tokens, "dispatch output (recv_tokens)"); + NVTECommWindow recv_tokens_win = maybe_make_window(recv_tokens); + auto recv_topk_w_te = + makeTransformerEngineTensor(recv_topk_weights.data_ptr(), Shape{recv_pr}, DType::kFloat32); + NVTECommWindow recv_topk_weights_win = maybe_make_window(recv_topk_weights); + + NVTEEpHandle handle{static_cast(handle_id), handle_mem_te.data()}; + nvte_ep_dispatch(handle, topk_idx_te.data(), tokens_te.data(), tokens_win, topk_w_te.data(), + topk_weights_win, recv_tokens_te.data(), recv_tokens_win, recv_topk_w_te.data(), + recv_topk_weights_win, stream); +} + +void ep_combine(at::Tensor handle_mem, int64_t handle_id, at::Tensor expert_out, + at::Tensor result) { + auto stream = at::cuda::getCurrentCUDAStream().stream(); + NVTE_CHECK(expert_out.dim() >= 2, "expert_out must be at least 2D [..., recv_pr, H]"); + NVTE_CHECK(result.dim() >= 2, "result must be at least 2D [..., H]"); + + const size_t H = static_cast(expert_out.size(-1)); + const size_t recv_pr = expert_out.numel() / H; + const size_t T_flat = result.numel() / H; + NVTE_CHECK(static_cast(result.size(-1)) == H, + "result hidden dim must equal expert_out hidden dim"); + NVTE_CHECK(result.scalar_type() == expert_out.scalar_type(), "result dtype (", + c10::toString(result.scalar_type()), ") must match expert_out dtype (", + c10::toString(expert_out.scalar_type()), ")"); + + auto eo_dtype = GetTransformerEngineDType(expert_out.scalar_type()); + auto handle_mem_te = makeTransformerEngineTensor( + handle_mem.data_ptr(), Shape{static_cast(handle_mem.numel())}, DType::kByte); + auto expert_out_te = + makeTransformerEngineTensor(expert_out.data_ptr(), Shape{recv_pr, H}, eo_dtype); + warn_if_not_symm(expert_out, "combine input (expert_out)"); + NVTECommWindow expert_out_win = maybe_make_window(expert_out); + // combine ``result`` is local accumulation (not cross-rank put/get); leave it + // un-annotated so the backend uses the raw-pointer path regardless of how it + // was allocated. + auto result_te = makeTransformerEngineTensor(result.data_ptr(), Shape{T_flat, H}, eo_dtype); + + NVTEEpHandle handle{static_cast(handle_id), handle_mem_te.data()}; + nvte_ep_combine(handle, expert_out_te.data(), expert_out_win, result_te.data(), stream); +} + +void ep_dispatch_bwd(at::Tensor handle_mem, int64_t handle_id, at::Tensor grad, + at::Tensor g_recv_topk_weights, at::Tensor grad_tokens, + at::Tensor grad_topk_weights) { + auto stream = at::cuda::getCurrentCUDAStream().stream(); + NVTE_CHECK(grad.dim() >= 2, "grad must be at least 2D [..., recv_pr, H]"); + NVTE_CHECK(grad_tokens.dim() >= 2, "grad_tokens must be at least 2D [..., H]"); + NVTE_CHECK(grad_topk_weights.dim() >= 2, "grad_topk_weights must be at least 2D [..., top_k]"); + + const size_t H = static_cast(grad.size(-1)); + const size_t recv_pr = grad.numel() / H; + const size_t T_flat = grad_tokens.numel() / H; + const size_t topk_n = static_cast(grad_topk_weights.size(-1)); + NVTE_CHECK(static_cast(g_recv_topk_weights.numel()) == recv_pr, + "g_recv_topk_weights total size must equal grad recv_pr"); + NVTE_CHECK(static_cast(grad_tokens.size(-1)) == H, + "grad_tokens hidden dim must equal grad H"); + NVTE_CHECK(static_cast(grad_topk_weights.numel()) == T_flat * topk_n, + "grad_topk_weights numel (", grad_topk_weights.numel(), + ") must equal T_flat * top_k (", T_flat * topk_n, ")"); + NVTE_CHECK(grad_tokens.scalar_type() == grad.scalar_type(), "grad_tokens dtype (", + c10::toString(grad_tokens.scalar_type()), ") must match grad dtype (", + c10::toString(grad.scalar_type()), ")"); + + auto g_dtype = GetTransformerEngineDType(grad.scalar_type()); + auto handle_mem_te = makeTransformerEngineTensor( + handle_mem.data_ptr(), Shape{static_cast(handle_mem.numel())}, DType::kByte); + auto grad_te = makeTransformerEngineTensor(grad.data_ptr(), Shape{recv_pr, H}, g_dtype); + warn_if_not_symm(grad, "dispatch_bwd input (grad)"); + NVTECommWindow grad_win = maybe_make_window(grad); + auto g_recv_w_te = + makeTransformerEngineTensor(g_recv_topk_weights.data_ptr(), Shape{recv_pr}, DType::kFloat32); + NVTECommWindow g_recv_topk_weights_win = maybe_make_window(g_recv_topk_weights); + auto grad_tokens_te = + makeTransformerEngineTensor(grad_tokens.data_ptr(), Shape{T_flat, H}, g_dtype); + auto grad_topk_w_te = makeTransformerEngineTensor(grad_topk_weights.data_ptr(), + Shape{T_flat, topk_n}, DType::kFloat32); + + NVTEEpHandle handle{static_cast(handle_id), handle_mem_te.data()}; + nvte_ep_dispatch_bwd(handle, grad_te.data(), grad_win, g_recv_w_te.data(), + g_recv_topk_weights_win, grad_tokens_te.data(), grad_topk_w_te.data(), + stream); +} + +void ep_combine_bwd(at::Tensor handle_mem, int64_t handle_id, at::Tensor grad, + at::Tensor grad_expert_out) { + auto stream = at::cuda::getCurrentCUDAStream().stream(); + NVTE_CHECK(grad.dim() >= 2, "grad must be at least 2D [..., H]"); + NVTE_CHECK(grad_expert_out.dim() >= 2, "grad_expert_out must be at least 2D [..., recv_pr, H]"); + + const size_t H = static_cast(grad.size(-1)); + const size_t T_flat = grad.numel() / H; + const size_t recv_pr = grad_expert_out.numel() / H; + NVTE_CHECK(static_cast(grad_expert_out.size(-1)) == H, + "grad_expert_out hidden dim must match grad H"); + NVTE_CHECK(grad_expert_out.scalar_type() == grad.scalar_type(), "grad_expert_out dtype (", + c10::toString(grad_expert_out.scalar_type()), ") must match grad dtype (", + c10::toString(grad.scalar_type()), ")"); + + auto g_dtype = GetTransformerEngineDType(grad.scalar_type()); + auto handle_mem_te = makeTransformerEngineTensor( + handle_mem.data_ptr(), Shape{static_cast(handle_mem.numel())}, DType::kByte); + auto grad_te = makeTransformerEngineTensor(grad.data_ptr(), Shape{T_flat, H}, g_dtype); + NVTECommWindow grad_win = maybe_make_window(grad); + auto grad_eo_te = + makeTransformerEngineTensor(grad_expert_out.data_ptr(), Shape{recv_pr, H}, g_dtype); + NVTECommWindow grad_eo_win = maybe_make_window(grad_expert_out); + + NVTEEpHandle handle{static_cast(handle_id), handle_mem_te.data()}; + nvte_ep_combine_bwd(handle, grad_te.data(), grad_win, grad_eo_te.data(), grad_eo_win, stream); +} + +void register_ep_bindings(pybind11::module_& m) { + namespace py = pybind11; + m.def("ep_get_unique_id", &ep_get_unique_id, + "Generate a fresh ncclUniqueId (128 raw bytes). Call on rank 0 only."); + m.def("ep_initialize", &ep_initialize, "Initialize the EP backend from a broadcast ncclUniqueId.", + py::arg("unique_id_bytes"), py::arg("rank"), py::arg("ep_size"), py::arg("group_name"), + py::arg("num_experts"), py::arg("max_tokens_per_rank"), py::arg("max_recv_tokens_per_rank"), + py::arg("hidden_dim"), py::arg("max_num_sms") = 0, + py::arg("allow_handle_mem_reloc") = false, py::call_guard()); + m.def("ep_finalize", &ep_finalize, "Tear down the EP backend. Idempotent.", + py::call_guard()); + m.def("ep_set_zero_copy", &ep_set_zero_copy, "Toggle EP zero-copy symm-mem annotation.", + py::arg("enabled")); + m.def("ep_get_zero_copy", &ep_get_zero_copy, "Return the current EP zero-copy toggle state."); + m.def("ep_register_layer", &ep_register_layer, + "Register an EP layer; returns (handle_id, handle_mem_size_bytes).", py::arg("top_k"), + py::arg("dispatch_output_per_expert_alignment") = 0); + m.def("ep_prepare", &ep_prepare, "EP prepare", py::call_guard()); + m.def("ep_dispatch", &ep_dispatch, "EP dispatch", py::call_guard()); + m.def("ep_combine", &ep_combine, "EP combine", py::call_guard()); + m.def("ep_dispatch_bwd", &ep_dispatch_bwd, "EP dispatch backward", + py::call_guard()); + m.def("ep_combine_bwd", &ep_combine_bwd, "EP combine backward", + py::call_guard()); +} + +} // namespace transformer_engine::pytorch + +#endif // NVTE_WITH_NCCL_EP diff --git a/transformer_engine/pytorch/csrc/extensions/pybind.cpp b/transformer_engine/pytorch/csrc/extensions/pybind.cpp index a4571c64e2..98154cffeb 100644 --- a/transformer_engine/pytorch/csrc/extensions/pybind.cpp +++ b/transformer_engine/pytorch/csrc/extensions/pybind.cpp @@ -230,6 +230,10 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { "DSquaredReLU + DBias + Quantize", py::arg("grad"), py::arg("fwd_input"), py::arg("quantizer")); +#ifdef NVTE_WITH_NCCL_EP + transformer_engine::pytorch::register_ep_bindings(m); +#endif // NVTE_WITH_NCCL_EP + // Permutation functions m.def("moe_permute_fwd", transformer_engine::pytorch::moe_permute_fwd, "MOE permute FWD", py::call_guard()); diff --git a/transformer_engine/pytorch/ep.py b/transformer_engine/pytorch/ep.py new file mode 100644 index 0000000000..76677cd177 --- /dev/null +++ b/transformer_engine/pytorch/ep.py @@ -0,0 +1,647 @@ +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. +"""PyTorch Expert Parallelism (EP) API.""" + +from __future__ import annotations + +import atexit +import contextlib +from typing import Optional + +import torch +import torch.distributed as dist + +import transformer_engine_torch as tex + + +__all__ = [ + "EpHandle", + "EpBuffer", + "ep_bootstrap", + "ep_finalize", + "ep_prepare", + "ep_dispatch", + "ep_combine", + "symm_mem_alloc", +] + + +# ── Symmetric-memory buffer allocator ──────────────────────────────────────── + + +def symm_mem_alloc( + shape, + dtype: torch.dtype, + ep_group: dist.ProcessGroup, + device: Optional[torch.device] = None, +) -> torch.Tensor: + """Allocate and rendezvous a symm-mem buffer on ``ep_group`` for the EP zero-copy path. + + Collective on ``ep_group``. + """ + if device is None: + device = torch.device("cuda", torch.cuda.current_device()) + try: + from torch.distributed import _symmetric_memory as _symm_mem + except ImportError as e: + raise RuntimeError( + "torch.distributed._symmetric_memory is unavailable; symm_mem_alloc " + "requires PyTorch built with NCCL symm-mem support." + ) from e + if _symm_mem.get_backend(device) != "NCCL": + _symm_mem.set_backend("NCCL") + t = _symm_mem.empty(*shape, dtype=dtype, device=device) + _symm_mem.rendezvous(t, group=ep_group.group_name) + return t + + +# ── Bootstrap ──────────────────────────────────────────────────────────────── + + +_BOOTSTRAPPED = False +_ATEXIT_REGISTERED = False + + +def _atexit_finalize() -> None: + """Best-effort teardown at interpreter shutdown.""" + global _BOOTSTRAPPED + if _BOOTSTRAPPED: + try: + tex.ep_finalize() + except Exception: + import traceback + + traceback.print_exc() + finally: + _BOOTSTRAPPED = False + + +def ep_bootstrap( + ep_group: dist.ProcessGroup, + num_experts: int, + max_tokens_per_rank: int, + recv_capacity_per_rank: int, + hidden_dim: int, + max_num_sms: int = 0, + allow_handle_mem_reloc: bool = False, +) -> None: + """Initialize the EP communicator. Call once per process before any EP op. + + ``ep_group`` must span exactly the EP ranks; requires ``ep_group.size() >= 4``. + """ + global _BOOTSTRAPPED, _ATEXIT_REGISTERED + if _BOOTSTRAPPED: + raise RuntimeError("ep_bootstrap was already called in this process") + if ep_group.size() < 4: + raise ValueError(f"ep_bootstrap requires ep_group.size() >= 4 (got {ep_group.size()}).") + + rank = ep_group.rank() + ep_size = ep_group.size() + device = torch.device("cuda", torch.cuda.current_device()) + + # dist.broadcast's src is the global rank; translate group-rank 0. + src_global = dist.get_global_rank(ep_group, 0) + uid_bytes = tex.ep_get_unique_id() if rank == 0 else b"\x00" * 128 + uid_tensor = torch.frombuffer(bytearray(uid_bytes), dtype=torch.uint8).to(device) + dist.broadcast(uid_tensor, src=src_global, group=ep_group) + uid_bytes = bytes(uid_tensor.cpu().numpy().tobytes()) + + tex.ep_initialize( + uid_bytes, + int(rank), + int(ep_size), + str(ep_group.group_name), + int(num_experts), + int(max_tokens_per_rank), + int(recv_capacity_per_rank), + int(hidden_dim), + int(max_num_sms), + bool(allow_handle_mem_reloc), + ) + _BOOTSTRAPPED = True + if not _ATEXIT_REGISTERED: + atexit.register(_atexit_finalize) + _ATEXIT_REGISTERED = True + + +def ep_finalize() -> None: + """Explicit teardown. Idempotent. Call before destroying the process group.""" + _atexit_finalize() + + +# ── Handle ─────────────────────────────────────────────────────────────────── + + +class EpHandle: + """Routing context for one EP layer. Construct once at module init; reuse per step. + + Single-use per step: do not share across concurrently in-flight + ``ep_dispatch`` / ``ep_combine`` calls (e.g. PP-1F1B microbatches). + """ + + __slots__ = ( + "handle_mem", + "handle_id", + "top_k", + "alignment", + "max_tokens_per_rank", + "recv_capacity_per_rank", + "hidden_dim", + "num_local_experts", + "payload_dtype", + ) + + def __init__( + self, + top_k: int, + max_tokens_per_rank: int, + recv_capacity_per_rank: int, + hidden_dim: int, + num_local_experts: int, + alignment: int = 0, + device: Optional[torch.device] = None, + payload_dtype: torch.dtype = torch.bfloat16, + ) -> None: + if device is None: + device = torch.device("cuda", torch.cuda.current_device()) + self.top_k = int(top_k) + self.alignment = int(alignment) + self.max_tokens_per_rank = int(max_tokens_per_rank) + self.recv_capacity_per_rank = int(recv_capacity_per_rank) + self.hidden_dim = int(hidden_dim) + self.num_local_experts = int(num_local_experts) + self.payload_dtype = payload_dtype + handle_id, size_bytes = tex.ep_register_layer(self.top_k, self.alignment) + self.handle_id = int(handle_id) + self.handle_mem = torch.empty(int(size_bytes), dtype=torch.uint8, device=device) + + +# ── Buffer ─────────────────────────────────────────────────────────────────── + + +class EpBuffer: + """Symm-mem-backed payload buffers (``recv_tokens``, ``combine_in``) for one EP layer. + + Construct once at layer init (collective rendezvous on ``ep_group``). + ``use_symm_mem=False`` falls back to plain HBM for debug runs. + + Multi-stream usage: call :meth:`record_stream` from every stream that + touches the buffer outside its allocation stream, otherwise the caching + allocator can reclaim memory that peers' symm-mem windows still point at. + """ + + __slots__ = ("recv_tokens", "combine_in") + + def __init__( + self, + handle: EpHandle, + ep_group: Optional[dist.ProcessGroup] = None, + *, + use_symm_mem: bool = True, + device: Optional[torch.device] = None, + ) -> None: + if device is None: + device = torch.device("cuda", torch.cuda.current_device()) + shape = (handle.recv_capacity_per_rank, handle.hidden_dim) + if use_symm_mem: + if ep_group is None: + raise ValueError("EpBuffer(use_symm_mem=True) requires ep_group.") + self.recv_tokens = symm_mem_alloc(shape, handle.payload_dtype, ep_group, device=device) + self.combine_in = symm_mem_alloc(shape, handle.payload_dtype, ep_group, device=device) + else: + self.recv_tokens = torch.empty(shape, dtype=handle.payload_dtype, device=device) + self.combine_in = torch.empty(shape, dtype=handle.payload_dtype, device=device) + + @classmethod + def from_external( + cls, + handle: EpHandle, + *, + recv_tokens: torch.Tensor, + combine_in: torch.Tensor, + ) -> "EpBuffer": + """Construct from caller-allocated buffers (e.g. slices of a shared symm-mem pool). + + Both tensors must have shape ``(handle.recv_capacity_per_rank, handle.hidden_dim)`` + and dtype ``handle.payload_dtype``. + """ + expected = (handle.recv_capacity_per_rank, handle.hidden_dim) + if tuple(recv_tokens.shape) != expected: + raise ValueError(f"recv_tokens shape {tuple(recv_tokens.shape)} != expected {expected}") + if tuple(combine_in.shape) != expected: + raise ValueError(f"combine_in shape {tuple(combine_in.shape)} != expected {expected}") + if recv_tokens.dtype != handle.payload_dtype or combine_in.dtype != handle.payload_dtype: + raise ValueError( + f"buffer dtype must match handle.payload_dtype ({handle.payload_dtype})" + ) + inst = cls.__new__(cls) + inst.recv_tokens = recv_tokens + inst.combine_in = combine_in + return inst + + def record_stream(self, stream: torch.cuda.Stream) -> None: + """Record ``stream`` as a user of both owned tensors so the caching allocator + defers reclaim until ``stream`` has caught up.""" + self.recv_tokens.record_stream(stream) + self.combine_in.record_stream(stream) + + +# ── torch.library custom ops (so they don't graph-break under torch.compile) ─ + +_LIB = "transformer_engine_ep" + + +@torch.library.custom_op( + f"{_LIB}::prepare", + mutates_args=("handle_mem", "token_counts"), + device_types="cuda", +) +def _prepare_op( + handle_mem: torch.Tensor, + handle_id: int, + topk_idx: torch.Tensor, + token_counts: torch.Tensor, + alignment: int, +) -> None: + tex.ep_prepare(handle_mem, handle_id, topk_idx, token_counts, alignment) + + +@_prepare_op.register_fake +def _(*args, **kw): + return None + + +@torch.library.custom_op( + f"{_LIB}::dispatch", + mutates_args=("recv_tokens", "recv_topk_weights"), + device_types="cuda", +) +def _dispatch_op( + handle_mem: torch.Tensor, + handle_id: int, + topk_idx: torch.Tensor, + tokens: torch.Tensor, + topk_weights: torch.Tensor, + recv_tokens: torch.Tensor, + recv_topk_weights: torch.Tensor, +) -> None: + tex.ep_dispatch( + handle_mem, + handle_id, + topk_idx, + tokens, + topk_weights, + recv_tokens, + recv_topk_weights, + ) + + +@_dispatch_op.register_fake +def _(*args, **kw): + return None + + +@torch.library.custom_op( + f"{_LIB}::combine", + mutates_args=("result",), + device_types="cuda", +) +def _combine_op( + handle_mem: torch.Tensor, + handle_id: int, + expert_out: torch.Tensor, + result: torch.Tensor, +) -> None: + tex.ep_combine(handle_mem, handle_id, expert_out, result) + + +@_combine_op.register_fake +def _(*args, **kw): + return None + + +@torch.library.custom_op( + f"{_LIB}::dispatch_bwd", + mutates_args=("grad_tokens", "grad_topk_weights"), + device_types="cuda", +) +def _dispatch_bwd_op( + handle_mem: torch.Tensor, + handle_id: int, + grad: torch.Tensor, + g_recv_topk_weights: torch.Tensor, + grad_tokens: torch.Tensor, + grad_topk_weights: torch.Tensor, +) -> None: + tex.ep_dispatch_bwd( + handle_mem, handle_id, grad, g_recv_topk_weights, grad_tokens, grad_topk_weights + ) + + +@_dispatch_bwd_op.register_fake +def _(*args, **kw): + return None + + +@torch.library.custom_op( + f"{_LIB}::combine_bwd", + mutates_args=("grad_expert_out",), + device_types="cuda", +) +def _combine_bwd_op( + handle_mem: torch.Tensor, + handle_id: int, + grad: torch.Tensor, + grad_expert_out: torch.Tensor, +) -> None: + tex.ep_combine_bwd(handle_mem, handle_id, grad, grad_expert_out) + + +@_combine_bwd_op.register_fake +def _(*args, **kw): + return None + + +# ── Non-autograd primitives ────────────────────────────────────────────────── + + +def ep_prepare(handle: EpHandle, topk_idx: torch.Tensor) -> torch.Tensor: + """AllGather the routing map; fills ``handle.handle_mem`` and returns ``token_counts`` + (int32, shape ``[num_local_experts]``). + + ``topk_idx`` must be int64. + """ + token_counts = torch.empty( + handle.num_local_experts, dtype=torch.int32, device=handle.handle_mem.device + ) + torch.ops.transformer_engine_ep.prepare( + handle.handle_mem, + handle.handle_id, + topk_idx, + token_counts, + handle.alignment, + ) + return token_counts + + +def _ep_dispatch_raw( + handle: EpHandle, + topk_idx: torch.Tensor, + tokens: torch.Tensor, + topk_weights: torch.Tensor, + recv_tokens: torch.Tensor, + recv_topk_weights: torch.Tensor, +) -> None: + """Raw dispatch — no autograd, no prepare. Caller must run ``ep_prepare`` first.""" + torch.ops.transformer_engine_ep.dispatch( + handle.handle_mem, + handle.handle_id, + topk_idx, + tokens, + topk_weights, + recv_tokens, + recv_topk_weights, + ) + + +def _ep_combine_raw(handle: EpHandle, expert_out: torch.Tensor, result: torch.Tensor) -> None: + """Raw combine — no autograd, no weighting. Caller pre-weights ``expert_out``.""" + torch.ops.transformer_engine_ep.combine(handle.handle_mem, handle.handle_id, expert_out, result) + + +# ── autograd.Function wrappers ─────────────────────────────────────────────── + + +class _EpDispatch(torch.autograd.Function): + """Autograd-aware prepare + dispatch. Snapshots ``handle_mem`` so backward survives a + later ``ep_prepare`` on the same handle (e.g. across PP-1F1B microbatches).""" + + @staticmethod + def forward( # type: ignore[override] + ctx, + handle_mem: torch.Tensor, + handle_id: int, + alignment: int, + recv_tokens: torch.Tensor, + num_local_experts: int, + zero_copy: bool, + topk_idx: torch.Tensor, + tokens: torch.Tensor, + topk_weights: torch.Tensor, + ): + device = tokens.device + recv_capacity = recv_tokens.shape[0] + token_counts = torch.empty(num_local_experts, dtype=torch.int32, device=device) + recv_topk_weights = torch.empty(recv_capacity, dtype=torch.float32, device=device) + with _zero_copy_scope(zero_copy): + torch.ops.transformer_engine_ep.prepare( + handle_mem, handle_id, topk_idx, token_counts, alignment + ) + torch.ops.transformer_engine_ep.dispatch( + handle_mem, + handle_id, + topk_idx, + tokens, + topk_weights, + recv_tokens, + recv_topk_weights, + ) + ctx.handle_mem_snapshot = handle_mem.detach().clone() + ctx.handle_id = handle_id + ctx.zero_copy = zero_copy + ctx.tokens_shape = tokens.shape + ctx.tokens_dtype = tokens.dtype + ctx.topk_weights_shape = topk_weights.shape + ctx.topk_weights_dtype = topk_weights.dtype + ctx.recv_capacity = recv_capacity + ctx.hidden_dim = tokens.shape[-1] + ctx.mark_non_differentiable(token_counts) + # Detach so the long-lived buffer isn't tracked as a differentiable output. + return recv_tokens.detach(), recv_topk_weights, token_counts + + @staticmethod + def backward(ctx, g_recv_tokens, g_recv_topk_weights, _g_token_counts): # type: ignore[override] + device = ctx.handle_mem_snapshot.device + if g_recv_tokens is None: + g_recv_tokens = torch.zeros( + ctx.recv_capacity, ctx.hidden_dim, dtype=ctx.tokens_dtype, device=device + ) + if g_recv_topk_weights is None: + g_recv_topk_weights = torch.zeros(ctx.recv_capacity, dtype=torch.float32, device=device) + grad_tokens = torch.empty(ctx.tokens_shape, dtype=ctx.tokens_dtype, device=device) + grad_topk_weights = torch.empty( + ctx.topk_weights_shape, dtype=ctx.topk_weights_dtype, device=device + ) + with _zero_copy_scope(ctx.zero_copy): + torch.ops.transformer_engine_ep.dispatch_bwd( + ctx.handle_mem_snapshot, + ctx.handle_id, + g_recv_tokens.contiguous(), + g_recv_topk_weights.contiguous(), + grad_tokens, + grad_topk_weights, + ) + return ( + None, # handle_mem + None, # handle_id + None, # alignment + None, # recv_tokens + None, # num_local_experts + None, # zero_copy + None, # topk_idx + grad_tokens, + grad_topk_weights, + ) + + +class _EpCombine(torch.autograd.Function): + """Autograd-aware weight + combine. Snapshots ``handle_mem`` so backward survives a + later ``ep_prepare`` on the same handle (e.g. across PP-1F1B microbatches).""" + + @staticmethod + def forward( # type: ignore[override] + ctx, + handle_mem: torch.Tensor, + handle_id: int, + combine_in: torch.Tensor, + num_local_tokens: int, + hidden_dim: int, + zero_copy: bool, + expert_out: torch.Tensor, + recv_topk_weights: torch.Tensor, + ): + device = expert_out.device + w = recv_topk_weights.unsqueeze(-1).to(torch.float32) + mask = (recv_topk_weights != 0).unsqueeze(-1).to(torch.float32) + combine_in.copy_((expert_out.to(torch.float32) * w * mask).to(combine_in.dtype)) + result = torch.empty(num_local_tokens, hidden_dim, dtype=expert_out.dtype, device=device) + with _zero_copy_scope(zero_copy): + torch.ops.transformer_engine_ep.combine(handle_mem, handle_id, combine_in, result) + ctx.save_for_backward(expert_out, recv_topk_weights) + ctx.handle_mem_snapshot = handle_mem.detach().clone() + ctx.handle_id = handle_id + ctx.zero_copy = zero_copy + return result + + @staticmethod + def backward(ctx, g_result): # type: ignore[override] + expert_out, recv_topk_weights = ctx.saved_tensors + grad_combine_in = torch.empty_like(expert_out) + with _zero_copy_scope(ctx.zero_copy): + torch.ops.transformer_engine_ep.combine_bwd( + ctx.handle_mem_snapshot, ctx.handle_id, g_result.contiguous(), grad_combine_in + ) + w = recv_topk_weights.unsqueeze(-1).to(torch.float32) + mask = (recv_topk_weights != 0).unsqueeze(-1).to(torch.float32) + gci_f32 = grad_combine_in.to(torch.float32) + grad_expert_out = (gci_f32 * w * mask).to(expert_out.dtype) + grad_recv_topk_weights = ( + (gci_f32 * expert_out.to(torch.float32) * mask).sum(-1).to(recv_topk_weights.dtype) + ) + return ( + None, # handle_mem + None, # handle_id + None, # combine_in + None, # num_local_tokens + None, # hidden_dim + None, # zero_copy + grad_expert_out, + grad_recv_topk_weights, + ) + + +# ── Public high-level wrappers ─────────────────────────────────────────────── + + +# FP8 dispatch is not yet supported by the common backend. +_FP8_DTYPES = (torch.float8_e4m3fn, torch.float8_e5m2) + + +@contextlib.contextmanager +def _zero_copy_scope(enabled: bool): + """Set the symm-mem zero-copy toggle for the scope, saving + restoring the prior value. + + No-op under ``torch.compile`` (pybind getter/setter aren't dynamo-traceable); + callers must pre-set the global toggle before entering the compiled region. + """ + if torch.compiler.is_compiling(): + yield + return + prev = tex.ep_get_zero_copy() + if prev == enabled: + yield + return + tex.ep_set_zero_copy(enabled) + try: + yield + finally: + tex.ep_set_zero_copy(prev) + + +def _reject_fp8(*tensors: torch.Tensor) -> None: + if torch.compiler.is_compiling(): + return + for t in tensors: + if t.dtype in _FP8_DTYPES: + raise NotImplementedError( + f"FP8 dispatch/combine not supported (got dtype={t.dtype}); " + "quantize outside the EP boundary." + ) + + +def ep_dispatch( + handle: EpHandle, + buffer: EpBuffer, + tokens: torch.Tensor, + topk_idx: torch.Tensor, + topk_weights: torch.Tensor, + *, + zero_copy: bool = True, +): + """Run prepare + dispatch with autograd. ``topk_idx`` must be int64. + + Returns ``(recv_tokens, recv_topk_weights, token_counts)``; + ``token_counts`` is non-differentiable. + """ + _reject_fp8(tokens, buffer.recv_tokens) + return _EpDispatch.apply( + handle.handle_mem, + handle.handle_id, + handle.alignment, + buffer.recv_tokens, + handle.num_local_experts, + zero_copy, + topk_idx, + tokens, + topk_weights, + ) + + +def ep_combine( + handle: EpHandle, + buffer: EpBuffer, + expert_out: torch.Tensor, + recv_topk_weights: torch.Tensor, + *, + num_local_tokens: Optional[int] = None, + zero_copy: bool = True, +): + """Apply per-slot weighting then combine, with autograd. + + Result shape is ``(num_local_tokens, handle.hidden_dim)``; defaults to + ``handle.max_tokens_per_rank`` rows. + """ + _reject_fp8(expert_out, buffer.combine_in) + if num_local_tokens is None: + num_local_tokens = handle.max_tokens_per_rank + return _EpCombine.apply( + handle.handle_mem, + handle.handle_id, + buffer.combine_in, + num_local_tokens, + handle.hidden_dim, + zero_copy, + expert_out, + recv_topk_weights, + )