diff --git a/transformer_engine/common/cast/dispatch/gated.cuh b/transformer_engine/common/cast/dispatch/gated.cuh index 06e8f0e306..bf4052b1b0 100644 --- a/transformer_engine/common/cast/dispatch/gated.cuh +++ b/transformer_engine/common/cast/dispatch/gated.cuh @@ -46,7 +46,12 @@ void quantize_gated_fwd_helper(const NVTETensor nvte_input, NVTETensor nvte_outp switch (output->scaling_mode) { case NVTE_DELAYED_TENSOR_SCALING: { - const bool use_tma_kernels = (cols % 32 == 0) && is_supported_by_CC_100(); + //const bool use_tma_kernels = (cols % 32 == 0) && is_supported_by_CC_100(); + // sm120 shared memory capapbilities are much smaller than sm100, so we disable TMA kernels on sm120 + // KL: It is possible that for fwd, the limits are not exceeded for sm120. To be investigated - + // are there any forward only tests we'd like to keep enabled on sm120? + const bool use_tma_kernels = + (cols % 32 == 0) && is_supported_by_CC_100() && !is_supported_by_CC_120(); if (use_tma_kernels) { Tensor dummy_grad_tensor; fp8::cast_gated_tma(input, dummy_grad_tensor, @@ -137,7 +142,10 @@ void quantize_gated_bwd_helper(const NVTETensor nvte_grad, const NVTETensor nvte switch (output->scaling_mode) { case NVTE_DELAYED_TENSOR_SCALING: { - const bool use_tma_kernels = (cols % 32 == 0) && is_supported_by_CC_100(); + //const bool use_tma_kernels = (cols % 32 == 0) && is_supported_by_CC_100(); + // sm120 shared memory capapbilities are much smaller than sm100, so we disable TMA kernels on sm120 + const bool use_tma_kernels = + (cols % 32 == 0) && is_supported_by_CC_100() && !is_supported_by_CC_120(); if (use_tma_kernels) { fp8::cast_gated_tma(gated_input, grad, output, p, stream); diff --git a/transformer_engine/common/common.cu b/transformer_engine/common/common.cu index 1bdd80a369..20a2021e56 100644 --- a/transformer_engine/common/common.cu +++ b/transformer_engine/common/common.cu @@ -287,6 +287,13 @@ bool is_supported_by_CC_100() { return deviceComputeCapability >= 100; } +// KL: test function for CC 120 +bool is_supported_by_CC_120() { + int deviceComputeCapability = cuda::sm_arch(cuda::current_device()); + + return deviceComputeCapability == 120; +} + std::vector> convert_tensor_array(NVTETensor **nvte_tensors, size_t outer_size, size_t inner_size) { std::vector> ret; diff --git a/transformer_engine/common/common.h b/transformer_engine/common/common.h index 6e207370dd..4f6ea21c84 100644 --- a/transformer_engine/common/common.h +++ b/transformer_engine/common/common.h @@ -1029,6 +1029,8 @@ void create_2D_tensor_map( bool is_supported_by_CC_100(); +bool is_supported_by_CC_120(); + std::vector> convert_tensor_array(NVTETensor **nvte_tensors, size_t outer_size, size_t inner_size);