From 17e51267c9b917183f23961962b8c857e9eb2e63 Mon Sep 17 00:00:00 2001 From: Phuong Nguyen Date: Fri, 22 May 2026 23:05:06 +0000 Subject: [PATCH 1/3] Expert Parallelism: common C API + NCCL EP v0.1 backend Signed-off-by: Phuong Nguyen --- .gitmodules | 4 + 3rdparty/nccl | 1 + qa/L1_cpp_distributed/test.sh | 3 + setup.py | 127 +++ tests/cpp_distributed/CMakeLists.txt | 91 +- tests/cpp_distributed/run_test_ep.sh | 137 +++ tests/cpp_distributed/test_ep_common.h | 308 ++++++ tests/cpp_distributed/test_ep_coverage.cu | 379 ++++++++ tests/cpp_distributed/test_ep_init.cu | 64 ++ tests/cpp_distributed/test_ep_pipeline.cu | 890 ++++++++++++++++++ transformer_engine/common/CMakeLists.txt | 90 ++ transformer_engine/common/ep/ep_api.cpp | 76 ++ transformer_engine/common/ep/ep_api_stub.cpp | 61 ++ transformer_engine/common/ep/ep_backend.cpp | 514 ++++++++++ transformer_engine/common/ep/ep_backend.h | 114 +++ .../include/transformer_engine/comm_window.h | 32 + .../common/include/transformer_engine/ep.h | 161 ++++ 17 files changed, 3050 insertions(+), 2 deletions(-) create mode 160000 3rdparty/nccl create mode 100755 tests/cpp_distributed/run_test_ep.sh create mode 100644 tests/cpp_distributed/test_ep_common.h create mode 100644 tests/cpp_distributed/test_ep_coverage.cu create mode 100644 tests/cpp_distributed/test_ep_init.cu create mode 100644 tests/cpp_distributed/test_ep_pipeline.cu create mode 100644 transformer_engine/common/ep/ep_api.cpp create mode 100644 transformer_engine/common/ep/ep_api_stub.cpp create mode 100644 transformer_engine/common/ep/ep_backend.cpp create mode 100644 transformer_engine/common/ep/ep_backend.h create mode 100644 transformer_engine/common/include/transformer_engine/comm_window.h create mode 100644 transformer_engine/common/include/transformer_engine/ep.h diff --git a/.gitmodules b/.gitmodules index 4b188d6bb1..e531c95507 100644 --- a/.gitmodules +++ b/.gitmodules @@ -7,3 +7,7 @@ [submodule "3rdparty/cutlass"] path = 3rdparty/cutlass url = https://github.com/NVIDIA/cutlass.git +[submodule "3rdparty/nccl"] + path = 3rdparty/nccl + url = https://github.com/NVIDIA/nccl.git + branch = v2.30u1 diff --git a/3rdparty/nccl b/3rdparty/nccl new file mode 160000 index 0000000000..6a9bc953ac --- /dev/null +++ b/3rdparty/nccl @@ -0,0 +1 @@ +Subproject commit 6a9bc953ac1c4eef92d5adbe3092d4c2cb0a4c98 diff --git a/qa/L1_cpp_distributed/test.sh b/qa/L1_cpp_distributed/test.sh index 8d767a4efb..7e5ce2cf0d 100755 --- a/qa/L1_cpp_distributed/test.sh +++ b/qa/L1_cpp_distributed/test.sh @@ -14,4 +14,7 @@ if [[ $(nvidia-smi --list-gpus | wc -l) -ge 4 ]]; then cmake -GNinja -S. -Bbuild cmake --build build mpirun --allow-run-as-root --np 4 --oversubscribe ./build/test_comm_gemm + + # EP suites; runner self-skips on pre-Hopper GPUs. + bash ./run_test_ep.sh 4 ./build fi diff --git a/setup.py b/setup.py index ec277b6349..db360c8a29 100644 --- a/setup.py +++ b/setup.py @@ -83,6 +83,34 @@ def setup_common_extension() -> CMakeExtension: cusolvermp_dir = os.getenv("CUSOLVERMP_HOME", "/usr") cmake_flags.append(f"-DCUSOLVERMP_DIR={cusolvermp_dir}") + # NCCL EP: on by default; auto-disabled if no arch >= 90. + # Set NVTE_BUILD_WITH_NCCL_EP=0/1 to force off/on. + nccl_ep_env = os.getenv("NVTE_BUILD_WITH_NCCL_EP") + explicit_nccl_ep = nccl_ep_env is not None + build_with_nccl_ep = bool(int(nccl_ep_env)) if explicit_nccl_ep else True + + if build_with_nccl_ep: + arch_tokens = [a.strip() for a in str(archs or "").split(";") if a.strip()] + has_hopper_or_newer = any(t.lower() == "native" for t in arch_tokens) or any( + int(t.rstrip("af")) >= 90 for t in arch_tokens if t.rstrip("af").isdigit() + ) + if not has_hopper_or_newer: + if explicit_nccl_ep: + raise RuntimeError( + "NVTE_BUILD_WITH_NCCL_EP=1 requires at least one CUDA arch >= 90 in " + f"NVTE_CUDA_ARCHS (got '{archs}'). Add '90' or unset NVTE_BUILD_WITH_NCCL_EP." + ) + print( + "[NCCL EP] No CUDA arch >= 90 in NVTE_CUDA_ARCHS" + f" ('{archs}'); auto-disabling NCCL EP (nvte_ep_* will throw at runtime)." + ) + build_with_nccl_ep = False + + if build_with_nccl_ep: + build_nccl_ep_submodule() + else: + cmake_flags.append("-DNVTE_WITH_NCCL_EP=OFF") + # Add custom CMake arguments from environment variable nvte_cmake_extra_args = os.getenv("NVTE_CMAKE_EXTRA_ARGS") if nvte_cmake_extra_args: @@ -128,6 +156,105 @@ def setup_requirements() -> Tuple[List[str], List[str]]: return [remove_dups(reqs) for reqs in [install_reqs, test_reqs]] +def _discover_nccl_home() -> str: + """Resolve NCCL_HOME: honor env var, else probe well-known prefixes, else ldconfig.""" + env_home = os.environ.get("NCCL_HOME") + if env_home: + if (Path(env_home) / "include" / "nccl.h").exists(): + return env_home + print( + f"[NCCL EP] WARNING: NCCL_HOME='{env_home}' is set but " + f"'{env_home}/include/nccl.h' was not found; falling back to system probes." + ) + + for cand in ("/opt/nvidia/nccl", "/usr/local/nccl", "/usr"): + p = Path(cand) + if (p / "include" / "nccl.h").exists() and any( + (p / "lib" / name).exists() or (p / "lib64" / name).exists() + for name in ("libnccl.so", "libnccl.so.2") + ): + return str(p) + + try: + out = subprocess.check_output(["ldconfig", "-p"], stderr=subprocess.DEVNULL).decode() + for line in out.splitlines(): + if "libnccl.so" in line and "=>" in line: + lib_path = Path(line.split("=>")[-1].strip()) + root = lib_path.parent.parent + if (root / "include" / "nccl.h").exists(): + return str(root) + except (subprocess.CalledProcessError, FileNotFoundError): + pass + + raise RuntimeError( + "Could not locate NCCL core (nccl.h + libnccl.so). Set NCCL_HOME to the install prefix." + ) + + +def build_nccl_ep_submodule() -> str: + """Build libnccl_ep.so from the 3rdparty/nccl submodule. + + NCCL EP is on by default; the system NCCL core (libnccl.so) supplies the + headers and runtime symbols. Returns the submodule build directory. + """ + nccl_root = current_file_path / "3rdparty" / "nccl" + if not (nccl_root / "Makefile").exists(): + raise RuntimeError( + f"NCCL submodule not found at {nccl_root}. " + "Run `git submodule update --init --recursive`." + ) + + build_dir = nccl_root / "build" + nccl_ep_lib = build_dir / "lib" / "libnccl_ep.so" + + archs = cuda_archs() or "90" + arch_list = [] + for a in str(archs).split(";"): + a = a.strip().rstrip("af") + if a and a.isdigit() and int(a) >= 90: + arch_list.append(a) + if not arch_list: + arch_list = ["90"] + gencode = " ".join(f"-gencode=arch=compute_{a},code=sm_{a}" for a in arch_list) + + nproc = os.cpu_count() or 8 + env = os.environ.copy() + env["NVCC_GENCODE"] = gencode + # NCCL EP needs the core NCCL headers + libnccl.so; write NCCL EP build + # outputs to the submodule's local build/ tree. + nccl_home = _discover_nccl_home() + env["NCCL_HOME"] = nccl_home + env["NCCL_EP_BUILDDIR"] = str(build_dir) + + if not nccl_ep_lib.exists(): + print(f"[NCCL EP] Building libnccl_ep.so (gencode='{gencode}')") + subprocess.check_call( + ["make", "-j", str(nproc), "-C", "contrib/nccl_ep", "lib"], + cwd=str(nccl_root), + env=env, + ) + + # TE's CMake expects nccl.h under 3rdparty/nccl/build/include/ for its + # version check. Mirror the top-level host headers from the system NCCL + # install — DON'T mirror nccl_device/ because the submodule ships its own + # newer copy at src/include/nccl_device/ with device-side templates that + # conflict with older system versions, and the JIT include path picks the + # submodule's. + nccl_include = build_dir / "include" + nccl_include.mkdir(parents=True, exist_ok=True) + for cand in (Path(nccl_home) / "include", Path("/usr/include")): + p = Path(cand) + if (p / "nccl.h").exists(): + for name in ("nccl.h", "nccl_net.h", "nccl_tuner.h"): + src = p / name + dst = nccl_include / name + if src.exists() and not dst.exists(): + dst.symlink_to(src) + break + + return str(build_dir) + + def git_check_submodules() -> None: """ Attempt to checkout git submodules automatically during setup. diff --git a/tests/cpp_distributed/CMakeLists.txt b/tests/cpp_distributed/CMakeLists.txt index 0d7258a81d..3870f57911 100644 --- a/tests/cpp_distributed/CMakeLists.txt +++ b/tests/cpp_distributed/CMakeLists.txt @@ -30,7 +30,7 @@ if(NOT DEFINED TE_LIB_PATH) get_filename_component(TE_LIB_PATH ${TE_LIB_FILE} DIRECTORY) endif() -find_library(TE_LIB NAMES transformer_engine PATHS "${TE_LIB_PATH}/.." ${TE_LIB_PATH} ENV TE_LIB_PATH REQUIRED) +find_library(TE_LIB NAMES transformer_engine PATHS "${TE_LIB_PATH}/.." ${TE_LIB_PATH} ENV TE_LIB_PATH REQUIRED NO_CMAKE_SYSTEM_PATH) message(STATUS "Found transformer_engine library: ${TE_LIB}") include_directories(../../transformer_engine/common/include) @@ -46,12 +46,99 @@ add_executable(test_comm_gemm find_package(OpenMP REQUIRED) find_package(MPI REQUIRED) + +# ── NCCL library ────────────────────────────────────────────────────────────── +# Search order: NCCL_HOME env → 3rdparty/nccl submodule build → system paths. +set(NCCL_SUBMODULE_BUILD "${CMAKE_CURRENT_SOURCE_DIR}/../../3rdparty/nccl/build") find_library(NCCL_LIB NAMES nccl libnccl - PATH_SUFFIXES lib + HINTS $ENV{NCCL_HOME}/lib ${NCCL_SUBMODULE_BUILD}/lib + PATH_SUFFIXES lib lib64 REQUIRED) + +# NCCL headers: prefer submodule build output (has the handle_init API), +# then submodule src, then system (CUDA toolkit). +set(NCCL_SUBMODULE_INCLUDE "${CMAKE_CURRENT_SOURCE_DIR}/../../3rdparty/nccl/build/include") +set(NCCL_SUBMODULE_SRC_INCLUDE "${CMAKE_CURRENT_SOURCE_DIR}/../../3rdparty/nccl/src/include") +if(EXISTS "${NCCL_SUBMODULE_INCLUDE}/nccl.h") + set(NCCL_INCLUDE_DIR "${NCCL_SUBMODULE_INCLUDE}") +elseif(EXISTS "${NCCL_SUBMODULE_SRC_INCLUDE}/nccl.h") + set(NCCL_INCLUDE_DIR "${NCCL_SUBMODULE_SRC_INCLUDE}") +elseif(DEFINED ENV{NCCL_HOME}) + set(NCCL_INCLUDE_DIR "$ENV{NCCL_HOME}/include") +endif() target_include_directories(test_comm_gemm PRIVATE ${MPI_CXX_INCLUDE_PATH} $ENV{CUBLASMP_HOME}/include) target_link_libraries(test_comm_gemm PUBLIC CUDA::cuda_driver CUDA::cudart GTest::gtest ${TE_LIB} CUDA::nvrtc MPI::MPI_CXX ${NCCL_LIB} OpenMP::OpenMP_CXX) include(GoogleTest) gtest_discover_tests(test_comm_gemm DISCOVERY_TIMEOUT 600) + +# ── EP distributed tests (HT mode) ───────────────────────────────────────── +# No MPI dependency — processes are spawned by run_test_ep.sh with +# --rank / --nranks flags. ncclUniqueId exchange uses a +# shared temp file (see test_ep_common.h for details). +# Headers + libs come from the in-tree 3rdparty/nccl submodule build. +set(NCCL_EP_SUBMODULE_ROOT + "${CMAKE_CURRENT_SOURCE_DIR}/../../3rdparty/nccl") +find_library(NCCL_EP_LIB + NAMES nccl_ep libnccl_ep + HINTS ${NCCL_EP_SUBMODULE_ROOT}/build/lib + NO_DEFAULT_PATH + REQUIRED) + +set(NCCL_EP_INCLUDE_DIR "${NCCL_EP_SUBMODULE_ROOT}/contrib/nccl_ep/include") +if(NOT EXISTS "${NCCL_EP_INCLUDE_DIR}/nccl_ep.h") + message(FATAL_ERROR + "NCCL EP header not found at ${NCCL_EP_INCLUDE_DIR}/nccl_ep.h. " + "Run `git submodule update --init --recursive` to checkout 3rdparty/nccl.") +endif() +message(STATUS "EP test: NCCL EP headers: ${NCCL_EP_INCLUDE_DIR}") + +# Collect NCCL include dirs shared by all EP test targets (nccl_ep.h + nccl.h). +set(EP_TEST_NCCL_INCLUDES ${NCCL_EP_INCLUDE_DIR}) +if(DEFINED NCCL_INCLUDE_DIR) + list(APPEND EP_TEST_NCCL_INCLUDES ${NCCL_INCLUDE_DIR}) + message(STATUS "EP test: NCCL headers: ${NCCL_INCLUDE_DIR}") +endif() + +set(EP_TEST_COMMON_INCLUDES + ${EP_TEST_NCCL_INCLUDES} + ../../transformer_engine/common/include + ../../transformer_engine/common + ${CMAKE_CURRENT_SOURCE_DIR}) + +set(EP_TEST_COMMON_LIBS + CUDA::cuda_driver + CUDA::cudart + CUDA::nvrtc + GTest::gtest + ${TE_LIB} + ${NCCL_LIB} + ${NCCL_EP_LIB}) + +# nvrtc symbols are referenced from libtransformer_engine.so but not in its +# DT_NEEDED list (loaded via dlopen in Python). For cpp tests we link nvrtc +# explicitly with --no-as-needed so the linker keeps the dependency. +set(EP_TEST_LINK_OPTS "LINKER:--no-as-needed") + +# ── EP init tests (InitPath, HandleMemSizeQuery) ───────────────────────────── +add_executable(test_ep_init test_ep_init.cu) +target_include_directories(test_ep_init PRIVATE ${EP_TEST_COMMON_INCLUDES}) +target_link_libraries(test_ep_init PUBLIC ${EP_TEST_COMMON_LIBS}) +target_link_options(test_ep_init PUBLIC ${EP_TEST_LINK_OPTS}) + +# ── EP pipeline tests (dispatch, combine, bwd, integrated) ─────────────────── +add_executable(test_ep_pipeline test_ep_pipeline.cu) +target_include_directories(test_ep_pipeline PRIVATE ${EP_TEST_COMMON_INCLUDES}) +target_link_libraries(test_ep_pipeline PUBLIC ${EP_TEST_COMMON_LIBS}) +target_link_options(test_ep_pipeline PUBLIC ${EP_TEST_LINK_OPTS}) + +# ── EP coverage tests (multi-handle, top_k=1, empty experts, negatives, threading) ── +add_executable(test_ep_coverage test_ep_coverage.cu) +target_include_directories(test_ep_coverage PRIVATE ${EP_TEST_COMMON_INCLUDES}) +target_link_libraries(test_ep_coverage PUBLIC ${EP_TEST_COMMON_LIBS}) +target_link_options(test_ep_coverage PUBLIC ${EP_TEST_LINK_OPTS}) + +# Do NOT use gtest_discover_tests — these binaries require multi-process +# launch via run_test_ep.sh, not direct single-process execution. +message(STATUS "EP distributed tests enabled: ${NCCL_EP_LIB}") diff --git a/tests/cpp_distributed/run_test_ep.sh b/tests/cpp_distributed/run_test_ep.sh new file mode 100755 index 0000000000..017d3f807b --- /dev/null +++ b/tests/cpp_distributed/run_test_ep.sh @@ -0,0 +1,137 @@ +#!/usr/bin/env bash +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. +# +# Run TE EP distributed unit tests across multiple GPUs. +# +# Spawns one background bash process per GPU (no MPI dependency), matching the +# JAX multi-process launcher style. ncclUniqueId is exchanged via a shared +# temp file (see test_ep_common.h). Each rank builds its own ncclComm_t and +# passes it to nvte_ep_initialize. +# +# Usage: +# bash run_test_ep.sh [num_gpus] [build_dir] +# +# Defaults: +# num_gpus = number of GPUs visible to nvidia-smi +# build_dir = /build +# +# Environment variables: +# GTEST_FILTER — forwarded to all processes (e.g., "EPDispatchTest.*") +# TEST_TIMEOUT_S — per-process timeout in seconds (default: 180) + +set -euo pipefail + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +BUILD_DIR="${2:-${SCRIPT_DIR}/build}" +NUM_GPUS="${1:-$(nvidia-smi -L 2>/dev/null | wc -l)}" +TEST_TIMEOUT_S="${TEST_TIMEOUT_S:-180}" + +# Skip cleanly on pre-Hopper: NCCL EP requires SM>=90. +MIN_SM=$(nvidia-smi --query-gpu=compute_cap --format=csv,noheader 2>/dev/null \ + | awk -F. 'NR==1 || ($1*10+$2) 0 && MIN_SM < 90 )); then + echo "NCCL EP requires SM>=90 (lowest visible GPU is SM${MIN_SM}); SKIPPING." + exit 0 +fi + +GTEST_ARGS="${GTEST_FILTER:+--gtest_filter=${GTEST_FILTER}}" +OVERALL_FAIL=0 + +# --------------------------------------------------------------------------- +# run_suite BINARY SUITE_NAME MIN_GPUS +# --------------------------------------------------------------------------- +run_suite() { + local BINARY="$1" + local SUITE_NAME="$2" + local MIN_GPUS="${3:-2}" + + local TEST_BIN="${BUILD_DIR}/${BINARY}" + + if [[ ! -x "${TEST_BIN}" ]]; then + echo "ERROR: binary not found: ${TEST_BIN}" + echo "Build: cd ${SCRIPT_DIR} && mkdir -p build && cd build && cmake .. && make" + OVERALL_FAIL=1 + return + fi + + if (( NUM_GPUS < MIN_GPUS )); then + echo "${SUITE_NAME}: requires ${MIN_GPUS} GPUs, found ${NUM_GPUS}. Skipping." + return + fi + + local TMPDIR_L="${TMPDIR:-/tmp}" + local UID_FILE="${TMPDIR_L}/te_ep_uid_${BINARY}_$$" + rm -f "${UID_FILE}" + + local LOG_DIR + LOG_DIR=$(mktemp -d) + local FAIL=0 + + echo "=== ${SUITE_NAME} ===" + echo " GPUs: ${NUM_GPUS} Binary: ${TEST_BIN}" + echo + + # Spawn one background process per GPU. ncclUniqueId is exchanged via the + # shared UID_FILE. Each process is wrapped in `timeout` to detect hangs early. + local PIDS=() + for i in $(seq 0 $((NUM_GPUS - 1))); do + timeout --foreground --signal=KILL "${TEST_TIMEOUT_S}" \ + "${TEST_BIN}" \ + --rank="${i}" \ + --nranks="${NUM_GPUS}" \ + --uid-file="${UID_FILE}" \ + ${GTEST_ARGS} \ + > "${LOG_DIR}/rank_${i}.log" 2>&1 & + PIDS+=($!) + done + for i in $(seq 0 $((NUM_GPUS - 1))); do + if ! wait "${PIDS[$i]}"; then + local rc=$? + FAIL=1 + if [[ $rc -eq 137 || $rc -eq 124 ]]; then + echo " rank ${i}: TIMEOUT after ${TEST_TIMEOUT_S}s (rc=${rc})" + fi + fi + done + + echo "--- Rank 0 output ---" + cat "${LOG_DIR}/rank_0.log" + + if (( FAIL )); then + for i in $(seq 1 $((NUM_GPUS - 1))); do + echo "--- Rank ${i} output ---" + cat "${LOG_DIR}/rank_${i}.log" + done + echo "=== ${SUITE_NAME}: FAILED ===" + OVERALL_FAIL=1 + else + echo "=== ${SUITE_NAME}: ALL PASSED ===" + fi + + rm -rf "${LOG_DIR}" + rm -f "${UID_FILE}" +} + +# --------------------------------------------------------------------------- +# Cleanup on abort +# --------------------------------------------------------------------------- +cleanup() { rm -f "${TMPDIR:-/tmp}"/te_ep_uid_*_"$$" 2>/dev/null || true; } +trap cleanup EXIT INT TERM + +# --------------------------------------------------------------------------- +# Run all suites +# --------------------------------------------------------------------------- +run_suite "test_ep_init" "EP Init Tests" 2 +run_suite "test_ep_pipeline" "EP Pipeline Tests" 2 +run_suite "test_ep_coverage" "EP Coverage Tests" 2 + +echo +if (( OVERALL_FAIL )); then + echo "=== SOME SUITES FAILED ===" +else + echo "=== ALL SUITES PASSED ===" +fi + +exit "${OVERALL_FAIL}" diff --git a/tests/cpp_distributed/test_ep_common.h b/tests/cpp_distributed/test_ep_common.h new file mode 100644 index 0000000000..77baa92b0c --- /dev/null +++ b/tests/cpp_distributed/test_ep_common.h @@ -0,0 +1,308 @@ +/************************************************************************* + * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +/* + * Shared TE EP test infrastructure. Include once per TU; ep_bootstrap() in + * each test binary's main() populates process-level globals. + * Defaults: 4 experts/rank, hidden_dim=256, max_tokens_per_rank=64. + */ +#pragma once + +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include + +// ── Error-checking macros ───────────────────────────────────────────────────── + +#define CHECK_NCCL(expr) \ + do { \ + ncclResult_t _err = (expr); \ + if (_err != ncclSuccess) \ + FAIL() << "NCCL error " << _err << ": " << ncclGetErrorString(_err); \ + } while (false) + +#define CHECK_CUDA(expr) \ + do { \ + cudaError_t _err = (expr); \ + if (_err != cudaSuccess) \ + FAIL() << "CUDA error " << _err << ": " << cudaGetErrorString(_err); \ + } while (false) + +#define ASSERT_CUDA_OK(expr) \ + do { \ + cudaError_t _err = (expr); \ + if (_err != cudaSuccess) { \ + fprintf(stderr, "CUDA error %d: %s\n", _err, cudaGetErrorString(_err)); \ + exit(EXIT_FAILURE); \ + } \ + } while (false) + +#define ASSERT_NCCL_OK(expr) \ + do { \ + ncclResult_t _err = (expr); \ + if (_err != ncclSuccess) { \ + fprintf(stderr, "NCCL error %d: %s\n", _err, ncclGetErrorString(_err)); \ + exit(EXIT_FAILURE); \ + } \ + } while (false) + +// ── Process-level state ─────────────────────────────────────────────────────── + +static int g_process_id = -1; +static int g_num_processes = -1; +static std::string g_uid_file; + +static int g_sm_major = -1; // set by ep_bootstrap; -1 until then +static int g_ep_size = -1; +static int g_num_experts = -1; +static int g_hidden_dim = 256; +static int g_max_tokens_per_rank = 64; +static bool g_ep_initialized = false; +static ncclComm_t g_ep_comm = nullptr; // owned by harness, destroyed in ep_teardown + +// ── TensorHandle RAII wrapper ───────────────────────────────────────────────── + +// View over a caller-owned device buffer; owns NVTETensor metadata only. Move-only. +struct TensorHandle { + NVTETensor tensor = nullptr; + void* dev_ptr = nullptr; + + ~TensorHandle() { + if (tensor) nvte_destroy_tensor(tensor); + } + + TensorHandle() = default; + TensorHandle(const TensorHandle&) = delete; + TensorHandle& operator=(const TensorHandle&) = delete; + + TensorHandle(TensorHandle&& o) noexcept : tensor(o.tensor), dev_ptr(o.dev_ptr) { + o.tensor = nullptr; o.dev_ptr = nullptr; + } + TensorHandle& operator=(TensorHandle&& o) noexcept { + if (this != &o) { + if (tensor) nvte_destroy_tensor(tensor); + tensor = o.tensor; dev_ptr = o.dev_ptr; + o.tensor = nullptr; o.dev_ptr = nullptr; + } + return *this; + } +}; + +static TensorHandle make_nvte_tensor(void* dev_ptr, + const std::vector& shape, + NVTEDType dtype) { + TensorHandle h; + h.dev_ptr = dev_ptr; + h.tensor = nvte_create_tensor(NVTE_DELAYED_TENSOR_SCALING); + + NVTEShape s; + s.ndim = shape.size(); + for (size_t i = 0; i < shape.size(); ++i) s.data[i] = shape[i]; + + NVTEBasicTensor bt; + bt.data_ptr = dev_ptr; + bt.dtype = dtype; + bt.shape = s; + nvte_set_tensor_param_v2(h.tensor, kNVTERowwiseData, &bt, sizeof(bt)); + + return h; +} + +// RAII owner for a cudaMalloc'd device buffer; frees on destruction. +template +struct DevBuf { + T* ptr = nullptr; + size_t count = 0; + + DevBuf() = default; + explicit DevBuf(size_t n) { alloc(n); } + ~DevBuf() { reset(); } + + DevBuf(const DevBuf&) = delete; + DevBuf& operator=(const DevBuf&) = delete; + DevBuf(DevBuf&& o) noexcept : ptr(o.ptr), count(o.count) { o.ptr = nullptr; o.count = 0; } + DevBuf& operator=(DevBuf&& o) noexcept { + if (this != &o) { reset(); ptr = o.ptr; count = o.count; o.ptr = nullptr; o.count = 0; } + return *this; + } + + void alloc(size_t n) { + reset(); + count = n; + if (n > 0) { + cudaError_t e = cudaMalloc(&ptr, n * sizeof(T)); + if (e != cudaSuccess) { + fprintf(stderr, "DevBuf cudaMalloc(%zu) failed: %s\n", n * sizeof(T), + cudaGetErrorString(e)); + ptr = nullptr; + count = 0; + } + } + } + + void reset() { + if (ptr) { cudaFree(ptr); ptr = nullptr; } + count = 0; + } + + T* get() const { return ptr; } + size_t bytes() const { return count * sizeof(T); } +}; + +// ── Shared routing helper ───────────────────────────────────────────────────── + +// Balanced round-robin routing: token t on rank r maps top_k experts to +// (r * num_local_experts + t * top_k + k) % num_experts +static inline std::vector routing_balanced( + int rank, int num_tokens, int top_k, int num_experts, int num_local_experts) { + std::vector idx(num_tokens * top_k); + for (int t = 0; t < num_tokens; ++t) + for (int k = 0; k < top_k; ++k) + idx[t * top_k + k] = (rank * num_local_experts + t * top_k + k) % num_experts; + return idx; +} + +// ── File-based ncclUniqueId exchange ───────────────────────────────────────── + +static void exchange_unique_id(ncclUniqueId* uid) { + const size_t sz = sizeof(ncclUniqueId); + + if (g_process_id == 0) { + ASSERT_NCCL_OK(ncclGetUniqueId(uid)); + FILE* f = fopen(g_uid_file.c_str(), "wb"); + if (!f) { fprintf(stderr, "Cannot open uid file: %s\n", g_uid_file.c_str()); exit(EXIT_FAILURE); } + fwrite(uid, 1, sz, f); + fclose(f); + } else { + auto deadline = std::chrono::steady_clock::now() + std::chrono::seconds(60); + while (true) { + FILE* f = fopen(g_uid_file.c_str(), "rb"); + if (f) { + fseek(f, 0, SEEK_END); + if (static_cast(ftell(f)) >= sz) { + fseek(f, 0, SEEK_SET); + size_t n = fread(uid, 1, sz, f); + fclose(f); + if (n == sz) break; + } else { + fclose(f); + } + } + if (std::chrono::steady_clock::now() > deadline) { + fprintf(stderr, "Process %d: timed out waiting for uid file\n", g_process_id); + exit(EXIT_FAILURE); + } + std::this_thread::sleep_for(std::chrono::milliseconds(50)); + } + } +} + +// ── CLI parsing ─────────────────────────────────────────────────────────────── + +static void ep_parse_args(int argc, char* argv[]) { + for (int i = 1; i < argc; ++i) { + std::string a(argv[i]); + if (a.rfind("--process-id=", 0) == 0) g_process_id = std::stoi(a.substr(13)); + else if (a.rfind("--rank=", 0) == 0) g_process_id = std::stoi(a.substr(7)); + else if (a.rfind("--num-processes=",0)==0) g_num_processes = std::stoi(a.substr(16)); + else if (a.rfind("--nranks=", 0) == 0) g_num_processes = std::stoi(a.substr(9)); + else if (a.rfind("--uid-file=", 0) == 0) g_uid_file = a.substr(11); + } + + if (g_process_id < 0 || g_num_processes <= 0) { + fprintf(stderr, + "Usage: %s --rank=N --nranks=N [--uid-file=path] [gtest flags]\n" + " Aliases: --process-id=N, --num-processes=N\n", + argc > 0 ? argv[0] : "test_ep"); + exit(EXIT_FAILURE); + } + + if (g_uid_file.empty()) { + const char* t = getenv("TMPDIR"); if (!t) t = "/tmp"; + g_uid_file = std::string(t) + "/te_ep_uid_" + std::to_string(g_process_id); + } +} + +// ── Bootstrap / teardown ────────────────────────────────────────────────────── + +// Returns false if the binary should exit without running tests (wrong SM, etc.). +static bool ep_bootstrap(int argc, char* argv[]) { + ep_parse_args(argc, argv); + ::testing::InitGoogleTest(&argc, argv); + + int device_count; + cudaGetDeviceCount(&device_count); + cudaSetDevice(g_process_id % device_count); + + int device, major; + cudaGetDevice(&device); + cudaDeviceGetAttribute(&major, cudaDevAttrComputeCapabilityMajor, device); + g_sm_major = major; + if (major < 9) { + if (g_process_id == 0) + printf("SKIP: EP requires SM_90+ (device is SM_%d0)\n", major); + return false; + } + if (g_num_processes < 2) { + if (g_process_id == 0) + printf("SKIP: at least 2 processes required\n"); + return false; + } + + g_ep_size = g_num_processes; + g_num_experts = g_ep_size * 4; // 4 experts per rank + + ncclUniqueId uid{}; + exchange_unique_id(&uid); + + NVTEEpGroupConfig group_config{}; + group_config.ep_size = g_ep_size; + group_config.num_experts = g_num_experts; + group_config.max_tokens_per_rank = g_max_tokens_per_rank; + // Worst-case for top_k fan-out: ep_size * max_tokens_per_rank * 2. + group_config.max_recv_tokens_per_rank = g_ep_size * g_max_tokens_per_rank * 2; + group_config.hidden_dim = g_hidden_dim; + + ASSERT_NCCL_OK(ncclCommInitRank(&g_ep_comm, g_num_processes, uid, g_process_id)); + nvte_ep_initialize(static_cast(g_ep_comm), group_config); + + if (g_process_id == 0) { + printf("EP initialized: ep_size=%d num_experts=%d " + "hidden_dim=%d max_tokens_per_rank=%d\n", + g_ep_size, g_num_experts, g_hidden_dim, g_max_tokens_per_rank); + } + + g_ep_initialized = true; + return true; +} + +// Tear down in dependency order: backend's ep_group reads from ep_comm, +// so destroy the group first, then the comm. +static void ep_teardown() { + if (g_ep_initialized) { + nvte_ep_shutdown(); + if (g_ep_comm != nullptr) { + ncclCommDestroy(g_ep_comm); + g_ep_comm = nullptr; + } + g_ep_initialized = false; + } + if (g_process_id == 0) remove(g_uid_file.c_str()); +} diff --git a/tests/cpp_distributed/test_ep_coverage.cu b/tests/cpp_distributed/test_ep_coverage.cu new file mode 100644 index 0000000000..ef7941905d --- /dev/null +++ b/tests/cpp_distributed/test_ep_coverage.cu @@ -0,0 +1,379 @@ +/************************************************************************* + * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +/* + * EP C-API coverage tests (paths not exercised by the pipeline suite). + * + * MultiHandleAllocTest — distinct handle ids; each works end-to-end. + * TopK1Test — top_k=1 dispatch/combine/bwd round-trip. + * EmptyExpertsTest — alignment ∈ {0, 2, 8, 16} with experts receiving 0 tokens. + * NegativeTests — alignment mismatch and null handle_mem must throw. + */ + +#include "test_ep_common.h" + +#include +#include + +// top1 -> expert 0, top2 -> expert 2; leaves local-expert 1 empty between two +// full experts. Requires top_k >= 2 and num_experts >= 3. +static std::vector routing_skip_middle(int num_tokens, int top_k) { + std::vector idx(num_tokens * top_k); + for (int t = 0; t < num_tokens; ++t) { + idx[t * top_k + 0] = 0; + if (top_k >= 2) idx[t * top_k + 1] = 2; + for (int k = 2; k < top_k; ++k) idx[t * top_k + k] = 2 + k; // distinct stragglers + } + return idx; +} + +static std::vector tokens_constant(int num_tokens, int hidden_dim, float val) { + std::vector v(num_tokens * hidden_dim); + nv_bfloat16 b = __float2bfloat16(val); + std::fill(v.begin(), v.end(), b); + return v; +} + +namespace { + +class EpCoverageBase : public ::testing::Test { + protected: + int ep_size_, num_experts_, num_local_experts_, hidden_dim_; + int max_tokens_per_rank_; + + void SetUp() override { + if (g_sm_major < 9) + GTEST_SKIP() << "EP requires SM_90+ (device is SM_" << g_sm_major << "0)"; + ASSERT_GE(g_num_processes, 2); + ASSERT_TRUE(g_ep_initialized); + ep_size_ = g_ep_size; + num_experts_ = g_num_experts; + num_local_experts_ = num_experts_ / ep_size_; + hidden_dim_ = g_hidden_dim; + max_tokens_per_rank_ = g_max_tokens_per_rank; + } + + // Helper: allocate buffers + tensor views for a single dispatch+combine. + struct Bundle { + DevBuf topk_idx; + DevBuf topk_weights; + DevBuf tokens; + DevBuf token_counts; + DevBuf handle_mem; + DevBuf recv_tokens; + DevBuf recv_topk_weights; + DevBuf result; + uint64_t handle_id = 0; + size_t handle_mem_size = 0; + size_t recv_capacity = 0; + }; + + Bundle make_bundle(int num_tokens, int top_k, int num_local_experts, + size_t alignment) { + Bundle b; + b.recv_capacity = static_cast(ep_size_) * max_tokens_per_rank_ * 2; + b.topk_idx.alloc(num_tokens * top_k); + b.topk_weights.alloc(num_tokens * top_k); + b.tokens.alloc(num_tokens * hidden_dim_); + b.token_counts.alloc(num_local_experts); + b.recv_tokens.alloc(b.recv_capacity * hidden_dim_); + b.recv_topk_weights.alloc(b.recv_capacity); + b.result.alloc(num_tokens * hidden_dim_); + NVTEEpLayerConfig cfg{num_local_experts, top_k, alignment}; + b.handle_id = nvte_ep_register_layer(cfg, &b.handle_mem_size); + b.handle_mem.alloc(b.handle_mem_size); + return b; + } +}; + +} // namespace + +// ============================================================================= +// MultiHandleAllocTest: ids are distinct and each is independently usable. +// ============================================================================= + +class MultiHandleAllocTest : public EpCoverageBase {}; + +TEST_F(MultiHandleAllocTest, IdsAreDistinct) { + NVTEEpLayerConfig cfg{num_local_experts_, /*top_k=*/2, /*alignment=*/0}; + const int kN = 8; + std::vector ids(kN); + for (int i = 0; i < kN; ++i) { + size_t sz = 0; + ids[i] = nvte_ep_register_layer(cfg, &sz); + } + for (int i = 0; i < kN; ++i) { + EXPECT_NE(ids[i], 0u) << "handle_id 0 is reserved as \"no id\""; + for (int j = i + 1; j < kN; ++j) + EXPECT_NE(ids[i], ids[j]) << "duplicate id " << ids[i] << " at indices " << i << ", " << j; + } +} + +TEST_F(MultiHandleAllocTest, TwoHandlesCoexist) { + const int num_tokens = 16, top_k = 2; + Bundle a = make_bundle(num_tokens, top_k, num_local_experts_, /*alignment=*/0); + Bundle b = make_bundle(num_tokens, top_k, num_local_experts_, /*alignment=*/0); + + auto h_idx = routing_balanced(g_process_id, num_tokens, top_k, + num_experts_, num_local_experts_); + std::vector h_w(num_tokens * top_k, 1.0f / top_k); + auto h_tok = tokens_constant(num_tokens, hidden_dim_, 0.5f); + for (Bundle* x : {&a, &b}) { + CHECK_CUDA(cudaMemcpy(x->topk_idx.get(), h_idx.data(), + h_idx.size() * sizeof(int64_t), cudaMemcpyHostToDevice)); + CHECK_CUDA(cudaMemcpy(x->topk_weights.get(), h_w.data(), + h_w.size() * sizeof(float), cudaMemcpyHostToDevice)); + CHECK_CUDA(cudaMemcpy(x->tokens.get(), h_tok.data(), + h_tok.size() * sizeof(nv_bfloat16), cudaMemcpyHostToDevice)); + } + + cudaStream_t stream; + CHECK_CUDA(cudaStreamCreate(&stream)); + + ASSERT_NE(a.handle_id, b.handle_id); + + auto run_one = [&](Bundle& x) { + auto topk_idx = make_nvte_tensor(x.topk_idx.get(), {(size_t)num_tokens, (size_t)top_k}, kNVTEInt64); + auto topk_weights = make_nvte_tensor(x.topk_weights.get(), {(size_t)num_tokens, (size_t)top_k}, kNVTEFloat32); + auto token_counts = make_nvte_tensor(x.token_counts.get(), {(size_t)num_local_experts_}, kNVTEInt32); + auto handle_mem = make_nvte_tensor(x.handle_mem.get(), {x.handle_mem_size}, kNVTEByte); + auto tokens = make_nvte_tensor(x.tokens.get(), {(size_t)num_tokens, (size_t)hidden_dim_}, kNVTEBFloat16); + auto recv_tokens = make_nvte_tensor(x.recv_tokens.get(), {x.recv_capacity, (size_t)hidden_dim_}, kNVTEBFloat16); + auto recv_w = make_nvte_tensor(x.recv_topk_weights.get(), {x.recv_capacity}, kNVTEFloat32); + auto result = make_nvte_tensor(x.result.get(), {(size_t)num_tokens, (size_t)hidden_dim_}, kNVTEBFloat16); + NVTEEpHandle h{x.handle_id, handle_mem.tensor}; + ASSERT_NO_THROW(nvte_ep_prepare(h, topk_idx.tensor, token_counts.tensor, + /*alignment=*/0, stream)); + ASSERT_NO_THROW(nvte_ep_dispatch(h, topk_idx.tensor, tokens.tensor, + NVTECommWindow{}, topk_weights.tensor, NVTECommWindow{}, + recv_tokens.tensor, NVTECommWindow{}, recv_w.tensor, + NVTECommWindow{}, stream)); + ASSERT_NO_THROW(nvte_ep_combine(h, recv_tokens.tensor, NVTECommWindow{}, + result.tensor, stream)); + }; + run_one(a); + run_one(b); + CHECK_CUDA(cudaStreamSynchronize(stream)); + + // Both round-trips must produce result == top_k * 0.5 = 1.0. + for (Bundle* x : {&a, &b}) { + std::vector h_res(num_tokens * hidden_dim_); + CHECK_CUDA(cudaMemcpy(h_res.data(), x->result.get(), + h_res.size() * sizeof(nv_bfloat16), cudaMemcpyDeviceToHost)); + const int probes[3] = {0, hidden_dim_ / 2, hidden_dim_ - 1}; + for (int t = 0; t < num_tokens; ++t) + for (int p : probes) + EXPECT_NEAR(__bfloat162float(h_res[t * hidden_dim_ + p]), + static_cast(top_k) * 0.5f, 1e-2f); + } + CHECK_CUDA(cudaStreamDestroy(stream)); +} + +// ============================================================================= +// TopK1Test: top_k=1 dispatch/combine round-trip, including dispatch_bwd. +// ============================================================================= + +class TopK1Test : public EpCoverageBase {}; + +TEST_F(TopK1Test, RoundTrip) { + const int num_tokens = 16, top_k = 1; + Bundle b = make_bundle(num_tokens, top_k, num_local_experts_, /*alignment=*/0); + + auto h_idx = routing_balanced(g_process_id, num_tokens, top_k, + num_experts_, num_local_experts_); + std::vector h_w(num_tokens * top_k, 1.0f); // top_k=1: weight is unity + auto h_tok = tokens_constant(num_tokens, hidden_dim_, 0.25f); + CHECK_CUDA(cudaMemcpy(b.topk_idx.get(), h_idx.data(), + h_idx.size() * sizeof(int64_t), cudaMemcpyHostToDevice)); + CHECK_CUDA(cudaMemcpy(b.topk_weights.get(), h_w.data(), + h_w.size() * sizeof(float), cudaMemcpyHostToDevice)); + CHECK_CUDA(cudaMemcpy(b.tokens.get(), h_tok.data(), + h_tok.size() * sizeof(nv_bfloat16), cudaMemcpyHostToDevice)); + + auto topk_idx_t = make_nvte_tensor(b.topk_idx.get(), + {(size_t)num_tokens, (size_t)top_k}, kNVTEInt64); + auto topk_weights_t = make_nvte_tensor(b.topk_weights.get(), + {(size_t)num_tokens, (size_t)top_k}, kNVTEFloat32); + auto token_counts_t = make_nvte_tensor(b.token_counts.get(), + {(size_t)num_local_experts_}, kNVTEInt32); + auto handle_mem_t = make_nvte_tensor(b.handle_mem.get(), + {b.handle_mem_size}, kNVTEByte); + auto tokens_t = make_nvte_tensor(b.tokens.get(), + {(size_t)num_tokens, (size_t)hidden_dim_}, kNVTEBFloat16); + auto recv_tokens_t = make_nvte_tensor(b.recv_tokens.get(), + {b.recv_capacity, (size_t)hidden_dim_}, kNVTEBFloat16); + auto recv_w_t = make_nvte_tensor(b.recv_topk_weights.get(), + {b.recv_capacity}, kNVTEFloat32); + auto result_t = make_nvte_tensor(b.result.get(), + {(size_t)num_tokens, (size_t)hidden_dim_}, kNVTEBFloat16); + + cudaStream_t stream; + CHECK_CUDA(cudaStreamCreate(&stream)); + + NVTEEpHandle h{b.handle_id, handle_mem_t.tensor}; + ASSERT_NO_THROW(nvte_ep_prepare(h, topk_idx_t.tensor, token_counts_t.tensor, + /*alignment=*/0, stream)); + ASSERT_NO_THROW(nvte_ep_dispatch(h, topk_idx_t.tensor, + tokens_t.tensor, NVTECommWindow{}, topk_weights_t.tensor, + NVTECommWindow{}, recv_tokens_t.tensor, NVTECommWindow{}, + recv_w_t.tensor, NVTECommWindow{}, stream)); + ASSERT_NO_THROW(nvte_ep_combine(h, recv_tokens_t.tensor, + NVTECommWindow{}, result_t.tensor, stream)); + CHECK_CUDA(cudaStreamSynchronize(stream)); + + // top_k=1: combine is unweighted gather, so result[t] == tokens[t]. + std::vector h_res(num_tokens * hidden_dim_); + CHECK_CUDA(cudaMemcpy(h_res.data(), b.result.get(), + h_res.size() * sizeof(nv_bfloat16), cudaMemcpyDeviceToHost)); + const int probes[3] = {0, hidden_dim_ / 2, hidden_dim_ - 1}; + for (int t = 0; t < num_tokens; ++t) + for (int p : probes) + EXPECT_NEAR(__bfloat162float(h_res[t * hidden_dim_ + p]), 0.25f, 1e-2f) + << "tok " << t << " hidden " << p; + + CHECK_CUDA(cudaStreamDestroy(stream)); +} + +// ============================================================================= +// EmptyExpertsTest: alignment ∈ {0, 2, 8, 16}, only local-expert 0 receives +// tokens. Round-trip must produce result == top_k * tokens regardless of the +// per-expert padding choice. +// ============================================================================= + +class EmptyExpertsTest : public EpCoverageBase, + public ::testing::WithParamInterface {}; + +TEST_P(EmptyExpertsTest, RoundTripCorrect) { + // routing_skip_middle needs experts {0, 2, ...}; smallest viable num_experts is 3. + ASSERT_GE(num_experts_, 3); + const size_t alignment = GetParam(); + const int num_tokens = 16, top_k = 2; + Bundle b = make_bundle(num_tokens, top_k, num_local_experts_, alignment); + + // top1 -> expert 0, top2 -> expert 2; rank 0's local-expert 1 receives 0 + // tokens between two non-empty experts. + std::vector h_idx = routing_skip_middle(num_tokens, top_k); + std::vector h_w(num_tokens * top_k, 1.0f / top_k); + auto h_tok = tokens_constant(num_tokens, hidden_dim_, 0.3f); + + CHECK_CUDA(cudaMemcpy(b.topk_idx.get(), h_idx.data(), + h_idx.size() * sizeof(int64_t), cudaMemcpyHostToDevice)); + CHECK_CUDA(cudaMemcpy(b.topk_weights.get(), h_w.data(), + h_w.size() * sizeof(float), cudaMemcpyHostToDevice)); + CHECK_CUDA(cudaMemcpy(b.tokens.get(), h_tok.data(), + h_tok.size() * sizeof(nv_bfloat16), cudaMemcpyHostToDevice)); + + auto topk_idx_t = make_nvte_tensor(b.topk_idx.get(), + {(size_t)num_tokens, (size_t)top_k}, kNVTEInt64); + auto topk_weights_t = make_nvte_tensor(b.topk_weights.get(), + {(size_t)num_tokens, (size_t)top_k}, kNVTEFloat32); + auto token_counts_t = make_nvte_tensor(b.token_counts.get(), + {(size_t)num_local_experts_}, kNVTEInt32); + auto handle_mem_t = make_nvte_tensor(b.handle_mem.get(), + {b.handle_mem_size}, kNVTEByte); + auto tokens_t = make_nvte_tensor(b.tokens.get(), + {(size_t)num_tokens, (size_t)hidden_dim_}, kNVTEBFloat16); + auto recv_tokens_t = make_nvte_tensor(b.recv_tokens.get(), + {b.recv_capacity, (size_t)hidden_dim_}, kNVTEBFloat16); + auto recv_w_t = make_nvte_tensor(b.recv_topk_weights.get(), + {b.recv_capacity}, kNVTEFloat32); + auto result_t = make_nvte_tensor(b.result.get(), + {(size_t)num_tokens, (size_t)hidden_dim_}, kNVTEBFloat16); + + cudaStream_t stream; + CHECK_CUDA(cudaStreamCreate(&stream)); + + NVTEEpHandle h{b.handle_id, handle_mem_t.tensor}; + ASSERT_NO_THROW(nvte_ep_prepare(h, topk_idx_t.tensor, token_counts_t.tensor, + alignment, stream)); + ASSERT_NO_THROW(nvte_ep_dispatch(h, topk_idx_t.tensor, + tokens_t.tensor, NVTECommWindow{}, topk_weights_t.tensor, + NVTECommWindow{}, recv_tokens_t.tensor, NVTECommWindow{}, + recv_w_t.tensor, NVTECommWindow{}, stream)); + ASSERT_NO_THROW(nvte_ep_combine(h, recv_tokens_t.tensor, + NVTECommWindow{}, result_t.tensor, stream)); + CHECK_CUDA(cudaStreamSynchronize(stream)); + + // Identity expert + uniform weights: result[t] == top_k * tokens[t]. + std::vector h_res(num_tokens * hidden_dim_); + CHECK_CUDA(cudaMemcpy(h_res.data(), b.result.get(), + h_res.size() * sizeof(nv_bfloat16), cudaMemcpyDeviceToHost)); + const float expected = static_cast(top_k) * 0.3f; + const int probes[3] = {0, hidden_dim_ / 2, hidden_dim_ - 1}; + for (int t = 0; t < num_tokens; ++t) + for (int p : probes) + EXPECT_NEAR(__bfloat162float(h_res[t * hidden_dim_ + p]), expected, 1e-2f) + << "alignment=" << alignment << " tok=" << t << " hidden=" << p; + + CHECK_CUDA(cudaStreamDestroy(stream)); +} + +INSTANTIATE_TEST_SUITE_P(Alignments, EmptyExpertsTest, + ::testing::Values(0, 2, 8, 16)); + +// ============================================================================= +// NegativeTests: prepare/dispatch must surface bad inputs as exceptions. +// ============================================================================= + +class NegativeTests : public EpCoverageBase {}; + +TEST_F(NegativeTests, AlignmentMismatchThrows) { + const int num_tokens = 8, top_k = 2; + // Allocate handle for alignment=0, then call prepare with alignment=16. + Bundle b = make_bundle(num_tokens, top_k, num_local_experts_, /*alignment=*/0); + auto h_idx = routing_balanced(g_process_id, num_tokens, top_k, + num_experts_, num_local_experts_); + CHECK_CUDA(cudaMemcpy(b.topk_idx.get(), h_idx.data(), + h_idx.size() * sizeof(int64_t), cudaMemcpyHostToDevice)); + + auto topk_idx_t = make_nvte_tensor(b.topk_idx.get(), + {(size_t)num_tokens, (size_t)top_k}, kNVTEInt64); + auto token_counts_t = make_nvte_tensor(b.token_counts.get(), + {(size_t)num_local_experts_}, kNVTEInt32); + auto handle_mem_t = make_nvte_tensor(b.handle_mem.get(), + {b.handle_mem_size}, kNVTEByte); + + cudaStream_t stream; + CHECK_CUDA(cudaStreamCreate(&stream)); + NVTEEpHandle h{b.handle_id, handle_mem_t.tensor}; + EXPECT_THROW(nvte_ep_prepare(h, topk_idx_t.tensor, token_counts_t.tensor, + /*alignment=*/16, stream), + std::exception); + CHECK_CUDA(cudaStreamDestroy(stream)); +} + +TEST_F(NegativeTests, NullHandleMemThrows) { + const int num_tokens = 8, top_k = 2; + Bundle b = make_bundle(num_tokens, top_k, num_local_experts_, /*alignment=*/0); + auto h_idx = routing_balanced(g_process_id, num_tokens, top_k, + num_experts_, num_local_experts_); + CHECK_CUDA(cudaMemcpy(b.topk_idx.get(), h_idx.data(), + h_idx.size() * sizeof(int64_t), cudaMemcpyHostToDevice)); + + auto topk_idx_t = make_nvte_tensor(b.topk_idx.get(), + {(size_t)num_tokens, (size_t)top_k}, kNVTEInt64); + auto token_counts_t = make_nvte_tensor(b.token_counts.get(), + {(size_t)num_local_experts_}, kNVTEInt32); + // Construct a tensor view backed by a null device pointer. + auto null_hm_t = make_nvte_tensor(nullptr, {b.handle_mem_size}, kNVTEByte); + + cudaStream_t stream; + CHECK_CUDA(cudaStreamCreate(&stream)); + NVTEEpHandle h{b.handle_id, null_hm_t.tensor}; + EXPECT_THROW(nvte_ep_prepare(h, topk_idx_t.tensor, token_counts_t.tensor, + /*alignment=*/0, stream), + std::exception); + CHECK_CUDA(cudaStreamDestroy(stream)); +} + +// ── main ────────────────────────────────────────────────────────────────────── + +int main(int argc, char* argv[]) { + if (!ep_bootstrap(argc, argv)) return 0; + int ret = RUN_ALL_TESTS(); + ep_teardown(); + return ret; +} diff --git a/tests/cpp_distributed/test_ep_init.cu b/tests/cpp_distributed/test_ep_init.cu new file mode 100644 index 0000000000..08744dfee5 --- /dev/null +++ b/tests/cpp_distributed/test_ep_init.cu @@ -0,0 +1,64 @@ +/************************************************************************* + * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +/* + * Unit tests for EP initialization paths. + * + * Tests: + * EPInitTest/InitPath — backend is live after init, handle_mem_size > 0 + * EPInitTest/NumLocalExperts — handle_mem_size is consistent across num_local_experts values + * + * Run via run_test_ep.sh (both uid and comm init paths are tested by the script). + */ + +#include "test_ep_common.h" + +// ── Fixture ─────────────────────────────────────────────────────────────────── + +class EPInitTest : public ::testing::Test { + protected: + void SetUp() override { + if (g_sm_major < 9) + GTEST_SKIP() << "EP requires SM_90+ (device is SM_" << g_sm_major << "0)"; + ASSERT_GE(g_num_processes, 2) << "EP tests require at least 2 processes"; + ASSERT_TRUE(g_ep_initialized) << "EP not initialized"; + } +}; + +// ── Tests ───────────────────────────────────────────────────────────────────── + +TEST_F(EPInitTest, InitPath) { + int nle = g_num_experts / g_ep_size; + NVTEEpLayerConfig cfg{nle, /*top_k=*/2}; + size_t sz = 0; + (void)nvte_ep_register_layer(cfg, &sz); + ASSERT_GT(sz, 0u) << "handle_mem_size must be > 0 after init"; + + if (g_process_id == 0) { + printf(" handle_mem : %zu bytes\n", sz); + } +} + +TEST_F(EPInitTest, NumLocalExperts) { + // handle_mem_size should be > 0 for any valid num_local_experts value. + for (int nle : {1, g_num_experts / g_ep_size}) { + NVTEEpLayerConfig cfg{nle, /*top_k=*/2}; + size_t sz = 0; + (void)nvte_ep_register_layer(cfg, &sz); + ASSERT_GT(sz, 0u) << "num_local_experts=" << nle; + if (g_process_id == 0) + printf(" nle=%-3d handle_mem_size=%zu bytes\n", nle, sz); + } +} + +// ── main ────────────────────────────────────────────────────────────────────── + +int main(int argc, char* argv[]) { + if (!ep_bootstrap(argc, argv)) return 0; + int ret = RUN_ALL_TESTS(); + ep_teardown(); + return ret; +} diff --git a/tests/cpp_distributed/test_ep_pipeline.cu b/tests/cpp_distributed/test_ep_pipeline.cu new file mode 100644 index 0000000000..41f83a6d11 --- /dev/null +++ b/tests/cpp_distributed/test_ep_pipeline.cu @@ -0,0 +1,890 @@ +/************************************************************************* + * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +/* + * EP pipeline tests: smallest-scope first. + * + * EPDispatchTest/PrepareAndDispatch — exact recv values + per-expert counts + * EPCombineTest/Combine — round-trip: out == top_k * tokens + * EPCombineBwdTest/CombineBwdCheck — exact grad_expert values + * EPDispatchBwdTest/DispatchBwdCheck — exact grad_tokens + * EPDispatchBwdGradWeightsTest/RoundTrip — exact per-(t, k) grad_topk_weights + * EPPipelineTest/FullForwardBackward — fwd + bwd NaN/Inf check + * + * Routing: token t on rank r → expert (r * num_local_experts + t * top_k + k) % num_experts + * Token values: rank r, token t → all hidden dims = (r+1)*0.01 + t*0.001 + * + * Closed-form expected values: + * dispatch recv: multiset of source-token values routed to this rank's experts + * combine: result[t] == top_k * tokens[t] + * combine_bwd: grad_expert[slot] == d_result[t] (no weighting) + * dispatch_bwd: grad_tokens[t] == top_k * d_result[t] + */ + +#include "test_ep_common.h" + +#include +#include +#include +#include + +// ── Deterministic routing helpers ───────────────────────────────────────────── + +// Token value for (rank, t): (rank * num_tokens + t + 1) / 256. Step 1/256 is +// bf16-exact and unique across (rank, t) when rank * num_tokens + t < 256. +static inline float token_value(int rank, int t, int num_tokens) { + return static_cast(rank * num_tokens + t + 1) * (1.0f / 256.0f); +} + +static std::vector generate_tokens(int rank, int num_tokens, int hidden_dim) { + std::vector v(num_tokens * hidden_dim); + for (int t = 0; t < num_tokens; ++t) { + nv_bfloat16 val = __float2bfloat16(token_value(rank, t, num_tokens)); + for (int h = 0; h < hidden_dim; ++h) + v[t * hidden_dim + h] = val; + } + return v; +} + +static std::vector expected_token_counts( + int recv_rank, int num_processes, int num_tokens, int top_k, + int num_experts, int num_local_experts) { + int base = recv_rank * num_local_experts; + std::vector cnt(num_local_experts, 0); + for (int src = 0; src < num_processes; ++src) { + auto idx = routing_balanced(src, num_tokens, top_k, num_experts, num_local_experts); + for (int t = 0; t < num_tokens; ++t) + for (int k = 0; k < top_k; ++k) { + int64_t e = idx[t * top_k + k]; + if (e >= base && e < base + num_local_experts) ++cnt[e - base]; + } + } + return cnt; +} + +static std::vector expected_recv_values_sorted( + int recv_rank, int num_processes, int num_tokens, int top_k, + int num_experts, int num_local_experts) { + int base = recv_rank * num_local_experts; + std::vector vals; + for (int src = 0; src < num_processes; ++src) { + auto idx = routing_balanced(src, num_tokens, top_k, num_experts, num_local_experts); + for (int t = 0; t < num_tokens; ++t) + for (int k = 0; k < top_k; ++k) { + int64_t e = idx[t * top_k + k]; + if (e >= base && e < base + num_local_experts) { + float raw = token_value(src, t, num_tokens); + vals.push_back(__bfloat162float(__float2bfloat16(raw))); + } + } + } + std::sort(vals.begin(), vals.end()); + return vals; +} + +// BF16 has 7 mantissa bits; relative ULP ≈ 2^-7. Use 4× headroom for +// accumulation noise inside dispatch/combine. +static float bf16_tol(float magnitude) { + return 4.f * std::ldexp(std::fabs(magnitude) + 1e-3f, -7); +} + +static bool check_no_nan_inf(const nv_bfloat16* dev, int count, const char* name) { + std::vector h(count); + cudaMemcpy(h.data(), dev, count * sizeof(nv_bfloat16), cudaMemcpyDeviceToHost); + for (int i = 0; i < count; ++i) { + float v = __bfloat162float(h[i]); + if (std::isnan(v) || std::isinf(v)) { + fprintf(stderr, "Rank %d: %s in %s[%d]\n", + g_process_id, std::isnan(v) ? "NaN" : "Inf", name, i); + return false; + } + } + return true; +} + +// ── Forward buffer set with RAII ────────────────────────────────────────────── + +struct EPBuffers { + // Forward + DevBuf topk_idx; + DevBuf topk_weights; + DevBuf tokens; + DevBuf token_counts; + DevBuf handle_mem; + DevBuf recv_tokens; + DevBuf recv_topk_weights; + DevBuf result; + // Backward + DevBuf grad_result; + DevBuf grad_expert; + DevBuf grad_tokens; + DevBuf g_recv_topk_weights; + DevBuf grad_topk_weights; + + uint64_t handle_id = 0; + size_t handle_mem_size = 0; + size_t recv_capacity = 0; + int top_k_ = 0; + + void alloc(int num_tokens, int top_k, int hidden_dim, int num_local_experts, + int ep_size, int max_tokens_per_rank, size_t alignment = 0) { + top_k_ = top_k; + recv_capacity = static_cast(ep_size) * max_tokens_per_rank * 2; + + topk_idx.alloc(num_tokens * top_k); + topk_weights.alloc(num_tokens * top_k); + tokens.alloc(num_tokens * hidden_dim); + token_counts.alloc(num_local_experts); + recv_tokens.alloc(recv_capacity * hidden_dim); + recv_topk_weights.alloc(recv_capacity); + result.alloc(num_tokens * hidden_dim); + + NVTEEpLayerConfig cfg{num_local_experts, top_k, alignment}; + handle_id = nvte_ep_register_layer(cfg, &handle_mem_size); + handle_mem.alloc(handle_mem_size); + + grad_result.alloc(num_tokens * hidden_dim); + grad_expert.alloc(recv_capacity * hidden_dim); + grad_tokens.alloc(num_tokens * hidden_dim); + g_recv_topk_weights.alloc(recv_capacity); + grad_topk_weights.alloc(num_tokens * top_k); + } +}; + +// Bundled NVTETensor views over an EPBuffers — one place to update the shape +// conventions when the C-API evolves. +struct EPTensors { + TensorHandle topk_idx, topk_weights, token_counts, handle_mem, tokens; + TensorHandle recv_tokens, recv_topk_weights, result; + TensorHandle grad_result, grad_expert, grad_tokens; + TensorHandle g_recv_topk_weights, grad_topk_weights; + + EPTensors(EPBuffers& b, int num_tokens, int top_k, int hidden_dim, + int num_local_experts) { + topk_idx = make_nvte_tensor(b.topk_idx.get(), + {(size_t)num_tokens, (size_t)top_k}, kNVTEInt64); + topk_weights = make_nvte_tensor(b.topk_weights.get(), + {(size_t)num_tokens, (size_t)top_k}, kNVTEFloat32); + token_counts = make_nvte_tensor(b.token_counts.get(), + {(size_t)num_local_experts}, kNVTEInt32); + handle_mem = make_nvte_tensor(b.handle_mem.get(), + {b.handle_mem_size}, kNVTEByte); + tokens = make_nvte_tensor(b.tokens.get(), + {(size_t)num_tokens, (size_t)hidden_dim}, kNVTEBFloat16); + recv_tokens = make_nvte_tensor(b.recv_tokens.get(), + {b.recv_capacity, (size_t)hidden_dim}, kNVTEBFloat16); + recv_topk_weights = make_nvte_tensor(b.recv_topk_weights.get(), + {b.recv_capacity}, kNVTEFloat32); + result = make_nvte_tensor(b.result.get(), + {(size_t)num_tokens, (size_t)hidden_dim}, kNVTEBFloat16); + grad_result = make_nvte_tensor(b.grad_result.get(), + {(size_t)num_tokens, (size_t)hidden_dim}, kNVTEBFloat16); + grad_expert = make_nvte_tensor(b.grad_expert.get(), + {b.recv_capacity, (size_t)hidden_dim}, kNVTEBFloat16); + grad_tokens = make_nvte_tensor(b.grad_tokens.get(), + {(size_t)num_tokens, (size_t)hidden_dim}, kNVTEBFloat16); + g_recv_topk_weights = make_nvte_tensor(b.g_recv_topk_weights.get(), + {b.recv_capacity}, kNVTEFloat32); + grad_topk_weights = make_nvte_tensor(b.grad_topk_weights.get(), + {(size_t)num_tokens, (size_t)top_k}, kNVTEFloat32); + } +}; + +// ── Shared fixture base ─────────────────────────────────────────────────────── + +class EpOpTestBase : public ::testing::Test { + protected: + int ep_size_, num_experts_, num_local_experts_, hidden_dim_; + int max_tokens_per_rank_, top_k_, num_tokens_; + + void SetUp() override { + if (g_sm_major < 9) + GTEST_SKIP() << "EP requires SM_90+ (device is SM_" << g_sm_major << "0)"; + ASSERT_GE(g_num_processes, 2); + ASSERT_TRUE(g_ep_initialized); + + ep_size_ = g_ep_size; + num_experts_ = g_num_experts; + num_local_experts_ = num_experts_ / ep_size_; + hidden_dim_ = g_hidden_dim; + max_tokens_per_rank_ = g_max_tokens_per_rank; + top_k_ = 2; + num_tokens_ = 32; + } + + void upload_inputs(EPBuffers& buf, int rank = -1) { + if (rank < 0) rank = g_process_id; + auto h_idx = routing_balanced(rank, num_tokens_, top_k_, + num_experts_, num_local_experts_); + std::vector h_w(num_tokens_ * top_k_, 1.0f / top_k_); + auto h_tok = generate_tokens(rank, num_tokens_, hidden_dim_); + + CHECK_CUDA(cudaMemcpy(buf.topk_idx.get(), h_idx.data(), + h_idx.size() * sizeof(int64_t), cudaMemcpyHostToDevice)); + CHECK_CUDA(cudaMemcpy(buf.topk_weights.get(), h_w.data(), + h_w.size() * sizeof(float), cudaMemcpyHostToDevice)); + CHECK_CUDA(cudaMemcpy(buf.tokens.get(), h_tok.data(), + h_tok.size() * sizeof(nv_bfloat16), cudaMemcpyHostToDevice)); + } + + NVTEEpLayerConfig layer_config(size_t alignment = 0) const { + return NVTEEpLayerConfig{num_local_experts_, top_k_, alignment}; + } + + // ASSERT_CUDA_OK (fprintf+exit) so this non-void helper stays legal. + int read_total_recv(const EPBuffers& buf) const { + std::vector cnt(num_local_experts_); + ASSERT_CUDA_OK(cudaMemcpy(cnt.data(), buf.token_counts.get(), + num_local_experts_ * sizeof(int32_t), cudaMemcpyDeviceToHost)); + int total = 0; + for (int c : cnt) total += c; + return total; + } +}; + +// ============================================================================= +// EPDispatchTest: exact recv values and per-expert counts. +// ============================================================================= + +class EPDispatchTest : public EpOpTestBase {}; + +TEST_F(EPDispatchTest, PrepareAndDispatch) { + EPBuffers buf; + buf.alloc(num_tokens_, top_k_, hidden_dim_, num_local_experts_, + ep_size_, max_tokens_per_rank_); + upload_inputs(buf); + EPTensors t(buf, num_tokens_, top_k_, hidden_dim_, num_local_experts_); + + CHECK_CUDA(cudaMemset(buf.recv_tokens.get(), 0, buf.recv_tokens.bytes())); + + cudaStream_t stream; + CHECK_CUDA(cudaStreamCreate(&stream)); + + uint64_t handle_id = buf.handle_id; + ASSERT_NO_THROW(nvte_ep_prepare(NVTEEpHandle{handle_id, t.handle_mem.tensor}, t.topk_idx.tensor, t.token_counts.tensor, /*alignment=*/0, stream)); + ASSERT_NO_THROW(nvte_ep_dispatch(NVTEEpHandle{handle_id, t.handle_mem.tensor}, t.topk_idx.tensor, + t.tokens.tensor, NVTECommWindow{}, t.topk_weights.tensor, + NVTECommWindow{}, t.recv_tokens.tensor, NVTECommWindow{}, + t.recv_topk_weights.tensor, NVTECommWindow{}, stream)); + CHECK_CUDA(cudaStreamSynchronize(stream)); + + // 1. Per-expert counts. + std::vector got_counts(num_local_experts_); + CHECK_CUDA(cudaMemcpy(got_counts.data(), buf.token_counts.get(), + num_local_experts_ * sizeof(int32_t), cudaMemcpyDeviceToHost)); + auto exp_counts = expected_token_counts(g_process_id, g_num_processes, num_tokens_, top_k_, + num_experts_, num_local_experts_); + int total_recv = 0; + for (int i = 0; i < num_local_experts_; ++i) { + EXPECT_EQ(got_counts[i], exp_counts[i]) << "local expert " << i; + total_recv += exp_counts[i]; + } + ASSERT_LE(total_recv, static_cast(buf.recv_capacity)) + << "total_recv exceeded recv_capacity — overflow would corrupt downstream memory"; + + // 2. Recv values: read only the filled prefix per local-expert zone, not the + // whole recv buffer — avoids false positives from legitimate-zero token values. + std::vector h_recv(buf.recv_capacity * hidden_dim_); + CHECK_CUDA(cudaMemcpy(h_recv.data(), buf.recv_tokens.get(), + h_recv.size() * sizeof(nv_bfloat16), cudaMemcpyDeviceToHost)); + + std::vector got_vals; + got_vals.reserve(total_recv); + size_t slot = 0; + for (int e = 0; e < num_local_experts_; ++e) { + for (int i = 0; i < got_counts[e]; ++i) { + got_vals.push_back(__bfloat162float(h_recv[slot * hidden_dim_])); + ++slot; + } + } + std::sort(got_vals.begin(), got_vals.end()); + + auto exp_vals = expected_recv_values_sorted(g_process_id, g_num_processes, num_tokens_, + top_k_, num_experts_, num_local_experts_); + + ASSERT_EQ(got_vals.size(), exp_vals.size()); + for (size_t i = 0; i < exp_vals.size(); ++i) + EXPECT_NEAR(got_vals[i], exp_vals[i], bf16_tol(exp_vals[i])) + << "recv value mismatch at sorted index " << i; + + // 3. recv_topk_weights: every filled slot must equal the per-token weight (1/top_k). + std::vector h_w(buf.recv_capacity); + CHECK_CUDA(cudaMemcpy(h_w.data(), buf.recv_topk_weights.get(), + h_w.size() * sizeof(float), cudaMemcpyDeviceToHost)); + const float exp_w = 1.0f / static_cast(top_k_); + for (int i = 0; i < total_recv; ++i) + EXPECT_NEAR(h_w[i], exp_w, 1e-6f) << "recv_topk_weights[" << i << "]"; + + if (g_process_id == 0) + printf(" PrepareAndDispatch: passed (recv=%d, values + weights exact)\n", total_recv); + + CHECK_CUDA(cudaStreamDestroy(stream)); +} + +// ============================================================================= +// EPCombineTest: round-trip identity expert → result == top_k * tokens. +// ============================================================================= + +class EPCombineTest : public EpOpTestBase {}; + +TEST_F(EPCombineTest, Combine) { + EPBuffers buf; + buf.alloc(num_tokens_, top_k_, hidden_dim_, num_local_experts_, + ep_size_, max_tokens_per_rank_); + upload_inputs(buf); + EPTensors t(buf, num_tokens_, top_k_, hidden_dim_, num_local_experts_); + + cudaStream_t stream; + CHECK_CUDA(cudaStreamCreate(&stream)); + + uint64_t handle_id = buf.handle_id; + ASSERT_NO_THROW(nvte_ep_prepare(NVTEEpHandle{handle_id, t.handle_mem.tensor}, t.topk_idx.tensor, t.token_counts.tensor, /*alignment=*/0, stream)); + ASSERT_NO_THROW(nvte_ep_dispatch(NVTEEpHandle{handle_id, t.handle_mem.tensor}, t.topk_idx.tensor, + t.tokens.tensor, NVTECommWindow{}, t.topk_weights.tensor, + NVTECommWindow{}, t.recv_tokens.tensor, NVTECommWindow{}, + t.recv_topk_weights.tensor, NVTECommWindow{}, stream)); + ASSERT_NO_THROW(nvte_ep_combine(NVTEEpHandle{handle_id, t.handle_mem.tensor}, t.recv_tokens.tensor, NVTECommWindow{}, + t.result.tensor, stream)); + CHECK_CUDA(cudaStreamSynchronize(stream)); + + std::vector h_result(num_tokens_ * hidden_dim_); + CHECK_CUDA(cudaMemcpy(h_result.data(), buf.result.get(), + h_result.size() * sizeof(nv_bfloat16), cudaMemcpyDeviceToHost)); + auto h_tok = generate_tokens(g_process_id, num_tokens_, hidden_dim_); + // Spot-check 3 hidden-dim positions per token to catch partial-row writes. + const int probes[3] = {0, hidden_dim_ / 2, hidden_dim_ - 1}; + for (int tok = 0; tok < num_tokens_; ++tok) { + float exp = __bfloat162float(h_tok[tok * hidden_dim_]) * static_cast(top_k_); + for (int p : probes) { + float got = __bfloat162float(h_result[tok * hidden_dim_ + p]); + EXPECT_NEAR(got, exp, bf16_tol(exp)) + << "token " << tok << " rank " << g_process_id << " hidden " << p; + } + } + + if (g_process_id == 0) + printf(" Combine: passed (result == top_k * tokens)\n"); + + CHECK_CUDA(cudaStreamDestroy(stream)); +} + +// ============================================================================= +// EPCombineBwdTest: filled slots in grad_expert == d_result (unweighted). +// ============================================================================= + +class EPCombineBwdTest : public EpOpTestBase {}; + +TEST_F(EPCombineBwdTest, CombineBwdCheck) { + EPBuffers buf; + buf.alloc(num_tokens_, top_k_, hidden_dim_, num_local_experts_, + ep_size_, max_tokens_per_rank_); + upload_inputs(buf); + EPTensors t(buf, num_tokens_, top_k_, hidden_dim_, num_local_experts_); + + cudaStream_t stream; + CHECK_CUDA(cudaStreamCreate(&stream)); + + uint64_t handle_id = buf.handle_id; + ASSERT_NO_THROW(nvte_ep_prepare(NVTEEpHandle{handle_id, t.handle_mem.tensor}, t.topk_idx.tensor, t.token_counts.tensor, /*alignment=*/0, stream)); + ASSERT_NO_THROW(nvte_ep_dispatch(NVTEEpHandle{handle_id, t.handle_mem.tensor}, t.topk_idx.tensor, + t.tokens.tensor, NVTECommWindow{}, t.topk_weights.tensor, + NVTECommWindow{}, t.recv_tokens.tensor, NVTECommWindow{}, + t.recv_topk_weights.tensor, NVTECommWindow{}, stream)); + ASSERT_NO_THROW(nvte_ep_combine(NVTEEpHandle{handle_id, t.handle_mem.tensor}, t.recv_tokens.tensor, NVTECommWindow{}, + t.result.tensor, stream)); + + std::vector h_grad_r(num_tokens_ * hidden_dim_, __float2bfloat16(0.1f)); + CHECK_CUDA(cudaMemcpyAsync(buf.grad_result.get(), h_grad_r.data(), + h_grad_r.size() * sizeof(nv_bfloat16), + cudaMemcpyHostToDevice, stream)); + CHECK_CUDA(cudaMemsetAsync(buf.grad_expert.get(), 0, buf.grad_expert.bytes(), stream)); + + ASSERT_NO_THROW(nvte_ep_combine_bwd(NVTEEpHandle{handle_id, t.handle_mem.tensor}, t.grad_result.tensor, NVTECommWindow{}, + t.grad_expert.tensor, NVTECommWindow{}, stream)); + CHECK_CUDA(cudaStreamSynchronize(stream)); + + int total_recv = read_total_recv(buf); + + std::vector cnt(num_local_experts_); + CHECK_CUDA(cudaMemcpy(cnt.data(), buf.token_counts.get(), + num_local_experts_ * sizeof(int32_t), cudaMemcpyDeviceToHost)); + std::vector h_ge(buf.recv_capacity * hidden_dim_); + CHECK_CUDA(cudaMemcpy(h_ge.data(), buf.grad_expert.get(), + h_ge.size() * sizeof(nv_bfloat16), cudaMemcpyDeviceToHost)); + + // Walk filled slots by per-expert zone (no v != 0 heuristic). + const float kExpGrad = 0.1f; + size_t slot = 0; + int filled = 0; + for (int e = 0; e < num_local_experts_; ++e) { + for (int i = 0; i < cnt[e]; ++i) { + float v = __bfloat162float(h_ge[slot * hidden_dim_]); + EXPECT_NEAR(v, kExpGrad, bf16_tol(kExpGrad)) + << "grad_expert expert " << e << " slot " << i << " (linear " << slot << ")"; + ++filled; ++slot; + } + } + EXPECT_EQ(filled, total_recv); + + if (g_process_id == 0) + printf(" CombineBwdCheck: passed (filled=%d)\n", filled); + + CHECK_CUDA(cudaStreamDestroy(stream)); +} + +// ============================================================================= +// EPDispatchBwdTest: grad_tokens == top_k * d_result. +// ============================================================================= + +class EPDispatchBwdTest : public EpOpTestBase {}; + +TEST_F(EPDispatchBwdTest, DispatchBwdCheck) { + EPBuffers buf; + buf.alloc(num_tokens_, top_k_, hidden_dim_, num_local_experts_, + ep_size_, max_tokens_per_rank_); + upload_inputs(buf); + EPTensors t(buf, num_tokens_, top_k_, hidden_dim_, num_local_experts_); + + cudaStream_t stream; + CHECK_CUDA(cudaStreamCreate(&stream)); + + uint64_t handle_id = buf.handle_id; + ASSERT_NO_THROW(nvte_ep_prepare(NVTEEpHandle{handle_id, t.handle_mem.tensor}, t.topk_idx.tensor, t.token_counts.tensor, /*alignment=*/0, stream)); + ASSERT_NO_THROW(nvte_ep_dispatch(NVTEEpHandle{handle_id, t.handle_mem.tensor}, t.topk_idx.tensor, + t.tokens.tensor, NVTECommWindow{}, t.topk_weights.tensor, + NVTECommWindow{}, t.recv_tokens.tensor, NVTECommWindow{}, + t.recv_topk_weights.tensor, NVTECommWindow{}, stream)); + ASSERT_NO_THROW(nvte_ep_combine(NVTEEpHandle{handle_id, t.handle_mem.tensor}, t.recv_tokens.tensor, NVTECommWindow{}, + t.result.tensor, stream)); + + std::vector h_grad(num_tokens_ * hidden_dim_, __float2bfloat16(0.1f)); + CHECK_CUDA(cudaMemcpyAsync(buf.grad_result.get(), h_grad.data(), + h_grad.size() * sizeof(nv_bfloat16), + cudaMemcpyHostToDevice, stream)); + CHECK_CUDA(cudaMemsetAsync(buf.grad_expert.get(), 0, buf.grad_expert.bytes(), stream)); + CHECK_CUDA(cudaMemsetAsync(buf.g_recv_topk_weights.get(), 0, buf.g_recv_topk_weights.bytes(), stream)); + CHECK_CUDA(cudaMemsetAsync(buf.grad_topk_weights.get(), 0, buf.grad_topk_weights.bytes(), stream)); + + ASSERT_NO_THROW(nvte_ep_combine_bwd(NVTEEpHandle{handle_id, t.handle_mem.tensor}, t.grad_result.tensor, NVTECommWindow{}, + t.grad_expert.tensor, NVTECommWindow{}, stream)); + ASSERT_NO_THROW(nvte_ep_dispatch_bwd(NVTEEpHandle{handle_id, t.handle_mem.tensor}, t.grad_expert.tensor, NVTECommWindow{}, + t.g_recv_topk_weights.tensor, NVTECommWindow{}, + t.grad_tokens.tensor, t.grad_topk_weights.tensor, stream)); + CHECK_CUDA(cudaStreamSynchronize(stream)); + + std::vector h_gt(num_tokens_ * hidden_dim_); + CHECK_CUDA(cudaMemcpy(h_gt.data(), buf.grad_tokens.get(), + h_gt.size() * sizeof(nv_bfloat16), cudaMemcpyDeviceToHost)); + const float kExpGrad = static_cast(top_k_) * 0.1f; + for (int tok = 0; tok < num_tokens_; ++tok) + EXPECT_NEAR(__bfloat162float(h_gt[tok * hidden_dim_]), kExpGrad, bf16_tol(kExpGrad)) + << "grad_tokens token " << tok; + + if (g_process_id == 0) + printf(" DispatchBwdCheck: passed (grad_tokens == %.2f)\n", kExpGrad); + + CHECK_CUDA(cudaStreamDestroy(stream)); +} + +// ============================================================================= +// EPDispatchBwdGradWeightsTest: round-trip per-(t, k) weights. +// ============================================================================= + +class EPDispatchBwdGradWeightsTest : public EpOpTestBase {}; + +TEST_F(EPDispatchBwdGradWeightsTest, RoundTrip) { + EPBuffers buf; + buf.alloc(num_tokens_, top_k_, hidden_dim_, num_local_experts_, + ep_size_, max_tokens_per_rank_); + upload_inputs(buf); + EPTensors t(buf, num_tokens_, top_k_, hidden_dim_, num_local_experts_); + + // Distinct per-(rank, t, k) weights so each slot carries a unique value. + std::vector h_w(num_tokens_ * top_k_); + for (int tok = 0; tok < num_tokens_; ++tok) + for (int k = 0; k < top_k_; ++k) + h_w[tok * top_k_ + k] = 0.1f + 0.01f * tok + 0.001f * k + + 0.0001f * (g_process_id + 1); + CHECK_CUDA(cudaMemcpy(buf.topk_weights.get(), h_w.data(), + h_w.size() * sizeof(float), cudaMemcpyHostToDevice)); + + cudaStream_t stream; + CHECK_CUDA(cudaStreamCreate(&stream)); + + uint64_t handle_id = buf.handle_id; + ASSERT_NO_THROW(nvte_ep_prepare(NVTEEpHandle{handle_id, t.handle_mem.tensor}, t.topk_idx.tensor, t.token_counts.tensor, /*alignment=*/0, stream)); + CHECK_CUDA(cudaMemsetAsync(buf.recv_topk_weights.get(), 0, + buf.recv_topk_weights.bytes(), stream)); + ASSERT_NO_THROW(nvte_ep_dispatch(NVTEEpHandle{handle_id, t.handle_mem.tensor}, t.topk_idx.tensor, + t.tokens.tensor, NVTECommWindow{}, t.topk_weights.tensor, + NVTECommWindow{}, t.recv_tokens.tensor, NVTECommWindow{}, + t.recv_topk_weights.tensor, NVTECommWindow{}, stream)); + + // Sentinel: NaN so any (t, k) the bwd kernel fails to write is immediately visible. + std::vector h_nan(num_tokens_ * top_k_, + std::numeric_limits::quiet_NaN()); + CHECK_CUDA(cudaMemcpyAsync(buf.grad_topk_weights.get(), h_nan.data(), + h_nan.size() * sizeof(float), + cudaMemcpyHostToDevice, stream)); + CHECK_CUDA(cudaMemsetAsync(buf.grad_expert.get(), 0, buf.grad_expert.bytes(), stream)); + + // g_recv_topk_weights := recv_topk_weights (the round-trip input). + auto g_recv_t = make_nvte_tensor(buf.recv_topk_weights.get(), + {buf.recv_capacity}, kNVTEFloat32); + ASSERT_NO_THROW(nvte_ep_dispatch_bwd(NVTEEpHandle{handle_id, t.handle_mem.tensor}, t.grad_expert.tensor, + NVTECommWindow{}, g_recv_t.tensor, NVTECommWindow{}, + t.grad_tokens.tensor, t.grad_topk_weights.tensor, stream)); + CHECK_CUDA(cudaStreamSynchronize(stream)); + + std::vector h_grad_w(num_tokens_ * top_k_); + CHECK_CUDA(cudaMemcpy(h_grad_w.data(), buf.grad_topk_weights.get(), + h_grad_w.size() * sizeof(float), cudaMemcpyDeviceToHost)); + + const float kTol = 1e-5f; + int errs = 0, k0_eq_k1 = 0; + for (int tok = 0; tok < num_tokens_; ++tok) { + for (int k = 0; k < top_k_; ++k) { + float got = h_grad_w[tok * top_k_ + k]; + float exp = h_w[tok * top_k_ + k]; + if (std::isnan(got) || std::fabs(got - exp) > kTol) { + if (errs < 8) + fprintf(stderr, "Rank %d: grad_topk_weights[%d, %d]: got %.6f, expected %.6f\n", + g_process_id, tok, k, got, exp); + ++errs; + } + } + if (top_k_ >= 2 && + std::fabs(h_grad_w[tok * top_k_ + 0] - h_grad_w[tok * top_k_ + 1]) < 1e-7f) + ++k0_eq_k1; + } + EXPECT_EQ(errs, 0); + EXPECT_EQ(k0_eq_k1, 0) << "per-token-average regression: grad[t, 0] == grad[t, 1]"; + + if (g_process_id == 0 && errs == 0 && k0_eq_k1 == 0) + printf(" RoundTrip: passed (%d (t, k) gradients)\n", num_tokens_ * top_k_); + + CHECK_CUDA(cudaStreamDestroy(stream)); +} + +// ============================================================================= +// Integrated FwdBwd: NaN/Inf check end-to-end. +// ============================================================================= + +class EPPipelineTest : public EpOpTestBase {}; + +TEST_F(EPPipelineTest, FullForwardBackward) { + EPBuffers buf; + buf.alloc(num_tokens_, top_k_, hidden_dim_, num_local_experts_, + ep_size_, max_tokens_per_rank_); + upload_inputs(buf); + EPTensors t(buf, num_tokens_, top_k_, hidden_dim_, num_local_experts_); + + cudaStream_t stream; + CHECK_CUDA(cudaStreamCreate(&stream)); + + uint64_t handle_id = buf.handle_id; + ASSERT_NO_THROW(nvte_ep_prepare(NVTEEpHandle{handle_id, t.handle_mem.tensor}, t.topk_idx.tensor, t.token_counts.tensor, /*alignment=*/0, stream)); + ASSERT_NO_THROW(nvte_ep_dispatch(NVTEEpHandle{handle_id, t.handle_mem.tensor}, t.topk_idx.tensor, + t.tokens.tensor, NVTECommWindow{}, t.topk_weights.tensor, + NVTECommWindow{}, t.recv_tokens.tensor, NVTECommWindow{}, + t.recv_topk_weights.tensor, NVTECommWindow{}, stream)); + ASSERT_NO_THROW(nvte_ep_combine(NVTEEpHandle{handle_id, t.handle_mem.tensor}, t.recv_tokens.tensor, NVTECommWindow{}, + t.result.tensor, stream)); + + std::vector h_grad(num_tokens_ * hidden_dim_, __float2bfloat16(0.1f)); + CHECK_CUDA(cudaMemcpyAsync(buf.grad_result.get(), h_grad.data(), + h_grad.size() * sizeof(nv_bfloat16), + cudaMemcpyHostToDevice, stream)); + CHECK_CUDA(cudaMemsetAsync(buf.grad_expert.get(), 0, buf.grad_expert.bytes(), stream)); + CHECK_CUDA(cudaMemsetAsync(buf.g_recv_topk_weights.get(), 0, buf.g_recv_topk_weights.bytes(), stream)); + CHECK_CUDA(cudaMemsetAsync(buf.grad_topk_weights.get(), 0, buf.grad_topk_weights.bytes(), stream)); + + ASSERT_NO_THROW(nvte_ep_combine_bwd(NVTEEpHandle{handle_id, t.handle_mem.tensor}, t.grad_result.tensor, NVTECommWindow{}, + t.grad_expert.tensor, NVTECommWindow{}, stream)); + ASSERT_NO_THROW(nvte_ep_dispatch_bwd(NVTEEpHandle{handle_id, t.handle_mem.tensor}, t.grad_expert.tensor, NVTECommWindow{}, + t.g_recv_topk_weights.tensor, NVTECommWindow{}, + t.grad_tokens.tensor, t.grad_topk_weights.tensor, stream)); + CHECK_CUDA(cudaStreamSynchronize(stream)); + + ASSERT_TRUE(check_no_nan_inf(buf.result.get(), num_tokens_ * hidden_dim_, "result")); + ASSERT_TRUE(check_no_nan_inf(buf.grad_tokens.get(), num_tokens_ * hidden_dim_, "grad_tokens")); + + if (g_process_id == 0) printf(" FullForwardBackward: passed\n"); + + CHECK_CUDA(cudaStreamDestroy(stream)); +} + +// ============================================================================= +// EPZeroCopyTest: dispatch/combine with NCCL symmetric-memory windows attached +// to payload tensors (zero-copy fast path via ncclEpTensorCreateFromWindow). +// Symm-mem requirements per spec: input&output of Dispatch, input of Combine, +// input&output of Combine bwd, input of Dispatch bwd. +// ============================================================================= + +namespace { + +// Caller-owned ncclMemAlloc'd buffer with a registered symmetric window. +// Frees in destructor (deregister + ncclMemFree). Non-copyable, move-only. +struct SymmBuf { + void* ptr = nullptr; + size_t bytes = 0; + ncclWindow_t win = nullptr; + + SymmBuf() = default; + SymmBuf(const SymmBuf&) = delete; + SymmBuf& operator=(const SymmBuf&) = delete; + SymmBuf(SymmBuf&& o) noexcept : ptr(o.ptr), bytes(o.bytes), win(o.win) { + o.ptr = nullptr; o.win = nullptr; o.bytes = 0; + } + ~SymmBuf() { + if (win) ncclCommWindowDeregister(g_ep_comm, win); + if (ptr) ncclMemFree(ptr); + } + + void alloc(size_t n_bytes) { + bytes = n_bytes; + ASSERT_NCCL_OK(ncclMemAlloc(&ptr, bytes)); + CHECK_CUDA(cudaMemset(ptr, 0, bytes)); + ASSERT_NCCL_OK(ncclCommWindowRegister(g_ep_comm, ptr, bytes, &win, + NCCL_WIN_COLL_SYMMETRIC)); + } +}; + +// Build an NVTECommWindow descriptor pointing at a SymmBuf's window (offset 0). +static inline NVTECommWindow symm_window(const SymmBuf& b) { + return NVTECommWindow{b.win, /*offset=*/0}; +} + +} // namespace + +class EPZeroCopyTest : public EpOpTestBase {}; + +// Identity round-trip with symm-mem on dispatch i/o + combine input. Bit-exact +// vs HBM reference (same routing, same input). +TEST_F(EPZeroCopyTest, IdentityAllSymm) { + // HBM reference run. + EPBuffers ref_buf; + ref_buf.alloc(num_tokens_, top_k_, hidden_dim_, num_local_experts_, + ep_size_, max_tokens_per_rank_); + upload_inputs(ref_buf); + EPTensors ref_t(ref_buf, num_tokens_, top_k_, hidden_dim_, num_local_experts_); + + cudaStream_t stream; + CHECK_CUDA(cudaStreamCreate(&stream)); + + uint64_t ref_hid = ref_buf.handle_id; + ASSERT_NO_THROW(nvte_ep_prepare(NVTEEpHandle{ref_hid, ref_t.handle_mem.tensor}, ref_t.topk_idx.tensor, ref_t.token_counts.tensor, /*alignment=*/0, stream)); + ASSERT_NO_THROW(nvte_ep_dispatch(NVTEEpHandle{ref_hid, ref_t.handle_mem.tensor}, ref_t.topk_idx.tensor, + ref_t.tokens.tensor, NVTECommWindow{}, ref_t.topk_weights.tensor, + NVTECommWindow{}, ref_t.recv_tokens.tensor, NVTECommWindow{}, + ref_t.recv_topk_weights.tensor, NVTECommWindow{}, stream)); + ASSERT_NO_THROW(nvte_ep_combine(NVTEEpHandle{ref_hid, ref_t.handle_mem.tensor}, ref_t.recv_tokens.tensor, NVTECommWindow{}, + ref_t.result.tensor, stream)); + CHECK_CUDA(cudaStreamSynchronize(stream)); + + std::vector ref_recv(ref_buf.recv_capacity * hidden_dim_); + std::vector ref_result(num_tokens_ * hidden_dim_); + CHECK_CUDA(cudaMemcpy(ref_recv.data(), ref_buf.recv_tokens.get(), + ref_recv.size() * sizeof(nv_bfloat16), cudaMemcpyDeviceToHost)); + CHECK_CUDA(cudaMemcpy(ref_result.data(), ref_buf.result.get(), + ref_result.size() * sizeof(nv_bfloat16), cudaMemcpyDeviceToHost)); + + // Symm-mem run: tokens, recv_tokens, combine_input (== recv_tokens) all symm. + EPBuffers sym_buf; // alloc all buffers except the symm ones. + sym_buf.alloc(num_tokens_, top_k_, hidden_dim_, num_local_experts_, + ep_size_, max_tokens_per_rank_); + upload_inputs(sym_buf); + + SymmBuf sym_tokens, sym_recv; + sym_tokens.alloc(num_tokens_ * hidden_dim_ * sizeof(nv_bfloat16)); + sym_recv .alloc(sym_buf.recv_capacity * hidden_dim_ * sizeof(nv_bfloat16)); + + // Stage same tokens into the symm-mem input. + auto h_tok = generate_tokens(g_process_id, num_tokens_, hidden_dim_); + CHECK_CUDA(cudaMemcpy(sym_tokens.ptr, h_tok.data(), + h_tok.size() * sizeof(nv_bfloat16), cudaMemcpyHostToDevice)); + + EPTensors sym_t(sym_buf, num_tokens_, top_k_, hidden_dim_, num_local_experts_); + // Replace the tokens/recv_tokens views with ones pointing at the symm buffers. + sym_t.tokens = make_nvte_tensor(sym_tokens.ptr, + {(size_t)num_tokens_, (size_t)hidden_dim_}, kNVTEBFloat16); + sym_t.recv_tokens = make_nvte_tensor(sym_recv.ptr, + {sym_buf.recv_capacity, (size_t)hidden_dim_}, kNVTEBFloat16); + + uint64_t sym_hid = sym_buf.handle_id; + ASSERT_NO_THROW(nvte_ep_prepare(NVTEEpHandle{sym_hid, sym_t.handle_mem.tensor}, sym_t.topk_idx.tensor, sym_t.token_counts.tensor, /*alignment=*/0, stream)); + ASSERT_NO_THROW(nvte_ep_dispatch(NVTEEpHandle{sym_hid, sym_t.handle_mem.tensor}, sym_t.topk_idx.tensor, + sym_t.tokens.tensor, symm_window(sym_tokens), + sym_t.topk_weights.tensor, NVTECommWindow{}, + sym_t.recv_tokens.tensor, symm_window(sym_recv), + sym_t.recv_topk_weights.tensor, NVTECommWindow{}, stream)); + ASSERT_NO_THROW(nvte_ep_combine(NVTEEpHandle{sym_hid, sym_t.handle_mem.tensor}, sym_t.recv_tokens.tensor, + symm_window(sym_recv), sym_t.result.tensor, stream)); + CHECK_CUDA(cudaStreamSynchronize(stream)); + + std::vector sym_recv_host(sym_buf.recv_capacity * hidden_dim_); + std::vector sym_result(num_tokens_ * hidden_dim_); + CHECK_CUDA(cudaMemcpy(sym_recv_host.data(), sym_recv.ptr, + sym_recv_host.size() * sizeof(nv_bfloat16), cudaMemcpyDeviceToHost)); + CHECK_CUDA(cudaMemcpy(sym_result.data(), sym_buf.result.get(), + sym_result.size() * sizeof(nv_bfloat16), cudaMemcpyDeviceToHost)); + + // Compare per filled recv slot (HBM ref vs symm) and full result. + int total_recv = read_total_recv(sym_buf); + for (int i = 0; i < total_recv * hidden_dim_; ++i) + ASSERT_EQ(__bfloat162float(sym_recv_host[i]), __bfloat162float(ref_recv[i])) + << "recv mismatch at " << i; + for (size_t i = 0; i < sym_result.size(); ++i) + ASSERT_EQ(__bfloat162float(sym_result[i]), __bfloat162float(ref_result[i])) + << "result mismatch at " << i; + + if (g_process_id == 0) + printf(" IdentityAllSymm: passed (recv_slots=%d, bit-exact vs HBM)\n", total_recv); + + CHECK_CUDA(cudaStreamDestroy(stream)); +} + +// Same buffers, 2 iterations — catches window-lifecycle regressions where the +// symm-mem registration goes stale between calls. +TEST_F(EPZeroCopyTest, IdentityAllSymmRepeated) { + EPBuffers buf; + buf.alloc(num_tokens_, top_k_, hidden_dim_, num_local_experts_, + ep_size_, max_tokens_per_rank_); + upload_inputs(buf); + + SymmBuf sym_tokens, sym_recv; + sym_tokens.alloc(num_tokens_ * hidden_dim_ * sizeof(nv_bfloat16)); + sym_recv .alloc(buf.recv_capacity * hidden_dim_ * sizeof(nv_bfloat16)); + auto h_tok = generate_tokens(g_process_id, num_tokens_, hidden_dim_); + CHECK_CUDA(cudaMemcpy(sym_tokens.ptr, h_tok.data(), + h_tok.size() * sizeof(nv_bfloat16), cudaMemcpyHostToDevice)); + + EPTensors t(buf, num_tokens_, top_k_, hidden_dim_, num_local_experts_); + t.tokens = make_nvte_tensor(sym_tokens.ptr, + {(size_t)num_tokens_, (size_t)hidden_dim_}, kNVTEBFloat16); + t.recv_tokens = make_nvte_tensor(sym_recv.ptr, + {buf.recv_capacity, (size_t)hidden_dim_}, kNVTEBFloat16); + + cudaStream_t stream; + CHECK_CUDA(cudaStreamCreate(&stream)); + + uint64_t handle_id = buf.handle_id; + for (int iter = 0; iter < 2; ++iter) { + ASSERT_NO_THROW(nvte_ep_prepare(NVTEEpHandle{handle_id, t.handle_mem.tensor}, t.topk_idx.tensor, t.token_counts.tensor, /*alignment=*/0, stream)); + ASSERT_NO_THROW(nvte_ep_dispatch(NVTEEpHandle{handle_id, t.handle_mem.tensor}, t.topk_idx.tensor, + t.tokens.tensor, symm_window(sym_tokens), + t.topk_weights.tensor, NVTECommWindow{}, + t.recv_tokens.tensor, symm_window(sym_recv), + t.recv_topk_weights.tensor, NVTECommWindow{}, stream)); + ASSERT_NO_THROW(nvte_ep_combine(NVTEEpHandle{handle_id, t.handle_mem.tensor}, t.recv_tokens.tensor, + symm_window(sym_recv), t.result.tensor, stream)); + CHECK_CUDA(cudaStreamSynchronize(stream)); + + std::vector h_res(num_tokens_ * hidden_dim_); + CHECK_CUDA(cudaMemcpy(h_res.data(), buf.result.get(), + h_res.size() * sizeof(nv_bfloat16), cudaMemcpyDeviceToHost)); + for (int tok = 0; tok < num_tokens_; ++tok) { + float exp = __bfloat162float(h_tok[tok * hidden_dim_]) * static_cast(top_k_); + float got = __bfloat162float(h_res[tok * hidden_dim_]); + ASSERT_NEAR(got, exp, bf16_tol(exp)) << "iter " << iter << " tok " << tok; + } + } + + if (g_process_id == 0) + printf(" IdentityAllSymmRepeated: passed (2 iters)\n"); + + CHECK_CUDA(cudaStreamDestroy(stream)); +} + +// Full forward+backward with symm-mem on every spec-mandated buffer: +// dispatch i/o, combine input, combine_bwd i/o, dispatch_bwd input. +// TODO: flaky on rank 0 (grad_tokens partial-zero) when run after the prior +// EPZeroCopyTest cases in the same binary; passes in isolation. Re-enable once +// the root cause (likely NCCL EP NVLS write→read coherence on grad_expert) is +// understood. Tracked separately. +TEST_F(EPZeroCopyTest, DISABLED_FullPipelineSymm) { + EPBuffers buf; + buf.alloc(num_tokens_, top_k_, hidden_dim_, num_local_experts_, + ep_size_, max_tokens_per_rank_); + upload_inputs(buf); + + // Symm-mem: tokens (dispatch input), recv_tokens (dispatch output AND + // combine input), grad_result (combine_bwd input), grad_expert + // (combine_bwd output AND dispatch_bwd input). + SymmBuf sym_tokens, sym_recv, sym_grad_result, sym_grad_expert; + sym_tokens .alloc(num_tokens_ * hidden_dim_ * sizeof(nv_bfloat16)); + sym_recv .alloc(buf.recv_capacity * hidden_dim_ * sizeof(nv_bfloat16)); + sym_grad_result.alloc(num_tokens_ * hidden_dim_ * sizeof(nv_bfloat16)); + sym_grad_expert.alloc(buf.recv_capacity * hidden_dim_ * sizeof(nv_bfloat16)); + + auto h_tok = generate_tokens(g_process_id, num_tokens_, hidden_dim_); + CHECK_CUDA(cudaMemcpy(sym_tokens.ptr, h_tok.data(), + h_tok.size() * sizeof(nv_bfloat16), cudaMemcpyHostToDevice)); + + EPTensors t(buf, num_tokens_, top_k_, hidden_dim_, num_local_experts_); + t.tokens = make_nvte_tensor(sym_tokens.ptr, + {(size_t)num_tokens_, (size_t)hidden_dim_}, kNVTEBFloat16); + t.recv_tokens = make_nvte_tensor(sym_recv.ptr, + {buf.recv_capacity, (size_t)hidden_dim_}, kNVTEBFloat16); + t.grad_result = make_nvte_tensor(sym_grad_result.ptr, + {(size_t)num_tokens_, (size_t)hidden_dim_}, kNVTEBFloat16); + t.grad_expert = make_nvte_tensor(sym_grad_expert.ptr, + {buf.recv_capacity, (size_t)hidden_dim_}, kNVTEBFloat16); + + cudaStream_t stream; + CHECK_CUDA(cudaStreamCreate(&stream)); + + uint64_t handle_id = buf.handle_id; + ASSERT_NO_THROW(nvte_ep_prepare(NVTEEpHandle{handle_id, t.handle_mem.tensor}, t.topk_idx.tensor, t.token_counts.tensor, /*alignment=*/0, stream)); + ASSERT_NO_THROW(nvte_ep_dispatch(NVTEEpHandle{handle_id, t.handle_mem.tensor}, t.topk_idx.tensor, + t.tokens.tensor, symm_window(sym_tokens), + t.topk_weights.tensor, NVTECommWindow{}, + t.recv_tokens.tensor, symm_window(sym_recv), + t.recv_topk_weights.tensor, NVTECommWindow{}, stream)); + ASSERT_NO_THROW(nvte_ep_combine(NVTEEpHandle{handle_id, t.handle_mem.tensor}, t.recv_tokens.tensor, + symm_window(sym_recv), t.result.tensor, stream)); + + std::vector h_grad(num_tokens_ * hidden_dim_, __float2bfloat16(0.1f)); + CHECK_CUDA(cudaMemcpyAsync(sym_grad_result.ptr, h_grad.data(), + h_grad.size() * sizeof(nv_bfloat16), + cudaMemcpyHostToDevice, stream)); + CHECK_CUDA(cudaMemsetAsync(sym_grad_expert.ptr, 0, sym_grad_expert.bytes, stream)); + CHECK_CUDA(cudaMemsetAsync(buf.g_recv_topk_weights.get(), 0, buf.g_recv_topk_weights.bytes(), stream)); + CHECK_CUDA(cudaMemsetAsync(buf.grad_topk_weights.get(), 0, buf.grad_topk_weights.bytes(), stream)); + + ASSERT_NO_THROW(nvte_ep_combine_bwd(NVTEEpHandle{handle_id, t.handle_mem.tensor}, t.grad_result.tensor, + symm_window(sym_grad_result), t.grad_expert.tensor, + symm_window(sym_grad_expert), stream)); + ASSERT_NO_THROW(nvte_ep_dispatch_bwd(NVTEEpHandle{handle_id, t.handle_mem.tensor}, t.grad_expert.tensor, + symm_window(sym_grad_expert), + t.g_recv_topk_weights.tensor, NVTECommWindow{}, + t.grad_tokens.tensor, t.grad_topk_weights.tensor, stream)); + CHECK_CUDA(cudaStreamSynchronize(stream)); + + ASSERT_TRUE(check_no_nan_inf(buf.result.get(), num_tokens_ * hidden_dim_, "result")); + ASSERT_TRUE(check_no_nan_inf(buf.grad_tokens.get(), num_tokens_ * hidden_dim_, "grad_tokens")); + + std::vector h_gt(num_tokens_ * hidden_dim_); + CHECK_CUDA(cudaMemcpy(h_gt.data(), buf.grad_tokens.get(), + h_gt.size() * sizeof(nv_bfloat16), cudaMemcpyDeviceToHost)); + const float kExpGrad = static_cast(top_k_) * 0.1f; + for (int tok = 0; tok < num_tokens_; ++tok) + EXPECT_NEAR(__bfloat162float(h_gt[tok * hidden_dim_]), kExpGrad, bf16_tol(kExpGrad)) + << "grad_tokens token " << tok; + + if (g_process_id == 0) printf(" FullPipelineSymm: passed\n"); + + CHECK_CUDA(cudaStreamDestroy(stream)); +} + +// ── main ────────────────────────────────────────────────────────────────────── + +int main(int argc, char* argv[]) { + if (!ep_bootstrap(argc, argv)) return 0; + int ret = RUN_ALL_TESTS(); + ep_teardown(); + return ret; +} diff --git a/transformer_engine/common/CMakeLists.txt b/transformer_engine/common/CMakeLists.txt index 030023d949..c5f8dfb1ab 100644 --- a/transformer_engine/common/CMakeLists.txt +++ b/transformer_engine/common/CMakeLists.txt @@ -379,6 +379,96 @@ if (NVTE_WITH_CUSOLVERMP) message(STATUS "Using cuSolverMp at: ${CUSOLVERMP_DIR}") endif() +# ── NCCL EP (on by default, HT mode only) ───────────────────────────────── +# Set -DNVTE_WITH_NCCL_EP=OFF (or NVTE_BUILD_WITH_NCCL_EP=0 in setup.py) to +# skip NCCL EP entirely — useful on older images whose system NCCL is below +# the 2.30.4 EP minimum. +option(NVTE_WITH_NCCL_EP "Build NCCL EP into libtransformer_engine.so" ON) +if(NVTE_WITH_NCCL_EP) +# SM>=90 and NCCL>=2.30.4 are gated at runtime in EPBackend::initialize. +# ── NCCL EP headers ──────────────────────────────────────────────────────── +# Headers + libs are produced by the in-tree 3rdparty/nccl submodule build +# (auto-built by setup.py via build_nccl_ep_submodule). +set(NCCL_EP_SUBMODULE_ROOT + "${CMAKE_CURRENT_SOURCE_DIR}/../../3rdparty/nccl") +set(NCCL_EP_INCLUDE_DIR "${NCCL_EP_SUBMODULE_ROOT}/contrib/nccl_ep/include") +if(NOT EXISTS "${NCCL_EP_INCLUDE_DIR}/nccl_ep.h") + message(FATAL_ERROR + "NCCL EP header not found at ${NCCL_EP_INCLUDE_DIR}/nccl_ep.h. " + "Run `git submodule update --init --recursive` to checkout 3rdparty/nccl.") +endif() +message(STATUS "NCCL EP headers: ${NCCL_EP_INCLUDE_DIR}") + +# ── libnccl_ep.so ────────────────────────────────────────────────────────── +set(NCCL_EP_LIB_DIR "${NCCL_EP_SUBMODULE_ROOT}/build/lib") +find_library(NCCL_EP_LIB + NAMES nccl_ep libnccl_ep + HINTS ${NCCL_EP_LIB_DIR} + NO_DEFAULT_PATH + REQUIRED) + +# ── NCCL + GIN headers ───────────────────────────────────────────────────── +# libnccl.so and all GIN headers (ncclGin.h, ncclWindow_t, ncclDevComm_t) +# ship with the base CUDA Toolkit OR the 3rdparty/nccl submodule build +# (preferred when present; auto-built by setup.py via build_nccl_ep_submodule). +if(NOT NCCL_LIB) + find_library(NCCL_LIB + NAMES nccl libnccl + HINTS ${NCCL_EP_LIB_DIR} ${CUDAToolkit_LIBRARY_DIR} + PATH_SUFFIXES lib lib64 + REQUIRED) +endif() + +set(NCCL_SUBMODULE_INCLUDE + "${CMAKE_CURRENT_SOURCE_DIR}/../../3rdparty/nccl/build/include") +if(EXISTS "${NCCL_SUBMODULE_INCLUDE}/nccl.h") + set(NCCL_INCLUDE_DIRS_FOR_TE ${NCCL_SUBMODULE_INCLUDE}) +else() + set(NCCL_INCLUDE_DIRS_FOR_TE ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES}) +endif() + +# Diagnostic: log detected NCCL header version (minimum enforced at runtime). +find_file(_nvte_nccl_header_path nccl.h + PATHS ${NCCL_INCLUDE_DIRS_FOR_TE} ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES} + NO_DEFAULT_PATH) +if(_nvte_nccl_header_path) + file(READ "${_nvte_nccl_header_path}" _nvte_nccl_h) + string(REGEX MATCH "#define[ \t]+NCCL_MAJOR[ \t]+([0-9]+)" _ "${_nvte_nccl_h}") + set(_nvte_nccl_major "${CMAKE_MATCH_1}") + string(REGEX MATCH "#define[ \t]+NCCL_MINOR[ \t]+([0-9]+)" _ "${_nvte_nccl_h}") + set(_nvte_nccl_minor "${CMAKE_MATCH_1}") + string(REGEX MATCH "#define[ \t]+NCCL_PATCH[ \t]+([0-9]+)" _ "${_nvte_nccl_h}") + set(_nvte_nccl_patch "${CMAKE_MATCH_1}") + if(_nvte_nccl_major AND _nvte_nccl_minor AND _nvte_nccl_patch) + message(STATUS "NCCL header: ${_nvte_nccl_header_path} (version ${_nvte_nccl_major}.${_nvte_nccl_minor}.${_nvte_nccl_patch})") + endif() +endif() + +target_include_directories(transformer_engine PRIVATE + ${NCCL_EP_INCLUDE_DIR} + ${NCCL_INCLUDE_DIRS_FOR_TE}) # covers nccl.h + nccl_device/ + +target_link_libraries(transformer_engine PUBLIC + ${NCCL_EP_LIB} + ${NCCL_LIB}) + +# Embed rpath so the installed wheel finds libnccl_ep.so at runtime. +# libnccl.so is already on the system via the Toolkit — no rpath needed for it. +set_target_properties(transformer_engine PROPERTIES + INSTALL_RPATH "$ORIGIN;${NCCL_EP_LIB_DIR}") + +target_sources(transformer_engine PRIVATE + ep/ep_backend.cpp + ep/ep_api.cpp) + +message(STATUS "NCCL EP enabled: ${NCCL_EP_LIB}") +message(STATUS "NCCL EP include: ${NCCL_EP_INCLUDE_DIR}") +else() + # NCCL EP off: export throwing nvte_ep_* stubs so framework bindings link. + target_sources(transformer_engine PRIVATE ep/ep_api_stub.cpp) + message(STATUS "NCCL EP disabled (NVTE_WITH_NCCL_EP=OFF) — using nvte_ep_* stubs") +endif() + # Number of philox4x32 rounds for stochastic rounding (build-time constant). set(NVTE_BUILD_NUM_PHILOX_ROUNDS_STR $ENV{NVTE_BUILD_NUM_PHILOX_ROUNDS}) if (NOT NVTE_BUILD_NUM_PHILOX_ROUNDS_STR) diff --git a/transformer_engine/common/ep/ep_api.cpp b/transformer_engine/common/ep/ep_api.cpp new file mode 100644 index 0000000000..89d8b38607 --- /dev/null +++ b/transformer_engine/common/ep/ep_api.cpp @@ -0,0 +1,76 @@ +/************************************************************************* + * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +/*! \file ep_api.cpp + * \brief nvte_ep_* C API: thin delegations to the EPBackend singleton. + */ + +#include +#include + +#include "../common.h" +#include "../util/logging.h" +#include "ep_backend.h" + +using transformer_engine::ep::EPBackend; + +void nvte_ep_initialize(void* ep_comm, NVTEEpGroupConfig group_config) { + NVTE_CHECK(ep_comm != nullptr, "ep_comm must not be null"); + EPBackend::initialize(static_cast(ep_comm), group_config); +} + +void nvte_ep_shutdown(void) { EPBackend::shutdown(); } + +uint64_t nvte_ep_register_layer(NVTEEpLayerConfig layer_config, size_t* handle_mem_size) { + NVTE_CHECK(handle_mem_size != nullptr, "handle_mem_size must not be null"); + return EPBackend::get().register_layer(layer_config, handle_mem_size); +} + +void nvte_ep_prepare(NVTEEpHandle handle, NVTETensor topk_idx, NVTETensor token_counts, + size_t dispatch_output_per_expert_alignment, cudaStream_t stream) { + void* mem_ptr = nvte_tensor_data(handle.mem); + NVTE_CHECK(mem_ptr != nullptr, "handle_mem tensor data must not be null"); + EPBackend::get().prepare(handle.id, topk_idx, token_counts, mem_ptr, + dispatch_output_per_expert_alignment, stream); +} + +void nvte_ep_dispatch(NVTEEpHandle handle, NVTETensor topk_idx, NVTETensor tokens, + NVTECommWindow tokens_win, NVTETensor topk_weights, + NVTECommWindow topk_weights_win, NVTETensor recv_tokens, + NVTECommWindow recv_tokens_win, NVTETensor recv_topk_weights, + NVTECommWindow recv_topk_weights_win, cudaStream_t stream) { + void* mem_ptr = nvte_tensor_data(handle.mem); + NVTE_CHECK(mem_ptr != nullptr, "handle_mem tensor data must not be null"); + EPBackend::get().dispatch(handle.id, mem_ptr, topk_idx, tokens, tokens_win, topk_weights, + topk_weights_win, recv_tokens, recv_tokens_win, recv_topk_weights, + recv_topk_weights_win, stream); +} + +void nvte_ep_combine(NVTEEpHandle handle, NVTETensor expert_out, NVTECommWindow expert_out_win, + NVTETensor result, cudaStream_t stream) { + void* mem_ptr = nvte_tensor_data(handle.mem); + NVTE_CHECK(mem_ptr != nullptr, "handle_mem tensor data must not be null"); + EPBackend::get().combine(handle.id, mem_ptr, expert_out, expert_out_win, result, stream); +} + +void nvte_ep_dispatch_bwd(NVTEEpHandle handle, NVTETensor grad, NVTECommWindow grad_win, + NVTETensor g_recv_topk_weights, NVTECommWindow g_recv_topk_weights_win, + NVTETensor grad_tokens, NVTETensor grad_topk_weights, + cudaStream_t stream) { + void* mem_ptr = nvte_tensor_data(handle.mem); + NVTE_CHECK(mem_ptr != nullptr, "handle_mem tensor data must not be null"); + EPBackend::get().dispatch_bwd(handle.id, mem_ptr, grad, grad_win, g_recv_topk_weights, + g_recv_topk_weights_win, grad_tokens, grad_topk_weights, stream); +} + +void nvte_ep_combine_bwd(NVTEEpHandle handle, NVTETensor grad, NVTECommWindow grad_win, + NVTETensor grad_expert_out, NVTECommWindow grad_expert_out_win, + cudaStream_t stream) { + void* mem_ptr = nvte_tensor_data(handle.mem); + NVTE_CHECK(mem_ptr != nullptr, "handle_mem tensor data must not be null"); + EPBackend::get().combine_bwd(handle.id, mem_ptr, grad, grad_win, grad_expert_out, + grad_expert_out_win, stream); +} diff --git a/transformer_engine/common/ep/ep_api_stub.cpp b/transformer_engine/common/ep/ep_api_stub.cpp new file mode 100644 index 0000000000..fe4127d87d --- /dev/null +++ b/transformer_engine/common/ep/ep_api_stub.cpp @@ -0,0 +1,61 @@ +/************************************************************************* + * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +/*! \file ep_api_stub.cpp + * \brief Throwing nvte_ep_* stubs compiled when NVTE_WITH_NCCL_EP=OFF. + */ + +#include + +#include "../util/logging.h" + +namespace { +[[noreturn]] void ep_not_built() { + NVTE_ERROR( + "NCCL EP is not built into this TransformerEngine. Rebuild TE with " + "NVTE_BUILD_WITH_NCCL_EP=1 and CUDA arch >= 90 (e.g. NVTE_CUDA_ARCHS=\"90\")."); +} +} // namespace + +void nvte_ep_initialize(void* /*ep_comm*/, NVTEEpGroupConfig /*group_config*/) { ep_not_built(); } + +void nvte_ep_shutdown(void) {} + +uint64_t nvte_ep_register_layer(NVTEEpLayerConfig /*layer_config*/, size_t* /*handle_mem_size*/) { + ep_not_built(); +} + +void nvte_ep_prepare(NVTEEpHandle /*handle*/, NVTETensor /*topk_idx*/, NVTETensor /*token_counts*/, + size_t /*dispatch_output_per_expert_alignment*/, cudaStream_t /*stream*/) { + ep_not_built(); +} + +void nvte_ep_dispatch(NVTEEpHandle /*handle*/, NVTETensor /*topk_idx*/, NVTETensor /*tokens*/, + NVTECommWindow /*tokens_win*/, NVTETensor /*topk_weights*/, + NVTECommWindow /*topk_weights_win*/, NVTETensor /*recv_tokens*/, + NVTECommWindow /*recv_tokens_win*/, NVTETensor /*recv_topk_weights*/, + NVTECommWindow /*recv_topk_weights_win*/, cudaStream_t /*stream*/) { + ep_not_built(); +} + +void nvte_ep_combine(NVTEEpHandle /*handle*/, NVTETensor /*expert_out*/, + NVTECommWindow /*expert_out_win*/, NVTETensor /*result*/, + cudaStream_t /*stream*/) { + ep_not_built(); +} + +void nvte_ep_dispatch_bwd(NVTEEpHandle /*handle*/, NVTETensor /*grad*/, NVTECommWindow /*grad_win*/, + NVTETensor /*g_recv_topk_weights*/, + NVTECommWindow /*g_recv_topk_weights_win*/, NVTETensor /*grad_tokens*/, + NVTETensor /*grad_topk_weights*/, cudaStream_t /*stream*/) { + ep_not_built(); +} + +void nvte_ep_combine_bwd(NVTEEpHandle /*handle*/, NVTETensor /*grad*/, NVTECommWindow /*grad_win*/, + NVTETensor /*grad_expert_out*/, NVTECommWindow /*grad_expert_out_win*/, + cudaStream_t /*stream*/) { + ep_not_built(); +} diff --git a/transformer_engine/common/ep/ep_backend.cpp b/transformer_engine/common/ep/ep_backend.cpp new file mode 100644 index 0000000000..ae0f3ab888 --- /dev/null +++ b/transformer_engine/common/ep/ep_backend.cpp @@ -0,0 +1,514 @@ +/************************************************************************* + * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +/*! \file ep_backend.cpp + * \brief EPBackend implementation. See ep_backend.h for the op flow. + */ + +#include "ep_backend.h" + +#include +#include +#include +#include +#include +#include + +#include "../common.h" +#include "../util/cuda_runtime.h" +#include "../util/logging.h" + +namespace transformer_engine { +namespace ep { + +namespace { + +// Build a by-value ncclEpTensor_t descriptor. `sizes` is caller-owned and must +// outlive any NCCL EP call that consumes the descriptor. +inline ncclEpTensor_t make_tensor(void* data, unsigned int ndim, ncclDataType_t datatype, + size_t* sizes) { + ncclEpTensor_t t = NCCL_EP_TENSOR_INIT; + t.ndim = ndim; + t.datatype = datatype; + t.data = data; + t.sizes = sizes; + return t; +} + +// Payload descriptor: prefer the symmem window when set, else fall back to the +// NVTETensor's raw device pointer. +inline ncclEpTensor_t make_payload_tensor(const NVTETensor t, const NVTECommWindow& win, + unsigned int ndim, ncclDataType_t datatype, + size_t* sizes) { + ncclEpTensor_t desc = NCCL_EP_TENSOR_INIT; + desc.ndim = ndim; + desc.datatype = datatype; + desc.sizes = sizes; + if (win.window != nullptr) { + desc.win_hdl = win.window; + desc.win_offset = win.offset; + } else { + desc.data = nvte_tensor_data(t); + NVTE_CHECK(desc.data != nullptr, "payload tensor data must not be null"); + } + return desc; +} + +// RAII guard for ncclEpHandle_t — destroys on scope exit, leak-free on throw. +class ScopedEpHandle { + public: + ScopedEpHandle() = default; + explicit ScopedEpHandle(ncclEpHandle_t h) : h_(h) {} + ~ScopedEpHandle() { + if (h_ != nullptr) ncclEpHandleDestroy(h_); + } + ScopedEpHandle(const ScopedEpHandle&) = delete; + ScopedEpHandle& operator=(const ScopedEpHandle&) = delete; + ScopedEpHandle(ScopedEpHandle&& other) noexcept : h_(other.h_) { other.h_ = nullptr; } + ScopedEpHandle& operator=(ScopedEpHandle&& other) noexcept { + if (this != &other) { + if (h_ != nullptr) ncclEpHandleDestroy(h_); + h_ = other.h_; + other.h_ = nullptr; + } + return *this; + } + operator ncclEpHandle_t() const { return h_; } + ncclEpHandle_t get() const { return h_; } + + private: + ncclEpHandle_t h_ = nullptr; +}; + +} // namespace + +// --------------------------------------------------------------------------- +// Singleton + bootstrap +// --------------------------------------------------------------------------- + +EPBackend& EPBackend::instance() { + static EPBackend inst; + return inst; +} + +EPBackend& EPBackend::get() { + EPBackend& inst = instance(); + NVTE_CHECK(inst.initialized_, "EPBackend not initialized. Call nvte_ep_initialize() first."); + return inst; +} + +void EPBackend::validate_config(const NVTEEpGroupConfig& config) { + NVTE_CHECK(config.ep_size > 0, "ep_size must be positive, got ", config.ep_size); + NVTE_CHECK(config.num_experts > 0, "num_experts must be positive, got ", config.num_experts); + NVTE_CHECK(config.max_tokens_per_rank > 0, "max_tokens_per_rank must be positive, got ", + config.max_tokens_per_rank); + NVTE_CHECK(config.max_recv_tokens_per_rank > 0, "max_recv_tokens_per_rank must be positive, got ", + config.max_recv_tokens_per_rank); + NVTE_CHECK(config.hidden_dim > 0, "hidden_dim must be positive, got ", config.hidden_dim); + NVTE_CHECK(config.hidden_dim * sizeof(nv_bfloat16) >= 16, + "hidden_dim * 2 must be >= 16 (NCCL EP 16B row alignment); got hidden_dim=", + config.hidden_dim); + NVTE_CHECK(config.num_experts % config.ep_size == 0, "num_experts (", config.num_experts, + ") must be divisible by ep_size (", config.ep_size, ")"); + NVTE_CHECK(config.max_num_sms >= 0, "max_num_sms must be >= 0 (0 = auto), got ", + config.max_num_sms); + + int device, major; + NVTE_CHECK_CUDA(cudaGetDevice(&device)); + NVTE_CHECK_CUDA(cudaDeviceGetAttribute(&major, cudaDevAttrComputeCapabilityMajor, device)); + NVTE_CHECK(major >= 9, + "NCCL EP requires SM_90+ (Hopper or later), " + "but current device has compute capability ", + major, ".x"); + + // NCCL EP needs CUDA multicast (NVLS); init hangs without it. + NVTE_CHECK(cuda::supports_multicast(device), + "NCCL EP requires CUDA multicast (NVLS) support on device ", device, + " but CU_DEVICE_ATTRIBUTE_MULTICAST_SUPPORTED reports 0."); +} + +void EPBackend::initialize(ncclComm_t ep_comm, NVTEEpGroupConfig config) { + EPBackend& inst = instance(); + std::lock_guard lock(inst.mutex_); + NVTE_CHECK(!inst.initialized_, "EP already initialized. Call initialize only once per process."); + NVTE_CHECK(ep_comm != nullptr, "ep_comm must not be null"); + + // Runtime gate: NCCL >= 2.30.4 (matches the submodule pin). + constexpr int kMinNcclVersion = 23004; + int nccl_version = 0; + NVTE_CHECK_NCCL(ncclGetVersion(&nccl_version)); + NVTE_CHECK(nccl_version >= kMinNcclVersion, "NCCL EP requires NCCL >= 2.30.4, found ", + nccl_version / 10000, ".", (nccl_version / 100) % 100, ".", nccl_version % 100, + " at runtime."); + + validate_config(config); + + int comm_size = 0; + NVTE_CHECK_NCCL(ncclCommCount(ep_comm, &comm_size)); + NVTE_CHECK(comm_size == config.ep_size, "ep_comm size (", comm_size, ") must equal ep_size (", + config.ep_size, "). Pass the EP sub-communicator, not the world comm."); + + inst.init(ep_comm, config); +} + +void EPBackend::shutdown() { + EPBackend& inst = instance(); + std::lock_guard lock(inst.mutex_); + if (!inst.initialized_) return; + inst.handles_.clear(); + // ncclEpGroupDestroy reads from ep_comm_; destroy group while comm is still alive. + if (inst.ep_group_ != nullptr) { + ncclEpGroupDestroy(inst.ep_group_); + inst.ep_group_ = nullptr; + } + inst.ep_comm_ = nullptr; // borrowed — caller destroys + inst.initialized_ = false; +} + +// --------------------------------------------------------------------------- +// Helpers +// --------------------------------------------------------------------------- + +ncclDataType_t EPBackend::nvte_dtype_to_nccl(NVTEDType dtype) { + switch (dtype) { + case kNVTEFloat32: + return ncclFloat32; + case kNVTEFloat16: + return ncclFloat16; + case kNVTEBFloat16: + return ncclBfloat16; + case kNVTEInt32: + return ncclInt32; + case kNVTEInt64: + return ncclInt64; + case kNVTEByte: + return ncclUint8; + case kNVTEFloat8E4M3: + return ncclFloat8e4m3; + case kNVTEFloat8E5M2: + return ncclFloat8e5m2; + default: + NVTE_ERROR("Unsupported NVTEDType for NCCL EP conversion: ", static_cast(dtype)); + } + return ncclFloat32; // unreachable +} + +// Open a transient ncclEpHandle over handle_mem. Caller owns the result. +ncclEpHandle_t EPBackend::open_handle(void* handle_mem, size_t handle_mem_size, int num_topk, + size_t dispatch_output_per_expert_alignment) { + size_t hm_sizes[1] = {handle_mem_size}; + ncclEpTensor_t routing_desc = make_tensor(handle_mem, 1, ncclUint8, hm_sizes); + ncclEpHandleConfig_t hcfg = NCCL_EP_HANDLE_CONFIG_INIT; + hcfg.dispatch_output_per_expert_alignment = dispatch_output_per_expert_alignment; + ncclEpHandle_t handle; + NVTE_CHECK_NCCL(ncclEpInitHandle(&handle, ep_group_, NCCL_EP_LAYOUT_EXPERT_MAJOR, &hcfg, num_topk, + &routing_desc)); + return handle; +} + +// --------------------------------------------------------------------------- +// Lifecycle +// --------------------------------------------------------------------------- + +// Static-dtor teardown: skip NCCL calls (CUDA context / borrowed ep_comm_ may +// already be gone) and release in-memory state only. +EPBackend::~EPBackend() { + std::lock_guard lock(mutex_); + if (!initialized_) return; + handles_.clear(); + ep_group_ = nullptr; + ep_comm_ = nullptr; + initialized_ = false; +} + +void EPBackend::init(ncclComm_t ep_comm, NVTEEpGroupConfig group_config) { + NVTE_CHECK(!initialized_, "EPBackend already initialized"); + + group_config_ = group_config; + + ncclEpGroupConfig_t cfg = NCCL_EP_GROUP_CONFIG_INIT; + cfg.algorithm = NCCL_EP_ALGO_HIGH_THROUGHPUT; + cfg.num_experts = static_cast(group_config.num_experts); + cfg.max_dispatch_tokens_per_rank = static_cast(group_config.max_tokens_per_rank); + cfg.max_token_bytes = static_cast(group_config.hidden_dim * sizeof(nv_bfloat16)); + cfg.rdma_buffer_size = NCCL_EP_AUTO; + cfg.num_qp_per_rank = NCCL_EP_AUTO; + cfg.num_channels = NCCL_EP_AUTO; + cfg.max_num_sms = group_config.max_num_sms > 0 + ? static_cast(group_config.max_num_sms) + : NCCL_EP_AUTO; + // Must be > 0; NCCL EP errors out on 0. + cfg.max_recv_tokens_per_rank = static_cast(group_config.max_recv_tokens_per_rank); + + NVTE_CHECK_NCCL(ncclEpCreateGroup(&ep_group_, ep_comm, &cfg)); + + ep_comm_ = ep_comm; + + initialized_ = true; +} + +// --------------------------------------------------------------------------- +// Per-handle_id config cache +// --------------------------------------------------------------------------- + +uint64_t EPBackend::insert_new_entry(size_t handle_mem_size, int top_k, size_t alignment) { + if (handle_cache_cap_ == 0) { + const char* cap_env = std::getenv("NVTE_EP_HANDLE_CACHE_SIZE"); + handle_cache_cap_ = (cap_env != nullptr) ? std::max(1, std::atoi(cap_env)) : 8192; + } + NVTE_CHECK(handles_.size() < handle_cache_cap_, "EP handle cache full (", handle_cache_cap_, + " entries). Raise via NVTE_EP_HANDLE_CACHE_SIZE."); + uint64_t id = next_handle_id_.fetch_add(1, std::memory_order_relaxed); + handles_.emplace(id, HandleEntry{handle_mem_size, alignment, top_k}); + return id; +} + +EPBackend::HandleEntry& EPBackend::lookup_config(uint64_t handle_id) { + auto it = handles_.find(handle_id); + NVTE_CHECK(it != handles_.end(), "ep op on handle_id=", handle_id, + " with no cached config — call ep_prepare first."); + return it->second; +} + +// --------------------------------------------------------------------------- +// Per-step operations +// --------------------------------------------------------------------------- + +uint64_t EPBackend::register_layer(NVTEEpLayerConfig layer_config, size_t* handle_mem_size) { + NVTE_CHECK(initialized_, "EPBackend not initialized"); + NVTE_CHECK(layer_config.top_k > 0, "NVTEEpLayerConfig.top_k must be > 0"); + NVTE_CHECK(handle_mem_size != nullptr, "handle_mem_size must not be null"); + ncclEpHandleConfig_t hcfg = NCCL_EP_HANDLE_CONFIG_INIT; + hcfg.dispatch_output_per_expert_alignment = layer_config.dispatch_output_per_expert_alignment; + size_t hm_size = 0; + NVTE_CHECK_NCCL(ncclEpHandleMemSize(ep_group_, NCCL_EP_LAYOUT_EXPERT_MAJOR, &hcfg, &hm_size, + layer_config.top_k)); + *handle_mem_size = hm_size; + std::lock_guard lock(mutex_); + return insert_new_entry(hm_size, layer_config.top_k, + layer_config.dispatch_output_per_expert_alignment); +} + +void EPBackend::prepare(uint64_t handle_id, const NVTETensor topk_idx, NVTETensor token_counts, + void* handle_mem, size_t dispatch_output_per_expert_alignment, + cudaStream_t stream) { + NVTE_CHECK(initialized_, "EPBackend not initialized"); + NVTE_CHECK(handle_mem != nullptr, "handle_mem must not be null"); + + NVTEShape idx_shape = nvte_tensor_shape(topk_idx); + void* idx_data = nvte_tensor_data(topk_idx); + NVTE_CHECK(idx_data != nullptr, "topk_idx data must not be null"); + + const size_t num_tokens = idx_shape.data[0]; + const size_t top_k = idx_shape.ndim > 1 ? idx_shape.data[1] : 1; + const size_t num_local_experts = + static_cast(group_config_.num_experts / group_config_.ep_size); + + size_t idx_sizes[2] = {num_tokens, top_k}; + ncclEpTensor_t nccl_topk_idx = make_tensor(idx_data, 2, ncclInt64, idx_sizes); + + // ncclEpUpdateHandle writes per-expert counts via expert_counters. + size_t cnt_sizes[1] = {num_local_experts}; + ncclEpTensor_t token_counts_desc; + void* token_counts_data = (token_counts != nullptr) ? nvte_tensor_data(token_counts) : nullptr; + if (token_counts_data != nullptr) { + token_counts_desc = make_tensor(token_counts_data, 1, ncclInt32, cnt_sizes); + } + ncclEpLayoutInfo_t layout_info = NCCL_EP_LAYOUT_INFO_INIT; + layout_info.expert_counters = (token_counts_data != nullptr) ? &token_counts_desc : nullptr; + + ScopedEpHandle transient; + { + std::lock_guard lock(mutex_); + HandleEntry& cfg = lookup_config(handle_id); + NVTE_CHECK(cfg.alignment == dispatch_output_per_expert_alignment, + "ep_prepare: alignment mismatch for handle_id=", handle_id, + " (cached=", cfg.alignment, ", got=", dispatch_output_per_expert_alignment, ")"); + transient = + ScopedEpHandle(open_handle(handle_mem, cfg.handle_mem_size, cfg.top_k, cfg.alignment)); + } + NVTE_CHECK_NCCL(ncclEpUpdateHandle(transient, &nccl_topk_idx, &layout_info, stream)); +} + +void EPBackend::dispatch(uint64_t handle_id, void* handle_mem, const NVTETensor topk_idx, + const NVTETensor tokens, const NVTECommWindow& tokens_win, + const NVTETensor topk_weights, const NVTECommWindow& topk_weights_win, + NVTETensor recv_tokens, const NVTECommWindow& recv_tokens_win, + NVTETensor recv_topk_weights, const NVTECommWindow& recv_topk_weights_win, + cudaStream_t stream) { + NVTE_CHECK(initialized_, "EPBackend not initialized"); + NVTE_CHECK(handle_mem != nullptr, "handle_mem must not be null"); + + NVTEShape tok_shape = nvte_tensor_shape(tokens); + NVTEDType tok_dtype = nvte_tensor_type(tokens); + + const size_t num_tokens = tok_shape.data[0]; + const size_t hidden_dim = tok_shape.data[1]; + + size_t tok_sizes[2] = {num_tokens, hidden_dim}; + ncclEpTensor_t nccl_tokens_in = + make_payload_tensor(tokens, tokens_win, 2, nvte_dtype_to_nccl(tok_dtype), tok_sizes); + + const bool is_forward = (topk_weights != nullptr); + + // Routing is cached in handle_mem by ep_prepare; dispatch only needs + // topk_weights to reconstruct the sparse-to-dense prob map. + size_t weights_in_sizes[2] = {0, 0}; + ncclEpTensor_t nccl_topk_weights_in; + if (is_forward) { + NVTE_CHECK(topk_idx != nullptr, "topk_idx required in forward dispatch"); + NVTEShape idx_shape = nvte_tensor_shape(topk_idx); + const size_t top_k = idx_shape.ndim > 1 ? idx_shape.data[1] : 1; + weights_in_sizes[0] = num_tokens; + weights_in_sizes[1] = top_k; + nccl_topk_weights_in = + make_payload_tensor(topk_weights, topk_weights_win, 2, ncclFloat32, weights_in_sizes); + } + + NVTEShape recv_shape = nvte_tensor_shape(recv_tokens); + NVTEDType recv_dtype = nvte_tensor_type(recv_tokens); + + size_t recv_sizes[2] = {recv_shape.data[0], recv_shape.data[1]}; + ncclEpTensor_t nccl_tokens_out = make_payload_tensor(recv_tokens, recv_tokens_win, 2, + nvte_dtype_to_nccl(recv_dtype), recv_sizes); + + size_t weights_out_sizes[1] = {recv_shape.data[0]}; + ncclEpTensor_t nccl_topk_weights_out; + if (is_forward) { + NVTE_CHECK(recv_topk_weights != nullptr, + "recv_topk_weights must not be null in forward dispatch"); + NVTEShape recv_w_shape = nvte_tensor_shape(recv_topk_weights); + NVTE_CHECK(recv_w_shape.ndim == 1, "recv_topk_weights must be 1D [recv_capacity]"); + nccl_topk_weights_out = make_payload_tensor(recv_topk_weights, recv_topk_weights_win, 1, + ncclFloat32, weights_out_sizes); + } + + ncclEpDispatchInputs_t in_struct = NCCL_EP_DISPATCH_INPUTS_INIT; + in_struct.tokens = &nccl_tokens_in; + in_struct.topk_weights = is_forward ? &nccl_topk_weights_in : nullptr; + + ncclEpDispatchOutputs_t out_struct = NCCL_EP_DISPATCH_OUTPUTS_INIT; + out_struct.tokens = &nccl_tokens_out; + out_struct.topk_weights = is_forward ? &nccl_topk_weights_out : nullptr; + + ncclEpDispatchConfig_t dispatch_cfg = NCCL_EP_DISPATCH_CONFIG_INIT; + dispatch_cfg.pass_direction = is_forward ? NCCL_EP_FWD_PASS : NCCL_EP_BWD_PASS; + + ScopedEpHandle transient; + { + std::lock_guard lock(mutex_); + HandleEntry& cfg = lookup_config(handle_id); + transient = + ScopedEpHandle(open_handle(handle_mem, cfg.handle_mem_size, cfg.top_k, cfg.alignment)); + } + NVTE_CHECK_NCCL(ncclEpDispatch(transient, &in_struct, &out_struct, + /*layout_info=*/nullptr, &dispatch_cfg, stream)); +} + +void EPBackend::combine(uint64_t handle_id, void* handle_mem, const NVTETensor expert_out, + const NVTECommWindow& expert_out_win, NVTETensor result, + cudaStream_t stream) { + NVTE_CHECK(initialized_, "EPBackend not initialized"); + NVTE_CHECK(handle_mem != nullptr, "handle_mem must not be null"); + + NVTEShape exp_shape = nvte_tensor_shape(expert_out); + NVTEDType exp_dtype = nvte_tensor_type(expert_out); + + size_t exp_sizes[2] = {exp_shape.data[0], exp_shape.data[1]}; + ncclEpTensor_t nccl_expert_in = + make_payload_tensor(expert_out, expert_out_win, 2, nvte_dtype_to_nccl(exp_dtype), exp_sizes); + + NVTEShape res_shape = nvte_tensor_shape(result); + void* res_data = nvte_tensor_data(result); + NVTEDType res_dtype = nvte_tensor_type(result); + NVTE_CHECK(res_data != nullptr, "result data must not be null"); + + size_t res_sizes[2] = {res_shape.data[0], res_shape.data[1]}; + ncclEpTensor_t nccl_result_out = + make_tensor(res_data, 2, nvte_dtype_to_nccl(res_dtype), res_sizes); + + ncclEpCombineInputs_t in_struct = NCCL_EP_COMBINE_INPUTS_INIT; + in_struct.tokens = &nccl_expert_in; + + ncclEpCombineOutputs_t out_struct = NCCL_EP_COMBINE_OUTPUTS_INIT; + out_struct.tokens = &nccl_result_out; + + ScopedEpHandle transient; + { + std::lock_guard lock(mutex_); + HandleEntry& cfg = lookup_config(handle_id); + transient = + ScopedEpHandle(open_handle(handle_mem, cfg.handle_mem_size, cfg.top_k, cfg.alignment)); + } + NVTE_CHECK_NCCL(ncclEpCombine(transient, &in_struct, &out_struct, /*config=*/nullptr, stream)); +} + +void EPBackend::dispatch_bwd(uint64_t handle_id, void* handle_mem, const NVTETensor grad, + const NVTECommWindow& grad_win, const NVTETensor g_recv_topk_weights, + const NVTECommWindow& g_recv_topk_weights_win, NVTETensor grad_tokens, + NVTETensor grad_topk_weights, cudaStream_t stream) { + NVTE_CHECK(initialized_, "EPBackend not initialized"); + NVTE_CHECK(handle_mem != nullptr, "handle_mem must not be null"); + + NVTEShape g_shape = nvte_tensor_shape(grad); + NVTEDType g_dtype = nvte_tensor_type(grad); + size_t g_sizes[2] = {g_shape.data[0], g_shape.data[1]}; + ncclEpTensor_t nccl_tok_in = + make_payload_tensor(grad, grad_win, 2, nvte_dtype_to_nccl(g_dtype), g_sizes); + + // g_recv_topk_weights must be 1D [recv_capacity] — caller flattens. + NVTEShape gw_shape = nvte_tensor_shape(g_recv_topk_weights); + NVTE_CHECK(gw_shape.ndim == 1, + "g_recv_topk_weights must be 1D [recv_capacity]; caller must flatten leading dims"); + size_t gw_sizes[1] = {gw_shape.data[0]}; + ncclEpTensor_t nccl_w_in = + make_payload_tensor(g_recv_topk_weights, g_recv_topk_weights_win, 1, ncclFloat32, gw_sizes); + + NVTEShape gt_shape = nvte_tensor_shape(grad_tokens); + void* gt_data = nvte_tensor_data(grad_tokens); + NVTE_CHECK(gt_data != nullptr, "grad_tokens data must not be null"); + size_t gt_sizes[2] = {gt_shape.data[0], gt_shape.data[1]}; + ncclEpTensor_t nccl_tok_out = make_tensor(gt_data, 2, nvte_dtype_to_nccl(g_dtype), gt_sizes); + + NVTEShape gtw_shape = nvte_tensor_shape(grad_topk_weights); + void* gtw_data = nvte_tensor_data(grad_topk_weights); + NVTE_CHECK(gtw_data != nullptr, "grad_topk_weights data must not be null"); + NVTE_CHECK(gtw_shape.ndim == 2, "grad_topk_weights must be 2D [T, top_k]"); + size_t gtw_sizes[2] = {gtw_shape.data[0], gtw_shape.data[1]}; + ncclEpTensor_t nccl_w_out = make_tensor(gtw_data, 2, ncclFloat32, gtw_sizes); + + ncclEpCombineInputs_t in_struct = NCCL_EP_COMBINE_INPUTS_INIT; + in_struct.tokens = &nccl_tok_in; + in_struct.topk_weights = &nccl_w_in; + + ncclEpCombineOutputs_t out_struct = NCCL_EP_COMBINE_OUTPUTS_INIT; + out_struct.tokens = &nccl_tok_out; + out_struct.topk_weights = &nccl_w_out; + + ncclEpCombineConfig_t cfg = NCCL_EP_COMBINE_CONFIG_INIT; + cfg.pass_direction = NCCL_EP_BWD_PASS; + + ScopedEpHandle transient; + { + std::lock_guard lock(mutex_); + HandleEntry& entry = lookup_config(handle_id); + transient = ScopedEpHandle( + open_handle(handle_mem, entry.handle_mem_size, entry.top_k, entry.alignment)); + } + NVTE_CHECK_NCCL(ncclEpCombine(transient, &in_struct, &out_struct, &cfg, stream)); +} + +void EPBackend::combine_bwd(uint64_t handle_id, void* handle_mem, const NVTETensor grad, + const NVTECommWindow& grad_win, NVTETensor grad_expert_out, + const NVTECommWindow& grad_expert_out_win, cudaStream_t stream) { + // Backward of combine = reverse-direction dispatch. + dispatch(handle_id, handle_mem, /*topk_idx=*/nullptr, grad, grad_win, /*topk_weights=*/nullptr, + /*topk_weights_win=*/NVTECommWindow{}, grad_expert_out, grad_expert_out_win, + /*recv_topk_weights=*/nullptr, /*recv_topk_weights_win=*/NVTECommWindow{}, stream); +} + +} // namespace ep +} // namespace transformer_engine diff --git a/transformer_engine/common/ep/ep_backend.h b/transformer_engine/common/ep/ep_backend.h new file mode 100644 index 0000000000..18307ebb4f --- /dev/null +++ b/transformer_engine/common/ep/ep_backend.h @@ -0,0 +1,114 @@ +/************************************************************************* + * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +/*! \file ep_backend.h + * \brief Internal NCCL EP singleton; not part of the public API. + * + * Per handle_id the cache stores config only (no device pointers), so + * handle_mem may be relocated between ops. Cap: NVTE_EP_HANDLE_CACHE_SIZE + * (default 8192); overflow throws. + */ + +#ifndef TRANSFORMER_ENGINE_COMMON_EP_EP_BACKEND_H_ +#define TRANSFORMER_ENGINE_COMMON_EP_EP_BACKEND_H_ + +#include +#include +#include +#include + +#include +#include +#include +#include +#include + +namespace transformer_engine { +namespace ep { + +/*! \brief EP backend singleton — owns the NCCL EP group; borrows the comm. */ +class EPBackend { + public: + /*! \brief Access the singleton. Aborts if not initialized. */ + static EPBackend& get(); + + /*! \brief Bootstrap from an existing EP sub-communicator. + * ep_comm is borrowed; the caller keeps it alive until shutdown() returns + * and must span exactly config.ep_size ranks. + */ + static void initialize(ncclComm_t ep_comm, NVTEEpGroupConfig config); + + /*! \brief Tear down the backend. Idempotent. Does not destroy ep_comm_. */ + static void shutdown(); + + // Host-only: reserve a fresh handle_id, cache the layer config, and report + // the handle_mem buffer size the caller must allocate. + uint64_t register_layer(NVTEEpLayerConfig layer_config, size_t* handle_mem_size); + + void prepare(uint64_t handle_id, const NVTETensor topk_idx, NVTETensor token_counts, + void* handle_mem, size_t dispatch_output_per_expert_alignment, cudaStream_t stream); + + void dispatch(uint64_t handle_id, void* handle_mem, const NVTETensor topk_idx, + const NVTETensor tokens, const NVTECommWindow& tokens_win, + const NVTETensor topk_weights, const NVTECommWindow& topk_weights_win, + NVTETensor recv_tokens, const NVTECommWindow& recv_tokens_win, + NVTETensor recv_topk_weights, const NVTECommWindow& recv_topk_weights_win, + cudaStream_t stream); + + void combine(uint64_t handle_id, void* handle_mem, const NVTETensor expert_out, + const NVTECommWindow& expert_out_win, NVTETensor result, cudaStream_t stream); + + // g_recv_topk_weights: 1D [recv_capacity] f32; grad_topk_weights: 2D [T, top_k] f32. + void dispatch_bwd(uint64_t handle_id, void* handle_mem, const NVTETensor grad, + const NVTECommWindow& grad_win, const NVTETensor g_recv_topk_weights, + const NVTECommWindow& g_recv_topk_weights_win, NVTETensor grad_tokens, + NVTETensor grad_topk_weights, cudaStream_t stream); + + void combine_bwd(uint64_t handle_id, void* handle_mem, const NVTETensor grad, + const NVTECommWindow& grad_win, NVTETensor grad_expert_out, + const NVTECommWindow& grad_expert_out_win, cudaStream_t stream); + + private: + EPBackend() = default; + ~EPBackend(); + EPBackend(const EPBackend&) = delete; + EPBackend& operator=(const EPBackend&) = delete; + + // ep_comm is borrowed — caller retains ownership across the backend lifetime. + void init(ncclComm_t ep_comm, NVTEEpGroupConfig config); + + static EPBackend& instance(); // Meyers singleton accessor + static void validate_config(const NVTEEpGroupConfig& config); + + static ncclDataType_t nvte_dtype_to_nccl(NVTEDType dtype); + // Open a transient ncclEpHandle over handle_mem. num_topk=-1 for paths + // that don't carry per-token weights. + ncclEpHandle_t open_handle(void* handle_mem, size_t handle_mem_size, int num_topk, + size_t dispatch_output_per_expert_alignment); + + ncclEpGroup_t ep_group_{nullptr}; + ncclComm_t ep_comm_{nullptr}; + NVTEEpGroupConfig group_config_{}; + bool initialized_{false}; + std::mutex mutex_; + struct HandleEntry { + size_t handle_mem_size; + size_t alignment; + int top_k; + }; + std::unordered_map handles_; + std::atomic next_handle_id_{1}; // 0 reserved as "no id" + size_t handle_cache_cap_{0}; // set lazily from NVTE_EP_HANDLE_CACHE_SIZE + + // Caller must hold mutex_. Throws on cap overflow. + uint64_t insert_new_entry(size_t handle_mem_size, int top_k, size_t alignment); + HandleEntry& lookup_config(uint64_t handle_id); +}; + +} // namespace ep +} // namespace transformer_engine + +#endif // TRANSFORMER_ENGINE_COMMON_EP_EP_BACKEND_H_ diff --git a/transformer_engine/common/include/transformer_engine/comm_window.h b/transformer_engine/common/include/transformer_engine/comm_window.h new file mode 100644 index 0000000000..088ea7f0c3 --- /dev/null +++ b/transformer_engine/common/include/transformer_engine/comm_window.h @@ -0,0 +1,32 @@ +/************************************************************************* + * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +/*! \file comm_window.h + * \brief Borrowed symmetric-memory window + offset for zero-copy one-sided ops. + * Pass ``{NULL, 0}`` to use the raw-pointer path. + */ + +#ifndef TRANSFORMER_ENGINE_COMM_WINDOW_H_ +#define TRANSFORMER_ENGINE_COMM_WINDOW_H_ + +#include +#include + +#ifdef __cplusplus +extern "C" { +#endif + +/*! \brief NCCL window + byte offset for a zero-copy payload tensor. */ +typedef struct { + ncclWindow_t window; /*!< NCCL window, or NULL to use the raw data pointer. */ + uint64_t offset; /*!< Byte offset of the payload within ``window``. */ +} NVTECommWindow; + +#ifdef __cplusplus +} +#endif + +#endif // TRANSFORMER_ENGINE_COMM_WINDOW_H_ diff --git a/transformer_engine/common/include/transformer_engine/ep.h b/transformer_engine/common/include/transformer_engine/ep.h new file mode 100644 index 0000000000..8c3a06b5f0 --- /dev/null +++ b/transformer_engine/common/include/transformer_engine/ep.h @@ -0,0 +1,161 @@ +/************************************************************************* + * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +/*! \file ep.h + * \brief Public C API for Expert Parallelism. Per-step ops are allocation-free + * and CUDA graph-capturable. + */ + +#ifndef TRANSFORMER_ENGINE_EP_H_ +#define TRANSFORMER_ENGINE_EP_H_ + +#include +#include +#include +#include +#include + +#ifdef __cplusplus +extern "C" { +#endif + +/* ── Config structs ─────────────────────────────────────────────────────── */ + +/*! \brief Group-level EP configuration (fixed for the EP group lifetime). */ +typedef struct { + int ep_size; /*!< EP world size. */ + int num_experts; /*!< Total experts across all ranks. */ + int max_tokens_per_rank; /*!< Upper bound on tokens this rank sends per dispatch. */ + /*! Upper bound on tokens received per dispatch (worst-case top_k fan-out; must be > 0). */ + int max_recv_tokens_per_rank; + int hidden_dim; /*!< Token hidden dimension. */ + int max_num_sms; /*!< Max SMs for EP kernels. 0 = auto. */ + /*! 0 (default): throw on relocated handle_mem for a cached handle_id. 1: silently rebuild. */ + int allow_handle_mem_reloc; +} NVTEEpGroupConfig; + +/*! \brief Per-layer EP configuration. */ +typedef struct { + int num_local_experts; /*!< Reserved for ABI stability (derived from group config). */ + int top_k; /*!< Per-token expert fan-out. Required. */ + size_t dispatch_output_per_expert_alignment; + /*!< Per-expert zone alignment in tokens (pow2; 0/1 = no padding). Must match + * between nvte_ep_register_layer and nvte_ep_prepare. */ +} NVTEEpLayerConfig; + +/* ── Bootstrap ──────────────────────────────────────────────────────────── */ + +/*! \brief Bootstrap from an existing NCCL EP sub-communicator. Requires SM>=90. + * + * ep_comm is borrowed and must span exactly group_config.ep_size ranks. + * Re-init after shutdown is allowed; double-init throws. + * + * \param[in] ep_comm Opaque ncclComm_t for the EP sub-group. + * \param[in] group_config Group-level EP configuration. + */ +void nvte_ep_initialize(void* ep_comm, NVTEEpGroupConfig group_config); + +/*! \brief Tear down the EP backend. Idempotent. Does not destroy ep_comm. */ +void nvte_ep_shutdown(void); + +/* ── Layer registration (host-only, eager) ───────────────────────────────── */ + +/*! \brief Reserve a handle_id for a layer config and report the handle_mem buffer + * size the caller must allocate. Host-only. + * + * \param[in] layer_config Per-layer EP configuration. + * \param[out] handle_mem_size Bytes the caller must allocate for handle_mem. + * \return uint64_t handle_id (non-zero). + */ +uint64_t nvte_ep_register_layer(NVTEEpLayerConfig layer_config, size_t* handle_mem_size); + +/*! \brief Per-step handle: the registered handle_id paired with its handle_mem buffer. */ +typedef struct { + uint64_t id; /*!< Handle id from nvte_ep_register_layer. */ + NVTETensor mem; /*!< Caller-allocated handle_mem buffer (size from nvte_ep_register_layer). */ +} NVTEEpHandle; + +/* ── Per-step ops (all allocation-free, CUDA graph-capturable) ──────────── */ + +/*! \brief AllGather the routing map; write per-expert counts and cache routing + * metadata in handle.mem for the subsequent dispatch/combine. + * + * \param[in] handle EP handle (id + mem buffer). + * \param[in] topk_idx [T, top_k] int64 routing indices. + * \param[out] token_counts [num_local_experts] int32 counts. + * \param[in] dispatch_output_per_expert_alignment Must match the handle_mem sizing. + * \param[in] stream CUDA stream. + */ +void nvte_ep_prepare(NVTEEpHandle handle, NVTETensor topk_idx, NVTETensor token_counts, + size_t dispatch_output_per_expert_alignment, cudaStream_t stream); + +/*! \brief Dispatch tokens (and routing weights) to expert ranks. + * + * \param[in] handle EP handle (id + mem buffer). + * \param[in] topk_idx [T, top_k] int64 sparse routing indices. + * \param[in] tokens [T, hidden_dim] input tokens. + * \param[in] tokens_win Optional symmem window for ``tokens``. + * \param[in] topk_weights [T, top_k] float32 weights, or null in backward. + * \param[in] topk_weights_win Optional symmem window for ``topk_weights``. + * \param[out] recv_tokens [recv_T, hidden_dim] received tokens. + * \param[in] recv_tokens_win Optional symmem window for ``recv_tokens``. + * \param[out] recv_topk_weights [recv_T] float32 per-slot weights, or null in backward. + * \param[in] recv_topk_weights_win Optional symmem window for ``recv_topk_weights``. + * \param[in] stream CUDA stream. + */ +void nvte_ep_dispatch(NVTEEpHandle handle, NVTETensor topk_idx, NVTETensor tokens, + NVTECommWindow tokens_win, NVTETensor topk_weights, + NVTECommWindow topk_weights_win, NVTETensor recv_tokens, + NVTECommWindow recv_tokens_win, NVTETensor recv_topk_weights, + NVTECommWindow recv_topk_weights_win, cudaStream_t stream); + +/*! \brief Scatter-sum expert outputs back to originating ranks. Unweighted — + * caller must pre-multiply expert_out by recv_topk_weights (and the + * valid-slot mask) before calling. + * + * \param[in] handle EP handle (id + mem buffer). + * \param[in] expert_out [recv_T, hidden_dim] pre-weighted expert outputs. + * \param[in] expert_out_win Optional symmem window for ``expert_out``. + * \param[out] result [T, hidden_dim] combined output. + * \param[in] stream CUDA stream. + */ +void nvte_ep_combine(NVTEEpHandle handle, NVTETensor expert_out, NVTECommWindow expert_out_win, + NVTETensor result, cudaStream_t stream); + +/*! \brief Backward of dispatch — routes token and weight grads back to source. + * + * \param[in] handle EP handle (id + mem buffer). + * \param[in] grad [recv_capacity, hidden_dim] grad w.r.t. recv_tokens. + * \param[in] grad_win Optional symmem window for ``grad``. + * \param[in] g_recv_topk_weights [recv_capacity] f32 grad w.r.t. recv_topk_weights. + * \param[in] g_recv_topk_weights_win Optional symmem window for ``g_recv_topk_weights``. + * \param[out] grad_tokens [T, hidden_dim] grad w.r.t. tokens. + * \param[out] grad_topk_weights [T, top_k] f32 grad w.r.t. topk_weights. + * \param[in] stream CUDA stream. + */ +void nvte_ep_dispatch_bwd(NVTEEpHandle handle, NVTETensor grad, NVTECommWindow grad_win, + NVTETensor g_recv_topk_weights, NVTECommWindow g_recv_topk_weights_win, + NVTETensor grad_tokens, NVTETensor grad_topk_weights, + cudaStream_t stream); + +/*! \brief Backward of combine. Padded slots in grad_expert_out are zeroed. + * + * \param[in] handle EP handle (id + mem buffer). + * \param[in] grad [T, hidden_dim] grad w.r.t. result. + * \param[in] grad_win Optional symmem window for ``grad``. + * \param[out] grad_expert_out [recv_capacity, hidden_dim] grad w.r.t. expert_out. + * \param[in] grad_expert_out_win Optional symmem window for ``grad_expert_out``. + * \param[in] stream CUDA stream. + */ +void nvte_ep_combine_bwd(NVTEEpHandle handle, NVTETensor grad, NVTECommWindow grad_win, + NVTETensor grad_expert_out, NVTECommWindow grad_expert_out_win, + cudaStream_t stream); + +#ifdef __cplusplus +} +#endif + +#endif // TRANSFORMER_ENGINE_EP_H_ From 69b01961c1e0c0a5683d0e76c833f7e626ca528a Mon Sep 17 00:00:00 2001 From: Phuong Nguyen Date: Fri, 22 May 2026 23:05:43 +0000 Subject: [PATCH 2/3] Expert Parallelism: JAX bindings (FFI, custom_vjp, multi-process tests, MoE example) Signed-off-by: Phuong Nguyen --- build_tools/jax.py | 41 +- examples/jax/ep/ep_moe.py | 396 ++++++++ examples/jax/ep/run_test_ep.sh | 85 ++ tests/jax/multi_process_launch_ep.sh | 67 ++ tests/jax/test_multi_process_ep.py | 690 +++++++++++++ .../jax/cpp_extensions/__init__.py | 1 + transformer_engine/jax/cpp_extensions/ep.py | 955 ++++++++++++++++++ transformer_engine/jax/csrc/extensions.h | 19 + transformer_engine/jax/csrc/extensions/ep.cpp | 457 +++++++++ .../jax/csrc/extensions/pybind.cpp | 18 + transformer_engine/jax/ep.py | 303 ++++++ transformer_engine/jax/sharding.py | 12 + 12 files changed, 3042 insertions(+), 2 deletions(-) create mode 100644 examples/jax/ep/ep_moe.py create mode 100755 examples/jax/ep/run_test_ep.sh create mode 100755 tests/jax/multi_process_launch_ep.sh create mode 100644 tests/jax/test_multi_process_ep.py create mode 100644 transformer_engine/jax/cpp_extensions/ep.py create mode 100644 transformer_engine/jax/csrc/extensions/ep.cpp create mode 100644 transformer_engine/jax/ep.py diff --git a/build_tools/jax.py b/build_tools/jax.py index a7b200f915..49c5001d18 100644 --- a/build_tools/jax.py +++ b/build_tools/jax.py @@ -103,13 +103,50 @@ def setup_jax_extension( setup_mpi_flags(include_dirs, cxx_flags) + # NCCL EP is on by default. Set NVTE_BUILD_WITH_NCCL_EP=0 to skip it. + build_with_nccl_ep = bool(int(os.getenv("NVTE_BUILD_WITH_NCCL_EP", "1"))) + libraries = [] + submod_lib_dir = None + submod_nccl_inc = None + if build_with_nccl_ep: + cxx_flags.append("-DNVTE_WITH_NCCL_EP") + # Headers + libs come from the in-tree 3rdparty/nccl submodule build + # (auto-produced by setup.py). + libraries = ["nccl", "nccl_ep"] + # NCCL EP requires SM>=90 (Hopper+). + archs_env = os.getenv("NVTE_CUDA_ARCHS", "") + for a in archs_env.split(";"): + a_num = "".join(c for c in a if c.isdigit()) + if a_num and int(a_num) < 90: + raise RuntimeError( + f"NCCL EP requires CUDA arch >= 90 (Hopper or newer); got '{a}' in" + " NVTE_CUDA_ARCHS." + ) + submod_root = (common_header_files / ".." / "3rdparty" / "nccl").resolve() + submod_ep_inc = submod_root / "contrib" / "nccl_ep" / "include" + if not (submod_ep_inc / "nccl_ep.h").exists(): + raise RuntimeError( + f"NCCL EP header not found at {submod_ep_inc}/nccl_ep.h. " + "Run `git submodule update --init --recursive` to checkout 3rdparty/nccl." + ) + include_dirs.append(submod_ep_inc) + submod_lib_dir = submod_root / "build" / "lib" + submod_nccl_inc = submod_root / "build" / "include" + # Define TE/JAX as a Pybind11Extension from pybind11.setup_helpers import Pybind11Extension - return Pybind11Extension( + ext = Pybind11Extension( "transformer_engine_jax", sources=[str(path) for path in sources], include_dirs=[str(path) for path in include_dirs], extra_compile_args=cxx_flags, - libraries=["nccl"], + libraries=libraries, ) + if submod_lib_dir is not None: + ext.library_dirs.append(str(submod_lib_dir)) + ext.runtime_library_dirs.append(str(submod_lib_dir)) + # Prefer submodule's nccl.h when present (matches the C++ side). + if (submod_nccl_inc / "nccl.h").exists(): + ext.include_dirs.insert(0, str(submod_nccl_inc)) + return ext diff --git a/examples/jax/ep/ep_moe.py b/examples/jax/ep/ep_moe.py new file mode 100644 index 0000000000..8dcac02a04 --- /dev/null +++ b/examples/jax/ep/ep_moe.py @@ -0,0 +1,396 @@ +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. +"""End-to-end MoE example: dispatch -> batched expert linear -> combine, fwd + bwd. + +One process per GPU. Run via run_test_ep.sh. +""" + +import argparse +import sys + +import jax +import jax.numpy as jnp +import numpy as np +from jax.sharding import Mesh, NamedSharding, PartitionSpec + +from transformer_engine.jax.ep import ep_bootstrap, ep_dispatch, ep_combine +from transformer_engine.jax.sharding import MeshResource, global_shard_guard + + +# ── Setup ─────────────────────────────────────────────────────────────────── + + +def _parse_args(): + p = argparse.ArgumentParser(description="TE-JAX EP MoE example (fwd + bwd)") + p.add_argument("--coordinator-address", required=True) + p.add_argument("--process-id", type=int, required=True) + p.add_argument("--num-processes", type=int, required=True) + p.add_argument("--num-tokens", type=int, default=8, help="Per-rank token count.") + p.add_argument("--top-k", type=int, default=2) + p.add_argument("--hidden", type=int, default=32) + p.add_argument("--hidden-out", type=int, default=32) + p.add_argument( + "--num-experts", + type=int, + default=None, + help="Total experts across the EP group. Default: num_processes.", + ) + p.add_argument("--dp-size", type=int, default=None, help="Default: num_procs // ep_size.") + p.add_argument( + "--check", + action="store_true", + default=True, + help="Verify fwd+bwd against a single-rank numpy reference.", + ) + return p.parse_args() + + +def _distributed_init(args): + jax.distributed.initialize( + coordinator_address=args.coordinator_address, + num_processes=args.num_processes, + process_id=args.process_id, + local_device_ids=[args.process_id], + ) + assert ( + jax.local_device_count() == 1 + ), f"EP example requires 1 GPU per process; got {jax.local_device_count()}" + + +def _build_mesh_and_resource(args): + """Pick a (2, 2) mesh by default. Override via --dp-size.""" + n = args.num_processes + if n < 4: + raise ValueError(f"num_processes ({n}) must be >= 4 for NCCL EP") + if args.dp_size is None: + if n != 4: + raise ValueError( + f"default mesh expects exactly 4 ranks (got {n}); pass --dp-size to override" + ) + args.dp_size = 2 + assert n % args.dp_size == 0, f"num_processes={n} not divisible by dp_size={args.dp_size}" + args.ep_size = n // args.dp_size + if args.num_experts is None: + args.num_experts = args.num_processes + assert args.num_experts % args.ep_size == 0 + args.num_local_experts = args.num_experts // args.ep_size + args.recv_capacity_per_rank = args.ep_size * args.num_tokens * args.top_k + + devs = np.asarray(jax.devices()).reshape(args.dp_size, args.ep_size) + mesh = Mesh(devs, ("dp", "ep")) + mr = MeshResource(dp_resource="dp", ep_resource="ep") + return mesh, mr + + +def _make_routing(dp_color, num_tokens, top_k, num_experts, num_local_experts): + """Deterministic routing: topk_idx[t, k] = (dp_color*NLE + t*K + k) % E.""" + topk_idx = np.empty((num_tokens, top_k), dtype=np.int32) + for t in range(num_tokens): + for k in range(top_k): + topk_idx[t, k] = (dp_color * num_local_experts + t * top_k + k) % num_experts + return topk_idx + + +def _make_inputs(args): + """Build 3D ``[B, S, H]`` arrays sharded ``(("dp","ep"), None, None)``. + + B = num_processes (sharded across the compound (dp,ep) axis so each rank + holds one slot); S = args.num_tokens. Global numpy views (rank-0 + reference) are kept 2D for the legacy reference implementation. + """ + T, K, H, H_out = args.num_tokens, args.top_k, args.hidden, args.hidden_out + E = args.num_experts + dp_size = args.dp_size + ep_size = args.ep_size + num_procs = args.num_processes + dp_color = args.process_id // ep_size + + rng_dp = np.random.default_rng(seed=42 + dp_color) + tokens_np = (rng_dp.standard_normal((T, H), dtype=np.float32) * 0.5).astype(np.float32) + topk_idx_np = _make_routing(dp_color, T, K, E, args.num_local_experts) + w_np = np.full((T, K), 1.0 / K, dtype=np.float32) + + tokens_global_np = np.concatenate( + [ + ( + np.random.default_rng(seed=42 + c).standard_normal((T, H), dtype=np.float32) * 0.5 + ).astype(np.float32) + for c in range(dp_size) + ], + axis=0, + ) + topk_idx_global_np = np.concatenate( + [_make_routing(c, T, K, E, args.num_local_experts) for c in range(dp_size)], axis=0 + ) + w_global_np = np.full((dp_size * T, K), 1.0 / K, dtype=np.float32) + + # Same seed on every rank → identical kernel array everywhere. + rng = np.random.default_rng(seed=42) + kernels_np = (rng.standard_normal((E, H, H_out), dtype=np.float32) * (1.0 / np.sqrt(H))).astype( + np.float32 + ) + + # Each rank contributes one [1, T, ...] slab; the global shape is + # [num_procs, T, ...] sharded on the first dim across (dp, ep). + mesh = args.mesh + dpep_spec = NamedSharding(mesh, PartitionSpec(("dp", "ep"), None, None)) + tokens = jax.make_array_from_process_local_data( + dpep_spec, tokens_np[None, :, :].astype(np.float32), (num_procs, T, H) + ).astype(jnp.bfloat16) + topk_idx = jax.make_array_from_process_local_data( + dpep_spec, topk_idx_np[None, :, :], (num_procs, T, K) + ) + topk_w = jax.make_array_from_process_local_data(dpep_spec, w_np[None, :, :], (num_procs, T, K)) + kernels = jnp.asarray(kernels_np, dtype=jnp.bfloat16) + return ( + tokens_global_np, + topk_idx_global_np, + w_global_np, + kernels_np, + tokens, + topk_idx, + topk_w, + kernels, + ) + + +# ── MoE step ──────────────────────────────────────────────────────────────── + + +def _batched_expert_linear(recv_tokens, kernels, num_local_experts, dp_size, ep_size): + """Per-expert linear. ``recv_tokens`` is 3D ``[num_procs, recv_pr, H]`` + (compound (dp,ep) leading); ``kernels`` is 4D ``[ep_size, NLE, H, H_out]``, + broadcast over the dp axis. Output matches ``recv_tokens``' 3D layout + with ``H_out`` in place of ``H``.""" + num_procs, recv_pr, H = recv_tokens.shape + H_out = kernels.shape[-1] + slots_per_expert = recv_pr // num_local_experts + # [num_procs, recv_pr, H] -> [dp, ep, NLE, slots, H] + grouped = recv_tokens.reshape(dp_size, ep_size, num_local_experts, slots_per_expert, H) + # Contract H; batch over (ep, NLE) which are present on both sides. + out = jax.lax.dot_general( + grouped, + kernels.astype(grouped.dtype), + dimension_numbers=(((4,), (2,)), ((1, 2), (0, 1))), + ) + # Output dim order from dot_general: batch dims first, then remaining lhs, rhs. + # batch=(ep,NLE), lhs_remaining=(dp,slots), rhs_remaining=(H_out,) + # → shape [ep, NLE, dp, slots, H_out]. Permute to [dp, ep, NLE, slots, H_out]. + out = jnp.transpose(out, (2, 0, 1, 3, 4)) + return out.reshape(num_procs, recv_pr, H_out) + + +def _moe_step(args, topk_idx, tokens, topk_w, kernels): + """Jit'd MoE step: dispatch -> batched per-expert linear -> combine. + + Inputs are 3D ``[B, S, H]`` with the first dim compound-sharded across + ``("dp","ep")``. Combine returns the same 3D shape. + """ + B = args.num_processes + S = args.num_tokens + NLE = args.num_local_experts + dp_size, ep_size = args.dp_size, args.ep_size + mesh = args.mesh + in_spec = PartitionSpec(("dp", "ep"), None, None) # [B, S, ...] + ep3 = PartitionSpec(("dp", "ep"), None, None) # [num_procs, recv_pr, H] + ep2 = PartitionSpec(("dp", "ep"), None) # [num_procs, recv_pr] + # Kernels are EP-replicated across dp colors; shard only the ep-rank axis. + kernel_spec = PartitionSpec("ep", None, None, None) + + kernels = kernels.reshape(ep_size, NLE, *kernels.shape[1:]) + + @jax.jit + def step(topk_idx, tokens, topk_w, local_kernels): + topk_idx = jax.lax.with_sharding_constraint(topk_idx, NamedSharding(mesh, in_spec)) + tokens = jax.lax.with_sharding_constraint(tokens, NamedSharding(mesh, in_spec)) + topk_w = jax.lax.with_sharding_constraint(topk_w, NamedSharding(mesh, in_spec)) + local_kernels = jax.lax.with_sharding_constraint( + local_kernels, NamedSharding(mesh, kernel_spec) + ) + slots_per_expert = args.recv_capacity_per_rank // NLE + recv_tokens, recv_topk_w, handle, _tc = ep_dispatch( + topk_idx, + tokens, + topk_w, + args.recv_capacity_per_rank, + dispatch_output_per_expert_alignment=slots_per_expert, + ) + recv_tokens = jax.lax.with_sharding_constraint(recv_tokens, NamedSharding(mesh, ep3)) + recv_topk_w = jax.lax.with_sharding_constraint(recv_topk_w, NamedSharding(mesh, ep2)) + expert_out = _batched_expert_linear(recv_tokens, local_kernels, NLE, dp_size, ep_size) + expert_out = jax.lax.with_sharding_constraint(expert_out, NamedSharding(mesh, ep3)) + return ep_combine( + handle, + _tc, + expert_out, + recv_topk_w, + num_local_tokens=(B, S), + out_sharding=(("dp", "ep"), None, None), + ) + + return step(topk_idx, tokens, topk_w, kernels) + + +# ── Reference (numerical check) ───────────────────────────────────────────── + + +def _reference_moe(tokens, topk_idx, topk_w, kernels): + """Single-rank dense MoE reference. tokens [T, H], output [T, H_out].""" + T, K = topk_idx.shape + H_out = kernels.shape[-1] + out = np.zeros((T, H_out), dtype=np.float32) + for t in range(T): + tok = tokens[t].astype(np.float32) + for k in range(K): + e = int(topk_idx[t, k]) + out[t] += float(topk_w[t, k]) * (tok @ kernels[e].astype(np.float32)) + return out + + +def _reference_grad(tokens, topk_idx, topk_w, kernels): + """d/dtokens of 0.5 * sum(ref_out**2) — used by --check to validate bwd.""" + T, K = topk_idx.shape + H = tokens.shape[-1] + ref_out = _reference_moe(tokens, topk_idx, topk_w, kernels) + grad = np.zeros((T, H), dtype=np.float32) + for t in range(T): + mixed = np.zeros_like(kernels[0]) + for k in range(K): + mixed = mixed + float(topk_w[t, k]) * kernels[int(topk_idx[t, k])] + grad[t] = ref_out[t] @ mixed.T + return ref_out, grad + + +# ── Main ──────────────────────────────────────────────────────────────────── + + +def main(): + args = _parse_args() + _distributed_init(args) + + dev = jax.local_devices()[0] + cap = getattr(dev, "compute_capability", None) + if cap is not None: + major, minor = (int(x) for x in str(cap).split(".")) + if major * 10 + minor < 90: + print(f"[ep_moe] SKIPPED: NCCL EP requires SM>=90 (got SM{major}{minor})") + return + + args.mesh, args.mr = _build_mesh_and_resource(args) + + with args.mesh, global_shard_guard(args.mr): + ep_bootstrap( + world_size=args.num_processes, + rank=args.process_id, + ep_size=args.ep_size, + num_experts=args.num_experts, + max_tokens_per_rank=args.num_tokens, + recv_capacity_per_rank=args.recv_capacity_per_rank, + hidden_dim=args.hidden, + ) + + ( + tokens_global_np, + topk_idx_global_np, + w_global_np, + kernels_np, + tokens, + topk_idx, + topk_w, + kernels, + ) = _make_inputs(args) + + def loss_fn(toks, idx, w, kern): + out = _moe_step(args, idx, toks, w, kern) + return 0.5 * (out.astype(jnp.float32) ** 2).sum(), out + + (loss, out_fwd), grad_tokens = jax.jit(jax.value_and_grad(loss_fn, has_aux=True))( + tokens, topk_idx, topk_w, kernels + ) + grad_tokens.block_until_ready() + out_fwd.block_until_ready() + + if args.process_id == 0: + print( + f"[ep_moe] loss={float(loss):.4f} grad_tokens.shape={grad_tokens.shape} " + f"dp={args.dp_size} ep={args.ep_size} " + f"num_experts={args.num_experts} recv_pr={args.recv_capacity_per_rank}" + ) + + if args.check: + + def _norm(spec, ndim): + return tuple(spec) + (None,) * (ndim - len(spec)) + + # JAX may collapse a size-1 mesh axis: when dp_size==1 the spec can + # appear as ``(("dp","ep"),...)`` or ``("ep",...)``. Accept both. + if args.dp_size > 1: + acceptable_specs = ((("dp", "ep"), None, None),) + else: + acceptable_specs = ((("dp", "ep"), None, None), ("ep", None, None)) + assert ( + _norm(out_fwd.sharding.spec, out_fwd.ndim) in acceptable_specs + ), f"out_fwd.sharding.spec={out_fwd.sharding.spec} (expected one of {acceptable_specs})" + assert _norm(grad_tokens.sharding.spec, grad_tokens.ndim) in acceptable_specs, ( + f"grad_tokens.sharding.spec={grad_tokens.sharding.spec}" + f" (expected one of {acceptable_specs})" + ) + + replicated = NamedSharding(args.mesh, jax.sharding.PartitionSpec()) + out_global = jax.jit(lambda x: jax.lax.with_sharding_constraint(x, replicated))(out_fwd) + grad_global = jax.jit(lambda x: jax.lax.with_sharding_constraint(x, replicated))( + grad_tokens + ) + out_global.block_until_ready() + grad_global.block_until_ready() + + ref_out, ref_grad = _reference_grad( + tokens_global_np, topk_idx_global_np, w_global_np, kernels_np + ) + ref_loss = 0.5 * float((ref_out.astype(np.float32) ** 2).sum()) + # 3D global ``[num_procs, S, H]`` with num_procs = dp * ep. Each EP + # column in a DP color sees identical inputs (and produces identical + # outputs), so collapse the ep dim to one replica before flattening + # to 2D against the dp-only reference. + dp_size, ep_size = args.dp_size, args.ep_size + global_out = ( + np.asarray(out_global.addressable_shards[0].data.astype(jnp.float32)) + .reshape(dp_size, ep_size, -1, ref_out.shape[-1])[:, 0] + .reshape(-1, ref_out.shape[-1]) + ) + global_grad = ( + np.asarray(grad_global.addressable_shards[0].data.astype(jnp.float32)) + .reshape(dp_size, ep_size, -1, ref_grad.shape[-1])[:, 0] + .reshape(-1, ref_grad.shape[-1]) + ) + if args.process_id == 0: + fwd_diff = np.abs(global_out - ref_out) + grad_diff = np.abs(global_grad - ref_grad) + print( + f"[ep_moe] DEBUG loss={float(loss):.4f} ref_loss(global)={ref_loss:.4f} " + f"ratio={float(loss) / max(ref_loss, 1e-9):.4f} (expected ~1.0)" + ) + print(f"[ep_moe] DEBUG fwd max-abs-diff per row: {fwd_diff.max(axis=1)}") + print(f"[ep_moe] DEBUG grad max-abs-diff per row: {grad_diff.max(axis=1)}") + np.testing.assert_allclose( + global_out, + ref_out, + rtol=5e-2, + atol=5e-2, + err_msg=f"rank {args.process_id}: fwd mismatch", + ) + np.testing.assert_allclose( + global_grad, + ref_grad, + rtol=5e-2, + atol=5e-2, + err_msg=f"rank {args.process_id}: bwd mismatch", + ) + if args.process_id == 0: + print(f"[ep_moe] --check PASSED (ref_out.sum()={float(ref_out.sum()):.4f})") + + +if __name__ == "__main__": + main() + sys.exit(0) diff --git a/examples/jax/ep/run_test_ep.sh b/examples/jax/ep/run_test_ep.sh new file mode 100755 index 0000000000..55b958f146 --- /dev/null +++ b/examples/jax/ep/run_test_ep.sh @@ -0,0 +1,85 @@ +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. +#!/bin/bash + +NUM_GPUS=${NUM_GPUS:-$(nvidia-smi -L | wc -l)} + +if [ "${NUM_GPUS}" -lt 4 ]; then + echo "NCCL EP requires at least 4 GPUs (found ${NUM_GPUS}); SKIPPING." + exit 0 +fi +# Default mesh is (2, 2); use exactly 4 ranks even on larger boxes. +NUM_GPUS="${NVTE_EP_NUM_RANKS:-4}" + +: ${TE_PATH:=/opt/transformerengine} +: ${XML_LOG_DIR:=/logs} +mkdir -p "$XML_LOG_DIR" + +# NCCL EP requires NVLink P2P among ranks on the node. +echo "*** Checking NVLINK support ***" +NVLINK_OUTPUT=$(nvidia-smi nvlink --status 2>&1) +NVLINK_EXIT_CODE=$? +if [ $NVLINK_EXIT_CODE -ne 0 ] || [[ "$NVLINK_OUTPUT" == *"not supported"* ]] \ + || [[ "$NVLINK_OUTPUT" == *"No devices"* ]] || [ -z "$NVLINK_OUTPUT" ]; then + echo "NVLINK is not supported on this platform — EP example requires NVLINK; SKIPPING" + exit 0 +fi +echo "NVLINK support detected" + +SCRIPT="$TE_PATH/examples/jax/ep/ep_moe.py" +export PYTHONPATH="${TE_PATH}${PYTHONPATH:+:${PYTHONPATH}}" +COORD="${COORD:-127.0.0.1:12345}" +TEST_TIMEOUT_S="${TEST_TIMEOUT_S:-300}" + +XLA_BASE_FLAGS="--xla_gpu_enable_latency_hiding_scheduler=true + --xla_gpu_graph_min_graph_size=1" +export XLA_FLAGS="${XLA_BASE_FLAGS}" + +# Stage NCCL EP JIT cubins on tmpfs to keep build/iteration fast. +: ${NCCL_EP_JIT_CACHE_DIR:="${TMPDIR:-/tmp}/nccl_ep_jit_cache_$(id -u)"} +export NCCL_EP_JIT_CACHE_DIR +mkdir -p "$NCCL_EP_JIT_CACHE_DIR" + +echo +echo "*** Executing ep_moe.py across $NUM_GPUS GPUs ***" + +PIDS=() +cleanup() { + for pid in "${PIDS[@]}"; do + kill -0 "$pid" 2>/dev/null && kill -KILL "$pid" 2>/dev/null || true + done +} +trap cleanup EXIT INT TERM + +EXTRA_ARGS=${EXTRA_ARGS:-"--check"} + +for ((i=1; i "stdout_rank_${i}.txt" 2>&1 & + PIDS+=($!) +done +timeout --foreground --signal=KILL "${TEST_TIMEOUT_S}" \ + python -u "$SCRIPT" \ + --coordinator-address "$COORD" --process-id "0" --num-processes "$NUM_GPUS" \ + $EXTRA_ARGS 2>&1 | tee stdout_rank_0.txt +wait + +HAS_FAILURE=0 +if grep -qE "FAILED|Traceback|ERROR" stdout_rank_0.txt; then + echo "... ep_moe FAILED" + HAS_FAILURE=1 +elif ! grep -qE "\[ep_moe\]" stdout_rank_0.txt; then + echo "... ep_moe INVALID (rank 0 produced no summary line)" + for ((i=1; i/dev/null + done + HAS_FAILURE=1 +else + echo "... ep_moe PASSED" +fi +rm -f stdout_rank_*.txt +exit $HAS_FAILURE diff --git a/tests/jax/multi_process_launch_ep.sh b/tests/jax/multi_process_launch_ep.sh new file mode 100755 index 0000000000..a37ffc2952 --- /dev/null +++ b/tests/jax/multi_process_launch_ep.sh @@ -0,0 +1,67 @@ +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +#!/bin/bash + +SCRIPT_NAMES="${SCRIPT_NAMES:-test_multi_process_ep.py}" +TEST_TIMEOUT_S="${TEST_TIMEOUT_S:-180}" + + +XLA_BASE_FLAGS="--xla_gpu_enable_latency_hiding_scheduler=true + --xla_gpu_graph_min_graph_size=1" + +export XLA_FLAGS="${XLA_BASE_FLAGS}" + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +TE_REPO_ROOT="$(cd "${SCRIPT_DIR}/../.." && pwd)" +export PYTHONPATH="${TE_REPO_ROOT}${PYTHONPATH:+:${PYTHONPATH}}" + +NUM_RUNS=$(nvidia-smi -L | wc -l) + +if [ "${NUM_RUNS}" -lt 4 ]; then + echo "NCCL EP requires at least 4 GPUs (found ${NUM_RUNS}); SKIPPING." + exit 0 +fi +# Default test mesh is (2, 2); use exactly 4 ranks even on larger boxes. +NUM_RUNS="${NVTE_TEST_EP_NUM_RANKS:-4}" + +OVERALL_RET=0 + +for SCRIPT_NAME in $SCRIPT_NAMES; do + echo "=== Running ${SCRIPT_NAME} ===" + for ((i=1; i stdout_rank_${i}.txt 2>&1 & + done + + timeout --foreground --signal=KILL "${TEST_TIMEOUT_S}" \ + python $SCRIPT_NAME 127.0.0.1:12345 0 $NUM_RUNS 2>&1 | tee stdout_multi_process.txt + + wait + + RET=0 + if grep -q "FAILED" stdout_multi_process.txt; then + RET=1 + fi + # Treat missing test summary on rank 0 as hang/crash rather than silent success. + if ! grep -qE "Ran [0-9]+ test|^OK$|PASSED" stdout_multi_process.txt; then + echo "ERROR: rank 0 produced no test summary for ${SCRIPT_NAME} — likely a hang or early crash." + echo " NCCL EP requires NVLS multicast; check NCCL_DEBUG=INFO output." + RET=1 + fi + if [ "$RET" -ne 0 ]; then + for ((i=1; i/dev/null || echo "(no log)" + done + fi + + rm -f stdout_multi_process.txt stdout_rank_*.txt + if [ "$RET" -ne 0 ]; then + OVERALL_RET=1 + fi +done + +exit "$OVERALL_RET" diff --git a/tests/jax/test_multi_process_ep.py b/tests/jax/test_multi_process_ep.py new file mode 100644 index 0000000000..0658ad9750 --- /dev/null +++ b/tests/jax/test_multi_process_ep.py @@ -0,0 +1,690 @@ +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. +"""Multi-process unit tests for the TE-JAX Expert Parallelism (EP) primitives. + +Default mesh is (dp=2, ep=2); override via ``NVTE_TEST_EP_MESH=DPxEP``. +Coverage: + + - ``ep_bootstrap`` rejects when ``ep_resource`` is unset. + - Individual primitives (``ep_prepare``, ``ep_dispatch_fwd``, ``ep_combine_fwd``) + round-trip an identity expert → output ≈ tokens. + - ``ep_dispatch`` custom_vjp: ``grad_tokens ≈ TOP_K · tokens`` (closed form). + - ``ep_combine`` custom_vjp: ``max|grad_eo| ≈ eo_const / TOP_K`` (closed form). + - ``ep_dispatch`` custom_vjp: exact per-(t, k) ``grad_topk_weights`` under + skewed upstream gradients (no k-axis averaging). + - HLO reshard guard: compile-only, no XLA collectives outside the EP FFI. + +Launch via tests/jax/multi_process_launch_ep.sh (one process per GPU). +""" + +import os +import sys +import unittest + +import jax +import jax.experimental.multihost_utils as jmu +import jax.numpy as jnp +import numpy as np +from jax.sharding import Mesh, NamedSharding, PartitionSpec + +from transformer_engine.jax.sharding import MeshResource, global_shard_guard +from transformer_engine.jax.ep import ep_bootstrap, ep_dispatch, ep_combine +from transformer_engine.jax.cpp_extensions.ep import ( + ep_prepare, + ep_dispatch_fwd, + ep_combine_fwd, +) + + +# ── Test config ───────────────────────────────────────────────────────────── +# NCCL EP requires NUM_LOCAL_EXPERTS*ep % 4 == 0 (TMA alignment in +# device/hybridep_adapter.cu:511). With NUM_LOCAL_EXPERTS=2, ep must be even. + +NUM_LOCAL_EXPERTS = 2 # per-rank → num_experts = NLE * EP +HIDDEN_DIM = 32 +TOP_K = 2 +TOKENS_PER_DP_SHARD = 4 # per device along dp + + +def _factor_dp_ep(num_procs): + """Default to a (2, 2) mesh. Override via ``NVTE_TEST_EP_MESH=DPxEP``. + + NUM_LOCAL_EXPERTS*ep must be a multiple of 4 for NCCL EP's TMA alignment. + """ + override = os.environ.get("NVTE_TEST_EP_MESH") + if override: + dp_str, ep_str = override.lower().split("x") + dp, ep = int(dp_str), int(ep_str) + if dp * ep != num_procs: + raise ValueError( + f"NVTE_TEST_EP_MESH={override!r} does not multiply to num_procs={num_procs}" + ) + if (NUM_LOCAL_EXPERTS * ep) % 4 != 0: + raise ValueError( + f"NUM_LOCAL_EXPERTS*ep ({NUM_LOCAL_EXPERTS}*{ep}) must be a multiple of 4 " + "for NCCL EP TMA alignment" + ) + return dp, ep + if num_procs != 4: + raise ValueError( + f"default mesh expects exactly 4 ranks (got {num_procs}); set " + "NVTE_TEST_EP_MESH=DPxEP to override" + ) + return 2, 2 + + +def _build_mesh(dp, ep): + devs = np.asarray(jax.devices()).reshape(dp, ep) + return Mesh(devs, ("dp", "ep")) + + +def _local_device_sm(): + """Return SM major*10+minor of the first local CUDA device, or None.""" + try: + dev = jax.local_devices()[0] + cap = getattr(dev, "compute_capability", None) + if cap is None: + return None + major, minor = (int(x) for x in str(cap).split(".")) + return major * 10 + minor + except Exception: + return None + + +class TestEP(unittest.TestCase): + @classmethod + def setUpClass(cls): + sm = _local_device_sm() + if sm is not None and sm < 90: + raise unittest.SkipTest(f"NCCL EP requires SM>=90 (got SM{sm})") + cls.num_procs = jax.process_count() + cls.rank = jax.process_index() + cls.dp, cls.ep = _factor_dp_ep(cls.num_procs) + cls.num_experts = NUM_LOCAL_EXPERTS * cls.ep + # recv_capacity is per-DP-group (NCCL EP comms isolated per DP color). + # Under PartitionSpec(("dp","ep"), None) each EP group sees + # T_global/dp = TOKENS_PER_DP_SHARD tokens total; pad for routing skew. + T_per_ep_group = TOKENS_PER_DP_SHARD + active_experts = min(cls.num_experts, T_per_ep_group * TOP_K) + overconc = cls.num_experts // active_experts + cls.recv_capacity_per_rank = ( + NUM_LOCAL_EXPERTS * max(T_per_ep_group * TOP_K, 16) * overconc * 2 + ) + cls.mesh = _build_mesh(cls.dp, cls.ep) + cls.mr = MeshResource(dp_resource="dp", ep_resource="ep") + with cls.mesh, global_shard_guard(cls.mr): + ep_bootstrap( + world_size=cls.num_procs, + rank=cls.rank, + ep_size=cls.ep, + num_experts=cls.num_experts, + max_tokens_per_rank=TOKENS_PER_DP_SHARD, + recv_capacity_per_rank=cls.recv_capacity_per_rank, + hidden_dim=HIDDEN_DIM, + ) + + # ── Bootstrap precondition ──────────────────────────────────────────── + + def test_bootstrap_rejects_missing_ep_axis(self): + """ep_bootstrap raises when MeshResource has no ep_resource.""" + with self.mesh, global_shard_guard(MeshResource()): + with self.assertRaisesRegex(ValueError, "ep_resource"): + ep_bootstrap( + world_size=self.num_procs, + rank=self.rank, + ep_size=self.ep, + num_experts=self.num_experts, + max_tokens_per_rank=TOKENS_PER_DP_SHARD, + recv_capacity_per_rank=self.recv_capacity_per_rank, + hidden_dim=HIDDEN_DIM, + ) + + # ── Helpers ─────────────────────────────────────────────────────────── + + def _make_identity_inputs(self, nonuniform=False): + """Identity routing + uniform weights — combined output ≈ tokens. + + ``nonuniform=False``: ``(t*TOP_K+k) % E`` (round-robin, near-balanced). + ``nonuniform=True``: ``top1=0`` for every token, ``top2=1+(t%(E-1))`` — + expert 0 absorbs the entire batch while the others split the second + slot evenly. Exercises a skewed per-expert load. + """ + T_global = TOKENS_PER_DP_SHARD * self.dp + E = self.num_experts + topk_idx = np.empty((T_global, TOP_K), dtype=np.int32) + if nonuniform: + assert TOP_K == 2, "non-uniform pattern assumes top_k=2" + for t in range(T_global): + topk_idx[t, 0] = 0 + topk_idx[t, 1] = 1 + (t % (E - 1)) + else: + for t in range(T_global): + for k in range(TOP_K): + topk_idx[t, k] = (t * TOP_K + k) % E + topk_idx = jnp.asarray(topk_idx) + topk_weights = jnp.full((T_global, TOP_K), 1.0 / TOP_K, dtype=jnp.float32) + tokens = jnp.asarray( + np.linspace(0.1, 0.9, T_global * HIDDEN_DIM, dtype=np.float32).reshape( + T_global, HIDDEN_DIM + ), + dtype=jnp.bfloat16, + ) + return T_global, topk_idx, tokens, topk_weights + + def _make_random_inputs(self, seed=42, nonuniform=True): + """Random tokens + skewed top-2 routing (top1=0 always; top2 varies). + + Non-uniform load by default — guarantees expert 0 receives every token + while the rest of the experts split the second slot. Use + ``nonuniform=False`` for a balanced (t%E, (t+1)%E) pattern. + """ + T_dp = TOKENS_PER_DP_SHARD * self.dp + E = self.num_experts + rng = np.random.default_rng(seed=seed) + tokens = jnp.asarray( + rng.standard_normal((T_dp, HIDDEN_DIM), dtype=np.float32) * 0.5, + dtype=jnp.bfloat16, + ) + topk_idx_np = np.empty((T_dp, TOP_K), dtype=np.int32) + if nonuniform: + assert TOP_K == 2, "non-uniform pattern assumes top_k=2" + for t in range(T_dp): + topk_idx_np[t, 0] = 0 + topk_idx_np[t, 1] = 1 + (t % (E - 1)) + else: + for t in range(T_dp): + a, b = t % E, (t + 1) % E + topk_idx_np[t, 0], topk_idx_np[t, 1] = (a, b) if a < b else (b, a) + topk_idx = jnp.asarray(topk_idx_np) + topk_weights = jnp.asarray(np.full((T_dp, TOP_K), 1.0 / TOP_K, dtype=np.float32)) + return T_dp, tokens, topk_idx, topk_weights + + # ── Individual primitives (cpp_extensions level) ────────────────────── + + def test_two_prepares_distinct_handle_ids(self): + """Two ep_prepare sites with matching (top_k, alignment) must produce + distinct handle_ids — distinct logical layers cannot share a + HandleEntry. Verified by tracing through jit so the primitive's + outer_primitive.bind path is exercised.""" + _T, topk_idx, _tokens, _w = self._make_identity_inputs() + captured: list = [] + dp_spec = PartitionSpec(("dp", "ep"), None) + with self.mesh, global_shard_guard(self.mr): + idx_s = jax.lax.with_sharding_constraint(topk_idx, NamedSharding(self.mesh, dp_spec)) + + @jax.jit + def run(idx): + _tc_a, ha = ep_prepare(idx) + _tc_b, hb = ep_prepare(idx) + captured.append((ha.handle_id, hb.handle_id)) + return ha.handle_mem, hb.handle_mem + + hm_a, hm_b = run(idx_s) + hm_a.block_until_ready() + hm_b.block_until_ready() + id_a, id_b = captured[0] + self.assertNotEqual(id_a, id_b, "two ep_prepare calls returned the same handle_id") + + def test_primitive_prepare(self): + """ep_prepare returns the expected shapes and a valid handle id.""" + T_global, topk_idx, _tokens, _w = self._make_identity_inputs() + del T_global + dp_spec = PartitionSpec(("dp", "ep"), None) + with self.mesh, global_shard_guard(self.mr): + idx_s = jax.lax.with_sharding_constraint(topk_idx, NamedSharding(self.mesh, dp_spec)) + + @jax.jit + def run(idx): + tc, handle = ep_prepare(idx) + return tc, handle.handle_mem + + tc, hm = run(idx_s) + tc.block_until_ready() + self.assertEqual(tc.shape, (self.dp * self.ep, NUM_LOCAL_EXPERTS)) + self.assertEqual(hm.shape[0], self.dp * self.ep) + self.assertGreater(hm.shape[1], 0) + + def _run_identity_round_trip(self, nonuniform): + T_global, topk_idx, tokens, topk_w = self._make_identity_inputs(nonuniform=nonuniform) + dp_spec = PartitionSpec(("dp", "ep"), None) + with self.mesh, global_shard_guard(self.mr): + idx_s = jax.lax.with_sharding_constraint(topk_idx, NamedSharding(self.mesh, dp_spec)) + tok_s = jax.lax.with_sharding_constraint(tokens, NamedSharding(self.mesh, dp_spec)) + w_s = jax.lax.with_sharding_constraint(topk_w, NamedSharding(self.mesh, dp_spec)) + + ep_spec_3d = PartitionSpec(("dp", "ep"), None, None) + ep_spec_2d = PartitionSpec(("dp", "ep"), None) + + @jax.jit + def run(idx, toks, w): + _tc, handle = ep_prepare(idx) + recv_t, recv_w, handle = ep_dispatch_fwd( + handle, idx, toks, w, self.recv_capacity_per_rank + ) + recv_t = jax.lax.with_sharding_constraint( + recv_t, NamedSharding(self.mesh, ep_spec_3d) + ) + recv_w = jax.lax.with_sharding_constraint( + recv_w, NamedSharding(self.mesh, ep_spec_2d) + ) + # Apply the weighted hadamard inline (combine FFI is unweighted). + mask = (recv_w != 0).astype(jnp.float32)[..., None] + weighted = (recv_t.astype(jnp.float32) * recv_w[..., None] * mask).astype( + recv_t.dtype + ) + weighted = jax.lax.with_sharding_constraint( + weighted, NamedSharding(self.mesh, ep_spec_3d) + ) + out = ep_combine_fwd( + handle, weighted, T_global, out_partition_spec=(("dp", "ep"), None) + ) + return jax.lax.with_sharding_constraint(out, NamedSharding(self.mesh, dp_spec)) + + out = run(idx_s, tok_s, w_s) + out.block_until_ready() + # Allgather so the rank-0 numpy comparison sees the full global tensor. + out_global = jmu.process_allgather(out, tiled=True) + + # Identity expert + uniform weights → out ≈ tokens (rank-0 check). + if self.rank == 0: + np.testing.assert_allclose( + np.asarray(out_global.astype(jnp.float32)), + np.asarray(tokens.astype(jnp.float32)), + atol=5e-2, + rtol=5e-2, + ) + + def test_primitive_dispatch_combine_identity_uniform(self): + """Round-robin routing → identity round-trip via the primitive layer.""" + self._run_identity_round_trip(nonuniform=False) + + def test_primitive_dispatch_combine_identity_nonuniform(self): + """Skewed routing (top1=0 always) → identity round-trip via the primitive layer.""" + self._run_identity_round_trip(nonuniform=True) + + def test_primitive_dispatch_combine_identity_bwd_uniform(self): + """Bwd through identity round-trip: ∇(0.5 ||out||²) w.r.t. tokens ≈ tokens. + + Identity routing + uniform top-k weights ⇒ dispatch∘combine is the + identity, so loss = 0.5||tokens||² and ∇_tokens loss = tokens. + """ + T_global, topk_idx, tokens, topk_w = self._make_identity_inputs(nonuniform=False) + dp_spec = PartitionSpec(("dp", "ep"), None) + ep_spec_3d = PartitionSpec(("dp", "ep"), None, None) + ep_spec_2d = PartitionSpec(("dp", "ep"), None) + + with self.mesh, global_shard_guard(self.mr): + + def loss_fn(toks): + toks = jax.lax.with_sharding_constraint(toks, NamedSharding(self.mesh, dp_spec)) + idx = jax.lax.with_sharding_constraint(topk_idx, NamedSharding(self.mesh, dp_spec)) + w = jax.lax.with_sharding_constraint(topk_w, NamedSharding(self.mesh, dp_spec)) + recv_t, recv_w, handle, tc = ep_dispatch(idx, toks, w, self.recv_capacity_per_rank) + recv_t = jax.lax.with_sharding_constraint( + recv_t, NamedSharding(self.mesh, ep_spec_3d) + ) + recv_w = jax.lax.with_sharding_constraint( + recv_w, NamedSharding(self.mesh, ep_spec_2d) + ) + out = ep_combine( + handle, tc, recv_t, recv_w, T_global, out_sharding=(("dp", "ep"), None) + ) + return 0.5 * (out.astype(jnp.float32) ** 2).sum() + + grad = jax.jit(jax.grad(loss_fn))(tokens) + grad.block_until_ready() + grad_global = jmu.process_allgather(grad, tiled=True) + + if self.rank == 0: + np.testing.assert_allclose( + np.asarray(grad_global.astype(jnp.float32)), + np.asarray(tokens.astype(jnp.float32)), + atol=5e-2, + rtol=5e-2, + ) + + def test_dispatch_combine_3d_input_output(self): + """3D input ``[B, S, H]`` sharded on the first dim only — + ``(("dp","ep"), None, None)`` here — dispatch accepts the rank-3 shape + and combine returns a matching 3D ``[B, S, H]`` output. End-to-end + round trip recovers the original tokens under identity routing + + uniform top-k weights.""" + T_global, topk_idx, tokens, topk_w = self._make_identity_inputs(nonuniform=False) + # B is sharded across all (dp*ep) ranks; S held in one piece per rank. + B, S, H = T_global, 1, tokens.shape[-1] + tokens_3d = tokens.reshape(B, S, H) + topk_idx_3d = topk_idx.reshape(B, S, -1) + topk_w_3d = topk_w.reshape(B, S, -1) + spec_3d = PartitionSpec(("dp", "ep"), None, None) + out_spec_3d = (("dp", "ep"), None, None) + with self.mesh, global_shard_guard(self.mr): + idx_s = jax.lax.with_sharding_constraint(topk_idx_3d, NamedSharding(self.mesh, spec_3d)) + tok_s = jax.lax.with_sharding_constraint(tokens_3d, NamedSharding(self.mesh, spec_3d)) + w_s = jax.lax.with_sharding_constraint(topk_w_3d, NamedSharding(self.mesh, spec_3d)) + + ep_t = PartitionSpec(("dp", "ep"), None, None) + ep_w = PartitionSpec(("dp", "ep"), None) + + @jax.jit + def run(idx, toks, w): + recv_t, recv_w, handle, _tc = ep_dispatch(idx, toks, w, self.recv_capacity_per_rank) + recv_t = jax.lax.with_sharding_constraint(recv_t, NamedSharding(self.mesh, ep_t)) + recv_w = jax.lax.with_sharding_constraint(recv_w, NamedSharding(self.mesh, ep_w)) + out = ep_combine( + handle, + _tc, + recv_t, + recv_w, + num_local_tokens=(B, S), + out_sharding=out_spec_3d, + ) + return out + + out = run(idx_s, tok_s, w_s) + out.block_until_ready() + out_global = jmu.process_allgather(out, tiled=True) + + if self.rank == 0: + self.assertEqual(out_global.shape, (B, S, H)) + np.testing.assert_allclose( + np.asarray(out_global.astype(jnp.float32)), + np.asarray(tokens_3d.astype(jnp.float32)), + atol=5e-2, + rtol=5e-2, + ) + + def test_dispatch_combine_dp_only_first_dim(self): + """Input sharded ``("dp", None)`` (no ep on leading) — dispatch must + accept it. JAX SPMD slices the missing ep axis locally so the kernel + still sees ``T/(dp*ep)`` tokens per rank.""" + T_global, topk_idx, tokens, topk_w = self._make_identity_inputs(nonuniform=False) + dp_only = PartitionSpec("dp", None) + with self.mesh, global_shard_guard(self.mr): + idx_s = jax.lax.with_sharding_constraint(topk_idx, NamedSharding(self.mesh, dp_only)) + tok_s = jax.lax.with_sharding_constraint(tokens, NamedSharding(self.mesh, dp_only)) + w_s = jax.lax.with_sharding_constraint(topk_w, NamedSharding(self.mesh, dp_only)) + + ep_t = PartitionSpec(("dp", "ep"), None, None) + ep_w = PartitionSpec(("dp", "ep"), None) + + @jax.jit + def run(idx, toks, w): + recv_t, recv_w, handle, _tc = ep_dispatch(idx, toks, w, self.recv_capacity_per_rank) + recv_t = jax.lax.with_sharding_constraint(recv_t, NamedSharding(self.mesh, ep_t)) + recv_w = jax.lax.with_sharding_constraint(recv_w, NamedSharding(self.mesh, ep_w)) + out = ep_combine( + handle, + _tc, + recv_t, + recv_w, + num_local_tokens=T_global, + out_sharding=(("dp", "ep"), None), + ) + return out + + out = run(idx_s, tok_s, w_s) + out.block_until_ready() + out_global = jmu.process_allgather(out, tiled=True) + + if self.rank == 0: + np.testing.assert_allclose( + np.asarray(out_global.astype(jnp.float32)), + np.asarray(tokens.astype(jnp.float32)), + atol=5e-2, + rtol=5e-2, + ) + + # ── Custom-VJP tests ───────────────────────────────────────────────── + + def test_dispatch_vjp_fwd_bwd(self): + """ep_dispatch fwd + jax.grad w.r.t. tokens. + + Identity routing + loss = 0.5||recv_tokens||² ⇒ each token appears + TOP_K times in recv_tokens (all routes fit recv_capacity), so + grad_tokens = TOP_K * tokens (closed form). + """ + T_global, topk_idx, tokens, topk_w = self._make_identity_inputs() + del T_global + dp_spec = PartitionSpec(("dp", "ep"), None) + ep_spec_3d = PartitionSpec(("dp", "ep"), None, None) + + with self.mesh, global_shard_guard(self.mr): + + def loss_fn(toks): + toks = jax.lax.with_sharding_constraint(toks, NamedSharding(self.mesh, dp_spec)) + idx = jax.lax.with_sharding_constraint(topk_idx, NamedSharding(self.mesh, dp_spec)) + w = jax.lax.with_sharding_constraint(topk_w, NamedSharding(self.mesh, dp_spec)) + recv_tokens, _recv_w, _handle, _tc = ep_dispatch( + idx, toks, w, self.recv_capacity_per_rank + ) + recv_tokens = jax.lax.with_sharding_constraint( + recv_tokens, NamedSharding(self.mesh, ep_spec_3d) + ) + return 0.5 * (recv_tokens.astype(jnp.float32) ** 2).sum() + + loss, grad_tokens = jax.jit(jax.value_and_grad(loss_fn))(tokens) + grad_tokens.block_until_ready() + grad_global = jmu.process_allgather(grad_tokens, tiled=True) + + self.assertTrue(np.isfinite(float(loss))) + self.assertEqual(grad_tokens.shape, tokens.shape) + if self.rank == 0: + np.testing.assert_allclose( + np.asarray(grad_global.astype(jnp.float32)), + np.asarray(tokens.astype(jnp.float32)) * float(TOP_K), + atol=5e-2, + rtol=5e-2, + ) + + def test_combine_vjp_fwd_bwd(self): + """ep_combine fwd + jax.grad w.r.t. expert_out. + + Identity routing + constant eo=c + uniform topk_w ⇒ combined[t] = c + (sum_k topk_w = 1) and grad_eo[e, s, h] = recv_w[e, s] * c at filled + slots — so max|grad_eo| ≈ c / TOP_K. + """ + T_global, topk_idx, tokens, topk_w = self._make_identity_inputs() + eo_const = 0.5 + expert_out = jnp.full( + (self.dp * self.ep, self.recv_capacity_per_rank, HIDDEN_DIM), + eo_const, + dtype=jnp.bfloat16, + ) + dp_spec = PartitionSpec(("dp", "ep"), None) + ep_spec_3d = PartitionSpec(("dp", "ep"), None, None) + + with self.mesh, global_shard_guard(self.mr): + + def loss_fn(eo): + eo = jax.lax.with_sharding_constraint(eo, NamedSharding(self.mesh, ep_spec_3d)) + toks = jax.lax.with_sharding_constraint(tokens, NamedSharding(self.mesh, dp_spec)) + idx = jax.lax.with_sharding_constraint(topk_idx, NamedSharding(self.mesh, dp_spec)) + w = jax.lax.with_sharding_constraint(topk_w, NamedSharding(self.mesh, dp_spec)) + _recv_tokens, recv_w, handle, tc = ep_dispatch( + idx, toks, w, self.recv_capacity_per_rank + ) + recv_w = jax.lax.with_sharding_constraint( + recv_w, NamedSharding(self.mesh, PartitionSpec(("dp", "ep"), None)) + ) + combined = ep_combine(handle, tc, eo, recv_w, T_global) + # Pin combined to dp-sharded so autodiff transpose feeds + # ep_combine_bwd a per-shard cotangent. + combined = jax.lax.with_sharding_constraint( + combined, NamedSharding(self.mesh, dp_spec) + ) + return 0.5 * (combined.astype(jnp.float32) ** 2).sum() + + loss, grad_eo = jax.jit(jax.value_and_grad(loss_fn))(expert_out) + grad_eo.block_until_ready() + + self.assertTrue(np.isfinite(float(loss))) + self.assertEqual(grad_eo.shape, expert_out.shape) + for shard in grad_eo.addressable_shards: + arr = np.asarray(shard.data.astype(jnp.float32)) + self.assertTrue(np.all(np.isfinite(arr))) + self.assertGreater(arr.max(), 0.0, "grad_eo has no positive entry on filled slots") + np.testing.assert_allclose( + arr.max(), + eo_const / float(TOP_K), + atol=5e-2, + rtol=5e-2, + ) + + def test_dispatch_bwd_exact_per_k_topk_weights(self): + """Distinct per-(t, k) upstream grads ⇒ grad[t, 0] != grad[t, 1] for all t. + + Guards against a regression where the bwd would average across the k + axis (per-token mean instead of per-slot exact recovery). + """ + T_dp, tokens, topk_idx, topk_w = self._make_random_inputs() + dp_spec = PartitionSpec(("dp", "ep"), None) + + with self.mesh, global_shard_guard(self.mr): + + def loss_fn(idx_in, tok_in, w_in): + idx_in = jax.lax.with_sharding_constraint(idx_in, NamedSharding(self.mesh, dp_spec)) + tok_in = jax.lax.with_sharding_constraint(tok_in, NamedSharding(self.mesh, dp_spec)) + w_in = jax.lax.with_sharding_constraint(w_in, NamedSharding(self.mesh, dp_spec)) + _recv_t, recv_w, _h, _tc = ep_dispatch( + idx_in, tok_in, w_in, self.recv_capacity_per_rank + ) + # Per-slot index scale ⇒ each slot's contribution differs. + scale = jnp.asarray( + np.arange(recv_w.size, dtype=np.float32).reshape(recv_w.shape) + 1.0 + ) + return jnp.sum(recv_w * scale) + + grad_topk_w = jax.jit(jax.grad(loss_fn, argnums=2))(topk_idx, tokens, topk_w) + grad_topk_w.block_until_ready() + grad_global = jmu.process_allgather(grad_topk_w, tiled=True) + + if self.rank == 0: + grad_np = np.asarray(grad_global).astype(np.float32) + mismatch = sum(int(abs(grad_np[t, 0] - grad_np[t, 1]) < 1e-6) for t in range(T_dp)) + self.assertEqual( + mismatch, + 0, + f"Expected grad[t, 0] != grad[t, 1] for all {T_dp} tokens under skewed " + f"upstream scaling; got {mismatch} tokens with grad[t, 0] == grad[t, 1].", + ) + + # ── HLO reshard guard ──────────────────────────────────────────────── + # Compile-only: assert XLA inserts no cross-device collectives outside + # the EP FFI. EP-axis flux is carried by the FFI itself. + + def test_z_no_unexpected_reshard_in_hlo_fwd(self): + """Compiled fwd HLO must not insert XLA collectives outside the EP FFI.""" + T_dp, tokens, topk_idx, topk_w = self._make_random_inputs() + dp_spec = PartitionSpec(("dp", "ep"), None) + ep_spec_3d = PartitionSpec(("dp", "ep"), None, None) + ep_spec_2d = PartitionSpec(("dp", "ep"), None) + + with self.mesh, global_shard_guard(self.mr): + + @jax.jit + def run(idx, toks, w): + idx = jax.lax.with_sharding_constraint(idx, NamedSharding(self.mesh, dp_spec)) + toks = jax.lax.with_sharding_constraint(toks, NamedSharding(self.mesh, dp_spec)) + w = jax.lax.with_sharding_constraint(w, NamedSharding(self.mesh, dp_spec)) + recv_t, recv_w, handle, tc = ep_dispatch(idx, toks, w, self.recv_capacity_per_rank) + recv_t = jax.lax.with_sharding_constraint( + recv_t, NamedSharding(self.mesh, ep_spec_3d) + ) + recv_w = jax.lax.with_sharding_constraint( + recv_w, NamedSharding(self.mesh, ep_spec_2d) + ) + out = ep_combine( + handle, tc, recv_t, recv_w, T_dp, out_sharding=(("dp", "ep"), None) + ) + return jax.lax.with_sharding_constraint(out, NamedSharding(self.mesh, dp_spec)) + + compiled = run.lower(topk_idx, tokens, topk_w).compile() + hlo = compiled.as_text() + # Match instruction names; "all-gather-start" and "all-gather-done" + # bracket a single async all-gather. + for op in ("all-gather-start", "all-to-all", "collective-permute"): + self.assertEqual(hlo.count(op), 0, f"unexpected XLA {op} in fwd HLO:\n{hlo}") + # XLA drops trailing-None entries from the spec; compare as a tuple. + # JAX collapses size-1 mesh axes, so dp=1 reduces ("dp","ep") to "ep". + expected = (("dp", "ep"),) if self.dp > 1 else ("ep",) + self.assertEqual(tuple(compiled.output_shardings.spec), expected) + + def test_z_no_unexpected_reshard_in_hlo_bwd(self): + """Compiled bwd HLO must not insert XLA collectives outside the EP FFI.""" + T_dp, tokens, topk_idx, topk_w = self._make_random_inputs() + rng = np.random.default_rng(seed=44) + expert_out = jnp.asarray( + rng.standard_normal( + (self.dp * self.ep, self.recv_capacity_per_rank, HIDDEN_DIM), dtype=np.float32 + ) + * 0.5, + dtype=jnp.bfloat16, + ) + dp_spec = PartitionSpec(("dp", "ep"), None) + ep_spec_3d = PartitionSpec(("dp", "ep"), None, None) + ep_spec_2d = PartitionSpec(("dp", "ep"), None) + + with self.mesh, global_shard_guard(self.mr): + + def fwd(eo, toks, idx, w): + eo = jax.lax.with_sharding_constraint(eo, NamedSharding(self.mesh, ep_spec_3d)) + toks = jax.lax.with_sharding_constraint(toks, NamedSharding(self.mesh, dp_spec)) + idx = jax.lax.with_sharding_constraint(idx, NamedSharding(self.mesh, dp_spec)) + w = jax.lax.with_sharding_constraint(w, NamedSharding(self.mesh, dp_spec)) + _rt, rw, handle, tc = ep_dispatch(idx, toks, w, self.recv_capacity_per_rank) + rw = jax.lax.with_sharding_constraint(rw, NamedSharding(self.mesh, ep_spec_2d)) + combined = ep_combine(handle, tc, eo, rw, T_dp, out_sharding=(("dp", "ep"), None)) + return jax.lax.with_sharding_constraint(combined, NamedSharding(self.mesh, dp_spec)) + + # jax.vjp + pinned cotangent feeds ep_combine_bwd/ep_dispatch_bwd + # the expected sharding without relying on XLA-transpose propagation. + def bwd_only(eo, toks, idx, w, g): + _y, vjp_fn = jax.vjp(fwd, eo, toks, idx, w) + g = jax.lax.with_sharding_constraint(g, NamedSharding(self.mesh, dp_spec)) + grads = vjp_fn(g) + return ( + jax.lax.with_sharding_constraint( + grads[0], NamedSharding(self.mesh, ep_spec_3d) + ), + jax.lax.with_sharding_constraint(grads[1], NamedSharding(self.mesh, dp_spec)), + ) + + g_seed = jnp.ones((T_dp, HIDDEN_DIM), dtype=jnp.bfloat16) + compiled = ( + jax.jit(bwd_only).lower(expert_out, tokens, topk_idx, topk_w, g_seed).compile() + ) + hlo = compiled.as_text() + for op in ("all-gather-start", "all-to-all", "collective-permute"): + self.assertEqual(hlo.count(op), 0, f"unexpected XLA {op} in bwd HLO:\n{hlo}") + + +# ── Entry point ────────────────────────────────────────────────────────────── + + +if __name__ == "__main__": + if len(sys.argv) < 4: + print("Usage: python test_multi_process_ep.py ") + sys.exit(1) + + coord_addr = sys.argv[1] + proc_id = int(sys.argv[2]) + num_procs = int(sys.argv[3]) + + jax.distributed.initialize( + coordinator_address=coord_addr, + num_processes=num_procs, + process_id=proc_id, + local_device_ids=[proc_id], + ) + + loader = unittest.TestLoader() + target = os.environ.get("TARGET_TEST") + if target: + name = target.split(".")[-1] + suite = loader.loadTestsFromName(name, TestEP) + else: + suite = loader.loadTestsFromTestCase(TestEP) + runner = unittest.TextTestRunner(verbosity=2) + result = runner.run(suite) + sys.exit(0 if result.wasSuccessful() else 1) diff --git a/transformer_engine/jax/cpp_extensions/__init__.py b/transformer_engine/jax/cpp_extensions/__init__.py index fe1f93dc7a..604da5e1b7 100644 --- a/transformer_engine/jax/cpp_extensions/__init__.py +++ b/transformer_engine/jax/cpp_extensions/__init__.py @@ -10,4 +10,5 @@ from .softmax import * from .gemm import * from .router import * +from .ep import * from .topk import * diff --git a/transformer_engine/jax/cpp_extensions/ep.py b/transformer_engine/jax/cpp_extensions/ep.py new file mode 100644 index 0000000000..7d112ad5f4 --- /dev/null +++ b/transformer_engine/jax/cpp_extensions/ep.py @@ -0,0 +1,955 @@ +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. +"""JAX/TE custom ops for Expert Parallelism (EP). + +Sharding model: + - EpPrepare / EpDispatch outputs carry a single leading ``num_procs`` dim. + Sharded compound ``(dp_resource, ep_resource)`` when DP is set, else + ``ep_resource`` alone. + - EpDispatch inputs are 2D ``[T, H]`` or 3D ``[B, S, H]``; only the first + dim may be sharded, with axis ∈ {ep, (dp, ep), dp, None}. Trailing dims + must be replicated. ``dp`` alone gets ``ep`` folded in locally. + - EpCombine output sharding comes from ``out_sharding`` or defaults to the + compound ``(dp, ep)`` axis on the leading dim. +""" + +from dataclasses import dataclass + +import jax +import jax.numpy as jnp +from jax import dtypes, ffi +from jax.sharding import NamedSharding, PartitionSpec +import jax.tree_util as jtu + +import transformer_engine_jax +from .base import BasePrimitive, register_primitive +from ..sharding import global_mesh_resource + +__all__ = [ + "EpConfig", + "EpHandle", + "set_ep_config", + "get_ep_config", + "get_ep_num_local_experts", + "ep_allocate_handle_id", + "ep_prepare", + "ep_dispatch_fwd", + "ep_combine_fwd", + "ep_dispatch_bwd", + "ep_combine_bwd", +] + + +# Routing-state container threaded through dispatch/combine/*_bwd. +@jtu.register_pytree_node_class +class EpHandle: + def __init__(self, handle_mem, handle_id): + self.handle_mem = handle_mem + self.handle_id = int(handle_id) + + def tree_flatten(self): + return (self.handle_mem,), (self.handle_id,) + + @classmethod + def tree_unflatten(cls, aux, children): + return cls(children[0], aux[0]) + + def __repr__(self): + return f"EpHandle(handle_id={self.handle_id})" + + +# ── Module-level EP config ────────────────────────────────────────────────── + + +@dataclass(frozen=True) +class EpConfig: + """Immutable Python view of the EP bootstrap config (see ep_bootstrap).""" + + world_size: int + rank: int + ep_size: int + num_experts: int + num_local_experts: int + max_tokens_per_rank: int + recv_capacity_per_rank: int + hidden_dim: int + + +_ep_config: EpConfig = None + + +def set_ep_config(config: EpConfig) -> None: + """Cache the EP config for abstract-eval / sharding helpers. Call once.""" + global _ep_config + _ep_config = config + + +def get_ep_config() -> EpConfig: + if _ep_config is None: + raise RuntimeError("EpConfig has not been set. Did you call ep_bootstrap()?") + return _ep_config + + +def get_ep_num_local_experts() -> int: + return get_ep_config().num_local_experts + + +# handle_id -> handle_mem buffer size in bytes. +_HANDLE_MEM_SIZE_BY_ID: dict = {} + + +def ep_allocate_handle_id(top_k: int, dispatch_output_per_expert_alignment: int = 0) -> int: + """Reserve a fresh handle_id for an EP layer. + + Distinct logical layers must each call this — sharing a handle_id across + layers corrupts the routing state, even when (top_k, alignment) match. + """ + handle_id, handle_mem_size = transformer_engine_jax.ep_register_layer( + int(top_k), int(dispatch_output_per_expert_alignment) + ) + handle_id = int(handle_id) + _HANDLE_MEM_SIZE_BY_ID[handle_id] = int(handle_mem_size) + return handle_id + + +def _ep_handle_mem_size(handle_id: int) -> int: + """Return the handle_mem byte size for an id from ep_allocate_handle_id.""" + try: + return _HANDLE_MEM_SIZE_BY_ID[int(handle_id)] + except KeyError as e: + raise RuntimeError( + f"handle_id={handle_id} not registered; call ep_allocate_handle_id first." + ) from e + + +def _leading_axis_ok(spec, ep_axis, outer_axes=()): + # Only the first dim may carry sharding; remaining dims must be replicated. + # The first dim's axis must be one of: + # ``ep_axis`` alone, + # a tuple of dp/fsdp axes (no ep — ep gets sliced in locally), + # a tuple ending in ``ep_axis`` with dp/fsdp axes before it. + # Examples on a (dp, ep) mesh: 2D ``(ep, None)``, ``(("dp","ep"), None)``, + # ``("dp", None)``; 3D ``(ep, None, None)``, ``(("dp","ep"), None, None)``, + # ``("dp", None, None)``. + if len(spec) < 2 or ep_axis is None: + return False + if any(ax is not None for ax in spec[1:]): + return False # only first dim sharded + leading = spec[0] + allowed_outers = {a for a in outer_axes if a is not None} + allowed = allowed_outers | {ep_axis, None} + elts = leading if isinstance(leading, tuple) else (leading,) + return all(a in allowed for a in elts) + + +def _canonical_input_spec(spec, ndim): + """Canonical input PartitionSpec the primitive demands JAX deliver. + + Sharding lives entirely on the first dim. If ``spec[0]`` already includes + ``ep_resource``, returned unchanged. Otherwise ``ep_resource`` is folded + into the first-dim axis tuple, e.g. ``"dp"`` → ``("dp","ep")``. The added + ep axis is a local slice (the missing dim was replicated), no cross-device + comm. + """ + gsr = global_mesh_resource() + ep = gsr.ep_resource + leading = spec[0] + present = leading if isinstance(leading, tuple) else (leading,) if leading is not None else () + if ep in present: + return PartitionSpec(*spec) + if leading is None: + new_leading = ep + elif isinstance(leading, tuple): + new_leading = (*leading, ep) + else: + new_leading = (leading, ep) + return PartitionSpec(new_leading, *([None] * (ndim - 1))) + + +def _dispatch_input_outer_axes(): + """dp/fsdp axes allowed as outer companions to ep_resource on dispatch input.""" + gsr = global_mesh_resource() + return tuple(a for a in (gsr.dp_resource, gsr.fsdp_resource) if a is not None) + + +def _ep_outer_axis(): + """The single dp/fsdp axis (if any) sitting outside ep on EP-output tensors. + + When set, EP-output globals carry an extra leading ``dp_size`` dim so SPMD + sees each DP color's slab as distinct (rather than replicated across DP). + """ + gsr = global_mesh_resource() + return gsr.dp_resource or gsr.fsdp_resource + + +def _ep_leading_dims(is_outer): + """Single leading dim of an EP-output tensor: ``(dp*ep,)`` (or ``(ep,)`` when + DP is unset) globally; ``(1,)`` per shard.""" + cfg = get_ep_config() + outer = _ep_outer_axis() + if not is_outer: + return (1,) + return (cfg.world_size,) if outer is not None else (cfg.ep_size,) + + +def _ep_output_spec(*trailing): + """PartitionSpec for an EP-output tensor: ``(("dp","ep"), *trailing)`` when + DP is set (compound leading axis on a single dim), else ``("ep",*trailing)``.""" + gsr = global_mesh_resource() + outer = _ep_outer_axis() + if outer is None: + return PartitionSpec(gsr.ep_resource, *trailing) + return PartitionSpec((outer, gsr.ep_resource), *trailing) + + +def _ep_spec_ok(spec, trailing_count): + """Accept ``(ep, *[None])`` (no DP) or ``((dp,ep), *[None])`` / + ``(("dp",), *[None])`` / ``("dp", *[None])`` / ``(None, *[None])`` (with DP) + on an EP-output tensor's single leading dim. JAX may collapse a size-1 + mesh axis to ``None`` (matters for dp_size=1 like 1x4).""" + gsr = global_mesh_resource() + ep_axis = gsr.ep_resource + outer = _ep_outer_axis() + expected_len = 1 + trailing_count + if len(spec) != expected_len: + return False + if any(ax is not None for ax in spec[1:]): + return False + leading = spec[0] + if outer is None: + return leading == ep_axis + allowed = {ep_axis, outer, None} + elts = leading if isinstance(leading, tuple) else (leading,) + return all(a in allowed for a in elts) + + +# ── ep_prepare ────────────────────────────────────────────────────────────── + + +class EpPreparePrimitive(BasePrimitive): + name = "te_ep_prepare_ffi" + multiple_results = True + impl_static_args = (1, 2, 3) # handle_id, dispatch_output_per_expert_alignment, is_outer + inner_primitive = None + outer_primitive = None + + @staticmethod + def abstract(topk_idx_aval, *, handle_id, dispatch_output_per_expert_alignment, is_outer): + # is_outer=True: global leading dim = (world_size,) (or (ep_size,) with + # no DP); False: per-shard = (1,). + del dispatch_output_per_expert_alignment + cfg = get_ep_config() + num_local_experts = cfg.num_local_experts + assert ( + len(topk_idx_aval.shape) >= 2 + ), f"topk_idx must be at least 2D [..., top_k], got shape {topk_idx_aval.shape}" + handle_mem_size = _ep_handle_mem_size(handle_id) + leading = _ep_leading_dims(is_outer) + token_counts_aval = jax.core.ShapedArray(leading + (num_local_experts,), jnp.int32) + handle_mem_aval = jax.core.ShapedArray(leading + (handle_mem_size,), jnp.uint8) + # FFI scratch for the int32 -> int64 topk_idx upcast. int32 with last + # dim doubled to keep the int64 byte count without JAX_ENABLE_X64. + # TODO(phuong): drop once NCCL EP supports int32 topk_idx. + workspace_shape = topk_idx_aval.shape[:-1] + (topk_idx_aval.shape[-1] * 2,) + workspace_aval = jax.core.ShapedArray(workspace_shape, jnp.int32) + return token_counts_aval, handle_mem_aval, workspace_aval + + @staticmethod + def outer_abstract(topk_idx_aval, *, handle_id, dispatch_output_per_expert_alignment, is_outer): + del is_outer + avals = EpPreparePrimitive.abstract( + topk_idx_aval, + handle_id=handle_id, + dispatch_output_per_expert_alignment=dispatch_output_per_expert_alignment, + is_outer=True, + ) + return avals[:2] + + @staticmethod + def lowering(ctx, topk_idx, *, handle_id, dispatch_output_per_expert_alignment, is_outer): + del is_outer + return ffi.ffi_lowering(EpPreparePrimitive.name)( + ctx, + topk_idx, + handle_id=int(handle_id), + dispatch_output_per_expert_alignment=dispatch_output_per_expert_alignment, + ) + + @staticmethod + def impl(topk_idx, handle_id, dispatch_output_per_expert_alignment, is_outer): + assert EpPreparePrimitive.inner_primitive is not None + token_counts, handle_mem, _workspace = EpPreparePrimitive.inner_primitive.bind( + topk_idx, + handle_id=handle_id, + dispatch_output_per_expert_alignment=dispatch_output_per_expert_alignment, + is_outer=is_outer, + ) + return token_counts, handle_mem + + @staticmethod + def batcher( + batched_args, batch_dims, *, handle_id, dispatch_output_per_expert_alignment, is_outer + ): + raise NotImplementedError("EpPreparePrimitive does not support vmap") + + @staticmethod + def partition( + handle_id, dispatch_output_per_expert_alignment, is_outer, mesh, arg_infos, result_infos + ): + del is_outer, result_infos + gsr = global_mesh_resource() + ep_axis = gsr.ep_resource + outer_axes = _dispatch_input_outer_axes() + idx_spec = arg_infos[0].sharding.spec + if not _leading_axis_ok(idx_spec, ep_axis, outer_axes): + raise NotImplementedError( + "EpPrepare: topk_idx leading dims must shard on ep_resource" + f" ('{ep_axis}') and/or {outer_axes}, with the topk dim replicated;" + f" got spec={idx_spec}." + ) + idx_ndim = len(arg_infos[0].shape) + arg_shardings = (NamedSharding(mesh, _canonical_input_spec(idx_spec, idx_ndim)),) + tc_sharding = NamedSharding(mesh, _ep_output_spec(None)) + hm_sharding = NamedSharding(mesh, _ep_output_spec(None)) + + def sharded_impl(topk_idx): + return EpPreparePrimitive.impl( + topk_idx, handle_id, dispatch_output_per_expert_alignment, False + ) + + return mesh, sharded_impl, (tc_sharding, hm_sharding), arg_shardings + + @staticmethod + def shardy_sharding_rule(*args): + # Signature: (*static_args, mesh, value_types, result_types). Static args + # for this primitive are (handle_id, dispatch_alignment, is_outer). + value_types = args[-2] + topk_idx_rank = len(value_types[0].shape) + in_axes = " ".join(f"L{i}" for i in range(topk_idx_rank - 1)) + " topk" + return f"{in_axes} -> EPL nle, EPL hm" + + +register_primitive(EpPreparePrimitive) + + +# ── ep_dispatch ───────────────────────────────────────────────────────────── + + +class EpDispatchPrimitive(BasePrimitive): + name = "te_ep_dispatch_ffi" + multiple_results = True + impl_static_args = (4, 5, 6, 7) # handle_id, recv_capacity_per_rank, top_k, is_outer + inner_primitive = None + outer_primitive = None + + @staticmethod + def abstract( + handle_mem_aval, + topk_idx_aval, + tokens_aval, + topk_weights_aval, + *, + handle_id, + recv_capacity_per_rank, + top_k, + is_outer, + ): + # is_outer=True: global leading dim = (world_size,) (or (ep_size,) with + # no DP); False: per-shard = (1,). + del handle_id, topk_weights_aval, top_k, handle_mem_aval + assert ( + len(tokens_aval.shape) >= 2 + ), f"tokens must be at least 2D [..., H], got shape {tokens_aval.shape}" + recv_pr = recv_capacity_per_rank + tok_dtype = dtypes.canonicalize_dtype(tokens_aval.dtype) + hidden_dim = tokens_aval.shape[-1] + leading = _ep_leading_dims(is_outer) + recv_tokens_aval = jax.core.ShapedArray(leading + (recv_pr, hidden_dim), tok_dtype) + recv_topk_weights_aval = jax.core.ShapedArray(leading + (recv_pr,), jnp.float32) + # int32 with last dim doubled to keep the int64 byte count without JAX_ENABLE_X64. + workspace_shape = topk_idx_aval.shape[:-1] + (topk_idx_aval.shape[-1] * 2,) + workspace_aval = jax.core.ShapedArray(workspace_shape, jnp.int32) + return (recv_tokens_aval, recv_topk_weights_aval, workspace_aval) + + @staticmethod + def outer_abstract(*args, **kwargs): + kwargs = dict(kwargs) + kwargs["is_outer"] = True + avals = EpDispatchPrimitive.abstract(*args, **kwargs) + return avals[:2] + + @staticmethod + def lowering( + ctx, + handle_mem, + topk_idx, + tokens, + topk_weights, + *, + handle_id, + recv_capacity_per_rank, + top_k, + is_outer, + ): + del recv_capacity_per_rank, is_outer + return ffi.ffi_lowering(EpDispatchPrimitive.name)( + ctx, + handle_mem, + topk_idx, + tokens, + topk_weights, + handle_id=int(handle_id), + top_k=top_k, + ) + + @staticmethod + def impl( + handle_mem, + topk_idx, + tokens, + topk_weights, + handle_id, + recv_capacity_per_rank, + top_k, + is_outer, + ): + assert EpDispatchPrimitive.inner_primitive is not None + recv_tokens, recv_topk_weights, _workspace = EpDispatchPrimitive.inner_primitive.bind( + handle_mem, + topk_idx, + tokens, + topk_weights, + handle_id=handle_id, + recv_capacity_per_rank=recv_capacity_per_rank, + top_k=top_k, + is_outer=is_outer, + ) + return recv_tokens, recv_topk_weights + + @staticmethod + def batcher(batched_args, batch_dims, *, handle_id, recv_capacity_per_rank, top_k, is_outer): + raise NotImplementedError("EpDispatchPrimitive does not support vmap") + + @staticmethod + def partition( + handle_id, recv_capacity_per_rank, top_k, is_outer, mesh, arg_infos, result_infos + ): + del is_outer, result_infos + gsr = global_mesh_resource() + ep_axis = gsr.ep_resource + outer_axes = _dispatch_input_outer_axes() + tokens_spec = arg_infos[2].sharding.spec + if not _leading_axis_ok(tokens_spec, ep_axis, outer_axes): + raise NotImplementedError( + "EpDispatch: tokens leading dims must shard on ep_resource" + f" ('{ep_axis}') and/or {outer_axes}, hidden dim replicated;" + f" got spec={tokens_spec}." + ) + idx_spec = arg_infos[1].sharding.spec + tw_spec = arg_infos[3].sharding.spec + arg_shardings = ( + arg_infos[0].sharding, + NamedSharding(mesh, _canonical_input_spec(idx_spec, len(arg_infos[1].shape))), + NamedSharding(mesh, _canonical_input_spec(tokens_spec, len(arg_infos[2].shape))), + NamedSharding(mesh, _canonical_input_spec(tw_spec, len(arg_infos[3].shape))), + ) + out_shardings = ( + NamedSharding(mesh, _ep_output_spec(None, None)), + NamedSharding(mesh, _ep_output_spec(None)), + ) + + def sharded_impl(handle_mem, topk_idx, tokens, topk_weights): + return EpDispatchPrimitive.impl( + handle_mem, + topk_idx, + tokens, + topk_weights, + handle_id, + recv_capacity_per_rank, + top_k, + False, + ) + + return mesh, sharded_impl, out_shardings, arg_shardings + + @staticmethod + def shardy_sharding_rule(*args): + # Signature: (*static_args, mesh, value_types, result_types). Static args + # for this primitive are (handle_id, recv_capacity_per_rank, top_k, is_outer). + value_types = args[-2] + # Inputs: handle_mem, topk_idx, tokens, topk_weights. + idx_rank = len(value_types[1].shape) + tok_rank = len(value_types[2].shape) + tw_rank = len(value_types[3].shape) + idx_axes = " ".join(f"I{i}" for i in range(idx_rank - 1)) + " topk_in" + tok_axes = " ".join(f"T{i}" for i in range(tok_rank - 1)) + " H" + tw_axes = " ".join(f"W{i}" for i in range(tw_rank - 1)) + " topk" + return f"EPL hm, {idx_axes}, {tok_axes}, {tw_axes} -> EPL recv_pr H, EPL recv_pr" + + +register_primitive(EpDispatchPrimitive) + + +# ── ep_combine ────────────────────────────────────────────────────────────── +# `expert_out` here is the post-weight buffer; ep.ep_combine applies the +# hadamard before calling. + + +def _normalize_leading_shape(s): + return s if isinstance(s, tuple) else (int(s),) + + +def _prod(seq): + p = 1 + for x in seq: + p *= int(x) + return p + + +def _resolve_out_partition_spec(out_partition_spec, num_leading): + """Pick the combine output PartitionSpec. + + Defaults to a compound leading axis ``(dp_resource, ep_resource)`` when a + DP/FSDP axis is set on the active MeshResource, else just ``ep_resource``. + This matches the input sharding so XLA does not need collective-permutes + in the bwd path. + """ + if out_partition_spec is not None: + assert len(out_partition_spec) == num_leading + 1, ( + f"out_partition_spec length {len(out_partition_spec)} must equal num_leading" + f" + 1 ({num_leading + 1})" + ) + return tuple(out_partition_spec) + gsr = global_mesh_resource() + if gsr.ep_resource is None: + raise ValueError( + "ep_combine: ep_resource is not set on the active MeshResource;" + " pass out_sharding=... explicitly." + ) + outer = gsr.dp_resource or gsr.fsdp_resource + leading = (outer, gsr.ep_resource) if outer is not None else gsr.ep_resource + return (leading,) + (None,) * num_leading + + +def _per_shard_leading(out_leading_shape, resolved_spec, mesh): + """Per-shard leading shape given resolved partition spec and mesh.""" + per_shard = list(out_leading_shape) + for i, ax in enumerate(resolved_spec[: len(out_leading_shape)]): + if ax is None: + continue + axes = ax if isinstance(ax, tuple) else (ax,) + factor = 1 + for a in axes: + factor *= mesh.shape[a] + assert ( + per_shard[i] % factor == 0 + ), f"leading dim {per_shard[i]} not divisible by shard factor {factor} on axes {axes}" + per_shard[i] //= factor + return tuple(per_shard) + + +class EpCombinePrimitive(BasePrimitive): + name = "te_ep_combine_ffi" + multiple_results = False + impl_static_args = (2, 3, 4) # handle_id, out_leading_shape, out_partition_spec + inner_primitive = None + outer_primitive = None + + @staticmethod + def abstract( + handle_mem_aval, + expert_out_aval, + *, + handle_id, + out_leading_shape, + out_partition_spec, + ): + del handle_id, out_partition_spec, handle_mem_aval + assert ( + len(expert_out_aval.shape) == 3 + ), f"expert_out must be 3D [num_procs, recv_pr, H], got shape {expert_out_aval.shape}" + eo_dtype = dtypes.canonicalize_dtype(expert_out_aval.dtype) + hidden_dim = expert_out_aval.shape[-1] + out_shape = tuple(out_leading_shape) + (hidden_dim,) + return jax.core.ShapedArray(out_shape, eo_dtype) + + @staticmethod + def lowering( + ctx, + handle_mem, + expert_out, + *, + handle_id, + out_leading_shape, + out_partition_spec, + ): + del out_partition_spec + return ffi.ffi_lowering(EpCombinePrimitive.name)( + ctx, + handle_mem, + expert_out, + handle_id=int(handle_id), + num_local_tokens=_prod(out_leading_shape), + ) + + @staticmethod + def impl(handle_mem, expert_out, handle_id, out_leading_shape, out_partition_spec): + assert EpCombinePrimitive.inner_primitive is not None + return EpCombinePrimitive.inner_primitive.bind( + handle_mem, + expert_out, + handle_id=handle_id, + out_leading_shape=out_leading_shape, + out_partition_spec=out_partition_spec, + ) + + @staticmethod + def batcher(batched_args, batch_dims, *, handle_id, out_leading_shape, out_partition_spec): + raise NotImplementedError("EpCombinePrimitive does not support vmap") + + @staticmethod + def partition(handle_id, out_leading_shape, out_partition_spec, mesh, arg_infos, result_infos): + del result_infos + eo_spec = arg_infos[1].sharding.spec + if not _ep_spec_ok(eo_spec, trailing_count=2): + raise NotImplementedError( + "EpCombine: expert_out must be sharded as PartitionSpec(ep_resource," + " None, None) (or ((dp, ep), None, None) when dp/fsdp is set)" + f" over [num_procs, recv_pr, H]; got spec={eo_spec}." + ) + resolved = _resolve_out_partition_spec(out_partition_spec, len(out_leading_shape)) + per_shard_leading = _per_shard_leading(out_leading_shape, resolved, mesh) + arg_shardings = tuple(a.sharding for a in arg_infos) + out_sharding = NamedSharding(mesh, PartitionSpec(*resolved)) + + def sharded_impl(handle_mem, expert_out): + return EpCombinePrimitive.impl( + handle_mem, expert_out, handle_id, per_shard_leading, out_partition_spec + ) + + return mesh, sharded_impl, out_sharding, arg_shardings + + @staticmethod + def shardy_sharding_rule(*args): + # Signature: (*static_args, mesh, value_types, result_types). Static args: + # (handle_id, out_leading_shape, out_partition_spec). + result_types = args[-1] + out_rank = len(result_types[0].shape) + out_axes = " ".join(f"O{i}" for i in range(out_rank - 1)) + " H" + return f"EPL hm, EPL recv_pr H -> {out_axes}" + + +register_primitive(EpCombinePrimitive) + + +# ── ep_dispatch_bwd ───────────────────────────────────────────────────────── + + +class EpDispatchBwdPrimitive(BasePrimitive): + name = "te_ep_dispatch_bwd_ffi" + multiple_results = True + impl_static_args = (3, 4, 5, 6) # handle_id, top_k, out_leading_shape, out_partition_spec + inner_primitive = None + outer_primitive = None + + @staticmethod + def abstract( + handle_mem_aval, + grad_aval, + g_recv_topk_weights_aval, + *, + handle_id, + top_k, + out_leading_shape, + out_partition_spec, + ): + del handle_id, g_recv_topk_weights_aval, out_partition_spec, handle_mem_aval + assert ( + len(grad_aval.shape) == 3 + ), f"grad must be 3D [num_procs, recv_pr, H], got shape {grad_aval.shape}" + g_dtype = dtypes.canonicalize_dtype(grad_aval.dtype) + hidden_dim = grad_aval.shape[-1] + result_aval = jax.core.ShapedArray(tuple(out_leading_shape) + (hidden_dim,), g_dtype) + grad_topk_weights_aval = jax.core.ShapedArray( + tuple(out_leading_shape) + (top_k,), jnp.float32 + ) + return result_aval, grad_topk_weights_aval + + @staticmethod + def lowering( + ctx, + handle_mem, + grad, + g_recv_topk_weights, + *, + handle_id, + top_k, + out_leading_shape, + out_partition_spec, + ): + del out_partition_spec + return ffi.ffi_lowering(EpDispatchBwdPrimitive.name)( + ctx, + handle_mem, + grad, + g_recv_topk_weights, + handle_id=int(handle_id), + num_local_tokens=_prod(out_leading_shape), + top_k=top_k, + ) + + @staticmethod + def impl( + handle_mem, + grad, + g_recv_topk_weights, + handle_id, + top_k, + out_leading_shape, + out_partition_spec, + ): + assert EpDispatchBwdPrimitive.inner_primitive is not None + return EpDispatchBwdPrimitive.inner_primitive.bind( + handle_mem, + grad, + g_recv_topk_weights, + handle_id=handle_id, + top_k=top_k, + out_leading_shape=out_leading_shape, + out_partition_spec=out_partition_spec, + ) + + @staticmethod + def batcher( + batched_args, + batch_dims, + *, + handle_id, + top_k, + out_leading_shape, + out_partition_spec, + ): + raise NotImplementedError("EpDispatchBwdPrimitive does not support vmap") + + @staticmethod + def partition( + handle_id, + top_k, + out_leading_shape, + out_partition_spec, + mesh, + arg_infos, + result_infos, + ): + del result_infos + g_spec = arg_infos[1].sharding.spec + if not _ep_spec_ok(g_spec, trailing_count=2): + raise NotImplementedError( + "EpDispatchBwd: grad must be sharded as PartitionSpec(ep_resource," + " None, None) (or ((dp, ep), None, None) when dp/fsdp is set)" + f" over [num_procs, recv_pr, H]; got spec={g_spec}." + ) + gw_spec = arg_infos[2].sharding.spec + if not _ep_spec_ok(gw_spec, trailing_count=1): + raise NotImplementedError( + "EpDispatchBwd: g_recv_topk_weights must be sharded as" + " PartitionSpec(ep_resource, None) (or ((dp, ep), None) when dp/fsdp is set)" + f" over [num_procs, recv_pr]; got spec={gw_spec}." + ) + resolved = _resolve_out_partition_spec(out_partition_spec, len(out_leading_shape)) + per_shard_leading = _per_shard_leading(out_leading_shape, resolved, mesh) + arg_shardings = tuple(a.sharding for a in arg_infos) + out_shardings = [ + NamedSharding(mesh, PartitionSpec(*resolved)), + NamedSharding(mesh, PartitionSpec(*resolved, None)), + ] + + def sharded_impl(handle_mem, grad, g_recv_topk_weights): + return EpDispatchBwdPrimitive.impl( + handle_mem, + grad, + g_recv_topk_weights, + handle_id, + top_k, + per_shard_leading, + out_partition_spec, + ) + + return mesh, sharded_impl, out_shardings, arg_shardings + + @staticmethod + def shardy_sharding_rule(*args): + # Signature: (*static_args, mesh, value_types, result_types). Result rank + # follows out_leading_shape (static arg #2): rank = len(out_leading) + 1. + result_types = args[-1] + out_rank = len(result_types[0].shape) + out_axes = " ".join(f"O{i}" for i in range(out_rank - 1)) + return f"EPL hm, EPL recv_pr H, EPL recv_pr -> {out_axes} H, {out_axes} k" + + +register_primitive(EpDispatchBwdPrimitive) + + +# ── ep_combine_bwd ────────────────────────────────────────────────────────── + + +class EpCombineBwdPrimitive(BasePrimitive): + name = "te_ep_combine_bwd_ffi" + multiple_results = False + impl_static_args = (2, 3, 4) # handle_id, recv_capacity_per_rank, is_outer + inner_primitive = None + outer_primitive = None + + @staticmethod + def abstract(handle_mem_aval, grad_aval, *, handle_id, recv_capacity_per_rank, is_outer): + # is_outer=True: global leading dim = (world_size,) (or (ep_size,) with + # no DP); False: per-shard = (1,). + del handle_id, handle_mem_aval + assert ( + len(grad_aval.shape) >= 2 + ), f"grad must be at least 2D [..., H], got shape {grad_aval.shape}" + g_dtype = dtypes.canonicalize_dtype(grad_aval.dtype) + hidden_dim = grad_aval.shape[-1] + leading = _ep_leading_dims(is_outer) + return jax.core.ShapedArray(leading + (recv_capacity_per_rank, hidden_dim), g_dtype) + + @staticmethod + def outer_abstract(*args, **kwargs): + kwargs = dict(kwargs) + kwargs["is_outer"] = True + return EpCombineBwdPrimitive.abstract(*args, **kwargs) + + @staticmethod + def lowering(ctx, handle_mem, grad, *, handle_id, recv_capacity_per_rank, is_outer): + del recv_capacity_per_rank, is_outer + return ffi.ffi_lowering(EpCombineBwdPrimitive.name)( + ctx, + handle_mem, + grad, + handle_id=int(handle_id), + ) + + @staticmethod + def impl(handle_mem, grad, handle_id, recv_capacity_per_rank, is_outer): + assert EpCombineBwdPrimitive.inner_primitive is not None + return EpCombineBwdPrimitive.inner_primitive.bind( + handle_mem, + grad, + handle_id=handle_id, + recv_capacity_per_rank=recv_capacity_per_rank, + is_outer=is_outer, + ) + + @staticmethod + def batcher(batched_args, batch_dims, *, handle_id, recv_capacity_per_rank, is_outer): + raise NotImplementedError("EpCombineBwdPrimitive does not support vmap") + + @staticmethod + def partition(handle_id, recv_capacity_per_rank, is_outer, mesh, arg_infos, result_infos): + del is_outer, result_infos + arg_shardings = tuple(a.sharding for a in arg_infos) + out_sharding = NamedSharding(mesh, _ep_output_spec(None, None)) + + def sharded_impl(handle_mem, grad): + return EpCombineBwdPrimitive.impl( + handle_mem, grad, handle_id, recv_capacity_per_rank, False + ) + + return mesh, sharded_impl, out_sharding, arg_shardings + + @staticmethod + def shardy_sharding_rule(*args): + # T axes are dynamic-rank based on the actual cotangent shape. + value_types = args[-2] + g_rank = len(value_types[1].shape) + g_axes = " ".join(f"T{i}" for i in range(g_rank - 1)) + " H" + return f"EPL hm, {g_axes} -> EPL recv_pr H" + + +register_primitive(EpCombineBwdPrimitive) + + +# ── Public-ish helpers (used by jax/ep.py) ────────────────────────────────── + + +_HANDLE_ID_CALLSITE_CACHE = {} + + +def ep_prepare(topk_idx, dispatch_output_per_expert_alignment=0): + """Exchange routing metadata; return ``(token_counts, EpHandle)``.""" + import sys as _sys + + top_k = int(topk_idx.shape[-1]) + alignment = int(dispatch_output_per_expert_alignment) + # Cache handle_id by caller (file:lineno, top_k, alignment): JAX re-traces + # the same call site (e.g. custom_vjp fwd vs primal) and the resulting + # EpHandles must share the same id to compare equal in pytree aux. + f = _sys._getframe(1) + cache_key = (f.f_code.co_filename, f.f_lineno, top_k, alignment) + handle_id = _HANDLE_ID_CALLSITE_CACHE.get(cache_key) + if handle_id is None: + handle_id = ep_allocate_handle_id(top_k, alignment) + _HANDLE_ID_CALLSITE_CACHE[cache_key] = handle_id + token_counts, handle_mem = EpPreparePrimitive.outer_primitive.bind( + topk_idx, + handle_id=handle_id, + dispatch_output_per_expert_alignment=alignment, + is_outer=True, + ) + return token_counts, EpHandle(handle_mem, handle_id) + + +def ep_dispatch_fwd(handle, topk_idx, tokens, topk_weights, recv_capacity_per_rank): + """Scatter tokens and weights to expert ranks; returns (recv_tokens, recv_topk_weights, handle).""" + top_k = int(topk_weights.shape[-1]) + recv_tokens, recv_topk_weights = EpDispatchPrimitive.outer_primitive.bind( + handle.handle_mem, + topk_idx, + tokens, + topk_weights, + handle_id=handle.handle_id, + recv_capacity_per_rank=recv_capacity_per_rank, + top_k=top_k, + is_outer=True, + ) + return recv_tokens, recv_topk_weights, handle + + +def ep_combine_fwd(handle, expert_out, num_local_tokens, out_partition_spec=None): + """Gather expert outputs back to home ranks. expert_out is pre-weighted.""" + out_leading = _normalize_leading_shape(num_local_tokens) + return EpCombinePrimitive.outer_primitive.bind( + handle.handle_mem, + expert_out, + handle_id=handle.handle_id, + out_leading_shape=out_leading, + out_partition_spec=out_partition_spec, + ) + + +def ep_dispatch_bwd( + handle, grad, g_recv_topk_weights, top_k, num_local_tokens, out_partition_spec=None +): + """Backward of dispatch; returns (grad_tokens, grad_topk_weights).""" + out_leading = _normalize_leading_shape(num_local_tokens) + return EpDispatchBwdPrimitive.outer_primitive.bind( + handle.handle_mem, + grad, + g_recv_topk_weights, + handle_id=handle.handle_id, + top_k=int(top_k), + out_leading_shape=out_leading, + out_partition_spec=out_partition_spec, + ) + + +def ep_combine_bwd(handle, grad, recv_capacity_per_rank): + """Backward of combine; returns grad_expert_out [num_procs, recv_capacity_per_rank, H].""" + return EpCombineBwdPrimitive.outer_primitive.bind( + handle.handle_mem, + grad, + handle_id=handle.handle_id, + recv_capacity_per_rank=recv_capacity_per_rank, + is_outer=True, + ) diff --git a/transformer_engine/jax/csrc/extensions.h b/transformer_engine/jax/csrc/extensions.h index 2ecfedc8a2..7f7bee84a9 100644 --- a/transformer_engine/jax/csrc/extensions.h +++ b/transformer_engine/jax/csrc/extensions.h @@ -199,6 +199,25 @@ XLA_FFI_DECLARE_HANDLER_SYMBOL(FusedTopkWithScoreFunctionBackwardHandler); XLA_FFI_DECLARE_HANDLER_SYMBOL(FusedMoEAuxLossForwardHandler); XLA_FFI_DECLARE_HANDLER_SYMBOL(FusedMoEAuxLossBackwardHandler); +// EP bootstrap (called once per process) +void EpInitialize(pybind11::bytes unique_id_bytes, int ep_size, int rank_within_group, + int num_experts, int max_tokens_per_rank, int max_recv_tokens_per_rank, + int hidden_dim, int max_num_sms); +// EP shutdown — registered as a Python atexit hook so it runs before +// C++ static destructors of the JAX extension and libtransformer_engine.so. +void EpShutdown(); +// Host-only: register an EP layer. Returns (handle_id, handle_mem_size) where +// handle_id is baked into each FFI op as a static int64 attribute (no D2H sync +// per op) and handle_mem_size sizes the caller's handle_mem buffer. +pybind11::tuple EpRegisterLayer(int top_k, size_t dispatch_output_per_expert_alignment); + +// EP FFI handlers +XLA_FFI_DECLARE_HANDLER_SYMBOL(EpPrepareHandler); +XLA_FFI_DECLARE_HANDLER_SYMBOL(EpDispatchHandler); +XLA_FFI_DECLARE_HANDLER_SYMBOL(EpCombineHandler); +XLA_FFI_DECLARE_HANDLER_SYMBOL(EpDispatchBwdHandler); +XLA_FFI_DECLARE_HANDLER_SYMBOL(EpCombineBwdHandler); + // TopK XLA_FFI_DECLARE_HANDLER_SYMBOL(TopkHandler); pybind11::tuple GetTopkWorkspaceSizes(int batch_size, int seq_len, int k); diff --git a/transformer_engine/jax/csrc/extensions/ep.cpp b/transformer_engine/jax/csrc/extensions/ep.cpp new file mode 100644 index 0000000000..e2c50135aa --- /dev/null +++ b/transformer_engine/jax/csrc/extensions/ep.cpp @@ -0,0 +1,457 @@ +/************************************************************************* + * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#ifdef NVTE_WITH_NCCL_EP + +#include "transformer_engine/ep.h" + +#include + +#include +#include +#include + +#include "../extensions.h" +#include "common.h" +#include "transformer_engine/gemm.h" + +namespace transformer_engine { +namespace jax { + +namespace { + +// Process-lifetime owner of the EP ncclComm_t. Created from a broadcast +// ncclUniqueId during EpInitialize; destroyed by EpShutdown (registered as a +// Python atexit hook from ep.py so it runs before C++ static destructors). +class EpCommManager { + public: + static EpCommManager& get() { + static EpCommManager inst; + return inst; + } + + void init_from_uid(const uint8_t* uid_bytes, int ep_size, int rank_within_group) { + std::lock_guard lock(mutex_); + NVTE_CHECK(comm_ == nullptr, "EP comm already initialized for this process"); + ncclUniqueId uid; + std::memcpy(&uid, uid_bytes, sizeof(uid)); + NVTE_CHECK_NCCL(ncclCommInitRank(&comm_, ep_size, uid, rank_within_group)); + } + + ncclComm_t comm() const { return comm_; } + + void shutdown() { + std::lock_guard lock(mutex_); + if (comm_ == nullptr) return; + ncclCommDestroy(comm_); + comm_ = nullptr; + } + + private: + EpCommManager() = default; + // Intentionally no NCCL teardown in the destructor: this runs at static-dtor + // time, after Python has finalized and possibly after the CUDA driver + // detaches the context. Calling ncclCommDestroy there has been observed to + // hang or report cudartUnloading. Normal teardown goes through the Python + // atexit hook (shutdown_ep_communicator) registered from ep.py; any path + // that skips that (os._exit, fatal signal) leaks the comm, which the OS + // reaps on process exit. + ~EpCommManager() = default; + EpCommManager(const EpCommManager&) = delete; + EpCommManager& operator=(const EpCommManager&) = delete; + + std::mutex mutex_; + ncclComm_t comm_{nullptr}; +}; + +} // namespace + +// handle_id is baked at jit trace time and carried as a static FFI attribute. + +struct EpPrepareConfig { + int64_t handle_id; + int64_t dispatch_output_per_expert_alignment; +}; + +struct EpDispatchConfig { + int64_t handle_id; + int64_t top_k; +}; + +struct EpCombineConfig { + int64_t handle_id; + int64_t num_local_tokens; +}; + +struct EpDispatchBwdConfig { + int64_t handle_id; + int64_t num_local_tokens; + int64_t top_k; +}; + +struct EpCombineBwdConfig { + int64_t handle_id; +}; + +// ── Bootstrap helpers ───────────────────────────────────────────────────────── + +void EpInitialize(pybind11::bytes unique_id_bytes_obj, int ep_size, int rank_within_group, + int num_experts, int max_tokens_per_rank, int max_recv_tokens_per_rank, + int hidden_dim, int max_num_sms) { + std::string uid_str = unique_id_bytes_obj; + NVTE_CHECK(static_cast(uid_str.size()) >= 128, + "unique_id_bytes must be at least 128 bytes (ncclUniqueId size)."); + EpCommManager::get().init_from_uid(reinterpret_cast(uid_str.data()), ep_size, + rank_within_group); + NVTEEpGroupConfig cfg{.ep_size = ep_size, + .num_experts = num_experts, + .max_tokens_per_rank = max_tokens_per_rank, + .max_recv_tokens_per_rank = max_recv_tokens_per_rank, + .hidden_dim = hidden_dim, + .max_num_sms = max_num_sms}; + // If common rejects the config (validate_config / ncclEpCreateGroup), roll + // the comm back so the two singletons don't end up in inconsistent states + // and the comm doesn't strand until process exit. + try { + nvte_ep_initialize(static_cast(EpCommManager::get().comm()), cfg); + } catch (...) { + EpCommManager::get().shutdown(); + throw; + } +} + +void EpShutdown() { + // Order matters: ep_group_ in common reads from the comm, so tear it down + // first, then destroy the comm. + nvte_ep_shutdown(); + EpCommManager::get().shutdown(); +} + +pybind11::tuple EpRegisterLayer(int top_k, size_t dispatch_output_per_expert_alignment) { + NVTEEpLayerConfig layer_cfg{0, top_k, dispatch_output_per_expert_alignment}; + size_t handle_mem_size = 0; + uint64_t handle_id = nvte_ep_register_layer(layer_cfg, &handle_mem_size); + return pybind11::make_tuple(handle_id, handle_mem_size); +} + +// ── ep_prepare ──────────────────────────────────────────────────────────────── + +Error_Type EpPrepareFFI(cudaStream_t stream, Buffer_Type topk_idx, Result_Type token_counts, + Result_Type handle_mem, Result_Type workspace, EpPrepareConfig config) { + auto topk_dims = topk_idx.dimensions(); + NVTE_CHECK(topk_dims.size() >= 2, + "topk_idx must be at least 2D [..., top_k], got ndim=", topk_dims.size()); + auto idx_etype = topk_idx.element_type(); + NVTE_CHECK(idx_etype == ::xla::ffi::DataType::S64 || idx_etype == ::xla::ffi::DataType::S32, + "topk_idx must be int32 or int64; got element_type=", static_cast(idx_etype)); + + std::vector topk_shape = {product(topk_dims, 0, topk_dims.size() - 1), + static_cast(topk_dims.back())}; + // NCCL EP currently requires int64 topk_idx; upcast int32 on-stream. + // TODO(phuong): drop once NCCL EP accepts int32. + void* topk_idx_data = topk_idx.untyped_data(); + if (idx_etype == ::xla::ffi::DataType::S32) { + const size_t n = topk_shape[0] * topk_shape[1]; + NVTE_CHECK(static_cast(workspace->element_count()) >= n, + "workspace too small for int32 → int64 upcast: element_count=", + workspace->element_count(), " < required ", n); + int64_t* ws = reinterpret_cast(workspace->untyped_data()); + nvte_convert_int32_to_int64(reinterpret_cast(topk_idx_data), ws, n, stream); + topk_idx_data = ws; + } + auto topk_idx_ = TensorWrapper(topk_idx_data, topk_shape, DType::kInt64); + + std::vector tc_shape = {static_cast(token_counts->element_count())}; + auto token_counts_ = TensorWrapper(token_counts->untyped_data(), tc_shape, DType::kInt32); + + std::vector hm_shape = {static_cast(handle_mem->element_count())}; + auto handle_mem_ = TensorWrapper(handle_mem->untyped_data(), hm_shape, DType::kByte); + + NVTEEpHandle handle{static_cast(config.handle_id), handle_mem_.data()}; + nvte_ep_prepare(handle, topk_idx_.data(), token_counts_.data(), + static_cast(config.dispatch_output_per_expert_alignment), stream); + return ffi_with_cuda_error_check(); +} + +XLA_FFI_DEFINE_HANDLER_SYMBOL(EpPrepareHandler, EpPrepareFFI, + FFI::Bind() + .Ctx() // stream + .Arg() // topk_idx + .Ret() // token_counts + .Ret() // handle_mem + .Ret() // workspace (FFI scratch) + .Attrs(), + FFI_CudaGraph_Traits); + +// ── ep_dispatch ─────────────────────────────────────────────────────────────── + +Error_Type EpDispatchFFI(cudaStream_t stream, Buffer_Type handle_mem, Buffer_Type topk_idx, + Buffer_Type tokens, Buffer_Type topk_weights, Result_Type recv_tokens, + Result_Type recv_topk_weights, Result_Type workspace, + EpDispatchConfig config) { + auto token_dims = tokens.dimensions(); + NVTE_CHECK(token_dims.size() >= 2, + "tokens must be at least 2D [..., H], got ndim=", token_dims.size()); + + std::vector hm_shape = {static_cast(handle_mem.element_count())}; + auto handle_mem_ = TensorWrapper(handle_mem.untyped_data(), hm_shape, DType::kByte); + + auto idx_dims = topk_idx.dimensions(); + NVTE_CHECK(idx_dims.size() >= 2, + "topk_idx must be at least 2D [..., top_k], got ndim=", idx_dims.size()); + auto idx_etype = topk_idx.element_type(); + NVTE_CHECK(idx_etype == ::xla::ffi::DataType::S64 || idx_etype == ::xla::ffi::DataType::S32, + "topk_idx must be int32 or int64; got element_type=", static_cast(idx_etype)); + NVTE_CHECK(static_cast(idx_dims.back()) == config.top_k, "top_k attr (", config.top_k, + ") must match topk_idx last dim (", idx_dims.back(), ")"); + std::vector idx_shape = {product(idx_dims, 0, idx_dims.size() - 1), + static_cast(idx_dims.back())}; + // NCCL EP currently requires int64 topk_idx; upcast int32 on-stream. + // TODO(phuong): drop once NCCL EP accepts int32. + void* topk_idx_data = topk_idx.untyped_data(); + if (idx_etype == ::xla::ffi::DataType::S32) { + const size_t n = idx_shape[0] * idx_shape[1]; + NVTE_CHECK(static_cast(workspace->element_count()) >= n, + "workspace too small for int32 → int64 upcast: element_count=", + workspace->element_count(), " < required ", n); + int64_t* ws = reinterpret_cast(workspace->untyped_data()); + nvte_convert_int32_to_int64(reinterpret_cast(topk_idx_data), ws, n, stream); + topk_idx_data = ws; + } + auto topk_idx_ = TensorWrapper(topk_idx_data, idx_shape, DType::kInt64); + + const size_t T_flat = product(token_dims, 0, token_dims.size() - 1); + const size_t H = static_cast(token_dims.back()); + std::vector tok_shape = {T_flat, H}; + auto token_dtype = convert_ffi_datatype_to_te_dtype(tokens.element_type()); + auto tokens_ = TensorWrapper(tokens.untyped_data(), tok_shape, token_dtype); + + auto tw_dims = topk_weights.dimensions(); + NVTE_CHECK(tw_dims.size() >= 2, + "topk_weights must be at least 2D [..., top_k], got ndim=", tw_dims.size()); + std::vector tw_shape = {product(tw_dims, 0, tw_dims.size() - 1), + static_cast(tw_dims.back())}; + auto topk_weights_ = TensorWrapper(topk_weights.untyped_data(), tw_shape, DType::kFloat32); + + // recv_tokens: flatten any leading dims into recv_capacity_per_rank. + auto recv_dims = recv_tokens->dimensions(); + NVTE_CHECK(recv_dims.size() >= 2, + "recv_tokens must be at least 2D [..., recv_pr, H]; got ndim=", recv_dims.size()); + const size_t recv_capacity_per_rank = product(recv_dims, 0, recv_dims.size() - 1); + std::vector recv_shape = {recv_capacity_per_rank, H}; + auto recv_tokens_ = TensorWrapper(recv_tokens->untyped_data(), recv_shape, token_dtype); + + auto recv_w_dims = recv_topk_weights->dimensions(); + NVTE_CHECK(recv_w_dims.size() >= 1, + "recv_topk_weights must be at least 1D; got ndim=", recv_w_dims.size()); + const size_t recv_w_total = product(recv_w_dims, 0, recv_w_dims.size()); + NVTE_CHECK(recv_w_total == recv_capacity_per_rank, "recv_topk_weights total (", recv_w_total, + ") must match recv_tokens recv_pr (", recv_capacity_per_rank, ")"); + std::vector recv_w_shape = {recv_capacity_per_rank}; + auto recv_topk_weights_ = + TensorWrapper(recv_topk_weights->untyped_data(), recv_w_shape, DType::kFloat32); + + NVTEEpHandle handle{static_cast(config.handle_id), handle_mem_.data()}; + NVTECommWindow no_win{nullptr, 0}; + nvte_ep_dispatch(handle, topk_idx_.data(), tokens_.data(), no_win, topk_weights_.data(), no_win, + recv_tokens_.data(), no_win, recv_topk_weights_.data(), no_win, stream); + + return ffi_with_cuda_error_check(); +} + +XLA_FFI_DEFINE_HANDLER_SYMBOL(EpDispatchHandler, EpDispatchFFI, + FFI::Bind() + .Ctx() // stream + .Arg() // handle_mem + .Arg() // topk_idx + .Arg() // tokens + .Arg() // topk_weights + .Ret() // recv_tokens + .Ret() // recv_topk_weights + .Ret() // workspace (FFI scratch) + .Attrs(), + FFI_CudaGraph_Traits); + +// ── ep_combine ──────────────────────────────────────────────────────────────── + +Error_Type EpCombineFFI(cudaStream_t stream, Buffer_Type handle_mem, Buffer_Type expert_out, + Result_Type result, EpCombineConfig config) { + auto eo_dims = expert_out.dimensions(); + NVTE_CHECK(eo_dims.size() >= 2, + "expert_out must be at least 2D [..., recv_pr, H]; got ndim=", eo_dims.size()); + + std::vector hm_shape = {static_cast(handle_mem.element_count())}; + auto handle_mem_ = TensorWrapper(handle_mem.untyped_data(), hm_shape, DType::kByte); + + const size_t recv_capacity_per_rank = product(eo_dims, 0, eo_dims.size() - 1); + const size_t H = static_cast(eo_dims.back()); + std::vector eo_shape = {recv_capacity_per_rank, H}; + auto eo_dtype = convert_ffi_datatype_to_te_dtype(expert_out.element_type()); + auto expert_out_ = TensorWrapper(expert_out.untyped_data(), eo_shape, eo_dtype); + + auto res_dims = result->dimensions(); + NVTE_CHECK(res_dims.size() >= 2, + "result must be at least 2D [..., H]; got ndim=", res_dims.size()); + const size_t res_T_flat = product(res_dims, 0, res_dims.size() - 1); + NVTE_CHECK(static_cast(res_T_flat) == config.num_local_tokens, + "result leading-dim product (", res_T_flat, ") must equal num_local_tokens (", + config.num_local_tokens, ")"); + std::vector res_shape = {res_T_flat, H}; + auto result_ = TensorWrapper(result->untyped_data(), res_shape, eo_dtype); + + NVTEEpHandle handle{static_cast(config.handle_id), handle_mem_.data()}; + NVTECommWindow no_win{nullptr, 0}; + nvte_ep_combine(handle, expert_out_.data(), no_win, result_.data(), stream); + + return ffi_with_cuda_error_check(); +} + +XLA_FFI_DEFINE_HANDLER_SYMBOL(EpCombineHandler, EpCombineFFI, + FFI::Bind() + .Ctx() // stream + .Arg() // handle_mem + .Arg() // expert_out + .Ret() // result + .Attrs(), + FFI_CudaGraph_Traits); + +// ── ep_dispatch_bwd ─────────────────────────────────────────────────────────── + +Error_Type EpDispatchBwdFFI(cudaStream_t stream, Buffer_Type handle_mem, Buffer_Type grad, + Buffer_Type g_recv_topk_weights, Result_Type grad_tokens, + Result_Type grad_topk_weights, EpDispatchBwdConfig config) { + auto grad_dims = grad.dimensions(); + NVTE_CHECK(grad_dims.size() >= 2, + "grad must be at least 2D [..., recv_pr, H]; got ndim=", grad_dims.size()); + + std::vector hm_shape = {static_cast(handle_mem.element_count())}; + auto handle_mem_ = TensorWrapper(handle_mem.untyped_data(), hm_shape, DType::kByte); + + const size_t recv_capacity_per_rank = product(grad_dims, 0, grad_dims.size() - 1); + const size_t H = static_cast(grad_dims.back()); + std::vector g_shape = {recv_capacity_per_rank, H}; + auto g_dtype = convert_ffi_datatype_to_te_dtype(grad.element_type()); + auto grad_ = TensorWrapper(grad.untyped_data(), g_shape, g_dtype); + + auto gw_dims = g_recv_topk_weights.dimensions(); + NVTE_CHECK( + gw_dims.size() >= 1, + "g_recv_topk_weights rank must flatten to recv_capacity_per_rank; got ndim=", gw_dims.size()); + const size_t gw_total = product(gw_dims, 0, gw_dims.size()); + NVTE_CHECK(gw_total == recv_capacity_per_rank, "g_recv_topk_weights total (", gw_total, + ") must match grad recv_pr (", recv_capacity_per_rank, ")"); + std::vector gw_shape = {recv_capacity_per_rank}; + auto g_recv_topk_weights_ = + TensorWrapper(g_recv_topk_weights.untyped_data(), gw_shape, DType::kFloat32); + + auto out_dims = grad_tokens->dimensions(); + NVTE_CHECK(out_dims.size() >= 2, + "grad_tokens must be at least 2D [..., H], got ndim=", out_dims.size()); + const size_t T_flat = product(out_dims, 0, out_dims.size() - 1); + NVTE_CHECK(static_cast(T_flat) == config.num_local_tokens, + "grad_tokens leading-dim product (", T_flat, ") must equal num_local_tokens (", + config.num_local_tokens, ")"); + std::vector out_shape = {T_flat, H}; + auto grad_tokens_ = TensorWrapper(grad_tokens->untyped_data(), out_shape, g_dtype); + + auto gtw_dims = grad_topk_weights->dimensions(); + NVTE_CHECK(gtw_dims.size() >= 2, + "grad_topk_weights must be at least 2D [..., top_k]; got ndim=", gtw_dims.size()); + const size_t gtw_T_flat = product(gtw_dims, 0, gtw_dims.size() - 1); + NVTE_CHECK(gtw_T_flat == T_flat, "grad_topk_weights leading-dim product (", gtw_T_flat, + ") must equal grad_tokens leading-dim product (", T_flat, ")"); + const size_t top_k = static_cast(gtw_dims.back()); + NVTE_CHECK(static_cast(top_k) == config.top_k, "top_k attr (", config.top_k, + ") must match grad_topk_weights last dim (", top_k, ")"); + std::vector gtw_shape = {T_flat, top_k}; + auto grad_topk_weights_ = + TensorWrapper(grad_topk_weights->untyped_data(), gtw_shape, DType::kFloat32); + + NVTEEpHandle handle{static_cast(config.handle_id), handle_mem_.data()}; + NVTECommWindow no_win{nullptr, 0}; + nvte_ep_dispatch_bwd(handle, grad_.data(), no_win, g_recv_topk_weights_.data(), no_win, + grad_tokens_.data(), grad_topk_weights_.data(), stream); + + return ffi_with_cuda_error_check(); +} + +XLA_FFI_DEFINE_HANDLER_SYMBOL(EpDispatchBwdHandler, EpDispatchBwdFFI, + FFI::Bind() + .Ctx() // stream + .Arg() // handle_mem + .Arg() // grad (w.r.t. recv_tokens) + .Arg() // g_recv_topk_weights + .Ret() // grad_tokens + .Ret() // grad_topk_weights + .Attrs(), + FFI_CudaGraph_Traits); + +// ── ep_combine_bwd ──────────────────────────────────────────────────────────── + +Error_Type EpCombineBwdFFI(cudaStream_t stream, Buffer_Type handle_mem, Buffer_Type grad, + Result_Type grad_expert_out, EpCombineBwdConfig config) { + auto grad_dims = grad.dimensions(); + NVTE_CHECK(grad_dims.size() >= 2, + "grad must be at least 2D [..., H], got ndim=", grad_dims.size()); + + std::vector hm_shape = {static_cast(handle_mem.element_count())}; + auto handle_mem_ = TensorWrapper(handle_mem.untyped_data(), hm_shape, DType::kByte); + + const size_t T_flat = product(grad_dims, 0, grad_dims.size() - 1); + const size_t H = static_cast(grad_dims.back()); + std::vector g_shape = {T_flat, H}; + auto g_dtype = convert_ffi_datatype_to_te_dtype(grad.element_type()); + auto grad_ = TensorWrapper(grad.untyped_data(), g_shape, g_dtype); + + auto out_dims = grad_expert_out->dimensions(); + NVTE_CHECK(out_dims.size() >= 2, + "grad_expert_out must be at least 2D [..., recv_pr, H]; got ndim=", out_dims.size()); + const size_t recv_capacity_per_rank = product(out_dims, 0, out_dims.size() - 1); + const size_t out_H = static_cast(out_dims.back()); + NVTE_CHECK(out_H == H, "grad_expert_out hidden dim (", out_H, ") must match grad H (", H, ")"); + std::vector out_shape = {recv_capacity_per_rank, H}; + auto grad_expert_out_ = TensorWrapper(grad_expert_out->untyped_data(), out_shape, g_dtype); + + NVTEEpHandle handle{static_cast(config.handle_id), handle_mem_.data()}; + NVTECommWindow no_win{nullptr, 0}; + nvte_ep_combine_bwd(handle, grad_.data(), no_win, grad_expert_out_.data(), no_win, stream); + + return ffi_with_cuda_error_check(); +} + +XLA_FFI_DEFINE_HANDLER_SYMBOL(EpCombineBwdHandler, EpCombineBwdFFI, + FFI::Bind() + .Ctx() // stream + .Arg() // handle_mem + .Arg() // grad (w.r.t. result) + .Ret() // grad_expert_out + .Attrs(), + FFI_CudaGraph_Traits); + +} // namespace jax +} // namespace transformer_engine + +XLA_FFI_REGISTER_STRUCT_ATTR_DECODING( + transformer_engine::jax::EpPrepareConfig, ::xla::ffi::StructMember("handle_id"), + ::xla::ffi::StructMember("dispatch_output_per_expert_alignment")); + +XLA_FFI_REGISTER_STRUCT_ATTR_DECODING(transformer_engine::jax::EpDispatchConfig, + ::xla::ffi::StructMember("handle_id"), + ::xla::ffi::StructMember("top_k")); + +XLA_FFI_REGISTER_STRUCT_ATTR_DECODING(transformer_engine::jax::EpCombineConfig, + ::xla::ffi::StructMember("handle_id"), + ::xla::ffi::StructMember("num_local_tokens")); + +XLA_FFI_REGISTER_STRUCT_ATTR_DECODING(transformer_engine::jax::EpDispatchBwdConfig, + ::xla::ffi::StructMember("handle_id"), + ::xla::ffi::StructMember("num_local_tokens"), + ::xla::ffi::StructMember("top_k")); + +XLA_FFI_REGISTER_STRUCT_ATTR_DECODING(transformer_engine::jax::EpCombineBwdConfig, + ::xla::ffi::StructMember("handle_id")); + +#endif // NVTE_WITH_NCCL_EP diff --git a/transformer_engine/jax/csrc/extensions/pybind.cpp b/transformer_engine/jax/csrc/extensions/pybind.cpp index 70d0403b3e..b34f8739ee 100644 --- a/transformer_engine/jax/csrc/extensions/pybind.cpp +++ b/transformer_engine/jax/csrc/extensions/pybind.cpp @@ -101,6 +101,15 @@ pybind11::dict Registrations() { dict["te_fused_moe_aux_loss_forward_ffi"] = EncapsulateFFI(FusedMoEAuxLossForwardHandler); dict["te_fused_moe_aux_loss_backward_ffi"] = EncapsulateFFI(FusedMoEAuxLossBackwardHandler); +#ifdef NVTE_WITH_NCCL_EP + // Expert Parallelism + dict["te_ep_prepare_ffi"] = EncapsulateFFI(EpPrepareHandler); + dict["te_ep_dispatch_ffi"] = EncapsulateFFI(EpDispatchHandler); + dict["te_ep_combine_ffi"] = EncapsulateFFI(EpCombineHandler); + dict["te_ep_dispatch_bwd_ffi"] = EncapsulateFFI(EpDispatchBwdHandler); + dict["te_ep_combine_bwd_ffi"] = EncapsulateFFI(EpCombineBwdHandler); +#endif // NVTE_WITH_NCCL_EP + // TopK dict["te_topk_ffi"] = EncapsulateFFI(TopkHandler); @@ -127,6 +136,15 @@ PYBIND11_MODULE(transformer_engine_jax, m) { m.def("initialize_cgemm_communicator", &InitializeCgemmCommunicator); m.def("get_cgemm_num_max_streams", &GetCgemmNumMaxStreams); m.def("get_grouped_gemm_setup_workspace_size", &nvte_get_grouped_gemm_setup_workspace_size); +#ifdef NVTE_WITH_NCCL_EP + m.def("initialize_ep_communicator", &EpInitialize, pybind11::arg("unique_id_bytes"), + pybind11::arg("ep_size"), pybind11::arg("rank_within_group"), pybind11::arg("num_experts"), + pybind11::arg("max_tokens_per_rank"), pybind11::arg("max_recv_tokens_per_rank"), + pybind11::arg("hidden_dim"), pybind11::arg("max_num_sms") = 0); + m.def("shutdown_ep_communicator", &EpShutdown); + m.def("ep_register_layer", &EpRegisterLayer, pybind11::arg("top_k"), + pybind11::arg("dispatch_output_per_expert_alignment") = 0); +#endif // NVTE_WITH_NCCL_EP pybind11::enum_(m, "DType", pybind11::module_local()) .value("kByte", DType::kByte) diff --git a/transformer_engine/jax/ep.py b/transformer_engine/jax/ep.py new file mode 100644 index 0000000000..40d07bc3d4 --- /dev/null +++ b/transformer_engine/jax/ep.py @@ -0,0 +1,303 @@ +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. +"""JAX Expert Parallelism (EP) API.""" + +import atexit +import ctypes +from functools import partial + +import jax +import jax.numpy as jnp +import jax.experimental.multihost_utils as jmu +import numpy as np + +import transformer_engine_jax +import transformer_engine.jax.cpp_extensions as tex +from transformer_engine.jax.cpp_extensions.ep import EpHandle +from transformer_engine.jax.sharding import global_mesh_resource, get_mesh_axis_size + +ep_prepare = tex.ep_prepare + +__all__ = [ + "EpHandle", + "ep_bootstrap", + "ep_prepare", + "ep_dispatch", + "ep_combine", +] + +_atexit_registered = False + + +# ── Bootstrap ──────────────────────────────────────────────────────────────── + + +def ep_bootstrap( + world_size, + rank, + ep_size, + num_experts, + max_tokens_per_rank, + recv_capacity_per_rank, + hidden_dim, + max_num_sms=0, +): + """Initialize the EP communicator. Call once per process before any EP op. + + max_num_sms caps the SMs allotted to EP kernels (0 = auto). + """ + if world_size < 2: + raise ValueError( + f"ep_bootstrap requires world_size >= 2 (got {world_size}); NCCL EP needs" + " at least 2 ranks to form a group." + ) + if world_size % ep_size != 0: + raise ValueError( + f"world_size ({world_size}) must be divisible by ep_size ({ep_size}); otherwise" + " some EP groups would have fewer than ep_size ranks and ncclCommInitRank would hang." + ) + if num_experts % ep_size != 0: + raise ValueError(f"num_experts ({num_experts}) must be divisible by ep_size ({ep_size}).") + if jax.local_device_count() != 1: + raise ValueError( + "ep_bootstrap requires one local device per process (got" + f" jax.local_device_count() = {jax.local_device_count()}); NCCL EP does not" + " support single-process multi-device setups." + ) + UID_SIZE = 128 + dp_color = rank // ep_size + rank_within_group = rank % ep_size + is_color_root = rank_within_group == 0 + if is_color_root: + try: + from nccl import get_unique_id + + uid_bytes = bytes(get_unique_id())[:UID_SIZE] + except ImportError: + libnccl = ctypes.CDLL("libnccl.so.2", use_errno=True) + uid_arr = (ctypes.c_uint8 * UID_SIZE)() + ret = libnccl.ncclGetUniqueId(ctypes.cast(uid_arr, ctypes.c_void_p)) + assert ret == 0, f"ncclGetUniqueId failed with code {ret}" + uid_bytes = bytes(uid_arr) + else: + uid_bytes = bytes(UID_SIZE) + + uid_arr = jnp.frombuffer(uid_bytes, dtype=jnp.uint8) + all_uids = jmu.process_allgather(uid_arr).reshape(world_size, UID_SIZE) + uid_bytes = bytes(np.asarray(all_uids[dp_color * ep_size]).tolist()) + + ep_resource = global_mesh_resource().ep_resource + if ep_resource is None: + raise ValueError( + "ep_bootstrap requires MeshResource.ep_resource to be set; enter a" + " global_shard_guard(MeshResource(..., ep_resource=)) before bootstrap." + ) + mesh_ep_size = get_mesh_axis_size(ep_resource) + if mesh_ep_size != ep_size: + raise ValueError( + f"ep_bootstrap: EpConfig.ep_size ({ep_size}) does not match mesh axis" + f" '{ep_resource}' size ({mesh_ep_size})." + ) + + transformer_engine_jax.initialize_ep_communicator( + uid_bytes, + ep_size, + rank_within_group, + num_experts, + max_tokens_per_rank, + recv_capacity_per_rank, + hidden_dim, + max_num_sms=int(max_num_sms), + ) + + # Shutdown ordering: + # - Python atexit is LIFO. ep_bootstrap runs jmu.process_allgather first, + # which assumes jax.distributed.initialize() ran earlier, so JAX's + # distributed atexit hooks are already registered before this one. Ours + # therefore fires first at exit — fine, because EpShutdown only touches + # NCCL (ncclEpGroupDestroy + ncclCommDestroy) and does not depend on + # JAX's coordination service. Do not add JAX calls to EpShutdown. + # - Running before C++ static destructors avoids the cudartUnloading + # hazard; the C++ destructors are intentionally no-ops. + global _atexit_registered + if not _atexit_registered: + atexit.register(transformer_engine_jax.shutdown_ep_communicator) + _atexit_registered = True + + tex.ep.set_ep_config( + tex.ep.EpConfig( + world_size=world_size, + rank=rank, + ep_size=ep_size, + num_experts=num_experts, + num_local_experts=num_experts // ep_size, + max_tokens_per_rank=max_tokens_per_rank, + recv_capacity_per_rank=recv_capacity_per_rank, + hidden_dim=hidden_dim, + ) + ) + + +# ── ep_dispatch (custom_vjp) ───────────────────────────────────────────────── + + +@partial(jax.custom_vjp, nondiff_argnums=(3, 4)) +def ep_dispatch( + topk_idx, + tokens, + topk_weights, + recv_capacity_per_rank, + dispatch_output_per_expert_alignment=0, +): + """Scatter tokens and weights to expert ranks. + + Inputs are 2D ``[T, H]`` or 3D ``[B, S, H]``. Only the leading dim may + be sharded — axis ∈ {ep, (dp, ep), dp, None}; trailing dims replicated. + + Args: + topk_idx: ``[..., top_k]`` int32/int64 routing indices. + tokens: ``[..., H]`` activations (matching leading dims). + topk_weights: ``[..., top_k]`` float32 routing weights. + recv_capacity_per_rank: STATIC int. Per-rank recv slot count. + dispatch_output_per_expert_alignment: STATIC int. Per-expert slot + alignment; 0 disables. + + Returns: + ``(recv_tokens, recv_topk_weights, handle, token_counts)`` where + ``recv_tokens`` is 3D ``[num_procs, recv_capacity_per_rank, H]`` + sharded ``(("dp","ep"), None, None)`` (or ``("ep", None, None)`` if + DP is unset), and ``recv_topk_weights`` is 2D + ``[num_procs, recv_capacity_per_rank]`` similarly sharded. Pass + ``handle`` to the matching ``ep_combine``. + """ + return _dispatch_fwd( + topk_idx, + tokens, + topk_weights, + recv_capacity_per_rank, + dispatch_output_per_expert_alignment, + )[0] + + +def _dispatch_fwd( + topk_idx, + tokens, + topk_weights, + recv_capacity_per_rank, + dispatch_output_per_expert_alignment, +): + top_k = int(topk_weights.shape[-1]) + token_counts, handle = tex.ep_prepare(topk_idx, dispatch_output_per_expert_alignment) + recv_tokens, recv_topk_weights, handle = tex.ep_dispatch_fwd( + handle, topk_idx, tokens, topk_weights, recv_capacity_per_rank + ) + out_leading = tuple(tokens.shape[:-1]) + primal = (recv_tokens, recv_topk_weights, handle, token_counts) + return primal, (handle, out_leading, top_k) + + +def _dispatch_bwd(recv_capacity_per_rank, dispatch_output_per_expert_alignment, res, g_outputs): + del recv_capacity_per_rank, dispatch_output_per_expert_alignment + handle, out_leading, top_k = res + # Re-pin cotangent sharding: XLA transpose can drop the EP axis on a + # single-fwd-output cotangent, landing a global tensor in the FFI. + gsr = global_mesh_resource() + ep_axis = gsr.ep_resource + outer = gsr.dp_resource or gsr.fsdp_resource + leading = (outer, ep_axis) if outer is not None else ep_axis + g_recv_tokens = jax.lax.with_sharding_constraint( + g_outputs[0], jax.sharding.PartitionSpec(leading, None, None) + ) + g_recv_topk_weights = jax.lax.with_sharding_constraint( + g_outputs[1], jax.sharding.PartitionSpec(leading, None) + ) + grad_tokens, grad_topk_weights = tex.ep_dispatch_bwd( + handle, g_recv_tokens, g_recv_topk_weights, top_k, out_leading + ) + return (None, grad_tokens, grad_topk_weights) + + +ep_dispatch.defvjp(_dispatch_fwd, _dispatch_bwd) + + +# ── ep_combine (custom_vjp) ────────────────────────────────────────────────── + + +@partial(jax.custom_vjp, nondiff_argnums=(4, 5)) +def ep_combine( + handle, token_counts, expert_out, recv_topk_weights, num_local_tokens, out_sharding=None +): + """Reduce weighted expert outputs back to source ranks. + + Args: + handle: ``EpHandle`` from a matching ``ep_dispatch`` call. + token_counts: ``[num_procs, num_local_experts]`` int32 (passed through). + expert_out: ``[num_procs, recv_capacity_per_rank, H]`` post-FFN activations. + recv_topk_weights: ``[num_procs, recv_capacity_per_rank]`` float32 weights + returned by ``ep_dispatch``. + num_local_tokens: STATIC int or tuple. int → 2D output ``[T, H]``; + tuple → N-D output ``[*tuple, H]``. + out_sharding: STATIC optional ``PartitionSpec`` tuple for the + output. Defaults to ``(("dp","ep"), *None)`` when + DP is set, else ``("ep", *None)``. Pass a custom + spec to override; only the leading dim may be + sharded. + + Returns: + ``[..., H]`` combined output shaped per ``num_local_tokens``. + """ + return _combine_fwd( + handle, token_counts, expert_out, recv_topk_weights, num_local_tokens, out_sharding + )[0] + + +def _make_valid_mask(recv_topk_weights, dtype): + # recv_topk_weights == 0 marks a padded slot. + return (recv_topk_weights != 0).astype(dtype)[..., None] + + +def _combine_fwd( + handle, token_counts, expert_out, recv_topk_weights, num_local_tokens, out_sharding +): + del token_counts + w = recv_topk_weights[..., None] + mask = _make_valid_mask(recv_topk_weights, jnp.float32) + weighted = (expert_out.astype(jnp.float32) * w * mask).astype(expert_out.dtype) + result = tex.ep_combine_fwd(handle, weighted, num_local_tokens, out_partition_spec=out_sharding) + return result, (handle, recv_topk_weights, expert_out) + + +def _combine_bwd(_num_local_tokens, _out_sharding, res, g_result): + handle, recv_topk_weights, expert_out = res + # expert_out is [..., recv_pr, H]; pull recv_pr from the second-to-last dim. + recv_capacity_per_rank = expert_out.shape[-2] + # Re-pin cotangent sharding: same XLA-transpose workaround as _dispatch_bwd. + gsr = global_mesh_resource() + if _out_sharding is not None: + spec = jax.sharding.PartitionSpec(*_out_sharding) + else: + ep_axis = gsr.ep_resource + outer = gsr.dp_resource or gsr.fsdp_resource + leading = (outer, ep_axis) if outer is not None and ep_axis is not None else ep_axis + spec = ( + jax.sharding.PartitionSpec(leading, *([None] * (g_result.ndim - 1))) + if leading is not None + else None + ) + if spec is not None: + g_result = jax.lax.with_sharding_constraint(g_result, spec) + grad_weighted = tex.ep_combine_bwd(handle, g_result, recv_capacity_per_rank) + w = recv_topk_weights[..., None] + mask = _make_valid_mask(recv_topk_weights, jnp.float32) + grad_weighted_f32 = grad_weighted.astype(jnp.float32) + grad_expert_out = (grad_weighted_f32 * w * mask).astype(grad_weighted.dtype) + grad_recv_topk_weights = ( + (grad_weighted_f32 * expert_out.astype(jnp.float32) * mask) + .sum(axis=-1) + .astype(recv_topk_weights.dtype) + ) + return (None, None, grad_expert_out, grad_recv_topk_weights) + + +ep_combine.defvjp(_combine_fwd, _combine_bwd) diff --git a/transformer_engine/jax/sharding.py b/transformer_engine/jax/sharding.py index 9b13412c14..863a530c8b 100644 --- a/transformer_engine/jax/sharding.py +++ b/transformer_engine/jax/sharding.py @@ -332,6 +332,12 @@ class MeshResource: fsdp_resource: Axis name for full-sharded data parallelism, default is None pp_resource: Axis name for pipeline parallelism (layer sharding), default is None cp_resource: Axis name for context parallelism (sequence sharding), default is None + ep_resource: Axis name for expert parallelism. Dispatch input tokens + must be sharded on their leading dim by ``ep_resource`` (alone or + compound with ``dp_resource`` / ``fsdp_resource`` as outer, e.g. + ``PartitionSpec(("dp", "ep"), None, None)``). Dispatch output + ``[ep_size, recv_capacity, H]`` is always sharded by ``ep_resource`` + on the leading ``ep_size`` dim. """ dp_resource: str = None @@ -340,6 +346,7 @@ class MeshResource: fsdp_resource: str = None pp_resource: str = None cp_resource: str = None + ep_resource: str = None _GLOBAL_MESH_RESOURCE = None @@ -441,3 +448,8 @@ def dp_or_fsdp_axis_size(): dp_size = get_mesh_axis_size(global_mesh_resource().dp_resource) fsdp_size = get_mesh_axis_size(global_mesh_resource().fsdp_resource) return dp_size if dp_size > 1 else fsdp_size + + +def ep_axis_size(): + """Get the size of the dispatch/EP axis (ep_resource). Returns 1 if unset.""" + return get_mesh_axis_size(global_mesh_resource().ep_resource) From 8cb8de46efc1c2b1edac5f4dba346d29da21bd0b Mon Sep 17 00:00:00 2001 From: Phuong Nguyen Date: Sat, 23 May 2026 00:31:54 +0000 Subject: [PATCH 3/3] JAX EP: tie NCCL comm lifetime to JAX executables via XLA stateful FFI Signed-off-by: Phuong Nguyen --- transformer_engine/jax/cpp_extensions/base.py | 11 + transformer_engine/jax/csrc/extensions.h | 21 +- transformer_engine/jax/csrc/extensions/ep.cpp | 273 +++++++++++------- .../jax/csrc/extensions/pybind.cpp | 28 +- transformer_engine/jax/ep.py | 15 +- 5 files changed, 222 insertions(+), 126 deletions(-) diff --git a/transformer_engine/jax/cpp_extensions/base.py b/transformer_engine/jax/cpp_extensions/base.py index 6eb588c849..2cdef4bfe7 100644 --- a/transformer_engine/jax/cpp_extensions/base.py +++ b/transformer_engine/jax/cpp_extensions/base.py @@ -266,6 +266,17 @@ def _gspmd_wrapper(*args, **kwargs): for _name, _value in transformer_engine_jax.registrations().items(): ffi.register_ffi_target(_name, _value, platform="CUDA") +# Register EpInstanceState (no-op when TE is built without NCCL EP). +if hasattr(transformer_engine_jax, "get_ep_instance_state_type_id"): + ffi.register_ffi_type( + "EpInstanceState", + { + "type_id": transformer_engine_jax.get_ep_instance_state_type_id(), + "type_info": transformer_engine_jax.get_ep_instance_state_type_info(), + }, + platform="CUDA", + ) + def manage_primitives(enable_names=None, disable_names=None, disable_all_first=False): """ diff --git a/transformer_engine/jax/csrc/extensions.h b/transformer_engine/jax/csrc/extensions.h index 7f7bee84a9..c2987aeef4 100644 --- a/transformer_engine/jax/csrc/extensions.h +++ b/transformer_engine/jax/csrc/extensions.h @@ -199,19 +199,20 @@ XLA_FFI_DECLARE_HANDLER_SYMBOL(FusedTopkWithScoreFunctionBackwardHandler); XLA_FFI_DECLARE_HANDLER_SYMBOL(FusedMoEAuxLossForwardHandler); XLA_FFI_DECLARE_HANDLER_SYMBOL(FusedMoEAuxLossBackwardHandler); -// EP bootstrap (called once per process) -void EpInitialize(pybind11::bytes unique_id_bytes, int ep_size, int rank_within_group, - int num_experts, int max_tokens_per_rank, int max_recv_tokens_per_rank, - int hidden_dim, int max_num_sms); -// EP shutdown — registered as a Python atexit hook so it runs before -// C++ static destructors of the JAX extension and libtransformer_engine.so. -void EpShutdown(); -// Host-only: register an EP layer. Returns (handle_id, handle_mem_size) where -// handle_id is baked into each FFI op as a static int64 attribute (no D2H sync -// per op) and handle_mem_size sizes the caller's handle_mem buffer. +// Bootstrap EP (eager NCCL comm init); anchor released by ReleaseEpResources. +void SetEpBootstrapParams(pybind11::bytes unique_id_bytes, int ep_size, int rank_within_group, + int num_experts, int max_tokens_per_rank, int max_recv_tokens_per_rank, + int hidden_dim, int max_num_sms); +void ReleaseEpResources(); +// Register an EP layer; returns (handle_id, handle_mem_size). pybind11::tuple EpRegisterLayer(int top_k, size_t dispatch_output_per_expert_alignment); +// EpInstanceState type_id / type_info capsules for jax.ffi.register_ffi_type. +pybind11::capsule GetEpInstanceStateTypeIdCapsule(); +pybind11::capsule GetEpInstanceStateTypeInfoCapsule(); + // EP FFI handlers +XLA_FFI_DECLARE_HANDLER_SYMBOL(EpInstantiateHandler); XLA_FFI_DECLARE_HANDLER_SYMBOL(EpPrepareHandler); XLA_FFI_DECLARE_HANDLER_SYMBOL(EpDispatchHandler); XLA_FFI_DECLARE_HANDLER_SYMBOL(EpCombineHandler); diff --git a/transformer_engine/jax/csrc/extensions/ep.cpp b/transformer_engine/jax/csrc/extensions/ep.cpp index e2c50135aa..5dc05de0ae 100644 --- a/transformer_engine/jax/csrc/extensions/ep.cpp +++ b/transformer_engine/jax/csrc/extensions/ep.cpp @@ -10,8 +10,10 @@ #include +#include #include #include +#include #include #include "../extensions.h" @@ -21,52 +23,85 @@ namespace transformer_engine { namespace jax { -namespace { +// NCCL comm + EPBackend lifetime tracks live JAX executables via XLA stateful FFI. + +struct EpBootstrapParams { + std::array uid_bytes{}; + int ep_size = 0; + int rank_within_group = 0; + int num_experts = 0; + int max_tokens_per_rank = 0; + int max_recv_tokens_per_rank = 0; + int hidden_dim = 0; + int max_num_sms = 0; +}; -// Process-lifetime owner of the EP ncclComm_t. Created from a broadcast -// ncclUniqueId during EpInitialize; destroyed by EpShutdown (registered as a -// Python atexit hook from ep.py so it runs before C++ static destructors). -class EpCommManager { +class EpResources { public: - static EpCommManager& get() { - static EpCommManager inst; - return inst; - } - - void init_from_uid(const uint8_t* uid_bytes, int ep_size, int rank_within_group) { - std::lock_guard lock(mutex_); - NVTE_CHECK(comm_ == nullptr, "EP comm already initialized for this process"); + explicit EpResources(const EpBootstrapParams& p) { ncclUniqueId uid; - std::memcpy(&uid, uid_bytes, sizeof(uid)); - NVTE_CHECK_NCCL(ncclCommInitRank(&comm_, ep_size, uid, rank_within_group)); + std::memcpy(&uid, p.uid_bytes.data(), sizeof(uid)); + NVTE_CHECK_NCCL(ncclCommInitRank(&comm_, p.ep_size, uid, p.rank_within_group)); + NVTEEpGroupConfig cfg{.ep_size = p.ep_size, + .num_experts = p.num_experts, + .max_tokens_per_rank = p.max_tokens_per_rank, + .max_recv_tokens_per_rank = p.max_recv_tokens_per_rank, + .hidden_dim = p.hidden_dim, + .max_num_sms = p.max_num_sms}; + try { + nvte_ep_initialize(static_cast(comm_), cfg); + } catch (...) { + ncclCommDestroy(comm_); + comm_ = nullptr; + throw; + } } - ncclComm_t comm() const { return comm_; } - - void shutdown() { - std::lock_guard lock(mutex_); + ~EpResources() { if (comm_ == nullptr) return; + nvte_ep_shutdown(); ncclCommDestroy(comm_); - comm_ = nullptr; } + EpResources(const EpResources&) = delete; + EpResources& operator=(const EpResources&) = delete; + + ncclComm_t comm() const { return comm_; } + private: - EpCommManager() = default; - // Intentionally no NCCL teardown in the destructor: this runs at static-dtor - // time, after Python has finalized and possibly after the CUDA driver - // detaches the context. Calling ncclCommDestroy there has been observed to - // hang or report cudartUnloading. Normal teardown goes through the Python - // atexit hook (shutdown_ep_communicator) registered from ep.py; any path - // that skips that (os._exit, fatal signal) leaks the comm, which the OS - // reaps on process exit. - ~EpCommManager() = default; - EpCommManager(const EpCommManager&) = delete; - EpCommManager& operator=(const EpCommManager&) = delete; - - std::mutex mutex_; ncclComm_t comm_{nullptr}; }; +struct EpInstanceState { + static ::xla::ffi::TypeId id; + static ::xla::ffi::TypeInfo info; + std::shared_ptr resources; +}; + +::xla::ffi::TypeId EpInstanceState::id = {}; +::xla::ffi::TypeInfo EpInstanceState::info = ::xla::ffi::MakeTypeInfo(); + +namespace { + +std::mutex g_ep_mu; +EpBootstrapParams g_ep_params; +bool g_ep_params_set = false; +std::weak_ptr g_ep_resources_weak; +// Python-held anchor so trace-time ep_register_layer finds EPBackend ready. +std::shared_ptr g_ep_resources_anchor; + +std::shared_ptr AcquireEpResources() { + std::lock_guard lock(g_ep_mu); + NVTE_CHECK(g_ep_params_set, + "EP bootstrap params not set; call transformer_engine_jax." + "set_ep_bootstrap_params() (typically via ep_bootstrap) first."); + auto sp = g_ep_resources_weak.lock(); + if (sp) return sp; + sp = std::make_shared(g_ep_params); + g_ep_resources_weak = sp; + return sp; +} + } // namespace // handle_id is baked at jit trace time and carried as a static FFI attribute. @@ -98,36 +133,44 @@ struct EpCombineBwdConfig { // ── Bootstrap helpers ───────────────────────────────────────────────────────── -void EpInitialize(pybind11::bytes unique_id_bytes_obj, int ep_size, int rank_within_group, - int num_experts, int max_tokens_per_rank, int max_recv_tokens_per_rank, - int hidden_dim, int max_num_sms) { +// Caches uid + group config and eagerly creates the NCCL comm (ranks +// synchronize via the UID broadcast). +void SetEpBootstrapParams(pybind11::bytes unique_id_bytes_obj, int ep_size, int rank_within_group, + int num_experts, int max_tokens_per_rank, int max_recv_tokens_per_rank, + int hidden_dim, int max_num_sms) { std::string uid_str = unique_id_bytes_obj; NVTE_CHECK(static_cast(uid_str.size()) >= 128, "unique_id_bytes must be at least 128 bytes (ncclUniqueId size)."); - EpCommManager::get().init_from_uid(reinterpret_cast(uid_str.data()), ep_size, - rank_within_group); - NVTEEpGroupConfig cfg{.ep_size = ep_size, - .num_experts = num_experts, - .max_tokens_per_rank = max_tokens_per_rank, - .max_recv_tokens_per_rank = max_recv_tokens_per_rank, - .hidden_dim = hidden_dim, - .max_num_sms = max_num_sms}; - // If common rejects the config (validate_config / ncclEpCreateGroup), roll - // the comm back so the two singletons don't end up in inconsistent states - // and the comm doesn't strand until process exit. - try { - nvte_ep_initialize(static_cast(EpCommManager::get().comm()), cfg); - } catch (...) { - EpCommManager::get().shutdown(); - throw; + std::shared_ptr anchor; + { + std::lock_guard lock(g_ep_mu); + NVTE_CHECK(!g_ep_resources_anchor, + "EP bootstrap already initialized; call release_ep_resources() before re-init."); + std::memcpy(g_ep_params.uid_bytes.data(), uid_str.data(), 128); + g_ep_params.ep_size = ep_size; + g_ep_params.rank_within_group = rank_within_group; + g_ep_params.num_experts = num_experts; + g_ep_params.max_tokens_per_rank = max_tokens_per_rank; + g_ep_params.max_recv_tokens_per_rank = max_recv_tokens_per_rank; + g_ep_params.hidden_dim = hidden_dim; + g_ep_params.max_num_sms = max_num_sms; + g_ep_params_set = true; } + // Acquire outside the lock: EpResources ctor runs ncclCommInitRank which is + // a collective and may block on peer ranks. + anchor = AcquireEpResources(); + std::lock_guard lock(g_ep_mu); + g_ep_resources_anchor = std::move(anchor); } -void EpShutdown() { - // Order matters: ep_group_ in common reads from the comm, so tear it down - // first, then destroy the comm. - nvte_ep_shutdown(); - EpCommManager::get().shutdown(); +// Drops the anchor; comm tears down once the last executable also releases. +void ReleaseEpResources() { + std::shared_ptr to_drop; + { + std::lock_guard lock(g_ep_mu); + to_drop = std::move(g_ep_resources_anchor); + } + // to_drop dtor runs outside the lock. } pybind11::tuple EpRegisterLayer(int top_k, size_t dispatch_output_per_expert_alignment) { @@ -137,10 +180,35 @@ pybind11::tuple EpRegisterLayer(int top_k, size_t dispatch_output_per_expert_ali return pybind11::make_tuple(handle_id, handle_mem_size); } +pybind11::capsule GetEpInstanceStateTypeIdCapsule() { + return pybind11::capsule(static_cast(&EpInstanceState::id), "xla.ffi.type_id"); +} + +pybind11::capsule GetEpInstanceStateTypeInfoCapsule() { + return pybind11::capsule(static_cast(&EpInstanceState::info), "xla.ffi.type_info"); +} + +// ── Instantiate handler ───────────────────────────────────────────────────── + +static ::xla::ffi::ErrorOr> EpInstantiateImpl() { + auto state = std::make_unique(); + try { + state->resources = AcquireEpResources(); + } catch (const std::exception& e) { + return ::xla::ffi::Unexpected( + ::xla::ffi::Error::Internal(std::string("EP instantiate failed: ") + e.what())); + } + return state; +} + +XLA_FFI_DEFINE_HANDLER_SYMBOL(EpInstantiateHandler, EpInstantiateImpl, FFI::BindInstantiate()); + // ── ep_prepare ──────────────────────────────────────────────────────────────── -Error_Type EpPrepareFFI(cudaStream_t stream, Buffer_Type topk_idx, Result_Type token_counts, - Result_Type handle_mem, Result_Type workspace, EpPrepareConfig config) { +Error_Type EpPrepareFFI(cudaStream_t stream, EpInstanceState* ep_state, Buffer_Type topk_idx, + Result_Type token_counts, Result_Type handle_mem, Result_Type workspace, + EpPrepareConfig config) { + (void)ep_state; // lifetime only. auto topk_dims = topk_idx.dimensions(); NVTE_CHECK(topk_dims.size() >= 2, "topk_idx must be at least 2D [..., top_k], got ndim=", topk_dims.size()); @@ -178,20 +246,22 @@ Error_Type EpPrepareFFI(cudaStream_t stream, Buffer_Type topk_idx, Result_Type t XLA_FFI_DEFINE_HANDLER_SYMBOL(EpPrepareHandler, EpPrepareFFI, FFI::Bind() - .Ctx() // stream - .Arg() // topk_idx - .Ret() // token_counts - .Ret() // handle_mem - .Ret() // workspace (FFI scratch) + .Ctx() // stream + .Ctx<::xla::ffi::State>() // EP state + .Arg() // topk_idx + .Ret() // token_counts + .Ret() // handle_mem + .Ret() // workspace (FFI scratch) .Attrs(), FFI_CudaGraph_Traits); // ── ep_dispatch ─────────────────────────────────────────────────────────────── -Error_Type EpDispatchFFI(cudaStream_t stream, Buffer_Type handle_mem, Buffer_Type topk_idx, - Buffer_Type tokens, Buffer_Type topk_weights, Result_Type recv_tokens, - Result_Type recv_topk_weights, Result_Type workspace, - EpDispatchConfig config) { +Error_Type EpDispatchFFI(cudaStream_t stream, EpInstanceState* ep_state, Buffer_Type handle_mem, + Buffer_Type topk_idx, Buffer_Type tokens, Buffer_Type topk_weights, + Result_Type recv_tokens, Result_Type recv_topk_weights, + Result_Type workspace, EpDispatchConfig config) { + (void)ep_state; auto token_dims = tokens.dimensions(); NVTE_CHECK(token_dims.size() >= 2, "tokens must be at least 2D [..., H], got ndim=", token_dims.size()); @@ -264,21 +334,23 @@ Error_Type EpDispatchFFI(cudaStream_t stream, Buffer_Type handle_mem, Buffer_Typ XLA_FFI_DEFINE_HANDLER_SYMBOL(EpDispatchHandler, EpDispatchFFI, FFI::Bind() - .Ctx() // stream - .Arg() // handle_mem - .Arg() // topk_idx - .Arg() // tokens - .Arg() // topk_weights - .Ret() // recv_tokens - .Ret() // recv_topk_weights - .Ret() // workspace (FFI scratch) + .Ctx() // stream + .Ctx<::xla::ffi::State>() // EP state + .Arg() // handle_mem + .Arg() // topk_idx + .Arg() // tokens + .Arg() // topk_weights + .Ret() // recv_tokens + .Ret() // recv_topk_weights + .Ret() // workspace (FFI scratch) .Attrs(), FFI_CudaGraph_Traits); // ── ep_combine ──────────────────────────────────────────────────────────────── -Error_Type EpCombineFFI(cudaStream_t stream, Buffer_Type handle_mem, Buffer_Type expert_out, - Result_Type result, EpCombineConfig config) { +Error_Type EpCombineFFI(cudaStream_t stream, EpInstanceState* ep_state, Buffer_Type handle_mem, + Buffer_Type expert_out, Result_Type result, EpCombineConfig config) { + (void)ep_state; auto eo_dims = expert_out.dimensions(); NVTE_CHECK(eo_dims.size() >= 2, "expert_out must be at least 2D [..., recv_pr, H]; got ndim=", eo_dims.size()); @@ -311,18 +383,21 @@ Error_Type EpCombineFFI(cudaStream_t stream, Buffer_Type handle_mem, Buffer_Type XLA_FFI_DEFINE_HANDLER_SYMBOL(EpCombineHandler, EpCombineFFI, FFI::Bind() - .Ctx() // stream - .Arg() // handle_mem - .Arg() // expert_out - .Ret() // result + .Ctx() // stream + .Ctx<::xla::ffi::State>() // EP state + .Arg() // handle_mem + .Arg() // expert_out + .Ret() // result .Attrs(), FFI_CudaGraph_Traits); // ── ep_dispatch_bwd ─────────────────────────────────────────────────────────── -Error_Type EpDispatchBwdFFI(cudaStream_t stream, Buffer_Type handle_mem, Buffer_Type grad, - Buffer_Type g_recv_topk_weights, Result_Type grad_tokens, - Result_Type grad_topk_weights, EpDispatchBwdConfig config) { +Error_Type EpDispatchBwdFFI(cudaStream_t stream, EpInstanceState* ep_state, Buffer_Type handle_mem, + Buffer_Type grad, Buffer_Type g_recv_topk_weights, + Result_Type grad_tokens, Result_Type grad_topk_weights, + EpDispatchBwdConfig config) { + (void)ep_state; auto grad_dims = grad.dimensions(); NVTE_CHECK(grad_dims.size() >= 2, "grad must be at least 2D [..., recv_pr, H]; got ndim=", grad_dims.size()); @@ -380,19 +455,22 @@ Error_Type EpDispatchBwdFFI(cudaStream_t stream, Buffer_Type handle_mem, Buffer_ XLA_FFI_DEFINE_HANDLER_SYMBOL(EpDispatchBwdHandler, EpDispatchBwdFFI, FFI::Bind() - .Ctx() // stream - .Arg() // handle_mem - .Arg() // grad (w.r.t. recv_tokens) - .Arg() // g_recv_topk_weights - .Ret() // grad_tokens - .Ret() // grad_topk_weights + .Ctx() // stream + .Ctx<::xla::ffi::State>() // EP state + .Arg() // handle_mem + .Arg() // grad (w.r.t. recv_tokens) + .Arg() // g_recv_topk_weights + .Ret() // grad_tokens + .Ret() // grad_topk_weights .Attrs(), FFI_CudaGraph_Traits); // ── ep_combine_bwd ──────────────────────────────────────────────────────────── -Error_Type EpCombineBwdFFI(cudaStream_t stream, Buffer_Type handle_mem, Buffer_Type grad, - Result_Type grad_expert_out, EpCombineBwdConfig config) { +Error_Type EpCombineBwdFFI(cudaStream_t stream, EpInstanceState* ep_state, Buffer_Type handle_mem, + Buffer_Type grad, Result_Type grad_expert_out, + EpCombineBwdConfig config) { + (void)ep_state; auto grad_dims = grad.dimensions(); NVTE_CHECK(grad_dims.size() >= 2, "grad must be at least 2D [..., H], got ndim=", grad_dims.size()); @@ -424,10 +502,11 @@ Error_Type EpCombineBwdFFI(cudaStream_t stream, Buffer_Type handle_mem, Buffer_T XLA_FFI_DEFINE_HANDLER_SYMBOL(EpCombineBwdHandler, EpCombineBwdFFI, FFI::Bind() - .Ctx() // stream - .Arg() // handle_mem - .Arg() // grad (w.r.t. result) - .Ret() // grad_expert_out + .Ctx() // stream + .Ctx<::xla::ffi::State>() // EP state + .Arg() // handle_mem + .Arg() // grad (w.r.t. result) + .Ret() // grad_expert_out .Attrs(), FFI_CudaGraph_Traits); diff --git a/transformer_engine/jax/csrc/extensions/pybind.cpp b/transformer_engine/jax/csrc/extensions/pybind.cpp index b34f8739ee..0304f37691 100644 --- a/transformer_engine/jax/csrc/extensions/pybind.cpp +++ b/transformer_engine/jax/csrc/extensions/pybind.cpp @@ -102,12 +102,22 @@ pybind11::dict Registrations() { dict["te_fused_moe_aux_loss_backward_ffi"] = EncapsulateFFI(FusedMoEAuxLossBackwardHandler); #ifdef NVTE_WITH_NCCL_EP - // Expert Parallelism - dict["te_ep_prepare_ffi"] = EncapsulateFFI(EpPrepareHandler); - dict["te_ep_dispatch_ffi"] = EncapsulateFFI(EpDispatchHandler); - dict["te_ep_combine_ffi"] = EncapsulateFFI(EpCombineHandler); - dict["te_ep_dispatch_bwd_ffi"] = EncapsulateFFI(EpDispatchBwdHandler); - dict["te_ep_combine_bwd_ffi"] = EncapsulateFFI(EpCombineBwdHandler); + // Expert Parallelism (instantiate handler pins NCCL comm to executable lifetime). + dict["te_ep_prepare_ffi"] = + pybind11::dict(pybind11::arg("instantiate") = EncapsulateFFI(EpInstantiateHandler), + pybind11::arg("execute") = EncapsulateFFI(EpPrepareHandler)); + dict["te_ep_dispatch_ffi"] = + pybind11::dict(pybind11::arg("instantiate") = EncapsulateFFI(EpInstantiateHandler), + pybind11::arg("execute") = EncapsulateFFI(EpDispatchHandler)); + dict["te_ep_combine_ffi"] = + pybind11::dict(pybind11::arg("instantiate") = EncapsulateFFI(EpInstantiateHandler), + pybind11::arg("execute") = EncapsulateFFI(EpCombineHandler)); + dict["te_ep_dispatch_bwd_ffi"] = + pybind11::dict(pybind11::arg("instantiate") = EncapsulateFFI(EpInstantiateHandler), + pybind11::arg("execute") = EncapsulateFFI(EpDispatchBwdHandler)); + dict["te_ep_combine_bwd_ffi"] = + pybind11::dict(pybind11::arg("instantiate") = EncapsulateFFI(EpInstantiateHandler), + pybind11::arg("execute") = EncapsulateFFI(EpCombineBwdHandler)); #endif // NVTE_WITH_NCCL_EP // TopK @@ -137,13 +147,15 @@ PYBIND11_MODULE(transformer_engine_jax, m) { m.def("get_cgemm_num_max_streams", &GetCgemmNumMaxStreams); m.def("get_grouped_gemm_setup_workspace_size", &nvte_get_grouped_gemm_setup_workspace_size); #ifdef NVTE_WITH_NCCL_EP - m.def("initialize_ep_communicator", &EpInitialize, pybind11::arg("unique_id_bytes"), + m.def("set_ep_bootstrap_params", &SetEpBootstrapParams, pybind11::arg("unique_id_bytes"), pybind11::arg("ep_size"), pybind11::arg("rank_within_group"), pybind11::arg("num_experts"), pybind11::arg("max_tokens_per_rank"), pybind11::arg("max_recv_tokens_per_rank"), pybind11::arg("hidden_dim"), pybind11::arg("max_num_sms") = 0); - m.def("shutdown_ep_communicator", &EpShutdown); + m.def("release_ep_resources", &ReleaseEpResources); m.def("ep_register_layer", &EpRegisterLayer, pybind11::arg("top_k"), pybind11::arg("dispatch_output_per_expert_alignment") = 0); + m.def("get_ep_instance_state_type_id", &GetEpInstanceStateTypeIdCapsule); + m.def("get_ep_instance_state_type_info", &GetEpInstanceStateTypeInfoCapsule); #endif // NVTE_WITH_NCCL_EP pybind11::enum_(m, "DType", pybind11::module_local()) diff --git a/transformer_engine/jax/ep.py b/transformer_engine/jax/ep.py index 40d07bc3d4..d2850defaf 100644 --- a/transformer_engine/jax/ep.py +++ b/transformer_engine/jax/ep.py @@ -100,7 +100,8 @@ def ep_bootstrap( f" '{ep_resource}' size ({mesh_ep_size})." ) - transformer_engine_jax.initialize_ep_communicator( + # Eager NCCL init while ranks are barrier-synced by the UID broadcast above. + transformer_engine_jax.set_ep_bootstrap_params( uid_bytes, ep_size, rank_within_group, @@ -111,18 +112,10 @@ def ep_bootstrap( max_num_sms=int(max_num_sms), ) - # Shutdown ordering: - # - Python atexit is LIFO. ep_bootstrap runs jmu.process_allgather first, - # which assumes jax.distributed.initialize() ran earlier, so JAX's - # distributed atexit hooks are already registered before this one. Ours - # therefore fires first at exit — fine, because EpShutdown only touches - # NCCL (ncclEpGroupDestroy + ncclCommDestroy) and does not depend on - # JAX's coordination service. Do not add JAX calls to EpShutdown. - # - Running before C++ static destructors avoids the cudartUnloading - # hazard; the C++ destructors are intentionally no-ops. + # Release the C++ anchor at interpreter shutdown so RAII can tear down NCCL. global _atexit_registered if not _atexit_registered: - atexit.register(transformer_engine_jax.shutdown_ep_communicator) + atexit.register(transformer_engine_jax.release_ep_resources) _atexit_registered = True tex.ep.set_ep_config(