From f68cd3cab34972a899ad0069e2c4ee806e8bc6fb Mon Sep 17 00:00:00 2001 From: Gaetan Lepage Date: Sat, 4 Apr 2026 12:27:48 +0200 Subject: [PATCH 1/2] Fix JAX extension build with NVTE_UB_WITH_MPI=1 Signed-off-by: Gaetan Lepage --- build_tools/jax.py | 5 ++++- build_tools/pytorch.py | 17 +++++++++-------- build_tools/utils.py | 11 +++++++++++ 3 files changed, 24 insertions(+), 9 deletions(-) diff --git a/build_tools/jax.py b/build_tools/jax.py index f07c0a202f..a7b200f915 100644 --- a/build_tools/jax.py +++ b/build_tools/jax.py @@ -3,13 +3,14 @@ # See LICENSE for license information. """JAX related extensions.""" + import os from pathlib import Path from packaging import version import setuptools -from .utils import get_cuda_include_dirs, all_files_in_dir, debug_build_enabled +from .utils import get_cuda_include_dirs, all_files_in_dir, debug_build_enabled, setup_mpi_flags from typing import List @@ -100,6 +101,8 @@ def setup_jax_extension( else: cxx_flags.append("-g0") + setup_mpi_flags(include_dirs, cxx_flags) + # Define TE/JAX as a Pybind11Extension from pybind11.setup_helpers import Pybind11Extension diff --git a/build_tools/pytorch.py b/build_tools/pytorch.py index fdfdee9b1c..533addaf53 100644 --- a/build_tools/pytorch.py +++ b/build_tools/pytorch.py @@ -3,12 +3,19 @@ # See LICENSE for license information. """PyTorch related extensions.""" + import os from pathlib import Path import setuptools -from .utils import all_files_in_dir, cuda_version, get_cuda_include_dirs, debug_build_enabled +from .utils import ( + all_files_in_dir, + cuda_version, + get_cuda_include_dirs, + debug_build_enabled, + setup_mpi_flags, +) from typing import List @@ -67,13 +74,7 @@ def setup_pytorch_extension( if version < (12, 0): raise RuntimeError("Transformer Engine requires CUDA 12.0 or newer") - if bool(int(os.getenv("NVTE_UB_WITH_MPI", "0"))): - assert ( - os.getenv("MPI_HOME") is not None - ), "MPI_HOME=/path/to/mpi must be set when compiling with NVTE_UB_WITH_MPI=1!" - mpi_path = Path(os.getenv("MPI_HOME")) - include_dirs.append(mpi_path / "include") - cxx_flags.append("-DNVTE_UB_WITH_MPI") + setup_mpi_flags(include_dirs, cxx_flags) library_dirs = [] libraries = [] diff --git a/build_tools/utils.py b/build_tools/utils.py index 885901068a..bab0177c73 100644 --- a/build_tools/utils.py +++ b/build_tools/utils.py @@ -341,6 +341,17 @@ def get_frameworks() -> List[str]: return _frameworks +def setup_mpi_flags(include_dirs: List, cxx_flags: List) -> None: + """Add MPI include path and compile definition if NVTE_UB_WITH_MPI is enabled.""" + if bool(int(os.getenv("NVTE_UB_WITH_MPI", "0"))): + assert os.getenv("MPI_HOME") is not None, ( + "MPI_HOME=/path/to/mpi must be set when compiling with NVTE_UB_WITH_MPI=1!" + ) + mpi_path = Path(os.getenv("MPI_HOME")) + include_dirs.append(mpi_path / "include") + cxx_flags.append("-DNVTE_UB_WITH_MPI") + + def copy_common_headers( src_dir: Union[Path, str], dst_dir: Union[Path, str], From bee67bbfad059ffe6d78680f2537adf8ced59547 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sat, 4 Apr 2026 10:33:23 +0000 Subject: [PATCH 2/2] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- build_tools/utils.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/build_tools/utils.py b/build_tools/utils.py index bab0177c73..d0f5eab425 100644 --- a/build_tools/utils.py +++ b/build_tools/utils.py @@ -344,9 +344,9 @@ def get_frameworks() -> List[str]: def setup_mpi_flags(include_dirs: List, cxx_flags: List) -> None: """Add MPI include path and compile definition if NVTE_UB_WITH_MPI is enabled.""" if bool(int(os.getenv("NVTE_UB_WITH_MPI", "0"))): - assert os.getenv("MPI_HOME") is not None, ( - "MPI_HOME=/path/to/mpi must be set when compiling with NVTE_UB_WITH_MPI=1!" - ) + assert ( + os.getenv("MPI_HOME") is not None + ), "MPI_HOME=/path/to/mpi must be set when compiling with NVTE_UB_WITH_MPI=1!" mpi_path = Path(os.getenv("MPI_HOME")) include_dirs.append(mpi_path / "include") cxx_flags.append("-DNVTE_UB_WITH_MPI")