Skip to content

Comm gemm fixes#2818

Open
almogsegal wants to merge 2 commits intoNVIDIA:mainfrom
almogsegal:comm-gemm-fixes
Open

Comm gemm fixes#2818
almogsegal wants to merge 2 commits intoNVIDIA:mainfrom
almogsegal:comm-gemm-fixes

Conversation

@almogsegal
Copy link
Copy Markdown

Description

Please include a brief summary of the changes, relevant motivation and context.

Fixes # (issue)

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • Change A
  • Change B

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps bot commented Mar 31, 2026

Greptile Summary

This PR fixes two bugs in GemmRsInitMatrices within the cuBLASMp-backed comm-GEMM implementation:

  1. Leading dimension (lld) correction for transposed-B matrix descriptor — When transb == true, the B matrix is laid out as n × k on a row-major grid (1 × nranks), so all n rows reside locally on each rank. The lld (leading dimension of the local column-major buffer) must therefore be n, not block_size(ctx, n). The old value was incorrect and would have caused cuBLASMp to use the wrong stride when reading the local B tile.

  2. Explicit COMMUNICATION_TYPE attribute for reduce-scatter — The reduce-scatter output communication type is now explicitly set to match d->dtype(). Without this, cuBLASMp would use its default communication dtype, which may not match the output tensor's actual dtype (e.g. FP8 or BF16), producing incorrect results.

  • Both changes are confined to GemmRsInitMatrices; the analogous AgGemmInitMatrices and GemmArInitMatrices paths are unchanged.
  • The PR description is empty and provides no context, motivation, or test coverage notes, making it harder to assess completeness.
  • No tests are added to validate the fixes.

Confidence Score: 4/5

The two fixes in GemmRsInitMatrices appear correct; a question remains about whether GemmArInitMatrices needs an analogous COMMUNICATION_TYPE fix.

Both changes look like valid correctness fixes: the lld correction aligns with how all other row-major grid descriptors in this file compute their leading dimension, and setting COMMUNICATION_TYPE to d->dtype() is necessary for non-default output types. The only uncertainty is whether GemmArInitMatrices (AllReduce path) also needs the communication type fix.

transformer_engine/common/comm_gemm/comm_gemm.cpp — specifically GemmArInitMatrices (lines 210-246), which may need an analogous COMMUNICATION_TYPE attribute set.

Important Files Changed

Filename Overview
transformer_engine/common/comm_gemm/comm_gemm.cpp Two bug fixes in GemmRsInitMatrices: (1) corrects lld parameter from block_size(ctx, n) to n for transposed-B matrix descriptor using row-major grid; (2) adds explicit COMMUNICATION_TYPE attribute to match output tensor dtype for reduce-scatter ops. Analogous fix may be needed in GemmArInitMatrices.

Sequence Diagram

sequenceDiagram
    participant Caller
    participant cublasmp_gemm
    participant GemmRsInitMatrices
    participant cuBLASMp

    Caller->>cublasmp_gemm: nvte_gemm_reduce_scatter(ctx, m, n, k, a, b, d, ...)
    cublasmp_gemm->>cuBLASMp: cublasMpMatmulDescriptorInit(matmul_desc, COMPUTE_32F)
    cublasmp_gemm->>GemmRsInitMatrices: init_matrices_fn(ctx, &ldd, m, n, k, a, b, d, transa, transb)

    alt transb == true
        GemmRsInitMatrices->>cuBLASMp: cublasMpMatrixDescriptorInit(n, k, block_n, block_k, 0, 0, n [FIXED], type, row_major, b_desc)
        Note over GemmRsInitMatrices,cuBLASMp: lld was block_size(n), now correctly n
    else transb == false
        GemmRsInitMatrices->>cuBLASMp: cublasMpMatrixDescriptorInit(k, n, block_k, block_n, 0, 0, block_k, type, col_major, b_desc)
    end

    GemmRsInitMatrices->>cuBLASMp: cublasMpMatrixDescriptorInit(D, ldd=m, row_major)
    GemmRsInitMatrices->>cuBLASMp: cublasMpMatmulDescriptorSetAttribute(COMMUNICATION_TYPE, d->dtype()) [NEW]

    cublasmp_gemm->>cuBLASMp: Set TRANSA, TRANSB, ALGO_TYPE, scale attrs...
    cublasmp_gemm->>cuBLASMp: cublasMpMatmul_bufferSize(...)
    cublasmp_gemm->>cuBLASMp: cublasMpMatmul(...) — GEMM + ReduceScatter
    cuBLASMp-->>Caller: Reduce-scattered result in D
Loading

Reviews (2): Last reviewed commit: "Set GemmRs communication type to output ..." | Re-trigger Greptile

@almogsegal almogsegal force-pushed the comm-gemm-fixes branch 2 times, most recently from fa9fef5 to 263c33f Compare March 31, 2026 17:11
With a row_major (1×P) grid, all rows are on a single process row,
so the local leading dimension must be n (full row count), not
block_size(n) which is n/P.

Signed-off-by: Almog Segal <asegal@nvidia.com>
Match the UserBuffers behavior where the reduce-scatter operates
in the output precision rather than FP32.

Signed-off-by: Almog Segal <asegal@nvidia.com>
@ptrendx
Copy link
Copy Markdown
Member

ptrendx commented Mar 31, 2026

/te-ci L1

@timmoon10 timmoon10 requested a review from vcherepanov-nv April 3, 2026 00:07
@ptrendx
Copy link
Copy Markdown
Member

ptrendx commented Apr 3, 2026

/te-ci L1

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants