[JAX] Expert Parallelism: JAX primitives + VJPs#3036
Conversation
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
…s, MoE example) Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Greptile SummaryThis PR lands the JAX bindings for the Expert Parallelism (EP) feature: XLA FFI handlers over the
Confidence Score: 3/5The build tooling has a gap between how setup.py and build_tools/jax.py decide whether to enable NCCL EP — a fresh clone on non-Hopper hardware will have setup.py silently skip EP but jax.py hard-fail, and an unset NVTE_CUDA_ARCHS with the submodule present produces a JAX extension compiled with EP while the common library is compiled without it. The JAX primitive logic, VJP math, and C++ FFI handlers look sound, and the stub file provides clean fallback symbols. However, the EP enable/disable decision is made independently and inconsistently by setup.py and build_tools/jax.py, meaning a non-trivial class of user environments will either hit a hard build error or silently produce a mismatched binary where EP calls throw ep_not_built() at runtime. build_tools/jax.py needs its EP-disable logic aligned with setup.py's graceful auto-detect; transformer_engine/jax/cpp_extensions/ep.py warrants a second look on the callsite-frame handle_id caching contract. Important Files Changed
Reviews (1): Last reviewed commit: "Expert Parallelism: JAX bindings (FFI, c..." | Re-trigger Greptile |
| if not (submod_ep_inc / "nccl_ep.h").exists(): | ||
| raise RuntimeError( | ||
| f"NCCL EP header not found at {submod_ep_inc}/nccl_ep.h. " | ||
| "Run `git submodule update --init --recursive` to checkout 3rdparty/nccl." | ||
| ) |
There was a problem hiding this comment.
Hard failure on missing submodule inconsistent with
setup.py auto-disable
build_tools/jax.py always hard-fails with RuntimeError when nccl_ep.h is absent, but setup.py gracefully auto-disables EP when no arch ≥ 90 is found (printing a warning instead of aborting). A user on non-Hopper hardware who does a plain git clone (no --recursive) will see setup.py complete silently with EP disabled, then hit this RuntimeError when the JAX extension is built — even though the correct behavior is to also disable EP in the JAX extension. The fix is to apply the same conditional: if NVTE_CUDA_ARCHS indicates no arch ≥ 90 (or is unset), print a warning and set build_with_nccl_ep = False rather than raising.
| # NCCL EP requires SM>=90 (Hopper+). | ||
| archs_env = os.getenv("NVTE_CUDA_ARCHS", "") | ||
| for a in archs_env.split(";"): | ||
| a_num = "".join(c for c in a if c.isdigit()) | ||
| if a_num and int(a_num) < 90: | ||
| raise RuntimeError( | ||
| f"NCCL EP requires CUDA arch >= 90 (Hopper or newer); got '{a}' in" | ||
| " NVTE_CUDA_ARCHS." | ||
| ) |
There was a problem hiding this comment.
Arch guard in
jax.py is inconsistent with setup.py's auto-disable logic
setup.py treats an empty or unset NVTE_CUDA_ARCHS as "no arch ≥ 90 → disable EP", whereas build_tools/jax.py only raises if a specific arch below 90 is explicitly listed. When NVTE_CUDA_ARCHS is unset, archs_env.split(";") yields [""], a_num is "", and the if a_num guard short-circuits — EP stays enabled. A user with no NVTE_CUDA_ARCHS set and a checked-out submodule will compile the JAX extension with NVTE_WITH_NCCL_EP defined while setup.py compiled the common library without it. Every call to EpInitialize will then hit the ep_not_built() stub and throw at runtime.
| _HANDLE_ID_CALLSITE_CACHE = {} | ||
|
|
||
|
|
||
| def ep_prepare(topk_idx, dispatch_output_per_expert_alignment=0): | ||
| """Exchange routing metadata; return ``(token_counts, EpHandle)``.""" | ||
| import sys as _sys | ||
|
|
||
| top_k = int(topk_idx.shape[-1]) | ||
| alignment = int(dispatch_output_per_expert_alignment) | ||
| # Cache handle_id by caller (file:lineno, top_k, alignment): JAX re-traces | ||
| # the same call site (e.g. custom_vjp fwd vs primal) and the resulting | ||
| # EpHandles must share the same id to compare equal in pytree aux. | ||
| f = _sys._getframe(1) | ||
| cache_key = (f.f_code.co_filename, f.f_lineno, top_k, alignment) | ||
| handle_id = _HANDLE_ID_CALLSITE_CACHE.get(cache_key) | ||
| if handle_id is None: | ||
| handle_id = ep_allocate_handle_id(top_k, alignment) | ||
| _HANDLE_ID_CALLSITE_CACHE[cache_key] = handle_id | ||
| token_counts, handle_mem = EpPreparePrimitive.outer_primitive.bind( | ||
| topk_idx, | ||
| handle_id=handle_id, | ||
| dispatch_output_per_expert_alignment=alignment, | ||
| is_outer=True, | ||
| ) | ||
| return token_counts, EpHandle(handle_mem, handle_id) |
There was a problem hiding this comment.
sys._getframe cache causes all same-config layers to share one handle_id
ep_prepare caches handle_id by (co_filename, f_lineno, top_k, alignment). When called from _dispatch_fwd in ep.py, the caller frame is always line 181 of ep.py, so every MoE layer using ep_dispatch with the same (top_k, alignment) will receive the identical handle_id. The docstring of ep_allocate_handle_id explicitly states "Distinct logical layers must each call this — sharing a handle_id across layers corrupts the routing state." While the C++ backend currently stores only config in the handle_id → HandleEntry map (per-step state lives in handle_mem), this directly contradicts the documented contract and any future NCCL EP change that attaches per-layer state to handle_id will silently break multi-layer MoE models. Additionally, sys._getframe is CPython-specific.
| } | ||
|
|
||
| private: | ||
| EpCommManager() = default; |
There was a problem hiding this comment.
If we use stateful FFI calls we could tie to EP communicator to the lifetime of the jax computation rather than the process.
There was a problem hiding this comment.
Cool to learn! I will update it.
| Error_Type EpPrepareFFI(cudaStream_t stream, Buffer_Type topk_idx, Result_Type token_counts, | ||
| Result_Type handle_mem, Result_Type workspace, EpPrepareConfig config) { | ||
| auto topk_dims = topk_idx.dimensions(); | ||
| NVTE_CHECK(topk_dims.size() >= 2, |
There was a problem hiding this comment.
nit: can we return FFI InvalidArgument instead of a NVTE_CHECK for these inputs?
There was a problem hiding this comment.
This is probably a good idea. I suggest we make another follow-up MR to do so for all the FFIs.
|
I would appreciate your help to review this PR @tdophung @jberchtold-nvidia! |
| kernels = kernels.reshape(ep_size, NLE, *kernels.shape[1:]) | ||
|
|
||
| @jax.jit | ||
| def step(idx, toks, w, lk): |
There was a problem hiding this comment.
What does lk stand for?
|
|
||
| # Keep XLA command-buffer capture OFF for the EP FFI sequence. Cuda-graph | ||
| # capture of dispatch→combine fails with CUDA_ERROR_INVALID_VALUE on graph | ||
| # destroy under recent XLA/NCCL; the non-cudaGraph path is correct since the |
There was a problem hiding this comment.
Does this mean we always have to disable cuda graph capture for now when using TE EP until this is fixed? If so, would it be better to remove the CudaGraph trait from the EP FFIs instead
There was a problem hiding this comment.
Wait, this is an old command. Let me clean them up and enable CUDAGraph! I fixed this in NCCL EP already :D.
| handle_mem_aval = jax.core.ShapedArray(leading + (handle_mem_size,), jnp.uint8) | ||
| # FFI scratch for the int32 -> int64 topk_idx upcast. | ||
| # TODO(phuong): drop once NCCL EP supports int32 topk_idx. | ||
| workspace_aval = jax.core.ShapedArray(topk_idx_aval.shape, jnp.int64) |
There was a problem hiding this comment.
I think this will fallback to jnp.int32 unless JAX_ENABLE_X64 is set, without adjusting the shape accordingly so you get a smaller amount of bytes. We may need to make this jnp.int32 and make the shape 2x bigger.
Here is jnp.zeros, I believe it also applies to primitive outputs
>>> jnp.zeros((128,), jnp.int64)
<stdin>:1: UserWarning: Explicitly requested dtype int64 requested in zeros is not available, and will be truncated to dtype int32. To enable more dtypes, set the jax_enable_x64 configuration option or the JAX_ENABLE_X64 shell environment variable. See https://github.com/jax-ml/jax#current-gotchas for more.
Array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], dtype=int32)
>>> a = jnp.zeros((128,), jnp.int64)
>>> a.dtype
dtype('int32')
>>> a.shape
(128,)
There was a problem hiding this comment.
You are right.
There is another int32-int64 conversion with a CUDA kernel triggered underneath, and this JAX conversion should be cleaned up.
| leading = _ep_leading_dims(is_outer) | ||
| recv_tokens_aval = jax.core.ShapedArray(leading + (recv_pr, hidden_dim), tok_dtype) | ||
| recv_topk_weights_aval = jax.core.ShapedArray(leading + (recv_pr,), jnp.float32) | ||
| workspace_aval = jax.core.ShapedArray(topk_idx_aval.shape, jnp.int64) |
There was a problem hiding this comment.
Same comment as above about int64
| local_device_ids=[args.process_id], | ||
| ) | ||
| assert ( | ||
| jax.local_device_count() == 1 |
There was a problem hiding this comment.
Does this local_device_count() == 1 restriction only apply to this example or does it apply to all TE EP usage? If all TE EP usage, can we also add this assertion into ep_bootstrap
There was a problem hiding this comment.
This will apply to all, as we are expecting a single process per device setup. The current TE EP does not support single process multiple devices. Let me add this assertion to the ep_bootstrap.
Summary
Third PR in the TE Expert Parallelism (EP) series, built on top of #3034. Lands the JAX bindings: an XLA FFI layer over the
nvte_ep_*C API, a Python wrapper withcustom_vjpfor autograd, mesh-aware sharding rules, a multi-process test suite, and an end-to-end MoE example. NCCLncclEpDispatch/ncclEpCombineare exposed as XLA primitives and work with CUDA-graph capture.Implementation
Public Python API (
transformer_engine/jax/ep.py)ep_dispatch/ep_combinearejax.custom_vjpfunctions: forward is the FFI primitive, backward calls the matchingnvte_ep_*_bwdFFI primitive directly (noep_preparein the bwd — routing state is already cached inhandle.mem). Note thatep_dispatchalso callsep_preparein the forward path, which all-gathers and preprocesses routing maps.XLA FFI bindings (
transformer_engine/jax/csrc/extensions/ep.cpp)Five
XLA_FFI_DEFINE_HANDLER_SYMBOLentries —EpPrepareHandler,EpDispatchHandler,EpCombineHandler,EpDispatchBwdHandler,EpCombineBwdHandler— each calling the correspondingnvte_ep_*C entry point. All markedFFI_CudaGraph_Traitsso they capture cleanly.handle_idis a static FFI attribute baked at jit trace time.Primitives + Python layer (
transformer_engine/jax/cpp_extensions/ep.py, +951 lines)Standard TE primitive plumbing:
abstract_eval(shape/dtype inference),lowering,impl,outer_primitiveregistration, and partitioning rules so the EP collective is treated as a single sharded op by XLA (no spurious resharding around it).Sharding (
transformer_engine/jax/sharding.py, +12 lines)Adds the EP mesh axis to the global mesh resource set so downstream sharding rules can reference it.
Build wiring (
build_tools/jax.py, +41 lines)Threads NCCL EP linkage through the JAX
transformer_engine_jaxextension. No new top-level build flags; rides on the parent PR'sNVTE_BUILD_WITH_NCCL_EP.Tests & example
tests/jax/test_multi_process_ep.py(+690 lines): 13 tests covering bootstrap,ep_prepareshape/handle contracts, primitive-level dispatch/combine identity (uniform + skewed routing),custom_vjpfwd+bwd correctness, and HLO inspection (must not insert XLA collectives outside the EP FFI).tests/jax/multi_process_launch_ep.sh: 4-rank launcher; setsXLA_FLAGSto keep XLA command-buffer capture off for the EP FFI sequence (NCCL EP graph-destroy interaction).examples/jax/ep/ep_moe.py(+394 lines) +run_test_ep.sh: end-to-end MoE with EP, dp=ep=2 mesh, includes a ref-comparison--checkthat verifies fwd+bwd vs a single-process reference.Type of change
Checklist: