Skip to content

ck_tile grouped gemm: more padding#574

Open
matthiasdiener wants to merge 8 commits into
devfrom
mdiener/cktile-grouped-gemm-padding
Open

ck_tile grouped gemm: more padding#574
matthiasdiener wants to merge 8 commits into
devfrom
mdiener/cktile-grouped-gemm-padding

Conversation

@matthiasdiener
Copy link
Copy Markdown
Contributor

@matthiasdiener matthiasdiener commented May 5, 2026

Description

Enabling padding always causes a significant (~15%) reduction in speed, so only enable it when necessary.

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • Change A
  • Change B

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

@matthiasdiener matthiasdiener requested a review from sudhu2k May 5, 2026 00:15
@matthiasdiener matthiasdiener self-assigned this May 5, 2026
@matthiasdiener matthiasdiener requested review from aris134 May 5, 2026 15:36
Comment thread transformer_engine/common/gemm/ck_grouped_gemm/ck_grouped_gemm_common.h Outdated
Comment thread tests/pytorch/test_numerics.py Outdated
Comment thread tests/pytorch/test_numerics.py Outdated
@matthiasdiener matthiasdiener added the ci-level 1 CI test level 1 label May 15, 2026
@matthiasdiener matthiasdiener marked this pull request as ready for review May 15, 2026 22:24
@matthiasdiener matthiasdiener requested a review from aris134 May 15, 2026 22:24
@matthiasdiener matthiasdiener requested a review from aris134 May 19, 2026 19:13
Comment thread tests/pytorch/test_numerics.py Outdated
reason="Only enable CUTLASS/CK grouped gemm on Hopper or ROCm",
)
@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16], ids=str)
@pytest.mark.parametrize("layout", ["TN", "NN", "NT"])
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TT?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added TT in aee2c4c.

Comment thread tests/pytorch/test_numerics.py Outdated
k_val = k_aligned
m_vals = [m_aligned] * z
n_val = unaligned_n

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would we want an MKN unaligned test? Would that cover something that isn't included in the current test sweep?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added an MKN test in aee2c4c.

Comment thread tests/pytorch/test_numerics.py Outdated
Comment on lines +3173 to +3188
if pad_dim == "K":
k_val = unaligned_k
m_vals = [m_aligned] * z
n_val = n_aligned
elif pad_dim == "M":
k_val = k_aligned
m_vals = unaligned_m
n_val = n_aligned
elif pad_dim == "MK":
k_val = unaligned_k
m_vals = unaligned_m
n_val = n_aligned
else: # N
k_val = k_aligned
m_vals = [m_aligned] * z
n_val = unaligned_n
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we factor out this if-elif-elif-else block that seems repeated for each layout?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Restructured this in aee2c4c

}
return launch_grouped_gemm_kernel<Kernel>(descs, ctx, stream_cfg);
// Dispatch with B's columnwise buffer as RowMajor (transB=false).
GroupedGemmRunContext ctx_nn = ctx;
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: ctx_nn seems a bit misleading since this only rewrites B as non-transposed via columnwise_data; A can still be T or N. Maybe rename to something like ctx_b_colwise?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Renamed in aee2c4c.

Comment thread tests/pytorch/test_numerics.py Outdated
grad = True
single_output = True
else: # NT
# NT GEMM: out[i] = A[i]^T @ B[i], A[i]: (m_i, k), B[i]: (m_i, n), out[i]: (n, k)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: this comment is a little confusing. For the grouped path, the user-facing NT inputs are A=(m_i,k), B=(m_i,n), out=(n,k), but normalization swaps operands/layouts before dispatch, so the actual dispatched gemm is B^T @ A = (n,m_i) @ (m_i,k).

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right, I removed this comment in aee2c4c

@matthiasdiener matthiasdiener requested a review from aris134 May 20, 2026 22:06
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ci-level 1 CI test level 1

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants