diff --git a/build_tools/jax.py b/build_tools/jax.py index f07c0a202f..a7b200f915 100644 --- a/build_tools/jax.py +++ b/build_tools/jax.py @@ -3,13 +3,14 @@ # See LICENSE for license information. """JAX related extensions.""" + import os from pathlib import Path from packaging import version import setuptools -from .utils import get_cuda_include_dirs, all_files_in_dir, debug_build_enabled +from .utils import get_cuda_include_dirs, all_files_in_dir, debug_build_enabled, setup_mpi_flags from typing import List @@ -100,6 +101,8 @@ def setup_jax_extension( else: cxx_flags.append("-g0") + setup_mpi_flags(include_dirs, cxx_flags) + # Define TE/JAX as a Pybind11Extension from pybind11.setup_helpers import Pybind11Extension diff --git a/build_tools/pytorch.py b/build_tools/pytorch.py index fdfdee9b1c..533addaf53 100644 --- a/build_tools/pytorch.py +++ b/build_tools/pytorch.py @@ -3,12 +3,19 @@ # See LICENSE for license information. """PyTorch related extensions.""" + import os from pathlib import Path import setuptools -from .utils import all_files_in_dir, cuda_version, get_cuda_include_dirs, debug_build_enabled +from .utils import ( + all_files_in_dir, + cuda_version, + get_cuda_include_dirs, + debug_build_enabled, + setup_mpi_flags, +) from typing import List @@ -67,13 +74,7 @@ def setup_pytorch_extension( if version < (12, 0): raise RuntimeError("Transformer Engine requires CUDA 12.0 or newer") - if bool(int(os.getenv("NVTE_UB_WITH_MPI", "0"))): - assert ( - os.getenv("MPI_HOME") is not None - ), "MPI_HOME=/path/to/mpi must be set when compiling with NVTE_UB_WITH_MPI=1!" - mpi_path = Path(os.getenv("MPI_HOME")) - include_dirs.append(mpi_path / "include") - cxx_flags.append("-DNVTE_UB_WITH_MPI") + setup_mpi_flags(include_dirs, cxx_flags) library_dirs = [] libraries = [] diff --git a/build_tools/utils.py b/build_tools/utils.py index 885901068a..d0f5eab425 100644 --- a/build_tools/utils.py +++ b/build_tools/utils.py @@ -341,6 +341,17 @@ def get_frameworks() -> List[str]: return _frameworks +def setup_mpi_flags(include_dirs: List, cxx_flags: List) -> None: + """Add MPI include path and compile definition if NVTE_UB_WITH_MPI is enabled.""" + if bool(int(os.getenv("NVTE_UB_WITH_MPI", "0"))): + assert ( + os.getenv("MPI_HOME") is not None + ), "MPI_HOME=/path/to/mpi must be set when compiling with NVTE_UB_WITH_MPI=1!" + mpi_path = Path(os.getenv("MPI_HOME")) + include_dirs.append(mpi_path / "include") + cxx_flags.append("-DNVTE_UB_WITH_MPI") + + def copy_common_headers( src_dir: Union[Path, str], dst_dir: Union[Path, str],