Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
05922ba
Remove unnecessary heap allocations
timmoon10 May 15, 2026
2230422
Avoid heap allocation in Tensor::flat_first_dim/flat_last_dim
timmoon10 May 15, 2026
1fbbec8
Add Tensor::flat_2d_dims() to compute both matrix dims in one pass
timmoon10 May 15, 2026
dd3a3a7
Use flat_2d_dims() throughout common lib
timmoon10 May 15, 2026
7512074
Generalize API for CUDA-Graph-safe copy to GPU.
timmoon10 May 16, 2026
5e6e31f
Dedup swizzle logic in get_device_pointer_for_data_and_scales
timmoon10 May 16, 2026
76ba3a9
Make separate functions for load data_ptrs and swizzle + load data_ptrs.
timmoon10 May 16, 2026
452ce3a
Change function name to nvte_load_value_on_device
timmoon10 May 16, 2026
c2ef7d3
Fix code review issues before opening PR
timmoon10 May 16, 2026
24e9e7f
Merge branch 'main' into tmoon/optimize-get_device_pointer_for_data_a…
timmoon10 May 16, 2026
48cc585
Formatter and review suggestions from @greptile-apps
timmoon10 May 16, 2026
1518550
Add Shape class wrapping NVTEShape
timmoon10 May 20, 2026
9774571
Make SimpleTensor stack-allocatable
timmoon10 May 20, 2026
72c58b9
Merge branch 'main' into tmoon/optimize-get_device_pointer_for_data_a…
timmoon10 May 20, 2026
49094b7
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 20, 2026
02d12f2
Make Shape conversion constructors explicit
timmoon10 May 20, 2026
7663196
Make conversion from Shape to std::vector explicit
timmoon10 May 21, 2026
87cdfa1
Add batched NVTETensor create/destroy
timmoon10 May 21, 2026
1e38cff
Use batched NVTETensor allocator in transform_and_load_data_ptrs_on_d…
timmoon10 May 21, 2026
cb0a7e6
Expand usage of batched NVTETensor allocator
timmoon10 May 21, 2026
a4e20b2
Use string_view in tensor checking functions
timmoon10 May 21, 2026
ca98d46
Merge branch 'main' into tmoon/optimize-get_device_pointer_for_data_a…
timmoon10 May 21, 2026
b8d77f5
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 21, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 3 additions & 6 deletions transformer_engine/common/cast/dispatch/quantize.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -97,8 +97,7 @@ void quantize_fwd_helper(const NVTETensor input, NVTETensor output,
CheckOutputTensor(*output_tensor, "output", false);

// Choose kernel
int32_t rows = input_tensor->flat_first_dim();
int32_t cols = input_tensor->flat_last_dim();
const auto [rows, cols] = input_tensor->flat_2d_dims();
auto dtype = input_tensor->dtype();
const bool row_scaled_nvfp4 = output_tensor->row_scaled_nvfp4;
if (row_scaled_nvfp4) {
Expand Down Expand Up @@ -246,8 +245,7 @@ void quantize_bwd_helper(const NVTETensor grad, const NVTETensor input, NVTETens
CheckOutputTensor(*output_tensor, "output", false);

// Choose kernel
int32_t rows = grad_tensor->flat_first_dim();
int32_t cols = grad_tensor->flat_last_dim();
const auto [rows, cols] = grad_tensor->flat_2d_dims();
auto dtype = grad_tensor->dtype();
NVTE_CHECK(!output_tensor->row_scaled_nvfp4,
"Backward NVFP4 quantization does not support row-scaled outputs.");
Expand Down Expand Up @@ -368,8 +366,7 @@ void group_quantize_fwd_host_aware_helper(const NVTETensor input, NVTETensor *ou
// output list here is allowed to have empty tensor

// Choose kernel
int32_t rows = input_tensor->flat_first_dim();
int32_t cols = input_tensor->flat_last_dim();
const auto [rows, cols] = input_tensor->flat_2d_dims();
auto dtype = input_tensor->dtype();

NVTE_CHECK(!quant_config_cpp.nvfp4_2d_quantization,
Expand Down
3 changes: 1 addition & 2 deletions transformer_engine/common/cast/fp8/quantize_fp8.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -391,8 +391,7 @@ void quantize_2D(const Tensor &input, const Tensor *act_input, Tensor *output, T
using namespace quantize_2D_kernel;
checkCuDriverContext(stream);

const size_t rows = input.flat_first_dim();
const size_t cols = input.flat_last_dim();
const auto [rows, cols] = input.flat_2d_dims();
const size_t chunks_Y = DIVUP(rows, FP8_CHUNK_DIM_Y);
const size_t chunks_X = DIVUP(cols, FP8_CHUNK_DIM_X);
const size_t blocks_Y = chunks_Y;
Expand Down
3 changes: 1 addition & 2 deletions transformer_engine/common/cast/mxfp8/dequantize_mxfp8.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -261,8 +261,7 @@ inline void dequantize(const Tensor &input, Tensor *output, cudaStream_t stream)
const size_t scale_dim_X_rowwise = use_rowwise_scaling ? 32 : 1;
const size_t scale_dim_Y_colwise = use_colwise_scaling ? 32 : 1;

const size_t rows = input.flat_first_dim();
const size_t cols = input.flat_last_dim();
const auto [rows, cols] = input.flat_2d_dims();
const size_t chunks_Y = DIVUP(rows, CHUNK_DIM_Y);
const size_t chunks_X = DIVUP(cols, CHUNK_DIM_X);

Expand Down
3 changes: 1 addition & 2 deletions transformer_engine/common/cast/mxfp8/quantize_mxfp8.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -578,8 +578,7 @@ void quantize(const Tensor &input, const Tensor *act_input, const Tensor *noop,
constexpr bool CAST_DBIAS_ONLY = IS_DBIAS && (!IS_DACT) && (!IS_ACT);

// Tensor dimensions
const size_t rows = input.flat_first_dim();
const size_t cols = input.flat_last_dim();
const auto [rows, cols] = input.flat_2d_dims();

// Tensor chunk handled by each CUDA block
constexpr size_t CHUNK_DIM_Y = CAST_DBIAS_ONLY ? 128 : 64;
Expand Down
3 changes: 1 addition & 2 deletions transformer_engine/common/cast/nvfp4/dequantize_nvfp4.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -94,8 +94,7 @@ inline void dequantize(const Tensor &input, Tensor *output, cudaStream_t stream)
const bool row_scaled_nvfp4 = input.row_scaled_nvfp4;

constexpr int FP4_BLOCK_SIZE = 16;
const size_t N = input.flat_first_dim();
const size_t M = input.flat_last_dim();
const auto [N, M] = input.flat_2d_dims();

NVTE_CHECK(M % FP4_BLOCK_SIZE == 0, "Last dimension of FP4 tensors needs to be divisible by ",
FP4_BLOCK_SIZE, ", but got ", input.data.shape, ".");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -783,8 +783,7 @@ void group_quantize_transpose(const Tensor &input, const Tensor *noop,

NVTE_CHECK(input.has_data(), "Cannot quantize tensor without rowwise data.");

const size_t rows = input.flat_first_dim();
const size_t cols = input.flat_last_dim();
const auto [rows, cols] = input.flat_2d_dims();

NVTE_CHECK(rows % 32 == 0,
"Number of tensor rows must be a multiple of 32"); // 16B alignment for TMA
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -121,8 +121,7 @@ inline void compute_rowwise_amax(const Tensor &input, const Tensor *noop, Tensor
#if FP4_TYPE_SUPPORTED
using namespace rowwise_amax_kernel;

const size_t rows = input.flat_first_dim();
const size_t cols = input.flat_last_dim();
const auto [rows, cols] = input.flat_2d_dims();
NVTE_CHECK(cols % ROWWISE_AMAX_SF_VEC_SIZE == 0,
"Row-scaled NVFP4 quantization requires last dim divisible by ",
ROWWISE_AMAX_SF_VEC_SIZE, ".");
Expand Down Expand Up @@ -1359,8 +1358,7 @@ void quantize_transpose(const Tensor &input, const Tensor *noop, Tensor *output,
"Transposed scaling tensor must be allocated");
}

const size_t rows = input.flat_first_dim();
const size_t cols = input.flat_last_dim();
const auto [rows, cols] = input.flat_2d_dims();

NVTE_CHECK(rows % 32 == 0,
"Number of tensor rows must be a multiple of 32"); // 16B alignment for TMA
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -718,8 +718,7 @@ inline void quantize_transpose_tuned_1D(const Tensor &input, const Tensor *noop,
"Transposed scaling tensor must be allocated");
}

const size_t rows = input.flat_first_dim();
const size_t cols = input.flat_last_dim();
const auto [rows, cols] = input.flat_2d_dims();

NVTE_CHECK(rows % 32 == 0,
"Number of tensor rows must be a multiple of 32"); // 16B alignment for TMA
Expand Down
27 changes: 9 additions & 18 deletions transformer_engine/common/comm_gemm/comm_gemm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -130,12 +130,9 @@ int64_t block_size(NVTECommGemmCtx* ctx, int64_t global_size) {
void AgGemmInitMatrices(NVTECommGemmCtx* ctx, int64_t* ldd, int64_t m, int64_t n, int64_t k,
const Tensor* a, const Tensor* b, const Tensor* d, bool transa,
bool transb) {
const auto a0 = a->flat_first_dim();
const auto a1 = a->flat_last_dim();
const auto b0 = b->flat_first_dim();
const auto b1 = b->flat_last_dim();
const auto d0 = d->flat_first_dim();
const auto d1 = d->flat_last_dim();
const auto [a0, a1] = a->flat_2d_dims();
const auto [b0, b1] = b->flat_2d_dims();
const auto [d0, d1] = d->flat_2d_dims();

if (transa) {
NVTE_CHECK(a1 == k, "Unsupported tensor dimension in A: expected ", k, ", got ", a1);
Expand Down Expand Up @@ -169,12 +166,9 @@ void AgGemmInitMatrices(NVTECommGemmCtx* ctx, int64_t* ldd, int64_t m, int64_t n
void GemmRsInitMatrices(NVTECommGemmCtx* ctx, int64_t* ldd, int64_t m, int64_t n, int64_t k,
const Tensor* a, const Tensor* b, const Tensor* d, bool transa,
bool transb) {
const auto a0 = a->flat_first_dim();
const auto a1 = a->flat_last_dim();
const auto b0 = b->flat_first_dim();
const auto b1 = b->flat_last_dim();
const auto d0 = d->flat_first_dim();
const auto d1 = d->flat_last_dim();
const auto [a0, a1] = a->flat_2d_dims();
const auto [b0, b1] = b->flat_2d_dims();
const auto [d0, d1] = d->flat_2d_dims();

if (transa) {
NVTE_CHECK(a0 == m, "Unsupported tensor dimension in A: expected ", m, ", got ", a0);
Expand Down Expand Up @@ -213,12 +207,9 @@ void GemmRsInitMatrices(NVTECommGemmCtx* ctx, int64_t* ldd, int64_t m, int64_t n
void GemmArInitMatrices(NVTECommGemmCtx* ctx, int64_t* ldd, int64_t m, int64_t n, int64_t k,
const Tensor* a, const Tensor* b, const Tensor* d, bool transa,
bool transb) {
const auto a0 = a->flat_first_dim();
const auto a1 = a->flat_last_dim();
const auto b0 = b->flat_first_dim();
const auto b1 = b->flat_last_dim();
const auto d0 = d->flat_first_dim();
const auto d1 = d->flat_last_dim();
const auto [a0, a1] = a->flat_2d_dims();
const auto [b0, b1] = b->flat_2d_dims();
const auto [d0, d1] = d->flat_2d_dims();

if (transa) {
NVTE_CHECK(a0 == m, "Unsupported tensor dimension in A: expected ", m, ", got ", a0);
Expand Down
Loading
Loading