diff --git a/build_tools/pytorch.py b/build_tools/pytorch.py index 533addaf53..a6c3938a00 100644 --- a/build_tools/pytorch.py +++ b/build_tools/pytorch.py @@ -76,6 +76,13 @@ def setup_pytorch_extension( setup_mpi_flags(include_dirs, cxx_flags) + # Mirror the cuSOLVERMp gate. newton_schulz.cpp is conditionally compiled + # in the common lib; the pytorch ext glob pulls the same source so it must + # see the same define, otherwise the pybind layer refers to undefined + # cusolvermp_ctx_* symbols. + if bool(int(os.getenv("NVTE_WITH_CUSOLVERMP", "0"))): + cxx_flags.append("-DNVTE_WITH_CUSOLVERMP") + library_dirs = [] libraries = [] if bool(int(os.getenv("NVTE_ENABLE_NVSHMEM", 0))): diff --git a/transformer_engine/common/CMakeLists.txt b/transformer_engine/common/CMakeLists.txt index 030023d949..ae66e8430a 100644 --- a/transformer_engine/common/CMakeLists.txt +++ b/transformer_engine/common/CMakeLists.txt @@ -147,6 +147,7 @@ set(transformer_engine_cuda_sources) set(transformer_engine_cuda_arch_specific_sources) list(APPEND transformer_engine_cpp_sources + comm_handle.cpp cudnn_utils.cpp transformer_engine.cpp fused_attn/fused_attn.cpp diff --git a/transformer_engine/common/comm_handle.cpp b/transformer_engine/common/comm_handle.cpp new file mode 100644 index 0000000000..fc2e1c5b5d --- /dev/null +++ b/transformer_engine/common/comm_handle.cpp @@ -0,0 +1,48 @@ +/************************************************************************* + * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#include "transformer_engine/comm_handle.h" + +#include "common.h" +#include "transformer_engine/nccl_comm.h" +#include "util/logging.h" + +using transformer_engine::convertNVTETensor; + +NVTEPeerHandleKind nvte_tensor_peer_handle_kind(const NVTETensor t) { + const auto* tensor = convertNVTETensor(t); + return tensor != nullptr ? tensor->peer_handle_kind : NVTE_PEER_HANDLE_NONE; +} + +void nvte_tensor_detach_peer_handle(NVTETensor t) { + auto* tensor = convertNVTETensor(t); + if (tensor == nullptr) return; + tensor->peer_handle_kind = NVTE_PEER_HANDLE_NONE; + tensor->peer_handle_data = nullptr; + tensor->peer_handle_offset = 0; +} + +void nvte_tensor_attach_nccl_window(NVTETensor t, void* window, uint64_t offset) { + auto* tensor = convertNVTETensor(t); + NVTE_CHECK(tensor != nullptr, "nvte_tensor_attach_nccl_window: invalid NVTETensor handle"); + if (window == nullptr) { + tensor->peer_handle_kind = NVTE_PEER_HANDLE_NONE; + tensor->peer_handle_data = nullptr; + tensor->peer_handle_offset = 0; + return; + } + tensor->peer_handle_kind = NVTE_PEER_HANDLE_NCCL_WINDOW; + tensor->peer_handle_data = window; + tensor->peer_handle_offset = offset; +} + +void nvte_tensor_nccl_window(const NVTETensor t, void** window, uint64_t* offset) { + const auto* tensor = convertNVTETensor(t); + const bool has_nccl = + tensor != nullptr && tensor->peer_handle_kind == NVTE_PEER_HANDLE_NCCL_WINDOW; + if (window != nullptr) *window = has_nccl ? tensor->peer_handle_data : nullptr; + if (offset != nullptr) *offset = has_nccl ? tensor->peer_handle_offset : 0; +} diff --git a/transformer_engine/common/common.h b/transformer_engine/common/common.h index 12479f2a9c..0ef1374fb7 100644 --- a/transformer_engine/common/common.h +++ b/transformer_engine/common/common.h @@ -24,6 +24,7 @@ static_assert(NVTE_BUILD_NUM_PHILOX_ROUNDS > 0, #endif #include +#include #include #include @@ -179,6 +180,13 @@ struct Tensor { */ bool row_scaled_nvfp4 = false; + /*! \brief Optional borrowed peer handle for one-sided RMA against this tensor. + * ``peer_handle_kind`` selects the backend owning ``peer_handle_data``; + * the caller keeps the resource valid for the tensor's lifetime. */ + NVTEPeerHandleKind peer_handle_kind = NVTE_PEER_HANDLE_NONE; + void *peer_handle_data = nullptr; + uint64_t peer_handle_offset = 0; + /*! Map from NVTETensorParam to parameter sizes */ static constexpr size_t attr_sizes[] = { sizeof(NVTEBasicTensor), // kNVTERowwiseData diff --git a/transformer_engine/common/include/transformer_engine/comm_handle.h b/transformer_engine/common/include/transformer_engine/comm_handle.h new file mode 100644 index 0000000000..5fb1674132 --- /dev/null +++ b/transformer_engine/common/include/transformer_engine/comm_handle.h @@ -0,0 +1,39 @@ +/************************************************************************* + * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +/*! \file comm_handle.h + * \brief Generic peer-handle annotation on NVTETensor for one-sided RMA. + * + * The annotation is borrowed; the tensor never owns the underlying resource. + * Per-backend setters/getters live in dedicated headers (e.g. ``nccl_comm.h``). + */ + +#ifndef TRANSFORMER_ENGINE_COMM_HANDLE_H_ +#define TRANSFORMER_ENGINE_COMM_HANDLE_H_ + +#include "transformer_engine.h" + +#ifdef __cplusplus +extern "C" { +#endif + +/*! \brief Comm backend that owns a tensor's peer handle. */ +typedef enum { + NVTE_PEER_HANDLE_NONE = 0, + NVTE_PEER_HANDLE_NCCL_WINDOW = 1, +} NVTEPeerHandleKind; + +/*! \brief Peer-handle kind attached to ``t``. */ +NVTEPeerHandleKind nvte_tensor_peer_handle_kind(const NVTETensor t); + +/*! \brief Clear any peer handle attached to ``t``. */ +void nvte_tensor_detach_peer_handle(NVTETensor t); + +#ifdef __cplusplus +} +#endif + +#endif // TRANSFORMER_ENGINE_COMM_HANDLE_H_ diff --git a/transformer_engine/common/include/transformer_engine/nccl_comm.h b/transformer_engine/common/include/transformer_engine/nccl_comm.h new file mode 100644 index 0000000000..b9ba5499a2 --- /dev/null +++ b/transformer_engine/common/include/transformer_engine/nccl_comm.h @@ -0,0 +1,41 @@ +/************************************************************************* + * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +/*! \file nccl_comm.h + * \brief Attach a registered NCCL symmetric-memory window to an NVTETensor. + * + * The window is caller-owned and must outlive the tensor; ``attach`` does + * not register or rendezvous it. + */ + +#ifndef TRANSFORMER_ENGINE_NCCL_COMM_H_ +#define TRANSFORMER_ENGINE_NCCL_COMM_H_ + +#include "comm_handle.h" +#include "transformer_engine.h" + +#ifdef __cplusplus +extern "C" { +#endif + +/*! \brief Attach an NCCL window + byte offset to ``t``. Pass ``window=NULL`` to detach. + * + * \param[in,out] t Tensor to annotate. + * \param[in] window Opaque ncclWindow_t (caller-owned), or NULL to clear. + * \param[in] offset Byte offset into the window where this tensor starts. + */ +void nvte_tensor_attach_nccl_window(NVTETensor t, void* window, uint64_t offset); + +/*! \brief Read the NCCL window + offset attached to ``t``; yields (NULL, 0) when unset. + * Either out-pointer may be NULL to skip that field. + */ +void nvte_tensor_nccl_window(const NVTETensor t, void** window, uint64_t* offset); + +#ifdef __cplusplus +} +#endif + +#endif // TRANSFORMER_ENGINE_NCCL_COMM_H_ diff --git a/transformer_engine/pytorch/csrc/extensions.h b/transformer_engine/pytorch/csrc/extensions.h index 8082ff07ed..cdec0f08fb 100644 --- a/transformer_engine/pytorch/csrc/extensions.h +++ b/transformer_engine/pytorch/csrc/extensions.h @@ -631,12 +631,14 @@ void bulk_overlap_ag_with_external_gemm(CommOverlap &allgather_communicator, at: * Newton-Schulz (cuSolverMp) **************************************************************************************************/ +#ifdef NVTE_WITH_CUSOLVERMP int64_t cusolvermp_ctx_create(int64_t nccl_comm_ptr, int nranks, int rank); void cusolvermp_ctx_destroy(int64_t ctx_ptr); void newton_schulz(int64_t ctx_ptr, int64_t m, int64_t n, at::Tensor x, int64_t num_iterations, std::vector coefficients); +#endif // NVTE_WITH_CUSOLVERMP } // namespace transformer_engine::pytorch diff --git a/transformer_engine/pytorch/csrc/extensions/newton_schulz.cpp b/transformer_engine/pytorch/csrc/extensions/newton_schulz.cpp index 8b24e8fdb9..3f2a57007b 100644 --- a/transformer_engine/pytorch/csrc/extensions/newton_schulz.cpp +++ b/transformer_engine/pytorch/csrc/extensions/newton_schulz.cpp @@ -4,6 +4,12 @@ * See LICENSE for license information. ************************************************************************/ +// Conditionally compiled: the common Newton-Schulz/cuSOLVERMp impl is gated +// behind NVTE_WITH_CUSOLVERMP in the common CMakeLists. Without the gate, the +// pytorch ext glob would pick this file up and produce undefined symbols +// (nvte_cusolvermp_ctx_*). Keep this gate aligned with common/. +#ifdef NVTE_WITH_CUSOLVERMP + #include "transformer_engine/newton_schulz.h" #include "../extensions.h" @@ -38,3 +44,5 @@ void newton_schulz(int64_t ctx_ptr, int64_t m, int64_t n, at::Tensor x, int64_t } } // namespace transformer_engine::pytorch + +#endif // NVTE_WITH_CUSOLVERMP diff --git a/transformer_engine/pytorch/csrc/extensions/pybind.cpp b/transformer_engine/pytorch/csrc/extensions/pybind.cpp index a4571c64e2..77c7952df1 100644 --- a/transformer_engine/pytorch/csrc/extensions/pybind.cpp +++ b/transformer_engine/pytorch/csrc/extensions/pybind.cpp @@ -589,6 +589,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { &transformer_engine::pytorch::multi_tensor_compute_scale_inv_e8m0_cuda, "Fused compute E8M0 scale_inv from amax", py::call_guard()); +#ifdef NVTE_WITH_CUSOLVERMP // Newton-Schulz (cuSolverMp) m.def("cusolvermp_ctx_create", &transformer_engine::pytorch::cusolvermp_ctx_create, "Create cuSolverMp context for Newton-Schulz", py::arg("nccl_comm_ptr"), py::arg("nranks"), @@ -599,6 +600,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { "Newton-Schulz matrix orthogonalization", py::arg("ctx_ptr"), py::arg("m"), py::arg("n"), py::arg("x"), py::arg("num_iterations"), py::arg("coefficients"), py::call_guard()); +#endif // NVTE_WITH_CUSOLVERMP // Comm+GEMM Overlap m.def("bulk_overlap_ag_with_external_gemm",