Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 10 additions & 2 deletions transformer_engine/common/cast/dispatch/gated.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,12 @@ void quantize_gated_fwd_helper(const NVTETensor nvte_input, NVTETensor nvte_outp

switch (output->scaling_mode) {
case NVTE_DELAYED_TENSOR_SCALING: {
const bool use_tma_kernels = (cols % 32 == 0) && is_supported_by_CC_100();
//const bool use_tma_kernels = (cols % 32 == 0) && is_supported_by_CC_100();
// sm120 shared memory capapbilities are much smaller than sm100, so we disable TMA kernels on sm120
// KL: It is possible that for fwd, the limits are not exceeded for sm120. To be investigated -
// are there any forward only tests we'd like to keep enabled on sm120?
const bool use_tma_kernels =
(cols % 32 == 0) && is_supported_by_CC_100() && !is_supported_by_CC_120();
if (use_tma_kernels) {
Tensor dummy_grad_tensor;
fp8::cast_gated_tma</*IS_BWD=*/false, ParamOP, ActOP, nullptr>(input, dummy_grad_tensor,
Expand Down Expand Up @@ -137,7 +142,10 @@ void quantize_gated_bwd_helper(const NVTETensor nvte_grad, const NVTETensor nvte

switch (output->scaling_mode) {
case NVTE_DELAYED_TENSOR_SCALING: {
const bool use_tma_kernels = (cols % 32 == 0) && is_supported_by_CC_100();
//const bool use_tma_kernels = (cols % 32 == 0) && is_supported_by_CC_100();
// sm120 shared memory capapbilities are much smaller than sm100, so we disable TMA kernels on sm120
const bool use_tma_kernels =
(cols % 32 == 0) && is_supported_by_CC_100() && !is_supported_by_CC_120();
if (use_tma_kernels) {
fp8::cast_gated_tma</*IS_BWD=*/true, ParamOP, ActOP, DActOP>(gated_input, grad, output, p,
stream);
Expand Down
7 changes: 7 additions & 0 deletions transformer_engine/common/common.cu
Original file line number Diff line number Diff line change
Expand Up @@ -287,6 +287,13 @@ bool is_supported_by_CC_100() {
return deviceComputeCapability >= 100;
}

// KL: test function for CC 120
bool is_supported_by_CC_120() {
int deviceComputeCapability = cuda::sm_arch(cuda::current_device());

return deviceComputeCapability == 120;
}

std::vector<std::vector<Tensor *>> convert_tensor_array(NVTETensor **nvte_tensors,
size_t outer_size, size_t inner_size) {
std::vector<std::vector<Tensor *>> ret;
Expand Down
2 changes: 2 additions & 0 deletions transformer_engine/common/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -1029,6 +1029,8 @@ void create_2D_tensor_map(

bool is_supported_by_CC_100();

bool is_supported_by_CC_120();

std::vector<std::vector<Tensor *>> convert_tensor_array(NVTETensor **nvte_tensors,
size_t outer_size, size_t inner_size);

Expand Down