diff --git a/transformer_engine/common/comm_gemm/comm_gemm.cpp b/transformer_engine/common/comm_gemm/comm_gemm.cpp index 7be3d1bb4d..a7d78f7ac0 100644 --- a/transformer_engine/common/comm_gemm/comm_gemm.cpp +++ b/transformer_engine/common/comm_gemm/comm_gemm.cpp @@ -186,9 +186,9 @@ void GemmRsInitMatrices(NVTECommGemmCtx* ctx, int64_t* ldd, int64_t m, int64_t n } if (transb) { NVTE_CHECK(b1 == n, "Unsupported tensor dimension in B: expected ", n, ", got ", b1); - NVTE_CHECK_CUBLASMP(cublasMpMatrixDescriptorInit( - n, k, block_size(ctx, n), block_size(ctx, k), 0, 0, block_size(ctx, n), - get_cuda_dtype(b->dtype()), ctx->grid_row_major.get(), ctx->b_desc.get())); + NVTE_CHECK_CUBLASMP(cublasMpMatrixDescriptorInit(n, k, block_size(ctx, n), block_size(ctx, k), + 0, 0, n, get_cuda_dtype(b->dtype()), + ctx->grid_row_major.get(), ctx->b_desc.get())); } else { NVTE_CHECK(b0 == n, "Unsupported tensor dimension in B: expected ", n, ", got ", b0); NVTE_CHECK_CUBLASMP(cublasMpMatrixDescriptorInit( @@ -200,6 +200,11 @@ void GemmRsInitMatrices(NVTECommGemmCtx* ctx, int64_t* ldd, int64_t m, int64_t n NVTE_CHECK_CUBLASMP(cublasMpMatrixDescriptorInit(m, n, m, block_size(ctx, n), 0, 0, *ldd, get_cuda_dtype(d->dtype()), ctx->grid_row_major.get(), ctx->d_desc.get())); + + const cudaDataType_t comm_type = get_cuda_dtype(d->dtype()); + NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorSetAttribute( + ctx->matmul_desc.get(), CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_COMMUNICATION_TYPE, &comm_type, + sizeof comm_type)); } void GemmArInitMatrices(NVTECommGemmCtx* ctx, int64_t* ldd, int64_t m, int64_t n, int64_t k,