Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
19 commits
Select commit Hold shift + click to select a range
496e3ab
initial impl
tdophung Apr 21, 2026
f453137
clean up any link to Maxtext. Permutation backends. clean up foward b…
tdophung Apr 22, 2026
0044bf2
add distributed test.
tdophung Apr 23, 2026
d78bc01
refactor to a2a from roe
tdophung Apr 30, 2026
6f87629
fix test_distributed issues with unpopulated LogicallyPartition pytre…
tdophung Apr 30, 2026
6aeb491
add option to choose weight fsdp sharding axis
tdophung May 5, 2026
25e1eb8
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 5, 2026
d7fef5a
address greptile comments
tdophung May 6, 2026
3a51708
address jeremys comments + relax the sum(group_size) <= dim_m constra…
nvjax May 7, 2026
dafaad4
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 7, 2026
27c18fe
revert C++ changes and will put in a new branch, tighten distributed…
tdophung May 12, 2026
abbb2c6
address more comments: ep_resource look up, perm backend enum, accept…
tdophung May 12, 2026
b375db7
tests/jax/test_distributed_moe_block.py
tdophung May 12, 2026
37c871c
change naming and add message for experimental feature
tdophung May 12, 2026
ddf5d90
[JAX] Refactor MoEBlock into a unified MoE custom_vjp, add tests
tdophung May 21, 2026
7080d3b
Merge branch 'main' into teddy/moe_block
tdophung May 21, 2026
84166d0
test(jax): parametrize MP MoE tests over (recipe, backend)
tdophung May 21, 2026
ee9b3ce
[JAX] Address review comments
tdophung May 22, 2026
2b69d72
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 22, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions qa/L0_jax_distributed_unittest/test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,12 @@ wait
TE_PATH=$TE_PATH bash $TE_PATH/examples/jax/collective_gemm/run_test_cgemm.sh || test_fail "run_test_cgemm.sh"
wait

# MoE custom_vjp distributed suite. Runs one Python process per GPU
# via tests/jax/run_multiprocess_moe_vjp.sh (mirrors the pattern in
# examples/jax/encoder/run_test_multiprocessing_encoder.sh). Requires
# >=4 visible GPUs.
TE_PATH=$TE_PATH bash $TE_PATH/tests/jax/run_multiprocess_moe_vjp.sh \
|| test_fail "test_multiprocess_moe_vjp.py"
# Exercise the multi-GPU tutorial in docs/examples/jax (needs >= 4 GPUs;
# auto-skips otherwise).
CUDA_VISIBLE_DEVICES=0,1,2,3 python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest_docs_examples_jax_distributed.xml -k multi_gpu $TE_PATH/docs/examples/jax/ || test_fail "docs/examples/jax (multi-GPU)"
Expand Down
14 changes: 14 additions & 0 deletions tests/jax/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,20 @@ def pytest_sessionfinish(self, session, exitstatus):
print("=" * 80)


def pytest_addoption(parser):
"""CLI options used by multiprocess JAX tests.

``--num-process`` and ``--process-id`` let a multiprocess launcher
(see ``tests/jax/run_multiprocess_moe_vjp.sh``) fork one pytest
process per GPU and tell each child its rank, so the test module
can call ``jax.distributed.initialize(...)`` with the right
``local_device_ids``. Both default to 0; non-multiprocess tests
ignore them.
"""
parser.addoption("--num-process", action="store", default=0)
parser.addoption("--process-id", action="store", default=0)


def pytest_configure(config):
config.addinivalue_line(
"markers",
Expand Down
130 changes: 130 additions & 0 deletions tests/jax/run_multiprocess_moe_vjp.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
#!/usr/bin/env bash
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
#
# Multiprocess (one-GPU-per-process) launcher for the unified MoE VJP
# test suite. Forks one pytest invocation per visible GPU, passing each
# its own --num-process=N --process-id=i, and waits for all of them.
# Each child calls jax.distributed.initialize(..., local_device_ids=
# process_id) so each Python process only sees its one GPU as a local
# device and the participating processes form a global mesh.

set -euo pipefail

SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
TE_ROOT="$(cd "$SCRIPT_DIR/../.." && pwd)"
TEST_FILE="$TE_ROOT/tests/jax/test_multiprocess_moe_vjp.py"
PYTEST_INI="$TE_ROOT/tests/jax/pytest.ini"

NUM_GPUS="${NUM_GPUS:-$(nvidia-smi -L | wc -l)}"
if [ "$NUM_GPUS" -lt 4 ]; then
echo "[run_multiprocess_moe_vjp.sh] need >=4 GPUs (got $NUM_GPUS); aborting" >&2
exit 1
fi

export XLA_PYTHON_CLIENT_PREALLOCATE="${XLA_PYTHON_CLIENT_PREALLOCATE:-false}"
export XLA_PYTHON_CLIENT_MEM_FRACTION="${XLA_PYTHON_CLIENT_MEM_FRACTION:-0.5}"
export MOE_VJP_COORDINATOR_ADDRESS="${MOE_VJP_COORDINATOR_ADDRESS:-127.0.0.1:13456}"

echo "============================================================"
echo "MoE VJP MULTIPROCESS test (one process per GPU, ${NUM_GPUS} GPUs)"
echo " test file : $TEST_FILE"
echo " coordinator : $MOE_VJP_COORDINATOR_ADDRESS"
echo " XLA_PYTHON_CLIENT_PREALLOCATE: $XLA_PYTHON_CLIENT_PREALLOCATE"
echo " XLA_PYTHON_CLIENT_MEM_FRACTION: $XLA_PYTHON_CLIENT_MEM_FRACTION"
echo "============================================================"

# Per-process logs. MOE_VJP_MP_LOG_DIR can be set to a host-mounted dir
# (e.g. when running inside a container that throws away /tmp on exit)
# so logs survive for postmortem inspection. Defaults to a fresh /tmp.
if [ -n "${MOE_VJP_MP_LOG_DIR:-}" ]; then
LOG_DIR="$MOE_VJP_MP_LOG_DIR"
mkdir -p "$LOG_DIR"
else
LOG_DIR=$(mktemp -d -t moe_vjp_mp_XXXXXX)
fi
echo "Per-process logs: $LOG_DIR"

PIDS=()

cleanup() {
for pid in "${PIDS[@]:-}"; do
if kill -0 "$pid" 2>/dev/null; then
kill -TERM "$pid" 2>/dev/null || true
fi
done
sleep 1
for pid in "${PIDS[@]:-}"; do
if kill -0 "$pid" 2>/dev/null; then
kill -KILL "$pid" 2>/dev/null || true
fi
done
}
trap cleanup EXIT INT TERM

# Launch one pytest per GPU. Process 0 streams to stdout; others log
# only to file so the live output isn't a mosaic.
for i in $(seq 0 $((NUM_GPUS - 1))); do
LOG_FILE="$LOG_DIR/proc_${i}.log"
PYTEST_CMD=(
python3 -m pytest -c "$PYTEST_INI"
"$TEST_FILE"
-p no:typeguard
-v -s
--num-process="$NUM_GPUS"
--process-id="$i"
)
if [ "$i" -eq 0 ]; then
echo "=== Live output from process 0 ==="
"${PYTEST_CMD[@]}" 2>&1 | tee "$LOG_FILE" &
else
"${PYTEST_CMD[@]}" > "$LOG_FILE" 2>&1 &
fi
PIDS+=("$!")
done

# Wait for all and collect exit codes.
EXITS=()
for pid in "${PIDS[@]}"; do
if wait "$pid"; then
EXITS+=("0")
else
EXITS+=("$?")
fi
done

# Summary.
echo
echo "============================================================"
echo "Per-process exit codes:"
for i in "${!EXITS[@]}"; do
echo " proc $i -> ${EXITS[$i]}"
done

# Final pass/fail. Any non-zero in any process fails the suite, but
# we tolerate non-zero on the non-zero processes only if proc 0
# reports PASS (this matches the encoder launcher's logic). Simplest
# strict rule: any non-zero is a failure.
FAILED=0
for e in "${EXITS[@]}"; do
if [ "$e" != "0" ]; then
FAILED=1
break
fi
done

echo
if [ "$FAILED" -eq 0 ]; then
echo "[run_multiprocess_moe_vjp.sh] all processes PASSED"
if [ -z "${MOE_VJP_MP_LOG_DIR:-}" ]; then
rm -rf "$LOG_DIR"
fi
exit 0
fi

echo "[run_multiprocess_moe_vjp.sh] at least one process FAILED"
echo " retaining logs at $LOG_DIR for diagnosis"
echo " process 0 tail:"
tail -20 "$LOG_DIR/proc_0.log" 2>/dev/null || true
exit 1
Loading
Loading