diff --git a/transformer_engine/common/cast/dispatch/quantize.cuh b/transformer_engine/common/cast/dispatch/quantize.cuh index 123362ce10..200f0134ad 100644 --- a/transformer_engine/common/cast/dispatch/quantize.cuh +++ b/transformer_engine/common/cast/dispatch/quantize.cuh @@ -97,8 +97,7 @@ void quantize_fwd_helper(const NVTETensor input, NVTETensor output, CheckOutputTensor(*output_tensor, "output", false); // Choose kernel - int32_t rows = input_tensor->flat_first_dim(); - int32_t cols = input_tensor->flat_last_dim(); + const auto [rows, cols] = input_tensor->flat_2d_dims(); auto dtype = input_tensor->dtype(); const bool row_scaled_nvfp4 = output_tensor->row_scaled_nvfp4; if (row_scaled_nvfp4) { @@ -246,8 +245,7 @@ void quantize_bwd_helper(const NVTETensor grad, const NVTETensor input, NVTETens CheckOutputTensor(*output_tensor, "output", false); // Choose kernel - int32_t rows = grad_tensor->flat_first_dim(); - int32_t cols = grad_tensor->flat_last_dim(); + const auto [rows, cols] = grad_tensor->flat_2d_dims(); auto dtype = grad_tensor->dtype(); NVTE_CHECK(!output_tensor->row_scaled_nvfp4, "Backward NVFP4 quantization does not support row-scaled outputs."); @@ -368,8 +366,7 @@ void group_quantize_fwd_host_aware_helper(const NVTETensor input, NVTETensor *ou // output list here is allowed to have empty tensor // Choose kernel - int32_t rows = input_tensor->flat_first_dim(); - int32_t cols = input_tensor->flat_last_dim(); + const auto [rows, cols] = input_tensor->flat_2d_dims(); auto dtype = input_tensor->dtype(); NVTE_CHECK(!quant_config_cpp.nvfp4_2d_quantization, diff --git a/transformer_engine/common/cast/fp8/quantize_fp8.cuh b/transformer_engine/common/cast/fp8/quantize_fp8.cuh index 96a42b494d..dfd140535a 100644 --- a/transformer_engine/common/cast/fp8/quantize_fp8.cuh +++ b/transformer_engine/common/cast/fp8/quantize_fp8.cuh @@ -391,8 +391,7 @@ void quantize_2D(const Tensor &input, const Tensor *act_input, Tensor *output, T using namespace quantize_2D_kernel; checkCuDriverContext(stream); - const size_t rows = input.flat_first_dim(); - const size_t cols = input.flat_last_dim(); + const auto [rows, cols] = input.flat_2d_dims(); const size_t chunks_Y = DIVUP(rows, FP8_CHUNK_DIM_Y); const size_t chunks_X = DIVUP(cols, FP8_CHUNK_DIM_X); const size_t blocks_Y = chunks_Y; diff --git a/transformer_engine/common/cast/mxfp8/dequantize_mxfp8.cuh b/transformer_engine/common/cast/mxfp8/dequantize_mxfp8.cuh index 6441a567a6..1face261bd 100644 --- a/transformer_engine/common/cast/mxfp8/dequantize_mxfp8.cuh +++ b/transformer_engine/common/cast/mxfp8/dequantize_mxfp8.cuh @@ -261,8 +261,7 @@ inline void dequantize(const Tensor &input, Tensor *output, cudaStream_t stream) const size_t scale_dim_X_rowwise = use_rowwise_scaling ? 32 : 1; const size_t scale_dim_Y_colwise = use_colwise_scaling ? 32 : 1; - const size_t rows = input.flat_first_dim(); - const size_t cols = input.flat_last_dim(); + const auto [rows, cols] = input.flat_2d_dims(); const size_t chunks_Y = DIVUP(rows, CHUNK_DIM_Y); const size_t chunks_X = DIVUP(cols, CHUNK_DIM_X); diff --git a/transformer_engine/common/cast/mxfp8/quantize_mxfp8.cuh b/transformer_engine/common/cast/mxfp8/quantize_mxfp8.cuh index 1549a292d8..456d926539 100644 --- a/transformer_engine/common/cast/mxfp8/quantize_mxfp8.cuh +++ b/transformer_engine/common/cast/mxfp8/quantize_mxfp8.cuh @@ -578,8 +578,7 @@ void quantize(const Tensor &input, const Tensor *act_input, const Tensor *noop, constexpr bool CAST_DBIAS_ONLY = IS_DBIAS && (!IS_DACT) && (!IS_ACT); // Tensor dimensions - const size_t rows = input.flat_first_dim(); - const size_t cols = input.flat_last_dim(); + const auto [rows, cols] = input.flat_2d_dims(); // Tensor chunk handled by each CUDA block constexpr size_t CHUNK_DIM_Y = CAST_DBIAS_ONLY ? 128 : 64; diff --git a/transformer_engine/common/cast/nvfp4/dequantize_nvfp4.cuh b/transformer_engine/common/cast/nvfp4/dequantize_nvfp4.cuh index d549a050ee..1fc00a5a2b 100644 --- a/transformer_engine/common/cast/nvfp4/dequantize_nvfp4.cuh +++ b/transformer_engine/common/cast/nvfp4/dequantize_nvfp4.cuh @@ -94,8 +94,7 @@ inline void dequantize(const Tensor &input, Tensor *output, cudaStream_t stream) const bool row_scaled_nvfp4 = input.row_scaled_nvfp4; constexpr int FP4_BLOCK_SIZE = 16; - const size_t N = input.flat_first_dim(); - const size_t M = input.flat_last_dim(); + const auto [N, M] = input.flat_2d_dims(); NVTE_CHECK(M % FP4_BLOCK_SIZE == 0, "Last dimension of FP4 tensors needs to be divisible by ", FP4_BLOCK_SIZE, ", but got ", input.data.shape, "."); diff --git a/transformer_engine/common/cast/nvfp4/group_quantize_transpose_nvfp4.cuh b/transformer_engine/common/cast/nvfp4/group_quantize_transpose_nvfp4.cuh index a2f3dac15a..a017a02242 100644 --- a/transformer_engine/common/cast/nvfp4/group_quantize_transpose_nvfp4.cuh +++ b/transformer_engine/common/cast/nvfp4/group_quantize_transpose_nvfp4.cuh @@ -783,8 +783,7 @@ void group_quantize_transpose(const Tensor &input, const Tensor *noop, NVTE_CHECK(input.has_data(), "Cannot quantize tensor without rowwise data."); - const size_t rows = input.flat_first_dim(); - const size_t cols = input.flat_last_dim(); + const auto [rows, cols] = input.flat_2d_dims(); NVTE_CHECK(rows % 32 == 0, "Number of tensor rows must be a multiple of 32"); // 16B alignment for TMA diff --git a/transformer_engine/common/cast/nvfp4/quantize_transpose_nvfp4.cuh b/transformer_engine/common/cast/nvfp4/quantize_transpose_nvfp4.cuh index 9e4aef5a1c..e285f6b719 100644 --- a/transformer_engine/common/cast/nvfp4/quantize_transpose_nvfp4.cuh +++ b/transformer_engine/common/cast/nvfp4/quantize_transpose_nvfp4.cuh @@ -121,8 +121,7 @@ inline void compute_rowwise_amax(const Tensor &input, const Tensor *noop, Tensor #if FP4_TYPE_SUPPORTED using namespace rowwise_amax_kernel; - const size_t rows = input.flat_first_dim(); - const size_t cols = input.flat_last_dim(); + const auto [rows, cols] = input.flat_2d_dims(); NVTE_CHECK(cols % ROWWISE_AMAX_SF_VEC_SIZE == 0, "Row-scaled NVFP4 quantization requires last dim divisible by ", ROWWISE_AMAX_SF_VEC_SIZE, "."); @@ -1359,8 +1358,7 @@ void quantize_transpose(const Tensor &input, const Tensor *noop, Tensor *output, "Transposed scaling tensor must be allocated"); } - const size_t rows = input.flat_first_dim(); - const size_t cols = input.flat_last_dim(); + const auto [rows, cols] = input.flat_2d_dims(); NVTE_CHECK(rows % 32 == 0, "Number of tensor rows must be a multiple of 32"); // 16B alignment for TMA diff --git a/transformer_engine/common/cast/nvfp4/specialized/quantize_transpose_nvfp4_tuned_1D.cuh b/transformer_engine/common/cast/nvfp4/specialized/quantize_transpose_nvfp4_tuned_1D.cuh index 8adda82131..53537b36f5 100644 --- a/transformer_engine/common/cast/nvfp4/specialized/quantize_transpose_nvfp4_tuned_1D.cuh +++ b/transformer_engine/common/cast/nvfp4/specialized/quantize_transpose_nvfp4_tuned_1D.cuh @@ -718,8 +718,7 @@ inline void quantize_transpose_tuned_1D(const Tensor &input, const Tensor *noop, "Transposed scaling tensor must be allocated"); } - const size_t rows = input.flat_first_dim(); - const size_t cols = input.flat_last_dim(); + const auto [rows, cols] = input.flat_2d_dims(); NVTE_CHECK(rows % 32 == 0, "Number of tensor rows must be a multiple of 32"); // 16B alignment for TMA diff --git a/transformer_engine/common/comm_gemm/comm_gemm.cpp b/transformer_engine/common/comm_gemm/comm_gemm.cpp index ce389c2006..60632b99d8 100644 --- a/transformer_engine/common/comm_gemm/comm_gemm.cpp +++ b/transformer_engine/common/comm_gemm/comm_gemm.cpp @@ -130,12 +130,9 @@ int64_t block_size(NVTECommGemmCtx* ctx, int64_t global_size) { void AgGemmInitMatrices(NVTECommGemmCtx* ctx, int64_t* ldd, int64_t m, int64_t n, int64_t k, const Tensor* a, const Tensor* b, const Tensor* d, bool transa, bool transb) { - const auto a0 = a->flat_first_dim(); - const auto a1 = a->flat_last_dim(); - const auto b0 = b->flat_first_dim(); - const auto b1 = b->flat_last_dim(); - const auto d0 = d->flat_first_dim(); - const auto d1 = d->flat_last_dim(); + const auto [a0, a1] = a->flat_2d_dims(); + const auto [b0, b1] = b->flat_2d_dims(); + const auto [d0, d1] = d->flat_2d_dims(); if (transa) { NVTE_CHECK(a1 == k, "Unsupported tensor dimension in A: expected ", k, ", got ", a1); @@ -169,12 +166,9 @@ void AgGemmInitMatrices(NVTECommGemmCtx* ctx, int64_t* ldd, int64_t m, int64_t n void GemmRsInitMatrices(NVTECommGemmCtx* ctx, int64_t* ldd, int64_t m, int64_t n, int64_t k, const Tensor* a, const Tensor* b, const Tensor* d, bool transa, bool transb) { - const auto a0 = a->flat_first_dim(); - const auto a1 = a->flat_last_dim(); - const auto b0 = b->flat_first_dim(); - const auto b1 = b->flat_last_dim(); - const auto d0 = d->flat_first_dim(); - const auto d1 = d->flat_last_dim(); + const auto [a0, a1] = a->flat_2d_dims(); + const auto [b0, b1] = b->flat_2d_dims(); + const auto [d0, d1] = d->flat_2d_dims(); if (transa) { NVTE_CHECK(a0 == m, "Unsupported tensor dimension in A: expected ", m, ", got ", a0); @@ -213,12 +207,9 @@ void GemmRsInitMatrices(NVTECommGemmCtx* ctx, int64_t* ldd, int64_t m, int64_t n void GemmArInitMatrices(NVTECommGemmCtx* ctx, int64_t* ldd, int64_t m, int64_t n, int64_t k, const Tensor* a, const Tensor* b, const Tensor* d, bool transa, bool transb) { - const auto a0 = a->flat_first_dim(); - const auto a1 = a->flat_last_dim(); - const auto b0 = b->flat_first_dim(); - const auto b1 = b->flat_last_dim(); - const auto d0 = d->flat_first_dim(); - const auto d1 = d->flat_last_dim(); + const auto [a0, a1] = a->flat_2d_dims(); + const auto [b0, b1] = b->flat_2d_dims(); + const auto [d0, d1] = d->flat_2d_dims(); if (transa) { NVTE_CHECK(a0 == m, "Unsupported tensor dimension in A: expected ", m, ", got ", a0); diff --git a/transformer_engine/common/common.h b/transformer_engine/common/common.h index 12479f2a9c..71bda1faea 100644 --- a/transformer_engine/common/common.h +++ b/transformer_engine/common/common.h @@ -16,6 +16,7 @@ static_assert(NVTE_BUILD_NUM_PHILOX_ROUNDS > 0, "NVTE_BUILD_NUM_PHILOX_ROUNDS must be a positive integer."); +#include #include #include #include @@ -26,13 +27,16 @@ static_assert(NVTE_BUILD_NUM_PHILOX_ROUNDS > 0, #include #include +#include +#include #include -#include -#include +#include +#include #include +#include #include #include -#include +#include #include #include "./nvtx.h" @@ -101,7 +105,9 @@ inline size_t product(const std::vector &shape, const size_t begin, cons return ret; } -inline size_t product(const std::vector &shape) { +template ::value>> +inline size_t product(const Container &shape) { size_t ret = 1; for (const auto &elem : shape) { ret *= elem; @@ -113,24 +119,135 @@ size_t get_buffer_size_bytes(const size_t N, const DType buffer_dtype); size_t get_buffer_size_bytes(const size_t dim_first, const size_t dim_last, const DType buffer_dtype); +/*! \brief Tensor shape + * + * Wraps NVTEShape with an interface similar to std::vector. + */ +class Shape { + public: + using value_type = size_t; + using size_type = size_t; + using iterator = size_t *; + using const_iterator = const size_t *; + + /*! Maximum number of dimensions this shape can hold. */ + static constexpr size_type max_ndim = std::extent_v; + + constexpr Shape() noexcept = default; + + explicit constexpr Shape(const NVTEShape &shape) noexcept : data_{shape} {} + + Shape(std::initializer_list shape) { + NVTE_CHECK(shape.size() <= max_ndim, "Too many dimensions (requested ", shape.size(), + ", max is ", max_ndim, ")."); + data_.ndim = shape.size(); + std::copy(shape.begin(), shape.end(), data_.data); + } + + // Construct from any container of integers + template ::value>> + explicit Shape(const Container &shape) { + NVTE_CHECK(shape.size() <= max_ndim, "Too many dimensions (requested ", shape.size(), + ", max is ", max_ndim, ")."); + data_.ndim = shape.size(); + std::copy(shape.begin(), shape.end(), data_.data); + } + + constexpr operator NVTEShape() const noexcept { return data_; } + + /*! Cast to std::vector */ + explicit operator std::vector() const { + return std::vector(data_.data, data_.data + data_.ndim); + } + + constexpr size_type size() const noexcept { return data_.ndim; } + constexpr bool empty() const noexcept { return data_.ndim == 0; } + static constexpr size_type capacity() noexcept { return max_ndim; } + + value_type *data() noexcept { return data_.data; } + constexpr const value_type *data() const noexcept { return data_.data; } + + iterator begin() noexcept { return data_.data; } + constexpr const_iterator begin() const noexcept { return data_.data; } + constexpr const_iterator cbegin() const noexcept { return data_.data; } + iterator end() noexcept { return data_.data + data_.ndim; } + constexpr const_iterator end() const noexcept { return data_.data + data_.ndim; } + constexpr const_iterator cend() const noexcept { return data_.data + data_.ndim; } + + const value_type &at(size_type i) const { + NVTE_CHECK(i < data_.ndim, "Attempted to access out-of-bounds entry (requested ", i, + ", size is ", data_.ndim, ")."); + return data_.data[i]; + } + value_type &at(size_type i) { return const_cast(std::as_const(*this).at(i)); } + + value_type &operator[](size_type i) noexcept { return data_.data[i]; } + constexpr const value_type &operator[](size_type i) const noexcept { return data_.data[i]; } + + value_type &front() noexcept { return data_.data[0]; } + constexpr const value_type &front() const noexcept { return data_.data[0]; } + + value_type &back() noexcept { return data_.data[data_.ndim - 1]; } + constexpr const value_type &back() const noexcept { return data_.data[data_.ndim - 1]; } + + void push_back(size_type value) { + NVTE_CHECK(data_.ndim < max_ndim, "Cannot add dimension: shape is at maximum capacity (", + max_ndim, ")."); + data_.data[data_.ndim++] = value; + } + + void resize(size_type count) { + NVTE_CHECK(count <= max_ndim, "Too many dimensions (requested ", count, ", max is ", max_ndim, + ")."); + data_.ndim = count; + } + + void clear() noexcept { data_.ndim = 0; } + + friend bool operator==(const Shape &lhs, const Shape &rhs) noexcept { + return lhs.data_.ndim == rhs.data_.ndim && + std::equal(lhs.data_.data, lhs.data_.data + lhs.data_.ndim, rhs.data_.data); + } + friend bool operator!=(const Shape &lhs, const Shape &rhs) noexcept { return !(lhs == rhs); } + + template ::value>> + friend bool operator==(const Shape &lhs, const Container &rhs) { + return lhs == Shape(rhs); + } + template + friend bool operator==(const T &lhs, const Shape &rhs) { + return rhs == lhs; + } + template + friend bool operator!=(const Shape &lhs, const T &rhs) { + return !(lhs == rhs); + } + template + friend bool operator!=(const T &lhs, const Shape &rhs) { + return !(rhs == lhs); + } + + private: + NVTEShape data_{}; +}; + struct SimpleTensor { void *dptr; - std::vector shape; + Shape shape; DType dtype; SimpleTensor(void *dptr, std::vector shape, DType dtype) - : dptr{dptr}, shape{std::move(shape)}, dtype{dtype} {} + : dptr{dptr}, shape(shape), dtype{dtype} {} - SimpleTensor(const NVTEBasicTensor &tensor) // NOLINT - : dptr(tensor.data_ptr), - shape(tensor.shape.data, tensor.shape.data + tensor.shape.ndim), - dtype(static_cast(tensor.dtype)) {} + SimpleTensor() : SimpleTensor(nullptr, {0}, DType::kFloat32) {} - SimpleTensor() : SimpleTensor(nullptr, std::vector{0}, DType::kFloat32) {} + SimpleTensor(const NVTEBasicTensor &tensor) // NOLINT + : dptr(tensor.data_ptr), shape(tensor.shape), dtype(static_cast(tensor.dtype)) {} operator NVTEBasicTensor() const { - return {dptr, static_cast(dtype), - nvte_make_shape(this->shape.data(), this->shape.size())}; + return {dptr, static_cast(dtype), static_cast(shape)}; } /*! Number of tensor elements. */ @@ -212,10 +329,11 @@ struct Tensor { /*! Number of tensor elements. */ size_t numel() const { - if (!has_data() && has_columnwise_data()) { - return product(columnwise_data.shape); + size_t ret = 1; + for (const size_t dim : shape()) { + ret *= dim; } - return product(data.shape); + return ret; } /*! Whether the tensor data buffer is not uninitialized. @@ -256,7 +374,7 @@ struct Tensor { * different shape, e.g. the column-wise data for some tensor * formats are transposed. */ - std::vector shape() const { + Shape shape() const { // Each tensor format interprets its data differently switch (scaling_mode) { case NVTE_DELAYED_TENSOR_SCALING: @@ -266,9 +384,8 @@ struct Tensor { // Row-wise data shape matches tensor logical shape, // column-wise data shape is transpose of logical shape if (!has_data() && has_columnwise_data()) { - std::vector ret; + Shape ret; if (!columnwise_data.shape.empty()) { - ret.reserve(columnwise_data.shape.size()); for (size_t i = 1; i < columnwise_data.shape.size(); i++) { ret.push_back(columnwise_data.shape[i]); } @@ -291,35 +408,36 @@ struct Tensor { } } - /*! Matrix height after tensor is flattened to 2D + /*! Matrix dimensions after flattening tensor to 2D. * * If a tensor has dimensions (D1, D2, ..., Dn), it is reinterpreted * as a (D1*D2*...*D(n-1), Dn) matrix. */ - size_t flat_first_dim() const { - const auto &full_shape = shape(); - size_t ret = 1; - if (!full_shape.empty()) { - for (size_t i = 0; i < full_shape.size() - 1; i++) { - ret *= full_shape[i]; - } + std::array flat_2d_dims() const { + const auto s = shape(); + if (s.empty()) { + return {1, 1}; } - return ret; + size_t first_dim = 1; + for (size_t i = 0; i + 1 < s.size(); ++i) { + first_dim *= s[i]; + } + return {first_dim, s.back()}; } + /*! Matrix height after tensor is flattened to 2D + * + * If a tensor has dimensions (D1, D2, ..., Dn), it is reinterpreted + * as a (D1*D2*...*D(n-1), Dn) matrix. + */ + size_t flat_first_dim() const { return flat_2d_dims()[0]; } + /*! Matrix width after tensor is flattened to 2D * * If a tensor has dimensions (D1, D2, ..., Dn), it is reinterpreted * as a (D1*D2*...*D(n-1), Dn) matrix. */ - size_t flat_last_dim() const { - const auto &full_shape = shape(); - if (full_shape.empty()) { - return 1; - } else { - return full_shape.back(); - } - } + size_t flat_last_dim() const { return flat_2d_dims()[1]; } }; struct GroupedTensor { @@ -1029,9 +1147,10 @@ inline bool is_aligned_tensor_data(const Tensor &t, size_t alignment) { size_t typeToSize(const DType type); size_t typeToNumBits(const DType type); -void CheckNoopTensor(const Tensor &t, const std::string &name); -void CheckInputTensor(const Tensor &t, const std::string &name, bool check_scale_inv_shapes = true); -void CheckOutputTensor(const Tensor &t, const std::string &name, bool allow_empty = false); +void CheckNoopTensor(const Tensor &t, const std::string_view &name); +void CheckInputTensor(const Tensor &t, const std::string_view &name, + bool check_scale_inv_shapes = true); +void CheckOutputTensor(const Tensor &t, const std::string_view &name, bool allow_empty = false); /*! \brief Update a tensor's FP8 scale-inverse * @@ -1066,9 +1185,9 @@ GroupedTensor *convertNVTEGroupedTensor(const NVTEGroupedTensor tensor); GroupedTensor *convertNVTEGroupedTensorCheck(const NVTEGroupedTensor tensor); // Helper functions for GroupedTensor validation -void CheckGroupedTensorShapeArrays(const GroupedTensor &t, const std::string &name); -void CheckInputGroupedTensor(const GroupedTensor &t, const std::string &name); -void CheckOutputGroupedTensor(const GroupedTensor &t, const std::string &name, +void CheckGroupedTensorShapeArrays(const GroupedTensor &t, const std::string_view &name); +void CheckInputGroupedTensor(const GroupedTensor &t, const std::string_view &name); +void CheckOutputGroupedTensor(const GroupedTensor &t, const std::string_view &name, bool allow_empty = false); } // namespace transformer_engine diff --git a/transformer_engine/common/gemm/cublaslt_gemm.cu b/transformer_engine/common/gemm/cublaslt_gemm.cu index 8589d7045d..e59e9c00c9 100644 --- a/transformer_engine/common/gemm/cublaslt_gemm.cu +++ b/transformer_engine/common/gemm/cublaslt_gemm.cu @@ -322,10 +322,8 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD, "cuBLAS GEMM does not support row-scaled NVFP4 inputs."); // Tensor dims in row-major order - const int A0 = inputA->flat_first_dim(); - const int A1 = inputA->flat_last_dim(); - const int B0 = inputB->flat_first_dim(); - const int B1 = inputB->flat_last_dim(); + const auto [A0, A1] = inputA->flat_2d_dims(); + const auto [B0, B1] = inputB->flat_2d_dims(); // GEMM dims in column-major order const int m = transa == CUBLAS_OP_T ? A0 : A1; diff --git a/transformer_engine/common/hadamard_transform/graph_safe_group_row_cast_col_hadamard_transform_cast_fusion.cu b/transformer_engine/common/hadamard_transform/graph_safe_group_row_cast_col_hadamard_transform_cast_fusion.cu index 0c3a5e9299..3acb94bebd 100644 --- a/transformer_engine/common/hadamard_transform/graph_safe_group_row_cast_col_hadamard_transform_cast_fusion.cu +++ b/transformer_engine/common/hadamard_transform/graph_safe_group_row_cast_col_hadamard_transform_cast_fusion.cu @@ -46,8 +46,8 @@ namespace { using namespace cute; -// Ensure Tensor refers to cute::Tensor, not transformer_engine::Tensor -using cute::Tensor; +using cute::Shape; // Avoid conflict with transformer_engine::Shape +using cute::Tensor; // Avoid conflict with transformer_engine::Tensor constexpr int kMaxTensorsPerKernel = 64; constexpr int kNVFP4BlockSize = 16; diff --git a/transformer_engine/common/hadamard_transform/group_hadamard_transform_cast_fusion.cu b/transformer_engine/common/hadamard_transform/group_hadamard_transform_cast_fusion.cu index e6de366f52..09ef0e295f 100644 --- a/transformer_engine/common/hadamard_transform/group_hadamard_transform_cast_fusion.cu +++ b/transformer_engine/common/hadamard_transform/group_hadamard_transform_cast_fusion.cu @@ -36,8 +36,9 @@ namespace detail { namespace { using namespace cute; -using cute:: - Tensor; // Ensure unqualified Tensor refers to cute::Tensor, not transformer_engine::Tensor + +using cute::Shape; // Avoid conflict with transformer_engine::Shape +using cute::Tensor; // Avoid conflict with transformer_engine::Tensor using Stride2D = cute::Stride>; diff --git a/transformer_engine/common/hadamard_transform/group_row_cast_col_hadamard_transform_cast_fusion.cu b/transformer_engine/common/hadamard_transform/group_row_cast_col_hadamard_transform_cast_fusion.cu index 1265f2711c..a1f1f6a819 100644 --- a/transformer_engine/common/hadamard_transform/group_row_cast_col_hadamard_transform_cast_fusion.cu +++ b/transformer_engine/common/hadamard_transform/group_row_cast_col_hadamard_transform_cast_fusion.cu @@ -46,8 +46,8 @@ namespace { using namespace cute; -// Ensure Tensor refers to cute::Tensor, not transformer_engine::Tensor -using cute::Tensor; +using cute::Shape; // Avoid conflict with transformer_engine::Shape +using cute::Tensor; // Avoid conflict with transformer_engine::Tensor constexpr int kMaxTensorsPerKernel = 64; diff --git a/transformer_engine/common/hadamard_transform/hadamard_transform_cast_fusion.cu b/transformer_engine/common/hadamard_transform/hadamard_transform_cast_fusion.cu index 957935668c..e9ae22cead 100644 --- a/transformer_engine/common/hadamard_transform/hadamard_transform_cast_fusion.cu +++ b/transformer_engine/common/hadamard_transform/hadamard_transform_cast_fusion.cu @@ -38,7 +38,9 @@ namespace detail { namespace { using namespace cute; -using cute::Tensor; // Ensure unqualified Tensor refers to cute::Tensor, not transformer_engine::Tensor + +using cute::Tensor; // Avoid conflict with transformer_engine::Tensor +using cute::Shape; // Avoid conflict with transformer_engine::Shape // calculate the global encode scale factor for a given global amax. __device__ __forceinline__ float ComputeGlobalEncodeScaleFP4(const float global_amax) { diff --git a/transformer_engine/common/hadamard_transform/row_cast_col_hadamard_transform_cast_fusion.cu b/transformer_engine/common/hadamard_transform/row_cast_col_hadamard_transform_cast_fusion.cu index 99060ab627..c41fc983ab 100644 --- a/transformer_engine/common/hadamard_transform/row_cast_col_hadamard_transform_cast_fusion.cu +++ b/transformer_engine/common/hadamard_transform/row_cast_col_hadamard_transform_cast_fusion.cu @@ -48,6 +48,9 @@ namespace { using namespace cute; +using cute::Tensor; // Avoid conflict with transformer_engine::Tensor +using cute::Shape; // Avoid conflict with transformer_engine::Shape + struct CLCResponse { uint32_t data[4] = {0}; }; constexpr int kFp4ConvertChunkElements = 8; diff --git a/transformer_engine/common/include/transformer_engine/transformer_engine.h b/transformer_engine/common/include/transformer_engine/transformer_engine.h index 045ae88893..68a5de57a4 100644 --- a/transformer_engine/common/include/transformer_engine/transformer_engine.h +++ b/transformer_engine/common/include/transformer_engine/transformer_engine.h @@ -131,6 +131,20 @@ typedef void *NVTETensor; */ NVTETensor nvte_create_tensor(NVTEScalingMode scaling_mode); +/*! \brief Create a batch of new TE tensors. + * + * Equivalent to calling nvte_create_tensor N times with the same + * scaling mode. Before use, each tensor's parameters need to be set. + * TE tensors are just wrappers on top of raw data and do not own + * memory. + * + * \param[in] scaling_mode Scaling mode shared by all tensors. + * \param[out] tensors Caller-allocated array of length N to + * receive the new tensors. + * \param[in] N Number of tensors to create. + */ +void nvte_create_tensors(NVTEScalingMode scaling_mode, NVTETensor *tensors, size_t N); + /*! \brief Destroy a TE tensor. * * Since the TE tensor does not own memory, the underlying @@ -140,6 +154,17 @@ NVTETensor nvte_create_tensor(NVTEScalingMode scaling_mode); */ void nvte_destroy_tensor(NVTETensor tensor); +/*! \brief Destroy a batch of TE tensors. + * + * Equivalent to calling nvte_destroy_tensor N times. Since TE tensors + * do not own memory, the underlying data is not freed during this + * operation. Null entries are ignored. + * + * \param[in] tensors Array of tensors to be destroyed. + * \param[in] N Number of tensors in the array. + */ +void nvte_destroy_tensors(NVTETensor *tensors, size_t N); + /*! \brief Get a raw pointer to the tensor's rowwise data. * * \param[in] tensor Tensor. diff --git a/transformer_engine/common/include/transformer_engine/utils.h b/transformer_engine/common/include/transformer_engine/utils.h index eca6f359ea..4273ba5f2c 100644 --- a/transformer_engine/common/include/transformer_engine/utils.h +++ b/transformer_engine/common/include/transformer_engine/utils.h @@ -5,13 +5,14 @@ ************************************************************************/ /*! \file utils.h - * \brief Utility functions (e.g. host-to-device pointer copies). + * \brief Utility functions (e.g. host-to-device value stores). */ #ifndef TRANSFORMER_ENGINE_UTILS_H_ #define TRANSFORMER_ENGINE_UTILS_H_ #include +#include #include #include @@ -19,12 +20,23 @@ extern "C" { #endif -/*! \brief Copy an array of device pointers (held on host) into a device tensor. +/*! \brief Copy a small host buffer into device memory. * - * \param[in] host_ptrs Host array of device pointer values cast to uint64_t. - * \param[out] output NVTETensor whose rowwise data buffer receives the pointer values. - * \param[in] count Number of pointers. - * \param[in] stream CUDA stream used for the operation. + * The data is copied into kernel arguments, so the host buffer may + * be freed immediately after this call returns. This is compatible + * with CUDA Graphs. + * + * \param[in] host_ptr Source in host memory. + * \param[out] device_ptr Destination in device memory. + * \param[in] num_bytes Size of the value in bytes. + * \param[in] stream CUDA stream for the operation. + */ +void nvte_load_value_on_device(const void *host_ptr, void *device_ptr, size_t num_bytes, + cudaStream_t stream); + +/*! \deprecated Use nvte_load_value_on_device instead. + * + * \brief Copy an array of device pointers (held on host) into a device tensor. */ void nvte_convert_pointers_to_tensor(const uint64_t *host_ptrs, NVTETensor output, int64_t count, cudaStream_t stream); diff --git a/transformer_engine/common/normalization/layernorm/ln_api.cpp b/transformer_engine/common/normalization/layernorm/ln_api.cpp index 7bd5a1bbd0..51f2ac8496 100644 --- a/transformer_engine/common/normalization/layernorm/ln_api.cpp +++ b/transformer_engine/common/normalization/layernorm/ln_api.cpp @@ -100,7 +100,7 @@ void layernorm_fwd(const Tensor& x, // BxSxhidden_size gamma_in_weight_dtype); if (workspace->data.numel() == 0) { - workspace->data.shape = plan->getWorkspaceShape(); + workspace->data.shape = Shape(plan->getWorkspaceShape()); workspace->data.dtype = DType::kByte; return; } @@ -185,7 +185,7 @@ void layernorm_bwd(const Tensor& dz, const Tensor& x, const Tensor& mu, const Te gamma_in_weight_dtype); if (workspace->data.numel() == 0) { - workspace->data.shape = plan->getWorkspaceShape(); + workspace->data.shape = Shape(plan->getWorkspaceShape()); workspace->data.dtype = DType::kByte; return; } else { diff --git a/transformer_engine/common/normalization/rmsnorm/rmsnorm_api.cpp b/transformer_engine/common/normalization/rmsnorm/rmsnorm_api.cpp index adf2ccee04..76ecff67a5 100644 --- a/transformer_engine/common/normalization/rmsnorm/rmsnorm_api.cpp +++ b/transformer_engine/common/normalization/rmsnorm/rmsnorm_api.cpp @@ -85,7 +85,7 @@ void rmsnorm_fwd(const Tensor &x, const Tensor &gamma, const float epsilon, Tens gamma_in_weight_dtype); if (workspace->data.numel() == 0) { - workspace->data.shape = plan->getWorkspaceShape(); + workspace->data.shape = Shape(plan->getWorkspaceShape()); workspace->data.dtype = DType::kByte; return; } @@ -162,7 +162,7 @@ void rmsnorm_bwd(const Tensor &dz, const Tensor &x, const Tensor &rsigma, const gamma_in_weight_dtype); if (workspace->data.numel() == 0) { - workspace->data.shape = plan->getWorkspaceShape(); + workspace->data.shape = Shape(plan->getWorkspaceShape()); workspace->data.dtype = DType::kByte; return; } else { @@ -233,7 +233,7 @@ void rmsnorm_bwd_add(const Tensor &dz, const Tensor &x, const Tensor &add, const gamma_in_weight_dtype); if (workspace->data.numel() == 0) { - workspace->data.shape = plan->getWorkspaceShape(); + workspace->data.shape = Shape(plan->getWorkspaceShape()); workspace->data.dtype = DType::kByte; return; } else { diff --git a/transformer_engine/common/swizzle/swizzle.cu b/transformer_engine/common/swizzle/swizzle.cu index c7ed407a59..b0d837479a 100644 --- a/transformer_engine/common/swizzle/swizzle.cu +++ b/transformer_engine/common/swizzle/swizzle.cu @@ -25,14 +25,15 @@ constexpr int MXFP8_BLOCK_SIZE = 32; constexpr int NVFP4_BLOCK_SIZE = 16; int get_max_dynamic_smem() { - static int max_smem = -1; - if (max_smem < 0) { - int device; + auto query_max_smem = []() -> int { + int device{0}, max_smem{0}; NVTE_CHECK_CUDA(cudaGetDevice(&device)); NVTE_CHECK_CUDA( cudaDeviceGetAttribute(&max_smem, cudaDevAttrMaxSharedMemoryPerBlockOptin, device)); - } - return max_smem; + return max_smem; + }; + static int cached_val = query_max_smem(); + return cached_val; } constexpr __device__ __host__ int TB_DIM = 32; @@ -1456,10 +1457,8 @@ void multi_tensor_swizzle_scaling_factors(const std::vector& input, // We don't allow empty tensors. They should be filtered out before calling this function. NVTE_CHECK(input[i]->numel() != 0, "Tensor input[", i, "] is empty."); - CheckInputTensor(*input[i], "scaling_factor_input[" + std::to_string(i) + "]", - check_scale_inv_shapes); - CheckInputTensor(*output[i], "scaling_factor_output[" + std::to_string(i) + "]", - check_scale_inv_shapes); + CheckInputTensor(*input[i], "scaling_factor_input", check_scale_inv_shapes); + CheckInputTensor(*output[i], "scaling_factor_output", check_scale_inv_shapes); all_has_data = all_has_data && input[i]->scale_inv.has_data(); all_has_columnwise_data = (all_has_columnwise_data && input[i]->columnwise_scale_inv.has_data()); @@ -1540,17 +1539,18 @@ void multi_tensor_swizzle_scaling_factors(const std::vector& input, const int pos = kernel_args.num_tensors; kernel_args.m_list[pos] = m; kernel_args.k_list[pos] = k; + const auto [first_dim, last_dim] = input[i]->flat_2d_dims(); if (!all_nvfp4 || all_has_data) { int block_scale_size = all_nvfp4 ? NVFP4_BLOCK_SIZE : MXFP8_BLOCK_SIZE; kernel_args.input_list[pos] = const_cast(input[i]->scale_inv.dptr); kernel_args.output_list[pos] = output[i]->scale_inv.dptr; - kernel_args.original_m_list[pos] = input[i]->flat_first_dim(); - kernel_args.original_k_list[pos] = input[i]->flat_last_dim() / block_scale_size; + kernel_args.original_m_list[pos] = first_dim; + kernel_args.original_k_list[pos] = last_dim / block_scale_size; } else { kernel_args.input_list[pos] = const_cast(input[i]->columnwise_scale_inv.dptr); kernel_args.output_list[pos] = output[i]->columnwise_scale_inv.dptr; - kernel_args.original_m_list[pos] = input[i]->flat_last_dim(); - kernel_args.original_k_list[pos] = input[i]->flat_first_dim() / NVFP4_BLOCK_SIZE; + kernel_args.original_m_list[pos] = last_dim; + kernel_args.original_k_list[pos] = first_dim / NVFP4_BLOCK_SIZE; } kernel_args.num_tensors++; } @@ -1609,8 +1609,9 @@ void multi_tensor_swizzle_scaling_factors(const std::vector& input, kernel_args.output_list[pos] = output[i]->columnwise_scale_inv.dptr; kernel_args.m_list[pos] = m; kernel_args.k_list[pos] = k; - kernel_args.original_m_list[pos] = input[i]->flat_last_dim(); - kernel_args.original_k_list[pos] = input[i]->flat_first_dim() / MXFP8_BLOCK_SIZE; + const auto [first_dim, last_dim] = input[i]->flat_2d_dims(); + kernel_args.original_m_list[pos] = last_dim; + kernel_args.original_k_list[pos] = first_dim / MXFP8_BLOCK_SIZE; kernel_args.num_tensors++; } // Launch the remaining tensors @@ -1958,6 +1959,8 @@ void nvte_multi_tensor_swizzle_scaling_factors(const NVTETensor* inputs, NVTETen using namespace transformer_engine; NVTE_CHECK(num_tensors > 0, "Number of tensors should be greater than 0."); std::vector input_list, output_list; + input_list.reserve(num_tensors); + output_list.reserve(num_tensors); for (size_t i = 0; i < num_tensors; i++) { input_list.push_back(convertNVTETensorCheck(inputs[i])); output_list.push_back(convertNVTETensorCheck(outputs[i])); diff --git a/transformer_engine/common/transformer_engine.cpp b/transformer_engine/common/transformer_engine.cpp index 1a52d76019..e77660cb63 100644 --- a/transformer_engine/common/transformer_engine.cpp +++ b/transformer_engine/common/transformer_engine.cpp @@ -7,6 +7,7 @@ #include #include +#include #include #include #include @@ -51,7 +52,7 @@ std::string to_string(const NVTEScalingMode &mode) { return "Invalid Scaling"; } -void CheckNoopTensor(const Tensor &t, const std::string &name) { +void CheckNoopTensor(const Tensor &t, const std::string_view &name) { if (t.data.has_data()) { NVTE_CHECK(t.numel() == 1, "Expected 1 element for ", name, " noop, but found ", t.numel(), "."); @@ -60,7 +61,7 @@ void CheckNoopTensor(const Tensor &t, const std::string &name) { } } -void CheckScaleTensorShape(const Tensor &t, const std::string &name) { +void CheckScaleTensorShape(const Tensor &t, const std::string_view &name) { NVTE_CHECK(t.scaling_mode != NVTE_INVALID_SCALING, "Invalid scaling mode!"); if (is_tensor_scaling(t.scaling_mode)) { if (is_fp8_dtype(t.dtype())) { @@ -91,60 +92,55 @@ void CheckScaleTensorShape(const Tensor &t, const std::string &name) { } else { if (t.scaling_mode == NVTE_MXFP8_1D_SCALING) { // Need (4, 128) alignment even for e8 scaling factor - auto block_alignment = std::vector{128ul, 4ul}; - size_t expected_x, expected_y, alignment; - const size_t block_size_rowwise = 32; - const size_t block_size_colwise = 32; + constexpr std::array block_alignment{128ul, 4ul}; + const auto [first_dim, last_dim] = t.flat_2d_dims(); if (t.has_data()) { - alignment = block_alignment[0]; - expected_x = - DIVUP(DIVUP(t.flat_first_dim(), static_cast(1)), alignment) * alignment; - alignment = block_alignment[1]; - expected_y = - DIVUP(DIVUP(t.flat_last_dim(), static_cast(block_size_rowwise)), alignment) * - alignment; - - const auto &expected = std::vector{expected_x, expected_y}; + constexpr std::array block_shape{1, 32}; + const std::array expected{ + DIVUP_TO_MULTIPLE(DIVUP(first_dim, block_shape[0]), block_alignment[0]), + DIVUP_TO_MULTIPLE(DIVUP(last_dim, block_shape[1]), block_alignment[1])}; NVTE_CHECK(t.scale_inv.shape == expected, "Tensor \"", name, "\" has invalid scale_inv shape (expected ", expected, ", got ", t.scale_inv.shape, ")"); } if (t.has_columnwise_data()) { - alignment = block_alignment[1]; - expected_x = - DIVUP(DIVUP(t.flat_first_dim(), static_cast(block_size_colwise)), alignment) * - alignment; - alignment = block_alignment[0]; - expected_y = DIVUP(DIVUP(t.flat_last_dim(), static_cast(1)), alignment) * alignment; - - const auto &expected = std::vector{expected_x, expected_y}; + constexpr std::array block_shape{32, 1}; + const std::array expected{ + DIVUP_TO_MULTIPLE(DIVUP(first_dim, block_shape[0]), block_alignment[1]), + DIVUP_TO_MULTIPLE(DIVUP(last_dim, block_shape[1]), block_alignment[0])}; NVTE_CHECK(t.columnwise_scale_inv.shape == expected, "Tensor \"", name, "\" has invalid columnwise_scale_inv shape (expected ", expected, ", got ", t.columnwise_scale_inv.shape, ")"); } } else if (t.scaling_mode == NVTE_NVFP4_1D_SCALING) { + const auto [first_dim, last_dim] = t.flat_2d_dims(); + if (t.has_data()) { - const size_t expected_y = DIVUP_TO_MULTIPLE(t.flat_first_dim(), 128); - const size_t expected_x = DIVUP_TO_MULTIPLE(DIVUP(t.flat_last_dim(), 16lu), 4); - const auto &expected = std::vector{expected_y, expected_x}; + constexpr std::array block_shape{1, 16}; + constexpr std::array block_alignment{128, 4}; + const std::array expected{ + DIVUP_TO_MULTIPLE(DIVUP(first_dim, block_shape[0]), block_alignment[0]), + DIVUP_TO_MULTIPLE(DIVUP(last_dim, block_shape[1]), block_alignment[1])}; NVTE_CHECK(t.scale_inv.shape == expected, "Tensor \"", name, "\" has invalid scale_inv shape (expected ", expected, ", got ", t.scale_inv.shape, ")"); } if (t.has_columnwise_data()) { - const size_t expected_y = DIVUP_TO_MULTIPLE(t.flat_last_dim(), 128); - const size_t expected_x = DIVUP_TO_MULTIPLE(DIVUP(t.flat_first_dim(), 16lu), 4); - const auto &expected = std::vector{expected_y, expected_x}; + constexpr std::array block_shape{1, 16}; + constexpr std::array block_alignment{128, 4}; + const std::array expected{ + DIVUP_TO_MULTIPLE(DIVUP(last_dim, block_shape[0]), block_alignment[0]), + DIVUP_TO_MULTIPLE(DIVUP(first_dim, block_shape[1]), block_alignment[1])}; NVTE_CHECK(t.columnwise_scale_inv.shape == expected, "Tensor \"", name, - "\" has invalid columnwise_scale_inv shape (expected ", expected, ", got ", + "\" has invalid columnwise_scale_inv shape (expected ", expected, ", got ", t.columnwise_scale_inv.shape, ")"); } } } } -void CheckInputTensor(const Tensor &t, const std::string &name, bool check_scale_inv_shapes) { +void CheckInputTensor(const Tensor &t, const std::string_view &name, bool check_scale_inv_shapes) { const DType type = t.dtype(); if (is_fp8_dtype(type)) { // FP8 input needs to have scale_inv @@ -200,7 +196,7 @@ void CheckInputTensor(const Tensor &t, const std::string &name, bool check_scale } } -void CheckOutputTensor(const Tensor &t, const std::string &name, bool allow_empty) { +void CheckOutputTensor(const Tensor &t, const std::string_view &name, bool allow_empty) { const DType type = t.dtype(); if (is_fp8_dtype(type)) { // FP8 output needs to have scale, scale_inv and (if delayed scaling) amax @@ -262,7 +258,7 @@ void CheckOutputTensor(const Tensor &t, const std::string &name, bool allow_empt CheckScaleTensorShape(t, name); } -void CheckGroupedTensorShapeArrays(const GroupedTensor &t, const std::string &name) { +void CheckGroupedTensorShapeArrays(const GroupedTensor &t, const std::string_view &name) { NVTE_CHECK(t.num_tensors > 0, "Grouped tensor ", name, " has no tensors!"); // Helper lambda to validate shape arrays @@ -332,7 +328,8 @@ void CheckGroupedTensorShapeArrays(const GroupedTensor &t, const std::string &na } // Helper function to check scale_inv for both input and output -static void CheckGroupedScaleInv(const GroupedTensor &t, const std::string &name, bool is_output) { +static void CheckGroupedScaleInv(const GroupedTensor &t, const std::string_view &name, + bool is_output) { const char *tensor_type = is_output ? "output" : "input"; // Helper to check scale_inv for both rowwise and columnwise layouts @@ -369,14 +366,15 @@ static void CheckGroupedScaleInv(const GroupedTensor &t, const std::string &name } } -void CheckInputGroupedTensor(const GroupedTensor &t, const std::string &name) { +void CheckInputGroupedTensor(const GroupedTensor &t, const std::string_view &name) { NVTE_CHECK(t.has_data() || t.has_columnwise_data(), "Input grouped tensor ", name, " not allocated"); CheckGroupedScaleInv(t, name, false); CheckGroupedTensorShapeArrays(t, name); } -void CheckOutputGroupedTensor(const GroupedTensor &t, const std::string &name, bool allow_empty) { +void CheckOutputGroupedTensor(const GroupedTensor &t, const std::string_view &name, + bool allow_empty) { if (!allow_empty) { NVTE_CHECK(t.has_data() || t.has_columnwise_data(), "Output grouped tensor ", name, " not allocated"); @@ -437,6 +435,32 @@ class TensorAllocator { MAX_TENSOR_NUM, ". There is probably a memory leak in your application."); } + void Allocate(NVTEScalingMode mode, NVTETensor *out, size_t N) { + std::lock_guard lock(mutex); + const size_t available = free_list.size() + (memory.capacity() - memory.size()); + NVTE_CHECK(available >= N, "Cannot allocate ", N, + " new NVTETensors. Maximum number of tensors reached: ", MAX_TENSOR_NUM, + ". There is probably a memory leak in your application."); + for (size_t i = 0; i < N; ++i) { + uintptr_t index; + if (!free_list.empty()) { + index = free_list.back(); + free_list.pop_back(); + } else { + memory.emplace_back(); + index = memory.size(); + size = index; + memory[index - 1].nvte_tensor = reinterpret_cast(index); + } + memory[index - 1].scaling_mode = mode; + out[i] = reinterpret_cast(index); + } + if (debug) { + std::cout << "Allocated range of " << N << " tensors. Free list size: " << free_list.size() + << " and capacity " << free_list.capacity() << std::endl; + } + } + void Free(NVTETensor t) { uintptr_t index = reinterpret_cast(t); if (index == 0) return; @@ -599,6 +623,10 @@ NVTETensor nvte_create_tensor(NVTEScalingMode scaling_mode) { return ret; } +void nvte_create_tensors(NVTEScalingMode scaling_mode, NVTETensor *tensors, size_t N) { + transformer_engine::TensorAllocator::instance().Allocate(scaling_mode, tensors, N); +} + void nvte_destroy_tensor(NVTETensor tensor) { transformer_engine::TensorAllocator::instance().Free(tensor); } @@ -636,11 +664,7 @@ NVTEShape nvte_tensor_shape(const NVTETensor tensor) { if (t == nullptr) { NVTE_ERROR("Invalid tensor: received null pointer in nvte_tensor_shape"); } - - // Determine tensor shape depending on tensor format - const std::vector &shape = t->shape(); - - return nvte_make_shape(shape.data(), shape.size()); + return t->shape(); } NVTEShape nvte_tensor_columnwise_shape(const NVTETensor tensor) { @@ -648,8 +672,7 @@ NVTEShape nvte_tensor_columnwise_shape(const NVTETensor tensor) { if (t == nullptr) { NVTE_ERROR("Invalid tensor: received null pointer in nvte_tensor_columnwise_shape"); } - const std::vector &shape = t->columnwise_data.shape; - return nvte_make_shape(shape.data(), shape.size()); + return t->columnwise_data.shape; } size_t nvte_tensor_ndims(const NVTETensor tensor) { return nvte_tensor_shape(tensor).ndim; } @@ -952,10 +975,8 @@ NVTEScalingMode nvte_tensor_scaling_mode(const NVTETensor tensor) { } void nvte_tensor_pack_create(NVTETensorPack *pack) { - for (int i = 0; i < pack->MAX_SIZE; i++) { - pack->tensors[i] = - transformer_engine::TensorAllocator::instance().Allocate(NVTE_DELAYED_TENSOR_SCALING); - } + transformer_engine::TensorAllocator::instance().Allocate(NVTE_DELAYED_TENSOR_SCALING, + pack->tensors, pack->MAX_SIZE); } void nvte_tensor_pack_destroy(NVTETensorPack *pack) { diff --git a/transformer_engine/common/transpose/cast_transpose_fusion.cu b/transformer_engine/common/transpose/cast_transpose_fusion.cu index 77c1322e7d..ba004291f3 100644 --- a/transformer_engine/common/transpose/cast_transpose_fusion.cu +++ b/transformer_engine/common/transpose/cast_transpose_fusion.cu @@ -179,8 +179,7 @@ inline __device__ void cast_and_transpose_regs(const CVec (&in)[nvec_out], void populate_cast_transpose_dbias_workspace_config(const Tensor &cast_output, /*cast*/ Tensor *workspace, const int nvec_out) { - const size_t row_length = cast_output.flat_last_dim(); - const size_t num_rows = cast_output.flat_first_dim(); + const auto [num_rows, row_length] = cast_output.flat_2d_dims(); const size_t tile_size_y = (nvec_out * THREADS_PER_WARP); NVTE_CHECK(num_rows % nvec_out == 0, "Unsupported shape."); diff --git a/transformer_engine/common/util/utils.cu b/transformer_engine/common/util/utils.cu index a183e6ec52..678a97ff30 100644 --- a/transformer_engine/common/util/utils.cu +++ b/transformer_engine/common/util/utils.cu @@ -7,45 +7,76 @@ #include #include +#include +#include + #include "../common.h" #include "../util/logging.h" +namespace transformer_engine { +namespace load_value_on_device { namespace { -constexpr int64_t kMaxKernelAddresses = 256; - -struct HostPointersArgs { - uint64_t ptrs[kMaxKernelAddresses]; +union Payload { + static constexpr size_t kMaxBytes = 2048; + static constexpr size_t kVectorSize = 4; + static constexpr size_t kMaxVectors = kMaxBytes / kVectorSize; + uint8_t bytes[kMaxBytes]; + uint32_t vectors[kMaxVectors]; }; -__global__ void write_pointers_kernel(HostPointersArgs args, uint64_t *out, int64_t count, - int64_t offset) { - const int64_t idx = static_cast(blockIdx.x) * blockDim.x + threadIdx.x; - if (idx < count) { - out[offset + idx] = args.ptrs[idx]; +constexpr size_t block_size = 512; +constexpr size_t num_blocks = DIVUP(Payload::kMaxVectors, block_size); + +__global__ void __launch_bounds__(block_size) kernel(Payload payload, size_t num_bytes, void *out) { + const size_t tid = blockIdx.x * blockDim.x + threadIdx.x; + if (Payload::kVectorSize * (tid + 1) <= num_bytes) { + reinterpret_cast(out)[tid] = payload.vectors[tid]; + } else { + for (size_t i = Payload::kVectorSize * tid; i < num_bytes; ++i) { + static_cast(out)[i] = payload.bytes[i]; + } } } } // namespace +} // namespace load_value_on_device +} // namespace transformer_engine + +void nvte_load_value_on_device(const void *host_ptr, void *device_ptr, size_t num_bytes, + cudaStream_t stream) { + NVTE_API_CALL(nvte_load_value_on_device); + using namespace transformer_engine::load_value_on_device; + + // Nothing to be done if size is zero + if (num_bytes == 0) { + return; + } + + // Check pointers + NVTE_CHECK(host_ptr != nullptr, "Attempting to read ", num_bytes, " bytes from a null pointer."); + NVTE_CHECK(device_ptr != nullptr, "Attempting to write ", num_bytes, + " bytes into a null pointer."); + NVTE_CHECK(reinterpret_cast(device_ptr) % Payload::kVectorSize == 0, + "Device pointer is not aligned to ", Payload::kVectorSize, " bytes."); + + // Chunk data to fit in kernel arguments and launch kernels + const uint8_t *src = static_cast(host_ptr); + uint8_t *dst = static_cast(device_ptr); + for (size_t offset = 0; offset < num_bytes; offset += Payload::kMaxBytes) { + const size_t chunk_size = std::min(num_bytes - offset, Payload::kMaxBytes); + Payload payload{}; + std::memcpy(payload.bytes, src + offset, chunk_size); + kernel<<>>(payload, chunk_size, dst + offset); + NVTE_CHECK_CUDA(cudaGetLastError()); + } +} void nvte_convert_pointers_to_tensor(const uint64_t *host_ptrs, NVTETensor output, int64_t count, cudaStream_t stream) { NVTE_API_CALL(nvte_convert_pointers_to_tensor); using namespace transformer_engine; Tensor *out_tensor = convertNVTETensorCheck(output); - uint64_t *out_ptr = static_cast(out_tensor->data.dptr); - NVTE_CHECK(out_ptr != nullptr, "Output tensor data pointer is null."); - - int64_t offset = 0; - while (offset < count) { - const int64_t chunk = std::min(kMaxKernelAddresses, count - offset); - HostPointersArgs args{}; - for (int64_t i = 0; i < chunk; ++i) { - args.ptrs[i] = host_ptrs[offset + i]; - } - constexpr int threads = kMaxKernelAddresses; - write_pointers_kernel<<<1, threads, 0, stream>>>(args, out_ptr, chunk, offset); - NVTE_CHECK_CUDA(cudaGetLastError()); - offset += chunk; - } + nvte_load_value_on_device(host_ptrs, out_tensor->data.dptr, + static_cast(count) * sizeof(uint64_t), stream); } diff --git a/transformer_engine/pytorch/csrc/extensions.h b/transformer_engine/pytorch/csrc/extensions.h index 8082ff07ed..8878c77a0b 100644 --- a/transformer_engine/pytorch/csrc/extensions.h +++ b/transformer_engine/pytorch/csrc/extensions.h @@ -484,14 +484,15 @@ size_t get_cublasLt_version(); size_t get_cudnn_version(); -std::vector convert_host_pointers_to_tensor( - std::vector> tensor_lists); - -std::tuple get_device_pointer_for_data_and_scales( - std::vector data_tensors, std::vector scale_tensors, bool swizzle, - bool rowwise, transformer_engine::DType data_dtype); at::Tensor splits_to_offsets(const at::Tensor &first_dims, int64_t logical_last_dim); +at::Tensor load_data_ptrs_on_device(const std::vector &tensors, + const c10::Device &device); + +std::tuple> transform_and_load_data_ptrs_on_device( + const std::string &transform_type, const std::vector &tensors, + const c10::Device &device); + /*************************************************************************************************** * Support THD format for Context Parallel **************************************************************************************************/ diff --git a/transformer_engine/pytorch/csrc/extensions/pybind.cpp b/transformer_engine/pytorch/csrc/extensions/pybind.cpp index a4571c64e2..c45082cf31 100644 --- a/transformer_engine/pytorch/csrc/extensions/pybind.cpp +++ b/transformer_engine/pytorch/csrc/extensions/pybind.cpp @@ -489,15 +489,12 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { "Get cublasLt version", py::call_guard()); m.def("get_cudnn_version", &transformer_engine::pytorch::get_cudnn_version, "Get cuDNN version", py::call_guard()); - m.def("convert_host_pointers_to_tensor", - &transformer_engine::pytorch::convert_host_pointers_to_tensor, - "Copy host-side device pointers into device tensors", py::arg("tensor_lists"), - py::call_guard()); - m.def("get_device_pointer_for_data_and_scales", - &transformer_engine::pytorch::get_device_pointer_for_data_and_scales, - "Swizzle scales and collect data/scale device pointers into device tensors", - py::arg("data_tensors"), py::arg("scale_tensors"), py::arg("swizzle") = false, - py::arg("rowwise"), py::arg("data_dtype"), py::call_guard()); + m.def("load_data_ptrs_on_device", &transformer_engine::pytorch::load_data_ptrs_on_device, + py::arg("tensors"), py::arg("device"), py::call_guard()); + m.def("transform_and_load_data_ptrs_on_device", + &transformer_engine::pytorch::transform_and_load_data_ptrs_on_device, + py::arg("transform_type"), py::arg("tensors"), py::arg("device"), + py::call_guard()); m.def("splits_to_offsets", &transformer_engine::pytorch::splits_to_offsets, "Compute grouped tensor offsets from split sizes", py::arg("first_dims"), py::arg("logical_last_dim"), py::call_guard()); diff --git a/transformer_engine/pytorch/csrc/extensions/recipe.cpp b/transformer_engine/pytorch/csrc/extensions/recipe.cpp index c02d2ec616..1bf353fd77 100644 --- a/transformer_engine/pytorch/csrc/extensions/recipe.cpp +++ b/transformer_engine/pytorch/csrc/extensions/recipe.cpp @@ -35,35 +35,37 @@ void fused_amax_and_scale_update_after_reduction(const at::Tensor& amax_reductio const std::string& amax_compute_algo, DType fp8_dtype, float margin) { size_t num_tensors = amax_histories.size(); - std::vector te_amax_histories; - std::vector te_scales; - te_amax_histories.reserve(num_tensors); - te_scales.reserve(num_tensors); + + // Helper to deallocate a batch of NVTETensors + struct DestroyGuard { + NVTETensor* data; + size_t n; + ~DestroyGuard() { nvte_destroy_tensors(data, n); } + }; + + // Allocate amax history and scale NVTETensors as batches + std::vector te_amax_histories(num_tensors); + nvte_create_tensors(NVTE_DELAYED_TENSOR_SCALING, te_amax_histories.data(), num_tensors); + DestroyGuard amax_guard{te_amax_histories.data(), num_tensors}; + std::vector te_scales(num_tensors); + nvte_create_tensors(NVTE_DELAYED_TENSOR_SCALING, te_scales.data(), num_tensors); + DestroyGuard scale_guard{te_scales.data(), num_tensors}; + for (size_t i = 0; i < num_tensors; i++) { - te_amax_histories.push_back(nvte_create_tensor(NVTE_DELAYED_TENSOR_SCALING)); - NVTETensor& amax_history = te_amax_histories.back(); NVTEShape amax_shape = convertTorchShape(amax_histories[i].sizes()); NVTEBasicTensor amax_history_data = {amax_histories[i].data_ptr(), static_cast(DType::kFloat32), amax_shape}; - nvte_set_tensor_param(&amax_history, kNVTERowwiseData, &amax_history_data); + nvte_set_tensor_param(&te_amax_histories[i], kNVTERowwiseData, &amax_history_data); - te_scales.push_back(nvte_create_tensor(NVTE_DELAYED_TENSOR_SCALING)); - NVTETensor& scale = te_scales.back(); NVTEShape scale_shape = convertTorchShape(scales[i].sizes()); NVTEBasicTensor scale_data = {scales[i].data_ptr(), static_cast(DType::kFloat32), scale_shape}; - nvte_set_tensor_param(&scale, kNVTERowwiseData, &scale_data); + nvte_set_tensor_param(&te_scales[i], kNVTERowwiseData, &scale_data); } nvte_delayed_scaling_recipe_amax_and_scale_update_after_reduction( makeTransformerEngineTensor(amax_reduction_buffer).data(), te_amax_histories, te_scales, amax_compute_algo.c_str(), static_cast(fp8_dtype), margin, at::cuda::getCurrentCUDAStream()); - for (auto& t : te_amax_histories) { - nvte_destroy_tensor(t); - } - for (auto& t : te_scales) { - nvte_destroy_tensor(t); - } } } // namespace transformer_engine::pytorch diff --git a/transformer_engine/pytorch/csrc/extensions/swizzle.cpp b/transformer_engine/pytorch/csrc/extensions/swizzle.cpp index 193aed29e6..3ef04ff232 100644 --- a/transformer_engine/pytorch/csrc/extensions/swizzle.cpp +++ b/transformer_engine/pytorch/csrc/extensions/swizzle.cpp @@ -173,6 +173,7 @@ std::optional multi_tensor_swizzle_scales_for_gemm_impl( // Filter out tensors that already have swizzled scales std::vector tensors_needing_swizzle; + tensors_needing_swizzle.reserve(tensors.size()); for (auto &tensor : tensors) { if (!tensor.get_with_gemm_swizzled_scales()) { tensors_needing_swizzle.push_back(&tensor); @@ -184,6 +185,7 @@ std::optional multi_tensor_swizzle_scales_for_gemm_impl( // Determine buffer size needed for swizzled scales std::vector output_scales_offsets; + output_scales_offsets.reserve(tensors_needing_swizzle.size()); size_t output_scales_bytes = 0; for (auto &tensor : tensors_needing_swizzle) { const auto scales_nvte = @@ -202,75 +204,86 @@ std::optional multi_tensor_swizzle_scales_for_gemm_impl( transformer_engine::DType::kByte, false); uint8_t *output_scales_dptr = reinterpret_cast(getDataPtr(output_scales_pyt)); - // Construct TE tensors with only scales - std::vector inputs_nvte, outputs_nvte; - for (size_t i = 0; i < tensors_needing_swizzle.size(); ++i) { + // Allocate input/output NVTETensors as a single batch. The first + // n_swizzle entries are inputs; the next n_swizzle are outputs. + const size_t n_swizzle = tensors_needing_swizzle.size(); + std::vector nvte_tensors(2 * n_swizzle); + nvte_create_tensors(scaling_mode, nvte_tensors.data(), nvte_tensors.size()); + struct DestroyGuard { + NVTETensor *data; + size_t n; + ~DestroyGuard() { nvte_destroy_tensors(data, n); } + } destroy_guard{nvte_tensors.data(), nvte_tensors.size()}; + NVTETensor *inputs_nvte = nvte_tensors.data(); + NVTETensor *outputs_nvte = nvte_tensors.data() + n_swizzle; + + auto set_param = [](NVTETensor t, NVTETensorParam param, void *dptr, + transformer_engine::DType dtype, const NVTEShape &shape) { + NVTEBasicTensor data{dptr, static_cast(dtype), shape}; + nvte_set_tensor_param_v2(t, param, &data, sizeof(data)); + }; + + // Cache output scale dtype/shape per tensor so we can update the + // source TensorWrappers without re-reading from the output NVTETensors. + std::vector output_scales_dtypes(n_swizzle); + std::vector output_scales_shapes(n_swizzle); + + for (size_t i = 0; i < n_swizzle; ++i) { auto &tensor = *tensors_needing_swizzle[i]; - inputs_nvte.emplace_back(scaling_mode); - outputs_nvte.emplace_back(scaling_mode); - auto &input_nvte = inputs_nvte.back(); - auto &output_nvte = outputs_nvte.back(); - output_nvte.set_with_gemm_swizzled_scales(true); + const uint8_t swizzled_flag = 1; + nvte_set_tensor_param_v2(outputs_nvte[i], kNVTEWithGEMMSwizzledScales, &swizzled_flag, + sizeof(swizzled_flag)); if (rowwise_usage) { const auto data_nvte = tensor.get_rowwise_data(); const auto scales_nvte = tensor.get_rowwise_scale_inv(); const auto data_dtype = static_cast(data_nvte.dtype); const auto scales_dtype = static_cast(scales_nvte.dtype); - input_nvte.set_rowwise_data(nullptr, data_dtype, data_nvte.shape); - input_nvte.set_rowwise_scale_inv(scales_nvte.data_ptr, scales_dtype, scales_nvte.shape); - output_nvte.set_rowwise_data(nullptr, data_dtype, data_nvte.shape); - output_nvte.set_rowwise_scale_inv(output_scales_dptr + output_scales_offsets[i], scales_dtype, - scales_nvte.shape); + output_scales_dtypes[i] = scales_dtype; + output_scales_shapes[i] = scales_nvte.shape; + set_param(inputs_nvte[i], kNVTERowwiseData, nullptr, data_dtype, data_nvte.shape); + set_param(inputs_nvte[i], kNVTERowwiseScaleInv, scales_nvte.data_ptr, scales_dtype, + scales_nvte.shape); + set_param(outputs_nvte[i], kNVTERowwiseData, nullptr, data_dtype, data_nvte.shape); + set_param(outputs_nvte[i], kNVTERowwiseScaleInv, + output_scales_dptr + output_scales_offsets[i], scales_dtype, scales_nvte.shape); } else { const auto data_nvte = tensor.get_columnwise_data(); const auto scales_nvte = tensor.get_columnwise_scale_inv(); const auto data_dtype = static_cast(data_nvte.dtype); const auto scales_dtype = static_cast(scales_nvte.dtype); - input_nvte.set_columnwise_data(nullptr, data_dtype, data_nvte.shape); - input_nvte.set_columnwise_scale_inv(scales_nvte.data_ptr, scales_dtype, scales_nvte.shape); - output_nvte.set_columnwise_data(nullptr, data_dtype, data_nvte.shape); - output_nvte.set_columnwise_scale_inv(output_scales_dptr + output_scales_offsets[i], - scales_dtype, scales_nvte.shape); + output_scales_dtypes[i] = scales_dtype; + output_scales_shapes[i] = scales_nvte.shape; + set_param(inputs_nvte[i], kNVTEColumnwiseData, nullptr, data_dtype, data_nvte.shape); + set_param(inputs_nvte[i], kNVTEColumnwiseScaleInv, scales_nvte.data_ptr, scales_dtype, + scales_nvte.shape); + set_param(outputs_nvte[i], kNVTEColumnwiseData, nullptr, data_dtype, data_nvte.shape); + set_param(outputs_nvte[i], kNVTEColumnwiseScaleInv, + output_scales_dptr + output_scales_offsets[i], scales_dtype, scales_nvte.shape); } } - // Pack raw NVTETensors into vectors - std::vector inputs_nvte_raw, outputs_nvte_raw; - for (auto &tensor : inputs_nvte) { - inputs_nvte_raw.emplace_back(tensor.data()); - } - for (auto &tensor : outputs_nvte) { - outputs_nvte_raw.emplace_back(tensor.data()); - } - // Launch kernel NVTE_SCOPED_GIL_RELEASE({ if (check_scale_inv_shapes) { - nvte_multi_tensor_swizzle_scaling_factors(inputs_nvte_raw.data(), outputs_nvte_raw.data(), - inputs_nvte_raw.size(), + nvte_multi_tensor_swizzle_scaling_factors(inputs_nvte, outputs_nvte, n_swizzle, at::cuda::getCurrentCUDAStream()); } else { - nvte_multi_tensor_swizzle_scaling_factors_unchecked( - inputs_nvte_raw.data(), outputs_nvte_raw.data(), inputs_nvte_raw.size(), - at::cuda::getCurrentCUDAStream()); + nvte_multi_tensor_swizzle_scaling_factors_unchecked(inputs_nvte, outputs_nvte, n_swizzle, + at::cuda::getCurrentCUDAStream()); } }); // Update tensors with swizzled scales - for (size_t i = 0; i < tensors_needing_swizzle.size(); ++i) { + for (size_t i = 0; i < n_swizzle; ++i) { auto &tensor = *tensors_needing_swizzle[i]; reset_tensor_data(tensor, !rowwise_usage, !columnwise_usage); tensor.set_with_gemm_swizzled_scales(true); if (rowwise_usage) { - auto scales_nvte = outputs_nvte[i].get_rowwise_scale_inv(); - const auto scales_dtype = static_cast(scales_nvte.dtype); - tensor.set_rowwise_scale_inv(output_scales_dptr + output_scales_offsets[i], scales_dtype, - scales_nvte.shape); + tensor.set_rowwise_scale_inv(output_scales_dptr + output_scales_offsets[i], + output_scales_dtypes[i], output_scales_shapes[i]); } else { - auto scales_nvte = outputs_nvte[i].get_columnwise_scale_inv(); - const auto scales_dtype = static_cast(scales_nvte.dtype); - tensor.set_columnwise_scale_inv(output_scales_dptr + output_scales_offsets[i], scales_dtype, - scales_nvte.shape); + tensor.set_columnwise_scale_inv(output_scales_dptr + output_scales_offsets[i], + output_scales_dtypes[i], output_scales_shapes[i]); } } diff --git a/transformer_engine/pytorch/csrc/extensions/utils.cpp b/transformer_engine/pytorch/csrc/extensions/utils.cpp index 9a093608d4..99b25b63ac 100644 --- a/transformer_engine/pytorch/csrc/extensions/utils.cpp +++ b/transformer_engine/pytorch/csrc/extensions/utils.cpp @@ -6,6 +6,10 @@ #include +#include +#include +#include +#include #include #include "common/common.h" @@ -13,153 +17,166 @@ namespace transformer_engine::pytorch { -namespace { - -at::Tensor collect_pointers_in_device_tensor(const std::vector& host_ptrs, - const at::Device& device, cudaStream_t stream) { - const int64_t count = static_cast(host_ptrs.size()); - auto out = at::empty({count}, at::TensorOptions().dtype(at::kLong).device(device)); - auto out_nvte = makeTransformerEngineTensor(out); - nvte_convert_pointers_to_tensor(host_ptrs.data(), out_nvte.data(), count, stream); - return out; -} +at::Tensor load_data_ptrs_on_device(const std::vector &tensors, + const c10::Device &device) { + // Collect data pointers + std::vector ptrs_host; + ptrs_host.reserve(tensors.size()); + for (const auto &tensor : tensors) { + ptrs_host.push_back(reinterpret_cast(tensor.data_ptr())); + } -} // namespace + // Allocate device buffer + auto ptrs_device = at::empty({static_cast(tensors.size())}, + at::TensorOptions().dtype(at::kLong).device(device)); -std::vector convert_host_pointers_to_tensor( - std::vector> tensor_lists) { - std::vector outputs; - outputs.reserve(tensor_lists.size()); - auto stream = at::cuda::getCurrentCUDAStream(); + // Load pointers on device + nvte_load_value_on_device(ptrs_host.data(), ptrs_device.data_ptr(), + tensors.size() * sizeof(uint64_t), at::cuda::getCurrentCUDAStream()); - for (const auto& tensor_list : tensor_lists) { - NVTE_CHECK(!tensor_list.empty(), "Tensor list is empty."); - const auto& first_tensor = tensor_list[0]; - NVTE_CHECK(first_tensor.is_cuda(), "Tensor list must be on CUDA."); - const auto device = first_tensor.device(); - const int64_t count = static_cast(tensor_list.size()); - std::vector host_ptrs(count); - for (int64_t i = 0; i < count; ++i) { - host_ptrs[i] = reinterpret_cast(tensor_list[static_cast(i)].data_ptr()); - } - outputs.push_back(collect_pointers_in_device_tensor(host_ptrs, device, stream)); - } - - return outputs; + return ptrs_device; } -std::tuple get_device_pointer_for_data_and_scales( - std::vector data_tensors, std::vector scale_tensors, bool swizzle, - bool rowwise, transformer_engine::DType data_dtype) { - const size_t num_tensors = data_tensors.size(); - NVTE_CHECK(num_tensors > 0, "data_tensors must not be empty."); - NVTE_CHECK(num_tensors == scale_tensors.size(), - "data_tensors and scale_tensors must have the same size."); - NVTE_CHECK(data_tensors[0].is_cuda(), "data_tensors must be on CUDA."); - const auto device = data_tensors[0].device(); - auto stream = at::cuda::getCurrentCUDAStream(); +std::tuple> transform_and_load_data_ptrs_on_device( + const std::string &transform_type, const std::vector &tensors, + const c10::Device &device) { + const size_t num_tensors = tensors.size(); - // Infer data shape from the first data tensor (expected 2D: n x k) - NVTE_CHECK(data_tensors[0].dim() == 2, - "data_tensors elements must be 2D, got dim=", data_tensors[0].dim()); - NVTEShape data_shape{}; - data_shape.ndim = 2; - data_shape.data[0] = static_cast(data_tensors[0].size(0)); - data_shape.data[1] = static_cast(data_tensors[0].size(1)); - - // Collect data device pointers - std::vector data_host_ptrs(num_tensors); - for (size_t i = 0; i < num_tensors; ++i) { - data_host_ptrs[i] = reinterpret_cast(data_tensors[i].data_ptr()); + // Trivial cases + if (transform_type.empty()) { + // No transform, just load pointers on device + return {load_data_ptrs_on_device(tensors, device), std::nullopt}; + } + if (num_tensors == 0) { + // No input tensors, return tensor with no elements + return {at::empty({int64_t{0}}, at::TensorOptions().dtype(at::kLong).device(device)), + std::nullopt}; } - // Swizzle scales and collect scale pointers - at::Tensor swizzled_scales_keepalive; - std::vector scale_host_ptrs(num_tensors); + // CUDA stream + auto stream = at::cuda::getCurrentCUDAStream(); - if (swizzle) { - NVTEScalingMode scaling_mode; - transformer_engine::DType scale_dtype; - if (is_fp8_dtype(data_dtype)) { + // Swizzle scales for GEMM, with uniform tensor sizes + const bool uniform_mxfp8_rowwise_swizzle = transform_type == "uniform_mxfp8_rowwise_swizzle"; + const bool uniform_mxfp8_colwise_swizzle = transform_type == "uniform_mxfp8_columnwise_swizzle"; + const bool uniform_nvfp4_swizzle = transform_type == "uniform_nvfp4_swizzle"; + if (uniform_mxfp8_rowwise_swizzle || uniform_mxfp8_colwise_swizzle || uniform_nvfp4_swizzle) { + // Tensor format + NVTEScalingMode scaling_mode = NVTE_INVALID_SCALING; + if (uniform_mxfp8_rowwise_swizzle || uniform_mxfp8_colwise_swizzle) { scaling_mode = NVTE_MXFP8_1D_SCALING; - scale_dtype = transformer_engine::DType::kFloat8E8M0; - } else if (is_fp4_dtype(data_dtype)) { + } else if (uniform_nvfp4_swizzle) { scaling_mode = NVTE_NVFP4_1D_SCALING; - scale_dtype = transformer_engine::DType::kFloat8E4M3; - } else { - NVTE_ERROR("data_dtype must be an FP8 or FP4 type for swizzling."); } - // Compute output buffer size for swizzled scales (16B aligned per tensor) - std::vector output_offsets; - size_t output_bytes = 0; - for (size_t i = 0; i < num_tensors; ++i) { - const size_t scale_numel = static_cast(scale_tensors[i].numel()); - const size_t dtype_bits = transformer_engine::pytorch::typeToNumBits(scale_dtype); - output_bytes = roundup(output_bytes, 16); - output_offsets.push_back(output_bytes); - output_bytes += ceildiv(scale_numel * dtype_bits, 8); + // Data types + transformer_engine::DType data_dtype, scale_dtype; + switch (scaling_mode) { + case NVTE_MXFP8_1D_SCALING: + data_dtype = transformer_engine::DType::kFloat8E4M3; + scale_dtype = transformer_engine::DType::kFloat8E8M0; + break; + case NVTE_NVFP4_1D_SCALING: + data_dtype = transformer_engine::DType::kFloat4E2M1; + scale_dtype = transformer_engine::DType::kFloat8E4M3; + break; + default: + NVTE_ERROR("Unsupported case."); } - // Allocate single buffer for all swizzled scales - swizzled_scales_keepalive = - allocateSpace(std::vector{output_bytes}, transformer_engine::DType::kByte, false); - uint8_t* output_dptr = reinterpret_cast(getDataPtr(swizzled_scales_keepalive)); + // Scale shape + const NVTEShape scale_shape = convertTorchShape(tensors[0].sizes()); + NVTE_CHECK(scale_shape.ndim == 2, + "Expected 2D scale tensor, but got shape=", getTensorShape(tensors[0]), "."); + const size_t scale_numel = scale_shape.data[0] * scale_shape.data[1]; + const size_t scale_dtype_bits = transformer_engine::pytorch::typeToNumBits(scale_dtype); + const size_t scale_bytes = ceildiv(scale_numel * scale_dtype_bits, 8); + + // Expected data shape + // Note: May not match actual data shape since the scales are padded. + // This is fine since we're not actually touching the data. + NVTEShape data_shape; + data_shape.ndim = 2; + if (uniform_mxfp8_rowwise_swizzle) { + data_shape.data[0] = scale_shape.data[0]; + data_shape.data[1] = scale_shape.data[1] * 32; + } else if (uniform_mxfp8_colwise_swizzle) { + data_shape.data[0] = scale_shape.data[0] * 32; + data_shape.data[1] = scale_shape.data[1]; + } else if (uniform_nvfp4_swizzle) { + data_shape.data[0] = scale_shape.data[0]; + data_shape.data[1] = scale_shape.data[1] * 16; + } else { + NVTE_ERROR("Unsupported case."); + } + + // Allocate single buffer for swizzled scales. + // Uses a uniform stride since all tensors share the same scale shape. + const size_t swizzled_scales_stride = roundup(scale_bytes, 16); // Align to 16 bytes + auto swizzled_scales = at::empty({static_cast(swizzled_scales_stride * num_tensors)}, + at::TensorOptions().dtype(at::kByte).device(device)); + uint8_t *swizzled_scales_dptr = reinterpret_cast(swizzled_scales.data_ptr()); + + // Allocate input/output NVTETensors as a single batch. The first + // num_tensors entries are inputs; the next num_tensors are outputs. + std::vector nvte_tensors(2 * num_tensors); + nvte_create_tensors(scaling_mode, nvte_tensors.data(), nvte_tensors.size()); + struct DestroyGuard { + NVTETensor *data; + size_t n; + ~DestroyGuard() { nvte_destroy_tensors(data, n); } + } destroy_guard{nvte_tensors.data(), nvte_tensors.size()}; + NVTETensor *inputs_nvte = nvte_tensors.data(); + NVTETensor *outputs_nvte = nvte_tensors.data() + num_tensors; + + auto set_param = [](NVTETensor t, NVTETensorParam param, void *dptr, + transformer_engine::DType dtype, const NVTEShape &shape) { + NVTEBasicTensor data{dptr, static_cast(dtype), shape}; + nvte_set_tensor_param_v2(t, param, &data, sizeof(data)); + }; - // Build TensorWrapper input/output pairs and get scale shapes - std::vector inputs_nvte, outputs_nvte; - inputs_nvte.reserve(num_tensors); - outputs_nvte.reserve(num_tensors); for (size_t i = 0; i < num_tensors; ++i) { - inputs_nvte.emplace_back(scaling_mode); - outputs_nvte.emplace_back(scaling_mode); - auto& input_nvte = inputs_nvte.back(); - auto& output_nvte = outputs_nvte.back(); - output_nvte.set_with_gemm_swizzled_scales(true); - - NVTEShape scale_shape = convertTorchShape(scale_tensors[i].sizes()); - void* scale_ptr = scale_tensors[i].data_ptr(); - uint8_t* out_scale_ptr = output_dptr + output_offsets[i]; - - if (rowwise) { - input_nvte.set_rowwise_data(nullptr, data_dtype, data_shape); - input_nvte.set_rowwise_scale_inv(scale_ptr, scale_dtype, scale_shape); - output_nvte.set_rowwise_data(nullptr, data_dtype, data_shape); - output_nvte.set_rowwise_scale_inv(out_scale_ptr, scale_dtype, scale_shape); - } else { - input_nvte.set_columnwise_data(nullptr, data_dtype, data_shape); - input_nvte.set_columnwise_scale_inv(scale_ptr, scale_dtype, scale_shape); - output_nvte.set_columnwise_data(nullptr, data_dtype, data_shape); - output_nvte.set_columnwise_scale_inv(out_scale_ptr, scale_dtype, scale_shape); + const uint8_t swizzled_flag = 1; + nvte_set_tensor_param_v2(outputs_nvte[i], kNVTEWithGEMMSwizzledScales, &swizzled_flag, + sizeof(swizzled_flag)); + void *in_scale_ptr = tensors[i].data_ptr(); + void *out_scale_ptr = swizzled_scales_dptr + i * swizzled_scales_stride; + if (uniform_mxfp8_rowwise_swizzle || uniform_nvfp4_swizzle) { + set_param(inputs_nvte[i], kNVTERowwiseData, nullptr, data_dtype, data_shape); + set_param(inputs_nvte[i], kNVTERowwiseScaleInv, in_scale_ptr, scale_dtype, scale_shape); + set_param(outputs_nvte[i], kNVTERowwiseData, nullptr, data_dtype, data_shape); + set_param(outputs_nvte[i], kNVTERowwiseScaleInv, out_scale_ptr, scale_dtype, scale_shape); + } else if (uniform_mxfp8_colwise_swizzle) { + set_param(inputs_nvte[i], kNVTEColumnwiseData, nullptr, data_dtype, data_shape); + set_param(inputs_nvte[i], kNVTEColumnwiseScaleInv, in_scale_ptr, scale_dtype, scale_shape); + set_param(outputs_nvte[i], kNVTEColumnwiseData, nullptr, data_dtype, data_shape); + set_param(outputs_nvte[i], kNVTEColumnwiseScaleInv, out_scale_ptr, scale_dtype, + scale_shape); } } - // Pack raw NVTETensors and launch swizzle kernel - std::vector inputs_raw, outputs_raw; - inputs_raw.reserve(num_tensors); - outputs_raw.reserve(num_tensors); - for (auto& t : inputs_nvte) inputs_raw.push_back(t.data()); - for (auto& t : outputs_nvte) outputs_raw.push_back(t.data()); - - nvte_multi_tensor_swizzle_scaling_factors(inputs_raw.data(), outputs_raw.data(), num_tensors, - stream); + // Launch kernel + nvte_multi_tensor_swizzle_scaling_factors(inputs_nvte, outputs_nvte, num_tensors, stream); - // Collect swizzled scale pointers + // Collect data pointers + std::vector ptrs_host; + ptrs_host.reserve(num_tensors); for (size_t i = 0; i < num_tensors; ++i) { - scale_host_ptrs[i] = reinterpret_cast(output_dptr + output_offsets[i]); + ptrs_host.push_back( + reinterpret_cast(swizzled_scales_dptr + i * swizzled_scales_stride)); } - } else { - swizzled_scales_keepalive = at::empty({0}, at::TensorOptions().dtype(at::kByte).device(device)); - for (size_t i = 0; i < num_tensors; ++i) { - scale_host_ptrs[i] = reinterpret_cast(scale_tensors[i].data_ptr()); - } - } - // Convert pointer arrays to device tensors - auto data_ptrs = collect_pointers_in_device_tensor(data_host_ptrs, device, stream); - auto scale_ptrs = collect_pointers_in_device_tensor(scale_host_ptrs, device, stream); + // Load pointers on device + auto ptrs_device = at::empty({static_cast(num_tensors)}, + at::TensorOptions().dtype(at::kLong).device(device)); + nvte_load_value_on_device(ptrs_host.data(), ptrs_device.data_ptr(), + num_tensors * sizeof(uint64_t), stream); + + return {std::move(ptrs_device), std::move(swizzled_scales)}; + } - return {std::move(data_ptrs), std::move(scale_ptrs), std::move(swizzled_scales_keepalive)}; + // Unsupported transform + NVTE_ERROR("Unsupported transform type (", transform_type, ")"); } } // namespace transformer_engine::pytorch diff --git a/transformer_engine/pytorch/ops/fused/backward_grouped_mlp.py b/transformer_engine/pytorch/ops/fused/backward_grouped_mlp.py index a11d0505c1..3a729874a1 100644 --- a/transformer_engine/pytorch/ops/fused/backward_grouped_mlp.py +++ b/transformer_engine/pytorch/ops/fused/backward_grouped_mlp.py @@ -115,7 +115,7 @@ def _cudnn_compute_wgrad( ) else: # Discrete mode: per-expert wgrad device pointers - (wgrad_ptrs,) = tex.convert_host_pointers_to_tensor([wgrad_output]) + wgrad_ptrs = tex.load_data_ptrs_on_device(wgrad_output, wgrad_output[0].device) wgrad_kernel_fn( a_tensor=a_tensor, b_tensor=b_tensor, @@ -498,12 +498,14 @@ def fuser_backward( fc2_dglu_kwargs["b_tensor"] = fc2_w_data fc2_dglu_kwargs["sfb_tensor"] = fc2_w_scales else: - fc2_b_ptrs, fc2_sfb_ptrs, _fc2_sw = tex.get_device_pointer_for_data_and_scales( + fc2_b_ptrs = tex.load_data_ptrs_on_device( [w._columnwise_data for w in grouped_fc2_weight], + device, + ) + fc2_sfb_ptrs, _fc2_sfb_buffer = tex.transform_and_load_data_ptrs_on_device( + "uniform_mxfp8_columnwise_swizzle", [w._columnwise_scale_inv for w in grouped_fc2_weight], - swizzle=True, - rowwise=False, - data_dtype=grouped_fc2_weight[0]._fp8_dtype, + device, ) fc2_dglu_kwargs["b_ptrs"] = fc2_b_ptrs fc2_dglu_kwargs["sfb_ptrs"] = fc2_sfb_ptrs @@ -655,14 +657,15 @@ def fuser_backward( fc1_dgrad_kwargs["b_tensor"] = fc1_w_data fc1_dgrad_kwargs["sfb_tensor"] = fc1_w_scales else: - fc1_b_ptrs, fc1_sfb_ptrs, _ = tex.get_device_pointer_for_data_and_scales( + fc1_b_ptrs = tex.load_data_ptrs_on_device( [w._columnwise_data for w in grouped_fc1_weight], + device, + ) + fc1_sfb_ptrs, _fc1_sfb_buffer = tex.transform_and_load_data_ptrs_on_device( + "uniform_mxfp8_columnwise_swizzle", [w._columnwise_scale_inv for w in grouped_fc1_weight], - swizzle=True, - rowwise=False, - data_dtype=grouped_fc1_weight[0]._fp8_dtype, + device, ) - fc1_dgrad_kwargs["b_ptrs"] = fc1_b_ptrs fc1_dgrad_kwargs["sfb_ptrs"] = fc1_sfb_ptrs fc1_dgrad_kwargs["n"] = fc1_weight_shape[1] diff --git a/transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py b/transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py index 91db2ff9b7..65a3c94aaf 100644 --- a/transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py +++ b/transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py @@ -333,12 +333,14 @@ def fuser_forward( fc1_glu_kwargs["sfb_tensor"] = fc1_w_scales else: # Discrete-weight kernel: per-expert data/scale pointers - fc1_b_ptrs, fc1_sfb_ptrs, _fc1_sw = tex.get_device_pointer_for_data_and_scales( + fc1_b_ptrs = tex.load_data_ptrs_on_device( [w._rowwise_data for w in grouped_fc1_weight], + device, + ) + fc1_sfb_ptrs, _fc1_sfb_buffer = tex.transform_and_load_data_ptrs_on_device( + "uniform_mxfp8_rowwise_swizzle", [w._rowwise_scale_inv for w in grouped_fc1_weight], - swizzle=True, - rowwise=True, - data_dtype=grouped_fc1_weight[0]._fp8_dtype, + device, ) fc1_glu_kwargs["b_ptrs"] = fc1_b_ptrs fc1_glu_kwargs["sfb_ptrs"] = fc1_sfb_ptrs @@ -432,12 +434,14 @@ def fuser_forward( fc2_quant_kwargs["b_tensor"] = fc2_w_data fc2_quant_kwargs["sfb_tensor"] = fc2_w_scales else: - fc2_b_ptrs, fc2_sfb_ptrs, _ = tex.get_device_pointer_for_data_and_scales( + fc2_b_ptrs = tex.load_data_ptrs_on_device( [w._rowwise_data for w in grouped_fc2_weight], + device, + ) + fc2_sfb_ptrs, _fc2_sfb_buffer = tex.transform_and_load_data_ptrs_on_device( + "uniform_mxfp8_rowwise_swizzle", [w._rowwise_scale_inv for w in grouped_fc2_weight], - swizzle=True, - rowwise=True, - data_dtype=grouped_fc2_weight[0]._fp8_dtype, + device, ) fc2_quant_kwargs["b_ptrs"] = fc2_b_ptrs fc2_quant_kwargs["sfb_ptrs"] = fc2_sfb_ptrs