Conversation
Greptile SummaryThis PR fixes two bugs in
Confidence Score: 4/5The two fixes in GemmRsInitMatrices appear correct; a question remains about whether GemmArInitMatrices needs an analogous COMMUNICATION_TYPE fix. Both changes look like valid correctness fixes: the lld correction aligns with how all other row-major grid descriptors in this file compute their leading dimension, and setting COMMUNICATION_TYPE to d->dtype() is necessary for non-default output types. The only uncertainty is whether GemmArInitMatrices (AllReduce path) also needs the communication type fix. transformer_engine/common/comm_gemm/comm_gemm.cpp — specifically GemmArInitMatrices (lines 210-246), which may need an analogous COMMUNICATION_TYPE attribute set. Important Files Changed
Sequence DiagramsequenceDiagram
participant Caller
participant cublasmp_gemm
participant GemmRsInitMatrices
participant cuBLASMp
Caller->>cublasmp_gemm: nvte_gemm_reduce_scatter(ctx, m, n, k, a, b, d, ...)
cublasmp_gemm->>cuBLASMp: cublasMpMatmulDescriptorInit(matmul_desc, COMPUTE_32F)
cublasmp_gemm->>GemmRsInitMatrices: init_matrices_fn(ctx, &ldd, m, n, k, a, b, d, transa, transb)
alt transb == true
GemmRsInitMatrices->>cuBLASMp: cublasMpMatrixDescriptorInit(n, k, block_n, block_k, 0, 0, n [FIXED], type, row_major, b_desc)
Note over GemmRsInitMatrices,cuBLASMp: lld was block_size(n), now correctly n
else transb == false
GemmRsInitMatrices->>cuBLASMp: cublasMpMatrixDescriptorInit(k, n, block_k, block_n, 0, 0, block_k, type, col_major, b_desc)
end
GemmRsInitMatrices->>cuBLASMp: cublasMpMatrixDescriptorInit(D, ldd=m, row_major)
GemmRsInitMatrices->>cuBLASMp: cublasMpMatmulDescriptorSetAttribute(COMMUNICATION_TYPE, d->dtype()) [NEW]
cublasmp_gemm->>cuBLASMp: Set TRANSA, TRANSB, ALGO_TYPE, scale attrs...
cublasmp_gemm->>cuBLASMp: cublasMpMatmul_bufferSize(...)
cublasmp_gemm->>cuBLASMp: cublasMpMatmul(...) — GEMM + ReduceScatter
cuBLASMp-->>Caller: Reduce-scattered result in D
Reviews (2): Last reviewed commit: "Set GemmRs communication type to output ..." | Re-trigger Greptile |
fa9fef5 to
263c33f
Compare
With a row_major (1×P) grid, all rows are on a single process row, so the local leading dimension must be n (full row count), not block_size(n) which is n/P. Signed-off-by: Almog Segal <asegal@nvidia.com>
Match the UserBuffers behavior where the reduce-scatter operates in the output precision rather than FP32. Signed-off-by: Almog Segal <asegal@nvidia.com>
|
/te-ci L1 |
|
/te-ci L1 |
Description
Please include a brief summary of the changes, relevant motivation and context.
Fixes # (issue)
Type of change
Changes
Please list the changes introduced in this PR:
Checklist: