Skip to content

Expert Parallelism: common C API + NCCL EP backend#3034

Open
phu0ngng wants to merge 3 commits into
NVIDIA:mainfrom
phu0ngng:phuong/ep-2-commwindow
Open

Expert Parallelism: common C API + NCCL EP backend#3034
phu0ngng wants to merge 3 commits into
NVIDIA:mainfrom
phu0ngng:phuong/ep-2-commwindow

Conversation

@phu0ngng
Copy link
Copy Markdown
Collaborator

Summary

First PR in the TE Expert Parallelism (EP) series. Lands the common C API and NCCL EP backend that later framework PRs (PyTorch, JAX) build on. No Python bindings yet — common-lib foundation plus build wiring only. Build/load works on any arch; SM and NCCL version gates fire at runtime.

Every network-bound payload tensor takes an optional NVTECommWindow. When the window is provided, the backend uses NCCL EP's symmetric-memory zero-copy path, which skips the D2D Memcpy from the user buffers to the Symmetric Staging Buffers.

Implementation

Public C API (transformer_engine/common/include/transformer_engine/{ep.h,comm_window.h})

Types: NVTEEpGroupConfig, NVTEEpLayerConfig, NVTEEpHandle, NVTECommWindow (side-band {ncclWindow_t window, size_t offset}; NCCL peer handles are not carried on NVTETensor).

Lifecycle (host-only, eager):

void     nvte_ep_initialize(void* ep_comm, NVTEEpGroupConfig group_config);
void     nvte_ep_shutdown(void);

uint64_t nvte_ep_register_layer(NVTEEpLayerConfig layer_config, size_t* handle_mem_size);
  • nvte_ep_initialize — borrow an external ncclComm_t for the EP sub-group and init the singleton backend.

  • nvte_ep_shutdown — tear down the backend; idempotent; does not destroy ep_comm.

  • nvte_ep_register_layer — reserve a handle_id for a layer config and report the handle_mem buffer size the caller must allocate. The pair {id, mem} becomes the per-step NVTEEpHandle.

Per-step (allocation-free, CUDA-graph capturable)

void nvte_ep_prepare(NVTEEpHandle handle, NVTETensor topk_idx, NVTETensor token_counts,
                     size_t dispatch_output_per_expert_alignment, cudaStream_t stream);

void nvte_ep_dispatch(NVTEEpHandle handle, NVTETensor topk_idx,
                      NVTETensor tokens, NVTECommWindow tokens_win,
                      NVTETensor topk_weights, NVTECommWindow topk_weights_win,
                      NVTETensor recv_tokens, NVTECommWindow recv_tokens_win,
                      NVTETensor recv_topk_weights,  NVTECommWindow recv_topk_weights_win,
                      cudaStream_t stream);

void nvte_ep_combine(NVTEEpHandle handle, NVTETensor expert_out, NVTECommWindow expert_out_win,
                     NVTETensor result, cudaStream_t stream);

void nvte_ep_dispatch_bwd(NVTEEpHandle handle, NVTETensor grad, NVTECommWindow grad_win,
                          NVTETensor g_recv_topk_weights, NVTECommWindow g_recv_topk_weights_win,
                          NVTETensor grad_tokens, NVTETensor grad_topk_weights, cudaStream_t stream);

void nvte_ep_combine_bwd(NVTEEpHandle handle, NVTETensor grad, NVTECommWindow grad_win,
                         NVTETensor grad_expert_out, NVTECommWindow grad_expert_out_win,
                         cudaStream_t stream);
  • nvte_ep_prepare — all-gather the routing map and write routing maps to handle.mem.
  • nvte_ep_dispatch — scatter tokens and routing weights from source ranks to expert ranks. tokens, topk_weights, recv_tokens, recv_topk_weights each accept an optional symm-mem window for zero-copy.
  • nvte_ep_combine — scatter-sum expert outputs back to source ranks (unweighted; caller pre-multiplies by recv_topk_weights). expert_out accepts a window.
  • nvte_ep_dispatch_bwd — backward of dispatch; routes token and weight grads back to source. grad and g_recv_topk_weights accept windows; the gathered outputs (grad_tokens, grad_topk_weights).
  • nvte_ep_combine_bwd — backward of combine; grad and grad_expert_out accept windows. Padded slots in grad_expert_out are zeroed.

Backend + build

  • NCCL EP backend (transformer_engine/common/ep/): EPBackend singleton, HT-mode dispatch/combine over NCCL EP (libnccl_ep.so), group/layer registration. Internal helper make_payload_tensor() builds the per-call ncclEpTensor_t: when the caller's NVTECommWindow.window != nullptr it sets win_hdl + win_offset (zero-copy); otherwise it sets data from nvte_tensor_data(t) (HBM fallback).
  • Runtime gates (in EPBackend::initialize): SM>=90 (via cudaDeviceGetAttribute), NCCL>=2.30.4 (via ncclGetVersion), CUDA multicast/NVLS support.
  • Stub path: when NVTE_WITH_NCCL_EP=OFF, ep/ep_api_stub.cpp provides throwing nvte_ep_* stubs so framework bindings link unconditionally; failure surfaces at first nvte_ep_initialize.
  • Build wiring
    • setup.py builds libnccl_ep.so from 3rdparty/nccl by default; auto-disables NCCL EP when no requested CUDA arch >= 90. Explicit NVTE_BUILD_WITH_NCCL_EP=1 with all archs < 90 is treated as user error NVTE_BUILD_WITH_NCCL_EP=0 to opt out.
    • NCCL_HOME resolved dynamically: explicit env → /opt/nvidia/nccl, /usr/local/nccl, /usrldconfig -p fallback.

Testing

  • C++ distributed tests under tests/cpp_distributed/.

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

@phu0ngng phu0ngng requested a review from ptrendx as a code owner May 22, 2026 02:42
@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps Bot commented May 22, 2026

Greptile Summary

This PR lands the foundational Expert Parallelism (EP) layer for TransformerEngine: a common C API (ep.h, comm_window.h) and a NCCL EP backend singleton that handles group/layer registration and the full forward/backward dispatch-combine cycle. No Python bindings are included yet; the change is intended as the base that PyTorch and JAX PRs will build on.

  • New C API (nvte_ep_initialize / nvte_ep_shutdown / nvte_ep_register_layer + per-step ops): thin wrappers over an EPBackend Meyers singleton that owns the ncclEpGroup_t and a handle-id cache; all per-step ops are allocation-free and designed to be CUDA-graph-capturable.
  • Build wiring: setup.py adds _discover_nccl_home / build_nccl_ep_submodule to drive the 3rdparty/nccl submodule build; auto-disables NCCL EP when no arch ≥ 90 is targeted; stub path (ep_api_stub.cpp) provides throwing symbols when NCCL EP is off.
  • Tests: new tests/cpp_distributed/ suite with a multi-process bash harness that spawns one process per GPU and exchanges ncclUniqueId via a shared temp file.

Confidence Score: 4/5

Safe to merge with one build issue addressed: the public comm_window.h header pulls in <nccl.h> and exposes ncclWindow_t, which causes the stub (off) build path to fail compilation on systems with NCCL < 2.30.

The comm_window.h public header includes <nccl.h> and uses ncclWindow_t. When NVTE_WITH_NCCL_EP=OFF, CMake adds no NCCL EP include dirs, yet ep_api_stub.cpp transitively includes that header. On a machine with NCCL 2.18 (no ncclWindow_t), the stub build — the fallback for pre-EP systems — fails to compile.

transformer_engine/common/include/transformer_engine/comm_window.h and transformer_engine/common/ep/ep_api_stub.cpp need attention: the public header's unconditional ncclWindow_t dependency breaks the stub build path on pre-2.30 NCCL systems.

Important Files Changed

Filename Overview
transformer_engine/common/include/transformer_engine/comm_window.h New public C header exposing NVTECommWindow. Unconditionally includes <nccl.h> and uses ncclWindow_t, breaking stub builds on pre-NCCL-2.30 systems — the exact scenario the stub path is meant to serve.
transformer_engine/common/ep/ep_backend.cpp Core EP singleton backend: group creation, layer registration, and all per-step ops (prepare/dispatch/combine and their backwards). Mutex-protected operations hold the lock across NCCL EP calls. The already-reported max_token_bytes hardcoding and ncclEpHandleConfig_t init asymmetry are notable concerns.
transformer_engine/common/ep/ep_api_stub.cpp Throwing stubs for NVTE_WITH_NCCL_EP=OFF builds. Compiles correctly only if the system NCCL headers include ncclWindow_t (NCCL >= 2.30), which may not hold on older-NCCL systems where this stub path is intended to be used.
transformer_engine/common/include/transformer_engine/ep.h New public C API header for Expert Parallelism — lifecycle, registration, and per-step ops. Well-documented with clear in/out annotations; inherits the nccl.h exposure issue from comm_window.h.
setup.py Adds NCCL EP detection, arch-gating, and build orchestration. _discover_nccl_home and build_nccl_ep_submodule are well-structured. libnccl_ep.so rebuild is skipped if the file already exists, which won't detect submodule updates.
transformer_engine/common/CMakeLists.txt Adds NCCL EP CMake wiring: header/lib discovery, rpath embed, and conditional stub vs. real backend source selection. Includes a runtime-diagnosed NCCL version log. Looks correct.
tests/cpp_distributed/test_ep_common.h Shared test infrastructure: process bootstrap, RAII tensor/buffer helpers, and uid-file-based ncclUniqueId exchange. Default uid path is rank-specific (deadlocks without --uid-file), but run_test_ep.sh always provides the flag.
tests/cpp_distributed/run_test_ep.sh Multi-process test harness: spawns one process per GPU, exchanges a shared UID file, collects logs, and enforces per-rank timeouts. SM < 90 skip logic is correct.

Sequence Diagram

sequenceDiagram
    participant Caller
    participant C_API as nvte_ep_* (ep_api.cpp)
    participant Backend as EPBackend singleton
    participant NCCL_EP as ncclEp* (libnccl_ep.so)

    Caller->>C_API: nvte_ep_initialize(ep_comm, group_config)
    C_API->>Backend: EPBackend::initialize()
    Backend->>NCCL_EP: ncclEpCreateGroup()

    Caller->>C_API: "nvte_ep_register_layer(layer_config, &mem_size)"
    C_API->>Backend: register_layer()
    Backend->>NCCL_EP: ncclEpHandleMemSize()
    Backend-->>Caller: handle_id + required mem_size

    Note over Caller: Allocates handle_mem buffer

    loop Per training step
        Caller->>C_API: nvte_ep_prepare(handle, topk_idx, token_counts, stream)
        C_API->>Backend: prepare() → ncclEpUpdateHandle()
        Backend->>NCCL_EP: ncclEpUpdateHandle (AllGather routing map)

        Caller->>C_API: nvte_ep_dispatch(handle, tokens, [win], weights, [win], stream)
        C_API->>Backend: dispatch() → ncclEpDispatch()
        Backend->>NCCL_EP: ncclEpDispatch (scatter tokens to expert ranks)

        Note over Caller: Expert computation on recv_tokens

        Caller->>C_API: nvte_ep_combine(handle, expert_out, [win], result, stream)
        C_API->>Backend: combine() → ncclEpCombine()
        Backend->>NCCL_EP: ncclEpCombine (scatter-sum back to source ranks)
    end

    Caller->>C_API: nvte_ep_shutdown()
    C_API->>Backend: EPBackend::shutdown()
    Backend->>NCCL_EP: ncclEpHandleDestroy + ncclEpGroupDestroy
Loading

Reviews (4): Last reviewed commit: "[pre-commit.ci] auto fixes from pre-comm..." | Re-trigger Greptile

Comment thread transformer_engine/common/ep/ep_backend.cpp Outdated
Comment thread transformer_engine/common/ep/ep_backend.cpp
Comment thread setup.py Outdated
Comment on lines +162 to +164
env_home = os.environ.get("NCCL_HOME")
if env_home and (Path(env_home) / "include" / "nccl.h").exists():
return env_home
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 NCCL_HOME set to a wrong path is silently ignored

If a user sets NCCL_HOME to an incorrect prefix that doesn't contain include/nccl.h, the function falls through to the system probe list without any warning. The function should warn when NCCL_HOME is set but doesn't resolve to a valid NCCL install.

Suggested change
env_home = os.environ.get("NCCL_HOME")
if env_home and (Path(env_home) / "include" / "nccl.h").exists():
return env_home
env_home = os.environ.get("NCCL_HOME")
if env_home:
if (Path(env_home) / "include" / "nccl.h").exists():
return env_home
print(
f"[NCCL EP] WARNING: NCCL_HOME='{env_home}' is set but "
f"'{env_home}/include/nccl.h' was not found; falling back to system probes."
)

Comment thread setup.py Outdated
cfg.algorithm = NCCL_EP_ALGO_HIGH_THROUGHPUT;
cfg.num_experts = static_cast<unsigned int>(group_config.num_experts);
cfg.max_dispatch_tokens_per_rank = static_cast<unsigned int>(group_config.max_tokens_per_rank);
cfg.max_token_bytes = static_cast<unsigned int>(group_config.hidden_dim * sizeof(nv_bfloat16));
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 max_token_bytes hardcoded to sizeof(nv_bfloat16) breaks float32 dispatch

cfg.max_token_bytes is computed as hidden_dim * sizeof(nv_bfloat16) (2 bytes), but nvte_dtype_to_nccl supports float32, float16, int32, int64, float8, etc. When a caller creates the EP group with this config and later dispatches float32 tokens (via nvte_ep_dispatch), the pre-allocated max_token_bytes is half the required size. NCCL EP uses this value to size internal staging buffers at group creation; dispatching a wider dtype silently overruns those buffers or triggers an internal NCCL error. NVTEEpGroupConfig needs a dtype (or max_token_element_bytes) field so callers can declare the maximum element width they will use.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note for myself: Need to expose this option for users to set in ep_bootstrap.

Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
@phu0ngng phu0ngng force-pushed the phuong/ep-2-commwindow branch from 099857f to 17e5126 Compare May 22, 2026 23:07
phu0ngng and others added 2 commits May 23, 2026 19:36
…em_reloc gating

Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant