From de46ee2ff4a9b4f07a655e2c6d864f2fa30000d8 Mon Sep 17 00:00:00 2001 From: Kshitij Lakhani Date: Thu, 19 Mar 2026 00:01:31 +0000 Subject: [PATCH 1/2] Do not use fp8::cast_gated_tma for sm120. Instead use the fall back fp8::cast_gated_bwd kernel as sm120 shmem requirements are insufficient Signed-off-by: Kshitij Lakhani --- transformer_engine/common/cast/dispatch/gated.cuh | 10 ++++++++-- transformer_engine/common/common.cu | 7 +++++++ transformer_engine/common/common.h | 2 ++ 3 files changed, 17 insertions(+), 2 deletions(-) diff --git a/transformer_engine/common/cast/dispatch/gated.cuh b/transformer_engine/common/cast/dispatch/gated.cuh index 06e8f0e306..3c13d7094f 100644 --- a/transformer_engine/common/cast/dispatch/gated.cuh +++ b/transformer_engine/common/cast/dispatch/gated.cuh @@ -46,7 +46,11 @@ void quantize_gated_fwd_helper(const NVTETensor nvte_input, NVTETensor nvte_outp switch (output->scaling_mode) { case NVTE_DELAYED_TENSOR_SCALING: { - const bool use_tma_kernels = (cols % 32 == 0) && is_supported_by_CC_100(); + //const bool use_tma_kernels = (cols % 32 == 0) && is_supported_by_CC_100(); + // sm120 shared memory capapbilities are much smaller than sm100, so we disable TMA kernels on sm120 + // KL: It is possible that for fwd, the limits are not exceeded for sm120. To be investigated - + // are there any forward only tests we'd like to keep enabled on sm120? + const bool use_tma_kernels = (cols % 32 == 0) && is_supported_by_CC_100() && !is_supported_by_CC_120(); if (use_tma_kernels) { Tensor dummy_grad_tensor; fp8::cast_gated_tma(input, dummy_grad_tensor, @@ -137,7 +141,9 @@ void quantize_gated_bwd_helper(const NVTETensor nvte_grad, const NVTETensor nvte switch (output->scaling_mode) { case NVTE_DELAYED_TENSOR_SCALING: { - const bool use_tma_kernels = (cols % 32 == 0) && is_supported_by_CC_100(); + //const bool use_tma_kernels = (cols % 32 == 0) && is_supported_by_CC_100(); + // sm120 shared memory capapbilities are much smaller than sm100, so we disable TMA kernels on sm120 + const bool use_tma_kernels = (cols % 32 == 0) && is_supported_by_CC_100() && !is_supported_by_CC_120(); if (use_tma_kernels) { fp8::cast_gated_tma(gated_input, grad, output, p, stream); diff --git a/transformer_engine/common/common.cu b/transformer_engine/common/common.cu index 1bdd80a369..20a2021e56 100644 --- a/transformer_engine/common/common.cu +++ b/transformer_engine/common/common.cu @@ -287,6 +287,13 @@ bool is_supported_by_CC_100() { return deviceComputeCapability >= 100; } +// KL: test function for CC 120 +bool is_supported_by_CC_120() { + int deviceComputeCapability = cuda::sm_arch(cuda::current_device()); + + return deviceComputeCapability == 120; +} + std::vector> convert_tensor_array(NVTETensor **nvte_tensors, size_t outer_size, size_t inner_size) { std::vector> ret; diff --git a/transformer_engine/common/common.h b/transformer_engine/common/common.h index 6e207370dd..4f6ea21c84 100644 --- a/transformer_engine/common/common.h +++ b/transformer_engine/common/common.h @@ -1029,6 +1029,8 @@ void create_2D_tensor_map( bool is_supported_by_CC_100(); +bool is_supported_by_CC_120(); + std::vector> convert_tensor_array(NVTETensor **nvte_tensors, size_t outer_size, size_t inner_size); From 4f4126a3f8b935a5e9eaab32ca938939a96a769e Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 3 Apr 2026 17:30:53 +0000 Subject: [PATCH 2/2] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- transformer_engine/common/cast/dispatch/gated.cuh | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/transformer_engine/common/cast/dispatch/gated.cuh b/transformer_engine/common/cast/dispatch/gated.cuh index 3c13d7094f..bf4052b1b0 100644 --- a/transformer_engine/common/cast/dispatch/gated.cuh +++ b/transformer_engine/common/cast/dispatch/gated.cuh @@ -48,9 +48,10 @@ void quantize_gated_fwd_helper(const NVTETensor nvte_input, NVTETensor nvte_outp case NVTE_DELAYED_TENSOR_SCALING: { //const bool use_tma_kernels = (cols % 32 == 0) && is_supported_by_CC_100(); // sm120 shared memory capapbilities are much smaller than sm100, so we disable TMA kernels on sm120 - // KL: It is possible that for fwd, the limits are not exceeded for sm120. To be investigated - + // KL: It is possible that for fwd, the limits are not exceeded for sm120. To be investigated - // are there any forward only tests we'd like to keep enabled on sm120? - const bool use_tma_kernels = (cols % 32 == 0) && is_supported_by_CC_100() && !is_supported_by_CC_120(); + const bool use_tma_kernels = + (cols % 32 == 0) && is_supported_by_CC_100() && !is_supported_by_CC_120(); if (use_tma_kernels) { Tensor dummy_grad_tensor; fp8::cast_gated_tma(input, dummy_grad_tensor, @@ -143,7 +144,8 @@ void quantize_gated_bwd_helper(const NVTETensor nvte_grad, const NVTETensor nvte case NVTE_DELAYED_TENSOR_SCALING: { //const bool use_tma_kernels = (cols % 32 == 0) && is_supported_by_CC_100(); // sm120 shared memory capapbilities are much smaller than sm100, so we disable TMA kernels on sm120 - const bool use_tma_kernels = (cols % 32 == 0) && is_supported_by_CC_100() && !is_supported_by_CC_120(); + const bool use_tma_kernels = + (cols % 32 == 0) && is_supported_by_CC_100() && !is_supported_by_CC_120(); if (use_tma_kernels) { fp8::cast_gated_tma(gated_input, grad, output, p, stream);