Skip to content

[JAX] Expert Parallelism: JAX primitives + VJPs#3036

Open
phu0ngng wants to merge 2 commits into
NVIDIA:mainfrom
phu0ngng:phuong/ep-3-jax
Open

[JAX] Expert Parallelism: JAX primitives + VJPs#3036
phu0ngng wants to merge 2 commits into
NVIDIA:mainfrom
phu0ngng:phuong/ep-3-jax

Conversation

@phu0ngng
Copy link
Copy Markdown
Collaborator

@phu0ngng phu0ngng commented May 22, 2026

Summary

Third PR in the TE Expert Parallelism (EP) series, built on top of #3034. Lands the JAX bindings: an XLA FFI layer over the nvte_ep_* C API, a Python wrapper with custom_vjp for autograd, mesh-aware sharding rules, a multi-process test suite, and an end-to-end MoE example. NCCL ncclEpDispatch/ncclEpCombine are exposed as XLA primitives and work with CUDA-graph capture.

Implementation

Public Python API (transformer_engine/jax/ep.py)

from transformer_engine.jax.ep import (
    EpHandle,        # opaque (id, handle_mem) pair from ep_prepare
    ep_bootstrap,    # one-shot per-process: init NCCL comm + nvte_ep_initialize
    ep_dispatch,     # custom_vjp-wrapped dispatch 
    ep_combine,      # custom_vjp-wrapped combine

ep_dispatch / ep_combine are jax.custom_vjp functions: forward is the FFI primitive, backward calls the matching nvte_ep_*_bwd FFI primitive directly (no ep_prepare in the bwd — routing state is already cached in handle.mem). Note that ep_dispatch also calls ep_prepare in the forward path, which all-gathers and preprocesses routing maps.

XLA FFI bindings (transformer_engine/jax/csrc/extensions/ep.cpp)

Five XLA_FFI_DEFINE_HANDLER_SYMBOL entries — EpPrepareHandler, EpDispatchHandler, EpCombineHandler, EpDispatchBwdHandler, EpCombineBwdHandler — each calling the corresponding nvte_ep_* C entry point. All marked FFI_CudaGraph_Traits so they capture cleanly. handle_id is a static FFI attribute baked at jit trace time.

Primitives + Python layer (transformer_engine/jax/cpp_extensions/ep.py, +951 lines)

Standard TE primitive plumbing: abstract_eval (shape/dtype inference), lowering, impl, outer_primitive registration, and partitioning rules so the EP collective is treated as a single sharded op by XLA (no spurious resharding around it).

Sharding (transformer_engine/jax/sharding.py, +12 lines)

Adds the EP mesh axis to the global mesh resource set so downstream sharding rules can reference it.

Build wiring (build_tools/jax.py, +41 lines)

Threads NCCL EP linkage through the JAX transformer_engine_jax extension. No new top-level build flags; rides on the parent PR's NVTE_BUILD_WITH_NCCL_EP.

Tests & example

  • tests/jax/test_multi_process_ep.py (+690 lines): 13 tests covering bootstrap, ep_prepare shape/handle contracts, primitive-level dispatch/combine identity (uniform + skewed routing), custom_vjp fwd+bwd correctness, and HLO inspection (must not insert XLA collectives outside the EP FFI).
  • tests/jax/multi_process_launch_ep.sh: 4-rank launcher; sets XLA_FLAGS to keep XLA command-buffer capture off for the EP FFI sequence (NCCL EP graph-destroy interaction).
  • examples/jax/ep/ep_moe.py (+394 lines) + run_test_ep.sh: end-to-end MoE with EP, dp=ep=2 mesh, includes a ref-comparison --check that verifies fwd+bwd vs a single-process reference.

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

phu0ngng added 2 commits May 22, 2026 02:03
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
…s, MoE example)

Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps Bot commented May 22, 2026

Greptile Summary

This PR lands the JAX bindings for the Expert Parallelism (EP) feature: XLA FFI handlers over the nvte_ep_* C API, jax.custom_vjp-wrapped ep_dispatch/ep_combine with full backward passes, mesh-aware SPMD sharding rules, build wiring for the NCCL EP submodule, a 13-test multi-process suite, and an end-to-end MoE example.

  • JAX primitives (cpp_extensions/ep.py, csrc/extensions/ep.cpp): Five FFI handlers (EpPrepare, EpDispatch, EpCombine, EpDispatchBwd, EpCombineBwd) all marked FFI_CudaGraph_Traits; abstract-eval, lowering, per-primitive sharding rules, and Shardy rules are provided. The VJP math in ep.py correctly pre-weights the expert output before ep_combine_fwd and unweights in _combine_bwd.
  • Build wiring (build_tools/jax.py): Threads nccl and nccl_ep library linkage into the JAX pybind11 extension; NCCL EP defaults to enabled but the arch-guard logic diverges from setup.py's graceful auto-disable, producing hard failures or mismatched binaries in several common environments.
  • Common C++ backend (ep_backend.cpp, ep_api_stub.cpp): Singleton EPBackend with SM≥90/NVLS checks at init time; throwing stubs for all nvte_ep_* symbols when built without EP.

Confidence Score: 3/5

The build tooling has a gap between how setup.py and build_tools/jax.py decide whether to enable NCCL EP — a fresh clone on non-Hopper hardware will have setup.py silently skip EP but jax.py hard-fail, and an unset NVTE_CUDA_ARCHS with the submodule present produces a JAX extension compiled with EP while the common library is compiled without it.

The JAX primitive logic, VJP math, and C++ FFI handlers look sound, and the stub file provides clean fallback symbols. However, the EP enable/disable decision is made independently and inconsistently by setup.py and build_tools/jax.py, meaning a non-trivial class of user environments will either hit a hard build error or silently produce a mismatched binary where EP calls throw ep_not_built() at runtime.

build_tools/jax.py needs its EP-disable logic aligned with setup.py's graceful auto-detect; transformer_engine/jax/cpp_extensions/ep.py warrants a second look on the callsite-frame handle_id caching contract.

Important Files Changed

Filename Overview
build_tools/jax.py Adds NCCL EP linkage; defaults to enabled but arch-guard logic diverges from setup.py, causing hard failures on non-Hopper builds or when the submodule is not initialized.
transformer_engine/jax/cpp_extensions/ep.py +951 lines of JAX primitives (EpPrepare/Dispatch/Combine forward+backward); sharding rules and lowering look correct, but the callsite-frame handle_id cache contradicts its own documented contract for multi-layer models.
transformer_engine/jax/ep.py Public API: ep_bootstrap, ep_dispatch (custom_vjp), ep_combine (custom_vjp); VJP math and sharding constraint re-pinning look correct; redundant Python divisibility assertion after C++ already validates it.
transformer_engine/jax/csrc/extensions/ep.cpp Five XLA FFI handlers for EP ops; int32→int64 upcast handled on-stream; RAII comm cleanup in EpCommManager; all handlers marked FFI_CudaGraph_Traits for graph capture.
transformer_engine/common/ep/ep_backend.cpp Core EP backend singleton; per-handle config cache (handle_id → HandleEntry); prepare/dispatch/combine/bwd ops delegate to NCCL EP with ScopedEpHandle RAII; SM≥90 and NVLS multicast checks at initialize time.
transformer_engine/jax/sharding.py Adds ep_resource field to MeshResource and ep_axis_size() helper; straightforward additive change, no existing functionality affected.
transformer_engine/common/ep/ep_api_stub.cpp Throwing stubs for all nvte_ep_* symbols when built without EP; provides clear error message directing users to rebuild with NVTE_BUILD_WITH_NCCL_EP=1.
transformer_engine/common/include/transformer_engine/ep.h Public C API header for EP: well-documented structs and per-step ops; clean C linkage with no opaque types exposed externally.

Reviews (1): Last reviewed commit: "Expert Parallelism: JAX bindings (FFI, c..." | Re-trigger Greptile

Comment thread build_tools/jax.py
Comment on lines +127 to +131
if not (submod_ep_inc / "nccl_ep.h").exists():
raise RuntimeError(
f"NCCL EP header not found at {submod_ep_inc}/nccl_ep.h. "
"Run `git submodule update --init --recursive` to checkout 3rdparty/nccl."
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Hard failure on missing submodule inconsistent with setup.py auto-disable

build_tools/jax.py always hard-fails with RuntimeError when nccl_ep.h is absent, but setup.py gracefully auto-disables EP when no arch ≥ 90 is found (printing a warning instead of aborting). A user on non-Hopper hardware who does a plain git clone (no --recursive) will see setup.py complete silently with EP disabled, then hit this RuntimeError when the JAX extension is built — even though the correct behavior is to also disable EP in the JAX extension. The fix is to apply the same conditional: if NVTE_CUDA_ARCHS indicates no arch ≥ 90 (or is unset), print a warning and set build_with_nccl_ep = False rather than raising.

Comment thread build_tools/jax.py
Comment on lines +116 to +124
# NCCL EP requires SM>=90 (Hopper+).
archs_env = os.getenv("NVTE_CUDA_ARCHS", "")
for a in archs_env.split(";"):
a_num = "".join(c for c in a if c.isdigit())
if a_num and int(a_num) < 90:
raise RuntimeError(
f"NCCL EP requires CUDA arch >= 90 (Hopper or newer); got '{a}' in"
" NVTE_CUDA_ARCHS."
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Arch guard in jax.py is inconsistent with setup.py's auto-disable logic

setup.py treats an empty or unset NVTE_CUDA_ARCHS as "no arch ≥ 90 → disable EP", whereas build_tools/jax.py only raises if a specific arch below 90 is explicitly listed. When NVTE_CUDA_ARCHS is unset, archs_env.split(";") yields [""], a_num is "", and the if a_num guard short-circuits — EP stays enabled. A user with no NVTE_CUDA_ARCHS set and a checked-out submodule will compile the JAX extension with NVTE_WITH_NCCL_EP defined while setup.py compiled the common library without it. Every call to EpInitialize will then hit the ep_not_built() stub and throw at runtime.

Comment on lines +872 to +896
_HANDLE_ID_CALLSITE_CACHE = {}


def ep_prepare(topk_idx, dispatch_output_per_expert_alignment=0):
"""Exchange routing metadata; return ``(token_counts, EpHandle)``."""
import sys as _sys

top_k = int(topk_idx.shape[-1])
alignment = int(dispatch_output_per_expert_alignment)
# Cache handle_id by caller (file:lineno, top_k, alignment): JAX re-traces
# the same call site (e.g. custom_vjp fwd vs primal) and the resulting
# EpHandles must share the same id to compare equal in pytree aux.
f = _sys._getframe(1)
cache_key = (f.f_code.co_filename, f.f_lineno, top_k, alignment)
handle_id = _HANDLE_ID_CALLSITE_CACHE.get(cache_key)
if handle_id is None:
handle_id = ep_allocate_handle_id(top_k, alignment)
_HANDLE_ID_CALLSITE_CACHE[cache_key] = handle_id
token_counts, handle_mem = EpPreparePrimitive.outer_primitive.bind(
topk_idx,
handle_id=handle_id,
dispatch_output_per_expert_alignment=alignment,
is_outer=True,
)
return token_counts, EpHandle(handle_mem, handle_id)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 sys._getframe cache causes all same-config layers to share one handle_id

ep_prepare caches handle_id by (co_filename, f_lineno, top_k, alignment). When called from _dispatch_fwd in ep.py, the caller frame is always line 181 of ep.py, so every MoE layer using ep_dispatch with the same (top_k, alignment) will receive the identical handle_id. The docstring of ep_allocate_handle_id explicitly states "Distinct logical layers must each call this — sharing a handle_id across layers corrupts the routing state." While the C++ backend currently stores only config in the handle_id → HandleEntry map (per-step state lives in handle_mem), this directly contradicts the documented contract and any future NCCL EP change that attaches per-layer state to handle_id will silently break multi-layer MoE models. Additionally, sys._getframe is CPython-specific.

}

private:
EpCommManager() = default;
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we use stateful FFI calls we could tie to EP communicator to the lifetime of the jax computation rather than the process.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cool to learn! I will update it.

Error_Type EpPrepareFFI(cudaStream_t stream, Buffer_Type topk_idx, Result_Type token_counts,
Result_Type handle_mem, Result_Type workspace, EpPrepareConfig config) {
auto topk_dims = topk_idx.dimensions();
NVTE_CHECK(topk_dims.size() >= 2,
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: can we return FFI InvalidArgument instead of a NVTE_CHECK for these inputs?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is probably a good idea. I suggest we make another follow-up MR to do so for all the FFIs.

@phu0ngng phu0ngng requested a review from tdophung May 22, 2026 15:51
@phu0ngng
Copy link
Copy Markdown
Collaborator Author

I would appreciate your help to review this PR @tdophung @jberchtold-nvidia!
Please focus on the changes in the JAX side, as the TE/Common ones will be discussed in #3034

Comment thread examples/jax/ep/ep_moe.py
kernels = kernels.reshape(ep_size, NLE, *kernels.shape[1:])

@jax.jit
def step(idx, toks, w, lk):
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What does lk stand for?


# Keep XLA command-buffer capture OFF for the EP FFI sequence. Cuda-graph
# capture of dispatch→combine fails with CUDA_ERROR_INVALID_VALUE on graph
# destroy under recent XLA/NCCL; the non-cudaGraph path is correct since the
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this mean we always have to disable cuda graph capture for now when using TE EP until this is fixed? If so, would it be better to remove the CudaGraph trait from the EP FFIs instead

Copy link
Copy Markdown
Collaborator Author

@phu0ngng phu0ngng May 22, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wait, this is an old command. Let me clean them up and enable CUDAGraph! I fixed this in NCCL EP already :D.

handle_mem_aval = jax.core.ShapedArray(leading + (handle_mem_size,), jnp.uint8)
# FFI scratch for the int32 -> int64 topk_idx upcast.
# TODO(phuong): drop once NCCL EP supports int32 topk_idx.
workspace_aval = jax.core.ShapedArray(topk_idx_aval.shape, jnp.int64)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this will fallback to jnp.int32 unless JAX_ENABLE_X64 is set, without adjusting the shape accordingly so you get a smaller amount of bytes. We may need to make this jnp.int32 and make the shape 2x bigger.

Here is jnp.zeros, I believe it also applies to primitive outputs

>>> jnp.zeros((128,), jnp.int64)
<stdin>:1: UserWarning: Explicitly requested dtype int64 requested in zeros is not available, and will be truncated to dtype int32. To enable more dtypes, set the jax_enable_x64 configuration option or the JAX_ENABLE_X64 shell environment variable. See https://github.com/jax-ml/jax#current-gotchas for more.
Array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], dtype=int32)
>>> a = jnp.zeros((128,), jnp.int64)
>>> a.dtype
dtype('int32')
>>> a.shape
(128,)

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You are right.
There is another int32-int64 conversion with a CUDA kernel triggered underneath, and this JAX conversion should be cleaned up.

leading = _ep_leading_dims(is_outer)
recv_tokens_aval = jax.core.ShapedArray(leading + (recv_pr, hidden_dim), tok_dtype)
recv_topk_weights_aval = jax.core.ShapedArray(leading + (recv_pr,), jnp.float32)
workspace_aval = jax.core.ShapedArray(topk_idx_aval.shape, jnp.int64)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same comment as above about int64

Comment thread examples/jax/ep/ep_moe.py
local_device_ids=[args.process_id],
)
assert (
jax.local_device_count() == 1
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this local_device_count() == 1 restriction only apply to this example or does it apply to all TE EP usage? If all TE EP usage, can we also add this assertion into ep_bootstrap

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This will apply to all, as we are expecting a single process per device setup. The current TE EP does not support single process multiple devices. Let me add this assertion to the ep_bootstrap.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants