From e32f3b8c854e3a902b3b168ab5aa2ec3994f318d Mon Sep 17 00:00:00 2001 From: Almog Segal Date: Tue, 31 Mar 2026 18:18:17 +0300 Subject: [PATCH 1/2] Fix GemmRs B descriptor lld for transb=true MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit With a row_major (1×P) grid, all rows are on a single process row, so the local leading dimension must be n (full row count), not block_size(n) which is n/P. Signed-off-by: Almog Segal --- transformer_engine/common/comm_gemm/comm_gemm.cpp | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/transformer_engine/common/comm_gemm/comm_gemm.cpp b/transformer_engine/common/comm_gemm/comm_gemm.cpp index 7be3d1bb4d..1a0079956f 100644 --- a/transformer_engine/common/comm_gemm/comm_gemm.cpp +++ b/transformer_engine/common/comm_gemm/comm_gemm.cpp @@ -186,9 +186,9 @@ void GemmRsInitMatrices(NVTECommGemmCtx* ctx, int64_t* ldd, int64_t m, int64_t n } if (transb) { NVTE_CHECK(b1 == n, "Unsupported tensor dimension in B: expected ", n, ", got ", b1); - NVTE_CHECK_CUBLASMP(cublasMpMatrixDescriptorInit( - n, k, block_size(ctx, n), block_size(ctx, k), 0, 0, block_size(ctx, n), - get_cuda_dtype(b->dtype()), ctx->grid_row_major.get(), ctx->b_desc.get())); + NVTE_CHECK_CUBLASMP(cublasMpMatrixDescriptorInit(n, k, block_size(ctx, n), block_size(ctx, k), + 0, 0, n, get_cuda_dtype(b->dtype()), + ctx->grid_row_major.get(), ctx->b_desc.get())); } else { NVTE_CHECK(b0 == n, "Unsupported tensor dimension in B: expected ", n, ", got ", b0); NVTE_CHECK_CUBLASMP(cublasMpMatrixDescriptorInit( From 263c33f6762d940525898461071ce65e57e7e2e0 Mon Sep 17 00:00:00 2001 From: Almog Segal Date: Tue, 31 Mar 2026 18:18:32 +0300 Subject: [PATCH 2/2] Set GemmRs communication type to output data type Match the UserBuffers behavior where the reduce-scatter operates in the output precision rather than FP32. Signed-off-by: Almog Segal --- transformer_engine/common/comm_gemm/comm_gemm.cpp | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/transformer_engine/common/comm_gemm/comm_gemm.cpp b/transformer_engine/common/comm_gemm/comm_gemm.cpp index 1a0079956f..a7d78f7ac0 100644 --- a/transformer_engine/common/comm_gemm/comm_gemm.cpp +++ b/transformer_engine/common/comm_gemm/comm_gemm.cpp @@ -200,6 +200,11 @@ void GemmRsInitMatrices(NVTECommGemmCtx* ctx, int64_t* ldd, int64_t m, int64_t n NVTE_CHECK_CUBLASMP(cublasMpMatrixDescriptorInit(m, n, m, block_size(ctx, n), 0, 0, *ldd, get_cuda_dtype(d->dtype()), ctx->grid_row_major.get(), ctx->d_desc.get())); + + const cudaDataType_t comm_type = get_cuda_dtype(d->dtype()); + NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorSetAttribute( + ctx->matmul_desc.get(), CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_COMMUNICATION_TYPE, &comm_type, + sizeof comm_type)); } void GemmArInitMatrices(NVTECommGemmCtx* ctx, int64_t* ldd, int64_t m, int64_t n, int64_t k,