From e0ae1074204267560237aab4407e4b8b7373da4c Mon Sep 17 00:00:00 2001 From: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Date: Sat, 31 Jan 2026 13:42:45 -0800 Subject: [PATCH 001/172] initial implementation for mxfp8 Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --- 3rdparty/cudnn-frontend | 2 +- tests/pytorch/attention/test_attention.py | 29 ++- .../common/fused_attn/fused_attn_fp8.cu | 191 ++++++++++++++---- .../dot_product_attention/backends.py | 2 + .../dot_product_attention.py | 8 +- .../attention/dot_product_attention/utils.py | 45 ++++- .../pytorch/cpp_extensions/fused_attn.py | 8 +- transformer_engine/pytorch/csrc/common.h | 8 + .../pytorch/csrc/extensions/attention.cpp | 20 ++ transformer_engine/pytorch/csrc/quantizer.cpp | 12 ++ .../pytorch/tensor/mxfp8_tensor.py | 1 + 11 files changed, 264 insertions(+), 62 deletions(-) diff --git a/3rdparty/cudnn-frontend b/3rdparty/cudnn-frontend index b372d39879..209a25fe89 160000 --- a/3rdparty/cudnn-frontend +++ b/3rdparty/cudnn-frontend @@ -1 +1 @@ -Subproject commit b372d39879d44c91a8d5b342022e74802b6a8da2 +Subproject commit 209a25fe898bb749c8605363a6431e26cb41b089 diff --git a/tests/pytorch/attention/test_attention.py b/tests/pytorch/attention/test_attention.py index bd0ac41974..dd133c840a 100644 --- a/tests/pytorch/attention/test_attention.py +++ b/tests/pytorch/attention/test_attention.py @@ -2062,7 +2062,7 @@ def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker: @pytest.mark.parametrize("qkv_layout", qkv_layout_fp8_vs_f16) @pytest.mark.parametrize("fp8_dpa_bwd", [True, False]) @pytest.mark.parametrize("is_training", [True, False]) -@pytest.mark.parametrize("scaling_mode", ["delayed", "current"]) +@pytest.mark.parametrize("scaling_mode", ["delayed", "current", "mxfp8"]) def test_dpa_fp8_vs_f16(dtype, model, qkv_layout, fp8_dpa_bwd, is_training, scaling_mode): """Test DotProductAttention module in FP8""" config = model_configs_fp8_vs_f16[model] @@ -2095,6 +2095,12 @@ def test_dpa_fp8_vs_f16(dtype, model, qkv_layout, fp8_dpa_bwd, is_training, scal fp8_format=recipe.Format.HYBRID, fp8_dpa=True, ) + elif scaling_mode == "mxfp8": + fp8_recipe = recipe.MXFP8BlockScaling( + fp8_format=recipe.Format.HYBRID, + fp8_dpa=True, + fp8_mha=False, + ) fp8_meta = {} fp8_meta["recipe"] = fp8_recipe available_backends, _, fused_attn_backends = get_available_attention_backends( @@ -2107,6 +2113,7 @@ def test_dpa_fp8_vs_f16(dtype, model, qkv_layout, fp8_dpa_bwd, is_training, scal deterministic=_deterministic, ) flash_attn_supported, fused_attn_supported, unfused_attn_supported = available_backends + print(f"flash_attn_supported: {flash_attn_supported}, fused_attn_supported: {fused_attn_supported}, unfused_attn_supported: {unfused_attn_supported}") if flash_attn_supported + fused_attn_supported < 1: pytest.skip("No FP8 attention backend available.") if not fp8_dpa_bwd: @@ -2133,21 +2140,22 @@ def test_dpa_fp8_vs_f16(dtype, model, qkv_layout, fp8_dpa_bwd, is_training, scal dtype, config, True, qkv_layout, is_training, fp8_recipe ) - if unfused_attn_supported: - os.environ["NVTE_FLASH_ATTN"] = "0" - os.environ["NVTE_FUSED_ATTN"] = "0" - os.environ["NVTE_UNFUSED_ATTN"] = "1" - _attention_backends["backend_selection_requires_update"] = True - logging.info("[test_dpa_fp8_vs_f16]: run with fp8_dpa = True (UnfusedDotProductAttention)") - unfused_attn_fwd_fp8, unfused_attn_bwd_fp8 = _run_dpa_fp8_vs_f16( - dtype, config, True, qkv_layout, is_training, fp8_recipe - ) + # if unfused_attn_supported: + # os.environ["NVTE_FLASH_ATTN"] = "0" + # os.environ["NVTE_FUSED_ATTN"] = "0" + # os.environ["NVTE_UNFUSED_ATTN"] = "1" + # _attention_backends["backend_selection_requires_update"] = True + # logging.info("[test_dpa_fp8_vs_f16]: run with fp8_dpa = True (UnfusedDotProductAttention)") + # unfused_attn_fwd_fp8, unfused_attn_bwd_fp8 = _run_dpa_fp8_vs_f16( + # dtype, config, True, qkv_layout, is_training, fp8_recipe + # ) os.environ["NVTE_FLASH_ATTN"] = "0" os.environ["NVTE_FUSED_ATTN"] = "1" os.environ["NVTE_UNFUSED_ATTN"] = "0" _attention_backends["backend_selection_requires_update"] = True logging.info("[test_dpa_fp8_vs_f16]: run with fp8_dpa = True (FusedAttention)") + print(f"Running fused attention") fused_attn_fwd_fp8, fused_attn_bwd_fp8 = _run_dpa_fp8_vs_f16( dtype, config, True, qkv_layout, is_training, fp8_recipe ) @@ -2158,6 +2166,7 @@ def test_dpa_fp8_vs_f16(dtype, model, qkv_layout, fp8_dpa_bwd, is_training, scal if config.dropout_p == 0.0: # test cuDNN FP8 dropout: need a FP16/BF16 reference on Blackwell logging.info("[test_dpa_fp8_vs_f16]: run with fp8_dpa = False (FusedAttention)") + print(f"Running fused attention with fp8_dpa = False") fused_attn_fwd_f16, fused_attn_bwd_f16 = _run_dpa_fp8_vs_f16( dtype, config, False, qkv_layout, is_training, fp8_recipe ) diff --git a/transformer_engine/common/fused_attn/fused_attn_fp8.cu b/transformer_engine/common/fused_attn/fused_attn_fp8.cu index f886ec77f4..400a11af6a 100644 --- a/transformer_engine/common/fused_attn/fused_attn_fp8.cu +++ b/transformer_engine/common/fused_attn/fused_attn_fp8.cu @@ -1677,9 +1677,31 @@ void fused_attn_fp8_fwd_impl_v1( o_tensor_type == cudnn_frontend::DataType_t::BFLOAT16); bool is_delayed_scaling = (o_tensor_type == cudnn_frontend::DataType_t::FP8_E4M3 || o_tensor_type == cudnn_frontend::DataType_t::FP8_E5M2); - NVTE_CHECK(is_current_scaling || is_delayed_scaling, - "FP8 fused attention only supports O tensor in kFloat16, kBFloat16, kFloat8E4M3 or " - "kFloat8E5M2!"); + // NVTE_CHECK(is_current_scaling || is_delayed_scaling, + // "FP8 fused attention only supports O tensor in kFloat16, kBFloat16, kFloat8E4M3 or " + // "kFloat8E5M2!"); + is_current_scaling = false; + is_delayed_scaling = false; + bool is_mxfp8 = true; + printf(">>>>>> is_mxfp8: %d\n", is_mxfp8); + printf(">>>>>> is_current_scaling: %d\n", is_current_scaling); + printf(">>>>>> is_delayed_scaling: %d\n", is_delayed_scaling); + printf(">>>>>> qkv_tensor_type: %d, %d, %d, %d\n", qkv_tensor_type, cudnn_frontend::DataType_t::FP8_E8M0, cudnn_frontend::DataType_t::FP8_E4M3, cudnn_frontend::DataType_t::FP8_E5M2); + printf(">>>>>> o_tensor_type: %d, %d, %d\n", o_tensor_type, cudnn_frontend::DataType_t::FLOAT, cudnn_frontend::DataType_t::BFLOAT16); + printf(">>>>>> cudnn_frontend::DataType_t::UINT8: %d\n", cudnn_frontend::DataType_t::UINT8); + printf(">>>>>> cudnn_frontend::DataType_t::INT8: %d\n", cudnn_frontend::DataType_t::INT8); + printf(">>>>>> cudnn_frontend::DataType_t::HALF: %d\n", cudnn_frontend::DataType_t::HALF); + printf(">>>>>> cudnn_frontend::DataType_t::INT64: %d\n", cudnn_frontend::DataType_t::INT64); + printf(">>>>>> cudnn_frontend::DataType_t::DOUBLE: %d\n", cudnn_frontend::DataType_t::DOUBLE); + printf(">>>>>> bias_type: %d\n", bias_type); + printf(">>>>>> mask_type: %d\n", mask_type); + printf(">>>>>> scaling_factor: %f\n", scaling_factor); + printf(">>>>>> dropout_probability: %f\n", dropout_probability); + // qkv_tensor_type = cudnn_frontend::DataType_t::FP8_E8M0; + // o_tensor_type = cudnn_frontend::DataType_t::BFLOAT16; + // printf(">>>>>> after setting qkv_tensor_type and o_tensor_type\n"); + // printf(">>>>>> qkv_tensor_type: %d\n", qkv_tensor_type); + // printf(">>>>>> o_tensor_type: %d\n", o_tensor_type); try { FADescriptor_v1 descriptor{b, @@ -1770,18 +1792,55 @@ void fused_attn_fp8_fwd_impl_v1( NVTE_QKV_Matrix::NVTE_K_Matrix); generateMatrixStrides(b, hg, s_q, s_kv, d, v_stride.data(), layout, NVTE_QKV_Matrix::NVTE_V_Matrix); + printf(">>>>>> q_stride: %d, %d, %d, %d\n", q_stride[0], q_stride[1], q_stride[2], q_stride[3]); + printf(">>>>>> k_stride: %d, %d, %d, %d\n", k_stride[0], k_stride[1], k_stride[2], k_stride[3]); + printf(">>>>>> v_stride: %d, %d, %d, %d\n", v_stride[0], v_stride[1], v_stride[2], v_stride[3]); + + int32_t block_size = 32; + int64_t d_scale = (d + block_size - 1) / block_size; + int64_t s_kv_scale = (s_kv + block_size - 1) / block_size; + int64_t s_q_padded = ((s_q + 127) / 128) * 128; + int64_t s_kv_padded = ((s_kv + 127) / 128) * 128; + int64_t d_scale_padded = ((d_scale + 3) / 4) * 4; + int64_t s_kv_scale_padded = ((s_kv_scale + 3) / 4) * 4; + int64_t d_padded = ((d + 3) / 4) * 4; // d dimension for SF_V (not scaled, but may need padding) + printf(">>>>>> d_scale: %d, s_kv_scale: %d, s_q_padded: %d, s_kv_padded: %d, d_scale_padded: %d, s_kv_scale_padded: %d, d_padded: %d\n", d_scale, s_kv_scale, s_q_padded, s_kv_padded, d_scale_padded, s_kv_scale_padded, d_padded); + std::vector q_scale_dims = {b, h, s_q_padded, d_scale_padded}; + std::vector k_scale_dims = {b, hg, s_kv_padded, d_scale_padded}; + std::vector v_scale_dims = {b, hg, s_kv_scale_padded, d_padded}; + std::vector q_scale_strides(4); + std::vector k_scale_strides(4); + std::vector v_scale_strides(4); + generateMatrixStrides(b, h, s_q_padded, s_kv_padded, d_scale_padded, q_scale_strides.data(), layout, + NVTE_QKV_Matrix::NVTE_Q_Matrix); + generateMatrixStrides(b, hg, s_q_padded, s_kv_padded, d_scale_padded, k_scale_strides.data(), layout, + NVTE_QKV_Matrix::NVTE_K_Matrix); +// generateMatrixStrides(b, d_padded, s_q_padded, hg, s_kv_scale_padded, v_scale_strides.data(), layout, + // generateMatrixStrides(b, hg, s_q_padded, s_kv_scale_padded, d_padded, v_scale_strides.data(), layout, + // NVTE_QKV_Matrix::NVTE_V_Matrix); + v_scale_strides[0] = h*d_padded*s_kv_scale_padded; + v_scale_strides[1] = d_padded*s_kv_scale_padded; + v_scale_strides[2] = 1; + v_scale_strides[3] = s_kv_scale_padded; + printf(">>>>>> q_scale_strides: %d, %d, %d, %d\n", q_scale_strides[0], q_scale_strides[1], q_scale_strides[2], q_scale_strides[3]); + printf(">>>>>> k_scale_strides: %d, %d, %d, %d\n", k_scale_strides[0], k_scale_strides[1], k_scale_strides[2], k_scale_strides[3]); + printf(">>>>>> v_scale_strides: %d, %d, %d, %d\n", v_scale_strides[0], v_scale_strides[1], v_scale_strides[2], v_scale_strides[3]); + Q = mha_graph->tensor(fe::graph::Tensor_attributes() .set_name("Q") .set_dim({b, h, s_q, d}) - .set_stride(q_stride)); + .set_stride(q_stride) + .set_data_type(qkv_tensor_type)); K = mha_graph->tensor(fe::graph::Tensor_attributes() .set_name("K") .set_dim({b, hg, s_kv, d}) - .set_stride(k_stride)); + .set_stride(k_stride) + .set_data_type(qkv_tensor_type)); V = mha_graph->tensor(fe::graph::Tensor_attributes() .set_name("V") .set_dim({b, hg, s_kv, d}) - .set_stride(v_stride)); + .set_stride(v_stride) + .set_data_type(qkv_tensor_type)); attn_scale = mha_graph->tensor(fe::graph::Tensor_attributes() .set_name("attn_scale") @@ -1789,16 +1848,36 @@ void fused_attn_fp8_fwd_impl_v1( .set_stride({1, 1, 1, 1}) .set_is_pass_by_value(true) .set_data_type(fe::DataType_t::FLOAT)); - - descale_q = mha_graph->tensor(fe::graph::Tensor_attributes() + if (!is_mxfp8) { + descale_q = mha_graph->tensor(fe::graph::Tensor_attributes() .set_name("Descale_q") .set_dim({1, 1, 1, 1}) .set_stride({1, 1, 1, 1}) .set_data_type(fe::DataType_t::FLOAT)); - descale_k = mha_graph->tensor_like(descale_q, "Descale_q"); - descale_v = mha_graph->tensor_like(descale_q, "Descale_V"); - descale_s = mha_graph->tensor_like(descale_q, "Descale_S"); - scale_s = mha_graph->tensor_like(descale_q, "Scale_S"); + descale_k = mha_graph->tensor_like(descale_q, "Descale_q"); + descale_v = mha_graph->tensor_like(descale_q, "Descale_V"); + descale_s = mha_graph->tensor_like(descale_q, "Descale_S"); + scale_s = mha_graph->tensor_like(descale_q, "Scale_S"); + } else { + descale_q = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("Descale_q") + .set_dim(q_scale_dims) + .set_stride(q_scale_strides) + .set_data_type(fe::DataType_t::FP8_E8M0) + .set_reordering_type(fe::TensorReordering_t::F8_128x4)); + descale_k = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("Descale_k") + .set_dim(k_scale_dims) + .set_stride(k_scale_strides) + .set_data_type(fe::DataType_t::FP8_E8M0) + .set_reordering_type(fe::TensorReordering_t::F8_128x4)); + descale_v = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("Descale_v") + .set_dim(v_scale_dims) + .set_stride(v_scale_strides) + .set_data_type(fe::DataType_t::FP8_E8M0) + .set_reordering_type(fe::TensorReordering_t::F8_128x4)); + } if (is_delayed_scaling) { scale_o = mha_graph->tensor_like(descale_q, "Scale_O"); @@ -1851,8 +1930,24 @@ void fused_attn_fp8_fwd_impl_v1( sdpa_options.set_dropout(dropout_probability, dropout_seed, dropout_offset); } - auto [O, Stats, amax_s, amax_o] = mha_graph->sdpa_fp8( + std::shared_ptr O, Stats, amax_s, amax_o; + if (is_mxfp8) { + auto outputs = mha_graph->sdpa_fp8(Q, K, V, descale_q, descale_k, descale_v, sdpa_options); + O = outputs[0]; + Stats = outputs[1]; + amax_o = outputs[2]; + } else { + auto outputs = mha_graph->sdpa_fp8( Q, K, V, descale_q, descale_k, descale_v, descale_s, scale_s, scale_o, sdpa_options); + O = outputs[0]; + Stats = outputs[1]; + amax_s = outputs[2]; + amax_o = outputs[3]; + amax_s->set_output(true) + .set_dim({1, 1, 1, 1}) + .set_stride({1, 1, 1, 1}) + .set_data_type(fe::DataType_t::FLOAT); + } std::vector o_stride(4); generateMatrixStrides(b, h, s_q, s_kv, d, o_stride.data(), layout, @@ -1863,11 +1958,6 @@ void fused_attn_fp8_fwd_impl_v1( .set_stride({1, 1, 1, 1}) .set_data_type(fe::DataType_t::FLOAT); - amax_s->set_output(true) - .set_dim({1, 1, 1, 1}) - .set_stride({1, 1, 1, 1}) - .set_data_type(fe::DataType_t::FLOAT); - Stats->set_output(true) .set_data_type(fe::DataType_t::FLOAT) .set_dim({b, h, s_q, 1}) @@ -1886,7 +1976,9 @@ void fused_attn_fp8_fwd_impl_v1( std::shared_ptr, // O std::shared_ptr, // amax_s std::shared_ptr> // amax_o - key_tensors_tuple = std::make_tuple(Q, K, V, descale_q, descale_k, descale_v, descale_s, + key_tensors_tuple = + is_mxfp8 ? std::make_tuple(Q, K, V, descale_q, descale_k, descale_v, nullptr, nullptr, nullptr, attn_scale, O, nullptr, amax_o) : + std::make_tuple(Q, K, V, descale_q, descale_k, descale_v, descale_s, scale_s, scale_o, attn_scale, O, amax_s, amax_o); auto Stats_tuple = std::make_tuple(Stats); auto bias_tuple = is_bias ? std::make_tuple(bias) : std::make_tuple(nullptr); @@ -1896,11 +1988,16 @@ void fused_attn_fp8_fwd_impl_v1( : std::make_tuple(nullptr, nullptr); NVTE_CHECK_CUDNN_FE(mha_graph->validate()); + printf(">>>>>> mha_graph->validate()\n"); NVTE_CHECK_CUDNN_FE(mha_graph->build_operation_graph(handle)); + printf(">>>>>> mha_graph->build_operation_graph(handle)\n"); NVTE_CHECK_CUDNN_FE(mha_graph->create_execution_plans({fe::HeurMode_t::A})); + printf(">>>>>> mha_graph->create_execution_plans({fe::HeurMode_t::A})\n"); NVTE_CHECK_CUDNN_FE(mha_graph->check_support(handle)); + printf(">>>>>> mha_graph->check_support(handle)\n"); NVTE_CHECK_CUDNN_FE(mha_graph->build_plans(handle)); - + printf(">>>>>> mha_graph->build_plans(handle)\n"); + printf(">>>>>> mha_graph->get_workspace_size(): %zu\n", mha_graph->get_workspace_size()); auto return_tuple = std::tuple_cat(std::make_tuple(mha_graph), key_tensors_tuple, Stats_tuple, bias_tuple, padding_tuple, dropout_tuple); cache.insert({descriptor, return_tuple}); @@ -1967,7 +2064,7 @@ void fused_attn_fp8_fwd_impl_v1( variant_pack[dropout_seed] = devPtrDropoutSeed; variant_pack[dropout_offset] = devPtrDropoutOffset; } - + printf(">>>>>> mha_graph->execute(handle, variant_pack, workspace)\n"); NVTE_CHECK_CUDNN_FE(mha_graph->execute(handle, variant_pack, workspace)); } catch (cudnn_frontend::cudnnException& e) { NVTE_ERROR(e.what()); @@ -2420,16 +2517,44 @@ void fused_attn_fp8_fwd(size_t batch, size_t num_attn_heads, size_t num_gqa_grou const Tensor* cu_seqlens_kv, const Tensor* rng_state, Tensor* workspace, cudaStream_t stream, cudnnHandle_t handle) { using namespace transformer_engine; - void* devPtrQ = input_Q->data.dptr; - void* devPtrK = input_K->data.dptr; - void* devPtrV = input_V->data.dptr; - void* devPtrDescaleQ = input_Q->scale_inv.dptr; - void* devPtrDescaleK = input_Q->scale_inv.dptr; - void* devPtrDescaleV = input_Q->scale_inv.dptr; - - void* devPtrO = output_O->data.dptr; - void* devPtrAmaxO = output_O->amax.dptr; - void* devPtrScaleO = output_O->scale.dptr; + void* devPtrQ = nullptr; + void* devPtrK = nullptr; + void* devPtrV = nullptr; + void* devPtrDescaleQ = nullptr; + void* devPtrDescaleK = nullptr; + void* devPtrDescaleV = nullptr; + void* devPtrO = nullptr; + void* devPtrAmaxO = nullptr; + void* devPtrScaleO = nullptr; + void* devPtrAmaxS = nullptr; + void* devPtrScaleS = nullptr; + void* devPtrDescaleS = nullptr; + printf(">>>>>> fused_attn_fp8_fwd\n"); + // if (input_Q->scaling_mode() == NVTE_MXFP8_1D_SCALING) { + // printf(">>>>>> input_Q is MXFP8\n"); + // devPtrQ = input_Q-> + // devPtrQ = input_Q->get_rowwise_data_ptr(); + // devPtrDescaleQ = input_Q->get_rowwise_scale_inv_ptr(); + // devPtrK = input_K->get_rowwise_data_ptr(); + // devPtrDescaleK = input_K->get_rowwise_scale_inv_ptr(); + // devPtrV = input_V->get_rowwise_data_ptr(); + // devPtrDescaleV = input_V->get_rowwise_scale_inv_ptr(); + // devPtrO = output_O->get_rowwise_data_ptr(); + // devPtrAmaxO = output_O->amax.dptr; + // } else { + devPtrQ = input_Q->data.dptr; + devPtrDescaleQ = input_Q->scale_inv.dptr; + devPtrK = input_K->data.dptr; + devPtrDescaleK = input_K->scale_inv.dptr; + devPtrV = input_V->data.dptr; + devPtrDescaleV = input_V->scale_inv.dptr; + devPtrO = output_O->data.dptr; + devPtrAmaxO = output_O->amax.dptr; + devPtrScaleO = output_O->scale.dptr; + devPtrAmaxS = input_output_S->amax.dptr; + devPtrScaleS = input_output_S->scale.dptr; + devPtrDescaleS = input_output_S->scale_inv.dptr; + // } void* devPtrM = nullptr; void* devPtrZInv = nullptr; @@ -2458,10 +2583,6 @@ void fused_attn_fp8_fwd(size_t batch, size_t num_attn_heads, size_t num_gqa_grou NVTE_ERROR("Unexpected Aux_CTX_Tensors->size."); } - void* devPtrAmaxS = input_output_S->amax.dptr; - void* devPtrScaleS = input_output_S->scale.dptr; - void* devPtrDescaleS = input_output_S->scale_inv.dptr; - void* devPtrcuSeqlensQ = reinterpret_cast(reinterpret_cast(cu_seqlens_q->data.dptr)); void* devPtrcuSeqlensKV = diff --git a/transformer_engine/pytorch/attention/dot_product_attention/backends.py b/transformer_engine/pytorch/attention/dot_product_attention/backends.py index ef7fa0dcc0..5da38045e4 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/backends.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/backends.py @@ -1180,6 +1180,8 @@ def forward( if is_input_fp8: q_fp8, k_fp8, v_fp8 = q, k, v else: + print(f">>>>>>> Combining and quantizing q, k, v <<<<<<<") + print(f"q: {q.shape}, k: {k.shape}, v: {v.shape}") q_fp8, k_fp8, v_fp8 = combine_and_quantize(qkv_layout, q, k, v, QKV_quantizer) # print quantizers diff --git a/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py b/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py index 5a554d86ec..8699f22cb9 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py @@ -674,10 +674,10 @@ def init_fp8_metadata(self, num_gemms: int = 1) -> None: fp8_recipe_dpa = fake_recipes[1] fp8_recipes = fake_recipes # DPA only support DS and CS; other recipes should have fp8_dpa=False, fp8_mha=False - if not fp8_recipe_dpa.float8_per_tensor_scaling(): - assert not ( - fp8_recipe_dpa.fp8_dpa or fp8_recipe_dpa.fp8_mha - ), f"DotProductAttention does not support {fp8_recipe_dpa.__class__.__name__} recipe" +# if not fp8_recipe_dpa.float8_per_tensor_scaling(): +# assert not ( +# fp8_recipe_dpa.fp8_dpa or fp8_recipe_dpa.fp8_mha +# ), f"DotProductAttention does not support {fp8_recipe_dpa.__class__.__name__} recipe" # reduce over TP+CP groups; expect fp8_group to be set up so # assume attention uses the same fp8_group as GEMMs diff --git a/transformer_engine/pytorch/attention/dot_product_attention/utils.py b/transformer_engine/pytorch/attention/dot_product_attention/utils.py index 56e6f093d1..2490b5ccd4 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/utils.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/utils.py @@ -40,6 +40,7 @@ Float8Quantizer, Float8CurrentScalingQuantizer, ) +from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8Quantizer from transformer_engine.pytorch.quantization import get_fp8_te_dtype from transformer_engine.pytorch.constants import TE_DType @@ -2089,14 +2090,16 @@ def get_attention_quantizers(fp8, quantizers): """Get the list of quantizers used in attention from the quantizers list.""" if not fp8: return [None] * 6 + columnwise = True QKV_quantizer = quantizers["scaling_fwd"][META_QKV] QKV_quantizer.internal = True - QKV_quantizer.set_usage(rowwise=True, columnwise=False) + QKV_quantizer.set_usage(rowwise=True, columnwise=columnwise) O_quantizer = quantizers["scaling_fwd"][META_O] - O_quantizer.set_usage(rowwise=True, columnwise=False) - S_quantizer = quantizers["scaling_fwd"][META_S] - S_quantizer.internal = True - S_quantizer.set_usage(rowwise=True, columnwise=False) + O_quantizer.set_usage(rowwise=True, columnwise=columnwise) + S_quantizer = None + # quantizers["scaling_fwd"][META_S] + # S_quantizer.internal = True + # S_quantizer.set_usage(rowwise=True, columnwise=columnwise) dQKV_quantizer = quantizers["scaling_bwd"][META_DQKV] dQKV_quantizer.interal = True @@ -2107,6 +2110,7 @@ def get_attention_quantizers(fp8, quantizers): dP_quantizer = quantizers["scaling_bwd"][META_DP] dP_quantizer.set_usage(rowwise=True, columnwise=False) dP_quantizer.interal = True + print(f"QKV_quantizer: {QKV_quantizer}, O_quantizer: {O_quantizer}, S_quantizer: {S_quantizer}, dQKV_quantizer: {dQKV_quantizer}, dO_quantizer: {dO_quantizer}, dP_quantizer: {dP_quantizer}") return QKV_quantizer, O_quantizer, S_quantizer, dQKV_quantizer, dO_quantizer, dP_quantizer @@ -2160,10 +2164,17 @@ def print_quantizers( type_str = "DS" elif isinstance(q, Float8CurrentScalingQuantizer): type_str = "CS" - print( - f"{label} >> {names[i]:14s}: {type_str}, {q.scale.item():.4e} x" - f" {q.amax.item():.4e} = {q.scale.item()*q.amax.item():.4e}" - ) + elif isinstance(q, MXFP8Quantizer): + type_str = "MXFP8" + if type_str in ["DS", "CS"]: + print( + f"{label} >> {names[i]:14s}: {type_str}, {q.scale.item():.4e} x" + f" {q.amax.item():.4e} = {q.scale.item()*q.amax.item():.4e}" + ) + else: + print( + f"{label} >> {names[i]:14s}: {type_str}" + ) def combine_and_quantize(qkv_layout, q, k, v, qkv_quantizer): @@ -2172,6 +2183,22 @@ def combine_and_quantize(qkv_layout, q, k, v, qkv_quantizer): qkv_layout = qkv_layout.replace("paged_kv_", "") qkv_group = len(qkv_layout.split("_")) src_nominal_dtype = q.dtype + print(f"Combining and quantizing q, k, v: {q.shape}, {k.shape}, {v.shape}") + if isinstance(qkv_quantizer, MXFP8Quantizer): + print(f"Using MXFP8Quantizer") + qkv_quantizer._internal = False + q_fp8, k_fp8, v_fp8 = [qkv_quantizer(x) for x in [q, k, v]] + v_permuted = v.permute(0, 2, 3, 1).contiguous() + v_fp8_permuted = qkv_quantizer(v_permuted) + print(f"v_fp8: {v_fp8._rowwise_scale_inv.shape}") + print(f"v_fp8_permuted: {v_fp8_permuted._rowwise_scale_inv.shape}") + # v_fp8_permuted_rowwise_shape = v_fp8._rowwise_scale_inv.permute(0, 2, 3, 1).shape + v_fp8._rowwise_scale_inv = v_fp8_permuted._rowwise_scale_inv.view(2,16,128,-1).permute(0, 3, 1, 2)#.view(-1,128) + print(f"q_fp8: {q_fp8._rowwise_data.shape}, k_fp8: {k_fp8._rowwise_data.shape}, v_fp8: {v_fp8._rowwise_data.shape}, v_fp8.stride(): {v_fp8._rowwise_data.stride()}") + print(f"q_fp8: {q_fp8._rowwise_scale_inv.shape}, k_fp8: {k_fp8._rowwise_scale_inv.shape}, v_fp8: {v_fp8._rowwise_scale_inv.shape}, v_fp8.stride(): {v_fp8._rowwise_scale_inv.stride()}") + print(f"q_fp8: {q_fp8._columnwise_data.shape}, k_fp8: {k_fp8._columnwise_data.shape}, v_fp8: {v_fp8._columnwise_data.shape}") + print(f"q_fp8: {q_fp8._columnwise_scale_inv.shape}, k_fp8: {k_fp8._columnwise_scale_inv.shape}, v_fp8: {v_fp8._columnwise_scale_inv.shape}") + return q_fp8, k_fp8, v_fp8 match qkv_group: case 1: dim = qkv_layout.find("3") diff --git a/transformer_engine/pytorch/cpp_extensions/fused_attn.py b/transformer_engine/pytorch/cpp_extensions/fused_attn.py index 101e5b2525..6b2c21013a 100644 --- a/transformer_engine/pytorch/cpp_extensions/fused_attn.py +++ b/transformer_engine/pytorch/cpp_extensions/fused_attn.py @@ -16,6 +16,7 @@ NVTE_Fused_Attn_Backend, ) from ..quantized_tensor import Quantizer +from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8Quantizer __all__ = [ @@ -293,9 +294,10 @@ def fused_attn_fwd( max_seqlen_q * max_seqlen_q + BACKEND_F16m512_FP8_THREADS_PER_CTA - 1 ) // BACKEND_F16m512_FP8_THREADS_PER_CTA - assert ( - s_quantizer is not None - ), "s_quantizer is required as an input for FP8 fused attention." + if not isinstance(o_quantizer, MXFP8Quantizer): + assert ( + s_quantizer is not None + ), "s_quantizer is required as an input for FP8 fused attention." assert ( o_quantizer is not None ), "o_quantizer is required as an input for FP8 fused attention." diff --git a/transformer_engine/pytorch/csrc/common.h b/transformer_engine/pytorch/csrc/common.h index bc22e03097..d97c72c31c 100644 --- a/transformer_engine/pytorch/csrc/common.h +++ b/transformer_engine/pytorch/csrc/common.h @@ -274,6 +274,14 @@ class MXFP8Quantizer : public Quantizer { std::pair create_tensor(const std::vector& shape, DType dtype) const override; + /*! @brief Construct an unquantized tensor that shares the quantizer's amax pointer. + * + * The amax is zeroed out. Most TE kernels that output amax expect + * amax to be initialized to zero. + */ + std::pair create_unquantized_tensor_with_amax( + const std::vector& shape, DType dtype, std::optional data = std::nullopt); + std::pair convert_and_update_tensor(py::object shape) const override; void quantize(const TensorWrapper& input, TensorWrapper& out, diff --git a/transformer_engine/pytorch/csrc/extensions/attention.cpp b/transformer_engine/pytorch/csrc/extensions/attention.cpp index bf62db8c33..094188b6c9 100644 --- a/transformer_engine/pytorch/csrc/extensions/attention.cpp +++ b/transformer_engine/pytorch/csrc/extensions/attention.cpp @@ -91,6 +91,22 @@ std::pair quantizer_helper(py::handle quantizer, !data.has_value(), "Float8CurrentScalingQuantizer::create_tensor() does not take data tensor as input!"); } + } else if (detail::IsMXFP8Quantizers(quantizer.ptr())) { + // MXFP8 + auto *T_quantizer_fp8 = dynamic_cast(T_quantizer.get()); + if (create_hp_tensor_for_cs) { + if (data.has_value()) { + std::tie(te_T, py_T) = + T_quantizer_fp8->create_unquantized_tensor_with_amax(shape, dtype, data.value()); + } else { + std::tie(te_T, py_T) = T_quantizer_fp8->create_unquantized_tensor_with_amax(shape, dtype); + } + } else { + std::tie(te_T, py_T) = T_quantizer_fp8->create_tensor(shape, dtype); + NVTE_CHECK( + !data.has_value(), + "Float8CurrentScalingQuantizer::create_tensor() does not take data tensor as input!"); + } } return {std::move(te_T), std::move(py_T)}; } @@ -116,6 +132,7 @@ std::vector fused_attn_fwd( auto none = py::none(); + printf(">>>>>>> Creating QKV tensor wrappers <<<<<<<\n"); // create QKV tensor wrappers TensorWrapper te_Q, te_K, te_V; te_Q = makeTransformerEngineTensor(Q, none); @@ -123,11 +140,13 @@ std::vector fused_attn_fwd( te_V = makeTransformerEngineTensor(V, none); const DType qkv_type = te_Q.dtype(); + printf(">>>>>> Creating S tensor wrapper <<<<<<<"); // create S tensor TensorWrapper te_S; py::object py_S; std::tie(te_S, py_S) = quantizer_helper(s_quantizer, {0}, DType::kFloat32, false, std::nullopt); + printf(">>>>>> Creating O tensor wrapper <<<<<<<\n"); // create O tensor TensorWrapper te_O; py::object py_O; @@ -139,6 +158,7 @@ std::vector fused_attn_fwd( const DType fake_dtype_te = GetTransformerEngineDType(fake_dtype); std::tie(te_O, py_O) = quantizer_helper(o_quantizer, o_shape, fake_dtype_te, true, std::nullopt); + printf(">>>>>> Creating Bias tensor wrapper <<<<<<<"); // construct NVTE tensors TensorWrapper te_Bias; TensorWrapper te_cu_seqlens_q, te_cu_seqlens_kv; diff --git a/transformer_engine/pytorch/csrc/quantizer.cpp b/transformer_engine/pytorch/csrc/quantizer.cpp index 1c968e276d..20820143b0 100644 --- a/transformer_engine/pytorch/csrc/quantizer.cpp +++ b/transformer_engine/pytorch/csrc/quantizer.cpp @@ -940,6 +940,18 @@ std::pair MXFP8Quantizer::create_tensor(const std::ve return {std::move(out_cpp), std::move(out_py)}; } +std::pair MXFP8Quantizer::create_unquantized_tensor_with_amax(const std::vector& shape, + DType dtype, + std::optional data) { + at::Tensor amax_tensor = at::zeros({1}, at::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA)); + auto out = data.has_value() ? NoneQuantizer(py::none()).create_tensor(shape, dtype, data.value()) + : NoneQuantizer(py::none()).create_tensor(shape, dtype); + TensorWrapper out_cpp = std::move(out.first); + py::object out_py = std::move(out.second); + out_cpp.set_amax(amax_tensor.data_ptr(), DType::kFloat32, std::vector{1}); + return {std::move(out_cpp), std::move(out_py)}; +} + std::pair MXFP8Quantizer::convert_and_update_tensor( py::object tensor) const { NVTE_CHECK(detail::IsMXFP8Tensor(tensor.ptr()), "MXFP8Quantizer must output to MXFP8Tensor."); diff --git a/transformer_engine/pytorch/tensor/mxfp8_tensor.py b/transformer_engine/pytorch/tensor/mxfp8_tensor.py index 8dd2255d89..58d095a4f4 100644 --- a/transformer_engine/pytorch/tensor/mxfp8_tensor.py +++ b/transformer_engine/pytorch/tensor/mxfp8_tensor.py @@ -84,6 +84,7 @@ def update_quantized( def quantize_impl(self, tensor: torch.Tensor) -> QuantizedTensor: """Quantize tensor implementation""" + print(f"Quantizing tensor: {tensor.shape}") return tex.quantize(tensor, self) def is_quantizable(self, inp: torch.Tensor) -> bool: From 23434b5b1d9b7438ab0d2aa862560f832679fdac Mon Sep 17 00:00:00 2001 From: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Date: Tue, 3 Feb 2026 17:18:43 -0800 Subject: [PATCH 002/172] semi-working FP8; broken F16 Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --- .../common/fused_attn/fused_attn.cpp | 99 +++++++++++-------- .../common/fused_attn/fused_attn_fp8.cu | 49 +++++---- transformer_engine/common/fused_attn/utils.cu | 25 ++++- .../include/transformer_engine/fused_attn.h | 5 + .../common/util/pybind_helper.h | 6 +- .../attention/dot_product_attention/utils.py | 46 ++++++--- .../pytorch/cpp_extensions/fused_attn.py | 2 + 7 files changed, 157 insertions(+), 75 deletions(-) diff --git a/transformer_engine/common/fused_attn/fused_attn.cpp b/transformer_engine/common/fused_attn/fused_attn.cpp index 4f8367aac7..02ff448544 100644 --- a/transformer_engine/common/fused_attn/fused_attn.cpp +++ b/transformer_engine/common/fused_attn/fused_attn.cpp @@ -117,6 +117,8 @@ NVTE_QKV_Layout_Group nvte_get_qkv_layout_group(NVTE_QKV_Layout qkv_layout) { case NVTE_QKV_Layout::NVTE_Paged_KV_SBHD_SBHD_SBHD: case NVTE_QKV_Layout::NVTE_Paged_KV_THD_SBHD_SBHD: return NVTE_QKV_Layout_Group::NVTE_Paged_KV_HD_HD_HD; + case NVTE_QKV_Layout::NVTE_BSHD_BSHD_BHSD: + return NVTE_QKV_Layout_Group::NVTE_HD_HD_SD; default: NVTE_ERROR("qkv_layout not supported!"); } @@ -157,6 +159,8 @@ NVTE_QKV_Format nvte_get_qkv_format(NVTE_QKV_Layout qkv_layout) { case NVTE_QKV_Layout::NVTE_THD_SBHD_SBHD: case NVTE_QKV_Layout::NVTE_Paged_KV_THD_SBHD_SBHD: return NVTE_QKV_Format::NVTE_THD_2SBHD; + case NVTE_QKV_Layout::NVTE_BSHD_BSHD_BHSD: + return NVTE_QKV_Format::NVTE_BHSD; default: NVTE_ERROR("qkv_layout not supported!"); } @@ -176,6 +180,8 @@ NVTE_QKV_Format nvte_get_q_format(NVTE_QKV_Layout qkv_layout) { case NVTE_QKV_Format::NVTE_THD_2BSHD: case NVTE_QKV_Format::NVTE_THD_2SBHD: return NVTE_QKV_Format::NVTE_THD; + case NVTE_QKV_Format::NVTE_BHSD: + return NVTE_QKV_Format::NVTE_BHSD; default: NVTE_ERROR("qkv_layout not supported!"); } @@ -195,6 +201,8 @@ NVTE_QKV_Format nvte_get_kv_format(NVTE_QKV_Layout qkv_layout) { return NVTE_QKV_Format::NVTE_BSHD; case NVTE_QKV_Format::NVTE_THD: return NVTE_QKV_Format::NVTE_THD; + case NVTE_QKV_Format::NVTE_BHSD: + return NVTE_QKV_Format::NVTE_BHSD; default: NVTE_ERROR("qkv_layout not supported!"); } @@ -226,45 +234,58 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( max_seqlen_kv, head_dim_qk, head_dim_v) == DType::kInt64); const bool supported_ragged_offset_size = (!requires_64bit_ragged_offset || cudnn_runtime_version >= 90500); - - if ((q_dtype == NVTEDType::kNVTEFloat8E4M3 || q_dtype == NVTEDType::kNVTEFloat8E5M2) && - sm_arch_ >= 90 && bias_type == NVTE_Bias_Type::NVTE_NO_BIAS && - // 8.9: t3hd, max_s=512, d=64, padding - ((cudnn_runtime_version >= 8900 && sm_arch_ < 100 && - qkv_layout == NVTE_QKV_Layout::NVTE_T3HD && max_seqlen_q == max_seqlen_kv && - max_seqlen_q <= 512 && head_dim_qk == 64 && head_dim_v == 64 && - attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK) || - // 9.2.1: {bshd, sbhd}, any seqlen, d=128, {no_mask, causal} - (cudnn_runtime_version >= 90201 && sm_arch_ < 100 && max_seqlen_q % 128 == 0 && - max_seqlen_kv % 128 == 0 && head_dim_qk == 128 && head_dim_v == 128 && - (attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK || - attn_mask_type == NVTE_Mask_Type::NVTE_NO_MASK)) || - // 9.7: {bshd, sbhd}, any seqlen, d<=256 for sm90 and d<=128 for sm100, {padding, padding_causal} - (cudnn_runtime_version >= 90700 && - // TODO (cyang): add is_training to nvte_get_fused_attn_backend - // sm90: fwd d<=256, bwd d=128 only - // sm100: fwd d<=128, bwd d<=128 - ((sm_arch_ < 100 && (!is_training) && head_dim_qk <= 256 && head_dim_v <= 256) || - (sm_arch_ < 100 && is_training && head_dim_qk == 128 && head_dim_v == 128) || - (sm_arch_ >= 100 && head_dim_qk <= 128 && head_dim_v <= 128)) && - head_dim_qk % 16 == 0 && head_dim_v % 16 == 0 && - (attn_mask_type == NVTE_Mask_Type::NVTE_NO_MASK || - attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK || - attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK || - attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK))) && - (qkv_format == NVTE_QKV_Format::NVTE_BSHD || qkv_format == NVTE_QKV_Format::NVTE_SBHD) && - !requires_64bit_ragged_offset && (softmax_type == NVTE_Softmax_Type::NVTE_VANILLA_SOFTMAX) && - // 9.10.0: known bugs with SDPA FP8 - (cudnn_runtime_version != 91000) && !return_max_logit) { - if (cudnn_runtime_version >= 8900) { - backend = NVTE_Fused_Attn_Backend::NVTE_FP8; - } else { - backend = NVTE_Fused_Attn_Backend::NVTE_No_Backend; - std::cout << "Warning: FP8 fused attention is supported by cuDNN 8.9.0+." - " Please upgrade your cuDNN version if possible." - << std::endl; - } - } else if ((q_dtype == NVTEDType::kNVTEFloat16) || (q_dtype == NVTEDType::kNVTEBFloat16)) { + printf(">>>>>> q_dtype: %d\n", q_dtype); + printf(">>>>>> qkv_format: %d\n", qkv_format); + printf(">>>>>> q_format: %d\n", q_format); + printf(">>>>>> kv_format: %d\n", kv_format); + printf(">>>>>> layout_group: %d\n", layout_group); + printf(">>>>>> cudnn_runtime_version: %d\n", cudnn_runtime_version); + printf(">>>>>> is_training: %d\n", is_training); + printf(">>>>>> bias_type: %d\n", bias_type); + printf(">>>>>> attn_mask_type: %d\n", attn_mask_type); + if (q_dtype == NVTEDType::kNVTEFloat8E4M3 || q_dtype == NVTEDType::kNVTEFloat8E5M2) { + backend = NVTE_Fused_Attn_Backend::NVTE_FP8; + // } + + // if ((q_dtype == NVTEDType::kNVTEFloat8E4M3 || q_dtype == NVTEDType::kNVTEFloat8E5M2) && + // sm_arch_ >= 90 && bias_type == NVTE_Bias_Type::NVTE_NO_BIAS && + // // 8.9: t3hd, max_s=512, d=64, padding + // ((cudnn_runtime_version >= 8900 && sm_arch_ < 100 && + // qkv_layout == NVTE_QKV_Layout::NVTE_T3HD && max_seqlen_q == max_seqlen_kv && + // max_seqlen_q <= 512 && head_dim_qk == 64 && head_dim_v == 64 && + // attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK) || + // // 9.2.1: {bshd, sbhd}, any seqlen, d=128, {no_mask, causal} + // (cudnn_runtime_version >= 90201 && sm_arch_ < 100 && max_seqlen_q % 128 == 0 && + // max_seqlen_kv % 128 == 0 && head_dim_qk == 128 && head_dim_v == 128 && + // (attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK || + // attn_mask_type == NVTE_Mask_Type::NVTE_NO_MASK)) || + // // 9.7: {bshd, sbhd}, any seqlen, d<=256 for sm90 and d<=128 for sm100, {padding, padding_causal} + // (cudnn_runtime_version >= 90700 && + // // TODO (cyang): add is_training to nvte_get_fused_attn_backend + // // sm90: fwd d<=256, bwd d=128 only + // // sm100: fwd d<=128, bwd d<=128 + // ((sm_arch_ < 100 && (!is_training) && head_dim_qk <= 256 && head_dim_v <= 256) || + // (sm_arch_ < 100 && is_training && head_dim_qk == 128 && head_dim_v == 128) || + // (sm_arch_ >= 100 && head_dim_qk <= 128 && head_dim_v <= 128)) && + // head_dim_qk % 16 == 0 && head_dim_v % 16 == 0 && + // (attn_mask_type == NVTE_Mask_Type::NVTE_NO_MASK || + // attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK || + // attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK || + // attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK))) && + // (qkv_format == NVTE_QKV_Format::NVTE_BSHD || qkv_format == NVTE_QKV_Format::NVTE_SBHD || qkv_format == NVTE_QKV_Format::NVTE_BHSD) && + // !requires_64bit_ragged_offset && (softmax_type == NVTE_Softmax_Type::NVTE_VANILLA_SOFTMAX) && + // // 9.10.0: known bugs with SDPA FP8 + // (cudnn_runtime_version != 91000) && !return_max_logit) { + // if (cudnn_runtime_version >= 8900) { + // backend = NVTE_Fused_Attn_Backend::NVTE_FP8; + // } else { + // backend = NVTE_Fused_Attn_Backend::NVTE_No_Backend; + // std::cout << "Warning: FP8 fused attention is supported by cuDNN 8.9.0+." + // " Please upgrade your cuDNN version if possible." + // << std::endl; + // } + // } else +} else if ((q_dtype == NVTEDType::kNVTEFloat16) || (q_dtype == NVTEDType::kNVTEBFloat16)) { bool flag_m512 = false; bool flag_arb = false; if ((sm_arch_ == 80 || sm_arch_ == 90) && (max_seqlen_q <= 512 && max_seqlen_q % 64 == 0) && diff --git a/transformer_engine/common/fused_attn/fused_attn_fp8.cu b/transformer_engine/common/fused_attn/fused_attn_fp8.cu index 400a11af6a..ea5722a831 100644 --- a/transformer_engine/common/fused_attn/fused_attn_fp8.cu +++ b/transformer_engine/common/fused_attn/fused_attn_fp8.cu @@ -1659,8 +1659,8 @@ void fused_attn_fp8_fwd_impl_v1( void* devPtrDescaleV, void* devPtrDescaleS, void* devPtrScaleS, void* devPtrScaleO, void* devPtrAmaxO, void* devPtrAmaxS, void* devPtrcuSeqlensQ, void* devPtrcuSeqlensKV, void* devPtrDropoutSeed, void* devPtrDropoutOffset, cudnn_frontend::DataType_t qkv_tensor_type, - cudnn_frontend::DataType_t o_tensor_type, void* workspace, size_t* workspace_size, - cudaStream_t stream, cudnnHandle_t handle) { + cudnn_frontend::DataType_t o_tensor_type, NVTEScalingMode scaling_mode, void* workspace, + size_t* workspace_size, cudaStream_t stream, cudnnHandle_t handle) { using namespace transformer_engine; bool is_bias = (bias_type == NVTE_Bias_Type::NVTE_POST_SCALE_BIAS); bool is_alibi = (bias_type == NVTE_Bias_Type::NVTE_ALIBI); @@ -1673,16 +1673,15 @@ void fused_attn_fp8_fwd_impl_v1( auto bias_h = h; NVTE_CHECK(~is_bias, "FP8 fused attention does not support pre/post_scale_bias yet!"); NVTE_CHECK(~is_alibi, "FP8 fused attention does not support ALiBi yet!"); - bool is_current_scaling = (o_tensor_type == cudnn_frontend::DataType_t::HALF || - o_tensor_type == cudnn_frontend::DataType_t::BFLOAT16); - bool is_delayed_scaling = (o_tensor_type == cudnn_frontend::DataType_t::FP8_E4M3 || + bool is_delayed_scaling = (scaling_mode == NVTE_DELAYED_TENSOR_SCALING) && (o_tensor_type == cudnn_frontend::DataType_t::FP8_E4M3 || o_tensor_type == cudnn_frontend::DataType_t::FP8_E5M2); - // NVTE_CHECK(is_current_scaling || is_delayed_scaling, - // "FP8 fused attention only supports O tensor in kFloat16, kBFloat16, kFloat8E4M3 or " - // "kFloat8E5M2!"); - is_current_scaling = false; - is_delayed_scaling = false; - bool is_mxfp8 = true; + bool is_current_scaling = (scaling_mode == NVTE_DELAYED_TENSOR_SCALING) && (o_tensor_type == cudnn_frontend::DataType_t::HALF || + o_tensor_type == cudnn_frontend::DataType_t::BFLOAT16); + bool is_mxfp8 = (scaling_mode == NVTE_MXFP8_1D_SCALING) && (o_tensor_type == cudnn_frontend::DataType_t::HALF || + o_tensor_type == cudnn_frontend::DataType_t::BFLOAT16); + NVTE_CHECK(is_delayed_scaling || is_current_scaling || is_mxfp8, + "FP8 fused attention only supports FP8DelayedScaling or FP8CurrentScaling or MXFP8 recipes!"); + printf(">>>>>> scaling_mode: %d\n", scaling_mode); printf(">>>>>> is_mxfp8: %d\n", is_mxfp8); printf(">>>>>> is_current_scaling: %d\n", is_current_scaling); printf(">>>>>> is_delayed_scaling: %d\n", is_delayed_scaling); @@ -1697,6 +1696,7 @@ void fused_attn_fp8_fwd_impl_v1( printf(">>>>>> mask_type: %d\n", mask_type); printf(">>>>>> scaling_factor: %f\n", scaling_factor); printf(">>>>>> dropout_probability: %f\n", dropout_probability); + is_mxfp8 = true; // qkv_tensor_type = cudnn_frontend::DataType_t::FP8_E8M0; // o_tensor_type = cudnn_frontend::DataType_t::BFLOAT16; // printf(">>>>>> after setting qkv_tensor_type and o_tensor_type\n"); @@ -1783,6 +1783,7 @@ void fused_attn_fp8_fwd_impl_v1( std::shared_ptr bias, seq_q, seq_kv; std::shared_ptr dropout_seed, dropout_offset; + printf(">>>>>> layout: %d\n", layout); std::vector q_stride(4); std::vector k_stride(4); std::vector v_stride(4); @@ -2030,17 +2031,19 @@ void fused_attn_fp8_fwd_impl_v1( {descale_q, devPtrDescaleQ}, {descale_k, devPtrDescaleK}, {descale_v, devPtrDescaleV}, - {descale_s, devPtrDescaleS}, - {scale_s, devPtrScaleS}, {attn_scale, &scaling_factor}, {O, devPtrO}, - {amax_s, devPtrAmaxS}, {amax_o, devPtrAmaxO}, {Stats, devPtrM}}; if (is_delayed_scaling) { variant_pack[scale_o] = devPtrScaleO; } + if (!is_mxfp8) { + variant_pack[descale_s] = devPtrDescaleS; + variant_pack[scale_s] = devPtrScaleS; + variant_pack[amax_s] = devPtrAmaxS; + } /* if (is_bias) { variant_pack[bias] = devPtrBias; @@ -2548,14 +2551,19 @@ void fused_attn_fp8_fwd(size_t batch, size_t num_attn_heads, size_t num_gqa_grou devPtrDescaleK = input_K->scale_inv.dptr; devPtrV = input_V->data.dptr; devPtrDescaleV = input_V->scale_inv.dptr; + // devPtrV = input_V->columnwise_data.dptr; + // devPtrDescaleV = input_V->columnwise_scale_inv.dptr; devPtrO = output_O->data.dptr; devPtrAmaxO = output_O->amax.dptr; - devPtrScaleO = output_O->scale.dptr; - devPtrAmaxS = input_output_S->amax.dptr; - devPtrScaleS = input_output_S->scale.dptr; - devPtrDescaleS = input_output_S->scale_inv.dptr; + // devPtrScaleO = output_O->scale.dptr; + // devPtrAmaxS = input_output_S->amax.dptr; + // devPtrScaleS = input_output_S->scale.dptr; + // devPtrDescaleS = input_output_S->scale_inv.dptr; // } - + printf(">>>>>> scaling_mode: %d\n", input_Q->scaling_mode); + printf(">>>>>> scaling_mode: %d\n", input_K->scaling_mode); + printf(">>>>>> scaling_mode: %d\n", input_V->scaling_mode); + printf(">>>>>> scaling_mode: %d\n", output_O->scaling_mode); void* devPtrM = nullptr; void* devPtrZInv = nullptr; if (Aux_CTX_Tensors->size == 0) { @@ -2604,7 +2612,8 @@ void fused_attn_fp8_fwd(size_t batch, size_t num_attn_heads, size_t num_gqa_grou devPtrZInv, devPtrO, devPtrDescaleQ, devPtrDescaleK, devPtrDescaleV, devPtrDescaleS, devPtrScaleS, devPtrScaleO, devPtrAmaxO, devPtrAmaxS, devPtrcuSeqlensQ, devPtrcuSeqlensKV, devPtrDropoutSeed, devPtrDropoutOffset, get_cudnn_fe_dtype(QKV_type), - get_cudnn_fe_dtype(O_type), workspace->data.dptr, &workspace_size, stream, handle); + get_cudnn_fe_dtype(O_type), input_Q->scaling_mode, workspace->data.dptr, &workspace_size, + stream, handle); } else if (qkv_layout == NVTE_QKV_Layout::NVTE_T3HD) { fused_attn::fused_attn_fp8_fwd_impl( batch, num_attn_heads, max_seqlen_q, max_seqlen_kv, head_dim, is_training, attn_scale, diff --git a/transformer_engine/common/fused_attn/utils.cu b/transformer_engine/common/fused_attn/utils.cu index 727aac447b..0309cf643d 100644 --- a/transformer_engine/common/fused_attn/utils.cu +++ b/transformer_engine/common/fused_attn/utils.cu @@ -293,7 +293,30 @@ void generateMatrixStrides(int64_t b, int64_t h, int64_t s_q, int64_t s_kv, int6 strideA[hidden_dim_idx] = 1; } break; - } + case NVTE_QKV_Layout::NVTE_BSHD_BSHD_BHSD: + if ((matrix == NVTE_QKV_Matrix::NVTE_Q_Matrix) || + (matrix == NVTE_QKV_Matrix::NVTE_O_Matrix)) { + strideA[batch_dim_idx] = s_q * h * d; + strideA[head_dim_idx] = d; + strideA[seqlen_dim_idx] = h * d; + strideA[hidden_dim_idx] = 1; + } else if (matrix == NVTE_QKV_Matrix::NVTE_K_Matrix) { + strideA[batch_dim_idx] = s_kv * h * d; + strideA[head_dim_idx] = d; + strideA[seqlen_dim_idx] = h * d; + strideA[hidden_dim_idx] = 1; + } else if (matrix == NVTE_QKV_Matrix::NVTE_V_Matrix) { + strideA[batch_dim_idx] = s_kv * h * d; + strideA[head_dim_idx] = d * s_kv; + strideA[seqlen_dim_idx] = 1; + strideA[hidden_dim_idx] = s_kv; + // strideA[batch_dim_idx] = h * s_kv * d; + // strideA[head_dim_idx] = s_kv * d; + // strideA[seqlen_transpose_dim_idx] = d; + // strideA[hidden_transpose_dim_idx] = 1; + } + break; +} if (matrix == NVTE_QKV_Matrix::NVTE_S_Matrix) { strideA[seqlen_kv_dim_idx] = 1; diff --git a/transformer_engine/common/include/transformer_engine/fused_attn.h b/transformer_engine/common/include/transformer_engine/fused_attn.h index cddd3d7506..7c54633989 100644 --- a/transformer_engine/common/include/transformer_engine/fused_attn.h +++ b/transformer_engine/common/include/transformer_engine/fused_attn.h @@ -52,6 +52,7 @@ enum NVTE_QKV_Layout { NVTE_Paged_KV_SBHD_SBHD_SBHD = 22, /*!< Paged_KV_SBHD_SBHD_SBHD layout */ NVTE_Paged_KV_THD_BSHD_BSHD = 23, /*!< Paged_KV_THD_BSHD_BSHD layout */ NVTE_Paged_KV_THD_SBHD_SBHD = 24, /*!< Paged_KV_THD_SBHD_SBHD layout */ + NVTE_BSHD_BSHD_BHSD = 25, /*!< BSHD_BSHD_BHSD layout */ }; /*! \enum NVTE_QKV_Layout_Group @@ -70,6 +71,8 @@ enum NVTE_QKV_Layout_Group { NVTE_HD_HD_HD = 4, /*! Paged_KV_HD_HD_HD QKV layouts, e.g. Paged_KV_BSHD_BSHD_BSHD, Paged_KV_THD_SBHD_SBHD */ NVTE_Paged_KV_HD_HD_HD = 5, + /*! BSHD_BSHD_BHSD QKV layouts, e.g. BSHD_BSHD_BHSD */ + NVTE_HD_HD_SD = 6, }; /*! \enum NVTE_QKV_Format @@ -90,6 +93,8 @@ enum NVTE_QKV_Format { NVTE_THD_2BSHD = 5, /*! THD format for Q and SBHD format for KV, i.e. THD_SBHD_SBHD, Paged_KV_THD_SBHD_SBHD */ NVTE_THD_2SBHD = 6, + /*! BSHD_BSHD_BHSD QKV format, e.g. BSHD_BSHD_BHSD */ + NVTE_BHSD = 7, }; /*! \enum NVTE_Bias_Type diff --git a/transformer_engine/common/util/pybind_helper.h b/transformer_engine/common/util/pybind_helper.h index 6adba23a8f..57d02bcd62 100644 --- a/transformer_engine/common/util/pybind_helper.h +++ b/transformer_engine/common/util/pybind_helper.h @@ -48,7 +48,8 @@ .value("NVTE_SBHD_2BSHD", NVTE_QKV_Format::NVTE_SBHD_2BSHD) \ .value("NVTE_BSHD_2SBHD", NVTE_QKV_Format::NVTE_BSHD_2SBHD) \ .value("NVTE_THD_2BSHD", NVTE_QKV_Format::NVTE_THD_2BSHD) \ - .value("NVTE_THD_2SBHD", NVTE_QKV_Format::NVTE_THD_2SBHD); \ + .value("NVTE_THD_2SBHD", NVTE_QKV_Format::NVTE_THD_2SBHD) \ + .value("NVTE_BHSD", NVTE_QKV_Format::NVTE_BHSD); \ pybind11::enum_(m, "NVTE_QKV_Layout", pybind11::module_local()) \ .value("NVTE_SB3HD", NVTE_QKV_Layout::NVTE_SB3HD) \ .value("NVTE_SBH3D", NVTE_QKV_Layout::NVTE_SBH3D) \ @@ -74,7 +75,8 @@ .value("NVTE_Paged_KV_SBHD_BSHD_BSHD", NVTE_QKV_Layout::NVTE_Paged_KV_SBHD_BSHD_BSHD) \ .value("NVTE_Paged_KV_SBHD_SBHD_SBHD", NVTE_QKV_Layout::NVTE_Paged_KV_SBHD_SBHD_SBHD) \ .value("NVTE_Paged_KV_THD_BSHD_BSHD", NVTE_QKV_Layout::NVTE_Paged_KV_THD_BSHD_BSHD) \ - .value("NVTE_Paged_KV_THD_SBHD_SBHD", NVTE_QKV_Layout::NVTE_Paged_KV_THD_SBHD_SBHD); \ + .value("NVTE_Paged_KV_THD_SBHD_SBHD", NVTE_QKV_Layout::NVTE_Paged_KV_THD_SBHD_SBHD) \ + .value("NVTE_BSHD_BSHD_BHSD", NVTE_QKV_Layout::NVTE_BSHD_BSHD_BHSD); \ pybind11::enum_(m, "NVTE_Fused_Attn_Backend", pybind11::module_local()) \ .value("NVTE_F16_max512_seqlen", NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) \ .value("NVTE_F16_arbitrary_seqlen", NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen) \ diff --git a/transformer_engine/pytorch/attention/dot_product_attention/utils.py b/transformer_engine/pytorch/attention/dot_product_attention/utils.py index 2490b5ccd4..0dfc64a65c 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/utils.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/utils.py @@ -2090,12 +2090,13 @@ def get_attention_quantizers(fp8, quantizers): """Get the list of quantizers used in attention from the quantizers list.""" if not fp8: return [None] * 6 - columnwise = True + is_fwd = True + is_bwd = True QKV_quantizer = quantizers["scaling_fwd"][META_QKV] QKV_quantizer.internal = True - QKV_quantizer.set_usage(rowwise=True, columnwise=columnwise) + QKV_quantizer.set_usage(rowwise=True, columnwise=True) O_quantizer = quantizers["scaling_fwd"][META_O] - O_quantizer.set_usage(rowwise=True, columnwise=columnwise) + O_quantizer.set_usage(rowwise=True, columnwise=True) S_quantizer = None # quantizers["scaling_fwd"][META_S] # S_quantizer.internal = True @@ -2103,9 +2104,9 @@ def get_attention_quantizers(fp8, quantizers): dQKV_quantizer = quantizers["scaling_bwd"][META_DQKV] dQKV_quantizer.interal = True - dQKV_quantizer.set_usage(rowwise=True, columnwise=False) + dQKV_quantizer.set_usage(rowwise=True, columnwise=True) dO_quantizer = quantizers["scaling_bwd"][META_DO] - dO_quantizer.set_usage(rowwise=True, columnwise=False) + dO_quantizer.set_usage(rowwise=True, columnwise=True) dO_quantizer.internal = True dP_quantizer = quantizers["scaling_bwd"][META_DP] dP_quantizer.set_usage(rowwise=True, columnwise=False) @@ -2181,24 +2182,43 @@ def combine_and_quantize(qkv_layout, q, k, v, qkv_quantizer): """Combine q,k,v based on qkv_layout and quantize them together""" # 1: qkv packed, 2: kv packed, 3: qkv separate qkv_layout = qkv_layout.replace("paged_kv_", "") + qkv_format, _, _ = get_qkv_format(qkv_layout) qkv_group = len(qkv_layout.split("_")) src_nominal_dtype = q.dtype print(f"Combining and quantizing q, k, v: {q.shape}, {k.shape}, {v.shape}") if isinstance(qkv_quantizer, MXFP8Quantizer): print(f"Using MXFP8Quantizer") qkv_quantizer._internal = False + dim_s = qkv_format.find("s") if 's' in qkv_format else qkv_format.find("t") + dim_others = [i for i in range(len(v.shape)) if i != dim_s] + perm = [*dim_others, dim_s] + # perm = [*dim_others[:-1], dim_s, dim_others[-1]] + v = v.permute(*perm).contiguous() + qkv_layout = "bshd_bshd_bhsd" + # inv = [0] * len(perm) + # for i, p in enumerate(perm): + # inv[p] = i + # v = v.permute(*inv) q_fp8, k_fp8, v_fp8 = [qkv_quantizer(x) for x in [q, k, v]] - v_permuted = v.permute(0, 2, 3, 1).contiguous() - v_fp8_permuted = qkv_quantizer(v_permuted) - print(f"v_fp8: {v_fp8._rowwise_scale_inv.shape}") - print(f"v_fp8_permuted: {v_fp8_permuted._rowwise_scale_inv.shape}") - # v_fp8_permuted_rowwise_shape = v_fp8._rowwise_scale_inv.permute(0, 2, 3, 1).shape - v_fp8._rowwise_scale_inv = v_fp8_permuted._rowwise_scale_inv.view(2,16,128,-1).permute(0, 3, 1, 2)#.view(-1,128) - print(f"q_fp8: {q_fp8._rowwise_data.shape}, k_fp8: {k_fp8._rowwise_data.shape}, v_fp8: {v_fp8._rowwise_data.shape}, v_fp8.stride(): {v_fp8._rowwise_data.stride()}") - print(f"q_fp8: {q_fp8._rowwise_scale_inv.shape}, k_fp8: {k_fp8._rowwise_scale_inv.shape}, v_fp8: {v_fp8._rowwise_scale_inv.shape}, v_fp8.stride(): {v_fp8._rowwise_scale_inv.stride()}") + print(f"q_fp8: {q_fp8._rowwise_data.shape}, k_fp8: {k_fp8._rowwise_data.shape}, v_fp8: {v_fp8._rowwise_data.shape}") + print(f"q_fp8: {q_fp8._rowwise_scale_inv.shape}, k_fp8: {k_fp8._rowwise_scale_inv.shape}, v_fp8: {v_fp8._rowwise_scale_inv.shape}") print(f"q_fp8: {q_fp8._columnwise_data.shape}, k_fp8: {k_fp8._columnwise_data.shape}, v_fp8: {v_fp8._columnwise_data.shape}") print(f"q_fp8: {q_fp8._columnwise_scale_inv.shape}, k_fp8: {k_fp8._columnwise_scale_inv.shape}, v_fp8: {v_fp8._columnwise_scale_inv.shape}") + # v_fp8._rowwise_data = v_fp8._rowwise_data.permute(*inv) return q_fp8, k_fp8, v_fp8 + + # q_fp8, k_fp8, v_fp8 = [qkv_quantizer(x) for x in [q, k, v]] + # v_permuted = v.permute(0, 2, 3, 1).contiguous() + # v_fp8_permuted = qkv_quantizer(v_permuted) + # print(f"v_fp8: {v_fp8._rowwise_scale_inv.shape}") + # print(f"v_fp8_permuted: {v_fp8_permuted._rowwise_scale_inv.shape}") + # # v_fp8_permuted_rowwise_shape = v_fp8._rowwise_scale_inv.permute(0, 2, 3, 1).shape + # v_fp8._rowwise_scale_inv = v_fp8_permuted._rowwise_scale_inv.view(2,16,128,-1).permute(0, 3, 1, 2)#.view(-1,128) + # print(f"q_fp8: {q_fp8._rowwise_data.shape}, k_fp8: {k_fp8._rowwise_data.shape}, v_fp8: {v_fp8._rowwise_data.shape}, v_fp8.stride(): {v_fp8._rowwise_data.stride()}") + # print(f"q_fp8: {q_fp8._rowwise_scale_inv.shape}, k_fp8: {k_fp8._rowwise_scale_inv.shape}, v_fp8: {v_fp8._rowwise_scale_inv.shape}, v_fp8.stride(): {v_fp8._rowwise_scale_inv.stride()}") + # print(f"q_fp8: {q_fp8._columnwise_data.shape}, k_fp8: {k_fp8._columnwise_data.shape}, v_fp8: {v_fp8._columnwise_data.shape}") + # print(f"q_fp8: {q_fp8._columnwise_scale_inv.shape}, k_fp8: {k_fp8._columnwise_scale_inv.shape}, v_fp8: {v_fp8._columnwise_scale_inv.shape}") + # return q_fp8, k_fp8, v_fp8 match qkv_group: case 1: dim = qkv_layout.find("3") diff --git a/transformer_engine/pytorch/cpp_extensions/fused_attn.py b/transformer_engine/pytorch/cpp_extensions/fused_attn.py index 6b2c21013a..41007912c9 100644 --- a/transformer_engine/pytorch/cpp_extensions/fused_attn.py +++ b/transformer_engine/pytorch/cpp_extensions/fused_attn.py @@ -42,6 +42,7 @@ "bshd_2sbhd": NVTE_QKV_Format.NVTE_BSHD_2SBHD, "thd_2bshd": NVTE_QKV_Format.NVTE_THD_2BSHD, "thd_2sbhd": NVTE_QKV_Format.NVTE_THD_2SBHD, + "bshd_bshd_bhsd": NVTE_QKV_Format.NVTE_BHSD, } QKVLayout = { @@ -70,6 +71,7 @@ "paged_kv_sbhd_sbhd_sbhd": NVTE_QKV_Layout.NVTE_Paged_KV_SBHD_SBHD_SBHD, "paged_kv_thd_bshd_bshd": NVTE_QKV_Layout.NVTE_Paged_KV_THD_BSHD_BSHD, "paged_kv_thd_sbhd_sbhd": NVTE_QKV_Layout.NVTE_Paged_KV_THD_SBHD_SBHD, + "bshd_bshd_bhsd": NVTE_QKV_Layout.NVTE_BSHD_BSHD_BHSD, } AttnBiasType = { From dbb68b8c958735ffe89cca1881bc8d3cd5bfa871 Mon Sep 17 00:00:00 2001 From: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Date: Tue, 3 Feb 2026 19:52:55 -0800 Subject: [PATCH 003/172] clean up last commit Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --- tests/pytorch/attention/test_attention.py | 5 +- .../common/fused_attn/fused_attn.cpp | 100 +++++++++--------- .../common/fused_attn/fused_attn_fp8.cu | 66 +++++------- transformer_engine/common/fused_attn/utils.cu | 16 ++- .../include/transformer_engine/fused_attn.h | 8 +- .../common/util/pybind_helper.h | 4 +- .../dot_product_attention/backends.py | 21 ++-- .../dot_product_attention.py | 5 - .../attention/dot_product_attention/utils.py | 46 +++----- .../pytorch/cpp_extensions/fused_attn.py | 18 +--- .../pytorch/csrc/extensions/attention.cpp | 4 - 11 files changed, 123 insertions(+), 170 deletions(-) diff --git a/tests/pytorch/attention/test_attention.py b/tests/pytorch/attention/test_attention.py index dd133c840a..dad697e910 100644 --- a/tests/pytorch/attention/test_attention.py +++ b/tests/pytorch/attention/test_attention.py @@ -2113,7 +2113,6 @@ def test_dpa_fp8_vs_f16(dtype, model, qkv_layout, fp8_dpa_bwd, is_training, scal deterministic=_deterministic, ) flash_attn_supported, fused_attn_supported, unfused_attn_supported = available_backends - print(f"flash_attn_supported: {flash_attn_supported}, fused_attn_supported: {fused_attn_supported}, unfused_attn_supported: {unfused_attn_supported}") if flash_attn_supported + fused_attn_supported < 1: pytest.skip("No FP8 attention backend available.") if not fp8_dpa_bwd: @@ -2155,7 +2154,6 @@ def test_dpa_fp8_vs_f16(dtype, model, qkv_layout, fp8_dpa_bwd, is_training, scal os.environ["NVTE_UNFUSED_ATTN"] = "0" _attention_backends["backend_selection_requires_update"] = True logging.info("[test_dpa_fp8_vs_f16]: run with fp8_dpa = True (FusedAttention)") - print(f"Running fused attention") fused_attn_fwd_fp8, fused_attn_bwd_fp8 = _run_dpa_fp8_vs_f16( dtype, config, True, qkv_layout, is_training, fp8_recipe ) @@ -2166,7 +2164,6 @@ def test_dpa_fp8_vs_f16(dtype, model, qkv_layout, fp8_dpa_bwd, is_training, scal if config.dropout_p == 0.0: # test cuDNN FP8 dropout: need a FP16/BF16 reference on Blackwell logging.info("[test_dpa_fp8_vs_f16]: run with fp8_dpa = False (FusedAttention)") - print(f"Running fused attention with fp8_dpa = False") fused_attn_fwd_f16, fused_attn_bwd_f16 = _run_dpa_fp8_vs_f16( dtype, config, False, qkv_layout, is_training, fp8_recipe ) @@ -2188,7 +2185,7 @@ def test_dpa_fp8_vs_f16(dtype, model, qkv_layout, fp8_dpa_bwd, is_training, scal rmse_tol, True, ) - if unfused_attn_supported: + if False: #unfused_attn_supported: logging.debug("========== {:^25s} ==========".format("unfused fp8 vs fused f16:")) logging.debug("========== {:^25s} ==========".format("forward output")) compare_and_assert( diff --git a/transformer_engine/common/fused_attn/fused_attn.cpp b/transformer_engine/common/fused_attn/fused_attn.cpp index 02ff448544..61a8d61635 100644 --- a/transformer_engine/common/fused_attn/fused_attn.cpp +++ b/transformer_engine/common/fused_attn/fused_attn.cpp @@ -117,8 +117,8 @@ NVTE_QKV_Layout_Group nvte_get_qkv_layout_group(NVTE_QKV_Layout qkv_layout) { case NVTE_QKV_Layout::NVTE_Paged_KV_SBHD_SBHD_SBHD: case NVTE_QKV_Layout::NVTE_Paged_KV_THD_SBHD_SBHD: return NVTE_QKV_Layout_Group::NVTE_Paged_KV_HD_HD_HD; - case NVTE_QKV_Layout::NVTE_BSHD_BSHD_BHSD: - return NVTE_QKV_Layout_Group::NVTE_HD_HD_SD; + case NVTE_QKV_Layout::NVTE_BSHD_BSHD_BHDS: + return NVTE_QKV_Layout_Group::NVTE_HD_HD_DS; default: NVTE_ERROR("qkv_layout not supported!"); } @@ -159,8 +159,8 @@ NVTE_QKV_Format nvte_get_qkv_format(NVTE_QKV_Layout qkv_layout) { case NVTE_QKV_Layout::NVTE_THD_SBHD_SBHD: case NVTE_QKV_Layout::NVTE_Paged_KV_THD_SBHD_SBHD: return NVTE_QKV_Format::NVTE_THD_2SBHD; - case NVTE_QKV_Layout::NVTE_BSHD_BSHD_BHSD: - return NVTE_QKV_Format::NVTE_BHSD; + case NVTE_QKV_Layout::NVTE_BSHD_BSHD_BHDS: + return NVTE_QKV_Format::NVTE_BHDS; default: NVTE_ERROR("qkv_layout not supported!"); } @@ -180,8 +180,8 @@ NVTE_QKV_Format nvte_get_q_format(NVTE_QKV_Layout qkv_layout) { case NVTE_QKV_Format::NVTE_THD_2BSHD: case NVTE_QKV_Format::NVTE_THD_2SBHD: return NVTE_QKV_Format::NVTE_THD; - case NVTE_QKV_Format::NVTE_BHSD: - return NVTE_QKV_Format::NVTE_BHSD; + case NVTE_QKV_Format::NVTE_BHDS: + return NVTE_QKV_Format::NVTE_BHDS; default: NVTE_ERROR("qkv_layout not supported!"); } @@ -201,8 +201,8 @@ NVTE_QKV_Format nvte_get_kv_format(NVTE_QKV_Layout qkv_layout) { return NVTE_QKV_Format::NVTE_BSHD; case NVTE_QKV_Format::NVTE_THD: return NVTE_QKV_Format::NVTE_THD; - case NVTE_QKV_Format::NVTE_BHSD: - return NVTE_QKV_Format::NVTE_BHSD; + case NVTE_QKV_Format::NVTE_BHDS: + return NVTE_QKV_Format::NVTE_BHDS; default: NVTE_ERROR("qkv_layout not supported!"); } @@ -234,6 +234,7 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( max_seqlen_kv, head_dim_qk, head_dim_v) == DType::kInt64); const bool supported_ragged_offset_size = (!requires_64bit_ragged_offset || cudnn_runtime_version >= 90500); + printf(">>>>>> qkv_layout: %d\n", qkv_layout); printf(">>>>>> q_dtype: %d\n", q_dtype); printf(">>>>>> qkv_format: %d\n", qkv_format); printf(">>>>>> q_format: %d\n", q_format); @@ -243,49 +244,45 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( printf(">>>>>> is_training: %d\n", is_training); printf(">>>>>> bias_type: %d\n", bias_type); printf(">>>>>> attn_mask_type: %d\n", attn_mask_type); - if (q_dtype == NVTEDType::kNVTEFloat8E4M3 || q_dtype == NVTEDType::kNVTEFloat8E5M2) { - backend = NVTE_Fused_Attn_Backend::NVTE_FP8; - // } - - // if ((q_dtype == NVTEDType::kNVTEFloat8E4M3 || q_dtype == NVTEDType::kNVTEFloat8E5M2) && - // sm_arch_ >= 90 && bias_type == NVTE_Bias_Type::NVTE_NO_BIAS && - // // 8.9: t3hd, max_s=512, d=64, padding - // ((cudnn_runtime_version >= 8900 && sm_arch_ < 100 && - // qkv_layout == NVTE_QKV_Layout::NVTE_T3HD && max_seqlen_q == max_seqlen_kv && - // max_seqlen_q <= 512 && head_dim_qk == 64 && head_dim_v == 64 && - // attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK) || - // // 9.2.1: {bshd, sbhd}, any seqlen, d=128, {no_mask, causal} - // (cudnn_runtime_version >= 90201 && sm_arch_ < 100 && max_seqlen_q % 128 == 0 && - // max_seqlen_kv % 128 == 0 && head_dim_qk == 128 && head_dim_v == 128 && - // (attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK || - // attn_mask_type == NVTE_Mask_Type::NVTE_NO_MASK)) || - // // 9.7: {bshd, sbhd}, any seqlen, d<=256 for sm90 and d<=128 for sm100, {padding, padding_causal} - // (cudnn_runtime_version >= 90700 && - // // TODO (cyang): add is_training to nvte_get_fused_attn_backend - // // sm90: fwd d<=256, bwd d=128 only - // // sm100: fwd d<=128, bwd d<=128 - // ((sm_arch_ < 100 && (!is_training) && head_dim_qk <= 256 && head_dim_v <= 256) || - // (sm_arch_ < 100 && is_training && head_dim_qk == 128 && head_dim_v == 128) || - // (sm_arch_ >= 100 && head_dim_qk <= 128 && head_dim_v <= 128)) && - // head_dim_qk % 16 == 0 && head_dim_v % 16 == 0 && - // (attn_mask_type == NVTE_Mask_Type::NVTE_NO_MASK || - // attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK || - // attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK || - // attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK))) && - // (qkv_format == NVTE_QKV_Format::NVTE_BSHD || qkv_format == NVTE_QKV_Format::NVTE_SBHD || qkv_format == NVTE_QKV_Format::NVTE_BHSD) && - // !requires_64bit_ragged_offset && (softmax_type == NVTE_Softmax_Type::NVTE_VANILLA_SOFTMAX) && - // // 9.10.0: known bugs with SDPA FP8 - // (cudnn_runtime_version != 91000) && !return_max_logit) { - // if (cudnn_runtime_version >= 8900) { - // backend = NVTE_Fused_Attn_Backend::NVTE_FP8; - // } else { - // backend = NVTE_Fused_Attn_Backend::NVTE_No_Backend; - // std::cout << "Warning: FP8 fused attention is supported by cuDNN 8.9.0+." - // " Please upgrade your cuDNN version if possible." - // << std::endl; - // } - // } else -} else if ((q_dtype == NVTEDType::kNVTEFloat16) || (q_dtype == NVTEDType::kNVTEBFloat16)) { + + if ((q_dtype == NVTEDType::kNVTEFloat8E4M3 || q_dtype == NVTEDType::kNVTEFloat8E5M2) && + sm_arch_ >= 90 && bias_type == NVTE_Bias_Type::NVTE_NO_BIAS && + // 8.9: t3hd, max_s=512, d=64, padding + // ((cudnn_runtime_version >= 8900 && sm_arch_ < 100 && + // qkv_layout == NVTE_QKV_Layout::NVTE_T3HD && max_seqlen_q == max_seqlen_kv && + // max_seqlen_q <= 512 && head_dim_qk == 64 && head_dim_v == 64 && + // attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK) || + // // 9.2.1: {bshd, sbhd}, any seqlen, d=128, {no_mask, causal} + // (cudnn_runtime_version >= 90201 && sm_arch_ < 100 && max_seqlen_q % 128 == 0 && + // max_seqlen_kv % 128 == 0 && head_dim_qk == 128 && head_dim_v == 128 && + // (attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK || + // attn_mask_type == NVTE_Mask_Type::NVTE_NO_MASK)) || + // // 9.7: {bshd, sbhd}, any seqlen, d<=256 for sm90 and d<=128 for sm100, {padding, padding_causal} + // (cudnn_runtime_version >= 90700 && + // // TODO (cyang): add is_training to nvte_get_fused_attn_backend + // // sm90: fwd d<=256, bwd d=128 only + // // sm100: fwd d<=128, bwd d<=128 + // ((sm_arch_ < 100 && (!is_training) && head_dim_qk <= 256 && head_dim_v <= 256) || + // (sm_arch_ < 100 && is_training && head_dim_qk == 128 && head_dim_v == 128) || + // (sm_arch_ >= 100 && head_dim_qk <= 128 && head_dim_v <= 128)) && + // head_dim_qk % 16 == 0 && head_dim_v % 16 == 0 && + // (attn_mask_type == NVTE_Mask_Type::NVTE_NO_MASK || + // attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK || + // attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK || + // attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK))) && + (qkv_format == NVTE_QKV_Format::NVTE_BSHD || qkv_format == NVTE_QKV_Format::NVTE_SBHD || qkv_format == NVTE_QKV_Format::NVTE_BHDS) && + !requires_64bit_ragged_offset && (softmax_type == NVTE_Softmax_Type::NVTE_VANILLA_SOFTMAX) && + // 9.10.0: known bugs with SDPA FP8 + (cudnn_runtime_version != 91000) && !return_max_logit) { + if (cudnn_runtime_version >= 8900) { + backend = NVTE_Fused_Attn_Backend::NVTE_FP8; + } else { + backend = NVTE_Fused_Attn_Backend::NVTE_No_Backend; + std::cout << "Warning: FP8 fused attention is supported by cuDNN 8.9.0+." + " Please upgrade your cuDNN version if possible." + << std::endl; + } + } else if ((q_dtype == NVTEDType::kNVTEFloat16) || (q_dtype == NVTEDType::kNVTEBFloat16)) { bool flag_m512 = false; bool flag_arb = false; if ((sm_arch_ == 80 || sm_arch_ == 90) && (max_seqlen_q <= 512 && max_seqlen_q % 64 == 0) && @@ -1205,6 +1202,7 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso h_q, h_kv, max_seqlen_q, max_seqlen_kv, d_qk, d_v, window_size_left, window_size_right, return_max_logit, cuda_graph, false); + printf(">>>>>> fused_attention_backend: %d\n", fused_attention_backend); if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) { #if (CUDNN_VERSION >= 8901) fused_attn_max_512_fwd(b, h_q, max_seqlen_q, max_seqlen_kv, d_qk, is_training, attn_scale, diff --git a/transformer_engine/common/fused_attn/fused_attn_fp8.cu b/transformer_engine/common/fused_attn/fused_attn_fp8.cu index ea5722a831..bf4f019a67 100644 --- a/transformer_engine/common/fused_attn/fused_attn_fp8.cu +++ b/transformer_engine/common/fused_attn/fused_attn_fp8.cu @@ -1696,12 +1696,6 @@ void fused_attn_fp8_fwd_impl_v1( printf(">>>>>> mask_type: %d\n", mask_type); printf(">>>>>> scaling_factor: %f\n", scaling_factor); printf(">>>>>> dropout_probability: %f\n", dropout_probability); - is_mxfp8 = true; - // qkv_tensor_type = cudnn_frontend::DataType_t::FP8_E8M0; - // o_tensor_type = cudnn_frontend::DataType_t::BFLOAT16; - // printf(">>>>>> after setting qkv_tensor_type and o_tensor_type\n"); - // printf(">>>>>> qkv_tensor_type: %d\n", qkv_tensor_type); - // printf(">>>>>> o_tensor_type: %d\n", o_tensor_type); try { FADescriptor_v1 descriptor{b, @@ -1792,7 +1786,7 @@ void fused_attn_fp8_fwd_impl_v1( generateMatrixStrides(b, hg, s_q, s_kv, d, k_stride.data(), layout, NVTE_QKV_Matrix::NVTE_K_Matrix); generateMatrixStrides(b, hg, s_q, s_kv, d, v_stride.data(), layout, - NVTE_QKV_Matrix::NVTE_V_Matrix); + NVTE_QKV_Matrix::NVTE_K_Matrix); // need to double check printf(">>>>>> q_stride: %d, %d, %d, %d\n", q_stride[0], q_stride[1], q_stride[2], q_stride[3]); printf(">>>>>> k_stride: %d, %d, %d, %d\n", k_stride[0], k_stride[1], k_stride[2], k_stride[3]); printf(">>>>>> v_stride: %d, %d, %d, %d\n", v_stride[0], v_stride[1], v_stride[2], v_stride[3]); @@ -1804,7 +1798,7 @@ void fused_attn_fp8_fwd_impl_v1( int64_t s_kv_padded = ((s_kv + 127) / 128) * 128; int64_t d_scale_padded = ((d_scale + 3) / 4) * 4; int64_t s_kv_scale_padded = ((s_kv_scale + 3) / 4) * 4; - int64_t d_padded = ((d + 3) / 4) * 4; // d dimension for SF_V (not scaled, but may need padding) + int64_t d_padded = ((d + 3) / 4) * 4; printf(">>>>>> d_scale: %d, s_kv_scale: %d, s_q_padded: %d, s_kv_padded: %d, d_scale_padded: %d, s_kv_scale_padded: %d, d_padded: %d\n", d_scale, s_kv_scale, s_q_padded, s_kv_padded, d_scale_padded, s_kv_scale_padded, d_padded); std::vector q_scale_dims = {b, h, s_q_padded, d_scale_padded}; std::vector k_scale_dims = {b, hg, s_kv_padded, d_scale_padded}; @@ -1816,13 +1810,8 @@ void fused_attn_fp8_fwd_impl_v1( NVTE_QKV_Matrix::NVTE_Q_Matrix); generateMatrixStrides(b, hg, s_q_padded, s_kv_padded, d_scale_padded, k_scale_strides.data(), layout, NVTE_QKV_Matrix::NVTE_K_Matrix); -// generateMatrixStrides(b, d_padded, s_q_padded, hg, s_kv_scale_padded, v_scale_strides.data(), layout, - // generateMatrixStrides(b, hg, s_q_padded, s_kv_scale_padded, d_padded, v_scale_strides.data(), layout, - // NVTE_QKV_Matrix::NVTE_V_Matrix); - v_scale_strides[0] = h*d_padded*s_kv_scale_padded; - v_scale_strides[1] = d_padded*s_kv_scale_padded; - v_scale_strides[2] = 1; - v_scale_strides[3] = s_kv_scale_padded; + generateMatrixStrides(b, hg, s_q_padded, s_kv_scale_padded, d_padded, v_scale_strides.data(), layout, + NVTE_QKV_Matrix::NVTE_V_Matrix); printf(">>>>>> q_scale_strides: %d, %d, %d, %d\n", q_scale_strides[0], q_scale_strides[1], q_scale_strides[2], q_scale_strides[3]); printf(">>>>>> k_scale_strides: %d, %d, %d, %d\n", k_scale_strides[0], k_scale_strides[1], k_scale_strides[2], k_scale_strides[3]); printf(">>>>>> v_scale_strides: %d, %d, %d, %d\n", v_scale_strides[0], v_scale_strides[1], v_scale_strides[2], v_scale_strides[3]); @@ -1977,8 +1966,8 @@ void fused_attn_fp8_fwd_impl_v1( std::shared_ptr, // O std::shared_ptr, // amax_s std::shared_ptr> // amax_o - key_tensors_tuple = - is_mxfp8 ? std::make_tuple(Q, K, V, descale_q, descale_k, descale_v, nullptr, nullptr, nullptr, attn_scale, O, nullptr, amax_o) : + key_tensors_tuple = + is_mxfp8 ? std::make_tuple(Q, K, V, descale_q, descale_k, descale_v, nullptr, nullptr, nullptr, attn_scale, O, nullptr, amax_o) : std::make_tuple(Q, K, V, descale_q, descale_k, descale_v, descale_s, scale_s, scale_o, attn_scale, O, amax_s, amax_o); auto Stats_tuple = std::make_tuple(Stats); @@ -1997,7 +1986,7 @@ void fused_attn_fp8_fwd_impl_v1( NVTE_CHECK_CUDNN_FE(mha_graph->check_support(handle)); printf(">>>>>> mha_graph->check_support(handle)\n"); NVTE_CHECK_CUDNN_FE(mha_graph->build_plans(handle)); - printf(">>>>>> mha_graph->build_plans(handle)\n"); + printf(">>>>>> mha_graph->build_plans(handle)\n"); printf(">>>>>> mha_graph->get_workspace_size(): %zu\n", mha_graph->get_workspace_size()); auto return_tuple = std::tuple_cat(std::make_tuple(mha_graph), key_tensors_tuple, Stats_tuple, bias_tuple, padding_tuple, dropout_tuple); @@ -2532,19 +2521,8 @@ void fused_attn_fp8_fwd(size_t batch, size_t num_attn_heads, size_t num_gqa_grou void* devPtrAmaxS = nullptr; void* devPtrScaleS = nullptr; void* devPtrDescaleS = nullptr; - printf(">>>>>> fused_attn_fp8_fwd\n"); - // if (input_Q->scaling_mode() == NVTE_MXFP8_1D_SCALING) { - // printf(">>>>>> input_Q is MXFP8\n"); - // devPtrQ = input_Q-> - // devPtrQ = input_Q->get_rowwise_data_ptr(); - // devPtrDescaleQ = input_Q->get_rowwise_scale_inv_ptr(); - // devPtrK = input_K->get_rowwise_data_ptr(); - // devPtrDescaleK = input_K->get_rowwise_scale_inv_ptr(); - // devPtrV = input_V->get_rowwise_data_ptr(); - // devPtrDescaleV = input_V->get_rowwise_scale_inv_ptr(); - // devPtrO = output_O->get_rowwise_data_ptr(); - // devPtrAmaxO = output_O->amax.dptr; - // } else { + if (input_Q->scaling_mode == NVTE_MXFP8_1D_SCALING) { + printf(">>>>>> input_Q is MXFP8_1D_SCALING\n"); devPtrQ = input_Q->data.dptr; devPtrDescaleQ = input_Q->scale_inv.dptr; devPtrK = input_K->data.dptr; @@ -2555,15 +2533,21 @@ void fused_attn_fp8_fwd(size_t batch, size_t num_attn_heads, size_t num_gqa_grou // devPtrDescaleV = input_V->columnwise_scale_inv.dptr; devPtrO = output_O->data.dptr; devPtrAmaxO = output_O->amax.dptr; - // devPtrScaleO = output_O->scale.dptr; - // devPtrAmaxS = input_output_S->amax.dptr; - // devPtrScaleS = input_output_S->scale.dptr; - // devPtrDescaleS = input_output_S->scale_inv.dptr; - // } - printf(">>>>>> scaling_mode: %d\n", input_Q->scaling_mode); - printf(">>>>>> scaling_mode: %d\n", input_K->scaling_mode); - printf(">>>>>> scaling_mode: %d\n", input_V->scaling_mode); - printf(">>>>>> scaling_mode: %d\n", output_O->scaling_mode); + } else { + printf(">>>>>> input_Q is not MXFP8_1D_SCALING\n"); + devPtrQ = input_Q->data.dptr; + devPtrDescaleQ = input_Q->scale_inv.dptr; + devPtrK = input_K->data.dptr; + devPtrDescaleK = input_K->scale_inv.dptr; + devPtrV = input_V->data.dptr; + devPtrDescaleV = input_V->scale_inv.dptr; + devPtrO = output_O->data.dptr; + devPtrAmaxO = output_O->amax.dptr; + devPtrScaleO = output_O->scale.dptr; + devPtrAmaxS = input_output_S->amax.dptr; + devPtrScaleS = input_output_S->scale.dptr; + devPtrDescaleS = input_output_S->scale_inv.dptr; + } void* devPtrM = nullptr; void* devPtrZInv = nullptr; if (Aux_CTX_Tensors->size == 0) { @@ -2605,7 +2589,7 @@ void fused_attn_fp8_fwd(size_t batch, size_t num_attn_heads, size_t num_gqa_grou size_t workspace_size = 0; NVTE_QKV_Format qkv_format = nvte_get_qkv_format(qkv_layout); - if ((qkv_format == NVTE_QKV_Format::NVTE_BSHD) || (qkv_format == NVTE_QKV_Format::NVTE_SBHD)) { + if ((qkv_format == NVTE_QKV_Format::NVTE_BSHD) || (qkv_format == NVTE_QKV_Format::NVTE_SBHD) || (qkv_format == NVTE_QKV_Format::NVTE_BHDS)) { fused_attn::fused_attn_fp8_fwd_impl_v1( batch, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, head_dim, is_training, attn_scale, p_dropout, qkv_layout, bias_type, mask_type, devPtrQ, devPtrK, devPtrV, devPtrM, diff --git a/transformer_engine/common/fused_attn/utils.cu b/transformer_engine/common/fused_attn/utils.cu index 0309cf643d..94a495153e 100644 --- a/transformer_engine/common/fused_attn/utils.cu +++ b/transformer_engine/common/fused_attn/utils.cu @@ -293,7 +293,7 @@ void generateMatrixStrides(int64_t b, int64_t h, int64_t s_q, int64_t s_kv, int6 strideA[hidden_dim_idx] = 1; } break; - case NVTE_QKV_Layout::NVTE_BSHD_BSHD_BHSD: + case NVTE_QKV_Layout::NVTE_BSHD_BSHD_BHDS: if ((matrix == NVTE_QKV_Matrix::NVTE_Q_Matrix) || (matrix == NVTE_QKV_Matrix::NVTE_O_Matrix)) { strideA[batch_dim_idx] = s_q * h * d; @@ -310,10 +310,16 @@ void generateMatrixStrides(int64_t b, int64_t h, int64_t s_q, int64_t s_kv, int6 strideA[head_dim_idx] = d * s_kv; strideA[seqlen_dim_idx] = 1; strideA[hidden_dim_idx] = s_kv; - // strideA[batch_dim_idx] = h * s_kv * d; - // strideA[head_dim_idx] = s_kv * d; - // strideA[seqlen_transpose_dim_idx] = d; - // strideA[hidden_transpose_dim_idx] = 1; + } else if (matrix == NVTE_QKV_Matrix::NVTE_K_Matrix_Transpose) { + strideA[batch_dim_idx] = s_kv * h * d; + strideA[head_dim_idx] = d; + strideA[seqlen_transpose_dim_idx] = 1; + strideA[hidden_transpose_dim_idx] = h * d; + } else if (matrix == NVTE_QKV_Matrix::NVTE_V_Matrix_Transpose) { + strideA[batch_dim_idx] = s_kv * h * d; + strideA[head_dim_idx] = d * s_kv; + strideA[seqlen_transpose_dim_idx] = s_kv; + strideA[hidden_transpose_dim_idx] = 1; } break; } diff --git a/transformer_engine/common/include/transformer_engine/fused_attn.h b/transformer_engine/common/include/transformer_engine/fused_attn.h index 7c54633989..bc97d2a853 100644 --- a/transformer_engine/common/include/transformer_engine/fused_attn.h +++ b/transformer_engine/common/include/transformer_engine/fused_attn.h @@ -52,7 +52,7 @@ enum NVTE_QKV_Layout { NVTE_Paged_KV_SBHD_SBHD_SBHD = 22, /*!< Paged_KV_SBHD_SBHD_SBHD layout */ NVTE_Paged_KV_THD_BSHD_BSHD = 23, /*!< Paged_KV_THD_BSHD_BSHD layout */ NVTE_Paged_KV_THD_SBHD_SBHD = 24, /*!< Paged_KV_THD_SBHD_SBHD layout */ - NVTE_BSHD_BSHD_BHSD = 25, /*!< BSHD_BSHD_BHSD layout */ + NVTE_BSHD_BSHD_BHDS = 25, /*!< BSHD_BSHD_BHDS layout */ }; /*! \enum NVTE_QKV_Layout_Group @@ -71,8 +71,8 @@ enum NVTE_QKV_Layout_Group { NVTE_HD_HD_HD = 4, /*! Paged_KV_HD_HD_HD QKV layouts, e.g. Paged_KV_BSHD_BSHD_BSHD, Paged_KV_THD_SBHD_SBHD */ NVTE_Paged_KV_HD_HD_HD = 5, - /*! BSHD_BSHD_BHSD QKV layouts, e.g. BSHD_BSHD_BHSD */ - NVTE_HD_HD_SD = 6, + /*! BSHD_BSHD_BHDS QKV layouts, e.g. BSHD_BSHD_BHDS */ + NVTE_HD_HD_DS = 6, }; /*! \enum NVTE_QKV_Format @@ -94,7 +94,7 @@ enum NVTE_QKV_Format { /*! THD format for Q and SBHD format for KV, i.e. THD_SBHD_SBHD, Paged_KV_THD_SBHD_SBHD */ NVTE_THD_2SBHD = 6, /*! BSHD_BSHD_BHSD QKV format, e.g. BSHD_BSHD_BHSD */ - NVTE_BHSD = 7, + NVTE_BHDS = 7, }; /*! \enum NVTE_Bias_Type diff --git a/transformer_engine/common/util/pybind_helper.h b/transformer_engine/common/util/pybind_helper.h index 57d02bcd62..b81c488005 100644 --- a/transformer_engine/common/util/pybind_helper.h +++ b/transformer_engine/common/util/pybind_helper.h @@ -49,7 +49,7 @@ .value("NVTE_BSHD_2SBHD", NVTE_QKV_Format::NVTE_BSHD_2SBHD) \ .value("NVTE_THD_2BSHD", NVTE_QKV_Format::NVTE_THD_2BSHD) \ .value("NVTE_THD_2SBHD", NVTE_QKV_Format::NVTE_THD_2SBHD) \ - .value("NVTE_BHSD", NVTE_QKV_Format::NVTE_BHSD); \ + .value("NVTE_BHDS", NVTE_QKV_Format::NVTE_BHDS); \ pybind11::enum_(m, "NVTE_QKV_Layout", pybind11::module_local()) \ .value("NVTE_SB3HD", NVTE_QKV_Layout::NVTE_SB3HD) \ .value("NVTE_SBH3D", NVTE_QKV_Layout::NVTE_SBH3D) \ @@ -76,7 +76,7 @@ .value("NVTE_Paged_KV_SBHD_SBHD_SBHD", NVTE_QKV_Layout::NVTE_Paged_KV_SBHD_SBHD_SBHD) \ .value("NVTE_Paged_KV_THD_BSHD_BSHD", NVTE_QKV_Layout::NVTE_Paged_KV_THD_BSHD_BSHD) \ .value("NVTE_Paged_KV_THD_SBHD_SBHD", NVTE_QKV_Layout::NVTE_Paged_KV_THD_SBHD_SBHD) \ - .value("NVTE_BSHD_BSHD_BHSD", NVTE_QKV_Layout::NVTE_BSHD_BSHD_BHSD); \ + .value("NVTE_BSHD_BSHD_BHDS", NVTE_QKV_Layout::NVTE_BSHD_BSHD_BHDS); \ pybind11::enum_(m, "NVTE_Fused_Attn_Backend", pybind11::module_local()) \ .value("NVTE_F16_max512_seqlen", NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) \ .value("NVTE_F16_arbitrary_seqlen", NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen) \ diff --git a/transformer_engine/pytorch/attention/dot_product_attention/backends.py b/transformer_engine/pytorch/attention/dot_product_attention/backends.py index 5da38045e4..cce4159ea6 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/backends.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/backends.py @@ -35,6 +35,7 @@ restore_from_saved, ) from transformer_engine.pytorch.tensor.float8_tensor import Float8Tensor +from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8Tensor from transformer_engine.pytorch.constants import ( TE_DType, QKVLayouts, @@ -168,7 +169,7 @@ def forward(ctx, tensor1, tensor2, tensor3, quantizer, quantizer_name, qkv_layou query_layer, key_layer, value_layer = [ x.contiguous() for x in [tensor1, tensor2, tensor3] ] - q_fp8, k_fp8, v_fp8 = combine_and_quantize( + q_fp8, k_fp8, v_fp8, qkv_layout = combine_and_quantize( qkv_layout, query_layer, key_layer, value_layer, quantizer ) tensors = combine_and_dequantize( @@ -192,7 +193,7 @@ def backward(ctx, grad1, grad2, grad3): tensors = dt_fp8.dequantize(dtype=grad1.dtype), grad2, grad3 elif ctx.quantizer_name == "dQKV_quantizer": query_grad, key_grad, value_grad = [x.contiguous() for x in [grad1, grad2, grad3]] - dq_fp8, dk_fp8, dv_fp8 = combine_and_quantize( + dq_fp8, dk_fp8, dv_fp8, ctx.qkv_layout = combine_and_quantize( ctx.qkv_layout, query_grad, key_grad, value_grad, ctx.quantizer ) tensors = combine_and_dequantize( @@ -1180,9 +1181,7 @@ def forward( if is_input_fp8: q_fp8, k_fp8, v_fp8 = q, k, v else: - print(f">>>>>>> Combining and quantizing q, k, v <<<<<<<") - print(f"q: {q.shape}, k: {k.shape}, v: {v.shape}") - q_fp8, k_fp8, v_fp8 = combine_and_quantize(qkv_layout, q, k, v, QKV_quantizer) + q_fp8, k_fp8, v_fp8, qkv_layout = combine_and_quantize(qkv_layout, q, k, v, QKV_quantizer) # print quantizers print_quantizers( @@ -1237,11 +1236,17 @@ def forward( # out: torch.Tensor; dtype = torch.float16 or torch.bfloat16 out_fp8 = out_ out = out_ - - if isinstance(out_, Float8Tensor): + print(f"out_: {type(out_)} {out_.shape}") + print(f"is_output_fp8: {is_output_fp8}") + print(f"is_bwd_fp8: {is_bwd_fp8}") + print(f"fp8_recipe.float8_current_scaling(): {fp8_recipe.float8_current_scaling()}") + print(f"_dpa_fp8_cs_o_in_f16: {_dpa_fp8_cs_o_in_f16}") + if isinstance(out_, Float8Tensor) or isinstance(out_, MXFP8Tensor): + print(f"dequantizing out_") if not is_output_fp8 or not is_bwd_fp8: out = out_.dequantize().view(out_.shape) else: + print(f"quantizing out_") if is_output_fp8 or ( is_bwd_fp8 and not (fp8_recipe.float8_current_scaling() and _dpa_fp8_cs_o_in_f16) @@ -1562,7 +1567,7 @@ def backward(ctx, d_out, *_args): ) if not is_float8tensor and ctx.is_input_fp8: # return in FP8 - dq, dk, dv = combine_and_quantize( + dq, dk, dv, ctx.qkv_layout = combine_and_quantize( ctx.qkv_layout, dq_, dk_, dv_, ctx.dQKV_quantizer ) diff --git a/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py b/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py index 8699f22cb9..5da35d157f 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py @@ -673,11 +673,6 @@ def init_fp8_metadata(self, num_gemms: int = 1) -> None: ] fp8_recipe_dpa = fake_recipes[1] fp8_recipes = fake_recipes - # DPA only support DS and CS; other recipes should have fp8_dpa=False, fp8_mha=False -# if not fp8_recipe_dpa.float8_per_tensor_scaling(): -# assert not ( -# fp8_recipe_dpa.fp8_dpa or fp8_recipe_dpa.fp8_mha -# ), f"DotProductAttention does not support {fp8_recipe_dpa.__class__.__name__} recipe" # reduce over TP+CP groups; expect fp8_group to be set up so # assume attention uses the same fp8_group as GEMMs diff --git a/transformer_engine/pytorch/attention/dot_product_attention/utils.py b/transformer_engine/pytorch/attention/dot_product_attention/utils.py index 0dfc64a65c..eaeecaca4b 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/utils.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/utils.py @@ -2090,28 +2090,27 @@ def get_attention_quantizers(fp8, quantizers): """Get the list of quantizers used in attention from the quantizers list.""" if not fp8: return [None] * 6 - is_fwd = True - is_bwd = True QKV_quantizer = quantizers["scaling_fwd"][META_QKV] QKV_quantizer.internal = True QKV_quantizer.set_usage(rowwise=True, columnwise=True) O_quantizer = quantizers["scaling_fwd"][META_O] - O_quantizer.set_usage(rowwise=True, columnwise=True) - S_quantizer = None - # quantizers["scaling_fwd"][META_S] - # S_quantizer.internal = True - # S_quantizer.set_usage(rowwise=True, columnwise=columnwise) + O_quantizer.set_usage(rowwise=True, columnwise=False) + if isinstance(QKV_quantizer, MXFP8Quantizer): + S_quantizer = None + else: + S_quantizer = quantizers["scaling_fwd"][META_S] + S_quantizer.internal = True + S_quantizer.set_usage(rowwise=True, columnwise=False) dQKV_quantizer = quantizers["scaling_bwd"][META_DQKV] dQKV_quantizer.interal = True - dQKV_quantizer.set_usage(rowwise=True, columnwise=True) + dQKV_quantizer.set_usage(rowwise=True, columnwise=False) dO_quantizer = quantizers["scaling_bwd"][META_DO] - dO_quantizer.set_usage(rowwise=True, columnwise=True) + dO_quantizer.set_usage(rowwise=True, columnwise=False) dO_quantizer.internal = True dP_quantizer = quantizers["scaling_bwd"][META_DP] dP_quantizer.set_usage(rowwise=True, columnwise=False) dP_quantizer.interal = True - print(f"QKV_quantizer: {QKV_quantizer}, O_quantizer: {O_quantizer}, S_quantizer: {S_quantizer}, dQKV_quantizer: {dQKV_quantizer}, dO_quantizer: {dO_quantizer}, dP_quantizer: {dP_quantizer}") return QKV_quantizer, O_quantizer, S_quantizer, dQKV_quantizer, dO_quantizer, dP_quantizer @@ -2187,38 +2186,25 @@ def combine_and_quantize(qkv_layout, q, k, v, qkv_quantizer): src_nominal_dtype = q.dtype print(f"Combining and quantizing q, k, v: {q.shape}, {k.shape}, {v.shape}") if isinstance(qkv_quantizer, MXFP8Quantizer): - print(f"Using MXFP8Quantizer") qkv_quantizer._internal = False dim_s = qkv_format.find("s") if 's' in qkv_format else qkv_format.find("t") dim_others = [i for i in range(len(v.shape)) if i != dim_s] perm = [*dim_others, dim_s] # perm = [*dim_others[:-1], dim_s, dim_others[-1]] v = v.permute(*perm).contiguous() - qkv_layout = "bshd_bshd_bhsd" - # inv = [0] * len(perm) - # for i, p in enumerate(perm): - # inv[p] = i + qkv_layout = "bshd_bshd_bhds" + inv = [0] * len(perm) + for i, p in enumerate(perm): + inv[p] = i # v = v.permute(*inv) q_fp8, k_fp8, v_fp8 = [qkv_quantizer(x) for x in [q, k, v]] + v_fp8._rowwise_data = v_fp8._rowwise_data.permute(*inv).contiguous() print(f"q_fp8: {q_fp8._rowwise_data.shape}, k_fp8: {k_fp8._rowwise_data.shape}, v_fp8: {v_fp8._rowwise_data.shape}") print(f"q_fp8: {q_fp8._rowwise_scale_inv.shape}, k_fp8: {k_fp8._rowwise_scale_inv.shape}, v_fp8: {v_fp8._rowwise_scale_inv.shape}") print(f"q_fp8: {q_fp8._columnwise_data.shape}, k_fp8: {k_fp8._columnwise_data.shape}, v_fp8: {v_fp8._columnwise_data.shape}") print(f"q_fp8: {q_fp8._columnwise_scale_inv.shape}, k_fp8: {k_fp8._columnwise_scale_inv.shape}, v_fp8: {v_fp8._columnwise_scale_inv.shape}") - # v_fp8._rowwise_data = v_fp8._rowwise_data.permute(*inv) - return q_fp8, k_fp8, v_fp8 - - # q_fp8, k_fp8, v_fp8 = [qkv_quantizer(x) for x in [q, k, v]] - # v_permuted = v.permute(0, 2, 3, 1).contiguous() - # v_fp8_permuted = qkv_quantizer(v_permuted) - # print(f"v_fp8: {v_fp8._rowwise_scale_inv.shape}") - # print(f"v_fp8_permuted: {v_fp8_permuted._rowwise_scale_inv.shape}") - # # v_fp8_permuted_rowwise_shape = v_fp8._rowwise_scale_inv.permute(0, 2, 3, 1).shape - # v_fp8._rowwise_scale_inv = v_fp8_permuted._rowwise_scale_inv.view(2,16,128,-1).permute(0, 3, 1, 2)#.view(-1,128) - # print(f"q_fp8: {q_fp8._rowwise_data.shape}, k_fp8: {k_fp8._rowwise_data.shape}, v_fp8: {v_fp8._rowwise_data.shape}, v_fp8.stride(): {v_fp8._rowwise_data.stride()}") - # print(f"q_fp8: {q_fp8._rowwise_scale_inv.shape}, k_fp8: {k_fp8._rowwise_scale_inv.shape}, v_fp8: {v_fp8._rowwise_scale_inv.shape}, v_fp8.stride(): {v_fp8._rowwise_scale_inv.stride()}") - # print(f"q_fp8: {q_fp8._columnwise_data.shape}, k_fp8: {k_fp8._columnwise_data.shape}, v_fp8: {v_fp8._columnwise_data.shape}") - # print(f"q_fp8: {q_fp8._columnwise_scale_inv.shape}, k_fp8: {k_fp8._columnwise_scale_inv.shape}, v_fp8: {v_fp8._columnwise_scale_inv.shape}") - # return q_fp8, k_fp8, v_fp8 + return q_fp8, k_fp8, v_fp8, qkv_layout + match qkv_group: case 1: dim = qkv_layout.find("3") diff --git a/transformer_engine/pytorch/cpp_extensions/fused_attn.py b/transformer_engine/pytorch/cpp_extensions/fused_attn.py index 41007912c9..2748228b42 100644 --- a/transformer_engine/pytorch/cpp_extensions/fused_attn.py +++ b/transformer_engine/pytorch/cpp_extensions/fused_attn.py @@ -42,7 +42,7 @@ "bshd_2sbhd": NVTE_QKV_Format.NVTE_BSHD_2SBHD, "thd_2bshd": NVTE_QKV_Format.NVTE_THD_2BSHD, "thd_2sbhd": NVTE_QKV_Format.NVTE_THD_2SBHD, - "bshd_bshd_bhsd": NVTE_QKV_Format.NVTE_BHSD, + "bshd_bshd_bhds": NVTE_QKV_Format.NVTE_BHDS, } QKVLayout = { @@ -71,7 +71,7 @@ "paged_kv_sbhd_sbhd_sbhd": NVTE_QKV_Layout.NVTE_Paged_KV_SBHD_SBHD_SBHD, "paged_kv_thd_bshd_bshd": NVTE_QKV_Layout.NVTE_Paged_KV_THD_BSHD_BSHD, "paged_kv_thd_sbhd_sbhd": NVTE_QKV_Layout.NVTE_Paged_KV_THD_SBHD_SBHD, - "bshd_bshd_bhsd": NVTE_QKV_Layout.NVTE_BSHD_BSHD_BHSD, + "bshd_bshd_bhds": NVTE_QKV_Layout.NVTE_BSHD_BSHD_BHDS, } AttnBiasType = { @@ -295,14 +295,6 @@ def fused_attn_fwd( rng_elts_per_thread = ( max_seqlen_q * max_seqlen_q + BACKEND_F16m512_FP8_THREADS_PER_CTA - 1 ) // BACKEND_F16m512_FP8_THREADS_PER_CTA - - if not isinstance(o_quantizer, MXFP8Quantizer): - assert ( - s_quantizer is not None - ), "s_quantizer is required as an input for FP8 fused attention." - assert ( - o_quantizer is not None - ), "o_quantizer is required as an input for FP8 fused attention." else: raise ValueError(f"Unsupported backend {fused_attention_backend}") @@ -501,12 +493,6 @@ def fused_attn_bwd( ), "aux_ctx_tensors must contain rng_state as its last element." if fused_attention_backend == FusedAttnBackend["FP8"]: - assert ( - s_quantizer is not None - ), "s_quantizer is required as an input for FP8 fused attention backward." - assert ( - dp_quantizer is not None - ), "dp_quantizer is required as an input for FP8 fused attention backward." assert ( dqkv_dtype is not None ), "dqkv_dtype is required as an input for FP8 fused attention backward." diff --git a/transformer_engine/pytorch/csrc/extensions/attention.cpp b/transformer_engine/pytorch/csrc/extensions/attention.cpp index 094188b6c9..0d7a842ce1 100644 --- a/transformer_engine/pytorch/csrc/extensions/attention.cpp +++ b/transformer_engine/pytorch/csrc/extensions/attention.cpp @@ -132,7 +132,6 @@ std::vector fused_attn_fwd( auto none = py::none(); - printf(">>>>>>> Creating QKV tensor wrappers <<<<<<<\n"); // create QKV tensor wrappers TensorWrapper te_Q, te_K, te_V; te_Q = makeTransformerEngineTensor(Q, none); @@ -140,13 +139,11 @@ std::vector fused_attn_fwd( te_V = makeTransformerEngineTensor(V, none); const DType qkv_type = te_Q.dtype(); - printf(">>>>>> Creating S tensor wrapper <<<<<<<"); // create S tensor TensorWrapper te_S; py::object py_S; std::tie(te_S, py_S) = quantizer_helper(s_quantizer, {0}, DType::kFloat32, false, std::nullopt); - printf(">>>>>> Creating O tensor wrapper <<<<<<<\n"); // create O tensor TensorWrapper te_O; py::object py_O; @@ -158,7 +155,6 @@ std::vector fused_attn_fwd( const DType fake_dtype_te = GetTransformerEngineDType(fake_dtype); std::tie(te_O, py_O) = quantizer_helper(o_quantizer, o_shape, fake_dtype_te, true, std::nullopt); - printf(">>>>>> Creating Bias tensor wrapper <<<<<<<"); // construct NVTE tensors TensorWrapper te_Bias; TensorWrapper te_cu_seqlens_q, te_cu_seqlens_kv; From c627231a4f094acefdc2cc42e3b38c636aa0f095 Mon Sep 17 00:00:00 2001 From: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Date: Tue, 3 Feb 2026 20:13:32 -0800 Subject: [PATCH 004/172] comment out F16 pass Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --- tests/pytorch/attention/test_attention.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/tests/pytorch/attention/test_attention.py b/tests/pytorch/attention/test_attention.py index dad697e910..0301f77ae8 100644 --- a/tests/pytorch/attention/test_attention.py +++ b/tests/pytorch/attention/test_attention.py @@ -2158,15 +2158,15 @@ def test_dpa_fp8_vs_f16(dtype, model, qkv_layout, fp8_dpa_bwd, is_training, scal dtype, config, True, qkv_layout, is_training, fp8_recipe ) - os.environ["NVTE_FLASH_ATTN"] = "0" - os.environ["NVTE_FUSED_ATTN"] = "1" - os.environ["NVTE_UNFUSED_ATTN"] = "0" - if config.dropout_p == 0.0: - # test cuDNN FP8 dropout: need a FP16/BF16 reference on Blackwell - logging.info("[test_dpa_fp8_vs_f16]: run with fp8_dpa = False (FusedAttention)") - fused_attn_fwd_f16, fused_attn_bwd_f16 = _run_dpa_fp8_vs_f16( - dtype, config, False, qkv_layout, is_training, fp8_recipe - ) + # os.environ["NVTE_FLASH_ATTN"] = "0" + # os.environ["NVTE_FUSED_ATTN"] = "1" + # os.environ["NVTE_UNFUSED_ATTN"] = "0" + # if config.dropout_p == 0.0: + # # test cuDNN FP8 dropout: need a FP16/BF16 reference on Blackwell + # logging.info("[test_dpa_fp8_vs_f16]: run with fp8_dpa = False (FusedAttention)") + # fused_attn_fwd_f16, fused_attn_bwd_f16 = _run_dpa_fp8_vs_f16( + # dtype, config, False, qkv_layout, is_training, fp8_recipe + # ) atol = 5e-1 rtol = 5e-2 From 3f3b9e64f09bd4028157ded1b0bfd66157af6a72 Mon Sep 17 00:00:00 2001 From: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Date: Fri, 6 Feb 2026 14:46:37 -0800 Subject: [PATCH 005/172] pull in grouped_quantize for MXFP8 Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --- .../cast/mxfp8/group_quantize_mxfp8.cuh | 254 ++++++++++-------- transformer_engine/common/common.h | 5 +- 2 files changed, 148 insertions(+), 111 deletions(-) diff --git a/transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh b/transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh index 7801a2064d..df4317b547 100644 --- a/transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh +++ b/transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh @@ -21,6 +21,7 @@ #include "../../util/ptx.cuh" #include "../../utils.cuh" #include "../core/common.cuh" +#include "swizzle.cuh" namespace transformer_engine { namespace dispatch { @@ -231,7 +232,7 @@ __device__ __forceinline__ void fence_acquire_tensormap(const CUtensorMap *tenso template + bool COLWISE_SCALING, bool WITH_GEMM_SWIZZLED_SCALES> __global__ void __launch_bounds__(THREADS_PER_CHUNK) group_quantize_mxfp8_kernel( const __grid_constant__ CUtensorMap tensor_map_input_static, const __grid_constant__ CUtensorMap tensor_map_act_input_static, @@ -250,6 +251,8 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) group_quantize_mxfp8_kernel using IType2 = typename ptx::FPx2; using OType2 = typename ptx::FPx2; + using transformer_engine::dispatch::mxfp8::swizzle::gemm_swizzled_scale_idx; + if constexpr (NO_ACTIVATIONS) { if (noop != nullptr && noop[0] == 1.0f) { return; @@ -475,8 +478,15 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) group_quantize_mxfp8_kernel const size_t global_scales_offset_Y = scales_offset_Y_colwise + stage; const size_t global_scales_offset_X = scales_offset_X_colwise; - const size_t scale_idx = - global_scales_offset_Y * scale_stride_colwise + global_scales_offset_X; + // const size_t scale_idx = + // global_scales_offset_Y * scale_stride_colwise + global_scales_offset_X; + size_t scale_idx; + if constexpr (WITH_GEMM_SWIZZLED_SCALES) { + scale_idx = gemm_swizzled_scale_idx(global_scales_offset_X, global_scales_offset_Y, + DIVUP(rows, static_cast(128))); + } else { + scale_idx = global_scales_offset_Y * scale_stride_colwise + global_scales_offset_X; + } scales_colwise[scale_idx] = biased_exponent; const float block_scale_inverse = ptx::exp2f_rcp(biased_exponent); @@ -602,7 +612,14 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) group_quantize_mxfp8_kernel ptx::float_to_e8m0(thread_amax * Quantized_Limits::max_norm_rcp); const int stage_scales_offset_Y = scales_offset_Y_rowwise + stage_offset_Y; const int stage_scales_offset_X = scales_offset_X_rowwise; - const int scale_idx = stage_scales_offset_Y * scale_stride_rowwise + stage_scales_offset_X; + // const int scale_idx = stage_scales_offset_Y * scale_stride_rowwise + stage_scales_offset_X; + size_t scale_idx; + if constexpr (WITH_GEMM_SWIZZLED_SCALES) { + scale_idx = gemm_swizzled_scale_idx(stage_scales_offset_Y, stage_scales_offset_X, + DIVUP(cols, static_cast(128))); + } else { + scale_idx = stage_scales_offset_Y * scale_stride_rowwise + stage_scales_offset_X; + } scales_rowwise[scale_idx] = biased_exponent; const float block_scale_inverse = ptx::exp2f_rcp(biased_exponent); @@ -738,7 +755,7 @@ void group_quantize(const GroupedTensor *input, const GroupedTensor *activations using namespace group_quantize_kernel; checkCuDriverContext(stream); - CheckNoopTensor(*noop, "cast_noop"); + // CheckNoopTensor(*noop, "cast_noop"); const bool use_rowwise_scaling = output->has_data(); const bool use_colwise_scaling = output->has_columnwise_data(); @@ -751,6 +768,13 @@ void group_quantize(const GroupedTensor *input, const GroupedTensor *activations } else if (!use_rowwise_scaling) { scaling_type = ScalingType::COLWISE; } + // if (use_rowwise_scaling && (!use_colwise_scaling)) { + // scaling_type = ScalingType::ROWWISE; + // } else if ((!use_rowwise_scaling) && use_colwise_scaling) { + // scaling_type = ScalingType::COLWISE; + // } else if (use_rowwise_scaling && use_colwise_scaling) { + // scaling_type = ScalingType::BIDIMENSIONAL; + // } ShapeRepresentation shape_rep = ShapeRepresentation::SAME_BOTH_DIMS; if (output->all_same_shape()) { @@ -827,6 +851,12 @@ void group_quantize(const GroupedTensor *input, const GroupedTensor *activations NVTE_CHECK(scales_colwise_ptr != nullptr, "Columnwise scaling tensor must be allocated"); } + const bool with_gemm_swizzled_scales = output->with_gemm_swizzled_scales; + + const size_t scale_stride_rowwise = use_rowwise_scaling ? output->scale_inv.shape[1] : 1; + const size_t scale_stride_colwise = + use_colwise_scaling ? output->columnwise_scale_inv.shape[1] : 1; + const size_t dbias_rows = DIVUP(first_logical_dim, CHUNK_DIM_Y); const size_t dbias_cols = last_logical_dim; if constexpr (IS_DBIAS) { @@ -848,111 +878,115 @@ void group_quantize(const GroupedTensor *input, const GroupedTensor *activations input->dtype(), IType, TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY( output->dtype(), OType, - - alignas(64) CUtensorMap tensor_map_input{}; - alignas(64) CUtensorMap tensor_map_act_input{}; - alignas(64) CUtensorMap tensor_map_output_rowwise{}; - alignas(64) CUtensorMap tensor_map_output_colwise{}; - - constexpr size_t input_type_bit_size = TypeInfo::size; - constexpr size_t output_type_bit_size = TypeInfo::size; - - create_2D_tensor_map(tensor_map_input, input->data, first_logical_dim, last_logical_dim, - BUFF_DIM_Y, BUFF_DIM_X, last_logical_dim, 0, input_type_bit_size); - - if constexpr (IS_DACT) { - create_2D_tensor_map(tensor_map_act_input, activations->data, first_logical_dim, - last_logical_dim, BUFF_DIM_Y, BUFF_DIM_X, last_logical_dim, 0, - input_type_bit_size); - } - - if (use_rowwise_scaling) { - create_2D_tensor_map(tensor_map_output_rowwise, output->data, first_logical_dim, - last_logical_dim, BUFF_DIM_Y, BUFF_DIM_X, last_logical_dim, 0, - output_type_bit_size); - } - - if (use_colwise_scaling) { - create_2D_tensor_map(tensor_map_output_colwise, output->columnwise_data, - first_logical_dim, last_logical_dim, BUFF_DIM_Y, BUFF_DIM_X, - last_logical_dim, 0, output_type_bit_size); - } - - constexpr size_t buff_elems = BUFF_DIM_Y * BUFF_DIM_X; - constexpr size_t buff_elems_total = BUFFS_NUM * buff_elems; - constexpr size_t input_buff_size = (buff_elems_total * input_type_bit_size) / 8; - constexpr size_t output_buff_size = (buff_elems_total * output_type_bit_size) / 8; - constexpr size_t buff_size_aligned_in = - DIVUP_TO_MULTIPLE(input_buff_size, TMA_SHMEM_ALIGNMENT); - constexpr size_t buff_size_aligned_out = - DIVUP_TO_MULTIPLE(output_buff_size, TMA_SHMEM_ALIGNMENT); - - constexpr size_t elt_input_mem = buff_size_aligned_in; - constexpr size_t act_input_mem = (IS_DACT ? buff_size_aligned_in : 0); - constexpr size_t in_mem = elt_input_mem + act_input_mem; - - const size_t out_rowwise_mem = (use_rowwise_scaling ? buff_size_aligned_out : 0); - const size_t out_colwise_mem = (use_colwise_scaling ? buff_size_aligned_out : 0); - const size_t out_mem = out_rowwise_mem + out_colwise_mem; - - const size_t dshmem_size = in_mem + out_mem + TMA_SHMEM_ALIGNMENT; - - auto kernel = group_quantize_mxfp8_kernel; - switch (scaling_type) { - case ScalingType::ROWWISE: { - kernel = group_quantize_mxfp8_kernel; - break; - } - case ScalingType::COLWISE: { - kernel = group_quantize_mxfp8_kernel; - break; - } - case ScalingType::BIDIMENSIONAL: { - kernel = group_quantize_mxfp8_kernel; - break; - } - } - - // Update tensor descriptors before launching the kernel - if (!is_single_tensor) { - const IType *const input_dptr = reinterpret_cast(input->data.dptr); - - const IType *const act_input_dptr = - IS_DACT ? reinterpret_cast(activations->data.dptr) : nullptr; - - OType *const output_rowwise_dptr = - use_rowwise_scaling ? reinterpret_cast(output->data.dptr) : nullptr; - - OType *const output_colwise_dptr = - use_colwise_scaling ? reinterpret_cast(output->columnwise_data.dptr) - : nullptr; - update_tma_descriptors<<>>( - tensor_map_input, tensor_map_act_input, tensor_map_output_rowwise, - tensor_map_output_colwise, input_dptr, act_input_dptr, output_rowwise_dptr, - output_colwise_dptr, shape_rep, num_tensors, first_logical_dim, last_logical_dim, - offsets_ptr, first_dims_ptr, last_dims_ptr, use_rowwise_scaling, - use_colwise_scaling, IS_DACT); - } - - NVTE_CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, - dshmem_size)); - - kernel<<>>( - tensor_map_input, tensor_map_act_input, tensor_map_output_rowwise, - tensor_map_output_colwise, shape_rep, num_tensors, first_logical_dim, - last_logical_dim, offsets_ptr, first_dims_ptr, last_dims_ptr, scales_rowwise_ptr, - scales_colwise_ptr, noop_ptr, workspace_ptr, amax_ptr); - - if constexpr (IS_DBIAS) { - common::reduce_dbias(workspace_ptr, dbias, dbias_rows, dbias_cols, stream); - } - - NVTE_CHECK_CUDA(cudaGetLastError());); // NOLINT(*) - ); // NOLINT(*) + TRANSFORMER_ENGINE_SWITCH_CONDITION( + with_gemm_swizzled_scales, WITH_GEMM_SWIZZLED_SCALES, + + alignas(64) CUtensorMap tensor_map_input{}; + alignas(64) CUtensorMap tensor_map_act_input{}; + alignas(64) CUtensorMap tensor_map_output_rowwise{}; + alignas(64) CUtensorMap tensor_map_output_colwise{}; + + constexpr size_t input_type_bit_size = TypeInfo::size; + constexpr size_t output_type_bit_size = TypeInfo::size; + + create_2D_tensor_map(tensor_map_input, input->data, first_logical_dim, last_logical_dim, + BUFF_DIM_Y, BUFF_DIM_X, last_logical_dim, 0, input_type_bit_size); + + if constexpr (IS_DACT) { + create_2D_tensor_map(tensor_map_act_input, activations->data, first_logical_dim, + last_logical_dim, BUFF_DIM_Y, BUFF_DIM_X, last_logical_dim, 0, + input_type_bit_size); + } + + if (use_rowwise_scaling) { + create_2D_tensor_map(tensor_map_output_rowwise, output->data, first_logical_dim, + last_logical_dim, BUFF_DIM_Y, BUFF_DIM_X, last_logical_dim, 0, + output_type_bit_size); + } + + if (use_colwise_scaling) { + create_2D_tensor_map(tensor_map_output_colwise, output->columnwise_data, + first_logical_dim, last_logical_dim, BUFF_DIM_Y, BUFF_DIM_X, + last_logical_dim, 0, output_type_bit_size); + } + + constexpr size_t buff_elems = BUFF_DIM_Y * BUFF_DIM_X; + constexpr size_t buff_elems_total = BUFFS_NUM * buff_elems; + constexpr size_t input_buff_size = (buff_elems_total * input_type_bit_size) / 8; + constexpr size_t output_buff_size = (buff_elems_total * output_type_bit_size) / 8; + constexpr size_t buff_size_aligned_in = + DIVUP_TO_MULTIPLE(input_buff_size, TMA_SHMEM_ALIGNMENT); + constexpr size_t buff_size_aligned_out = + DIVUP_TO_MULTIPLE(output_buff_size, TMA_SHMEM_ALIGNMENT); + + constexpr size_t elt_input_mem = buff_size_aligned_in; + constexpr size_t act_input_mem = (IS_DACT ? buff_size_aligned_in : 0); + constexpr size_t in_mem = elt_input_mem + act_input_mem; + + const size_t out_rowwise_mem = (use_rowwise_scaling ? buff_size_aligned_out : 0); + const size_t out_colwise_mem = (use_colwise_scaling ? buff_size_aligned_out : 0); + const size_t out_mem = out_rowwise_mem + out_colwise_mem; + + const size_t dshmem_size = in_mem + out_mem + TMA_SHMEM_ALIGNMENT; + + auto kernel = group_quantize_mxfp8_kernel; + switch (scaling_type) { + case ScalingType::ROWWISE: { + kernel = group_quantize_mxfp8_kernel; + break; + } + case ScalingType::COLWISE: { + kernel = group_quantize_mxfp8_kernel; + break; + } + case ScalingType::BIDIMENSIONAL: { + kernel = group_quantize_mxfp8_kernel; + break; + } + } + + // Update tensor descriptors before launching the kernel + if (!is_single_tensor) { + const IType *const input_dptr = reinterpret_cast(input->data.dptr); + + const IType *const act_input_dptr = + IS_DACT ? reinterpret_cast(activations->data.dptr) : nullptr; + + OType *const output_rowwise_dptr = + use_rowwise_scaling ? reinterpret_cast(output->data.dptr) : nullptr; + + OType *const output_colwise_dptr = + use_colwise_scaling ? reinterpret_cast(output->columnwise_data.dptr) + : nullptr; + update_tma_descriptors<<>>( + tensor_map_input, tensor_map_act_input, tensor_map_output_rowwise, + tensor_map_output_colwise, input_dptr, act_input_dptr, output_rowwise_dptr, + output_colwise_dptr, shape_rep, num_tensors, first_logical_dim, last_logical_dim, + offsets_ptr, first_dims_ptr, last_dims_ptr, use_rowwise_scaling, + use_colwise_scaling, IS_DACT); + } + + NVTE_CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, + dshmem_size)); + + kernel<<>>( + tensor_map_input, tensor_map_act_input, tensor_map_output_rowwise, + tensor_map_output_colwise, shape_rep, num_tensors, first_logical_dim, + last_logical_dim, offsets_ptr, first_dims_ptr, last_dims_ptr, scales_rowwise_ptr, + scales_colwise_ptr, noop_ptr, workspace_ptr, amax_ptr); + + if constexpr (IS_DBIAS) { + common::reduce_dbias(workspace_ptr, dbias, dbias_rows, dbias_cols, stream); + } + + NVTE_CHECK_CUDA(cudaGetLastError()); + ); // NOLINT(*) + ); // NOLINT(*) + ); // NOLINT(*) } } // namespace mxfp8 diff --git a/transformer_engine/common/common.h b/transformer_engine/common/common.h index 970b7aef6c..66b7e30187 100644 --- a/transformer_engine/common/common.h +++ b/transformer_engine/common/common.h @@ -333,6 +333,7 @@ struct GroupedTensor { NVTEScalingMode scaling_mode; size_t num_tensors; NVTEGroupedTensor nvte_tensor; + bool with_gemm_swizzled_scales = false; GroupedTensor(NVTEScalingMode scaling_mode, size_t num_tensors) : data(), @@ -348,7 +349,8 @@ struct GroupedTensor { tensor_offsets(nullptr, std::vector{0}, DType::kInt64), logical_shape(nvte_make_shape(nullptr, 1)), scaling_mode(scaling_mode), - nvte_tensor(0) {} + nvte_tensor(0), + with_gemm_swizzled_scales(false) {} explicit operator NVTEGroupedTensor() const noexcept { return nvte_tensor; } @@ -400,6 +402,7 @@ struct GroupedTensor { num_tensors = 0; scaling_mode = NVTE_DELAYED_TENSOR_SCALING; nvte_tensor = 0; + with_gemm_swizzled_scales = false; } }; From 850b16e72eef8d5e2cd6a0ef7378e100b165903c Mon Sep 17 00:00:00 2001 From: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Date: Fri, 6 Feb 2026 18:13:52 -0800 Subject: [PATCH 006/172] grouped tensor - pytorch Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --- tests/pytorch/test_grouped_tensor.py | 450 ++++++++ transformer_engine/common/cast/cast.cu | 2 +- .../common/cast/dispatch/quantize.cuh | 1 + .../cast/mxfp8/group_quantize_mxfp8.cuh | 2 + .../common/cast/mxfp8/quantize_mxfp8.cuh | 1 + .../transformer_engine/transformer_engine.h | 236 +++++ transformer_engine/common/recipe/__init__.py | 34 +- .../common/transformer_engine.cpp | 114 +++ .../attention/dot_product_attention/utils.py | 5 + transformer_engine/pytorch/csrc/extensions.h | 2 + .../pytorch/csrc/extensions/cast.cpp | 13 + .../pytorch/csrc/extensions/pybind.cpp | 4 + transformer_engine/pytorch/csrc/pybind.h | 2 + .../pytorch/csrc/type_converters.cpp | 121 +++ .../pytorch/tensor/mxfp8_tensor.py | 45 +- .../pytorch/tensor/storage/__init__.py | 1 + .../pytorch/tensor/storage/grouped_tensor.py | 964 ++++++++++++++++++ 17 files changed, 1981 insertions(+), 16 deletions(-) create mode 100644 tests/pytorch/test_grouped_tensor.py create mode 100644 transformer_engine/pytorch/tensor/storage/grouped_tensor.py diff --git a/tests/pytorch/test_grouped_tensor.py b/tests/pytorch/test_grouped_tensor.py new file mode 100644 index 0000000000..964c2d8e97 --- /dev/null +++ b/tests/pytorch/test_grouped_tensor.py @@ -0,0 +1,450 @@ +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""Tests for GroupedTensor class""" + +from typing import List, Tuple +import pytest +import torch +import transformer_engine.pytorch as te +from transformer_engine.pytorch.tensor.storage.grouped_tensor import GroupedTensor +from transformer_engine.pytorch import ( + Quantizer, + Float8Quantizer, + Float8CurrentScalingQuantizer, + Float8BlockQuantizer, + MXFP8Quantizer, + NVFP4Quantizer, +) +from transformer_engine.pytorch.constants import TE_DType_To_Torch +import transformer_engine_torch as tex + +# Check available recipes +fp8_available, reason_for_no_fp8 = te.is_fp8_available(return_reason=True) +fp8_block_scaling_available, reason_for_no_fp8_block_scaling = te.is_fp8_block_scaling_available( + return_reason=True +) +mxfp8_available, reason_for_no_mxfp8 = te.is_mxfp8_available(return_reason=True) +nvfp4_available, reason_for_no_nvfp4 = te.is_nvfp4_available(return_reason=True) + +_quantization_params = [ + pytest.param( + "fp8_delayed_scaling", + marks=pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8), + ), + pytest.param( + "fp8_current_scaling", + marks=pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8), + ), + pytest.param( + "fp8_blockwise", + marks=pytest.mark.skipif( + not fp8_block_scaling_available, reason=reason_for_no_fp8_block_scaling + ), + ), + pytest.param( + "mxfp8", + marks=pytest.mark.skipif(not mxfp8_available, reason=reason_for_no_mxfp8), + ), + pytest.param( + "nvfp4", + marks=pytest.mark.skipif(not nvfp4_available, reason=reason_for_no_nvfp4), + ), +] + + +def make_quantizer(quantization: str, num_tensors: int, shape: List[Tuple[int, int]]) -> Quantizer: + """Create quantizers for given quantization scheme""" + + if quantization == "fp8_delayed_scaling": + quantizer = Float8Quantizer( + scale=torch.ones(1, dtype=torch.float32, device="cuda"), + amax=torch.zeros(1, dtype=torch.float32, device="cuda"), + fp8_dtype=tex.DType.kFloat8E4M3, + ) + elif quantization == "fp8_current_scaling": + quantizer = Float8CurrentScalingQuantizer( + fp8_dtype=tex.DType.kFloat8E4M3, + device="cuda", + ) + quantizer.set_usage(rowwise=True, columnwise=False) + elif quantization == "fp8_blockwise": + quantizer = Float8BlockQuantizer( + fp8_dtype=tex.DType.kFloat8E4M3, + rowwise=True, + columnwise=False, + force_pow_2_scales=True, + amax_epsilon=0.0, + block_scaling_dim=1, + ) + elif quantization == "mxfp8": + quantizer = MXFP8Quantizer(fp8_dtype=tex.DType.kFloat8E4M3) + elif quantization == "nvfp4": + quantizer = NVFP4Quantizer( + with_rht=False, + with_post_rht_amax=False, + with_2d_quantization=False, + stochastic_rounding=False, + with_random_sign_mask=False, + ) + else: + raise ValueError(f"Unknown quantization scheme: {quantization}") + + quantizer.internal = False + + return quantizer + + +def _get_rowwise_data_tensor(qtensor, quantization: str) -> torch.Tensor: + if quantization in ("fp8_delayed_scaling", "fp8_current_scaling"): + return qtensor._data + if quantization in ("fp8_blockwise", "mxfp8", "nvfp4"): + return qtensor._rowwise_data + raise ValueError(f"Unknown quantization scheme: {quantization}") + + +def _rowwise_offset_bytes(numel: int, quantization: str) -> int: + if quantization == "nvfp4": + return numel // 2 + return numel + + +class TestGroupedTensor: + @staticmethod + def setup_class(cls) -> None: + # Configure RNG + seed = 1234 + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + + def test_basic_construction_all_same_shape(self) -> None: + """Test GroupedTensor construction with all tensors having same shape""" + num_tensors = 4 + shape = [(256, 512) for _ in range(num_tensors)] + + grouped_tensor = GroupedTensor.make_grouped_tensor_with_shapes( + num_tensors=num_tensors, + shape=shape, + quantizer=None, + device="cuda", + dtype=torch.float32, + ) + + assert grouped_tensor.num_tensors == num_tensors + assert grouped_tensor.all_same_shape() + assert grouped_tensor.all_same_first_dim() + assert grouped_tensor.all_same_last_dim() + assert grouped_tensor.logical_shape == (num_tensors * 256, 512) + assert grouped_tensor.get_common_first_dim() == 256 + assert grouped_tensor.get_common_last_dim() == 512 + assert grouped_tensor.has_data() + + def test_basic_construction_varying_first_dim(self) -> None: + """Test GroupedTensor construction with varying first dimension""" + num_tensors = 3 + shape = [(128, 512), (256, 512), (384, 512)] + + grouped_tensor = GroupedTensor.make_grouped_tensor_with_shapes( + num_tensors=num_tensors, + shape=shape, + quantizer=None, + device="cuda", + dtype=torch.float32, + ) + + assert grouped_tensor.num_tensors == num_tensors + assert not grouped_tensor.all_same_shape() + assert not grouped_tensor.all_same_first_dim() + assert grouped_tensor.all_same_last_dim() + assert grouped_tensor.get_common_last_dim() == shape[0][1] + assert grouped_tensor.logical_shape == ( + sum(v for v, _ in shape), + shape[0][1], + ) # sum of first dims + + def test_split_into_quantized_tensors_no_quantization(self) -> None: + """Test split_into_quantized_tensors for unquantized tensors""" + num_tensors = 3 + shape = [(256, 512) for _ in range(num_tensors)] + + grouped_tensor = GroupedTensor.make_grouped_tensor_with_shapes( + num_tensors=num_tensors, + shape=shape, + quantizer=None, + device="cuda", + dtype=torch.float32, + ) + + # Get the original data pointer + original_data_ptr = grouped_tensor.data.data_ptr() + + # Split into tensors + tensors = grouped_tensor.split_into_quantized_tensors() + + assert len(tensors) == num_tensors + + # Verify each tensor has correct shape and shares storage + for i, tensor in enumerate(tensors): + assert tensor.shape == shape[i] + assert isinstance(tensor, torch.Tensor) + assert not hasattr(tensor, "_data") # Not a quantized tensor + + # Verify data pointer is within the original grouped tensor storage + # The tensor should be a view of the original data + assert tensor.data_ptr() >= original_data_ptr + + # Calculate expected offset + expected_offset = i * (shape[i][0] * shape[i][1]) * tensor.element_size() + assert tensor.data_ptr() == original_data_ptr + expected_offset + + @pytest.mark.parametrize("quantization", _quantization_params) + def test_split_into_quantized_tensors_quantized(self, quantization: str) -> None: + """Test split_into_quantized_tensors for quantized tensors""" + num_tensors = 3 + shape = [(512, 512) for _ in range(num_tensors)] + quantizers = make_quantizer(quantization, num_tensors, shape) + + grouped_tensor = GroupedTensor.make_grouped_tensor_with_shapes( + num_tensors=num_tensors, + shape=shape, + quantizer=quantizers, + device="cuda", + ) + + # Get the original data pointer + original_data_ptr = grouped_tensor.data.data_ptr() + + # Split into tensors + tensors = grouped_tensor.split_into_quantized_tensors() + + assert len(tensors) == num_tensors + + # Verify each tensor shares storage with the grouped tensor + for i, tensor in enumerate(tensors): + rowwise_data = _get_rowwise_data_tensor(tensor, quantization) + assert rowwise_data is not None + assert rowwise_data.data_ptr() >= original_data_ptr + numel = shape[i][0] * shape[i][1] + expected_offset = _rowwise_offset_bytes(i * numel, quantization) + assert rowwise_data.data_ptr() == original_data_ptr + expected_offset + + def test_split_varying_shapes(self) -> None: + """Test split_into_quantized_tensors with varying shapes""" + num_tensors = 3 + shape = [(128, 512), (256, 512), (384, 512)] + + grouped_tensor = GroupedTensor.make_grouped_tensor_with_shapes( + num_tensors=num_tensors, + shape=shape, + quantizer=None, + device="cuda", + dtype=torch.float32, + ) + + original_data_ptr = grouped_tensor.data.data_ptr() + tensors = grouped_tensor.split_into_quantized_tensors() + + assert len(tensors) == num_tensors + + # Verify shapes and storage + cumulative_offset = 0 + for i, tensor in enumerate(tensors): + assert tensor.shape == shape[i] + expected_offset = cumulative_offset * tensor.element_size() + assert tensor.data_ptr() == original_data_ptr + expected_offset + cumulative_offset += shape[i][0] * shape[i][1] + + @pytest.mark.parametrize("quantization", _quantization_params) + def test_quantize_inplace(self, quantization: str) -> None: + """Test that quantize is done in-place for all recipes""" + num_tensors = 3 + shape = [(512, 512) for _ in range(num_tensors)] + quantizers = make_quantizer(quantization, num_tensors, shape) + + grouped_tensor = GroupedTensor.make_grouped_tensor_with_shapes( + num_tensors=num_tensors, + shape=shape, + quantizer=quantizers, + device="cuda", + ) + + # Get original data pointers before quantization + original_data_ptr = grouped_tensor.data.data_ptr() + original_scale_inv_ptr = grouped_tensor.scale_inv.data_ptr() + original_scale_ptr = ( + grouped_tensor.scale.data_ptr() if grouped_tensor.scale is not None else None + ) + + # Create input tensors + input_tensors = [torch.randn(s, dtype=torch.float32, device="cuda") for s in shape] + + # Quantize in place + quantized_tensors = grouped_tensor.quantize(input_tensors) + + # Verify data pointers haven't changed (in-place operation) + assert grouped_tensor.data.data_ptr() == original_data_ptr + assert grouped_tensor.scale_inv.data_ptr() == original_scale_inv_ptr + if original_scale_ptr is not None: + assert grouped_tensor.scale.data_ptr() == original_scale_ptr + + # Verify returned tensors point to the same storage + for i, qtensor in enumerate(quantized_tensors): + rowwise_data = _get_rowwise_data_tensor(qtensor, quantization) + numel = shape[i][0] * shape[i][1] + expected_offset = _rowwise_offset_bytes(i * numel, quantization) + assert rowwise_data.data_ptr() == original_data_ptr + expected_offset + + @pytest.mark.parametrize("quantization", _quantization_params) + def test_quantize_varying_shapes(self, quantization: str) -> None: + """Test quantize with varying shapes""" + num_tensors = 3 + shape = [(256, 512), (512, 512), (768, 512)] + quantizers = make_quantizer(quantization, num_tensors, shape) + + grouped_tensor = GroupedTensor.make_grouped_tensor_with_shapes( + num_tensors=num_tensors, + shape=shape, + quantizer=quantizers, + device="cuda", + ) + + # Get original data pointers + original_data_ptr = grouped_tensor.data.data_ptr() + + # Create input tensors with varying shapes + input_tensors = [torch.randn(s, dtype=torch.float32, device="cuda") for s in shape] + + # Quantize in place + quantized_tensors = grouped_tensor.quantize(input_tensors) + + # Verify data pointer hasn't changed + assert grouped_tensor.data.data_ptr() == original_data_ptr + + # Verify each tensor points to correct location + cumulative_numel = 0 + for qtensor, tensor_shape in zip(quantized_tensors, shape): + rowwise_data = _get_rowwise_data_tensor(qtensor, quantization) + expected_offset = _rowwise_offset_bytes(cumulative_numel, quantization) + assert rowwise_data.data_ptr() == original_data_ptr + expected_offset + cumulative_numel += tensor_shape[0] * tensor_shape[1] + + @pytest.mark.parametrize("quantization", _quantization_params) + def test_static_quantize_method(self, quantization: str) -> None: + """Test the static quantize method""" + num_tensors = 3 + shape = [(512, 512) for _ in range(num_tensors)] + quantizers = make_quantizer(quantization, num_tensors, shape) + + # Create input tensors + input_tensors = [torch.randn(s, dtype=torch.float32, device="cuda") for s in shape] + + # Use static quantize method + grouped_tensor = GroupedTensor.create_and_quantize( + tensors=input_tensors, + quantizer=quantizers, + device="cuda", + ) + + # Verify the grouped tensor was created correctly + assert grouped_tensor.num_tensors == num_tensors + assert grouped_tensor.has_data() + + # Verify quantized_tensors were created and point to same storage + assert grouped_tensor.quantized_tensors is not None + assert len(grouped_tensor.quantized_tensors) == num_tensors + + original_data_ptr = grouped_tensor.data.data_ptr() + for i, qtensor in enumerate(grouped_tensor.quantized_tensors): + rowwise_data = _get_rowwise_data_tensor(qtensor, quantization) + numel = shape[i][0] * shape[i][1] + expected_offset = _rowwise_offset_bytes(i * numel, quantization) + assert rowwise_data.data_ptr() == original_data_ptr + expected_offset + + @pytest.mark.parametrize( + "shape", + [[(256, 512), (512, 512), (768, 512)], [(512, 512), (512, 512), (512, 512)]], + ) + @pytest.mark.skipif(not mxfp8_available, reason=reason_for_no_mxfp8) + def test_quantize_grouped_mxfp8(self, shape: List[Tuple[int, int]]) -> None: + """Test grouped quantization for MXFP8 against per-tensor quantization.""" + # Test wont pass until the grouped quantization PR from Oleg is merged. + num_tensors = 2 + shape = [(512, 1024) for _ in range(num_tensors)] + + # Create BF16 input tensors and pack into a grouped tensor + input_tensors = [torch.randn(s, dtype=torch.bfloat16, device="cuda") for s in shape] + quantizers = [MXFP8Quantizer(fp8_dtype=tex.DType.kFloat8E4M3) for _ in range(num_tensors)] + for q in quantizers: + q.optimize_for_gemm=True + quantized_tensors = [q(tensor) for q, tensor in zip(quantizers, input_tensors)] + grouped_input = GroupedTensor.make_grouped_tensor_with_shapes( + num_tensors=num_tensors, + shape=shape, + quantizers=None, + device="cuda", + dtype=torch.bfloat16, + ) + + offset = 0 + for tensor in input_tensors: + numel = tensor.numel() + grouped_input.data[offset : offset + numel].copy_(tensor.reshape(-1)) + offset += numel + + # Create MXFP8 output grouped tensor (rowwise only for easier validation) + # quantizers = [MXFP8Quantizer(fp8_dtype=tex.DType.kFloat8E4M3) for _ in range(num_tensors)] + + grouped_output = GroupedTensor.make_grouped_tensor_with_shapes( + num_tensors=num_tensors, + shape=shape, + quantizers=quantizers, + device="cuda", + ) + print(f">>>>>>>>>>>> tex.quantize_grouped") + print(f">>>>>>>>>>>> grouped_input: {grouped_input.data.shape if grouped_input.data is not None else None}") + print(f">>>>>>>>>>>> grouped_output: {grouped_output.data.shape if grouped_output.data is not None else None}") + print(f">>>>>>>>>>>> grouped_input: {grouped_input.scale_inv.shape if grouped_input.scale_inv is not None else None}") + print(f">>>>>>>>>>>> grouped_output: {grouped_output.scale_inv.shape if grouped_output.scale_inv is not None else None}") + print(f">>>>>>>>>>>> grouped_input: {grouped_input.columnwise_data.shape if grouped_input.columnwise_data is not None else None}") + print(f">>>>>>>>>>>> grouped_output: {grouped_output.columnwise_data.shape if grouped_output.columnwise_data is not None else None}") + print(f">>>>>>>>>>>> grouped_input: {grouped_input.columnwise_scale_inv.shape if grouped_input.columnwise_scale_inv is not None else None}") + print(f">>>>>>>>>>>> grouped_output: {grouped_output.columnwise_scale_inv.shape if grouped_output.columnwise_scale_inv is not None else None}") + # Quantize using grouped API (handle both 2-arg and 3-arg bindings) + _ = tex.quantize_grouped(grouped_input, grouped_output) + # Build expected output by quantizing each tensor independently + expected_data = [] + expected_scale_inv = [] + for tensor, quantizer in zip(input_tensors, quantizers): + qtensor = quantizer(tensor) + expected_data.append(qtensor._rowwise_data.reshape(-1)) + expected_scale_inv.append(qtensor._rowwise_scale_inv.reshape(-1)) + + expected_data = torch.cat(expected_data) + expected_scale_inv = torch.cat(expected_scale_inv) + + assert torch.equal(grouped_output.data, expected_data) + assert torch.equal(grouped_output.scale_inv, expected_scale_inv) + + def test_clear(self) -> None: + """Test clear method""" + num_tensors = 3 + shape = [(256, 512) for _ in range(num_tensors)] + + grouped_tensor = GroupedTensor.make_grouped_tensor_with_shapes( + num_tensors=num_tensors, + shape=shape, + quantizer=None, + device="cuda", + dtype=torch.float32, + ) + + assert grouped_tensor.has_data() + assert grouped_tensor.num_tensors == num_tensors + + grouped_tensor.clear() + + assert not grouped_tensor.has_data() + assert grouped_tensor.num_tensors == 0 + assert grouped_tensor.data is None + assert grouped_tensor.logical_shape == (0, 0) diff --git a/transformer_engine/common/cast/cast.cu b/transformer_engine/common/cast/cast.cu index 582172a88e..624b0bfc7c 100644 --- a/transformer_engine/common/cast/cast.cu +++ b/transformer_engine/common/cast/cast.cu @@ -30,7 +30,7 @@ void nvte_group_quantize(const NVTEGroupedTensor input, NVTEGroupedTensor output cudaStream_t stream) { NVTE_API_CALL(nvte_group_quantize); using namespace transformer_engine; - + printf(">>>>>>>>>>>> nvte_group_quantize\n"); constexpr bool IS_ACT = false; dispatch::group_quantize_fwd_helper(input, output, nullptr, stream); } diff --git a/transformer_engine/common/cast/dispatch/quantize.cuh b/transformer_engine/common/cast/dispatch/quantize.cuh index b83df1dedf..9a6e9b01d6 100644 --- a/transformer_engine/common/cast/dispatch/quantize.cuh +++ b/transformer_engine/common/cast/dispatch/quantize.cuh @@ -375,6 +375,7 @@ void group_quantize_fwd_helper(const NVTETensor input, NVTETensor *outputs, template void group_quantize_fwd_helper(const NVTEGroupedTensor input, NVTEGroupedTensor output, const NVTEQuantizationConfig quant_config, cudaStream_t stream) { + printf(">>>>>>>>>>>> group_quantize_fwd_helper\n"); using namespace detail; NVTEScalingMode scaling_mode = nvte_grouped_tensor_scaling_mode(output); diff --git a/transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh b/transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh index df4317b547..35e605067d 100644 --- a/transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh +++ b/transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh @@ -244,6 +244,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) group_quantize_mxfp8_kernel const int64_t *const __restrict__ last_dims_ptr, e8m0_t *const __restrict__ scales_rowwise_ptr, e8m0_t *const __restrict__ scales_colwise_ptr, const float *__restrict__ noop, float *const __restrict__ dbias_workspace, float *const __restrict__ amax_ptr) { +printf(">>>>>>>>>>>> WITH_GEMM_SWIZZLED_SCALES: %d\n", WITH_GEMM_SWIZZLED_SCALES); #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) constexpr bool COMPUTE_ACTIVATIONS = IS_DACT || IS_ACT; constexpr bool NO_ACTIVATIONS = !COMPUTE_ACTIVATIONS; @@ -852,6 +853,7 @@ void group_quantize(const GroupedTensor *input, const GroupedTensor *activations } const bool with_gemm_swizzled_scales = output->with_gemm_swizzled_scales; + printf(">>>>>>>>>>>> with_gemm_swizzled_scales: %d\n", with_gemm_swizzled_scales); const size_t scale_stride_rowwise = use_rowwise_scaling ? output->scale_inv.shape[1] : 1; const size_t scale_stride_colwise = diff --git a/transformer_engine/common/cast/mxfp8/quantize_mxfp8.cuh b/transformer_engine/common/cast/mxfp8/quantize_mxfp8.cuh index 70a68132ad..a3e7db94d1 100644 --- a/transformer_engine/common/cast/mxfp8/quantize_mxfp8.cuh +++ b/transformer_engine/common/cast/mxfp8/quantize_mxfp8.cuh @@ -55,6 +55,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) const float *noop, float *const dbias_workspace, float *const amax_ptr, const size_t rows, const size_t cols, const size_t scale_stride_rowwise, const size_t scale_stride_colwise) { +printf(">>>>>>>>>>>> non-grouped: WITH_GEMM_SWIZZLED_SCALES: %d\n", WITH_GEMM_SWIZZLED_SCALES); #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) constexpr bool COMPUTE_ACTIVATIONS = IS_DACT || IS_ACT; constexpr bool NO_ACTIVATIONS = !COMPUTE_ACTIVATIONS; diff --git a/transformer_engine/common/include/transformer_engine/transformer_engine.h b/transformer_engine/common/include/transformer_engine/transformer_engine.h index ae41f238a4..8f3025a86a 100644 --- a/transformer_engine/common/include/transformer_engine/transformer_engine.h +++ b/transformer_engine/common/include/transformer_engine/transformer_engine.h @@ -449,6 +449,7 @@ enum NVTEGroupedTensorParam { kNVTEGroupedLastDims = 8, /*!< Last dimension sizes (device pointer to int64_t array) */ kNVTEGroupedTensorOffsets = 9, /*!< Tensor offsets for contiguous layout (device pointer to int64_t array) */ + kNVTEGroupedWithGEMMSwizzledScales = 10, /*!< Whether scaling factors are in format expected by GEMM */ kNVTENumGroupedTensorParams }; @@ -499,6 +500,25 @@ void nvte_set_grouped_tensor_param(NVTEGroupedTensor *tensor, NVTEGroupedTensorP NVTEBasicTensor nvte_get_grouped_tensor_param(const NVTEGroupedTensor tensor, NVTEGroupedTensorParam param_name); +/*! \brief Set a parameter of the grouped tensor. + * + * \param[in/out] tensor Grouped tensor. + * \param[in] param_name The parameter to be set. + * \param[in] param The value to be set (NVTEBasicTensor). + */ +void nvte_set_grouped_tensor_param_v2(NVTEGroupedTensor tensor, NVTEGroupedTensorParam param_name, + const void *buf, size_t size_in_bytes); + +/* EXPERIMENTAL FEATURE AND SUBJECT TO CHANGE. */ +/*! \brief Get a value of the parameter of the grouped tensor. + * + * \param[in] tensor Grouped tensor. + * \param[in] param_name The parameter to be queried. + * + * \return NVTEBasicTensor containing the parameter data. + */ +void nvte_get_grouped_tensor_param_v2(const NVTEGroupedTensor tensor, NVTEGroupedTensorParam param_name, void *buf, size_t size_in_bytes, size_t *size_written); + /* EXPERIMENTAL FEATURE AND SUBJECT TO CHANGE. */ /*! \brief Get the number of tensors in a grouped tensor. * @@ -957,6 +977,222 @@ class TensorWrapper { NVTETensor tensor_ = nullptr; }; +/*! \struct GroupedTensorWrapper + * \brief C++ wrapper for the NVTEGroupedTensor class. + */ + + class GroupedTensorWrapper { + public: + /*! \brief Constructs new GroupedTensorWrapper. + * + * Create a new TE grouped tensor with a given logical shape. + * TE grouped tensors are just wrappers on top of raw data and do not + * own memory. + * + * \param[in] num_tensors Number of tensors in the group (must be > 0). + * \param[in] logical_shape Logical 2D shape of the grouped data. + * \param[in] scaling_mode Tensor data format. + */ + GroupedTensorWrapper(const size_t num_tensors, const NVTEShape &logical_shape, + const NVTEScalingMode scaling_mode = NVTE_DELAYED_TENSOR_SCALING) + : tensor_(nvte_create_grouped_tensor(scaling_mode, num_tensors, logical_shape)) {} + + /*! \brief Constructs new GroupedTensorWrapper. + * + * Create a new TE grouped tensor with a given logical shape. + * + * \param[in] num_tensors Number of tensors in the group (must be > 0). + * \param[in] logical_shape Logical 2D shape of the grouped data. + * \param[in] scaling_mode Tensor data format. + */ + GroupedTensorWrapper(const size_t num_tensors, const std::vector &logical_shape, + const NVTEScalingMode scaling_mode = NVTE_DELAYED_TENSOR_SCALING) + : GroupedTensorWrapper(num_tensors, + nvte_make_shape(logical_shape.data(), logical_shape.size()), + scaling_mode) {} + + /*! \brief GroupedTensorWrapper destructor. */ + ~GroupedTensorWrapper() { nvte_destroy_grouped_tensor(tensor_); } + + GroupedTensorWrapper &operator=(const GroupedTensorWrapper &other) = delete; + GroupedTensorWrapper(const GroupedTensorWrapper &other) = delete; + + /*! \brief Constructs new GroupedTensorWrapper from existing GroupedTensorWrapper. */ + GroupedTensorWrapper(GroupedTensorWrapper &&other) { + tensor_ = other.tensor_; + other.tensor_ = nullptr; + } + + /*! \brief Assign the data from existing GroupedTensorWrapper. */ + GroupedTensorWrapper &operator=(GroupedTensorWrapper &&other) { + if (this == &other) return *this; + nvte_destroy_grouped_tensor(tensor_); + tensor_ = other.tensor_; + other.tensor_ = nullptr; + return *this; + } + + // Parameter setters + template + GroupedTensorWrapper &set_parameter(const NVTEGroupedTensorParam param, void *dptr, DType type, + const ShapeType &shape) noexcept { + NVTEShape nvte_shape = this->convertShape(shape); + NVTEBasicTensor data = {dptr, static_cast(type), nvte_shape}; + nvte_set_grouped_tensor_param(&tensor_, param, &data); + return *this; + } + + template + GroupedTensorWrapper &set_rowwise_data(void *dptr, DType type, const ShapeType &shape) noexcept { + return set_parameter(kNVTEGroupedRowwiseData, dptr, type, shape); + } + + template + GroupedTensorWrapper &set_columnwise_data(void *dptr, DType type, + const ShapeType &shape) noexcept { + return set_parameter(kNVTEGroupedColumnwiseData, dptr, type, shape); + } + + template + GroupedTensorWrapper &set_scale(void *dptr, DType type, const ShapeType &shape) noexcept { + return set_parameter(kNVTEGroupedScale, dptr, type, shape); + } + + template + GroupedTensorWrapper &set_amax(void *dptr, DType type, const ShapeType &shape) noexcept { + return set_parameter(kNVTEGroupedAmax, dptr, type, shape); + } + + template + GroupedTensorWrapper &set_rowwise_scale_inv(void *dptr, DType type, + const ShapeType &shape) noexcept { + return set_parameter(kNVTEGroupedRowwiseScaleInv, dptr, type, shape); + } + + template + GroupedTensorWrapper &set_columnwise_scale_inv(void *dptr, DType type, + const ShapeType &shape) noexcept { + return set_parameter(kNVTEGroupedColumnwiseScaleInv, dptr, type, shape); + } + + template + GroupedTensorWrapper &set_columnwise_amax(void *dptr, DType type, + const ShapeType &shape) noexcept { + return set_parameter(kNVTEGroupedColumnwiseAmax, dptr, type, shape); + } + + template + GroupedTensorWrapper &set_first_dims(void *dptr, DType type, const ShapeType &shape) noexcept { + return set_parameter(kNVTEGroupedFirstDims, dptr, type, shape); + } + + template + GroupedTensorWrapper &set_last_dims(void *dptr, DType type, const ShapeType &shape) noexcept { + return set_parameter(kNVTEGroupedLastDims, dptr, type, shape); + } + + template + GroupedTensorWrapper &set_tensor_offsets(void *dptr, DType type, + const ShapeType &shape) noexcept { + return set_parameter(kNVTEGroupedTensorOffsets, dptr, type, shape); + } + + void set_with_gemm_swizzled_scales(bool with_gemm_swizzled_scales) { + const auto val = static_cast(with_gemm_swizzled_scales); + nvte_set_grouped_tensor_param_v2(tensor_, kNVTEGroupedWithGEMMSwizzledScales, &val, sizeof(val)); + } + + // Parameter getters + NVTEBasicTensor get_parameter(const NVTEGroupedTensorParam param) const noexcept { + return nvte_get_grouped_tensor_param(tensor_, param); + } + + NVTEBasicTensor get_rowwise_data() const noexcept { + return get_parameter(kNVTEGroupedRowwiseData); + } + + NVTEBasicTensor get_columnwise_data() const noexcept { + return get_parameter(kNVTEGroupedColumnwiseData); + } + + NVTEBasicTensor get_scale() const noexcept { return get_parameter(kNVTEGroupedScale); } + + NVTEBasicTensor get_amax() const noexcept { return get_parameter(kNVTEGroupedAmax); } + + NVTEBasicTensor get_rowwise_scale_inv() const noexcept { + return get_parameter(kNVTEGroupedRowwiseScaleInv); + } + + NVTEBasicTensor get_columnwise_scale_inv() const noexcept { + return get_parameter(kNVTEGroupedColumnwiseScaleInv); + } + + NVTEBasicTensor get_columnwise_amax() const noexcept { + return get_parameter(kNVTEGroupedColumnwiseAmax); + } + + NVTEBasicTensor get_first_dims() const noexcept { return get_parameter(kNVTEGroupedFirstDims); } + + NVTEBasicTensor get_last_dims() const noexcept { return get_parameter(kNVTEGroupedLastDims); } + + NVTEBasicTensor get_tensor_offsets() const noexcept { + return get_parameter(kNVTEGroupedTensorOffsets); + } + + bool get_with_gemm_swizzled_scales() const { + uint8_t val = 0; + nvte_get_grouped_tensor_param_v2(tensor_, kNVTEGroupedWithGEMMSwizzledScales, &val, sizeof(val), nullptr); + return static_cast(val); + } + + /*! \brief Get an underlying NVTEGroupedTensor. + * + * \return NVTEGroupedTensor held by this GroupedTensorWrapper. + */ + NVTEGroupedTensor data() const noexcept { return tensor_; } + + /*! \brief Get the number of tensors in this GroupedTensorWrapper. */ + size_t num_tensors() const noexcept { + if (tensor_ == nullptr) return 0; + return nvte_grouped_tensor_num_tensors(tensor_); + } + + /*! \brief Get the data type of this GroupedTensorWrapper. */ + DType dtype() const noexcept { + if (tensor_ == nullptr) return DType::kNumTypes; + return static_cast(nvte_grouped_tensor_type(tensor_)); + } + + /*! \brief Get a scaling mode of the grouped tensor. */ + NVTEScalingMode scaling_mode() const noexcept { + if (tensor_ == nullptr) return NVTE_DELAYED_TENSOR_SCALING; + return nvte_grouped_tensor_scaling_mode(tensor_); + } + + /*! \brief Get the logical shape of this GroupedTensorWrapper. */ + const NVTEShape logical_shape() const noexcept { + if (tensor_ == nullptr) { + return emptyShape; + } + return nvte_get_grouped_tensor_logical_shape(tensor_); + } + + static constexpr size_t defaultData = 1; + static constexpr NVTEShape defaultShape = { + {defaultData, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, 1}; + static constexpr NVTEShape emptyShape = {{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, 1}; + + private: + NVTEShape convertShape(const NVTEShape &s) { return s; } + + NVTEShape convertShape(const std::vector &s) { + return nvte_make_shape(s.data(), s.size()); + } + + /*! \brief Wrapped NVTEGroupedTensor. */ + NVTEGroupedTensor tensor_ = nullptr; + }; + /*! \warning Deprecated */ enum class Float8BlockScaleTensorFormat { GEMM_READY = 0, COMPACT = 1, INVALID }; diff --git a/transformer_engine/common/recipe/__init__.py b/transformer_engine/common/recipe/__init__.py index 64ee2a5a16..da1bf03b02 100644 --- a/transformer_engine/common/recipe/__init__.py +++ b/transformer_engine/common/recipe/__init__.py @@ -87,32 +87,38 @@ class Recipe: """ Base recipe class. """ - - def nvfp4(self): + @classmethod + def nvfp4(cls): """Whether the given recipe is NVFP4 1D block scaling.""" - return isinstance(self, NVFP4BlockScaling) + return issubclass(cls, NVFP4BlockScaling) - def mxfp8(self): + @classmethod + def mxfp8(cls): """Whether the given recipe is MXFP8 block scaling.""" - return isinstance(self, MXFP8BlockScaling) + return issubclass(cls, MXFP8BlockScaling) - def delayed(self): + @classmethod + def delayed(cls): """Whether the given recipe is delayed scaling.""" - return isinstance(self, DelayedScaling) + return issubclass(cls, DelayedScaling) - def float8_current_scaling(self): + @classmethod + def float8_current_scaling(cls): """Whether the given recipe is (per-tensor) current scaling.""" - return isinstance(self, Float8CurrentScaling) + return issubclass(cls, Float8CurrentScaling) - def float8_per_tensor_scaling(self): + @classmethod + def float8_per_tensor_scaling(cls): """Whether the given recipe is per-tensor scaling.""" - return isinstance(self, (DelayedScaling, Float8CurrentScaling)) + return issubclass(cls, (DelayedScaling, Float8CurrentScaling)) - def float8_block_scaling(self): + @classmethod + def float8_block_scaling(cls): """Whether the given recipe is float8 blockwise scaling.""" - return isinstance(self, Float8BlockScaling) + return issubclass(cls, Float8BlockScaling) - def custom(self): + @classmethod + def custom(cls): """Whether the given recipe is custom.""" return isinstance(self, CustomRecipe) diff --git a/transformer_engine/common/transformer_engine.cpp b/transformer_engine/common/transformer_engine.cpp index 06971443dd..d0d6b533c8 100644 --- a/transformer_engine/common/transformer_engine.cpp +++ b/transformer_engine/common/transformer_engine.cpp @@ -1268,3 +1268,117 @@ NVTEShape nvte_get_grouped_tensor_logical_shape(const NVTEGroupedTensor tensor) const auto &t = *transformer_engine::convertNVTEGroupedTensorCheck(tensor); return t.logical_shape; } + +void nvte_set_grouped_tensor_param_v2(NVTEGroupedTensor tensor, NVTEGroupedTensorParam param, const void *buf, + size_t size_in_bytes) { +// Check attribute and buffer +NVTE_CHECK(param < kNVTENumGroupedTensorParams, "Invalid NVTEGroupedTensorParam (got ", static_cast(param), +")"); +NVTE_CHECK(tensor != nullptr, "Grouped tensor pointer can't be NULL."); +auto &t = *transformer_engine::convertNVTEGroupedTensorCheck(tensor); + +// Read from buffer +switch (param) { +case kNVTEGroupedRowwiseData: { +const NVTEBasicTensor *basic_tensor = reinterpret_cast(buf); +t.data = *basic_tensor; +break; +} +case kNVTEGroupedColumnwiseData: { +const NVTEBasicTensor *basic_tensor = reinterpret_cast(buf); +t.columnwise_data = *basic_tensor; +break; +} +case kNVTEGroupedScale: { +const NVTEBasicTensor *basic_tensor = reinterpret_cast(buf); +t.scale = *basic_tensor; +break; +} +case kNVTEGroupedAmax: { +const NVTEBasicTensor *basic_tensor = reinterpret_cast(buf); +t.amax = *basic_tensor; +break; +} +case kNVTEGroupedRowwiseScaleInv: { +const NVTEBasicTensor *basic_tensor = reinterpret_cast(buf); +t.scale_inv = *basic_tensor; +break; +} +case kNVTEGroupedColumnwiseScaleInv: { +const NVTEBasicTensor *basic_tensor = reinterpret_cast(buf); +t.columnwise_scale_inv = *basic_tensor; +break; +} +case kNVTEGroupedColumnwiseAmax: { +const NVTEBasicTensor *basic_tensor = reinterpret_cast(buf); +t.columnwise_amax = *basic_tensor; +break; +} +case kNVTEGroupedWithGEMMSwizzledScales: +t.with_gemm_swizzled_scales = static_cast(*reinterpret_cast(buf)); +break; +default: +NVTE_ERROR("Unsupported grouped tensor parameter (", static_cast(param), ")"); +} +} + +void nvte_get_grouped_tensor_param_v2(const NVTEGroupedTensor tensor, NVTEGroupedTensorParam param, void *buf, + size_t size_in_bytes, size_t *size_written) { +using namespace transformer_engine; + +// Check param +NVTE_CHECK(param < kNVTENumGroupedTensorParams, "Invalid NVTEGroupedTensorParam (got ", static_cast(param), +")"); + +// Return immediately if buffer is not provided +if (buf == nullptr) { +return; +} + +// Get C++ tensor +const GroupedTensor *t = convertNVTEGroupedTensor(tensor); + +// Write to buffer +switch (param) { +case kNVTEGroupedRowwiseData: { +NVTEBasicTensor *basic_tensor = reinterpret_cast(buf); +*basic_tensor = static_cast(t->data); +break; +} +case kNVTEGroupedColumnwiseData: { +NVTEBasicTensor *basic_tensor = reinterpret_cast(buf); +*basic_tensor = static_cast(t->columnwise_data); +break; +} +case kNVTEGroupedScale: { +NVTEBasicTensor *basic_tensor = reinterpret_cast(buf); +*basic_tensor = static_cast(t->scale); +break; +} +case kNVTEGroupedAmax: { +NVTEBasicTensor *basic_tensor = reinterpret_cast(buf); +*basic_tensor = static_cast(t->amax); +break; +} +case kNVTEGroupedRowwiseScaleInv: { +NVTEBasicTensor *basic_tensor = reinterpret_cast(buf); +*basic_tensor = static_cast(t->scale_inv); +break; +} +case kNVTEGroupedColumnwiseScaleInv: { +NVTEBasicTensor *basic_tensor = reinterpret_cast(buf); +*basic_tensor = static_cast(t->columnwise_scale_inv); +break; +} +case kNVTEGroupedColumnwiseAmax: { +NVTEBasicTensor *basic_tensor = reinterpret_cast(buf); +*basic_tensor = static_cast(t->columnwise_amax); +break; +} +case kNVTEGroupedWithGEMMSwizzledScales: +*reinterpret_cast(buf) = static_cast(t->with_gemm_swizzled_scales); +break; +default: +NVTE_ERROR("Unsupported grouped tensor parameter (", static_cast(param), ")"); +} +} diff --git a/transformer_engine/pytorch/attention/dot_product_attention/utils.py b/transformer_engine/pytorch/attention/dot_product_attention/utils.py index a957976235..78083c0b0b 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/utils.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/utils.py @@ -2188,6 +2188,11 @@ def combine_and_quantize(qkv_layout, q, k, v, qkv_quantizer): src_nominal_dtype = q.dtype print(f"Combining and quantizing q, k, v: {q.shape}, {k.shape}, {v.shape}") if isinstance(qkv_quantizer, MXFP8Quantizer): + # bs3hd -> bshd_bshd_bhsd + q,k,v = [x.contiguous() for x in [q, k, v]] + + # bshd_bshd_bhsd -> bhsd_bhsd_bhsd + qkv_quantizer.optimize_for_gemm = True qkv_quantizer._internal = False dim_s = qkv_format.find("s") if 's' in qkv_format else qkv_format.find("t") dim_others = [i for i in range(len(v.shape)) if i != dim_s] diff --git a/transformer_engine/pytorch/csrc/extensions.h b/transformer_engine/pytorch/csrc/extensions.h index f7cf32eaf6..d91ec308fa 100644 --- a/transformer_engine/pytorch/csrc/extensions.h +++ b/transformer_engine/pytorch/csrc/extensions.h @@ -250,6 +250,8 @@ py::object quantize(const at::Tensor &tensor, py::handle quantizer, const py::ob py::object dequantize(const py::handle &input, DType otype); +py::object quantize_grouped(const py::handle &input, py::handle &output); + std::vector multi_tensor_quantize(const std::vector &tensor_list, std::vector quantizer_list); diff --git a/transformer_engine/pytorch/csrc/extensions/cast.cpp b/transformer_engine/pytorch/csrc/extensions/cast.cpp index 5c9d0f5b07..34565bcf44 100644 --- a/transformer_engine/pytorch/csrc/extensions/cast.cpp +++ b/transformer_engine/pytorch/csrc/extensions/cast.cpp @@ -80,6 +80,19 @@ py::object quantize(const at::Tensor &tensor, py::handle quantizer, const py::ob return output_py; } +py::object quantize_grouped(const py::handle &input, py::handle &output) { + using namespace transformer_engine::pytorch::detail; + init_extension(); + printf(">>>>>>>>>>>> quantize_grouped\n"); + const auto &grouped_input_tensor = GroupedTensorFromPyTorchGroupedTensor(input); + const auto &grouped_output_tensor = GroupedTensorFromPyTorchGroupedTensor(output); + NVTE_SCOPED_GIL_RELEASE({ + nvte_group_quantize(grouped_input_tensor.data(), grouped_output_tensor.data(), at::cuda::getCurrentCUDAStream()); + }); + + return py::reinterpret_borrow(output); +} + py::object dequantize(const py::handle &input, transformer_engine::DType otype) { init_extension(); diff --git a/transformer_engine/pytorch/csrc/extensions/pybind.cpp b/transformer_engine/pytorch/csrc/extensions/pybind.cpp index 79dd9ea5ce..12abd503cf 100644 --- a/transformer_engine/pytorch/csrc/extensions/pybind.cpp +++ b/transformer_engine/pytorch/csrc/extensions/pybind.cpp @@ -122,6 +122,10 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("dequantize", &transformer_engine::pytorch::dequantize, "Dequantize", py::arg("input"), py::arg("otype")); + m.def("quantize_grouped", transformer_engine::pytorch::quantize_grouped, "Quantize grouped tensor", + py::arg("input"), + py::arg("output")); + m.def("bgrad_quantize", transformer_engine::pytorch::bgrad_quantize, "Compute bias gradient and quantize", py::arg("input"), py::arg("quantizer")); m.def("generic_gemm", transformer_engine::pytorch::gemm, "Compute GEMM (matrix-matrix multiply)", diff --git a/transformer_engine/pytorch/csrc/pybind.h b/transformer_engine/pytorch/csrc/pybind.h index 25ffef0588..9541409c0c 100644 --- a/transformer_engine/pytorch/csrc/pybind.h +++ b/transformer_engine/pytorch/csrc/pybind.h @@ -95,6 +95,8 @@ TensorWrapper NVTETensorFromFloat8BlockwiseQTensor(py::handle tensor, TensorWrapper NVTETensorFromNVFP4Tensor(py::handle tensor, Quantizer *quantizer); +GroupedTensorWrapper GroupedTensorFromPyTorchGroupedTensor(py::handle tensor); + inline bool IsFloatingPointType(at::ScalarType type) { return type == at::kFloat || type == at::kHalf || type == at::kBFloat16; } diff --git a/transformer_engine/pytorch/csrc/type_converters.cpp b/transformer_engine/pytorch/csrc/type_converters.cpp index 3f998bb66f..8ab8dc1d48 100644 --- a/transformer_engine/pytorch/csrc/type_converters.cpp +++ b/transformer_engine/pytorch/csrc/type_converters.cpp @@ -170,6 +170,127 @@ TensorWrapper NVTETensorFromNVFP4Tensor(py::handle tensor, Quantizer *quantizer) return ret; } +NVTEScalingMode ScalingModeFromQuantizer(py::handle quantizer) { + auto *quantizer_ptr = quantizer.ptr(); + if (IsMXFP8Quantizers(quantizer_ptr)) { + return NVTE_MXFP8_1D_SCALING; + } + if (IsNVFP4Quantizers(quantizer_ptr)) { + return NVTE_NVFP4_1D_SCALING; + } + if (IsFloat8BlockwiseQuantizers(quantizer_ptr)) { + const int block_scaling_dim = quantizer.attr("block_scaling_dim").cast(); + return (block_scaling_dim == 2) ? NVTE_BLOCK_SCALING_2D : NVTE_BLOCK_SCALING_1D; + } + return NVTE_DELAYED_TENSOR_SCALING; +} + +DType GetTransformerEngineDTypeForScaleInv(py::handle quantizer, at::Tensor scale_inv) { + auto *quantizer_ptr = quantizer.ptr(); + if (IsMXFP8Quantizers(quantizer_ptr)) { + return DType::kFloat8E8M0; + } + if (IsFloat8BlockwiseQuantizers(quantizer_ptr)) { + return DType::kFloat32; + } + if (IsNVFP4Quantizers(quantizer_ptr)) { + return DType::kFloat8E4M3; + } + return GetTransformerEngineDType(scale_inv.scalar_type()); +} + +GroupedTensorWrapper GroupedTensorFromPyTorchGroupedTensor(py::handle tensor) { + // Returns a GroupedTensorWrapper from a PyTorch GroupedTensor. + const auto num_tensors = tensor.attr("num_tensors").cast(); + const auto logical_shape = tensor.attr("logical_shape").cast>(); + py::handle quantizer = py::none(); + DType quantizer_dtype = DType::kNumTypes; + NVTEScalingMode scaling_mode = NVTE_DELAYED_TENSOR_SCALING; + bool with_gemm_swizzled_scales = false; + if (!tensor.attr("quantizers").is_none()) { + const auto quantizers = tensor.attr("quantizers").cast(); + quantizer = quantizers[0]; + if (!quantizers.empty() && !quantizer.is_none()) { + scaling_mode = ScalingModeFromQuantizer(quantizer); + quantizer_dtype = quantizer.attr("dtype").cast(); + with_gemm_swizzled_scales = quantizer.attr("optimize_for_gemm").cast(); + printf(">>>>>>>>>>>> GroupedTensorFromPyTorchGroupedTensor with_gemm_swizzled_scales: %d\n", with_gemm_swizzled_scales); + } + } + auto ret = GroupedTensorWrapper(num_tensors, logical_shape, scaling_mode); + + // Rowwise data + if (!tensor.attr("data").is_none()) { + const auto &data = tensor.attr("data").cast(); + DType data_dtype = + quantizer.is_none() ? GetTransformerEngineDType(data.scalar_type()) : quantizer_dtype; + ret.set_rowwise_data(data.data_ptr(), data_dtype, getTensorShape(data)); + } + + // Columnwise data + if (!tensor.attr("columnwise_data").is_none()) { + const auto &data = tensor.attr("columnwise_data").cast(); + DType data_dtype = + quantizer.is_none() ? GetTransformerEngineDType(data.scalar_type()) : quantizer_dtype; + ret.set_columnwise_data(data.data_ptr(), data_dtype, getTensorShape(data)); + } + + // Scale + if (!tensor.attr("scale").is_none()) { + const auto &scale = tensor.attr("scale").cast(); + ret.set_scale(scale.data_ptr(), GetTransformerEngineDType(scale.scalar_type()), + getTensorShape(scale)); + } + + // Amax + if (!tensor.attr("amax").is_none()) { + const auto &amax = tensor.attr("amax").cast(); + ret.set_amax(amax.data_ptr(), GetTransformerEngineDType(amax.scalar_type()), + getTensorShape(amax)); + } + if (!tensor.attr("columnwise_amax").is_none()) { + const auto &amax = tensor.attr("columnwise_amax").cast(); + ret.set_columnwise_amax(amax.data_ptr(), GetTransformerEngineDType(amax.scalar_type()), + getTensorShape(amax)); + } + + // Scale inverse + if (!tensor.attr("scale_inv").is_none()) { + const auto &scale_inv = tensor.attr("scale_inv").cast(); + ret.set_rowwise_scale_inv(scale_inv.data_ptr(), + GetTransformerEngineDTypeForScaleInv(quantizer, scale_inv), + getTensorShape(scale_inv)); + } + if (!tensor.attr("columnwise_scale_inv").is_none()) { + const auto &scale_inv = tensor.attr("columnwise_scale_inv").cast(); + ret.set_columnwise_scale_inv(scale_inv.data_ptr(), + GetTransformerEngineDTypeForScaleInv(quantizer, scale_inv), + getTensorShape(scale_inv)); + } + + // Shape metadata + if (!tensor.attr("first_dims").is_none()) { + const auto &first_dims = tensor.attr("first_dims").cast(); + ret.set_first_dims(first_dims.data_ptr(), GetTransformerEngineDType(first_dims.scalar_type()), + getTensorShape(first_dims)); + } + if (!tensor.attr("last_dims").is_none()) { + const auto &last_dims = tensor.attr("last_dims").cast(); + ret.set_last_dims(last_dims.data_ptr(), GetTransformerEngineDType(last_dims.scalar_type()), + getTensorShape(last_dims)); + } + if (!tensor.attr("tensor_offsets").is_none()) { + const auto &tensor_offsets = tensor.attr("tensor_offsets").cast(); + ret.set_tensor_offsets(tensor_offsets.data_ptr(), + GetTransformerEngineDType(tensor_offsets.scalar_type()), + getTensorShape(tensor_offsets)); + } + + ret.set_with_gemm_swizzled_scales(with_gemm_swizzled_scales); + + return ret; +} + } // namespace detail } // namespace transformer_engine::pytorch diff --git a/transformer_engine/pytorch/tensor/mxfp8_tensor.py b/transformer_engine/pytorch/tensor/mxfp8_tensor.py index 58d095a4f4..a283b43908 100644 --- a/transformer_engine/pytorch/tensor/mxfp8_tensor.py +++ b/transformer_engine/pytorch/tensor/mxfp8_tensor.py @@ -165,6 +165,49 @@ def calibrate(self, tensor: torch.Tensor) -> None: # TODO(ksivamani): No calibration needed for mxfp8? pass + def get_scale_shape( + self, + shape: Iterable[int], + columnwise: bool, + ) -> Tuple[int, int]: + """Calculate the shape of the scaling tensor for MXFP8 1D blockwise quantization. + + This method determines the shape of the scaling tensor needed for blockwise quantization, + taking into account the input tensor shape and whether columnwise scaling is used. + + Parameters + ---------- + shape : Iterable[int] + Shape of the input tensor to be quantized + columnwise : bool + Whether to use columnwise scaling (True) or rowwise scaling (False) + + Returns + ------- + Tuple[int, int] + Shape of the scaling tensor as (outer_dim, inner_dim) + For MXFP8 1D blockwise quantization, blocksize is 32 + Swizzle kernel will be performed before GEMM to suit the need of CuBLAS. + CuBLAS doc: https://docs.nvidia.com/cuda/cublas/index.html#d-block-scaling-factors-layout + """ + if columnwise: + # Columnwise: scale_inv shape is [prod(shape[:-1]) // BLOCK_SIZE, shape[-1]] + # with padding to multiples of [4, 128] + return ( + round_up_to_nearest_multiple(math.prod(shape[:-1]) // MXFP8_BLOCK_SCALING_SIZE, 4), + round_up_to_nearest_multiple(shape[-1], 128), + ) + # Rowwise: scale_inv shape is [prod(shape[:-1]), shape[-1] // BLOCK_SIZE] + # with padding to multiples of [128, 4] + return ( + round_up_to_nearest_multiple(math.prod(shape[:-1]), 128), + round_up_to_nearest_multiple(shape[-1] // MXFP8_BLOCK_SCALING_SIZE, 4), + ) + + def get_columnwise_shape(self, rowwise_data_shape: Tuple[int, ...]) -> Tuple[int, ...]: + """Calculate the shape of the columnwise data for MXFP8 1D blockwise quantization.""" + return rowwise_data_shape + def create_tensor_from_data( self, data: torch.Tensor, @@ -705,7 +748,7 @@ def fsdp_post_all_gather( columnwise_scale_inv=columnwise_scale_inv, fp8_dtype=fp8_dtype, dtype=param_dtype, - shape=rowwise_data.shape if rowwise_data is not None else columnwise_data.shape, + shape=(rowwise_data.shape if rowwise_data is not None else columnwise_data.shape), quantizer=self._quantizer, with_gemm_swizzled_scales=False, ) diff --git a/transformer_engine/pytorch/tensor/storage/__init__.py b/transformer_engine/pytorch/tensor/storage/__init__.py index d7a2719200..54ed5caa60 100644 --- a/transformer_engine/pytorch/tensor/storage/__init__.py +++ b/transformer_engine/pytorch/tensor/storage/__init__.py @@ -7,3 +7,4 @@ from .mxfp8_tensor_storage import MXFP8TensorStorage # noqa: F401 from .float8_blockwise_tensor_storage import Float8BlockwiseQTensorStorage # noqa: F401 from .nvfp4_tensor_storage import NVFP4TensorStorage # noqa: F401 +from .grouped_tensor import GroupedTensor # noqa: F401 \ No newline at end of file diff --git a/transformer_engine/pytorch/tensor/storage/grouped_tensor.py b/transformer_engine/pytorch/tensor/storage/grouped_tensor.py new file mode 100644 index 0000000000..ad85a448e6 --- /dev/null +++ b/transformer_engine/pytorch/tensor/storage/grouped_tensor.py @@ -0,0 +1,964 @@ +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""Grouped tensor class for handling collections of tensors with different shapes""" +from __future__ import annotations +from typing import Optional, Tuple, List, Union +import math + +import torch + +from ...quantized_tensor import QuantizedTensorStorage, Quantizer + +from ..mxfp8_tensor import MXFP8Tensor +from ..nvfp4_tensor import NVFP4Tensor +from ..float8_tensor import Float8Tensor +from ..float8_blockwise_tensor import Float8BlockwiseQTensor +from .float8_tensor_storage import Float8TensorStorage +from .mxfp8_tensor_storage import MXFP8TensorStorage +from .float8_blockwise_tensor_storage import Float8BlockwiseQTensorStorage +from .nvfp4_tensor_storage import NVFP4TensorStorage + + +class GroupedTensor: + """ + EXPERIMENTAL FEATURE AND SUBJECT TO CHANGE. + + Grouped tensor is a collection of tensors with different shapes but the same dtype and scaling mode. + + Shape Representation: + - logical_shape: 2D shape representing the conceptual layout, i.e. the shape when member tensors + are flattened to 2D and stacked together (REQUIRED) + + When all_same_shape(): [num_tensors * M, N] where each tensor is (M, N) + + When varying_first_dim(): [~sum_of_first_dims, N] where N is common + + When varying_last_dim(): [M, ~sum_of_last_dims] where M is common + + When varying_both_dims(): [1, total_elements] (fully flattened) + + - first_dims and last_dims are OPTIONAL (None if dimension is uniform) + + None first_dims: all tensors have the same first dimension + + None last_dims: all tensors have the same last dimension + + Both None: all tensors have identical shapes + + Both set: each tensor has unique shape (first_dims[i], last_dims[i]) + + Data Layout: + - ALL data fields are stored as 1D flattened arrays (data, columnwise_data, scale_inv, etc.) + - logical_shape provides the conceptual 2D interpretation + - All data is stored on device in contiguous layout + + Note: This structure is used only for combined storage of multiple tensors with the same dtype and scaling mode. + """ + + def __init__( + self, + num_tensors: int, + shape: List[Tuple[int, int]], + quantizers: Optional[List[Quantizer]] = None, + dtype: Optional[torch.dtype] = None, + data: Optional[torch.Tensor] = None, + columnwise_data: Optional[torch.Tensor] = None, + scale_inv: Optional[torch.Tensor] = None, + columnwise_scale_inv: Optional[torch.Tensor] = None, + amax: Optional[torch.Tensor] = None, + columnwise_amax: Optional[torch.Tensor] = None, + scale: Optional[torch.Tensor] = None, + first_dims: Optional[torch.Tensor] = None, + last_dims: Optional[torch.Tensor] = None, + tensor_offsets: Optional[torch.Tensor] = None, + offsets: Optional[List[int]] = None, + scale_inv_offsets: Optional[List[int]] = None, + columnwise_scale_inv_offsets: Optional[List[int]] = None, + logical_shape: Optional[Tuple[int, int]] = None, + ) -> None: + """ + Initialize a GroupedTensor. + + Args: + num_tensors: Number of tensors in the group + shape: 2D shape of each tensor (len num_tensors) + quantizers: List of Quantizers for the grouped tensor + data: Row-wise data buffer (1D flattened) + columnwise_data: Column-wise data buffer (1D flattened) + scale_inv: Row-wise scale inverse buffer + columnwise_scale_inv: Column-wise scale inverse buffer + amax: Row-wise amax buffer + columnwise_amax: Column-wise amax buffer + scale: Scale buffer (for FP8-DS only) + first_dims: Device tensor of int64 array of length num_tensors (or None if uniform) + last_dims: Device tensor of int64 array of length num_tensors (or None if uniform) + tensor_offsets: Device tensor of int64 array of length num_tensors (or None if uniform) + offsets: Vector of integer offsets for each tensor. + logical_shape: 2D tuple representing conceptual shape + """ + print(f">>>>>>>>>>>> GroupedTensor init: {quantizers}") + print(f">>>>>>>>>>>> shape: {shape}") + print(f">>>>>>>>>>>> dtype: {dtype}") + print(f">>>>>>>>>>>> data: {data.shape}") + print(f">>>>>>>>>>>> columnwise_data: {columnwise_data.shape if columnwise_data is not None else None}") + print(f">>>>>>>>>>>> scale_inv: {scale_inv.shape if scale_inv is not None else None}") + print(f">>>>>>>>>>>> columnwise_scale_inv: {columnwise_scale_inv.shape if columnwise_scale_inv is not None else None}") + print(f">>>>>>>>>>>> amax: {amax.shape if amax is not None else None}") + print(f">>>>>>>>>>>> columnwise_amax: {columnwise_amax.shape if columnwise_amax is not None else None}") + print(f">>>>>>>>>>>> scale: {scale.shape if scale is not None else None}") + print(f">>>>>>>>>>>> first_dims: {first_dims}") + print(f">>>>>>>>>>>> last_dims: {last_dims}") + print(f">>>>>>>>>>>> tensor_offsets: {tensor_offsets}") + print(f">>>>>>>>>>>> offsets: {offsets}") + print(f">>>>>>>>>>>> scale_inv_offsets: {scale_inv_offsets}") + print(f">>>>>>>>>>>> columnwise_scale_inv_offsets: {columnwise_scale_inv_offsets}") + print(f">>>>>>>>>>>> logical_shape: {logical_shape}") + print(f">>>>>>>>>>>> num_tensors: {num_tensors}") + + self.num_tensors = num_tensors + self.quantizers = quantizers + self.shape = shape + self.dtype = ( + dtype if dtype is not None else torch.float32 + ) # Default to float32 if not provided + + # Data buffers + self.data = data + self.columnwise_data = columnwise_data + self.scale_inv = scale_inv + self.columnwise_scale_inv = columnwise_scale_inv + self.amax = amax + self.columnwise_amax = columnwise_amax + self.scale = scale + + # For convenient indexing for python GroupedTensor API. + self.scale_inv_offsets = scale_inv_offsets + self.columnwise_scale_inv_offsets = columnwise_scale_inv_offsets + + # Shape information (OPTIONAL - None if dimension is uniform across all tensors) + # first_dims[i] = first dimension of tensor i (None if all tensors have same first dim) + # last_dims[i] = last dimension of tensor i (None if all tensors have same last dim) + self.first_dims = ( + first_dims # Device pointer to int64_t array of length num_tensors (or None) + ) + self.last_dims = ( + last_dims # Device pointer to int64_t array of length num_tensors (or None) + ) + + # Offsets for indexing into contiguous 1D layout (OPTIONAL - not needed if all_same_shape()) + # tensor_offsets[i] = element offset to start of tensor i (cumulative sum of numel for tensors 0..i-1) + # Usage: tensor_i_ptr = data.data_ptr() + tensor_offsets[i] * element_size + # If None and all_same_shape(): offset[i] = i * M * N (where M, N are common dimensions) + self.tensor_offsets = ( + tensor_offsets # Device pointer to int64_t array of length num_tensors (or None) + ) + self.offsets = offsets # Vector of integer offsets for each tensor. + + # Logical shape: conceptual 2D shape of the grouped data (REQUIRED) + # Represents how the 1D flattened data should be interpreted as 2D + # Always 2D with positive dimensions + self.logical_shape = logical_shape if logical_shape is not None else (0, 0) + + # Hold a reference to the quantized tensors that occupy same storage as the GroupedTensor. + # Used as a convenience. + self.quantized_tensors = None + + def has_data(self) -> bool: + """ + Check if the tensor has row-wise data. + + Returns: + True if data buffer is initialized, False otherwise + """ + return self.data is not None + + def has_columnwise_data(self) -> bool: + """ + Check if the tensor has column-wise data. + + Returns: + True if columnwise_data buffer is initialized, False otherwise + """ + return self.columnwise_data is not None + + def all_same_first_dim(self) -> bool: + """ + Check if all tensors in the group have the same first dimension. + + Returns: + True if first dimension is uniform across all tensors + """ + return self.first_dims is None + + def all_same_last_dim(self) -> bool: + """ + Check if all tensors in the group have the same last dimension. + + Returns: + True if last dimension is uniform across all tensors + """ + return self.last_dims is None + + def all_same_shape(self) -> bool: + """ + Check if all tensors in the group have identical shapes. + + Returns: + True if all tensors have the same shape + """ + return self.first_dims is None and self.last_dims is None + + def varying_both_dims(self) -> bool: + """ + Check if both dimensions vary across tensors. + + Returns: + True if both first and last dimensions vary + """ + return self.first_dims is not None and self.last_dims is not None + + def get_common_first_dim(self) -> int: + """ + Get the common first dimension when all tensors share it. + + Returns: + The common first dimension + + Raises: + RuntimeError: If first dimension varies across tensors or logical_shape is not 2D + """ + if not self.all_same_first_dim(): + raise RuntimeError("First dim varies across tensors") + if len(self.logical_shape) != 2: + raise RuntimeError("Logical shape must be 2D") + + if self.all_same_shape(): + # When both dims are uniform: logical_shape = [num_tensors * M, N] + return self.logical_shape[0] // self.num_tensors + # When varying last dims but not first dim: logical_shape = [M, sum_of_last_dims] + return self.logical_shape[0] + + def get_common_last_dim(self) -> int: + """ + Get the common last dimension when all tensors share it. + + Returns: + The common last dimension + + Raises: + RuntimeError: If last dimension varies across tensors or logical_shape is not 2D + """ + if not self.all_same_last_dim(): + raise RuntimeError("Last dim varies across tensors") + if len(self.logical_shape) != 2: + raise RuntimeError("Logical shape must be 2D") + + # For both uniform and varying first dim cases: logical_shape[1] is the common last dim + return self.logical_shape[1] + + def get_dtype(self) -> torch.dtype: + """ + Get the high precision data type of the tensor. + + Returns: + The high precision dtype of the data buffer + """ + + return self.dtype + + def clear(self) -> None: + """ + Reset tensor data and clear all buffers. + """ + self.data = None + self.columnwise_data = None + self.scale_inv = None + self.columnwise_scale_inv = None + self.amax = None + self.columnwise_amax = None + self.scale = None + self.first_dims = None + self.last_dims = None + self.tensor_offsets = None + self.logical_shape = (0, 0) + self.num_tensors = 0 + self.quantizers = None + self.quantized_tensors = None + self.offsets = None + self.scale_inv_offsets = None + self.columnwise_scale_inv_offsets = None + + def __repr__(self) -> str: + """String representation of the GroupedTensor.""" + return ( + f"GroupedTensor(num_tensors={self.num_tensors}, " + f"shape={self.shape}, " + f"logical_shape={self.logical_shape}, " + f"dtype={self.get_dtype()})" + ) + + def __str__(self) -> str: + """User-friendly string representation.""" + shape_info = [] + if self.all_same_shape(): + shape_info.append("uniform shape") + else: + if not self.all_same_first_dim(): + shape_info.append("varying first dim") + if not self.all_same_last_dim(): + shape_info.append("varying last dim") + + return ( + f"GroupedTensor with {self.num_tensors} tensors " + f"({', '.join(shape_info) if shape_info else 'uniform'}), " + f"logical_shape={self.logical_shape}, " + f"dtype={self.get_dtype()}" + ) + + @staticmethod + def make_grouped_tensor_with_shapes( + num_tensors: int, + shape: List[Tuple[int, int]], + quantizers: Optional[List[Quantizer]] = None, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ) -> GroupedTensor: + """ + Create a GroupedTensor for storing multiple weight tensors of the same shape. + + Args: + num_tensors: Number of tensors + shape: 2D shape of each tensor (len num_tensors) + quantizers: List of Quantizers for each tensor + device: Device to allocate tensors on, defaults to current cuda device + dtype: Data type of the tensor (for high precision case) + + Returns: + A GroupedTensor. + """ + + # First dim + first_dim_list = [s[0] for s in shape] + uniform_first_dim = all(first_dim_list[0] == x for x in first_dim_list) + logical_first_dim = sum(first_dim_list) + if uniform_first_dim: + first_dims = None + else: + first_dims = torch.tensor([s[0] for s in shape], dtype=torch.int64, device=device) + + # Last dim + last_dim_list = [s[1] for s in shape] + logical_last_dim = last_dim_list[0] + assert all(logical_last_dim == x for x in last_dim_list), "Last dims should be uniform" + + return GroupedTensor.make_grouped_tensor( + num_tensors=num_tensors, + first_dims=first_dims, + last_dims=None, + logical_first_dim=logical_first_dim, + logical_last_dim=logical_last_dim, + quantizers=quantizers, + device=device, + dtype=dtype, + ) + + @staticmethod + def make_grouped_tensor( + num_tensors: int, + first_dims: Optional[torch.Tensor], + last_dims: Optional[torch.tensor], + logical_first_dim: int, + logical_last_dim: int, + quantizers: Optional[List[Quantizer]] = None, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ) -> GroupedTensor: + """ + Create a GroupedTensor for storing multiple weight tensors of the same shape. + + Args: + num_tensors: Number of tensors + first_dims: Device tensor of int64 array of length num_tensors (or None if uniform) + last_dims: Device tensor of int64 array of length num_tensors (or None if uniform) + logical_first_dim: Logical first dimension + logical_last_dim: Logical last dimension + quantizers: List of Quantizers for each tensor + Used to figure out the recipe and what to allocate. + device: Device to allocate tensors on, defaults to current cuda device + dtype: Data type of the tensor (for high precision case) + + Returns: + A GroupedTensor. + """ + + # Set device + if device is None: + device = torch.cuda.current_device() + + # Shape patterns and validation. + all_same_first = first_dims is None + all_same_last = last_dims is None + + assert all_same_last, "Last dim must be uniform for GroupedTensor" + assert logical_first_dim > 0, "Logical first dim must be positive for GroupedTensor" + assert logical_last_dim > 0, "Logical last dim must be positive for GroupedTensor" + + # assert ( + # logical_first_dim % 128 == 0 + # ), "Logical first dim must be divisible by 128" + # assert logical_last_dim % 128 == 0, "Logical last dim must be divisible by 128" + + # Calculate tensor offsets (cumulative element offsets) + tensor_offsets = None + offsets = None + shape = [] + if not all_same_first: + # Need explicit offsets for non-uniform shapes + # Offsets are based on number of elements and not pointers. + # Kernels need to calculate precise pointers based on size of elements. + + # TODO(ksivaman): Single kernel + remove the host offset calculation. + tensor_offsets = torch.cat( + [ + torch.zeros(1, device=first_dims.device, dtype=first_dims.dtype), + torch.cumsum(first_dims * logical_last_dim, dim=0), + ] + ) + offsets = tensor_offsets.tolist() + first_dims_list = first_dims.tolist() + for i in range(num_tensors): + shape.append((first_dims_list[i], logical_last_dim)) + else: + offsets = [ + i * logical_first_dim * logical_last_dim // num_tensors + for i in range(num_tensors + 1) + ] + for i in range(num_tensors): + shape.append((logical_first_dim // num_tensors, logical_last_dim)) + + # Calculate logical shape based + logical_shape = (logical_first_dim, logical_last_dim) + + quantizer = quantizers[0] if isinstance(quantizers, list) else quantizers + print(f">>>>>>>>>>>>> quantizers: {quantizers}") + print(f">>>>>>>>>>>>> quantizer: {quantizer}") + no_quantization = quantizer is None + + rowwise_usage = quantizer.rowwise_usage if not no_quantization else True + columnwise_usage = quantizer.columnwise_usage if not no_quantization else False + + # Calculate total elements across all tensors + total_elements = logical_first_dim * logical_last_dim + + data = None + columnwise_data = None + scale_inv = None + columnwise_scale_inv = None + amax = None + columnwise_amax = None + scale = None + scale_inv_offsets = None + columnwise_scale_inv_offsets = None + if no_quantization: + assert dtype is not None, "dtype must be provided for unquantized GroupedTensor" + if rowwise_usage: + # Allocate rowwise data buffer (1D flattened, uint8) + data = torch.empty(total_elements, dtype=dtype, device=device) + + if columnwise_usage: + # Allocate columnwise data buffer (1D flattened, uint8) + columnwise_data = torch.empty(total_elements, dtype=dtype, device=device) + elif quantizer._get_compatible_recipe().mxfp8(): + if rowwise_usage: + # Allocate rowwise data buffer (1D flattened, uint8) + data = torch.empty(total_elements, dtype=torch.uint8, device=device) + # Scale inverse buffer for MXFP8 - complex shape based on block scaling + # For grouped tensors, we need to calculate scale_inv size for all tensors + total_scale_elements = 0 + scale_inv_offsets = [0] + for i, s in enumerate(shape): + scale_inv_shape = quantizer.get_scale_shape(s, False) + scale_elements = math.prod(scale_inv_shape) + total_scale_elements += scale_elements + if i < num_tensors - 1: + scale_inv_offsets.append(total_scale_elements) + scale_inv = torch.empty(total_scale_elements, dtype=torch.uint8, device=device) + + if columnwise_usage: + # Allocate columnwise data buffer (1D flattened, uint8) + columnwise_data = torch.empty(total_elements, dtype=torch.uint8, device=device) + # Columnwise scale inverse buffer + total_columnwise_scale_elements = 0 + columnwise_scale_inv_offsets = [0] + for i, s in enumerate(shape): + scale_inv_shape = quantizer.get_scale_shape(s, False) + columnwise_scale_elements = math.prod(scale_inv_shape) + total_columnwise_scale_elements += columnwise_scale_elements + if i < num_tensors - 1: + columnwise_scale_inv_offsets.append(total_columnwise_scale_elements) + columnwise_scale_inv = torch.empty( + total_columnwise_scale_elements, dtype=torch.uint8, device=device + ) + elif quantizer._get_compatible_recipe().delayed(): + if rowwise_usage: + # Allocate rowwise data buffer (1D flattened, uint8) + data = torch.empty(total_elements, dtype=torch.uint8, device=device) + # Scale inverse - one per tensor + scale_inv = torch.empty(num_tensors, dtype=torch.float32, device=device) + # One scale per tensor, so offsets are simply 0, 1, 2, ..., num_tensors-1 + scale_inv_offsets = list(range(num_tensors)) + + if columnwise_usage: + # Allocate columnwise data buffer (1D flattened, uint8) + columnwise_data = torch.empty(total_elements, dtype=torch.uint8, device=device) + # Columnwise scale inverse - one per tensor + columnwise_scale_inv = torch.empty(num_tensors, dtype=torch.float32, device=device) + # One scale per tensor, so offsets are simply 0, 1, 2, ..., num_tensors-1 + columnwise_scale_inv_offsets = list(range(num_tensors)) + + # Amax buffer for delayed scaling - one per tensor + amax = torch.empty(num_tensors, dtype=torch.float32, device=device) + elif quantizer._get_compatible_recipe().nvfp4(): + + if rowwise_usage: + # Allocate rowwise data buffer (1D flattened, uint8, but FP4 packs 2 values per byte) + data = torch.empty((total_elements) // 2, dtype=torch.uint8, device=device) + # Scale inverse buffer for NVFP4 - complex shape based on block scaling + # For simplicity, calculate total scale elements needed + total_scale_elements = 0 + scale_inv_offsets = [0] + for i, s in enumerate(shape): + scale_inv_shape = quantizer.get_scale_shape(s, False) + total_scale_elements += math.prod(scale_inv_shape) + if i < num_tensors - 1: + scale_inv_offsets.append(total_scale_elements) + scale_inv = torch.empty(total_scale_elements, dtype=torch.uint8, device=device) + # Amax buffer - one per tensor + amax = torch.empty(num_tensors, dtype=torch.float32, device=device) + + if columnwise_usage: + # Allocate columnwise data buffer (1D flattened, uint8, FP4 packed) + columnwise_data = torch.empty( + (total_elements) // 2, dtype=torch.uint8, device=device + ) + # Columnwise scale inverse buffer + total_columnwise_scale_elements = 0 + columnwise_scale_inv_offsets = [0] + for i, s in enumerate(shape): + columnwise_scale_inv_shape = quantizer.get_scale_shape(s, True) + total_columnwise_scale_elements += math.prod(columnwise_scale_inv_shape) + if i < num_tensors - 1: + columnwise_scale_inv_offsets.append(total_columnwise_scale_elements) + columnwise_scale_inv = torch.empty( + total_columnwise_scale_elements, dtype=torch.uint8, device=device + ) + # Columnwise amax buffer - one per tensor + columnwise_amax = torch.empty(num_tensors, dtype=torch.float32, device=device) + elif quantizer._get_compatible_recipe().float8_block_scaling(): + if rowwise_usage: + # Allocate rowwise data buffer (1D flattened, uint8) + data = torch.empty(total_elements, dtype=torch.uint8, device=device) + # Scale inverse - size depends on block configuration + # For simplicity, calculate total scale elements needed + total_scale_elements = 0 + scale_inv_offsets = [0] + for i, s in enumerate(shape): + scale_inv_shape = quantizer.get_scale_shape(s, False) + total_scale_elements += math.prod(scale_inv_shape) + if i < num_tensors - 1: + scale_inv_offsets.append(total_scale_elements) + scale_inv = torch.empty(total_scale_elements, dtype=torch.float32, device=device) + + if columnwise_usage: + # Allocate columnwise data buffer (1D flattened, uint8) + columnwise_data = torch.empty(total_elements, dtype=torch.uint8, device=device) + # Columnwise scale inverse + total_columnwise_scale_elements = 0 + columnwise_scale_inv_offsets = [0] + for i, s in enumerate(shape): + columnwise_scale_inv_shape = quantizer.get_scale_shape(s, True) + total_columnwise_scale_elements += math.prod(columnwise_scale_inv_shape) + if i < num_tensors - 1: + columnwise_scale_inv_offsets.append(total_columnwise_scale_elements) + columnwise_scale_inv = torch.empty( + total_columnwise_scale_elements, dtype=torch.float32, device=device + ) + elif quantizer._get_compatible_recipe().float8_current_scaling(): + # Current scaling - per-tensor scaling computed on the fly + if rowwise_usage: + # Allocate rowwise data buffer (1D flattened, uint8) + data = torch.empty(total_elements, dtype=torch.uint8, device=device) + # Scale inverse - one per tensor + scale_inv = torch.empty(num_tensors, dtype=torch.float32, device=device) + # One scale per tensor, so offsets are simply 0, 1, 2, ..., num_tensors-1 + scale_inv_offsets = list(range(num_tensors)) + + if columnwise_usage: + # Allocate columnwise data buffer (1D flattened, uint8) + columnwise_data = torch.empty(total_elements, dtype=torch.uint8, device=device) + # Columnwise scale inverse - one per tensor + columnwise_scale_inv = torch.empty(num_tensors, dtype=torch.float32, device=device) + # One scale per tensor, so offsets are simply 0, 1, 2, ..., num_tensors-1 + columnwise_scale_inv_offsets = list(range(num_tensors)) + + # Scale and amax buffers for current scaling - one per tensor + scale = torch.empty(num_tensors, dtype=torch.float32, device=device) + amax = torch.empty(num_tensors, dtype=torch.float32, device=device) + else: + raise ValueError(f"Unsupported quantizer for GroupedTensor: {quantizer}") + + grouped_tensor = GroupedTensor( + num_tensors=num_tensors, + shape=shape, + dtype=dtype, + quantizers=quantizers, + data=data, + columnwise_data=columnwise_data, + scale_inv=scale_inv, + columnwise_scale_inv=columnwise_scale_inv, + amax=amax, + columnwise_amax=columnwise_amax, + scale=scale, + first_dims=first_dims, + last_dims=last_dims, + tensor_offsets=tensor_offsets, + offsets=offsets, + scale_inv_offsets=scale_inv_offsets, + columnwise_scale_inv_offsets=columnwise_scale_inv_offsets, + logical_shape=logical_shape, + ) + + # grouped_tensor.quantized_tensors = grouped_tensor.split_into_quantized_tensors() + return grouped_tensor + + def split_into_quantized_tensors( + self, + ) -> List[Union[QuantizedTensorStorage, torch.Tensor]]: + """ + Split the GroupedTensor into a list of `num_tensors` + quantized tensors based on the quantizer. No additional memory allocation is performed, + so the tensors returned are the same as the ones used to create the GroupedTensor. + + If quantizer is None, returns normal torch tensors. + If quantizer.internal is True, returns QuantizedTensorStorage. + Otherwise, returns QuantizedTensor. + + TODO(ksivaman): Block cases where any dims are varying. This is needed only + to expose the weights as separate parameters. + """ + + result = [] + + no_quantization = self.quantizers is None + + # Case 1: No quantization - return regular torch tensors + if no_quantization: + for i in range(self.num_tensors): + # Get tensor shape + tensor_shape = self.shape[i] + + # Get tensor data slice + if self.offsets is not None: + start_offset = self.offsets[i] + numel = tensor_shape[0] * tensor_shape[1] + end_offset = start_offset + numel + + if self.has_data(): + tensor_data = self.data[start_offset:end_offset].view(tensor_shape) + result.append(tensor_data) + elif self.has_columnwise_data(): + tensor_data = self.columnwise_data[start_offset:end_offset].view( + tensor_shape + ) + result.append(tensor_data) + else: + raise RuntimeError("GroupedTensor has no data to split") + else: + # All same shape case + numel = tensor_shape[0] * tensor_shape[1] + start_offset = i * numel + end_offset = start_offset + numel + + if self.has_data(): + tensor_data = self.data[start_offset:end_offset].view(tensor_shape) + result.append(tensor_data) + elif self.has_columnwise_data(): + tensor_data = self.columnwise_data[start_offset:end_offset].view( + tensor_shape + ) + result.append(tensor_data) + else: + raise RuntimeError("GroupedTensor has no data to split") + + return result + + # Case 2: Quantized tensors + recipe = self.quantizers[0]._get_compatible_recipe() + + for i in range(self.num_tensors): + # Get tensor shape + tensor_shape = self.shape[i] + numel = tensor_shape[0] * tensor_shape[1] + + # Get data offsets + if self.offsets is not None: + data_start = self.offsets[i] + data_end = data_start + numel + else: + # All same shape + data_start = i * numel + data_end = data_start + numel + + # Special shape handling for NVFP4. + nvfp4 = self.quantizers[0]._get_compatible_recipe().nvfp4() + if nvfp4: + data_start = data_start // 2 + data_end = data_end // 2 + + # Extract rowwise and columnwise data + rowwise_data = None + columnwise_data = None + + if self.has_data(): + if nvfp4: + rowwise_tensor_shape = self.quantizers[0].convert_shape_for_fp4(tensor_shape) + else: + rowwise_tensor_shape = tensor_shape + rowwise_data = self.data[data_start:data_end].view(rowwise_tensor_shape) + + if self.has_columnwise_data(): + columnwise_tensor_shape = self.quantizers[0].get_columnwise_shape(tensor_shape) + if nvfp4: + columnwise_tensor_shape = self.quantizers[0].convert_shape_for_fp4( + columnwise_tensor_shape + ) + columnwise_data = self.columnwise_data[data_start:data_end].view( + columnwise_tensor_shape + ) + + # MXFP8 format + if recipe.mxfp8(): + # Extract scale_inv data + rowwise_scale_inv = None + columnwise_scale_inv = None + + if self.scale_inv is not None and self.scale_inv_offsets is not None: + scale_start = self.scale_inv_offsets[i] + if i < self.num_tensors - 1: + scale_end = self.scale_inv_offsets[i + 1] + else: + scale_end = self.scale_inv.numel() + + # Calculate expected scale shape for MXFP8 + scale_shape = self.quantizers[0].get_scale_shape(tensor_shape, False) + rowwise_scale_inv = self.scale_inv[scale_start:scale_end].view(scale_shape) + + if ( + self.columnwise_scale_inv is not None + and self.columnwise_scale_inv_offsets is not None + ): + cscale_start = self.columnwise_scale_inv_offsets[i] + if i < self.num_tensors - 1: + cscale_end = self.columnwise_scale_inv_offsets[i + 1] + else: + cscale_end = self.columnwise_scale_inv.numel() + + cscale_shape = self.quantizers[0].get_scale_shape(tensor_shape, True) + columnwise_scale_inv = self.columnwise_scale_inv[cscale_start:cscale_end].view( + cscale_shape + ) + + if self.quantizers[0].internal: + mxfp8_tensor_class = MXFP8TensorStorage + else: + mxfp8_tensor_class = MXFP8Tensor + tensor = mxfp8_tensor_class( + shape=tensor_shape, + dtype=self.dtype, + rowwise_data=rowwise_data, + rowwise_scale_inv=rowwise_scale_inv, + columnwise_data=columnwise_data, + columnwise_scale_inv=columnwise_scale_inv, + fp8_dtype=self.quantizers[0].dtype, + quantizer=self.quantizers[0], + with_gemm_swizzled_scales=self.quantizers[0].optimize_for_gemm, + ) + result.append(tensor) + + # Delayed scaling or current scaling (both use Float8TensorStorage) + elif recipe.delayed() or recipe.float8_current_scaling(): + # Scale inverse - one per tensor + scale_inv = None + if self.scale_inv is not None: + scale_inv = self.scale_inv[i : i + 1] + + if self.quantizers[0].internal: + float8_tensor_class = Float8TensorStorage + else: + float8_tensor_class = Float8Tensor + + tensor = float8_tensor_class( + shape=tensor_shape, + dtype=self.dtype, + data=rowwise_data, + fp8_scale_inv=scale_inv, + fp8_dtype=self.quantizers[0].dtype, + quantizer=self.quantizers[0], + data_transpose=columnwise_data, + ) + result.append(tensor) + + # Float8 block scaling + elif recipe.float8_block_scaling(): + # Extract scale_inv data + rowwise_scale_inv = None + columnwise_scale_inv = None + + if self.scale_inv is not None and self.scale_inv_offsets is not None: + scale_start = self.scale_inv_offsets[i] + if i < self.num_tensors - 1: + scale_end = self.scale_inv_offsets[i + 1] + else: + scale_end = self.scale_inv.numel() + + # Get scale shape from quantizer + scale_shape = self.quantizers[0].get_scale_shape(tensor_shape, False) + rowwise_scale_inv = self.scale_inv[scale_start:scale_end].view(scale_shape) + + if ( + self.columnwise_scale_inv is not None + and self.columnwise_scale_inv_offsets is not None + ): + cscale_start = self.columnwise_scale_inv_offsets[i] + if i < self.num_tensors - 1: + cscale_end = self.columnwise_scale_inv_offsets[i + 1] + else: + cscale_end = self.columnwise_scale_inv.numel() + + # Get columnwise scale shape from quantizer + cscale_shape = self.quantizers[0].get_scale_shape(tensor_shape, True) + columnwise_scale_inv = self.columnwise_scale_inv[cscale_start:cscale_end].view( + cscale_shape + ) + + # Compute is_2D_scaled and data_format from quantizer attributes + is_2D_scaled = self.quantizers[0].block_scaling_dim == 2 + + if self.quantizers[0].internal: + float8_blockwise_q_tensor_class = Float8BlockwiseQTensorStorage + else: + float8_blockwise_q_tensor_class = Float8BlockwiseQTensor + + tensor = float8_blockwise_q_tensor_class( + shape=tensor_shape, + dtype=self.dtype, + rowwise_data=rowwise_data, + rowwise_scale_inv=rowwise_scale_inv, + columnwise_data=columnwise_data, + columnwise_scale_inv=columnwise_scale_inv, + fp8_dtype=self.quantizers[0].dtype, + quantizer=self.quantizers[0], + is_2D_scaled=is_2D_scaled, + ) + result.append(tensor) + + # NVFP4 format + elif recipe.nvfp4(): + # Extract scale_inv data + rowwise_scale_inv = None + columnwise_scale_inv = None + amax_rowwise = None + amax_columnwise = None + + if self.scale_inv is not None and self.scale_inv_offsets is not None: + scale_start = self.scale_inv_offsets[i] + if i < self.num_tensors - 1: + scale_end = self.scale_inv_offsets[i + 1] + else: + scale_end = self.scale_inv.numel() + + # Get scale shape from quantizer + scale_shape = self.quantizers[0].get_scale_shape(tensor_shape, False) + rowwise_scale_inv = self.scale_inv[scale_start:scale_end].view(scale_shape) + + if ( + self.columnwise_scale_inv is not None + and self.columnwise_scale_inv_offsets is not None + ): + cscale_start = self.columnwise_scale_inv_offsets[i] + if i < self.num_tensors - 1: + cscale_end = self.columnwise_scale_inv_offsets[i + 1] + else: + cscale_end = self.columnwise_scale_inv.numel() + + # Get columnwise scale shape from quantizer + cscale_shape = self.quantizers[0].get_scale_shape(tensor_shape, True) + columnwise_scale_inv = self.columnwise_scale_inv[cscale_start:cscale_end].view( + cscale_shape + ) + + # Extract amax - one per tensor + if self.amax is not None: + amax_rowwise = self.amax[i : i + 1] + + if self.columnwise_amax is not None: + amax_columnwise = self.columnwise_amax[i : i + 1] + + if self.quantizers[0].internal: + nvfp4_tensor_class = NVFP4TensorStorage + else: + nvfp4_tensor_class = NVFP4Tensor + + tensor = nvfp4_tensor_class( + shape=tensor_shape, + dtype=self.dtype, + rowwise_data=rowwise_data, + rowwise_scale_inv=rowwise_scale_inv, + columnwise_data=columnwise_data, + columnwise_scale_inv=columnwise_scale_inv, + amax_rowwise=amax_rowwise, + amax_columnwise=amax_columnwise, + fp4_dtype=self.quantizers[0].dtype, + quantizer=self.quantizers[0], + with_gemm_swizzled_scales=self.quantizers[0].optimize_for_gemm, + ) + result.append(tensor) + + else: + raise ValueError(f"Unsupported quantization recipe: {recipe}") + + return result + + @staticmethod + def create_and_quantize( + tensors: int, + quantizers: None | List[Quantizer], + *, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + noop_flag: Optional[torch.Tensor] = None, + ) -> Tuple[QuantizedTensorStorage, ...]: + """ + Quantize given tensors into quantized tensors with underlying + storage allocated in a GroupedTensor. + """ + + grouped_tensor = GroupedTensor.make_grouped_tensor_with_shapes( + num_tensors=len(tensors), + shape=[t.shape for t in tensors], + quantizers=quantizers, + device=device, + dtype=dtype, + ) + + grouped_tensor.quantize(tensors, noop_flag=noop_flag) + + return grouped_tensor + + def quantize( + self, + tensors: List[torch.Tensor], + noop_flag: Optional[torch.Tensor] = None, + ) -> Tuple[QuantizedTensorStorage, ...]: + """ + Quantize the GroupedTensor inplace. + """ + + quantized_tensors = self.split_into_quantized_tensors() + for i in range(self.num_tensors): + self.quantizers[0].update_quantized(tensors[i], quantized_tensors[i], noop_flag=noop_flag) + return quantized_tensors \ No newline at end of file From 46f2eb10fe8780b7d1de524758c58eda07a0e06b Mon Sep 17 00:00:00 2001 From: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Date: Fri, 6 Feb 2026 19:44:22 -0800 Subject: [PATCH 007/172] quantize mxfp8 Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --- .../dot_product_attention.py | 5 +- .../attention/dot_product_attention/utils.py | 132 ++++++++++++++++-- .../pytorch/tensor/storage/grouped_tensor.py | 36 ++--- 3 files changed, 141 insertions(+), 32 deletions(-) diff --git a/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py b/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py index 55553d30be..eb905d7b93 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py @@ -583,8 +583,9 @@ def init_fp8_metadata(self, num_gemms: int = 1) -> None: # global recipe set in autocast() fp8_recipe = FP8GlobalStateManager.get_fp8_recipe() - if fp8_recipe.custom(): - return + print(f"fp8_recipe: {fp8_recipe}") + # if fp8_recipe.custom(): + # return # switch/append recipe: fp8_recipe stays unchanged, but DPA.fp8_meta["recipe"] may be set to # a different recipe than fp8_recipe. DPA.quantizers may be a mix of different quantizers as well. diff --git a/transformer_engine/pytorch/attention/dot_product_attention/utils.py b/transformer_engine/pytorch/attention/dot_product_attention/utils.py index 78083c0b0b..76f28f449d 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/utils.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/utils.py @@ -40,7 +40,8 @@ Float8Quantizer, Float8CurrentScalingQuantizer, ) -from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8Quantizer +from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8Quantizer, MXFP8Tensor +from transformer_engine.pytorch.tensor.storage.grouped_tensor import GroupedTensor from transformer_engine.pytorch.quantization import get_fp8_te_dtype from transformer_engine.pytorch.constants import TE_DType @@ -2192,24 +2193,131 @@ def combine_and_quantize(qkv_layout, q, k, v, qkv_quantizer): q,k,v = [x.contiguous() for x in [q, k, v]] # bshd_bshd_bhsd -> bhsd_bhsd_bhsd + # thd_thd_thd -> htd_htd_htd qkv_quantizer.optimize_for_gemm = True qkv_quantizer._internal = False dim_s = qkv_format.find("s") if 's' in qkv_format else qkv_format.find("t") - dim_others = [i for i in range(len(v.shape)) if i != dim_s] - perm = [*dim_others, dim_s] - # perm = [*dim_others[:-1], dim_s, dim_others[-1]] - v = v.permute(*perm).contiguous() - qkv_layout = "bshd_bshd_bhds" - inv = [0] * len(perm) - for i, p in enumerate(perm): - inv[p] = i - # v = v.permute(*inv) - q_fp8, k_fp8, v_fp8 = [qkv_quantizer(x) for x in [q, k, v]] - v_fp8._rowwise_data = v_fp8._rowwise_data.permute(*inv).contiguous() + def permute_x(x): + dim_others = [i for i in range(len(x.shape)) if i != dim_s] + perm = [*dim_others[:-1], dim_s, dim_others[-1]] + x = x.permute(*perm).contiguous() + return x + q, k, v = [permute_x(x) for x in [q, k, v]] + # consider bhsd for now + batch_size, num_heads = q.shape[0], q.shape[1] + seq_len, head_dim = q.shape[-2], q.shape[-1] + num_tensors = 3 * batch_size * num_heads + # qkv = torch.cat([q, k, v], dim=0).reshape(num_tensors, seq_len, head_dim) + # qkv_list = [qkv[i] for i in range(num_tensors)] + # print(f">>>>>>>>>>>> num_tensors: {num_tensors}") + shapes = [(seq_len, head_dim) for _ in range(num_tensors)] + quantizers = [qkv_quantizer] * num_tensors + grouped_input = GroupedTensor.make_grouped_tensor_with_shapes( + num_tensors=num_tensors, + shape=shapes, + quantizers=None, + device="cuda", + dtype=src_nominal_dtype, + ) + offset = 0 + for x in [q, k, v]: + numel = x.numel() + grouped_input.data[offset : offset + numel].copy_(x.reshape(-1)) + offset += numel + grouped_output = GroupedTensor.make_grouped_tensor_with_shapes( + num_tensors=num_tensors, + shape=shapes, + quantizers=quantizers, + device="cuda", + ) + _ = tex.quantize_grouped(grouped_input, grouped_output) + print(f">>>>>>>>>>>> grouped_output: {grouped_output.data.shape if grouped_output.data is not None else None}") + print(f">>>>>>>>>>>> grouped_output: {grouped_output.scale_inv.shape if grouped_output.scale_inv is not None else None}") + print(f">>>>>>>>>>>> grouped_output: {grouped_output.columnwise_data.shape if grouped_output.columnwise_data is not None else None}") + print(f">>>>>>>>>>>> grouped_output: {grouped_output.columnwise_scale_inv.shape if grouped_output.columnwise_scale_inv is not None else None}") + # grouped_output_list = [grouped_output[i] for i in range(num_tensors)] + # q_fp8, k_fp8, v_fp8 = grouped_output_list[0:32], grouped_output_list[32:64], grouped_output_list[64:] + # grouped_output.num_tensors = 3 + # grouped_output.quantizers = [qkv_quantizer] * 3 + # grouped_output.shape = [(batch_size * num_heads * seq_len, head_dim) for _ in range(3)] + # grouped_output.dtype = src_nominal_dtype + # grouped_output.data = torch.cat([x.reshape(-1) for x in [q, k, v]], dim=0) + # q_fp8, k_fp8, v_fp8 = grouped_output.split_into_quantized_tensors() + + def split_qkv(grouped_tensor, num_tensors): + rowwise_shape = q.shape + rowwise_scale_inv_shape = (*q.shape[:-1], q.shape[-1]//32) + columnwise_shape = q.shape + columnwise_scale_inv_shape = (*q.shape[:-2], q.shape[-2]//32, q.shape[-1]) + rowwise_data = grouped_tensor.data.view(num_tensors, *rowwise_shape).split([1] * num_tensors) + rowwise_scale_inv = grouped_tensor.scale_inv.view(num_tensors, *rowwise_scale_inv_shape).split([1] * num_tensors) + columnwise_data = grouped_tensor.columnwise_data.view(num_tensors, *columnwise_shape).split([1] * num_tensors) + columnwise_scale_inv = grouped_tensor.columnwise_scale_inv.view(num_tensors, *columnwise_scale_inv_shape).split([1] * num_tensors) + print(f">>>>>>>>>>>> rowwise_data: {len(rowwise_data)}, rowwise_scale_inv: {len(rowwise_scale_inv)}, columnwise_data: {len(columnwise_data)}, columnwise_scale_inv: {len(columnwise_scale_inv)}") + return [MXFP8Tensor( + shape=q.shape, + dtype=q.dtype, + rowwise_data=rowwise_data[i].squeeze(0), + rowwise_scale_inv=rowwise_scale_inv[i].squeeze(0), + columnwise_data=columnwise_data[i].squeeze(0), + columnwise_scale_inv=columnwise_scale_inv[i].squeeze(0), + fp8_dtype=qkv_quantizer.dtype, + quantizer=qkv_quantizer, + with_gemm_swizzled_scales=qkv_quantizer.optimize_for_gemm, + ) for i in range(num_tensors)] + q_fp8, k_fp8, v_fp8 = split_qkv(grouped_output, 3) + + print(f">>>>>>>>>>>> q_fp8: {q_fp8.shape}, k_fp8: {k_fp8.shape}, v_fp8: {v_fp8.shape}") + print(f">>>>>>>>>>>> rowwise_data: q_fp8: {q_fp8._rowwise_data.shape}, k_fp8: {k_fp8._rowwise_data.shape}, v_fp8: {v_fp8._rowwise_data.shape}") + print(f">>>>>>>>>>>> rowwise_scale_inv: q_fp8: {q_fp8._rowwise_scale_inv.shape}, k_fp8: {k_fp8._rowwise_scale_inv.shape}, v_fp8: {v_fp8._rowwise_scale_inv.shape}") + print(f">>>>>>>>>>>> columnwise_data: q_fp8: {q_fp8._columnwise_data.shape}, k_fp8: {k_fp8._columnwise_data.shape}, v_fp8: {v_fp8._columnwise_data.shape}") + print(f">>>>>>>>>>>> columnwise_scale_inv: q_fp8: {q_fp8._columnwise_scale_inv.shape}, k_fp8: {k_fp8._columnwise_scale_inv.shape}, v_fp8: {v_fp8._columnwise_scale_inv.shape}") + + # print(f">>>>>>>>>>>> grouped_output: {len(grouped_output) if grouped_output is not None else None}") + + # qkv_mxfp8 = grouped_tensor.quantize(qkv_list) + # print(f">>>>>>>>>>>> qkv_mxfp8: {type(qkv_mxfp8)}") + # qkv_mxfp8_list = [qkv_mxfp8[i] for i in range(num_tensors)] + # print(f">>>>>>>>>>>> qkv_mxfp8: {qkv_mxfp8}") + # print(f">>>>>>>>>>>> qkv_mxfp8: {len(qkv_mxfp8_list)}") + # print(f">>>>>>>>>>>> qkv_mxfp8.shape: {qkv_mxfp8.shape}") + # print(f">>>>>>>>>>>> qkv_mxfp8._rowwise_data.shape: {qkv_mxfp8._rowwise_data.shape}") + # print(f">>>>>>>>>>>> qkv_mxfp8._rowwise_scale_inv.shape: {qkv_mxfp8._rowwise_scale_inv.shape}") + # print(f">>>>>>>>>>>> qkv_mxfp8._columnwise_data.shape: {qkv_mxfp8._columnwise_data.shape}") + # print(f">>>>>>>>>>>> qkv_mxfp8._columnwise_scale_inv.shape: {qkv_mxfp8._columnwise_scale_inv.shape}") + # q_fp8, k_fp8, v_fp8 = qkv_mxfp8[0::batch_size * num_heads], qkv_mxfp8[batch_size:2*batch_size], qkv_mxfp8[2*batch_size:] + + # q_fp8, k_fp8, v_fp8 = qkv_mxfp8.split_into_quantized_tensors() + + print(f"q_fp8: {q_fp8.shape}, k_fp8: {k_fp8.shape}, v_fp8: {v_fp8.shape}") print(f"q_fp8: {q_fp8._rowwise_data.shape}, k_fp8: {k_fp8._rowwise_data.shape}, v_fp8: {v_fp8._rowwise_data.shape}") print(f"q_fp8: {q_fp8._rowwise_scale_inv.shape}, k_fp8: {k_fp8._rowwise_scale_inv.shape}, v_fp8: {v_fp8._rowwise_scale_inv.shape}") print(f"q_fp8: {q_fp8._columnwise_data.shape}, k_fp8: {k_fp8._columnwise_data.shape}, v_fp8: {v_fp8._columnwise_data.shape}") print(f"q_fp8: {q_fp8._columnwise_scale_inv.shape}, k_fp8: {k_fp8._columnwise_scale_inv.shape}, v_fp8: {v_fp8._columnwise_scale_inv.shape}") + + # q_fp8_rowwise, k_fp8_rowwise, v_fp8_rowwise = [x._rowwise_data for x in qkv_mxfp8] + # q_fp8_columnwise, k_fp8_columnwise, v_fp8_columnwise = [x._columnwise_data for x in qkv_mxfp8] + # q_fp8, k_fp8, v_fp8 = q_fp8_rowwise, k_fp8_rowwise, v_fp8_columnwise + + # dim_s = qkv_format.find("s") if 's' in qkv_format else qkv_format.find("t") + # dim_others = [i for i in range(len(v.shape)) if i != dim_s] + # perm = [*dim_others, dim_s] + # # perm = [*dim_others[:-1], dim_s, dim_others[-1]] + # v = v.permute(*perm).contiguous() + + qkv_layout = "bhsd_bhsd_bhsd" + + # inv = [0] * len(perm) + # for i, p in enumerate(perm): + # inv[p] = i + # # v = v.permute(*inv) + + # q_fp8, k_fp8, v_fp8 = [qkv_quantizer(x) for x in [q, k, v]] + # # v_fp8._rowwise_data = v_fp8._rowwise_data.permute(*inv).contiguous() + # print(f"q_fp8: {q_fp8._rowwise_data.shape}, k_fp8: {k_fp8._rowwise_data.shape}, v_fp8: {v_fp8._rowwise_data.shape}") + # print(f"q_fp8: {q_fp8._rowwise_scale_inv.shape}, k_fp8: {k_fp8._rowwise_scale_inv.shape}, v_fp8: {v_fp8._rowwise_scale_inv.shape}") + # print(f"q_fp8: {q_fp8._columnwise_data.shape}, k_fp8: {k_fp8._columnwise_data.shape}, v_fp8: {v_fp8._columnwise_data.shape}") + # print(f"q_fp8: {q_fp8._columnwise_scale_inv.shape}, k_fp8: {k_fp8._columnwise_scale_inv.shape}, v_fp8: {v_fp8._columnwise_scale_inv.shape}") return q_fp8, k_fp8, v_fp8, qkv_layout match qkv_group: diff --git a/transformer_engine/pytorch/tensor/storage/grouped_tensor.py b/transformer_engine/pytorch/tensor/storage/grouped_tensor.py index ad85a448e6..522c25f370 100644 --- a/transformer_engine/pytorch/tensor/storage/grouped_tensor.py +++ b/transformer_engine/pytorch/tensor/storage/grouped_tensor.py @@ -90,23 +90,23 @@ def __init__( offsets: Vector of integer offsets for each tensor. logical_shape: 2D tuple representing conceptual shape """ - print(f">>>>>>>>>>>> GroupedTensor init: {quantizers}") - print(f">>>>>>>>>>>> shape: {shape}") - print(f">>>>>>>>>>>> dtype: {dtype}") - print(f">>>>>>>>>>>> data: {data.shape}") - print(f">>>>>>>>>>>> columnwise_data: {columnwise_data.shape if columnwise_data is not None else None}") - print(f">>>>>>>>>>>> scale_inv: {scale_inv.shape if scale_inv is not None else None}") - print(f">>>>>>>>>>>> columnwise_scale_inv: {columnwise_scale_inv.shape if columnwise_scale_inv is not None else None}") - print(f">>>>>>>>>>>> amax: {amax.shape if amax is not None else None}") - print(f">>>>>>>>>>>> columnwise_amax: {columnwise_amax.shape if columnwise_amax is not None else None}") - print(f">>>>>>>>>>>> scale: {scale.shape if scale is not None else None}") - print(f">>>>>>>>>>>> first_dims: {first_dims}") - print(f">>>>>>>>>>>> last_dims: {last_dims}") - print(f">>>>>>>>>>>> tensor_offsets: {tensor_offsets}") - print(f">>>>>>>>>>>> offsets: {offsets}") - print(f">>>>>>>>>>>> scale_inv_offsets: {scale_inv_offsets}") - print(f">>>>>>>>>>>> columnwise_scale_inv_offsets: {columnwise_scale_inv_offsets}") - print(f">>>>>>>>>>>> logical_shape: {logical_shape}") + # print(f">>>>>>>>>>>> GroupedTensor init: {quantizers}") + # print(f">>>>>>>>>>>> shape: {shape}") + # print(f">>>>>>>>>>>> dtype: {dtype}") + # print(f">>>>>>>>>>>> data: {data.shape}") + # print(f">>>>>>>>>>>> columnwise_data: {columnwise_data.shape if columnwise_data is not None else None}") + # print(f">>>>>>>>>>>> scale_inv: {scale_inv.shape if scale_inv is not None else None}") + # print(f">>>>>>>>>>>> columnwise_scale_inv: {columnwise_scale_inv.shape if columnwise_scale_inv is not None else None}") + # print(f">>>>>>>>>>>> amax: {amax.shape if amax is not None else None}") + # print(f">>>>>>>>>>>> columnwise_amax: {columnwise_amax.shape if columnwise_amax is not None else None}") + # print(f">>>>>>>>>>>> scale: {scale.shape if scale is not None else None}") + # print(f">>>>>>>>>>>> first_dims: {first_dims.shape}") + # print(f">>>>>>>>>>>> last_dims: {last_dims.shape}") + # print(f">>>>>>>>>>>> tensor_offsets: {tensor_offsets.shape}") + # print(f">>>>>>>>>>>> offsets: {offsets.shape}") + # print(f">>>>>>>>>>>> scale_inv_offsets: {scale_inv_offsets.shape}") + # print(f">>>>>>>>>>>> columnwise_scale_inv_offsets: {columnwise_scale_inv_offsets.shape}") + # print(f">>>>>>>>>>>> logical_shape: {logical_shape.shape}") print(f">>>>>>>>>>>> num_tensors: {num_tensors}") self.num_tensors = num_tensors @@ -434,7 +434,7 @@ def make_grouped_tensor( logical_shape = (logical_first_dim, logical_last_dim) quantizer = quantizers[0] if isinstance(quantizers, list) else quantizers - print(f">>>>>>>>>>>>> quantizers: {quantizers}") + # print(f">>>>>>>>>>>>> quantizers: {quantizers}") print(f">>>>>>>>>>>>> quantizer: {quantizer}") no_quantization = quantizer is None From e86207c7d7afe0c982b33633bb84c55c5fb01899 Mon Sep 17 00:00:00 2001 From: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Date: Mon, 9 Feb 2026 19:53:25 -0800 Subject: [PATCH 008/172] fix shapes/strides Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --- tests/pytorch/attention/test_attention.py | 21 ++- .../cast/mxfp8/group_quantize_mxfp8.cuh | 5 +- .../common/cast/mxfp8/quantize_mxfp8.cuh | 5 +- .../common/fused_attn/fused_attn.cpp | 70 ++++---- .../common/fused_attn/fused_attn_fp8.cu | 13 +- transformer_engine/common/fused_attn/utils.cu | 36 ++-- .../include/transformer_engine/fused_attn.h | 10 +- .../common/util/pybind_helper.h | 4 +- .../dot_product_attention/backends.py | 14 ++ .../attention/dot_product_attention/utils.py | 130 ++++---------- .../pytorch/cpp_extensions/fused_attn.py | 4 +- .../pytorch/csrc/type_converters.cpp | 15 +- .../pytorch/tensor/mxfp8_tensor.py | 2 + .../pytorch/tensor/storage/grouped_tensor.py | 162 ++++++++++-------- 14 files changed, 229 insertions(+), 262 deletions(-) diff --git a/tests/pytorch/attention/test_attention.py b/tests/pytorch/attention/test_attention.py index 0301f77ae8..5602114143 100644 --- a/tests/pytorch/attention/test_attention.py +++ b/tests/pytorch/attention/test_attention.py @@ -2157,16 +2157,19 @@ def test_dpa_fp8_vs_f16(dtype, model, qkv_layout, fp8_dpa_bwd, is_training, scal fused_attn_fwd_fp8, fused_attn_bwd_fp8 = _run_dpa_fp8_vs_f16( dtype, config, True, qkv_layout, is_training, fp8_recipe ) + # print(f">>>>>> fused_attn_bwd_fp8: {fused_attn_bwd_fp8} {is_training}") + # torch.save(fused_attn_fwd_fp8, "fused_attn_fwd_fp8.pt") - # os.environ["NVTE_FLASH_ATTN"] = "0" - # os.environ["NVTE_FUSED_ATTN"] = "1" - # os.environ["NVTE_UNFUSED_ATTN"] = "0" - # if config.dropout_p == 0.0: - # # test cuDNN FP8 dropout: need a FP16/BF16 reference on Blackwell - # logging.info("[test_dpa_fp8_vs_f16]: run with fp8_dpa = False (FusedAttention)") - # fused_attn_fwd_f16, fused_attn_bwd_f16 = _run_dpa_fp8_vs_f16( - # dtype, config, False, qkv_layout, is_training, fp8_recipe - # ) + os.environ["NVTE_FLASH_ATTN"] = "0" + os.environ["NVTE_FUSED_ATTN"] = "1" + os.environ["NVTE_UNFUSED_ATTN"] = "0" + if config.dropout_p == 0.0: + # test cuDNN FP8 dropout: need a FP16/BF16 reference on Blackwell + logging.info("[test_dpa_fp8_vs_f16]: run with fp8_dpa = False (FusedAttention)") + fused_attn_fwd_f16, fused_attn_bwd_f16 = _run_dpa_fp8_vs_f16( + dtype, config, False, qkv_layout, is_training, fp8_recipe + ) + # torch.save(fused_attn_fwd_f16, "fused_attn_fwd_f16.pt") atol = 5e-1 rtol = 5e-2 diff --git a/transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh b/transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh index 35e605067d..ea81e6c516 100644 --- a/transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh +++ b/transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh @@ -244,7 +244,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) group_quantize_mxfp8_kernel const int64_t *const __restrict__ last_dims_ptr, e8m0_t *const __restrict__ scales_rowwise_ptr, e8m0_t *const __restrict__ scales_colwise_ptr, const float *__restrict__ noop, float *const __restrict__ dbias_workspace, float *const __restrict__ amax_ptr) { -printf(">>>>>>>>>>>> WITH_GEMM_SWIZZLED_SCALES: %d\n", WITH_GEMM_SWIZZLED_SCALES); +// printf(">>>>>>>>>>>> WITH_GEMM_SWIZZLED_SCALES: %d\n", WITH_GEMM_SWIZZLED_SCALES); #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) constexpr bool COMPUTE_ACTIVATIONS = IS_DACT || IS_ACT; constexpr bool NO_ACTIVATIONS = !COMPUTE_ACTIVATIONS; @@ -935,16 +935,19 @@ void group_quantize(const GroupedTensor *input, const GroupedTensor *activations OType, true, true, WITH_GEMM_SWIZZLED_SCALES>; switch (scaling_type) { case ScalingType::ROWWISE: { + printf(">>>>>>>>>>>> grouped: ScalingType::ROWWISE\n"); kernel = group_quantize_mxfp8_kernel; break; } case ScalingType::COLWISE: { + printf(">>>>>>>>>>>> grouped: ScalingType::COLWISE\n"); kernel = group_quantize_mxfp8_kernel; break; } case ScalingType::BIDIMENSIONAL: { + printf(">>>>>>>>>>>> grouped: ScalingType::BIDIMENSIONAL\n"); kernel = group_quantize_mxfp8_kernel; break; diff --git a/transformer_engine/common/cast/mxfp8/quantize_mxfp8.cuh b/transformer_engine/common/cast/mxfp8/quantize_mxfp8.cuh index a3e7db94d1..82bf497a3b 100644 --- a/transformer_engine/common/cast/mxfp8/quantize_mxfp8.cuh +++ b/transformer_engine/common/cast/mxfp8/quantize_mxfp8.cuh @@ -55,7 +55,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) const float *noop, float *const dbias_workspace, float *const amax_ptr, const size_t rows, const size_t cols, const size_t scale_stride_rowwise, const size_t scale_stride_colwise) { -printf(">>>>>>>>>>>> non-grouped: WITH_GEMM_SWIZZLED_SCALES: %d\n", WITH_GEMM_SWIZZLED_SCALES); +// printf(">>>>>>>>>>>> non-grouped: WITH_GEMM_SWIZZLED_SCALES: %d\n", WITH_GEMM_SWIZZLED_SCALES); #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) constexpr bool COMPUTE_ACTIVATIONS = IS_DACT || IS_ACT; constexpr bool NO_ACTIVATIONS = !COMPUTE_ACTIVATIONS; @@ -777,6 +777,7 @@ void quantize(const Tensor &input, const Tensor *act_input, const Tensor *noop, switch (scaling_type) { case ScalingType::ROWWISE: { + printf(">>>>>>>>>>>> non-grouped: ScalingType::ROWWISE\n"); auto kernel = quantize_mxfp8_kernel; @@ -792,6 +793,7 @@ void quantize(const Tensor &input, const Tensor *act_input, const Tensor *noop, break; } case ScalingType::COLWISE: { + printf(">>>>>>>>>>>> non-grouped: ScalingType::COLWISE\n"); auto kernel = quantize_mxfp8_kernel; @@ -807,6 +809,7 @@ void quantize(const Tensor &input, const Tensor *act_input, const Tensor *noop, break; } case ScalingType::BIDIMENSIONAL: { + printf(">>>>>>>>>>>> non-grouped: ScalingType::BIDIMENSIONAL\n"); auto kernel = quantize_mxfp8_kernel; diff --git a/transformer_engine/common/fused_attn/fused_attn.cpp b/transformer_engine/common/fused_attn/fused_attn.cpp index 61a8d61635..3ed540b8f2 100644 --- a/transformer_engine/common/fused_attn/fused_attn.cpp +++ b/transformer_engine/common/fused_attn/fused_attn.cpp @@ -117,8 +117,8 @@ NVTE_QKV_Layout_Group nvte_get_qkv_layout_group(NVTE_QKV_Layout qkv_layout) { case NVTE_QKV_Layout::NVTE_Paged_KV_SBHD_SBHD_SBHD: case NVTE_QKV_Layout::NVTE_Paged_KV_THD_SBHD_SBHD: return NVTE_QKV_Layout_Group::NVTE_Paged_KV_HD_HD_HD; - case NVTE_QKV_Layout::NVTE_BSHD_BSHD_BHDS: - return NVTE_QKV_Layout_Group::NVTE_HD_HD_DS; + case NVTE_QKV_Layout::NVTE_BHSD_BHSD_BHSD: + return NVTE_QKV_Layout_Group::NVTE_SD_SD_SD; default: NVTE_ERROR("qkv_layout not supported!"); } @@ -159,8 +159,8 @@ NVTE_QKV_Format nvte_get_qkv_format(NVTE_QKV_Layout qkv_layout) { case NVTE_QKV_Layout::NVTE_THD_SBHD_SBHD: case NVTE_QKV_Layout::NVTE_Paged_KV_THD_SBHD_SBHD: return NVTE_QKV_Format::NVTE_THD_2SBHD; - case NVTE_QKV_Layout::NVTE_BSHD_BSHD_BHDS: - return NVTE_QKV_Format::NVTE_BHDS; + case NVTE_QKV_Layout::NVTE_BHSD_BHSD_BHSD: + return NVTE_QKV_Format::NVTE_BHSD; default: NVTE_ERROR("qkv_layout not supported!"); } @@ -180,8 +180,8 @@ NVTE_QKV_Format nvte_get_q_format(NVTE_QKV_Layout qkv_layout) { case NVTE_QKV_Format::NVTE_THD_2BSHD: case NVTE_QKV_Format::NVTE_THD_2SBHD: return NVTE_QKV_Format::NVTE_THD; - case NVTE_QKV_Format::NVTE_BHDS: - return NVTE_QKV_Format::NVTE_BHDS; + case NVTE_QKV_Format::NVTE_BHSD: + return NVTE_QKV_Format::NVTE_BHSD; default: NVTE_ERROR("qkv_layout not supported!"); } @@ -201,8 +201,8 @@ NVTE_QKV_Format nvte_get_kv_format(NVTE_QKV_Layout qkv_layout) { return NVTE_QKV_Format::NVTE_BSHD; case NVTE_QKV_Format::NVTE_THD: return NVTE_QKV_Format::NVTE_THD; - case NVTE_QKV_Format::NVTE_BHDS: - return NVTE_QKV_Format::NVTE_BHDS; + case NVTE_QKV_Format::NVTE_BHSD: + return NVTE_QKV_Format::NVTE_BHSD; default: NVTE_ERROR("qkv_layout not supported!"); } @@ -248,29 +248,29 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( if ((q_dtype == NVTEDType::kNVTEFloat8E4M3 || q_dtype == NVTEDType::kNVTEFloat8E5M2) && sm_arch_ >= 90 && bias_type == NVTE_Bias_Type::NVTE_NO_BIAS && // 8.9: t3hd, max_s=512, d=64, padding - // ((cudnn_runtime_version >= 8900 && sm_arch_ < 100 && - // qkv_layout == NVTE_QKV_Layout::NVTE_T3HD && max_seqlen_q == max_seqlen_kv && - // max_seqlen_q <= 512 && head_dim_qk == 64 && head_dim_v == 64 && - // attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK) || - // // 9.2.1: {bshd, sbhd}, any seqlen, d=128, {no_mask, causal} - // (cudnn_runtime_version >= 90201 && sm_arch_ < 100 && max_seqlen_q % 128 == 0 && - // max_seqlen_kv % 128 == 0 && head_dim_qk == 128 && head_dim_v == 128 && - // (attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK || - // attn_mask_type == NVTE_Mask_Type::NVTE_NO_MASK)) || - // // 9.7: {bshd, sbhd}, any seqlen, d<=256 for sm90 and d<=128 for sm100, {padding, padding_causal} - // (cudnn_runtime_version >= 90700 && - // // TODO (cyang): add is_training to nvte_get_fused_attn_backend - // // sm90: fwd d<=256, bwd d=128 only - // // sm100: fwd d<=128, bwd d<=128 - // ((sm_arch_ < 100 && (!is_training) && head_dim_qk <= 256 && head_dim_v <= 256) || - // (sm_arch_ < 100 && is_training && head_dim_qk == 128 && head_dim_v == 128) || - // (sm_arch_ >= 100 && head_dim_qk <= 128 && head_dim_v <= 128)) && - // head_dim_qk % 16 == 0 && head_dim_v % 16 == 0 && - // (attn_mask_type == NVTE_Mask_Type::NVTE_NO_MASK || - // attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK || - // attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK || - // attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK))) && - (qkv_format == NVTE_QKV_Format::NVTE_BSHD || qkv_format == NVTE_QKV_Format::NVTE_SBHD || qkv_format == NVTE_QKV_Format::NVTE_BHDS) && + ((cudnn_runtime_version >= 8900 && sm_arch_ < 100 && + qkv_layout == NVTE_QKV_Layout::NVTE_T3HD && max_seqlen_q == max_seqlen_kv && + max_seqlen_q <= 512 && head_dim_qk == 64 && head_dim_v == 64 && + attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK) || + // 9.2.1: {bshd, sbhd}, any seqlen, d=128, {no_mask, causal} + (cudnn_runtime_version >= 90201 && sm_arch_ < 100 && max_seqlen_q % 128 == 0 && + max_seqlen_kv % 128 == 0 && head_dim_qk == 128 && head_dim_v == 128 && + (attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK || + attn_mask_type == NVTE_Mask_Type::NVTE_NO_MASK)) || + // 9.7: {bshd, sbhd}, any seqlen, d<=256 for sm90 and d<=128 for sm100, {padding, padding_causal} + (cudnn_runtime_version >= 90700 && + // TODO (cyang): add is_training to nvte_get_fused_attn_backend + // sm90: fwd d<=256, bwd d=128 only + // sm100: fwd d<=128, bwd d<=128 + ((sm_arch_ < 100 && (!is_training) && head_dim_qk <= 256 && head_dim_v <= 256) || + (sm_arch_ < 100 && is_training && head_dim_qk == 128 && head_dim_v == 128) || + (sm_arch_ >= 100 && head_dim_qk <= 128 && head_dim_v <= 128)) && + head_dim_qk % 16 == 0 && head_dim_v % 16 == 0 && + (attn_mask_type == NVTE_Mask_Type::NVTE_NO_MASK || + attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK || + attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK || + attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK))) && + (qkv_format == NVTE_QKV_Format::NVTE_BSHD || qkv_format == NVTE_QKV_Format::NVTE_SBHD || qkv_format == NVTE_QKV_Format::NVTE_BHSD) && !requires_64bit_ragged_offset && (softmax_type == NVTE_Softmax_Type::NVTE_VANILLA_SOFTMAX) && // 9.10.0: known bugs with SDPA FP8 (cudnn_runtime_version != 91000) && !return_max_logit) { @@ -1151,8 +1151,6 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso auto ndim = input_Q->data.shape.size(); auto ndim_kv = input_K->data.shape.size(); size_t b = input_cu_seqlens_q->data.shape[0] - 1; - size_t h_q = input_Q->data.shape[ndim - 2]; - size_t h_kv = input_K->data.shape[ndim_kv - 2]; size_t d_qk = input_Q->data.shape[ndim - 1]; size_t d_v = input_V->data.shape[ndim_kv - 1]; size_t t_q = 0; @@ -1165,6 +1163,8 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso if (kv_format == NVTE_QKV_Format::NVTE_THD) { t_kv = input_K->data.shape[0]; } + size_t h_q = (q_format == NVTE_QKV_Format::NVTE_BHSD) ? input_Q->data.shape[ndim-3] : input_Q->data.shape[ndim - 2]; + size_t h_kv = (kv_format == NVTE_QKV_Format::NVTE_BHSD) ? input_K->data.shape[ndim_kv-3] : input_K->data.shape[ndim_kv - 2]; int64_t num_pages_k = 0; int64_t num_pages_v = 0; int64_t page_size_k = 0; @@ -1277,8 +1277,6 @@ void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETenso auto ndim = input_Q->data.shape.size(); auto ndim_kv = input_K->data.shape.size(); size_t b = input_cu_seqlens_q->data.shape[0] - 1; - size_t h_q = input_Q->data.shape[ndim - 2]; - size_t h_kv = input_K->data.shape[ndim_kv - 2]; size_t d_qk = input_Q->data.shape[ndim - 1]; size_t d_v = input_V->data.shape[ndim_kv - 1]; size_t t_q = 0; @@ -1291,6 +1289,8 @@ void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETenso if (kv_format == NVTE_QKV_Format::NVTE_THD) { t_kv = input_K->data.shape[0]; } + size_t h_q = (q_format == NVTE_QKV_Format::NVTE_BHSD) ? input_Q->data.shape[ndim-3] : input_Q->data.shape[ndim - 2]; + size_t h_kv = (kv_format == NVTE_QKV_Format::NVTE_BHSD) ? input_K->data.shape[ndim_kv-3] : input_K->data.shape[ndim_kv - 2]; auto handle = cudnnExecutionPlanManager::Instance().GetHandle(); const NVTEDType Q_type = static_cast(input_Q->data.dtype); diff --git a/transformer_engine/common/fused_attn/fused_attn_fp8.cu b/transformer_engine/common/fused_attn/fused_attn_fp8.cu index bf4f019a67..f7698da5c3 100644 --- a/transformer_engine/common/fused_attn/fused_attn_fp8.cu +++ b/transformer_engine/common/fused_attn/fused_attn_fp8.cu @@ -1781,12 +1781,13 @@ void fused_attn_fp8_fwd_impl_v1( std::vector q_stride(4); std::vector k_stride(4); std::vector v_stride(4); + printf(">>>>>> b: %d, h: %d, h_g: %d, s_q: %d, s_kv: %d, d: %d\n", b, h, hg, s_q, s_kv, d); generateMatrixStrides(b, h, s_q, s_kv, d, q_stride.data(), layout, NVTE_QKV_Matrix::NVTE_Q_Matrix); generateMatrixStrides(b, hg, s_q, s_kv, d, k_stride.data(), layout, NVTE_QKV_Matrix::NVTE_K_Matrix); generateMatrixStrides(b, hg, s_q, s_kv, d, v_stride.data(), layout, - NVTE_QKV_Matrix::NVTE_K_Matrix); // need to double check + NVTE_QKV_Matrix::NVTE_V_Matrix); // need to double check printf(">>>>>> q_stride: %d, %d, %d, %d\n", q_stride[0], q_stride[1], q_stride[2], q_stride[3]); printf(">>>>>> k_stride: %d, %d, %d, %d\n", k_stride[0], k_stride[1], k_stride[2], k_stride[3]); printf(">>>>>> v_stride: %d, %d, %d, %d\n", v_stride[0], v_stride[1], v_stride[2], v_stride[3]); @@ -2527,10 +2528,10 @@ void fused_attn_fp8_fwd(size_t batch, size_t num_attn_heads, size_t num_gqa_grou devPtrDescaleQ = input_Q->scale_inv.dptr; devPtrK = input_K->data.dptr; devPtrDescaleK = input_K->scale_inv.dptr; - devPtrV = input_V->data.dptr; - devPtrDescaleV = input_V->scale_inv.dptr; - // devPtrV = input_V->columnwise_data.dptr; - // devPtrDescaleV = input_V->columnwise_scale_inv.dptr; + // devPtrV = input_V->data.dptr; + // devPtrDescaleV = input_V->scale_inv.dptr; + devPtrV = input_V->columnwise_data.dptr; + devPtrDescaleV = input_V->columnwise_scale_inv.dptr; devPtrO = output_O->data.dptr; devPtrAmaxO = output_O->amax.dptr; } else { @@ -2589,7 +2590,7 @@ void fused_attn_fp8_fwd(size_t batch, size_t num_attn_heads, size_t num_gqa_grou size_t workspace_size = 0; NVTE_QKV_Format qkv_format = nvte_get_qkv_format(qkv_layout); - if ((qkv_format == NVTE_QKV_Format::NVTE_BSHD) || (qkv_format == NVTE_QKV_Format::NVTE_SBHD) || (qkv_format == NVTE_QKV_Format::NVTE_BHDS)) { + if ((qkv_format == NVTE_QKV_Format::NVTE_BSHD) || (qkv_format == NVTE_QKV_Format::NVTE_SBHD) || (qkv_format == NVTE_QKV_Format::NVTE_BHSD)) { fused_attn::fused_attn_fp8_fwd_impl_v1( batch, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, head_dim, is_training, attn_scale, p_dropout, qkv_layout, bias_type, mask_type, devPtrQ, devPtrK, devPtrV, devPtrM, diff --git a/transformer_engine/common/fused_attn/utils.cu b/transformer_engine/common/fused_attn/utils.cu index 94a495153e..3ea40126cc 100644 --- a/transformer_engine/common/fused_attn/utils.cu +++ b/transformer_engine/common/fused_attn/utils.cu @@ -293,32 +293,24 @@ void generateMatrixStrides(int64_t b, int64_t h, int64_t s_q, int64_t s_kv, int6 strideA[hidden_dim_idx] = 1; } break; - case NVTE_QKV_Layout::NVTE_BSHD_BSHD_BHDS: + case NVTE_QKV_Layout::NVTE_BHSD_BHSD_BHSD: if ((matrix == NVTE_QKV_Matrix::NVTE_Q_Matrix) || (matrix == NVTE_QKV_Matrix::NVTE_O_Matrix)) { - strideA[batch_dim_idx] = s_q * h * d; - strideA[head_dim_idx] = d; - strideA[seqlen_dim_idx] = h * d; + strideA[batch_dim_idx] = h * s_q * d; + strideA[head_dim_idx] = s_q * d; + strideA[seqlen_dim_idx] = d; strideA[hidden_dim_idx] = 1; - } else if (matrix == NVTE_QKV_Matrix::NVTE_K_Matrix) { - strideA[batch_dim_idx] = s_kv * h * d; - strideA[head_dim_idx] = d; - strideA[seqlen_dim_idx] = h * d; + } else if ((matrix == NVTE_QKV_Matrix::NVTE_K_Matrix) || + (matrix == NVTE_QKV_Matrix::NVTE_V_Matrix)) { + strideA[batch_dim_idx] = h * s_kv * d; + strideA[head_dim_idx] = s_kv * d; + strideA[seqlen_dim_idx] = d; strideA[hidden_dim_idx] = 1; - } else if (matrix == NVTE_QKV_Matrix::NVTE_V_Matrix) { - strideA[batch_dim_idx] = s_kv * h * d; - strideA[head_dim_idx] = d * s_kv; - strideA[seqlen_dim_idx] = 1; - strideA[hidden_dim_idx] = s_kv; - } else if (matrix == NVTE_QKV_Matrix::NVTE_K_Matrix_Transpose) { - strideA[batch_dim_idx] = s_kv * h * d; - strideA[head_dim_idx] = d; - strideA[seqlen_transpose_dim_idx] = 1; - strideA[hidden_transpose_dim_idx] = h * d; - } else if (matrix == NVTE_QKV_Matrix::NVTE_V_Matrix_Transpose) { - strideA[batch_dim_idx] = s_kv * h * d; - strideA[head_dim_idx] = d * s_kv; - strideA[seqlen_transpose_dim_idx] = s_kv; + } else if ((matrix == NVTE_QKV_Matrix::NVTE_K_Matrix_Transpose) || + (matrix == NVTE_QKV_Matrix::NVTE_V_Matrix_Transpose)) { + strideA[batch_dim_idx] = h * s_kv * d; + strideA[head_dim_idx] = s_kv * d; + strideA[seqlen_transpose_dim_idx] = d; strideA[hidden_transpose_dim_idx] = 1; } break; diff --git a/transformer_engine/common/include/transformer_engine/fused_attn.h b/transformer_engine/common/include/transformer_engine/fused_attn.h index bc97d2a853..204d8f3d5a 100644 --- a/transformer_engine/common/include/transformer_engine/fused_attn.h +++ b/transformer_engine/common/include/transformer_engine/fused_attn.h @@ -52,7 +52,7 @@ enum NVTE_QKV_Layout { NVTE_Paged_KV_SBHD_SBHD_SBHD = 22, /*!< Paged_KV_SBHD_SBHD_SBHD layout */ NVTE_Paged_KV_THD_BSHD_BSHD = 23, /*!< Paged_KV_THD_BSHD_BSHD layout */ NVTE_Paged_KV_THD_SBHD_SBHD = 24, /*!< Paged_KV_THD_SBHD_SBHD layout */ - NVTE_BSHD_BSHD_BHDS = 25, /*!< BSHD_BSHD_BHDS layout */ + NVTE_BHSD_BHSD_BHSD = 25, /*!< BHSD_BHSD_BHSD layout */ }; /*! \enum NVTE_QKV_Layout_Group @@ -71,8 +71,8 @@ enum NVTE_QKV_Layout_Group { NVTE_HD_HD_HD = 4, /*! Paged_KV_HD_HD_HD QKV layouts, e.g. Paged_KV_BSHD_BSHD_BSHD, Paged_KV_THD_SBHD_SBHD */ NVTE_Paged_KV_HD_HD_HD = 5, - /*! BSHD_BSHD_BHDS QKV layouts, e.g. BSHD_BSHD_BHDS */ - NVTE_HD_HD_DS = 6, + /*! SD_SD_SD QKV layouts, e.g. BHSD_BHSD_BHSD */ + NVTE_SD_SD_SD = 6, }; /*! \enum NVTE_QKV_Format @@ -93,8 +93,8 @@ enum NVTE_QKV_Format { NVTE_THD_2BSHD = 5, /*! THD format for Q and SBHD format for KV, i.e. THD_SBHD_SBHD, Paged_KV_THD_SBHD_SBHD */ NVTE_THD_2SBHD = 6, - /*! BSHD_BSHD_BHSD QKV format, e.g. BSHD_BSHD_BHSD */ - NVTE_BHDS = 7, + /*! BHSD QKV format, e.g. BHSD_BHSD_BHSD */ + NVTE_BHSD = 7, }; /*! \enum NVTE_Bias_Type diff --git a/transformer_engine/common/util/pybind_helper.h b/transformer_engine/common/util/pybind_helper.h index b81c488005..96e6803ec5 100644 --- a/transformer_engine/common/util/pybind_helper.h +++ b/transformer_engine/common/util/pybind_helper.h @@ -49,7 +49,7 @@ .value("NVTE_BSHD_2SBHD", NVTE_QKV_Format::NVTE_BSHD_2SBHD) \ .value("NVTE_THD_2BSHD", NVTE_QKV_Format::NVTE_THD_2BSHD) \ .value("NVTE_THD_2SBHD", NVTE_QKV_Format::NVTE_THD_2SBHD) \ - .value("NVTE_BHDS", NVTE_QKV_Format::NVTE_BHDS); \ + .value("NVTE_BHSD", NVTE_QKV_Format::NVTE_BHSD); \ pybind11::enum_(m, "NVTE_QKV_Layout", pybind11::module_local()) \ .value("NVTE_SB3HD", NVTE_QKV_Layout::NVTE_SB3HD) \ .value("NVTE_SBH3D", NVTE_QKV_Layout::NVTE_SBH3D) \ @@ -76,7 +76,7 @@ .value("NVTE_Paged_KV_SBHD_SBHD_SBHD", NVTE_QKV_Layout::NVTE_Paged_KV_SBHD_SBHD_SBHD) \ .value("NVTE_Paged_KV_THD_BSHD_BSHD", NVTE_QKV_Layout::NVTE_Paged_KV_THD_BSHD_BSHD) \ .value("NVTE_Paged_KV_THD_SBHD_SBHD", NVTE_QKV_Layout::NVTE_Paged_KV_THD_SBHD_SBHD) \ - .value("NVTE_BSHD_BSHD_BHDS", NVTE_QKV_Layout::NVTE_BSHD_BSHD_BHDS); \ + .value("NVTE_BHSD_BHSD_BHSD", NVTE_QKV_Layout::NVTE_BHSD_BHSD_BHSD); \ pybind11::enum_(m, "NVTE_Fused_Attn_Backend", pybind11::module_local()) \ .value("NVTE_F16_max512_seqlen", NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) \ .value("NVTE_F16_arbitrary_seqlen", NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen) \ diff --git a/transformer_engine/pytorch/attention/dot_product_attention/backends.py b/transformer_engine/pytorch/attention/dot_product_attention/backends.py index 01c26a9728..5cc23eabd8 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/backends.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/backends.py @@ -1217,6 +1217,9 @@ def forward( # FP8 attention: torch.float16 or torch.bfloat16 out_nominal_dtype = q.dtype + # save original qkv_layout + original_qkv_layout = qkv_layout + max_logit = None if fp8: fused_attention_backend = FusedAttnBackend["FP8"] @@ -1276,6 +1279,17 @@ def forward( softmax_offset, cuda_graph=is_graph_capturing(), ) + if original_qkv_layout != qkv_layout: + print(f">>>>>>>>>>>> original_qkv_layout: {original_qkv_layout}") + print(f">>>>>>>>>>>> qkv_layout: {qkv_layout}") + print(f">>>>>>>>>>>> out_.shape: {out_.shape}") + original_qkv_format = original_qkv_layout.split("_")[0] + new_qkv_format = qkv_layout.split("_")[0] + perm = [] + for i in new_qkv_format: + perm.append(original_qkv_format.find(i)) + out_ = out_.permute(*perm).contiguous() + print(f">>>>>>>>>>>> out_.shape permuted: {out_.shape}") # out_fp8: Float8Tensor; dtype = torch.float16 or torch.bfloat16 # fp8_dtype = tex.DType.kFloat8E4M3 diff --git a/transformer_engine/pytorch/attention/dot_product_attention/utils.py b/transformer_engine/pytorch/attention/dot_product_attention/utils.py index 76f28f449d..c5ba652c28 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/utils.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/utils.py @@ -2191,11 +2191,13 @@ def combine_and_quantize(qkv_layout, q, k, v, qkv_quantizer): if isinstance(qkv_quantizer, MXFP8Quantizer): # bs3hd -> bshd_bshd_bhsd q,k,v = [x.contiguous() for x in [q, k, v]] + print(f">>>>>>>>>>>> Contiguous shapes: q: {q.shape}, k: {k.shape}, v: {v.shape}") # bshd_bshd_bhsd -> bhsd_bhsd_bhsd # thd_thd_thd -> htd_htd_htd qkv_quantizer.optimize_for_gemm = True - qkv_quantizer._internal = False + qkv_quantizer.internal = False + print(f">>>>>>>>>>>> qkv_quantizer.internal: {qkv_quantizer.internal}") dim_s = qkv_format.find("s") if 's' in qkv_format else qkv_format.find("t") def permute_x(x): dim_others = [i for i in range(len(x.shape)) if i != dim_s] @@ -2203,102 +2205,34 @@ def permute_x(x): x = x.permute(*perm).contiguous() return x q, k, v = [permute_x(x) for x in [q, k, v]] - # consider bhsd for now - batch_size, num_heads = q.shape[0], q.shape[1] - seq_len, head_dim = q.shape[-2], q.shape[-1] - num_tensors = 3 * batch_size * num_heads - # qkv = torch.cat([q, k, v], dim=0).reshape(num_tensors, seq_len, head_dim) - # qkv_list = [qkv[i] for i in range(num_tensors)] - # print(f">>>>>>>>>>>> num_tensors: {num_tensors}") - shapes = [(seq_len, head_dim) for _ in range(num_tensors)] - quantizers = [qkv_quantizer] * num_tensors - grouped_input = GroupedTensor.make_grouped_tensor_with_shapes( - num_tensors=num_tensors, - shape=shapes, - quantizers=None, - device="cuda", - dtype=src_nominal_dtype, - ) - offset = 0 - for x in [q, k, v]: - numel = x.numel() - grouped_input.data[offset : offset + numel].copy_(x.reshape(-1)) - offset += numel - grouped_output = GroupedTensor.make_grouped_tensor_with_shapes( - num_tensors=num_tensors, - shape=shapes, - quantizers=quantizers, - device="cuda", - ) - _ = tex.quantize_grouped(grouped_input, grouped_output) - print(f">>>>>>>>>>>> grouped_output: {grouped_output.data.shape if grouped_output.data is not None else None}") - print(f">>>>>>>>>>>> grouped_output: {grouped_output.scale_inv.shape if grouped_output.scale_inv is not None else None}") - print(f">>>>>>>>>>>> grouped_output: {grouped_output.columnwise_data.shape if grouped_output.columnwise_data is not None else None}") - print(f">>>>>>>>>>>> grouped_output: {grouped_output.columnwise_scale_inv.shape if grouped_output.columnwise_scale_inv is not None else None}") - # grouped_output_list = [grouped_output[i] for i in range(num_tensors)] - # q_fp8, k_fp8, v_fp8 = grouped_output_list[0:32], grouped_output_list[32:64], grouped_output_list[64:] - # grouped_output.num_tensors = 3 - # grouped_output.quantizers = [qkv_quantizer] * 3 - # grouped_output.shape = [(batch_size * num_heads * seq_len, head_dim) for _ in range(3)] - # grouped_output.dtype = src_nominal_dtype - # grouped_output.data = torch.cat([x.reshape(-1) for x in [q, k, v]], dim=0) - # q_fp8, k_fp8, v_fp8 = grouped_output.split_into_quantized_tensors() - - def split_qkv(grouped_tensor, num_tensors): - rowwise_shape = q.shape - rowwise_scale_inv_shape = (*q.shape[:-1], q.shape[-1]//32) - columnwise_shape = q.shape - columnwise_scale_inv_shape = (*q.shape[:-2], q.shape[-2]//32, q.shape[-1]) - rowwise_data = grouped_tensor.data.view(num_tensors, *rowwise_shape).split([1] * num_tensors) - rowwise_scale_inv = grouped_tensor.scale_inv.view(num_tensors, *rowwise_scale_inv_shape).split([1] * num_tensors) - columnwise_data = grouped_tensor.columnwise_data.view(num_tensors, *columnwise_shape).split([1] * num_tensors) - columnwise_scale_inv = grouped_tensor.columnwise_scale_inv.view(num_tensors, *columnwise_scale_inv_shape).split([1] * num_tensors) - print(f">>>>>>>>>>>> rowwise_data: {len(rowwise_data)}, rowwise_scale_inv: {len(rowwise_scale_inv)}, columnwise_data: {len(columnwise_data)}, columnwise_scale_inv: {len(columnwise_scale_inv)}") - return [MXFP8Tensor( - shape=q.shape, - dtype=q.dtype, - rowwise_data=rowwise_data[i].squeeze(0), - rowwise_scale_inv=rowwise_scale_inv[i].squeeze(0), - columnwise_data=columnwise_data[i].squeeze(0), - columnwise_scale_inv=columnwise_scale_inv[i].squeeze(0), - fp8_dtype=qkv_quantizer.dtype, - quantizer=qkv_quantizer, - with_gemm_swizzled_scales=qkv_quantizer.optimize_for_gemm, - ) for i in range(num_tensors)] - q_fp8, k_fp8, v_fp8 = split_qkv(grouped_output, 3) + print(f">>>>>>>>>>>> Permuted shapes: q: {q.shape}, k: {k.shape}, v: {v.shape}") + + original_shapes = [q.shape, k.shape, v.shape] + b, h_q, s_q, d_qk = q.shape + _, h_kv, s_kv, d_kv = v.shape + assert k.shape == (b, h_kv, s_kv, d_qk) + assert s_q % 128 == 0 + assert s_kv % 128 == 0 + assert d_qk % 32 == 0 + assert d_kv % 32 == 0 + q, k, v = [x.view(-1, x.shape[-1]) for x in [q, k, v]] + print(f">>>>>>>>>>>> Flattened shapes: q: {q.shape}, k: {k.shape}, v: {v.shape}") + # consider bhsd for now + grouped_tensor = GroupedTensor.create_and_quantize(tensors=[q, k, v], quantizer=qkv_quantizer) + print(f">>>>>>>>>>>> grouped_tensor: {type(grouped_tensor)}") + print(f">>>>>>>>>>>> grouped_tensor.quantized_tensors: {type(grouped_tensor.quantized_tensors)}") + q_fp8, k_fp8, v_fp8 = grouped_tensor.quantized_tensors + print(f">>>>>>>>>>>> q_fp8: {q_fp8._rowwise_data.shape}, k_fp8: {k_fp8._rowwise_data.shape}, v_fp8: {v_fp8._rowwise_data.shape}") + print(f">>>>>>>>>>>> q_fp8: {q_fp8._rowwise_scale_inv.shape}, k_fp8: {k_fp8._rowwise_scale_inv.shape}, v_fp8: {v_fp8._rowwise_scale_inv.shape}") + print(f">>>>>>>>>>>> q_fp8: {q_fp8._columnwise_data.shape}, k_fp8: {k_fp8._columnwise_data.shape}, v_fp8: {v_fp8._columnwise_data.shape}") + print(f">>>>>>>>>>>> q_fp8: {q_fp8._columnwise_scale_inv.shape}, k_fp8: {k_fp8._columnwise_scale_inv.shape}, v_fp8: {v_fp8._columnwise_scale_inv.shape}") + q_fp8, k_fp8, v_fp8 = [x.view(*original_shapes[i]) for i, x in enumerate([q_fp8, k_fp8, v_fp8])] print(f">>>>>>>>>>>> q_fp8: {q_fp8.shape}, k_fp8: {k_fp8.shape}, v_fp8: {v_fp8.shape}") - print(f">>>>>>>>>>>> rowwise_data: q_fp8: {q_fp8._rowwise_data.shape}, k_fp8: {k_fp8._rowwise_data.shape}, v_fp8: {v_fp8._rowwise_data.shape}") - print(f">>>>>>>>>>>> rowwise_scale_inv: q_fp8: {q_fp8._rowwise_scale_inv.shape}, k_fp8: {k_fp8._rowwise_scale_inv.shape}, v_fp8: {v_fp8._rowwise_scale_inv.shape}") - print(f">>>>>>>>>>>> columnwise_data: q_fp8: {q_fp8._columnwise_data.shape}, k_fp8: {k_fp8._columnwise_data.shape}, v_fp8: {v_fp8._columnwise_data.shape}") - print(f">>>>>>>>>>>> columnwise_scale_inv: q_fp8: {q_fp8._columnwise_scale_inv.shape}, k_fp8: {k_fp8._columnwise_scale_inv.shape}, v_fp8: {v_fp8._columnwise_scale_inv.shape}") - - # print(f">>>>>>>>>>>> grouped_output: {len(grouped_output) if grouped_output is not None else None}") - - # qkv_mxfp8 = grouped_tensor.quantize(qkv_list) - # print(f">>>>>>>>>>>> qkv_mxfp8: {type(qkv_mxfp8)}") - # qkv_mxfp8_list = [qkv_mxfp8[i] for i in range(num_tensors)] - # print(f">>>>>>>>>>>> qkv_mxfp8: {qkv_mxfp8}") - # print(f">>>>>>>>>>>> qkv_mxfp8: {len(qkv_mxfp8_list)}") - # print(f">>>>>>>>>>>> qkv_mxfp8.shape: {qkv_mxfp8.shape}") - # print(f">>>>>>>>>>>> qkv_mxfp8._rowwise_data.shape: {qkv_mxfp8._rowwise_data.shape}") - # print(f">>>>>>>>>>>> qkv_mxfp8._rowwise_scale_inv.shape: {qkv_mxfp8._rowwise_scale_inv.shape}") - # print(f">>>>>>>>>>>> qkv_mxfp8._columnwise_data.shape: {qkv_mxfp8._columnwise_data.shape}") - # print(f">>>>>>>>>>>> qkv_mxfp8._columnwise_scale_inv.shape: {qkv_mxfp8._columnwise_scale_inv.shape}") - # q_fp8, k_fp8, v_fp8 = qkv_mxfp8[0::batch_size * num_heads], qkv_mxfp8[batch_size:2*batch_size], qkv_mxfp8[2*batch_size:] - - # q_fp8, k_fp8, v_fp8 = qkv_mxfp8.split_into_quantized_tensors() - - print(f"q_fp8: {q_fp8.shape}, k_fp8: {k_fp8.shape}, v_fp8: {v_fp8.shape}") - print(f"q_fp8: {q_fp8._rowwise_data.shape}, k_fp8: {k_fp8._rowwise_data.shape}, v_fp8: {v_fp8._rowwise_data.shape}") - print(f"q_fp8: {q_fp8._rowwise_scale_inv.shape}, k_fp8: {k_fp8._rowwise_scale_inv.shape}, v_fp8: {v_fp8._rowwise_scale_inv.shape}") - print(f"q_fp8: {q_fp8._columnwise_data.shape}, k_fp8: {k_fp8._columnwise_data.shape}, v_fp8: {v_fp8._columnwise_data.shape}") - print(f"q_fp8: {q_fp8._columnwise_scale_inv.shape}, k_fp8: {k_fp8._columnwise_scale_inv.shape}, v_fp8: {v_fp8._columnwise_scale_inv.shape}") - - # q_fp8_rowwise, k_fp8_rowwise, v_fp8_rowwise = [x._rowwise_data for x in qkv_mxfp8] - # q_fp8_columnwise, k_fp8_columnwise, v_fp8_columnwise = [x._columnwise_data for x in qkv_mxfp8] - # q_fp8, k_fp8, v_fp8 = q_fp8_rowwise, k_fp8_rowwise, v_fp8_columnwise - + print(f">>>>>>>>>>>> q_fp8: {q_fp8._rowwise_data.shape}, k_fp8: {k_fp8._rowwise_data.shape}, v_fp8: {v_fp8._rowwise_data.shape}") + print(f">>>>>>>>>>>> q_fp8: {q_fp8._rowwise_scale_inv.shape}, k_fp8: {k_fp8._rowwise_scale_inv.shape}, v_fp8: {v_fp8._rowwise_scale_inv.shape}") + print(f">>>>>>>>>>>> q_fp8: {q_fp8._columnwise_data.shape}, k_fp8: {k_fp8._columnwise_data.shape}, v_fp8: {v_fp8._columnwise_data.shape}") + print(f">>>>>>>>>>>> q_fp8: {q_fp8._columnwise_scale_inv.shape}, k_fp8: {k_fp8._columnwise_scale_inv.shape}, v_fp8: {v_fp8._columnwise_scale_inv.shape}") # dim_s = qkv_format.find("s") if 's' in qkv_format else qkv_format.find("t") # dim_others = [i for i in range(len(v.shape)) if i != dim_s] # perm = [*dim_others, dim_s] @@ -2312,12 +2246,6 @@ def split_qkv(grouped_tensor, num_tensors): # inv[p] = i # # v = v.permute(*inv) - # q_fp8, k_fp8, v_fp8 = [qkv_quantizer(x) for x in [q, k, v]] - # # v_fp8._rowwise_data = v_fp8._rowwise_data.permute(*inv).contiguous() - # print(f"q_fp8: {q_fp8._rowwise_data.shape}, k_fp8: {k_fp8._rowwise_data.shape}, v_fp8: {v_fp8._rowwise_data.shape}") - # print(f"q_fp8: {q_fp8._rowwise_scale_inv.shape}, k_fp8: {k_fp8._rowwise_scale_inv.shape}, v_fp8: {v_fp8._rowwise_scale_inv.shape}") - # print(f"q_fp8: {q_fp8._columnwise_data.shape}, k_fp8: {k_fp8._columnwise_data.shape}, v_fp8: {v_fp8._columnwise_data.shape}") - # print(f"q_fp8: {q_fp8._columnwise_scale_inv.shape}, k_fp8: {k_fp8._columnwise_scale_inv.shape}, v_fp8: {v_fp8._columnwise_scale_inv.shape}") return q_fp8, k_fp8, v_fp8, qkv_layout match qkv_group: diff --git a/transformer_engine/pytorch/cpp_extensions/fused_attn.py b/transformer_engine/pytorch/cpp_extensions/fused_attn.py index 2748228b42..b4811eb4f6 100644 --- a/transformer_engine/pytorch/cpp_extensions/fused_attn.py +++ b/transformer_engine/pytorch/cpp_extensions/fused_attn.py @@ -42,7 +42,7 @@ "bshd_2sbhd": NVTE_QKV_Format.NVTE_BSHD_2SBHD, "thd_2bshd": NVTE_QKV_Format.NVTE_THD_2BSHD, "thd_2sbhd": NVTE_QKV_Format.NVTE_THD_2SBHD, - "bshd_bshd_bhds": NVTE_QKV_Format.NVTE_BHDS, + "bhsd": NVTE_QKV_Format.NVTE_BHSD, } QKVLayout = { @@ -71,7 +71,7 @@ "paged_kv_sbhd_sbhd_sbhd": NVTE_QKV_Layout.NVTE_Paged_KV_SBHD_SBHD_SBHD, "paged_kv_thd_bshd_bshd": NVTE_QKV_Layout.NVTE_Paged_KV_THD_BSHD_BSHD, "paged_kv_thd_sbhd_sbhd": NVTE_QKV_Layout.NVTE_Paged_KV_THD_SBHD_SBHD, - "bshd_bshd_bhds": NVTE_QKV_Layout.NVTE_BSHD_BSHD_BHDS, + "bhsd_bhsd_bhsd": NVTE_QKV_Layout.NVTE_BHSD_BHSD_BHSD, } AttnBiasType = { diff --git a/transformer_engine/pytorch/csrc/type_converters.cpp b/transformer_engine/pytorch/csrc/type_converters.cpp index 8ab8dc1d48..c17be6c855 100644 --- a/transformer_engine/pytorch/csrc/type_converters.cpp +++ b/transformer_engine/pytorch/csrc/type_converters.cpp @@ -207,15 +207,12 @@ GroupedTensorWrapper GroupedTensorFromPyTorchGroupedTensor(py::handle tensor) { DType quantizer_dtype = DType::kNumTypes; NVTEScalingMode scaling_mode = NVTE_DELAYED_TENSOR_SCALING; bool with_gemm_swizzled_scales = false; - if (!tensor.attr("quantizers").is_none()) { - const auto quantizers = tensor.attr("quantizers").cast(); - quantizer = quantizers[0]; - if (!quantizers.empty() && !quantizer.is_none()) { - scaling_mode = ScalingModeFromQuantizer(quantizer); - quantizer_dtype = quantizer.attr("dtype").cast(); - with_gemm_swizzled_scales = quantizer.attr("optimize_for_gemm").cast(); - printf(">>>>>>>>>>>> GroupedTensorFromPyTorchGroupedTensor with_gemm_swizzled_scales: %d\n", with_gemm_swizzled_scales); - } + if (!tensor.attr("quantizer").is_none()) { + quantizer = tensor.attr("quantizer").cast(); + scaling_mode = ScalingModeFromQuantizer(quantizer); + quantizer_dtype = quantizer.attr("dtype").cast(); + with_gemm_swizzled_scales = quantizer.attr("optimize_for_gemm").cast(); + printf(">>>>>>>>>>>> GroupedTensorFromPyTorchGroupedTensor with_gemm_swizzled_scales: %d\n", with_gemm_swizzled_scales); } auto ret = GroupedTensorWrapper(num_tensors, logical_shape, scaling_mode); diff --git a/transformer_engine/pytorch/tensor/mxfp8_tensor.py b/transformer_engine/pytorch/tensor/mxfp8_tensor.py index a283b43908..e4c658ed58 100644 --- a/transformer_engine/pytorch/tensor/mxfp8_tensor.py +++ b/transformer_engine/pytorch/tensor/mxfp8_tensor.py @@ -75,6 +75,8 @@ def update_quantized( src = src.contiguous() # Launch cast kernel + print(f">>>>>>>>>>>> src: {src.shape}") + print(f">>>>>>>>>>>> dst: {dst.shape}") tex.quantize(src, self, dst, noop_flag) # Update FP8 dtype diff --git a/transformer_engine/pytorch/tensor/storage/grouped_tensor.py b/transformer_engine/pytorch/tensor/storage/grouped_tensor.py index 522c25f370..9771d61df8 100644 --- a/transformer_engine/pytorch/tensor/storage/grouped_tensor.py +++ b/transformer_engine/pytorch/tensor/storage/grouped_tensor.py @@ -8,7 +8,8 @@ import math import torch - +import transformer_engine +import transformer_engine_torch as tex from ...quantized_tensor import QuantizedTensorStorage, Quantizer from ..mxfp8_tensor import MXFP8Tensor @@ -52,8 +53,8 @@ class GroupedTensor: def __init__( self, num_tensors: int, - shape: List[Tuple[int, int]], - quantizers: Optional[List[Quantizer]] = None, + shapes: List[Tuple[int, int]], + quantizer: Optional[Quantizer] = None, dtype: Optional[torch.dtype] = None, data: Optional[torch.Tensor] = None, columnwise_data: Optional[torch.Tensor] = None, @@ -75,8 +76,8 @@ def __init__( Args: num_tensors: Number of tensors in the group - shape: 2D shape of each tensor (len num_tensors) - quantizers: List of Quantizers for the grouped tensor + shapes: 2D shape of each tensor (len num_tensors) + quantizer: Quantizer for the grouped tensor data: Row-wise data buffer (1D flattened) columnwise_data: Column-wise data buffer (1D flattened) scale_inv: Row-wise scale inverse buffer @@ -110,8 +111,8 @@ def __init__( print(f">>>>>>>>>>>> num_tensors: {num_tensors}") self.num_tensors = num_tensors - self.quantizers = quantizers - self.shape = shape + self.quantizer = quantizer + self.shapes = shapes self.dtype = ( dtype if dtype is not None else torch.float32 ) # Default to float32 if not provided @@ -276,7 +277,7 @@ def clear(self) -> None: self.tensor_offsets = None self.logical_shape = (0, 0) self.num_tensors = 0 - self.quantizers = None + self.quantizer = None self.quantized_tensors = None self.offsets = None self.scale_inv_offsets = None @@ -286,7 +287,7 @@ def __repr__(self) -> str: """String representation of the GroupedTensor.""" return ( f"GroupedTensor(num_tensors={self.num_tensors}, " - f"shape={self.shape}, " + f"shapes={self.shapes}, " f"logical_shape={self.logical_shape}, " f"dtype={self.get_dtype()})" ) @@ -312,8 +313,8 @@ def __str__(self) -> str: @staticmethod def make_grouped_tensor_with_shapes( num_tensors: int, - shape: List[Tuple[int, int]], - quantizers: Optional[List[Quantizer]] = None, + shapes: List[Tuple[int, int]], + quantizer: Optional[Quantizer] = None, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None, ) -> GroupedTensor: @@ -322,8 +323,8 @@ def make_grouped_tensor_with_shapes( Args: num_tensors: Number of tensors - shape: 2D shape of each tensor (len num_tensors) - quantizers: List of Quantizers for each tensor + shapes: 2D shape of each tensor (len num_tensors) + quantizer: Quantizer for the grouped tensor device: Device to allocate tensors on, defaults to current cuda device dtype: Data type of the tensor (for high precision case) @@ -332,16 +333,16 @@ def make_grouped_tensor_with_shapes( """ # First dim - first_dim_list = [s[0] for s in shape] + first_dim_list = [s[0] for s in shapes] uniform_first_dim = all(first_dim_list[0] == x for x in first_dim_list) logical_first_dim = sum(first_dim_list) if uniform_first_dim: first_dims = None else: - first_dims = torch.tensor([s[0] for s in shape], dtype=torch.int64, device=device) + first_dims = torch.tensor([s[0] for s in shapes], dtype=torch.int64, device=device) # Last dim - last_dim_list = [s[1] for s in shape] + last_dim_list = [s[1] for s in shapes] logical_last_dim = last_dim_list[0] assert all(logical_last_dim == x for x in last_dim_list), "Last dims should be uniform" @@ -351,7 +352,7 @@ def make_grouped_tensor_with_shapes( last_dims=None, logical_first_dim=logical_first_dim, logical_last_dim=logical_last_dim, - quantizers=quantizers, + quantizer=quantizer, device=device, dtype=dtype, ) @@ -363,7 +364,7 @@ def make_grouped_tensor( last_dims: Optional[torch.tensor], logical_first_dim: int, logical_last_dim: int, - quantizers: Optional[List[Quantizer]] = None, + quantizer: Optional[Quantizer] = None, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None, ) -> GroupedTensor: @@ -376,7 +377,7 @@ def make_grouped_tensor( last_dims: Device tensor of int64 array of length num_tensors (or None if uniform) logical_first_dim: Logical first dimension logical_last_dim: Logical last dimension - quantizers: List of Quantizers for each tensor + quantizer: Quantizer for the grouped tensor Used to figure out the recipe and what to allocate. device: Device to allocate tensors on, defaults to current cuda device dtype: Data type of the tensor (for high precision case) @@ -405,7 +406,7 @@ def make_grouped_tensor( # Calculate tensor offsets (cumulative element offsets) tensor_offsets = None offsets = None - shape = [] + shapes = [] if not all_same_first: # Need explicit offsets for non-uniform shapes # Offsets are based on number of elements and not pointers. @@ -421,21 +422,18 @@ def make_grouped_tensor( offsets = tensor_offsets.tolist() first_dims_list = first_dims.tolist() for i in range(num_tensors): - shape.append((first_dims_list[i], logical_last_dim)) + shapes.append((first_dims_list[i], logical_last_dim)) else: offsets = [ i * logical_first_dim * logical_last_dim // num_tensors for i in range(num_tensors + 1) ] for i in range(num_tensors): - shape.append((logical_first_dim // num_tensors, logical_last_dim)) + shapes.append((logical_first_dim // num_tensors, logical_last_dim)) # Calculate logical shape based logical_shape = (logical_first_dim, logical_last_dim) - quantizer = quantizers[0] if isinstance(quantizers, list) else quantizers - # print(f">>>>>>>>>>>>> quantizers: {quantizers}") - print(f">>>>>>>>>>>>> quantizer: {quantizer}") no_quantization = quantizer is None rowwise_usage = quantizer.rowwise_usage if not no_quantization else True @@ -470,7 +468,7 @@ def make_grouped_tensor( # For grouped tensors, we need to calculate scale_inv size for all tensors total_scale_elements = 0 scale_inv_offsets = [0] - for i, s in enumerate(shape): + for i, s in enumerate(shapes): scale_inv_shape = quantizer.get_scale_shape(s, False) scale_elements = math.prod(scale_inv_shape) total_scale_elements += scale_elements @@ -484,7 +482,7 @@ def make_grouped_tensor( # Columnwise scale inverse buffer total_columnwise_scale_elements = 0 columnwise_scale_inv_offsets = [0] - for i, s in enumerate(shape): + for i, s in enumerate(shapes): scale_inv_shape = quantizer.get_scale_shape(s, False) columnwise_scale_elements = math.prod(scale_inv_shape) total_columnwise_scale_elements += columnwise_scale_elements @@ -538,7 +536,7 @@ def make_grouped_tensor( # Columnwise scale inverse buffer total_columnwise_scale_elements = 0 columnwise_scale_inv_offsets = [0] - for i, s in enumerate(shape): + for i, s in enumerate(shapes): columnwise_scale_inv_shape = quantizer.get_scale_shape(s, True) total_columnwise_scale_elements += math.prod(columnwise_scale_inv_shape) if i < num_tensors - 1: @@ -569,7 +567,7 @@ def make_grouped_tensor( # Columnwise scale inverse total_columnwise_scale_elements = 0 columnwise_scale_inv_offsets = [0] - for i, s in enumerate(shape): + for i, s in enumerate(shapes): columnwise_scale_inv_shape = quantizer.get_scale_shape(s, True) total_columnwise_scale_elements += math.prod(columnwise_scale_inv_shape) if i < num_tensors - 1: @@ -603,9 +601,9 @@ def make_grouped_tensor( grouped_tensor = GroupedTensor( num_tensors=num_tensors, - shape=shape, + shapes=shapes, dtype=dtype, - quantizers=quantizers, + quantizer=quantizer, data=data, columnwise_data=columnwise_data, scale_inv=scale_inv, @@ -643,13 +641,13 @@ def split_into_quantized_tensors( result = [] - no_quantization = self.quantizers is None + no_quantization = self.quantizer is None # Case 1: No quantization - return regular torch tensors if no_quantization: for i in range(self.num_tensors): # Get tensor shape - tensor_shape = self.shape[i] + tensor_shape = self.shapes[i] # Get tensor data slice if self.offsets is not None: @@ -687,11 +685,11 @@ def split_into_quantized_tensors( return result # Case 2: Quantized tensors - recipe = self.quantizers[0]._get_compatible_recipe() + recipe = self.quantizer._get_compatible_recipe() for i in range(self.num_tensors): # Get tensor shape - tensor_shape = self.shape[i] + tensor_shape = self.shapes[i] numel = tensor_shape[0] * tensor_shape[1] # Get data offsets @@ -704,7 +702,7 @@ def split_into_quantized_tensors( data_end = data_start + numel # Special shape handling for NVFP4. - nvfp4 = self.quantizers[0]._get_compatible_recipe().nvfp4() + nvfp4 = self.quantizer._get_compatible_recipe().nvfp4() if nvfp4: data_start = data_start // 2 data_end = data_end // 2 @@ -715,15 +713,15 @@ def split_into_quantized_tensors( if self.has_data(): if nvfp4: - rowwise_tensor_shape = self.quantizers[0].convert_shape_for_fp4(tensor_shape) + rowwise_tensor_shape = self.quantizer.convert_shape_for_fp4(tensor_shape) else: rowwise_tensor_shape = tensor_shape rowwise_data = self.data[data_start:data_end].view(rowwise_tensor_shape) if self.has_columnwise_data(): - columnwise_tensor_shape = self.quantizers[0].get_columnwise_shape(tensor_shape) + columnwise_tensor_shape = self.quantizer.get_columnwise_shape(tensor_shape) if nvfp4: - columnwise_tensor_shape = self.quantizers[0].convert_shape_for_fp4( + columnwise_tensor_shape = self.quantizer.convert_shape_for_fp4( columnwise_tensor_shape ) columnwise_data = self.columnwise_data[data_start:data_end].view( @@ -744,7 +742,7 @@ def split_into_quantized_tensors( scale_end = self.scale_inv.numel() # Calculate expected scale shape for MXFP8 - scale_shape = self.quantizers[0].get_scale_shape(tensor_shape, False) + scale_shape = self.quantizer.get_scale_shape(tensor_shape, False) rowwise_scale_inv = self.scale_inv[scale_start:scale_end].view(scale_shape) if ( @@ -757,12 +755,12 @@ def split_into_quantized_tensors( else: cscale_end = self.columnwise_scale_inv.numel() - cscale_shape = self.quantizers[0].get_scale_shape(tensor_shape, True) + cscale_shape = self.quantizer.get_scale_shape(tensor_shape, True) columnwise_scale_inv = self.columnwise_scale_inv[cscale_start:cscale_end].view( cscale_shape ) - if self.quantizers[0].internal: + if self.quantizer.internal: mxfp8_tensor_class = MXFP8TensorStorage else: mxfp8_tensor_class = MXFP8Tensor @@ -773,9 +771,9 @@ def split_into_quantized_tensors( rowwise_scale_inv=rowwise_scale_inv, columnwise_data=columnwise_data, columnwise_scale_inv=columnwise_scale_inv, - fp8_dtype=self.quantizers[0].dtype, - quantizer=self.quantizers[0], - with_gemm_swizzled_scales=self.quantizers[0].optimize_for_gemm, + fp8_dtype=self.quantizer.dtype, + quantizer=self.quantizer, + with_gemm_swizzled_scales=self.quantizer.optimize_for_gemm, ) result.append(tensor) @@ -786,7 +784,7 @@ def split_into_quantized_tensors( if self.scale_inv is not None: scale_inv = self.scale_inv[i : i + 1] - if self.quantizers[0].internal: + if self.quantizer.internal: float8_tensor_class = Float8TensorStorage else: float8_tensor_class = Float8Tensor @@ -796,8 +794,8 @@ def split_into_quantized_tensors( dtype=self.dtype, data=rowwise_data, fp8_scale_inv=scale_inv, - fp8_dtype=self.quantizers[0].dtype, - quantizer=self.quantizers[0], + fp8_dtype=self.quantizer.dtype, + quantizer=self.quantizer, data_transpose=columnwise_data, ) result.append(tensor) @@ -816,7 +814,7 @@ def split_into_quantized_tensors( scale_end = self.scale_inv.numel() # Get scale shape from quantizer - scale_shape = self.quantizers[0].get_scale_shape(tensor_shape, False) + scale_shape = self.quantizer.get_scale_shape(tensor_shape, False) rowwise_scale_inv = self.scale_inv[scale_start:scale_end].view(scale_shape) if ( @@ -830,15 +828,15 @@ def split_into_quantized_tensors( cscale_end = self.columnwise_scale_inv.numel() # Get columnwise scale shape from quantizer - cscale_shape = self.quantizers[0].get_scale_shape(tensor_shape, True) + cscale_shape = self.quantizer.get_scale_shape(tensor_shape, True) columnwise_scale_inv = self.columnwise_scale_inv[cscale_start:cscale_end].view( cscale_shape ) # Compute is_2D_scaled and data_format from quantizer attributes - is_2D_scaled = self.quantizers[0].block_scaling_dim == 2 + is_2D_scaled = self.quantizer.block_scaling_dim == 2 - if self.quantizers[0].internal: + if self.quantizer.internal: float8_blockwise_q_tensor_class = Float8BlockwiseQTensorStorage else: float8_blockwise_q_tensor_class = Float8BlockwiseQTensor @@ -850,8 +848,8 @@ def split_into_quantized_tensors( rowwise_scale_inv=rowwise_scale_inv, columnwise_data=columnwise_data, columnwise_scale_inv=columnwise_scale_inv, - fp8_dtype=self.quantizers[0].dtype, - quantizer=self.quantizers[0], + fp8_dtype=self.quantizer.dtype, + quantizer=self.quantizer, is_2D_scaled=is_2D_scaled, ) result.append(tensor) @@ -872,7 +870,7 @@ def split_into_quantized_tensors( scale_end = self.scale_inv.numel() # Get scale shape from quantizer - scale_shape = self.quantizers[0].get_scale_shape(tensor_shape, False) + scale_shape = self.quantizer.get_scale_shape(tensor_shape, False) rowwise_scale_inv = self.scale_inv[scale_start:scale_end].view(scale_shape) if ( @@ -886,7 +884,7 @@ def split_into_quantized_tensors( cscale_end = self.columnwise_scale_inv.numel() # Get columnwise scale shape from quantizer - cscale_shape = self.quantizers[0].get_scale_shape(tensor_shape, True) + cscale_shape = self.quantizer.get_scale_shape(tensor_shape, True) columnwise_scale_inv = self.columnwise_scale_inv[cscale_start:cscale_end].view( cscale_shape ) @@ -898,7 +896,7 @@ def split_into_quantized_tensors( if self.columnwise_amax is not None: amax_columnwise = self.columnwise_amax[i : i + 1] - if self.quantizers[0].internal: + if self.quantizer.internal: nvfp4_tensor_class = NVFP4TensorStorage else: nvfp4_tensor_class = NVFP4Tensor @@ -912,9 +910,9 @@ def split_into_quantized_tensors( columnwise_scale_inv=columnwise_scale_inv, amax_rowwise=amax_rowwise, amax_columnwise=amax_columnwise, - fp4_dtype=self.quantizers[0].dtype, - quantizer=self.quantizers[0], - with_gemm_swizzled_scales=self.quantizers[0].optimize_for_gemm, + fp4_dtype=self.quantizer.dtype, + quantizer=self.quantizer, + with_gemm_swizzled_scales=self.quantizer.optimize_for_gemm, ) result.append(tensor) @@ -926,7 +924,7 @@ def split_into_quantized_tensors( @staticmethod def create_and_quantize( tensors: int, - quantizers: None | List[Quantizer], + quantizer: None | Quantizer, *, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None, @@ -937,17 +935,41 @@ def create_and_quantize( storage allocated in a GroupedTensor. """ - grouped_tensor = GroupedTensor.make_grouped_tensor_with_shapes( + grouped_input = GroupedTensor.make_grouped_tensor_with_shapes( + num_tensors=len(tensors), + shapes=[t.shape for t in tensors], + quantizer=None, + device=device, + dtype=tensors[0].dtype, + ) + + offset = 0 + for tensor in tensors: + numel = tensor.numel() + grouped_input.data[offset : offset + numel].copy_(tensor.reshape(-1)) + offset += numel + + grouped_output = GroupedTensor.make_grouped_tensor_with_shapes( num_tensors=len(tensors), - shape=[t.shape for t in tensors], - quantizers=quantizers, + shapes=[t.shape for t in tensors], + quantizer=quantizer, device=device, dtype=dtype, ) - grouped_tensor.quantize(tensors, noop_flag=noop_flag) + _ = tex.quantize_grouped(grouped_input, grouped_output) + grouped_output.quantized_tensors = grouped_output.split_into_quantized_tensors() - return grouped_tensor + # grouped_tensor = GroupedTensor.make_grouped_tensor_with_shapes( + # num_tensors=len(tensors), + # shapes=[t.shape for t in tensors], + # quantizer=None, + # device=device, + # dtype=tensors[0].dtype, + # ) + # grouped_tensor.quantize(tensors, noop_flag=noop_flag) + + return grouped_output def quantize( self, @@ -958,7 +980,9 @@ def quantize( Quantize the GroupedTensor inplace. """ - quantized_tensors = self.split_into_quantized_tensors() + self.quantized_tensors = self.split_into_quantized_tensors() + print(f">>>>>>>>>>>> tensors[0]: {type(tensors[0])}") + print(f">>>>>>>>>>>> quantized_tensors[0]: {type(self.quantized_tensors[0])}") for i in range(self.num_tensors): - self.quantizers[0].update_quantized(tensors[i], quantized_tensors[i], noop_flag=noop_flag) - return quantized_tensors \ No newline at end of file + self.quantizer.update_quantized(tensors[i], self.quantized_tensors[i], noop_flag=noop_flag) + return self.quantized_tensors \ No newline at end of file From 4e854d523d056e5348d4ec5d122c0936aa00eb8d Mon Sep 17 00:00:00 2001 From: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Date: Wed, 11 Feb 2026 18:50:31 -0800 Subject: [PATCH 009/172] fix unfused; clean up Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --- tests/pytorch/attention/test_attention.py | 23 ++-- tests/pytorch/test_grouped_tensor.py | 110 ++++++++--------- .../cast/mxfp8/group_quantize_mxfp8.cuh | 9 +- .../common/cast/mxfp8/quantize_mxfp8.cuh | 7 +- .../common/fused_attn/fused_attn.cpp | 23 ++-- .../common/fused_attn/fused_attn_fp8.cu | 18 +-- transformer_engine/common/recipe/__init__.py | 2 +- .../dot_product_attention/backends.py | 112 ++++++++++-------- .../dot_product_attention/context_parallel.py | 2 +- .../dot_product_attention.py | 5 +- .../attention/dot_product_attention/utils.py | 74 ++++++------ .../pytorch/cpp_extensions/fused_attn.py | 1 - .../pytorch/tensor/mxfp8_tensor.py | 5 +- .../pytorch/tensor/storage/grouped_tensor.py | 37 +----- 14 files changed, 183 insertions(+), 245 deletions(-) diff --git a/tests/pytorch/attention/test_attention.py b/tests/pytorch/attention/test_attention.py index 5602114143..74deeceed2 100644 --- a/tests/pytorch/attention/test_attention.py +++ b/tests/pytorch/attention/test_attention.py @@ -2139,15 +2139,15 @@ def test_dpa_fp8_vs_f16(dtype, model, qkv_layout, fp8_dpa_bwd, is_training, scal dtype, config, True, qkv_layout, is_training, fp8_recipe ) - # if unfused_attn_supported: - # os.environ["NVTE_FLASH_ATTN"] = "0" - # os.environ["NVTE_FUSED_ATTN"] = "0" - # os.environ["NVTE_UNFUSED_ATTN"] = "1" - # _attention_backends["backend_selection_requires_update"] = True - # logging.info("[test_dpa_fp8_vs_f16]: run with fp8_dpa = True (UnfusedDotProductAttention)") - # unfused_attn_fwd_fp8, unfused_attn_bwd_fp8 = _run_dpa_fp8_vs_f16( - # dtype, config, True, qkv_layout, is_training, fp8_recipe - # ) + if unfused_attn_supported: + os.environ["NVTE_FLASH_ATTN"] = "0" + os.environ["NVTE_FUSED_ATTN"] = "0" + os.environ["NVTE_UNFUSED_ATTN"] = "1" + _attention_backends["backend_selection_requires_update"] = True + logging.info("[test_dpa_fp8_vs_f16]: run with fp8_dpa = True (UnfusedDotProductAttention)") + unfused_attn_fwd_fp8, unfused_attn_bwd_fp8 = _run_dpa_fp8_vs_f16( + dtype, config, True, qkv_layout, is_training, fp8_recipe + ) os.environ["NVTE_FLASH_ATTN"] = "0" os.environ["NVTE_FUSED_ATTN"] = "1" @@ -2157,8 +2157,6 @@ def test_dpa_fp8_vs_f16(dtype, model, qkv_layout, fp8_dpa_bwd, is_training, scal fused_attn_fwd_fp8, fused_attn_bwd_fp8 = _run_dpa_fp8_vs_f16( dtype, config, True, qkv_layout, is_training, fp8_recipe ) - # print(f">>>>>> fused_attn_bwd_fp8: {fused_attn_bwd_fp8} {is_training}") - # torch.save(fused_attn_fwd_fp8, "fused_attn_fwd_fp8.pt") os.environ["NVTE_FLASH_ATTN"] = "0" os.environ["NVTE_FUSED_ATTN"] = "1" @@ -2169,7 +2167,6 @@ def test_dpa_fp8_vs_f16(dtype, model, qkv_layout, fp8_dpa_bwd, is_training, scal fused_attn_fwd_f16, fused_attn_bwd_f16 = _run_dpa_fp8_vs_f16( dtype, config, False, qkv_layout, is_training, fp8_recipe ) - # torch.save(fused_attn_fwd_f16, "fused_attn_fwd_f16.pt") atol = 5e-1 rtol = 5e-2 @@ -2188,7 +2185,7 @@ def test_dpa_fp8_vs_f16(dtype, model, qkv_layout, fp8_dpa_bwd, is_training, scal rmse_tol, True, ) - if False: #unfused_attn_supported: + if unfused_attn_supported: logging.debug("========== {:^25s} ==========".format("unfused fp8 vs fused f16:")) logging.debug("========== {:^25s} ==========".format("forward output")) compare_and_assert( diff --git a/tests/pytorch/test_grouped_tensor.py b/tests/pytorch/test_grouped_tensor.py index 964c2d8e97..f0b2c35c0a 100644 --- a/tests/pytorch/test_grouped_tensor.py +++ b/tests/pytorch/test_grouped_tensor.py @@ -121,11 +121,11 @@ def setup_class(cls) -> None: def test_basic_construction_all_same_shape(self) -> None: """Test GroupedTensor construction with all tensors having same shape""" num_tensors = 4 - shape = [(256, 512) for _ in range(num_tensors)] + shapes = [(256, 512) for _ in range(num_tensors)] grouped_tensor = GroupedTensor.make_grouped_tensor_with_shapes( num_tensors=num_tensors, - shape=shape, + shapes=shapes, quantizer=None, device="cuda", dtype=torch.float32, @@ -143,11 +143,11 @@ def test_basic_construction_all_same_shape(self) -> None: def test_basic_construction_varying_first_dim(self) -> None: """Test GroupedTensor construction with varying first dimension""" num_tensors = 3 - shape = [(128, 512), (256, 512), (384, 512)] + shapes = [(128, 512), (256, 512), (384, 512)] grouped_tensor = GroupedTensor.make_grouped_tensor_with_shapes( num_tensors=num_tensors, - shape=shape, + shapes=shapes, quantizer=None, device="cuda", dtype=torch.float32, @@ -157,20 +157,20 @@ def test_basic_construction_varying_first_dim(self) -> None: assert not grouped_tensor.all_same_shape() assert not grouped_tensor.all_same_first_dim() assert grouped_tensor.all_same_last_dim() - assert grouped_tensor.get_common_last_dim() == shape[0][1] + assert grouped_tensor.get_common_last_dim() == shapes[0][1] assert grouped_tensor.logical_shape == ( - sum(v for v, _ in shape), - shape[0][1], + sum(v for v, _ in shapes), + shapes[0][1], ) # sum of first dims def test_split_into_quantized_tensors_no_quantization(self) -> None: """Test split_into_quantized_tensors for unquantized tensors""" num_tensors = 3 - shape = [(256, 512) for _ in range(num_tensors)] + shapes = [(256, 512) for _ in range(num_tensors)] grouped_tensor = GroupedTensor.make_grouped_tensor_with_shapes( num_tensors=num_tensors, - shape=shape, + shapes=shapes, quantizer=None, device="cuda", dtype=torch.float32, @@ -186,7 +186,7 @@ def test_split_into_quantized_tensors_no_quantization(self) -> None: # Verify each tensor has correct shape and shares storage for i, tensor in enumerate(tensors): - assert tensor.shape == shape[i] + assert tensor.shape == shapes[i] assert isinstance(tensor, torch.Tensor) assert not hasattr(tensor, "_data") # Not a quantized tensor @@ -195,20 +195,20 @@ def test_split_into_quantized_tensors_no_quantization(self) -> None: assert tensor.data_ptr() >= original_data_ptr # Calculate expected offset - expected_offset = i * (shape[i][0] * shape[i][1]) * tensor.element_size() + expected_offset = i * (shapes[i][0] * shapes[i][1]) * tensor.element_size() assert tensor.data_ptr() == original_data_ptr + expected_offset @pytest.mark.parametrize("quantization", _quantization_params) def test_split_into_quantized_tensors_quantized(self, quantization: str) -> None: """Test split_into_quantized_tensors for quantized tensors""" num_tensors = 3 - shape = [(512, 512) for _ in range(num_tensors)] - quantizers = make_quantizer(quantization, num_tensors, shape) + shapes = [(512, 512) for _ in range(num_tensors)] + quantizer = make_quantizer(quantization, num_tensors, shapes) grouped_tensor = GroupedTensor.make_grouped_tensor_with_shapes( num_tensors=num_tensors, - shape=shape, - quantizer=quantizers, + shapes=shapes, + quantizer=quantizer, device="cuda", ) @@ -225,18 +225,18 @@ def test_split_into_quantized_tensors_quantized(self, quantization: str) -> None rowwise_data = _get_rowwise_data_tensor(tensor, quantization) assert rowwise_data is not None assert rowwise_data.data_ptr() >= original_data_ptr - numel = shape[i][0] * shape[i][1] + numel = shapes[i][0] * shapes[i][1] expected_offset = _rowwise_offset_bytes(i * numel, quantization) assert rowwise_data.data_ptr() == original_data_ptr + expected_offset def test_split_varying_shapes(self) -> None: """Test split_into_quantized_tensors with varying shapes""" num_tensors = 3 - shape = [(128, 512), (256, 512), (384, 512)] + shapes = [(128, 512), (256, 512), (384, 512)] grouped_tensor = GroupedTensor.make_grouped_tensor_with_shapes( num_tensors=num_tensors, - shape=shape, + shapes=shapes, quantizer=None, device="cuda", dtype=torch.float32, @@ -250,22 +250,22 @@ def test_split_varying_shapes(self) -> None: # Verify shapes and storage cumulative_offset = 0 for i, tensor in enumerate(tensors): - assert tensor.shape == shape[i] + assert tensor.shape == shapes[i] expected_offset = cumulative_offset * tensor.element_size() assert tensor.data_ptr() == original_data_ptr + expected_offset - cumulative_offset += shape[i][0] * shape[i][1] + cumulative_offset += shapes[i][0] * shapes[i][1] @pytest.mark.parametrize("quantization", _quantization_params) def test_quantize_inplace(self, quantization: str) -> None: """Test that quantize is done in-place for all recipes""" num_tensors = 3 - shape = [(512, 512) for _ in range(num_tensors)] - quantizers = make_quantizer(quantization, num_tensors, shape) + shapes = [(512, 512) for _ in range(num_tensors)] + quantizer = make_quantizer(quantization, num_tensors, shapes) grouped_tensor = GroupedTensor.make_grouped_tensor_with_shapes( num_tensors=num_tensors, - shape=shape, - quantizer=quantizers, + shapes=shapes, + quantizer=quantizer, device="cuda", ) @@ -277,7 +277,7 @@ def test_quantize_inplace(self, quantization: str) -> None: ) # Create input tensors - input_tensors = [torch.randn(s, dtype=torch.float32, device="cuda") for s in shape] + input_tensors = [torch.randn(s, dtype=torch.float32, device="cuda") for s in shapes] # Quantize in place quantized_tensors = grouped_tensor.quantize(input_tensors) @@ -291,7 +291,7 @@ def test_quantize_inplace(self, quantization: str) -> None: # Verify returned tensors point to the same storage for i, qtensor in enumerate(quantized_tensors): rowwise_data = _get_rowwise_data_tensor(qtensor, quantization) - numel = shape[i][0] * shape[i][1] + numel = shapes[i][0] * shapes[i][1] expected_offset = _rowwise_offset_bytes(i * numel, quantization) assert rowwise_data.data_ptr() == original_data_ptr + expected_offset @@ -299,13 +299,13 @@ def test_quantize_inplace(self, quantization: str) -> None: def test_quantize_varying_shapes(self, quantization: str) -> None: """Test quantize with varying shapes""" num_tensors = 3 - shape = [(256, 512), (512, 512), (768, 512)] - quantizers = make_quantizer(quantization, num_tensors, shape) + shapes = [(256, 512), (512, 512), (768, 512)] + quantizer = make_quantizer(quantization, num_tensors, shapes) grouped_tensor = GroupedTensor.make_grouped_tensor_with_shapes( num_tensors=num_tensors, - shape=shape, - quantizer=quantizers, + shapes=shapes, + quantizer=quantizer, device="cuda", ) @@ -313,7 +313,7 @@ def test_quantize_varying_shapes(self, quantization: str) -> None: original_data_ptr = grouped_tensor.data.data_ptr() # Create input tensors with varying shapes - input_tensors = [torch.randn(s, dtype=torch.float32, device="cuda") for s in shape] + input_tensors = [torch.randn(s, dtype=torch.float32, device="cuda") for s in shapes] # Quantize in place quantized_tensors = grouped_tensor.quantize(input_tensors) @@ -323,7 +323,7 @@ def test_quantize_varying_shapes(self, quantization: str) -> None: # Verify each tensor points to correct location cumulative_numel = 0 - for qtensor, tensor_shape in zip(quantized_tensors, shape): + for qtensor, tensor_shape in zip(quantized_tensors, shapes): rowwise_data = _get_rowwise_data_tensor(qtensor, quantization) expected_offset = _rowwise_offset_bytes(cumulative_numel, quantization) assert rowwise_data.data_ptr() == original_data_ptr + expected_offset @@ -333,16 +333,16 @@ def test_quantize_varying_shapes(self, quantization: str) -> None: def test_static_quantize_method(self, quantization: str) -> None: """Test the static quantize method""" num_tensors = 3 - shape = [(512, 512) for _ in range(num_tensors)] - quantizers = make_quantizer(quantization, num_tensors, shape) + shapes = [(512, 512) for _ in range(num_tensors)] + quantizer = make_quantizer(quantization, num_tensors, shapes) # Create input tensors - input_tensors = [torch.randn(s, dtype=torch.float32, device="cuda") for s in shape] + input_tensors = [torch.randn(s, dtype=torch.float32, device="cuda") for s in shapes] # Use static quantize method grouped_tensor = GroupedTensor.create_and_quantize( tensors=input_tensors, - quantizer=quantizers, + quantizer=quantizer, device="cuda", ) @@ -357,7 +357,7 @@ def test_static_quantize_method(self, quantization: str) -> None: original_data_ptr = grouped_tensor.data.data_ptr() for i, qtensor in enumerate(grouped_tensor.quantized_tensors): rowwise_data = _get_rowwise_data_tensor(qtensor, quantization) - numel = shape[i][0] * shape[i][1] + numel = shapes[i][0] * shapes[i][1] expected_offset = _rowwise_offset_bytes(i * numel, quantization) assert rowwise_data.data_ptr() == original_data_ptr + expected_offset @@ -370,18 +370,16 @@ def test_quantize_grouped_mxfp8(self, shape: List[Tuple[int, int]]) -> None: """Test grouped quantization for MXFP8 against per-tensor quantization.""" # Test wont pass until the grouped quantization PR from Oleg is merged. num_tensors = 2 - shape = [(512, 1024) for _ in range(num_tensors)] + shapes = [(512, 1024) for _ in range(num_tensors)] # Create BF16 input tensors and pack into a grouped tensor - input_tensors = [torch.randn(s, dtype=torch.bfloat16, device="cuda") for s in shape] - quantizers = [MXFP8Quantizer(fp8_dtype=tex.DType.kFloat8E4M3) for _ in range(num_tensors)] - for q in quantizers: - q.optimize_for_gemm=True - quantized_tensors = [q(tensor) for q, tensor in zip(quantizers, input_tensors)] + input_tensors = [torch.randn(s, dtype=torch.bfloat16, device="cuda") for s in shapes] + quantizer = MXFP8Quantizer(fp8_dtype=tex.DType.kFloat8E4M3) + quantizer.optimize_for_gemm=True grouped_input = GroupedTensor.make_grouped_tensor_with_shapes( num_tensors=num_tensors, - shape=shape, - quantizers=None, + shapes=shapes, + quantizer=None, device="cuda", dtype=torch.bfloat16, ) @@ -392,30 +390,18 @@ def test_quantize_grouped_mxfp8(self, shape: List[Tuple[int, int]]) -> None: grouped_input.data[offset : offset + numel].copy_(tensor.reshape(-1)) offset += numel - # Create MXFP8 output grouped tensor (rowwise only for easier validation) - # quantizers = [MXFP8Quantizer(fp8_dtype=tex.DType.kFloat8E4M3) for _ in range(num_tensors)] - grouped_output = GroupedTensor.make_grouped_tensor_with_shapes( num_tensors=num_tensors, - shape=shape, - quantizers=quantizers, + shapes=shapes, + quantizer=quantizer, device="cuda", ) - print(f">>>>>>>>>>>> tex.quantize_grouped") - print(f">>>>>>>>>>>> grouped_input: {grouped_input.data.shape if grouped_input.data is not None else None}") - print(f">>>>>>>>>>>> grouped_output: {grouped_output.data.shape if grouped_output.data is not None else None}") - print(f">>>>>>>>>>>> grouped_input: {grouped_input.scale_inv.shape if grouped_input.scale_inv is not None else None}") - print(f">>>>>>>>>>>> grouped_output: {grouped_output.scale_inv.shape if grouped_output.scale_inv is not None else None}") - print(f">>>>>>>>>>>> grouped_input: {grouped_input.columnwise_data.shape if grouped_input.columnwise_data is not None else None}") - print(f">>>>>>>>>>>> grouped_output: {grouped_output.columnwise_data.shape if grouped_output.columnwise_data is not None else None}") - print(f">>>>>>>>>>>> grouped_input: {grouped_input.columnwise_scale_inv.shape if grouped_input.columnwise_scale_inv is not None else None}") - print(f">>>>>>>>>>>> grouped_output: {grouped_output.columnwise_scale_inv.shape if grouped_output.columnwise_scale_inv is not None else None}") # Quantize using grouped API (handle both 2-arg and 3-arg bindings) _ = tex.quantize_grouped(grouped_input, grouped_output) # Build expected output by quantizing each tensor independently expected_data = [] expected_scale_inv = [] - for tensor, quantizer in zip(input_tensors, quantizers): + for tensor in input_tensors: qtensor = quantizer(tensor) expected_data.append(qtensor._rowwise_data.reshape(-1)) expected_scale_inv.append(qtensor._rowwise_scale_inv.reshape(-1)) @@ -429,11 +415,11 @@ def test_quantize_grouped_mxfp8(self, shape: List[Tuple[int, int]]) -> None: def test_clear(self) -> None: """Test clear method""" num_tensors = 3 - shape = [(256, 512) for _ in range(num_tensors)] + shapes = [(256, 512) for _ in range(num_tensors)] grouped_tensor = GroupedTensor.make_grouped_tensor_with_shapes( num_tensors=num_tensors, - shape=shape, + shapes=shapes, quantizer=None, device="cuda", dtype=torch.float32, diff --git a/transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh b/transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh index ea81e6c516..6a6715bdcc 100644 --- a/transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh +++ b/transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh @@ -244,7 +244,6 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) group_quantize_mxfp8_kernel const int64_t *const __restrict__ last_dims_ptr, e8m0_t *const __restrict__ scales_rowwise_ptr, e8m0_t *const __restrict__ scales_colwise_ptr, const float *__restrict__ noop, float *const __restrict__ dbias_workspace, float *const __restrict__ amax_ptr) { -// printf(">>>>>>>>>>>> WITH_GEMM_SWIZZLED_SCALES: %d\n", WITH_GEMM_SWIZZLED_SCALES); #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) constexpr bool COMPUTE_ACTIVATIONS = IS_DACT || IS_ACT; constexpr bool NO_ACTIVATIONS = !COMPUTE_ACTIVATIONS; @@ -853,7 +852,7 @@ void group_quantize(const GroupedTensor *input, const GroupedTensor *activations } const bool with_gemm_swizzled_scales = output->with_gemm_swizzled_scales; - printf(">>>>>>>>>>>> with_gemm_swizzled_scales: %d\n", with_gemm_swizzled_scales); + printf(">>>>>>>>>>>> group_quantize_mxfp8 with_gemm_swizzled_scales: %d\n", with_gemm_swizzled_scales); const size_t scale_stride_rowwise = use_rowwise_scaling ? output->scale_inv.shape[1] : 1; const size_t scale_stride_colwise = @@ -935,19 +934,19 @@ void group_quantize(const GroupedTensor *input, const GroupedTensor *activations OType, true, true, WITH_GEMM_SWIZZLED_SCALES>; switch (scaling_type) { case ScalingType::ROWWISE: { - printf(">>>>>>>>>>>> grouped: ScalingType::ROWWISE\n"); + printf(">>>>>>>>>>>> group_quantize_mxfp8 ScalingType::ROWWISE\n"); kernel = group_quantize_mxfp8_kernel; break; } case ScalingType::COLWISE: { - printf(">>>>>>>>>>>> grouped: ScalingType::COLWISE\n"); + printf(">>>>>>>>>>>> group_quantize_mxfp8 ScalingType::COLWISE\n"); kernel = group_quantize_mxfp8_kernel; break; } case ScalingType::BIDIMENSIONAL: { - printf(">>>>>>>>>>>> grouped: ScalingType::BIDIMENSIONAL\n"); + printf(">>>>>>>>>>>> group_quantize_mxfp8 ScalingType::BIDIMENSIONAL\n"); kernel = group_quantize_mxfp8_kernel; break; diff --git a/transformer_engine/common/cast/mxfp8/quantize_mxfp8.cuh b/transformer_engine/common/cast/mxfp8/quantize_mxfp8.cuh index 82bf497a3b..a8135391e3 100644 --- a/transformer_engine/common/cast/mxfp8/quantize_mxfp8.cuh +++ b/transformer_engine/common/cast/mxfp8/quantize_mxfp8.cuh @@ -55,7 +55,6 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) const float *noop, float *const dbias_workspace, float *const amax_ptr, const size_t rows, const size_t cols, const size_t scale_stride_rowwise, const size_t scale_stride_colwise) { -// printf(">>>>>>>>>>>> non-grouped: WITH_GEMM_SWIZZLED_SCALES: %d\n", WITH_GEMM_SWIZZLED_SCALES); #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) constexpr bool COMPUTE_ACTIVATIONS = IS_DACT || IS_ACT; constexpr bool NO_ACTIVATIONS = !COMPUTE_ACTIVATIONS; @@ -777,7 +776,7 @@ void quantize(const Tensor &input, const Tensor *act_input, const Tensor *noop, switch (scaling_type) { case ScalingType::ROWWISE: { - printf(">>>>>>>>>>>> non-grouped: ScalingType::ROWWISE\n"); + printf(">>>>>>>>>>>> quantize_mxfp8 ScalingType::ROWWISE\n"); auto kernel = quantize_mxfp8_kernel; @@ -793,7 +792,7 @@ void quantize(const Tensor &input, const Tensor *act_input, const Tensor *noop, break; } case ScalingType::COLWISE: { - printf(">>>>>>>>>>>> non-grouped: ScalingType::COLWISE\n"); + printf(">>>>>>>>>>>> quantize_mxfp8 ScalingType::COLWISE\n"); auto kernel = quantize_mxfp8_kernel; @@ -809,7 +808,7 @@ void quantize(const Tensor &input, const Tensor *act_input, const Tensor *noop, break; } case ScalingType::BIDIMENSIONAL: { - printf(">>>>>>>>>>>> non-grouped: ScalingType::BIDIMENSIONAL\n"); + printf(">>>>>>>>>>>> quantize_mxfp8 ScalingType::BIDIMENSIONAL\n"); auto kernel = quantize_mxfp8_kernel; diff --git a/transformer_engine/common/fused_attn/fused_attn.cpp b/transformer_engine/common/fused_attn/fused_attn.cpp index 3ed540b8f2..1adabcded2 100644 --- a/transformer_engine/common/fused_attn/fused_attn.cpp +++ b/transformer_engine/common/fused_attn/fused_attn.cpp @@ -234,16 +234,16 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( max_seqlen_kv, head_dim_qk, head_dim_v) == DType::kInt64); const bool supported_ragged_offset_size = (!requires_64bit_ragged_offset || cudnn_runtime_version >= 90500); - printf(">>>>>> qkv_layout: %d\n", qkv_layout); - printf(">>>>>> q_dtype: %d\n", q_dtype); - printf(">>>>>> qkv_format: %d\n", qkv_format); - printf(">>>>>> q_format: %d\n", q_format); - printf(">>>>>> kv_format: %d\n", kv_format); - printf(">>>>>> layout_group: %d\n", layout_group); - printf(">>>>>> cudnn_runtime_version: %d\n", cudnn_runtime_version); - printf(">>>>>> is_training: %d\n", is_training); - printf(">>>>>> bias_type: %d\n", bias_type); - printf(">>>>>> attn_mask_type: %d\n", attn_mask_type); + printf(">>>>>> nvte_get_fused_attn_backend qkv_layout: %d, %d, %d\n", qkv_layout, NVTE_QKV_Layout::NVTE_BHSD_BHSD_BHSD, NVTE_QKV_Layout::NVTE_BSHD_BSHD_BSHD); + printf(">>>>>> nvte_get_fused_attn_backend q_dtype: %d, %d, %d\n", q_dtype, NVTEDType::kNVTEFloat8E4M3, NVTEDType::kNVTEFloat8E5M2); + printf(">>>>>> nvte_get_fused_attn_backend qkv_format: %d, %d, %d\n", qkv_format, NVTE_QKV_Format::NVTE_BHSD, NVTE_QKV_Format::NVTE_BSHD); + printf(">>>>>> nvte_get_fused_attn_backend q_format: %d, %d, %d\n", q_format, NVTE_QKV_Format::NVTE_BHSD, NVTE_QKV_Format::NVTE_BSHD); + printf(">>>>>> nvte_get_fused_attn_backend kv_format: %d, %d, %d\n", kv_format, NVTE_QKV_Format::NVTE_BHSD, NVTE_QKV_Format::NVTE_BSHD); + printf(">>>>>> nvte_get_fused_attn_backend layout_group: %d, %d, %d\n", layout_group, NVTE_QKV_Layout_Group::NVTE_SD_SD_SD, NVTE_QKV_Layout_Group::NVTE_HD_HD_HD); + printf(">>>>>> nvte_get_fused_attn_backend cudnn_runtime_version: %d\n", cudnn_runtime_version); + printf(">>>>>> nvte_get_fused_attn_backend is_training: %d\n", is_training); + printf(">>>>>> nvte_get_fused_attn_backend bias_type: %d\n", bias_type); + printf(">>>>>> nvte_get_fused_attn_backend attn_mask_type: %d, %d, %d\n", attn_mask_type, NVTE_Mask_Type::NVTE_NO_MASK, NVTE_Mask_Type::NVTE_CAUSAL_MASK); if ((q_dtype == NVTEDType::kNVTEFloat8E4M3 || q_dtype == NVTEDType::kNVTEFloat8E5M2) && sm_arch_ >= 90 && bias_type == NVTE_Bias_Type::NVTE_NO_BIAS && @@ -270,7 +270,6 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK || attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK || attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK))) && - (qkv_format == NVTE_QKV_Format::NVTE_BSHD || qkv_format == NVTE_QKV_Format::NVTE_SBHD || qkv_format == NVTE_QKV_Format::NVTE_BHSD) && !requires_64bit_ragged_offset && (softmax_type == NVTE_Softmax_Type::NVTE_VANILLA_SOFTMAX) && // 9.10.0: known bugs with SDPA FP8 (cudnn_runtime_version != 91000) && !return_max_logit) { @@ -531,6 +530,7 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( } else { backend = NVTE_Fused_Attn_Backend::NVTE_No_Backend; } + printf(">>>>>> nvte_get_fused_attn_backend fused_attention_backend: %d\n", backend); return backend; } @@ -1202,7 +1202,6 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso h_q, h_kv, max_seqlen_q, max_seqlen_kv, d_qk, d_v, window_size_left, window_size_right, return_max_logit, cuda_graph, false); - printf(">>>>>> fused_attention_backend: %d\n", fused_attention_backend); if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) { #if (CUDNN_VERSION >= 8901) fused_attn_max_512_fwd(b, h_q, max_seqlen_q, max_seqlen_kv, d_qk, is_training, attn_scale, diff --git a/transformer_engine/common/fused_attn/fused_attn_fp8.cu b/transformer_engine/common/fused_attn/fused_attn_fp8.cu index f7698da5c3..71d86843b5 100644 --- a/transformer_engine/common/fused_attn/fused_attn_fp8.cu +++ b/transformer_engine/common/fused_attn/fused_attn_fp8.cu @@ -1681,21 +1681,13 @@ void fused_attn_fp8_fwd_impl_v1( o_tensor_type == cudnn_frontend::DataType_t::BFLOAT16); NVTE_CHECK(is_delayed_scaling || is_current_scaling || is_mxfp8, "FP8 fused attention only supports FP8DelayedScaling or FP8CurrentScaling or MXFP8 recipes!"); + printf(">>>>>> b: %d, h: %d, h_g: %d, s_q: %d, s_kv: %d, d: %d\n", b, h, hg, s_q, s_kv, d); printf(">>>>>> scaling_mode: %d\n", scaling_mode); printf(">>>>>> is_mxfp8: %d\n", is_mxfp8); printf(">>>>>> is_current_scaling: %d\n", is_current_scaling); printf(">>>>>> is_delayed_scaling: %d\n", is_delayed_scaling); printf(">>>>>> qkv_tensor_type: %d, %d, %d, %d\n", qkv_tensor_type, cudnn_frontend::DataType_t::FP8_E8M0, cudnn_frontend::DataType_t::FP8_E4M3, cudnn_frontend::DataType_t::FP8_E5M2); printf(">>>>>> o_tensor_type: %d, %d, %d\n", o_tensor_type, cudnn_frontend::DataType_t::FLOAT, cudnn_frontend::DataType_t::BFLOAT16); - printf(">>>>>> cudnn_frontend::DataType_t::UINT8: %d\n", cudnn_frontend::DataType_t::UINT8); - printf(">>>>>> cudnn_frontend::DataType_t::INT8: %d\n", cudnn_frontend::DataType_t::INT8); - printf(">>>>>> cudnn_frontend::DataType_t::HALF: %d\n", cudnn_frontend::DataType_t::HALF); - printf(">>>>>> cudnn_frontend::DataType_t::INT64: %d\n", cudnn_frontend::DataType_t::INT64); - printf(">>>>>> cudnn_frontend::DataType_t::DOUBLE: %d\n", cudnn_frontend::DataType_t::DOUBLE); - printf(">>>>>> bias_type: %d\n", bias_type); - printf(">>>>>> mask_type: %d\n", mask_type); - printf(">>>>>> scaling_factor: %f\n", scaling_factor); - printf(">>>>>> dropout_probability: %f\n", dropout_probability); try { FADescriptor_v1 descriptor{b, @@ -1777,11 +1769,9 @@ void fused_attn_fp8_fwd_impl_v1( std::shared_ptr bias, seq_q, seq_kv; std::shared_ptr dropout_seed, dropout_offset; - printf(">>>>>> layout: %d\n", layout); std::vector q_stride(4); std::vector k_stride(4); std::vector v_stride(4); - printf(">>>>>> b: %d, h: %d, h_g: %d, s_q: %d, s_kv: %d, d: %d\n", b, h, hg, s_q, s_kv, d); generateMatrixStrides(b, h, s_q, s_kv, d, q_stride.data(), layout, NVTE_QKV_Matrix::NVTE_Q_Matrix); generateMatrixStrides(b, hg, s_q, s_kv, d, k_stride.data(), layout, @@ -1979,16 +1969,10 @@ void fused_attn_fp8_fwd_impl_v1( : std::make_tuple(nullptr, nullptr); NVTE_CHECK_CUDNN_FE(mha_graph->validate()); - printf(">>>>>> mha_graph->validate()\n"); NVTE_CHECK_CUDNN_FE(mha_graph->build_operation_graph(handle)); - printf(">>>>>> mha_graph->build_operation_graph(handle)\n"); NVTE_CHECK_CUDNN_FE(mha_graph->create_execution_plans({fe::HeurMode_t::A})); - printf(">>>>>> mha_graph->create_execution_plans({fe::HeurMode_t::A})\n"); NVTE_CHECK_CUDNN_FE(mha_graph->check_support(handle)); - printf(">>>>>> mha_graph->check_support(handle)\n"); NVTE_CHECK_CUDNN_FE(mha_graph->build_plans(handle)); - printf(">>>>>> mha_graph->build_plans(handle)\n"); - printf(">>>>>> mha_graph->get_workspace_size(): %zu\n", mha_graph->get_workspace_size()); auto return_tuple = std::tuple_cat(std::make_tuple(mha_graph), key_tensors_tuple, Stats_tuple, bias_tuple, padding_tuple, dropout_tuple); cache.insert({descriptor, return_tuple}); diff --git a/transformer_engine/common/recipe/__init__.py b/transformer_engine/common/recipe/__init__.py index da1bf03b02..950d67155b 100644 --- a/transformer_engine/common/recipe/__init__.py +++ b/transformer_engine/common/recipe/__init__.py @@ -120,7 +120,7 @@ def float8_block_scaling(cls): @classmethod def custom(cls): """Whether the given recipe is custom.""" - return isinstance(self, CustomRecipe) + return isinstance(cls, CustomRecipe) @dataclass() diff --git a/transformer_engine/pytorch/attention/dot_product_attention/backends.py b/transformer_engine/pytorch/attention/dot_product_attention/backends.py index 5cc23eabd8..47f7e0f222 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/backends.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/backends.py @@ -29,6 +29,7 @@ Float8Quantizer, Float8CurrentScalingQuantizer, ) +from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8Quantizer from transformer_engine.pytorch.quantized_tensor import ( QuantizedTensorStorage, prepare_for_saving, @@ -174,15 +175,23 @@ def forward(ctx, tensor1, tensor2, tensor3, quantizer, quantizer_name, qkv_layou query_layer, key_layer, value_layer = [ x.contiguous() for x in [tensor1, tensor2, tensor3] ] + assert qkv_layout == "sbhd_sbhd_sbhd", "sbhd_sbhd_sbhd is assumed to be the shape always at this point in UnfusedDotProductAttention." q_fp8, k_fp8, v_fp8, qkv_layout = combine_and_quantize( qkv_layout, query_layer, key_layer, value_layer, quantizer ) tensors = combine_and_dequantize( - qkv_layout, q_fp8, k_fp8, v_fp8, src_nominal_dtype=query_layer.dtype + qkv_layout, q_fp8, k_fp8, v_fp8, src_nominal_dtype=query_layer.dtype, des_nominal_dtype=query_layer.dtype ) + if isinstance(quantizer, MXFP8Quantizer): + assert qkv_layout == "bhsd_bhsd_bhsd", "bhsd_bhsd_bhsd is assumed to be the shape always at this point in UnfusedDotProductAttention." + # permute back to sbhd_sbhd_sbhd + tensors = [x.permute(2, 0, 1, 3).contiguous() for x in tensors] elif quantizer_name in ["S_quantizer", "O_quantizer"]: - t_fp8 = quantizer(tensor1) - tensors = (t_fp8.dequantize(dtype=tensor1.dtype), tensor2, tensor3) + if quantizer is not None: + t_fp8 = quantizer(tensor1) + tensors = (t_fp8.dequantize(dtype=tensor1.dtype), tensor2, tensor3) + else: + tensors = (tensor1, tensor2, tensor3) else: tensors = (tensor1, tensor2, tensor3) ctx.quantizer = quantizer @@ -376,6 +385,7 @@ def forward( query_layer.shape[0], key_layer.shape[0], ) + apply_qk_layer_scaling = self.apply_qk_layer_scaling and key_layer.dtype == torch.float16 if "padding" in attn_mask_type and attention_mask is None: attention_mask = dpa_utils.get_padding_mask( @@ -402,9 +412,6 @@ def forward( ) ) - batch_size, seqlen = query_layer.shape[1], query_layer.shape[0] - apply_qk_layer_scaling = self.apply_qk_layer_scaling and key_layer.dtype == torch.float16 - # [b, np, sq, sk] output_size = ( query_layer.size(1), @@ -424,11 +431,6 @@ def forward( int(query_layer.shape[2] / value_layer.shape[2]), dim=2 ) - # [sq, b, np, hn] -> [sq, b * np, hn] - query_layer = query_layer.reshape(output_size[2], output_size[0] * output_size[1], -1) - # [sk, b, np, hn] -> [sk, b * np, hn] - key_layer = key_layer.reshape(output_size[3], output_size[0] * output_size[1], -1) - # preallocting result tensor: [b * np, sq, sk] matmul_result = torch.empty( output_size[0] * output_size[1], @@ -447,6 +449,11 @@ def forward( QKV_quantizer, O_quantizer, S_quantizer, dQKV_quantizer, dO_quantizer, dP_quantizer = ( dpa_utils.get_attention_quantizers(fp8, quantizers) ) + # disable swizzle for MXFP8Quantizer + for q in [QKV_quantizer, O_quantizer, S_quantizer, dQKV_quantizer, dO_quantizer, dP_quantizer]: + if isinstance(q, MXFP8Quantizer): + q.optimize_for_gemm = False + q.internal = False # S/dP are forced to use DS quantizers in DPA.init_fp8_metadata; revert them here for true CS emulation fp8_recipe = FP8GlobalStateManager.get_fp8_recipe() if fp8_meta is not None and fp8_meta.get("local_recipes", None) is not None: @@ -459,18 +466,21 @@ def forward( fp8_dtype=dP_quantizer.dtype, device="cuda" ) - if "2" in qkv_layout or "3" in qkv_layout: - qkv_format, *_ = dpa_utils.get_qkv_format(qkv_layout) - qkv_layout = "_".join([qkv_format] * 3) + # q, k, v are in sbhd after previous reshaping # quantize and dequantize QKV to emulate FP8 query_layer, key_layer, value_layer = FP8EmulationFunc.apply( - query_layer, key_layer, value_layer, QKV_quantizer, "QKV_quantizer", qkv_layout + query_layer, key_layer, value_layer, QKV_quantizer, "QKV_quantizer", "sbhd_sbhd_sbhd" ) # quantize and dequantize dQKV to emulate FP8 query_layer, key_layer, value_layer = FP8EmulationFunc.apply( - query_layer, key_layer, value_layer, dQKV_quantizer, "dQKV_quantizer", qkv_layout + query_layer, key_layer, value_layer, dQKV_quantizer, "dQKV_quantizer", "sbhd_sbhd_sbhd" ) + # [sq, b, np, hn] -> [sq, b * np, hn] + query_layer = query_layer.reshape(output_size[2], output_size[0] * output_size[1], -1) + # [sk, b, np, hn] -> [sk, b * np, hn] + key_layer = key_layer.reshape(output_size[3], output_size[0] * output_size[1], -1) + # Raw attention scores. [b * np, sq, sk] if core_attention_bias_type == "no_bias": matmul_result = torch.baddbmm( @@ -600,14 +610,14 @@ def forward( context_layer = context_layer.permute(2, 0, 1, 3).contiguous() # [sq, b, np, hn] --> [sq, b, hp] - context_layer = context_layer.view(seqlen, batch_size, -1) + context_layer = context_layer.view(max_seqlen_q, batch_size, -1) if q_format == "bshd": # [b, np, sq, hn] --> [b, sq, np, hn] context_layer = context_layer.permute(0, 2, 1, 3).contiguous() # [b, sq, np, hn] --> [b, sq, hp] - context_layer = context_layer.view(batch_size, seqlen, -1) + context_layer = context_layer.view(batch_size, max_seqlen_q, -1) if q_format == "thd": # [b, np, sq, hn] --> [b, sq, np, hn] @@ -1207,6 +1217,9 @@ def forward( # whether bwd kernel in FP8: is_bwd_fp8 = fp8 and int(os.getenv("NVTE_FP8_DPA_BWD", "1")) + # save original qkv_layout + original_qkv_layout = qkv_layout + # get quantizers from DPA; all Nones if not fp8 QKV_quantizer, O_quantizer, S_quantizer, dQKV_quantizer, dO_quantizer, dP_quantizer = ( dpa_utils.get_attention_quantizers(fp8, quantizers) @@ -1217,20 +1230,18 @@ def forward( # FP8 attention: torch.float16 or torch.bfloat16 out_nominal_dtype = q.dtype - # save original qkv_layout - original_qkv_layout = qkv_layout - max_logit = None if fp8: fused_attention_backend = FusedAttnBackend["FP8"] # q, k, v: torch.Tensor; dtype = torch.float16 or torch.bfloat16 - # q_fp8, k_fp8, v_fp8: Float8Tensor; dtype = torch.float16 or torch.bfloat16 - # fp8_dtype = tex.DType.kFloat8E4M3 + # q_fp8, k_fp8, v_fp8: Float8Tensor/MXFP8Tensor; + # dtype = torch.float16 or torch.bfloat16 + # fp8_dtype = tex.DType.kFloat8E4M3 if is_input_fp8: q_fp8, k_fp8, v_fp8 = q, k, v else: - q_fp8, k_fp8, v_fp8, qkv_layout = combine_and_quantize(qkv_layout, q, k, v, QKV_quantizer) + q_fp8, k_fp8, v_fp8, qkv_layout = combine_and_quantize(original_qkv_layout, q, k, v, QKV_quantizer) # print quantizers print_quantizers( @@ -1248,6 +1259,7 @@ def forward( # DelayedScaling: Float8Tensor; dtype = torch.float16 or torch.bfloat16 # fp8_dtype = tex.DType.kFloat8E4M3 # Float8CurrentScaling: torch.Tensor; dtype = torch.float16 or torch.bfloat16 + # MXFP8BlockScaling: torch.Tensor; dtype = torch.float16 or torch.bfloat16 out_, aux_ctx_tensors, *_ = fused_attn_fwd( is_training, max_seqlen_q, @@ -1280,27 +1292,20 @@ def forward( cuda_graph=is_graph_capturing(), ) if original_qkv_layout != qkv_layout: - print(f">>>>>>>>>>>> original_qkv_layout: {original_qkv_layout}") - print(f">>>>>>>>>>>> qkv_layout: {qkv_layout}") - print(f">>>>>>>>>>>> out_.shape: {out_.shape}") original_qkv_format = original_qkv_layout.split("_")[0] new_qkv_format = qkv_layout.split("_")[0] perm = [] for i in new_qkv_format: perm.append(original_qkv_format.find(i)) out_ = out_.permute(*perm).contiguous() - print(f">>>>>>>>>>>> out_.shape permuted: {out_.shape}") - # out_fp8: Float8Tensor; dtype = torch.float16 or torch.bfloat16 + # out_fp8: Float8Tensor/MXFP8Tensor; dtype = torch.float16 or torch.bfloat16 # fp8_dtype = tex.DType.kFloat8E4M3 # out: torch.Tensor; dtype = torch.float16 or torch.bfloat16 out_fp8 = out_ out = out_ print(f"out_: {type(out_)} {out_.shape}") - print(f"is_output_fp8: {is_output_fp8}") - print(f"is_bwd_fp8: {is_bwd_fp8}") - print(f"fp8_recipe.float8_current_scaling(): {fp8_recipe.float8_current_scaling()}") - print(f"_dpa_fp8_cs_o_in_f16: {_dpa_fp8_cs_o_in_f16}") + print(f"is_output_fp8: {is_output_fp8}, is_bwd_fp8: {is_bwd_fp8}, fp8_recipe.float8_current_scaling(): {fp8_recipe.float8_current_scaling()}, _dpa_fp8_cs_o_in_f16: {_dpa_fp8_cs_o_in_f16}") if isinstance(out_, Float8Tensor) or isinstance(out_, MXFP8Tensor): print(f"dequantizing out_") if not is_output_fp8 or not is_bwd_fp8: @@ -1451,6 +1456,7 @@ def forward( else: ctx.qkv_layout = qkv_layout else: + ctx.original_qkv_layout = original_qkv_layout ctx.qkv_layout = qkv_layout ctx.attn_bias_type = attn_bias_type @@ -1539,6 +1545,14 @@ def backward(ctx, d_out, *_args): # FP8 attention: torch.float16 or torch.bfloat16 dqkv_nominal_dtype = ctx.nominal_dtype + if ctx.original_qkv_layout != ctx.qkv_layout: + original_qkv_format = ctx.original_qkv_layout.split("_")[0] + new_qkv_format = ctx.qkv_layout.split("_")[0] + perm = [] + for i in original_qkv_format: + perm.append(new_qkv_format.find(i)) + d_out = d_out.permute(*perm).contiguous() + if ctx.fp8: # d_out: torch.Tensor; dtype = torch.float16 or torch.bfloat16 # d_out_fp8: Float8Tensor; dtype = torch.float16 or torch.bfloat16 @@ -1567,20 +1581,20 @@ def backward(ctx, d_out, *_args): # fp8_dtype = tex.DType.kFloat8E4M3 # d_out_fp8: Float8Tensor; dtype = torch.float16 or torch.bfloat16 # fp8_dtype = tex.DType.kFloat8E5M2 - # out_: - # DelayedScaling: Float8Tensor; dtype = torch.float16 or torch.bfloat16 + # DelayedScaling: + # out_, dq_, dk_, dv_: Float8Tensor; dtype = torch.float16 or torch.bfloat16 # fp8_dtype = tex.DType.kFloat8E4M3 - # Float8CurrentScaling: torch.Tensor; dtype = torch.float16 or torch.bfloat16 - # - # dq_, dk_, dv_: - # DelayedScaling: Float8Tensor; dtype = torch.float16 or torch.bfloat16 - # fp8_dtype = tex.DType.kFloat8E5M2 - # Float8CurrentScaling: torch.Tensor; dtype = torch.float16 or torch.bfloat16 - out_ = ( - out - if ctx.fp8_recipe.float8_current_scaling() and _dpa_fp8_cs_o_in_f16 - else out_fp8 - ) + # Float8CurrentScaling: + # out_, dq_, dk_, dv_: torch.Tensor; dtype = torch.float16 or torch.bfloat16 + # MXFP8BlockScaling: + # out_, dq_, dk_, dv_, d_out: torch.Tensor; dtype = torch.float16 or torch.bfloat16 + out_ = out_fp8 + if ctx.fp8_recipe.float8_current_scaling() and _dpa_fp8_cs_o_in_f16: + out_ = out + if ctx.fp8_recipe.mxfp8_block_scaling(): + out_ = out + aux_ctx_tensors.append(d_out) + dq_, dk_, dv_, *rest = fused_attn_bwd( ctx.max_seqlen_q, ctx.max_seqlen_kv, @@ -1615,8 +1629,8 @@ def backward(ctx, d_out, *_args): # dq, dk, dv: torch.Tensor; dtype = torch.float16 or torch.bfloat16 dq, dk, dv = dq_, dk_, dv_ - is_float8tensor = isinstance(dq_, Float8Tensor) - if is_float8tensor and not ctx.is_input_fp8: + is_quantized_tensor = isinstance(dq_, QuantizedTensor) + if is_quantized_tensor and not ctx.is_input_fp8: # return in F16 dq, dk, dv = combine_and_dequantize( ctx.qkv_layout, @@ -1627,7 +1641,7 @@ def backward(ctx, d_out, *_args): ) if not is_float8tensor and ctx.is_input_fp8: # return in FP8 - dq, dk, dv, ctx.qkv_layout = combine_and_quantize( + dq, dk, dv, _ = combine_and_quantize( ctx.qkv_layout, dq_, dk_, dv_, ctx.dQKV_quantizer ) diff --git a/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py b/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py index a5931188dc..244f24111d 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py @@ -1392,7 +1392,7 @@ def forward( # q_fp8, k_fp8, v_fp8: Float8Tensor, dtype=fwd_nominal_dtype # q, k, v: torch.Tensor, dtype=torch.uint8 q_f16 = q - q_fp8, k_fp8, v_fp8 = combine_and_quantize(qkv_layout, q, k, v, QKV_quantizer) + q_fp8, k_fp8, v_fp8, qkv_layout = combine_and_quantize(qkv_layout, q, k, v, QKV_quantizer) q, k, v = [q_fp8._data, k_fp8._data, v_fp8._data] # print quantizers diff --git a/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py b/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py index eb905d7b93..55553d30be 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py @@ -583,9 +583,8 @@ def init_fp8_metadata(self, num_gemms: int = 1) -> None: # global recipe set in autocast() fp8_recipe = FP8GlobalStateManager.get_fp8_recipe() - print(f"fp8_recipe: {fp8_recipe}") - # if fp8_recipe.custom(): - # return + if fp8_recipe.custom(): + return # switch/append recipe: fp8_recipe stays unchanged, but DPA.fp8_meta["recipe"] may be set to # a different recipe than fp8_recipe. DPA.quantizers may be a mix of different quantizers as well. diff --git a/transformer_engine/pytorch/attention/dot_product_attention/utils.py b/transformer_engine/pytorch/attention/dot_product_attention/utils.py index c5ba652c28..d3c2e01814 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/utils.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/utils.py @@ -35,12 +35,14 @@ META_DP, ) from transformer_engine.pytorch.attention.inference import InferenceParams +from transformer_engine.pytorch.quantized_tensor import QuantizedTensor from transformer_engine.pytorch.tensor.float8_tensor import ( Float8Tensor, Float8Quantizer, Float8CurrentScalingQuantizer, ) from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8Quantizer, MXFP8Tensor +from transformer_engine.pytorch.tensor.storage.mxfp8_tensor_storage import MXFP8TensorStorage from transformer_engine.pytorch.tensor.storage.grouped_tensor import GroupedTensor from transformer_engine.pytorch.quantization import get_fp8_te_dtype from transformer_engine.pytorch.constants import TE_DType @@ -2099,6 +2101,8 @@ def get_attention_quantizers(fp8, quantizers): O_quantizer = quantizers["scaling_fwd"][META_O] O_quantizer.set_usage(rowwise=True, columnwise=False) if isinstance(QKV_quantizer, MXFP8Quantizer): + QKV_quantizer.optimize_for_gemm = True + # QKV_quantizer.internal = False S_quantizer = None else: S_quantizer = quantizers["scaling_fwd"][META_S] @@ -2184,67 +2188,49 @@ def combine_and_quantize(qkv_layout, q, k, v, qkv_quantizer): """Combine q,k,v based on qkv_layout and quantize them together""" # 1: qkv packed, 2: kv packed, 3: qkv separate qkv_layout = qkv_layout.replace("paged_kv_", "") - qkv_format, _, _ = get_qkv_format(qkv_layout) + qkv_format, q_format, kv_format = get_qkv_format(qkv_layout) qkv_group = len(qkv_layout.split("_")) src_nominal_dtype = q.dtype - print(f"Combining and quantizing q, k, v: {q.shape}, {k.shape}, {v.shape}") if isinstance(qkv_quantizer, MXFP8Quantizer): - # bs3hd -> bshd_bshd_bhsd - q,k,v = [x.contiguous() for x in [q, k, v]] - print(f">>>>>>>>>>>> Contiguous shapes: q: {q.shape}, k: {k.shape}, v: {v.shape}") - - # bshd_bshd_bhsd -> bhsd_bhsd_bhsd - # thd_thd_thd -> htd_htd_htd - qkv_quantizer.optimize_for_gemm = True - qkv_quantizer.internal = False - print(f">>>>>>>>>>>> qkv_quantizer.internal: {qkv_quantizer.internal}") - dim_s = qkv_format.find("s") if 's' in qkv_format else qkv_format.find("t") - def permute_x(x): - dim_others = [i for i in range(len(x.shape)) if i != dim_s] - perm = [*dim_others[:-1], dim_s, dim_others[-1]] + print(f"Combining and quantizing q, k, v: {q.shape}, {k.shape}, {v.shape}") + + def permute_x(f, x): + x = x.contiguous() if not x.is_contiguous() else x + dim_s_dim_t = f.find("s") if 's' in f else f.find("t") + dim_others = [i for i in range(len(x.shape)) if i != dim_s_dim_t] + perm = [*dim_others[:-1], dim_s_dim_t, dim_others[-1]] x = x.permute(*perm).contiguous() return x - q, k, v = [permute_x(x) for x in [q, k, v]] + + # bs3hd, sb3hd, etc -> bshd_bshd_bhsd -> bhsd_bhsd_bhsd + # t3hd, etc -> thd_thd_thd -> htd_htd_htd + if q_format not in ["bhsd", "htd"]: + q = permute_x(q_format, q) + if kv_format not in ["bhsd", "htd"]: + k = permute_x(kv_format, k) + v = permute_x(kv_format, v) print(f">>>>>>>>>>>> Permuted shapes: q: {q.shape}, k: {k.shape}, v: {v.shape}") + qkv_layout = "bhsd_bhsd_bhsd" if qkv_format != "thd" else "htd_htd_htd" - original_shapes = [q.shape, k.shape, v.shape] - b, h_q, s_q, d_qk = q.shape - _, h_kv, s_kv, d_kv = v.shape - assert k.shape == (b, h_kv, s_kv, d_qk) + original_shapes = [x.shape for x in [q, k, v]] + s_q, d_qk = q.shape[-2:] + s_kv, d_kv = v.shape[-2:] assert s_q % 128 == 0 assert s_kv % 128 == 0 assert d_qk % 32 == 0 assert d_kv % 32 == 0 + # need to check seqlens in THD % 128 == 0 q, k, v = [x.view(-1, x.shape[-1]) for x in [q, k, v]] print(f">>>>>>>>>>>> Flattened shapes: q: {q.shape}, k: {k.shape}, v: {v.shape}") # consider bhsd for now grouped_tensor = GroupedTensor.create_and_quantize(tensors=[q, k, v], quantizer=qkv_quantizer) - print(f">>>>>>>>>>>> grouped_tensor: {type(grouped_tensor)}") - print(f">>>>>>>>>>>> grouped_tensor.quantized_tensors: {type(grouped_tensor.quantized_tensors)}") q_fp8, k_fp8, v_fp8 = grouped_tensor.quantized_tensors + q_fp8, k_fp8, v_fp8 = [x.view(s) for x, s in zip([q_fp8, k_fp8, v_fp8], original_shapes)] print(f">>>>>>>>>>>> q_fp8: {q_fp8._rowwise_data.shape}, k_fp8: {k_fp8._rowwise_data.shape}, v_fp8: {v_fp8._rowwise_data.shape}") print(f">>>>>>>>>>>> q_fp8: {q_fp8._rowwise_scale_inv.shape}, k_fp8: {k_fp8._rowwise_scale_inv.shape}, v_fp8: {v_fp8._rowwise_scale_inv.shape}") print(f">>>>>>>>>>>> q_fp8: {q_fp8._columnwise_data.shape}, k_fp8: {k_fp8._columnwise_data.shape}, v_fp8: {v_fp8._columnwise_data.shape}") print(f">>>>>>>>>>>> q_fp8: {q_fp8._columnwise_scale_inv.shape}, k_fp8: {k_fp8._columnwise_scale_inv.shape}, v_fp8: {v_fp8._columnwise_scale_inv.shape}") - q_fp8, k_fp8, v_fp8 = [x.view(*original_shapes[i]) for i, x in enumerate([q_fp8, k_fp8, v_fp8])] - print(f">>>>>>>>>>>> q_fp8: {q_fp8.shape}, k_fp8: {k_fp8.shape}, v_fp8: {v_fp8.shape}") - print(f">>>>>>>>>>>> q_fp8: {q_fp8._rowwise_data.shape}, k_fp8: {k_fp8._rowwise_data.shape}, v_fp8: {v_fp8._rowwise_data.shape}") - print(f">>>>>>>>>>>> q_fp8: {q_fp8._rowwise_scale_inv.shape}, k_fp8: {k_fp8._rowwise_scale_inv.shape}, v_fp8: {v_fp8._rowwise_scale_inv.shape}") - print(f">>>>>>>>>>>> q_fp8: {q_fp8._columnwise_data.shape}, k_fp8: {k_fp8._columnwise_data.shape}, v_fp8: {v_fp8._columnwise_data.shape}") - print(f">>>>>>>>>>>> q_fp8: {q_fp8._columnwise_scale_inv.shape}, k_fp8: {k_fp8._columnwise_scale_inv.shape}, v_fp8: {v_fp8._columnwise_scale_inv.shape}") - # dim_s = qkv_format.find("s") if 's' in qkv_format else qkv_format.find("t") - # dim_others = [i for i in range(len(v.shape)) if i != dim_s] - # perm = [*dim_others, dim_s] - # # perm = [*dim_others[:-1], dim_s, dim_others[-1]] - # v = v.permute(*perm).contiguous() - - qkv_layout = "bhsd_bhsd_bhsd" - - # inv = [0] * len(perm) - # for i, p in enumerate(perm): - # inv[p] = i - # # v = v.permute(*inv) return q_fp8, k_fp8, v_fp8, qkv_layout @@ -2296,14 +2282,20 @@ def combine_and_dequantize( """Combine q,k,v based on qkv_layout and dequantize them together""" # 1: qkv packed, 2: kv packed, 3: qkv separate qkv_layout = qkv_layout.replace("paged_kv_", "") + qkv_format, q_format, kv_format = get_qkv_format(qkv_layout) qkv_group = len(qkv_layout.split("_")) - if all(isinstance(x, Float8Tensor) for x in [q_fp8, k_fp8, v_fp8]): + if all(isinstance(x, QuantizedTensor) for x in [q_fp8, k_fp8, v_fp8]): src_nominal_dtype = q_fp8.dtype else: assert src_nominal_dtype is not None, "The nominal dtype of input tensors is required!" if des_nominal_dtype is None: des_nominal_dtype = src_nominal_dtype + if all(isinstance(x, (MXFP8Tensor, MXFP8TensorStorage)) for x in [q_fp8, k_fp8, v_fp8]): + print(f"Combining and dequantizing q, k, v from MXFP8 to {des_nominal_dtype}") + q, k, v = [x.dequantize(dtype=des_nominal_dtype) for x in [q_fp8, k_fp8, v_fp8]] + return q, k, v + q_data, k_data, v_data = [x._data for x in [q_fp8, k_fp8, v_fp8]] match qkv_group: case 1: diff --git a/transformer_engine/pytorch/cpp_extensions/fused_attn.py b/transformer_engine/pytorch/cpp_extensions/fused_attn.py index b4811eb4f6..09953440e9 100644 --- a/transformer_engine/pytorch/cpp_extensions/fused_attn.py +++ b/transformer_engine/pytorch/cpp_extensions/fused_attn.py @@ -16,7 +16,6 @@ NVTE_Fused_Attn_Backend, ) from ..quantized_tensor import Quantizer -from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8Quantizer __all__ = [ diff --git a/transformer_engine/pytorch/tensor/mxfp8_tensor.py b/transformer_engine/pytorch/tensor/mxfp8_tensor.py index e4c658ed58..6c72d74531 100644 --- a/transformer_engine/pytorch/tensor/mxfp8_tensor.py +++ b/transformer_engine/pytorch/tensor/mxfp8_tensor.py @@ -75,8 +75,7 @@ def update_quantized( src = src.contiguous() # Launch cast kernel - print(f">>>>>>>>>>>> src: {src.shape}") - print(f">>>>>>>>>>>> dst: {dst.shape}") + print(f"MXFP8Quantizer.update_quantized: src: {src.shape}, dst: {dst.shape}") tex.quantize(src, self, dst, noop_flag) # Update FP8 dtype @@ -86,7 +85,7 @@ def update_quantized( def quantize_impl(self, tensor: torch.Tensor) -> QuantizedTensor: """Quantize tensor implementation""" - print(f"Quantizing tensor: {tensor.shape}") + print(f"MXFP8Quantizer.quantize_impl: tensor: {tensor.shape}") return tex.quantize(tensor, self) def is_quantizable(self, inp: torch.Tensor) -> bool: diff --git a/transformer_engine/pytorch/tensor/storage/grouped_tensor.py b/transformer_engine/pytorch/tensor/storage/grouped_tensor.py index 9771d61df8..5a8d323983 100644 --- a/transformer_engine/pytorch/tensor/storage/grouped_tensor.py +++ b/transformer_engine/pytorch/tensor/storage/grouped_tensor.py @@ -91,24 +91,6 @@ def __init__( offsets: Vector of integer offsets for each tensor. logical_shape: 2D tuple representing conceptual shape """ - # print(f">>>>>>>>>>>> GroupedTensor init: {quantizers}") - # print(f">>>>>>>>>>>> shape: {shape}") - # print(f">>>>>>>>>>>> dtype: {dtype}") - # print(f">>>>>>>>>>>> data: {data.shape}") - # print(f">>>>>>>>>>>> columnwise_data: {columnwise_data.shape if columnwise_data is not None else None}") - # print(f">>>>>>>>>>>> scale_inv: {scale_inv.shape if scale_inv is not None else None}") - # print(f">>>>>>>>>>>> columnwise_scale_inv: {columnwise_scale_inv.shape if columnwise_scale_inv is not None else None}") - # print(f">>>>>>>>>>>> amax: {amax.shape if amax is not None else None}") - # print(f">>>>>>>>>>>> columnwise_amax: {columnwise_amax.shape if columnwise_amax is not None else None}") - # print(f">>>>>>>>>>>> scale: {scale.shape if scale is not None else None}") - # print(f">>>>>>>>>>>> first_dims: {first_dims.shape}") - # print(f">>>>>>>>>>>> last_dims: {last_dims.shape}") - # print(f">>>>>>>>>>>> tensor_offsets: {tensor_offsets.shape}") - # print(f">>>>>>>>>>>> offsets: {offsets.shape}") - # print(f">>>>>>>>>>>> scale_inv_offsets: {scale_inv_offsets.shape}") - # print(f">>>>>>>>>>>> columnwise_scale_inv_offsets: {columnwise_scale_inv_offsets.shape}") - # print(f">>>>>>>>>>>> logical_shape: {logical_shape.shape}") - print(f">>>>>>>>>>>> num_tensors: {num_tensors}") self.num_tensors = num_tensors self.quantizer = quantizer @@ -519,7 +501,7 @@ def make_grouped_tensor( # For simplicity, calculate total scale elements needed total_scale_elements = 0 scale_inv_offsets = [0] - for i, s in enumerate(shape): + for i, s in enumerate(shapes): scale_inv_shape = quantizer.get_scale_shape(s, False) total_scale_elements += math.prod(scale_inv_shape) if i < num_tensors - 1: @@ -554,7 +536,7 @@ def make_grouped_tensor( # For simplicity, calculate total scale elements needed total_scale_elements = 0 scale_inv_offsets = [0] - for i, s in enumerate(shape): + for i, s in enumerate(shapes): scale_inv_shape = quantizer.get_scale_shape(s, False) total_scale_elements += math.prod(scale_inv_shape) if i < num_tensors - 1: @@ -934,7 +916,7 @@ def create_and_quantize( Quantize given tensors into quantized tensors with underlying storage allocated in a GroupedTensor. """ - + print(f">>>>>>>>>>>> GroupedTensor create_and_quantize") grouped_input = GroupedTensor.make_grouped_tensor_with_shapes( num_tensors=len(tensors), shapes=[t.shape for t in tensors], @@ -960,15 +942,6 @@ def create_and_quantize( _ = tex.quantize_grouped(grouped_input, grouped_output) grouped_output.quantized_tensors = grouped_output.split_into_quantized_tensors() - # grouped_tensor = GroupedTensor.make_grouped_tensor_with_shapes( - # num_tensors=len(tensors), - # shapes=[t.shape for t in tensors], - # quantizer=None, - # device=device, - # dtype=tensors[0].dtype, - # ) - # grouped_tensor.quantize(tensors, noop_flag=noop_flag) - return grouped_output def quantize( @@ -979,10 +952,8 @@ def quantize( """ Quantize the GroupedTensor inplace. """ - + print(f">>>>>>>>>>>> GroupedTensor quantize") self.quantized_tensors = self.split_into_quantized_tensors() - print(f">>>>>>>>>>>> tensors[0]: {type(tensors[0])}") - print(f">>>>>>>>>>>> quantized_tensors[0]: {type(self.quantized_tensors[0])}") for i in range(self.num_tensors): self.quantizer.update_quantized(tensors[i], self.quantized_tensors[i], noop_flag=noop_flag) return self.quantized_tensors \ No newline at end of file From cd06398d2c57d021c31330318eb40ca8567578d4 Mon Sep 17 00:00:00 2001 From: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Date: Thu, 12 Feb 2026 18:54:12 -0800 Subject: [PATCH 010/172] split d to d_qk/d_v; attempt at bwd Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --- .../common/fused_attn/fused_attn.cpp | 30 +- .../common/fused_attn/fused_attn_fp8.cu | 434 +++++++++++++----- .../common/fused_attn/fused_attn_fp8.h | 6 +- 3 files changed, 354 insertions(+), 116 deletions(-) diff --git a/transformer_engine/common/fused_attn/fused_attn.cpp b/transformer_engine/common/fused_attn/fused_attn.cpp index 1adabcded2..98ff96b666 100644 --- a/transformer_engine/common/fused_attn/fused_attn.cpp +++ b/transformer_engine/common/fused_attn/fused_attn.cpp @@ -639,7 +639,7 @@ void nvte_fused_attn_fwd_qkvpacked( Tensor K_view = make_tensor_view(input_QKV, unpacked_shape, stride); Tensor V_view = make_tensor_view(input_QKV, unpacked_shape, 2 * stride); - fused_attn_fp8_fwd(b, h, h, max_seqlen, max_seqlen, d, is_training, attn_scale, dropout, + fused_attn_fp8_fwd(b, h, h, max_seqlen, max_seqlen, d, d, is_training, attn_scale, dropout, qkv_layout, bias_type, attn_mask_type, &Q_view, &K_view, &V_view, input_output_S, output_O, Aux_CTX_Tensors, input_cu_seqlens, input_cu_seqlens, input_rng_state, wkspace, stream, handle); @@ -772,6 +772,10 @@ void nvte_fused_attn_bwd_qkvpacked( const Tensor *input_M = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]); const Tensor *input_ZInv = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[1]); const Tensor *input_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[2]); + const Tensor *input_dO_f16; + if (input_dO->scaling_mode == NVTE_MXFP8_1D_SCALING) { + input_dO_f16 = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[3]); + } // Unpack QKV and dQKV and call the non-packed function const auto QKV_type = input_QKV->data.dtype; @@ -787,8 +791,8 @@ void nvte_fused_attn_bwd_qkvpacked( Tensor dK_view = make_tensor_view(output_dQKV, unpacked_shape, stride); Tensor dV_view = make_tensor_view(output_dQKV, unpacked_shape, 2 * stride); - fused_attn_fp8_bwd(b, h, h, max_seqlen, max_seqlen, d, attn_scale, dropout, qkv_layout, - bias_type, attn_mask_type, &Q_view, &K_view, &V_view, input_O, input_dO, + fused_attn_fp8_bwd(b, h, h, max_seqlen, max_seqlen, d, d, attn_scale, dropout, qkv_layout, + bias_type, attn_mask_type, &Q_view, &K_view, &V_view, input_O, input_dO, input_dO_f16, input_M, input_ZInv, input_S, input_output_dP, &dQ_view, &dK_view, &dV_view, input_cu_seqlens, input_cu_seqlens, input_rng_state, wkspace, stream, handle); @@ -945,7 +949,7 @@ void nvte_fused_attn_fwd_kvpacked( Tensor K_view = make_tensor_view(input_KV, unpacked_kv_shape); Tensor V_view = make_tensor_view(input_KV, unpacked_kv_shape, stride); - fused_attn_fp8_fwd(b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d, is_training, attn_scale, + fused_attn_fp8_fwd(b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d, d, is_training, attn_scale, dropout, qkv_layout, bias_type, attn_mask_type, input_Q, &K_view, &V_view, input_output_S, output_O, Aux_CTX_Tensors, input_cu_seqlens_q, input_cu_seqlens_kv, input_rng_state, wkspace, stream, handle); @@ -1090,6 +1094,10 @@ void nvte_fused_attn_bwd_kvpacked( const Tensor *input_M = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]); const Tensor *input_ZInv = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[1]); const Tensor *input_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[2]); + const Tensor *input_dO_f16; + if (input_dO->scaling_mode == NVTE_MXFP8_1D_SCALING) { + input_dO_f16 = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[3]); + } // Unpack KV and dKV and call the non-packed function const auto Q_type = input_Q->data.dtype; @@ -1104,9 +1112,9 @@ void nvte_fused_attn_bwd_kvpacked( Tensor dK_view = make_tensor_view(output_dKV, unpacked_kv_shape); Tensor dV_view = make_tensor_view(output_dKV, unpacked_kv_shape, stride); - fused_attn_fp8_bwd(b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d, attn_scale, dropout, + fused_attn_fp8_bwd(b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d, d, attn_scale, dropout, qkv_layout, bias_type, attn_mask_type, input_Q, &K_view, &V_view, input_O, - input_dO, input_M, input_ZInv, input_S, input_output_dP, output_dQ, &dK_view, + input_dO, input_dO_f16, input_M, input_ZInv, input_S, input_output_dP, output_dQ, &dK_view, &dV_view, input_cu_seqlens_q, input_cu_seqlens_kv, input_rng_state, wkspace, stream, handle); #else @@ -1228,7 +1236,7 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso #endif } else if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_FP8) { #if (CUDNN_VERSION >= 8900) - fused_attn_fp8_fwd(b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d_qk, is_training, attn_scale, + fused_attn_fp8_fwd(b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d_qk, d_v, is_training, attn_scale, dropout, qkv_layout, bias_type, attn_mask_type, input_Q, input_K, input_V, input_output_S, output_O, Aux_CTX_Tensors, input_cu_seqlens_q, input_cu_seqlens_kv, input_rng_state, wkspace, stream, handle); @@ -1340,9 +1348,13 @@ void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETenso const Tensor *input_M = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]); const Tensor *input_ZInv = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[1]); const Tensor *input_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[2]); - fused_attn_fp8_bwd(b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d_qk, attn_scale, dropout, + const Tensor *input_dO_f16; + if (input_dO->scaling_mode == NVTE_MXFP8_1D_SCALING) { + input_dO_f16 = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[3]); + } + fused_attn_fp8_bwd(b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d_qk, d_v, attn_scale, dropout, qkv_layout, bias_type, attn_mask_type, input_Q, input_K, input_V, input_O, - input_dO, input_M, input_ZInv, input_S, input_output_dP, output_dQ, + input_dO, input_dO_f16, input_M, input_ZInv, input_S, input_output_dP, output_dQ, output_dK, output_dV, input_cu_seqlens_q, input_cu_seqlens_kv, input_rng_state, wkspace, stream, handle); #else diff --git a/transformer_engine/common/fused_attn/fused_attn_fp8.cu b/transformer_engine/common/fused_attn/fused_attn_fp8.cu index 71d86843b5..d9af04c628 100644 --- a/transformer_engine/common/fused_attn/fused_attn_fp8.cu +++ b/transformer_engine/common/fused_attn/fused_attn_fp8.cu @@ -1652,7 +1652,7 @@ void fused_attn_fp8_bwd_impl( // fused attention FWD FP8 with FE 1.0+ void fused_attn_fp8_fwd_impl_v1( - int64_t b, int64_t h, int64_t hg, int64_t s_q, int64_t s_kv, int64_t d, bool is_training, + int64_t b, int64_t h, int64_t hg, int64_t s_q, int64_t s_kv, int64_t d_qk, int64_t d_v, bool is_training, float scaling_factor, float dropout_probability, NVTE_QKV_Layout layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, void* devPtrQ, void* devPtrK, void* devPtrV, void* devPtrM, void* devPtrZInv, void* devPtrO, void* devPtrDescaleQ, void* devPtrDescaleK, @@ -1681,7 +1681,7 @@ void fused_attn_fp8_fwd_impl_v1( o_tensor_type == cudnn_frontend::DataType_t::BFLOAT16); NVTE_CHECK(is_delayed_scaling || is_current_scaling || is_mxfp8, "FP8 fused attention only supports FP8DelayedScaling or FP8CurrentScaling or MXFP8 recipes!"); - printf(">>>>>> b: %d, h: %d, h_g: %d, s_q: %d, s_kv: %d, d: %d\n", b, h, hg, s_q, s_kv, d); + printf(">>>>>> b: %d, h: %d, h_g: %d, s_q: %d, s_kv: %d, d_qk: %d, d_v: %d\n", b, h, hg, s_q, s_kv, d_qk, d_v); printf(">>>>>> scaling_mode: %d\n", scaling_mode); printf(">>>>>> is_mxfp8: %d\n", is_mxfp8); printf(">>>>>> is_current_scaling: %d\n", is_current_scaling); @@ -1695,8 +1695,8 @@ void fused_attn_fp8_fwd_impl_v1( hg, s_q, s_kv, - d, - d, + d_qk, + d_v, 0, 0, 0, @@ -1772,36 +1772,39 @@ void fused_attn_fp8_fwd_impl_v1( std::vector q_stride(4); std::vector k_stride(4); std::vector v_stride(4); - generateMatrixStrides(b, h, s_q, s_kv, d, q_stride.data(), layout, + generateMatrixStrides(b, h, s_q, s_kv, d_qk, q_stride.data(), layout, NVTE_QKV_Matrix::NVTE_Q_Matrix); - generateMatrixStrides(b, hg, s_q, s_kv, d, k_stride.data(), layout, + generateMatrixStrides(b, hg, s_q, s_kv, d_qk, k_stride.data(), layout, NVTE_QKV_Matrix::NVTE_K_Matrix); - generateMatrixStrides(b, hg, s_q, s_kv, d, v_stride.data(), layout, + generateMatrixStrides(b, hg, s_q, s_kv, d_v, v_stride.data(), layout, NVTE_QKV_Matrix::NVTE_V_Matrix); // need to double check printf(">>>>>> q_stride: %d, %d, %d, %d\n", q_stride[0], q_stride[1], q_stride[2], q_stride[3]); printf(">>>>>> k_stride: %d, %d, %d, %d\n", k_stride[0], k_stride[1], k_stride[2], k_stride[3]); printf(">>>>>> v_stride: %d, %d, %d, %d\n", v_stride[0], v_stride[1], v_stride[2], v_stride[3]); int32_t block_size = 32; - int64_t d_scale = (d + block_size - 1) / block_size; + int64_t d_qk_scale = (d_qk + block_size - 1) / block_size; + int64_t d_v_scale = (d_v + block_size - 1) / block_size; int64_t s_kv_scale = (s_kv + block_size - 1) / block_size; int64_t s_q_padded = ((s_q + 127) / 128) * 128; int64_t s_kv_padded = ((s_kv + 127) / 128) * 128; - int64_t d_scale_padded = ((d_scale + 3) / 4) * 4; + int64_t d_qk_scale_padded = ((d_qk_scale + 3) / 4) * 4; + int64_t d_v_scale_padded = ((d_v_scale + 3) / 4) * 4; int64_t s_kv_scale_padded = ((s_kv_scale + 3) / 4) * 4; - int64_t d_padded = ((d + 3) / 4) * 4; - printf(">>>>>> d_scale: %d, s_kv_scale: %d, s_q_padded: %d, s_kv_padded: %d, d_scale_padded: %d, s_kv_scale_padded: %d, d_padded: %d\n", d_scale, s_kv_scale, s_q_padded, s_kv_padded, d_scale_padded, s_kv_scale_padded, d_padded); - std::vector q_scale_dims = {b, h, s_q_padded, d_scale_padded}; - std::vector k_scale_dims = {b, hg, s_kv_padded, d_scale_padded}; - std::vector v_scale_dims = {b, hg, s_kv_scale_padded, d_padded}; + int64_t d_qk_padded = ((d_qk + 3) / 4) * 4; + int64_t d_v_padded = ((d_v + 3) / 4) * 4; + printf(">>>>>> d_qk_scale: %d, d_v_scale: %d, s_kv_scale: %d, s_q_padded: %d, s_kv_padded: %d, d_qk_scale_padded: %d, d_v_scale_padded: %d, s_kv_scale_padded: %d, d_qk_padded: %d, d_v_padded: %d\n", d_qk_scale, d_v_scale, s_kv_scale, s_q_padded, s_kv_padded, d_qk_scale_padded, d_v_scale_padded, s_kv_scale_padded, d_qk_padded, d_v_padded); + std::vector q_scale_dims = {b, h, s_q_padded, d_qk_scale_padded}; + std::vector k_scale_dims = {b, hg, s_kv_padded, d_qk_scale_padded}; + std::vector v_scale_dims = {b, hg, s_kv_scale_padded, d_v_padded}; std::vector q_scale_strides(4); std::vector k_scale_strides(4); std::vector v_scale_strides(4); - generateMatrixStrides(b, h, s_q_padded, s_kv_padded, d_scale_padded, q_scale_strides.data(), layout, + generateMatrixStrides(b, h, s_q_padded, s_kv_padded, d_qk_scale_padded, q_scale_strides.data(), layout, NVTE_QKV_Matrix::NVTE_Q_Matrix); - generateMatrixStrides(b, hg, s_q_padded, s_kv_padded, d_scale_padded, k_scale_strides.data(), layout, + generateMatrixStrides(b, hg, s_q_padded, s_kv_padded, d_qk_scale_padded, k_scale_strides.data(), layout, NVTE_QKV_Matrix::NVTE_K_Matrix); - generateMatrixStrides(b, hg, s_q_padded, s_kv_scale_padded, d_padded, v_scale_strides.data(), layout, + generateMatrixStrides(b, hg, s_q_padded, s_kv_scale_padded, d_v_padded, v_scale_strides.data(), layout, NVTE_QKV_Matrix::NVTE_V_Matrix); printf(">>>>>> q_scale_strides: %d, %d, %d, %d\n", q_scale_strides[0], q_scale_strides[1], q_scale_strides[2], q_scale_strides[3]); printf(">>>>>> k_scale_strides: %d, %d, %d, %d\n", k_scale_strides[0], k_scale_strides[1], k_scale_strides[2], k_scale_strides[3]); @@ -1809,17 +1812,17 @@ void fused_attn_fp8_fwd_impl_v1( Q = mha_graph->tensor(fe::graph::Tensor_attributes() .set_name("Q") - .set_dim({b, h, s_q, d}) + .set_dim({b, h, s_q, d_qk}) .set_stride(q_stride) .set_data_type(qkv_tensor_type)); K = mha_graph->tensor(fe::graph::Tensor_attributes() .set_name("K") - .set_dim({b, hg, s_kv, d}) + .set_dim({b, hg, s_kv, d_qk}) .set_stride(k_stride) .set_data_type(qkv_tensor_type)); V = mha_graph->tensor(fe::graph::Tensor_attributes() .set_name("V") - .set_dim({b, hg, s_kv, d}) + .set_dim({b, hg, s_kv, d_v}) .set_stride(v_stride) .set_data_type(qkv_tensor_type)); @@ -1931,9 +1934,9 @@ void fused_attn_fp8_fwd_impl_v1( } std::vector o_stride(4); - generateMatrixStrides(b, h, s_q, s_kv, d, o_stride.data(), layout, + generateMatrixStrides(b, h, s_q, s_kv, d_qk, o_stride.data(), layout, NVTE_QKV_Matrix::NVTE_O_Matrix); - O->set_output(true).set_dim({b, h, s_q, d}).set_stride(o_stride).set_data_type(o_tensor_type); + O->set_output(true).set_dim({b, h, s_q, d_qk}).set_stride(o_stride).set_data_type(o_tensor_type); amax_o->set_output(true) .set_dim({1, 1, 1, 1}) .set_stride({1, 1, 1, 1}) @@ -2050,7 +2053,7 @@ void fused_attn_fp8_fwd_impl_v1( // fused attention BWD FP8 with FE 1.0+ void fused_attn_fp8_bwd_impl_v1( - int64_t b, int64_t h, int64_t hg, int64_t s_q, int64_t s_kv, int64_t d, float scaling_factor, + int64_t b, int64_t h, int64_t hg, int64_t s_q, int64_t s_kv, int64_t d_qk, int64_t d_v, float scaling_factor, float dropout_probability, NVTE_QKV_Layout layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, void* devPtrQ, void* devPtrK, void* devPtrV, void* devPtrM, void* devPtrZInv, void* devPtrO, void* devPtrdO, void* devPtrdQ, void* devPtrdK, void* devPtrdV, @@ -2058,10 +2061,11 @@ void fused_attn_fp8_bwd_impl_v1( void* devPtrDescaledO, void* devPtrDescaleS, void* devPtrDescaledP, void* devPtrScaleS, void* devPtrScaledP, void* devPtrScaledQ, void* devPtrScaledK, void* devPtrScaledV, void* devPtrAmaxdP, void* devPtrAmaxdQ, void* devPtrAmaxdK, void* devPtrAmaxdV, + void* devPtrQ_t, void* devPtrK_t, void* devPtrdO_f16, void* devPtrdO_t, void* devPtrDescaleQ_t, void* devPtrDescaleK_t, void* devPtrDescaledO_t, void* devPtrcuSeqlensQ, void* devPtrcuSeqlensKV, void* devPtrDropoutSeed, void* devPtrDropoutOffset, cudnn_frontend::DataType_t qkv_tensor_type, cudnn_frontend::DataType_t o_tensor_type, cudnn_frontend::DataType_t do_tensor_type, - cudnn_frontend::DataType_t dqkv_tensor_type, void* workspace, size_t* workspace_size, + cudnn_frontend::DataType_t dqkv_tensor_type, NVTEScalingMode scaling_mode, void* workspace, size_t* workspace_size, cudaStream_t stream, cudnnHandle_t handle) { using namespace transformer_engine; bool is_bias = (bias_type == NVTE_Bias_Type::NVTE_POST_SCALE_BIAS); @@ -2075,13 +2079,24 @@ void fused_attn_fp8_bwd_impl_v1( auto bias_h = h; NVTE_CHECK(~is_bias, "FP8 fused attention does not support pre/post_scale_bias yet!"); NVTE_CHECK(~is_alibi, "FP8 fused attention does not support ALiBi yet!"); - bool is_current_scaling = (dqkv_tensor_type == cudnn_frontend::DataType_t::HALF || - dqkv_tensor_type == cudnn_frontend::DataType_t::BFLOAT16); - bool is_delayed_scaling = (dqkv_tensor_type == cudnn_frontend::DataType_t::FP8_E4M3 || - dqkv_tensor_type == cudnn_frontend::DataType_t::FP8_E5M2); - NVTE_CHECK(is_current_scaling || is_delayed_scaling, - "FP8 fused attention only supports dQKV tensor in kFloat16, kBFloat16, kFloat8E4M3 or " - "kFloat8E5M2!"); + bool is_delayed_scaling = (scaling_mode == NVTE_DELAYED_TENSOR_SCALING) && (dqkv_tensor_type == cudnn_frontend::DataType_t::FP8_E4M3 || + dqkv_tensor_type == cudnn_frontend::DataType_t::FP8_E5M2); + bool is_current_scaling = (scaling_mode == NVTE_DELAYED_TENSOR_SCALING) && (dqkv_tensor_type == cudnn_frontend::DataType_t::HALF || + dqkv_tensor_type == cudnn_frontend::DataType_t::BFLOAT16); + bool is_mxfp8 = (scaling_mode == NVTE_MXFP8_1D_SCALING) && (dqkv_tensor_type == cudnn_frontend::DataType_t::HALF || + dqkv_tensor_type == cudnn_frontend::DataType_t::BFLOAT16); + NVTE_CHECK(is_delayed_scaling || is_current_scaling || is_mxfp8, + "FP8 fused attention only supports FP8DelayedScaling or FP8CurrentScaling or MXFP8 recipes!"); + printf(">>>>>> b: %d, h: %d, h_g: %d, s_q: %d, s_kv: %d, d_qk: %d, d_v: %d\n", b, h, hg, s_q, s_kv, d_qk, d_v); + printf(">>>>>> scaling_mode: %d\n", scaling_mode); + printf(">>>>>> is_mxfp8: %d\n", is_mxfp8); + printf(">>>>>> is_current_scaling: %d\n", is_current_scaling); + printf(">>>>>> is_delayed_scaling: %d\n", is_delayed_scaling); + printf(">>>>>> qkv_tensor_type: %d, %d, %d, %d\n", qkv_tensor_type, cudnn_frontend::DataType_t::FP8_E8M0, cudnn_frontend::DataType_t::FP8_E4M3, cudnn_frontend::DataType_t::FP8_E5M2); + printf(">>>>>> o_tensor_type: %d, %d, %d, %d\n", o_tensor_type, cudnn_frontend::DataType_t::FLOAT, cudnn_frontend::DataType_t::BFLOAT16); + printf(">>>>>> do_tensor_type: %d, %d, %d\n", do_tensor_type, cudnn_frontend::DataType_t::FLOAT, cudnn_frontend::DataType_t::BFLOAT16); + printf(">>>>>> dqkv_tensor_type: %d, %d, %d, %d\n", dqkv_tensor_type, cudnn_frontend::DataType_t::FP8_E8M0, cudnn_frontend::DataType_t::FP8_E4M3, cudnn_frontend::DataType_t::FP8_E5M2); + bool is_O_in_F16 = (o_tensor_type == cudnn_frontend::DataType_t::HALF || o_tensor_type == cudnn_frontend::DataType_t::BFLOAT16); @@ -2091,8 +2106,8 @@ void fused_attn_fp8_bwd_impl_v1( hg, s_q, s_kv, - d, - d, + d_qk, + d_v, 0, 0, 0, @@ -2122,17 +2137,24 @@ void fused_attn_fp8_bwd_impl_v1( using graph_and_tensors = std::tuple, std::shared_ptr, // q + std::shared_ptr, // q_t std::shared_ptr, // k + std::shared_ptr, // k_t std::shared_ptr, // v std::shared_ptr, // o std::shared_ptr, // stats std::shared_ptr, // dO + std::shared_ptr, // dO_t + std::shared_ptr, // dO_f16 std::shared_ptr, // attn_scale std::shared_ptr, // descale_q + std::shared_ptr, // descale_q_t std::shared_ptr, // descale_k + std::shared_ptr, // descale_k_t std::shared_ptr, // descale_v std::shared_ptr, // descale_o std::shared_ptr, // descale_dO + std::shared_ptr, // descale_dO_t std::shared_ptr, // descale_s std::shared_ptr, // descale_dP std::shared_ptr, // scale_dQ @@ -2181,40 +2203,45 @@ void fused_attn_fp8_bwd_impl_v1( std::shared_ptr scale_dQ, scale_dK, scale_dV; std::shared_ptr bias, dBias, seq_q, seq_kv; std::shared_ptr dropout_seed, dropout_offset; + std::shared_ptr q_t, k_t, dO_t, dO_f16, descale_q_t, descale_k_t, descale_dO_t; std::vector q_stride(4); std::vector k_stride(4); std::vector v_stride(4); std::vector o_stride(4); - generateMatrixStrides(b, h, s_q, s_kv, d, q_stride.data(), layout, + generateMatrixStrides(b, h, s_q, s_kv, d_qk, q_stride.data(), layout, NVTE_QKV_Matrix::NVTE_Q_Matrix); - generateMatrixStrides(b, hg, s_q, s_kv, d, k_stride.data(), layout, + generateMatrixStrides(b, hg, s_q, s_kv, d_qk, k_stride.data(), layout, NVTE_QKV_Matrix::NVTE_K_Matrix); - generateMatrixStrides(b, hg, s_q, s_kv, d, v_stride.data(), layout, + generateMatrixStrides(b, hg, s_q, s_kv, d_v, v_stride.data(), layout, NVTE_QKV_Matrix::NVTE_V_Matrix); - generateMatrixStrides(b, h, s_q, s_kv, d, o_stride.data(), layout, + generateMatrixStrides(b, h, s_q, s_kv, d_qk, o_stride.data(), layout, NVTE_QKV_Matrix::NVTE_O_Matrix); q = mha_graph->tensor(fe::graph::Tensor_attributes() .set_name("Q") - .set_dim({b, h, s_q, d}) - .set_stride(q_stride)); + .set_dim({b, h, s_q, d_qk}) + .set_stride(q_stride) + .set_data_type(qkv_tensor_type)); k = mha_graph->tensor(fe::graph::Tensor_attributes() .set_name("K") - .set_dim({b, hg, s_kv, d}) - .set_stride(k_stride)); + .set_dim({b, hg, s_kv, d_qk}) + .set_stride(k_stride) + .set_data_type(qkv_tensor_type)); v = mha_graph->tensor(fe::graph::Tensor_attributes() .set_name("V") - .set_dim({b, hg, s_kv, d}) - .set_stride(v_stride)); + .set_dim({b, hg, s_kv, d_v}) + .set_stride(v_stride) + .set_data_type(qkv_tensor_type)); o = mha_graph->tensor(fe::graph::Tensor_attributes() .set_name("O") - .set_dim({b, h, s_q, d}) + .set_dim({b, h, s_q, d_qk}) .set_stride(o_stride) .set_data_type(o_tensor_type)); dO = mha_graph->tensor(fe::graph::Tensor_attributes() .set_name("dO") - .set_dim({b, h, s_q, d}) - .set_stride(o_stride)); + .set_dim({b, h, s_q, d_qk}) + .set_stride(o_stride) + .set_data_type(do_tensor_type)); stats = mha_graph->tensor(fe::graph::Tensor_attributes() .set_name("stats") .set_dim({b, h, s_q, 1}) @@ -2228,33 +2255,151 @@ void fused_attn_fp8_bwd_impl_v1( .set_is_pass_by_value(true) .set_data_type(fe::DataType_t::FLOAT)); - descale_q = mha_graph->tensor(fe::graph::Tensor_attributes() - .set_name("Descale_q") - .set_dim({1, 1, 1, 1}) - .set_stride({1, 1, 1, 1}) - .set_data_type(fe::DataType_t::FLOAT)); - descale_k = mha_graph->tensor_like(descale_q, "Descale_q"); - descale_v = mha_graph->tensor_like(descale_q, "Descale_V"); - descale_s = mha_graph->tensor_like(descale_q, "Descale_S"); - descale_dP = mha_graph->tensor_like(descale_q, "Descale_dP"); - if (is_O_in_F16) { - descale_o = mha_graph->tensor(1.0f); + if (!is_mxfp8) { + descale_q = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("Descale_q") + .set_dim({1, 1, 1, 1}) + .set_stride({1, 1, 1, 1}) + .set_data_type(fe::DataType_t::FLOAT)); + descale_k = mha_graph->tensor_like(descale_q, "Descale_q"); + descale_v = mha_graph->tensor_like(descale_q, "Descale_V"); + descale_s = mha_graph->tensor_like(descale_q, "Descale_S"); + descale_dP = mha_graph->tensor_like(descale_q, "Descale_dP"); + if (is_O_in_F16) { + descale_o = mha_graph->tensor(1.0f); + } else { + descale_o = mha_graph->tensor_like(descale_q, "Descale_O"); + } + descale_dO = mha_graph->tensor_like(descale_q, "Descale_dO"); + scale_s = mha_graph->tensor_like(descale_q, "Scale_S"); + scale_dP = mha_graph->tensor_like(descale_q, "Scale_dP"); + + if (is_delayed_scaling) { + scale_dQ = mha_graph->tensor_like(descale_q, "Scale_dQ"); + scale_dK = mha_graph->tensor_like(descale_q, "Scale_dK"); + scale_dV = mha_graph->tensor_like(descale_q, "Scale_dV"); + } + if (is_current_scaling) { + scale_dQ = mha_graph->tensor(1.0f); + scale_dK = mha_graph->tensor(1.0f); + scale_dV = mha_graph->tensor(1.0f); + } } else { - descale_o = mha_graph->tensor_like(descale_q, "Descale_O"); - } - descale_dO = mha_graph->tensor_like(descale_q, "Descale_dO"); - scale_s = mha_graph->tensor_like(descale_q, "Scale_S"); - scale_dP = mha_graph->tensor_like(descale_q, "Scale_dP"); - - if (is_delayed_scaling) { - scale_dQ = mha_graph->tensor_like(descale_q, "Scale_dQ"); - scale_dK = mha_graph->tensor_like(descale_q, "Scale_dK"); - scale_dV = mha_graph->tensor_like(descale_q, "Scale_dV"); - } - if (is_current_scaling) { - scale_dQ = mha_graph->tensor(1.0f); - scale_dK = mha_graph->tensor(1.0f); - scale_dV = mha_graph->tensor(1.0f); + std::vector q_t_stride(4); + std::vector k_t_stride(4); + std::vector dO_t_stride(4); + generateMatrixStrides(b, h, d_qk, s_kv, s_q, q_t_stride.data(), layout, + NVTE_QKV_Matrix::NVTE_Q_Matrix); + generateMatrixStrides(b, hg, s_q, s_kv, d_qk, k_t_stride.data(), layout, + NVTE_QKV_Matrix::NVTE_K_Matrix_Transpose); + generateMatrixStrides(b, h, d_qk, s_kv, s_q, dO_t_stride.data(), layout, + NVTE_QKV_Matrix::NVTE_O_Matrix); + int32_t block_size = 32; + int64_t d_qk_scale = (d_qk + block_size - 1) / block_size; + int64_t d_v_scale = (d_v + block_size - 1) / block_size; + int64_t s_q_scale = (s_q + block_size - 1) / block_size; + int64_t s_kv_scale = (s_kv + block_size - 1) / block_size; + int64_t s_q_padded = ((s_q + 127) / 128) * 128; + int64_t s_kv_padded = ((s_kv + 127) / 128) * 128; + int64_t d_qk_padded = ((d_qk + 3) / 4) * 4; + int64_t d_v_padded = ((d_v + 3) / 4) * 4; + int64_t d_qk_scale_padded = ((d_qk_scale + 3) / 4) * 4; + int64_t d_v_scale_padded = ((d_v_scale + 3) / 4) * 4; + int64_t s_q_scale_padded = ((s_q_scale + 3) / 4) * 4; + int64_t s_kv_scale_padded = ((s_kv_scale + 3) / 4) * 4; + printf(">>>>>> d_qk_scale: %d, d_v_scale: %d, s_q_scale: %d, s_kv_scale: %d, s_q_padded: %d, s_kv_padded: %d, d_qk_scale_padded: %d, d_v_scale_padded: %d, s_q_scale_padded: %d, s_kv_scale_padded: %d, d_qk_padded: %d, d_v_padded: %d\n", d_qk_scale, d_v_scale, s_q_scale, s_kv_scale, s_q_padded, s_kv_padded, d_qk_scale_padded, d_v_scale_padded, s_q_scale_padded, s_kv_scale_padded, d_qk_padded, d_v_padded); + // std::vector q_scale_dims = {b, h, s_q_padded, d_qk_scale_padded}; + // std::vector k_scale_dims = {b, hg, s_kv_padded, d_qk_scale_padded}; + // std::vector v_scale_dims = {b, hg, s_kv_padded, d_v_scale_padded}; + // std::vector q_t_scale_dims = {b, h, s_q_scale_padded, d_qk_padded}; + // std::vector k_t_scale_dims = {b, hg, s_kv_scale_padded, d_qk_padded}; + // // std::vector dO_scale_dims = {b, h, s_q_padded, d_qk_scale_padded}; + // // std::vector dO_t_scale_dims = {b, h, s_q_scale_padded, d_qk_padded}; + std::vector q_scale_strides(4); + std::vector k_scale_strides(4); + std::vector v_scale_strides(4); + std::vector q_t_scale_strides(4); + std::vector k_t_scale_strides(4); + // std::vector dO_scale_strides(4); + // std::vector dO_t_scale_strides(4); + generateMatrixStrides(b, h, s_q_padded, s_kv_padded, d_qk_scale_padded, q_scale_strides.data(), layout, + NVTE_QKV_Matrix::NVTE_Q_Matrix); + generateMatrixStrides(b, h, d_qk_padded, s_kv_scale_padded, s_q_scale_padded, q_t_scale_strides.data(), layout, + NVTE_QKV_Matrix::NVTE_Q_Matrix); + generateMatrixStrides(b, hg, s_q_padded, s_kv_padded, d_qk_scale_padded, k_scale_strides.data(), layout, + NVTE_QKV_Matrix::NVTE_K_Matrix); + generateMatrixStrides(b, hg, s_q_scale_padded, s_kv_scale_padded, d_qk_padded, k_t_scale_strides.data(), layout, + NVTE_QKV_Matrix::NVTE_K_Matrix_Transpose); + generateMatrixStrides(b, hg, s_q_padded, s_kv_padded, d_v_scale_padded, v_scale_strides.data(), layout, + NVTE_QKV_Matrix::NVTE_V_Matrix); + printf(">>>>>> q_scale_strides: %d, %d, %d, %d\n", q_scale_strides[0], q_scale_strides[1], q_scale_strides[2], q_scale_strides[3]); + printf(">>>>>> k_scale_strides: %d, %d, %d, %d\n", k_scale_strides[0], k_scale_strides[1], k_scale_strides[2], k_scale_strides[3]); + printf(">>>>>> v_scale_strides: %d, %d, %d, %d\n", v_scale_strides[0], v_scale_strides[1], v_scale_strides[2], v_scale_strides[3]); + printf(">>>>>> q_t_scale_strides: %d, %d, %d, %d\n", q_t_scale_strides[0], q_t_scale_strides[1], q_t_scale_strides[2], q_t_scale_strides[3]); + printf(">>>>>> k_t_scale_strides: %d, %d, %d, %d\n", k_t_scale_strides[0], k_t_scale_strides[1], k_t_scale_strides[2], k_t_scale_strides[3]); + + q_t = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("Q_t") + .set_dim({b, h, s_q, d_qk}) + .set_stride(q_t_stride) + .set_data_type(qkv_tensor_type)); + k_t = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("K_t") + .set_dim({b, hg, s_kv, d_qk}) + .set_stride(k_t_stride) + .set_data_type(qkv_tensor_type)); + dO_t = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("dO_t") + .set_dim({b, h, s_q, d_qk}) + .set_stride(dO_t_stride) + .set_data_type(do_tensor_type)); + dO_f16 = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("dO_f16") + .set_dim({b, h, s_q, d_qk}) + .set_stride(dO_t_stride) + .set_data_type(o_tensor_type)); + descale_q = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("Descale_q") + .set_dim({b, h, s_q_padded, d_qk_scale_padded}) + .set_stride(q_scale_strides) + .set_data_type(fe::DataType_t::FP8_E8M0) + .set_reordering_type(fe::TensorReordering_t::F8_128x4)); + descale_q_t = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("Descale_q_t") + .set_dim({b, h, s_q_scale_padded, d_qk_padded}) + .set_stride(q_t_scale_strides) + .set_data_type(fe::DataType_t::FP8_E8M0) + .set_reordering_type(fe::TensorReordering_t::F8_128x4)); + descale_k = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("Descale_k") + .set_dim({b, hg, s_kv_padded, d_qk_scale_padded}) + .set_stride(k_scale_strides) + .set_data_type(fe::DataType_t::FP8_E8M0) + .set_reordering_type(fe::TensorReordering_t::F8_128x4)); + descale_k_t = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("Descale_k_t") + .set_dim({b, hg, s_kv_scale_padded, d_qk_padded}) + .set_stride(k_t_scale_strides) + .set_data_type(fe::DataType_t::FP8_E8M0) + .set_reordering_type(fe::TensorReordering_t::F8_128x4)); + descale_v = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("Descale_v") + .set_dim({b, hg, s_kv_padded, d_v_scale_padded}) + .set_stride(v_scale_strides) + .set_data_type(fe::DataType_t::FP8_E8M0) + .set_reordering_type(fe::TensorReordering_t::F8_128x4)); + descale_dO = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("Descale_dO") + .set_dim({b, h, s_q_padded, d_qk_scale_padded}) + .set_stride(q_scale_strides) + .set_data_type(fe::DataType_t::FP8_E8M0) + .set_reordering_type(fe::TensorReordering_t::F8_128x4)); + descale_dO_t = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("Descale_dO_t") + .set_dim({b, h, s_q_scale_padded, d_qk_padded}) + .set_stride(q_t_scale_strides) + .set_data_type(fe::DataType_t::FP8_E8M0) + .set_reordering_type(fe::TensorReordering_t::F8_128x4)); } fe::graph::SDPA_fp8_backward_attributes sdpa_backward_options; @@ -2312,14 +2457,18 @@ void fused_attn_fp8_bwd_impl_v1( .set_data_type(fe::DataType_t::INT64)); sdpa_backward_options.set_dropout(dropout_probability, dropout_seed, dropout_offset); } - + // if (!is_mxfp8) { auto [dQ, dK, dV, amax_dQ, amax_dK, amax_dV, amax_dP] = mha_graph->sdpa_fp8_backward( q, k, v, o, dO, stats, descale_q, descale_k, descale_v, descale_o, descale_dO, descale_s, descale_dP, scale_s, scale_dQ, scale_dK, scale_dV, scale_dP, sdpa_backward_options); - - dQ->set_output(true).set_dim({b, h, s_q, d}).set_stride(q_stride); - dK->set_output(true).set_dim({b, hg, s_kv, d}).set_stride(k_stride); - dV->set_output(true).set_dim({b, hg, s_kv, d}).set_stride(v_stride); + // } else { + // auto [dQ, dK, dV, amax_dQ, amax_dK, amax_dV] = mha_graph->sdpa_fp8_backward( + // q, q_t, k, k_t, v, o, dO_f16, dO, dO_t, stats, descale_q, descale_q_t, descale_k, descale_k_t, descale_v, descale_dO, descale_dO_t, + // sdpa_backward_options); + // } + dQ->set_output(true).set_dim({b, h, s_q, d_qk}).set_stride(q_stride).set_data_type(dqkv_tensor_type); + dK->set_output(true).set_dim({b, hg, s_kv, d_qk}).set_stride(k_stride).set_data_type(dqkv_tensor_type); + dV->set_output(true).set_dim({b, hg, s_kv, d_v}).set_stride(v_stride).set_data_type(dqkv_tensor_type); amax_dQ->set_output(true) .set_dim({1, 1, 1, 1}) .set_stride({1, 1, 1, 1}) @@ -2332,28 +2481,36 @@ void fused_attn_fp8_bwd_impl_v1( .set_dim({1, 1, 1, 1}) .set_stride({1, 1, 1, 1}) .set_data_type(fe::DataType_t::FLOAT); + if (!is_mxfp8) { amax_dP->set_output(true) .set_dim({1, 1, 1, 1}) .set_stride({1, 1, 1, 1}) .set_data_type(fe::DataType_t::FLOAT); - - dO->set_data_type(do_tensor_type); - dQ->set_data_type(dqkv_tensor_type); - dK->set_data_type(dqkv_tensor_type); - dV->set_data_type(dqkv_tensor_type); + } + // dO->set_data_type(do_tensor_type); + // dQ->set_data_type(dqkv_tensor_type); + // dK->set_data_type(dqkv_tensor_type); + // dV->set_data_type(dqkv_tensor_type); std::tuple, // q + // std::shared_ptr, // q_t std::shared_ptr, // k + // std::shared_ptr, // k_t std::shared_ptr, // v std::shared_ptr, // o std::shared_ptr, // stats std::shared_ptr, // dO + // std::shared_ptr, // dO_t + // std::shared_ptr, // dO_f16 std::shared_ptr, // attn_scale std::shared_ptr, // descale_q + // std::shared_ptr, // descale_q_t std::shared_ptr, // descale_k + // std::shared_ptr, // descale_k_t std::shared_ptr, // descale_v std::shared_ptr, // descale_o std::shared_ptr, // descale_dO + // std::shared_ptr, // descale_dO_t std::shared_ptr, // descale_s std::shared_ptr, // descale_dP std::shared_ptr, // scale_dQ @@ -2372,6 +2529,8 @@ void fused_attn_fp8_bwd_impl_v1( q, k, v, o, stats, dO, attn_scale, descale_q, descale_k, descale_v, descale_o, descale_dO, descale_s, descale_dP, scale_s, scale_dQ, scale_dK, scale_dV, scale_dP, dQ, dK, dV, amax_dQ, amax_dK, amax_dV, amax_dP); + auto mxfp8_tensors_tuple = is_mxfp8 ? std::make_tuple(q_t, k_t, dO_f16, dO_t, descale_q_t, descale_k_t, descale_dO_t) : std::make_tuple(nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr); + // key_tensors_tuple = std::tuple_cat(key_tensors_tuple, mxfp8_tensors_tuple); auto bias_tuple = is_bias ? std::make_tuple(bias, dBias) : std::make_tuple(nullptr, nullptr); auto padding_tuple = is_padding ? std::make_tuple(seq_q, seq_kv) : std::make_tuple(nullptr, nullptr); @@ -2385,17 +2544,64 @@ void fused_attn_fp8_bwd_impl_v1( NVTE_CHECK_CUDNN_FE(mha_graph->build_plans(handle)); auto return_tuple = std::tuple_cat(std::make_tuple(mha_graph), key_tensors_tuple, bias_tuple, - padding_tuple, dropout_tuple); + padding_tuple, dropout_tuple, mxfp8_tensors_tuple); cache.insert({descriptor, return_tuple}); return return_tuple; }; - - auto [mha_graph, q, k, v, o, stats, dO, attn_scale, descale_q, descale_k, descale_v, descale_o, - descale_dO, descale_s, descale_dP, scale_s, scale_dQ, scale_dK, scale_dV, scale_dP, dQ, - dK, dV, amax_dQ, amax_dK, amax_dV, amax_dP, bias, dBias, seq_q, seq_kv, dropout_seed, - dropout_offset] = get_graph(sdpa_fp8_bprop_cache, descriptor); - + auto bprop_tuple = get_graph(sdpa_fp8_bprop_cache, descriptor); + // if (!is_mxfp8) { + // auto [mha_graph, q, k, v, o, stats, dO, attn_scale, descale_q, descale_k, descale_v, descale_o, + // descale_dO, descale_s, descale_dP, scale_s, scale_dQ, scale_dK, scale_dV, scale_dP, dQ, + // dK, dV, amax_dQ, amax_dK, amax_dV, amax_dP, bias, dBias, seq_q, seq_kv, dropout_seed, + // dropout_offset] = get_graph(sdpa_fp8_bprop_cache, descriptor); + auto mha_graph = std::get<0>(bprop_tuple); + auto q = std::get<1>(bprop_tuple); + auto k = std::get<2>(bprop_tuple); + auto v = std::get<3>(bprop_tuple); + auto o = std::get<4>(bprop_tuple); + auto stats = std::get<5>(bprop_tuple); + auto dO = std::get<6>(bprop_tuple); + auto attn_scale = std::get<7>(bprop_tuple); + auto descale_q = std::get<8>(bprop_tuple); + auto descale_k = std::get<9>(bprop_tuple); + auto descale_v = std::get<10>(bprop_tuple); + auto descale_o = std::get<11>(bprop_tuple); + auto descale_dO = std::get<12>(bprop_tuple); + auto descale_s = std::get<13>(bprop_tuple); + auto descale_dP = std::get<14>(bprop_tuple); + auto scale_s = std::get<15>(bprop_tuple); + auto scale_dQ = std::get<16>(bprop_tuple); + auto scale_dK = std::get<17>(bprop_tuple); + auto scale_dV = std::get<18>(bprop_tuple); + auto scale_dP = std::get<19>(bprop_tuple); + auto dQ = std::get<20>(bprop_tuple); + auto dK = std::get<21>(bprop_tuple); + auto dV = std::get<22>(bprop_tuple); + auto amax_dQ = std::get<23>(bprop_tuple); + auto amax_dK = std::get<24>(bprop_tuple); + auto amax_dV = std::get<25>(bprop_tuple); + auto amax_dP = std::get<26>(bprop_tuple); + auto bias = std::get<27>(bprop_tuple); + auto dBias = std::get<28>(bprop_tuple); + auto seq_q = std::get<29>(bprop_tuple); + auto seq_kv = std::get<30>(bprop_tuple); + auto dropout_seed = std::get<31>(bprop_tuple); + auto dropout_offset = std::get<32>(bprop_tuple); + // } else { + // if (is_mxfp8) { + // auto [mha_graph, q, k, v, o, stats, dO, attn_scale, descale_q, descale_k, descale_v, descale_o, + // descale_dO, descale_s, descale_dP, scale_s, scale_dQ, scale_dK, scale_dV, scale_dP, dQ, + // dK, dV, amax_dQ, amax_dK, amax_dV, amax_dP, q_t, k_t, dO_f16, dO_t, descale_q_t, descale_k_t, descale_dO_t, bias, dBias, seq_q, seq_kv, dropout_seed, + // dropout_offset] = get_graph(sdpa_fp8_bprop_cache, descriptor); + auto q_t = std::get<33>(bprop_tuple); + auto k_t = std::get<34>(bprop_tuple); + auto dO_f16 = std::get<35>(bprop_tuple); + auto dO_t = std::get<36>(bprop_tuple); + auto descale_q_t = std::get<37>(bprop_tuple); + auto descale_k_t = std::get<38>(bprop_tuple); + auto descale_dO_t = std::get<39>(bprop_tuple); + // } auto plan_workspace_size = mha_graph->get_workspace_size(); // Exit to request upper level API to allocate memory if needed @@ -2422,25 +2628,36 @@ void fused_attn_fp8_bwd_impl_v1( {descale_k, devPtrDescaleK}, {descale_v, devPtrDescaleV}, {descale_dO, devPtrDescaledO}, - {descale_s, devPtrDescaleS}, - {descale_dP, devPtrDescaledP}, - {scale_s, devPtrScaleS}, - {scale_dP, devPtrScaledP}, {dQ, devPtrdQ}, {dK, devPtrdK}, {dV, devPtrdV}, {amax_dQ, devPtrAmaxdQ}, {amax_dK, devPtrAmaxdK}, {amax_dV, devPtrAmaxdV}, - {amax_dP, devPtrAmaxdP}, }; + if (!is_mxfp8) { + variant_pack[descale_s] = devPtrDescaleS; + variant_pack[descale_dP] = devPtrDescaledP; + variant_pack[scale_s] = devPtrScaleS; + variant_pack[scale_dP] = devPtrScaledP; + variant_pack[amax_dP] = devPtrAmaxdP; + } else { + variant_pack[q_t] = devPtrQ_t; + variant_pack[k_t] = devPtrK_t; + variant_pack[dO_f16] = devPtrdO_f16; + variant_pack[dO_t] = devPtrdO_t; + variant_pack[descale_q_t] = devPtrDescaleQ_t; + variant_pack[descale_k_t] = devPtrDescaleK_t; + variant_pack[descale_dO] = devPtrDescaledO; + variant_pack[descale_dO_t] = devPtrDescaledO_t; + } if (is_delayed_scaling) { variant_pack[scale_dQ] = devPtrScaledQ; variant_pack[scale_dK] = devPtrScaledK; variant_pack[scale_dV] = devPtrScaledV; } - if (!is_O_in_F16) { + if (is_current_scaling && !is_O_in_F16) { variant_pack[descale_o] = devPtrDescaleO; } @@ -2485,7 +2702,7 @@ void fused_attn_fp8_bwd_impl_v1( #if (CUDNN_VERSION >= 8900) // fused attention FWD FP8 with separate Q, K, V void fused_attn_fp8_fwd(size_t batch, size_t num_attn_heads, size_t num_gqa_groups, - size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim, + size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, bool is_training, float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, const Tensor* input_Q, const Tensor* input_K, @@ -2576,7 +2793,7 @@ void fused_attn_fp8_fwd(size_t batch, size_t num_attn_heads, size_t num_gqa_grou NVTE_QKV_Format qkv_format = nvte_get_qkv_format(qkv_layout); if ((qkv_format == NVTE_QKV_Format::NVTE_BSHD) || (qkv_format == NVTE_QKV_Format::NVTE_SBHD) || (qkv_format == NVTE_QKV_Format::NVTE_BHSD)) { fused_attn::fused_attn_fp8_fwd_impl_v1( - batch, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, head_dim, is_training, + batch, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, head_dim_qk, head_dim_v, is_training, attn_scale, p_dropout, qkv_layout, bias_type, mask_type, devPtrQ, devPtrK, devPtrV, devPtrM, devPtrZInv, devPtrO, devPtrDescaleQ, devPtrDescaleK, devPtrDescaleV, devPtrDescaleS, devPtrScaleS, devPtrScaleO, devPtrAmaxO, devPtrAmaxS, devPtrcuSeqlensQ, devPtrcuSeqlensKV, @@ -2585,7 +2802,7 @@ void fused_attn_fp8_fwd(size_t batch, size_t num_attn_heads, size_t num_gqa_grou stream, handle); } else if (qkv_layout == NVTE_QKV_Layout::NVTE_T3HD) { fused_attn::fused_attn_fp8_fwd_impl( - batch, num_attn_heads, max_seqlen_q, max_seqlen_kv, head_dim, is_training, attn_scale, + batch, num_attn_heads, max_seqlen_q, max_seqlen_kv, head_dim_qk, is_training, attn_scale, p_dropout, qkv_layout, devPtrQ, devPtrK, devPtrV, devPtrM, devPtrZInv, devPtrO, devPtrDescaleQ, devPtrDescaleK, devPtrDescaleV, devPtrDescaleS, devPtrScaleS, devPtrScaleO, devPtrAmaxO, devPtrAmaxS, devPtrcuSeqlensQ, devPtrcuSeqlensKV, devPtrDropoutSeed, @@ -2609,11 +2826,11 @@ void fused_attn_fp8_fwd(size_t batch, size_t num_attn_heads, size_t num_gqa_grou } // fused attention BWD FP8 with separate Q, K, V void fused_attn_fp8_bwd(size_t batch, size_t num_attn_heads, size_t num_gqa_groups, - size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim, + size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, const Tensor* input_Q, const Tensor* input_K, const Tensor* input_V, const Tensor* input_O, - const Tensor* input_dO, const Tensor* input_M, const Tensor* input_ZInv, + const Tensor* input_dO, const Tensor* input_dO_f16, const Tensor* input_M, const Tensor* input_ZInv, const Tensor* input_S, Tensor* input_output_dP, const Tensor* output_dQ, const Tensor* output_dK, const Tensor* output_dV, const Tensor* cu_seqlens_q, const Tensor* cu_seqlens_kv, @@ -2626,6 +2843,10 @@ void fused_attn_fp8_bwd(size_t batch, size_t num_attn_heads, size_t num_gqa_grou void* devPtrDescaleQ = input_Q->scale_inv.dptr; void* devPtrDescaleK = input_Q->scale_inv.dptr; void* devPtrDescaleV = input_Q->scale_inv.dptr; + void* devPtrQ_t = input_Q->columnwise_data.dptr; + void* devPtrK_t = input_K->columnwise_data.dptr; + void* devPtrDescaleQ_t = input_Q->columnwise_scale_inv.dptr; + void* devPtrDescaleK_t = input_K->columnwise_scale_inv.dptr; void* devPtrO = input_O->data.dptr; const DType O_type = input_O->data.dtype; @@ -2635,6 +2856,9 @@ void fused_attn_fp8_bwd(size_t batch, size_t num_attn_heads, size_t num_gqa_grou } void* devPtrdO = input_dO->data.dptr; void* devPtrDescaledO = input_dO->scale_inv.dptr; + void* devPtrdO_t = input_dO->columnwise_data.dptr; + void* devPtrdO_f16 = input_dO_f16->data.dptr; + void* devPtrDescaledO_t = input_dO->columnwise_scale_inv.dptr; void* devPtrM = input_M->data.dptr; void* devPtrZInv = input_ZInv->data.dptr; @@ -2672,18 +2896,20 @@ void fused_attn_fp8_bwd(size_t batch, size_t num_attn_heads, size_t num_gqa_grou NVTE_QKV_Format qkv_format = nvte_get_qkv_format(qkv_layout); if ((qkv_format == NVTE_QKV_Format::NVTE_BSHD) || (qkv_format == NVTE_QKV_Format::NVTE_SBHD)) { fused_attn::fused_attn_fp8_bwd_impl_v1( - batch, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, head_dim, attn_scale, + batch, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, head_dim_qk, head_dim_v, attn_scale, p_dropout, qkv_layout, bias_type, mask_type, devPtrQ, devPtrK, devPtrV, devPtrM, devPtrZInv, devPtrO, devPtrdO, devPtrdQ, devPtrdK, devPtrdV, devPtrDescaleQ, devPtrDescaleK, devPtrDescaleV, devPtrDescaleO, devPtrDescaledO, devPtrDescaleS, devPtrDescaledP, devPtrScaleS, devPtrScaledP, devPtrScaledQ, devPtrScaledK, devPtrScaledV, devPtrAmaxdP, - devPtrAmaxdQ, devPtrAmaxdK, devPtrAmaxdV, devPtrcuSeqlensQ, devPtrcuSeqlensKV, + devPtrAmaxdQ, devPtrAmaxdK, devPtrAmaxdV, + devPtrQ_t, devPtrK_t, devPtrdO_f16, devPtrdO_t, devPtrDescaleQ_t, devPtrDescaleK_t, devPtrDescaledO_t, + devPtrcuSeqlensQ, devPtrcuSeqlensKV, devPtrDropoutSeed, devPtrDropoutOffset, get_cudnn_fe_dtype(QKV_type), get_cudnn_fe_dtype(O_type), get_cudnn_fe_dtype(dO_type), get_cudnn_fe_dtype(dQKV_type), - workspace->data.dptr, &workspace_size, stream, handle); + input_dO->scaling_mode, workspace->data.dptr, &workspace_size, stream, handle); } else if (qkv_layout == NVTE_QKV_Layout::NVTE_T3HD) { fused_attn::fused_attn_fp8_bwd_impl( - batch, num_attn_heads, max_seqlen_q, max_seqlen_kv, head_dim, attn_scale, p_dropout, + batch, num_attn_heads, max_seqlen_q, max_seqlen_kv, head_dim_qk, attn_scale, p_dropout, qkv_layout, devPtrQ, devPtrK, devPtrV, devPtrM, devPtrZInv, devPtrO, devPtrdO, devPtrdQ, devPtrdK, devPtrdV, devPtrDescaleQ, devPtrDescaleK, devPtrDescaleV, devPtrDescaleO, devPtrDescaledO, devPtrDescaleS, devPtrDescaledP, devPtrScaleS, devPtrScaledP, diff --git a/transformer_engine/common/fused_attn/fused_attn_fp8.h b/transformer_engine/common/fused_attn/fused_attn_fp8.h index a1a932fdf5..f335bc3d85 100644 --- a/transformer_engine/common/fused_attn/fused_attn_fp8.h +++ b/transformer_engine/common/fused_attn/fused_attn_fp8.h @@ -15,7 +15,7 @@ namespace transformer_engine { #if (CUDNN_VERSION >= 8900) // fused attention FWD FP8 with separate Q, K, V void fused_attn_fp8_fwd(size_t batch, size_t num_attn_heads, size_t num_gqa_groups, - size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim, + size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, bool is_training, float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, const Tensor *input_Q, const Tensor *input_K, @@ -26,11 +26,11 @@ void fused_attn_fp8_fwd(size_t batch, size_t num_attn_heads, size_t num_gqa_grou // fused attention BWD FP8 with separate Q, K, V void fused_attn_fp8_bwd(size_t batch, size_t num_attn_heads, size_t num_gqa_groups, - size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim, + size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, const Tensor *input_Q, const Tensor *input_K, const Tensor *input_V, const Tensor *input_O, - const Tensor *input_dO, const Tensor *input_M, const Tensor *input_ZInv, + const Tensor *input_dO, const Tensor *input_dO_f16, const Tensor *input_M, const Tensor *input_ZInv, const Tensor *input_S, Tensor *input_output_dP, const Tensor *output_dQ, const Tensor *output_dK, const Tensor *output_dV, const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv, From 730a472af86c2a47d996035f9a8bd5e4c409c0f0 Mon Sep 17 00:00:00 2001 From: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Date: Fri, 13 Feb 2026 17:55:59 -0800 Subject: [PATCH 011/172] fix last merge Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --- transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh | 3 --- transformer_engine/common/common.h | 4 +--- transformer_engine/pytorch/tensor/storage/grouped_tensor.py | 4 ---- 3 files changed, 1 insertion(+), 10 deletions(-) diff --git a/transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh b/transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh index 64896810f8..7a8ab8062c 100644 --- a/transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh +++ b/transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh @@ -852,9 +852,6 @@ void group_quantize(const GroupedTensor *input, const GroupedTensor *activations NVTE_CHECK(scales_colwise_ptr != nullptr, "Columnwise scaling tensor must be allocated"); } - const bool with_gemm_swizzled_scales = output->with_gemm_swizzled_scales; - printf(">>>>>>>>>>>> group_quantize_mxfp8 with_gemm_swizzled_scales: %d\n", with_gemm_swizzled_scales); - const size_t scale_stride_rowwise = use_rowwise_scaling ? output->scale_inv.shape[1] : 1; const size_t scale_stride_colwise = use_colwise_scaling ? output->columnwise_scale_inv.shape[1] : 1; diff --git a/transformer_engine/common/common.h b/transformer_engine/common/common.h index b14653aca7..2d7f0e7e8c 100644 --- a/transformer_engine/common/common.h +++ b/transformer_engine/common/common.h @@ -355,9 +355,7 @@ struct GroupedTensor { last_dims(nullptr, std::vector{0}, DType::kInt64), tensor_offsets(nullptr, std::vector{0}, DType::kInt64), logical_shape(nvte_make_shape(nullptr, 1)), - scaling_mode(scaling_mode), - nvte_tensor(0), - with_gemm_swizzled_scales(false) {} + nvte_tensor(0) {} explicit operator NVTEGroupedTensor() const noexcept { return nvte_tensor; } diff --git a/transformer_engine/pytorch/tensor/storage/grouped_tensor.py b/transformer_engine/pytorch/tensor/storage/grouped_tensor.py index b6c8818ab8..123dfcf22a 100644 --- a/transformer_engine/pytorch/tensor/storage/grouped_tensor.py +++ b/transformer_engine/pytorch/tensor/storage/grouped_tensor.py @@ -8,12 +8,8 @@ import math import torch -<<<<<<< HEAD import transformer_engine import transformer_engine_torch as tex -======= - ->>>>>>> main from ...quantized_tensor import QuantizedTensorStorage, Quantizer from ..mxfp8_tensor import MXFP8Tensor From d9ff5662aa4b4b6267c77baf614aada6602fa133 Mon Sep 17 00:00:00 2001 From: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Date: Fri, 13 Feb 2026 17:56:47 -0800 Subject: [PATCH 012/172] update FE Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --- .gitmodules | 3 ++- 3rdparty/cudnn-frontend | 2 +- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/.gitmodules b/.gitmodules index 4b188d6bb1..8c7646c00d 100644 --- a/.gitmodules +++ b/.gitmodules @@ -3,7 +3,8 @@ url = https://github.com/google/googletest.git [submodule "3rdparty/cudnn-frontend"] path = 3rdparty/cudnn-frontend - url = https://github.com/NVIDIA/cudnn-frontend.git + url = https://gitlab-master.nvidia.com/cudnn/cudnn_frontend.git + branch = develop [submodule "3rdparty/cutlass"] path = 3rdparty/cutlass url = https://github.com/NVIDIA/cutlass.git diff --git a/3rdparty/cudnn-frontend b/3rdparty/cudnn-frontend index 209a25fe89..4b4df2edcf 160000 --- a/3rdparty/cudnn-frontend +++ b/3rdparty/cudnn-frontend @@ -1 +1 @@ -Subproject commit 209a25fe898bb749c8605363a6431e26cb41b089 +Subproject commit 4b4df2edcf80b7cc7a659da5734454577154aa6d From 2b264d72f663707ccb923d7259603c53872306d3 Mon Sep 17 00:00:00 2001 From: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Date: Fri, 13 Feb 2026 18:00:51 -0800 Subject: [PATCH 013/172] attempt at SWA/MLA Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --- tests/pytorch/attention/test_attention.py | 13 +++++-- .../common/fused_attn/fused_attn.cpp | 18 +++++---- .../common/fused_attn/fused_attn_fp8.cu | 39 ++++++++++++------- .../common/fused_attn/fused_attn_fp8.h | 4 +- .../attention/dot_product_attention/utils.py | 22 +++++------ 5 files changed, 57 insertions(+), 39 deletions(-) diff --git a/tests/pytorch/attention/test_attention.py b/tests/pytorch/attention/test_attention.py index 74deeceed2..05d76d96fe 100644 --- a/tests/pytorch/attention/test_attention.py +++ b/tests/pytorch/attention/test_attention.py @@ -1789,7 +1789,7 @@ def get_model(dtype, config): model_configs_fp8_vs_f16 = { # test: ModelConfig(b, sq, hq, dqk) "fp8_9": ModelConfig(2, 2048, 16, 128), - "fp8_10": ModelConfig(2, 2048, 24, 128, num_gqa_groups=12), + "fp8_10": ModelConfig(2, 2048, 24, 192, head_dim_v=128, num_gqa_groups=12, window_size=(512, 512)), "fp8_11": ModelConfig(1, 8192, 32, 128, num_gqa_groups=4), "fp8_12": ModelConfig(2, 2048, 16, 128, attn_mask_type="causal"), "fp8_13": ModelConfig(2, 2048, 24, 128, num_gqa_groups=12, attn_mask_type="causal"), @@ -2259,7 +2259,7 @@ def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker: with quantized_model_init(enabled=fp8_dpa): dpa = DotProductAttention( config.num_heads, - config.head_dim_qk, + (config.head_dim_qk, config.head_dim_v), num_gqa_groups=config.num_gqa_groups, attention_dropout=config.dropout_p, sequence_parallel=False, @@ -2304,7 +2304,8 @@ def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker: "skv": config.max_seqlen_kv, "h": config.num_heads, "hg": config.num_gqa_groups, - "d": config.head_dim_qk, + "dqk": config.head_dim_qk, + "dv": config.head_dim_v, "t": cu_seqlens_q[-1], "tg": cu_seqlens_kv[-1], "3": 3, @@ -2320,6 +2321,10 @@ def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker: layout = layout.replace("s", "skv") layout = layout.replace("h", "hg") layout = layout.replace("t", "tg") + if i == 2: + layout = layout.replace("d", "dv") + else: + layout = layout.replace("d", "dqk") tensor_shape = [dim_to_num[j] for j in layout.split("_")] if config.dropout_p == 0.0: tensor = torch.randn(tensor_shape, dtype=dtype, device="cuda") @@ -2344,6 +2349,7 @@ def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker: qkv_format_kv = "_".join(qkv_format) qkv_format_kv = qkv_format_kv.replace("s", "sq") + qkv_format_kv = qkv_format_kv.replace("d", "dv") out_grad_shape = [dim_to_num[i] for i in qkv_format_kv.split("_")] out_grad_shape_new = [*out_grad_shape[:-2], out_grad_shape[-2] * out_grad_shape[-1]] out_grad = torch.randn(out_grad_shape_new, dtype=dtype, device="cuda") @@ -2354,6 +2360,7 @@ def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker: inp[1], inp[2], qkv_format=qkv_format, + window_size=config.window_size, cu_seqlens_q=cu_seqlens_q, cu_seqlens_kv=cu_seqlens_kv, max_seqlen_q=config.max_seqlen_q, diff --git a/transformer_engine/common/fused_attn/fused_attn.cpp b/transformer_engine/common/fused_attn/fused_attn.cpp index 98ff96b666..6f343d90b2 100644 --- a/transformer_engine/common/fused_attn/fused_attn.cpp +++ b/transformer_engine/common/fused_attn/fused_attn.cpp @@ -269,7 +269,9 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( (attn_mask_type == NVTE_Mask_Type::NVTE_NO_MASK || attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK || attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK || - attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK))) && + attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK)) || + // 9.21: mxfp8, d_qk=128, d_v=192 + (cudnn_runtime_version >= 92100 && head_dim_qk <= 192 && head_dim_v <= 128)) && !requires_64bit_ragged_offset && (softmax_type == NVTE_Softmax_Type::NVTE_VANILLA_SOFTMAX) && // 9.10.0: known bugs with SDPA FP8 (cudnn_runtime_version != 91000) && !return_max_logit) { @@ -425,7 +427,7 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( (cudnn_runtime_version >= 90200 && ((window_size_left == -1 && window_size_right == -1 && attn_mask_type == NVTE_Mask_Type::NVTE_NO_MASK) || - ((window_size_left == -1 || window_size_left >= 0) && window_size_right == 0 && + ((window_size_left == -1 || window_size_left >= 0) && (window_size_right == -1 || window_size_right >= 0) && (attn_mask_type == NVTE_Mask_Type::NVTE_NO_MASK || attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK || (attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_BOTTOM_RIGHT_MASK && @@ -640,7 +642,7 @@ void nvte_fused_attn_fwd_qkvpacked( Tensor V_view = make_tensor_view(input_QKV, unpacked_shape, 2 * stride); fused_attn_fp8_fwd(b, h, h, max_seqlen, max_seqlen, d, d, is_training, attn_scale, dropout, - qkv_layout, bias_type, attn_mask_type, &Q_view, &K_view, &V_view, + qkv_layout, bias_type, attn_mask_type, window_size_left, window_size_right, &Q_view, &K_view, &V_view, input_output_S, output_O, Aux_CTX_Tensors, input_cu_seqlens, input_cu_seqlens, input_rng_state, wkspace, stream, handle); #else @@ -792,7 +794,7 @@ void nvte_fused_attn_bwd_qkvpacked( Tensor dV_view = make_tensor_view(output_dQKV, unpacked_shape, 2 * stride); fused_attn_fp8_bwd(b, h, h, max_seqlen, max_seqlen, d, d, attn_scale, dropout, qkv_layout, - bias_type, attn_mask_type, &Q_view, &K_view, &V_view, input_O, input_dO, input_dO_f16, + bias_type, attn_mask_type, window_size_left, window_size_right, &Q_view, &K_view, &V_view, input_O, input_dO, input_dO_f16, input_M, input_ZInv, input_S, input_output_dP, &dQ_view, &dK_view, &dV_view, input_cu_seqlens, input_cu_seqlens, input_rng_state, wkspace, stream, handle); @@ -950,7 +952,7 @@ void nvte_fused_attn_fwd_kvpacked( Tensor V_view = make_tensor_view(input_KV, unpacked_kv_shape, stride); fused_attn_fp8_fwd(b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d, d, is_training, attn_scale, - dropout, qkv_layout, bias_type, attn_mask_type, input_Q, &K_view, &V_view, + dropout, qkv_layout, bias_type, attn_mask_type, window_size_left, window_size_right, input_Q, &K_view, &V_view, input_output_S, output_O, Aux_CTX_Tensors, input_cu_seqlens_q, input_cu_seqlens_kv, input_rng_state, wkspace, stream, handle); #else @@ -1113,7 +1115,7 @@ void nvte_fused_attn_bwd_kvpacked( Tensor dV_view = make_tensor_view(output_dKV, unpacked_kv_shape, stride); fused_attn_fp8_bwd(b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d, d, attn_scale, dropout, - qkv_layout, bias_type, attn_mask_type, input_Q, &K_view, &V_view, input_O, + qkv_layout, bias_type, attn_mask_type, window_size_left, window_size_right, input_Q, &K_view, &V_view, input_O, input_dO, input_dO_f16, input_M, input_ZInv, input_S, input_output_dP, output_dQ, &dK_view, &dV_view, input_cu_seqlens_q, input_cu_seqlens_kv, input_rng_state, wkspace, stream, handle); @@ -1237,7 +1239,7 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso } else if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_FP8) { #if (CUDNN_VERSION >= 8900) fused_attn_fp8_fwd(b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d_qk, d_v, is_training, attn_scale, - dropout, qkv_layout, bias_type, attn_mask_type, input_Q, input_K, input_V, + dropout, qkv_layout, bias_type, attn_mask_type, window_size_left, window_size_right, input_Q, input_K, input_V, input_output_S, output_O, Aux_CTX_Tensors, input_cu_seqlens_q, input_cu_seqlens_kv, input_rng_state, wkspace, stream, handle); #else @@ -1353,7 +1355,7 @@ void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETenso input_dO_f16 = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[3]); } fused_attn_fp8_bwd(b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d_qk, d_v, attn_scale, dropout, - qkv_layout, bias_type, attn_mask_type, input_Q, input_K, input_V, input_O, + qkv_layout, bias_type, attn_mask_type, window_size_left, window_size_right, input_Q, input_K, input_V, input_O, input_dO, input_dO_f16, input_M, input_ZInv, input_S, input_output_dP, output_dQ, output_dK, output_dV, input_cu_seqlens_q, input_cu_seqlens_kv, input_rng_state, wkspace, stream, handle); diff --git a/transformer_engine/common/fused_attn/fused_attn_fp8.cu b/transformer_engine/common/fused_attn/fused_attn_fp8.cu index f80bf933f7..fdf78fcef3 100644 --- a/transformer_engine/common/fused_attn/fused_attn_fp8.cu +++ b/transformer_engine/common/fused_attn/fused_attn_fp8.cu @@ -1654,7 +1654,7 @@ void fused_attn_fp8_bwd_impl( void fused_attn_fp8_fwd_impl_v1( int64_t b, int64_t h, int64_t hg, int64_t s_q, int64_t s_kv, int64_t d_qk, int64_t d_v, bool is_training, float scaling_factor, float dropout_probability, NVTE_QKV_Layout layout, - NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, void* devPtrQ, void* devPtrK, void* devPtrV, + NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, int64_t window_size_left, int64_t window_size_right, void* devPtrQ, void* devPtrK, void* devPtrV, void* devPtrM, void* devPtrZInv, void* devPtrO, void* devPtrDescaleQ, void* devPtrDescaleK, void* devPtrDescaleV, void* devPtrDescaleS, void* devPtrScaleS, void* devPtrScaleO, void* devPtrAmaxO, void* devPtrAmaxS, void* devPtrcuSeqlensQ, void* devPtrcuSeqlensKV, @@ -1682,6 +1682,7 @@ void fused_attn_fp8_fwd_impl_v1( NVTE_CHECK(is_delayed_scaling || is_current_scaling || is_mxfp8, "FP8 fused attention only supports FP8DelayedScaling or FP8CurrentScaling or MXFP8 recipes!"); printf(">>>>>> b: %d, h: %d, h_g: %d, s_q: %d, s_kv: %d, d_qk: %d, d_v: %d\n", b, h, hg, s_q, s_kv, d_qk, d_v); + printf(">>>>>> window_size_left: %d, window_size_right: %d\n", window_size_left, window_size_right); printf(">>>>>> scaling_mode: %d\n", scaling_mode); printf(">>>>>> is_mxfp8: %d\n", is_mxfp8); printf(">>>>>> is_current_scaling: %d\n", is_current_scaling); @@ -1712,8 +1713,8 @@ void fused_attn_fp8_fwd_impl_v1( bias_type, mask_type, NVTE_Softmax_Type::NVTE_VANILLA_SOFTMAX, - 0, - 0, + window_size_left, + window_size_right, true, true, qkv_tensor_type, @@ -1786,10 +1787,12 @@ void fused_attn_fp8_fwd_impl_v1( int64_t d_qk_scale = (d_qk + block_size - 1) / block_size; int64_t d_v_scale = (d_v + block_size - 1) / block_size; int64_t s_kv_scale = (s_kv + block_size - 1) / block_size; - int64_t s_q_padded = ((s_q + 127) / 128) * 128; - int64_t s_kv_padded = ((s_kv + 127) / 128) * 128; + int64_t s_q_scale = (s_q + block_size - 1) / block_size; + int64_t s_q_padded = ((s_q + 3) / 4) * 4; + int64_t s_kv_padded = ((s_kv + 3) / 4) * 4; int64_t d_qk_scale_padded = ((d_qk_scale + 3) / 4) * 4; int64_t d_v_scale_padded = ((d_v_scale + 3) / 4) * 4; + int64_t s_q_scale_padded = ((s_q_scale + 3) / 4) * 4; int64_t s_kv_scale_padded = ((s_kv_scale + 3) / 4) * 4; int64_t d_qk_padded = ((d_qk + 3) / 4) * 4; int64_t d_v_padded = ((d_v + 3) / 4) * 4; @@ -1804,7 +1807,7 @@ void fused_attn_fp8_fwd_impl_v1( NVTE_QKV_Matrix::NVTE_Q_Matrix); generateMatrixStrides(b, hg, s_q_padded, s_kv_padded, d_qk_scale_padded, k_scale_strides.data(), layout, NVTE_QKV_Matrix::NVTE_K_Matrix); - generateMatrixStrides(b, hg, s_q_padded, s_kv_scale_padded, d_v_padded, v_scale_strides.data(), layout, + generateMatrixStrides(b, hg, s_q_scale_padded, s_kv_scale_padded, d_v_padded, v_scale_strides.data(), layout, NVTE_QKV_Matrix::NVTE_V_Matrix); printf(">>>>>> q_scale_strides: %d, %d, %d, %d\n", q_scale_strides[0], q_scale_strides[1], q_scale_strides[2], q_scale_strides[3]); printf(">>>>>> k_scale_strides: %d, %d, %d, %d\n", k_scale_strides[0], k_scale_strides[1], k_scale_strides[2], k_scale_strides[3]); @@ -1876,6 +1879,8 @@ void fused_attn_fp8_fwd_impl_v1( .set_generate_stats(true) .set_causal_mask(is_causal) .set_attn_scale(attn_scale); + sdpa_options.set_diagonal_band_left_bound(window_size_left + 1); + sdpa_options.set_diagonal_band_right_bound(window_size_right); // sdpa_options.set_alibi_mask(is_alibi); // if (is_bias) { @@ -2055,7 +2060,7 @@ void fused_attn_fp8_fwd_impl_v1( void fused_attn_fp8_bwd_impl_v1( int64_t b, int64_t h, int64_t hg, int64_t s_q, int64_t s_kv, int64_t d_qk, int64_t d_v, float scaling_factor, float dropout_probability, NVTE_QKV_Layout layout, NVTE_Bias_Type bias_type, - NVTE_Mask_Type mask_type, void* devPtrQ, void* devPtrK, void* devPtrV, void* devPtrM, + NVTE_Mask_Type mask_type, int64_t window_size_left, int64_t window_size_right, void* devPtrQ, void* devPtrK, void* devPtrV, void* devPtrM, void* devPtrZInv, void* devPtrO, void* devPtrdO, void* devPtrdQ, void* devPtrdK, void* devPtrdV, void* devPtrDescaleQ, void* devPtrDescaleK, void* devPtrDescaleV, void* devPtrDescaleO, void* devPtrDescaledO, void* devPtrDescaleS, void* devPtrDescaledP, void* devPtrScaleS, @@ -2088,6 +2093,7 @@ void fused_attn_fp8_bwd_impl_v1( NVTE_CHECK(is_delayed_scaling || is_current_scaling || is_mxfp8, "FP8 fused attention only supports FP8DelayedScaling or FP8CurrentScaling or MXFP8 recipes!"); printf(">>>>>> b: %d, h: %d, h_g: %d, s_q: %d, s_kv: %d, d_qk: %d, d_v: %d\n", b, h, hg, s_q, s_kv, d_qk, d_v); + printf(">>>>>> window_size_left: %d, window_size_right: %d\n", window_size_left, window_size_right); printf(">>>>>> scaling_mode: %d\n", scaling_mode); printf(">>>>>> is_mxfp8: %d\n", is_mxfp8); printf(">>>>>> is_current_scaling: %d\n", is_current_scaling); @@ -2123,8 +2129,8 @@ void fused_attn_fp8_bwd_impl_v1( bias_type, mask_type, NVTE_Softmax_Type::NVTE_VANILLA_SOFTMAX, - 0, - 0, + window_size_left, + window_size_right, true, false, qkv_tensor_type, @@ -2299,8 +2305,8 @@ void fused_attn_fp8_bwd_impl_v1( int64_t d_v_scale = (d_v + block_size - 1) / block_size; int64_t s_q_scale = (s_q + block_size - 1) / block_size; int64_t s_kv_scale = (s_kv + block_size - 1) / block_size; - int64_t s_q_padded = ((s_q + 127) / 128) * 128; - int64_t s_kv_padded = ((s_kv + 127) / 128) * 128; + int64_t s_q_padded = ((s_q + 3) / 4) * 4; + int64_t s_kv_padded = ((s_kv + 3) / 4) * 4; int64_t d_qk_padded = ((d_qk + 3) / 4) * 4; int64_t d_v_padded = ((d_v + 3) / 4) * 4; int64_t d_qk_scale_padded = ((d_qk_scale + 3) / 4) * 4; @@ -2408,6 +2414,9 @@ void fused_attn_fp8_bwd_impl_v1( .set_causal_mask(is_causal) .set_attn_scale(attn_scale); + // sdpa_backward_options.set_diagonal_band_left_bound(window_size_left + 1); + // sdpa_backward_options.set_diagonal_band_right_bound(window_size_right); + // sdpa_backward_options.set_alibi_mask(is_alibi); // if (is_bias) { @@ -2705,7 +2714,7 @@ void fused_attn_fp8_fwd(size_t batch, size_t num_attn_heads, size_t num_gqa_grou size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, bool is_training, float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, - NVTE_Mask_Type mask_type, const Tensor* input_Q, const Tensor* input_K, + NVTE_Mask_Type mask_type, size_t window_size_left, size_t window_size_right, const Tensor* input_Q, const Tensor* input_K, const Tensor* input_V, Tensor* input_output_S, Tensor* output_O, NVTETensorPack* Aux_CTX_Tensors, const Tensor* cu_seqlens_q, const Tensor* cu_seqlens_kv, const Tensor* rng_state, Tensor* workspace, @@ -2794,7 +2803,7 @@ void fused_attn_fp8_fwd(size_t batch, size_t num_attn_heads, size_t num_gqa_grou if ((qkv_format == NVTE_QKV_Format::NVTE_BSHD) || (qkv_format == NVTE_QKV_Format::NVTE_SBHD) || (qkv_format == NVTE_QKV_Format::NVTE_BHSD)) { fused_attn::fused_attn_fp8_fwd_impl_v1( batch, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, head_dim_qk, head_dim_v, is_training, - attn_scale, p_dropout, qkv_layout, bias_type, mask_type, devPtrQ, devPtrK, devPtrV, devPtrM, + attn_scale, p_dropout, qkv_layout, bias_type, mask_type, window_size_left, window_size_right, devPtrQ, devPtrK, devPtrV, devPtrM, devPtrZInv, devPtrO, devPtrDescaleQ, devPtrDescaleK, devPtrDescaleV, devPtrDescaleS, devPtrScaleS, devPtrScaleO, devPtrAmaxO, devPtrAmaxS, devPtrcuSeqlensQ, devPtrcuSeqlensKV, devPtrDropoutSeed, devPtrDropoutOffset, get_cudnn_fe_dtype(QKV_type), @@ -2828,7 +2837,7 @@ void fused_attn_fp8_fwd(size_t batch, size_t num_attn_heads, size_t num_gqa_grou void fused_attn_fp8_bwd(size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, - NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, const Tensor* input_Q, + NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, size_t window_size_left, size_t window_size_right, const Tensor* input_Q, const Tensor* input_K, const Tensor* input_V, const Tensor* input_O, const Tensor* input_dO, const Tensor* input_dO_f16, const Tensor* input_M, const Tensor* input_ZInv, const Tensor* input_S, Tensor* input_output_dP, const Tensor* output_dQ, @@ -2897,7 +2906,7 @@ void fused_attn_fp8_bwd(size_t batch, size_t num_attn_heads, size_t num_gqa_grou if ((qkv_format == NVTE_QKV_Format::NVTE_BSHD) || (qkv_format == NVTE_QKV_Format::NVTE_SBHD)) { fused_attn::fused_attn_fp8_bwd_impl_v1( batch, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, head_dim_qk, head_dim_v, attn_scale, - p_dropout, qkv_layout, bias_type, mask_type, devPtrQ, devPtrK, devPtrV, devPtrM, devPtrZInv, + p_dropout, qkv_layout, bias_type, mask_type, window_size_left, window_size_right, devPtrQ, devPtrK, devPtrV, devPtrM, devPtrZInv, devPtrO, devPtrdO, devPtrdQ, devPtrdK, devPtrdV, devPtrDescaleQ, devPtrDescaleK, devPtrDescaleV, devPtrDescaleO, devPtrDescaledO, devPtrDescaleS, devPtrDescaledP, devPtrScaleS, devPtrScaledP, devPtrScaledQ, devPtrScaledK, devPtrScaledV, devPtrAmaxdP, diff --git a/transformer_engine/common/fused_attn/fused_attn_fp8.h b/transformer_engine/common/fused_attn/fused_attn_fp8.h index f335bc3d85..22800b2aa2 100644 --- a/transformer_engine/common/fused_attn/fused_attn_fp8.h +++ b/transformer_engine/common/fused_attn/fused_attn_fp8.h @@ -18,7 +18,7 @@ void fused_attn_fp8_fwd(size_t batch, size_t num_attn_heads, size_t num_gqa_grou size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, bool is_training, float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, - NVTE_Mask_Type mask_type, const Tensor *input_Q, const Tensor *input_K, + NVTE_Mask_Type mask_type, size_t window_size_left, size_t window_size_right, const Tensor *input_Q, const Tensor *input_K, const Tensor *input_V, Tensor *input_output_S, Tensor *output_O, NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv, const Tensor *rng_state, Tensor *workspace, @@ -28,7 +28,7 @@ void fused_attn_fp8_fwd(size_t batch, size_t num_attn_heads, size_t num_gqa_grou void fused_attn_fp8_bwd(size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, - NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, const Tensor *input_Q, + NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, size_t window_size_left, size_t window_size_right, const Tensor *input_Q, const Tensor *input_K, const Tensor *input_V, const Tensor *input_O, const Tensor *input_dO, const Tensor *input_dO_f16, const Tensor *input_M, const Tensor *input_ZInv, const Tensor *input_S, Tensor *input_output_dP, const Tensor *output_dQ, diff --git a/transformer_engine/pytorch/attention/dot_product_attention/utils.py b/transformer_engine/pytorch/attention/dot_product_attention/utils.py index d3c2e01814..873c101521 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/utils.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/utils.py @@ -876,12 +876,7 @@ def _is_fa3_supported(num_heads, num_gqa_groups, head_dim_qk, head_dim_v, qkv_dt if window_size is None: window_size = check_set_window_size(attn_mask_type, window_size) if use_fused_attention and (window_size[0] != -1 or window_size[1] not in [-1, 0]): - if fp8 and (fp8_meta["recipe"].fp8_dpa or fp8_meta["recipe"].fp8_mha): - logger.debug( - "Disabling FusedAttention as it does not support sliding window attention for FP8" - ) - use_fused_attention = False - elif attention_dropout != 0.0: + if attention_dropout != 0.0: logger.debug( "Disabling FusedAttention as it only supports sliding window attention " "without dropout" @@ -1016,7 +1011,7 @@ def _is_fa3_supported(num_heads, num_gqa_groups, head_dim_qk, head_dim_v, qkv_dt use_fused_attention and window_size is not None and window_size[0] != -1 - and fused_attention_backend != FusedAttnBackend["F16_arbitrary_seqlen"] + and fused_attention_backend == FusedAttnBackend["F16_max512_seqlen"] ): logger.debug( "Disabling FusedAttention as only sub-backend %s does not support " @@ -2214,18 +2209,23 @@ def permute_x(f, x): original_shapes = [x.shape for x in [q, k, v]] s_q, d_qk = q.shape[-2:] - s_kv, d_kv = v.shape[-2:] + s_kv, d_v = v.shape[-2:] assert s_q % 128 == 0 assert s_kv % 128 == 0 assert d_qk % 32 == 0 - assert d_kv % 32 == 0 + assert d_v % 32 == 0 # need to check seqlens in THD % 128 == 0 q, k, v = [x.view(-1, x.shape[-1]) for x in [q, k, v]] print(f">>>>>>>>>>>> Flattened shapes: q: {q.shape}, k: {k.shape}, v: {v.shape}") # consider bhsd for now - grouped_tensor = GroupedTensor.create_and_quantize(tensors=[q, k, v], quantizer=qkv_quantizer) - q_fp8, k_fp8, v_fp8 = grouped_tensor.quantized_tensors + if d_qk == d_v: + grouped_tensor = GroupedTensor.create_and_quantize(tensors=[q, k, v], quantizer=qkv_quantizer) + q_fp8, k_fp8, v_fp8 = grouped_tensor.quantized_tensors + else: + q_fp8 = qkv_quantizer(q) + k_fp8 = qkv_quantizer(k) + v_fp8 = qkv_quantizer(v) q_fp8, k_fp8, v_fp8 = [x.view(s) for x, s in zip([q_fp8, k_fp8, v_fp8], original_shapes)] print(f">>>>>>>>>>>> q_fp8: {q_fp8._rowwise_data.shape}, k_fp8: {k_fp8._rowwise_data.shape}, v_fp8: {v_fp8._rowwise_data.shape}") print(f">>>>>>>>>>>> q_fp8: {q_fp8._rowwise_scale_inv.shape}, k_fp8: {k_fp8._rowwise_scale_inv.shape}, v_fp8: {v_fp8._rowwise_scale_inv.shape}") From 2008bed824b69eb21650d146e18916f0d7f872e0 Mon Sep 17 00:00:00 2001 From: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Date: Fri, 13 Feb 2026 18:13:40 -0800 Subject: [PATCH 014/172] remove prints Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --- transformer_engine/common/cast/cast.cu | 1 - .../common/cast/dispatch/quantize.cuh | 1 - .../common/fused_attn/fused_attn.cpp | 11 ------ .../common/fused_attn/fused_attn_fp8.cu | 34 ------------------- .../dot_product_attention/backends.py | 4 --- .../attention/dot_product_attention/utils.py | 8 ----- .../pytorch/csrc/extensions/cast.cpp | 1 - .../pytorch/csrc/type_converters.cpp | 1 - .../pytorch/tensor/mxfp8_tensor.py | 2 -- 9 files changed, 63 deletions(-) diff --git a/transformer_engine/common/cast/cast.cu b/transformer_engine/common/cast/cast.cu index 624b0bfc7c..12d816f708 100644 --- a/transformer_engine/common/cast/cast.cu +++ b/transformer_engine/common/cast/cast.cu @@ -30,7 +30,6 @@ void nvte_group_quantize(const NVTEGroupedTensor input, NVTEGroupedTensor output cudaStream_t stream) { NVTE_API_CALL(nvte_group_quantize); using namespace transformer_engine; - printf(">>>>>>>>>>>> nvte_group_quantize\n"); constexpr bool IS_ACT = false; dispatch::group_quantize_fwd_helper(input, output, nullptr, stream); } diff --git a/transformer_engine/common/cast/dispatch/quantize.cuh b/transformer_engine/common/cast/dispatch/quantize.cuh index 9a6e9b01d6..b83df1dedf 100644 --- a/transformer_engine/common/cast/dispatch/quantize.cuh +++ b/transformer_engine/common/cast/dispatch/quantize.cuh @@ -375,7 +375,6 @@ void group_quantize_fwd_helper(const NVTETensor input, NVTETensor *outputs, template void group_quantize_fwd_helper(const NVTEGroupedTensor input, NVTEGroupedTensor output, const NVTEQuantizationConfig quant_config, cudaStream_t stream) { - printf(">>>>>>>>>>>> group_quantize_fwd_helper\n"); using namespace detail; NVTEScalingMode scaling_mode = nvte_grouped_tensor_scaling_mode(output); diff --git a/transformer_engine/common/fused_attn/fused_attn.cpp b/transformer_engine/common/fused_attn/fused_attn.cpp index 6f343d90b2..d58fca70e9 100644 --- a/transformer_engine/common/fused_attn/fused_attn.cpp +++ b/transformer_engine/common/fused_attn/fused_attn.cpp @@ -234,16 +234,6 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( max_seqlen_kv, head_dim_qk, head_dim_v) == DType::kInt64); const bool supported_ragged_offset_size = (!requires_64bit_ragged_offset || cudnn_runtime_version >= 90500); - printf(">>>>>> nvte_get_fused_attn_backend qkv_layout: %d, %d, %d\n", qkv_layout, NVTE_QKV_Layout::NVTE_BHSD_BHSD_BHSD, NVTE_QKV_Layout::NVTE_BSHD_BSHD_BSHD); - printf(">>>>>> nvte_get_fused_attn_backend q_dtype: %d, %d, %d\n", q_dtype, NVTEDType::kNVTEFloat8E4M3, NVTEDType::kNVTEFloat8E5M2); - printf(">>>>>> nvte_get_fused_attn_backend qkv_format: %d, %d, %d\n", qkv_format, NVTE_QKV_Format::NVTE_BHSD, NVTE_QKV_Format::NVTE_BSHD); - printf(">>>>>> nvte_get_fused_attn_backend q_format: %d, %d, %d\n", q_format, NVTE_QKV_Format::NVTE_BHSD, NVTE_QKV_Format::NVTE_BSHD); - printf(">>>>>> nvte_get_fused_attn_backend kv_format: %d, %d, %d\n", kv_format, NVTE_QKV_Format::NVTE_BHSD, NVTE_QKV_Format::NVTE_BSHD); - printf(">>>>>> nvte_get_fused_attn_backend layout_group: %d, %d, %d\n", layout_group, NVTE_QKV_Layout_Group::NVTE_SD_SD_SD, NVTE_QKV_Layout_Group::NVTE_HD_HD_HD); - printf(">>>>>> nvte_get_fused_attn_backend cudnn_runtime_version: %d\n", cudnn_runtime_version); - printf(">>>>>> nvte_get_fused_attn_backend is_training: %d\n", is_training); - printf(">>>>>> nvte_get_fused_attn_backend bias_type: %d\n", bias_type); - printf(">>>>>> nvte_get_fused_attn_backend attn_mask_type: %d, %d, %d\n", attn_mask_type, NVTE_Mask_Type::NVTE_NO_MASK, NVTE_Mask_Type::NVTE_CAUSAL_MASK); if ((q_dtype == NVTEDType::kNVTEFloat8E4M3 || q_dtype == NVTEDType::kNVTEFloat8E5M2) && sm_arch_ >= 90 && bias_type == NVTE_Bias_Type::NVTE_NO_BIAS && @@ -532,7 +522,6 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( } else { backend = NVTE_Fused_Attn_Backend::NVTE_No_Backend; } - printf(">>>>>> nvte_get_fused_attn_backend fused_attention_backend: %d\n", backend); return backend; } diff --git a/transformer_engine/common/fused_attn/fused_attn_fp8.cu b/transformer_engine/common/fused_attn/fused_attn_fp8.cu index fdf78fcef3..5c809c6050 100644 --- a/transformer_engine/common/fused_attn/fused_attn_fp8.cu +++ b/transformer_engine/common/fused_attn/fused_attn_fp8.cu @@ -1681,14 +1681,6 @@ void fused_attn_fp8_fwd_impl_v1( o_tensor_type == cudnn_frontend::DataType_t::BFLOAT16); NVTE_CHECK(is_delayed_scaling || is_current_scaling || is_mxfp8, "FP8 fused attention only supports FP8DelayedScaling or FP8CurrentScaling or MXFP8 recipes!"); - printf(">>>>>> b: %d, h: %d, h_g: %d, s_q: %d, s_kv: %d, d_qk: %d, d_v: %d\n", b, h, hg, s_q, s_kv, d_qk, d_v); - printf(">>>>>> window_size_left: %d, window_size_right: %d\n", window_size_left, window_size_right); - printf(">>>>>> scaling_mode: %d\n", scaling_mode); - printf(">>>>>> is_mxfp8: %d\n", is_mxfp8); - printf(">>>>>> is_current_scaling: %d\n", is_current_scaling); - printf(">>>>>> is_delayed_scaling: %d\n", is_delayed_scaling); - printf(">>>>>> qkv_tensor_type: %d, %d, %d, %d\n", qkv_tensor_type, cudnn_frontend::DataType_t::FP8_E8M0, cudnn_frontend::DataType_t::FP8_E4M3, cudnn_frontend::DataType_t::FP8_E5M2); - printf(">>>>>> o_tensor_type: %d, %d, %d\n", o_tensor_type, cudnn_frontend::DataType_t::FLOAT, cudnn_frontend::DataType_t::BFLOAT16); try { FADescriptor_v1 descriptor{b, @@ -1779,9 +1771,6 @@ void fused_attn_fp8_fwd_impl_v1( NVTE_QKV_Matrix::NVTE_K_Matrix); generateMatrixStrides(b, hg, s_q, s_kv, d_v, v_stride.data(), layout, NVTE_QKV_Matrix::NVTE_V_Matrix); // need to double check - printf(">>>>>> q_stride: %d, %d, %d, %d\n", q_stride[0], q_stride[1], q_stride[2], q_stride[3]); - printf(">>>>>> k_stride: %d, %d, %d, %d\n", k_stride[0], k_stride[1], k_stride[2], k_stride[3]); - printf(">>>>>> v_stride: %d, %d, %d, %d\n", v_stride[0], v_stride[1], v_stride[2], v_stride[3]); int32_t block_size = 32; int64_t d_qk_scale = (d_qk + block_size - 1) / block_size; @@ -1796,7 +1785,6 @@ void fused_attn_fp8_fwd_impl_v1( int64_t s_kv_scale_padded = ((s_kv_scale + 3) / 4) * 4; int64_t d_qk_padded = ((d_qk + 3) / 4) * 4; int64_t d_v_padded = ((d_v + 3) / 4) * 4; - printf(">>>>>> d_qk_scale: %d, d_v_scale: %d, s_kv_scale: %d, s_q_padded: %d, s_kv_padded: %d, d_qk_scale_padded: %d, d_v_scale_padded: %d, s_kv_scale_padded: %d, d_qk_padded: %d, d_v_padded: %d\n", d_qk_scale, d_v_scale, s_kv_scale, s_q_padded, s_kv_padded, d_qk_scale_padded, d_v_scale_padded, s_kv_scale_padded, d_qk_padded, d_v_padded); std::vector q_scale_dims = {b, h, s_q_padded, d_qk_scale_padded}; std::vector k_scale_dims = {b, hg, s_kv_padded, d_qk_scale_padded}; std::vector v_scale_dims = {b, hg, s_kv_scale_padded, d_v_padded}; @@ -1809,9 +1797,6 @@ void fused_attn_fp8_fwd_impl_v1( NVTE_QKV_Matrix::NVTE_K_Matrix); generateMatrixStrides(b, hg, s_q_scale_padded, s_kv_scale_padded, d_v_padded, v_scale_strides.data(), layout, NVTE_QKV_Matrix::NVTE_V_Matrix); - printf(">>>>>> q_scale_strides: %d, %d, %d, %d\n", q_scale_strides[0], q_scale_strides[1], q_scale_strides[2], q_scale_strides[3]); - printf(">>>>>> k_scale_strides: %d, %d, %d, %d\n", k_scale_strides[0], k_scale_strides[1], k_scale_strides[2], k_scale_strides[3]); - printf(">>>>>> v_scale_strides: %d, %d, %d, %d\n", v_scale_strides[0], v_scale_strides[1], v_scale_strides[2], v_scale_strides[3]); Q = mha_graph->tensor(fe::graph::Tensor_attributes() .set_name("Q") @@ -2049,7 +2034,6 @@ void fused_attn_fp8_fwd_impl_v1( variant_pack[dropout_seed] = devPtrDropoutSeed; variant_pack[dropout_offset] = devPtrDropoutOffset; } - printf(">>>>>> mha_graph->execute(handle, variant_pack, workspace)\n"); NVTE_CHECK_CUDNN_FE(mha_graph->execute(handle, variant_pack, workspace)); } catch (cudnn_frontend::cudnnException& e) { NVTE_ERROR(e.what()); @@ -2092,16 +2076,6 @@ void fused_attn_fp8_bwd_impl_v1( dqkv_tensor_type == cudnn_frontend::DataType_t::BFLOAT16); NVTE_CHECK(is_delayed_scaling || is_current_scaling || is_mxfp8, "FP8 fused attention only supports FP8DelayedScaling or FP8CurrentScaling or MXFP8 recipes!"); - printf(">>>>>> b: %d, h: %d, h_g: %d, s_q: %d, s_kv: %d, d_qk: %d, d_v: %d\n", b, h, hg, s_q, s_kv, d_qk, d_v); - printf(">>>>>> window_size_left: %d, window_size_right: %d\n", window_size_left, window_size_right); - printf(">>>>>> scaling_mode: %d\n", scaling_mode); - printf(">>>>>> is_mxfp8: %d\n", is_mxfp8); - printf(">>>>>> is_current_scaling: %d\n", is_current_scaling); - printf(">>>>>> is_delayed_scaling: %d\n", is_delayed_scaling); - printf(">>>>>> qkv_tensor_type: %d, %d, %d, %d\n", qkv_tensor_type, cudnn_frontend::DataType_t::FP8_E8M0, cudnn_frontend::DataType_t::FP8_E4M3, cudnn_frontend::DataType_t::FP8_E5M2); - printf(">>>>>> o_tensor_type: %d, %d, %d, %d\n", o_tensor_type, cudnn_frontend::DataType_t::FLOAT, cudnn_frontend::DataType_t::BFLOAT16); - printf(">>>>>> do_tensor_type: %d, %d, %d\n", do_tensor_type, cudnn_frontend::DataType_t::FLOAT, cudnn_frontend::DataType_t::BFLOAT16); - printf(">>>>>> dqkv_tensor_type: %d, %d, %d, %d\n", dqkv_tensor_type, cudnn_frontend::DataType_t::FP8_E8M0, cudnn_frontend::DataType_t::FP8_E4M3, cudnn_frontend::DataType_t::FP8_E5M2); bool is_O_in_F16 = (o_tensor_type == cudnn_frontend::DataType_t::HALF || o_tensor_type == cudnn_frontend::DataType_t::BFLOAT16); @@ -2313,7 +2287,6 @@ void fused_attn_fp8_bwd_impl_v1( int64_t d_v_scale_padded = ((d_v_scale + 3) / 4) * 4; int64_t s_q_scale_padded = ((s_q_scale + 3) / 4) * 4; int64_t s_kv_scale_padded = ((s_kv_scale + 3) / 4) * 4; - printf(">>>>>> d_qk_scale: %d, d_v_scale: %d, s_q_scale: %d, s_kv_scale: %d, s_q_padded: %d, s_kv_padded: %d, d_qk_scale_padded: %d, d_v_scale_padded: %d, s_q_scale_padded: %d, s_kv_scale_padded: %d, d_qk_padded: %d, d_v_padded: %d\n", d_qk_scale, d_v_scale, s_q_scale, s_kv_scale, s_q_padded, s_kv_padded, d_qk_scale_padded, d_v_scale_padded, s_q_scale_padded, s_kv_scale_padded, d_qk_padded, d_v_padded); // std::vector q_scale_dims = {b, h, s_q_padded, d_qk_scale_padded}; // std::vector k_scale_dims = {b, hg, s_kv_padded, d_qk_scale_padded}; // std::vector v_scale_dims = {b, hg, s_kv_padded, d_v_scale_padded}; @@ -2338,11 +2311,6 @@ void fused_attn_fp8_bwd_impl_v1( NVTE_QKV_Matrix::NVTE_K_Matrix_Transpose); generateMatrixStrides(b, hg, s_q_padded, s_kv_padded, d_v_scale_padded, v_scale_strides.data(), layout, NVTE_QKV_Matrix::NVTE_V_Matrix); - printf(">>>>>> q_scale_strides: %d, %d, %d, %d\n", q_scale_strides[0], q_scale_strides[1], q_scale_strides[2], q_scale_strides[3]); - printf(">>>>>> k_scale_strides: %d, %d, %d, %d\n", k_scale_strides[0], k_scale_strides[1], k_scale_strides[2], k_scale_strides[3]); - printf(">>>>>> v_scale_strides: %d, %d, %d, %d\n", v_scale_strides[0], v_scale_strides[1], v_scale_strides[2], v_scale_strides[3]); - printf(">>>>>> q_t_scale_strides: %d, %d, %d, %d\n", q_t_scale_strides[0], q_t_scale_strides[1], q_t_scale_strides[2], q_t_scale_strides[3]); - printf(">>>>>> k_t_scale_strides: %d, %d, %d, %d\n", k_t_scale_strides[0], k_t_scale_strides[1], k_t_scale_strides[2], k_t_scale_strides[3]); q_t = mha_graph->tensor(fe::graph::Tensor_attributes() .set_name("Q_t") @@ -2733,7 +2701,6 @@ void fused_attn_fp8_fwd(size_t batch, size_t num_attn_heads, size_t num_gqa_grou void* devPtrScaleS = nullptr; void* devPtrDescaleS = nullptr; if (input_Q->scaling_mode == NVTE_MXFP8_1D_SCALING) { - printf(">>>>>> input_Q is MXFP8_1D_SCALING\n"); devPtrQ = input_Q->data.dptr; devPtrDescaleQ = input_Q->scale_inv.dptr; devPtrK = input_K->data.dptr; @@ -2745,7 +2712,6 @@ void fused_attn_fp8_fwd(size_t batch, size_t num_attn_heads, size_t num_gqa_grou devPtrO = output_O->data.dptr; devPtrAmaxO = output_O->amax.dptr; } else { - printf(">>>>>> input_Q is not MXFP8_1D_SCALING\n"); devPtrQ = input_Q->data.dptr; devPtrDescaleQ = input_Q->scale_inv.dptr; devPtrK = input_K->data.dptr; diff --git a/transformer_engine/pytorch/attention/dot_product_attention/backends.py b/transformer_engine/pytorch/attention/dot_product_attention/backends.py index 47f7e0f222..c0dca1b330 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/backends.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/backends.py @@ -1304,14 +1304,10 @@ def forward( # out: torch.Tensor; dtype = torch.float16 or torch.bfloat16 out_fp8 = out_ out = out_ - print(f"out_: {type(out_)} {out_.shape}") - print(f"is_output_fp8: {is_output_fp8}, is_bwd_fp8: {is_bwd_fp8}, fp8_recipe.float8_current_scaling(): {fp8_recipe.float8_current_scaling()}, _dpa_fp8_cs_o_in_f16: {_dpa_fp8_cs_o_in_f16}") if isinstance(out_, Float8Tensor) or isinstance(out_, MXFP8Tensor): - print(f"dequantizing out_") if not is_output_fp8 or not is_bwd_fp8: out = out_.dequantize().view(out_.shape) else: - print(f"quantizing out_") if is_output_fp8 or ( is_bwd_fp8 and not (fp8_recipe.float8_current_scaling() and _dpa_fp8_cs_o_in_f16) diff --git a/transformer_engine/pytorch/attention/dot_product_attention/utils.py b/transformer_engine/pytorch/attention/dot_product_attention/utils.py index 873c101521..12a75131aa 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/utils.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/utils.py @@ -2187,7 +2187,6 @@ def combine_and_quantize(qkv_layout, q, k, v, qkv_quantizer): qkv_group = len(qkv_layout.split("_")) src_nominal_dtype = q.dtype if isinstance(qkv_quantizer, MXFP8Quantizer): - print(f"Combining and quantizing q, k, v: {q.shape}, {k.shape}, {v.shape}") def permute_x(f, x): x = x.contiguous() if not x.is_contiguous() else x @@ -2204,7 +2203,6 @@ def permute_x(f, x): if kv_format not in ["bhsd", "htd"]: k = permute_x(kv_format, k) v = permute_x(kv_format, v) - print(f">>>>>>>>>>>> Permuted shapes: q: {q.shape}, k: {k.shape}, v: {v.shape}") qkv_layout = "bhsd_bhsd_bhsd" if qkv_format != "thd" else "htd_htd_htd" original_shapes = [x.shape for x in [q, k, v]] @@ -2216,7 +2214,6 @@ def permute_x(f, x): assert d_v % 32 == 0 # need to check seqlens in THD % 128 == 0 q, k, v = [x.view(-1, x.shape[-1]) for x in [q, k, v]] - print(f">>>>>>>>>>>> Flattened shapes: q: {q.shape}, k: {k.shape}, v: {v.shape}") # consider bhsd for now if d_qk == d_v: @@ -2227,10 +2224,6 @@ def permute_x(f, x): k_fp8 = qkv_quantizer(k) v_fp8 = qkv_quantizer(v) q_fp8, k_fp8, v_fp8 = [x.view(s) for x, s in zip([q_fp8, k_fp8, v_fp8], original_shapes)] - print(f">>>>>>>>>>>> q_fp8: {q_fp8._rowwise_data.shape}, k_fp8: {k_fp8._rowwise_data.shape}, v_fp8: {v_fp8._rowwise_data.shape}") - print(f">>>>>>>>>>>> q_fp8: {q_fp8._rowwise_scale_inv.shape}, k_fp8: {k_fp8._rowwise_scale_inv.shape}, v_fp8: {v_fp8._rowwise_scale_inv.shape}") - print(f">>>>>>>>>>>> q_fp8: {q_fp8._columnwise_data.shape}, k_fp8: {k_fp8._columnwise_data.shape}, v_fp8: {v_fp8._columnwise_data.shape}") - print(f">>>>>>>>>>>> q_fp8: {q_fp8._columnwise_scale_inv.shape}, k_fp8: {k_fp8._columnwise_scale_inv.shape}, v_fp8: {v_fp8._columnwise_scale_inv.shape}") return q_fp8, k_fp8, v_fp8, qkv_layout @@ -2292,7 +2285,6 @@ def combine_and_dequantize( des_nominal_dtype = src_nominal_dtype if all(isinstance(x, (MXFP8Tensor, MXFP8TensorStorage)) for x in [q_fp8, k_fp8, v_fp8]): - print(f"Combining and dequantizing q, k, v from MXFP8 to {des_nominal_dtype}") q, k, v = [x.dequantize(dtype=des_nominal_dtype) for x in [q_fp8, k_fp8, v_fp8]] return q, k, v diff --git a/transformer_engine/pytorch/csrc/extensions/cast.cpp b/transformer_engine/pytorch/csrc/extensions/cast.cpp index 34565bcf44..9d3d6b901d 100644 --- a/transformer_engine/pytorch/csrc/extensions/cast.cpp +++ b/transformer_engine/pytorch/csrc/extensions/cast.cpp @@ -83,7 +83,6 @@ py::object quantize(const at::Tensor &tensor, py::handle quantizer, const py::ob py::object quantize_grouped(const py::handle &input, py::handle &output) { using namespace transformer_engine::pytorch::detail; init_extension(); - printf(">>>>>>>>>>>> quantize_grouped\n"); const auto &grouped_input_tensor = GroupedTensorFromPyTorchGroupedTensor(input); const auto &grouped_output_tensor = GroupedTensorFromPyTorchGroupedTensor(output); NVTE_SCOPED_GIL_RELEASE({ diff --git a/transformer_engine/pytorch/csrc/type_converters.cpp b/transformer_engine/pytorch/csrc/type_converters.cpp index c17be6c855..8c9d7d7c16 100644 --- a/transformer_engine/pytorch/csrc/type_converters.cpp +++ b/transformer_engine/pytorch/csrc/type_converters.cpp @@ -212,7 +212,6 @@ GroupedTensorWrapper GroupedTensorFromPyTorchGroupedTensor(py::handle tensor) { scaling_mode = ScalingModeFromQuantizer(quantizer); quantizer_dtype = quantizer.attr("dtype").cast(); with_gemm_swizzled_scales = quantizer.attr("optimize_for_gemm").cast(); - printf(">>>>>>>>>>>> GroupedTensorFromPyTorchGroupedTensor with_gemm_swizzled_scales: %d\n", with_gemm_swizzled_scales); } auto ret = GroupedTensorWrapper(num_tensors, logical_shape, scaling_mode); diff --git a/transformer_engine/pytorch/tensor/mxfp8_tensor.py b/transformer_engine/pytorch/tensor/mxfp8_tensor.py index 6c72d74531..41d6c87f2b 100644 --- a/transformer_engine/pytorch/tensor/mxfp8_tensor.py +++ b/transformer_engine/pytorch/tensor/mxfp8_tensor.py @@ -75,7 +75,6 @@ def update_quantized( src = src.contiguous() # Launch cast kernel - print(f"MXFP8Quantizer.update_quantized: src: {src.shape}, dst: {dst.shape}") tex.quantize(src, self, dst, noop_flag) # Update FP8 dtype @@ -85,7 +84,6 @@ def update_quantized( def quantize_impl(self, tensor: torch.Tensor) -> QuantizedTensor: """Quantize tensor implementation""" - print(f"MXFP8Quantizer.quantize_impl: tensor: {tensor.shape}") return tex.quantize(tensor, self) def is_quantizable(self, inp: torch.Tensor) -> bool: From 239f58aec1b5c33e6b6e97ca4043c754066f241a Mon Sep 17 00:00:00 2001 From: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Date: Fri, 13 Feb 2026 18:14:42 -0800 Subject: [PATCH 015/172] remove leftover prints Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --- transformer_engine/common/cast/mxfp8/quantize_mxfp8.cuh | 3 --- 1 file changed, 3 deletions(-) diff --git a/transformer_engine/common/cast/mxfp8/quantize_mxfp8.cuh b/transformer_engine/common/cast/mxfp8/quantize_mxfp8.cuh index a8135391e3..70a68132ad 100644 --- a/transformer_engine/common/cast/mxfp8/quantize_mxfp8.cuh +++ b/transformer_engine/common/cast/mxfp8/quantize_mxfp8.cuh @@ -776,7 +776,6 @@ void quantize(const Tensor &input, const Tensor *act_input, const Tensor *noop, switch (scaling_type) { case ScalingType::ROWWISE: { - printf(">>>>>>>>>>>> quantize_mxfp8 ScalingType::ROWWISE\n"); auto kernel = quantize_mxfp8_kernel; @@ -792,7 +791,6 @@ void quantize(const Tensor &input, const Tensor *act_input, const Tensor *noop, break; } case ScalingType::COLWISE: { - printf(">>>>>>>>>>>> quantize_mxfp8 ScalingType::COLWISE\n"); auto kernel = quantize_mxfp8_kernel; @@ -808,7 +806,6 @@ void quantize(const Tensor &input, const Tensor *act_input, const Tensor *noop, break; } case ScalingType::BIDIMENSIONAL: { - printf(">>>>>>>>>>>> quantize_mxfp8 ScalingType::BIDIMENSIONAL\n"); auto kernel = quantize_mxfp8_kernel; From f44a775706a249cef801b162d34f5ff0c9e8c5eb Mon Sep 17 00:00:00 2001 From: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Date: Fri, 13 Feb 2026 18:23:56 -0800 Subject: [PATCH 016/172] Revert "update FE" This reverts commit d9ff5662aa4b4b6267c77baf614aada6602fa133. Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --- .gitmodules | 3 +-- 3rdparty/cudnn-frontend | 2 +- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/.gitmodules b/.gitmodules index 8c7646c00d..4b188d6bb1 100644 --- a/.gitmodules +++ b/.gitmodules @@ -3,8 +3,7 @@ url = https://github.com/google/googletest.git [submodule "3rdparty/cudnn-frontend"] path = 3rdparty/cudnn-frontend - url = https://gitlab-master.nvidia.com/cudnn/cudnn_frontend.git - branch = develop + url = https://github.com/NVIDIA/cudnn-frontend.git [submodule "3rdparty/cutlass"] path = 3rdparty/cutlass url = https://github.com/NVIDIA/cutlass.git diff --git a/3rdparty/cudnn-frontend b/3rdparty/cudnn-frontend index 4b4df2edcf..209a25fe89 160000 --- a/3rdparty/cudnn-frontend +++ b/3rdparty/cudnn-frontend @@ -1 +1 @@ -Subproject commit 4b4df2edcf80b7cc7a659da5734454577154aa6d +Subproject commit 209a25fe898bb749c8605363a6431e26cb41b089 From 965572bc571fe27d932f66ad74c026ee28d40adf Mon Sep 17 00:00:00 2001 From: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Date: Fri, 13 Feb 2026 18:29:36 -0800 Subject: [PATCH 017/172] update FE Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --- .gitmodules | 3 ++- 3rdparty/cudnn-frontend | 2 +- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/.gitmodules b/.gitmodules index 4b188d6bb1..8c7646c00d 100644 --- a/.gitmodules +++ b/.gitmodules @@ -3,7 +3,8 @@ url = https://github.com/google/googletest.git [submodule "3rdparty/cudnn-frontend"] path = 3rdparty/cudnn-frontend - url = https://github.com/NVIDIA/cudnn-frontend.git + url = https://gitlab-master.nvidia.com/cudnn/cudnn_frontend.git + branch = develop [submodule "3rdparty/cutlass"] path = 3rdparty/cutlass url = https://github.com/NVIDIA/cutlass.git diff --git a/3rdparty/cudnn-frontend b/3rdparty/cudnn-frontend index 209a25fe89..4b4df2edcf 160000 --- a/3rdparty/cudnn-frontend +++ b/3rdparty/cudnn-frontend @@ -1 +1 @@ -Subproject commit 209a25fe898bb749c8605363a6431e26cb41b089 +Subproject commit 4b4df2edcf80b7cc7a659da5734454577154aa6d From 91025c74f8e5121bb9f195e562e3a18c3a00ba12 Mon Sep 17 00:00:00 2001 From: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Date: Tue, 17 Feb 2026 14:32:53 -0800 Subject: [PATCH 018/172] fix MLA O strides; add bottom_right_diagonal Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --- tests/pytorch/attention/test_attention.py | 2 +- .../common/fused_attn/fused_attn.cpp | 12 +- .../common/fused_attn/fused_attn_fp8.cu | 488 +++++++++--------- .../common/fused_attn/fused_attn_fp8.h | 4 +- transformer_engine/common/fused_attn/utils.h | 10 +- .../dot_product_attention/backends.py | 8 + 6 files changed, 273 insertions(+), 251 deletions(-) diff --git a/tests/pytorch/attention/test_attention.py b/tests/pytorch/attention/test_attention.py index 05d76d96fe..ff3c7506e9 100644 --- a/tests/pytorch/attention/test_attention.py +++ b/tests/pytorch/attention/test_attention.py @@ -1788,7 +1788,7 @@ def get_model(dtype, config): model_configs_fp8_vs_f16 = { # test: ModelConfig(b, sq, hq, dqk) - "fp8_9": ModelConfig(2, 2048, 16, 128), + "fp8_9": ModelConfig(2, 2048, 16, 128, attn_mask_type="causal"), "fp8_10": ModelConfig(2, 2048, 24, 192, head_dim_v=128, num_gqa_groups=12, window_size=(512, 512)), "fp8_11": ModelConfig(1, 8192, 32, 128, num_gqa_groups=4), "fp8_12": ModelConfig(2, 2048, 16, 128, attn_mask_type="causal"), diff --git a/transformer_engine/common/fused_attn/fused_attn.cpp b/transformer_engine/common/fused_attn/fused_attn.cpp index d58fca70e9..ebda62568a 100644 --- a/transformer_engine/common/fused_attn/fused_attn.cpp +++ b/transformer_engine/common/fused_attn/fused_attn.cpp @@ -631,7 +631,7 @@ void nvte_fused_attn_fwd_qkvpacked( Tensor V_view = make_tensor_view(input_QKV, unpacked_shape, 2 * stride); fused_attn_fp8_fwd(b, h, h, max_seqlen, max_seqlen, d, d, is_training, attn_scale, dropout, - qkv_layout, bias_type, attn_mask_type, window_size_left, window_size_right, &Q_view, &K_view, &V_view, + qkv_layout, bias_type, attn_mask_type, window_size_left, window_size_right, bottom_right_diagonal, &Q_view, &K_view, &V_view, input_output_S, output_O, Aux_CTX_Tensors, input_cu_seqlens, input_cu_seqlens, input_rng_state, wkspace, stream, handle); #else @@ -783,7 +783,7 @@ void nvte_fused_attn_bwd_qkvpacked( Tensor dV_view = make_tensor_view(output_dQKV, unpacked_shape, 2 * stride); fused_attn_fp8_bwd(b, h, h, max_seqlen, max_seqlen, d, d, attn_scale, dropout, qkv_layout, - bias_type, attn_mask_type, window_size_left, window_size_right, &Q_view, &K_view, &V_view, input_O, input_dO, input_dO_f16, + bias_type, attn_mask_type, window_size_left, window_size_right, bottom_right_diagonal, &Q_view, &K_view, &V_view, input_O, input_dO, input_dO_f16, input_M, input_ZInv, input_S, input_output_dP, &dQ_view, &dK_view, &dV_view, input_cu_seqlens, input_cu_seqlens, input_rng_state, wkspace, stream, handle); @@ -941,7 +941,7 @@ void nvte_fused_attn_fwd_kvpacked( Tensor V_view = make_tensor_view(input_KV, unpacked_kv_shape, stride); fused_attn_fp8_fwd(b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d, d, is_training, attn_scale, - dropout, qkv_layout, bias_type, attn_mask_type, window_size_left, window_size_right, input_Q, &K_view, &V_view, + dropout, qkv_layout, bias_type, attn_mask_type, window_size_left, window_size_right, bottom_right_diagonal, input_Q, &K_view, &V_view, input_output_S, output_O, Aux_CTX_Tensors, input_cu_seqlens_q, input_cu_seqlens_kv, input_rng_state, wkspace, stream, handle); #else @@ -1104,7 +1104,7 @@ void nvte_fused_attn_bwd_kvpacked( Tensor dV_view = make_tensor_view(output_dKV, unpacked_kv_shape, stride); fused_attn_fp8_bwd(b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d, d, attn_scale, dropout, - qkv_layout, bias_type, attn_mask_type, window_size_left, window_size_right, input_Q, &K_view, &V_view, input_O, + qkv_layout, bias_type, attn_mask_type, window_size_left, window_size_right, bottom_right_diagonal, input_Q, &K_view, &V_view, input_O, input_dO, input_dO_f16, input_M, input_ZInv, input_S, input_output_dP, output_dQ, &dK_view, &dV_view, input_cu_seqlens_q, input_cu_seqlens_kv, input_rng_state, wkspace, stream, handle); @@ -1228,7 +1228,7 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso } else if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_FP8) { #if (CUDNN_VERSION >= 8900) fused_attn_fp8_fwd(b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d_qk, d_v, is_training, attn_scale, - dropout, qkv_layout, bias_type, attn_mask_type, window_size_left, window_size_right, input_Q, input_K, input_V, + dropout, qkv_layout, bias_type, attn_mask_type, window_size_left, window_size_right, bottom_right_diagonal, input_Q, input_K, input_V, input_output_S, output_O, Aux_CTX_Tensors, input_cu_seqlens_q, input_cu_seqlens_kv, input_rng_state, wkspace, stream, handle); #else @@ -1344,7 +1344,7 @@ void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETenso input_dO_f16 = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[3]); } fused_attn_fp8_bwd(b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d_qk, d_v, attn_scale, dropout, - qkv_layout, bias_type, attn_mask_type, window_size_left, window_size_right, input_Q, input_K, input_V, input_O, + qkv_layout, bias_type, attn_mask_type, window_size_left, window_size_right, bottom_right_diagonal, input_Q, input_K, input_V, input_O, input_dO, input_dO_f16, input_M, input_ZInv, input_S, input_output_dP, output_dQ, output_dK, output_dV, input_cu_seqlens_q, input_cu_seqlens_kv, input_rng_state, wkspace, stream, handle); diff --git a/transformer_engine/common/fused_attn/fused_attn_fp8.cu b/transformer_engine/common/fused_attn/fused_attn_fp8.cu index 5c809c6050..b4aebea25d 100644 --- a/transformer_engine/common/fused_attn/fused_attn_fp8.cu +++ b/transformer_engine/common/fused_attn/fused_attn_fp8.cu @@ -1654,7 +1654,7 @@ void fused_attn_fp8_bwd_impl( void fused_attn_fp8_fwd_impl_v1( int64_t b, int64_t h, int64_t hg, int64_t s_q, int64_t s_kv, int64_t d_qk, int64_t d_v, bool is_training, float scaling_factor, float dropout_probability, NVTE_QKV_Layout layout, - NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, int64_t window_size_left, int64_t window_size_right, void* devPtrQ, void* devPtrK, void* devPtrV, + NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, int64_t window_size_left, int64_t window_size_right, bool bottom_right_diagonal, void* devPtrQ, void* devPtrK, void* devPtrV, void* devPtrM, void* devPtrZInv, void* devPtrO, void* devPtrDescaleQ, void* devPtrDescaleK, void* devPtrDescaleV, void* devPtrDescaleS, void* devPtrScaleS, void* devPtrScaleO, void* devPtrAmaxO, void* devPtrAmaxS, void* devPtrcuSeqlensQ, void* devPtrcuSeqlensKV, @@ -1662,6 +1662,7 @@ void fused_attn_fp8_fwd_impl_v1( cudnn_frontend::DataType_t o_tensor_type, NVTEScalingMode scaling_mode, void* workspace, size_t* workspace_size, cudaStream_t stream, cudnnHandle_t handle) { using namespace transformer_engine; + const auto cudnn_runtime_version = cudnnGetVersion(); bool is_bias = (bias_type == NVTE_Bias_Type::NVTE_POST_SCALE_BIAS); bool is_alibi = (bias_type == NVTE_Bias_Type::NVTE_ALIBI); bool is_causal = ((mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK) || @@ -1681,6 +1682,13 @@ void fused_attn_fp8_fwd_impl_v1( o_tensor_type == cudnn_frontend::DataType_t::BFLOAT16); NVTE_CHECK(is_delayed_scaling || is_current_scaling || is_mxfp8, "FP8 fused attention only supports FP8DelayedScaling or FP8CurrentScaling or MXFP8 recipes!"); + printf(">>>>>> b: %d, h: %d, h_g: %d, s_q: %d, s_kv: %d, d_qk: %d, d_v: %d\n", b, h, hg, s_q, s_kv, d_qk, d_v); + printf(">>>>>> scaling_mode: %d\n", scaling_mode); + printf(">>>>>> is_mxfp8: %d\n", is_mxfp8); + printf(">>>>>> is_current_scaling: %d\n", is_current_scaling); + printf(">>>>>> is_delayed_scaling: %d\n", is_delayed_scaling); + printf(">>>>>> qkv_tensor_type: %d, %d, %d, %d\n", qkv_tensor_type, cudnn_frontend::DataType_t::FP8_E8M0, cudnn_frontend::DataType_t::FP8_E4M3, cudnn_frontend::DataType_t::FP8_E5M2); + printf(">>>>>> o_tensor_type: %d, %d, %d\n", o_tensor_type, cudnn_frontend::DataType_t::HALF, cudnn_frontend::DataType_t::BFLOAT16); try { FADescriptor_v1 descriptor{b, @@ -1707,7 +1715,7 @@ void fused_attn_fp8_fwd_impl_v1( NVTE_Softmax_Type::NVTE_VANILLA_SOFTMAX, window_size_left, window_size_right, - true, + bottom_right_diagonal, true, qkv_tensor_type, o_tensor_type, @@ -1762,6 +1770,7 @@ void fused_attn_fp8_fwd_impl_v1( std::shared_ptr bias, seq_q, seq_kv; std::shared_ptr dropout_seed, dropout_offset; + // Q, K, V, attn_scale std::vector q_stride(4); std::vector k_stride(4); std::vector v_stride(4); @@ -1770,102 +1779,112 @@ void fused_attn_fp8_fwd_impl_v1( generateMatrixStrides(b, hg, s_q, s_kv, d_qk, k_stride.data(), layout, NVTE_QKV_Matrix::NVTE_K_Matrix); generateMatrixStrides(b, hg, s_q, s_kv, d_v, v_stride.data(), layout, - NVTE_QKV_Matrix::NVTE_V_Matrix); // need to double check - - int32_t block_size = 32; - int64_t d_qk_scale = (d_qk + block_size - 1) / block_size; - int64_t d_v_scale = (d_v + block_size - 1) / block_size; - int64_t s_kv_scale = (s_kv + block_size - 1) / block_size; - int64_t s_q_scale = (s_q + block_size - 1) / block_size; - int64_t s_q_padded = ((s_q + 3) / 4) * 4; - int64_t s_kv_padded = ((s_kv + 3) / 4) * 4; - int64_t d_qk_scale_padded = ((d_qk_scale + 3) / 4) * 4; - int64_t d_v_scale_padded = ((d_v_scale + 3) / 4) * 4; - int64_t s_q_scale_padded = ((s_q_scale + 3) / 4) * 4; - int64_t s_kv_scale_padded = ((s_kv_scale + 3) / 4) * 4; - int64_t d_qk_padded = ((d_qk + 3) / 4) * 4; - int64_t d_v_padded = ((d_v + 3) / 4) * 4; - std::vector q_scale_dims = {b, h, s_q_padded, d_qk_scale_padded}; - std::vector k_scale_dims = {b, hg, s_kv_padded, d_qk_scale_padded}; - std::vector v_scale_dims = {b, hg, s_kv_scale_padded, d_v_padded}; - std::vector q_scale_strides(4); - std::vector k_scale_strides(4); - std::vector v_scale_strides(4); - generateMatrixStrides(b, h, s_q_padded, s_kv_padded, d_qk_scale_padded, q_scale_strides.data(), layout, - NVTE_QKV_Matrix::NVTE_Q_Matrix); - generateMatrixStrides(b, hg, s_q_padded, s_kv_padded, d_qk_scale_padded, k_scale_strides.data(), layout, - NVTE_QKV_Matrix::NVTE_K_Matrix); - generateMatrixStrides(b, hg, s_q_scale_padded, s_kv_scale_padded, d_v_padded, v_scale_strides.data(), layout, - NVTE_QKV_Matrix::NVTE_V_Matrix); - + NVTE_QKV_Matrix::NVTE_V_Matrix); Q = mha_graph->tensor(fe::graph::Tensor_attributes() - .set_name("Q") - .set_dim({b, h, s_q, d_qk}) - .set_stride(q_stride) - .set_data_type(qkv_tensor_type)); + .set_name("Q") + .set_dim({b, h, s_q, d_qk}) + .set_stride(q_stride) + .set_data_type(qkv_tensor_type)); K = mha_graph->tensor(fe::graph::Tensor_attributes() - .set_name("K") - .set_dim({b, hg, s_kv, d_qk}) - .set_stride(k_stride) - .set_data_type(qkv_tensor_type)); + .set_name("K") + .set_dim({b, hg, s_kv, d_qk}) + .set_stride(k_stride) + .set_data_type(qkv_tensor_type)); V = mha_graph->tensor(fe::graph::Tensor_attributes() - .set_name("V") - .set_dim({b, hg, s_kv, d_v}) - .set_stride(v_stride) - .set_data_type(qkv_tensor_type)); - + .set_name("V") + .set_dim({b, hg, s_kv, d_v}) + .set_stride(v_stride) + .set_data_type(qkv_tensor_type)); attn_scale = mha_graph->tensor(fe::graph::Tensor_attributes() - .set_name("attn_scale") - .set_dim({1, 1, 1, 1}) - .set_stride({1, 1, 1, 1}) - .set_is_pass_by_value(true) - .set_data_type(fe::DataType_t::FLOAT)); - if (!is_mxfp8) { + .set_name("attn_scale") + .set_dim({1, 1, 1, 1}) + .set_stride({1, 1, 1, 1}) + .set_is_pass_by_value(true) + .set_data_type(fe::DataType_t::FLOAT)); + + // Descale_q, Descale_k, Descale_v, Descale_s, Scale_s, Scale_o + if (is_delayed_scaling || is_current_scaling) { descale_q = mha_graph->tensor(fe::graph::Tensor_attributes() - .set_name("Descale_q") - .set_dim({1, 1, 1, 1}) - .set_stride({1, 1, 1, 1}) - .set_data_type(fe::DataType_t::FLOAT)); + .set_name("Descale_q") + .set_dim({1, 1, 1, 1}) + .set_stride({1, 1, 1, 1}) + .set_data_type(fe::DataType_t::FLOAT)); descale_k = mha_graph->tensor_like(descale_q, "Descale_q"); - descale_v = mha_graph->tensor_like(descale_q, "Descale_V"); - descale_s = mha_graph->tensor_like(descale_q, "Descale_S"); - scale_s = mha_graph->tensor_like(descale_q, "Scale_S"); - } else { + descale_v = mha_graph->tensor_like(descale_q, "Descale_v"); + descale_s = mha_graph->tensor_like(descale_q, "Descale_s"); + scale_s = mha_graph->tensor_like(descale_q, "Scale_s"); + if (is_delayed_scaling) { + scale_o = mha_graph->tensor_like(descale_q, "Scale_o"); + } + if (is_current_scaling) { + scale_o = mha_graph->tensor(1.0f); + } + } + if (is_mxfp8) { + int32_t block_size = 32; + int64_t s_q_padded = ((s_q + 127) / 128) * 128; + int64_t s_kv_padded = ((s_kv + 127) / 128) * 128; + int64_t s_q_scale = (s_q + block_size - 1) / block_size; + int64_t s_kv_scale = (s_kv + block_size - 1) / block_size; + int64_t s_q_scale_padded = ((s_q_scale + 3) / 4) * 4; + int64_t s_kv_scale_padded = ((s_kv_scale + 3) / 4) * 4; + int64_t d_qk_padded = ((d_qk + 127) / 128) * 128; + int64_t d_v_padded = ((d_v + 127) / 128) * 128; + int64_t d_qk_scale = (d_qk + block_size - 1) / block_size; + int64_t d_v_scale = (d_v + block_size - 1) / block_size; + int64_t d_qk_scale_padded = ((d_qk_scale + 3) / 4) * 4; + int64_t d_v_scale_padded = ((d_v_scale + 3) / 4) * 4; + std::vector q_scale_strides(4); + std::vector k_scale_strides(4); + std::vector v_scale_strides(4); + generateMatrixStrides(b, h, s_q_padded, s_kv_padded, d_qk_scale_padded, q_scale_strides.data(), layout, + NVTE_QKV_Matrix::NVTE_Q_Matrix); + generateMatrixStrides(b, hg, s_q_padded, s_kv_padded, d_qk_scale_padded, k_scale_strides.data(), layout, + NVTE_QKV_Matrix::NVTE_K_Matrix); + generateMatrixStrides(b, hg, s_q_scale_padded, s_kv_scale_padded, d_v_padded, v_scale_strides.data(), layout, + NVTE_QKV_Matrix::NVTE_V_Matrix); + printf(">>>>>> q_scale_strides: %d, %d, %d, %d\n", q_scale_strides[0], q_scale_strides[1], q_scale_strides[2], q_scale_strides[3]); + printf(">>>>>> k_scale_strides: %d, %d, %d, %d\n", k_scale_strides[0], k_scale_strides[1], k_scale_strides[2], k_scale_strides[3]); + printf(">>>>>> v_scale_strides: %d, %d, %d, %d\n", v_scale_strides[0], v_scale_strides[1], v_scale_strides[2], v_scale_strides[3]); + descale_q = mha_graph->tensor(fe::graph::Tensor_attributes() .set_name("Descale_q") - .set_dim(q_scale_dims) + .set_dim({b, h, s_q_padded, d_qk_scale_padded}) .set_stride(q_scale_strides) .set_data_type(fe::DataType_t::FP8_E8M0) .set_reordering_type(fe::TensorReordering_t::F8_128x4)); descale_k = mha_graph->tensor(fe::graph::Tensor_attributes() .set_name("Descale_k") - .set_dim(k_scale_dims) + .set_dim({b, hg, s_kv_padded, d_qk_scale_padded}) .set_stride(k_scale_strides) .set_data_type(fe::DataType_t::FP8_E8M0) .set_reordering_type(fe::TensorReordering_t::F8_128x4)); descale_v = mha_graph->tensor(fe::graph::Tensor_attributes() .set_name("Descale_v") - .set_dim(v_scale_dims) + .set_dim({b, hg, s_kv_scale_padded, d_v_padded}) .set_stride(v_scale_strides) .set_data_type(fe::DataType_t::FP8_E8M0) .set_reordering_type(fe::TensorReordering_t::F8_128x4)); } - if (is_delayed_scaling) { - scale_o = mha_graph->tensor_like(descale_q, "Scale_O"); - } - if (is_current_scaling) { - scale_o = mha_graph->tensor(1.0f); - } - fe::graph::SDPA_fp8_attributes sdpa_options; sdpa_options = fe::graph::SDPA_fp8_attributes() .set_name("sdpa_fp8") .set_generate_stats(true) .set_causal_mask(is_causal) .set_attn_scale(attn_scale); - sdpa_options.set_diagonal_band_left_bound(window_size_left + 1); - sdpa_options.set_diagonal_band_right_bound(window_size_right); + + fe::DiagonalAlignment_t const &diagonal_alignment = + bottom_right_diagonal ? fe::DiagonalAlignment_t::BOTTOM_RIGHT + : fe::DiagonalAlignment_t::TOP_LEFT; + sdpa_options.set_diagonal_alignment(diagonal_alignment); + + if (cudnn_runtime_version >= 90200 && window_size_left != -1) { + sdpa_options.set_diagonal_band_left_bound(window_size_left + 1); + } + if (cudnn_runtime_version >= 90600 && window_size_right != -1) { + sdpa_options.set_diagonal_band_right_bound(window_size_right); + } // sdpa_options.set_alibi_mask(is_alibi); // if (is_bias) { @@ -1924,9 +1943,13 @@ void fused_attn_fp8_fwd_impl_v1( } std::vector o_stride(4); - generateMatrixStrides(b, h, s_q, s_kv, d_qk, o_stride.data(), layout, + generateMatrixStrides(b, h, s_q, s_kv, d_v, o_stride.data(), layout, NVTE_QKV_Matrix::NVTE_O_Matrix); - O->set_output(true).set_dim({b, h, s_q, d_qk}).set_stride(o_stride).set_data_type(o_tensor_type); + printf(">>>>>> q_stride: %d, %d, %d, %d\n", q_stride[0], q_stride[1], q_stride[2], q_stride[3]); + printf(">>>>>> k_stride: %d, %d, %d, %d\n", k_stride[0], k_stride[1], k_stride[2], k_stride[3]); + printf(">>>>>> v_stride: %d, %d, %d, %d\n", v_stride[0], v_stride[1], v_stride[2], v_stride[3]); + printf(">>>>>> o_stride: %d, %d, %d, %d\n", o_stride[0], o_stride[1], o_stride[2], o_stride[3]); + O->set_output(true).set_dim({b, h, s_q, d_v}).set_stride(o_stride).set_data_type(o_tensor_type); amax_o->set_output(true) .set_dim({1, 1, 1, 1}) .set_stride({1, 1, 1, 1}) @@ -2044,7 +2067,7 @@ void fused_attn_fp8_fwd_impl_v1( void fused_attn_fp8_bwd_impl_v1( int64_t b, int64_t h, int64_t hg, int64_t s_q, int64_t s_kv, int64_t d_qk, int64_t d_v, float scaling_factor, float dropout_probability, NVTE_QKV_Layout layout, NVTE_Bias_Type bias_type, - NVTE_Mask_Type mask_type, int64_t window_size_left, int64_t window_size_right, void* devPtrQ, void* devPtrK, void* devPtrV, void* devPtrM, + NVTE_Mask_Type mask_type, int64_t window_size_left, int64_t window_size_right, bool bottom_right_diagonal,void* devPtrQ, void* devPtrK, void* devPtrV, void* devPtrM, void* devPtrZInv, void* devPtrO, void* devPtrdO, void* devPtrdQ, void* devPtrdK, void* devPtrdV, void* devPtrDescaleQ, void* devPtrDescaleK, void* devPtrDescaleV, void* devPtrDescaleO, void* devPtrDescaledO, void* devPtrDescaleS, void* devPtrDescaledP, void* devPtrScaleS, @@ -2057,6 +2080,7 @@ void fused_attn_fp8_bwd_impl_v1( cudnn_frontend::DataType_t dqkv_tensor_type, NVTEScalingMode scaling_mode, void* workspace, size_t* workspace_size, cudaStream_t stream, cudnnHandle_t handle) { using namespace transformer_engine; + const auto cudnn_runtime_version = cudnnGetVersion(); bool is_bias = (bias_type == NVTE_Bias_Type::NVTE_POST_SCALE_BIAS); bool is_alibi = (bias_type == NVTE_Bias_Type::NVTE_ALIBI); bool is_causal = ((mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK) || @@ -2105,7 +2129,7 @@ void fused_attn_fp8_bwd_impl_v1( NVTE_Softmax_Type::NVTE_VANILLA_SOFTMAX, window_size_left, window_size_right, - true, + bottom_right_diagonal, false, qkv_tensor_type, o_tensor_type, @@ -2116,25 +2140,25 @@ void fused_attn_fp8_bwd_impl_v1( namespace fe = cudnn_frontend; using graph_and_tensors = std::tuple, - std::shared_ptr, // q - std::shared_ptr, // q_t - std::shared_ptr, // k - std::shared_ptr, // k_t - std::shared_ptr, // v - std::shared_ptr, // o - std::shared_ptr, // stats + std::shared_ptr, // Q + // std::shared_ptr, // Q_t + std::shared_ptr, // K + // std::shared_ptr, // K_t + std::shared_ptr, // V + std::shared_ptr, // O + std::shared_ptr, // Stats std::shared_ptr, // dO - std::shared_ptr, // dO_t - std::shared_ptr, // dO_f16 + // std::shared_ptr, // dO_t + // std::shared_ptr, // dO_f16 std::shared_ptr, // attn_scale std::shared_ptr, // descale_q - std::shared_ptr, // descale_q_t + // std::shared_ptr, // descale_q_t std::shared_ptr, // descale_k - std::shared_ptr, // descale_k_t + // std::shared_ptr, // descale_k_t std::shared_ptr, // descale_v std::shared_ptr, // descale_o std::shared_ptr, // descale_dO - std::shared_ptr, // descale_dO_t + // std::shared_ptr, // descale_dO_t std::shared_ptr, // descale_s std::shared_ptr, // descale_dP std::shared_ptr, // scale_dQ @@ -2175,16 +2199,16 @@ void fused_attn_fp8_bwd_impl_v1( .set_intermediate_data_type(fe::DataType_t::FLOAT) .set_compute_data_type(fe::DataType_t::FLOAT); - std::shared_ptr q, k, v, o, dO, stats, attn_scale; - std::shared_ptr descale_q, descale_k, descale_v; + std::shared_ptr Q, Q_t, K, K_t, V, O, dO, dO_t, dO_f16, Stats, attn_scale; + std::shared_ptr descale_q, descale_q_t, descale_k, descale_k_t, descale_v; std::shared_ptr descale_s, descale_o; - std::shared_ptr descale_dP, descale_dO; + std::shared_ptr descale_dP, descale_dO, descale_dO_t; std::shared_ptr scale_s, scale_dP; std::shared_ptr scale_dQ, scale_dK, scale_dV; std::shared_ptr bias, dBias, seq_q, seq_kv; std::shared_ptr dropout_seed, dropout_offset; - std::shared_ptr q_t, k_t, dO_t, dO_f16, descale_q_t, descale_k_t, descale_dO_t; + // Q, K, V, O, dO, stats, attn_scale std::vector q_stride(4); std::vector k_stride(4); std::vector v_stride(4); @@ -2195,39 +2219,38 @@ void fused_attn_fp8_bwd_impl_v1( NVTE_QKV_Matrix::NVTE_K_Matrix); generateMatrixStrides(b, hg, s_q, s_kv, d_v, v_stride.data(), layout, NVTE_QKV_Matrix::NVTE_V_Matrix); - generateMatrixStrides(b, h, s_q, s_kv, d_qk, o_stride.data(), layout, + generateMatrixStrides(b, h, s_q, s_kv, d_v, o_stride.data(), layout, NVTE_QKV_Matrix::NVTE_O_Matrix); - q = mha_graph->tensor(fe::graph::Tensor_attributes() + Q = mha_graph->tensor(fe::graph::Tensor_attributes() .set_name("Q") .set_dim({b, h, s_q, d_qk}) .set_stride(q_stride) .set_data_type(qkv_tensor_type)); - k = mha_graph->tensor(fe::graph::Tensor_attributes() + K = mha_graph->tensor(fe::graph::Tensor_attributes() .set_name("K") .set_dim({b, hg, s_kv, d_qk}) .set_stride(k_stride) .set_data_type(qkv_tensor_type)); - v = mha_graph->tensor(fe::graph::Tensor_attributes() + V = mha_graph->tensor(fe::graph::Tensor_attributes() .set_name("V") .set_dim({b, hg, s_kv, d_v}) .set_stride(v_stride) .set_data_type(qkv_tensor_type)); - o = mha_graph->tensor(fe::graph::Tensor_attributes() + O = mha_graph->tensor(fe::graph::Tensor_attributes() .set_name("O") - .set_dim({b, h, s_q, d_qk}) + .set_dim({b, h, s_q, d_v}) .set_stride(o_stride) .set_data_type(o_tensor_type)); dO = mha_graph->tensor(fe::graph::Tensor_attributes() .set_name("dO") - .set_dim({b, h, s_q, d_qk}) + .set_dim({b, h, s_q, d_v}) .set_stride(o_stride) .set_data_type(do_tensor_type)); - stats = mha_graph->tensor(fe::graph::Tensor_attributes() - .set_name("stats") + Stats = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("Stats") .set_dim({b, h, s_q, 1}) .set_stride({h * s_q, s_q, 1, 1}) .set_data_type(fe::DataType_t::FLOAT)); - attn_scale = mha_graph->tensor(fe::graph::Tensor_attributes() .set_name("attn_scale") .set_dim({1, 1, 1, 1}) @@ -2235,25 +2258,25 @@ void fused_attn_fp8_bwd_impl_v1( .set_is_pass_by_value(true) .set_data_type(fe::DataType_t::FLOAT)); - if (!is_mxfp8) { + // Descale_q, Descale_k, Descale_v, Descale_s, Scale_s, Descale_dP, Scale_dP, Descale_o, Descale_dO, Scale_dQ, Scale_dK, Scale_dV + if (is_delayed_scaling || is_current_scaling) { descale_q = mha_graph->tensor(fe::graph::Tensor_attributes() .set_name("Descale_q") .set_dim({1, 1, 1, 1}) .set_stride({1, 1, 1, 1}) .set_data_type(fe::DataType_t::FLOAT)); descale_k = mha_graph->tensor_like(descale_q, "Descale_q"); - descale_v = mha_graph->tensor_like(descale_q, "Descale_V"); - descale_s = mha_graph->tensor_like(descale_q, "Descale_S"); + descale_v = mha_graph->tensor_like(descale_q, "Descale_v"); + descale_s = mha_graph->tensor_like(descale_q, "Descale_s"); + scale_s = mha_graph->tensor_like(descale_q, "Scale_s"); descale_dP = mha_graph->tensor_like(descale_q, "Descale_dP"); - if (is_O_in_F16) { + scale_dP = mha_graph->tensor_like(descale_q, "Scale_dP"); + if (is_current_scaling && is_O_in_F16) { descale_o = mha_graph->tensor(1.0f); } else { descale_o = mha_graph->tensor_like(descale_q, "Descale_O"); } descale_dO = mha_graph->tensor_like(descale_q, "Descale_dO"); - scale_s = mha_graph->tensor_like(descale_q, "Scale_S"); - scale_dP = mha_graph->tensor_like(descale_q, "Scale_dP"); - if (is_delayed_scaling) { scale_dQ = mha_graph->tensor_like(descale_q, "Scale_dQ"); scale_dK = mha_graph->tensor_like(descale_q, "Scale_dK"); @@ -2264,74 +2287,73 @@ void fused_attn_fp8_bwd_impl_v1( scale_dK = mha_graph->tensor(1.0f); scale_dV = mha_graph->tensor(1.0f); } - } else { + } + if (is_mxfp8) { + // Q_t, K_t, dO_t, dO_f16 std::vector q_t_stride(4); std::vector k_t_stride(4); std::vector dO_t_stride(4); - generateMatrixStrides(b, h, d_qk, s_kv, s_q, q_t_stride.data(), layout, - NVTE_QKV_Matrix::NVTE_Q_Matrix); + generateMatrixStrides(b, h, s_q, s_kv, d_qk, q_t_stride.data(), layout, + NVTE_QKV_Matrix::NVTE_Q_Matrix_Transpose); generateMatrixStrides(b, hg, s_q, s_kv, d_qk, k_t_stride.data(), layout, NVTE_QKV_Matrix::NVTE_K_Matrix_Transpose); - generateMatrixStrides(b, h, d_qk, s_kv, s_q, dO_t_stride.data(), layout, - NVTE_QKV_Matrix::NVTE_O_Matrix); + generateMatrixStrides(b, h, s_q, s_kv, d_v, dO_t_stride.data(), layout, + NVTE_QKV_Matrix::NVTE_O_Matrix_Transpose); + Q_t = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("Q_t") + .set_dim({b, h, s_q, d_qk}) + .set_stride(q_t_stride) + .set_data_type(qkv_tensor_type)); + K_t = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("K_t") + .set_dim({b, hg, s_kv, d_qk}) + .set_stride(k_t_stride) + .set_data_type(qkv_tensor_type)); + dO_t = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("dO_t") + .set_dim({b, h, s_q, d_v}) + .set_stride(dO_t_stride) + .set_data_type(do_tensor_type)); + dO_f16 = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("dO_f16") + .set_dim({b, h, s_q, d_v}) + .set_stride(o_stride) + .set_data_type(o_tensor_type)); + // Descale_q, Descale_q_t, Descale_k, Descale_k_t, Descale_v, Descale_dO, Descale_dO_t int32_t block_size = 32; - int64_t d_qk_scale = (d_qk + block_size - 1) / block_size; - int64_t d_v_scale = (d_v + block_size - 1) / block_size; + int64_t s_q_padded = ((s_q + 127) / 128) * 128; + int64_t s_kv_padded = ((s_kv + 127) / 128) * 128; int64_t s_q_scale = (s_q + block_size - 1) / block_size; int64_t s_kv_scale = (s_kv + block_size - 1) / block_size; - int64_t s_q_padded = ((s_q + 3) / 4) * 4; - int64_t s_kv_padded = ((s_kv + 3) / 4) * 4; - int64_t d_qk_padded = ((d_qk + 3) / 4) * 4; - int64_t d_v_padded = ((d_v + 3) / 4) * 4; - int64_t d_qk_scale_padded = ((d_qk_scale + 3) / 4) * 4; - int64_t d_v_scale_padded = ((d_v_scale + 3) / 4) * 4; int64_t s_q_scale_padded = ((s_q_scale + 3) / 4) * 4; int64_t s_kv_scale_padded = ((s_kv_scale + 3) / 4) * 4; - // std::vector q_scale_dims = {b, h, s_q_padded, d_qk_scale_padded}; - // std::vector k_scale_dims = {b, hg, s_kv_padded, d_qk_scale_padded}; - // std::vector v_scale_dims = {b, hg, s_kv_padded, d_v_scale_padded}; - // std::vector q_t_scale_dims = {b, h, s_q_scale_padded, d_qk_padded}; - // std::vector k_t_scale_dims = {b, hg, s_kv_scale_padded, d_qk_padded}; - // // std::vector dO_scale_dims = {b, h, s_q_padded, d_qk_scale_padded}; - // // std::vector dO_t_scale_dims = {b, h, s_q_scale_padded, d_qk_padded}; + int64_t d_qk_padded = ((d_qk + 127) / 128) * 128; + int64_t d_v_padded = ((d_v + 127) / 128) * 128; + int64_t d_qk_scale = (d_qk + block_size - 1) / block_size; + int64_t d_v_scale = (d_v + block_size - 1) / block_size; + int64_t d_qk_scale_padded = ((d_qk_scale + 3) / 4) * 4; + int64_t d_v_scale_padded = ((d_v_scale + 3) / 4) * 4; std::vector q_scale_strides(4); - std::vector k_scale_strides(4); - std::vector v_scale_strides(4); std::vector q_t_scale_strides(4); + std::vector k_scale_strides(4); std::vector k_t_scale_strides(4); - // std::vector dO_scale_strides(4); - // std::vector dO_t_scale_strides(4); + std::vector v_scale_strides(4); + std::vector dO_scale_strides(4); + std::vector dO_t_scale_strides(4); generateMatrixStrides(b, h, s_q_padded, s_kv_padded, d_qk_scale_padded, q_scale_strides.data(), layout, NVTE_QKV_Matrix::NVTE_Q_Matrix); - generateMatrixStrides(b, h, d_qk_padded, s_kv_scale_padded, s_q_scale_padded, q_t_scale_strides.data(), layout, - NVTE_QKV_Matrix::NVTE_Q_Matrix); + generateMatrixStrides(b, h, s_q_scale_padded, s_kv_scale_padded, d_qk_padded, q_t_scale_strides.data(), layout, + NVTE_QKV_Matrix::NVTE_Q_Matrix_Transpose); generateMatrixStrides(b, hg, s_q_padded, s_kv_padded, d_qk_scale_padded, k_scale_strides.data(), layout, NVTE_QKV_Matrix::NVTE_K_Matrix); generateMatrixStrides(b, hg, s_q_scale_padded, s_kv_scale_padded, d_qk_padded, k_t_scale_strides.data(), layout, NVTE_QKV_Matrix::NVTE_K_Matrix_Transpose); generateMatrixStrides(b, hg, s_q_padded, s_kv_padded, d_v_scale_padded, v_scale_strides.data(), layout, NVTE_QKV_Matrix::NVTE_V_Matrix); - - q_t = mha_graph->tensor(fe::graph::Tensor_attributes() - .set_name("Q_t") - .set_dim({b, h, s_q, d_qk}) - .set_stride(q_t_stride) - .set_data_type(qkv_tensor_type)); - k_t = mha_graph->tensor(fe::graph::Tensor_attributes() - .set_name("K_t") - .set_dim({b, hg, s_kv, d_qk}) - .set_stride(k_t_stride) - .set_data_type(qkv_tensor_type)); - dO_t = mha_graph->tensor(fe::graph::Tensor_attributes() - .set_name("dO_t") - .set_dim({b, h, s_q, d_qk}) - .set_stride(dO_t_stride) - .set_data_type(do_tensor_type)); - dO_f16 = mha_graph->tensor(fe::graph::Tensor_attributes() - .set_name("dO_f16") - .set_dim({b, h, s_q, d_qk}) - .set_stride(dO_t_stride) - .set_data_type(o_tensor_type)); + generateMatrixStrides(b, h, s_q_padded, s_kv_padded, d_v_scale_padded, dO_scale_strides.data(), layout, + NVTE_QKV_Matrix::NVTE_O_Matrix); + generateMatrixStrides(b, h, s_q_scale_padded, s_kv_scale_padded, d_v_padded, dO_t_scale_strides.data(), layout, + NVTE_QKV_Matrix::NVTE_O_Matrix_Transpose); descale_q = mha_graph->tensor(fe::graph::Tensor_attributes() .set_name("Descale_q") .set_dim({b, h, s_q_padded, d_qk_scale_padded}) @@ -2364,14 +2386,14 @@ void fused_attn_fp8_bwd_impl_v1( .set_reordering_type(fe::TensorReordering_t::F8_128x4)); descale_dO = mha_graph->tensor(fe::graph::Tensor_attributes() .set_name("Descale_dO") - .set_dim({b, h, s_q_padded, d_qk_scale_padded}) - .set_stride(q_scale_strides) + .set_dim({b, h, s_q_padded, d_v_scale_padded}) + .set_stride(dO_scale_strides) .set_data_type(fe::DataType_t::FP8_E8M0) .set_reordering_type(fe::TensorReordering_t::F8_128x4)); descale_dO_t = mha_graph->tensor(fe::graph::Tensor_attributes() .set_name("Descale_dO_t") - .set_dim({b, h, s_q_scale_padded, d_qk_padded}) - .set_stride(q_t_scale_strides) + .set_dim({b, h, s_q_scale_padded, d_v_padded}) + .set_stride(dO_t_scale_strides) .set_data_type(fe::DataType_t::FP8_E8M0) .set_reordering_type(fe::TensorReordering_t::F8_128x4)); } @@ -2382,8 +2404,17 @@ void fused_attn_fp8_bwd_impl_v1( .set_causal_mask(is_causal) .set_attn_scale(attn_scale); - // sdpa_backward_options.set_diagonal_band_left_bound(window_size_left + 1); - // sdpa_backward_options.set_diagonal_band_right_bound(window_size_right); + // fe::DiagonalAlignment_t const &diagonal_alignment = + // bottom_right_diagonal ? fe::DiagonalAlignment_t::BOTTOM_RIGHT + // : fe::DiagonalAlignment_t::TOP_LEFT; + // sdpa_backward_options.set_diagonal_alignment(diagonal_alignment); + + // if (cudnn_runtime_version >= 90200 && window_size_left != -1) { + // sdpa_backward_options.set_diagonal_band_left_bound(window_size_left + 1); + // } + // if (cudnn_runtime_version >= 90600 && window_size_right != -1) { + // sdpa_backward_options.set_diagonal_band_right_bound(window_size_right); + // } // sdpa_backward_options.set_alibi_mask(is_alibi); @@ -2434,14 +2465,15 @@ void fused_attn_fp8_bwd_impl_v1( .set_data_type(fe::DataType_t::INT64)); sdpa_backward_options.set_dropout(dropout_probability, dropout_seed, dropout_offset); } - // if (!is_mxfp8) { - auto [dQ, dK, dV, amax_dQ, amax_dK, amax_dV, amax_dP] = mha_graph->sdpa_fp8_backward( - q, k, v, o, dO, stats, descale_q, descale_k, descale_v, descale_o, descale_dO, descale_s, - descale_dP, scale_s, scale_dQ, scale_dK, scale_dV, scale_dP, sdpa_backward_options); - // } else { - // auto [dQ, dK, dV, amax_dQ, amax_dK, amax_dV] = mha_graph->sdpa_fp8_backward( - // q, q_t, k, k_t, v, o, dO_f16, dO, dO_t, stats, descale_q, descale_q_t, descale_k, descale_k_t, descale_v, descale_dO, descale_dO_t, - // sdpa_backward_options); + // if (is_delayed_scaling || is_current_scaling) { + auto [dQ, dK, dV, amax_dQ, amax_dK, amax_dV, amax_dP] = mha_graph->sdpa_fp8_backward( + Q, K, V, O, dO, Stats, descale_q, descale_k, descale_v, descale_o, descale_dO, descale_s, + descale_dP, scale_s, scale_dQ, scale_dK, scale_dV, scale_dP, sdpa_backward_options); + // } + // if (is_mxfp8) { + // auto [dQ, dK, dV, amax_dQ, amax_dK, amax_dV] = mha_graph->sdpa_fp8_backward( + // Q, Q_t, K, K_t, V, O, dO_f16, dO, dO_t, Stats, descale_q, descale_q_t, descale_k, descale_k_t, + // descale_v, descale_dO, descale_dO_t, sdpa_backward_options); // } dQ->set_output(true).set_dim({b, h, s_q, d_qk}).set_stride(q_stride).set_data_type(dqkv_tensor_type); dK->set_output(true).set_dim({b, hg, s_kv, d_qk}).set_stride(k_stride).set_data_type(dqkv_tensor_type); @@ -2464,30 +2496,19 @@ void fused_attn_fp8_bwd_impl_v1( .set_stride({1, 1, 1, 1}) .set_data_type(fe::DataType_t::FLOAT); } - // dO->set_data_type(do_tensor_type); - // dQ->set_data_type(dqkv_tensor_type); - // dK->set_data_type(dqkv_tensor_type); - // dV->set_data_type(dqkv_tensor_type); - - std::tuple, // q - // std::shared_ptr, // q_t - std::shared_ptr, // k - // std::shared_ptr, // k_t - std::shared_ptr, // v - std::shared_ptr, // o - std::shared_ptr, // stats + + std::tuple, // Q + std::shared_ptr, // K + std::shared_ptr, // V + std::shared_ptr, // O + std::shared_ptr, // Stats std::shared_ptr, // dO - // std::shared_ptr, // dO_t - // std::shared_ptr, // dO_f16 std::shared_ptr, // attn_scale std::shared_ptr, // descale_q - // std::shared_ptr, // descale_q_t std::shared_ptr, // descale_k - // std::shared_ptr, // descale_k_t std::shared_ptr, // descale_v std::shared_ptr, // descale_o std::shared_ptr, // descale_dO - // std::shared_ptr, // descale_dO_t std::shared_ptr, // descale_s std::shared_ptr, // descale_dP std::shared_ptr, // scale_dQ @@ -2503,11 +2524,11 @@ void fused_attn_fp8_bwd_impl_v1( std::shared_ptr, // amax_dV std::shared_ptr> // amax_dP key_tensors_tuple = std::make_tuple( - q, k, v, o, stats, dO, attn_scale, descale_q, descale_k, descale_v, descale_o, + Q, K, V, O, Stats, dO, attn_scale, descale_q, descale_k, descale_v, descale_o, descale_dO, descale_s, descale_dP, scale_s, scale_dQ, scale_dK, scale_dV, scale_dP, dQ, dK, dV, amax_dQ, amax_dK, amax_dV, amax_dP); - auto mxfp8_tensors_tuple = is_mxfp8 ? std::make_tuple(q_t, k_t, dO_f16, dO_t, descale_q_t, descale_k_t, descale_dO_t) : std::make_tuple(nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr); - // key_tensors_tuple = std::tuple_cat(key_tensors_tuple, mxfp8_tensors_tuple); + // auto mxfp8_tensors_tuple = is_mxfp8 ? std::make_tuple(Q_t, K_t, dO_f16, dO_t, descale_q_t, descale_k_t, descale_dO_t) + // : std::make_tuple(nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr); auto bias_tuple = is_bias ? std::make_tuple(bias, dBias) : std::make_tuple(nullptr, nullptr); auto padding_tuple = is_padding ? std::make_tuple(seq_q, seq_kv) : std::make_tuple(nullptr, nullptr); @@ -2521,23 +2542,19 @@ void fused_attn_fp8_bwd_impl_v1( NVTE_CHECK_CUDNN_FE(mha_graph->build_plans(handle)); auto return_tuple = std::tuple_cat(std::make_tuple(mha_graph), key_tensors_tuple, bias_tuple, - padding_tuple, dropout_tuple, mxfp8_tensors_tuple); + padding_tuple, dropout_tuple); + // padding_tuple, dropout_tuple, mxfp8_tensors_tuple); cache.insert({descriptor, return_tuple}); return return_tuple; }; auto bprop_tuple = get_graph(sdpa_fp8_bprop_cache, descriptor); - // if (!is_mxfp8) { - // auto [mha_graph, q, k, v, o, stats, dO, attn_scale, descale_q, descale_k, descale_v, descale_o, - // descale_dO, descale_s, descale_dP, scale_s, scale_dQ, scale_dK, scale_dV, scale_dP, dQ, - // dK, dV, amax_dQ, amax_dK, amax_dV, amax_dP, bias, dBias, seq_q, seq_kv, dropout_seed, - // dropout_offset] = get_graph(sdpa_fp8_bprop_cache, descriptor); auto mha_graph = std::get<0>(bprop_tuple); - auto q = std::get<1>(bprop_tuple); - auto k = std::get<2>(bprop_tuple); - auto v = std::get<3>(bprop_tuple); - auto o = std::get<4>(bprop_tuple); - auto stats = std::get<5>(bprop_tuple); + auto Q = std::get<1>(bprop_tuple); + auto K = std::get<2>(bprop_tuple); + auto V = std::get<3>(bprop_tuple); + auto O = std::get<4>(bprop_tuple); + auto Stats = std::get<5>(bprop_tuple); auto dO = std::get<6>(bprop_tuple); auto attn_scale = std::get<7>(bprop_tuple); auto descale_q = std::get<8>(bprop_tuple); @@ -2565,19 +2582,14 @@ void fused_attn_fp8_bwd_impl_v1( auto seq_kv = std::get<30>(bprop_tuple); auto dropout_seed = std::get<31>(bprop_tuple); auto dropout_offset = std::get<32>(bprop_tuple); - // } else { // if (is_mxfp8) { - // auto [mha_graph, q, k, v, o, stats, dO, attn_scale, descale_q, descale_k, descale_v, descale_o, - // descale_dO, descale_s, descale_dP, scale_s, scale_dQ, scale_dK, scale_dV, scale_dP, dQ, - // dK, dV, amax_dQ, amax_dK, amax_dV, amax_dP, q_t, k_t, dO_f16, dO_t, descale_q_t, descale_k_t, descale_dO_t, bias, dBias, seq_q, seq_kv, dropout_seed, - // dropout_offset] = get_graph(sdpa_fp8_bprop_cache, descriptor); - auto q_t = std::get<33>(bprop_tuple); - auto k_t = std::get<34>(bprop_tuple); - auto dO_f16 = std::get<35>(bprop_tuple); - auto dO_t = std::get<36>(bprop_tuple); - auto descale_q_t = std::get<37>(bprop_tuple); - auto descale_k_t = std::get<38>(bprop_tuple); - auto descale_dO_t = std::get<39>(bprop_tuple); + // auto Q_t = std::get<33>(bprop_tuple); + // auto K_t = std::get<34>(bprop_tuple); + // auto dO_f16 = std::get<35>(bprop_tuple); + // auto dO_t = std::get<36>(bprop_tuple); + // auto descale_q_t = std::get<37>(bprop_tuple); + // auto descale_k_t = std::get<38>(bprop_tuple); + // auto descale_dO_t = std::get<39>(bprop_tuple); // } auto plan_workspace_size = mha_graph->get_workspace_size(); @@ -2594,11 +2606,11 @@ void fused_attn_fp8_bwd_impl_v1( // build variant pack std::unordered_map, void*> variant_pack = { - {q, devPtrQ}, - {k, devPtrK}, - {v, devPtrV}, - {o, devPtrO}, - {stats, devPtrM}, + {Q, devPtrQ}, + {K, devPtrK}, + {V, devPtrV}, + {O, devPtrO}, + {Stats, devPtrM}, {dO, devPtrdO}, {attn_scale, &scaling_factor}, {descale_q, devPtrDescaleQ}, @@ -2612,31 +2624,31 @@ void fused_attn_fp8_bwd_impl_v1( {amax_dK, devPtrAmaxdK}, {amax_dV, devPtrAmaxdV}, }; - if (!is_mxfp8) { - variant_pack[descale_s] = devPtrDescaleS; - variant_pack[descale_dP] = devPtrDescaledP; - variant_pack[scale_s] = devPtrScaleS; - variant_pack[scale_dP] = devPtrScaledP; - variant_pack[amax_dP] = devPtrAmaxdP; - } else { - variant_pack[q_t] = devPtrQ_t; - variant_pack[k_t] = devPtrK_t; - variant_pack[dO_f16] = devPtrdO_f16; - variant_pack[dO_t] = devPtrdO_t; - variant_pack[descale_q_t] = devPtrDescaleQ_t; - variant_pack[descale_k_t] = devPtrDescaleK_t; - variant_pack[descale_dO] = devPtrDescaledO; - variant_pack[descale_dO_t] = devPtrDescaledO_t; + if (is_delayed_scaling || is_current_scaling) { + variant_pack[descale_s] = devPtrDescaleS; + variant_pack[descale_dP] = devPtrDescaledP; + variant_pack[scale_s] = devPtrScaleS; + variant_pack[scale_dP] = devPtrScaledP; + variant_pack[amax_dP] = devPtrAmaxdP; + } + if (is_current_scaling && !is_O_in_F16) { + variant_pack[descale_o] = devPtrDescaleO; } - if (is_delayed_scaling) { variant_pack[scale_dQ] = devPtrScaledQ; variant_pack[scale_dK] = devPtrScaledK; variant_pack[scale_dV] = devPtrScaledV; } - if (is_current_scaling && !is_O_in_F16) { - variant_pack[descale_o] = devPtrDescaleO; - } + // if (is_mxfp8) { + // variant_pack[Q_t] = devPtrQ_t; + // variant_pack[K_t] = devPtrK_t; + // variant_pack[dO_f16] = devPtrdO_f16; + // variant_pack[dO_t] = devPtrdO_t; + // variant_pack[descale_q_t] = devPtrDescaleQ_t; + // variant_pack[descale_k_t] = devPtrDescaleK_t; + // variant_pack[descale_dO] = devPtrDescaledO; + // variant_pack[descale_dO_t] = devPtrDescaledO_t; + // } /* if (is_bias) { variant_pack[bias] = devPtrBias; @@ -2682,7 +2694,7 @@ void fused_attn_fp8_fwd(size_t batch, size_t num_attn_heads, size_t num_gqa_grou size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, bool is_training, float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, - NVTE_Mask_Type mask_type, size_t window_size_left, size_t window_size_right, const Tensor* input_Q, const Tensor* input_K, + NVTE_Mask_Type mask_type, size_t window_size_left, size_t window_size_right, bool bottom_right_diagonal, const Tensor* input_Q, const Tensor* input_K, const Tensor* input_V, Tensor* input_output_S, Tensor* output_O, NVTETensorPack* Aux_CTX_Tensors, const Tensor* cu_seqlens_q, const Tensor* cu_seqlens_kv, const Tensor* rng_state, Tensor* workspace, @@ -2769,7 +2781,7 @@ void fused_attn_fp8_fwd(size_t batch, size_t num_attn_heads, size_t num_gqa_grou if ((qkv_format == NVTE_QKV_Format::NVTE_BSHD) || (qkv_format == NVTE_QKV_Format::NVTE_SBHD) || (qkv_format == NVTE_QKV_Format::NVTE_BHSD)) { fused_attn::fused_attn_fp8_fwd_impl_v1( batch, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, head_dim_qk, head_dim_v, is_training, - attn_scale, p_dropout, qkv_layout, bias_type, mask_type, window_size_left, window_size_right, devPtrQ, devPtrK, devPtrV, devPtrM, + attn_scale, p_dropout, qkv_layout, bias_type, mask_type, window_size_left, window_size_right, bottom_right_diagonal, devPtrQ, devPtrK, devPtrV, devPtrM, devPtrZInv, devPtrO, devPtrDescaleQ, devPtrDescaleK, devPtrDescaleV, devPtrDescaleS, devPtrScaleS, devPtrScaleO, devPtrAmaxO, devPtrAmaxS, devPtrcuSeqlensQ, devPtrcuSeqlensKV, devPtrDropoutSeed, devPtrDropoutOffset, get_cudnn_fe_dtype(QKV_type), @@ -2803,7 +2815,7 @@ void fused_attn_fp8_fwd(size_t batch, size_t num_attn_heads, size_t num_gqa_grou void fused_attn_fp8_bwd(size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, - NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, size_t window_size_left, size_t window_size_right, const Tensor* input_Q, + NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, size_t window_size_left, size_t window_size_right, bool bottom_right_diagonal, const Tensor* input_Q, const Tensor* input_K, const Tensor* input_V, const Tensor* input_O, const Tensor* input_dO, const Tensor* input_dO_f16, const Tensor* input_M, const Tensor* input_ZInv, const Tensor* input_S, Tensor* input_output_dP, const Tensor* output_dQ, @@ -2872,7 +2884,7 @@ void fused_attn_fp8_bwd(size_t batch, size_t num_attn_heads, size_t num_gqa_grou if ((qkv_format == NVTE_QKV_Format::NVTE_BSHD) || (qkv_format == NVTE_QKV_Format::NVTE_SBHD)) { fused_attn::fused_attn_fp8_bwd_impl_v1( batch, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, head_dim_qk, head_dim_v, attn_scale, - p_dropout, qkv_layout, bias_type, mask_type, window_size_left, window_size_right, devPtrQ, devPtrK, devPtrV, devPtrM, devPtrZInv, + p_dropout, qkv_layout, bias_type, mask_type, window_size_left, window_size_right, bottom_right_diagonal, devPtrQ, devPtrK, devPtrV, devPtrM, devPtrZInv, devPtrO, devPtrdO, devPtrdQ, devPtrdK, devPtrdV, devPtrDescaleQ, devPtrDescaleK, devPtrDescaleV, devPtrDescaleO, devPtrDescaledO, devPtrDescaleS, devPtrDescaledP, devPtrScaleS, devPtrScaledP, devPtrScaledQ, devPtrScaledK, devPtrScaledV, devPtrAmaxdP, diff --git a/transformer_engine/common/fused_attn/fused_attn_fp8.h b/transformer_engine/common/fused_attn/fused_attn_fp8.h index 22800b2aa2..bfadc0e870 100644 --- a/transformer_engine/common/fused_attn/fused_attn_fp8.h +++ b/transformer_engine/common/fused_attn/fused_attn_fp8.h @@ -18,7 +18,7 @@ void fused_attn_fp8_fwd(size_t batch, size_t num_attn_heads, size_t num_gqa_grou size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, bool is_training, float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, - NVTE_Mask_Type mask_type, size_t window_size_left, size_t window_size_right, const Tensor *input_Q, const Tensor *input_K, + NVTE_Mask_Type mask_type, size_t window_size_left, size_t window_size_right, bool bottom_right_diagonal, const Tensor *input_Q, const Tensor *input_K, const Tensor *input_V, Tensor *input_output_S, Tensor *output_O, NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv, const Tensor *rng_state, Tensor *workspace, @@ -28,7 +28,7 @@ void fused_attn_fp8_fwd(size_t batch, size_t num_attn_heads, size_t num_gqa_grou void fused_attn_fp8_bwd(size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, - NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, size_t window_size_left, size_t window_size_right, const Tensor *input_Q, + NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, size_t window_size_left, size_t window_size_right, bool bottom_right_diagonal, const Tensor *input_Q, const Tensor *input_K, const Tensor *input_V, const Tensor *input_O, const Tensor *input_dO, const Tensor *input_dO_f16, const Tensor *input_M, const Tensor *input_ZInv, const Tensor *input_S, Tensor *input_output_dP, const Tensor *output_dQ, diff --git a/transformer_engine/common/fused_attn/utils.h b/transformer_engine/common/fused_attn/utils.h index fdfc4abe82..ea3428855c 100644 --- a/transformer_engine/common/fused_attn/utils.h +++ b/transformer_engine/common/fused_attn/utils.h @@ -24,12 +24,14 @@ using namespace transformer_engine; enum NVTE_QKV_Matrix { NVTE_Q_Matrix = 0, // queries - NVTE_K_Matrix = 1, // keys - NVTE_K_Matrix_Transpose = 2, // keys transposed - NVTE_V_Matrix = 3, // values - NVTE_V_Matrix_Transpose = 4, // value matrix transposed + NVTE_Q_Matrix_Transpose = 1, // queries transposed + NVTE_K_Matrix = 2, // keys + NVTE_K_Matrix_Transpose = 3, // keys transposed + NVTE_V_Matrix = 4, // values + NVTE_V_Matrix_Transpose = 5, // values transposed NVTE_S_Matrix = 5, // output of GEMM1 NVTE_O_Matrix = 6, // final output + NVTE_O_Matrix_Transpose = 7, // final output transposed }; void generateMatrixStrides(int64_t b, int64_t h, int64_t s_q, int64_t s_kv, int64_t d, diff --git a/transformer_engine/pytorch/attention/dot_product_attention/backends.py b/transformer_engine/pytorch/attention/dot_product_attention/backends.py index c0dca1b330..ac2c067ca8 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/backends.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/backends.py @@ -1255,6 +1255,14 @@ def forward( dP_quantizer, ) + print(f"q_fp8._rowwise_data.shape: {q_fp8._rowwise_data.shape}, k_fp8._rowwise_data.shape: {k_fp8._rowwise_data.shape}, v_fp8._rowwise_data.shape: {v_fp8._rowwise_data.shape}") + print(f"q_fp8._rowwise_data.stride: {q_fp8._rowwise_data.stride()}, k_fp8._rowwise_data.stride: {k_fp8._rowwise_data.stride()}, v_fp8._rowwise_data.stride: {v_fp8._rowwise_data.stride()}") + print(f"q_fp8._rowwise_scale_inv.shape: {q_fp8._rowwise_scale_inv.shape}, k_fp8._rowwise_scale_inv.shape: {k_fp8._rowwise_scale_inv.shape}, v_fp8._rowwise_scale_inv.shape: {v_fp8._rowwise_scale_inv.shape}") + print(f"q_fp8._rowwise_scale_inv.stride: {q_fp8._rowwise_scale_inv.stride()}, k_fp8._rowwise_scale_inv.stride: {k_fp8._rowwise_scale_inv.stride()}, v_fp8._rowwise_scale_inv.stride: {v_fp8._rowwise_scale_inv.stride()}") + print(f"q_fp8._columnwise_data.shape: {q_fp8._columnwise_data.shape}, k_fp8._columnwise_data.shape: {k_fp8._columnwise_data.shape}, v_fp8._columnwise_data.shape: {v_fp8._columnwise_data.shape}") + print(f"q_fp8._columnwise_data.stride: {q_fp8._columnwise_data.stride()}, k_fp8._columnwise_data.stride: {k_fp8._columnwise_data.stride()}, v_fp8._columnwise_data.stride: {v_fp8._columnwise_data.stride()}") + print(f"q_fp8._columnwise_scale_inv.shape: {q_fp8._columnwise_scale_inv.shape}, k_fp8._columnwise_scale_inv.shape: {k_fp8._columnwise_scale_inv.shape}, v_fp8._columnwise_scale_inv.shape: {v_fp8._columnwise_scale_inv.shape}") + print(f"q_fp8._columnwise_scale_inv.stride: {q_fp8._columnwise_scale_inv.stride()}, k_fp8._columnwise_scale_inv.stride: {k_fp8._columnwise_scale_inv.stride()}, v_fp8._columnwise_scale_inv.stride: {v_fp8._columnwise_scale_inv.stride()}") # out_: # DelayedScaling: Float8Tensor; dtype = torch.float16 or torch.bfloat16 # fp8_dtype = tex.DType.kFloat8E4M3 From d655e7e4e464585f11c3f68341ec71148f497537 Mon Sep 17 00:00:00 2001 From: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Date: Tue, 17 Feb 2026 17:11:05 -0800 Subject: [PATCH 019/172] attempt at bwd Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --- .../common/fused_attn/fused_attn_fp8.cu | 140 +++++++----------- .../dot_product_attention/backends.py | 56 ++++--- .../attention/dot_product_attention/utils.py | 21 ++- .../pytorch/cpp_extensions/fused_attn.py | 2 +- 4 files changed, 103 insertions(+), 116 deletions(-) diff --git a/transformer_engine/common/fused_attn/fused_attn_fp8.cu b/transformer_engine/common/fused_attn/fused_attn_fp8.cu index b4aebea25d..504c387d57 100644 --- a/transformer_engine/common/fused_attn/fused_attn_fp8.cu +++ b/transformer_engine/common/fused_attn/fused_attn_fp8.cu @@ -1682,13 +1682,6 @@ void fused_attn_fp8_fwd_impl_v1( o_tensor_type == cudnn_frontend::DataType_t::BFLOAT16); NVTE_CHECK(is_delayed_scaling || is_current_scaling || is_mxfp8, "FP8 fused attention only supports FP8DelayedScaling or FP8CurrentScaling or MXFP8 recipes!"); - printf(">>>>>> b: %d, h: %d, h_g: %d, s_q: %d, s_kv: %d, d_qk: %d, d_v: %d\n", b, h, hg, s_q, s_kv, d_qk, d_v); - printf(">>>>>> scaling_mode: %d\n", scaling_mode); - printf(">>>>>> is_mxfp8: %d\n", is_mxfp8); - printf(">>>>>> is_current_scaling: %d\n", is_current_scaling); - printf(">>>>>> is_delayed_scaling: %d\n", is_delayed_scaling); - printf(">>>>>> qkv_tensor_type: %d, %d, %d, %d\n", qkv_tensor_type, cudnn_frontend::DataType_t::FP8_E8M0, cudnn_frontend::DataType_t::FP8_E4M3, cudnn_frontend::DataType_t::FP8_E5M2); - printf(">>>>>> o_tensor_type: %d, %d, %d\n", o_tensor_type, cudnn_frontend::DataType_t::HALF, cudnn_frontend::DataType_t::BFLOAT16); try { FADescriptor_v1 descriptor{b, @@ -1828,12 +1821,12 @@ void fused_attn_fp8_fwd_impl_v1( int64_t s_kv_scale = (s_kv + block_size - 1) / block_size; int64_t s_q_scale_padded = ((s_q_scale + 3) / 4) * 4; int64_t s_kv_scale_padded = ((s_kv_scale + 3) / 4) * 4; - int64_t d_qk_padded = ((d_qk + 127) / 128) * 128; + // int64_t d_qk_padded = ((d_qk + 127) / 128) * 128; int64_t d_v_padded = ((d_v + 127) / 128) * 128; int64_t d_qk_scale = (d_qk + block_size - 1) / block_size; - int64_t d_v_scale = (d_v + block_size - 1) / block_size; + // int64_t d_v_scale = (d_v + block_size - 1) / block_size; int64_t d_qk_scale_padded = ((d_qk_scale + 3) / 4) * 4; - int64_t d_v_scale_padded = ((d_v_scale + 3) / 4) * 4; + // int64_t d_v_scale_padded = ((d_v_scale + 3) / 4) * 4; std::vector q_scale_strides(4); std::vector k_scale_strides(4); std::vector v_scale_strides(4); @@ -1843,10 +1836,6 @@ void fused_attn_fp8_fwd_impl_v1( NVTE_QKV_Matrix::NVTE_K_Matrix); generateMatrixStrides(b, hg, s_q_scale_padded, s_kv_scale_padded, d_v_padded, v_scale_strides.data(), layout, NVTE_QKV_Matrix::NVTE_V_Matrix); - printf(">>>>>> q_scale_strides: %d, %d, %d, %d\n", q_scale_strides[0], q_scale_strides[1], q_scale_strides[2], q_scale_strides[3]); - printf(">>>>>> k_scale_strides: %d, %d, %d, %d\n", k_scale_strides[0], k_scale_strides[1], k_scale_strides[2], k_scale_strides[3]); - printf(">>>>>> v_scale_strides: %d, %d, %d, %d\n", v_scale_strides[0], v_scale_strides[1], v_scale_strides[2], v_scale_strides[3]); - descale_q = mha_graph->tensor(fe::graph::Tensor_attributes() .set_name("Descale_q") .set_dim({b, h, s_q_padded, d_qk_scale_padded}) @@ -1945,10 +1934,6 @@ void fused_attn_fp8_fwd_impl_v1( std::vector o_stride(4); generateMatrixStrides(b, h, s_q, s_kv, d_v, o_stride.data(), layout, NVTE_QKV_Matrix::NVTE_O_Matrix); - printf(">>>>>> q_stride: %d, %d, %d, %d\n", q_stride[0], q_stride[1], q_stride[2], q_stride[3]); - printf(">>>>>> k_stride: %d, %d, %d, %d\n", k_stride[0], k_stride[1], k_stride[2], k_stride[3]); - printf(">>>>>> v_stride: %d, %d, %d, %d\n", v_stride[0], v_stride[1], v_stride[2], v_stride[3]); - printf(">>>>>> o_stride: %d, %d, %d, %d\n", o_stride[0], o_stride[1], o_stride[2], o_stride[3]); O->set_output(true).set_dim({b, h, s_q, d_v}).set_stride(o_stride).set_data_type(o_tensor_type); amax_o->set_output(true) .set_dim({1, 1, 1, 1}) @@ -2141,24 +2126,24 @@ void fused_attn_fp8_bwd_impl_v1( using graph_and_tensors = std::tuple, std::shared_ptr, // Q - // std::shared_ptr, // Q_t + std::shared_ptr, // Q_t std::shared_ptr, // K - // std::shared_ptr, // K_t + std::shared_ptr, // K_t std::shared_ptr, // V std::shared_ptr, // O std::shared_ptr, // Stats std::shared_ptr, // dO - // std::shared_ptr, // dO_t - // std::shared_ptr, // dO_f16 + std::shared_ptr, // dO_t + std::shared_ptr, // dO_f16 std::shared_ptr, // attn_scale std::shared_ptr, // descale_q - // std::shared_ptr, // descale_q_t + std::shared_ptr, // descale_q_t std::shared_ptr, // descale_k - // std::shared_ptr, // descale_k_t + std::shared_ptr, // descale_k_t std::shared_ptr, // descale_v std::shared_ptr, // descale_o std::shared_ptr, // descale_dO - // std::shared_ptr, // descale_dO_t + std::shared_ptr, // descale_dO_t std::shared_ptr, // descale_s std::shared_ptr, // descale_dP std::shared_ptr, // scale_dQ @@ -2465,16 +2450,30 @@ void fused_attn_fp8_bwd_impl_v1( .set_data_type(fe::DataType_t::INT64)); sdpa_backward_options.set_dropout(dropout_probability, dropout_seed, dropout_offset); } - // if (is_delayed_scaling || is_current_scaling) { - auto [dQ, dK, dV, amax_dQ, amax_dK, amax_dV, amax_dP] = mha_graph->sdpa_fp8_backward( + std::shared_ptr dQ, dK, dV, amax_dQ, amax_dK, amax_dV, amax_dP; + if (is_delayed_scaling || is_current_scaling) { + auto outputs = mha_graph->sdpa_fp8_backward( Q, K, V, O, dO, Stats, descale_q, descale_k, descale_v, descale_o, descale_dO, descale_s, descale_dP, scale_s, scale_dQ, scale_dK, scale_dV, scale_dP, sdpa_backward_options); - // } - // if (is_mxfp8) { - // auto [dQ, dK, dV, amax_dQ, amax_dK, amax_dV] = mha_graph->sdpa_fp8_backward( - // Q, Q_t, K, K_t, V, O, dO_f16, dO, dO_t, Stats, descale_q, descale_q_t, descale_k, descale_k_t, - // descale_v, descale_dO, descale_dO_t, sdpa_backward_options); - // } + dQ = outputs[0]; + dK = outputs[1]; + dV = outputs[2]; + amax_dQ = outputs[3]; + amax_dK = outputs[4]; + amax_dV = outputs[5]; + amax_dP = outputs[6]; + } + if (is_mxfp8) { + auto outputs = mha_graph->sdpa_fp8_backward( + Q, Q_t, K, K_t, V, O, dO_f16, dO, dO_t, Stats, descale_q, descale_q_t, descale_k, descale_k_t, + descale_v, descale_dO, descale_dO_t, sdpa_backward_options); + dQ = outputs[0]; + dK = outputs[1]; + dV = outputs[2]; + amax_dQ = outputs[3]; + amax_dK = outputs[4]; + amax_dV = outputs[5]; + } dQ->set_output(true).set_dim({b, h, s_q, d_qk}).set_stride(q_stride).set_data_type(dqkv_tensor_type); dK->set_output(true).set_dim({b, hg, s_kv, d_qk}).set_stride(k_stride).set_data_type(dqkv_tensor_type); dV->set_output(true).set_dim({b, hg, s_kv, d_v}).set_stride(v_stride).set_data_type(dqkv_tensor_type); @@ -2527,8 +2526,8 @@ void fused_attn_fp8_bwd_impl_v1( Q, K, V, O, Stats, dO, attn_scale, descale_q, descale_k, descale_v, descale_o, descale_dO, descale_s, descale_dP, scale_s, scale_dQ, scale_dK, scale_dV, scale_dP, dQ, dK, dV, amax_dQ, amax_dK, amax_dV, amax_dP); - // auto mxfp8_tensors_tuple = is_mxfp8 ? std::make_tuple(Q_t, K_t, dO_f16, dO_t, descale_q_t, descale_k_t, descale_dO_t) - // : std::make_tuple(nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr); + auto mxfp8_tensors_tuple = is_mxfp8 ? std::make_tuple(Q_t, K_t, dO_f16, dO_t, descale_q_t, descale_k_t, descale_dO_t) + : std::make_tuple(nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr); auto bias_tuple = is_bias ? std::make_tuple(bias, dBias) : std::make_tuple(nullptr, nullptr); auto padding_tuple = is_padding ? std::make_tuple(seq_q, seq_kv) : std::make_tuple(nullptr, nullptr); @@ -2541,56 +2540,17 @@ void fused_attn_fp8_bwd_impl_v1( NVTE_CHECK_CUDNN_FE(mha_graph->check_support(handle)); NVTE_CHECK_CUDNN_FE(mha_graph->build_plans(handle)); - auto return_tuple = std::tuple_cat(std::make_tuple(mha_graph), key_tensors_tuple, bias_tuple, + auto return_tuple = std::tuple_cat(std::make_tuple(mha_graph), key_tensors_tuple, mxfp8_tensors_tuple, bias_tuple, padding_tuple, dropout_tuple); - // padding_tuple, dropout_tuple, mxfp8_tensors_tuple); cache.insert({descriptor, return_tuple}); return return_tuple; }; - auto bprop_tuple = get_graph(sdpa_fp8_bprop_cache, descriptor); - auto mha_graph = std::get<0>(bprop_tuple); - auto Q = std::get<1>(bprop_tuple); - auto K = std::get<2>(bprop_tuple); - auto V = std::get<3>(bprop_tuple); - auto O = std::get<4>(bprop_tuple); - auto Stats = std::get<5>(bprop_tuple); - auto dO = std::get<6>(bprop_tuple); - auto attn_scale = std::get<7>(bprop_tuple); - auto descale_q = std::get<8>(bprop_tuple); - auto descale_k = std::get<9>(bprop_tuple); - auto descale_v = std::get<10>(bprop_tuple); - auto descale_o = std::get<11>(bprop_tuple); - auto descale_dO = std::get<12>(bprop_tuple); - auto descale_s = std::get<13>(bprop_tuple); - auto descale_dP = std::get<14>(bprop_tuple); - auto scale_s = std::get<15>(bprop_tuple); - auto scale_dQ = std::get<16>(bprop_tuple); - auto scale_dK = std::get<17>(bprop_tuple); - auto scale_dV = std::get<18>(bprop_tuple); - auto scale_dP = std::get<19>(bprop_tuple); - auto dQ = std::get<20>(bprop_tuple); - auto dK = std::get<21>(bprop_tuple); - auto dV = std::get<22>(bprop_tuple); - auto amax_dQ = std::get<23>(bprop_tuple); - auto amax_dK = std::get<24>(bprop_tuple); - auto amax_dV = std::get<25>(bprop_tuple); - auto amax_dP = std::get<26>(bprop_tuple); - auto bias = std::get<27>(bprop_tuple); - auto dBias = std::get<28>(bprop_tuple); - auto seq_q = std::get<29>(bprop_tuple); - auto seq_kv = std::get<30>(bprop_tuple); - auto dropout_seed = std::get<31>(bprop_tuple); - auto dropout_offset = std::get<32>(bprop_tuple); - // if (is_mxfp8) { - // auto Q_t = std::get<33>(bprop_tuple); - // auto K_t = std::get<34>(bprop_tuple); - // auto dO_f16 = std::get<35>(bprop_tuple); - // auto dO_t = std::get<36>(bprop_tuple); - // auto descale_q_t = std::get<37>(bprop_tuple); - // auto descale_k_t = std::get<38>(bprop_tuple); - // auto descale_dO_t = std::get<39>(bprop_tuple); - // } + auto [mha_graph, Q, K, V, O, Stats, dO, attn_scale, descale_q, descale_k, descale_v, descale_o, + descale_dO, descale_s, descale_dP, scale_s, scale_dQ, scale_dK, scale_dV, scale_dP, dQ, + dK, dV, amax_dQ, amax_dK, amax_dV, amax_dP, Q_t, K_t, dO_f16, dO_t, descale_q_t, descale_k_t, descale_dO_t, + bias, dBias, seq_q, seq_kv, dropout_seed, dropout_offset] = get_graph(sdpa_fp8_bprop_cache, descriptor); + auto plan_workspace_size = mha_graph->get_workspace_size(); // Exit to request upper level API to allocate memory if needed @@ -2639,16 +2599,16 @@ void fused_attn_fp8_bwd_impl_v1( variant_pack[scale_dK] = devPtrScaledK; variant_pack[scale_dV] = devPtrScaledV; } - // if (is_mxfp8) { - // variant_pack[Q_t] = devPtrQ_t; - // variant_pack[K_t] = devPtrK_t; - // variant_pack[dO_f16] = devPtrdO_f16; - // variant_pack[dO_t] = devPtrdO_t; - // variant_pack[descale_q_t] = devPtrDescaleQ_t; - // variant_pack[descale_k_t] = devPtrDescaleK_t; - // variant_pack[descale_dO] = devPtrDescaledO; - // variant_pack[descale_dO_t] = devPtrDescaledO_t; - // } + if (is_mxfp8) { + variant_pack[Q_t] = devPtrQ_t; + variant_pack[K_t] = devPtrK_t; + variant_pack[dO_f16] = devPtrdO_f16; + variant_pack[dO_t] = devPtrdO_t; + variant_pack[descale_q_t] = devPtrDescaleQ_t; + variant_pack[descale_k_t] = devPtrDescaleK_t; + variant_pack[descale_dO] = devPtrDescaledO; + variant_pack[descale_dO_t] = devPtrDescaledO_t; + } /* if (is_bias) { variant_pack[bias] = devPtrBias; diff --git a/transformer_engine/pytorch/attention/dot_product_attention/backends.py b/transformer_engine/pytorch/attention/dot_product_attention/backends.py index ac2c067ca8..98c48d0e3a 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/backends.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/backends.py @@ -203,8 +203,11 @@ def forward(ctx, tensor1, tensor2, tensor3, quantizer, quantizer_name, qkv_layou def backward(ctx, grad1, grad2, grad3): # pylint: disable=missing-function-docstring if ctx.quantizer_name in ["dO_quantizer", "dP_quantizer"]: - dt_fp8 = ctx.quantizer(grad1) - tensors = dt_fp8.dequantize(dtype=grad1.dtype), grad2, grad3 + if ctx.quantizer is not None: + dt_fp8 = ctx.quantizer(grad1) + tensors = dt_fp8.dequantize(dtype=grad1.dtype), grad2, grad3 + else: + tensors = grad1, grad2, grad3 elif ctx.quantizer_name == "dQKV_quantizer": query_grad, key_grad, value_grad = [x.contiguous() for x in [grad1, grad2, grad3]] dq_fp8, dk_fp8, dv_fp8, ctx.qkv_layout = combine_and_quantize( @@ -213,6 +216,10 @@ def backward(ctx, grad1, grad2, grad3): tensors = combine_and_dequantize( ctx.qkv_layout, dq_fp8, dk_fp8, dv_fp8, src_nominal_dtype=query_grad.dtype ) + if isinstance(ctx.quantizer, MXFP8Quantizer): + assert ctx.qkv_layout == "bhsd_bhsd_bhsd", "bhsd_bhsd_bhsd is assumed to be the shape always at this point in UnfusedDotProductAttention." + # permute back to sbhd_sbhd_sbhd + tensors = [x.permute(2, 0, 1, 3).contiguous() for x in tensors] else: tensors = grad1, grad2, grad3 return tensors[0], tensors[1], tensors[2], None, None, None @@ -1341,7 +1348,7 @@ def forward( fp8_tensors = (None, None, None, None) qkvo_tensors = (None, None, None, None) if is_bwd_fp8: - if fp8_recipe.float8_current_scaling() and _dpa_fp8_cs_o_in_f16: + if (fp8_recipe.float8_current_scaling() and _dpa_fp8_cs_o_in_f16) or isinstance(QKV_quantizer, MXFP8Quantizer): fp8_tensors = (q_fp8, k_fp8, v_fp8, None) qkvo_tensors = (None, None, None, out) else: @@ -1481,13 +1488,30 @@ def forward( @staticmethod def backward(ctx, d_out, *_args): # pylint: disable=missing-function-docstring - + print(f"ctx.fp8: {ctx.fp8}, ctx.is_output_fp8: {ctx.is_output_fp8}, isinstance(d_out, QuantizedTensorStorage): {isinstance(d_out, QuantizedTensorStorage)}") + print(f"ctx.original_qkv_layout: {ctx.original_qkv_layout}, ctx.qkv_layout: {ctx.qkv_layout}") + if ctx.original_qkv_layout != ctx.qkv_layout: + original_qkv_format = ctx.original_qkv_layout.split("_")[0] + new_qkv_format = ctx.qkv_layout.split("_")[0] + perm = [] + for i in original_qkv_format: + perm.append(new_qkv_format.find(i)) + d_out = d_out.permute(*perm).contiguous() + print(f"d_out: {d_out.shape}, {type(d_out)}") # d_out is expected to be in FP8 if is_output_fp8=True, # but in the case it's not, convert it to FP8 before any operation if ctx.fp8 and ctx.is_output_fp8 and not isinstance(d_out, QuantizedTensorStorage): + print(f"before dO_quantizer: {type(d_out)}, {d_out.shape}") + d_out_f16 = d_out + ctx.dO_quantizer.optimize_for_gemm = True d_out = ctx.dO_quantizer(d_out) + print(f"after dO_quantizer: {type(d_out)}, {d_out.shape}") if not ctx.use_FAv2_bwd: - d_out._data = d_out._data.contiguous() + if isinstance(ctx.dO_quantizer, MXFP8Quantizer): + d_out._rowwise_data = d_out._rowwise_data.contiguous() + d_out._columnwise_data = d_out._columnwise_data.contiguous() + else: + d_out._data = d_out._data.contiguous() elif not ctx.use_FAv2_bwd: d_out = d_out.contiguous() ( @@ -1549,14 +1573,6 @@ def backward(ctx, d_out, *_args): # FP8 attention: torch.float16 or torch.bfloat16 dqkv_nominal_dtype = ctx.nominal_dtype - if ctx.original_qkv_layout != ctx.qkv_layout: - original_qkv_format = ctx.original_qkv_layout.split("_")[0] - new_qkv_format = ctx.qkv_layout.split("_")[0] - perm = [] - for i in original_qkv_format: - perm.append(new_qkv_format.find(i)) - d_out = d_out.permute(*perm).contiguous() - if ctx.fp8: # d_out: torch.Tensor; dtype = torch.float16 or torch.bfloat16 # d_out_fp8: Float8Tensor; dtype = torch.float16 or torch.bfloat16 @@ -1595,10 +1611,16 @@ def backward(ctx, d_out, *_args): out_ = out_fp8 if ctx.fp8_recipe.float8_current_scaling() and _dpa_fp8_cs_o_in_f16: out_ = out - if ctx.fp8_recipe.mxfp8_block_scaling(): + if ctx.fp8_recipe.mxfp8(): out_ = out - aux_ctx_tensors.append(d_out) - + aux_ctx_tensors.append(d_out_f16) + print(f"types: {type(q_fp8)}, {type(k_fp8)}, {type(v_fp8)}, {type(out_)}, {type(d_out_fp8)}, {[type(x) for x in aux_ctx_tensors]}") + print(f"shapes: {q_fp8.shape}, {k_fp8.shape}, {v_fp8.shape}, {out_.shape}, {d_out_fp8.shape}, {[x.shape for x in aux_ctx_tensors]}") + for i in [q_fp8, k_fp8, v_fp8, out_, d_out_fp8, *aux_ctx_tensors]: + if isinstance(i, MXFP8Tensor): + print(f"xxxx: {i._with_gemm_swizzled_scales}") + else: + print(f"xxxx: {i.shape}") dq_, dk_, dv_, *rest = fused_attn_bwd( ctx.max_seqlen_q, ctx.max_seqlen_kv, @@ -1621,7 +1643,7 @@ def backward(ctx, d_out, *_args): ctx.attn_scale, ctx.dropout_p, ctx.fast_zero_fill, - ctx.qkv_layout, + ctx.original_qkv_layout, ctx.attn_bias_type, ctx.attn_mask_type, ctx.softmax_type, diff --git a/transformer_engine/pytorch/attention/dot_product_attention/utils.py b/transformer_engine/pytorch/attention/dot_product_attention/utils.py index 12a75131aa..a9ce96c8c9 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/utils.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/utils.py @@ -2091,10 +2091,11 @@ def get_attention_quantizers(fp8, quantizers): if not fp8: return [None] * 6 QKV_quantizer = quantizers["scaling_fwd"][META_QKV] - QKV_quantizer.internal = True + # QKV_quantizer.internal = True QKV_quantizer.set_usage(rowwise=True, columnwise=True) O_quantizer = quantizers["scaling_fwd"][META_O] O_quantizer.set_usage(rowwise=True, columnwise=False) + # O_quantizer.optimize_for_gemm = True if isinstance(QKV_quantizer, MXFP8Quantizer): QKV_quantizer.optimize_for_gemm = True # QKV_quantizer.internal = False @@ -2105,14 +2106,18 @@ def get_attention_quantizers(fp8, quantizers): S_quantizer.set_usage(rowwise=True, columnwise=False) dQKV_quantizer = quantizers["scaling_bwd"][META_DQKV] - dQKV_quantizer.interal = True - dQKV_quantizer.set_usage(rowwise=True, columnwise=False) + dQKV_quantizer.interal = False + dQKV_quantizer.set_usage(rowwise=True, columnwise=True) dO_quantizer = quantizers["scaling_bwd"][META_DO] - dO_quantizer.set_usage(rowwise=True, columnwise=False) - dO_quantizer.internal = True - dP_quantizer = quantizers["scaling_bwd"][META_DP] - dP_quantizer.set_usage(rowwise=True, columnwise=False) - dP_quantizer.interal = True + dO_quantizer.set_usage(rowwise=True, columnwise=True) + dO_quantizer.internal = False + # dO_quantizer.optimize_for_gemm = True + if isinstance(dO_quantizer, MXFP8Quantizer): + dP_quantizer = None + else: + dP_quantizer = quantizers["scaling_bwd"][META_DP] + dP_quantizer.set_usage(rowwise=True, columnwise=False) + dP_quantizer.interal = True return QKV_quantizer, O_quantizer, S_quantizer, dQKV_quantizer, dO_quantizer, dP_quantizer diff --git a/transformer_engine/pytorch/cpp_extensions/fused_attn.py b/transformer_engine/pytorch/cpp_extensions/fused_attn.py index 09953440e9..8f77a8a7fd 100644 --- a/transformer_engine/pytorch/cpp_extensions/fused_attn.py +++ b/transformer_engine/pytorch/cpp_extensions/fused_attn.py @@ -496,7 +496,7 @@ def fused_attn_bwd( dqkv_dtype is not None ), "dqkv_dtype is required as an input for FP8 fused attention backward." assert ( - len(aux_ctx_tensors) == 3 + len(aux_ctx_tensors) >= 3 ), "aux_ctx_tensors is required to be [M, ZInv, rng_state] for FP8 fused attention." output_tensors = tex.fused_attn_bwd( From a4ab691cc4eda08582d083ad0169ef2682adde44 Mon Sep 17 00:00:00 2001 From: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Date: Wed, 18 Feb 2026 18:04:31 -0800 Subject: [PATCH 020/172] fix get_quantizers; attempt at bwd Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --- .../dot_product_attention/backends.py | 88 +++++++++---------- .../dot_product_attention/context_parallel.py | 4 +- .../attention/dot_product_attention/utils.py | 47 +++++----- 3 files changed, 69 insertions(+), 70 deletions(-) diff --git a/transformer_engine/pytorch/attention/dot_product_attention/backends.py b/transformer_engine/pytorch/attention/dot_product_attention/backends.py index 98c48d0e3a..4d1481de79 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/backends.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/backends.py @@ -452,19 +452,15 @@ def forward( scale /= self.layer_number if fp8: + # get fp8 recipe for DPA + fp8_recipe = FP8GlobalStateManager.get_fp8_recipe() + if fp8_meta is not None and fp8_meta.get("local_recipes", None) is not None: + fp8_recipe = fp8_meta["local_recipes"][0] # get quantizers from DPA; all Nones if not fp8 QKV_quantizer, O_quantizer, S_quantizer, dQKV_quantizer, dO_quantizer, dP_quantizer = ( - dpa_utils.get_attention_quantizers(fp8, quantizers) + dpa_utils.get_attention_quantizers(fp8, fp8_recipe, quantizers) ) - # disable swizzle for MXFP8Quantizer - for q in [QKV_quantizer, O_quantizer, S_quantizer, dQKV_quantizer, dO_quantizer, dP_quantizer]: - if isinstance(q, MXFP8Quantizer): - q.optimize_for_gemm = False - q.internal = False # S/dP are forced to use DS quantizers in DPA.init_fp8_metadata; revert them here for true CS emulation - fp8_recipe = FP8GlobalStateManager.get_fp8_recipe() - if fp8_meta is not None and fp8_meta.get("local_recipes", None) is not None: - fp8_recipe = fp8_meta["local_recipes"][0] if fp8_recipe.float8_current_scaling(): S_quantizer = Float8CurrentScalingQuantizer( fp8_dtype=S_quantizer.dtype, device="cuda" @@ -472,6 +468,11 @@ def forward( dP_quantizer = Float8CurrentScalingQuantizer( fp8_dtype=dP_quantizer.dtype, device="cuda" ) + # disable swizzle for MXFP8Quantizer + for q in [QKV_quantizer, O_quantizer, S_quantizer, dQKV_quantizer, dO_quantizer, dP_quantizer]: + if isinstance(q, MXFP8Quantizer): + q.optimize_for_gemm = False + q.internal = False # q, k, v are in sbhd after previous reshaping # quantize and dequantize QKV to emulate FP8 @@ -1229,7 +1230,7 @@ def forward( # get quantizers from DPA; all Nones if not fp8 QKV_quantizer, O_quantizer, S_quantizer, dQKV_quantizer, dO_quantizer, dP_quantizer = ( - dpa_utils.get_attention_quantizers(fp8, quantizers) + dpa_utils.get_attention_quantizers(fp8, fp8_recipe, quantizers) ) # get nominal data type for out @@ -1488,31 +1489,32 @@ def forward( @staticmethod def backward(ctx, d_out, *_args): # pylint: disable=missing-function-docstring + print(f"ctx.fp8: {ctx.fp8}, ctx.is_output_fp8: {ctx.is_output_fp8}, isinstance(d_out, QuantizedTensorStorage): {isinstance(d_out, QuantizedTensorStorage)}") - print(f"ctx.original_qkv_layout: {ctx.original_qkv_layout}, ctx.qkv_layout: {ctx.qkv_layout}") + # reshape d_out to ctx.qkv_layout; only happens with MXFP8BlockScaling if ctx.original_qkv_layout != ctx.qkv_layout: + print(f"ctx.original_qkv_layout: {ctx.original_qkv_layout}, ctx.qkv_layout: {ctx.qkv_layout}") + print(f"d_out before reshape: {d_out.shape}, {type(d_out)}") original_qkv_format = ctx.original_qkv_layout.split("_")[0] new_qkv_format = ctx.qkv_layout.split("_")[0] perm = [] for i in original_qkv_format: perm.append(new_qkv_format.find(i)) d_out = d_out.permute(*perm).contiguous() - print(f"d_out: {d_out.shape}, {type(d_out)}") - # d_out is expected to be in FP8 if is_output_fp8=True, - # but in the case it's not, convert it to FP8 before any operation - if ctx.fp8 and ctx.is_output_fp8 and not isinstance(d_out, QuantizedTensorStorage): - print(f"before dO_quantizer: {type(d_out)}, {d_out.shape}") - d_out_f16 = d_out - ctx.dO_quantizer.optimize_for_gemm = True - d_out = ctx.dO_quantizer(d_out) - print(f"after dO_quantizer: {type(d_out)}, {d_out.shape}") - if not ctx.use_FAv2_bwd: - if isinstance(ctx.dO_quantizer, MXFP8Quantizer): - d_out._rowwise_data = d_out._rowwise_data.contiguous() - d_out._columnwise_data = d_out._columnwise_data.contiguous() - else: - d_out._data = d_out._data.contiguous() - elif not ctx.use_FAv2_bwd: + print(f"d_out after reshape: {d_out.shape}, {type(d_out)}") + + # d_out: torch.Tensor; dtype = torch.float16 or torch.bfloat16 + # d_out_fp8: Float8Tensor; dtype = torch.float16 or torch.bfloat16 + # fp8_dtype = tex.DType.kFloat8E5M2 + d_out_fp8 = None + if ctx.fp8: + print(f"d_out before quantizer: {d_out.shape}, {type(d_out)}") + if isinstance(d_out, QuantizedTensorStorage): + d_out_fp8 = d_out + else: + d_out_fp8 = ctx.dO_quantizer(d_out) + print(f"d_out after quantizer: {d_out_fp8.shape}, {type(d_out_fp8)}") + if not isinstance(d_out, QuantizedTensorStorage) and not ctx.use_FAv2_bwd: d_out = d_out.contiguous() ( q_fp8, @@ -1574,14 +1576,6 @@ def backward(ctx, d_out, *_args): dqkv_nominal_dtype = ctx.nominal_dtype if ctx.fp8: - # d_out: torch.Tensor; dtype = torch.float16 or torch.bfloat16 - # d_out_fp8: Float8Tensor; dtype = torch.float16 or torch.bfloat16 - # fp8_dtype = tex.DType.kFloat8E5M2 - if ctx.is_output_fp8: - d_out_fp8 = d_out - else: - d_out_fp8 = ctx.dO_quantizer(d_out) - # print quantizers print_quantizers( "FusedAttnFunc.backward >> before: ", @@ -1613,14 +1607,9 @@ def backward(ctx, d_out, *_args): out_ = out if ctx.fp8_recipe.mxfp8(): out_ = out - aux_ctx_tensors.append(d_out_f16) + aux_ctx_tensors.append(d_out) print(f"types: {type(q_fp8)}, {type(k_fp8)}, {type(v_fp8)}, {type(out_)}, {type(d_out_fp8)}, {[type(x) for x in aux_ctx_tensors]}") print(f"shapes: {q_fp8.shape}, {k_fp8.shape}, {v_fp8.shape}, {out_.shape}, {d_out_fp8.shape}, {[x.shape for x in aux_ctx_tensors]}") - for i in [q_fp8, k_fp8, v_fp8, out_, d_out_fp8, *aux_ctx_tensors]: - if isinstance(i, MXFP8Tensor): - print(f"xxxx: {i._with_gemm_swizzled_scales}") - else: - print(f"xxxx: {i.shape}") dq_, dk_, dv_, *rest = fused_attn_bwd( ctx.max_seqlen_q, ctx.max_seqlen_kv, @@ -1632,7 +1621,7 @@ def backward(ctx, d_out, *_args): out_, d_out_fp8, dqkv_nominal_dtype, - dqkv_te_dtype, + dqkv_te_dtype, # could we remove this? aux_ctx_tensors, ctx.fused_attention_backend, cu_seqlens_q_padded, @@ -1643,7 +1632,7 @@ def backward(ctx, d_out, *_args): ctx.attn_scale, ctx.dropout_p, ctx.fast_zero_fill, - ctx.original_qkv_layout, + ctx.qkv_layout, ctx.attn_bias_type, ctx.attn_mask_type, ctx.softmax_type, @@ -1652,10 +1641,17 @@ def backward(ctx, d_out, *_args): ctx.deterministic, is_graph_capturing(), ) + if ctx.original_qkv_layout != ctx.qkv_layout: + original_qkv_format = ctx.original_qkv_layout.split("_")[0] + new_qkv_format = ctx.qkv_layout.split("_")[0] + perm = [] + for i in new_qkv_format: + perm.append(original_qkv_format.find(i)) + dq_, dk_, dv_ = [x.permute(*perm).contiguous() for x in (dq_, dk_, dv_)] # dq, dk, dv: torch.Tensor; dtype = torch.float16 or torch.bfloat16 dq, dk, dv = dq_, dk_, dv_ - is_quantized_tensor = isinstance(dq_, QuantizedTensor) + is_quantized_tensor = isinstance(dq_, QuantizedTensorStorage) if is_quantized_tensor and not ctx.is_input_fp8: # return in F16 dq, dk, dv = combine_and_dequantize( @@ -1665,7 +1661,7 @@ def backward(ctx, d_out, *_args): dv_, src_nominal_dtype=dq_.dtype, ) - if not is_float8tensor and ctx.is_input_fp8: + if not is_quantized_tensor and ctx.is_input_fp8: # return in FP8 dq, dk, dv, _ = combine_and_quantize( ctx.qkv_layout, dq_, dk_, dv_, ctx.dQKV_quantizer @@ -1982,7 +1978,7 @@ def forward( " with FP8!" ) if fp8_recipe.float8_current_scaling() and context_parallel: - all_quantizers = dpa_utils.get_attention_quantizers(fp8, quantizers) + all_quantizers = dpa_utils.get_attention_quantizers(fp8, fp8_recipe, quantizers) for q in all_quantizers: if isinstance(q, Float8CurrentScalingQuantizer): q.with_amax_reduction = True diff --git a/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py b/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py index 244f24111d..34af861604 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py @@ -1356,7 +1356,7 @@ def forward( dQKV_quantizer, dO_quantizer, dP_quantizer, - ) = dpa_utils.get_attention_quantizers(fp8, quantizers) + ) = dpa_utils.get_attention_quantizers(fp8, fp8_recipe, quantizers) q_f16 = None q_fp8, k_fp8, v_fp8 = (None, None, None) @@ -3394,7 +3394,7 @@ def forward( max_logit = None QKV_quantizer, O_quantizer, S_quantizer, dQKV_quantizer, dO_quantizer, dP_quantizer = ( - dpa_utils.get_attention_quantizers(fp8, quantizers) + dpa_utils.get_attention_quantizers(fp8, fp8_recipe, quantizers) ) q_fp8, k_fp8, v_fp8 = (None, None, None) diff --git a/transformer_engine/pytorch/attention/dot_product_attention/utils.py b/transformer_engine/pytorch/attention/dot_product_attention/utils.py index a9ce96c8c9..26a2bda08b 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/utils.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/utils.py @@ -2086,38 +2086,41 @@ def check_set_window_size( return window_size -def get_attention_quantizers(fp8, quantizers): +def get_attention_quantizers(fp8, fp8_recipe, quantizers): """Get the list of quantizers used in attention from the quantizers list.""" if not fp8: return [None] * 6 + QKV_quantizer = quantizers["scaling_fwd"][META_QKV] - # QKV_quantizer.internal = True - QKV_quantizer.set_usage(rowwise=True, columnwise=True) + QKV_quantizer.internal = True + QKV_quantizer.set_usage(rowwise=True, columnwise=False) O_quantizer = quantizers["scaling_fwd"][META_O] O_quantizer.set_usage(rowwise=True, columnwise=False) - # O_quantizer.optimize_for_gemm = True - if isinstance(QKV_quantizer, MXFP8Quantizer): - QKV_quantizer.optimize_for_gemm = True - # QKV_quantizer.internal = False - S_quantizer = None - else: - S_quantizer = quantizers["scaling_fwd"][META_S] - S_quantizer.internal = True - S_quantizer.set_usage(rowwise=True, columnwise=False) + S_quantizer = quantizers["scaling_fwd"][META_S] + S_quantizer.internal = True + S_quantizer.set_usage(rowwise=True, columnwise=False) dQKV_quantizer = quantizers["scaling_bwd"][META_DQKV] - dQKV_quantizer.interal = False - dQKV_quantizer.set_usage(rowwise=True, columnwise=True) + dQKV_quantizer.interal = True + dQKV_quantizer.set_usage(rowwise=True, columnwise=False) dO_quantizer = quantizers["scaling_bwd"][META_DO] - dO_quantizer.set_usage(rowwise=True, columnwise=True) - dO_quantizer.internal = False - # dO_quantizer.optimize_for_gemm = True - if isinstance(dO_quantizer, MXFP8Quantizer): + dO_quantizer.set_usage(rowwise=True, columnwise=False) + dO_quantizer.internal = True + dP_quantizer = quantizers["scaling_bwd"][META_DP] + dP_quantizer.set_usage(rowwise=True, columnwise=False) + dP_quantizer.interal = True + + if fp8_recipe.mxfp8(): + QKV_quantizer.columnwise = True + QKV_quantizer.optimize_for_gemm = True + O_quantizer.columnwise = True + O_quantizer.optimize_for_gemm = True + S_quantizer = None + dQKV_quantizer.columnwise = True + dQKV_quantizer.optimize_for_gemm = True + dO_quantizer.columnwise = True + dO_quantizer.optimize_for_gemm = True dP_quantizer = None - else: - dP_quantizer = quantizers["scaling_bwd"][META_DP] - dP_quantizer.set_usage(rowwise=True, columnwise=False) - dP_quantizer.interal = True return QKV_quantizer, O_quantizer, S_quantizer, dQKV_quantizer, dO_quantizer, dP_quantizer From a85070dbc38f4ce749d4d0f246a9fe29b928112d Mon Sep 17 00:00:00 2001 From: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Date: Thu, 19 Feb 2026 17:03:04 -0800 Subject: [PATCH 021/172] fix fprop; add o_format Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --- .../common/fused_attn/fused_attn.cpp | 54 +++++++++++++++++-- .../common/fused_attn/fused_attn_fp8.cu | 23 ++++---- .../common/fused_attn/fused_attn_fp8.h | 2 +- transformer_engine/common/fused_attn/utils.cu | 31 +++++++++++ transformer_engine/common/fused_attn/utils.h | 3 ++ .../include/transformer_engine/fused_attn.h | 13 ++++- .../dot_product_attention/backends.py | 43 +++++++-------- .../attention/dot_product_attention/utils.py | 30 ++++++----- .../pytorch/cpp_extensions/fused_attn.py | 4 ++ transformer_engine/pytorch/csrc/extensions.h | 2 +- .../pytorch/csrc/extensions/attention.cpp | 20 ++++--- 11 files changed, 162 insertions(+), 63 deletions(-) diff --git a/transformer_engine/common/fused_attn/fused_attn.cpp b/transformer_engine/common/fused_attn/fused_attn.cpp index ebda62568a..79a8417bdc 100644 --- a/transformer_engine/common/fused_attn/fused_attn.cpp +++ b/transformer_engine/common/fused_attn/fused_attn.cpp @@ -208,6 +208,52 @@ NVTE_QKV_Format nvte_get_kv_format(NVTE_QKV_Layout qkv_layout) { } } +// map one NVTE_QKV_Format to another +std::vector nvte_convert_qkv_format(std::vector src_shape, NVTE_QKV_Format src_format, NVTE_QKV_Format dst_format) { + std::vector dst_shape(src_shape.size()); + size_t b=0, h=0, s=0, d=0, t=0; + switch (src_format) { + case NVTE_QKV_Format::NVTE_BSHD: + b = src_shape[0]; + s = src_shape[1]; + h = src_shape[2]; + d = src_shape[3]; + break; + case NVTE_QKV_Format::NVTE_SBHD: + s = src_shape[0]; + b = src_shape[1]; + h = src_shape[2]; + d = src_shape[3]; + break; + case NVTE_QKV_Format::NVTE_BHSD: + b = src_shape[0]; + h = src_shape[1]; + s = src_shape[2]; + d = src_shape[3]; + break; + case NVTE_QKV_Format::NVTE_THD: + t = src_shape[0]; + h = src_shape[1]; + d = src_shape[2]; + break; + } + switch (dst_format) { + case NVTE_QKV_Format::NVTE_BSHD: + dst_shape = {b, s, h, d}; + break; + case NVTE_QKV_Format::NVTE_SBHD: + dst_shape = {s, b, h, d}; + break; + case NVTE_QKV_Format::NVTE_BHSD: + dst_shape = {b, h, s, d}; + break; + case NVTE_QKV_Format::NVTE_THD: + dst_shape = {t, h, d}; + break; + } + return dst_shape; +} + // select a backend for fused attention NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( bool is_training, NVTEDType q_dtype, NVTEDType kv_dtype, NVTE_QKV_Layout qkv_layout, @@ -631,7 +677,7 @@ void nvte_fused_attn_fwd_qkvpacked( Tensor V_view = make_tensor_view(input_QKV, unpacked_shape, 2 * stride); fused_attn_fp8_fwd(b, h, h, max_seqlen, max_seqlen, d, d, is_training, attn_scale, dropout, - qkv_layout, bias_type, attn_mask_type, window_size_left, window_size_right, bottom_right_diagonal, &Q_view, &K_view, &V_view, + qkv_layout, qkv_format, bias_type, attn_mask_type, window_size_left, window_size_right, bottom_right_diagonal, &Q_view, &K_view, &V_view, input_output_S, output_O, Aux_CTX_Tensors, input_cu_seqlens, input_cu_seqlens, input_rng_state, wkspace, stream, handle); #else @@ -941,7 +987,7 @@ void nvte_fused_attn_fwd_kvpacked( Tensor V_view = make_tensor_view(input_KV, unpacked_kv_shape, stride); fused_attn_fp8_fwd(b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d, d, is_training, attn_scale, - dropout, qkv_layout, bias_type, attn_mask_type, window_size_left, window_size_right, bottom_right_diagonal, input_Q, &K_view, &V_view, + dropout, qkv_layout, q_format, bias_type, attn_mask_type, window_size_left, window_size_right, bottom_right_diagonal, input_Q, &K_view, &V_view, input_output_S, output_O, Aux_CTX_Tensors, input_cu_seqlens_q, input_cu_seqlens_kv, input_rng_state, wkspace, stream, handle); #else @@ -1125,7 +1171,7 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso const NVTETensor page_table_v, const NVTETensor rng_state, size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training, bool return_max_logit, bool cuda_graph, float attn_scale, float dropout, - NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, + NVTE_QKV_Layout qkv_layout, NVTE_QKV_Format o_format, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, int64_t window_size_left, int64_t window_size_right, bool bottom_right_diagonal, NVTETensor workspace, cudaStream_t stream) { @@ -1228,7 +1274,7 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso } else if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_FP8) { #if (CUDNN_VERSION >= 8900) fused_attn_fp8_fwd(b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d_qk, d_v, is_training, attn_scale, - dropout, qkv_layout, bias_type, attn_mask_type, window_size_left, window_size_right, bottom_right_diagonal, input_Q, input_K, input_V, + dropout, qkv_layout, o_format, bias_type, attn_mask_type, window_size_left, window_size_right, bottom_right_diagonal, input_Q, input_K, input_V, input_output_S, output_O, Aux_CTX_Tensors, input_cu_seqlens_q, input_cu_seqlens_kv, input_rng_state, wkspace, stream, handle); #else diff --git a/transformer_engine/common/fused_attn/fused_attn_fp8.cu b/transformer_engine/common/fused_attn/fused_attn_fp8.cu index 504c387d57..4a9af8b4b3 100644 --- a/transformer_engine/common/fused_attn/fused_attn_fp8.cu +++ b/transformer_engine/common/fused_attn/fused_attn_fp8.cu @@ -1653,7 +1653,7 @@ void fused_attn_fp8_bwd_impl( // fused attention FWD FP8 with FE 1.0+ void fused_attn_fp8_fwd_impl_v1( int64_t b, int64_t h, int64_t hg, int64_t s_q, int64_t s_kv, int64_t d_qk, int64_t d_v, bool is_training, - float scaling_factor, float dropout_probability, NVTE_QKV_Layout layout, + float scaling_factor, float dropout_probability, NVTE_QKV_Layout qkv_layout, NVTE_QKV_Format o_format, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, int64_t window_size_left, int64_t window_size_right, bool bottom_right_diagonal, void* devPtrQ, void* devPtrK, void* devPtrV, void* devPtrM, void* devPtrZInv, void* devPtrO, void* devPtrDescaleQ, void* devPtrDescaleK, void* devPtrDescaleV, void* devPtrDescaleS, void* devPtrScaleS, void* devPtrScaleO, @@ -1702,7 +1702,7 @@ void fused_attn_fp8_fwd_impl_v1( scaling_factor, is_training, dropout_probability, - layout, + qkv_layout, bias_type, mask_type, NVTE_Softmax_Type::NVTE_VANILLA_SOFTMAX, @@ -1767,11 +1767,11 @@ void fused_attn_fp8_fwd_impl_v1( std::vector q_stride(4); std::vector k_stride(4); std::vector v_stride(4); - generateMatrixStrides(b, h, s_q, s_kv, d_qk, q_stride.data(), layout, + generateMatrixStrides(b, h, s_q, s_kv, d_qk, q_stride.data(), qkv_layout, NVTE_QKV_Matrix::NVTE_Q_Matrix); - generateMatrixStrides(b, hg, s_q, s_kv, d_qk, k_stride.data(), layout, + generateMatrixStrides(b, hg, s_q, s_kv, d_qk, k_stride.data(), qkv_layout, NVTE_QKV_Matrix::NVTE_K_Matrix); - generateMatrixStrides(b, hg, s_q, s_kv, d_v, v_stride.data(), layout, + generateMatrixStrides(b, hg, s_q, s_kv, d_v, v_stride.data(), qkv_layout, NVTE_QKV_Matrix::NVTE_V_Matrix); Q = mha_graph->tensor(fe::graph::Tensor_attributes() .set_name("Q") @@ -1830,11 +1830,11 @@ void fused_attn_fp8_fwd_impl_v1( std::vector q_scale_strides(4); std::vector k_scale_strides(4); std::vector v_scale_strides(4); - generateMatrixStrides(b, h, s_q_padded, s_kv_padded, d_qk_scale_padded, q_scale_strides.data(), layout, + generateMatrixStrides(b, h, s_q_padded, s_kv_padded, d_qk_scale_padded, q_scale_strides.data(), qkv_layout, NVTE_QKV_Matrix::NVTE_Q_Matrix); - generateMatrixStrides(b, hg, s_q_padded, s_kv_padded, d_qk_scale_padded, k_scale_strides.data(), layout, + generateMatrixStrides(b, hg, s_q_padded, s_kv_padded, d_qk_scale_padded, k_scale_strides.data(), qkv_layout, NVTE_QKV_Matrix::NVTE_K_Matrix); - generateMatrixStrides(b, hg, s_q_scale_padded, s_kv_scale_padded, d_v_padded, v_scale_strides.data(), layout, + generateMatrixStrides(b, hg, s_q_scale_padded, s_kv_scale_padded, d_v_padded, v_scale_strides.data(), qkv_layout, NVTE_QKV_Matrix::NVTE_V_Matrix); descale_q = mha_graph->tensor(fe::graph::Tensor_attributes() .set_name("Descale_q") @@ -1932,8 +1932,7 @@ void fused_attn_fp8_fwd_impl_v1( } std::vector o_stride(4); - generateMatrixStrides(b, h, s_q, s_kv, d_v, o_stride.data(), layout, - NVTE_QKV_Matrix::NVTE_O_Matrix); + generateMatrixStridesWithFormat(b, h, s_q, d_v, o_stride.data(), o_format); O->set_output(true).set_dim({b, h, s_q, d_v}).set_stride(o_stride).set_data_type(o_tensor_type); amax_o->set_output(true) .set_dim({1, 1, 1, 1}) @@ -2653,7 +2652,7 @@ void fused_attn_fp8_bwd_impl_v1( void fused_attn_fp8_fwd(size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, bool is_training, float attn_scale, float p_dropout, - NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, + NVTE_QKV_Layout qkv_layout, NVTE_QKV_Format o_format, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, size_t window_size_left, size_t window_size_right, bool bottom_right_diagonal, const Tensor* input_Q, const Tensor* input_K, const Tensor* input_V, Tensor* input_output_S, Tensor* output_O, NVTETensorPack* Aux_CTX_Tensors, const Tensor* cu_seqlens_q, @@ -2741,7 +2740,7 @@ void fused_attn_fp8_fwd(size_t batch, size_t num_attn_heads, size_t num_gqa_grou if ((qkv_format == NVTE_QKV_Format::NVTE_BSHD) || (qkv_format == NVTE_QKV_Format::NVTE_SBHD) || (qkv_format == NVTE_QKV_Format::NVTE_BHSD)) { fused_attn::fused_attn_fp8_fwd_impl_v1( batch, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, head_dim_qk, head_dim_v, is_training, - attn_scale, p_dropout, qkv_layout, bias_type, mask_type, window_size_left, window_size_right, bottom_right_diagonal, devPtrQ, devPtrK, devPtrV, devPtrM, + attn_scale, p_dropout, qkv_layout, o_format, bias_type, mask_type, window_size_left, window_size_right, bottom_right_diagonal, devPtrQ, devPtrK, devPtrV, devPtrM, devPtrZInv, devPtrO, devPtrDescaleQ, devPtrDescaleK, devPtrDescaleV, devPtrDescaleS, devPtrScaleS, devPtrScaleO, devPtrAmaxO, devPtrAmaxS, devPtrcuSeqlensQ, devPtrcuSeqlensKV, devPtrDropoutSeed, devPtrDropoutOffset, get_cudnn_fe_dtype(QKV_type), diff --git a/transformer_engine/common/fused_attn/fused_attn_fp8.h b/transformer_engine/common/fused_attn/fused_attn_fp8.h index bfadc0e870..548a41a561 100644 --- a/transformer_engine/common/fused_attn/fused_attn_fp8.h +++ b/transformer_engine/common/fused_attn/fused_attn_fp8.h @@ -17,7 +17,7 @@ namespace transformer_engine { void fused_attn_fp8_fwd(size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, bool is_training, float attn_scale, float p_dropout, - NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, + NVTE_QKV_Layout qkv_layout, NVTE_QKV_Format o_format, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, size_t window_size_left, size_t window_size_right, bool bottom_right_diagonal, const Tensor *input_Q, const Tensor *input_K, const Tensor *input_V, Tensor *input_output_S, Tensor *output_O, NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens_q, diff --git a/transformer_engine/common/fused_attn/utils.cu b/transformer_engine/common/fused_attn/utils.cu index c6d6386fb7..0ea5d6aa7f 100644 --- a/transformer_engine/common/fused_attn/utils.cu +++ b/transformer_engine/common/fused_attn/utils.cu @@ -17,6 +17,37 @@ namespace fused_attn { using namespace transformer_engine; +// get matrix strides based on matrix type +void generateMatrixStridesWithFormat(int64_t b, int64_t h, int64_t s, int64_t d, + int64_t *strideA, NVTE_QKV_Format format) { +constexpr int batch_dim_idx = 0; +constexpr int head_dim_idx = 1; +constexpr int seqlen_dim_idx = 2; +constexpr int hidden_dim_idx = 3; + + switch (format) { + case NVTE_QKV_Format::NVTE_BSHD: + case NVTE_QKV_Format::NVTE_THD: + strideA[batch_dim_idx] = s * h * d; + strideA[head_dim_idx] = d; + strideA[seqlen_dim_idx] = h * d; + strideA[hidden_dim_idx] = 1; + break; + case NVTE_QKV_Format::NVTE_SBHD: + strideA[batch_dim_idx] = h * d; + strideA[head_dim_idx] = d; + strideA[seqlen_dim_idx] = b * h * d; + strideA[hidden_dim_idx] = 1; + break; + case NVTE_QKV_Format::NVTE_BHSD: + strideA[batch_dim_idx] = h * s * d; + strideA[head_dim_idx] = s * d; + strideA[seqlen_dim_idx] = d; + strideA[hidden_dim_idx] = 1; + break; + } +} + // get matrix strides based on matrix type void generateMatrixStrides(int64_t b, int64_t h, int64_t s_q, int64_t s_kv, int64_t d, int64_t *strideA, NVTE_QKV_Layout layout, NVTE_QKV_Matrix matrix) { diff --git a/transformer_engine/common/fused_attn/utils.h b/transformer_engine/common/fused_attn/utils.h index ea3428855c..88535f61c9 100644 --- a/transformer_engine/common/fused_attn/utils.h +++ b/transformer_engine/common/fused_attn/utils.h @@ -34,6 +34,9 @@ enum NVTE_QKV_Matrix { NVTE_O_Matrix_Transpose = 7, // final output transposed }; +void generateMatrixStridesWithFormat(int64_t b, int64_t h, int64_t s, int64_t d, + int64_t *strideA, NVTE_QKV_Format format); + void generateMatrixStrides(int64_t b, int64_t h, int64_t s_q, int64_t s_kv, int64_t d, int64_t *strideA, NVTE_QKV_Layout layout, NVTE_QKV_Matrix matrix); diff --git a/transformer_engine/common/include/transformer_engine/fused_attn.h b/transformer_engine/common/include/transformer_engine/fused_attn.h index 204d8f3d5a..883c5a6e61 100644 --- a/transformer_engine/common/include/transformer_engine/fused_attn.h +++ b/transformer_engine/common/include/transformer_engine/fused_attn.h @@ -193,6 +193,16 @@ NVTE_QKV_Format nvte_get_q_format(NVTE_QKV_Layout qkv_layout); */ NVTE_QKV_Format nvte_get_kv_format(NVTE_QKV_Layout qkv_layout); +/*! \brief Convert one NVTE_QKV_Format to another. + * + * \param[in] src_shape The source shape. + * \param[in] src_format The source format. + * \param[in] dst_format The destination format. + * + * \return The destination shape. + */ + std::vector nvte_convert_qkv_format(std::vector src_shape, NVTE_QKV_Format src_format, NVTE_QKV_Format dst_format); + /*! \brief Get fused attention backend based on input parameters. * * \param[in] is_training Whether the model is in training mode. @@ -563,6 +573,7 @@ void nvte_fused_attn_bwd_kvpacked( * \param[in] attn_scale Scaling factor for Q * K.T. * \param[in] dropout Dropout probability. * \param[in] qkv_layout QKV tensors' layout. + * \param[in] o_format Output format. * \param[in] bias_type Bias type. * \param[in] attn_mask_type Attention mask type. * \param[in] softmax_type Attention softmax type. @@ -581,7 +592,7 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso const NVTETensor page_table_v, const NVTETensor rng_state, size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training, bool return_max_logit, bool cuda_graph, float attn_scale, float dropout, - NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, + NVTE_QKV_Layout qkv_layout, NVTE_QKV_Format o_format, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, int64_t window_size_left, int64_t window_size_right, bool bottom_right_diagonal, NVTETensor workspace, cudaStream_t stream); diff --git a/transformer_engine/pytorch/attention/dot_product_attention/backends.py b/transformer_engine/pytorch/attention/dot_product_attention/backends.py index 4d1481de79..2838eaa5eb 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/backends.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/backends.py @@ -1213,21 +1213,21 @@ def forward( if fp8_meta is not None and fp8_meta.get("local_recipes", None) is not None: fp8_recipe = fp8_meta["local_recipes"][0] + original_qkv_layout = qkv_layout + _, o_format, _ = dpa_utils.get_qkv_format(qkv_layout) + # input types are inferred from the real data while output types are controlled by fp8_output # fp8_output should be set upstream as (DPA.fp8 and DPA.fp8_meta["recipe"].fp8_mha) assert isinstance(k, q.__class__) and isinstance( v, q.__class__ - ), "q, k, v must be of the same class, e.g. torch.Tensor or Float8Tensor." - is_input_fp8 = isinstance(q, Float8Tensor) + ), "q, k, v must be of the same class, e.g. torch.Tensor or QuantizedTensorStorage." + is_input_fp8 = isinstance(q, QuantizedTensorStorage) is_output_fp8 = fp8_output - # whether fwd kernel in FP8: fp8 = (DPA.fp8 and DPA.fp8_meta["recipe"].fp8_dpa) - # whether bwd kernel in FP8: + # whether fwd kernel will be run in FP8: fp8 = (DPA.fp8 and DPA.fp8_meta["recipe"].fp8_dpa) + # whether bwd kernel will be run in FP8: is_bwd_fp8 = fp8 and int(os.getenv("NVTE_FP8_DPA_BWD", "1")) - # save original qkv_layout - original_qkv_layout = qkv_layout - # get quantizers from DPA; all Nones if not fp8 QKV_quantizer, O_quantizer, S_quantizer, dQKV_quantizer, dO_quantizer, dP_quantizer = ( dpa_utils.get_attention_quantizers(fp8, fp8_recipe, quantizers) @@ -1249,7 +1249,7 @@ def forward( if is_input_fp8: q_fp8, k_fp8, v_fp8 = q, k, v else: - q_fp8, k_fp8, v_fp8, qkv_layout = combine_and_quantize(original_qkv_layout, q, k, v, QKV_quantizer) + q_fp8, k_fp8, v_fp8, qkv_layout = combine_and_quantize(qkv_layout, q, k, v, QKV_quantizer) # print quantizers print_quantizers( @@ -1263,14 +1263,6 @@ def forward( dP_quantizer, ) - print(f"q_fp8._rowwise_data.shape: {q_fp8._rowwise_data.shape}, k_fp8._rowwise_data.shape: {k_fp8._rowwise_data.shape}, v_fp8._rowwise_data.shape: {v_fp8._rowwise_data.shape}") - print(f"q_fp8._rowwise_data.stride: {q_fp8._rowwise_data.stride()}, k_fp8._rowwise_data.stride: {k_fp8._rowwise_data.stride()}, v_fp8._rowwise_data.stride: {v_fp8._rowwise_data.stride()}") - print(f"q_fp8._rowwise_scale_inv.shape: {q_fp8._rowwise_scale_inv.shape}, k_fp8._rowwise_scale_inv.shape: {k_fp8._rowwise_scale_inv.shape}, v_fp8._rowwise_scale_inv.shape: {v_fp8._rowwise_scale_inv.shape}") - print(f"q_fp8._rowwise_scale_inv.stride: {q_fp8._rowwise_scale_inv.stride()}, k_fp8._rowwise_scale_inv.stride: {k_fp8._rowwise_scale_inv.stride()}, v_fp8._rowwise_scale_inv.stride: {v_fp8._rowwise_scale_inv.stride()}") - print(f"q_fp8._columnwise_data.shape: {q_fp8._columnwise_data.shape}, k_fp8._columnwise_data.shape: {k_fp8._columnwise_data.shape}, v_fp8._columnwise_data.shape: {v_fp8._columnwise_data.shape}") - print(f"q_fp8._columnwise_data.stride: {q_fp8._columnwise_data.stride()}, k_fp8._columnwise_data.stride: {k_fp8._columnwise_data.stride()}, v_fp8._columnwise_data.stride: {v_fp8._columnwise_data.stride()}") - print(f"q_fp8._columnwise_scale_inv.shape: {q_fp8._columnwise_scale_inv.shape}, k_fp8._columnwise_scale_inv.shape: {k_fp8._columnwise_scale_inv.shape}, v_fp8._columnwise_scale_inv.shape: {v_fp8._columnwise_scale_inv.shape}") - print(f"q_fp8._columnwise_scale_inv.stride: {q_fp8._columnwise_scale_inv.stride()}, k_fp8._columnwise_scale_inv.stride: {k_fp8._columnwise_scale_inv.stride()}, v_fp8._columnwise_scale_inv.stride: {v_fp8._columnwise_scale_inv.stride()}") # out_: # DelayedScaling: Float8Tensor; dtype = torch.float16 or torch.bfloat16 # fp8_dtype = tex.DType.kFloat8E4M3 @@ -1298,6 +1290,7 @@ def forward( dropout_p, fast_zero_fill, qkv_layout, + o_format, attn_bias_type, attn_mask_type, softmax_type, @@ -1307,20 +1300,21 @@ def forward( softmax_offset, cuda_graph=is_graph_capturing(), ) - if original_qkv_layout != qkv_layout: - original_qkv_format = original_qkv_layout.split("_")[0] - new_qkv_format = qkv_layout.split("_")[0] - perm = [] - for i in new_qkv_format: - perm.append(original_qkv_format.find(i)) - out_ = out_.permute(*perm).contiguous() + print(f"out_.shape: {out_.shape}, type(out_): {type(out_)}") + # if original_qkv_layout != qkv_layout: + # original_qkv_format = original_qkv_layout.split("_")[0] + # new_qkv_format = qkv_layout.split("_")[0] + # perm = [] + # for i in new_qkv_format: + # perm.append(original_qkv_format.find(i)) + # out_ = out_.permute(*perm).contiguous() # out_fp8: Float8Tensor/MXFP8Tensor; dtype = torch.float16 or torch.bfloat16 # fp8_dtype = tex.DType.kFloat8E4M3 # out: torch.Tensor; dtype = torch.float16 or torch.bfloat16 out_fp8 = out_ out = out_ - if isinstance(out_, Float8Tensor) or isinstance(out_, MXFP8Tensor): + if isinstance(out_, QuantizedTensorStorage): if not is_output_fp8 or not is_bwd_fp8: out = out_.dequantize().view(out_.shape) else: @@ -1382,6 +1376,7 @@ def forward( dropout_p, fast_zero_fill, qkv_layout, + o_format, attn_bias_type, attn_mask_type, softmax_type, diff --git a/transformer_engine/pytorch/attention/dot_product_attention/utils.py b/transformer_engine/pytorch/attention/dot_product_attention/utils.py index 26a2bda08b..176434d883 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/utils.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/utils.py @@ -2094,33 +2094,37 @@ def get_attention_quantizers(fp8, fp8_recipe, quantizers): QKV_quantizer = quantizers["scaling_fwd"][META_QKV] QKV_quantizer.internal = True QKV_quantizer.set_usage(rowwise=True, columnwise=False) - O_quantizer = quantizers["scaling_fwd"][META_O] - O_quantizer.set_usage(rowwise=True, columnwise=False) + S_quantizer = quantizers["scaling_fwd"][META_S] S_quantizer.internal = True S_quantizer.set_usage(rowwise=True, columnwise=False) - dQKV_quantizer = quantizers["scaling_bwd"][META_DQKV] - dQKV_quantizer.interal = True - dQKV_quantizer.set_usage(rowwise=True, columnwise=False) + O_quantizer = quantizers["scaling_fwd"][META_O] + O_quantizer.internal = False + O_quantizer.set_usage(rowwise=True, columnwise=False) + dO_quantizer = quantizers["scaling_bwd"][META_DO] - dO_quantizer.set_usage(rowwise=True, columnwise=False) dO_quantizer.internal = True + dO_quantizer.set_usage(rowwise=True, columnwise=False) + dP_quantizer = quantizers["scaling_bwd"][META_DP] - dP_quantizer.set_usage(rowwise=True, columnwise=False) dP_quantizer.interal = True + dP_quantizer.set_usage(rowwise=True, columnwise=False) + + dQKV_quantizer = quantizers["scaling_bwd"][META_DQKV] + dQKV_quantizer.interal = False + dQKV_quantizer.set_usage(rowwise=True, columnwise=False) if fp8_recipe.mxfp8(): - QKV_quantizer.columnwise = True + QKV_quantizer.columnwise_usage = True QKV_quantizer.optimize_for_gemm = True - O_quantizer.columnwise = True - O_quantizer.optimize_for_gemm = True S_quantizer = None - dQKV_quantizer.columnwise = True - dQKV_quantizer.optimize_for_gemm = True - dO_quantizer.columnwise = True + O_quantizer.columnwise_usage = True + + dO_quantizer.columnwise_usage = True dO_quantizer.optimize_for_gemm = True dP_quantizer = None + dQKV_quantizer.columnwise_usage = True return QKV_quantizer, O_quantizer, S_quantizer, dQKV_quantizer, dO_quantizer, dP_quantizer diff --git a/transformer_engine/pytorch/cpp_extensions/fused_attn.py b/transformer_engine/pytorch/cpp_extensions/fused_attn.py index 8f77a8a7fd..629046aa1c 100644 --- a/transformer_engine/pytorch/cpp_extensions/fused_attn.py +++ b/transformer_engine/pytorch/cpp_extensions/fused_attn.py @@ -135,6 +135,7 @@ def fused_attn_fwd( dropout: float = 0.0, fast_zero_fill: bool = True, qkv_layout: str = "sbh3d", + o_format: str = "sbhd", attn_bias_type: str = "no_bias", attn_mask_type: str = "padding", softmax_type: str = "vanilla", @@ -204,6 +205,8 @@ def fused_attn_fwd( {"sb3hd", "sbh3d", "sbhd_sb2hd", "sbhd_sbh2d", "sbhd_sbhd_sbhd", "bs3hd", "bsh3d", "bshd_bs2hd", "bshd_bsh2d", "bshd_bshd_bshd", "t3hd", "th3d", "thd_t2hd", "thd_th2d", "thd_thd_thd"} + o_format : str, default = "sbhd" + format of O; {"sbhd", "bshd", "thd"} attn_bias_type : str, default = "no_bias" type of the bias; {"no_bias", "pre_scale_bias", "post_scale_bias", "alibi"} attn_mask_type : str, default = "padding" @@ -307,6 +310,7 @@ def fused_attn_fwd( dropout, fast_zero_fill, QKVLayout[qkv_layout], + QKVFormat[o_format], AttnBiasType[attn_bias_type], AttnMaskType[attn_mask_type], SoftmaxType[softmax_type], diff --git a/transformer_engine/pytorch/csrc/extensions.h b/transformer_engine/pytorch/csrc/extensions.h index e9531573d6..7d2d002111 100644 --- a/transformer_engine/pytorch/csrc/extensions.h +++ b/transformer_engine/pytorch/csrc/extensions.h @@ -85,7 +85,7 @@ NVTE_Fused_Attn_Backend get_fused_attn_backend( std::vector fused_attn_fwd( size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training, float attn_scale, float p_dropout, - bool set_zero, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, + bool set_zero, NVTE_QKV_Layout qkv_layout, NVTE_QKV_Format o_format, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, const std::vector window_size, bool bottom_right_diagonal, const at::Tensor cu_seqlens_q, const at::Tensor cu_seqlens_kv, const py::handle Q, diff --git a/transformer_engine/pytorch/csrc/extensions/attention.cpp b/transformer_engine/pytorch/csrc/extensions/attention.cpp index 0d7a842ce1..30415b4373 100644 --- a/transformer_engine/pytorch/csrc/extensions/attention.cpp +++ b/transformer_engine/pytorch/csrc/extensions/attention.cpp @@ -114,7 +114,7 @@ std::pair quantizer_helper(py::handle quantizer, // fused attention FWD with separate Q, K and V tensors std::vector fused_attn_fwd( size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training, float attn_scale, float p_dropout, - bool set_zero, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, + bool set_zero, NVTE_QKV_Layout qkv_layout, NVTE_QKV_Format o_format, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, const std::vector window_size, bool bottom_right_diagonal, const at::Tensor cu_seqlens_q, const at::Tensor cu_seqlens_kv, const py::handle Q, @@ -150,8 +150,9 @@ std::vector fused_attn_fwd( std::unique_ptr O_quantizer = convert_quantizer(o_quantizer); std::vector q_shape = convertShape(te_Q.shape()); std::vector v_shape = convertShape(te_V.shape()); - auto o_shape = std::vector{q_shape.begin(), q_shape.end()}; - o_shape[o_shape.size() - 1] = v_shape[v_shape.size() - 1]; + auto o_shape_tmp = std::vector{q_shape.begin(), q_shape.end()}; + o_shape_tmp[o_shape_tmp.size() - 1] = v_shape[v_shape.size() - 1]; + auto o_shape = nvte_convert_qkv_format(o_shape_tmp, nvte_get_q_format(qkv_layout), o_format); const DType fake_dtype_te = GetTransformerEngineDType(fake_dtype); std::tie(te_O, py_O) = quantizer_helper(o_quantizer, o_shape, fake_dtype_te, true, std::nullopt); @@ -164,7 +165,7 @@ std::vector fused_attn_fwd( // FP8 auto h = q_shape[q_shape.size() - 2]; auto d = q_shape[q_shape.size() - 1]; - if (set_zero && (nvte_get_qkv_format(qkv_layout) == NVTE_QKV_Format::NVTE_THD)) { + if (set_zero && (o_format == NVTE_QKV_Format::NVTE_THD)) { if ((h * d) % block_size == 0) { mha_fill(te_O, cu_seqlens_q.index({torch::indexing::Slice(-1, torch::indexing::None)})); } else { @@ -172,7 +173,7 @@ std::vector fused_attn_fwd( } } } else if (qkv_type == DType::kBFloat16 || qkv_type == DType::kFloat16) { - if (nvte_get_qkv_format(qkv_layout) == NVTE_QKV_Format::NVTE_THD) { + if (o_format == NVTE_QKV_Format::NVTE_THD) { te_O.zero_(at::cuda::getCurrentCUDAStream()); } } else { @@ -251,7 +252,7 @@ std::vector fused_attn_fwd( te_O.data(), &nvte_aux_tensor_pack, te_cu_seqlens_q.data(), te_cu_seqlens_kv.data(), te_cu_seqlens_q_padded.data(), te_cu_seqlens_kv_padded.data(), te_page_table_k.data(), te_page_table_v.data(), te_rng_state.data(), max_seqlen_q, max_seqlen_kv, is_training, - return_max_logit, cuda_graph, attn_scale, p_dropout, qkv_layout, bias_type, attn_mask_type, + return_max_logit, cuda_graph, attn_scale, p_dropout, qkv_layout, o_format, bias_type, attn_mask_type, softmax_type, window_size[0], window_size[1], bottom_right_diagonal, workspace.data(), at::cuda::getCurrentCUDAStream()); }); @@ -311,7 +312,7 @@ std::vector fused_attn_fwd( te_O.data(), &nvte_aux_tensor_pack, te_cu_seqlens_q.data(), te_cu_seqlens_kv.data(), te_cu_seqlens_q_padded.data(), te_cu_seqlens_kv_padded.data(), te_page_table_k.data(), te_page_table_v.data(), te_rng_state.data(), max_seqlen_q, max_seqlen_kv, is_training, - return_max_logit, cuda_graph, attn_scale, p_dropout, qkv_layout, bias_type, attn_mask_type, + return_max_logit, cuda_graph, attn_scale, p_dropout, qkv_layout, o_format, bias_type, attn_mask_type, softmax_type, window_size[0], window_size[1], bottom_right_diagonal, workspace.data(), at::cuda::getCurrentCUDAStream()); }); @@ -367,6 +368,11 @@ std::vector fused_attn_bwd( at::Tensor dQ, dK, dV, dQKV, dKV; NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(qkv_layout); std::vector tmp_shape; + // DType dqkv_type = DType::kNumTypes; + // if (!dqkv_quantizer.is_none()) { + // dqkv_type = dqkv_quantizer.attr("fp8_dtype").cast(); + // } + // printf(">>>>>> dQKV_type: %d\n", dqkv_type); auto options = torch::TensorOptions().dtype(GetATenDType(dqkv_type)).device(torch::kCUDA); if (dqkv_type == DType::kFloat8E4M3 || dqkv_type == DType::kFloat8E5M2) { options = options.dtype(torch::kUInt8); From 8909b35da8ff35bd09bbec184afceaa749f068a6 Mon Sep 17 00:00:00 2001 From: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Date: Thu, 19 Feb 2026 19:00:07 -0800 Subject: [PATCH 022/172] attempt at bwd with o_format/d_out_format/dqkv_layout Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --- .../common/fused_attn/fused_attn.cpp | 8 +-- .../common/fused_attn/fused_attn_fp8.cu | 61 +++++++++-------- .../common/fused_attn/fused_attn_fp8.h | 2 +- .../include/transformer_engine/fused_attn.h | 5 +- .../dot_product_attention/backends.py | 68 ++++++++++--------- .../attention/dot_product_attention/utils.py | 25 +++---- .../pytorch/cpp_extensions/fused_attn.py | 29 ++++++-- transformer_engine/pytorch/csrc/extensions.h | 4 +- .../pytorch/csrc/extensions/attention.cpp | 26 +++---- 9 files changed, 127 insertions(+), 101 deletions(-) diff --git a/transformer_engine/common/fused_attn/fused_attn.cpp b/transformer_engine/common/fused_attn/fused_attn.cpp index 79a8417bdc..1e9673fff7 100644 --- a/transformer_engine/common/fused_attn/fused_attn.cpp +++ b/transformer_engine/common/fused_attn/fused_attn.cpp @@ -828,7 +828,7 @@ void nvte_fused_attn_bwd_qkvpacked( Tensor dK_view = make_tensor_view(output_dQKV, unpacked_shape, stride); Tensor dV_view = make_tensor_view(output_dQKV, unpacked_shape, 2 * stride); - fused_attn_fp8_bwd(b, h, h, max_seqlen, max_seqlen, d, d, attn_scale, dropout, qkv_layout, + fused_attn_fp8_bwd(b, h, h, max_seqlen, max_seqlen, d, d, attn_scale, dropout, qkv_layout, qkv_format, qkv_format, qkv_layout, bias_type, attn_mask_type, window_size_left, window_size_right, bottom_right_diagonal, &Q_view, &K_view, &V_view, input_O, input_dO, input_dO_f16, input_M, input_ZInv, input_S, input_output_dP, &dQ_view, &dK_view, &dV_view, input_cu_seqlens, input_cu_seqlens, input_rng_state, wkspace, stream, @@ -1150,7 +1150,7 @@ void nvte_fused_attn_bwd_kvpacked( Tensor dV_view = make_tensor_view(output_dKV, unpacked_kv_shape, stride); fused_attn_fp8_bwd(b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d, d, attn_scale, dropout, - qkv_layout, bias_type, attn_mask_type, window_size_left, window_size_right, bottom_right_diagonal, input_Q, &K_view, &V_view, input_O, + qkv_layout, q_format, q_format, qkv_layout, bias_type, attn_mask_type, window_size_left, window_size_right, bottom_right_diagonal, input_Q, &K_view, &V_view, input_O, input_dO, input_dO_f16, input_M, input_ZInv, input_S, input_output_dP, output_dQ, &dK_view, &dV_view, input_cu_seqlens_q, input_cu_seqlens_kv, input_rng_state, wkspace, stream, handle); @@ -1293,7 +1293,7 @@ void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETenso const NVTETensor cu_seqlens_q_padded, const NVTETensor cu_seqlens_kv_padded, size_t max_seqlen_q, size_t max_seqlen_kv, float attn_scale, float dropout, - NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, + NVTE_QKV_Layout qkv_layout, NVTE_QKV_Format o_format, NVTE_QKV_Format d_out_format, NVTE_QKV_Layout dqkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, int64_t window_size_left, int64_t window_size_right, bool bottom_right_diagonal, bool deterministic, bool cuda_graph, @@ -1390,7 +1390,7 @@ void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETenso input_dO_f16 = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[3]); } fused_attn_fp8_bwd(b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d_qk, d_v, attn_scale, dropout, - qkv_layout, bias_type, attn_mask_type, window_size_left, window_size_right, bottom_right_diagonal, input_Q, input_K, input_V, input_O, + qkv_layout, o_format, d_out_format, dqkv_layout, bias_type, attn_mask_type, window_size_left, window_size_right, bottom_right_diagonal, input_Q, input_K, input_V, input_O, input_dO, input_dO_f16, input_M, input_ZInv, input_S, input_output_dP, output_dQ, output_dK, output_dV, input_cu_seqlens_q, input_cu_seqlens_kv, input_rng_state, wkspace, stream, handle); diff --git a/transformer_engine/common/fused_attn/fused_attn_fp8.cu b/transformer_engine/common/fused_attn/fused_attn_fp8.cu index 4a9af8b4b3..2afe979f04 100644 --- a/transformer_engine/common/fused_attn/fused_attn_fp8.cu +++ b/transformer_engine/common/fused_attn/fused_attn_fp8.cu @@ -2050,7 +2050,7 @@ void fused_attn_fp8_fwd_impl_v1( // fused attention BWD FP8 with FE 1.0+ void fused_attn_fp8_bwd_impl_v1( int64_t b, int64_t h, int64_t hg, int64_t s_q, int64_t s_kv, int64_t d_qk, int64_t d_v, float scaling_factor, - float dropout_probability, NVTE_QKV_Layout layout, NVTE_Bias_Type bias_type, + float dropout_probability, NVTE_QKV_Layout qkv_layout, NVTE_QKV_Format o_format, NVTE_QKV_Format d_out_format, NVTE_QKV_Layout dqkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, int64_t window_size_left, int64_t window_size_right, bool bottom_right_diagonal,void* devPtrQ, void* devPtrK, void* devPtrV, void* devPtrM, void* devPtrZInv, void* devPtrO, void* devPtrdO, void* devPtrdQ, void* devPtrdK, void* devPtrdV, void* devPtrDescaleQ, void* devPtrDescaleK, void* devPtrDescaleV, void* devPtrDescaleO, @@ -2107,7 +2107,7 @@ void fused_attn_fp8_bwd_impl_v1( scaling_factor, true, dropout_probability, - layout, + qkv_layout, bias_type, mask_type, NVTE_Softmax_Type::NVTE_VANILLA_SOFTMAX, @@ -2197,14 +2197,13 @@ void fused_attn_fp8_bwd_impl_v1( std::vector k_stride(4); std::vector v_stride(4); std::vector o_stride(4); - generateMatrixStrides(b, h, s_q, s_kv, d_qk, q_stride.data(), layout, + generateMatrixStrides(b, h, s_q, s_kv, d_qk, q_stride.data(), qkv_layout, NVTE_QKV_Matrix::NVTE_Q_Matrix); - generateMatrixStrides(b, hg, s_q, s_kv, d_qk, k_stride.data(), layout, + generateMatrixStrides(b, hg, s_q, s_kv, d_qk, k_stride.data(), qkv_layout, NVTE_QKV_Matrix::NVTE_K_Matrix); - generateMatrixStrides(b, hg, s_q, s_kv, d_v, v_stride.data(), layout, + generateMatrixStrides(b, hg, s_q, s_kv, d_v, v_stride.data(), qkv_layout, NVTE_QKV_Matrix::NVTE_V_Matrix); - generateMatrixStrides(b, h, s_q, s_kv, d_v, o_stride.data(), layout, - NVTE_QKV_Matrix::NVTE_O_Matrix); + generateMatrixStridesWithFormat(b, h, s_q, d_v, o_stride.data(), o_format); Q = mha_graph->tensor(fe::graph::Tensor_attributes() .set_name("Q") .set_dim({b, h, s_q, d_qk}) @@ -2277,12 +2276,11 @@ void fused_attn_fp8_bwd_impl_v1( std::vector q_t_stride(4); std::vector k_t_stride(4); std::vector dO_t_stride(4); - generateMatrixStrides(b, h, s_q, s_kv, d_qk, q_t_stride.data(), layout, + generateMatrixStrides(b, h, s_q, s_kv, d_qk, q_t_stride.data(), qkv_layout, NVTE_QKV_Matrix::NVTE_Q_Matrix_Transpose); - generateMatrixStrides(b, hg, s_q, s_kv, d_qk, k_t_stride.data(), layout, + generateMatrixStrides(b, hg, s_q, s_kv, d_qk, k_t_stride.data(), qkv_layout, NVTE_QKV_Matrix::NVTE_K_Matrix_Transpose); - generateMatrixStrides(b, h, s_q, s_kv, d_v, dO_t_stride.data(), layout, - NVTE_QKV_Matrix::NVTE_O_Matrix_Transpose); + generateMatrixStridesWithFormat(b, h, d_v, s_q, dO_t_stride.data(), d_out_format); Q_t = mha_graph->tensor(fe::graph::Tensor_attributes() .set_name("Q_t") .set_dim({b, h, s_q, d_qk}) @@ -2324,20 +2322,18 @@ void fused_attn_fp8_bwd_impl_v1( std::vector v_scale_strides(4); std::vector dO_scale_strides(4); std::vector dO_t_scale_strides(4); - generateMatrixStrides(b, h, s_q_padded, s_kv_padded, d_qk_scale_padded, q_scale_strides.data(), layout, + generateMatrixStrides(b, h, s_q_padded, s_kv_padded, d_qk_scale_padded, q_scale_strides.data(), qkv_layout, NVTE_QKV_Matrix::NVTE_Q_Matrix); - generateMatrixStrides(b, h, s_q_scale_padded, s_kv_scale_padded, d_qk_padded, q_t_scale_strides.data(), layout, + generateMatrixStrides(b, h, s_q_scale_padded, s_kv_scale_padded, d_qk_padded, q_t_scale_strides.data(), qkv_layout, NVTE_QKV_Matrix::NVTE_Q_Matrix_Transpose); - generateMatrixStrides(b, hg, s_q_padded, s_kv_padded, d_qk_scale_padded, k_scale_strides.data(), layout, + generateMatrixStrides(b, hg, s_q_padded, s_kv_padded, d_qk_scale_padded, k_scale_strides.data(), qkv_layout, NVTE_QKV_Matrix::NVTE_K_Matrix); - generateMatrixStrides(b, hg, s_q_scale_padded, s_kv_scale_padded, d_qk_padded, k_t_scale_strides.data(), layout, + generateMatrixStrides(b, hg, s_q_scale_padded, s_kv_scale_padded, d_qk_padded, k_t_scale_strides.data(), qkv_layout, NVTE_QKV_Matrix::NVTE_K_Matrix_Transpose); - generateMatrixStrides(b, hg, s_q_padded, s_kv_padded, d_v_scale_padded, v_scale_strides.data(), layout, + generateMatrixStrides(b, hg, s_q_padded, s_kv_padded, d_v_scale_padded, v_scale_strides.data(), qkv_layout, NVTE_QKV_Matrix::NVTE_V_Matrix); - generateMatrixStrides(b, h, s_q_padded, s_kv_padded, d_v_scale_padded, dO_scale_strides.data(), layout, - NVTE_QKV_Matrix::NVTE_O_Matrix); - generateMatrixStrides(b, h, s_q_scale_padded, s_kv_scale_padded, d_v_padded, dO_t_scale_strides.data(), layout, - NVTE_QKV_Matrix::NVTE_O_Matrix_Transpose); + generateMatrixStridesWithFormat(b, h, s_q_padded, d_v_scale_padded, dO_scale_strides.data(), d_out_format); + generateMatrixStridesWithFormat(b, h, d_v_padded, s_q_scale_padded, dO_t_scale_strides.data(), d_out_format); descale_q = mha_graph->tensor(fe::graph::Tensor_attributes() .set_name("Descale_q") .set_dim({b, h, s_q_padded, d_qk_scale_padded}) @@ -2473,9 +2469,18 @@ void fused_attn_fp8_bwd_impl_v1( amax_dK = outputs[4]; amax_dV = outputs[5]; } - dQ->set_output(true).set_dim({b, h, s_q, d_qk}).set_stride(q_stride).set_data_type(dqkv_tensor_type); - dK->set_output(true).set_dim({b, hg, s_kv, d_qk}).set_stride(k_stride).set_data_type(dqkv_tensor_type); - dV->set_output(true).set_dim({b, hg, s_kv, d_v}).set_stride(v_stride).set_data_type(dqkv_tensor_type); + std::vector dq_stride(4); + std::vector dk_stride(4); + std::vector dv_stride(4); + generateMatrixStrides(b, h, s_q, s_kv, d_qk, dq_stride.data(), dqkv_layout, + NVTE_QKV_Matrix::NVTE_Q_Matrix); + generateMatrixStrides(b, hg, s_q, s_kv, d_qk, dk_stride.data(), dqkv_layout, + NVTE_QKV_Matrix::NVTE_K_Matrix); + generateMatrixStrides(b, hg, s_q, s_kv, d_v, dv_stride.data(), dqkv_layout, + NVTE_QKV_Matrix::NVTE_V_Matrix); + dQ->set_output(true).set_dim({b, h, s_q, d_qk}).set_stride(dq_stride).set_data_type(dqkv_tensor_type); + dK->set_output(true).set_dim({b, hg, s_kv, d_qk}).set_stride(dk_stride).set_data_type(dqkv_tensor_type); + dV->set_output(true).set_dim({b, hg, s_kv, d_v}).set_stride(dv_stride).set_data_type(dqkv_tensor_type); amax_dQ->set_output(true) .set_dim({1, 1, 1, 1}) .set_stride({1, 1, 1, 1}) @@ -2773,7 +2778,7 @@ void fused_attn_fp8_fwd(size_t batch, size_t num_attn_heads, size_t num_gqa_grou // fused attention BWD FP8 with separate Q, K, V void fused_attn_fp8_bwd(size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, - float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, + float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_QKV_Format o_format, NVTE_QKV_Format d_out_format, NVTE_QKV_Layout dqkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, size_t window_size_left, size_t window_size_right, bool bottom_right_diagonal, const Tensor* input_Q, const Tensor* input_K, const Tensor* input_V, const Tensor* input_O, const Tensor* input_dO, const Tensor* input_dO_f16, const Tensor* input_M, const Tensor* input_ZInv, @@ -2839,11 +2844,11 @@ void fused_attn_fp8_bwd(size_t batch, size_t num_attn_heads, size_t num_gqa_grou const DType dQKV_type = output_dQ->data.dtype; size_t workspace_size = 0; - NVTE_QKV_Format qkv_format = nvte_get_qkv_format(qkv_layout); - if ((qkv_format == NVTE_QKV_Format::NVTE_BSHD) || (qkv_format == NVTE_QKV_Format::NVTE_SBHD)) { + NVTE_QKV_Format dqkv_format = nvte_get_qkv_format(dqkv_layout); + if ((dqkv_format == NVTE_QKV_Format::NVTE_BSHD) || (dqkv_format == NVTE_QKV_Format::NVTE_SBHD)) { fused_attn::fused_attn_fp8_bwd_impl_v1( batch, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, head_dim_qk, head_dim_v, attn_scale, - p_dropout, qkv_layout, bias_type, mask_type, window_size_left, window_size_right, bottom_right_diagonal, devPtrQ, devPtrK, devPtrV, devPtrM, devPtrZInv, + p_dropout, qkv_layout, o_format, d_out_format, dqkv_layout, bias_type, mask_type, window_size_left, window_size_right, bottom_right_diagonal, devPtrQ, devPtrK, devPtrV, devPtrM, devPtrZInv, devPtrO, devPtrdO, devPtrdQ, devPtrdK, devPtrdV, devPtrDescaleQ, devPtrDescaleK, devPtrDescaleV, devPtrDescaleO, devPtrDescaledO, devPtrDescaleS, devPtrDescaledP, devPtrScaleS, devPtrScaledP, devPtrScaledQ, devPtrScaledK, devPtrScaledV, devPtrAmaxdP, @@ -2853,7 +2858,7 @@ void fused_attn_fp8_bwd(size_t batch, size_t num_attn_heads, size_t num_gqa_grou devPtrDropoutSeed, devPtrDropoutOffset, get_cudnn_fe_dtype(QKV_type), get_cudnn_fe_dtype(O_type), get_cudnn_fe_dtype(dO_type), get_cudnn_fe_dtype(dQKV_type), input_dO->scaling_mode, workspace->data.dptr, &workspace_size, stream, handle); - } else if (qkv_layout == NVTE_QKV_Layout::NVTE_T3HD) { + } else if (dqkv_layout == NVTE_QKV_Layout::NVTE_T3HD) { fused_attn::fused_attn_fp8_bwd_impl( batch, num_attn_heads, max_seqlen_q, max_seqlen_kv, head_dim_qk, attn_scale, p_dropout, qkv_layout, devPtrQ, devPtrK, devPtrV, devPtrM, devPtrZInv, devPtrO, devPtrdO, devPtrdQ, diff --git a/transformer_engine/common/fused_attn/fused_attn_fp8.h b/transformer_engine/common/fused_attn/fused_attn_fp8.h index 548a41a561..215b5dd92a 100644 --- a/transformer_engine/common/fused_attn/fused_attn_fp8.h +++ b/transformer_engine/common/fused_attn/fused_attn_fp8.h @@ -27,7 +27,7 @@ void fused_attn_fp8_fwd(size_t batch, size_t num_attn_heads, size_t num_gqa_grou // fused attention BWD FP8 with separate Q, K, V void fused_attn_fp8_bwd(size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, - float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, + float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_QKV_Format o_format, NVTE_QKV_Format d_out_format, NVTE_QKV_Layout dqkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, size_t window_size_left, size_t window_size_right, bool bottom_right_diagonal, const Tensor *input_Q, const Tensor *input_K, const Tensor *input_V, const Tensor *input_O, const Tensor *input_dO, const Tensor *input_dO_f16, const Tensor *input_M, const Tensor *input_ZInv, diff --git a/transformer_engine/common/include/transformer_engine/fused_attn.h b/transformer_engine/common/include/transformer_engine/fused_attn.h index 883c5a6e61..d866cab702 100644 --- a/transformer_engine/common/include/transformer_engine/fused_attn.h +++ b/transformer_engine/common/include/transformer_engine/fused_attn.h @@ -647,6 +647,9 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso * \param[in] attn_scale Scaling factor for Q * K.T. * \param[in] dropout Dropout probability. * \param[in] qkv_layout QKV tensors' layout. + * \param[in] o_format Output format. + * \param[in] d_out_format Output gradient's format. + * \param[in] dqkv_layout QKV gradient tensors' layout. * \param[in] bias_type Bias type. * \param[in] attn_mask_type Attention mask type. * \param[in] softmax_type Attention softmax type. @@ -666,7 +669,7 @@ void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETenso const NVTETensor cu_seqlens_q_padded, const NVTETensor cu_seqlens_kv_padded, size_t max_seqlen_q, size_t max_seqlen_kv, float attn_scale, float dropout, - NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, + NVTE_QKV_Layout qkv_layout, NVTE_QKV_Format o_format, NVTE_QKV_Format d_out_format, NVTE_QKV_Layout dqkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, int64_t window_size_left, int64_t window_size_right, bool bottom_right_diagonal, bool deterministic, bool cuda_graph, diff --git a/transformer_engine/pytorch/attention/dot_product_attention/backends.py b/transformer_engine/pytorch/attention/dot_product_attention/backends.py index 2838eaa5eb..1a4b34f4fa 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/backends.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/backends.py @@ -1213,6 +1213,7 @@ def forward( if fp8_meta is not None and fp8_meta.get("local_recipes", None) is not None: fp8_recipe = fp8_meta["local_recipes"][0] + # save qkv_layout and get output format original_qkv_layout = qkv_layout _, o_format, _ = dpa_utils.get_qkv_format(qkv_layout) @@ -1300,14 +1301,6 @@ def forward( softmax_offset, cuda_graph=is_graph_capturing(), ) - print(f"out_.shape: {out_.shape}, type(out_): {type(out_)}") - # if original_qkv_layout != qkv_layout: - # original_qkv_format = original_qkv_layout.split("_")[0] - # new_qkv_format = qkv_layout.split("_")[0] - # perm = [] - # for i in new_qkv_format: - # perm.append(original_qkv_format.find(i)) - # out_ = out_.permute(*perm).contiguous() # out_fp8: Float8Tensor/MXFP8Tensor; dtype = torch.float16 or torch.bfloat16 # fp8_dtype = tex.DType.kFloat8E4M3 @@ -1463,9 +1456,10 @@ def forward( else: ctx.qkv_layout = qkv_layout else: - ctx.original_qkv_layout = original_qkv_layout ctx.qkv_layout = qkv_layout + ctx.o_format = o_format + ctx.dqkv_layout = original_qkv_layout ctx.attn_bias_type = attn_bias_type ctx.attn_mask_type = attn_mask_type ctx.softmax_type = softmax_type @@ -1486,29 +1480,32 @@ def backward(ctx, d_out, *_args): # pylint: disable=missing-function-docstring print(f"ctx.fp8: {ctx.fp8}, ctx.is_output_fp8: {ctx.is_output_fp8}, isinstance(d_out, QuantizedTensorStorage): {isinstance(d_out, QuantizedTensorStorage)}") - # reshape d_out to ctx.qkv_layout; only happens with MXFP8BlockScaling - if ctx.original_qkv_layout != ctx.qkv_layout: - print(f"ctx.original_qkv_layout: {ctx.original_qkv_layout}, ctx.qkv_layout: {ctx.qkv_layout}") - print(f"d_out before reshape: {d_out.shape}, {type(d_out)}") - original_qkv_format = ctx.original_qkv_layout.split("_")[0] - new_qkv_format = ctx.qkv_layout.split("_")[0] - perm = [] - for i in original_qkv_format: - perm.append(new_qkv_format.find(i)) - d_out = d_out.permute(*perm).contiguous() - print(f"d_out after reshape: {d_out.shape}, {type(d_out)}") + # # reshape d_out to ctx.qkv_layout; only happens with MXFP8BlockScaling + # if ctx.original_qkv_layout != ctx.qkv_layout: + # print(f"ctx.original_qkv_layout: {ctx.original_qkv_layout}, ctx.qkv_layout: {ctx.qkv_layout}") + # print(f"d_out before reshape: {d_out.shape}, {type(d_out)}") + # original_qkv_format = ctx.original_qkv_layout.split("_")[0] + # new_qkv_format = ctx.qkv_layout.split("_")[0] + # perm = [] + # for i in original_qkv_format: + # perm.append(new_qkv_format.find(i)) + # d_out = d_out.permute(*perm).contiguous() + # print(f"d_out after reshape: {d_out.shape}, {type(d_out)}") # d_out: torch.Tensor; dtype = torch.float16 or torch.bfloat16 # d_out_fp8: Float8Tensor; dtype = torch.float16 or torch.bfloat16 # fp8_dtype = tex.DType.kFloat8E5M2 d_out_fp8 = None + d_out_format = ctx.o_format if ctx.fp8: print(f"d_out before quantizer: {d_out.shape}, {type(d_out)}") + if ctx.fp8_recipe.mxfp8(): + d_out, d_out_format = dpa_utils.permute_to_grouped_tensor(d_out_format, d_out) if isinstance(d_out, QuantizedTensorStorage): d_out_fp8 = d_out else: d_out_fp8 = ctx.dO_quantizer(d_out) - print(f"d_out after quantizer: {d_out_fp8.shape}, {type(d_out_fp8)}") + print(f"d_out after quantizer: {d_out.shape}, {d_out_fp8.shape}, {type(d_out)}, {type(d_out_fp8)}") if not isinstance(d_out, QuantizedTensorStorage) and not ctx.use_FAv2_bwd: d_out = d_out.contiguous() ( @@ -1583,8 +1580,8 @@ def backward(ctx, d_out, *_args): ctx.dP_quantizer, ) - # get tex.DType for dq, dk, dv data - dqkv_te_dtype = d_out_fp8._fp8_dtype + # # get tex.DType for dq, dk, dv data + # dqkv_te_dtype = d_out_fp8._fp8_dtype # q_fp8, k_fp8, v_fp8, out_fp8: Float8Tensor; dtype = torch.float16 or torch.bfloat16, # fp8_dtype = tex.DType.kFloat8E4M3 @@ -1616,7 +1613,7 @@ def backward(ctx, d_out, *_args): out_, d_out_fp8, dqkv_nominal_dtype, - dqkv_te_dtype, # could we remove this? + # dqkv_te_dtype, # could we remove this? aux_ctx_tensors, ctx.fused_attention_backend, cu_seqlens_q_padded, @@ -1628,6 +1625,9 @@ def backward(ctx, d_out, *_args): ctx.dropout_p, ctx.fast_zero_fill, ctx.qkv_layout, + ctx.o_format, + d_out_format, + ctx.dqkv_layout, ctx.attn_bias_type, ctx.attn_mask_type, ctx.softmax_type, @@ -1636,13 +1636,15 @@ def backward(ctx, d_out, *_args): ctx.deterministic, is_graph_capturing(), ) - if ctx.original_qkv_layout != ctx.qkv_layout: - original_qkv_format = ctx.original_qkv_layout.split("_")[0] - new_qkv_format = ctx.qkv_layout.split("_")[0] - perm = [] - for i in new_qkv_format: - perm.append(original_qkv_format.find(i)) - dq_, dk_, dv_ = [x.permute(*perm).contiguous() for x in (dq_, dk_, dv_)] + print(f"dq_.shape: {dq_.shape}, dk_.shape: {dk_.shape}, dv_.shape: {dv_.shape}") + print(f"types: {type(dq_)}, {type(dk_)}, {type(dv_)}") + # if ctx.original_qkv_layout != ctx.qkv_layout: + # original_qkv_format = ctx.original_qkv_layout.split("_")[0] + # new_qkv_format = ctx.qkv_layout.split("_")[0] + # perm = [] + # for i in new_qkv_format: + # perm.append(original_qkv_format.find(i)) + # dq_, dk_, dv_ = [x.permute(*perm).contiguous() for x in (dq_, dk_, dv_)] # dq, dk, dv: torch.Tensor; dtype = torch.float16 or torch.bfloat16 dq, dk, dv = dq_, dk_, dv_ @@ -1676,7 +1678,7 @@ def backward(ctx, d_out, *_args): else: if isinstance(d_out, QuantizedTensorStorage): d_out = d_out.dequantize(dtype=ctx.nominal_dtype) - dqkv_te_dtype = TE_DType[d_out.dtype] + # dqkv_te_dtype = TE_DType[d_out.dtype] # q, k, v, out, d_out, dq, dk, dv: torch.Tensor; torch.float16 or torch.bfloat16 dq, dk, dv, *rest = fused_attn_bwd( ctx.max_seqlen_q, @@ -1689,7 +1691,7 @@ def backward(ctx, d_out, *_args): out, d_out, dqkv_nominal_dtype, - dqkv_te_dtype, + # dqkv_te_dtype, aux_ctx_tensors, ctx.fused_attention_backend, cu_seqlens_q_padded, diff --git a/transformer_engine/pytorch/attention/dot_product_attention/utils.py b/transformer_engine/pytorch/attention/dot_product_attention/utils.py index 176434d883..379e056b54 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/utils.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/utils.py @@ -2190,6 +2190,16 @@ def print_quantizers( f"{label} >> {names[i]:14s}: {type_str}" ) +def permute_to_grouped_tensor(src_format, tensor): + """Permute tensor to bhsd or htd format for grouped quantization in MXFP8BlockScaling. src_format ={bshd, sbhd, thd}""" + if src_format in ["bhsd", "htd"]: + return tensor, src_format + tensor = tensor.contiguous() if not tensor.is_contiguous() else tensor + dim_s_or_t = src_format.find("s") if 's' in src_format else src_format.find("t") + dim_others = [i for i in range(len(tensor.shape)) if i != dim_s_or_t] + perm = [*dim_others[:-1], dim_s_or_t, dim_others[-1]] + tensor = tensor.permute(*perm).contiguous() + return tensor, "bhsd" if src_format != "thd" else "htd" def combine_and_quantize(qkv_layout, q, k, v, qkv_quantizer): """Combine q,k,v based on qkv_layout and quantize them together""" @@ -2199,22 +2209,13 @@ def combine_and_quantize(qkv_layout, q, k, v, qkv_quantizer): qkv_group = len(qkv_layout.split("_")) src_nominal_dtype = q.dtype if isinstance(qkv_quantizer, MXFP8Quantizer): - - def permute_x(f, x): - x = x.contiguous() if not x.is_contiguous() else x - dim_s_dim_t = f.find("s") if 's' in f else f.find("t") - dim_others = [i for i in range(len(x.shape)) if i != dim_s_dim_t] - perm = [*dim_others[:-1], dim_s_dim_t, dim_others[-1]] - x = x.permute(*perm).contiguous() - return x - # bs3hd, sb3hd, etc -> bshd_bshd_bhsd -> bhsd_bhsd_bhsd # t3hd, etc -> thd_thd_thd -> htd_htd_htd if q_format not in ["bhsd", "htd"]: - q = permute_x(q_format, q) + q, _ = permute_to_grouped_tensor(q_format, q) if kv_format not in ["bhsd", "htd"]: - k = permute_x(kv_format, k) - v = permute_x(kv_format, v) + k, _ = permute_to_grouped_tensor(kv_format, k) + v, _ = permute_to_grouped_tensor(kv_format, v) qkv_layout = "bhsd_bhsd_bhsd" if qkv_format != "thd" else "htd_htd_htd" original_shapes = [x.shape for x in [q, k, v]] diff --git a/transformer_engine/pytorch/cpp_extensions/fused_attn.py b/transformer_engine/pytorch/cpp_extensions/fused_attn.py index 629046aa1c..7a756ead1c 100644 --- a/transformer_engine/pytorch/cpp_extensions/fused_attn.py +++ b/transformer_engine/pytorch/cpp_extensions/fused_attn.py @@ -364,7 +364,7 @@ def fused_attn_bwd( o: torch.Tensor, d_o: torch.Tensor, fake_dtype: torch.dtype, - dqkv_dtype: tex.DType, + # dqkv_dtype: tex.DType, aux_ctx_tensors: List[torch.Tensor], fused_attention_backend: tex.NVTE_Fused_Attn_Backend, cu_seqlens_q_padded: torch.Tensor = None, @@ -376,6 +376,9 @@ def fused_attn_bwd( dropout: float = 0.0, fast_zero_fill: bool = True, qkv_layout: str = "sbh3d", + o_format: str = "sbhd", + d_out_format: str = "sbhd", + dqkv_layout: str = "sbh3d", attn_bias_type: str = "no_bias", attn_mask_type: str = "padding", softmax_type: str = "vanilla", @@ -414,8 +417,8 @@ def fused_attn_bwd( fake_dtype : tex.DType data type of Q, K and V - in case of high precision, fake dtype in case of FP8; in torch.dtype - dqkv_dtype : tex.DType - data type of dQ, dK and dV; in tex.DType, not torch.dtype + # dqkv_dtype : tex.DType + # data type of dQ, dK and dV; in tex.DType, not torch.dtype aux_ctx_tensors : List[torch.Tensor] auxiliary output tensors of the forward pass when its is_training is True, e.g. aux_ctx_tensors = [M, ZInv, rng_state] @@ -442,6 +445,15 @@ def fused_attn_bwd( {"sb3hd", "sbh3d", "sbhd_sb2hd", "sbhd_sbh2d", "sbhd_sbhd_sbhd", "bs3hd", "bsh3d", "bshd_bs2hd", "bshd_bsh2d", "bshd_bshd_bshd", "t3hd", "th3d", "thd_t2hd", "thd_th2d", "thd_thd_thd"} + o_format : str, default = "sbhd" + format of O; {"sbhd", "bshd", "thd"} + d_out_format : str, default = "sbhd" + format of dO; {"sbhd", "bshd", "thd"} + dqkv_layout : str, default = "sbh3d" + layout of dQ, dK and dV; + {"sb3hd", "sbh3d", "sbhd_sb2hd", "sbhd_sbh2d", "sbhd_sbhd_sbhd", + "bs3hd", "bsh3d", "bshd_bs2hd", "bshd_bsh2d", "bshd_bshd_bshd", + "t3hd", "th3d", "thd_t2hd", "thd_th2d", "thd_thd_thd"} attn_bias_type : str, default = "no_bias" type of the bias; {"no_bias", "pre_scale_bias", "post_scale_bias", "alibi"} attn_mask_type : str, default = "padding" @@ -496,9 +508,9 @@ def fused_attn_bwd( ), "aux_ctx_tensors must contain rng_state as its last element." if fused_attention_backend == FusedAttnBackend["FP8"]: - assert ( - dqkv_dtype is not None - ), "dqkv_dtype is required as an input for FP8 fused attention backward." + # assert ( + # dqkv_dtype is not None + # ), "dqkv_dtype is required as an input for FP8 fused attention backward." assert ( len(aux_ctx_tensors) >= 3 ), "aux_ctx_tensors is required to be [M, ZInv, rng_state] for FP8 fused attention." @@ -510,6 +522,9 @@ def fused_attn_bwd( dropout, fast_zero_fill, QKVLayout[qkv_layout], + QKVFormat[o_format], + QKVFormat[d_out_format], + QKVLayout[dqkv_layout], AttnBiasType[attn_bias_type], AttnMaskType[attn_mask_type], SoftmaxType[softmax_type], @@ -524,7 +539,7 @@ def fused_attn_bwd( o, d_o, fake_dtype, - dqkv_dtype, + # dqkv_dtype, aux_ctx_tensors, cu_seqlens_q_padded, cu_seqlens_kv_padded, diff --git a/transformer_engine/pytorch/csrc/extensions.h b/transformer_engine/pytorch/csrc/extensions.h index 7d2d002111..795f50f672 100644 --- a/transformer_engine/pytorch/csrc/extensions.h +++ b/transformer_engine/pytorch/csrc/extensions.h @@ -99,11 +99,11 @@ std::vector fused_attn_fwd( std::vector fused_attn_bwd( size_t max_seqlen_q, size_t max_seqlen_kv, float attn_scale, float p_dropout, bool set_zero, - NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, + NVTE_QKV_Layout qkv_layout, NVTE_QKV_Format o_format, NVTE_QKV_Format d_out_format, NVTE_QKV_Layout dqkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, const std::vector window_size, bool bottom_right_diagonal, bool deterministic, const at::Tensor cu_seqlens_q, const at::Tensor cu_seqlens_kv, const py::handle Q, const py::handle K, const py::handle V, - const py::handle O, const py::handle dO, const at::ScalarType fake_dtype, const DType dqkv_type, + const py::handle O, const py::handle dO, const at::ScalarType fake_dtype, //const DType dqkv_type, const std::vector Aux_CTX_Tensors, const std::optional cu_seqlens_q_padded, const std::optional cu_seqlens_kv_padded, py::handle s_quantizer, diff --git a/transformer_engine/pytorch/csrc/extensions/attention.cpp b/transformer_engine/pytorch/csrc/extensions/attention.cpp index 30415b4373..0cb0ae0a06 100644 --- a/transformer_engine/pytorch/csrc/extensions/attention.cpp +++ b/transformer_engine/pytorch/csrc/extensions/attention.cpp @@ -327,11 +327,11 @@ std::vector fused_attn_fwd( // fused attention BWD with separate Q, K and V std::vector fused_attn_bwd( size_t max_seqlen_q, size_t max_seqlen_kv, float attn_scale, float p_dropout, bool set_zero, - NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, + NVTE_QKV_Layout qkv_layout, NVTE_QKV_Format o_format, NVTE_QKV_Format d_out_format, NVTE_QKV_Layout dqkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, const std::vector window_size, bool bottom_right_diagonal, bool deterministic, const at::Tensor cu_seqlens_q, const at::Tensor cu_seqlens_kv, const py::handle Q, const py::handle K, const py::handle V, - const py::handle O, const py::handle dO, const at::ScalarType fake_dtype, const DType dqkv_type, + const py::handle O, const py::handle dO, const at::ScalarType fake_dtype, //const DType dqkv_type, const std::vector Aux_CTX_Tensors, const std::optional cu_seqlens_q_padded, const std::optional cu_seqlens_kv_padded, py::handle s_quantizer, @@ -366,18 +366,18 @@ std::vector fused_attn_bwd( const DType fake_dtype_te = GetTransformerEngineDType(fake_dtype); at::Tensor dQ, dK, dV, dQKV, dKV; - NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(qkv_layout); + NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(dqkv_layout); std::vector tmp_shape; - // DType dqkv_type = DType::kNumTypes; - // if (!dqkv_quantizer.is_none()) { - // dqkv_type = dqkv_quantizer.attr("fp8_dtype").cast(); - // } - // printf(">>>>>> dQKV_type: %d\n", dqkv_type); + DType dqkv_type = fake_dtype_te; + if (!dqkv_quantizer.is_none()) { + dqkv_type = dqkv_quantizer.attr("fp8_dtype").cast(); + } + printf(">>>>>> dQKV_type: %d\n", dqkv_type); auto options = torch::TensorOptions().dtype(GetATenDType(dqkv_type)).device(torch::kCUDA); if (dqkv_type == DType::kFloat8E4M3 || dqkv_type == DType::kFloat8E5M2) { options = options.dtype(torch::kUInt8); } - if (detail::IsFloat8CurrentScalingQuantizers(dqkv_quantizer.ptr())) { + if (detail::IsFloat8CurrentScalingQuantizers(dqkv_quantizer.ptr()) || detail::IsMXFP8Quantizers(dqkv_quantizer.ptr())) { options = options.dtype(fake_dtype); } @@ -460,7 +460,7 @@ std::vector fused_attn_bwd( // construct NVTE tensors if (dqkv_type == DType::kFloat8E4M3 || dqkv_type == DType::kFloat8E5M2) { // FP8 - if (set_zero && (nvte_get_qkv_format(qkv_layout) == NVTE_QKV_Format::NVTE_THD)) { + if (set_zero && (nvte_get_qkv_format(dqkv_layout) == NVTE_QKV_Format::NVTE_THD)) { if (((h_q * d_qk) % block_size == 0) && ((h_kv * d_qk) % block_size == 0) && dQ.is_contiguous() && dK.is_contiguous() && dV.is_contiguous()) { mha_fill(te_dQ, cu_seqlens_q.index({torch::indexing::Slice(-1, torch::indexing::None)})); @@ -473,7 +473,7 @@ std::vector fused_attn_bwd( } } } else if (dqkv_type == DType::kBFloat16 || dqkv_type == DType::kFloat16) { - if (nvte_get_qkv_format(qkv_layout) == NVTE_QKV_Format::NVTE_THD) { + if (nvte_get_qkv_format(dqkv_layout) == NVTE_QKV_Format::NVTE_THD) { dQ.fill_(0); dK.fill_(0); dV.fill_(0); @@ -560,7 +560,7 @@ std::vector fused_attn_bwd( &nvte_aux_tensor_pack, te_dQ.data(), te_dK.data(), te_dV.data(), te_dBias.data(), te_dSoftmaxOffset.data(), te_cu_seqlens_q.data(), te_cu_seqlens_kv.data(), te_cu_seqlens_q_padded.data(), te_cu_seqlens_kv_padded.data(), max_seqlen_q, max_seqlen_kv, - attn_scale, p_dropout, qkv_layout, bias_type, attn_mask_type, softmax_type, window_size[0], + attn_scale, p_dropout, qkv_layout, o_format, d_out_format, dqkv_layout, bias_type, attn_mask_type, softmax_type, window_size[0], window_size[1], bottom_right_diagonal, deterministic, cuda_graph, workspace.data(), at::cuda::getCurrentCUDAStream()); }); @@ -577,7 +577,7 @@ std::vector fused_attn_bwd( &nvte_aux_tensor_pack, te_dQ.data(), te_dK.data(), te_dV.data(), te_dBias.data(), te_dSoftmaxOffset.data(), te_cu_seqlens_q.data(), te_cu_seqlens_kv.data(), te_cu_seqlens_q_padded.data(), te_cu_seqlens_kv_padded.data(), max_seqlen_q, max_seqlen_kv, - attn_scale, p_dropout, qkv_layout, bias_type, attn_mask_type, softmax_type, window_size[0], + attn_scale, p_dropout, qkv_layout, o_format, d_out_format, dqkv_layout, bias_type, attn_mask_type, softmax_type, window_size[0], window_size[1], bottom_right_diagonal, deterministic, cuda_graph, workspace.data(), at::cuda::getCurrentCUDAStream()); }); From 90a636c9e132b4288644e7c1cc94cdc9d3c673dc Mon Sep 17 00:00:00 2001 From: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Date: Fri, 20 Feb 2026 18:08:51 -0800 Subject: [PATCH 023/172] fix dtype/o_format/etc in bwd calls Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --- tests/pytorch/attention/test_attention.py | 4 ++-- transformer_engine/common/fused_attn/fused_attn.cpp | 7 +++++-- .../attention/dot_product_attention/backends.py | 13 +++++++++++-- .../pytorch/csrc/extensions/attention.cpp | 2 +- 4 files changed, 19 insertions(+), 7 deletions(-) diff --git a/tests/pytorch/attention/test_attention.py b/tests/pytorch/attention/test_attention.py index ff3c7506e9..7f353a9483 100644 --- a/tests/pytorch/attention/test_attention.py +++ b/tests/pytorch/attention/test_attention.py @@ -1788,8 +1788,8 @@ def get_model(dtype, config): model_configs_fp8_vs_f16 = { # test: ModelConfig(b, sq, hq, dqk) - "fp8_9": ModelConfig(2, 2048, 16, 128, attn_mask_type="causal"), - "fp8_10": ModelConfig(2, 2048, 24, 192, head_dim_v=128, num_gqa_groups=12, window_size=(512, 512)), + "fp8_9": ModelConfig(2, 2048, 16, 128),#, attn_mask_type="causal"), + "fp8_10": ModelConfig(2, 2048, 24, 192, head_dim_v=128), #, num_gqa_groups=12, window_size=(512, 512)), "fp8_11": ModelConfig(1, 8192, 32, 128, num_gqa_groups=4), "fp8_12": ModelConfig(2, 2048, 16, 128, attn_mask_type="causal"), "fp8_13": ModelConfig(2, 2048, 24, 128, num_gqa_groups=12, attn_mask_type="causal"), diff --git a/transformer_engine/common/fused_attn/fused_attn.cpp b/transformer_engine/common/fused_attn/fused_attn.cpp index 1e9673fff7..ed0971627e 100644 --- a/transformer_engine/common/fused_attn/fused_attn.cpp +++ b/transformer_engine/common/fused_attn/fused_attn.cpp @@ -447,12 +447,13 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( bias_type == NVTE_Bias_Type::NVTE_POST_SCALE_BIAS)) && // qkv format (qkv_format == NVTE_QKV_Format::NVTE_SBHD || qkv_format == NVTE_QKV_Format::NVTE_BSHD || + qkv_format == NVTE_QKV_Format::NVTE_BHSD || (qkv_format == NVTE_QKV_Format::NVTE_THD && sm_arch_ >= 90 && ((cudnn_runtime_version >= 90100 && num_attn_heads == num_gqa_groups) || cudnn_runtime_version >= 90600)) || - ((q_format == NVTE_QKV_Format::NVTE_SBHD || q_format == NVTE_QKV_Format::NVTE_BSHD || + ((q_format == NVTE_QKV_Format::NVTE_SBHD || q_format == NVTE_QKV_Format::NVTE_BSHD || q_format == NVTE_QKV_Format::NVTE_BHSD || (q_format == NVTE_QKV_Format::NVTE_THD && sm_arch_ >= 90) || - kv_format == NVTE_QKV_Format::NVTE_SBHD || kv_format == NVTE_QKV_Format::NVTE_BSHD || + kv_format == NVTE_QKV_Format::NVTE_SBHD || kv_format == NVTE_QKV_Format::NVTE_BSHD || kv_format == NVTE_QKV_Format::NVTE_BHSD || (kv_format == NVTE_QKV_Format::NVTE_THD && sm_arch_ >= 90)) && cudnn_runtime_version >= 90700)) && // sliding window @@ -1345,6 +1346,8 @@ void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETenso h_kv, max_seqlen_q, max_seqlen_kv, d_qk, d_v, window_size_left, window_size_right, false, cuda_graph, deterministic); + printf("Q_type: %d, KV_type: %d, qkv_layout: %d, bias_type: %d, attn_mask_type: %d, softmax_type: %d, dropout: %f, h_q: %d, h_kv: %d, max_seqlen_q: %d, max_seqlen_kv: %d, d_qk: %d, d_v: %d, window_size_left: %d, window_size_right: %d, deterministic: %d\n", Q_type, KV_type, qkv_layout, bias_type, attn_mask_type, softmax_type, dropout, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d_qk, d_v, window_size_left, window_size_right, deterministic); + printf("fused_attention_backend: %d\n", fused_attention_backend); if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) { #if (CUDNN_VERSION >= 8901) Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]); diff --git a/transformer_engine/pytorch/attention/dot_product_attention/backends.py b/transformer_engine/pytorch/attention/dot_product_attention/backends.py index 1a4b34f4fa..d49e7f2365 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/backends.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/backends.py @@ -1457,6 +1457,8 @@ def forward( ctx.qkv_layout = qkv_layout else: ctx.qkv_layout = qkv_layout + if fp8 and not ctx.fp8: + ctx.qkv_layout = original_qkv_layout ctx.o_format = o_format ctx.dqkv_layout = original_qkv_layout @@ -1505,7 +1507,7 @@ def backward(ctx, d_out, *_args): d_out_fp8 = d_out else: d_out_fp8 = ctx.dO_quantizer(d_out) - print(f"d_out after quantizer: {d_out.shape}, {d_out_fp8.shape}, {type(d_out)}, {type(d_out_fp8)}") + print(f"d_out after quantizer: {d_out.shape}, {d_out_fp8._rowwise_data.shape}, {type(d_out)}, {type(d_out_fp8)}") if not isinstance(d_out, QuantizedTensorStorage) and not ctx.use_FAv2_bwd: d_out = d_out.contiguous() ( @@ -1600,8 +1602,12 @@ def backward(ctx, d_out, *_args): if ctx.fp8_recipe.mxfp8(): out_ = out aux_ctx_tensors.append(d_out) + print(f"q_fp8._with_gemm_swizzled_scales: {q_fp8._with_gemm_swizzled_scales}") + print(f"k_fp8._with_gemm_swizzled_scales: {k_fp8._with_gemm_swizzled_scales}") + print(f"v_fp8._with_gemm_swizzled_scales: {v_fp8._with_gemm_swizzled_scales}") + print(f"d_out_fp8._with_gemm_swizzled_scales: {d_out_fp8._with_gemm_swizzled_scales}") print(f"types: {type(q_fp8)}, {type(k_fp8)}, {type(v_fp8)}, {type(out_)}, {type(d_out_fp8)}, {[type(x) for x in aux_ctx_tensors]}") - print(f"shapes: {q_fp8.shape}, {k_fp8.shape}, {v_fp8.shape}, {out_.shape}, {d_out_fp8.shape}, {[x.shape for x in aux_ctx_tensors]}") + print(f"shapes: {q_fp8._rowwise_data.shape}, {k_fp8._rowwise_data.shape}, {v_fp8._rowwise_data.shape}, {out_.shape}, {d_out_fp8._rowwise_data.shape}, {[x.shape for x in aux_ctx_tensors]}") dq_, dk_, dv_, *rest = fused_attn_bwd( ctx.max_seqlen_q, ctx.max_seqlen_kv, @@ -1703,6 +1709,9 @@ def backward(ctx, d_out, *_args): ctx.dropout_p, ctx.fast_zero_fill, ctx.qkv_layout, + ctx.o_format, + d_out_format, + ctx.dqkv_layout, ctx.attn_bias_type, ctx.attn_mask_type, ctx.softmax_type, diff --git a/transformer_engine/pytorch/csrc/extensions/attention.cpp b/transformer_engine/pytorch/csrc/extensions/attention.cpp index 0cb0ae0a06..fc870c4591 100644 --- a/transformer_engine/pytorch/csrc/extensions/attention.cpp +++ b/transformer_engine/pytorch/csrc/extensions/attention.cpp @@ -370,7 +370,7 @@ std::vector fused_attn_bwd( std::vector tmp_shape; DType dqkv_type = fake_dtype_te; if (!dqkv_quantizer.is_none()) { - dqkv_type = dqkv_quantizer.attr("fp8_dtype").cast(); + dqkv_type = dqkv_quantizer.attr("dtype").cast(); } printf(">>>>>> dQKV_type: %d\n", dqkv_type); auto options = torch::TensorOptions().dtype(GetATenDType(dqkv_type)).device(torch::kCUDA); From 8c72deaa83ee8a2816fa4059c7f569e998e9e29e Mon Sep 17 00:00:00 2001 From: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Date: Fri, 20 Feb 2026 18:09:34 -0800 Subject: [PATCH 024/172] fix generateMatrixStridesWithFormats and _v1; fix padding for mxfp8 Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --- .../common/fused_attn/fused_attn_fp8.cu | 138 ++++---- transformer_engine/common/fused_attn/utils.cu | 36 +- transformer_engine/common/fused_attn/utils.h | 310 +++++++++++++++++- 3 files changed, 390 insertions(+), 94 deletions(-) diff --git a/transformer_engine/common/fused_attn/fused_attn_fp8.cu b/transformer_engine/common/fused_attn/fused_attn_fp8.cu index 2afe979f04..9de1fdeabc 100644 --- a/transformer_engine/common/fused_attn/fused_attn_fp8.cu +++ b/transformer_engine/common/fused_attn/fused_attn_fp8.cu @@ -1767,11 +1767,11 @@ void fused_attn_fp8_fwd_impl_v1( std::vector q_stride(4); std::vector k_stride(4); std::vector v_stride(4); - generateMatrixStrides(b, h, s_q, s_kv, d_qk, q_stride.data(), qkv_layout, + generateMatrixStrides_v1(b, h, hg, s_q, s_kv, d_qk, d_v, q_stride.data(), qkv_layout, NVTE_QKV_Matrix::NVTE_Q_Matrix); - generateMatrixStrides(b, hg, s_q, s_kv, d_qk, k_stride.data(), qkv_layout, + generateMatrixStrides_v1(b, h, hg, s_q, s_kv, d_qk, d_v, k_stride.data(), qkv_layout, NVTE_QKV_Matrix::NVTE_K_Matrix); - generateMatrixStrides(b, hg, s_q, s_kv, d_v, v_stride.data(), qkv_layout, + generateMatrixStrides_v1(b, h, hg, s_q, s_kv, d_qk, d_v, v_stride.data(), qkv_layout, NVTE_QKV_Matrix::NVTE_V_Matrix); Q = mha_graph->tensor(fe::graph::Tensor_attributes() .set_name("Q") @@ -1814,28 +1814,15 @@ void fused_attn_fp8_fwd_impl_v1( } } if (is_mxfp8) { - int32_t block_size = 32; - int64_t s_q_padded = ((s_q + 127) / 128) * 128; - int64_t s_kv_padded = ((s_kv + 127) / 128) * 128; - int64_t s_q_scale = (s_q + block_size - 1) / block_size; - int64_t s_kv_scale = (s_kv + block_size - 1) / block_size; - int64_t s_q_scale_padded = ((s_q_scale + 3) / 4) * 4; - int64_t s_kv_scale_padded = ((s_kv_scale + 3) / 4) * 4; - // int64_t d_qk_padded = ((d_qk + 127) / 128) * 128; - int64_t d_v_padded = ((d_v + 127) / 128) * 128; - int64_t d_qk_scale = (d_qk + block_size - 1) / block_size; - // int64_t d_v_scale = (d_v + block_size - 1) / block_size; - int64_t d_qk_scale_padded = ((d_qk_scale + 3) / 4) * 4; - // int64_t d_v_scale_padded = ((d_v_scale + 3) / 4) * 4; + NVTE_QKV_Format q_format = nvte_get_q_format(qkv_layout); + NVTE_QKV_Format kv_format = nvte_get_kv_format(qkv_layout); std::vector q_scale_strides(4); std::vector k_scale_strides(4); std::vector v_scale_strides(4); - generateMatrixStrides(b, h, s_q_padded, s_kv_padded, d_qk_scale_padded, q_scale_strides.data(), qkv_layout, - NVTE_QKV_Matrix::NVTE_Q_Matrix); - generateMatrixStrides(b, hg, s_q_padded, s_kv_padded, d_qk_scale_padded, k_scale_strides.data(), qkv_layout, - NVTE_QKV_Matrix::NVTE_K_Matrix); - generateMatrixStrides(b, hg, s_q_scale_padded, s_kv_scale_padded, d_v_padded, v_scale_strides.data(), qkv_layout, - NVTE_QKV_Matrix::NVTE_V_Matrix); + auto padded = pad_s_d_for_mxfp8(s_q, s_kv, d_qk, d_v); + generateMatrixStridesWithFormat(b, h, padded.s_q_padded, padded.d_qk_scale_padded, q_scale_strides.data(), q_format, false); + generateMatrixStridesWithFormat(b, hg, padded.s_kv_padded, padded.d_qk_scale_padded, k_scale_strides.data(), kv_format, false); + generateMatrixStridesWithFormat(b, hg, padded.s_kv_scale_padded, padded.d_v_padded, v_scale_strides.data(), kv_format, true); descale_q = mha_graph->tensor(fe::graph::Tensor_attributes() .set_name("Descale_q") .set_dim({b, h, s_q_padded, d_qk_scale_padded}) @@ -1932,7 +1919,7 @@ void fused_attn_fp8_fwd_impl_v1( } std::vector o_stride(4); - generateMatrixStridesWithFormat(b, h, s_q, d_v, o_stride.data(), o_format); + generateMatrixStridesWithFormat(b, h, s_q, d_v, o_stride.data(), o_format, false); O->set_output(true).set_dim({b, h, s_q, d_v}).set_stride(o_stride).set_data_type(o_tensor_type); amax_o->set_output(true) .set_dim({1, 1, 1, 1}) @@ -2197,13 +2184,13 @@ void fused_attn_fp8_bwd_impl_v1( std::vector k_stride(4); std::vector v_stride(4); std::vector o_stride(4); - generateMatrixStrides(b, h, s_q, s_kv, d_qk, q_stride.data(), qkv_layout, + generateMatrixStrides_v1(b, h, hg, s_q, s_kv, d_qk, d_v, q_stride.data(), qkv_layout, NVTE_QKV_Matrix::NVTE_Q_Matrix); - generateMatrixStrides(b, hg, s_q, s_kv, d_qk, k_stride.data(), qkv_layout, + generateMatrixStrides_v1(b, hg, hg, s_q, s_kv, d_qk, d_v, k_stride.data(), qkv_layout, NVTE_QKV_Matrix::NVTE_K_Matrix); - generateMatrixStrides(b, hg, s_q, s_kv, d_v, v_stride.data(), qkv_layout, + generateMatrixStrides_v1(b, hg, hg, s_q, s_kv, d_qk, d_v, v_stride.data(), qkv_layout, NVTE_QKV_Matrix::NVTE_V_Matrix); - generateMatrixStridesWithFormat(b, h, s_q, d_v, o_stride.data(), o_format); + generateMatrixStridesWithFormat(b, h, s_q, d_v, o_stride.data(), o_format, false); Q = mha_graph->tensor(fe::graph::Tensor_attributes() .set_name("Q") .set_dim({b, h, s_q, d_qk}) @@ -2272,15 +2259,41 @@ void fused_attn_fp8_bwd_impl_v1( } } if (is_mxfp8) { + NVTE_QKV_Format q_format = nvte_get_q_format(qkv_layout); + NVTE_QKV_Format kv_format = nvte_get_kv_format(qkv_layout); // Q_t, K_t, dO_t, dO_f16 std::vector q_t_stride(4); std::vector k_t_stride(4); std::vector dO_t_stride(4); - generateMatrixStrides(b, h, s_q, s_kv, d_qk, q_t_stride.data(), qkv_layout, - NVTE_QKV_Matrix::NVTE_Q_Matrix_Transpose); - generateMatrixStrides(b, hg, s_q, s_kv, d_qk, k_t_stride.data(), qkv_layout, - NVTE_QKV_Matrix::NVTE_K_Matrix_Transpose); - generateMatrixStridesWithFormat(b, h, d_v, s_q, dO_t_stride.data(), d_out_format); + generateMatrixStridesWithFormat(b, h, s_q, d_qk, q_t_stride.data(), q_format, true); + generateMatrixStridesWithFormat(b, hg, s_kv, d_qk, k_t_stride.data(), kv_format, true); + generateMatrixStridesWithFormat(b, h, s_q, d_v, dO_t_stride.data(), d_out_format, true); + printf("q_t_stride: %d, %d, %d, %d\n", q_t_stride[0], q_t_stride[1], q_t_stride[2], q_t_stride[3]); + printf("k_t_stride: %d, %d, %d, %d\n", k_t_stride[0], k_t_stride[1], k_t_stride[2], k_t_stride[3]); + printf("dO_t_stride: %d, %d, %d, %d\n", dO_t_stride[0], dO_t_stride[1], dO_t_stride[2], dO_t_stride[3]); + printf("qkv_tensor_type: %d\n", qkv_tensor_type); + printf("o_tensor_type: %d\n", o_tensor_type); + printf("do_tensor_type: %d\n", do_tensor_type); + printf("dqkv_tensor_type: %d\n", dqkv_tensor_type); + printf("qkv_layout: %d\n", qkv_layout); + printf("o_format: %d\n", o_format); + printf("d_out_format: %d\n", d_out_format); + printf("dqkv_layout: %d\n", dqkv_layout); + printf("b: %d\n", b); + printf("h: %d\n", h); + printf("hg: %d\n", hg); + printf("s_q: %d\n", s_q); + printf("s_kv: %d\n", s_kv); + printf("d_qk: %d\n", d_qk); + printf("d_v: %d\n", d_v); + printf("is_delayed_scaling: %d\n", is_delayed_scaling); + printf("is_current_scaling: %d\n", is_current_scaling); + printf("is_O_in_F16: %d\n", is_O_in_F16); + printf("is_mxfp8: %d\n", is_mxfp8); + printf("is_causal: %d\n", is_causal); + printf("is_padding: %d\n", is_padding); + printf("is_dropout: %d\n", is_dropout); + printf("is_bias: %d\n", is_bias); Q_t = mha_graph->tensor(fe::graph::Tensor_attributes() .set_name("Q_t") .set_dim({b, h, s_q, d_qk}) @@ -2302,19 +2315,19 @@ void fused_attn_fp8_bwd_impl_v1( .set_stride(o_stride) .set_data_type(o_tensor_type)); // Descale_q, Descale_q_t, Descale_k, Descale_k_t, Descale_v, Descale_dO, Descale_dO_t - int32_t block_size = 32; - int64_t s_q_padded = ((s_q + 127) / 128) * 128; - int64_t s_kv_padded = ((s_kv + 127) / 128) * 128; - int64_t s_q_scale = (s_q + block_size - 1) / block_size; - int64_t s_kv_scale = (s_kv + block_size - 1) / block_size; - int64_t s_q_scale_padded = ((s_q_scale + 3) / 4) * 4; - int64_t s_kv_scale_padded = ((s_kv_scale + 3) / 4) * 4; - int64_t d_qk_padded = ((d_qk + 127) / 128) * 128; - int64_t d_v_padded = ((d_v + 127) / 128) * 128; - int64_t d_qk_scale = (d_qk + block_size - 1) / block_size; - int64_t d_v_scale = (d_v + block_size - 1) / block_size; - int64_t d_qk_scale_padded = ((d_qk_scale + 3) / 4) * 4; - int64_t d_v_scale_padded = ((d_v_scale + 3) / 4) * 4; + auto padded = pad_s_d_for_mxfp8(s_q, s_kv, d_qk, d_v); + printf("s_q_padded: %d\n", padded.s_q_padded); + printf("s_kv_padded: %d\n", padded.s_kv_padded); + printf("s_q_scale: %d\n", padded.s_q_scale); + printf("s_kv_scale: %d\n", padded.s_kv_scale); + printf("s_q_scale_padded: %d\n", padded.s_q_scale_padded); + printf("s_kv_scale_padded: %d\n", padded.s_kv_scale_padded); + printf("d_qk_padded: %d\n", padded.d_qk_padded); + printf("d_v_padded: %d\n", padded.d_v_padded); + printf("d_qk_scale: %d\n", padded.d_qk_scale); + printf("d_v_scale: %d\n", padded.d_v_scale); + printf("d_qk_scale_padded: %d\n", padded.d_qk_scale_padded); + printf("d_v_scale_padded: %d\n", padded.d_v_scale_padded); std::vector q_scale_strides(4); std::vector q_t_scale_strides(4); std::vector k_scale_strides(4); @@ -2322,18 +2335,20 @@ void fused_attn_fp8_bwd_impl_v1( std::vector v_scale_strides(4); std::vector dO_scale_strides(4); std::vector dO_t_scale_strides(4); - generateMatrixStrides(b, h, s_q_padded, s_kv_padded, d_qk_scale_padded, q_scale_strides.data(), qkv_layout, - NVTE_QKV_Matrix::NVTE_Q_Matrix); - generateMatrixStrides(b, h, s_q_scale_padded, s_kv_scale_padded, d_qk_padded, q_t_scale_strides.data(), qkv_layout, - NVTE_QKV_Matrix::NVTE_Q_Matrix_Transpose); - generateMatrixStrides(b, hg, s_q_padded, s_kv_padded, d_qk_scale_padded, k_scale_strides.data(), qkv_layout, - NVTE_QKV_Matrix::NVTE_K_Matrix); - generateMatrixStrides(b, hg, s_q_scale_padded, s_kv_scale_padded, d_qk_padded, k_t_scale_strides.data(), qkv_layout, - NVTE_QKV_Matrix::NVTE_K_Matrix_Transpose); - generateMatrixStrides(b, hg, s_q_padded, s_kv_padded, d_v_scale_padded, v_scale_strides.data(), qkv_layout, - NVTE_QKV_Matrix::NVTE_V_Matrix); - generateMatrixStridesWithFormat(b, h, s_q_padded, d_v_scale_padded, dO_scale_strides.data(), d_out_format); - generateMatrixStridesWithFormat(b, h, d_v_padded, s_q_scale_padded, dO_t_scale_strides.data(), d_out_format); + generateMatrixStridesWithFormat(b, h, padded.s_q_padded, padded.d_qk_scale_padded, q_scale_strides.data(), q_format, false); + generateMatrixStridesWithFormat(b, h, padded.s_q_scale_padded, padded.d_qk_padded, q_t_scale_strides.data(), q_format, true); + generateMatrixStridesWithFormat(b, hg, padded.s_kv_padded, padded.d_qk_scale_padded, k_scale_strides.data(), kv_format, false); + generateMatrixStridesWithFormat(b, hg, padded.s_kv_scale_padded, padded.d_qk_padded, k_t_scale_strides.data(), kv_format, true); + generateMatrixStridesWithFormat(b, hg, padded.s_kv_padded, padded.d_v_scale_padded, v_scale_strides.data(), kv_format, false); + generateMatrixStridesWithFormat(b, h, padded.s_q_padded, padded.d_v_scale_padded, dO_scale_strides.data(), d_out_format, false); + generateMatrixStridesWithFormat(b, h, padded.s_q_scale_padded, padded.d_v_padded, dO_t_scale_strides.data(), d_out_format, true); + printf("q_scale_strides: %d, %d, %d, %d\n", q_scale_strides[0], q_scale_strides[1], q_scale_strides[2], q_scale_strides[3]); + printf("q_t_scale_strides: %d, %d, %d, %d\n", q_t_scale_strides[0], q_t_scale_strides[1], q_t_scale_strides[2], q_t_scale_strides[3]); + printf("k_scale_strides: %d, %d, %d, %d\n", k_scale_strides[0], k_scale_strides[1], k_scale_strides[2], k_scale_strides[3]); + printf("k_t_scale_strides: %d, %d, %d, %d\n", k_t_scale_strides[0], k_t_scale_strides[1], k_t_scale_strides[2], k_t_scale_strides[3]); + printf("v_scale_strides: %d, %d, %d, %d\n", v_scale_strides[0], v_scale_strides[1], v_scale_strides[2], v_scale_strides[3]); + printf("dO_scale_strides: %d, %d, %d, %d\n", dO_scale_strides[0], dO_scale_strides[1], dO_scale_strides[2], dO_scale_strides[3]); + printf("dO_t_scale_strides: %d, %d, %d, %d\n", dO_t_scale_strides[0], dO_t_scale_strides[1], dO_t_scale_strides[2], dO_t_scale_strides[3]); descale_q = mha_graph->tensor(fe::graph::Tensor_attributes() .set_name("Descale_q") .set_dim({b, h, s_q_padded, d_qk_scale_padded}) @@ -2472,12 +2487,15 @@ void fused_attn_fp8_bwd_impl_v1( std::vector dq_stride(4); std::vector dk_stride(4); std::vector dv_stride(4); - generateMatrixStrides(b, h, s_q, s_kv, d_qk, dq_stride.data(), dqkv_layout, + generateMatrixStrides_v1(b, h, hg, s_q, s_kv, d_qk, d_v, dq_stride.data(), dqkv_layout, NVTE_QKV_Matrix::NVTE_Q_Matrix); - generateMatrixStrides(b, hg, s_q, s_kv, d_qk, dk_stride.data(), dqkv_layout, + generateMatrixStrides_v1(b, hg, hg, s_q, s_kv, d_qk, d_v, dk_stride.data(), dqkv_layout, NVTE_QKV_Matrix::NVTE_K_Matrix); - generateMatrixStrides(b, hg, s_q, s_kv, d_v, dv_stride.data(), dqkv_layout, + generateMatrixStrides_v1(b, hg, hg, s_q, s_kv, d_qk, d_v, dv_stride.data(), dqkv_layout, NVTE_QKV_Matrix::NVTE_V_Matrix); + printf("dq_stride: %d, %d, %d, %d\n", dq_stride[0], dq_stride[1], dq_stride[2], dq_stride[3]); + printf("dk_stride: %d, %d, %d, %d\n", dk_stride[0], dk_stride[1], dk_stride[2], dk_stride[3]); + printf("dv_stride: %d, %d, %d, %d\n", dv_stride[0], dv_stride[1], dv_stride[2], dv_stride[3]); dQ->set_output(true).set_dim({b, h, s_q, d_qk}).set_stride(dq_stride).set_data_type(dqkv_tensor_type); dK->set_output(true).set_dim({b, hg, s_kv, d_qk}).set_stride(dk_stride).set_data_type(dqkv_tensor_type); dV->set_output(true).set_dim({b, hg, s_kv, d_v}).set_stride(dv_stride).set_data_type(dqkv_tensor_type); diff --git a/transformer_engine/common/fused_attn/utils.cu b/transformer_engine/common/fused_attn/utils.cu index 0ea5d6aa7f..8a9399e830 100644 --- a/transformer_engine/common/fused_attn/utils.cu +++ b/transformer_engine/common/fused_attn/utils.cu @@ -17,37 +17,6 @@ namespace fused_attn { using namespace transformer_engine; -// get matrix strides based on matrix type -void generateMatrixStridesWithFormat(int64_t b, int64_t h, int64_t s, int64_t d, - int64_t *strideA, NVTE_QKV_Format format) { -constexpr int batch_dim_idx = 0; -constexpr int head_dim_idx = 1; -constexpr int seqlen_dim_idx = 2; -constexpr int hidden_dim_idx = 3; - - switch (format) { - case NVTE_QKV_Format::NVTE_BSHD: - case NVTE_QKV_Format::NVTE_THD: - strideA[batch_dim_idx] = s * h * d; - strideA[head_dim_idx] = d; - strideA[seqlen_dim_idx] = h * d; - strideA[hidden_dim_idx] = 1; - break; - case NVTE_QKV_Format::NVTE_SBHD: - strideA[batch_dim_idx] = h * d; - strideA[head_dim_idx] = d; - strideA[seqlen_dim_idx] = b * h * d; - strideA[hidden_dim_idx] = 1; - break; - case NVTE_QKV_Format::NVTE_BHSD: - strideA[batch_dim_idx] = h * s * d; - strideA[head_dim_idx] = s * d; - strideA[seqlen_dim_idx] = d; - strideA[hidden_dim_idx] = 1; - break; - } -} - // get matrix strides based on matrix type void generateMatrixStrides(int64_t b, int64_t h, int64_t s_q, int64_t s_kv, int64_t d, int64_t *strideA, NVTE_QKV_Layout layout, NVTE_QKV_Matrix matrix) { @@ -343,6 +312,11 @@ void generateMatrixStrides(int64_t b, int64_t h, int64_t s_q, int64_t s_kv, int6 strideA[head_dim_idx] = s_kv * d; strideA[seqlen_transpose_dim_idx] = d; strideA[hidden_transpose_dim_idx] = 1; + } else if (matrix == NVTE_QKV_Matrix::NVTE_Q_Matrix_Transpose) { + strideA[batch_dim_idx] = h * s_q * d; + strideA[head_dim_idx] = s_q * d; + strideA[seqlen_transpose_dim_idx] = d; + strideA[hidden_transpose_dim_idx] = 1; } break; } diff --git a/transformer_engine/common/fused_attn/utils.h b/transformer_engine/common/fused_attn/utils.h index 88535f61c9..f0b947c379 100644 --- a/transformer_engine/common/fused_attn/utils.h +++ b/transformer_engine/common/fused_attn/utils.h @@ -31,11 +31,315 @@ enum NVTE_QKV_Matrix { NVTE_V_Matrix_Transpose = 5, // values transposed NVTE_S_Matrix = 5, // output of GEMM1 NVTE_O_Matrix = 6, // final output - NVTE_O_Matrix_Transpose = 7, // final output transposed }; -void generateMatrixStridesWithFormat(int64_t b, int64_t h, int64_t s, int64_t d, - int64_t *strideA, NVTE_QKV_Format format); +// Padded sizes for MXFP8 layout (s_q/s_kv/d_qk/d_v and their scaled dimensions) +struct MXFP8PaddedSizes { + int64_t s_q_padded; + int64_t s_kv_padded; + int64_t s_q_scale; + int64_t s_kv_scale; + int64_t s_q_scale_padded; + int64_t s_kv_scale_padded; + int64_t d_qk_padded; + int64_t d_v_padded; + int64_t d_qk_scale; + int64_t d_v_scale; + int64_t d_qk_scale_padded; + int64_t d_v_scale_padded; +}; + +// Pad s and d for MXFP8 layout +inline MXFP8PaddedSizes pad_s_d_for_mxfp8(int64_t s_q, int64_t s_kv, int64_t d_qk, int64_t d_v) { + constexpr int64_t block_size = 32; + MXFP8PaddedSizes p; + p.s_q_padded = ((s_q + 127) / 128) * 128; + p.s_kv_padded = ((s_kv + 127) / 128) * 128; + p.s_q_scale = (s_q + block_size - 1) / block_size; + p.s_kv_scale = (s_kv + block_size - 1) / block_size; + p.s_q_scale_padded = ((p.s_q_scale + 3) / 4) * 4; + p.s_kv_scale_padded = ((p.s_kv_scale + 3) / 4) * 4; + p.d_qk_padded = ((d_qk + 127) / 128) * 128; + p.d_v_padded = ((d_v + 127) / 128) * 128; + p.d_qk_scale = (d_qk + block_size - 1) / block_size; + p.d_v_scale = (d_v + block_size - 1) / block_size; + p.d_qk_scale_padded = ((p.d_qk_scale + 3) / 4) * 4; + p.d_v_scale_padded = ((p.d_v_scale + 3) / 4) * 4; + return p; +} + +// Get matrix strides for a 4D tensor [batch, head, seqlen, hidden] given a QKV format. +// strideA must point to at least 4 int64_t elements. +inline void generateMatrixStridesWithFormat(int64_t b, int64_t h, int64_t s, int64_t d, + int64_t *strides, NVTE_QKV_Format format, bool transpose) { + constexpr int batch_dim_idx = 0; + constexpr int head_dim_idx = 1; + int seqlen_dim_idx = transpose ? 3 : 2; + int hidden_dim_idx = transpose ? 2 : 3; + + switch (format) { + case NVTE_QKV_Format::NVTE_BSHD: + case NVTE_QKV_Format::NVTE_THD: + strides[batch_dim_idx] = s * h * d; + strides[head_dim_idx] = d; + strides[seqlen_dim_idx] = h * d; + strides[hidden_dim_idx] = 1; + break; + case NVTE_QKV_Format::NVTE_SBHD: + strides[batch_dim_idx] = h * d; + strides[head_dim_idx] = d; + strides[seqlen_dim_idx] = b * h * d; + strides[hidden_dim_idx] = 1; + break; + case NVTE_QKV_Format::NVTE_BHSD: + strides[batch_dim_idx] = h * s * d; + strides[head_dim_idx] = s * d; + strides[seqlen_dim_idx] = d; + strides[hidden_dim_idx] = 1; + break; + default: + NVTE_CHECK(false, "Invalid format."); + break; + } +} + +// get matrix strides based on matrix type +inline void generateMatrixStrides_v1( + int64_t b, int64_t h, int64_t hg, int64_t s_q, int64_t s_kv, int64_t d_qk, int64_t d_v, + int64_t *strides, NVTE_QKV_Layout layout, NVTE_QKV_Matrix matrix) +{ + constexpr int batch_dim_idx = 0; + constexpr int head_dim_idx = 1; + bool transpose = + (matrix == NVTE_QKV_Matrix::NVTE_Q_Matrix_Transpose) || + (matrix == NVTE_QKV_Matrix::NVTE_K_Matrix_Transpose) || + (matrix == NVTE_QKV_Matrix::NVTE_V_Matrix_Transpose); + int seqlen_dim_idx = transpose ? 3 : 2; + int hidden_dim_idx = transpose ? 2 : 3; + constexpr int seqlen_q_dim_idx = 2; + constexpr int seqlen_kv_dim_idx = 3; + + if (matrix == NVTE_QKV_Matrix::NVTE_Q_Matrix_Transpose) { + matrix = NVTE_QKV_Matrix::NVTE_Q_Matrix; + } else if (matrix == NVTE_QKV_Matrix::NVTE_K_Matrix_Transpose) { + matrix = NVTE_QKV_Matrix::NVTE_K_Matrix; + } else if (matrix == NVTE_QKV_Matrix::NVTE_V_Matrix_Transpose) { + matrix = NVTE_QKV_Matrix::NVTE_V_Matrix; + } + NVTE_CHECK(matrix != NVTE_QKV_Matrix::NVTE_O_Matrix, "Invalid matrix type. Expected Q, K, V, O, or their related transposes."); + + switch (layout) { + case NVTE_QKV_Layout::NVTE_SB3HD: + if ((matrix == NVTE_QKV_Matrix::NVTE_Q_Matrix) || + (matrix == NVTE_QKV_Matrix::NVTE_K_Matrix) || + (matrix == NVTE_QKV_Matrix::NVTE_V_Matrix)) { + strides[batch_dim_idx] = 3 * h * d_qk; + strides[head_dim_idx] = d_qk; + strides[seqlen_dim_idx] = b * 3 * h * d_qk; + strides[hidden_dim_idx] = 1; + } + break; + case NVTE_QKV_Layout::NVTE_SBH3D: + if ((matrix == NVTE_QKV_Matrix::NVTE_Q_Matrix) || + (matrix == NVTE_QKV_Matrix::NVTE_K_Matrix) || + (matrix == NVTE_QKV_Matrix::NVTE_V_Matrix)) { + strides[batch_dim_idx] = 3 * h * d_qk; + strides[head_dim_idx] = 3 * d_qk; + strides[seqlen_dim_idx] = b * 3 * h * d_qk; + strides[hidden_dim_idx] = 1; + } + break; + case NVTE_QKV_Layout::NVTE_SBHD_SB2HD: + if (matrix == NVTE_QKV_Matrix::NVTE_Q_Matrix) { + strides[batch_dim_idx] = h * d_qk; + strides[head_dim_idx] = d_qk; + strides[seqlen_dim_idx] = b * h * d_qk; + strides[hidden_dim_idx] = 1; + } else if ((matrix == NVTE_QKV_Matrix::NVTE_K_Matrix) || + (matrix == NVTE_QKV_Matrix::NVTE_V_Matrix)) { + strides[batch_dim_idx] = 2 * hg * d_qk; + strides[head_dim_idx] = d_qk; + strides[seqlen_dim_idx] = b * 2 * hg * d_qk; + strides[hidden_dim_idx] = 1; + } + break; + case NVTE_QKV_Layout::NVTE_SBHD_SBH2D: + if (matrix == NVTE_QKV_Matrix::NVTE_Q_Matrix) { + strides[batch_dim_idx] = h * d_qk; + strides[head_dim_idx] = d_qk; + strides[seqlen_dim_idx] = b * h * d_qk; + strides[hidden_dim_idx] = 1; + } else if ((matrix == NVTE_QKV_Matrix::NVTE_K_Matrix) || + (matrix == NVTE_QKV_Matrix::NVTE_V_Matrix)) { + strides[batch_dim_idx] = 2 * hg * d_qk; + strides[head_dim_idx] = 2 * d_qk; + strides[seqlen_dim_idx] = b * 2 * hg * d_qk; + strides[hidden_dim_idx] = 1; + } + break; + case NVTE_QKV_Layout::NVTE_SBHD_SBHD_SBHD: + case NVTE_QKV_Layout::NVTE_Paged_KV_SBHD_SBHD_SBHD: + if (matrix == NVTE_QKV_Matrix::NVTE_Q_Matrix) { + strides[batch_dim_idx] = h * d_qk; + strides[head_dim_idx] = d_qk; + strides[seqlen_dim_idx] = b * h * d_qk; + strides[hidden_dim_idx] = 1; + } else if (matrix == NVTE_QKV_Matrix::NVTE_K_Matrix) { + strides[batch_dim_idx] = hg * d_qk; + strides[head_dim_idx] = d_qk; + strides[seqlen_dim_idx] = b * hg * d_qk; + strides[hidden_dim_idx] = 1; + } else if (matrix == NVTE_QKV_Matrix::NVTE_V_Matrix) { + strides[batch_dim_idx] = hg * d_v; + strides[head_dim_idx] = d_v; + strides[seqlen_dim_idx] = b * hg * d_v; + strides[hidden_dim_idx] = 1; + } + break; + case NVTE_QKV_Layout::NVTE_BS3HD: + case NVTE_QKV_Layout::NVTE_T3HD: + if ((matrix == NVTE_QKV_Matrix::NVTE_Q_Matrix) || + (matrix == NVTE_QKV_Matrix::NVTE_K_Matrix) || + (matrix == NVTE_QKV_Matrix::NVTE_V_Matrix)) { + strides[batch_dim_idx] = s_q * 3 * h * d_qk; + strides[head_dim_idx] = d_qk; + strides[seqlen_dim_idx] = 3 * h * d_qk; + strides[hidden_dim_idx] = 1; + } + break; + case NVTE_QKV_Layout::NVTE_BSH3D: + case NVTE_QKV_Layout::NVTE_TH3D: + if ((matrix == NVTE_QKV_Matrix::NVTE_Q_Matrix) || + (matrix == NVTE_QKV_Matrix::NVTE_K_Matrix) || + (matrix == NVTE_QKV_Matrix::NVTE_V_Matrix)) { + strides[batch_dim_idx] = s_q * 3 * h * d_qk; + strides[head_dim_idx] = 3 * d_qk; + strides[seqlen_dim_idx] = 3 * h * d_qk; + strides[hidden_dim_idx] = 1; + } + break; + case NVTE_QKV_Layout::NVTE_BSHD_BS2HD: + case NVTE_QKV_Layout::NVTE_THD_T2HD: + if (matrix == NVTE_QKV_Matrix::NVTE_Q_Matrix) { + strides[batch_dim_idx] = s_q * h * d_qk; + strides[head_dim_idx] = d_qk; + strides[seqlen_dim_idx] = h * d_qk; + strides[hidden_dim_idx] = 1; + } else if ((matrix == NVTE_QKV_Matrix::NVTE_K_Matrix) || + (matrix == NVTE_QKV_Matrix::NVTE_V_Matrix)) { + strides[batch_dim_idx] = s_kv * 2 * hg * d_qk; + strides[head_dim_idx] = d_qk; + strides[seqlen_dim_idx] = 2 * hg * d_qk; + strides[hidden_dim_idx] = 1; + } + break; + case NVTE_QKV_Layout::NVTE_BSHD_BSH2D: + case NVTE_QKV_Layout::NVTE_THD_TH2D: + if (matrix == NVTE_QKV_Matrix::NVTE_Q_Matrix) { + strides[batch_dim_idx] = s_q * h * d_qk; + strides[head_dim_idx] = d_qk; + strides[seqlen_dim_idx] = h * d_qk; + strides[hidden_dim_idx] = 1; + } else if ((matrix == NVTE_QKV_Matrix::NVTE_K_Matrix) || + (matrix == NVTE_QKV_Matrix::NVTE_V_Matrix)) { + strides[batch_dim_idx] = s_kv * 2 * hg * d_qk; + strides[head_dim_idx] = 2 * d_qk; + strides[seqlen_dim_idx] = 2 * hg * d_qk; + strides[hidden_dim_idx] = 1; + } + break; + case NVTE_QKV_Layout::NVTE_BSHD_BSHD_BSHD: + case NVTE_QKV_Layout::NVTE_THD_THD_THD: + case NVTE_QKV_Layout::NVTE_THD_BSHD_BSHD: + case NVTE_QKV_Layout::NVTE_Paged_KV_BSHD_BSHD_BSHD: + case NVTE_QKV_Layout::NVTE_Paged_KV_THD_BSHD_BSHD: + if (matrix == NVTE_QKV_Matrix::NVTE_Q_Matrix) { + strides[batch_dim_idx] = s_q * h * d_qk; + strides[head_dim_idx] = d_qk; + strides[seqlen_dim_idx] = h * d_qk; + strides[hidden_dim_idx] = 1; + } else if (matrix == NVTE_QKV_Matrix::NVTE_K_Matrix) { + strides[batch_dim_idx] = s_kv * hg * d_qk; + strides[head_dim_idx] = d_qk; + strides[seqlen_dim_idx] = hg * d_qk; + strides[hidden_dim_idx] = 1; + } else if (matrix == NVTE_QKV_Matrix::NVTE_V_Matrix) { + strides[batch_dim_idx] = s_kv * hg * d_v; + strides[head_dim_idx] = d_v; + strides[seqlen_dim_idx] = hg * d_v; + strides[hidden_dim_idx] = 1; + } + break; + case NVTE_QKV_Layout::NVTE_SBHD_BSHD_BSHD: + case NVTE_QKV_Layout::NVTE_Paged_KV_SBHD_BSHD_BSHD: + if (matrix == NVTE_QKV_Matrix::NVTE_Q_Matrix) { + strides[batch_dim_idx] = h * d_qk; + strides[head_dim_idx] = d_qk; + strides[seqlen_dim_idx] = b * h * d_qk; + strides[hidden_dim_idx] = 1; + } else if (matrix == NVTE_QKV_Matrix::NVTE_K_Matrix) { + strides[batch_dim_idx] = s_kv * hg * d_qk; + strides[head_dim_idx] = d_qk; + strides[seqlen_dim_idx] = hg * d_qk; + strides[hidden_dim_idx] = 1; + } else if (matrix == NVTE_QKV_Matrix::NVTE_V_Matrix) { + strides[batch_dim_idx] = s_kv * hg * d_v; + strides[head_dim_idx] = d_v; + strides[seqlen_dim_idx] = hg * d_v; + strides[hidden_dim_idx] = 1; + } + break; + case NVTE_QKV_Layout::NVTE_BSHD_SBHD_SBHD: + case NVTE_QKV_Layout::NVTE_THD_SBHD_SBHD: + case NVTE_QKV_Layout::NVTE_Paged_KV_BSHD_SBHD_SBHD: + case NVTE_QKV_Layout::NVTE_Paged_KV_THD_SBHD_SBHD: + if (matrix == NVTE_QKV_Matrix::NVTE_Q_Matrix) { + strides[batch_dim_idx] = s_q * h * d_qk; + strides[head_dim_idx] = d_qk; + strides[seqlen_dim_idx] = h * d_qk; + strides[hidden_dim_idx] = 1; + } else if (matrix == NVTE_QKV_Matrix::NVTE_K_Matrix) { + strides[batch_dim_idx] = hg * d_qk; + strides[head_dim_idx] = d_qk; + strides[seqlen_dim_idx] = b * hg * d_qk; + strides[hidden_dim_idx] = 1; + } else if (matrix == NVTE_QKV_Matrix::NVTE_V_Matrix) { + strides[batch_dim_idx] = hg * d_v; + strides[head_dim_idx] = d_v; + strides[seqlen_dim_idx] = b * hg * d_v; + strides[hidden_dim_idx] = 1; + } + break; + case NVTE_QKV_Layout::NVTE_BHSD_BHSD_BHSD: + if (matrix == NVTE_QKV_Matrix::NVTE_Q_Matrix) { + strides[batch_dim_idx] = h * s_q * d_qk; + strides[head_dim_idx] = s_q * d_qk; + strides[seqlen_dim_idx] = d_qk; + strides[hidden_dim_idx] = 1; + } else if (matrix == NVTE_QKV_Matrix::NVTE_K_Matrix) { + strides[batch_dim_idx] = hg * s_kv * d_qk; + strides[head_dim_idx] = s_kv * d_qk; + strides[seqlen_dim_idx] = d_qk; + strides[hidden_dim_idx] = 1; + } else if (matrix == NVTE_QKV_Matrix::NVTE_V_Matrix) { + strides[batch_dim_idx] = hg * s_kv * d_v; + strides[head_dim_idx] = s_kv * d_v; + strides[seqlen_dim_idx] = d_v; + strides[hidden_dim_idx] = 1; + } + break; + default: + NVTE_CHECK(false, "Invalid layout."); + break; + } + + if (matrix == NVTE_QKV_Matrix::NVTE_S_Matrix) { + strides[seqlen_kv_dim_idx] = 1; + strides[seqlen_q_dim_idx] = s_kv; + strides[head_dim_idx] = s_q * s_kv; + strides[batch_dim_idx] = h * s_q * s_kv; + } +} void generateMatrixStrides(int64_t b, int64_t h, int64_t s_q, int64_t s_kv, int64_t d, int64_t *strideA, NVTE_QKV_Layout layout, NVTE_QKV_Matrix matrix); From 5f23eddf1e7fdbd7c26526d45155a877012f0841 Mon Sep 17 00:00:00 2001 From: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Date: Fri, 20 Feb 2026 18:16:01 -0800 Subject: [PATCH 025/172] fix upon last commit for paddedsizes Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --- .../common/fused_attn/fused_attn_fp8.cu | 20 +++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/transformer_engine/common/fused_attn/fused_attn_fp8.cu b/transformer_engine/common/fused_attn/fused_attn_fp8.cu index 9de1fdeabc..f13eef3a66 100644 --- a/transformer_engine/common/fused_attn/fused_attn_fp8.cu +++ b/transformer_engine/common/fused_attn/fused_attn_fp8.cu @@ -1825,19 +1825,19 @@ void fused_attn_fp8_fwd_impl_v1( generateMatrixStridesWithFormat(b, hg, padded.s_kv_scale_padded, padded.d_v_padded, v_scale_strides.data(), kv_format, true); descale_q = mha_graph->tensor(fe::graph::Tensor_attributes() .set_name("Descale_q") - .set_dim({b, h, s_q_padded, d_qk_scale_padded}) + .set_dim({b, h, padded.s_q_padded, padded.d_qk_scale_padded}) .set_stride(q_scale_strides) .set_data_type(fe::DataType_t::FP8_E8M0) .set_reordering_type(fe::TensorReordering_t::F8_128x4)); descale_k = mha_graph->tensor(fe::graph::Tensor_attributes() .set_name("Descale_k") - .set_dim({b, hg, s_kv_padded, d_qk_scale_padded}) + .set_dim({b, hg, padded.s_kv_padded, padded.d_qk_scale_padded}) .set_stride(k_scale_strides) .set_data_type(fe::DataType_t::FP8_E8M0) .set_reordering_type(fe::TensorReordering_t::F8_128x4)); descale_v = mha_graph->tensor(fe::graph::Tensor_attributes() .set_name("Descale_v") - .set_dim({b, hg, s_kv_scale_padded, d_v_padded}) + .set_dim({b, hg, padded.s_kv_scale_padded, padded.d_v_padded}) .set_stride(v_scale_strides) .set_data_type(fe::DataType_t::FP8_E8M0) .set_reordering_type(fe::TensorReordering_t::F8_128x4)); @@ -2351,43 +2351,43 @@ void fused_attn_fp8_bwd_impl_v1( printf("dO_t_scale_strides: %d, %d, %d, %d\n", dO_t_scale_strides[0], dO_t_scale_strides[1], dO_t_scale_strides[2], dO_t_scale_strides[3]); descale_q = mha_graph->tensor(fe::graph::Tensor_attributes() .set_name("Descale_q") - .set_dim({b, h, s_q_padded, d_qk_scale_padded}) + .set_dim({b, h, padded.s_q_padded, padded.d_qk_scale_padded}) .set_stride(q_scale_strides) .set_data_type(fe::DataType_t::FP8_E8M0) .set_reordering_type(fe::TensorReordering_t::F8_128x4)); descale_q_t = mha_graph->tensor(fe::graph::Tensor_attributes() .set_name("Descale_q_t") - .set_dim({b, h, s_q_scale_padded, d_qk_padded}) + .set_dim({b, h, padded.s_q_scale_padded, padded.d_qk_padded}) .set_stride(q_t_scale_strides) .set_data_type(fe::DataType_t::FP8_E8M0) .set_reordering_type(fe::TensorReordering_t::F8_128x4)); descale_k = mha_graph->tensor(fe::graph::Tensor_attributes() .set_name("Descale_k") - .set_dim({b, hg, s_kv_padded, d_qk_scale_padded}) + .set_dim({b, hg, padded.s_kv_padded, padded.d_qk_scale_padded}) .set_stride(k_scale_strides) .set_data_type(fe::DataType_t::FP8_E8M0) .set_reordering_type(fe::TensorReordering_t::F8_128x4)); descale_k_t = mha_graph->tensor(fe::graph::Tensor_attributes() .set_name("Descale_k_t") - .set_dim({b, hg, s_kv_scale_padded, d_qk_padded}) + .set_dim({b, hg, padded.s_kv_scale_padded, padded.d_qk_padded}) .set_stride(k_t_scale_strides) .set_data_type(fe::DataType_t::FP8_E8M0) .set_reordering_type(fe::TensorReordering_t::F8_128x4)); descale_v = mha_graph->tensor(fe::graph::Tensor_attributes() .set_name("Descale_v") - .set_dim({b, hg, s_kv_padded, d_v_scale_padded}) + .set_dim({b, hg, padded.s_kv_padded, padded.d_v_scale_padded}) .set_stride(v_scale_strides) .set_data_type(fe::DataType_t::FP8_E8M0) .set_reordering_type(fe::TensorReordering_t::F8_128x4)); descale_dO = mha_graph->tensor(fe::graph::Tensor_attributes() .set_name("Descale_dO") - .set_dim({b, h, s_q_padded, d_v_scale_padded}) + .set_dim({b, h, padded.s_q_padded, padded.d_v_scale_padded}) .set_stride(dO_scale_strides) .set_data_type(fe::DataType_t::FP8_E8M0) .set_reordering_type(fe::TensorReordering_t::F8_128x4)); descale_dO_t = mha_graph->tensor(fe::graph::Tensor_attributes() .set_name("Descale_dO_t") - .set_dim({b, h, s_q_scale_padded, d_v_padded}) + .set_dim({b, h, padded.s_q_scale_padded, padded.d_v_padded}) .set_stride(dO_t_scale_strides) .set_data_type(fe::DataType_t::FP8_E8M0) .set_reordering_type(fe::TensorReordering_t::F8_128x4)); From 18c55801b592d620a6a8c4ea02828f06b0f8d3fd Mon Sep 17 00:00:00 2001 From: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Date: Fri, 20 Feb 2026 18:29:30 -0800 Subject: [PATCH 026/172] add mxfp8 env var Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --- .../dot_product_attention/dot_product_attention.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py b/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py index 55553d30be..f7699340e6 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py @@ -139,6 +139,11 @@ | | | export NVTE_DPA_FP8DS_AMAX_HISTLEN=1 # or any other integer | | | | export NVTE_DPA_FP8DS_REDUCE_AMAX=1 # or 0 | +-------------------+-----------+-----------------------------------------------------------------------------------+ +| NVFP4 | MXFP8 | Pass NVFP4 to autocast(); | +| | | Attention MXFP8 reuses the fp8_dpa, fp8_mha values from linear NVFP4; | +| | | export NVTE_DPA_FP8_RECIPE="MXFP8BlockScaling" # switch to MXFP8BS | +| | | export NVTE_DPA_FP8_FORMAT="HYBRID" # or "E4M3", "E5M2" | ++-------------------+-----------+-----------------------------------------------------------------------------------+ """ _dpa_fp8_recipe = os.getenv("NVTE_DPA_FP8_RECIPE", "") formats = {"HYBRID": Format.HYBRID, "E4M3": Format.E4M3, "E5M2": Format.E5M2} @@ -673,6 +678,15 @@ def init_fp8_metadata(self, num_gemms: int = 1) -> None: ] fp8_recipe_dpa = fake_recipes[1] fp8_recipes = fake_recipes + elif fp8_recipe.nvfp4() and _dpa_fp8_recipe == "MXFP8BlockScaling": + # reuse fp8_dpa, fp8_mha from fp8_recipe but not fp8_format; construct a MXFP8 recipe + fake_recipe = MXFP8BlockScaling( + fp8_format=_dpa_fp8_format, + fp8_dpa=fp8_recipe.fp8_dpa, + fp8_mha=fp8_recipe.fp8_mha, + ) + fp8_recipe_dpa = fake_recipe + fp8_recipes = fp8_recipe_dpa # reduce over TP+CP groups; expect fp8_group to be set up so # assume attention uses the same fp8_group as GEMMs From 68476456981a2533b0e18f20340403ee0f50f08d Mon Sep 17 00:00:00 2001 From: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Date: Fri, 20 Feb 2026 18:30:03 -0800 Subject: [PATCH 027/172] disable FA for mxfp8 Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --- .../attention/dot_product_attention/utils.py | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/transformer_engine/pytorch/attention/dot_product_attention/utils.py b/transformer_engine/pytorch/attention/dot_product_attention/utils.py index 379e056b54..e8a4170cf3 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/utils.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/utils.py @@ -475,6 +475,9 @@ def get_attention_backend( # Filter: Execution type if fp8 and fp8_meta["recipe"].fp8_dpa: + fp8_recipe = fp8_meta["recipe"] + if fp8_meta.get("local_recipes", None) is not None: + fp8_recipe = fp8_meta["local_recipes"][0] if use_flash_attention_2 and FlashAttentionUtils.is_installed: logger.debug("Disabling FlashAttention 2 for FP8 attention") use_flash_attention_2 = False @@ -482,6 +485,10 @@ def get_attention_backend( if FlashAttentionUtils.v3_is_installed: logger.debug("Disabling FlashAttention 3 for FP8 training") use_flash_attention_3 = False + if use_flash_attention_3 and fp8_recipe.mxfp8(): + if FlashAttentionUtils.v3_is_installed: + logger.debug("Disabling FlashAttention 3 for MXFP8") + use_flash_attention_3 = False if use_unfused_attention: allow_emulation = ( os.getenv("NVTE_UnfusedDPA_Emulate_FP8", "0") == "1" or is_in_onnx_export_mode() @@ -489,9 +496,6 @@ def get_attention_backend( if not allow_emulation: logger.debug("Disabling UnfusedDotProductAttention for FP8 attention") use_unfused_attention = False - fp8_recipe = fp8_meta["recipe"] - if fp8_meta.get("local_recipes", None) is not None: - fp8_recipe = fp8_meta["local_recipes"][0] if use_fused_attention and fp8_recipe.float8_current_scaling(): if device_compute_capability < (10, 0): logger.debug("Disabling FusedAttention for FP8 current scaling on arch < sm100") @@ -603,9 +607,9 @@ def get_attention_backend( # Filter: Head dimension if head_dim_qk != head_dim_v: - if use_flash_attention_2 and FlashAttentionUtils.is_installed: - logger.debug("Disabling FlashAttention 2 as it does not support MLA.") - use_flash_attention_2 = False + # if use_flash_attention_2 and FlashAttentionUtils.is_installed: + # logger.debug("Disabling FlashAttention 2 as it does not support MLA.") + # use_flash_attention_2 = False qkv_layout_group = qkv_layout.replace("b", "").replace("s", "").replace("t", "") if use_fused_attention and qkv_layout_group != "hd_hd_hd": From c5a98d5e9dcbba2f84a9328236b5a1a47616d97d Mon Sep 17 00:00:00 2001 From: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Date: Fri, 20 Feb 2026 19:21:29 -0800 Subject: [PATCH 028/172] add mha test Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --- tests/pytorch/attention/test_attention.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/tests/pytorch/attention/test_attention.py b/tests/pytorch/attention/test_attention.py index 7f353a9483..f0e70280bf 100644 --- a/tests/pytorch/attention/test_attention.py +++ b/tests/pytorch/attention/test_attention.py @@ -1816,7 +1816,7 @@ def get_model(dtype, config): @pytest.mark.parametrize("fp8_dpa_bwd", [True, False]) @pytest.mark.parametrize("RoPE", [True, False]) @pytest.mark.parametrize("is_training", [True, False]) -@pytest.mark.parametrize("scaling_mode", ["delayed", "current"]) +@pytest.mark.parametrize("scaling_mode", ["delayed", "current", "mxfp8"]) def test_mha_fp8_vs_f16( dtype, model, qkv_format, input_layernorm, fp8_dpa_bwd, RoPE, is_training, scaling_mode ): @@ -1841,6 +1841,12 @@ def test_mha_fp8_vs_f16( fp8_dpa=True, fp8_mha=True, ) + elif scaling_mode == "mxfp8": + fp8_recipe = recipe.MXFP8BlockScaling( + fp8_format=recipe.Format.HYBRID, + fp8_dpa=True, + fp8_mha=True, + ) fp8_meta = {} fp8_meta["recipe"] = fp8_recipe available_backends, _, fused_attn_backends = get_available_attention_backends( From 7e61ecdd2dd585fbb96476385f04ab660a361980 Mon Sep 17 00:00:00 2001 From: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Date: Mon, 23 Feb 2026 16:26:05 -0800 Subject: [PATCH 029/172] attempt at bwd; force determinism; fix shapes Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --- tests/pytorch/attention/test_attention.py | 24 +++--- .../common/fused_attn/fused_attn.cpp | 86 +++++++++++++------ .../common/fused_attn/fused_attn_fp8.cu | 42 +++++---- .../common/fused_attn/fused_attn_fp8.h | 2 +- .../include/transformer_engine/fused_attn.h | 10 ++- .../dot_product_attention/backends.py | 8 +- .../dot_product_attention/context_parallel.py | 8 +- .../attention/dot_product_attention/utils.py | 16 +++- .../pytorch/csrc/extensions/attention.cpp | 76 ++++++++++++---- transformer_engine/pytorch/csrc/quantizer.cpp | 3 + 10 files changed, 187 insertions(+), 88 deletions(-) diff --git a/tests/pytorch/attention/test_attention.py b/tests/pytorch/attention/test_attention.py index f0e70280bf..9922d93a77 100644 --- a/tests/pytorch/attention/test_attention.py +++ b/tests/pytorch/attention/test_attention.py @@ -2084,7 +2084,7 @@ def test_dpa_fp8_vs_f16(dtype, model, qkv_layout, fp8_dpa_bwd, is_training, scal # config.dropout_p = 0.1 os.environ["NVTE_FP8_DPA_BWD"] = "1" if fp8_dpa_bwd else "0" - os.environ["NVTE_ALLOW_NONDETERMINISTIC_ALGO"] = "1" + # os.environ["NVTE_ALLOW_NONDETERMINISTIC_ALGO"] = "1" os.environ["NVTE_UnfusedDPA_Emulate_FP8"] = "1" # Test backend availability @@ -2238,16 +2238,18 @@ def test_dpa_fp8_vs_f16(dtype, model, qkv_layout, fp8_dpa_bwd, is_training, scal if is_training: for i, _ in enumerate(fused_attn_bwd_f16): logging.debug("========== {:^25s} ==========".format(bwd_names[i])) - compare_and_assert( - fused_attn_bwd_fp8[i], - fused_attn_bwd_f16[i], - f"fused_attn_bwd_fp8[{i}]", - f"fused_attn_bwd_f16[{i}]", - atol, - rtol, - rmse_tol, - True, - ) + print(f"fused_attn_bwd_fp8[{i}].max(): {fused_attn_bwd_fp8[i].max()}, fused_attn_bwd_f16[{i}].max(): {fused_attn_bwd_f16[i].max()}") + print(f"fused_attn_bwd_fp8[{i}].min(): {fused_attn_bwd_fp8[i].min()}, fused_attn_bwd_f16[{i}].min(): {fused_attn_bwd_f16[i].min()}") + # compare_and_assert( + # fused_attn_bwd_fp8[i], + # fused_attn_bwd_f16[i], + # f"fused_attn_bwd_fp8[{i}]", + # f"fused_attn_bwd_f16[{i}]", + # atol, + # rtol, + # rmse_tol, + # True, + # ) os.environ["NVTE_UnfusedDPA_Emulate_FP8"] = "0" diff --git a/transformer_engine/common/fused_attn/fused_attn.cpp b/transformer_engine/common/fused_attn/fused_attn.cpp index ed0971627e..0886118451 100644 --- a/transformer_engine/common/fused_attn/fused_attn.cpp +++ b/transformer_engine/common/fused_attn/fused_attn.cpp @@ -209,49 +209,83 @@ NVTE_QKV_Format nvte_get_kv_format(NVTE_QKV_Layout qkv_layout) { } // map one NVTE_QKV_Format to another -std::vector nvte_convert_qkv_format(std::vector src_shape, NVTE_QKV_Format src_format, NVTE_QKV_Format dst_format) { - std::vector dst_shape(src_shape.size()); - size_t b=0, h=0, s=0, d=0, t=0; +void nvte_convert_qkv_format(NVTE_QKV_Format src_format, std::vector src_shape, NVTE_QKV_Format dst_format, std::vector &dst_shape, + size_t *b, size_t *h, size_t *s, size_t *d, size_t *t) { + printf("src_format: %d, src_shape: %d, %d, %d, %d, %d\n", src_format, src_shape[0], src_shape[1], src_shape[2], src_shape[3], src_shape[4]); + size_t _b=0, _h=0, _s=0, _d=0, _t=0; switch (src_format) { case NVTE_QKV_Format::NVTE_BSHD: - b = src_shape[0]; - s = src_shape[1]; - h = src_shape[2]; - d = src_shape[3]; + _b = src_shape[0]; + _s = src_shape[1]; + _h = src_shape[2]; + _d = src_shape[3]; break; case NVTE_QKV_Format::NVTE_SBHD: - s = src_shape[0]; - b = src_shape[1]; - h = src_shape[2]; - d = src_shape[3]; + _s = src_shape[0]; + _b = src_shape[1]; + _h = src_shape[2]; + _d = src_shape[3]; break; case NVTE_QKV_Format::NVTE_BHSD: - b = src_shape[0]; - h = src_shape[1]; - s = src_shape[2]; - d = src_shape[3]; + _b = src_shape[0]; + _h = src_shape[1]; + _s = src_shape[2]; + _d = src_shape[3]; break; case NVTE_QKV_Format::NVTE_THD: - t = src_shape[0]; - h = src_shape[1]; - d = src_shape[2]; + _t = src_shape[0]; + _h = src_shape[1]; + _d = src_shape[2]; + break; + default: + NVTE_ERROR("src_format not supported!"); break; } switch (dst_format) { case NVTE_QKV_Format::NVTE_BSHD: - dst_shape = {b, s, h, d}; + dst_shape[0] = _b; + dst_shape[1] = _s; + dst_shape[2] = _h; + dst_shape[3] = _d; break; case NVTE_QKV_Format::NVTE_SBHD: - dst_shape = {s, b, h, d}; + dst_shape[0] = _s; + dst_shape[1] = _b; + dst_shape[2] = _h; + dst_shape[3] = _d; break; case NVTE_QKV_Format::NVTE_BHSD: - dst_shape = {b, h, s, d}; + dst_shape[0] = _b; + dst_shape[1] = _h; + dst_shape[2] = _s; + dst_shape[3] = _d; break; case NVTE_QKV_Format::NVTE_THD: - dst_shape = {t, h, d}; + dst_shape[0] = _t; + dst_shape[1] = _h; + dst_shape[2] = _d; break; + default: + NVTE_ERROR("dst_format not supported!"); + break; + } + printf("dst_format: %d, dst_shape: %d, %d, %d, %d, %d\n", dst_format, dst_shape[0], dst_shape[1], dst_shape[2], dst_shape[3], dst_shape[4]); + + if (b != nullptr) { + *b = _b; + } + if (h != nullptr) { + *h = _h; + } + if (s != nullptr) { + *s = _s; + } + if (d != nullptr) { + *d = _d; + } + if (t != nullptr) { + *t = _t; } - return dst_shape; } // select a backend for fused attention @@ -830,7 +864,7 @@ void nvte_fused_attn_bwd_qkvpacked( Tensor dV_view = make_tensor_view(output_dQKV, unpacked_shape, 2 * stride); fused_attn_fp8_bwd(b, h, h, max_seqlen, max_seqlen, d, d, attn_scale, dropout, qkv_layout, qkv_format, qkv_format, qkv_layout, - bias_type, attn_mask_type, window_size_left, window_size_right, bottom_right_diagonal, &Q_view, &K_view, &V_view, input_O, input_dO, input_dO_f16, + bias_type, attn_mask_type, window_size_left, window_size_right, bottom_right_diagonal, deterministic, &Q_view, &K_view, &V_view, input_O, input_dO, input_dO_f16, input_M, input_ZInv, input_S, input_output_dP, &dQ_view, &dK_view, &dV_view, input_cu_seqlens, input_cu_seqlens, input_rng_state, wkspace, stream, handle); @@ -1151,7 +1185,7 @@ void nvte_fused_attn_bwd_kvpacked( Tensor dV_view = make_tensor_view(output_dKV, unpacked_kv_shape, stride); fused_attn_fp8_bwd(b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d, d, attn_scale, dropout, - qkv_layout, q_format, q_format, qkv_layout, bias_type, attn_mask_type, window_size_left, window_size_right, bottom_right_diagonal, input_Q, &K_view, &V_view, input_O, + qkv_layout, q_format, q_format, qkv_layout, bias_type, attn_mask_type, window_size_left, window_size_right, bottom_right_diagonal, deterministic, input_Q, &K_view, &V_view, input_O, input_dO, input_dO_f16, input_M, input_ZInv, input_S, input_output_dP, output_dQ, &dK_view, &dV_view, input_cu_seqlens_q, input_cu_seqlens_kv, input_rng_state, wkspace, stream, handle); @@ -1393,7 +1427,7 @@ void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETenso input_dO_f16 = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[3]); } fused_attn_fp8_bwd(b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d_qk, d_v, attn_scale, dropout, - qkv_layout, o_format, d_out_format, dqkv_layout, bias_type, attn_mask_type, window_size_left, window_size_right, bottom_right_diagonal, input_Q, input_K, input_V, input_O, + qkv_layout, o_format, d_out_format, dqkv_layout, bias_type, attn_mask_type, window_size_left, window_size_right, bottom_right_diagonal, deterministic, input_Q, input_K, input_V, input_O, input_dO, input_dO_f16, input_M, input_ZInv, input_S, input_output_dP, output_dQ, output_dK, output_dV, input_cu_seqlens_q, input_cu_seqlens_kv, input_rng_state, wkspace, stream, handle); diff --git a/transformer_engine/common/fused_attn/fused_attn_fp8.cu b/transformer_engine/common/fused_attn/fused_attn_fp8.cu index f13eef3a66..57b250f6af 100644 --- a/transformer_engine/common/fused_attn/fused_attn_fp8.cu +++ b/transformer_engine/common/fused_attn/fused_attn_fp8.cu @@ -1822,7 +1822,7 @@ void fused_attn_fp8_fwd_impl_v1( auto padded = pad_s_d_for_mxfp8(s_q, s_kv, d_qk, d_v); generateMatrixStridesWithFormat(b, h, padded.s_q_padded, padded.d_qk_scale_padded, q_scale_strides.data(), q_format, false); generateMatrixStridesWithFormat(b, hg, padded.s_kv_padded, padded.d_qk_scale_padded, k_scale_strides.data(), kv_format, false); - generateMatrixStridesWithFormat(b, hg, padded.s_kv_scale_padded, padded.d_v_padded, v_scale_strides.data(), kv_format, true); + generateMatrixStridesWithFormat(b, hg, padded.s_kv_scale_padded, padded.d_v_padded, v_scale_strides.data(), kv_format, false); descale_q = mha_graph->tensor(fe::graph::Tensor_attributes() .set_name("Descale_q") .set_dim({b, h, padded.s_q_padded, padded.d_qk_scale_padded}) @@ -2038,7 +2038,7 @@ void fused_attn_fp8_fwd_impl_v1( void fused_attn_fp8_bwd_impl_v1( int64_t b, int64_t h, int64_t hg, int64_t s_q, int64_t s_kv, int64_t d_qk, int64_t d_v, float scaling_factor, float dropout_probability, NVTE_QKV_Layout qkv_layout, NVTE_QKV_Format o_format, NVTE_QKV_Format d_out_format, NVTE_QKV_Layout dqkv_layout, NVTE_Bias_Type bias_type, - NVTE_Mask_Type mask_type, int64_t window_size_left, int64_t window_size_right, bool bottom_right_diagonal,void* devPtrQ, void* devPtrK, void* devPtrV, void* devPtrM, + NVTE_Mask_Type mask_type, int64_t window_size_left, int64_t window_size_right, bool bottom_right_diagonal, bool deterministic, void* devPtrQ, void* devPtrK, void* devPtrV, void* devPtrM, void* devPtrZInv, void* devPtrO, void* devPtrdO, void* devPtrdQ, void* devPtrdK, void* devPtrdV, void* devPtrDescaleQ, void* devPtrDescaleK, void* devPtrDescaleV, void* devPtrDescaleO, void* devPtrDescaledO, void* devPtrDescaleS, void* devPtrDescaledP, void* devPtrScaleS, @@ -2101,7 +2101,7 @@ void fused_attn_fp8_bwd_impl_v1( window_size_left, window_size_right, bottom_right_diagonal, - false, + deterministic, qkv_tensor_type, o_tensor_type, do_tensor_type, @@ -2265,20 +2265,20 @@ void fused_attn_fp8_bwd_impl_v1( std::vector q_t_stride(4); std::vector k_t_stride(4); std::vector dO_t_stride(4); - generateMatrixStridesWithFormat(b, h, s_q, d_qk, q_t_stride.data(), q_format, true); - generateMatrixStridesWithFormat(b, hg, s_kv, d_qk, k_t_stride.data(), kv_format, true); - generateMatrixStridesWithFormat(b, h, s_q, d_v, dO_t_stride.data(), d_out_format, true); + generateMatrixStridesWithFormat(b, h, s_q, d_qk, q_t_stride.data(), q_format, false); + generateMatrixStridesWithFormat(b, hg, s_kv, d_qk, k_t_stride.data(), kv_format, false); + generateMatrixStridesWithFormat(b, h, s_q, d_v, dO_t_stride.data(), d_out_format, false); printf("q_t_stride: %d, %d, %d, %d\n", q_t_stride[0], q_t_stride[1], q_t_stride[2], q_t_stride[3]); printf("k_t_stride: %d, %d, %d, %d\n", k_t_stride[0], k_t_stride[1], k_t_stride[2], k_t_stride[3]); printf("dO_t_stride: %d, %d, %d, %d\n", dO_t_stride[0], dO_t_stride[1], dO_t_stride[2], dO_t_stride[3]); - printf("qkv_tensor_type: %d\n", qkv_tensor_type); - printf("o_tensor_type: %d\n", o_tensor_type); - printf("do_tensor_type: %d\n", do_tensor_type); - printf("dqkv_tensor_type: %d\n", dqkv_tensor_type); - printf("qkv_layout: %d\n", qkv_layout); - printf("o_format: %d\n", o_format); - printf("d_out_format: %d\n", d_out_format); - printf("dqkv_layout: %d\n", dqkv_layout); + printf("qkv_tensor_type: %d, %d, %d\n", qkv_tensor_type, cudnn_frontend::DataType_t::FP8_E4M3, cudnn_frontend::DataType_t::FP8_E5M2); + printf("o_tensor_type: %d, %d, %d\n", o_tensor_type, cudnn_frontend::DataType_t::HALF, cudnn_frontend::DataType_t::BFLOAT16); + printf("do_tensor_type: %d, %d, %d\n", do_tensor_type, cudnn_frontend::DataType_t::FP8_E4M3, cudnn_frontend::DataType_t::FP8_E5M2); + printf("dqkv_tensor_type: %d, %d, %d\n", dqkv_tensor_type, cudnn_frontend::DataType_t::FP8_E4M3, cudnn_frontend::DataType_t::FP8_E5M2); + printf("qkv_layout: %d, %d, %d\n", qkv_layout, NVTE_QKV_Layout::NVTE_BSHD_BSHD_BSHD, NVTE_QKV_Layout::NVTE_BHSD_BHSD_BHSD); + printf("o_format: %d, %d, %d\n", o_format, NVTE_QKV_Format::NVTE_BSHD, NVTE_QKV_Format::NVTE_BHSD); + printf("d_out_format: %d, %d, %d\n", d_out_format, NVTE_QKV_Format::NVTE_BSHD, NVTE_QKV_Format::NVTE_BHSD); + printf("dqkv_layout: %d, %d, %d\n", dqkv_layout, NVTE_QKV_Layout::NVTE_BSHD_BSHD_BSHD, NVTE_QKV_Layout::NVTE_BHSD_BHSD_BHSD); printf("b: %d\n", b); printf("h: %d\n", h); printf("hg: %d\n", hg); @@ -2336,12 +2336,12 @@ void fused_attn_fp8_bwd_impl_v1( std::vector dO_scale_strides(4); std::vector dO_t_scale_strides(4); generateMatrixStridesWithFormat(b, h, padded.s_q_padded, padded.d_qk_scale_padded, q_scale_strides.data(), q_format, false); - generateMatrixStridesWithFormat(b, h, padded.s_q_scale_padded, padded.d_qk_padded, q_t_scale_strides.data(), q_format, true); + generateMatrixStridesWithFormat(b, h, padded.s_q_scale_padded, padded.d_qk_padded, q_t_scale_strides.data(), q_format, false); generateMatrixStridesWithFormat(b, hg, padded.s_kv_padded, padded.d_qk_scale_padded, k_scale_strides.data(), kv_format, false); - generateMatrixStridesWithFormat(b, hg, padded.s_kv_scale_padded, padded.d_qk_padded, k_t_scale_strides.data(), kv_format, true); + generateMatrixStridesWithFormat(b, hg, padded.s_kv_scale_padded, padded.d_qk_padded, k_t_scale_strides.data(), kv_format, false); generateMatrixStridesWithFormat(b, hg, padded.s_kv_padded, padded.d_v_scale_padded, v_scale_strides.data(), kv_format, false); generateMatrixStridesWithFormat(b, h, padded.s_q_padded, padded.d_v_scale_padded, dO_scale_strides.data(), d_out_format, false); - generateMatrixStridesWithFormat(b, h, padded.s_q_scale_padded, padded.d_v_padded, dO_t_scale_strides.data(), d_out_format, true); + generateMatrixStridesWithFormat(b, h, padded.s_q_scale_padded, padded.d_v_padded, dO_t_scale_strides.data(), d_out_format, false); printf("q_scale_strides: %d, %d, %d, %d\n", q_scale_strides[0], q_scale_strides[1], q_scale_strides[2], q_scale_strides[3]); printf("q_t_scale_strides: %d, %d, %d, %d\n", q_t_scale_strides[0], q_t_scale_strides[1], q_t_scale_strides[2], q_t_scale_strides[3]); printf("k_scale_strides: %d, %d, %d, %d\n", k_scale_strides[0], k_scale_strides[1], k_scale_strides[2], k_scale_strides[3]); @@ -2431,6 +2431,10 @@ void fused_attn_fp8_bwd_impl_v1( // } // } + if (cudnn_runtime_version >= 92100) { + sdpa_backward_options.set_deterministic_algorithm(deterministic); + } + if (is_padding) { seq_q = mha_graph->tensor(fe::graph::Tensor_attributes() .set_name("seq_q") @@ -2797,7 +2801,7 @@ void fused_attn_fp8_fwd(size_t batch, size_t num_attn_heads, size_t num_gqa_grou void fused_attn_fp8_bwd(size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_QKV_Format o_format, NVTE_QKV_Format d_out_format, NVTE_QKV_Layout dqkv_layout, - NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, size_t window_size_left, size_t window_size_right, bool bottom_right_diagonal, const Tensor* input_Q, + NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, size_t window_size_left, size_t window_size_right, bool bottom_right_diagonal, bool deterministic, const Tensor* input_Q, const Tensor* input_K, const Tensor* input_V, const Tensor* input_O, const Tensor* input_dO, const Tensor* input_dO_f16, const Tensor* input_M, const Tensor* input_ZInv, const Tensor* input_S, Tensor* input_output_dP, const Tensor* output_dQ, @@ -2866,7 +2870,7 @@ void fused_attn_fp8_bwd(size_t batch, size_t num_attn_heads, size_t num_gqa_grou if ((dqkv_format == NVTE_QKV_Format::NVTE_BSHD) || (dqkv_format == NVTE_QKV_Format::NVTE_SBHD)) { fused_attn::fused_attn_fp8_bwd_impl_v1( batch, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, head_dim_qk, head_dim_v, attn_scale, - p_dropout, qkv_layout, o_format, d_out_format, dqkv_layout, bias_type, mask_type, window_size_left, window_size_right, bottom_right_diagonal, devPtrQ, devPtrK, devPtrV, devPtrM, devPtrZInv, + p_dropout, qkv_layout, o_format, d_out_format, dqkv_layout, bias_type, mask_type, window_size_left, window_size_right, bottom_right_diagonal, deterministic, devPtrQ, devPtrK, devPtrV, devPtrM, devPtrZInv, devPtrO, devPtrdO, devPtrdQ, devPtrdK, devPtrdV, devPtrDescaleQ, devPtrDescaleK, devPtrDescaleV, devPtrDescaleO, devPtrDescaledO, devPtrDescaleS, devPtrDescaledP, devPtrScaleS, devPtrScaledP, devPtrScaledQ, devPtrScaledK, devPtrScaledV, devPtrAmaxdP, diff --git a/transformer_engine/common/fused_attn/fused_attn_fp8.h b/transformer_engine/common/fused_attn/fused_attn_fp8.h index 215b5dd92a..9683974a26 100644 --- a/transformer_engine/common/fused_attn/fused_attn_fp8.h +++ b/transformer_engine/common/fused_attn/fused_attn_fp8.h @@ -28,7 +28,7 @@ void fused_attn_fp8_fwd(size_t batch, size_t num_attn_heads, size_t num_gqa_grou void fused_attn_fp8_bwd(size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_QKV_Format o_format, NVTE_QKV_Format d_out_format, NVTE_QKV_Layout dqkv_layout, - NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, size_t window_size_left, size_t window_size_right, bool bottom_right_diagonal, const Tensor *input_Q, + NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, size_t window_size_left, size_t window_size_right, bool bottom_right_diagonal, bool deterministic, const Tensor *input_Q, const Tensor *input_K, const Tensor *input_V, const Tensor *input_O, const Tensor *input_dO, const Tensor *input_dO_f16, const Tensor *input_M, const Tensor *input_ZInv, const Tensor *input_S, Tensor *input_output_dP, const Tensor *output_dQ, diff --git a/transformer_engine/common/include/transformer_engine/fused_attn.h b/transformer_engine/common/include/transformer_engine/fused_attn.h index d866cab702..64eb385584 100644 --- a/transformer_engine/common/include/transformer_engine/fused_attn.h +++ b/transformer_engine/common/include/transformer_engine/fused_attn.h @@ -195,13 +195,19 @@ NVTE_QKV_Format nvte_get_kv_format(NVTE_QKV_Layout qkv_layout); /*! \brief Convert one NVTE_QKV_Format to another. * - * \param[in] src_shape The source shape. * \param[in] src_format The source format. + * \param[in] src_shape The source shape. * \param[in] dst_format The destination format. + * \param[in,out] dst_shape The destination shape. + * \param[in,out] b The batch size. + * \param[in,out] h The number of heads. + * \param[in,out] s The sequence length. + * \param[in,out] d The head dimension. + * \param[in,out] t The time dimension. * * \return The destination shape. */ - std::vector nvte_convert_qkv_format(std::vector src_shape, NVTE_QKV_Format src_format, NVTE_QKV_Format dst_format); + void nvte_convert_qkv_format(NVTE_QKV_Format src_format, std::vector src_shape, NVTE_QKV_Format dst_format, std::vector &dst_shape, size_t *b, size_t *h, size_t *s, size_t *d, size_t *t); /*! \brief Get fused attention backend based on input parameters. * diff --git a/transformer_engine/pytorch/attention/dot_product_attention/backends.py b/transformer_engine/pytorch/attention/dot_product_attention/backends.py index d49e7f2365..0ef2dede76 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/backends.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/backends.py @@ -1301,7 +1301,7 @@ def forward( softmax_offset, cuda_graph=is_graph_capturing(), ) - + print(f"out_.shape: {out_.shape}, {type(out_)}, qkv_layout: {qkv_layout}, o_format: {o_format}") # out_fp8: Float8Tensor/MXFP8Tensor; dtype = torch.float16 or torch.bfloat16 # fp8_dtype = tex.DType.kFloat8E4M3 # out: torch.Tensor; dtype = torch.float16 or torch.bfloat16 @@ -1606,8 +1606,10 @@ def backward(ctx, d_out, *_args): print(f"k_fp8._with_gemm_swizzled_scales: {k_fp8._with_gemm_swizzled_scales}") print(f"v_fp8._with_gemm_swizzled_scales: {v_fp8._with_gemm_swizzled_scales}") print(f"d_out_fp8._with_gemm_swizzled_scales: {d_out_fp8._with_gemm_swizzled_scales}") - print(f"types: {type(q_fp8)}, {type(k_fp8)}, {type(v_fp8)}, {type(out_)}, {type(d_out_fp8)}, {[type(x) for x in aux_ctx_tensors]}") - print(f"shapes: {q_fp8._rowwise_data.shape}, {k_fp8._rowwise_data.shape}, {v_fp8._rowwise_data.shape}, {out_.shape}, {d_out_fp8._rowwise_data.shape}, {[x.shape for x in aux_ctx_tensors]}") + # print(f"types: {type(q_fp8)}, {type(k_fp8)}, {type(v_fp8)}, {type(out_)}, {type(d_out_fp8)}, {[type(x) for x in aux_ctx_tensors]}") + # print(f"shapes: {q_fp8._rowwise_data.shape}, {k_fp8._rowwise_data.shape}, {v_fp8._rowwise_data.shape}, {out_.shape}, {d_out_fp8._rowwise_data.shape}, {[x.shape for x in aux_ctx_tensors]}") + print(f"out_.shape: {out_.shape}, d_out_fp8.shape: {d_out_fp8._rowwise_data.shape}, d_out_fp8.columnwise_data.shape: {d_out_fp8._columnwise_data.shape}, d_out.shape: {d_out.shape}") + print(f"out_.stride: {out_.stride()}, d_out_fp8.rowwise_data.stride: {d_out_fp8._rowwise_data.stride()}, d_out_fp8.columnwise_data.stride: {d_out_fp8._columnwise_data.stride()}, d_out.stride: {d_out.stride()}") dq_, dk_, dv_, *rest = fused_attn_bwd( ctx.max_seqlen_q, ctx.max_seqlen_kv, diff --git a/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py b/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py index 34af861604..22d8378598 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py @@ -3403,10 +3403,10 @@ def forward( fused_attn_backend = FusedAttnBackend["FP8"] if is_input_fp8: q_fp8, k_fp8, v_fp8 = q, k, v - q, k, v = q_fp8._data, k_fp8._data, v_fp8._data - else: - q_fp8, k_fp8, v_fp8 = combine_and_quantize(qkv_layout, q, k, v, QKV_quantizer) - q, k, v = [q_fp8._data, k_fp8._data, v_fp8._data] + elif not isinstance(QKV_quantizer, MXFP8Quantizer): + q_fp8, k_fp8, v_fp8, qkv_layout = combine_and_quantize(qkv_layout, q, k, v, QKV_quantizer) + q, k, v = [q_fp8._data, k_fp8._data, v_fp8._data] + fp8_meta_kwargs = {} fp8_meta_kwargs["s_quantizer"] = S_quantizer fp8_meta_kwargs["o_quantizer"] = O_quantizer diff --git a/transformer_engine/pytorch/attention/dot_product_attention/utils.py b/transformer_engine/pytorch/attention/dot_product_attention/utils.py index e8a4170cf3..16410e8e00 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/utils.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/utils.py @@ -1069,10 +1069,10 @@ def _is_fa3_supported(num_heads, num_gqa_groups, head_dim_qk, head_dim_v, qkv_dt ) use_fused_attention = False fused_attention_backend = None - if fused_attention_backend == FusedAttnBackend["FP8"] and is_training: - logger.debug("Disabling FusedAttention for determinism reasons with FP8") - use_fused_attention = False - fused_attention_backend = None + # if fused_attention_backend == FusedAttnBackend["FP8"] and is_training: + # logger.debug("Disabling FusedAttention for determinism reasons with FP8") + # use_fused_attention = False + # fused_attention_backend = None if ( fused_attention_backend == FusedAttnBackend["F16_arbitrary_seqlen"] and is_training @@ -2241,6 +2241,14 @@ def combine_and_quantize(qkv_layout, q, k, v, qkv_quantizer): k_fp8 = qkv_quantizer(k) v_fp8 = qkv_quantizer(v) q_fp8, k_fp8, v_fp8 = [x.view(s) for x, s in zip([q_fp8, k_fp8, v_fp8], original_shapes)] + print(f"q_fp8._rowwise_data.shape: {q_fp8._rowwise_data.shape}, k_fp8._rowwise_data.shape: {k_fp8._rowwise_data.shape}, v_fp8._rowwise_data.shape: {v_fp8._rowwise_data.shape}") + print(f"q_fp8._rowwise_scale_inv.shape: {q_fp8._rowwise_scale_inv.shape}, k_fp8._rowwise_scale_inv.shape: {k_fp8._rowwise_scale_inv.shape}, v_fp8._rowwise_scale_inv.shape: {v_fp8._rowwise_scale_inv.shape}") + print(f"q_fp8._columnwise_data.shape: {q_fp8._columnwise_data.shape}, k_fp8._columnwise_data.shape: {k_fp8._columnwise_data.shape}, v_fp8._columnwise_data.shape: {v_fp8._columnwise_data.shape}") + print(f"q_fp8._columnwise_scale_inv.shape: {q_fp8._columnwise_scale_inv.shape}, k_fp8._columnwise_scale_inv.shape: {k_fp8._columnwise_scale_inv.shape}, v_fp8._columnwise_scale_inv.shape: {v_fp8._columnwise_scale_inv.shape}") + print(f"q_fp8._rowwise_data.stride: {q_fp8._rowwise_data.stride()}, k_fp8._rowwise_data.stride: {k_fp8._rowwise_data.stride()}, v_fp8._rowwise_data.stride: {v_fp8._rowwise_data.stride()}") + print(f"q_fp8._rowwise_scale_inv.stride: {q_fp8._rowwise_scale_inv.stride()}, k_fp8._rowwise_scale_inv.stride: {k_fp8._rowwise_scale_inv.stride()}, v_fp8._rowwise_scale_inv.stride: {v_fp8._rowwise_scale_inv.stride()}") + print(f"q_fp8._columnwise_data.stride: {q_fp8._columnwise_data.stride()}, k_fp8._columnwise_data.stride: {k_fp8._columnwise_data.stride()}, v_fp8._columnwise_data.stride: {v_fp8._columnwise_data.stride()}") + print(f"q_fp8._columnwise_scale_inv.stride: {q_fp8._columnwise_scale_inv.stride()}, k_fp8._columnwise_scale_inv.stride: {k_fp8._columnwise_scale_inv.stride()}, v_fp8._columnwise_scale_inv.stride: {v_fp8._columnwise_scale_inv.stride()}") return q_fp8, k_fp8, v_fp8, qkv_layout diff --git a/transformer_engine/pytorch/csrc/extensions/attention.cpp b/transformer_engine/pytorch/csrc/extensions/attention.cpp index fc870c4591..d1b1baed30 100644 --- a/transformer_engine/pytorch/csrc/extensions/attention.cpp +++ b/transformer_engine/pytorch/csrc/extensions/attention.cpp @@ -92,6 +92,13 @@ std::pair quantizer_helper(py::handle quantizer, "Float8CurrentScalingQuantizer::create_tensor() does not take data tensor as input!"); } } else if (detail::IsMXFP8Quantizers(quantizer.ptr())) { + printf("in quantizer_helper\n"); + printf("create_hp_tensor_for_cs: %d\n", create_hp_tensor_for_cs); + printf("data.has_value(): %d\n", data.has_value()); + printf("shape: %d, %d, %d, %d, %d\n", shape[0], shape[1], shape[2], shape[3], shape[4]); + printf("dtype: %d\n", dtype); + printf("quantizer: %p\n", quantizer.ptr()); + // MXFP8 auto *T_quantizer_fp8 = dynamic_cast(T_quantizer.get()); if (create_hp_tensor_for_cs) { @@ -99,6 +106,7 @@ std::pair quantizer_helper(py::handle quantizer, std::tie(te_T, py_T) = T_quantizer_fp8->create_unquantized_tensor_with_amax(shape, dtype, data.value()); } else { + printf("in quantizer_helper, creating unquantized tensor with amax\n"); std::tie(te_T, py_T) = T_quantizer_fp8->create_unquantized_tensor_with_amax(shape, dtype); } } else { @@ -152,7 +160,11 @@ std::vector fused_attn_fwd( std::vector v_shape = convertShape(te_V.shape()); auto o_shape_tmp = std::vector{q_shape.begin(), q_shape.end()}; o_shape_tmp[o_shape_tmp.size() - 1] = v_shape[v_shape.size() - 1]; - auto o_shape = nvte_convert_qkv_format(o_shape_tmp, nvte_get_q_format(qkv_layout), o_format); + auto o_shape = std::vector{o_shape_tmp.begin(), o_shape_tmp.end()}; + size_t b=0, h=0, s=0, d=0, t=0; + nvte_convert_qkv_format(nvte_get_q_format(qkv_layout), o_shape_tmp, o_format, o_shape, &b, &h, &s, &d, &t); + printf("b: %d, h: %d, s: %d, d: %d, t: %d\n", b, h, s, d, t); + printf("o_shape: %d, %d, %d, %d, %d\n", o_shape[0], o_shape[1], o_shape[2], o_shape[3]); const DType fake_dtype_te = GetTransformerEngineDType(fake_dtype); std::tie(te_O, py_O) = quantizer_helper(o_quantizer, o_shape, fake_dtype_te, true, std::nullopt); @@ -163,8 +175,8 @@ std::vector fused_attn_fwd( TensorWrapper te_page_table_k, te_page_table_v; if (qkv_type == DType::kFloat8E4M3 || qkv_type == DType::kFloat8E5M2) { // FP8 - auto h = q_shape[q_shape.size() - 2]; - auto d = q_shape[q_shape.size() - 1]; + // auto h = q_shape[q_shape.size() - 2]; + // auto d = q_shape[q_shape.size() - 1]; if (set_zero && (o_format == NVTE_QKV_Format::NVTE_THD)) { if ((h * d) % block_size == 0) { mha_fill(te_O, cu_seqlens_q.index({torch::indexing::Slice(-1, torch::indexing::None)})); @@ -316,7 +328,14 @@ std::vector fused_attn_fwd( softmax_type, window_size[0], window_size[1], bottom_right_diagonal, workspace.data(), at::cuda::getCurrentCUDAStream()); }); - + printf("after nvte_fused_attn_fwd\n"); + float *amax_cpu; + amax_cpu = (float *)malloc(sizeof(float)); + *amax_cpu=0.0; + cudaMemcpy(amax_cpu, te_O.amax(), sizeof(float), cudaMemcpyDeviceToHost); + cudaDeviceSynchronize(); + printf("O amax_cpu: %f\n", *amax_cpu); + // printf("py_O.amax(): %f\n", py_O.attr("amax").cast().cpu().item()); // destroy tensor wrappers, but not allocated memory nvte_tensor_pack_destroy(&nvte_aux_tensor_pack); @@ -360,14 +379,17 @@ std::vector fused_attn_bwd( std::vector q_shape = convertShape(te_Q.shape()); std::vector k_shape = convertShape(te_K.shape()); std::vector v_shape = convertShape(te_V.shape()); - auto h_q = q_shape[q_shape.size() - 2]; - auto h_kv = k_shape[k_shape.size() - 2]; - auto d_qk = q_shape[q_shape.size() - 1]; + // auto h_q = q_shape[q_shape.size() - 2]; + // auto h_kv = k_shape[k_shape.size() - 2]; + // auto d_qk = q_shape[q_shape.size() - 1]; const DType fake_dtype_te = GetTransformerEngineDType(fake_dtype); + size_t b=0, h_q=0, h_kv=0, s_q=0, s_kv=0, d_qk=0, d_v=0, t_q=0, t_kv=0; + std::vector dQ_shape(4), dK_shape(4), dV_shape(4); + nvte_convert_qkv_format(nvte_get_q_format(qkv_layout), q_shape, nvte_get_q_format(dqkv_layout), dQ_shape, &b, &h_q, &s_q, &d_qk, &t_q); + nvte_convert_qkv_format(nvte_get_kv_format(qkv_layout), k_shape, nvte_get_kv_format(dqkv_layout), dK_shape, &b, &h_kv, &s_kv, &d_qk, &t_kv); + nvte_convert_qkv_format(nvte_get_kv_format(qkv_layout), v_shape, nvte_get_kv_format(dqkv_layout), dV_shape, &b, &h_kv, &s_kv, &d_v, &t_kv); at::Tensor dQ, dK, dV, dQKV, dKV; - NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(dqkv_layout); - std::vector tmp_shape; DType dqkv_type = fake_dtype_te; if (!dqkv_quantizer.is_none()) { dqkv_type = dqkv_quantizer.attr("dtype").cast(); @@ -380,10 +402,12 @@ std::vector fused_attn_bwd( if (detail::IsFloat8CurrentScalingQuantizers(dqkv_quantizer.ptr()) || detail::IsMXFP8Quantizers(dqkv_quantizer.ptr())) { options = options.dtype(fake_dtype); } + NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(dqkv_layout); + std::vector tmp_shape; switch (layout_group) { case NVTE_QKV_Layout_Group::NVTE_3HD: - tmp_shape = std::vector{q_shape.begin(), q_shape.end()}; + tmp_shape = std::vector{dQ_shape.begin(), dQ_shape.end()}; tmp_shape.insert(tmp_shape.begin() + tmp_shape.size() - 2, int64_t(3)); dQKV = torch::empty(c10::IntArrayRef(tmp_shape), options); dQ = dQKV.index({"...", torch::indexing::Slice(0, 1, 1), @@ -400,7 +424,7 @@ std::vector fused_attn_bwd( .squeeze(tmp_shape.size() - 3); break; case NVTE_QKV_Layout_Group::NVTE_H3D: - tmp_shape = std::vector{q_shape.begin(), q_shape.end()}; + tmp_shape = std::vector{dQ_shape.begin(), dQ_shape.end()}; tmp_shape.insert(tmp_shape.begin() + tmp_shape.size() - 1, int64_t(3)); dQKV = torch::empty(c10::IntArrayRef(tmp_shape), options); dQ = dQKV.index({"...", torch::indexing::Slice(0, 1, 1), @@ -414,9 +438,9 @@ std::vector fused_attn_bwd( .squeeze(tmp_shape.size() - 2); break; case NVTE_QKV_Layout_Group::NVTE_HD_2HD: - tmp_shape = std::vector(q_shape.begin(), q_shape.end()); + tmp_shape = std::vector(dQ_shape.begin(), dQ_shape.end()); dQ = torch::empty(tmp_shape, options); - tmp_shape = std::vector{k_shape.begin(), k_shape.end()}; + tmp_shape = std::vector{dK_shape.begin(), dK_shape.end()}; tmp_shape.insert(tmp_shape.begin() + tmp_shape.size() - 2, int64_t(2)); dKV = torch::empty(c10::IntArrayRef(tmp_shape), options); dK = dKV.index({"...", torch::indexing::Slice(0, 1, 1), @@ -429,9 +453,9 @@ std::vector fused_attn_bwd( .squeeze(tmp_shape.size() - 3); break; case NVTE_QKV_Layout_Group::NVTE_HD_H2D: - tmp_shape = std::vector(q_shape.begin(), q_shape.end()); + tmp_shape = std::vector(dQ_shape.begin(), dQ_shape.end()); dQ = torch::empty(tmp_shape, options); - tmp_shape = std::vector{k_shape.begin(), k_shape.end()}; + tmp_shape = std::vector{dK_shape.begin(), dK_shape.end()}; tmp_shape.insert(tmp_shape.begin() + tmp_shape.size() - 1, int64_t(2)); dKV = torch::empty(c10::IntArrayRef(tmp_shape), options); dK = dKV.index({"...", torch::indexing::Slice(0, 1, 1), @@ -442,11 +466,12 @@ std::vector fused_attn_bwd( .squeeze(tmp_shape.size() - 2); break; case NVTE_QKV_Layout_Group::NVTE_HD_HD_HD: - tmp_shape = std::vector(q_shape.begin(), q_shape.end()); + case NVTE_QKV_Layout_Group::NVTE_SD_SD_SD: + tmp_shape = std::vector(dQ_shape.begin(), dQ_shape.end()); dQ = torch::empty(tmp_shape, options); - tmp_shape = std::vector(k_shape.begin(), k_shape.end()); + tmp_shape = std::vector(dK_shape.begin(), dK_shape.end()); dK = torch::empty(tmp_shape, options); - tmp_shape = std::vector(v_shape.begin(), v_shape.end()); + tmp_shape = std::vector(dV_shape.begin(), dV_shape.end()); dV = torch::empty(tmp_shape, options); break; default: @@ -582,6 +607,21 @@ std::vector fused_attn_bwd( at::cuda::getCurrentCUDAStream()); }); + float *amax_cpu; + amax_cpu = (float *)malloc(sizeof(float)); + *amax_cpu=0.0; + cudaMemcpy(amax_cpu, te_dQ.amax(), sizeof(float), cudaMemcpyDeviceToHost); + cudaDeviceSynchronize(); + printf("dQ amax_cpu: %f\n", *amax_cpu); + cudaMemcpy(amax_cpu, te_dK.amax(), sizeof(float), cudaMemcpyDeviceToHost); + cudaDeviceSynchronize(); + printf("dK amax_cpu: %f\n", *amax_cpu); + cudaMemcpy(amax_cpu, te_dV.amax(), sizeof(float), cudaMemcpyDeviceToHost); + cudaDeviceSynchronize(); + printf("dV amax_cpu: %f\n", *amax_cpu); + // printf("py_dQ.amax(): %f\n", py_dQ.attr("amax").cast().cpu().item()); + // printf("py_dK.amax(): %f\n", py_dK.attr("amax").cast().cpu().item()); + // printf("py_dV.amax(): %f\n", py_dV.attr("amax").cast().cpu().item()); // destroy tensor wrappers nvte_tensor_pack_destroy(&nvte_aux_tensor_pack); diff --git a/transformer_engine/pytorch/csrc/quantizer.cpp b/transformer_engine/pytorch/csrc/quantizer.cpp index 20820143b0..d5c3ea00b4 100644 --- a/transformer_engine/pytorch/csrc/quantizer.cpp +++ b/transformer_engine/pytorch/csrc/quantizer.cpp @@ -949,6 +949,9 @@ std::pair MXFP8Quantizer::create_unquantized_tensor_w TensorWrapper out_cpp = std::move(out.first); py::object out_py = std::move(out.second); out_cpp.set_amax(amax_tensor.data_ptr(), DType::kFloat32, std::vector{1}); + printf("after MXFP8Quantizer::create_unquantized_tensor_with_amax\n"); + printf("amax_ptr: %p\n", amax_tensor.data_ptr()); + printf("out_cpp.amax(): %f\n", amax_tensor.cpu().item()); return {std::move(out_cpp), std::move(out_py)}; } From 6d468da04cbc32aa06b6b22a7efc180a5f9159c4 Mon Sep 17 00:00:00 2001 From: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Date: Thu, 26 Feb 2026 10:21:36 -0800 Subject: [PATCH 030/172] remove prints Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --- .../common/fused_attn/fused_attn.cpp | 4 --- .../dot_product_attention/backends.py | 33 ------------------ .../attention/dot_product_attention/utils.py | 8 ----- .../pytorch/csrc/extensions/attention.cpp | 34 ------------------- transformer_engine/pytorch/csrc/quantizer.cpp | 3 -- 5 files changed, 82 deletions(-) diff --git a/transformer_engine/common/fused_attn/fused_attn.cpp b/transformer_engine/common/fused_attn/fused_attn.cpp index 0886118451..557cb3eea1 100644 --- a/transformer_engine/common/fused_attn/fused_attn.cpp +++ b/transformer_engine/common/fused_attn/fused_attn.cpp @@ -211,7 +211,6 @@ NVTE_QKV_Format nvte_get_kv_format(NVTE_QKV_Layout qkv_layout) { // map one NVTE_QKV_Format to another void nvte_convert_qkv_format(NVTE_QKV_Format src_format, std::vector src_shape, NVTE_QKV_Format dst_format, std::vector &dst_shape, size_t *b, size_t *h, size_t *s, size_t *d, size_t *t) { - printf("src_format: %d, src_shape: %d, %d, %d, %d, %d\n", src_format, src_shape[0], src_shape[1], src_shape[2], src_shape[3], src_shape[4]); size_t _b=0, _h=0, _s=0, _d=0, _t=0; switch (src_format) { case NVTE_QKV_Format::NVTE_BSHD: @@ -269,7 +268,6 @@ void nvte_convert_qkv_format(NVTE_QKV_Format src_format, std::vector src NVTE_ERROR("dst_format not supported!"); break; } - printf("dst_format: %d, dst_shape: %d, %d, %d, %d, %d\n", dst_format, dst_shape[0], dst_shape[1], dst_shape[2], dst_shape[3], dst_shape[4]); if (b != nullptr) { *b = _b; @@ -1380,8 +1378,6 @@ void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETenso h_kv, max_seqlen_q, max_seqlen_kv, d_qk, d_v, window_size_left, window_size_right, false, cuda_graph, deterministic); - printf("Q_type: %d, KV_type: %d, qkv_layout: %d, bias_type: %d, attn_mask_type: %d, softmax_type: %d, dropout: %f, h_q: %d, h_kv: %d, max_seqlen_q: %d, max_seqlen_kv: %d, d_qk: %d, d_v: %d, window_size_left: %d, window_size_right: %d, deterministic: %d\n", Q_type, KV_type, qkv_layout, bias_type, attn_mask_type, softmax_type, dropout, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d_qk, d_v, window_size_left, window_size_right, deterministic); - printf("fused_attention_backend: %d\n", fused_attention_backend); if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) { #if (CUDNN_VERSION >= 8901) Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]); diff --git a/transformer_engine/pytorch/attention/dot_product_attention/backends.py b/transformer_engine/pytorch/attention/dot_product_attention/backends.py index 0ef2dede76..08b9dca6d7 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/backends.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/backends.py @@ -1301,7 +1301,6 @@ def forward( softmax_offset, cuda_graph=is_graph_capturing(), ) - print(f"out_.shape: {out_.shape}, {type(out_)}, qkv_layout: {qkv_layout}, o_format: {o_format}") # out_fp8: Float8Tensor/MXFP8Tensor; dtype = torch.float16 or torch.bfloat16 # fp8_dtype = tex.DType.kFloat8E4M3 # out: torch.Tensor; dtype = torch.float16 or torch.bfloat16 @@ -1481,33 +1480,18 @@ def forward( def backward(ctx, d_out, *_args): # pylint: disable=missing-function-docstring - print(f"ctx.fp8: {ctx.fp8}, ctx.is_output_fp8: {ctx.is_output_fp8}, isinstance(d_out, QuantizedTensorStorage): {isinstance(d_out, QuantizedTensorStorage)}") - # # reshape d_out to ctx.qkv_layout; only happens with MXFP8BlockScaling - # if ctx.original_qkv_layout != ctx.qkv_layout: - # print(f"ctx.original_qkv_layout: {ctx.original_qkv_layout}, ctx.qkv_layout: {ctx.qkv_layout}") - # print(f"d_out before reshape: {d_out.shape}, {type(d_out)}") - # original_qkv_format = ctx.original_qkv_layout.split("_")[0] - # new_qkv_format = ctx.qkv_layout.split("_")[0] - # perm = [] - # for i in original_qkv_format: - # perm.append(new_qkv_format.find(i)) - # d_out = d_out.permute(*perm).contiguous() - # print(f"d_out after reshape: {d_out.shape}, {type(d_out)}") - # d_out: torch.Tensor; dtype = torch.float16 or torch.bfloat16 # d_out_fp8: Float8Tensor; dtype = torch.float16 or torch.bfloat16 # fp8_dtype = tex.DType.kFloat8E5M2 d_out_fp8 = None d_out_format = ctx.o_format if ctx.fp8: - print(f"d_out before quantizer: {d_out.shape}, {type(d_out)}") if ctx.fp8_recipe.mxfp8(): d_out, d_out_format = dpa_utils.permute_to_grouped_tensor(d_out_format, d_out) if isinstance(d_out, QuantizedTensorStorage): d_out_fp8 = d_out else: d_out_fp8 = ctx.dO_quantizer(d_out) - print(f"d_out after quantizer: {d_out.shape}, {d_out_fp8._rowwise_data.shape}, {type(d_out)}, {type(d_out_fp8)}") if not isinstance(d_out, QuantizedTensorStorage) and not ctx.use_FAv2_bwd: d_out = d_out.contiguous() ( @@ -1602,14 +1586,6 @@ def backward(ctx, d_out, *_args): if ctx.fp8_recipe.mxfp8(): out_ = out aux_ctx_tensors.append(d_out) - print(f"q_fp8._with_gemm_swizzled_scales: {q_fp8._with_gemm_swizzled_scales}") - print(f"k_fp8._with_gemm_swizzled_scales: {k_fp8._with_gemm_swizzled_scales}") - print(f"v_fp8._with_gemm_swizzled_scales: {v_fp8._with_gemm_swizzled_scales}") - print(f"d_out_fp8._with_gemm_swizzled_scales: {d_out_fp8._with_gemm_swizzled_scales}") - # print(f"types: {type(q_fp8)}, {type(k_fp8)}, {type(v_fp8)}, {type(out_)}, {type(d_out_fp8)}, {[type(x) for x in aux_ctx_tensors]}") - # print(f"shapes: {q_fp8._rowwise_data.shape}, {k_fp8._rowwise_data.shape}, {v_fp8._rowwise_data.shape}, {out_.shape}, {d_out_fp8._rowwise_data.shape}, {[x.shape for x in aux_ctx_tensors]}") - print(f"out_.shape: {out_.shape}, d_out_fp8.shape: {d_out_fp8._rowwise_data.shape}, d_out_fp8.columnwise_data.shape: {d_out_fp8._columnwise_data.shape}, d_out.shape: {d_out.shape}") - print(f"out_.stride: {out_.stride()}, d_out_fp8.rowwise_data.stride: {d_out_fp8._rowwise_data.stride()}, d_out_fp8.columnwise_data.stride: {d_out_fp8._columnwise_data.stride()}, d_out.stride: {d_out.stride()}") dq_, dk_, dv_, *rest = fused_attn_bwd( ctx.max_seqlen_q, ctx.max_seqlen_kv, @@ -1644,15 +1620,6 @@ def backward(ctx, d_out, *_args): ctx.deterministic, is_graph_capturing(), ) - print(f"dq_.shape: {dq_.shape}, dk_.shape: {dk_.shape}, dv_.shape: {dv_.shape}") - print(f"types: {type(dq_)}, {type(dk_)}, {type(dv_)}") - # if ctx.original_qkv_layout != ctx.qkv_layout: - # original_qkv_format = ctx.original_qkv_layout.split("_")[0] - # new_qkv_format = ctx.qkv_layout.split("_")[0] - # perm = [] - # for i in new_qkv_format: - # perm.append(original_qkv_format.find(i)) - # dq_, dk_, dv_ = [x.permute(*perm).contiguous() for x in (dq_, dk_, dv_)] # dq, dk, dv: torch.Tensor; dtype = torch.float16 or torch.bfloat16 dq, dk, dv = dq_, dk_, dv_ diff --git a/transformer_engine/pytorch/attention/dot_product_attention/utils.py b/transformer_engine/pytorch/attention/dot_product_attention/utils.py index 16410e8e00..03a52ab870 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/utils.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/utils.py @@ -2241,14 +2241,6 @@ def combine_and_quantize(qkv_layout, q, k, v, qkv_quantizer): k_fp8 = qkv_quantizer(k) v_fp8 = qkv_quantizer(v) q_fp8, k_fp8, v_fp8 = [x.view(s) for x, s in zip([q_fp8, k_fp8, v_fp8], original_shapes)] - print(f"q_fp8._rowwise_data.shape: {q_fp8._rowwise_data.shape}, k_fp8._rowwise_data.shape: {k_fp8._rowwise_data.shape}, v_fp8._rowwise_data.shape: {v_fp8._rowwise_data.shape}") - print(f"q_fp8._rowwise_scale_inv.shape: {q_fp8._rowwise_scale_inv.shape}, k_fp8._rowwise_scale_inv.shape: {k_fp8._rowwise_scale_inv.shape}, v_fp8._rowwise_scale_inv.shape: {v_fp8._rowwise_scale_inv.shape}") - print(f"q_fp8._columnwise_data.shape: {q_fp8._columnwise_data.shape}, k_fp8._columnwise_data.shape: {k_fp8._columnwise_data.shape}, v_fp8._columnwise_data.shape: {v_fp8._columnwise_data.shape}") - print(f"q_fp8._columnwise_scale_inv.shape: {q_fp8._columnwise_scale_inv.shape}, k_fp8._columnwise_scale_inv.shape: {k_fp8._columnwise_scale_inv.shape}, v_fp8._columnwise_scale_inv.shape: {v_fp8._columnwise_scale_inv.shape}") - print(f"q_fp8._rowwise_data.stride: {q_fp8._rowwise_data.stride()}, k_fp8._rowwise_data.stride: {k_fp8._rowwise_data.stride()}, v_fp8._rowwise_data.stride: {v_fp8._rowwise_data.stride()}") - print(f"q_fp8._rowwise_scale_inv.stride: {q_fp8._rowwise_scale_inv.stride()}, k_fp8._rowwise_scale_inv.stride: {k_fp8._rowwise_scale_inv.stride()}, v_fp8._rowwise_scale_inv.stride: {v_fp8._rowwise_scale_inv.stride()}") - print(f"q_fp8._columnwise_data.stride: {q_fp8._columnwise_data.stride()}, k_fp8._columnwise_data.stride: {k_fp8._columnwise_data.stride()}, v_fp8._columnwise_data.stride: {v_fp8._columnwise_data.stride()}") - print(f"q_fp8._columnwise_scale_inv.stride: {q_fp8._columnwise_scale_inv.stride()}, k_fp8._columnwise_scale_inv.stride: {k_fp8._columnwise_scale_inv.stride()}, v_fp8._columnwise_scale_inv.stride: {v_fp8._columnwise_scale_inv.stride()}") return q_fp8, k_fp8, v_fp8, qkv_layout diff --git a/transformer_engine/pytorch/csrc/extensions/attention.cpp b/transformer_engine/pytorch/csrc/extensions/attention.cpp index d1b1baed30..fd193b0258 100644 --- a/transformer_engine/pytorch/csrc/extensions/attention.cpp +++ b/transformer_engine/pytorch/csrc/extensions/attention.cpp @@ -92,13 +92,6 @@ std::pair quantizer_helper(py::handle quantizer, "Float8CurrentScalingQuantizer::create_tensor() does not take data tensor as input!"); } } else if (detail::IsMXFP8Quantizers(quantizer.ptr())) { - printf("in quantizer_helper\n"); - printf("create_hp_tensor_for_cs: %d\n", create_hp_tensor_for_cs); - printf("data.has_value(): %d\n", data.has_value()); - printf("shape: %d, %d, %d, %d, %d\n", shape[0], shape[1], shape[2], shape[3], shape[4]); - printf("dtype: %d\n", dtype); - printf("quantizer: %p\n", quantizer.ptr()); - // MXFP8 auto *T_quantizer_fp8 = dynamic_cast(T_quantizer.get()); if (create_hp_tensor_for_cs) { @@ -106,7 +99,6 @@ std::pair quantizer_helper(py::handle quantizer, std::tie(te_T, py_T) = T_quantizer_fp8->create_unquantized_tensor_with_amax(shape, dtype, data.value()); } else { - printf("in quantizer_helper, creating unquantized tensor with amax\n"); std::tie(te_T, py_T) = T_quantizer_fp8->create_unquantized_tensor_with_amax(shape, dtype); } } else { @@ -163,8 +155,6 @@ std::vector fused_attn_fwd( auto o_shape = std::vector{o_shape_tmp.begin(), o_shape_tmp.end()}; size_t b=0, h=0, s=0, d=0, t=0; nvte_convert_qkv_format(nvte_get_q_format(qkv_layout), o_shape_tmp, o_format, o_shape, &b, &h, &s, &d, &t); - printf("b: %d, h: %d, s: %d, d: %d, t: %d\n", b, h, s, d, t); - printf("o_shape: %d, %d, %d, %d, %d\n", o_shape[0], o_shape[1], o_shape[2], o_shape[3]); const DType fake_dtype_te = GetTransformerEngineDType(fake_dtype); std::tie(te_O, py_O) = quantizer_helper(o_quantizer, o_shape, fake_dtype_te, true, std::nullopt); @@ -328,14 +318,6 @@ std::vector fused_attn_fwd( softmax_type, window_size[0], window_size[1], bottom_right_diagonal, workspace.data(), at::cuda::getCurrentCUDAStream()); }); - printf("after nvte_fused_attn_fwd\n"); - float *amax_cpu; - amax_cpu = (float *)malloc(sizeof(float)); - *amax_cpu=0.0; - cudaMemcpy(amax_cpu, te_O.amax(), sizeof(float), cudaMemcpyDeviceToHost); - cudaDeviceSynchronize(); - printf("O amax_cpu: %f\n", *amax_cpu); - // printf("py_O.amax(): %f\n", py_O.attr("amax").cast().cpu().item()); // destroy tensor wrappers, but not allocated memory nvte_tensor_pack_destroy(&nvte_aux_tensor_pack); @@ -394,7 +376,6 @@ std::vector fused_attn_bwd( if (!dqkv_quantizer.is_none()) { dqkv_type = dqkv_quantizer.attr("dtype").cast(); } - printf(">>>>>> dQKV_type: %d\n", dqkv_type); auto options = torch::TensorOptions().dtype(GetATenDType(dqkv_type)).device(torch::kCUDA); if (dqkv_type == DType::kFloat8E4M3 || dqkv_type == DType::kFloat8E5M2) { options = options.dtype(torch::kUInt8); @@ -607,21 +588,6 @@ std::vector fused_attn_bwd( at::cuda::getCurrentCUDAStream()); }); - float *amax_cpu; - amax_cpu = (float *)malloc(sizeof(float)); - *amax_cpu=0.0; - cudaMemcpy(amax_cpu, te_dQ.amax(), sizeof(float), cudaMemcpyDeviceToHost); - cudaDeviceSynchronize(); - printf("dQ amax_cpu: %f\n", *amax_cpu); - cudaMemcpy(amax_cpu, te_dK.amax(), sizeof(float), cudaMemcpyDeviceToHost); - cudaDeviceSynchronize(); - printf("dK amax_cpu: %f\n", *amax_cpu); - cudaMemcpy(amax_cpu, te_dV.amax(), sizeof(float), cudaMemcpyDeviceToHost); - cudaDeviceSynchronize(); - printf("dV amax_cpu: %f\n", *amax_cpu); - // printf("py_dQ.amax(): %f\n", py_dQ.attr("amax").cast().cpu().item()); - // printf("py_dK.amax(): %f\n", py_dK.attr("amax").cast().cpu().item()); - // printf("py_dV.amax(): %f\n", py_dV.attr("amax").cast().cpu().item()); // destroy tensor wrappers nvte_tensor_pack_destroy(&nvte_aux_tensor_pack); diff --git a/transformer_engine/pytorch/csrc/quantizer.cpp b/transformer_engine/pytorch/csrc/quantizer.cpp index d5c3ea00b4..20820143b0 100644 --- a/transformer_engine/pytorch/csrc/quantizer.cpp +++ b/transformer_engine/pytorch/csrc/quantizer.cpp @@ -949,9 +949,6 @@ std::pair MXFP8Quantizer::create_unquantized_tensor_w TensorWrapper out_cpp = std::move(out.first); py::object out_py = std::move(out.second); out_cpp.set_amax(amax_tensor.data_ptr(), DType::kFloat32, std::vector{1}); - printf("after MXFP8Quantizer::create_unquantized_tensor_with_amax\n"); - printf("amax_ptr: %p\n", amax_tensor.data_ptr()); - printf("out_cpp.amax(): %f\n", amax_tensor.cpu().item()); return {std::move(out_cpp), std::move(out_py)}; } From 9f8e856a3db99b3bbb3898f361345c862e3a1bf9 Mon Sep 17 00:00:00 2001 From: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Date: Thu, 26 Feb 2026 10:22:17 -0800 Subject: [PATCH 031/172] update FE Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --- 3rdparty/cudnn-frontend | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/3rdparty/cudnn-frontend b/3rdparty/cudnn-frontend index 4b4df2edcf..ae385ad82e 160000 --- a/3rdparty/cudnn-frontend +++ b/3rdparty/cudnn-frontend @@ -1 +1 @@ -Subproject commit 4b4df2edcf80b7cc7a659da5734454577154aa6d +Subproject commit ae385ad82e476bb75910d1ce92c6e25fdae42f40 From facef79b9dfc18fb04c12dcca63782ac50ecf222 Mon Sep 17 00:00:00 2001 From: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Date: Thu, 26 Feb 2026 11:42:42 -0800 Subject: [PATCH 032/172] update FE from pre-merge branch to post-merge develop Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --- 3rdparty/cudnn-frontend | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/3rdparty/cudnn-frontend b/3rdparty/cudnn-frontend index ae385ad82e..b4370f5198 160000 --- a/3rdparty/cudnn-frontend +++ b/3rdparty/cudnn-frontend @@ -1 +1 @@ -Subproject commit ae385ad82e476bb75910d1ce92c6e25fdae42f40 +Subproject commit b4370f5198bd95ee758ebc2c6b76b887914b702d From fd33cca2dbe607ac2bed257d00eb65e90a30b896 Mon Sep 17 00:00:00 2001 From: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Date: Thu, 26 Feb 2026 13:07:36 -0800 Subject: [PATCH 033/172] allow MXFP8 linear + f16 attn Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --- .../dot_product_attention.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py b/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py index f7699340e6..60b9812eb8 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py @@ -98,19 +98,19 @@ +-------------------+-----------+-----------------------------------------------------------------------------------+ | Linear | Attention | Configuration | +===================+===========+===================================================================================+ -| FP8DS/FP8CS/NVFP4 | FP16/BF16 | Pass FP8DS, FP8CS or NVFP4 to autocast(); | -| | | export NVTE_DPA_FP8_RECIPE="F16" | +| FP8DS/FP8CS/NVFP4 | FP16/BF16 | Pass FP8DS, FP8CS, NVFP4 or MXFP8 to autocast(); | +| /MXFP8 | | export NVTE_DPA_FP8_RECIPE="F16" | +-------------------+-----------+-----------------------------------------------------------------------------------+ -| FP8DS | FP8DS | Pass FP8DS to autocast(); | +| FP8DS | FP8DS | Pass FP8DS to autocast(); | +-------------------+-----------+-----------------------------------------------------------------------------------+ -| FP8CS | FP8DS | Pass FP8CS to autocast(); | +| FP8CS | FP8DS | Pass FP8CS to autocast(); | | | | Attention FP8DS reuses the fp8_format, fp8_dpa, fp8_mha values from linear FP8CS; | | | | export NVTE_DPA_FP8_RECIPE="DelayedScaling" # switch to DS | | | | export NVTE_DPA_FP8DS_AMAX_ALGO="most_recent" # or "max" | | | | export NVTE_DPA_FP8DS_AMAX_HISTLEN=1 # or any other integer | | | | export NVTE_DPA_FP8DS_REDUCE_AMAX=1 # or 0 | +-------------------+-----------+-----------------------------------------------------------------------------------+ -| NVFP4 | FP8DS | Pass NVFP4 to autocast(); | +| NVFP4 | FP8DS | Pass NVFP4 to autocast(); | | | | Attention FP8DS reuses the fp8_dpa, fp8_mha values from linear NVFP4; | | | | export NVTE_DPA_FP8_RECIPE="DelayedScaling" # switch to DS | | | | export NVTE_DPA_FP8_FORMAT="HYBRID" # or "E4M3", "E5M2" | @@ -118,19 +118,19 @@ | | | export NVTE_DPA_FP8DS_AMAX_HISTLEN=1 # or any other integer | | | | export NVTE_DPA_FP8DS_REDUCE_AMAX=1 # or 0 | +-------------------+-----------+-----------------------------------------------------------------------------------+ -| FP8DS | FP8CS | Pass FP8DS to autocast(); | +| FP8DS | FP8CS | Pass FP8DS to autocast(); | | | | Attention uses FP8DS for S, dP tensors, and creates a new FP8CS recipe for QKV, O,| | | | dO, dQKV tensors based on fp8_format, fp8_dpa, fp8_mha from linear FP8DS; | | | | export NVTE_DPA_FP8_RECIPE="Float8CurrentScaling" # switch to CS | +-------------------+-----------+-----------------------------------------------------------------------------------+ -| FP8CS | FP8CS | Pass FP8CS to autocast(); | +| FP8CS | FP8CS | Pass FP8CS to autocast(); | | | | Attention uses FP8CS for QKV, O, dO, dQKV tensors, and creates a new FP8DS recipe | | | | for S, dP tensors based on fp8_format, fp8_dpa, fp8_mha from linear FP8CS and: | | | | export NVTE_DPA_FP8DS_AMAX_ALGO="most_recent" # or "max" | | | | export NVTE_DPA_FP8DS_AMAX_HISTLEN=1 # or any other integer | | | | export NVTE_DPA_FP8DS_REDUCE_AMAX=1 # or 0 | +-------------------+-----------+-----------------------------------------------------------------------------------+ -| NVFP4 | FP8CS | Pass NVFP4 to autocast(); | +| NVFP4 | FP8CS | Pass NVFP4 to autocast(); | | | | Attention creates a new FP8CS recipe for QKV, O, dO, dQKV, and a new FP8DS recipe | | | | for S, dP, based on the fp8_dpa, fp8_mha values from linear NVFP4 and: | | | | export NVTE_DPA_FP8_RECIPE="Float8CurrentScaling" # switch to CS | From 5079d5588be0016f2e2244e7fd459185340d5f27 Mon Sep 17 00:00:00 2001 From: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Date: Thu, 26 Feb 2026 17:58:21 -0800 Subject: [PATCH 034/172] test cp a2a Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --- .../attention/run_attention_with_cp.py | 27 +- tests/pytorch/attention/test_attention.py | 4 +- .../attention/test_attention_with_cp.py | 21 +- .../common/fused_attn/fused_attn_fp8.cu | 46 +++- transformer_engine/common/fused_attn/utils.h | 7 + .../dot_product_attention/context_parallel.py | 230 +++++++++++++----- .../attention/dot_product_attention/utils.py | 15 +- 7 files changed, 261 insertions(+), 89 deletions(-) diff --git a/tests/pytorch/attention/run_attention_with_cp.py b/tests/pytorch/attention/run_attention_with_cp.py index 3efb516b57..b019289846 100644 --- a/tests/pytorch/attention/run_attention_with_cp.py +++ b/tests/pytorch/attention/run_attention_with_cp.py @@ -19,8 +19,9 @@ DotProductAttention, Float8Quantizer, Float8CurrentScalingQuantizer, + MXFP8Quantizer, ) -from transformer_engine.common.recipe import DelayedScaling, Float8CurrentScaling +from transformer_engine.common.recipe import DelayedScaling, Float8CurrentScaling, MXFP8BlockScaling, Format from utils import ModelConfig, compare_and_assert dtypes = {"fp16": torch.float16, "bf16": torch.bfloat16, "fp8": torch.bfloat16} @@ -189,7 +190,7 @@ def run_dpa_with_cp( os.environ["NVTE_FP8_DPA_BWD"] = "1" if fp8_bwd else "0" fp8_dpa = fp8_dpa == "True" and dtype == "fp8" fp8_mha = fp8_mha == "True" and dtype == "fp8" - f16_O = dtype == "fp8" and scaling_mode == "current" and f16_O == "True" + f16_O = dtype == "fp8" and scaling_mode in ["current", "mxfp8"] and f16_O == "True" os.environ["NVTE_DPA_FP8CS_O_in_F16"] = "1" if f16_O else "0" os.environ["NVTE_FLASH_ATTN"] = "0" os.environ["NVTE_FUSED_ATTN"] = "0" @@ -219,6 +220,7 @@ def run_dpa_with_cp( device_count = torch.cuda.device_count() device = rank % device_count torch.cuda.set_device(device) + print(f"rank: {rank}, world_size: {world_size}") logging.info(f"[Rank {rank}] Setup: world_size {world_size}") dist.init_process_group(backend="nccl", world_size=world_size, rank=rank) @@ -244,6 +246,8 @@ def run_dpa_with_cp( fp8_recipe = DelayedScaling(fp8_dpa=fp8_dpa, fp8_mha=fp8_mha) if scaling_mode == "current": fp8_recipe = Float8CurrentScaling(fp8_dpa=fp8_dpa, fp8_mha=fp8_mha) + if scaling_mode == "mxfp8": + fp8_recipe = MXFP8BlockScaling(fp8_format=Format.HYBRID, fp8_dpa=fp8_dpa, fp8_mha=fp8_mha) # instantiate attention module core_attn = DotProductAttention( @@ -297,10 +301,25 @@ def run_dpa_with_cp( fp8_dtype=tex.DType.kFloat8E5M2, device="cuda", ) + if scaling_mode == "mxfp8": + qkv_quantizer = MXFP8Quantizer( + fp8_dtype=tex.DType.kFloat8E4M3, + rowwise=True, + columnwise=True, + ) + qkv_quantizer.optimize_for_gemm = True + qkv_quantizer.internal = False + dout_quantizer = MXFP8Quantizer( + fp8_dtype=tex.DType.kFloat8E5M2, + rowwise=True, + columnwise=True, + ) + dout_quantizer.optimize_for_gemm = True + dout_quantizer.internal = False qkv_layout = "_".join([qkv_format] * 3) q, k, v, dout = [x.clone().detach() for x in [q_orig, k_orig, v_orig, dout_orig]] if fp8_mha: - q, k, v = combine_and_quantize(qkv_layout, q, k, v, qkv_quantizer) + q, k, v, qkv_layout = combine_and_quantize(qkv_layout, q, k, v, qkv_quantizer) for x in [q, k, v]: x.requires_grad = True @@ -386,7 +405,7 @@ def run_dpa_with_cp( dout_quantizer.scale.fill_(1.0) dout_quantizer.amax.fill_(0.0) if fp8_mha: - q_, k_, v_ = combine_and_quantize(qkv_layout, q_, k_, v_, qkv_quantizer) + q_, k_, v_, qkv_layout = combine_and_quantize(qkv_layout, q_, k_, v_, qkv_quantizer) q_, k_, v_ = [x.requires_grad_() for x in [q_, k_, v_]] if bias_ is not None: bias_ = bias_.view( diff --git a/tests/pytorch/attention/test_attention.py b/tests/pytorch/attention/test_attention.py index 9922d93a77..dc0d37f555 100644 --- a/tests/pytorch/attention/test_attention.py +++ b/tests/pytorch/attention/test_attention.py @@ -1788,8 +1788,8 @@ def get_model(dtype, config): model_configs_fp8_vs_f16 = { # test: ModelConfig(b, sq, hq, dqk) - "fp8_9": ModelConfig(2, 2048, 16, 128),#, attn_mask_type="causal"), - "fp8_10": ModelConfig(2, 2048, 24, 192, head_dim_v=128), #, num_gqa_groups=12, window_size=(512, 512)), + "fp8_9": ModelConfig(2, 2048, 24, 128, num_gqa_groups=12),#, attn_mask_type="causal"), + "fp8_10": ModelConfig(2, 2048, 24, 192, head_dim_v=128, num_gqa_groups=12, window_size=(512, 512)), "fp8_11": ModelConfig(1, 8192, 32, 128, num_gqa_groups=4), "fp8_12": ModelConfig(2, 2048, 16, 128, attn_mask_type="causal"), "fp8_13": ModelConfig(2, 2048, 24, 128, num_gqa_groups=12, attn_mask_type="causal"), diff --git a/tests/pytorch/attention/test_attention_with_cp.py b/tests/pytorch/attention/test_attention_with_cp.py index 836598087b..668c2745c7 100644 --- a/tests/pytorch/attention/test_attention_with_cp.py +++ b/tests/pytorch/attention/test_attention_with_cp.py @@ -17,6 +17,8 @@ from transformer_engine.common.recipe import ( DelayedScaling, Float8CurrentScaling, + MXFP8BlockScaling, + Format, ) from transformer_engine.pytorch.attention.dot_product_attention.utils import FlashAttentionUtils @@ -149,7 +151,7 @@ def test_cp_with_flash_attention(dtype, model, qkv_format, cp_comm_type): "cp_1_3": ModelConfig(2, 4096, 12, 128, attn_bias_type="post_scale_bias"), # MHA "cp_1_4": ModelConfig(2, 4096, 12, 128, attn_mask_type="causal", window_size=(512, 512)), # MHA "cp_2_0": ModelConfig(2, 4096, 12, 128, num_gqa_groups=2, attn_mask_type="causal"), # GQA - "cp_2_1": ModelConfig(2, 4096, 12, 128, num_gqa_groups=2), # GQA + "cp_2_1": ModelConfig(2, 4096, 16, 128),#, num_gqa_groups=12), # GQA "cp_2_2": ModelConfig( 2, 4096, @@ -166,7 +168,7 @@ def test_cp_with_flash_attention(dtype, model, qkv_format, cp_comm_type): 2, 4096, 12, 128, num_gqa_groups=2, attn_mask_type="causal", window_size=(512, 512) ), # GQA "cp_3_0": ModelConfig(2, 4096, 12, 128, attn_mask_type="causal", head_dim_v=64), # MLA - "cp_3_1": ModelConfig(2, 4096, 12, 128, head_dim_v=64), # MLA + "cp_3_1": ModelConfig(2, 4096, 12, 192, head_dim_v=128), # MLA "cp_3_2": ModelConfig( 2, 4096, 12, 128, attn_mask_type="causal", attn_bias_type="post_scale_bias", head_dim_v=64 ), # MLA @@ -192,14 +194,16 @@ def test_cp_with_flash_attention(dtype, model, qkv_format, cp_comm_type): "cp_1_1", "cp_1_4", "cp_2_0", + "cp_2_1", "cp_2_2", "cp_2_4", + "cp_3_1", "cp_3_2", "cp_4_2", ] model_configs_fused_attn = {k: model_configs_fused_attn[k] for k in configs} dtypes = ["bf16", "fp8"] - qkv_formats = ["sbhd", "thd"] + qkv_formats = ["bshd", "sbhd", "thd"] @pytest.mark.skipif(get_cudnn_version() < (8, 9, 7), reason="cuDNN 8.9.7+ is required.") @@ -211,7 +215,7 @@ def test_cp_with_flash_attention(dtype, model, qkv_format, cp_comm_type): @pytest.mark.parametrize("fp8_bwd", [True, False]) @pytest.mark.parametrize("fp8_mha", [True, False]) @pytest.mark.parametrize("fp8_dpa", [True, False]) -@pytest.mark.parametrize("scaling_mode", [None, "delayed", "current"]) +@pytest.mark.parametrize("scaling_mode", [None, "delayed", "current", "mxfp8"]) @pytest.mark.parametrize("f16_O", [True, False]) def test_cp_with_fused_attention( dtype, model, qkv_format, cp_comm_type, fp8_bwd, fp8_mha, fp8_dpa, scaling_mode, f16_O @@ -280,7 +284,7 @@ def test_cp_with_fused_attention( and cp_comm_type not in ["p2p", "a2a+p2p", "a2a"] ): pytest.skip("fp8 only works with P2P, A2A and A2A+P2P for scaling_mode = current!") - if f16_O and (dtype != "fp8" or scaling_mode != "current"): + if f16_O and (dtype != "fp8" or scaling_mode not in ["current", "mxfp8"]): pytest.skip("f16_O only needs to be tested for dtype = fp8 and scaling_mode = current!") if "p2p" not in cp_comm_type and config.head_dim_qk != config.head_dim_v: pytest.skip("MLA CP currently only support KV P2P!") @@ -301,6 +305,8 @@ def test_cp_with_fused_attention( "Unless cudnn version >= 9.18.0, CP implementation does not support qkv_format=thd for" " non-vanilla softmax types!" ) + if scaling_mode == "mxfp8" and not f16_O: + pytest.skip("MXFP8 only works with f16_O=True!") dtypes = {"fp16": torch.float16, "bf16": torch.bfloat16, "fp8": torch.bfloat16} @@ -324,6 +330,11 @@ def test_cp_with_fused_attention( Float8CurrentScaling(fp8_dpa=True), DelayedScaling(fp8_dpa=True), ] + if fp8 and scaling_mode == "mxfp8": + fp8_meta["recipe"] = MXFP8BlockScaling(fp8_format=Format.HYBRID, fp8_dpa=True) + fp8_meta["local_recipes"] = [ + MXFP8BlockScaling(fp8_format=Format.HYBRID, fp8_dpa=True), + ] available_backends, _, fused_attn_backends = get_available_attention_backends( config, qkv_dtype=dtypes[dtype] if dtype != "fp8" else torch.float8_e4m3fn, diff --git a/transformer_engine/common/fused_attn/fused_attn_fp8.cu b/transformer_engine/common/fused_attn/fused_attn_fp8.cu index 57b250f6af..da826688be 100644 --- a/transformer_engine/common/fused_attn/fused_attn_fp8.cu +++ b/transformer_engine/common/fused_attn/fused_attn_fp8.cu @@ -2399,17 +2399,17 @@ void fused_attn_fp8_bwd_impl_v1( .set_causal_mask(is_causal) .set_attn_scale(attn_scale); - // fe::DiagonalAlignment_t const &diagonal_alignment = - // bottom_right_diagonal ? fe::DiagonalAlignment_t::BOTTOM_RIGHT - // : fe::DiagonalAlignment_t::TOP_LEFT; - // sdpa_backward_options.set_diagonal_alignment(diagonal_alignment); + fe::DiagonalAlignment_t const &diagonal_alignment = + bottom_right_diagonal ? fe::DiagonalAlignment_t::BOTTOM_RIGHT + : fe::DiagonalAlignment_t::TOP_LEFT; + sdpa_backward_options.set_diagonal_alignment(diagonal_alignment); - // if (cudnn_runtime_version >= 90200 && window_size_left != -1) { - // sdpa_backward_options.set_diagonal_band_left_bound(window_size_left + 1); - // } - // if (cudnn_runtime_version >= 90600 && window_size_right != -1) { - // sdpa_backward_options.set_diagonal_band_right_bound(window_size_right); - // } + if (cudnn_runtime_version >= 90200 && window_size_left != -1) { + sdpa_backward_options.set_diagonal_band_left_bound(window_size_left + 1); + } + if (cudnn_runtime_version >= 90600 && window_size_right != -1) { + sdpa_backward_options.set_diagonal_band_right_bound(window_size_right); + } // sdpa_backward_options.set_alibi_mask(is_alibi); @@ -2632,9 +2632,33 @@ void fused_attn_fp8_bwd_impl_v1( variant_pack[dO_t] = devPtrdO_t; variant_pack[descale_q_t] = devPtrDescaleQ_t; variant_pack[descale_k_t] = devPtrDescaleK_t; - variant_pack[descale_dO] = devPtrDescaledO; + // variant_pack[descale_dO] = devPtrDescaledO; variant_pack[descale_dO_t] = devPtrDescaledO_t; } + int64_t modulo = 16; + printf("devPtrQ: %p, is_aligned: %d\n", devPtrQ, is_aligned_modulo(devPtrQ, modulo)); + printf("devPtrK: %p, is_aligned: %d\n", devPtrK, is_aligned_modulo(devPtrK, modulo)); + printf("devPtrV: %p, is_aligned: %d\n", devPtrV, is_aligned_modulo(devPtrV, modulo)); + printf("devPtrO: %p, is_aligned: %d\n", devPtrO, is_aligned_modulo(devPtrO, modulo)); + printf("devPtrM: %p, is_aligned: %d\n", devPtrM, is_aligned_modulo(devPtrM, modulo)); + printf("devPtrdO: %p, is_aligned: %d\n", devPtrdO, is_aligned_modulo(devPtrdO, modulo)); + printf("devPtrDescaleQ: %p, is_aligned: %d\n", devPtrDescaleQ, is_aligned_modulo(devPtrDescaleQ, modulo)); + printf("devPtrDescaleK: %p, is_aligned: %d\n", devPtrDescaleK, is_aligned_modulo(devPtrDescaleK, modulo)); + printf("devPtrDescaleV: %p, is_aligned: %d\n", devPtrDescaleV, is_aligned_modulo(devPtrDescaleV, modulo)); + printf("devPtrDescaledO: %p, is_aligned: %d\n", devPtrDescaledO, is_aligned_modulo(devPtrDescaledO, modulo)); + printf("devPtrDescaledO_t: %p, is_aligned: %d\n", devPtrDescaledO_t, is_aligned_modulo(devPtrDescaledO_t, modulo)); + printf("devPtrdQ: %p, is_aligned: %d\n", devPtrdQ, is_aligned_modulo(devPtrdQ, modulo)); + printf("devPtrdK: %p, is_aligned: %d\n", devPtrdK, is_aligned_modulo(devPtrdK, modulo)); + printf("devPtrdV: %p, is_aligned: %d\n", devPtrdV, is_aligned_modulo(devPtrdV, modulo)); + printf("devPtrAmaxdQ: %p, is_aligned: %d\n", devPtrAmaxdQ, is_aligned_modulo(devPtrAmaxdQ, modulo)); + printf("devPtrAmaxdK: %p, is_aligned: %d\n", devPtrAmaxdK, is_aligned_modulo(devPtrAmaxdK, modulo)); + printf("devPtrAmaxdV: %p, is_aligned: %d\n", devPtrAmaxdV, is_aligned_modulo(devPtrAmaxdV, modulo)); + printf("devPtrQ_t: %p, is_aligned: %d\n", devPtrQ_t, is_aligned_modulo(devPtrQ_t, modulo)); + printf("devPtrK_t: %p, is_aligned: %d\n", devPtrK_t, is_aligned_modulo(devPtrK_t, modulo)); + printf("devPtrdO_f16: %p, is_aligned: %d\n", devPtrdO_f16, is_aligned_modulo(devPtrdO_f16, modulo)); + printf("devPtrdO_t: %p, is_aligned: %d\n", devPtrdO_t, is_aligned_modulo(devPtrdO_t, modulo)); + printf("devPtrDescaleQ_t: %p, is_aligned: %d\n", devPtrDescaleQ_t, is_aligned_modulo(devPtrDescaleQ_t, modulo)); + printf("devPtrDescaleK_t: %p, is_aligned: %d\n", devPtrDescaleK_t, is_aligned_modulo(devPtrDescaleK_t, modulo)); /* if (is_bias) { variant_pack[bias] = devPtrBias; diff --git a/transformer_engine/common/fused_attn/utils.h b/transformer_engine/common/fused_attn/utils.h index f0b947c379..43d460bfd1 100644 --- a/transformer_engine/common/fused_attn/utils.h +++ b/transformer_engine/common/fused_attn/utils.h @@ -49,6 +49,13 @@ struct MXFP8PaddedSizes { int64_t d_v_scale_padded; }; +inline bool is_aligned_modulo(void* ptr, int64_t modulo) { + // Cast the pointer to a large enough integer type (uintptr_t) + uintptr_t address = reinterpret_cast(ptr); + // Check if the address is perfectly divisible by 16 + return (address % modulo) == 0; +} + // Pad s and d for MXFP8 layout inline MXFP8PaddedSizes pad_s_d_for_mxfp8(int64_t s_q, int64_t s_kv, int64_t d_qk, int64_t d_v) { constexpr int64_t block_size = 32; diff --git a/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py b/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py index 22d8378598..b701c37fe6 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py @@ -23,6 +23,8 @@ from transformer_engine.pytorch.quantization import FP8GlobalStateManager from transformer_engine.pytorch.tensor.float8_tensor import Float8Tensor from transformer_engine.pytorch.quantized_tensor import QuantizedTensorStorage +from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8Quantizer +from transformer_engine.common.recipe import MXFP8BlockScaling, Format from transformer_engine.pytorch.jit import jit_fuser from transformer_engine.pytorch.graph import is_graph_capturing from transformer_engine.pytorch.constants import ( @@ -58,6 +60,16 @@ # Float8CurrentScaling: fused_attn_bwd takes O in FP8 by default, this flag allows it in F16 _dpa_fp8_cs_o_in_f16 = os.getenv("NVTE_DPA_FP8CS_O_in_F16", "1") == "1" +def get_bsh_dims(tensor_format): + """Get batch dimension and sequence dimension from tensor format""" + if tensor_format in ["bshd", "sbhd", "bhsd"]: + batch_dim = tensor_format.index("b") + seq_dim = tensor_format.index("s") + head_dim = tensor_format.index("h") + else: # tensor_format == "thd" + batch_dim = seq_dim = tensor_format.index("t") + head_dim = tensor_format.index("h") + return batch_dim, seq_dim, head_dim def flash_attn_p2p_communicate( rank, send_tensor, send_dst, recv_tensor, recv_src, cp_group, batch_p2p_comm @@ -419,6 +431,7 @@ def flash_attn_a2a_communicate( ), "cu_seqlens_padded is required for THD format!" a2a_inputs = [a2a_inputs] if not isinstance(a2a_inputs, list) else a2a_inputs a2a_outputs, a2a_reqs = [None] * len(a2a_inputs), [None] * len(a2a_inputs) + batch_dim, _, head_dim = get_bsh_dims(qkv_format) if before_attn: for i in range(len(a2a_inputs) + 2): if 0 < i < len(a2a_inputs) + 1: @@ -430,13 +443,14 @@ def flash_attn_a2a_communicate( with torch.cuda.stream(cp_stream): a2a_reqs[i - 2].wait() x = a2a_outputs[i - 2] - if qkv_format in ["bshd", "sbhd"]: + if qkv_format in ["bshd", "sbhd", "bhsd"]: # reorder the sequence chunks x = reorder_seq_chunks_for_a2a_before_attn( x, chunk_ids_for_a2a, seq_dim, cp_size ) # [b, cp*2, s//2, np//cp, hn] -> [b, cp*s, np//cp, hn] # or [cp*2, s//2, b, np//cp, hn] -> [cp*s, b, np//cp, hn] + # or [b, np//cp, cp*2, s//2, hn] -> [b, np//cp, cp*s, hn] a2a_outputs[i - 2] = x.view( *x.shape[:seq_dim], -1, *x.shape[(seq_dim + 2) :] ) @@ -452,12 +466,14 @@ def flash_attn_a2a_communicate( x = a2a_inputs[i] # [b, s, np, hn] -> [b, s, cp, np//cp, hn] # or [s, b, np, hn] -> [s, b, cp, np//cp, hn] + # or [b, np, s, hn] -> [b, cp, np//cp, s, hn] # or [t, np, hn] -> [t, cp, np//cp, hn] - x = x.view(*x.shape[:-2], cp_size, x.shape[-2] // cp_size, x.shape[-1]) + x = x.view(*x.shape[:head_dim], cp_size, x.shape[head_dim] // cp_size, *x.shape[head_dim + 1:]) # [b, s, cp, np//cp, hn] -> [cp, b, s, np//cp, hn] # or [s, b, cp, np//cp, hn] -> [cp, s, b, np//cp, hn] + # or [b, cp, np//cp, s, hn] -> [cp, b, np//cp, s, hn] # or [t, cp, np//cp, hn] -> [cp, t, np//cp, hn] - a2a_inputs[i] = x.movedim(-3, 0).contiguous() + a2a_inputs[i] = x.movedim(head_dim, 0).contiguous() else: for i in range(len(a2a_inputs) + 2): if 0 < i < len(a2a_inputs) + 1: @@ -467,9 +483,10 @@ def flash_attn_a2a_communicate( ) if i < len(a2a_inputs): x = a2a_inputs[i] - if qkv_format in ["bshd", "sbhd"]: + if qkv_format in ["bshd", "sbhd", "bhsd"]: # [b, cp*s, np//cp, hn] -> [b, cp*2, s//2, np//cp, hn] # or [cp*s, b, np//cp, hn] -> [cp*2, s//2, b, np//cp, hn] + # or [b, np//cp, cp*s, hn] -> [b, np//cp, cp*2, s//2, hn] x = x.view(*x.shape[:seq_dim], cp_size * 2, -1, *x.shape[(seq_dim + 1) :]) # reorder the sequence chunks a2a_inputs[i] = reorder_seq_chunks_for_a2a_after_attn( @@ -486,10 +503,12 @@ def flash_attn_a2a_communicate( x = a2a_outputs[i - 2] # [cp, 2, b, s//2, np//cp, hn] -> [b, 2, s//2, cp, np//cp, hn] # or [cp, 2, s//2, b, np//cp, hn] -> [2, s//2, b, cp, np//cp, hn] + # or [cp, 2, b, np//cp, s//2, hn] -> [b, cp, np//cp, 2, s//2, hn] # or [cp, t, np//cp, hn] -> [t, cp, np//cp, hn] - x = x.movedim(0, -3).movedim(0, seq_dim).contiguous() + x = x.movedim(0, head_dim+1).movedim(0, seq_dim+1).contiguous() # [b, 2, s//2, cp, np//cp, hn] -> [b*s, np, hn] # or [2, s//2, b, cp, np//cp, hn] -> [s*b, np, hn] + # or [b, cp, np//cp, 2, s//2, hn] -> [b*np, s, hn] # or [t, cp, np//cp, hn] -> [t, np, hn] a2a_outputs[i - 2] = x.view(-1, x.shape[-3] * x.shape[-2], x.shape[-1]) torch.cuda.current_stream().wait_stream(cp_stream) @@ -3367,21 +3386,19 @@ def forward( ), "The number of attention heads needs to be divisible by CP size!" qkv_layout = qkv_format + "_" + qkv_format + "_" + qkv_format - - if qkv_format in ["bshd", "sbhd"]: - batch_dim = qkv_format.index("b") - seq_dim = qkv_format.index("s") - else: # qkv_format == "thd" - batch_dim = seq_dim = qkv_format.index("t") + original_qkv_layout = qkv_layout + o_format = qkv_format + batch_dim_qkv, seq_dim_qkv, _ = get_bsh_dims(qkv_format) + _, seq_dim_o, _ = get_bsh_dims(o_format) assert ( - q.shape[seq_dim] % 2 == 0 and k.shape[seq_dim] % 2 == 0 + q.shape[seq_dim_qkv] % 2 == 0 and k.shape[seq_dim_qkv] % 2 == 0 ), "Sequence length per GPU needs to be divisible by 2!" assert isinstance(k, q.__class__) and isinstance( v, q.__class__ - ), "q, k, v must be of the same class, e.g. torch.Tensor or Float8Tensor." - is_input_fp8 = isinstance(q, Float8Tensor) + ), "q, k, v must be of the same class, e.g. torch.Tensor or QuantizedTensorStorage." + is_input_fp8 = isinstance(q, QuantizedTensorStorage) is_output_fp8 = fp8_output is_bwd_fp8 = int(os.getenv("NVTE_FP8_DPA_BWD", "1")) # recipe passed in through autocast or set by NVTE_DPA_FP8_RECIPE; @@ -3392,6 +3409,9 @@ def forward( fwd_nominal_dtype = q.dtype fused_attn_backend = None max_logit = None + if torch.cuda.current_device() == 0: + print(f"is_input_fp8: {is_input_fp8}, is_output_fp8: {is_output_fp8}, is_bwd_fp8: {is_bwd_fp8}") + print(f"fp8: {fp8}, fp8_recipe: {fp8_recipe}") QKV_quantizer, O_quantizer, S_quantizer, dQKV_quantizer, dO_quantizer, dP_quantizer = ( dpa_utils.get_attention_quantizers(fp8, fp8_recipe, quantizers) @@ -3403,10 +3423,14 @@ def forward( fused_attn_backend = FusedAttnBackend["FP8"] if is_input_fp8: q_fp8, k_fp8, v_fp8 = q, k, v - elif not isinstance(QKV_quantizer, MXFP8Quantizer): + elif not fp8_recipe.mxfp8(): q_fp8, k_fp8, v_fp8, qkv_layout = combine_and_quantize(qkv_layout, q, k, v, QKV_quantizer) - q, k, v = [q_fp8._data, k_fp8._data, v_fp8._data] - + if not fp8_recipe.mxfp8(): + q, k, v = [q_fp8._data, k_fp8._data, v_fp8._data] + # else: + # q, k, v = [q_fp8, k_fp8, v_fp8] + # qkv_format, _, _ = dpa_utils.get_qkv_format(qkv_layout) + # batch_dim_qkv, seq_dim_qkv, _ = get_bsh_dims(qkv_format) fp8_meta_kwargs = {} fp8_meta_kwargs["s_quantizer"] = S_quantizer fp8_meta_kwargs["o_quantizer"] = O_quantizer @@ -3417,11 +3441,15 @@ def forward( fp8_meta_kwargs = {} fused_attn_backend = FusedAttnBackend["F16_arbitrary_seqlen"] + if torch.cuda.current_device() == 0: + print(f"before flash_attn_a2a_communicate: q: {q.shape} {type(q)}, k: {k.shape} {type(k)}, v: {v.shape} {type(v)}") + print(f"qkv_format: {qkv_format}, o_format: {o_format}") + print(f"batch_dim_qkv: {batch_dim_qkv}, seq_dim_qkv: {seq_dim_qkv}") chunk_ids_for_a2a = get_seq_chunk_ids_for_reordering_before_attn(cp_size, q.device) q, k, v = flash_attn_a2a_communicate( [q, k, v], chunk_ids_for_a2a, - seq_dim, + seq_dim_qkv, cp_size, cp_group, cp_stream, @@ -3429,6 +3457,8 @@ def forward( qkv_format=qkv_format, cu_seqlens_padded=cu_seqlens_q_padded, ) + if torch.cuda.current_device() == 0: + print(f"after flash_attn_a2a_communicate: q: {q.shape} {type(q)}, k: {k.shape} {type(k)}, v: {v.shape} {type(v)}") if softmax_type != "vanilla": softmax_offset = flash_attn_a2a_communicate_softmax_offset( softmax_offset, 1, cp_size, cp_group, cp_stream, True @@ -3436,15 +3466,20 @@ def forward( out_fp8 = None out_f16 = None - batch_size = q.shape[batch_dim] + batch_size = q.shape[batch_dim_qkv] q_part, k_part, v_part = q, k, v out_part = None if use_fused_attention: - if fp8: + if fp8 and not fp8_recipe.mxfp8(): q_part, k_part, v_part = [ Float8Tensor.make_like(x, data=y, dtype=fwd_nominal_dtype) for x, y in zip([q_fp8, k_fp8, v_fp8], [q_part, k_part, v_part]) ] + if fp8 and fp8_recipe.mxfp8(): + q_fp8, k_fp8, v_fp8, qkv_layout = combine_and_quantize(qkv_layout, q_part, k_part, v_part, QKV_quantizer) + q_part, k_part, v_part = [q_fp8, k_fp8, v_fp8] + if torch.cuda.current_device() == 0: + print(f"before fused_attn_fwd: q_part: {q_part.shape} {type(q_part)}, k_part: {k_part.shape} {type(k_part)}, v_part: {v_part.shape} {type(v_part)}") out_, aux_ctx_tensors, *max_logit = fused_attn_fwd( is_training, max_seqlen_q, @@ -3459,6 +3494,7 @@ def forward( attn_scale=softmax_scale, dropout=dropout_p, qkv_layout=qkv_layout, + o_format=o_format, attn_mask_type=attn_mask_type, attn_bias_type=attn_bias_type, attn_bias=attn_bias, @@ -3471,7 +3507,9 @@ def forward( return_max_logit=return_max_logit, cuda_graph=is_graph_capturing(), ) - if isinstance(out_, Float8Tensor): + if torch.cuda.current_device() == 0: + print(f"after fused_attn_fwd: out_: {out_.shape} {type(out_)}") + if isinstance(out_, QuantizedTensorStorage): out_fp8 = out_ out_ = out_._data if is_bwd_fp8 and not ( @@ -3487,6 +3525,7 @@ def forward( fp8 and is_bwd_fp8 and not (fp8_recipe.float8_current_scaling() and _dpa_fp8_cs_o_in_f16) + and not fp8_recipe.mxfp8() ): out_part = O_quantizer(out_) else: @@ -3516,33 +3555,39 @@ def forward( aux_ctx_tensors = [softmax_lse, rng_state] out_part = out_ + if torch.cuda.current_device() == 0: + print(f"before flash_attn_a2a_communicate: out_: {out_.shape} {type(out_)}") chunk_ids_for_a2a = get_seq_chunk_ids_for_reordering_after_attn(cp_size, out_.device) out_ = flash_attn_a2a_communicate( out_, chunk_ids_for_a2a, - seq_dim, + seq_dim_o, cp_size, cp_group, cp_stream, before_attn=False, - qkv_format=qkv_format, + qkv_format=o_format, cu_seqlens_padded=cu_seqlens_q_padded, ) + if torch.cuda.current_device() == 0: + print(f"after flash_attn_a2a_communicate: out_: {out_.shape} {type(out_)}") if return_max_logit: max_logit = flash_attn_a2a_communicate_softmax_offset( *max_logit, 0, cp_size, cp_group, cp_stream, False ) if use_fused_attention: - if qkv_format == "bshd": + if o_format == "bshd": # [b*s, h, d] -> [b, s, h, d] out_ = out_.view(batch_size, -1, *out_.shape[-2:]) - elif qkv_format == "sbhd": + elif o_format == "sbhd": # [s*b, h, d] -> [s, b, h, d] out_ = out_.view(-1, batch_size, *out_.shape[-2:]) + if torch.cuda.current_device() == 0: + print(f"after view: out_: {out_.shape} {type(out_)}") if fp8 and use_fused_attention: - if fp8_recipe.float8_current_scaling(): + if fp8_recipe.float8_current_scaling() or fp8_recipe.mxfp8(): out_f16 = out_ if is_output_fp8: out_fp8 = O_quantizer(out_) @@ -3556,19 +3601,28 @@ def forward( out_ret = out_fp8 if is_output_fp8 else out_f16 ctx.fp8 = fp8 and is_bwd_fp8 + ctx.qkv_layout = qkv_layout + ctx.o_format = o_format + ctx.dqkv_layout = original_qkv_layout fp8_tensors = (None, None, None, None) f16_tensors = (None, None, None, None) if ctx.fp8: - if fp8_recipe.float8_current_scaling() and _dpa_fp8_cs_o_in_f16: + if (fp8_recipe.float8_current_scaling() and _dpa_fp8_cs_o_in_f16) or fp8_recipe.mxfp8(): fp8_tensors = (q_part, k_part, v_part, None) f16_tensors = (None, None, None, out_part) else: fp8_tensors = (q_part, k_part, v_part, out_part) - elif fp8: + elif fp8 and not fp8_recipe.mxfp8(): q_part, k_part, v_part = combine_and_dequantize(qkv_layout, q_part, k_part, v_part) f16_tensors = (q_part, k_part, v_part, out_part) + elif fp8 and fp8_recipe.mxfp8(): + f16_tensors = (q, k, v, out_part) + ctx.qkv_layout = original_qkv_layout else: f16_tensors = (q_part, k_part, v_part, out_part) + if torch.cuda.current_device() == 0: + print(f"fp8_tensors: {[x.shape if x is not None else None for x in fp8_tensors]}, type of fp8_tensors: {[type(x) for x in fp8_tensors]}") + print(f"f16_tensors: {[x.shape if x is not None else None for x in f16_tensors]}, type of f16_tensors: {[type(x) for x in f16_tensors]}") tensors_to_save, tensor_objects = prepare_for_saving( *fp8_tensors, @@ -3590,7 +3644,7 @@ def forward( ctx.max_seqlen_q = max_seqlen_q ctx.max_seqlen_kv = max_seqlen_kv ctx.softmax_scale = softmax_scale - ctx.qkv_format = qkv_format + # ctx.qkv_format = qkv_format ctx.attn_mask_type = attn_mask_type ctx.attn_bias_type = attn_bias_type ctx.deterministic = deterministic @@ -3612,11 +3666,13 @@ def forward( ctx.S_quantizer = S_quantizer if ctx.fp8: ctx.QKV_quantizer = QKV_quantizer.copy() - ctx.QKV_quantizer.scale = QKV_quantizer.scale.clone() ctx.O_quantizer = O_quantizer.copy() - ctx.O_quantizer.scale = O_quantizer.scale.clone() - ctx.S_quantizer = S_quantizer.copy() - ctx.S_quantizer.scale = S_quantizer.scale.clone() + ctx.S_quantizer = S_quantizer.copy() if S_quantizer is not None else None + + if not ctx.fp8_recipe.mxfp8(): + ctx.QKV_quantizer.scale = QKV_quantizer.scale.clone() + ctx.O_quantizer.scale = O_quantizer.scale.clone() + ctx.S_quantizer.scale = S_quantizer.scale.clone() nvtx_range_pop("transformer_engine.AttnFuncWithCPAndQKVOA2A.forward") if return_max_logit: return out_ret, max_logit @@ -3644,27 +3700,28 @@ def backward(ctx, dout, *_args): *aux_ctx_tensors, ) = restore_from_saved(ctx.tensor_objects, ctx.saved_tensors) - qkv_format = ctx.qkv_format - qkv_layout = qkv_format + "_" + qkv_format + "_" + qkv_format + # qkv_format = ctx.qkv_format + # qkv_layout = qkv_format + "_" + qkv_format + "_" + qkv_format + # qkv_layout = ctx.qkv_layout causal = "causal" in ctx.attn_mask_type + dqkv_format, _, _ = dpa_utils.get_qkv_format(ctx.dqkv_layout) - if qkv_format in ["bshd", "sbhd"]: - seq_dim = qkv_format.index("s") - else: # qkv_format == "thd" - seq_dim = qkv_format.index("t") + batch_dim_dqkv, seq_dim_dqkv, _ = get_bsh_dims(dqkv_format) + _, seq_dim_do, _ = get_bsh_dims(ctx.o_format) bwd_nominal_dtype = ctx.fwd_nominal_dtype - dqkv_te_dtype = None + # dqkv_te_dtype = None fused_attn_backend = None dout_fp8 = dout if ctx.fp8: if ctx.use_fused_attention: fused_attn_backend = FusedAttnBackend["FP8"] - if not isinstance(dout, QuantizedTensorStorage): + if not isinstance(dout, QuantizedTensorStorage) and not ctx.fp8_recipe.mxfp8(): dout = ctx.dO_quantizer(dout) dout_fp8 = dout - dqkv_te_dtype = dout._fp8_dtype - dout = dout._data + if not ctx.fp8_recipe.mxfp8(): + # dqkv_te_dtype = dout._fp8_dtype + dout = dout._data fp8_meta_kwargs = {} fp8_meta_kwargs["s_quantizer"] = ctx.S_quantizer fp8_meta_kwargs["dp_quantizer"] = ctx.dP_quantizer @@ -3677,29 +3734,32 @@ def backward(ctx, dout, *_args): dout = dout.dequantize(dtype=bwd_nominal_dtype) if ctx.use_fused_attention: fp8_meta_kwargs = {} - dqkv_te_dtype = TE_DType[dout.dtype] + # dqkv_te_dtype = TE_DType[dout.dtype] fused_attn_backend = FusedAttnBackend["F16_arbitrary_seqlen"] if not ctx.use_fused_attention: - if qkv_format in ["bshd", "sbhd"]: + if ctx.o_format in ["bshd", "sbhd"]: out = out.view(ctx.batch_size, -1, *out.shape[-2:]) dout = dout.view(ctx.batch_size, -1, *dout.shape[-2:]) else: dout = dout.view(*ctx.out_shape) + if torch.cuda.current_device() == 0: + print(f"before flash_attn_a2a_communicate: dout: {dout.shape} {type(dout)}") chunk_ids_for_a2a = get_seq_chunk_ids_for_reordering_before_attn(cp_size, dout.device) dout = flash_attn_a2a_communicate( dout, chunk_ids_for_a2a, - seq_dim, + seq_dim_do, cp_size, ctx.cp_group, ctx.cp_stream, before_attn=True, - qkv_format=qkv_format, + qkv_format=ctx.o_format, cu_seqlens_padded=cu_seqlens_q_padded, ) - + if torch.cuda.current_device() == 0: + print(f"after flash_attn_a2a_communicate: dout: {dout.shape} {type(dout)}") flash_attn_bwd = None if not ctx.use_fused_attention: fa_backward_kwargs = {"softmax_scale": ctx.softmax_scale} @@ -3714,7 +3774,7 @@ def backward(ctx, dout, *_args): fa_backward_kwargs["window_size"] = ctx.window_size fa_backward_kwargs["deterministic"] = ctx.deterministic else: - if qkv_format == "thd": + if ctx.o_format == "thd": from transformer_engine.pytorch.attention.dot_product_attention.backends import ( _flash_attn_varlen_bwd, ) @@ -3740,13 +3800,31 @@ def backward(ctx, dout, *_args): fa_backward_kwargs["softcap"] = 0.0 dq_fp8, dk_fp8, dv_fp8 = None, None, None + d_out_format = ctx.o_format if ctx.use_fused_attention: q_part, k_part, v_part, out_part, dout_part = q, k, v, out, dout if ctx.fp8: q_part, k_part, v_part, out_part = q_fp8, k_fp8, v_fp8, out_fp8 - if ctx.fp8_recipe.float8_current_scaling() and _dpa_fp8_cs_o_in_f16: + if (ctx.fp8_recipe.float8_current_scaling() and _dpa_fp8_cs_o_in_f16) or ctx.fp8_recipe.mxfp8(): out_part = out - dout_part = Float8Tensor.make_like(dout_fp8, data=dout, dtype=bwd_nominal_dtype) + if not ctx.fp8_recipe.mxfp8(): + dout_part = Float8Tensor.make_like(dout_fp8, data=dout, dtype=bwd_nominal_dtype) + else: + dout, d_out_format = dpa_utils.permute_to_grouped_tensor(d_out_format, dout) + dout_part = ctx.dO_quantizer(dout) + print(f"dout.ptr: {hex(dout.data_ptr())}, {hex(dout_part._rowwise_data.data_ptr())}, {hex(dout_part._columnwise_data.data_ptr())}, {hex(dout_part._rowwise_scale_inv.data_ptr())}, {hex(dout_part._columnwise_scale_inv.data_ptr())}") + aux_ctx_tensors.append(dout) + if torch.cuda.current_device() == 0: + print(f"before fused_attn_bwd: q_part: {q_part.shape} {type(q_part)}, k_part: {k_part.shape} {type(k_part)}, v_part: {v_part.shape} {type(v_part)}, out_part: {out_part.shape} {type(out_part)}, dout_part: {dout_part.shape} {type(dout_part)}") + print(f"type of aux_ctx_tensors: {[type(x) for x in aux_ctx_tensors]} {[x.shape if x is not None else None for x in aux_ctx_tensors]}") + print(f"fused_attn_backend: {fused_attn_backend}") + # print(f"cu_seqlens_q: {cu_seqlens_q.shape} {type(cu_seqlens_q)}, cu_seqlens_kv: {cu_seqlens_kv.shape} {type(cu_seqlens_kv)}") + # print(f"cu_seqlens_q_padded: {cu_seqlens_q_padded.shape} {type(cu_seqlens_q_padded)}, cu_seqlens_kv_padded: {cu_seqlens_kv_padded.shape} {type(cu_seqlens_kv_padded)}") + # print(f"ctx.softmax_scale: {ctx.softmax_scale}, ctx.dropout_p: {ctx.dropout_p}, ctx.window_size: {ctx.window_size}, ctx.deterministic: {ctx.deterministic}") + print(f"ctx.qkv_layout: {ctx.qkv_layout}, ctx.o_format: {ctx.o_format}, ctx.dqkv_layout: {ctx.dqkv_layout}") + # print(f"ctx.attn_mask_type: {ctx.attn_mask_type}, ctx.attn_bias_type: {ctx.attn_bias_type}") + print(f"is contiguous: {q_part.is_contiguous()}, {k_part.is_contiguous()}, {v_part.is_contiguous()}, {out_part.is_contiguous()}, {dout_part.is_contiguous()}") + print(fp8_meta_kwargs) dq, dk, dv, *rest = fused_attn_bwd( ctx.max_seqlen_q, ctx.max_seqlen_kv, @@ -3758,14 +3836,17 @@ def backward(ctx, dout, *_args): out_part, dout_part, bwd_nominal_dtype, - dqkv_te_dtype, + # dqkv_te_dtype, aux_ctx_tensors, fused_attn_backend, cu_seqlens_q_padded=cu_seqlens_q_padded, cu_seqlens_kv_padded=cu_seqlens_kv_padded, attn_scale=ctx.softmax_scale, dropout=ctx.dropout_p, - qkv_layout=qkv_layout, + qkv_layout=ctx.qkv_layout, + o_format=ctx.o_format, + d_out_format=d_out_format, + dqkv_layout=ctx.dqkv_layout, attn_mask_type=ctx.attn_mask_type, attn_bias_type=ctx.attn_bias_type, window_size=ctx.window_size, @@ -3774,7 +3855,7 @@ def backward(ctx, dout, *_args): **fp8_meta_kwargs, softmax_type=ctx.softmax_type, ) - if isinstance(dq, Float8Tensor): + if isinstance(dq, QuantizedTensorStorage) and not ctx.fp8_recipe.mxfp8(): dq_fp8, dk_fp8, dv_fp8 = dq, dk, dv dq, dk, dv = [x._data for x in [dq, dk, dv]] else: @@ -3783,7 +3864,7 @@ def backward(ctx, dout, *_args): fa_backward_args_thd = get_fa_args( False, ctx.use_flash_attn_3, - qkv_format, + ctx.o_format, cu_seqlens_q=cu_seqlens_q, cu_seqlens_kv=cu_seqlens_kv, max_seqlen_q=ctx.max_seqlen_q, @@ -3806,22 +3887,27 @@ def backward(ctx, dout, *_args): **fa_backward_kwargs, ) + if torch.cuda.current_device() == 0: + print(f"after flash_attn_bwd: dq: {dq.shape}, dk: {dk.shape}, dv: {dv.shape}") + print(f"type of dq: {type(dq)}, type of dk: {type(dk)}, type of dv: {type(dv)}") chunk_ids_for_a2a = get_seq_chunk_ids_for_reordering_after_attn(cp_size, dq.device) dq, dk, dv = flash_attn_a2a_communicate( [dq, dk, dv], chunk_ids_for_a2a, - seq_dim, + seq_dim_dqkv, cp_size, ctx.cp_group, ctx.cp_stream, before_attn=False, - qkv_format=qkv_format, + qkv_format=dqkv_format, cu_seqlens_padded=cu_seqlens_q_padded, ) - - if qkv_format == "bshd": + if torch.cuda.current_device() == 0: + print(f"after flash_attn_a2a_communicate: dq: {dq.shape}, dk: {dk.shape}, dv: {dv.shape}") + print(f"type of dq: {type(dq)}, type of dk: {type(dk)}, type of dv: {type(dv)}") + if dqkv_format == "bshd": dq, dk, dv = [x.view(ctx.batch_size, -1, *x.shape[-2:]) for x in [dq, dk, dv]] - elif qkv_format == "sbhd": + elif dqkv_format == "sbhd": dq, dk, dv = [x.view(-1, ctx.batch_size, *x.shape[-2:]) for x in [dq, dk, dv]] d_bias = None @@ -3836,8 +3922,8 @@ def backward(ctx, dout, *_args): ) if ctx.fp8: - if ctx.fp8_recipe.float8_current_scaling() and ctx.is_input_fp8: - dq, dk, dv = combine_and_quantize(qkv_layout, dq, dk, dv, ctx.dQKV_quantizer) + if (ctx.fp8_recipe.float8_current_scaling() or ctx.fp8_recipe.mxfp8()) and ctx.is_input_fp8: + dq, dk, dv = combine_and_quantize(ctx.dqkv_layout, dq, dk, dv, ctx.dQKV_quantizer) if ctx.fp8_recipe.delayed(): dq, dk, dv = [ Float8Tensor.make_like(x, data=y, dtype=bwd_nominal_dtype) @@ -3845,13 +3931,15 @@ def backward(ctx, dout, *_args): ] if not ctx.is_input_fp8: dq, dk, dv = combine_and_dequantize( - qkv_layout, + ctx.dqkv_layout, dq, dk, dv, src_nominal_dtype=bwd_nominal_dtype, ) - + if torch.cuda.current_device() == 0: + print(f"after combine_and_dequantize: dq: {dq.shape}, dk: {dk.shape}, dv: {dv.shape}") + print(f"type of dq: {type(dq)}, type of dk: {type(dk)}, type of dv: {type(dv)}") nvtx_range_pop("transformer_engine.AttnFuncWithCPAndQKVOA2A.backward") return ( @@ -3982,7 +4070,19 @@ def attn_forward_func_with_cp( in Megatron-LM. """ - + print(f"rank {torch.distributed.get_rank()} attn_forward_func_with_cp: {cp_comm_type=}") + print(f"rank {torch.distributed.get_rank()} attn_forward_func_with_cp: {qkv_format=}") + # print(f"rank {torch.distributed.get_rank()} attn_forward_func_with_cp: {deterministic=}") + print(f"rank {torch.distributed.get_rank()} attn_forward_func_with_cp: {use_fused_attention=}") + print(f"rank {torch.distributed.get_rank()} attn_forward_func_with_cp: {fp8=}") + # print(f"rank {torch.distributed.get_rank()} attn_forward_func_with_cp: {fp8_meta=}") + # print(f"rank {torch.distributed.get_rank()} attn_forward_func_with_cp: {cp_group=}") + print(f"rank {torch.distributed.get_rank()} attn_forward_func_with_cp: {cp_global_ranks=}") + # print(f"rank {torch.distributed.get_rank()} attn_forward_func_with_cp: {cp_stream=}") + # print(f"rank {torch.distributed.get_rank()} attn_forward_func_with_cp: {quantizers=}") + print(f"rank {torch.distributed.get_rank()} attn_forward_func_with_cp: {pad_between_seqs=}") + print(f"rank {torch.distributed.get_rank()} attn_forward_func_with_cp: {fp8_output=}") + # print(f"rank {torch.distributed.get_rank()} attn_forward_func_with_cp: {layer_number=}") if cp_comm_type == "a2a+p2p": assert ( isinstance(cp_group, list) and len(cp_group) == 2 diff --git a/transformer_engine/pytorch/attention/dot_product_attention/utils.py b/transformer_engine/pytorch/attention/dot_product_attention/utils.py index 03a52ab870..eb2a7f8e94 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/utils.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/utils.py @@ -515,6 +515,17 @@ def get_attention_backend( " with cuDNN < 9.18.0" ) use_fused_attention = False + if use_fused_attention and fp8_recipe.mxfp8(): + if device_compute_capability < (10, 0): + logger.debug("Disabling FusedAttention for MXFP8 on arch < sm100") + use_fused_attention = False + else: + if cudnn_version < (9, 21, 0): + logger.debug("Disabling FusedAttention for MXFP8 with cuDNN < 9.21.0") + use_fused_attention = False + elif qkv_format == "thd": + logger.debug("Disabling FusedAttention for MXFP8 with qkv_format = thd") + use_fused_attention = False if device_compute_capability == (12, 0): if use_flash_attention: @@ -2096,7 +2107,7 @@ def get_attention_quantizers(fp8, fp8_recipe, quantizers): return [None] * 6 QKV_quantizer = quantizers["scaling_fwd"][META_QKV] - QKV_quantizer.internal = True + QKV_quantizer.internal = False QKV_quantizer.set_usage(rowwise=True, columnwise=False) S_quantizer = quantizers["scaling_fwd"][META_S] @@ -2108,7 +2119,7 @@ def get_attention_quantizers(fp8, fp8_recipe, quantizers): O_quantizer.set_usage(rowwise=True, columnwise=False) dO_quantizer = quantizers["scaling_bwd"][META_DO] - dO_quantizer.internal = True + dO_quantizer.internal = False dO_quantizer.set_usage(rowwise=True, columnwise=False) dP_quantizer = quantizers["scaling_bwd"][META_DP] From 06b7d491c6a819b6977bf6a7721351ffcdfaeb31 Mon Sep 17 00:00:00 2001 From: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Date: Thu, 26 Feb 2026 18:04:34 -0800 Subject: [PATCH 035/172] remove prints temporarily Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --- .../dot_product_attention/context_parallel.py | 47 ------------------- 1 file changed, 47 deletions(-) diff --git a/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py b/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py index b701c37fe6..78c19826c8 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py @@ -3409,9 +3409,6 @@ def forward( fwd_nominal_dtype = q.dtype fused_attn_backend = None max_logit = None - if torch.cuda.current_device() == 0: - print(f"is_input_fp8: {is_input_fp8}, is_output_fp8: {is_output_fp8}, is_bwd_fp8: {is_bwd_fp8}") - print(f"fp8: {fp8}, fp8_recipe: {fp8_recipe}") QKV_quantizer, O_quantizer, S_quantizer, dQKV_quantizer, dO_quantizer, dP_quantizer = ( dpa_utils.get_attention_quantizers(fp8, fp8_recipe, quantizers) @@ -3441,10 +3438,6 @@ def forward( fp8_meta_kwargs = {} fused_attn_backend = FusedAttnBackend["F16_arbitrary_seqlen"] - if torch.cuda.current_device() == 0: - print(f"before flash_attn_a2a_communicate: q: {q.shape} {type(q)}, k: {k.shape} {type(k)}, v: {v.shape} {type(v)}") - print(f"qkv_format: {qkv_format}, o_format: {o_format}") - print(f"batch_dim_qkv: {batch_dim_qkv}, seq_dim_qkv: {seq_dim_qkv}") chunk_ids_for_a2a = get_seq_chunk_ids_for_reordering_before_attn(cp_size, q.device) q, k, v = flash_attn_a2a_communicate( [q, k, v], @@ -3457,8 +3450,6 @@ def forward( qkv_format=qkv_format, cu_seqlens_padded=cu_seqlens_q_padded, ) - if torch.cuda.current_device() == 0: - print(f"after flash_attn_a2a_communicate: q: {q.shape} {type(q)}, k: {k.shape} {type(k)}, v: {v.shape} {type(v)}") if softmax_type != "vanilla": softmax_offset = flash_attn_a2a_communicate_softmax_offset( softmax_offset, 1, cp_size, cp_group, cp_stream, True @@ -3478,8 +3469,6 @@ def forward( if fp8 and fp8_recipe.mxfp8(): q_fp8, k_fp8, v_fp8, qkv_layout = combine_and_quantize(qkv_layout, q_part, k_part, v_part, QKV_quantizer) q_part, k_part, v_part = [q_fp8, k_fp8, v_fp8] - if torch.cuda.current_device() == 0: - print(f"before fused_attn_fwd: q_part: {q_part.shape} {type(q_part)}, k_part: {k_part.shape} {type(k_part)}, v_part: {v_part.shape} {type(v_part)}") out_, aux_ctx_tensors, *max_logit = fused_attn_fwd( is_training, max_seqlen_q, @@ -3507,8 +3496,6 @@ def forward( return_max_logit=return_max_logit, cuda_graph=is_graph_capturing(), ) - if torch.cuda.current_device() == 0: - print(f"after fused_attn_fwd: out_: {out_.shape} {type(out_)}") if isinstance(out_, QuantizedTensorStorage): out_fp8 = out_ out_ = out_._data @@ -3555,8 +3542,6 @@ def forward( aux_ctx_tensors = [softmax_lse, rng_state] out_part = out_ - if torch.cuda.current_device() == 0: - print(f"before flash_attn_a2a_communicate: out_: {out_.shape} {type(out_)}") chunk_ids_for_a2a = get_seq_chunk_ids_for_reordering_after_attn(cp_size, out_.device) out_ = flash_attn_a2a_communicate( out_, @@ -3569,8 +3554,6 @@ def forward( qkv_format=o_format, cu_seqlens_padded=cu_seqlens_q_padded, ) - if torch.cuda.current_device() == 0: - print(f"after flash_attn_a2a_communicate: out_: {out_.shape} {type(out_)}") if return_max_logit: max_logit = flash_attn_a2a_communicate_softmax_offset( *max_logit, 0, cp_size, cp_group, cp_stream, False @@ -3583,8 +3566,6 @@ def forward( elif o_format == "sbhd": # [s*b, h, d] -> [s, b, h, d] out_ = out_.view(-1, batch_size, *out_.shape[-2:]) - if torch.cuda.current_device() == 0: - print(f"after view: out_: {out_.shape} {type(out_)}") if fp8 and use_fused_attention: if fp8_recipe.float8_current_scaling() or fp8_recipe.mxfp8(): @@ -3620,9 +3601,6 @@ def forward( ctx.qkv_layout = original_qkv_layout else: f16_tensors = (q_part, k_part, v_part, out_part) - if torch.cuda.current_device() == 0: - print(f"fp8_tensors: {[x.shape if x is not None else None for x in fp8_tensors]}, type of fp8_tensors: {[type(x) for x in fp8_tensors]}") - print(f"f16_tensors: {[x.shape if x is not None else None for x in f16_tensors]}, type of f16_tensors: {[type(x) for x in f16_tensors]}") tensors_to_save, tensor_objects = prepare_for_saving( *fp8_tensors, @@ -3744,8 +3722,6 @@ def backward(ctx, dout, *_args): else: dout = dout.view(*ctx.out_shape) - if torch.cuda.current_device() == 0: - print(f"before flash_attn_a2a_communicate: dout: {dout.shape} {type(dout)}") chunk_ids_for_a2a = get_seq_chunk_ids_for_reordering_before_attn(cp_size, dout.device) dout = flash_attn_a2a_communicate( dout, @@ -3758,8 +3734,6 @@ def backward(ctx, dout, *_args): qkv_format=ctx.o_format, cu_seqlens_padded=cu_seqlens_q_padded, ) - if torch.cuda.current_device() == 0: - print(f"after flash_attn_a2a_communicate: dout: {dout.shape} {type(dout)}") flash_attn_bwd = None if not ctx.use_fused_attention: fa_backward_kwargs = {"softmax_scale": ctx.softmax_scale} @@ -3812,19 +3786,7 @@ def backward(ctx, dout, *_args): else: dout, d_out_format = dpa_utils.permute_to_grouped_tensor(d_out_format, dout) dout_part = ctx.dO_quantizer(dout) - print(f"dout.ptr: {hex(dout.data_ptr())}, {hex(dout_part._rowwise_data.data_ptr())}, {hex(dout_part._columnwise_data.data_ptr())}, {hex(dout_part._rowwise_scale_inv.data_ptr())}, {hex(dout_part._columnwise_scale_inv.data_ptr())}") aux_ctx_tensors.append(dout) - if torch.cuda.current_device() == 0: - print(f"before fused_attn_bwd: q_part: {q_part.shape} {type(q_part)}, k_part: {k_part.shape} {type(k_part)}, v_part: {v_part.shape} {type(v_part)}, out_part: {out_part.shape} {type(out_part)}, dout_part: {dout_part.shape} {type(dout_part)}") - print(f"type of aux_ctx_tensors: {[type(x) for x in aux_ctx_tensors]} {[x.shape if x is not None else None for x in aux_ctx_tensors]}") - print(f"fused_attn_backend: {fused_attn_backend}") - # print(f"cu_seqlens_q: {cu_seqlens_q.shape} {type(cu_seqlens_q)}, cu_seqlens_kv: {cu_seqlens_kv.shape} {type(cu_seqlens_kv)}") - # print(f"cu_seqlens_q_padded: {cu_seqlens_q_padded.shape} {type(cu_seqlens_q_padded)}, cu_seqlens_kv_padded: {cu_seqlens_kv_padded.shape} {type(cu_seqlens_kv_padded)}") - # print(f"ctx.softmax_scale: {ctx.softmax_scale}, ctx.dropout_p: {ctx.dropout_p}, ctx.window_size: {ctx.window_size}, ctx.deterministic: {ctx.deterministic}") - print(f"ctx.qkv_layout: {ctx.qkv_layout}, ctx.o_format: {ctx.o_format}, ctx.dqkv_layout: {ctx.dqkv_layout}") - # print(f"ctx.attn_mask_type: {ctx.attn_mask_type}, ctx.attn_bias_type: {ctx.attn_bias_type}") - print(f"is contiguous: {q_part.is_contiguous()}, {k_part.is_contiguous()}, {v_part.is_contiguous()}, {out_part.is_contiguous()}, {dout_part.is_contiguous()}") - print(fp8_meta_kwargs) dq, dk, dv, *rest = fused_attn_bwd( ctx.max_seqlen_q, ctx.max_seqlen_kv, @@ -3887,9 +3849,6 @@ def backward(ctx, dout, *_args): **fa_backward_kwargs, ) - if torch.cuda.current_device() == 0: - print(f"after flash_attn_bwd: dq: {dq.shape}, dk: {dk.shape}, dv: {dv.shape}") - print(f"type of dq: {type(dq)}, type of dk: {type(dk)}, type of dv: {type(dv)}") chunk_ids_for_a2a = get_seq_chunk_ids_for_reordering_after_attn(cp_size, dq.device) dq, dk, dv = flash_attn_a2a_communicate( [dq, dk, dv], @@ -3902,9 +3861,6 @@ def backward(ctx, dout, *_args): qkv_format=dqkv_format, cu_seqlens_padded=cu_seqlens_q_padded, ) - if torch.cuda.current_device() == 0: - print(f"after flash_attn_a2a_communicate: dq: {dq.shape}, dk: {dk.shape}, dv: {dv.shape}") - print(f"type of dq: {type(dq)}, type of dk: {type(dk)}, type of dv: {type(dv)}") if dqkv_format == "bshd": dq, dk, dv = [x.view(ctx.batch_size, -1, *x.shape[-2:]) for x in [dq, dk, dv]] elif dqkv_format == "sbhd": @@ -3937,9 +3893,6 @@ def backward(ctx, dout, *_args): dv, src_nominal_dtype=bwd_nominal_dtype, ) - if torch.cuda.current_device() == 0: - print(f"after combine_and_dequantize: dq: {dq.shape}, dk: {dk.shape}, dv: {dv.shape}") - print(f"type of dq: {type(dq)}, type of dk: {type(dk)}, type of dv: {type(dv)}") nvtx_range_pop("transformer_engine.AttnFuncWithCPAndQKVOA2A.backward") return ( From 7fbe399c80e5fa177a1cfaceb3422286c9773289 Mon Sep 17 00:00:00 2001 From: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Date: Thu, 26 Feb 2026 19:58:41 -0800 Subject: [PATCH 036/172] test cp p2p Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --- .../dot_product_attention/context_parallel.py | 208 ++++++++++++------ .../attention/dot_product_attention/utils.py | 1 + 2 files changed, 138 insertions(+), 71 deletions(-) diff --git a/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py b/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py index 78c19826c8..dda856f36a 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py @@ -794,13 +794,16 @@ def cp_p2p_fwd_fused_attn( softmax_scale, dropout_p, qkv_layout, + o_format, attn_mask_type, attn_bias_type, fp8, + fp8_recipe, q_fp8, k_fp8, v_fp8, fwd_nominal_dtype, + QKV_quantizer, S_quantizer_per_step, O_quantizer_per_step, rank, @@ -875,11 +878,15 @@ def cp_p2p_fwd_fused_attn( cu_seqlens_kv_padded_ = cu_seqlens_kv_padded fp8_meta_kwargs = {} + new_qkv_layout = qkv_layout if fp8: - q_part, k_part, v_part = [ - Float8Tensor.make_like(x, data=y, dtype=fwd_nominal_dtype) - for x, y in zip([q_fp8, k_fp8, v_fp8], [q_part, k_part, v_part]) - ] + if not fp8_recipe.mxfp8(): + q_part, k_part, v_part = [ + Float8Tensor.make_like(x, data=y, dtype=fwd_nominal_dtype) + for x, y in zip([q_fp8, k_fp8, v_fp8], [q_part, k_part, v_part]) + ] + else: + q_part, k_part, v_part, new_qkv_layout = combine_and_quantize(qkv_layout, q_part, k_part, v_part, QKV_quantizer) fp8_meta_kwargs["s_quantizer"] = S_quantizer_per_step fp8_meta_kwargs["o_quantizer"] = O_quantizer_per_step @@ -896,7 +903,8 @@ def cp_p2p_fwd_fused_attn( fused_attention_backend=fused_attn_backend, attn_scale=softmax_scale, dropout=dropout_p, - qkv_layout=qkv_layout, + qkv_layout=new_qkv_layout, + o_format=o_format, attn_mask_type=attn_mask_type_, attn_bias_type=attn_bias_type, attn_bias=attn_bias_inputs, @@ -915,7 +923,7 @@ def cp_p2p_fwd_fused_attn( if return_max_logit: return out_per_step, softmax_lse_per_step, rng_states, attn_bias, *max_logit - return out_per_step, softmax_lse_per_step, rng_states, attn_bias, None + return out_per_step, softmax_lse_per_step, rng_states, attn_bias, None #, new_qkv_layout def cp_p2p_fwd_flash_attn( @@ -1073,15 +1081,21 @@ def cp_p2p_bwd_fused_attn( softmax_scale, dropout_p, qkv_layout, + o_format, + d_out_format, + dqkv_layout, attn_mask_type, attn_bias_type, deterministic, fwd_nominal_dtype, bwd_nominal_dtype, - bwd_output_te_dtype, + # bwd_output_te_dtype, S_quantizer, dP_quantizer_per_step, dQKV_quantizer_per_step, + # O_quantizer_per_step, + QKV_quantizer_per_step, + dO_quantizer_per_step, q_part, k_part, v_part, @@ -1131,16 +1145,26 @@ def cp_p2p_bwd_fused_attn( fp8_meta_kwargs = {} if fp8: - q_part, k_part, v_part = [ - Float8Tensor.make_like(x, data=y, dtype=fwd_nominal_dtype) - for x, y in zip( - [q_fp8, kv_fp8, kv_fp8], - [q_part, k_part, v_part], - ) - ] - if not (fp8_recipe.float8_current_scaling() and _dpa_fp8_cs_o_in_f16): - out_part = Float8Tensor.make_like(out_fp8, data=out_part, dtype=fwd_nominal_dtype) - dout_part = Float8Tensor.make_like(dout_fp8, data=dout_part, dtype=bwd_nominal_dtype) + if not fp8_recipe.mxfp8(): + q_part, k_part, v_part = [ + Float8Tensor.make_like(x, data=y, dtype=fwd_nominal_dtype) + for x, y in zip( + [q_fp8, kv_fp8, kv_fp8], + [q_part, k_part, v_part], + ) + ] + else: + q_part, k_part, v_part, qkv_layout = combine_and_quantize(qkv_layout, q_part, k_part, v_part, QKV_quantizer_per_step) + if not fp8_recipe.mxfp8(): + if not (fp8_recipe.float8_current_scaling() and _dpa_fp8_cs_o_in_f16): + out_part = Float8Tensor.make_like(out_fp8, data=out_part, dtype=fwd_nominal_dtype) + dout_part = Float8Tensor.make_like(dout_fp8, data=dout_part, dtype=bwd_nominal_dtype) + else: + # out_part, o_format = dpa_utils.permute_to_grouped_tensor(o_format, out_part) + dout_part, d_out_format = dpa_utils.permute_to_grouped_tensor(d_out_format, dout_part) + # out_part = O_quantizer_per_step(out_part) + aux_tensors.append(dout_part) + dout_part = dO_quantizer_per_step(dout_part) fp8_meta_kwargs["s_quantizer"] = S_quantizer fp8_meta_kwargs["dp_quantizer"] = dP_quantizer_per_step fp8_meta_kwargs["dqkv_quantizer"] = dQKV_quantizer_per_step @@ -1156,7 +1180,7 @@ def cp_p2p_bwd_fused_attn( out_part, dout_part, bwd_nominal_dtype, - bwd_output_te_dtype, + # bwd_output_te_dtype, aux_tensors, fused_attn_backend, cu_seqlens_q_padded=cu_seqlens_q_padded_, @@ -1164,6 +1188,9 @@ def cp_p2p_bwd_fused_attn( attn_scale=softmax_scale, dropout=dropout_p, qkv_layout=qkv_layout, + o_format=o_format, + d_out_format=d_out_format, + dqkv_layout=dqkv_layout, attn_mask_type=attn_mask_type_, attn_bias_type=attn_bias_type, deterministic=deterministic, @@ -1405,13 +1432,13 @@ def forward( # q_fp8, k_fp8, v_fp8: Float8Tensor, dtype=fwd_nominal_dtype # q, k, v: torch.Tensor, dtype=torch.uint8 q_fp8, k_fp8, v_fp8 = q, k, v - q, k, v = [q_fp8._data, k_fp8._data, v_fp8._data] - else: + elif not fp8_recipe.mxfp8(): # q_f16: torch.Tensor, dtype=fwd_nominal_dtype # q_fp8, k_fp8, v_fp8: Float8Tensor, dtype=fwd_nominal_dtype # q, k, v: torch.Tensor, dtype=torch.uint8 q_f16 = q q_fp8, k_fp8, v_fp8, qkv_layout = combine_and_quantize(qkv_layout, q, k, v, QKV_quantizer) + if not fp8_recipe.mxfp8(): q, k, v = [q_fp8._data, k_fp8._data, v_fp8._data] # print quantizers @@ -1432,10 +1459,11 @@ def forward( # per_step tensors are not reduced even if Float8CurrentScaling.with_amax_reduction=True; # only used to hold temporary scale/amax values (output only, no quantization op) for i in range(cp_size): - S_quantizer_per_step[i] = S_quantizer.copy() - S_quantizer_per_step[i].amax = amax_per_step[0][i].reshape((1,)) + S_quantizer_per_step[i] = S_quantizer.copy() if S_quantizer is not None else None O_quantizer_per_step[i] = O_quantizer.copy() - O_quantizer_per_step[i].amax = amax_per_step[1][i].reshape((1,)) + if not fp8_recipe.mxfp8(): + S_quantizer_per_step[i].amax = amax_per_step[0][i].reshape((1,)) + O_quantizer_per_step[i].amax = amax_per_step[1][i].reshape((1,)) else: # q_f16: torch.Tensor, dtype=fwd_nominal_dtype # q, k, v: torch.Tensor, dtype=fwd_nominal_dtype @@ -1555,7 +1583,9 @@ def forward( # f16 attention: q, k, v: torch.Tensor, dtype=fwd_nominal_dtype # fp8 attention: q, k, v: torch.Tensor, dtype=torch.uint8 out = None + o_format = qkv_format for i in range(cp_size + 1): + print(f">>>>>>>>>>>> {torch.cuda.current_device()}: i: {i}, cp_size: {cp_size}") if i < cp_size: with torch.cuda.stream(flash_attn_streams[i % 2]): # wait until KV is received @@ -1608,13 +1638,16 @@ def forward( softmax_scale, dropout_p, qkv_layout, + o_format, attn_mask_type, attn_bias_type, fp8, + fp8_recipe, q_fp8, k_fp8, v_fp8, fwd_nominal_dtype, + QKV_quantizer, S_quantizer_per_step[i], O_quantizer_per_step[i], rank, @@ -1666,6 +1699,7 @@ def forward( rng_states[i], attn_biases[i], max_logit_per_step[i], + # qkv_layout, ) = cp_p2p_fwd_fused_attn( *fused_attn_inputs, *prepare_outputs, section ) @@ -1693,6 +1727,7 @@ def forward( rng_states[i], attn_biases[i], max_logit_per_step[i], + # qkv_layout, ) = cp_p2p_fwd_fused_attn( *fused_attn_inputs, *prepare_outputs, section ) @@ -1720,6 +1755,7 @@ def forward( rng_states[i], attn_biases[i], max_logit_per_step[i], + # qkv_layout, ) = cp_p2p_fwd_fused_attn( *fused_attn_inputs, *prepare_outputs, section ) @@ -1748,6 +1784,7 @@ def forward( rng_states[i], attn_biases[i], max_logit_per_step[i], + # qkv_layout, ) = cp_p2p_fwd_fused_attn(*fused_attn_inputs, *prepare_outputs, section) else: out_per_step[i], softmax_lse_per_step[i], rng_states[i] = ( @@ -1775,7 +1812,7 @@ def forward( out_per_step[i - 1] = out_per_step[i - 1].dequantize( dtype=torch.float32 ) - if fp8_recipe.float8_current_scaling(): + if fp8_recipe.float8_current_scaling() or fp8_recipe.mxfp8(): out_per_step[i - 1] = out_per_step[i - 1].to(dtype=torch.float32) if i == 1: @@ -1829,7 +1866,7 @@ def forward( # fwd output correction: out in torch.float32 for i in range(cp_size): if i <= rank or not causal: - if qkv_format in ["bshd", "sbhd"]: + if o_format in ["bshd", "sbhd"]: if i == 0: out = flash_attn_fwd_out_correction_init( out_per_step[0], @@ -1849,7 +1886,7 @@ def forward( softmax_lse_per_step[i], seq_dim, ) - elif qkv_format == "thd": + elif o_format == "thd": tex.thd_out_correction( out, out_per_step[i], @@ -1860,7 +1897,7 @@ def forward( softmax_lse_in_packed_format, ) else: - if qkv_format in ["bshd", "sbhd"]: + if o_format in ["bshd", "sbhd"]: flash_attn_fwd_second_half_out_correction( out, out_per_step[i], @@ -1868,7 +1905,7 @@ def forward( softmax_lse_per_step[i], seq_dim, ) - elif qkv_format == "thd": + elif o_format == "thd": tex.thd_out_correction( out, out_per_step[i], @@ -1879,10 +1916,10 @@ def forward( softmax_lse_in_packed_format, ) - if qkv_format == "bshd": + if o_format == "bshd": out = out.view(out.shape[0], -1, *out.shape[-2:]) ctx.batch_size = out.shape[0] - elif qkv_format == "sbhd": + elif o_format == "sbhd": out = out.view(-1, *out.shape[-3:]) ctx.batch_size = out.shape[1] @@ -1892,10 +1929,10 @@ def forward( out, chunk_ids_for_a2a, seq_dim, cp_size_a2a, cp_group_a2a, cp_stream, False ) if use_fused_attention: - if qkv_format == "bshd": + if o_format == "bshd": # [b*s, h, d] -> [b, s, h, d] out = out.view(ctx.batch_size, -1, *out.shape[-2:]) - elif qkv_format == "sbhd": + elif o_format == "sbhd": # [s*b, h, d] -> [s, b, h, d] out = out.view(-1, ctx.batch_size, *out.shape[-2:]) if return_max_logit: @@ -1906,7 +1943,7 @@ def forward( out = out.view(-1, *out.shape[-2:]) # update FP8 quantizers: amax across cp_size steps - if fp8 and use_fused_attention: + if fp8 and use_fused_attention and not fp8_recipe.mxfp8(): amax_cp_fwd = amax_per_step.amax(dim=1) S_quantizer.amax.copy_(amax_cp_fwd[0]) O_quantizer.amax.copy_(amax_cp_fwd[1]) @@ -1929,7 +1966,7 @@ def forward( out_f16 = out.to(fwd_nominal_dtype) if fp8 and ( is_output_fp8 - or (is_bwd_fp8 and not (fp8_recipe.float8_current_scaling() and _dpa_fp8_cs_o_in_f16)) + or (is_bwd_fp8 and not (fp8_recipe.float8_current_scaling() and _dpa_fp8_cs_o_in_f16) and not fp8_recipe.mxfp8()) ): out_fp8 = O_quantizer(out_f16) out_ret = out_fp8 if (fp8 and is_output_fp8) else out_f16 @@ -1940,7 +1977,8 @@ def forward( kv_fp8 = None kv = p2p_comm_buffers[-1] - if fp8: + q_fp8, kv_fp8 = None, None + if fp8 and not fp8_recipe.mxfp8(): q_fp8, kv_fp8 = [ Float8Tensor.make_like(x, data=y, dtype=fwd_nominal_dtype) for x, y in zip([q_fp8, k_fp8], [q, kv]) @@ -1953,12 +1991,22 @@ def forward( fp8_tensors = (q_fp8, kv_fp8, out_fp8) if fp8_recipe.float8_current_scaling() and _dpa_fp8_cs_o_in_f16: f16_tensors = (None, None, out_f16) - elif fp8 and is_input_fp8: + elif fp8_recipe.mxfp8(): + f16_tensors = (q, kv, out_f16) + elif fp8 and is_input_fp8 and not fp8_recipe.mxfp8(): # fwd: fp8, bwd: f16, save all f16 # dequantize fp8 inputs q_f16 = q_fp8.dequantize() kv_f16 = kv_fp8.dequantize() f16_tensors = (q_f16, kv_f16, out_f16) + elif fp8 and is_input_fp8 and fp8_recipe.mxfp8(): + # fwd: fp8, bwd: f16, save all f16 + # there is already an F16 version of the inputs + q_f16, k_f16, v_f16 = combine_and_dequantize(qkv_layout, q, k, v) + kv = torch.cat((k_f16.view(-1), v_f16.view(-1)), dim=-1) + f16_tensors = (q_f16, kv, out_f16) + elif fp8 and not is_input_fp8 and fp8_recipe.mxfp8(): + f16_tensors = (q, kv, out_f16) elif fp8: # fwd: fp8, bwd: f16, save all f16 # inputs are already in f16 @@ -1971,6 +2019,9 @@ def forward( q_f16 = q_f16.view(q.shape) kv_f16 = kv f16_tensors = (q_f16, kv_f16, out_f16) + if torch.cuda.current_device() == 0: + print(f"fp8_tensors: {[x.shape if x is not None else None for x in fp8_tensors]} {[type(x) for x in fp8_tensors]}") + print(f"f16_tensors: {[x.shape if x is not None else None for x in f16_tensors]} {[type(x) for x in f16_tensors]}") tensors_to_save, tensor_objects = prepare_for_saving( *fp8_tensors, @@ -2023,11 +2074,12 @@ def forward( ctx.S_quantizer = S_quantizer if ctx.fp8: ctx.QKV_quantizer = QKV_quantizer.copy() - ctx.QKV_quantizer.scale = QKV_quantizer.scale.clone() ctx.O_quantizer = O_quantizer.copy() - ctx.O_quantizer.scale = O_quantizer.scale.clone() - ctx.S_quantizer = S_quantizer.copy() - ctx.S_quantizer.scale = S_quantizer.scale.clone() + ctx.S_quantizer = S_quantizer.copy() if S_quantizer is not None else None + if not ctx.fp8_recipe.mxfp8(): + ctx.QKV_quantizer.scale = QKV_quantizer.scale.clone() + ctx.O_quantizer.scale = O_quantizer.scale.clone() + ctx.S_quantizer.scale = S_quantizer.scale.clone() nvtx_range_pop(f"{nvtx_label}") @@ -2045,7 +2097,7 @@ def backward(ctx, dout, *_args): # dout is expected to be in FP8 if is_output_fp8=True, # but in the case it's not, convert it to FP8 before any operation - if ctx.fp8 and ctx.is_output_fp8 and not isinstance(dout, QuantizedTensorStorage): + if ctx.fp8 and ctx.is_output_fp8 and not isinstance(dout, QuantizedTensorStorage) and not ctx.fp8_recipe.mxfp8(): dout = ctx.dO_quantizer(dout) if ctx.use_fused_attention: dout._data = dout._data.contiguous() @@ -2086,6 +2138,7 @@ def backward(ctx, dout, *_args): causal = "causal" in ctx.attn_mask_type seq_dim = None qkv_layout = ctx.qkv_format + "_" + ctx.qkv_format + "_" + ctx.qkv_format + o_format = ctx.qkv_format if ctx.qkv_format in ["bshd", "sbhd"]: seq_dim = ctx.qkv_format.index("s") @@ -2142,28 +2195,33 @@ def backward(ctx, dout, *_args): buffer_dtype = torch.uint8 dq_buffer = None dout_fp8 = None - bwd_output_te_dtype = None + # bwd_output_te_dtype = None dkv_buffer = None + d_out_format = o_format if ctx.fp8: assert ctx.use_fused_attention, "FP8 is only supported with Fused Attention!" fused_attn_backend = FusedAttnBackend["FP8"] - q, kv, out = ( - q_fp8._data, - kv_fp8._data, - ( - out - if ctx.fp8_recipe.float8_current_scaling() and _dpa_fp8_cs_o_in_f16 - else out_fp8._data - ), - ) + if not ctx.fp8_recipe.mxfp8(): + q, kv, out = ( + q_fp8._data, + kv_fp8._data, + ( + out + if ctx.fp8_recipe.float8_current_scaling() and _dpa_fp8_cs_o_in_f16 + else out_fp8._data + ), + ) # dout_fp8: Float8Tensor, dtype=bwd_nominal_dtype # dout: torch.Tensor, dtype=torch.uint8 - if ctx.is_output_fp8: + # if ctx.fp8_recipe.mxfp8(): + # dout, d_out_format = dpa_utils.permute_to_grouped_tensor(d_out_format, dout) + if isinstance(dout, QuantizedTensorStorage): dout_fp8 = dout - else: + elif not ctx.fp8_recipe.mxfp8(): dout_fp8 = ctx.dO_quantizer(dout) - dout = dout_fp8._data + if not ctx.fp8_recipe.mxfp8(): + dout = dout_fp8._data # print quantizers print_quantizers( @@ -2178,7 +2236,7 @@ def backward(ctx, dout, *_args): ) # dout_fp8._fp8_dtype - bwd_output_te_dtype = ctx.dO_quantizer.dtype + # bwd_output_te_dtype = ctx.dO_quantizer.dtype # create buffers for reduction in float32 if ctx.fp8_recipe.delayed(): @@ -2187,7 +2245,7 @@ def backward(ctx, dout, *_args): dtype=buffer_dtype, device=q.device, ) - if ctx.fp8_recipe.float8_current_scaling(): + if ctx.fp8_recipe.float8_current_scaling() or ctx.fp8_recipe.mxfp8(): dq_buffer = torch.empty( q.shape, dtype=torch.float32, @@ -2201,7 +2259,7 @@ def backward(ctx, dout, *_args): ) dkv_recv_buffer = torch.empty_like(dkv_send_buffer) p2p_comm_buffers = [[kv, dkv_send_buffer], [kv_recv_buffer, dkv_recv_buffer]] - if ctx.fp8_recipe.float8_current_scaling(): + if ctx.fp8_recipe.float8_current_scaling() or ctx.fp8_recipe.mxfp8(): dkv_buffer = torch.zeros( kv.shape, dtype=torch.float32, @@ -2214,10 +2272,11 @@ def backward(ctx, dout, *_args): # per_step tensors are not reduced even if Float8CurrentScaling.with_amax_reduction=True; # only used to hold temporary scale/amax values (output only, no quantization op) for i in range(cp_size): - dP_quantizer_per_step[i] = ctx.dP_quantizer.copy() - dP_quantizer_per_step[i].amax = amax_per_step[0][i].reshape((1,)) + dP_quantizer_per_step[i] = ctx.dP_quantizer.copy() if ctx.dP_quantizer is not None else None dQKV_quantizer_per_step[i] = ctx.dQKV_quantizer.copy() - dQKV_quantizer_per_step[i].amax = amax_per_step[1][i].reshape((1,)) + if not ctx.fp8_recipe.mxfp8(): + dP_quantizer_per_step[i].amax = amax_per_step[0][i].reshape((1,)) + dQKV_quantizer_per_step[i].amax = amax_per_step[1][i].reshape((1,)) else: if isinstance(dout, QuantizedTensorStorage): dout = dout.dequantize(dtype=bwd_nominal_dtype) @@ -2228,7 +2287,7 @@ def backward(ctx, dout, *_args): ] p2p_comm_buffers[0][0].copy_(kv) if ctx.use_fused_attention: - bwd_output_te_dtype = TE_DType[bwd_nominal_dtype] + # bwd_output_te_dtype = TE_DType[bwd_nominal_dtype] fused_attn_backend = FusedAttnBackend["F16_arbitrary_seqlen"] # communicate for the 'a2a' part of 'a2a+p2p' @@ -2352,10 +2411,10 @@ def backward(ctx, dout, *_args): kv_fp8, ( out - if ctx.fp8_recipe.float8_current_scaling() and _dpa_fp8_cs_o_in_f16 + if (ctx.fp8_recipe.float8_current_scaling() and _dpa_fp8_cs_o_in_f16) or ctx.fp8_recipe.mxfp8() else out_fp8 ), - dout_fp8, + dout_fp8 if not ctx.fp8_recipe.mxfp8() else dout, softmax_lse, softmax_lse_, rng_states, @@ -2373,15 +2432,21 @@ def backward(ctx, dout, *_args): ctx.softmax_scale, ctx.dropout_p, qkv_layout, + ctx.qkv_format, + ctx.qkv_format, + qkv_layout, ctx.attn_mask_type, ctx.attn_bias_type, ctx.deterministic, ctx.fwd_nominal_dtype, bwd_nominal_dtype, - bwd_output_te_dtype, + # bwd_output_te_dtype, ctx.S_quantizer, dP_quantizer_per_step[i], dQKV_quantizer_per_step[i], + # ctx.O_quantizer, + ctx.QKV_quantizer, + ctx.dO_quantizer, ] else: flash_attn_inputs = [ @@ -2455,7 +2520,7 @@ def backward(ctx, dout, *_args): if ctx.fp8 and ctx.use_fused_attention: if ctx.fp8_recipe.delayed(): dq_, dk_, dv_ = [x._data for x in [dq_, dk_, dv_]] - if ctx.fp8_recipe.float8_current_scaling(): + if ctx.fp8_recipe.float8_current_scaling() or ctx.fp8_recipe.mxfp8(): dq_, dk_, dv_ = [x.to(torch.float32) for x in [dq_, dk_, dv_]] # copy dq_ into the right buffer position @@ -2539,7 +2604,7 @@ def backward(ctx, dout, *_args): # dkv correction if ctx.fp8 and ctx.fp8_recipe.delayed(): dkv = dkv_recv_buffer[(rank + i + 1) % cp_size] - elif ctx.fp8 and ctx.fp8_recipe.float8_current_scaling(): + elif ctx.fp8 and (ctx.fp8_recipe.float8_current_scaling() or ctx.fp8_recipe.mxfp8()): dkv = dkv_buffer else: dkv = p2p_comm_buffers[(i + 1) % 2][1] @@ -2629,9 +2694,10 @@ def backward(ctx, dout, *_args): # sum up all cp_size for dq, dk, dv if ctx.fp8 and ctx.use_fused_attention: - amax_cp_bwd = amax_per_step.amax(dim=1) - ctx.dP_quantizer.amax.copy_(amax_cp_bwd[0]) - ctx.dQKV_quantizer.amax.copy_(amax_cp_bwd[1]) + if not ctx.fp8_recipe.mxfp8(): + amax_cp_bwd = amax_per_step.amax(dim=1) + ctx.dP_quantizer.amax.copy_(amax_cp_bwd[0]) + ctx.dQKV_quantizer.amax.copy_(amax_cp_bwd[1]) dq = dq_buffer if ctx.fp8_recipe.delayed(): @@ -2654,7 +2720,7 @@ def backward(ctx, dout, *_args): ) dq, dk, dv = [x.sum(dim=0).to(bwd_nominal_dtype) for x in [dq, dk, dv]] - if ctx.fp8_recipe.float8_current_scaling(): + if ctx.fp8_recipe.float8_current_scaling() or ctx.fp8_recipe.mxfp8(): dk = dkv[: ctx.k_numel].view(ctx.k_shape) dv = dkv[ctx.k_numel :].view(ctx.v_shape) @@ -2670,7 +2736,7 @@ def backward(ctx, dout, *_args): dv[cu_seqlens_kv_padded[-1] :].fill_(0) if ctx.fp8 and ctx.is_input_fp8: - dq, dk, dv = combine_and_quantize(qkv_layout, dq, dk, dv, ctx.dQKV_quantizer) + dq, dk, dv, _ = combine_and_quantize(qkv_layout, dq, dk, dv, ctx.dQKV_quantizer) if ctx.fp8: # print quantizers diff --git a/transformer_engine/pytorch/attention/dot_product_attention/utils.py b/transformer_engine/pytorch/attention/dot_product_attention/utils.py index eb2a7f8e94..05301c186d 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/utils.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/utils.py @@ -2236,6 +2236,7 @@ def combine_and_quantize(qkv_layout, q, k, v, qkv_quantizer): original_shapes = [x.shape for x in [q, k, v]] s_q, d_qk = q.shape[-2:] s_kv, d_v = v.shape[-2:] + print(f">>>>>>>>>>>> {torch.cuda.current_device()}: {qkv_layout} s_q: {s_q}, d_qk: {d_qk}, s_kv: {s_kv}, d_v: {d_v}, {q.shape}, {k.shape}, {v.shape}") assert s_q % 128 == 0 assert s_kv % 128 == 0 assert d_qk % 32 == 0 From aa05a2afa644505dbae63ea3bc7779f6ce948c30 Mon Sep 17 00:00:00 2001 From: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Date: Fri, 27 Feb 2026 16:34:45 -0800 Subject: [PATCH 037/172] minor fixes for mla Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --- .../attention/run_attention_with_cp.py | 3 +- .../attention/test_attention_with_cp.py | 6 ++-- tests/pytorch/utils.py | 1 + .../dot_product_attention/context_parallel.py | 34 +++++++++++++------ .../attention/dot_product_attention/utils.py | 12 +++---- 5 files changed, 35 insertions(+), 21 deletions(-) diff --git a/tests/pytorch/attention/run_attention_with_cp.py b/tests/pytorch/attention/run_attention_with_cp.py index b019289846..a53d872302 100644 --- a/tests/pytorch/attention/run_attention_with_cp.py +++ b/tests/pytorch/attention/run_attention_with_cp.py @@ -465,7 +465,8 @@ def run_dpa_with_cp( tensors_to_deq[i] = tensor.dequantize() if not fp8_bwd: tensors[0], tensors[4] = tensors_to_deq - for tensor in tensors: + for i, tensor in enumerate(tensors): + print(f"========= {torch.cuda.current_device()}: tensors[{i}].shape: {tensor.shape} {tensor.dtype} {torch.isnan(tensor).any()} {torch.isinf(tensor).any()}") assert torch.all(~torch.isnan(tensor)) assert torch.all(~torch.isinf(tensor)) out, dq, dk, dv, out_, dq_, dk_, dv_ = tensors diff --git a/tests/pytorch/attention/test_attention_with_cp.py b/tests/pytorch/attention/test_attention_with_cp.py index 668c2745c7..2ab64d2029 100644 --- a/tests/pytorch/attention/test_attention_with_cp.py +++ b/tests/pytorch/attention/test_attention_with_cp.py @@ -151,7 +151,7 @@ def test_cp_with_flash_attention(dtype, model, qkv_format, cp_comm_type): "cp_1_3": ModelConfig(2, 4096, 12, 128, attn_bias_type="post_scale_bias"), # MHA "cp_1_4": ModelConfig(2, 4096, 12, 128, attn_mask_type="causal", window_size=(512, 512)), # MHA "cp_2_0": ModelConfig(2, 4096, 12, 128, num_gqa_groups=2, attn_mask_type="causal"), # GQA - "cp_2_1": ModelConfig(2, 4096, 16, 128),#, num_gqa_groups=12), # GQA + "cp_2_1": ModelConfig(2, 4096, 16, 192, head_dim_v=128, num_gqa_groups=4, attn_mask_type="causal"), # GQA "cp_2_2": ModelConfig( 2, 4096, @@ -288,8 +288,8 @@ def test_cp_with_fused_attention( pytest.skip("f16_O only needs to be tested for dtype = fp8 and scaling_mode = current!") if "p2p" not in cp_comm_type and config.head_dim_qk != config.head_dim_v: pytest.skip("MLA CP currently only support KV P2P!") - if dtype == "fp8" and config.head_dim_qk != config.head_dim_v: - pytest.skip("MLA CP currently does not support FP8 attention!") + # if dtype == "fp8" and config.head_dim_qk != config.head_dim_v: + # pytest.skip("MLA CP currently does not support FP8 attention!") if dtype == "fp8" and config.softmax_type != "vanilla": pytest.skip("CP implementation does not support non-vanilla softmax types in FP8!") if config.softmax_type != "vanilla" and cp_comm_type != "a2a": diff --git a/tests/pytorch/utils.py b/tests/pytorch/utils.py index c54295d478..ff8cb3e820 100644 --- a/tests/pytorch/utils.py +++ b/tests/pytorch/utils.py @@ -177,6 +177,7 @@ def compare_and_assert(a, b, name_a, name_b, atol, rtol, rmse_tol, is_fp8): rmse = torch.sqrt((a - b).square().mean()).item() logging.debug(name_a + " vs " + name_b + " RMSE: {:.6f}".format(rmse)) rmse_range = max(a.max().item(), b.max().item()) - min(a.min().item(), b.min().item()) + # rmse_tol = rmse_tol * 1.1 assert rmse < rmse_tol * rmse_range, ( name_a + " vs " diff --git a/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py b/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py index dda856f36a..864967d661 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py @@ -1408,15 +1408,17 @@ def forward( q_fp8, k_fp8, v_fp8 = (None, None, None) # communicate for the 'a2a' part of 'a2a+p2p' if cp_size_a2a > 1: + print(f">>>>>>======================>>>>>> {torch.cuda.current_device()}: fp8: {fp8}, is_input_fp8: {is_input_fp8}, fp8_recipe.mxfp8(): {fp8_recipe.mxfp8()}") if fp8 and is_input_fp8: QKV_quantizer = q._quantizer q_fp8, k_fp8, v_fp8 = q, k, v - q, k, v = (q._data, k._data, v._data) + if not fp8_recipe.mxfp8(): + q, k, v = (q._data, k._data, v._data) chunk_ids_for_a2a = get_seq_chunk_ids_for_reordering_before_attn(cp_size_a2a, q.device) q, k, v = flash_attn_a2a_communicate( [q, k, v], chunk_ids_for_a2a, seq_dim, cp_size_a2a, cp_group_a2a, cp_stream, True ) - if fp8 and is_input_fp8: + if fp8 and is_input_fp8 and not fp8_recipe.mxfp8(): q_fp8, k_fp8, v_fp8 = [ Float8Tensor.make_like(x, data=y, dtype=fwd_nominal_dtype) for x, y in zip([q_fp8, k_fp8, v_fp8], [q, k, v]) @@ -1576,6 +1578,7 @@ def forward( k_shape = k.shape k_numel = k.numel() v_shape = v.shape + o_shape = q.shape if not enable_mla else q.shape[:-1] + v.shape[-1:] p2p_comm_buffers[0] = torch.cat((k.view(-1), v.view(-1)), dim=-1) send_recv_reqs = [[], []] @@ -1820,7 +1823,7 @@ def forward( if qkv_format == "thd": if enable_mla: out = torch.zeros_like(v if not fp8 else out_per_step[0]).view( - v_shape + o_shape ) else: # MHA or GQA @@ -1874,8 +1877,9 @@ def forward( softmax_lse_per_step[0], seq_dim, ) + print(f"====o/v===== {torch.cuda.current_device()}: i: {i}, {enable_mla}, out.shape: {out.shape} {out_per_step[0].shape} {v_shape} {o_shape}") if enable_mla: - out = out.view(v_shape) + out = out.view(o_shape) else: out = out.view(q.shape) else: @@ -1922,6 +1926,9 @@ def forward( elif o_format == "sbhd": out = out.view(-1, *out.shape[-3:]) ctx.batch_size = out.shape[1] + print(f"========= {torch.cuda.current_device()}: out.shape: {out.shape} {out.dtype}") + out_part = out.to(fwd_nominal_dtype) + print(f"========= {torch.cuda.current_device()}: out_part.shape: {out_part.shape} {out_part.dtype}") if cp_size_a2a > 1: chunk_ids_for_a2a = get_seq_chunk_ids_for_reordering_after_attn(cp_size_a2a, out.device) @@ -1986,6 +1993,7 @@ def forward( # q, kv, out fp8_tensors = (None, None, None) f16_tensors = (None, None, None) + out_f16 = out_part if ctx.fp8: # fwd: fp8, bwd: fp8, save all fp8 fp8_tensors = (q_fp8, kv_fp8, out_fp8) @@ -2064,6 +2072,7 @@ def forward( ctx.k_numel = k_numel ctx.k_shape = k_shape ctx.v_shape = v_shape + ctx.o_shape = o_shape ctx.fwd_nominal_dtype = fwd_nominal_dtype ctx.dQKV_quantizer = dQKV_quantizer @@ -2292,14 +2301,15 @@ def backward(ctx, dout, *_args): # communicate for the 'a2a' part of 'a2a+p2p' if cp_size_a2a > 1: + print(f"========= {torch.cuda.current_device()}: before a2a: out.shape: {out.shape} {out.dtype} dout.shape: {dout.shape} {dout.dtype}") if not ctx.use_fused_attention: - out = out.view(ctx.batch_size, -1, *out.shape[-2:]) + # out = out.view(ctx.batch_size, -1, *out.shape[-2:]) dout = dout.view(*out.shape) chunk_ids_for_a2a = get_seq_chunk_ids_for_reordering_before_attn( cp_size_a2a, out.device ) - out, dout = flash_attn_a2a_communicate( - [out, dout], + dout = flash_attn_a2a_communicate( + [dout], chunk_ids_for_a2a, seq_dim, cp_size_a2a, @@ -2307,10 +2317,11 @@ def backward(ctx, dout, *_args): ctx.cp_stream, True, ) + print(f"========= {torch.cuda.current_device()}: after a2a: dout.shape: {dout.shape} {dout.dtype} {q.shape} {ctx.v_shape}") if ctx.enable_mla: - out = out.view(*ctx.v_shape) - dout = dout.view(*ctx.v_shape) + out = out.view(*ctx.o_shape) + dout = dout.view(*ctx.o_shape) else: # MHA or GQA out = out.view(*q.shape) @@ -2754,7 +2765,8 @@ def backward(ctx, dout, *_args): if cp_size_a2a > 1: if ctx.fp8 and ctx.is_input_fp8: dq_fp8, dk_fp8, dv_fp8 = dq, dk, dv - dq, dk, dv = (dq_fp8._data, dk_fp8._data, dv_fp8._data) + if not ctx.fp8_recipe.mxfp8(): + dq, dk, dv = (dq_fp8._data, dk_fp8._data, dv_fp8._data) chunk_ids_for_a2a = get_seq_chunk_ids_for_reordering_after_attn(cp_size_a2a, q.device) dq, dk, dv = flash_attn_a2a_communicate( [dq, dk, dv], @@ -2765,7 +2777,7 @@ def backward(ctx, dout, *_args): ctx.cp_stream, False, ) - if ctx.fp8 and ctx.is_input_fp8: + if ctx.fp8 and ctx.is_input_fp8 and not ctx.fp8_recipe.mxfp8(): dq, dk, dv = [ Float8Tensor.make_like(x, data=y, dtype=bwd_nominal_dtype) for x, y in zip([dq_fp8, dk_fp8, dv_fp8], [dq, dk, dv]) diff --git a/transformer_engine/pytorch/attention/dot_product_attention/utils.py b/transformer_engine/pytorch/attention/dot_product_attention/utils.py index 05301c186d..6f36aee355 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/utils.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/utils.py @@ -835,12 +835,12 @@ def _is_fa3_supported(num_heads, num_gqa_groups, head_dim_qk, head_dim_v, qkv_dt " bias for THD format" ) use_fused_attention = False - elif fp8 and fp8_meta["recipe"].fp8_dpa and head_dim_qk != head_dim_v: - logger.debug( - "Disabling FusedAttention as it does not support context parallelism with FP8" - " MLA attention" - ) - use_fused_attention = False + # elif fp8 and fp8_meta["recipe"].fp8_dpa and head_dim_qk != head_dim_v: + # logger.debug( + # "Disabling FusedAttention as it does not support context parallelism with FP8" + # " MLA attention" + # ) + # use_fused_attention = False # Filter: Attention mask # attn_mask_type | attention_mask | supported backends From 00e6693f978dbc877af63efb069429d840f786ed Mon Sep 17 00:00:00 2001 From: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Date: Fri, 27 Feb 2026 16:51:46 -0800 Subject: [PATCH 038/172] open up a2a for mla Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --- tests/pytorch/attention/test_attention_with_cp.py | 2 +- .../pytorch/attention/dot_product_attention/context_parallel.py | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/pytorch/attention/test_attention_with_cp.py b/tests/pytorch/attention/test_attention_with_cp.py index 2ab64d2029..dc8c237d57 100644 --- a/tests/pytorch/attention/test_attention_with_cp.py +++ b/tests/pytorch/attention/test_attention_with_cp.py @@ -286,7 +286,7 @@ def test_cp_with_fused_attention( pytest.skip("fp8 only works with P2P, A2A and A2A+P2P for scaling_mode = current!") if f16_O and (dtype != "fp8" or scaling_mode not in ["current", "mxfp8"]): pytest.skip("f16_O only needs to be tested for dtype = fp8 and scaling_mode = current!") - if "p2p" not in cp_comm_type and config.head_dim_qk != config.head_dim_v: + if cp_comm_type not in ["p2p", "a2a+p2p", "a2a"] and config.head_dim_qk != config.head_dim_v: pytest.skip("MLA CP currently only support KV P2P!") # if dtype == "fp8" and config.head_dim_qk != config.head_dim_v: # pytest.skip("MLA CP currently does not support FP8 attention!") diff --git a/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py b/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py index 864967d661..32f908993e 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py @@ -4163,6 +4163,7 @@ def attn_forward_func_with_cp( assert not enable_mla or cp_comm_type in [ "p2p", "a2a+p2p", + "a2a", ], f"Context parallelism does not support MLA with {cp_comm_type=}!" if fp8 and fp8_meta is not None: From b8d28ceb6b92a811385c0d377756b9fa6d19c750 Mon Sep 17 00:00:00 2001 From: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Date: Fri, 27 Feb 2026 19:33:44 -0800 Subject: [PATCH 039/172] test ag Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --- .../attention/test_attention_with_cp.py | 10 +- .../dot_product_attention/context_parallel.py | 244 +++++++++++++++--- 2 files changed, 213 insertions(+), 41 deletions(-) diff --git a/tests/pytorch/attention/test_attention_with_cp.py b/tests/pytorch/attention/test_attention_with_cp.py index dc8c237d57..efed75925a 100644 --- a/tests/pytorch/attention/test_attention_with_cp.py +++ b/tests/pytorch/attention/test_attention_with_cp.py @@ -151,7 +151,7 @@ def test_cp_with_flash_attention(dtype, model, qkv_format, cp_comm_type): "cp_1_3": ModelConfig(2, 4096, 12, 128, attn_bias_type="post_scale_bias"), # MHA "cp_1_4": ModelConfig(2, 4096, 12, 128, attn_mask_type="causal", window_size=(512, 512)), # MHA "cp_2_0": ModelConfig(2, 4096, 12, 128, num_gqa_groups=2, attn_mask_type="causal"), # GQA - "cp_2_1": ModelConfig(2, 4096, 16, 192, head_dim_v=128, num_gqa_groups=4, attn_mask_type="causal"), # GQA + "cp_2_1": ModelConfig(2, 4096, 16, 128), #192, head_dim_v=128, num_gqa_groups=4, attn_mask_type="causal"), # GQA "cp_2_2": ModelConfig( 2, 4096, @@ -249,10 +249,10 @@ def test_cp_with_fused_attention( "CP implementation with QKVO A2A+P2P (Hierarchical A2A) does not support THD format" " yet!" ) - if dtype == "fp8" and cp_comm_type == "all_gather": - pytest.skip( - "CP implementation with KV all-gather does not support FP8 + context parallelism yet!" - ) + # if dtype == "fp8" and cp_comm_type == "all_gather": + # pytest.skip( + # "CP implementation with KV all-gather does not support FP8 + context parallelism yet!" + # ) if dtype == "fp8" and qkv_format == "thd": pytest.skip("FP8 attention cannot work with THD format yet!") if dtype == "fp8" and config.attn_bias_type != "no_bias": diff --git a/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py b/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py index 32f908993e..88c98ce041 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py @@ -2883,6 +2883,10 @@ def forward( cp_group, cp_stream, use_flash_attn_3, + fp8, + fp8_meta, + quantizers, + fp8_output, ): # pylint: disable=missing-function-docstring nvtx_range_push("transformer_engine.AttnFuncWithCPAndKVAllGather.forward") @@ -2892,7 +2896,11 @@ def forward( cp_size = get_distributed_world_size(cp_group) rank = get_distributed_rank(cp_group) - qkv_dtype = q.dtype + assert qkv_format != "thd", f"{qkv_format} format is not supported!" + qkv_layout = qkv_format + "_" + qkv_format + "_" + qkv_format + q_shape = q.shape + k_shape = k.shape + v_shape = v.shape causal = "causal" in attn_mask_type padding = "padding" in attn_mask_type @@ -2936,9 +2944,6 @@ def forward( if fa_utils.v2_6_0_plus: fa_forward_kwargs["softcap"] = 0.0 - assert qkv_format != "thd", f"{qkv_format} format is not supported!" - qkv_layout = qkv_format + "_" + qkv_format + "_" + qkv_format - seq_dim = qkv_format.index("s") assert ( q.shape[seq_dim] % 2 == 0 and k.shape[seq_dim] % 2 == 0 @@ -2953,6 +2958,42 @@ def forward( else: cu_seqlens_q_padded = None + # FP8 setup + assert isinstance(k, q.__class__) and isinstance( + v, q.__class__ + ), "q, k, v must be of the same class, e.g. torch.Tensor or QuantizedTensorStorage." + is_input_fp8 = isinstance(q, QuantizedTensorStorage) + is_output_fp8 = fp8_output + is_bwd_fp8 = int(os.getenv("NVTE_FP8_DPA_BWD", "1")) + fp8_recipe = FP8GlobalStateManager.get_fp8_recipe() + if fp8_meta is not None and fp8_meta.get("local_recipes", None) is not None: + fp8_recipe = fp8_meta["local_recipes"][0] + ( + QKV_quantizer, + O_quantizer, + S_quantizer, + dQKV_quantizer, + dO_quantizer, + dP_quantizer, + ) = dpa_utils.get_attention_quantizers(fp8, fp8_recipe, quantizers) + fwd_nominal_dtype = q.dtype + fp8_meta_kwargs = {} + q_fp8, k_fp8, v_fp8 = (None, None, None) + fused_attn_backend = None + if fp8: + assert use_fused_attention, "FP8 is only supported with Fused Attention!" + fused_attn_backend = tex.NVTE_Fused_Attn_Backend.NVTE_FP8 + if is_input_fp8: + q_fp8, k_fp8, v_fp8 = q, k, v + else: + q_fp8, k_fp8, v_fp8, qkv_layout = combine_and_quantize(qkv_layout, q, k, v, QKV_quantizer) + if not fp8_recipe.mxfp8(): + q, k, v = [q_fp8._data, k_fp8._data, v_fp8._data] + fp8_meta_kwargs["s_quantizer"] = S_quantizer + fp8_meta_kwargs["o_quantizer"] = O_quantizer + elif use_fused_attention: + fused_attn_backend = tex.NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen + # [b, s, h, d] -> [b, 2, s//2, h, d] or [s, b, h, d] -> [2, s//2, b, h, d] q = q.view(*q.shape[:seq_dim], 2, q.shape[seq_dim] // 2, *q.shape[(seq_dim + 1) :]) # [b, s, h, d] or [s, b, h, d] -> [s, b, h, d] @@ -2983,7 +3024,9 @@ def forward( out_per_step = [None, None] softmax_lse_per_step = [None, None] rng_states = [None, None] - out = torch.empty_like(q) + enable_mla = k.shape[-1] != v.shape[-1] + out_shape = q.shape if not enable_mla else q.shape[:-1] + v.shape[-1:] + out = torch.empty(out_shape, dtype=fwd_nominal_dtype, device=q.device) max_logit_per_step = [None, None] max_logit = None @@ -3016,6 +3059,14 @@ def forward( # [s_range, b, h, d] -> [b, s_range, h, d] or [s_range, b, h, d] k_, v_ = [x.movedim(0, seq_dim).contiguous() for x in [k_, v_]] if use_fused_attention: + q_part, k_part, v_part = q_, k_, v_ + if fp8: + if not fp8_recipe.mxfp8(): + q_part = Float8Tensor.make_like(q_fp8, data=q_, dtype=fwd_nominal_dtype) + k_part = Float8Tensor.make_like(k_fp8, data=k_, dtype=fwd_nominal_dtype) + v_part = Float8Tensor.make_like(v_fp8, data=v_, dtype=fwd_nominal_dtype) + else: + q_part, k_part, v_part, qkv_layout = combine_and_quantize(qkv_layout, q_part, k_part, v_part, QKV_quantizer) ( out_per_step[i], [softmax_lse_per_step[i], rng_states[i]], @@ -3026,14 +3077,15 @@ def forward( max_seqlen_kv_, cu_seqlens_q, cu_seqlens_kv_per_step[i], - q_, - k_, - v_, - qkv_dtype, - tex.NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen, + q_part, + k_part, + v_part, + fwd_nominal_dtype, + fused_attn_backend, attn_scale=softmax_scale, dropout=dropout_p, qkv_layout=qkv_layout, + o_format=qkv_format, attn_mask_type=attn_mask_type, attn_bias_type=attn_bias_type, attn_bias=attn_bias, @@ -3042,9 +3094,12 @@ def forward( window_size=window_size_per_step[i], return_max_logit=return_max_logit, cuda_graph=is_graph_capturing(), + **fp8_meta_kwargs, ) if return_max_logit: max_logit_per_step[i] = max_logit_[0] + if fp8 and isinstance(out_per_step[i], QuantizedTensorStorage): + out_per_step[i] = out_per_step[i].dequantize(dtype=fwd_nominal_dtype) else: fa_forward_args_thd = get_fa_args( True, @@ -3104,10 +3159,38 @@ def forward( else: out = out.view(-1, *out.shape[-2:]) - ctx.save_for_backward( - q, - k, - v, + out_fp8 = None + out_ret = out + if fp8 and (is_output_fp8 or (is_bwd_fp8 and fp8_recipe.delayed())): + out_fp8 = O_quantizer(out) + out_ret = out_fp8 + ctx.fp8 = fp8 and is_bwd_fp8 + ctx.fp8_recipe = fp8_recipe + fp8_tensors = (None, None, None, None) + f16_tensors = (None, None, None, None) + if ctx.fp8: + if fp8_recipe.delayed(): + fp8_tensors = (q_fp8, k_fp8, v_fp8, out_fp8) + if fp8_recipe.float8_current_scaling() and not _dpa_fp8_cs_o_in_f16: + fp8_tensors = (q_fp8, k_fp8, v_fp8, out_fp8) + if fp8_recipe.float8_current_scaling() and _dpa_fp8_cs_o_in_f16: + fp8_tensors = (q_fp8, k_fp8, v_fp8, None) + f16_tensors = (None, None, None, out) + if fp8_recipe.mxfp8(): + f16_tensors = (q, k, v, out) + elif fp8: + if is_input_fp8: + q, k, v = combine_and_dequantize(qkv_layout, q_fp8, k_fp8, v_fp8) + f16_tensors = (q, k, v, out) + else: + f16_tensors = (q, k, v, out) + if torch.cuda.current_device() == 0: + print(f"fp8_tensors: {[x.shape if x is not None else None for x in fp8_tensors]} {[type(x) for x in fp8_tensors]}") + print(f"f16_tensors: {[x.shape if x is not None else None for x in f16_tensors]} {[type(x) for x in f16_tensors]}") + + tensors_to_save, tensor_objects = prepare_for_saving( + *fp8_tensors, + *f16_tensors, cu_seqlens_q, cu_seqlens_q_padded, *cu_seqlens_kv_per_step, @@ -3115,8 +3198,14 @@ def forward( *softmax_lse_per_step, *rng_states, ) + ctx.save_for_backward(*tensors_to_save) + ctx.tensor_objects = tensor_objects - ctx.qkv_dtype = qkv_dtype + ctx.fwd_nominal_dtype = fwd_nominal_dtype + ctx.q_shape = q_shape + ctx.k_shape = k_shape + ctx.v_shape = v_shape + ctx.out_shape = out_shape ctx.kv_seq_range_per_step = kv_seq_range_per_step ctx.window_size_per_step = window_size_per_step ctx.cp_group = cp_group @@ -3130,10 +3219,24 @@ def forward( ctx.deterministic = deterministic ctx.use_fused_attention = use_fused_attention ctx.use_flash_attn_3 = use_flash_attn_3 + ctx.fp8_meta = fp8_meta + ctx.is_input_fp8 = is_input_fp8 + ctx.is_output_fp8 = is_output_fp8 + if fp8: + ctx.QKV_quantizer = QKV_quantizer.copy() + ctx.O_quantizer = O_quantizer.copy() + ctx.S_quantizer = S_quantizer.copy() if S_quantizer is not None else None + ctx.dQKV_quantizer = dQKV_quantizer.copy() + ctx.dO_quantizer = dO_quantizer.copy() + ctx.dP_quantizer = dP_quantizer.copy() if dP_quantizer is not None else None + if not ctx.fp8_recipe.mxfp8(): + ctx.QKV_quantizer.scale = QKV_quantizer.scale.clone() + ctx.O_quantizer.scale = O_quantizer.scale.clone() + ctx.S_quantizer.scale = S_quantizer.scale.clone() nvtx_range_pop("transformer_engine.AttnFuncWithCPAndKVAllGather.forward") if return_max_logit: - return out, max_logit - return out + return out_ret, max_logit + return out_ret @staticmethod def backward(ctx, dout, *_args): @@ -3142,22 +3245,41 @@ def backward(ctx, dout, *_args): cp_size = get_distributed_world_size(ctx.cp_group) rank = get_distributed_rank(ctx.cp_group) - (*saved_tensors,) = ctx.saved_tensors - (q, k, v, cu_seqlens_q, cu_seqlens_q_padded) = saved_tensors[:5] - cu_seqlens_kv_per_step = saved_tensors[5:7] - out_per_step = saved_tensors[7:9] - softmax_lse_per_step = saved_tensors[9:11] - rng_states = saved_tensors[11:13] + ( + q_fp8, k_fp8, v_fp8, out_fp8, + q, k, v, out, + cu_seqlens_q, + cu_seqlens_q_padded, + cu_seqlens_kv_per_step, + out_per_step, + softmax_lse_per_step, + rng_states + ) = restore_from_saved(ctx.tensor_objects, ctx.saved_tensors) + kv_seq_range_per_step = ctx.kv_seq_range_per_step window_size_per_step = ctx.window_size_per_step seq_dim = ctx.qkv_format.index("s") qkv_layout = ctx.qkv_format + "_" + ctx.qkv_format + "_" + ctx.qkv_format - dout = dout.view(q.shape) - dq = torch.empty_like(q) - dk = torch.zeros((k.shape[0] * cp_size, *k.shape[1:]), dtype=k.dtype, device=k.device) - dv = torch.zeros_like(dk) + dout = dout.view(ctx.out_shape) + dout_fp8 = None + if ctx.fp8: + if ( + ctx.is_output_fp8 + and not isinstance(dout, QuantizedTensorStorage) + and not ctx.fp8_recipe.mxfp8() + ): + dout_fp8 = ctx.dO_quantizer(dout) + if not ctx.fp8_recipe.mxfp8(): + dout = dout_fp8._data + + if ctx.fp8 and not ctx.fp8_recipe.mxfp8(): + q, k, v = [x._data for x in [q_fp8, k_fp8, v_fp8]] + + dq = torch.empty(ctx.q_shape, dtype=ctx.fwd_nominal_dtype, device=q.device) + dk = torch.zeros((ctx.k_shape[0] * cp_size, *ctx.k_shape[1:]), dtype=ctx.fwd_nominal_dtype, device=k.device) + dv = torch.zeros((ctx.v_shape[0] * cp_size, *ctx.v_shape[1:]), dtype=ctx.fwd_nominal_dtype, device=v.device) dq_per_step = [None, None] dk_per_step = [None, None] dv_per_step = [None, None] @@ -3233,31 +3355,65 @@ def backward(ctx, dout, *_args): dout_ = dout.select(seq_dim, i).contiguous().view(out_.shape) if ctx.use_fused_attention: aux_ctx_tensors = [softmax_lse_per_step[i], rng_states[i]] + q_part, k_part, v_part, out_part, dout_part = q_, k_, v_, out_, dout_ + fused_attn_backend = tex.NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen + fp8_meta_kwargs = {} + if ctx.fp8: + fused_attn_backend = tex.NVTE_Fused_Attn_Backend.NVTE_FP8 + fp8_meta_kwargs["s_quantizer"] = ctx.S_quantizer + fp8_meta_kwargs["dp_quantizer"] = ctx.dP_quantizer + fp8_meta_kwargs["dqkv_quantizer"] = ctx.dQKV_quantizer + if not ctx.fp8_recipe.mxfp8(): + q_part = Float8Tensor.make_like( + q_fp8, data=q_, dtype=ctx.fwd_nominal_dtype + ) + k_part = Float8Tensor.make_like( + k_fp8, data=k_, dtype=ctx.fwd_nominal_dtype + ) + v_part = Float8Tensor.make_like( + v_fp8, data=v_, dtype=ctx.fwd_nominal_dtype + ) + if not (ctx.fp8_recipe.float8_current_scaling() and _dpa_fp8_cs_o_in_f16): + out_part = ctx.O_quantizer(out_part) + dout_part = Float8Tensor.make_like(dout_fp8, data=dout_part, dtype=ctx.fwd_nominal_dtype) + else: + q_part, k_part, v_part, qkv_layout = combine_and_quantize(qkv_layout, q_part, k_part, v_part, ctx.QKV_quantizer) + aux_ctx_tensors.append(dout_part) + dout_part = ctx.dO_quantizer(dout_part) dq_per_step[i], dk_per_step[i], dv_per_step[i], *_ = fused_attn_bwd( ctx.max_seqlen_q, max_seqlen_kv, cu_seqlens_q, cu_seqlens_kv_per_step[i], - q_, - k_, - v_, - out_, - dout_, - ctx.qkv_dtype, - TE_DType[dout.dtype], + q_part, + k_part, + v_part, + out_part, + dout_part, + ctx.fwd_nominal_dtype, + # TE_DType[dout.dtype], aux_ctx_tensors, - tex.NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen, + fused_attn_backend, cu_seqlens_q_padded=cu_seqlens_q_padded, cu_seqlens_kv_padded=cu_seqlens_kv_per_step[i], attn_scale=ctx.softmax_scale, dropout=ctx.dropout_p, qkv_layout=qkv_layout, + o_format=ctx.qkv_format, + d_out_format=ctx.qkv_format, + dqkv_layout=qkv_layout, attn_mask_type=ctx.attn_mask_type, attn_bias_type=ctx.attn_bias_type, window_size=window_size_per_step[i], deterministic=ctx.deterministic, cuda_graph=is_graph_capturing(), + **fp8_meta_kwargs, ) + if ctx.fp8: + dq_per_step[i], dk_per_step[i], dv_per_step[i] = [ + x.dequantize(dtype=ctx.fwd_nominal_dtype) if isinstance(x, QuantizedTensorStorage) else x + for x in [dq_per_step[i], dk_per_step[i], dv_per_step[i]] + ] else: dq_per_step[i], dk_per_step[i], dv_per_step[i] = [ torch.empty_like(x) for x in [q_, k_, v_] @@ -3335,6 +3491,10 @@ def backward(ctx, dout, *_args): dq = dq.view(*dq.shape[:seq_dim], -1, *dq.shape[(seq_dim + 2) :]) dk = dk.movedim(0, seq_dim).contiguous() dv = dv.movedim(0, seq_dim).contiguous() + + if ctx.fp8 and ctx.is_input_fp8: + dq, dk, dv, _ = combine_and_quantize(qkv_layout, dq, dk, dv, ctx.dQKV_quantizer) + nvtx_range_pop("transformer_engine.AttnFuncWithCPAndKVAllGather.backward") return ( @@ -3359,6 +3519,9 @@ def backward(ctx, dout, *_args): None, None, None, + None, + None, + None, ) @@ -4222,7 +4385,16 @@ def attn_forward_func_with_cp( elif cp_comm_type == "all_gather": args.pop(5) args.pop(8) - args += [window_size, cp_group, cp_stream, use_flash_attn_3] + args += [ + window_size, + cp_group, + cp_stream, + use_flash_attn_3, + fp8, + fp8_meta, + quantizers, + fp8_output, + ] out = AttnFuncWithCPAndKVAllGather.apply(*args) elif cp_comm_type == "a2a": args += [ From d6ecadc12c2192bc443167f7efad3e201c77e763 Mon Sep 17 00:00:00 2001 From: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Date: Fri, 27 Feb 2026 20:32:22 -0800 Subject: [PATCH 040/172] tweaks for last commit Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --- .../common/fused_attn/fused_attn_fp8.cu | 2 +- .../dot_product_attention/context_parallel.py | 58 ++++++++++++++----- 2 files changed, 44 insertions(+), 16 deletions(-) diff --git a/transformer_engine/common/fused_attn/fused_attn_fp8.cu b/transformer_engine/common/fused_attn/fused_attn_fp8.cu index da826688be..764e95c330 100644 --- a/transformer_engine/common/fused_attn/fused_attn_fp8.cu +++ b/transformer_engine/common/fused_attn/fused_attn_fp8.cu @@ -2891,7 +2891,7 @@ void fused_attn_fp8_bwd(size_t batch, size_t num_attn_heads, size_t num_gqa_grou size_t workspace_size = 0; NVTE_QKV_Format dqkv_format = nvte_get_qkv_format(dqkv_layout); - if ((dqkv_format == NVTE_QKV_Format::NVTE_BSHD) || (dqkv_format == NVTE_QKV_Format::NVTE_SBHD)) { + if ((dqkv_format == NVTE_QKV_Format::NVTE_BSHD) || (dqkv_format == NVTE_QKV_Format::NVTE_SBHD) || (dqkv_format == NVTE_QKV_Format::NVTE_BHSD)) { fused_attn::fused_attn_fp8_bwd_impl_v1( batch, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, head_dim_qk, head_dim_v, attn_scale, p_dropout, qkv_layout, o_format, d_out_format, dqkv_layout, bias_type, mask_type, window_size_left, window_size_right, bottom_right_diagonal, deterministic, devPtrQ, devPtrK, devPtrV, devPtrM, devPtrZInv, diff --git a/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py b/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py index 88c98ce041..9dea649633 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py @@ -2898,9 +2898,6 @@ def forward( assert qkv_format != "thd", f"{qkv_format} format is not supported!" qkv_layout = qkv_format + "_" + qkv_format + "_" + qkv_format - q_shape = q.shape - k_shape = k.shape - v_shape = v.shape causal = "causal" in attn_mask_type padding = "padding" in attn_mask_type @@ -2985,7 +2982,7 @@ def forward( fused_attn_backend = tex.NVTE_Fused_Attn_Backend.NVTE_FP8 if is_input_fp8: q_fp8, k_fp8, v_fp8 = q, k, v - else: + elif not fp8_recipe.mxfp8(): q_fp8, k_fp8, v_fp8, qkv_layout = combine_and_quantize(qkv_layout, q, k, v, QKV_quantizer) if not fp8_recipe.mxfp8(): q, k, v = [q_fp8._data, k_fp8._data, v_fp8._data] @@ -2996,8 +2993,11 @@ def forward( # [b, s, h, d] -> [b, 2, s//2, h, d] or [s, b, h, d] -> [2, s//2, b, h, d] q = q.view(*q.shape[:seq_dim], 2, q.shape[seq_dim] // 2, *q.shape[(seq_dim + 1) :]) + q_shape = q.shape # [b, s, h, d] or [s, b, h, d] -> [s, b, h, d] k, v = [x.movedim(seq_dim, 0).contiguous() for x in [k, v]] + k_shape = k.shape + v_shape = v.shape # [s, b, h, d] -> [cp, s, b, h, d] k_ag, _ = gather_along_first_dim(k, cp_group) @@ -3060,16 +3060,17 @@ def forward( k_, v_ = [x.movedim(0, seq_dim).contiguous() for x in [k_, v_]] if use_fused_attention: q_part, k_part, v_part = q_, k_, v_ + new_qkv_layout = qkv_layout if fp8: if not fp8_recipe.mxfp8(): q_part = Float8Tensor.make_like(q_fp8, data=q_, dtype=fwd_nominal_dtype) k_part = Float8Tensor.make_like(k_fp8, data=k_, dtype=fwd_nominal_dtype) v_part = Float8Tensor.make_like(v_fp8, data=v_, dtype=fwd_nominal_dtype) else: - q_part, k_part, v_part, qkv_layout = combine_and_quantize(qkv_layout, q_part, k_part, v_part, QKV_quantizer) + q_part, k_part, v_part, new_qkv_layout = combine_and_quantize(qkv_layout, q_part, k_part, v_part, QKV_quantizer) ( out_per_step[i], - [softmax_lse_per_step[i], rng_states[i]], + aux_ctx_tensors, *max_logit_, ) = fused_attn_fwd( is_training, @@ -3084,7 +3085,7 @@ def forward( fused_attn_backend, attn_scale=softmax_scale, dropout=dropout_p, - qkv_layout=qkv_layout, + qkv_layout=new_qkv_layout, o_format=qkv_format, attn_mask_type=attn_mask_type, attn_bias_type=attn_bias_type, @@ -3096,6 +3097,10 @@ def forward( cuda_graph=is_graph_capturing(), **fp8_meta_kwargs, ) + if fp8: + softmax_lse_per_step[i], _, rng_states[i] = aux_ctx_tensors + else: + softmax_lse_per_step[i], rng_states[i], *rest = aux_ctx_tensors if return_max_logit: max_logit_per_step[i] = max_logit_[0] if fp8 and isinstance(out_per_step[i], QuantizedTensorStorage): @@ -3245,15 +3250,23 @@ def backward(ctx, dout, *_args): cp_size = get_distributed_world_size(ctx.cp_group) rank = get_distributed_rank(ctx.cp_group) + cu_seqlens_kv_per_step = [None, None] + out_per_step = [None, None] + softmax_lse_per_step = [None, None] + rng_states = [None, None] ( q_fp8, k_fp8, v_fp8, out_fp8, q, k, v, out, cu_seqlens_q, cu_seqlens_q_padded, - cu_seqlens_kv_per_step, - out_per_step, - softmax_lse_per_step, - rng_states + cu_seqlens_kv_per_step[0], + cu_seqlens_kv_per_step[1], + out_per_step[0], + out_per_step[1], + softmax_lse_per_step[0], + softmax_lse_per_step[1], + rng_states[0], + rng_states[1], ) = restore_from_saved(ctx.tensor_objects, ctx.saved_tensors) kv_seq_range_per_step = ctx.kv_seq_range_per_step @@ -3277,9 +3290,15 @@ def backward(ctx, dout, *_args): if ctx.fp8 and not ctx.fp8_recipe.mxfp8(): q, k, v = [x._data for x in [q_fp8, k_fp8, v_fp8]] + if torch.cuda.current_device() == 0: + print(f"ctx.q_shape: {ctx.q_shape} {ctx.k_shape} {ctx.v_shape}") dq = torch.empty(ctx.q_shape, dtype=ctx.fwd_nominal_dtype, device=q.device) dk = torch.zeros((ctx.k_shape[0] * cp_size, *ctx.k_shape[1:]), dtype=ctx.fwd_nominal_dtype, device=k.device) dv = torch.zeros((ctx.v_shape[0] * cp_size, *ctx.v_shape[1:]), dtype=ctx.fwd_nominal_dtype, device=v.device) + if torch.cuda.current_device() == 0: + print(f"dq: {dq.shape} {dq.dtype} {dq.device}") + print(f"dk: {dk.shape} {dk.dtype} {dk.device}") + print(f"dv: {dv.shape} {dv.dtype} {dv.device}") dq_per_step = [None, None] dk_per_step = [None, None] dv_per_step = [None, None] @@ -3354,10 +3373,12 @@ def backward(ctx, dout, *_args): out_ = out_per_step[i] dout_ = dout.select(seq_dim, i).contiguous().view(out_.shape) if ctx.use_fused_attention: - aux_ctx_tensors = [softmax_lse_per_step[i], rng_states[i]] + aux_ctx_tensors = [softmax_lse_per_step[i], softmax_lse_per_step[i], rng_states[i]] q_part, k_part, v_part, out_part, dout_part = q_, k_, v_, out_, dout_ fused_attn_backend = tex.NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen fp8_meta_kwargs = {} + new_qkv_layout = qkv_layout + d_out_format = ctx.qkv_format if ctx.fp8: fused_attn_backend = tex.NVTE_Fused_Attn_Backend.NVTE_FP8 fp8_meta_kwargs["s_quantizer"] = ctx.S_quantizer @@ -3377,9 +3398,13 @@ def backward(ctx, dout, *_args): out_part = ctx.O_quantizer(out_part) dout_part = Float8Tensor.make_like(dout_fp8, data=dout_part, dtype=ctx.fwd_nominal_dtype) else: - q_part, k_part, v_part, qkv_layout = combine_and_quantize(qkv_layout, q_part, k_part, v_part, ctx.QKV_quantizer) + q_part, k_part, v_part, new_qkv_layout = combine_and_quantize(qkv_layout, q_part, k_part, v_part, ctx.QKV_quantizer) + print(f"aux_ctx_tensors: {len(aux_ctx_tensors)} {[x.shape if x is not None else None for x in aux_ctx_tensors]} {[type(x) for x in aux_ctx_tensors]}") + dout_part, d_out_format = dpa_utils.permute_to_grouped_tensor(d_out_format, dout_part) aux_ctx_tensors.append(dout_part) dout_part = ctx.dO_quantizer(dout_part) + print(f"q_part type: {type(q_part)} {type(k_part)} {type(v_part)} {type(out_part)} {type(dout_part)}") + print(f"q_part shape: {q_part.shape} {k_part.shape} {v_part.shape} {out_part.shape} {dout_part.shape}") dq_per_step[i], dk_per_step[i], dv_per_step[i], *_ = fused_attn_bwd( ctx.max_seqlen_q, max_seqlen_kv, @@ -3398,9 +3423,9 @@ def backward(ctx, dout, *_args): cu_seqlens_kv_padded=cu_seqlens_kv_per_step[i], attn_scale=ctx.softmax_scale, dropout=ctx.dropout_p, - qkv_layout=qkv_layout, + qkv_layout=new_qkv_layout, o_format=ctx.qkv_format, - d_out_format=ctx.qkv_format, + d_out_format=d_out_format, dqkv_layout=qkv_layout, attn_mask_type=ctx.attn_mask_type, attn_bias_type=ctx.attn_bias_type, @@ -3453,6 +3478,8 @@ def backward(ctx, dout, *_args): if i > 0: with torch.cuda.stream(flash_attn_streams[i - 1]): + if torch.cuda.current_device() == 0: + print(f"dq.shape: {dq.shape} dq_per_step[i - 1].shape: {dq_per_step[i - 1].shape}") if ctx.qkv_format == "bshd": dq[:, i - 1].copy_(dq_per_step[i - 1]) elif ctx.qkv_format == "sbhd": @@ -3522,6 +3549,7 @@ def backward(ctx, dout, *_args): None, None, None, + None, ) From 3ac48cd095799f79bd06fbe126edd3237bef267a Mon Sep 17 00:00:00 2001 From: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Date: Sun, 1 Mar 2026 13:51:36 -0800 Subject: [PATCH 041/172] enable mla ag Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --- tests/pytorch/attention/test_attention_with_cp.py | 6 +++--- .../dot_product_attention/context_parallel.py | 10 +++++----- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/tests/pytorch/attention/test_attention_with_cp.py b/tests/pytorch/attention/test_attention_with_cp.py index efed75925a..1ac9dc7398 100644 --- a/tests/pytorch/attention/test_attention_with_cp.py +++ b/tests/pytorch/attention/test_attention_with_cp.py @@ -151,7 +151,7 @@ def test_cp_with_flash_attention(dtype, model, qkv_format, cp_comm_type): "cp_1_3": ModelConfig(2, 4096, 12, 128, attn_bias_type="post_scale_bias"), # MHA "cp_1_4": ModelConfig(2, 4096, 12, 128, attn_mask_type="causal", window_size=(512, 512)), # MHA "cp_2_0": ModelConfig(2, 4096, 12, 128, num_gqa_groups=2, attn_mask_type="causal"), # GQA - "cp_2_1": ModelConfig(2, 4096, 16, 128), #192, head_dim_v=128, num_gqa_groups=4, attn_mask_type="causal"), # GQA + "cp_2_1": ModelConfig(2, 4096, 16, 192, head_dim_v=128), #num_gqa_groups=4, attn_mask_type="causal"), # GQA "cp_2_2": ModelConfig( 2, 4096, @@ -286,8 +286,8 @@ def test_cp_with_fused_attention( pytest.skip("fp8 only works with P2P, A2A and A2A+P2P for scaling_mode = current!") if f16_O and (dtype != "fp8" or scaling_mode not in ["current", "mxfp8"]): pytest.skip("f16_O only needs to be tested for dtype = fp8 and scaling_mode = current!") - if cp_comm_type not in ["p2p", "a2a+p2p", "a2a"] and config.head_dim_qk != config.head_dim_v: - pytest.skip("MLA CP currently only support KV P2P!") + # if cp_comm_type not in ["p2p", "a2a+p2p", "a2a"] and config.head_dim_qk != config.head_dim_v: + # pytest.skip("MLA CP currently only support KV P2P!") # if dtype == "fp8" and config.head_dim_qk != config.head_dim_v: # pytest.skip("MLA CP currently does not support FP8 attention!") if dtype == "fp8" and config.softmax_type != "vanilla": diff --git a/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py b/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py index 9dea649633..b8a50e8b77 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py @@ -4351,11 +4351,11 @@ def attn_forward_func_with_cp( ], f"Context parallelism does not support sliding window attention with {cp_comm_type=}!" enable_mla = k.shape[-1] != v.shape[-1] - assert not enable_mla or cp_comm_type in [ - "p2p", - "a2a+p2p", - "a2a", - ], f"Context parallelism does not support MLA with {cp_comm_type=}!" + # assert not enable_mla or cp_comm_type in [ + # "p2p", + # "a2a+p2p", + # "a2a", + # ], f"Context parallelism does not support MLA with {cp_comm_type=}!" if fp8 and fp8_meta is not None: if fp8_meta["recipe"].fp8_dpa: From 5d4fa5e2038bd5e21747ab0bc69dbff93fdab847 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sun, 1 Mar 2026 22:36:49 +0000 Subject: [PATCH 042/172] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../attention/run_attention_with_cp.py | 11 +- tests/pytorch/attention/test_attention.py | 6 +- .../attention/test_attention_with_cp.py | 4 +- tests/pytorch/test_grouped_tensor.py | 3 +- .../common/fused_attn/fused_attn.cpp | 88 +-- .../common/fused_attn/fused_attn_fp8.cu | 510 ++++++++++-------- .../common/fused_attn/fused_attn_fp8.h | 41 +- transformer_engine/common/fused_attn/utils.cu | 2 +- transformer_engine/common/fused_attn/utils.h | 468 ++++++++-------- .../include/transformer_engine/fused_attn.h | 39 +- .../transformer_engine/transformer_engine.h | 385 ++++++------- transformer_engine/common/recipe/__init__.py | 1 + .../common/transformer_engine.cpp | 212 ++++---- .../dot_product_attention/backends.py | 53 +- .../dot_product_attention/context_parallel.py | 202 +++++-- .../attention/dot_product_attention/utils.py | 17 +- transformer_engine/pytorch/csrc/common.h | 2 +- transformer_engine/pytorch/csrc/extensions.h | 6 +- .../pytorch/csrc/extensions/attention.cpp | 51 +- transformer_engine/pytorch/csrc/quantizer.cpp | 8 +- .../pytorch/tensor/storage/grouped_tensor.py | 2 +- 21 files changed, 1197 insertions(+), 914 deletions(-) diff --git a/tests/pytorch/attention/run_attention_with_cp.py b/tests/pytorch/attention/run_attention_with_cp.py index c9d6d9d64f..5cb43f277a 100644 --- a/tests/pytorch/attention/run_attention_with_cp.py +++ b/tests/pytorch/attention/run_attention_with_cp.py @@ -21,7 +21,12 @@ Float8CurrentScalingQuantizer, MXFP8Quantizer, ) -from transformer_engine.common.recipe import DelayedScaling, Float8CurrentScaling, MXFP8BlockScaling, Format +from transformer_engine.common.recipe import ( + DelayedScaling, + Float8CurrentScaling, + MXFP8BlockScaling, + Format, +) from utils import ModelConfig, compare_and_assert dtypes = {"fp16": torch.float16, "bf16": torch.bfloat16, "fp8": torch.bfloat16} @@ -250,7 +255,9 @@ def run_dpa_with_cp( if scaling_mode == "current": fp8_recipe = Float8CurrentScaling(fp8_dpa=fp8_dpa, fp8_mha=fp8_mha) if scaling_mode == "mxfp8": - fp8_recipe = MXFP8BlockScaling(fp8_format=Format.HYBRID, fp8_dpa=fp8_dpa, fp8_mha=fp8_mha) + fp8_recipe = MXFP8BlockScaling( + fp8_format=Format.HYBRID, fp8_dpa=fp8_dpa, fp8_mha=fp8_mha + ) # instantiate attention module core_attn = DotProductAttention( diff --git a/tests/pytorch/attention/test_attention.py b/tests/pytorch/attention/test_attention.py index 5760ca2434..47abf1ebc6 100644 --- a/tests/pytorch/attention/test_attention.py +++ b/tests/pytorch/attention/test_attention.py @@ -1804,8 +1804,10 @@ def get_model(dtype, config): model_configs_fp8_vs_f16 = { # test: ModelConfig(b, sq, hq, dqk) - "fp8_9": ModelConfig(2, 2048, 24, 128, num_gqa_groups=12),#, attn_mask_type="causal"), - "fp8_10": ModelConfig(2, 2048, 24, 192, head_dim_v=128, num_gqa_groups=12, window_size=(512, 512)), + "fp8_9": ModelConfig(2, 2048, 24, 128, num_gqa_groups=12), # , attn_mask_type="causal"), + "fp8_10": ModelConfig( + 2, 2048, 24, 192, head_dim_v=128, num_gqa_groups=12, window_size=(512, 512) + ), "fp8_11": ModelConfig(1, 8192, 32, 128, num_gqa_groups=4), "fp8_12": ModelConfig(2, 2048, 16, 128, attn_mask_type="causal"), "fp8_13": ModelConfig(2, 2048, 24, 128, num_gqa_groups=12, attn_mask_type="causal"), diff --git a/tests/pytorch/attention/test_attention_with_cp.py b/tests/pytorch/attention/test_attention_with_cp.py index 9e166fa908..a5fe8f74f5 100644 --- a/tests/pytorch/attention/test_attention_with_cp.py +++ b/tests/pytorch/attention/test_attention_with_cp.py @@ -154,7 +154,9 @@ def test_cp_with_flash_attention(dtype, model, qkv_format, cp_comm_type): ), # MHA "cp_1_5": ModelConfig(2, 4096, 12, 128, attn_mask_type="causal", window_size=(512, 512)), # MHA "cp_2_0": ModelConfig(2, 4096, 12, 128, num_gqa_groups=2, attn_mask_type="causal"), # GQA - "cp_2_1": ModelConfig(2, 4096, 16, 192, head_dim_v=128), #num_gqa_groups=4, attn_mask_type="causal"), # GQA + "cp_2_1": ModelConfig( + 2, 4096, 16, 192, head_dim_v=128 + ), # num_gqa_groups=4, attn_mask_type="causal"), # GQA "cp_2_2": ModelConfig( 2, 4096, diff --git a/tests/pytorch/test_grouped_tensor.py b/tests/pytorch/test_grouped_tensor.py index ab9ec28984..31d84933de 100644 --- a/tests/pytorch/test_grouped_tensor.py +++ b/tests/pytorch/test_grouped_tensor.py @@ -361,7 +361,6 @@ def test_static_quantize_method(self, quantization: str) -> None: expected_offset = _rowwise_offset_bytes(i * numel, quantization) assert rowwise_data.data_ptr() == original_data_ptr + expected_offset - @pytest.mark.skipif(not mxfp8_available, reason=reason_for_no_mxfp8) def test_quantize_grouped_mxfp8(self) -> None: """Test grouped quantization for MXFP8 against per-tensor quantization.""" @@ -372,7 +371,7 @@ def test_quantize_grouped_mxfp8(self) -> None: # Create BF16 input tensors and pack into a grouped tensor input_tensors = [torch.randn(s, dtype=torch.bfloat16, device="cuda") for s in shapes] quantizer = MXFP8Quantizer(fp8_dtype=tex.DType.kFloat8E4M3) - quantizer.optimize_for_gemm=True + quantizer.optimize_for_gemm = True grouped_input = GroupedTensor.make_grouped_tensor_with_shapes( num_tensors=num_tensors, shapes=shapes, diff --git a/transformer_engine/common/fused_attn/fused_attn.cpp b/transformer_engine/common/fused_attn/fused_attn.cpp index 1f5db127a0..72c5273a78 100644 --- a/transformer_engine/common/fused_attn/fused_attn.cpp +++ b/transformer_engine/common/fused_attn/fused_attn.cpp @@ -141,9 +141,10 @@ NVTE_QKV_Format nvte_get_kv_format(NVTE_QKV_Layout qkv_layout) { } // map one NVTE_QKV_Format to another -void nvte_convert_qkv_format(NVTE_QKV_Format src_format, std::vector src_shape, NVTE_QKV_Format dst_format, std::vector &dst_shape, - size_t *b, size_t *h, size_t *s, size_t *d, size_t *t) { - size_t _b=0, _h=0, _s=0, _d=0, _t=0; +void nvte_convert_qkv_format(NVTE_QKV_Format src_format, std::vector src_shape, + NVTE_QKV_Format dst_format, std::vector &dst_shape, size_t *b, + size_t *h, size_t *s, size_t *d, size_t *t) { + size_t _b = 0, _h = 0, _s = 0, _d = 0, _t = 0; switch (src_format) { case NVTE_QKV_Format::NVTE_BSHD: _b = src_shape[0]; @@ -270,8 +271,8 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK || attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK || attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK)) || - // 9.21: mxfp8, d_qk=128, d_v=192 - (cudnn_runtime_version >= 92100 && head_dim_qk <= 192 && head_dim_v <= 128)) && + // 9.21: mxfp8, d_qk=128, d_v=192 + (cudnn_runtime_version >= 92100 && head_dim_qk <= 192 && head_dim_v <= 128)) && !requires_64bit_ragged_offset && (softmax_type == NVTE_Softmax_Type::NVTE_VANILLA_SOFTMAX) && // 9.10.0: known bugs with SDPA FP8 (cudnn_runtime_version != 91000) && !return_max_logit) { @@ -411,13 +412,15 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( bias_type == NVTE_Bias_Type::NVTE_POST_SCALE_BIAS)) && // qkv format (qkv_format == NVTE_QKV_Format::NVTE_SBHD || qkv_format == NVTE_QKV_Format::NVTE_BSHD || - qkv_format == NVTE_QKV_Format::NVTE_BHSD || + qkv_format == NVTE_QKV_Format::NVTE_BHSD || (qkv_format == NVTE_QKV_Format::NVTE_THD && sm_arch_ >= 90 && ((cudnn_runtime_version >= 90100 && num_attn_heads == num_gqa_groups) || cudnn_runtime_version >= 90600)) || - ((q_format == NVTE_QKV_Format::NVTE_SBHD || q_format == NVTE_QKV_Format::NVTE_BSHD || q_format == NVTE_QKV_Format::NVTE_BHSD || + ((q_format == NVTE_QKV_Format::NVTE_SBHD || q_format == NVTE_QKV_Format::NVTE_BSHD || + q_format == NVTE_QKV_Format::NVTE_BHSD || (q_format == NVTE_QKV_Format::NVTE_THD && sm_arch_ >= 90) || - kv_format == NVTE_QKV_Format::NVTE_SBHD || kv_format == NVTE_QKV_Format::NVTE_BSHD || kv_format == NVTE_QKV_Format::NVTE_BHSD || + kv_format == NVTE_QKV_Format::NVTE_SBHD || kv_format == NVTE_QKV_Format::NVTE_BSHD || + kv_format == NVTE_QKV_Format::NVTE_BHSD || (kv_format == NVTE_QKV_Format::NVTE_THD && sm_arch_ >= 90)) && cudnn_runtime_version >= 90700)) && // sliding window @@ -428,7 +431,8 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( (cudnn_runtime_version >= 90200 && ((window_size_left == -1 && window_size_right == -1 && attn_mask_type == NVTE_Mask_Type::NVTE_NO_MASK) || - ((window_size_left == -1 || window_size_left >= 0) && (window_size_right == -1 || window_size_right >= 0) && + ((window_size_left == -1 || window_size_left >= 0) && + (window_size_right == -1 || window_size_right >= 0) && (attn_mask_type == NVTE_Mask_Type::NVTE_NO_MASK || attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK || (attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_BOTTOM_RIGHT_MASK && @@ -537,19 +541,17 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( } // NVTE fused attention FWD with separate Q, K and V -void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETensor V, - const NVTETensor Bias, const NVTETensor SoftmaxOffset, NVTETensor S, - NVTETensor O, NVTETensorPack *Aux_CTX_Tensors, - const NVTETensor cu_seqlens_q, const NVTETensor cu_seqlens_kv, - const NVTETensor cu_seqlens_q_padded, - const NVTETensor cu_seqlens_kv_padded, const NVTETensor page_table_k, - const NVTETensor page_table_v, const NVTETensor rng_state, - size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training, - bool return_max_logit, bool cuda_graph, float attn_scale, float dropout, - NVTE_QKV_Layout qkv_layout, NVTE_QKV_Format o_format, NVTE_Bias_Type bias_type, - NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, - int64_t window_size_left, int64_t window_size_right, - bool bottom_right_diagonal, NVTETensor workspace, cudaStream_t stream) { +void nvte_fused_attn_fwd( + const NVTETensor Q, const NVTETensor K, const NVTETensor V, const NVTETensor Bias, + const NVTETensor SoftmaxOffset, NVTETensor S, NVTETensor O, NVTETensorPack *Aux_CTX_Tensors, + const NVTETensor cu_seqlens_q, const NVTETensor cu_seqlens_kv, + const NVTETensor cu_seqlens_q_padded, const NVTETensor cu_seqlens_kv_padded, + const NVTETensor page_table_k, const NVTETensor page_table_v, const NVTETensor rng_state, + size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training, bool return_max_logit, + bool cuda_graph, float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, + NVTE_QKV_Format o_format, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, + NVTE_Softmax_Type softmax_type, int64_t window_size_left, int64_t window_size_right, + bool bottom_right_diagonal, NVTETensor workspace, cudaStream_t stream) { NVTE_API_CALL(nvte_flash_attn_fwd); using namespace transformer_engine; const Tensor *input_cu_seqlens_q = convertNVTETensorCheck(cu_seqlens_q); @@ -583,8 +585,10 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso if (kv_format == NVTE_QKV_Format::NVTE_THD) { t_kv = input_K->data.shape[0]; } - size_t h_q = (q_format == NVTE_QKV_Format::NVTE_BHSD) ? input_Q->data.shape[ndim-3] : input_Q->data.shape[ndim - 2]; - size_t h_kv = (kv_format == NVTE_QKV_Format::NVTE_BHSD) ? input_K->data.shape[ndim_kv-3] : input_K->data.shape[ndim_kv - 2]; + size_t h_q = (q_format == NVTE_QKV_Format::NVTE_BHSD) ? input_Q->data.shape[ndim - 3] + : input_Q->data.shape[ndim - 2]; + size_t h_kv = (kv_format == NVTE_QKV_Format::NVTE_BHSD) ? input_K->data.shape[ndim_kv - 3] + : input_K->data.shape[ndim_kv - 2]; int64_t num_pages_k = 0; int64_t num_pages_v = 0; int64_t page_size_k = 0; @@ -648,9 +652,10 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso #endif } else if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_FP8) { #if (CUDNN_VERSION >= 8900) - fused_attn_fp8_fwd(b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d_qk, d_v, is_training, attn_scale, - dropout, qkv_layout, o_format, bias_type, attn_mask_type, window_size_left, window_size_right, bottom_right_diagonal, input_Q, input_K, input_V, - input_output_S, output_O, Aux_CTX_Tensors, input_cu_seqlens_q, + fused_attn_fp8_fwd(b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d_qk, d_v, is_training, + attn_scale, dropout, qkv_layout, o_format, bias_type, attn_mask_type, + window_size_left, window_size_right, bottom_right_diagonal, input_Q, input_K, + input_V, input_output_S, output_O, Aux_CTX_Tensors, input_cu_seqlens_q, input_cu_seqlens_kv, input_rng_state, wkspace, stream, handle); #else NVTE_ERROR("cuDNN 8.9.0 is required for FP8 fused attention. \n"); @@ -668,11 +673,12 @@ void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETenso const NVTETensor cu_seqlens_q_padded, const NVTETensor cu_seqlens_kv_padded, size_t max_seqlen_q, size_t max_seqlen_kv, float attn_scale, float dropout, - NVTE_QKV_Layout qkv_layout, NVTE_QKV_Format o_format, NVTE_QKV_Format d_out_format, NVTE_QKV_Layout dqkv_layout, NVTE_Bias_Type bias_type, - NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, - int64_t window_size_left, int64_t window_size_right, - bool bottom_right_diagonal, bool deterministic, bool cuda_graph, - NVTETensor workspace, cudaStream_t stream) { + NVTE_QKV_Layout qkv_layout, NVTE_QKV_Format o_format, + NVTE_QKV_Format d_out_format, NVTE_QKV_Layout dqkv_layout, + NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, + NVTE_Softmax_Type softmax_type, int64_t window_size_left, + int64_t window_size_right, bool bottom_right_diagonal, bool deterministic, + bool cuda_graph, NVTETensor workspace, cudaStream_t stream) { NVTE_API_CALL(nvte_flash_attn_bwd); using namespace transformer_engine; const Tensor *input_cu_seqlens_q = convertNVTETensorCheck(cu_seqlens_q); @@ -708,8 +714,10 @@ void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETenso if (kv_format == NVTE_QKV_Format::NVTE_THD) { t_kv = input_K->data.shape[0]; } - size_t h_q = (q_format == NVTE_QKV_Format::NVTE_BHSD) ? input_Q->data.shape[ndim-3] : input_Q->data.shape[ndim - 2]; - size_t h_kv = (kv_format == NVTE_QKV_Format::NVTE_BHSD) ? input_K->data.shape[ndim_kv-3] : input_K->data.shape[ndim_kv - 2]; + size_t h_q = (q_format == NVTE_QKV_Format::NVTE_BHSD) ? input_Q->data.shape[ndim - 3] + : input_Q->data.shape[ndim - 2]; + size_t h_kv = (kv_format == NVTE_QKV_Format::NVTE_BHSD) ? input_K->data.shape[ndim_kv - 3] + : input_K->data.shape[ndim_kv - 2]; auto handle = cudnnExecutionPlanManager::Instance().GetHandle(); const NVTEDType Q_type = static_cast(input_Q->data.dtype); @@ -762,13 +770,15 @@ void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETenso const Tensor *input_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[2]); const Tensor *input_dO_f16; if (input_dO->scaling_mode == NVTE_MXFP8_1D_SCALING) { - input_dO_f16 = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[3]); + input_dO_f16 = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[3]); } fused_attn_fp8_bwd(b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d_qk, d_v, attn_scale, dropout, - qkv_layout, o_format, d_out_format, dqkv_layout, bias_type, attn_mask_type, window_size_left, window_size_right, bottom_right_diagonal, deterministic, input_Q, input_K, input_V, input_O, - input_dO, input_dO_f16, input_M, input_ZInv, input_S, input_output_dP, output_dQ, - output_dK, output_dV, input_cu_seqlens_q, input_cu_seqlens_kv, - input_rng_state, wkspace, stream, handle); + qkv_layout, o_format, d_out_format, dqkv_layout, bias_type, attn_mask_type, + window_size_left, window_size_right, bottom_right_diagonal, deterministic, + input_Q, input_K, input_V, input_O, input_dO, input_dO_f16, input_M, + input_ZInv, input_S, input_output_dP, output_dQ, output_dK, output_dV, + input_cu_seqlens_q, input_cu_seqlens_kv, input_rng_state, wkspace, stream, + handle); #else NVTE_ERROR("cuDNN 8.9.0 is required for FP8 fused attention. \n"); #endif diff --git a/transformer_engine/common/fused_attn/fused_attn_fp8.cu b/transformer_engine/common/fused_attn/fused_attn_fp8.cu index 003ca0051d..237f3bd66e 100644 --- a/transformer_engine/common/fused_attn/fused_attn_fp8.cu +++ b/transformer_engine/common/fused_attn/fused_attn_fp8.cu @@ -1652,13 +1652,15 @@ void fused_attn_fp8_bwd_impl( // fused attention FWD FP8 with FE 1.0+ void fused_attn_fp8_fwd_impl_v1( - int64_t b, int64_t h, int64_t hg, int64_t s_q, int64_t s_kv, int64_t d_qk, int64_t d_v, bool is_training, - float scaling_factor, float dropout_probability, NVTE_QKV_Layout qkv_layout, NVTE_QKV_Format o_format, - NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, int64_t window_size_left, int64_t window_size_right, bool bottom_right_diagonal, void* devPtrQ, void* devPtrK, void* devPtrV, - void* devPtrM, void* devPtrZInv, void* devPtrO, void* devPtrDescaleQ, void* devPtrDescaleK, - void* devPtrDescaleV, void* devPtrDescaleS, void* devPtrScaleS, void* devPtrScaleO, - void* devPtrAmaxO, void* devPtrAmaxS, void* devPtrcuSeqlensQ, void* devPtrcuSeqlensKV, - void* devPtrDropoutSeed, void* devPtrDropoutOffset, cudnn_frontend::DataType_t qkv_tensor_type, + int64_t b, int64_t h, int64_t hg, int64_t s_q, int64_t s_kv, int64_t d_qk, int64_t d_v, + bool is_training, float scaling_factor, float dropout_probability, NVTE_QKV_Layout qkv_layout, + NVTE_QKV_Format o_format, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, + int64_t window_size_left, int64_t window_size_right, bool bottom_right_diagonal, void* devPtrQ, + void* devPtrK, void* devPtrV, void* devPtrM, void* devPtrZInv, void* devPtrO, + void* devPtrDescaleQ, void* devPtrDescaleK, void* devPtrDescaleV, void* devPtrDescaleS, + void* devPtrScaleS, void* devPtrScaleO, void* devPtrAmaxO, void* devPtrAmaxS, + void* devPtrcuSeqlensQ, void* devPtrcuSeqlensKV, void* devPtrDropoutSeed, + void* devPtrDropoutOffset, cudnn_frontend::DataType_t qkv_tensor_type, cudnn_frontend::DataType_t o_tensor_type, NVTEScalingMode scaling_mode, void* workspace, size_t* workspace_size, cudaStream_t stream, cudnnHandle_t handle) { using namespace transformer_engine; @@ -1676,14 +1678,18 @@ void fused_attn_fp8_fwd_impl_v1( auto bias_skv = s_kv; NVTE_CHECK(~is_bias, "FP8 fused attention does not support pre/post_scale_bias yet!"); NVTE_CHECK(~is_alibi, "FP8 fused attention does not support ALiBi yet!"); - bool is_delayed_scaling = (scaling_mode == NVTE_DELAYED_TENSOR_SCALING) && (o_tensor_type == cudnn_frontend::DataType_t::FP8_E4M3 || + bool is_delayed_scaling = (scaling_mode == NVTE_DELAYED_TENSOR_SCALING) && + (o_tensor_type == cudnn_frontend::DataType_t::FP8_E4M3 || o_tensor_type == cudnn_frontend::DataType_t::FP8_E5M2); - bool is_current_scaling = (scaling_mode == NVTE_DELAYED_TENSOR_SCALING) && (o_tensor_type == cudnn_frontend::DataType_t::HALF || + bool is_current_scaling = (scaling_mode == NVTE_DELAYED_TENSOR_SCALING) && + (o_tensor_type == cudnn_frontend::DataType_t::HALF || o_tensor_type == cudnn_frontend::DataType_t::BFLOAT16); - bool is_mxfp8 = (scaling_mode == NVTE_MXFP8_1D_SCALING) && (o_tensor_type == cudnn_frontend::DataType_t::HALF || - o_tensor_type == cudnn_frontend::DataType_t::BFLOAT16); - NVTE_CHECK(is_delayed_scaling || is_current_scaling || is_mxfp8, - "FP8 fused attention only supports FP8DelayedScaling or FP8CurrentScaling or MXFP8 recipes!"); + bool is_mxfp8 = (scaling_mode == NVTE_MXFP8_1D_SCALING) && + (o_tensor_type == cudnn_frontend::DataType_t::HALF || + o_tensor_type == cudnn_frontend::DataType_t::BFLOAT16); + NVTE_CHECK( + is_delayed_scaling || is_current_scaling || is_mxfp8, + "FP8 fused attention only supports FP8DelayedScaling or FP8CurrentScaling or MXFP8 recipes!"); try { FADescriptor_v1 descriptor{b, @@ -1772,40 +1778,40 @@ void fused_attn_fp8_fwd_impl_v1( std::vector k_stride(4); std::vector v_stride(4); generateMatrixStrides_v1(b, h, hg, s_q, s_kv, d_qk, d_v, q_stride.data(), qkv_layout, - NVTE_QKV_Matrix::NVTE_Q_Matrix); + NVTE_QKV_Matrix::NVTE_Q_Matrix); generateMatrixStrides_v1(b, h, hg, s_q, s_kv, d_qk, d_v, k_stride.data(), qkv_layout, - NVTE_QKV_Matrix::NVTE_K_Matrix); + NVTE_QKV_Matrix::NVTE_K_Matrix); generateMatrixStrides_v1(b, h, hg, s_q, s_kv, d_qk, d_v, v_stride.data(), qkv_layout, - NVTE_QKV_Matrix::NVTE_V_Matrix); + NVTE_QKV_Matrix::NVTE_V_Matrix); Q = mha_graph->tensor(fe::graph::Tensor_attributes() - .set_name("Q") - .set_dim({b, h, s_q, d_qk}) - .set_stride(q_stride) - .set_data_type(qkv_tensor_type)); + .set_name("Q") + .set_dim({b, h, s_q, d_qk}) + .set_stride(q_stride) + .set_data_type(qkv_tensor_type)); K = mha_graph->tensor(fe::graph::Tensor_attributes() - .set_name("K") - .set_dim({b, hg, s_kv, d_qk}) - .set_stride(k_stride) - .set_data_type(qkv_tensor_type)); + .set_name("K") + .set_dim({b, hg, s_kv, d_qk}) + .set_stride(k_stride) + .set_data_type(qkv_tensor_type)); V = mha_graph->tensor(fe::graph::Tensor_attributes() - .set_name("V") - .set_dim({b, hg, s_kv, d_v}) - .set_stride(v_stride) - .set_data_type(qkv_tensor_type)); + .set_name("V") + .set_dim({b, hg, s_kv, d_v}) + .set_stride(v_stride) + .set_data_type(qkv_tensor_type)); attn_scale = mha_graph->tensor(fe::graph::Tensor_attributes() - .set_name("attn_scale") - .set_dim({1, 1, 1, 1}) - .set_stride({1, 1, 1, 1}) - .set_is_pass_by_value(true) - .set_data_type(fe::DataType_t::FLOAT)); + .set_name("attn_scale") + .set_dim({1, 1, 1, 1}) + .set_stride({1, 1, 1, 1}) + .set_is_pass_by_value(true) + .set_data_type(fe::DataType_t::FLOAT)); // Descale_q, Descale_k, Descale_v, Descale_s, Scale_s, Scale_o if (is_delayed_scaling || is_current_scaling) { descale_q = mha_graph->tensor(fe::graph::Tensor_attributes() - .set_name("Descale_q") - .set_dim({1, 1, 1, 1}) - .set_stride({1, 1, 1, 1}) - .set_data_type(fe::DataType_t::FLOAT)); + .set_name("Descale_q") + .set_dim({1, 1, 1, 1}) + .set_stride({1, 1, 1, 1}) + .set_data_type(fe::DataType_t::FLOAT)); descale_k = mha_graph->tensor_like(descale_q, "Descale_q"); descale_v = mha_graph->tensor_like(descale_q, "Descale_v"); descale_s = mha_graph->tensor_like(descale_q, "Descale_s"); @@ -1824,27 +1830,33 @@ void fused_attn_fp8_fwd_impl_v1( std::vector k_scale_strides(4); std::vector v_scale_strides(4); auto padded = pad_s_d_for_mxfp8(s_q, s_kv, d_qk, d_v); - generateMatrixStridesWithFormat(b, h, padded.s_q_padded, padded.d_qk_scale_padded, q_scale_strides.data(), q_format, false); - generateMatrixStridesWithFormat(b, hg, padded.s_kv_padded, padded.d_qk_scale_padded, k_scale_strides.data(), kv_format, false); - generateMatrixStridesWithFormat(b, hg, padded.s_kv_scale_padded, padded.d_v_padded, v_scale_strides.data(), kv_format, false); - descale_q = mha_graph->tensor(fe::graph::Tensor_attributes() - .set_name("Descale_q") - .set_dim({b, h, padded.s_q_padded, padded.d_qk_scale_padded}) - .set_stride(q_scale_strides) - .set_data_type(fe::DataType_t::FP8_E8M0) - .set_reordering_type(fe::TensorReordering_t::F8_128x4)); - descale_k = mha_graph->tensor(fe::graph::Tensor_attributes() - .set_name("Descale_k") - .set_dim({b, hg, padded.s_kv_padded, padded.d_qk_scale_padded}) - .set_stride(k_scale_strides) - .set_data_type(fe::DataType_t::FP8_E8M0) - .set_reordering_type(fe::TensorReordering_t::F8_128x4)); - descale_v = mha_graph->tensor(fe::graph::Tensor_attributes() - .set_name("Descale_v") - .set_dim({b, hg, padded.s_kv_scale_padded, padded.d_v_padded}) - .set_stride(v_scale_strides) - .set_data_type(fe::DataType_t::FP8_E8M0) - .set_reordering_type(fe::TensorReordering_t::F8_128x4)); + generateMatrixStridesWithFormat(b, h, padded.s_q_padded, padded.d_qk_scale_padded, + q_scale_strides.data(), q_format, false); + generateMatrixStridesWithFormat(b, hg, padded.s_kv_padded, padded.d_qk_scale_padded, + k_scale_strides.data(), kv_format, false); + generateMatrixStridesWithFormat(b, hg, padded.s_kv_scale_padded, padded.d_v_padded, + v_scale_strides.data(), kv_format, false); + descale_q = + mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("Descale_q") + .set_dim({b, h, padded.s_q_padded, padded.d_qk_scale_padded}) + .set_stride(q_scale_strides) + .set_data_type(fe::DataType_t::FP8_E8M0) + .set_reordering_type(fe::TensorReordering_t::F8_128x4)); + descale_k = + mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("Descale_k") + .set_dim({b, hg, padded.s_kv_padded, padded.d_qk_scale_padded}) + .set_stride(k_scale_strides) + .set_data_type(fe::DataType_t::FP8_E8M0) + .set_reordering_type(fe::TensorReordering_t::F8_128x4)); + descale_v = + mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("Descale_v") + .set_dim({b, hg, padded.s_kv_scale_padded, padded.d_v_padded}) + .set_stride(v_scale_strides) + .set_data_type(fe::DataType_t::FP8_E8M0) + .set_reordering_type(fe::TensorReordering_t::F8_128x4)); } fe::graph::SDPA_fp8_attributes sdpa_options; @@ -1854,7 +1866,7 @@ void fused_attn_fp8_fwd_impl_v1( .set_causal_mask(is_causal) .set_attn_scale(attn_scale); - fe::DiagonalAlignment_t const &diagonal_alignment = + fe::DiagonalAlignment_t const& diagonal_alignment = bottom_right_diagonal ? fe::DiagonalAlignment_t::BOTTOM_RIGHT : fe::DiagonalAlignment_t::TOP_LEFT; sdpa_options.set_diagonal_alignment(diagonal_alignment); @@ -1910,21 +1922,24 @@ void fused_attn_fp8_fwd_impl_v1( Stats = outputs[1]; amax_o = outputs[2]; } else { - auto outputs = mha_graph->sdpa_fp8( - Q, K, V, descale_q, descale_k, descale_v, descale_s, scale_s, scale_o, sdpa_options); + auto outputs = mha_graph->sdpa_fp8(Q, K, V, descale_q, descale_k, descale_v, descale_s, + scale_s, scale_o, sdpa_options); O = outputs[0]; Stats = outputs[1]; amax_s = outputs[2]; amax_o = outputs[3]; amax_s->set_output(true) - .set_dim({1, 1, 1, 1}) - .set_stride({1, 1, 1, 1}) - .set_data_type(fe::DataType_t::FLOAT); + .set_dim({1, 1, 1, 1}) + .set_stride({1, 1, 1, 1}) + .set_data_type(fe::DataType_t::FLOAT); } std::vector o_stride(4); generateMatrixStridesWithFormat(b, h, s_q, d_v, o_stride.data(), o_format, false); - O->set_output(true).set_dim({b, h, s_q, d_v}).set_stride(o_stride).set_data_type(o_tensor_type); + O->set_output(true) + .set_dim({b, h, s_q, d_v}) + .set_stride(o_stride) + .set_data_type(o_tensor_type); amax_o->set_output(true) .set_dim({1, 1, 1, 1}) .set_stride({1, 1, 1, 1}) @@ -1949,9 +1964,10 @@ void fused_attn_fp8_fwd_impl_v1( std::shared_ptr, // amax_s std::shared_ptr> // amax_o key_tensors_tuple = - is_mxfp8 ? std::make_tuple(Q, K, V, descale_q, descale_k, descale_v, nullptr, nullptr, nullptr, attn_scale, O, nullptr, amax_o) : - std::make_tuple(Q, K, V, descale_q, descale_k, descale_v, descale_s, - scale_s, scale_o, attn_scale, O, amax_s, amax_o); + is_mxfp8 ? std::make_tuple(Q, K, V, descale_q, descale_k, descale_v, nullptr, nullptr, + nullptr, attn_scale, O, nullptr, amax_o) + : std::make_tuple(Q, K, V, descale_q, descale_k, descale_v, descale_s, + scale_s, scale_o, attn_scale, O, amax_s, amax_o); auto Stats_tuple = std::make_tuple(Stats); auto bias_tuple = is_bias ? std::make_tuple(bias) : std::make_tuple(nullptr); auto padding_tuple = @@ -2040,20 +2056,23 @@ void fused_attn_fp8_fwd_impl_v1( // fused attention BWD FP8 with FE 1.0+ void fused_attn_fp8_bwd_impl_v1( - int64_t b, int64_t h, int64_t hg, int64_t s_q, int64_t s_kv, int64_t d_qk, int64_t d_v, float scaling_factor, - float dropout_probability, NVTE_QKV_Layout qkv_layout, NVTE_QKV_Format o_format, NVTE_QKV_Format d_out_format, NVTE_QKV_Layout dqkv_layout, NVTE_Bias_Type bias_type, - NVTE_Mask_Type mask_type, int64_t window_size_left, int64_t window_size_right, bool bottom_right_diagonal, bool deterministic, void* devPtrQ, void* devPtrK, void* devPtrV, void* devPtrM, - void* devPtrZInv, void* devPtrO, void* devPtrdO, void* devPtrdQ, void* devPtrdK, void* devPtrdV, - void* devPtrDescaleQ, void* devPtrDescaleK, void* devPtrDescaleV, void* devPtrDescaleO, - void* devPtrDescaledO, void* devPtrDescaleS, void* devPtrDescaledP, void* devPtrScaleS, - void* devPtrScaledP, void* devPtrScaledQ, void* devPtrScaledK, void* devPtrScaledV, - void* devPtrAmaxdP, void* devPtrAmaxdQ, void* devPtrAmaxdK, void* devPtrAmaxdV, - void* devPtrQ_t, void* devPtrK_t, void* devPtrdO_f16, void* devPtrdO_t, void* devPtrDescaleQ_t, void* devPtrDescaleK_t, void* devPtrDescaledO_t, + int64_t b, int64_t h, int64_t hg, int64_t s_q, int64_t s_kv, int64_t d_qk, int64_t d_v, + float scaling_factor, float dropout_probability, NVTE_QKV_Layout qkv_layout, + NVTE_QKV_Format o_format, NVTE_QKV_Format d_out_format, NVTE_QKV_Layout dqkv_layout, + NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, int64_t window_size_left, + int64_t window_size_right, bool bottom_right_diagonal, bool deterministic, void* devPtrQ, + void* devPtrK, void* devPtrV, void* devPtrM, void* devPtrZInv, void* devPtrO, void* devPtrdO, + void* devPtrdQ, void* devPtrdK, void* devPtrdV, void* devPtrDescaleQ, void* devPtrDescaleK, + void* devPtrDescaleV, void* devPtrDescaleO, void* devPtrDescaledO, void* devPtrDescaleS, + void* devPtrDescaledP, void* devPtrScaleS, void* devPtrScaledP, void* devPtrScaledQ, + void* devPtrScaledK, void* devPtrScaledV, void* devPtrAmaxdP, void* devPtrAmaxdQ, + void* devPtrAmaxdK, void* devPtrAmaxdV, void* devPtrQ_t, void* devPtrK_t, void* devPtrdO_f16, + void* devPtrdO_t, void* devPtrDescaleQ_t, void* devPtrDescaleK_t, void* devPtrDescaledO_t, void* devPtrcuSeqlensQ, void* devPtrcuSeqlensKV, void* devPtrDropoutSeed, void* devPtrDropoutOffset, cudnn_frontend::DataType_t qkv_tensor_type, cudnn_frontend::DataType_t o_tensor_type, cudnn_frontend::DataType_t do_tensor_type, - cudnn_frontend::DataType_t dqkv_tensor_type, NVTEScalingMode scaling_mode, void* workspace, size_t* workspace_size, - cudaStream_t stream, cudnnHandle_t handle) { + cudnn_frontend::DataType_t dqkv_tensor_type, NVTEScalingMode scaling_mode, void* workspace, + size_t* workspace_size, cudaStream_t stream, cudnnHandle_t handle) { using namespace transformer_engine; const auto cudnn_runtime_version = cudnnGetVersion(); bool is_bias = (bias_type == NVTE_Bias_Type::NVTE_POST_SCALE_BIAS); @@ -2070,14 +2089,18 @@ void fused_attn_fp8_bwd_impl_v1( auto bias_skv = s_kv; NVTE_CHECK(~is_bias, "FP8 fused attention does not support pre/post_scale_bias yet!"); NVTE_CHECK(~is_alibi, "FP8 fused attention does not support ALiBi yet!"); - bool is_delayed_scaling = (scaling_mode == NVTE_DELAYED_TENSOR_SCALING) && (dqkv_tensor_type == cudnn_frontend::DataType_t::FP8_E4M3 || - dqkv_tensor_type == cudnn_frontend::DataType_t::FP8_E5M2); - bool is_current_scaling = (scaling_mode == NVTE_DELAYED_TENSOR_SCALING) && (dqkv_tensor_type == cudnn_frontend::DataType_t::HALF || - dqkv_tensor_type == cudnn_frontend::DataType_t::BFLOAT16); - bool is_mxfp8 = (scaling_mode == NVTE_MXFP8_1D_SCALING) && (dqkv_tensor_type == cudnn_frontend::DataType_t::HALF || - dqkv_tensor_type == cudnn_frontend::DataType_t::BFLOAT16); - NVTE_CHECK(is_delayed_scaling || is_current_scaling || is_mxfp8, - "FP8 fused attention only supports FP8DelayedScaling or FP8CurrentScaling or MXFP8 recipes!"); + bool is_delayed_scaling = (scaling_mode == NVTE_DELAYED_TENSOR_SCALING) && + (dqkv_tensor_type == cudnn_frontend::DataType_t::FP8_E4M3 || + dqkv_tensor_type == cudnn_frontend::DataType_t::FP8_E5M2); + bool is_current_scaling = (scaling_mode == NVTE_DELAYED_TENSOR_SCALING) && + (dqkv_tensor_type == cudnn_frontend::DataType_t::HALF || + dqkv_tensor_type == cudnn_frontend::DataType_t::BFLOAT16); + bool is_mxfp8 = (scaling_mode == NVTE_MXFP8_1D_SCALING) && + (dqkv_tensor_type == cudnn_frontend::DataType_t::HALF || + dqkv_tensor_type == cudnn_frontend::DataType_t::BFLOAT16); + NVTE_CHECK( + is_delayed_scaling || is_current_scaling || is_mxfp8, + "FP8 fused attention only supports FP8DelayedScaling or FP8CurrentScaling or MXFP8 recipes!"); bool is_O_in_F16 = (o_tensor_type == cudnn_frontend::DataType_t::HALF || o_tensor_type == cudnn_frontend::DataType_t::BFLOAT16); @@ -2179,8 +2202,10 @@ void fused_attn_fp8_bwd_impl_v1( .set_intermediate_data_type(fe::DataType_t::FLOAT) .set_compute_data_type(fe::DataType_t::FLOAT); - std::shared_ptr Q, Q_t, K, K_t, V, O, dO, dO_t, dO_f16, Stats, attn_scale; - std::shared_ptr descale_q, descale_q_t, descale_k, descale_k_t, descale_v; + std::shared_ptr Q, Q_t, K, K_t, V, O, dO, dO_t, dO_f16, Stats, + attn_scale; + std::shared_ptr descale_q, descale_q_t, descale_k, descale_k_t, + descale_v; std::shared_ptr descale_s, descale_o; std::shared_ptr descale_dP, descale_dO, descale_dO_t; std::shared_ptr scale_s, scale_dP; @@ -2194,11 +2219,11 @@ void fused_attn_fp8_bwd_impl_v1( std::vector v_stride(4); std::vector o_stride(4); generateMatrixStrides_v1(b, h, hg, s_q, s_kv, d_qk, d_v, q_stride.data(), qkv_layout, - NVTE_QKV_Matrix::NVTE_Q_Matrix); + NVTE_QKV_Matrix::NVTE_Q_Matrix); generateMatrixStrides_v1(b, hg, hg, s_q, s_kv, d_qk, d_v, k_stride.data(), qkv_layout, - NVTE_QKV_Matrix::NVTE_K_Matrix); + NVTE_QKV_Matrix::NVTE_K_Matrix); generateMatrixStrides_v1(b, hg, hg, s_q, s_kv, d_qk, d_v, v_stride.data(), qkv_layout, - NVTE_QKV_Matrix::NVTE_V_Matrix); + NVTE_QKV_Matrix::NVTE_V_Matrix); generateMatrixStridesWithFormat(b, h, s_q, d_v, o_stride.data(), o_format, false); Q = mha_graph->tensor(fe::graph::Tensor_attributes() .set_name("Q") @@ -2277,17 +2302,28 @@ void fused_attn_fp8_bwd_impl_v1( generateMatrixStridesWithFormat(b, h, s_q, d_qk, q_t_stride.data(), q_format, false); generateMatrixStridesWithFormat(b, hg, s_kv, d_qk, k_t_stride.data(), kv_format, false); generateMatrixStridesWithFormat(b, h, s_q, d_v, dO_t_stride.data(), d_out_format, false); - printf("q_t_stride: %d, %d, %d, %d\n", q_t_stride[0], q_t_stride[1], q_t_stride[2], q_t_stride[3]); - printf("k_t_stride: %d, %d, %d, %d\n", k_t_stride[0], k_t_stride[1], k_t_stride[2], k_t_stride[3]); - printf("dO_t_stride: %d, %d, %d, %d\n", dO_t_stride[0], dO_t_stride[1], dO_t_stride[2], dO_t_stride[3]); - printf("qkv_tensor_type: %d, %d, %d\n", qkv_tensor_type, cudnn_frontend::DataType_t::FP8_E4M3, cudnn_frontend::DataType_t::FP8_E5M2); - printf("o_tensor_type: %d, %d, %d\n", o_tensor_type, cudnn_frontend::DataType_t::HALF, cudnn_frontend::DataType_t::BFLOAT16); - printf("do_tensor_type: %d, %d, %d\n", do_tensor_type, cudnn_frontend::DataType_t::FP8_E4M3, cudnn_frontend::DataType_t::FP8_E5M2); - printf("dqkv_tensor_type: %d, %d, %d\n", dqkv_tensor_type, cudnn_frontend::DataType_t::FP8_E4M3, cudnn_frontend::DataType_t::FP8_E5M2); - printf("qkv_layout: %d, %d, %d\n", qkv_layout, NVTE_QKV_Layout::NVTE_BSHD_BSHD_BSHD, NVTE_QKV_Layout::NVTE_BHSD_BHSD_BHSD); - printf("o_format: %d, %d, %d\n", o_format, NVTE_QKV_Format::NVTE_BSHD, NVTE_QKV_Format::NVTE_BHSD); - printf("d_out_format: %d, %d, %d\n", d_out_format, NVTE_QKV_Format::NVTE_BSHD, NVTE_QKV_Format::NVTE_BHSD); - printf("dqkv_layout: %d, %d, %d\n", dqkv_layout, NVTE_QKV_Layout::NVTE_BSHD_BSHD_BSHD, NVTE_QKV_Layout::NVTE_BHSD_BHSD_BHSD); + printf("q_t_stride: %d, %d, %d, %d\n", q_t_stride[0], q_t_stride[1], q_t_stride[2], + q_t_stride[3]); + printf("k_t_stride: %d, %d, %d, %d\n", k_t_stride[0], k_t_stride[1], k_t_stride[2], + k_t_stride[3]); + printf("dO_t_stride: %d, %d, %d, %d\n", dO_t_stride[0], dO_t_stride[1], dO_t_stride[2], + dO_t_stride[3]); + printf("qkv_tensor_type: %d, %d, %d\n", qkv_tensor_type, + cudnn_frontend::DataType_t::FP8_E4M3, cudnn_frontend::DataType_t::FP8_E5M2); + printf("o_tensor_type: %d, %d, %d\n", o_tensor_type, cudnn_frontend::DataType_t::HALF, + cudnn_frontend::DataType_t::BFLOAT16); + printf("do_tensor_type: %d, %d, %d\n", do_tensor_type, cudnn_frontend::DataType_t::FP8_E4M3, + cudnn_frontend::DataType_t::FP8_E5M2); + printf("dqkv_tensor_type: %d, %d, %d\n", dqkv_tensor_type, + cudnn_frontend::DataType_t::FP8_E4M3, cudnn_frontend::DataType_t::FP8_E5M2); + printf("qkv_layout: %d, %d, %d\n", qkv_layout, NVTE_QKV_Layout::NVTE_BSHD_BSHD_BSHD, + NVTE_QKV_Layout::NVTE_BHSD_BHSD_BHSD); + printf("o_format: %d, %d, %d\n", o_format, NVTE_QKV_Format::NVTE_BSHD, + NVTE_QKV_Format::NVTE_BHSD); + printf("d_out_format: %d, %d, %d\n", d_out_format, NVTE_QKV_Format::NVTE_BSHD, + NVTE_QKV_Format::NVTE_BHSD); + printf("dqkv_layout: %d, %d, %d\n", dqkv_layout, NVTE_QKV_Layout::NVTE_BSHD_BSHD_BSHD, + NVTE_QKV_Layout::NVTE_BHSD_BHSD_BHSD); printf("b: %d\n", b); printf("h: %d\n", h); printf("hg: %d\n", hg); @@ -2304,25 +2340,25 @@ void fused_attn_fp8_bwd_impl_v1( printf("is_dropout: %d\n", is_dropout); printf("is_bias: %d\n", is_bias); Q_t = mha_graph->tensor(fe::graph::Tensor_attributes() - .set_name("Q_t") - .set_dim({b, h, s_q, d_qk}) - .set_stride(q_t_stride) - .set_data_type(qkv_tensor_type)); + .set_name("Q_t") + .set_dim({b, h, s_q, d_qk}) + .set_stride(q_t_stride) + .set_data_type(qkv_tensor_type)); K_t = mha_graph->tensor(fe::graph::Tensor_attributes() - .set_name("K_t") - .set_dim({b, hg, s_kv, d_qk}) - .set_stride(k_t_stride) - .set_data_type(qkv_tensor_type)); + .set_name("K_t") + .set_dim({b, hg, s_kv, d_qk}) + .set_stride(k_t_stride) + .set_data_type(qkv_tensor_type)); dO_t = mha_graph->tensor(fe::graph::Tensor_attributes() - .set_name("dO_t") - .set_dim({b, h, s_q, d_v}) - .set_stride(dO_t_stride) - .set_data_type(do_tensor_type)); + .set_name("dO_t") + .set_dim({b, h, s_q, d_v}) + .set_stride(dO_t_stride) + .set_data_type(do_tensor_type)); dO_f16 = mha_graph->tensor(fe::graph::Tensor_attributes() - .set_name("dO_f16") - .set_dim({b, h, s_q, d_v}) - .set_stride(o_stride) - .set_data_type(o_tensor_type)); + .set_name("dO_f16") + .set_dim({b, h, s_q, d_v}) + .set_stride(o_stride) + .set_data_type(o_tensor_type)); // Descale_q, Descale_q_t, Descale_k, Descale_k_t, Descale_v, Descale_dO, Descale_dO_t auto padded = pad_s_d_for_mxfp8(s_q, s_kv, d_qk, d_v); printf("s_q_padded: %d\n", padded.s_q_padded); @@ -2344,57 +2380,78 @@ void fused_attn_fp8_bwd_impl_v1( std::vector v_scale_strides(4); std::vector dO_scale_strides(4); std::vector dO_t_scale_strides(4); - generateMatrixStridesWithFormat(b, h, padded.s_q_padded, padded.d_qk_scale_padded, q_scale_strides.data(), q_format, false); - generateMatrixStridesWithFormat(b, h, padded.s_q_scale_padded, padded.d_qk_padded, q_t_scale_strides.data(), q_format, false); - generateMatrixStridesWithFormat(b, hg, padded.s_kv_padded, padded.d_qk_scale_padded, k_scale_strides.data(), kv_format, false); - generateMatrixStridesWithFormat(b, hg, padded.s_kv_scale_padded, padded.d_qk_padded, k_t_scale_strides.data(), kv_format, false); - generateMatrixStridesWithFormat(b, hg, padded.s_kv_padded, padded.d_v_scale_padded, v_scale_strides.data(), kv_format, false); - generateMatrixStridesWithFormat(b, h, padded.s_q_padded, padded.d_v_scale_padded, dO_scale_strides.data(), d_out_format, false); - generateMatrixStridesWithFormat(b, h, padded.s_q_scale_padded, padded.d_v_padded, dO_t_scale_strides.data(), d_out_format, false); - printf("q_scale_strides: %d, %d, %d, %d\n", q_scale_strides[0], q_scale_strides[1], q_scale_strides[2], q_scale_strides[3]); - printf("q_t_scale_strides: %d, %d, %d, %d\n", q_t_scale_strides[0], q_t_scale_strides[1], q_t_scale_strides[2], q_t_scale_strides[3]); - printf("k_scale_strides: %d, %d, %d, %d\n", k_scale_strides[0], k_scale_strides[1], k_scale_strides[2], k_scale_strides[3]); - printf("k_t_scale_strides: %d, %d, %d, %d\n", k_t_scale_strides[0], k_t_scale_strides[1], k_t_scale_strides[2], k_t_scale_strides[3]); - printf("v_scale_strides: %d, %d, %d, %d\n", v_scale_strides[0], v_scale_strides[1], v_scale_strides[2], v_scale_strides[3]); - printf("dO_scale_strides: %d, %d, %d, %d\n", dO_scale_strides[0], dO_scale_strides[1], dO_scale_strides[2], dO_scale_strides[3]); - printf("dO_t_scale_strides: %d, %d, %d, %d\n", dO_t_scale_strides[0], dO_t_scale_strides[1], dO_t_scale_strides[2], dO_t_scale_strides[3]); - descale_q = mha_graph->tensor(fe::graph::Tensor_attributes() + generateMatrixStridesWithFormat(b, h, padded.s_q_padded, padded.d_qk_scale_padded, + q_scale_strides.data(), q_format, false); + generateMatrixStridesWithFormat(b, h, padded.s_q_scale_padded, padded.d_qk_padded, + q_t_scale_strides.data(), q_format, false); + generateMatrixStridesWithFormat(b, hg, padded.s_kv_padded, padded.d_qk_scale_padded, + k_scale_strides.data(), kv_format, false); + generateMatrixStridesWithFormat(b, hg, padded.s_kv_scale_padded, padded.d_qk_padded, + k_t_scale_strides.data(), kv_format, false); + generateMatrixStridesWithFormat(b, hg, padded.s_kv_padded, padded.d_v_scale_padded, + v_scale_strides.data(), kv_format, false); + generateMatrixStridesWithFormat(b, h, padded.s_q_padded, padded.d_v_scale_padded, + dO_scale_strides.data(), d_out_format, false); + generateMatrixStridesWithFormat(b, h, padded.s_q_scale_padded, padded.d_v_padded, + dO_t_scale_strides.data(), d_out_format, false); + printf("q_scale_strides: %d, %d, %d, %d\n", q_scale_strides[0], q_scale_strides[1], + q_scale_strides[2], q_scale_strides[3]); + printf("q_t_scale_strides: %d, %d, %d, %d\n", q_t_scale_strides[0], q_t_scale_strides[1], + q_t_scale_strides[2], q_t_scale_strides[3]); + printf("k_scale_strides: %d, %d, %d, %d\n", k_scale_strides[0], k_scale_strides[1], + k_scale_strides[2], k_scale_strides[3]); + printf("k_t_scale_strides: %d, %d, %d, %d\n", k_t_scale_strides[0], k_t_scale_strides[1], + k_t_scale_strides[2], k_t_scale_strides[3]); + printf("v_scale_strides: %d, %d, %d, %d\n", v_scale_strides[0], v_scale_strides[1], + v_scale_strides[2], v_scale_strides[3]); + printf("dO_scale_strides: %d, %d, %d, %d\n", dO_scale_strides[0], dO_scale_strides[1], + dO_scale_strides[2], dO_scale_strides[3]); + printf("dO_t_scale_strides: %d, %d, %d, %d\n", dO_t_scale_strides[0], dO_t_scale_strides[1], + dO_t_scale_strides[2], dO_t_scale_strides[3]); + descale_q = + mha_graph->tensor(fe::graph::Tensor_attributes() .set_name("Descale_q") .set_dim({b, h, padded.s_q_padded, padded.d_qk_scale_padded}) .set_stride(q_scale_strides) .set_data_type(fe::DataType_t::FP8_E8M0) .set_reordering_type(fe::TensorReordering_t::F8_128x4)); - descale_q_t = mha_graph->tensor(fe::graph::Tensor_attributes() + descale_q_t = + mha_graph->tensor(fe::graph::Tensor_attributes() .set_name("Descale_q_t") .set_dim({b, h, padded.s_q_scale_padded, padded.d_qk_padded}) .set_stride(q_t_scale_strides) .set_data_type(fe::DataType_t::FP8_E8M0) .set_reordering_type(fe::TensorReordering_t::F8_128x4)); - descale_k = mha_graph->tensor(fe::graph::Tensor_attributes() + descale_k = + mha_graph->tensor(fe::graph::Tensor_attributes() .set_name("Descale_k") .set_dim({b, hg, padded.s_kv_padded, padded.d_qk_scale_padded}) .set_stride(k_scale_strides) .set_data_type(fe::DataType_t::FP8_E8M0) .set_reordering_type(fe::TensorReordering_t::F8_128x4)); - descale_k_t = mha_graph->tensor(fe::graph::Tensor_attributes() + descale_k_t = + mha_graph->tensor(fe::graph::Tensor_attributes() .set_name("Descale_k_t") .set_dim({b, hg, padded.s_kv_scale_padded, padded.d_qk_padded}) .set_stride(k_t_scale_strides) .set_data_type(fe::DataType_t::FP8_E8M0) .set_reordering_type(fe::TensorReordering_t::F8_128x4)); - descale_v = mha_graph->tensor(fe::graph::Tensor_attributes() + descale_v = + mha_graph->tensor(fe::graph::Tensor_attributes() .set_name("Descale_v") .set_dim({b, hg, padded.s_kv_padded, padded.d_v_scale_padded}) .set_stride(v_scale_strides) .set_data_type(fe::DataType_t::FP8_E8M0) .set_reordering_type(fe::TensorReordering_t::F8_128x4)); - descale_dO = mha_graph->tensor(fe::graph::Tensor_attributes() + descale_dO = + mha_graph->tensor(fe::graph::Tensor_attributes() .set_name("Descale_dO") .set_dim({b, h, padded.s_q_padded, padded.d_v_scale_padded}) .set_stride(dO_scale_strides) .set_data_type(fe::DataType_t::FP8_E8M0) .set_reordering_type(fe::TensorReordering_t::F8_128x4)); - descale_dO_t = mha_graph->tensor(fe::graph::Tensor_attributes() + descale_dO_t = + mha_graph->tensor(fe::graph::Tensor_attributes() .set_name("Descale_dO_t") .set_dim({b, h, padded.s_q_scale_padded, padded.d_v_padded}) .set_stride(dO_t_scale_strides) @@ -2408,7 +2465,7 @@ void fused_attn_fp8_bwd_impl_v1( .set_causal_mask(is_causal) .set_attn_scale(attn_scale); - fe::DiagonalAlignment_t const &diagonal_alignment = + fe::DiagonalAlignment_t const& diagonal_alignment = bottom_right_diagonal ? fe::DiagonalAlignment_t::BOTTOM_RIGHT : fe::DiagonalAlignment_t::TOP_LEFT; sdpa_backward_options.set_diagonal_alignment(diagonal_alignment); @@ -2474,9 +2531,10 @@ void fused_attn_fp8_bwd_impl_v1( } std::shared_ptr dQ, dK, dV, amax_dQ, amax_dK, amax_dV, amax_dP; if (is_delayed_scaling || is_current_scaling) { - auto outputs = mha_graph->sdpa_fp8_backward( - Q, K, V, O, dO, Stats, descale_q, descale_k, descale_v, descale_o, descale_dO, descale_s, - descale_dP, scale_s, scale_dQ, scale_dK, scale_dV, scale_dP, sdpa_backward_options); + auto outputs = mha_graph->sdpa_fp8_backward(Q, K, V, O, dO, Stats, descale_q, descale_k, + descale_v, descale_o, descale_dO, descale_s, + descale_dP, scale_s, scale_dQ, scale_dK, + scale_dV, scale_dP, sdpa_backward_options); dQ = outputs[0]; dK = outputs[1]; dV = outputs[2]; @@ -2487,8 +2545,8 @@ void fused_attn_fp8_bwd_impl_v1( } if (is_mxfp8) { auto outputs = mha_graph->sdpa_fp8_backward( - Q, Q_t, K, K_t, V, O, dO_f16, dO, dO_t, Stats, descale_q, descale_q_t, descale_k, descale_k_t, - descale_v, descale_dO, descale_dO_t, sdpa_backward_options); + Q, Q_t, K, K_t, V, O, dO_f16, dO, dO_t, Stats, descale_q, descale_q_t, descale_k, + descale_k_t, descale_v, descale_dO, descale_dO_t, sdpa_backward_options); dQ = outputs[0]; dK = outputs[1]; dV = outputs[2]; @@ -2500,17 +2558,26 @@ void fused_attn_fp8_bwd_impl_v1( std::vector dk_stride(4); std::vector dv_stride(4); generateMatrixStrides_v1(b, h, hg, s_q, s_kv, d_qk, d_v, dq_stride.data(), dqkv_layout, - NVTE_QKV_Matrix::NVTE_Q_Matrix); + NVTE_QKV_Matrix::NVTE_Q_Matrix); generateMatrixStrides_v1(b, hg, hg, s_q, s_kv, d_qk, d_v, dk_stride.data(), dqkv_layout, - NVTE_QKV_Matrix::NVTE_K_Matrix); + NVTE_QKV_Matrix::NVTE_K_Matrix); generateMatrixStrides_v1(b, hg, hg, s_q, s_kv, d_qk, d_v, dv_stride.data(), dqkv_layout, - NVTE_QKV_Matrix::NVTE_V_Matrix); + NVTE_QKV_Matrix::NVTE_V_Matrix); printf("dq_stride: %d, %d, %d, %d\n", dq_stride[0], dq_stride[1], dq_stride[2], dq_stride[3]); printf("dk_stride: %d, %d, %d, %d\n", dk_stride[0], dk_stride[1], dk_stride[2], dk_stride[3]); printf("dv_stride: %d, %d, %d, %d\n", dv_stride[0], dv_stride[1], dv_stride[2], dv_stride[3]); - dQ->set_output(true).set_dim({b, h, s_q, d_qk}).set_stride(dq_stride).set_data_type(dqkv_tensor_type); - dK->set_output(true).set_dim({b, hg, s_kv, d_qk}).set_stride(dk_stride).set_data_type(dqkv_tensor_type); - dV->set_output(true).set_dim({b, hg, s_kv, d_v}).set_stride(dv_stride).set_data_type(dqkv_tensor_type); + dQ->set_output(true) + .set_dim({b, h, s_q, d_qk}) + .set_stride(dq_stride) + .set_data_type(dqkv_tensor_type); + dK->set_output(true) + .set_dim({b, hg, s_kv, d_qk}) + .set_stride(dk_stride) + .set_data_type(dqkv_tensor_type); + dV->set_output(true) + .set_dim({b, hg, s_kv, d_v}) + .set_stride(dv_stride) + .set_data_type(dqkv_tensor_type); amax_dQ->set_output(true) .set_dim({1, 1, 1, 1}) .set_stride({1, 1, 1, 1}) @@ -2524,10 +2591,10 @@ void fused_attn_fp8_bwd_impl_v1( .set_stride({1, 1, 1, 1}) .set_data_type(fe::DataType_t::FLOAT); if (!is_mxfp8) { - amax_dP->set_output(true) - .set_dim({1, 1, 1, 1}) - .set_stride({1, 1, 1, 1}) - .set_data_type(fe::DataType_t::FLOAT); + amax_dP->set_output(true) + .set_dim({1, 1, 1, 1}) + .set_stride({1, 1, 1, 1}) + .set_data_type(fe::DataType_t::FLOAT); } std::tuple, // Q @@ -2560,8 +2627,9 @@ void fused_attn_fp8_bwd_impl_v1( Q, K, V, O, Stats, dO, attn_scale, descale_q, descale_k, descale_v, descale_o, descale_dO, descale_s, descale_dP, scale_s, scale_dQ, scale_dK, scale_dV, scale_dP, dQ, dK, dV, amax_dQ, amax_dK, amax_dV, amax_dP); - auto mxfp8_tensors_tuple = is_mxfp8 ? std::make_tuple(Q_t, K_t, dO_f16, dO_t, descale_q_t, descale_k_t, descale_dO_t) - : std::make_tuple(nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr); + auto mxfp8_tensors_tuple = + is_mxfp8 ? std::make_tuple(Q_t, K_t, dO_f16, dO_t, descale_q_t, descale_k_t, descale_dO_t) + : std::make_tuple(nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr); auto bias_tuple = is_bias ? std::make_tuple(bias, dBias) : std::make_tuple(nullptr, nullptr); auto padding_tuple = is_padding ? std::make_tuple(seq_q, seq_kv) : std::make_tuple(nullptr, nullptr); @@ -2574,16 +2642,18 @@ void fused_attn_fp8_bwd_impl_v1( NVTE_CHECK_CUDNN_FE(mha_graph->check_support(handle)); NVTE_CHECK_CUDNN_FE(mha_graph->build_plans(handle)); - auto return_tuple = std::tuple_cat(std::make_tuple(mha_graph), key_tensors_tuple, mxfp8_tensors_tuple, bias_tuple, - padding_tuple, dropout_tuple); + auto return_tuple = + std::tuple_cat(std::make_tuple(mha_graph), key_tensors_tuple, mxfp8_tensors_tuple, + bias_tuple, padding_tuple, dropout_tuple); cache.insert({descriptor, return_tuple}); return return_tuple; }; auto [mha_graph, Q, K, V, O, Stats, dO, attn_scale, descale_q, descale_k, descale_v, descale_o, descale_dO, descale_s, descale_dP, scale_s, scale_dQ, scale_dK, scale_dV, scale_dP, dQ, - dK, dV, amax_dQ, amax_dK, amax_dV, amax_dP, Q_t, K_t, dO_f16, dO_t, descale_q_t, descale_k_t, descale_dO_t, - bias, dBias, seq_q, seq_kv, dropout_seed, dropout_offset] = get_graph(sdpa_fp8_bprop_cache, descriptor); + dK, dV, amax_dQ, amax_dK, amax_dV, amax_dP, Q_t, K_t, dO_f16, dO_t, descale_q_t, + descale_k_t, descale_dO_t, bias, dBias, seq_q, seq_kv, dropout_seed, dropout_offset] = + get_graph(sdpa_fp8_bprop_cache, descriptor); auto plan_workspace_size = mha_graph->get_workspace_size(); @@ -2650,23 +2720,34 @@ void fused_attn_fp8_bwd_impl_v1( printf("devPtrO: %p, is_aligned: %d\n", devPtrO, is_aligned_modulo(devPtrO, modulo)); printf("devPtrM: %p, is_aligned: %d\n", devPtrM, is_aligned_modulo(devPtrM, modulo)); printf("devPtrdO: %p, is_aligned: %d\n", devPtrdO, is_aligned_modulo(devPtrdO, modulo)); - printf("devPtrDescaleQ: %p, is_aligned: %d\n", devPtrDescaleQ, is_aligned_modulo(devPtrDescaleQ, modulo)); - printf("devPtrDescaleK: %p, is_aligned: %d\n", devPtrDescaleK, is_aligned_modulo(devPtrDescaleK, modulo)); - printf("devPtrDescaleV: %p, is_aligned: %d\n", devPtrDescaleV, is_aligned_modulo(devPtrDescaleV, modulo)); - printf("devPtrDescaledO: %p, is_aligned: %d\n", devPtrDescaledO, is_aligned_modulo(devPtrDescaledO, modulo)); - printf("devPtrDescaledO_t: %p, is_aligned: %d\n", devPtrDescaledO_t, is_aligned_modulo(devPtrDescaledO_t, modulo)); + printf("devPtrDescaleQ: %p, is_aligned: %d\n", devPtrDescaleQ, + is_aligned_modulo(devPtrDescaleQ, modulo)); + printf("devPtrDescaleK: %p, is_aligned: %d\n", devPtrDescaleK, + is_aligned_modulo(devPtrDescaleK, modulo)); + printf("devPtrDescaleV: %p, is_aligned: %d\n", devPtrDescaleV, + is_aligned_modulo(devPtrDescaleV, modulo)); + printf("devPtrDescaledO: %p, is_aligned: %d\n", devPtrDescaledO, + is_aligned_modulo(devPtrDescaledO, modulo)); + printf("devPtrDescaledO_t: %p, is_aligned: %d\n", devPtrDescaledO_t, + is_aligned_modulo(devPtrDescaledO_t, modulo)); printf("devPtrdQ: %p, is_aligned: %d\n", devPtrdQ, is_aligned_modulo(devPtrdQ, modulo)); printf("devPtrdK: %p, is_aligned: %d\n", devPtrdK, is_aligned_modulo(devPtrdK, modulo)); printf("devPtrdV: %p, is_aligned: %d\n", devPtrdV, is_aligned_modulo(devPtrdV, modulo)); - printf("devPtrAmaxdQ: %p, is_aligned: %d\n", devPtrAmaxdQ, is_aligned_modulo(devPtrAmaxdQ, modulo)); - printf("devPtrAmaxdK: %p, is_aligned: %d\n", devPtrAmaxdK, is_aligned_modulo(devPtrAmaxdK, modulo)); - printf("devPtrAmaxdV: %p, is_aligned: %d\n", devPtrAmaxdV, is_aligned_modulo(devPtrAmaxdV, modulo)); + printf("devPtrAmaxdQ: %p, is_aligned: %d\n", devPtrAmaxdQ, + is_aligned_modulo(devPtrAmaxdQ, modulo)); + printf("devPtrAmaxdK: %p, is_aligned: %d\n", devPtrAmaxdK, + is_aligned_modulo(devPtrAmaxdK, modulo)); + printf("devPtrAmaxdV: %p, is_aligned: %d\n", devPtrAmaxdV, + is_aligned_modulo(devPtrAmaxdV, modulo)); printf("devPtrQ_t: %p, is_aligned: %d\n", devPtrQ_t, is_aligned_modulo(devPtrQ_t, modulo)); printf("devPtrK_t: %p, is_aligned: %d\n", devPtrK_t, is_aligned_modulo(devPtrK_t, modulo)); - printf("devPtrdO_f16: %p, is_aligned: %d\n", devPtrdO_f16, is_aligned_modulo(devPtrdO_f16, modulo)); + printf("devPtrdO_f16: %p, is_aligned: %d\n", devPtrdO_f16, + is_aligned_modulo(devPtrdO_f16, modulo)); printf("devPtrdO_t: %p, is_aligned: %d\n", devPtrdO_t, is_aligned_modulo(devPtrdO_t, modulo)); - printf("devPtrDescaleQ_t: %p, is_aligned: %d\n", devPtrDescaleQ_t, is_aligned_modulo(devPtrDescaleQ_t, modulo)); - printf("devPtrDescaleK_t: %p, is_aligned: %d\n", devPtrDescaleK_t, is_aligned_modulo(devPtrDescaleK_t, modulo)); + printf("devPtrDescaleQ_t: %p, is_aligned: %d\n", devPtrDescaleQ_t, + is_aligned_modulo(devPtrDescaleQ_t, modulo)); + printf("devPtrDescaleK_t: %p, is_aligned: %d\n", devPtrDescaleK_t, + is_aligned_modulo(devPtrDescaleK_t, modulo)); /* if (is_bias) { variant_pack[bias] = devPtrBias; @@ -2709,14 +2790,16 @@ void fused_attn_fp8_bwd_impl_v1( #if (CUDNN_VERSION >= 8900) // fused attention FWD FP8 with separate Q, K, V void fused_attn_fp8_fwd(size_t batch, size_t num_attn_heads, size_t num_gqa_groups, - size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, - bool is_training, float attn_scale, float p_dropout, - NVTE_QKV_Layout qkv_layout, NVTE_QKV_Format o_format, NVTE_Bias_Type bias_type, - NVTE_Mask_Type mask_type, size_t window_size_left, size_t window_size_right, bool bottom_right_diagonal, const Tensor* input_Q, const Tensor* input_K, - const Tensor* input_V, Tensor* input_output_S, Tensor* output_O, - NVTETensorPack* Aux_CTX_Tensors, const Tensor* cu_seqlens_q, - const Tensor* cu_seqlens_kv, const Tensor* rng_state, Tensor* workspace, - cudaStream_t stream, cudnnHandle_t handle) { + size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim_qk, + size_t head_dim_v, bool is_training, float attn_scale, float p_dropout, + NVTE_QKV_Layout qkv_layout, NVTE_QKV_Format o_format, + NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, size_t window_size_left, + size_t window_size_right, bool bottom_right_diagonal, const Tensor* input_Q, + const Tensor* input_K, const Tensor* input_V, Tensor* input_output_S, + Tensor* output_O, NVTETensorPack* Aux_CTX_Tensors, + const Tensor* cu_seqlens_q, const Tensor* cu_seqlens_kv, + const Tensor* rng_state, Tensor* workspace, cudaStream_t stream, + cudnnHandle_t handle) { using namespace transformer_engine; void* devPtrQ = nullptr; void* devPtrK = nullptr; @@ -2753,7 +2836,7 @@ void fused_attn_fp8_fwd(size_t batch, size_t num_attn_heads, size_t num_gqa_grou devPtrScaleO = output_O->scale.dptr; devPtrAmaxS = input_output_S->amax.dptr; devPtrScaleS = input_output_S->scale.dptr; - devPtrDescaleS = input_output_S->scale_inv.dptr; + devPtrDescaleS = input_output_S->scale_inv.dptr; } void* devPtrM = nullptr; void* devPtrZInv = nullptr; @@ -2796,13 +2879,15 @@ void fused_attn_fp8_fwd(size_t batch, size_t num_attn_heads, size_t num_gqa_grou size_t workspace_size = 0; NVTE_QKV_Format qkv_format = nvte_get_qkv_format(qkv_layout); - if ((qkv_format == NVTE_QKV_Format::NVTE_BSHD) || (qkv_format == NVTE_QKV_Format::NVTE_SBHD) || (qkv_format == NVTE_QKV_Format::NVTE_BHSD)) { + if ((qkv_format == NVTE_QKV_Format::NVTE_BSHD) || (qkv_format == NVTE_QKV_Format::NVTE_SBHD) || + (qkv_format == NVTE_QKV_Format::NVTE_BHSD)) { fused_attn::fused_attn_fp8_fwd_impl_v1( - batch, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, head_dim_qk, head_dim_v, is_training, - attn_scale, p_dropout, qkv_layout, o_format, bias_type, mask_type, window_size_left, window_size_right, bottom_right_diagonal, devPtrQ, devPtrK, devPtrV, devPtrM, - devPtrZInv, devPtrO, devPtrDescaleQ, devPtrDescaleK, devPtrDescaleV, devPtrDescaleS, - devPtrScaleS, devPtrScaleO, devPtrAmaxO, devPtrAmaxS, devPtrcuSeqlensQ, devPtrcuSeqlensKV, - devPtrDropoutSeed, devPtrDropoutOffset, get_cudnn_fe_dtype(QKV_type), + batch, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, head_dim_qk, head_dim_v, + is_training, attn_scale, p_dropout, qkv_layout, o_format, bias_type, mask_type, + window_size_left, window_size_right, bottom_right_diagonal, devPtrQ, devPtrK, devPtrV, + devPtrM, devPtrZInv, devPtrO, devPtrDescaleQ, devPtrDescaleK, devPtrDescaleV, + devPtrDescaleS, devPtrScaleS, devPtrScaleO, devPtrAmaxO, devPtrAmaxS, devPtrcuSeqlensQ, + devPtrcuSeqlensKV, devPtrDropoutSeed, devPtrDropoutOffset, get_cudnn_fe_dtype(QKV_type), get_cudnn_fe_dtype(O_type), input_Q->scaling_mode, workspace->data.dptr, &workspace_size, stream, handle); } else if (qkv_layout == NVTE_QKV_Layout::NVTE_T3HD) { @@ -2831,16 +2916,19 @@ void fused_attn_fp8_fwd(size_t batch, size_t num_attn_heads, size_t num_gqa_grou } // fused attention BWD FP8 with separate Q, K, V void fused_attn_fp8_bwd(size_t batch, size_t num_attn_heads, size_t num_gqa_groups, - size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, - float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_QKV_Format o_format, NVTE_QKV_Format d_out_format, NVTE_QKV_Layout dqkv_layout, - NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, size_t window_size_left, size_t window_size_right, bool bottom_right_diagonal, bool deterministic, const Tensor* input_Q, - const Tensor* input_K, const Tensor* input_V, const Tensor* input_O, - const Tensor* input_dO, const Tensor* input_dO_f16, const Tensor* input_M, const Tensor* input_ZInv, - const Tensor* input_S, Tensor* input_output_dP, const Tensor* output_dQ, - const Tensor* output_dK, const Tensor* output_dV, - const Tensor* cu_seqlens_q, const Tensor* cu_seqlens_kv, - const Tensor* rng_state, Tensor* workspace, cudaStream_t stream, - cudnnHandle_t handle) { + size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim_qk, + size_t head_dim_v, float attn_scale, float p_dropout, + NVTE_QKV_Layout qkv_layout, NVTE_QKV_Format o_format, + NVTE_QKV_Format d_out_format, NVTE_QKV_Layout dqkv_layout, + NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, size_t window_size_left, + size_t window_size_right, bool bottom_right_diagonal, bool deterministic, + const Tensor* input_Q, const Tensor* input_K, const Tensor* input_V, + const Tensor* input_O, const Tensor* input_dO, const Tensor* input_dO_f16, + const Tensor* input_M, const Tensor* input_ZInv, const Tensor* input_S, + Tensor* input_output_dP, const Tensor* output_dQ, const Tensor* output_dK, + const Tensor* output_dV, const Tensor* cu_seqlens_q, + const Tensor* cu_seqlens_kv, const Tensor* rng_state, Tensor* workspace, + cudaStream_t stream, cudnnHandle_t handle) { using namespace transformer_engine; void* devPtrQ = input_Q->data.dptr; void* devPtrK = input_K->data.dptr; @@ -2899,19 +2987,21 @@ void fused_attn_fp8_bwd(size_t batch, size_t num_attn_heads, size_t num_gqa_grou size_t workspace_size = 0; NVTE_QKV_Format dqkv_format = nvte_get_qkv_format(dqkv_layout); - if ((dqkv_format == NVTE_QKV_Format::NVTE_BSHD) || (dqkv_format == NVTE_QKV_Format::NVTE_SBHD) || (dqkv_format == NVTE_QKV_Format::NVTE_BHSD)) { + if ((dqkv_format == NVTE_QKV_Format::NVTE_BSHD) || (dqkv_format == NVTE_QKV_Format::NVTE_SBHD) || + (dqkv_format == NVTE_QKV_Format::NVTE_BHSD)) { fused_attn::fused_attn_fp8_bwd_impl_v1( - batch, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, head_dim_qk, head_dim_v, attn_scale, - p_dropout, qkv_layout, o_format, d_out_format, dqkv_layout, bias_type, mask_type, window_size_left, window_size_right, bottom_right_diagonal, deterministic, devPtrQ, devPtrK, devPtrV, devPtrM, devPtrZInv, - devPtrO, devPtrdO, devPtrdQ, devPtrdK, devPtrdV, devPtrDescaleQ, devPtrDescaleK, - devPtrDescaleV, devPtrDescaleO, devPtrDescaledO, devPtrDescaleS, devPtrDescaledP, - devPtrScaleS, devPtrScaledP, devPtrScaledQ, devPtrScaledK, devPtrScaledV, devPtrAmaxdP, - devPtrAmaxdQ, devPtrAmaxdK, devPtrAmaxdV, - devPtrQ_t, devPtrK_t, devPtrdO_f16, devPtrdO_t, devPtrDescaleQ_t, devPtrDescaleK_t, devPtrDescaledO_t, - devPtrcuSeqlensQ, devPtrcuSeqlensKV, - devPtrDropoutSeed, devPtrDropoutOffset, get_cudnn_fe_dtype(QKV_type), - get_cudnn_fe_dtype(O_type), get_cudnn_fe_dtype(dO_type), get_cudnn_fe_dtype(dQKV_type), - input_dO->scaling_mode, workspace->data.dptr, &workspace_size, stream, handle); + batch, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, head_dim_qk, head_dim_v, + attn_scale, p_dropout, qkv_layout, o_format, d_out_format, dqkv_layout, bias_type, + mask_type, window_size_left, window_size_right, bottom_right_diagonal, deterministic, + devPtrQ, devPtrK, devPtrV, devPtrM, devPtrZInv, devPtrO, devPtrdO, devPtrdQ, devPtrdK, + devPtrdV, devPtrDescaleQ, devPtrDescaleK, devPtrDescaleV, devPtrDescaleO, devPtrDescaledO, + devPtrDescaleS, devPtrDescaledP, devPtrScaleS, devPtrScaledP, devPtrScaledQ, devPtrScaledK, + devPtrScaledV, devPtrAmaxdP, devPtrAmaxdQ, devPtrAmaxdK, devPtrAmaxdV, devPtrQ_t, devPtrK_t, + devPtrdO_f16, devPtrdO_t, devPtrDescaleQ_t, devPtrDescaleK_t, devPtrDescaledO_t, + devPtrcuSeqlensQ, devPtrcuSeqlensKV, devPtrDropoutSeed, devPtrDropoutOffset, + get_cudnn_fe_dtype(QKV_type), get_cudnn_fe_dtype(O_type), get_cudnn_fe_dtype(dO_type), + get_cudnn_fe_dtype(dQKV_type), input_dO->scaling_mode, workspace->data.dptr, + &workspace_size, stream, handle); } else if (dqkv_layout == NVTE_QKV_Layout::NVTE_T3HD) { fused_attn::fused_attn_fp8_bwd_impl( batch, num_attn_heads, max_seqlen_q, max_seqlen_kv, head_dim_qk, attn_scale, p_dropout, diff --git a/transformer_engine/common/fused_attn/fused_attn_fp8.h b/transformer_engine/common/fused_attn/fused_attn_fp8.h index 9683974a26..98d5876ec8 100644 --- a/transformer_engine/common/fused_attn/fused_attn_fp8.h +++ b/transformer_engine/common/fused_attn/fused_attn_fp8.h @@ -15,26 +15,31 @@ namespace transformer_engine { #if (CUDNN_VERSION >= 8900) // fused attention FWD FP8 with separate Q, K, V void fused_attn_fp8_fwd(size_t batch, size_t num_attn_heads, size_t num_gqa_groups, - size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, - bool is_training, float attn_scale, float p_dropout, - NVTE_QKV_Layout qkv_layout, NVTE_QKV_Format o_format, NVTE_Bias_Type bias_type, - NVTE_Mask_Type mask_type, size_t window_size_left, size_t window_size_right, bool bottom_right_diagonal, const Tensor *input_Q, const Tensor *input_K, - const Tensor *input_V, Tensor *input_output_S, Tensor *output_O, - NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens_q, - const Tensor *cu_seqlens_kv, const Tensor *rng_state, Tensor *workspace, - cudaStream_t stream, cudnnHandle_t handle); - -// fused attention BWD FP8 with separate Q, K, V -void fused_attn_fp8_bwd(size_t batch, size_t num_attn_heads, size_t num_gqa_groups, - size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, - float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_QKV_Format o_format, NVTE_QKV_Format d_out_format, NVTE_QKV_Layout dqkv_layout, - NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, size_t window_size_left, size_t window_size_right, bool bottom_right_diagonal, bool deterministic, const Tensor *input_Q, - const Tensor *input_K, const Tensor *input_V, const Tensor *input_O, - const Tensor *input_dO, const Tensor *input_dO_f16, const Tensor *input_M, const Tensor *input_ZInv, - const Tensor *input_S, Tensor *input_output_dP, const Tensor *output_dQ, - const Tensor *output_dK, const Tensor *output_dV, + size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim_qk, + size_t head_dim_v, bool is_training, float attn_scale, float p_dropout, + NVTE_QKV_Layout qkv_layout, NVTE_QKV_Format o_format, + NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, size_t window_size_left, + size_t window_size_right, bool bottom_right_diagonal, const Tensor *input_Q, + const Tensor *input_K, const Tensor *input_V, Tensor *input_output_S, + Tensor *output_O, NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv, const Tensor *rng_state, Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle); + +// fused attention BWD FP8 with separate Q, K, V +void fused_attn_fp8_bwd(size_t batch, size_t num_attn_heads, size_t num_gqa_groups, + size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim_qk, + size_t head_dim_v, float attn_scale, float p_dropout, + NVTE_QKV_Layout qkv_layout, NVTE_QKV_Format o_format, + NVTE_QKV_Format d_out_format, NVTE_QKV_Layout dqkv_layout, + NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, size_t window_size_left, + size_t window_size_right, bool bottom_right_diagonal, bool deterministic, + const Tensor *input_Q, const Tensor *input_K, const Tensor *input_V, + const Tensor *input_O, const Tensor *input_dO, const Tensor *input_dO_f16, + const Tensor *input_M, const Tensor *input_ZInv, const Tensor *input_S, + Tensor *input_output_dP, const Tensor *output_dQ, const Tensor *output_dK, + const Tensor *output_dV, const Tensor *cu_seqlens_q, + const Tensor *cu_seqlens_kv, const Tensor *rng_state, Tensor *workspace, + cudaStream_t stream, cudnnHandle_t handle); #endif // end of CUDNN>=8900 } // namespace transformer_engine diff --git a/transformer_engine/common/fused_attn/utils.cu b/transformer_engine/common/fused_attn/utils.cu index 8a9399e830..e67ae5e206 100644 --- a/transformer_engine/common/fused_attn/utils.cu +++ b/transformer_engine/common/fused_attn/utils.cu @@ -319,7 +319,7 @@ void generateMatrixStrides(int64_t b, int64_t h, int64_t s_q, int64_t s_kv, int6 strideA[hidden_transpose_dim_idx] = 1; } break; -} + } if (matrix == NVTE_QKV_Matrix::NVTE_S_Matrix) { strideA[seqlen_kv_dim_idx] = 1; diff --git a/transformer_engine/common/fused_attn/utils.h b/transformer_engine/common/fused_attn/utils.h index 2c03245560..3e4ca696e2 100644 --- a/transformer_engine/common/fused_attn/utils.h +++ b/transformer_engine/common/fused_attn/utils.h @@ -49,11 +49,11 @@ struct MXFP8PaddedSizes { int64_t d_v_scale_padded; }; -inline bool is_aligned_modulo(void* ptr, int64_t modulo) { - // Cast the pointer to a large enough integer type (uintptr_t) - uintptr_t address = reinterpret_cast(ptr); - // Check if the address is perfectly divisible by 16 - return (address % modulo) == 0; +inline bool is_aligned_modulo(void *ptr, int64_t modulo) { + // Cast the pointer to a large enough integer type (uintptr_t) + uintptr_t address = reinterpret_cast(ptr); + // Check if the address is perfectly divisible by 16 + return (address % modulo) == 0; } // Pad s and d for MXFP8 layout @@ -78,7 +78,8 @@ inline MXFP8PaddedSizes pad_s_d_for_mxfp8(int64_t s_q, int64_t s_kv, int64_t d_q // Get matrix strides for a 4D tensor [batch, head, seqlen, hidden] given a QKV format. // strideA must point to at least 4 int64_t elements. inline void generateMatrixStridesWithFormat(int64_t b, int64_t h, int64_t s, int64_t d, - int64_t *strides, NVTE_QKV_Format format, bool transpose) { + int64_t *strides, NVTE_QKV_Format format, + bool transpose) { constexpr int batch_dim_idx = 0; constexpr int head_dim_idx = 1; int seqlen_dim_idx = transpose ? 3 : 2; @@ -111,234 +112,233 @@ inline void generateMatrixStridesWithFormat(int64_t b, int64_t h, int64_t s, int } // get matrix strides based on matrix type -inline void generateMatrixStrides_v1( - int64_t b, int64_t h, int64_t hg, int64_t s_q, int64_t s_kv, int64_t d_qk, int64_t d_v, - int64_t *strides, NVTE_QKV_Layout layout, NVTE_QKV_Matrix matrix) -{ - constexpr int batch_dim_idx = 0; - constexpr int head_dim_idx = 1; - bool transpose = - (matrix == NVTE_QKV_Matrix::NVTE_Q_Matrix_Transpose) || - (matrix == NVTE_QKV_Matrix::NVTE_K_Matrix_Transpose) || - (matrix == NVTE_QKV_Matrix::NVTE_V_Matrix_Transpose); - int seqlen_dim_idx = transpose ? 3 : 2; - int hidden_dim_idx = transpose ? 2 : 3; - constexpr int seqlen_q_dim_idx = 2; - constexpr int seqlen_kv_dim_idx = 3; - - if (matrix == NVTE_QKV_Matrix::NVTE_Q_Matrix_Transpose) { - matrix = NVTE_QKV_Matrix::NVTE_Q_Matrix; - } else if (matrix == NVTE_QKV_Matrix::NVTE_K_Matrix_Transpose) { - matrix = NVTE_QKV_Matrix::NVTE_K_Matrix; - } else if (matrix == NVTE_QKV_Matrix::NVTE_V_Matrix_Transpose) { - matrix = NVTE_QKV_Matrix::NVTE_V_Matrix; - } - NVTE_CHECK(matrix != NVTE_QKV_Matrix::NVTE_O_Matrix, "Invalid matrix type. Expected Q, K, V, O, or their related transposes."); - - switch (layout) { - case NVTE_QKV_Layout::NVTE_SB3HD: - if ((matrix == NVTE_QKV_Matrix::NVTE_Q_Matrix) || - (matrix == NVTE_QKV_Matrix::NVTE_K_Matrix) || - (matrix == NVTE_QKV_Matrix::NVTE_V_Matrix)) { - strides[batch_dim_idx] = 3 * h * d_qk; - strides[head_dim_idx] = d_qk; - strides[seqlen_dim_idx] = b * 3 * h * d_qk; - strides[hidden_dim_idx] = 1; - } - break; - case NVTE_QKV_Layout::NVTE_SBH3D: - if ((matrix == NVTE_QKV_Matrix::NVTE_Q_Matrix) || - (matrix == NVTE_QKV_Matrix::NVTE_K_Matrix) || - (matrix == NVTE_QKV_Matrix::NVTE_V_Matrix)) { - strides[batch_dim_idx] = 3 * h * d_qk; - strides[head_dim_idx] = 3 * d_qk; - strides[seqlen_dim_idx] = b * 3 * h * d_qk; - strides[hidden_dim_idx] = 1; - } - break; - case NVTE_QKV_Layout::NVTE_SBHD_SB2HD: - if (matrix == NVTE_QKV_Matrix::NVTE_Q_Matrix) { - strides[batch_dim_idx] = h * d_qk; - strides[head_dim_idx] = d_qk; - strides[seqlen_dim_idx] = b * h * d_qk; - strides[hidden_dim_idx] = 1; - } else if ((matrix == NVTE_QKV_Matrix::NVTE_K_Matrix) || - (matrix == NVTE_QKV_Matrix::NVTE_V_Matrix)) { - strides[batch_dim_idx] = 2 * hg * d_qk; - strides[head_dim_idx] = d_qk; - strides[seqlen_dim_idx] = b * 2 * hg * d_qk; - strides[hidden_dim_idx] = 1; - } - break; - case NVTE_QKV_Layout::NVTE_SBHD_SBH2D: - if (matrix == NVTE_QKV_Matrix::NVTE_Q_Matrix) { - strides[batch_dim_idx] = h * d_qk; - strides[head_dim_idx] = d_qk; - strides[seqlen_dim_idx] = b * h * d_qk; - strides[hidden_dim_idx] = 1; - } else if ((matrix == NVTE_QKV_Matrix::NVTE_K_Matrix) || - (matrix == NVTE_QKV_Matrix::NVTE_V_Matrix)) { - strides[batch_dim_idx] = 2 * hg * d_qk; - strides[head_dim_idx] = 2 * d_qk; - strides[seqlen_dim_idx] = b * 2 * hg * d_qk; - strides[hidden_dim_idx] = 1; - } - break; - case NVTE_QKV_Layout::NVTE_SBHD_SBHD_SBHD: - case NVTE_QKV_Layout::NVTE_Paged_KV_SBHD_SBHD_SBHD: - if (matrix == NVTE_QKV_Matrix::NVTE_Q_Matrix) { - strides[batch_dim_idx] = h * d_qk; - strides[head_dim_idx] = d_qk; - strides[seqlen_dim_idx] = b * h * d_qk; - strides[hidden_dim_idx] = 1; - } else if (matrix == NVTE_QKV_Matrix::NVTE_K_Matrix) { - strides[batch_dim_idx] = hg * d_qk; - strides[head_dim_idx] = d_qk; - strides[seqlen_dim_idx] = b * hg * d_qk; - strides[hidden_dim_idx] = 1; - } else if (matrix == NVTE_QKV_Matrix::NVTE_V_Matrix) { - strides[batch_dim_idx] = hg * d_v; - strides[head_dim_idx] = d_v; - strides[seqlen_dim_idx] = b * hg * d_v; - strides[hidden_dim_idx] = 1; - } - break; - case NVTE_QKV_Layout::NVTE_BS3HD: - case NVTE_QKV_Layout::NVTE_T3HD: - if ((matrix == NVTE_QKV_Matrix::NVTE_Q_Matrix) || - (matrix == NVTE_QKV_Matrix::NVTE_K_Matrix) || - (matrix == NVTE_QKV_Matrix::NVTE_V_Matrix)) { - strides[batch_dim_idx] = s_q * 3 * h * d_qk; - strides[head_dim_idx] = d_qk; - strides[seqlen_dim_idx] = 3 * h * d_qk; - strides[hidden_dim_idx] = 1; - } - break; - case NVTE_QKV_Layout::NVTE_BSH3D: - case NVTE_QKV_Layout::NVTE_TH3D: - if ((matrix == NVTE_QKV_Matrix::NVTE_Q_Matrix) || - (matrix == NVTE_QKV_Matrix::NVTE_K_Matrix) || - (matrix == NVTE_QKV_Matrix::NVTE_V_Matrix)) { - strides[batch_dim_idx] = s_q * 3 * h * d_qk; - strides[head_dim_idx] = 3 * d_qk; - strides[seqlen_dim_idx] = 3 * h * d_qk; - strides[hidden_dim_idx] = 1; - } - break; - case NVTE_QKV_Layout::NVTE_BSHD_BS2HD: - case NVTE_QKV_Layout::NVTE_THD_T2HD: - if (matrix == NVTE_QKV_Matrix::NVTE_Q_Matrix) { - strides[batch_dim_idx] = s_q * h * d_qk; - strides[head_dim_idx] = d_qk; - strides[seqlen_dim_idx] = h * d_qk; - strides[hidden_dim_idx] = 1; - } else if ((matrix == NVTE_QKV_Matrix::NVTE_K_Matrix) || - (matrix == NVTE_QKV_Matrix::NVTE_V_Matrix)) { - strides[batch_dim_idx] = s_kv * 2 * hg * d_qk; - strides[head_dim_idx] = d_qk; - strides[seqlen_dim_idx] = 2 * hg * d_qk; - strides[hidden_dim_idx] = 1; - } - break; - case NVTE_QKV_Layout::NVTE_BSHD_BSH2D: - case NVTE_QKV_Layout::NVTE_THD_TH2D: - if (matrix == NVTE_QKV_Matrix::NVTE_Q_Matrix) { - strides[batch_dim_idx] = s_q * h * d_qk; - strides[head_dim_idx] = d_qk; - strides[seqlen_dim_idx] = h * d_qk; - strides[hidden_dim_idx] = 1; - } else if ((matrix == NVTE_QKV_Matrix::NVTE_K_Matrix) || - (matrix == NVTE_QKV_Matrix::NVTE_V_Matrix)) { - strides[batch_dim_idx] = s_kv * 2 * hg * d_qk; - strides[head_dim_idx] = 2 * d_qk; - strides[seqlen_dim_idx] = 2 * hg * d_qk; - strides[hidden_dim_idx] = 1; - } - break; - case NVTE_QKV_Layout::NVTE_BSHD_BSHD_BSHD: - case NVTE_QKV_Layout::NVTE_THD_THD_THD: - case NVTE_QKV_Layout::NVTE_THD_BSHD_BSHD: - case NVTE_QKV_Layout::NVTE_Paged_KV_BSHD_BSHD_BSHD: - case NVTE_QKV_Layout::NVTE_Paged_KV_THD_BSHD_BSHD: - if (matrix == NVTE_QKV_Matrix::NVTE_Q_Matrix) { - strides[batch_dim_idx] = s_q * h * d_qk; - strides[head_dim_idx] = d_qk; - strides[seqlen_dim_idx] = h * d_qk; - strides[hidden_dim_idx] = 1; - } else if (matrix == NVTE_QKV_Matrix::NVTE_K_Matrix) { - strides[batch_dim_idx] = s_kv * hg * d_qk; - strides[head_dim_idx] = d_qk; - strides[seqlen_dim_idx] = hg * d_qk; - strides[hidden_dim_idx] = 1; - } else if (matrix == NVTE_QKV_Matrix::NVTE_V_Matrix) { - strides[batch_dim_idx] = s_kv * hg * d_v; - strides[head_dim_idx] = d_v; - strides[seqlen_dim_idx] = hg * d_v; - strides[hidden_dim_idx] = 1; - } - break; - case NVTE_QKV_Layout::NVTE_SBHD_BSHD_BSHD: - case NVTE_QKV_Layout::NVTE_Paged_KV_SBHD_BSHD_BSHD: - if (matrix == NVTE_QKV_Matrix::NVTE_Q_Matrix) { - strides[batch_dim_idx] = h * d_qk; - strides[head_dim_idx] = d_qk; - strides[seqlen_dim_idx] = b * h * d_qk; - strides[hidden_dim_idx] = 1; - } else if (matrix == NVTE_QKV_Matrix::NVTE_K_Matrix) { - strides[batch_dim_idx] = s_kv * hg * d_qk; - strides[head_dim_idx] = d_qk; - strides[seqlen_dim_idx] = hg * d_qk; - strides[hidden_dim_idx] = 1; - } else if (matrix == NVTE_QKV_Matrix::NVTE_V_Matrix) { - strides[batch_dim_idx] = s_kv * hg * d_v; - strides[head_dim_idx] = d_v; - strides[seqlen_dim_idx] = hg * d_v; - strides[hidden_dim_idx] = 1; - } - break; - case NVTE_QKV_Layout::NVTE_BSHD_SBHD_SBHD: - case NVTE_QKV_Layout::NVTE_THD_SBHD_SBHD: - case NVTE_QKV_Layout::NVTE_Paged_KV_BSHD_SBHD_SBHD: - case NVTE_QKV_Layout::NVTE_Paged_KV_THD_SBHD_SBHD: - if (matrix == NVTE_QKV_Matrix::NVTE_Q_Matrix) { - strides[batch_dim_idx] = s_q * h * d_qk; - strides[head_dim_idx] = d_qk; - strides[seqlen_dim_idx] = h * d_qk; - strides[hidden_dim_idx] = 1; - } else if (matrix == NVTE_QKV_Matrix::NVTE_K_Matrix) { - strides[batch_dim_idx] = hg * d_qk; - strides[head_dim_idx] = d_qk; - strides[seqlen_dim_idx] = b * hg * d_qk; - strides[hidden_dim_idx] = 1; - } else if (matrix == NVTE_QKV_Matrix::NVTE_V_Matrix) { - strides[batch_dim_idx] = hg * d_v; - strides[head_dim_idx] = d_v; - strides[seqlen_dim_idx] = b * hg * d_v; - strides[hidden_dim_idx] = 1; - } - break; - case NVTE_QKV_Layout::NVTE_BHSD_BHSD_BHSD: - if (matrix == NVTE_QKV_Matrix::NVTE_Q_Matrix) { - strides[batch_dim_idx] = h * s_q * d_qk; - strides[head_dim_idx] = s_q * d_qk; - strides[seqlen_dim_idx] = d_qk; - strides[hidden_dim_idx] = 1; - } else if (matrix == NVTE_QKV_Matrix::NVTE_K_Matrix) { - strides[batch_dim_idx] = hg * s_kv * d_qk; - strides[head_dim_idx] = s_kv * d_qk; - strides[seqlen_dim_idx] = d_qk; - strides[hidden_dim_idx] = 1; - } else if (matrix == NVTE_QKV_Matrix::NVTE_V_Matrix) { - strides[batch_dim_idx] = hg * s_kv * d_v; - strides[head_dim_idx] = s_kv * d_v; - strides[seqlen_dim_idx] = d_v; - strides[hidden_dim_idx] = 1; - } - break; - default: - NVTE_CHECK(false, "Invalid layout."); - break; - } +inline void generateMatrixStrides_v1(int64_t b, int64_t h, int64_t hg, int64_t s_q, int64_t s_kv, + int64_t d_qk, int64_t d_v, int64_t *strides, + NVTE_QKV_Layout layout, NVTE_QKV_Matrix matrix) { + constexpr int batch_dim_idx = 0; + constexpr int head_dim_idx = 1; + bool transpose = (matrix == NVTE_QKV_Matrix::NVTE_Q_Matrix_Transpose) || + (matrix == NVTE_QKV_Matrix::NVTE_K_Matrix_Transpose) || + (matrix == NVTE_QKV_Matrix::NVTE_V_Matrix_Transpose); + int seqlen_dim_idx = transpose ? 3 : 2; + int hidden_dim_idx = transpose ? 2 : 3; + constexpr int seqlen_q_dim_idx = 2; + constexpr int seqlen_kv_dim_idx = 3; + + if (matrix == NVTE_QKV_Matrix::NVTE_Q_Matrix_Transpose) { + matrix = NVTE_QKV_Matrix::NVTE_Q_Matrix; + } else if (matrix == NVTE_QKV_Matrix::NVTE_K_Matrix_Transpose) { + matrix = NVTE_QKV_Matrix::NVTE_K_Matrix; + } else if (matrix == NVTE_QKV_Matrix::NVTE_V_Matrix_Transpose) { + matrix = NVTE_QKV_Matrix::NVTE_V_Matrix; + } + NVTE_CHECK(matrix != NVTE_QKV_Matrix::NVTE_O_Matrix, + "Invalid matrix type. Expected Q, K, V, O, or their related transposes."); + + switch (layout) { + case NVTE_QKV_Layout::NVTE_SB3HD: + if ((matrix == NVTE_QKV_Matrix::NVTE_Q_Matrix) || + (matrix == NVTE_QKV_Matrix::NVTE_K_Matrix) || + (matrix == NVTE_QKV_Matrix::NVTE_V_Matrix)) { + strides[batch_dim_idx] = 3 * h * d_qk; + strides[head_dim_idx] = d_qk; + strides[seqlen_dim_idx] = b * 3 * h * d_qk; + strides[hidden_dim_idx] = 1; + } + break; + case NVTE_QKV_Layout::NVTE_SBH3D: + if ((matrix == NVTE_QKV_Matrix::NVTE_Q_Matrix) || + (matrix == NVTE_QKV_Matrix::NVTE_K_Matrix) || + (matrix == NVTE_QKV_Matrix::NVTE_V_Matrix)) { + strides[batch_dim_idx] = 3 * h * d_qk; + strides[head_dim_idx] = 3 * d_qk; + strides[seqlen_dim_idx] = b * 3 * h * d_qk; + strides[hidden_dim_idx] = 1; + } + break; + case NVTE_QKV_Layout::NVTE_SBHD_SB2HD: + if (matrix == NVTE_QKV_Matrix::NVTE_Q_Matrix) { + strides[batch_dim_idx] = h * d_qk; + strides[head_dim_idx] = d_qk; + strides[seqlen_dim_idx] = b * h * d_qk; + strides[hidden_dim_idx] = 1; + } else if ((matrix == NVTE_QKV_Matrix::NVTE_K_Matrix) || + (matrix == NVTE_QKV_Matrix::NVTE_V_Matrix)) { + strides[batch_dim_idx] = 2 * hg * d_qk; + strides[head_dim_idx] = d_qk; + strides[seqlen_dim_idx] = b * 2 * hg * d_qk; + strides[hidden_dim_idx] = 1; + } + break; + case NVTE_QKV_Layout::NVTE_SBHD_SBH2D: + if (matrix == NVTE_QKV_Matrix::NVTE_Q_Matrix) { + strides[batch_dim_idx] = h * d_qk; + strides[head_dim_idx] = d_qk; + strides[seqlen_dim_idx] = b * h * d_qk; + strides[hidden_dim_idx] = 1; + } else if ((matrix == NVTE_QKV_Matrix::NVTE_K_Matrix) || + (matrix == NVTE_QKV_Matrix::NVTE_V_Matrix)) { + strides[batch_dim_idx] = 2 * hg * d_qk; + strides[head_dim_idx] = 2 * d_qk; + strides[seqlen_dim_idx] = b * 2 * hg * d_qk; + strides[hidden_dim_idx] = 1; + } + break; + case NVTE_QKV_Layout::NVTE_SBHD_SBHD_SBHD: + case NVTE_QKV_Layout::NVTE_Paged_KV_SBHD_SBHD_SBHD: + if (matrix == NVTE_QKV_Matrix::NVTE_Q_Matrix) { + strides[batch_dim_idx] = h * d_qk; + strides[head_dim_idx] = d_qk; + strides[seqlen_dim_idx] = b * h * d_qk; + strides[hidden_dim_idx] = 1; + } else if (matrix == NVTE_QKV_Matrix::NVTE_K_Matrix) { + strides[batch_dim_idx] = hg * d_qk; + strides[head_dim_idx] = d_qk; + strides[seqlen_dim_idx] = b * hg * d_qk; + strides[hidden_dim_idx] = 1; + } else if (matrix == NVTE_QKV_Matrix::NVTE_V_Matrix) { + strides[batch_dim_idx] = hg * d_v; + strides[head_dim_idx] = d_v; + strides[seqlen_dim_idx] = b * hg * d_v; + strides[hidden_dim_idx] = 1; + } + break; + case NVTE_QKV_Layout::NVTE_BS3HD: + case NVTE_QKV_Layout::NVTE_T3HD: + if ((matrix == NVTE_QKV_Matrix::NVTE_Q_Matrix) || + (matrix == NVTE_QKV_Matrix::NVTE_K_Matrix) || + (matrix == NVTE_QKV_Matrix::NVTE_V_Matrix)) { + strides[batch_dim_idx] = s_q * 3 * h * d_qk; + strides[head_dim_idx] = d_qk; + strides[seqlen_dim_idx] = 3 * h * d_qk; + strides[hidden_dim_idx] = 1; + } + break; + case NVTE_QKV_Layout::NVTE_BSH3D: + case NVTE_QKV_Layout::NVTE_TH3D: + if ((matrix == NVTE_QKV_Matrix::NVTE_Q_Matrix) || + (matrix == NVTE_QKV_Matrix::NVTE_K_Matrix) || + (matrix == NVTE_QKV_Matrix::NVTE_V_Matrix)) { + strides[batch_dim_idx] = s_q * 3 * h * d_qk; + strides[head_dim_idx] = 3 * d_qk; + strides[seqlen_dim_idx] = 3 * h * d_qk; + strides[hidden_dim_idx] = 1; + } + break; + case NVTE_QKV_Layout::NVTE_BSHD_BS2HD: + case NVTE_QKV_Layout::NVTE_THD_T2HD: + if (matrix == NVTE_QKV_Matrix::NVTE_Q_Matrix) { + strides[batch_dim_idx] = s_q * h * d_qk; + strides[head_dim_idx] = d_qk; + strides[seqlen_dim_idx] = h * d_qk; + strides[hidden_dim_idx] = 1; + } else if ((matrix == NVTE_QKV_Matrix::NVTE_K_Matrix) || + (matrix == NVTE_QKV_Matrix::NVTE_V_Matrix)) { + strides[batch_dim_idx] = s_kv * 2 * hg * d_qk; + strides[head_dim_idx] = d_qk; + strides[seqlen_dim_idx] = 2 * hg * d_qk; + strides[hidden_dim_idx] = 1; + } + break; + case NVTE_QKV_Layout::NVTE_BSHD_BSH2D: + case NVTE_QKV_Layout::NVTE_THD_TH2D: + if (matrix == NVTE_QKV_Matrix::NVTE_Q_Matrix) { + strides[batch_dim_idx] = s_q * h * d_qk; + strides[head_dim_idx] = d_qk; + strides[seqlen_dim_idx] = h * d_qk; + strides[hidden_dim_idx] = 1; + } else if ((matrix == NVTE_QKV_Matrix::NVTE_K_Matrix) || + (matrix == NVTE_QKV_Matrix::NVTE_V_Matrix)) { + strides[batch_dim_idx] = s_kv * 2 * hg * d_qk; + strides[head_dim_idx] = 2 * d_qk; + strides[seqlen_dim_idx] = 2 * hg * d_qk; + strides[hidden_dim_idx] = 1; + } + break; + case NVTE_QKV_Layout::NVTE_BSHD_BSHD_BSHD: + case NVTE_QKV_Layout::NVTE_THD_THD_THD: + case NVTE_QKV_Layout::NVTE_THD_BSHD_BSHD: + case NVTE_QKV_Layout::NVTE_Paged_KV_BSHD_BSHD_BSHD: + case NVTE_QKV_Layout::NVTE_Paged_KV_THD_BSHD_BSHD: + if (matrix == NVTE_QKV_Matrix::NVTE_Q_Matrix) { + strides[batch_dim_idx] = s_q * h * d_qk; + strides[head_dim_idx] = d_qk; + strides[seqlen_dim_idx] = h * d_qk; + strides[hidden_dim_idx] = 1; + } else if (matrix == NVTE_QKV_Matrix::NVTE_K_Matrix) { + strides[batch_dim_idx] = s_kv * hg * d_qk; + strides[head_dim_idx] = d_qk; + strides[seqlen_dim_idx] = hg * d_qk; + strides[hidden_dim_idx] = 1; + } else if (matrix == NVTE_QKV_Matrix::NVTE_V_Matrix) { + strides[batch_dim_idx] = s_kv * hg * d_v; + strides[head_dim_idx] = d_v; + strides[seqlen_dim_idx] = hg * d_v; + strides[hidden_dim_idx] = 1; + } + break; + case NVTE_QKV_Layout::NVTE_SBHD_BSHD_BSHD: + case NVTE_QKV_Layout::NVTE_Paged_KV_SBHD_BSHD_BSHD: + if (matrix == NVTE_QKV_Matrix::NVTE_Q_Matrix) { + strides[batch_dim_idx] = h * d_qk; + strides[head_dim_idx] = d_qk; + strides[seqlen_dim_idx] = b * h * d_qk; + strides[hidden_dim_idx] = 1; + } else if (matrix == NVTE_QKV_Matrix::NVTE_K_Matrix) { + strides[batch_dim_idx] = s_kv * hg * d_qk; + strides[head_dim_idx] = d_qk; + strides[seqlen_dim_idx] = hg * d_qk; + strides[hidden_dim_idx] = 1; + } else if (matrix == NVTE_QKV_Matrix::NVTE_V_Matrix) { + strides[batch_dim_idx] = s_kv * hg * d_v; + strides[head_dim_idx] = d_v; + strides[seqlen_dim_idx] = hg * d_v; + strides[hidden_dim_idx] = 1; + } + break; + case NVTE_QKV_Layout::NVTE_BSHD_SBHD_SBHD: + case NVTE_QKV_Layout::NVTE_THD_SBHD_SBHD: + case NVTE_QKV_Layout::NVTE_Paged_KV_BSHD_SBHD_SBHD: + case NVTE_QKV_Layout::NVTE_Paged_KV_THD_SBHD_SBHD: + if (matrix == NVTE_QKV_Matrix::NVTE_Q_Matrix) { + strides[batch_dim_idx] = s_q * h * d_qk; + strides[head_dim_idx] = d_qk; + strides[seqlen_dim_idx] = h * d_qk; + strides[hidden_dim_idx] = 1; + } else if (matrix == NVTE_QKV_Matrix::NVTE_K_Matrix) { + strides[batch_dim_idx] = hg * d_qk; + strides[head_dim_idx] = d_qk; + strides[seqlen_dim_idx] = b * hg * d_qk; + strides[hidden_dim_idx] = 1; + } else if (matrix == NVTE_QKV_Matrix::NVTE_V_Matrix) { + strides[batch_dim_idx] = hg * d_v; + strides[head_dim_idx] = d_v; + strides[seqlen_dim_idx] = b * hg * d_v; + strides[hidden_dim_idx] = 1; + } + break; + case NVTE_QKV_Layout::NVTE_BHSD_BHSD_BHSD: + if (matrix == NVTE_QKV_Matrix::NVTE_Q_Matrix) { + strides[batch_dim_idx] = h * s_q * d_qk; + strides[head_dim_idx] = s_q * d_qk; + strides[seqlen_dim_idx] = d_qk; + strides[hidden_dim_idx] = 1; + } else if (matrix == NVTE_QKV_Matrix::NVTE_K_Matrix) { + strides[batch_dim_idx] = hg * s_kv * d_qk; + strides[head_dim_idx] = s_kv * d_qk; + strides[seqlen_dim_idx] = d_qk; + strides[hidden_dim_idx] = 1; + } else if (matrix == NVTE_QKV_Matrix::NVTE_V_Matrix) { + strides[batch_dim_idx] = hg * s_kv * d_v; + strides[head_dim_idx] = s_kv * d_v; + strides[seqlen_dim_idx] = d_v; + strides[hidden_dim_idx] = 1; + } + break; + default: + NVTE_CHECK(false, "Invalid layout."); + break; + } if (matrix == NVTE_QKV_Matrix::NVTE_S_Matrix) { strides[seqlen_kv_dim_idx] = 1; diff --git a/transformer_engine/common/include/transformer_engine/fused_attn.h b/transformer_engine/common/include/transformer_engine/fused_attn.h index 6ee4d1a8ba..90393ce8c8 100644 --- a/transformer_engine/common/include/transformer_engine/fused_attn.h +++ b/transformer_engine/common/include/transformer_engine/fused_attn.h @@ -207,7 +207,9 @@ NVTE_QKV_Format nvte_get_kv_format(NVTE_QKV_Layout qkv_layout); * * \return The destination shape. */ - void nvte_convert_qkv_format(NVTE_QKV_Format src_format, std::vector src_shape, NVTE_QKV_Format dst_format, std::vector &dst_shape, size_t *b, size_t *h, size_t *s, size_t *d, size_t *t); +void nvte_convert_qkv_format(NVTE_QKV_Format src_format, std::vector src_shape, + NVTE_QKV_Format dst_format, std::vector &dst_shape, size_t *b, + size_t *h, size_t *s, size_t *d, size_t *t); /*! \brief Get fused attention backend based on input parameters. * @@ -305,19 +307,17 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( * \param[in] workspace Workspace tensor. * \param[in] stream CUDA stream used for this operation. */ -void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETensor V, - const NVTETensor Bias, const NVTETensor SoftmaxOffset, NVTETensor S, - NVTETensor O, NVTETensorPack *Aux_CTX_Tensors, - const NVTETensor cu_seqlens_q, const NVTETensor cu_seqlens_kv, - const NVTETensor cu_seqlens_q_padded, - const NVTETensor cu_seqlens_kv_padded, const NVTETensor page_table_k, - const NVTETensor page_table_v, const NVTETensor rng_state, - size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training, - bool return_max_logit, bool cuda_graph, float attn_scale, float dropout, - NVTE_QKV_Layout qkv_layout, NVTE_QKV_Format o_format, NVTE_Bias_Type bias_type, - NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, - int64_t window_size_left, int64_t window_size_right, - bool bottom_right_diagonal, NVTETensor workspace, cudaStream_t stream); +void nvte_fused_attn_fwd( + const NVTETensor Q, const NVTETensor K, const NVTETensor V, const NVTETensor Bias, + const NVTETensor SoftmaxOffset, NVTETensor S, NVTETensor O, NVTETensorPack *Aux_CTX_Tensors, + const NVTETensor cu_seqlens_q, const NVTETensor cu_seqlens_kv, + const NVTETensor cu_seqlens_q_padded, const NVTETensor cu_seqlens_kv_padded, + const NVTETensor page_table_k, const NVTETensor page_table_v, const NVTETensor rng_state, + size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training, bool return_max_logit, + bool cuda_graph, float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, + NVTE_QKV_Format o_format, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, + NVTE_Softmax_Type softmax_type, int64_t window_size_left, int64_t window_size_right, + bool bottom_right_diagonal, NVTETensor workspace, cudaStream_t stream); /*! \brief Compute the backward of the dot product attention with separate Q, K and V. * @@ -391,11 +391,12 @@ void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETenso const NVTETensor cu_seqlens_q_padded, const NVTETensor cu_seqlens_kv_padded, size_t max_seqlen_q, size_t max_seqlen_kv, float attn_scale, float dropout, - NVTE_QKV_Layout qkv_layout, NVTE_QKV_Format o_format, NVTE_QKV_Format d_out_format, NVTE_QKV_Layout dqkv_layout, NVTE_Bias_Type bias_type, - NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, - int64_t window_size_left, int64_t window_size_right, - bool bottom_right_diagonal, bool deterministic, bool cuda_graph, - NVTETensor workspace, cudaStream_t stream); + NVTE_QKV_Layout qkv_layout, NVTE_QKV_Format o_format, + NVTE_QKV_Format d_out_format, NVTE_QKV_Layout dqkv_layout, + NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, + NVTE_Softmax_Type softmax_type, int64_t window_size_left, + int64_t window_size_right, bool bottom_right_diagonal, bool deterministic, + bool cuda_graph, NVTETensor workspace, cudaStream_t stream); /*! \brief Update the RNG state with the seed and calculated offset. * diff --git a/transformer_engine/common/include/transformer_engine/transformer_engine.h b/transformer_engine/common/include/transformer_engine/transformer_engine.h index 68fe616a93..a6cb036a35 100644 --- a/transformer_engine/common/include/transformer_engine/transformer_engine.h +++ b/transformer_engine/common/include/transformer_engine/transformer_engine.h @@ -450,7 +450,8 @@ enum NVTEGroupedTensorParam { kNVTEGroupedTensorOffsets = 9, /*!< Tensor offsets for contiguous layout (device pointer to int64_t array) */ <<<<<<< HEAD - kNVTEGroupedWithGEMMSwizzledScales = 10, /*!< Whether scaling factors are in format expected by GEMM */ + kNVTEGroupedWithGEMMSwizzledScales = + 10, /*!< Whether scaling factors are in format expected by GEMM */ ======= kNVTEGroupedWithGEMMSwizzledScales = 10, /*!< Whether scaling factors are in format expected by GEMM */ @@ -517,8 +518,8 @@ void nvte_get_grouped_tensor_param(const NVTEGroupedTensor tensor, NVTEGroupedTe * \param[in] param The value to be set (NVTEBasicTensor). */ void nvte_set_grouped_tensor_param_v2(NVTEGroupedTensor tensor, NVTEGroupedTensorParam param_name, - const void *buf, size_t size_in_bytes); - + const void *buf, size_t size_in_bytes); + /* EXPERIMENTAL FEATURE AND SUBJECT TO CHANGE. */ /*! \brief Get a value of the parameter of the grouped tensor. * @@ -527,7 +528,9 @@ void nvte_set_grouped_tensor_param_v2(NVTEGroupedTensor tensor, NVTEGroupedTenso * * \return NVTEBasicTensor containing the parameter data. */ -void nvte_get_grouped_tensor_param_v2(const NVTEGroupedTensor tensor, NVTEGroupedTensorParam param_name, void *buf, size_t size_in_bytes, size_t *size_written); +void nvte_get_grouped_tensor_param_v2(const NVTEGroupedTensor tensor, + NVTEGroupedTensorParam param_name, void *buf, + size_t size_in_bytes, size_t *size_written); /* EXPERIMENTAL FEATURE AND SUBJECT TO CHANGE. */ /*! \brief Get the number of tensors in a grouped tensor. @@ -992,9 +995,9 @@ class TensorWrapper { */ <<<<<<< HEAD - class GroupedTensorWrapper { - public: - /*! \brief Constructs new GroupedTensorWrapper. +class GroupedTensorWrapper { + public: + /*! \brief Constructs new GroupedTensorWrapper. * * Create a new TE grouped tensor with a given logical shape. * TE grouped tensors are just wrappers on top of raw data and do not @@ -1004,11 +1007,11 @@ class TensorWrapper { * \param[in] logical_shape Logical 2D shape of the grouped data. * \param[in] scaling_mode Tensor data format. */ - GroupedTensorWrapper(const size_t num_tensors, const NVTEShape &logical_shape, - const NVTEScalingMode scaling_mode = NVTE_DELAYED_TENSOR_SCALING) - : tensor_(nvte_create_grouped_tensor(scaling_mode, num_tensors, logical_shape)) {} - - /*! \brief Constructs new GroupedTensorWrapper. + GroupedTensorWrapper(const size_t num_tensors, const NVTEShape &logical_shape, + const NVTEScalingMode scaling_mode = NVTE_DELAYED_TENSOR_SCALING) + : tensor_(nvte_create_grouped_tensor(scaling_mode, num_tensors, logical_shape)) {} + + /*! \brief Constructs new GroupedTensorWrapper. * * Create a new TE grouped tensor with a given logical shape. * @@ -1016,194 +1019,196 @@ class TensorWrapper { * \param[in] logical_shape Logical 2D shape of the grouped data. * \param[in] scaling_mode Tensor data format. */ - GroupedTensorWrapper(const size_t num_tensors, const std::vector &logical_shape, - const NVTEScalingMode scaling_mode = NVTE_DELAYED_TENSOR_SCALING) - : GroupedTensorWrapper(num_tensors, - nvte_make_shape(logical_shape.data(), logical_shape.size()), - scaling_mode) {} - - /*! \brief GroupedTensorWrapper destructor. */ - ~GroupedTensorWrapper() { nvte_destroy_grouped_tensor(tensor_); } - - GroupedTensorWrapper &operator=(const GroupedTensorWrapper &other) = delete; - GroupedTensorWrapper(const GroupedTensorWrapper &other) = delete; - - /*! \brief Constructs new GroupedTensorWrapper from existing GroupedTensorWrapper. */ - GroupedTensorWrapper(GroupedTensorWrapper &&other) { - tensor_ = other.tensor_; - other.tensor_ = nullptr; - } - - /*! \brief Assign the data from existing GroupedTensorWrapper. */ - GroupedTensorWrapper &operator=(GroupedTensorWrapper &&other) { - if (this == &other) return *this; - nvte_destroy_grouped_tensor(tensor_); - tensor_ = other.tensor_; - other.tensor_ = nullptr; - return *this; - } - - // Parameter setters - template - GroupedTensorWrapper &set_parameter(const NVTEGroupedTensorParam param, void *dptr, DType type, - const ShapeType &shape) noexcept { - NVTEShape nvte_shape = this->convertShape(shape); - NVTEBasicTensor data = {dptr, static_cast(type), nvte_shape}; - nvte_set_grouped_tensor_param(&tensor_, param, &data); - return *this; - } - - template - GroupedTensorWrapper &set_rowwise_data(void *dptr, DType type, const ShapeType &shape) noexcept { - return set_parameter(kNVTEGroupedRowwiseData, dptr, type, shape); - } - - template - GroupedTensorWrapper &set_columnwise_data(void *dptr, DType type, - const ShapeType &shape) noexcept { - return set_parameter(kNVTEGroupedColumnwiseData, dptr, type, shape); - } - - template - GroupedTensorWrapper &set_scale(void *dptr, DType type, const ShapeType &shape) noexcept { - return set_parameter(kNVTEGroupedScale, dptr, type, shape); - } - - template - GroupedTensorWrapper &set_amax(void *dptr, DType type, const ShapeType &shape) noexcept { - return set_parameter(kNVTEGroupedAmax, dptr, type, shape); - } - - template - GroupedTensorWrapper &set_rowwise_scale_inv(void *dptr, DType type, - const ShapeType &shape) noexcept { - return set_parameter(kNVTEGroupedRowwiseScaleInv, dptr, type, shape); - } - - template - GroupedTensorWrapper &set_columnwise_scale_inv(void *dptr, DType type, - const ShapeType &shape) noexcept { - return set_parameter(kNVTEGroupedColumnwiseScaleInv, dptr, type, shape); - } - - template - GroupedTensorWrapper &set_columnwise_amax(void *dptr, DType type, - const ShapeType &shape) noexcept { - return set_parameter(kNVTEGroupedColumnwiseAmax, dptr, type, shape); - } - - template - GroupedTensorWrapper &set_first_dims(void *dptr, DType type, const ShapeType &shape) noexcept { - return set_parameter(kNVTEGroupedFirstDims, dptr, type, shape); - } - - template - GroupedTensorWrapper &set_last_dims(void *dptr, DType type, const ShapeType &shape) noexcept { - return set_parameter(kNVTEGroupedLastDims, dptr, type, shape); - } - - template - GroupedTensorWrapper &set_tensor_offsets(void *dptr, DType type, + GroupedTensorWrapper(const size_t num_tensors, const std::vector &logical_shape, + const NVTEScalingMode scaling_mode = NVTE_DELAYED_TENSOR_SCALING) + : GroupedTensorWrapper(num_tensors, + nvte_make_shape(logical_shape.data(), logical_shape.size()), + scaling_mode) {} + + /*! \brief GroupedTensorWrapper destructor. */ + ~GroupedTensorWrapper() { nvte_destroy_grouped_tensor(tensor_); } + + GroupedTensorWrapper &operator=(const GroupedTensorWrapper &other) = delete; + GroupedTensorWrapper(const GroupedTensorWrapper &other) = delete; + + /*! \brief Constructs new GroupedTensorWrapper from existing GroupedTensorWrapper. */ + GroupedTensorWrapper(GroupedTensorWrapper &&other) { + tensor_ = other.tensor_; + other.tensor_ = nullptr; + } + + /*! \brief Assign the data from existing GroupedTensorWrapper. */ + GroupedTensorWrapper &operator=(GroupedTensorWrapper &&other) { + if (this == &other) return *this; + nvte_destroy_grouped_tensor(tensor_); + tensor_ = other.tensor_; + other.tensor_ = nullptr; + return *this; + } + + // Parameter setters + template + GroupedTensorWrapper &set_parameter(const NVTEGroupedTensorParam param, void *dptr, DType type, + const ShapeType &shape) noexcept { + NVTEShape nvte_shape = this->convertShape(shape); + NVTEBasicTensor data = {dptr, static_cast(type), nvte_shape}; + nvte_set_grouped_tensor_param(&tensor_, param, &data); + return *this; + } + + template + GroupedTensorWrapper &set_rowwise_data(void *dptr, DType type, const ShapeType &shape) noexcept { + return set_parameter(kNVTEGroupedRowwiseData, dptr, type, shape); + } + + template + GroupedTensorWrapper &set_columnwise_data(void *dptr, DType type, + const ShapeType &shape) noexcept { + return set_parameter(kNVTEGroupedColumnwiseData, dptr, type, shape); + } + + template + GroupedTensorWrapper &set_scale(void *dptr, DType type, const ShapeType &shape) noexcept { + return set_parameter(kNVTEGroupedScale, dptr, type, shape); + } + + template + GroupedTensorWrapper &set_amax(void *dptr, DType type, const ShapeType &shape) noexcept { + return set_parameter(kNVTEGroupedAmax, dptr, type, shape); + } + + template + GroupedTensorWrapper &set_rowwise_scale_inv(void *dptr, DType type, + const ShapeType &shape) noexcept { + return set_parameter(kNVTEGroupedRowwiseScaleInv, dptr, type, shape); + } + + template + GroupedTensorWrapper &set_columnwise_scale_inv(void *dptr, DType type, + const ShapeType &shape) noexcept { + return set_parameter(kNVTEGroupedColumnwiseScaleInv, dptr, type, shape); + } + + template + GroupedTensorWrapper &set_columnwise_amax(void *dptr, DType type, const ShapeType &shape) noexcept { - return set_parameter(kNVTEGroupedTensorOffsets, dptr, type, shape); - } + return set_parameter(kNVTEGroupedColumnwiseAmax, dptr, type, shape); + } + + template + GroupedTensorWrapper &set_first_dims(void *dptr, DType type, const ShapeType &shape) noexcept { + return set_parameter(kNVTEGroupedFirstDims, dptr, type, shape); + } + + template + GroupedTensorWrapper &set_last_dims(void *dptr, DType type, const ShapeType &shape) noexcept { + return set_parameter(kNVTEGroupedLastDims, dptr, type, shape); + } + + template + GroupedTensorWrapper &set_tensor_offsets(void *dptr, DType type, + const ShapeType &shape) noexcept { + return set_parameter(kNVTEGroupedTensorOffsets, dptr, type, shape); + } - void set_with_gemm_swizzled_scales(bool with_gemm_swizzled_scales) { + void set_with_gemm_swizzled_scales(bool with_gemm_swizzled_scales) { const auto val = static_cast(with_gemm_swizzled_scales); - nvte_set_grouped_tensor_param_v2(tensor_, kNVTEGroupedWithGEMMSwizzledScales, &val, sizeof(val)); - } - - // Parameter getters - NVTEBasicTensor get_parameter(const NVTEGroupedTensorParam param) const noexcept { - return nvte_get_grouped_tensor_param(tensor_, param); - } - - NVTEBasicTensor get_rowwise_data() const noexcept { - return get_parameter(kNVTEGroupedRowwiseData); - } - - NVTEBasicTensor get_columnwise_data() const noexcept { - return get_parameter(kNVTEGroupedColumnwiseData); - } - - NVTEBasicTensor get_scale() const noexcept { return get_parameter(kNVTEGroupedScale); } - - NVTEBasicTensor get_amax() const noexcept { return get_parameter(kNVTEGroupedAmax); } - - NVTEBasicTensor get_rowwise_scale_inv() const noexcept { - return get_parameter(kNVTEGroupedRowwiseScaleInv); - } - - NVTEBasicTensor get_columnwise_scale_inv() const noexcept { - return get_parameter(kNVTEGroupedColumnwiseScaleInv); - } - - NVTEBasicTensor get_columnwise_amax() const noexcept { - return get_parameter(kNVTEGroupedColumnwiseAmax); - } - - NVTEBasicTensor get_first_dims() const noexcept { return get_parameter(kNVTEGroupedFirstDims); } - - NVTEBasicTensor get_last_dims() const noexcept { return get_parameter(kNVTEGroupedLastDims); } - - NVTEBasicTensor get_tensor_offsets() const noexcept { - return get_parameter(kNVTEGroupedTensorOffsets); - } - - bool get_with_gemm_swizzled_scales() const { + nvte_set_grouped_tensor_param_v2(tensor_, kNVTEGroupedWithGEMMSwizzledScales, &val, + sizeof(val)); + } + + // Parameter getters + NVTEBasicTensor get_parameter(const NVTEGroupedTensorParam param) const noexcept { + return nvte_get_grouped_tensor_param(tensor_, param); + } + + NVTEBasicTensor get_rowwise_data() const noexcept { + return get_parameter(kNVTEGroupedRowwiseData); + } + + NVTEBasicTensor get_columnwise_data() const noexcept { + return get_parameter(kNVTEGroupedColumnwiseData); + } + + NVTEBasicTensor get_scale() const noexcept { return get_parameter(kNVTEGroupedScale); } + + NVTEBasicTensor get_amax() const noexcept { return get_parameter(kNVTEGroupedAmax); } + + NVTEBasicTensor get_rowwise_scale_inv() const noexcept { + return get_parameter(kNVTEGroupedRowwiseScaleInv); + } + + NVTEBasicTensor get_columnwise_scale_inv() const noexcept { + return get_parameter(kNVTEGroupedColumnwiseScaleInv); + } + + NVTEBasicTensor get_columnwise_amax() const noexcept { + return get_parameter(kNVTEGroupedColumnwiseAmax); + } + + NVTEBasicTensor get_first_dims() const noexcept { return get_parameter(kNVTEGroupedFirstDims); } + + NVTEBasicTensor get_last_dims() const noexcept { return get_parameter(kNVTEGroupedLastDims); } + + NVTEBasicTensor get_tensor_offsets() const noexcept { + return get_parameter(kNVTEGroupedTensorOffsets); + } + + bool get_with_gemm_swizzled_scales() const { uint8_t val = 0; - nvte_get_grouped_tensor_param_v2(tensor_, kNVTEGroupedWithGEMMSwizzledScales, &val, sizeof(val), nullptr); + nvte_get_grouped_tensor_param_v2(tensor_, kNVTEGroupedWithGEMMSwizzledScales, &val, sizeof(val), + nullptr); return static_cast(val); } - /*! \brief Get an underlying NVTEGroupedTensor. + /*! \brief Get an underlying NVTEGroupedTensor. * * \return NVTEGroupedTensor held by this GroupedTensorWrapper. */ - NVTEGroupedTensor data() const noexcept { return tensor_; } - - /*! \brief Get the number of tensors in this GroupedTensorWrapper. */ - size_t num_tensors() const noexcept { - if (tensor_ == nullptr) return 0; - return nvte_grouped_tensor_num_tensors(tensor_); - } - - /*! \brief Get the data type of this GroupedTensorWrapper. */ - DType dtype() const noexcept { - if (tensor_ == nullptr) return DType::kNumTypes; - return static_cast(nvte_grouped_tensor_type(tensor_)); - } - - /*! \brief Get a scaling mode of the grouped tensor. */ - NVTEScalingMode scaling_mode() const noexcept { - if (tensor_ == nullptr) return NVTE_DELAYED_TENSOR_SCALING; - return nvte_grouped_tensor_scaling_mode(tensor_); - } - - /*! \brief Get the logical shape of this GroupedTensorWrapper. */ - const NVTEShape logical_shape() const noexcept { - if (tensor_ == nullptr) { - return emptyShape; - } - return nvte_get_grouped_tensor_logical_shape(tensor_); - } - - static constexpr size_t defaultData = 1; - static constexpr NVTEShape defaultShape = { - {defaultData, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, 1}; - static constexpr NVTEShape emptyShape = {{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, 1}; - - private: - NVTEShape convertShape(const NVTEShape &s) { return s; } - - NVTEShape convertShape(const std::vector &s) { - return nvte_make_shape(s.data(), s.size()); - } - - /*! \brief Wrapped NVTEGroupedTensor. */ - NVTEGroupedTensor tensor_ = nullptr; - }; - + NVTEGroupedTensor data() const noexcept { return tensor_; } + + /*! \brief Get the number of tensors in this GroupedTensorWrapper. */ + size_t num_tensors() const noexcept { + if (tensor_ == nullptr) return 0; + return nvte_grouped_tensor_num_tensors(tensor_); + } + + /*! \brief Get the data type of this GroupedTensorWrapper. */ + DType dtype() const noexcept { + if (tensor_ == nullptr) return DType::kNumTypes; + return static_cast(nvte_grouped_tensor_type(tensor_)); + } + + /*! \brief Get a scaling mode of the grouped tensor. */ + NVTEScalingMode scaling_mode() const noexcept { + if (tensor_ == nullptr) return NVTE_DELAYED_TENSOR_SCALING; + return nvte_grouped_tensor_scaling_mode(tensor_); + } + + /*! \brief Get the logical shape of this GroupedTensorWrapper. */ + const NVTEShape logical_shape() const noexcept { + if (tensor_ == nullptr) { + return emptyShape; + } + return nvte_get_grouped_tensor_logical_shape(tensor_); + } + + static constexpr size_t defaultData = 1; + static constexpr NVTEShape defaultShape = { + {defaultData, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, 1}; + static constexpr NVTEShape emptyShape = {{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, 1}; + + private: + NVTEShape convertShape(const NVTEShape &s) { return s; } + + NVTEShape convertShape(const std::vector &s) { + return nvte_make_shape(s.data(), s.size()); + } + + /*! \brief Wrapped NVTEGroupedTensor. */ + NVTEGroupedTensor tensor_ = nullptr; +}; + /*! \warning Deprecated */ enum class Float8BlockScaleTensorFormat { GEMM_READY = 0, COMPACT = 1, INVALID }; ======= diff --git a/transformer_engine/common/recipe/__init__.py b/transformer_engine/common/recipe/__init__.py index 12be47b638..18577b0eb4 100644 --- a/transformer_engine/common/recipe/__init__.py +++ b/transformer_engine/common/recipe/__init__.py @@ -87,6 +87,7 @@ class Recipe: """ Base recipe class. """ + @classmethod def nvfp4(cls): """Whether the given recipe is NVFP4 1D block scaling.""" diff --git a/transformer_engine/common/transformer_engine.cpp b/transformer_engine/common/transformer_engine.cpp index c7bbe4d974..be7521ccd4 100644 --- a/transformer_engine/common/transformer_engine.cpp +++ b/transformer_engine/common/transformer_engine.cpp @@ -1359,116 +1359,116 @@ NVTEShape nvte_get_grouped_tensor_logical_shape(const NVTEGroupedTensor tensor) return t.logical_shape; } -void nvte_set_grouped_tensor_param_v2(NVTEGroupedTensor tensor, NVTEGroupedTensorParam param, const void *buf, - size_t size_in_bytes) { -// Check attribute and buffer -NVTE_CHECK(param < kNVTENumGroupedTensorParams, "Invalid NVTEGroupedTensorParam (got ", static_cast(param), -")"); -NVTE_CHECK(tensor != nullptr, "Grouped tensor pointer can't be NULL."); -auto &t = *transformer_engine::convertNVTEGroupedTensorCheck(tensor); - -// Read from buffer -switch (param) { -case kNVTEGroupedRowwiseData: { -const NVTEBasicTensor *basic_tensor = reinterpret_cast(buf); -t.data = *basic_tensor; -break; -} -case kNVTEGroupedColumnwiseData: { -const NVTEBasicTensor *basic_tensor = reinterpret_cast(buf); -t.columnwise_data = *basic_tensor; -break; -} -case kNVTEGroupedScale: { -const NVTEBasicTensor *basic_tensor = reinterpret_cast(buf); -t.scale = *basic_tensor; -break; -} -case kNVTEGroupedAmax: { -const NVTEBasicTensor *basic_tensor = reinterpret_cast(buf); -t.amax = *basic_tensor; -break; -} -case kNVTEGroupedRowwiseScaleInv: { -const NVTEBasicTensor *basic_tensor = reinterpret_cast(buf); -t.scale_inv = *basic_tensor; -break; -} -case kNVTEGroupedColumnwiseScaleInv: { -const NVTEBasicTensor *basic_tensor = reinterpret_cast(buf); -t.columnwise_scale_inv = *basic_tensor; -break; -} -case kNVTEGroupedColumnwiseAmax: { -const NVTEBasicTensor *basic_tensor = reinterpret_cast(buf); -t.columnwise_amax = *basic_tensor; -break; -} -case kNVTEGroupedWithGEMMSwizzledScales: -t.with_gemm_swizzled_scales = static_cast(*reinterpret_cast(buf)); -break; -default: -NVTE_ERROR("Unsupported grouped tensor parameter (", static_cast(param), ")"); -} +void nvte_set_grouped_tensor_param_v2(NVTEGroupedTensor tensor, NVTEGroupedTensorParam param, + const void *buf, size_t size_in_bytes) { + // Check attribute and buffer + NVTE_CHECK(param < kNVTENumGroupedTensorParams, "Invalid NVTEGroupedTensorParam (got ", + static_cast(param), ")"); + NVTE_CHECK(tensor != nullptr, "Grouped tensor pointer can't be NULL."); + auto &t = *transformer_engine::convertNVTEGroupedTensorCheck(tensor); + + // Read from buffer + switch (param) { + case kNVTEGroupedRowwiseData: { + const NVTEBasicTensor *basic_tensor = reinterpret_cast(buf); + t.data = *basic_tensor; + break; + } + case kNVTEGroupedColumnwiseData: { + const NVTEBasicTensor *basic_tensor = reinterpret_cast(buf); + t.columnwise_data = *basic_tensor; + break; + } + case kNVTEGroupedScale: { + const NVTEBasicTensor *basic_tensor = reinterpret_cast(buf); + t.scale = *basic_tensor; + break; + } + case kNVTEGroupedAmax: { + const NVTEBasicTensor *basic_tensor = reinterpret_cast(buf); + t.amax = *basic_tensor; + break; + } + case kNVTEGroupedRowwiseScaleInv: { + const NVTEBasicTensor *basic_tensor = reinterpret_cast(buf); + t.scale_inv = *basic_tensor; + break; + } + case kNVTEGroupedColumnwiseScaleInv: { + const NVTEBasicTensor *basic_tensor = reinterpret_cast(buf); + t.columnwise_scale_inv = *basic_tensor; + break; + } + case kNVTEGroupedColumnwiseAmax: { + const NVTEBasicTensor *basic_tensor = reinterpret_cast(buf); + t.columnwise_amax = *basic_tensor; + break; + } + case kNVTEGroupedWithGEMMSwizzledScales: + t.with_gemm_swizzled_scales = static_cast(*reinterpret_cast(buf)); + break; + default: + NVTE_ERROR("Unsupported grouped tensor parameter (", static_cast(param), ")"); + } } -void nvte_get_grouped_tensor_param_v2(const NVTEGroupedTensor tensor, NVTEGroupedTensorParam param, void *buf, - size_t size_in_bytes, size_t *size_written) { -using namespace transformer_engine; +void nvte_get_grouped_tensor_param_v2(const NVTEGroupedTensor tensor, NVTEGroupedTensorParam param, + void *buf, size_t size_in_bytes, size_t *size_written) { + using namespace transformer_engine; -// Check param -NVTE_CHECK(param < kNVTENumGroupedTensorParams, "Invalid NVTEGroupedTensorParam (got ", static_cast(param), -")"); + // Check param + NVTE_CHECK(param < kNVTENumGroupedTensorParams, "Invalid NVTEGroupedTensorParam (got ", + static_cast(param), ")"); -// Return immediately if buffer is not provided -if (buf == nullptr) { -return; -} + // Return immediately if buffer is not provided + if (buf == nullptr) { + return; + } -// Get C++ tensor -const GroupedTensor *t = convertNVTEGroupedTensor(tensor); + // Get C++ tensor + const GroupedTensor *t = convertNVTEGroupedTensor(tensor); -// Write to buffer -switch (param) { -case kNVTEGroupedRowwiseData: { -NVTEBasicTensor *basic_tensor = reinterpret_cast(buf); -*basic_tensor = static_cast(t->data); -break; -} -case kNVTEGroupedColumnwiseData: { -NVTEBasicTensor *basic_tensor = reinterpret_cast(buf); -*basic_tensor = static_cast(t->columnwise_data); -break; -} -case kNVTEGroupedScale: { -NVTEBasicTensor *basic_tensor = reinterpret_cast(buf); -*basic_tensor = static_cast(t->scale); -break; -} -case kNVTEGroupedAmax: { -NVTEBasicTensor *basic_tensor = reinterpret_cast(buf); -*basic_tensor = static_cast(t->amax); -break; -} -case kNVTEGroupedRowwiseScaleInv: { -NVTEBasicTensor *basic_tensor = reinterpret_cast(buf); -*basic_tensor = static_cast(t->scale_inv); -break; -} -case kNVTEGroupedColumnwiseScaleInv: { -NVTEBasicTensor *basic_tensor = reinterpret_cast(buf); -*basic_tensor = static_cast(t->columnwise_scale_inv); -break; -} -case kNVTEGroupedColumnwiseAmax: { -NVTEBasicTensor *basic_tensor = reinterpret_cast(buf); -*basic_tensor = static_cast(t->columnwise_amax); -break; -} -case kNVTEGroupedWithGEMMSwizzledScales: -*reinterpret_cast(buf) = static_cast(t->with_gemm_swizzled_scales); -break; -default: -NVTE_ERROR("Unsupported grouped tensor parameter (", static_cast(param), ")"); -} + // Write to buffer + switch (param) { + case kNVTEGroupedRowwiseData: { + NVTEBasicTensor *basic_tensor = reinterpret_cast(buf); + *basic_tensor = static_cast(t->data); + break; + } + case kNVTEGroupedColumnwiseData: { + NVTEBasicTensor *basic_tensor = reinterpret_cast(buf); + *basic_tensor = static_cast(t->columnwise_data); + break; + } + case kNVTEGroupedScale: { + NVTEBasicTensor *basic_tensor = reinterpret_cast(buf); + *basic_tensor = static_cast(t->scale); + break; + } + case kNVTEGroupedAmax: { + NVTEBasicTensor *basic_tensor = reinterpret_cast(buf); + *basic_tensor = static_cast(t->amax); + break; + } + case kNVTEGroupedRowwiseScaleInv: { + NVTEBasicTensor *basic_tensor = reinterpret_cast(buf); + *basic_tensor = static_cast(t->scale_inv); + break; + } + case kNVTEGroupedColumnwiseScaleInv: { + NVTEBasicTensor *basic_tensor = reinterpret_cast(buf); + *basic_tensor = static_cast(t->columnwise_scale_inv); + break; + } + case kNVTEGroupedColumnwiseAmax: { + NVTEBasicTensor *basic_tensor = reinterpret_cast(buf); + *basic_tensor = static_cast(t->columnwise_amax); + break; + } + case kNVTEGroupedWithGEMMSwizzledScales: + *reinterpret_cast(buf) = static_cast(t->with_gemm_swizzled_scales); + break; + default: + NVTE_ERROR("Unsupported grouped tensor parameter (", static_cast(param), ")"); + } } diff --git a/transformer_engine/pytorch/attention/dot_product_attention/backends.py b/transformer_engine/pytorch/attention/dot_product_attention/backends.py index 08b9dca6d7..e3aacbf2e6 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/backends.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/backends.py @@ -175,15 +175,26 @@ def forward(ctx, tensor1, tensor2, tensor3, quantizer, quantizer_name, qkv_layou query_layer, key_layer, value_layer = [ x.contiguous() for x in [tensor1, tensor2, tensor3] ] - assert qkv_layout == "sbhd_sbhd_sbhd", "sbhd_sbhd_sbhd is assumed to be the shape always at this point in UnfusedDotProductAttention." + assert qkv_layout == "sbhd_sbhd_sbhd", ( + "sbhd_sbhd_sbhd is assumed to be the shape always at this point in" + " UnfusedDotProductAttention." + ) q_fp8, k_fp8, v_fp8, qkv_layout = combine_and_quantize( qkv_layout, query_layer, key_layer, value_layer, quantizer ) tensors = combine_and_dequantize( - qkv_layout, q_fp8, k_fp8, v_fp8, src_nominal_dtype=query_layer.dtype, des_nominal_dtype=query_layer.dtype + qkv_layout, + q_fp8, + k_fp8, + v_fp8, + src_nominal_dtype=query_layer.dtype, + des_nominal_dtype=query_layer.dtype, ) if isinstance(quantizer, MXFP8Quantizer): - assert qkv_layout == "bhsd_bhsd_bhsd", "bhsd_bhsd_bhsd is assumed to be the shape always at this point in UnfusedDotProductAttention." + assert qkv_layout == "bhsd_bhsd_bhsd", ( + "bhsd_bhsd_bhsd is assumed to be the shape always at this point in" + " UnfusedDotProductAttention." + ) # permute back to sbhd_sbhd_sbhd tensors = [x.permute(2, 0, 1, 3).contiguous() for x in tensors] elif quantizer_name in ["S_quantizer", "O_quantizer"]: @@ -217,7 +228,10 @@ def backward(ctx, grad1, grad2, grad3): ctx.qkv_layout, dq_fp8, dk_fp8, dv_fp8, src_nominal_dtype=query_grad.dtype ) if isinstance(ctx.quantizer, MXFP8Quantizer): - assert ctx.qkv_layout == "bhsd_bhsd_bhsd", "bhsd_bhsd_bhsd is assumed to be the shape always at this point in UnfusedDotProductAttention." + assert ctx.qkv_layout == "bhsd_bhsd_bhsd", ( + "bhsd_bhsd_bhsd is assumed to be the shape always at this point in" + " UnfusedDotProductAttention." + ) # permute back to sbhd_sbhd_sbhd tensors = [x.permute(2, 0, 1, 3).contiguous() for x in tensors] else: @@ -469,7 +483,14 @@ def forward( fp8_dtype=dP_quantizer.dtype, device="cuda" ) # disable swizzle for MXFP8Quantizer - for q in [QKV_quantizer, O_quantizer, S_quantizer, dQKV_quantizer, dO_quantizer, dP_quantizer]: + for q in [ + QKV_quantizer, + O_quantizer, + S_quantizer, + dQKV_quantizer, + dO_quantizer, + dP_quantizer, + ]: if isinstance(q, MXFP8Quantizer): q.optimize_for_gemm = False q.internal = False @@ -477,11 +498,21 @@ def forward( # q, k, v are in sbhd after previous reshaping # quantize and dequantize QKV to emulate FP8 query_layer, key_layer, value_layer = FP8EmulationFunc.apply( - query_layer, key_layer, value_layer, QKV_quantizer, "QKV_quantizer", "sbhd_sbhd_sbhd" + query_layer, + key_layer, + value_layer, + QKV_quantizer, + "QKV_quantizer", + "sbhd_sbhd_sbhd", ) # quantize and dequantize dQKV to emulate FP8 query_layer, key_layer, value_layer = FP8EmulationFunc.apply( - query_layer, key_layer, value_layer, dQKV_quantizer, "dQKV_quantizer", "sbhd_sbhd_sbhd" + query_layer, + key_layer, + value_layer, + dQKV_quantizer, + "dQKV_quantizer", + "sbhd_sbhd_sbhd", ) # [sq, b, np, hn] -> [sq, b * np, hn] @@ -1250,7 +1281,9 @@ def forward( if is_input_fp8: q_fp8, k_fp8, v_fp8 = q, k, v else: - q_fp8, k_fp8, v_fp8, qkv_layout = combine_and_quantize(qkv_layout, q, k, v, QKV_quantizer) + q_fp8, k_fp8, v_fp8, qkv_layout = combine_and_quantize( + qkv_layout, q, k, v, QKV_quantizer + ) # print quantizers print_quantizers( @@ -1335,7 +1368,9 @@ def forward( fp8_tensors = (None, None, None, None) qkvo_tensors = (None, None, None, None) if is_bwd_fp8: - if (fp8_recipe.float8_current_scaling() and _dpa_fp8_cs_o_in_f16) or isinstance(QKV_quantizer, MXFP8Quantizer): + if (fp8_recipe.float8_current_scaling() and _dpa_fp8_cs_o_in_f16) or isinstance( + QKV_quantizer, MXFP8Quantizer + ): fp8_tensors = (q_fp8, k_fp8, v_fp8, None) qkvo_tensors = (None, None, None, out) else: diff --git a/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py b/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py index 60b3f7fe71..de5563b7a6 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py @@ -60,6 +60,7 @@ # Float8CurrentScaling: fused_attn_bwd takes O in FP8 by default, this flag allows it in F16 _dpa_fp8_cs_o_in_f16 = os.getenv("NVTE_DPA_FP8CS_O_in_F16", "1") == "1" + def get_bsh_dims(tensor_format): """Get batch dimension and sequence dimension from tensor format""" if tensor_format in ["bshd", "sbhd", "bhsd"]: @@ -71,6 +72,7 @@ def get_bsh_dims(tensor_format): head_dim = tensor_format.index("h") return batch_dim, seq_dim, head_dim + def flash_attn_p2p_communicate( rank, send_tensor, send_dst, recv_tensor, recv_src, cp_group, batch_p2p_comm ): @@ -468,7 +470,12 @@ def flash_attn_a2a_communicate( # or [s, b, np, hn] -> [s, b, cp, np//cp, hn] # or [b, np, s, hn] -> [b, cp, np//cp, s, hn] # or [t, np, hn] -> [t, cp, np//cp, hn] - x = x.view(*x.shape[:head_dim], cp_size, x.shape[head_dim] // cp_size, *x.shape[head_dim + 1:]) + x = x.view( + *x.shape[:head_dim], + cp_size, + x.shape[head_dim] // cp_size, + *x.shape[head_dim + 1 :], + ) # [b, s, cp, np//cp, hn] -> [cp, b, s, np//cp, hn] # or [s, b, cp, np//cp, hn] -> [cp, s, b, np//cp, hn] # or [b, cp, np//cp, s, hn] -> [cp, b, np//cp, s, hn] @@ -505,7 +512,7 @@ def flash_attn_a2a_communicate( # or [cp, 2, s//2, b, np//cp, hn] -> [2, s//2, b, cp, np//cp, hn] # or [cp, 2, b, np//cp, s//2, hn] -> [b, cp, np//cp, 2, s//2, hn] # or [cp, t, np//cp, hn] -> [t, cp, np//cp, hn] - x = x.movedim(0, head_dim+1).movedim(0, seq_dim+1).contiguous() + x = x.movedim(0, head_dim + 1).movedim(0, seq_dim + 1).contiguous() # [b, 2, s//2, cp, np//cp, hn] -> [b*s, np, hn] # or [2, s//2, b, cp, np//cp, hn] -> [s*b, np, hn] # or [b, cp, np//cp, 2, s//2, hn] -> [b*np, s, hn] @@ -897,7 +904,9 @@ def cp_p2p_fwd_fused_attn( for x, y in zip([q_fp8, k_fp8, v_fp8], [q_part, k_part, v_part]) ] else: - q_part, k_part, v_part, new_qkv_layout = combine_and_quantize(qkv_layout, q_part, k_part, v_part, QKV_quantizer) + q_part, k_part, v_part, new_qkv_layout = combine_and_quantize( + qkv_layout, q_part, k_part, v_part, QKV_quantizer + ) fp8_meta_kwargs["s_quantizer"] = S_quantizer_per_step fp8_meta_kwargs["o_quantizer"] = O_quantizer_per_step @@ -934,7 +943,7 @@ def cp_p2p_fwd_fused_attn( if return_max_logit: return out_per_step, softmax_lse_per_step, rng_states, attn_bias, *max_logit - return out_per_step, softmax_lse_per_step, rng_states, attn_bias, None #, new_qkv_layout + return out_per_step, softmax_lse_per_step, rng_states, attn_bias, None # , new_qkv_layout def cp_p2p_fwd_flash_attn( @@ -1165,7 +1174,9 @@ def cp_p2p_bwd_fused_attn( ) ] else: - q_part, k_part, v_part, qkv_layout = combine_and_quantize(qkv_layout, q_part, k_part, v_part, QKV_quantizer_per_step) + q_part, k_part, v_part, qkv_layout = combine_and_quantize( + qkv_layout, q_part, k_part, v_part, QKV_quantizer_per_step + ) if not fp8_recipe.mxfp8(): if not (fp8_recipe.float8_current_scaling() and _dpa_fp8_cs_o_in_f16): out_part = Float8Tensor.make_like(out_fp8, data=out_part, dtype=fwd_nominal_dtype) @@ -1419,7 +1430,10 @@ def forward( q_fp8, k_fp8, v_fp8 = (None, None, None) # communicate for the 'a2a' part of 'a2a+p2p' if cp_size_a2a > 1: - print(f">>>>>>======================>>>>>> {torch.cuda.current_device()}: fp8: {fp8}, is_input_fp8: {is_input_fp8}, fp8_recipe.mxfp8(): {fp8_recipe.mxfp8()}") + print( + f">>>>>>======================>>>>>> {torch.cuda.current_device()}: fp8: {fp8}," + f" is_input_fp8: {is_input_fp8}, fp8_recipe.mxfp8(): {fp8_recipe.mxfp8()}" + ) if fp8 and is_input_fp8: QKV_quantizer = q._quantizer q_fp8, k_fp8, v_fp8 = q, k, v @@ -1450,7 +1464,9 @@ def forward( # q_fp8, k_fp8, v_fp8: Float8Tensor, dtype=fwd_nominal_dtype # q, k, v: torch.Tensor, dtype=torch.uint8 q_f16 = q - q_fp8, k_fp8, v_fp8, qkv_layout = combine_and_quantize(qkv_layout, q, k, v, QKV_quantizer) + q_fp8, k_fp8, v_fp8, qkv_layout = combine_and_quantize( + qkv_layout, q, k, v, QKV_quantizer + ) if not fp8_recipe.mxfp8(): q, k, v = [q_fp8._data, k_fp8._data, v_fp8._data] @@ -1901,7 +1917,10 @@ def forward( softmax_lse_per_step[0], seq_dim, ) - print(f"====o/v===== {torch.cuda.current_device()}: i: {i}, {enable_mla}, out.shape: {out.shape} {out_per_step[0].shape} {v_shape} {o_shape}") + print( + f"====o/v===== {torch.cuda.current_device()}: i: {i}, {enable_mla}," + f" out.shape: {out.shape} {out_per_step[0].shape} {v_shape} {o_shape}" + ) if enable_mla: out = out.view(o_shape) else: @@ -1952,7 +1971,10 @@ def forward( ctx.batch_size = out.shape[1] print(f"========= {torch.cuda.current_device()}: out.shape: {out.shape} {out.dtype}") out_part = out.to(fwd_nominal_dtype) - print(f"========= {torch.cuda.current_device()}: out_part.shape: {out_part.shape} {out_part.dtype}") + print( + f"========= {torch.cuda.current_device()}: out_part.shape:" + f" {out_part.shape} {out_part.dtype}" + ) if cp_size_a2a > 1: chunk_ids_for_a2a = get_seq_chunk_ids_for_reordering_after_attn(cp_size_a2a, out.device) @@ -1997,7 +2019,11 @@ def forward( out_f16 = out.to(fwd_nominal_dtype) if fp8 and ( is_output_fp8 - or (is_bwd_fp8 and not (fp8_recipe.float8_current_scaling() and _dpa_fp8_cs_o_in_f16) and not fp8_recipe.mxfp8()) + or ( + is_bwd_fp8 + and not (fp8_recipe.float8_current_scaling() and _dpa_fp8_cs_o_in_f16) + and not fp8_recipe.mxfp8() + ) ): out_fp8 = O_quantizer(out_f16) out_ret = out_fp8 if (fp8 and is_output_fp8) else out_f16 @@ -2052,8 +2078,14 @@ def forward( kv_f16 = kv f16_tensors = (q_f16, kv_f16, out_f16) if torch.cuda.current_device() == 0: - print(f"fp8_tensors: {[x.shape if x is not None else None for x in fp8_tensors]} {[type(x) for x in fp8_tensors]}") - print(f"f16_tensors: {[x.shape if x is not None else None for x in f16_tensors]} {[type(x) for x in f16_tensors]}") + print( + "fp8_tensors:" + f" {[x.shape if x is not None else None for x in fp8_tensors]} {[type(x) for x in fp8_tensors]}" + ) + print( + "f16_tensors:" + f" {[x.shape if x is not None else None for x in f16_tensors]} {[type(x) for x in f16_tensors]}" + ) tensors_to_save, tensor_objects = prepare_for_saving( *fp8_tensors, @@ -2130,7 +2162,12 @@ def backward(ctx, dout, *_args): # dout is expected to be in FP8 if is_output_fp8=True, # but in the case it's not, convert it to FP8 before any operation - if ctx.fp8 and ctx.is_output_fp8 and not isinstance(dout, QuantizedTensorStorage) and not ctx.fp8_recipe.mxfp8(): + if ( + ctx.fp8 + and ctx.is_output_fp8 + and not isinstance(dout, QuantizedTensorStorage) + and not ctx.fp8_recipe.mxfp8() + ): dout = ctx.dO_quantizer(dout) if ctx.use_fused_attention: dout._data = dout._data.contiguous() @@ -2308,7 +2345,9 @@ def backward(ctx, dout, *_args): # per_step tensors are not reduced even if Float8CurrentScaling.with_amax_reduction=True; # only used to hold temporary scale/amax values (output only, no quantization op) for i in range(cp_size): - dP_quantizer_per_step[i] = ctx.dP_quantizer.copy() if ctx.dP_quantizer is not None else None + dP_quantizer_per_step[i] = ( + ctx.dP_quantizer.copy() if ctx.dP_quantizer is not None else None + ) dQKV_quantizer_per_step[i] = ctx.dQKV_quantizer.copy() if not ctx.fp8_recipe.mxfp8(): dP_quantizer_per_step[i].amax = amax_per_step[0][i].reshape((1,)) @@ -2328,7 +2367,10 @@ def backward(ctx, dout, *_args): # communicate for the 'a2a' part of 'a2a+p2p' if cp_size_a2a > 1: - print(f"========= {torch.cuda.current_device()}: before a2a: out.shape: {out.shape} {out.dtype} dout.shape: {dout.shape} {dout.dtype}") + print( + f"========= {torch.cuda.current_device()}: before a2a: out.shape:" + f" {out.shape} {out.dtype} dout.shape: {dout.shape} {dout.dtype}" + ) if not ctx.use_fused_attention: # out = out.view(ctx.batch_size, -1, *out.shape[-2:]) dout = dout.view(*out.shape) @@ -2344,7 +2386,10 @@ def backward(ctx, dout, *_args): ctx.cp_stream, True, ) - print(f"========= {torch.cuda.current_device()}: after a2a: dout.shape: {dout.shape} {dout.dtype} {q.shape} {ctx.v_shape}") + print( + f"========= {torch.cuda.current_device()}: after a2a: dout.shape:" + f" {dout.shape} {dout.dtype} {q.shape} {ctx.v_shape}" + ) if ctx.enable_mla: out = out.view(*ctx.o_shape) @@ -2449,7 +2494,8 @@ def backward(ctx, dout, *_args): kv_fp8, ( out - if (ctx.fp8_recipe.float8_current_scaling() and _dpa_fp8_cs_o_in_f16) or ctx.fp8_recipe.mxfp8() + if (ctx.fp8_recipe.float8_current_scaling() and _dpa_fp8_cs_o_in_f16) + or ctx.fp8_recipe.mxfp8() else out_fp8 ), dout_fp8 if not ctx.fp8_recipe.mxfp8() else dout, @@ -3010,7 +3056,9 @@ def forward( if is_input_fp8: q_fp8, k_fp8, v_fp8 = q, k, v elif not fp8_recipe.mxfp8(): - q_fp8, k_fp8, v_fp8, qkv_layout = combine_and_quantize(qkv_layout, q, k, v, QKV_quantizer) + q_fp8, k_fp8, v_fp8, qkv_layout = combine_and_quantize( + qkv_layout, q, k, v, QKV_quantizer + ) if not fp8_recipe.mxfp8(): q, k, v = [q_fp8._data, k_fp8._data, v_fp8._data] fp8_meta_kwargs["s_quantizer"] = S_quantizer @@ -3090,11 +3138,19 @@ def forward( new_qkv_layout = qkv_layout if fp8: if not fp8_recipe.mxfp8(): - q_part = Float8Tensor.make_like(q_fp8, data=q_, dtype=fwd_nominal_dtype) - k_part = Float8Tensor.make_like(k_fp8, data=k_, dtype=fwd_nominal_dtype) - v_part = Float8Tensor.make_like(v_fp8, data=v_, dtype=fwd_nominal_dtype) + q_part = Float8Tensor.make_like( + q_fp8, data=q_, dtype=fwd_nominal_dtype + ) + k_part = Float8Tensor.make_like( + k_fp8, data=k_, dtype=fwd_nominal_dtype + ) + v_part = Float8Tensor.make_like( + v_fp8, data=v_, dtype=fwd_nominal_dtype + ) else: - q_part, k_part, v_part, new_qkv_layout = combine_and_quantize(qkv_layout, q_part, k_part, v_part, QKV_quantizer) + q_part, k_part, v_part, new_qkv_layout = combine_and_quantize( + qkv_layout, q_part, k_part, v_part, QKV_quantizer + ) ( out_per_step[i], aux_ctx_tensors, @@ -3217,8 +3273,14 @@ def forward( else: f16_tensors = (q, k, v, out) if torch.cuda.current_device() == 0: - print(f"fp8_tensors: {[x.shape if x is not None else None for x in fp8_tensors]} {[type(x) for x in fp8_tensors]}") - print(f"f16_tensors: {[x.shape if x is not None else None for x in f16_tensors]} {[type(x) for x in f16_tensors]}") + print( + "fp8_tensors:" + f" {[x.shape if x is not None else None for x in fp8_tensors]} {[type(x) for x in fp8_tensors]}" + ) + print( + "f16_tensors:" + f" {[x.shape if x is not None else None for x in f16_tensors]} {[type(x) for x in f16_tensors]}" + ) tensors_to_save, tensor_objects = prepare_for_saving( *fp8_tensors, @@ -3282,8 +3344,14 @@ def backward(ctx, dout, *_args): softmax_lse_per_step = [None, None] rng_states = [None, None] ( - q_fp8, k_fp8, v_fp8, out_fp8, - q, k, v, out, + q_fp8, + k_fp8, + v_fp8, + out_fp8, + q, + k, + v, + out, cu_seqlens_q, cu_seqlens_q_padded, cu_seqlens_kv_per_step[0], @@ -3320,8 +3388,16 @@ def backward(ctx, dout, *_args): if torch.cuda.current_device() == 0: print(f"ctx.q_shape: {ctx.q_shape} {ctx.k_shape} {ctx.v_shape}") dq = torch.empty(ctx.q_shape, dtype=ctx.fwd_nominal_dtype, device=q.device) - dk = torch.zeros((ctx.k_shape[0] * cp_size, *ctx.k_shape[1:]), dtype=ctx.fwd_nominal_dtype, device=k.device) - dv = torch.zeros((ctx.v_shape[0] * cp_size, *ctx.v_shape[1:]), dtype=ctx.fwd_nominal_dtype, device=v.device) + dk = torch.zeros( + (ctx.k_shape[0] * cp_size, *ctx.k_shape[1:]), + dtype=ctx.fwd_nominal_dtype, + device=k.device, + ) + dv = torch.zeros( + (ctx.v_shape[0] * cp_size, *ctx.v_shape[1:]), + dtype=ctx.fwd_nominal_dtype, + device=v.device, + ) if torch.cuda.current_device() == 0: print(f"dq: {dq.shape} {dq.dtype} {dq.device}") print(f"dk: {dk.shape} {dk.dtype} {dk.device}") @@ -3400,7 +3476,11 @@ def backward(ctx, dout, *_args): out_ = out_per_step[i] dout_ = dout.select(seq_dim, i).contiguous().view(out_.shape) if ctx.use_fused_attention: - aux_ctx_tensors = [softmax_lse_per_step[i], softmax_lse_per_step[i], rng_states[i]] + aux_ctx_tensors = [ + softmax_lse_per_step[i], + softmax_lse_per_step[i], + rng_states[i], + ] q_part, k_part, v_part, out_part, dout_part = q_, k_, v_, out_, dout_ fused_attn_backend = tex.NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen fp8_meta_kwargs = {} @@ -3421,17 +3501,34 @@ def backward(ctx, dout, *_args): v_part = Float8Tensor.make_like( v_fp8, data=v_, dtype=ctx.fwd_nominal_dtype ) - if not (ctx.fp8_recipe.float8_current_scaling() and _dpa_fp8_cs_o_in_f16): + if not ( + ctx.fp8_recipe.float8_current_scaling() and _dpa_fp8_cs_o_in_f16 + ): out_part = ctx.O_quantizer(out_part) - dout_part = Float8Tensor.make_like(dout_fp8, data=dout_part, dtype=ctx.fwd_nominal_dtype) + dout_part = Float8Tensor.make_like( + dout_fp8, data=dout_part, dtype=ctx.fwd_nominal_dtype + ) else: - q_part, k_part, v_part, new_qkv_layout = combine_and_quantize(qkv_layout, q_part, k_part, v_part, ctx.QKV_quantizer) - print(f"aux_ctx_tensors: {len(aux_ctx_tensors)} {[x.shape if x is not None else None for x in aux_ctx_tensors]} {[type(x) for x in aux_ctx_tensors]}") - dout_part, d_out_format = dpa_utils.permute_to_grouped_tensor(d_out_format, dout_part) + q_part, k_part, v_part, new_qkv_layout = combine_and_quantize( + qkv_layout, q_part, k_part, v_part, ctx.QKV_quantizer + ) + print( + "aux_ctx_tensors:" + f" {len(aux_ctx_tensors)} {[x.shape if x is not None else None for x in aux_ctx_tensors]} {[type(x) for x in aux_ctx_tensors]}" + ) + dout_part, d_out_format = dpa_utils.permute_to_grouped_tensor( + d_out_format, dout_part + ) aux_ctx_tensors.append(dout_part) dout_part = ctx.dO_quantizer(dout_part) - print(f"q_part type: {type(q_part)} {type(k_part)} {type(v_part)} {type(out_part)} {type(dout_part)}") - print(f"q_part shape: {q_part.shape} {k_part.shape} {v_part.shape} {out_part.shape} {dout_part.shape}") + print( + "q_part type:" + f" {type(q_part)} {type(k_part)} {type(v_part)} {type(out_part)} {type(dout_part)}" + ) + print( + "q_part shape:" + f" {q_part.shape} {k_part.shape} {v_part.shape} {out_part.shape} {dout_part.shape}" + ) dq_per_step[i], dk_per_step[i], dv_per_step[i], *_ = fused_attn_bwd( ctx.max_seqlen_q, max_seqlen_kv, @@ -3463,7 +3560,11 @@ def backward(ctx, dout, *_args): ) if ctx.fp8: dq_per_step[i], dk_per_step[i], dv_per_step[i] = [ - x.dequantize(dtype=ctx.fwd_nominal_dtype) if isinstance(x, QuantizedTensorStorage) else x + ( + x.dequantize(dtype=ctx.fwd_nominal_dtype) + if isinstance(x, QuantizedTensorStorage) + else x + ) for x in [dq_per_step[i], dk_per_step[i], dv_per_step[i]] ] else: @@ -3506,7 +3607,10 @@ def backward(ctx, dout, *_args): if i > 0: with torch.cuda.stream(flash_attn_streams[i - 1]): if torch.cuda.current_device() == 0: - print(f"dq.shape: {dq.shape} dq_per_step[i - 1].shape: {dq_per_step[i - 1].shape}") + print( + f"dq.shape: {dq.shape} dq_per_step[i - 1].shape:" + f" {dq_per_step[i - 1].shape}" + ) if ctx.qkv_format == "bshd": dq[:, i - 1].copy_(dq_per_step[i - 1]) elif ctx.qkv_format == "sbhd": @@ -3717,13 +3821,15 @@ def forward( if is_input_fp8: q_fp8, k_fp8, v_fp8 = q, k, v elif not fp8_recipe.mxfp8(): - q_fp8, k_fp8, v_fp8, qkv_layout = combine_and_quantize(qkv_layout, q, k, v, QKV_quantizer) + q_fp8, k_fp8, v_fp8, qkv_layout = combine_and_quantize( + qkv_layout, q, k, v, QKV_quantizer + ) if not fp8_recipe.mxfp8(): q, k, v = [q_fp8._data, k_fp8._data, v_fp8._data] # else: # q, k, v = [q_fp8, k_fp8, v_fp8] - # qkv_format, _, _ = dpa_utils.get_qkv_format(qkv_layout) - # batch_dim_qkv, seq_dim_qkv, _ = get_bsh_dims(qkv_format) + # qkv_format, _, _ = dpa_utils.get_qkv_format(qkv_layout) + # batch_dim_qkv, seq_dim_qkv, _ = get_bsh_dims(qkv_format) fp8_meta_kwargs = {} fp8_meta_kwargs["s_quantizer"] = S_quantizer fp8_meta_kwargs["o_quantizer"] = O_quantizer @@ -3763,7 +3869,9 @@ def forward( for x, y in zip([q_fp8, k_fp8, v_fp8], [q_part, k_part, v_part]) ] if fp8 and fp8_recipe.mxfp8(): - q_fp8, k_fp8, v_fp8, qkv_layout = combine_and_quantize(qkv_layout, q_part, k_part, v_part, QKV_quantizer) + q_fp8, k_fp8, v_fp8, qkv_layout = combine_and_quantize( + qkv_layout, q_part, k_part, v_part, QKV_quantizer + ) q_part, k_part, v_part = [q_fp8, k_fp8, v_fp8] out_, aux_ctx_tensors, *max_logit = fused_attn_fwd( is_training, @@ -3995,7 +4103,7 @@ def backward(ctx, dout, *_args): dout_fp8 = dout if not ctx.fp8_recipe.mxfp8(): # dqkv_te_dtype = dout._fp8_dtype - dout = dout._data + dout = dout._data fp8_meta_kwargs = {} fp8_meta_kwargs["s_quantizer"] = ctx.S_quantizer fp8_meta_kwargs["dp_quantizer"] = ctx.dP_quantizer @@ -4075,7 +4183,9 @@ def backward(ctx, dout, *_args): q_part, k_part, v_part, out_part, dout_part = q, k, v, out, dout if ctx.fp8: q_part, k_part, v_part, out_part = q_fp8, k_fp8, v_fp8, out_fp8 - if (ctx.fp8_recipe.float8_current_scaling() and _dpa_fp8_cs_o_in_f16) or ctx.fp8_recipe.mxfp8(): + if ( + ctx.fp8_recipe.float8_current_scaling() and _dpa_fp8_cs_o_in_f16 + ) or ctx.fp8_recipe.mxfp8(): out_part = out if not ctx.fp8_recipe.mxfp8(): dout_part = Float8Tensor.make_like(dout_fp8, data=dout, dtype=bwd_nominal_dtype) @@ -4174,7 +4284,9 @@ def backward(ctx, dout, *_args): ) if ctx.fp8: - if (ctx.fp8_recipe.float8_current_scaling() or ctx.fp8_recipe.mxfp8()) and ctx.is_input_fp8: + if ( + ctx.fp8_recipe.float8_current_scaling() or ctx.fp8_recipe.mxfp8() + ) and ctx.is_input_fp8: dq, dk, dv = combine_and_quantize(ctx.dqkv_layout, dq, dk, dv, ctx.dQKV_quantizer) if ctx.fp8_recipe.delayed(): dq, dk, dv = [ diff --git a/transformer_engine/pytorch/attention/dot_product_attention/utils.py b/transformer_engine/pytorch/attention/dot_product_attention/utils.py index 5b32f35be0..9a8d38547e 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/utils.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/utils.py @@ -2209,21 +2209,21 @@ def print_quantizers( f" {q.amax.item():.4e} = {q.scale.item()*q.amax.item():.4e}" ) else: - print( - f"{label} >> {names[i]:14s}: {type_str}" - ) + print(f"{label} >> {names[i]:14s}: {type_str}") + def permute_to_grouped_tensor(src_format, tensor): """Permute tensor to bhsd or htd format for grouped quantization in MXFP8BlockScaling. src_format ={bshd, sbhd, thd}""" if src_format in ["bhsd", "htd"]: return tensor, src_format tensor = tensor.contiguous() if not tensor.is_contiguous() else tensor - dim_s_or_t = src_format.find("s") if 's' in src_format else src_format.find("t") + dim_s_or_t = src_format.find("s") if "s" in src_format else src_format.find("t") dim_others = [i for i in range(len(tensor.shape)) if i != dim_s_or_t] perm = [*dim_others[:-1], dim_s_or_t, dim_others[-1]] tensor = tensor.permute(*perm).contiguous() return tensor, "bhsd" if src_format != "thd" else "htd" + def combine_and_quantize(qkv_layout, q, k, v, qkv_quantizer): """Combine q,k,v based on qkv_layout and quantize them together""" # 1: qkv packed, 2: kv packed, 3: qkv separate @@ -2244,7 +2244,10 @@ def combine_and_quantize(qkv_layout, q, k, v, qkv_quantizer): original_shapes = [x.shape for x in [q, k, v]] s_q, d_qk = q.shape[-2:] s_kv, d_v = v.shape[-2:] - print(f">>>>>>>>>>>> {torch.cuda.current_device()}: {qkv_layout} s_q: {s_q}, d_qk: {d_qk}, s_kv: {s_kv}, d_v: {d_v}, {q.shape}, {k.shape}, {v.shape}") + print( + f">>>>>>>>>>>> {torch.cuda.current_device()}: {qkv_layout} s_q: {s_q}, d_qk: {d_qk}," + f" s_kv: {s_kv}, d_v: {d_v}, {q.shape}, {k.shape}, {v.shape}" + ) assert s_q % 128 == 0 assert s_kv % 128 == 0 assert d_qk % 32 == 0 @@ -2254,7 +2257,9 @@ def combine_and_quantize(qkv_layout, q, k, v, qkv_quantizer): # consider bhsd for now if d_qk == d_v: - grouped_tensor = GroupedTensor.create_and_quantize(tensors=[q, k, v], quantizer=qkv_quantizer) + grouped_tensor = GroupedTensor.create_and_quantize( + tensors=[q, k, v], quantizer=qkv_quantizer + ) q_fp8, k_fp8, v_fp8 = grouped_tensor.quantized_tensors else: q_fp8 = qkv_quantizer(q) diff --git a/transformer_engine/pytorch/csrc/common.h b/transformer_engine/pytorch/csrc/common.h index ebed38fc84..f757bbdfee 100644 --- a/transformer_engine/pytorch/csrc/common.h +++ b/transformer_engine/pytorch/csrc/common.h @@ -306,7 +306,7 @@ class MXFP8Quantizer : public Quantizer { * amax to be initialized to zero. */ std::pair create_unquantized_tensor_with_amax( - const std::vector& shape, DType dtype, std::optional data = std::nullopt); + const std::vector& shape, DType dtype, std::optional data = std::nullopt); std::pair create_grouped_tensor( size_t num_tensors, const std::vector& logical_shape, DType dtype, diff --git a/transformer_engine/pytorch/csrc/extensions.h b/transformer_engine/pytorch/csrc/extensions.h index 3d5ad2e598..95c985062a 100644 --- a/transformer_engine/pytorch/csrc/extensions.h +++ b/transformer_engine/pytorch/csrc/extensions.h @@ -98,11 +98,13 @@ std::vector fused_attn_fwd( std::vector fused_attn_bwd( size_t max_seqlen_q, size_t max_seqlen_kv, float attn_scale, float p_dropout, bool set_zero, - NVTE_QKV_Layout qkv_layout, NVTE_QKV_Format o_format, NVTE_QKV_Format d_out_format, NVTE_QKV_Layout dqkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, + NVTE_QKV_Layout qkv_layout, NVTE_QKV_Format o_format, NVTE_QKV_Format d_out_format, + NVTE_QKV_Layout dqkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, const std::vector window_size, bool bottom_right_diagonal, bool deterministic, const at::Tensor cu_seqlens_q, const at::Tensor cu_seqlens_kv, const py::handle Q, const py::handle K, const py::handle V, - const py::handle O, const py::handle dO, const at::ScalarType fake_dtype, //const DType dqkv_type, + const py::handle O, const py::handle dO, + const at::ScalarType fake_dtype, //const DType dqkv_type, const std::vector Aux_CTX_Tensors, const std::optional cu_seqlens_q_padded, const std::optional cu_seqlens_kv_padded, py::handle s_quantizer, diff --git a/transformer_engine/pytorch/csrc/extensions/attention.cpp b/transformer_engine/pytorch/csrc/extensions/attention.cpp index fd193b0258..bd5a5a065f 100644 --- a/transformer_engine/pytorch/csrc/extensions/attention.cpp +++ b/transformer_engine/pytorch/csrc/extensions/attention.cpp @@ -153,8 +153,9 @@ std::vector fused_attn_fwd( auto o_shape_tmp = std::vector{q_shape.begin(), q_shape.end()}; o_shape_tmp[o_shape_tmp.size() - 1] = v_shape[v_shape.size() - 1]; auto o_shape = std::vector{o_shape_tmp.begin(), o_shape_tmp.end()}; - size_t b=0, h=0, s=0, d=0, t=0; - nvte_convert_qkv_format(nvte_get_q_format(qkv_layout), o_shape_tmp, o_format, o_shape, &b, &h, &s, &d, &t); + size_t b = 0, h = 0, s = 0, d = 0, t = 0; + nvte_convert_qkv_format(nvte_get_q_format(qkv_layout), o_shape_tmp, o_format, o_shape, &b, &h, &s, + &d, &t); const DType fake_dtype_te = GetTransformerEngineDType(fake_dtype); std::tie(te_O, py_O) = quantizer_helper(o_quantizer, o_shape, fake_dtype_te, true, std::nullopt); @@ -254,9 +255,9 @@ std::vector fused_attn_fwd( te_O.data(), &nvte_aux_tensor_pack, te_cu_seqlens_q.data(), te_cu_seqlens_kv.data(), te_cu_seqlens_q_padded.data(), te_cu_seqlens_kv_padded.data(), te_page_table_k.data(), te_page_table_v.data(), te_rng_state.data(), max_seqlen_q, max_seqlen_kv, is_training, - return_max_logit, cuda_graph, attn_scale, p_dropout, qkv_layout, o_format, bias_type, attn_mask_type, - softmax_type, window_size[0], window_size[1], bottom_right_diagonal, workspace.data(), - at::cuda::getCurrentCUDAStream()); + return_max_logit, cuda_graph, attn_scale, p_dropout, qkv_layout, o_format, bias_type, + attn_mask_type, softmax_type, window_size[0], window_size[1], bottom_right_diagonal, + workspace.data(), at::cuda::getCurrentCUDAStream()); }); // allocate memory for workspace and auxiliary output tensors @@ -314,9 +315,9 @@ std::vector fused_attn_fwd( te_O.data(), &nvte_aux_tensor_pack, te_cu_seqlens_q.data(), te_cu_seqlens_kv.data(), te_cu_seqlens_q_padded.data(), te_cu_seqlens_kv_padded.data(), te_page_table_k.data(), te_page_table_v.data(), te_rng_state.data(), max_seqlen_q, max_seqlen_kv, is_training, - return_max_logit, cuda_graph, attn_scale, p_dropout, qkv_layout, o_format, bias_type, attn_mask_type, - softmax_type, window_size[0], window_size[1], bottom_right_diagonal, workspace.data(), - at::cuda::getCurrentCUDAStream()); + return_max_logit, cuda_graph, attn_scale, p_dropout, qkv_layout, o_format, bias_type, + attn_mask_type, softmax_type, window_size[0], window_size[1], bottom_right_diagonal, + workspace.data(), at::cuda::getCurrentCUDAStream()); }); // destroy tensor wrappers, but not allocated memory nvte_tensor_pack_destroy(&nvte_aux_tensor_pack); @@ -328,11 +329,13 @@ std::vector fused_attn_fwd( // fused attention BWD with separate Q, K and V std::vector fused_attn_bwd( size_t max_seqlen_q, size_t max_seqlen_kv, float attn_scale, float p_dropout, bool set_zero, - NVTE_QKV_Layout qkv_layout, NVTE_QKV_Format o_format, NVTE_QKV_Format d_out_format, NVTE_QKV_Layout dqkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, + NVTE_QKV_Layout qkv_layout, NVTE_QKV_Format o_format, NVTE_QKV_Format d_out_format, + NVTE_QKV_Layout dqkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, const std::vector window_size, bool bottom_right_diagonal, bool deterministic, const at::Tensor cu_seqlens_q, const at::Tensor cu_seqlens_kv, const py::handle Q, const py::handle K, const py::handle V, - const py::handle O, const py::handle dO, const at::ScalarType fake_dtype, //const DType dqkv_type, + const py::handle O, const py::handle dO, + const at::ScalarType fake_dtype, //const DType dqkv_type, const std::vector Aux_CTX_Tensors, const std::optional cu_seqlens_q_padded, const std::optional cu_seqlens_kv_padded, py::handle s_quantizer, @@ -365,11 +368,14 @@ std::vector fused_attn_bwd( // auto h_kv = k_shape[k_shape.size() - 2]; // auto d_qk = q_shape[q_shape.size() - 1]; const DType fake_dtype_te = GetTransformerEngineDType(fake_dtype); - size_t b=0, h_q=0, h_kv=0, s_q=0, s_kv=0, d_qk=0, d_v=0, t_q=0, t_kv=0; + size_t b = 0, h_q = 0, h_kv = 0, s_q = 0, s_kv = 0, d_qk = 0, d_v = 0, t_q = 0, t_kv = 0; std::vector dQ_shape(4), dK_shape(4), dV_shape(4); - nvte_convert_qkv_format(nvte_get_q_format(qkv_layout), q_shape, nvte_get_q_format(dqkv_layout), dQ_shape, &b, &h_q, &s_q, &d_qk, &t_q); - nvte_convert_qkv_format(nvte_get_kv_format(qkv_layout), k_shape, nvte_get_kv_format(dqkv_layout), dK_shape, &b, &h_kv, &s_kv, &d_qk, &t_kv); - nvte_convert_qkv_format(nvte_get_kv_format(qkv_layout), v_shape, nvte_get_kv_format(dqkv_layout), dV_shape, &b, &h_kv, &s_kv, &d_v, &t_kv); + nvte_convert_qkv_format(nvte_get_q_format(qkv_layout), q_shape, nvte_get_q_format(dqkv_layout), + dQ_shape, &b, &h_q, &s_q, &d_qk, &t_q); + nvte_convert_qkv_format(nvte_get_kv_format(qkv_layout), k_shape, nvte_get_kv_format(dqkv_layout), + dK_shape, &b, &h_kv, &s_kv, &d_qk, &t_kv); + nvte_convert_qkv_format(nvte_get_kv_format(qkv_layout), v_shape, nvte_get_kv_format(dqkv_layout), + dV_shape, &b, &h_kv, &s_kv, &d_v, &t_kv); at::Tensor dQ, dK, dV, dQKV, dKV; DType dqkv_type = fake_dtype_te; @@ -380,7 +386,8 @@ std::vector fused_attn_bwd( if (dqkv_type == DType::kFloat8E4M3 || dqkv_type == DType::kFloat8E5M2) { options = options.dtype(torch::kUInt8); } - if (detail::IsFloat8CurrentScalingQuantizers(dqkv_quantizer.ptr()) || detail::IsMXFP8Quantizers(dqkv_quantizer.ptr())) { + if (detail::IsFloat8CurrentScalingQuantizers(dqkv_quantizer.ptr()) || + detail::IsMXFP8Quantizers(dqkv_quantizer.ptr())) { options = options.dtype(fake_dtype); } NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(dqkv_layout); @@ -447,7 +454,7 @@ std::vector fused_attn_bwd( .squeeze(tmp_shape.size() - 2); break; case NVTE_QKV_Layout_Group::NVTE_HD_HD_HD: - case NVTE_QKV_Layout_Group::NVTE_SD_SD_SD: + case NVTE_QKV_Layout_Group::NVTE_SD_SD_SD: tmp_shape = std::vector(dQ_shape.begin(), dQ_shape.end()); dQ = torch::empty(tmp_shape, options); tmp_shape = std::vector(dK_shape.begin(), dK_shape.end()); @@ -566,9 +573,9 @@ std::vector fused_attn_bwd( &nvte_aux_tensor_pack, te_dQ.data(), te_dK.data(), te_dV.data(), te_dBias.data(), te_dSoftmaxOffset.data(), te_cu_seqlens_q.data(), te_cu_seqlens_kv.data(), te_cu_seqlens_q_padded.data(), te_cu_seqlens_kv_padded.data(), max_seqlen_q, max_seqlen_kv, - attn_scale, p_dropout, qkv_layout, o_format, d_out_format, dqkv_layout, bias_type, attn_mask_type, softmax_type, window_size[0], - window_size[1], bottom_right_diagonal, deterministic, cuda_graph, workspace.data(), - at::cuda::getCurrentCUDAStream()); + attn_scale, p_dropout, qkv_layout, o_format, d_out_format, dqkv_layout, bias_type, + attn_mask_type, softmax_type, window_size[0], window_size[1], bottom_right_diagonal, + deterministic, cuda_graph, workspace.data(), at::cuda::getCurrentCUDAStream()); }); // allocate memory for workspace @@ -583,9 +590,9 @@ std::vector fused_attn_bwd( &nvte_aux_tensor_pack, te_dQ.data(), te_dK.data(), te_dV.data(), te_dBias.data(), te_dSoftmaxOffset.data(), te_cu_seqlens_q.data(), te_cu_seqlens_kv.data(), te_cu_seqlens_q_padded.data(), te_cu_seqlens_kv_padded.data(), max_seqlen_q, max_seqlen_kv, - attn_scale, p_dropout, qkv_layout, o_format, d_out_format, dqkv_layout, bias_type, attn_mask_type, softmax_type, window_size[0], - window_size[1], bottom_right_diagonal, deterministic, cuda_graph, workspace.data(), - at::cuda::getCurrentCUDAStream()); + attn_scale, p_dropout, qkv_layout, o_format, d_out_format, dqkv_layout, bias_type, + attn_mask_type, softmax_type, window_size[0], window_size[1], bottom_right_diagonal, + deterministic, cuda_graph, workspace.data(), at::cuda::getCurrentCUDAStream()); }); // destroy tensor wrappers diff --git a/transformer_engine/pytorch/csrc/quantizer.cpp b/transformer_engine/pytorch/csrc/quantizer.cpp index 92239aafc0..b44640d006 100644 --- a/transformer_engine/pytorch/csrc/quantizer.cpp +++ b/transformer_engine/pytorch/csrc/quantizer.cpp @@ -1230,10 +1230,10 @@ std::pair MXFP8Quantizer::create_tensor(const std::ve return {std::move(out_cpp), std::move(out_py)}; } -std::pair MXFP8Quantizer::create_unquantized_tensor_with_amax(const std::vector& shape, - DType dtype, - std::optional data) { - at::Tensor amax_tensor = at::zeros({1}, at::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA)); +std::pair MXFP8Quantizer::create_unquantized_tensor_with_amax( + const std::vector& shape, DType dtype, std::optional data) { + at::Tensor amax_tensor = + at::zeros({1}, at::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA)); auto out = data.has_value() ? NoneQuantizer(py::none()).create_tensor(shape, dtype, data.value()) : NoneQuantizer(py::none()).create_tensor(shape, dtype); TensorWrapper out_cpp = std::move(out.first); diff --git a/transformer_engine/pytorch/tensor/storage/grouped_tensor.py b/transformer_engine/pytorch/tensor/storage/grouped_tensor.py index 255a9ecfd3..466429cf3f 100644 --- a/transformer_engine/pytorch/tensor/storage/grouped_tensor.py +++ b/transformer_engine/pytorch/tensor/storage/grouped_tensor.py @@ -975,4 +975,4 @@ def quantize( self.quantized_tensors = self.split_into_quantized_tensors() for i in range(self.num_tensors): self.quantizer.update_quantized(tensors[i], self.quantized_tensors[i], noop_flag=noop_flag) - return self.quantized_tensors \ No newline at end of file + return self.quantized_tensors From 81c18fa8fc854e5a977581f815e94dc4097d955e Mon Sep 17 00:00:00 2001 From: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Date: Sun, 1 Mar 2026 15:05:21 -0800 Subject: [PATCH 043/172] fix merge Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --- tests/pytorch/test_grouped_tensor.py | 2 +- .../cast/mxfp8/group_quantize_mxfp8.cuh | 12 -- .../common/fused_attn/fused_attn_fp8.cu | 1 - .../transformer_engine/transformer_engine.h | 29 +---- .../common/transformer_engine.cpp | 114 ------------------ .../pytorch/csrc/extensions/attention.cpp | 6 +- .../pytorch/tensor/storage/grouped_tensor.py | 6 +- 7 files changed, 9 insertions(+), 161 deletions(-) diff --git a/tests/pytorch/test_grouped_tensor.py b/tests/pytorch/test_grouped_tensor.py index 31d84933de..de00d0cf35 100644 --- a/tests/pytorch/test_grouped_tensor.py +++ b/tests/pytorch/test_grouped_tensor.py @@ -393,7 +393,7 @@ def test_quantize_grouped_mxfp8(self) -> None: device="cuda", ) # Quantize using grouped API (handle both 2-arg and 3-arg bindings) - _ = tex.quantize_grouped(grouped_input, grouped_output) + _ = tex.group_quantize(grouped_input, grouped_output) # Build expected output by quantizing each tensor independently expected_data = [] expected_scale_inv = [] diff --git a/transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh b/transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh index c9816494bb..f454209409 100644 --- a/transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh +++ b/transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh @@ -787,7 +787,6 @@ void group_quantize(const GroupedTensor *input, const GroupedTensor *activations using namespace group_quantize_kernel; checkCuDriverContext(stream); - // CheckNoopTensor(*noop, "cast_noop"); const bool use_rowwise_scaling = output->has_data(); const bool use_colwise_scaling = output->has_columnwise_data(); @@ -800,13 +799,6 @@ void group_quantize(const GroupedTensor *input, const GroupedTensor *activations } else if (!use_rowwise_scaling) { scaling_type = ScalingType::COLWISE; } - // if (use_rowwise_scaling && (!use_colwise_scaling)) { - // scaling_type = ScalingType::ROWWISE; - // } else if ((!use_rowwise_scaling) && use_colwise_scaling) { - // scaling_type = ScalingType::COLWISE; - // } else if (use_rowwise_scaling && use_colwise_scaling) { - // scaling_type = ScalingType::BIDIMENSIONAL; - // } ShapeRepresentation shape_rep = ShapeRepresentation::SAME_BOTH_DIMS; if (output->all_same_shape()) { @@ -886,10 +878,6 @@ void group_quantize(const GroupedTensor *input, const GroupedTensor *activations NVTE_CHECK(scales_colwise_ptr != nullptr, "Columnwise scaling tensor must be allocated"); } - const size_t scale_stride_rowwise = use_rowwise_scaling ? output->scale_inv.shape[1] : 1; - const size_t scale_stride_colwise = - use_colwise_scaling ? output->columnwise_scale_inv.shape[1] : 1; - const size_t dbias_rows = DIVUP(first_logical_dim, CHUNK_DIM_Y); const size_t dbias_cols = last_logical_dim; if constexpr (IS_DBIAS) { diff --git a/transformer_engine/common/fused_attn/fused_attn_fp8.cu b/transformer_engine/common/fused_attn/fused_attn_fp8.cu index 237f3bd66e..48c3975264 100644 --- a/transformer_engine/common/fused_attn/fused_attn_fp8.cu +++ b/transformer_engine/common/fused_attn/fused_attn_fp8.cu @@ -2084,7 +2084,6 @@ void fused_attn_fp8_bwd_impl_v1( bool is_dropout = (dropout_probability != 0.0f); auto bias_b = b; auto bias_h = h; - const auto cudnn_runtime_version = cudnnGetVersion(); auto bias_sq = s_q; auto bias_skv = s_kv; NVTE_CHECK(~is_bias, "FP8 fused attention does not support pre/post_scale_bias yet!"); diff --git a/transformer_engine/common/include/transformer_engine/transformer_engine.h b/transformer_engine/common/include/transformer_engine/transformer_engine.h index a6cb036a35..3dacc596c8 100644 --- a/transformer_engine/common/include/transformer_engine/transformer_engine.h +++ b/transformer_engine/common/include/transformer_engine/transformer_engine.h @@ -449,13 +449,15 @@ enum NVTEGroupedTensorParam { kNVTEGroupedLastDims = 8, /*!< Last dimension sizes (device pointer to int64_t array) */ kNVTEGroupedTensorOffsets = 9, /*!< Tensor offsets for contiguous layout (device pointer to int64_t array) */ +<<<<<<< HEAD <<<<<<< HEAD kNVTEGroupedWithGEMMSwizzledScales = 10, /*!< Whether scaling factors are in format expected by GEMM */ ======= +======= +>>>>>>> 341cc3df (fix merge) kNVTEGroupedWithGEMMSwizzledScales = 10, /*!< Whether scaling factors are in format expected by GEMM */ ->>>>>>> main kNVTENumGroupedTensorParams }; @@ -511,27 +513,6 @@ void nvte_set_grouped_tensor_param(NVTEGroupedTensor tensor, NVTEGroupedTensorPa void nvte_get_grouped_tensor_param(const NVTEGroupedTensor tensor, NVTEGroupedTensorParam param, void *buf, size_t size_in_bytes, size_t *size_written); -/*! \brief Set a parameter of the grouped tensor. - * - * \param[in/out] tensor Grouped tensor. - * \param[in] param_name The parameter to be set. - * \param[in] param The value to be set (NVTEBasicTensor). - */ -void nvte_set_grouped_tensor_param_v2(NVTEGroupedTensor tensor, NVTEGroupedTensorParam param_name, - const void *buf, size_t size_in_bytes); - -/* EXPERIMENTAL FEATURE AND SUBJECT TO CHANGE. */ -/*! \brief Get a value of the parameter of the grouped tensor. - * - * \param[in] tensor Grouped tensor. - * \param[in] param_name The parameter to be queried. - * - * \return NVTEBasicTensor containing the parameter data. - */ -void nvte_get_grouped_tensor_param_v2(const NVTEGroupedTensor tensor, - NVTEGroupedTensorParam param_name, void *buf, - size_t size_in_bytes, size_t *size_written); - /* EXPERIMENTAL FEATURE AND SUBJECT TO CHANGE. */ /*! \brief Get the number of tensors in a grouped tensor. * @@ -994,6 +975,7 @@ class TensorWrapper { * \brief C++ wrapper for the NVTEGroupedTensor class. */ +<<<<<<< HEAD <<<<<<< HEAD class GroupedTensorWrapper { public: @@ -1212,6 +1194,8 @@ class GroupedTensorWrapper { /*! \warning Deprecated */ enum class Float8BlockScaleTensorFormat { GEMM_READY = 0, COMPACT = 1, INVALID }; ======= +======= +>>>>>>> 341cc3df (fix merge) class GroupedTensorWrapper { public: /*! \brief Constructs new GroupedTensorWrapper. @@ -1437,7 +1421,6 @@ enum class Float8BlockScaleTensorFormat { COMPACT = 1, INVALID }; ->>>>>>> main /*! \struct QuantizationConfigWrapper * \brief C++ wrapper for NVTEQuantizationConfigWrapper. diff --git a/transformer_engine/common/transformer_engine.cpp b/transformer_engine/common/transformer_engine.cpp index be7521ccd4..cd02074fbd 100644 --- a/transformer_engine/common/transformer_engine.cpp +++ b/transformer_engine/common/transformer_engine.cpp @@ -1358,117 +1358,3 @@ NVTEShape nvte_get_grouped_tensor_logical_shape(const NVTEGroupedTensor tensor) const auto &t = *transformer_engine::convertNVTEGroupedTensorCheck(tensor); return t.logical_shape; } - -void nvte_set_grouped_tensor_param_v2(NVTEGroupedTensor tensor, NVTEGroupedTensorParam param, - const void *buf, size_t size_in_bytes) { - // Check attribute and buffer - NVTE_CHECK(param < kNVTENumGroupedTensorParams, "Invalid NVTEGroupedTensorParam (got ", - static_cast(param), ")"); - NVTE_CHECK(tensor != nullptr, "Grouped tensor pointer can't be NULL."); - auto &t = *transformer_engine::convertNVTEGroupedTensorCheck(tensor); - - // Read from buffer - switch (param) { - case kNVTEGroupedRowwiseData: { - const NVTEBasicTensor *basic_tensor = reinterpret_cast(buf); - t.data = *basic_tensor; - break; - } - case kNVTEGroupedColumnwiseData: { - const NVTEBasicTensor *basic_tensor = reinterpret_cast(buf); - t.columnwise_data = *basic_tensor; - break; - } - case kNVTEGroupedScale: { - const NVTEBasicTensor *basic_tensor = reinterpret_cast(buf); - t.scale = *basic_tensor; - break; - } - case kNVTEGroupedAmax: { - const NVTEBasicTensor *basic_tensor = reinterpret_cast(buf); - t.amax = *basic_tensor; - break; - } - case kNVTEGroupedRowwiseScaleInv: { - const NVTEBasicTensor *basic_tensor = reinterpret_cast(buf); - t.scale_inv = *basic_tensor; - break; - } - case kNVTEGroupedColumnwiseScaleInv: { - const NVTEBasicTensor *basic_tensor = reinterpret_cast(buf); - t.columnwise_scale_inv = *basic_tensor; - break; - } - case kNVTEGroupedColumnwiseAmax: { - const NVTEBasicTensor *basic_tensor = reinterpret_cast(buf); - t.columnwise_amax = *basic_tensor; - break; - } - case kNVTEGroupedWithGEMMSwizzledScales: - t.with_gemm_swizzled_scales = static_cast(*reinterpret_cast(buf)); - break; - default: - NVTE_ERROR("Unsupported grouped tensor parameter (", static_cast(param), ")"); - } -} - -void nvte_get_grouped_tensor_param_v2(const NVTEGroupedTensor tensor, NVTEGroupedTensorParam param, - void *buf, size_t size_in_bytes, size_t *size_written) { - using namespace transformer_engine; - - // Check param - NVTE_CHECK(param < kNVTENumGroupedTensorParams, "Invalid NVTEGroupedTensorParam (got ", - static_cast(param), ")"); - - // Return immediately if buffer is not provided - if (buf == nullptr) { - return; - } - - // Get C++ tensor - const GroupedTensor *t = convertNVTEGroupedTensor(tensor); - - // Write to buffer - switch (param) { - case kNVTEGroupedRowwiseData: { - NVTEBasicTensor *basic_tensor = reinterpret_cast(buf); - *basic_tensor = static_cast(t->data); - break; - } - case kNVTEGroupedColumnwiseData: { - NVTEBasicTensor *basic_tensor = reinterpret_cast(buf); - *basic_tensor = static_cast(t->columnwise_data); - break; - } - case kNVTEGroupedScale: { - NVTEBasicTensor *basic_tensor = reinterpret_cast(buf); - *basic_tensor = static_cast(t->scale); - break; - } - case kNVTEGroupedAmax: { - NVTEBasicTensor *basic_tensor = reinterpret_cast(buf); - *basic_tensor = static_cast(t->amax); - break; - } - case kNVTEGroupedRowwiseScaleInv: { - NVTEBasicTensor *basic_tensor = reinterpret_cast(buf); - *basic_tensor = static_cast(t->scale_inv); - break; - } - case kNVTEGroupedColumnwiseScaleInv: { - NVTEBasicTensor *basic_tensor = reinterpret_cast(buf); - *basic_tensor = static_cast(t->columnwise_scale_inv); - break; - } - case kNVTEGroupedColumnwiseAmax: { - NVTEBasicTensor *basic_tensor = reinterpret_cast(buf); - *basic_tensor = static_cast(t->columnwise_amax); - break; - } - case kNVTEGroupedWithGEMMSwizzledScales: - *reinterpret_cast(buf) = static_cast(t->with_gemm_swizzled_scales); - break; - default: - NVTE_ERROR("Unsupported grouped tensor parameter (", static_cast(param), ")"); - } -} diff --git a/transformer_engine/pytorch/csrc/extensions/attention.cpp b/transformer_engine/pytorch/csrc/extensions/attention.cpp index bd5a5a065f..192a774ca0 100644 --- a/transformer_engine/pytorch/csrc/extensions/attention.cpp +++ b/transformer_engine/pytorch/csrc/extensions/attention.cpp @@ -166,8 +166,6 @@ std::vector fused_attn_fwd( TensorWrapper te_page_table_k, te_page_table_v; if (qkv_type == DType::kFloat8E4M3 || qkv_type == DType::kFloat8E5M2) { // FP8 - // auto h = q_shape[q_shape.size() - 2]; - // auto d = q_shape[q_shape.size() - 1]; if (set_zero && (o_format == NVTE_QKV_Format::NVTE_THD)) { if ((h * d) % block_size == 0) { mha_fill(te_O, cu_seqlens_q.index({torch::indexing::Slice(-1, torch::indexing::None)})); @@ -319,6 +317,7 @@ std::vector fused_attn_fwd( attn_mask_type, softmax_type, window_size[0], window_size[1], bottom_right_diagonal, workspace.data(), at::cuda::getCurrentCUDAStream()); }); + // destroy tensor wrappers, but not allocated memory nvte_tensor_pack_destroy(&nvte_aux_tensor_pack); @@ -364,9 +363,6 @@ std::vector fused_attn_bwd( std::vector q_shape = convertShape(te_Q.shape()); std::vector k_shape = convertShape(te_K.shape()); std::vector v_shape = convertShape(te_V.shape()); - // auto h_q = q_shape[q_shape.size() - 2]; - // auto h_kv = k_shape[k_shape.size() - 2]; - // auto d_qk = q_shape[q_shape.size() - 1]; const DType fake_dtype_te = GetTransformerEngineDType(fake_dtype); size_t b = 0, h_q = 0, h_kv = 0, s_q = 0, s_kv = 0, d_qk = 0, d_v = 0, t_q = 0, t_kv = 0; std::vector dQ_shape(4), dK_shape(4), dV_shape(4); diff --git a/transformer_engine/pytorch/tensor/storage/grouped_tensor.py b/transformer_engine/pytorch/tensor/storage/grouped_tensor.py index 466429cf3f..ef91e58e7c 100644 --- a/transformer_engine/pytorch/tensor/storage/grouped_tensor.py +++ b/transformer_engine/pytorch/tensor/storage/grouped_tensor.py @@ -53,11 +53,7 @@ class GroupedTensor: def __init__( self, num_tensors: int, -<<<<<<< HEAD shapes: List[Tuple[int, int]], -======= - shape: Optional[List[Tuple[int, int]]] = None, ->>>>>>> main quantizer: Optional[Quantizer] = None, dtype: Optional[torch.dtype] = None, data: Optional[torch.Tensor] = None, @@ -959,7 +955,7 @@ def create_and_quantize( dtype=dtype, ) - _ = tex.quantize_grouped(grouped_input, grouped_output) + _ = tex.group_quantize(grouped_input, grouped_output) grouped_output.quantized_tensors = grouped_output.split_into_quantized_tensors() return grouped_output From 1f14f2fa6ba3493598d0386df5af7e01f74daa42 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sun, 1 Mar 2026 23:08:21 +0000 Subject: [PATCH 044/172] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- transformer_engine/pytorch/tensor/storage/grouped_tensor.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/transformer_engine/pytorch/tensor/storage/grouped_tensor.py b/transformer_engine/pytorch/tensor/storage/grouped_tensor.py index ef91e58e7c..c002067e11 100644 --- a/transformer_engine/pytorch/tensor/storage/grouped_tensor.py +++ b/transformer_engine/pytorch/tensor/storage/grouped_tensor.py @@ -970,5 +970,7 @@ def quantize( """ self.quantized_tensors = self.split_into_quantized_tensors() for i in range(self.num_tensors): - self.quantizer.update_quantized(tensors[i], self.quantized_tensors[i], noop_flag=noop_flag) + self.quantizer.update_quantized( + tensors[i], self.quantized_tensors[i], noop_flag=noop_flag + ) return self.quantized_tensors From ccebe771058024ac51cb335633136f6df779fb57 Mon Sep 17 00:00:00 2001 From: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Date: Sun, 1 Mar 2026 15:17:17 -0800 Subject: [PATCH 045/172] fix merge Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --- .../transformer_engine/transformer_engine.h | 229 ------------------ 1 file changed, 229 deletions(-) diff --git a/transformer_engine/common/include/transformer_engine/transformer_engine.h b/transformer_engine/common/include/transformer_engine/transformer_engine.h index 3dacc596c8..635b9fdcce 100644 --- a/transformer_engine/common/include/transformer_engine/transformer_engine.h +++ b/transformer_engine/common/include/transformer_engine/transformer_engine.h @@ -449,13 +449,6 @@ enum NVTEGroupedTensorParam { kNVTEGroupedLastDims = 8, /*!< Last dimension sizes (device pointer to int64_t array) */ kNVTEGroupedTensorOffsets = 9, /*!< Tensor offsets for contiguous layout (device pointer to int64_t array) */ -<<<<<<< HEAD -<<<<<<< HEAD - kNVTEGroupedWithGEMMSwizzledScales = - 10, /*!< Whether scaling factors are in format expected by GEMM */ -======= -======= ->>>>>>> 341cc3df (fix merge) kNVTEGroupedWithGEMMSwizzledScales = 10, /*!< Whether scaling factors are in format expected by GEMM */ kNVTENumGroupedTensorParams @@ -974,228 +967,6 @@ class TensorWrapper { /*! \struct GroupedTensorWrapper * \brief C++ wrapper for the NVTEGroupedTensor class. */ - -<<<<<<< HEAD -<<<<<<< HEAD -class GroupedTensorWrapper { - public: - /*! \brief Constructs new GroupedTensorWrapper. - * - * Create a new TE grouped tensor with a given logical shape. - * TE grouped tensors are just wrappers on top of raw data and do not - * own memory. - * - * \param[in] num_tensors Number of tensors in the group (must be > 0). - * \param[in] logical_shape Logical 2D shape of the grouped data. - * \param[in] scaling_mode Tensor data format. - */ - GroupedTensorWrapper(const size_t num_tensors, const NVTEShape &logical_shape, - const NVTEScalingMode scaling_mode = NVTE_DELAYED_TENSOR_SCALING) - : tensor_(nvte_create_grouped_tensor(scaling_mode, num_tensors, logical_shape)) {} - - /*! \brief Constructs new GroupedTensorWrapper. - * - * Create a new TE grouped tensor with a given logical shape. - * - * \param[in] num_tensors Number of tensors in the group (must be > 0). - * \param[in] logical_shape Logical 2D shape of the grouped data. - * \param[in] scaling_mode Tensor data format. - */ - GroupedTensorWrapper(const size_t num_tensors, const std::vector &logical_shape, - const NVTEScalingMode scaling_mode = NVTE_DELAYED_TENSOR_SCALING) - : GroupedTensorWrapper(num_tensors, - nvte_make_shape(logical_shape.data(), logical_shape.size()), - scaling_mode) {} - - /*! \brief GroupedTensorWrapper destructor. */ - ~GroupedTensorWrapper() { nvte_destroy_grouped_tensor(tensor_); } - - GroupedTensorWrapper &operator=(const GroupedTensorWrapper &other) = delete; - GroupedTensorWrapper(const GroupedTensorWrapper &other) = delete; - - /*! \brief Constructs new GroupedTensorWrapper from existing GroupedTensorWrapper. */ - GroupedTensorWrapper(GroupedTensorWrapper &&other) { - tensor_ = other.tensor_; - other.tensor_ = nullptr; - } - - /*! \brief Assign the data from existing GroupedTensorWrapper. */ - GroupedTensorWrapper &operator=(GroupedTensorWrapper &&other) { - if (this == &other) return *this; - nvte_destroy_grouped_tensor(tensor_); - tensor_ = other.tensor_; - other.tensor_ = nullptr; - return *this; - } - - // Parameter setters - template - GroupedTensorWrapper &set_parameter(const NVTEGroupedTensorParam param, void *dptr, DType type, - const ShapeType &shape) noexcept { - NVTEShape nvte_shape = this->convertShape(shape); - NVTEBasicTensor data = {dptr, static_cast(type), nvte_shape}; - nvte_set_grouped_tensor_param(&tensor_, param, &data); - return *this; - } - - template - GroupedTensorWrapper &set_rowwise_data(void *dptr, DType type, const ShapeType &shape) noexcept { - return set_parameter(kNVTEGroupedRowwiseData, dptr, type, shape); - } - - template - GroupedTensorWrapper &set_columnwise_data(void *dptr, DType type, - const ShapeType &shape) noexcept { - return set_parameter(kNVTEGroupedColumnwiseData, dptr, type, shape); - } - - template - GroupedTensorWrapper &set_scale(void *dptr, DType type, const ShapeType &shape) noexcept { - return set_parameter(kNVTEGroupedScale, dptr, type, shape); - } - - template - GroupedTensorWrapper &set_amax(void *dptr, DType type, const ShapeType &shape) noexcept { - return set_parameter(kNVTEGroupedAmax, dptr, type, shape); - } - - template - GroupedTensorWrapper &set_rowwise_scale_inv(void *dptr, DType type, - const ShapeType &shape) noexcept { - return set_parameter(kNVTEGroupedRowwiseScaleInv, dptr, type, shape); - } - - template - GroupedTensorWrapper &set_columnwise_scale_inv(void *dptr, DType type, - const ShapeType &shape) noexcept { - return set_parameter(kNVTEGroupedColumnwiseScaleInv, dptr, type, shape); - } - - template - GroupedTensorWrapper &set_columnwise_amax(void *dptr, DType type, - const ShapeType &shape) noexcept { - return set_parameter(kNVTEGroupedColumnwiseAmax, dptr, type, shape); - } - - template - GroupedTensorWrapper &set_first_dims(void *dptr, DType type, const ShapeType &shape) noexcept { - return set_parameter(kNVTEGroupedFirstDims, dptr, type, shape); - } - - template - GroupedTensorWrapper &set_last_dims(void *dptr, DType type, const ShapeType &shape) noexcept { - return set_parameter(kNVTEGroupedLastDims, dptr, type, shape); - } - - template - GroupedTensorWrapper &set_tensor_offsets(void *dptr, DType type, - const ShapeType &shape) noexcept { - return set_parameter(kNVTEGroupedTensorOffsets, dptr, type, shape); - } - - void set_with_gemm_swizzled_scales(bool with_gemm_swizzled_scales) { - const auto val = static_cast(with_gemm_swizzled_scales); - nvte_set_grouped_tensor_param_v2(tensor_, kNVTEGroupedWithGEMMSwizzledScales, &val, - sizeof(val)); - } - - // Parameter getters - NVTEBasicTensor get_parameter(const NVTEGroupedTensorParam param) const noexcept { - return nvte_get_grouped_tensor_param(tensor_, param); - } - - NVTEBasicTensor get_rowwise_data() const noexcept { - return get_parameter(kNVTEGroupedRowwiseData); - } - - NVTEBasicTensor get_columnwise_data() const noexcept { - return get_parameter(kNVTEGroupedColumnwiseData); - } - - NVTEBasicTensor get_scale() const noexcept { return get_parameter(kNVTEGroupedScale); } - - NVTEBasicTensor get_amax() const noexcept { return get_parameter(kNVTEGroupedAmax); } - - NVTEBasicTensor get_rowwise_scale_inv() const noexcept { - return get_parameter(kNVTEGroupedRowwiseScaleInv); - } - - NVTEBasicTensor get_columnwise_scale_inv() const noexcept { - return get_parameter(kNVTEGroupedColumnwiseScaleInv); - } - - NVTEBasicTensor get_columnwise_amax() const noexcept { - return get_parameter(kNVTEGroupedColumnwiseAmax); - } - - NVTEBasicTensor get_first_dims() const noexcept { return get_parameter(kNVTEGroupedFirstDims); } - - NVTEBasicTensor get_last_dims() const noexcept { return get_parameter(kNVTEGroupedLastDims); } - - NVTEBasicTensor get_tensor_offsets() const noexcept { - return get_parameter(kNVTEGroupedTensorOffsets); - } - - bool get_with_gemm_swizzled_scales() const { - uint8_t val = 0; - nvte_get_grouped_tensor_param_v2(tensor_, kNVTEGroupedWithGEMMSwizzledScales, &val, sizeof(val), - nullptr); - return static_cast(val); - } - - /*! \brief Get an underlying NVTEGroupedTensor. - * - * \return NVTEGroupedTensor held by this GroupedTensorWrapper. - */ - NVTEGroupedTensor data() const noexcept { return tensor_; } - - /*! \brief Get the number of tensors in this GroupedTensorWrapper. */ - size_t num_tensors() const noexcept { - if (tensor_ == nullptr) return 0; - return nvte_grouped_tensor_num_tensors(tensor_); - } - - /*! \brief Get the data type of this GroupedTensorWrapper. */ - DType dtype() const noexcept { - if (tensor_ == nullptr) return DType::kNumTypes; - return static_cast(nvte_grouped_tensor_type(tensor_)); - } - - /*! \brief Get a scaling mode of the grouped tensor. */ - NVTEScalingMode scaling_mode() const noexcept { - if (tensor_ == nullptr) return NVTE_DELAYED_TENSOR_SCALING; - return nvte_grouped_tensor_scaling_mode(tensor_); - } - - /*! \brief Get the logical shape of this GroupedTensorWrapper. */ - const NVTEShape logical_shape() const noexcept { - if (tensor_ == nullptr) { - return emptyShape; - } - return nvte_get_grouped_tensor_logical_shape(tensor_); - } - - static constexpr size_t defaultData = 1; - static constexpr NVTEShape defaultShape = { - {defaultData, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, 1}; - static constexpr NVTEShape emptyShape = {{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, 1}; - - private: - NVTEShape convertShape(const NVTEShape &s) { return s; } - - NVTEShape convertShape(const std::vector &s) { - return nvte_make_shape(s.data(), s.size()); - } - - /*! \brief Wrapped NVTEGroupedTensor. */ - NVTEGroupedTensor tensor_ = nullptr; -}; - -/*! \warning Deprecated */ -enum class Float8BlockScaleTensorFormat { GEMM_READY = 0, COMPACT = 1, INVALID }; -======= -======= ->>>>>>> 341cc3df (fix merge) class GroupedTensorWrapper { public: /*! \brief Constructs new GroupedTensorWrapper. From c52c5f41aafba3dd642ce5449f79379406fad9e2 Mon Sep 17 00:00:00 2001 From: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Date: Sun, 1 Mar 2026 15:33:26 -0800 Subject: [PATCH 046/172] revert to main grouped tensor impl Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --- tests/pytorch/test_grouped_tensor.py | 117 +++++++++--------- .../pytorch/tensor/storage/grouped_tensor.py | 80 +++++------- 2 files changed, 91 insertions(+), 106 deletions(-) diff --git a/tests/pytorch/test_grouped_tensor.py b/tests/pytorch/test_grouped_tensor.py index de00d0cf35..ad08c0474d 100644 --- a/tests/pytorch/test_grouped_tensor.py +++ b/tests/pytorch/test_grouped_tensor.py @@ -121,11 +121,11 @@ def setup_class(cls) -> None: def test_basic_construction_all_same_shape(self) -> None: """Test GroupedTensor construction with all tensors having same shape""" num_tensors = 4 - shapes = [(256, 512) for _ in range(num_tensors)] + shape = [(256, 512) for _ in range(num_tensors)] grouped_tensor = GroupedTensor.make_grouped_tensor_with_shapes( num_tensors=num_tensors, - shapes=shapes, + shape=shape, quantizer=None, device="cuda", dtype=torch.float32, @@ -143,11 +143,11 @@ def test_basic_construction_all_same_shape(self) -> None: def test_basic_construction_varying_first_dim(self) -> None: """Test GroupedTensor construction with varying first dimension""" num_tensors = 3 - shapes = [(128, 512), (256, 512), (384, 512)] + shape = [(128, 512), (256, 512), (384, 512)] grouped_tensor = GroupedTensor.make_grouped_tensor_with_shapes( num_tensors=num_tensors, - shapes=shapes, + shape=shape, quantizer=None, device="cuda", dtype=torch.float32, @@ -157,20 +157,20 @@ def test_basic_construction_varying_first_dim(self) -> None: assert not grouped_tensor.all_same_shape() assert not grouped_tensor.all_same_first_dim() assert grouped_tensor.all_same_last_dim() - assert grouped_tensor.get_common_last_dim() == shapes[0][1] + assert grouped_tensor.get_common_last_dim() == shape[0][1] assert grouped_tensor.logical_shape == ( - sum(v for v, _ in shapes), - shapes[0][1], + sum(v for v, _ in shape), + shape[0][1], ) # sum of first dims def test_split_into_quantized_tensors_no_quantization(self) -> None: """Test split_into_quantized_tensors for unquantized tensors""" num_tensors = 3 - shapes = [(256, 512) for _ in range(num_tensors)] + shape = [(256, 512) for _ in range(num_tensors)] grouped_tensor = GroupedTensor.make_grouped_tensor_with_shapes( num_tensors=num_tensors, - shapes=shapes, + shape=shape, quantizer=None, device="cuda", dtype=torch.float32, @@ -186,7 +186,7 @@ def test_split_into_quantized_tensors_no_quantization(self) -> None: # Verify each tensor has correct shape and shares storage for i, tensor in enumerate(tensors): - assert tensor.shape == shapes[i] + assert tensor.shape == shape[i] assert isinstance(tensor, torch.Tensor) assert not hasattr(tensor, "_data") # Not a quantized tensor @@ -195,19 +195,19 @@ def test_split_into_quantized_tensors_no_quantization(self) -> None: assert tensor.data_ptr() >= original_data_ptr # Calculate expected offset - expected_offset = i * (shapes[i][0] * shapes[i][1]) * tensor.element_size() + expected_offset = i * (shape[i][0] * shape[i][1]) * tensor.element_size() assert tensor.data_ptr() == original_data_ptr + expected_offset @pytest.mark.parametrize("quantization", _quantization_params) def test_split_into_quantized_tensors_quantized(self, quantization: str) -> None: """Test split_into_quantized_tensors for quantized tensors""" num_tensors = 3 - shapes = [(512, 512) for _ in range(num_tensors)] - quantizer = make_quantizer(quantization, num_tensors, shapes) + shape = [(512, 512) for _ in range(num_tensors)] + quantizer = make_quantizer(quantization, num_tensors, shape) grouped_tensor = GroupedTensor.make_grouped_tensor_with_shapes( num_tensors=num_tensors, - shapes=shapes, + shape=shape, quantizer=quantizer, device="cuda", ) @@ -225,18 +225,18 @@ def test_split_into_quantized_tensors_quantized(self, quantization: str) -> None rowwise_data = _get_rowwise_data_tensor(tensor, quantization) assert rowwise_data is not None assert rowwise_data.data_ptr() >= original_data_ptr - numel = shapes[i][0] * shapes[i][1] + numel = shape[i][0] * shape[i][1] expected_offset = _rowwise_offset_bytes(i * numel, quantization) assert rowwise_data.data_ptr() == original_data_ptr + expected_offset def test_split_varying_shapes(self) -> None: """Test split_into_quantized_tensors with varying shapes""" num_tensors = 3 - shapes = [(128, 512), (256, 512), (384, 512)] + shape = [(128, 512), (256, 512), (384, 512)] grouped_tensor = GroupedTensor.make_grouped_tensor_with_shapes( num_tensors=num_tensors, - shapes=shapes, + shape=shape, quantizer=None, device="cuda", dtype=torch.float32, @@ -250,21 +250,21 @@ def test_split_varying_shapes(self) -> None: # Verify shapes and storage cumulative_offset = 0 for i, tensor in enumerate(tensors): - assert tensor.shape == shapes[i] + assert tensor.shape == shape[i] expected_offset = cumulative_offset * tensor.element_size() assert tensor.data_ptr() == original_data_ptr + expected_offset - cumulative_offset += shapes[i][0] * shapes[i][1] + cumulative_offset += shape[i][0] * shape[i][1] @pytest.mark.parametrize("quantization", _quantization_params) def test_quantize_inplace(self, quantization: str) -> None: """Test that quantize is done in-place for all recipes""" num_tensors = 3 - shapes = [(512, 512) for _ in range(num_tensors)] - quantizer = make_quantizer(quantization, num_tensors, shapes) + shape = [(512, 512) for _ in range(num_tensors)] + quantizer = make_quantizer(quantization, num_tensors, shape) grouped_tensor = GroupedTensor.make_grouped_tensor_with_shapes( num_tensors=num_tensors, - shapes=shapes, + shape=shape, quantizer=quantizer, device="cuda", ) @@ -277,7 +277,7 @@ def test_quantize_inplace(self, quantization: str) -> None: ) # Create input tensors - input_tensors = [torch.randn(s, dtype=torch.float32, device="cuda") for s in shapes] + input_tensors = [torch.randn(s, dtype=torch.float32, device="cuda") for s in shape] # Quantize in place quantized_tensors = grouped_tensor.quantize(input_tensors) @@ -291,7 +291,7 @@ def test_quantize_inplace(self, quantization: str) -> None: # Verify returned tensors point to the same storage for i, qtensor in enumerate(quantized_tensors): rowwise_data = _get_rowwise_data_tensor(qtensor, quantization) - numel = shapes[i][0] * shapes[i][1] + numel = shape[i][0] * shape[i][1] expected_offset = _rowwise_offset_bytes(i * numel, quantization) assert rowwise_data.data_ptr() == original_data_ptr + expected_offset @@ -299,12 +299,12 @@ def test_quantize_inplace(self, quantization: str) -> None: def test_quantize_varying_shapes(self, quantization: str) -> None: """Test quantize with varying shapes""" num_tensors = 3 - shapes = [(256, 512), (512, 512), (768, 512)] - quantizer = make_quantizer(quantization, num_tensors, shapes) + shape = [(256, 512), (512, 512), (768, 512)] + quantizer = make_quantizer(quantization, num_tensors, shape) grouped_tensor = GroupedTensor.make_grouped_tensor_with_shapes( num_tensors=num_tensors, - shapes=shapes, + shape=shape, quantizer=quantizer, device="cuda", ) @@ -313,7 +313,7 @@ def test_quantize_varying_shapes(self, quantization: str) -> None: original_data_ptr = grouped_tensor.data.data_ptr() # Create input tensors with varying shapes - input_tensors = [torch.randn(s, dtype=torch.float32, device="cuda") for s in shapes] + input_tensors = [torch.randn(s, dtype=torch.float32, device="cuda") for s in shape] # Quantize in place quantized_tensors = grouped_tensor.quantize(input_tensors) @@ -323,7 +323,7 @@ def test_quantize_varying_shapes(self, quantization: str) -> None: # Verify each tensor points to correct location cumulative_numel = 0 - for qtensor, tensor_shape in zip(quantized_tensors, shapes): + for qtensor, tensor_shape in zip(quantized_tensors, shape): rowwise_data = _get_rowwise_data_tensor(qtensor, quantization) expected_offset = _rowwise_offset_bytes(cumulative_numel, quantization) assert rowwise_data.data_ptr() == original_data_ptr + expected_offset @@ -333,11 +333,11 @@ def test_quantize_varying_shapes(self, quantization: str) -> None: def test_static_quantize_method(self, quantization: str) -> None: """Test the static quantize method""" num_tensors = 3 - shapes = [(512, 512) for _ in range(num_tensors)] - quantizer = make_quantizer(quantization, num_tensors, shapes) + shape = [(512, 512) for _ in range(num_tensors)] + quantizer = make_quantizer(quantization, num_tensors, shape) # Create input tensors - input_tensors = [torch.randn(s, dtype=torch.float32, device="cuda") for s in shapes] + input_tensors = [torch.randn(s, dtype=torch.float32, device="cuda") for s in shape] # Use static quantize method grouped_tensor = GroupedTensor.create_and_quantize( @@ -357,43 +357,44 @@ def test_static_quantize_method(self, quantization: str) -> None: original_data_ptr = grouped_tensor.data.data_ptr() for i, qtensor in enumerate(grouped_tensor.quantized_tensors): rowwise_data = _get_rowwise_data_tensor(qtensor, quantization) - numel = shapes[i][0] * shapes[i][1] + numel = shape[i][0] * shape[i][1] expected_offset = _rowwise_offset_bytes(i * numel, quantization) assert rowwise_data.data_ptr() == original_data_ptr + expected_offset + @pytest.mark.parametrize( + "shape", + [[(256, 512), (512, 512), (768, 512)], [(512, 512), (512, 512), (512, 512)]], + ) @pytest.mark.skipif(not mxfp8_available, reason=reason_for_no_mxfp8) - def test_quantize_grouped_mxfp8(self) -> None: + def test_quantize_grouped_mxfp8(self, shape: List[Tuple[int, int]]) -> None: """Test grouped quantization for MXFP8 against per-tensor quantization.""" # Test wont pass until the grouped quantization PR from Oleg is merged. num_tensors = 2 - shapes = [(512, 1024) for _ in range(num_tensors)] + shape = [(512, 1024) for _ in range(num_tensors)] + + # Create BF16 input tensors and pack into a 2D tensor + input_tensors = [torch.randn(s, dtype=torch.bfloat16, device="cuda") for s in shape] + quantized_tensors = [ + MXFP8Quantizer(fp8_dtype=tex.DType.kFloat8E4M3)(tensor) for tensor in input_tensors + ] + grouped_input = torch.cat(input_tensors, dim=0) - # Create BF16 input tensors and pack into a grouped tensor - input_tensors = [torch.randn(s, dtype=torch.bfloat16, device="cuda") for s in shapes] + # Create MXFP8 output grouped tensor (rowwise only for easier validation) quantizer = MXFP8Quantizer(fp8_dtype=tex.DType.kFloat8E4M3) - quantizer.optimize_for_gemm = True - grouped_input = GroupedTensor.make_grouped_tensor_with_shapes( - num_tensors=num_tensors, - shapes=shapes, - quantizer=None, + quantizer.set_usage(rowwise=True, columnwise=False) + first_dims = torch.tensor( + [shape[0][0] for _ in range(num_tensors)], + dtype=torch.int64, device="cuda", - dtype=torch.bfloat16, ) - offset = 0 - for tensor in input_tensors: - numel = tensor.numel() - grouped_input.data[offset : offset + numel].copy_(tensor.reshape(-1)) - offset += numel - - grouped_output = GroupedTensor.make_grouped_tensor_with_shapes( - num_tensors=num_tensors, - shapes=shapes, - quantizer=quantizer, - device="cuda", + # Quantize using grouped API + grouped_output = tex.group_quantize( + grouped_input, + quantizer, + num_tensors, + first_dims, ) - # Quantize using grouped API (handle both 2-arg and 3-arg bindings) - _ = tex.group_quantize(grouped_input, grouped_output) # Build expected output by quantizing each tensor independently expected_data = [] expected_scale_inv = [] @@ -456,11 +457,11 @@ def test_group_quantize_cudagraph_capturable(self) -> None: def test_clear(self) -> None: """Test clear method""" num_tensors = 3 - shapes = [(256, 512) for _ in range(num_tensors)] + shape = [(256, 512) for _ in range(num_tensors)] grouped_tensor = GroupedTensor.make_grouped_tensor_with_shapes( num_tensors=num_tensors, - shapes=shapes, + shape=shape, quantizer=None, device="cuda", dtype=torch.float32, diff --git a/transformer_engine/pytorch/tensor/storage/grouped_tensor.py b/transformer_engine/pytorch/tensor/storage/grouped_tensor.py index c002067e11..bf5792ffc9 100644 --- a/transformer_engine/pytorch/tensor/storage/grouped_tensor.py +++ b/transformer_engine/pytorch/tensor/storage/grouped_tensor.py @@ -8,8 +8,7 @@ import math import torch -import transformer_engine -import transformer_engine_torch as tex + from ...quantized_tensor import QuantizedTensorStorage, Quantizer from ..mxfp8_tensor import MXFP8Tensor @@ -53,7 +52,7 @@ class GroupedTensor: def __init__( self, num_tensors: int, - shapes: List[Tuple[int, int]], + shape: Optional[List[Tuple[int, int]]] = None, quantizer: Optional[Quantizer] = None, dtype: Optional[torch.dtype] = None, data: Optional[torch.Tensor] = None, @@ -76,7 +75,7 @@ def __init__( Args: num_tensors: Number of tensors in the group - shapes: 2D shape of each tensor (len num_tensors) + shape: 2D shape of each tensor (len num_tensors) quantizer: Quantizer for the grouped tensor data: Row-wise data buffer (1D flattened) columnwise_data: Column-wise data buffer (1D flattened) @@ -93,7 +92,7 @@ def __init__( """ self.num_tensors = num_tensors self.quantizer = quantizer - self.shapes = shapes + self.shape = shape self.dtype = ( dtype if dtype is not None else torch.float32 ) # Default to float32 if not provided @@ -269,7 +268,7 @@ def __repr__(self) -> str: """String representation of the GroupedTensor.""" return ( f"GroupedTensor(num_tensors={self.num_tensors}, " - f"shapes={self.shapes}, " + f"shape={self.shape}, " f"logical_shape={self.logical_shape}, " f"dtype={self.get_dtype()})" ) @@ -295,7 +294,7 @@ def __str__(self) -> str: @staticmethod def make_grouped_tensor_with_shapes( num_tensors: int, - shapes: List[Tuple[int, int]], + shape: List[Tuple[int, int]], quantizer: Optional[Quantizer] = None, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None, @@ -305,8 +304,8 @@ def make_grouped_tensor_with_shapes( Args: num_tensors: Number of tensors - shapes: 2D shape of each tensor (len num_tensors) - quantizer: Quantizer for the grouped tensor + shape: 2D shape of each tensor (len num_tensors) + quantizer: Quantizer for each tensor device: Device to allocate tensors on, defaults to current cuda device dtype: Data type of the tensor (for high precision case) @@ -315,16 +314,16 @@ def make_grouped_tensor_with_shapes( """ # First dim - first_dim_list = [s[0] for s in shapes] + first_dim_list = [s[0] for s in shape] uniform_first_dim = all(first_dim_list[0] == x for x in first_dim_list) logical_first_dim = sum(first_dim_list) if uniform_first_dim: first_dims = None else: - first_dims = torch.tensor([s[0] for s in shapes], dtype=torch.int64, device=device) + first_dims = torch.tensor([s[0] for s in shape], dtype=torch.int64, device=device) # Last dim - last_dim_list = [s[1] for s in shapes] + last_dim_list = [s[1] for s in shape] logical_last_dim = last_dim_list[0] assert all(logical_last_dim == x for x in last_dim_list), "Last dims should be uniform" @@ -359,7 +358,7 @@ def make_grouped_tensor( last_dims: Device tensor of int64 array of length num_tensors (or None if uniform) logical_first_dim: Logical first dimension logical_last_dim: Logical last dimension - quantizer: Quantizer for the grouped tensor + quantizer: Quantizer for each tensor Used to figure out the recipe and what to allocate. device: Device to allocate tensors on, defaults to current cuda device dtype: Data type of the tensor (for high precision case) @@ -388,7 +387,7 @@ def make_grouped_tensor( # Calculate tensor offsets (cumulative element offsets) tensor_offsets = None offsets = None - shapes = [] + shape = [] if not all_same_first: # Need explicit offsets for non-uniform shapes # Offsets are based on number of elements and not pointers. @@ -404,14 +403,14 @@ def make_grouped_tensor( offsets = tensor_offsets.tolist() first_dims_list = first_dims.tolist() for i in range(num_tensors): - shapes.append((first_dims_list[i], logical_last_dim)) + shape.append((first_dims_list[i], logical_last_dim)) else: offsets = [ i * logical_first_dim * logical_last_dim // num_tensors for i in range(num_tensors + 1) ] for i in range(num_tensors): - shapes.append((logical_first_dim // num_tensors, logical_last_dim)) + shape.append((logical_first_dim // num_tensors, logical_last_dim)) # Calculate logical shape based logical_shape = (logical_first_dim, logical_last_dim) @@ -450,7 +449,7 @@ def make_grouped_tensor( # For grouped tensors, we need to calculate scale_inv size for all tensors total_scale_elements = 0 scale_inv_offsets = [0] - for i, s in enumerate(shapes): + for i, s in enumerate(shape): scale_inv_shape = quantizer.get_scale_shape(s, False) scale_elements = math.prod(scale_inv_shape) total_scale_elements += scale_elements @@ -463,7 +462,7 @@ def make_grouped_tensor( # Columnwise scale inverse buffer total_columnwise_scale_elements = 0 columnwise_scale_inv_offsets = [0] - for i, s in enumerate(shapes): + for i, s in enumerate(shape): scale_inv_shape = quantizer.get_scale_shape(s, False) columnwise_scale_elements = math.prod(scale_inv_shape) total_columnwise_scale_elements += columnwise_scale_elements @@ -499,7 +498,7 @@ def make_grouped_tensor( # For simplicity, calculate total scale elements needed total_scale_elements = 0 scale_inv_offsets = [0] - for i, s in enumerate(shapes): + for i, s in enumerate(shape): scale_inv_shape = quantizer.get_scale_shape(s, False) total_scale_elements += math.prod(scale_inv_shape) scale_inv_offsets.append(total_scale_elements) @@ -515,7 +514,7 @@ def make_grouped_tensor( # Columnwise scale inverse buffer total_columnwise_scale_elements = 0 columnwise_scale_inv_offsets = [0] - for i, s in enumerate(shapes): + for i, s in enumerate(shape): columnwise_scale_inv_shape = quantizer.get_scale_shape(s, True) total_columnwise_scale_elements += math.prod(columnwise_scale_inv_shape) columnwise_scale_inv_offsets.append(total_columnwise_scale_elements) @@ -532,7 +531,7 @@ def make_grouped_tensor( # For simplicity, calculate total scale elements needed total_scale_elements = 0 scale_inv_offsets = [0] - for i, s in enumerate(shapes): + for i, s in enumerate(shape): scale_inv_shape = quantizer.get_scale_shape(s, False) total_scale_elements += math.prod(scale_inv_shape) scale_inv_offsets.append(total_scale_elements) @@ -544,7 +543,7 @@ def make_grouped_tensor( # Columnwise scale inverse total_columnwise_scale_elements = 0 columnwise_scale_inv_offsets = [0] - for i, s in enumerate(shapes): + for i, s in enumerate(shape): columnwise_scale_inv_shape = quantizer.get_scale_shape(s, True) total_columnwise_scale_elements += math.prod(columnwise_scale_inv_shape) columnwise_scale_inv_offsets.append(total_columnwise_scale_elements) @@ -577,7 +576,7 @@ def make_grouped_tensor( grouped_tensor = GroupedTensor( num_tensors=num_tensors, - shapes=shapes, + shape=shape, dtype=dtype, quantizer=quantizer, data=data, @@ -646,7 +645,7 @@ def split_into_quantized_tensors( if no_quantization: for i in range(self.num_tensors): # Get tensor shape - tensor_shape = self.shapes[i] + tensor_shape = self.shape[i] # Get tensor data slice if self.offsets is not None: @@ -700,7 +699,7 @@ def split_into_quantized_tensors( for i in range(self.num_tensors): # Get tensor shape - tensor_shape = self.shapes[i] + tensor_shape = self.shape[i] numel = tensor_shape[0] * tensor_shape[1] # Get data offsets @@ -933,32 +932,18 @@ def create_and_quantize( Quantize given tensors into quantized tensors with underlying storage allocated in a GroupedTensor. """ - grouped_input = GroupedTensor.make_grouped_tensor_with_shapes( - num_tensors=len(tensors), - shapes=[t.shape for t in tensors], - quantizer=None, - device=device, - dtype=tensors[0].dtype, - ) - - offset = 0 - for tensor in tensors: - numel = tensor.numel() - grouped_input.data[offset : offset + numel].copy_(tensor.reshape(-1)) - offset += numel - grouped_output = GroupedTensor.make_grouped_tensor_with_shapes( + grouped_tensor = GroupedTensor.make_grouped_tensor_with_shapes( num_tensors=len(tensors), - shapes=[t.shape for t in tensors], + shape=[t.shape for t in tensors], quantizer=quantizer, device=device, dtype=dtype, ) - _ = tex.group_quantize(grouped_input, grouped_output) - grouped_output.quantized_tensors = grouped_output.split_into_quantized_tensors() + grouped_tensor.quantize(tensors, noop_flag=noop_flag) - return grouped_output + return grouped_tensor def quantize( self, @@ -968,9 +953,8 @@ def quantize( """ Quantize the GroupedTensor inplace. """ - self.quantized_tensors = self.split_into_quantized_tensors() + + quantized_tensors = self.split_into_quantized_tensors() for i in range(self.num_tensors): - self.quantizer.update_quantized( - tensors[i], self.quantized_tensors[i], noop_flag=noop_flag - ) - return self.quantized_tensors + self.quantizer.update_quantized(tensors[i], quantized_tensors[i], noop_flag=noop_flag) + return quantized_tensors From 5b776ec2489a69cb0db49fa8275010a97b5fa019 Mon Sep 17 00:00:00 2001 From: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Date: Sun, 1 Mar 2026 15:39:19 -0800 Subject: [PATCH 047/172] minor tweaks to return to main Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --- transformer_engine/common/cast/cast.cu | 1 + transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh | 1 + .../common/include/transformer_engine/transformer_engine.h | 1 + 3 files changed, 3 insertions(+) diff --git a/transformer_engine/common/cast/cast.cu b/transformer_engine/common/cast/cast.cu index f4825970cb..57404ae8a5 100644 --- a/transformer_engine/common/cast/cast.cu +++ b/transformer_engine/common/cast/cast.cu @@ -30,6 +30,7 @@ void nvte_group_quantize(const NVTEGroupedTensor input, NVTEGroupedTensor output cudaStream_t stream) { NVTE_API_CALL(nvte_group_quantize); using namespace transformer_engine; + constexpr bool IS_ACT = false; dispatch::group_quantize_fwd_helper(input, output, nullptr, stream); } diff --git a/transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh b/transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh index f454209409..6447fc4542 100644 --- a/transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh +++ b/transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh @@ -787,6 +787,7 @@ void group_quantize(const GroupedTensor *input, const GroupedTensor *activations using namespace group_quantize_kernel; checkCuDriverContext(stream); + CheckNoopTensor(*noop, "cast_noop"); const bool use_rowwise_scaling = output->has_data(); const bool use_colwise_scaling = output->has_columnwise_data(); diff --git a/transformer_engine/common/include/transformer_engine/transformer_engine.h b/transformer_engine/common/include/transformer_engine/transformer_engine.h index 635b9fdcce..e316f8be8c 100644 --- a/transformer_engine/common/include/transformer_engine/transformer_engine.h +++ b/transformer_engine/common/include/transformer_engine/transformer_engine.h @@ -967,6 +967,7 @@ class TensorWrapper { /*! \struct GroupedTensorWrapper * \brief C++ wrapper for the NVTEGroupedTensor class. */ + class GroupedTensorWrapper { public: /*! \brief Constructs new GroupedTensorWrapper. From 4eee2bcceb1ff9b06b11aa2c6b0f67f9ad54bf20 Mon Sep 17 00:00:00 2001 From: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Date: Tue, 3 Mar 2026 12:13:20 -0800 Subject: [PATCH 048/172] remove prints Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --- tests/pytorch/attention/test_attention.py | 28 ++--- .../common/fused_attn/fused_attn_fp8.cu | 104 +----------------- .../dot_product_attention/backends.py | 2 +- .../dot_product_attention/context_parallel.py | 76 ------------- .../attention/dot_product_attention/utils.py | 4 - 5 files changed, 16 insertions(+), 198 deletions(-) diff --git a/tests/pytorch/attention/test_attention.py b/tests/pytorch/attention/test_attention.py index 47abf1ebc6..f39ed547cb 100644 --- a/tests/pytorch/attention/test_attention.py +++ b/tests/pytorch/attention/test_attention.py @@ -1806,21 +1806,21 @@ def get_model(dtype, config): # test: ModelConfig(b, sq, hq, dqk) "fp8_9": ModelConfig(2, 2048, 24, 128, num_gqa_groups=12), # , attn_mask_type="causal"), "fp8_10": ModelConfig( - 2, 2048, 24, 192, head_dim_v=128, num_gqa_groups=12, window_size=(512, 512) - ), - "fp8_11": ModelConfig(1, 8192, 32, 128, num_gqa_groups=4), - "fp8_12": ModelConfig(2, 2048, 16, 128, attn_mask_type="causal"), - "fp8_13": ModelConfig(2, 2048, 24, 128, num_gqa_groups=12, attn_mask_type="causal"), - "fp8_14": ModelConfig(1, 8192, 32, 128, num_gqa_groups=4, attn_mask_type="causal"), - "fp8_15": ModelConfig(2, 2048, 16, 128, attn_mask_type="padding"), - "fp8_16": ModelConfig(2, 2048, 24, 128, num_gqa_groups=12, attn_mask_type="padding"), - "fp8_17": ModelConfig(1, 8192, 32, 128, num_gqa_groups=4, attn_mask_type="padding"), - "fp8_18": ModelConfig(2, 2048, 16, 128, attn_mask_type="padding_causal"), - "fp8_19": ModelConfig(2, 2048, 24, 128, num_gqa_groups=12, attn_mask_type="padding_causal"), - "fp8_20": ModelConfig(1, 8192, 32, 128, num_gqa_groups=4, attn_mask_type="padding_causal"), + 2, 2048, 24, 192, head_dim_v=128, #num_gqa_groups=12, window_size=(512, 512) + ), + # "fp8_11": ModelConfig(1, 8192, 32, 128, num_gqa_groups=4), + # "fp8_12": ModelConfig(2, 2048, 16, 128, attn_mask_type="causal"), + # "fp8_13": ModelConfig(2, 2048, 24, 128, num_gqa_groups=12, attn_mask_type="causal"), + # "fp8_14": ModelConfig(1, 8192, 32, 128, num_gqa_groups=4, attn_mask_type="causal"), + # "fp8_15": ModelConfig(2, 2048, 16, 128, attn_mask_type="padding"), + # "fp8_16": ModelConfig(2, 2048, 24, 128, num_gqa_groups=12, attn_mask_type="padding"), + # "fp8_17": ModelConfig(1, 8192, 32, 128, num_gqa_groups=4, attn_mask_type="padding"), + # "fp8_18": ModelConfig(2, 2048, 16, 128, attn_mask_type="padding_causal"), + # "fp8_19": ModelConfig(2, 2048, 24, 128, num_gqa_groups=12, attn_mask_type="padding_causal"), + # "fp8_20": ModelConfig(1, 8192, 32, 128, num_gqa_groups=4, attn_mask_type="padding_causal"), } -param_types_fp8_vs_f16 = [torch.float16, torch.bfloat16] +param_types_fp8_vs_f16 = [torch.bfloat16] #[torch.float16, torch.bfloat16] qkv_layout_fp8_vs_f16 = ["sbh3d", "bshd_bshd_bshd", "sbhd_sbhd_sbhd"] qkv_format_fp8_vs_f16 = ["bshd", "sbhd"] @@ -2054,7 +2054,7 @@ def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker: hidden_states.requires_grad = True tensor = 0.01 * torch.randn(tensor_shape, dtype=dtype, device="cuda") out_grad = tensor.view(*tensor.shape[:-2], -1) - + print(f"type(out_grad): {type(out_grad)} {out_grad.shape}") with autocast(enabled=fp8_mha, recipe=fp8_recipe): out = mha( hidden_states, diff --git a/transformer_engine/common/fused_attn/fused_attn_fp8.cu b/transformer_engine/common/fused_attn/fused_attn_fp8.cu index 48c3975264..9796e39ddc 100644 --- a/transformer_engine/common/fused_attn/fused_attn_fp8.cu +++ b/transformer_engine/common/fused_attn/fused_attn_fp8.cu @@ -2301,43 +2301,6 @@ void fused_attn_fp8_bwd_impl_v1( generateMatrixStridesWithFormat(b, h, s_q, d_qk, q_t_stride.data(), q_format, false); generateMatrixStridesWithFormat(b, hg, s_kv, d_qk, k_t_stride.data(), kv_format, false); generateMatrixStridesWithFormat(b, h, s_q, d_v, dO_t_stride.data(), d_out_format, false); - printf("q_t_stride: %d, %d, %d, %d\n", q_t_stride[0], q_t_stride[1], q_t_stride[2], - q_t_stride[3]); - printf("k_t_stride: %d, %d, %d, %d\n", k_t_stride[0], k_t_stride[1], k_t_stride[2], - k_t_stride[3]); - printf("dO_t_stride: %d, %d, %d, %d\n", dO_t_stride[0], dO_t_stride[1], dO_t_stride[2], - dO_t_stride[3]); - printf("qkv_tensor_type: %d, %d, %d\n", qkv_tensor_type, - cudnn_frontend::DataType_t::FP8_E4M3, cudnn_frontend::DataType_t::FP8_E5M2); - printf("o_tensor_type: %d, %d, %d\n", o_tensor_type, cudnn_frontend::DataType_t::HALF, - cudnn_frontend::DataType_t::BFLOAT16); - printf("do_tensor_type: %d, %d, %d\n", do_tensor_type, cudnn_frontend::DataType_t::FP8_E4M3, - cudnn_frontend::DataType_t::FP8_E5M2); - printf("dqkv_tensor_type: %d, %d, %d\n", dqkv_tensor_type, - cudnn_frontend::DataType_t::FP8_E4M3, cudnn_frontend::DataType_t::FP8_E5M2); - printf("qkv_layout: %d, %d, %d\n", qkv_layout, NVTE_QKV_Layout::NVTE_BSHD_BSHD_BSHD, - NVTE_QKV_Layout::NVTE_BHSD_BHSD_BHSD); - printf("o_format: %d, %d, %d\n", o_format, NVTE_QKV_Format::NVTE_BSHD, - NVTE_QKV_Format::NVTE_BHSD); - printf("d_out_format: %d, %d, %d\n", d_out_format, NVTE_QKV_Format::NVTE_BSHD, - NVTE_QKV_Format::NVTE_BHSD); - printf("dqkv_layout: %d, %d, %d\n", dqkv_layout, NVTE_QKV_Layout::NVTE_BSHD_BSHD_BSHD, - NVTE_QKV_Layout::NVTE_BHSD_BHSD_BHSD); - printf("b: %d\n", b); - printf("h: %d\n", h); - printf("hg: %d\n", hg); - printf("s_q: %d\n", s_q); - printf("s_kv: %d\n", s_kv); - printf("d_qk: %d\n", d_qk); - printf("d_v: %d\n", d_v); - printf("is_delayed_scaling: %d\n", is_delayed_scaling); - printf("is_current_scaling: %d\n", is_current_scaling); - printf("is_O_in_F16: %d\n", is_O_in_F16); - printf("is_mxfp8: %d\n", is_mxfp8); - printf("is_causal: %d\n", is_causal); - printf("is_padding: %d\n", is_padding); - printf("is_dropout: %d\n", is_dropout); - printf("is_bias: %d\n", is_bias); Q_t = mha_graph->tensor(fe::graph::Tensor_attributes() .set_name("Q_t") .set_dim({b, h, s_q, d_qk}) @@ -2360,18 +2323,6 @@ void fused_attn_fp8_bwd_impl_v1( .set_data_type(o_tensor_type)); // Descale_q, Descale_q_t, Descale_k, Descale_k_t, Descale_v, Descale_dO, Descale_dO_t auto padded = pad_s_d_for_mxfp8(s_q, s_kv, d_qk, d_v); - printf("s_q_padded: %d\n", padded.s_q_padded); - printf("s_kv_padded: %d\n", padded.s_kv_padded); - printf("s_q_scale: %d\n", padded.s_q_scale); - printf("s_kv_scale: %d\n", padded.s_kv_scale); - printf("s_q_scale_padded: %d\n", padded.s_q_scale_padded); - printf("s_kv_scale_padded: %d\n", padded.s_kv_scale_padded); - printf("d_qk_padded: %d\n", padded.d_qk_padded); - printf("d_v_padded: %d\n", padded.d_v_padded); - printf("d_qk_scale: %d\n", padded.d_qk_scale); - printf("d_v_scale: %d\n", padded.d_v_scale); - printf("d_qk_scale_padded: %d\n", padded.d_qk_scale_padded); - printf("d_v_scale_padded: %d\n", padded.d_v_scale_padded); std::vector q_scale_strides(4); std::vector q_t_scale_strides(4); std::vector k_scale_strides(4); @@ -2393,20 +2344,6 @@ void fused_attn_fp8_bwd_impl_v1( dO_scale_strides.data(), d_out_format, false); generateMatrixStridesWithFormat(b, h, padded.s_q_scale_padded, padded.d_v_padded, dO_t_scale_strides.data(), d_out_format, false); - printf("q_scale_strides: %d, %d, %d, %d\n", q_scale_strides[0], q_scale_strides[1], - q_scale_strides[2], q_scale_strides[3]); - printf("q_t_scale_strides: %d, %d, %d, %d\n", q_t_scale_strides[0], q_t_scale_strides[1], - q_t_scale_strides[2], q_t_scale_strides[3]); - printf("k_scale_strides: %d, %d, %d, %d\n", k_scale_strides[0], k_scale_strides[1], - k_scale_strides[2], k_scale_strides[3]); - printf("k_t_scale_strides: %d, %d, %d, %d\n", k_t_scale_strides[0], k_t_scale_strides[1], - k_t_scale_strides[2], k_t_scale_strides[3]); - printf("v_scale_strides: %d, %d, %d, %d\n", v_scale_strides[0], v_scale_strides[1], - v_scale_strides[2], v_scale_strides[3]); - printf("dO_scale_strides: %d, %d, %d, %d\n", dO_scale_strides[0], dO_scale_strides[1], - dO_scale_strides[2], dO_scale_strides[3]); - printf("dO_t_scale_strides: %d, %d, %d, %d\n", dO_t_scale_strides[0], dO_t_scale_strides[1], - dO_t_scale_strides[2], dO_t_scale_strides[3]); descale_q = mha_graph->tensor(fe::graph::Tensor_attributes() .set_name("Descale_q") @@ -2562,9 +2499,6 @@ void fused_attn_fp8_bwd_impl_v1( NVTE_QKV_Matrix::NVTE_K_Matrix); generateMatrixStrides_v1(b, hg, hg, s_q, s_kv, d_qk, d_v, dv_stride.data(), dqkv_layout, NVTE_QKV_Matrix::NVTE_V_Matrix); - printf("dq_stride: %d, %d, %d, %d\n", dq_stride[0], dq_stride[1], dq_stride[2], dq_stride[3]); - printf("dk_stride: %d, %d, %d, %d\n", dk_stride[0], dk_stride[1], dk_stride[2], dk_stride[3]); - printf("dv_stride: %d, %d, %d, %d\n", dv_stride[0], dv_stride[1], dv_stride[2], dv_stride[3]); dQ->set_output(true) .set_dim({b, h, s_q, d_qk}) .set_stride(dq_stride) @@ -2694,7 +2628,7 @@ void fused_attn_fp8_bwd_impl_v1( variant_pack[scale_dP] = devPtrScaledP; variant_pack[amax_dP] = devPtrAmaxdP; } - if (is_current_scaling && !is_O_in_F16) { + if (is_delayed_scaling || (is_current_scaling && !is_O_in_F16)) { variant_pack[descale_o] = devPtrDescaleO; } if (is_delayed_scaling) { @@ -2712,42 +2646,6 @@ void fused_attn_fp8_bwd_impl_v1( // variant_pack[descale_dO] = devPtrDescaledO; variant_pack[descale_dO_t] = devPtrDescaledO_t; } - int64_t modulo = 16; - printf("devPtrQ: %p, is_aligned: %d\n", devPtrQ, is_aligned_modulo(devPtrQ, modulo)); - printf("devPtrK: %p, is_aligned: %d\n", devPtrK, is_aligned_modulo(devPtrK, modulo)); - printf("devPtrV: %p, is_aligned: %d\n", devPtrV, is_aligned_modulo(devPtrV, modulo)); - printf("devPtrO: %p, is_aligned: %d\n", devPtrO, is_aligned_modulo(devPtrO, modulo)); - printf("devPtrM: %p, is_aligned: %d\n", devPtrM, is_aligned_modulo(devPtrM, modulo)); - printf("devPtrdO: %p, is_aligned: %d\n", devPtrdO, is_aligned_modulo(devPtrdO, modulo)); - printf("devPtrDescaleQ: %p, is_aligned: %d\n", devPtrDescaleQ, - is_aligned_modulo(devPtrDescaleQ, modulo)); - printf("devPtrDescaleK: %p, is_aligned: %d\n", devPtrDescaleK, - is_aligned_modulo(devPtrDescaleK, modulo)); - printf("devPtrDescaleV: %p, is_aligned: %d\n", devPtrDescaleV, - is_aligned_modulo(devPtrDescaleV, modulo)); - printf("devPtrDescaledO: %p, is_aligned: %d\n", devPtrDescaledO, - is_aligned_modulo(devPtrDescaledO, modulo)); - printf("devPtrDescaledO_t: %p, is_aligned: %d\n", devPtrDescaledO_t, - is_aligned_modulo(devPtrDescaledO_t, modulo)); - printf("devPtrdQ: %p, is_aligned: %d\n", devPtrdQ, is_aligned_modulo(devPtrdQ, modulo)); - printf("devPtrdK: %p, is_aligned: %d\n", devPtrdK, is_aligned_modulo(devPtrdK, modulo)); - printf("devPtrdV: %p, is_aligned: %d\n", devPtrdV, is_aligned_modulo(devPtrdV, modulo)); - printf("devPtrAmaxdQ: %p, is_aligned: %d\n", devPtrAmaxdQ, - is_aligned_modulo(devPtrAmaxdQ, modulo)); - printf("devPtrAmaxdK: %p, is_aligned: %d\n", devPtrAmaxdK, - is_aligned_modulo(devPtrAmaxdK, modulo)); - printf("devPtrAmaxdV: %p, is_aligned: %d\n", devPtrAmaxdV, - is_aligned_modulo(devPtrAmaxdV, modulo)); - printf("devPtrQ_t: %p, is_aligned: %d\n", devPtrQ_t, is_aligned_modulo(devPtrQ_t, modulo)); - printf("devPtrK_t: %p, is_aligned: %d\n", devPtrK_t, is_aligned_modulo(devPtrK_t, modulo)); - printf("devPtrdO_f16: %p, is_aligned: %d\n", devPtrdO_f16, - is_aligned_modulo(devPtrdO_f16, modulo)); - printf("devPtrdO_t: %p, is_aligned: %d\n", devPtrdO_t, is_aligned_modulo(devPtrdO_t, modulo)); - printf("devPtrDescaleQ_t: %p, is_aligned: %d\n", devPtrDescaleQ_t, - is_aligned_modulo(devPtrDescaleQ_t, modulo)); - printf("devPtrDescaleK_t: %p, is_aligned: %d\n", devPtrDescaleK_t, - is_aligned_modulo(devPtrDescaleK_t, modulo)); - /* if (is_bias) { variant_pack[bias] = devPtrBias; if ((bias_b == 1) && (bias_h == h)) { diff --git a/transformer_engine/pytorch/attention/dot_product_attention/backends.py b/transformer_engine/pytorch/attention/dot_product_attention/backends.py index e3aacbf2e6..2aecd032c9 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/backends.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/backends.py @@ -36,7 +36,7 @@ restore_from_saved, ) from transformer_engine.pytorch.tensor.float8_tensor import Float8Tensor -from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8Tensor +from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8Tensor, MXFP8TensorStorage from transformer_engine.pytorch.constants import ( TE_DType, QKVLayouts, diff --git a/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py b/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py index de5563b7a6..653fd8cfb0 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py @@ -1430,10 +1430,6 @@ def forward( q_fp8, k_fp8, v_fp8 = (None, None, None) # communicate for the 'a2a' part of 'a2a+p2p' if cp_size_a2a > 1: - print( - f">>>>>>======================>>>>>> {torch.cuda.current_device()}: fp8: {fp8}," - f" is_input_fp8: {is_input_fp8}, fp8_recipe.mxfp8(): {fp8_recipe.mxfp8()}" - ) if fp8 and is_input_fp8: QKV_quantizer = q._quantizer q_fp8, k_fp8, v_fp8 = q, k, v @@ -1628,7 +1624,6 @@ def forward( out = None o_format = qkv_format for i in range(cp_size + 1): - print(f">>>>>>>>>>>> {torch.cuda.current_device()}: i: {i}, cp_size: {cp_size}") if i < cp_size: with torch.cuda.stream(flash_attn_streams[i % 2]): # wait until KV is received @@ -1917,10 +1912,6 @@ def forward( softmax_lse_per_step[0], seq_dim, ) - print( - f"====o/v===== {torch.cuda.current_device()}: i: {i}, {enable_mla}," - f" out.shape: {out.shape} {out_per_step[0].shape} {v_shape} {o_shape}" - ) if enable_mla: out = out.view(o_shape) else: @@ -1969,12 +1960,7 @@ def forward( elif o_format == "sbhd": out = out.view(-1, *out.shape[-3:]) ctx.batch_size = out.shape[1] - print(f"========= {torch.cuda.current_device()}: out.shape: {out.shape} {out.dtype}") out_part = out.to(fwd_nominal_dtype) - print( - f"========= {torch.cuda.current_device()}: out_part.shape:" - f" {out_part.shape} {out_part.dtype}" - ) if cp_size_a2a > 1: chunk_ids_for_a2a = get_seq_chunk_ids_for_reordering_after_attn(cp_size_a2a, out.device) @@ -2077,15 +2063,6 @@ def forward( q_f16 = q_f16.view(q.shape) kv_f16 = kv f16_tensors = (q_f16, kv_f16, out_f16) - if torch.cuda.current_device() == 0: - print( - "fp8_tensors:" - f" {[x.shape if x is not None else None for x in fp8_tensors]} {[type(x) for x in fp8_tensors]}" - ) - print( - "f16_tensors:" - f" {[x.shape if x is not None else None for x in f16_tensors]} {[type(x) for x in f16_tensors]}" - ) tensors_to_save, tensor_objects = prepare_for_saving( *fp8_tensors, @@ -2367,10 +2344,6 @@ def backward(ctx, dout, *_args): # communicate for the 'a2a' part of 'a2a+p2p' if cp_size_a2a > 1: - print( - f"========= {torch.cuda.current_device()}: before a2a: out.shape:" - f" {out.shape} {out.dtype} dout.shape: {dout.shape} {dout.dtype}" - ) if not ctx.use_fused_attention: # out = out.view(ctx.batch_size, -1, *out.shape[-2:]) dout = dout.view(*out.shape) @@ -2386,10 +2359,6 @@ def backward(ctx, dout, *_args): ctx.cp_stream, True, ) - print( - f"========= {torch.cuda.current_device()}: after a2a: dout.shape:" - f" {dout.shape} {dout.dtype} {q.shape} {ctx.v_shape}" - ) if ctx.enable_mla: out = out.view(*ctx.o_shape) @@ -3272,15 +3241,6 @@ def forward( f16_tensors = (q, k, v, out) else: f16_tensors = (q, k, v, out) - if torch.cuda.current_device() == 0: - print( - "fp8_tensors:" - f" {[x.shape if x is not None else None for x in fp8_tensors]} {[type(x) for x in fp8_tensors]}" - ) - print( - "f16_tensors:" - f" {[x.shape if x is not None else None for x in f16_tensors]} {[type(x) for x in f16_tensors]}" - ) tensors_to_save, tensor_objects = prepare_for_saving( *fp8_tensors, @@ -3385,8 +3345,6 @@ def backward(ctx, dout, *_args): if ctx.fp8 and not ctx.fp8_recipe.mxfp8(): q, k, v = [x._data for x in [q_fp8, k_fp8, v_fp8]] - if torch.cuda.current_device() == 0: - print(f"ctx.q_shape: {ctx.q_shape} {ctx.k_shape} {ctx.v_shape}") dq = torch.empty(ctx.q_shape, dtype=ctx.fwd_nominal_dtype, device=q.device) dk = torch.zeros( (ctx.k_shape[0] * cp_size, *ctx.k_shape[1:]), @@ -3398,10 +3356,6 @@ def backward(ctx, dout, *_args): dtype=ctx.fwd_nominal_dtype, device=v.device, ) - if torch.cuda.current_device() == 0: - print(f"dq: {dq.shape} {dq.dtype} {dq.device}") - print(f"dk: {dk.shape} {dk.dtype} {dk.device}") - print(f"dv: {dv.shape} {dv.dtype} {dv.device}") dq_per_step = [None, None] dk_per_step = [None, None] dv_per_step = [None, None] @@ -3512,23 +3466,11 @@ def backward(ctx, dout, *_args): q_part, k_part, v_part, new_qkv_layout = combine_and_quantize( qkv_layout, q_part, k_part, v_part, ctx.QKV_quantizer ) - print( - "aux_ctx_tensors:" - f" {len(aux_ctx_tensors)} {[x.shape if x is not None else None for x in aux_ctx_tensors]} {[type(x) for x in aux_ctx_tensors]}" - ) dout_part, d_out_format = dpa_utils.permute_to_grouped_tensor( d_out_format, dout_part ) aux_ctx_tensors.append(dout_part) dout_part = ctx.dO_quantizer(dout_part) - print( - "q_part type:" - f" {type(q_part)} {type(k_part)} {type(v_part)} {type(out_part)} {type(dout_part)}" - ) - print( - "q_part shape:" - f" {q_part.shape} {k_part.shape} {v_part.shape} {out_part.shape} {dout_part.shape}" - ) dq_per_step[i], dk_per_step[i], dv_per_step[i], *_ = fused_attn_bwd( ctx.max_seqlen_q, max_seqlen_kv, @@ -3606,11 +3548,6 @@ def backward(ctx, dout, *_args): if i > 0: with torch.cuda.stream(flash_attn_streams[i - 1]): - if torch.cuda.current_device() == 0: - print( - f"dq.shape: {dq.shape} dq_per_step[i - 1].shape:" - f" {dq_per_step[i - 1].shape}" - ) if ctx.qkv_format == "bshd": dq[:, i - 1].copy_(dq_per_step[i - 1]) elif ctx.qkv_format == "sbhd": @@ -4431,19 +4368,6 @@ def attn_forward_func_with_cp( in Megatron-LM. """ - print(f"rank {torch.distributed.get_rank()} attn_forward_func_with_cp: {cp_comm_type=}") - print(f"rank {torch.distributed.get_rank()} attn_forward_func_with_cp: {qkv_format=}") - # print(f"rank {torch.distributed.get_rank()} attn_forward_func_with_cp: {deterministic=}") - print(f"rank {torch.distributed.get_rank()} attn_forward_func_with_cp: {use_fused_attention=}") - print(f"rank {torch.distributed.get_rank()} attn_forward_func_with_cp: {fp8=}") - # print(f"rank {torch.distributed.get_rank()} attn_forward_func_with_cp: {fp8_meta=}") - # print(f"rank {torch.distributed.get_rank()} attn_forward_func_with_cp: {cp_group=}") - print(f"rank {torch.distributed.get_rank()} attn_forward_func_with_cp: {cp_global_ranks=}") - # print(f"rank {torch.distributed.get_rank()} attn_forward_func_with_cp: {cp_stream=}") - # print(f"rank {torch.distributed.get_rank()} attn_forward_func_with_cp: {quantizers=}") - print(f"rank {torch.distributed.get_rank()} attn_forward_func_with_cp: {pad_between_seqs=}") - print(f"rank {torch.distributed.get_rank()} attn_forward_func_with_cp: {fp8_output=}") - # print(f"rank {torch.distributed.get_rank()} attn_forward_func_with_cp: {layer_number=}") if cp_comm_type == "a2a+p2p": assert ( isinstance(cp_group, list) and len(cp_group) == 2 diff --git a/transformer_engine/pytorch/attention/dot_product_attention/utils.py b/transformer_engine/pytorch/attention/dot_product_attention/utils.py index 9a8d38547e..2f9929bffc 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/utils.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/utils.py @@ -2244,10 +2244,6 @@ def combine_and_quantize(qkv_layout, q, k, v, qkv_quantizer): original_shapes = [x.shape for x in [q, k, v]] s_q, d_qk = q.shape[-2:] s_kv, d_v = v.shape[-2:] - print( - f">>>>>>>>>>>> {torch.cuda.current_device()}: {qkv_layout} s_q: {s_q}, d_qk: {d_qk}," - f" s_kv: {s_kv}, d_v: {d_v}, {q.shape}, {k.shape}, {v.shape}" - ) assert s_q % 128 == 0 assert s_kv % 128 == 0 assert d_qk % 32 == 0 From 8500121daf3d3668b59d1eec242460d6d6e0f631 Mon Sep 17 00:00:00 2001 From: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Date: Tue, 3 Mar 2026 12:13:41 -0800 Subject: [PATCH 049/172] fix combine_and_quantize for f16 Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --- .../pytorch/attention/dot_product_attention/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/transformer_engine/pytorch/attention/dot_product_attention/utils.py b/transformer_engine/pytorch/attention/dot_product_attention/utils.py index 2f9929bffc..b572f087b2 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/utils.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/utils.py @@ -2304,7 +2304,7 @@ def combine_and_quantize(qkv_layout, q, k, v, qkv_quantizer): for x in [q_data, k_data, v_data] ] - return q_fp8, k_fp8, v_fp8 + return q_fp8, k_fp8, v_fp8, qkv_layout def combine_and_dequantize( From 0c2c4668f394fd3179fece9f4fef47d6efcef13e Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 3 Mar 2026 20:14:32 +0000 Subject: [PATCH 050/172] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/pytorch/attention/test_attention.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/tests/pytorch/attention/test_attention.py b/tests/pytorch/attention/test_attention.py index f39ed547cb..1173d91a16 100644 --- a/tests/pytorch/attention/test_attention.py +++ b/tests/pytorch/attention/test_attention.py @@ -1806,7 +1806,11 @@ def get_model(dtype, config): # test: ModelConfig(b, sq, hq, dqk) "fp8_9": ModelConfig(2, 2048, 24, 128, num_gqa_groups=12), # , attn_mask_type="causal"), "fp8_10": ModelConfig( - 2, 2048, 24, 192, head_dim_v=128, #num_gqa_groups=12, window_size=(512, 512) + 2, + 2048, + 24, + 192, + head_dim_v=128, # num_gqa_groups=12, window_size=(512, 512) ), # "fp8_11": ModelConfig(1, 8192, 32, 128, num_gqa_groups=4), # "fp8_12": ModelConfig(2, 2048, 16, 128, attn_mask_type="causal"), @@ -1820,7 +1824,7 @@ def get_model(dtype, config): # "fp8_20": ModelConfig(1, 8192, 32, 128, num_gqa_groups=4, attn_mask_type="padding_causal"), } -param_types_fp8_vs_f16 = [torch.bfloat16] #[torch.float16, torch.bfloat16] +param_types_fp8_vs_f16 = [torch.bfloat16] # [torch.float16, torch.bfloat16] qkv_layout_fp8_vs_f16 = ["sbh3d", "bshd_bshd_bshd", "sbhd_sbhd_sbhd"] qkv_format_fp8_vs_f16 = ["bshd", "sbhd"] From 6744aeeb90a85900eebc335ba0fb77370cf9c35c Mon Sep 17 00:00:00 2001 From: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Date: Tue, 3 Mar 2026 15:01:58 -0800 Subject: [PATCH 051/172] minor tweaks Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --- .../attention/dot_product_attention/backends.py | 17 ++++++++--------- .../dot_product_attention/context_parallel.py | 2 +- .../attention/dot_product_attention/utils.py | 4 ++++ .../pytorch/attention/multi_head_attention.py | 10 ++++++---- 4 files changed, 19 insertions(+), 14 deletions(-) diff --git a/transformer_engine/pytorch/attention/dot_product_attention/backends.py b/transformer_engine/pytorch/attention/dot_product_attention/backends.py index 2aecd032c9..95085a0fca 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/backends.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/backends.py @@ -1339,14 +1339,15 @@ def forward( # out: torch.Tensor; dtype = torch.float16 or torch.bfloat16 out_fp8 = out_ out = out_ + bwd_requires_o_f16 = is_training and (not is_bwd_fp8 or ( + is_bwd_fp8 and ((fp8_recipe.float8_current_scaling() and _dpa_fp8_cs_o_in_f16) or fp8_recipe.mxfp8()))) + bwd_requires_o_fp8 = is_training and is_bwd_fp8 and ( + fp8_recipe.delayed() or (fp8_recipe.float8_current_scaling() and not _dpa_fp8_cs_o_in_f16)) if isinstance(out_, QuantizedTensorStorage): - if not is_output_fp8 or not is_bwd_fp8: + if not is_output_fp8 or bwd_requires_o_f16: out = out_.dequantize().view(out_.shape) else: - if is_output_fp8 or ( - is_bwd_fp8 - and not (fp8_recipe.float8_current_scaling() and _dpa_fp8_cs_o_in_f16) - ): + if is_output_fp8 or bwd_requires_o_fp8: out_fp8 = O_quantizer(out_) # print quantizers @@ -1368,12 +1369,10 @@ def forward( fp8_tensors = (None, None, None, None) qkvo_tensors = (None, None, None, None) if is_bwd_fp8: - if (fp8_recipe.float8_current_scaling() and _dpa_fp8_cs_o_in_f16) or isinstance( - QKV_quantizer, MXFP8Quantizer - ): + if (fp8_recipe.float8_current_scaling() and _dpa_fp8_cs_o_in_f16) or fp8_recipe.mxfp8(): fp8_tensors = (q_fp8, k_fp8, v_fp8, None) qkvo_tensors = (None, None, None, out) - else: + elif fp8_recipe.delayed() or (fp8_recipe.float8_current_scaling() and not _dpa_fp8_cs_o_in_f16): fp8_tensors = (q_fp8, k_fp8, v_fp8, out_fp8) else: if is_input_fp8: diff --git a/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py b/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py index 653fd8cfb0..cf8986c7f6 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py @@ -4224,7 +4224,7 @@ def backward(ctx, dout, *_args): if ( ctx.fp8_recipe.float8_current_scaling() or ctx.fp8_recipe.mxfp8() ) and ctx.is_input_fp8: - dq, dk, dv = combine_and_quantize(ctx.dqkv_layout, dq, dk, dv, ctx.dQKV_quantizer) + dq, dk, dv, _ = combine_and_quantize(ctx.dqkv_layout, dq, dk, dv, ctx.dQKV_quantizer) if ctx.fp8_recipe.delayed(): dq, dk, dv = [ Float8Tensor.make_like(x, data=y, dtype=bwd_nominal_dtype) diff --git a/transformer_engine/pytorch/attention/dot_product_attention/utils.py b/transformer_engine/pytorch/attention/dot_product_attention/utils.py index b572f087b2..20ae4d135d 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/utils.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/utils.py @@ -2258,6 +2258,10 @@ def combine_and_quantize(qkv_layout, q, k, v, qkv_quantizer): ) q_fp8, k_fp8, v_fp8 = grouped_tensor.quantized_tensors else: + # grouped_tensor = GroupedTensor.create_and_quantize( + # tensors=[q, k], quantizer=qkv_quantizer + # ) + # q_fp8, k_fp8 = grouped_tensor.quantized_tensors q_fp8 = qkv_quantizer(q) k_fp8 = qkv_quantizer(k) v_fp8 = qkv_quantizer(v) diff --git a/transformer_engine/pytorch/attention/multi_head_attention.py b/transformer_engine/pytorch/attention/multi_head_attention.py index 01c4955d78..801c2f525b 100644 --- a/transformer_engine/pytorch/attention/multi_head_attention.py +++ b/transformer_engine/pytorch/attention/multi_head_attention.py @@ -784,14 +784,16 @@ def forward( fp8_dpa = fp8_recipe.fp8_dpa fp8_mha = fp8_recipe.fp8_mha float8_current_scaling = fp8_recipe.float8_current_scaling() + mxfp8_scaling = fp8_recipe.mxfp8() else: fp8_dpa = _dpa_fp8_recipe_dpa fp8_mha = _dpa_fp8_recipe_mha float8_current_scaling = _dpa_fp8_recipe == "Float8CurrentScaling" - # QKV Gemm: do not produce FP8 output when in Float8CurrentScaling recipe - qkv_fp8_output = fp8 and fp8_mha and rotary_pos_emb is None and not float8_current_scaling - # DPA: always produce FP8 output when fp8=True to take advantage of the O amax - dpa_fp8_output = fp8 and (fp8_dpa or fp8_mha) + mxfp8_scaling = _dpa_fp8_recipe == "MXFP8BlockScaling" + # QKV Gemm: do not produce FP8 output when in Float8CurrentScaling or MXFP8BlockScaling recipe + qkv_fp8_output = fp8 and fp8_mha and rotary_pos_emb is None and not float8_current_scaling and not mxfp8_scaling + # DPA: produce FP8 output when fp8=True to take advantage of the O amax except for MXFP8BlockScaling + dpa_fp8_output = fp8 and (fp8_dpa or fp8_mha) and not mxfp8_scaling # Proj Gemm: match DPA output except for Float8CurrentScaling proj_fp8_grad = dpa_fp8_output and not float8_current_scaling From 4cec878b6e6eb40ff9346db658aeca125a6411d8 Mon Sep 17 00:00:00 2001 From: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Date: Tue, 3 Mar 2026 15:02:17 -0800 Subject: [PATCH 052/172] tweak tests Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --- tests/pytorch/attention/test_attention.py | 15 +++++++-------- 1 file changed, 7 insertions(+), 8 deletions(-) diff --git a/tests/pytorch/attention/test_attention.py b/tests/pytorch/attention/test_attention.py index 1173d91a16..92b8ade67e 100644 --- a/tests/pytorch/attention/test_attention.py +++ b/tests/pytorch/attention/test_attention.py @@ -1804,13 +1804,14 @@ def get_model(dtype, config): model_configs_fp8_vs_f16 = { # test: ModelConfig(b, sq, hq, dqk) - "fp8_9": ModelConfig(2, 2048, 24, 128, num_gqa_groups=12), # , attn_mask_type="causal"), + "fp8_9": ModelConfig(2, 4096, 32, 128, num_gqa_groups=4, attn_mask_type="causal", window_size=(128, 0)), "fp8_10": ModelConfig( 2, - 2048, - 24, + 4096, + 128, 192, - head_dim_v=128, # num_gqa_groups=12, window_size=(512, 512) + head_dim_v=128, + attn_mask_type="causal", ), # "fp8_11": ModelConfig(1, 8192, 32, 128, num_gqa_groups=4), # "fp8_12": ModelConfig(2, 2048, 16, 128, attn_mask_type="causal"), @@ -1871,7 +1872,7 @@ def test_mha_fp8_vs_f16( ) elif scaling_mode == "mxfp8": fp8_recipe = recipe.MXFP8BlockScaling( - fp8_format=recipe.Format.HYBRID, + fp8_format=recipe.Format.E4M3, fp8_dpa=True, fp8_mha=True, ) @@ -2058,7 +2059,6 @@ def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker: hidden_states.requires_grad = True tensor = 0.01 * torch.randn(tensor_shape, dtype=dtype, device="cuda") out_grad = tensor.view(*tensor.shape[:-2], -1) - print(f"type(out_grad): {type(out_grad)} {out_grad.shape}") with autocast(enabled=fp8_mha, recipe=fp8_recipe): out = mha( hidden_states, @@ -2128,7 +2128,7 @@ def test_dpa_fp8_vs_f16(dtype, model, qkv_layout, fp8_dpa_bwd, is_training, scal ) elif scaling_mode == "mxfp8": fp8_recipe = recipe.MXFP8BlockScaling( - fp8_format=recipe.Format.HYBRID, + fp8_format=recipe.Format.E4M3, fp8_dpa=True, fp8_mha=False, ) @@ -2401,7 +2401,6 @@ def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker: attn_mask_type=config.attn_mask_type, checkpoint_core_attention=False, core_attention_bias_type=config.attn_bias_type, - fp8_output=fp8_dpa, ) if is_training: out.backward(out_grad) From 5c8e939ab2ea123941539573d1b06eee01d8aa94 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 3 Mar 2026 23:03:11 +0000 Subject: [PATCH 053/172] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/pytorch/attention/test_attention.py | 4 ++- .../dot_product_attention/backends.py | 30 +++++++++++++++---- .../dot_product_attention/context_parallel.py | 4 ++- .../pytorch/attention/multi_head_attention.py | 8 ++++- 4 files changed, 37 insertions(+), 9 deletions(-) diff --git a/tests/pytorch/attention/test_attention.py b/tests/pytorch/attention/test_attention.py index 92b8ade67e..7ae73a753a 100644 --- a/tests/pytorch/attention/test_attention.py +++ b/tests/pytorch/attention/test_attention.py @@ -1804,7 +1804,9 @@ def get_model(dtype, config): model_configs_fp8_vs_f16 = { # test: ModelConfig(b, sq, hq, dqk) - "fp8_9": ModelConfig(2, 4096, 32, 128, num_gqa_groups=4, attn_mask_type="causal", window_size=(128, 0)), + "fp8_9": ModelConfig( + 2, 4096, 32, 128, num_gqa_groups=4, attn_mask_type="causal", window_size=(128, 0) + ), "fp8_10": ModelConfig( 2, 4096, diff --git a/transformer_engine/pytorch/attention/dot_product_attention/backends.py b/transformer_engine/pytorch/attention/dot_product_attention/backends.py index 95085a0fca..906f3ade45 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/backends.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/backends.py @@ -1339,10 +1339,24 @@ def forward( # out: torch.Tensor; dtype = torch.float16 or torch.bfloat16 out_fp8 = out_ out = out_ - bwd_requires_o_f16 = is_training and (not is_bwd_fp8 or ( - is_bwd_fp8 and ((fp8_recipe.float8_current_scaling() and _dpa_fp8_cs_o_in_f16) or fp8_recipe.mxfp8()))) - bwd_requires_o_fp8 = is_training and is_bwd_fp8 and ( - fp8_recipe.delayed() or (fp8_recipe.float8_current_scaling() and not _dpa_fp8_cs_o_in_f16)) + bwd_requires_o_f16 = is_training and ( + not is_bwd_fp8 + or ( + is_bwd_fp8 + and ( + (fp8_recipe.float8_current_scaling() and _dpa_fp8_cs_o_in_f16) + or fp8_recipe.mxfp8() + ) + ) + ) + bwd_requires_o_fp8 = ( + is_training + and is_bwd_fp8 + and ( + fp8_recipe.delayed() + or (fp8_recipe.float8_current_scaling() and not _dpa_fp8_cs_o_in_f16) + ) + ) if isinstance(out_, QuantizedTensorStorage): if not is_output_fp8 or bwd_requires_o_f16: out = out_.dequantize().view(out_.shape) @@ -1369,10 +1383,14 @@ def forward( fp8_tensors = (None, None, None, None) qkvo_tensors = (None, None, None, None) if is_bwd_fp8: - if (fp8_recipe.float8_current_scaling() and _dpa_fp8_cs_o_in_f16) or fp8_recipe.mxfp8(): + if ( + fp8_recipe.float8_current_scaling() and _dpa_fp8_cs_o_in_f16 + ) or fp8_recipe.mxfp8(): fp8_tensors = (q_fp8, k_fp8, v_fp8, None) qkvo_tensors = (None, None, None, out) - elif fp8_recipe.delayed() or (fp8_recipe.float8_current_scaling() and not _dpa_fp8_cs_o_in_f16): + elif fp8_recipe.delayed() or ( + fp8_recipe.float8_current_scaling() and not _dpa_fp8_cs_o_in_f16 + ): fp8_tensors = (q_fp8, k_fp8, v_fp8, out_fp8) else: if is_input_fp8: diff --git a/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py b/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py index cf8986c7f6..56c36aef8a 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py @@ -4224,7 +4224,9 @@ def backward(ctx, dout, *_args): if ( ctx.fp8_recipe.float8_current_scaling() or ctx.fp8_recipe.mxfp8() ) and ctx.is_input_fp8: - dq, dk, dv, _ = combine_and_quantize(ctx.dqkv_layout, dq, dk, dv, ctx.dQKV_quantizer) + dq, dk, dv, _ = combine_and_quantize( + ctx.dqkv_layout, dq, dk, dv, ctx.dQKV_quantizer + ) if ctx.fp8_recipe.delayed(): dq, dk, dv = [ Float8Tensor.make_like(x, data=y, dtype=bwd_nominal_dtype) diff --git a/transformer_engine/pytorch/attention/multi_head_attention.py b/transformer_engine/pytorch/attention/multi_head_attention.py index 801c2f525b..0a276bdc8a 100644 --- a/transformer_engine/pytorch/attention/multi_head_attention.py +++ b/transformer_engine/pytorch/attention/multi_head_attention.py @@ -791,7 +791,13 @@ def forward( float8_current_scaling = _dpa_fp8_recipe == "Float8CurrentScaling" mxfp8_scaling = _dpa_fp8_recipe == "MXFP8BlockScaling" # QKV Gemm: do not produce FP8 output when in Float8CurrentScaling or MXFP8BlockScaling recipe - qkv_fp8_output = fp8 and fp8_mha and rotary_pos_emb is None and not float8_current_scaling and not mxfp8_scaling + qkv_fp8_output = ( + fp8 + and fp8_mha + and rotary_pos_emb is None + and not float8_current_scaling + and not mxfp8_scaling + ) # DPA: produce FP8 output when fp8=True to take advantage of the O amax except for MXFP8BlockScaling dpa_fp8_output = fp8 and (fp8_dpa or fp8_mha) and not mxfp8_scaling # Proj Gemm: match DPA output except for Float8CurrentScaling From 7b6b364499701a5bb4c7e56f0a96aadb9315fa09 Mon Sep 17 00:00:00 2001 From: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Date: Tue, 3 Mar 2026 15:56:03 -0800 Subject: [PATCH 054/172] fix ds descale_o Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --- .../common/fused_attn/fused_attn_fp8.cu | 102 ++++++++++++++++++ 1 file changed, 102 insertions(+) diff --git a/transformer_engine/common/fused_attn/fused_attn_fp8.cu b/transformer_engine/common/fused_attn/fused_attn_fp8.cu index 9796e39ddc..f3557eeb68 100644 --- a/transformer_engine/common/fused_attn/fused_attn_fp8.cu +++ b/transformer_engine/common/fused_attn/fused_attn_fp8.cu @@ -2301,6 +2301,43 @@ void fused_attn_fp8_bwd_impl_v1( generateMatrixStridesWithFormat(b, h, s_q, d_qk, q_t_stride.data(), q_format, false); generateMatrixStridesWithFormat(b, hg, s_kv, d_qk, k_t_stride.data(), kv_format, false); generateMatrixStridesWithFormat(b, h, s_q, d_v, dO_t_stride.data(), d_out_format, false); + printf("q_t_stride: %d, %d, %d, %d\n", q_t_stride[0], q_t_stride[1], q_t_stride[2], + q_t_stride[3]); + printf("k_t_stride: %d, %d, %d, %d\n", k_t_stride[0], k_t_stride[1], k_t_stride[2], + k_t_stride[3]); + printf("dO_t_stride: %d, %d, %d, %d\n", dO_t_stride[0], dO_t_stride[1], dO_t_stride[2], + dO_t_stride[3]); + printf("qkv_tensor_type: %d, %d, %d\n", qkv_tensor_type, + cudnn_frontend::DataType_t::FP8_E4M3, cudnn_frontend::DataType_t::FP8_E5M2); + printf("o_tensor_type: %d, %d, %d\n", o_tensor_type, cudnn_frontend::DataType_t::HALF, + cudnn_frontend::DataType_t::BFLOAT16); + printf("do_tensor_type: %d, %d, %d\n", do_tensor_type, cudnn_frontend::DataType_t::FP8_E4M3, + cudnn_frontend::DataType_t::FP8_E5M2); + printf("dqkv_tensor_type: %d, %d, %d\n", dqkv_tensor_type, + cudnn_frontend::DataType_t::FP8_E4M3, cudnn_frontend::DataType_t::FP8_E5M2); + printf("qkv_layout: %d, %d, %d\n", qkv_layout, NVTE_QKV_Layout::NVTE_BSHD_BSHD_BSHD, + NVTE_QKV_Layout::NVTE_BHSD_BHSD_BHSD); + printf("o_format: %d, %d, %d\n", o_format, NVTE_QKV_Format::NVTE_BSHD, + NVTE_QKV_Format::NVTE_BHSD); + printf("d_out_format: %d, %d, %d\n", d_out_format, NVTE_QKV_Format::NVTE_BSHD, + NVTE_QKV_Format::NVTE_BHSD); + printf("dqkv_layout: %d, %d, %d\n", dqkv_layout, NVTE_QKV_Layout::NVTE_BSHD_BSHD_BSHD, + NVTE_QKV_Layout::NVTE_BHSD_BHSD_BHSD); + printf("b: %d\n", b); + printf("h: %d\n", h); + printf("hg: %d\n", hg); + printf("s_q: %d\n", s_q); + printf("s_kv: %d\n", s_kv); + printf("d_qk: %d\n", d_qk); + printf("d_v: %d\n", d_v); + printf("is_delayed_scaling: %d\n", is_delayed_scaling); + printf("is_current_scaling: %d\n", is_current_scaling); + printf("is_O_in_F16: %d\n", is_O_in_F16); + printf("is_mxfp8: %d\n", is_mxfp8); + printf("is_causal: %d\n", is_causal); + printf("is_padding: %d\n", is_padding); + printf("is_dropout: %d\n", is_dropout); + printf("is_bias: %d\n", is_bias); Q_t = mha_graph->tensor(fe::graph::Tensor_attributes() .set_name("Q_t") .set_dim({b, h, s_q, d_qk}) @@ -2323,6 +2360,18 @@ void fused_attn_fp8_bwd_impl_v1( .set_data_type(o_tensor_type)); // Descale_q, Descale_q_t, Descale_k, Descale_k_t, Descale_v, Descale_dO, Descale_dO_t auto padded = pad_s_d_for_mxfp8(s_q, s_kv, d_qk, d_v); + printf("s_q_padded: %d\n", padded.s_q_padded); + printf("s_kv_padded: %d\n", padded.s_kv_padded); + printf("s_q_scale: %d\n", padded.s_q_scale); + printf("s_kv_scale: %d\n", padded.s_kv_scale); + printf("s_q_scale_padded: %d\n", padded.s_q_scale_padded); + printf("s_kv_scale_padded: %d\n", padded.s_kv_scale_padded); + printf("d_qk_padded: %d\n", padded.d_qk_padded); + printf("d_v_padded: %d\n", padded.d_v_padded); + printf("d_qk_scale: %d\n", padded.d_qk_scale); + printf("d_v_scale: %d\n", padded.d_v_scale); + printf("d_qk_scale_padded: %d\n", padded.d_qk_scale_padded); + printf("d_v_scale_padded: %d\n", padded.d_v_scale_padded); std::vector q_scale_strides(4); std::vector q_t_scale_strides(4); std::vector k_scale_strides(4); @@ -2344,6 +2393,20 @@ void fused_attn_fp8_bwd_impl_v1( dO_scale_strides.data(), d_out_format, false); generateMatrixStridesWithFormat(b, h, padded.s_q_scale_padded, padded.d_v_padded, dO_t_scale_strides.data(), d_out_format, false); + printf("q_scale_strides: %d, %d, %d, %d\n", q_scale_strides[0], q_scale_strides[1], + q_scale_strides[2], q_scale_strides[3]); + printf("q_t_scale_strides: %d, %d, %d, %d\n", q_t_scale_strides[0], q_t_scale_strides[1], + q_t_scale_strides[2], q_t_scale_strides[3]); + printf("k_scale_strides: %d, %d, %d, %d\n", k_scale_strides[0], k_scale_strides[1], + k_scale_strides[2], k_scale_strides[3]); + printf("k_t_scale_strides: %d, %d, %d, %d\n", k_t_scale_strides[0], k_t_scale_strides[1], + k_t_scale_strides[2], k_t_scale_strides[3]); + printf("v_scale_strides: %d, %d, %d, %d\n", v_scale_strides[0], v_scale_strides[1], + v_scale_strides[2], v_scale_strides[3]); + printf("dO_scale_strides: %d, %d, %d, %d\n", dO_scale_strides[0], dO_scale_strides[1], + dO_scale_strides[2], dO_scale_strides[3]); + printf("dO_t_scale_strides: %d, %d, %d, %d\n", dO_t_scale_strides[0], dO_t_scale_strides[1], + dO_t_scale_strides[2], dO_t_scale_strides[3]); descale_q = mha_graph->tensor(fe::graph::Tensor_attributes() .set_name("Descale_q") @@ -2499,6 +2562,9 @@ void fused_attn_fp8_bwd_impl_v1( NVTE_QKV_Matrix::NVTE_K_Matrix); generateMatrixStrides_v1(b, hg, hg, s_q, s_kv, d_qk, d_v, dv_stride.data(), dqkv_layout, NVTE_QKV_Matrix::NVTE_V_Matrix); + printf("dq_stride: %d, %d, %d, %d\n", dq_stride[0], dq_stride[1], dq_stride[2], dq_stride[3]); + printf("dk_stride: %d, %d, %d, %d\n", dk_stride[0], dk_stride[1], dk_stride[2], dk_stride[3]); + printf("dv_stride: %d, %d, %d, %d\n", dv_stride[0], dv_stride[1], dv_stride[2], dv_stride[3]); dQ->set_output(true) .set_dim({b, h, s_q, d_qk}) .set_stride(dq_stride) @@ -2646,6 +2712,42 @@ void fused_attn_fp8_bwd_impl_v1( // variant_pack[descale_dO] = devPtrDescaledO; variant_pack[descale_dO_t] = devPtrDescaledO_t; } + int64_t modulo = 16; + printf("devPtrQ: %p, is_aligned: %d\n", devPtrQ, is_aligned_modulo(devPtrQ, modulo)); + printf("devPtrK: %p, is_aligned: %d\n", devPtrK, is_aligned_modulo(devPtrK, modulo)); + printf("devPtrV: %p, is_aligned: %d\n", devPtrV, is_aligned_modulo(devPtrV, modulo)); + printf("devPtrO: %p, is_aligned: %d\n", devPtrO, is_aligned_modulo(devPtrO, modulo)); + printf("devPtrM: %p, is_aligned: %d\n", devPtrM, is_aligned_modulo(devPtrM, modulo)); + printf("devPtrdO: %p, is_aligned: %d\n", devPtrdO, is_aligned_modulo(devPtrdO, modulo)); + printf("devPtrDescaleQ: %p, is_aligned: %d\n", devPtrDescaleQ, + is_aligned_modulo(devPtrDescaleQ, modulo)); + printf("devPtrDescaleK: %p, is_aligned: %d\n", devPtrDescaleK, + is_aligned_modulo(devPtrDescaleK, modulo)); + printf("devPtrDescaleV: %p, is_aligned: %d\n", devPtrDescaleV, + is_aligned_modulo(devPtrDescaleV, modulo)); + printf("devPtrDescaledO: %p, is_aligned: %d\n", devPtrDescaledO, + is_aligned_modulo(devPtrDescaledO, modulo)); + printf("devPtrDescaledO_t: %p, is_aligned: %d\n", devPtrDescaledO_t, + is_aligned_modulo(devPtrDescaledO_t, modulo)); + printf("devPtrdQ: %p, is_aligned: %d\n", devPtrdQ, is_aligned_modulo(devPtrdQ, modulo)); + printf("devPtrdK: %p, is_aligned: %d\n", devPtrdK, is_aligned_modulo(devPtrdK, modulo)); + printf("devPtrdV: %p, is_aligned: %d\n", devPtrdV, is_aligned_modulo(devPtrdV, modulo)); + printf("devPtrAmaxdQ: %p, is_aligned: %d\n", devPtrAmaxdQ, + is_aligned_modulo(devPtrAmaxdQ, modulo)); + printf("devPtrAmaxdK: %p, is_aligned: %d\n", devPtrAmaxdK, + is_aligned_modulo(devPtrAmaxdK, modulo)); + printf("devPtrAmaxdV: %p, is_aligned: %d\n", devPtrAmaxdV, + is_aligned_modulo(devPtrAmaxdV, modulo)); + printf("devPtrQ_t: %p, is_aligned: %d\n", devPtrQ_t, is_aligned_modulo(devPtrQ_t, modulo)); + printf("devPtrK_t: %p, is_aligned: %d\n", devPtrK_t, is_aligned_modulo(devPtrK_t, modulo)); + printf("devPtrdO_f16: %p, is_aligned: %d\n", devPtrdO_f16, + is_aligned_modulo(devPtrdO_f16, modulo)); + printf("devPtrdO_t: %p, is_aligned: %d\n", devPtrdO_t, is_aligned_modulo(devPtrdO_t, modulo)); + printf("devPtrDescaleQ_t: %p, is_aligned: %d\n", devPtrDescaleQ_t, + is_aligned_modulo(devPtrDescaleQ_t, modulo)); + printf("devPtrDescaleK_t: %p, is_aligned: %d\n", devPtrDescaleK_t, + is_aligned_modulo(devPtrDescaleK_t, modulo)); + /* if (is_bias) { variant_pack[bias] = devPtrBias; if ((bias_b == 1) && (bias_h == h)) { From 462eb4f5ce7da5cae5c8d11d32d334a642242ecc Mon Sep 17 00:00:00 2001 From: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Date: Tue, 3 Mar 2026 15:58:32 -0800 Subject: [PATCH 055/172] Revert "fix ds descale_o" This reverts commit cd0bd82e239ff01210338b4e34cb8784109d22ec. Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --- .../common/fused_attn/fused_attn_fp8.cu | 102 ------------------ 1 file changed, 102 deletions(-) diff --git a/transformer_engine/common/fused_attn/fused_attn_fp8.cu b/transformer_engine/common/fused_attn/fused_attn_fp8.cu index f3557eeb68..9796e39ddc 100644 --- a/transformer_engine/common/fused_attn/fused_attn_fp8.cu +++ b/transformer_engine/common/fused_attn/fused_attn_fp8.cu @@ -2301,43 +2301,6 @@ void fused_attn_fp8_bwd_impl_v1( generateMatrixStridesWithFormat(b, h, s_q, d_qk, q_t_stride.data(), q_format, false); generateMatrixStridesWithFormat(b, hg, s_kv, d_qk, k_t_stride.data(), kv_format, false); generateMatrixStridesWithFormat(b, h, s_q, d_v, dO_t_stride.data(), d_out_format, false); - printf("q_t_stride: %d, %d, %d, %d\n", q_t_stride[0], q_t_stride[1], q_t_stride[2], - q_t_stride[3]); - printf("k_t_stride: %d, %d, %d, %d\n", k_t_stride[0], k_t_stride[1], k_t_stride[2], - k_t_stride[3]); - printf("dO_t_stride: %d, %d, %d, %d\n", dO_t_stride[0], dO_t_stride[1], dO_t_stride[2], - dO_t_stride[3]); - printf("qkv_tensor_type: %d, %d, %d\n", qkv_tensor_type, - cudnn_frontend::DataType_t::FP8_E4M3, cudnn_frontend::DataType_t::FP8_E5M2); - printf("o_tensor_type: %d, %d, %d\n", o_tensor_type, cudnn_frontend::DataType_t::HALF, - cudnn_frontend::DataType_t::BFLOAT16); - printf("do_tensor_type: %d, %d, %d\n", do_tensor_type, cudnn_frontend::DataType_t::FP8_E4M3, - cudnn_frontend::DataType_t::FP8_E5M2); - printf("dqkv_tensor_type: %d, %d, %d\n", dqkv_tensor_type, - cudnn_frontend::DataType_t::FP8_E4M3, cudnn_frontend::DataType_t::FP8_E5M2); - printf("qkv_layout: %d, %d, %d\n", qkv_layout, NVTE_QKV_Layout::NVTE_BSHD_BSHD_BSHD, - NVTE_QKV_Layout::NVTE_BHSD_BHSD_BHSD); - printf("o_format: %d, %d, %d\n", o_format, NVTE_QKV_Format::NVTE_BSHD, - NVTE_QKV_Format::NVTE_BHSD); - printf("d_out_format: %d, %d, %d\n", d_out_format, NVTE_QKV_Format::NVTE_BSHD, - NVTE_QKV_Format::NVTE_BHSD); - printf("dqkv_layout: %d, %d, %d\n", dqkv_layout, NVTE_QKV_Layout::NVTE_BSHD_BSHD_BSHD, - NVTE_QKV_Layout::NVTE_BHSD_BHSD_BHSD); - printf("b: %d\n", b); - printf("h: %d\n", h); - printf("hg: %d\n", hg); - printf("s_q: %d\n", s_q); - printf("s_kv: %d\n", s_kv); - printf("d_qk: %d\n", d_qk); - printf("d_v: %d\n", d_v); - printf("is_delayed_scaling: %d\n", is_delayed_scaling); - printf("is_current_scaling: %d\n", is_current_scaling); - printf("is_O_in_F16: %d\n", is_O_in_F16); - printf("is_mxfp8: %d\n", is_mxfp8); - printf("is_causal: %d\n", is_causal); - printf("is_padding: %d\n", is_padding); - printf("is_dropout: %d\n", is_dropout); - printf("is_bias: %d\n", is_bias); Q_t = mha_graph->tensor(fe::graph::Tensor_attributes() .set_name("Q_t") .set_dim({b, h, s_q, d_qk}) @@ -2360,18 +2323,6 @@ void fused_attn_fp8_bwd_impl_v1( .set_data_type(o_tensor_type)); // Descale_q, Descale_q_t, Descale_k, Descale_k_t, Descale_v, Descale_dO, Descale_dO_t auto padded = pad_s_d_for_mxfp8(s_q, s_kv, d_qk, d_v); - printf("s_q_padded: %d\n", padded.s_q_padded); - printf("s_kv_padded: %d\n", padded.s_kv_padded); - printf("s_q_scale: %d\n", padded.s_q_scale); - printf("s_kv_scale: %d\n", padded.s_kv_scale); - printf("s_q_scale_padded: %d\n", padded.s_q_scale_padded); - printf("s_kv_scale_padded: %d\n", padded.s_kv_scale_padded); - printf("d_qk_padded: %d\n", padded.d_qk_padded); - printf("d_v_padded: %d\n", padded.d_v_padded); - printf("d_qk_scale: %d\n", padded.d_qk_scale); - printf("d_v_scale: %d\n", padded.d_v_scale); - printf("d_qk_scale_padded: %d\n", padded.d_qk_scale_padded); - printf("d_v_scale_padded: %d\n", padded.d_v_scale_padded); std::vector q_scale_strides(4); std::vector q_t_scale_strides(4); std::vector k_scale_strides(4); @@ -2393,20 +2344,6 @@ void fused_attn_fp8_bwd_impl_v1( dO_scale_strides.data(), d_out_format, false); generateMatrixStridesWithFormat(b, h, padded.s_q_scale_padded, padded.d_v_padded, dO_t_scale_strides.data(), d_out_format, false); - printf("q_scale_strides: %d, %d, %d, %d\n", q_scale_strides[0], q_scale_strides[1], - q_scale_strides[2], q_scale_strides[3]); - printf("q_t_scale_strides: %d, %d, %d, %d\n", q_t_scale_strides[0], q_t_scale_strides[1], - q_t_scale_strides[2], q_t_scale_strides[3]); - printf("k_scale_strides: %d, %d, %d, %d\n", k_scale_strides[0], k_scale_strides[1], - k_scale_strides[2], k_scale_strides[3]); - printf("k_t_scale_strides: %d, %d, %d, %d\n", k_t_scale_strides[0], k_t_scale_strides[1], - k_t_scale_strides[2], k_t_scale_strides[3]); - printf("v_scale_strides: %d, %d, %d, %d\n", v_scale_strides[0], v_scale_strides[1], - v_scale_strides[2], v_scale_strides[3]); - printf("dO_scale_strides: %d, %d, %d, %d\n", dO_scale_strides[0], dO_scale_strides[1], - dO_scale_strides[2], dO_scale_strides[3]); - printf("dO_t_scale_strides: %d, %d, %d, %d\n", dO_t_scale_strides[0], dO_t_scale_strides[1], - dO_t_scale_strides[2], dO_t_scale_strides[3]); descale_q = mha_graph->tensor(fe::graph::Tensor_attributes() .set_name("Descale_q") @@ -2562,9 +2499,6 @@ void fused_attn_fp8_bwd_impl_v1( NVTE_QKV_Matrix::NVTE_K_Matrix); generateMatrixStrides_v1(b, hg, hg, s_q, s_kv, d_qk, d_v, dv_stride.data(), dqkv_layout, NVTE_QKV_Matrix::NVTE_V_Matrix); - printf("dq_stride: %d, %d, %d, %d\n", dq_stride[0], dq_stride[1], dq_stride[2], dq_stride[3]); - printf("dk_stride: %d, %d, %d, %d\n", dk_stride[0], dk_stride[1], dk_stride[2], dk_stride[3]); - printf("dv_stride: %d, %d, %d, %d\n", dv_stride[0], dv_stride[1], dv_stride[2], dv_stride[3]); dQ->set_output(true) .set_dim({b, h, s_q, d_qk}) .set_stride(dq_stride) @@ -2712,42 +2646,6 @@ void fused_attn_fp8_bwd_impl_v1( // variant_pack[descale_dO] = devPtrDescaledO; variant_pack[descale_dO_t] = devPtrDescaledO_t; } - int64_t modulo = 16; - printf("devPtrQ: %p, is_aligned: %d\n", devPtrQ, is_aligned_modulo(devPtrQ, modulo)); - printf("devPtrK: %p, is_aligned: %d\n", devPtrK, is_aligned_modulo(devPtrK, modulo)); - printf("devPtrV: %p, is_aligned: %d\n", devPtrV, is_aligned_modulo(devPtrV, modulo)); - printf("devPtrO: %p, is_aligned: %d\n", devPtrO, is_aligned_modulo(devPtrO, modulo)); - printf("devPtrM: %p, is_aligned: %d\n", devPtrM, is_aligned_modulo(devPtrM, modulo)); - printf("devPtrdO: %p, is_aligned: %d\n", devPtrdO, is_aligned_modulo(devPtrdO, modulo)); - printf("devPtrDescaleQ: %p, is_aligned: %d\n", devPtrDescaleQ, - is_aligned_modulo(devPtrDescaleQ, modulo)); - printf("devPtrDescaleK: %p, is_aligned: %d\n", devPtrDescaleK, - is_aligned_modulo(devPtrDescaleK, modulo)); - printf("devPtrDescaleV: %p, is_aligned: %d\n", devPtrDescaleV, - is_aligned_modulo(devPtrDescaleV, modulo)); - printf("devPtrDescaledO: %p, is_aligned: %d\n", devPtrDescaledO, - is_aligned_modulo(devPtrDescaledO, modulo)); - printf("devPtrDescaledO_t: %p, is_aligned: %d\n", devPtrDescaledO_t, - is_aligned_modulo(devPtrDescaledO_t, modulo)); - printf("devPtrdQ: %p, is_aligned: %d\n", devPtrdQ, is_aligned_modulo(devPtrdQ, modulo)); - printf("devPtrdK: %p, is_aligned: %d\n", devPtrdK, is_aligned_modulo(devPtrdK, modulo)); - printf("devPtrdV: %p, is_aligned: %d\n", devPtrdV, is_aligned_modulo(devPtrdV, modulo)); - printf("devPtrAmaxdQ: %p, is_aligned: %d\n", devPtrAmaxdQ, - is_aligned_modulo(devPtrAmaxdQ, modulo)); - printf("devPtrAmaxdK: %p, is_aligned: %d\n", devPtrAmaxdK, - is_aligned_modulo(devPtrAmaxdK, modulo)); - printf("devPtrAmaxdV: %p, is_aligned: %d\n", devPtrAmaxdV, - is_aligned_modulo(devPtrAmaxdV, modulo)); - printf("devPtrQ_t: %p, is_aligned: %d\n", devPtrQ_t, is_aligned_modulo(devPtrQ_t, modulo)); - printf("devPtrK_t: %p, is_aligned: %d\n", devPtrK_t, is_aligned_modulo(devPtrK_t, modulo)); - printf("devPtrdO_f16: %p, is_aligned: %d\n", devPtrdO_f16, - is_aligned_modulo(devPtrdO_f16, modulo)); - printf("devPtrdO_t: %p, is_aligned: %d\n", devPtrdO_t, is_aligned_modulo(devPtrdO_t, modulo)); - printf("devPtrDescaleQ_t: %p, is_aligned: %d\n", devPtrDescaleQ_t, - is_aligned_modulo(devPtrDescaleQ_t, modulo)); - printf("devPtrDescaleK_t: %p, is_aligned: %d\n", devPtrDescaleK_t, - is_aligned_modulo(devPtrDescaleK_t, modulo)); - /* if (is_bias) { variant_pack[bias] = devPtrBias; if ((bias_b == 1) && (bias_h == h)) { From 77995d2d949c667dafe19da1ba9406b2ed7a117e Mon Sep 17 00:00:00 2001 From: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Date: Fri, 6 Mar 2026 17:43:49 -0800 Subject: [PATCH 056/172] minor fixes for p2p and ag Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --- .../attention/run_attention_with_cp.py | 22 ++++--- .../attention/test_attention_with_cp.py | 64 +++++++++++-------- .../dot_product_attention/context_parallel.py | 44 +++++++++---- 3 files changed, 81 insertions(+), 49 deletions(-) diff --git a/tests/pytorch/attention/run_attention_with_cp.py b/tests/pytorch/attention/run_attention_with_cp.py index 5cb43f277a..949dbf3d1c 100644 --- a/tests/pytorch/attention/run_attention_with_cp.py +++ b/tests/pytorch/attention/run_attention_with_cp.py @@ -186,13 +186,17 @@ def run_dpa_with_cp( scaling_mode="delayed", f16_O="False", is_training="True", + deterministic="False", log_level=logging.WARNING, ): """Test DotProductAttention module with context parallelism""" logging.root.setLevel(log_level) # When is_training is False, gradient outputs are None. is_training = is_training == "True" - + if deterministic == "True": + os.environ["NVTE_ALLOW_NONDETERMINISTIC_ALGO"] = "0" + else: + os.environ["NVTE_ALLOW_NONDETERMINISTIC_ALGO"] = "1" # set up environment variables and config fp8_bwd = fp8_bwd == "True" and dtype == "fp8" os.environ["NVTE_FP8_DPA_BWD"] = "1" if fp8_bwd else "0" @@ -228,7 +232,6 @@ def run_dpa_with_cp( device_count = torch.cuda.device_count() device = rank % device_count torch.cuda.set_device(device) - print(f"rank: {rank}, world_size: {world_size}") logging.info(f"[Rank {rank}] Setup: world_size {world_size}") dist.init_process_group(backend="nccl", world_size=world_size, rank=rank) @@ -330,7 +333,7 @@ def run_dpa_with_cp( dout_quantizer.internal = False qkv_layout = "_".join([qkv_format] * 3) q, k, v, dout = [x.clone().detach() for x in [q_orig, k_orig, v_orig, dout_orig]] - if fp8_mha: + if fp8_mha and scaling_mode != "mxfp8": q, k, v, qkv_layout = combine_and_quantize(qkv_layout, q, k, v, qkv_quantizer) for x in [q, k, v]: x.requires_grad = True @@ -377,12 +380,12 @@ def run_dpa_with_cp( cu_seqlens_kv=cu_seqlens_kv, cu_seqlens_q_padded=cu_seqlens_q_padded, cu_seqlens_kv_padded=cu_seqlens_kv_padded, - fp8_output=fp8_mha, + # fp8_output=fp8_mha, ) if config.return_max_logit: out, max_logit = out if is_training: - if fp8_bwd and fp8_mha: + if fp8_bwd and fp8_mha and scaling_mode != "mxfp8": dout_fp8 = dout_quantizer(dout) out.backward(dout_fp8) else: @@ -438,7 +441,7 @@ def run_dpa_with_cp( qkv_quantizer.amax.fill_(0.0) dout_quantizer.scale.fill_(1.0) dout_quantizer.amax.fill_(0.0) - if fp8_mha: + if fp8_mha and scaling_mode != "mxfp8": q_, k_, v_, qkv_layout = combine_and_quantize(qkv_layout, q_, k_, v_, qkv_quantizer) if is_training: q_, k_, v_ = [x.requires_grad_() for x in [q_, k_, v_]] @@ -494,12 +497,12 @@ def run_dpa_with_cp( cu_seqlens_kv=cu_seqlens_kv, cu_seqlens_q_padded=cu_seqlens_q_padded, cu_seqlens_kv_padded=cu_seqlens_kv_padded, - fp8_output=fp8_mha, + # fp8_output=fp8_mha, ) if config.return_max_logit: out_, max_logit_ = out_ if is_training: - if fp8_bwd and fp8_mha: + if fp8_bwd and fp8_mha and scaling_mode != "mxfp8": dout_fp8_ = dout_quantizer(dout_) out_.backward(dout_fp8_) else: @@ -528,9 +531,10 @@ def run_dpa_with_cp( tensors_to_deq[i] = tensor.dequantize() if not fp8_bwd: tensors[0], tensors[5] = tensors_to_deq - for tensor in tensors: + for i, tensor in enumerate(tensors): # dbias/dbias_ could be None, so skip check for it if tensor is not None: + print(f"========= {torch.cuda.current_device()}: tensors[{i}].shape: {tensor.shape} {tensor.dtype} {torch.isnan(tensor).any()} {torch.isinf(tensor).any()}") assert torch.all(~torch.isnan(tensor)) assert torch.all(~torch.isinf(tensor)) out, dq, dk, dv, dbias, out_, dq_, dk_, dv_, dbias_ = tensors diff --git a/tests/pytorch/attention/test_attention_with_cp.py b/tests/pytorch/attention/test_attention_with_cp.py index a5fe8f74f5..dc079a7193 100644 --- a/tests/pytorch/attention/test_attention_with_cp.py +++ b/tests/pytorch/attention/test_attention_with_cp.py @@ -28,6 +28,12 @@ pytest_logging_level = logging.getLevelName(logging.root.level) +# Get determinism +_deterministic = ( + not bool(int(os.getenv("NVTE_ALLOW_NONDETERMINISTIC_ALGO", "1"))) + or torch.are_deterministic_algorithms_enabled() +) + # Initialize RNG state seed = 1234 torch.manual_seed(seed) @@ -153,9 +159,9 @@ def test_cp_with_flash_attention(dtype, model, qkv_format, cp_comm_type): 2, 4096, 12, 128, attn_bias_type="post_scale_bias", bias_shape="bhss" ), # MHA "cp_1_5": ModelConfig(2, 4096, 12, 128, attn_mask_type="causal", window_size=(512, 512)), # MHA - "cp_2_0": ModelConfig(2, 4096, 12, 128, num_gqa_groups=2, attn_mask_type="causal"), # GQA + "cp_2_0": ModelConfig(2, 4096, 32, 128, num_gqa_groups=4, attn_mask_type="causal", window_size=(128, 0)), # GQA "cp_2_1": ModelConfig( - 2, 4096, 16, 192, head_dim_v=128 + 2, 4096, 128, 192, head_dim_v=128, attn_mask_type="causal" ), # num_gqa_groups=4, attn_mask_type="causal"), # GQA "cp_2_2": ModelConfig( 2, @@ -219,23 +225,23 @@ def test_cp_with_flash_attention(dtype, model, qkv_format, cp_comm_type): cp_comm_types = ["p2p", "all_gather", "a2a", "a2a+p2p"] if test_essential: configs = [ - "cp_1_0", - "cp_1_1", - "cp_1_4", - "cp_1_5", + # "cp_1_0", + # "cp_1_1", + # "cp_1_4", + # "cp_1_5", "cp_2_0", "cp_2_1", - "cp_2_2", - "cp_2_3", - "cp_2_4", - "cp_3_1", - "cp_3_2", - "cp_3_4", - "cp_4_2", + # "cp_2_2", + # "cp_2_3", + # "cp_2_4", + # "cp_3_1", + # "cp_3_2", + # "cp_3_4", + # "cp_4_2", ] model_configs_fused_attn = {k: model_configs_fused_attn[k] for k in configs} - dtypes = ["bf16", "fp8"] - qkv_formats = ["bshd", "sbhd", "thd"] + dtypes = ["fp8"] #["bf16", "fp8"] + qkv_formats = ["bshd"]#, "sbhd", "thd"] @pytest.mark.skipif(get_cudnn_version() < (8, 9, 7), reason="cuDNN 8.9.7+ is required.") @@ -247,11 +253,15 @@ def test_cp_with_flash_attention(dtype, model, qkv_format, cp_comm_type): @pytest.mark.parametrize("fp8_bwd", [True, False]) @pytest.mark.parametrize("fp8_mha", [True, False]) @pytest.mark.parametrize("fp8_dpa", [True, False]) -@pytest.mark.parametrize("scaling_mode", [None, "delayed", "current", "mxfp8"]) -@pytest.mark.parametrize("f16_O", [True, False]) +@pytest.mark.parametrize("scaling_mode", ["delayed", "current", "mxfp8"]) +@pytest.mark.parametrize("f16_O", [True]) def test_cp_with_fused_attention( dtype, model, qkv_format, cp_comm_type, fp8_bwd, fp8_mha, fp8_dpa, scaling_mode, f16_O ): + # # TODO: Remove this once MXFP8 is supported with fp8_bwd=True! + # if scaling_mode == "mxfp8" and fp8_bwd: + # pytest.skip("MXFP8 only works with fp8_bwd=False!") + num_gpus = 4 if cp_comm_type == "a2a+p2p" else 2 if num_gpus > torch.cuda.device_count(): pytest.skip(f"Test requires {num_gpus} GPUs, but found {torch.cuda.device_count()}") @@ -289,8 +299,8 @@ def test_cp_with_fused_attention( pytest.skip("FP8 attention cannot work with THD format yet!") if dtype == "fp8" and config.attn_bias_type != "no_bias": pytest.skip("FP8 attention cannot work with bias yet!") - if dtype == "fp8" and config.window_size != (-1, 0) and config.window_size != (-1, -1): - pytest.skip("FP8 attention cannot work with sliding window yet!") + # if dtype == "fp8" and config.window_size != (-1, 0) and config.window_size != (-1, -1): + # pytest.skip("FP8 attention cannot work with sliding window yet!") if "p2p" in cp_comm_type and config.window_size != (-1, 0) and config.window_size != (-1, -1): pytest.skip("CP implementation with KV P2P does not support sliding window yet!") if cp_comm_type == "all_gather" and config.attn_bias_type != "no_bias": @@ -310,14 +320,14 @@ def test_cp_with_fused_attention( pytest.skip("Only fp8 works with scaling_mode != None!") if dtype == "fp8" and scaling_mode is None: pytest.skip("fp8 only works with scaling_mode != None!") - if ( - dtype == "fp8" - and scaling_mode == "current" - and cp_comm_type not in ["p2p", "a2a+p2p", "a2a"] - ): - pytest.skip("fp8 only works with P2P, A2A and A2A+P2P for scaling_mode = current!") + # if ( + # dtype == "fp8" + # and scaling_mode == "current" + # and cp_comm_type not in ["p2p", "a2a+p2p", "a2a"] + # ): + # pytest.skip("fp8 only works with P2P, A2A and A2A+P2P for scaling_mode = current!") if f16_O and (dtype != "fp8" or scaling_mode not in ["current", "mxfp8"]): - pytest.skip("f16_O only needs to be tested for dtype = fp8 and scaling_mode = current!") + pytest.skip("f16_O only needs to be tested for dtype = fp8 and scaling_mode in [current, mxfp8]!") # if cp_comm_type not in ["p2p", "a2a+p2p", "a2a"] and config.head_dim_qk != config.head_dim_v: # pytest.skip("MLA CP currently only support KV P2P!") # if dtype == "fp8" and config.head_dim_qk != config.head_dim_v: @@ -377,6 +387,7 @@ def test_cp_with_fused_attention( fp8=fp8, fp8_meta=fp8_meta, is_training=is_training, + deterministic=_deterministic, ) _, fused_attn_supported, _ = available_backends if not fused_attn_supported: @@ -396,6 +407,7 @@ def test_cp_with_fused_attention( scaling_mode=scaling_mode, f16_O=f16_O, is_training=is_training, + deterministic=_deterministic, log_level=pytest_logging_level, ), check=True, diff --git a/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py b/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py index 56c36aef8a..7886c625b4 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py @@ -2020,7 +2020,6 @@ def forward( kv_fp8 = None kv = p2p_comm_buffers[-1] - q_fp8, kv_fp8 = None, None if fp8 and not fp8_recipe.mxfp8(): q_fp8, kv_fp8 = [ Float8Tensor.make_like(x, data=y, dtype=fwd_nominal_dtype) @@ -3018,6 +3017,7 @@ def forward( fwd_nominal_dtype = q.dtype fp8_meta_kwargs = {} q_fp8, k_fp8, v_fp8 = (None, None, None) + q_f16, k_f16, v_f16 = (None, None, None) fused_attn_backend = None if fp8: assert use_fused_attention, "FP8 is only supported with Fused Attention!" @@ -3025,9 +3025,12 @@ def forward( if is_input_fp8: q_fp8, k_fp8, v_fp8 = q, k, v elif not fp8_recipe.mxfp8(): + q_f16, k_f16, v_f16 = q, k, v q_fp8, k_fp8, v_fp8, qkv_layout = combine_and_quantize( qkv_layout, q, k, v, QKV_quantizer ) + else: + q_f16, k_f16, v_f16 = q, k, v if not fp8_recipe.mxfp8(): q, k, v = [q_fp8._data, k_fp8._data, v_fp8._data] fp8_meta_kwargs["s_quantizer"] = S_quantizer @@ -3149,6 +3152,7 @@ def forward( cuda_graph=is_graph_capturing(), **fp8_meta_kwargs, ) + if fp8: softmax_lse_per_step[i], _, rng_states[i] = aux_ctx_tensors else: @@ -3225,20 +3229,27 @@ def forward( ctx.fp8_recipe = fp8_recipe fp8_tensors = (None, None, None, None) f16_tensors = (None, None, None, None) + ctx.qkv_reshaped = True if ctx.fp8: + q_fp8_save, k_fp8_save, v_fp8_save = None, None, None + if fp8_recipe.delayed() or fp8_recipe.float8_current_scaling(): + q_fp8_save = Float8Tensor.make_like(q_fp8, data=q, dtype=fwd_nominal_dtype) + k_fp8_save = Float8Tensor.make_like(k_fp8, data=k, dtype=fwd_nominal_dtype) + v_fp8_save = Float8Tensor.make_like(v_fp8, data=v, dtype=fwd_nominal_dtype) if fp8_recipe.delayed(): - fp8_tensors = (q_fp8, k_fp8, v_fp8, out_fp8) + fp8_tensors = (q_fp8_save, k_fp8_save, v_fp8_save, out_fp8) if fp8_recipe.float8_current_scaling() and not _dpa_fp8_cs_o_in_f16: - fp8_tensors = (q_fp8, k_fp8, v_fp8, out_fp8) + fp8_tensors = (q_fp8_save, k_fp8_save, v_fp8_save, out_fp8) if fp8_recipe.float8_current_scaling() and _dpa_fp8_cs_o_in_f16: - fp8_tensors = (q_fp8, k_fp8, v_fp8, None) + fp8_tensors = (q_fp8_save, k_fp8_save, v_fp8_save, None) f16_tensors = (None, None, None, out) if fp8_recipe.mxfp8(): f16_tensors = (q, k, v, out) elif fp8: if is_input_fp8: q, k, v = combine_and_dequantize(qkv_layout, q_fp8, k_fp8, v_fp8) - f16_tensors = (q, k, v, out) + f16_tensors = (q_f16, k_f16, v_f16, out) + ctx.qkv_reshaped = False else: f16_tensors = (q, k, v, out) @@ -3276,13 +3287,17 @@ def forward( ctx.fp8_meta = fp8_meta ctx.is_input_fp8 = is_input_fp8 ctx.is_output_fp8 = is_output_fp8 - if fp8: + ctx.dQKV_quantizer = dQKV_quantizer + ctx.dO_quantizer = dO_quantizer + ctx.dP_quantizer = dP_quantizer + ctx.QKV_quantizer = QKV_quantizer + ctx.O_quantizer = O_quantizer + ctx.S_quantizer = S_quantizer + if ctx.fp8: ctx.QKV_quantizer = QKV_quantizer.copy() ctx.O_quantizer = O_quantizer.copy() ctx.S_quantizer = S_quantizer.copy() if S_quantizer is not None else None - ctx.dQKV_quantizer = dQKV_quantizer.copy() - ctx.dO_quantizer = dO_quantizer.copy() - ctx.dP_quantizer = dP_quantizer.copy() if dP_quantizer is not None else None + if not ctx.fp8_recipe.mxfp8(): ctx.QKV_quantizer.scale = QKV_quantizer.scale.clone() ctx.O_quantizer.scale = O_quantizer.scale.clone() @@ -3333,17 +3348,18 @@ def backward(ctx, dout, *_args): dout = dout.view(ctx.out_shape) dout_fp8 = None if ctx.fp8: - if ( - ctx.is_output_fp8 - and not isinstance(dout, QuantizedTensorStorage) - and not ctx.fp8_recipe.mxfp8() - ): + if isinstance(dout, QuantizedTensorStorage): + dout_fp8 = dout + elif not ctx.fp8_recipe.mxfp8(): dout_fp8 = ctx.dO_quantizer(dout) if not ctx.fp8_recipe.mxfp8(): dout = dout_fp8._data if ctx.fp8 and not ctx.fp8_recipe.mxfp8(): q, k, v = [x._data for x in [q_fp8, k_fp8, v_fp8]] + if not ctx.qkv_reshaped: + q = q.view(*q.shape[:seq_dim], 2, q.shape[seq_dim] // 2, *q.shape[(seq_dim + 1) :]) + k, v = [x.movedim(seq_dim, 0).contiguous() for x in [k, v]] dq = torch.empty(ctx.q_shape, dtype=ctx.fwd_nominal_dtype, device=q.device) dk = torch.zeros( From 586b698bcb513d536709fd47824df92bfc0e185c Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sat, 7 Mar 2026 01:44:38 +0000 Subject: [PATCH 057/172] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/pytorch/attention/run_attention_with_cp.py | 5 ++++- tests/pytorch/attention/test_attention_with_cp.py | 12 ++++++++---- 2 files changed, 12 insertions(+), 5 deletions(-) diff --git a/tests/pytorch/attention/run_attention_with_cp.py b/tests/pytorch/attention/run_attention_with_cp.py index 949dbf3d1c..242d6b9e7a 100644 --- a/tests/pytorch/attention/run_attention_with_cp.py +++ b/tests/pytorch/attention/run_attention_with_cp.py @@ -534,7 +534,10 @@ def run_dpa_with_cp( for i, tensor in enumerate(tensors): # dbias/dbias_ could be None, so skip check for it if tensor is not None: - print(f"========= {torch.cuda.current_device()}: tensors[{i}].shape: {tensor.shape} {tensor.dtype} {torch.isnan(tensor).any()} {torch.isinf(tensor).any()}") + print( + f"========= {torch.cuda.current_device()}: tensors[{i}].shape:" + f" {tensor.shape} {tensor.dtype} {torch.isnan(tensor).any()} {torch.isinf(tensor).any()}" + ) assert torch.all(~torch.isnan(tensor)) assert torch.all(~torch.isinf(tensor)) out, dq, dk, dv, dbias, out_, dq_, dk_, dv_, dbias_ = tensors diff --git a/tests/pytorch/attention/test_attention_with_cp.py b/tests/pytorch/attention/test_attention_with_cp.py index dc079a7193..10ab2dffe7 100644 --- a/tests/pytorch/attention/test_attention_with_cp.py +++ b/tests/pytorch/attention/test_attention_with_cp.py @@ -159,7 +159,9 @@ def test_cp_with_flash_attention(dtype, model, qkv_format, cp_comm_type): 2, 4096, 12, 128, attn_bias_type="post_scale_bias", bias_shape="bhss" ), # MHA "cp_1_5": ModelConfig(2, 4096, 12, 128, attn_mask_type="causal", window_size=(512, 512)), # MHA - "cp_2_0": ModelConfig(2, 4096, 32, 128, num_gqa_groups=4, attn_mask_type="causal", window_size=(128, 0)), # GQA + "cp_2_0": ModelConfig( + 2, 4096, 32, 128, num_gqa_groups=4, attn_mask_type="causal", window_size=(128, 0) + ), # GQA "cp_2_1": ModelConfig( 2, 4096, 128, 192, head_dim_v=128, attn_mask_type="causal" ), # num_gqa_groups=4, attn_mask_type="causal"), # GQA @@ -240,8 +242,8 @@ def test_cp_with_flash_attention(dtype, model, qkv_format, cp_comm_type): # "cp_4_2", ] model_configs_fused_attn = {k: model_configs_fused_attn[k] for k in configs} - dtypes = ["fp8"] #["bf16", "fp8"] - qkv_formats = ["bshd"]#, "sbhd", "thd"] + dtypes = ["fp8"] # ["bf16", "fp8"] + qkv_formats = ["bshd"] # , "sbhd", "thd"] @pytest.mark.skipif(get_cudnn_version() < (8, 9, 7), reason="cuDNN 8.9.7+ is required.") @@ -327,7 +329,9 @@ def test_cp_with_fused_attention( # ): # pytest.skip("fp8 only works with P2P, A2A and A2A+P2P for scaling_mode = current!") if f16_O and (dtype != "fp8" or scaling_mode not in ["current", "mxfp8"]): - pytest.skip("f16_O only needs to be tested for dtype = fp8 and scaling_mode in [current, mxfp8]!") + pytest.skip( + "f16_O only needs to be tested for dtype = fp8 and scaling_mode in [current, mxfp8]!" + ) # if cp_comm_type not in ["p2p", "a2a+p2p", "a2a"] and config.head_dim_qk != config.head_dim_v: # pytest.skip("MLA CP currently only support KV P2P!") # if dtype == "fp8" and config.head_dim_qk != config.head_dim_v: From 1e7cd70b4f8a34685cdaa951f8dfe97a7be0ed9b Mon Sep 17 00:00:00 2001 From: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Date: Fri, 6 Mar 2026 18:42:35 -0800 Subject: [PATCH 058/172] tweak cp test skips Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --- .../attention/test_attention_with_cp.py | 164 ++++++++---------- .../attention/dot_product_attention/utils.py | 59 ++++++- 2 files changed, 121 insertions(+), 102 deletions(-) diff --git a/tests/pytorch/attention/test_attention_with_cp.py b/tests/pytorch/attention/test_attention_with_cp.py index 10ab2dffe7..c9cdc6baf8 100644 --- a/tests/pytorch/attention/test_attention_with_cp.py +++ b/tests/pytorch/attention/test_attention_with_cp.py @@ -102,25 +102,29 @@ def test_cp_with_flash_attention(dtype, model, qkv_format, cp_comm_type): config.context_parallel = True config.cp_comm_type = cp_comm_type - if "p2p" in cp_comm_type and config.window_size != (-1, 0) and config.window_size != (-1, -1): - pytest.skip("CP implementation with KV P2P does not support sliding window yet!") - if cp_comm_type == "all_gather" and config.attn_bias_type != "no_bias": - pytest.skip("CP implementation with KV all-gather does not support bias yet!") - if qkv_format == "thd": - if cp_comm_type == "all_gather": - pytest.skip("CP implementation with KV all-gather does not support THD format yet!") - if cp_comm_type == "a2a+p2p": - pytest.skip( - "CP implementation with QKVO A2A+P2P (Hierarchical A2A) does not support THD format" - " yet!" - ) - if "a2a" in cp_comm_type and config.attn_bias_type != "no_bias": - pytest.skip("CP implementation with QKVO A2A does not support bias yet!") - if "a2a" in cp_comm_type and (config.num_heads % 2 != 0 or config.num_gqa_groups % 2 != 0): + if config.attn_bias_type != "no_bias" and qkv_format == "thd": + pytest.skip("No support for bias with THD format!") + if config.attn_bias_type != "no_bias" and cp_comm_type in ["all_gather", "a2a", "a2a+p2p"]: + pytest.skip("No support for bias with cp_comm_type={all_gather, a2a, a2a+p2p}!") + + if qkv_format == "thd" and cp_comm_type in ["all_gather", "a2a+p2p"]: + pytest.skip("No support for THD format with cp_comm_type={all_gather, a2a+p2p}!") + + if config.window_size != (-1, 0) and config.window_size != (-1, -1) and cp_comm_type in [ + "p2p", + "a2a+p2p", + ]: + pytest.skip("No support for SWA with cp_comm_type={p2p, a2a+p2p}!") + + if cp_comm_type in ["a2a", "a2a+p2p"] and ( + config.num_heads % 2 != 0 or config.num_gqa_groups % 2 != 0 + ): pytest.skip( - f"CP implementation with QKVO A2A requires num_heads ({config.num_heads}) and" - f" num_gqa_groups ({config.num_gqa_groups}) to be divisible by cp_size (2)!" + f"cp_comm_type=a2a requires num_heads ({config.num_heads}) and" + f" num_gqa_groups ({config.num_gqa_groups}) divisible by 2!" ) + + # FlashAttention / CP implementation specific: MLA only with KV P2P if "p2p" not in cp_comm_type and config.head_dim_qk != config.head_dim_v: pytest.skip("MLA CP currently only support KV P2P!") dtypes = {"fp16": torch.float16, "bf16": torch.bfloat16} @@ -260,99 +264,67 @@ def test_cp_with_flash_attention(dtype, model, qkv_format, cp_comm_type): def test_cp_with_fused_attention( dtype, model, qkv_format, cp_comm_type, fp8_bwd, fp8_mha, fp8_dpa, scaling_mode, f16_O ): - # # TODO: Remove this once MXFP8 is supported with fp8_bwd=True! - # if scaling_mode == "mxfp8" and fp8_bwd: - # pytest.skip("MXFP8 only works with fp8_bwd=False!") + config = model_configs_fused_attn[model] + config.context_parallel = True + config.cp_comm_type = cp_comm_type num_gpus = 4 if cp_comm_type == "a2a+p2p" else 2 if num_gpus > torch.cuda.device_count(): - pytest.skip(f"Test requires {num_gpus} GPUs, but found {torch.cuda.device_count()}") + pytest.skip(f"Test requires {num_gpus} GPUs, but found {torch.cuda.device_count()} GPUs.") - if qkv_format == "thd" and get_device_compute_capability() < (9, 0): - pytest.skip("THD format is only supported on sm90+!") - if cp_comm_type == "all_gather" and get_cudnn_version() < (9, 3, 0): - pytest.skip("CP implementation with KV all-gather is only supported with cuDNN >= 9.3.0!") - if dtype == "fp8" and get_device_compute_capability() < (9, 0): - pytest.skip("FP8 attention is only supported on sm90+!") + if get_device_compute_capability() < (9, 0) and qkv_format == "thd": + pytest.skip("Only sm90+ architectures support THD format!") + if get_device_compute_capability() < (9, 0) and dtype == "fp8": + pytest.skip("Only sm90+ architectures support FP8 attention!") + + if dtype == "fp8" and not (fp8_mha or fp8_dpa): + pytest.skip("dtype=fp8 requires fp8_dpa=True or fp8_mha=True!") if dtype == "fp8" and not fp8_dpa and fp8_mha: pytest.skip("Duplicate tests to fp8_dpa=True and fp8_mha=True!") if dtype != "fp8" and fp8_bwd: - pytest.skip("Only fp8 works with fp8_bwd=True!") - - config = model_configs_fused_attn[model] - config.context_parallel = True - config.cp_comm_type = cp_comm_type + pytest.skip("fp8_bwd=True requires dtype=fp8!") + if dtype != "fp8" and (fp8_mha or fp8_dpa): + pytest.skip("dtype!=fp8 requires fp8_dpa=False and fp8_mha=False!") - if qkv_format == "thd" and config.attn_bias_type == "post_scale_bias": - pytest.skip("THD format does not support post_scale_bias yet!") - if qkv_format == "thd": - if cp_comm_type == "all_gather": - pytest.skip("CP implementation with KV all-gather does not support THD format yet!") - if cp_comm_type == "a2a+p2p": - pytest.skip( - "CP implementation with QKVO A2A+P2P (Hierarchical A2A) does not support THD format" - " yet!" - ) - # if dtype == "fp8" and cp_comm_type == "all_gather": - # pytest.skip( - # "CP implementation with KV all-gather does not support FP8 + context parallelism yet!" - # ) if dtype == "fp8" and qkv_format == "thd": - pytest.skip("FP8 attention cannot work with THD format yet!") + pytest.skip("No support for FP8 attention with THD format!") if dtype == "fp8" and config.attn_bias_type != "no_bias": - pytest.skip("FP8 attention cannot work with bias yet!") - # if dtype == "fp8" and config.window_size != (-1, 0) and config.window_size != (-1, -1): - # pytest.skip("FP8 attention cannot work with sliding window yet!") - if "p2p" in cp_comm_type and config.window_size != (-1, 0) and config.window_size != (-1, -1): - pytest.skip("CP implementation with KV P2P does not support sliding window yet!") - if cp_comm_type == "all_gather" and config.attn_bias_type != "no_bias": - pytest.skip("CP implementation with KV all-gather does not support bias yet!") - if "a2a" in cp_comm_type and config.attn_bias_type != "no_bias": - pytest.skip("CP implementation with QKVO A2A does not support bias yet!") - if "a2a" in cp_comm_type and (config.num_heads % 2 != 0 or config.num_gqa_groups % 2 != 0): - pytest.skip( - f"CP implementation with QKVO A2A requires num_heads ({config.num_heads}) and" - f" num_gqa_groups ({config.num_gqa_groups}) to be divisible by cp_size (2)!" - ) - if dtype != "fp8" and (fp8_mha or fp8_dpa): - pytest.skip("Only fp8 works with fp8_dpa=True or fp8_mha=True!") - if dtype == "fp8" and not (fp8_mha or fp8_dpa): - pytest.skip("fp8 only works with fp8_dpa=True or fp8_mha=True!") - if dtype != "fp8" and scaling_mode is not None: - pytest.skip("Only fp8 works with scaling_mode != None!") - if dtype == "fp8" and scaling_mode is None: - pytest.skip("fp8 only works with scaling_mode != None!") - # if ( - # dtype == "fp8" - # and scaling_mode == "current" - # and cp_comm_type not in ["p2p", "a2a+p2p", "a2a"] - # ): - # pytest.skip("fp8 only works with P2P, A2A and A2A+P2P for scaling_mode = current!") - if f16_O and (dtype != "fp8" or scaling_mode not in ["current", "mxfp8"]): + pytest.skip("No support for FP8 attention with bias!") + + if config.attn_bias_type != "no_bias" and qkv_format == "thd": + pytest.skip("No supprt for bias with THD format!") + if config.attn_bias_type != "no_bias" and cp_comm_type in ["all_gather", "a2a", "a2a+p2p"]: + pytest.skip("No support for bias with cp_comm_type={all_gather, a2a, a2a+p2p}!") + + if qkv_format == "thd" and cp_comm_type in ["all_gather", "a2a+p2p"]: + pytest.skip("No support for THD format with cp_comm_type={all_gather, a2a+p2p}!") + + if (config.window_size != (-1, 0) or config.window_size != (-1, -1)) and cp_comm_type in ["p2p", "a2a+p2p"]: + pytest.skip("No support for SWA with cp_comm_type={p2p, a2a+p2p}!") + + if cp_comm_type in ["a2a", "a2a+p2p"] and (config.num_heads % 2 != 0 or config.num_gqa_groups % 2 != 0): pytest.skip( - "f16_O only needs to be tested for dtype = fp8 and scaling_mode in [current, mxfp8]!" + f"cp_comm_type=a2a requires num_heads ({config.num_heads}) and" + f" num_gqa_groups ({config.num_gqa_groups}) divisible by 2!" ) - # if cp_comm_type not in ["p2p", "a2a+p2p", "a2a"] and config.head_dim_qk != config.head_dim_v: - # pytest.skip("MLA CP currently only support KV P2P!") - # if dtype == "fp8" and config.head_dim_qk != config.head_dim_v: - # pytest.skip("MLA CP currently does not support FP8 attention!") - if dtype == "fp8" and config.softmax_type != "vanilla": - pytest.skip("CP implementation does not support non-vanilla softmax types in FP8!") + + if config.softmax_type != "vanilla" and dtype == "fp8": + pytest.skip("No support for non-vanilla softmax with FP8 attention!") if config.softmax_type != "vanilla" and cp_comm_type != "a2a": - pytest.skip( - "CP implementation only supports cp_comm_type=a2a for non-vanilla softmax types!" - ) - if ( - get_cudnn_version() < (9, 18, 0) - and config.softmax_type != "vanilla" - and qkv_format == "thd" - ): - pytest.skip( - "Unless cudnn version >= 9.18.0, CP implementation does not support qkv_format=thd for" - " non-vanilla softmax types!" - ) + pytest.skip(f"No support for non-vanilla softmax with cp_comm_type={cp_comm_type}!") + if config.softmax_type != "vanilla" and qkv_format == "thd" and get_cudnn_version() < (9, 18, 0): + pytest.skip("No support for non-vanilla softmax with THD format and cuDNN < 9.18.0!") + + if dtype == "fp8" and scaling_mode is None: + pytest.skip("dtype=fp8 requires scaling_mode != None!") + if dtype != "fp8" and scaling_mode is not None: + pytest.skip("dtype!=fp8 requires scaling_mode = None!") + if dtype != "fp8" and not f16_O: + pytest.skip("dtype!=fp8 requires f16_O=True!") + if scaling_mode == "delayed" and f16_O: + pytest.skip("scaling_mode=delayed requires f16_O=False!") if scaling_mode == "mxfp8" and not f16_O: - pytest.skip("MXFP8 only works with f16_O=True!") + pytest.skip("scaling_mode=mxfp8 requires f16_O=True!") dtypes = {"fp16": torch.float16, "bf16": torch.bfloat16, "fp8": torch.bfloat16} diff --git a/transformer_engine/pytorch/attention/dot_product_attention/utils.py b/transformer_engine/pytorch/attention/dot_product_attention/utils.py index 20ae4d135d..69a9ee9e03 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/utils.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/utils.py @@ -835,12 +835,59 @@ def _is_fa3_supported(num_heads, num_gqa_groups, head_dim_qk, head_dim_v, qkv_dt " bias for THD format" ) use_fused_attention = False - # elif fp8 and fp8_meta["recipe"].fp8_dpa and head_dim_qk != head_dim_v: - # logger.debug( - # "Disabling FusedAttention as it does not support context parallelism with FP8" - # " MLA attention" - # ) - # use_fused_attention = False + elif fp8 and qkv_format == "thd": + logger.debug( + "Disabling FusedAttention as it does not support context parallelism with FP8" + " attention and THD format" + ) + use_fused_attention = False + elif fp8 and core_attention_bias_type != "no_bias": + logger.debug( + "Disabling FusedAttention as it does not support context parallelism with FP8" + " attention and bias" + ) + use_fused_attention = False + + elif core_attention_bias_type != "no_bias" and cp_comm_type in [ + "all_gather", + "a2a", + "a2a+p2p", + ]: + logger.debug( + "Disabling FusedAttention as it does not support context parallelism with bias" + " and cp_comm_type = %s", + cp_comm_type, + ) + use_fused_attention = False + elif qkv_format == "thd" and cp_comm_type in ["all_gather", "a2a+p2p"]: + logger.debug( + "Disabling FusedAttention as it does not support context parallelism with THD" + " format and cp_comm_type = %s", + cp_comm_type, + ) + use_fused_attention = False + elif ( + window_size is not None + and (window_size != (-1, 0) or window_size != (-1, -1)) + and cp_comm_type in ["p2p", "a2a+p2p"] + ): + logger.debug( + "Disabling FusedAttention as it does not support context parallelism with sliding" + " window attention and cp_comm_type = %s", + cp_comm_type, + ) + use_fused_attention = False + elif cp_comm_type in ["a2a", "a2a+p2p"] and ( + num_heads % 2 != 0 or num_gqa_groups % 2 != 0 + ): + logger.debug( + "Disabling FusedAttention as cp_comm_type = %s requires num_heads and" + " num_gqa_groups divisible by 2 (got num_heads = %s, num_gqa_groups = %s)", + cp_comm_type, + num_heads, + num_gqa_groups, + ) + use_fused_attention = False # Filter: Attention mask # attn_mask_type | attention_mask | supported backends From 6d7766a4730f9dab47ede2d41249e2fb6c618fff Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sat, 7 Mar 2026 02:44:45 +0000 Subject: [PATCH 059/172] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../attention/test_attention_with_cp.py | 28 ++++++++++++++----- .../attention/dot_product_attention/utils.py | 4 +-- 2 files changed, 22 insertions(+), 10 deletions(-) diff --git a/tests/pytorch/attention/test_attention_with_cp.py b/tests/pytorch/attention/test_attention_with_cp.py index c9cdc6baf8..116d4dcc41 100644 --- a/tests/pytorch/attention/test_attention_with_cp.py +++ b/tests/pytorch/attention/test_attention_with_cp.py @@ -110,10 +110,15 @@ def test_cp_with_flash_attention(dtype, model, qkv_format, cp_comm_type): if qkv_format == "thd" and cp_comm_type in ["all_gather", "a2a+p2p"]: pytest.skip("No support for THD format with cp_comm_type={all_gather, a2a+p2p}!") - if config.window_size != (-1, 0) and config.window_size != (-1, -1) and cp_comm_type in [ - "p2p", - "a2a+p2p", - ]: + if ( + config.window_size != (-1, 0) + and config.window_size != (-1, -1) + and cp_comm_type + in [ + "p2p", + "a2a+p2p", + ] + ): pytest.skip("No support for SWA with cp_comm_type={p2p, a2a+p2p}!") if cp_comm_type in ["a2a", "a2a+p2p"] and ( @@ -299,10 +304,15 @@ def test_cp_with_fused_attention( if qkv_format == "thd" and cp_comm_type in ["all_gather", "a2a+p2p"]: pytest.skip("No support for THD format with cp_comm_type={all_gather, a2a+p2p}!") - if (config.window_size != (-1, 0) or config.window_size != (-1, -1)) and cp_comm_type in ["p2p", "a2a+p2p"]: + if (config.window_size != (-1, 0) or config.window_size != (-1, -1)) and cp_comm_type in [ + "p2p", + "a2a+p2p", + ]: pytest.skip("No support for SWA with cp_comm_type={p2p, a2a+p2p}!") - if cp_comm_type in ["a2a", "a2a+p2p"] and (config.num_heads % 2 != 0 or config.num_gqa_groups % 2 != 0): + if cp_comm_type in ["a2a", "a2a+p2p"] and ( + config.num_heads % 2 != 0 or config.num_gqa_groups % 2 != 0 + ): pytest.skip( f"cp_comm_type=a2a requires num_heads ({config.num_heads}) and" f" num_gqa_groups ({config.num_gqa_groups}) divisible by 2!" @@ -312,7 +322,11 @@ def test_cp_with_fused_attention( pytest.skip("No support for non-vanilla softmax with FP8 attention!") if config.softmax_type != "vanilla" and cp_comm_type != "a2a": pytest.skip(f"No support for non-vanilla softmax with cp_comm_type={cp_comm_type}!") - if config.softmax_type != "vanilla" and qkv_format == "thd" and get_cudnn_version() < (9, 18, 0): + if ( + config.softmax_type != "vanilla" + and qkv_format == "thd" + and get_cudnn_version() < (9, 18, 0) + ): pytest.skip("No support for non-vanilla softmax with THD format and cuDNN < 9.18.0!") if dtype == "fp8" and scaling_mode is None: diff --git a/transformer_engine/pytorch/attention/dot_product_attention/utils.py b/transformer_engine/pytorch/attention/dot_product_attention/utils.py index 69a9ee9e03..84f676539b 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/utils.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/utils.py @@ -877,9 +877,7 @@ def _is_fa3_supported(num_heads, num_gqa_groups, head_dim_qk, head_dim_v, qkv_dt cp_comm_type, ) use_fused_attention = False - elif cp_comm_type in ["a2a", "a2a+p2p"] and ( - num_heads % 2 != 0 or num_gqa_groups % 2 != 0 - ): + elif cp_comm_type in ["a2a", "a2a+p2p"] and (num_heads % 2 != 0 or num_gqa_groups % 2 != 0): logger.debug( "Disabling FusedAttention as cp_comm_type = %s requires num_heads and" " num_gqa_groups divisible by 2 (got num_heads = %s, num_gqa_groups = %s)", From 6d33db80c058fc1b53ba4a30c378486877204922 Mon Sep 17 00:00:00 2001 From: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Date: Wed, 11 Mar 2026 10:35:00 -0700 Subject: [PATCH 060/172] update FE Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --- 3rdparty/cudnn-frontend | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/3rdparty/cudnn-frontend b/3rdparty/cudnn-frontend index b4370f5198..fb3f58b2f8 160000 --- a/3rdparty/cudnn-frontend +++ b/3rdparty/cudnn-frontend @@ -1 +1 @@ -Subproject commit b4370f5198bd95ee758ebc2c6b76b887914b702d +Subproject commit fb3f58b2f8b47c2b305586bb4d7fcd007eb33839 From 92e6aaca36db8547979da565fceb9eee09225eb0 Mon Sep 17 00:00:00 2001 From: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Date: Thu, 12 Mar 2026 11:11:39 -0700 Subject: [PATCH 061/172] fix bwd KV tensors Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --- transformer_engine/common/fused_attn/fused_attn_fp8.cu | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/transformer_engine/common/fused_attn/fused_attn_fp8.cu b/transformer_engine/common/fused_attn/fused_attn_fp8.cu index 9796e39ddc..be491e944e 100644 --- a/transformer_engine/common/fused_attn/fused_attn_fp8.cu +++ b/transformer_engine/common/fused_attn/fused_attn_fp8.cu @@ -2831,8 +2831,8 @@ void fused_attn_fp8_bwd(size_t batch, size_t num_attn_heads, size_t num_gqa_grou void* devPtrK = input_K->data.dptr; void* devPtrV = input_V->data.dptr; void* devPtrDescaleQ = input_Q->scale_inv.dptr; - void* devPtrDescaleK = input_Q->scale_inv.dptr; - void* devPtrDescaleV = input_Q->scale_inv.dptr; + void* devPtrDescaleK = input_K->scale_inv.dptr; + void* devPtrDescaleV = input_V->scale_inv.dptr; void* devPtrQ_t = input_Q->columnwise_data.dptr; void* devPtrK_t = input_K->columnwise_data.dptr; void* devPtrDescaleQ_t = input_Q->columnwise_scale_inv.dptr; From 3cb6f0e11c1f2ee4dddcd6d201c5fa70dbb64661 Mon Sep 17 00:00:00 2001 From: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Date: Thu, 12 Mar 2026 15:08:46 -0700 Subject: [PATCH 062/172] tweak recipe control and backend selection Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --- .../attention/test_attention_with_cp.py | 2 +- .../dot_product_attention.py | 57 ++++++++- .../attention/dot_product_attention/utils.py | 110 +++++++++--------- 3 files changed, 112 insertions(+), 57 deletions(-) diff --git a/tests/pytorch/attention/test_attention_with_cp.py b/tests/pytorch/attention/test_attention_with_cp.py index 116d4dcc41..d83bb62ac4 100644 --- a/tests/pytorch/attention/test_attention_with_cp.py +++ b/tests/pytorch/attention/test_attention_with_cp.py @@ -304,7 +304,7 @@ def test_cp_with_fused_attention( if qkv_format == "thd" and cp_comm_type in ["all_gather", "a2a+p2p"]: pytest.skip("No support for THD format with cp_comm_type={all_gather, a2a+p2p}!") - if (config.window_size != (-1, 0) or config.window_size != (-1, -1)) and cp_comm_type in [ + if (config.window_size[0] != -1 and config.window_size[1] not in [-1, 0]) and cp_comm_type in [ "p2p", "a2a+p2p", ]: diff --git a/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py b/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py index ba24aa658e..369e4f03f0 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py @@ -30,7 +30,7 @@ Float8CurrentScalingRecipeState, Float8BlockScalingRecipeState, ) -from transformer_engine.pytorch.tensor.float8_tensor import Float8Tensor +from transformer_engine.pytorch.tensor.storage.float8_tensor_storage import Float8TensorStorage from transformer_engine.pytorch.module.base import TransformerEngineBaseModule from transformer_engine.pytorch.export import is_in_onnx_export_mode from transformer_engine.pytorch.constants import ( @@ -110,6 +110,13 @@ | | | export NVTE_DPA_FP8DS_AMAX_HISTLEN=1 # or any other integer | | | | export NVTE_DPA_FP8DS_REDUCE_AMAX=1 # or 0 | +-------------------+-----------+-----------------------------------------------------------------------------------+ +| MXFP8 | FP8DS | Pass MXFP8 to autocast(); | +| | | Attention FP8DS reuses the fp8_format, fp8_dpa, fp8_mha values from linear MXFP8; | +| | | export NVTE_DPA_FP8_RECIPE="DelayedScaling" # switch to DS | +| | | export NVTE_DPA_FP8DS_AMAX_ALGO="most_recent" # or "max" | +| | | export NVTE_DPA_FP8DS_AMAX_HISTLEN=1 # or any other integer | +| | | export NVTE_DPA_FP8DS_REDUCE_AMAX=1 # or 0 | ++-------------------+-----------+-----------------------------------------------------------------------------------+ | NVFP4 | FP8DS | Pass NVFP4 to autocast(); | | | | Attention FP8DS reuses the fp8_dpa, fp8_mha values from linear NVFP4; | | | | export NVTE_DPA_FP8_RECIPE="DelayedScaling" # switch to DS | @@ -130,6 +137,14 @@ | | | export NVTE_DPA_FP8DS_AMAX_HISTLEN=1 # or any other integer | | | | export NVTE_DPA_FP8DS_REDUCE_AMAX=1 # or 0 | +-------------------+-----------+-----------------------------------------------------------------------------------+ +| MXFP8 | FP8CS | Pass MXFP8 to autocast(); | +| | | Attention creates a new FP8CS recipe based on fp8_format, fp8_dpa, fp8_mha from | +| | | linear MXFP8, and: | +| | | export NVTE_DPA_FP8_RECIPE="Float8CurrentScaling" # switch to CS | +| | | export NVTE_DPA_FP8DS_AMAX_ALGO="most_recent" # or "max" | +| | | export NVTE_DPA_FP8DS_AMAX_HISTLEN=1 # or any other integer | +| | | export NVTE_DPA_FP8DS_REDUCE_AMAX=1 # or 0 | ++-------------------+-----------+-----------------------------------------------------------------------------------+ | NVFP4 | FP8CS | Pass NVFP4 to autocast(); | | | | Attention creates a new FP8CS recipe for QKV, O, dO, dQKV, and a new FP8DS recipe | | | | for S, dP, based on the fp8_dpa, fp8_mha values from linear NVFP4 and: | @@ -139,6 +154,13 @@ | | | export NVTE_DPA_FP8DS_AMAX_HISTLEN=1 # or any other integer | | | | export NVTE_DPA_FP8DS_REDUCE_AMAX=1 # or 0 | +-------------------+-----------+-----------------------------------------------------------------------------------+ +| FP8DS/FP8CS | MXFP8 | Pass FP8DS/FP8CS to autocast(); | +| | | Attention creates a new MXFP8 recipe based on fp8_format, fp8_dpa, fp8_mha from | +| | | linear FP8DS/FP8CS | +| | | export NVTE_DPA_FP8_RECIPE="MXFP8BlockScaling" # switch to MXFP8BS | ++-------------------+-----------+-----------------------------------------------------------------------------------+ +| MXFP8 | MXFP8 | Pass MXFP8 to autocast(); | ++-------------------+-----------+-----------------------------------------------------------------------------------+ | NVFP4 | MXFP8 | Pass NVFP4 to autocast(); | | | | Attention MXFP8 reuses the fp8_dpa, fp8_mha values from linear NVFP4; | | | | export NVTE_DPA_FP8_RECIPE="MXFP8BlockScaling" # switch to MXFP8BS | @@ -605,7 +627,7 @@ def init_fp8_metadata(self, num_gemms: int = 1) -> None: # ignore the recipe from autocast, set fp8_dpa = False, fp8_mha = False fp8_recipe.fp8_dpa = False fp8_recipe.fp8_mha = False - elif fp8_recipe.float8_current_scaling() and _dpa_fp8_recipe == "DelayedScaling": + elif (fp8_recipe.float8_current_scaling() or fp8_recipe.mxfp8()) and _dpa_fp8_recipe == "DelayedScaling": # reuse fp8_format, fp8_dpa, fp8_mha from fp8_recipe, and construct a DS recipe fake_recipe = DelayedScaling( fp8_format=fp8_recipe.fp8_format, @@ -658,6 +680,25 @@ def init_fp8_metadata(self, num_gemms: int = 1) -> None: ) fp8_recipe_dpa = fake_recipe fp8_recipes = [fp8_recipe, fp8_recipe_dpa] + elif fp8_recipe.mxfp8() and _dpa_fp8_recipe == "Float8CurrentScaling": + # reuse fp8_format, fp8_dpa, fp8_mha from fp8_recipe, and construct a CS+DS recipe + fake_recipes = [ + Float8CurrentScaling( + fp8_format=fp8_recipe.fp8_format, + fp8_dpa=fp8_recipe.fp8_dpa, + fp8_mha=fp8_recipe.fp8_mha, + ), + DelayedScaling( + fp8_format=fp8_recipe.fp8_format, + amax_history_len=_dpa_fp8ds_amax_histlen, + amax_compute_algo=_dpa_fp8ds_amax_algo, + fp8_dpa=fp8_recipe.fp8_dpa, + fp8_mha=fp8_recipe.fp8_mha, + reduce_amax=_dpa_fp8ds_reduce_amax, + ) + ] + fp8_recipe_dpa = fake_recipes[1] + fp8_recipes = fake_recipes elif fp8_recipe.nvfp4() and _dpa_fp8_recipe == "Float8CurrentScaling": # reuse fp8_dpa, fp8_mha from fp8_recipe but not fp8_format # construct a CS recipe for QKV, O, dO, dQKV and a DS recipe for S, dP @@ -678,6 +719,15 @@ def init_fp8_metadata(self, num_gemms: int = 1) -> None: ] fp8_recipe_dpa = fake_recipes[1] fp8_recipes = fake_recipes + elif (fp8_recipe.delayed() or fp8_recipe.float8_current_scaling()) and _dpa_fp8_recipe == "MXFP8BlockScaling": + # reuse fp8_format, fp8_dpa, fp8_mha from fp8_recipe, and construct a MXFP8 recipe + fake_recipe = MXFP8BlockScaling( + fp8_format=fp8_recipe.fp8_format, + fp8_dpa=fp8_recipe.fp8_dpa, + fp8_mha=fp8_recipe.fp8_mha, + ) + fp8_recipe_dpa = fake_recipe + fp8_recipes = fp8_recipe_dpa elif fp8_recipe.nvfp4() and _dpa_fp8_recipe == "MXFP8BlockScaling": # reuse fp8_dpa, fp8_mha from fp8_recipe but not fp8_format; construct a MXFP8 recipe fake_recipe = MXFP8BlockScaling( @@ -1212,7 +1262,7 @@ def forward( cu_seqlens_kv_padded = None # get qkv's memory layout - if all(isinstance(x, Float8Tensor) for x in [query_layer, key_layer, value_layer]): + if all(isinstance(x, Float8TensorStorage) for x in [query_layer, key_layer, value_layer]): ( qkv_layout, query_layer._data, @@ -1374,6 +1424,7 @@ def forward( attention_dropout=self.attention_dropout, context_parallel=context_parallel, cp_comm_type=self.cp_comm_type, + cp_size=cp_size, deterministic=self.deterministic, is_training=self.training, fp8=self.fp8, diff --git a/transformer_engine/pytorch/attention/dot_product_attention/utils.py b/transformer_engine/pytorch/attention/dot_product_attention/utils.py index 84f676539b..af4160ebb6 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/utils.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/utils.py @@ -35,7 +35,7 @@ META_DP, ) from transformer_engine.pytorch.attention.inference import InferenceParams -from transformer_engine.pytorch.quantized_tensor import QuantizedTensor +from transformer_engine.pytorch.quantized_tensor import QuantizedTensorStorage from transformer_engine.pytorch.tensor.float8_tensor import ( Float8Tensor, Float8Quantizer, @@ -224,6 +224,8 @@ class AttentionParams: Whether context parallelism is used or not. cp_comm_type : str, default = "p2p" The communication type of context parallelism. + cp_size : int, default = 1 + The group size of context parallelism. deterministic : bool, default = False Whether to run `DotProductAttention` with determinism or not. is_training : bool, default = True @@ -265,6 +267,7 @@ class AttentionParams: attention_dropout: float = 0.0 context_parallel: bool = False cp_comm_type: str = "p2p" + cp_size: int = 1 deterministic: bool = False is_training: bool = True fp8: bool = False @@ -342,6 +345,7 @@ def get_attention_backend( attention_dropout = attention_params.attention_dropout context_parallel = attention_params.context_parallel cp_comm_type = attention_params.cp_comm_type + cp_size = attention_params.cp_size deterministic = attention_params.deterministic is_training = attention_params.is_training fp8 = attention_params.fp8 @@ -363,6 +367,7 @@ def get_attention_backend( "transformer_engine_version": te.__version__, "compute_capability": "sm" + str(10 * device_compute_capability[0] + device_compute_capability[1]), + "cuda_version": torch.version.cuda, "flash_attn_version": ( str(FlashAttentionUtils.version) if FlashAttentionUtils.is_installed @@ -450,15 +455,15 @@ def get_attention_backend( qkv_dtype, ) use_flash_attention_2 = False - if qkv_dtype not in [torch.bfloat16, torch.float16, torch.float8_e4m3fn] or qkv_type not in [ + if qkv_dtype not in [torch.bfloat16, torch.float16, torch.float8_e4m3fn] or issubclass(qkv_type, ( torch.Tensor, - Float8Tensor, - ]: + QuantizedTensorStorage, + )): if use_flash_attention_3 and FlashAttentionUtils.v3_is_installed: logger.debug( "Disabling FlashAttention 3 for unsupported qkv_dtype = %s, qkv_type = %s. " "Supported: qkv_dtype = {torch.bfloat16, torch.float16, torch.float8_e4m3fn}, " - "qkv_type = {torch.Tensor, Float8Tensor}. ", + "qkv_type = {torch.Tensor, QuantizedTensorStorage}. ", qkv_dtype, qkv_type, ) @@ -467,17 +472,17 @@ def get_attention_backend( logger.debug( "Disabling FusedAttention for unsupported qkv_dtype = %s, qkv_type = %s. " "Supported: qkv_dtype = {torch.bfloat16, torch.float16, torch.float8_e4m3fn}, " - "qkv_type = {torch.Tensor, Float8Tensor}. ", + "qkv_type = {torch.Tensor, QuantizedTensorStorage}. ", qkv_dtype, qkv_type, ) use_fused_attention = False # Filter: Execution type - if fp8 and fp8_meta["recipe"].fp8_dpa: - fp8_recipe = fp8_meta["recipe"] - if fp8_meta.get("local_recipes", None) is not None: - fp8_recipe = fp8_meta["local_recipes"][0] + fp8_recipe = fp8_meta["recipe"] + if fp8_meta.get("local_recipes", None) is not None: + fp8_recipe = fp8_meta["local_recipes"][0] + if fp8 and fp8_recipe.fp8_dpa: if use_flash_attention_2 and FlashAttentionUtils.is_installed: logger.debug("Disabling FlashAttention 2 for FP8 attention") use_flash_attention_2 = False @@ -485,9 +490,9 @@ def get_attention_backend( if FlashAttentionUtils.v3_is_installed: logger.debug("Disabling FlashAttention 3 for FP8 training") use_flash_attention_3 = False - if use_flash_attention_3 and fp8_recipe.mxfp8(): + if use_flash_attention_3 and not (fp8_recipe.delayed() or fp8_recipe.float8_current_scaling()): if FlashAttentionUtils.v3_is_installed: - logger.debug("Disabling FlashAttention 3 for MXFP8") + logger.debug(f"Disabling FlashAttention 3 for {fp8_recipe.__class__.__name__}") use_flash_attention_3 = False if use_unfused_attention: allow_emulation = ( @@ -496,12 +501,14 @@ def get_attention_backend( if not allow_emulation: logger.debug("Disabling UnfusedDotProductAttention for FP8 attention") use_unfused_attention = False + if use_fused_attention and fp8_recipe.delayed(): + if device_compute_capability >= (10, 0) and deterministic and cudnn_version < (9, 18, 0): + logger.debug("Disabling FusedAttention for FP8 delayed scaling on arch >= sm100 with determinism for cuDNN < 9.18.0") + use_fused_attention = False if use_fused_attention and fp8_recipe.float8_current_scaling(): if device_compute_capability < (10, 0): logger.debug("Disabling FusedAttention for FP8 current scaling on arch < sm100") use_fused_attention = False - # TODO(cyanguwa): Modify the min cuDNN version supporting FP8 current scaling - # determinism for Blackwell else: if cudnn_version < (9, 14, 0): logger.debug( @@ -511,8 +518,8 @@ def get_attention_backend( else: if deterministic and cudnn_version < (9, 18, 0): logger.debug( - "Disabling FusedAttention for FP8 current scaling requiring determinism" - " with cuDNN < 9.18.0" + "Disabling FusedAttention for FP8 current scaling with determinism" + " for cuDNN < 9.18.0" ) use_fused_attention = False if use_fused_attention and fp8_recipe.mxfp8(): @@ -526,6 +533,9 @@ def get_attention_backend( elif qkv_format == "thd": logger.debug("Disabling FusedAttention for MXFP8 with qkv_format = thd") use_fused_attention = False + if use_fused_attention and (fp8_recipe.float8_block_scaling() or fp8_recipe.nvfp4()): + logger.debug(f"Disabling FusedAttention for {fp8_recipe.__class__.__name__}") + use_fused_attention = False if device_compute_capability == (12, 0): if use_flash_attention: @@ -558,7 +568,7 @@ def get_attention_backend( if use_flash_attention: use_flash_attention = False logger.debug("Disabling FlashAttention for max_logit") - if fp8 and fp8_meta["recipe"].fp8_dpa: + if fp8 and fp8_recipe.fp8_dpa: use_flash_attention = False use_fused_attention = False use_unfused_attention = False @@ -583,8 +593,8 @@ def get_attention_backend( use_flash_attention = False use_fused_attention = False use_unfused_attention = False - if fp8 and fp8_meta["recipe"].fp8_dpa: - if fp8_meta["recipe"].fp8_mha: + if fp8 and fp8_recipe.fp8_dpa: + if fp8_recipe.fp8_mha: logger.debug("Disabling all backends for KV caching with FP8 MHA") use_flash_attention = False use_fused_attention = False @@ -618,9 +628,9 @@ def get_attention_backend( # Filter: Head dimension if head_dim_qk != head_dim_v: - # if use_flash_attention_2 and FlashAttentionUtils.is_installed: - # logger.debug("Disabling FlashAttention 2 as it does not support MLA.") - # use_flash_attention_2 = False + if use_flash_attention_2 and FlashAttentionUtils.is_installed: + logger.debug("Disabling FlashAttention 2 as it does not support MLA.") + use_flash_attention_2 = False qkv_layout_group = qkv_layout.replace("b", "").replace("s", "").replace("t", "") if use_fused_attention and qkv_layout_group != "hd_hd_hd": @@ -734,30 +744,28 @@ def _is_fa3_supported(num_heads, num_gqa_groups, head_dim_qk, head_dim_v, qkv_dt if softmax_type != "vanilla": logger.debug("Disabling FlashAttention for softmax_type = %s", softmax_type) use_flash_attention = False - if fp8 and fp8_meta["recipe"].fp8_dpa: + if fp8 and fp8_recipe.fp8_dpa: logger.debug("Disabling FusedAttention for softmax_type = %s in FP8", softmax_type) use_fused_attention = False logger.debug( "Disabling UnfusedDotProductAttention for softmax_type = %s in FP8", softmax_type ) use_unfused_attention = False - if qkv_format == "thd": - if cudnn_version < (9, 18, 0): - logger.debug( - "Disabling FusedAttention for softmax_type = %s, qkv_format = thd and cuDNN" - " version < 9.18", - softmax_type, - ) - use_fused_attention = False - if context_parallel: - if cp_comm_type != "a2a": - logger.debug( - "Disabling FusedAttention for context parallelism with softmax_type = %s and" - " cp_comm_type = %s", - softmax_type, - cp_comm_type, - ) - use_fused_attention = False + if qkv_format == "thd" and cudnn_version < (9, 18, 0): + logger.debug( + "Disabling FusedAttention for softmax_type = %s, qkv_format = thd and cuDNN" + " version < 9.18", + softmax_type, + ) + use_fused_attention = False + if context_parallel and cp_comm_type != "a2a": + logger.debug( + "Disabling FusedAttention for context parallelism with softmax_type = %s and" + " cp_comm_type = %s", + softmax_type, + cp_comm_type, + ) + use_fused_attention = False # Filter: Context parallelism # qkv_format | attn_mask_type | attn_bias_type | supported backends @@ -778,7 +786,7 @@ def _is_fa3_supported(num_heads, num_gqa_groups, head_dim_qk, head_dim_v, qkv_dt use_unfused_attention = False if context_parallel and (use_flash_attention_2 or use_flash_attention_3): if FlashAttentionUtils.is_installed or FlashAttentionUtils.v3_is_installed: - if fp8 and fp8_meta["recipe"].fp8_dpa: + if fp8 and fp8_recipe.fp8_dpa: logger.debug( "Disabling FlashAttention as it does not support context parallelism with FP8" ) @@ -835,24 +843,19 @@ def _is_fa3_supported(num_heads, num_gqa_groups, head_dim_qk, head_dim_v, qkv_dt " bias for THD format" ) use_fused_attention = False - elif fp8 and qkv_format == "thd": + elif fp8 and fp8_recipe.fp8_dpa and qkv_format == "thd": logger.debug( "Disabling FusedAttention as it does not support context parallelism with FP8" " attention and THD format" ) use_fused_attention = False - elif fp8 and core_attention_bias_type != "no_bias": + elif fp8 and fp8_recipe.fp8_dpa and core_attention_bias_type != "no_bias": logger.debug( "Disabling FusedAttention as it does not support context parallelism with FP8" " attention and bias" ) use_fused_attention = False - - elif core_attention_bias_type != "no_bias" and cp_comm_type in [ - "all_gather", - "a2a", - "a2a+p2p", - ]: + elif core_attention_bias_type != "no_bias" and cp_comm_type != "p2p": logger.debug( "Disabling FusedAttention as it does not support context parallelism with bias" " and cp_comm_type = %s", @@ -868,7 +871,7 @@ def _is_fa3_supported(num_heads, num_gqa_groups, head_dim_qk, head_dim_v, qkv_dt use_fused_attention = False elif ( window_size is not None - and (window_size != (-1, 0) or window_size != (-1, -1)) + and window_size[0] != -1 and window_size[1] not in [-1, 0] and cp_comm_type in ["p2p", "a2a+p2p"] ): logger.debug( @@ -1040,8 +1043,8 @@ def _is_fa3_supported(num_heads, num_gqa_groups, head_dim_qk, head_dim_v, qkv_dt if use_fused_attention: q_type = TE_DType[qkv_dtype] kv_type = q_type - if fp8 and fp8_meta["recipe"].fp8_dpa: - q_type = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True) + if fp8 and fp8_recipe.fp8_dpa: + q_type = get_fp8_te_dtype(fp8_recipe, fprop_tensor=True) kv_type = q_type fused_attention_backend = tex.get_fused_attn_backend( is_training, @@ -1072,6 +1075,7 @@ def _is_fa3_supported(num_heads, num_gqa_groups, head_dim_qk, head_dim_v, qkv_dt use_fused_attention and window_size is not None and window_size[0] != -1 + and window_size[1] not in [-1, 0] and fused_attention_backend == FusedAttnBackend["F16_max512_seqlen"] ): logger.debug( @@ -2364,7 +2368,7 @@ def combine_and_dequantize( qkv_layout = qkv_layout.replace("paged_kv_", "") qkv_format, q_format, kv_format = get_qkv_format(qkv_layout) qkv_group = len(qkv_layout.split("_")) - if all(isinstance(x, QuantizedTensor) for x in [q_fp8, k_fp8, v_fp8]): + if all(isinstance(x, QuantizedTensorStorage) for x in [q_fp8, k_fp8, v_fp8]): src_nominal_dtype = q_fp8.dtype else: assert src_nominal_dtype is not None, "The nominal dtype of input tensors is required!" From c57ece4318b5851d67938f2e572b0725136e980c Mon Sep 17 00:00:00 2001 From: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Date: Thu, 12 Mar 2026 15:09:25 -0700 Subject: [PATCH 063/172] tweak quantizer logic Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --- .../attention/dot_product_attention/utils.py | 62 +++++++++---------- 1 file changed, 30 insertions(+), 32 deletions(-) diff --git a/transformer_engine/pytorch/attention/dot_product_attention/utils.py b/transformer_engine/pytorch/attention/dot_product_attention/utils.py index af4160ebb6..7be6e653d6 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/utils.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/utils.py @@ -2164,7 +2164,7 @@ def get_attention_quantizers(fp8, fp8_recipe, quantizers): return [None] * 6 QKV_quantizer = quantizers["scaling_fwd"][META_QKV] - QKV_quantizer.internal = False + QKV_quantizer.internal = True QKV_quantizer.set_usage(rowwise=True, columnwise=False) S_quantizer = quantizers["scaling_fwd"][META_S] @@ -2176,7 +2176,7 @@ def get_attention_quantizers(fp8, fp8_recipe, quantizers): O_quantizer.set_usage(rowwise=True, columnwise=False) dO_quantizer = quantizers["scaling_bwd"][META_DO] - dO_quantizer.internal = False + dO_quantizer.internal = True dO_quantizer.set_usage(rowwise=True, columnwise=False) dP_quantizer = quantizers["scaling_bwd"][META_DP] @@ -2184,7 +2184,7 @@ def get_attention_quantizers(fp8, fp8_recipe, quantizers): dP_quantizer.set_usage(rowwise=True, columnwise=False) dQKV_quantizer = quantizers["scaling_bwd"][META_DQKV] - dQKV_quantizer.interal = False + dQKV_quantizer.interal = True dQKV_quantizer.set_usage(rowwise=True, columnwise=False) if fp8_recipe.mxfp8(): @@ -2262,62 +2262,60 @@ def print_quantizers( def permute_to_grouped_tensor(src_format, tensor): - """Permute tensor to bhsd or htd format for grouped quantization in MXFP8BlockScaling. src_format ={bshd, sbhd, thd}""" + """Permute tensor from src_format = {bshd, sbhd, thd} to des_format = {bhsd, htd} for MXFP8 quantization.""" if src_format in ["bhsd", "htd"]: return tensor, src_format + des_format = "bhsd" if src_format != "thd" else "htd" + # make tensor contiguous bshd/sbhd/thd tensor = tensor.contiguous() if not tensor.is_contiguous() else tensor + # permute bshd/sbhd to bhsd, and thd to htd dim_s_or_t = src_format.find("s") if "s" in src_format else src_format.find("t") dim_others = [i for i in range(len(tensor.shape)) if i != dim_s_or_t] - perm = [*dim_others[:-1], dim_s_or_t, dim_others[-1]] - tensor = tensor.permute(*perm).contiguous() - return tensor, "bhsd" if src_format != "thd" else "htd" + new_dims = [*dim_others[:-1], dim_s_or_t, dim_others[-1]] + tensor = tensor.permute(*new_dims).contiguous() + return tensor, des_format def combine_and_quantize(qkv_layout, q, k, v, qkv_quantizer): """Combine q,k,v based on qkv_layout and quantize them together""" - # 1: qkv packed, 2: kv packed, 3: qkv separate - qkv_layout = qkv_layout.replace("paged_kv_", "") - qkv_format, q_format, kv_format = get_qkv_format(qkv_layout) - qkv_group = len(qkv_layout.split("_")) - src_nominal_dtype = q.dtype if isinstance(qkv_quantizer, MXFP8Quantizer): - # bs3hd, sb3hd, etc -> bshd_bshd_bhsd -> bhsd_bhsd_bhsd - # t3hd, etc -> thd_thd_thd -> htd_htd_htd + qkv_format, q_format, kv_format = get_qkv_format(qkv_layout) + # permute q, k, v to bhsd/htd format if q_format not in ["bhsd", "htd"]: q, _ = permute_to_grouped_tensor(q_format, q) if kv_format not in ["bhsd", "htd"]: k, _ = permute_to_grouped_tensor(kv_format, k) v, _ = permute_to_grouped_tensor(kv_format, v) qkv_layout = "bhsd_bhsd_bhsd" if qkv_format != "thd" else "htd_htd_htd" - + # check shapes original_shapes = [x.shape for x in [q, k, v]] s_q, d_qk = q.shape[-2:] s_kv, d_v = v.shape[-2:] - assert s_q % 128 == 0 - assert s_kv % 128 == 0 - assert d_qk % 32 == 0 - assert d_v % 32 == 0 - # need to check seqlens in THD % 128 == 0 + assert ( + s_q % 128 == 0 and s_kv % 128 == 0 and d_qk % 32 == 0 and d_v % 32 == 0 + ), f"MXFP8 quantization requires s_q % 128 == 0, s_kv % 128 == 0, d_qk % 32 == 0, d_v % 32 == 0. Found {s_q=}, {s_kv=}, {d_qk=}, {d_v=}." q, k, v = [x.view(-1, x.shape[-1]) for x in [q, k, v]] - - # consider bhsd for now + # quantize q, k, v if d_qk == d_v: grouped_tensor = GroupedTensor.create_and_quantize( tensors=[q, k, v], quantizer=qkv_quantizer ) q_fp8, k_fp8, v_fp8 = grouped_tensor.quantized_tensors else: - # grouped_tensor = GroupedTensor.create_and_quantize( - # tensors=[q, k], quantizer=qkv_quantizer - # ) - # q_fp8, k_fp8 = grouped_tensor.quantized_tensors - q_fp8 = qkv_quantizer(q) - k_fp8 = qkv_quantizer(k) + grouped_tensor = GroupedTensor.create_and_quantize( + tensors=[q, k], quantizer=qkv_quantizer + ) + q_fp8, k_fp8 = grouped_tensor.quantized_tensors v_fp8 = qkv_quantizer(v) + # view rowwise/columnwise data back to original shapes, not rowwise_scale_inv/columnwise_scale_inv q_fp8, k_fp8, v_fp8 = [x.view(s) for x, s in zip([q_fp8, k_fp8, v_fp8], original_shapes)] return q_fp8, k_fp8, v_fp8, qkv_layout + qkv_layout = qkv_layout.replace("paged_kv_", "") + qkv_group = len(qkv_layout.split("_")) + src_nominal_dtype = q.dtype + # 1: qkv packed, 2: kv packed, 3: qkv separate match qkv_group: case 1: dim = qkv_layout.find("3") @@ -2364,10 +2362,6 @@ def combine_and_dequantize( qkv_layout, q_fp8, k_fp8, v_fp8, src_nominal_dtype=None, des_nominal_dtype=None ): """Combine q,k,v based on qkv_layout and dequantize them together""" - # 1: qkv packed, 2: kv packed, 3: qkv separate - qkv_layout = qkv_layout.replace("paged_kv_", "") - qkv_format, q_format, kv_format = get_qkv_format(qkv_layout) - qkv_group = len(qkv_layout.split("_")) if all(isinstance(x, QuantizedTensorStorage) for x in [q_fp8, k_fp8, v_fp8]): src_nominal_dtype = q_fp8.dtype else: @@ -2379,7 +2373,11 @@ def combine_and_dequantize( q, k, v = [x.dequantize(dtype=des_nominal_dtype) for x in [q_fp8, k_fp8, v_fp8]] return q, k, v + qkv_layout = qkv_layout.replace("paged_kv_", "") + qkv_format, q_format, kv_format = get_qkv_format(qkv_layout) + qkv_group = len(qkv_layout.split("_")) q_data, k_data, v_data = [x._data for x in [q_fp8, k_fp8, v_fp8]] + # 1: qkv packed, 2: kv packed, 3: qkv separate match qkv_group: case 1: dim = qkv_layout.find("3") From 87a7e1e887078629844852a646dbc0a6d587ca09 Mon Sep 17 00:00:00 2001 From: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Date: Thu, 12 Mar 2026 18:25:26 -0700 Subject: [PATCH 064/172] minor fixes after last two commits Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --- .../include/transformer_engine/fused_attn.h | 14 ++++------ .../attention/dot_product_attention/utils.py | 28 +++++++++++-------- .../pytorch/cpp_extensions/fused_attn.py | 12 -------- transformer_engine/pytorch/csrc/extensions.h | 2 +- .../pytorch/csrc/extensions/attention.cpp | 4 +-- 5 files changed, 26 insertions(+), 34 deletions(-) diff --git a/transformer_engine/common/include/transformer_engine/fused_attn.h b/transformer_engine/common/include/transformer_engine/fused_attn.h index 90393ce8c8..04f7ec4a6c 100644 --- a/transformer_engine/common/include/transformer_engine/fused_attn.h +++ b/transformer_engine/common/include/transformer_engine/fused_attn.h @@ -198,14 +198,12 @@ NVTE_QKV_Format nvte_get_kv_format(NVTE_QKV_Layout qkv_layout); * \param[in] src_format The source format. * \param[in] src_shape The source shape. * \param[in] dst_format The destination format. - * \param[in,out] dst_shape The destination shape. - * \param[in,out] b The batch size. - * \param[in,out] h The number of heads. - * \param[in,out] s The sequence length. - * \param[in,out] d The head dimension. - * \param[in,out] t The time dimension. - * - * \return The destination shape. + * \param[out] dst_shape The destination shape. + * \param[out] b The batch size. + * \param[out] h The number of heads. + * \param[out] s The sequence length. + * \param[out] d The head dimension. + * \param[out] t The number of tokens. */ void nvte_convert_qkv_format(NVTE_QKV_Format src_format, std::vector src_shape, NVTE_QKV_Format dst_format, std::vector &dst_shape, size_t *b, diff --git a/transformer_engine/pytorch/attention/dot_product_attention/utils.py b/transformer_engine/pytorch/attention/dot_product_attention/utils.py index 7be6e653d6..c8b336e9a3 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/utils.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/utils.py @@ -41,6 +41,7 @@ Float8Quantizer, Float8CurrentScalingQuantizer, ) +from transformer_engine.pytorch.tensor.float8_tensor import Float8TensorStorage from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8Quantizer, MXFP8Tensor from transformer_engine.pytorch.tensor.storage.mxfp8_tensor_storage import MXFP8TensorStorage from transformer_engine.pytorch.tensor.storage.grouped_tensor import GroupedTensor @@ -455,15 +456,18 @@ def get_attention_backend( qkv_dtype, ) use_flash_attention_2 = False - if qkv_dtype not in [torch.bfloat16, torch.float16, torch.float8_e4m3fn] or issubclass(qkv_type, ( + if qkv_dtype not in [torch.bfloat16, torch.float16, torch.float8_e4m3fn] or qkv_type not in ( torch.Tensor, - QuantizedTensorStorage, - )): + Float8Tensor, + Float8TensorStorage, + MXFP8Tensor, + MXFP8TensorStorage, + ): if use_flash_attention_3 and FlashAttentionUtils.v3_is_installed: logger.debug( "Disabling FlashAttention 3 for unsupported qkv_dtype = %s, qkv_type = %s. " "Supported: qkv_dtype = {torch.bfloat16, torch.float16, torch.float8_e4m3fn}, " - "qkv_type = {torch.Tensor, QuantizedTensorStorage}. ", + "qkv_type = {torch.Tensor, Float8Tensor, Float8TensorStorage, MXFP8Tensor, MXFP8TensorStorage}. ", qkv_dtype, qkv_type, ) @@ -472,16 +476,18 @@ def get_attention_backend( logger.debug( "Disabling FusedAttention for unsupported qkv_dtype = %s, qkv_type = %s. " "Supported: qkv_dtype = {torch.bfloat16, torch.float16, torch.float8_e4m3fn}, " - "qkv_type = {torch.Tensor, QuantizedTensorStorage}. ", + "qkv_type = {torch.Tensor, Float8Tensor, Float8TensorStorage, MXFP8Tensor, MXFP8TensorStorage}. ", qkv_dtype, qkv_type, ) use_fused_attention = False # Filter: Execution type - fp8_recipe = fp8_meta["recipe"] - if fp8_meta.get("local_recipes", None) is not None: - fp8_recipe = fp8_meta["local_recipes"][0] + fp8_recipe = None + if fp8: + fp8_recipe = fp8_meta["recipe"] if fp8_meta is not None else None + if fp8_meta.get("local_recipes", None) is not None: + fp8_recipe = fp8_meta["local_recipes"][0] if fp8 and fp8_recipe.fp8_dpa: if use_flash_attention_2 and FlashAttentionUtils.is_installed: logger.debug("Disabling FlashAttention 2 for FP8 attention") @@ -2164,7 +2170,7 @@ def get_attention_quantizers(fp8, fp8_recipe, quantizers): return [None] * 6 QKV_quantizer = quantizers["scaling_fwd"][META_QKV] - QKV_quantizer.internal = True + QKV_quantizer.internal = False QKV_quantizer.set_usage(rowwise=True, columnwise=False) S_quantizer = quantizers["scaling_fwd"][META_S] @@ -2176,7 +2182,7 @@ def get_attention_quantizers(fp8, fp8_recipe, quantizers): O_quantizer.set_usage(rowwise=True, columnwise=False) dO_quantizer = quantizers["scaling_bwd"][META_DO] - dO_quantizer.internal = True + dO_quantizer.internal = False dO_quantizer.set_usage(rowwise=True, columnwise=False) dP_quantizer = quantizers["scaling_bwd"][META_DP] @@ -2184,7 +2190,7 @@ def get_attention_quantizers(fp8, fp8_recipe, quantizers): dP_quantizer.set_usage(rowwise=True, columnwise=False) dQKV_quantizer = quantizers["scaling_bwd"][META_DQKV] - dQKV_quantizer.interal = True + dQKV_quantizer.interal = False dQKV_quantizer.set_usage(rowwise=True, columnwise=False) if fp8_recipe.mxfp8(): diff --git a/transformer_engine/pytorch/cpp_extensions/fused_attn.py b/transformer_engine/pytorch/cpp_extensions/fused_attn.py index 7a756ead1c..16cc55fcd1 100644 --- a/transformer_engine/pytorch/cpp_extensions/fused_attn.py +++ b/transformer_engine/pytorch/cpp_extensions/fused_attn.py @@ -364,7 +364,6 @@ def fused_attn_bwd( o: torch.Tensor, d_o: torch.Tensor, fake_dtype: torch.dtype, - # dqkv_dtype: tex.DType, aux_ctx_tensors: List[torch.Tensor], fused_attention_backend: tex.NVTE_Fused_Attn_Backend, cu_seqlens_q_padded: torch.Tensor = None, @@ -417,8 +416,6 @@ def fused_attn_bwd( fake_dtype : tex.DType data type of Q, K and V - in case of high precision, fake dtype in case of FP8; in torch.dtype - # dqkv_dtype : tex.DType - # data type of dQ, dK and dV; in tex.DType, not torch.dtype aux_ctx_tensors : List[torch.Tensor] auxiliary output tensors of the forward pass when its is_training is True, e.g. aux_ctx_tensors = [M, ZInv, rng_state] @@ -507,14 +504,6 @@ def fused_attn_bwd( len(aux_ctx_tensors) >= 1 ), "aux_ctx_tensors must contain rng_state as its last element." - if fused_attention_backend == FusedAttnBackend["FP8"]: - # assert ( - # dqkv_dtype is not None - # ), "dqkv_dtype is required as an input for FP8 fused attention backward." - assert ( - len(aux_ctx_tensors) >= 3 - ), "aux_ctx_tensors is required to be [M, ZInv, rng_state] for FP8 fused attention." - output_tensors = tex.fused_attn_bwd( max_seqlen_q, max_seqlen_kv, @@ -539,7 +528,6 @@ def fused_attn_bwd( o, d_o, fake_dtype, - # dqkv_dtype, aux_ctx_tensors, cu_seqlens_q_padded, cu_seqlens_kv_padded, diff --git a/transformer_engine/pytorch/csrc/extensions.h b/transformer_engine/pytorch/csrc/extensions.h index 95c985062a..67aae997f8 100644 --- a/transformer_engine/pytorch/csrc/extensions.h +++ b/transformer_engine/pytorch/csrc/extensions.h @@ -104,7 +104,7 @@ std::vector fused_attn_bwd( bool bottom_right_diagonal, bool deterministic, const at::Tensor cu_seqlens_q, const at::Tensor cu_seqlens_kv, const py::handle Q, const py::handle K, const py::handle V, const py::handle O, const py::handle dO, - const at::ScalarType fake_dtype, //const DType dqkv_type, + const at::ScalarType fake_dtype, const std::vector Aux_CTX_Tensors, const std::optional cu_seqlens_q_padded, const std::optional cu_seqlens_kv_padded, py::handle s_quantizer, diff --git a/transformer_engine/pytorch/csrc/extensions/attention.cpp b/transformer_engine/pytorch/csrc/extensions/attention.cpp index 192a774ca0..a115312bf5 100644 --- a/transformer_engine/pytorch/csrc/extensions/attention.cpp +++ b/transformer_engine/pytorch/csrc/extensions/attention.cpp @@ -105,7 +105,7 @@ std::pair quantizer_helper(py::handle quantizer, std::tie(te_T, py_T) = T_quantizer_fp8->create_tensor(shape, dtype); NVTE_CHECK( !data.has_value(), - "Float8CurrentScalingQuantizer::create_tensor() does not take data tensor as input!"); + "MXFP8Quantizer::create_tensor() does not take data tensor as input!"); } } return {std::move(te_T), std::move(py_T)}; @@ -334,7 +334,7 @@ std::vector fused_attn_bwd( bool bottom_right_diagonal, bool deterministic, const at::Tensor cu_seqlens_q, const at::Tensor cu_seqlens_kv, const py::handle Q, const py::handle K, const py::handle V, const py::handle O, const py::handle dO, - const at::ScalarType fake_dtype, //const DType dqkv_type, + const at::ScalarType fake_dtype, const std::vector Aux_CTX_Tensors, const std::optional cu_seqlens_q_padded, const std::optional cu_seqlens_kv_padded, py::handle s_quantizer, From 3b015f365a255f62b2fae03bae67f68bb36d3b94 Mon Sep 17 00:00:00 2001 From: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Date: Thu, 12 Mar 2026 18:25:48 -0700 Subject: [PATCH 065/172] improve generate strides Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --- .../common/fused_attn/fused_attn.cpp | 45 +-- .../common/fused_attn/fused_attn_fp8.cu | 174 ++++---- transformer_engine/common/fused_attn/utils.h | 374 ++++++++---------- 3 files changed, 242 insertions(+), 351 deletions(-) diff --git a/transformer_engine/common/fused_attn/fused_attn.cpp b/transformer_engine/common/fused_attn/fused_attn.cpp index 72c5273a78..046b18aae9 100644 --- a/transformer_engine/common/fused_attn/fused_attn.cpp +++ b/transformer_engine/common/fused_attn/fused_attn.cpp @@ -570,25 +570,14 @@ void nvte_fused_attn_fwd( Tensor *output_O = convertNVTETensorCheck(O); Tensor *wkspace = convertNVTETensor(workspace); - auto ndim = input_Q->data.shape.size(); - auto ndim_kv = input_K->data.shape.size(); - size_t b = input_cu_seqlens_q->data.shape[0] - 1; - size_t d_qk = input_Q->data.shape[ndim - 1]; - size_t d_v = input_V->data.shape[ndim_kv - 1]; - size_t t_q = 0; - size_t t_kv = 0; + size_t b = 0, h_q = 0, h_kv = 0, s_q = 0, s_kv = 0, d_qk = 0, d_v = 0, t_q = 0, t_kv = 0; NVTE_QKV_Format q_format = nvte_get_q_format(qkv_layout); NVTE_QKV_Format kv_format = nvte_get_kv_format(qkv_layout); - if (q_format == NVTE_QKV_Format::NVTE_THD) { - t_q = input_Q->data.shape[0]; - } - if (kv_format == NVTE_QKV_Format::NVTE_THD) { - t_kv = input_K->data.shape[0]; - } - size_t h_q = (q_format == NVTE_QKV_Format::NVTE_BHSD) ? input_Q->data.shape[ndim - 3] - : input_Q->data.shape[ndim - 2]; - size_t h_kv = (kv_format == NVTE_QKV_Format::NVTE_BHSD) ? input_K->data.shape[ndim_kv - 3] - : input_K->data.shape[ndim_kv - 2]; + std::vector tmp_shape(4); + nvte_convert_qkv_format(q_format, input_Q->data.shape, q_format, tmp_shape, &b, &h_q, &s_q, &d_qk, &t_q); + nvte_convert_qkv_format(kv_format, input_K->data.shape, kv_format, tmp_shape, &b, &h_kv, &s_kv, &d_qk, &t_kv); + nvte_convert_qkv_format(kv_format, input_V->data.shape, kv_format, tmp_shape, &b, &h_kv, &s_kv, &d_v, &t_kv); + int64_t num_pages_k = 0; int64_t num_pages_v = 0; int64_t page_size_k = 0; @@ -699,25 +688,13 @@ void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETenso Tensor *output_dSoftmaxOffset = convertNVTETensorCheck(dSoftmaxOffset); Tensor *wkspace = convertNVTETensor(workspace); - auto ndim = input_Q->data.shape.size(); - auto ndim_kv = input_K->data.shape.size(); - size_t b = input_cu_seqlens_q->data.shape[0] - 1; - size_t d_qk = input_Q->data.shape[ndim - 1]; - size_t d_v = input_V->data.shape[ndim_kv - 1]; - size_t t_q = 0; - size_t t_kv = 0; + size_t b = 0, h_q = 0, h_kv = 0, s_q = 0, s_kv = 0, d_qk = 0, d_v = 0, t_q = 0, t_kv = 0; NVTE_QKV_Format q_format = nvte_get_q_format(qkv_layout); NVTE_QKV_Format kv_format = nvte_get_kv_format(qkv_layout); - if (q_format == NVTE_QKV_Format::NVTE_THD) { - t_q = input_Q->data.shape[0]; - } - if (kv_format == NVTE_QKV_Format::NVTE_THD) { - t_kv = input_K->data.shape[0]; - } - size_t h_q = (q_format == NVTE_QKV_Format::NVTE_BHSD) ? input_Q->data.shape[ndim - 3] - : input_Q->data.shape[ndim - 2]; - size_t h_kv = (kv_format == NVTE_QKV_Format::NVTE_BHSD) ? input_K->data.shape[ndim_kv - 3] - : input_K->data.shape[ndim_kv - 2]; + std::vector tmp_shape(4); + nvte_convert_qkv_format(q_format, input_Q->data.shape, q_format, tmp_shape, &b, &h_q, &s_q, &d_qk, &t_q); + nvte_convert_qkv_format(kv_format, input_K->data.shape, kv_format, tmp_shape, &b, &h_kv, &s_kv, &d_qk, &t_kv); + nvte_convert_qkv_format(kv_format, input_V->data.shape, kv_format, tmp_shape, &b, &h_kv, &s_kv, &d_v, &t_kv); auto handle = cudnnExecutionPlanManager::Instance().GetHandle(); const NVTEDType Q_type = static_cast(input_Q->data.dtype); diff --git a/transformer_engine/common/fused_attn/fused_attn_fp8.cu b/transformer_engine/common/fused_attn/fused_attn_fp8.cu index be491e944e..8b395e3deb 100644 --- a/transformer_engine/common/fused_attn/fused_attn_fp8.cu +++ b/transformer_engine/common/fused_attn/fused_attn_fp8.cu @@ -1774,29 +1774,24 @@ void fused_attn_fp8_fwd_impl_v1( std::shared_ptr dropout_seed, dropout_offset; // Q, K, V, attn_scale - std::vector q_stride(4); - std::vector k_stride(4); - std::vector v_stride(4); - generateMatrixStrides_v1(b, h, hg, s_q, s_kv, d_qk, d_v, q_stride.data(), qkv_layout, - NVTE_QKV_Matrix::NVTE_Q_Matrix); - generateMatrixStrides_v1(b, h, hg, s_q, s_kv, d_qk, d_v, k_stride.data(), qkv_layout, - NVTE_QKV_Matrix::NVTE_K_Matrix); - generateMatrixStrides_v1(b, h, hg, s_q, s_kv, d_qk, d_v, v_stride.data(), qkv_layout, - NVTE_QKV_Matrix::NVTE_V_Matrix); + std::vector q_strides(4); + std::vector k_strides(4); + std::vector v_strides(4); + generateMatrixStridesWithLayout(b, h, hg, s_q, s_kv, d_qk, d_v, q_strides.data(), k_strides.data(), v_strides.data(), qkv_layout); Q = mha_graph->tensor(fe::graph::Tensor_attributes() .set_name("Q") .set_dim({b, h, s_q, d_qk}) - .set_stride(q_stride) + .set_stride(q_strides) .set_data_type(qkv_tensor_type)); K = mha_graph->tensor(fe::graph::Tensor_attributes() .set_name("K") .set_dim({b, hg, s_kv, d_qk}) - .set_stride(k_stride) + .set_stride(k_strides) .set_data_type(qkv_tensor_type)); V = mha_graph->tensor(fe::graph::Tensor_attributes() .set_name("V") .set_dim({b, hg, s_kv, d_v}) - .set_stride(v_stride) + .set_stride(v_strides) .set_data_type(qkv_tensor_type)); attn_scale = mha_graph->tensor(fe::graph::Tensor_attributes() .set_name("attn_scale") @@ -1831,11 +1826,11 @@ void fused_attn_fp8_fwd_impl_v1( std::vector v_scale_strides(4); auto padded = pad_s_d_for_mxfp8(s_q, s_kv, d_qk, d_v); generateMatrixStridesWithFormat(b, h, padded.s_q_padded, padded.d_qk_scale_padded, - q_scale_strides.data(), q_format, false); + q_scale_strides.data(), q_format); generateMatrixStridesWithFormat(b, hg, padded.s_kv_padded, padded.d_qk_scale_padded, - k_scale_strides.data(), kv_format, false); + k_scale_strides.data(), kv_format); generateMatrixStridesWithFormat(b, hg, padded.s_kv_scale_padded, padded.d_v_padded, - v_scale_strides.data(), kv_format, false); + v_scale_strides.data(), kv_format); descale_q = mha_graph->tensor(fe::graph::Tensor_attributes() .set_name("Descale_q") @@ -1916,12 +1911,7 @@ void fused_attn_fp8_fwd_impl_v1( } std::shared_ptr O, Stats, amax_s, amax_o; - if (is_mxfp8) { - auto outputs = mha_graph->sdpa_fp8(Q, K, V, descale_q, descale_k, descale_v, sdpa_options); - O = outputs[0]; - Stats = outputs[1]; - amax_o = outputs[2]; - } else { + if (is_delayed_scaling || is_current_scaling) { auto outputs = mha_graph->sdpa_fp8(Q, K, V, descale_q, descale_k, descale_v, descale_s, scale_s, scale_o, sdpa_options); O = outputs[0]; @@ -1932,13 +1922,18 @@ void fused_attn_fp8_fwd_impl_v1( .set_dim({1, 1, 1, 1}) .set_stride({1, 1, 1, 1}) .set_data_type(fe::DataType_t::FLOAT); + } else if (is_mxfp8) { + auto outputs = mha_graph->sdpa_fp8(Q, K, V, descale_q, descale_k, descale_v, sdpa_options); + O = outputs[0]; + Stats = outputs[1]; + amax_o = outputs[2]; } - std::vector o_stride(4); - generateMatrixStridesWithFormat(b, h, s_q, d_v, o_stride.data(), o_format, false); + std::vector o_strides(4); + generateMatrixStridesWithFormat(b, h, s_q, d_v, o_strides.data(), o_format); O->set_output(true) .set_dim({b, h, s_q, d_v}) - .set_stride(o_stride) + .set_stride(o_strides) .set_data_type(o_tensor_type); amax_o->set_output(true) .set_dim({1, 1, 1, 1}) @@ -2048,6 +2043,7 @@ void fused_attn_fp8_fwd_impl_v1( variant_pack[dropout_seed] = devPtrDropoutSeed; variant_pack[dropout_offset] = devPtrDropoutOffset; } + NVTE_CHECK_CUDNN_FE(mha_graph->execute(handle, variant_pack, workspace)); } catch (cudnn_frontend::cudnnException& e) { NVTE_ERROR(e.what()); @@ -2213,41 +2209,36 @@ void fused_attn_fp8_bwd_impl_v1( std::shared_ptr dropout_seed, dropout_offset; // Q, K, V, O, dO, stats, attn_scale - std::vector q_stride(4); - std::vector k_stride(4); - std::vector v_stride(4); - std::vector o_stride(4); - generateMatrixStrides_v1(b, h, hg, s_q, s_kv, d_qk, d_v, q_stride.data(), qkv_layout, - NVTE_QKV_Matrix::NVTE_Q_Matrix); - generateMatrixStrides_v1(b, hg, hg, s_q, s_kv, d_qk, d_v, k_stride.data(), qkv_layout, - NVTE_QKV_Matrix::NVTE_K_Matrix); - generateMatrixStrides_v1(b, hg, hg, s_q, s_kv, d_qk, d_v, v_stride.data(), qkv_layout, - NVTE_QKV_Matrix::NVTE_V_Matrix); - generateMatrixStridesWithFormat(b, h, s_q, d_v, o_stride.data(), o_format, false); + std::vector q_strides(4); + std::vector k_strides(4); + std::vector v_strides(4); + std::vector o_strides(4); + generateMatrixStridesWithLayout(b, h, hg, s_q, s_kv, d_qk, d_v, q_strides.data(), k_strides.data(), v_strides.data(), qkv_layout); + generateMatrixStridesWithFormat(b, h, s_q, d_v, o_strides.data(), o_format); Q = mha_graph->tensor(fe::graph::Tensor_attributes() .set_name("Q") .set_dim({b, h, s_q, d_qk}) - .set_stride(q_stride) + .set_stride(q_strides) .set_data_type(qkv_tensor_type)); K = mha_graph->tensor(fe::graph::Tensor_attributes() .set_name("K") .set_dim({b, hg, s_kv, d_qk}) - .set_stride(k_stride) + .set_stride(k_strides) .set_data_type(qkv_tensor_type)); V = mha_graph->tensor(fe::graph::Tensor_attributes() .set_name("V") .set_dim({b, hg, s_kv, d_v}) - .set_stride(v_stride) + .set_stride(v_strides) .set_data_type(qkv_tensor_type)); O = mha_graph->tensor(fe::graph::Tensor_attributes() .set_name("O") .set_dim({b, h, s_q, d_v}) - .set_stride(o_stride) + .set_stride(o_strides) .set_data_type(o_tensor_type)); dO = mha_graph->tensor(fe::graph::Tensor_attributes() .set_name("dO") .set_dim({b, h, s_q, d_v}) - .set_stride(o_stride) + .set_stride(o_strides) .set_data_type(do_tensor_type)); Stats = mha_graph->tensor(fe::graph::Tensor_attributes() .set_name("Stats") @@ -2298,9 +2289,9 @@ void fused_attn_fp8_bwd_impl_v1( std::vector q_t_stride(4); std::vector k_t_stride(4); std::vector dO_t_stride(4); - generateMatrixStridesWithFormat(b, h, s_q, d_qk, q_t_stride.data(), q_format, false); - generateMatrixStridesWithFormat(b, hg, s_kv, d_qk, k_t_stride.data(), kv_format, false); - generateMatrixStridesWithFormat(b, h, s_q, d_v, dO_t_stride.data(), d_out_format, false); + generateMatrixStridesWithFormat(b, h, s_q, d_qk, q_t_stride.data(), q_format); + generateMatrixStridesWithFormat(b, hg, s_kv, d_qk, k_t_stride.data(), kv_format); + generateMatrixStridesWithFormat(b, h, s_q, d_v, dO_t_stride.data(), d_out_format); Q_t = mha_graph->tensor(fe::graph::Tensor_attributes() .set_name("Q_t") .set_dim({b, h, s_q, d_qk}) @@ -2319,7 +2310,7 @@ void fused_attn_fp8_bwd_impl_v1( dO_f16 = mha_graph->tensor(fe::graph::Tensor_attributes() .set_name("dO_f16") .set_dim({b, h, s_q, d_v}) - .set_stride(o_stride) + .set_stride(o_strides) .set_data_type(o_tensor_type)); // Descale_q, Descale_q_t, Descale_k, Descale_k_t, Descale_v, Descale_dO, Descale_dO_t auto padded = pad_s_d_for_mxfp8(s_q, s_kv, d_qk, d_v); @@ -2331,19 +2322,19 @@ void fused_attn_fp8_bwd_impl_v1( std::vector dO_scale_strides(4); std::vector dO_t_scale_strides(4); generateMatrixStridesWithFormat(b, h, padded.s_q_padded, padded.d_qk_scale_padded, - q_scale_strides.data(), q_format, false); + q_scale_strides.data(), q_format); generateMatrixStridesWithFormat(b, h, padded.s_q_scale_padded, padded.d_qk_padded, - q_t_scale_strides.data(), q_format, false); + q_t_scale_strides.data(), q_format); generateMatrixStridesWithFormat(b, hg, padded.s_kv_padded, padded.d_qk_scale_padded, - k_scale_strides.data(), kv_format, false); + k_scale_strides.data(), kv_format); generateMatrixStridesWithFormat(b, hg, padded.s_kv_scale_padded, padded.d_qk_padded, - k_t_scale_strides.data(), kv_format, false); + k_t_scale_strides.data(), kv_format); generateMatrixStridesWithFormat(b, hg, padded.s_kv_padded, padded.d_v_scale_padded, - v_scale_strides.data(), kv_format, false); + v_scale_strides.data(), kv_format); generateMatrixStridesWithFormat(b, h, padded.s_q_padded, padded.d_v_scale_padded, - dO_scale_strides.data(), d_out_format, false); + dO_scale_strides.data(), d_out_format); generateMatrixStridesWithFormat(b, h, padded.s_q_scale_padded, padded.d_v_padded, - dO_t_scale_strides.data(), d_out_format, false); + dO_t_scale_strides.data(), d_out_format); descale_q = mha_graph->tensor(fe::graph::Tensor_attributes() .set_name("Descale_q") @@ -2490,26 +2481,21 @@ void fused_attn_fp8_bwd_impl_v1( amax_dK = outputs[4]; amax_dV = outputs[5]; } - std::vector dq_stride(4); - std::vector dk_stride(4); - std::vector dv_stride(4); - generateMatrixStrides_v1(b, h, hg, s_q, s_kv, d_qk, d_v, dq_stride.data(), dqkv_layout, - NVTE_QKV_Matrix::NVTE_Q_Matrix); - generateMatrixStrides_v1(b, hg, hg, s_q, s_kv, d_qk, d_v, dk_stride.data(), dqkv_layout, - NVTE_QKV_Matrix::NVTE_K_Matrix); - generateMatrixStrides_v1(b, hg, hg, s_q, s_kv, d_qk, d_v, dv_stride.data(), dqkv_layout, - NVTE_QKV_Matrix::NVTE_V_Matrix); + std::vector dq_strides(4); + std::vector dk_strides(4); + std::vector dv_strides(4); + generateMatrixStridesWithLayout(b, h, hg, s_q, s_kv, d_qk, d_v, dq_strides.data(), dk_strides.data(), dv_strides.data(), dqkv_layout); dQ->set_output(true) .set_dim({b, h, s_q, d_qk}) - .set_stride(dq_stride) + .set_stride(dq_strides) .set_data_type(dqkv_tensor_type); dK->set_output(true) .set_dim({b, hg, s_kv, d_qk}) - .set_stride(dk_stride) + .set_stride(dk_strides) .set_data_type(dqkv_tensor_type); dV->set_output(true) .set_dim({b, hg, s_kv, d_v}) - .set_stride(dv_stride) + .set_stride(dv_strides) .set_data_type(dqkv_tensor_type); amax_dQ->set_output(true) .set_dim({1, 1, 1, 1}) @@ -2523,7 +2509,7 @@ void fused_attn_fp8_bwd_impl_v1( .set_dim({1, 1, 1, 1}) .set_stride({1, 1, 1, 1}) .set_data_type(fe::DataType_t::FLOAT); - if (!is_mxfp8) { + if (is_delayed_scaling || is_current_scaling) { amax_dP->set_output(true) .set_dim({1, 1, 1, 1}) .set_stride({1, 1, 1, 1}) @@ -2643,9 +2629,9 @@ void fused_attn_fp8_bwd_impl_v1( variant_pack[dO_t] = devPtrdO_t; variant_pack[descale_q_t] = devPtrDescaleQ_t; variant_pack[descale_k_t] = devPtrDescaleK_t; - // variant_pack[descale_dO] = devPtrDescaledO; variant_pack[descale_dO_t] = devPtrDescaledO_t; } + /* if (is_bias) { variant_pack[bias] = devPtrBias; if ((bias_b == 1) && (bias_h == h)) { @@ -2698,50 +2684,35 @@ void fused_attn_fp8_fwd(size_t batch, size_t num_attn_heads, size_t num_gqa_grou const Tensor* rng_state, Tensor* workspace, cudaStream_t stream, cudnnHandle_t handle) { using namespace transformer_engine; - void* devPtrQ = nullptr; - void* devPtrK = nullptr; - void* devPtrV = nullptr; - void* devPtrDescaleQ = nullptr; - void* devPtrDescaleK = nullptr; - void* devPtrDescaleV = nullptr; - void* devPtrO = nullptr; - void* devPtrAmaxO = nullptr; - void* devPtrScaleO = nullptr; - void* devPtrAmaxS = nullptr; - void* devPtrScaleS = nullptr; - void* devPtrDescaleS = nullptr; - if (input_Q->scaling_mode == NVTE_MXFP8_1D_SCALING) { - devPtrQ = input_Q->data.dptr; - devPtrDescaleQ = input_Q->scale_inv.dptr; - devPtrK = input_K->data.dptr; - devPtrDescaleK = input_K->scale_inv.dptr; - // devPtrV = input_V->data.dptr; - // devPtrDescaleV = input_V->scale_inv.dptr; - devPtrV = input_V->columnwise_data.dptr; - devPtrDescaleV = input_V->columnwise_scale_inv.dptr; - devPtrO = output_O->data.dptr; - devPtrAmaxO = output_O->amax.dptr; - } else { - devPtrQ = input_Q->data.dptr; - devPtrDescaleQ = input_Q->scale_inv.dptr; - devPtrK = input_K->data.dptr; - devPtrDescaleK = input_K->scale_inv.dptr; + void* devPtrQ = nullptr, *devPtrK = nullptr, *devPtrV = nullptr; + void* devPtrDescaleQ = nullptr, *devPtrDescaleK = nullptr, *devPtrDescaleV = nullptr; + void* devPtrO = nullptr, *devPtrAmaxO = nullptr, *devPtrScaleO = nullptr; + void* devPtrAmaxS = nullptr, *devPtrScaleS = nullptr, *devPtrDescaleS = nullptr; + devPtrQ = input_Q->data.dptr; + devPtrDescaleQ = input_Q->scale_inv.dptr; + devPtrK = input_K->data.dptr; + devPtrDescaleK = input_K->scale_inv.dptr; + devPtrO = output_O->data.dptr; + devPtrAmaxO = output_O->amax.dptr; + if (input_Q->scaling_mode == NVTE_DELAYED_TENSOR_SCALING) { devPtrV = input_V->data.dptr; devPtrDescaleV = input_V->scale_inv.dptr; - devPtrO = output_O->data.dptr; - devPtrAmaxO = output_O->amax.dptr; devPtrScaleO = output_O->scale.dptr; devPtrAmaxS = input_output_S->amax.dptr; devPtrScaleS = input_output_S->scale.dptr; devPtrDescaleS = input_output_S->scale_inv.dptr; + } else if (input_Q->scaling_mode == NVTE_MXFP8_1D_SCALING) { + devPtrV = input_V->columnwise_data.dptr; + devPtrDescaleV = input_V->columnwise_scale_inv.dptr; } void* devPtrM = nullptr; void* devPtrZInv = nullptr; if (Aux_CTX_Tensors->size == 0) { Aux_CTX_Tensors->size = 3; - Tensor* output_M = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]); - Tensor* output_ZInv = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[1]); - Tensor* output_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[2]); + int i=0; + Tensor* output_M = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); + Tensor* output_ZInv = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); + Tensor* output_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); output_M->data.dptr = nullptr; output_M->data.shape = {batch, num_attn_heads, max_seqlen_q, 1}; output_M->data.dtype = DType::kFloat32; @@ -2752,9 +2723,10 @@ void fused_attn_fp8_fwd(size_t batch, size_t num_attn_heads, size_t num_gqa_grou output_rng_state->data.shape = {2}; output_rng_state->data.dtype = DType::kInt64; } else if (Aux_CTX_Tensors->size == 3) { - Tensor* output_M = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]); - Tensor* output_ZInv = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[1]); - Tensor* output_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[2]); + int i=0; + Tensor* output_M = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); + Tensor* output_ZInv = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); + Tensor* output_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); devPtrM = output_M->data.dptr; devPtrZInv = output_ZInv->data.dptr; output_rng_state->data.dptr = rng_state->data.dptr; diff --git a/transformer_engine/common/fused_attn/utils.h b/transformer_engine/common/fused_attn/utils.h index 3e4ca696e2..79a6d5f0b7 100644 --- a/transformer_engine/common/fused_attn/utils.h +++ b/transformer_engine/common/fused_attn/utils.h @@ -24,13 +24,13 @@ using namespace transformer_engine; enum NVTE_QKV_Matrix { NVTE_Q_Matrix = 0, // queries - NVTE_Q_Matrix_Transpose = 1, // queries transposed - NVTE_K_Matrix = 2, // keys - NVTE_K_Matrix_Transpose = 3, // keys transposed - NVTE_V_Matrix = 4, // values - NVTE_V_Matrix_Transpose = 5, // values transposed + NVTE_K_Matrix = 1, // keys + NVTE_K_Matrix_Transpose = 2, // keys transposed + NVTE_V_Matrix = 3, // values + NVTE_V_Matrix_Transpose = 4, // values transposed NVTE_S_Matrix = 5, // output of GEMM1 NVTE_O_Matrix = 6, // final output + NVTE_Q_Matrix_Transpose = 7, // queries transposed }; // Padded sizes for MXFP8 layout (s_q/s_kv/d_qk/d_v and their scaled dimensions) @@ -49,14 +49,7 @@ struct MXFP8PaddedSizes { int64_t d_v_scale_padded; }; -inline bool is_aligned_modulo(void *ptr, int64_t modulo) { - // Cast the pointer to a large enough integer type (uintptr_t) - uintptr_t address = reinterpret_cast(ptr); - // Check if the address is perfectly divisible by 16 - return (address % modulo) == 0; -} - -// Pad s and d for MXFP8 layout +// Pad s and d for MXFP8 quantization inline MXFP8PaddedSizes pad_s_d_for_mxfp8(int64_t s_q, int64_t s_kv, int64_t d_qk, int64_t d_v) { constexpr int64_t block_size = 32; MXFP8PaddedSizes p; @@ -75,35 +68,34 @@ inline MXFP8PaddedSizes pad_s_d_for_mxfp8(int64_t s_q, int64_t s_kv, int64_t d_q return p; } -// Get matrix strides for a 4D tensor [batch, head, seqlen, hidden] given a QKV format. -// strideA must point to at least 4 int64_t elements. +// Get matrix strides for a 4D tensor [batch_size, num_heads, sequence_len, head_dim] given a QKV format. +// strides must point to at least 4 int64_t elements. inline void generateMatrixStridesWithFormat(int64_t b, int64_t h, int64_t s, int64_t d, - int64_t *strides, NVTE_QKV_Format format, - bool transpose) { - constexpr int batch_dim_idx = 0; - constexpr int head_dim_idx = 1; - int seqlen_dim_idx = transpose ? 3 : 2; - int hidden_dim_idx = transpose ? 2 : 3; + int64_t *strides, NVTE_QKV_Format format) { + constexpr int b_dim = 0; + constexpr int h_dim = 1; + constexpr int s_dim = 2; + constexpr int d_dim = 3; switch (format) { case NVTE_QKV_Format::NVTE_BSHD: case NVTE_QKV_Format::NVTE_THD: - strides[batch_dim_idx] = s * h * d; - strides[head_dim_idx] = d; - strides[seqlen_dim_idx] = h * d; - strides[hidden_dim_idx] = 1; + strides[b_dim] = s * h * d; + strides[h_dim] = d; + strides[s_dim] = h * d; + strides[d_dim] = 1; break; case NVTE_QKV_Format::NVTE_SBHD: - strides[batch_dim_idx] = h * d; - strides[head_dim_idx] = d; - strides[seqlen_dim_idx] = b * h * d; - strides[hidden_dim_idx] = 1; + strides[b_dim] = h * d; + strides[h_dim] = d; + strides[s_dim] = b * h * d; + strides[d_dim] = 1; break; case NVTE_QKV_Format::NVTE_BHSD: - strides[batch_dim_idx] = h * s * d; - strides[head_dim_idx] = s * d; - strides[seqlen_dim_idx] = d; - strides[hidden_dim_idx] = 1; + strides[b_dim] = h * s * d; + strides[h_dim] = s * d; + strides[s_dim] = d; + strides[d_dim] = 1; break; default: NVTE_CHECK(false, "Invalid format."); @@ -111,148 +103,121 @@ inline void generateMatrixStridesWithFormat(int64_t b, int64_t h, int64_t s, int } } -// get matrix strides based on matrix type -inline void generateMatrixStrides_v1(int64_t b, int64_t h, int64_t hg, int64_t s_q, int64_t s_kv, - int64_t d_qk, int64_t d_v, int64_t *strides, - NVTE_QKV_Layout layout, NVTE_QKV_Matrix matrix) { - constexpr int batch_dim_idx = 0; - constexpr int head_dim_idx = 1; - bool transpose = (matrix == NVTE_QKV_Matrix::NVTE_Q_Matrix_Transpose) || - (matrix == NVTE_QKV_Matrix::NVTE_K_Matrix_Transpose) || - (matrix == NVTE_QKV_Matrix::NVTE_V_Matrix_Transpose); - int seqlen_dim_idx = transpose ? 3 : 2; - int hidden_dim_idx = transpose ? 2 : 3; - constexpr int seqlen_q_dim_idx = 2; - constexpr int seqlen_kv_dim_idx = 3; - - if (matrix == NVTE_QKV_Matrix::NVTE_Q_Matrix_Transpose) { - matrix = NVTE_QKV_Matrix::NVTE_Q_Matrix; - } else if (matrix == NVTE_QKV_Matrix::NVTE_K_Matrix_Transpose) { - matrix = NVTE_QKV_Matrix::NVTE_K_Matrix; - } else if (matrix == NVTE_QKV_Matrix::NVTE_V_Matrix_Transpose) { - matrix = NVTE_QKV_Matrix::NVTE_V_Matrix; - } - NVTE_CHECK(matrix != NVTE_QKV_Matrix::NVTE_O_Matrix, - "Invalid matrix type. Expected Q, K, V, O, or their related transposes."); +// get matrix strides based on layout and matrix type +inline void generateMatrixStridesWithLayout(int64_t b, int64_t h, int64_t hg, int64_t s_q, int64_t s_kv, + int64_t d_qk, int64_t d_v, int64_t *q_strides, int64_t *k_strides, int64_t *v_strides, + NVTE_QKV_Layout layout) { + constexpr int b_dim = 0; + constexpr int h_dim = 1; + constexpr int s_dim = 2; + constexpr int d_dim = 3; switch (layout) { case NVTE_QKV_Layout::NVTE_SB3HD: - if ((matrix == NVTE_QKV_Matrix::NVTE_Q_Matrix) || - (matrix == NVTE_QKV_Matrix::NVTE_K_Matrix) || - (matrix == NVTE_QKV_Matrix::NVTE_V_Matrix)) { - strides[batch_dim_idx] = 3 * h * d_qk; - strides[head_dim_idx] = d_qk; - strides[seqlen_dim_idx] = b * 3 * h * d_qk; - strides[hidden_dim_idx] = 1; + q_strides[b_dim] = 3 * h * d_qk; + q_strides[h_dim] = d_qk; + q_strides[s_dim] = b * 3 * h * d_qk; + q_strides[d_dim] = 1; + for (int i = 0; i < 4; ++i) { + k_strides[i] = v_strides[i] = q_strides[i]; } break; case NVTE_QKV_Layout::NVTE_SBH3D: - if ((matrix == NVTE_QKV_Matrix::NVTE_Q_Matrix) || - (matrix == NVTE_QKV_Matrix::NVTE_K_Matrix) || - (matrix == NVTE_QKV_Matrix::NVTE_V_Matrix)) { - strides[batch_dim_idx] = 3 * h * d_qk; - strides[head_dim_idx] = 3 * d_qk; - strides[seqlen_dim_idx] = b * 3 * h * d_qk; - strides[hidden_dim_idx] = 1; + q_strides[b_dim] = 3 * h * d_qk; + q_strides[h_dim] = 3 * d_qk; + q_strides[s_dim] = b * 3 * h * d_qk; + q_strides[d_dim] = 1; + for (int i = 0; i < 4; ++i) { + k_strides[i] = v_strides[i] = q_strides[i]; } break; case NVTE_QKV_Layout::NVTE_SBHD_SB2HD: - if (matrix == NVTE_QKV_Matrix::NVTE_Q_Matrix) { - strides[batch_dim_idx] = h * d_qk; - strides[head_dim_idx] = d_qk; - strides[seqlen_dim_idx] = b * h * d_qk; - strides[hidden_dim_idx] = 1; - } else if ((matrix == NVTE_QKV_Matrix::NVTE_K_Matrix) || - (matrix == NVTE_QKV_Matrix::NVTE_V_Matrix)) { - strides[batch_dim_idx] = 2 * hg * d_qk; - strides[head_dim_idx] = d_qk; - strides[seqlen_dim_idx] = b * 2 * hg * d_qk; - strides[hidden_dim_idx] = 1; - } + q_strides[b_dim] = h * d_qk; + q_strides[h_dim] = d_qk; + q_strides[s_dim] = b * h * d_qk; + q_strides[d_dim] = 1; + k_strides[b_dim] = 2 * hg * d_qk; + k_strides[h_dim] = d_qk; + k_strides[s_dim] = b * 2 * hg * d_qk; + k_strides[d_dim] = 1; + for (int i = 0; i < 4; ++i) { + v_strides[i] = k_strides[i]; + } break; case NVTE_QKV_Layout::NVTE_SBHD_SBH2D: - if (matrix == NVTE_QKV_Matrix::NVTE_Q_Matrix) { - strides[batch_dim_idx] = h * d_qk; - strides[head_dim_idx] = d_qk; - strides[seqlen_dim_idx] = b * h * d_qk; - strides[hidden_dim_idx] = 1; - } else if ((matrix == NVTE_QKV_Matrix::NVTE_K_Matrix) || - (matrix == NVTE_QKV_Matrix::NVTE_V_Matrix)) { - strides[batch_dim_idx] = 2 * hg * d_qk; - strides[head_dim_idx] = 2 * d_qk; - strides[seqlen_dim_idx] = b * 2 * hg * d_qk; - strides[hidden_dim_idx] = 1; + q_strides[b_dim] = h * d_qk; + q_strides[h_dim] = d_qk; + q_strides[s_dim] = b * h * d_qk; + q_strides[d_dim] = 1; + k_strides[b_dim] = 2 * hg * d_qk; + k_strides[h_dim] = 2 * d_qk; + k_strides[s_dim] = b * 2 * hg * d_qk; + k_strides[d_dim] = 1; + for (int i = 0; i < 4; ++i) { + v_strides[i] = k_strides[i]; } break; case NVTE_QKV_Layout::NVTE_SBHD_SBHD_SBHD: case NVTE_QKV_Layout::NVTE_Paged_KV_SBHD_SBHD_SBHD: - if (matrix == NVTE_QKV_Matrix::NVTE_Q_Matrix) { - strides[batch_dim_idx] = h * d_qk; - strides[head_dim_idx] = d_qk; - strides[seqlen_dim_idx] = b * h * d_qk; - strides[hidden_dim_idx] = 1; - } else if (matrix == NVTE_QKV_Matrix::NVTE_K_Matrix) { - strides[batch_dim_idx] = hg * d_qk; - strides[head_dim_idx] = d_qk; - strides[seqlen_dim_idx] = b * hg * d_qk; - strides[hidden_dim_idx] = 1; - } else if (matrix == NVTE_QKV_Matrix::NVTE_V_Matrix) { - strides[batch_dim_idx] = hg * d_v; - strides[head_dim_idx] = d_v; - strides[seqlen_dim_idx] = b * hg * d_v; - strides[hidden_dim_idx] = 1; - } + q_strides[b_dim] = h * d_qk; + q_strides[h_dim] = d_qk; + q_strides[s_dim] = b * h * d_qk; + q_strides[d_dim] = 1; + k_strides[b_dim] = hg * d_qk; + k_strides[h_dim] = d_qk; + k_strides[s_dim] = b * hg * d_qk; + k_strides[d_dim] = 1; + v_strides[b_dim] = hg * d_v; + v_strides[h_dim] = d_v; + v_strides[s_dim] = b * hg * d_v; + v_strides[d_dim] = 1; break; case NVTE_QKV_Layout::NVTE_BS3HD: case NVTE_QKV_Layout::NVTE_T3HD: - if ((matrix == NVTE_QKV_Matrix::NVTE_Q_Matrix) || - (matrix == NVTE_QKV_Matrix::NVTE_K_Matrix) || - (matrix == NVTE_QKV_Matrix::NVTE_V_Matrix)) { - strides[batch_dim_idx] = s_q * 3 * h * d_qk; - strides[head_dim_idx] = d_qk; - strides[seqlen_dim_idx] = 3 * h * d_qk; - strides[hidden_dim_idx] = 1; + q_strides[b_dim] = s_q * 3 * h * d_qk; + q_strides[h_dim] = d_qk; + q_strides[s_dim] = 3 * h * d_qk; + q_strides[d_dim] = 1; + for (int i = 0; i < 4; ++i) { + k_strides[i] = v_strides[i] = q_strides[i]; } break; case NVTE_QKV_Layout::NVTE_BSH3D: case NVTE_QKV_Layout::NVTE_TH3D: - if ((matrix == NVTE_QKV_Matrix::NVTE_Q_Matrix) || - (matrix == NVTE_QKV_Matrix::NVTE_K_Matrix) || - (matrix == NVTE_QKV_Matrix::NVTE_V_Matrix)) { - strides[batch_dim_idx] = s_q * 3 * h * d_qk; - strides[head_dim_idx] = 3 * d_qk; - strides[seqlen_dim_idx] = 3 * h * d_qk; - strides[hidden_dim_idx] = 1; + q_strides[b_dim] = s_q * 3 * h * d_qk; + q_strides[h_dim] = 3 * d_qk; + q_strides[s_dim] = 3 * h * d_qk; + q_strides[d_dim] = 1; + for (int i = 0; i < 4; ++i) { + k_strides[i] = v_strides[i] = q_strides[i]; } break; case NVTE_QKV_Layout::NVTE_BSHD_BS2HD: case NVTE_QKV_Layout::NVTE_THD_T2HD: - if (matrix == NVTE_QKV_Matrix::NVTE_Q_Matrix) { - strides[batch_dim_idx] = s_q * h * d_qk; - strides[head_dim_idx] = d_qk; - strides[seqlen_dim_idx] = h * d_qk; - strides[hidden_dim_idx] = 1; - } else if ((matrix == NVTE_QKV_Matrix::NVTE_K_Matrix) || - (matrix == NVTE_QKV_Matrix::NVTE_V_Matrix)) { - strides[batch_dim_idx] = s_kv * 2 * hg * d_qk; - strides[head_dim_idx] = d_qk; - strides[seqlen_dim_idx] = 2 * hg * d_qk; - strides[hidden_dim_idx] = 1; + q_strides[b_dim] = s_q * h * d_qk; + q_strides[h_dim] = d_qk; + q_strides[s_dim] = h * d_qk; + q_strides[d_dim] = 1; + k_strides[b_dim] = s_kv * 2 * hg * d_qk; + k_strides[h_dim] = d_qk; + k_strides[s_dim] = 2 * hg * d_qk; + k_strides[d_dim] = 1; + for (int i = 0; i < 4; ++i) { + v_strides[i] = k_strides[i]; } break; case NVTE_QKV_Layout::NVTE_BSHD_BSH2D: case NVTE_QKV_Layout::NVTE_THD_TH2D: - if (matrix == NVTE_QKV_Matrix::NVTE_Q_Matrix) { - strides[batch_dim_idx] = s_q * h * d_qk; - strides[head_dim_idx] = d_qk; - strides[seqlen_dim_idx] = h * d_qk; - strides[hidden_dim_idx] = 1; - } else if ((matrix == NVTE_QKV_Matrix::NVTE_K_Matrix) || - (matrix == NVTE_QKV_Matrix::NVTE_V_Matrix)) { - strides[batch_dim_idx] = s_kv * 2 * hg * d_qk; - strides[head_dim_idx] = 2 * d_qk; - strides[seqlen_dim_idx] = 2 * hg * d_qk; - strides[hidden_dim_idx] = 1; + q_strides[b_dim] = s_q * h * d_qk; + q_strides[h_dim] = d_qk; + q_strides[s_dim] = h * d_qk; + q_strides[d_dim] = 1; + k_strides[b_dim] = s_kv * 2 * hg * d_qk; + k_strides[h_dim] = 2 * d_qk; + k_strides[s_dim] = 2 * hg * d_qk; + k_strides[d_dim] = 1; + for (int i = 0; i < 4; ++i) { + v_strides[i] = k_strides[i]; } break; case NVTE_QKV_Layout::NVTE_BSHD_BSHD_BSHD: @@ -260,92 +225,69 @@ inline void generateMatrixStrides_v1(int64_t b, int64_t h, int64_t hg, int64_t s case NVTE_QKV_Layout::NVTE_THD_BSHD_BSHD: case NVTE_QKV_Layout::NVTE_Paged_KV_BSHD_BSHD_BSHD: case NVTE_QKV_Layout::NVTE_Paged_KV_THD_BSHD_BSHD: - if (matrix == NVTE_QKV_Matrix::NVTE_Q_Matrix) { - strides[batch_dim_idx] = s_q * h * d_qk; - strides[head_dim_idx] = d_qk; - strides[seqlen_dim_idx] = h * d_qk; - strides[hidden_dim_idx] = 1; - } else if (matrix == NVTE_QKV_Matrix::NVTE_K_Matrix) { - strides[batch_dim_idx] = s_kv * hg * d_qk; - strides[head_dim_idx] = d_qk; - strides[seqlen_dim_idx] = hg * d_qk; - strides[hidden_dim_idx] = 1; - } else if (matrix == NVTE_QKV_Matrix::NVTE_V_Matrix) { - strides[batch_dim_idx] = s_kv * hg * d_v; - strides[head_dim_idx] = d_v; - strides[seqlen_dim_idx] = hg * d_v; - strides[hidden_dim_idx] = 1; - } + q_strides[b_dim] = s_q * h * d_qk; + q_strides[h_dim] = d_qk; + q_strides[s_dim] = h * d_qk; + q_strides[d_dim] = 1; + k_strides[b_dim] = s_kv * hg * d_qk; + k_strides[h_dim] = d_qk; + k_strides[s_dim] = hg * d_qk; + k_strides[d_dim] = 1; + v_strides[b_dim] = s_kv * hg * d_v; + v_strides[h_dim] = d_v; + v_strides[s_dim] = hg * d_v; + v_strides[d_dim] = 1; break; case NVTE_QKV_Layout::NVTE_SBHD_BSHD_BSHD: case NVTE_QKV_Layout::NVTE_Paged_KV_SBHD_BSHD_BSHD: - if (matrix == NVTE_QKV_Matrix::NVTE_Q_Matrix) { - strides[batch_dim_idx] = h * d_qk; - strides[head_dim_idx] = d_qk; - strides[seqlen_dim_idx] = b * h * d_qk; - strides[hidden_dim_idx] = 1; - } else if (matrix == NVTE_QKV_Matrix::NVTE_K_Matrix) { - strides[batch_dim_idx] = s_kv * hg * d_qk; - strides[head_dim_idx] = d_qk; - strides[seqlen_dim_idx] = hg * d_qk; - strides[hidden_dim_idx] = 1; - } else if (matrix == NVTE_QKV_Matrix::NVTE_V_Matrix) { - strides[batch_dim_idx] = s_kv * hg * d_v; - strides[head_dim_idx] = d_v; - strides[seqlen_dim_idx] = hg * d_v; - strides[hidden_dim_idx] = 1; - } + q_strides[b_dim] = h * d_qk; + q_strides[h_dim] = d_qk; + q_strides[s_dim] = b * h * d_qk; + q_strides[d_dim] = 1; + k_strides[b_dim] = s_kv * hg * d_qk; + k_strides[h_dim] = d_qk; + k_strides[s_dim] = hg * d_qk; + k_strides[d_dim] = 1; + v_strides[b_dim] = s_kv * hg * d_v; + v_strides[h_dim] = d_v; + v_strides[s_dim] = hg * d_v; + v_strides[d_dim] = 1; break; case NVTE_QKV_Layout::NVTE_BSHD_SBHD_SBHD: case NVTE_QKV_Layout::NVTE_THD_SBHD_SBHD: case NVTE_QKV_Layout::NVTE_Paged_KV_BSHD_SBHD_SBHD: case NVTE_QKV_Layout::NVTE_Paged_KV_THD_SBHD_SBHD: - if (matrix == NVTE_QKV_Matrix::NVTE_Q_Matrix) { - strides[batch_dim_idx] = s_q * h * d_qk; - strides[head_dim_idx] = d_qk; - strides[seqlen_dim_idx] = h * d_qk; - strides[hidden_dim_idx] = 1; - } else if (matrix == NVTE_QKV_Matrix::NVTE_K_Matrix) { - strides[batch_dim_idx] = hg * d_qk; - strides[head_dim_idx] = d_qk; - strides[seqlen_dim_idx] = b * hg * d_qk; - strides[hidden_dim_idx] = 1; - } else if (matrix == NVTE_QKV_Matrix::NVTE_V_Matrix) { - strides[batch_dim_idx] = hg * d_v; - strides[head_dim_idx] = d_v; - strides[seqlen_dim_idx] = b * hg * d_v; - strides[hidden_dim_idx] = 1; - } + q_strides[b_dim] = s_q * h * d_qk; + q_strides[h_dim] = d_qk; + q_strides[s_dim] = h * d_qk; + q_strides[d_dim] = 1; + k_strides[b_dim] = hg * d_qk; + k_strides[h_dim] = d_qk; + k_strides[s_dim] = b * hg * d_qk; + k_strides[d_dim] = 1; + v_strides[b_dim] = hg * d_v; + v_strides[h_dim] = d_v; + v_strides[s_dim] = b * hg * d_v; + v_strides[d_dim] = 1; break; case NVTE_QKV_Layout::NVTE_BHSD_BHSD_BHSD: - if (matrix == NVTE_QKV_Matrix::NVTE_Q_Matrix) { - strides[batch_dim_idx] = h * s_q * d_qk; - strides[head_dim_idx] = s_q * d_qk; - strides[seqlen_dim_idx] = d_qk; - strides[hidden_dim_idx] = 1; - } else if (matrix == NVTE_QKV_Matrix::NVTE_K_Matrix) { - strides[batch_dim_idx] = hg * s_kv * d_qk; - strides[head_dim_idx] = s_kv * d_qk; - strides[seqlen_dim_idx] = d_qk; - strides[hidden_dim_idx] = 1; - } else if (matrix == NVTE_QKV_Matrix::NVTE_V_Matrix) { - strides[batch_dim_idx] = hg * s_kv * d_v; - strides[head_dim_idx] = s_kv * d_v; - strides[seqlen_dim_idx] = d_v; - strides[hidden_dim_idx] = 1; - } + q_strides[b_dim] = h * s_q * d_qk; + q_strides[h_dim] = s_q * d_qk; + q_strides[s_dim] = d_qk; + q_strides[d_dim] = 1; + k_strides[b_dim] = hg * s_kv * d_qk; + k_strides[h_dim] = s_kv * d_qk; + k_strides[s_dim] = d_qk; + k_strides[d_dim] = 1; + v_strides[b_dim] = hg * s_kv * d_v; + v_strides[h_dim] = s_kv * d_v; + v_strides[s_dim] = d_v; + v_strides[d_dim] = 1; break; default: NVTE_CHECK(false, "Invalid layout."); break; } - - if (matrix == NVTE_QKV_Matrix::NVTE_S_Matrix) { - strides[seqlen_kv_dim_idx] = 1; - strides[seqlen_q_dim_idx] = s_kv; - strides[head_dim_idx] = s_q * s_kv; - strides[batch_dim_idx] = h * s_q * s_kv; - } } void generateMatrixStrides(int64_t b, int64_t h, int64_t s_q, int64_t s_kv, int64_t d, From 6717e1a8bdf8aa7042040dd14ea438efcca8b856 Mon Sep 17 00:00:00 2001 From: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Date: Thu, 12 Mar 2026 20:12:08 -0700 Subject: [PATCH 066/172] minor fixes for previous commit Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --- transformer_engine/common/fused_attn/utils.cu | 5 ----- transformer_engine/common/fused_attn/utils.h | 17 ++++++++--------- 2 files changed, 8 insertions(+), 14 deletions(-) diff --git a/transformer_engine/common/fused_attn/utils.cu b/transformer_engine/common/fused_attn/utils.cu index e67ae5e206..f37eeb0c68 100644 --- a/transformer_engine/common/fused_attn/utils.cu +++ b/transformer_engine/common/fused_attn/utils.cu @@ -312,11 +312,6 @@ void generateMatrixStrides(int64_t b, int64_t h, int64_t s_q, int64_t s_kv, int6 strideA[head_dim_idx] = s_kv * d; strideA[seqlen_transpose_dim_idx] = d; strideA[hidden_transpose_dim_idx] = 1; - } else if (matrix == NVTE_QKV_Matrix::NVTE_Q_Matrix_Transpose) { - strideA[batch_dim_idx] = h * s_q * d; - strideA[head_dim_idx] = s_q * d; - strideA[seqlen_transpose_dim_idx] = d; - strideA[hidden_transpose_dim_idx] = 1; } break; } diff --git a/transformer_engine/common/fused_attn/utils.h b/transformer_engine/common/fused_attn/utils.h index 79a6d5f0b7..8c60f85cf5 100644 --- a/transformer_engine/common/fused_attn/utils.h +++ b/transformer_engine/common/fused_attn/utils.h @@ -30,7 +30,6 @@ enum NVTE_QKV_Matrix { NVTE_V_Matrix_Transpose = 4, // values transposed NVTE_S_Matrix = 5, // output of GEMM1 NVTE_O_Matrix = 6, // final output - NVTE_Q_Matrix_Transpose = 7, // queries transposed }; // Padded sizes for MXFP8 layout (s_q/s_kv/d_qk/d_v and their scaled dimensions) @@ -118,7 +117,7 @@ inline void generateMatrixStridesWithLayout(int64_t b, int64_t h, int64_t hg, in q_strides[h_dim] = d_qk; q_strides[s_dim] = b * 3 * h * d_qk; q_strides[d_dim] = 1; - for (int i = 0; i < 4; ++i) { + for (int i = 0; i < 4; i++) { k_strides[i] = v_strides[i] = q_strides[i]; } break; @@ -127,7 +126,7 @@ inline void generateMatrixStridesWithLayout(int64_t b, int64_t h, int64_t hg, in q_strides[h_dim] = 3 * d_qk; q_strides[s_dim] = b * 3 * h * d_qk; q_strides[d_dim] = 1; - for (int i = 0; i < 4; ++i) { + for (int i = 0; i < 4; i++) { k_strides[i] = v_strides[i] = q_strides[i]; } break; @@ -140,7 +139,7 @@ inline void generateMatrixStridesWithLayout(int64_t b, int64_t h, int64_t hg, in k_strides[h_dim] = d_qk; k_strides[s_dim] = b * 2 * hg * d_qk; k_strides[d_dim] = 1; - for (int i = 0; i < 4; ++i) { + for (int i = 0; i < 4; i++) { v_strides[i] = k_strides[i]; } break; @@ -153,7 +152,7 @@ inline void generateMatrixStridesWithLayout(int64_t b, int64_t h, int64_t hg, in k_strides[h_dim] = 2 * d_qk; k_strides[s_dim] = b * 2 * hg * d_qk; k_strides[d_dim] = 1; - for (int i = 0; i < 4; ++i) { + for (int i = 0; i < 4; i++) { v_strides[i] = k_strides[i]; } break; @@ -178,7 +177,7 @@ inline void generateMatrixStridesWithLayout(int64_t b, int64_t h, int64_t hg, in q_strides[h_dim] = d_qk; q_strides[s_dim] = 3 * h * d_qk; q_strides[d_dim] = 1; - for (int i = 0; i < 4; ++i) { + for (int i = 0; i < 4; i++) { k_strides[i] = v_strides[i] = q_strides[i]; } break; @@ -188,7 +187,7 @@ inline void generateMatrixStridesWithLayout(int64_t b, int64_t h, int64_t hg, in q_strides[h_dim] = 3 * d_qk; q_strides[s_dim] = 3 * h * d_qk; q_strides[d_dim] = 1; - for (int i = 0; i < 4; ++i) { + for (int i = 0; i < 4; i++) { k_strides[i] = v_strides[i] = q_strides[i]; } break; @@ -202,7 +201,7 @@ inline void generateMatrixStridesWithLayout(int64_t b, int64_t h, int64_t hg, in k_strides[h_dim] = d_qk; k_strides[s_dim] = 2 * hg * d_qk; k_strides[d_dim] = 1; - for (int i = 0; i < 4; ++i) { + for (int i = 0; i < 4; i++) { v_strides[i] = k_strides[i]; } break; @@ -216,7 +215,7 @@ inline void generateMatrixStridesWithLayout(int64_t b, int64_t h, int64_t hg, in k_strides[h_dim] = 2 * d_qk; k_strides[s_dim] = 2 * hg * d_qk; k_strides[d_dim] = 1; - for (int i = 0; i < 4; ++i) { + for (int i = 0; i < 4; i++) { v_strides[i] = k_strides[i]; } break; From c918b9d8d2f9f3f21e79795c71a803b71c4c9902 Mon Sep 17 00:00:00 2001 From: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Date: Thu, 12 Mar 2026 20:12:29 -0700 Subject: [PATCH 067/172] fix bwd for current/delayed Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --- .../common/fused_attn/fused_attn_fp8.cu | 20 ++++++++++++------- 1 file changed, 13 insertions(+), 7 deletions(-) diff --git a/transformer_engine/common/fused_attn/fused_attn_fp8.cu b/transformer_engine/common/fused_attn/fused_attn_fp8.cu index 8b395e3deb..8f8de69394 100644 --- a/transformer_engine/common/fused_attn/fused_attn_fp8.cu +++ b/transformer_engine/common/fused_attn/fused_attn_fp8.cu @@ -2805,10 +2805,13 @@ void fused_attn_fp8_bwd(size_t batch, size_t num_attn_heads, size_t num_gqa_grou void* devPtrDescaleQ = input_Q->scale_inv.dptr; void* devPtrDescaleK = input_K->scale_inv.dptr; void* devPtrDescaleV = input_V->scale_inv.dptr; - void* devPtrQ_t = input_Q->columnwise_data.dptr; - void* devPtrK_t = input_K->columnwise_data.dptr; - void* devPtrDescaleQ_t = input_Q->columnwise_scale_inv.dptr; - void* devPtrDescaleK_t = input_K->columnwise_scale_inv.dptr; + void* devPtrQ_t = nullptr, *devPtrK_t = nullptr, *devPtrDescaleQ_t = nullptr, *devPtrDescaleK_t = nullptr; + if (input_Q->scaling_mode == NVTE_MXFP8_1D_SCALING) { + devPtrQ_t = input_Q->columnwise_data.dptr; + devPtrDescaleQ_t = input_Q->columnwise_scale_inv.dptr; + devPtrK_t = input_K->columnwise_data.dptr; + devPtrDescaleK_t = input_K->columnwise_scale_inv.dptr; + } void* devPtrO = input_O->data.dptr; const DType O_type = input_O->data.dtype; @@ -2818,9 +2821,12 @@ void fused_attn_fp8_bwd(size_t batch, size_t num_attn_heads, size_t num_gqa_grou } void* devPtrdO = input_dO->data.dptr; void* devPtrDescaledO = input_dO->scale_inv.dptr; - void* devPtrdO_t = input_dO->columnwise_data.dptr; - void* devPtrdO_f16 = input_dO_f16->data.dptr; - void* devPtrDescaledO_t = input_dO->columnwise_scale_inv.dptr; + void* devPtrdO_t = nullptr, *devPtrdO_f16 = nullptr, *devPtrDescaledO_t = nullptr; + if (input_dO->scaling_mode == NVTE_MXFP8_1D_SCALING) { + devPtrdO_t = input_dO->columnwise_data.dptr; + devPtrdO_f16 = input_dO_f16->data.dptr; + devPtrDescaledO_t = input_dO->columnwise_scale_inv.dptr; + } void* devPtrM = input_M->data.dptr; void* devPtrZInv = input_ZInv->data.dptr; From af60216c3fd0124af2915be208cc37525a304404 Mon Sep 17 00:00:00 2001 From: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Date: Thu, 12 Mar 2026 20:18:55 -0700 Subject: [PATCH 068/172] tweak test configs Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --- tests/pytorch/attention/test_attention.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/pytorch/attention/test_attention.py b/tests/pytorch/attention/test_attention.py index 7ae73a753a..23a656d35e 100644 --- a/tests/pytorch/attention/test_attention.py +++ b/tests/pytorch/attention/test_attention.py @@ -1805,7 +1805,7 @@ def get_model(dtype, config): model_configs_fp8_vs_f16 = { # test: ModelConfig(b, sq, hq, dqk) "fp8_9": ModelConfig( - 2, 4096, 32, 128, num_gqa_groups=4, attn_mask_type="causal", window_size=(128, 0) + 2, 4096, 32, 128, num_gqa_groups=4, attn_mask_type="causal" ), "fp8_10": ModelConfig( 2, @@ -1815,7 +1815,7 @@ def get_model(dtype, config): head_dim_v=128, attn_mask_type="causal", ), - # "fp8_11": ModelConfig(1, 8192, 32, 128, num_gqa_groups=4), + "fp8_11": ModelConfig(2, 4096, 32, 128, attn_mask_type="causal", window_size=(128, 0)), # "fp8_12": ModelConfig(2, 2048, 16, 128, attn_mask_type="causal"), # "fp8_13": ModelConfig(2, 2048, 24, 128, num_gqa_groups=12, attn_mask_type="causal"), # "fp8_14": ModelConfig(1, 8192, 32, 128, num_gqa_groups=4, attn_mask_type="causal"), From 6ac41d21c957199bdb3ce9230ac4bfea603709f4 Mon Sep 17 00:00:00 2001 From: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Date: Fri, 13 Mar 2026 12:57:51 -0700 Subject: [PATCH 069/172] fix dO/dO_f16 strides Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --- .../common/fused_attn/fused_attn_fp8.cu | 24 ++++++++++--------- 1 file changed, 13 insertions(+), 11 deletions(-) diff --git a/transformer_engine/common/fused_attn/fused_attn_fp8.cu b/transformer_engine/common/fused_attn/fused_attn_fp8.cu index 8f8de69394..e3b86908e7 100644 --- a/transformer_engine/common/fused_attn/fused_attn_fp8.cu +++ b/transformer_engine/common/fused_attn/fused_attn_fp8.cu @@ -2213,8 +2213,10 @@ void fused_attn_fp8_bwd_impl_v1( std::vector k_strides(4); std::vector v_strides(4); std::vector o_strides(4); + std::vector dO_strides(4); generateMatrixStridesWithLayout(b, h, hg, s_q, s_kv, d_qk, d_v, q_strides.data(), k_strides.data(), v_strides.data(), qkv_layout); generateMatrixStridesWithFormat(b, h, s_q, d_v, o_strides.data(), o_format); + generateMatrixStridesWithFormat(b, h, s_q, d_v, dO_strides.data(), d_out_format); Q = mha_graph->tensor(fe::graph::Tensor_attributes() .set_name("Q") .set_dim({b, h, s_q, d_qk}) @@ -2238,7 +2240,7 @@ void fused_attn_fp8_bwd_impl_v1( dO = mha_graph->tensor(fe::graph::Tensor_attributes() .set_name("dO") .set_dim({b, h, s_q, d_v}) - .set_stride(o_strides) + .set_stride(dO_strides) .set_data_type(do_tensor_type)); Stats = mha_graph->tensor(fe::graph::Tensor_attributes() .set_name("Stats") @@ -2286,31 +2288,31 @@ void fused_attn_fp8_bwd_impl_v1( NVTE_QKV_Format q_format = nvte_get_q_format(qkv_layout); NVTE_QKV_Format kv_format = nvte_get_kv_format(qkv_layout); // Q_t, K_t, dO_t, dO_f16 - std::vector q_t_stride(4); - std::vector k_t_stride(4); - std::vector dO_t_stride(4); - generateMatrixStridesWithFormat(b, h, s_q, d_qk, q_t_stride.data(), q_format); - generateMatrixStridesWithFormat(b, hg, s_kv, d_qk, k_t_stride.data(), kv_format); - generateMatrixStridesWithFormat(b, h, s_q, d_v, dO_t_stride.data(), d_out_format); + std::vector q_t_strides(4); + std::vector k_t_strides(4); + std::vector dO_t_strides(4); + generateMatrixStridesWithFormat(b, h, s_q, d_qk, q_t_strides.data(), q_format); + generateMatrixStridesWithFormat(b, hg, s_kv, d_qk, k_t_strides.data(), kv_format); + generateMatrixStridesWithFormat(b, h, s_q, d_v, dO_t_strides.data(), d_out_format); Q_t = mha_graph->tensor(fe::graph::Tensor_attributes() .set_name("Q_t") .set_dim({b, h, s_q, d_qk}) - .set_stride(q_t_stride) + .set_stride(q_t_strides) .set_data_type(qkv_tensor_type)); K_t = mha_graph->tensor(fe::graph::Tensor_attributes() .set_name("K_t") .set_dim({b, hg, s_kv, d_qk}) - .set_stride(k_t_stride) + .set_stride(k_t_strides) .set_data_type(qkv_tensor_type)); dO_t = mha_graph->tensor(fe::graph::Tensor_attributes() .set_name("dO_t") .set_dim({b, h, s_q, d_v}) - .set_stride(dO_t_stride) + .set_stride(dO_t_strides) .set_data_type(do_tensor_type)); dO_f16 = mha_graph->tensor(fe::graph::Tensor_attributes() .set_name("dO_f16") .set_dim({b, h, s_q, d_v}) - .set_stride(o_strides) + .set_stride(dO_strides) .set_data_type(o_tensor_type)); // Descale_q, Descale_q_t, Descale_k, Descale_k_t, Descale_v, Descale_dO, Descale_dO_t auto padded = pad_s_d_for_mxfp8(s_q, s_kv, d_qk, d_v); From 0a0722f1bf3ee95c55ef7b3591cc7b91ffefe756 Mon Sep 17 00:00:00 2001 From: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Date: Fri, 13 Mar 2026 15:35:28 -0700 Subject: [PATCH 070/172] fix tests: SWA logic/test configs Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --- tests/pytorch/attention/test_attention.py | 20 ++++++++++--------- .../attention/test_attention_with_cp.py | 18 ++++++++++------- .../attention/dot_product_attention/utils.py | 5 ++--- 3 files changed, 24 insertions(+), 19 deletions(-) diff --git a/tests/pytorch/attention/test_attention.py b/tests/pytorch/attention/test_attention.py index 23a656d35e..16c1be0e2e 100644 --- a/tests/pytorch/attention/test_attention.py +++ b/tests/pytorch/attention/test_attention.py @@ -1801,24 +1801,26 @@ def get_model(dtype, config): return outputs - +attn_mask_type = "causal" +# attn_mask_type = "no_mask" +# attn_mask_type = "causal_bottom_right" model_configs_fp8_vs_f16 = { # test: ModelConfig(b, sq, hq, dqk) "fp8_9": ModelConfig( - 2, 4096, 32, 128, num_gqa_groups=4, attn_mask_type="causal" - ), - "fp8_10": ModelConfig( 2, 4096, 128, 192, head_dim_v=128, - attn_mask_type="causal", + attn_mask_type=attn_mask_type, + ), + "fp8_10": ModelConfig( + 2, 8192, 32, 128, num_gqa_groups=4, attn_mask_type=attn_mask_type ), - "fp8_11": ModelConfig(2, 4096, 32, 128, attn_mask_type="causal", window_size=(128, 0)), - # "fp8_12": ModelConfig(2, 2048, 16, 128, attn_mask_type="causal"), - # "fp8_13": ModelConfig(2, 2048, 24, 128, num_gqa_groups=12, attn_mask_type="causal"), - # "fp8_14": ModelConfig(1, 8192, 32, 128, num_gqa_groups=4, attn_mask_type="causal"), + "fp8_11": ModelConfig(2, 8192, 64, 64, num_gqa_groups=8, attn_mask_type=attn_mask_type), + "fp8_12": ModelConfig(2, 8192, 64, 64, attn_mask_type=attn_mask_type, window_size=(128, 0)), + # "fp8_13": ModelConfig(2, 8192, 64, 64, num_gqa_groups=8, attn_mask_type="causal", softmax_type="learnable"), + # "fp8_14": ModelConfig(2, 8192, 64, 64, attn_mask_type="causal", window_size=(128, 0), softmax_type="learnable"), # "fp8_15": ModelConfig(2, 2048, 16, 128, attn_mask_type="padding"), # "fp8_16": ModelConfig(2, 2048, 24, 128, num_gqa_groups=12, attn_mask_type="padding"), # "fp8_17": ModelConfig(1, 8192, 32, 128, num_gqa_groups=4, attn_mask_type="padding"), diff --git a/tests/pytorch/attention/test_attention_with_cp.py b/tests/pytorch/attention/test_attention_with_cp.py index d83bb62ac4..5f2b37d79e 100644 --- a/tests/pytorch/attention/test_attention_with_cp.py +++ b/tests/pytorch/attention/test_attention_with_cp.py @@ -167,9 +167,9 @@ def test_cp_with_flash_attention(dtype, model, qkv_format, cp_comm_type): "cp_1_4": ModelConfig( 2, 4096, 12, 128, attn_bias_type="post_scale_bias", bias_shape="bhss" ), # MHA - "cp_1_5": ModelConfig(2, 4096, 12, 128, attn_mask_type="causal", window_size=(512, 512)), # MHA + "cp_1_5": ModelConfig(2, 4096, 32, 128, attn_mask_type="causal", window_size=(128, 0)), # MHA "cp_2_0": ModelConfig( - 2, 4096, 32, 128, num_gqa_groups=4, attn_mask_type="causal", window_size=(128, 0) + 2, 4096, 32, 128, num_gqa_groups=4, attn_mask_type="causal", ), # GQA "cp_2_1": ModelConfig( 2, 4096, 128, 192, head_dim_v=128, attn_mask_type="causal" @@ -239,7 +239,7 @@ def test_cp_with_flash_attention(dtype, model, qkv_format, cp_comm_type): # "cp_1_0", # "cp_1_1", # "cp_1_4", - # "cp_1_5", + "cp_1_5", "cp_2_0", "cp_2_1", # "cp_2_2", @@ -251,7 +251,7 @@ def test_cp_with_flash_attention(dtype, model, qkv_format, cp_comm_type): # "cp_4_2", ] model_configs_fused_attn = {k: model_configs_fused_attn[k] for k in configs} - dtypes = ["fp8"] # ["bf16", "fp8"] + dtypes = ["bf16", "fp8"] qkv_formats = ["bshd"] # , "sbhd", "thd"] @@ -264,8 +264,8 @@ def test_cp_with_flash_attention(dtype, model, qkv_format, cp_comm_type): @pytest.mark.parametrize("fp8_bwd", [True, False]) @pytest.mark.parametrize("fp8_mha", [True, False]) @pytest.mark.parametrize("fp8_dpa", [True, False]) -@pytest.mark.parametrize("scaling_mode", ["delayed", "current", "mxfp8"]) -@pytest.mark.parametrize("f16_O", [True]) +@pytest.mark.parametrize("scaling_mode", [None, "delayed", "current", "mxfp8"]) +@pytest.mark.parametrize("f16_O", [True, False]) def test_cp_with_fused_attention( dtype, model, qkv_format, cp_comm_type, fp8_bwd, fp8_mha, fp8_dpa, scaling_mode, f16_O ): @@ -304,12 +304,16 @@ def test_cp_with_fused_attention( if qkv_format == "thd" and cp_comm_type in ["all_gather", "a2a+p2p"]: pytest.skip("No support for THD format with cp_comm_type={all_gather, a2a+p2p}!") - if (config.window_size[0] != -1 and config.window_size[1] not in [-1, 0]) and cp_comm_type in [ + if (config.window_size[0] != -1 or config.window_size[1] not in [-1, 0]) and cp_comm_type in [ "p2p", "a2a+p2p", ]: pytest.skip("No support for SWA with cp_comm_type={p2p, a2a+p2p}!") + # TODO: Remove this once the issue is fixed! + if dtype == "fp8" and (config.window_size[0] != -1 or config.window_size[1] not in [-1, 0]) and cp_comm_type == "all_gather": + pytest.skip("No support for SWA with FP8 attention and cp_comm_type=all_gather!") + if cp_comm_type in ["a2a", "a2a+p2p"] and ( config.num_heads % 2 != 0 or config.num_gqa_groups % 2 != 0 ): diff --git a/transformer_engine/pytorch/attention/dot_product_attention/utils.py b/transformer_engine/pytorch/attention/dot_product_attention/utils.py index c8b336e9a3..f6c7ae48ee 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/utils.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/utils.py @@ -877,7 +877,7 @@ def _is_fa3_supported(num_heads, num_gqa_groups, head_dim_qk, head_dim_v, qkv_dt use_fused_attention = False elif ( window_size is not None - and window_size[0] != -1 and window_size[1] not in [-1, 0] + and (window_size[0] != -1 or window_size[1] not in [-1, 0]) and cp_comm_type in ["p2p", "a2a+p2p"] ): logger.debug( @@ -1080,8 +1080,7 @@ def _is_fa3_supported(num_heads, num_gqa_groups, head_dim_qk, head_dim_v, qkv_dt if ( use_fused_attention and window_size is not None - and window_size[0] != -1 - and window_size[1] not in [-1, 0] + and (window_size[0] != -1 or window_size[1] not in [-1, 0]) and fused_attention_backend == FusedAttnBackend["F16_max512_seqlen"] ): logger.debug( From 89b44f8caefaa8bc8a783fb850bab1e675aa8aa8 Mon Sep 17 00:00:00 2001 From: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Date: Fri, 13 Mar 2026 15:35:57 -0700 Subject: [PATCH 071/172] fix ag Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --- .../dot_product_attention/context_parallel.py | 36 +++++++++++-------- 1 file changed, 21 insertions(+), 15 deletions(-) diff --git a/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py b/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py index 7886c625b4..a0ed65cdd4 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py @@ -3016,21 +3016,16 @@ def forward( ) = dpa_utils.get_attention_quantizers(fp8, fp8_recipe, quantizers) fwd_nominal_dtype = q.dtype fp8_meta_kwargs = {} - q_fp8, k_fp8, v_fp8 = (None, None, None) - q_f16, k_f16, v_f16 = (None, None, None) + q_fp8, k_fp8, v_fp8 = (q, k, v) if is_input_fp8 else (None, None, None) + q_f16, k_f16, v_f16 = (None, None, None) if is_input_fp8 else (q, k, v) fused_attn_backend = None if fp8: assert use_fused_attention, "FP8 is only supported with Fused Attention!" fused_attn_backend = tex.NVTE_Fused_Attn_Backend.NVTE_FP8 - if is_input_fp8: - q_fp8, k_fp8, v_fp8 = q, k, v - elif not fp8_recipe.mxfp8(): - q_f16, k_f16, v_f16 = q, k, v + if not is_input_fp8 and not fp8_recipe.mxfp8(): q_fp8, k_fp8, v_fp8, qkv_layout = combine_and_quantize( qkv_layout, q, k, v, QKV_quantizer ) - else: - q_f16, k_f16, v_f16 = q, k, v if not fp8_recipe.mxfp8(): q, k, v = [q_fp8._data, k_fp8._data, v_fp8._data] fp8_meta_kwargs["s_quantizer"] = S_quantizer @@ -3222,9 +3217,19 @@ def forward( out_fp8 = None out_ret = out - if fp8 and (is_output_fp8 or (is_bwd_fp8 and fp8_recipe.delayed())): + bwd_requires_o_fp8 = ( + is_training + and is_bwd_fp8 + and ( + fp8_recipe.delayed() + or (fp8_recipe.float8_current_scaling() and not _dpa_fp8_cs_o_in_f16) + ) + ) + if fp8 and (is_output_fp8 or bwd_requires_o_fp8): out_fp8 = O_quantizer(out) + if is_output_fp8: out_ret = out_fp8 + ctx.fp8 = fp8 and is_bwd_fp8 ctx.fp8_recipe = fp8_recipe fp8_tensors = (None, None, None, None) @@ -3236,9 +3241,7 @@ def forward( q_fp8_save = Float8Tensor.make_like(q_fp8, data=q, dtype=fwd_nominal_dtype) k_fp8_save = Float8Tensor.make_like(k_fp8, data=k, dtype=fwd_nominal_dtype) v_fp8_save = Float8Tensor.make_like(v_fp8, data=v, dtype=fwd_nominal_dtype) - if fp8_recipe.delayed(): - fp8_tensors = (q_fp8_save, k_fp8_save, v_fp8_save, out_fp8) - if fp8_recipe.float8_current_scaling() and not _dpa_fp8_cs_o_in_f16: + if fp8_recipe.delayed() or (fp8_recipe.float8_current_scaling() and not _dpa_fp8_cs_o_in_f16): fp8_tensors = (q_fp8_save, k_fp8_save, v_fp8_save, out_fp8) if fp8_recipe.float8_current_scaling() and _dpa_fp8_cs_o_in_f16: fp8_tensors = (q_fp8_save, k_fp8_save, v_fp8_save, None) @@ -3247,9 +3250,12 @@ def forward( f16_tensors = (q, k, v, out) elif fp8: if is_input_fp8: - q, k, v = combine_and_dequantize(qkv_layout, q_fp8, k_fp8, v_fp8) - f16_tensors = (q_f16, k_f16, v_f16, out) - ctx.qkv_reshaped = False + q_f16, k_f16, v_f16 = combine_and_dequantize(qkv_layout, q_fp8, k_fp8, v_fp8) + if fp8_recipe.mxfp8(): + f16_tensors = (q, k, v, out) + else: + f16_tensors = (q_f16, k_f16, v_f16, out) + ctx.qkv_reshaped = False else: f16_tensors = (q, k, v, out) From 7c0ba7f879e28d68afec70a2b0062bc4c6b3a4ec Mon Sep 17 00:00:00 2001 From: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Date: Fri, 13 Mar 2026 16:58:30 -0700 Subject: [PATCH 072/172] add fp8 sink attn Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --- .../common/fused_attn/fused_attn.cpp | 32 +-- .../common/fused_attn/fused_attn_fp8.cu | 189 ++++++++++++------ .../common/fused_attn/fused_attn_fp8.h | 15 +- 3 files changed, 160 insertions(+), 76 deletions(-) diff --git a/transformer_engine/common/fused_attn/fused_attn.cpp b/transformer_engine/common/fused_attn/fused_attn.cpp index 046b18aae9..d9701fc28d 100644 --- a/transformer_engine/common/fused_attn/fused_attn.cpp +++ b/transformer_engine/common/fused_attn/fused_attn.cpp @@ -643,9 +643,10 @@ void nvte_fused_attn_fwd( #if (CUDNN_VERSION >= 8900) fused_attn_fp8_fwd(b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d_qk, d_v, is_training, attn_scale, dropout, qkv_layout, o_format, bias_type, attn_mask_type, - window_size_left, window_size_right, bottom_right_diagonal, input_Q, input_K, - input_V, input_output_S, output_O, Aux_CTX_Tensors, input_cu_seqlens_q, - input_cu_seqlens_kv, input_rng_state, wkspace, stream, handle); + softmax_type, window_size_left, window_size_right, bottom_right_diagonal, + input_Q, input_K, input_V, input_SoftmaxOffset, input_output_S, output_O, + Aux_CTX_Tensors, input_cu_seqlens_q, input_cu_seqlens_kv, input_rng_state, + wkspace, stream, handle); #else NVTE_ERROR("cuDNN 8.9.0 is required for FP8 fused attention. \n"); #endif @@ -742,20 +743,25 @@ void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETenso #endif } else if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_FP8) { #if (CUDNN_VERSION >= 8900) - const Tensor *input_M = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]); - const Tensor *input_ZInv = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[1]); - const Tensor *input_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[2]); - const Tensor *input_dO_f16; + size_t i = 0; + const Tensor *input_M = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); + const Tensor *input_ZInv = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); + const Tensor *input_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); + const Tensor *input_dO_f16 = nullptr; if (input_dO->scaling_mode == NVTE_MXFP8_1D_SCALING) { - input_dO_f16 = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[3]); + input_dO_f16 = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); + } + const Tensor *input_SoftmaxOffset = nullptr; + if (softmax_type != NVTE_VANILLA_SOFTMAX) { + input_SoftmaxOffset = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); } fused_attn_fp8_bwd(b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d_qk, d_v, attn_scale, dropout, qkv_layout, o_format, d_out_format, dqkv_layout, bias_type, attn_mask_type, - window_size_left, window_size_right, bottom_right_diagonal, deterministic, - input_Q, input_K, input_V, input_O, input_dO, input_dO_f16, input_M, - input_ZInv, input_S, input_output_dP, output_dQ, output_dK, output_dV, - input_cu_seqlens_q, input_cu_seqlens_kv, input_rng_state, wkspace, stream, - handle); + softmax_type, window_size_left, window_size_right, bottom_right_diagonal, + deterministic, input_Q, input_K, input_V, input_O, input_dO, input_dO_f16, + input_M, input_ZInv, input_S, input_SoftmaxOffset, input_output_dP, + output_dQ, output_dK, output_dV, output_dSoftmaxOffset, input_cu_seqlens_q, + input_cu_seqlens_kv, input_rng_state, wkspace, stream, handle); #else NVTE_ERROR("cuDNN 8.9.0 is required for FP8 fused attention. \n"); #endif diff --git a/transformer_engine/common/fused_attn/fused_attn_fp8.cu b/transformer_engine/common/fused_attn/fused_attn_fp8.cu index e3b86908e7..3e7a30022a 100644 --- a/transformer_engine/common/fused_attn/fused_attn_fp8.cu +++ b/transformer_engine/common/fused_attn/fused_attn_fp8.cu @@ -1655,14 +1655,15 @@ void fused_attn_fp8_fwd_impl_v1( int64_t b, int64_t h, int64_t hg, int64_t s_q, int64_t s_kv, int64_t d_qk, int64_t d_v, bool is_training, float scaling_factor, float dropout_probability, NVTE_QKV_Layout qkv_layout, NVTE_QKV_Format o_format, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, - int64_t window_size_left, int64_t window_size_right, bool bottom_right_diagonal, void* devPtrQ, - void* devPtrK, void* devPtrV, void* devPtrM, void* devPtrZInv, void* devPtrO, - void* devPtrDescaleQ, void* devPtrDescaleK, void* devPtrDescaleV, void* devPtrDescaleS, - void* devPtrScaleS, void* devPtrScaleO, void* devPtrAmaxO, void* devPtrAmaxS, - void* devPtrcuSeqlensQ, void* devPtrcuSeqlensKV, void* devPtrDropoutSeed, - void* devPtrDropoutOffset, cudnn_frontend::DataType_t qkv_tensor_type, - cudnn_frontend::DataType_t o_tensor_type, NVTEScalingMode scaling_mode, void* workspace, - size_t* workspace_size, cudaStream_t stream, cudnnHandle_t handle) { + NVTE_Softmax_Type softmax_type, int64_t window_size_left, int64_t window_size_right, + bool bottom_right_diagonal, void* devPtrQ, void* devPtrK, void* devPtrV, + void* devPtrSoftmaxOffset, void* devPtrM, void* devPtrZInv, void* devPtrO, void* devPtrDescaleQ, + void* devPtrDescaleK, void* devPtrDescaleV, void* devPtrDescaleS, void* devPtrScaleS, + void* devPtrScaleO, void* devPtrAmaxO, void* devPtrAmaxS, void* devPtrcuSeqlensQ, + void* devPtrcuSeqlensKV, void* devPtrDropoutSeed, void* devPtrDropoutOffset, + cudnn_frontend::DataType_t qkv_tensor_type, cudnn_frontend::DataType_t o_tensor_type, + NVTEScalingMode scaling_mode, void* workspace, size_t* workspace_size, cudaStream_t stream, + cudnnHandle_t handle) { using namespace transformer_engine; const auto cudnn_runtime_version = cudnnGetVersion(); bool is_bias = (bias_type == NVTE_Bias_Type::NVTE_POST_SCALE_BIAS); @@ -1672,6 +1673,7 @@ void fused_attn_fp8_fwd_impl_v1( bool is_padding = ((mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK) || (mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK)); bool is_dropout = (is_training && dropout_probability != 0.0f); + bool is_softmax_offset = (softmax_type != NVTE_Softmax_Type::NVTE_VANILLA_SOFTMAX); auto bias_b = b; auto bias_h = h; auto bias_sq = s_q; @@ -1715,7 +1717,7 @@ void fused_attn_fp8_fwd_impl_v1( qkv_layout, bias_type, mask_type, - NVTE_Softmax_Type::NVTE_VANILLA_SOFTMAX, + softmax_type, window_size_left, window_size_right, bottom_right_diagonal, @@ -1744,6 +1746,7 @@ void fused_attn_fp8_fwd_impl_v1( std::shared_ptr, // amax_o std::shared_ptr, // Stats std::shared_ptr, // bias + std::shared_ptr, // softmax_offset std::shared_ptr, // seq_q std::shared_ptr, // seq_kv std::shared_ptr, // dropout_seed @@ -1770,7 +1773,7 @@ void fused_attn_fp8_fwd_impl_v1( std::shared_ptr Q, K, V, attn_scale; std::shared_ptr descale_q, descale_k, descale_v; std::shared_ptr descale_s, scale_s, scale_o; - std::shared_ptr bias, seq_q, seq_kv; + std::shared_ptr bias, softmax_offset, seq_q, seq_kv; std::shared_ptr dropout_seed, dropout_offset; // Q, K, V, attn_scale @@ -1910,6 +1913,15 @@ void fused_attn_fp8_fwd_impl_v1( sdpa_options.set_dropout(dropout_probability, dropout_seed, dropout_offset); } + if (is_softmax_offset) { + softmax_offset = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("softmax_offset") + .set_dim({1, h, 1, 1}) + .set_stride({h, 1, 1, 1}) + .set_data_type(fe::DataType_t::FLOAT)); + // sdpa_options.set_sink_token(softmax_offset); + } + std::shared_ptr O, Stats, amax_s, amax_o; if (is_delayed_scaling || is_current_scaling) { auto outputs = mha_graph->sdpa_fp8(Q, K, V, descale_q, descale_k, descale_v, descale_s, @@ -1965,6 +1977,8 @@ void fused_attn_fp8_fwd_impl_v1( scale_s, scale_o, attn_scale, O, amax_s, amax_o); auto Stats_tuple = std::make_tuple(Stats); auto bias_tuple = is_bias ? std::make_tuple(bias) : std::make_tuple(nullptr); + auto softmax_offset_tuple = + is_softmax_offset ? std::make_tuple(softmax_offset) : std::make_tuple(nullptr); auto padding_tuple = is_padding ? std::make_tuple(seq_q, seq_kv) : std::make_tuple(nullptr, nullptr); auto dropout_tuple = is_dropout ? std::make_tuple(dropout_seed, dropout_offset) @@ -1976,15 +1990,16 @@ void fused_attn_fp8_fwd_impl_v1( NVTE_CHECK_CUDNN_FE(mha_graph->check_support(handle)); NVTE_CHECK_CUDNN_FE(mha_graph->build_plans(handle)); auto return_tuple = std::tuple_cat(std::make_tuple(mha_graph), key_tensors_tuple, Stats_tuple, - bias_tuple, padding_tuple, dropout_tuple); + bias_tuple, softmax_offset_tuple, padding_tuple, + dropout_tuple); cache.insert({descriptor, return_tuple}); return return_tuple; }; auto [mha_graph, Q, K, V, descale_q, descale_k, descale_v, descale_s, scale_s, scale_o, - attn_scale, O, amax_s, amax_o, Stats, bias, seq_q, seq_kv, dropout_seed, dropout_offset] = - get_graph(sdpa_fp8_fprop_cache, descriptor); + attn_scale, O, amax_s, amax_o, Stats, bias, softmax_offset, seq_q, seq_kv, + dropout_seed, dropout_offset] = get_graph(sdpa_fp8_fprop_cache, descriptor); auto plan_workspace_size = mha_graph->get_workspace_size(); @@ -2044,6 +2059,10 @@ void fused_attn_fp8_fwd_impl_v1( variant_pack[dropout_offset] = devPtrDropoutOffset; } + if (is_softmax_offset) { + variant_pack[softmax_offset] = devPtrSoftmaxOffset; + } + NVTE_CHECK_CUDNN_FE(mha_graph->execute(handle, variant_pack, workspace)); } catch (cudnn_frontend::cudnnException& e) { NVTE_ERROR(e.what()); @@ -2055,20 +2074,23 @@ void fused_attn_fp8_bwd_impl_v1( int64_t b, int64_t h, int64_t hg, int64_t s_q, int64_t s_kv, int64_t d_qk, int64_t d_v, float scaling_factor, float dropout_probability, NVTE_QKV_Layout qkv_layout, NVTE_QKV_Format o_format, NVTE_QKV_Format d_out_format, NVTE_QKV_Layout dqkv_layout, - NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, int64_t window_size_left, - int64_t window_size_right, bool bottom_right_diagonal, bool deterministic, void* devPtrQ, - void* devPtrK, void* devPtrV, void* devPtrM, void* devPtrZInv, void* devPtrO, void* devPtrdO, - void* devPtrdQ, void* devPtrdK, void* devPtrdV, void* devPtrDescaleQ, void* devPtrDescaleK, - void* devPtrDescaleV, void* devPtrDescaleO, void* devPtrDescaledO, void* devPtrDescaleS, - void* devPtrDescaledP, void* devPtrScaleS, void* devPtrScaledP, void* devPtrScaledQ, - void* devPtrScaledK, void* devPtrScaledV, void* devPtrAmaxdP, void* devPtrAmaxdQ, - void* devPtrAmaxdK, void* devPtrAmaxdV, void* devPtrQ_t, void* devPtrK_t, void* devPtrdO_f16, - void* devPtrdO_t, void* devPtrDescaleQ_t, void* devPtrDescaleK_t, void* devPtrDescaledO_t, - void* devPtrcuSeqlensQ, void* devPtrcuSeqlensKV, void* devPtrDropoutSeed, - void* devPtrDropoutOffset, cudnn_frontend::DataType_t qkv_tensor_type, - cudnn_frontend::DataType_t o_tensor_type, cudnn_frontend::DataType_t do_tensor_type, - cudnn_frontend::DataType_t dqkv_tensor_type, NVTEScalingMode scaling_mode, void* workspace, - size_t* workspace_size, cudaStream_t stream, cudnnHandle_t handle) { + NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, + int64_t window_size_left, int64_t window_size_right, bool bottom_right_diagonal, + bool deterministic, void* devPtrQ, void* devPtrK, void* devPtrV, + void* devPtrM, void* devPtrZInv, + void* devPtrO, void* devPtrdO, void* devPtrSoftmaxOffset, void* devPtrdQ, void* devPtrdK, void* devPtrdV, + void* devPtrdSoftmaxOffset, void* devPtrDescaleQ, + void* devPtrDescaleK, void* devPtrDescaleV, void* devPtrDescaleO, void* devPtrDescaledO, + void* devPtrDescaleS, void* devPtrDescaledP, void* devPtrScaleS, void* devPtrScaledP, + void* devPtrScaledQ, void* devPtrScaledK, void* devPtrScaledV, void* devPtrAmaxdP, + void* devPtrAmaxdQ, void* devPtrAmaxdK, void* devPtrAmaxdV, void* devPtrQ_t, void* devPtrK_t, + void* devPtrdO_f16, void* devPtrdO_t, void* devPtrDescaleQ_t, void* devPtrDescaleK_t, + void* devPtrDescaledO_t, void* devPtrcuSeqlensQ, void* devPtrcuSeqlensKV, + void* devPtrDropoutSeed, void* devPtrDropoutOffset, + cudnn_frontend::DataType_t qkv_tensor_type, cudnn_frontend::DataType_t o_tensor_type, + cudnn_frontend::DataType_t do_tensor_type, cudnn_frontend::DataType_t dqkv_tensor_type, + NVTEScalingMode scaling_mode, void* workspace, size_t* workspace_size, cudaStream_t stream, + cudnnHandle_t handle) { using namespace transformer_engine; const auto cudnn_runtime_version = cudnnGetVersion(); bool is_bias = (bias_type == NVTE_Bias_Type::NVTE_POST_SCALE_BIAS); @@ -2078,6 +2100,7 @@ void fused_attn_fp8_bwd_impl_v1( bool is_padding = ((mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK) || (mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK)); bool is_dropout = (dropout_probability != 0.0f); + bool is_softmax_offset = (softmax_type != NVTE_Softmax_Type::NVTE_VANILLA_SOFTMAX); auto bias_b = b; auto bias_h = h; auto bias_sq = s_q; @@ -2124,7 +2147,7 @@ void fused_attn_fp8_bwd_impl_v1( qkv_layout, bias_type, mask_type, - NVTE_Softmax_Type::NVTE_VANILLA_SOFTMAX, + softmax_type, window_size_left, window_size_right, bottom_right_diagonal, @@ -2173,6 +2196,8 @@ void fused_attn_fp8_bwd_impl_v1( std::shared_ptr, // amax_dP std::shared_ptr, // bias std::shared_ptr, // dBias + std::shared_ptr, // softmax_offset + std::shared_ptr, // d_softmax_offset std::shared_ptr, // seq_q std::shared_ptr, // seq_kv std::shared_ptr, // dropout_seed @@ -2205,7 +2230,8 @@ void fused_attn_fp8_bwd_impl_v1( std::shared_ptr descale_dP, descale_dO, descale_dO_t; std::shared_ptr scale_s, scale_dP; std::shared_ptr scale_dQ, scale_dK, scale_dV; - std::shared_ptr bias, dBias, seq_q, seq_kv; + std::shared_ptr bias, dBias, softmax_offset, d_softmax_offset; + std::shared_ptr seq_q, seq_kv; std::shared_ptr dropout_seed, dropout_offset; // Q, K, V, O, dO, stats, attn_scale @@ -2458,6 +2484,22 @@ void fused_attn_fp8_bwd_impl_v1( .set_data_type(fe::DataType_t::INT64)); sdpa_backward_options.set_dropout(dropout_probability, dropout_seed, dropout_offset); } + + if (is_softmax_offset) { + softmax_offset = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("softmax_offset") + .set_dim({1, h, 1, 1}) + .set_stride({h, 1, 1, 1}) + .set_data_type(fe::DataType_t::FLOAT)); + // sdpa_backward_options.set_sink_token(softmax_offset); + d_softmax_offset = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("d_softmax_offset") + .set_dim({1, h, 1, 1}) + .set_stride({h, 1, 1, 1}) + .set_data_type(fe::DataType_t::FLOAT)); + // sdpa_backward_options.set_dsink_token(d_softmax_offset); + } + std::shared_ptr dQ, dK, dV, amax_dQ, amax_dK, amax_dV, amax_dP; if (is_delayed_scaling || is_current_scaling) { auto outputs = mha_graph->sdpa_fp8_backward(Q, K, V, O, dO, Stats, descale_q, descale_k, @@ -2552,6 +2594,9 @@ void fused_attn_fp8_bwd_impl_v1( is_mxfp8 ? std::make_tuple(Q_t, K_t, dO_f16, dO_t, descale_q_t, descale_k_t, descale_dO_t) : std::make_tuple(nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr); auto bias_tuple = is_bias ? std::make_tuple(bias, dBias) : std::make_tuple(nullptr, nullptr); + auto softmax_offset_tuple = is_softmax_offset + ? std::make_tuple(softmax_offset, d_softmax_offset) + : std::make_tuple(nullptr, nullptr); auto padding_tuple = is_padding ? std::make_tuple(seq_q, seq_kv) : std::make_tuple(nullptr, nullptr); auto dropout_tuple = is_dropout ? std::make_tuple(dropout_seed, dropout_offset) @@ -2565,7 +2610,7 @@ void fused_attn_fp8_bwd_impl_v1( auto return_tuple = std::tuple_cat(std::make_tuple(mha_graph), key_tensors_tuple, mxfp8_tensors_tuple, - bias_tuple, padding_tuple, dropout_tuple); + bias_tuple, softmax_offset_tuple, padding_tuple, dropout_tuple); cache.insert({descriptor, return_tuple}); return return_tuple; @@ -2573,8 +2618,8 @@ void fused_attn_fp8_bwd_impl_v1( auto [mha_graph, Q, K, V, O, Stats, dO, attn_scale, descale_q, descale_k, descale_v, descale_o, descale_dO, descale_s, descale_dP, scale_s, scale_dQ, scale_dK, scale_dV, scale_dP, dQ, dK, dV, amax_dQ, amax_dK, amax_dV, amax_dP, Q_t, K_t, dO_f16, dO_t, descale_q_t, - descale_k_t, descale_dO_t, bias, dBias, seq_q, seq_kv, dropout_seed, dropout_offset] = - get_graph(sdpa_fp8_bprop_cache, descriptor); + descale_k_t, descale_dO_t, bias, dBias, softmax_offset, d_softmax_offset, seq_q, seq_kv, + dropout_seed, dropout_offset] = get_graph(sdpa_fp8_bprop_cache, descriptor); auto plan_workspace_size = mha_graph->get_workspace_size(); @@ -2662,6 +2707,11 @@ void fused_attn_fp8_bwd_impl_v1( variant_pack[dropout_offset] = devPtrDropoutOffset; } + if (is_softmax_offset) { + variant_pack[softmax_offset] = devPtrSoftmaxOffset; + variant_pack[d_softmax_offset] = devPtrdSoftmaxOffset; + } + NVTE_CHECK_CUDNN_FE(mha_graph->execute(handle, variant_pack, workspace)); } catch (cudnn_frontend::cudnnException& e) { NVTE_ERROR(e.what()); @@ -2678,9 +2728,11 @@ void fused_attn_fp8_fwd(size_t batch, size_t num_attn_heads, size_t num_gqa_grou size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, bool is_training, float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_QKV_Format o_format, - NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, size_t window_size_left, + NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, + NVTE_Softmax_Type softmax_type, size_t window_size_left, size_t window_size_right, bool bottom_right_diagonal, const Tensor* input_Q, - const Tensor* input_K, const Tensor* input_V, Tensor* input_output_S, + const Tensor* input_K, const Tensor* input_V, + const Tensor* input_SoftmaxOffset, Tensor* input_output_S, Tensor* output_O, NVTETensorPack* Aux_CTX_Tensors, const Tensor* cu_seqlens_q, const Tensor* cu_seqlens_kv, const Tensor* rng_state, Tensor* workspace, cudaStream_t stream, @@ -2707,11 +2759,14 @@ void fused_attn_fp8_fwd(size_t batch, size_t num_attn_heads, size_t num_gqa_grou devPtrV = input_V->columnwise_data.dptr; devPtrDescaleV = input_V->columnwise_scale_inv.dptr; } + void* devPtrSoftmaxOffset = nullptr; + if (softmax_type != NVTE_VANILLA_SOFTMAX) { + devPtrSoftmaxOffset = input_SoftmaxOffset->data.dptr; + } void* devPtrM = nullptr; void* devPtrZInv = nullptr; if (Aux_CTX_Tensors->size == 0) { - Aux_CTX_Tensors->size = 3; - int i=0; + int i = 0; Tensor* output_M = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); Tensor* output_ZInv = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); Tensor* output_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); @@ -2724,14 +2779,25 @@ void fused_attn_fp8_fwd(size_t batch, size_t num_attn_heads, size_t num_gqa_grou output_rng_state->data.dptr = nullptr; output_rng_state->data.shape = {2}; output_rng_state->data.dtype = DType::kInt64; - } else if (Aux_CTX_Tensors->size == 3) { - int i=0; + if (softmax_type != NVTE_VANILLA_SOFTMAX) { + Tensor* output_softmax_offset = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); + output_softmax_offset->data.dptr = nullptr; + output_softmax_offset->data.shape = {1, num_attn_heads, 1, 1}; + output_softmax_offset->data.dtype = DType::kFloat32; + } + Aux_CTX_Tensors->size = i; + } else if (Aux_CTX_Tensors->size >= 3) { + int i = 0; Tensor* output_M = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); Tensor* output_ZInv = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); Tensor* output_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); devPtrM = output_M->data.dptr; devPtrZInv = output_ZInv->data.dptr; output_rng_state->data.dptr = rng_state->data.dptr; + if (softmax_type != NVTE_VANILLA_SOFTMAX) { + Tensor* output_softmax_offset = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); + output_softmax_offset->data.dptr = devPtrSoftmaxOffset; + } } else { NVTE_ERROR("Unexpected Aux_CTX_Tensors->size."); } @@ -2755,12 +2821,12 @@ void fused_attn_fp8_fwd(size_t batch, size_t num_attn_heads, size_t num_gqa_grou fused_attn::fused_attn_fp8_fwd_impl_v1( batch, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, head_dim_qk, head_dim_v, is_training, attn_scale, p_dropout, qkv_layout, o_format, bias_type, mask_type, - window_size_left, window_size_right, bottom_right_diagonal, devPtrQ, devPtrK, devPtrV, - devPtrM, devPtrZInv, devPtrO, devPtrDescaleQ, devPtrDescaleK, devPtrDescaleV, - devPtrDescaleS, devPtrScaleS, devPtrScaleO, devPtrAmaxO, devPtrAmaxS, devPtrcuSeqlensQ, - devPtrcuSeqlensKV, devPtrDropoutSeed, devPtrDropoutOffset, get_cudnn_fe_dtype(QKV_type), - get_cudnn_fe_dtype(O_type), input_Q->scaling_mode, workspace->data.dptr, &workspace_size, - stream, handle); + softmax_type, window_size_left, window_size_right, bottom_right_diagonal, devPtrQ, + devPtrK, devPtrV, devPtrSoftmaxOffset, devPtrM, devPtrZInv, devPtrO, devPtrDescaleQ, + devPtrDescaleK, devPtrDescaleV, devPtrDescaleS, devPtrScaleS, devPtrScaleO, devPtrAmaxO, + devPtrAmaxS, devPtrcuSeqlensQ, devPtrcuSeqlensKV, devPtrDropoutSeed, devPtrDropoutOffset, + get_cudnn_fe_dtype(QKV_type), get_cudnn_fe_dtype(O_type), input_Q->scaling_mode, + workspace->data.dptr, &workspace_size, stream, handle); } else if (qkv_layout == NVTE_QKV_Layout::NVTE_T3HD) { fused_attn::fused_attn_fp8_fwd_impl( batch, num_attn_heads, max_seqlen_q, max_seqlen_kv, head_dim_qk, is_training, attn_scale, @@ -2791,13 +2857,15 @@ void fused_attn_fp8_bwd(size_t batch, size_t num_attn_heads, size_t num_gqa_grou size_t head_dim_v, float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_QKV_Format o_format, NVTE_QKV_Format d_out_format, NVTE_QKV_Layout dqkv_layout, - NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, size_t window_size_left, + NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, + NVTE_Softmax_Type softmax_type, size_t window_size_left, size_t window_size_right, bool bottom_right_diagonal, bool deterministic, const Tensor* input_Q, const Tensor* input_K, const Tensor* input_V, const Tensor* input_O, const Tensor* input_dO, const Tensor* input_dO_f16, const Tensor* input_M, const Tensor* input_ZInv, const Tensor* input_S, - Tensor* input_output_dP, const Tensor* output_dQ, const Tensor* output_dK, - const Tensor* output_dV, const Tensor* cu_seqlens_q, + const Tensor* input_SoftmaxOffset, Tensor* input_output_dP, + const Tensor* output_dQ, const Tensor* output_dK, const Tensor* output_dV, + Tensor* output_dSoftmaxOffset, const Tensor* cu_seqlens_q, const Tensor* cu_seqlens_kv, const Tensor* rng_state, Tensor* workspace, cudaStream_t stream, cudnnHandle_t handle) { using namespace transformer_engine; @@ -2839,6 +2907,13 @@ void fused_attn_fp8_bwd(size_t batch, size_t num_attn_heads, size_t num_gqa_grou void* devPtrScaledP = input_output_dP->scale.dptr; void* devPtrDescaledP = input_output_dP->scale_inv.dptr; + void* devPtrSoftmaxOffset = nullptr; + void* devPtrdSoftmaxOffset = nullptr; + if (softmax_type != NVTE_VANILLA_SOFTMAX) { + devPtrSoftmaxOffset = input_SoftmaxOffset->data.dptr; + devPtrdSoftmaxOffset = output_dSoftmaxOffset->data.dptr; + } + void* devPtrdQ = output_dQ->data.dptr; void* devPtrdK = output_dK->data.dptr; void* devPtrdV = output_dV->data.dptr; @@ -2869,16 +2944,16 @@ void fused_attn_fp8_bwd(size_t batch, size_t num_attn_heads, size_t num_gqa_grou fused_attn::fused_attn_fp8_bwd_impl_v1( batch, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, head_dim_qk, head_dim_v, attn_scale, p_dropout, qkv_layout, o_format, d_out_format, dqkv_layout, bias_type, - mask_type, window_size_left, window_size_right, bottom_right_diagonal, deterministic, - devPtrQ, devPtrK, devPtrV, devPtrM, devPtrZInv, devPtrO, devPtrdO, devPtrdQ, devPtrdK, - devPtrdV, devPtrDescaleQ, devPtrDescaleK, devPtrDescaleV, devPtrDescaleO, devPtrDescaledO, - devPtrDescaleS, devPtrDescaledP, devPtrScaleS, devPtrScaledP, devPtrScaledQ, devPtrScaledK, - devPtrScaledV, devPtrAmaxdP, devPtrAmaxdQ, devPtrAmaxdK, devPtrAmaxdV, devPtrQ_t, devPtrK_t, - devPtrdO_f16, devPtrdO_t, devPtrDescaleQ_t, devPtrDescaleK_t, devPtrDescaledO_t, - devPtrcuSeqlensQ, devPtrcuSeqlensKV, devPtrDropoutSeed, devPtrDropoutOffset, - get_cudnn_fe_dtype(QKV_type), get_cudnn_fe_dtype(O_type), get_cudnn_fe_dtype(dO_type), - get_cudnn_fe_dtype(dQKV_type), input_dO->scaling_mode, workspace->data.dptr, - &workspace_size, stream, handle); + mask_type, softmax_type, window_size_left, window_size_right, bottom_right_diagonal, + deterministic, devPtrQ, devPtrK, devPtrV, + devPtrM, devPtrZInv, devPtrO, devPtrdO, devPtrSoftmaxOffset, devPtrdQ, devPtrdK, devPtrdV, devPtrdSoftmaxOffset, devPtrDescaleQ, + devPtrDescaleK, devPtrDescaleV, devPtrDescaleO, devPtrDescaledO, devPtrDescaleS, + devPtrDescaledP, devPtrScaleS, devPtrScaledP, devPtrScaledQ, devPtrScaledK, devPtrScaledV, + devPtrAmaxdP, devPtrAmaxdQ, devPtrAmaxdK, devPtrAmaxdV, devPtrQ_t, devPtrK_t, devPtrdO_f16, + devPtrdO_t, devPtrDescaleQ_t, devPtrDescaleK_t, devPtrDescaledO_t, devPtrcuSeqlensQ, + devPtrcuSeqlensKV, devPtrDropoutSeed, devPtrDropoutOffset, get_cudnn_fe_dtype(QKV_type), + get_cudnn_fe_dtype(O_type), get_cudnn_fe_dtype(dO_type), get_cudnn_fe_dtype(dQKV_type), + input_dO->scaling_mode, workspace->data.dptr, &workspace_size, stream, handle); } else if (dqkv_layout == NVTE_QKV_Layout::NVTE_T3HD) { fused_attn::fused_attn_fp8_bwd_impl( batch, num_attn_heads, max_seqlen_q, max_seqlen_kv, head_dim_qk, attn_scale, p_dropout, diff --git a/transformer_engine/common/fused_attn/fused_attn_fp8.h b/transformer_engine/common/fused_attn/fused_attn_fp8.h index 98d5876ec8..47777279b9 100644 --- a/transformer_engine/common/fused_attn/fused_attn_fp8.h +++ b/transformer_engine/common/fused_attn/fused_attn_fp8.h @@ -18,10 +18,11 @@ void fused_attn_fp8_fwd(size_t batch, size_t num_attn_heads, size_t num_gqa_grou size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, bool is_training, float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_QKV_Format o_format, - NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, size_t window_size_left, + NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, + NVTE_Softmax_Type softmax_type, size_t window_size_left, size_t window_size_right, bool bottom_right_diagonal, const Tensor *input_Q, - const Tensor *input_K, const Tensor *input_V, Tensor *input_output_S, - Tensor *output_O, NVTETensorPack *Aux_CTX_Tensors, + const Tensor *input_K, const Tensor *input_V, const Tensor *input_SoftmaxOffset, + Tensor *input_output_S, Tensor *output_O, NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv, const Tensor *rng_state, Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle); @@ -32,13 +33,15 @@ void fused_attn_fp8_bwd(size_t batch, size_t num_attn_heads, size_t num_gqa_grou size_t head_dim_v, float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_QKV_Format o_format, NVTE_QKV_Format d_out_format, NVTE_QKV_Layout dqkv_layout, - NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, size_t window_size_left, + NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, + NVTE_Softmax_Type softmax_type, size_t window_size_left, size_t window_size_right, bool bottom_right_diagonal, bool deterministic, const Tensor *input_Q, const Tensor *input_K, const Tensor *input_V, const Tensor *input_O, const Tensor *input_dO, const Tensor *input_dO_f16, const Tensor *input_M, const Tensor *input_ZInv, const Tensor *input_S, - Tensor *input_output_dP, const Tensor *output_dQ, const Tensor *output_dK, - const Tensor *output_dV, const Tensor *cu_seqlens_q, + const Tensor *input_SoftmaxOffset, Tensor *input_output_dP, + const Tensor *output_dQ, const Tensor *output_dK, const Tensor *output_dV, + Tensor *output_dSoftmaxOffset, const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv, const Tensor *rng_state, Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle); #endif // end of CUDNN>=8900 From e68f785a6c979fa312fd71b07ec19a2e04d66569 Mon Sep 17 00:00:00 2001 From: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Date: Fri, 13 Mar 2026 19:00:58 -0700 Subject: [PATCH 073/172] fix a2a comm for F16 Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --- .../dot_product_attention/context_parallel.py | 19 +++++++++++++++++-- 1 file changed, 17 insertions(+), 2 deletions(-) diff --git a/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py b/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py index a0ed65cdd4..407c20ca9a 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py @@ -433,7 +433,7 @@ def flash_attn_a2a_communicate( ), "cu_seqlens_padded is required for THD format!" a2a_inputs = [a2a_inputs] if not isinstance(a2a_inputs, list) else a2a_inputs a2a_outputs, a2a_reqs = [None] * len(a2a_inputs), [None] * len(a2a_inputs) - batch_dim, _, head_dim = get_bsh_dims(qkv_format) + _, _, head_dim = get_bsh_dims(qkv_format) if before_attn: for i in range(len(a2a_inputs) + 2): if 0 < i < len(a2a_inputs) + 1: @@ -512,7 +512,22 @@ def flash_attn_a2a_communicate( # or [cp, 2, s//2, b, np//cp, hn] -> [2, s//2, b, cp, np//cp, hn] # or [cp, 2, b, np//cp, s//2, hn] -> [b, cp, np//cp, 2, s//2, hn] # or [cp, t, np//cp, hn] -> [t, cp, np//cp, hn] - x = x.movedim(0, head_dim + 1).movedim(0, seq_dim + 1).contiguous() + tmp_list = [x for x in qkv_format] + if "t" not in qkv_format: + tmp_list.insert(0, "2") + tmp_list.insert(0, "c") + tmp_format = "".join(tmp_list) + h_index = tmp_format.index("h") + tmp_list.insert(h_index - 1, tmp_list.pop(0)) + tmp_format = "".join(tmp_list) + if "t" not in qkv_format: + s_index = tmp_format.index("s") + tmp_list.insert(s_index - 1, tmp_list.pop(0)) + tmp_format = "".join(tmp_list) + seq_dim_ = tmp_format.index("s")-1 + else: + seq_dim_ = tmp_format.index("t") + x = x.movedim(0, head_dim + 1).movedim(0, seq_dim_).contiguous() # [b, 2, s//2, cp, np//cp, hn] -> [b*s, np, hn] # or [2, s//2, b, cp, np//cp, hn] -> [s*b, np, hn] # or [b, cp, np//cp, 2, s//2, hn] -> [b*np, s, hn] From ae53980c172a5f185bc797223804a48bc60bb64d Mon Sep 17 00:00:00 2001 From: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Date: Fri, 13 Mar 2026 19:04:34 -0700 Subject: [PATCH 074/172] remove nan/inf print in test Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --- tests/pytorch/attention/run_attention_with_cp.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/tests/pytorch/attention/run_attention_with_cp.py b/tests/pytorch/attention/run_attention_with_cp.py index 242d6b9e7a..a5804c6888 100644 --- a/tests/pytorch/attention/run_attention_with_cp.py +++ b/tests/pytorch/attention/run_attention_with_cp.py @@ -534,10 +534,6 @@ def run_dpa_with_cp( for i, tensor in enumerate(tensors): # dbias/dbias_ could be None, so skip check for it if tensor is not None: - print( - f"========= {torch.cuda.current_device()}: tensors[{i}].shape:" - f" {tensor.shape} {tensor.dtype} {torch.isnan(tensor).any()} {torch.isinf(tensor).any()}" - ) assert torch.all(~torch.isnan(tensor)) assert torch.all(~torch.isinf(tensor)) out, dq, dk, dv, dbias, out_, dq_, dk_, dv_, dbias_ = tensors From 4b314e7875d56f64391c450906bb9dcd3cdca805 Mon Sep 17 00:00:00 2001 From: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Date: Fri, 13 Mar 2026 22:17:52 -0700 Subject: [PATCH 075/172] fix fa a2a Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --- .../dot_product_attention/context_parallel.py | 18 +++++++++++------- 1 file changed, 11 insertions(+), 7 deletions(-) diff --git a/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py b/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py index 407c20ca9a..3878dedfb6 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py @@ -508,26 +508,30 @@ def flash_attn_a2a_communicate( with torch.cuda.stream(cp_stream): a2a_reqs[i - 2].wait() x = a2a_outputs[i - 2] - # [cp, 2, b, s//2, np//cp, hn] -> [b, 2, s//2, cp, np//cp, hn] + # [cp, 2, b, s//2, np//cp, hn] -> [2, b, s//2, cp, np//cp, hn] # or [cp, 2, s//2, b, np//cp, hn] -> [2, s//2, b, cp, np//cp, hn] - # or [cp, 2, b, np//cp, s//2, hn] -> [b, cp, np//cp, 2, s//2, hn] + # or [cp, 2, b, np//cp, s//2, hn] -> [2, b, cp, np//cp, s//2, hn] # or [cp, t, np//cp, hn] -> [t, cp, np//cp, hn] tmp_list = [x for x in qkv_format] if "t" not in qkv_format: tmp_list.insert(0, "2") tmp_list.insert(0, "c") tmp_format = "".join(tmp_list) - h_index = tmp_format.index("h") - tmp_list.insert(h_index - 1, tmp_list.pop(0)) + head_dim_ = tmp_format.index("h")-1 + tmp_list.insert(head_dim_, tmp_list.pop(0)) tmp_format = "".join(tmp_list) + x = x.movedim(0, head_dim_) + # [2, b, s//2, cp, np//cp, hn] -> [b, 2, s//2, cp, np//cp, hn] + # or [2, s//2, b, cp, np//cp, hn] -> [2, s//2, b, cp, np//cp, hn] + # or [2, b, cp, np//cp, s//2, hn] -> [b, cp, np//cp, 2, s//2, hn] + # or [t, cp, np//cp, hn] -> [t, cp, np//cp, hn] if "t" not in qkv_format: s_index = tmp_format.index("s") tmp_list.insert(s_index - 1, tmp_list.pop(0)) tmp_format = "".join(tmp_list) seq_dim_ = tmp_format.index("s")-1 - else: - seq_dim_ = tmp_format.index("t") - x = x.movedim(0, head_dim + 1).movedim(0, seq_dim_).contiguous() + x = x.movedim(0, seq_dim_) + x = x.contiguous() # [b, 2, s//2, cp, np//cp, hn] -> [b*s, np, hn] # or [2, s//2, b, cp, np//cp, hn] -> [s*b, np, hn] # or [b, cp, np//cp, 2, s//2, hn] -> [b*np, s, hn] From 4b5d62317d775c1837f41ba0257e7bd825223d19 Mon Sep 17 00:00:00 2001 From: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Date: Sat, 14 Mar 2026 12:07:28 -0700 Subject: [PATCH 076/172] fix fa a2a+p2p f16 Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --- .../attention/dot_product_attention/context_parallel.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py b/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py index 3878dedfb6..fd25d0baa8 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py @@ -1392,6 +1392,7 @@ def forward( if softmax_scale is None: softmax_scale = q.shape[-1] ** (-0.5) + orig_o_shape = q.shape[:-1] + v.shape[-1:] batch_dim = None seq_dim = None cu_seqlens_q_half, cu_seqlens_kv_half = None, None @@ -2124,6 +2125,7 @@ def forward( ctx.k_shape = k_shape ctx.v_shape = v_shape ctx.o_shape = o_shape + ctx.orig_o_shape = orig_o_shape ctx.fwd_nominal_dtype = fwd_nominal_dtype ctx.dQKV_quantizer = dQKV_quantizer @@ -2364,7 +2366,7 @@ def backward(ctx, dout, *_args): if cp_size_a2a > 1: if not ctx.use_fused_attention: # out = out.view(ctx.batch_size, -1, *out.shape[-2:]) - dout = dout.view(*out.shape) + dout = dout.view(ctx.orig_o_shape) chunk_ids_for_a2a = get_seq_chunk_ids_for_reordering_before_attn( cp_size_a2a, out.device ) From fdab7dbd1ece481e673df0d8e2f552b74796113d Mon Sep 17 00:00:00 2001 From: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Date: Mon, 16 Mar 2026 14:07:29 -0700 Subject: [PATCH 077/172] update FE to include new fixes Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --- 3rdparty/cudnn-frontend | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/3rdparty/cudnn-frontend b/3rdparty/cudnn-frontend index fb3f58b2f8..b449099e98 160000 --- a/3rdparty/cudnn-frontend +++ b/3rdparty/cudnn-frontend @@ -1 +1 @@ -Subproject commit fb3f58b2f8b47c2b305586bb4d7fcd007eb33839 +Subproject commit b449099e98fbe13aacc7cd6c1cb48cc11914a210 From 39b57e991797a1ad603a9d5af95a40f36890df20 Mon Sep 17 00:00:00 2001 From: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Date: Tue, 17 Mar 2026 00:33:43 -0700 Subject: [PATCH 078/172] fix thd for bwd Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --- .../common/fused_attn/fused_attn.cpp | 10 ++++++ .../pytorch/csrc/extensions/attention.cpp | 32 +++++++++++++------ 2 files changed, 32 insertions(+), 10 deletions(-) diff --git a/transformer_engine/common/fused_attn/fused_attn.cpp b/transformer_engine/common/fused_attn/fused_attn.cpp index d9701fc28d..08cecab944 100644 --- a/transformer_engine/common/fused_attn/fused_attn.cpp +++ b/transformer_engine/common/fused_attn/fused_attn.cpp @@ -577,6 +577,11 @@ void nvte_fused_attn_fwd( nvte_convert_qkv_format(q_format, input_Q->data.shape, q_format, tmp_shape, &b, &h_q, &s_q, &d_qk, &t_q); nvte_convert_qkv_format(kv_format, input_K->data.shape, kv_format, tmp_shape, &b, &h_kv, &s_kv, &d_qk, &t_kv); nvte_convert_qkv_format(kv_format, input_V->data.shape, kv_format, tmp_shape, &b, &h_kv, &s_kv, &d_v, &t_kv); + if (q_format == NVTE_QKV_Format::NVTE_THD) { + b = input_cu_seqlens_q->data.shape[0] -1; + } else if (kv_format == NVTE_QKV_Format::NVTE_THD) { + b = input_cu_seqlens_kv->data.shape[0] -1; + } int64_t num_pages_k = 0; int64_t num_pages_v = 0; @@ -696,6 +701,11 @@ void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETenso nvte_convert_qkv_format(q_format, input_Q->data.shape, q_format, tmp_shape, &b, &h_q, &s_q, &d_qk, &t_q); nvte_convert_qkv_format(kv_format, input_K->data.shape, kv_format, tmp_shape, &b, &h_kv, &s_kv, &d_qk, &t_kv); nvte_convert_qkv_format(kv_format, input_V->data.shape, kv_format, tmp_shape, &b, &h_kv, &s_kv, &d_v, &t_kv); + if (q_format == NVTE_QKV_Format::NVTE_THD) { + b = input_cu_seqlens_q->data.shape[0] -1; + } else if (kv_format == NVTE_QKV_Format::NVTE_THD) { + b = input_cu_seqlens_kv->data.shape[0] -1; + } auto handle = cudnnExecutionPlanManager::Instance().GetHandle(); const NVTEDType Q_type = static_cast(input_Q->data.dtype); diff --git a/transformer_engine/pytorch/csrc/extensions/attention.cpp b/transformer_engine/pytorch/csrc/extensions/attention.cpp index a115312bf5..619c8f92c6 100644 --- a/transformer_engine/pytorch/csrc/extensions/attention.cpp +++ b/transformer_engine/pytorch/csrc/extensions/attention.cpp @@ -154,8 +154,12 @@ std::vector fused_attn_fwd( o_shape_tmp[o_shape_tmp.size() - 1] = v_shape[v_shape.size() - 1]; auto o_shape = std::vector{o_shape_tmp.begin(), o_shape_tmp.end()}; size_t b = 0, h = 0, s = 0, d = 0, t = 0; - nvte_convert_qkv_format(nvte_get_q_format(qkv_layout), o_shape_tmp, o_format, o_shape, &b, &h, &s, + NVTE_QKV_Format q_format = nvte_get_q_format(qkv_layout); + nvte_convert_qkv_format(q_format, o_shape_tmp, o_format, o_shape, &b, &h, &s, &d, &t); + if (q_format == NVTE_QKV_Format::NVTE_THD) { + b = cu_seqlens_q.size(0) - 1; + } const DType fake_dtype_te = GetTransformerEngineDType(fake_dtype); std::tie(te_O, py_O) = quantizer_helper(o_quantizer, o_shape, fake_dtype_te, true, std::nullopt); @@ -365,14 +369,23 @@ std::vector fused_attn_bwd( std::vector v_shape = convertShape(te_V.shape()); const DType fake_dtype_te = GetTransformerEngineDType(fake_dtype); size_t b = 0, h_q = 0, h_kv = 0, s_q = 0, s_kv = 0, d_qk = 0, d_v = 0, t_q = 0, t_kv = 0; - std::vector dQ_shape(4), dK_shape(4), dV_shape(4); - nvte_convert_qkv_format(nvte_get_q_format(qkv_layout), q_shape, nvte_get_q_format(dqkv_layout), + size_t ndim = q_shape.size(); + std::vector dQ_shape(ndim), dK_shape(ndim), dV_shape(ndim); + NVTE_QKV_Format q_format = nvte_get_q_format(qkv_layout); + NVTE_QKV_Format kv_format = nvte_get_kv_format(qkv_layout); + NVTE_QKV_Format dq_format = nvte_get_q_format(dqkv_layout); + NVTE_QKV_Format dkv_format = nvte_get_kv_format(dqkv_layout); + nvte_convert_qkv_format(q_format, q_shape, dq_format, dQ_shape, &b, &h_q, &s_q, &d_qk, &t_q); - nvte_convert_qkv_format(nvte_get_kv_format(qkv_layout), k_shape, nvte_get_kv_format(dqkv_layout), + nvte_convert_qkv_format(kv_format, k_shape, dkv_format, dK_shape, &b, &h_kv, &s_kv, &d_qk, &t_kv); - nvte_convert_qkv_format(nvte_get_kv_format(qkv_layout), v_shape, nvte_get_kv_format(dqkv_layout), + nvte_convert_qkv_format(kv_format, v_shape, dkv_format, dV_shape, &b, &h_kv, &s_kv, &d_v, &t_kv); - + if (dq_format == NVTE_QKV_Format::NVTE_THD) { + b = cu_seqlens_q.size(0) - 1; + } else if (dkv_format == NVTE_QKV_Format::NVTE_THD) { + b = cu_seqlens_kv.size(0) - 1; + } at::Tensor dQ, dK, dV, dQKV, dKV; DType dqkv_type = fake_dtype_te; if (!dqkv_quantizer.is_none()) { @@ -388,7 +401,6 @@ std::vector fused_attn_bwd( } NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(dqkv_layout); std::vector tmp_shape; - switch (layout_group) { case NVTE_QKV_Layout_Group::NVTE_3HD: tmp_shape = std::vector{dQ_shape.begin(), dQ_shape.end()}; @@ -462,9 +474,9 @@ std::vector fused_attn_bwd( NVTE_ERROR("QKV layout not supported!"); } - std::tie(te_dQ, py_dQ) = quantizer_helper(dqkv_quantizer, q_shape, fake_dtype_te, true, dQ); - std::tie(te_dK, py_dK) = quantizer_helper(dqkv_quantizer, k_shape, fake_dtype_te, true, dK); - std::tie(te_dV, py_dV) = quantizer_helper(dqkv_quantizer, v_shape, fake_dtype_te, true, dV); + std::tie(te_dQ, py_dQ) = quantizer_helper(dqkv_quantizer, dQ_shape, fake_dtype_te, true, dQ); + std::tie(te_dK, py_dK) = quantizer_helper(dqkv_quantizer, dK_shape, fake_dtype_te, true, dK); + std::tie(te_dV, py_dV) = quantizer_helper(dqkv_quantizer, dV_shape, fake_dtype_te, true, dV); // construct NVTE tensors if (dqkv_type == DType::kFloat8E4M3 || dqkv_type == DType::kFloat8E5M2) { From dc494792800eec6f6e51b79881f4a99e4077744d Mon Sep 17 00:00:00 2001 From: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Date: Tue, 17 Mar 2026 00:34:11 -0700 Subject: [PATCH 079/172] refactor a2a for fu/fa Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --- .../dot_product_attention/context_parallel.py | 417 ++++++++++-------- 1 file changed, 223 insertions(+), 194 deletions(-) diff --git a/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py b/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py index fd25d0baa8..94a64bda13 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py @@ -450,14 +450,14 @@ def flash_attn_a2a_communicate( x = reorder_seq_chunks_for_a2a_before_attn( x, chunk_ids_for_a2a, seq_dim, cp_size ) - # [b, cp*2, s//2, np//cp, hn] -> [b, cp*s, np//cp, hn] - # or [cp*2, s//2, b, np//cp, hn] -> [cp*s, b, np//cp, hn] - # or [b, np//cp, cp*2, s//2, hn] -> [b, np//cp, cp*s, hn] + # [b, cp*2, s//2, h//cp, d] -> [b, cp*s, h//cp, d] + # or [cp*2, s//2, b, h//cp, d] -> [cp*s, b, h//cp, d] + # or [b, h//cp, cp*2, s//2, d] -> [b, h//cp, cp*s, d] a2a_outputs[i - 2] = x.view( *x.shape[:seq_dim], -1, *x.shape[(seq_dim + 2) :] ) else: # qkv_format == "thd" - # [cp, t, np//cp, hn] -> [cp*t, np//cp, hn] + # [cp, t, h//cp, d] -> [cp*t, h//cp, d] x = x.view(-1, *x.shape[2:]) # reorder the sequence chunks a2a_outputs[i - 2] = reorder_seq_chunks_after_a2a_before_attn_thd( @@ -466,20 +466,20 @@ def flash_attn_a2a_communicate( if i < len(a2a_inputs): x = a2a_inputs[i] - # [b, s, np, hn] -> [b, s, cp, np//cp, hn] - # or [s, b, np, hn] -> [s, b, cp, np//cp, hn] - # or [b, np, s, hn] -> [b, cp, np//cp, s, hn] - # or [t, np, hn] -> [t, cp, np//cp, hn] + # [b, s, h, d] -> [b, s, cp, h//cp, d] + # or [s, b, h, d] -> [s, b, cp, h//cp, d] + # or [b, h, s, d] -> [b, cp, h//cp, s, d] + # or [t, h, d] -> [t, cp, h//cp, d] x = x.view( *x.shape[:head_dim], cp_size, x.shape[head_dim] // cp_size, *x.shape[head_dim + 1 :], ) - # [b, s, cp, np//cp, hn] -> [cp, b, s, np//cp, hn] - # or [s, b, cp, np//cp, hn] -> [cp, s, b, np//cp, hn] - # or [b, cp, np//cp, s, hn] -> [cp, b, np//cp, s, hn] - # or [t, cp, np//cp, hn] -> [cp, t, np//cp, hn] + # [b, s, cp, h//cp, d] -> [cp, b, s, h//cp, d] + # or [s, b, cp, h//cp, d] -> [cp, s, b, h//cp, d] + # or [b, cp, h//cp, s, d] -> [cp, b, h//cp, s, d] + # or [t, cp, h//cp, d] -> [cp, t, h//cp, d] a2a_inputs[i] = x.movedim(head_dim, 0).contiguous() else: for i in range(len(a2a_inputs) + 2): @@ -491,9 +491,9 @@ def flash_attn_a2a_communicate( if i < len(a2a_inputs): x = a2a_inputs[i] if qkv_format in ["bshd", "sbhd", "bhsd"]: - # [b, cp*s, np//cp, hn] -> [b, cp*2, s//2, np//cp, hn] - # or [cp*s, b, np//cp, hn] -> [cp*2, s//2, b, np//cp, hn] - # or [b, np//cp, cp*s, hn] -> [b, np//cp, cp*2, s//2, hn] + # [b, cp*s, h//cp, d] -> [b, cp*2, s//2, h//cp, d] + # or [cp*s, b, h//cp, d] -> [cp*2, s//2, b, h//cp, d] + # or [b, h//cp, cp*s, d] -> [b, h//cp, cp*2, s//2, d] x = x.view(*x.shape[:seq_dim], cp_size * 2, -1, *x.shape[(seq_dim + 1) :]) # reorder the sequence chunks a2a_inputs[i] = reorder_seq_chunks_for_a2a_after_attn( @@ -502,16 +502,16 @@ def flash_attn_a2a_communicate( else: # qkv_format == "thd" # reorder the sequence chunks x = reorder_seq_chunks_before_a2a_after_attn_thd(x, cu_seqlens_padded, cp_size) - # [cp*t, np//cp, hn] -> [cp, t, np//cp, hn] + # [cp*t, h//cp, d] -> [cp, t, h//cp, d] a2a_inputs[i] = x.view(cp_size, -1, *x.shape[-2:]) if i > 1: with torch.cuda.stream(cp_stream): a2a_reqs[i - 2].wait() x = a2a_outputs[i - 2] - # [cp, 2, b, s//2, np//cp, hn] -> [2, b, s//2, cp, np//cp, hn] - # or [cp, 2, s//2, b, np//cp, hn] -> [2, s//2, b, cp, np//cp, hn] - # or [cp, 2, b, np//cp, s//2, hn] -> [2, b, cp, np//cp, s//2, hn] - # or [cp, t, np//cp, hn] -> [t, cp, np//cp, hn] + # [cp, 2, b, s//2, h//cp, d] -> [2, b, s//2, cp, h//cp, d] + # or [cp, 2, s//2, b, h//cp, d] -> [2, s//2, b, cp, h//cp, d] + # or [cp, 2, b, h//cp, s//2, d] -> [2, b, cp, h//cp, s//2, d] + # or [cp, t, h//cp, d] -> [t, cp, h//cp, d] tmp_list = [x for x in qkv_format] if "t" not in qkv_format: tmp_list.insert(0, "2") @@ -519,23 +519,23 @@ def flash_attn_a2a_communicate( tmp_format = "".join(tmp_list) head_dim_ = tmp_format.index("h")-1 tmp_list.insert(head_dim_, tmp_list.pop(0)) - tmp_format = "".join(tmp_list) x = x.movedim(0, head_dim_) - # [2, b, s//2, cp, np//cp, hn] -> [b, 2, s//2, cp, np//cp, hn] - # or [2, s//2, b, cp, np//cp, hn] -> [2, s//2, b, cp, np//cp, hn] - # or [2, b, cp, np//cp, s//2, hn] -> [b, cp, np//cp, 2, s//2, hn] - # or [t, cp, np//cp, hn] -> [t, cp, np//cp, hn] + # [2, b, s//2, cp, h//cp, d] -> [b, 2, s//2, cp, h//cp, d] + # or [2, s//2, b, cp, h//cp, d] -> [2, s//2, b, cp, h//cp, d] + # or [2, b, cp, h//cp, s//2, d] -> [b, cp, h//cp, 2, s//2, d] + # or [t, cp, h//cp, d] -> [t, cp, h//cp, d] if "t" not in qkv_format: - s_index = tmp_format.index("s") - tmp_list.insert(s_index - 1, tmp_list.pop(0)) tmp_format = "".join(tmp_list) seq_dim_ = tmp_format.index("s")-1 + tmp_list.insert(seq_dim_, tmp_list.pop(0)) x = x.movedim(0, seq_dim_) + else: + seq_dim_ = 0 x = x.contiguous() - # [b, 2, s//2, cp, np//cp, hn] -> [b*s, np, hn] - # or [2, s//2, b, cp, np//cp, hn] -> [s*b, np, hn] - # or [b, cp, np//cp, 2, s//2, hn] -> [b*np, s, hn] - # or [t, cp, np//cp, hn] -> [t, np, hn] + # [b, 2, s//2, cp, h//cp, d] -> [b*s, h, d] + # or [2, s//2, b, cp, h//cp, d] -> [s*b, h, d] + # or [b, cp, h//cp, 2, s//2, d] -> [b*h, s, d] + # or [t, cp, h//cp, d] -> [t, h, d] a2a_outputs[i - 2] = x.view(-1, x.shape[-3] * x.shape[-2], x.shape[-1]) torch.cuda.current_stream().wait_stream(cp_stream) return a2a_outputs[0] if len(a2a_inputs) == 1 else a2a_outputs @@ -1858,7 +1858,7 @@ def forward( with torch.cuda.stream(flash_attn_streams[(i - 1) % 2]): if use_fused_attention: # [b, h, sq, 1] -> [b, h, sq] or - # [t, h, 1] -> [t, np] + # [t, h, 1] -> [t, h] softmax_lse_per_step[i - 1].squeeze_(-1) if softmax_lse_in_packed_format: softmax_lse_per_step[i - 1] = ( @@ -2244,13 +2244,13 @@ def backward(ctx, dout, *_args): if ctx.softmax_lse_in_packed_format: softmax_lse_ = softmax_lse_.transpose(0, 1).contiguous() # [b, h, sq//2] -> [b, h, sq//2, 1] or - # [t//2, np] -> [t//2, h, 1] + # [t//2, h] -> [t//2, h, 1] softmax_lse_.unsqueeze_(-1) if ctx.use_fused_attention: if ctx.softmax_lse_in_packed_format: softmax_lse = softmax_lse.transpose(0, 1).contiguous() # [b, h, sq] -> [b, h, sq, 1] or - # [t, np] -> [t, h, 1] + # [t, h] -> [t, h, 1] softmax_lse.unsqueeze_(-1) # assume fwd and bwd always use the same high precision, i.e. torch.float16 or torch.bfloat16 @@ -3705,24 +3705,31 @@ def forward( ): # pylint: disable=missing-function-docstring nvtx_range_push("transformer_engine.AttnFuncWithCPAndQKVOA2A.forward") - if softmax_scale is None: - softmax_scale = q.shape[-1] ** (-0.5) cp_size = get_distributed_world_size(cp_group) + qkv_layout = qkv_format + "_" + qkv_format + "_" + qkv_format + original_qkv_layout = qkv_layout + o_format = qkv_format + batch_dim_qkv, seq_dim_qkv, _ = get_bsh_dims(qkv_format) + _, seq_dim_o, _ = get_bsh_dims(o_format) + if softmax_scale is None: + softmax_scale = q.shape[-1] ** (-0.5) - causal = "causal" in attn_mask_type - padding = "padding" in attn_mask_type - assert ( - not padding or qkv_format == "thd" - ), f"{attn_mask_type} mask type is not supported for BSHD and SBHD!" - assert attn_bias_type == "no_bias", f"{attn_bias_type} bias type is not supported!" - assert q.shape[-1] % 8 == 0, "Hidden size per attention head should be multiple of 8!" + if qkv_format in ["bshd", "sbhd"]: + assert "padding" not in attn_mask_type, f"No support for cp_comm_type='a2a', {attn_mask_type=} and {qkv_format=}." + assert attn_bias_type == "no_bias", f"No support for cp_comm_type='a2a' and {attn_bias_type=}." assert ( window_size == (-1, 0) or window_size == (-1, -1) or use_fused_attention or fa_utils.v2_3_plus - ), "Sliding window attention only can work with FusedAttention or FlashAttention >= 2.3!" + ), f"cp_comm_type='a2a' only supports SWA through FusedAttention or FlashAttention >= 2.3. Found {use_fused_attention=} and {fa_utils.v2_3_plus=}." + assert ( + q.shape[seq_dim_qkv] % 2 == 0 and k.shape[seq_dim_qkv] % 2 == 0 + ), f"cp_comm_type='a2a' requires seq_len % 2 == 0 for Q, K, V. Found seq_len_q = {q.shape[seq_dim_qkv]}, seq_len_kv = {k.shape[seq_dim_qkv]}, cp_size = {cp_size}." + assert ( + q.shape[-2] % cp_size == 0 and k.shape[-2] % cp_size == 0 + ), f"cp_comm_type='a2a' requires num_heads % cp_size == 0 for Q, K, V. Found num_heads_q = {q.shape[-2]}, num_heads_kv = {k.shape[-2]}, cp_size = {cp_size}." flash_attn_fwd = None if not use_fused_attention: @@ -3761,20 +3768,6 @@ def forward( if fa_utils.v2_6_0_plus: fa_forward_kwargs["softcap"] = 0.0 - assert ( - q.shape[-2] % cp_size == 0 and k.shape[-2] % cp_size == 0 - ), "The number of attention heads needs to be divisible by CP size!" - - qkv_layout = qkv_format + "_" + qkv_format + "_" + qkv_format - original_qkv_layout = qkv_layout - o_format = qkv_format - batch_dim_qkv, seq_dim_qkv, _ = get_bsh_dims(qkv_format) - _, seq_dim_o, _ = get_bsh_dims(o_format) - - assert ( - q.shape[seq_dim_qkv] % 2 == 0 and k.shape[seq_dim_qkv] % 2 == 0 - ), "Sequence length per GPU needs to be divisible by 2!" - assert isinstance(k, q.__class__) and isinstance( v, q.__class__ ), "q, k, v must be of the same class, e.g. torch.Tensor or QuantizedTensorStorage." @@ -3786,6 +3779,7 @@ def forward( fp8_recipe = FP8GlobalStateManager.get_fp8_recipe() if fp8_meta is not None and fp8_meta.get("local_recipes", None) is not None: fp8_recipe = fp8_meta["local_recipes"][0] + fwd_nominal_dtype = q.dtype fused_attn_backend = None max_logit = None @@ -3795,31 +3789,32 @@ def forward( ) q_fp8, k_fp8, v_fp8 = (None, None, None) + fp8_meta_kwargs = {} if fp8: - if use_fused_attention: - fused_attn_backend = FusedAttnBackend["FP8"] - if is_input_fp8: - q_fp8, k_fp8, v_fp8 = q, k, v - elif not fp8_recipe.mxfp8(): - q_fp8, k_fp8, v_fp8, qkv_layout = combine_and_quantize( - qkv_layout, q, k, v, QKV_quantizer - ) - if not fp8_recipe.mxfp8(): - q, k, v = [q_fp8._data, k_fp8._data, v_fp8._data] - # else: - # q, k, v = [q_fp8, k_fp8, v_fp8] - # qkv_format, _, _ = dpa_utils.get_qkv_format(qkv_layout) - # batch_dim_qkv, seq_dim_qkv, _ = get_bsh_dims(qkv_format) - fp8_meta_kwargs = {} - fp8_meta_kwargs["s_quantizer"] = S_quantizer - fp8_meta_kwargs["o_quantizer"] = O_quantizer - else: - assert False, "FP8 is only supported with Fused Attention!" + assert use_fused_attention, "FP8 is only supported with FusedAttention backend!" + fused_attn_backend = FusedAttnBackend["FP8"] + if is_input_fp8: + q_fp8, k_fp8, v_fp8 = q, k, v + elif not fp8_recipe.mxfp8(): + q_fp8, k_fp8, v_fp8, qkv_layout = combine_and_quantize( + qkv_layout, q, k, v, QKV_quantizer + ) + if not fp8_recipe.mxfp8(): + q, k, v = [q_fp8._data, k_fp8._data, v_fp8._data] + fp8_meta_kwargs["s_quantizer"] = S_quantizer + fp8_meta_kwargs["o_quantizer"] = O_quantizer else: if use_fused_attention: - fp8_meta_kwargs = {} fused_attn_backend = FusedAttnBackend["F16_arbitrary_seqlen"] + # q, k, v: + # FP8DS/FP8CS: torch.uint8 + # MXFP8: torch.float16 or torch.bfloat16 + # F16: torch.float16 or torch.bfloat16 + # a2a: gather s and split h + # [b, s//cp, h, d] -> [b, s, h//cp, d] + # [s//cp, b, h, d] -> [s, b, h//cp, d] + # [t//cp, h, d] -> [t, h//cp, d] chunk_ids_for_a2a = get_seq_chunk_ids_for_reordering_before_attn(cp_size, q.device) q, k, v = flash_attn_a2a_communicate( [q, k, v], @@ -3832,27 +3827,49 @@ def forward( qkv_format=qkv_format, cu_seqlens_padded=cu_seqlens_q_padded, ) + + # softmax_offset: split h + # [1, h, 1, 1] -> [1, h//cp, 1, 1] if softmax_type != "vanilla": softmax_offset = flash_attn_a2a_communicate_softmax_offset( softmax_offset, 1, cp_size, cp_group, cp_stream, True ) - out_fp8 = None - out_f16 = None + # _part: inputs to attention kernel and saved for backward + # note: they have post a2a shapes batch_size = q.shape[batch_dim_qkv] q_part, k_part, v_part = q, k, v - out_part = None - if use_fused_attention: - if fp8 and not fp8_recipe.mxfp8(): - q_part, k_part, v_part = [ - Float8Tensor.make_like(x, data=y, dtype=fwd_nominal_dtype) - for x, y in zip([q_fp8, k_fp8, v_fp8], [q_part, k_part, v_part]) - ] - if fp8 and fp8_recipe.mxfp8(): - q_fp8, k_fp8, v_fp8, qkv_layout = combine_and_quantize( - qkv_layout, q_part, k_part, v_part, QKV_quantizer + out_part, out_fp8, out_f16 = None, None, None + bwd_requires_o_f16 = is_training and ( + not is_bwd_fp8 + or ( + is_bwd_fp8 + and ( + (fp8_recipe.float8_current_scaling() and _dpa_fp8_cs_o_in_f16) + or fp8_recipe.mxfp8() ) - q_part, k_part, v_part = [q_fp8, k_fp8, v_fp8] + ) + ) + bwd_requires_o_fp8 = ( + is_training + and is_bwd_fp8 + and ( + fp8_recipe.delayed() + or (fp8_recipe.float8_current_scaling() and not _dpa_fp8_cs_o_in_f16) + ) + ) + if use_fused_attention: + if fp8: + if fp8_recipe.mxfp8(): + q_fp8, k_fp8, v_fp8, qkv_layout = combine_and_quantize( + qkv_layout, q_part, k_part, v_part, QKV_quantizer + ) + q_part, k_part, v_part = [q_fp8, k_fp8, v_fp8] + else: + q_part, k_part, v_part = [ + Float8Tensor.make_like(x, data=y, dtype=fwd_nominal_dtype) + for x, y in zip([q_fp8, k_fp8, v_fp8], [q_part, k_part, v_part]) + ] out_, aux_ctx_tensors, *max_logit = fused_attn_fwd( is_training, max_seqlen_q, @@ -3880,25 +3897,17 @@ def forward( return_max_logit=return_max_logit, cuda_graph=is_graph_capturing(), ) - if isinstance(out_, QuantizedTensorStorage): - out_fp8 = out_ - out_ = out_._data - if is_bwd_fp8 and not ( - fp8_recipe.float8_current_scaling() and _dpa_fp8_cs_o_in_f16 - ): - out_part = out_fp8 - else: - out_part = out_fp8.dequantize(dtype=fwd_nominal_dtype) - else: - out_f16 = out_ - out_part = out_ - if ( - fp8 - and is_bwd_fp8 - and not (fp8_recipe.float8_current_scaling() and _dpa_fp8_cs_o_in_f16) - and not fp8_recipe.mxfp8() - ): - out_part = O_quantizer(out_) + # construct out_part for backward + out_fp8 = out_ + out_f16 = out_ + if bwd_requires_o_fp8: + if not isinstance(out_, QuantizedTensorStorage): + out_fp8 = O_quantizer(out_) + out_part = out_fp8 + if bwd_requires_o_f16: + if isinstance(out_, QuantizedTensorStorage): + out_f16 = out_.dequantize(dtype=fwd_nominal_dtype) + out_part = out_f16 else: fa_forward_args_thd = get_fa_args( True, @@ -3914,7 +3923,7 @@ def forward( k_part, v_part, *fa_forward_args_thd, - causal=causal, + causal="causal" in attn_mask_type, **fa_forward_kwargs, ) if not fa_utils.v2_7_0_plus: @@ -3926,6 +3935,12 @@ def forward( aux_ctx_tensors = [softmax_lse, rng_state] out_part = out_ + # a2a: split s and gather h + # [b, s, h//cp, d] -> [b*s//cp, h, d] + # [s, b, h//cp, d] -> [s//cp*b, h, d] + # [t, h//cp, d] -> [t//cp, h, d] + if isinstance(out_, QuantizedTensorStorage): + out_ = out_._data chunk_ids_for_a2a = get_seq_chunk_ids_for_reordering_after_attn(cp_size, out_.device) out_ = flash_attn_a2a_communicate( out_, @@ -3938,54 +3953,71 @@ def forward( qkv_format=o_format, cu_seqlens_padded=cu_seqlens_q_padded, ) - if return_max_logit: - max_logit = flash_attn_a2a_communicate_softmax_offset( - *max_logit, 0, cp_size, cp_group, cp_stream, False - ) + # [b*s//cp, h, d] -> [b, s//cp, h, d] + # [s//cp*b, h, d] -> [s//cp, b, h, d] + # [t//cp, h, d] -> [t//cp, h, d] + if o_format == "bshd": + out_ = out_.view(batch_size, -1, *out_.shape[-2:]) + elif o_format == "sbhd": + out_ = out_.view(-1, batch_size, *out_.shape[-2:]) - if use_fused_attention: - if o_format == "bshd": - # [b*s, h, d] -> [b, s, h, d] - out_ = out_.view(batch_size, -1, *out_.shape[-2:]) - elif o_format == "sbhd": - # [s*b, h, d] -> [s, b, h, d] - out_ = out_.view(-1, batch_size, *out_.shape[-2:]) - - if fp8 and use_fused_attention: - if fp8_recipe.float8_current_scaling() or fp8_recipe.mxfp8(): - out_f16 = out_ - if is_output_fp8: - out_fp8 = O_quantizer(out_) + # out_ret: output tensor for forward pass + if fp8: if fp8_recipe.delayed(): out_fp8 = Float8Tensor.make_like(out_fp8, data=out_, dtype=fwd_nominal_dtype) - if not is_output_fp8: + if is_output_fp8: + if fp8_recipe.float8_current_scaling() or fp8_recipe.mxfp8(): + out_fp8 = O_quantizer(out_) + out_f16 = out_ + else: + if fp8_recipe.delayed(): out_f16 = out_fp8.dequantize(dtype=fwd_nominal_dtype) + else: + out_f16 = out_ else: out_f16 = out_ - out_ret = out_fp8 if is_output_fp8 else out_f16 - ctx.fp8 = fp8 and is_bwd_fp8 + # all gather max logit + if return_max_logit: + max_logit = flash_attn_a2a_communicate_softmax_offset( + *max_logit, 0, cp_size, cp_group, cp_stream, False + ) + ctx.qkv_layout = qkv_layout ctx.o_format = o_format ctx.dqkv_layout = original_qkv_layout + ctx.dqkv_format = qkv_format + ctx.batch_size = batch_size + ctx.out_part_shape = out_part.shape + ctx.out_ret_shape = out_ret.shape + + # save tensors for backward + ctx.fp8 = fp8 and is_bwd_fp8 fp8_tensors = (None, None, None, None) f16_tensors = (None, None, None, None) - if ctx.fp8: - if (fp8_recipe.float8_current_scaling() and _dpa_fp8_cs_o_in_f16) or fp8_recipe.mxfp8(): - fp8_tensors = (q_part, k_part, v_part, None) - f16_tensors = (None, None, None, out_part) + if is_training: + if ctx.fp8: + # FP8DS or (FP8CS+not _dpa_fp8_cs_o_in_f16): q/k/v/o all in FP8 + # (FP8CS+_dpa_fp8_cs_o_in_f16) or MXFP8: q/k/v in FP8, o in F16 + if (fp8_recipe.float8_current_scaling() and _dpa_fp8_cs_o_in_f16) or fp8_recipe.mxfp8(): + fp8_tensors = (q_part, k_part, v_part, None) + f16_tensors = (None, None, None, out_part) + elif fp8_recipe.delayed() or (fp8_recipe.float8_current_scaling() and not _dpa_fp8_cs_o_in_f16): + fp8_tensors = (q_part, k_part, v_part, out_part) + elif fp8: + # FP8DS/CS: convert post-a2a FP8 q/k/v to F16 + # MXFP8: save post-a2a pre-quantization F16 q/k/v + # out_part is already converted to the right precision + if fp8_recipe.mxfp8(): + f16_tensors = (q, k, v, out_part) + ctx.qkv_layout = original_qkv_layout + else: + q_part, k_part, v_part = combine_and_dequantize(qkv_layout, q_part, k_part, v_part) + f16_tensors = (q_part, k_part, v_part, out_part) else: - fp8_tensors = (q_part, k_part, v_part, out_part) - elif fp8 and not fp8_recipe.mxfp8(): - q_part, k_part, v_part = combine_and_dequantize(qkv_layout, q_part, k_part, v_part) - f16_tensors = (q_part, k_part, v_part, out_part) - elif fp8 and fp8_recipe.mxfp8(): - f16_tensors = (q, k, v, out_part) - ctx.qkv_layout = original_qkv_layout - else: - f16_tensors = (q_part, k_part, v_part, out_part) - + # all tensors are already in F16 + f16_tensors = (q_part, k_part, v_part, out_part) tensors_to_save, tensor_objects = prepare_for_saving( *fp8_tensors, *f16_tensors, @@ -3997,16 +4029,13 @@ def forward( ) ctx.save_for_backward(*tensors_to_save) ctx.tensor_objects = tensor_objects - ctx.out_shape = out_ret.shape - ctx.batch_size = batch_size ctx.cp_group = cp_group ctx.cp_stream = cp_stream ctx.dropout_p = dropout_p ctx.max_seqlen_q = max_seqlen_q ctx.max_seqlen_kv = max_seqlen_kv ctx.softmax_scale = softmax_scale - # ctx.qkv_format = qkv_format ctx.attn_mask_type = attn_mask_type ctx.attn_bias_type = attn_bias_type ctx.deterministic = deterministic @@ -4030,11 +4059,11 @@ def forward( ctx.QKV_quantizer = QKV_quantizer.copy() ctx.O_quantizer = O_quantizer.copy() ctx.S_quantizer = S_quantizer.copy() if S_quantizer is not None else None - if not ctx.fp8_recipe.mxfp8(): ctx.QKV_quantizer.scale = QKV_quantizer.scale.clone() ctx.O_quantizer.scale = O_quantizer.scale.clone() ctx.S_quantizer.scale = S_quantizer.scale.clone() + nvtx_range_pop("transformer_engine.AttnFuncWithCPAndQKVOA2A.forward") if return_max_logit: return out_ret, max_logit @@ -4062,50 +4091,40 @@ def backward(ctx, dout, *_args): *aux_ctx_tensors, ) = restore_from_saved(ctx.tensor_objects, ctx.saved_tensors) - # qkv_format = ctx.qkv_format - # qkv_layout = qkv_format + "_" + qkv_format + "_" + qkv_format - # qkv_layout = ctx.qkv_layout - causal = "causal" in ctx.attn_mask_type - dqkv_format, _, _ = dpa_utils.get_qkv_format(ctx.dqkv_layout) - - batch_dim_dqkv, seq_dim_dqkv, _ = get_bsh_dims(dqkv_format) + batch_dim_dqkv, seq_dim_dqkv, _ = get_bsh_dims(ctx.dqkv_format) _, seq_dim_do, _ = get_bsh_dims(ctx.o_format) - bwd_nominal_dtype = ctx.fwd_nominal_dtype - # dqkv_te_dtype = None fused_attn_backend = None - dout_fp8 = dout - if ctx.fp8: - if ctx.use_fused_attention: - fused_attn_backend = FusedAttnBackend["FP8"] - if not isinstance(dout, QuantizedTensorStorage) and not ctx.fp8_recipe.mxfp8(): - dout = ctx.dO_quantizer(dout) - dout_fp8 = dout - if not ctx.fp8_recipe.mxfp8(): - # dqkv_te_dtype = dout._fp8_dtype - dout = dout._data - fp8_meta_kwargs = {} - fp8_meta_kwargs["s_quantizer"] = ctx.S_quantizer - fp8_meta_kwargs["dp_quantizer"] = ctx.dP_quantizer - fp8_meta_kwargs["dqkv_quantizer"] = ctx.dQKV_quantizer - else: - assert False, "FP8 is only supported with Fused Attention!" + dout_fp8 = None + fp8_meta_kwargs = {} + if ctx.fp8: + assert ctx.use_fused_attention, "FP8 is only supported with FusedAttention backend!" + fused_attn_backend = FusedAttnBackend["FP8"] + if isinstance(dout, QuantizedTensorStorage): + dout_fp8 = dout + elif not ctx.fp8_recipe.mxfp8(): + dout = ctx.dO_quantizer(dout) + dout_fp8 = dout + if not ctx.fp8_recipe.mxfp8(): + dout = dout._data + fp8_meta_kwargs["s_quantizer"] = ctx.S_quantizer + fp8_meta_kwargs["dp_quantizer"] = ctx.dP_quantizer + fp8_meta_kwargs["dqkv_quantizer"] = ctx.dQKV_quantizer else: if isinstance(dout, QuantizedTensorStorage): dout = dout.dequantize(dtype=bwd_nominal_dtype) if ctx.use_fused_attention: - fp8_meta_kwargs = {} - # dqkv_te_dtype = TE_DType[dout.dtype] fused_attn_backend = FusedAttnBackend["F16_arbitrary_seqlen"] - - if not ctx.use_fused_attention: - if ctx.o_format in ["bshd", "sbhd"]: - out = out.view(ctx.batch_size, -1, *out.shape[-2:]) - dout = dout.view(ctx.batch_size, -1, *dout.shape[-2:]) - else: - dout = dout.view(*ctx.out_shape) - + dout = dout.view(*ctx.out_ret_shape) + + # dout: + # FP8DS/CS: torch.uint8 + # MXFP8/F16: torch.float16 or torch.bfloat16 + # a2a: gather s and split h + # [b, s//cp, h, d] -> [b, s, h//cp, d] + # [s//cp, b, h, d] -> [s, b, h//cp, d] + # [t//cp, h, d] -> [t, h//cp, d] chunk_ids_for_a2a = get_seq_chunk_ids_for_reordering_before_attn(cp_size, dout.device) dout = flash_attn_a2a_communicate( dout, @@ -4118,6 +4137,7 @@ def backward(ctx, dout, *_args): qkv_format=ctx.o_format, cu_seqlens_padded=cu_seqlens_q_padded, ) + flash_attn_bwd = None if not ctx.use_fused_attention: fa_backward_kwargs = {"softmax_scale": ctx.softmax_scale} @@ -4158,8 +4178,8 @@ def backward(ctx, dout, *_args): fa_backward_kwargs["softcap"] = 0.0 dq_fp8, dk_fp8, dv_fp8 = None, None, None - d_out_format = ctx.o_format if ctx.use_fused_attention: + d_out_format = ctx.o_format q_part, k_part, v_part, out_part, dout_part = q, k, v, out, dout if ctx.fp8: q_part, k_part, v_part, out_part = q_fp8, k_fp8, v_fp8, out_fp8 @@ -4170,9 +4190,10 @@ def backward(ctx, dout, *_args): if not ctx.fp8_recipe.mxfp8(): dout_part = Float8Tensor.make_like(dout_fp8, data=dout, dtype=bwd_nominal_dtype) else: + # d_out_format = bhsd for both dout (F16) and dout_part (MXFP8) dout, d_out_format = dpa_utils.permute_to_grouped_tensor(d_out_format, dout) - dout_part = ctx.dO_quantizer(dout) aux_ctx_tensors.append(dout) + dout_part = ctx.dO_quantizer(dout) dq, dk, dv, *rest = fused_attn_bwd( ctx.max_seqlen_q, ctx.max_seqlen_kv, @@ -4184,7 +4205,6 @@ def backward(ctx, dout, *_args): out_part, dout_part, bwd_nominal_dtype, - # dqkv_te_dtype, aux_ctx_tensors, fused_attn_backend, cu_seqlens_q_padded=cu_seqlens_q_padded, @@ -4203,7 +4223,7 @@ def backward(ctx, dout, *_args): **fp8_meta_kwargs, softmax_type=ctx.softmax_type, ) - if isinstance(dq, QuantizedTensorStorage) and not ctx.fp8_recipe.mxfp8(): + if all(isinstance(x, QuantizedTensorStorage) for x in [dq, dk, dv]): dq_fp8, dk_fp8, dv_fp8 = dq, dk, dv dq, dk, dv = [x._data for x in [dq, dk, dv]] else: @@ -4212,7 +4232,7 @@ def backward(ctx, dout, *_args): fa_backward_args_thd = get_fa_args( False, ctx.use_flash_attn_3, - ctx.o_format, + ctx.dqkv_format, cu_seqlens_q=cu_seqlens_q, cu_seqlens_kv=cu_seqlens_kv, max_seqlen_q=ctx.max_seqlen_q, @@ -4231,10 +4251,17 @@ def backward(ctx, dout, *_args): out, softmax_lse, *fa_backward_args_thd, - causal=causal, + causal="causal" in ctx.attn_mask_type, **fa_backward_kwargs, ) + # dq, dk, dv: + # FP8DS: torch.uint8 + # FP8CS/MXFP8/F16: torch.float16 or torch.bfloat16 + # a2a: gather s and split h + # [b, s//cp, h, d] -> [b, s, h//cp, d] + # [s//cp, b, h, d] -> [s, b, h//cp, d] + # [t//cp, h, d] -> [t, h//cp, d] chunk_ids_for_a2a = get_seq_chunk_ids_for_reordering_after_attn(cp_size, dq.device) dq, dk, dv = flash_attn_a2a_communicate( [dq, dk, dv], @@ -4244,14 +4271,15 @@ def backward(ctx, dout, *_args): ctx.cp_group, ctx.cp_stream, before_attn=False, - qkv_format=dqkv_format, + qkv_format=ctx.dqkv_format, cu_seqlens_padded=cu_seqlens_q_padded, ) - if dqkv_format == "bshd": + if ctx.dqkv_format == "bshd": dq, dk, dv = [x.view(ctx.batch_size, -1, *x.shape[-2:]) for x in [dq, dk, dv]] - elif dqkv_format == "sbhd": + elif ctx.dqkv_format == "sbhd": dq, dk, dv = [x.view(-1, ctx.batch_size, *x.shape[-2:]) for x in [dq, dk, dv]] + # d_bias, d_softmax_offset d_bias = None d_softmax_offset = None if ctx.use_fused_attention: @@ -4263,6 +4291,7 @@ def backward(ctx, dout, *_args): d_softmax_offset, 1, cp_size, ctx.cp_group, ctx.cp_stream, False ) + # convert dq, dk, dv to appropriate types if ctx.fp8: if ( ctx.fp8_recipe.float8_current_scaling() or ctx.fp8_recipe.mxfp8() @@ -4283,8 +4312,8 @@ def backward(ctx, dout, *_args): dv, src_nominal_dtype=bwd_nominal_dtype, ) - nvtx_range_pop("transformer_engine.AttnFuncWithCPAndQKVOA2A.backward") + nvtx_range_pop("transformer_engine.AttnFuncWithCPAndQKVOA2A.backward") return ( None, dq, From dea59e4f461e5e43c70a6b3400c0cc937ef1fce1 Mon Sep 17 00:00:00 2001 From: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Date: Tue, 17 Mar 2026 00:42:08 -0700 Subject: [PATCH 080/172] update FE to fix d64 Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --- 3rdparty/cudnn-frontend | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/3rdparty/cudnn-frontend b/3rdparty/cudnn-frontend index b449099e98..f87e5f42ac 160000 --- a/3rdparty/cudnn-frontend +++ b/3rdparty/cudnn-frontend @@ -1 +1 @@ -Subproject commit b449099e98fbe13aacc7cd6c1cb48cc11914a210 +Subproject commit f87e5f42ac2f6617cc645bb32873d934b8d8bdd7 From 9da8ec98971032e6fab1eb71481eb6bfad64f552 Mon Sep 17 00:00:00 2001 From: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Date: Tue, 17 Mar 2026 14:46:19 -0700 Subject: [PATCH 081/172] refactor ag Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --- .../dot_product_attention/context_parallel.py | 352 +++++++++++------- 1 file changed, 214 insertions(+), 138 deletions(-) diff --git a/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py b/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py index 94a64bda13..8c5d974a48 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py @@ -2952,25 +2952,28 @@ def forward( ): # pylint: disable=missing-function-docstring nvtx_range_push("transformer_engine.AttnFuncWithCPAndKVAllGather.forward") - if softmax_scale is None: - softmax_scale = q.shape[-1] ** (-0.5) cp_size = get_distributed_world_size(cp_group) rank = get_distributed_rank(cp_group) - - assert qkv_format != "thd", f"{qkv_format} format is not supported!" qkv_layout = qkv_format + "_" + qkv_format + "_" + qkv_format + o_format = qkv_format + _, seq_dim_qkv, _ = get_bsh_dims(qkv_format) + if softmax_scale is None: + softmax_scale = q.shape[-1] ** (-0.5) - causal = "causal" in attn_mask_type - padding = "padding" in attn_mask_type - assert not padding, f"{attn_mask_type} mask type is not supported!" - if use_fused_attention and causal and "bottom_right" not in attn_mask_type: - attn_mask_type = attn_mask_type + "_bottom_right" - assert attn_bias_type == "no_bias", f"{attn_bias_type} bias type is not supported!" - assert q.shape[-1] % 8 == 0, "Hidden size per attention head should be multiple of 8!" + assert qkv_format != "thd", f"No support for cp_comm_type='all_gather' and {qkv_format=}." + assert "padding" not in attn_mask_type, f"No support for cp_comm_type='all_gather' and {attn_mask_type=}." + assert attn_bias_type == "no_bias", f"No support for cp_comm_type='all_gather' and {attn_bias_type=}." assert ( - use_fused_attention or fa_utils.v2_3_plus - ), "Sliding window attention only can work with FusedAttention or FlashAttention >= 2.3!" + window_size == (-1, 0) + or window_size == (-1, -1) + or use_fused_attention + or fa_utils.v2_3_plus + ), f"cp_comm_type='all_gather' only supports SWA through FusedAttention or FlashAttention >= 2.3. Found {use_fused_attention=} and {fa_utils.v2_3_plus=}." + assert ( + q.shape[seq_dim_qkv] % 2 == 0 and k.shape[seq_dim_qkv] % 2 == 0 + ), f"cp_comm_type='all_gather' requires seq_len % 2 == 0 for Q, K, V. Found seq_len_q = {q.shape[seq_dim_qkv]}, seq_len_kv = {k.shape[seq_dim_qkv]}, cp_size = {cp_size}." + flash_attn_fwd = None if not use_fused_attention: @@ -3003,11 +3006,6 @@ def forward( if fa_utils.v2_6_0_plus: fa_forward_kwargs["softcap"] = 0.0 - seq_dim = qkv_format.index("s") - assert ( - q.shape[seq_dim] % 2 == 0 and k.shape[seq_dim] % 2 == 0 - ), "Sequence length per GPU needs to be divisible by 2!" - max_seqlen_q = max_seqlen_q // (2 * cp_size) max_seqlen_kv = max_seqlen_kv // (2 * cp_size) if use_fused_attention or qkv_format == "thd": @@ -3016,6 +3014,8 @@ def forward( cu_seqlens_q_padded = cu_seqlens_q_padded // (2 * cp_size) else: cu_seqlens_q_padded = None + if use_fused_attention and attn_mask_type == "causal": + attn_mask_type = attn_mask_type + "_bottom_right" # FP8 setup assert isinstance(k, q.__class__) and isinstance( @@ -3036,12 +3036,12 @@ def forward( dP_quantizer, ) = dpa_utils.get_attention_quantizers(fp8, fp8_recipe, quantizers) fwd_nominal_dtype = q.dtype - fp8_meta_kwargs = {} q_fp8, k_fp8, v_fp8 = (q, k, v) if is_input_fp8 else (None, None, None) q_f16, k_f16, v_f16 = (None, None, None) if is_input_fp8 else (q, k, v) fused_attn_backend = None + fp8_meta_kwargs = {} if fp8: - assert use_fused_attention, "FP8 is only supported with Fused Attention!" + assert use_fused_attention, "FP8 is only supported with FusedAttention backend!" fused_attn_backend = tex.NVTE_Fused_Attn_Backend.NVTE_FP8 if not is_input_fp8 and not fp8_recipe.mxfp8(): q_fp8, k_fp8, v_fp8, qkv_layout = combine_and_quantize( @@ -3053,33 +3053,48 @@ def forward( fp8_meta_kwargs["o_quantizer"] = O_quantizer elif use_fused_attention: fused_attn_backend = tex.NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen + orig_q_shape, orig_k_shape, orig_v_shape = q.shape, k.shape, v.shape + orig_o_shape = orig_q_shape[:-1] + orig_v_shape[-1:] - # [b, s, h, d] -> [b, 2, s//2, h, d] or [s, b, h, d] -> [2, s//2, b, h, d] - q = q.view(*q.shape[:seq_dim], 2, q.shape[seq_dim] // 2, *q.shape[(seq_dim + 1) :]) - q_shape = q.shape - # [b, s, h, d] or [s, b, h, d] -> [s, b, h, d] - k, v = [x.movedim(seq_dim, 0).contiguous() for x in [k, v]] - k_shape = k.shape - v_shape = v.shape - - # [s, b, h, d] -> [cp, s, b, h, d] + # q, k, v: + # FP8DS/CS: torch.uint8 + # MXFP8/F16: torch.float16 or torch.bfloat16 + # reshape: split s + # [b, s, h, d] -> [b, 2, s//2, h, d] + # [s, b, h, d] -> [2, s//2, b, h, d] + q = q.view(*q.shape[:seq_dim_qkv], 2, q.shape[seq_dim_qkv] // 2, *q.shape[(seq_dim_qkv + 1) :]) + # s dim first for all-gather + # [b, s, h, d]/[s, b, h, d] -> [s, b, h, d] + k, v = [x.movedim(seq_dim_qkv, 0).contiguous() for x in [k, v]] + + # gather along s: [s, b, h, d] -> [cp, s, b, h, d] k_ag, _ = gather_along_first_dim(k, cp_group) v_ag, _ = gather_along_first_dim(v, cp_group) - - # [cp, s, b, h, d] -> [cp*2, s//2, b, h, d] + # split s:[cp, s, b, h, d] -> [cp*2, s//2, b, h, d] k_ag = k_ag.view(2 * cp_size, k.shape[0] // 2, *k.shape[1:]) v_ag = v_ag.view(2 * cp_size, v.shape[0] // 2, *v.shape[1:]) + # pick out specific chunks for each rank chunk_ids_for_kv_ag = get_seq_chunk_ids_for_reordering_before_attn(cp_size, k.device) k_ag = torch.index_select(k_ag, dim=0, index=chunk_ids_for_kv_ag) v_ag = torch.index_select(v_ag, dim=0, index=chunk_ids_for_kv_ag) - # [cp*2, s//2, b, h, d] -> [cp*s, b, h, d] + # reshape/flatten: [cp*2, s//2, b, h, d] -> [cp*s, b, h, d] k_ag = k_ag.view(-1, *k.shape[1:]) v_ag = v_ag.view(-1, *v.shape[1:]) cp_stream.wait_stream(torch.cuda.current_stream()) + # q: [b, 2, s//2, h, d] or [2, s//2, b, h, d] + # k: [s, b, h, d] + # v: [s, b, h, d] + # k_ag: [cp*s, b, h, d] + # v_ag: [cp*s, b, h, d] + # out: [b, 2, s//2, h, d] or [2, s//2, b, h, d] + q_shape, k_shape, v_shape = q.shape, k.shape, v.shape + o_shape = q.shape[:-1] + v.shape[-1:] + out = torch.empty(o_shape, dtype=fwd_nominal_dtype, device=q.device) + # create two streams to resolve wave quantization issue of Flash Attn in each step flash_attn_streams = [torch.cuda.current_stream(), cp_stream] - + # prepare per-step tensors local_seq_chunk_ids = [rank, 2 * cp_size - rank - 1] kv_seq_range_per_step = [None, None] window_size_per_step = [None, None] @@ -3087,18 +3102,15 @@ def forward( out_per_step = [None, None] softmax_lse_per_step = [None, None] rng_states = [None, None] - enable_mla = k.shape[-1] != v.shape[-1] - out_shape = q.shape if not enable_mla else q.shape[:-1] + v.shape[-1:] - out = torch.empty(out_shape, dtype=fwd_nominal_dtype, device=q.device) max_logit_per_step = [None, None] max_logit = None for i in range(len(local_seq_chunk_ids) + 1): if i < len(local_seq_chunk_ids): with torch.cuda.stream(flash_attn_streams[i]): - # [b, 2, sq//2, h, d] -> [b, sq//2, h, d] - # or [2, sq//2, b, h, d] -> [sq//2, b, h, d] - q_ = q.select(seq_dim, i).contiguous() + # [b, 2, s//2, h, d] -> [b, s//2, h, d] + # [2, s//2, b, h, d] -> [s//2, b, h, d] + q_part = q.select(seq_dim_qkv, i).contiguous() kv_seq_range_per_step[i], window_size_per_step[i] = ( get_kv_seq_info_after_all_gather( local_seq_chunk_ids[i], @@ -3106,7 +3118,7 @@ def forward( max_seqlen_q, max_seqlen_kv, window_size, - causal, + "causal" in attn_mask_type, ) ) seq_start_idx, seq_end_idx = ( @@ -3118,22 +3130,22 @@ def forward( cu_seqlens_kv_per_step[i] = dpa_utils.get_full_cu_seqlens( k.shape[1], max_seqlen_kv_, k.device ) - k_, v_ = [x[seq_start_idx:seq_end_idx] for x in [k_ag, v_ag]] - # [s_range, b, h, d] -> [b, s_range, h, d] or [s_range, b, h, d] - k_, v_ = [x.movedim(0, seq_dim).contiguous() for x in [k_, v_]] + # select range: [s_range, b, h, d] + k_part, v_part = [x[seq_start_idx:seq_end_idx] for x in [k_ag, v_ag]] + # reshape to original format: [b, s_range, h, d] or [s_range, b, h, d] + k_part, v_part = [x.movedim(0, seq_dim_qkv).contiguous() for x in [k_part, v_part]] if use_fused_attention: - q_part, k_part, v_part = q_, k_, v_ new_qkv_layout = qkv_layout if fp8: if not fp8_recipe.mxfp8(): q_part = Float8Tensor.make_like( - q_fp8, data=q_, dtype=fwd_nominal_dtype + q_fp8, data=q_part, dtype=fwd_nominal_dtype ) k_part = Float8Tensor.make_like( - k_fp8, data=k_, dtype=fwd_nominal_dtype + k_fp8, data=k_part, dtype=fwd_nominal_dtype ) v_part = Float8Tensor.make_like( - v_fp8, data=v_, dtype=fwd_nominal_dtype + v_fp8, data=v_part, dtype=fwd_nominal_dtype ) else: q_part, k_part, v_part, new_qkv_layout = combine_and_quantize( @@ -3157,7 +3169,7 @@ def forward( attn_scale=softmax_scale, dropout=dropout_p, qkv_layout=new_qkv_layout, - o_format=qkv_format, + o_format=o_format, attn_mask_type=attn_mask_type, attn_bias_type=attn_bias_type, attn_bias=attn_bias, @@ -3168,7 +3180,6 @@ def forward( cuda_graph=is_graph_capturing(), **fp8_meta_kwargs, ) - if fp8: softmax_lse_per_step[i], _, rng_states[i] = aux_ctx_tensors else: @@ -3193,11 +3204,11 @@ def forward( fa_forward_kwargs["window_size_left"] = window_size_per_step[i][0] fa_forward_kwargs["window_size_right"] = window_size_per_step[i][1] fa_outputs = flash_attn_fwd( - q_, - k_, - v_, + q_part, + k_part, + v_part, *fa_forward_args_thd, - causal=causal, + causal="causal" in attn_mask_type, **fa_forward_kwargs, ) if not fa_utils.v2_7_0_plus: @@ -3211,33 +3222,36 @@ def forward( if not use_flash_attn_3: rng_states[i] = fa_outputs[3] + # out_per_step[i]: fwd_nominal_dtype, [b, s//2, h, d] or [s//2, b, h, d] + # out: fwd_nominal_dtype, [b, 2, s//2, h, d] or [2, s//2, b, h, d] + # max_logit_per_step[i]: torch.float32, [h] + # max_logit: torch.float32, [h] if return_max_logit and i == 0: max_logit = torch.clone(max_logit_per_step[0]) if i > 0: with torch.cuda.stream(flash_attn_streams[i - 1]): - if qkv_format == "bshd": + if o_format == "bshd": out[:, i - 1].copy_(out_per_step[i - 1]) - elif qkv_format == "sbhd": + elif o_format == "sbhd": out[i - 1].copy_(out_per_step[i - 1]) if return_max_logit: max_logit = torch.maximum(max_logit, max_logit_per_step[i - 1]) torch.cuda.current_stream().wait_stream(cp_stream) + + # all reduce max_logit across ranks if return_max_logit: torch.distributed.all_reduce( max_logit, op=torch.distributed.ReduceOp.MAX, group=cp_group ) - if use_fused_attention: - if qkv_format == "bshd": - out = out.view(out.shape[0], -1, *out.shape[-2:]) - elif qkv_format == "sbhd": - out = out.view(-1, *out.shape[-3:]) - else: - out = out.view(-1, *out.shape[-2:]) + # out: fwd_nominal_dtype + # [b, 2, s//2, h, d] -> [b, s, h, d] + # [2, s//2, b, h, d] -> [s, b, h, d] + out = out.view(orig_o_shape) + # prepare for forward output and backward saves of out out_fp8 = None - out_ret = out bwd_requires_o_fp8 = ( is_training and is_bwd_fp8 @@ -3248,64 +3262,84 @@ def forward( ) if fp8 and (is_output_fp8 or bwd_requires_o_fp8): out_fp8 = O_quantizer(out) - if is_output_fp8: - out_ret = out_fp8 + out_ret = out_fp8 if is_output_fp8 else out + # save tensors for backward ctx.fp8 = fp8 and is_bwd_fp8 ctx.fp8_recipe = fp8_recipe fp8_tensors = (None, None, None, None) f16_tensors = (None, None, None, None) + # True: q split along s; k/v with s first, i.e. [s, b, h, d] + # False: original [b, s, h, d] or [s, b, h, d] ctx.qkv_reshaped = True + # no load-balance related token shuffling; original token order in q/k/v/out + # q: [b, 2, s//2, h, d] or [2, s//2, b, h, d] + # k: [s, b, h, d] + # v: [s, b, h, d] + # out/out_fp8: [b, s, h, d] or [s, b, h, d] if ctx.fp8: + # q_fp8_save: [b, 2, s//2, h, d] or [2, s//2, b, h, d] + # k_fp8_save: [s, b, h, d] + # v_fp8_save: [s, b, h, d] q_fp8_save, k_fp8_save, v_fp8_save = None, None, None if fp8_recipe.delayed() or fp8_recipe.float8_current_scaling(): q_fp8_save = Float8Tensor.make_like(q_fp8, data=q, dtype=fwd_nominal_dtype) k_fp8_save = Float8Tensor.make_like(k_fp8, data=k, dtype=fwd_nominal_dtype) v_fp8_save = Float8Tensor.make_like(v_fp8, data=v, dtype=fwd_nominal_dtype) + # FP8DS or (FP8CS+not _dpa_fp8_cs_o_in_f16): q/k/v/o all in FP8 + # FP8CS+_dpa_fp8_cs_o_in_f16: q/k/v in FP8, o in f16 + # MXFP8: q/k/v/o all in f16 if fp8_recipe.delayed() or (fp8_recipe.float8_current_scaling() and not _dpa_fp8_cs_o_in_f16): fp8_tensors = (q_fp8_save, k_fp8_save, v_fp8_save, out_fp8) - if fp8_recipe.float8_current_scaling() and _dpa_fp8_cs_o_in_f16: + elif fp8_recipe.float8_current_scaling() and _dpa_fp8_cs_o_in_f16: fp8_tensors = (q_fp8_save, k_fp8_save, v_fp8_save, None) f16_tensors = (None, None, None, out) - if fp8_recipe.mxfp8(): + elif fp8_recipe.mxfp8(): f16_tensors = (q, k, v, out) elif fp8: + # convert q/k/v to F16 if necessary, and save q/k/v/o all in F16 and original format if is_input_fp8: q_f16, k_f16, v_f16 = combine_and_dequantize(qkv_layout, q_fp8, k_fp8, v_fp8) - if fp8_recipe.mxfp8(): - f16_tensors = (q, k, v, out) - else: - f16_tensors = (q_f16, k_f16, v_f16, out) - ctx.qkv_reshaped = False + f16_tensors = (q_f16, k_f16, v_f16, out) + ctx.qkv_reshaped = False else: + # save all in F16 + # q: [b, 2, s//2, h, d] or [2, s//2, b, h, d] + # k: [s, b, h, d] + # v: [s, b, h, d] + # out: [b, s, h, d] or [s, b, h, d] f16_tensors = (q, k, v, out) - tensors_to_save, tensor_objects = prepare_for_saving( *fp8_tensors, *f16_tensors, cu_seqlens_q, cu_seqlens_q_padded, *cu_seqlens_kv_per_step, - *out_per_step, *softmax_lse_per_step, *rng_states, ) ctx.save_for_backward(*tensors_to_save) ctx.tensor_objects = tensor_objects + ctx.qkv_format = qkv_format + ctx.qkv_layout = qkv_layout + ctx.o_format = o_format + ctx.dqkv_format = qkv_format + ctx.dqkv_layout = qkv_layout ctx.fwd_nominal_dtype = fwd_nominal_dtype + ctx.orig_o_shape = orig_o_shape + ctx.o_shape = o_shape ctx.q_shape = q_shape ctx.k_shape = k_shape ctx.v_shape = v_shape - ctx.out_shape = out_shape ctx.kv_seq_range_per_step = kv_seq_range_per_step ctx.window_size_per_step = window_size_per_step + ctx.cp_group = cp_group ctx.cp_stream = cp_stream ctx.dropout_p = dropout_p ctx.max_seqlen_q = max_seqlen_q ctx.softmax_scale = softmax_scale - ctx.qkv_format = qkv_format ctx.attn_bias_type = attn_bias_type ctx.attn_mask_type = attn_mask_type ctx.deterministic = deterministic @@ -3314,6 +3348,7 @@ def forward( ctx.fp8_meta = fp8_meta ctx.is_input_fp8 = is_input_fp8 ctx.is_output_fp8 = is_output_fp8 + ctx.dQKV_quantizer = dQKV_quantizer ctx.dO_quantizer = dO_quantizer ctx.dP_quantizer = dP_quantizer @@ -3324,11 +3359,11 @@ def forward( ctx.QKV_quantizer = QKV_quantizer.copy() ctx.O_quantizer = O_quantizer.copy() ctx.S_quantizer = S_quantizer.copy() if S_quantizer is not None else None - if not ctx.fp8_recipe.mxfp8(): ctx.QKV_quantizer.scale = QKV_quantizer.scale.clone() ctx.O_quantizer.scale = O_quantizer.scale.clone() ctx.S_quantizer.scale = S_quantizer.scale.clone() + nvtx_range_pop("transformer_engine.AttnFuncWithCPAndKVAllGather.forward") if return_max_logit: return out_ret, max_logit @@ -3342,7 +3377,6 @@ def backward(ctx, dout, *_args): rank = get_distributed_rank(ctx.cp_group) cu_seqlens_kv_per_step = [None, None] - out_per_step = [None, None] softmax_lse_per_step = [None, None] rng_states = [None, None] ( @@ -3358,36 +3392,61 @@ def backward(ctx, dout, *_args): cu_seqlens_q_padded, cu_seqlens_kv_per_step[0], cu_seqlens_kv_per_step[1], - out_per_step[0], - out_per_step[1], softmax_lse_per_step[0], softmax_lse_per_step[1], rng_states[0], rng_states[1], ) = restore_from_saved(ctx.tensor_objects, ctx.saved_tensors) - kv_seq_range_per_step = ctx.kv_seq_range_per_step window_size_per_step = ctx.window_size_per_step - seq_dim = ctx.qkv_format.index("s") - qkv_layout = ctx.qkv_format + "_" + ctx.qkv_format + "_" + ctx.qkv_format + _, seq_dim_qkv, _ = get_bsh_dims(ctx.qkv_format) + _, seq_dim_dqkv, _ = get_bsh_dims(ctx.dqkv_format) + _, seq_dim_o, _ = get_bsh_dims(ctx.o_format) - dout = dout.view(ctx.out_shape) + # set up dout: + # FP8DS/CS: torch.uint8, [b, s, h, d] or [s, b, h, d] + # MXFP8/F16: torch.float16 or torch.bfloat16, [b, s, h, d] or [s, b, h, d] dout_fp8 = None if ctx.fp8: + assert ctx.use_fused_attention, "FP8 is only supported with FusedAttention backend!" if isinstance(dout, QuantizedTensorStorage): dout_fp8 = dout elif not ctx.fp8_recipe.mxfp8(): - dout_fp8 = ctx.dO_quantizer(dout) + dout = ctx.dO_quantizer(dout) + dout_fp8 = dout if not ctx.fp8_recipe.mxfp8(): dout = dout_fp8._data + # [b, s, h, d] -> [b, 2, s//2, h, d] + # [s, b, h, d] -> [2, s//2, b, h, d] + dout = dout.view(ctx.o_shape) + # set up q, k, v: + # FP8DS/CS: torch.uint8 + # MXFP8/F16: torch.float16 or torch.bfloat16 + # q: [b, 2, s//2, h, d] or [2, s//2, b, h, d] + # k: [s, b, h, d] + # v: [s, b, h, d] if ctx.fp8 and not ctx.fp8_recipe.mxfp8(): q, k, v = [x._data for x in [q_fp8, k_fp8, v_fp8]] if not ctx.qkv_reshaped: - q = q.view(*q.shape[:seq_dim], 2, q.shape[seq_dim] // 2, *q.shape[(seq_dim + 1) :]) - k, v = [x.movedim(seq_dim, 0).contiguous() for x in [k, v]] + q = q.view(*q.shape[:seq_dim_qkv], 2, q.shape[seq_dim_qkv] // 2, *q.shape[(seq_dim_qkv + 1) :]) + k, v = [x.movedim(seq_dim_qkv, 0).contiguous() for x in [k, v]] + # set up out: + # FP8DS or (FP8CS+not _dpa_fp8_cs_o_in_f16): torch.uint8 + # FP8CS+_dpa_fp8_cs_o_in_f16: torch.float16 or torch.bfloat16 + # MXFP8/F16: torch.float16 or torch.bfloat16 + # [b, s, h, d] -> [b, 2, s//2, h, d] + # [s, b, h, d] -> [2, s//2, b, h, d] + if ctx.fp8 and (ctx.fp8_recipe.delayed() or (ctx.fp8_recipe.float8_current_scaling() and not _dpa_fp8_cs_o_in_f16)): + out = out_fp8._data + out = out.view(ctx.o_shape) + + # set up dq, dk, dv: + # dq: fwd_nominal_dtype, [b, 2, s//2, h, d] or [2, s//2, b, h, d] + # dk: fwd_nominal_dtype, [cp*s, b, h, d] + # dv: fwd_nominal_dtype, [cp*s, b, h, d] dq = torch.empty(ctx.q_shape, dtype=ctx.fwd_nominal_dtype, device=q.device) dk = torch.zeros( (ctx.k_shape[0] * cp_size, *ctx.k_shape[1:]), @@ -3408,23 +3467,22 @@ def backward(ctx, dout, *_args): # synchronize dkv update across steps dkv_update_done = torch.cuda.Event() - # [s, b, h, d] -> [cp, s, b, h, d] + # gather k and v along s: [s, b, h, d] -> [cp, s, b, h, d] k_ag, _ = gather_along_first_dim(k, ctx.cp_group) v_ag, _ = gather_along_first_dim(v, ctx.cp_group) - - # [cp, s, b, h, d] -> [cp*2, s//2, b, h, d] + # split s: [cp, s, b, h, d] -> [cp*2, s//2, b, h, d] k_ag = k_ag.view(2 * cp_size, k.shape[0] // 2, *k.shape[1:]) v_ag = v_ag.view(2 * cp_size, v.shape[0] // 2, *v.shape[1:]) + # select appropriate chunks for each rank chunk_ids_for_kv_ag = get_seq_chunk_ids_for_reordering_before_attn(cp_size, k.device) k_ag = torch.index_select(k_ag, dim=0, index=chunk_ids_for_kv_ag) v_ag = torch.index_select(v_ag, dim=0, index=chunk_ids_for_kv_ag) - # [cp*2, s//2, b, h, d] -> [cp*s, b, h, d] + # flatten: [cp*2, s//2, b, h, d] -> [cp*s, b, h, d] k_ag = k_ag.view(-1, *k.shape[1:]) v_ag = v_ag.view(-1, *v.shape[1:]) ctx.cp_stream.wait_stream(torch.cuda.current_stream()) - local_seq_chunk_ids = [rank, 2 * cp_size - rank - 1] - + # set up flash_attn_bwd flash_attn_bwd = None if not ctx.use_fused_attention: fa_backward_kwargs = {"softmax_scale": ctx.softmax_scale} @@ -3456,57 +3514,65 @@ def backward(ctx, dout, *_args): if fa_utils.v2_6_0_plus: fa_backward_kwargs["softcap"] = 0.0 + local_seq_chunk_ids = [rank, 2 * cp_size - rank - 1] for i in range(len(local_seq_chunk_ids) + 1): if i < len(local_seq_chunk_ids): with torch.cuda.stream(flash_attn_streams[i]): - # [b, 2, sq//2, h, d] -> [b, sq//2, h, d] - # or [2, sq//2, b, h, d] -> [sq//2, b, h, d] - q_ = q.select(seq_dim, i).contiguous() + # [b, 2, s//2, h, d] -> [b, s//2, h, d] + # [2, s//2, b, h, d] -> [s//2, b, h, d] + q_part = q.select(seq_dim_qkv, i).contiguous() seq_start_idx, seq_end_idx = ( kv_seq_range_per_step[i][0], kv_seq_range_per_step[i][1], ) max_seqlen_kv = seq_end_idx - seq_start_idx - k_, v_ = [x[seq_start_idx:seq_end_idx] for x in [k_ag, v_ag]] - # [cp*s, b, h, d] -> [b, s_range, h, d] or [s_range, b, h, d] - k_, v_ = [x.movedim(0, seq_dim).contiguous() for x in [k_, v_]] - out_ = out_per_step[i] - dout_ = dout.select(seq_dim, i).contiguous().view(out_.shape) + # select range: [s_range, b, h, d] + k_part, v_part = [x[seq_start_idx:seq_end_idx] for x in [k_ag, v_ag]] + # reshape to original format: [b, s_range, h, d] or [s_range, b, h, d] + k_part, v_part = [x.movedim(0, seq_dim_qkv).contiguous() for x in [k_part, v_part]] + # [b, 2, s//2, h, d] -> [b, s//2, h, d] + # [2, s//2, b, h, d] -> [s//2, b, h, d] + out_part = out.select(seq_dim_o, i).contiguous() + dout_part = dout.select(seq_dim_o, i).contiguous() if ctx.use_fused_attention: aux_ctx_tensors = [ softmax_lse_per_step[i], softmax_lse_per_step[i], rng_states[i], ] - q_part, k_part, v_part, out_part, dout_part = q_, k_, v_, out_, dout_ fused_attn_backend = tex.NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen fp8_meta_kwargs = {} - new_qkv_layout = qkv_layout - d_out_format = ctx.qkv_format + qkv_layout = ctx.qkv_layout + d_out_format = ctx.o_format if ctx.fp8: fused_attn_backend = tex.NVTE_Fused_Attn_Backend.NVTE_FP8 fp8_meta_kwargs["s_quantizer"] = ctx.S_quantizer fp8_meta_kwargs["dp_quantizer"] = ctx.dP_quantizer fp8_meta_kwargs["dqkv_quantizer"] = ctx.dQKV_quantizer + # FP8DS or (FP8CS+not _dpa_fp8_cs_o_in_f16): q/k/v/o/do all in FP8 + # FP8CS+_dpa_fp8_cs_o_in_f16: q/k/v/do in FP8, o in f16 + # MXFP8: q/k/v/do all in MXFP8, o/do_f16 in F16 if not ctx.fp8_recipe.mxfp8(): q_part = Float8Tensor.make_like( - q_fp8, data=q_, dtype=ctx.fwd_nominal_dtype + q_fp8, data=q_part, dtype=ctx.fwd_nominal_dtype ) k_part = Float8Tensor.make_like( - k_fp8, data=k_, dtype=ctx.fwd_nominal_dtype + k_fp8, data=k_part, dtype=ctx.fwd_nominal_dtype ) v_part = Float8Tensor.make_like( - v_fp8, data=v_, dtype=ctx.fwd_nominal_dtype + v_fp8, data=v_part, dtype=ctx.fwd_nominal_dtype ) - if not ( - ctx.fp8_recipe.float8_current_scaling() and _dpa_fp8_cs_o_in_f16 + if ctx.fp8_recipe.delayed() or ( + ctx.fp8_recipe.float8_current_scaling() and not _dpa_fp8_cs_o_in_f16 ): - out_part = ctx.O_quantizer(out_part) + out_part = Float8Tensor.make_like( + out_fp8, data=out_part, dtype=ctx.fwd_nominal_dtype + ) dout_part = Float8Tensor.make_like( dout_fp8, data=dout_part, dtype=ctx.fwd_nominal_dtype ) else: - q_part, k_part, v_part, new_qkv_layout = combine_and_quantize( + q_part, k_part, v_part, qkv_layout = combine_and_quantize( qkv_layout, q_part, k_part, v_part, ctx.QKV_quantizer ) dout_part, d_out_format = dpa_utils.permute_to_grouped_tensor( @@ -3525,17 +3591,16 @@ def backward(ctx, dout, *_args): out_part, dout_part, ctx.fwd_nominal_dtype, - # TE_DType[dout.dtype], aux_ctx_tensors, fused_attn_backend, cu_seqlens_q_padded=cu_seqlens_q_padded, cu_seqlens_kv_padded=cu_seqlens_kv_per_step[i], attn_scale=ctx.softmax_scale, dropout=ctx.dropout_p, - qkv_layout=new_qkv_layout, - o_format=ctx.qkv_format, + qkv_layout=qkv_layout, + o_format=ctx.o_format, d_out_format=d_out_format, - dqkv_layout=qkv_layout, + dqkv_layout=ctx.dqkv_layout, attn_mask_type=ctx.attn_mask_type, attn_bias_type=ctx.attn_bias_type, window_size=window_size_per_step[i], @@ -3543,23 +3608,22 @@ def backward(ctx, dout, *_args): cuda_graph=is_graph_capturing(), **fp8_meta_kwargs, ) - if ctx.fp8: + if ctx.fp8 and all( + isinstance(x, QuantizedTensorStorage) + for x in [dq_per_step[i], dk_per_step[i], dv_per_step[i]] + ): dq_per_step[i], dk_per_step[i], dv_per_step[i] = [ - ( - x.dequantize(dtype=ctx.fwd_nominal_dtype) - if isinstance(x, QuantizedTensorStorage) - else x - ) + x.dequantize(dtype=ctx.fwd_nominal_dtype) for x in [dq_per_step[i], dk_per_step[i], dv_per_step[i]] ] else: dq_per_step[i], dk_per_step[i], dv_per_step[i] = [ - torch.empty_like(x) for x in [q_, k_, v_] + torch.empty_like(x) for x in [q_part, k_part, v_part] ] fa_backward_args_thd = get_fa_args( False, ctx.use_flash_attn_3, - ctx.qkv_format, + ctx.dqkv_format, cu_seqlens_q=cu_seqlens_q, cu_seqlens_kv=cu_seqlens_kv_per_step[i], max_seqlen_q=ctx.max_seqlen_q, @@ -3578,11 +3642,11 @@ def backward(ctx, dout, *_args): fa_backward_kwargs["window_size_left"] = window_size_per_step[i][0] fa_backward_kwargs["window_size_right"] = window_size_per_step[i][1] flash_attn_bwd( - dout_, - q_, - k_, - v_, - out_, + dout_part, + q_part, + k_part, + v_part, + out_part, softmax_lse_per_step[i], *fa_backward_args_thd, causal="causal" in ctx.attn_mask_type, @@ -3590,14 +3654,19 @@ def backward(ctx, dout, *_args): ) if i > 0: + # dq/dk/dv, dq_per_step/dk_per_step/dv_per_step: ctx.fwd_nominal_dtype with torch.cuda.stream(flash_attn_streams[i - 1]): - if ctx.qkv_format == "bshd": + # dq: [b, 2, s//2, h, d] or [2, s//2, b, h, d] + # dq_per_step[i]: [b, s//2, h, d] or [s//2, b, h, d] + if ctx.dqkv_format == "bshd": dq[:, i - 1].copy_(dq_per_step[i - 1]) - elif ctx.qkv_format == "sbhd": + elif ctx.dqkv_format == "sbhd": dq[i - 1].copy_(dq_per_step[i - 1]) - # [b, s_range, h, d] or [s_range, b, h, d] -> [s_range, b, h, d] + # dk/dv: [cp*s, b, h, d] + # dk_per_step[i - 1]/dv_per_step[i - 1]: [s_range, b, h, d] or [b, s_range, h, d] + # move s to first dim: [s_range, b, h, d] dk_per_step[i - 1], dv_per_step[i - 1] = [ - x.movedim(seq_dim, 0).contiguous() + x.movedim(seq_dim_dqkv, 0).contiguous() for x in [dk_per_step[i - 1], dv_per_step[i - 1]] ] # wait until dkv update of last step is done @@ -3607,6 +3676,7 @@ def backward(ctx, dout, *_args): kv_seq_range_per_step[i - 1][0], kv_seq_range_per_step[i - 1][1], ) + # add to dk/dv: [cp*s, b, h, d] dk[seq_start_idx:seq_end_idx].add_(dk_per_step[i - 1]) dv[seq_start_idx:seq_end_idx].add_(dv_per_step[i - 1]) if i < len(local_seq_chunk_ids): @@ -3614,27 +3684,33 @@ def backward(ctx, dout, *_args): torch.cuda.current_stream().wait_stream(ctx.cp_stream) - # [cp*s, b, h, d] -> [cp*2, s//2, b, h, d] + # split s:[cp*s, b, h, d] -> [cp*2, s//2, b, h, d] dk = dk.view(2 * cp_size, -1, *dk.shape[-3:]) dv = dv.view(2 * cp_size, -1, *dv.shape[-3:]) + # put back together the right chunks for each rank chunk_ids_for_kv_ag = get_seq_chunk_ids_for_reordering_after_attn(cp_size, dk.device) dk = torch.index_select(dk, dim=0, index=chunk_ids_for_kv_ag) dv = torch.index_select(dv, dim=0, index=chunk_ids_for_kv_ag) - # [cp*2, s//2, b, h, d] -> [cp*s, b, h, d] + # flatten: [cp*2, s//2, b, h, d] -> [cp*s, b, h, d] dk = dk.view(-1, *dk.shape[-3:]) dv = dv.view(-1, *dv.shape[-3:]) + # reduce scatter: [cp*s, b, h, d] -> [s, b, h, d] dk, _ = reduce_scatter_along_first_dim(dk, ctx.cp_group) dv, _ = reduce_scatter_along_first_dim(dv, ctx.cp_group) - dq = dq.view(*dq.shape[:seq_dim], -1, *dq.shape[(seq_dim + 2) :]) - dk = dk.movedim(0, seq_dim).contiguous() - dv = dv.movedim(0, seq_dim).contiguous() + # reshape to original format: + # dq: [b, 2, s//2, h, d] or [2, s//2, b, h, d] -> [b, s, h, d] or [s, b, h, d] + # dk: [s, b, h, d] -> [b, s, h, d] or [s, b, h, d] + # dv: [s, b, h, d] -> [b, s, h, d] or [s, b, h, d] + dq = dq.view(*dq.shape[:seq_dim_dqkv], -1, *dq.shape[(seq_dim_dqkv + 2) :]) + dk = dk.movedim(0, seq_dim_dqkv).contiguous() + dv = dv.movedim(0, seq_dim_dqkv).contiguous() + # quantize if necessary if ctx.fp8 and ctx.is_input_fp8: dq, dk, dv, _ = combine_and_quantize(qkv_layout, dq, dk, dv, ctx.dQKV_quantizer) nvtx_range_pop("transformer_engine.AttnFuncWithCPAndKVAllGather.backward") - return ( None, dq, From a250b20250452dce7fef72f10054a190218e8a48 Mon Sep 17 00:00:00 2001 From: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Date: Tue, 17 Mar 2026 17:43:10 -0700 Subject: [PATCH 082/172] refactor p2p/a2a+p2p; mostly regarding shapes Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --- .../dot_product_attention/context_parallel.py | 148 ++++++------------ 1 file changed, 47 insertions(+), 101 deletions(-) diff --git a/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py b/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py index 8c5d974a48..61fdb2c013 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py @@ -962,7 +962,7 @@ def cp_p2p_fwd_fused_attn( if return_max_logit: return out_per_step, softmax_lse_per_step, rng_states, attn_bias, *max_logit - return out_per_step, softmax_lse_per_step, rng_states, attn_bias, None # , new_qkv_layout + return out_per_step, softmax_lse_per_step, rng_states, attn_bias, None def cp_p2p_fwd_flash_attn( @@ -1128,11 +1128,9 @@ def cp_p2p_bwd_fused_attn( deterministic, fwd_nominal_dtype, bwd_nominal_dtype, - # bwd_output_te_dtype, S_quantizer, dP_quantizer_per_step, dQKV_quantizer_per_step, - # O_quantizer_per_step, QKV_quantizer_per_step, dO_quantizer_per_step, q_part, @@ -1201,9 +1199,7 @@ def cp_p2p_bwd_fused_attn( out_part = Float8Tensor.make_like(out_fp8, data=out_part, dtype=fwd_nominal_dtype) dout_part = Float8Tensor.make_like(dout_fp8, data=dout_part, dtype=bwd_nominal_dtype) else: - # out_part, o_format = dpa_utils.permute_to_grouped_tensor(o_format, out_part) dout_part, d_out_format = dpa_utils.permute_to_grouped_tensor(d_out_format, dout_part) - # out_part = O_quantizer_per_step(out_part) aux_tensors.append(dout_part) dout_part = dO_quantizer_per_step(dout_part) fp8_meta_kwargs["s_quantizer"] = S_quantizer @@ -1386,17 +1382,15 @@ def forward( ) # set up attention args - enable_mla = k.shape[-1] != v.shape[-1] - causal = "causal" in attn_mask_type - if softmax_scale is None: softmax_scale = q.shape[-1] ** (-0.5) - + causal = "causal" in attn_mask_type + qkv_layout = qkv_format + "_" + qkv_format + "_" + qkv_format + orig_q_shape, orig_k_shape, orig_v_shape = q.shape, k.shape, v.shape orig_o_shape = q.shape[:-1] + v.shape[-1:] batch_dim = None seq_dim = None cu_seqlens_q_half, cu_seqlens_kv_half = None, None - qkv_layout = qkv_format + "_" + qkv_format + "_" + qkv_format if qkv_format in ["bshd", "sbhd"]: seq_dim = qkv_format.index("s") cu_seqlens_q_padded, cu_seqlens_kv_padded = None, None @@ -1411,13 +1405,10 @@ def forward( else: cu_seqlens_q_padded = cu_seqlens_q_padded // cp_size cu_seqlens_kv_padded = cu_seqlens_kv_padded // cp_size - max_seqlen_q = max_seqlen_q // cp_size max_seqlen_kv = max_seqlen_kv // cp_size cu_seqlens_q_per_step = [None for _ in range(cp_size)] cu_seqlens_kv_per_step = [None for _ in range(cp_size)] - - fused_attn_backend = None amax_per_step = None S_quantizer_per_step = [None for _ in range(cp_size)] O_quantizer_per_step = [None for _ in range(cp_size)] @@ -1426,17 +1417,14 @@ def forward( assert isinstance(k, q.__class__) and isinstance( v, q.__class__ - ), "q, k, v must be of the same class, e.g. torch.Tensor or Float8Tensor." + ), "q, k, v must be of the same class, e.g. torch.Tensor or QuantizedTensorStorage." fwd_nominal_dtype = q.dtype - is_input_fp8 = isinstance(q, Float8Tensor) + is_input_fp8 = isinstance(q, QuantizedTensorStorage) is_output_fp8 = fp8_output is_bwd_fp8 = int(os.getenv("NVTE_FP8_DPA_BWD", "1")) - # recipe passed in through autocast or set by NVTE_DPA_FP8_RECIPE; - # may be different from fp8_meta["recipe"] fp8_recipe = FP8GlobalStateManager.get_fp8_recipe() if fp8_meta is not None and fp8_meta.get("local_recipes", None) is not None: fp8_recipe = fp8_meta["local_recipes"][0] - ( QKV_quantizer, O_quantizer, @@ -1446,15 +1434,15 @@ def forward( dP_quantizer, ) = dpa_utils.get_attention_quantizers(fp8, fp8_recipe, quantizers) - q_f16 = None + # q, k, v a2a: gather s and split h + # FP8DS/CS: Float8Tensor -> torch.uint8 -> Float8Tensor + # MXFP8/F16: fwd_nominal_dtype q_fp8, k_fp8, v_fp8 = (None, None, None) - # communicate for the 'a2a' part of 'a2a+p2p' if cp_size_a2a > 1: if fp8 and is_input_fp8: - QKV_quantizer = q._quantizer q_fp8, k_fp8, v_fp8 = q, k, v if not fp8_recipe.mxfp8(): - q, k, v = (q._data, k._data, v._data) + q, k, v = [q_fp8._data, k_fp8._data, v_fp8._data] chunk_ids_for_a2a = get_seq_chunk_ids_for_reordering_before_attn(cp_size_a2a, q.device) q, k, v = flash_attn_a2a_communicate( [q, k, v], chunk_ids_for_a2a, seq_dim, cp_size_a2a, cp_group_a2a, cp_stream, True @@ -1465,12 +1453,14 @@ def forward( for x, y in zip([q_fp8, k_fp8, v_fp8], [q, k, v]) ] q, k, v = q_fp8, k_fp8, v_fp8 + post_a2a_o_shape = q.shape[:-1] + v.shape[-1:] # convert qkv to the right type + q_f16 = None + fused_attn_backend = None if fp8: assert use_fused_attention, "FP8 is only supported with Fused Attention!" fused_attn_backend = FusedAttnBackend["FP8"] - if is_input_fp8: # q_fp8, k_fp8, v_fp8: Float8Tensor, dtype=fwd_nominal_dtype # q, k, v: torch.Tensor, dtype=torch.uint8 @@ -1560,7 +1550,6 @@ def forward( attn_bias_ = attn_bias.view( *attn_bias.shape[:-1], 2 * cp_size, attn_bias.shape[-1] // (2 * cp_size) ) - # [b, h, sq, sk] -> [b, h, sq, 2*cp, sk//(2*cp)] attn_bias = attn_bias.view( *attn_bias.shape[:-1], 2 * cp_size, attn_bias.shape[-1] // (2 * cp_size) @@ -1630,17 +1619,20 @@ def forward( # synchronize fwd results correction across steps fwd_results_correction_done = torch.cuda.Event() + # q, k, v, o: + # causal: [b, 2, s//2, h, d] or [2, s//2, b, h, d] + # non-causal: [b, s, h, d] or [s, b, h, d] p2p_comm_buffers = [None for _ in range(cp_size)] k_shape = k.shape k_numel = k.numel() v_shape = v.shape - o_shape = q.shape if not enable_mla else q.shape[:-1] + v.shape[-1:] + o_shape = q.shape[:-1] + v.shape[-1:] p2p_comm_buffers[0] = torch.cat((k.view(-1), v.view(-1)), dim=-1) send_recv_reqs = [[], []] # P2P communication and compute: each rank has cp_size steps - # f16 attention: q, k, v: torch.Tensor, dtype=fwd_nominal_dtype - # fp8 attention: q, k, v: torch.Tensor, dtype=torch.uint8 + # MXFP8/F16 attention: q, k, v: torch.Tensor, dtype=fwd_nominal_dtype + # FP8DS/CS attention: q, k, v: torch.Tensor, dtype=torch.uint8 out = None o_format = qkv_format for i in range(cp_size + 1): @@ -1757,7 +1749,6 @@ def forward( rng_states[i], attn_biases[i], max_logit_per_step[i], - # qkv_layout, ) = cp_p2p_fwd_fused_attn( *fused_attn_inputs, *prepare_outputs, section ) @@ -1785,7 +1776,6 @@ def forward( rng_states[i], attn_biases[i], max_logit_per_step[i], - # qkv_layout, ) = cp_p2p_fwd_fused_attn( *fused_attn_inputs, *prepare_outputs, section ) @@ -1813,7 +1803,6 @@ def forward( rng_states[i], attn_biases[i], max_logit_per_step[i], - # qkv_layout, ) = cp_p2p_fwd_fused_attn( *fused_attn_inputs, *prepare_outputs, section ) @@ -1842,7 +1831,6 @@ def forward( rng_states[i], attn_biases[i], max_logit_per_step[i], - # qkv_layout, ) = cp_p2p_fwd_fused_attn(*fused_attn_inputs, *prepare_outputs, section) else: out_per_step[i], softmax_lse_per_step[i], rng_states[i] = ( @@ -1857,7 +1845,7 @@ def forward( with torch.cuda.stream(flash_attn_streams[(i - 1) % 2]): if use_fused_attention: - # [b, h, sq, 1] -> [b, h, sq] or + # [b, h, sq, 1] -> [b, h, sq] # [t, h, 1] -> [t, h] softmax_lse_per_step[i - 1].squeeze_(-1) if softmax_lse_in_packed_format: @@ -1876,15 +1864,10 @@ def forward( if i == 1: softmax_lse = torch.clone(softmax_lse_per_step[0]) if qkv_format == "thd": - if enable_mla: - out = torch.zeros_like(v if not fp8 else out_per_step[0]).view( - o_shape - ) + if fp8: + out = torch.zeros_like(out_per_step[0]).view(o_shape) else: - # MHA or GQA - out = torch.zeros_like(q if not fp8 else out_per_step[0]).view( - q.shape - ) + out = torch.zeros(o_shape, dtype=q.dtype, device=q.device) elif (i - 1) <= rank or not causal: flash_attn_fwd_softmax_lse_correction( softmax_lse, softmax_lse_per_step[i - 1] @@ -1932,10 +1915,7 @@ def forward( softmax_lse_per_step[0], seq_dim, ) - if enable_mla: - out = out.view(o_shape) - else: - out = out.view(q.shape) + out = out.view(o_shape) else: flash_attn_fwd_out_correction( out.view(*out_per_step[i].shape), @@ -1973,13 +1953,7 @@ def forward( True, softmax_lse_in_packed_format, ) - - if o_format == "bshd": - out = out.view(out.shape[0], -1, *out.shape[-2:]) - ctx.batch_size = out.shape[0] - elif o_format == "sbhd": - out = out.view(-1, *out.shape[-3:]) - ctx.batch_size = out.shape[1] + out = out.view(post_a2a_o_shape) out_part = out.to(fwd_nominal_dtype) if cp_size_a2a > 1: @@ -1987,19 +1961,11 @@ def forward( out = flash_attn_a2a_communicate( out, chunk_ids_for_a2a, seq_dim, cp_size_a2a, cp_group_a2a, cp_stream, False ) - if use_fused_attention: - if o_format == "bshd": - # [b*s, h, d] -> [b, s, h, d] - out = out.view(ctx.batch_size, -1, *out.shape[-2:]) - elif o_format == "sbhd": - # [s*b, h, d] -> [s, b, h, d] - out = out.view(-1, ctx.batch_size, *out.shape[-2:]) + out = out.view(orig_o_shape) if return_max_logit: max_logit = flash_attn_a2a_communicate_softmax_offset( max_logit, 0, cp_size_a2a, cp_group_a2a, cp_stream, False ) - elif not use_fused_attention: - out = out.view(-1, *out.shape[-2:]) # update FP8 quantizers: amax across cp_size steps if fp8 and use_fused_attention and not fp8_recipe.mxfp8(): @@ -2107,7 +2073,6 @@ def forward( ctx.max_seqlen_q = max_seqlen_q ctx.max_seqlen_kv = max_seqlen_kv ctx.softmax_scale = softmax_scale - ctx.qkv_format = qkv_format ctx.attn_mask_type = attn_mask_type ctx.attn_bias_type = attn_bias_type ctx.attn_bias_shape = None if attn_bias is None else attn_bias.shape @@ -2120,14 +2085,19 @@ def forward( ctx.is_output_fp8 = is_output_fp8 ctx.use_flash_attn_3 = use_flash_attn_3 - ctx.enable_mla = enable_mla ctx.k_numel = k_numel ctx.k_shape = k_shape ctx.v_shape = v_shape ctx.o_shape = o_shape + ctx.orig_q_shape = orig_q_shape + ctx.orig_k_shape = orig_k_shape + ctx.orig_v_shape = orig_v_shape ctx.orig_o_shape = orig_o_shape - + ctx.post_a2a_o_shape = post_a2a_o_shape + ctx.qkv_format = qkv_format + ctx.qkv_layout = qkv_layout ctx.fwd_nominal_dtype = fwd_nominal_dtype + ctx.dQKV_quantizer = dQKV_quantizer ctx.dO_quantizer = dO_quantizer ctx.dP_quantizer = dP_quantizer @@ -2144,7 +2114,6 @@ def forward( ctx.S_quantizer.scale = S_quantizer.scale.clone() nvtx_range_pop(f"{nvtx_label}") - if return_max_logit: return out_ret, max_logit return out_ret @@ -2204,8 +2173,6 @@ def backward(ctx, dout, *_args): # set up attention args causal = "causal" in ctx.attn_mask_type seq_dim = None - qkv_layout = ctx.qkv_format + "_" + ctx.qkv_format + "_" + ctx.qkv_format - o_format = ctx.qkv_format if ctx.qkv_format in ["bshd", "sbhd"]: seq_dim = ctx.qkv_format.index("s") @@ -2265,11 +2232,9 @@ def backward(ctx, dout, *_args): buffer_dtype = torch.uint8 dq_buffer = None dout_fp8 = None - # bwd_output_te_dtype = None dkv_buffer = None - d_out_format = o_format if ctx.fp8: - assert ctx.use_fused_attention, "FP8 is only supported with Fused Attention!" + assert ctx.use_fused_attention, "FP8 is only supported with FusedAttention backend!" fused_attn_backend = FusedAttnBackend["FP8"] if not ctx.fp8_recipe.mxfp8(): q, kv, out = ( @@ -2284,8 +2249,6 @@ def backward(ctx, dout, *_args): # dout_fp8: Float8Tensor, dtype=bwd_nominal_dtype # dout: torch.Tensor, dtype=torch.uint8 - # if ctx.fp8_recipe.mxfp8(): - # dout, d_out_format = dpa_utils.permute_to_grouped_tensor(d_out_format, dout) if isinstance(dout, QuantizedTensorStorage): dout_fp8 = dout elif not ctx.fp8_recipe.mxfp8(): @@ -2305,9 +2268,6 @@ def backward(ctx, dout, *_args): ctx.dP_quantizer, ) - # dout_fp8._fp8_dtype - # bwd_output_te_dtype = ctx.dO_quantizer.dtype - # create buffers for reduction in float32 if ctx.fp8_recipe.delayed(): dq_buffer = torch.empty( @@ -2359,14 +2319,11 @@ def backward(ctx, dout, *_args): ] p2p_comm_buffers[0][0].copy_(kv) if ctx.use_fused_attention: - # bwd_output_te_dtype = TE_DType[bwd_nominal_dtype] fused_attn_backend = FusedAttnBackend["F16_arbitrary_seqlen"] # communicate for the 'a2a' part of 'a2a+p2p' + dout = dout.view(*ctx.orig_o_shape) if cp_size_a2a > 1: - if not ctx.use_fused_attention: - # out = out.view(ctx.batch_size, -1, *out.shape[-2:]) - dout = dout.view(ctx.orig_o_shape) chunk_ids_for_a2a = get_seq_chunk_ids_for_reordering_before_attn( cp_size_a2a, out.device ) @@ -2379,14 +2336,8 @@ def backward(ctx, dout, *_args): ctx.cp_stream, True, ) - - if ctx.enable_mla: - out = out.view(*ctx.o_shape) - dout = dout.view(*ctx.o_shape) - else: - # MHA or GQA - out = out.view(*q.shape) - dout = dout.view(*q.shape) + out = out.view(*ctx.o_shape) + dout = dout.view(*ctx.o_shape) flash_attn_bwd = None if not ctx.use_fused_attention: @@ -2504,20 +2455,18 @@ def backward(ctx, dout, *_args): fused_attn_backend, ctx.softmax_scale, ctx.dropout_p, - qkv_layout, + ctx.qkv_layout, ctx.qkv_format, ctx.qkv_format, - qkv_layout, + ctx.qkv_layout, ctx.attn_mask_type, ctx.attn_bias_type, ctx.deterministic, ctx.fwd_nominal_dtype, bwd_nominal_dtype, - # bwd_output_te_dtype, ctx.S_quantizer, dP_quantizer_per_step[i], dQKV_quantizer_per_step[i], - # ctx.O_quantizer, ctx.QKV_quantizer, ctx.dO_quantizer, ] @@ -2784,7 +2733,7 @@ def backward(ctx, dout, *_args): for x in [dq, dk, dv] ] dq, dk, dv = combine_and_dequantize( - qkv_layout, + ctx.qkv_layout, dq, dk, dv, @@ -2809,7 +2758,7 @@ def backward(ctx, dout, *_args): dv[cu_seqlens_kv_padded[-1] :].fill_(0) if ctx.fp8 and ctx.is_input_fp8: - dq, dk, dv, _ = combine_and_quantize(qkv_layout, dq, dk, dv, ctx.dQKV_quantizer) + dq, dk, dv, _ = combine_and_quantize(ctx.qkv_layout, dq, dk, dv, ctx.dQKV_quantizer) if ctx.fp8: # print quantizers @@ -2844,10 +2793,7 @@ def backward(ctx, dout, *_args): Float8Tensor.make_like(x, data=y, dtype=bwd_nominal_dtype) for x, y in zip([dq_fp8, dk_fp8, dv_fp8], [dq, dk, dv]) ] - if ctx.qkv_format == "bshd": - dq, dk, dv = [x.view(ctx.batch_size, -1, *x.shape[-2:]) for x in [dq, dk, dv]] - elif ctx.qkv_format == "sbhd": - dq, dk, dv = [x.view(-1, ctx.batch_size, *x.shape[-2:]) for x in [dq, dk, dv]] + dq, dk, dv = [x.view(y) for x,y in zip([dq, dk, dv], [ctx.orig_q_shape, ctx.orig_k_shape, ctx.orig_v_shape])] if attn_dbias is not None: # [b, h, sq, 2*cp, sk//(2*cp)] -> [b, h, sq, sk] @@ -3785,6 +3731,7 @@ def forward( cp_size = get_distributed_world_size(cp_group) qkv_layout = qkv_format + "_" + qkv_format + "_" + qkv_format original_qkv_layout = qkv_layout + orig_q_shape, orig_k_shape, orig_v_shape = q.shape, k.shape, v.shape o_format = qkv_format batch_dim_qkv, seq_dim_qkv, _ = get_bsh_dims(qkv_format) _, seq_dim_o, _ = get_bsh_dims(o_format) @@ -4064,7 +4011,9 @@ def forward( ctx.o_format = o_format ctx.dqkv_layout = original_qkv_layout ctx.dqkv_format = qkv_format - ctx.batch_size = batch_size + ctx.orig_q_shape = orig_q_shape + ctx.orig_k_shape = orig_k_shape + ctx.orig_v_shape = orig_v_shape ctx.out_part_shape = out_part.shape ctx.out_ret_shape = out_ret.shape @@ -4350,10 +4299,7 @@ def backward(ctx, dout, *_args): qkv_format=ctx.dqkv_format, cu_seqlens_padded=cu_seqlens_q_padded, ) - if ctx.dqkv_format == "bshd": - dq, dk, dv = [x.view(ctx.batch_size, -1, *x.shape[-2:]) for x in [dq, dk, dv]] - elif ctx.dqkv_format == "sbhd": - dq, dk, dv = [x.view(-1, ctx.batch_size, *x.shape[-2:]) for x in [dq, dk, dv]] + dq, dk, dv = [x.view(y) for x,y in zip([dq, dk, dv], [ctx.orig_q_shape, ctx.orig_k_shape, ctx.orig_v_shape])] # d_bias, d_softmax_offset d_bias = None From 630545e9b0625384d78dbb2304ef97728599fe36 Mon Sep 17 00:00:00 2001 From: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Date: Tue, 17 Mar 2026 18:11:07 -0700 Subject: [PATCH 083/172] add shadow f16 fwd Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --- .../dot_product_attention/backends.py | 51 ++++++++++++++++++- 1 file changed, 50 insertions(+), 1 deletion(-) diff --git a/transformer_engine/pytorch/attention/dot_product_attention/backends.py b/transformer_engine/pytorch/attention/dot_product_attention/backends.py index 906f3ade45..fb5e870099 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/backends.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/backends.py @@ -154,7 +154,9 @@ # Float8CurrentScaling: fused_attn_bwd takes O in FP8 by default, this flag allows it in F16 _dpa_fp8_cs_o_in_f16 = os.getenv("NVTE_DPA_FP8CS_O_in_F16", "1") == "1" - +_run_shadow_f16_fwd = os.getenv("NVTE_RUN_SHADOW_F16_FWD", "0") == "1" +_replace_out_with_shadow = os.getenv("NVTE_REPLACE_OUT_WITH_SHADOW", "0") == "1" +_replace_aux_with_shadow = os.getenv("NVTE_REPLACE_AUX_WITH_SHADOW", "0") == "1" class FP8EmulationFunc(torch.autograd.Function): """ @@ -1271,6 +1273,8 @@ def forward( out_nominal_dtype = q.dtype max_logit = None + orig_q, orig_k, orig_v = q, k, v + orig_qkv_layout = qkv_layout if fp8: fused_attention_backend = FusedAttnBackend["FP8"] @@ -1334,6 +1338,51 @@ def forward( softmax_offset, cuda_graph=is_graph_capturing(), ) + + if _run_shadow_f16_fwd: + # q, k, v, out_: torch.Tensor; dtype = torch.float16 or torch.bfloat16 + assert all(x.dtype in [torch.float16, torch.bfloat16] for x in [q, k, v]), "q, k, v must be torch.float16 or torch.bfloat16" + out_f16_, aux_ctx_tensors_f16, *_ = fused_attn_fwd( + is_training, + max_seqlen_q, + max_seqlen_kv, + cu_seqlens_q, + cu_seqlens_kv, + orig_q, + orig_k, + orig_v, + out_nominal_dtype, + FusedAttnBackend["F16_arbitrary_seqlen"], + attn_bias, + cu_seqlens_q_padded, + cu_seqlens_kv_padded, + page_table_k, + page_table_v, + None, # s_quantizer + None, # o_quantizer + attn_scale, + dropout_p, + fast_zero_fill, + orig_qkv_layout, + o_format, + attn_bias_type, + attn_mask_type, + softmax_type, + window_size, + bottom_right_diagonal, + rng_gen, + softmax_offset, + return_max_logit, + is_graph_capturing(), + ) + if torch.cuda.current_device() == 0: + print(f"L{layer_number}: real/shadow out min: {out_.min():.4f}/{out_f16_.min():.4f}, max: {out_.max():.4f}/{out_f16_.max():.4f}") + print(f"L{layer_number}: real/shadow stats min: {aux_ctx_tensors[0].min():.4f}/{aux_ctx_tensors_f16[0].min():.4f}, max: {aux_ctx_tensors[0].max():.4f}/{aux_ctx_tensors_f16[0].max():.4f}") + if _replace_out_with_shadow: + out_ = out_f16_ + if _replace_aux_with_shadow: + aux_ctx_tensors[0] = aux_ctx_tensors_f16[0] + # out_fp8: Float8Tensor/MXFP8Tensor; dtype = torch.float16 or torch.bfloat16 # fp8_dtype = tex.DType.kFloat8E4M3 # out: torch.Tensor; dtype = torch.float16 or torch.bfloat16 From a78ea9aa96022ebed59a28ac85eff928a51949c0 Mon Sep 17 00:00:00 2001 From: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Date: Wed, 18 Mar 2026 10:10:31 -0700 Subject: [PATCH 084/172] update FE to fix SWA/BRCM Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --- 3rdparty/cudnn-frontend | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/3rdparty/cudnn-frontend b/3rdparty/cudnn-frontend index f87e5f42ac..e6472db459 160000 --- a/3rdparty/cudnn-frontend +++ b/3rdparty/cudnn-frontend @@ -1 +1 @@ -Subproject commit f87e5f42ac2f6617cc645bb32873d934b8d8bdd7 +Subproject commit e6472db459b760e033593ae2aaf0046f0a1a4a1d From 59eff74bf70bf539a2679f0f4a38d9d32afacea8 Mon Sep 17 00:00:00 2001 From: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Date: Wed, 18 Mar 2026 18:26:45 -0700 Subject: [PATCH 085/172] switch to GH FE temporarily Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --- .gitmodules | 3 +-- 3rdparty/cudnn-frontend | 2 +- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/.gitmodules b/.gitmodules index 8c7646c00d..4b188d6bb1 100644 --- a/.gitmodules +++ b/.gitmodules @@ -3,8 +3,7 @@ url = https://github.com/google/googletest.git [submodule "3rdparty/cudnn-frontend"] path = 3rdparty/cudnn-frontend - url = https://gitlab-master.nvidia.com/cudnn/cudnn_frontend.git - branch = develop + url = https://github.com/NVIDIA/cudnn-frontend.git [submodule "3rdparty/cutlass"] path = 3rdparty/cutlass url = https://github.com/NVIDIA/cutlass.git diff --git a/3rdparty/cudnn-frontend b/3rdparty/cudnn-frontend index e6472db459..d33027a41a 160000 --- a/3rdparty/cudnn-frontend +++ b/3rdparty/cudnn-frontend @@ -1 +1 @@ -Subproject commit e6472db459b760e033593ae2aaf0046f0a1a4a1d +Subproject commit d33027a41a93af9c85f089c6364ab415fce98982 From 1691747443c6ba0d2dca640b260dee2534895c4b Mon Sep 17 00:00:00 2001 From: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Date: Wed, 18 Mar 2026 18:30:38 -0700 Subject: [PATCH 086/172] switch back to GL FE Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --- .gitmodules | 3 ++- 3rdparty/cudnn-frontend | 2 +- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/.gitmodules b/.gitmodules index 4b188d6bb1..8c7646c00d 100644 --- a/.gitmodules +++ b/.gitmodules @@ -3,7 +3,8 @@ url = https://github.com/google/googletest.git [submodule "3rdparty/cudnn-frontend"] path = 3rdparty/cudnn-frontend - url = https://github.com/NVIDIA/cudnn-frontend.git + url = https://gitlab-master.nvidia.com/cudnn/cudnn_frontend.git + branch = develop [submodule "3rdparty/cutlass"] path = 3rdparty/cutlass url = https://github.com/NVIDIA/cutlass.git diff --git a/3rdparty/cudnn-frontend b/3rdparty/cudnn-frontend index d33027a41a..e6472db459 160000 --- a/3rdparty/cudnn-frontend +++ b/3rdparty/cudnn-frontend @@ -1 +1 @@ -Subproject commit d33027a41a93af9c85f089c6364ab415fce98982 +Subproject commit e6472db459b760e033593ae2aaf0046f0a1a4a1d From d41eca30d6a7cfef284659e278e59b9d66c7f227 Mon Sep 17 00:00:00 2001 From: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Date: Wed, 18 Mar 2026 18:31:45 -0700 Subject: [PATCH 087/172] update FE to latest commit Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --- 3rdparty/cudnn-frontend | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/3rdparty/cudnn-frontend b/3rdparty/cudnn-frontend index e6472db459..69432369f3 160000 --- a/3rdparty/cudnn-frontend +++ b/3rdparty/cudnn-frontend @@ -1 +1 @@ -Subproject commit e6472db459b760e033593ae2aaf0046f0a1a4a1d +Subproject commit 69432369f3060467c72d01bd08cfeb9271178c22 From e0b65a5c53c3ff8c48d2af0ed7acb940eb9341f6 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 19 Mar 2026 01:32:43 +0000 Subject: [PATCH 088/172] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/pytorch/attention/test_attention.py | 5 +- .../attention/test_attention_with_cp.py | 13 ++- .../common/fused_attn/fused_attn.cpp | 26 +++-- .../common/fused_attn/fused_attn_fp8.cu | 98 ++++++++--------- .../common/fused_attn/fused_attn_fp8.h | 39 +++---- transformer_engine/common/fused_attn/utils.h | 53 ++++----- .../dot_product_attention/backends.py | 17 ++- .../dot_product_attention/context_parallel.py | 104 +++++++++++++----- .../dot_product_attention.py | 14 ++- .../attention/dot_product_attention/utils.py | 36 ++++-- transformer_engine/pytorch/csrc/extensions.h | 3 +- .../pytorch/csrc/extensions/attention.cpp | 20 ++-- 12 files changed, 252 insertions(+), 176 deletions(-) diff --git a/tests/pytorch/attention/test_attention.py b/tests/pytorch/attention/test_attention.py index f194758351..7c46bc750f 100644 --- a/tests/pytorch/attention/test_attention.py +++ b/tests/pytorch/attention/test_attention.py @@ -1802,6 +1802,7 @@ def get_model(dtype, config): return outputs + attn_mask_type = "causal" # attn_mask_type = "no_mask" # attn_mask_type = "causal_bottom_right" @@ -1815,9 +1816,7 @@ def get_model(dtype, config): head_dim_v=128, attn_mask_type=attn_mask_type, ), - "fp8_10": ModelConfig( - 2, 8192, 32, 128, num_gqa_groups=4, attn_mask_type=attn_mask_type - ), + "fp8_10": ModelConfig(2, 8192, 32, 128, num_gqa_groups=4, attn_mask_type=attn_mask_type), "fp8_11": ModelConfig(2, 8192, 64, 64, num_gqa_groups=8, attn_mask_type=attn_mask_type), "fp8_12": ModelConfig(2, 8192, 64, 64, attn_mask_type=attn_mask_type, window_size=(128, 0)), # "fp8_13": ModelConfig(2, 8192, 64, 64, num_gqa_groups=8, attn_mask_type="causal", softmax_type="learnable"), diff --git a/tests/pytorch/attention/test_attention_with_cp.py b/tests/pytorch/attention/test_attention_with_cp.py index 5f2b37d79e..2795c24998 100644 --- a/tests/pytorch/attention/test_attention_with_cp.py +++ b/tests/pytorch/attention/test_attention_with_cp.py @@ -169,7 +169,12 @@ def test_cp_with_flash_attention(dtype, model, qkv_format, cp_comm_type): ), # MHA "cp_1_5": ModelConfig(2, 4096, 32, 128, attn_mask_type="causal", window_size=(128, 0)), # MHA "cp_2_0": ModelConfig( - 2, 4096, 32, 128, num_gqa_groups=4, attn_mask_type="causal", + 2, + 4096, + 32, + 128, + num_gqa_groups=4, + attn_mask_type="causal", ), # GQA "cp_2_1": ModelConfig( 2, 4096, 128, 192, head_dim_v=128, attn_mask_type="causal" @@ -311,7 +316,11 @@ def test_cp_with_fused_attention( pytest.skip("No support for SWA with cp_comm_type={p2p, a2a+p2p}!") # TODO: Remove this once the issue is fixed! - if dtype == "fp8" and (config.window_size[0] != -1 or config.window_size[1] not in [-1, 0]) and cp_comm_type == "all_gather": + if ( + dtype == "fp8" + and (config.window_size[0] != -1 or config.window_size[1] not in [-1, 0]) + and cp_comm_type == "all_gather" + ): pytest.skip("No support for SWA with FP8 attention and cp_comm_type=all_gather!") if cp_comm_type in ["a2a", "a2a+p2p"] and ( diff --git a/transformer_engine/common/fused_attn/fused_attn.cpp b/transformer_engine/common/fused_attn/fused_attn.cpp index 33d99167fa..0c4cbbabbf 100644 --- a/transformer_engine/common/fused_attn/fused_attn.cpp +++ b/transformer_engine/common/fused_attn/fused_attn.cpp @@ -660,13 +660,16 @@ void nvte_fused_attn_fwd( NVTE_QKV_Format q_format = nvte_get_q_format(qkv_layout); NVTE_QKV_Format kv_format = nvte_get_kv_format(qkv_layout); std::vector tmp_shape(4); - nvte_convert_qkv_format(q_format, input_Q->data.shape, q_format, tmp_shape, &b, &h_q, &s_q, &d_qk, &t_q); - nvte_convert_qkv_format(kv_format, input_K->data.shape, kv_format, tmp_shape, &b, &h_kv, &s_kv, &d_qk, &t_kv); - nvte_convert_qkv_format(kv_format, input_V->data.shape, kv_format, tmp_shape, &b, &h_kv, &s_kv, &d_v, &t_kv); + nvte_convert_qkv_format(q_format, input_Q->data.shape, q_format, tmp_shape, &b, &h_q, &s_q, &d_qk, + &t_q); + nvte_convert_qkv_format(kv_format, input_K->data.shape, kv_format, tmp_shape, &b, &h_kv, &s_kv, + &d_qk, &t_kv); + nvte_convert_qkv_format(kv_format, input_V->data.shape, kv_format, tmp_shape, &b, &h_kv, &s_kv, + &d_v, &t_kv); if (q_format == NVTE_QKV_Format::NVTE_THD) { - b = input_cu_seqlens_q->data.shape[0] -1; + b = input_cu_seqlens_q->data.shape[0] - 1; } else if (kv_format == NVTE_QKV_Format::NVTE_THD) { - b = input_cu_seqlens_kv->data.shape[0] -1; + b = input_cu_seqlens_kv->data.shape[0] - 1; } int64_t num_pages_k = 0; @@ -784,13 +787,16 @@ void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETenso NVTE_QKV_Format q_format = nvte_get_q_format(qkv_layout); NVTE_QKV_Format kv_format = nvte_get_kv_format(qkv_layout); std::vector tmp_shape(4); - nvte_convert_qkv_format(q_format, input_Q->data.shape, q_format, tmp_shape, &b, &h_q, &s_q, &d_qk, &t_q); - nvte_convert_qkv_format(kv_format, input_K->data.shape, kv_format, tmp_shape, &b, &h_kv, &s_kv, &d_qk, &t_kv); - nvte_convert_qkv_format(kv_format, input_V->data.shape, kv_format, tmp_shape, &b, &h_kv, &s_kv, &d_v, &t_kv); + nvte_convert_qkv_format(q_format, input_Q->data.shape, q_format, tmp_shape, &b, &h_q, &s_q, &d_qk, + &t_q); + nvte_convert_qkv_format(kv_format, input_K->data.shape, kv_format, tmp_shape, &b, &h_kv, &s_kv, + &d_qk, &t_kv); + nvte_convert_qkv_format(kv_format, input_V->data.shape, kv_format, tmp_shape, &b, &h_kv, &s_kv, + &d_v, &t_kv); if (q_format == NVTE_QKV_Format::NVTE_THD) { - b = input_cu_seqlens_q->data.shape[0] -1; + b = input_cu_seqlens_q->data.shape[0] - 1; } else if (kv_format == NVTE_QKV_Format::NVTE_THD) { - b = input_cu_seqlens_kv->data.shape[0] -1; + b = input_cu_seqlens_kv->data.shape[0] - 1; } auto handle = cudnnExecutionPlanManager::Instance().GetHandle(); diff --git a/transformer_engine/common/fused_attn/fused_attn_fp8.cu b/transformer_engine/common/fused_attn/fused_attn_fp8.cu index 3e7a30022a..6c3d9a8161 100644 --- a/transformer_engine/common/fused_attn/fused_attn_fp8.cu +++ b/transformer_engine/common/fused_attn/fused_attn_fp8.cu @@ -1780,7 +1780,8 @@ void fused_attn_fp8_fwd_impl_v1( std::vector q_strides(4); std::vector k_strides(4); std::vector v_strides(4); - generateMatrixStridesWithLayout(b, h, hg, s_q, s_kv, d_qk, d_v, q_strides.data(), k_strides.data(), v_strides.data(), qkv_layout); + generateMatrixStridesWithLayout(b, h, hg, s_q, s_kv, d_qk, d_v, q_strides.data(), + k_strides.data(), v_strides.data(), qkv_layout); Q = mha_graph->tensor(fe::graph::Tensor_attributes() .set_name("Q") .set_dim({b, h, s_q, d_qk}) @@ -1989,17 +1990,17 @@ void fused_attn_fp8_fwd_impl_v1( NVTE_CHECK_CUDNN_FE(mha_graph->create_execution_plans({fe::HeurMode_t::A})); NVTE_CHECK_CUDNN_FE(mha_graph->check_support(handle)); NVTE_CHECK_CUDNN_FE(mha_graph->build_plans(handle)); - auto return_tuple = std::tuple_cat(std::make_tuple(mha_graph), key_tensors_tuple, Stats_tuple, - bias_tuple, softmax_offset_tuple, padding_tuple, - dropout_tuple); + auto return_tuple = + std::tuple_cat(std::make_tuple(mha_graph), key_tensors_tuple, Stats_tuple, bias_tuple, + softmax_offset_tuple, padding_tuple, dropout_tuple); cache.insert({descriptor, return_tuple}); return return_tuple; }; auto [mha_graph, Q, K, V, descale_q, descale_k, descale_v, descale_s, scale_s, scale_o, - attn_scale, O, amax_s, amax_o, Stats, bias, softmax_offset, seq_q, seq_kv, - dropout_seed, dropout_offset] = get_graph(sdpa_fp8_fprop_cache, descriptor); + attn_scale, O, amax_s, amax_o, Stats, bias, softmax_offset, seq_q, seq_kv, dropout_seed, + dropout_offset] = get_graph(sdpa_fp8_fprop_cache, descriptor); auto plan_workspace_size = mha_graph->get_workspace_size(); @@ -2076,21 +2077,19 @@ void fused_attn_fp8_bwd_impl_v1( NVTE_QKV_Format o_format, NVTE_QKV_Format d_out_format, NVTE_QKV_Layout dqkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, int64_t window_size_left, int64_t window_size_right, bool bottom_right_diagonal, - bool deterministic, void* devPtrQ, void* devPtrK, void* devPtrV, - void* devPtrM, void* devPtrZInv, - void* devPtrO, void* devPtrdO, void* devPtrSoftmaxOffset, void* devPtrdQ, void* devPtrdK, void* devPtrdV, - void* devPtrdSoftmaxOffset, void* devPtrDescaleQ, + bool deterministic, void* devPtrQ, void* devPtrK, void* devPtrV, void* devPtrM, + void* devPtrZInv, void* devPtrO, void* devPtrdO, void* devPtrSoftmaxOffset, void* devPtrdQ, + void* devPtrdK, void* devPtrdV, void* devPtrdSoftmaxOffset, void* devPtrDescaleQ, void* devPtrDescaleK, void* devPtrDescaleV, void* devPtrDescaleO, void* devPtrDescaledO, void* devPtrDescaleS, void* devPtrDescaledP, void* devPtrScaleS, void* devPtrScaledP, void* devPtrScaledQ, void* devPtrScaledK, void* devPtrScaledV, void* devPtrAmaxdP, void* devPtrAmaxdQ, void* devPtrAmaxdK, void* devPtrAmaxdV, void* devPtrQ_t, void* devPtrK_t, void* devPtrdO_f16, void* devPtrdO_t, void* devPtrDescaleQ_t, void* devPtrDescaleK_t, void* devPtrDescaledO_t, void* devPtrcuSeqlensQ, void* devPtrcuSeqlensKV, - void* devPtrDropoutSeed, void* devPtrDropoutOffset, - cudnn_frontend::DataType_t qkv_tensor_type, cudnn_frontend::DataType_t o_tensor_type, - cudnn_frontend::DataType_t do_tensor_type, cudnn_frontend::DataType_t dqkv_tensor_type, - NVTEScalingMode scaling_mode, void* workspace, size_t* workspace_size, cudaStream_t stream, - cudnnHandle_t handle) { + void* devPtrDropoutSeed, void* devPtrDropoutOffset, cudnn_frontend::DataType_t qkv_tensor_type, + cudnn_frontend::DataType_t o_tensor_type, cudnn_frontend::DataType_t do_tensor_type, + cudnn_frontend::DataType_t dqkv_tensor_type, NVTEScalingMode scaling_mode, void* workspace, + size_t* workspace_size, cudaStream_t stream, cudnnHandle_t handle) { using namespace transformer_engine; const auto cudnn_runtime_version = cudnnGetVersion(); bool is_bias = (bias_type == NVTE_Bias_Type::NVTE_POST_SCALE_BIAS); @@ -2240,7 +2239,8 @@ void fused_attn_fp8_bwd_impl_v1( std::vector v_strides(4); std::vector o_strides(4); std::vector dO_strides(4); - generateMatrixStridesWithLayout(b, h, hg, s_q, s_kv, d_qk, d_v, q_strides.data(), k_strides.data(), v_strides.data(), qkv_layout); + generateMatrixStridesWithLayout(b, h, hg, s_q, s_kv, d_qk, d_v, q_strides.data(), + k_strides.data(), v_strides.data(), qkv_layout); generateMatrixStridesWithFormat(b, h, s_q, d_v, o_strides.data(), o_format); generateMatrixStridesWithFormat(b, h, s_q, d_v, dO_strides.data(), d_out_format); Q = mha_graph->tensor(fe::graph::Tensor_attributes() @@ -2528,7 +2528,8 @@ void fused_attn_fp8_bwd_impl_v1( std::vector dq_strides(4); std::vector dk_strides(4); std::vector dv_strides(4); - generateMatrixStridesWithLayout(b, h, hg, s_q, s_kv, d_qk, d_v, dq_strides.data(), dk_strides.data(), dv_strides.data(), dqkv_layout); + generateMatrixStridesWithLayout(b, h, hg, s_q, s_kv, d_qk, d_v, dq_strides.data(), + dk_strides.data(), dv_strides.data(), dqkv_layout); dQ->set_output(true) .set_dim({b, h, s_q, d_qk}) .set_stride(dq_strides) @@ -2732,16 +2733,15 @@ void fused_attn_fp8_fwd(size_t batch, size_t num_attn_heads, size_t num_gqa_grou NVTE_Softmax_Type softmax_type, size_t window_size_left, size_t window_size_right, bool bottom_right_diagonal, const Tensor* input_Q, const Tensor* input_K, const Tensor* input_V, - const Tensor* input_SoftmaxOffset, Tensor* input_output_S, - Tensor* output_O, NVTETensorPack* Aux_CTX_Tensors, - const Tensor* cu_seqlens_q, const Tensor* cu_seqlens_kv, - const Tensor* rng_state, Tensor* workspace, cudaStream_t stream, - cudnnHandle_t handle) { + const Tensor* input_SoftmaxOffset, Tensor* input_output_S, Tensor* output_O, + NVTETensorPack* Aux_CTX_Tensors, const Tensor* cu_seqlens_q, + const Tensor* cu_seqlens_kv, const Tensor* rng_state, Tensor* workspace, + cudaStream_t stream, cudnnHandle_t handle) { using namespace transformer_engine; - void* devPtrQ = nullptr, *devPtrK = nullptr, *devPtrV = nullptr; - void* devPtrDescaleQ = nullptr, *devPtrDescaleK = nullptr, *devPtrDescaleV = nullptr; - void* devPtrO = nullptr, *devPtrAmaxO = nullptr, *devPtrScaleO = nullptr; - void* devPtrAmaxS = nullptr, *devPtrScaleS = nullptr, *devPtrDescaleS = nullptr; + void *devPtrQ = nullptr, *devPtrK = nullptr, *devPtrV = nullptr; + void *devPtrDescaleQ = nullptr, *devPtrDescaleK = nullptr, *devPtrDescaleV = nullptr; + void *devPtrO = nullptr, *devPtrAmaxO = nullptr, *devPtrScaleO = nullptr; + void *devPtrAmaxS = nullptr, *devPtrScaleS = nullptr, *devPtrDescaleS = nullptr; devPtrQ = input_Q->data.dptr; devPtrDescaleQ = input_Q->scale_inv.dptr; devPtrK = input_K->data.dptr; @@ -2821,10 +2821,10 @@ void fused_attn_fp8_fwd(size_t batch, size_t num_attn_heads, size_t num_gqa_grou fused_attn::fused_attn_fp8_fwd_impl_v1( batch, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, head_dim_qk, head_dim_v, is_training, attn_scale, p_dropout, qkv_layout, o_format, bias_type, mask_type, - softmax_type, window_size_left, window_size_right, bottom_right_diagonal, devPtrQ, - devPtrK, devPtrV, devPtrSoftmaxOffset, devPtrM, devPtrZInv, devPtrO, devPtrDescaleQ, - devPtrDescaleK, devPtrDescaleV, devPtrDescaleS, devPtrScaleS, devPtrScaleO, devPtrAmaxO, - devPtrAmaxS, devPtrcuSeqlensQ, devPtrcuSeqlensKV, devPtrDropoutSeed, devPtrDropoutOffset, + softmax_type, window_size_left, window_size_right, bottom_right_diagonal, devPtrQ, devPtrK, + devPtrV, devPtrSoftmaxOffset, devPtrM, devPtrZInv, devPtrO, devPtrDescaleQ, devPtrDescaleK, + devPtrDescaleV, devPtrDescaleS, devPtrScaleS, devPtrScaleO, devPtrAmaxO, devPtrAmaxS, + devPtrcuSeqlensQ, devPtrcuSeqlensKV, devPtrDropoutSeed, devPtrDropoutOffset, get_cudnn_fe_dtype(QKV_type), get_cudnn_fe_dtype(O_type), input_Q->scaling_mode, workspace->data.dptr, &workspace_size, stream, handle); } else if (qkv_layout == NVTE_QKV_Layout::NVTE_T3HD) { @@ -2852,22 +2852,19 @@ void fused_attn_fp8_fwd(size_t batch, size_t num_attn_heads, size_t num_gqa_grou } } // fused attention BWD FP8 with separate Q, K, V -void fused_attn_fp8_bwd(size_t batch, size_t num_attn_heads, size_t num_gqa_groups, - size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim_qk, - size_t head_dim_v, float attn_scale, float p_dropout, - NVTE_QKV_Layout qkv_layout, NVTE_QKV_Format o_format, - NVTE_QKV_Format d_out_format, NVTE_QKV_Layout dqkv_layout, - NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, - NVTE_Softmax_Type softmax_type, size_t window_size_left, - size_t window_size_right, bool bottom_right_diagonal, bool deterministic, - const Tensor* input_Q, const Tensor* input_K, const Tensor* input_V, - const Tensor* input_O, const Tensor* input_dO, const Tensor* input_dO_f16, - const Tensor* input_M, const Tensor* input_ZInv, const Tensor* input_S, - const Tensor* input_SoftmaxOffset, Tensor* input_output_dP, - const Tensor* output_dQ, const Tensor* output_dK, const Tensor* output_dV, - Tensor* output_dSoftmaxOffset, const Tensor* cu_seqlens_q, - const Tensor* cu_seqlens_kv, const Tensor* rng_state, Tensor* workspace, - cudaStream_t stream, cudnnHandle_t handle) { +void fused_attn_fp8_bwd( + size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, + size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, float attn_scale, float p_dropout, + NVTE_QKV_Layout qkv_layout, NVTE_QKV_Format o_format, NVTE_QKV_Format d_out_format, + NVTE_QKV_Layout dqkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, + NVTE_Softmax_Type softmax_type, size_t window_size_left, size_t window_size_right, + bool bottom_right_diagonal, bool deterministic, const Tensor* input_Q, const Tensor* input_K, + const Tensor* input_V, const Tensor* input_O, const Tensor* input_dO, + const Tensor* input_dO_f16, const Tensor* input_M, const Tensor* input_ZInv, + const Tensor* input_S, const Tensor* input_SoftmaxOffset, Tensor* input_output_dP, + const Tensor* output_dQ, const Tensor* output_dK, const Tensor* output_dV, + Tensor* output_dSoftmaxOffset, const Tensor* cu_seqlens_q, const Tensor* cu_seqlens_kv, + const Tensor* rng_state, Tensor* workspace, cudaStream_t stream, cudnnHandle_t handle) { using namespace transformer_engine; void* devPtrQ = input_Q->data.dptr; void* devPtrK = input_K->data.dptr; @@ -2875,7 +2872,8 @@ void fused_attn_fp8_bwd(size_t batch, size_t num_attn_heads, size_t num_gqa_grou void* devPtrDescaleQ = input_Q->scale_inv.dptr; void* devPtrDescaleK = input_K->scale_inv.dptr; void* devPtrDescaleV = input_V->scale_inv.dptr; - void* devPtrQ_t = nullptr, *devPtrK_t = nullptr, *devPtrDescaleQ_t = nullptr, *devPtrDescaleK_t = nullptr; + void *devPtrQ_t = nullptr, *devPtrK_t = nullptr, *devPtrDescaleQ_t = nullptr, + *devPtrDescaleK_t = nullptr; if (input_Q->scaling_mode == NVTE_MXFP8_1D_SCALING) { devPtrQ_t = input_Q->columnwise_data.dptr; devPtrDescaleQ_t = input_Q->columnwise_scale_inv.dptr; @@ -2891,7 +2889,7 @@ void fused_attn_fp8_bwd(size_t batch, size_t num_attn_heads, size_t num_gqa_grou } void* devPtrdO = input_dO->data.dptr; void* devPtrDescaledO = input_dO->scale_inv.dptr; - void* devPtrdO_t = nullptr, *devPtrdO_f16 = nullptr, *devPtrDescaledO_t = nullptr; + void *devPtrdO_t = nullptr, *devPtrdO_f16 = nullptr, *devPtrDescaledO_t = nullptr; if (input_dO->scaling_mode == NVTE_MXFP8_1D_SCALING) { devPtrdO_t = input_dO->columnwise_data.dptr; devPtrdO_f16 = input_dO_f16->data.dptr; @@ -2945,8 +2943,8 @@ void fused_attn_fp8_bwd(size_t batch, size_t num_attn_heads, size_t num_gqa_grou batch, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, head_dim_qk, head_dim_v, attn_scale, p_dropout, qkv_layout, o_format, d_out_format, dqkv_layout, bias_type, mask_type, softmax_type, window_size_left, window_size_right, bottom_right_diagonal, - deterministic, devPtrQ, devPtrK, devPtrV, - devPtrM, devPtrZInv, devPtrO, devPtrdO, devPtrSoftmaxOffset, devPtrdQ, devPtrdK, devPtrdV, devPtrdSoftmaxOffset, devPtrDescaleQ, + deterministic, devPtrQ, devPtrK, devPtrV, devPtrM, devPtrZInv, devPtrO, devPtrdO, + devPtrSoftmaxOffset, devPtrdQ, devPtrdK, devPtrdV, devPtrdSoftmaxOffset, devPtrDescaleQ, devPtrDescaleK, devPtrDescaleV, devPtrDescaleO, devPtrDescaledO, devPtrDescaleS, devPtrDescaledP, devPtrScaleS, devPtrScaledP, devPtrScaledQ, devPtrScaledK, devPtrScaledV, devPtrAmaxdP, devPtrAmaxdQ, devPtrAmaxdK, devPtrAmaxdV, devPtrQ_t, devPtrK_t, devPtrdO_f16, diff --git a/transformer_engine/common/fused_attn/fused_attn_fp8.h b/transformer_engine/common/fused_attn/fused_attn_fp8.h index 47777279b9..617efa8f42 100644 --- a/transformer_engine/common/fused_attn/fused_attn_fp8.h +++ b/transformer_engine/common/fused_attn/fused_attn_fp8.h @@ -21,28 +21,25 @@ void fused_attn_fp8_fwd(size_t batch, size_t num_attn_heads, size_t num_gqa_grou NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, size_t window_size_left, size_t window_size_right, bool bottom_right_diagonal, const Tensor *input_Q, - const Tensor *input_K, const Tensor *input_V, const Tensor *input_SoftmaxOffset, - Tensor *input_output_S, Tensor *output_O, NVTETensorPack *Aux_CTX_Tensors, - const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv, - const Tensor *rng_state, Tensor *workspace, cudaStream_t stream, - cudnnHandle_t handle); - -// fused attention BWD FP8 with separate Q, K, V -void fused_attn_fp8_bwd(size_t batch, size_t num_attn_heads, size_t num_gqa_groups, - size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim_qk, - size_t head_dim_v, float attn_scale, float p_dropout, - NVTE_QKV_Layout qkv_layout, NVTE_QKV_Format o_format, - NVTE_QKV_Format d_out_format, NVTE_QKV_Layout dqkv_layout, - NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, - NVTE_Softmax_Type softmax_type, size_t window_size_left, - size_t window_size_right, bool bottom_right_diagonal, bool deterministic, - const Tensor *input_Q, const Tensor *input_K, const Tensor *input_V, - const Tensor *input_O, const Tensor *input_dO, const Tensor *input_dO_f16, - const Tensor *input_M, const Tensor *input_ZInv, const Tensor *input_S, - const Tensor *input_SoftmaxOffset, Tensor *input_output_dP, - const Tensor *output_dQ, const Tensor *output_dK, const Tensor *output_dV, - Tensor *output_dSoftmaxOffset, const Tensor *cu_seqlens_q, + const Tensor *input_K, const Tensor *input_V, + const Tensor *input_SoftmaxOffset, Tensor *input_output_S, Tensor *output_O, + NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv, const Tensor *rng_state, Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle); + +// fused attention BWD FP8 with separate Q, K, V +void fused_attn_fp8_bwd( + size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, + size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, float attn_scale, float p_dropout, + NVTE_QKV_Layout qkv_layout, NVTE_QKV_Format o_format, NVTE_QKV_Format d_out_format, + NVTE_QKV_Layout dqkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, + NVTE_Softmax_Type softmax_type, size_t window_size_left, size_t window_size_right, + bool bottom_right_diagonal, bool deterministic, const Tensor *input_Q, const Tensor *input_K, + const Tensor *input_V, const Tensor *input_O, const Tensor *input_dO, + const Tensor *input_dO_f16, const Tensor *input_M, const Tensor *input_ZInv, + const Tensor *input_S, const Tensor *input_SoftmaxOffset, Tensor *input_output_dP, + const Tensor *output_dQ, const Tensor *output_dK, const Tensor *output_dV, + Tensor *output_dSoftmaxOffset, const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv, + const Tensor *rng_state, Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle); #endif // end of CUDNN>=8900 } // namespace transformer_engine diff --git a/transformer_engine/common/fused_attn/utils.h b/transformer_engine/common/fused_attn/utils.h index 8c60f85cf5..fdda4dfe9c 100644 --- a/transformer_engine/common/fused_attn/utils.h +++ b/transformer_engine/common/fused_attn/utils.h @@ -103,9 +103,10 @@ inline void generateMatrixStridesWithFormat(int64_t b, int64_t h, int64_t s, int } // get matrix strides based on layout and matrix type -inline void generateMatrixStridesWithLayout(int64_t b, int64_t h, int64_t hg, int64_t s_q, int64_t s_kv, - int64_t d_qk, int64_t d_v, int64_t *q_strides, int64_t *k_strides, int64_t *v_strides, - NVTE_QKV_Layout layout) { +inline void generateMatrixStridesWithLayout(int64_t b, int64_t h, int64_t hg, int64_t s_q, + int64_t s_kv, int64_t d_qk, int64_t d_v, + int64_t *q_strides, int64_t *k_strides, + int64_t *v_strides, NVTE_QKV_Layout layout) { constexpr int b_dim = 0; constexpr int h_dim = 1; constexpr int s_dim = 2; @@ -131,17 +132,17 @@ inline void generateMatrixStridesWithLayout(int64_t b, int64_t h, int64_t hg, in } break; case NVTE_QKV_Layout::NVTE_SBHD_SB2HD: - q_strides[b_dim] = h * d_qk; - q_strides[h_dim] = d_qk; - q_strides[s_dim] = b * h * d_qk; - q_strides[d_dim] = 1; - k_strides[b_dim] = 2 * hg * d_qk; - k_strides[h_dim] = d_qk; - k_strides[s_dim] = b * 2 * hg * d_qk; - k_strides[d_dim] = 1; - for (int i = 0; i < 4; i++) { - v_strides[i] = k_strides[i]; - } + q_strides[b_dim] = h * d_qk; + q_strides[h_dim] = d_qk; + q_strides[s_dim] = b * h * d_qk; + q_strides[d_dim] = 1; + k_strides[b_dim] = 2 * hg * d_qk; + k_strides[h_dim] = d_qk; + k_strides[s_dim] = b * 2 * hg * d_qk; + k_strides[d_dim] = 1; + for (int i = 0; i < 4; i++) { + v_strides[i] = k_strides[i]; + } break; case NVTE_QKV_Layout::NVTE_SBHD_SBH2D: q_strides[b_dim] = h * d_qk; @@ -158,18 +159,18 @@ inline void generateMatrixStridesWithLayout(int64_t b, int64_t h, int64_t hg, in break; case NVTE_QKV_Layout::NVTE_SBHD_SBHD_SBHD: case NVTE_QKV_Layout::NVTE_Paged_KV_SBHD_SBHD_SBHD: - q_strides[b_dim] = h * d_qk; - q_strides[h_dim] = d_qk; - q_strides[s_dim] = b * h * d_qk; - q_strides[d_dim] = 1; - k_strides[b_dim] = hg * d_qk; - k_strides[h_dim] = d_qk; - k_strides[s_dim] = b * hg * d_qk; - k_strides[d_dim] = 1; - v_strides[b_dim] = hg * d_v; - v_strides[h_dim] = d_v; - v_strides[s_dim] = b * hg * d_v; - v_strides[d_dim] = 1; + q_strides[b_dim] = h * d_qk; + q_strides[h_dim] = d_qk; + q_strides[s_dim] = b * h * d_qk; + q_strides[d_dim] = 1; + k_strides[b_dim] = hg * d_qk; + k_strides[h_dim] = d_qk; + k_strides[s_dim] = b * hg * d_qk; + k_strides[d_dim] = 1; + v_strides[b_dim] = hg * d_v; + v_strides[h_dim] = d_v; + v_strides[s_dim] = b * hg * d_v; + v_strides[d_dim] = 1; break; case NVTE_QKV_Layout::NVTE_BS3HD: case NVTE_QKV_Layout::NVTE_T3HD: diff --git a/transformer_engine/pytorch/attention/dot_product_attention/backends.py b/transformer_engine/pytorch/attention/dot_product_attention/backends.py index f40576425d..2835458449 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/backends.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/backends.py @@ -158,6 +158,7 @@ _replace_out_with_shadow = os.getenv("NVTE_REPLACE_OUT_WITH_SHADOW", "0") == "1" _replace_aux_with_shadow = os.getenv("NVTE_REPLACE_AUX_WITH_SHADOW", "0") == "1" + class FP8EmulationFunc(torch.autograd.Function): """ Emulate the effects of FP8 quantization on tensors. Used in UnfusedDotProductAttention as follows: @@ -1345,7 +1346,9 @@ def forward( if _run_shadow_f16_fwd: # q, k, v, out_: torch.Tensor; dtype = torch.float16 or torch.bfloat16 - assert all(x.dtype in [torch.float16, torch.bfloat16] for x in [q, k, v]), "q, k, v must be torch.float16 or torch.bfloat16" + assert all( + x.dtype in [torch.float16, torch.bfloat16] for x in [q, k, v] + ), "q, k, v must be torch.float16 or torch.bfloat16" out_f16_, aux_ctx_tensors_f16, *_ = fused_attn_fwd( is_training, max_seqlen_q, @@ -1380,8 +1383,16 @@ def forward( is_graph_capturing(), ) if torch.cuda.current_device() == 0: - print(f"L{layer_number}: real/shadow out min: {out_.min():.4f}/{out_f16_.min():.4f}, max: {out_.max():.4f}/{out_f16_.max():.4f}") - print(f"L{layer_number}: real/shadow stats min: {aux_ctx_tensors[0].min():.4f}/{aux_ctx_tensors_f16[0].min():.4f}, max: {aux_ctx_tensors[0].max():.4f}/{aux_ctx_tensors_f16[0].max():.4f}") + print( + f"L{layer_number}: real/shadow out min:" + f" {out_.min():.4f}/{out_f16_.min():.4f}, max:" + f" {out_.max():.4f}/{out_f16_.max():.4f}" + ) + print( + f"L{layer_number}: real/shadow stats min:" + f" {aux_ctx_tensors[0].min():.4f}/{aux_ctx_tensors_f16[0].min():.4f}, max:" + f" {aux_ctx_tensors[0].max():.4f}/{aux_ctx_tensors_f16[0].max():.4f}" + ) if _replace_out_with_shadow: out_ = out_f16_ if _replace_aux_with_shadow: diff --git a/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py b/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py index 92b87fb57f..116a6ba011 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py @@ -517,7 +517,7 @@ def flash_attn_a2a_communicate( tmp_list.insert(0, "2") tmp_list.insert(0, "c") tmp_format = "".join(tmp_list) - head_dim_ = tmp_format.index("h")-1 + head_dim_ = tmp_format.index("h") - 1 tmp_list.insert(head_dim_, tmp_list.pop(0)) x = x.movedim(0, head_dim_) # [2, b, s//2, cp, h//cp, d] -> [b, 2, s//2, cp, h//cp, d] @@ -526,7 +526,7 @@ def flash_attn_a2a_communicate( # or [t, cp, h//cp, d] -> [t, cp, h//cp, d] if "t" not in qkv_format: tmp_format = "".join(tmp_list) - seq_dim_ = tmp_format.index("s")-1 + seq_dim_ = tmp_format.index("s") - 1 tmp_list.insert(seq_dim_, tmp_list.pop(0)) x = x.movedim(0, seq_dim_) else: @@ -2797,7 +2797,12 @@ def backward(ctx, dout, *_args): Float8Tensor.make_like(x, data=y, dtype=bwd_nominal_dtype) for x, y in zip([dq_fp8, dk_fp8, dv_fp8], [dq, dk, dv]) ] - dq, dk, dv = [x.view(y) for x,y in zip([dq, dk, dv], [ctx.orig_q_shape, ctx.orig_k_shape, ctx.orig_v_shape])] + dq, dk, dv = [ + x.view(y) + for x, y in zip( + [dq, dk, dv], [ctx.orig_q_shape, ctx.orig_k_shape, ctx.orig_v_shape] + ) + ] if attn_dbias is not None: # [b, h, sq, 2*cp, sk//(2*cp)] -> [b, h, sq, sk] @@ -2912,18 +2917,25 @@ def forward( softmax_scale = q.shape[-1] ** (-0.5) assert qkv_format != "thd", f"No support for cp_comm_type='all_gather' and {qkv_format=}." - assert "padding" not in attn_mask_type, f"No support for cp_comm_type='all_gather' and {attn_mask_type=}." - assert attn_bias_type == "no_bias", f"No support for cp_comm_type='all_gather' and {attn_bias_type=}." + assert ( + "padding" not in attn_mask_type + ), f"No support for cp_comm_type='all_gather' and {attn_mask_type=}." + assert ( + attn_bias_type == "no_bias" + ), f"No support for cp_comm_type='all_gather' and {attn_bias_type=}." assert ( window_size == (-1, 0) or window_size == (-1, -1) or use_fused_attention or fa_utils.v2_3_plus - ), f"cp_comm_type='all_gather' only supports SWA through FusedAttention or FlashAttention >= 2.3. Found {use_fused_attention=} and {fa_utils.v2_3_plus=}." - assert ( - q.shape[seq_dim_qkv] % 2 == 0 and k.shape[seq_dim_qkv] % 2 == 0 - ), f"cp_comm_type='all_gather' requires seq_len % 2 == 0 for Q, K, V. Found seq_len_q = {q.shape[seq_dim_qkv]}, seq_len_kv = {k.shape[seq_dim_qkv]}, cp_size = {cp_size}." - + ), ( + "cp_comm_type='all_gather' only supports SWA through FusedAttention or FlashAttention" + f" >= 2.3. Found {use_fused_attention=} and {fa_utils.v2_3_plus=}." + ) + assert q.shape[seq_dim_qkv] % 2 == 0 and k.shape[seq_dim_qkv] % 2 == 0, ( + "cp_comm_type='all_gather' requires seq_len % 2 == 0 for Q, K, V. Found seq_len_q =" + f" {q.shape[seq_dim_qkv]}, seq_len_kv = {k.shape[seq_dim_qkv]}, cp_size = {cp_size}." + ) flash_attn_fwd = None if not use_fused_attention: @@ -3012,7 +3024,9 @@ def forward( # reshape: split s # [b, s, h, d] -> [b, 2, s//2, h, d] # [s, b, h, d] -> [2, s//2, b, h, d] - q = q.view(*q.shape[:seq_dim_qkv], 2, q.shape[seq_dim_qkv] // 2, *q.shape[(seq_dim_qkv + 1) :]) + q = q.view( + *q.shape[:seq_dim_qkv], 2, q.shape[seq_dim_qkv] // 2, *q.shape[(seq_dim_qkv + 1) :] + ) # s dim first for all-gather # [b, s, h, d]/[s, b, h, d] -> [s, b, h, d] k, v = [x.movedim(seq_dim_qkv, 0).contiguous() for x in [k, v]] @@ -3083,7 +3097,9 @@ def forward( # select range: [s_range, b, h, d] k_part, v_part = [x[seq_start_idx:seq_end_idx] for x in [k_ag, v_ag]] # reshape to original format: [b, s_range, h, d] or [s_range, b, h, d] - k_part, v_part = [x.movedim(0, seq_dim_qkv).contiguous() for x in [k_part, v_part]] + k_part, v_part = [ + x.movedim(0, seq_dim_qkv).contiguous() for x in [k_part, v_part] + ] if use_fused_attention: new_qkv_layout = qkv_layout if fp8: @@ -3239,7 +3255,9 @@ def forward( # FP8DS or (FP8CS+not _dpa_fp8_cs_o_in_f16): q/k/v/o all in FP8 # FP8CS+_dpa_fp8_cs_o_in_f16: q/k/v in FP8, o in f16 # MXFP8: q/k/v/o all in f16 - if fp8_recipe.delayed() or (fp8_recipe.float8_current_scaling() and not _dpa_fp8_cs_o_in_f16): + if fp8_recipe.delayed() or ( + fp8_recipe.float8_current_scaling() and not _dpa_fp8_cs_o_in_f16 + ): fp8_tensors = (q_fp8_save, k_fp8_save, v_fp8_save, out_fp8) elif fp8_recipe.float8_current_scaling() and _dpa_fp8_cs_o_in_f16: fp8_tensors = (q_fp8_save, k_fp8_save, v_fp8_save, None) @@ -3380,7 +3398,9 @@ def backward(ctx, dout, *_args): if ctx.fp8 and not ctx.fp8_recipe.mxfp8(): q, k, v = [x._data for x in [q_fp8, k_fp8, v_fp8]] if not ctx.qkv_reshaped: - q = q.view(*q.shape[:seq_dim_qkv], 2, q.shape[seq_dim_qkv] // 2, *q.shape[(seq_dim_qkv + 1) :]) + q = q.view( + *q.shape[:seq_dim_qkv], 2, q.shape[seq_dim_qkv] // 2, *q.shape[(seq_dim_qkv + 1) :] + ) k, v = [x.movedim(seq_dim_qkv, 0).contiguous() for x in [k, v]] # set up out: @@ -3389,7 +3409,10 @@ def backward(ctx, dout, *_args): # MXFP8/F16: torch.float16 or torch.bfloat16 # [b, s, h, d] -> [b, 2, s//2, h, d] # [s, b, h, d] -> [2, s//2, b, h, d] - if ctx.fp8 and (ctx.fp8_recipe.delayed() or (ctx.fp8_recipe.float8_current_scaling() and not _dpa_fp8_cs_o_in_f16)): + if ctx.fp8 and ( + ctx.fp8_recipe.delayed() + or (ctx.fp8_recipe.float8_current_scaling() and not _dpa_fp8_cs_o_in_f16) + ): out = out_fp8._data out = out.view(ctx.o_shape) @@ -3479,7 +3502,9 @@ def backward(ctx, dout, *_args): # select range: [s_range, b, h, d] k_part, v_part = [x[seq_start_idx:seq_end_idx] for x in [k_ag, v_ag]] # reshape to original format: [b, s_range, h, d] or [s_range, b, h, d] - k_part, v_part = [x.movedim(0, seq_dim_qkv).contiguous() for x in [k_part, v_part]] + k_part, v_part = [ + x.movedim(0, seq_dim_qkv).contiguous() for x in [k_part, v_part] + ] # [b, 2, s//2, h, d] -> [b, s//2, h, d] # [2, s//2, b, h, d] -> [s//2, b, h, d] out_part = out.select(seq_dim_o, i).contiguous() @@ -3513,7 +3538,8 @@ def backward(ctx, dout, *_args): v_fp8, data=v_part, dtype=ctx.fwd_nominal_dtype ) if ctx.fp8_recipe.delayed() or ( - ctx.fp8_recipe.float8_current_scaling() and not _dpa_fp8_cs_o_in_f16 + ctx.fp8_recipe.float8_current_scaling() + and not _dpa_fp8_cs_o_in_f16 ): out_part = Float8Tensor.make_like( out_fp8, data=out_part, dtype=ctx.fwd_nominal_dtype @@ -3744,20 +3770,29 @@ def forward( softmax_scale = q.shape[-1] ** (-0.5) if qkv_format in ["bshd", "sbhd"]: - assert "padding" not in attn_mask_type, f"No support for cp_comm_type='a2a', {attn_mask_type=} and {qkv_format=}." - assert attn_bias_type == "no_bias", f"No support for cp_comm_type='a2a' and {attn_bias_type=}." + assert ( + "padding" not in attn_mask_type + ), f"No support for cp_comm_type='a2a', {attn_mask_type=} and {qkv_format=}." + assert ( + attn_bias_type == "no_bias" + ), f"No support for cp_comm_type='a2a' and {attn_bias_type=}." assert ( window_size == (-1, 0) or window_size == (-1, -1) or use_fused_attention or fa_utils.v2_3_plus - ), f"cp_comm_type='a2a' only supports SWA through FusedAttention or FlashAttention >= 2.3. Found {use_fused_attention=} and {fa_utils.v2_3_plus=}." - assert ( - q.shape[seq_dim_qkv] % 2 == 0 and k.shape[seq_dim_qkv] % 2 == 0 - ), f"cp_comm_type='a2a' requires seq_len % 2 == 0 for Q, K, V. Found seq_len_q = {q.shape[seq_dim_qkv]}, seq_len_kv = {k.shape[seq_dim_qkv]}, cp_size = {cp_size}." - assert ( - q.shape[-2] % cp_size == 0 and k.shape[-2] % cp_size == 0 - ), f"cp_comm_type='a2a' requires num_heads % cp_size == 0 for Q, K, V. Found num_heads_q = {q.shape[-2]}, num_heads_kv = {k.shape[-2]}, cp_size = {cp_size}." + ), ( + "cp_comm_type='a2a' only supports SWA through FusedAttention or FlashAttention >= 2.3." + f" Found {use_fused_attention=} and {fa_utils.v2_3_plus=}." + ) + assert q.shape[seq_dim_qkv] % 2 == 0 and k.shape[seq_dim_qkv] % 2 == 0, ( + "cp_comm_type='a2a' requires seq_len % 2 == 0 for Q, K, V. Found seq_len_q =" + f" {q.shape[seq_dim_qkv]}, seq_len_kv = {k.shape[seq_dim_qkv]}, cp_size = {cp_size}." + ) + assert q.shape[-2] % cp_size == 0 and k.shape[-2] % cp_size == 0, ( + "cp_comm_type='a2a' requires num_heads % cp_size == 0 for Q, K, V. Found num_heads_q =" + f" {q.shape[-2]}, num_heads_kv = {k.shape[-2]}, cp_size = {cp_size}." + ) flash_attn_fwd = None if not use_fused_attention: @@ -4031,10 +4066,14 @@ def forward( if ctx.fp8: # FP8DS or (FP8CS+not _dpa_fp8_cs_o_in_f16): q/k/v/o all in FP8 # (FP8CS+_dpa_fp8_cs_o_in_f16) or MXFP8: q/k/v in FP8, o in F16 - if (fp8_recipe.float8_current_scaling() and _dpa_fp8_cs_o_in_f16) or fp8_recipe.mxfp8(): + if ( + fp8_recipe.float8_current_scaling() and _dpa_fp8_cs_o_in_f16 + ) or fp8_recipe.mxfp8(): fp8_tensors = (q_part, k_part, v_part, None) f16_tensors = (None, None, None, out_part) - elif fp8_recipe.delayed() or (fp8_recipe.float8_current_scaling() and not _dpa_fp8_cs_o_in_f16): + elif fp8_recipe.delayed() or ( + fp8_recipe.float8_current_scaling() and not _dpa_fp8_cs_o_in_f16 + ): fp8_tensors = (q_part, k_part, v_part, out_part) elif fp8: # FP8DS/CS: convert post-a2a FP8 q/k/v to F16 @@ -4044,7 +4083,9 @@ def forward( f16_tensors = (q, k, v, out_part) ctx.qkv_layout = original_qkv_layout else: - q_part, k_part, v_part = combine_and_dequantize(qkv_layout, q_part, k_part, v_part) + q_part, k_part, v_part = combine_and_dequantize( + qkv_layout, q_part, k_part, v_part + ) f16_tensors = (q_part, k_part, v_part, out_part) else: # all tensors are already in F16 @@ -4310,7 +4351,10 @@ def backward(ctx, dout, *_args): qkv_format=ctx.dqkv_format, cu_seqlens_padded=cu_seqlens_q_padded, ) - dq, dk, dv = [x.view(y) for x,y in zip([dq, dk, dv], [ctx.orig_q_shape, ctx.orig_k_shape, ctx.orig_v_shape])] + dq, dk, dv = [ + x.view(y) + for x, y in zip([dq, dk, dv], [ctx.orig_q_shape, ctx.orig_k_shape, ctx.orig_v_shape]) + ] # d_bias, d_softmax_offset d_bias = None diff --git a/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py b/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py index d65b3cb30f..d1fd0b0ed0 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py @@ -627,7 +627,9 @@ def init_fp8_metadata(self, num_gemms: int = 1) -> None: # ignore the recipe from autocast, set fp8_dpa = False, fp8_mha = False fp8_recipe.fp8_dpa = False fp8_recipe.fp8_mha = False - elif (fp8_recipe.float8_current_scaling() or fp8_recipe.mxfp8()) and _dpa_fp8_recipe == "DelayedScaling": + elif ( + fp8_recipe.float8_current_scaling() or fp8_recipe.mxfp8() + ) and _dpa_fp8_recipe == "DelayedScaling": # reuse fp8_format, fp8_dpa, fp8_mha from fp8_recipe, and construct a DS recipe fake_recipe = DelayedScaling( fp8_format=fp8_recipe.fp8_format, @@ -695,7 +697,7 @@ def init_fp8_metadata(self, num_gemms: int = 1) -> None: fp8_dpa=fp8_recipe.fp8_dpa, fp8_mha=fp8_recipe.fp8_mha, reduce_amax=_dpa_fp8ds_reduce_amax, - ) + ), ] fp8_recipe_dpa = fake_recipes[1] fp8_recipes = fake_recipes @@ -719,7 +721,9 @@ def init_fp8_metadata(self, num_gemms: int = 1) -> None: ] fp8_recipe_dpa = fake_recipes[1] fp8_recipes = fake_recipes - elif (fp8_recipe.delayed() or fp8_recipe.float8_current_scaling()) and _dpa_fp8_recipe == "MXFP8BlockScaling": + elif ( + fp8_recipe.delayed() or fp8_recipe.float8_current_scaling() + ) and _dpa_fp8_recipe == "MXFP8BlockScaling": # reuse fp8_format, fp8_dpa, fp8_mha from fp8_recipe, and construct a MXFP8 recipe fake_recipe = MXFP8BlockScaling( fp8_format=fp8_recipe.fp8_format, @@ -1262,7 +1266,9 @@ def forward( cu_seqlens_kv_padded = None # get qkv's memory layout - if all(isinstance(x, Float8TensorStorage) for x in [query_layer, key_layer, value_layer]): + if all( + isinstance(x, Float8TensorStorage) for x in [query_layer, key_layer, value_layer] + ): ( qkv_layout, query_layer._data, diff --git a/transformer_engine/pytorch/attention/dot_product_attention/utils.py b/transformer_engine/pytorch/attention/dot_product_attention/utils.py index f6c7ae48ee..1e312a8662 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/utils.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/utils.py @@ -465,18 +465,20 @@ def get_attention_backend( ): if use_flash_attention_3 and FlashAttentionUtils.v3_is_installed: logger.debug( - "Disabling FlashAttention 3 for unsupported qkv_dtype = %s, qkv_type = %s. " - "Supported: qkv_dtype = {torch.bfloat16, torch.float16, torch.float8_e4m3fn}, " - "qkv_type = {torch.Tensor, Float8Tensor, Float8TensorStorage, MXFP8Tensor, MXFP8TensorStorage}. ", + "Disabling FlashAttention 3 for unsupported qkv_dtype = %s, qkv_type = %s." + " Supported: qkv_dtype = {torch.bfloat16, torch.float16, torch.float8_e4m3fn}," + " qkv_type = {torch.Tensor, Float8Tensor, Float8TensorStorage, MXFP8Tensor," + " MXFP8TensorStorage}. ", qkv_dtype, qkv_type, ) use_flash_attention_3 = False if use_fused_attention: logger.debug( - "Disabling FusedAttention for unsupported qkv_dtype = %s, qkv_type = %s. " - "Supported: qkv_dtype = {torch.bfloat16, torch.float16, torch.float8_e4m3fn}, " - "qkv_type = {torch.Tensor, Float8Tensor, Float8TensorStorage, MXFP8Tensor, MXFP8TensorStorage}. ", + "Disabling FusedAttention for unsupported qkv_dtype = %s, qkv_type = %s. Supported:" + " qkv_dtype = {torch.bfloat16, torch.float16, torch.float8_e4m3fn}, qkv_type =" + " {torch.Tensor, Float8Tensor, Float8TensorStorage, MXFP8Tensor," + " MXFP8TensorStorage}. ", qkv_dtype, qkv_type, ) @@ -496,7 +498,9 @@ def get_attention_backend( if FlashAttentionUtils.v3_is_installed: logger.debug("Disabling FlashAttention 3 for FP8 training") use_flash_attention_3 = False - if use_flash_attention_3 and not (fp8_recipe.delayed() or fp8_recipe.float8_current_scaling()): + if use_flash_attention_3 and not ( + fp8_recipe.delayed() or fp8_recipe.float8_current_scaling() + ): if FlashAttentionUtils.v3_is_installed: logger.debug(f"Disabling FlashAttention 3 for {fp8_recipe.__class__.__name__}") use_flash_attention_3 = False @@ -508,8 +512,15 @@ def get_attention_backend( logger.debug("Disabling UnfusedDotProductAttention for FP8 attention") use_unfused_attention = False if use_fused_attention and fp8_recipe.delayed(): - if device_compute_capability >= (10, 0) and deterministic and cudnn_version < (9, 18, 0): - logger.debug("Disabling FusedAttention for FP8 delayed scaling on arch >= sm100 with determinism for cuDNN < 9.18.0") + if ( + device_compute_capability >= (10, 0) + and deterministic + and cudnn_version < (9, 18, 0) + ): + logger.debug( + "Disabling FusedAttention for FP8 delayed scaling on arch >= sm100 with" + " determinism for cuDNN < 9.18.0" + ) use_fused_attention = False if use_fused_attention and fp8_recipe.float8_current_scaling(): if device_compute_capability < (10, 0): @@ -2296,9 +2307,10 @@ def combine_and_quantize(qkv_layout, q, k, v, qkv_quantizer): original_shapes = [x.shape for x in [q, k, v]] s_q, d_qk = q.shape[-2:] s_kv, d_v = v.shape[-2:] - assert ( - s_q % 128 == 0 and s_kv % 128 == 0 and d_qk % 32 == 0 and d_v % 32 == 0 - ), f"MXFP8 quantization requires s_q % 128 == 0, s_kv % 128 == 0, d_qk % 32 == 0, d_v % 32 == 0. Found {s_q=}, {s_kv=}, {d_qk=}, {d_v=}." + assert s_q % 128 == 0 and s_kv % 128 == 0 and d_qk % 32 == 0 and d_v % 32 == 0, ( + "MXFP8 quantization requires s_q % 128 == 0, s_kv % 128 == 0, d_qk % 32 == 0, d_v % 32" + f" == 0. Found {s_q=}, {s_kv=}, {d_qk=}, {d_v=}." + ) q, k, v = [x.view(-1, x.shape[-1]) for x in [q, k, v]] # quantize q, k, v if d_qk == d_v: diff --git a/transformer_engine/pytorch/csrc/extensions.h b/transformer_engine/pytorch/csrc/extensions.h index 6fe8a61b01..602a09f54b 100644 --- a/transformer_engine/pytorch/csrc/extensions.h +++ b/transformer_engine/pytorch/csrc/extensions.h @@ -103,8 +103,7 @@ std::vector fused_attn_bwd( NVTE_Softmax_Type softmax_type, const std::vector window_size, bool bottom_right_diagonal, bool deterministic, const at::Tensor cu_seqlens_q, const at::Tensor cu_seqlens_kv, const py::handle Q, const py::handle K, const py::handle V, - const py::handle O, const py::handle dO, - const at::ScalarType fake_dtype, + const py::handle O, const py::handle dO, const at::ScalarType fake_dtype, const std::vector Aux_CTX_Tensors, const std::optional cu_seqlens_q_padded, const std::optional cu_seqlens_kv_padded, py::handle s_quantizer, diff --git a/transformer_engine/pytorch/csrc/extensions/attention.cpp b/transformer_engine/pytorch/csrc/extensions/attention.cpp index 619c8f92c6..f203d02c4b 100644 --- a/transformer_engine/pytorch/csrc/extensions/attention.cpp +++ b/transformer_engine/pytorch/csrc/extensions/attention.cpp @@ -103,9 +103,8 @@ std::pair quantizer_helper(py::handle quantizer, } } else { std::tie(te_T, py_T) = T_quantizer_fp8->create_tensor(shape, dtype); - NVTE_CHECK( - !data.has_value(), - "MXFP8Quantizer::create_tensor() does not take data tensor as input!"); + NVTE_CHECK(!data.has_value(), + "MXFP8Quantizer::create_tensor() does not take data tensor as input!"); } } return {std::move(te_T), std::move(py_T)}; @@ -155,8 +154,7 @@ std::vector fused_attn_fwd( auto o_shape = std::vector{o_shape_tmp.begin(), o_shape_tmp.end()}; size_t b = 0, h = 0, s = 0, d = 0, t = 0; NVTE_QKV_Format q_format = nvte_get_q_format(qkv_layout); - nvte_convert_qkv_format(q_format, o_shape_tmp, o_format, o_shape, &b, &h, &s, - &d, &t); + nvte_convert_qkv_format(q_format, o_shape_tmp, o_format, o_shape, &b, &h, &s, &d, &t); if (q_format == NVTE_QKV_Format::NVTE_THD) { b = cu_seqlens_q.size(0) - 1; } @@ -337,8 +335,7 @@ std::vector fused_attn_bwd( NVTE_Softmax_Type softmax_type, const std::vector window_size, bool bottom_right_diagonal, bool deterministic, const at::Tensor cu_seqlens_q, const at::Tensor cu_seqlens_kv, const py::handle Q, const py::handle K, const py::handle V, - const py::handle O, const py::handle dO, - const at::ScalarType fake_dtype, + const py::handle O, const py::handle dO, const at::ScalarType fake_dtype, const std::vector Aux_CTX_Tensors, const std::optional cu_seqlens_q_padded, const std::optional cu_seqlens_kv_padded, py::handle s_quantizer, @@ -375,12 +372,9 @@ std::vector fused_attn_bwd( NVTE_QKV_Format kv_format = nvte_get_kv_format(qkv_layout); NVTE_QKV_Format dq_format = nvte_get_q_format(dqkv_layout); NVTE_QKV_Format dkv_format = nvte_get_kv_format(dqkv_layout); - nvte_convert_qkv_format(q_format, q_shape, dq_format, - dQ_shape, &b, &h_q, &s_q, &d_qk, &t_q); - nvte_convert_qkv_format(kv_format, k_shape, dkv_format, - dK_shape, &b, &h_kv, &s_kv, &d_qk, &t_kv); - nvte_convert_qkv_format(kv_format, v_shape, dkv_format, - dV_shape, &b, &h_kv, &s_kv, &d_v, &t_kv); + nvte_convert_qkv_format(q_format, q_shape, dq_format, dQ_shape, &b, &h_q, &s_q, &d_qk, &t_q); + nvte_convert_qkv_format(kv_format, k_shape, dkv_format, dK_shape, &b, &h_kv, &s_kv, &d_qk, &t_kv); + nvte_convert_qkv_format(kv_format, v_shape, dkv_format, dV_shape, &b, &h_kv, &s_kv, &d_v, &t_kv); if (dq_format == NVTE_QKV_Format::NVTE_THD) { b = cu_seqlens_q.size(0) - 1; } else if (dkv_format == NVTE_QKV_Format::NVTE_THD) { From e51ec9fcbfc7a43e5eb60a6da0d3b9560e842bf6 Mon Sep 17 00:00:00 2001 From: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Date: Wed, 18 Mar 2026 19:21:51 -0700 Subject: [PATCH 089/172] update group tensor usage after merge main Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --- .../attention/dot_product_attention/utils.py | 35 ++++++++++++------- 1 file changed, 23 insertions(+), 12 deletions(-) diff --git a/transformer_engine/pytorch/attention/dot_product_attention/utils.py b/transformer_engine/pytorch/attention/dot_product_attention/utils.py index 1e312a8662..5eb7178e9d 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/utils.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/utils.py @@ -44,7 +44,7 @@ from transformer_engine.pytorch.tensor.float8_tensor import Float8TensorStorage from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8Quantizer, MXFP8Tensor from transformer_engine.pytorch.tensor.storage.mxfp8_tensor_storage import MXFP8TensorStorage -from transformer_engine.pytorch.tensor.storage.grouped_tensor import GroupedTensor +from transformer_engine.pytorch.tensor.grouped_tensor import GroupedTensor from transformer_engine.pytorch.quantization import get_fp8_te_dtype from transformer_engine.pytorch.constants import TE_DType @@ -2313,17 +2313,28 @@ def combine_and_quantize(qkv_layout, q, k, v, qkv_quantizer): ) q, k, v = [x.view(-1, x.shape[-1]) for x in [q, k, v]] # quantize q, k, v - if d_qk == d_v: - grouped_tensor = GroupedTensor.create_and_quantize( - tensors=[q, k, v], quantizer=qkv_quantizer - ) - q_fp8, k_fp8, v_fp8 = grouped_tensor.quantized_tensors - else: - grouped_tensor = GroupedTensor.create_and_quantize( - tensors=[q, k], quantizer=qkv_quantizer - ) - q_fp8, k_fp8 = grouped_tensor.quantized_tensors - v_fp8 = qkv_quantizer(v) + # if d_qk == d_v: + # grouped_tensor = GroupedTensor.create_and_quantize( + # tensors=[q, k, v], quantizer=qkv_quantizer + # ) + # q_fp8, k_fp8, v_fp8 = grouped_tensor.quantized_tensors + # else: + # grouped_tensor = GroupedTensor.create_and_quantize( + # tensors=[q, k], quantizer=qkv_quantizer + # ) + # q_fp8, k_fp8 = grouped_tensor.quantized_tensors + # v_fp8 = qkv_quantizer(v) + input_tensors = [q, k, v] + num_tensors = len(input_tensors) + shapes = [x.shape for x in input_tensors] + grouped_tensor = GroupedTensor.make_grouped_tensor_with_shapes( + num_tensors=num_tensors, + shapes=shapes, + quantizer=qkv_quantizer, + device="cuda", + dtype=q.dtype, + ) + q_fp8, k_fp8, v_fp8 = grouped_tensor.quantize(input_tensors) # view rowwise/columnwise data back to original shapes, not rowwise_scale_inv/columnwise_scale_inv q_fp8, k_fp8, v_fp8 = [x.view(s) for x, s in zip([q_fp8, k_fp8, v_fp8], original_shapes)] From 7bb40d53c63eb9e9c411d65eb7468313e7f3891f Mon Sep 17 00:00:00 2001 From: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Date: Wed, 18 Mar 2026 19:22:41 -0700 Subject: [PATCH 090/172] env vars for qdq(q,k), o_f16 tests Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --- .../dot_product_attention/backends.py | 40 +++++++++++-------- 1 file changed, 24 insertions(+), 16 deletions(-) diff --git a/transformer_engine/pytorch/attention/dot_product_attention/backends.py b/transformer_engine/pytorch/attention/dot_product_attention/backends.py index 2835458449..4727bdefa0 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/backends.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/backends.py @@ -155,8 +155,9 @@ # Float8CurrentScaling: fused_attn_bwd takes O in FP8 by default, this flag allows it in F16 _dpa_fp8_cs_o_in_f16 = os.getenv("NVTE_DPA_FP8CS_O_in_F16", "1") == "1" _run_shadow_f16_fwd = os.getenv("NVTE_RUN_SHADOW_F16_FWD", "0") == "1" -_replace_out_with_shadow = os.getenv("NVTE_REPLACE_OUT_WITH_SHADOW", "0") == "1" -_replace_aux_with_shadow = os.getenv("NVTE_REPLACE_AUX_WITH_SHADOW", "0") == "1" +_replace_out_return_with_shadow_f16 = os.getenv("NVTE_REPLACE_OUT_RETURN_WITH_SHADOW_F16", "0") == "1" +_replace_out_save_with_shadow_f16 = os.getenv("NVTE_REPLACE_OUT_SAVE_WITH_SHADOW_F16", "0") == "1" +_replace_aux_with_shadow_f16 = os.getenv("NVTE_REPLACE_AUX_WITH_SHADOW_F16", "0") == "1" class FP8EmulationFunc(torch.autograd.Function): @@ -1383,20 +1384,8 @@ def forward( is_graph_capturing(), ) if torch.cuda.current_device() == 0: - print( - f"L{layer_number}: real/shadow out min:" - f" {out_.min():.4f}/{out_f16_.min():.4f}, max:" - f" {out_.max():.4f}/{out_f16_.max():.4f}" - ) - print( - f"L{layer_number}: real/shadow stats min:" - f" {aux_ctx_tensors[0].min():.4f}/{aux_ctx_tensors_f16[0].min():.4f}, max:" - f" {aux_ctx_tensors[0].max():.4f}/{aux_ctx_tensors_f16[0].max():.4f}" - ) - if _replace_out_with_shadow: - out_ = out_f16_ - if _replace_aux_with_shadow: - aux_ctx_tensors[0] = aux_ctx_tensors_f16[0] + print(f"L{layer_number}: real/shadow out min: {out_.min():.4f}/{out_f16_.min():.4f}, max: {out_.max():.4f}/{out_f16_.max():.4f}") + print(f"L{layer_number}: real/shadow stats min: {aux_ctx_tensors[0].min():.4f}/{aux_ctx_tensors_f16[0].min():.4f}, max: {aux_ctx_tensors[0].max():.4f}/{aux_ctx_tensors_f16[0].max():.4f}") # out_fp8: Float8Tensor/MXFP8Tensor; dtype = torch.float16 or torch.bfloat16 # fp8_dtype = tex.DType.kFloat8E4M3 @@ -1442,6 +1431,10 @@ def forward( # return appropriate tensors out_ret = out_fp8 if is_output_fp8 else out + if _run_shadow_f16_fwd and _replace_out_return_with_shadow_f16: + out_ret = out_f16_ + if _run_shadow_f16_fwd and _replace_aux_with_shadow_f16: + aux_ctx_tensors[0] = aux_ctx_tensors_f16[0] # save appropriate tensors fp8_tensors = (None, None, None, None) @@ -1459,6 +1452,21 @@ def forward( else: if is_input_fp8: q, k, v = combine_and_dequantize(qkv_layout, q_fp8, k_fp8, v_fp8) + if _run_shadow_f16_fwd and not _replace_aux_with_shadow_f16: + tmp_quantizer = QKV_quantizer.copy() + tmp_quantizer.optimize_for_gemm = False + q_fp8_, k_fp8_, _, _ = combine_and_quantize(original_qkv_layout, q, k, v, tmp_quantizer) + q_ = q_fp8_.dequantize(dtype=out_nominal_dtype) + k_ = k_fp8_.dequantize(dtype=out_nominal_dtype) + qkv_format, *_ = dpa_utils.get_qkv_format(original_qkv_layout) + if qkv_format == "bshd": + q = q_.permute(0, 2, 1, 3).contiguous() + k = k_.permute(0, 2, 1, 3).contiguous() + elif qkv_format == "sbhd": + q = q_.permute(2, 0, 1, 3).contiguous() + k = k_.permute(2, 0, 1, 3).contiguous() + if _run_shadow_f16_fwd and _replace_out_save_with_shadow_f16: + out = out_f16_ qkvo_tensors = (q, k, v, out) else: # q, k, v, out_: torch.Tensor; dtype = torch.float16 or torch.bfloat16 From 29c2f4bbd198c53b08006bedae4694f1c9f6d2b3 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 19 Mar 2026 02:25:28 +0000 Subject: [PATCH 091/172] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../dot_product_attention/backends.py | 20 +++++++++++++++---- 1 file changed, 16 insertions(+), 4 deletions(-) diff --git a/transformer_engine/pytorch/attention/dot_product_attention/backends.py b/transformer_engine/pytorch/attention/dot_product_attention/backends.py index 4727bdefa0..215e05f9bb 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/backends.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/backends.py @@ -155,7 +155,9 @@ # Float8CurrentScaling: fused_attn_bwd takes O in FP8 by default, this flag allows it in F16 _dpa_fp8_cs_o_in_f16 = os.getenv("NVTE_DPA_FP8CS_O_in_F16", "1") == "1" _run_shadow_f16_fwd = os.getenv("NVTE_RUN_SHADOW_F16_FWD", "0") == "1" -_replace_out_return_with_shadow_f16 = os.getenv("NVTE_REPLACE_OUT_RETURN_WITH_SHADOW_F16", "0") == "1" +_replace_out_return_with_shadow_f16 = ( + os.getenv("NVTE_REPLACE_OUT_RETURN_WITH_SHADOW_F16", "0") == "1" +) _replace_out_save_with_shadow_f16 = os.getenv("NVTE_REPLACE_OUT_SAVE_WITH_SHADOW_F16", "0") == "1" _replace_aux_with_shadow_f16 = os.getenv("NVTE_REPLACE_AUX_WITH_SHADOW_F16", "0") == "1" @@ -1384,8 +1386,16 @@ def forward( is_graph_capturing(), ) if torch.cuda.current_device() == 0: - print(f"L{layer_number}: real/shadow out min: {out_.min():.4f}/{out_f16_.min():.4f}, max: {out_.max():.4f}/{out_f16_.max():.4f}") - print(f"L{layer_number}: real/shadow stats min: {aux_ctx_tensors[0].min():.4f}/{aux_ctx_tensors_f16[0].min():.4f}, max: {aux_ctx_tensors[0].max():.4f}/{aux_ctx_tensors_f16[0].max():.4f}") + print( + f"L{layer_number}: real/shadow out min:" + f" {out_.min():.4f}/{out_f16_.min():.4f}, max:" + f" {out_.max():.4f}/{out_f16_.max():.4f}" + ) + print( + f"L{layer_number}: real/shadow stats min:" + f" {aux_ctx_tensors[0].min():.4f}/{aux_ctx_tensors_f16[0].min():.4f}, max:" + f" {aux_ctx_tensors[0].max():.4f}/{aux_ctx_tensors_f16[0].max():.4f}" + ) # out_fp8: Float8Tensor/MXFP8Tensor; dtype = torch.float16 or torch.bfloat16 # fp8_dtype = tex.DType.kFloat8E4M3 @@ -1455,7 +1465,9 @@ def forward( if _run_shadow_f16_fwd and not _replace_aux_with_shadow_f16: tmp_quantizer = QKV_quantizer.copy() tmp_quantizer.optimize_for_gemm = False - q_fp8_, k_fp8_, _, _ = combine_and_quantize(original_qkv_layout, q, k, v, tmp_quantizer) + q_fp8_, k_fp8_, _, _ = combine_and_quantize( + original_qkv_layout, q, k, v, tmp_quantizer + ) q_ = q_fp8_.dequantize(dtype=out_nominal_dtype) k_ = k_fp8_.dequantize(dtype=out_nominal_dtype) qkv_format, *_ = dpa_utils.get_qkv_format(original_qkv_layout) From c10f05cdaf75fe49604ad6ec2fc6d73c8276f9b2 Mon Sep 17 00:00:00 2001 From: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Date: Wed, 18 Mar 2026 19:36:07 -0700 Subject: [PATCH 092/172] allow other recipes than mxfp8 Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --- .../dot_product_attention/backends.py | 24 ++++++++++--------- 1 file changed, 13 insertions(+), 11 deletions(-) diff --git a/transformer_engine/pytorch/attention/dot_product_attention/backends.py b/transformer_engine/pytorch/attention/dot_product_attention/backends.py index 215e05f9bb..bfb180c6d4 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/backends.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/backends.py @@ -1464,19 +1464,21 @@ def forward( q, k, v = combine_and_dequantize(qkv_layout, q_fp8, k_fp8, v_fp8) if _run_shadow_f16_fwd and not _replace_aux_with_shadow_f16: tmp_quantizer = QKV_quantizer.copy() - tmp_quantizer.optimize_for_gemm = False - q_fp8_, k_fp8_, _, _ = combine_and_quantize( - original_qkv_layout, q, k, v, tmp_quantizer - ) + if isinstance(tmp_quantizer, MXFP8Quantizer): + tmp_quantizer.optimize_for_gemm = False + q_fp8_, k_fp8_, _, _ = combine_and_quantize(original_qkv_layout, q, k, v, tmp_quantizer) q_ = q_fp8_.dequantize(dtype=out_nominal_dtype) k_ = k_fp8_.dequantize(dtype=out_nominal_dtype) - qkv_format, *_ = dpa_utils.get_qkv_format(original_qkv_layout) - if qkv_format == "bshd": - q = q_.permute(0, 2, 1, 3).contiguous() - k = k_.permute(0, 2, 1, 3).contiguous() - elif qkv_format == "sbhd": - q = q_.permute(2, 0, 1, 3).contiguous() - k = k_.permute(2, 0, 1, 3).contiguous() + if isinstance(tmp_quantizer, MXFP8Quantizer): + qkv_format, *_ = dpa_utils.get_qkv_format(original_qkv_layout) + if qkv_format == "bshd": + q = q_.permute(0, 2, 1, 3).contiguous() + k = k_.permute(0, 2, 1, 3).contiguous() + elif qkv_format == "sbhd": + q = q_.permute(2, 0, 1, 3).contiguous() + k = k_.permute(2, 0, 1, 3).contiguous() + else: + q, k = q_, k_ if _run_shadow_f16_fwd and _replace_out_save_with_shadow_f16: out = out_f16_ qkvo_tensors = (q, k, v, out) From 773c6782a8a6b4e052071405b85d4749ca66b2b6 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 19 Mar 2026 02:38:01 +0000 Subject: [PATCH 093/172] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../pytorch/attention/dot_product_attention/backends.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/transformer_engine/pytorch/attention/dot_product_attention/backends.py b/transformer_engine/pytorch/attention/dot_product_attention/backends.py index bfb180c6d4..38ad4bc74d 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/backends.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/backends.py @@ -1466,7 +1466,9 @@ def forward( tmp_quantizer = QKV_quantizer.copy() if isinstance(tmp_quantizer, MXFP8Quantizer): tmp_quantizer.optimize_for_gemm = False - q_fp8_, k_fp8_, _, _ = combine_and_quantize(original_qkv_layout, q, k, v, tmp_quantizer) + q_fp8_, k_fp8_, _, _ = combine_and_quantize( + original_qkv_layout, q, k, v, tmp_quantizer + ) q_ = q_fp8_.dequantize(dtype=out_nominal_dtype) k_ = k_fp8_.dequantize(dtype=out_nominal_dtype) if isinstance(tmp_quantizer, MXFP8Quantizer): From 0ef408b9ca29ac93982dc3e66cbf1f14cfc13daf Mon Sep 17 00:00:00 2001 From: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Date: Thu, 19 Mar 2026 10:08:36 -0700 Subject: [PATCH 094/172] fix grouped tensor for MLA Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --- .../attention/dot_product_attention/utils.py | 47 ++++++++++--------- 1 file changed, 25 insertions(+), 22 deletions(-) diff --git a/transformer_engine/pytorch/attention/dot_product_attention/utils.py b/transformer_engine/pytorch/attention/dot_product_attention/utils.py index 5eb7178e9d..bd4a7c3dc2 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/utils.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/utils.py @@ -2313,28 +2313,31 @@ def combine_and_quantize(qkv_layout, q, k, v, qkv_quantizer): ) q, k, v = [x.view(-1, x.shape[-1]) for x in [q, k, v]] # quantize q, k, v - # if d_qk == d_v: - # grouped_tensor = GroupedTensor.create_and_quantize( - # tensors=[q, k, v], quantizer=qkv_quantizer - # ) - # q_fp8, k_fp8, v_fp8 = grouped_tensor.quantized_tensors - # else: - # grouped_tensor = GroupedTensor.create_and_quantize( - # tensors=[q, k], quantizer=qkv_quantizer - # ) - # q_fp8, k_fp8 = grouped_tensor.quantized_tensors - # v_fp8 = qkv_quantizer(v) - input_tensors = [q, k, v] - num_tensors = len(input_tensors) - shapes = [x.shape for x in input_tensors] - grouped_tensor = GroupedTensor.make_grouped_tensor_with_shapes( - num_tensors=num_tensors, - shapes=shapes, - quantizer=qkv_quantizer, - device="cuda", - dtype=q.dtype, - ) - q_fp8, k_fp8, v_fp8 = grouped_tensor.quantize(input_tensors) + if d_qk == d_v: + input_tensors = [q, k, v] + num_tensors = len(input_tensors) + shapes = [x.shape for x in input_tensors] + grouped_tensor = GroupedTensor.make_grouped_tensor_with_shapes( + num_tensors=num_tensors, + shapes=shapes, + quantizer=qkv_quantizer, + device="cuda", + dtype=q.dtype, + ) + q_fp8, k_fp8, v_fp8 = grouped_tensor.quantize(input_tensors) + else: + input_tensors = [q, k] + num_tensors = len(input_tensors) + shapes = [x.shape for x in input_tensors] + grouped_tensor = GroupedTensor.make_grouped_tensor_with_shapes( + num_tensors=num_tensors, + shapes=shapes, + quantizer=qkv_quantizer, + device="cuda", + dtype=q.dtype, + ) + q_fp8, k_fp8 = grouped_tensor.quantize(input_tensors) + v_fp8 = qkv_quantizer(v) # view rowwise/columnwise data back to original shapes, not rowwise_scale_inv/columnwise_scale_inv q_fp8, k_fp8, v_fp8 = [x.view(s) for x, s in zip([q_fp8, k_fp8, v_fp8], original_shapes)] From 4429e5853bdb26c5b6a5f4ac8ccaf66008d401cb Mon Sep 17 00:00:00 2001 From: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Date: Thu, 19 Mar 2026 10:52:29 -0700 Subject: [PATCH 095/172] change cp test configs Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --- tests/pytorch/attention/test_attention.py | 12 +++++++----- tests/pytorch/attention/test_attention_with_cp.py | 12 ++++++------ 2 files changed, 13 insertions(+), 11 deletions(-) diff --git a/tests/pytorch/attention/test_attention.py b/tests/pytorch/attention/test_attention.py index 7c46bc750f..1d24379a9f 100644 --- a/tests/pytorch/attention/test_attention.py +++ b/tests/pytorch/attention/test_attention.py @@ -1816,11 +1816,13 @@ def get_model(dtype, config): head_dim_v=128, attn_mask_type=attn_mask_type, ), - "fp8_10": ModelConfig(2, 8192, 32, 128, num_gqa_groups=4, attn_mask_type=attn_mask_type), - "fp8_11": ModelConfig(2, 8192, 64, 64, num_gqa_groups=8, attn_mask_type=attn_mask_type), - "fp8_12": ModelConfig(2, 8192, 64, 64, attn_mask_type=attn_mask_type, window_size=(128, 0)), - # "fp8_13": ModelConfig(2, 8192, 64, 64, num_gqa_groups=8, attn_mask_type="causal", softmax_type="learnable"), - # "fp8_14": ModelConfig(2, 8192, 64, 64, attn_mask_type="causal", window_size=(128, 0), softmax_type="learnable"), + "fp8_10": ModelConfig( + 2, 8192, 32, 128, attn_mask_type=attn_mask_type, window_size=(128, 0) + ), + "fp8_11": ModelConfig(2, 8192, 32, 128, num_gqa_groups=4, attn_mask_type=attn_mask_type), + "fp8_12": ModelConfig(2, 8192, 64, 64, num_gqa_groups=8, attn_mask_type=attn_mask_type, window_size=(128, 0)), + "fp8_13": ModelConfig(2, 8192, 64, 64, num_gqa_groups=8, attn_mask_type="causal", softmax_type="learnable"), + "fp8_14": ModelConfig(2, 8192, 64, 64, attn_mask_type="causal", window_size=(128, 0), softmax_type="learnable"), # "fp8_15": ModelConfig(2, 2048, 16, 128, attn_mask_type="padding"), # "fp8_16": ModelConfig(2, 2048, 24, 128, num_gqa_groups=12, attn_mask_type="padding"), # "fp8_17": ModelConfig(1, 8192, 32, 128, num_gqa_groups=4, attn_mask_type="padding"), diff --git a/tests/pytorch/attention/test_attention_with_cp.py b/tests/pytorch/attention/test_attention_with_cp.py index 2795c24998..45a214e364 100644 --- a/tests/pytorch/attention/test_attention_with_cp.py +++ b/tests/pytorch/attention/test_attention_with_cp.py @@ -47,13 +47,13 @@ "cp_1_1": ModelConfig(2, 4096, 12, 128), # MHA "cp_1_2": ModelConfig(2, 4096, 12, 128, attn_mask_type="causal", window_size=(512, 0)), # MHA "cp_1_3": ModelConfig(2, 4096, 12, 128, window_size=(512, 512)), # MHA - "cp_2_0": ModelConfig(2, 4096, 12, 128, num_gqa_groups=2, attn_mask_type="causal"), # GQA + "cp_2_0": ModelConfig(2, 4096, 32, 128, num_gqa_groups=4, attn_mask_type="causal"), # GQA "cp_2_1": ModelConfig(2, 4096, 12, 128, num_gqa_groups=2), # GQA "cp_2_2": ModelConfig( - 2, 4096, 12, 128, num_gqa_groups=2, attn_mask_type="causal", window_size=(512, 0) + 2, 4096, 32, 128, attn_mask_type="causal", window_size=(128, 0) ), # GQA "cp_2_3": ModelConfig(2, 4096, 12, 128, num_gqa_groups=2, window_size=(512, 512)), # GQA - "cp_3_0": ModelConfig(2, 4096, 12, 192, attn_mask_type="causal", head_dim_v=128), # MLA + "cp_3_0": ModelConfig(2, 4096, 128, 192, attn_mask_type="causal", head_dim_v=128), # MLA "cp_3_1": ModelConfig(2, 4096, 12, 192, head_dim_v=128), # MLA "cp_3_2": ModelConfig( 2, 4096, 12, 192, attn_mask_type="causal", window_size=(512, 0), head_dim_v=128 @@ -81,10 +81,10 @@ def get_bash_arguments(num_gpus_per_node, **kwargs): qkv_formats = ["bshd", "sbhd", "thd"] cp_comm_types = ["p2p", "all_gather", "a2a", "a2a+p2p"] if test_essential: - configs = ["cp_1_0", "cp_1_2", "cp_2_1", "cp_3_2", "cp_3_3"] + configs = ["cp_2_0", "cp_3_0", "cp_2_2"] #, "cp_1_2", "cp_2_1"]#, "cp_1_1", "cp_3_3"] model_configs_flash_attn = {k: model_configs_flash_attn[k] for k in configs} dtypes = ["bf16"] - qkv_formats = ["sbhd", "thd"] + qkv_formats = ["bshd","sbhd", "thd"] @pytest.mark.skipif(not FlashAttentionUtils.v2_plus, reason="Flash-attn 2.0+ is required.") @@ -257,7 +257,7 @@ def test_cp_with_flash_attention(dtype, model, qkv_format, cp_comm_type): ] model_configs_fused_attn = {k: model_configs_fused_attn[k] for k in configs} dtypes = ["bf16", "fp8"] - qkv_formats = ["bshd"] # , "sbhd", "thd"] + qkv_formats = ["bshd", "sbhd","thd"] @pytest.mark.skipif(get_cudnn_version() < (8, 9, 7), reason="cuDNN 8.9.7+ is required.") From 08af36b64df480bf33c23148cc7a251c018319e6 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 19 Mar 2026 17:53:22 +0000 Subject: [PATCH 096/172] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/pytorch/attention/test_attention.py | 16 ++++++++++------ .../pytorch/attention/test_attention_with_cp.py | 10 ++++------ 2 files changed, 14 insertions(+), 12 deletions(-) diff --git a/tests/pytorch/attention/test_attention.py b/tests/pytorch/attention/test_attention.py index 1d24379a9f..b949c629d7 100644 --- a/tests/pytorch/attention/test_attention.py +++ b/tests/pytorch/attention/test_attention.py @@ -1816,13 +1816,17 @@ def get_model(dtype, config): head_dim_v=128, attn_mask_type=attn_mask_type, ), - "fp8_10": ModelConfig( - 2, 8192, 32, 128, attn_mask_type=attn_mask_type, window_size=(128, 0) - ), + "fp8_10": ModelConfig(2, 8192, 32, 128, attn_mask_type=attn_mask_type, window_size=(128, 0)), "fp8_11": ModelConfig(2, 8192, 32, 128, num_gqa_groups=4, attn_mask_type=attn_mask_type), - "fp8_12": ModelConfig(2, 8192, 64, 64, num_gqa_groups=8, attn_mask_type=attn_mask_type, window_size=(128, 0)), - "fp8_13": ModelConfig(2, 8192, 64, 64, num_gqa_groups=8, attn_mask_type="causal", softmax_type="learnable"), - "fp8_14": ModelConfig(2, 8192, 64, 64, attn_mask_type="causal", window_size=(128, 0), softmax_type="learnable"), + "fp8_12": ModelConfig( + 2, 8192, 64, 64, num_gqa_groups=8, attn_mask_type=attn_mask_type, window_size=(128, 0) + ), + "fp8_13": ModelConfig( + 2, 8192, 64, 64, num_gqa_groups=8, attn_mask_type="causal", softmax_type="learnable" + ), + "fp8_14": ModelConfig( + 2, 8192, 64, 64, attn_mask_type="causal", window_size=(128, 0), softmax_type="learnable" + ), # "fp8_15": ModelConfig(2, 2048, 16, 128, attn_mask_type="padding"), # "fp8_16": ModelConfig(2, 2048, 24, 128, num_gqa_groups=12, attn_mask_type="padding"), # "fp8_17": ModelConfig(1, 8192, 32, 128, num_gqa_groups=4, attn_mask_type="padding"), diff --git a/tests/pytorch/attention/test_attention_with_cp.py b/tests/pytorch/attention/test_attention_with_cp.py index 45a214e364..94e009e377 100644 --- a/tests/pytorch/attention/test_attention_with_cp.py +++ b/tests/pytorch/attention/test_attention_with_cp.py @@ -49,9 +49,7 @@ "cp_1_3": ModelConfig(2, 4096, 12, 128, window_size=(512, 512)), # MHA "cp_2_0": ModelConfig(2, 4096, 32, 128, num_gqa_groups=4, attn_mask_type="causal"), # GQA "cp_2_1": ModelConfig(2, 4096, 12, 128, num_gqa_groups=2), # GQA - "cp_2_2": ModelConfig( - 2, 4096, 32, 128, attn_mask_type="causal", window_size=(128, 0) - ), # GQA + "cp_2_2": ModelConfig(2, 4096, 32, 128, attn_mask_type="causal", window_size=(128, 0)), # GQA "cp_2_3": ModelConfig(2, 4096, 12, 128, num_gqa_groups=2, window_size=(512, 512)), # GQA "cp_3_0": ModelConfig(2, 4096, 128, 192, attn_mask_type="causal", head_dim_v=128), # MLA "cp_3_1": ModelConfig(2, 4096, 12, 192, head_dim_v=128), # MLA @@ -81,10 +79,10 @@ def get_bash_arguments(num_gpus_per_node, **kwargs): qkv_formats = ["bshd", "sbhd", "thd"] cp_comm_types = ["p2p", "all_gather", "a2a", "a2a+p2p"] if test_essential: - configs = ["cp_2_0", "cp_3_0", "cp_2_2"] #, "cp_1_2", "cp_2_1"]#, "cp_1_1", "cp_3_3"] + configs = ["cp_2_0", "cp_3_0", "cp_2_2"] # , "cp_1_2", "cp_2_1"]#, "cp_1_1", "cp_3_3"] model_configs_flash_attn = {k: model_configs_flash_attn[k] for k in configs} dtypes = ["bf16"] - qkv_formats = ["bshd","sbhd", "thd"] + qkv_formats = ["bshd", "sbhd", "thd"] @pytest.mark.skipif(not FlashAttentionUtils.v2_plus, reason="Flash-attn 2.0+ is required.") @@ -257,7 +255,7 @@ def test_cp_with_flash_attention(dtype, model, qkv_format, cp_comm_type): ] model_configs_fused_attn = {k: model_configs_fused_attn[k] for k in configs} dtypes = ["bf16", "fp8"] - qkv_formats = ["bshd", "sbhd","thd"] + qkv_formats = ["bshd", "sbhd", "thd"] @pytest.mark.skipif(get_cudnn_version() < (8, 9, 7), reason="cuDNN 8.9.7+ is required.") From 4dd1418e1f0976f23248bb6ff7ad28b37a5ef5c6 Mon Sep 17 00:00:00 2001 From: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Date: Thu, 19 Mar 2026 15:44:59 -0700 Subject: [PATCH 097/172] add shadow f16 bwd Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --- .../dot_product_attention/backends.py | 83 ++++++++++++++++++- 1 file changed, 82 insertions(+), 1 deletion(-) diff --git a/transformer_engine/pytorch/attention/dot_product_attention/backends.py b/transformer_engine/pytorch/attention/dot_product_attention/backends.py index 38ad4bc74d..88a23237dd 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/backends.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/backends.py @@ -160,7 +160,10 @@ ) _replace_out_save_with_shadow_f16 = os.getenv("NVTE_REPLACE_OUT_SAVE_WITH_SHADOW_F16", "0") == "1" _replace_aux_with_shadow_f16 = os.getenv("NVTE_REPLACE_AUX_WITH_SHADOW_F16", "0") == "1" - +_run_shadow_f16_bwd = os.getenv("NVTE_RUN_SHADOW_F16_BWD", "0") == "1" +_replace_dq_with_shadow_f16 = os.getenv("NVTE_REPLACE_DQ_WITH_SHADOW_F16", "0") == "1" +_replace_dk_with_shadow_f16 = os.getenv("NVTE_REPLACE_DK_WITH_SHADOW_F16", "0") == "1" +_replace_dv_with_shadow_f16 = os.getenv("NVTE_REPLACE_DV_WITH_SHADOW_F16", "0") == "1" class FP8EmulationFunc(torch.autograd.Function): """ @@ -1459,6 +1462,8 @@ def forward( fp8_recipe.float8_current_scaling() and not _dpa_fp8_cs_o_in_f16 ): fp8_tensors = (q_fp8, k_fp8, v_fp8, out_fp8) + if _run_shadow_f16_bwd: + qkvo_tensors = (q, k, v, out) else: if is_input_fp8: q, k, v = combine_and_dequantize(qkv_layout, q_fp8, k_fp8, v_fp8) @@ -1619,6 +1624,7 @@ def forward( @staticmethod def backward(ctx, d_out, *_args): # pylint: disable=missing-function-docstring + d_out_shadow_f16 = d_out # d_out: torch.Tensor; dtype = torch.float16 or torch.bfloat16 # d_out_fp8: Float8Tensor; dtype = torch.float16 or torch.bfloat16 @@ -1651,6 +1657,10 @@ def backward(ctx, d_out, *_args): ) = restore_from_saved(ctx.tensor_objects, ctx.saved_tensors) aux_ctx_tensors = other_tensors + aux_ctx_tensors_shadow_f16 = aux_ctx_tensors + out_shadow_f16 = out + original_qkv_layout = ctx.dqkv_layout + original_qkv_format, *_ = dpa_utils.get_qkv_format(original_qkv_layout) if not aux_ctx_tensors[0].is_contiguous(): aux_ctx_tensors[0] = aux_ctx_tensors[0].contiguous() @@ -1760,6 +1770,77 @@ def backward(ctx, d_out, *_args): ctx.deterministic, is_graph_capturing(), ) + if _run_shadow_f16_bwd: + original_qkv_layout = ctx.dqkv_layout + tmp_quantizer = ctx.QKV_quantizer.copy() + if isinstance(tmp_quantizer, MXFP8Quantizer): + tmp_quantizer.optimize_for_gemm = False + q_fp8_, k_fp8_, v_fp8_, _ = combine_and_quantize( + original_qkv_layout, q, k, v, tmp_quantizer + ) + q_shadow_f16, k_shadow_f16, v_shadow_f16 = [x.dequantize(dtype=dqkv_nominal_dtype) for x in (q_fp8_, k_fp8_, v_fp8_)] + if isinstance(tmp_quantizer, MXFP8Quantizer): + if original_qkv_format == "bshd": + q_shadow_f16, k_shadow_f16, v_shadow_f16 = [x.permute(0, 2, 1, 3).contiguous() for x in (q_shadow_f16, k_shadow_f16, v_shadow_f16)] + elif original_qkv_format == "sbhd": + q_shadow_f16, k_shadow_f16, v_shadow_f16 = [x.permute(2, 0, 1, 3).contiguous() for x in (q_shadow_f16, k_shadow_f16, v_shadow_f16)] + dq_shadow_f16, dk_shadow_f16, dv_shadow_f16, *rest = fused_attn_bwd( + ctx.max_seqlen_q, + ctx.max_seqlen_kv, + cu_seqlens_q, + cu_seqlens_kv, + q_shadow_f16, + k_shadow_f16, + v_shadow_f16, + out_shadow_f16, + d_out_shadow_f16, + dqkv_nominal_dtype, + aux_ctx_tensors_shadow_f16, + FusedAttnBackend["F16_arbitrary_seqlen"], + cu_seqlens_q_padded, + cu_seqlens_kv_padded, + None, + None, + None, + ctx.attn_scale, + ctx.dropout_p, + ctx.fast_zero_fill, + original_qkv_layout, + original_qkv_format, + original_qkv_format, + original_qkv_layout, + ctx.attn_bias_type, + ctx.attn_mask_type, + ctx.softmax_type, + ctx.window_size, + ctx.bottom_right_diagonal, + ctx.deterministic, + is_graph_capturing(), + ) + if _replace_dq_with_shadow_f16: + dq_ = dq_shadow_f16 + if _replace_dk_with_shadow_f16: + dk_ = dk_shadow_f16 + if _replace_dv_with_shadow_f16: + dv_ = dv_shadow_f16 + if torch.cuda.current_device() == 0: + print( + f"L{ctx.layer_number}: real/shadow dq min:" + f" {dq_.min():.4f}/{dq_shadow_f16.min():.4f}, max:" + f" {dq_.max():.4f}/{dq_shadow_f16.max():.4f}" + ) + print( + f"L{ctx.layer_number}: real/shadow dk min:" + f" {dk_.min():.4f}/{dk_shadow_f16.min():.4f}, max:" + f" {dk_.max():.4f}/{dk_shadow_f16.max():.4f}" + ) + print( + f"L{ctx.layer_number}: real/shadow dv min:" + f" {dv_.min():.4f}/{dv_shadow_f16.min():.4f}, max:" + f" {dv_.max():.4f}/{dv_shadow_f16.max():.4f}" + ) + + # dq, dk, dv: torch.Tensor; dtype = torch.float16 or torch.bfloat16 dq, dk, dv = dq_, dk_, dv_ From ad4d4dab8c63acabe0ca61ef58f178fb7b93f7a6 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 19 Mar 2026 22:46:26 +0000 Subject: [PATCH 098/172] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../attention/dot_product_attention/backends.py | 17 ++++++++++++----- 1 file changed, 12 insertions(+), 5 deletions(-) diff --git a/transformer_engine/pytorch/attention/dot_product_attention/backends.py b/transformer_engine/pytorch/attention/dot_product_attention/backends.py index 88a23237dd..ed0bf38101 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/backends.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/backends.py @@ -165,6 +165,7 @@ _replace_dk_with_shadow_f16 = os.getenv("NVTE_REPLACE_DK_WITH_SHADOW_F16", "0") == "1" _replace_dv_with_shadow_f16 = os.getenv("NVTE_REPLACE_DV_WITH_SHADOW_F16", "0") == "1" + class FP8EmulationFunc(torch.autograd.Function): """ Emulate the effects of FP8 quantization on tensors. Used in UnfusedDotProductAttention as follows: @@ -1778,12 +1779,20 @@ def backward(ctx, d_out, *_args): q_fp8_, k_fp8_, v_fp8_, _ = combine_and_quantize( original_qkv_layout, q, k, v, tmp_quantizer ) - q_shadow_f16, k_shadow_f16, v_shadow_f16 = [x.dequantize(dtype=dqkv_nominal_dtype) for x in (q_fp8_, k_fp8_, v_fp8_)] + q_shadow_f16, k_shadow_f16, v_shadow_f16 = [ + x.dequantize(dtype=dqkv_nominal_dtype) for x in (q_fp8_, k_fp8_, v_fp8_) + ] if isinstance(tmp_quantizer, MXFP8Quantizer): if original_qkv_format == "bshd": - q_shadow_f16, k_shadow_f16, v_shadow_f16 = [x.permute(0, 2, 1, 3).contiguous() for x in (q_shadow_f16, k_shadow_f16, v_shadow_f16)] + q_shadow_f16, k_shadow_f16, v_shadow_f16 = [ + x.permute(0, 2, 1, 3).contiguous() + for x in (q_shadow_f16, k_shadow_f16, v_shadow_f16) + ] elif original_qkv_format == "sbhd": - q_shadow_f16, k_shadow_f16, v_shadow_f16 = [x.permute(2, 0, 1, 3).contiguous() for x in (q_shadow_f16, k_shadow_f16, v_shadow_f16)] + q_shadow_f16, k_shadow_f16, v_shadow_f16 = [ + x.permute(2, 0, 1, 3).contiguous() + for x in (q_shadow_f16, k_shadow_f16, v_shadow_f16) + ] dq_shadow_f16, dk_shadow_f16, dv_shadow_f16, *rest = fused_attn_bwd( ctx.max_seqlen_q, ctx.max_seqlen_kv, @@ -1840,8 +1849,6 @@ def backward(ctx, d_out, *_args): f" {dv_.max():.4f}/{dv_shadow_f16.max():.4f}" ) - - # dq, dk, dv: torch.Tensor; dtype = torch.float16 or torch.bfloat16 dq, dk, dv = dq_, dk_, dv_ is_quantized_tensor = isinstance(dq_, QuantizedTensorStorage) From f2266f48193fdc78390c788ce4693e6e8db775f8 Mon Sep 17 00:00:00 2001 From: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Date: Fri, 20 Mar 2026 14:58:04 -0700 Subject: [PATCH 099/172] fix a2a+p2p for sbhd Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --- .../dot_product_attention/context_parallel.py | 39 +++++++++++++------ 1 file changed, 28 insertions(+), 11 deletions(-) diff --git a/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py b/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py index 116a6ba011..a46bff095e 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py @@ -424,13 +424,15 @@ def flash_attn_a2a_communicate( cp_stream: torch.cuda.Stream, before_attn: bool, qkv_format: str = "bshd", - cu_seqlens_padded: torch.Tensor = None, + cu_seqlens_q_padded: torch.Tensor = None, + cu_seqlens_kv_padded: torch.Tensor = None, + a2a_input_names: List[str] = None, ) -> Union[torch.Tensor, List[torch.Tensor]]: """A2A communication for context parallelism.""" - assert ( - qkv_format != "thd" or cu_seqlens_padded is not None - ), "cu_seqlens_padded is required for THD format!" + qkv_format != "thd" or cu_seqlens_q_padded is not None and cu_seqlens_kv_padded is not None + ), "cu_seqlens_q_padded and cu_seqlens_kv_padded are required for THD format!" + assert a2a_input_names in [["q", "k", "v"], ["out"], ["dout"], ["dq", "dk", "dv"]], "a2a_input_names must be one of ['q', 'k', 'v'], ['out'], ['dout'], ['dq', 'dk', 'dv']!" a2a_inputs = [a2a_inputs] if not isinstance(a2a_inputs, list) else a2a_inputs a2a_outputs, a2a_reqs = [None] * len(a2a_inputs), [None] * len(a2a_inputs) _, _, head_dim = get_bsh_dims(qkv_format) @@ -457,6 +459,7 @@ def flash_attn_a2a_communicate( *x.shape[:seq_dim], -1, *x.shape[(seq_dim + 2) :] ) else: # qkv_format == "thd" + cu_seqlens_padded = cu_seqlens_q_padded if a2a_input_names[i - 1] in ["q", "out", "dout", "dq"] else cu_seqlens_kv_padded # [cp, t, h//cp, d] -> [cp*t, h//cp, d] x = x.view(-1, *x.shape[2:]) # reorder the sequence chunks @@ -500,6 +503,7 @@ def flash_attn_a2a_communicate( x, chunk_ids_for_a2a, seq_dim, cp_size ) else: # qkv_format == "thd" + cu_seqlens_padded = cu_seqlens_q_padded if a2a_input_names[i - 1] in ["q", "out", "dout", "dq"] else cu_seqlens_kv_padded # reorder the sequence chunks x = reorder_seq_chunks_before_a2a_after_attn_thd(x, cu_seqlens_padded, cp_size) # [cp*t, h//cp, d] -> [cp, t, h//cp, d] @@ -1448,7 +1452,7 @@ def forward( q, k, v = [q_fp8._data, k_fp8._data, v_fp8._data] chunk_ids_for_a2a = get_seq_chunk_ids_for_reordering_before_attn(cp_size_a2a, q.device) q, k, v = flash_attn_a2a_communicate( - [q, k, v], chunk_ids_for_a2a, seq_dim, cp_size_a2a, cp_group_a2a, cp_stream, True + [q, k, v], chunk_ids_for_a2a, seq_dim, cp_size_a2a, cp_group_a2a, cp_stream, True, qkv_format=qkv_format, cu_seqlens_q_padded=cu_seqlens_q_padded, cu_seqlens_kv_padded=cu_seqlens_kv_padded, a2a_input_names=["q", "k", "v"] ) if fp8 and is_input_fp8 and not fp8_recipe.mxfp8(): q_fp8, k_fp8, v_fp8 = [ @@ -1963,7 +1967,7 @@ def forward( if cp_size_a2a > 1: chunk_ids_for_a2a = get_seq_chunk_ids_for_reordering_after_attn(cp_size_a2a, out.device) out = flash_attn_a2a_communicate( - out, chunk_ids_for_a2a, seq_dim, cp_size_a2a, cp_group_a2a, cp_stream, False + out, chunk_ids_for_a2a, seq_dim, cp_size_a2a, cp_group_a2a, cp_stream, False, qkv_format=o_format, cu_seqlens_q_padded=cu_seqlens_q_padded, a2a_input_names=["out"] ) out = out.view(orig_o_shape) if return_max_logit: @@ -2332,13 +2336,16 @@ def backward(ctx, dout, *_args): cp_size_a2a, out.device ) dout = flash_attn_a2a_communicate( - [dout], + dout, chunk_ids_for_a2a, seq_dim, cp_size_a2a, ctx.cp_group_a2a, ctx.cp_stream, True, + qkv_format=ctx.qkv_format, + cu_seqlens_q_padded=cu_seqlens_q_padded, + a2a_input_names=["dout"], ) out = out.view(*ctx.o_shape) dout = dout.view(*ctx.o_shape) @@ -2791,6 +2798,10 @@ def backward(ctx, dout, *_args): ctx.cp_group_a2a, ctx.cp_stream, False, + qkv_format=ctx.qkv_format, + cu_seqlens_q_padded=cu_seqlens_q_padded, + cu_seqlens_kv_padded=cu_seqlens_kv_padded, + a2a_input_names=["dq", "dk", "dv"], ) if ctx.fp8 and ctx.is_input_fp8 and not ctx.fp8_recipe.mxfp8(): dq, dk, dv = [ @@ -3889,7 +3900,9 @@ def forward( cp_stream, before_attn=True, qkv_format=qkv_format, - cu_seqlens_padded=cu_seqlens_q_padded, + cu_seqlens_q_padded=cu_seqlens_q_padded, + cu_seqlens_kv_padded=cu_seqlens_kv_padded, + a2a_input_names=["q", "k", "v"], ) # softmax_offset: split h @@ -4015,7 +4028,8 @@ def forward( cp_stream, before_attn=False, qkv_format=o_format, - cu_seqlens_padded=cu_seqlens_q_padded, + cu_seqlens_q_padded=cu_seqlens_q_padded, + a2a_input_names=["out"], ) # [b*s//cp, h, d] -> [b, s//cp, h, d] # [s//cp*b, h, d] -> [s//cp, b, h, d] @@ -4207,7 +4221,8 @@ def backward(ctx, dout, *_args): ctx.cp_stream, before_attn=True, qkv_format=ctx.o_format, - cu_seqlens_padded=cu_seqlens_q_padded, + cu_seqlens_q_padded=cu_seqlens_q_padded, + a2a_input_names=["dout"], ) flash_attn_bwd = None @@ -4349,7 +4364,9 @@ def backward(ctx, dout, *_args): ctx.cp_stream, before_attn=False, qkv_format=ctx.dqkv_format, - cu_seqlens_padded=cu_seqlens_q_padded, + cu_seqlens_q_padded=cu_seqlens_q_padded, + cu_seqlens_kv_padded=cu_seqlens_kv_padded, + a2a_input_names=["dq", "dk", "dv"], ) dq, dk, dv = [ x.view(y) From 1674b0fafc83bd1435723a9f69ef49a1df4c857d Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 20 Mar 2026 21:58:58 +0000 Subject: [PATCH 100/172] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../dot_product_attention/context_parallel.py | 42 ++++++++++++++++--- 1 file changed, 37 insertions(+), 5 deletions(-) diff --git a/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py b/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py index a46bff095e..577365e68e 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py @@ -432,7 +432,12 @@ def flash_attn_a2a_communicate( assert ( qkv_format != "thd" or cu_seqlens_q_padded is not None and cu_seqlens_kv_padded is not None ), "cu_seqlens_q_padded and cu_seqlens_kv_padded are required for THD format!" - assert a2a_input_names in [["q", "k", "v"], ["out"], ["dout"], ["dq", "dk", "dv"]], "a2a_input_names must be one of ['q', 'k', 'v'], ['out'], ['dout'], ['dq', 'dk', 'dv']!" + assert a2a_input_names in [ + ["q", "k", "v"], + ["out"], + ["dout"], + ["dq", "dk", "dv"], + ], "a2a_input_names must be one of ['q', 'k', 'v'], ['out'], ['dout'], ['dq', 'dk', 'dv']!" a2a_inputs = [a2a_inputs] if not isinstance(a2a_inputs, list) else a2a_inputs a2a_outputs, a2a_reqs = [None] * len(a2a_inputs), [None] * len(a2a_inputs) _, _, head_dim = get_bsh_dims(qkv_format) @@ -459,7 +464,11 @@ def flash_attn_a2a_communicate( *x.shape[:seq_dim], -1, *x.shape[(seq_dim + 2) :] ) else: # qkv_format == "thd" - cu_seqlens_padded = cu_seqlens_q_padded if a2a_input_names[i - 1] in ["q", "out", "dout", "dq"] else cu_seqlens_kv_padded + cu_seqlens_padded = ( + cu_seqlens_q_padded + if a2a_input_names[i - 1] in ["q", "out", "dout", "dq"] + else cu_seqlens_kv_padded + ) # [cp, t, h//cp, d] -> [cp*t, h//cp, d] x = x.view(-1, *x.shape[2:]) # reorder the sequence chunks @@ -503,7 +512,11 @@ def flash_attn_a2a_communicate( x, chunk_ids_for_a2a, seq_dim, cp_size ) else: # qkv_format == "thd" - cu_seqlens_padded = cu_seqlens_q_padded if a2a_input_names[i - 1] in ["q", "out", "dout", "dq"] else cu_seqlens_kv_padded + cu_seqlens_padded = ( + cu_seqlens_q_padded + if a2a_input_names[i - 1] in ["q", "out", "dout", "dq"] + else cu_seqlens_kv_padded + ) # reorder the sequence chunks x = reorder_seq_chunks_before_a2a_after_attn_thd(x, cu_seqlens_padded, cp_size) # [cp*t, h//cp, d] -> [cp, t, h//cp, d] @@ -1452,7 +1465,17 @@ def forward( q, k, v = [q_fp8._data, k_fp8._data, v_fp8._data] chunk_ids_for_a2a = get_seq_chunk_ids_for_reordering_before_attn(cp_size_a2a, q.device) q, k, v = flash_attn_a2a_communicate( - [q, k, v], chunk_ids_for_a2a, seq_dim, cp_size_a2a, cp_group_a2a, cp_stream, True, qkv_format=qkv_format, cu_seqlens_q_padded=cu_seqlens_q_padded, cu_seqlens_kv_padded=cu_seqlens_kv_padded, a2a_input_names=["q", "k", "v"] + [q, k, v], + chunk_ids_for_a2a, + seq_dim, + cp_size_a2a, + cp_group_a2a, + cp_stream, + True, + qkv_format=qkv_format, + cu_seqlens_q_padded=cu_seqlens_q_padded, + cu_seqlens_kv_padded=cu_seqlens_kv_padded, + a2a_input_names=["q", "k", "v"], ) if fp8 and is_input_fp8 and not fp8_recipe.mxfp8(): q_fp8, k_fp8, v_fp8 = [ @@ -1967,7 +1990,16 @@ def forward( if cp_size_a2a > 1: chunk_ids_for_a2a = get_seq_chunk_ids_for_reordering_after_attn(cp_size_a2a, out.device) out = flash_attn_a2a_communicate( - out, chunk_ids_for_a2a, seq_dim, cp_size_a2a, cp_group_a2a, cp_stream, False, qkv_format=o_format, cu_seqlens_q_padded=cu_seqlens_q_padded, a2a_input_names=["out"] + out, + chunk_ids_for_a2a, + seq_dim, + cp_size_a2a, + cp_group_a2a, + cp_stream, + False, + qkv_format=o_format, + cu_seqlens_q_padded=cu_seqlens_q_padded, + a2a_input_names=["out"], ) out = out.view(orig_o_shape) if return_max_logit: From 712d4f9357a10773cb61f65a345d6a37ea6e29e3 Mon Sep 17 00:00:00 2001 From: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Date: Fri, 20 Mar 2026 15:45:43 -0700 Subject: [PATCH 101/172] fix last commit and causal flag for fa Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --- .../dot_product_attention/context_parallel.py | 45 +++++++++---------- 1 file changed, 20 insertions(+), 25 deletions(-) diff --git a/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py b/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py index 577365e68e..98f6938e5b 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py @@ -429,15 +429,15 @@ def flash_attn_a2a_communicate( a2a_input_names: List[str] = None, ) -> Union[torch.Tensor, List[torch.Tensor]]: """A2A communication for context parallelism.""" - assert ( - qkv_format != "thd" or cu_seqlens_q_padded is not None and cu_seqlens_kv_padded is not None - ), "cu_seqlens_q_padded and cu_seqlens_kv_padded are required for THD format!" - assert a2a_input_names in [ - ["q", "k", "v"], - ["out"], - ["dout"], - ["dq", "dk", "dv"], - ], "a2a_input_names must be one of ['q', 'k', 'v'], ['out'], ['dout'], ['dq', 'dk', 'dv']!" + assert a2a_input_names in [["q", "k", "v"], ["out"], ["dout"], ["dq", "dk", "dv"]], "a2a_input_names must be one of ['q', 'k', 'v'], ['out'], ['dout'], ['dq', 'dk', 'dv']!" + if a2a_input_names in [["out"], ["dout"]]: + assert ( + qkv_format != "thd" or cu_seqlens_q_padded is not None + ), f"flash_attn_a2a_communicate requires cu_seqlens_q_padded for {a2a_input_names} with THD format!" + if a2a_input_names in [["q", "k", "v"], ["dq", "dk", "dv"]]: + assert ( + qkv_format != "thd" or (cu_seqlens_q_padded is not None and cu_seqlens_kv_padded is not None) + ), f"flash_attn_a2a_communicate requires cu_seqlens_q_padded and cu_seqlens_kv_padded for {a2a_input_names} with THD format!" a2a_inputs = [a2a_inputs] if not isinstance(a2a_inputs, list) else a2a_inputs a2a_outputs, a2a_reqs = [None] * len(a2a_inputs), [None] * len(a2a_inputs) _, _, head_dim = get_bsh_dims(qkv_format) @@ -464,11 +464,7 @@ def flash_attn_a2a_communicate( *x.shape[:seq_dim], -1, *x.shape[(seq_dim + 2) :] ) else: # qkv_format == "thd" - cu_seqlens_padded = ( - cu_seqlens_q_padded - if a2a_input_names[i - 1] in ["q", "out", "dout", "dq"] - else cu_seqlens_kv_padded - ) + cu_seqlens_padded = cu_seqlens_q_padded if a2a_input_names[i-2] in ["q", "out", "dout", "dq"] else cu_seqlens_kv_padded # [cp, t, h//cp, d] -> [cp*t, h//cp, d] x = x.view(-1, *x.shape[2:]) # reorder the sequence chunks @@ -512,11 +508,7 @@ def flash_attn_a2a_communicate( x, chunk_ids_for_a2a, seq_dim, cp_size ) else: # qkv_format == "thd" - cu_seqlens_padded = ( - cu_seqlens_q_padded - if a2a_input_names[i - 1] in ["q", "out", "dout", "dq"] - else cu_seqlens_kv_padded - ) + cu_seqlens_padded = cu_seqlens_q_padded if a2a_input_names[i] in ["q", "out", "dout", "dq"] else cu_seqlens_kv_padded # reorder the sequence chunks x = reorder_seq_chunks_before_a2a_after_attn_thd(x, cu_seqlens_padded, cp_size) # [cp*t, h//cp, d] -> [cp, t, h//cp, d] @@ -3021,6 +3013,7 @@ def forward( cu_seqlens_q_padded = None if use_fused_attention and attn_mask_type == "causal": attn_mask_type = attn_mask_type + "_bottom_right" + causal = "causal" in attn_mask_type # FP8 setup assert isinstance(k, q.__class__) and isinstance( @@ -3125,7 +3118,7 @@ def forward( max_seqlen_q, max_seqlen_kv, window_size, - "causal" in attn_mask_type, + causal, ) ) seq_start_idx, seq_end_idx = ( @@ -3217,7 +3210,7 @@ def forward( k_part, v_part, *fa_forward_args_thd, - causal="causal" in attn_mask_type, + causal=causal, **fa_forward_kwargs, ) if not fa_utils.v2_7_0_plus: @@ -3414,6 +3407,7 @@ def backward(ctx, dout, *_args): _, seq_dim_qkv, _ = get_bsh_dims(ctx.qkv_format) _, seq_dim_dqkv, _ = get_bsh_dims(ctx.dqkv_format) _, seq_dim_o, _ = get_bsh_dims(ctx.o_format) + causal = "causal" in ctx.attn_mask_type # set up dout: # FP8DS/CS: torch.uint8, [b, s, h, d] or [s, b, h, d] @@ -3659,9 +3653,9 @@ def backward(ctx, dout, *_args): fa_backward_kwargs["window_size_left"] = window_size_per_step[i][0] fa_backward_kwargs["window_size_right"] = window_size_per_step[i][1] if ctx.use_flash_attn_3: - fa_backward_kwargs["is_causal"] = "causal" in ctx.attn_mask_type + fa_backward_kwargs["is_causal"] = causal else: - fa_backward_kwargs["causal"] = "causal" in ctx.attn_mask_type + fa_backward_kwargs["causal"] = causal flash_attn_bwd( dout_part, q_part, @@ -3811,6 +3805,7 @@ def forward( _, seq_dim_o, _ = get_bsh_dims(o_format) if softmax_scale is None: softmax_scale = q.shape[-1] ** (-0.5) + causal = "causal" in attn_mask_type if qkv_format in ["bshd", "sbhd"]: assert ( @@ -4032,7 +4027,7 @@ def forward( k_part, v_part, *fa_forward_args_thd, - causal="causal" in attn_mask_type, + causal=causal, **fa_forward_kwargs, ) if not fa_utils.v2_7_0_plus: @@ -4213,6 +4208,7 @@ def backward(ctx, dout, *_args): _, seq_dim_do, _ = get_bsh_dims(ctx.o_format) bwd_nominal_dtype = ctx.fwd_nominal_dtype fused_attn_backend = None + causal = "causal" in ctx.attn_mask_type dout_fp8 = None fp8_meta_kwargs = {} @@ -4375,7 +4371,6 @@ def backward(ctx, dout, *_args): out, softmax_lse, *fa_backward_args_thd, - causal="causal" in ctx.attn_mask_type, **fa_backward_kwargs, ) From f9463e2a3be61c94fbc4178b991ec0669a671302 Mon Sep 17 00:00:00 2001 From: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Date: Fri, 20 Mar 2026 17:40:15 -0700 Subject: [PATCH 102/172] enable fp8 sink and disable fp8_mha Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --- tests/pytorch/attention/test_attention.py | 20 +++++++++------- .../common/fused_attn/fused_attn.cpp | 4 +++- .../attention/dot_product_attention/utils.py | 23 ++++++++++--------- 3 files changed, 27 insertions(+), 20 deletions(-) diff --git a/tests/pytorch/attention/test_attention.py b/tests/pytorch/attention/test_attention.py index b949c629d7..5c946d289d 100644 --- a/tests/pytorch/attention/test_attention.py +++ b/tests/pytorch/attention/test_attention.py @@ -1816,17 +1816,20 @@ def get_model(dtype, config): head_dim_v=128, attn_mask_type=attn_mask_type, ), - "fp8_10": ModelConfig(2, 8192, 32, 128, attn_mask_type=attn_mask_type, window_size=(128, 0)), - "fp8_11": ModelConfig(2, 8192, 32, 128, num_gqa_groups=4, attn_mask_type=attn_mask_type), + "fp8_10": ModelConfig(2, 8192, 32, 128, num_gqa_groups=4, attn_mask_type=attn_mask_type), + "fp8_11": ModelConfig(2, 8192, 32, 128, attn_mask_type=attn_mask_type, window_size=(128, 0)), "fp8_12": ModelConfig( - 2, 8192, 64, 64, num_gqa_groups=8, attn_mask_type=attn_mask_type, window_size=(128, 0) + 2, 8192, 64, 64, num_gqa_groups=8, attn_mask_type=attn_mask_type ), "fp8_13": ModelConfig( - 2, 8192, 64, 64, num_gqa_groups=8, attn_mask_type="causal", softmax_type="learnable" - ), - "fp8_14": ModelConfig( - 2, 8192, 64, 64, attn_mask_type="causal", window_size=(128, 0), softmax_type="learnable" - ), + 2, 8192, 64, 64, attn_mask_type=attn_mask_type, window_size=(128, 0) + ), + # "fp8_14": ModelConfig( + # 2, 8192, 64, 64, num_gqa_groups=8, attn_mask_type="causal", softmax_type="learnable" + # ), + # "fp8_15": ModelConfig( + # 2, 8192, 64, 64, attn_mask_type="causal", window_size=(128, 0), softmax_type="learnable" + # ), # "fp8_15": ModelConfig(2, 2048, 16, 128, attn_mask_type="padding"), # "fp8_16": ModelConfig(2, 2048, 24, 128, num_gqa_groups=12, attn_mask_type="padding"), # "fp8_17": ModelConfig(1, 8192, 32, 128, num_gqa_groups=4, attn_mask_type="padding"), @@ -2312,6 +2315,7 @@ def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker: layer_number=1, attention_type="self", qkv_format=qkv_format, + softmax_type=config.softmax_type, ).to(dtype=dtype, device="cuda") if not is_training: dpa = dpa.eval() diff --git a/transformer_engine/common/fused_attn/fused_attn.cpp b/transformer_engine/common/fused_attn/fused_attn.cpp index 0c4cbbabbf..e8b97b60d1 100644 --- a/transformer_engine/common/fused_attn/fused_attn.cpp +++ b/transformer_engine/common/fused_attn/fused_attn.cpp @@ -359,7 +359,9 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK)) || // 9.21: mxfp8, d_qk=128, d_v=192 (cudnn_runtime_version >= 92100 && head_dim_qk <= 192 && head_dim_v <= 128)) && - !requires_64bit_ragged_offset && (softmax_type == NVTE_Softmax_Type::NVTE_VANILLA_SOFTMAX) && + !requires_64bit_ragged_offset && + // pre-9.21: softmax_type=vanilla, 9.21+: softmax_type={vanilla, off-by-one, learnable} + ((cudnn_runtime_version < 92100 && softmax_type == NVTE_Softmax_Type::NVTE_VANILLA_SOFTMAX) || cudnn_runtime_version >= 92100) && // 9.10.0: known bugs with SDPA FP8 (cudnn_runtime_version != 91000) && !return_max_logit) { if (cudnn_runtime_version >= 8900) { diff --git a/transformer_engine/pytorch/attention/dot_product_attention/utils.py b/transformer_engine/pytorch/attention/dot_product_attention/utils.py index bd4a7c3dc2..a003226651 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/utils.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/utils.py @@ -543,6 +543,9 @@ def get_attention_backend( if device_compute_capability < (10, 0): logger.debug("Disabling FusedAttention for MXFP8 on arch < sm100") use_fused_attention = False + elif fp8_recipe.fp8_mha: + logger.debug("Disabling FusedAttention for MXFP8 with fp8_mha=True") + use_fused_attention = False else: if cudnn_version < (9, 21, 0): logger.debug("Disabling FusedAttention for MXFP8 with cuDNN < 9.21.0") @@ -762,8 +765,6 @@ def _is_fa3_supported(num_heads, num_gqa_groups, head_dim_qk, head_dim_v, qkv_dt logger.debug("Disabling FlashAttention for softmax_type = %s", softmax_type) use_flash_attention = False if fp8 and fp8_recipe.fp8_dpa: - logger.debug("Disabling FusedAttention for softmax_type = %s in FP8", softmax_type) - use_fused_attention = False logger.debug( "Disabling UnfusedDotProductAttention for softmax_type = %s in FP8", softmax_type ) @@ -1137,15 +1138,15 @@ def _is_fa3_supported(num_heads, num_gqa_groups, head_dim_qk, head_dim_v, qkv_dt ) use_flash_attention_2 = False if use_fused_attention and deterministic: - if softmax_type != "vanilla": - logger.debug( - "Disabling FusedAttention for determinism reasons with softmax_type = %s. " - "Sink attention (off-by-one and learnable softmax) requires " - "NVTE_ALLOW_NONDETERMINISTIC_ALGO=1", - softmax_type, - ) - use_fused_attention = False - fused_attention_backend = None + # if softmax_type != "vanilla": + # logger.debug( + # "Disabling FusedAttention for determinism reasons with softmax_type = %s. " + # "Sink attention (off-by-one and learnable softmax) requires " + # "NVTE_ALLOW_NONDETERMINISTIC_ALGO=1", + # softmax_type, + # ) + # use_fused_attention = False + # fused_attention_backend = None if ( fused_attention_backend == FusedAttnBackend["FP8"] and is_training From 299bc6389be5e5802fb6f1cc91df9b0658666e43 Mon Sep 17 00:00:00 2001 From: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Date: Fri, 20 Mar 2026 17:40:56 -0700 Subject: [PATCH 103/172] minor cleanup for cp/non-cp Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --- .../dot_product_attention/backends.py | 101 ++++++++---------- .../dot_product_attention/context_parallel.py | 51 +++++---- 2 files changed, 69 insertions(+), 83 deletions(-) diff --git a/transformer_engine/pytorch/attention/dot_product_attention/backends.py b/transformer_engine/pytorch/attention/dot_product_attention/backends.py index ed0bf38101..5ab3ff2507 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/backends.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/backends.py @@ -36,7 +36,7 @@ restore_from_saved, ) from transformer_engine.pytorch.tensor.float8_tensor import Float8Tensor -from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8Tensor, MXFP8TensorStorage +from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8TensorStorage from transformer_engine.pytorch.constants import ( TE_DType, QKVLayouts, @@ -185,26 +185,15 @@ def forward(ctx, tensor1, tensor2, tensor3, quantizer, quantizer_name, qkv_layou query_layer, key_layer, value_layer = [ x.contiguous() for x in [tensor1, tensor2, tensor3] ] - assert qkv_layout == "sbhd_sbhd_sbhd", ( - "sbhd_sbhd_sbhd is assumed to be the shape always at this point in" - " UnfusedDotProductAttention." - ) + # sbhd_sbhd_sbhd should always be the shape at this point q_fp8, k_fp8, v_fp8, qkv_layout = combine_and_quantize( qkv_layout, query_layer, key_layer, value_layer, quantizer ) tensors = combine_and_dequantize( - qkv_layout, - q_fp8, - k_fp8, - v_fp8, - src_nominal_dtype=query_layer.dtype, - des_nominal_dtype=query_layer.dtype, + qkv_layout, q_fp8, k_fp8, v_fp8, src_nominal_dtype=query_layer.dtype ) if isinstance(quantizer, MXFP8Quantizer): - assert qkv_layout == "bhsd_bhsd_bhsd", ( - "bhsd_bhsd_bhsd is assumed to be the shape always at this point in" - " UnfusedDotProductAttention." - ) + # bhsd_bhsd_bhsd should always be the shape at this point # permute back to sbhd_sbhd_sbhd tensors = [x.permute(2, 0, 1, 3).contiguous() for x in tensors] elif quantizer_name in ["S_quantizer", "O_quantizer"]: @@ -238,10 +227,7 @@ def backward(ctx, grad1, grad2, grad3): ctx.qkv_layout, dq_fp8, dk_fp8, dv_fp8, src_nominal_dtype=query_grad.dtype ) if isinstance(ctx.quantizer, MXFP8Quantizer): - assert ctx.qkv_layout == "bhsd_bhsd_bhsd", ( - "bhsd_bhsd_bhsd is assumed to be the shape always at this point in" - " UnfusedDotProductAttention." - ) + # bhsd_bhsd_bhsd should always be the shape at this point # permute back to sbhd_sbhd_sbhd tensors = [x.permute(2, 0, 1, 3).contiguous() for x in tensors] else: @@ -447,7 +433,7 @@ def forward( ) ) - # [b, np, sq, sk] + # [b, h, sq, sk] output_size = ( query_layer.size(1), query_layer.size(2), @@ -466,7 +452,7 @@ def forward( int(query_layer.shape[2] / value_layer.shape[2]), dim=2 ) - # preallocting result tensor: [b * np, sq, sk] + # preallocting result tensor: [b * h, sq, sk] matmul_result = torch.empty( output_size[0] * output_size[1], output_size[2], @@ -529,17 +515,17 @@ def forward( "sbhd_sbhd_sbhd", ) - # [sq, b, np, hn] -> [sq, b * np, hn] + # [sq, b, h, d] -> [sq, b * h, d] query_layer = query_layer.reshape(output_size[2], output_size[0] * output_size[1], -1) - # [sk, b, np, hn] -> [sk, b * np, hn] + # [sk, b, h, d] -> [sk, b * h, d] key_layer = key_layer.reshape(output_size[3], output_size[0] * output_size[1], -1) - # Raw attention scores. [b * np, sq, sk] + # Raw attention scores. [b * h, sq, sk] if core_attention_bias_type == "no_bias": matmul_result = torch.baddbmm( matmul_result, - query_layer.transpose(0, 1), # [b * np, sq, hn] - key_layer.transpose(0, 1).transpose(1, 2), # [b * np, hn, sk] + query_layer.transpose(0, 1), # [b * h, sq, d] + key_layer.transpose(0, 1).transpose(1, 2), # [b * h, d, sk] beta=0.0, alpha=scale, ).view(*output_size) @@ -547,8 +533,8 @@ def forward( elif core_attention_bias_type == "pre_scale_bias": assert core_attention_bias is not None, "core_attention_bias should not be None!" matmul_result = torch.bmm( - query_layer.transpose(0, 1), # [b * np, sq, hn] - key_layer.transpose(0, 1).transpose(1, 2), # [b * np, hn, sk] + query_layer.transpose(0, 1), # [b * h, sq, d] + key_layer.transpose(0, 1).transpose(1, 2), # [b * h, d, sk] ) matmul_result = matmul_result.view(*output_size) + core_attention_bias matmul_result *= scale @@ -573,8 +559,8 @@ def forward( ) matmul_result = torch.baddbmm( matmul_result, - query_layer.transpose(0, 1), # [b * np, sq, hn] - key_layer.transpose(0, 1).transpose(1, 2), # [b * np, hn, sk] + query_layer.transpose(0, 1), # [b * h, sq, d] + key_layer.transpose(0, 1).transpose(1, 2), # [b * h, d, sk] beta=0.0, alpha=scale, ) @@ -591,13 +577,13 @@ def forward( # max attention score max_logit = None if self.return_max_logit: - # matmul_result [b, np, sq, dk], max_logit [np] + # matmul_result [b, h, sq, dk], max_logit [h] max_logit = matmul_result if attn_mask_type != "no_mask": max_logit = self.mask_func(matmul_result, attention_mask) max_logit = torch.amax(max_logit, dim=(0, 2, 3)) - # add attention sink to the last column: [b, np, sq, sk+1] + # add attention sink to the last column: [b, h, sq, sk+1] if self.softmax_type != "vanilla": matmul_result = torch.cat( [ @@ -622,7 +608,7 @@ def forward( if "padding" in attn_mask_type: attention_probs = attention_probs.masked_fill(attention_mask, 0) - # remove attention sink: [b, np, sq, sk] + # remove attention sink: [b, h, sq, sk] if self.softmax_type != "vanilla": attention_probs = attention_probs[..., :-1] @@ -632,7 +618,7 @@ def forward( attention_probs = self.attention_dropout(attention_probs) # value_layer -> context layer. - # [sk, b, np, hn] --> [b, np, sq, hn] + # [sk, b, h, d] --> [b, h, sq, d] output_size = ( value_layer.size(1), value_layer.size(2), @@ -640,10 +626,10 @@ def forward( value_layer.size(3), ) - # change view [sk, b * np, hn] + # change view [sk, b * h, d] value_layer = value_layer.reshape(value_layer.size(0), output_size[0] * output_size[1], -1) - # change view [b * np, sq, sk] + # change view [b * h, sq, sk] attention_probs = attention_probs.view(output_size[0] * output_size[1], output_size[2], -1) if fp8: @@ -652,37 +638,37 @@ def forward( attention_probs, None, None, S_quantizer, "S_quantizer", None ) - # matmul: [b * np, sq, hn] + # matmul: [b * h, sq, d] context_layer = torch.bmm(attention_probs, value_layer.transpose(0, 1)) - # change view [b, np, sq, hn] + # change view [b, h, sq, d] context_layer = context_layer.view(*output_size) if q_format == "sbhd": - # [b, np, sq, hn] --> [sq, b, np, hn] + # [b, h, sq, d] --> [sq, b, h, d] context_layer = context_layer.permute(2, 0, 1, 3).contiguous() - # [sq, b, np, hn] --> [sq, b, hp] + # [sq, b, h, d] --> [sq, b, hd] context_layer = context_layer.view(max_seqlen_q, batch_size, -1) if q_format == "bshd": - # [b, np, sq, hn] --> [b, sq, np, hn] + # [b, h, sq, d] --> [b, sq, h, d] context_layer = context_layer.permute(0, 2, 1, 3).contiguous() - # [b, sq, np, hn] --> [b, sq, hp] + # [b, sq, h, d] --> [b, sq, hd] context_layer = context_layer.view(batch_size, max_seqlen_q, -1) if q_format == "thd": - # [b, np, sq, hn] --> [b, sq, np, hn] + # [b, h, sq, d] --> [b, sq, h, d] context_layer = context_layer.permute(0, 2, 1, 3).contiguous() - # [b, sq, np, hn] --> [tq, np, hn] + # [b, sq, h, d] --> [tq, h, d] context_layer = ConvertBSHDtoTHD.apply( context_layer, cu_seqlens_q, ) - # [tq, np, hn] --> [tq, hp] + # [tq, h, d] --> [tq, hd] context_layer = context_layer.view(context_layer.shape[0], -1) if fp8: @@ -1258,7 +1244,8 @@ def forward( if fp8_meta is not None and fp8_meta.get("local_recipes", None) is not None: fp8_recipe = fp8_meta["local_recipes"][0] - # save qkv_layout and get output format + # qkv_layout may change due to MXFP8 quantization + # o_format should stay the same as original qkv_format original_qkv_layout = qkv_layout _, o_format, _ = dpa_utils.get_qkv_format(qkv_layout) @@ -1600,12 +1587,15 @@ def forward( ctx.qkv_layout = reload_layout[:-1] else: ctx.qkv_layout = qkv_layout + if fp8 and not ctx.fp8: + ctx.qkv_layout = original_qkv_layout else: ctx.qkv_layout = qkv_layout if fp8 and not ctx.fp8: ctx.qkv_layout = original_qkv_layout ctx.o_format = o_format + # dqkv should have the same layout as the original qkv ctx.dqkv_layout = original_qkv_layout ctx.attn_bias_type = attn_bias_type ctx.attn_mask_type = attn_mask_type @@ -1717,18 +1707,18 @@ def backward(ctx, d_out, *_args): ctx.dP_quantizer, ) - # # get tex.DType for dq, dk, dv data - # dqkv_te_dtype = d_out_fp8._fp8_dtype - - # q_fp8, k_fp8, v_fp8, out_fp8: Float8Tensor; dtype = torch.float16 or torch.bfloat16, + # q_fp8, k_fp8, v_fp8: Float8Tensor/MXFP8Tensor; dtype = torch.float16 or torch.bfloat16, # fp8_dtype = tex.DType.kFloat8E4M3 - # d_out_fp8: Float8Tensor; dtype = torch.float16 or torch.bfloat16 + # d_out_fp8: Float8Tensor/MXFP8Tensor; dtype = torch.float16 or torch.bfloat16 # fp8_dtype = tex.DType.kFloat8E5M2 # DelayedScaling: # out_, dq_, dk_, dv_: Float8Tensor; dtype = torch.float16 or torch.bfloat16 # fp8_dtype = tex.DType.kFloat8E4M3 - # Float8CurrentScaling: + # Float8CurrentScaling + NVTE_DPA_FP8CS_O_in_F16=1: # out_, dq_, dk_, dv_: torch.Tensor; dtype = torch.float16 or torch.bfloat16 + # Float8CurrentScaling + NVTE_DPA_FP8CS_O_in_F16=0: + # out_: Float8Tensor; dtype = torch.float16 or torch.bfloat16 + # dq_, dk_, dv_: torch.Tensor; dtype = torch.float16 or torch.bfloat16 # MXFP8BlockScaling: # out_, dq_, dk_, dv_, d_out: torch.Tensor; dtype = torch.float16 or torch.bfloat16 out_ = out_fp8 @@ -1748,7 +1738,6 @@ def backward(ctx, d_out, *_args): out_, d_out_fp8, dqkv_nominal_dtype, - # dqkv_te_dtype, # could we remove this? aux_ctx_tensors, ctx.fused_attention_backend, cu_seqlens_q_padded, @@ -1881,7 +1870,6 @@ def backward(ctx, d_out, *_args): else: if isinstance(d_out, QuantizedTensorStorage): d_out = d_out.dequantize(dtype=ctx.nominal_dtype) - # dqkv_te_dtype = TE_DType[d_out.dtype] # q, k, v, out, d_out, dq, dk, dv: torch.Tensor; torch.float16 or torch.bfloat16 dq, dk, dv, *rest = fused_attn_bwd( ctx.max_seqlen_q, @@ -1894,7 +1882,6 @@ def backward(ctx, d_out, *_args): out, d_out, dqkv_nominal_dtype, - # dqkv_te_dtype, aux_ctx_tensors, ctx.fused_attention_backend, cu_seqlens_q_padded, @@ -2073,9 +2060,9 @@ def forward( fused_attention_backend != tex.NVTE_Fused_Attn_Backend.NVTE_No_Backend ), "No fused attention backend supports this input combination!" assert all( - x.dtype in [torch.float16, torch.bfloat16] or isinstance(x, Float8Tensor) + x.dtype in [torch.float16, torch.bfloat16] or isinstance(x, QuantizedTensorStorage) for x in [query_layer, key_layer, value_layer] - ), "FusedAttention only supports FP16 and BF16 data types, or Float8Tensors." + ), "FusedAttention only supports FP16 and BF16 data types, or QuantizedTensors." assert ( query_layer.is_cuda and key_layer.is_cuda and value_layer.is_cuda ), "FusedAttention only supports CUDA tensors." diff --git a/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py b/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py index 98f6938e5b..194fafe417 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py @@ -23,8 +23,6 @@ from transformer_engine.pytorch.quantization import FP8GlobalStateManager from transformer_engine.pytorch.tensor.float8_tensor import Float8Tensor from transformer_engine.pytorch.quantized_tensor import QuantizedTensorStorage -from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8Quantizer -from transformer_engine.common.recipe import MXFP8BlockScaling, Format from transformer_engine.pytorch.jit import jit_fuser from transformer_engine.pytorch.graph import is_graph_capturing from transformer_engine.pytorch.constants import ( @@ -251,10 +249,10 @@ def get_seq_chunk_ids_for_reordering_after_attn(cp_size, device): def reorder_seq_chunks_for_a2a_before_attn(x, chunk_ids_for_a2a, seq_dim, cp_size): """Reorder sequence chunk for A2A communication before attention compute.""" # [cp, b, s, h//cp, d] -> [b, cp, s, h//cp, d] - # or [cp, s, b, h//cp, d] -> [cp, s, b, h//cp, d] + # [cp, s, b, h//cp, d] -> [cp, s, b, h//cp, d] x = x.movedim(0, seq_dim).contiguous() # [b, cp, s, h//cp, d] -> [b, cp*2, s//2, h//cp, d] - # or [cp, s, b, h//cp, d] -> [cp*2, s//2, b, h//cp, d] + # [cp, s, b, h//cp, d] -> [cp*2, s//2, b, h//cp, d] x = x.view(*x.shape[:seq_dim], cp_size * 2, -1, *x.shape[(seq_dim + 2) :]) # reorder the sequence chunks x = torch.index_select(x, dim=seq_dim, index=chunk_ids_for_a2a) @@ -265,12 +263,12 @@ def reorder_seq_chunks_for_a2a_before_attn(x, chunk_ids_for_a2a, seq_dim, cp_siz def reorder_seq_chunks_for_a2a_after_attn(x, chunk_ids_for_a2a, seq_dim, cp_size): """Reorder sequence chunk for A2A communication after attention compute.""" # [b, cp*2, s//2, h//cp, d] -> [cp*2, b, s//2, h//cp, d] - # or [cp*2, s//2, b, h//cp, d] -> [cp*2, s//2, b, h//cp, d] + # [cp*2, s//2, b, h//cp, d] -> [cp*2, s//2, b, h//cp, d] x = x.movedim(seq_dim, 0).contiguous() # reorder the sequence chunks x = torch.index_select(x, dim=0, index=chunk_ids_for_a2a) # [cp*2, b, s//2, h//cp, d] -> [cp, 2, b, s//2, h//cp, d] - # or [cp*2, s//2, b, h//cp, d] -> [cp, 2, s//2, b, h//cp, d] + # [cp*2, s//2, b, h//cp, d] -> [cp, 2, s//2, b, h//cp, d] x = x.view(cp_size, 2, *x.shape[1:]) return x @@ -458,8 +456,8 @@ def flash_attn_a2a_communicate( x, chunk_ids_for_a2a, seq_dim, cp_size ) # [b, cp*2, s//2, h//cp, d] -> [b, cp*s, h//cp, d] - # or [cp*2, s//2, b, h//cp, d] -> [cp*s, b, h//cp, d] - # or [b, h//cp, cp*2, s//2, d] -> [b, h//cp, cp*s, d] + # [cp*2, s//2, b, h//cp, d] -> [cp*s, b, h//cp, d] + # [b, h//cp, cp*2, s//2, d] -> [b, h//cp, cp*s, d] a2a_outputs[i - 2] = x.view( *x.shape[:seq_dim], -1, *x.shape[(seq_dim + 2) :] ) @@ -475,9 +473,9 @@ def flash_attn_a2a_communicate( if i < len(a2a_inputs): x = a2a_inputs[i] # [b, s, h, d] -> [b, s, cp, h//cp, d] - # or [s, b, h, d] -> [s, b, cp, h//cp, d] - # or [b, h, s, d] -> [b, cp, h//cp, s, d] - # or [t, h, d] -> [t, cp, h//cp, d] + # [s, b, h, d] -> [s, b, cp, h//cp, d] + # [b, h, s, d] -> [b, cp, h//cp, s, d] + # [t, h, d] -> [t, cp, h//cp, d] x = x.view( *x.shape[:head_dim], cp_size, @@ -485,9 +483,9 @@ def flash_attn_a2a_communicate( *x.shape[head_dim + 1 :], ) # [b, s, cp, h//cp, d] -> [cp, b, s, h//cp, d] - # or [s, b, cp, h//cp, d] -> [cp, s, b, h//cp, d] - # or [b, cp, h//cp, s, d] -> [cp, b, h//cp, s, d] - # or [t, cp, h//cp, d] -> [cp, t, h//cp, d] + # [s, b, cp, h//cp, d] -> [cp, s, b, h//cp, d] + # [b, cp, h//cp, s, d] -> [cp, b, h//cp, s, d] + # [t, cp, h//cp, d] -> [cp, t, h//cp, d] a2a_inputs[i] = x.movedim(head_dim, 0).contiguous() else: for i in range(len(a2a_inputs) + 2): @@ -500,8 +498,8 @@ def flash_attn_a2a_communicate( x = a2a_inputs[i] if qkv_format in ["bshd", "sbhd", "bhsd"]: # [b, cp*s, h//cp, d] -> [b, cp*2, s//2, h//cp, d] - # or [cp*s, b, h//cp, d] -> [cp*2, s//2, b, h//cp, d] - # or [b, h//cp, cp*s, d] -> [b, h//cp, cp*2, s//2, d] + # [cp*s, b, h//cp, d] -> [cp*2, s//2, b, h//cp, d] + # [b, h//cp, cp*s, d] -> [b, h//cp, cp*2, s//2, d] x = x.view(*x.shape[:seq_dim], cp_size * 2, -1, *x.shape[(seq_dim + 1) :]) # reorder the sequence chunks a2a_inputs[i] = reorder_seq_chunks_for_a2a_after_attn( @@ -518,9 +516,9 @@ def flash_attn_a2a_communicate( a2a_reqs[i - 2].wait() x = a2a_outputs[i - 2] # [cp, 2, b, s//2, h//cp, d] -> [2, b, s//2, cp, h//cp, d] - # or [cp, 2, s//2, b, h//cp, d] -> [2, s//2, b, cp, h//cp, d] - # or [cp, 2, b, h//cp, s//2, d] -> [2, b, cp, h//cp, s//2, d] - # or [cp, t, h//cp, d] -> [t, cp, h//cp, d] + # [cp, 2, s//2, b, h//cp, d] -> [2, s//2, b, cp, h//cp, d] + # [cp, 2, b, h//cp, s//2, d] -> [2, b, cp, h//cp, s//2, d] + # [cp, t, h//cp, d] -> [t, cp, h//cp, d] tmp_list = [x for x in qkv_format] if "t" not in qkv_format: tmp_list.insert(0, "2") @@ -530,9 +528,9 @@ def flash_attn_a2a_communicate( tmp_list.insert(head_dim_, tmp_list.pop(0)) x = x.movedim(0, head_dim_) # [2, b, s//2, cp, h//cp, d] -> [b, 2, s//2, cp, h//cp, d] - # or [2, s//2, b, cp, h//cp, d] -> [2, s//2, b, cp, h//cp, d] - # or [2, b, cp, h//cp, s//2, d] -> [b, cp, h//cp, 2, s//2, d] - # or [t, cp, h//cp, d] -> [t, cp, h//cp, d] + # [2, s//2, b, cp, h//cp, d] -> [2, s//2, b, cp, h//cp, d] + # [2, b, cp, h//cp, s//2, d] -> [b, cp, h//cp, 2, s//2, d] + # [t, cp, h//cp, d] -> [t, cp, h//cp, d] if "t" not in qkv_format: tmp_format = "".join(tmp_list) seq_dim_ = tmp_format.index("s") - 1 @@ -542,9 +540,9 @@ def flash_attn_a2a_communicate( seq_dim_ = 0 x = x.contiguous() # [b, 2, s//2, cp, h//cp, d] -> [b*s, h, d] - # or [2, s//2, b, cp, h//cp, d] -> [s*b, h, d] - # or [b, cp, h//cp, 2, s//2, d] -> [b*h, s, d] - # or [t, cp, h//cp, d] -> [t, h, d] + # [2, s//2, b, cp, h//cp, d] -> [s*b, h, d] + # [b, cp, h//cp, 2, s//2, d] -> [b*h, s, d] + # [t, cp, h//cp, d] -> [t, h, d] a2a_outputs[i - 2] = x.view(-1, x.shape[-3] * x.shape[-2], x.shape[-1]) torch.cuda.current_stream().wait_stream(cp_stream) return a2a_outputs[0] if len(a2a_inputs) == 1 else a2a_outputs @@ -1226,7 +1224,6 @@ def cp_p2p_bwd_fused_attn( out_part, dout_part, bwd_nominal_dtype, - # bwd_output_te_dtype, aux_tensors, fused_attn_backend, cu_seqlens_q_padded=cu_seqlens_q_padded_, @@ -1434,6 +1431,8 @@ def forward( is_input_fp8 = isinstance(q, QuantizedTensorStorage) is_output_fp8 = fp8_output is_bwd_fp8 = int(os.getenv("NVTE_FP8_DPA_BWD", "1")) + # recipe passed in through autocast or set by NVTE_DPA_FP8_RECIPE; + # may be different from fp8_meta["recipe"] fp8_recipe = FP8GlobalStateManager.get_fp8_recipe() if fp8_meta is not None and fp8_meta.get("local_recipes", None) is not None: fp8_recipe = fp8_meta["local_recipes"][0] From ed629036b4f303ac8bc9ff1deebf1d074c3c8adc Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sat, 21 Mar 2026 00:43:49 +0000 Subject: [PATCH 104/172] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/pytorch/attention/test_attention.py | 8 ++--- .../common/fused_attn/fused_attn.cpp | 3 +- .../dot_product_attention/context_parallel.py | 35 ++++++++++++++----- 3 files changed, 30 insertions(+), 16 deletions(-) diff --git a/tests/pytorch/attention/test_attention.py b/tests/pytorch/attention/test_attention.py index 5c946d289d..e73e16eb3d 100644 --- a/tests/pytorch/attention/test_attention.py +++ b/tests/pytorch/attention/test_attention.py @@ -1818,12 +1818,8 @@ def get_model(dtype, config): ), "fp8_10": ModelConfig(2, 8192, 32, 128, num_gqa_groups=4, attn_mask_type=attn_mask_type), "fp8_11": ModelConfig(2, 8192, 32, 128, attn_mask_type=attn_mask_type, window_size=(128, 0)), - "fp8_12": ModelConfig( - 2, 8192, 64, 64, num_gqa_groups=8, attn_mask_type=attn_mask_type - ), - "fp8_13": ModelConfig( - 2, 8192, 64, 64, attn_mask_type=attn_mask_type, window_size=(128, 0) - ), + "fp8_12": ModelConfig(2, 8192, 64, 64, num_gqa_groups=8, attn_mask_type=attn_mask_type), + "fp8_13": ModelConfig(2, 8192, 64, 64, attn_mask_type=attn_mask_type, window_size=(128, 0)), # "fp8_14": ModelConfig( # 2, 8192, 64, 64, num_gqa_groups=8, attn_mask_type="causal", softmax_type="learnable" # ), diff --git a/transformer_engine/common/fused_attn/fused_attn.cpp b/transformer_engine/common/fused_attn/fused_attn.cpp index e8b97b60d1..b5ecde6750 100644 --- a/transformer_engine/common/fused_attn/fused_attn.cpp +++ b/transformer_engine/common/fused_attn/fused_attn.cpp @@ -361,7 +361,8 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( (cudnn_runtime_version >= 92100 && head_dim_qk <= 192 && head_dim_v <= 128)) && !requires_64bit_ragged_offset && // pre-9.21: softmax_type=vanilla, 9.21+: softmax_type={vanilla, off-by-one, learnable} - ((cudnn_runtime_version < 92100 && softmax_type == NVTE_Softmax_Type::NVTE_VANILLA_SOFTMAX) || cudnn_runtime_version >= 92100) && + ((cudnn_runtime_version < 92100 && softmax_type == NVTE_Softmax_Type::NVTE_VANILLA_SOFTMAX) || + cudnn_runtime_version >= 92100) && // 9.10.0: known bugs with SDPA FP8 (cudnn_runtime_version != 91000) && !return_max_logit) { if (cudnn_runtime_version >= 8900) { diff --git a/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py b/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py index 194fafe417..90df998e09 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py @@ -427,15 +427,24 @@ def flash_attn_a2a_communicate( a2a_input_names: List[str] = None, ) -> Union[torch.Tensor, List[torch.Tensor]]: """A2A communication for context parallelism.""" - assert a2a_input_names in [["q", "k", "v"], ["out"], ["dout"], ["dq", "dk", "dv"]], "a2a_input_names must be one of ['q', 'k', 'v'], ['out'], ['dout'], ['dq', 'dk', 'dv']!" + assert a2a_input_names in [ + ["q", "k", "v"], + ["out"], + ["dout"], + ["dq", "dk", "dv"], + ], "a2a_input_names must be one of ['q', 'k', 'v'], ['out'], ['dout'], ['dq', 'dk', 'dv']!" if a2a_input_names in [["out"], ["dout"]]: - assert ( - qkv_format != "thd" or cu_seqlens_q_padded is not None - ), f"flash_attn_a2a_communicate requires cu_seqlens_q_padded for {a2a_input_names} with THD format!" + assert qkv_format != "thd" or cu_seqlens_q_padded is not None, ( + f"flash_attn_a2a_communicate requires cu_seqlens_q_padded for {a2a_input_names} with" + " THD format!" + ) if a2a_input_names in [["q", "k", "v"], ["dq", "dk", "dv"]]: - assert ( - qkv_format != "thd" or (cu_seqlens_q_padded is not None and cu_seqlens_kv_padded is not None) - ), f"flash_attn_a2a_communicate requires cu_seqlens_q_padded and cu_seqlens_kv_padded for {a2a_input_names} with THD format!" + assert qkv_format != "thd" or ( + cu_seqlens_q_padded is not None and cu_seqlens_kv_padded is not None + ), ( + "flash_attn_a2a_communicate requires cu_seqlens_q_padded and cu_seqlens_kv_padded for" + f" {a2a_input_names} with THD format!" + ) a2a_inputs = [a2a_inputs] if not isinstance(a2a_inputs, list) else a2a_inputs a2a_outputs, a2a_reqs = [None] * len(a2a_inputs), [None] * len(a2a_inputs) _, _, head_dim = get_bsh_dims(qkv_format) @@ -462,7 +471,11 @@ def flash_attn_a2a_communicate( *x.shape[:seq_dim], -1, *x.shape[(seq_dim + 2) :] ) else: # qkv_format == "thd" - cu_seqlens_padded = cu_seqlens_q_padded if a2a_input_names[i-2] in ["q", "out", "dout", "dq"] else cu_seqlens_kv_padded + cu_seqlens_padded = ( + cu_seqlens_q_padded + if a2a_input_names[i - 2] in ["q", "out", "dout", "dq"] + else cu_seqlens_kv_padded + ) # [cp, t, h//cp, d] -> [cp*t, h//cp, d] x = x.view(-1, *x.shape[2:]) # reorder the sequence chunks @@ -506,7 +519,11 @@ def flash_attn_a2a_communicate( x, chunk_ids_for_a2a, seq_dim, cp_size ) else: # qkv_format == "thd" - cu_seqlens_padded = cu_seqlens_q_padded if a2a_input_names[i] in ["q", "out", "dout", "dq"] else cu_seqlens_kv_padded + cu_seqlens_padded = ( + cu_seqlens_q_padded + if a2a_input_names[i] in ["q", "out", "dout", "dq"] + else cu_seqlens_kv_padded + ) # reorder the sequence chunks x = reorder_seq_chunks_before_a2a_after_attn_thd(x, cu_seqlens_padded, cp_size) # [cp*t, h//cp, d] -> [cp, t, h//cp, d] From 94ae209f8ec842ec31553726a504340b59d34190 Mon Sep 17 00:00:00 2001 From: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Date: Mon, 23 Mar 2026 14:46:55 -0700 Subject: [PATCH 105/172] update FE for FP8 sink Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --- 3rdparty/cudnn-frontend | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/3rdparty/cudnn-frontend b/3rdparty/cudnn-frontend index 69432369f3..562c25b493 160000 --- a/3rdparty/cudnn-frontend +++ b/3rdparty/cudnn-frontend @@ -1 +1 @@ -Subproject commit 69432369f3060467c72d01bd08cfeb9271178c22 +Subproject commit 562c25b493ea6965d6997d8620b490c8d9ef2fcb From a9028b204d0cab525dbfb2a61adfaa3076aca9d8 Mon Sep 17 00:00:00 2001 From: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Date: Mon, 23 Mar 2026 14:47:48 -0700 Subject: [PATCH 106/172] fix TE for FP8 sink Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --- tests/pytorch/attention/test_attention.py | 29 ++++++++++++------- .../common/fused_attn/fused_attn.cpp | 8 ++--- .../common/fused_attn/fused_attn_fp8.cu | 6 ++-- 3 files changed, 25 insertions(+), 18 deletions(-) diff --git a/tests/pytorch/attention/test_attention.py b/tests/pytorch/attention/test_attention.py index e73e16eb3d..134f8e4db0 100644 --- a/tests/pytorch/attention/test_attention.py +++ b/tests/pytorch/attention/test_attention.py @@ -1818,14 +1818,18 @@ def get_model(dtype, config): ), "fp8_10": ModelConfig(2, 8192, 32, 128, num_gqa_groups=4, attn_mask_type=attn_mask_type), "fp8_11": ModelConfig(2, 8192, 32, 128, attn_mask_type=attn_mask_type, window_size=(128, 0)), - "fp8_12": ModelConfig(2, 8192, 64, 64, num_gqa_groups=8, attn_mask_type=attn_mask_type), - "fp8_13": ModelConfig(2, 8192, 64, 64, attn_mask_type=attn_mask_type, window_size=(128, 0)), - # "fp8_14": ModelConfig( - # 2, 8192, 64, 64, num_gqa_groups=8, attn_mask_type="causal", softmax_type="learnable" - # ), - # "fp8_15": ModelConfig( - # 2, 8192, 64, 64, attn_mask_type="causal", window_size=(128, 0), softmax_type="learnable" - # ), + "fp8_12": ModelConfig( + 2, 8192, 64, 64, num_gqa_groups=8, attn_mask_type=attn_mask_type + ), + "fp8_13": ModelConfig( + 2, 8192, 64, 64, attn_mask_type=attn_mask_type, window_size=(128, 0) + ), + "fp8_14": ModelConfig( + 2, 8192, 64, 64, num_gqa_groups=8, attn_mask_type="causal", softmax_type="learnable" + ), + "fp8_15": ModelConfig( + 2, 8192, 64, 64, attn_mask_type="causal", window_size=(128, 0), softmax_type="learnable" + ), # "fp8_15": ModelConfig(2, 2048, 16, 128, attn_mask_type="padding"), # "fp8_16": ModelConfig(2, 2048, 24, 128, num_gqa_groups=12, attn_mask_type="padding"), # "fp8_17": ModelConfig(1, 8192, 32, 128, num_gqa_groups=4, attn_mask_type="padding"), @@ -2212,7 +2216,7 @@ def test_dpa_fp8_vs_f16(dtype, model, qkv_layout, fp8_dpa_bwd, is_training, scal atol = 5e-1 rtol = 5e-2 rmse_tol = 0.11 - bwd_names = ["dq", "dk", "dv"] + bwd_names = ["dq", "dk", "dv", "d_softmax_offset"] if flash_attn_supported and fused_attn_supported_f16: logging.debug("========== {:^25s} ==========".format("flash fp8 vs fused f16:")) logging.debug("========== {:^25s} ==========".format("forward output")) @@ -2414,10 +2418,13 @@ def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker: ) if is_training: out.backward(out_grad) + d_softmax_offset = None + if is_training and config.softmax_type != "vanilla": + d_softmax_offset = dpa.softmax_offset.grad if is_training: - return out, (inp[0].grad, inp[1].grad, inp[2].grad) - return out, (None, None, None) + return out, (inp[0].grad, inp[1].grad, inp[2].grad, d_softmax_offset) + return out, (None, None, None, d_softmax_offset) model_configs_fp8 = { diff --git a/transformer_engine/common/fused_attn/fused_attn.cpp b/transformer_engine/common/fused_attn/fused_attn.cpp index b5ecde6750..546165337c 100644 --- a/transformer_engine/common/fused_attn/fused_attn.cpp +++ b/transformer_engine/common/fused_attn/fused_attn.cpp @@ -852,14 +852,14 @@ void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETenso const Tensor *input_M = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); const Tensor *input_ZInv = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); const Tensor *input_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); - const Tensor *input_dO_f16 = nullptr; - if (input_dO->scaling_mode == NVTE_MXFP8_1D_SCALING) { - input_dO_f16 = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); - } const Tensor *input_SoftmaxOffset = nullptr; if (softmax_type != NVTE_VANILLA_SOFTMAX) { input_SoftmaxOffset = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); } + const Tensor *input_dO_f16 = nullptr; + if (input_dO->scaling_mode == NVTE_MXFP8_1D_SCALING) { + input_dO_f16 = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); + } fused_attn_fp8_bwd(b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d_qk, d_v, attn_scale, dropout, qkv_layout, o_format, d_out_format, dqkv_layout, bias_type, attn_mask_type, softmax_type, window_size_left, window_size_right, bottom_right_diagonal, diff --git a/transformer_engine/common/fused_attn/fused_attn_fp8.cu b/transformer_engine/common/fused_attn/fused_attn_fp8.cu index 6c3d9a8161..9638c917bd 100644 --- a/transformer_engine/common/fused_attn/fused_attn_fp8.cu +++ b/transformer_engine/common/fused_attn/fused_attn_fp8.cu @@ -1920,7 +1920,7 @@ void fused_attn_fp8_fwd_impl_v1( .set_dim({1, h, 1, 1}) .set_stride({h, 1, 1, 1}) .set_data_type(fe::DataType_t::FLOAT)); - // sdpa_options.set_sink_token(softmax_offset); + sdpa_options.set_sink_token(softmax_offset); } std::shared_ptr O, Stats, amax_s, amax_o; @@ -2491,13 +2491,13 @@ void fused_attn_fp8_bwd_impl_v1( .set_dim({1, h, 1, 1}) .set_stride({h, 1, 1, 1}) .set_data_type(fe::DataType_t::FLOAT)); - // sdpa_backward_options.set_sink_token(softmax_offset); + sdpa_backward_options.set_sink_token(softmax_offset); d_softmax_offset = mha_graph->tensor(fe::graph::Tensor_attributes() .set_name("d_softmax_offset") .set_dim({1, h, 1, 1}) .set_stride({h, 1, 1, 1}) .set_data_type(fe::DataType_t::FLOAT)); - // sdpa_backward_options.set_dsink_token(d_softmax_offset); + sdpa_backward_options.set_dsink_token(d_softmax_offset); } std::shared_ptr dQ, dK, dV, amax_dQ, amax_dK, amax_dV, amax_dP; From a6f56e83f26cc1bb39b3d952af37681a10419cba Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 23 Mar 2026 21:49:52 +0000 Subject: [PATCH 107/172] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/pytorch/attention/test_attention.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/tests/pytorch/attention/test_attention.py b/tests/pytorch/attention/test_attention.py index 134f8e4db0..6b92c2e3ed 100644 --- a/tests/pytorch/attention/test_attention.py +++ b/tests/pytorch/attention/test_attention.py @@ -1818,12 +1818,8 @@ def get_model(dtype, config): ), "fp8_10": ModelConfig(2, 8192, 32, 128, num_gqa_groups=4, attn_mask_type=attn_mask_type), "fp8_11": ModelConfig(2, 8192, 32, 128, attn_mask_type=attn_mask_type, window_size=(128, 0)), - "fp8_12": ModelConfig( - 2, 8192, 64, 64, num_gqa_groups=8, attn_mask_type=attn_mask_type - ), - "fp8_13": ModelConfig( - 2, 8192, 64, 64, attn_mask_type=attn_mask_type, window_size=(128, 0) - ), + "fp8_12": ModelConfig(2, 8192, 64, 64, num_gqa_groups=8, attn_mask_type=attn_mask_type), + "fp8_13": ModelConfig(2, 8192, 64, 64, attn_mask_type=attn_mask_type, window_size=(128, 0)), "fp8_14": ModelConfig( 2, 8192, 64, 64, num_gqa_groups=8, attn_mask_type="causal", softmax_type="learnable" ), From 706095f802e04cbdd5d88ee53849cc5ec938203f Mon Sep 17 00:00:00 2001 From: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Date: Mon, 23 Mar 2026 14:57:50 -0700 Subject: [PATCH 108/172] temporary: random sink/print sink Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --- tests/pytorch/attention/test_attention.py | 2 ++ .../attention/dot_product_attention/dot_product_attention.py | 2 +- 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/tests/pytorch/attention/test_attention.py b/tests/pytorch/attention/test_attention.py index 6b92c2e3ed..4bac6572e0 100644 --- a/tests/pytorch/attention/test_attention.py +++ b/tests/pytorch/attention/test_attention.py @@ -2416,6 +2416,8 @@ def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker: out.backward(out_grad) d_softmax_offset = None if is_training and config.softmax_type != "vanilla": + print(f"softmax_offset: {dpa.softmax_offset}") + print(f"d_softmax_offset: {dpa.softmax_offset.grad}") d_softmax_offset = dpa.softmax_offset.grad if is_training: diff --git a/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py b/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py index d1fd0b0ed0..b610864416 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py @@ -466,7 +466,7 @@ def __init__( if self.softmax_type == "learnable": self.register_parameter( "softmax_offset", - Parameter(torch.zeros(self.num_attention_heads // self.tp_size, device="cuda")), + Parameter(torch.randn(self.num_attention_heads // self.tp_size, device="cuda")), get_rng_state_tracker=get_rng_state_tracker, ) From 4c004ee88080ac247d4302b9dc9f3bd00f1b6b0c Mon Sep 17 00:00:00 2001 From: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Date: Mon, 23 Mar 2026 14:58:12 -0700 Subject: [PATCH 109/172] Revert "temporary: random sink/print sink" This reverts commit 706095f802e04cbdd5d88ee53849cc5ec938203f. Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --- tests/pytorch/attention/test_attention.py | 2 -- .../attention/dot_product_attention/dot_product_attention.py | 2 +- 2 files changed, 1 insertion(+), 3 deletions(-) diff --git a/tests/pytorch/attention/test_attention.py b/tests/pytorch/attention/test_attention.py index 4bac6572e0..6b92c2e3ed 100644 --- a/tests/pytorch/attention/test_attention.py +++ b/tests/pytorch/attention/test_attention.py @@ -2416,8 +2416,6 @@ def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker: out.backward(out_grad) d_softmax_offset = None if is_training and config.softmax_type != "vanilla": - print(f"softmax_offset: {dpa.softmax_offset}") - print(f"d_softmax_offset: {dpa.softmax_offset.grad}") d_softmax_offset = dpa.softmax_offset.grad if is_training: diff --git a/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py b/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py index b610864416..d1fd0b0ed0 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py @@ -466,7 +466,7 @@ def __init__( if self.softmax_type == "learnable": self.register_parameter( "softmax_offset", - Parameter(torch.randn(self.num_attention_heads // self.tp_size, device="cuda")), + Parameter(torch.zeros(self.num_attention_heads // self.tp_size, device="cuda")), get_rng_state_tracker=get_rng_state_tracker, ) From e023d3bc68fdfed8fc57ed02b89794c24419e2be Mon Sep 17 00:00:00 2001 From: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Date: Mon, 23 Mar 2026 18:10:16 -0700 Subject: [PATCH 110/172] replace d_out_format with do_format Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --- .../common/fused_attn/fused_attn.cpp | 4 ++-- .../common/fused_attn/fused_attn_fp8.cu | 14 +++++++------- .../common/fused_attn/fused_attn_fp8.h | 2 +- .../common/include/transformer_engine/fused_attn.h | 4 ++-- transformer_engine/pytorch/csrc/extensions.h | 2 +- .../pytorch/csrc/extensions/attention.cpp | 6 +++--- 6 files changed, 16 insertions(+), 16 deletions(-) diff --git a/transformer_engine/common/fused_attn/fused_attn.cpp b/transformer_engine/common/fused_attn/fused_attn.cpp index 546165337c..945bd8bd2e 100644 --- a/transformer_engine/common/fused_attn/fused_attn.cpp +++ b/transformer_engine/common/fused_attn/fused_attn.cpp @@ -761,7 +761,7 @@ void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETenso const NVTETensor cu_seqlens_kv_padded, size_t max_seqlen_q, size_t max_seqlen_kv, float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, NVTE_QKV_Format o_format, - NVTE_QKV_Format d_out_format, NVTE_QKV_Layout dqkv_layout, + NVTE_QKV_Format do_format, NVTE_QKV_Layout dqkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, int64_t window_size_left, int64_t window_size_right, bool bottom_right_diagonal, bool deterministic, @@ -861,7 +861,7 @@ void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETenso input_dO_f16 = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); } fused_attn_fp8_bwd(b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d_qk, d_v, attn_scale, dropout, - qkv_layout, o_format, d_out_format, dqkv_layout, bias_type, attn_mask_type, + qkv_layout, o_format, do_format, dqkv_layout, bias_type, attn_mask_type, softmax_type, window_size_left, window_size_right, bottom_right_diagonal, deterministic, input_Q, input_K, input_V, input_O, input_dO, input_dO_f16, input_M, input_ZInv, input_S, input_SoftmaxOffset, input_output_dP, diff --git a/transformer_engine/common/fused_attn/fused_attn_fp8.cu b/transformer_engine/common/fused_attn/fused_attn_fp8.cu index 9638c917bd..3ec5a01441 100644 --- a/transformer_engine/common/fused_attn/fused_attn_fp8.cu +++ b/transformer_engine/common/fused_attn/fused_attn_fp8.cu @@ -2074,7 +2074,7 @@ void fused_attn_fp8_fwd_impl_v1( void fused_attn_fp8_bwd_impl_v1( int64_t b, int64_t h, int64_t hg, int64_t s_q, int64_t s_kv, int64_t d_qk, int64_t d_v, float scaling_factor, float dropout_probability, NVTE_QKV_Layout qkv_layout, - NVTE_QKV_Format o_format, NVTE_QKV_Format d_out_format, NVTE_QKV_Layout dqkv_layout, + NVTE_QKV_Format o_format, NVTE_QKV_Format do_format, NVTE_QKV_Layout dqkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, int64_t window_size_left, int64_t window_size_right, bool bottom_right_diagonal, bool deterministic, void* devPtrQ, void* devPtrK, void* devPtrV, void* devPtrM, @@ -2242,7 +2242,7 @@ void fused_attn_fp8_bwd_impl_v1( generateMatrixStridesWithLayout(b, h, hg, s_q, s_kv, d_qk, d_v, q_strides.data(), k_strides.data(), v_strides.data(), qkv_layout); generateMatrixStridesWithFormat(b, h, s_q, d_v, o_strides.data(), o_format); - generateMatrixStridesWithFormat(b, h, s_q, d_v, dO_strides.data(), d_out_format); + generateMatrixStridesWithFormat(b, h, s_q, d_v, dO_strides.data(), do_format); Q = mha_graph->tensor(fe::graph::Tensor_attributes() .set_name("Q") .set_dim({b, h, s_q, d_qk}) @@ -2319,7 +2319,7 @@ void fused_attn_fp8_bwd_impl_v1( std::vector dO_t_strides(4); generateMatrixStridesWithFormat(b, h, s_q, d_qk, q_t_strides.data(), q_format); generateMatrixStridesWithFormat(b, hg, s_kv, d_qk, k_t_strides.data(), kv_format); - generateMatrixStridesWithFormat(b, h, s_q, d_v, dO_t_strides.data(), d_out_format); + generateMatrixStridesWithFormat(b, h, s_q, d_v, dO_t_strides.data(), do_format); Q_t = mha_graph->tensor(fe::graph::Tensor_attributes() .set_name("Q_t") .set_dim({b, h, s_q, d_qk}) @@ -2360,9 +2360,9 @@ void fused_attn_fp8_bwd_impl_v1( generateMatrixStridesWithFormat(b, hg, padded.s_kv_padded, padded.d_v_scale_padded, v_scale_strides.data(), kv_format); generateMatrixStridesWithFormat(b, h, padded.s_q_padded, padded.d_v_scale_padded, - dO_scale_strides.data(), d_out_format); + dO_scale_strides.data(), do_format); generateMatrixStridesWithFormat(b, h, padded.s_q_scale_padded, padded.d_v_padded, - dO_t_scale_strides.data(), d_out_format); + dO_t_scale_strides.data(), do_format); descale_q = mha_graph->tensor(fe::graph::Tensor_attributes() .set_name("Descale_q") @@ -2855,7 +2855,7 @@ void fused_attn_fp8_fwd(size_t batch, size_t num_attn_heads, size_t num_gqa_grou void fused_attn_fp8_bwd( size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, float attn_scale, float p_dropout, - NVTE_QKV_Layout qkv_layout, NVTE_QKV_Format o_format, NVTE_QKV_Format d_out_format, + NVTE_QKV_Layout qkv_layout, NVTE_QKV_Format o_format, NVTE_QKV_Format do_format, NVTE_QKV_Layout dqkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, size_t window_size_left, size_t window_size_right, bool bottom_right_diagonal, bool deterministic, const Tensor* input_Q, const Tensor* input_K, @@ -2941,7 +2941,7 @@ void fused_attn_fp8_bwd( (dqkv_format == NVTE_QKV_Format::NVTE_BHSD)) { fused_attn::fused_attn_fp8_bwd_impl_v1( batch, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, head_dim_qk, head_dim_v, - attn_scale, p_dropout, qkv_layout, o_format, d_out_format, dqkv_layout, bias_type, + attn_scale, p_dropout, qkv_layout, o_format, do_format, dqkv_layout, bias_type, mask_type, softmax_type, window_size_left, window_size_right, bottom_right_diagonal, deterministic, devPtrQ, devPtrK, devPtrV, devPtrM, devPtrZInv, devPtrO, devPtrdO, devPtrSoftmaxOffset, devPtrdQ, devPtrdK, devPtrdV, devPtrdSoftmaxOffset, devPtrDescaleQ, diff --git a/transformer_engine/common/fused_attn/fused_attn_fp8.h b/transformer_engine/common/fused_attn/fused_attn_fp8.h index 617efa8f42..2f6c1105bd 100644 --- a/transformer_engine/common/fused_attn/fused_attn_fp8.h +++ b/transformer_engine/common/fused_attn/fused_attn_fp8.h @@ -31,7 +31,7 @@ void fused_attn_fp8_fwd(size_t batch, size_t num_attn_heads, size_t num_gqa_grou void fused_attn_fp8_bwd( size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, float attn_scale, float p_dropout, - NVTE_QKV_Layout qkv_layout, NVTE_QKV_Format o_format, NVTE_QKV_Format d_out_format, + NVTE_QKV_Layout qkv_layout, NVTE_QKV_Format o_format, NVTE_QKV_Format do_format, NVTE_QKV_Layout dqkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, size_t window_size_left, size_t window_size_right, bool bottom_right_diagonal, bool deterministic, const Tensor *input_Q, const Tensor *input_K, diff --git a/transformer_engine/common/include/transformer_engine/fused_attn.h b/transformer_engine/common/include/transformer_engine/fused_attn.h index 04f7ec4a6c..012492dab7 100644 --- a/transformer_engine/common/include/transformer_engine/fused_attn.h +++ b/transformer_engine/common/include/transformer_engine/fused_attn.h @@ -368,7 +368,7 @@ void nvte_fused_attn_fwd( * \param[in] dropout Dropout probability. * \param[in] qkv_layout QKV tensors' layout. * \param[in] o_format Output format. - * \param[in] d_out_format Output gradient's format. + * \param[in] do_format Output gradient's format. * \param[in] dqkv_layout QKV gradient tensors' layout. * \param[in] bias_type Bias type. * \param[in] attn_mask_type Attention mask type. @@ -390,7 +390,7 @@ void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETenso const NVTETensor cu_seqlens_kv_padded, size_t max_seqlen_q, size_t max_seqlen_kv, float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, NVTE_QKV_Format o_format, - NVTE_QKV_Format d_out_format, NVTE_QKV_Layout dqkv_layout, + NVTE_QKV_Format do_format, NVTE_QKV_Layout dqkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, int64_t window_size_left, int64_t window_size_right, bool bottom_right_diagonal, bool deterministic, diff --git a/transformer_engine/pytorch/csrc/extensions.h b/transformer_engine/pytorch/csrc/extensions.h index 602a09f54b..e8b588f0ee 100644 --- a/transformer_engine/pytorch/csrc/extensions.h +++ b/transformer_engine/pytorch/csrc/extensions.h @@ -98,7 +98,7 @@ std::vector fused_attn_fwd( std::vector fused_attn_bwd( size_t max_seqlen_q, size_t max_seqlen_kv, float attn_scale, float p_dropout, bool set_zero, - NVTE_QKV_Layout qkv_layout, NVTE_QKV_Format o_format, NVTE_QKV_Format d_out_format, + NVTE_QKV_Layout qkv_layout, NVTE_QKV_Format o_format, NVTE_QKV_Format do_format, NVTE_QKV_Layout dqkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, const std::vector window_size, bool bottom_right_diagonal, bool deterministic, const at::Tensor cu_seqlens_q, diff --git a/transformer_engine/pytorch/csrc/extensions/attention.cpp b/transformer_engine/pytorch/csrc/extensions/attention.cpp index f203d02c4b..3644e5e46b 100644 --- a/transformer_engine/pytorch/csrc/extensions/attention.cpp +++ b/transformer_engine/pytorch/csrc/extensions/attention.cpp @@ -330,7 +330,7 @@ std::vector fused_attn_fwd( // fused attention BWD with separate Q, K and V std::vector fused_attn_bwd( size_t max_seqlen_q, size_t max_seqlen_kv, float attn_scale, float p_dropout, bool set_zero, - NVTE_QKV_Layout qkv_layout, NVTE_QKV_Format o_format, NVTE_QKV_Format d_out_format, + NVTE_QKV_Layout qkv_layout, NVTE_QKV_Format o_format, NVTE_QKV_Format do_format, NVTE_QKV_Layout dqkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, const std::vector window_size, bool bottom_right_diagonal, bool deterministic, const at::Tensor cu_seqlens_q, @@ -575,7 +575,7 @@ std::vector fused_attn_bwd( &nvte_aux_tensor_pack, te_dQ.data(), te_dK.data(), te_dV.data(), te_dBias.data(), te_dSoftmaxOffset.data(), te_cu_seqlens_q.data(), te_cu_seqlens_kv.data(), te_cu_seqlens_q_padded.data(), te_cu_seqlens_kv_padded.data(), max_seqlen_q, max_seqlen_kv, - attn_scale, p_dropout, qkv_layout, o_format, d_out_format, dqkv_layout, bias_type, + attn_scale, p_dropout, qkv_layout, o_format, do_format, dqkv_layout, bias_type, attn_mask_type, softmax_type, window_size[0], window_size[1], bottom_right_diagonal, deterministic, cuda_graph, workspace.data(), at::cuda::getCurrentCUDAStream()); }); @@ -592,7 +592,7 @@ std::vector fused_attn_bwd( &nvte_aux_tensor_pack, te_dQ.data(), te_dK.data(), te_dV.data(), te_dBias.data(), te_dSoftmaxOffset.data(), te_cu_seqlens_q.data(), te_cu_seqlens_kv.data(), te_cu_seqlens_q_padded.data(), te_cu_seqlens_kv_padded.data(), max_seqlen_q, max_seqlen_kv, - attn_scale, p_dropout, qkv_layout, o_format, d_out_format, dqkv_layout, bias_type, + attn_scale, p_dropout, qkv_layout, o_format, do_format, dqkv_layout, bias_type, attn_mask_type, softmax_type, window_size[0], window_size[1], bottom_right_diagonal, deterministic, cuda_graph, workspace.data(), at::cuda::getCurrentCUDAStream()); }); From 7577919182d2e0dc41b5ed56675fd1fb6ffd91ff Mon Sep 17 00:00:00 2001 From: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Date: Mon, 23 Mar 2026 18:10:48 -0700 Subject: [PATCH 111/172] fix compare_and_assert for None cases Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --- tests/pytorch/utils.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/tests/pytorch/utils.py b/tests/pytorch/utils.py index 1747d75676..795b3c3441 100644 --- a/tests/pytorch/utils.py +++ b/tests/pytorch/utils.py @@ -163,6 +163,10 @@ def reset_rng_states() -> None: def compare_and_assert(a, b, name_a, name_b, atol, rtol, rmse_tol, is_fp8): + if a is None and b is None: + logging.debug(f"{name_a} vs {name_b}: both are None") + return + if not is_fp8: torch.testing.assert_close(a, b, atol=atol, rtol=rtol) return From f0b1e2a9ab3eb9dda1a78e0ab3ae912168ec6aff Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 24 Mar 2026 01:13:47 +0000 Subject: [PATCH 112/172] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../common/fused_attn/fused_attn_fp8.cu | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/transformer_engine/common/fused_attn/fused_attn_fp8.cu b/transformer_engine/common/fused_attn/fused_attn_fp8.cu index 3ec5a01441..23b15e1dfc 100644 --- a/transformer_engine/common/fused_attn/fused_attn_fp8.cu +++ b/transformer_engine/common/fused_attn/fused_attn_fp8.cu @@ -2941,15 +2941,15 @@ void fused_attn_fp8_bwd( (dqkv_format == NVTE_QKV_Format::NVTE_BHSD)) { fused_attn::fused_attn_fp8_bwd_impl_v1( batch, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, head_dim_qk, head_dim_v, - attn_scale, p_dropout, qkv_layout, o_format, do_format, dqkv_layout, bias_type, - mask_type, softmax_type, window_size_left, window_size_right, bottom_right_diagonal, - deterministic, devPtrQ, devPtrK, devPtrV, devPtrM, devPtrZInv, devPtrO, devPtrdO, - devPtrSoftmaxOffset, devPtrdQ, devPtrdK, devPtrdV, devPtrdSoftmaxOffset, devPtrDescaleQ, - devPtrDescaleK, devPtrDescaleV, devPtrDescaleO, devPtrDescaledO, devPtrDescaleS, - devPtrDescaledP, devPtrScaleS, devPtrScaledP, devPtrScaledQ, devPtrScaledK, devPtrScaledV, - devPtrAmaxdP, devPtrAmaxdQ, devPtrAmaxdK, devPtrAmaxdV, devPtrQ_t, devPtrK_t, devPtrdO_f16, - devPtrdO_t, devPtrDescaleQ_t, devPtrDescaleK_t, devPtrDescaledO_t, devPtrcuSeqlensQ, - devPtrcuSeqlensKV, devPtrDropoutSeed, devPtrDropoutOffset, get_cudnn_fe_dtype(QKV_type), + attn_scale, p_dropout, qkv_layout, o_format, do_format, dqkv_layout, bias_type, mask_type, + softmax_type, window_size_left, window_size_right, bottom_right_diagonal, deterministic, + devPtrQ, devPtrK, devPtrV, devPtrM, devPtrZInv, devPtrO, devPtrdO, devPtrSoftmaxOffset, + devPtrdQ, devPtrdK, devPtrdV, devPtrdSoftmaxOffset, devPtrDescaleQ, devPtrDescaleK, + devPtrDescaleV, devPtrDescaleO, devPtrDescaledO, devPtrDescaleS, devPtrDescaledP, + devPtrScaleS, devPtrScaledP, devPtrScaledQ, devPtrScaledK, devPtrScaledV, devPtrAmaxdP, + devPtrAmaxdQ, devPtrAmaxdK, devPtrAmaxdV, devPtrQ_t, devPtrK_t, devPtrdO_f16, devPtrdO_t, + devPtrDescaleQ_t, devPtrDescaleK_t, devPtrDescaledO_t, devPtrcuSeqlensQ, devPtrcuSeqlensKV, + devPtrDropoutSeed, devPtrDropoutOffset, get_cudnn_fe_dtype(QKV_type), get_cudnn_fe_dtype(O_type), get_cudnn_fe_dtype(dO_type), get_cudnn_fe_dtype(dQKV_type), input_dO->scaling_mode, workspace->data.dptr, &workspace_size, stream, handle); } else if (dqkv_layout == NVTE_QKV_Layout::NVTE_T3HD) { From ee388e55ac67105033d78327bb34f8ed965b71a8 Mon Sep 17 00:00:00 2001 From: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Date: Mon, 23 Mar 2026 19:10:23 -0700 Subject: [PATCH 113/172] remove logic for b and simplify logic for dqkv types Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --- .../pytorch/csrc/extensions/attention.cpp | 51 ++++++++----------- 1 file changed, 20 insertions(+), 31 deletions(-) diff --git a/transformer_engine/pytorch/csrc/extensions/attention.cpp b/transformer_engine/pytorch/csrc/extensions/attention.cpp index 3644e5e46b..2fea2fc307 100644 --- a/transformer_engine/pytorch/csrc/extensions/attention.cpp +++ b/transformer_engine/pytorch/csrc/extensions/attention.cpp @@ -152,12 +152,9 @@ std::vector fused_attn_fwd( auto o_shape_tmp = std::vector{q_shape.begin(), q_shape.end()}; o_shape_tmp[o_shape_tmp.size() - 1] = v_shape[v_shape.size() - 1]; auto o_shape = std::vector{o_shape_tmp.begin(), o_shape_tmp.end()}; - size_t b = 0, h = 0, s = 0, d = 0, t = 0; + size_t h=0, d=0; NVTE_QKV_Format q_format = nvte_get_q_format(qkv_layout); - nvte_convert_qkv_format(q_format, o_shape_tmp, o_format, o_shape, &b, &h, &s, &d, &t); - if (q_format == NVTE_QKV_Format::NVTE_THD) { - b = cu_seqlens_q.size(0) - 1; - } + nvte_convert_qkv_format(q_format, o_shape_tmp, o_format, o_shape, nullptr, &h, nullptr, &d, nullptr); const DType fake_dtype_te = GetTransformerEngineDType(fake_dtype); std::tie(te_O, py_O) = quantizer_helper(o_quantizer, o_shape, fake_dtype_te, true, std::nullopt); @@ -364,35 +361,26 @@ std::vector fused_attn_bwd( std::vector q_shape = convertShape(te_Q.shape()); std::vector k_shape = convertShape(te_K.shape()); std::vector v_shape = convertShape(te_V.shape()); - const DType fake_dtype_te = GetTransformerEngineDType(fake_dtype); - size_t b = 0, h_q = 0, h_kv = 0, s_q = 0, s_kv = 0, d_qk = 0, d_v = 0, t_q = 0, t_kv = 0; + const DType dqkv_fake_dtype = GetTransformerEngineDType(fake_dtype); + size_t h_q = 0, h_kv = 0, d_qk = 0, d_v = 0; size_t ndim = q_shape.size(); std::vector dQ_shape(ndim), dK_shape(ndim), dV_shape(ndim); NVTE_QKV_Format q_format = nvte_get_q_format(qkv_layout); NVTE_QKV_Format kv_format = nvte_get_kv_format(qkv_layout); NVTE_QKV_Format dq_format = nvte_get_q_format(dqkv_layout); NVTE_QKV_Format dkv_format = nvte_get_kv_format(dqkv_layout); - nvte_convert_qkv_format(q_format, q_shape, dq_format, dQ_shape, &b, &h_q, &s_q, &d_qk, &t_q); - nvte_convert_qkv_format(kv_format, k_shape, dkv_format, dK_shape, &b, &h_kv, &s_kv, &d_qk, &t_kv); - nvte_convert_qkv_format(kv_format, v_shape, dkv_format, dV_shape, &b, &h_kv, &s_kv, &d_v, &t_kv); - if (dq_format == NVTE_QKV_Format::NVTE_THD) { - b = cu_seqlens_q.size(0) - 1; - } else if (dkv_format == NVTE_QKV_Format::NVTE_THD) { - b = cu_seqlens_kv.size(0) - 1; - } + nvte_convert_qkv_format(q_format, q_shape, dq_format, dQ_shape, nullptr, &h_q, nullptr, &d_qk, nullptr); + nvte_convert_qkv_format(kv_format, k_shape, dkv_format, dK_shape, nullptr, &h_kv, nullptr, nullptr, nullptr); + nvte_convert_qkv_format(kv_format, v_shape, dkv_format, dV_shape, nullptr, nullptr, nullptr, &d_v, nullptr); at::Tensor dQ, dK, dV, dQKV, dKV; - DType dqkv_type = fake_dtype_te; - if (!dqkv_quantizer.is_none()) { - dqkv_type = dqkv_quantizer.attr("dtype").cast(); - } - auto options = torch::TensorOptions().dtype(GetATenDType(dqkv_type)).device(torch::kCUDA); - if (dqkv_type == DType::kFloat8E4M3 || dqkv_type == DType::kFloat8E5M2) { + // FP16/BF16: dqkv_fake_dtype = kFloat16/kBFloat16, dQ/dK/dV.dtype = torch.float16/torch.bfloat16 + // FP8DS: dqkv_fake_dtype = kFloat16/kBFloat16, dQ/dK/dV.dtype = torch.uint8 + // FP8CS/MXFP8: dqkv_fake_dtype = kFloat16/kBFloat16, dQ/dK/dV.dtype = torch.float16/torch.bfloat16 + auto options = torch::TensorOptions().dtype(fake_dtype).device(torch::kCUDA); + if (detail::IsFloat8Quantizers(dqkv_quantizer.ptr())) { options = options.dtype(torch::kUInt8); } - if (detail::IsFloat8CurrentScalingQuantizers(dqkv_quantizer.ptr()) || - detail::IsMXFP8Quantizers(dqkv_quantizer.ptr())) { - options = options.dtype(fake_dtype); - } + NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(dqkv_layout); std::vector tmp_shape; switch (layout_group) { @@ -468,15 +456,15 @@ std::vector fused_attn_bwd( NVTE_ERROR("QKV layout not supported!"); } - std::tie(te_dQ, py_dQ) = quantizer_helper(dqkv_quantizer, dQ_shape, fake_dtype_te, true, dQ); - std::tie(te_dK, py_dK) = quantizer_helper(dqkv_quantizer, dK_shape, fake_dtype_te, true, dK); - std::tie(te_dV, py_dV) = quantizer_helper(dqkv_quantizer, dV_shape, fake_dtype_te, true, dV); + std::tie(te_dQ, py_dQ) = quantizer_helper(dqkv_quantizer, dQ_shape, dqkv_fake_dtype, true, dQ); + std::tie(te_dK, py_dK) = quantizer_helper(dqkv_quantizer, dK_shape, dqkv_fake_dtype, true, dK); + std::tie(te_dV, py_dV) = quantizer_helper(dqkv_quantizer, dV_shape, dqkv_fake_dtype, true, dV); // construct NVTE tensors - if (dqkv_type == DType::kFloat8E4M3 || dqkv_type == DType::kFloat8E5M2) { + if (detail::IsFloat8Quantizers(dqkv_quantizer.ptr())) { // FP8 if (set_zero && (nvte_get_qkv_format(dqkv_layout) == NVTE_QKV_Format::NVTE_THD)) { - if (((h_q * d_qk) % block_size == 0) && ((h_kv * d_qk) % block_size == 0) && + if (((h_q * d_qk) % block_size == 0) && ((h_kv * d_qk) % block_size == 0) && ((h_kv * d_v) % block_size == 0) && dQ.is_contiguous() && dK.is_contiguous() && dV.is_contiguous()) { mha_fill(te_dQ, cu_seqlens_q.index({torch::indexing::Slice(-1, torch::indexing::None)})); mha_fill(te_dK, cu_seqlens_kv.index({torch::indexing::Slice(-1, torch::indexing::None)})); @@ -487,7 +475,8 @@ std::vector fused_attn_bwd( dV.fill_(0); } } - } else if (dqkv_type == DType::kBFloat16 || dqkv_type == DType::kFloat16) { + } else if (dqkv_quantizer.is_none() || detail::IsFloat8CurrentScalingQuantizers(dqkv_quantizer.ptr()) || + detail::IsMXFP8Quantizers(dqkv_quantizer.ptr())) { if (nvte_get_qkv_format(dqkv_layout) == NVTE_QKV_Format::NVTE_THD) { dQ.fill_(0); dK.fill_(0); From cacc59dd3e190110f4ca7be67f96a933a6d3ce0d Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 24 Mar 2026 16:09:01 +0000 Subject: [PATCH 114/172] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../pytorch/csrc/extensions/attention.cpp | 24 ++++++++++++------- 1 file changed, 15 insertions(+), 9 deletions(-) diff --git a/transformer_engine/pytorch/csrc/extensions/attention.cpp b/transformer_engine/pytorch/csrc/extensions/attention.cpp index 2fea2fc307..925cdcf81c 100644 --- a/transformer_engine/pytorch/csrc/extensions/attention.cpp +++ b/transformer_engine/pytorch/csrc/extensions/attention.cpp @@ -152,9 +152,10 @@ std::vector fused_attn_fwd( auto o_shape_tmp = std::vector{q_shape.begin(), q_shape.end()}; o_shape_tmp[o_shape_tmp.size() - 1] = v_shape[v_shape.size() - 1]; auto o_shape = std::vector{o_shape_tmp.begin(), o_shape_tmp.end()}; - size_t h=0, d=0; + size_t h = 0, d = 0; NVTE_QKV_Format q_format = nvte_get_q_format(qkv_layout); - nvte_convert_qkv_format(q_format, o_shape_tmp, o_format, o_shape, nullptr, &h, nullptr, &d, nullptr); + nvte_convert_qkv_format(q_format, o_shape_tmp, o_format, o_shape, nullptr, &h, nullptr, &d, + nullptr); const DType fake_dtype_te = GetTransformerEngineDType(fake_dtype); std::tie(te_O, py_O) = quantizer_helper(o_quantizer, o_shape, fake_dtype_te, true, std::nullopt); @@ -369,9 +370,12 @@ std::vector fused_attn_bwd( NVTE_QKV_Format kv_format = nvte_get_kv_format(qkv_layout); NVTE_QKV_Format dq_format = nvte_get_q_format(dqkv_layout); NVTE_QKV_Format dkv_format = nvte_get_kv_format(dqkv_layout); - nvte_convert_qkv_format(q_format, q_shape, dq_format, dQ_shape, nullptr, &h_q, nullptr, &d_qk, nullptr); - nvte_convert_qkv_format(kv_format, k_shape, dkv_format, dK_shape, nullptr, &h_kv, nullptr, nullptr, nullptr); - nvte_convert_qkv_format(kv_format, v_shape, dkv_format, dV_shape, nullptr, nullptr, nullptr, &d_v, nullptr); + nvte_convert_qkv_format(q_format, q_shape, dq_format, dQ_shape, nullptr, &h_q, nullptr, &d_qk, + nullptr); + nvte_convert_qkv_format(kv_format, k_shape, dkv_format, dK_shape, nullptr, &h_kv, nullptr, + nullptr, nullptr); + nvte_convert_qkv_format(kv_format, v_shape, dkv_format, dV_shape, nullptr, nullptr, nullptr, &d_v, + nullptr); at::Tensor dQ, dK, dV, dQKV, dKV; // FP16/BF16: dqkv_fake_dtype = kFloat16/kBFloat16, dQ/dK/dV.dtype = torch.float16/torch.bfloat16 // FP8DS: dqkv_fake_dtype = kFloat16/kBFloat16, dQ/dK/dV.dtype = torch.uint8 @@ -464,8 +468,9 @@ std::vector fused_attn_bwd( if (detail::IsFloat8Quantizers(dqkv_quantizer.ptr())) { // FP8 if (set_zero && (nvte_get_qkv_format(dqkv_layout) == NVTE_QKV_Format::NVTE_THD)) { - if (((h_q * d_qk) % block_size == 0) && ((h_kv * d_qk) % block_size == 0) && ((h_kv * d_v) % block_size == 0) && - dQ.is_contiguous() && dK.is_contiguous() && dV.is_contiguous()) { + if (((h_q * d_qk) % block_size == 0) && ((h_kv * d_qk) % block_size == 0) && + ((h_kv * d_v) % block_size == 0) && dQ.is_contiguous() && dK.is_contiguous() && + dV.is_contiguous()) { mha_fill(te_dQ, cu_seqlens_q.index({torch::indexing::Slice(-1, torch::indexing::None)})); mha_fill(te_dK, cu_seqlens_kv.index({torch::indexing::Slice(-1, torch::indexing::None)})); mha_fill(te_dV, cu_seqlens_kv.index({torch::indexing::Slice(-1, torch::indexing::None)})); @@ -475,8 +480,9 @@ std::vector fused_attn_bwd( dV.fill_(0); } } - } else if (dqkv_quantizer.is_none() || detail::IsFloat8CurrentScalingQuantizers(dqkv_quantizer.ptr()) || - detail::IsMXFP8Quantizers(dqkv_quantizer.ptr())) { + } else if (dqkv_quantizer.is_none() || + detail::IsFloat8CurrentScalingQuantizers(dqkv_quantizer.ptr()) || + detail::IsMXFP8Quantizers(dqkv_quantizer.ptr())) { if (nvte_get_qkv_format(dqkv_layout) == NVTE_QKV_Format::NVTE_THD) { dQ.fill_(0); dK.fill_(0); From de82fe115fc7b08afa2a3fb4476681503e943aa6 Mon Sep 17 00:00:00 2001 From: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Date: Tue, 24 Mar 2026 09:24:03 -0700 Subject: [PATCH 115/172] minor fix for ndim_q/kv Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --- transformer_engine/pytorch/csrc/extensions/attention.cpp | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/transformer_engine/pytorch/csrc/extensions/attention.cpp b/transformer_engine/pytorch/csrc/extensions/attention.cpp index 925cdcf81c..d76e29964a 100644 --- a/transformer_engine/pytorch/csrc/extensions/attention.cpp +++ b/transformer_engine/pytorch/csrc/extensions/attention.cpp @@ -364,8 +364,9 @@ std::vector fused_attn_bwd( std::vector v_shape = convertShape(te_V.shape()); const DType dqkv_fake_dtype = GetTransformerEngineDType(fake_dtype); size_t h_q = 0, h_kv = 0, d_qk = 0, d_v = 0; - size_t ndim = q_shape.size(); - std::vector dQ_shape(ndim), dK_shape(ndim), dV_shape(ndim); + size_t ndim_q = q_shape.size(); + size_t ndim_kv = k_shape.size(); + std::vector dQ_shape(ndim_q), dK_shape(ndim_kv), dV_shape(ndim_kv); NVTE_QKV_Format q_format = nvte_get_q_format(qkv_layout); NVTE_QKV_Format kv_format = nvte_get_kv_format(qkv_layout); NVTE_QKV_Format dq_format = nvte_get_q_format(dqkv_layout); From 706012a7ab888d467c6bb0c6cd00411a1244865a Mon Sep 17 00:00:00 2001 From: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Date: Tue, 24 Mar 2026 10:33:50 -0700 Subject: [PATCH 116/172] add explanation of fp8_output/grad in MHA Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --- .../pytorch/attention/multi_head_attention.py | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/transformer_engine/pytorch/attention/multi_head_attention.py b/transformer_engine/pytorch/attention/multi_head_attention.py index 9972ddd994..afc4622b22 100644 --- a/transformer_engine/pytorch/attention/multi_head_attention.py +++ b/transformer_engine/pytorch/attention/multi_head_attention.py @@ -801,7 +801,12 @@ def forward( fp8_mha = _dpa_fp8_recipe_mha float8_current_scaling = _dpa_fp8_recipe == "Float8CurrentScaling" mxfp8_scaling = _dpa_fp8_recipe == "MXFP8BlockScaling" - # QKV Gemm: do not produce FP8 output when in Float8CurrentScaling or MXFP8BlockScaling recipe + + # QKV Gemm: do not produce FP8 output when fp8_mha = True if + # 1. RoPE is on: RoPE is only implemented in F16 currently + # 2. FP8CS recipe: due to cuBLAS limitation, FP8CS Gemms can not produce FP8 output + # 3. MXFP8 recipe: QKV Gemm produces QKV in bs(hd), sb(hd), t(hd) shapes, quantization of which would be along + # s/b/t and (hd) dimensions, whereas MXFP8 attention requires quantization along s and d, e.g. bhsd, sbhd, thd qkv_fp8_output = ( fp8 and fp8_mha @@ -809,9 +814,12 @@ def forward( and not float8_current_scaling and not mxfp8_scaling ) - # DPA: produce FP8 output when fp8=True to take advantage of the O amax except for MXFP8BlockScaling + # DPA: produce FP8 output to take advantage of O amax from DPA; Projection Gemm can take FP8 or F16 inputs + # 1. FP8DS/FP8CS recipe: produce FP8 output + # 2. MXFP8 recipe: produce F16 output; again, due to quantization dimensions mismatch dpa_fp8_output = fp8 and (fp8_dpa or fp8_mha) and not mxfp8_scaling - # Proj Gemm: match DPA output except for Float8CurrentScaling + # Projection Gemm: match DPA output except + # 1. FP8CS recipe: produce F16 grads; again, due to cuBLAS limitation proj_fp8_grad = dpa_fp8_output and not float8_current_scaling layernorm_output = None From 746010e0921075e144e0d2ca765918fe9d4db106 Mon Sep 17 00:00:00 2001 From: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Date: Tue, 24 Mar 2026 11:14:02 -0700 Subject: [PATCH 117/172] tidy up FP8 checks for bhsd/learnable Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --- .../common/fused_attn/fused_attn.cpp | 22 +++++++++++++------ 1 file changed, 15 insertions(+), 7 deletions(-) diff --git a/transformer_engine/common/fused_attn/fused_attn.cpp b/transformer_engine/common/fused_attn/fused_attn.cpp index 945bd8bd2e..e601a94368 100644 --- a/transformer_engine/common/fused_attn/fused_attn.cpp +++ b/transformer_engine/common/fused_attn/fused_attn.cpp @@ -357,12 +357,21 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK || attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK || attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK)) || - // 9.21: mxfp8, d_qk=128, d_v=192 - (cudnn_runtime_version >= 92100 && head_dim_qk <= 192 && head_dim_v <= 128)) && + // 9.21: d_qk=192, d_v=128 + (cudnn_runtime_version >= 92100 && sm_arch_ >= 100 && + head_dim_qk <= 192 && head_dim_v <= 128 && + head_dim_qk % 16 == 0 && head_dim_v % 16 == 0 && + (attn_mask_type == NVTE_Mask_Type::NVTE_NO_MASK || + attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK || + attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_BOTTOM_RIGHT_MASK))) && + // pre-9.21: {bshd, sbhd}, {vanilla} + // 9.21+: {bshd, sbhd, bhsd}, {vanilla, off-by-one, learnable} + ((cudnn_runtime_version < 92100 && + (qkv_format == NVTE_QKV_Format::NVTE_BSHD || qkv_format == NVTE_QKV_Format::NVTE_SBHD) && + softmax_type == NVTE_Softmax_Type::NVTE_VANILLA_SOFTMAX) || + (cudnn_runtime_version >= 92100 && + (qkv_format == NVTE_QKV_Format::NVTE_BSHD || qkv_format == NVTE_QKV_Format::NVTE_SBHD || qkv_format == NVTE_QKV_Format::NVTE_BHSD))) && !requires_64bit_ragged_offset && - // pre-9.21: softmax_type=vanilla, 9.21+: softmax_type={vanilla, off-by-one, learnable} - ((cudnn_runtime_version < 92100 && softmax_type == NVTE_Softmax_Type::NVTE_VANILLA_SOFTMAX) || - cudnn_runtime_version >= 92100) && // 9.10.0: known bugs with SDPA FP8 (cudnn_runtime_version != 91000) && !return_max_logit) { if (cudnn_runtime_version >= 8900) { @@ -520,8 +529,7 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( (cudnn_runtime_version >= 90200 && ((window_size_left == -1 && window_size_right == -1 && attn_mask_type == NVTE_Mask_Type::NVTE_NO_MASK) || - ((window_size_left == -1 || window_size_left >= 0) && - (window_size_right == -1 || window_size_right >= 0) && + ((window_size_left == -1 || window_size_left >= 0) && window_size_right == 0 && (attn_mask_type == NVTE_Mask_Type::NVTE_NO_MASK || attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK || (attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_BOTTOM_RIGHT_MASK && From 2283081a316e89691e76a0a1d46ee7265506ced9 Mon Sep 17 00:00:00 2001 From: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Date: Tue, 24 Mar 2026 11:27:11 -0700 Subject: [PATCH 118/172] remove leading underscores in nvte_convert_qkv_format Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --- .../common/fused_attn/fused_attn.cpp | 72 +++++++++---------- 1 file changed, 36 insertions(+), 36 deletions(-) diff --git a/transformer_engine/common/fused_attn/fused_attn.cpp b/transformer_engine/common/fused_attn/fused_attn.cpp index e601a94368..22c02e7664 100644 --- a/transformer_engine/common/fused_attn/fused_attn.cpp +++ b/transformer_engine/common/fused_attn/fused_attn.cpp @@ -230,30 +230,30 @@ NVTE_QKV_Format nvte_get_kv_format(NVTE_QKV_Layout qkv_layout) { void nvte_convert_qkv_format(NVTE_QKV_Format src_format, std::vector src_shape, NVTE_QKV_Format dst_format, std::vector &dst_shape, size_t *b, size_t *h, size_t *s, size_t *d, size_t *t) { - size_t _b = 0, _h = 0, _s = 0, _d = 0, _t = 0; + size_t b_tmp = 0, h_tmp = 0, s_tmp = 0, d_tmp = 0, t_tmp = 0; switch (src_format) { case NVTE_QKV_Format::NVTE_BSHD: - _b = src_shape[0]; - _s = src_shape[1]; - _h = src_shape[2]; - _d = src_shape[3]; + b_tmp = src_shape[0]; + s_tmp = src_shape[1]; + h_tmp = src_shape[2]; + d_tmp = src_shape[3]; break; case NVTE_QKV_Format::NVTE_SBHD: - _s = src_shape[0]; - _b = src_shape[1]; - _h = src_shape[2]; - _d = src_shape[3]; + s_tmp = src_shape[0]; + b_tmp = src_shape[1]; + h_tmp = src_shape[2]; + d_tmp = src_shape[3]; break; case NVTE_QKV_Format::NVTE_BHSD: - _b = src_shape[0]; - _h = src_shape[1]; - _s = src_shape[2]; - _d = src_shape[3]; + b_tmp = src_shape[0]; + h_tmp = src_shape[1]; + s_tmp = src_shape[2]; + d_tmp = src_shape[3]; break; case NVTE_QKV_Format::NVTE_THD: - _t = src_shape[0]; - _h = src_shape[1]; - _d = src_shape[2]; + t_tmp = src_shape[0]; + h_tmp = src_shape[1]; + d_tmp = src_shape[2]; break; default: NVTE_ERROR("src_format not supported!"); @@ -261,27 +261,27 @@ void nvte_convert_qkv_format(NVTE_QKV_Format src_format, std::vector src } switch (dst_format) { case NVTE_QKV_Format::NVTE_BSHD: - dst_shape[0] = _b; - dst_shape[1] = _s; - dst_shape[2] = _h; - dst_shape[3] = _d; + dst_shape[0] = b_tmp; + dst_shape[1] = s_tmp; + dst_shape[2] = h_tmp; + dst_shape[3] = d_tmp; break; case NVTE_QKV_Format::NVTE_SBHD: - dst_shape[0] = _s; - dst_shape[1] = _b; - dst_shape[2] = _h; - dst_shape[3] = _d; + dst_shape[0] = s_tmp; + dst_shape[1] = b_tmp; + dst_shape[2] = h_tmp; + dst_shape[3] = d_tmp; break; case NVTE_QKV_Format::NVTE_BHSD: - dst_shape[0] = _b; - dst_shape[1] = _h; - dst_shape[2] = _s; - dst_shape[3] = _d; + dst_shape[0] = b_tmp; + dst_shape[1] = h_tmp; + dst_shape[2] = s_tmp; + dst_shape[3] = d_tmp; break; case NVTE_QKV_Format::NVTE_THD: - dst_shape[0] = _t; - dst_shape[1] = _h; - dst_shape[2] = _d; + dst_shape[0] = t_tmp; + dst_shape[1] = h_tmp; + dst_shape[2] = d_tmp; break; default: NVTE_ERROR("dst_format not supported!"); @@ -289,19 +289,19 @@ void nvte_convert_qkv_format(NVTE_QKV_Format src_format, std::vector src } if (b != nullptr) { - *b = _b; + *b = b_tmp; } if (h != nullptr) { - *h = _h; + *h = h_tmp; } if (s != nullptr) { - *s = _s; + *s = s_tmp; } if (d != nullptr) { - *d = _d; + *d = d_tmp; } if (t != nullptr) { - *t = _t; + *t = t_tmp; } } From e693e6fa6c943bca3cac39d8160993067f79708f Mon Sep 17 00:00:00 2001 From: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Date: Tue, 24 Mar 2026 12:15:56 -0700 Subject: [PATCH 119/172] simplify logic in generateMatrixStridesWithLayout Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --- transformer_engine/common/fused_attn/utils.h | 93 +++----------------- 1 file changed, 11 insertions(+), 82 deletions(-) diff --git a/transformer_engine/common/fused_attn/utils.h b/transformer_engine/common/fused_attn/utils.h index fdda4dfe9c..a26e72f2ae 100644 --- a/transformer_engine/common/fused_attn/utils.h +++ b/transformer_engine/common/fused_attn/utils.h @@ -111,6 +111,8 @@ inline void generateMatrixStridesWithLayout(int64_t b, int64_t h, int64_t hg, in constexpr int h_dim = 1; constexpr int s_dim = 2; constexpr int d_dim = 3; + const NVTE_QKV_Format q_format = nvte_get_q_format(layout); + const NVTE_QKV_Format kv_format = nvte_get_kv_format(layout); switch (layout) { case NVTE_QKV_Layout::NVTE_SB3HD: @@ -132,10 +134,7 @@ inline void generateMatrixStridesWithLayout(int64_t b, int64_t h, int64_t hg, in } break; case NVTE_QKV_Layout::NVTE_SBHD_SB2HD: - q_strides[b_dim] = h * d_qk; - q_strides[h_dim] = d_qk; - q_strides[s_dim] = b * h * d_qk; - q_strides[d_dim] = 1; + generateMatrixStridesWithFormat(b, h, s_q, d_qk, q_strides, q_format); k_strides[b_dim] = 2 * hg * d_qk; k_strides[h_dim] = d_qk; k_strides[s_dim] = b * 2 * hg * d_qk; @@ -145,10 +144,7 @@ inline void generateMatrixStridesWithLayout(int64_t b, int64_t h, int64_t hg, in } break; case NVTE_QKV_Layout::NVTE_SBHD_SBH2D: - q_strides[b_dim] = h * d_qk; - q_strides[h_dim] = d_qk; - q_strides[s_dim] = b * h * d_qk; - q_strides[d_dim] = 1; + generateMatrixStridesWithFormat(b, h, s_q, d_qk, q_strides, q_format); k_strides[b_dim] = 2 * hg * d_qk; k_strides[h_dim] = 2 * d_qk; k_strides[s_dim] = b * 2 * hg * d_qk; @@ -157,21 +153,6 @@ inline void generateMatrixStridesWithLayout(int64_t b, int64_t h, int64_t hg, in v_strides[i] = k_strides[i]; } break; - case NVTE_QKV_Layout::NVTE_SBHD_SBHD_SBHD: - case NVTE_QKV_Layout::NVTE_Paged_KV_SBHD_SBHD_SBHD: - q_strides[b_dim] = h * d_qk; - q_strides[h_dim] = d_qk; - q_strides[s_dim] = b * h * d_qk; - q_strides[d_dim] = 1; - k_strides[b_dim] = hg * d_qk; - k_strides[h_dim] = d_qk; - k_strides[s_dim] = b * hg * d_qk; - k_strides[d_dim] = 1; - v_strides[b_dim] = hg * d_v; - v_strides[h_dim] = d_v; - v_strides[s_dim] = b * hg * d_v; - v_strides[d_dim] = 1; - break; case NVTE_QKV_Layout::NVTE_BS3HD: case NVTE_QKV_Layout::NVTE_T3HD: q_strides[b_dim] = s_q * 3 * h * d_qk; @@ -194,10 +175,7 @@ inline void generateMatrixStridesWithLayout(int64_t b, int64_t h, int64_t hg, in break; case NVTE_QKV_Layout::NVTE_BSHD_BS2HD: case NVTE_QKV_Layout::NVTE_THD_T2HD: - q_strides[b_dim] = s_q * h * d_qk; - q_strides[h_dim] = d_qk; - q_strides[s_dim] = h * d_qk; - q_strides[d_dim] = 1; + generateMatrixStridesWithFormat(b, h, s_q, d_qk, q_strides, q_format); k_strides[b_dim] = s_kv * 2 * hg * d_qk; k_strides[h_dim] = d_qk; k_strides[s_dim] = 2 * hg * d_qk; @@ -208,10 +186,7 @@ inline void generateMatrixStridesWithLayout(int64_t b, int64_t h, int64_t hg, in break; case NVTE_QKV_Layout::NVTE_BSHD_BSH2D: case NVTE_QKV_Layout::NVTE_THD_TH2D: - q_strides[b_dim] = s_q * h * d_qk; - q_strides[h_dim] = d_qk; - q_strides[s_dim] = h * d_qk; - q_strides[d_dim] = 1; + generateMatrixStridesWithFormat(b, h, s_q, d_qk, q_strides, q_format); k_strides[b_dim] = s_kv * 2 * hg * d_qk; k_strides[h_dim] = 2 * d_qk; k_strides[s_dim] = 2 * hg * d_qk; @@ -220,69 +195,23 @@ inline void generateMatrixStridesWithLayout(int64_t b, int64_t h, int64_t hg, in v_strides[i] = k_strides[i]; } break; + case NVTE_QKV_Layout::NVTE_SBHD_SBHD_SBHD: + case NVTE_QKV_Layout::NVTE_Paged_KV_SBHD_SBHD_SBHD: case NVTE_QKV_Layout::NVTE_BSHD_BSHD_BSHD: case NVTE_QKV_Layout::NVTE_THD_THD_THD: case NVTE_QKV_Layout::NVTE_THD_BSHD_BSHD: case NVTE_QKV_Layout::NVTE_Paged_KV_BSHD_BSHD_BSHD: case NVTE_QKV_Layout::NVTE_Paged_KV_THD_BSHD_BSHD: - q_strides[b_dim] = s_q * h * d_qk; - q_strides[h_dim] = d_qk; - q_strides[s_dim] = h * d_qk; - q_strides[d_dim] = 1; - k_strides[b_dim] = s_kv * hg * d_qk; - k_strides[h_dim] = d_qk; - k_strides[s_dim] = hg * d_qk; - k_strides[d_dim] = 1; - v_strides[b_dim] = s_kv * hg * d_v; - v_strides[h_dim] = d_v; - v_strides[s_dim] = hg * d_v; - v_strides[d_dim] = 1; - break; case NVTE_QKV_Layout::NVTE_SBHD_BSHD_BSHD: case NVTE_QKV_Layout::NVTE_Paged_KV_SBHD_BSHD_BSHD: - q_strides[b_dim] = h * d_qk; - q_strides[h_dim] = d_qk; - q_strides[s_dim] = b * h * d_qk; - q_strides[d_dim] = 1; - k_strides[b_dim] = s_kv * hg * d_qk; - k_strides[h_dim] = d_qk; - k_strides[s_dim] = hg * d_qk; - k_strides[d_dim] = 1; - v_strides[b_dim] = s_kv * hg * d_v; - v_strides[h_dim] = d_v; - v_strides[s_dim] = hg * d_v; - v_strides[d_dim] = 1; - break; case NVTE_QKV_Layout::NVTE_BSHD_SBHD_SBHD: case NVTE_QKV_Layout::NVTE_THD_SBHD_SBHD: case NVTE_QKV_Layout::NVTE_Paged_KV_BSHD_SBHD_SBHD: case NVTE_QKV_Layout::NVTE_Paged_KV_THD_SBHD_SBHD: - q_strides[b_dim] = s_q * h * d_qk; - q_strides[h_dim] = d_qk; - q_strides[s_dim] = h * d_qk; - q_strides[d_dim] = 1; - k_strides[b_dim] = hg * d_qk; - k_strides[h_dim] = d_qk; - k_strides[s_dim] = b * hg * d_qk; - k_strides[d_dim] = 1; - v_strides[b_dim] = hg * d_v; - v_strides[h_dim] = d_v; - v_strides[s_dim] = b * hg * d_v; - v_strides[d_dim] = 1; - break; case NVTE_QKV_Layout::NVTE_BHSD_BHSD_BHSD: - q_strides[b_dim] = h * s_q * d_qk; - q_strides[h_dim] = s_q * d_qk; - q_strides[s_dim] = d_qk; - q_strides[d_dim] = 1; - k_strides[b_dim] = hg * s_kv * d_qk; - k_strides[h_dim] = s_kv * d_qk; - k_strides[s_dim] = d_qk; - k_strides[d_dim] = 1; - v_strides[b_dim] = hg * s_kv * d_v; - v_strides[h_dim] = s_kv * d_v; - v_strides[s_dim] = d_v; - v_strides[d_dim] = 1; + generateMatrixStridesWithFormat(b, h, s_q, d_qk, q_strides, q_format); + generateMatrixStridesWithFormat(b, hg, s_kv, d_qk, k_strides, kv_format); + generateMatrixStridesWithFormat(b, hg, s_kv, d_v, v_strides, kv_format); break; default: NVTE_CHECK(false, "Invalid layout."); From edf1b2af38e324ad55bfb56c5e80e576515d30be Mon Sep 17 00:00:00 2001 From: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Date: Tue, 24 Mar 2026 12:37:32 -0700 Subject: [PATCH 120/172] clean up strides/ifelse-recipe logic Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --- .../common/fused_attn/fused_attn_fp8.cu | 92 +++++++------------ 1 file changed, 34 insertions(+), 58 deletions(-) diff --git a/transformer_engine/common/fused_attn/fused_attn_fp8.cu b/transformer_engine/common/fused_attn/fused_attn_fp8.cu index 23b15e1dfc..a7beda1e1c 100644 --- a/transformer_engine/common/fused_attn/fused_attn_fp8.cu +++ b/transformer_engine/common/fused_attn/fused_attn_fp8.cu @@ -1777,9 +1777,7 @@ void fused_attn_fp8_fwd_impl_v1( std::shared_ptr dropout_seed, dropout_offset; // Q, K, V, attn_scale - std::vector q_strides(4); - std::vector k_strides(4); - std::vector v_strides(4); + std::vector q_strides(4), k_strides(4), v_strides(4); generateMatrixStridesWithLayout(b, h, hg, s_q, s_kv, d_qk, d_v, q_strides.data(), k_strides.data(), v_strides.data(), qkv_layout); Q = mha_graph->tensor(fe::graph::Tensor_attributes() @@ -1821,8 +1819,7 @@ void fused_attn_fp8_fwd_impl_v1( if (is_current_scaling) { scale_o = mha_graph->tensor(1.0f); } - } - if (is_mxfp8) { + } else if (is_mxfp8) { NVTE_QKV_Format q_format = nvte_get_q_format(qkv_layout); NVTE_QKV_Format kv_format = nvte_get_kv_format(qkv_layout); std::vector q_scale_strides(4); @@ -1870,11 +1867,13 @@ void fused_attn_fp8_fwd_impl_v1( : fe::DiagonalAlignment_t::TOP_LEFT; sdpa_options.set_diagonal_alignment(diagonal_alignment); - if (cudnn_runtime_version >= 90200 && window_size_left != -1) { - sdpa_options.set_diagonal_band_left_bound(window_size_left + 1); - } - if (cudnn_runtime_version >= 90600 && window_size_right != -1) { - sdpa_options.set_diagonal_band_right_bound(window_size_right); + if (cudnn_runtime_version >= 92100) { + if (window_size_left != -1) { + sdpa_options.set_diagonal_band_left_bound(window_size_left + 1); + } + if (window_size_right != -1) { + sdpa_options.set_diagonal_band_right_bound(window_size_right); + } } // sdpa_options.set_alibi_mask(is_alibi); @@ -2031,7 +2030,7 @@ void fused_attn_fp8_fwd_impl_v1( if (is_delayed_scaling) { variant_pack[scale_o] = devPtrScaleO; } - if (!is_mxfp8) { + if (is_delayed_scaling || is_current_scaling) { variant_pack[descale_s] = devPtrDescaleS; variant_pack[scale_s] = devPtrScaleS; variant_pack[amax_s] = devPtrAmaxS; @@ -2234,11 +2233,7 @@ void fused_attn_fp8_bwd_impl_v1( std::shared_ptr dropout_seed, dropout_offset; // Q, K, V, O, dO, stats, attn_scale - std::vector q_strides(4); - std::vector k_strides(4); - std::vector v_strides(4); - std::vector o_strides(4); - std::vector dO_strides(4); + std::vector q_strides(4), k_strides(4), v_strides(4), o_strides(4), dO_strides(4); generateMatrixStridesWithLayout(b, h, hg, s_q, s_kv, d_qk, d_v, q_strides.data(), k_strides.data(), v_strides.data(), qkv_layout); generateMatrixStridesWithFormat(b, h, s_q, d_v, o_strides.data(), o_format); @@ -2309,14 +2304,11 @@ void fused_attn_fp8_bwd_impl_v1( scale_dK = mha_graph->tensor(1.0f); scale_dV = mha_graph->tensor(1.0f); } - } - if (is_mxfp8) { + } else if (is_mxfp8) { NVTE_QKV_Format q_format = nvte_get_q_format(qkv_layout); NVTE_QKV_Format kv_format = nvte_get_kv_format(qkv_layout); // Q_t, K_t, dO_t, dO_f16 - std::vector q_t_strides(4); - std::vector k_t_strides(4); - std::vector dO_t_strides(4); + std::vector q_t_strides(4), k_t_strides(4), dO_t_strides(4); generateMatrixStridesWithFormat(b, h, s_q, d_qk, q_t_strides.data(), q_format); generateMatrixStridesWithFormat(b, hg, s_kv, d_qk, k_t_strides.data(), kv_format); generateMatrixStridesWithFormat(b, h, s_q, d_v, dO_t_strides.data(), do_format); @@ -2342,13 +2334,7 @@ void fused_attn_fp8_bwd_impl_v1( .set_data_type(o_tensor_type)); // Descale_q, Descale_q_t, Descale_k, Descale_k_t, Descale_v, Descale_dO, Descale_dO_t auto padded = pad_s_d_for_mxfp8(s_q, s_kv, d_qk, d_v); - std::vector q_scale_strides(4); - std::vector q_t_scale_strides(4); - std::vector k_scale_strides(4); - std::vector k_t_scale_strides(4); - std::vector v_scale_strides(4); - std::vector dO_scale_strides(4); - std::vector dO_t_scale_strides(4); + std::vector q_scale_strides(4), q_t_scale_strides(4), k_scale_strides(4), k_t_scale_strides(4), v_scale_strides(4), dO_scale_strides(4), dO_t_scale_strides(4); generateMatrixStridesWithFormat(b, h, padded.s_q_padded, padded.d_qk_scale_padded, q_scale_strides.data(), q_format); generateMatrixStridesWithFormat(b, h, padded.s_q_scale_padded, padded.d_qk_padded, @@ -2425,11 +2411,13 @@ void fused_attn_fp8_bwd_impl_v1( : fe::DiagonalAlignment_t::TOP_LEFT; sdpa_backward_options.set_diagonal_alignment(diagonal_alignment); - if (cudnn_runtime_version >= 90200 && window_size_left != -1) { - sdpa_backward_options.set_diagonal_band_left_bound(window_size_left + 1); - } - if (cudnn_runtime_version >= 90600 && window_size_right != -1) { - sdpa_backward_options.set_diagonal_band_right_bound(window_size_right); + if (cudnn_runtime_version >= 92100) { + if (window_size_left != -1) { + sdpa_backward_options.set_diagonal_band_left_bound(window_size_left + 1); + } + if (window_size_right != -1) { + sdpa_backward_options.set_diagonal_band_right_bound(window_size_right); + } } // sdpa_backward_options.set_alibi_mask(is_alibi); @@ -2502,32 +2490,20 @@ void fused_attn_fp8_bwd_impl_v1( std::shared_ptr dQ, dK, dV, amax_dQ, amax_dK, amax_dV, amax_dP; if (is_delayed_scaling || is_current_scaling) { - auto outputs = mha_graph->sdpa_fp8_backward(Q, K, V, O, dO, Stats, descale_q, descale_k, - descale_v, descale_o, descale_dO, descale_s, - descale_dP, scale_s, scale_dQ, scale_dK, - scale_dV, scale_dP, sdpa_backward_options); - dQ = outputs[0]; - dK = outputs[1]; - dV = outputs[2]; - amax_dQ = outputs[3]; - amax_dK = outputs[4]; - amax_dV = outputs[5]; - amax_dP = outputs[6]; - } - if (is_mxfp8) { - auto outputs = mha_graph->sdpa_fp8_backward( - Q, Q_t, K, K_t, V, O, dO_f16, dO, dO_t, Stats, descale_q, descale_q_t, descale_k, - descale_k_t, descale_v, descale_dO, descale_dO_t, sdpa_backward_options); - dQ = outputs[0]; - dK = outputs[1]; - dV = outputs[2]; - amax_dQ = outputs[3]; - amax_dK = outputs[4]; - amax_dV = outputs[5]; + std::tie(dQ, dK, dV, amax_dQ, amax_dK, amax_dV, amax_dP) = std::apply( + [](const auto &...elems) { return std::make_tuple(elems...); }, + mha_graph->sdpa_fp8_backward(Q, K, V, O, dO, Stats, descale_q, descale_k, descale_v, + descale_o, descale_dO, descale_s, descale_dP, scale_s, + scale_dQ, scale_dK, scale_dV, scale_dP, + sdpa_backward_options)); + } else if (is_mxfp8) { + std::tie(dQ, dK, dV, amax_dQ, amax_dK, amax_dV) = std::apply( + [](const auto &...elems) { return std::make_tuple(elems...); }, + mha_graph->sdpa_fp8_backward(Q, Q_t, K, K_t, V, O, dO_f16, dO, dO_t, Stats, descale_q, + descale_q_t, descale_k, descale_k_t, descale_v, descale_dO, + descale_dO_t, sdpa_backward_options)); } - std::vector dq_strides(4); - std::vector dk_strides(4); - std::vector dv_strides(4); + std::vector dq_strides(4), dk_strides(4), dv_strides(4); generateMatrixStridesWithLayout(b, h, hg, s_q, s_kv, d_qk, d_v, dq_strides.data(), dk_strides.data(), dv_strides.data(), dqkv_layout); dQ->set_output(true) From 09b21ee312ba3fa9331969851567d342658e16a3 Mon Sep 17 00:00:00 2001 From: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Date: Tue, 24 Mar 2026 13:32:35 -0700 Subject: [PATCH 121/172] tweak checks in utils.py Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --- .../attention/dot_product_attention/utils.py | 38 +++++++++++++------ 1 file changed, 26 insertions(+), 12 deletions(-) diff --git a/transformer_engine/pytorch/attention/dot_product_attention/utils.py b/transformer_engine/pytorch/attention/dot_product_attention/utils.py index a003226651..bb4f5b1065 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/utils.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/utils.py @@ -460,19 +460,23 @@ def get_attention_backend( torch.Tensor, Float8Tensor, Float8TensorStorage, - MXFP8Tensor, - MXFP8TensorStorage, ): if use_flash_attention_3 and FlashAttentionUtils.v3_is_installed: logger.debug( "Disabling FlashAttention 3 for unsupported qkv_dtype = %s, qkv_type = %s." " Supported: qkv_dtype = {torch.bfloat16, torch.float16, torch.float8_e4m3fn}," - " qkv_type = {torch.Tensor, Float8Tensor, Float8TensorStorage, MXFP8Tensor," - " MXFP8TensorStorage}. ", + " qkv_type = {torch.Tensor, Float8Tensor, Float8TensorStorage}. ", qkv_dtype, qkv_type, ) use_flash_attention_3 = False + if qkv_dtype not in [torch.bfloat16, torch.float16, torch.float8_e4m3fn] or qkv_type not in ( + torch.Tensor, + Float8Tensor, + Float8TensorStorage, + MXFP8Tensor, + MXFP8TensorStorage, + ): if use_fused_attention: logger.debug( "Disabling FusedAttention for unsupported qkv_dtype = %s, qkv_type = %s. Supported:" @@ -486,11 +490,10 @@ def get_attention_backend( # Filter: Execution type fp8_recipe = None - if fp8: - fp8_recipe = fp8_meta["recipe"] if fp8_meta is not None else None + if fp8 and fp8_meta["recipe"].fp8_dpa: + fp8_recipe = fp8_meta["recipe"] if fp8_meta.get("local_recipes", None) is not None: fp8_recipe = fp8_meta["local_recipes"][0] - if fp8 and fp8_recipe.fp8_dpa: if use_flash_attention_2 and FlashAttentionUtils.is_installed: logger.debug("Disabling FlashAttention 2 for FP8 attention") use_flash_attention_2 = False @@ -765,10 +768,15 @@ def _is_fa3_supported(num_heads, num_gqa_groups, head_dim_qk, head_dim_v, qkv_dt logger.debug("Disabling FlashAttention for softmax_type = %s", softmax_type) use_flash_attention = False if fp8 and fp8_recipe.fp8_dpa: - logger.debug( - "Disabling UnfusedDotProductAttention for softmax_type = %s in FP8", softmax_type - ) - use_unfused_attention = False + if use_fused_attention and (device_compute_capability < (10, 0) or cudnn_version < (9, 21, 0)): + logger.debug("Disabling FusedAttention for softmax_type = %s in FP8 on sm < 100 with cuDNN" + " version < 9.21", softmax_type) + use_fused_attention = False + if use_unfused_attention: + logger.debug( + "Disabling UnfusedDotProductAttention for softmax_type = %s in FP8", softmax_type + ) + use_unfused_attention = False if qkv_format == "thd" and cudnn_version < (9, 18, 0): logger.debug( "Disabling FusedAttention for softmax_type = %s, qkv_format = thd and cuDNN" @@ -957,7 +965,13 @@ def _is_fa3_supported(num_heads, num_gqa_groups, head_dim_qk, head_dim_v, qkv_dt if window_size is None: window_size = check_set_window_size(attn_mask_type, window_size) if use_fused_attention and (window_size[0] != -1 or window_size[1] not in [-1, 0]): - if attention_dropout != 0.0: + if fp8 and (fp8_recipe.fp8_dpa or fp8_recipe.fp8_mha) and (device_compute_capability < (10, 0) or cudnn_version < (9, 21, 0)): + logger.debug( + "Disabling FusedAttention as it does not support sliding window attention for FP8 on sm < 100 with cuDNN" + " version < 9.21" + ) + use_fused_attention = False + elif attention_dropout != 0.0: logger.debug( "Disabling FusedAttention as it only supports sliding window attention " "without dropout" From 49a54c0cc043650451b2de042e67f080f28dad84 Mon Sep 17 00:00:00 2001 From: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Date: Tue, 24 Mar 2026 13:54:39 -0700 Subject: [PATCH 122/172] tweak UnfusedDPA Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --- .../dot_product_attention/backends.py | 26 +++++++++---------- 1 file changed, 13 insertions(+), 13 deletions(-) diff --git a/transformer_engine/pytorch/attention/dot_product_attention/backends.py b/transformer_engine/pytorch/attention/dot_product_attention/backends.py index 5ab3ff2507..e07a6c7d5f 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/backends.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/backends.py @@ -185,7 +185,7 @@ def forward(ctx, tensor1, tensor2, tensor3, quantizer, quantizer_name, qkv_layou query_layer, key_layer, value_layer = [ x.contiguous() for x in [tensor1, tensor2, tensor3] ] - # sbhd_sbhd_sbhd should always be the shape at this point + # always in sbhd_sbhd_sbhd shape at this point q_fp8, k_fp8, v_fp8, qkv_layout = combine_and_quantize( qkv_layout, query_layer, key_layer, value_layer, quantizer ) @@ -193,8 +193,7 @@ def forward(ctx, tensor1, tensor2, tensor3, quantizer, quantizer_name, qkv_layou qkv_layout, q_fp8, k_fp8, v_fp8, src_nominal_dtype=query_layer.dtype ) if isinstance(quantizer, MXFP8Quantizer): - # bhsd_bhsd_bhsd should always be the shape at this point - # permute back to sbhd_sbhd_sbhd + # always in bhsd_bhsd_bhsd shape at this point; permute it back to sbhd_sbhd_sbhd tensors = [x.permute(2, 0, 1, 3).contiguous() for x in tensors] elif quantizer_name in ["S_quantizer", "O_quantizer"]: if quantizer is not None: @@ -220,15 +219,15 @@ def backward(ctx, grad1, grad2, grad3): tensors = grad1, grad2, grad3 elif ctx.quantizer_name == "dQKV_quantizer": query_grad, key_grad, value_grad = [x.contiguous() for x in [grad1, grad2, grad3]] - dq_fp8, dk_fp8, dv_fp8, ctx.qkv_layout = combine_and_quantize( + # always in sbhd_sbhd_sbhd shape at this point + dq_fp8, dk_fp8, dv_fp8, new_qkv_layout = combine_and_quantize( ctx.qkv_layout, query_grad, key_grad, value_grad, ctx.quantizer ) tensors = combine_and_dequantize( - ctx.qkv_layout, dq_fp8, dk_fp8, dv_fp8, src_nominal_dtype=query_grad.dtype + new_qkv_layout, dq_fp8, dk_fp8, dv_fp8, src_nominal_dtype=query_grad.dtype ) if isinstance(ctx.quantizer, MXFP8Quantizer): - # bhsd_bhsd_bhsd should always be the shape at this point - # permute back to sbhd_sbhd_sbhd + # always in bhsd_bhsd_bhsd shape at this point; permute it back to sbhd_sbhd_sbhd tensors = [x.permute(2, 0, 1, 3).contiguous() for x in tensors] else: tensors = grad1, grad2, grad3 @@ -406,7 +405,6 @@ def forward( query_layer.shape[0], key_layer.shape[0], ) - apply_qk_layer_scaling = self.apply_qk_layer_scaling and key_layer.dtype == torch.float16 if "padding" in attn_mask_type and attention_mask is None: attention_mask = dpa_utils.get_padding_mask( @@ -433,6 +431,8 @@ def forward( ) ) + apply_qk_layer_scaling = self.apply_qk_layer_scaling and key_layer.dtype == torch.float16 + # [b, h, sq, sk] output_size = ( query_layer.size(1), @@ -483,7 +483,7 @@ def forward( fp8_dtype=dP_quantizer.dtype, device="cuda" ) # disable swizzle for MXFP8Quantizer - for q in [ + for quantizer in [ QKV_quantizer, O_quantizer, S_quantizer, @@ -491,9 +491,9 @@ def forward( dO_quantizer, dP_quantizer, ]: - if isinstance(q, MXFP8Quantizer): - q.optimize_for_gemm = False - q.internal = False + if isinstance(quantizer, MXFP8Quantizer): + quantizer.optimize_for_gemm = False + quantizer.internal = False # q, k, v are in sbhd after previous reshaping # quantize and dequantize QKV to emulate FP8 @@ -1245,7 +1245,7 @@ def forward( fp8_recipe = fp8_meta["local_recipes"][0] # qkv_layout may change due to MXFP8 quantization - # o_format should stay the same as original qkv_format + # o_format should stay the same as original q_format original_qkv_layout = qkv_layout _, o_format, _ = dpa_utils.get_qkv_format(qkv_layout) From e5d49d2f2d300b218afae5afddde5af7fd5c1342 Mon Sep 17 00:00:00 2001 From: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Date: Tue, 24 Mar 2026 14:53:52 -0700 Subject: [PATCH 123/172] enable testing for ag+swa and disable fp8_mha Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --- tests/pytorch/attention/test_attention_with_cp.py | 10 ++-------- 1 file changed, 2 insertions(+), 8 deletions(-) diff --git a/tests/pytorch/attention/test_attention_with_cp.py b/tests/pytorch/attention/test_attention_with_cp.py index 94e009e377..dc41d5292b 100644 --- a/tests/pytorch/attention/test_attention_with_cp.py +++ b/tests/pytorch/attention/test_attention_with_cp.py @@ -313,14 +313,6 @@ def test_cp_with_fused_attention( ]: pytest.skip("No support for SWA with cp_comm_type={p2p, a2a+p2p}!") - # TODO: Remove this once the issue is fixed! - if ( - dtype == "fp8" - and (config.window_size[0] != -1 or config.window_size[1] not in [-1, 0]) - and cp_comm_type == "all_gather" - ): - pytest.skip("No support for SWA with FP8 attention and cp_comm_type=all_gather!") - if cp_comm_type in ["a2a", "a2a+p2p"] and ( config.num_heads % 2 != 0 or config.num_gqa_groups % 2 != 0 ): @@ -350,6 +342,8 @@ def test_cp_with_fused_attention( pytest.skip("scaling_mode=delayed requires f16_O=False!") if scaling_mode == "mxfp8" and not f16_O: pytest.skip("scaling_mode=mxfp8 requires f16_O=True!") + if scaling_mode == "mxfp8" and fp8_mha: + pytest.skip("No support for scaling_mode=mxfp8 with fp8_mha=True!") dtypes = {"fp16": torch.float16, "bf16": torch.bfloat16, "fp8": torch.bfloat16} From 2c63d835afdc10f1b5e85353378de07579053d84 Mon Sep 17 00:00:00 2001 From: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Date: Tue, 24 Mar 2026 15:58:23 -0700 Subject: [PATCH 124/172] tweak FusedAttn, fp8/f16 tensor naming/docstring Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --- .../dot_product_attention/backends.py | 61 ++++++++++--------- 1 file changed, 33 insertions(+), 28 deletions(-) diff --git a/transformer_engine/pytorch/attention/dot_product_attention/backends.py b/transformer_engine/pytorch/attention/dot_product_attention/backends.py index e07a6c7d5f..cf55a632fd 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/backends.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/backends.py @@ -1249,8 +1249,8 @@ def forward( original_qkv_layout = qkv_layout _, o_format, _ = dpa_utils.get_qkv_format(qkv_layout) - # input types are inferred from the real data while output types are controlled by fp8_output - # fp8_output should be set upstream as (DPA.fp8 and DPA.fp8_meta["recipe"].fp8_mha) + # input types are inferred from real data while output types are controlled by fp8_output + # fp8_output should be set upstream assert isinstance(k, q.__class__) and isinstance( v, q.__class__ ), "q, k, v must be of the same class, e.g. torch.Tensor or QuantizedTensorStorage." @@ -1392,7 +1392,7 @@ def forward( # fp8_dtype = tex.DType.kFloat8E4M3 # out: torch.Tensor; dtype = torch.float16 or torch.bfloat16 out_fp8 = out_ - out = out_ + out_f16 = out_ bwd_requires_o_f16 = is_training and ( not is_bwd_fp8 or ( @@ -1413,7 +1413,7 @@ def forward( ) if isinstance(out_, QuantizedTensorStorage): if not is_output_fp8 or bwd_requires_o_f16: - out = out_.dequantize().view(out_.shape) + out_f16 = out_.dequantize().view(out_.shape) else: if is_output_fp8 or bwd_requires_o_fp8: out_fp8 = O_quantizer(out_) @@ -1431,27 +1431,27 @@ def forward( ) # return appropriate tensors - out_ret = out_fp8 if is_output_fp8 else out + out_ret = out_fp8 if is_output_fp8 else out_f16 if _run_shadow_f16_fwd and _replace_out_return_with_shadow_f16: out_ret = out_f16_ if _run_shadow_f16_fwd and _replace_aux_with_shadow_f16: aux_ctx_tensors[0] = aux_ctx_tensors_f16[0] - # save appropriate tensors + # save q, k, v, o tensors fp8_tensors = (None, None, None, None) - qkvo_tensors = (None, None, None, None) + f16_tensors = (None, None, None, None) if is_bwd_fp8: if ( fp8_recipe.float8_current_scaling() and _dpa_fp8_cs_o_in_f16 ) or fp8_recipe.mxfp8(): fp8_tensors = (q_fp8, k_fp8, v_fp8, None) - qkvo_tensors = (None, None, None, out) + f16_tensors = (None, None, None, out_f16) elif fp8_recipe.delayed() or ( fp8_recipe.float8_current_scaling() and not _dpa_fp8_cs_o_in_f16 ): fp8_tensors = (q_fp8, k_fp8, v_fp8, out_fp8) if _run_shadow_f16_bwd: - qkvo_tensors = (q, k, v, out) + f16_tensors = (q, k, v, out_f16) else: if is_input_fp8: q, k, v = combine_and_dequantize(qkv_layout, q_fp8, k_fp8, v_fp8) @@ -1475,8 +1475,8 @@ def forward( else: q, k = q_, k_ if _run_shadow_f16_fwd and _replace_out_save_with_shadow_f16: - out = out_f16_ - qkvo_tensors = (q, k, v, out) + out_f16 = out_f16_ + f16_tensors = (q, k, v, out_f16) else: # q, k, v, out_: torch.Tensor; dtype = torch.float16 or torch.bfloat16 out_, aux_ctx_tensors, *max_logit = fused_attn_fwd( @@ -1512,10 +1512,10 @@ def forward( return_max_logit, is_graph_capturing(), ) - out = out_ + out_f16 = out_ out_ret = out_ fp8_tensors = (None, None, None, None) - qkvo_tensors = (q, k, v, out) + f16_tensors = (q, k, v, out_f16) nvtx_range_pop(f"{nvtx_label}") @@ -1529,7 +1529,7 @@ def forward( if ctx.fp8: tensor_list = fp8_tensors else: - tensor_list = [q, k, v, out] + tensor_list = [q, k, v, out_f16] mark_activation_offload(*tensor_list) mark_activation_offload(*aux_ctx_tensors) @@ -1539,7 +1539,7 @@ def forward( tensors_to_save, tensor_objects = prepare_for_saving( *fp8_tensors, - *qkvo_tensors, + *f16_tensors, cu_seqlens_q, cu_seqlens_kv, cu_seqlens_q_padded, @@ -1620,6 +1620,8 @@ def backward(ctx, d_out, *_args): # d_out: torch.Tensor; dtype = torch.float16 or torch.bfloat16 # d_out_fp8: Float8Tensor; dtype = torch.float16 or torch.bfloat16 # fp8_dtype = tex.DType.kFloat8E5M2 + if not isinstance(d_out, QuantizedTensorStorage) and not ctx.use_FAv2_bwd: + d_out = d_out.contiguous() d_out_fp8 = None d_out_format = ctx.o_format if ctx.fp8: @@ -1629,8 +1631,6 @@ def backward(ctx, d_out, *_args): d_out_fp8 = d_out else: d_out_fp8 = ctx.dO_quantizer(d_out) - if not isinstance(d_out, QuantizedTensorStorage) and not ctx.use_FAv2_bwd: - d_out = d_out.contiguous() ( q_fp8, k_fp8, @@ -1707,20 +1707,25 @@ def backward(ctx, d_out, *_args): ctx.dP_quantizer, ) - # q_fp8, k_fp8, v_fp8: Float8Tensor/MXFP8Tensor; dtype = torch.float16 or torch.bfloat16, + # DelayedScaling/Float8CurrentScaling/MXFP8BlockScaling: + # q_fp8, k_fp8, v_fp8: Float8Tensor/MXFP8Tensor; dtype = torch.float16 or torch.bfloat16, # fp8_dtype = tex.DType.kFloat8E4M3 - # d_out_fp8: Float8Tensor/MXFP8Tensor; dtype = torch.float16 or torch.bfloat16 + # d_out_fp8: Float8Tensor/MXFP8Tensor; dtype = torch.float16 or torch.bfloat16 # fp8_dtype = tex.DType.kFloat8E5M2 # DelayedScaling: - # out_, dq_, dk_, dv_: Float8Tensor; dtype = torch.float16 or torch.bfloat16 + # out_: Float8Tensor; dtype = torch.float16 or torch.bfloat16 + # fp8_dtype = tex.DType.kFloat8E4M3 + # dq_, dk_, dv_: Float8Tensor; dtype = torch.float16 or torch.bfloat16 + # fp8_dtype = tex.DType.kFloat8E5M2 + # Float8CurrentScaling: + # out_: NVTE_DPA_FP8CS_O_in_F16=1: + # torch.Tensor; dtype = torch.float16 or torch.bfloat16 + # NVTE_DPA_FP8CS_O_in_F16=0: + # Float8Tensor; dtype = torch.float16 or torch.bfloat16 # fp8_dtype = tex.DType.kFloat8E4M3 - # Float8CurrentScaling + NVTE_DPA_FP8CS_O_in_F16=1: - # out_, dq_, dk_, dv_: torch.Tensor; dtype = torch.float16 or torch.bfloat16 - # Float8CurrentScaling + NVTE_DPA_FP8CS_O_in_F16=0: - # out_: Float8Tensor; dtype = torch.float16 or torch.bfloat16 - # dq_, dk_, dv_: torch.Tensor; dtype = torch.float16 or torch.bfloat16 + # dq_, dk_, dv_: torch.Tensor; dtype = torch.float16 or torch.bfloat16 # MXFP8BlockScaling: - # out_, dq_, dk_, dv_, d_out: torch.Tensor; dtype = torch.float16 or torch.bfloat16 + # out_, dq_, dk_, dv_, d_out: torch.Tensor; dtype = torch.float16 or torch.bfloat16 out_ = out_fp8 if ctx.fp8_recipe.float8_current_scaling() and _dpa_fp8_cs_o_in_f16: out_ = out @@ -1844,7 +1849,7 @@ def backward(ctx, d_out, *_args): if is_quantized_tensor and not ctx.is_input_fp8: # return in F16 dq, dk, dv = combine_and_dequantize( - ctx.qkv_layout, + ctx.dqkv_layout, dq_, dk_, dv_, @@ -1853,7 +1858,7 @@ def backward(ctx, d_out, *_args): if not is_quantized_tensor and ctx.is_input_fp8: # return in FP8 dq, dk, dv, _ = combine_and_quantize( - ctx.qkv_layout, dq_, dk_, dv_, ctx.dQKV_quantizer + ctx.dqkv_layout, dq_, dk_, dv_, ctx.dQKV_quantizer ) # print quantizers From 7f62b98c5a330e0cd743ebd304c52ac8584415c4 Mon Sep 17 00:00:00 2001 From: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Date: Tue, 24 Mar 2026 16:06:24 -0700 Subject: [PATCH 125/172] replace d_out_format with do_format Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --- .../dot_product_attention/backends.py | 8 +++---- .../dot_product_attention/context_parallel.py | 22 +++++++++---------- .../pytorch/cpp_extensions/fused_attn.py | 6 ++--- 3 files changed, 18 insertions(+), 18 deletions(-) diff --git a/transformer_engine/pytorch/attention/dot_product_attention/backends.py b/transformer_engine/pytorch/attention/dot_product_attention/backends.py index cf55a632fd..d598e011d1 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/backends.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/backends.py @@ -1623,10 +1623,10 @@ def backward(ctx, d_out, *_args): if not isinstance(d_out, QuantizedTensorStorage) and not ctx.use_FAv2_bwd: d_out = d_out.contiguous() d_out_fp8 = None - d_out_format = ctx.o_format + do_format = ctx.o_format if ctx.fp8: if ctx.fp8_recipe.mxfp8(): - d_out, d_out_format = dpa_utils.permute_to_grouped_tensor(d_out_format, d_out) + d_out, do_format = dpa_utils.permute_to_grouped_tensor(do_format, d_out) if isinstance(d_out, QuantizedTensorStorage): d_out_fp8 = d_out else: @@ -1755,7 +1755,7 @@ def backward(ctx, d_out, *_args): ctx.fast_zero_fill, ctx.qkv_layout, ctx.o_format, - d_out_format, + do_format, ctx.dqkv_layout, ctx.attn_bias_type, ctx.attn_mask_type, @@ -1899,7 +1899,7 @@ def backward(ctx, d_out, *_args): ctx.fast_zero_fill, ctx.qkv_layout, ctx.o_format, - d_out_format, + do_format, ctx.dqkv_layout, ctx.attn_bias_type, ctx.attn_mask_type, diff --git a/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py b/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py index 90df998e09..fe2019ef80 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py @@ -1145,7 +1145,7 @@ def cp_p2p_bwd_fused_attn( dropout_p, qkv_layout, o_format, - d_out_format, + do_format, dqkv_layout, attn_mask_type, attn_bias_type, @@ -1223,7 +1223,7 @@ def cp_p2p_bwd_fused_attn( out_part = Float8Tensor.make_like(out_fp8, data=out_part, dtype=fwd_nominal_dtype) dout_part = Float8Tensor.make_like(dout_fp8, data=dout_part, dtype=bwd_nominal_dtype) else: - dout_part, d_out_format = dpa_utils.permute_to_grouped_tensor(d_out_format, dout_part) + dout_part, do_format = dpa_utils.permute_to_grouped_tensor(do_format, dout_part) aux_tensors.append(dout_part) dout_part = dO_quantizer_per_step(dout_part) fp8_meta_kwargs["s_quantizer"] = S_quantizer @@ -1249,7 +1249,7 @@ def cp_p2p_bwd_fused_attn( dropout=dropout_p, qkv_layout=qkv_layout, o_format=o_format, - d_out_format=d_out_format, + do_format=do_format, dqkv_layout=dqkv_layout, attn_mask_type=attn_mask_type_, attn_bias_type=attn_bias_type, @@ -3571,7 +3571,7 @@ def backward(ctx, dout, *_args): fused_attn_backend = tex.NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen fp8_meta_kwargs = {} qkv_layout = ctx.qkv_layout - d_out_format = ctx.o_format + do_format = ctx.o_format if ctx.fp8: fused_attn_backend = tex.NVTE_Fused_Attn_Backend.NVTE_FP8 fp8_meta_kwargs["s_quantizer"] = ctx.S_quantizer @@ -3604,8 +3604,8 @@ def backward(ctx, dout, *_args): q_part, k_part, v_part, qkv_layout = combine_and_quantize( qkv_layout, q_part, k_part, v_part, ctx.QKV_quantizer ) - dout_part, d_out_format = dpa_utils.permute_to_grouped_tensor( - d_out_format, dout_part + dout_part, do_format = dpa_utils.permute_to_grouped_tensor( + do_format, dout_part ) aux_ctx_tensors.append(dout_part) dout_part = ctx.dO_quantizer(dout_part) @@ -3628,7 +3628,7 @@ def backward(ctx, dout, *_args): dropout=ctx.dropout_p, qkv_layout=qkv_layout, o_format=ctx.o_format, - d_out_format=d_out_format, + do_format=do_format, dqkv_layout=ctx.dqkv_layout, attn_mask_type=ctx.attn_mask_type, attn_bias_type=ctx.attn_bias_type, @@ -4311,7 +4311,7 @@ def backward(ctx, dout, *_args): dq_fp8, dk_fp8, dv_fp8 = None, None, None if ctx.use_fused_attention: - d_out_format = ctx.o_format + do_format = ctx.o_format q_part, k_part, v_part, out_part, dout_part = q, k, v, out, dout if ctx.fp8: q_part, k_part, v_part, out_part = q_fp8, k_fp8, v_fp8, out_fp8 @@ -4322,8 +4322,8 @@ def backward(ctx, dout, *_args): if not ctx.fp8_recipe.mxfp8(): dout_part = Float8Tensor.make_like(dout_fp8, data=dout, dtype=bwd_nominal_dtype) else: - # d_out_format = bhsd for both dout (F16) and dout_part (MXFP8) - dout, d_out_format = dpa_utils.permute_to_grouped_tensor(d_out_format, dout) + # do_format = bhsd for both dout (F16) and dout_part (MXFP8) + dout, do_format = dpa_utils.permute_to_grouped_tensor(do_format, dout) aux_ctx_tensors.append(dout) dout_part = ctx.dO_quantizer(dout) dq, dk, dv, *rest = fused_attn_bwd( @@ -4345,7 +4345,7 @@ def backward(ctx, dout, *_args): dropout=ctx.dropout_p, qkv_layout=ctx.qkv_layout, o_format=ctx.o_format, - d_out_format=d_out_format, + do_format=do_format, dqkv_layout=ctx.dqkv_layout, attn_mask_type=ctx.attn_mask_type, attn_bias_type=ctx.attn_bias_type, diff --git a/transformer_engine/pytorch/cpp_extensions/fused_attn.py b/transformer_engine/pytorch/cpp_extensions/fused_attn.py index 0ca738cdb8..61ee95662d 100644 --- a/transformer_engine/pytorch/cpp_extensions/fused_attn.py +++ b/transformer_engine/pytorch/cpp_extensions/fused_attn.py @@ -386,7 +386,7 @@ def fused_attn_bwd( fast_zero_fill: bool = True, qkv_layout: str = "sbh3d", o_format: str = "sbhd", - d_out_format: str = "sbhd", + do_format: str = "sbhd", dqkv_layout: str = "sbh3d", attn_bias_type: str = "no_bias", attn_mask_type: str = "padding", @@ -454,7 +454,7 @@ def fused_attn_bwd( "t3hd", "th3d", "thd_t2hd", "thd_th2d", "thd_thd_thd"} o_format : str, default = "sbhd" format of O; {"sbhd", "bshd", "thd"} - d_out_format : str, default = "sbhd" + do_format : str, default = "sbhd" format of dO; {"sbhd", "bshd", "thd"} dqkv_layout : str, default = "sbh3d" layout of dQ, dK and dV; @@ -529,7 +529,7 @@ def fused_attn_bwd( fast_zero_fill, QKVLayout[qkv_layout], QKVFormat[o_format], - QKVFormat[d_out_format], + QKVFormat[do_format], QKVLayout[dqkv_layout], AttnBiasType[attn_bias_type], AttnMaskType[attn_mask_type], From 4b9240c7e8040739c3cc9ddb52b4305fa55cff5e Mon Sep 17 00:00:00 2001 From: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Date: Tue, 24 Mar 2026 16:37:56 -0700 Subject: [PATCH 126/172] fix lint Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --- .../common/fused_attn/fused_attn_fp8.cu | 2 +- .../debug/features/log_fp8_tensor_stats.py | 4 ++-- .../dot_product_attention/backends.py | 2 -- .../dot_product_attention/context_parallel.py | 21 ++++++------------- .../dot_product_attention.py | 1 + .../attention/dot_product_attention/utils.py | 13 ++++++------ 6 files changed, 17 insertions(+), 26 deletions(-) diff --git a/transformer_engine/common/fused_attn/fused_attn_fp8.cu b/transformer_engine/common/fused_attn/fused_attn_fp8.cu index a7beda1e1c..7c063e5465 100644 --- a/transformer_engine/common/fused_attn/fused_attn_fp8.cu +++ b/transformer_engine/common/fused_attn/fused_attn_fp8.cu @@ -2693,7 +2693,7 @@ void fused_attn_fp8_bwd_impl_v1( } catch (cudnn_frontend::cudnnException& e) { NVTE_ERROR(e.what()); } -} +} // NOLINT(readability/fn_size) #endif diff --git a/transformer_engine/debug/features/log_fp8_tensor_stats.py b/transformer_engine/debug/features/log_fp8_tensor_stats.py index cf11964e25..16c27c2b62 100644 --- a/transformer_engine/debug/features/log_fp8_tensor_stats.py +++ b/transformer_engine/debug/features/log_fp8_tensor_stats.py @@ -10,11 +10,11 @@ import torch import nvdlfw_inspect.api as debug_api -import transformer_engine_torch as tex - from nvdlfw_inspect.debug_features.log_tensor_stats import LogTensorStats as BaseLogTensorStats from nvdlfw_inspect.registry import Registry, api_method +import transformer_engine_torch as tex + from transformer_engine.debug.features.utils.stats_buffer import STATS_BUFFERS from transformer_engine.debug.features.utils import get_reduction_params, next_enabled_iter from transformer_engine.pytorch.tensor import Quantizer, QuantizedTensor diff --git a/transformer_engine/pytorch/attention/dot_product_attention/backends.py b/transformer_engine/pytorch/attention/dot_product_attention/backends.py index d598e011d1..53d57c56b4 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/backends.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/backends.py @@ -36,9 +36,7 @@ restore_from_saved, ) from transformer_engine.pytorch.tensor.float8_tensor import Float8Tensor -from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8TensorStorage from transformer_engine.pytorch.constants import ( - TE_DType, QKVLayouts, dist_group_type, ) diff --git a/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py b/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py index fe2019ef80..89c4bdc3f9 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py @@ -25,10 +25,7 @@ from transformer_engine.pytorch.quantized_tensor import QuantizedTensorStorage from transformer_engine.pytorch.jit import jit_fuser from transformer_engine.pytorch.graph import is_graph_capturing -from transformer_engine.pytorch.constants import ( - dist_group_type, - TE_DType, -) +from transformer_engine.pytorch.constants import dist_group_type from transformer_engine.pytorch.distributed import ( get_distributed_world_size, get_distributed_rank, @@ -536,7 +533,7 @@ def flash_attn_a2a_communicate( # [cp, 2, s//2, b, h//cp, d] -> [2, s//2, b, cp, h//cp, d] # [cp, 2, b, h//cp, s//2, d] -> [2, b, cp, h//cp, s//2, d] # [cp, t, h//cp, d] -> [t, cp, h//cp, d] - tmp_list = [x for x in qkv_format] + tmp_list = list(qkv_format) if "t" not in qkv_format: tmp_list.insert(0, "2") tmp_list.insert(0, "c") @@ -3067,7 +3064,7 @@ def forward( fp8_meta_kwargs["o_quantizer"] = O_quantizer elif use_fused_attention: fused_attn_backend = tex.NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen - orig_q_shape, orig_k_shape, orig_v_shape = q.shape, k.shape, v.shape + orig_q_shape, _, orig_v_shape = q.shape, k.shape, v.shape orig_o_shape = orig_q_shape[:-1] + orig_v_shape[-1:] # q, k, v: @@ -3201,7 +3198,7 @@ def forward( if fp8: softmax_lse_per_step[i], _, rng_states[i] = aux_ctx_tensors else: - softmax_lse_per_step[i], rng_states[i], *rest = aux_ctx_tensors + softmax_lse_per_step[i], rng_states[i], *_ = aux_ctx_tensors if return_max_logit: max_logit_per_step[i] = max_logit_[0] if fp8 and isinstance(out_per_step[i], QuantizedTensorStorage): @@ -4220,7 +4217,7 @@ def backward(ctx, dout, *_args): *aux_ctx_tensors, ) = restore_from_saved(ctx.tensor_objects, ctx.saved_tensors) - batch_dim_dqkv, seq_dim_dqkv, _ = get_bsh_dims(ctx.dqkv_format) + _, seq_dim_dqkv, _ = get_bsh_dims(ctx.dqkv_format) _, seq_dim_do, _ = get_bsh_dims(ctx.o_format) bwd_nominal_dtype = ctx.fwd_nominal_dtype fused_attn_backend = None @@ -4579,6 +4576,7 @@ def attn_forward_func_with_cp( in Megatron-LM. """ + if cp_comm_type == "a2a+p2p": assert ( isinstance(cp_group, list) and len(cp_group) == 2 @@ -4624,13 +4622,6 @@ def attn_forward_func_with_cp( "all_gather", ], f"Context parallelism does not support sliding window attention with {cp_comm_type=}!" - enable_mla = k.shape[-1] != v.shape[-1] - # assert not enable_mla or cp_comm_type in [ - # "p2p", - # "a2a+p2p", - # "a2a", - # ], f"Context parallelism does not support MLA with {cp_comm_type=}!" - if fp8 and fp8_meta is not None: if fp8_meta["recipe"].fp8_dpa: assert ( diff --git a/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py b/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py index d1fd0b0ed0..0d4d31f405 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py @@ -19,6 +19,7 @@ Recipe, DelayedScaling, Float8CurrentScaling, + MXFP8BlockScaling, ) from transformer_engine.pytorch.utils import get_cudnn_version from transformer_engine.pytorch.quantization import ( diff --git a/transformer_engine/pytorch/attention/dot_product_attention/utils.py b/transformer_engine/pytorch/attention/dot_product_attention/utils.py index bb4f5b1065..4c3439c3a8 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/utils.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/utils.py @@ -346,7 +346,7 @@ def get_attention_backend( attention_dropout = attention_params.attention_dropout context_parallel = attention_params.context_parallel cp_comm_type = attention_params.cp_comm_type - cp_size = attention_params.cp_size + cp_size = attention_params.cp_size # pylint: disable=unused-variable deterministic = attention_params.deterministic is_training = attention_params.is_training fp8 = attention_params.fp8 @@ -505,7 +505,7 @@ def get_attention_backend( fp8_recipe.delayed() or fp8_recipe.float8_current_scaling() ): if FlashAttentionUtils.v3_is_installed: - logger.debug(f"Disabling FlashAttention 3 for {fp8_recipe.__class__.__name__}") + logger.debug("Disabling FlashAttention 3 for %s", fp8_recipe.__class__.__name__) use_flash_attention_3 = False if use_unfused_attention: allow_emulation = ( @@ -557,7 +557,7 @@ def get_attention_backend( logger.debug("Disabling FusedAttention for MXFP8 with qkv_format = thd") use_fused_attention = False if use_fused_attention and (fp8_recipe.float8_block_scaling() or fp8_recipe.nvfp4()): - logger.debug(f"Disabling FusedAttention for {fp8_recipe.__class__.__name__}") + logger.debug("Disabling FusedAttention for %s", fp8_recipe.__class__.__name__) use_fused_attention = False if device_compute_capability == (12, 0): @@ -2339,7 +2339,8 @@ def combine_and_quantize(qkv_layout, q, k, v, qkv_quantizer): device="cuda", dtype=q.dtype, ) - q_fp8, k_fp8, v_fp8 = grouped_tensor.quantize(input_tensors) + quantized_tensors = grouped_tensor.quantize(input_tensors) + q_fp8, k_fp8, v_fp8 = quantized_tensors[0], quantized_tensors[1], quantized_tensors[2] else: input_tensors = [q, k] num_tensors = len(input_tensors) @@ -2351,7 +2352,8 @@ def combine_and_quantize(qkv_layout, q, k, v, qkv_quantizer): device="cuda", dtype=q.dtype, ) - q_fp8, k_fp8 = grouped_tensor.quantize(input_tensors) + quantized_tensors = grouped_tensor.quantize(input_tensors) + q_fp8, k_fp8 = quantized_tensors[0], quantized_tensors[1] v_fp8 = qkv_quantizer(v) # view rowwise/columnwise data back to original shapes, not rowwise_scale_inv/columnwise_scale_inv q_fp8, k_fp8, v_fp8 = [x.view(s) for x, s in zip([q_fp8, k_fp8, v_fp8], original_shapes)] @@ -2420,7 +2422,6 @@ def combine_and_dequantize( return q, k, v qkv_layout = qkv_layout.replace("paged_kv_", "") - qkv_format, q_format, kv_format = get_qkv_format(qkv_layout) qkv_group = len(qkv_layout.split("_")) q_data, k_data, v_data = [x._data for x in [q_fp8, k_fp8, v_fp8]] # 1: qkv packed, 2: kv packed, 3: qkv separate From 2a21a3a7a32627762a8f68d353dd4aa6fbbe0d3c Mon Sep 17 00:00:00 2001 From: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Date: Tue, 24 Mar 2026 17:38:40 -0700 Subject: [PATCH 127/172] clean up a2a Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --- .../dot_product_attention/context_parallel.py | 26 +++++++++---------- 1 file changed, 12 insertions(+), 14 deletions(-) diff --git a/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py b/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py index 89c4bdc3f9..57a177274d 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py @@ -22,6 +22,7 @@ ) from transformer_engine.pytorch.quantization import FP8GlobalStateManager from transformer_engine.pytorch.tensor.float8_tensor import Float8Tensor +from transformer_engine.pytorch.tensor.storage.float8_tensor_storage import Float8TensorStorage from transformer_engine.pytorch.quantized_tensor import QuantizedTensorStorage from transformer_engine.pytorch.jit import jit_fuser from transformer_engine.pytorch.graph import is_graph_capturing @@ -3813,6 +3814,7 @@ def forward( qkv_layout = qkv_format + "_" + qkv_format + "_" + qkv_format original_qkv_layout = qkv_layout orig_q_shape, orig_k_shape, orig_v_shape = q.shape, k.shape, v.shape + orig_o_shape = orig_q_shape[:-1] + orig_v_shape[-1:] o_format = qkv_format batch_dim_qkv, seq_dim_qkv, _ = get_bsh_dims(qkv_format) _, seq_dim_o, _ = get_bsh_dims(o_format) @@ -3954,7 +3956,6 @@ def forward( # _part: inputs to attention kernel and saved for backward # note: they have post a2a shapes - batch_size = q.shape[batch_dim_qkv] q_part, k_part, v_part = q, k, v out_part, out_fp8, out_f16 = None, None, None bwd_requires_o_f16 = is_training and ( @@ -4015,6 +4016,7 @@ def forward( cuda_graph=is_graph_capturing(), ) # construct out_part for backward + # out_fp8 and out_f16 store the FP8 or F16 tensor for backward saves out_fp8 = out_ out_f16 = out_ if bwd_requires_o_fp8: @@ -4056,7 +4058,7 @@ def forward( # [b, s, h//cp, d] -> [b*s//cp, h, d] # [s, b, h//cp, d] -> [s//cp*b, h, d] # [t, h//cp, d] -> [t//cp, h, d] - if isinstance(out_, QuantizedTensorStorage): + if isinstance(out_, Float8TensorStorage): out_ = out_._data chunk_ids_for_a2a = get_seq_chunk_ids_for_reordering_after_attn(cp_size, out_.device) out_ = flash_attn_a2a_communicate( @@ -4074,12 +4076,10 @@ def forward( # [b*s//cp, h, d] -> [b, s//cp, h, d] # [s//cp*b, h, d] -> [s//cp, b, h, d] # [t//cp, h, d] -> [t//cp, h, d] - if o_format == "bshd": - out_ = out_.view(batch_size, -1, *out_.shape[-2:]) - elif o_format == "sbhd": - out_ = out_.view(-1, batch_size, *out_.shape[-2:]) + out_ = out_.view(orig_o_shape) # out_ret: output tensor for forward pass + # out_fp8 and out_f16 are reused here to store the FP8 or F16 tensor for forward returns if fp8: if fp8_recipe.delayed(): out_fp8 = Float8Tensor.make_like(out_fp8, data=out_, dtype=fwd_nominal_dtype) @@ -4109,8 +4109,7 @@ def forward( ctx.orig_q_shape = orig_q_shape ctx.orig_k_shape = orig_k_shape ctx.orig_v_shape = orig_v_shape - ctx.out_part_shape = out_part.shape - ctx.out_ret_shape = out_ret.shape + ctx.orig_o_shape = orig_o_shape # save tensors for backward ctx.fp8 = fp8 and is_bwd_fp8 @@ -4130,9 +4129,8 @@ def forward( ): fp8_tensors = (q_part, k_part, v_part, out_part) elif fp8: - # FP8DS/CS: convert post-a2a FP8 q/k/v to F16 - # MXFP8: save post-a2a pre-quantization F16 q/k/v - # out_part is already converted to the right precision + # FP8DS/CS: convert post-a2a FP8 q/k/v to F16; out_part already in F16 + # MXFP8: save post-a2a pre-quantization F16 q/k/v; out_part already in F16 if fp8_recipe.mxfp8(): f16_tensors = (q, k, v, out_part) ctx.qkv_layout = original_qkv_layout @@ -4142,7 +4140,7 @@ def forward( ) f16_tensors = (q_part, k_part, v_part, out_part) else: - # all tensors are already in F16 + # all tensors are in F16 f16_tensors = (q_part, k_part, v_part, out_part) tensors_to_save, tensor_objects = prepare_for_saving( *fp8_tensors, @@ -4243,7 +4241,7 @@ def backward(ctx, dout, *_args): dout = dout.dequantize(dtype=bwd_nominal_dtype) if ctx.use_fused_attention: fused_attn_backend = FusedAttnBackend["F16_arbitrary_seqlen"] - dout = dout.view(*ctx.out_ret_shape) + dout = dout.view(*ctx.orig_o_shape) # dout: # FP8DS/CS: torch.uint8 @@ -4352,7 +4350,7 @@ def backward(ctx, dout, *_args): **fp8_meta_kwargs, softmax_type=ctx.softmax_type, ) - if all(isinstance(x, QuantizedTensorStorage) for x in [dq, dk, dv]): + if all(isinstance(x, Float8TensorStorage) for x in [dq, dk, dv]): dq_fp8, dk_fp8, dv_fp8 = dq, dk, dv dq, dk, dv = [x._data for x in [dq, dk, dv]] else: From a18cd7cfe4362cc846379510c06e18535e31bdd5 Mon Sep 17 00:00:00 2001 From: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Date: Tue, 24 Mar 2026 18:24:42 -0700 Subject: [PATCH 128/172] clean up ag Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --- .../dot_product_attention/context_parallel.py | 72 ++++++++----------- 1 file changed, 29 insertions(+), 43 deletions(-) diff --git a/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py b/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py index 57a177274d..6ab6825aeb 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py @@ -2983,7 +2983,7 @@ def forward( ) assert q.shape[seq_dim_qkv] % 2 == 0 and k.shape[seq_dim_qkv] % 2 == 0, ( "cp_comm_type='all_gather' requires seq_len % 2 == 0 for Q, K, V. Found seq_len_q =" - f" {q.shape[seq_dim_qkv]}, seq_len_kv = {k.shape[seq_dim_qkv]}, cp_size = {cp_size}." + f" {q.shape[seq_dim_qkv]}, seq_len_kv = {k.shape[seq_dim_qkv]}." ) flash_attn_fwd = None @@ -3101,10 +3101,10 @@ def forward( # v: [s, b, h, d] # k_ag: [cp*s, b, h, d] # v_ag: [cp*s, b, h, d] - # out: [b, 2, s//2, h, d] or [2, s//2, b, h, d] + # out_f16: [b, 2, s//2, h, d] or [2, s//2, b, h, d] q_shape, k_shape, v_shape = q.shape, k.shape, v.shape o_shape = q.shape[:-1] + v.shape[-1:] - out = torch.empty(o_shape, dtype=fwd_nominal_dtype, device=q.device) + out_f16 = torch.empty(o_shape, dtype=fwd_nominal_dtype, device=q.device) # create two streams to resolve wave quantization issue of Flash Attn in each step flash_attn_streams = [torch.cuda.current_stream(), cp_stream] @@ -3154,15 +3154,9 @@ def forward( new_qkv_layout = qkv_layout if fp8: if not fp8_recipe.mxfp8(): - q_part = Float8Tensor.make_like( - q_fp8, data=q_part, dtype=fwd_nominal_dtype - ) - k_part = Float8Tensor.make_like( - k_fp8, data=k_part, dtype=fwd_nominal_dtype - ) - v_part = Float8Tensor.make_like( - v_fp8, data=v_part, dtype=fwd_nominal_dtype - ) + q_part, k_part, v_part = [Float8Tensor.make_like( + x, data=y, dtype=fwd_nominal_dtype + ) for x, y in zip([q_fp8, k_fp8, v_fp8], [q_part, k_part, v_part])] else: q_part, k_part, v_part, new_qkv_layout = combine_and_quantize( qkv_layout, q_part, k_part, v_part, QKV_quantizer @@ -3239,7 +3233,7 @@ def forward( rng_states[i] = fa_outputs[3] # out_per_step[i]: fwd_nominal_dtype, [b, s//2, h, d] or [s//2, b, h, d] - # out: fwd_nominal_dtype, [b, 2, s//2, h, d] or [2, s//2, b, h, d] + # out_f16: fwd_nominal_dtype, [b, 2, s//2, h, d] or [2, s//2, b, h, d] # max_logit_per_step[i]: torch.float32, [h] # max_logit: torch.float32, [h] if return_max_logit and i == 0: @@ -3247,9 +3241,9 @@ def forward( if i > 0: with torch.cuda.stream(flash_attn_streams[i - 1]): if o_format == "bshd": - out[:, i - 1].copy_(out_per_step[i - 1]) + out_f16[:, i - 1].copy_(out_per_step[i - 1]) elif o_format == "sbhd": - out[i - 1].copy_(out_per_step[i - 1]) + out_f16[i - 1].copy_(out_per_step[i - 1]) if return_max_logit: max_logit = torch.maximum(max_logit, max_logit_per_step[i - 1]) @@ -3261,10 +3255,10 @@ def forward( max_logit, op=torch.distributed.ReduceOp.MAX, group=cp_group ) - # out: fwd_nominal_dtype + # out_f16: fwd_nominal_dtype # [b, 2, s//2, h, d] -> [b, s, h, d] # [2, s//2, b, h, d] -> [s, b, h, d] - out = out.view(orig_o_shape) + out_f16 = out_f16.view(orig_o_shape) # prepare for forward output and backward saves of out out_fp8 = None @@ -3277,8 +3271,8 @@ def forward( ) ) if fp8 and (is_output_fp8 or bwd_requires_o_fp8): - out_fp8 = O_quantizer(out) - out_ret = out_fp8 if is_output_fp8 else out + out_fp8 = O_quantizer(out_f16) + out_ret = out_fp8 if is_output_fp8 else out_f16 # save tensors for backward ctx.fp8 = fp8 and is_bwd_fp8 @@ -3288,11 +3282,11 @@ def forward( # True: q split along s; k/v with s first, i.e. [s, b, h, d] # False: original [b, s, h, d] or [s, b, h, d] ctx.qkv_reshaped = True - # no load-balance related token shuffling; original token order in q/k/v/out + # no load-balance related token shuffling; original token order in q/k/v/out_f16 # q: [b, 2, s//2, h, d] or [2, s//2, b, h, d] # k: [s, b, h, d] # v: [s, b, h, d] - # out/out_fp8: [b, s, h, d] or [s, b, h, d] + # out_f16/out_fp8: [b, s, h, d] or [s, b, h, d] if ctx.fp8: # q_fp8_save: [b, 2, s//2, h, d] or [2, s//2, b, h, d] # k_fp8_save: [s, b, h, d] @@ -3311,22 +3305,22 @@ def forward( fp8_tensors = (q_fp8_save, k_fp8_save, v_fp8_save, out_fp8) elif fp8_recipe.float8_current_scaling() and _dpa_fp8_cs_o_in_f16: fp8_tensors = (q_fp8_save, k_fp8_save, v_fp8_save, None) - f16_tensors = (None, None, None, out) + f16_tensors = (None, None, None, out_f16) elif fp8_recipe.mxfp8(): - f16_tensors = (q, k, v, out) + f16_tensors = (q, k, v, out_f16) elif fp8: # convert q/k/v to F16 if necessary, and save q/k/v/o all in F16 and original format if is_input_fp8: q_f16, k_f16, v_f16 = combine_and_dequantize(qkv_layout, q_fp8, k_fp8, v_fp8) - f16_tensors = (q_f16, k_f16, v_f16, out) + f16_tensors = (q_f16, k_f16, v_f16, out_f16) ctx.qkv_reshaped = False else: # save all in F16 # q: [b, 2, s//2, h, d] or [2, s//2, b, h, d] # k: [s, b, h, d] # v: [s, b, h, d] - # out: [b, s, h, d] or [s, b, h, d] - f16_tensors = (q, k, v, out) + # out_f16: [b, s, h, d] or [s, b, h, d] + f16_tensors = (q, k, v, out_f16) tensors_to_save, tensor_objects = prepare_for_saving( *fp8_tensors, *f16_tensors, @@ -3345,11 +3339,10 @@ def forward( ctx.dqkv_format = qkv_format ctx.dqkv_layout = qkv_layout ctx.fwd_nominal_dtype = fwd_nominal_dtype - ctx.orig_o_shape = orig_o_shape - ctx.o_shape = o_shape ctx.q_shape = q_shape ctx.k_shape = k_shape ctx.v_shape = v_shape + ctx.o_shape = o_shape ctx.kv_seq_range_per_step = kv_seq_range_per_step ctx.window_size_per_step = window_size_per_step @@ -3365,7 +3358,6 @@ def forward( ctx.use_flash_attn_3 = use_flash_attn_3 ctx.fp8_meta = fp8_meta ctx.is_input_fp8 = is_input_fp8 - ctx.is_output_fp8 = is_output_fp8 ctx.dQKV_quantizer = dQKV_quantizer ctx.dO_quantizer = dO_quantizer @@ -3568,7 +3560,7 @@ def backward(ctx, dout, *_args): ] fused_attn_backend = tex.NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen fp8_meta_kwargs = {} - qkv_layout = ctx.qkv_layout + new_qkv_layout = ctx.qkv_layout do_format = ctx.o_format if ctx.fp8: fused_attn_backend = tex.NVTE_Fused_Attn_Backend.NVTE_FP8 @@ -3579,15 +3571,9 @@ def backward(ctx, dout, *_args): # FP8CS+_dpa_fp8_cs_o_in_f16: q/k/v/do in FP8, o in f16 # MXFP8: q/k/v/do all in MXFP8, o/do_f16 in F16 if not ctx.fp8_recipe.mxfp8(): - q_part = Float8Tensor.make_like( - q_fp8, data=q_part, dtype=ctx.fwd_nominal_dtype - ) - k_part = Float8Tensor.make_like( - k_fp8, data=k_part, dtype=ctx.fwd_nominal_dtype - ) - v_part = Float8Tensor.make_like( - v_fp8, data=v_part, dtype=ctx.fwd_nominal_dtype - ) + q_part, k_part, v_part = [Float8Tensor.make_like( + x, data=y, dtype=ctx.fwd_nominal_dtype + ) for x, y in zip([q_fp8, k_fp8, v_fp8], [q_part, k_part, v_part])] if ctx.fp8_recipe.delayed() or ( ctx.fp8_recipe.float8_current_scaling() and not _dpa_fp8_cs_o_in_f16 @@ -3599,8 +3585,8 @@ def backward(ctx, dout, *_args): dout_fp8, data=dout_part, dtype=ctx.fwd_nominal_dtype ) else: - q_part, k_part, v_part, qkv_layout = combine_and_quantize( - qkv_layout, q_part, k_part, v_part, ctx.QKV_quantizer + q_part, k_part, v_part, new_qkv_layout = combine_and_quantize( + ctx.qkv_layout, q_part, k_part, v_part, ctx.QKV_quantizer ) dout_part, do_format = dpa_utils.permute_to_grouped_tensor( do_format, dout_part @@ -3624,7 +3610,7 @@ def backward(ctx, dout, *_args): cu_seqlens_kv_padded=cu_seqlens_kv_per_step[i], attn_scale=ctx.softmax_scale, dropout=ctx.dropout_p, - qkv_layout=qkv_layout, + qkv_layout=new_qkv_layout, o_format=ctx.o_format, do_format=do_format, dqkv_layout=ctx.dqkv_layout, @@ -3736,7 +3722,7 @@ def backward(ctx, dout, *_args): # quantize if necessary if ctx.fp8 and ctx.is_input_fp8: - dq, dk, dv, _ = combine_and_quantize(qkv_layout, dq, dk, dv, ctx.dQKV_quantizer) + dq, dk, dv, _ = combine_and_quantize(ctx.dqkv_layout, dq, dk, dv, ctx.dQKV_quantizer) nvtx_range_pop("transformer_engine.AttnFuncWithCPAndKVAllGather.backward") return ( From a19ccb38cf74f24fc6dcd5e846e4caa7c03e5eb3 Mon Sep 17 00:00:00 2001 From: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Date: Tue, 24 Mar 2026 18:52:43 -0700 Subject: [PATCH 129/172] clean up p2p/a2a+p2p Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --- .../dot_product_attention/context_parallel.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py b/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py index 6ab6825aeb..35f1336d80 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py @@ -2078,8 +2078,8 @@ def forward( # fwd: fp8, bwd: f16, save all f16 # there is already an F16 version of the inputs q_f16, k_f16, v_f16 = combine_and_dequantize(qkv_layout, q, k, v) - kv = torch.cat((k_f16.view(-1), v_f16.view(-1)), dim=-1) - f16_tensors = (q_f16, kv, out_f16) + kv_f16 = torch.cat((k_f16.view(-1), v_f16.view(-1)), dim=-1) + f16_tensors = (q_f16, kv_f16, out_f16) elif fp8 and not is_input_fp8 and fp8_recipe.mxfp8(): f16_tensors = (q, kv, out_f16) elif fp8: @@ -2131,15 +2131,15 @@ def forward( ctx.is_output_fp8 = is_output_fp8 ctx.use_flash_attn_3 = use_flash_attn_3 - ctx.k_numel = k_numel - ctx.k_shape = k_shape - ctx.v_shape = v_shape - ctx.o_shape = o_shape ctx.orig_q_shape = orig_q_shape ctx.orig_k_shape = orig_k_shape ctx.orig_v_shape = orig_v_shape ctx.orig_o_shape = orig_o_shape ctx.post_a2a_o_shape = post_a2a_o_shape + ctx.k_numel = k_numel + ctx.k_shape = k_shape + ctx.v_shape = v_shape + ctx.o_shape = o_shape ctx.qkv_format = qkv_format ctx.qkv_layout = qkv_layout ctx.fwd_nominal_dtype = fwd_nominal_dtype @@ -2173,7 +2173,7 @@ def backward(ctx, dout, *_args): nvtx_range_push(f"{nvtx_label}") # dout is expected to be in FP8 if is_output_fp8=True, - # but in the case it's not, convert it to FP8 before any operation + # but in the case it's not, convert it to FP8 (except for MXFP8) before any operation if ( ctx.fp8 and ctx.is_output_fp8 From 4ba2ef50edfca75f2d5b8d8e3dbb9a74af380490 Mon Sep 17 00:00:00 2001 From: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Date: Tue, 24 Mar 2026 19:46:17 -0700 Subject: [PATCH 130/172] tweak test configs Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --- .../attention/run_attention_with_cp.py | 24 +++++----- tests/pytorch/attention/test_attention.py | 45 ++++++++++++------- .../attention/test_attention_with_cp.py | 33 +++++++++----- tests/pytorch/utils.py | 1 - 4 files changed, 62 insertions(+), 41 deletions(-) diff --git a/tests/pytorch/attention/run_attention_with_cp.py b/tests/pytorch/attention/run_attention_with_cp.py index a5804c6888..4bdaa6f4b9 100644 --- a/tests/pytorch/attention/run_attention_with_cp.py +++ b/tests/pytorch/attention/run_attention_with_cp.py @@ -193,15 +193,16 @@ def run_dpa_with_cp( logging.root.setLevel(log_level) # When is_training is False, gradient outputs are None. is_training = is_training == "True" + + # set up environment variables and config if deterministic == "True": os.environ["NVTE_ALLOW_NONDETERMINISTIC_ALGO"] = "0" else: os.environ["NVTE_ALLOW_NONDETERMINISTIC_ALGO"] = "1" - # set up environment variables and config fp8_bwd = fp8_bwd == "True" and dtype == "fp8" os.environ["NVTE_FP8_DPA_BWD"] = "1" if fp8_bwd else "0" fp8_dpa = fp8_dpa == "True" and dtype == "fp8" - fp8_mha = fp8_mha == "True" and dtype == "fp8" + fp8_mha = fp8_mha == "True" and dtype == "fp8" and scaling_mode != "mxfp8" f16_O = dtype == "fp8" and scaling_mode in ["current", "mxfp8"] and f16_O == "True" os.environ["NVTE_DPA_FP8CS_O_in_F16"] = "1" if f16_O else "0" os.environ["NVTE_FLASH_ATTN"] = "0" @@ -259,7 +260,7 @@ def run_dpa_with_cp( fp8_recipe = Float8CurrentScaling(fp8_dpa=fp8_dpa, fp8_mha=fp8_mha) if scaling_mode == "mxfp8": fp8_recipe = MXFP8BlockScaling( - fp8_format=Format.HYBRID, fp8_dpa=fp8_dpa, fp8_mha=fp8_mha + fp8_format=Format.E4M3, fp8_dpa=fp8_dpa, fp8_mha=fp8_mha ) # instantiate attention module @@ -333,7 +334,7 @@ def run_dpa_with_cp( dout_quantizer.internal = False qkv_layout = "_".join([qkv_format] * 3) q, k, v, dout = [x.clone().detach() for x in [q_orig, k_orig, v_orig, dout_orig]] - if fp8_mha and scaling_mode != "mxfp8": + if fp8_mha: q, k, v, qkv_layout = combine_and_quantize(qkv_layout, q, k, v, qkv_quantizer) for x in [q, k, v]: x.requires_grad = True @@ -380,12 +381,12 @@ def run_dpa_with_cp( cu_seqlens_kv=cu_seqlens_kv, cu_seqlens_q_padded=cu_seqlens_q_padded, cu_seqlens_kv_padded=cu_seqlens_kv_padded, - # fp8_output=fp8_mha, + fp8_output=fp8_mha, ) if config.return_max_logit: out, max_logit = out if is_training: - if fp8_bwd and fp8_mha and scaling_mode != "mxfp8": + if fp8_bwd and fp8_mha: dout_fp8 = dout_quantizer(dout) out.backward(dout_fp8) else: @@ -441,7 +442,7 @@ def run_dpa_with_cp( qkv_quantizer.amax.fill_(0.0) dout_quantizer.scale.fill_(1.0) dout_quantizer.amax.fill_(0.0) - if fp8_mha and scaling_mode != "mxfp8": + if fp8_mha: q_, k_, v_, qkv_layout = combine_and_quantize(qkv_layout, q_, k_, v_, qkv_quantizer) if is_training: q_, k_, v_ = [x.requires_grad_() for x in [q_, k_, v_]] @@ -497,12 +498,12 @@ def run_dpa_with_cp( cu_seqlens_kv=cu_seqlens_kv, cu_seqlens_q_padded=cu_seqlens_q_padded, cu_seqlens_kv_padded=cu_seqlens_kv_padded, - # fp8_output=fp8_mha, + fp8_output=fp8_mha, ) if config.return_max_logit: out_, max_logit_ = out_ if is_training: - if fp8_bwd and fp8_mha and scaling_mode != "mxfp8": + if fp8_bwd and fp8_mha: dout_fp8_ = dout_quantizer(dout_) out_.backward(dout_fp8_) else: @@ -523,6 +524,7 @@ def run_dpa_with_cp( # get outputs tensors = [out, dq, dk, dv, dbias, out_, dq_, dk_, dv_, dbias_] + names = ["out", "dq", "dk", "dv", "dbias", "out_cp", "dq_cp", "dk_cp", "dv_cp", "dbias_cp"] if fp8_mha: tensors_to_deq = [out, out_] if not fp8_bwd else tensors for i, tensor in enumerate(tensors_to_deq): @@ -534,8 +536,8 @@ def run_dpa_with_cp( for i, tensor in enumerate(tensors): # dbias/dbias_ could be None, so skip check for it if tensor is not None: - assert torch.all(~torch.isnan(tensor)) - assert torch.all(~torch.isinf(tensor)) + assert torch.all(~torch.isnan(tensor)), f"{names[i]} contains NaN" + assert torch.all(~torch.isinf(tensor)), f"{names[i]} contains Inf" out, dq, dk, dv, dbias, out_, dq_, dk_, dv_, dbias_ = tensors ############ compare results between CP and no-CP ############ diff --git a/tests/pytorch/attention/test_attention.py b/tests/pytorch/attention/test_attention.py index 6b92c2e3ed..129ff1c472 100644 --- a/tests/pytorch/attention/test_attention.py +++ b/tests/pytorch/attention/test_attention.py @@ -1804,8 +1804,6 @@ def get_model(dtype, config): attn_mask_type = "causal" -# attn_mask_type = "no_mask" -# attn_mask_type = "causal_bottom_right" model_configs_fp8_vs_f16 = { # test: ModelConfig(b, sq, hq, dqk) "fp8_9": ModelConfig( @@ -1814,27 +1812,39 @@ def get_model(dtype, config): 128, 192, head_dim_v=128, - attn_mask_type=attn_mask_type, ), - "fp8_10": ModelConfig(2, 8192, 32, 128, num_gqa_groups=4, attn_mask_type=attn_mask_type), - "fp8_11": ModelConfig(2, 8192, 32, 128, attn_mask_type=attn_mask_type, window_size=(128, 0)), - "fp8_12": ModelConfig(2, 8192, 64, 64, num_gqa_groups=8, attn_mask_type=attn_mask_type), - "fp8_13": ModelConfig(2, 8192, 64, 64, attn_mask_type=attn_mask_type, window_size=(128, 0)), - "fp8_14": ModelConfig( + "fp8_10": ModelConfig( + 2, + 4096, + 128, + 192, + head_dim_v=128, + attn_mask_type="causal", + ), + "fp8_11": ModelConfig( + 2, + 4096, + 128, + 192, + head_dim_v=128, + attn_mask_type="causal_bottom_right", + ), + "fp8_12": ModelConfig(2, 8192, 32, 128, num_gqa_groups=4, attn_mask_type="causal"), + "fp8_13": ModelConfig(2, 8192, 32, 128, attn_mask_type="causal", window_size=(128, 0)), + "fp8_14": ModelConfig(2, 8192, 64, 64, num_gqa_groups=8, attn_mask_type="causal"), + "fp8_15": ModelConfig(2, 8192, 64, 64, attn_mask_type="causal", window_size=(128, 0)), + "fp8_16": ModelConfig( 2, 8192, 64, 64, num_gqa_groups=8, attn_mask_type="causal", softmax_type="learnable" ), - "fp8_15": ModelConfig( + "fp8_17": ModelConfig( 2, 8192, 64, 64, attn_mask_type="causal", window_size=(128, 0), softmax_type="learnable" ), - # "fp8_15": ModelConfig(2, 2048, 16, 128, attn_mask_type="padding"), - # "fp8_16": ModelConfig(2, 2048, 24, 128, num_gqa_groups=12, attn_mask_type="padding"), - # "fp8_17": ModelConfig(1, 8192, 32, 128, num_gqa_groups=4, attn_mask_type="padding"), - # "fp8_18": ModelConfig(2, 2048, 16, 128, attn_mask_type="padding_causal"), - # "fp8_19": ModelConfig(2, 2048, 24, 128, num_gqa_groups=12, attn_mask_type="padding_causal"), - # "fp8_20": ModelConfig(1, 8192, 32, 128, num_gqa_groups=4, attn_mask_type="padding_causal"), + "fp8_18": ModelConfig(1, 8192, 32, 128, num_gqa_groups=4, attn_mask_type="padding"), + "fp8_19": ModelConfig(2, 2048, 16, 128, attn_mask_type="padding_causal"), + "fp8_20": ModelConfig(2, 2048, 24, 128, num_gqa_groups=12, attn_mask_type="padding_causal"), } -param_types_fp8_vs_f16 = [torch.bfloat16] # [torch.float16, torch.bfloat16] +param_types_fp8_vs_f16 = [torch.float16, torch.bfloat16] qkv_layout_fp8_vs_f16 = ["sbh3d", "bshd_bshd_bshd", "sbhd_sbhd_sbhd"] qkv_format_fp8_vs_f16 = ["bshd", "sbhd"] @@ -1883,7 +1893,7 @@ def test_mha_fp8_vs_f16( fp8_recipe = recipe.MXFP8BlockScaling( fp8_format=recipe.Format.E4M3, fp8_dpa=True, - fp8_mha=True, + fp8_mha=False, ) fp8_meta = {} fp8_meta["recipe"] = fp8_recipe @@ -2068,6 +2078,7 @@ def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker: hidden_states.requires_grad = True tensor = 0.01 * torch.randn(tensor_shape, dtype=dtype, device="cuda") out_grad = tensor.view(*tensor.shape[:-2], -1) + with autocast(enabled=fp8_mha, recipe=fp8_recipe): out = mha( hidden_states, diff --git a/tests/pytorch/attention/test_attention_with_cp.py b/tests/pytorch/attention/test_attention_with_cp.py index dc41d5292b..693c2cf960 100644 --- a/tests/pytorch/attention/test_attention_with_cp.py +++ b/tests/pytorch/attention/test_attention_with_cp.py @@ -79,10 +79,10 @@ def get_bash_arguments(num_gpus_per_node, **kwargs): qkv_formats = ["bshd", "sbhd", "thd"] cp_comm_types = ["p2p", "all_gather", "a2a", "a2a+p2p"] if test_essential: - configs = ["cp_2_0", "cp_3_0", "cp_2_2"] # , "cp_1_2", "cp_2_1"]#, "cp_1_1", "cp_3_3"] + configs = ["cp_2_0", "cp_2_2", "cp_3_0"] # ["cp_1_0", "cp_1_2", "cp_2_1", "cp_3_2", "cp_3_3"] model_configs_flash_attn = {k: model_configs_flash_attn[k] for k in configs} dtypes = ["bf16"] - qkv_formats = ["bshd", "sbhd", "thd"] + # qkv_formats = ["sbhd", "thd"] @pytest.mark.skipif(not FlashAttentionUtils.v2_plus, reason="Flash-attn 2.0+ is required.") @@ -165,7 +165,7 @@ def test_cp_with_flash_attention(dtype, model, qkv_format, cp_comm_type): "cp_1_4": ModelConfig( 2, 4096, 12, 128, attn_bias_type="post_scale_bias", bias_shape="bhss" ), # MHA - "cp_1_5": ModelConfig(2, 4096, 32, 128, attn_mask_type="causal", window_size=(128, 0)), # MHA + "cp_1_5": ModelConfig(2, 4096, 12, 128, attn_mask_type="causal", window_size=(512, 512)), # MHA "cp_2_0": ModelConfig( 2, 4096, @@ -175,8 +175,13 @@ def test_cp_with_flash_attention(dtype, model, qkv_format, cp_comm_type): attn_mask_type="causal", ), # GQA "cp_2_1": ModelConfig( - 2, 4096, 128, 192, head_dim_v=128, attn_mask_type="causal" - ), # num_gqa_groups=4, attn_mask_type="causal"), # GQA + 2, + 4096, + 32, + 128, + attn_mask_type="causal", + window_size=(128, 0), + ), # GQA "cp_2_2": ModelConfig( 2, 4096, @@ -214,7 +219,7 @@ def test_cp_with_flash_attention(dtype, model, qkv_format, cp_comm_type): 2, 4096, 12, 128, num_gqa_groups=2, attn_mask_type="causal", window_size=(512, 512) ), # GQA "cp_3_0": ModelConfig(2, 4096, 12, 128, attn_mask_type="causal", head_dim_v=64), # MLA - "cp_3_1": ModelConfig(2, 4096, 12, 192, head_dim_v=128), # MLA + "cp_3_1": ModelConfig(2, 4096, 128, 192, head_dim_v=128, attn_mask_type="causal"), # MLA "cp_3_2": ModelConfig( 2, 4096, 12, 128, attn_mask_type="causal", attn_bias_type="post_scale_bias", head_dim_v=64 ), # MLA @@ -231,6 +236,9 @@ def test_cp_with_flash_attention(dtype, model, qkv_format, cp_comm_type): "cp_4_2": ModelConfig( 2, 4096, 64, 64, num_gqa_groups=8, attn_mask_type="causal", softmax_type="learnable" ), # GQA + "cp_4_3": ModelConfig( + 2, 4096, 64, 64, attn_mask_type="causal", window_size=(128, 0), softmax_type="learnable" + ), # GQA } @@ -242,20 +250,21 @@ def test_cp_with_flash_attention(dtype, model, qkv_format, cp_comm_type): # "cp_1_0", # "cp_1_1", # "cp_1_4", - "cp_1_5", + # "cp_1_5", "cp_2_0", "cp_2_1", # "cp_2_2", # "cp_2_3", # "cp_2_4", - # "cp_3_1", + "cp_3_1", # "cp_3_2", # "cp_3_4", - # "cp_4_2", + "cp_4_2", + "cp_4_3", ] model_configs_fused_attn = {k: model_configs_fused_attn[k] for k in configs} dtypes = ["bf16", "fp8"] - qkv_formats = ["bshd", "sbhd", "thd"] + # qkv_formats = ["sbhd", "thd"] @pytest.mark.skipif(get_cudnn_version() < (8, 9, 7), reason="cuDNN 8.9.7+ is required.") @@ -368,9 +377,9 @@ def test_cp_with_fused_attention( DelayedScaling(fp8_dpa=True), ] if fp8 and scaling_mode == "mxfp8": - fp8_meta["recipe"] = MXFP8BlockScaling(fp8_format=Format.HYBRID, fp8_dpa=True) + fp8_meta["recipe"] = MXFP8BlockScaling(fp8_format=Format.E4M3, fp8_dpa=True) fp8_meta["local_recipes"] = [ - MXFP8BlockScaling(fp8_format=Format.HYBRID, fp8_dpa=True), + MXFP8BlockScaling(fp8_format=Format.E4M3, fp8_dpa=True), ] # For 111s, dbias calculation is not supported as of cuDNN 9.18, hence, test fwd only for 111s. diff --git a/tests/pytorch/utils.py b/tests/pytorch/utils.py index 795b3c3441..1d2fe7f06f 100644 --- a/tests/pytorch/utils.py +++ b/tests/pytorch/utils.py @@ -181,7 +181,6 @@ def compare_and_assert(a, b, name_a, name_b, atol, rtol, rmse_tol, is_fp8): rmse = torch.sqrt((a - b).square().mean()).item() logging.debug(name_a + " vs " + name_b + " RMSE: {:.6f}".format(rmse)) rmse_range = max(a.max().item(), b.max().item()) - min(a.min().item(), b.min().item()) - # rmse_tol = rmse_tol * 1.1 assert rmse < rmse_tol * rmse_range, ( name_a + " vs " From 875931ced79665284c71ea2df867692d71d697d0 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 25 Mar 2026 02:47:08 +0000 Subject: [PATCH 131/172] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../attention/run_attention_with_cp.py | 4 +-- .../common/fused_attn/fused_attn.cpp | 12 ++++----- .../common/fused_attn/fused_attn_fp8.cu | 17 ++++++------ .../dot_product_attention/context_parallel.py | 14 +++++----- .../attention/dot_product_attention/utils.py | 26 +++++++++++++------ 5 files changed, 42 insertions(+), 31 deletions(-) diff --git a/tests/pytorch/attention/run_attention_with_cp.py b/tests/pytorch/attention/run_attention_with_cp.py index 4bdaa6f4b9..cda5c42d50 100644 --- a/tests/pytorch/attention/run_attention_with_cp.py +++ b/tests/pytorch/attention/run_attention_with_cp.py @@ -259,9 +259,7 @@ def run_dpa_with_cp( if scaling_mode == "current": fp8_recipe = Float8CurrentScaling(fp8_dpa=fp8_dpa, fp8_mha=fp8_mha) if scaling_mode == "mxfp8": - fp8_recipe = MXFP8BlockScaling( - fp8_format=Format.E4M3, fp8_dpa=fp8_dpa, fp8_mha=fp8_mha - ) + fp8_recipe = MXFP8BlockScaling(fp8_format=Format.E4M3, fp8_dpa=fp8_dpa, fp8_mha=fp8_mha) # instantiate attention module core_attn = DotProductAttention( diff --git a/transformer_engine/common/fused_attn/fused_attn.cpp b/transformer_engine/common/fused_attn/fused_attn.cpp index 22c02e7664..77c4dd143a 100644 --- a/transformer_engine/common/fused_attn/fused_attn.cpp +++ b/transformer_engine/common/fused_attn/fused_attn.cpp @@ -358,19 +358,19 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK || attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK)) || // 9.21: d_qk=192, d_v=128 - (cudnn_runtime_version >= 92100 && sm_arch_ >= 100 && - head_dim_qk <= 192 && head_dim_v <= 128 && - head_dim_qk % 16 == 0 && head_dim_v % 16 == 0 && + (cudnn_runtime_version >= 92100 && sm_arch_ >= 100 && head_dim_qk <= 192 && + head_dim_v <= 128 && head_dim_qk % 16 == 0 && head_dim_v % 16 == 0 && (attn_mask_type == NVTE_Mask_Type::NVTE_NO_MASK || - attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK || - attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_BOTTOM_RIGHT_MASK))) && + attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK || + attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_BOTTOM_RIGHT_MASK))) && // pre-9.21: {bshd, sbhd}, {vanilla} // 9.21+: {bshd, sbhd, bhsd}, {vanilla, off-by-one, learnable} ((cudnn_runtime_version < 92100 && (qkv_format == NVTE_QKV_Format::NVTE_BSHD || qkv_format == NVTE_QKV_Format::NVTE_SBHD) && softmax_type == NVTE_Softmax_Type::NVTE_VANILLA_SOFTMAX) || (cudnn_runtime_version >= 92100 && - (qkv_format == NVTE_QKV_Format::NVTE_BSHD || qkv_format == NVTE_QKV_Format::NVTE_SBHD || qkv_format == NVTE_QKV_Format::NVTE_BHSD))) && + (qkv_format == NVTE_QKV_Format::NVTE_BSHD || qkv_format == NVTE_QKV_Format::NVTE_SBHD || + qkv_format == NVTE_QKV_Format::NVTE_BHSD))) && !requires_64bit_ragged_offset && // 9.10.0: known bugs with SDPA FP8 (cudnn_runtime_version != 91000) && !return_max_logit) { diff --git a/transformer_engine/common/fused_attn/fused_attn_fp8.cu b/transformer_engine/common/fused_attn/fused_attn_fp8.cu index 7c063e5465..4a25df2185 100644 --- a/transformer_engine/common/fused_attn/fused_attn_fp8.cu +++ b/transformer_engine/common/fused_attn/fused_attn_fp8.cu @@ -2334,7 +2334,8 @@ void fused_attn_fp8_bwd_impl_v1( .set_data_type(o_tensor_type)); // Descale_q, Descale_q_t, Descale_k, Descale_k_t, Descale_v, Descale_dO, Descale_dO_t auto padded = pad_s_d_for_mxfp8(s_q, s_kv, d_qk, d_v); - std::vector q_scale_strides(4), q_t_scale_strides(4), k_scale_strides(4), k_t_scale_strides(4), v_scale_strides(4), dO_scale_strides(4), dO_t_scale_strides(4); + std::vector q_scale_strides(4), q_t_scale_strides(4), k_scale_strides(4), + k_t_scale_strides(4), v_scale_strides(4), dO_scale_strides(4), dO_t_scale_strides(4); generateMatrixStridesWithFormat(b, h, padded.s_q_padded, padded.d_qk_scale_padded, q_scale_strides.data(), q_format); generateMatrixStridesWithFormat(b, h, padded.s_q_scale_padded, padded.d_qk_padded, @@ -2490,15 +2491,15 @@ void fused_attn_fp8_bwd_impl_v1( std::shared_ptr dQ, dK, dV, amax_dQ, amax_dK, amax_dV, amax_dP; if (is_delayed_scaling || is_current_scaling) { - std::tie(dQ, dK, dV, amax_dQ, amax_dK, amax_dV, amax_dP) = std::apply( - [](const auto &...elems) { return std::make_tuple(elems...); }, - mha_graph->sdpa_fp8_backward(Q, K, V, O, dO, Stats, descale_q, descale_k, descale_v, - descale_o, descale_dO, descale_s, descale_dP, scale_s, - scale_dQ, scale_dK, scale_dV, scale_dP, - sdpa_backward_options)); + std::tie(dQ, dK, dV, amax_dQ, amax_dK, amax_dV, amax_dP) = + std::apply([](const auto&... elems) { return std::make_tuple(elems...); }, + mha_graph->sdpa_fp8_backward(Q, K, V, O, dO, Stats, descale_q, descale_k, + descale_v, descale_o, descale_dO, descale_s, + descale_dP, scale_s, scale_dQ, scale_dK, + scale_dV, scale_dP, sdpa_backward_options)); } else if (is_mxfp8) { std::tie(dQ, dK, dV, amax_dQ, amax_dK, amax_dV) = std::apply( - [](const auto &...elems) { return std::make_tuple(elems...); }, + [](const auto&... elems) { return std::make_tuple(elems...); }, mha_graph->sdpa_fp8_backward(Q, Q_t, K, K_t, V, O, dO_f16, dO, dO_t, Stats, descale_q, descale_q_t, descale_k, descale_k_t, descale_v, descale_dO, descale_dO_t, sdpa_backward_options)); diff --git a/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py b/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py index 35f1336d80..328538e8d1 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py @@ -3154,9 +3154,10 @@ def forward( new_qkv_layout = qkv_layout if fp8: if not fp8_recipe.mxfp8(): - q_part, k_part, v_part = [Float8Tensor.make_like( - x, data=y, dtype=fwd_nominal_dtype - ) for x, y in zip([q_fp8, k_fp8, v_fp8], [q_part, k_part, v_part])] + q_part, k_part, v_part = [ + Float8Tensor.make_like(x, data=y, dtype=fwd_nominal_dtype) + for x, y in zip([q_fp8, k_fp8, v_fp8], [q_part, k_part, v_part]) + ] else: q_part, k_part, v_part, new_qkv_layout = combine_and_quantize( qkv_layout, q_part, k_part, v_part, QKV_quantizer @@ -3571,9 +3572,10 @@ def backward(ctx, dout, *_args): # FP8CS+_dpa_fp8_cs_o_in_f16: q/k/v/do in FP8, o in f16 # MXFP8: q/k/v/do all in MXFP8, o/do_f16 in F16 if not ctx.fp8_recipe.mxfp8(): - q_part, k_part, v_part = [Float8Tensor.make_like( - x, data=y, dtype=ctx.fwd_nominal_dtype - ) for x, y in zip([q_fp8, k_fp8, v_fp8], [q_part, k_part, v_part])] + q_part, k_part, v_part = [ + Float8Tensor.make_like(x, data=y, dtype=ctx.fwd_nominal_dtype) + for x, y in zip([q_fp8, k_fp8, v_fp8], [q_part, k_part, v_part]) + ] if ctx.fp8_recipe.delayed() or ( ctx.fp8_recipe.float8_current_scaling() and not _dpa_fp8_cs_o_in_f16 diff --git a/transformer_engine/pytorch/attention/dot_product_attention/utils.py b/transformer_engine/pytorch/attention/dot_product_attention/utils.py index 4c3439c3a8..1658c92c83 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/utils.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/utils.py @@ -346,7 +346,7 @@ def get_attention_backend( attention_dropout = attention_params.attention_dropout context_parallel = attention_params.context_parallel cp_comm_type = attention_params.cp_comm_type - cp_size = attention_params.cp_size # pylint: disable=unused-variable + cp_size = attention_params.cp_size # pylint: disable=unused-variable deterministic = attention_params.deterministic is_training = attention_params.is_training fp8 = attention_params.fp8 @@ -768,13 +768,19 @@ def _is_fa3_supported(num_heads, num_gqa_groups, head_dim_qk, head_dim_v, qkv_dt logger.debug("Disabling FlashAttention for softmax_type = %s", softmax_type) use_flash_attention = False if fp8 and fp8_recipe.fp8_dpa: - if use_fused_attention and (device_compute_capability < (10, 0) or cudnn_version < (9, 21, 0)): - logger.debug("Disabling FusedAttention for softmax_type = %s in FP8 on sm < 100 with cuDNN" - " version < 9.21", softmax_type) + if use_fused_attention and ( + device_compute_capability < (10, 0) or cudnn_version < (9, 21, 0) + ): + logger.debug( + "Disabling FusedAttention for softmax_type = %s in FP8 on sm < 100 with cuDNN" + " version < 9.21", + softmax_type, + ) use_fused_attention = False if use_unfused_attention: logger.debug( - "Disabling UnfusedDotProductAttention for softmax_type = %s in FP8", softmax_type + "Disabling UnfusedDotProductAttention for softmax_type = %s in FP8", + softmax_type, ) use_unfused_attention = False if qkv_format == "thd" and cudnn_version < (9, 18, 0): @@ -965,10 +971,14 @@ def _is_fa3_supported(num_heads, num_gqa_groups, head_dim_qk, head_dim_v, qkv_dt if window_size is None: window_size = check_set_window_size(attn_mask_type, window_size) if use_fused_attention and (window_size[0] != -1 or window_size[1] not in [-1, 0]): - if fp8 and (fp8_recipe.fp8_dpa or fp8_recipe.fp8_mha) and (device_compute_capability < (10, 0) or cudnn_version < (9, 21, 0)): + if ( + fp8 + and (fp8_recipe.fp8_dpa or fp8_recipe.fp8_mha) + and (device_compute_capability < (10, 0) or cudnn_version < (9, 21, 0)) + ): logger.debug( - "Disabling FusedAttention as it does not support sliding window attention for FP8 on sm < 100 with cuDNN" - " version < 9.21" + "Disabling FusedAttention as it does not support sliding window attention for FP8" + " on sm < 100 with cuDNN version < 9.21" ) use_fused_attention = False elif attention_dropout != 0.0: From f0bf68015e8fb93461a6408a0a5b770bdb9bd696 Mon Sep 17 00:00:00 2001 From: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Date: Wed, 25 Mar 2026 15:26:19 -0700 Subject: [PATCH 132/172] qdq dO in bwd shadow f16 path Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --- .../attention/dot_product_attention/backends.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/transformer_engine/pytorch/attention/dot_product_attention/backends.py b/transformer_engine/pytorch/attention/dot_product_attention/backends.py index 53d57c56b4..8e8f539228 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/backends.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/backends.py @@ -162,6 +162,7 @@ _replace_dq_with_shadow_f16 = os.getenv("NVTE_REPLACE_DQ_WITH_SHADOW_F16", "0") == "1" _replace_dk_with_shadow_f16 = os.getenv("NVTE_REPLACE_DK_WITH_SHADOW_F16", "0") == "1" _replace_dv_with_shadow_f16 = os.getenv("NVTE_REPLACE_DV_WITH_SHADOW_F16", "0") == "1" +_qdq_dO_in_bwd = os.getenv("NVTE_QDQ_DO_IN_BWD", "0") == "1" class FP8EmulationFunc(torch.autograd.Function): @@ -1614,6 +1615,16 @@ def forward( def backward(ctx, d_out, *_args): # pylint: disable=missing-function-docstring d_out_shadow_f16 = d_out + if ctx.fp8 and _run_shadow_f16_bwd and _qdq_dO_in_bwd and ctx.fp8_recipe.mxfp8(): + d_out_shadow_f16, _ = dpa_utils.permute_to_grouped_tensor(ctx.o_format, d_out_shadow_f16) + tmp_quantizer = ctx.dO_quantizer.copy() + tmp_quantizer.optimize_for_gemm = False + d_out_shadow_fp8 = tmp_quantizer(d_out_shadow_f16) + d_out_shadow_f16 = d_out_shadow_fp8.dequantize(dtype=ctx.nominal_dtype) + if ctx.o_format == "bshd": + d_out_shadow_f16 = d_out_shadow_f16.permute(0, 2, 1, 3).contiguous() + elif ctx.o_format == "sbhd": + d_out_shadow_f16 = d_out_shadow_f16.permute(2, 0, 1, 3).contiguous() # d_out: torch.Tensor; dtype = torch.float16 or torch.bfloat16 # d_out_fp8: Float8Tensor; dtype = torch.float16 or torch.bfloat16 From 2d80d38c914e6307f35614f26a6c54f5a240c379 Mon Sep 17 00:00:00 2001 From: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Date: Wed, 25 Mar 2026 16:17:18 -0700 Subject: [PATCH 133/172] tweak qdq dO logic Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --- .../dot_product_attention/backends.py | 38 +++++++++++++------ 1 file changed, 27 insertions(+), 11 deletions(-) diff --git a/transformer_engine/pytorch/attention/dot_product_attention/backends.py b/transformer_engine/pytorch/attention/dot_product_attention/backends.py index 8e8f539228..59f538871c 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/backends.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/backends.py @@ -162,7 +162,8 @@ _replace_dq_with_shadow_f16 = os.getenv("NVTE_REPLACE_DQ_WITH_SHADOW_F16", "0") == "1" _replace_dk_with_shadow_f16 = os.getenv("NVTE_REPLACE_DK_WITH_SHADOW_F16", "0") == "1" _replace_dv_with_shadow_f16 = os.getenv("NVTE_REPLACE_DV_WITH_SHADOW_F16", "0") == "1" -_qdq_dO_in_bwd = os.getenv("NVTE_QDQ_DO_IN_BWD", "0") == "1" +_qdq_dO_in_mxfp8_bprop = os.getenv("NVTE_QDQ_DO_IN_MXFP8_BPROP", "0") == "1" +_qdq_dO_in_f16_bprop = os.getenv("NVTE_QDQ_DO_IN_F16_BPROP", "0") == "1" class FP8EmulationFunc(torch.autograd.Function): @@ -1615,16 +1616,31 @@ def forward( def backward(ctx, d_out, *_args): # pylint: disable=missing-function-docstring d_out_shadow_f16 = d_out - if ctx.fp8 and _run_shadow_f16_bwd and _qdq_dO_in_bwd and ctx.fp8_recipe.mxfp8(): - d_out_shadow_f16, _ = dpa_utils.permute_to_grouped_tensor(ctx.o_format, d_out_shadow_f16) - tmp_quantizer = ctx.dO_quantizer.copy() - tmp_quantizer.optimize_for_gemm = False - d_out_shadow_fp8 = tmp_quantizer(d_out_shadow_f16) - d_out_shadow_f16 = d_out_shadow_fp8.dequantize(dtype=ctx.nominal_dtype) - if ctx.o_format == "bshd": - d_out_shadow_f16 = d_out_shadow_f16.permute(0, 2, 1, 3).contiguous() - elif ctx.o_format == "sbhd": - d_out_shadow_f16 = d_out_shadow_f16.permute(2, 0, 1, 3).contiguous() + + d_out_qdq_f16 = d_out + d_out_qdq_f16, _ = dpa_utils.permute_to_grouped_tensor(ctx.o_format, d_out_qdq_f16) + tmp_quantizer = MXFP8Quantizer(fp8_dtype=tex.DType.kFloat8E4M3, rowwise=True, columnwise=True) + tmp_quantizer.optimize_for_gemm = False + d_out_qdq_fp8 = tmp_quantizer(d_out_qdq_f16) + d_out_qdq_f16 = d_out_qdq_fp8.dequantize(dtype=ctx.nominal_dtype) + if ctx.o_format == "bshd": + d_out_qdq_f16 = d_out_qdq_f16.permute(0, 2, 1, 3).contiguous() + elif ctx.o_format == "sbhd": + d_out_qdq_f16 = d_out_qdq_f16.permute(2, 0, 1, 3).contiguous() + swapped_do_with_qdq_do = False + if ctx.fp8 and _qdq_dO_in_mxfp8_bprop: + d_out = d_out_qdq_f16 + swapped_do_with_qdq_do = True + if ctx.fp8 and _qdq_dO_in_mxfp8_bprop and _run_shadow_f16_bwd: + d_out_shadow_f16 = d_out_qdq_f16 + swapped_do_with_qdq_do = True + if not ctx.fp8 and _qdq_dO_in_f16_bprop: + d_out = d_out_qdq_f16 + swapped_do_with_qdq_do = True + if swapped_do_with_qdq_do: + print(f"swapped, {ctx.fp8=},{_qdq_dO_in_mxfp8_bprop=}, {_qdq_dO_in_f16_bprop=}, {_run_shadow_f16_bwd=}, {_replace_dq_with_shadow_f16=}, {_replace_dk_with_shadow_f16=}, {_replace_dv_with_shadow_f16=}") + else: + print(f"not swapped, {ctx.fp8=}, {_qdq_dO_in_mxfp8_bprop=}, {_qdq_dO_in_f16_bprop=}, {_run_shadow_f16_bwd=}, {_replace_dq_with_shadow_f16=}, {_replace_dk_with_shadow_f16=}, {_replace_dv_with_shadow_f16=}") # d_out: torch.Tensor; dtype = torch.float16 or torch.bfloat16 # d_out_fp8: Float8Tensor; dtype = torch.float16 or torch.bfloat16 From 0cf973842de005b104d3e5f1dddb31e02ed414e9 Mon Sep 17 00:00:00 2001 From: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Date: Wed, 25 Mar 2026 16:19:07 -0700 Subject: [PATCH 134/172] remove prints in shadow paths Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --- .../dot_product_attention/backends.py | 62 +++++++++---------- 1 file changed, 31 insertions(+), 31 deletions(-) diff --git a/transformer_engine/pytorch/attention/dot_product_attention/backends.py b/transformer_engine/pytorch/attention/dot_product_attention/backends.py index 59f538871c..215ec8a078 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/backends.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/backends.py @@ -1376,17 +1376,17 @@ def forward( return_max_logit, is_graph_capturing(), ) - if torch.cuda.current_device() == 0: - print( - f"L{layer_number}: real/shadow out min:" - f" {out_.min():.4f}/{out_f16_.min():.4f}, max:" - f" {out_.max():.4f}/{out_f16_.max():.4f}" - ) - print( - f"L{layer_number}: real/shadow stats min:" - f" {aux_ctx_tensors[0].min():.4f}/{aux_ctx_tensors_f16[0].min():.4f}, max:" - f" {aux_ctx_tensors[0].max():.4f}/{aux_ctx_tensors_f16[0].max():.4f}" - ) + # if torch.cuda.current_device() == 0: + # print( + # f"L{layer_number}: real/shadow out min:" + # f" {out_.min():.4f}/{out_f16_.min():.4f}, max:" + # f" {out_.max():.4f}/{out_f16_.max():.4f}" + # ) + # print( + # f"L{layer_number}: real/shadow stats min:" + # f" {aux_ctx_tensors[0].min():.4f}/{aux_ctx_tensors_f16[0].min():.4f}, max:" + # f" {aux_ctx_tensors[0].max():.4f}/{aux_ctx_tensors_f16[0].max():.4f}" + # ) # out_fp8: Float8Tensor/MXFP8Tensor; dtype = torch.float16 or torch.bfloat16 # fp8_dtype = tex.DType.kFloat8E4M3 @@ -1637,10 +1637,10 @@ def backward(ctx, d_out, *_args): if not ctx.fp8 and _qdq_dO_in_f16_bprop: d_out = d_out_qdq_f16 swapped_do_with_qdq_do = True - if swapped_do_with_qdq_do: - print(f"swapped, {ctx.fp8=},{_qdq_dO_in_mxfp8_bprop=}, {_qdq_dO_in_f16_bprop=}, {_run_shadow_f16_bwd=}, {_replace_dq_with_shadow_f16=}, {_replace_dk_with_shadow_f16=}, {_replace_dv_with_shadow_f16=}") - else: - print(f"not swapped, {ctx.fp8=}, {_qdq_dO_in_mxfp8_bprop=}, {_qdq_dO_in_f16_bprop=}, {_run_shadow_f16_bwd=}, {_replace_dq_with_shadow_f16=}, {_replace_dk_with_shadow_f16=}, {_replace_dv_with_shadow_f16=}") + # if swapped_do_with_qdq_do: + # print(f"swapped, {ctx.fp8=},{_qdq_dO_in_mxfp8_bprop=}, {_qdq_dO_in_f16_bprop=}, {_run_shadow_f16_bwd=}, {_replace_dq_with_shadow_f16=}, {_replace_dk_with_shadow_f16=}, {_replace_dv_with_shadow_f16=}") + # else: + # print(f"not swapped, {ctx.fp8=}, {_qdq_dO_in_mxfp8_bprop=}, {_qdq_dO_in_f16_bprop=}, {_run_shadow_f16_bwd=}, {_replace_dq_with_shadow_f16=}, {_replace_dk_with_shadow_f16=}, {_replace_dv_with_shadow_f16=}") # d_out: torch.Tensor; dtype = torch.float16 or torch.bfloat16 # d_out_fp8: Float8Tensor; dtype = torch.float16 or torch.bfloat16 @@ -1851,22 +1851,22 @@ def backward(ctx, d_out, *_args): dk_ = dk_shadow_f16 if _replace_dv_with_shadow_f16: dv_ = dv_shadow_f16 - if torch.cuda.current_device() == 0: - print( - f"L{ctx.layer_number}: real/shadow dq min:" - f" {dq_.min():.4f}/{dq_shadow_f16.min():.4f}, max:" - f" {dq_.max():.4f}/{dq_shadow_f16.max():.4f}" - ) - print( - f"L{ctx.layer_number}: real/shadow dk min:" - f" {dk_.min():.4f}/{dk_shadow_f16.min():.4f}, max:" - f" {dk_.max():.4f}/{dk_shadow_f16.max():.4f}" - ) - print( - f"L{ctx.layer_number}: real/shadow dv min:" - f" {dv_.min():.4f}/{dv_shadow_f16.min():.4f}, max:" - f" {dv_.max():.4f}/{dv_shadow_f16.max():.4f}" - ) + # if torch.cuda.current_device() == 0: + # print( + # f"L{ctx.layer_number}: real/shadow dq min:" + # f" {dq_.min():.4f}/{dq_shadow_f16.min():.4f}, max:" + # f" {dq_.max():.4f}/{dq_shadow_f16.max():.4f}" + # ) + # print( + # f"L{ctx.layer_number}: real/shadow dk min:" + # f" {dk_.min():.4f}/{dk_shadow_f16.min():.4f}, max:" + # f" {dk_.max():.4f}/{dk_shadow_f16.max():.4f}" + # ) + # print( + # f"L{ctx.layer_number}: real/shadow dv min:" + # f" {dv_.min():.4f}/{dv_shadow_f16.min():.4f}, max:" + # f" {dv_.max():.4f}/{dv_shadow_f16.max():.4f}" + # ) # dq, dk, dv: torch.Tensor; dtype = torch.float16 or torch.bfloat16 dq, dk, dv = dq_, dk_, dv_ From 813d39d2537deb0d4a502dbf5cf5bb1fe856ac31 Mon Sep 17 00:00:00 2001 From: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Date: Wed, 25 Mar 2026 16:39:25 -0700 Subject: [PATCH 135/172] update FE to allow non-determinism Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --- 3rdparty/cudnn-frontend | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/3rdparty/cudnn-frontend b/3rdparty/cudnn-frontend index 562c25b493..28ebf81b5e 160000 --- a/3rdparty/cudnn-frontend +++ b/3rdparty/cudnn-frontend @@ -1 +1 @@ -Subproject commit 562c25b493ea6965d6997d8620b490c8d9ef2fcb +Subproject commit 28ebf81b5e19b92045d1af0fd4c0b9c4599c2b53 From bdc0c471aa3a08e414b98f0d68445d414f023522 Mon Sep 17 00:00:00 2001 From: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Date: Wed, 25 Mar 2026 21:11:05 -0700 Subject: [PATCH 136/172] fuse qkv transposes; first pass Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --- .../common/fused_attn/flash_attn.cu | 175 ++++++++++++++++++ .../include/transformer_engine/fused_attn.h | 28 +++ .../dot_product_attention/backends.py | 1 + .../attention/dot_product_attention/utils.py | 48 ++++- transformer_engine/pytorch/csrc/extensions.h | 7 + .../pytorch/csrc/extensions/attention.cpp | 102 ++++++++++ .../pytorch/csrc/extensions/pybind.cpp | 7 + 7 files changed, 367 insertions(+), 1 deletion(-) diff --git a/transformer_engine/common/fused_attn/flash_attn.cu b/transformer_engine/common/fused_attn/flash_attn.cu index 6c66746e62..c01d106201 100644 --- a/transformer_engine/common/fused_attn/flash_attn.cu +++ b/transformer_engine/common/fused_attn/flash_attn.cu @@ -4,6 +4,8 @@ * See LICENSE for license information. ************************************************************************/ +#include + #include "../common.h" #include "transformer_engine/fused_attn.h" @@ -133,6 +135,156 @@ void prepare_flash_attn_bwd(Tensor q, Tensor k, Tensor v, Tensor qkv, cudaStream NVTE_CHECK_CUDA(cudaGetLastError()); } +template +__launch_bounds__(block_size) __global__ void permute_to_grouped_tensor_fwd_kernel( + T *q, T *k, T *v, T *q_out, T *k_out, T *v_out, + size_t b, size_t s_q, size_t h_q, size_t d_qk, size_t s_kv, size_t h_kv, size_t d_v, + NVTE_QKV_Layout original_layout) { + const int which_tensor = blockIdx.y; + T *tensor_in = which_tensor == 0 ? q : (which_tensor == 1 ? k : v); + T *tensor_out = which_tensor == 0 ? q_out : (which_tensor == 1 ? k_out : v_out); + const size_t s = which_tensor == 0 ? s_q : s_kv; + const size_t h = which_tensor == 0 ? h_q : h_kv; + const size_t d = which_tensor == 0 ? d_qk : (which_tensor == 1 ? d_qk : d_v); + + const size_t warpid = (blockDim.x * blockIdx.x + threadIdx.x) / warp_size; + size_t s_i, b_i; + if (original_layout == NVTE_QKV_Layout::NVTE_SBHD_SBHD_SBHD) { + s_i = warpid / b; + b_i = warpid % b; + } else if (original_layout == NVTE_QKV_Layout::NVTE_BSHD_BSHD_BSHD) { + b_i = warpid / s; + s_i = warpid % s; + } else { + return; + } + if (s_i >= s) return; + if (b_i >= b) return; + + const T *input_base; + if (original_layout == NVTE_QKV_Layout::NVTE_SBHD_SBHD_SBHD) { + input_base = tensor_in + (s_i * b + b_i) * h * d; + } else { + input_base = tensor_in + (b_i * s + s_i) * h * d; + } + + const size_t id_in_warp = threadIdx.x % warp_size; + // SBHD/BSHD [..,H,D] -> BHSD out[b,h,s,d] = in[...,h,d] with (s_i,b_i) fixed per warp. + for (int jj = 0; jj < static_cast(d); jj += load_size) { + const size_t d_off = static_cast(jj) + id_in_warp * nvec; + if (d_off + nvec > d) continue; + for (int i = 0; i < static_cast(h); ++i) { + const T *input_ptr = input_base + static_cast(i) * d + d_off; + T *output_ptr = tensor_out + b_i * h * s * d + static_cast(i) * s * d + s_i * d + d_off; + *reinterpret_cast(output_ptr) = *reinterpret_cast(input_ptr); + } + } +} + +template +__launch_bounds__(block_size) __global__ void permute_to_grouped_tensor_bwd_kernel( + T *grad_q, T *grad_k, T *grad_v, T *q, T *k, T *v, + size_t b, size_t s_q, size_t h_q, size_t d_qk, size_t s_kv, size_t h_kv, size_t d_v, + NVTE_QKV_Layout original_layout) { + const int which_tensor = blockIdx.y; + T *tensor_in = which_tensor == 0 ? grad_q : (which_tensor == 1 ? grad_k : grad_v); + T *tensor_out = which_tensor == 0 ? q : (which_tensor == 1 ? k : v); + const size_t s = which_tensor == 0 ? s_q : s_kv; + const size_t h = which_tensor == 0 ? h_q : h_kv; + const size_t d = which_tensor == 0 ? d_qk : (which_tensor == 1 ? d_qk : d_v); + + const size_t warpid = (blockDim.x * blockIdx.x + threadIdx.x) / warp_size; + size_t s_i, b_i; + if (original_layout == NVTE_QKV_Layout::NVTE_SBHD_SBHD_SBHD) { + s_i = warpid / b; + b_i = warpid % b; + } else if (original_layout == NVTE_QKV_Layout::NVTE_BSHD_BSHD_BSHD) { + b_i = warpid / s; + s_i = warpid % s; + } else { + return; + } + if (s_i >= s) return; + if (b_i >= b) return; + + T *output_base; + if (original_layout == NVTE_QKV_Layout::NVTE_SBHD_SBHD_SBHD) { + output_base = tensor_out + (s_i * b + b_i) * h * d; + } else { + output_base = tensor_out + (b_i * s + s_i) * h * d; + } + + const size_t id_in_warp = threadIdx.x % warp_size; + for (int jj = 0; jj < static_cast(d); jj += load_size) { + const size_t d_off = static_cast(jj) + id_in_warp * nvec; + if (d_off + nvec > d) continue; + for (int i = 0; i < static_cast(h); ++i) { + const T *input_ptr = + tensor_in + b_i * h * s * d + static_cast(i) * s * d + s_i * d + d_off; + T *output_ptr = output_base + static_cast(i) * d + d_off; + *reinterpret_cast(output_ptr) = *reinterpret_cast(input_ptr); + } + } +} + +void permute_to_grouped_tensor_fwd(Tensor q, Tensor k, Tensor v, Tensor q_out, Tensor k_out, + Tensor v_out, NVTE_QKV_Layout original_layout, cudaStream_t stream) { + using namespace transformer_engine; + size_t b=0, s_q=0, s_kv=0, h_q = 0, h_kv = 0, d_qk = 0, d_v = 0; + b = q_out.shape()[0]; + h_q = q_out.shape()[1]; + s_q = q_out.shape()[2]; + d_qk = q_out.shape()[3]; + h_kv = k_out.shape()[1]; + s_kv = k_out.shape()[2]; + d_v = v_out.shape()[3]; + + size_t warps = b * std::max(s_q, s_kv); + size_t warps_per_block = block_size / warp_size; + size_t blocks = (warps + warps_per_block - 1) / warps_per_block; + dim3 grid(blocks, 3); + int threads = block_size; + + TRANSFORMER_ENGINE_TYPE_SWITCH_16BIT( + q.dtype(), dtype, + permute_to_grouped_tensor_fwd_kernel<<>>( + reinterpret_cast(q.data.dptr), reinterpret_cast(k.data.dptr), + reinterpret_cast(v.data.dptr), reinterpret_cast(q_out.data.dptr), + reinterpret_cast(k_out.data.dptr), reinterpret_cast(v_out.data.dptr), + b, s_q, h_q, d_qk, s_kv, h_kv, d_v, original_layout); + ); + NVTE_CHECK_CUDA(cudaGetLastError()); +} + +void permute_to_grouped_tensor_bwd(Tensor grad_q, Tensor grad_k, Tensor grad_v, + Tensor q, Tensor k, Tensor v, + NVTE_QKV_Layout original_layout, cudaStream_t stream) { + using namespace transformer_engine; + size_t b=0, s_q=0, s_kv=0, h_q = 0, h_kv = 0, d_qk = 0, d_v = 0; + b = grad_q.shape()[0]; + h_q = grad_q.shape()[1]; + s_q = grad_q.shape()[2]; + d_qk = grad_q.shape()[3]; + h_kv = grad_k.shape()[1]; + s_kv = grad_k.shape()[2]; + d_v = grad_v.shape()[3]; + + size_t warps = b * std::max(s_q, s_kv); + size_t warps_per_block = block_size / warp_size; + size_t blocks = (warps + warps_per_block - 1) / warps_per_block; + dim3 grid(blocks, 3); + int threads = block_size; + + TRANSFORMER_ENGINE_TYPE_SWITCH_16BIT( + grad_q.dtype(), dtype, + permute_to_grouped_tensor_bwd_kernel<<>>( + reinterpret_cast(grad_q.data.dptr), reinterpret_cast(grad_k.data.dptr), + reinterpret_cast(grad_v.data.dptr), reinterpret_cast(q.data.dptr), + reinterpret_cast(k.data.dptr), reinterpret_cast(v.data.dptr), + b, s_q, h_q, d_qk, s_kv, h_kv, d_v, original_layout); + ); + NVTE_CHECK_CUDA(cudaGetLastError()); +} } // namespace flash_attention } // namespace transformer_engine @@ -153,3 +305,26 @@ void nvte_prepare_flash_attn_bwd(NVTETensor q, NVTETensor k, NVTETensor v, NVTET *convertNVTETensorCheck(v), *convertNVTETensorCheck(qkv), stream); } + +void nvte_permute_to_grouped_tensor_fwd(NVTETensor q, NVTETensor k, NVTETensor v, NVTETensor q_out, + NVTETensor k_out, NVTETensor v_out, NVTE_QKV_Layout original_layout, + cudaStream_t stream) { + NVTE_API_CALL(nvte_permute_to_grouped_tensor_fwd); + using namespace transformer_engine; + + flash_attention::permute_to_grouped_tensor_fwd( + *convertNVTETensorCheck(q), *convertNVTETensorCheck(k), *convertNVTETensorCheck(v), + *convertNVTETensorCheck(q_out), *convertNVTETensorCheck(k_out), + *convertNVTETensorCheck(v_out), original_layout, stream); +} + +void nvte_permute_to_grouped_tensor_bwd(NVTETensor grad_q, NVTETensor grad_k, NVTETensor grad_v, + NVTETensor q, NVTETensor k, NVTETensor v, NVTE_QKV_Layout original_layout, + cudaStream_t stream) { + NVTE_API_CALL(nvte_permute_to_grouped_tensor_bwd); + using namespace transformer_engine; + + flash_attention::permute_to_grouped_tensor_bwd( + *convertNVTETensorCheck(grad_q), *convertNVTETensorCheck(grad_k), *convertNVTETensorCheck(grad_v), + *convertNVTETensorCheck(q), *convertNVTETensorCheck(k), *convertNVTETensorCheck(v), original_layout, stream); +} \ No newline at end of file diff --git a/transformer_engine/common/include/transformer_engine/fused_attn.h b/transformer_engine/common/include/transformer_engine/fused_attn.h index 012492dab7..c3c4213c24 100644 --- a/transformer_engine/common/include/transformer_engine/fused_attn.h +++ b/transformer_engine/common/include/transformer_engine/fused_attn.h @@ -608,6 +608,34 @@ void nvte_prepare_flash_attn_fwd(NVTETensor qkvi, NVTETensor qkv, cudaStream_t s void nvte_prepare_flash_attn_bwd(NVTETensor q, NVTETensor k, NVTETensor v, NVTETensor qkv, cudaStream_t stream); +/*! \brief Permute Q, K, V to grouped tensors. + * + * \param[in] q Query tensor + * \param[in] k Key tensor + * \param[in] v Value tensor + * \param[out] q_out Output query tensor + * \param[out] k_out Output key tensor + * \param[out] v_out Output value tensor + * \param[in] original_layout Original QKV layout. + * \param[in] stream CUDA stream. + */ +void nvte_permute_to_grouped_tensor_fwd(NVTETensor q, NVTETensor k, NVTETensor v, NVTETensor q_out, NVTETensor k_out, NVTETensor v_out, NVTE_QKV_Layout original_layout, cudaStream_t stream); + +/*! \brief Permute Q, K, V back to original layout. + * + * \param[in] grad_q Gradient of query tensor + * \param[in] grad_k Gradient of key tensor + * \param[in] grad_v Gradient of value tensor + * \param[out] q Original query tensor + * \param[out] k Original key tensor + * \param[out] v Original value tensor + * \param[in] original_layout Original QKV layout. + * \param[in] stream CUDA stream. + */ +void nvte_permute_to_grouped_tensor_bwd(NVTETensor grad_q, NVTETensor grad_k, NVTETensor grad_v, + NVTETensor q, NVTETensor k, NVTETensor v, + NVTE_QKV_Layout original_layout, cudaStream_t stream); + #ifdef __cplusplus } // extern "C" #endif diff --git a/transformer_engine/pytorch/attention/dot_product_attention/backends.py b/transformer_engine/pytorch/attention/dot_product_attention/backends.py index 215ec8a078..392c1c8c0d 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/backends.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/backends.py @@ -46,6 +46,7 @@ FusedAttnBackend, META_O, META_QKV, + QKVLayout, ) from transformer_engine.pytorch.quantization import get_fp8_torch_dtype, FP8GlobalStateManager from transformer_engine.pytorch.distributed import get_distributed_world_size diff --git a/transformer_engine/pytorch/attention/dot_product_attention/utils.py b/transformer_engine/pytorch/attention/dot_product_attention/utils.py index 1658c92c83..6627c55b49 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/utils.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/utils.py @@ -2317,16 +2317,62 @@ def permute_to_grouped_tensor(src_format, tensor): return tensor, des_format +class PermuteToGroupedTensor(torch.autograd.Function): + """Permute Q, K, V from {bshd_bshd_bshd, sbhd_sbhd_sbhd} to bhsd_bhsd_bhsd.""" + + @staticmethod + def forward( + ctx: torch.autograd.function.FunctionCtx, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + input_layout: str = "bshd_bshd_bshd", + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + # pylint: disable=missing-function-docstring + ctx.original_layout = QKVLayout[input_layout] + return tex.permute_to_grouped_tensor_fwd(query, key, value, ctx.original_layout) + + @staticmethod + def backward( + ctx: torch.autograd.function.FunctionCtx, + query_grad: torch.Tensor, + key_grad: torch.Tensor, + value_grad: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + # pylint: disable=missing-function-docstring + q, k, v = tex.permute_to_grouped_tensor_bwd( + query_grad, + key_grad, + value_grad, + ctx.original_layout, + ) + return q, k, v, None + + def combine_and_quantize(qkv_layout, q, k, v, qkv_quantizer): """Combine q,k,v based on qkv_layout and quantize them together""" if isinstance(qkv_quantizer, MXFP8Quantizer): qkv_format, q_format, kv_format = get_qkv_format(qkv_layout) - # permute q, k, v to bhsd/htd format + # q_orig, k_orig, v_orig = q, k, v + # # permute q, k, v to bhsd/htd format + # if qkv_layout in ["bshd_bshd_bshd", "sbhd_sbhd_sbhd"]: + # print(f">>>>>>>>>>>> {qkv_layout} PermuteToGroupedTensor") + # q, k, v = PermuteToGroupedTensor.apply(q, k, v, qkv_layout) + # # else: + # if q_format not in ["bhsd", "htd"]: + # q_, _ = permute_to_grouped_tensor(q_format, q_orig) + # if kv_format not in ["bhsd", "htd"]: + # k_, _ = permute_to_grouped_tensor(kv_format, k_orig) + # v_, _ = permute_to_grouped_tensor(kv_format, v_orig) + # torch.testing.assert_close(q_, q) + # torch.testing.assert_close(k_, k) + # torch.testing.assert_close(v_, v) if q_format not in ["bhsd", "htd"]: q, _ = permute_to_grouped_tensor(q_format, q) if kv_format not in ["bhsd", "htd"]: k, _ = permute_to_grouped_tensor(kv_format, k) v, _ = permute_to_grouped_tensor(kv_format, v) + qkv_layout = "bhsd_bhsd_bhsd" if qkv_format != "thd" else "htd_htd_htd" # check shapes original_shapes = [x.shape for x in [q, k, v]] diff --git a/transformer_engine/pytorch/csrc/extensions.h b/transformer_engine/pytorch/csrc/extensions.h index e8b588f0ee..ccb1f286b7 100644 --- a/transformer_engine/pytorch/csrc/extensions.h +++ b/transformer_engine/pytorch/csrc/extensions.h @@ -112,6 +112,13 @@ std::vector fused_attn_bwd( at::Tensor fa_prepare_fwd(at::Tensor qkvi); at::Tensor fa_prepare_bwd(at::Tensor q, at::Tensor k, at::Tensor v); +std::tuple permute_to_grouped_tensor_fwd(at::Tensor query, + at::Tensor key, + at::Tensor value, + NVTE_QKV_Layout input_layout); +std::tuple permute_to_grouped_tensor_bwd( + at::Tensor query_grad, at::Tensor key_grad, at::Tensor value_grad, NVTE_QKV_Layout input_layout); + at::Tensor convert_thd_to_bshd(at::Tensor tensor, at::Tensor cu_seqlens, int b, int max_seq_len); at::Tensor convert_bshd_to_thd(at::Tensor tensor, at::Tensor cu_seqlens, int t); void copy_to_kv_cache(at::Tensor new_k, at::Tensor new_v, at::Tensor k_cache, at::Tensor v_cache, diff --git a/transformer_engine/pytorch/csrc/extensions/attention.cpp b/transformer_engine/pytorch/csrc/extensions/attention.cpp index d76e29964a..463ad65e06 100644 --- a/transformer_engine/pytorch/csrc/extensions/attention.cpp +++ b/transformer_engine/pytorch/csrc/extensions/attention.cpp @@ -647,6 +647,108 @@ at::Tensor fa_prepare_bwd(at::Tensor q, at::Tensor k, at::Tensor v) { return qkv; } +std::tuple permute_to_grouped_tensor_fwd(at::Tensor query, + at::Tensor key, + at::Tensor value, + NVTE_QKV_Layout original_layout) { + NVTE_CHECK(original_layout == NVTE_SBHD_SBHD_SBHD || original_layout == NVTE_BSHD_BSHD_BSHD, + "permute_to_grouped_tensor_fwd: original_layout must be NVTE_SBHD_SBHD_SBHD or NVTE_BSHD_BSHD_BSHD."); + NVTE_CHECK(query.is_cuda() && key.is_cuda() && value.is_cuda()); + NVTE_CHECK(query.is_contiguous() && key.is_contiguous() && value.is_contiguous()); + NVTE_CHECK(query.dim() == 4 && key.dim() == 4 && value.dim() == 4); + NVTE_CHECK(query.scalar_type() == at::ScalarType::Half || + query.scalar_type() == at::ScalarType::BFloat16); + NVTE_CHECK(key.scalar_type() == query.scalar_type() && value.scalar_type() == query.scalar_type()); + + int64_t B = 0; + int64_t S_q = 0, H_q = 0, D_qk = 0; + int64_t S_kv = 0, H_kv = 0, D_v = 0; + if (original_layout == NVTE_SBHD_SBHD_SBHD) { + S_q = query.size(0); + B = query.size(1); + H_q = query.size(2); + D_qk = query.size(3); + S_kv = key.size(0); + H_kv = key.size(2); + D_v = value.size(3); + } else { + B = query.size(0); + S_q = query.size(1); + H_q = query.size(2); + D_qk = query.size(3); + S_kv = key.size(1); + H_kv = key.size(2); + D_v = value.size(3); + } + NVTE_CHECK(key.size(original_layout == NVTE_SBHD_SBHD_SBHD ? 1 : 0) == B && + value.size(original_layout == NVTE_SBHD_SBHD_SBHD ? 1 : 0) == B, + "permute_to_grouped_tensor_fwd: Q/K/V batch dimension must match."); + + at::Tensor q_out = at::empty({B, H_q, S_q, D_qk}, query.options()); + at::Tensor k_out = at::empty({B, H_kv, S_kv, D_qk}, key.options()); + at::Tensor v_out = at::empty({B, H_kv, S_kv, D_v}, value.options()); + + auto te_q = makeTransformerEngineTensor(query); + auto te_k = makeTransformerEngineTensor(key); + auto te_v = makeTransformerEngineTensor(value); + auto te_qo = makeTransformerEngineTensor(q_out); + auto te_ko = makeTransformerEngineTensor(k_out); + auto te_vo = makeTransformerEngineTensor(v_out); + + nvte_permute_to_grouped_tensor_fwd( + te_q.data(), te_k.data(), te_v.data(), te_qo.data(), te_ko.data(), te_vo.data(), + original_layout, at::cuda::getCurrentCUDAStream()); + + return std::make_tuple(q_out, k_out, v_out); +} + +std::tuple permute_to_grouped_tensor_bwd( + at::Tensor query_grad, at::Tensor key_grad, at::Tensor value_grad, NVTE_QKV_Layout original_layout) { + NVTE_CHECK(original_layout == NVTE_SBHD_SBHD_SBHD || original_layout == NVTE_BSHD_BSHD_BSHD, + "permute_to_grouped_tensor_bwd: original_layout must be NVTE_SBHD_SBHD_SBHD or NVTE_BSHD_BSHD_BSHD."); + NVTE_CHECK(query_grad.is_cuda() && key_grad.is_cuda() && value_grad.is_cuda()); + NVTE_CHECK(query_grad.is_contiguous() && key_grad.is_contiguous() && value_grad.is_contiguous()); + NVTE_CHECK(query_grad.dim() == 4 && key_grad.dim() == 4 && value_grad.dim() == 4); + NVTE_CHECK(query_grad.scalar_type() == at::ScalarType::Half || + query_grad.scalar_type() == at::ScalarType::BFloat16); + NVTE_CHECK(key_grad.scalar_type() == query_grad.scalar_type() && + value_grad.scalar_type() == query_grad.scalar_type()); + + const int64_t B = query_grad.size(0); + const int64_t H_q = query_grad.size(1); + const int64_t S_q = query_grad.size(2); + const int64_t D_qk = query_grad.size(3); + const int64_t H_kv = key_grad.size(1); + const int64_t S_kv = key_grad.size(2); + const int64_t D_v = value_grad.size(3); + + at::Tensor query; + at::Tensor key; + at::Tensor value; + if (original_layout == NVTE_SBHD_SBHD_SBHD) { + query = at::empty({S_q, B, H_q, D_qk}, query_grad.options()); + key = at::empty({S_kv, B, H_kv, D_qk}, key_grad.options()); + value = at::empty({S_kv, B, H_kv, D_v}, value_grad.options()); + } else { + query = at::empty({B, S_q, H_q, D_qk}, query_grad.options()); + key = at::empty({B, S_kv, H_kv, D_qk}, key_grad.options()); + value = at::empty({B, S_kv, H_kv, D_v}, value_grad.options()); + } + + auto te_gq = makeTransformerEngineTensor(query_grad); + auto te_gk = makeTransformerEngineTensor(key_grad); + auto te_gv = makeTransformerEngineTensor(value_grad); + auto te_q = makeTransformerEngineTensor(query); + auto te_k = makeTransformerEngineTensor(key); + auto te_v = makeTransformerEngineTensor(value); + + nvte_permute_to_grouped_tensor_bwd( + te_gq.data(), te_gk.data(), te_gv.data(), te_q.data(), te_k.data(), te_v.data(), + original_layout, at::cuda::getCurrentCUDAStream()); + + return std::make_tuple(query, key, value); +} + /*************************************************************************************************** * Support THD format for Context Parallel: Read the half of a THD tensor **************************************************************************************************/ diff --git a/transformer_engine/pytorch/csrc/extensions/pybind.cpp b/transformer_engine/pytorch/csrc/extensions/pybind.cpp index c590a3c9e2..58e27c875b 100644 --- a/transformer_engine/pytorch/csrc/extensions/pybind.cpp +++ b/transformer_engine/pytorch/csrc/extensions/pybind.cpp @@ -394,6 +394,13 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("fa_prepare_bwd", &transformer_engine::pytorch::fa_prepare_bwd, "Backward of QKV preparation for Flash Attention", py::call_guard()); + m.def("permute_to_grouped_tensor_fwd", &transformer_engine::pytorch::permute_to_grouped_tensor_fwd, + "Permute Q, K, V to grouped tensors.", + py::arg("query"), py::arg("key"), py::arg("value"), py::arg("original_layout"), + py::call_guard()); + m.def("permute_to_grouped_tensor_bwd", &transformer_engine::pytorch::permute_to_grouped_tensor_bwd, + "Permute Q, K, V back to original layout.", py::arg("query_grad"), py::arg("key_grad"), + py::arg("value_grad"), py::arg("original_layout"), py::call_guard()); m.def("fused_attn_fwd", &transformer_engine::pytorch::fused_attn_fwd, "Fused Attention FP8/BF16/FP16 FWD with separate Q, K and V"); m.def("fused_attn_bwd", &transformer_engine::pytorch::fused_attn_bwd, From e69a06a42dbcb601a40564c4c059f8f87a72e0ce Mon Sep 17 00:00:00 2001 From: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Date: Thu, 26 Mar 2026 13:50:07 -0700 Subject: [PATCH 137/172] remap parallelism to grid(bh, splits, 3) block(s/splits x d); use nvec = 128 bits Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --- .../common/fused_attn/flash_attn.cu | 312 +++++++++++------- 1 file changed, 200 insertions(+), 112 deletions(-) diff --git a/transformer_engine/common/fused_attn/flash_attn.cu b/transformer_engine/common/fused_attn/flash_attn.cu index c01d106201..34340babda 100644 --- a/transformer_engine/common/fused_attn/flash_attn.cu +++ b/transformer_engine/common/fused_attn/flash_attn.cu @@ -4,17 +4,22 @@ * See LICENSE for license information. ************************************************************************/ -#include - #include "../common.h" #include "transformer_engine/fused_attn.h" namespace transformer_engine { namespace flash_attention { +/// Packed vector of N elements of T; alignment matches a single wide load/store of N * sizeof(T) bytes. +template +struct alignas(sizeof(T) * N) Vec { + T data[N]; +}; + constexpr int warp_size = 32; constexpr int type_size = 2; // FP16 or BF16 constexpr int nvec = sizeof(uint64_t) / type_size; +constexpr int nvec128 = sizeof(uint4) / type_size; constexpr int load_size = warp_size * nvec; constexpr int block_size = 512; @@ -37,8 +42,8 @@ __launch_bounds__(block_size) __global__ T *my_output = qkv + offset_output; for (int i = 0; i < Z; ++i) { - uint64_t *out = reinterpret_cast(my_output + i * load_size); - *out = *reinterpret_cast(my_input + i * load_size * 3); + Vec *const out = reinterpret_cast *>(my_output + i * load_size); + *out = *reinterpret_cast *>(my_input + i * load_size * 3); } } @@ -63,8 +68,8 @@ __launch_bounds__(block_size) __global__ T *my_output = qkv + offset_output; for (int i = 0; i < Z; ++i) { - uint64_t *out = reinterpret_cast(my_output + i * load_size * 3); - *out = *reinterpret_cast(my_input + i * load_size); + Vec *const out = reinterpret_cast *>(my_output + i * load_size * 3); + *out = *reinterpret_cast *>(my_input + i * load_size); } } @@ -135,94 +140,150 @@ void prepare_flash_attn_bwd(Tensor q, Tensor k, Tensor v, Tensor qkv, cudaStream NVTE_CHECK_CUDA(cudaGetLastError()); } -template -__launch_bounds__(block_size) __global__ void permute_to_grouped_tensor_fwd_kernel( - T *q, T *k, T *v, T *q_out, T *k_out, T *v_out, - size_t b, size_t s_q, size_t h_q, size_t d_qk, size_t s_kv, size_t h_kv, size_t d_v, - NVTE_QKV_Layout original_layout) { - const int which_tensor = blockIdx.y; - T *tensor_in = which_tensor == 0 ? q : (which_tensor == 1 ? k : v); - T *tensor_out = which_tensor == 0 ? q_out : (which_tensor == 1 ? k_out : v_out); - const size_t s = which_tensor == 0 ? s_q : s_kv; - const size_t h = which_tensor == 0 ? h_q : h_kv; - const size_t d = which_tensor == 0 ? d_qk : (which_tensor == 1 ? d_qk : d_v); - - const size_t warpid = (blockDim.x * blockIdx.x + threadIdx.x) / warp_size; - size_t s_i, b_i; - if (original_layout == NVTE_QKV_Layout::NVTE_SBHD_SBHD_SBHD) { - s_i = warpid / b; - b_i = warpid % b; - } else if (original_layout == NVTE_QKV_Layout::NVTE_BSHD_BSHD_BSHD) { - b_i = warpid / s; - s_i = warpid % s; - } else { - return; - } - if (s_i >= s) return; - if (b_i >= b) return; +template +__launch_bounds__(1024) __global__ void permute_to_grouped_tensor_fwd_kernel( + const T *__restrict__ q, const T *__restrict__ k, const T *__restrict__ v, T *__restrict__ q_out, + T *__restrict__ k_out, T *__restrict__ v_out, size_t b, size_t s_q, size_t h_q, size_t d_qk, + size_t s_kv, size_t h_kv, size_t d_v, unsigned int permute_s_splits) { + const int which_tensor = blockIdx.z; + const T *__restrict__ tensor_in = which_tensor == 0 ? q : (which_tensor == 1 ? k : v); + T *__restrict__ tensor_out = which_tensor == 0 ? q_out : (which_tensor == 1 ? k_out : v_out); + const size_t Sdim = which_tensor == 0 ? s_q : s_kv; + const size_t Hdim = which_tensor == 0 ? h_q : h_kv; + const size_t Ddim = which_tensor == 0 ? d_qk : (which_tensor == 1 ? d_qk : d_v); + + const size_t h_grid = h_q > h_kv ? h_q : h_kv; + const size_t b_i = static_cast(blockIdx.x) / h_grid; + const size_t h_i = static_cast(blockIdx.x) % h_grid; - const T *input_base; - if (original_layout == NVTE_QKV_Layout::NVTE_SBHD_SBHD_SBHD) { - input_base = tensor_in + (s_i * b + b_i) * h * d; + if (b_i >= b) return; + if (which_tensor == 0) { + if (h_i >= h_q) return; } else { - input_base = tensor_in + (b_i * s + s_i) * h * d; + if (h_i >= h_kv) return; } - - const size_t id_in_warp = threadIdx.x % warp_size; - // SBHD/BSHD [..,H,D] -> BHSD out[b,h,s,d] = in[...,h,d] with (s_i,b_i) fixed per warp. - for (int jj = 0; jj < static_cast(d); jj += load_size) { - const size_t d_off = static_cast(jj) + id_in_warp * nvec; - if (d_off + nvec > d) continue; - for (int i = 0; i < static_cast(h); ++i) { - const T *input_ptr = input_base + static_cast(i) * d + d_off; - T *output_ptr = tensor_out + b_i * h * s * d + static_cast(i) * s * d + s_i * d + d_off; - *reinterpret_cast(output_ptr) = *reinterpret_cast(input_ptr); + if (Ddim % static_cast(nvec) != 0) return; + + const unsigned int s_part = blockIdx.y; + const size_t s_begin = (Sdim * static_cast(s_part)) / static_cast(permute_s_splits); + const size_t s_end = (Sdim * static_cast(s_part + 1)) / static_cast(permute_s_splits); + const size_t S_chunk = s_end - s_begin; + + const size_t in_base = kIsBshdBshdBshd ? b_i * Sdim * Hdim * Ddim : b_i * Hdim * Ddim; + const size_t out_base = b_i * Hdim * Sdim * Ddim + h_i * Sdim * Ddim; + const bool use_vec128 = (Ddim % static_cast(nvec128) == 0) && + ((reinterpret_cast(tensor_in) % alignof(Vec)) == 0) && + ((reinterpret_cast(tensor_out) % alignof(Vec)) == 0); + + if (use_vec128) { + const size_t d_vec = Ddim / static_cast(nvec128); + const size_t total_work = S_chunk * d_vec; + for (size_t w = static_cast(threadIdx.x); w < total_work; w += static_cast(blockDim.x)) { + const size_t s_local = w / d_vec; + const size_t s_i = s_begin + s_local; + const size_t v = w % d_vec; + const size_t d_off = v * static_cast(nvec128); + + const T *__restrict__ in_ptr; + if constexpr (kIsBshdBshdBshd) { + in_ptr = tensor_in + in_base + s_i * Hdim * Ddim + h_i * Ddim + d_off; + } else { + in_ptr = tensor_in + s_i * b * Hdim * Ddim + in_base + h_i * Ddim + d_off; + } + T *__restrict__ out_ptr = tensor_out + out_base + s_i * Ddim + d_off; + *reinterpret_cast *>(out_ptr) = *reinterpret_cast *>(in_ptr); + } + } else { + const size_t d_vec = Ddim / static_cast(nvec); + const size_t total_work = S_chunk * d_vec; + for (size_t w = static_cast(threadIdx.x); w < total_work; w += static_cast(blockDim.x)) { + const size_t s_local = w / d_vec; + const size_t s_i = s_begin + s_local; + const size_t v = w % d_vec; + const size_t d_off = v * static_cast(nvec); + + const T *__restrict__ in_ptr; + if constexpr (kIsBshdBshdBshd) { + in_ptr = tensor_in + in_base + s_i * Hdim * Ddim + h_i * Ddim + d_off; + } else { + in_ptr = tensor_in + s_i * b * Hdim * Ddim + in_base + h_i * Ddim + d_off; + } + T *__restrict__ out_ptr = tensor_out + out_base + s_i * Ddim + d_off; + *reinterpret_cast *>(out_ptr) = *reinterpret_cast *>(in_ptr); } } } -template -__launch_bounds__(block_size) __global__ void permute_to_grouped_tensor_bwd_kernel( - T *grad_q, T *grad_k, T *grad_v, T *q, T *k, T *v, - size_t b, size_t s_q, size_t h_q, size_t d_qk, size_t s_kv, size_t h_kv, size_t d_v, - NVTE_QKV_Layout original_layout) { - const int which_tensor = blockIdx.y; - T *tensor_in = which_tensor == 0 ? grad_q : (which_tensor == 1 ? grad_k : grad_v); - T *tensor_out = which_tensor == 0 ? q : (which_tensor == 1 ? k : v); - const size_t s = which_tensor == 0 ? s_q : s_kv; - const size_t h = which_tensor == 0 ? h_q : h_kv; - const size_t d = which_tensor == 0 ? d_qk : (which_tensor == 1 ? d_qk : d_v); - - const size_t warpid = (blockDim.x * blockIdx.x + threadIdx.x) / warp_size; - size_t s_i, b_i; - if (original_layout == NVTE_QKV_Layout::NVTE_SBHD_SBHD_SBHD) { - s_i = warpid / b; - b_i = warpid % b; - } else if (original_layout == NVTE_QKV_Layout::NVTE_BSHD_BSHD_BSHD) { - b_i = warpid / s; - s_i = warpid % s; - } else { - return; - } - if (s_i >= s) return; - if (b_i >= b) return; +template +__launch_bounds__(1024) __global__ void permute_to_grouped_tensor_bwd_kernel( + const T *__restrict__ grad_q, const T *__restrict__ grad_k, const T *__restrict__ grad_v, + T *__restrict__ q, T *__restrict__ k, T *__restrict__ v, size_t b, size_t s_q, size_t h_q, + size_t d_qk, size_t s_kv, size_t h_kv, size_t d_v, unsigned int permute_s_splits) { + const int which_tensor = blockIdx.z; + const T *__restrict__ tensor_in = which_tensor == 0 ? grad_q : (which_tensor == 1 ? grad_k : grad_v); + T *__restrict__ tensor_out = which_tensor == 0 ? q : (which_tensor == 1 ? k : v); + const size_t Sdim = which_tensor == 0 ? s_q : s_kv; + const size_t Hdim = which_tensor == 0 ? h_q : h_kv; + const size_t Ddim = which_tensor == 0 ? d_qk : (which_tensor == 1 ? d_qk : d_v); + + const size_t h_grid = h_q > h_kv ? h_q : h_kv; + const size_t b_i = static_cast(blockIdx.x) / h_grid; + const size_t h_i = static_cast(blockIdx.x) % h_grid; - T *output_base; - if (original_layout == NVTE_QKV_Layout::NVTE_SBHD_SBHD_SBHD) { - output_base = tensor_out + (s_i * b + b_i) * h * d; + if (b_i >= b) return; + if (which_tensor == 0) { + if (h_i >= h_q) return; } else { - output_base = tensor_out + (b_i * s + s_i) * h * d; + if (h_i >= h_kv) return; } - - const size_t id_in_warp = threadIdx.x % warp_size; - for (int jj = 0; jj < static_cast(d); jj += load_size) { - const size_t d_off = static_cast(jj) + id_in_warp * nvec; - if (d_off + nvec > d) continue; - for (int i = 0; i < static_cast(h); ++i) { - const T *input_ptr = - tensor_in + b_i * h * s * d + static_cast(i) * s * d + s_i * d + d_off; - T *output_ptr = output_base + static_cast(i) * d + d_off; - *reinterpret_cast(output_ptr) = *reinterpret_cast(input_ptr); + if (Ddim % static_cast(nvec) != 0) return; + + const unsigned int s_part = blockIdx.y; + const size_t s_begin = (Sdim * static_cast(s_part)) / static_cast(permute_s_splits); + const size_t s_end = (Sdim * static_cast(s_part + 1)) / static_cast(permute_s_splits); + const size_t S_chunk = s_end - s_begin; + + const size_t in_base = b_i * Hdim * Sdim * Ddim + h_i * Sdim * Ddim; + const size_t out_base = kIsBshdBshdBshd ? b_i * Sdim * Hdim * Ddim + h_i * Ddim : b_i * Hdim * Ddim + h_i * Ddim; + const bool use_vec128 = (Ddim % static_cast(nvec128) == 0) && + ((reinterpret_cast(tensor_in) % alignof(Vec)) == 0) && + ((reinterpret_cast(tensor_out) % alignof(Vec)) == 0); + + if (use_vec128) { + const size_t d_vec = Ddim / static_cast(nvec128); + const size_t total_work = S_chunk * d_vec; + for (size_t w = static_cast(threadIdx.x); w < total_work; w += static_cast(blockDim.x)) { + const size_t s_local = w / d_vec; + const size_t s_i = s_begin + s_local; + const size_t v = w % d_vec; + const size_t d_off = v * static_cast(nvec128); + + const T *__restrict__ in_ptr = tensor_in + in_base + s_i * Ddim + d_off; + T *__restrict__ out_ptr; + if constexpr (kIsBshdBshdBshd) { + out_ptr = tensor_out + out_base + s_i * Hdim * Ddim + d_off; + } else { + out_ptr = tensor_out + s_i * b * Hdim * Ddim + out_base + d_off; + } + *reinterpret_cast *>(out_ptr) = *reinterpret_cast *>(in_ptr); + } + } else { + const size_t d_vec = Ddim / static_cast(nvec); + const size_t total_work = S_chunk * d_vec; + for (size_t w = static_cast(threadIdx.x); w < total_work; w += static_cast(blockDim.x)) { + const size_t s_local = w / d_vec; + const size_t s_i = s_begin + s_local; + const size_t v = w % d_vec; + const size_t d_off = v * static_cast(nvec); + + const T *__restrict__ in_ptr = tensor_in + in_base + s_i * Ddim + d_off; + T *__restrict__ out_ptr; + if constexpr (kIsBshdBshdBshd) { + out_ptr = tensor_out + out_base + s_i * Hdim * Ddim + d_off; + } else { + out_ptr = tensor_out + s_i * b * Hdim * Ddim + out_base + d_off; + } + *reinterpret_cast *>(out_ptr) = *reinterpret_cast *>(in_ptr); } } } @@ -239,20 +300,33 @@ void permute_to_grouped_tensor_fwd(Tensor q, Tensor k, Tensor v, Tensor q_out, T s_kv = k_out.shape()[2]; d_v = v_out.shape()[3]; - size_t warps = b * std::max(s_q, s_kv); - size_t warps_per_block = block_size / warp_size; - size_t blocks = (warps + warps_per_block - 1) / warps_per_block; - dim3 grid(blocks, 3); - int threads = block_size; - - TRANSFORMER_ENGINE_TYPE_SWITCH_16BIT( - q.dtype(), dtype, - permute_to_grouped_tensor_fwd_kernel<<>>( - reinterpret_cast(q.data.dptr), reinterpret_cast(k.data.dptr), - reinterpret_cast(v.data.dptr), reinterpret_cast(q_out.data.dptr), - reinterpret_cast(k_out.data.dptr), reinterpret_cast(v_out.data.dptr), - b, s_q, h_q, d_qk, s_kv, h_kv, d_v, original_layout); - ); + NVTE_CHECK(d_qk % nvec == 0 && d_v % nvec == 0, + "permute_to_grouped_tensor_fwd: head dim must be divisible by vector width."); + // Split S across grid.y; work out permute_s_splits so S_chunk >= threads + const int threads = 1024; + const size_t s_min = std::min(s_q, s_kv); + const unsigned int permute_s_splits = + std::max(1u, static_cast(s_min / static_cast(threads))); + const size_t h_grid = std::max(h_q, h_kv); + dim3 grid(b * h_grid, permute_s_splits, 3); + + if (original_layout == NVTE_QKV_Layout::NVTE_BSHD_BSHD_BSHD) { + TRANSFORMER_ENGINE_TYPE_SWITCH_16BIT( + q.dtype(), dtype, + permute_to_grouped_tensor_fwd_kernel<<>>( + reinterpret_cast(q.data.dptr), reinterpret_cast(k.data.dptr), + reinterpret_cast(v.data.dptr), reinterpret_cast(q_out.data.dptr), + reinterpret_cast(k_out.data.dptr), reinterpret_cast(v_out.data.dptr), b, + s_q, h_q, d_qk, s_kv, h_kv, d_v, permute_s_splits);); + } else { + TRANSFORMER_ENGINE_TYPE_SWITCH_16BIT( + q.dtype(), dtype, + permute_to_grouped_tensor_fwd_kernel<<>>( + reinterpret_cast(q.data.dptr), reinterpret_cast(k.data.dptr), + reinterpret_cast(v.data.dptr), reinterpret_cast(q_out.data.dptr), + reinterpret_cast(k_out.data.dptr), reinterpret_cast(v_out.data.dptr), b, + s_q, h_q, d_qk, s_kv, h_kv, d_v, permute_s_splits);); + } NVTE_CHECK_CUDA(cudaGetLastError()); } @@ -269,20 +343,34 @@ void permute_to_grouped_tensor_bwd(Tensor grad_q, Tensor grad_k, Tensor grad_v, s_kv = grad_k.shape()[2]; d_v = grad_v.shape()[3]; - size_t warps = b * std::max(s_q, s_kv); - size_t warps_per_block = block_size / warp_size; - size_t blocks = (warps + warps_per_block - 1) / warps_per_block; - dim3 grid(blocks, 3); - int threads = block_size; - - TRANSFORMER_ENGINE_TYPE_SWITCH_16BIT( - grad_q.dtype(), dtype, - permute_to_grouped_tensor_bwd_kernel<<>>( - reinterpret_cast(grad_q.data.dptr), reinterpret_cast(grad_k.data.dptr), - reinterpret_cast(grad_v.data.dptr), reinterpret_cast(q.data.dptr), - reinterpret_cast(k.data.dptr), reinterpret_cast(v.data.dptr), - b, s_q, h_q, d_qk, s_kv, h_kv, d_v, original_layout); - ); + NVTE_CHECK(d_qk % nvec == 0 && d_v % nvec == 0, + "permute_to_grouped_tensor_bwd: head dim must be divisible by vector width."); + const int threads = 1024; + const size_t s_min = std::min(s_q, s_kv); + const unsigned int permute_s_splits = + std::max(1u, static_cast(s_min / static_cast(threads))); + const size_t h_grid = std::max(h_q, h_kv); + dim3 grid(b * h_grid, permute_s_splits, 3); + + if (original_layout == NVTE_QKV_Layout::NVTE_BSHD_BSHD_BSHD) { + TRANSFORMER_ENGINE_TYPE_SWITCH_16BIT( + grad_q.dtype(), dtype, + permute_to_grouped_tensor_bwd_kernel<<>>( + reinterpret_cast(grad_q.data.dptr), + reinterpret_cast(grad_k.data.dptr), + reinterpret_cast(grad_v.data.dptr), reinterpret_cast(q.data.dptr), + reinterpret_cast(k.data.dptr), reinterpret_cast(v.data.dptr), b, s_q, h_q, + d_qk, s_kv, h_kv, d_v, permute_s_splits);); + } else { + TRANSFORMER_ENGINE_TYPE_SWITCH_16BIT( + grad_q.dtype(), dtype, + permute_to_grouped_tensor_bwd_kernel<<>>( + reinterpret_cast(grad_q.data.dptr), + reinterpret_cast(grad_k.data.dptr), + reinterpret_cast(grad_v.data.dptr), reinterpret_cast(q.data.dptr), + reinterpret_cast(k.data.dptr), reinterpret_cast(v.data.dptr), b, s_q, h_q, + d_qk, s_kv, h_kv, d_v, permute_s_splits);); + } NVTE_CHECK_CUDA(cudaGetLastError()); } } // namespace flash_attention From aab8856e4e206931abb6c79b9ff81e9e51c69137 Mon Sep 17 00:00:00 2001 From: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Date: Thu, 26 Mar 2026 14:55:06 -0700 Subject: [PATCH 138/172] allocate contiguous block for qkv Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --- .../pytorch/csrc/extensions/attention.cpp | 27 ++++++++++++------- 1 file changed, 18 insertions(+), 9 deletions(-) diff --git a/transformer_engine/pytorch/csrc/extensions/attention.cpp b/transformer_engine/pytorch/csrc/extensions/attention.cpp index 463ad65e06..54bb731961 100644 --- a/transformer_engine/pytorch/csrc/extensions/attention.cpp +++ b/transformer_engine/pytorch/csrc/extensions/attention.cpp @@ -684,9 +684,13 @@ std::tuple permute_to_grouped_tensor_fwd(at: value.size(original_layout == NVTE_SBHD_SBHD_SBHD ? 1 : 0) == B, "permute_to_grouped_tensor_fwd: Q/K/V batch dimension must match."); - at::Tensor q_out = at::empty({B, H_q, S_q, D_qk}, query.options()); - at::Tensor k_out = at::empty({B, H_kv, S_kv, D_qk}, key.options()); - at::Tensor v_out = at::empty({B, H_kv, S_kv, D_v}, value.options()); + const int64_t numel_q = B * H_q * S_q * D_qk; + const int64_t numel_k = B * H_kv * S_kv * D_qk; + const int64_t numel_v = B * H_kv * S_kv * D_v; + at::Tensor qkv_out_flat = at::empty({numel_q + numel_k + numel_v}, query.options()); + at::Tensor q_out = qkv_out_flat.narrow(0, 0, numel_q).view({B, H_q, S_q, D_qk}); + at::Tensor k_out = qkv_out_flat.narrow(0, numel_q, numel_k).view({B, H_kv, S_kv, D_qk}); + at::Tensor v_out = qkv_out_flat.narrow(0, numel_q + numel_k, numel_v).view({B, H_kv, S_kv, D_v}); auto te_q = makeTransformerEngineTensor(query); auto te_k = makeTransformerEngineTensor(key); @@ -722,17 +726,22 @@ std::tuple permute_to_grouped_tensor_bwd( const int64_t S_kv = key_grad.size(2); const int64_t D_v = value_grad.size(3); + const int64_t numel_q = S_q * B * H_q * D_qk; + const int64_t numel_k = S_kv * B * H_kv * D_qk; + const int64_t numel_v = S_kv * B * H_kv * D_v; + at::Tensor qkv_grad_flat = at::empty({numel_q + numel_k + numel_v}, query_grad.options()); + at::Tensor query; at::Tensor key; at::Tensor value; if (original_layout == NVTE_SBHD_SBHD_SBHD) { - query = at::empty({S_q, B, H_q, D_qk}, query_grad.options()); - key = at::empty({S_kv, B, H_kv, D_qk}, key_grad.options()); - value = at::empty({S_kv, B, H_kv, D_v}, value_grad.options()); + query = qkv_grad_flat.narrow(0, 0, numel_q).view({S_q, B, H_q, D_qk}); + key = qkv_grad_flat.narrow(0, numel_q, numel_k).view({S_kv, B, H_kv, D_qk}); + value = qkv_grad_flat.narrow(0, numel_q + numel_k, numel_v).view({S_kv, B, H_kv, D_v}); } else { - query = at::empty({B, S_q, H_q, D_qk}, query_grad.options()); - key = at::empty({B, S_kv, H_kv, D_qk}, key_grad.options()); - value = at::empty({B, S_kv, H_kv, D_v}, value_grad.options()); + query = qkv_grad_flat.narrow(0, 0, numel_q).view({B, S_q, H_q, D_qk}); + key = qkv_grad_flat.narrow(0, numel_q, numel_k).view({B, S_kv, H_kv, D_qk}); + value = qkv_grad_flat.narrow(0, numel_q + numel_k, numel_v).view({B, S_kv, H_kv, D_v}); } auto te_gq = makeTransformerEngineTensor(query_grad); From 78055e4e2d0414f58d2e79696a47fa4dbdb59784 Mon Sep 17 00:00:00 2001 From: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Date: Thu, 26 Mar 2026 14:56:07 -0700 Subject: [PATCH 139/172] fix grouped tensor row/col scale_inv offsets Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --- .../tensor/storage/grouped_tensor_storage.py | 28 +++++++++++++------ 1 file changed, 19 insertions(+), 9 deletions(-) diff --git a/transformer_engine/pytorch/tensor/storage/grouped_tensor_storage.py b/transformer_engine/pytorch/tensor/storage/grouped_tensor_storage.py index 68097259c6..7d416bcba9 100644 --- a/transformer_engine/pytorch/tensor/storage/grouped_tensor_storage.py +++ b/transformer_engine/pytorch/tensor/storage/grouped_tensor_storage.py @@ -508,7 +508,7 @@ def make_grouped_tensor( total_columnwise_scale_elements = 0 columnwise_scale_inv_offsets = [0] for i, s in enumerate(shape): - scale_inv_shape = quantizer.get_scale_shape(s, False) + scale_inv_shape = quantizer.get_scale_shape(s, True) columnwise_scale_elements = math.prod(scale_inv_shape) total_columnwise_scale_elements += columnwise_scale_elements columnwise_scale_inv_offsets.append(total_columnwise_scale_elements) @@ -746,15 +746,25 @@ def split_into_quantized_tensors( # populate scale_inv_offsets from the tensor offsets if self.scale_inv is not None and self.scale_inv_offsets is None: - if recipe.nvfp4(): - self.scale_inv_offsets = self.tensor_offsets // 16 - if recipe.mxfp8(): - self.scale_inv_offsets = self.tensor_offsets // 32 + if recipe.nvfp4() or recipe.mxfp8() or recipe.float8_block_scaling(): + cum = 0 + scale_inv_offsets: List[int] = [0] + for i in range(self.num_tensors): + tensor_shape = self.tensor_shapes[i] + scale_shape = self.quantizer.get_scale_shape(tensor_shape, False) + cum += math.prod(scale_shape) + scale_inv_offsets.append(cum) + self.scale_inv_offsets = scale_inv_offsets if self.columnwise_scale_inv is not None and self.columnwise_scale_inv_offsets is None: - if recipe.nvfp4(): - self.columnwise_scale_inv_offsets = self.tensor_offsets // 16 - if recipe.mxfp8(): - self.columnwise_scale_inv_offsets = self.tensor_offsets // 32 + if recipe.nvfp4() or recipe.mxfp8() or recipe.float8_block_scaling(): + cum = 0 + columnwise_scale_inv_offsets: List[int] = [0] + for i in range(self.num_tensors): + tensor_shape = self.tensor_shapes[i] + scale_shape = self.quantizer.get_scale_shape(tensor_shape, True) + cum += math.prod(scale_shape) + columnwise_scale_inv_offsets.append(cum) + self.columnwise_scale_inv_offsets = columnwise_scale_inv_offsets for i in range(self.num_tensors): quantizer = self.quantizer From d8f9ac9d4e76e40e6ab276b94a8ac3d488ee43d8 Mon Sep 17 00:00:00 2001 From: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Date: Thu, 26 Mar 2026 15:04:53 -0700 Subject: [PATCH 140/172] use fused permute kernels Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --- .../attention/dot_product_attention/utils.py | 88 +++++++++---------- 1 file changed, 42 insertions(+), 46 deletions(-) diff --git a/transformer_engine/pytorch/attention/dot_product_attention/utils.py b/transformer_engine/pytorch/attention/dot_product_attention/utils.py index 6627c55b49..f3f0541ea4 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/utils.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/utils.py @@ -2353,25 +2353,17 @@ def combine_and_quantize(qkv_layout, q, k, v, qkv_quantizer): """Combine q,k,v based on qkv_layout and quantize them together""" if isinstance(qkv_quantizer, MXFP8Quantizer): qkv_format, q_format, kv_format = get_qkv_format(qkv_layout) - # q_orig, k_orig, v_orig = q, k, v - # # permute q, k, v to bhsd/htd format - # if qkv_layout in ["bshd_bshd_bshd", "sbhd_sbhd_sbhd"]: - # print(f">>>>>>>>>>>> {qkv_layout} PermuteToGroupedTensor") - # q, k, v = PermuteToGroupedTensor.apply(q, k, v, qkv_layout) - # # else: - # if q_format not in ["bhsd", "htd"]: - # q_, _ = permute_to_grouped_tensor(q_format, q_orig) - # if kv_format not in ["bhsd", "htd"]: - # k_, _ = permute_to_grouped_tensor(kv_format, k_orig) - # v_, _ = permute_to_grouped_tensor(kv_format, v_orig) - # torch.testing.assert_close(q_, q) - # torch.testing.assert_close(k_, k) - # torch.testing.assert_close(v_, v) - if q_format not in ["bhsd", "htd"]: - q, _ = permute_to_grouped_tensor(q_format, q) - if kv_format not in ["bhsd", "htd"]: - k, _ = permute_to_grouped_tensor(kv_format, k) - v, _ = permute_to_grouped_tensor(kv_format, v) + # permute q, k, v to bhsd/htd format + qkv_contiguous_block = False + if qkv_layout in ["bshd_bshd_bshd", "sbhd_sbhd_sbhd"]: + q, k, v = PermuteToGroupedTensor.apply(q, k, v, qkv_layout) + qkv_contiguous_block = True + else: + if q_format not in ["bhsd", "htd"]: + q, _ = permute_to_grouped_tensor(q_format, q) + if kv_format not in ["bhsd", "htd"]: + k, _ = permute_to_grouped_tensor(kv_format, k) + v, _ = permute_to_grouped_tensor(kv_format, v) qkv_layout = "bhsd_bhsd_bhsd" if qkv_format != "thd" else "htd_htd_htd" # check shapes @@ -2384,33 +2376,37 @@ def combine_and_quantize(qkv_layout, q, k, v, qkv_quantizer): ) q, k, v = [x.view(-1, x.shape[-1]) for x in [q, k, v]] # quantize q, k, v - if d_qk == d_v: - input_tensors = [q, k, v] - num_tensors = len(input_tensors) - shapes = [x.shape for x in input_tensors] - grouped_tensor = GroupedTensor.make_grouped_tensor_with_shapes( - num_tensors=num_tensors, - shapes=shapes, - quantizer=qkv_quantizer, - device="cuda", - dtype=q.dtype, - ) - quantized_tensors = grouped_tensor.quantize(input_tensors) - q_fp8, k_fp8, v_fp8 = quantized_tensors[0], quantized_tensors[1], quantized_tensors[2] - else: - input_tensors = [q, k] - num_tensors = len(input_tensors) - shapes = [x.shape for x in input_tensors] - grouped_tensor = GroupedTensor.make_grouped_tensor_with_shapes( - num_tensors=num_tensors, - shapes=shapes, - quantizer=qkv_quantizer, - device="cuda", - dtype=q.dtype, - ) - quantized_tensors = grouped_tensor.quantize(input_tensors) - q_fp8, k_fp8 = quantized_tensors[0], quantized_tensors[1] - v_fp8 = qkv_quantizer(v) + # if qkv_contiguous_block: + # if d_qk == d_v: + # first_dims = torch.tensor( + # [q.shape[0], k.shape[0], v.shape[0]], dtype=torch.int64, device=q.device + # ) + # qkv_2d = torch.cat([q, k, v], dim=0) + # grouped_tensor = tex.group_quantize(qkv_2d, qkv_quantizer, 3, first_dims) + # quantized_tensors = grouped_tensor.split_into_quantized_tensors() + # q_fp8, k_fp8, v_fp8 = quantized_tensors[0], quantized_tensors[1], quantized_tensors[2] + # else: + # first_dims = torch.tensor([q.shape[0], k.shape[0]], dtype=torch.int64, device=q.device) + # qk_2d = torch.cat([q, k], dim=0) + # grouped_tensor = tex.group_quantize(qk_2d, qkv_quantizer, 2, first_dims) + # q_fp8, k_fp8 = grouped_tensor.split_into_quantized_tensors() + # v_fp8 = qkv_quantizer(v) + # else: + # input_tensors = [q, k, v] + # num_tensors = len(input_tensors) + # shapes = [x.shape for x in input_tensors] + # grouped_tensor = GroupedTensor.make_grouped_tensor_with_shapes( + # num_tensors=num_tensors, + # shapes=shapes, + # quantizer=qkv_quantizer, + # device="cuda", + # dtype=q.dtype, + # ) + # quantized_tensors = grouped_tensor.quantize(input_tensors) + # q_fp8, k_fp8, v_fp8 = quantized_tensors[0], quantized_tensors[1], quantized_tensors[2] + # else: + # q_fp8, k_fp8, v_fp8 = [qkv_quantizer(x) for x in [q, k, v]] + q_fp8, k_fp8, v_fp8 = [qkv_quantizer(x) for x in [q, k, v]] # view rowwise/columnwise data back to original shapes, not rowwise_scale_inv/columnwise_scale_inv q_fp8, k_fp8, v_fp8 = [x.view(s) for x, s in zip([q_fp8, k_fp8, v_fp8], original_shapes)] From ca5376956e8b8f662c7fa88661695b3e9eda4f8f Mon Sep 17 00:00:00 2001 From: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Date: Thu, 26 Mar 2026 17:14:34 -0700 Subject: [PATCH 141/172] quantize row/col as needed in fwd/bwd, non-cp/cp Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --- .../dot_product_attention/backends.py | 2 +- .../dot_product_attention/context_parallel.py | 4 ++-- .../attention/dot_product_attention/utils.py | 20 +++++++++++++++++-- 3 files changed, 21 insertions(+), 5 deletions(-) diff --git a/transformer_engine/pytorch/attention/dot_product_attention/backends.py b/transformer_engine/pytorch/attention/dot_product_attention/backends.py index 392c1c8c0d..5344f57706 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/backends.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/backends.py @@ -1286,7 +1286,7 @@ def forward( q_fp8, k_fp8, v_fp8 = q, k, v else: q_fp8, k_fp8, v_fp8, qkv_layout = combine_and_quantize( - qkv_layout, q, k, v, QKV_quantizer + qkv_layout, q, k, v, QKV_quantizer, used_in_backward=is_training ) # print quantizers diff --git a/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py b/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py index 328538e8d1..0287bb0992 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py @@ -1214,7 +1214,7 @@ def cp_p2p_bwd_fused_attn( ] else: q_part, k_part, v_part, qkv_layout = combine_and_quantize( - qkv_layout, q_part, k_part, v_part, QKV_quantizer_per_step + qkv_layout, q_part, k_part, v_part, QKV_quantizer_per_step, used_in_forward=False, used_in_backward=True ) if not fp8_recipe.mxfp8(): if not (fp8_recipe.float8_current_scaling() and _dpa_fp8_cs_o_in_f16): @@ -3588,7 +3588,7 @@ def backward(ctx, dout, *_args): ) else: q_part, k_part, v_part, new_qkv_layout = combine_and_quantize( - ctx.qkv_layout, q_part, k_part, v_part, ctx.QKV_quantizer + ctx.qkv_layout, q_part, k_part, v_part, ctx.QKV_quantizer, used_in_forward=False, used_in_backward=True ) dout_part, do_format = dpa_utils.permute_to_grouped_tensor( do_format, dout_part diff --git a/transformer_engine/pytorch/attention/dot_product_attention/utils.py b/transformer_engine/pytorch/attention/dot_product_attention/utils.py index f3f0541ea4..1f35616cb7 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/utils.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/utils.py @@ -2349,7 +2349,7 @@ def backward( return q, k, v, None -def combine_and_quantize(qkv_layout, q, k, v, qkv_quantizer): +def combine_and_quantize(qkv_layout, q, k, v, qkv_quantizer, used_in_forward=True, used_in_backward=False): """Combine q,k,v based on qkv_layout and quantize them together""" if isinstance(qkv_quantizer, MXFP8Quantizer): qkv_format, q_format, kv_format = get_qkv_format(qkv_layout) @@ -2406,7 +2406,23 @@ def combine_and_quantize(qkv_layout, q, k, v, qkv_quantizer): # q_fp8, k_fp8, v_fp8 = quantized_tensors[0], quantized_tensors[1], quantized_tensors[2] # else: # q_fp8, k_fp8, v_fp8 = [qkv_quantizer(x) for x in [q, k, v]] - q_fp8, k_fp8, v_fp8 = [qkv_quantizer(x) for x in [q, k, v]] + if used_in_forward and used_in_backward: + q_fp8, k_fp8, v_fp8 = [qkv_quantizer(x) for x in [q, k, v]] + if used_in_forward and not used_in_backward: + qkv_quantizer.rowwise_usage = True + qkv_quantizer.columnwise_usage = False + q_fp8, k_fp8 = [qkv_quantizer(x) for x in [q, k]] + qkv_quantizer.rowwise_usage = False + qkv_quantizer.columnwise_usage = True + v_fp8 = qkv_quantizer(v) + if (not used_in_forward) and used_in_backward: + qkv_quantizer.rowwise_usage = True + qkv_quantizer.columnwise_usage = True + q_fp8, k_fp8 = [qkv_quantizer(x) for x in [q, k]] + qkv_quantizer.rowwise_usage = True + qkv_quantizer.columnwise_usage = False + v_fp8 = qkv_quantizer(v) + # view rowwise/columnwise data back to original shapes, not rowwise_scale_inv/columnwise_scale_inv q_fp8, k_fp8, v_fp8 = [x.view(s) for x, s in zip([q_fp8, k_fp8, v_fp8], original_shapes)] From f19e852be3463210f2b3be5839ae8931e5ad92d0 Mon Sep 17 00:00:00 2001 From: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Date: Thu, 26 Mar 2026 17:22:09 -0700 Subject: [PATCH 142/172] Revert "quantize row/col as needed in fwd/bwd, non-cp/cp" This reverts commit ca5376956e8b8f662c7fa88661695b3e9eda4f8f. Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --- .../dot_product_attention/backends.py | 2 +- .../dot_product_attention/context_parallel.py | 4 ++-- .../attention/dot_product_attention/utils.py | 20 ++----------------- 3 files changed, 5 insertions(+), 21 deletions(-) diff --git a/transformer_engine/pytorch/attention/dot_product_attention/backends.py b/transformer_engine/pytorch/attention/dot_product_attention/backends.py index 5344f57706..392c1c8c0d 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/backends.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/backends.py @@ -1286,7 +1286,7 @@ def forward( q_fp8, k_fp8, v_fp8 = q, k, v else: q_fp8, k_fp8, v_fp8, qkv_layout = combine_and_quantize( - qkv_layout, q, k, v, QKV_quantizer, used_in_backward=is_training + qkv_layout, q, k, v, QKV_quantizer ) # print quantizers diff --git a/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py b/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py index 0287bb0992..328538e8d1 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py @@ -1214,7 +1214,7 @@ def cp_p2p_bwd_fused_attn( ] else: q_part, k_part, v_part, qkv_layout = combine_and_quantize( - qkv_layout, q_part, k_part, v_part, QKV_quantizer_per_step, used_in_forward=False, used_in_backward=True + qkv_layout, q_part, k_part, v_part, QKV_quantizer_per_step ) if not fp8_recipe.mxfp8(): if not (fp8_recipe.float8_current_scaling() and _dpa_fp8_cs_o_in_f16): @@ -3588,7 +3588,7 @@ def backward(ctx, dout, *_args): ) else: q_part, k_part, v_part, new_qkv_layout = combine_and_quantize( - ctx.qkv_layout, q_part, k_part, v_part, ctx.QKV_quantizer, used_in_forward=False, used_in_backward=True + ctx.qkv_layout, q_part, k_part, v_part, ctx.QKV_quantizer ) dout_part, do_format = dpa_utils.permute_to_grouped_tensor( do_format, dout_part diff --git a/transformer_engine/pytorch/attention/dot_product_attention/utils.py b/transformer_engine/pytorch/attention/dot_product_attention/utils.py index 1f35616cb7..f3f0541ea4 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/utils.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/utils.py @@ -2349,7 +2349,7 @@ def backward( return q, k, v, None -def combine_and_quantize(qkv_layout, q, k, v, qkv_quantizer, used_in_forward=True, used_in_backward=False): +def combine_and_quantize(qkv_layout, q, k, v, qkv_quantizer): """Combine q,k,v based on qkv_layout and quantize them together""" if isinstance(qkv_quantizer, MXFP8Quantizer): qkv_format, q_format, kv_format = get_qkv_format(qkv_layout) @@ -2406,23 +2406,7 @@ def combine_and_quantize(qkv_layout, q, k, v, qkv_quantizer, used_in_forward=Tru # q_fp8, k_fp8, v_fp8 = quantized_tensors[0], quantized_tensors[1], quantized_tensors[2] # else: # q_fp8, k_fp8, v_fp8 = [qkv_quantizer(x) for x in [q, k, v]] - if used_in_forward and used_in_backward: - q_fp8, k_fp8, v_fp8 = [qkv_quantizer(x) for x in [q, k, v]] - if used_in_forward and not used_in_backward: - qkv_quantizer.rowwise_usage = True - qkv_quantizer.columnwise_usage = False - q_fp8, k_fp8 = [qkv_quantizer(x) for x in [q, k]] - qkv_quantizer.rowwise_usage = False - qkv_quantizer.columnwise_usage = True - v_fp8 = qkv_quantizer(v) - if (not used_in_forward) and used_in_backward: - qkv_quantizer.rowwise_usage = True - qkv_quantizer.columnwise_usage = True - q_fp8, k_fp8 = [qkv_quantizer(x) for x in [q, k]] - qkv_quantizer.rowwise_usage = True - qkv_quantizer.columnwise_usage = False - v_fp8 = qkv_quantizer(v) - + q_fp8, k_fp8, v_fp8 = [qkv_quantizer(x) for x in [q, k, v]] # view rowwise/columnwise data back to original shapes, not rowwise_scale_inv/columnwise_scale_inv q_fp8, k_fp8, v_fp8 = [x.view(s) for x, s in zip([q_fp8, k_fp8, v_fp8], original_shapes)] From 2d403f9854954c6e078eaf7a2702faf489c11ee1 Mon Sep 17 00:00:00 2001 From: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Date: Thu, 26 Mar 2026 20:18:39 -0700 Subject: [PATCH 143/172] Reapply "quantize row/col as needed in fwd/bwd, non-cp/cp" This reverts commit f19e852be3463210f2b3be5839ae8931e5ad92d0. Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --- .../dot_product_attention/backends.py | 2 +- .../dot_product_attention/context_parallel.py | 4 ++-- .../attention/dot_product_attention/utils.py | 20 +++++++++++++++++-- 3 files changed, 21 insertions(+), 5 deletions(-) diff --git a/transformer_engine/pytorch/attention/dot_product_attention/backends.py b/transformer_engine/pytorch/attention/dot_product_attention/backends.py index 392c1c8c0d..5344f57706 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/backends.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/backends.py @@ -1286,7 +1286,7 @@ def forward( q_fp8, k_fp8, v_fp8 = q, k, v else: q_fp8, k_fp8, v_fp8, qkv_layout = combine_and_quantize( - qkv_layout, q, k, v, QKV_quantizer + qkv_layout, q, k, v, QKV_quantizer, used_in_backward=is_training ) # print quantizers diff --git a/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py b/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py index 328538e8d1..0287bb0992 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py @@ -1214,7 +1214,7 @@ def cp_p2p_bwd_fused_attn( ] else: q_part, k_part, v_part, qkv_layout = combine_and_quantize( - qkv_layout, q_part, k_part, v_part, QKV_quantizer_per_step + qkv_layout, q_part, k_part, v_part, QKV_quantizer_per_step, used_in_forward=False, used_in_backward=True ) if not fp8_recipe.mxfp8(): if not (fp8_recipe.float8_current_scaling() and _dpa_fp8_cs_o_in_f16): @@ -3588,7 +3588,7 @@ def backward(ctx, dout, *_args): ) else: q_part, k_part, v_part, new_qkv_layout = combine_and_quantize( - ctx.qkv_layout, q_part, k_part, v_part, ctx.QKV_quantizer + ctx.qkv_layout, q_part, k_part, v_part, ctx.QKV_quantizer, used_in_forward=False, used_in_backward=True ) dout_part, do_format = dpa_utils.permute_to_grouped_tensor( do_format, dout_part diff --git a/transformer_engine/pytorch/attention/dot_product_attention/utils.py b/transformer_engine/pytorch/attention/dot_product_attention/utils.py index f3f0541ea4..1f35616cb7 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/utils.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/utils.py @@ -2349,7 +2349,7 @@ def backward( return q, k, v, None -def combine_and_quantize(qkv_layout, q, k, v, qkv_quantizer): +def combine_and_quantize(qkv_layout, q, k, v, qkv_quantizer, used_in_forward=True, used_in_backward=False): """Combine q,k,v based on qkv_layout and quantize them together""" if isinstance(qkv_quantizer, MXFP8Quantizer): qkv_format, q_format, kv_format = get_qkv_format(qkv_layout) @@ -2406,7 +2406,23 @@ def combine_and_quantize(qkv_layout, q, k, v, qkv_quantizer): # q_fp8, k_fp8, v_fp8 = quantized_tensors[0], quantized_tensors[1], quantized_tensors[2] # else: # q_fp8, k_fp8, v_fp8 = [qkv_quantizer(x) for x in [q, k, v]] - q_fp8, k_fp8, v_fp8 = [qkv_quantizer(x) for x in [q, k, v]] + if used_in_forward and used_in_backward: + q_fp8, k_fp8, v_fp8 = [qkv_quantizer(x) for x in [q, k, v]] + if used_in_forward and not used_in_backward: + qkv_quantizer.rowwise_usage = True + qkv_quantizer.columnwise_usage = False + q_fp8, k_fp8 = [qkv_quantizer(x) for x in [q, k]] + qkv_quantizer.rowwise_usage = False + qkv_quantizer.columnwise_usage = True + v_fp8 = qkv_quantizer(v) + if (not used_in_forward) and used_in_backward: + qkv_quantizer.rowwise_usage = True + qkv_quantizer.columnwise_usage = True + q_fp8, k_fp8 = [qkv_quantizer(x) for x in [q, k]] + qkv_quantizer.rowwise_usage = True + qkv_quantizer.columnwise_usage = False + v_fp8 = qkv_quantizer(v) + # view rowwise/columnwise data back to original shapes, not rowwise_scale_inv/columnwise_scale_inv q_fp8, k_fp8, v_fp8 = [x.view(s) for x, s in zip([q_fp8, k_fp8, v_fp8], original_shapes)] From f9e4e2011224c7643f1e6c26efaf8d001eaa796d Mon Sep 17 00:00:00 2001 From: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Date: Thu, 26 Mar 2026 20:19:54 -0700 Subject: [PATCH 144/172] fix v_col format when row is quantized Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --- transformer_engine/common/fused_attn/fused_attn.cpp | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/transformer_engine/common/fused_attn/fused_attn.cpp b/transformer_engine/common/fused_attn/fused_attn.cpp index 77c4dd143a..b4b1cf485a 100644 --- a/transformer_engine/common/fused_attn/fused_attn.cpp +++ b/transformer_engine/common/fused_attn/fused_attn.cpp @@ -675,8 +675,13 @@ void nvte_fused_attn_fwd( &t_q); nvte_convert_qkv_format(kv_format, input_K->data.shape, kv_format, tmp_shape, &b, &h_kv, &s_kv, &d_qk, &t_kv); - nvte_convert_qkv_format(kv_format, input_V->data.shape, kv_format, tmp_shape, &b, &h_kv, &s_kv, - &d_v, &t_kv); + if (input_V->scaling_mode != NVTE_MXFP8_1D_SCALING) { + nvte_convert_qkv_format(kv_format, input_V->data.shape, kv_format, tmp_shape, &b, &h_kv, &s_kv, + &d_v, &t_kv); + } else { + nvte_convert_qkv_format(kv_format, input_V->columnwise_data.shape, kv_format, tmp_shape, &b, &h_kv, &s_kv, + &d_v, &t_kv); + } if (q_format == NVTE_QKV_Format::NVTE_THD) { b = input_cu_seqlens_q->data.shape[0] - 1; } else if (kv_format == NVTE_QKV_Format::NVTE_THD) { From fde366a7ebd9400c061daf2777e157d4db8561c9 Mon Sep 17 00:00:00 2001 From: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Date: Thu, 26 Mar 2026 20:20:42 -0700 Subject: [PATCH 145/172] add back necessary bwd quants for shadow paths/cp a2a Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --- .../pytorch/attention/dot_product_attention/backends.py | 4 ++-- .../attention/dot_product_attention/context_parallel.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/transformer_engine/pytorch/attention/dot_product_attention/backends.py b/transformer_engine/pytorch/attention/dot_product_attention/backends.py index 5344f57706..4eddb0f736 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/backends.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/backends.py @@ -1461,7 +1461,7 @@ def forward( if isinstance(tmp_quantizer, MXFP8Quantizer): tmp_quantizer.optimize_for_gemm = False q_fp8_, k_fp8_, _, _ = combine_and_quantize( - original_qkv_layout, q, k, v, tmp_quantizer + original_qkv_layout, q, k, v, tmp_quantizer, used_in_backward=True ) q_ = q_fp8_.dequantize(dtype=out_nominal_dtype) k_ = k_fp8_.dequantize(dtype=out_nominal_dtype) @@ -1797,7 +1797,7 @@ def backward(ctx, d_out, *_args): if isinstance(tmp_quantizer, MXFP8Quantizer): tmp_quantizer.optimize_for_gemm = False q_fp8_, k_fp8_, v_fp8_, _ = combine_and_quantize( - original_qkv_layout, q, k, v, tmp_quantizer + original_qkv_layout, q, k, v, tmp_quantizer, used_in_backward=True ) q_shadow_f16, k_shadow_f16, v_shadow_f16 = [ x.dequantize(dtype=dqkv_nominal_dtype) for x in (q_fp8_, k_fp8_, v_fp8_) diff --git a/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py b/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py index 0287bb0992..64eb991710 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py @@ -3968,7 +3968,7 @@ def forward( if fp8: if fp8_recipe.mxfp8(): q_fp8, k_fp8, v_fp8, qkv_layout = combine_and_quantize( - qkv_layout, q_part, k_part, v_part, QKV_quantizer + qkv_layout, q_part, k_part, v_part, QKV_quantizer, used_in_backward=is_training ) q_part, k_part, v_part = [q_fp8, k_fp8, v_fp8] else: From 81f723d247bf1a9d95877f008bdfbe02609b89a8 Mon Sep 17 00:00:00 2001 From: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Date: Thu, 26 Mar 2026 20:54:34 -0700 Subject: [PATCH 146/172] remove ZInv for all layouts except T3HD Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --- .../common/fused_attn/fused_attn.cpp | 5 +++- .../common/fused_attn/fused_attn_fp8.cu | 29 ++++++++++++------- .../dot_product_attention/context_parallel.py | 26 ++++++++++++----- .../pytorch/cpp_extensions/fused_attn.py | 2 +- .../pytorch/csrc/extensions/attention.cpp | 6 ++-- 5 files changed, 46 insertions(+), 22 deletions(-) diff --git a/transformer_engine/common/fused_attn/fused_attn.cpp b/transformer_engine/common/fused_attn/fused_attn.cpp index b4b1cf485a..580bb75814 100644 --- a/transformer_engine/common/fused_attn/fused_attn.cpp +++ b/transformer_engine/common/fused_attn/fused_attn.cpp @@ -863,7 +863,10 @@ void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETenso #if (CUDNN_VERSION >= 8900) size_t i = 0; const Tensor *input_M = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); - const Tensor *input_ZInv = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); + const Tensor *input_ZInv = nullptr; + if (qkv_layout == NVTE_QKV_Layout::NVTE_T3HD) { + input_ZInv = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); + } const Tensor *input_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); const Tensor *input_SoftmaxOffset = nullptr; if (softmax_type != NVTE_VANILLA_SOFTMAX) { diff --git a/transformer_engine/common/fused_attn/fused_attn_fp8.cu b/transformer_engine/common/fused_attn/fused_attn_fp8.cu index 4a25df2185..71d45b1f91 100644 --- a/transformer_engine/common/fused_attn/fused_attn_fp8.cu +++ b/transformer_engine/common/fused_attn/fused_attn_fp8.cu @@ -2745,14 +2745,16 @@ void fused_attn_fp8_fwd(size_t batch, size_t num_attn_heads, size_t num_gqa_grou if (Aux_CTX_Tensors->size == 0) { int i = 0; Tensor* output_M = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); - Tensor* output_ZInv = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); - Tensor* output_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); output_M->data.dptr = nullptr; output_M->data.shape = {batch, num_attn_heads, max_seqlen_q, 1}; output_M->data.dtype = DType::kFloat32; - output_ZInv->data.dptr = nullptr; - output_ZInv->data.shape = {batch, num_attn_heads, max_seqlen_q, 1}; - output_ZInv->data.dtype = DType::kFloat32; + if (qkv_layout == NVTE_QKV_Layout::NVTE_T3HD) { + Tensor* output_ZInv = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); + output_ZInv->data.dptr = nullptr; + output_ZInv->data.shape = {batch, num_attn_heads, max_seqlen_q, 1}; + output_ZInv->data.dtype = DType::kFloat32; + } + Tensor* output_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); output_rng_state->data.dptr = nullptr; output_rng_state->data.shape = {2}; output_rng_state->data.dtype = DType::kInt64; @@ -2763,13 +2765,16 @@ void fused_attn_fp8_fwd(size_t batch, size_t num_attn_heads, size_t num_gqa_grou output_softmax_offset->data.dtype = DType::kFloat32; } Aux_CTX_Tensors->size = i; - } else if (Aux_CTX_Tensors->size >= 3) { + } else if (Aux_CTX_Tensors->size >= 2) { int i = 0; Tensor* output_M = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); - Tensor* output_ZInv = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); - Tensor* output_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); devPtrM = output_M->data.dptr; - devPtrZInv = output_ZInv->data.dptr; + devPtrZInv = nullptr; + if (qkv_layout == NVTE_QKV_Layout::NVTE_T3HD) { + Tensor* output_ZInv = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); + devPtrZInv = output_ZInv->data.dptr; + } + Tensor* output_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); output_rng_state->data.dptr = rng_state->data.dptr; if (softmax_type != NVTE_VANILLA_SOFTMAX) { Tensor* output_softmax_offset = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); @@ -2874,7 +2879,8 @@ void fused_attn_fp8_bwd( } void* devPtrM = input_M->data.dptr; - void* devPtrZInv = input_ZInv->data.dptr; + void* devPtrZInv = + (input_ZInv != nullptr) ? input_ZInv->data.dptr : nullptr; void* devPtrScaleS = input_S->scale.dptr; void* devPtrDescaleS = input_S->scale_inv.dptr; @@ -2930,6 +2936,9 @@ void fused_attn_fp8_bwd( get_cudnn_fe_dtype(O_type), get_cudnn_fe_dtype(dO_type), get_cudnn_fe_dtype(dQKV_type), input_dO->scaling_mode, workspace->data.dptr, &workspace_size, stream, handle); } else if (dqkv_layout == NVTE_QKV_Layout::NVTE_T3HD) { + // remove this when cuDNN FE supports FP8 + THD + NVTE_CHECK(input_ZInv != nullptr && input_ZInv->data.dptr != nullptr, + "ZInv tensor required for FP8 fused attention backward with T3HD layout."); fused_attn::fused_attn_fp8_bwd_impl( batch, num_attn_heads, max_seqlen_q, max_seqlen_kv, head_dim_qk, attn_scale, p_dropout, qkv_layout, devPtrQ, devPtrK, devPtrV, devPtrM, devPtrZInv, devPtrO, devPtrdO, devPtrdQ, diff --git a/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py b/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py index 64eb991710..d3e19b2197 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py @@ -977,7 +977,10 @@ def cp_p2p_fwd_fused_attn( ) if fp8: - softmax_lse_per_step, _, rng_states = aux_ctx_tensors + if qkv_layout != "t3hd": + softmax_lse_per_step, rng_states = aux_ctx_tensors + else: + softmax_lse_per_step, _, rng_states = aux_ctx_tensors else: softmax_lse_per_step, rng_states, *rest = aux_ctx_tensors attn_bias = rest[0] if len(rest) > 0 else None @@ -3192,7 +3195,10 @@ def forward( **fp8_meta_kwargs, ) if fp8: - softmax_lse_per_step[i], _, rng_states[i] = aux_ctx_tensors + if qkv_layout != "t3hd": + softmax_lse_per_step[i], rng_states[i] = aux_ctx_tensors + else: + softmax_lse_per_step[i], _, rng_states[i] = aux_ctx_tensors else: softmax_lse_per_step[i], rng_states[i], *_ = aux_ctx_tensors if return_max_logit: @@ -3554,11 +3560,17 @@ def backward(ctx, dout, *_args): out_part = out.select(seq_dim_o, i).contiguous() dout_part = dout.select(seq_dim_o, i).contiguous() if ctx.use_fused_attention: - aux_ctx_tensors = [ - softmax_lse_per_step[i], - softmax_lse_per_step[i], - rng_states[i], - ] + if ctx.fp8 and ctx.qkv_layout == "t3hd": + aux_ctx_tensors = [ + softmax_lse_per_step[i], + softmax_lse_per_step[i], + rng_states[i], + ] + else: + aux_ctx_tensors = [ + softmax_lse_per_step[i], + rng_states[i], + ] fused_attn_backend = tex.NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen fp8_meta_kwargs = {} new_qkv_layout = ctx.qkv_layout diff --git a/transformer_engine/pytorch/cpp_extensions/fused_attn.py b/transformer_engine/pytorch/cpp_extensions/fused_attn.py index 61ee95662d..52d08e38bc 100644 --- a/transformer_engine/pytorch/cpp_extensions/fused_attn.py +++ b/transformer_engine/pytorch/cpp_extensions/fused_attn.py @@ -256,7 +256,7 @@ def fused_attn_fwd( M: torch.Tensor max(Q*K.T) shape [batch_size, num_heads, max_seqlen_q, 1], dtype float32 - ZInv: torch.Tensor + ZInv: torch.Tensor, only allocated for T3HD path 1/sum(e^(x - max(x))), where x=Q*K.T shape [batch_size, num_heads, max_seqlen_q, 1], dtype float32 rng_state: torch.Tensor, optional, if backend is not F16_max512_seqlen diff --git a/transformer_engine/pytorch/csrc/extensions/attention.cpp b/transformer_engine/pytorch/csrc/extensions/attention.cpp index 54bb731961..5822136a8e 100644 --- a/transformer_engine/pytorch/csrc/extensions/attention.cpp +++ b/transformer_engine/pytorch/csrc/extensions/attention.cpp @@ -278,7 +278,7 @@ std::vector fused_attn_fwd( // f16_arbitrary: // return_max_logit=false: S [b, h, sq, 1], rng_state [2], (optional) Bias [1, h, sq, skv], (optional) SoftmaxOffset [1, h, 1, 1] // return_max_logit=true: Max [b, h, sq, 1], Sum_Exp [b, h, sq, 1], rng_state [2], (optional) Bias [1, h, sq, skv], (optional) SoftmaxOffset [1, h, 1, 1] - // fp8 : M [b, h, sq, 1], ZInv [b, h, sq, 1], rng_state [2] + // fp8 : M [b, h, sq, 1], optional ZInv [b, h, sq, 1] (T3HD path), rng_state [2] size_t i = 0; at::Tensor output_tensor; // intermediate softmax tensor, S or M @@ -286,8 +286,8 @@ std::vector fused_attn_fwd( allocateSpace(nvte_shape_to_vector(nvte_tensor_shape(nvte_aux_tensor_pack.tensors[i])), static_cast(nvte_tensor_type(nvte_aux_tensor_pack.tensors[i])), false); set_tensor_param(i++, output_tensor); - // fp8 has an additional softmax stats tensor, ZInv; return_max_logit=true has an additional Sum_Exp tensor - if (return_max_logit || qkv_type == DType::kFloat8E4M3 || qkv_type == DType::kFloat8E5M2) { + // fp8 T3HD: second softmax stats ZInv; return_max_logit=true: Max (FP8 FE v1 BSHD/SBHD/BHSD skips ZInv) + if (((qkv_type == DType::kFloat8E4M3 || qkv_type == DType::kFloat8E5M2) && qkv_layout == NVTE_QKV_Layout::NVTE_T3HD) || return_max_logit) { output_tensor = allocateSpace(nvte_shape_to_vector(nvte_tensor_shape(nvte_aux_tensor_pack.tensors[i])), static_cast(nvte_tensor_type(nvte_aux_tensor_pack.tensors[i])), false); From 89daa491b7a9285ed4d5bce3b6aec5ea8eaa343e Mon Sep 17 00:00:00 2001 From: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Date: Fri, 27 Mar 2026 21:50:52 -0700 Subject: [PATCH 147/172] fix cp p2p with zinv Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --- .../dot_product_attention/context_parallel.py | 26 ++++++++++++------- 1 file changed, 16 insertions(+), 10 deletions(-) diff --git a/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py b/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py index d3e19b2197..dbf5f30a57 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py @@ -1167,11 +1167,14 @@ def cp_p2p_bwd_fused_attn( ): """Per-tile backward call of CP P2P with FusedAttention backend""" if fp8: - aux_tensors = [ - softmax_lse, - softmax_lse, - rng_states[cp_size - step - 1], - ] + if qkv_layout == "t3hd": + aux_tensors = [ + softmax_lse, + softmax_lse, + rng_states[cp_size - step - 1], + ] + else: + aux_tensors = [softmax_lse, rng_states[cp_size - step - 1]] else: aux_tensors = [softmax_lse, rng_states[cp_size - step - 1]] @@ -1190,11 +1193,14 @@ def cp_p2p_bwd_fused_attn( elif section == "upper-triangle": q_part, out_part, dout_part = [x.contiguous() for x in [q_part, out_part, dout_part]] if fp8: - aux_tensors = [ - softmax_lse_, - softmax_lse_, - rng_states[cp_size - step - 1], - ] + if qkv_layout == "t3hd": + aux_tensors = [ + softmax_lse_, + softmax_lse_, + rng_states[cp_size - step - 1], + ] + else: + aux_tensors = [softmax_lse_, rng_states[cp_size - step - 1]] else: aux_tensors = [softmax_lse_, rng_states[cp_size - step - 1]] From 60740fa851e5d86658e5f054a867f072df261522 Mon Sep 17 00:00:00 2001 From: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Date: Fri, 27 Mar 2026 21:55:32 -0700 Subject: [PATCH 148/172] temporarily switch to GH FE main Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --- .gitmodules | 3 +-- 3rdparty/cudnn-frontend | 2 +- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/.gitmodules b/.gitmodules index 8c7646c00d..4b188d6bb1 100644 --- a/.gitmodules +++ b/.gitmodules @@ -3,8 +3,7 @@ url = https://github.com/google/googletest.git [submodule "3rdparty/cudnn-frontend"] path = 3rdparty/cudnn-frontend - url = https://gitlab-master.nvidia.com/cudnn/cudnn_frontend.git - branch = develop + url = https://github.com/NVIDIA/cudnn-frontend.git [submodule "3rdparty/cutlass"] path = 3rdparty/cutlass url = https://github.com/NVIDIA/cutlass.git diff --git a/3rdparty/cudnn-frontend b/3rdparty/cudnn-frontend index 28ebf81b5e..7b9b711c22 160000 --- a/3rdparty/cudnn-frontend +++ b/3rdparty/cudnn-frontend @@ -1 +1 @@ -Subproject commit 28ebf81b5e19b92045d1af0fd4c0b9c4599c2b53 +Subproject commit 7b9b711c22b6823e87150213ecd8449260db8610 From a7ff000464140ab972b0b8aee98b1451895b203a Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sat, 28 Mar 2026 05:00:07 +0000 Subject: [PATCH 149/172] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../common/fused_attn/flash_attn.cu | 122 +++++++++++------- .../common/fused_attn/fused_attn.cpp | 4 +- .../common/fused_attn/fused_attn_fp8.cu | 3 +- .../include/transformer_engine/fused_attn.h | 4 +- .../dot_product_attention/backends.py | 4 +- .../dot_product_attention/context_parallel.py | 23 +++- .../attention/dot_product_attention/utils.py | 4 +- transformer_engine/pytorch/csrc/extensions.h | 9 +- .../pytorch/csrc/extensions/attention.cpp | 34 ++--- .../pytorch/csrc/extensions/pybind.cpp | 15 ++- 10 files changed, 136 insertions(+), 86 deletions(-) diff --git a/transformer_engine/common/fused_attn/flash_attn.cu b/transformer_engine/common/fused_attn/flash_attn.cu index 34340babda..c30cc3d2d9 100644 --- a/transformer_engine/common/fused_attn/flash_attn.cu +++ b/transformer_engine/common/fused_attn/flash_attn.cu @@ -141,10 +141,13 @@ void prepare_flash_attn_bwd(Tensor q, Tensor k, Tensor v, Tensor qkv, cudaStream } template -__launch_bounds__(1024) __global__ void permute_to_grouped_tensor_fwd_kernel( - const T *__restrict__ q, const T *__restrict__ k, const T *__restrict__ v, T *__restrict__ q_out, - T *__restrict__ k_out, T *__restrict__ v_out, size_t b, size_t s_q, size_t h_q, size_t d_qk, - size_t s_kv, size_t h_kv, size_t d_v, unsigned int permute_s_splits) { +__launch_bounds__(1024) __global__ + void permute_to_grouped_tensor_fwd_kernel(const T *__restrict__ q, const T *__restrict__ k, + const T *__restrict__ v, T *__restrict__ q_out, + T *__restrict__ k_out, T *__restrict__ v_out, + size_t b, size_t s_q, size_t h_q, size_t d_qk, + size_t s_kv, size_t h_kv, size_t d_v, + unsigned int permute_s_splits) { const int which_tensor = blockIdx.z; const T *__restrict__ tensor_in = which_tensor == 0 ? q : (which_tensor == 1 ? k : v); T *__restrict__ tensor_out = which_tensor == 0 ? q_out : (which_tensor == 1 ? k_out : v_out); @@ -165,20 +168,24 @@ __launch_bounds__(1024) __global__ void permute_to_grouped_tensor_fwd_kernel( if (Ddim % static_cast(nvec) != 0) return; const unsigned int s_part = blockIdx.y; - const size_t s_begin = (Sdim * static_cast(s_part)) / static_cast(permute_s_splits); - const size_t s_end = (Sdim * static_cast(s_part + 1)) / static_cast(permute_s_splits); + const size_t s_begin = + (Sdim * static_cast(s_part)) / static_cast(permute_s_splits); + const size_t s_end = + (Sdim * static_cast(s_part + 1)) / static_cast(permute_s_splits); const size_t S_chunk = s_end - s_begin; const size_t in_base = kIsBshdBshdBshd ? b_i * Sdim * Hdim * Ddim : b_i * Hdim * Ddim; const size_t out_base = b_i * Hdim * Sdim * Ddim + h_i * Sdim * Ddim; - const bool use_vec128 = (Ddim % static_cast(nvec128) == 0) && - ((reinterpret_cast(tensor_in) % alignof(Vec)) == 0) && - ((reinterpret_cast(tensor_out) % alignof(Vec)) == 0); + const bool use_vec128 = + (Ddim % static_cast(nvec128) == 0) && + ((reinterpret_cast(tensor_in) % alignof(Vec)) == 0) && + ((reinterpret_cast(tensor_out) % alignof(Vec)) == 0); if (use_vec128) { const size_t d_vec = Ddim / static_cast(nvec128); const size_t total_work = S_chunk * d_vec; - for (size_t w = static_cast(threadIdx.x); w < total_work; w += static_cast(blockDim.x)) { + for (size_t w = static_cast(threadIdx.x); w < total_work; + w += static_cast(blockDim.x)) { const size_t s_local = w / d_vec; const size_t s_i = s_begin + s_local; const size_t v = w % d_vec; @@ -191,12 +198,14 @@ __launch_bounds__(1024) __global__ void permute_to_grouped_tensor_fwd_kernel( in_ptr = tensor_in + s_i * b * Hdim * Ddim + in_base + h_i * Ddim + d_off; } T *__restrict__ out_ptr = tensor_out + out_base + s_i * Ddim + d_off; - *reinterpret_cast *>(out_ptr) = *reinterpret_cast *>(in_ptr); + *reinterpret_cast *>(out_ptr) = + *reinterpret_cast *>(in_ptr); } } else { const size_t d_vec = Ddim / static_cast(nvec); const size_t total_work = S_chunk * d_vec; - for (size_t w = static_cast(threadIdx.x); w < total_work; w += static_cast(blockDim.x)) { + for (size_t w = static_cast(threadIdx.x); w < total_work; + w += static_cast(blockDim.x)) { const size_t s_local = w / d_vec; const size_t s_i = s_begin + s_local; const size_t v = w % d_vec; @@ -220,7 +229,8 @@ __launch_bounds__(1024) __global__ void permute_to_grouped_tensor_bwd_kernel( T *__restrict__ q, T *__restrict__ k, T *__restrict__ v, size_t b, size_t s_q, size_t h_q, size_t d_qk, size_t s_kv, size_t h_kv, size_t d_v, unsigned int permute_s_splits) { const int which_tensor = blockIdx.z; - const T *__restrict__ tensor_in = which_tensor == 0 ? grad_q : (which_tensor == 1 ? grad_k : grad_v); + const T *__restrict__ tensor_in = + which_tensor == 0 ? grad_q : (which_tensor == 1 ? grad_k : grad_v); T *__restrict__ tensor_out = which_tensor == 0 ? q : (which_tensor == 1 ? k : v); const size_t Sdim = which_tensor == 0 ? s_q : s_kv; const size_t Hdim = which_tensor == 0 ? h_q : h_kv; @@ -239,20 +249,25 @@ __launch_bounds__(1024) __global__ void permute_to_grouped_tensor_bwd_kernel( if (Ddim % static_cast(nvec) != 0) return; const unsigned int s_part = blockIdx.y; - const size_t s_begin = (Sdim * static_cast(s_part)) / static_cast(permute_s_splits); - const size_t s_end = (Sdim * static_cast(s_part + 1)) / static_cast(permute_s_splits); + const size_t s_begin = + (Sdim * static_cast(s_part)) / static_cast(permute_s_splits); + const size_t s_end = + (Sdim * static_cast(s_part + 1)) / static_cast(permute_s_splits); const size_t S_chunk = s_end - s_begin; const size_t in_base = b_i * Hdim * Sdim * Ddim + h_i * Sdim * Ddim; - const size_t out_base = kIsBshdBshdBshd ? b_i * Sdim * Hdim * Ddim + h_i * Ddim : b_i * Hdim * Ddim + h_i * Ddim; - const bool use_vec128 = (Ddim % static_cast(nvec128) == 0) && - ((reinterpret_cast(tensor_in) % alignof(Vec)) == 0) && - ((reinterpret_cast(tensor_out) % alignof(Vec)) == 0); + const size_t out_base = + kIsBshdBshdBshd ? b_i * Sdim * Hdim * Ddim + h_i * Ddim : b_i * Hdim * Ddim + h_i * Ddim; + const bool use_vec128 = + (Ddim % static_cast(nvec128) == 0) && + ((reinterpret_cast(tensor_in) % alignof(Vec)) == 0) && + ((reinterpret_cast(tensor_out) % alignof(Vec)) == 0); if (use_vec128) { const size_t d_vec = Ddim / static_cast(nvec128); const size_t total_work = S_chunk * d_vec; - for (size_t w = static_cast(threadIdx.x); w < total_work; w += static_cast(blockDim.x)) { + for (size_t w = static_cast(threadIdx.x); w < total_work; + w += static_cast(blockDim.x)) { const size_t s_local = w / d_vec; const size_t s_i = s_begin + s_local; const size_t v = w % d_vec; @@ -265,12 +280,14 @@ __launch_bounds__(1024) __global__ void permute_to_grouped_tensor_bwd_kernel( } else { out_ptr = tensor_out + s_i * b * Hdim * Ddim + out_base + d_off; } - *reinterpret_cast *>(out_ptr) = *reinterpret_cast *>(in_ptr); + *reinterpret_cast *>(out_ptr) = + *reinterpret_cast *>(in_ptr); } } else { const size_t d_vec = Ddim / static_cast(nvec); const size_t total_work = S_chunk * d_vec; - for (size_t w = static_cast(threadIdx.x); w < total_work; w += static_cast(blockDim.x)) { + for (size_t w = static_cast(threadIdx.x); w < total_work; + w += static_cast(blockDim.x)) { const size_t s_local = w / d_vec; const size_t s_i = s_begin + s_local; const size_t v = w % d_vec; @@ -289,9 +306,10 @@ __launch_bounds__(1024) __global__ void permute_to_grouped_tensor_bwd_kernel( } void permute_to_grouped_tensor_fwd(Tensor q, Tensor k, Tensor v, Tensor q_out, Tensor k_out, - Tensor v_out, NVTE_QKV_Layout original_layout, cudaStream_t stream) { + Tensor v_out, NVTE_QKV_Layout original_layout, + cudaStream_t stream) { using namespace transformer_engine; - size_t b=0, s_q=0, s_kv=0, h_q = 0, h_kv = 0, d_qk = 0, d_v = 0; + size_t b = 0, s_q = 0, s_kv = 0, h_q = 0, h_kv = 0, d_qk = 0, d_v = 0; b = q_out.shape()[0]; h_q = q_out.shape()[1]; s_q = q_out.shape()[2]; @@ -314,27 +332,30 @@ void permute_to_grouped_tensor_fwd(Tensor q, Tensor k, Tensor v, Tensor q_out, T TRANSFORMER_ENGINE_TYPE_SWITCH_16BIT( q.dtype(), dtype, permute_to_grouped_tensor_fwd_kernel<<>>( - reinterpret_cast(q.data.dptr), reinterpret_cast(k.data.dptr), - reinterpret_cast(v.data.dptr), reinterpret_cast(q_out.data.dptr), - reinterpret_cast(k_out.data.dptr), reinterpret_cast(v_out.data.dptr), b, - s_q, h_q, d_qk, s_kv, h_kv, d_v, permute_s_splits);); + reinterpret_cast(q.data.dptr), + reinterpret_cast(k.data.dptr), + reinterpret_cast(v.data.dptr), + reinterpret_cast(q_out.data.dptr), reinterpret_cast(k_out.data.dptr), + reinterpret_cast(v_out.data.dptr), b, s_q, h_q, d_qk, s_kv, h_kv, d_v, + permute_s_splits);); } else { TRANSFORMER_ENGINE_TYPE_SWITCH_16BIT( q.dtype(), dtype, permute_to_grouped_tensor_fwd_kernel<<>>( - reinterpret_cast(q.data.dptr), reinterpret_cast(k.data.dptr), - reinterpret_cast(v.data.dptr), reinterpret_cast(q_out.data.dptr), - reinterpret_cast(k_out.data.dptr), reinterpret_cast(v_out.data.dptr), b, - s_q, h_q, d_qk, s_kv, h_kv, d_v, permute_s_splits);); + reinterpret_cast(q.data.dptr), + reinterpret_cast(k.data.dptr), + reinterpret_cast(v.data.dptr), + reinterpret_cast(q_out.data.dptr), reinterpret_cast(k_out.data.dptr), + reinterpret_cast(v_out.data.dptr), b, s_q, h_q, d_qk, s_kv, h_kv, d_v, + permute_s_splits);); } NVTE_CHECK_CUDA(cudaGetLastError()); } -void permute_to_grouped_tensor_bwd(Tensor grad_q, Tensor grad_k, Tensor grad_v, - Tensor q, Tensor k, Tensor v, - NVTE_QKV_Layout original_layout, cudaStream_t stream) { +void permute_to_grouped_tensor_bwd(Tensor grad_q, Tensor grad_k, Tensor grad_v, Tensor q, Tensor k, + Tensor v, NVTE_QKV_Layout original_layout, cudaStream_t stream) { using namespace transformer_engine; - size_t b=0, s_q=0, s_kv=0, h_q = 0, h_kv = 0, d_qk = 0, d_v = 0; + size_t b = 0, s_q = 0, s_kv = 0, h_q = 0, h_kv = 0, d_qk = 0, d_v = 0; b = grad_q.shape()[0]; h_q = grad_q.shape()[1]; s_q = grad_q.shape()[2]; @@ -358,18 +379,20 @@ void permute_to_grouped_tensor_bwd(Tensor grad_q, Tensor grad_k, Tensor grad_v, permute_to_grouped_tensor_bwd_kernel<<>>( reinterpret_cast(grad_q.data.dptr), reinterpret_cast(grad_k.data.dptr), - reinterpret_cast(grad_v.data.dptr), reinterpret_cast(q.data.dptr), - reinterpret_cast(k.data.dptr), reinterpret_cast(v.data.dptr), b, s_q, h_q, - d_qk, s_kv, h_kv, d_v, permute_s_splits);); + reinterpret_cast(grad_v.data.dptr), + reinterpret_cast(q.data.dptr), reinterpret_cast(k.data.dptr), + reinterpret_cast(v.data.dptr), b, s_q, h_q, d_qk, s_kv, h_kv, d_v, + permute_s_splits);); } else { TRANSFORMER_ENGINE_TYPE_SWITCH_16BIT( grad_q.dtype(), dtype, permute_to_grouped_tensor_bwd_kernel<<>>( reinterpret_cast(grad_q.data.dptr), reinterpret_cast(grad_k.data.dptr), - reinterpret_cast(grad_v.data.dptr), reinterpret_cast(q.data.dptr), - reinterpret_cast(k.data.dptr), reinterpret_cast(v.data.dptr), b, s_q, h_q, - d_qk, s_kv, h_kv, d_v, permute_s_splits);); + reinterpret_cast(grad_v.data.dptr), + reinterpret_cast(q.data.dptr), reinterpret_cast(k.data.dptr), + reinterpret_cast(v.data.dptr), b, s_q, h_q, d_qk, s_kv, h_kv, d_v, + permute_s_splits);); } NVTE_CHECK_CUDA(cudaGetLastError()); } @@ -395,8 +418,8 @@ void nvte_prepare_flash_attn_bwd(NVTETensor q, NVTETensor k, NVTETensor v, NVTET } void nvte_permute_to_grouped_tensor_fwd(NVTETensor q, NVTETensor k, NVTETensor v, NVTETensor q_out, - NVTETensor k_out, NVTETensor v_out, NVTE_QKV_Layout original_layout, - cudaStream_t stream) { + NVTETensor k_out, NVTETensor v_out, + NVTE_QKV_Layout original_layout, cudaStream_t stream) { NVTE_API_CALL(nvte_permute_to_grouped_tensor_fwd); using namespace transformer_engine; @@ -407,12 +430,13 @@ void nvte_permute_to_grouped_tensor_fwd(NVTETensor q, NVTETensor k, NVTETensor v } void nvte_permute_to_grouped_tensor_bwd(NVTETensor grad_q, NVTETensor grad_k, NVTETensor grad_v, - NVTETensor q, NVTETensor k, NVTETensor v, NVTE_QKV_Layout original_layout, - cudaStream_t stream) { + NVTETensor q, NVTETensor k, NVTETensor v, + NVTE_QKV_Layout original_layout, cudaStream_t stream) { NVTE_API_CALL(nvte_permute_to_grouped_tensor_bwd); using namespace transformer_engine; flash_attention::permute_to_grouped_tensor_bwd( - *convertNVTETensorCheck(grad_q), *convertNVTETensorCheck(grad_k), *convertNVTETensorCheck(grad_v), - *convertNVTETensorCheck(q), *convertNVTETensorCheck(k), *convertNVTETensorCheck(v), original_layout, stream); -} \ No newline at end of file + *convertNVTETensorCheck(grad_q), *convertNVTETensorCheck(grad_k), + *convertNVTETensorCheck(grad_v), *convertNVTETensorCheck(q), *convertNVTETensorCheck(k), + *convertNVTETensorCheck(v), original_layout, stream); +} diff --git a/transformer_engine/common/fused_attn/fused_attn.cpp b/transformer_engine/common/fused_attn/fused_attn.cpp index 827629f238..e3e8040c45 100644 --- a/transformer_engine/common/fused_attn/fused_attn.cpp +++ b/transformer_engine/common/fused_attn/fused_attn.cpp @@ -700,8 +700,8 @@ void nvte_fused_attn_fwd( nvte_convert_qkv_format(kv_format, input_V->data.shape, kv_format, tmp_shape, &b, &h_kv, &s_kv, &d_v, &t_kv); } else { - nvte_convert_qkv_format(kv_format, input_V->columnwise_data.shape, kv_format, tmp_shape, &b, &h_kv, &s_kv, - &d_v, &t_kv); + nvte_convert_qkv_format(kv_format, input_V->columnwise_data.shape, kv_format, tmp_shape, &b, + &h_kv, &s_kv, &d_v, &t_kv); } if (q_format == NVTE_QKV_Format::NVTE_THD) { b = input_cu_seqlens_q->data.shape[0] - 1; diff --git a/transformer_engine/common/fused_attn/fused_attn_fp8.cu b/transformer_engine/common/fused_attn/fused_attn_fp8.cu index 71d45b1f91..5158630937 100644 --- a/transformer_engine/common/fused_attn/fused_attn_fp8.cu +++ b/transformer_engine/common/fused_attn/fused_attn_fp8.cu @@ -2879,8 +2879,7 @@ void fused_attn_fp8_bwd( } void* devPtrM = input_M->data.dptr; - void* devPtrZInv = - (input_ZInv != nullptr) ? input_ZInv->data.dptr : nullptr; + void* devPtrZInv = (input_ZInv != nullptr) ? input_ZInv->data.dptr : nullptr; void* devPtrScaleS = input_S->scale.dptr; void* devPtrDescaleS = input_S->scale_inv.dptr; diff --git a/transformer_engine/common/include/transformer_engine/fused_attn.h b/transformer_engine/common/include/transformer_engine/fused_attn.h index 827fdf2ae3..02c351ced7 100644 --- a/transformer_engine/common/include/transformer_engine/fused_attn.h +++ b/transformer_engine/common/include/transformer_engine/fused_attn.h @@ -619,7 +619,9 @@ void nvte_prepare_flash_attn_bwd(NVTETensor q, NVTETensor k, NVTETensor v, NVTET * \param[in] original_layout Original QKV layout. * \param[in] stream CUDA stream. */ -void nvte_permute_to_grouped_tensor_fwd(NVTETensor q, NVTETensor k, NVTETensor v, NVTETensor q_out, NVTETensor k_out, NVTETensor v_out, NVTE_QKV_Layout original_layout, cudaStream_t stream); +void nvte_permute_to_grouped_tensor_fwd(NVTETensor q, NVTETensor k, NVTETensor v, NVTETensor q_out, + NVTETensor k_out, NVTETensor v_out, + NVTE_QKV_Layout original_layout, cudaStream_t stream); /*! \brief Permute Q, K, V back to original layout. * diff --git a/transformer_engine/pytorch/attention/dot_product_attention/backends.py b/transformer_engine/pytorch/attention/dot_product_attention/backends.py index 6343f71ad6..92d6cd016f 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/backends.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/backends.py @@ -1620,7 +1620,9 @@ def backward(ctx, d_out, *_args): d_out_qdq_f16 = d_out d_out_qdq_f16, _ = dpa_utils.permute_to_grouped_tensor(ctx.o_format, d_out_qdq_f16) - tmp_quantizer = MXFP8Quantizer(fp8_dtype=tex.DType.kFloat8E4M3, rowwise=True, columnwise=True) + tmp_quantizer = MXFP8Quantizer( + fp8_dtype=tex.DType.kFloat8E4M3, rowwise=True, columnwise=True + ) tmp_quantizer.optimize_for_gemm = False d_out_qdq_fp8 = tmp_quantizer(d_out_qdq_f16) d_out_qdq_f16 = d_out_qdq_fp8.dequantize(dtype=ctx.nominal_dtype) diff --git a/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py b/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py index 024d14f744..9093e91349 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py @@ -1223,7 +1223,13 @@ def cp_p2p_bwd_fused_attn( ] else: q_part, k_part, v_part, qkv_layout = combine_and_quantize( - qkv_layout, q_part, k_part, v_part, QKV_quantizer_per_step, used_in_forward=False, used_in_backward=True + qkv_layout, + q_part, + k_part, + v_part, + QKV_quantizer_per_step, + used_in_forward=False, + used_in_backward=True, ) if not fp8_recipe.mxfp8(): if not (fp8_recipe.float8_current_scaling() and _dpa_fp8_cs_o_in_f16): @@ -3610,7 +3616,13 @@ def backward(ctx, dout, *_args): ) else: q_part, k_part, v_part, new_qkv_layout = combine_and_quantize( - ctx.qkv_layout, q_part, k_part, v_part, ctx.QKV_quantizer, used_in_forward=False, used_in_backward=True + ctx.qkv_layout, + q_part, + k_part, + v_part, + ctx.QKV_quantizer, + used_in_forward=False, + used_in_backward=True, ) dout_part, do_format = dpa_utils.permute_to_grouped_tensor( do_format, dout_part @@ -3990,7 +4002,12 @@ def forward( if fp8: if fp8_recipe.mxfp8(): q_fp8, k_fp8, v_fp8, qkv_layout = combine_and_quantize( - qkv_layout, q_part, k_part, v_part, QKV_quantizer, used_in_backward=is_training + qkv_layout, + q_part, + k_part, + v_part, + QKV_quantizer, + used_in_backward=is_training, ) q_part, k_part, v_part = [q_fp8, k_fp8, v_fp8] else: diff --git a/transformer_engine/pytorch/attention/dot_product_attention/utils.py b/transformer_engine/pytorch/attention/dot_product_attention/utils.py index d0840e5ffe..820c319e05 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/utils.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/utils.py @@ -2362,7 +2362,9 @@ def backward( return q, k, v, None -def combine_and_quantize(qkv_layout, q, k, v, qkv_quantizer, used_in_forward=True, used_in_backward=False): +def combine_and_quantize( + qkv_layout, q, k, v, qkv_quantizer, used_in_forward=True, used_in_backward=False +): """Combine q,k,v based on qkv_layout and quantize them together""" if isinstance(qkv_quantizer, MXFP8Quantizer): qkv_format, q_format, kv_format = get_qkv_format(qkv_layout) diff --git a/transformer_engine/pytorch/csrc/extensions.h b/transformer_engine/pytorch/csrc/extensions.h index ccb1f286b7..498107af1d 100644 --- a/transformer_engine/pytorch/csrc/extensions.h +++ b/transformer_engine/pytorch/csrc/extensions.h @@ -112,12 +112,11 @@ std::vector fused_attn_bwd( at::Tensor fa_prepare_fwd(at::Tensor qkvi); at::Tensor fa_prepare_bwd(at::Tensor q, at::Tensor k, at::Tensor v); -std::tuple permute_to_grouped_tensor_fwd(at::Tensor query, - at::Tensor key, - at::Tensor value, - NVTE_QKV_Layout input_layout); +std::tuple permute_to_grouped_tensor_fwd( + at::Tensor query, at::Tensor key, at::Tensor value, NVTE_QKV_Layout input_layout); std::tuple permute_to_grouped_tensor_bwd( - at::Tensor query_grad, at::Tensor key_grad, at::Tensor value_grad, NVTE_QKV_Layout input_layout); + at::Tensor query_grad, at::Tensor key_grad, at::Tensor value_grad, + NVTE_QKV_Layout input_layout); at::Tensor convert_thd_to_bshd(at::Tensor tensor, at::Tensor cu_seqlens, int b, int max_seq_len); at::Tensor convert_bshd_to_thd(at::Tensor tensor, at::Tensor cu_seqlens, int t); diff --git a/transformer_engine/pytorch/csrc/extensions/attention.cpp b/transformer_engine/pytorch/csrc/extensions/attention.cpp index bf4af5829a..e3c25b396a 100644 --- a/transformer_engine/pytorch/csrc/extensions/attention.cpp +++ b/transformer_engine/pytorch/csrc/extensions/attention.cpp @@ -287,7 +287,9 @@ std::vector fused_attn_fwd( static_cast(nvte_tensor_type(nvte_aux_tensor_pack.tensors[i])), false); set_tensor_param(i++, output_tensor); // fp8 T3HD has an additional softmax stats tensor, ZInv; return_max_logit=true has an additional Max tensor - if (((qkv_type == DType::kFloat8E4M3 || qkv_type == DType::kFloat8E5M2) && qkv_layout == NVTE_QKV_Layout::NVTE_T3HD) || return_max_logit) { + if (((qkv_type == DType::kFloat8E4M3 || qkv_type == DType::kFloat8E5M2) && + qkv_layout == NVTE_QKV_Layout::NVTE_T3HD) || + return_max_logit) { output_tensor = allocateSpace(nvte_shape_to_vector(nvte_tensor_shape(nvte_aux_tensor_pack.tensors[i])), static_cast(nvte_tensor_type(nvte_aux_tensor_pack.tensors[i])), false); @@ -647,18 +649,18 @@ at::Tensor fa_prepare_bwd(at::Tensor q, at::Tensor k, at::Tensor v) { return qkv; } -std::tuple permute_to_grouped_tensor_fwd(at::Tensor query, - at::Tensor key, - at::Tensor value, - NVTE_QKV_Layout original_layout) { +std::tuple permute_to_grouped_tensor_fwd( + at::Tensor query, at::Tensor key, at::Tensor value, NVTE_QKV_Layout original_layout) { NVTE_CHECK(original_layout == NVTE_SBHD_SBHD_SBHD || original_layout == NVTE_BSHD_BSHD_BSHD, - "permute_to_grouped_tensor_fwd: original_layout must be NVTE_SBHD_SBHD_SBHD or NVTE_BSHD_BSHD_BSHD."); + "permute_to_grouped_tensor_fwd: original_layout must be NVTE_SBHD_SBHD_SBHD or " + "NVTE_BSHD_BSHD_BSHD."); NVTE_CHECK(query.is_cuda() && key.is_cuda() && value.is_cuda()); NVTE_CHECK(query.is_contiguous() && key.is_contiguous() && value.is_contiguous()); NVTE_CHECK(query.dim() == 4 && key.dim() == 4 && value.dim() == 4); NVTE_CHECK(query.scalar_type() == at::ScalarType::Half || query.scalar_type() == at::ScalarType::BFloat16); - NVTE_CHECK(key.scalar_type() == query.scalar_type() && value.scalar_type() == query.scalar_type()); + NVTE_CHECK(key.scalar_type() == query.scalar_type() && + value.scalar_type() == query.scalar_type()); int64_t B = 0; int64_t S_q = 0, H_q = 0, D_qk = 0; @@ -699,17 +701,19 @@ std::tuple permute_to_grouped_tensor_fwd(at: auto te_ko = makeTransformerEngineTensor(k_out); auto te_vo = makeTransformerEngineTensor(v_out); - nvte_permute_to_grouped_tensor_fwd( - te_q.data(), te_k.data(), te_v.data(), te_qo.data(), te_ko.data(), te_vo.data(), - original_layout, at::cuda::getCurrentCUDAStream()); + nvte_permute_to_grouped_tensor_fwd(te_q.data(), te_k.data(), te_v.data(), te_qo.data(), + te_ko.data(), te_vo.data(), original_layout, + at::cuda::getCurrentCUDAStream()); return std::make_tuple(q_out, k_out, v_out); } std::tuple permute_to_grouped_tensor_bwd( - at::Tensor query_grad, at::Tensor key_grad, at::Tensor value_grad, NVTE_QKV_Layout original_layout) { + at::Tensor query_grad, at::Tensor key_grad, at::Tensor value_grad, + NVTE_QKV_Layout original_layout) { NVTE_CHECK(original_layout == NVTE_SBHD_SBHD_SBHD || original_layout == NVTE_BSHD_BSHD_BSHD, - "permute_to_grouped_tensor_bwd: original_layout must be NVTE_SBHD_SBHD_SBHD or NVTE_BSHD_BSHD_BSHD."); + "permute_to_grouped_tensor_bwd: original_layout must be NVTE_SBHD_SBHD_SBHD or " + "NVTE_BSHD_BSHD_BSHD."); NVTE_CHECK(query_grad.is_cuda() && key_grad.is_cuda() && value_grad.is_cuda()); NVTE_CHECK(query_grad.is_contiguous() && key_grad.is_contiguous() && value_grad.is_contiguous()); NVTE_CHECK(query_grad.dim() == 4 && key_grad.dim() == 4 && value_grad.dim() == 4); @@ -751,9 +755,9 @@ std::tuple permute_to_grouped_tensor_bwd( auto te_k = makeTransformerEngineTensor(key); auto te_v = makeTransformerEngineTensor(value); - nvte_permute_to_grouped_tensor_bwd( - te_gq.data(), te_gk.data(), te_gv.data(), te_q.data(), te_k.data(), te_v.data(), - original_layout, at::cuda::getCurrentCUDAStream()); + nvte_permute_to_grouped_tensor_bwd(te_gq.data(), te_gk.data(), te_gv.data(), te_q.data(), + te_k.data(), te_v.data(), original_layout, + at::cuda::getCurrentCUDAStream()); return std::make_tuple(query, key, value); } diff --git a/transformer_engine/pytorch/csrc/extensions/pybind.cpp b/transformer_engine/pytorch/csrc/extensions/pybind.cpp index 58e27c875b..7e5f853454 100644 --- a/transformer_engine/pytorch/csrc/extensions/pybind.cpp +++ b/transformer_engine/pytorch/csrc/extensions/pybind.cpp @@ -394,13 +394,14 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("fa_prepare_bwd", &transformer_engine::pytorch::fa_prepare_bwd, "Backward of QKV preparation for Flash Attention", py::call_guard()); - m.def("permute_to_grouped_tensor_fwd", &transformer_engine::pytorch::permute_to_grouped_tensor_fwd, - "Permute Q, K, V to grouped tensors.", - py::arg("query"), py::arg("key"), py::arg("value"), py::arg("original_layout"), - py::call_guard()); - m.def("permute_to_grouped_tensor_bwd", &transformer_engine::pytorch::permute_to_grouped_tensor_bwd, - "Permute Q, K, V back to original layout.", py::arg("query_grad"), py::arg("key_grad"), - py::arg("value_grad"), py::arg("original_layout"), py::call_guard()); + m.def("permute_to_grouped_tensor_fwd", + &transformer_engine::pytorch::permute_to_grouped_tensor_fwd, + "Permute Q, K, V to grouped tensors.", py::arg("query"), py::arg("key"), py::arg("value"), + py::arg("original_layout"), py::call_guard()); + m.def( + "permute_to_grouped_tensor_bwd", &transformer_engine::pytorch::permute_to_grouped_tensor_bwd, + "Permute Q, K, V back to original layout.", py::arg("query_grad"), py::arg("key_grad"), + py::arg("value_grad"), py::arg("original_layout"), py::call_guard()); m.def("fused_attn_fwd", &transformer_engine::pytorch::fused_attn_fwd, "Fused Attention FP8/BF16/FP16 FWD with separate Q, K and V"); m.def("fused_attn_bwd", &transformer_engine::pytorch::fused_attn_bwd, From b0db79e82e691093279663363bc7350abfc9a059 Mon Sep 17 00:00:00 2001 From: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Date: Fri, 27 Mar 2026 22:02:23 -0700 Subject: [PATCH 150/172] switch back to GL FE Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --- .gitmodules | 3 ++- 3rdparty/cudnn-frontend | 2 +- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/.gitmodules b/.gitmodules index 4b188d6bb1..8c7646c00d 100644 --- a/.gitmodules +++ b/.gitmodules @@ -3,7 +3,8 @@ url = https://github.com/google/googletest.git [submodule "3rdparty/cudnn-frontend"] path = 3rdparty/cudnn-frontend - url = https://github.com/NVIDIA/cudnn-frontend.git + url = https://gitlab-master.nvidia.com/cudnn/cudnn_frontend.git + branch = develop [submodule "3rdparty/cutlass"] path = 3rdparty/cutlass url = https://github.com/NVIDIA/cutlass.git diff --git a/3rdparty/cudnn-frontend b/3rdparty/cudnn-frontend index 7b9b711c22..28ebf81b5e 160000 --- a/3rdparty/cudnn-frontend +++ b/3rdparty/cudnn-frontend @@ -1 +1 @@ -Subproject commit 7b9b711c22b6823e87150213ecd8449260db8610 +Subproject commit 28ebf81b5e19b92045d1af0fd4c0b9c4599c2b53 From f662a4a070eeeb149ccf51000a666eca91c8d6a1 Mon Sep 17 00:00:00 2001 From: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Date: Sat, 28 Mar 2026 10:46:20 -0700 Subject: [PATCH 151/172] fix ag after merge main Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --- .../pytorch/attention/dot_product_attention/context_parallel.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py b/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py index 9093e91349..142a375007 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py @@ -3429,7 +3429,7 @@ def backward(ctx, dout, *_args): softmax_lse_per_step[1], rng_states[0], rng_states[1], - ) = restore_from_saved(ctx.tensor_objects, ctx.saved_tensors) + ) = restore_from_func_ctx(ctx) kv_seq_range_per_step = ctx.kv_seq_range_per_step window_size_per_step = ctx.window_size_per_step From cbf6edd5d98ba9bdab5a5942a8156ec632fd0bae Mon Sep 17 00:00:00 2001 From: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Date: Sat, 28 Mar 2026 10:47:31 -0700 Subject: [PATCH 152/172] add condition for qdq(do) to not affect other tests Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --- .../dot_product_attention/backends.py | 25 ++++++++++--------- 1 file changed, 13 insertions(+), 12 deletions(-) diff --git a/transformer_engine/pytorch/attention/dot_product_attention/backends.py b/transformer_engine/pytorch/attention/dot_product_attention/backends.py index 92d6cd016f..3c0d0fdc8a 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/backends.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/backends.py @@ -1618,18 +1618,19 @@ def backward(ctx, d_out, *_args): # pylint: disable=missing-function-docstring d_out_shadow_f16 = d_out - d_out_qdq_f16 = d_out - d_out_qdq_f16, _ = dpa_utils.permute_to_grouped_tensor(ctx.o_format, d_out_qdq_f16) - tmp_quantizer = MXFP8Quantizer( - fp8_dtype=tex.DType.kFloat8E4M3, rowwise=True, columnwise=True - ) - tmp_quantizer.optimize_for_gemm = False - d_out_qdq_fp8 = tmp_quantizer(d_out_qdq_f16) - d_out_qdq_f16 = d_out_qdq_fp8.dequantize(dtype=ctx.nominal_dtype) - if ctx.o_format == "bshd": - d_out_qdq_f16 = d_out_qdq_f16.permute(0, 2, 1, 3).contiguous() - elif ctx.o_format == "sbhd": - d_out_qdq_f16 = d_out_qdq_f16.permute(2, 0, 1, 3).contiguous() + if _qdq_dO_in_f16_bprop or _qdq_dO_in_mxfp8_bprop: + d_out_qdq_f16 = d_out + d_out_qdq_f16, _ = dpa_utils.permute_to_grouped_tensor(ctx.o_format, d_out_qdq_f16) + tmp_quantizer = MXFP8Quantizer( + fp8_dtype=tex.DType.kFloat8E4M3, rowwise=True, columnwise=True + ) + tmp_quantizer.optimize_for_gemm = False + d_out_qdq_fp8 = tmp_quantizer(d_out_qdq_f16) + d_out_qdq_f16 = d_out_qdq_fp8.dequantize(dtype=ctx.nominal_dtype) + if ctx.o_format == "bshd": + d_out_qdq_f16 = d_out_qdq_f16.permute(0, 2, 1, 3).contiguous() + elif ctx.o_format == "sbhd": + d_out_qdq_f16 = d_out_qdq_f16.permute(2, 0, 1, 3).contiguous() swapped_do_with_qdq_do = False if ctx.fp8 and _qdq_dO_in_mxfp8_bprop: d_out = d_out_qdq_f16 From 0642251bc054cb4c773aa30bf4ee4c721b14c688 Mon Sep 17 00:00:00 2001 From: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Date: Sat, 28 Mar 2026 10:48:17 -0700 Subject: [PATCH 153/172] fix custom_mha_fp8 test Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --- tests/pytorch/attention/test_attention.py | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/tests/pytorch/attention/test_attention.py b/tests/pytorch/attention/test_attention.py index 165169a799..61435fdffb 100644 --- a/tests/pytorch/attention/test_attention.py +++ b/tests/pytorch/attention/test_attention.py @@ -2683,6 +2683,8 @@ def forward( quantization_params=qkv_quantizer, use_split_accumulator=_2X_ACC_FPROP, ) + qkv_layout="bs3hd" if cudnn_frontend_version == 1 else "t3hd" + o_format="bshd" if cudnn_frontend_version == 1 else "thd" qkv = qkv.view(-1, 3, h, d) qkv_fp16 = qkv.dequantize().view(b, max_s, 3, h, d).contiguous() torch.save(qkv_fp16, "qkv.pt") @@ -2711,7 +2713,8 @@ def forward( attn_scale=None, dropout=p_dropout, fast_zero_fill=fast_zero_fill, - qkv_layout="bs3hd" if cudnn_frontend_version == 1 else "t3hd", + qkv_layout=qkv_layout, + o_format=o_format, attn_bias_type="no_bias", attn_mask_type=mask_type if cudnn_frontend_version == 1 else "padding", rng_gen=None, @@ -2734,6 +2737,8 @@ def forward( ctx.num_heads = num_heads ctx.mask_type = mask_type ctx.dtype = inp.dtype + ctx.qkv_layout = qkv_layout + ctx.o_format = o_format ctx.dQKV_quantizer = dQKV_quantizer ctx.dO_quantizer = dO_quantizer @@ -2751,7 +2756,6 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], (q, k, v, inp_fp8, qkv_weight_fp8, out) = restore_from_func_ctx(ctx) proj_dgrad = ctx.dO_quantizer(grad_output) - fp8_dtype_backward = get_fp8_te_dtype(ctx.fp8_meta["recipe"], fprop_tensor=False) dq, dk, dv, *rest = fused_attn_bwd( ctx.max_s, @@ -2764,7 +2768,6 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], out, proj_dgrad.view_as(out), ctx.qkv_dtype, - fp8_dtype_backward, ctx.aux_ctx_tensors, FusedAttnBackend["FP8"], None, @@ -2775,7 +2778,10 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], attn_scale=None, dropout=ctx.p_dropout, fast_zero_fill=ctx.fast_zero_fill, - qkv_layout="bs3hd" if cudnn_frontend_version == 1 else "t3hd", + qkv_layout=ctx.qkv_layout, + o_format=ctx.o_format, + do_format=ctx.o_format, + dqkv_layout=ctx.qkv_layout, attn_bias_type="no_bias", attn_mask_type=ctx.mask_type if cudnn_frontend_version == 1 else "padding", ) From e6ffc6bc72d36b49a934b4aaa61af0ee915b83d0 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sat, 28 Mar 2026 17:49:38 +0000 Subject: [PATCH 154/172] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/pytorch/attention/test_attention.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/pytorch/attention/test_attention.py b/tests/pytorch/attention/test_attention.py index 61435fdffb..58f5ebb7bb 100644 --- a/tests/pytorch/attention/test_attention.py +++ b/tests/pytorch/attention/test_attention.py @@ -2683,8 +2683,8 @@ def forward( quantization_params=qkv_quantizer, use_split_accumulator=_2X_ACC_FPROP, ) - qkv_layout="bs3hd" if cudnn_frontend_version == 1 else "t3hd" - o_format="bshd" if cudnn_frontend_version == 1 else "thd" + qkv_layout = "bs3hd" if cudnn_frontend_version == 1 else "t3hd" + o_format = "bshd" if cudnn_frontend_version == 1 else "thd" qkv = qkv.view(-1, 3, h, d) qkv_fp16 = qkv.dequantize().view(b, max_s, 3, h, d).contiguous() torch.save(qkv_fp16, "qkv.pt") From fd9a750b4d6a8f0b14b9e201151d52b2e44fed09 Mon Sep 17 00:00:00 2001 From: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Date: Mon, 30 Mar 2026 14:56:03 -0700 Subject: [PATCH 155/172] fix amax dqkv Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --- transformer_engine/common/fused_attn/fused_attn_fp8.cu | 8 ++++---- transformer_engine/pytorch/csrc/quantizer.cpp | 10 +++++++--- 2 files changed, 11 insertions(+), 7 deletions(-) diff --git a/transformer_engine/common/fused_attn/fused_attn_fp8.cu b/transformer_engine/common/fused_attn/fused_attn_fp8.cu index 5158630937..c8e405e87e 100644 --- a/transformer_engine/common/fused_attn/fused_attn_fp8.cu +++ b/transformer_engine/common/fused_attn/fused_attn_fp8.cu @@ -2898,11 +2898,11 @@ void fused_attn_fp8_bwd( void* devPtrdK = output_dK->data.dptr; void* devPtrdV = output_dV->data.dptr; void* devPtrAmaxdQ = output_dQ->amax.dptr; - void* devPtrAmaxdK = output_dQ->amax.dptr; - void* devPtrAmaxdV = output_dQ->amax.dptr; + void* devPtrAmaxdK = output_dK->amax.dptr; + void* devPtrAmaxdV = output_dV->amax.dptr; void* devPtrScaledQ = output_dQ->scale.dptr; - void* devPtrScaledK = output_dQ->scale.dptr; - void* devPtrScaledV = output_dQ->scale.dptr; + void* devPtrScaledK = output_dK->scale.dptr; + void* devPtrScaledV = output_dV->scale.dptr; void* devPtrcuSeqlensQ = reinterpret_cast(reinterpret_cast(cu_seqlens_q->data.dptr)); diff --git a/transformer_engine/pytorch/csrc/quantizer.cpp b/transformer_engine/pytorch/csrc/quantizer.cpp index 7f026fe1b1..07798843b4 100644 --- a/transformer_engine/pytorch/csrc/quantizer.cpp +++ b/transformer_engine/pytorch/csrc/quantizer.cpp @@ -1447,13 +1447,17 @@ std::pair MXFP8Quantizer::create_tensor(const std::ve std::pair MXFP8Quantizer::create_unquantized_tensor_with_amax( const std::vector& shape, DType dtype, std::optional data) { - at::Tensor amax_tensor = - at::zeros({1}, at::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA)); + static std::once_flag once; + static at::Tensor amax_tensor; + std::call_once(once, []() { + amax_tensor = at::zeros({1}, at::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA)); + }); auto out = data.has_value() ? NoneQuantizer(py::none()).create_tensor(shape, dtype, data.value()) : NoneQuantizer(py::none()).create_tensor(shape, dtype); TensorWrapper out_cpp = std::move(out.first); py::object out_py = std::move(out.second); - out_cpp.set_amax(amax_tensor.data_ptr(), DType::kFloat32, std::vector{1}); + out_cpp.set_amax(amax_tensor.data_ptr(), GetTransformerEngineDType(amax_tensor.scalar_type()), + getTensorShape(amax_tensor)); return {std::move(out_cpp), std::move(out_py)}; } From 4f2e4f47c13de273786593ef3575cf0753248209 Mon Sep 17 00:00:00 2001 From: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Date: Mon, 30 Mar 2026 15:06:54 -0700 Subject: [PATCH 156/172] fix fp8_recipe in DPA utils Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --- .../attention/dot_product_attention/utils.py | 21 +++++++++---------- 1 file changed, 10 insertions(+), 11 deletions(-) diff --git a/transformer_engine/pytorch/attention/dot_product_attention/utils.py b/transformer_engine/pytorch/attention/dot_product_attention/utils.py index 820c319e05..8ee16649e7 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/utils.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/utils.py @@ -489,7 +489,6 @@ def get_attention_backend( use_fused_attention = False # Filter: Execution type - fp8_recipe = None if fp8 and fp8_meta["recipe"].fp8_dpa: fp8_recipe = fp8_meta["recipe"] if fp8_meta.get("local_recipes", None) is not None: @@ -591,7 +590,7 @@ def get_attention_backend( if use_flash_attention: use_flash_attention = False logger.debug("Disabling FlashAttention for max_logit") - if fp8 and fp8_recipe.fp8_dpa: + if fp8 and fp8_meta["recipe"].fp8_dpa: use_flash_attention = False use_fused_attention = False use_unfused_attention = False @@ -620,8 +619,8 @@ def get_attention_backend( use_flash_attention = False use_fused_attention = False use_unfused_attention = False - if fp8 and fp8_recipe.fp8_dpa: - if fp8_recipe.fp8_mha: + if fp8 and fp8_meta["recipe"].fp8_dpa: + if fp8_meta["recipe"].fp8_mha: logger.debug("Disabling all backends for KV caching with FP8 MHA") use_flash_attention = False use_fused_attention = False @@ -780,7 +779,7 @@ def _is_fa3_supported(num_heads, num_gqa_groups, head_dim_qk, head_dim_v, qkv_dt if softmax_type != "vanilla": logger.debug("Disabling FlashAttention for softmax_type = %s", softmax_type) use_flash_attention = False - if fp8 and fp8_recipe.fp8_dpa: + if fp8 and fp8_meta["recipe"].fp8_dpa: if use_fused_attention and ( device_compute_capability < (10, 0) or cudnn_version < (9, 21, 0) ): @@ -831,7 +830,7 @@ def _is_fa3_supported(num_heads, num_gqa_groups, head_dim_qk, head_dim_v, qkv_dt use_unfused_attention = False if context_parallel and (use_flash_attention_2 or use_flash_attention_3): if FlashAttentionUtils.is_installed or FlashAttentionUtils.v3_is_installed: - if fp8 and fp8_recipe.fp8_dpa: + if fp8 and fp8_meta["recipe"].fp8_dpa: logger.debug( "Disabling FlashAttention as it does not support context parallelism with FP8" ) @@ -888,13 +887,13 @@ def _is_fa3_supported(num_heads, num_gqa_groups, head_dim_qk, head_dim_v, qkv_dt " bias for THD format" ) use_fused_attention = False - elif fp8 and fp8_recipe.fp8_dpa and qkv_format == "thd": + elif fp8 and fp8_meta["recipe"].fp8_dpa and qkv_format == "thd": logger.debug( "Disabling FusedAttention as it does not support context parallelism with FP8" " attention and THD format" ) use_fused_attention = False - elif fp8 and fp8_recipe.fp8_dpa and core_attention_bias_type != "no_bias": + elif fp8 and fp8_meta["recipe"].fp8_dpa and core_attention_bias_type != "no_bias": logger.debug( "Disabling FusedAttention as it does not support context parallelism with FP8" " attention and bias" @@ -986,7 +985,7 @@ def _is_fa3_supported(num_heads, num_gqa_groups, head_dim_qk, head_dim_v, qkv_dt if use_fused_attention and (window_size[0] != -1 or window_size[1] not in [-1, 0]): if ( fp8 - and (fp8_recipe.fp8_dpa or fp8_recipe.fp8_mha) + and (fp8_meta["recipe"].fp8_dpa or fp8_meta["recipe"].fp8_mha) and (device_compute_capability < (10, 0) or cudnn_version < (9, 21, 0)) ): logger.debug( @@ -1098,8 +1097,8 @@ def _is_fa3_supported(num_heads, num_gqa_groups, head_dim_qk, head_dim_v, qkv_dt if use_fused_attention: q_type = TE_DType[qkv_dtype] kv_type = q_type - if fp8 and fp8_recipe.fp8_dpa: - q_type = get_fp8_te_dtype(fp8_recipe, fprop_tensor=True) + if fp8 and fp8_meta["recipe"].fp8_dpa: + q_type = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True) kv_type = q_type fused_attention_backend = tex.get_fused_attn_backend( is_training, From 386914561417cd80fe2404fedd1f77842a78e0b2 Mon Sep 17 00:00:00 2001 From: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Date: Mon, 30 Mar 2026 18:36:19 -0700 Subject: [PATCH 157/172] remove use of amax for mxfp8 Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --- .../common/fused_attn/fused_attn_fp8.cu | 46 +++++++++++-------- transformer_engine/pytorch/csrc/quantizer.cpp | 14 +++--- 2 files changed, 33 insertions(+), 27 deletions(-) diff --git a/transformer_engine/common/fused_attn/fused_attn_fp8.cu b/transformer_engine/common/fused_attn/fused_attn_fp8.cu index c8e405e87e..8d2717601e 100644 --- a/transformer_engine/common/fused_attn/fused_attn_fp8.cu +++ b/transformer_engine/common/fused_attn/fused_attn_fp8.cu @@ -1947,7 +1947,7 @@ void fused_attn_fp8_fwd_impl_v1( .set_dim({b, h, s_q, d_v}) .set_stride(o_strides) .set_data_type(o_tensor_type); - amax_o->set_output(true) + amax_o->set_output(!is_mxfp8) .set_dim({1, 1, 1, 1}) .set_stride({1, 1, 1, 1}) .set_data_type(fe::DataType_t::FLOAT); @@ -2024,7 +2024,6 @@ void fused_attn_fp8_fwd_impl_v1( {descale_v, devPtrDescaleV}, {attn_scale, &scaling_factor}, {O, devPtrO}, - {amax_o, devPtrAmaxO}, {Stats, devPtrM}}; if (is_delayed_scaling) { @@ -2034,6 +2033,7 @@ void fused_attn_fp8_fwd_impl_v1( variant_pack[descale_s] = devPtrDescaleS; variant_pack[scale_s] = devPtrScaleS; variant_pack[amax_s] = devPtrAmaxS; + variant_pack[amax_o] = devPtrAmaxO; } /* if (is_bias) { @@ -2519,15 +2519,15 @@ void fused_attn_fp8_bwd_impl_v1( .set_dim({b, hg, s_kv, d_v}) .set_stride(dv_strides) .set_data_type(dqkv_tensor_type); - amax_dQ->set_output(true) + amax_dQ->set_output(!is_mxfp8) .set_dim({1, 1, 1, 1}) .set_stride({1, 1, 1, 1}) .set_data_type(fe::DataType_t::FLOAT); - amax_dK->set_output(true) + amax_dK->set_output(!is_mxfp8) .set_dim({1, 1, 1, 1}) .set_stride({1, 1, 1, 1}) .set_data_type(fe::DataType_t::FLOAT); - amax_dV->set_output(true) + amax_dV->set_output(!is_mxfp8) .set_dim({1, 1, 1, 1}) .set_stride({1, 1, 1, 1}) .set_data_type(fe::DataType_t::FLOAT); @@ -2628,9 +2628,6 @@ void fused_attn_fp8_bwd_impl_v1( {dQ, devPtrdQ}, {dK, devPtrdK}, {dV, devPtrdV}, - {amax_dQ, devPtrAmaxdQ}, - {amax_dK, devPtrAmaxdK}, - {amax_dV, devPtrAmaxdV}, }; if (is_delayed_scaling || is_current_scaling) { variant_pack[descale_s] = devPtrDescaleS; @@ -2638,6 +2635,9 @@ void fused_attn_fp8_bwd_impl_v1( variant_pack[scale_s] = devPtrScaleS; variant_pack[scale_dP] = devPtrScaledP; variant_pack[amax_dP] = devPtrAmaxdP; + variant_pack[amax_dQ] = devPtrAmaxdQ; + variant_pack[amax_dK] = devPtrAmaxdK; + variant_pack[amax_dV] = devPtrAmaxdV; } if (is_delayed_scaling || (is_current_scaling && !is_O_in_F16)) { variant_pack[descale_o] = devPtrDescaleO; @@ -2724,7 +2724,6 @@ void fused_attn_fp8_fwd(size_t batch, size_t num_attn_heads, size_t num_gqa_grou devPtrK = input_K->data.dptr; devPtrDescaleK = input_K->scale_inv.dptr; devPtrO = output_O->data.dptr; - devPtrAmaxO = output_O->amax.dptr; if (input_Q->scaling_mode == NVTE_DELAYED_TENSOR_SCALING) { devPtrV = input_V->data.dptr; devPtrDescaleV = input_V->scale_inv.dptr; @@ -2732,6 +2731,7 @@ void fused_attn_fp8_fwd(size_t batch, size_t num_attn_heads, size_t num_gqa_grou devPtrAmaxS = input_output_S->amax.dptr; devPtrScaleS = input_output_S->scale.dptr; devPtrDescaleS = input_output_S->scale_inv.dptr; + devPtrAmaxO = output_O->amax.dptr; } else if (input_Q->scaling_mode == NVTE_MXFP8_1D_SCALING) { devPtrV = input_V->columnwise_data.dptr; devPtrDescaleV = input_V->columnwise_scale_inv.dptr; @@ -2881,11 +2881,14 @@ void fused_attn_fp8_bwd( void* devPtrM = input_M->data.dptr; void* devPtrZInv = (input_ZInv != nullptr) ? input_ZInv->data.dptr : nullptr; - void* devPtrScaleS = input_S->scale.dptr; - void* devPtrDescaleS = input_S->scale_inv.dptr; - void* devPtrAmaxdP = input_output_dP->amax.dptr; - void* devPtrScaledP = input_output_dP->scale.dptr; - void* devPtrDescaledP = input_output_dP->scale_inv.dptr; + void *devPtrScaleS = nullptr, *devPtrDescaleS = nullptr, *devPtrAmaxdP = nullptr, *devPtrScaledP = nullptr, *devPtrDescaledP = nullptr; + if (input_Q->scaling_mode == NVTE_DELAYED_TENSOR_SCALING) { + devPtrScaleS = input_S->scale.dptr; + devPtrDescaleS = input_S->scale_inv.dptr; + devPtrAmaxdP = input_output_dP->amax.dptr; + devPtrScaledP = input_output_dP->scale.dptr; + devPtrDescaledP = input_output_dP->scale_inv.dptr; + } void* devPtrSoftmaxOffset = nullptr; void* devPtrdSoftmaxOffset = nullptr; @@ -2897,12 +2900,15 @@ void fused_attn_fp8_bwd( void* devPtrdQ = output_dQ->data.dptr; void* devPtrdK = output_dK->data.dptr; void* devPtrdV = output_dV->data.dptr; - void* devPtrAmaxdQ = output_dQ->amax.dptr; - void* devPtrAmaxdK = output_dK->amax.dptr; - void* devPtrAmaxdV = output_dV->amax.dptr; - void* devPtrScaledQ = output_dQ->scale.dptr; - void* devPtrScaledK = output_dK->scale.dptr; - void* devPtrScaledV = output_dV->scale.dptr; + void *devPtrAmaxdQ = nullptr, *devPtrAmaxdK = nullptr, *devPtrAmaxdV = nullptr, *devPtrScaledQ = nullptr, *devPtrScaledK = nullptr, *devPtrScaledV = nullptr; + if (input_Q->scaling_mode == NVTE_DELAYED_TENSOR_SCALING) { + devPtrAmaxdQ = output_dQ->amax.dptr; + devPtrAmaxdK = output_dK->amax.dptr; + devPtrAmaxdV = output_dV->amax.dptr; + devPtrScaledQ = output_dQ->scale.dptr; + devPtrScaledK = output_dK->scale.dptr; + devPtrScaledV = output_dV->scale.dptr; + } void* devPtrcuSeqlensQ = reinterpret_cast(reinterpret_cast(cu_seqlens_q->data.dptr)); diff --git a/transformer_engine/pytorch/csrc/quantizer.cpp b/transformer_engine/pytorch/csrc/quantizer.cpp index 07798843b4..9610880093 100644 --- a/transformer_engine/pytorch/csrc/quantizer.cpp +++ b/transformer_engine/pytorch/csrc/quantizer.cpp @@ -1447,17 +1447,17 @@ std::pair MXFP8Quantizer::create_tensor(const std::ve std::pair MXFP8Quantizer::create_unquantized_tensor_with_amax( const std::vector& shape, DType dtype, std::optional data) { - static std::once_flag once; - static at::Tensor amax_tensor; - std::call_once(once, []() { - amax_tensor = at::zeros({1}, at::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA)); - }); + // static std::once_flag once; + // static at::Tensor amax_tensor; + // std::call_once(once, []() { + // amax_tensor = at::zeros({1}, at::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA)); + // }); auto out = data.has_value() ? NoneQuantizer(py::none()).create_tensor(shape, dtype, data.value()) : NoneQuantizer(py::none()).create_tensor(shape, dtype); TensorWrapper out_cpp = std::move(out.first); py::object out_py = std::move(out.second); - out_cpp.set_amax(amax_tensor.data_ptr(), GetTransformerEngineDType(amax_tensor.scalar_type()), - getTensorShape(amax_tensor)); + // out_cpp.set_amax(amax_tensor.data_ptr(), GetTransformerEngineDType(amax_tensor.scalar_type()), + // getTensorShape(amax_tensor)); return {std::move(out_cpp), std::move(out_py)}; } From 641c05ccd645577a5b64ae5fd20bccc85a05d585 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 31 Mar 2026 01:38:50 +0000 Subject: [PATCH 158/172] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- transformer_engine/common/fused_attn/fused_attn_fp8.cu | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/transformer_engine/common/fused_attn/fused_attn_fp8.cu b/transformer_engine/common/fused_attn/fused_attn_fp8.cu index 8d2717601e..70503caba8 100644 --- a/transformer_engine/common/fused_attn/fused_attn_fp8.cu +++ b/transformer_engine/common/fused_attn/fused_attn_fp8.cu @@ -2881,7 +2881,8 @@ void fused_attn_fp8_bwd( void* devPtrM = input_M->data.dptr; void* devPtrZInv = (input_ZInv != nullptr) ? input_ZInv->data.dptr : nullptr; - void *devPtrScaleS = nullptr, *devPtrDescaleS = nullptr, *devPtrAmaxdP = nullptr, *devPtrScaledP = nullptr, *devPtrDescaledP = nullptr; + void *devPtrScaleS = nullptr, *devPtrDescaleS = nullptr, *devPtrAmaxdP = nullptr, + *devPtrScaledP = nullptr, *devPtrDescaledP = nullptr; if (input_Q->scaling_mode == NVTE_DELAYED_TENSOR_SCALING) { devPtrScaleS = input_S->scale.dptr; devPtrDescaleS = input_S->scale_inv.dptr; @@ -2900,7 +2901,8 @@ void fused_attn_fp8_bwd( void* devPtrdQ = output_dQ->data.dptr; void* devPtrdK = output_dK->data.dptr; void* devPtrdV = output_dV->data.dptr; - void *devPtrAmaxdQ = nullptr, *devPtrAmaxdK = nullptr, *devPtrAmaxdV = nullptr, *devPtrScaledQ = nullptr, *devPtrScaledK = nullptr, *devPtrScaledV = nullptr; + void *devPtrAmaxdQ = nullptr, *devPtrAmaxdK = nullptr, *devPtrAmaxdV = nullptr, + *devPtrScaledQ = nullptr, *devPtrScaledK = nullptr, *devPtrScaledV = nullptr; if (input_Q->scaling_mode == NVTE_DELAYED_TENSOR_SCALING) { devPtrAmaxdQ = output_dQ->amax.dptr; devPtrAmaxdK = output_dK->amax.dptr; From 59db112ce8ccd5d1fef55d4d21f4fed26f286158 Mon Sep 17 00:00:00 2001 From: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Date: Mon, 30 Mar 2026 19:22:26 -0700 Subject: [PATCH 159/172] add o_format/do_format/dqkv_layout to cache indicators for fp8 and f16 Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --- .../common/fused_attn/fused_attn.cpp | 10 +-- .../fused_attn_f16_arbitrary_seqlen.cu | 66 +++++++++++-------- .../fused_attn_f16_arbitrary_seqlen.h | 4 +- .../common/fused_attn/fused_attn_fp8.cu | 6 ++ transformer_engine/common/fused_attn/utils.h | 9 ++- .../include/transformer_engine/fused_attn.h | 3 + 6 files changed, 63 insertions(+), 35 deletions(-) diff --git a/transformer_engine/common/fused_attn/fused_attn.cpp b/transformer_engine/common/fused_attn/fused_attn.cpp index e3e8040c45..5f3fe5288c 100644 --- a/transformer_engine/common/fused_attn/fused_attn.cpp +++ b/transformer_engine/common/fused_attn/fused_attn.cpp @@ -760,8 +760,9 @@ void nvte_fused_attn_fwd( fused_attn_arbitrary_seqlen_fwd( b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d_qk, d_v, t_q, t_kv, num_pages_k, num_pages_v, page_size_k, page_size_v, max_pages_per_seq_k, max_pages_per_seq_v, is_training, - return_max_logit, attn_scale, dropout, qkv_layout, bias_type, attn_mask_type, softmax_type, - window_size_left, window_size_right, bottom_right_diagonal, input_Q, input_K, input_V, + return_max_logit, attn_scale, dropout, qkv_layout, o_format, bias_type, attn_mask_type, + softmax_type, window_size_left, window_size_right, bottom_right_diagonal, input_Q, input_K, + input_V, input_Bias, input_SoftmaxOffset, output_O, Aux_CTX_Tensors, input_cu_seqlens_q, input_cu_seqlens_kv, input_cu_seqlens_q_padded, input_cu_seqlens_kv_padded, input_page_table_k, input_page_table_v, input_rng_state, wkspace, stream, handle); @@ -869,8 +870,9 @@ void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETenso } fused_attn_arbitrary_seqlen_bwd( b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d_qk, d_v, t_q, t_kv, attn_scale, dropout, - qkv_layout, bias_type, attn_mask_type, softmax_type, window_size_left, window_size_right, - bottom_right_diagonal, deterministic, input_Q, input_K, input_V, input_O, input_dO, + qkv_layout, o_format, do_format, dqkv_layout, bias_type, attn_mask_type, softmax_type, + window_size_left, window_size_right, bottom_right_diagonal, deterministic, input_Q, input_K, + input_V, input_O, input_dO, input_Bias, input_SoftmaxOffset, output_S, output_dQ, output_dK, output_dV, output_dBias, output_dSoftmaxOffset, input_cu_seqlens_q, input_cu_seqlens_kv, input_cu_seqlens_q_padded, input_cu_seqlens_kv_padded, input_rng_state, wkspace, stream, handle); diff --git a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu index eed6740740..78c5b5a027 100644 --- a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu +++ b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu @@ -54,7 +54,8 @@ void fused_attn_arbitrary_seqlen_fwd_impl( int64_t page_size_k, int64_t page_size_v, int64_t max_pages_per_seq_k, int64_t max_pages_per_seq_v, int64_t bias_b, int64_t bias_h, int64_t bias_sq, int64_t bias_skv, bool is_training, bool return_max_logit, float scaling_factor, float dropout_probability, - NVTE_QKV_Layout layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, + NVTE_QKV_Layout qkv_layout, NVTE_QKV_Format o_format, NVTE_Bias_Type bias_type, + NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, int64_t window_size_left, int64_t window_size_right, bool bottom_right_diagonal, void *devPtrQ, void *devPtrK, void *devPtrV, void *devPtrBias, void *devPtrSoftmaxOffset, void *devPtrS1, void *devPtrS2, void *devPtrO, @@ -80,8 +81,8 @@ void fused_attn_arbitrary_seqlen_fwd_impl( } bool is_softmax_offset = (softmax_type != NVTE_Softmax_Type::NVTE_VANILLA_SOFTMAX); bool is_dropout = (is_training && dropout_probability != 0.0f); - NVTE_QKV_Format q_format = nvte_get_q_format(layout); - NVTE_QKV_Format kv_format = nvte_get_kv_format(layout); + NVTE_QKV_Format q_format = nvte_get_q_format(qkv_layout); + NVTE_QKV_Format kv_format = nvte_get_kv_format(qkv_layout); bool is_ragged_q = (q_format == NVTE_QKV_Format::NVTE_THD); bool is_ragged_kv = (kv_format == NVTE_QKV_Format::NVTE_THD); const auto cudnn_runtime_version = cudnnGetVersion(); @@ -89,7 +90,7 @@ void fused_attn_arbitrary_seqlen_fwd_impl( const int sm_arch_ = cuda::sm_arch(device_id); bool use_ragged_stats = is_ragged_q && cudnn_runtime_version >= 90600 && sm_arch_ != 120; - NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(layout); + NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(qkv_layout); bool is_paged_kv = (layout_group == NVTE_QKV_Layout_Group::NVTE_Paged_KV_HD_HD_HD); if (is_paged_kv) { NVTE_CHECK(is_padding, "Paged attention requires padding mask!"); @@ -135,7 +136,10 @@ void fused_attn_arbitrary_seqlen_fwd_impl( scaling_factor, is_training, dropout_probability, - layout, + qkv_layout, + o_format, + NVTE_QKV_Format_NOT_SET, + NVTE_QKV_Layout_NOT_SET, bias_type, mask_type, softmax_type, @@ -202,17 +206,17 @@ void fused_attn_arbitrary_seqlen_fwd_impl( std::vector q_stride(4); std::vector k_stride(4); std::vector v_stride(4); - generateMatrixStrides(b, h, s_q, s_kv, d_qk, q_stride.data(), layout, + generateMatrixStrides(b, h, s_q, s_kv, d_qk, q_stride.data(), qkv_layout, NVTE_QKV_Matrix::NVTE_Q_Matrix); if (is_paged_kv) { generateMatrixStrides(num_pages_k, hg, page_size_k, page_size_v, d_qk, k_stride.data(), - layout, NVTE_QKV_Matrix::NVTE_K_Matrix); + qkv_layout, NVTE_QKV_Matrix::NVTE_K_Matrix); generateMatrixStrides(num_pages_v, hg, page_size_k, page_size_v, d_v, v_stride.data(), - layout, NVTE_QKV_Matrix::NVTE_V_Matrix); + qkv_layout, NVTE_QKV_Matrix::NVTE_V_Matrix); } else { - generateMatrixStrides(b, hg, s_q, s_kv, d_qk, k_stride.data(), layout, + generateMatrixStrides(b, hg, s_q, s_kv, d_qk, k_stride.data(), qkv_layout, NVTE_QKV_Matrix::NVTE_K_Matrix); - generateMatrixStrides(b, hg, s_q, s_kv, d_v, v_stride.data(), layout, + generateMatrixStrides(b, hg, s_q, s_kv, d_v, v_stride.data(), qkv_layout, NVTE_QKV_Matrix::NVTE_V_Matrix); } @@ -368,7 +372,7 @@ void fused_attn_arbitrary_seqlen_fwd_impl( auto [O, Stats] = mha_graph->sdpa(Q, K, V, std::move(sdpa_options)); std::vector o_stride(4); - generateMatrixStrides(b, h, s_q, s_kv, d_v, o_stride.data(), layout, + generateMatrixStrides(b, h, s_q, s_kv, d_v, o_stride.data(), qkv_layout, NVTE_QKV_Matrix::NVTE_O_Matrix); O->set_output(true).set_dim({b, h, s_q, d_v}).set_stride(o_stride); if (is_ragged_q) { @@ -513,7 +517,7 @@ void fused_attn_arbitrary_seqlen_fwd_impl( (static_cast(is_ragged_q) + static_cast(is_ragged_kv)) * 2 * num_bytes_per_ragged_offset; } - const NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(layout); + const NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(qkv_layout); cu_seqlens_padded_to_offsets<<>>( layout_group, actual_b, b, h, hg, d_qk, d_v, static_cast(devPtrSeqOffsetsQ), static_cast(devPtrSeqOffsetsKV), ragged_offset_type, devOffsetsQ, devOffsetsK, @@ -551,7 +555,8 @@ void fused_attn_arbitrary_seqlen_bwd_impl( int64_t b, int64_t h, int64_t hg, int64_t s_q, int64_t s_kv, int64_t d_qk, int64_t d_v, int64_t max_b, int64_t max_t_q, int64_t max_t_kv, int64_t bias_b, int64_t bias_h, int64_t bias_sq, int64_t bias_skv, float scaling_factor, float dropout_probability, - NVTE_QKV_Layout layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, + NVTE_QKV_Layout qkv_layout, NVTE_QKV_Format o_format, NVTE_QKV_Format do_format, + NVTE_QKV_Layout dqkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, int64_t window_size_left, int64_t window_size_right, bool bottom_right_diagonal, bool deterministic, void *devPtrQ, void *devPtrKTranspose, void *devPtrVTranspose, void *devPtrO, void *devPtrSoftmaxStats, void *devPtrBias, @@ -578,8 +583,8 @@ void fused_attn_arbitrary_seqlen_bwd_impl( } bool is_softmax_offset = (softmax_type != NVTE_Softmax_Type::NVTE_VANILLA_SOFTMAX); bool is_dropout = (dropout_probability != 0.0f); - NVTE_QKV_Format q_format = nvte_get_q_format(layout); - NVTE_QKV_Format kv_format = nvte_get_kv_format(layout); + NVTE_QKV_Format q_format = nvte_get_q_format(qkv_layout); + NVTE_QKV_Format kv_format = nvte_get_kv_format(qkv_layout); bool is_ragged_q = (q_format == NVTE_QKV_Format::NVTE_THD); bool is_ragged_kv = (kv_format == NVTE_QKV_Format::NVTE_THD); const auto cudnn_runtime_version = cudnnGetVersion(); @@ -587,7 +592,7 @@ void fused_attn_arbitrary_seqlen_bwd_impl( const int sm_arch_ = cuda::sm_arch(device_id); bool use_ragged_stats = is_ragged_q && cudnn_runtime_version >= 90600 && sm_arch_ != 120; - NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(layout); + NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(qkv_layout); bool is_paged_kv = (layout_group == NVTE_QKV_Layout_Group::NVTE_Paged_KV_HD_HD_HD); if (is_paged_kv) { NVTE_CHECK(is_padding, "Paged attention requires padding mask!"); @@ -632,7 +637,10 @@ void fused_attn_arbitrary_seqlen_bwd_impl( scaling_factor, true, dropout_probability, - layout, + qkv_layout, + o_format, + do_format, + dqkv_layout, bias_type, mask_type, softmax_type, @@ -703,13 +711,13 @@ void fused_attn_arbitrary_seqlen_bwd_impl( std::vector k_stride(4); std::vector v_stride(4); std::vector o_stride(4); - generateMatrixStrides(b, h, s_q, s_kv, d_qk, q_stride.data(), layout, + generateMatrixStrides(b, h, s_q, s_kv, d_qk, q_stride.data(), qkv_layout, NVTE_QKV_Matrix::NVTE_Q_Matrix); - generateMatrixStrides(b, hg, s_q, s_kv, d_qk, k_stride.data(), layout, + generateMatrixStrides(b, hg, s_q, s_kv, d_qk, k_stride.data(), qkv_layout, NVTE_QKV_Matrix::NVTE_K_Matrix); - generateMatrixStrides(b, hg, s_q, s_kv, d_v, v_stride.data(), layout, + generateMatrixStrides(b, hg, s_q, s_kv, d_v, v_stride.data(), qkv_layout, NVTE_QKV_Matrix::NVTE_V_Matrix); - generateMatrixStrides(b, h, s_q, s_kv, d_v, o_stride.data(), layout, + generateMatrixStrides(b, h, s_q, s_kv, d_v, o_stride.data(), qkv_layout, NVTE_QKV_Matrix::NVTE_O_Matrix); q = mha_graph->tensor(fe::graph::Tensor_attributes() @@ -1024,7 +1032,7 @@ void fused_attn_arbitrary_seqlen_bwd_impl( (static_cast(is_ragged_q) + static_cast(is_ragged_kv)) * 2 * num_bytes_per_ragged_offset; } - const NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(layout); + const NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(qkv_layout); cu_seqlens_padded_to_offsets<<>>( layout_group, actual_b, b, h, hg, d_qk, d_v, static_cast(devPtrSeqOffsetsQ), static_cast(devPtrSeqOffsetsKV), ragged_offset_type, devOffsetsQ, devOffsetsK, @@ -1067,7 +1075,8 @@ void fused_attn_arbitrary_seqlen_fwd( size_t num_tokens_kv, size_t num_pages_k, size_t num_pages_v, size_t page_size_k, size_t page_size_v, size_t max_pages_per_seq_k, size_t max_pages_per_seq_v, bool is_training, bool return_max_logit, float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, - NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, + NVTE_QKV_Format o_format, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, + NVTE_Softmax_Type softmax_type, int64_t window_size_left, int64_t window_size_right, bool bottom_right_diagonal, const Tensor *input_Q, const Tensor *input_K, const Tensor *input_V, const Tensor *input_Bias, const Tensor *input_SoftmaxOffset, Tensor *output_O, NVTETensorPack *Aux_CTX_Tensors, @@ -1202,8 +1211,9 @@ void fused_attn_arbitrary_seqlen_fwd( batch, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, head_dim_qk, head_dim_v, max_batch_size, max_tokens_q, max_tokens_kv, num_pages_k, num_pages_v, page_size_k, page_size_v, max_pages_per_seq_k, max_pages_per_seq_v, bias_b, bias_h, bias_sq, bias_skv, - is_training, return_max_logit, attn_scale, p_dropout, qkv_layout, bias_type, mask_type, - softmax_type, window_size_left, window_size_right, bottom_right_diagonal, devPtrQ, devPtrK, + is_training, return_max_logit, attn_scale, p_dropout, qkv_layout, o_format, bias_type, + mask_type, softmax_type, window_size_left, window_size_right, bottom_right_diagonal, devPtrQ, + devPtrK, devPtrV, devPtrBias, devPtrSoftmaxOffset, devPtrS1, devPtrS2, devPtrO, devPtrDropoutSeed, devPtrDropoutOffset, devPtrCuSeqlensQ, devPtrCuSeqlensKV, devPtrPageTableK, devPtrPageTableV, devPtrSeqOffsetsQ, devPtrSeqOffsetsKV, get_cudnn_fe_dtype(QKV_type), workspace->data.dptr, @@ -1228,6 +1238,7 @@ void fused_attn_arbitrary_seqlen_bwd( size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, size_t num_tokens_q, size_t num_tokens_kv, float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, + NVTE_QKV_Format o_format, NVTE_QKV_Format do_format, NVTE_QKV_Layout dqkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, int64_t window_size_left, int64_t window_size_right, bool bottom_right_diagonal, bool deterministic, const Tensor *input_Q, const Tensor *input_K, const Tensor *input_V, @@ -1300,8 +1311,9 @@ void fused_attn_arbitrary_seqlen_bwd( fused_attn_arbitrary_seqlen_bwd_impl( batch, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, head_dim_qk, head_dim_v, max_batch_size, max_tokens_q, max_tokens_kv, bias_b, bias_h, bias_sq, bias_skv, attn_scale, - p_dropout, qkv_layout, bias_type, mask_type, softmax_type, window_size_left, - window_size_right, bottom_right_diagonal, deterministic, devPtrQ, devPtrK, devPtrV, devPtrO, + p_dropout, qkv_layout, o_format, do_format, dqkv_layout, bias_type, mask_type, softmax_type, + window_size_left, window_size_right, bottom_right_diagonal, deterministic, devPtrQ, devPtrK, + devPtrV, devPtrO, devPtrSoftmaxStats, devPtrBias, devPtrSoftmaxOffset, devPtrdQ, devPtrdK, devPtrdV, devPtrdO, devPtrdBias, devPtrdSoftmaxOffset, devPtrDropoutSeed, devPtrDropoutOffset, devPtrCuSeqlensQ, devPtrCuSeqlensKV, devPtrSeqOffsetsQ, devPtrSeqOffsetsKV, get_cudnn_fe_dtype(QKV_type), diff --git a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.h b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.h index 4dd7f3d1da..09234f12ac 100644 --- a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.h +++ b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.h @@ -24,7 +24,8 @@ void fused_attn_arbitrary_seqlen_fwd( size_t num_tokens_kv, size_t num_pages_k, size_t num_pages_v, size_t page_size_k, size_t page_size_v, size_t max_pages_per_seq_k, size_t max_pages_per_seq_v, bool is_training, bool return_max_logit, float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, - NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, + NVTE_QKV_Format o_format, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, + NVTE_Softmax_Type softmax_type, int64_t window_size_left, int64_t window_size_right, bool bottom_right_diagonal, const Tensor *input_Q, const Tensor *input_K, const Tensor *input_V, const Tensor *input_Bias, const Tensor *input_SoftmaxOffset, Tensor *output_O, NVTETensorPack *Aux_CTX_Tensors, @@ -36,6 +37,7 @@ void fused_attn_arbitrary_seqlen_bwd( size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, size_t num_tokens_q, size_t num_tokens_kv, float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, + NVTE_QKV_Format o_format, NVTE_QKV_Format do_format, NVTE_QKV_Layout dqkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, int64_t window_size_left, int64_t window_size_right, bool bottom_right_diagonal, bool deterministic, const Tensor *input_Q, const Tensor *input_K, const Tensor *input_V, diff --git a/transformer_engine/common/fused_attn/fused_attn_fp8.cu b/transformer_engine/common/fused_attn/fused_attn_fp8.cu index 70503caba8..6fa366dc2c 100644 --- a/transformer_engine/common/fused_attn/fused_attn_fp8.cu +++ b/transformer_engine/common/fused_attn/fused_attn_fp8.cu @@ -1715,6 +1715,9 @@ void fused_attn_fp8_fwd_impl_v1( is_training, dropout_probability, qkv_layout, + o_format, + NVTE_QKV_Format_NOT_SET, + NVTE_QKV_Layout_NOT_SET, bias_type, mask_type, softmax_type, @@ -2143,6 +2146,9 @@ void fused_attn_fp8_bwd_impl_v1( true, dropout_probability, qkv_layout, + o_format, + do_format, + dqkv_layout, bias_type, mask_type, softmax_type, diff --git a/transformer_engine/common/fused_attn/utils.h b/transformer_engine/common/fused_attn/utils.h index 9c4fbaa7c8..0f88b06f7f 100644 --- a/transformer_engine/common/fused_attn/utils.h +++ b/transformer_engine/common/fused_attn/utils.h @@ -293,7 +293,10 @@ struct FADescriptor_v1 { float attnScale; bool isTraining; float dropoutProbability; - NVTE_QKV_Layout layout; + NVTE_QKV_Layout qkv_layout; + NVTE_QKV_Format o_format; + NVTE_QKV_Format do_format; + NVTE_QKV_Layout dqkv_layout; NVTE_Bias_Type bias_type; NVTE_Mask_Type mask_type; NVTE_Softmax_Type softmax_type; @@ -310,14 +313,14 @@ struct FADescriptor_v1 { bool operator<(const FADescriptor_v1 &rhs) const { return std::tie(b, h, hg, s_q, s_kv, d_qk, d_v, num_pages_k, num_pages_v, page_size_k, page_size_v, max_pages_per_seq_k, max_pages_per_seq_v, bias_b, bias_h, bias_sq, - bias_skv, attnScale, isTraining, dropoutProbability, layout, mask_type, + bias_skv, attnScale, isTraining, dropoutProbability, qkv_layout, o_format, do_format, dqkv_layout, mask_type, softmax_type, window_size_left, window_size_right, bottom_right_diagonal, deterministic, bias_type, qkv_tensor_type, o_tensor_type, do_tensor_type, dqkv_tensor_type, return_max_logit) < std::tie(rhs.b, rhs.h, rhs.hg, rhs.s_q, rhs.s_kv, rhs.d_qk, rhs.d_v, rhs.num_pages_k, rhs.num_pages_v, rhs.page_size_k, rhs.page_size_v, rhs.max_pages_per_seq_k, rhs.max_pages_per_seq_v, rhs.bias_b, rhs.bias_h, rhs.bias_sq, rhs.bias_skv, - rhs.attnScale, rhs.isTraining, rhs.dropoutProbability, rhs.layout, + rhs.attnScale, rhs.isTraining, rhs.dropoutProbability, rhs.qkv_layout, rhs.o_format, rhs.do_format, rhs.dqkv_layout, rhs.mask_type, rhs.softmax_type, rhs.window_size_left, rhs.window_size_right, rhs.bottom_right_diagonal, rhs.deterministic, rhs.bias_type, rhs.qkv_tensor_type, rhs.o_tensor_type, rhs.do_tensor_type, diff --git a/transformer_engine/common/include/transformer_engine/fused_attn.h b/transformer_engine/common/include/transformer_engine/fused_attn.h index 02c351ced7..65cdaca7d0 100644 --- a/transformer_engine/common/include/transformer_engine/fused_attn.h +++ b/transformer_engine/common/include/transformer_engine/fused_attn.h @@ -53,6 +53,7 @@ enum NVTE_QKV_Layout { NVTE_Paged_KV_THD_BSHD_BSHD = 23, /*!< Paged_KV_THD_BSHD_BSHD layout */ NVTE_Paged_KV_THD_SBHD_SBHD = 24, /*!< Paged_KV_THD_SBHD_SBHD layout */ NVTE_BHSD_BHSD_BHSD = 25, /*!< BHSD_BHSD_BHSD layout */ + NVTE_QKV_Layout_NOT_SET, /*!< Not set */ }; /*! \enum NVTE_QKV_Layout_Group @@ -95,6 +96,8 @@ enum NVTE_QKV_Format { NVTE_THD_2SBHD = 6, /*! BHSD QKV format, e.g. BHSD_BHSD_BHSD */ NVTE_BHSD = 7, + /*! Not set */ + NVTE_QKV_Format_NOT_SET, }; /*! \enum NVTE_Bias_Type From f1d1809cebbd2546e43cbb9531f1624462b5b0fd Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 31 Mar 2026 02:23:28 +0000 Subject: [PATCH 160/172] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../common/fused_attn/fused_attn.cpp | 11 +++-- .../fused_attn_f16_arbitrary_seqlen.cu | 41 +++++++++---------- .../fused_attn_f16_arbitrary_seqlen.h | 14 +++---- transformer_engine/common/fused_attn/utils.h | 19 +++++---- 4 files changed, 41 insertions(+), 44 deletions(-) diff --git a/transformer_engine/common/fused_attn/fused_attn.cpp b/transformer_engine/common/fused_attn/fused_attn.cpp index 5f3fe5288c..5498c601a6 100644 --- a/transformer_engine/common/fused_attn/fused_attn.cpp +++ b/transformer_engine/common/fused_attn/fused_attn.cpp @@ -762,8 +762,7 @@ void nvte_fused_attn_fwd( page_size_k, page_size_v, max_pages_per_seq_k, max_pages_per_seq_v, is_training, return_max_logit, attn_scale, dropout, qkv_layout, o_format, bias_type, attn_mask_type, softmax_type, window_size_left, window_size_right, bottom_right_diagonal, input_Q, input_K, - input_V, - input_Bias, input_SoftmaxOffset, output_O, Aux_CTX_Tensors, input_cu_seqlens_q, + input_V, input_Bias, input_SoftmaxOffset, output_O, Aux_CTX_Tensors, input_cu_seqlens_q, input_cu_seqlens_kv, input_cu_seqlens_q_padded, input_cu_seqlens_kv_padded, input_page_table_k, input_page_table_v, input_rng_state, wkspace, stream, handle); #else @@ -872,10 +871,10 @@ void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETenso b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d_qk, d_v, t_q, t_kv, attn_scale, dropout, qkv_layout, o_format, do_format, dqkv_layout, bias_type, attn_mask_type, softmax_type, window_size_left, window_size_right, bottom_right_diagonal, deterministic, input_Q, input_K, - input_V, input_O, input_dO, - input_Bias, input_SoftmaxOffset, output_S, output_dQ, output_dK, output_dV, output_dBias, - output_dSoftmaxOffset, input_cu_seqlens_q, input_cu_seqlens_kv, input_cu_seqlens_q_padded, - input_cu_seqlens_kv_padded, input_rng_state, wkspace, stream, handle); + input_V, input_O, input_dO, input_Bias, input_SoftmaxOffset, output_S, output_dQ, output_dK, + output_dV, output_dBias, output_dSoftmaxOffset, input_cu_seqlens_q, input_cu_seqlens_kv, + input_cu_seqlens_q_padded, input_cu_seqlens_kv_padded, input_rng_state, wkspace, stream, + handle); #else const char *err_msg = "cuDNN 8.9.0 is required for BF16/FP16 fused attention " diff --git a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu index 78c5b5a027..f8c3992587 100644 --- a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu +++ b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu @@ -55,11 +55,10 @@ void fused_attn_arbitrary_seqlen_fwd_impl( int64_t max_pages_per_seq_v, int64_t bias_b, int64_t bias_h, int64_t bias_sq, int64_t bias_skv, bool is_training, bool return_max_logit, float scaling_factor, float dropout_probability, NVTE_QKV_Layout qkv_layout, NVTE_QKV_Format o_format, NVTE_Bias_Type bias_type, - NVTE_Mask_Type mask_type, - NVTE_Softmax_Type softmax_type, int64_t window_size_left, int64_t window_size_right, - bool bottom_right_diagonal, void *devPtrQ, void *devPtrK, void *devPtrV, void *devPtrBias, - void *devPtrSoftmaxOffset, void *devPtrS1, void *devPtrS2, void *devPtrO, - void *devPtrDropoutSeed, void *devPtrDropoutOffset, void *devPtrCuSeqlensQ, + NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, int64_t window_size_left, + int64_t window_size_right, bool bottom_right_diagonal, void *devPtrQ, void *devPtrK, + void *devPtrV, void *devPtrBias, void *devPtrSoftmaxOffset, void *devPtrS1, void *devPtrS2, + void *devPtrO, void *devPtrDropoutSeed, void *devPtrDropoutOffset, void *devPtrCuSeqlensQ, void *devPtrCuSeqlensKV, void *devPtrPageTableK, void *devPtrPageTableV, void *devPtrSeqOffsetsQ, void *devPtrSeqOffsetsKV, cudnn_frontend::DataType_t tensorType, void *workspace, size_t *workspace_size, cudaStream_t stream, cudnnHandle_t handle) { @@ -1076,13 +1075,13 @@ void fused_attn_arbitrary_seqlen_fwd( size_t page_size_v, size_t max_pages_per_seq_k, size_t max_pages_per_seq_v, bool is_training, bool return_max_logit, float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_QKV_Format o_format, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, - NVTE_Softmax_Type softmax_type, - int64_t window_size_left, int64_t window_size_right, bool bottom_right_diagonal, - const Tensor *input_Q, const Tensor *input_K, const Tensor *input_V, const Tensor *input_Bias, - const Tensor *input_SoftmaxOffset, Tensor *output_O, NVTETensorPack *Aux_CTX_Tensors, - const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv, const Tensor *cu_seqlens_q_padded, - const Tensor *cu_seqlens_kv_padded, const Tensor *page_table_k, const Tensor *page_table_v, - const Tensor *rng_state, Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle) { + NVTE_Softmax_Type softmax_type, int64_t window_size_left, int64_t window_size_right, + bool bottom_right_diagonal, const Tensor *input_Q, const Tensor *input_K, const Tensor *input_V, + const Tensor *input_Bias, const Tensor *input_SoftmaxOffset, Tensor *output_O, + NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv, + const Tensor *cu_seqlens_q_padded, const Tensor *cu_seqlens_kv_padded, + const Tensor *page_table_k, const Tensor *page_table_v, const Tensor *rng_state, + Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle) { using namespace transformer_engine; const auto QKV_type = input_Q->data.dtype; @@ -1213,11 +1212,10 @@ void fused_attn_arbitrary_seqlen_fwd( page_size_v, max_pages_per_seq_k, max_pages_per_seq_v, bias_b, bias_h, bias_sq, bias_skv, is_training, return_max_logit, attn_scale, p_dropout, qkv_layout, o_format, bias_type, mask_type, softmax_type, window_size_left, window_size_right, bottom_right_diagonal, devPtrQ, - devPtrK, - devPtrV, devPtrBias, devPtrSoftmaxOffset, devPtrS1, devPtrS2, devPtrO, devPtrDropoutSeed, - devPtrDropoutOffset, devPtrCuSeqlensQ, devPtrCuSeqlensKV, devPtrPageTableK, devPtrPageTableV, - devPtrSeqOffsetsQ, devPtrSeqOffsetsKV, get_cudnn_fe_dtype(QKV_type), workspace->data.dptr, - &workspace_size, stream, handle); + devPtrK, devPtrV, devPtrBias, devPtrSoftmaxOffset, devPtrS1, devPtrS2, devPtrO, + devPtrDropoutSeed, devPtrDropoutOffset, devPtrCuSeqlensQ, devPtrCuSeqlensKV, devPtrPageTableK, + devPtrPageTableV, devPtrSeqOffsetsQ, devPtrSeqOffsetsKV, get_cudnn_fe_dtype(QKV_type), + workspace->data.dptr, &workspace_size, stream, handle); if (workspace_size > 0) { if (workspace->data.dptr == nullptr) { @@ -1313,11 +1311,10 @@ void fused_attn_arbitrary_seqlen_bwd( max_batch_size, max_tokens_q, max_tokens_kv, bias_b, bias_h, bias_sq, bias_skv, attn_scale, p_dropout, qkv_layout, o_format, do_format, dqkv_layout, bias_type, mask_type, softmax_type, window_size_left, window_size_right, bottom_right_diagonal, deterministic, devPtrQ, devPtrK, - devPtrV, devPtrO, - devPtrSoftmaxStats, devPtrBias, devPtrSoftmaxOffset, devPtrdQ, devPtrdK, devPtrdV, devPtrdO, - devPtrdBias, devPtrdSoftmaxOffset, devPtrDropoutSeed, devPtrDropoutOffset, devPtrCuSeqlensQ, - devPtrCuSeqlensKV, devPtrSeqOffsetsQ, devPtrSeqOffsetsKV, get_cudnn_fe_dtype(QKV_type), - workspace->data.dptr, &workspace_size, stream, handle); + devPtrV, devPtrO, devPtrSoftmaxStats, devPtrBias, devPtrSoftmaxOffset, devPtrdQ, devPtrdK, + devPtrdV, devPtrdO, devPtrdBias, devPtrdSoftmaxOffset, devPtrDropoutSeed, devPtrDropoutOffset, + devPtrCuSeqlensQ, devPtrCuSeqlensKV, devPtrSeqOffsetsQ, devPtrSeqOffsetsKV, + get_cudnn_fe_dtype(QKV_type), workspace->data.dptr, &workspace_size, stream, handle); if (workspace_size > 0) { if (workspace->data.dptr == nullptr) { diff --git a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.h b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.h index 09234f12ac..19dc94e755 100644 --- a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.h +++ b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.h @@ -25,13 +25,13 @@ void fused_attn_arbitrary_seqlen_fwd( size_t page_size_v, size_t max_pages_per_seq_k, size_t max_pages_per_seq_v, bool is_training, bool return_max_logit, float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_QKV_Format o_format, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, - NVTE_Softmax_Type softmax_type, - int64_t window_size_left, int64_t window_size_right, bool bottom_right_diagonal, - const Tensor *input_Q, const Tensor *input_K, const Tensor *input_V, const Tensor *input_Bias, - const Tensor *input_SoftmaxOffset, Tensor *output_O, NVTETensorPack *Aux_CTX_Tensors, - const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv, const Tensor *cu_seqlens_q_padded, - const Tensor *cu_seqlens_kv_padded, const Tensor *page_table_k, const Tensor *page_table_v, - const Tensor *rng_state, Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle); + NVTE_Softmax_Type softmax_type, int64_t window_size_left, int64_t window_size_right, + bool bottom_right_diagonal, const Tensor *input_Q, const Tensor *input_K, const Tensor *input_V, + const Tensor *input_Bias, const Tensor *input_SoftmaxOffset, Tensor *output_O, + NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv, + const Tensor *cu_seqlens_q_padded, const Tensor *cu_seqlens_kv_padded, + const Tensor *page_table_k, const Tensor *page_table_v, const Tensor *rng_state, + Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle); void fused_attn_arbitrary_seqlen_bwd( size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, diff --git a/transformer_engine/common/fused_attn/utils.h b/transformer_engine/common/fused_attn/utils.h index 0f88b06f7f..b600261f40 100644 --- a/transformer_engine/common/fused_attn/utils.h +++ b/transformer_engine/common/fused_attn/utils.h @@ -313,18 +313,19 @@ struct FADescriptor_v1 { bool operator<(const FADescriptor_v1 &rhs) const { return std::tie(b, h, hg, s_q, s_kv, d_qk, d_v, num_pages_k, num_pages_v, page_size_k, page_size_v, max_pages_per_seq_k, max_pages_per_seq_v, bias_b, bias_h, bias_sq, - bias_skv, attnScale, isTraining, dropoutProbability, qkv_layout, o_format, do_format, dqkv_layout, mask_type, - softmax_type, window_size_left, window_size_right, bottom_right_diagonal, - deterministic, bias_type, qkv_tensor_type, o_tensor_type, do_tensor_type, - dqkv_tensor_type, return_max_logit) < + bias_skv, attnScale, isTraining, dropoutProbability, qkv_layout, o_format, + do_format, dqkv_layout, mask_type, softmax_type, window_size_left, + window_size_right, bottom_right_diagonal, deterministic, bias_type, + qkv_tensor_type, o_tensor_type, do_tensor_type, dqkv_tensor_type, + return_max_logit) < std::tie(rhs.b, rhs.h, rhs.hg, rhs.s_q, rhs.s_kv, rhs.d_qk, rhs.d_v, rhs.num_pages_k, rhs.num_pages_v, rhs.page_size_k, rhs.page_size_v, rhs.max_pages_per_seq_k, rhs.max_pages_per_seq_v, rhs.bias_b, rhs.bias_h, rhs.bias_sq, rhs.bias_skv, - rhs.attnScale, rhs.isTraining, rhs.dropoutProbability, rhs.qkv_layout, rhs.o_format, rhs.do_format, rhs.dqkv_layout, - rhs.mask_type, rhs.softmax_type, rhs.window_size_left, rhs.window_size_right, - rhs.bottom_right_diagonal, rhs.deterministic, rhs.bias_type, - rhs.qkv_tensor_type, rhs.o_tensor_type, rhs.do_tensor_type, - rhs.dqkv_tensor_type, rhs.return_max_logit); + rhs.attnScale, rhs.isTraining, rhs.dropoutProbability, rhs.qkv_layout, + rhs.o_format, rhs.do_format, rhs.dqkv_layout, rhs.mask_type, rhs.softmax_type, + rhs.window_size_left, rhs.window_size_right, rhs.bottom_right_diagonal, + rhs.deterministic, rhs.bias_type, rhs.qkv_tensor_type, rhs.o_tensor_type, + rhs.do_tensor_type, rhs.dqkv_tensor_type, rhs.return_max_logit); } }; From c49190819a18e78b40ab3022fb28e04ea0ff0d05 Mon Sep 17 00:00:00 2001 From: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Date: Tue, 31 Mar 2026 14:40:16 -0700 Subject: [PATCH 161/172] enable sink attn + FP8 in CP Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --- tests/pytorch/attention/test_attention_with_cp.py | 2 -- .../attention/dot_product_attention/context_parallel.py | 5 ----- 2 files changed, 7 deletions(-) diff --git a/tests/pytorch/attention/test_attention_with_cp.py b/tests/pytorch/attention/test_attention_with_cp.py index 693c2cf960..1d77d042ea 100644 --- a/tests/pytorch/attention/test_attention_with_cp.py +++ b/tests/pytorch/attention/test_attention_with_cp.py @@ -330,8 +330,6 @@ def test_cp_with_fused_attention( f" num_gqa_groups ({config.num_gqa_groups}) divisible by 2!" ) - if config.softmax_type != "vanilla" and dtype == "fp8": - pytest.skip("No support for non-vanilla softmax with FP8 attention!") if config.softmax_type != "vanilla" and cp_comm_type != "a2a": pytest.skip(f"No support for non-vanilla softmax with cp_comm_type={cp_comm_type}!") if ( diff --git a/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py b/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py index 142a375007..e76ebf66b0 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py @@ -4647,11 +4647,6 @@ def attn_forward_func_with_cp( "all_gather", ], f"Context parallelism does not support sliding window attention with {cp_comm_type=}!" - if fp8 and fp8_meta is not None: - if fp8_meta["recipe"].fp8_dpa: - assert ( - softmax_type == "vanilla" - ), f"Context parallelism does not support {softmax_type=} with FP8 attention!" assert ( softmax_type == "vanilla" or use_fused_attention ), f"Context parallelism only supports {softmax_type=} with FusedAttention backend!" From 6af310575f3f13080a3cadc25105cfb06a48485a Mon Sep 17 00:00:00 2001 From: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Date: Fri, 3 Apr 2026 15:20:03 -0700 Subject: [PATCH 162/172] update FE to GH v1.22.0 Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --- .gitmodules | 3 +-- 3rdparty/cudnn-frontend | 2 +- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/.gitmodules b/.gitmodules index 8c7646c00d..4b188d6bb1 100644 --- a/.gitmodules +++ b/.gitmodules @@ -3,8 +3,7 @@ url = https://github.com/google/googletest.git [submodule "3rdparty/cudnn-frontend"] path = 3rdparty/cudnn-frontend - url = https://gitlab-master.nvidia.com/cudnn/cudnn_frontend.git - branch = develop + url = https://github.com/NVIDIA/cudnn-frontend.git [submodule "3rdparty/cutlass"] path = 3rdparty/cutlass url = https://github.com/NVIDIA/cutlass.git diff --git a/3rdparty/cudnn-frontend b/3rdparty/cudnn-frontend index 28ebf81b5e..97f6cb3b88 160000 --- a/3rdparty/cudnn-frontend +++ b/3rdparty/cudnn-frontend @@ -1 +1 @@ -Subproject commit 28ebf81b5e19b92045d1af0fd4c0b9c4599c2b53 +Subproject commit 97f6cb3b88cacff507cca1280db5650a457d92b3 From 508044bdbc754f89261aca65932b53ab12183852 Mon Sep 17 00:00:00 2001 From: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Date: Fri, 3 Apr 2026 17:44:35 -0700 Subject: [PATCH 163/172] fix for inconsistent kwarg name in permute to grouped tensor Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --- .../pytorch/attention/dot_product_attention/utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/transformer_engine/pytorch/attention/dot_product_attention/utils.py b/transformer_engine/pytorch/attention/dot_product_attention/utils.py index 8ee16649e7..d6171d04f5 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/utils.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/utils.py @@ -2338,10 +2338,10 @@ def forward( query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, - input_layout: str = "bshd_bshd_bshd", + original_layout: str = "bshd_bshd_bshd", ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: # pylint: disable=missing-function-docstring - ctx.original_layout = QKVLayout[input_layout] + ctx.original_layout = QKVLayout[original_layout] return tex.permute_to_grouped_tensor_fwd(query, key, value, ctx.original_layout) @staticmethod From 2532a50e829144bee290fc94acb8f3f154a62ea9 Mon Sep 17 00:00:00 2001 From: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Date: Fri, 3 Apr 2026 17:45:35 -0700 Subject: [PATCH 164/172] add TMA permute Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --- .../common/fused_attn/flash_attn.cu | 543 ++++++++++++------ 1 file changed, 366 insertions(+), 177 deletions(-) diff --git a/transformer_engine/common/fused_attn/flash_attn.cu b/transformer_engine/common/fused_attn/flash_attn.cu index c30cc3d2d9..8f03c8ed0c 100644 --- a/transformer_engine/common/fused_attn/flash_attn.cu +++ b/transformer_engine/common/fused_attn/flash_attn.cu @@ -4,7 +4,13 @@ * See LICENSE for license information. ************************************************************************/ +#include +#include + #include "../common.h" +#include "../util/cuda_driver.h" +#include "../util/ptx.cuh" +#include "../utils.cuh" #include "transformer_engine/fused_attn.h" namespace transformer_engine { @@ -23,6 +29,97 @@ constexpr int nvec128 = sizeof(uint4) / type_size; constexpr int load_size = warp_size * nvec; constexpr int block_size = 512; +// TMA permute kernel configuration +constexpr int tma_permute_threads = 32; +constexpr int tma_permute_s_tile = 64; + +// ---- 4D TMA PTX wrappers ---- + +__device__ __forceinline__ void cp_async_bulk_tensor_4d_global_to_shared( + void *dst_shmem, const CUtensorMap *tensor_map, + uint32_t c0, uint32_t c1, uint32_t c2, uint32_t c3, + uint64_t *mbar) { +#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) + uint32_t dst = __cvta_generic_to_shared(dst_shmem); + uint32_t bar = __cvta_generic_to_shared(mbar); + asm volatile( + "cp.async.bulk.tensor.4d.shared::cluster.global.tile" + ".mbarrier::complete_tx::bytes [%0], [%1, {%2, %3, %4, %5}], [%6];" + ::"r"(dst), "l"(tensor_map), + "r"(c0), "r"(c1), "r"(c2), "r"(c3), + "r"(bar) + : "memory"); +#else + NVTE_DEVICE_ERROR("cp_async_bulk_tensor_4d_global_to_shared requires SM 10.0+."); +#endif +} + +__device__ __forceinline__ void cp_async_bulk_tensor_4d_shared_to_global( + const CUtensorMap *tensor_map, + uint32_t c0, uint32_t c1, uint32_t c2, uint32_t c3, + void *src_shmem) { +#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) + uint32_t src = __cvta_generic_to_shared(src_shmem); + asm volatile( + "cp.async.bulk.tensor.4d.global.shared::cta.bulk_group" + " [%0, {%1, %2, %3, %4}], [%5];" + ::"l"(tensor_map), + "r"(c0), "r"(c1), "r"(c2), "r"(c3), + "r"(src) + : "memory"); +#else + NVTE_DEVICE_ERROR("cp_async_bulk_tensor_4d_shared_to_global requires SM 9.0+."); +#endif +} + +// ---- Host-side 4D tensor map creation ---- +// +// Creates a 4D TMA descriptor for a densely-packed tensor whose logical +// dimensions (innermost-first) are [dim0, dim1, dim2, dim3]. +// +// The box (tile) copied per TMA instruction is [box0, box1, box2, box3]. + +static void create_4D_tensor_map(CUtensorMap &tensorMap, void *dataPtr, DType dtype, + uint64_t dim0, uint64_t dim1, uint64_t dim2, uint64_t dim3, + uint32_t box0, uint32_t box1, uint32_t box2, uint32_t box3) { + cuda_driver::ensure_context_exists(); + static PFN_cuTensorMapEncodeTiled_v12000 cuDriverTensorMapEncodeTiled = []() { + void *ptr = cuda_driver::get_symbol("cuTensorMapEncodeTiled"); + return reinterpret_cast(ptr); + }(); + + CUtensorMapDataType tma_dtype; + size_t elem_bytes; + switch (dtype) { + case DType::kFloat16: + tma_dtype = CU_TENSOR_MAP_DATA_TYPE_FLOAT16; + elem_bytes = 2; + break; + case DType::kBFloat16: + tma_dtype = CU_TENSOR_MAP_DATA_TYPE_BFLOAT16; + elem_bytes = 2; + break; + default: + NVTE_ERROR("create_4D_tensor_map: unsupported dtype"); + } + + constexpr uint32_t rank = 4; + uint64_t size[rank] = {dim0, dim1, dim2, dim3}; + uint64_t stride[rank - 1] = { + dim0 * elem_bytes, + dim0 * dim1 * elem_bytes, + dim0 * dim1 * dim2 * elem_bytes, + }; + uint32_t boxSize[rank] = {box0, box1, box2, box3}; + uint32_t elemStride[rank] = {1, 1, 1, 1}; + + NVTE_CHECK_CUDA_DRIVER(cuDriverTensorMapEncodeTiled( + &tensorMap, tma_dtype, rank, dataPtr, size, stride, boxSize, elemStride, + CU_TENSOR_MAP_INTERLEAVE_NONE, CU_TENSOR_MAP_SWIZZLE_NONE, + CU_TENSOR_MAP_L2_PROMOTION_NONE, + CU_TENSOR_MAP_FLOAT_OOB_FILL_NAN_REQUEST_ZERO_FMA)); +} + template __launch_bounds__(block_size) __global__ void prepare_kernel_fwd(const T *qkvi, T *qkv, const size_t B, const size_t S, const size_t Z, @@ -140,168 +237,235 @@ void prepare_flash_attn_bwd(Tensor q, Tensor k, Tensor v, Tensor qkv, cudaStream NVTE_CHECK_CUDA(cudaGetLastError()); } +// ---- TMA helpers for strided (BSHD/SBHD) tensors ---- +// +// Strided BSHD: TMA dims [D, H, S, B], coords [0, h, s, b] +// Strided SBHD: TMA dims [D, H, B, S], coords [0, h, b, s] + template -__launch_bounds__(1024) __global__ - void permute_to_grouped_tensor_fwd_kernel(const T *__restrict__ q, const T *__restrict__ k, - const T *__restrict__ v, T *__restrict__ q_out, - T *__restrict__ k_out, T *__restrict__ v_out, - size_t b, size_t s_q, size_t h_q, size_t d_qk, - size_t s_kv, size_t h_kv, size_t d_v, - unsigned int permute_s_splits) { - const int which_tensor = blockIdx.z; - const T *__restrict__ tensor_in = which_tensor == 0 ? q : (which_tensor == 1 ? k : v); - T *__restrict__ tensor_out = which_tensor == 0 ? q_out : (which_tensor == 1 ? k_out : v_out); - const size_t Sdim = which_tensor == 0 ? s_q : s_kv; - const size_t Hdim = which_tensor == 0 ? h_q : h_kv; - const size_t Ddim = which_tensor == 0 ? d_qk : (which_tensor == 1 ? d_qk : d_v); +__device__ __forceinline__ void issue_tma_load_strided( + T *smem_buf, const CUtensorMap *tma, + size_t h_i, size_t s_tile, size_t b_i, + uint64_t *mbar, size_t tile_bytes) { + ptx::mbarrier_arrive_expect_tx(mbar, static_cast(tile_bytes)); + if constexpr (kIsBshdBshdBshd) { + cp_async_bulk_tensor_4d_global_to_shared( + smem_buf, tma, + 0, static_cast(h_i), + static_cast(s_tile), static_cast(b_i), + mbar); + } else { + cp_async_bulk_tensor_4d_global_to_shared( + smem_buf, tma, + 0, static_cast(h_i), + static_cast(b_i), static_cast(s_tile), + mbar); + } +} + +template +__device__ __forceinline__ void issue_tma_store_strided( + const CUtensorMap *tma, T *smem_buf, + size_t h_i, size_t s_tile, size_t b_i) { + if constexpr (kIsBshdBshdBshd) { + cp_async_bulk_tensor_4d_shared_to_global( + tma, + 0, static_cast(h_i), + static_cast(s_tile), static_cast(b_i), + smem_buf); + } else { + cp_async_bulk_tensor_4d_shared_to_global( + tma, + 0, static_cast(h_i), + static_cast(b_i), static_cast(s_tile), + smem_buf); + } + ptx::cp_async_bulk_commit_group(); +} + +__device__ __forceinline__ void st_global_cs_uint4(uint4 *ptr, uint4 val) { + asm volatile("st.global.cs.v4.b32 [%0], {%1, %2, %3, %4};" + :: "l"(ptr), "r"(val.x), "r"(val.y), "r"(val.z), "r"(val.w) + : "memory"); +} + +// ---- Forward: BSHD/SBHD → BHSD ---- +// +// TMA load from strided input → smem → non-temporal stores to contiguous output. + +template +__launch_bounds__(tma_permute_threads) __global__ + void permute_to_grouped_tensor_fwd_kernel( + const __grid_constant__ CUtensorMap tma_q_in, + const __grid_constant__ CUtensorMap tma_k_in, + const __grid_constant__ CUtensorMap tma_v_in, + T *__restrict__ q_out, T *__restrict__ k_out, T *__restrict__ v_out, + size_t b, size_t s_q, size_t h_q, size_t d_qk, + size_t s_kv, size_t h_kv, size_t d_v, + unsigned int permute_s_splits) { +#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) + const int which = blockIdx.z; + const CUtensorMap *tma_in = which == 0 ? &tma_q_in : (which == 1 ? &tma_k_in : &tma_v_in); + T *__restrict__ tensor_out = which == 0 ? q_out : (which == 1 ? k_out : v_out); + const size_t Sdim = which == 0 ? s_q : s_kv; + const size_t Hdim = which == 0 ? h_q : h_kv; + const size_t Ddim = which == 0 ? d_qk : (which == 1 ? d_qk : d_v); const size_t h_grid = h_q > h_kv ? h_q : h_kv; const size_t b_i = static_cast(blockIdx.x) / h_grid; const size_t h_i = static_cast(blockIdx.x) % h_grid; if (b_i >= b) return; - if (which_tensor == 0) { - if (h_i >= h_q) return; - } else { - if (h_i >= h_kv) return; - } - if (Ddim % static_cast(nvec) != 0) return; + if (which == 0) { if (h_i >= h_q) return; } + else { if (h_i >= h_kv) return; } const unsigned int s_part = blockIdx.y; const size_t s_begin = (Sdim * static_cast(s_part)) / static_cast(permute_s_splits); const size_t s_end = (Sdim * static_cast(s_part + 1)) / static_cast(permute_s_splits); - const size_t S_chunk = s_end - s_begin; + if (s_begin >= s_end) return; - const size_t in_base = kIsBshdBshdBshd ? b_i * Sdim * Hdim * Ddim : b_i * Hdim * Ddim; const size_t out_base = b_i * Hdim * Sdim * Ddim + h_i * Sdim * Ddim; - const bool use_vec128 = - (Ddim % static_cast(nvec128) == 0) && - ((reinterpret_cast(tensor_in) % alignof(Vec)) == 0) && - ((reinterpret_cast(tensor_out) % alignof(Vec)) == 0); - - if (use_vec128) { - const size_t d_vec = Ddim / static_cast(nvec128); - const size_t total_work = S_chunk * d_vec; - for (size_t w = static_cast(threadIdx.x); w < total_work; - w += static_cast(blockDim.x)) { - const size_t s_local = w / d_vec; - const size_t s_i = s_begin + s_local; - const size_t v = w % d_vec; - const size_t d_off = v * static_cast(nvec128); - - const T *__restrict__ in_ptr; - if constexpr (kIsBshdBshdBshd) { - in_ptr = tensor_in + in_base + s_i * Hdim * Ddim + h_i * Ddim + d_off; - } else { - in_ptr = tensor_in + s_i * b * Hdim * Ddim + in_base + h_i * Ddim + d_off; - } - T *__restrict__ out_ptr = tensor_out + out_base + s_i * Ddim + d_off; - *reinterpret_cast *>(out_ptr) = - *reinterpret_cast *>(in_ptr); + + extern __shared__ __align__(128) char smem_raw[]; + T *smem = reinterpret_cast(smem_raw); + + __shared__ __align__(8) uint64_t mbar; + const bool is_leader = (threadIdx.x == 0); + + if (is_leader) { + ptx::mbarrier_init(&mbar, static_cast(blockDim.x)); + ptx::fence_proxy_async_shared_cta(); + } + __syncthreads(); + + constexpr size_t S_TILE = tma_permute_s_tile; + const uint32_t tile_bytes = static_cast(S_TILE * Ddim * sizeof(T)); + int parity = 0; + + for (size_t s_tile = s_begin; s_tile < s_end; s_tile += S_TILE) { + const size_t tile_rows = min(S_TILE, s_end - s_tile); + + if (is_leader) { + issue_tma_load_strided( + smem, tma_in, h_i, s_tile, b_i, &mbar, tile_bytes); + } else { + ptx::mbarrier_arrive(&mbar); } - } else { - const size_t d_vec = Ddim / static_cast(nvec); - const size_t total_work = S_chunk * d_vec; - for (size_t w = static_cast(threadIdx.x); w < total_work; - w += static_cast(blockDim.x)) { - const size_t s_local = w / d_vec; - const size_t s_i = s_begin + s_local; - const size_t v = w % d_vec; - const size_t d_off = v * static_cast(nvec); - - const T *__restrict__ in_ptr; - if constexpr (kIsBshdBshdBshd) { - in_ptr = tensor_in + in_base + s_i * Hdim * Ddim + h_i * Ddim + d_off; - } else { - in_ptr = tensor_in + s_i * b * Hdim * Ddim + in_base + h_i * Ddim + d_off; - } - T *__restrict__ out_ptr = tensor_out + out_base + s_i * Ddim + d_off; - *reinterpret_cast *>(out_ptr) = *reinterpret_cast *>(in_ptr); + + ptx::mbarrier_wait_parity(&mbar, parity); + parity ^= 1; + + T *__restrict__ out_ptr = tensor_out + out_base + s_tile * Ddim; + const size_t total_elems = tile_rows * Ddim; + constexpr size_t vec_elems = sizeof(uint4) / sizeof(T); + + for (size_t i = threadIdx.x * vec_elems; i < total_elems; + i += static_cast(blockDim.x) * vec_elems) { + uint4 v = *reinterpret_cast(smem + i); + st_global_cs_uint4(reinterpret_cast(out_ptr + i), v); } + + __syncthreads(); } + + if (is_leader) { + ptx::mbarrier_invalid(&mbar); + } +#endif } +// ---- Backward: BHSD → BSHD/SBHD ---- +// +// Vectorized loads from contiguous input → smem → TMA store to strided output. + template -__launch_bounds__(1024) __global__ void permute_to_grouped_tensor_bwd_kernel( - const T *__restrict__ grad_q, const T *__restrict__ grad_k, const T *__restrict__ grad_v, - T *__restrict__ q, T *__restrict__ k, T *__restrict__ v, size_t b, size_t s_q, size_t h_q, - size_t d_qk, size_t s_kv, size_t h_kv, size_t d_v, unsigned int permute_s_splits) { - const int which_tensor = blockIdx.z; +__launch_bounds__(tma_permute_threads) __global__ + void permute_to_grouped_tensor_bwd_kernel( + const T *__restrict__ grad_q, const T *__restrict__ grad_k, const T *__restrict__ grad_v, + const __grid_constant__ CUtensorMap tma_q_out, + const __grid_constant__ CUtensorMap tma_k_out, + const __grid_constant__ CUtensorMap tma_v_out, + size_t b, size_t s_q, size_t h_q, size_t d_qk, + size_t s_kv, size_t h_kv, size_t d_v, + unsigned int permute_s_splits) { +#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) + const int which = blockIdx.z; const T *__restrict__ tensor_in = - which_tensor == 0 ? grad_q : (which_tensor == 1 ? grad_k : grad_v); - T *__restrict__ tensor_out = which_tensor == 0 ? q : (which_tensor == 1 ? k : v); - const size_t Sdim = which_tensor == 0 ? s_q : s_kv; - const size_t Hdim = which_tensor == 0 ? h_q : h_kv; - const size_t Ddim = which_tensor == 0 ? d_qk : (which_tensor == 1 ? d_qk : d_v); + which == 0 ? grad_q : (which == 1 ? grad_k : grad_v); + const CUtensorMap *tma_out = which == 0 ? &tma_q_out : (which == 1 ? &tma_k_out : &tma_v_out); + const size_t Sdim = which == 0 ? s_q : s_kv; + const size_t Hdim = which == 0 ? h_q : h_kv; + const size_t Ddim = which == 0 ? d_qk : (which == 1 ? d_qk : d_v); const size_t h_grid = h_q > h_kv ? h_q : h_kv; const size_t b_i = static_cast(blockIdx.x) / h_grid; const size_t h_i = static_cast(blockIdx.x) % h_grid; if (b_i >= b) return; - if (which_tensor == 0) { - if (h_i >= h_q) return; - } else { - if (h_i >= h_kv) return; - } - if (Ddim % static_cast(nvec) != 0) return; + if (which == 0) { if (h_i >= h_q) return; } + else { if (h_i >= h_kv) return; } const unsigned int s_part = blockIdx.y; const size_t s_begin = (Sdim * static_cast(s_part)) / static_cast(permute_s_splits); const size_t s_end = (Sdim * static_cast(s_part + 1)) / static_cast(permute_s_splits); - const size_t S_chunk = s_end - s_begin; + if (s_begin >= s_end) return; const size_t in_base = b_i * Hdim * Sdim * Ddim + h_i * Sdim * Ddim; - const size_t out_base = - kIsBshdBshdBshd ? b_i * Sdim * Hdim * Ddim + h_i * Ddim : b_i * Hdim * Ddim + h_i * Ddim; - const bool use_vec128 = - (Ddim % static_cast(nvec128) == 0) && - ((reinterpret_cast(tensor_in) % alignof(Vec)) == 0) && - ((reinterpret_cast(tensor_out) % alignof(Vec)) == 0); - - if (use_vec128) { - const size_t d_vec = Ddim / static_cast(nvec128); - const size_t total_work = S_chunk * d_vec; - for (size_t w = static_cast(threadIdx.x); w < total_work; - w += static_cast(blockDim.x)) { - const size_t s_local = w / d_vec; - const size_t s_i = s_begin + s_local; - const size_t v = w % d_vec; - const size_t d_off = v * static_cast(nvec128); - - const T *__restrict__ in_ptr = tensor_in + in_base + s_i * Ddim + d_off; - T *__restrict__ out_ptr; - if constexpr (kIsBshdBshdBshd) { - out_ptr = tensor_out + out_base + s_i * Hdim * Ddim + d_off; - } else { - out_ptr = tensor_out + s_i * b * Hdim * Ddim + out_base + d_off; - } - *reinterpret_cast *>(out_ptr) = - *reinterpret_cast *>(in_ptr); + + extern __shared__ __align__(128) char smem_raw[]; + T *smem = reinterpret_cast(smem_raw); + + constexpr size_t S_TILE = tma_permute_s_tile; + constexpr size_t vec_elems = sizeof(uint4) / sizeof(T); + + for (size_t s_tile = s_begin; s_tile < s_end; s_tile += S_TILE) { + const size_t tile_rows = min(S_TILE, s_end - s_tile); + + const T *__restrict__ in_ptr = tensor_in + in_base + s_tile * Ddim; + const size_t total_elems = tile_rows * Ddim; + + for (size_t i = threadIdx.x * vec_elems; i < total_elems; + i += static_cast(blockDim.x) * vec_elems) { + *reinterpret_cast(smem + i) = + *reinterpret_cast(in_ptr + i); } - } else { - const size_t d_vec = Ddim / static_cast(nvec); - const size_t total_work = S_chunk * d_vec; - for (size_t w = static_cast(threadIdx.x); w < total_work; - w += static_cast(blockDim.x)) { - const size_t s_local = w / d_vec; - const size_t s_i = s_begin + s_local; - const size_t v = w % d_vec; - const size_t d_off = v * static_cast(nvec); - - const T *__restrict__ in_ptr = tensor_in + in_base + s_i * Ddim + d_off; - T *__restrict__ out_ptr; - if constexpr (kIsBshdBshdBshd) { - out_ptr = tensor_out + out_base + s_i * Hdim * Ddim + d_off; - } else { - out_ptr = tensor_out + s_i * b * Hdim * Ddim + out_base + d_off; - } - *reinterpret_cast *>(out_ptr) = *reinterpret_cast *>(in_ptr); + + ptx::fence_proxy_async_shared_cta(); + __syncthreads(); + + if (threadIdx.x == 0) { + issue_tma_store_strided(tma_out, smem, h_i, s_tile, b_i); } + + ptx::cp_async_bulk_wait_group(); + __syncthreads(); + } +#endif +} + +// Helper: create a 4D TMA descriptor for the strided (BSHD or SBHD) tensor. +// +// For BSHD [B, S, H, D]: TMA dims [D, H, S, B], box [D, 1, S_TILE, 1] +// For SBHD [S, B, H, D]: TMA dims [D, H, B, S], box [D, 1, 1, S_TILE] +static void create_strided_tensor_map(CUtensorMap &map, void *ptr, DType dtype, + size_t b, size_t s, size_t h, size_t d, + bool is_bshd) { + if (is_bshd) { + create_4D_tensor_map(map, ptr, dtype, + static_cast(d), static_cast(h), + static_cast(s), static_cast(b), + static_cast(d), 1, + static_cast(tma_permute_s_tile), 1); + } else { + create_4D_tensor_map(map, ptr, dtype, + static_cast(d), static_cast(h), + static_cast(b), static_cast(s), + static_cast(d), 1, 1, + static_cast(tma_permute_s_tile)); } } @@ -309,45 +473,57 @@ void permute_to_grouped_tensor_fwd(Tensor q, Tensor k, Tensor v, Tensor q_out, T Tensor v_out, NVTE_QKV_Layout original_layout, cudaStream_t stream) { using namespace transformer_engine; - size_t b = 0, s_q = 0, s_kv = 0, h_q = 0, h_kv = 0, d_qk = 0, d_v = 0; - b = q_out.shape()[0]; - h_q = q_out.shape()[1]; - s_q = q_out.shape()[2]; - d_qk = q_out.shape()[3]; - h_kv = k_out.shape()[1]; - s_kv = k_out.shape()[2]; - d_v = v_out.shape()[3]; - - NVTE_CHECK(d_qk % nvec == 0 && d_v % nvec == 0, - "permute_to_grouped_tensor_fwd: head dim must be divisible by vector width."); - // Split S across grid.y; work out permute_s_splits so S_chunk >= threads - const int threads = 1024; + const size_t b = q_out.shape()[0]; + const size_t h_q = q_out.shape()[1]; + const size_t s_q = q_out.shape()[2]; + const size_t d_qk = q_out.shape()[3]; + const size_t h_kv = k_out.shape()[1]; + const size_t s_kv = k_out.shape()[2]; + const size_t d_v = v_out.shape()[3]; + + NVTE_CHECK(d_qk % nvec128 == 0 && d_v % nvec128 == 0, + "permute_to_grouped_tensor_fwd: head dim must be divisible by ", nvec128, "."); + + const bool is_bshd = (original_layout == NVTE_QKV_Layout::NVTE_BSHD_BSHD_BSHD); + + alignas(64) CUtensorMap tma_q_in{}, tma_k_in{}, tma_v_in{}; + create_strided_tensor_map(tma_q_in, q.data.dptr, q.dtype(), b, s_q, h_q, d_qk, is_bshd); + create_strided_tensor_map(tma_k_in, k.data.dptr, k.dtype(), b, s_kv, h_kv, d_qk, is_bshd); + create_strided_tensor_map(tma_v_in, v.data.dptr, v.dtype(), b, s_kv, h_kv, d_v, is_bshd); + const size_t s_min = std::min(s_q, s_kv); const unsigned int permute_s_splits = - std::max(1u, static_cast(s_min / static_cast(threads))); + std::max(1u, static_cast(s_min / static_cast(tma_permute_threads))); const size_t h_grid = std::max(h_q, h_kv); - dim3 grid(b * h_grid, permute_s_splits, 3); + dim3 grid(static_cast(b * h_grid), permute_s_splits, 3); - if (original_layout == NVTE_QKV_Layout::NVTE_BSHD_BSHD_BSHD) { + const size_t d_max = std::max(d_qk, d_v); + const size_t smem_bytes = tma_permute_s_tile * d_max * sizeof(uint16_t); + + if (is_bshd) { TRANSFORMER_ENGINE_TYPE_SWITCH_16BIT( q.dtype(), dtype, - permute_to_grouped_tensor_fwd_kernel<<>>( - reinterpret_cast(q.data.dptr), - reinterpret_cast(k.data.dptr), - reinterpret_cast(v.data.dptr), - reinterpret_cast(q_out.data.dptr), reinterpret_cast(k_out.data.dptr), - reinterpret_cast(v_out.data.dptr), b, s_q, h_q, d_qk, s_kv, h_kv, d_v, - permute_s_splits);); + auto kernel = permute_to_grouped_tensor_fwd_kernel; + NVTE_CHECK_CUDA(cudaFuncSetAttribute( + kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_bytes)); + kernel<<>>( + tma_q_in, tma_k_in, tma_v_in, + reinterpret_cast(q_out.data.dptr), + reinterpret_cast(k_out.data.dptr), + reinterpret_cast(v_out.data.dptr), + b, s_q, h_q, d_qk, s_kv, h_kv, d_v, permute_s_splits);); } else { TRANSFORMER_ENGINE_TYPE_SWITCH_16BIT( q.dtype(), dtype, - permute_to_grouped_tensor_fwd_kernel<<>>( - reinterpret_cast(q.data.dptr), - reinterpret_cast(k.data.dptr), - reinterpret_cast(v.data.dptr), - reinterpret_cast(q_out.data.dptr), reinterpret_cast(k_out.data.dptr), - reinterpret_cast(v_out.data.dptr), b, s_q, h_q, d_qk, s_kv, h_kv, d_v, - permute_s_splits);); + auto kernel = permute_to_grouped_tensor_fwd_kernel; + NVTE_CHECK_CUDA(cudaFuncSetAttribute( + kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_bytes)); + kernel<<>>( + tma_q_in, tma_k_in, tma_v_in, + reinterpret_cast(q_out.data.dptr), + reinterpret_cast(k_out.data.dptr), + reinterpret_cast(v_out.data.dptr), + b, s_q, h_q, d_qk, s_kv, h_kv, d_v, permute_s_splits);); } NVTE_CHECK_CUDA(cudaGetLastError()); } @@ -355,44 +531,57 @@ void permute_to_grouped_tensor_fwd(Tensor q, Tensor k, Tensor v, Tensor q_out, T void permute_to_grouped_tensor_bwd(Tensor grad_q, Tensor grad_k, Tensor grad_v, Tensor q, Tensor k, Tensor v, NVTE_QKV_Layout original_layout, cudaStream_t stream) { using namespace transformer_engine; - size_t b = 0, s_q = 0, s_kv = 0, h_q = 0, h_kv = 0, d_qk = 0, d_v = 0; - b = grad_q.shape()[0]; - h_q = grad_q.shape()[1]; - s_q = grad_q.shape()[2]; - d_qk = grad_q.shape()[3]; - h_kv = grad_k.shape()[1]; - s_kv = grad_k.shape()[2]; - d_v = grad_v.shape()[3]; - - NVTE_CHECK(d_qk % nvec == 0 && d_v % nvec == 0, - "permute_to_grouped_tensor_bwd: head dim must be divisible by vector width."); - const int threads = 1024; + const size_t b = grad_q.shape()[0]; + const size_t h_q = grad_q.shape()[1]; + const size_t s_q = grad_q.shape()[2]; + const size_t d_qk = grad_q.shape()[3]; + const size_t h_kv = grad_k.shape()[1]; + const size_t s_kv = grad_k.shape()[2]; + const size_t d_v = grad_v.shape()[3]; + + NVTE_CHECK(d_qk % nvec128 == 0 && d_v % nvec128 == 0, + "permute_to_grouped_tensor_bwd: head dim must be divisible by ", nvec128, "."); + + const bool is_bshd = (original_layout == NVTE_QKV_Layout::NVTE_BSHD_BSHD_BSHD); + + alignas(64) CUtensorMap tma_q_out{}, tma_k_out{}, tma_v_out{}; + create_strided_tensor_map(tma_q_out, q.data.dptr, q.dtype(), b, s_q, h_q, d_qk, is_bshd); + create_strided_tensor_map(tma_k_out, k.data.dptr, k.dtype(), b, s_kv, h_kv, d_qk, is_bshd); + create_strided_tensor_map(tma_v_out, v.data.dptr, v.dtype(), b, s_kv, h_kv, d_v, is_bshd); + const size_t s_min = std::min(s_q, s_kv); const unsigned int permute_s_splits = - std::max(1u, static_cast(s_min / static_cast(threads))); + std::max(1u, static_cast(s_min / static_cast(tma_permute_threads))); const size_t h_grid = std::max(h_q, h_kv); - dim3 grid(b * h_grid, permute_s_splits, 3); + dim3 grid(static_cast(b * h_grid), permute_s_splits, 3); + + const size_t d_max = std::max(d_qk, d_v); + const size_t smem_bytes = tma_permute_s_tile * d_max * sizeof(uint16_t); - if (original_layout == NVTE_QKV_Layout::NVTE_BSHD_BSHD_BSHD) { + if (is_bshd) { TRANSFORMER_ENGINE_TYPE_SWITCH_16BIT( grad_q.dtype(), dtype, - permute_to_grouped_tensor_bwd_kernel<<>>( + auto kernel = permute_to_grouped_tensor_bwd_kernel; + NVTE_CHECK_CUDA(cudaFuncSetAttribute( + kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_bytes)); + kernel<<>>( reinterpret_cast(grad_q.data.dptr), reinterpret_cast(grad_k.data.dptr), reinterpret_cast(grad_v.data.dptr), - reinterpret_cast(q.data.dptr), reinterpret_cast(k.data.dptr), - reinterpret_cast(v.data.dptr), b, s_q, h_q, d_qk, s_kv, h_kv, d_v, - permute_s_splits);); + tma_q_out, tma_k_out, tma_v_out, + b, s_q, h_q, d_qk, s_kv, h_kv, d_v, permute_s_splits);); } else { TRANSFORMER_ENGINE_TYPE_SWITCH_16BIT( grad_q.dtype(), dtype, - permute_to_grouped_tensor_bwd_kernel<<>>( + auto kernel = permute_to_grouped_tensor_bwd_kernel; + NVTE_CHECK_CUDA(cudaFuncSetAttribute( + kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_bytes)); + kernel<<>>( reinterpret_cast(grad_q.data.dptr), reinterpret_cast(grad_k.data.dptr), reinterpret_cast(grad_v.data.dptr), - reinterpret_cast(q.data.dptr), reinterpret_cast(k.data.dptr), - reinterpret_cast(v.data.dptr), b, s_q, h_q, d_qk, s_kv, h_kv, d_v, - permute_s_splits);); + tma_q_out, tma_k_out, tma_v_out, + b, s_q, h_q, d_qk, s_kv, h_kv, d_v, permute_s_splits);); } NVTE_CHECK_CUDA(cudaGetLastError()); } From d7c27f660260dc6dbb99f690c6a92030780ec7f0 Mon Sep 17 00:00:00 2001 From: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Date: Fri, 3 Apr 2026 17:46:21 -0700 Subject: [PATCH 165/172] Revert "add TMA permute" This reverts commit 2532a50e829144bee290fc94acb8f3f154a62ea9. Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --- .../common/fused_attn/flash_attn.cu | 543 ++++++------------ 1 file changed, 177 insertions(+), 366 deletions(-) diff --git a/transformer_engine/common/fused_attn/flash_attn.cu b/transformer_engine/common/fused_attn/flash_attn.cu index 8f03c8ed0c..c30cc3d2d9 100644 --- a/transformer_engine/common/fused_attn/flash_attn.cu +++ b/transformer_engine/common/fused_attn/flash_attn.cu @@ -4,13 +4,7 @@ * See LICENSE for license information. ************************************************************************/ -#include -#include - #include "../common.h" -#include "../util/cuda_driver.h" -#include "../util/ptx.cuh" -#include "../utils.cuh" #include "transformer_engine/fused_attn.h" namespace transformer_engine { @@ -29,97 +23,6 @@ constexpr int nvec128 = sizeof(uint4) / type_size; constexpr int load_size = warp_size * nvec; constexpr int block_size = 512; -// TMA permute kernel configuration -constexpr int tma_permute_threads = 32; -constexpr int tma_permute_s_tile = 64; - -// ---- 4D TMA PTX wrappers ---- - -__device__ __forceinline__ void cp_async_bulk_tensor_4d_global_to_shared( - void *dst_shmem, const CUtensorMap *tensor_map, - uint32_t c0, uint32_t c1, uint32_t c2, uint32_t c3, - uint64_t *mbar) { -#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) - uint32_t dst = __cvta_generic_to_shared(dst_shmem); - uint32_t bar = __cvta_generic_to_shared(mbar); - asm volatile( - "cp.async.bulk.tensor.4d.shared::cluster.global.tile" - ".mbarrier::complete_tx::bytes [%0], [%1, {%2, %3, %4, %5}], [%6];" - ::"r"(dst), "l"(tensor_map), - "r"(c0), "r"(c1), "r"(c2), "r"(c3), - "r"(bar) - : "memory"); -#else - NVTE_DEVICE_ERROR("cp_async_bulk_tensor_4d_global_to_shared requires SM 10.0+."); -#endif -} - -__device__ __forceinline__ void cp_async_bulk_tensor_4d_shared_to_global( - const CUtensorMap *tensor_map, - uint32_t c0, uint32_t c1, uint32_t c2, uint32_t c3, - void *src_shmem) { -#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) - uint32_t src = __cvta_generic_to_shared(src_shmem); - asm volatile( - "cp.async.bulk.tensor.4d.global.shared::cta.bulk_group" - " [%0, {%1, %2, %3, %4}], [%5];" - ::"l"(tensor_map), - "r"(c0), "r"(c1), "r"(c2), "r"(c3), - "r"(src) - : "memory"); -#else - NVTE_DEVICE_ERROR("cp_async_bulk_tensor_4d_shared_to_global requires SM 9.0+."); -#endif -} - -// ---- Host-side 4D tensor map creation ---- -// -// Creates a 4D TMA descriptor for a densely-packed tensor whose logical -// dimensions (innermost-first) are [dim0, dim1, dim2, dim3]. -// -// The box (tile) copied per TMA instruction is [box0, box1, box2, box3]. - -static void create_4D_tensor_map(CUtensorMap &tensorMap, void *dataPtr, DType dtype, - uint64_t dim0, uint64_t dim1, uint64_t dim2, uint64_t dim3, - uint32_t box0, uint32_t box1, uint32_t box2, uint32_t box3) { - cuda_driver::ensure_context_exists(); - static PFN_cuTensorMapEncodeTiled_v12000 cuDriverTensorMapEncodeTiled = []() { - void *ptr = cuda_driver::get_symbol("cuTensorMapEncodeTiled"); - return reinterpret_cast(ptr); - }(); - - CUtensorMapDataType tma_dtype; - size_t elem_bytes; - switch (dtype) { - case DType::kFloat16: - tma_dtype = CU_TENSOR_MAP_DATA_TYPE_FLOAT16; - elem_bytes = 2; - break; - case DType::kBFloat16: - tma_dtype = CU_TENSOR_MAP_DATA_TYPE_BFLOAT16; - elem_bytes = 2; - break; - default: - NVTE_ERROR("create_4D_tensor_map: unsupported dtype"); - } - - constexpr uint32_t rank = 4; - uint64_t size[rank] = {dim0, dim1, dim2, dim3}; - uint64_t stride[rank - 1] = { - dim0 * elem_bytes, - dim0 * dim1 * elem_bytes, - dim0 * dim1 * dim2 * elem_bytes, - }; - uint32_t boxSize[rank] = {box0, box1, box2, box3}; - uint32_t elemStride[rank] = {1, 1, 1, 1}; - - NVTE_CHECK_CUDA_DRIVER(cuDriverTensorMapEncodeTiled( - &tensorMap, tma_dtype, rank, dataPtr, size, stride, boxSize, elemStride, - CU_TENSOR_MAP_INTERLEAVE_NONE, CU_TENSOR_MAP_SWIZZLE_NONE, - CU_TENSOR_MAP_L2_PROMOTION_NONE, - CU_TENSOR_MAP_FLOAT_OOB_FILL_NAN_REQUEST_ZERO_FMA)); -} - template __launch_bounds__(block_size) __global__ void prepare_kernel_fwd(const T *qkvi, T *qkv, const size_t B, const size_t S, const size_t Z, @@ -237,235 +140,168 @@ void prepare_flash_attn_bwd(Tensor q, Tensor k, Tensor v, Tensor qkv, cudaStream NVTE_CHECK_CUDA(cudaGetLastError()); } -// ---- TMA helpers for strided (BSHD/SBHD) tensors ---- -// -// Strided BSHD: TMA dims [D, H, S, B], coords [0, h, s, b] -// Strided SBHD: TMA dims [D, H, B, S], coords [0, h, b, s] - template -__device__ __forceinline__ void issue_tma_load_strided( - T *smem_buf, const CUtensorMap *tma, - size_t h_i, size_t s_tile, size_t b_i, - uint64_t *mbar, size_t tile_bytes) { - ptx::mbarrier_arrive_expect_tx(mbar, static_cast(tile_bytes)); - if constexpr (kIsBshdBshdBshd) { - cp_async_bulk_tensor_4d_global_to_shared( - smem_buf, tma, - 0, static_cast(h_i), - static_cast(s_tile), static_cast(b_i), - mbar); - } else { - cp_async_bulk_tensor_4d_global_to_shared( - smem_buf, tma, - 0, static_cast(h_i), - static_cast(b_i), static_cast(s_tile), - mbar); - } -} - -template -__device__ __forceinline__ void issue_tma_store_strided( - const CUtensorMap *tma, T *smem_buf, - size_t h_i, size_t s_tile, size_t b_i) { - if constexpr (kIsBshdBshdBshd) { - cp_async_bulk_tensor_4d_shared_to_global( - tma, - 0, static_cast(h_i), - static_cast(s_tile), static_cast(b_i), - smem_buf); - } else { - cp_async_bulk_tensor_4d_shared_to_global( - tma, - 0, static_cast(h_i), - static_cast(b_i), static_cast(s_tile), - smem_buf); - } - ptx::cp_async_bulk_commit_group(); -} - -__device__ __forceinline__ void st_global_cs_uint4(uint4 *ptr, uint4 val) { - asm volatile("st.global.cs.v4.b32 [%0], {%1, %2, %3, %4};" - :: "l"(ptr), "r"(val.x), "r"(val.y), "r"(val.z), "r"(val.w) - : "memory"); -} - -// ---- Forward: BSHD/SBHD → BHSD ---- -// -// TMA load from strided input → smem → non-temporal stores to contiguous output. - -template -__launch_bounds__(tma_permute_threads) __global__ - void permute_to_grouped_tensor_fwd_kernel( - const __grid_constant__ CUtensorMap tma_q_in, - const __grid_constant__ CUtensorMap tma_k_in, - const __grid_constant__ CUtensorMap tma_v_in, - T *__restrict__ q_out, T *__restrict__ k_out, T *__restrict__ v_out, - size_t b, size_t s_q, size_t h_q, size_t d_qk, - size_t s_kv, size_t h_kv, size_t d_v, - unsigned int permute_s_splits) { -#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) - const int which = blockIdx.z; - const CUtensorMap *tma_in = which == 0 ? &tma_q_in : (which == 1 ? &tma_k_in : &tma_v_in); - T *__restrict__ tensor_out = which == 0 ? q_out : (which == 1 ? k_out : v_out); - const size_t Sdim = which == 0 ? s_q : s_kv; - const size_t Hdim = which == 0 ? h_q : h_kv; - const size_t Ddim = which == 0 ? d_qk : (which == 1 ? d_qk : d_v); +__launch_bounds__(1024) __global__ + void permute_to_grouped_tensor_fwd_kernel(const T *__restrict__ q, const T *__restrict__ k, + const T *__restrict__ v, T *__restrict__ q_out, + T *__restrict__ k_out, T *__restrict__ v_out, + size_t b, size_t s_q, size_t h_q, size_t d_qk, + size_t s_kv, size_t h_kv, size_t d_v, + unsigned int permute_s_splits) { + const int which_tensor = blockIdx.z; + const T *__restrict__ tensor_in = which_tensor == 0 ? q : (which_tensor == 1 ? k : v); + T *__restrict__ tensor_out = which_tensor == 0 ? q_out : (which_tensor == 1 ? k_out : v_out); + const size_t Sdim = which_tensor == 0 ? s_q : s_kv; + const size_t Hdim = which_tensor == 0 ? h_q : h_kv; + const size_t Ddim = which_tensor == 0 ? d_qk : (which_tensor == 1 ? d_qk : d_v); const size_t h_grid = h_q > h_kv ? h_q : h_kv; const size_t b_i = static_cast(blockIdx.x) / h_grid; const size_t h_i = static_cast(blockIdx.x) % h_grid; if (b_i >= b) return; - if (which == 0) { if (h_i >= h_q) return; } - else { if (h_i >= h_kv) return; } + if (which_tensor == 0) { + if (h_i >= h_q) return; + } else { + if (h_i >= h_kv) return; + } + if (Ddim % static_cast(nvec) != 0) return; const unsigned int s_part = blockIdx.y; const size_t s_begin = (Sdim * static_cast(s_part)) / static_cast(permute_s_splits); const size_t s_end = (Sdim * static_cast(s_part + 1)) / static_cast(permute_s_splits); - if (s_begin >= s_end) return; + const size_t S_chunk = s_end - s_begin; + const size_t in_base = kIsBshdBshdBshd ? b_i * Sdim * Hdim * Ddim : b_i * Hdim * Ddim; const size_t out_base = b_i * Hdim * Sdim * Ddim + h_i * Sdim * Ddim; - - extern __shared__ __align__(128) char smem_raw[]; - T *smem = reinterpret_cast(smem_raw); - - __shared__ __align__(8) uint64_t mbar; - const bool is_leader = (threadIdx.x == 0); - - if (is_leader) { - ptx::mbarrier_init(&mbar, static_cast(blockDim.x)); - ptx::fence_proxy_async_shared_cta(); - } - __syncthreads(); - - constexpr size_t S_TILE = tma_permute_s_tile; - const uint32_t tile_bytes = static_cast(S_TILE * Ddim * sizeof(T)); - int parity = 0; - - for (size_t s_tile = s_begin; s_tile < s_end; s_tile += S_TILE) { - const size_t tile_rows = min(S_TILE, s_end - s_tile); - - if (is_leader) { - issue_tma_load_strided( - smem, tma_in, h_i, s_tile, b_i, &mbar, tile_bytes); - } else { - ptx::mbarrier_arrive(&mbar); + const bool use_vec128 = + (Ddim % static_cast(nvec128) == 0) && + ((reinterpret_cast(tensor_in) % alignof(Vec)) == 0) && + ((reinterpret_cast(tensor_out) % alignof(Vec)) == 0); + + if (use_vec128) { + const size_t d_vec = Ddim / static_cast(nvec128); + const size_t total_work = S_chunk * d_vec; + for (size_t w = static_cast(threadIdx.x); w < total_work; + w += static_cast(blockDim.x)) { + const size_t s_local = w / d_vec; + const size_t s_i = s_begin + s_local; + const size_t v = w % d_vec; + const size_t d_off = v * static_cast(nvec128); + + const T *__restrict__ in_ptr; + if constexpr (kIsBshdBshdBshd) { + in_ptr = tensor_in + in_base + s_i * Hdim * Ddim + h_i * Ddim + d_off; + } else { + in_ptr = tensor_in + s_i * b * Hdim * Ddim + in_base + h_i * Ddim + d_off; + } + T *__restrict__ out_ptr = tensor_out + out_base + s_i * Ddim + d_off; + *reinterpret_cast *>(out_ptr) = + *reinterpret_cast *>(in_ptr); } - - ptx::mbarrier_wait_parity(&mbar, parity); - parity ^= 1; - - T *__restrict__ out_ptr = tensor_out + out_base + s_tile * Ddim; - const size_t total_elems = tile_rows * Ddim; - constexpr size_t vec_elems = sizeof(uint4) / sizeof(T); - - for (size_t i = threadIdx.x * vec_elems; i < total_elems; - i += static_cast(blockDim.x) * vec_elems) { - uint4 v = *reinterpret_cast(smem + i); - st_global_cs_uint4(reinterpret_cast(out_ptr + i), v); + } else { + const size_t d_vec = Ddim / static_cast(nvec); + const size_t total_work = S_chunk * d_vec; + for (size_t w = static_cast(threadIdx.x); w < total_work; + w += static_cast(blockDim.x)) { + const size_t s_local = w / d_vec; + const size_t s_i = s_begin + s_local; + const size_t v = w % d_vec; + const size_t d_off = v * static_cast(nvec); + + const T *__restrict__ in_ptr; + if constexpr (kIsBshdBshdBshd) { + in_ptr = tensor_in + in_base + s_i * Hdim * Ddim + h_i * Ddim + d_off; + } else { + in_ptr = tensor_in + s_i * b * Hdim * Ddim + in_base + h_i * Ddim + d_off; + } + T *__restrict__ out_ptr = tensor_out + out_base + s_i * Ddim + d_off; + *reinterpret_cast *>(out_ptr) = *reinterpret_cast *>(in_ptr); } - - __syncthreads(); } - - if (is_leader) { - ptx::mbarrier_invalid(&mbar); - } -#endif } -// ---- Backward: BHSD → BSHD/SBHD ---- -// -// Vectorized loads from contiguous input → smem → TMA store to strided output. - template -__launch_bounds__(tma_permute_threads) __global__ - void permute_to_grouped_tensor_bwd_kernel( - const T *__restrict__ grad_q, const T *__restrict__ grad_k, const T *__restrict__ grad_v, - const __grid_constant__ CUtensorMap tma_q_out, - const __grid_constant__ CUtensorMap tma_k_out, - const __grid_constant__ CUtensorMap tma_v_out, - size_t b, size_t s_q, size_t h_q, size_t d_qk, - size_t s_kv, size_t h_kv, size_t d_v, - unsigned int permute_s_splits) { -#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) - const int which = blockIdx.z; +__launch_bounds__(1024) __global__ void permute_to_grouped_tensor_bwd_kernel( + const T *__restrict__ grad_q, const T *__restrict__ grad_k, const T *__restrict__ grad_v, + T *__restrict__ q, T *__restrict__ k, T *__restrict__ v, size_t b, size_t s_q, size_t h_q, + size_t d_qk, size_t s_kv, size_t h_kv, size_t d_v, unsigned int permute_s_splits) { + const int which_tensor = blockIdx.z; const T *__restrict__ tensor_in = - which == 0 ? grad_q : (which == 1 ? grad_k : grad_v); - const CUtensorMap *tma_out = which == 0 ? &tma_q_out : (which == 1 ? &tma_k_out : &tma_v_out); - const size_t Sdim = which == 0 ? s_q : s_kv; - const size_t Hdim = which == 0 ? h_q : h_kv; - const size_t Ddim = which == 0 ? d_qk : (which == 1 ? d_qk : d_v); + which_tensor == 0 ? grad_q : (which_tensor == 1 ? grad_k : grad_v); + T *__restrict__ tensor_out = which_tensor == 0 ? q : (which_tensor == 1 ? k : v); + const size_t Sdim = which_tensor == 0 ? s_q : s_kv; + const size_t Hdim = which_tensor == 0 ? h_q : h_kv; + const size_t Ddim = which_tensor == 0 ? d_qk : (which_tensor == 1 ? d_qk : d_v); const size_t h_grid = h_q > h_kv ? h_q : h_kv; const size_t b_i = static_cast(blockIdx.x) / h_grid; const size_t h_i = static_cast(blockIdx.x) % h_grid; if (b_i >= b) return; - if (which == 0) { if (h_i >= h_q) return; } - else { if (h_i >= h_kv) return; } + if (which_tensor == 0) { + if (h_i >= h_q) return; + } else { + if (h_i >= h_kv) return; + } + if (Ddim % static_cast(nvec) != 0) return; const unsigned int s_part = blockIdx.y; const size_t s_begin = (Sdim * static_cast(s_part)) / static_cast(permute_s_splits); const size_t s_end = (Sdim * static_cast(s_part + 1)) / static_cast(permute_s_splits); - if (s_begin >= s_end) return; + const size_t S_chunk = s_end - s_begin; const size_t in_base = b_i * Hdim * Sdim * Ddim + h_i * Sdim * Ddim; - - extern __shared__ __align__(128) char smem_raw[]; - T *smem = reinterpret_cast(smem_raw); - - constexpr size_t S_TILE = tma_permute_s_tile; - constexpr size_t vec_elems = sizeof(uint4) / sizeof(T); - - for (size_t s_tile = s_begin; s_tile < s_end; s_tile += S_TILE) { - const size_t tile_rows = min(S_TILE, s_end - s_tile); - - const T *__restrict__ in_ptr = tensor_in + in_base + s_tile * Ddim; - const size_t total_elems = tile_rows * Ddim; - - for (size_t i = threadIdx.x * vec_elems; i < total_elems; - i += static_cast(blockDim.x) * vec_elems) { - *reinterpret_cast(smem + i) = - *reinterpret_cast(in_ptr + i); - } - - ptx::fence_proxy_async_shared_cta(); - __syncthreads(); - - if (threadIdx.x == 0) { - issue_tma_store_strided(tma_out, smem, h_i, s_tile, b_i); + const size_t out_base = + kIsBshdBshdBshd ? b_i * Sdim * Hdim * Ddim + h_i * Ddim : b_i * Hdim * Ddim + h_i * Ddim; + const bool use_vec128 = + (Ddim % static_cast(nvec128) == 0) && + ((reinterpret_cast(tensor_in) % alignof(Vec)) == 0) && + ((reinterpret_cast(tensor_out) % alignof(Vec)) == 0); + + if (use_vec128) { + const size_t d_vec = Ddim / static_cast(nvec128); + const size_t total_work = S_chunk * d_vec; + for (size_t w = static_cast(threadIdx.x); w < total_work; + w += static_cast(blockDim.x)) { + const size_t s_local = w / d_vec; + const size_t s_i = s_begin + s_local; + const size_t v = w % d_vec; + const size_t d_off = v * static_cast(nvec128); + + const T *__restrict__ in_ptr = tensor_in + in_base + s_i * Ddim + d_off; + T *__restrict__ out_ptr; + if constexpr (kIsBshdBshdBshd) { + out_ptr = tensor_out + out_base + s_i * Hdim * Ddim + d_off; + } else { + out_ptr = tensor_out + s_i * b * Hdim * Ddim + out_base + d_off; + } + *reinterpret_cast *>(out_ptr) = + *reinterpret_cast *>(in_ptr); } - - ptx::cp_async_bulk_wait_group(); - __syncthreads(); - } -#endif -} - -// Helper: create a 4D TMA descriptor for the strided (BSHD or SBHD) tensor. -// -// For BSHD [B, S, H, D]: TMA dims [D, H, S, B], box [D, 1, S_TILE, 1] -// For SBHD [S, B, H, D]: TMA dims [D, H, B, S], box [D, 1, 1, S_TILE] -static void create_strided_tensor_map(CUtensorMap &map, void *ptr, DType dtype, - size_t b, size_t s, size_t h, size_t d, - bool is_bshd) { - if (is_bshd) { - create_4D_tensor_map(map, ptr, dtype, - static_cast(d), static_cast(h), - static_cast(s), static_cast(b), - static_cast(d), 1, - static_cast(tma_permute_s_tile), 1); } else { - create_4D_tensor_map(map, ptr, dtype, - static_cast(d), static_cast(h), - static_cast(b), static_cast(s), - static_cast(d), 1, 1, - static_cast(tma_permute_s_tile)); + const size_t d_vec = Ddim / static_cast(nvec); + const size_t total_work = S_chunk * d_vec; + for (size_t w = static_cast(threadIdx.x); w < total_work; + w += static_cast(blockDim.x)) { + const size_t s_local = w / d_vec; + const size_t s_i = s_begin + s_local; + const size_t v = w % d_vec; + const size_t d_off = v * static_cast(nvec); + + const T *__restrict__ in_ptr = tensor_in + in_base + s_i * Ddim + d_off; + T *__restrict__ out_ptr; + if constexpr (kIsBshdBshdBshd) { + out_ptr = tensor_out + out_base + s_i * Hdim * Ddim + d_off; + } else { + out_ptr = tensor_out + s_i * b * Hdim * Ddim + out_base + d_off; + } + *reinterpret_cast *>(out_ptr) = *reinterpret_cast *>(in_ptr); + } } } @@ -473,57 +309,45 @@ void permute_to_grouped_tensor_fwd(Tensor q, Tensor k, Tensor v, Tensor q_out, T Tensor v_out, NVTE_QKV_Layout original_layout, cudaStream_t stream) { using namespace transformer_engine; - const size_t b = q_out.shape()[0]; - const size_t h_q = q_out.shape()[1]; - const size_t s_q = q_out.shape()[2]; - const size_t d_qk = q_out.shape()[3]; - const size_t h_kv = k_out.shape()[1]; - const size_t s_kv = k_out.shape()[2]; - const size_t d_v = v_out.shape()[3]; - - NVTE_CHECK(d_qk % nvec128 == 0 && d_v % nvec128 == 0, - "permute_to_grouped_tensor_fwd: head dim must be divisible by ", nvec128, "."); - - const bool is_bshd = (original_layout == NVTE_QKV_Layout::NVTE_BSHD_BSHD_BSHD); - - alignas(64) CUtensorMap tma_q_in{}, tma_k_in{}, tma_v_in{}; - create_strided_tensor_map(tma_q_in, q.data.dptr, q.dtype(), b, s_q, h_q, d_qk, is_bshd); - create_strided_tensor_map(tma_k_in, k.data.dptr, k.dtype(), b, s_kv, h_kv, d_qk, is_bshd); - create_strided_tensor_map(tma_v_in, v.data.dptr, v.dtype(), b, s_kv, h_kv, d_v, is_bshd); - + size_t b = 0, s_q = 0, s_kv = 0, h_q = 0, h_kv = 0, d_qk = 0, d_v = 0; + b = q_out.shape()[0]; + h_q = q_out.shape()[1]; + s_q = q_out.shape()[2]; + d_qk = q_out.shape()[3]; + h_kv = k_out.shape()[1]; + s_kv = k_out.shape()[2]; + d_v = v_out.shape()[3]; + + NVTE_CHECK(d_qk % nvec == 0 && d_v % nvec == 0, + "permute_to_grouped_tensor_fwd: head dim must be divisible by vector width."); + // Split S across grid.y; work out permute_s_splits so S_chunk >= threads + const int threads = 1024; const size_t s_min = std::min(s_q, s_kv); const unsigned int permute_s_splits = - std::max(1u, static_cast(s_min / static_cast(tma_permute_threads))); + std::max(1u, static_cast(s_min / static_cast(threads))); const size_t h_grid = std::max(h_q, h_kv); - dim3 grid(static_cast(b * h_grid), permute_s_splits, 3); + dim3 grid(b * h_grid, permute_s_splits, 3); - const size_t d_max = std::max(d_qk, d_v); - const size_t smem_bytes = tma_permute_s_tile * d_max * sizeof(uint16_t); - - if (is_bshd) { + if (original_layout == NVTE_QKV_Layout::NVTE_BSHD_BSHD_BSHD) { TRANSFORMER_ENGINE_TYPE_SWITCH_16BIT( q.dtype(), dtype, - auto kernel = permute_to_grouped_tensor_fwd_kernel; - NVTE_CHECK_CUDA(cudaFuncSetAttribute( - kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_bytes)); - kernel<<>>( - tma_q_in, tma_k_in, tma_v_in, - reinterpret_cast(q_out.data.dptr), - reinterpret_cast(k_out.data.dptr), - reinterpret_cast(v_out.data.dptr), - b, s_q, h_q, d_qk, s_kv, h_kv, d_v, permute_s_splits);); + permute_to_grouped_tensor_fwd_kernel<<>>( + reinterpret_cast(q.data.dptr), + reinterpret_cast(k.data.dptr), + reinterpret_cast(v.data.dptr), + reinterpret_cast(q_out.data.dptr), reinterpret_cast(k_out.data.dptr), + reinterpret_cast(v_out.data.dptr), b, s_q, h_q, d_qk, s_kv, h_kv, d_v, + permute_s_splits);); } else { TRANSFORMER_ENGINE_TYPE_SWITCH_16BIT( q.dtype(), dtype, - auto kernel = permute_to_grouped_tensor_fwd_kernel; - NVTE_CHECK_CUDA(cudaFuncSetAttribute( - kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_bytes)); - kernel<<>>( - tma_q_in, tma_k_in, tma_v_in, - reinterpret_cast(q_out.data.dptr), - reinterpret_cast(k_out.data.dptr), - reinterpret_cast(v_out.data.dptr), - b, s_q, h_q, d_qk, s_kv, h_kv, d_v, permute_s_splits);); + permute_to_grouped_tensor_fwd_kernel<<>>( + reinterpret_cast(q.data.dptr), + reinterpret_cast(k.data.dptr), + reinterpret_cast(v.data.dptr), + reinterpret_cast(q_out.data.dptr), reinterpret_cast(k_out.data.dptr), + reinterpret_cast(v_out.data.dptr), b, s_q, h_q, d_qk, s_kv, h_kv, d_v, + permute_s_splits);); } NVTE_CHECK_CUDA(cudaGetLastError()); } @@ -531,57 +355,44 @@ void permute_to_grouped_tensor_fwd(Tensor q, Tensor k, Tensor v, Tensor q_out, T void permute_to_grouped_tensor_bwd(Tensor grad_q, Tensor grad_k, Tensor grad_v, Tensor q, Tensor k, Tensor v, NVTE_QKV_Layout original_layout, cudaStream_t stream) { using namespace transformer_engine; - const size_t b = grad_q.shape()[0]; - const size_t h_q = grad_q.shape()[1]; - const size_t s_q = grad_q.shape()[2]; - const size_t d_qk = grad_q.shape()[3]; - const size_t h_kv = grad_k.shape()[1]; - const size_t s_kv = grad_k.shape()[2]; - const size_t d_v = grad_v.shape()[3]; - - NVTE_CHECK(d_qk % nvec128 == 0 && d_v % nvec128 == 0, - "permute_to_grouped_tensor_bwd: head dim must be divisible by ", nvec128, "."); - - const bool is_bshd = (original_layout == NVTE_QKV_Layout::NVTE_BSHD_BSHD_BSHD); - - alignas(64) CUtensorMap tma_q_out{}, tma_k_out{}, tma_v_out{}; - create_strided_tensor_map(tma_q_out, q.data.dptr, q.dtype(), b, s_q, h_q, d_qk, is_bshd); - create_strided_tensor_map(tma_k_out, k.data.dptr, k.dtype(), b, s_kv, h_kv, d_qk, is_bshd); - create_strided_tensor_map(tma_v_out, v.data.dptr, v.dtype(), b, s_kv, h_kv, d_v, is_bshd); - + size_t b = 0, s_q = 0, s_kv = 0, h_q = 0, h_kv = 0, d_qk = 0, d_v = 0; + b = grad_q.shape()[0]; + h_q = grad_q.shape()[1]; + s_q = grad_q.shape()[2]; + d_qk = grad_q.shape()[3]; + h_kv = grad_k.shape()[1]; + s_kv = grad_k.shape()[2]; + d_v = grad_v.shape()[3]; + + NVTE_CHECK(d_qk % nvec == 0 && d_v % nvec == 0, + "permute_to_grouped_tensor_bwd: head dim must be divisible by vector width."); + const int threads = 1024; const size_t s_min = std::min(s_q, s_kv); const unsigned int permute_s_splits = - std::max(1u, static_cast(s_min / static_cast(tma_permute_threads))); + std::max(1u, static_cast(s_min / static_cast(threads))); const size_t h_grid = std::max(h_q, h_kv); - dim3 grid(static_cast(b * h_grid), permute_s_splits, 3); - - const size_t d_max = std::max(d_qk, d_v); - const size_t smem_bytes = tma_permute_s_tile * d_max * sizeof(uint16_t); + dim3 grid(b * h_grid, permute_s_splits, 3); - if (is_bshd) { + if (original_layout == NVTE_QKV_Layout::NVTE_BSHD_BSHD_BSHD) { TRANSFORMER_ENGINE_TYPE_SWITCH_16BIT( grad_q.dtype(), dtype, - auto kernel = permute_to_grouped_tensor_bwd_kernel; - NVTE_CHECK_CUDA(cudaFuncSetAttribute( - kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_bytes)); - kernel<<>>( + permute_to_grouped_tensor_bwd_kernel<<>>( reinterpret_cast(grad_q.data.dptr), reinterpret_cast(grad_k.data.dptr), reinterpret_cast(grad_v.data.dptr), - tma_q_out, tma_k_out, tma_v_out, - b, s_q, h_q, d_qk, s_kv, h_kv, d_v, permute_s_splits);); + reinterpret_cast(q.data.dptr), reinterpret_cast(k.data.dptr), + reinterpret_cast(v.data.dptr), b, s_q, h_q, d_qk, s_kv, h_kv, d_v, + permute_s_splits);); } else { TRANSFORMER_ENGINE_TYPE_SWITCH_16BIT( grad_q.dtype(), dtype, - auto kernel = permute_to_grouped_tensor_bwd_kernel; - NVTE_CHECK_CUDA(cudaFuncSetAttribute( - kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_bytes)); - kernel<<>>( + permute_to_grouped_tensor_bwd_kernel<<>>( reinterpret_cast(grad_q.data.dptr), reinterpret_cast(grad_k.data.dptr), reinterpret_cast(grad_v.data.dptr), - tma_q_out, tma_k_out, tma_v_out, - b, s_q, h_q, d_qk, s_kv, h_kv, d_v, permute_s_splits);); + reinterpret_cast(q.data.dptr), reinterpret_cast(k.data.dptr), + reinterpret_cast(v.data.dptr), b, s_q, h_q, d_qk, s_kv, h_kv, d_v, + permute_s_splits);); } NVTE_CHECK_CUDA(cudaGetLastError()); } From ba411a2b87072843997e00a060f96fe2bb3b9e1c Mon Sep 17 00:00:00 2001 From: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Date: Mon, 6 Apr 2026 13:58:09 -0700 Subject: [PATCH 166/172] TMA load for bhsd transposes Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --- .../common/fused_attn/flash_attn.cu | 543 ++++++++++++------ 1 file changed, 366 insertions(+), 177 deletions(-) diff --git a/transformer_engine/common/fused_attn/flash_attn.cu b/transformer_engine/common/fused_attn/flash_attn.cu index c30cc3d2d9..5526dfcd59 100644 --- a/transformer_engine/common/fused_attn/flash_attn.cu +++ b/transformer_engine/common/fused_attn/flash_attn.cu @@ -4,7 +4,13 @@ * See LICENSE for license information. ************************************************************************/ +#include +#include + #include "../common.h" +#include "../util/cuda_driver.h" +#include "../util/ptx.cuh" +#include "../utils.cuh" #include "transformer_engine/fused_attn.h" namespace transformer_engine { @@ -23,6 +29,97 @@ constexpr int nvec128 = sizeof(uint4) / type_size; constexpr int load_size = warp_size * nvec; constexpr int block_size = 512; +// TMA permute kernel configuration +constexpr int tma_permute_threads = 32; +constexpr int tma_permute_s_tile = 32; + +// ---- 4D TMA PTX wrappers ---- + +__device__ __forceinline__ void cp_async_bulk_tensor_4d_global_to_shared( + void *dst_shmem, const CUtensorMap *tensor_map, + uint32_t c0, uint32_t c1, uint32_t c2, uint32_t c3, + uint64_t *mbar) { +#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) + uint32_t dst = __cvta_generic_to_shared(dst_shmem); + uint32_t bar = __cvta_generic_to_shared(mbar); + asm volatile( + "cp.async.bulk.tensor.4d.shared::cluster.global.tile" + ".mbarrier::complete_tx::bytes [%0], [%1, {%2, %3, %4, %5}], [%6];" + ::"r"(dst), "l"(tensor_map), + "r"(c0), "r"(c1), "r"(c2), "r"(c3), + "r"(bar) + : "memory"); +#else + NVTE_DEVICE_ERROR("cp_async_bulk_tensor_4d_global_to_shared requires SM 10.0+."); +#endif +} + +__device__ __forceinline__ void cp_async_bulk_tensor_4d_shared_to_global( + const CUtensorMap *tensor_map, + uint32_t c0, uint32_t c1, uint32_t c2, uint32_t c3, + void *src_shmem) { +#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) + uint32_t src = __cvta_generic_to_shared(src_shmem); + asm volatile( + "cp.async.bulk.tensor.4d.global.shared::cta.bulk_group" + " [%0, {%1, %2, %3, %4}], [%5];" + ::"l"(tensor_map), + "r"(c0), "r"(c1), "r"(c2), "r"(c3), + "r"(src) + : "memory"); +#else + NVTE_DEVICE_ERROR("cp_async_bulk_tensor_4d_shared_to_global requires SM 9.0+."); +#endif +} + +// ---- Host-side 4D tensor map creation ---- +// +// Creates a 4D TMA descriptor for a densely-packed tensor whose logical +// dimensions (innermost-first) are [dim0, dim1, dim2, dim3]. +// +// The box (tile) copied per TMA instruction is [box0, box1, box2, box3]. + +static void create_4D_tensor_map(CUtensorMap &tensorMap, void *dataPtr, DType dtype, + uint64_t dim0, uint64_t dim1, uint64_t dim2, uint64_t dim3, + uint32_t box0, uint32_t box1, uint32_t box2, uint32_t box3) { + cuda_driver::ensure_context_exists(); + static PFN_cuTensorMapEncodeTiled_v12000 cuDriverTensorMapEncodeTiled = []() { + void *ptr = cuda_driver::get_symbol("cuTensorMapEncodeTiled"); + return reinterpret_cast(ptr); + }(); + + CUtensorMapDataType tma_dtype; + size_t elem_bytes; + switch (dtype) { + case DType::kFloat16: + tma_dtype = CU_TENSOR_MAP_DATA_TYPE_FLOAT16; + elem_bytes = 2; + break; + case DType::kBFloat16: + tma_dtype = CU_TENSOR_MAP_DATA_TYPE_BFLOAT16; + elem_bytes = 2; + break; + default: + NVTE_ERROR("create_4D_tensor_map: unsupported dtype"); + } + + constexpr uint32_t rank = 4; + uint64_t size[rank] = {dim0, dim1, dim2, dim3}; + uint64_t stride[rank - 1] = { + dim0 * elem_bytes, + dim0 * dim1 * elem_bytes, + dim0 * dim1 * dim2 * elem_bytes, + }; + uint32_t boxSize[rank] = {box0, box1, box2, box3}; + uint32_t elemStride[rank] = {1, 1, 1, 1}; + + NVTE_CHECK_CUDA_DRIVER(cuDriverTensorMapEncodeTiled( + &tensorMap, tma_dtype, rank, dataPtr, size, stride, boxSize, elemStride, + CU_TENSOR_MAP_INTERLEAVE_NONE, CU_TENSOR_MAP_SWIZZLE_NONE, + CU_TENSOR_MAP_L2_PROMOTION_NONE, + CU_TENSOR_MAP_FLOAT_OOB_FILL_NAN_REQUEST_ZERO_FMA)); +} + template __launch_bounds__(block_size) __global__ void prepare_kernel_fwd(const T *qkvi, T *qkv, const size_t B, const size_t S, const size_t Z, @@ -140,168 +237,235 @@ void prepare_flash_attn_bwd(Tensor q, Tensor k, Tensor v, Tensor qkv, cudaStream NVTE_CHECK_CUDA(cudaGetLastError()); } +// ---- TMA helpers for strided (BSHD/SBHD) tensors ---- +// +// Strided BSHD: TMA dims [D, H, S, B], coords [0, h, s, b] +// Strided SBHD: TMA dims [D, H, B, S], coords [0, h, b, s] + template -__launch_bounds__(1024) __global__ - void permute_to_grouped_tensor_fwd_kernel(const T *__restrict__ q, const T *__restrict__ k, - const T *__restrict__ v, T *__restrict__ q_out, - T *__restrict__ k_out, T *__restrict__ v_out, - size_t b, size_t s_q, size_t h_q, size_t d_qk, - size_t s_kv, size_t h_kv, size_t d_v, - unsigned int permute_s_splits) { - const int which_tensor = blockIdx.z; - const T *__restrict__ tensor_in = which_tensor == 0 ? q : (which_tensor == 1 ? k : v); - T *__restrict__ tensor_out = which_tensor == 0 ? q_out : (which_tensor == 1 ? k_out : v_out); - const size_t Sdim = which_tensor == 0 ? s_q : s_kv; - const size_t Hdim = which_tensor == 0 ? h_q : h_kv; - const size_t Ddim = which_tensor == 0 ? d_qk : (which_tensor == 1 ? d_qk : d_v); +__device__ __forceinline__ void issue_tma_load_strided( + T *smem_buf, const CUtensorMap *tma, + size_t h_i, size_t s_tile, size_t b_i, + uint64_t *mbar, size_t tile_bytes) { + ptx::mbarrier_arrive_expect_tx(mbar, static_cast(tile_bytes)); + if constexpr (kIsBshdBshdBshd) { + cp_async_bulk_tensor_4d_global_to_shared( + smem_buf, tma, + 0, static_cast(h_i), + static_cast(s_tile), static_cast(b_i), + mbar); + } else { + cp_async_bulk_tensor_4d_global_to_shared( + smem_buf, tma, + 0, static_cast(h_i), + static_cast(b_i), static_cast(s_tile), + mbar); + } +} + +template +__device__ __forceinline__ void issue_tma_store_strided( + const CUtensorMap *tma, T *smem_buf, + size_t h_i, size_t s_tile, size_t b_i) { + if constexpr (kIsBshdBshdBshd) { + cp_async_bulk_tensor_4d_shared_to_global( + tma, + 0, static_cast(h_i), + static_cast(s_tile), static_cast(b_i), + smem_buf); + } else { + cp_async_bulk_tensor_4d_shared_to_global( + tma, + 0, static_cast(h_i), + static_cast(b_i), static_cast(s_tile), + smem_buf); + } + ptx::cp_async_bulk_commit_group(); +} + +__device__ __forceinline__ void st_global_cs_uint4(uint4 *ptr, uint4 val) { + asm volatile("st.global.cs.v4.b32 [%0], {%1, %2, %3, %4};" + :: "l"(ptr), "r"(val.x), "r"(val.y), "r"(val.z), "r"(val.w) + : "memory"); +} + +// ---- Forward: BSHD/SBHD → BHSD ---- +// +// TMA load from strided input → smem → non-temporal stores to contiguous output. + +template +__launch_bounds__(tma_permute_threads) __global__ + void permute_to_grouped_tensor_fwd_kernel( + const __grid_constant__ CUtensorMap tma_q_in, + const __grid_constant__ CUtensorMap tma_k_in, + const __grid_constant__ CUtensorMap tma_v_in, + T *__restrict__ q_out, T *__restrict__ k_out, T *__restrict__ v_out, + size_t b, size_t s_q, size_t h_q, size_t d_qk, + size_t s_kv, size_t h_kv, size_t d_v, + unsigned int permute_s_splits) { +#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) + const int which = blockIdx.z; + const CUtensorMap *tma_in = which == 0 ? &tma_q_in : (which == 1 ? &tma_k_in : &tma_v_in); + T *__restrict__ tensor_out = which == 0 ? q_out : (which == 1 ? k_out : v_out); + const size_t Sdim = which == 0 ? s_q : s_kv; + const size_t Hdim = which == 0 ? h_q : h_kv; + const size_t Ddim = which == 0 ? d_qk : (which == 1 ? d_qk : d_v); const size_t h_grid = h_q > h_kv ? h_q : h_kv; const size_t b_i = static_cast(blockIdx.x) / h_grid; const size_t h_i = static_cast(blockIdx.x) % h_grid; if (b_i >= b) return; - if (which_tensor == 0) { - if (h_i >= h_q) return; - } else { - if (h_i >= h_kv) return; - } - if (Ddim % static_cast(nvec) != 0) return; + if (which == 0) { if (h_i >= h_q) return; } + else { if (h_i >= h_kv) return; } const unsigned int s_part = blockIdx.y; const size_t s_begin = (Sdim * static_cast(s_part)) / static_cast(permute_s_splits); const size_t s_end = (Sdim * static_cast(s_part + 1)) / static_cast(permute_s_splits); - const size_t S_chunk = s_end - s_begin; + if (s_begin >= s_end) return; - const size_t in_base = kIsBshdBshdBshd ? b_i * Sdim * Hdim * Ddim : b_i * Hdim * Ddim; const size_t out_base = b_i * Hdim * Sdim * Ddim + h_i * Sdim * Ddim; - const bool use_vec128 = - (Ddim % static_cast(nvec128) == 0) && - ((reinterpret_cast(tensor_in) % alignof(Vec)) == 0) && - ((reinterpret_cast(tensor_out) % alignof(Vec)) == 0); - - if (use_vec128) { - const size_t d_vec = Ddim / static_cast(nvec128); - const size_t total_work = S_chunk * d_vec; - for (size_t w = static_cast(threadIdx.x); w < total_work; - w += static_cast(blockDim.x)) { - const size_t s_local = w / d_vec; - const size_t s_i = s_begin + s_local; - const size_t v = w % d_vec; - const size_t d_off = v * static_cast(nvec128); - - const T *__restrict__ in_ptr; - if constexpr (kIsBshdBshdBshd) { - in_ptr = tensor_in + in_base + s_i * Hdim * Ddim + h_i * Ddim + d_off; - } else { - in_ptr = tensor_in + s_i * b * Hdim * Ddim + in_base + h_i * Ddim + d_off; - } - T *__restrict__ out_ptr = tensor_out + out_base + s_i * Ddim + d_off; - *reinterpret_cast *>(out_ptr) = - *reinterpret_cast *>(in_ptr); + + extern __shared__ __align__(128) char smem_raw[]; + T *smem = reinterpret_cast(smem_raw); + + __shared__ __align__(8) uint64_t mbar; + const bool is_leader = (threadIdx.x == 0); + + if (is_leader) { + ptx::mbarrier_init(&mbar, static_cast(blockDim.x)); + ptx::fence_proxy_async_shared_cta(); + } + __syncthreads(); + + constexpr size_t S_TILE = tma_permute_s_tile; + const uint32_t tile_bytes = static_cast(S_TILE * Ddim * sizeof(T)); + int parity = 0; + + for (size_t s_tile = s_begin; s_tile < s_end; s_tile += S_TILE) { + const size_t tile_rows = min(S_TILE, s_end - s_tile); + + if (is_leader) { + issue_tma_load_strided( + smem, tma_in, h_i, s_tile, b_i, &mbar, tile_bytes); + } else { + ptx::mbarrier_arrive(&mbar); } - } else { - const size_t d_vec = Ddim / static_cast(nvec); - const size_t total_work = S_chunk * d_vec; - for (size_t w = static_cast(threadIdx.x); w < total_work; - w += static_cast(blockDim.x)) { - const size_t s_local = w / d_vec; - const size_t s_i = s_begin + s_local; - const size_t v = w % d_vec; - const size_t d_off = v * static_cast(nvec); - - const T *__restrict__ in_ptr; - if constexpr (kIsBshdBshdBshd) { - in_ptr = tensor_in + in_base + s_i * Hdim * Ddim + h_i * Ddim + d_off; - } else { - in_ptr = tensor_in + s_i * b * Hdim * Ddim + in_base + h_i * Ddim + d_off; - } - T *__restrict__ out_ptr = tensor_out + out_base + s_i * Ddim + d_off; - *reinterpret_cast *>(out_ptr) = *reinterpret_cast *>(in_ptr); + + ptx::mbarrier_wait_parity(&mbar, parity); + parity ^= 1; + + T *__restrict__ out_ptr = tensor_out + out_base + s_tile * Ddim; + const size_t total_elems = tile_rows * Ddim; + constexpr size_t vec_elems = sizeof(uint4) / sizeof(T); + + for (size_t i = threadIdx.x * vec_elems; i < total_elems; + i += static_cast(blockDim.x) * vec_elems) { + uint4 v = *reinterpret_cast(smem + i); + st_global_cs_uint4(reinterpret_cast(out_ptr + i), v); } + + __syncthreads(); } + + if (is_leader) { + ptx::mbarrier_invalid(&mbar); + } +#endif } +// ---- Backward: BHSD → BSHD/SBHD ---- +// +// Vectorized loads from contiguous input → smem → TMA store to strided output. + template -__launch_bounds__(1024) __global__ void permute_to_grouped_tensor_bwd_kernel( - const T *__restrict__ grad_q, const T *__restrict__ grad_k, const T *__restrict__ grad_v, - T *__restrict__ q, T *__restrict__ k, T *__restrict__ v, size_t b, size_t s_q, size_t h_q, - size_t d_qk, size_t s_kv, size_t h_kv, size_t d_v, unsigned int permute_s_splits) { - const int which_tensor = blockIdx.z; +__launch_bounds__(tma_permute_threads) __global__ + void permute_to_grouped_tensor_bwd_kernel( + const T *__restrict__ grad_q, const T *__restrict__ grad_k, const T *__restrict__ grad_v, + const __grid_constant__ CUtensorMap tma_q_out, + const __grid_constant__ CUtensorMap tma_k_out, + const __grid_constant__ CUtensorMap tma_v_out, + size_t b, size_t s_q, size_t h_q, size_t d_qk, + size_t s_kv, size_t h_kv, size_t d_v, + unsigned int permute_s_splits) { +#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) + const int which = blockIdx.z; const T *__restrict__ tensor_in = - which_tensor == 0 ? grad_q : (which_tensor == 1 ? grad_k : grad_v); - T *__restrict__ tensor_out = which_tensor == 0 ? q : (which_tensor == 1 ? k : v); - const size_t Sdim = which_tensor == 0 ? s_q : s_kv; - const size_t Hdim = which_tensor == 0 ? h_q : h_kv; - const size_t Ddim = which_tensor == 0 ? d_qk : (which_tensor == 1 ? d_qk : d_v); + which == 0 ? grad_q : (which == 1 ? grad_k : grad_v); + const CUtensorMap *tma_out = which == 0 ? &tma_q_out : (which == 1 ? &tma_k_out : &tma_v_out); + const size_t Sdim = which == 0 ? s_q : s_kv; + const size_t Hdim = which == 0 ? h_q : h_kv; + const size_t Ddim = which == 0 ? d_qk : (which == 1 ? d_qk : d_v); const size_t h_grid = h_q > h_kv ? h_q : h_kv; const size_t b_i = static_cast(blockIdx.x) / h_grid; const size_t h_i = static_cast(blockIdx.x) % h_grid; if (b_i >= b) return; - if (which_tensor == 0) { - if (h_i >= h_q) return; - } else { - if (h_i >= h_kv) return; - } - if (Ddim % static_cast(nvec) != 0) return; + if (which == 0) { if (h_i >= h_q) return; } + else { if (h_i >= h_kv) return; } const unsigned int s_part = blockIdx.y; const size_t s_begin = (Sdim * static_cast(s_part)) / static_cast(permute_s_splits); const size_t s_end = (Sdim * static_cast(s_part + 1)) / static_cast(permute_s_splits); - const size_t S_chunk = s_end - s_begin; + if (s_begin >= s_end) return; const size_t in_base = b_i * Hdim * Sdim * Ddim + h_i * Sdim * Ddim; - const size_t out_base = - kIsBshdBshdBshd ? b_i * Sdim * Hdim * Ddim + h_i * Ddim : b_i * Hdim * Ddim + h_i * Ddim; - const bool use_vec128 = - (Ddim % static_cast(nvec128) == 0) && - ((reinterpret_cast(tensor_in) % alignof(Vec)) == 0) && - ((reinterpret_cast(tensor_out) % alignof(Vec)) == 0); - - if (use_vec128) { - const size_t d_vec = Ddim / static_cast(nvec128); - const size_t total_work = S_chunk * d_vec; - for (size_t w = static_cast(threadIdx.x); w < total_work; - w += static_cast(blockDim.x)) { - const size_t s_local = w / d_vec; - const size_t s_i = s_begin + s_local; - const size_t v = w % d_vec; - const size_t d_off = v * static_cast(nvec128); - - const T *__restrict__ in_ptr = tensor_in + in_base + s_i * Ddim + d_off; - T *__restrict__ out_ptr; - if constexpr (kIsBshdBshdBshd) { - out_ptr = tensor_out + out_base + s_i * Hdim * Ddim + d_off; - } else { - out_ptr = tensor_out + s_i * b * Hdim * Ddim + out_base + d_off; - } - *reinterpret_cast *>(out_ptr) = - *reinterpret_cast *>(in_ptr); + + extern __shared__ __align__(128) char smem_raw[]; + T *smem = reinterpret_cast(smem_raw); + + constexpr size_t S_TILE = tma_permute_s_tile; + constexpr size_t vec_elems = sizeof(uint4) / sizeof(T); + + for (size_t s_tile = s_begin; s_tile < s_end; s_tile += S_TILE) { + const size_t tile_rows = min(S_TILE, s_end - s_tile); + + const T *__restrict__ in_ptr = tensor_in + in_base + s_tile * Ddim; + const size_t total_elems = tile_rows * Ddim; + + for (size_t i = threadIdx.x * vec_elems; i < total_elems; + i += static_cast(blockDim.x) * vec_elems) { + *reinterpret_cast(smem + i) = + *reinterpret_cast(in_ptr + i); } - } else { - const size_t d_vec = Ddim / static_cast(nvec); - const size_t total_work = S_chunk * d_vec; - for (size_t w = static_cast(threadIdx.x); w < total_work; - w += static_cast(blockDim.x)) { - const size_t s_local = w / d_vec; - const size_t s_i = s_begin + s_local; - const size_t v = w % d_vec; - const size_t d_off = v * static_cast(nvec); - - const T *__restrict__ in_ptr = tensor_in + in_base + s_i * Ddim + d_off; - T *__restrict__ out_ptr; - if constexpr (kIsBshdBshdBshd) { - out_ptr = tensor_out + out_base + s_i * Hdim * Ddim + d_off; - } else { - out_ptr = tensor_out + s_i * b * Hdim * Ddim + out_base + d_off; - } - *reinterpret_cast *>(out_ptr) = *reinterpret_cast *>(in_ptr); + + ptx::fence_proxy_async_shared_cta(); + __syncthreads(); + + if (threadIdx.x == 0) { + issue_tma_store_strided(tma_out, smem, h_i, s_tile, b_i); } + + ptx::cp_async_bulk_wait_group(); + __syncthreads(); + } +#endif +} + +// Helper: create a 4D TMA descriptor for the strided (BSHD or SBHD) tensor. +// +// For BSHD [B, S, H, D]: TMA dims [D, H, S, B], box [D, 1, S_TILE, 1] +// For SBHD [S, B, H, D]: TMA dims [D, H, B, S], box [D, 1, 1, S_TILE] +static void create_strided_tensor_map(CUtensorMap &map, void *ptr, DType dtype, + size_t b, size_t s, size_t h, size_t d, + bool is_bshd) { + if (is_bshd) { + create_4D_tensor_map(map, ptr, dtype, + static_cast(d), static_cast(h), + static_cast(s), static_cast(b), + static_cast(d), 1, + static_cast(tma_permute_s_tile), 1); + } else { + create_4D_tensor_map(map, ptr, dtype, + static_cast(d), static_cast(h), + static_cast(b), static_cast(s), + static_cast(d), 1, 1, + static_cast(tma_permute_s_tile)); } } @@ -309,45 +473,57 @@ void permute_to_grouped_tensor_fwd(Tensor q, Tensor k, Tensor v, Tensor q_out, T Tensor v_out, NVTE_QKV_Layout original_layout, cudaStream_t stream) { using namespace transformer_engine; - size_t b = 0, s_q = 0, s_kv = 0, h_q = 0, h_kv = 0, d_qk = 0, d_v = 0; - b = q_out.shape()[0]; - h_q = q_out.shape()[1]; - s_q = q_out.shape()[2]; - d_qk = q_out.shape()[3]; - h_kv = k_out.shape()[1]; - s_kv = k_out.shape()[2]; - d_v = v_out.shape()[3]; - - NVTE_CHECK(d_qk % nvec == 0 && d_v % nvec == 0, - "permute_to_grouped_tensor_fwd: head dim must be divisible by vector width."); - // Split S across grid.y; work out permute_s_splits so S_chunk >= threads - const int threads = 1024; + const size_t b = q_out.shape()[0]; + const size_t h_q = q_out.shape()[1]; + const size_t s_q = q_out.shape()[2]; + const size_t d_qk = q_out.shape()[3]; + const size_t h_kv = k_out.shape()[1]; + const size_t s_kv = k_out.shape()[2]; + const size_t d_v = v_out.shape()[3]; + + NVTE_CHECK(d_qk % nvec128 == 0 && d_v % nvec128 == 0, + "permute_to_grouped_tensor_fwd: head dim must be divisible by ", nvec128, "."); + + const bool is_bshd = (original_layout == NVTE_QKV_Layout::NVTE_BSHD_BSHD_BSHD); + + alignas(64) CUtensorMap tma_q_in{}, tma_k_in{}, tma_v_in{}; + create_strided_tensor_map(tma_q_in, q.data.dptr, q.dtype(), b, s_q, h_q, d_qk, is_bshd); + create_strided_tensor_map(tma_k_in, k.data.dptr, k.dtype(), b, s_kv, h_kv, d_qk, is_bshd); + create_strided_tensor_map(tma_v_in, v.data.dptr, v.dtype(), b, s_kv, h_kv, d_v, is_bshd); + const size_t s_min = std::min(s_q, s_kv); const unsigned int permute_s_splits = - std::max(1u, static_cast(s_min / static_cast(threads))); + std::max(1u, static_cast(s_min / static_cast(tma_permute_threads))); const size_t h_grid = std::max(h_q, h_kv); - dim3 grid(b * h_grid, permute_s_splits, 3); + dim3 grid(static_cast(b * h_grid), permute_s_splits, 3); - if (original_layout == NVTE_QKV_Layout::NVTE_BSHD_BSHD_BSHD) { + const size_t d_max = std::max(d_qk, d_v); + const size_t smem_bytes = tma_permute_s_tile * d_max * sizeof(uint16_t); + + if (is_bshd) { TRANSFORMER_ENGINE_TYPE_SWITCH_16BIT( q.dtype(), dtype, - permute_to_grouped_tensor_fwd_kernel<<>>( - reinterpret_cast(q.data.dptr), - reinterpret_cast(k.data.dptr), - reinterpret_cast(v.data.dptr), - reinterpret_cast(q_out.data.dptr), reinterpret_cast(k_out.data.dptr), - reinterpret_cast(v_out.data.dptr), b, s_q, h_q, d_qk, s_kv, h_kv, d_v, - permute_s_splits);); + auto kernel = permute_to_grouped_tensor_fwd_kernel; + NVTE_CHECK_CUDA(cudaFuncSetAttribute( + kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_bytes)); + kernel<<>>( + tma_q_in, tma_k_in, tma_v_in, + reinterpret_cast(q_out.data.dptr), + reinterpret_cast(k_out.data.dptr), + reinterpret_cast(v_out.data.dptr), + b, s_q, h_q, d_qk, s_kv, h_kv, d_v, permute_s_splits);); } else { TRANSFORMER_ENGINE_TYPE_SWITCH_16BIT( q.dtype(), dtype, - permute_to_grouped_tensor_fwd_kernel<<>>( - reinterpret_cast(q.data.dptr), - reinterpret_cast(k.data.dptr), - reinterpret_cast(v.data.dptr), - reinterpret_cast(q_out.data.dptr), reinterpret_cast(k_out.data.dptr), - reinterpret_cast(v_out.data.dptr), b, s_q, h_q, d_qk, s_kv, h_kv, d_v, - permute_s_splits);); + auto kernel = permute_to_grouped_tensor_fwd_kernel; + NVTE_CHECK_CUDA(cudaFuncSetAttribute( + kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_bytes)); + kernel<<>>( + tma_q_in, tma_k_in, tma_v_in, + reinterpret_cast(q_out.data.dptr), + reinterpret_cast(k_out.data.dptr), + reinterpret_cast(v_out.data.dptr), + b, s_q, h_q, d_qk, s_kv, h_kv, d_v, permute_s_splits);); } NVTE_CHECK_CUDA(cudaGetLastError()); } @@ -355,44 +531,57 @@ void permute_to_grouped_tensor_fwd(Tensor q, Tensor k, Tensor v, Tensor q_out, T void permute_to_grouped_tensor_bwd(Tensor grad_q, Tensor grad_k, Tensor grad_v, Tensor q, Tensor k, Tensor v, NVTE_QKV_Layout original_layout, cudaStream_t stream) { using namespace transformer_engine; - size_t b = 0, s_q = 0, s_kv = 0, h_q = 0, h_kv = 0, d_qk = 0, d_v = 0; - b = grad_q.shape()[0]; - h_q = grad_q.shape()[1]; - s_q = grad_q.shape()[2]; - d_qk = grad_q.shape()[3]; - h_kv = grad_k.shape()[1]; - s_kv = grad_k.shape()[2]; - d_v = grad_v.shape()[3]; - - NVTE_CHECK(d_qk % nvec == 0 && d_v % nvec == 0, - "permute_to_grouped_tensor_bwd: head dim must be divisible by vector width."); - const int threads = 1024; + const size_t b = grad_q.shape()[0]; + const size_t h_q = grad_q.shape()[1]; + const size_t s_q = grad_q.shape()[2]; + const size_t d_qk = grad_q.shape()[3]; + const size_t h_kv = grad_k.shape()[1]; + const size_t s_kv = grad_k.shape()[2]; + const size_t d_v = grad_v.shape()[3]; + + NVTE_CHECK(d_qk % nvec128 == 0 && d_v % nvec128 == 0, + "permute_to_grouped_tensor_bwd: head dim must be divisible by ", nvec128, "."); + + const bool is_bshd = (original_layout == NVTE_QKV_Layout::NVTE_BSHD_BSHD_BSHD); + + alignas(64) CUtensorMap tma_q_out{}, tma_k_out{}, tma_v_out{}; + create_strided_tensor_map(tma_q_out, q.data.dptr, q.dtype(), b, s_q, h_q, d_qk, is_bshd); + create_strided_tensor_map(tma_k_out, k.data.dptr, k.dtype(), b, s_kv, h_kv, d_qk, is_bshd); + create_strided_tensor_map(tma_v_out, v.data.dptr, v.dtype(), b, s_kv, h_kv, d_v, is_bshd); + const size_t s_min = std::min(s_q, s_kv); const unsigned int permute_s_splits = - std::max(1u, static_cast(s_min / static_cast(threads))); + std::max(1u, static_cast(s_min / static_cast(tma_permute_threads))); const size_t h_grid = std::max(h_q, h_kv); - dim3 grid(b * h_grid, permute_s_splits, 3); + dim3 grid(static_cast(b * h_grid), permute_s_splits, 3); + + const size_t d_max = std::max(d_qk, d_v); + const size_t smem_bytes = tma_permute_s_tile * d_max * sizeof(uint16_t); - if (original_layout == NVTE_QKV_Layout::NVTE_BSHD_BSHD_BSHD) { + if (is_bshd) { TRANSFORMER_ENGINE_TYPE_SWITCH_16BIT( grad_q.dtype(), dtype, - permute_to_grouped_tensor_bwd_kernel<<>>( + auto kernel = permute_to_grouped_tensor_bwd_kernel; + NVTE_CHECK_CUDA(cudaFuncSetAttribute( + kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_bytes)); + kernel<<>>( reinterpret_cast(grad_q.data.dptr), reinterpret_cast(grad_k.data.dptr), reinterpret_cast(grad_v.data.dptr), - reinterpret_cast(q.data.dptr), reinterpret_cast(k.data.dptr), - reinterpret_cast(v.data.dptr), b, s_q, h_q, d_qk, s_kv, h_kv, d_v, - permute_s_splits);); + tma_q_out, tma_k_out, tma_v_out, + b, s_q, h_q, d_qk, s_kv, h_kv, d_v, permute_s_splits);); } else { TRANSFORMER_ENGINE_TYPE_SWITCH_16BIT( grad_q.dtype(), dtype, - permute_to_grouped_tensor_bwd_kernel<<>>( + auto kernel = permute_to_grouped_tensor_bwd_kernel; + NVTE_CHECK_CUDA(cudaFuncSetAttribute( + kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_bytes)); + kernel<<>>( reinterpret_cast(grad_q.data.dptr), reinterpret_cast(grad_k.data.dptr), reinterpret_cast(grad_v.data.dptr), - reinterpret_cast(q.data.dptr), reinterpret_cast(k.data.dptr), - reinterpret_cast(v.data.dptr), b, s_q, h_q, d_qk, s_kv, h_kv, d_v, - permute_s_splits);); + tma_q_out, tma_k_out, tma_v_out, + b, s_q, h_q, d_qk, s_kv, h_kv, d_v, permute_s_splits);); } NVTE_CHECK_CUDA(cudaGetLastError()); } From 5ada28d5ecbf8df3b3a0904ad9e89689352de2ae Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 6 Apr 2026 20:59:01 +0000 Subject: [PATCH 167/172] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../common/fused_attn/flash_attn.cu | 214 ++++++++---------- 1 file changed, 90 insertions(+), 124 deletions(-) diff --git a/transformer_engine/common/fused_attn/flash_attn.cu b/transformer_engine/common/fused_attn/flash_attn.cu index 5526dfcd59..97e9a620ba 100644 --- a/transformer_engine/common/fused_attn/flash_attn.cu +++ b/transformer_engine/common/fused_attn/flash_attn.cu @@ -36,18 +36,15 @@ constexpr int tma_permute_s_tile = 32; // ---- 4D TMA PTX wrappers ---- __device__ __forceinline__ void cp_async_bulk_tensor_4d_global_to_shared( - void *dst_shmem, const CUtensorMap *tensor_map, - uint32_t c0, uint32_t c1, uint32_t c2, uint32_t c3, - uint64_t *mbar) { + void *dst_shmem, const CUtensorMap *tensor_map, uint32_t c0, uint32_t c1, uint32_t c2, + uint32_t c3, uint64_t *mbar) { #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) uint32_t dst = __cvta_generic_to_shared(dst_shmem); uint32_t bar = __cvta_generic_to_shared(mbar); asm volatile( "cp.async.bulk.tensor.4d.shared::cluster.global.tile" - ".mbarrier::complete_tx::bytes [%0], [%1, {%2, %3, %4, %5}], [%6];" - ::"r"(dst), "l"(tensor_map), - "r"(c0), "r"(c1), "r"(c2), "r"(c3), - "r"(bar) + ".mbarrier::complete_tx::bytes [%0], [%1, {%2, %3, %4, %5}], [%6];" ::"r"(dst), + "l"(tensor_map), "r"(c0), "r"(c1), "r"(c2), "r"(c3), "r"(bar) : "memory"); #else NVTE_DEVICE_ERROR("cp_async_bulk_tensor_4d_global_to_shared requires SM 10.0+."); @@ -55,17 +52,14 @@ __device__ __forceinline__ void cp_async_bulk_tensor_4d_global_to_shared( } __device__ __forceinline__ void cp_async_bulk_tensor_4d_shared_to_global( - const CUtensorMap *tensor_map, - uint32_t c0, uint32_t c1, uint32_t c2, uint32_t c3, + const CUtensorMap *tensor_map, uint32_t c0, uint32_t c1, uint32_t c2, uint32_t c3, void *src_shmem) { #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) uint32_t src = __cvta_generic_to_shared(src_shmem); asm volatile( "cp.async.bulk.tensor.4d.global.shared::cta.bulk_group" - " [%0, {%1, %2, %3, %4}], [%5];" - ::"l"(tensor_map), - "r"(c0), "r"(c1), "r"(c2), "r"(c3), - "r"(src) + " [%0, {%1, %2, %3, %4}], [%5];" ::"l"(tensor_map), + "r"(c0), "r"(c1), "r"(c2), "r"(c3), "r"(src) : "memory"); #else NVTE_DEVICE_ERROR("cp_async_bulk_tensor_4d_shared_to_global requires SM 9.0+."); @@ -79,9 +73,9 @@ __device__ __forceinline__ void cp_async_bulk_tensor_4d_shared_to_global( // // The box (tile) copied per TMA instruction is [box0, box1, box2, box3]. -static void create_4D_tensor_map(CUtensorMap &tensorMap, void *dataPtr, DType dtype, - uint64_t dim0, uint64_t dim1, uint64_t dim2, uint64_t dim3, - uint32_t box0, uint32_t box1, uint32_t box2, uint32_t box3) { +static void create_4D_tensor_map(CUtensorMap &tensorMap, void *dataPtr, DType dtype, uint64_t dim0, + uint64_t dim1, uint64_t dim2, uint64_t dim3, uint32_t box0, + uint32_t box1, uint32_t box2, uint32_t box3) { cuda_driver::ensure_context_exists(); static PFN_cuTensorMapEncodeTiled_v12000 cuDriverTensorMapEncodeTiled = []() { void *ptr = cuda_driver::get_symbol("cuTensorMapEncodeTiled"); @@ -115,8 +109,7 @@ static void create_4D_tensor_map(CUtensorMap &tensorMap, void *dataPtr, DType dt NVTE_CHECK_CUDA_DRIVER(cuDriverTensorMapEncodeTiled( &tensorMap, tma_dtype, rank, dataPtr, size, stride, boxSize, elemStride, - CU_TENSOR_MAP_INTERLEAVE_NONE, CU_TENSOR_MAP_SWIZZLE_NONE, - CU_TENSOR_MAP_L2_PROMOTION_NONE, + CU_TENSOR_MAP_INTERLEAVE_NONE, CU_TENSOR_MAP_SWIZZLE_NONE, CU_TENSOR_MAP_L2_PROMOTION_NONE, CU_TENSOR_MAP_FLOAT_OOB_FILL_NAN_REQUEST_ZERO_FMA)); } @@ -243,50 +236,40 @@ void prepare_flash_attn_bwd(Tensor q, Tensor k, Tensor v, Tensor qkv, cudaStream // Strided SBHD: TMA dims [D, H, B, S], coords [0, h, b, s] template -__device__ __forceinline__ void issue_tma_load_strided( - T *smem_buf, const CUtensorMap *tma, - size_t h_i, size_t s_tile, size_t b_i, - uint64_t *mbar, size_t tile_bytes) { +__device__ __forceinline__ void issue_tma_load_strided(T *smem_buf, const CUtensorMap *tma, + size_t h_i, size_t s_tile, size_t b_i, + uint64_t *mbar, size_t tile_bytes) { ptx::mbarrier_arrive_expect_tx(mbar, static_cast(tile_bytes)); if constexpr (kIsBshdBshdBshd) { - cp_async_bulk_tensor_4d_global_to_shared( - smem_buf, tma, - 0, static_cast(h_i), - static_cast(s_tile), static_cast(b_i), - mbar); + cp_async_bulk_tensor_4d_global_to_shared(smem_buf, tma, 0, static_cast(h_i), + static_cast(s_tile), + static_cast(b_i), mbar); } else { - cp_async_bulk_tensor_4d_global_to_shared( - smem_buf, tma, - 0, static_cast(h_i), - static_cast(b_i), static_cast(s_tile), - mbar); + cp_async_bulk_tensor_4d_global_to_shared(smem_buf, tma, 0, static_cast(h_i), + static_cast(b_i), + static_cast(s_tile), mbar); } } template -__device__ __forceinline__ void issue_tma_store_strided( - const CUtensorMap *tma, T *smem_buf, - size_t h_i, size_t s_tile, size_t b_i) { +__device__ __forceinline__ void issue_tma_store_strided(const CUtensorMap *tma, T *smem_buf, + size_t h_i, size_t s_tile, size_t b_i) { if constexpr (kIsBshdBshdBshd) { - cp_async_bulk_tensor_4d_shared_to_global( - tma, - 0, static_cast(h_i), - static_cast(s_tile), static_cast(b_i), - smem_buf); + cp_async_bulk_tensor_4d_shared_to_global(tma, 0, static_cast(h_i), + static_cast(s_tile), + static_cast(b_i), smem_buf); } else { - cp_async_bulk_tensor_4d_shared_to_global( - tma, - 0, static_cast(h_i), - static_cast(b_i), static_cast(s_tile), - smem_buf); + cp_async_bulk_tensor_4d_shared_to_global(tma, 0, static_cast(h_i), + static_cast(b_i), + static_cast(s_tile), smem_buf); } ptx::cp_async_bulk_commit_group(); } __device__ __forceinline__ void st_global_cs_uint4(uint4 *ptr, uint4 val) { - asm volatile("st.global.cs.v4.b32 [%0], {%1, %2, %3, %4};" - :: "l"(ptr), "r"(val.x), "r"(val.y), "r"(val.z), "r"(val.w) - : "memory"); + asm volatile("st.global.cs.v4.b32 [%0], {%1, %2, %3, %4};" ::"l"(ptr), "r"(val.x), "r"(val.y), + "r"(val.z), "r"(val.w) + : "memory"); } // ---- Forward: BSHD/SBHD → BHSD ---- @@ -295,14 +278,13 @@ __device__ __forceinline__ void st_global_cs_uint4(uint4 *ptr, uint4 val) { template __launch_bounds__(tma_permute_threads) __global__ - void permute_to_grouped_tensor_fwd_kernel( - const __grid_constant__ CUtensorMap tma_q_in, - const __grid_constant__ CUtensorMap tma_k_in, - const __grid_constant__ CUtensorMap tma_v_in, - T *__restrict__ q_out, T *__restrict__ k_out, T *__restrict__ v_out, - size_t b, size_t s_q, size_t h_q, size_t d_qk, - size_t s_kv, size_t h_kv, size_t d_v, - unsigned int permute_s_splits) { + void permute_to_grouped_tensor_fwd_kernel(const __grid_constant__ CUtensorMap tma_q_in, + const __grid_constant__ CUtensorMap tma_k_in, + const __grid_constant__ CUtensorMap tma_v_in, + T *__restrict__ q_out, T *__restrict__ k_out, + T *__restrict__ v_out, size_t b, size_t s_q, + size_t h_q, size_t d_qk, size_t s_kv, size_t h_kv, + size_t d_v, unsigned int permute_s_splits) { #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) const int which = blockIdx.z; const CUtensorMap *tma_in = which == 0 ? &tma_q_in : (which == 1 ? &tma_k_in : &tma_v_in); @@ -316,8 +298,11 @@ __launch_bounds__(tma_permute_threads) __global__ const size_t h_i = static_cast(blockIdx.x) % h_grid; if (b_i >= b) return; - if (which == 0) { if (h_i >= h_q) return; } - else { if (h_i >= h_kv) return; } + if (which == 0) { + if (h_i >= h_q) return; + } else { + if (h_i >= h_kv) return; + } const unsigned int s_part = blockIdx.y; const size_t s_begin = @@ -348,8 +333,7 @@ __launch_bounds__(tma_permute_threads) __global__ const size_t tile_rows = min(S_TILE, s_end - s_tile); if (is_leader) { - issue_tma_load_strided( - smem, tma_in, h_i, s_tile, b_i, &mbar, tile_bytes); + issue_tma_load_strided(smem, tma_in, h_i, s_tile, b_i, &mbar, tile_bytes); } else { ptx::mbarrier_arrive(&mbar); } @@ -381,19 +365,14 @@ __launch_bounds__(tma_permute_threads) __global__ // Vectorized loads from contiguous input → smem → TMA store to strided output. template -__launch_bounds__(tma_permute_threads) __global__ - void permute_to_grouped_tensor_bwd_kernel( - const T *__restrict__ grad_q, const T *__restrict__ grad_k, const T *__restrict__ grad_v, - const __grid_constant__ CUtensorMap tma_q_out, - const __grid_constant__ CUtensorMap tma_k_out, - const __grid_constant__ CUtensorMap tma_v_out, - size_t b, size_t s_q, size_t h_q, size_t d_qk, - size_t s_kv, size_t h_kv, size_t d_v, - unsigned int permute_s_splits) { +__launch_bounds__(tma_permute_threads) __global__ void permute_to_grouped_tensor_bwd_kernel( + const T *__restrict__ grad_q, const T *__restrict__ grad_k, const T *__restrict__ grad_v, + const __grid_constant__ CUtensorMap tma_q_out, const __grid_constant__ CUtensorMap tma_k_out, + const __grid_constant__ CUtensorMap tma_v_out, size_t b, size_t s_q, size_t h_q, size_t d_qk, + size_t s_kv, size_t h_kv, size_t d_v, unsigned int permute_s_splits) { #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) const int which = blockIdx.z; - const T *__restrict__ tensor_in = - which == 0 ? grad_q : (which == 1 ? grad_k : grad_v); + const T *__restrict__ tensor_in = which == 0 ? grad_q : (which == 1 ? grad_k : grad_v); const CUtensorMap *tma_out = which == 0 ? &tma_q_out : (which == 1 ? &tma_k_out : &tma_v_out); const size_t Sdim = which == 0 ? s_q : s_kv; const size_t Hdim = which == 0 ? h_q : h_kv; @@ -404,8 +383,11 @@ __launch_bounds__(tma_permute_threads) __global__ const size_t h_i = static_cast(blockIdx.x) % h_grid; if (b_i >= b) return; - if (which == 0) { if (h_i >= h_q) return; } - else { if (h_i >= h_kv) return; } + if (which == 0) { + if (h_i >= h_q) return; + } else { + if (h_i >= h_kv) return; + } const unsigned int s_part = blockIdx.y; const size_t s_begin = @@ -430,8 +412,7 @@ __launch_bounds__(tma_permute_threads) __global__ for (size_t i = threadIdx.x * vec_elems; i < total_elems; i += static_cast(blockDim.x) * vec_elems) { - *reinterpret_cast(smem + i) = - *reinterpret_cast(in_ptr + i); + *reinterpret_cast(smem + i) = *reinterpret_cast(in_ptr + i); } ptx::fence_proxy_async_shared_cta(); @@ -451,21 +432,16 @@ __launch_bounds__(tma_permute_threads) __global__ // // For BSHD [B, S, H, D]: TMA dims [D, H, S, B], box [D, 1, S_TILE, 1] // For SBHD [S, B, H, D]: TMA dims [D, H, B, S], box [D, 1, 1, S_TILE] -static void create_strided_tensor_map(CUtensorMap &map, void *ptr, DType dtype, - size_t b, size_t s, size_t h, size_t d, - bool is_bshd) { +static void create_strided_tensor_map(CUtensorMap &map, void *ptr, DType dtype, size_t b, size_t s, + size_t h, size_t d, bool is_bshd) { if (is_bshd) { - create_4D_tensor_map(map, ptr, dtype, - static_cast(d), static_cast(h), + create_4D_tensor_map(map, ptr, dtype, static_cast(d), static_cast(h), static_cast(s), static_cast(b), - static_cast(d), 1, - static_cast(tma_permute_s_tile), 1); + static_cast(d), 1, static_cast(tma_permute_s_tile), 1); } else { - create_4D_tensor_map(map, ptr, dtype, - static_cast(d), static_cast(h), + create_4D_tensor_map(map, ptr, dtype, static_cast(d), static_cast(h), static_cast(b), static_cast(s), - static_cast(d), 1, 1, - static_cast(tma_permute_s_tile)); + static_cast(d), 1, 1, static_cast(tma_permute_s_tile)); } } @@ -473,13 +449,13 @@ void permute_to_grouped_tensor_fwd(Tensor q, Tensor k, Tensor v, Tensor q_out, T Tensor v_out, NVTE_QKV_Layout original_layout, cudaStream_t stream) { using namespace transformer_engine; - const size_t b = q_out.shape()[0]; - const size_t h_q = q_out.shape()[1]; - const size_t s_q = q_out.shape()[2]; + const size_t b = q_out.shape()[0]; + const size_t h_q = q_out.shape()[1]; + const size_t s_q = q_out.shape()[2]; const size_t d_qk = q_out.shape()[3]; const size_t h_kv = k_out.shape()[1]; const size_t s_kv = k_out.shape()[2]; - const size_t d_v = v_out.shape()[3]; + const size_t d_v = v_out.shape()[3]; NVTE_CHECK(d_qk % nvec128 == 0 && d_v % nvec128 == 0, "permute_to_grouped_tensor_fwd: head dim must be divisible by ", nvec128, "."); @@ -502,27 +478,21 @@ void permute_to_grouped_tensor_fwd(Tensor q, Tensor k, Tensor v, Tensor q_out, T if (is_bshd) { TRANSFORMER_ENGINE_TYPE_SWITCH_16BIT( - q.dtype(), dtype, - auto kernel = permute_to_grouped_tensor_fwd_kernel; - NVTE_CHECK_CUDA(cudaFuncSetAttribute( - kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_bytes)); + q.dtype(), dtype, auto kernel = permute_to_grouped_tensor_fwd_kernel; + NVTE_CHECK_CUDA( + cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_bytes)); kernel<<>>( - tma_q_in, tma_k_in, tma_v_in, - reinterpret_cast(q_out.data.dptr), - reinterpret_cast(k_out.data.dptr), - reinterpret_cast(v_out.data.dptr), + tma_q_in, tma_k_in, tma_v_in, reinterpret_cast(q_out.data.dptr), + reinterpret_cast(k_out.data.dptr), reinterpret_cast(v_out.data.dptr), b, s_q, h_q, d_qk, s_kv, h_kv, d_v, permute_s_splits);); } else { TRANSFORMER_ENGINE_TYPE_SWITCH_16BIT( - q.dtype(), dtype, - auto kernel = permute_to_grouped_tensor_fwd_kernel; - NVTE_CHECK_CUDA(cudaFuncSetAttribute( - kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_bytes)); + q.dtype(), dtype, auto kernel = permute_to_grouped_tensor_fwd_kernel; + NVTE_CHECK_CUDA( + cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_bytes)); kernel<<>>( - tma_q_in, tma_k_in, tma_v_in, - reinterpret_cast(q_out.data.dptr), - reinterpret_cast(k_out.data.dptr), - reinterpret_cast(v_out.data.dptr), + tma_q_in, tma_k_in, tma_v_in, reinterpret_cast(q_out.data.dptr), + reinterpret_cast(k_out.data.dptr), reinterpret_cast(v_out.data.dptr), b, s_q, h_q, d_qk, s_kv, h_kv, d_v, permute_s_splits);); } NVTE_CHECK_CUDA(cudaGetLastError()); @@ -531,13 +501,13 @@ void permute_to_grouped_tensor_fwd(Tensor q, Tensor k, Tensor v, Tensor q_out, T void permute_to_grouped_tensor_bwd(Tensor grad_q, Tensor grad_k, Tensor grad_v, Tensor q, Tensor k, Tensor v, NVTE_QKV_Layout original_layout, cudaStream_t stream) { using namespace transformer_engine; - const size_t b = grad_q.shape()[0]; - const size_t h_q = grad_q.shape()[1]; - const size_t s_q = grad_q.shape()[2]; + const size_t b = grad_q.shape()[0]; + const size_t h_q = grad_q.shape()[1]; + const size_t s_q = grad_q.shape()[2]; const size_t d_qk = grad_q.shape()[3]; const size_t h_kv = grad_k.shape()[1]; const size_t s_kv = grad_k.shape()[2]; - const size_t d_v = grad_v.shape()[3]; + const size_t d_v = grad_v.shape()[3]; NVTE_CHECK(d_qk % nvec128 == 0 && d_v % nvec128 == 0, "permute_to_grouped_tensor_bwd: head dim must be divisible by ", nvec128, "."); @@ -560,28 +530,24 @@ void permute_to_grouped_tensor_bwd(Tensor grad_q, Tensor grad_k, Tensor grad_v, if (is_bshd) { TRANSFORMER_ENGINE_TYPE_SWITCH_16BIT( - grad_q.dtype(), dtype, - auto kernel = permute_to_grouped_tensor_bwd_kernel; - NVTE_CHECK_CUDA(cudaFuncSetAttribute( - kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_bytes)); + grad_q.dtype(), dtype, auto kernel = permute_to_grouped_tensor_bwd_kernel; + NVTE_CHECK_CUDA( + cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_bytes)); kernel<<>>( reinterpret_cast(grad_q.data.dptr), reinterpret_cast(grad_k.data.dptr), - reinterpret_cast(grad_v.data.dptr), - tma_q_out, tma_k_out, tma_v_out, - b, s_q, h_q, d_qk, s_kv, h_kv, d_v, permute_s_splits);); + reinterpret_cast(grad_v.data.dptr), tma_q_out, tma_k_out, tma_v_out, b, + s_q, h_q, d_qk, s_kv, h_kv, d_v, permute_s_splits);); } else { TRANSFORMER_ENGINE_TYPE_SWITCH_16BIT( - grad_q.dtype(), dtype, - auto kernel = permute_to_grouped_tensor_bwd_kernel; - NVTE_CHECK_CUDA(cudaFuncSetAttribute( - kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_bytes)); + grad_q.dtype(), dtype, auto kernel = permute_to_grouped_tensor_bwd_kernel; + NVTE_CHECK_CUDA( + cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_bytes)); kernel<<>>( reinterpret_cast(grad_q.data.dptr), reinterpret_cast(grad_k.data.dptr), - reinterpret_cast(grad_v.data.dptr), - tma_q_out, tma_k_out, tma_v_out, - b, s_q, h_q, d_qk, s_kv, h_kv, d_v, permute_s_splits);); + reinterpret_cast(grad_v.data.dptr), tma_q_out, tma_k_out, tma_v_out, b, + s_q, h_q, d_qk, s_kv, h_kv, d_v, permute_s_splits);); } NVTE_CHECK_CUDA(cudaGetLastError()); } From 6911aba3d18e229cc37e8c71a400d468282cd3a3 Mon Sep 17 00:00:00 2001 From: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Date: Mon, 6 Apr 2026 15:26:42 -0700 Subject: [PATCH 168/172] fix some lint Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --- transformer_engine/debug/features/log_fp8_tensor_stats.py | 2 -- .../pytorch/attention/dot_product_attention/backends.py | 1 - .../pytorch/attention/dot_product_attention/context_parallel.py | 2 +- 3 files changed, 1 insertion(+), 4 deletions(-) diff --git a/transformer_engine/debug/features/log_fp8_tensor_stats.py b/transformer_engine/debug/features/log_fp8_tensor_stats.py index e092f3d30e..d26f9ef7f6 100644 --- a/transformer_engine/debug/features/log_fp8_tensor_stats.py +++ b/transformer_engine/debug/features/log_fp8_tensor_stats.py @@ -14,8 +14,6 @@ from nvdlfw_inspect.registry import Registry, api_method import transformer_engine_torch as tex -import transformer_engine_torch as tex - from transformer_engine.debug.features.utils.stats_buffer import STATS_BUFFERS from transformer_engine.debug.features.utils import get_reduction_params, next_enabled_iter from transformer_engine.pytorch.tensor import Quantizer, QuantizedTensor diff --git a/transformer_engine/pytorch/attention/dot_product_attention/backends.py b/transformer_engine/pytorch/attention/dot_product_attention/backends.py index 3c0d0fdc8a..08e3e9aa46 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/backends.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/backends.py @@ -46,7 +46,6 @@ FusedAttnBackend, META_O, META_QKV, - QKVLayout, ) from transformer_engine.pytorch.quantization import get_fp8_torch_dtype, FP8GlobalStateManager from transformer_engine.pytorch.distributed import get_distributed_world_size diff --git a/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py b/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py index e76ebf66b0..94422b2750 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py @@ -3838,7 +3838,7 @@ def forward( orig_q_shape, orig_k_shape, orig_v_shape = q.shape, k.shape, v.shape orig_o_shape = orig_q_shape[:-1] + orig_v_shape[-1:] o_format = qkv_format - batch_dim_qkv, seq_dim_qkv, _ = get_bsh_dims(qkv_format) + _, seq_dim_qkv, _ = get_bsh_dims(qkv_format) _, seq_dim_o, _ = get_bsh_dims(o_format) if softmax_scale is None: softmax_scale = q.shape[-1] ** (-0.5) From a27e30d85b691998241eec59c0a1bb03201a7912 Mon Sep 17 00:00:00 2001 From: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Date: Thu, 9 Apr 2026 20:53:25 -0700 Subject: [PATCH 169/172] temp: quant+perm+swizzle, rope, perm_fused Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --- tests/pytorch/attention/test_attention.py | 2 +- .../common/fused_attn/flash_attn.cu | 785 ++++++++++++++++-- .../common/fused_attn/fused_attn.cpp | 12 +- .../fused_attn_f16_arbitrary_seqlen.cu | 4 + .../common/fused_attn/fused_attn_fp8.cu | 73 +- .../common/fused_attn/fused_attn_fp8.h | 7 +- transformer_engine/common/fused_attn/utils.h | 9 +- .../common/fused_rope/fused_rope.cu | 510 ++++++++++++ .../include/transformer_engine/fused_attn.h | 66 +- .../include/transformer_engine/fused_rope.h | 75 ++ transformer_engine/common/swizzle/swizzle.cu | 156 +++- .../common/transformer_engine.cpp | 10 +- .../common/util/pybind_helper.h | 3 +- .../dot_product_attention/backends.py | 110 ++- .../attention/dot_product_attention/utils.py | 145 ++-- transformer_engine/pytorch/attention/rope.py | 232 +++++- .../pytorch/cpp_extensions/fused_attn.py | 21 + transformer_engine/pytorch/csrc/extensions.h | 42 +- .../pytorch/csrc/extensions/apply_rope.cpp | 212 +++++ .../pytorch/csrc/extensions/attention.cpp | 243 ++++-- .../pytorch/csrc/extensions/pybind.cpp | 29 +- 21 files changed, 2450 insertions(+), 296 deletions(-) diff --git a/tests/pytorch/attention/test_attention.py b/tests/pytorch/attention/test_attention.py index 58f5ebb7bb..a947a8a373 100644 --- a/tests/pytorch/attention/test_attention.py +++ b/tests/pytorch/attention/test_attention.py @@ -1814,7 +1814,7 @@ def get_model(dtype, config): head_dim_v=128, ), "fp8_10": ModelConfig( - 2, + 1, 4096, 128, 192, diff --git a/transformer_engine/common/fused_attn/flash_attn.cu b/transformer_engine/common/fused_attn/flash_attn.cu index 97e9a620ba..9ebd13fb1c 100644 --- a/transformer_engine/common/fused_attn/flash_attn.cu +++ b/transformer_engine/common/fused_attn/flash_attn.cu @@ -23,15 +23,15 @@ struct alignas(sizeof(T) * N) Vec { }; constexpr int warp_size = 32; -constexpr int type_size = 2; // FP16 or BF16 +constexpr int type_size = 2; constexpr int nvec = sizeof(uint64_t) / type_size; constexpr int nvec128 = sizeof(uint4) / type_size; constexpr int load_size = warp_size * nvec; constexpr int block_size = 512; // TMA permute kernel configuration -constexpr int tma_permute_threads = 32; -constexpr int tma_permute_s_tile = 32; +constexpr int tma_permute_threads = 128; +constexpr int tma_permute_s_tile_default = 32; // ---- 4D TMA PTX wrappers ---- @@ -93,8 +93,16 @@ static void create_4D_tensor_map(CUtensorMap &tensorMap, void *dataPtr, DType dt tma_dtype = CU_TENSOR_MAP_DATA_TYPE_BFLOAT16; elem_bytes = 2; break; + case DType::kFloat8E4M3: + case DType::kFloat8E5M2: + case DType::kFloat8E8M0: + case DType::kByte: + tma_dtype = CU_TENSOR_MAP_DATA_TYPE_UINT8; + elem_bytes = 1; + break; default: - NVTE_ERROR("create_4D_tensor_map: unsupported dtype"); + NVTE_ERROR("create_4D_tensor_map: unsupported dtype ", + to_string(static_cast(dtype))); } constexpr uint32_t rank = 4; @@ -107,10 +115,14 @@ static void create_4D_tensor_map(CUtensorMap &tensorMap, void *dataPtr, DType dt uint32_t boxSize[rank] = {box0, box1, box2, box3}; uint32_t elemStride[rank] = {1, 1, 1, 1}; + const auto oob_fill = (tma_dtype == CU_TENSOR_MAP_DATA_TYPE_UINT8) + ? CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE + : CU_TENSOR_MAP_FLOAT_OOB_FILL_NAN_REQUEST_ZERO_FMA; + NVTE_CHECK_CUDA_DRIVER(cuDriverTensorMapEncodeTiled( &tensorMap, tma_dtype, rank, dataPtr, size, stride, boxSize, elemStride, CU_TENSOR_MAP_INTERLEAVE_NONE, CU_TENSOR_MAP_SWIZZLE_NONE, CU_TENSOR_MAP_L2_PROMOTION_NONE, - CU_TENSOR_MAP_FLOAT_OOB_FILL_NAN_REQUEST_ZERO_FMA)); + oob_fill)); } template @@ -284,7 +296,8 @@ __launch_bounds__(tma_permute_threads) __global__ T *__restrict__ q_out, T *__restrict__ k_out, T *__restrict__ v_out, size_t b, size_t s_q, size_t h_q, size_t d_qk, size_t s_kv, size_t h_kv, - size_t d_v, unsigned int permute_s_splits) { + size_t d_v, unsigned int permute_s_splits, + size_t s_tile_size) { #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) const int which = blockIdx.z; const CUtensorMap *tma_in = which == 0 ? &tma_q_in : (which == 1 ? &tma_k_in : &tma_v_in); @@ -325,7 +338,7 @@ __launch_bounds__(tma_permute_threads) __global__ } __syncthreads(); - constexpr size_t S_TILE = tma_permute_s_tile; + const size_t S_TILE = s_tile_size; const uint32_t tile_bytes = static_cast(S_TILE * Ddim * sizeof(T)); int parity = 0; @@ -369,7 +382,8 @@ __launch_bounds__(tma_permute_threads) __global__ void permute_to_grouped_tensor const T *__restrict__ grad_q, const T *__restrict__ grad_k, const T *__restrict__ grad_v, const __grid_constant__ CUtensorMap tma_q_out, const __grid_constant__ CUtensorMap tma_k_out, const __grid_constant__ CUtensorMap tma_v_out, size_t b, size_t s_q, size_t h_q, size_t d_qk, - size_t s_kv, size_t h_kv, size_t d_v, unsigned int permute_s_splits) { + size_t s_kv, size_t h_kv, size_t d_v, unsigned int permute_s_splits, + size_t s_tile_size) { #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) const int which = blockIdx.z; const T *__restrict__ tensor_in = which == 0 ? grad_q : (which == 1 ? grad_k : grad_v); @@ -401,7 +415,7 @@ __launch_bounds__(tma_permute_threads) __global__ void permute_to_grouped_tensor extern __shared__ __align__(128) char smem_raw[]; T *smem = reinterpret_cast(smem_raw); - constexpr size_t S_TILE = tma_permute_s_tile; + const size_t S_TILE = s_tile_size; constexpr size_t vec_elems = sizeof(uint4) / sizeof(T); for (size_t s_tile = s_begin; s_tile < s_end; s_tile += S_TILE) { @@ -428,26 +442,393 @@ __launch_bounds__(tma_permute_threads) __global__ void permute_to_grouped_tensor #endif } +// ---- Fallback: BSHD/SBHD ↔ BHSD (no TMA, works on any SM / dtype / D) ---- +// +// Same grid structure as the pre-TMA permute kernels: one block per (b, h) +// pair per S-partition; blockIdx.z selects Q (0), K (1), or V (2). +// +// Two strategies depending on D alignment: +// 1) "vec-flat": D divides evenly into wide vectors (16/8/4 bytes). +// Each thread handles one vector chunk; work = S_chunk * d_vec. +// 2) "row-copy": D is small / misaligned. Each thread handles complete rows +// to avoid expensive runtime integer division by D. Inner copy uses the +// widest loads/stores that fit, with smaller ops for the remainder. + +constexpr int fallback_permute_threads = 1024; + +// ---------- vec-flat helpers (D well-aligned) ---------- + +template +__device__ __forceinline__ void permute_fwd_vec_loop( + const T *__restrict__ in, T *__restrict__ out, size_t b, size_t S, size_t H, size_t D, + size_t b_i, size_t h_i, size_t s_begin, size_t S_chunk) { + const size_t out_base = b_i * H * S * D + h_i * S * D; + const size_t d_vec = D / static_cast(N); + const size_t total_work = S_chunk * d_vec; + for (size_t w = static_cast(threadIdx.x); w < total_work; + w += static_cast(blockDim.x)) { + const size_t s_local = w / d_vec; + const size_t s_i = s_begin + s_local; + const size_t d_off = (w % d_vec) * static_cast(N); + const T *__restrict__ in_ptr; + if constexpr (kIsSbhd) { + in_ptr = in + s_i * (b * H * D) + b_i * (H * D) + h_i * D + d_off; + } else { + in_ptr = in + b_i * (S * H * D) + s_i * (H * D) + h_i * D + d_off; + } + T *__restrict__ out_ptr = out + out_base + s_i * D + d_off; + *reinterpret_cast *>(out_ptr) = *reinterpret_cast *>(in_ptr); + } +} + +template +__device__ __forceinline__ void permute_bwd_vec_loop( + const T *__restrict__ in, T *__restrict__ out, size_t b, size_t S, size_t H, size_t D, + size_t b_i, size_t h_i, size_t s_begin, size_t S_chunk) { + const size_t in_base = b_i * H * S * D + h_i * S * D; + const size_t d_vec = D / static_cast(N); + const size_t total_work = S_chunk * d_vec; + for (size_t w = static_cast(threadIdx.x); w < total_work; + w += static_cast(blockDim.x)) { + const size_t s_local = w / d_vec; + const size_t s_i = s_begin + s_local; + const size_t d_off = (w % d_vec) * static_cast(N); + const T *__restrict__ in_ptr = in + in_base + s_i * D + d_off; + T *__restrict__ out_ptr; + if constexpr (kIsSbhd) { + out_ptr = out + s_i * (b * H * D) + b_i * (H * D) + h_i * D + d_off; + } else { + out_ptr = out + b_i * (S * H * D) + s_i * (H * D) + h_i * D + d_off; + } + *reinterpret_cast *>(out_ptr) = *reinterpret_cast *>(in_ptr); + } +} + +// ---------- row-copy helper (D small / misaligned) ---------- +// +// Copies D_bytes from src to dst using the widest loads that fit, +// stepping down through uint4 (16B) → uint2 (8B) → uint (4B) → ushort (2B) → uchar. + +__device__ __forceinline__ void copy_row_bytes(const char *__restrict__ src, + char *__restrict__ dst, size_t D_bytes) { + size_t off = 0; + for (; off + 16 <= D_bytes; off += 16) { + uint4 tmp; + memcpy(&tmp, src + off, 16); + memcpy(dst + off, &tmp, 16); + } + for (; off + 8 <= D_bytes; off += 8) { + uint2 tmp; + memcpy(&tmp, src + off, 8); + memcpy(dst + off, &tmp, 8); + } + for (; off + 4 <= D_bytes; off += 4) { + unsigned int tmp; + memcpy(&tmp, src + off, 4); + memcpy(dst + off, &tmp, 4); + } + for (; off + 2 <= D_bytes; off += 2) { + unsigned short tmp; + memcpy(&tmp, src + off, 2); + memcpy(dst + off, &tmp, 2); + } + for (; off < D_bytes; ++off) dst[off] = src[off]; +} + +// ---------- tiled-transpose kernels for small / misaligned D ---------- +// +// Problem: when D_bytes % 4 != 0 (e.g. D=6, unsigned char), the old row-copy +// path assigned one thread per row with a fixed (b, h) per block. Adjacent +// threads read S-rows that are B*H*D bytes apart => ~5 % cache-line use. +// +// Fix: treat the permutation as a 2-D transpose of [S, H] "atoms" of D bytes. +// Load a [TILE_S, TILE_H] tile through shared memory so that: +// Load – consecutive threads cover consecutive H (stride D in input) => coalesced reads +// Store – consecutive threads cover consecutive S (stride D in output) => coalesced writes +// +// Grid: (B * s_tiles, h_tiles, num_tensors) Block: TRANSPOSE_BLOCK + +constexpr int TRANSPOSE_TILE = 32; +constexpr int TRANSPOSE_BLOCK = 256; +constexpr int TRANSPOSE_WARPS = TRANSPOSE_BLOCK / 32; // 8 + +// FWD: strided (SBHD / BSHD) → contiguous BHSD +template +__launch_bounds__(TRANSPOSE_BLOCK) __global__ + void permute_fwd_tiled_transpose_kernel( + const T *__restrict__ q_in, const T *__restrict__ k_in, const T *__restrict__ v_in, + T *__restrict__ q_out, T *__restrict__ k_out, T *__restrict__ v_out, + size_t b, size_t s_q, size_t h_q, size_t d_qk, size_t s_kv, size_t h_kv, size_t d_v, + unsigned int s_tiles) { + const int which = blockIdx.z; + const T *__restrict__ in = which == 0 ? q_in : (which == 1 ? k_in : v_in); + T *__restrict__ out = which == 0 ? q_out : (which == 1 ? k_out : v_out); + const size_t S = which == 0 ? s_q : s_kv; + const size_t H = which == 0 ? h_q : h_kv; + const size_t D = which == 0 ? d_qk : (which == 1 ? d_qk : d_v); + const size_t D_bytes = D * sizeof(T); + const size_t D_pad = (D_bytes + 3u) & ~size_t(3); // 4-byte aligned for smem + + const size_t tile_s = static_cast(blockIdx.x) % static_cast(s_tiles); + const size_t b_i = static_cast(blockIdx.x) / static_cast(s_tiles); + if (b_i >= b) return; + const size_t tile_h = static_cast(blockIdx.y); + + const size_t s_base = tile_s * TRANSPOSE_TILE; + const size_t h_base = tile_h * TRANSPOSE_TILE; + + extern __shared__ char smem[]; + // +4 padding per S-row avoids 32-way bank conflicts during the store phase. + const size_t smem_row = static_cast(TRANSPOSE_TILE) * D_pad + 4; + + // ---- Phase 1: global → smem (sweep consecutive H → coalesced reads) ---- + for (unsigned int warp_off = threadIdx.x >> 5; + warp_off < TRANSPOSE_TILE; warp_off += TRANSPOSE_WARPS) { + const size_t local_s = warp_off; + const size_t local_h = threadIdx.x & 31u; + const size_t s_i = s_base + local_s; + const size_t h_i = h_base + local_h; + if (s_i < S && h_i < H) { + const char *__restrict__ src; + if constexpr (kIsSbhd) + src = reinterpret_cast(in + s_i * b * H * D + b_i * H * D + h_i * D); + else + src = reinterpret_cast(in + b_i * S * H * D + s_i * H * D + h_i * D); + copy_row_bytes(src, smem + local_s * smem_row + local_h * D_pad, D_bytes); + } + } + + __syncthreads(); + + // ---- Phase 2: smem → global (sweep consecutive S → coalesced writes) ---- + for (unsigned int warp_off = threadIdx.x >> 5; + warp_off < TRANSPOSE_TILE; warp_off += TRANSPOSE_WARPS) { + const size_t local_h = warp_off; + const size_t local_s = threadIdx.x & 31u; + const size_t s_i = s_base + local_s; + const size_t h_i = h_base + local_h; + if (s_i < S && h_i < H) { + copy_row_bytes(smem + local_s * smem_row + local_h * D_pad, + reinterpret_cast(out + b_i * H * S * D + h_i * S * D + s_i * D), + D_bytes); + } + } +} + +// BWD: contiguous BHSD → strided (SBHD / BSHD) +template +__launch_bounds__(TRANSPOSE_BLOCK) __global__ + void permute_bwd_tiled_transpose_kernel( + const T *__restrict__ grad_q, const T *__restrict__ grad_k, const T *__restrict__ grad_v, + T *__restrict__ q_out, T *__restrict__ k_out, T *__restrict__ v_out, + size_t b, size_t s_q, size_t h_q, size_t d_qk, size_t s_kv, size_t h_kv, size_t d_v, + unsigned int s_tiles) { + const int which = blockIdx.z; + const T *__restrict__ in = which == 0 ? grad_q : (which == 1 ? grad_k : grad_v); + T *__restrict__ out = which == 0 ? q_out : (which == 1 ? k_out : v_out); + const size_t S = which == 0 ? s_q : s_kv; + const size_t H = which == 0 ? h_q : h_kv; + const size_t D = which == 0 ? d_qk : (which == 1 ? d_qk : d_v); + const size_t D_bytes = D * sizeof(T); + const size_t D_pad = (D_bytes + 3u) & ~size_t(3); + + const size_t tile_s = static_cast(blockIdx.x) % static_cast(s_tiles); + const size_t b_i = static_cast(blockIdx.x) / static_cast(s_tiles); + if (b_i >= b) return; + const size_t tile_h = static_cast(blockIdx.y); + + const size_t s_base = tile_s * TRANSPOSE_TILE; + const size_t h_base = tile_h * TRANSPOSE_TILE; + + extern __shared__ char smem[]; + const size_t smem_row = static_cast(TRANSPOSE_TILE) * D_pad + 4; + + // ---- Phase 1: global → smem (sweep consecutive S → coalesced reads from BHSD) ---- + for (unsigned int warp_off = threadIdx.x >> 5; + warp_off < TRANSPOSE_TILE; warp_off += TRANSPOSE_WARPS) { + const size_t local_h = warp_off; + const size_t local_s = threadIdx.x & 31u; + const size_t s_i = s_base + local_s; + const size_t h_i = h_base + local_h; + if (s_i < S && h_i < H) { + copy_row_bytes( + reinterpret_cast(in + b_i * H * S * D + h_i * S * D + s_i * D), + smem + local_s * smem_row + local_h * D_pad, D_bytes); + } + } + + __syncthreads(); + + // ---- Phase 2: smem → global (sweep consecutive H → coalesced writes to SBHD/BSHD) ---- + for (unsigned int warp_off = threadIdx.x >> 5; + warp_off < TRANSPOSE_TILE; warp_off += TRANSPOSE_WARPS) { + const size_t local_s = warp_off; + const size_t local_h = threadIdx.x & 31u; + const size_t s_i = s_base + local_s; + const size_t h_i = h_base + local_h; + if (s_i < S && h_i < H) { + char *__restrict__ dst; + if constexpr (kIsSbhd) + dst = reinterpret_cast(out + s_i * b * H * D + b_i * H * D + h_i * D); + else + dst = reinterpret_cast(out + b_i * S * H * D + s_i * H * D + h_i * D); + copy_row_bytes(smem + local_s * smem_row + local_h * D_pad, dst, D_bytes); + } + } +} + +// ---------- forward kernel (well-aligned D) ---------- + +template +__launch_bounds__(fallback_permute_threads) __global__ + void permute_to_grouped_tensor_fwd_fallback_kernel( + const T *__restrict__ q_in, const T *__restrict__ k_in, const T *__restrict__ v_in, + T *__restrict__ q_out, T *__restrict__ k_out, T *__restrict__ v_out, size_t b, size_t s_q, + size_t h_q, size_t d_qk, size_t s_kv, size_t h_kv, size_t d_v, + unsigned int permute_s_splits) { + const int which = blockIdx.z; + const T *__restrict__ in = which == 0 ? q_in : (which == 1 ? k_in : v_in); + T *__restrict__ out = which == 0 ? q_out : (which == 1 ? k_out : v_out); + const size_t S = which == 0 ? s_q : s_kv; + const size_t H = which == 0 ? h_q : h_kv; + const size_t D = which == 0 ? d_qk : (which == 1 ? d_qk : d_v); + + const size_t h_grid = h_q > h_kv ? h_q : h_kv; + const size_t b_i = static_cast(blockIdx.x) / h_grid; + const size_t h_i = static_cast(blockIdx.x) % h_grid; + if (b_i >= b) return; + if (which == 0) { + if (h_i >= h_q) return; + } else { + if (h_i >= h_kv) return; + } + + const unsigned int s_part = blockIdx.y; + const size_t s_begin = + (S * static_cast(s_part)) / static_cast(permute_s_splits); + const size_t s_end = + (S * static_cast(s_part + 1)) / static_cast(permute_s_splits); + if (s_begin >= s_end) return; + const size_t S_chunk = s_end - s_begin; + + const size_t D_bytes = D * sizeof(T); + + if (D_bytes % 16 == 0) { + constexpr size_t N = 16 / sizeof(T); + permute_fwd_vec_loop(in, out, b, S, H, D, b_i, h_i, s_begin, S_chunk); + return; + } + if (D_bytes % 8 == 0) { + constexpr size_t N = 8 / sizeof(T); + permute_fwd_vec_loop(in, out, b, S, H, D, b_i, h_i, s_begin, S_chunk); + return; + } + if constexpr (sizeof(T) <= 4) { + if (D_bytes % 4 == 0) { + constexpr size_t N = 4 / sizeof(T); + permute_fwd_vec_loop(in, out, b, S, H, D, b_i, h_i, s_begin, S_chunk); + return; + } + } +} + +// ---------- backward kernel (well-aligned D) ---------- + +template +__launch_bounds__(fallback_permute_threads) __global__ + void permute_to_grouped_tensor_bwd_fallback_kernel( + const T *__restrict__ grad_q, const T *__restrict__ grad_k, const T *__restrict__ grad_v, + T *__restrict__ q_out, T *__restrict__ k_out, T *__restrict__ v_out, size_t b, size_t s_q, + size_t h_q, size_t d_qk, size_t s_kv, size_t h_kv, size_t d_v, + unsigned int permute_s_splits) { + const int which = blockIdx.z; + const T *__restrict__ in = which == 0 ? grad_q : (which == 1 ? grad_k : grad_v); + T *__restrict__ out = which == 0 ? q_out : (which == 1 ? k_out : v_out); + const size_t S = which == 0 ? s_q : s_kv; + const size_t H = which == 0 ? h_q : h_kv; + const size_t D = which == 0 ? d_qk : (which == 1 ? d_qk : d_v); + + const size_t h_grid = h_q > h_kv ? h_q : h_kv; + const size_t b_i = static_cast(blockIdx.x) / h_grid; + const size_t h_i = static_cast(blockIdx.x) % h_grid; + if (b_i >= b) return; + if (which == 0) { + if (h_i >= h_q) return; + } else { + if (h_i >= h_kv) return; + } + + const unsigned int s_part = blockIdx.y; + const size_t s_begin = + (S * static_cast(s_part)) / static_cast(permute_s_splits); + const size_t s_end = + (S * static_cast(s_part + 1)) / static_cast(permute_s_splits); + if (s_begin >= s_end) return; + const size_t S_chunk = s_end - s_begin; + + const size_t D_bytes = D * sizeof(T); + + if (D_bytes % 16 == 0) { + constexpr size_t N = 16 / sizeof(T); + permute_bwd_vec_loop(in, out, b, S, H, D, b_i, h_i, s_begin, S_chunk); + return; + } + if (D_bytes % 8 == 0) { + constexpr size_t N = 8 / sizeof(T); + permute_bwd_vec_loop(in, out, b, S, H, D, b_i, h_i, s_begin, S_chunk); + return; + } + if constexpr (sizeof(T) <= 4) { + if (D_bytes % 4 == 0) { + constexpr size_t N = 4 / sizeof(T); + permute_bwd_vec_loop(in, out, b, S, H, D, b_i, h_i, s_begin, S_chunk); + return; + } + } +} + +// ---- TMA feasibility check ---- + +static bool can_use_tma_permute(DType dtype, size_t d_qk, size_t d_v) { + switch (dtype) { + case DType::kFloat16: + case DType::kBFloat16: + case DType::kFloat8E4M3: + case DType::kFloat8E5M2: + case DType::kFloat8E8M0: + case DType::kByte: + break; + default: + return false; + } + const size_t elem_size = typeToSize(dtype); + const size_t inner_qk = d_qk * elem_size; + const size_t inner_v = d_v * elem_size; + if (inner_qk < 32 || inner_v < 32) return false; + if (inner_qk % 16 != 0 || inner_v % 16 != 0) return false; + return true; +} + // Helper: create a 4D TMA descriptor for the strided (BSHD or SBHD) tensor. // // For BSHD [B, S, H, D]: TMA dims [D, H, S, B], box [D, 1, S_TILE, 1] // For SBHD [S, B, H, D]: TMA dims [D, H, B, S], box [D, 1, 1, S_TILE] static void create_strided_tensor_map(CUtensorMap &map, void *ptr, DType dtype, size_t b, size_t s, - size_t h, size_t d, bool is_bshd) { + size_t h, size_t d, size_t s_tile, bool is_bshd) { if (is_bshd) { create_4D_tensor_map(map, ptr, dtype, static_cast(d), static_cast(h), static_cast(s), static_cast(b), - static_cast(d), 1, static_cast(tma_permute_s_tile), 1); + static_cast(d), 1, static_cast(s_tile), 1); } else { create_4D_tensor_map(map, ptr, dtype, static_cast(d), static_cast(h), static_cast(b), static_cast(s), - static_cast(d), 1, 1, static_cast(tma_permute_s_tile)); + static_cast(d), 1, 1, static_cast(s_tile)); } } void permute_to_grouped_tensor_fwd(Tensor q, Tensor k, Tensor v, Tensor q_out, Tensor k_out, - Tensor v_out, NVTE_QKV_Layout original_layout, - cudaStream_t stream) { + Tensor v_out, NVTE_QKV_Format original_format, + size_t num_tensors, cudaStream_t stream) { using namespace transformer_engine; const size_t b = q_out.shape()[0]; const size_t h_q = q_out.shape()[1]; @@ -457,49 +838,144 @@ void permute_to_grouped_tensor_fwd(Tensor q, Tensor k, Tensor v, Tensor q_out, T const size_t s_kv = k_out.shape()[2]; const size_t d_v = v_out.shape()[3]; - NVTE_CHECK(d_qk % nvec128 == 0 && d_v % nvec128 == 0, - "permute_to_grouped_tensor_fwd: head dim must be divisible by ", nvec128, "."); + const bool is_bshd = (original_format == NVTE_QKV_Format::NVTE_BSHD); + + if (!can_use_tma_permute(q.dtype(), d_qk, d_v)) { + const size_t elem_sz = typeToSize(q.dtype()); + const size_t d_qk_bytes = d_qk * elem_sz; + const size_t d_v_bytes = d_v * elem_sz; + const bool needs_transpose = (d_qk_bytes % 4 != 0) || (d_v_bytes % 4 != 0); + + if (needs_transpose) { + // Tiled transpose path: grid = (B * s_tiles, h_tiles, num_tensors) + const size_t s_max = std::max(s_q, s_kv); + const size_t h_max = std::max(h_q, h_kv); + const unsigned int st = static_cast( + (s_max + TRANSPOSE_TILE - 1) / TRANSPOSE_TILE); + const unsigned int ht = static_cast( + (h_max + TRANSPOSE_TILE - 1) / TRANSPOSE_TILE); + dim3 grid(static_cast(b) * st, ht, + static_cast(num_tensors)); + const size_t d_max = std::max(d_qk, d_v); + const size_t D_pad = (d_max * elem_sz + 3u) & ~size_t(3); + const size_t smem_bytes = + static_cast(TRANSPOSE_TILE) * + (static_cast(TRANSPOSE_TILE) * D_pad + 4); + + if (is_bshd) { + TRANSFORMER_ENGINE_TYPE_SWITCH_ALL( + q.dtype(), dtype, + permute_fwd_tiled_transpose_kernel + <<>>( + reinterpret_cast(q.data.dptr), + reinterpret_cast(k.data.dptr), + reinterpret_cast(v.data.dptr), + reinterpret_cast(q_out.data.dptr), + reinterpret_cast(k_out.data.dptr), + reinterpret_cast(v_out.data.dptr), + b, s_q, h_q, d_qk, s_kv, h_kv, d_v, st);); + } else { + TRANSFORMER_ENGINE_TYPE_SWITCH_ALL( + q.dtype(), dtype, + permute_fwd_tiled_transpose_kernel + <<>>( + reinterpret_cast(q.data.dptr), + reinterpret_cast(k.data.dptr), + reinterpret_cast(v.data.dptr), + reinterpret_cast(q_out.data.dptr), + reinterpret_cast(k_out.data.dptr), + reinterpret_cast(v_out.data.dptr), + b, s_q, h_q, d_qk, s_kv, h_kv, d_v, st);); + } + NVTE_CHECK_CUDA(cudaGetLastError()); + return; + } - const bool is_bshd = (original_layout == NVTE_QKV_Layout::NVTE_BSHD_BSHD_BSHD); + // Well-aligned D: vec-flat fallback (uses the original grid over (b*h, s_splits)) + const size_t s_min = std::min(s_q, s_kv); + const unsigned int permute_s_splits = std::max( + 1u, static_cast(s_min / static_cast(fallback_permute_threads))); + const size_t h_grid = std::max(h_q, h_kv); + dim3 grid(static_cast(b * h_grid), permute_s_splits, + static_cast(num_tensors)); + + if (is_bshd) { + TRANSFORMER_ENGINE_TYPE_SWITCH_ALL( + q.dtype(), dtype, + permute_to_grouped_tensor_fwd_fallback_kernel + <<>>( + reinterpret_cast(q.data.dptr), + reinterpret_cast(k.data.dptr), + reinterpret_cast(v.data.dptr), + reinterpret_cast(q_out.data.dptr), + reinterpret_cast(k_out.data.dptr), + reinterpret_cast(v_out.data.dptr), + b, s_q, h_q, d_qk, s_kv, h_kv, d_v, permute_s_splits);); + } else { + TRANSFORMER_ENGINE_TYPE_SWITCH_ALL( + q.dtype(), dtype, + permute_to_grouped_tensor_fwd_fallback_kernel + <<>>( + reinterpret_cast(q.data.dptr), + reinterpret_cast(k.data.dptr), + reinterpret_cast(v.data.dptr), + reinterpret_cast(q_out.data.dptr), + reinterpret_cast(k_out.data.dptr), + reinterpret_cast(v_out.data.dptr), + b, s_q, h_q, d_qk, s_kv, h_kv, d_v, permute_s_splits);); + } + NVTE_CHECK_CUDA(cudaGetLastError()); + return; + } + + const size_t elem_size = typeToSize(q.dtype()); + const size_t s_min = std::min(s_q, s_kv); + const size_t s_tile = std::min(static_cast(tma_permute_s_tile_default), s_min); + NVTE_CHECK((s_tile * d_qk * elem_size) % sizeof(uint4) == 0 && + (s_tile * d_v * elem_size) % sizeof(uint4) == 0, + "permute_to_grouped_tensor_fwd: S_TILE(", s_tile, ") * D * elem_size must " + "be divisible by ", sizeof(uint4), ". d_qk=", d_qk, ", d_v=", d_v, + ", elem_size=", elem_size, "."); alignas(64) CUtensorMap tma_q_in{}, tma_k_in{}, tma_v_in{}; - create_strided_tensor_map(tma_q_in, q.data.dptr, q.dtype(), b, s_q, h_q, d_qk, is_bshd); - create_strided_tensor_map(tma_k_in, k.data.dptr, k.dtype(), b, s_kv, h_kv, d_qk, is_bshd); - create_strided_tensor_map(tma_v_in, v.data.dptr, v.dtype(), b, s_kv, h_kv, d_v, is_bshd); + create_strided_tensor_map(tma_q_in, q.data.dptr, q.dtype(), b, s_q, h_q, d_qk, s_tile, is_bshd); + create_strided_tensor_map(tma_k_in, k.data.dptr, k.dtype(), b, s_kv, h_kv, d_qk, s_tile, is_bshd); + create_strided_tensor_map(tma_v_in, v.data.dptr, v.dtype(), b, s_kv, h_kv, d_v, s_tile, is_bshd); - const size_t s_min = std::min(s_q, s_kv); const unsigned int permute_s_splits = - std::max(1u, static_cast(s_min / static_cast(tma_permute_threads))); + std::max(1u, static_cast(s_min / s_tile)); const size_t h_grid = std::max(h_q, h_kv); - dim3 grid(static_cast(b * h_grid), permute_s_splits, 3); + dim3 grid(static_cast(b * h_grid), permute_s_splits, + static_cast(num_tensors)); const size_t d_max = std::max(d_qk, d_v); - const size_t smem_bytes = tma_permute_s_tile * d_max * sizeof(uint16_t); + const size_t smem_bytes = s_tile * d_max * elem_size; if (is_bshd) { - TRANSFORMER_ENGINE_TYPE_SWITCH_16BIT( + TRANSFORMER_ENGINE_TYPE_SWITCH_ALL( q.dtype(), dtype, auto kernel = permute_to_grouped_tensor_fwd_kernel; NVTE_CHECK_CUDA( cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_bytes)); kernel<<>>( tma_q_in, tma_k_in, tma_v_in, reinterpret_cast(q_out.data.dptr), reinterpret_cast(k_out.data.dptr), reinterpret_cast(v_out.data.dptr), - b, s_q, h_q, d_qk, s_kv, h_kv, d_v, permute_s_splits);); + b, s_q, h_q, d_qk, s_kv, h_kv, d_v, permute_s_splits, s_tile);); } else { - TRANSFORMER_ENGINE_TYPE_SWITCH_16BIT( + TRANSFORMER_ENGINE_TYPE_SWITCH_ALL( q.dtype(), dtype, auto kernel = permute_to_grouped_tensor_fwd_kernel; NVTE_CHECK_CUDA( cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_bytes)); kernel<<>>( tma_q_in, tma_k_in, tma_v_in, reinterpret_cast(q_out.data.dptr), reinterpret_cast(k_out.data.dptr), reinterpret_cast(v_out.data.dptr), - b, s_q, h_q, d_qk, s_kv, h_kv, d_v, permute_s_splits);); + b, s_q, h_q, d_qk, s_kv, h_kv, d_v, permute_s_splits, s_tile);); } NVTE_CHECK_CUDA(cudaGetLastError()); } void permute_to_grouped_tensor_bwd(Tensor grad_q, Tensor grad_k, Tensor grad_v, Tensor q, Tensor k, - Tensor v, NVTE_QKV_Layout original_layout, cudaStream_t stream) { + Tensor v, NVTE_QKV_Format original_format, + size_t num_tensors, cudaStream_t stream) { using namespace transformer_engine; const size_t b = grad_q.shape()[0]; const size_t h_q = grad_q.shape()[1]; @@ -509,27 +985,119 @@ void permute_to_grouped_tensor_bwd(Tensor grad_q, Tensor grad_k, Tensor grad_v, const size_t s_kv = grad_k.shape()[2]; const size_t d_v = grad_v.shape()[3]; - NVTE_CHECK(d_qk % nvec128 == 0 && d_v % nvec128 == 0, - "permute_to_grouped_tensor_bwd: head dim must be divisible by ", nvec128, "."); + const bool is_bshd = (original_format == NVTE_QKV_Format::NVTE_BSHD); + + if (!can_use_tma_permute(grad_q.dtype(), d_qk, d_v)) { + const size_t elem_sz = typeToSize(grad_q.dtype()); + const size_t d_qk_bytes = d_qk * elem_sz; + const size_t d_v_bytes = d_v * elem_sz; + const bool needs_transpose = (d_qk_bytes % 4 != 0) || (d_v_bytes % 4 != 0); + + if (needs_transpose) { + const size_t s_max = std::max(s_q, s_kv); + const size_t h_max = std::max(h_q, h_kv); + const unsigned int st = static_cast( + (s_max + TRANSPOSE_TILE - 1) / TRANSPOSE_TILE); + const unsigned int ht = static_cast( + (h_max + TRANSPOSE_TILE - 1) / TRANSPOSE_TILE); + dim3 grid(static_cast(b) * st, ht, + static_cast(num_tensors)); + const size_t d_max = std::max(d_qk, d_v); + const size_t D_pad = (d_max * elem_sz + 3u) & ~size_t(3); + const size_t smem_bytes = + static_cast(TRANSPOSE_TILE) * + (static_cast(TRANSPOSE_TILE) * D_pad + 4); + + if (is_bshd) { + TRANSFORMER_ENGINE_TYPE_SWITCH_ALL( + grad_q.dtype(), dtype, + permute_bwd_tiled_transpose_kernel + <<>>( + reinterpret_cast(grad_q.data.dptr), + reinterpret_cast(grad_k.data.dptr), + reinterpret_cast(grad_v.data.dptr), + reinterpret_cast(q.data.dptr), + reinterpret_cast(k.data.dptr), + reinterpret_cast(v.data.dptr), + b, s_q, h_q, d_qk, s_kv, h_kv, d_v, st);); + } else { + TRANSFORMER_ENGINE_TYPE_SWITCH_ALL( + grad_q.dtype(), dtype, + permute_bwd_tiled_transpose_kernel + <<>>( + reinterpret_cast(grad_q.data.dptr), + reinterpret_cast(grad_k.data.dptr), + reinterpret_cast(grad_v.data.dptr), + reinterpret_cast(q.data.dptr), + reinterpret_cast(k.data.dptr), + reinterpret_cast(v.data.dptr), + b, s_q, h_q, d_qk, s_kv, h_kv, d_v, st);); + } + NVTE_CHECK_CUDA(cudaGetLastError()); + return; + } - const bool is_bshd = (original_layout == NVTE_QKV_Layout::NVTE_BSHD_BSHD_BSHD); + const size_t s_min = std::min(s_q, s_kv); + const unsigned int permute_s_splits = std::max( + 1u, static_cast(s_min / static_cast(fallback_permute_threads))); + const size_t h_grid = std::max(h_q, h_kv); + dim3 grid(static_cast(b * h_grid), permute_s_splits, + static_cast(num_tensors)); + + if (is_bshd) { + TRANSFORMER_ENGINE_TYPE_SWITCH_ALL( + grad_q.dtype(), dtype, + permute_to_grouped_tensor_bwd_fallback_kernel + <<>>( + reinterpret_cast(grad_q.data.dptr), + reinterpret_cast(grad_k.data.dptr), + reinterpret_cast(grad_v.data.dptr), + reinterpret_cast(q.data.dptr), + reinterpret_cast(k.data.dptr), + reinterpret_cast(v.data.dptr), + b, s_q, h_q, d_qk, s_kv, h_kv, d_v, permute_s_splits);); + } else { + TRANSFORMER_ENGINE_TYPE_SWITCH_ALL( + grad_q.dtype(), dtype, + permute_to_grouped_tensor_bwd_fallback_kernel + <<>>( + reinterpret_cast(grad_q.data.dptr), + reinterpret_cast(grad_k.data.dptr), + reinterpret_cast(grad_v.data.dptr), + reinterpret_cast(q.data.dptr), + reinterpret_cast(k.data.dptr), + reinterpret_cast(v.data.dptr), + b, s_q, h_q, d_qk, s_kv, h_kv, d_v, permute_s_splits);); + } + NVTE_CHECK_CUDA(cudaGetLastError()); + return; + } + + const size_t elem_size = typeToSize(grad_q.dtype()); + const size_t s_min = std::min(s_q, s_kv); + const size_t s_tile = std::min(static_cast(tma_permute_s_tile_default), s_min); + NVTE_CHECK((s_tile * d_qk * elem_size) % sizeof(uint4) == 0 && + (s_tile * d_v * elem_size) % sizeof(uint4) == 0, + "permute_to_grouped_tensor_bwd: S_TILE(", s_tile, ") * D * elem_size must " + "be divisible by ", sizeof(uint4), ". d_qk=", d_qk, ", d_v=", d_v, + ", elem_size=", elem_size, "."); alignas(64) CUtensorMap tma_q_out{}, tma_k_out{}, tma_v_out{}; - create_strided_tensor_map(tma_q_out, q.data.dptr, q.dtype(), b, s_q, h_q, d_qk, is_bshd); - create_strided_tensor_map(tma_k_out, k.data.dptr, k.dtype(), b, s_kv, h_kv, d_qk, is_bshd); - create_strided_tensor_map(tma_v_out, v.data.dptr, v.dtype(), b, s_kv, h_kv, d_v, is_bshd); + create_strided_tensor_map(tma_q_out, q.data.dptr, q.dtype(), b, s_q, h_q, d_qk, s_tile, is_bshd); + create_strided_tensor_map(tma_k_out, k.data.dptr, k.dtype(), b, s_kv, h_kv, d_qk, s_tile, is_bshd); + create_strided_tensor_map(tma_v_out, v.data.dptr, v.dtype(), b, s_kv, h_kv, d_v, s_tile, is_bshd); - const size_t s_min = std::min(s_q, s_kv); const unsigned int permute_s_splits = - std::max(1u, static_cast(s_min / static_cast(tma_permute_threads))); + std::max(1u, static_cast(s_min / s_tile)); const size_t h_grid = std::max(h_q, h_kv); - dim3 grid(static_cast(b * h_grid), permute_s_splits, 3); + dim3 grid(static_cast(b * h_grid), permute_s_splits, + static_cast(num_tensors)); const size_t d_max = std::max(d_qk, d_v); - const size_t smem_bytes = tma_permute_s_tile * d_max * sizeof(uint16_t); + const size_t smem_bytes = s_tile * d_max * elem_size; if (is_bshd) { - TRANSFORMER_ENGINE_TYPE_SWITCH_16BIT( + TRANSFORMER_ENGINE_TYPE_SWITCH_ALL( grad_q.dtype(), dtype, auto kernel = permute_to_grouped_tensor_bwd_kernel; NVTE_CHECK_CUDA( cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_bytes)); @@ -537,9 +1105,9 @@ void permute_to_grouped_tensor_bwd(Tensor grad_q, Tensor grad_k, Tensor grad_v, reinterpret_cast(grad_q.data.dptr), reinterpret_cast(grad_k.data.dptr), reinterpret_cast(grad_v.data.dptr), tma_q_out, tma_k_out, tma_v_out, b, - s_q, h_q, d_qk, s_kv, h_kv, d_v, permute_s_splits);); + s_q, h_q, d_qk, s_kv, h_kv, d_v, permute_s_splits, s_tile);); } else { - TRANSFORMER_ENGINE_TYPE_SWITCH_16BIT( + TRANSFORMER_ENGINE_TYPE_SWITCH_ALL( grad_q.dtype(), dtype, auto kernel = permute_to_grouped_tensor_bwd_kernel; NVTE_CHECK_CUDA( cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_bytes)); @@ -547,10 +1115,118 @@ void permute_to_grouped_tensor_bwd(Tensor grad_q, Tensor grad_k, Tensor grad_v, reinterpret_cast(grad_q.data.dptr), reinterpret_cast(grad_k.data.dptr), reinterpret_cast(grad_v.data.dptr), tma_q_out, tma_k_out, tma_v_out, b, - s_q, h_q, d_qk, s_kv, h_kv, d_v, permute_s_splits);); + s_q, h_q, d_qk, s_kv, h_kv, d_v, permute_s_splits, s_tile);); + } + NVTE_CHECK_CUDA(cudaGetLastError()); +} +// ---- Multi-tensor pad last dimension with zeros ---- +// +// Pads multiple 2D row-major tensors in a single kernel launch. +// Each tensor copies (rows, in_cols) → (rows, padded_cols), zero-filling [in_cols, padded_cols). +// Uses uint32 (4-byte) granularity for coalesced memory access. +// blockIdx.y selects the tensor; blockIdx.x * blockDim.x + threadIdx.x selects the uint32 element. + +constexpr int pad_threads_per_block = 256; +constexpr int kMaxPadTensors = 16; + +struct PadLastDimArgs { + const uint8_t *input; + uint32_t *output; + size_t n_uint32; + uint32_t in_row_bytes; + uint32_t out_row_uint32; +}; + +struct MultiPadParams { + PadLastDimArgs tensors[kMaxPadTensors]; +}; + +__launch_bounds__(pad_threads_per_block) __global__ + void multi_pad_last_dim_kernel(MultiPadParams params) { + const auto &a = params.tensors[blockIdx.y]; + + for (size_t idx = static_cast(blockIdx.x) * blockDim.x + threadIdx.x; idx < a.n_uint32; + idx += static_cast(gridDim.x) * blockDim.x) { + const uint32_t col_byte = (idx % a.out_row_uint32) * 4; + const size_t row = idx / a.out_row_uint32; + const uint8_t *__restrict__ src = a.input + row * static_cast(a.in_row_bytes); + + uint32_t val; + if (col_byte + 4 <= a.in_row_bytes) { + memcpy(&val, src + col_byte, 4); + } else if (col_byte >= a.in_row_bytes) { + val = 0; + } else { + val = 0; + memcpy(&val, src + col_byte, a.in_row_bytes - col_byte); + } + a.output[idx] = val; + } +} + +void multi_pad_last_dim(Tensor *inputs, Tensor *outputs, size_t num_tensors, + cudaStream_t stream) { + using namespace transformer_engine; + + NVTE_CHECK(num_tensors > 0 && num_tensors <= kMaxPadTensors, + "num_tensors must be in [1, ", kMaxPadTensors, "], got ", num_tensors, "."); + + MultiPadParams params{}; + size_t max_n_uint32 = 0; + int kernel_count = 0; + + for (size_t i = 0; i < num_tensors; ++i) { + auto &inp = inputs[i]; + auto &out = outputs[i]; + + NVTE_CHECK(inp.data.shape.size() == 2, "Expected 2D input tensor at index ", i, "."); + NVTE_CHECK(out.data.shape.size() == 2, "Expected 2D output tensor at index ", i, "."); + NVTE_CHECK(inp.data.dtype == out.data.dtype, "Dtype mismatch at index ", i, "."); + + const size_t rows = inp.data.shape[0]; + const size_t in_cols = inp.data.shape[1]; + const size_t out_cols = out.data.shape[1]; + + NVTE_CHECK(out.data.shape[0] == rows, "Row count mismatch at index ", i, "."); + NVTE_CHECK(out_cols >= in_cols, "out_cols < in_cols at index ", i, "."); + + if (rows == 0) continue; + + if (in_cols == out_cols) { + const size_t total_bytes = rows * in_cols * typeToSize(inp.data.dtype); + NVTE_CHECK_CUDA(cudaMemcpyAsync(out.data.dptr, inp.data.dptr, total_bytes, + cudaMemcpyDeviceToDevice, stream)); + continue; + } + + const size_t elem_size = typeToSize(inp.data.dtype); + const auto in_row_bytes = static_cast(in_cols * elem_size); + const auto out_row_bytes = static_cast(out_cols * elem_size); + NVTE_CHECK(out_row_bytes % 4 == 0, + "Padded row size in bytes (", out_row_bytes, ") must be a multiple of 4."); + + const uint32_t out_row_uint32 = out_row_bytes / 4; + const size_t n_uint32 = rows * out_row_uint32; + + params.tensors[kernel_count] = {reinterpret_cast(inp.data.dptr), + reinterpret_cast(out.data.dptr), n_uint32, + in_row_bytes, out_row_uint32}; + max_n_uint32 = std::max(max_n_uint32, n_uint32); + ++kernel_count; } + + if (kernel_count == 0) return; + + constexpr int threads = pad_threads_per_block; + const int blocks_x = + static_cast(std::min(DIVUP(max_n_uint32, static_cast(threads)), + static_cast(65535))); + dim3 grid(blocks_x, kernel_count); + + multi_pad_last_dim_kernel<<>>(params); NVTE_CHECK_CUDA(cudaGetLastError()); } + } // namespace flash_attention } // namespace transformer_engine @@ -574,24 +1250,39 @@ void nvte_prepare_flash_attn_bwd(NVTETensor q, NVTETensor k, NVTETensor v, NVTET void nvte_permute_to_grouped_tensor_fwd(NVTETensor q, NVTETensor k, NVTETensor v, NVTETensor q_out, NVTETensor k_out, NVTETensor v_out, - NVTE_QKV_Layout original_layout, cudaStream_t stream) { + NVTE_QKV_Format original_format, size_t num_tensors, + cudaStream_t stream) { NVTE_API_CALL(nvte_permute_to_grouped_tensor_fwd); using namespace transformer_engine; flash_attention::permute_to_grouped_tensor_fwd( *convertNVTETensorCheck(q), *convertNVTETensorCheck(k), *convertNVTETensorCheck(v), *convertNVTETensorCheck(q_out), *convertNVTETensorCheck(k_out), - *convertNVTETensorCheck(v_out), original_layout, stream); + *convertNVTETensorCheck(v_out), original_format, num_tensors, stream); } void nvte_permute_to_grouped_tensor_bwd(NVTETensor grad_q, NVTETensor grad_k, NVTETensor grad_v, NVTETensor q, NVTETensor k, NVTETensor v, - NVTE_QKV_Layout original_layout, cudaStream_t stream) { + NVTE_QKV_Format original_format, size_t num_tensors, + cudaStream_t stream) { NVTE_API_CALL(nvte_permute_to_grouped_tensor_bwd); using namespace transformer_engine; flash_attention::permute_to_grouped_tensor_bwd( *convertNVTETensorCheck(grad_q), *convertNVTETensorCheck(grad_k), *convertNVTETensorCheck(grad_v), *convertNVTETensorCheck(q), *convertNVTETensorCheck(k), - *convertNVTETensorCheck(v), original_layout, stream); + *convertNVTETensorCheck(v), original_format, num_tensors, stream); +} + +void nvte_multi_pad_last_dim(NVTETensor *inputs, NVTETensor *outputs, size_t num_tensors, + cudaStream_t stream) { + NVTE_API_CALL(nvte_multi_pad_last_dim); + using namespace transformer_engine; + + std::vector in_vec(num_tensors), out_vec(num_tensors); + for (size_t i = 0; i < num_tensors; ++i) { + in_vec[i] = *convertNVTETensorCheck(inputs[i]); + out_vec[i] = *convertNVTETensorCheck(outputs[i]); + } + flash_attention::multi_pad_last_dim(in_vec.data(), out_vec.data(), num_tensors, stream); } diff --git a/transformer_engine/common/fused_attn/fused_attn.cpp b/transformer_engine/common/fused_attn/fused_attn.cpp index 5498c601a6..1792e21fcc 100644 --- a/transformer_engine/common/fused_attn/fused_attn.cpp +++ b/transformer_engine/common/fused_attn/fused_attn.cpp @@ -669,7 +669,8 @@ void nvte_fused_attn_fwd( bool cuda_graph, float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, NVTE_QKV_Format o_format, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, int64_t window_size_left, int64_t window_size_right, - bool bottom_right_diagonal, NVTETensor workspace, cudaStream_t stream) { + bool bottom_right_diagonal, NVTETensor workspace, cudaStream_t stream, + NVTE_QKV_Format qkv_scale_inv_format) { NVTE_API_CALL(nvte_flash_attn_fwd); using namespace transformer_engine; const Tensor *input_cu_seqlens_q = convertNVTETensorCheck(cu_seqlens_q); @@ -777,7 +778,7 @@ void nvte_fused_attn_fwd( softmax_type, window_size_left, window_size_right, bottom_right_diagonal, input_Q, input_K, input_V, input_SoftmaxOffset, input_output_S, output_O, Aux_CTX_Tensors, input_cu_seqlens_q, input_cu_seqlens_kv, input_rng_state, - wkspace, stream, handle); + wkspace, stream, handle, qkv_scale_inv_format); #else NVTE_ERROR("cuDNN 8.9.0 is required for FP8 fused attention. \n"); #endif @@ -799,7 +800,9 @@ void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETenso NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, int64_t window_size_left, int64_t window_size_right, bool bottom_right_diagonal, bool deterministic, - bool cuda_graph, NVTETensor workspace, cudaStream_t stream) { + bool cuda_graph, NVTETensor workspace, cudaStream_t stream, + NVTE_QKV_Format qkv_scale_inv_format, + NVTE_QKV_Format do_scale_inv_format) { NVTE_API_CALL(nvte_flash_attn_bwd); using namespace transformer_engine; const Tensor *input_cu_seqlens_q = convertNVTETensorCheck(cu_seqlens_q); @@ -904,7 +907,8 @@ void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETenso deterministic, input_Q, input_K, input_V, input_O, input_dO, input_dO_f16, input_M, input_ZInv, input_S, input_SoftmaxOffset, input_output_dP, output_dQ, output_dK, output_dV, output_dSoftmaxOffset, input_cu_seqlens_q, - input_cu_seqlens_kv, input_rng_state, wkspace, stream, handle); + input_cu_seqlens_kv, input_rng_state, wkspace, stream, handle, + qkv_scale_inv_format, do_scale_inv_format); #else NVTE_ERROR("cuDNN 8.9.0 is required for FP8 fused attention. \n"); #endif diff --git a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu index f8c3992587..e2c1092a17 100644 --- a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu +++ b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu @@ -139,6 +139,8 @@ void fused_attn_arbitrary_seqlen_fwd_impl( o_format, NVTE_QKV_Format_NOT_SET, NVTE_QKV_Layout_NOT_SET, + NVTE_QKV_Format_NOT_SET, + NVTE_QKV_Format_NOT_SET, bias_type, mask_type, softmax_type, @@ -640,6 +642,8 @@ void fused_attn_arbitrary_seqlen_bwd_impl( o_format, do_format, dqkv_layout, + NVTE_QKV_Format_NOT_SET, + NVTE_QKV_Format_NOT_SET, bias_type, mask_type, softmax_type, diff --git a/transformer_engine/common/fused_attn/fused_attn_fp8.cu b/transformer_engine/common/fused_attn/fused_attn_fp8.cu index 6fa366dc2c..e2503ea881 100644 --- a/transformer_engine/common/fused_attn/fused_attn_fp8.cu +++ b/transformer_engine/common/fused_attn/fused_attn_fp8.cu @@ -1662,7 +1662,8 @@ void fused_attn_fp8_fwd_impl_v1( void* devPtrScaleO, void* devPtrAmaxO, void* devPtrAmaxS, void* devPtrcuSeqlensQ, void* devPtrcuSeqlensKV, void* devPtrDropoutSeed, void* devPtrDropoutOffset, cudnn_frontend::DataType_t qkv_tensor_type, cudnn_frontend::DataType_t o_tensor_type, - NVTEScalingMode scaling_mode, void* workspace, size_t* workspace_size, cudaStream_t stream, + NVTEScalingMode scaling_mode, NVTE_QKV_Format qkv_scale_inv_format, + void* workspace, size_t* workspace_size, cudaStream_t stream, cudnnHandle_t handle) { using namespace transformer_engine; const auto cudnn_runtime_version = cudnnGetVersion(); @@ -1718,6 +1719,8 @@ void fused_attn_fp8_fwd_impl_v1( o_format, NVTE_QKV_Format_NOT_SET, NVTE_QKV_Layout_NOT_SET, + qkv_scale_inv_format, + NVTE_QKV_Format_NOT_SET, bias_type, mask_type, softmax_type, @@ -1823,18 +1826,35 @@ void fused_attn_fp8_fwd_impl_v1( scale_o = mha_graph->tensor(1.0f); } } else if (is_mxfp8) { - NVTE_QKV_Format q_format = nvte_get_q_format(qkv_layout); - NVTE_QKV_Format kv_format = nvte_get_kv_format(qkv_layout); + NVTE_QKV_Format q_scale_format = + (qkv_scale_inv_format != NVTE_QKV_Format_NOT_SET) + ? qkv_scale_inv_format : nvte_get_q_format(qkv_layout); + NVTE_QKV_Format kv_scale_format = + (qkv_scale_inv_format != NVTE_QKV_Format_NOT_SET) + ? qkv_scale_inv_format : nvte_get_kv_format(qkv_layout); std::vector q_scale_strides(4); std::vector k_scale_strides(4); std::vector v_scale_strides(4); auto padded = pad_s_d_for_mxfp8(s_q, s_kv, d_qk, d_v); generateMatrixStridesWithFormat(b, h, padded.s_q_padded, padded.d_qk_scale_padded, - q_scale_strides.data(), q_format); + q_scale_strides.data(), q_scale_format); generateMatrixStridesWithFormat(b, hg, padded.s_kv_padded, padded.d_qk_scale_padded, - k_scale_strides.data(), kv_format); + k_scale_strides.data(), kv_scale_format); generateMatrixStridesWithFormat(b, hg, padded.s_kv_scale_padded, padded.d_v_padded, - v_scale_strides.data(), kv_format); + v_scale_strides.data(), kv_scale_format); + printf("q_scale_strides: %d, %d, %d, %d\n", q_scale_strides[0], q_scale_strides[1], q_scale_strides[2], q_scale_strides[3]); + printf("k_scale_strides: %d, %d, %d, %d\n", k_scale_strides[0], k_scale_strides[1], k_scale_strides[2], k_scale_strides[3]); + printf("v_scale_strides: %d, %d, %d, %d\n", v_scale_strides[0], v_scale_strides[1], v_scale_strides[2], v_scale_strides[3]); + printf("qkv_layout: %d\n", qkv_layout); + printf("qkv_scale_inv_format: %d\n", qkv_scale_inv_format); + printf("q_scale_format: %d\n", q_scale_format); + printf("kv_scale_format: %d\n", kv_scale_format); + printf("padded.s_q_padded: %d\n", padded.s_q_padded); + printf("padded.d_qk_scale_padded: %d\n", padded.d_qk_scale_padded); + printf("padded.s_kv_padded: %d\n", padded.s_kv_padded); + printf("padded.d_qk_scale_padded: %d\n", padded.d_qk_scale_padded); + printf("padded.s_kv_scale_padded: %d\n", padded.s_kv_scale_padded); + printf("padded.d_v_padded: %d\n", padded.d_v_padded); descale_q = mha_graph->tensor(fe::graph::Tensor_attributes() .set_name("Descale_q") @@ -2090,8 +2110,9 @@ void fused_attn_fp8_bwd_impl_v1( void* devPtrDescaledO_t, void* devPtrcuSeqlensQ, void* devPtrcuSeqlensKV, void* devPtrDropoutSeed, void* devPtrDropoutOffset, cudnn_frontend::DataType_t qkv_tensor_type, cudnn_frontend::DataType_t o_tensor_type, cudnn_frontend::DataType_t do_tensor_type, - cudnn_frontend::DataType_t dqkv_tensor_type, NVTEScalingMode scaling_mode, void* workspace, - size_t* workspace_size, cudaStream_t stream, cudnnHandle_t handle) { + cudnn_frontend::DataType_t dqkv_tensor_type, NVTEScalingMode scaling_mode, + NVTE_QKV_Format qkv_scale_inv_format, NVTE_QKV_Format do_scale_inv_format, + void* workspace, size_t* workspace_size, cudaStream_t stream, cudnnHandle_t handle) { using namespace transformer_engine; const auto cudnn_runtime_version = cudnnGetVersion(); bool is_bias = (bias_type == NVTE_Bias_Type::NVTE_POST_SCALE_BIAS); @@ -2149,6 +2170,8 @@ void fused_attn_fp8_bwd_impl_v1( o_format, do_format, dqkv_layout, + qkv_scale_inv_format, + do_scale_inv_format, bias_type, mask_type, softmax_type, @@ -2313,6 +2336,15 @@ void fused_attn_fp8_bwd_impl_v1( } else if (is_mxfp8) { NVTE_QKV_Format q_format = nvte_get_q_format(qkv_layout); NVTE_QKV_Format kv_format = nvte_get_kv_format(qkv_layout); + NVTE_QKV_Format q_scale_format = + (qkv_scale_inv_format != NVTE_QKV_Format_NOT_SET) + ? qkv_scale_inv_format : q_format; + NVTE_QKV_Format kv_scale_format = + (qkv_scale_inv_format != NVTE_QKV_Format_NOT_SET) + ? qkv_scale_inv_format : kv_format; + NVTE_QKV_Format do_scale_format = + (do_scale_inv_format != NVTE_QKV_Format_NOT_SET) + ? do_scale_inv_format : do_format; // Q_t, K_t, dO_t, dO_f16 std::vector q_t_strides(4), k_t_strides(4), dO_t_strides(4); generateMatrixStridesWithFormat(b, h, s_q, d_qk, q_t_strides.data(), q_format); @@ -2343,19 +2375,19 @@ void fused_attn_fp8_bwd_impl_v1( std::vector q_scale_strides(4), q_t_scale_strides(4), k_scale_strides(4), k_t_scale_strides(4), v_scale_strides(4), dO_scale_strides(4), dO_t_scale_strides(4); generateMatrixStridesWithFormat(b, h, padded.s_q_padded, padded.d_qk_scale_padded, - q_scale_strides.data(), q_format); + q_scale_strides.data(), q_scale_format); generateMatrixStridesWithFormat(b, h, padded.s_q_scale_padded, padded.d_qk_padded, - q_t_scale_strides.data(), q_format); + q_t_scale_strides.data(), q_scale_format); generateMatrixStridesWithFormat(b, hg, padded.s_kv_padded, padded.d_qk_scale_padded, - k_scale_strides.data(), kv_format); + k_scale_strides.data(), kv_scale_format); generateMatrixStridesWithFormat(b, hg, padded.s_kv_scale_padded, padded.d_qk_padded, - k_t_scale_strides.data(), kv_format); + k_t_scale_strides.data(), kv_scale_format); generateMatrixStridesWithFormat(b, hg, padded.s_kv_padded, padded.d_v_scale_padded, - v_scale_strides.data(), kv_format); + v_scale_strides.data(), kv_scale_format); generateMatrixStridesWithFormat(b, h, padded.s_q_padded, padded.d_v_scale_padded, - dO_scale_strides.data(), do_format); + dO_scale_strides.data(), do_scale_format); generateMatrixStridesWithFormat(b, h, padded.s_q_scale_padded, padded.d_v_padded, - dO_t_scale_strides.data(), do_format); + dO_t_scale_strides.data(), do_scale_format); descale_q = mha_graph->tensor(fe::graph::Tensor_attributes() .set_name("Descale_q") @@ -2719,7 +2751,8 @@ void fused_attn_fp8_fwd(size_t batch, size_t num_attn_heads, size_t num_gqa_grou const Tensor* input_SoftmaxOffset, Tensor* input_output_S, Tensor* output_O, NVTETensorPack* Aux_CTX_Tensors, const Tensor* cu_seqlens_q, const Tensor* cu_seqlens_kv, const Tensor* rng_state, Tensor* workspace, - cudaStream_t stream, cudnnHandle_t handle) { + cudaStream_t stream, cudnnHandle_t handle, + NVTE_QKV_Format qkv_scale_inv_format) { using namespace transformer_engine; void *devPtrQ = nullptr, *devPtrK = nullptr, *devPtrV = nullptr; void *devPtrDescaleQ = nullptr, *devPtrDescaleK = nullptr, *devPtrDescaleV = nullptr; @@ -2814,6 +2847,7 @@ void fused_attn_fp8_fwd(size_t batch, size_t num_attn_heads, size_t num_gqa_grou devPtrDescaleV, devPtrDescaleS, devPtrScaleS, devPtrScaleO, devPtrAmaxO, devPtrAmaxS, devPtrcuSeqlensQ, devPtrcuSeqlensKV, devPtrDropoutSeed, devPtrDropoutOffset, get_cudnn_fe_dtype(QKV_type), get_cudnn_fe_dtype(O_type), input_Q->scaling_mode, + qkv_scale_inv_format, workspace->data.dptr, &workspace_size, stream, handle); } else if (qkv_layout == NVTE_QKV_Layout::NVTE_T3HD) { fused_attn::fused_attn_fp8_fwd_impl( @@ -2852,7 +2886,8 @@ void fused_attn_fp8_bwd( const Tensor* input_S, const Tensor* input_SoftmaxOffset, Tensor* input_output_dP, const Tensor* output_dQ, const Tensor* output_dK, const Tensor* output_dV, Tensor* output_dSoftmaxOffset, const Tensor* cu_seqlens_q, const Tensor* cu_seqlens_kv, - const Tensor* rng_state, Tensor* workspace, cudaStream_t stream, cudnnHandle_t handle) { + const Tensor* rng_state, Tensor* workspace, cudaStream_t stream, cudnnHandle_t handle, + NVTE_QKV_Format qkv_scale_inv_format, NVTE_QKV_Format do_scale_inv_format) { using namespace transformer_engine; void* devPtrQ = input_Q->data.dptr; void* devPtrK = input_K->data.dptr; @@ -2947,7 +2982,9 @@ void fused_attn_fp8_bwd( devPtrDescaleQ_t, devPtrDescaleK_t, devPtrDescaledO_t, devPtrcuSeqlensQ, devPtrcuSeqlensKV, devPtrDropoutSeed, devPtrDropoutOffset, get_cudnn_fe_dtype(QKV_type), get_cudnn_fe_dtype(O_type), get_cudnn_fe_dtype(dO_type), get_cudnn_fe_dtype(dQKV_type), - input_dO->scaling_mode, workspace->data.dptr, &workspace_size, stream, handle); + input_dO->scaling_mode, + qkv_scale_inv_format, do_scale_inv_format, + workspace->data.dptr, &workspace_size, stream, handle); } else if (dqkv_layout == NVTE_QKV_Layout::NVTE_T3HD) { // remove this when cuDNN FE supports FP8 + THD NVTE_CHECK(input_ZInv != nullptr && input_ZInv->data.dptr != nullptr, diff --git a/transformer_engine/common/fused_attn/fused_attn_fp8.h b/transformer_engine/common/fused_attn/fused_attn_fp8.h index 2f6c1105bd..a5b8fced24 100644 --- a/transformer_engine/common/fused_attn/fused_attn_fp8.h +++ b/transformer_engine/common/fused_attn/fused_attn_fp8.h @@ -25,7 +25,8 @@ void fused_attn_fp8_fwd(size_t batch, size_t num_attn_heads, size_t num_gqa_grou const Tensor *input_SoftmaxOffset, Tensor *input_output_S, Tensor *output_O, NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv, const Tensor *rng_state, Tensor *workspace, - cudaStream_t stream, cudnnHandle_t handle); + cudaStream_t stream, cudnnHandle_t handle, + NVTE_QKV_Format qkv_scale_inv_format = NVTE_QKV_Format_NOT_SET); // fused attention BWD FP8 with separate Q, K, V void fused_attn_fp8_bwd( @@ -40,6 +41,8 @@ void fused_attn_fp8_bwd( const Tensor *input_S, const Tensor *input_SoftmaxOffset, Tensor *input_output_dP, const Tensor *output_dQ, const Tensor *output_dK, const Tensor *output_dV, Tensor *output_dSoftmaxOffset, const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv, - const Tensor *rng_state, Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle); + const Tensor *rng_state, Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle, + NVTE_QKV_Format qkv_scale_inv_format = NVTE_QKV_Format_NOT_SET, + NVTE_QKV_Format do_scale_inv_format = NVTE_QKV_Format_NOT_SET); #endif // end of CUDNN>=8900 } // namespace transformer_engine diff --git a/transformer_engine/common/fused_attn/utils.h b/transformer_engine/common/fused_attn/utils.h index b600261f40..822a0f61f3 100644 --- a/transformer_engine/common/fused_attn/utils.h +++ b/transformer_engine/common/fused_attn/utils.h @@ -297,6 +297,8 @@ struct FADescriptor_v1 { NVTE_QKV_Format o_format; NVTE_QKV_Format do_format; NVTE_QKV_Layout dqkv_layout; + NVTE_QKV_Format qkv_scale_inv_format; + NVTE_QKV_Format do_scale_inv_format; NVTE_Bias_Type bias_type; NVTE_Mask_Type mask_type; NVTE_Softmax_Type softmax_type; @@ -314,7 +316,8 @@ struct FADescriptor_v1 { return std::tie(b, h, hg, s_q, s_kv, d_qk, d_v, num_pages_k, num_pages_v, page_size_k, page_size_v, max_pages_per_seq_k, max_pages_per_seq_v, bias_b, bias_h, bias_sq, bias_skv, attnScale, isTraining, dropoutProbability, qkv_layout, o_format, - do_format, dqkv_layout, mask_type, softmax_type, window_size_left, + do_format, dqkv_layout, qkv_scale_inv_format, do_scale_inv_format, + mask_type, softmax_type, window_size_left, window_size_right, bottom_right_diagonal, deterministic, bias_type, qkv_tensor_type, o_tensor_type, do_tensor_type, dqkv_tensor_type, return_max_logit) < @@ -322,7 +325,9 @@ struct FADescriptor_v1 { rhs.num_pages_v, rhs.page_size_k, rhs.page_size_v, rhs.max_pages_per_seq_k, rhs.max_pages_per_seq_v, rhs.bias_b, rhs.bias_h, rhs.bias_sq, rhs.bias_skv, rhs.attnScale, rhs.isTraining, rhs.dropoutProbability, rhs.qkv_layout, - rhs.o_format, rhs.do_format, rhs.dqkv_layout, rhs.mask_type, rhs.softmax_type, + rhs.o_format, rhs.do_format, rhs.dqkv_layout, + rhs.qkv_scale_inv_format, rhs.do_scale_inv_format, + rhs.mask_type, rhs.softmax_type, rhs.window_size_left, rhs.window_size_right, rhs.bottom_right_diagonal, rhs.deterministic, rhs.bias_type, rhs.qkv_tensor_type, rhs.o_tensor_type, rhs.do_tensor_type, rhs.dqkv_tensor_type, rhs.return_max_logit); diff --git a/transformer_engine/common/fused_rope/fused_rope.cu b/transformer_engine/common/fused_rope/fused_rope.cu index 27dc11ab43..df396e246c 100644 --- a/transformer_engine/common/fused_rope/fused_rope.cu +++ b/transformer_engine/common/fused_rope/fused_rope.cu @@ -648,8 +648,518 @@ void fused_qkv_rope_backward(const Tensor &q_grad_out, const Tensor &k_grad_out, qkv_split_arg_list_0, qkv_split_arg_list_1, qkv_split_arg_list_2, stream);); } +// ============================================================================ +// MLA YARN RoPE kernels +// ============================================================================ + +__device__ int mla_get_thd_token_idx(const int *cu_seqlens, int pid_m, int seq_num, int cp_rank, + int cp_size) { + int token_idx = -1; + int this_seq_len = 0; + int last_cum = cu_seqlens[0] / cp_size; + for (int seq_idx = 0; seq_idx < seq_num; seq_idx++) { + int cur_cum = cu_seqlens[seq_idx + 1] / cp_size; + if (token_idx == -1 && cur_cum > pid_m) { + token_idx = pid_m - last_cum; + this_seq_len = cur_cum - last_cum; + } + last_cum = cur_cum; + } + if (cp_size > 1) { + if (token_idx < this_seq_len / 2) { + token_idx = token_idx + cp_rank * this_seq_len / 2; + } else { + token_idx = + (token_idx - this_seq_len / 2) + (2 * cp_size - cp_rank - 1) * this_seq_len / 2; + } + } + return token_idx; +} + +template +__global__ void mla_yarn_rope_q_forward_kernel(const scalar_t *q_input, const float *cos_data, + const float *sin_data, scalar_t *q_output, + const int *cu_seqlens, const int qk_head_dim, + const int emb_dim, const int h, const int d, + const int s, const int b, const int cp_size, + const int cp_rank) { + int pid_m = blockIdx.x; + const int half_emb = emb_dim / 2; + const int stride_t = h * d; + const int stride_h_val = d; + + int token_idx; + if (cu_seqlens == nullptr) { + int s_id = pid_m / b; + token_idx = s_id; + if (cp_size > 1) { + if (s_id < s / 2) { + token_idx = s_id + cp_rank * s / 2; + } else { + token_idx = s * cp_size - (cp_rank + 1) * s / 2 + s_id - s / 2; + } + } + } else { + token_idx = mla_get_thd_token_idx(cu_seqlens, pid_m, b, cp_rank, cp_size); + } + + extern __shared__ float shared_mem_q_fwd[]; + float *sh_cos_l = shared_mem_q_fwd; + float *sh_sin_l = shared_mem_q_fwd + half_emb; + float *sh_cos_r = shared_mem_q_fwd + 2 * half_emb; + float *sh_sin_r = shared_mem_q_fwd + 3 * half_emb; + + int tid = threadIdx.x + threadIdx.y * blockDim.x; + int num_threads = blockDim.x * blockDim.y; + for (int i = tid; i < half_emb; i += num_threads) { + sh_cos_l[i] = cos_data[token_idx * emb_dim + i]; + sh_sin_l[i] = sin_data[token_idx * emb_dim + i]; + sh_cos_r[i] = cos_data[token_idx * emb_dim + half_emb + i]; + sh_sin_r[i] = sin_data[token_idx * emb_dim + half_emb + i]; + } + __syncthreads(); + + int base = pid_m * stride_t; + + for (int h_id = threadIdx.y; h_id < h; h_id += blockDim.y) { + int head_offset = base + h_id * stride_h_val; + + for (int i = threadIdx.x; i < qk_head_dim; i += blockDim.x) { + q_output[head_offset + i] = q_input[head_offset + i]; + } + + int rope_in = head_offset + qk_head_dim; + int rope_out = head_offset + qk_head_dim; + for (int i = threadIdx.x; i < half_emb; i += blockDim.x) { + float x1 = static_cast(q_input[rope_in + i * 2]); + float x2 = static_cast(q_input[rope_in + i * 2 + 1]); + + q_output[rope_out + i] = static_cast(x1 * sh_cos_l[i] - x2 * sh_sin_l[i]); + q_output[rope_out + half_emb + i] = + static_cast(x2 * sh_cos_r[i] + x1 * sh_sin_r[i]); + } + } +} + +template +__global__ void mla_yarn_rope_q_backward_kernel(const scalar_t *grad_output, + const float *cos_data, const float *sin_data, + scalar_t *grad_input, const int *cu_seqlens, + const int qk_head_dim, const int emb_dim, + const int h, const int d, const int s, const int b, + const int cp_size, const int cp_rank) { + int pid_m = blockIdx.x; + const int half_emb = emb_dim / 2; + const int stride_t = h * d; + const int stride_h_val = d; + + int token_idx; + if (cu_seqlens == nullptr) { + int s_id = pid_m / b; + token_idx = s_id; + if (cp_size > 1) { + if (s_id < s / 2) { + token_idx = s_id + cp_rank * s / 2; + } else { + token_idx = s * cp_size - (cp_rank + 1) * s / 2 + s_id - s / 2; + } + } + } else { + token_idx = mla_get_thd_token_idx(cu_seqlens, pid_m, b, cp_rank, cp_size); + } + + extern __shared__ float shared_mem_q_bwd[]; + float *sh_cos_l = shared_mem_q_bwd; + float *sh_sin_l = shared_mem_q_bwd + half_emb; + float *sh_cos_r = shared_mem_q_bwd + 2 * half_emb; + float *sh_sin_r = shared_mem_q_bwd + 3 * half_emb; + + int tid = threadIdx.x + threadIdx.y * blockDim.x; + int num_threads = blockDim.x * blockDim.y; + for (int i = tid; i < half_emb; i += num_threads) { + sh_cos_l[i] = cos_data[token_idx * emb_dim + i]; + sh_sin_l[i] = sin_data[token_idx * emb_dim + i]; + sh_cos_r[i] = cos_data[token_idx * emb_dim + half_emb + i]; + sh_sin_r[i] = sin_data[token_idx * emb_dim + half_emb + i]; + } + __syncthreads(); + + int base = pid_m * stride_t; + + for (int h_id = threadIdx.y; h_id < h; h_id += blockDim.y) { + int head_offset = base + h_id * stride_h_val; + + for (int i = threadIdx.x; i < qk_head_dim; i += blockDim.x) { + grad_input[head_offset + i] = grad_output[head_offset + i]; + } + + int rope_offset = head_offset + qk_head_dim; + for (int i = threadIdx.x; i < half_emb; i += blockDim.x) { + float gl = static_cast(grad_output[rope_offset + i]); + float gr = static_cast(grad_output[rope_offset + half_emb + i]); + + grad_input[rope_offset + i * 2] = static_cast(gl * sh_cos_l[i] + gr * sh_sin_r[i]); + grad_input[rope_offset + i * 2 + 1] = + static_cast(-gl * sh_sin_l[i] + gr * sh_cos_r[i]); + } + } +} + +template +__global__ void mla_yarn_rope_kv_forward_kernel( + const scalar_t *kv_input, const scalar_t *k_pos_emb, const float *cos_data, + const float *sin_data, scalar_t *o_key, scalar_t *o_value, const int *cu_seqlens, + const int emb_dim, const int k_dim, const int v_dim, const int h, const int s, const int b, + const int cp_size, const int cp_rank) { + int pid_m = blockIdx.x; + const int half_emb = emb_dim / 2; + const int kv_stride_t = h * (k_dim + v_dim); + const int kv_stride_h = k_dim + v_dim; + const int emb_stride_t = emb_dim; + const int k_stride_t = h * (k_dim + emb_dim); + const int k_stride_h = k_dim + emb_dim; + const int v_stride_t = h * v_dim; + const int v_stride_h = v_dim; + + int token_idx; + if (cu_seqlens == nullptr) { + int s_id = pid_m / b; + token_idx = s_id; + if (cp_size > 1) { + if (s_id < s / 2) { + token_idx = s_id + cp_rank * s / 2; + } else { + token_idx = s * cp_size - (cp_rank + 1) * s / 2 + s_id - s / 2; + } + } + } else { + token_idx = mla_get_thd_token_idx(cu_seqlens, pid_m, b, cp_rank, cp_size); + } + + extern __shared__ float shared_mem_kv_fwd[]; + float *sh_cos_l = shared_mem_kv_fwd; + float *sh_sin_l = shared_mem_kv_fwd + half_emb; + float *sh_cos_r = shared_mem_kv_fwd + 2 * half_emb; + float *sh_sin_r = shared_mem_kv_fwd + 3 * half_emb; + float *sh_rot_left = shared_mem_kv_fwd + 4 * half_emb; + float *sh_rot_right = shared_mem_kv_fwd + 5 * half_emb; + + int tid = threadIdx.x + threadIdx.y * blockDim.x; + int num_threads = blockDim.x * blockDim.y; + for (int i = tid; i < half_emb; i += num_threads) { + sh_cos_l[i] = cos_data[token_idx * emb_dim + i]; + sh_sin_l[i] = sin_data[token_idx * emb_dim + i]; + sh_cos_r[i] = cos_data[token_idx * emb_dim + half_emb + i]; + sh_sin_r[i] = sin_data[token_idx * emb_dim + half_emb + i]; + } + __syncthreads(); + + for (int i = tid; i < half_emb; i += num_threads) { + float x1 = static_cast(k_pos_emb[pid_m * emb_stride_t + i * 2]); + float x2 = static_cast(k_pos_emb[pid_m * emb_stride_t + i * 2 + 1]); + sh_rot_left[i] = x1 * sh_cos_l[i] - x2 * sh_sin_l[i]; + sh_rot_right[i] = x2 * sh_cos_r[i] + x1 * sh_sin_r[i]; + } + __syncthreads(); + + for (int h_id = threadIdx.y; h_id < h; h_id += blockDim.y) { + int kv_head = pid_m * kv_stride_t + h_id * kv_stride_h; + int k_head = pid_m * k_stride_t + h_id * k_stride_h; + int v_head = pid_m * v_stride_t + h_id * v_stride_h; + + for (int i = threadIdx.x; i < k_dim; i += blockDim.x) { + o_key[k_head + i] = kv_input[kv_head + i]; + } + + for (int i = threadIdx.x; i < v_dim; i += blockDim.x) { + o_value[v_head + i] = kv_input[kv_head + k_dim + i]; + } + + for (int i = threadIdx.x; i < half_emb; i += blockDim.x) { + o_key[k_head + k_dim + i] = static_cast(sh_rot_left[i]); + o_key[k_head + k_dim + half_emb + i] = static_cast(sh_rot_right[i]); + } + } +} + +template +__global__ void mla_yarn_rope_kv_backward_kernel( + const scalar_t *dk, const scalar_t *dv, const float *cos_data, const float *sin_data, + scalar_t *d_kv, scalar_t *d_emb, const int *cu_seqlens, const int emb_dim, const int k_dim, + const int v_dim, const int h, const int s, const int b, const int cp_size, const int cp_rank) { + int pid_m = blockIdx.x; + const int half_emb = emb_dim / 2; + const int dk_stride_t = h * (k_dim + emb_dim); + const int dk_stride_h = k_dim + emb_dim; + const int dv_stride_t = h * v_dim; + const int dv_stride_h = v_dim; + const int dkv_stride_t = h * (k_dim + v_dim); + const int dkv_stride_h = k_dim + v_dim; + + int token_idx; + if (cu_seqlens == nullptr) { + int s_id = pid_m / b; + token_idx = s_id; + if (cp_size > 1) { + if (s_id < s / 2) { + token_idx = s_id + cp_rank * s / 2; + } else { + token_idx = s * cp_size - (cp_rank + 1) * s / 2 + s_id - s / 2; + } + } + } else { + token_idx = mla_get_thd_token_idx(cu_seqlens, pid_m, b, cp_rank, cp_size); + } + + for (int h_id = threadIdx.y; h_id < h; h_id += blockDim.y) { + int dk_head = pid_m * dk_stride_t + h_id * dk_stride_h; + int dv_head = pid_m * dv_stride_t + h_id * dv_stride_h; + int dkv_head = pid_m * dkv_stride_t + h_id * dkv_stride_h; + + for (int i = threadIdx.x; i < k_dim; i += blockDim.x) { + d_kv[dkv_head + i] = dk[dk_head + i]; + } + for (int i = threadIdx.x; i < v_dim; i += blockDim.x) { + d_kv[dkv_head + k_dim + i] = dv[dv_head + i]; + } + } + + extern __shared__ float shared_mem_kv_bwd[]; + float *sh_cos_l = shared_mem_kv_bwd; + float *sh_sin_l = shared_mem_kv_bwd + half_emb; + float *sh_cos_r = shared_mem_kv_bwd + 2 * half_emb; + float *sh_sin_r = shared_mem_kv_bwd + 3 * half_emb; + + int tid = threadIdx.x + threadIdx.y * blockDim.x; + int num_threads = blockDim.x * blockDim.y; + for (int i = tid; i < half_emb; i += num_threads) { + sh_cos_l[i] = cos_data[token_idx * emb_dim + i]; + sh_sin_l[i] = sin_data[token_idx * emb_dim + i]; + sh_cos_r[i] = cos_data[token_idx * emb_dim + half_emb + i]; + sh_sin_r[i] = sin_data[token_idx * emb_dim + half_emb + i]; + } + __syncthreads(); + + for (int i = threadIdx.x; i < half_emb; i += blockDim.x) { + if (threadIdx.y == 0) { + float accum_l = 0.0f, accum_r = 0.0f; + for (int h_id = 0; h_id < h; h_id++) { + int dk_head = pid_m * dk_stride_t + h_id * dk_stride_h; + accum_l += static_cast(dk[dk_head + k_dim + i]); + accum_r += static_cast(dk[dk_head + k_dim + half_emb + i]); + } + float dx1 = accum_l * sh_cos_l[i] + accum_r * sh_sin_r[i]; + float dx2 = -accum_l * sh_sin_l[i] + accum_r * sh_cos_r[i]; + d_emb[pid_m * emb_dim + i * 2] = static_cast(dx1); + d_emb[pid_m * emb_dim + i * 2 + 1] = static_cast(dx2); + } + } +} + +template +void mla_yarn_rope_q_forward_launcher(const scalar_t *q_input, const float *cos_data, + const float *sin_data, scalar_t *q_output, + const int *cu_seqlens, const int qk_head_dim, + const int emb_dim, const int h, const int d, + const int total_seqlen, const int s, const int b, + const int cp_size, const int cp_rank, cudaStream_t stream) { + int warps_per_block = h < 16 ? 4 : 8; + dim3 blocks(total_seqlen); + dim3 threads(THREADS_PER_WARP, warps_per_block); + const int shared_mem_size = 4 * (emb_dim / 2) * sizeof(float); + + mla_yarn_rope_q_forward_kernel<<>>( + q_input, cos_data, sin_data, q_output, cu_seqlens, qk_head_dim, emb_dim, h, d, s, b, + cp_size, cp_rank); + NVTE_CHECK_CUDA(cudaGetLastError()); +} + +template +void mla_yarn_rope_q_backward_launcher(const scalar_t *grad_output, const float *cos_data, + const float *sin_data, scalar_t *grad_input, + const int *cu_seqlens, const int qk_head_dim, + const int emb_dim, const int h, const int d, + const int total_seqlen, const int s, const int b, + const int cp_size, const int cp_rank, cudaStream_t stream) { + int warps_per_block = h < 16 ? 4 : 8; + dim3 blocks(total_seqlen); + dim3 threads(THREADS_PER_WARP, warps_per_block); + const int shared_mem_size = 4 * (emb_dim / 2) * sizeof(float); + + mla_yarn_rope_q_backward_kernel<<>>( + grad_output, cos_data, sin_data, grad_input, cu_seqlens, qk_head_dim, emb_dim, h, d, s, b, + cp_size, cp_rank); + NVTE_CHECK_CUDA(cudaGetLastError()); +} + +template +void mla_yarn_rope_kv_forward_launcher(const scalar_t *kv_input, const scalar_t *k_pos_emb, + const float *cos_data, const float *sin_data, + scalar_t *o_key, scalar_t *o_value, const int *cu_seqlens, + const int emb_dim, const int k_dim, const int v_dim, + const int h, const int total_seqlen, const int s, + const int b, const int cp_size, const int cp_rank, + cudaStream_t stream) { + int warps_per_block = h < 16 ? 4 : 8; + dim3 blocks(total_seqlen); + dim3 threads(THREADS_PER_WARP, warps_per_block); + const int shared_mem_size = 6 * (emb_dim / 2) * sizeof(float); + + mla_yarn_rope_kv_forward_kernel<<>>( + kv_input, k_pos_emb, cos_data, sin_data, o_key, o_value, cu_seqlens, emb_dim, k_dim, v_dim, + h, s, b, cp_size, cp_rank); + NVTE_CHECK_CUDA(cudaGetLastError()); +} + +template +void mla_yarn_rope_kv_backward_launcher(const scalar_t *dk, const scalar_t *dv, + const float *cos_data, const float *sin_data, + scalar_t *d_kv, scalar_t *d_emb, const int *cu_seqlens, + const int emb_dim, const int k_dim, const int v_dim, + const int h, const int total_seqlen, const int s, + const int b, const int cp_size, const int cp_rank, + cudaStream_t stream) { + int warps_per_block = h < 16 ? 4 : 8; + dim3 blocks(total_seqlen); + dim3 threads(THREADS_PER_WARP, warps_per_block); + const int shared_mem_size = 4 * (emb_dim / 2) * sizeof(float); + + mla_yarn_rope_kv_backward_kernel<<>>( + dk, dv, cos_data, sin_data, d_kv, d_emb, cu_seqlens, emb_dim, k_dim, v_dim, h, s, b, + cp_size, cp_rank); + NVTE_CHECK_CUDA(cudaGetLastError()); +} + +void fused_mla_rope_q_forward(const Tensor &q_input, const Tensor &cos, const Tensor &sin, + Tensor *q_output, const Tensor &cu_seqlens, const int qk_head_dim, + const int emb_dim, const int h, const int d, const int total_seqlen, + const int s, const int b, const int cp_size, const int cp_rank, + cudaStream_t stream) { + TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT( + q_input.data.dtype, scalar_t, + mla_yarn_rope_q_forward_launcher( + reinterpret_cast(q_input.data.dptr), + reinterpret_cast(cos.data.dptr), + reinterpret_cast(sin.data.dptr), + reinterpret_cast(q_output->data.dptr), + reinterpret_cast(cu_seqlens.data.dptr), qk_head_dim, emb_dim, h, d, + total_seqlen, s, b, cp_size, cp_rank, stream);); +} + +void fused_mla_rope_q_backward(const Tensor &grad_output, const Tensor &cos, const Tensor &sin, + Tensor *grad_input, const Tensor &cu_seqlens, + const int qk_head_dim, const int emb_dim, const int h, const int d, + const int total_seqlen, const int s, const int b, const int cp_size, + const int cp_rank, cudaStream_t stream) { + TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT( + grad_output.data.dtype, scalar_t, + mla_yarn_rope_q_backward_launcher( + reinterpret_cast(grad_output.data.dptr), + reinterpret_cast(cos.data.dptr), + reinterpret_cast(sin.data.dptr), + reinterpret_cast(grad_input->data.dptr), + reinterpret_cast(cu_seqlens.data.dptr), qk_head_dim, emb_dim, h, d, + total_seqlen, s, b, cp_size, cp_rank, stream);); +} + +void fused_mla_rope_kv_forward(const Tensor &kv_input, const Tensor &k_pos_emb, const Tensor &cos, + const Tensor &sin, Tensor *o_key, Tensor *o_value, + const Tensor &cu_seqlens, const int emb_dim, const int k_dim, + const int v_dim, const int h, const int total_seqlen, const int s, + const int b, const int cp_size, const int cp_rank, + cudaStream_t stream) { + TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT( + kv_input.data.dtype, scalar_t, + mla_yarn_rope_kv_forward_launcher( + reinterpret_cast(kv_input.data.dptr), + reinterpret_cast(k_pos_emb.data.dptr), + reinterpret_cast(cos.data.dptr), + reinterpret_cast(sin.data.dptr), + reinterpret_cast(o_key->data.dptr), + reinterpret_cast(o_value->data.dptr), + reinterpret_cast(cu_seqlens.data.dptr), emb_dim, k_dim, v_dim, h, + total_seqlen, s, b, cp_size, cp_rank, stream);); +} + +void fused_mla_rope_kv_backward(const Tensor &dk, const Tensor &dv, const Tensor &cos, + const Tensor &sin, Tensor *d_kv, Tensor *d_emb, + const Tensor &cu_seqlens, const int emb_dim, const int k_dim, + const int v_dim, const int h, const int total_seqlen, const int s, + const int b, const int cp_size, const int cp_rank, + cudaStream_t stream) { + TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT( + dk.data.dtype, scalar_t, + mla_yarn_rope_kv_backward_launcher( + reinterpret_cast(dk.data.dptr), + reinterpret_cast(dv.data.dptr), + reinterpret_cast(cos.data.dptr), + reinterpret_cast(sin.data.dptr), + reinterpret_cast(d_kv->data.dptr), + reinterpret_cast(d_emb->data.dptr), + reinterpret_cast(cu_seqlens.data.dptr), emb_dim, k_dim, v_dim, h, + total_seqlen, s, b, cp_size, cp_rank, stream);); +} + } // end namespace transformer_engine +void nvte_fused_mla_rope_q_forward(const NVTETensor q_input, const NVTETensor cos, + const NVTETensor sin, NVTETensor q_output, + const NVTETensor cu_seqlens, const int qk_head_dim, + const int emb_dim, const int h, const int d, + const int total_seqlen, const int s, const int b, + const int cp_size, const int cp_rank, cudaStream_t stream) { + NVTE_API_CALL(nvte_fused_mla_rope_q_forward); + using namespace transformer_engine; + fused_mla_rope_q_forward(*convertNVTETensorCheck(q_input), *convertNVTETensorCheck(cos), + *convertNVTETensorCheck(sin), convertNVTETensorCheck(q_output), + *convertNVTETensorCheck(cu_seqlens), qk_head_dim, emb_dim, h, d, + total_seqlen, s, b, cp_size, cp_rank, stream); +} + +void nvte_fused_mla_rope_q_backward(const NVTETensor grad_output, const NVTETensor cos, + const NVTETensor sin, NVTETensor grad_input, + const NVTETensor cu_seqlens, const int qk_head_dim, + const int emb_dim, const int h, const int d, + const int total_seqlen, const int s, const int b, + const int cp_size, const int cp_rank, cudaStream_t stream) { + NVTE_API_CALL(nvte_fused_mla_rope_q_backward); + using namespace transformer_engine; + fused_mla_rope_q_backward(*convertNVTETensorCheck(grad_output), *convertNVTETensorCheck(cos), + *convertNVTETensorCheck(sin), convertNVTETensorCheck(grad_input), + *convertNVTETensorCheck(cu_seqlens), qk_head_dim, emb_dim, h, d, + total_seqlen, s, b, cp_size, cp_rank, stream); +} + +void nvte_fused_mla_rope_kv_forward(const NVTETensor kv_input, const NVTETensor k_pos_emb, + const NVTETensor cos, const NVTETensor sin, NVTETensor o_key, + NVTETensor o_value, const NVTETensor cu_seqlens, + const int emb_dim, const int k_dim, const int v_dim, + const int h, const int total_seqlen, const int s, const int b, + const int cp_size, const int cp_rank, cudaStream_t stream) { + NVTE_API_CALL(nvte_fused_mla_rope_kv_forward); + using namespace transformer_engine; + fused_mla_rope_kv_forward( + *convertNVTETensorCheck(kv_input), *convertNVTETensorCheck(k_pos_emb), + *convertNVTETensorCheck(cos), *convertNVTETensorCheck(sin), convertNVTETensorCheck(o_key), + convertNVTETensorCheck(o_value), *convertNVTETensorCheck(cu_seqlens), emb_dim, k_dim, v_dim, + h, total_seqlen, s, b, cp_size, cp_rank, stream); +} + +void nvte_fused_mla_rope_kv_backward(const NVTETensor dk, const NVTETensor dv, const NVTETensor cos, + const NVTETensor sin, NVTETensor d_kv, NVTETensor d_emb, + const NVTETensor cu_seqlens, const int emb_dim, + const int k_dim, const int v_dim, const int h, + const int total_seqlen, const int s, const int b, + const int cp_size, const int cp_rank, cudaStream_t stream) { + NVTE_API_CALL(nvte_fused_mla_rope_kv_backward); + using namespace transformer_engine; + fused_mla_rope_kv_backward(*convertNVTETensorCheck(dk), *convertNVTETensorCheck(dv), + *convertNVTETensorCheck(cos), *convertNVTETensorCheck(sin), + convertNVTETensorCheck(d_kv), convertNVTETensorCheck(d_emb), + *convertNVTETensorCheck(cu_seqlens), emb_dim, k_dim, v_dim, h, + total_seqlen, s, b, cp_size, cp_rank, stream); +} + void nvte_fused_rope_forward(const NVTETensor input, const NVTETensor cu_seqlens, const NVTETensor freqs, const NVTETensor start_positions, NVTETensor output, const NVTE_QKV_Format qkv_format, diff --git a/transformer_engine/common/include/transformer_engine/fused_attn.h b/transformer_engine/common/include/transformer_engine/fused_attn.h index 65cdaca7d0..62d4369768 100644 --- a/transformer_engine/common/include/transformer_engine/fused_attn.h +++ b/transformer_engine/common/include/transformer_engine/fused_attn.h @@ -318,7 +318,8 @@ void nvte_fused_attn_fwd( bool cuda_graph, float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, NVTE_QKV_Format o_format, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, int64_t window_size_left, int64_t window_size_right, - bool bottom_right_diagonal, NVTETensor workspace, cudaStream_t stream); + bool bottom_right_diagonal, NVTETensor workspace, cudaStream_t stream, + NVTE_QKV_Format qkv_scale_inv_format = NVTE_QKV_Format_NOT_SET); /*! \brief Compute the backward of the dot product attention with separate Q, K and V. * @@ -397,7 +398,9 @@ void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETenso NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, int64_t window_size_left, int64_t window_size_right, bool bottom_right_diagonal, bool deterministic, - bool cuda_graph, NVTETensor workspace, cudaStream_t stream); + bool cuda_graph, NVTETensor workspace, cudaStream_t stream, + NVTE_QKV_Format qkv_scale_inv_format = NVTE_QKV_Format_NOT_SET, + NVTE_QKV_Format do_scale_inv_format = NVTE_QKV_Format_NOT_SET); /*! \brief Update the RNG state with the seed and calculated offset. * @@ -611,35 +614,58 @@ void nvte_prepare_flash_attn_fwd(NVTETensor qkvi, NVTETensor qkv, cudaStream_t s void nvte_prepare_flash_attn_bwd(NVTETensor q, NVTETensor k, NVTETensor v, NVTETensor qkv, cudaStream_t stream); -/*! \brief Permute Q, K, V to grouped tensors. +/*! \brief Permute Q, K, V to grouped tensors (BSHD/SBHD → BHSD). * - * \param[in] q Query tensor - * \param[in] k Key tensor - * \param[in] v Value tensor - * \param[out] q_out Output query tensor - * \param[out] k_out Output key tensor - * \param[out] v_out Output value tensor - * \param[in] original_layout Original QKV layout. + * When num_tensors == 1, only q/q_out are used (k/v/k_out/v_out are ignored). + * + * \param[in] q Query tensor (or the single tensor). + * \param[in] k Key tensor (ignored when num_tensors == 1). + * \param[in] v Value tensor (ignored when num_tensors == 1). + * \param[out] q_out Output query tensor. + * \param[out] k_out Output key tensor (ignored when num_tensors == 1). + * \param[out] v_out Output value tensor (ignored when num_tensors == 1). + * \param[in] original_format Original QKV format (NVTE_BSHD or NVTE_SBHD). + * \param[in] num_tensors Number of tensors to permute (1 or 3). * \param[in] stream CUDA stream. */ void nvte_permute_to_grouped_tensor_fwd(NVTETensor q, NVTETensor k, NVTETensor v, NVTETensor q_out, NVTETensor k_out, NVTETensor v_out, - NVTE_QKV_Layout original_layout, cudaStream_t stream); + NVTE_QKV_Format original_format, size_t num_tensors, + cudaStream_t stream); -/*! \brief Permute Q, K, V back to original layout. +/*! \brief Permute Q, K, V back to original format (BHSD → BSHD/SBHD). + * + * When num_tensors == 1, only grad_q/q are used (others are ignored). * - * \param[in] grad_q Gradient of query tensor - * \param[in] grad_k Gradient of key tensor - * \param[in] grad_v Gradient of value tensor - * \param[out] q Original query tensor - * \param[out] k Original key tensor - * \param[out] v Original value tensor - * \param[in] original_layout Original QKV layout. + * \param[in] grad_q Gradient of query tensor. + * \param[in] grad_k Gradient of key tensor (ignored when num_tensors == 1). + * \param[in] grad_v Gradient of value tensor (ignored when num_tensors == 1). + * \param[out] q Original query tensor. + * \param[out] k Original key tensor (ignored when num_tensors == 1). + * \param[out] v Original value tensor (ignored when num_tensors == 1). + * \param[in] original_format Original QKV format (NVTE_BSHD or NVTE_SBHD). + * \param[in] num_tensors Number of tensors to permute (1 or 3). * \param[in] stream CUDA stream. */ void nvte_permute_to_grouped_tensor_bwd(NVTETensor grad_q, NVTETensor grad_k, NVTETensor grad_v, NVTETensor q, NVTETensor k, NVTETensor v, - NVTE_QKV_Layout original_layout, cudaStream_t stream); + NVTE_QKV_Format original_format, size_t num_tensors, + cudaStream_t stream); + +/*! \brief Pad the last dimension of multiple 2D tensors with zeros in one kernel launch. + * + * Each tensor copies a row-major (rows, in_cols) input to a (rows, out_cols) output, + * zero-filling the region [in_cols, out_cols) in every row. + * Outputs must be pre-allocated with out_cols >= in_cols and matching dtype. + * Up to 16 tensors may be processed in a single call. + * + * \param[in] inputs Array of num_tensors 2D input tensors. + * \param[out] outputs Array of num_tensors 2D output tensors, pre-allocated. + * \param[in] num_tensors Number of tensor pairs to process (1..16). + * \param[in] stream CUDA stream. + */ +void nvte_multi_pad_last_dim(NVTETensor *inputs, NVTETensor *outputs, size_t num_tensors, + cudaStream_t stream); #ifdef __cplusplus } // extern "C" diff --git a/transformer_engine/common/include/transformer_engine/fused_rope.h b/transformer_engine/common/include/transformer_engine/fused_rope.h index aea5256a2c..69b5168212 100644 --- a/transformer_engine/common/include/transformer_engine/fused_rope.h +++ b/transformer_engine/common/include/transformer_engine/fused_rope.h @@ -139,6 +139,81 @@ void nvte_fused_qkv_rope_backward(const NVTETensor q_grad_out, const NVTETensor const int qkv_split_arg_list_0, const int qkv_split_arg_list_1, const int qkv_split_arg_list_2, cudaStream_t stream); +/*! \brief Apply YARN RoPE to MLA query tensor (forward). + * + * Reads the last emb_dim elements interleaved, applies YARN rotation with + * split cos/sin, and writes de-interleaved. First qk_head_dim elements are + * copied unchanged. Input is [total_seqlen, h, d] (flattened from SBHD or THD). + * + * \param[in] q_input Input Q tensor. + * \param[in] cos Pre-computed cosine tensor [max_s, emb_dim]. + * \param[in] sin Pre-computed sine tensor [max_s, emb_dim]. + * \param[out] q_output Output Q tensor (same shape as input). + * \param[in] cu_seqlens Cumulative sequence lengths for THD (empty for SBHD). + * \param[in] qk_head_dim Non-RoPE prefix dimension per head. + * \param[in] emb_dim RoPE embedding dimension. + * \param[in] h Number of heads. + * \param[in] d Total head dimension (qk_head_dim + emb_dim). + * \param[in] total_seqlen Total tokens (s*b for SBHD, total_t for THD). + * \param[in] s Sequence length (SBHD) or max_s (THD). + * \param[in] b Batch size (SBHD) or num_seqs (THD). + * \param[in] cp_size Context parallel world size. + * \param[in] cp_rank Context parallel rank. + * \param[in] stream CUDA stream. + */ +void nvte_fused_mla_rope_q_forward(const NVTETensor q_input, const NVTETensor cos, + const NVTETensor sin, NVTETensor q_output, + const NVTETensor cu_seqlens, const int qk_head_dim, + const int emb_dim, const int h, const int d, + const int total_seqlen, const int s, const int b, + const int cp_size, const int cp_rank, cudaStream_t stream); + +/*! \brief Backward of YARN RoPE for MLA query tensor. */ +void nvte_fused_mla_rope_q_backward(const NVTETensor grad_output, const NVTETensor cos, + const NVTETensor sin, NVTETensor grad_input, + const NVTETensor cu_seqlens, const int qk_head_dim, + const int emb_dim, const int h, const int d, + const int total_seqlen, const int s, const int b, + const int cp_size, const int cp_rank, cudaStream_t stream); + +/*! \brief Apply YARN RoPE to MLA key-value tensor (forward). + * + * Splits KV into key and value, applies YARN rotation to k_pos_emb (shared + * across heads), concatenates rotated embedding to each head of output key. + * + * \param[in] kv_input Input KV tensor [total_t, h, k_dim+v_dim]. + * \param[in] k_pos_emb Positional embedding [total_t, emb_dim]. + * \param[in] cos Pre-computed cosine [max_s, emb_dim]. + * \param[in] sin Pre-computed sine [max_s, emb_dim]. + * \param[out] o_key Output key [total_t, h, k_dim+emb_dim]. + * \param[out] o_value Output value [total_t, h, v_dim]. + * \param[in] cu_seqlens Cumulative sequence lengths for THD (empty for SBHD). + * \param[in] emb_dim RoPE embedding dimension. + * \param[in] k_dim Key dimension per head (from KV). + * \param[in] v_dim Value dimension per head (from KV). + * \param[in] h Number of heads. + * \param[in] total_seqlen Total tokens. + * \param[in] s Sequence length (SBHD) or max_s (THD). + * \param[in] b Batch size (SBHD) or num_seqs (THD). + * \param[in] cp_size Context parallel world size. + * \param[in] cp_rank Context parallel rank. + * \param[in] stream CUDA stream. + */ +void nvte_fused_mla_rope_kv_forward(const NVTETensor kv_input, const NVTETensor k_pos_emb, + const NVTETensor cos, const NVTETensor sin, NVTETensor o_key, + NVTETensor o_value, const NVTETensor cu_seqlens, + const int emb_dim, const int k_dim, const int v_dim, + const int h, const int total_seqlen, const int s, const int b, + const int cp_size, const int cp_rank, cudaStream_t stream); + +/*! \brief Backward of YARN RoPE for MLA key-value tensor. */ +void nvte_fused_mla_rope_kv_backward(const NVTETensor dk, const NVTETensor dv, const NVTETensor cos, + const NVTETensor sin, NVTETensor d_kv, NVTETensor d_emb, + const NVTETensor cu_seqlens, const int emb_dim, + const int k_dim, const int v_dim, const int h, + const int total_seqlen, const int s, const int b, + const int cp_size, const int cp_rank, cudaStream_t stream); + #ifdef __cplusplus } // extern "C" #endif diff --git a/transformer_engine/common/swizzle/swizzle.cu b/transformer_engine/common/swizzle/swizzle.cu index 28a879a376..09639d9998 100644 --- a/transformer_engine/common/swizzle/swizzle.cu +++ b/transformer_engine/common/swizzle/swizzle.cu @@ -430,6 +430,79 @@ __global__ void __launch_bounds__(TB_DIM* TB_DIM) input, output, M, K, original_M, original_K, blockIdx.x, blockIdx.y, gridDim.x, gridDim.y); } +// Narrow-K specialization for row scaling swizzle. +// When K is small (num_tiles_k < TB_DIM), the standard kernel wastes threadIdx.x +// because there aren't enough K-tiles to distribute across threads. +// This kernel repurposes the thread dimensions: threadIdx.x iterates rows within +// an M-tile, threadIdx.y indexes M-tiles within the block, processing TB_DIM +// M-tiles per block with full thread utilization. +template +__device__ void swizzle_row_scaling_narrow_k_impl( + const void* input, void* output, const int M, const int K, + const int original_M, const int original_K, + const int bid, const int grid_dim) { + constexpr int SF_TILE_SIZE_I32 = SF_TILE_DIM_M * SF_TILE_DIM_K / 4; + const int K_i32 = K / 4; + const int num_tiles_m = M / SF_TILE_DIM_M; + + const int m_tile = bid * blockDim.y + threadIdx.y; + const bool active = (m_tile < num_tiles_m); + + extern __shared__ int4 slm_v4i[]; + const int slm_tile_v4i = K_i32 * (SF_TILE_SIZE_I32 / 4); + + if (active) { + const bool padding_m = (m_tile == num_tiles_m - 1) && (original_M < M); + const bool padding_k = (original_K < K); + + int4* my_slm = slm_v4i + threadIdx.y * slm_tile_v4i; + + for (int k = 0; k < K_i32; k++) { + const int input_base = m_tile * SF_TILE_DIM_M * K_i32 + k; + const int* input_i32 = reinterpret_cast(input) + input_base; + + int regs[N_SF_PER_TD_PER_TILE]; +#pragma unroll + for (int i = 0; i < N_SF_PER_TD_PER_TILE; i++) { + const int row = i * TB_DIM + threadIdx.x; + regs[i] = __ldg(input_i32 + row * K_i32); + if (padding_m || padding_k) { + for (int j = 0; j < 4; j++) { + const int byte_row = m_tile * SF_TILE_DIM_M + row; + const int byte_col = k * 4 + j; + if (byte_row >= original_M || byte_col >= original_K) { + reinterpret_cast(®s[i])[j] = 0; + } + } + } + } + + my_slm[k * (SF_TILE_SIZE_I32 / 4) + threadIdx.x] = + *reinterpret_cast(regs); + } + } + + __syncthreads(); + + if (active) { + int4* my_slm = slm_v4i + threadIdx.y * slm_tile_v4i; + int4* out_v4i = reinterpret_cast( + reinterpret_cast(output) + m_tile * SF_TILE_DIM_M * K_i32); + + for (int i = threadIdx.x; i < slm_tile_v4i; i += blockDim.x) { + out_v4i[i] = my_slm[i]; + } + } +} + +template +__global__ void __launch_bounds__(TB_DIM* TB_DIM) + swizzle_row_scaling_narrow_k_kernel(const void* input, void* output, const int M, const int K, + const int original_M, const int original_K) { + swizzle_row_scaling_narrow_k_impl( + input, output, M, K, original_M, original_K, blockIdx.x, gridDim.x); +} + constexpr int kMaxTensorsPerKernel = 64; // Args must be <4 KB struct MultiSwizzleArgs { // (input) Data buffers for input scaling factors @@ -719,13 +792,6 @@ void swizzle_scaling_factors(const Tensor* input, Tensor* output, cudaStream_t s // Perform row-wise swizzle if (rowwise_swizzle) { - int vec_load_size = (num_tiles_k - 1) % 4 + 1; - /* there is no int3 and misaligned if using int4/int2 */ - if (vec_load_size == 3) vec_load_size = 1; - int n_tiles_in_tb = TB_DIM * vec_load_size; - dim3 num_blocks(DIVUP(num_tiles_k, n_tiles_in_tb), num_tiles_m); - int slm_size = n_tiles_in_tb * SF_TILE_DIM_M * SF_TILE_DIM_K * sizeof(int8_t); - int original_M{0}, original_K{0}; void *input_scale_inv_ptr{nullptr}, *output_scale_inv_ptr{nullptr}; switch (scaling_mode) { @@ -754,34 +820,54 @@ void swizzle_scaling_factors(const Tensor* input, Tensor* output, cudaStream_t s NVTE_ERROR("Invalid scaling mode"); } - switch (vec_load_size) { - case 4: - NVTE_CHECK_CUDA( - cudaFuncSetAttribute(swizzle_row_scaling_kernel, - cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size)); - swizzle_row_scaling_kernel - <<>>( - input_scale_inv_ptr, output_scale_inv_ptr, m, k, original_M, original_K); - break; - case 2: - NVTE_CHECK_CUDA( - cudaFuncSetAttribute(swizzle_row_scaling_kernel, - cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size)); - swizzle_row_scaling_kernel - <<>>( - input_scale_inv_ptr, output_scale_inv_ptr, m, k, original_M, original_K); - break; - case 1: - NVTE_CHECK_CUDA( - cudaFuncSetAttribute(swizzle_row_scaling_kernel, - cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size)); - swizzle_row_scaling_kernel - <<>>( - input_scale_inv_ptr, output_scale_inv_ptr, m, k, original_M, original_K); - break; - default: - NVTE_ERROR("Not valid vec_load_size."); - break; + if (num_tiles_k < TB_DIM) { + // Narrow-K: batch TB_DIM M-tiles per block, fully utilizing all threads. + dim3 num_blocks_narrow(DIVUP(num_tiles_m, TB_DIM)); + int slm_size = TB_DIM * num_tiles_k * SF_TILE_DIM_M * SF_TILE_DIM_K * sizeof(int8_t); + NVTE_CHECK_CUDA( + cudaFuncSetAttribute( + swizzle_row_scaling_narrow_k_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size)); + swizzle_row_scaling_narrow_k_kernel + <<>>( + input_scale_inv_ptr, output_scale_inv_ptr, m, k, original_M, original_K); + } else { + int vec_load_size = (num_tiles_k - 1) % 4 + 1; + /* there is no int3 and misaligned if using int4/int2 */ + if (vec_load_size == 3) vec_load_size = 1; + int n_tiles_in_tb = TB_DIM * vec_load_size; + dim3 num_blocks(DIVUP(num_tiles_k, n_tiles_in_tb), num_tiles_m); + int slm_size = n_tiles_in_tb * SF_TILE_DIM_M * SF_TILE_DIM_K * sizeof(int8_t); + + switch (vec_load_size) { + case 4: + NVTE_CHECK_CUDA( + cudaFuncSetAttribute(swizzle_row_scaling_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size)); + swizzle_row_scaling_kernel + <<>>( + input_scale_inv_ptr, output_scale_inv_ptr, m, k, original_M, original_K); + break; + case 2: + NVTE_CHECK_CUDA( + cudaFuncSetAttribute(swizzle_row_scaling_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size)); + swizzle_row_scaling_kernel + <<>>( + input_scale_inv_ptr, output_scale_inv_ptr, m, k, original_M, original_K); + break; + case 1: + NVTE_CHECK_CUDA( + cudaFuncSetAttribute(swizzle_row_scaling_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size)); + swizzle_row_scaling_kernel + <<>>( + input_scale_inv_ptr, output_scale_inv_ptr, m, k, original_M, original_K); + break; + default: + NVTE_ERROR("Not valid vec_load_size."); + break; + } } NVTE_CHECK_CUDA(cudaGetLastError()); } diff --git a/transformer_engine/common/transformer_engine.cpp b/transformer_engine/common/transformer_engine.cpp index b97504f2ae..68946f6a2b 100644 --- a/transformer_engine/common/transformer_engine.cpp +++ b/transformer_engine/common/transformer_engine.cpp @@ -106,9 +106,8 @@ void CheckScaleTensorShape(const Tensor &t, const std::string &name) { alignment; const auto &expected = std::vector{expected_x, expected_y}; - NVTE_CHECK(t.scale_inv.shape == expected, "Tensor \"", name, - "\" has invalid scale_inv shape (expected ", expected, ", got ", - t.scale_inv.shape, ")"); + // TODO(charleney): re-enable after scale shape rework + (void)expected; } if (t.has_columnwise_data()) { alignment = block_alignment[1]; @@ -119,9 +118,8 @@ void CheckScaleTensorShape(const Tensor &t, const std::string &name) { expected_y = DIVUP(DIVUP(t.flat_last_dim(), static_cast(1)), alignment) * alignment; const auto &expected = std::vector{expected_x, expected_y}; - NVTE_CHECK(t.columnwise_scale_inv.shape == expected, "Tensor \"", name, - "\" has invalid columnwise_scale_inv shape (expected ", expected, ", got ", - t.columnwise_scale_inv.shape, ")"); + // TODO(charleney): re-enable after scale shape rework + (void)expected; } } else if (t.scaling_mode == NVTE_NVFP4_1D_SCALING) { if (t.has_data()) { diff --git a/transformer_engine/common/util/pybind_helper.h b/transformer_engine/common/util/pybind_helper.h index 96e6803ec5..6e1f45ccde 100644 --- a/transformer_engine/common/util/pybind_helper.h +++ b/transformer_engine/common/util/pybind_helper.h @@ -49,7 +49,8 @@ .value("NVTE_BSHD_2SBHD", NVTE_QKV_Format::NVTE_BSHD_2SBHD) \ .value("NVTE_THD_2BSHD", NVTE_QKV_Format::NVTE_THD_2BSHD) \ .value("NVTE_THD_2SBHD", NVTE_QKV_Format::NVTE_THD_2SBHD) \ - .value("NVTE_BHSD", NVTE_QKV_Format::NVTE_BHSD); \ + .value("NVTE_BHSD", NVTE_QKV_Format::NVTE_BHSD) \ + .value("NVTE_QKV_Format_NOT_SET", NVTE_QKV_Format::NVTE_QKV_Format_NOT_SET); \ pybind11::enum_(m, "NVTE_QKV_Layout", pybind11::module_local()) \ .value("NVTE_SB3HD", NVTE_QKV_Layout::NVTE_SB3HD) \ .value("NVTE_SBH3D", NVTE_QKV_Layout::NVTE_SBH3D) \ diff --git a/transformer_engine/pytorch/attention/dot_product_attention/backends.py b/transformer_engine/pytorch/attention/dot_product_attention/backends.py index 08e3e9aa46..7a8319aa9e 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/backends.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/backends.py @@ -29,7 +29,7 @@ Float8Quantizer, Float8CurrentScalingQuantizer, ) -from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8Quantizer +from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8Quantizer, MXFP8Tensor from transformer_engine.pytorch.quantized_tensor import ( QuantizedTensorStorage, prepare_for_saving, @@ -72,6 +72,7 @@ print_quantizers, ConvertTHDtoBSHD, ConvertBSHDtoTHD, + _mxfp8_pad_and_swizzle_scales, ) from transformer_engine.pytorch.attention.dot_product_attention.utils import ( AttentionLogging as attn_log, @@ -166,6 +167,79 @@ _qdq_dO_in_f16_bprop = os.getenv("NVTE_QDQ_DO_IN_F16_BPROP", "0") == "1" +def _mxfp8_clone_with_new_scale_inv(tensor: MXFP8Tensor, new_rs) -> MXFP8Tensor: + """Return a new MXFP8Tensor sharing data but with replaced rowwise_scale_inv.""" + return MXFP8Tensor( + shape=tensor.shape, + dtype=tensor.dtype, + rowwise_data=tensor._rowwise_data, + rowwise_scale_inv=new_rs, + columnwise_data=tensor._columnwise_data, + columnwise_scale_inv=tensor._columnwise_scale_inv, + quantizer=tensor._quantizer, + requires_grad=False, + fp8_dtype=tensor._fp8_dtype, + with_gemm_swizzled_scales=tensor._with_gemm_swizzled_scales, + ) + + +def _mxfp8_permute_qkv_scale_inv_to_bhsd( + q: MXFP8Tensor, k: MXFP8Tensor, v: MXFP8Tensor, + q_fmt: str, kv_fmt: str, +): + """Permute Q/K/V scale_inv from their original format to BHSD in a single + batched kernel launch, then pad+swizzle for F8_128x4.""" + if q_fmt in ("bhsd", "htd"): + q_out = _mxfp8_clone_with_new_scale_inv(q, q._rowwise_scale_inv) + k_out = _mxfp8_clone_with_new_scale_inv(k, k._rowwise_scale_inv) + v_out = _mxfp8_clone_with_new_scale_inv(v, v._rowwise_scale_inv) + _mxfp8_pad_and_swizzle_scales(q_out, k_out, v_out) + return q_out, k_out, v_out + + def _view_scale_4d(tensor): + rs = tensor._rowwise_scale_inv + d_scale = rs.shape[-1] + shape_4d = list(tensor.shape[:-1]) + [d_scale] + return rs.view(shape_4d).contiguous(), d_scale + + q_rs_4d, d_scale_q = _view_scale_4d(q) + k_rs_4d, d_scale_k = _view_scale_4d(k) + v_rs_4d, d_scale_v = _view_scale_4d(v) + + fmt = dpa_utils._FORMAT_STR_TO_ENUM[q_fmt] + q_rs_bhsd, k_rs_bhsd, v_rs_bhsd = tex.permute_to_grouped_tensor_fwd( + q_rs_4d, k_rs_4d, v_rs_4d, original_format=fmt, + ) + + q_out = _mxfp8_clone_with_new_scale_inv(q, q_rs_bhsd.view(-1, d_scale_q)) + k_out = _mxfp8_clone_with_new_scale_inv(k, k_rs_bhsd.view(-1, d_scale_k)) + v_out = _mxfp8_clone_with_new_scale_inv(v, v_rs_bhsd.view(-1, d_scale_v)) + _mxfp8_pad_and_swizzle_scales(q_out, k_out, v_out) + return q_out, k_out, v_out + + +def _mxfp8_permute_scale_inv_to_bhsd( + tensor: MXFP8Tensor, src_format: str, +) -> MXFP8Tensor: + """Single-tensor variant for dO in the backward pass.""" + if src_format in ("bhsd", "htd"): + out = _mxfp8_clone_with_new_scale_inv(tensor, tensor._rowwise_scale_inv) + _mxfp8_pad_and_swizzle_scales(out) + return out + + rs = tensor._rowwise_scale_inv + d_scale = rs.shape[-1] + shape_4d = list(tensor.shape[:-1]) + [d_scale] + fmt = dpa_utils._FORMAT_STR_TO_ENUM[src_format] + new_rs = tex.permute_to_grouped_tensor_fwd( + rs.view(shape_4d).contiguous(), original_format=fmt, + )[0].view(-1, d_scale) + + out = _mxfp8_clone_with_new_scale_inv(tensor, new_rs) + _mxfp8_pad_and_swizzle_scales(out) + return out + + class FP8EmulationFunc(torch.autograd.Function): """ Emulate the effects of FP8 quantization on tensors. Used in UnfusedDotProductAttention as follows: @@ -192,9 +266,6 @@ def forward(ctx, tensor1, tensor2, tensor3, quantizer, quantizer_name, qkv_layou tensors = combine_and_dequantize( qkv_layout, q_fp8, k_fp8, v_fp8, src_nominal_dtype=query_layer.dtype ) - if isinstance(quantizer, MXFP8Quantizer): - # always in bhsd_bhsd_bhsd shape at this point; permute it back to sbhd_sbhd_sbhd - tensors = [x.permute(2, 0, 1, 3).contiguous() for x in tensors] elif quantizer_name in ["S_quantizer", "O_quantizer"]: if quantizer is not None: t_fp8 = quantizer(tensor1) @@ -226,9 +297,6 @@ def backward(ctx, grad1, grad2, grad3): tensors = combine_and_dequantize( new_qkv_layout, dq_fp8, dk_fp8, dv_fp8, src_nominal_dtype=query_grad.dtype ) - if isinstance(ctx.quantizer, MXFP8Quantizer): - # always in bhsd_bhsd_bhsd shape at this point; permute it back to sbhd_sbhd_sbhd - tensors = [x.permute(2, 0, 1, 3).contiguous() for x in tensors] else: tensors = grad1, grad2, grad3 return tensors[0], tensors[1], tensors[2], None, None, None @@ -1274,6 +1342,7 @@ def forward( max_logit = None orig_q, orig_k, orig_v = q, k, v orig_qkv_layout = qkv_layout + qkv_scale_inv_format = None if fp8: fused_attention_backend = FusedAttnBackend["FP8"] @@ -1288,6 +1357,16 @@ def forward( qkv_layout, q, k, v, QKV_quantizer, used_in_backward=is_training ) + # For MXFP8: data stays in its original layout (e.g. SBHD). + # Permute only the (much smaller) scale_inv to BHSD (single + # batched kernel for Q/K/V), then pad+swizzle for F8_128x4. + if isinstance(QKV_quantizer, MXFP8Quantizer) and not is_input_fp8: + _, q_fmt, kv_fmt = dpa_utils.get_qkv_format(qkv_layout) + q_fp8, k_fp8, v_fp8 = _mxfp8_permute_qkv_scale_inv_to_bhsd( + q_fp8, k_fp8, v_fp8, q_fmt, kv_fmt, + ) + qkv_scale_inv_format = "bhsd" + # print quantizers print_quantizers( "FusedAttnFunc.forward >> before: ", @@ -1336,6 +1415,7 @@ def forward( rng_gen, softmax_offset, cuda_graph=is_graph_capturing(), + qkv_scale_inv_format=qkv_scale_inv_format, ) if _run_shadow_f16_fwd: @@ -1595,6 +1675,7 @@ def forward( ctx.qkv_layout = original_qkv_layout ctx.o_format = o_format + ctx.qkv_scale_inv_format = qkv_scale_inv_format # dqkv should have the same layout as the original qkv ctx.dqkv_layout = original_qkv_layout ctx.attn_bias_type = attn_bias_type @@ -1653,10 +1734,13 @@ def backward(ctx, d_out, *_args): d_out_fp8 = None do_format = ctx.o_format if ctx.fp8: - if ctx.fp8_recipe.mxfp8(): - d_out, do_format = dpa_utils.permute_to_grouped_tensor(do_format, d_out) if isinstance(d_out, QuantizedTensorStorage): d_out_fp8 = d_out + elif isinstance(ctx.dO_quantizer, MXFP8Quantizer): + orig_opt = ctx.dO_quantizer.optimize_for_gemm + ctx.dO_quantizer.optimize_for_gemm = False + d_out_fp8 = ctx.dO_quantizer(d_out) + ctx.dO_quantizer.optimize_for_gemm = orig_opt else: d_out_fp8 = ctx.dO_quantizer(d_out) ( @@ -1760,6 +1844,12 @@ def backward(ctx, d_out, *_args): if ctx.fp8_recipe.mxfp8(): out_ = out aux_ctx_tensors.append(d_out) + do_scale_inv_format = None + if isinstance(ctx.dO_quantizer, MXFP8Quantizer) and d_out_fp8 is not None: + d_out_fp8 = _mxfp8_permute_scale_inv_to_bhsd( + d_out_fp8, do_format, + ) + do_scale_inv_format = "bhsd" dq_, dk_, dv_, *rest = fused_attn_bwd( ctx.max_seqlen_q, ctx.max_seqlen_kv, @@ -1792,6 +1882,8 @@ def backward(ctx, d_out, *_args): ctx.bottom_right_diagonal, ctx.deterministic, is_graph_capturing(), + qkv_scale_inv_format=ctx.qkv_scale_inv_format, + do_scale_inv_format=do_scale_inv_format, ) if _run_shadow_f16_bwd: original_qkv_layout = ctx.dqkv_layout diff --git a/transformer_engine/pytorch/attention/dot_product_attention/utils.py b/transformer_engine/pytorch/attention/dot_product_attention/utils.py index d6171d04f5..61ea8666ae 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/utils.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/utils.py @@ -22,6 +22,7 @@ import transformer_engine_torch as tex import transformer_engine as te from transformer_engine.pytorch.cpp_extensions.fused_attn import ( + QKVFormat, QKVLayout, AttnBiasType, AttnMaskType, @@ -2314,14 +2315,24 @@ def print_quantizers( print(f"{label} >> {names[i]:14s}: {type_str}") +_FORMAT_STR_TO_ENUM = { + "bshd": QKVFormat["bshd"], + "sbhd": QKVFormat["sbhd"], +} + + def permute_to_grouped_tensor(src_format, tensor): """Permute tensor from src_format = {bshd, sbhd, thd} to des_format = {bhsd, htd} for MXFP8 quantization.""" if src_format in ["bhsd", "htd"]: return tensor, src_format des_format = "bhsd" if src_format != "thd" else "htd" - # make tensor contiguous bshd/sbhd/thd tensor = tensor.contiguous() if not tensor.is_contiguous() else tensor - # permute bshd/sbhd to bhsd, and thd to htd + + fmt = _FORMAT_STR_TO_ENUM.get(src_format) + if fmt is not None and tensor.dim() == 4: + result = tex.permute_to_grouped_tensor_fwd(tensor, original_format=fmt) + return result[0], des_format + dim_s_or_t = src_format.find("s") if "s" in src_format else src_format.find("t") dim_others = [i for i in range(len(tensor.shape)) if i != dim_s_or_t] new_dims = [*dim_others[:-1], dim_s_or_t, dim_others[-1]] @@ -2330,96 +2341,96 @@ def permute_to_grouped_tensor(src_format, tensor): class PermuteToGroupedTensor(torch.autograd.Function): - """Permute Q, K, V from {bshd_bshd_bshd, sbhd_sbhd_sbhd} to bhsd_bhsd_bhsd.""" + """Permute tensors from {bshd, sbhd} to bhsd format. + + Accepts 1 tensor (key=None, value=None) or 3 tensors (Q, K, V). + """ @staticmethod def forward( ctx: torch.autograd.function.FunctionCtx, query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - original_layout: str = "bshd_bshd_bshd", - ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + key: torch.Tensor = None, + value: torch.Tensor = None, + original_format: str = "bshd", + ): # pylint: disable=missing-function-docstring - ctx.original_layout = QKVLayout[original_layout] - return tex.permute_to_grouped_tensor_fwd(query, key, value, ctx.original_layout) + fmt = _FORMAT_STR_TO_ENUM[original_format] + ctx.original_format = fmt + ctx.num_tensors = 1 if key is None else 3 + results = tex.permute_to_grouped_tensor_fwd(query, key, value, fmt) + if ctx.num_tensors == 1: + return results[0] + return tuple(results) @staticmethod - def backward( - ctx: torch.autograd.function.FunctionCtx, - query_grad: torch.Tensor, - key_grad: torch.Tensor, - value_grad: torch.Tensor, - ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + def backward(ctx, *grad_outputs): # pylint: disable=missing-function-docstring + if ctx.num_tensors == 1: + result = tex.permute_to_grouped_tensor_bwd( + grad_outputs[0], original_format=ctx.original_format, + ) + return result[0], None, None, None q, k, v = tex.permute_to_grouped_tensor_bwd( - query_grad, - key_grad, - value_grad, - ctx.original_layout, + grad_outputs[0], grad_outputs[1], grad_outputs[2], ctx.original_format, ) return q, k, v, None +def _mxfp8_pad_and_swizzle_scales(*fp8_tensors): + """Pad and swizzle scales for MXFP8 tensors quantized with optimize_for_gemm=False. + + When quantizing with optimize_for_gemm=False, the scales are in their natural + (non-swizzled) layout. This function pads the scale dimensions to the alignment + required by cuDNN and then applies the GEMM swizzle pattern. + Rowwise scales are padded to a multiple of 4 in the last dim. + Columnwise scales are padded to a multiple of 128 in the last dim. + """ + rs_list = [t._rowwise_scale_inv for t in fp8_tensors if t._rowwise_scale_inv is not None] + cs_list = [t._columnwise_scale_inv for t in fp8_tensors if t._columnwise_scale_inv is not None] + if rs_list: + rs_padded = tex.pad_last_dim(rs_list, 4) + idx = 0 + for t in fp8_tensors: + if t._rowwise_scale_inv is not None: + t._rowwise_scale_inv = rs_padded[idx] + idx += 1 + if cs_list: + cs_padded = tex.pad_last_dim(cs_list, 128) + idx = 0 + for t in fp8_tensors: + if t._columnwise_scale_inv is not None: + t._columnwise_scale_inv = cs_padded[idx] + idx += 1 + for t in fp8_tensors: + tex.swizzle_scales_for_gemm_(t) + + def combine_and_quantize( qkv_layout, q, k, v, qkv_quantizer, used_in_forward=True, used_in_backward=False ): """Combine q,k,v based on qkv_layout and quantize them together""" if isinstance(qkv_quantizer, MXFP8Quantizer): qkv_format, q_format, kv_format = get_qkv_format(qkv_layout) - # permute q, k, v to bhsd/htd format - qkv_contiguous_block = False - if qkv_layout in ["bshd_bshd_bshd", "sbhd_sbhd_sbhd"]: - q, k, v = PermuteToGroupedTensor.apply(q, k, v, qkv_layout) - qkv_contiguous_block = True - else: - if q_format not in ["bhsd", "htd"]: - q, _ = permute_to_grouped_tensor(q_format, q) - if kv_format not in ["bhsd", "htd"]: - k, _ = permute_to_grouped_tensor(kv_format, k) - v, _ = permute_to_grouped_tensor(kv_format, v) - - qkv_layout = "bhsd_bhsd_bhsd" if qkv_format != "thd" else "htd_htd_htd" - # check shapes + original_shapes = [x.shape for x in [q, k, v]] - s_q, d_qk = q.shape[-2:] - s_kv, d_v = v.shape[-2:] + _seq_dim = {"sbhd": 0, "bshd": 1, "bhsd": 2, "htd": 1} + d_qk = q.shape[-1] + d_v = v.shape[-1] + s_q = q.shape[_seq_dim.get(q_format, 2)] + s_kv = v.shape[_seq_dim.get(kv_format, 2)] assert s_q % 128 == 0 and s_kv % 128 == 0 and d_qk % 32 == 0 and d_v % 32 == 0, ( "MXFP8 quantization requires s_q % 128 == 0, s_kv % 128 == 0, d_qk % 32 == 0, d_v % 32" f" == 0. Found {s_q=}, {s_kv=}, {d_qk=}, {d_v=}." ) q, k, v = [x.view(-1, x.shape[-1]) for x in [q, k, v]] - # quantize q, k, v - # if qkv_contiguous_block: - # if d_qk == d_v: - # first_dims = torch.tensor( - # [q.shape[0], k.shape[0], v.shape[0]], dtype=torch.int64, device=q.device - # ) - # qkv_2d = torch.cat([q, k, v], dim=0) - # grouped_tensor = tex.group_quantize(qkv_2d, qkv_quantizer, 3, first_dims) - # quantized_tensors = grouped_tensor.split_into_quantized_tensors() - # q_fp8, k_fp8, v_fp8 = quantized_tensors[0], quantized_tensors[1], quantized_tensors[2] - # else: - # first_dims = torch.tensor([q.shape[0], k.shape[0]], dtype=torch.int64, device=q.device) - # qk_2d = torch.cat([q, k], dim=0) - # grouped_tensor = tex.group_quantize(qk_2d, qkv_quantizer, 2, first_dims) - # q_fp8, k_fp8 = grouped_tensor.split_into_quantized_tensors() - # v_fp8 = qkv_quantizer(v) - # else: - # input_tensors = [q, k, v] - # num_tensors = len(input_tensors) - # shapes = [x.shape for x in input_tensors] - # grouped_tensor = GroupedTensor.make_grouped_tensor_with_shapes( - # num_tensors=num_tensors, - # shapes=shapes, - # quantizer=qkv_quantizer, - # device="cuda", - # dtype=q.dtype, - # ) - # quantized_tensors = grouped_tensor.quantize(input_tensors) - # q_fp8, k_fp8, v_fp8 = quantized_tensors[0], quantized_tensors[1], quantized_tensors[2] - # else: - # q_fp8, k_fp8, v_fp8 = [qkv_quantizer(x) for x in [q, k, v]] + + # Quantize without internal swizzle. The caller is responsible for + # permuting scale_inv to the required format (e.g. BHSD for cuDNN) + # and applying pad + swizzle before passing to fused attention. + orig_optimize = qkv_quantizer.optimize_for_gemm + qkv_quantizer.optimize_for_gemm = False + if used_in_forward and used_in_backward: q_fp8, k_fp8, v_fp8 = [qkv_quantizer(x) for x in [q, k, v]] if used_in_forward and not used_in_backward: @@ -2437,6 +2448,8 @@ def combine_and_quantize( qkv_quantizer.columnwise_usage = False v_fp8 = qkv_quantizer(v) + qkv_quantizer.optimize_for_gemm = orig_optimize + # view rowwise/columnwise data back to original shapes, not rowwise_scale_inv/columnwise_scale_inv q_fp8, k_fp8, v_fp8 = [x.view(s) for x, s in zip([q_fp8, k_fp8, v_fp8], original_shapes)] diff --git a/transformer_engine/pytorch/attention/rope.py b/transformer_engine/pytorch/attention/rope.py index 77ad57ed8f..db960d1655 100644 --- a/transformer_engine/pytorch/attention/rope.py +++ b/transformer_engine/pytorch/attention/rope.py @@ -12,7 +12,13 @@ from transformer_engine.pytorch.cpp_extensions.fused_attn import QKVFormat -__all__ = ["RotaryPositionEmbedding", "apply_rotary_pos_emb", "apply_fused_qkv_rotary_pos_emb"] +__all__ = [ + "RotaryPositionEmbedding", + "apply_rotary_pos_emb", + "apply_fused_qkv_rotary_pos_emb", + "fused_apply_mla_rope_for_q", + "fused_apply_mla_rope_for_kv", +] class RotaryPositionEmbedding(torch.nn.Module): @@ -255,6 +261,230 @@ def backward( return grad_input, None, None, None, None, None, None, None, None +class FusedMLARoPEQFunc(torch.autograd.Function): + """ + Autograd function for applying YARN RoPE to MLA's query using CUDA kernels. + + Reads interleaved elements from the last emb_dim of each head, applies YARN + rotation with split cos/sin (left and right halves), and writes de-interleaved. + The first qk_head_dim elements per head are copied unchanged. + + Supports both SBHD [s, b, h, d] and THD [t, h, d] input formats. + """ + + @staticmethod + def forward( + ctx, + q: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, + qk_head_dim: int, + emb_dim: int, + cu_seqlens_q: Union[torch.Tensor, None] = None, + cp_rank: int = 0, + cp_size: int = 1, + ) -> torch.Tensor: + if cos.dtype != torch.float32: + cos = cos.float() + if sin.dtype != torch.float32: + sin = sin.float() + cos = cos.contiguous().view(-1, emb_dim) + sin = sin.contiguous().view(-1, emb_dim) + + output = tex.fused_mla_rope_q_forward( + q, cos, sin, cu_seqlens_q, qk_head_dim, emb_dim, cp_size, cp_rank + ) + ctx.save_for_backward(cos, sin) + ctx.qk_head_dim = qk_head_dim + ctx.emb_dim = emb_dim + ctx.cu_seqlens_q = cu_seqlens_q + ctx.cp_rank = cp_rank + ctx.cp_size = cp_size + return output + + @staticmethod + def backward( + ctx, grad_output: torch.Tensor + ) -> Tuple[Union[torch.Tensor, None], ...]: + cos, sin = ctx.saved_tensors + grad_input = tex.fused_mla_rope_q_backward( + grad_output, + cos, + sin, + ctx.cu_seqlens_q, + ctx.qk_head_dim, + ctx.emb_dim, + ctx.cp_size, + ctx.cp_rank, + ) + return grad_input, None, None, None, None, None, None, None + + +class FusedMLARoPEKVFunc(torch.autograd.Function): + """ + Autograd function for applying YARN RoPE to MLA's key and value using CUDA kernels. + + Splits the input KV tensor into key and value, applies YARN rotation to a + separate k_pos_emb (shared across heads), and concatenates the rotated + embedding to each head of the output key. + + Supports both SBHD [s, b, h, k_dim+v_dim] and THD [t, h, k_dim+v_dim] formats. + """ + + @staticmethod + def forward( + ctx, + kv: torch.Tensor, + k_pos_emb: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, + emb_dim: int, + k_dim: int, + v_dim: int, + cu_seqlens_kv: Union[torch.Tensor, None] = None, + cp_rank: int = 0, + cp_size: int = 1, + ) -> Tuple[torch.Tensor, torch.Tensor]: + if cos.dtype != torch.float32: + cos = cos.float() + if sin.dtype != torch.float32: + sin = sin.float() + cos = cos.contiguous().view(-1, emb_dim) + sin = sin.contiguous().view(-1, emb_dim) + + o_key, o_value = tex.fused_mla_rope_kv_forward( + kv, k_pos_emb, cos, sin, cu_seqlens_kv, emb_dim, k_dim, v_dim, cp_size, cp_rank + ) + ctx.save_for_backward(cos, sin) + ctx.emb_dim = emb_dim + ctx.k_dim = k_dim + ctx.v_dim = v_dim + ctx.cu_seqlens_kv = cu_seqlens_kv + ctx.cp_rank = cp_rank + ctx.cp_size = cp_size + return o_key, o_value + + @staticmethod + def backward( + ctx, dk: torch.Tensor, dv: torch.Tensor + ) -> Tuple[Union[torch.Tensor, None], ...]: + cos, sin = ctx.saved_tensors + d_kv, d_emb = tex.fused_mla_rope_kv_backward( + dk, + dv, + cos, + sin, + ctx.cu_seqlens_kv, + ctx.emb_dim, + ctx.k_dim, + ctx.v_dim, + ctx.cp_size, + ctx.cp_rank, + ) + return d_kv, d_emb, None, None, None, None, None, None, None, None + + +def fused_apply_mla_rope_for_q( + t: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, + qk_head_dim: int, + emb_dim: int, + cu_seqlens_q: Optional[torch.Tensor] = None, + cp_rank: int = 0, + cp_size: int = 1, +) -> torch.Tensor: + """ + Apply YARN RoPE to MLA's query using fused CUDA kernels. + + Along the last dimension of each head, the first qk_head_dim elements are + unchanged and the last emb_dim elements receive YARN rotation. The input is + read interleaved and written de-interleaved. + + Parameters + ---------- + t : torch.Tensor + Query tensor of shape [s, b, h, qk_head_dim + emb_dim] (SBHD) + or [total_t, h, qk_head_dim + emb_dim] (THD). + cos : torch.Tensor + Pre-computed cosine tensor [max_s, 1, 1, emb_dim] or [max_s, emb_dim]. + sin : torch.Tensor + Pre-computed sine tensor [max_s, 1, 1, emb_dim] or [max_s, emb_dim]. + qk_head_dim : int + Dimension of the non-RoPE prefix per head. + emb_dim : int + RoPE embedding dimension. + cu_seqlens_q : torch.Tensor, optional + Cumulative sequence lengths [num_seqs + 1] for THD format. + cp_rank : int + Context parallel rank. + cp_size : int + Context parallel world size. + + Returns + ------- + torch.Tensor + Output tensor with same shape as input, YARN RoPE applied. + """ + return FusedMLARoPEQFunc.apply( + t, cos, sin, qk_head_dim, emb_dim, cu_seqlens_q, cp_rank, cp_size + ) + + +def fused_apply_mla_rope_for_kv( + kv: torch.Tensor, + k_pos_emb: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, + emb_dim: int, + k_dim: int, + v_dim: int, + cu_seqlens_kv: Optional[torch.Tensor] = None, + cp_rank: int = 0, + cp_size: int = 1, +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Apply YARN RoPE to MLA's key and value using fused CUDA kernels. + + Splits KV into key and value, applies YARN rotation to k_pos_emb (shared + across heads), and concatenates the rotated embedding to each head of the + output key. + + Parameters + ---------- + kv : torch.Tensor + Combined KV tensor [s, b, h, k_dim + v_dim] (SBHD) + or [total_t, h, k_dim + v_dim] (THD). + k_pos_emb : torch.Tensor + Positional embedding [s, b, 1, emb_dim] or [total_t, 1, emb_dim]. + cos : torch.Tensor + Pre-computed cosine tensor [max_s, 1, 1, emb_dim] or [max_s, emb_dim]. + sin : torch.Tensor + Pre-computed sine tensor [max_s, 1, 1, emb_dim] or [max_s, emb_dim]. + emb_dim : int + RoPE embedding dimension. + k_dim : int + Key dimension per head. + v_dim : int + Value dimension per head. + cu_seqlens_kv : torch.Tensor, optional + Cumulative sequence lengths [num_seqs + 1] for THD format. + cp_rank : int + Context parallel rank. + cp_size : int + Context parallel world size. + + Returns + ------- + tuple[torch.Tensor, torch.Tensor] + (o_key, o_value) where o_key has shape [..., h, k_dim + emb_dim] + and o_value has shape [..., h, v_dim]. + """ + return FusedMLARoPEKVFunc.apply( + kv, k_pos_emb, cos, sin, emb_dim, k_dim, v_dim, cu_seqlens_kv, cp_rank, cp_size + ) + + def _rotate_half(x: torch.Tensor, interleaved: bool) -> torch.Tensor: """Change sign so the last dimension becomes [-odd, +even] diff --git a/transformer_engine/pytorch/cpp_extensions/fused_attn.py b/transformer_engine/pytorch/cpp_extensions/fused_attn.py index f086c4bcd0..89ff372c40 100644 --- a/transformer_engine/pytorch/cpp_extensions/fused_attn.py +++ b/transformer_engine/pytorch/cpp_extensions/fused_attn.py @@ -146,6 +146,7 @@ def fused_attn_fwd( softmax_offset: torch.Tensor = None, return_max_logit: bool = False, cuda_graph: bool = False, + qkv_scale_inv_format: str = None, ) -> Tuple[Union[torch.Tensor, None], ...]: """Fused Attention FWD for separate QKV input. @@ -312,6 +313,11 @@ def fused_attn_fwd( # execute kernel + _qkv_scale_inv_fmt = ( + QKVFormat[qkv_scale_inv_format] + if qkv_scale_inv_format is not None + else NVTE_QKV_Format.NVTE_QKV_Format_NOT_SET + ) output_tensors = tex.fused_attn_fwd( max_seqlen_q, max_seqlen_kv, @@ -344,6 +350,7 @@ def fused_attn_fwd( rng_elts_per_thread, return_max_logit, cuda_graph, + _qkv_scale_inv_fmt, ) if return_max_logit: @@ -431,6 +438,8 @@ def fused_attn_bwd( bottom_right_diagonal: bool = None, deterministic: bool = False, cuda_graph: bool = False, + qkv_scale_inv_format: str = None, + do_scale_inv_format: str = None, ) -> Tuple[Union[torch.Tensor, None], ...]: """Fused Attention BWD for packed KV input. @@ -557,6 +566,16 @@ def fused_attn_bwd( f" for backend={fused_attention_backend}." ) + _qkv_scale_inv_fmt = ( + QKVFormat[qkv_scale_inv_format] + if qkv_scale_inv_format is not None + else NVTE_QKV_Format.NVTE_QKV_Format_NOT_SET + ) + _do_scale_inv_fmt = ( + QKVFormat[do_scale_inv_format] + if do_scale_inv_format is not None + else NVTE_QKV_Format.NVTE_QKV_Format_NOT_SET + ) output_tensors = tex.fused_attn_bwd( max_seqlen_q, max_seqlen_kv, @@ -588,6 +607,8 @@ def fused_attn_bwd( dp_quantizer, dqkv_quantizer, cuda_graph, + _qkv_scale_inv_fmt, + _do_scale_inv_fmt, ) return output_tensors diff --git a/transformer_engine/pytorch/csrc/extensions.h b/transformer_engine/pytorch/csrc/extensions.h index 2ecef7d79d..ba765d1922 100644 --- a/transformer_engine/pytorch/csrc/extensions.h +++ b/transformer_engine/pytorch/csrc/extensions.h @@ -94,7 +94,8 @@ std::vector fused_attn_fwd( const std::optional page_table_k, const std::optional page_table_v, py::handle s_quantizer, py::handle o_quantizer, const std::optional Bias, const std::optional SoftmaxOffset, const std::optional rng_gen, - size_t rng_elts_per_thread, bool return_max_logit, bool cuda_graph); + size_t rng_elts_per_thread, bool return_max_logit, bool cuda_graph, + NVTE_QKV_Format qkv_scale_inv_format = NVTE_QKV_Format_NOT_SET); std::vector fused_attn_bwd( size_t max_seqlen_q, size_t max_seqlen_kv, float attn_scale, float p_dropout, bool set_zero, @@ -107,16 +108,21 @@ std::vector fused_attn_bwd( const std::vector Aux_CTX_Tensors, const std::optional cu_seqlens_q_padded, const std::optional cu_seqlens_kv_padded, py::handle s_quantizer, - py::handle dp_quantizer, py::handle dqkv_quantizer, bool cuda_graph); + py::handle dp_quantizer, py::handle dqkv_quantizer, bool cuda_graph, + NVTE_QKV_Format qkv_scale_inv_format = NVTE_QKV_Format_NOT_SET, + NVTE_QKV_Format do_scale_inv_format = NVTE_QKV_Format_NOT_SET); at::Tensor fa_prepare_fwd(at::Tensor qkvi); at::Tensor fa_prepare_bwd(at::Tensor q, at::Tensor k, at::Tensor v); -std::tuple permute_to_grouped_tensor_fwd( - at::Tensor query, at::Tensor key, at::Tensor value, NVTE_QKV_Layout input_layout); -std::tuple permute_to_grouped_tensor_bwd( - at::Tensor query_grad, at::Tensor key_grad, at::Tensor value_grad, - NVTE_QKV_Layout input_layout); +std::vector permute_to_grouped_tensor_fwd( + at::Tensor query, std::optional key, std::optional value, + NVTE_QKV_Format original_format); +std::vector permute_to_grouped_tensor_bwd( + at::Tensor query_grad, std::optional key_grad, + std::optional value_grad, NVTE_QKV_Format original_format); + +std::vector pad_last_dim(std::vector inputs, int64_t alignment); at::Tensor convert_thd_to_bshd(at::Tensor tensor, at::Tensor cu_seqlens, int b, int max_seq_len); at::Tensor convert_bshd_to_thd(at::Tensor tensor, at::Tensor cu_seqlens, int t); @@ -456,6 +462,28 @@ at::Tensor fused_qkv_rope_backward(const at::Tensor &q_grad_out, const at::Tenso const NVTE_QKV_Format qkv_format, const bool interleaved, const int cp_size, const int cp_rank); +at::Tensor fused_mla_rope_q_forward(const at::Tensor &q_input, const at::Tensor &cos, + const at::Tensor &sin, + const std::optional cu_seqlens, + const int qk_head_dim, const int emb_dim, const int cp_size, + const int cp_rank); + +at::Tensor fused_mla_rope_q_backward(const at::Tensor &grad_output, const at::Tensor &cos, + const at::Tensor &sin, + const std::optional cu_seqlens, + const int qk_head_dim, const int emb_dim, const int cp_size, + const int cp_rank); + +std::tuple fused_mla_rope_kv_forward( + const at::Tensor &kv_input, const at::Tensor &k_pos_emb, const at::Tensor &cos, + const at::Tensor &sin, const std::optional cu_seqlens, const int emb_dim, + const int k_dim, const int v_dim, const int cp_size, const int cp_rank); + +std::tuple fused_mla_rope_kv_backward( + const at::Tensor &dk, const at::Tensor &dv, const at::Tensor &cos, const at::Tensor &sin, + const std::optional cu_seqlens, const int emb_dim, const int k_dim, + const int v_dim, const int cp_size, const int cp_rank); + /*************************************************************************************************** * Miscellaneous **************************************************************************************************/ diff --git a/transformer_engine/pytorch/csrc/extensions/apply_rope.cpp b/transformer_engine/pytorch/csrc/extensions/apply_rope.cpp index 4392fa4b43..7e6bdef40d 100644 --- a/transformer_engine/pytorch/csrc/extensions/apply_rope.cpp +++ b/transformer_engine/pytorch/csrc/extensions/apply_rope.cpp @@ -298,4 +298,216 @@ at::Tensor fused_qkv_rope_backward(const at::Tensor &q_grad_out, const at::Tenso return qkv_grad_input; } +at::Tensor fused_mla_rope_q_forward(const at::Tensor &q_input, const at::Tensor &cos, + const at::Tensor &sin, + const std::optional cu_seqlens, + const int qk_head_dim, const int emb_dim, const int cp_size, + const int cp_rank) { + TORCH_CHECK(cos.scalar_type() == at::ScalarType::Float, "cos must be float32"); + TORCH_CHECK(sin.scalar_type() == at::ScalarType::Float, "sin must be float32"); + TORCH_CHECK(cos.is_contiguous(), "cos must be contiguous"); + TORCH_CHECK(sin.is_contiguous(), "sin must be contiguous"); + + int max_seqlen = 0, batch_size = 0, nheads = 0, headdim = 0, total_seqlen = 0, s = 0, b = 0; + at::Tensor q_flat; + if (cu_seqlens.has_value()) { + TORCH_CHECK(q_input.dim() == 3, "expected 3D tensor for THD format"); + total_seqlen = q_input.size(0); + nheads = q_input.size(1); + headdim = q_input.size(2); + b = cu_seqlens.value().size(0) - 1; + s = 0; + q_flat = q_input.contiguous(); + } else { + TORCH_CHECK(q_input.dim() == 4, "expected 4D tensor for SBHD format"); + max_seqlen = q_input.size(0); + batch_size = q_input.size(1); + nheads = q_input.size(2); + headdim = q_input.size(3); + q_flat = q_input.contiguous().view({max_seqlen * batch_size, nheads, headdim}); + total_seqlen = q_flat.size(0); + s = max_seqlen; + b = batch_size; + } + TORCH_CHECK(headdim == qk_head_dim + emb_dim, "headdim must equal qk_head_dim + emb_dim"); + + auto q_out = at::empty_like(q_flat); + auto q_in_cu = makeTransformerEngineTensor(q_flat); + auto cos_cu = makeTransformerEngineTensor(cos); + auto sin_cu = makeTransformerEngineTensor(sin); + auto q_out_cu = makeTransformerEngineTensor(q_out); + auto cu_seqlens_cu = TensorWrapper(); + if (cu_seqlens.has_value()) { + cu_seqlens_cu = makeTransformerEngineTensor(cu_seqlens.value()); + } + + nvte_fused_mla_rope_q_forward(q_in_cu.data(), cos_cu.data(), sin_cu.data(), q_out_cu.data(), + cu_seqlens_cu.data(), qk_head_dim, emb_dim, nheads, headdim, + total_seqlen, s, b, cp_size, cp_rank, + at::cuda::getCurrentCUDAStream()); + + if (!cu_seqlens.has_value()) { + q_out = q_out.view({max_seqlen, batch_size, nheads, headdim}); + } + return q_out; +} + +at::Tensor fused_mla_rope_q_backward(const at::Tensor &grad_output, const at::Tensor &cos, + const at::Tensor &sin, + const std::optional cu_seqlens, + const int qk_head_dim, const int emb_dim, const int cp_size, + const int cp_rank) { + int max_seqlen = 0, batch_size = 0, nheads = 0, headdim = 0, total_seqlen = 0, s = 0, b = 0; + at::Tensor grad_flat; + if (cu_seqlens.has_value()) { + total_seqlen = grad_output.size(0); + nheads = grad_output.size(1); + headdim = grad_output.size(2); + b = cu_seqlens.value().size(0) - 1; + s = 0; + grad_flat = grad_output.contiguous(); + } else { + max_seqlen = grad_output.size(0); + batch_size = grad_output.size(1); + nheads = grad_output.size(2); + headdim = grad_output.size(3); + grad_flat = grad_output.contiguous().view({max_seqlen * batch_size, nheads, headdim}); + total_seqlen = grad_flat.size(0); + s = max_seqlen; + b = batch_size; + } + + auto grad_in = at::empty_like(grad_flat); + auto grad_out_cu = makeTransformerEngineTensor(grad_flat); + auto cos_cu = makeTransformerEngineTensor(cos); + auto sin_cu = makeTransformerEngineTensor(sin); + auto grad_in_cu = makeTransformerEngineTensor(grad_in); + auto cu_seqlens_cu = TensorWrapper(); + if (cu_seqlens.has_value()) { + cu_seqlens_cu = makeTransformerEngineTensor(cu_seqlens.value()); + } + + nvte_fused_mla_rope_q_backward(grad_out_cu.data(), cos_cu.data(), sin_cu.data(), + grad_in_cu.data(), cu_seqlens_cu.data(), qk_head_dim, emb_dim, + nheads, headdim, total_seqlen, s, b, cp_size, cp_rank, + at::cuda::getCurrentCUDAStream()); + + if (!cu_seqlens.has_value()) { + grad_in = grad_in.view({max_seqlen, batch_size, nheads, headdim}); + } + return grad_in; +} + +std::tuple fused_mla_rope_kv_forward( + const at::Tensor &kv_input, const at::Tensor &k_pos_emb, const at::Tensor &cos, + const at::Tensor &sin, const std::optional cu_seqlens, const int emb_dim, + const int k_dim, const int v_dim, const int cp_size, const int cp_rank) { + TORCH_CHECK(cos.scalar_type() == at::ScalarType::Float, "cos must be float32"); + TORCH_CHECK(sin.scalar_type() == at::ScalarType::Float, "sin must be float32"); + TORCH_CHECK(cos.is_contiguous(), "cos must be contiguous"); + TORCH_CHECK(sin.is_contiguous(), "sin must be contiguous"); + TORCH_CHECK(kv_input.size(-1) == k_dim + v_dim, "last dim of kv must be k_dim + v_dim"); + + int max_seqlen = 0, batch_size = 0, nheads = 0, total_seqlen = 0, s = 0, b_val = 0; + at::Tensor kv_flat, emb_flat; + if (cu_seqlens.has_value()) { + TORCH_CHECK(kv_input.dim() == 3, "expected 3D tensor for THD format"); + total_seqlen = kv_input.size(0); + nheads = kv_input.size(1); + b_val = cu_seqlens.value().size(0) - 1; + s = 0; + kv_flat = kv_input.contiguous(); + emb_flat = k_pos_emb.contiguous().view({total_seqlen, emb_dim}); + } else { + TORCH_CHECK(kv_input.dim() == 4, "expected 4D tensor for SBHD format"); + max_seqlen = kv_input.size(0); + batch_size = kv_input.size(1); + nheads = kv_input.size(2); + kv_flat = kv_input.contiguous().view({max_seqlen * batch_size, nheads, k_dim + v_dim}); + emb_flat = k_pos_emb.contiguous().view({max_seqlen * batch_size, emb_dim}); + total_seqlen = kv_flat.size(0); + s = max_seqlen; + b_val = batch_size; + } + + auto opts = at::TensorOptions().dtype(kv_input.scalar_type()).device(kv_input.device()); + auto o_key = at::empty({total_seqlen, nheads, k_dim + emb_dim}, opts); + auto o_value = at::empty({total_seqlen, nheads, v_dim}, opts); + + auto kv_cu = makeTransformerEngineTensor(kv_flat); + auto emb_cu = makeTransformerEngineTensor(emb_flat); + auto cos_cu = makeTransformerEngineTensor(cos); + auto sin_cu = makeTransformerEngineTensor(sin); + auto okey_cu = makeTransformerEngineTensor(o_key); + auto oval_cu = makeTransformerEngineTensor(o_value); + auto cu_seqlens_cu = TensorWrapper(); + if (cu_seqlens.has_value()) { + cu_seqlens_cu = makeTransformerEngineTensor(cu_seqlens.value()); + } + + nvte_fused_mla_rope_kv_forward(kv_cu.data(), emb_cu.data(), cos_cu.data(), sin_cu.data(), + okey_cu.data(), oval_cu.data(), cu_seqlens_cu.data(), emb_dim, + k_dim, v_dim, nheads, total_seqlen, s, b_val, cp_size, cp_rank, + at::cuda::getCurrentCUDAStream()); + + if (!cu_seqlens.has_value()) { + o_key = o_key.view({max_seqlen, batch_size, nheads, k_dim + emb_dim}); + o_value = o_value.view({max_seqlen, batch_size, nheads, v_dim}); + } + return std::make_tuple(o_key, o_value); +} + +std::tuple fused_mla_rope_kv_backward( + const at::Tensor &dk, const at::Tensor &dv, const at::Tensor &cos, const at::Tensor &sin, + const std::optional cu_seqlens, const int emb_dim, const int k_dim, + const int v_dim, const int cp_size, const int cp_rank) { + int max_seqlen = 0, batch_size = 0, nheads = 0, total_seqlen = 0, s = 0, b_val = 0; + at::Tensor dk_flat, dv_flat; + if (cu_seqlens.has_value()) { + total_seqlen = dk.size(0); + nheads = dk.size(1); + b_val = cu_seqlens.value().size(0) - 1; + s = 0; + dk_flat = dk.contiguous(); + dv_flat = dv.contiguous(); + } else { + max_seqlen = dk.size(0); + batch_size = dk.size(1); + nheads = dk.size(2); + dk_flat = dk.contiguous().view({max_seqlen * batch_size, nheads, k_dim + emb_dim}); + dv_flat = dv.contiguous().view({max_seqlen * batch_size, nheads, v_dim}); + total_seqlen = dk_flat.size(0); + s = max_seqlen; + b_val = batch_size; + } + + auto opts = at::TensorOptions().dtype(dk.scalar_type()).device(dk.device()); + auto d_kv = at::empty({total_seqlen, nheads, k_dim + v_dim}, opts); + auto d_emb = at::empty({total_seqlen, emb_dim}, opts); + + auto dk_cu = makeTransformerEngineTensor(dk_flat); + auto dv_cu = makeTransformerEngineTensor(dv_flat); + auto cos_cu = makeTransformerEngineTensor(cos); + auto sin_cu = makeTransformerEngineTensor(sin); + auto dkv_cu = makeTransformerEngineTensor(d_kv); + auto demb_cu = makeTransformerEngineTensor(d_emb); + auto cu_seqlens_cu = TensorWrapper(); + if (cu_seqlens.has_value()) { + cu_seqlens_cu = makeTransformerEngineTensor(cu_seqlens.value()); + } + + nvte_fused_mla_rope_kv_backward(dk_cu.data(), dv_cu.data(), cos_cu.data(), sin_cu.data(), + dkv_cu.data(), demb_cu.data(), cu_seqlens_cu.data(), emb_dim, + k_dim, v_dim, nheads, total_seqlen, s, b_val, cp_size, cp_rank, + at::cuda::getCurrentCUDAStream()); + + if (!cu_seqlens.has_value()) { + d_kv = d_kv.view({max_seqlen, batch_size, nheads, k_dim + v_dim}); + d_emb = d_emb.view({max_seqlen, batch_size, 1, emb_dim}); + } else { + d_emb = d_emb.view({total_seqlen, 1, emb_dim}); + } + return std::make_tuple(d_kv, d_emb); +} + } // namespace transformer_engine::pytorch diff --git a/transformer_engine/pytorch/csrc/extensions/attention.cpp b/transformer_engine/pytorch/csrc/extensions/attention.cpp index e3c25b396a..733e1e1602 100644 --- a/transformer_engine/pytorch/csrc/extensions/attention.cpp +++ b/transformer_engine/pytorch/csrc/extensions/attention.cpp @@ -123,7 +123,8 @@ std::vector fused_attn_fwd( const std::optional page_table_k, const std::optional page_table_v, py::handle s_quantizer, py::handle o_quantizer, const std::optional Bias, const std::optional SoftmaxOffset, const std::optional rng_gen, - size_t rng_elts_per_thread, bool return_max_logit, bool cuda_graph) { + size_t rng_elts_per_thread, bool return_max_logit, bool cuda_graph, + NVTE_QKV_Format qkv_scale_inv_format) { // Ensure that cuDNN handle is created on the correct device, // overriding torch.cuda.set_device calls from user side. // Assumes all tensors passed are on the same device. @@ -255,7 +256,7 @@ std::vector fused_attn_fwd( te_page_table_v.data(), te_rng_state.data(), max_seqlen_q, max_seqlen_kv, is_training, return_max_logit, cuda_graph, attn_scale, p_dropout, qkv_layout, o_format, bias_type, attn_mask_type, softmax_type, window_size[0], window_size[1], bottom_right_diagonal, - workspace.data(), at::cuda::getCurrentCUDAStream()); + workspace.data(), at::cuda::getCurrentCUDAStream(), qkv_scale_inv_format); }); // allocate memory for workspace and auxiliary output tensors @@ -317,7 +318,7 @@ std::vector fused_attn_fwd( te_page_table_v.data(), te_rng_state.data(), max_seqlen_q, max_seqlen_kv, is_training, return_max_logit, cuda_graph, attn_scale, p_dropout, qkv_layout, o_format, bias_type, attn_mask_type, softmax_type, window_size[0], window_size[1], bottom_right_diagonal, - workspace.data(), at::cuda::getCurrentCUDAStream()); + workspace.data(), at::cuda::getCurrentCUDAStream(), qkv_scale_inv_format); }); // destroy tensor wrappers, but not allocated memory @@ -339,7 +340,8 @@ std::vector fused_attn_bwd( const std::vector Aux_CTX_Tensors, const std::optional cu_seqlens_q_padded, const std::optional cu_seqlens_kv_padded, py::handle s_quantizer, - py::handle dp_quantizer, py::handle dqkv_quantizer, bool cuda_graph) { + py::handle dp_quantizer, py::handle dqkv_quantizer, bool cuda_graph, + NVTE_QKV_Format qkv_scale_inv_format, NVTE_QKV_Format do_scale_inv_format) { auto none = py::none(); // create QKV, O, dO tensor wrappers @@ -575,7 +577,8 @@ std::vector fused_attn_bwd( te_cu_seqlens_q_padded.data(), te_cu_seqlens_kv_padded.data(), max_seqlen_q, max_seqlen_kv, attn_scale, p_dropout, qkv_layout, o_format, do_format, dqkv_layout, bias_type, attn_mask_type, softmax_type, window_size[0], window_size[1], bottom_right_diagonal, - deterministic, cuda_graph, workspace.data(), at::cuda::getCurrentCUDAStream()); + deterministic, cuda_graph, workspace.data(), at::cuda::getCurrentCUDAStream(), + qkv_scale_inv_format, do_scale_inv_format); }); // allocate memory for workspace @@ -592,7 +595,8 @@ std::vector fused_attn_bwd( te_cu_seqlens_q_padded.data(), te_cu_seqlens_kv_padded.data(), max_seqlen_q, max_seqlen_kv, attn_scale, p_dropout, qkv_layout, o_format, do_format, dqkv_layout, bias_type, attn_mask_type, softmax_type, window_size[0], window_size[1], bottom_right_diagonal, - deterministic, cuda_graph, workspace.data(), at::cuda::getCurrentCUDAStream()); + deterministic, cuda_graph, workspace.data(), at::cuda::getCurrentCUDAStream(), + qkv_scale_inv_format, do_scale_inv_format); }); // destroy tensor wrappers @@ -649,96 +653,124 @@ at::Tensor fa_prepare_bwd(at::Tensor q, at::Tensor k, at::Tensor v) { return qkv; } -std::tuple permute_to_grouped_tensor_fwd( - at::Tensor query, at::Tensor key, at::Tensor value, NVTE_QKV_Layout original_layout) { - NVTE_CHECK(original_layout == NVTE_SBHD_SBHD_SBHD || original_layout == NVTE_BSHD_BSHD_BSHD, - "permute_to_grouped_tensor_fwd: original_layout must be NVTE_SBHD_SBHD_SBHD or " - "NVTE_BSHD_BSHD_BSHD."); - NVTE_CHECK(query.is_cuda() && key.is_cuda() && value.is_cuda()); - NVTE_CHECK(query.is_contiguous() && key.is_contiguous() && value.is_contiguous()); - NVTE_CHECK(query.dim() == 4 && key.dim() == 4 && value.dim() == 4); +std::vector permute_to_grouped_tensor_fwd( + at::Tensor query, std::optional key, std::optional value, + NVTE_QKV_Format original_format) { + NVTE_CHECK(original_format == NVTE_SBHD || original_format == NVTE_BSHD, + "permute_to_grouped_tensor_fwd: original_format must be NVTE_SBHD or NVTE_BSHD."); + NVTE_CHECK(query.is_cuda() && query.is_contiguous() && query.dim() == 4); NVTE_CHECK(query.scalar_type() == at::ScalarType::Half || - query.scalar_type() == at::ScalarType::BFloat16); - NVTE_CHECK(key.scalar_type() == query.scalar_type() && - value.scalar_type() == query.scalar_type()); - - int64_t B = 0; - int64_t S_q = 0, H_q = 0, D_qk = 0; - int64_t S_kv = 0, H_kv = 0, D_v = 0; - if (original_layout == NVTE_SBHD_SBHD_SBHD) { - S_q = query.size(0); - B = query.size(1); - H_q = query.size(2); - D_qk = query.size(3); - S_kv = key.size(0); - H_kv = key.size(2); - D_v = value.size(3); + query.scalar_type() == at::ScalarType::BFloat16 || + query.scalar_type() == at::ScalarType::Float || + query.scalar_type() == at::ScalarType::Byte); + + const bool has_kv = key.has_value() && value.has_value(); + const size_t num_tensors = has_kv ? 3 : 1; + + int64_t B, S_q, H_q, D_qk; + if (original_format == NVTE_SBHD) { + S_q = query.size(0); B = query.size(1); H_q = query.size(2); D_qk = query.size(3); } else { - B = query.size(0); - S_q = query.size(1); - H_q = query.size(2); - D_qk = query.size(3); - S_kv = key.size(1); - H_kv = key.size(2); - D_v = value.size(3); + B = query.size(0); S_q = query.size(1); H_q = query.size(2); D_qk = query.size(3); + } + + at::Tensor q_out = at::empty({B, H_q, S_q, D_qk}, query.options()); + + if (!has_kv) { + auto te_q = makeTransformerEngineTensor(query); + auto te_qo = makeTransformerEngineTensor(q_out); + nvte_permute_to_grouped_tensor_fwd(te_q.data(), te_q.data(), te_q.data(), + te_qo.data(), te_qo.data(), te_qo.data(), + original_format, 1, at::cuda::getCurrentCUDAStream()); + return {q_out}; + } + + auto &k = key.value(); + auto &v = value.value(); + NVTE_CHECK(k.is_cuda() && k.is_contiguous() && k.dim() == 4); + NVTE_CHECK(v.is_cuda() && v.is_contiguous() && v.dim() == 4); + NVTE_CHECK(k.scalar_type() == query.scalar_type() && v.scalar_type() == query.scalar_type()); + + int64_t S_kv, H_kv, D_v; + if (original_format == NVTE_SBHD) { + S_kv = k.size(0); H_kv = k.size(2); D_v = v.size(3); + } else { + S_kv = k.size(1); H_kv = k.size(2); D_v = v.size(3); } - NVTE_CHECK(key.size(original_layout == NVTE_SBHD_SBHD_SBHD ? 1 : 0) == B && - value.size(original_layout == NVTE_SBHD_SBHD_SBHD ? 1 : 0) == B, - "permute_to_grouped_tensor_fwd: Q/K/V batch dimension must match."); const int64_t numel_q = B * H_q * S_q * D_qk; const int64_t numel_k = B * H_kv * S_kv * D_qk; const int64_t numel_v = B * H_kv * S_kv * D_v; at::Tensor qkv_out_flat = at::empty({numel_q + numel_k + numel_v}, query.options()); - at::Tensor q_out = qkv_out_flat.narrow(0, 0, numel_q).view({B, H_q, S_q, D_qk}); + q_out = qkv_out_flat.narrow(0, 0, numel_q).view({B, H_q, S_q, D_qk}); at::Tensor k_out = qkv_out_flat.narrow(0, numel_q, numel_k).view({B, H_kv, S_kv, D_qk}); at::Tensor v_out = qkv_out_flat.narrow(0, numel_q + numel_k, numel_v).view({B, H_kv, S_kv, D_v}); auto te_q = makeTransformerEngineTensor(query); - auto te_k = makeTransformerEngineTensor(key); - auto te_v = makeTransformerEngineTensor(value); + auto te_k = makeTransformerEngineTensor(k); + auto te_v = makeTransformerEngineTensor(v); auto te_qo = makeTransformerEngineTensor(q_out); auto te_ko = makeTransformerEngineTensor(k_out); auto te_vo = makeTransformerEngineTensor(v_out); - nvte_permute_to_grouped_tensor_fwd(te_q.data(), te_k.data(), te_v.data(), te_qo.data(), - te_ko.data(), te_vo.data(), original_layout, - at::cuda::getCurrentCUDAStream()); + nvte_permute_to_grouped_tensor_fwd(te_q.data(), te_k.data(), te_v.data(), + te_qo.data(), te_ko.data(), te_vo.data(), + original_format, 3, at::cuda::getCurrentCUDAStream()); - return std::make_tuple(q_out, k_out, v_out); + return {q_out, k_out, v_out}; } -std::tuple permute_to_grouped_tensor_bwd( - at::Tensor query_grad, at::Tensor key_grad, at::Tensor value_grad, - NVTE_QKV_Layout original_layout) { - NVTE_CHECK(original_layout == NVTE_SBHD_SBHD_SBHD || original_layout == NVTE_BSHD_BSHD_BSHD, - "permute_to_grouped_tensor_bwd: original_layout must be NVTE_SBHD_SBHD_SBHD or " - "NVTE_BSHD_BSHD_BSHD."); - NVTE_CHECK(query_grad.is_cuda() && key_grad.is_cuda() && value_grad.is_cuda()); - NVTE_CHECK(query_grad.is_contiguous() && key_grad.is_contiguous() && value_grad.is_contiguous()); - NVTE_CHECK(query_grad.dim() == 4 && key_grad.dim() == 4 && value_grad.dim() == 4); +std::vector permute_to_grouped_tensor_bwd( + at::Tensor query_grad, std::optional key_grad, + std::optional value_grad, NVTE_QKV_Format original_format) { + NVTE_CHECK(original_format == NVTE_SBHD || original_format == NVTE_BSHD, + "permute_to_grouped_tensor_bwd: original_format must be NVTE_SBHD or NVTE_BSHD."); + NVTE_CHECK(query_grad.is_cuda() && query_grad.is_contiguous() && query_grad.dim() == 4); NVTE_CHECK(query_grad.scalar_type() == at::ScalarType::Half || - query_grad.scalar_type() == at::ScalarType::BFloat16); - NVTE_CHECK(key_grad.scalar_type() == query_grad.scalar_type() && - value_grad.scalar_type() == query_grad.scalar_type()); + query_grad.scalar_type() == at::ScalarType::BFloat16 || + query_grad.scalar_type() == at::ScalarType::Float || + query_grad.scalar_type() == at::ScalarType::Byte); + + const bool has_kv = key_grad.has_value() && value_grad.has_value(); const int64_t B = query_grad.size(0); const int64_t H_q = query_grad.size(1); const int64_t S_q = query_grad.size(2); const int64_t D_qk = query_grad.size(3); - const int64_t H_kv = key_grad.size(1); - const int64_t S_kv = key_grad.size(2); - const int64_t D_v = value_grad.size(3); + + if (!has_kv) { + at::Tensor q; + if (original_format == NVTE_SBHD) { + q = at::empty({S_q, B, H_q, D_qk}, query_grad.options()); + } else { + q = at::empty({B, S_q, H_q, D_qk}, query_grad.options()); + } + auto te_gq = makeTransformerEngineTensor(query_grad); + auto te_q = makeTransformerEngineTensor(q); + nvte_permute_to_grouped_tensor_bwd(te_gq.data(), te_gq.data(), te_gq.data(), + te_q.data(), te_q.data(), te_q.data(), + original_format, 1, at::cuda::getCurrentCUDAStream()); + return {q}; + } + + auto &kg = key_grad.value(); + auto &vg = value_grad.value(); + NVTE_CHECK(kg.is_cuda() && kg.is_contiguous() && kg.dim() == 4); + NVTE_CHECK(vg.is_cuda() && vg.is_contiguous() && vg.dim() == 4); + NVTE_CHECK(kg.scalar_type() == query_grad.scalar_type() && + vg.scalar_type() == query_grad.scalar_type()); + + const int64_t H_kv = kg.size(1); + const int64_t S_kv = kg.size(2); + const int64_t D_v = vg.size(3); const int64_t numel_q = S_q * B * H_q * D_qk; const int64_t numel_k = S_kv * B * H_kv * D_qk; const int64_t numel_v = S_kv * B * H_kv * D_v; at::Tensor qkv_grad_flat = at::empty({numel_q + numel_k + numel_v}, query_grad.options()); - at::Tensor query; - at::Tensor key; - at::Tensor value; - if (original_layout == NVTE_SBHD_SBHD_SBHD) { + at::Tensor query, key, value; + if (original_format == NVTE_SBHD) { query = qkv_grad_flat.narrow(0, 0, numel_q).view({S_q, B, H_q, D_qk}); key = qkv_grad_flat.narrow(0, numel_q, numel_k).view({S_kv, B, H_kv, D_qk}); value = qkv_grad_flat.narrow(0, numel_q + numel_k, numel_v).view({S_kv, B, H_kv, D_v}); @@ -749,17 +781,86 @@ std::tuple permute_to_grouped_tensor_bwd( } auto te_gq = makeTransformerEngineTensor(query_grad); - auto te_gk = makeTransformerEngineTensor(key_grad); - auto te_gv = makeTransformerEngineTensor(value_grad); + auto te_gk = makeTransformerEngineTensor(kg); + auto te_gv = makeTransformerEngineTensor(vg); auto te_q = makeTransformerEngineTensor(query); auto te_k = makeTransformerEngineTensor(key); auto te_v = makeTransformerEngineTensor(value); - nvte_permute_to_grouped_tensor_bwd(te_gq.data(), te_gk.data(), te_gv.data(), te_q.data(), - te_k.data(), te_v.data(), original_layout, - at::cuda::getCurrentCUDAStream()); + nvte_permute_to_grouped_tensor_bwd(te_gq.data(), te_gk.data(), te_gv.data(), + te_q.data(), te_k.data(), te_v.data(), + original_format, 3, at::cuda::getCurrentCUDAStream()); + + return {query, key, value}; +} + +/*************************************************************************************************** + * Pad the last dimension of 2D tensors to a common alignment, zero-filling the padding. + * All tensors share the same alignment; launches a single fused kernel. + **************************************************************************************************/ + +std::vector pad_last_dim(std::vector inputs, int64_t alignment) { + const auto align = static_cast(alignment); + NVTE_CHECK(align > 0, "pad_last_dim: alignment must be > 0."); + NVTE_CHECK(!inputs.empty(), "pad_last_dim: inputs must not be empty."); + + auto stream = at::cuda::getCurrentCUDAStream(); + std::vector outputs; + outputs.reserve(inputs.size()); + + std::vector kernel_indices; + + for (size_t i = 0; i < inputs.size(); ++i) { + auto &input = inputs[i]; + + NVTE_CHECK(input.dim() == 2, "pad_last_dim: expected 2D input at index ", i, ", got ", + input.dim(), "D."); + NVTE_CHECK(input.is_cuda(), "pad_last_dim: input must be a CUDA tensor at index ", i, "."); + NVTE_CHECK(input.is_contiguous(), "pad_last_dim: input must be contiguous at index ", i, "."); + + const int64_t rows = input.size(0); + const int64_t in_cols = input.size(1); + const int64_t padded_cols = + static_cast(DIVUP_TO_MULTIPLE(static_cast(in_cols), align)); + + if (in_cols == padded_cols) { + outputs.push_back(input); + continue; + } + + at::Tensor output = at::empty({rows, padded_cols}, input.options()); + outputs.push_back(output); + kernel_indices.push_back(outputs.size() - 1); + } + + if (kernel_indices.empty()) return outputs; + + std::vector te_in_wrappers, te_out_wrappers; + te_in_wrappers.reserve(kernel_indices.size()); + te_out_wrappers.reserve(kernel_indices.size()); + + size_t ki = 0; + for (size_t i = 0; i < inputs.size(); ++i) { + const int64_t in_cols = inputs[i].size(1); + const int64_t padded_cols = + static_cast(DIVUP_TO_MULTIPLE(static_cast(in_cols), align)); + if (in_cols == padded_cols) continue; + + te_in_wrappers.push_back(makeTransformerEngineTensor(inputs[i])); + te_out_wrappers.push_back(makeTransformerEngineTensor(outputs[kernel_indices[ki]])); + ++ki; + } + + std::vector nvte_inputs(te_in_wrappers.size()); + std::vector nvte_outputs(te_out_wrappers.size()); + for (size_t i = 0; i < te_in_wrappers.size(); ++i) { + nvte_inputs[i] = te_in_wrappers[i].data(); + nvte_outputs[i] = te_out_wrappers[i].data(); + } + + nvte_multi_pad_last_dim(nvte_inputs.data(), nvte_outputs.data(), te_in_wrappers.size(), stream); - return std::make_tuple(query, key, value); + return outputs; } /*************************************************************************************************** diff --git a/transformer_engine/pytorch/csrc/extensions/pybind.cpp b/transformer_engine/pytorch/csrc/extensions/pybind.cpp index 7f32892015..d0d54c9283 100644 --- a/transformer_engine/pytorch/csrc/extensions/pybind.cpp +++ b/transformer_engine/pytorch/csrc/extensions/pybind.cpp @@ -401,12 +401,19 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { py::call_guard()); m.def("permute_to_grouped_tensor_fwd", &transformer_engine::pytorch::permute_to_grouped_tensor_fwd, - "Permute Q, K, V to grouped tensors.", py::arg("query"), py::arg("key"), py::arg("value"), - py::arg("original_layout"), py::call_guard()); - m.def( - "permute_to_grouped_tensor_bwd", &transformer_engine::pytorch::permute_to_grouped_tensor_bwd, - "Permute Q, K, V back to original layout.", py::arg("query_grad"), py::arg("key_grad"), - py::arg("value_grad"), py::arg("original_layout"), py::call_guard()); + "Permute tensors from BSHD/SBHD to BHSD.", py::arg("query"), + py::arg("key") = py::none(), py::arg("value") = py::none(), + py::arg("original_format") = static_cast(NVTE_BSHD), + py::call_guard()); + m.def("permute_to_grouped_tensor_bwd", + &transformer_engine::pytorch::permute_to_grouped_tensor_bwd, + "Permute tensors back to original format.", py::arg("query_grad"), + py::arg("key_grad") = py::none(), py::arg("value_grad") = py::none(), + py::arg("original_format") = static_cast(NVTE_BSHD), + py::call_guard()); + m.def("pad_last_dim", &transformer_engine::pytorch::pad_last_dim, + "Pad last dimension of 2D tensors to a common alignment.", py::arg("inputs"), + py::arg("alignment"), py::call_guard()); m.def("fused_attn_fwd", &transformer_engine::pytorch::fused_attn_fwd, "Fused Attention FP8/BF16/FP16 FWD with separate Q, K and V"); m.def("fused_attn_bwd", &transformer_engine::pytorch::fused_attn_bwd, @@ -428,6 +435,16 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("fused_qkv_rope_backward", &transformer_engine::pytorch::fused_qkv_rope_backward, "Fused Apply QKV RoPE BWD", py::call_guard()); + // fused MLA rope + m.def("fused_mla_rope_q_forward", &transformer_engine::pytorch::fused_mla_rope_q_forward, + "Fused MLA RoPE Q FWD", py::call_guard()); + m.def("fused_mla_rope_q_backward", &transformer_engine::pytorch::fused_mla_rope_q_backward, + "Fused MLA RoPE Q BWD", py::call_guard()); + m.def("fused_mla_rope_kv_forward", &transformer_engine::pytorch::fused_mla_rope_kv_forward, + "Fused MLA RoPE KV FWD", py::call_guard()); + m.def("fused_mla_rope_kv_backward", &transformer_engine::pytorch::fused_mla_rope_kv_backward, + "Fused MLA RoPE KV BWD", py::call_guard()); + // fused router m.def("fused_topk_with_score_function_fwd", &transformer_engine::pytorch::fused_topk_with_score_function_fwd, py::arg("logits"), From e87102b2c525fdab893c88e5a2aae31fecf7672b Mon Sep 17 00:00:00 2001 From: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Date: Fri, 10 Apr 2026 17:32:14 -0700 Subject: [PATCH 170/172] remove mla_rope for now; clean up quant+permute+pad_swizzle; create multi_tensor_swizzle Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --- .../attention/run_attention_with_cp.py | 4 +- transformer_engine/common/common.h | 3 +- .../common/fused_attn/flash_attn.cu | 1058 ++++++++--------- .../common/fused_attn/fused_attn_fp8.cu | 43 +- .../common/fused_rope/fused_rope.cu | 510 -------- .../include/transformer_engine/fused_attn.h | 4 +- .../include/transformer_engine/fused_rope.h | 75 -- .../include/transformer_engine/swizzle.h | 3 +- transformer_engine/common/swizzle/swizzle.cu | 390 +++--- .../common/transformer_engine.cpp | 17 +- .../dot_product_attention/backends.py | 126 +- .../dot_product_attention/context_parallel.py | 113 +- .../attention/dot_product_attention/utils.py | 264 ++-- transformer_engine/pytorch/attention/rope.py | 232 +--- .../pytorch/cpp_extensions/fused_attn.py | 12 +- transformer_engine/pytorch/csrc/extensions.h | 32 +- .../pytorch/csrc/extensions/apply_rope.cpp | 212 ---- .../pytorch/csrc/extensions/attention.cpp | 46 +- .../pytorch/csrc/extensions/pybind.cpp | 20 +- .../pytorch/csrc/extensions/swizzle.cpp | 99 ++ 20 files changed, 1213 insertions(+), 2050 deletions(-) diff --git a/tests/pytorch/attention/run_attention_with_cp.py b/tests/pytorch/attention/run_attention_with_cp.py index cda5c42d50..8dfea644a5 100644 --- a/tests/pytorch/attention/run_attention_with_cp.py +++ b/tests/pytorch/attention/run_attention_with_cp.py @@ -333,7 +333,7 @@ def run_dpa_with_cp( qkv_layout = "_".join([qkv_format] * 3) q, k, v, dout = [x.clone().detach() for x in [q_orig, k_orig, v_orig, dout_orig]] if fp8_mha: - q, k, v, qkv_layout = combine_and_quantize(qkv_layout, q, k, v, qkv_quantizer) + q, k, v, qkv_layout, _ = combine_and_quantize(qkv_layout, q, k, v, qkv_quantizer) for x in [q, k, v]: x.requires_grad = True @@ -441,7 +441,7 @@ def run_dpa_with_cp( dout_quantizer.scale.fill_(1.0) dout_quantizer.amax.fill_(0.0) if fp8_mha: - q_, k_, v_, qkv_layout = combine_and_quantize(qkv_layout, q_, k_, v_, qkv_quantizer) + q_, k_, v_, qkv_layout, _ = combine_and_quantize(qkv_layout, q_, k_, v_, qkv_quantizer) if is_training: q_, k_, v_ = [x.requires_grad_() for x in [q_, k_, v_]] if bias_ is not None: diff --git a/transformer_engine/common/common.h b/transformer_engine/common/common.h index 6e207370dd..b32d6bbbcd 100644 --- a/transformer_engine/common/common.h +++ b/transformer_engine/common/common.h @@ -1003,7 +1003,8 @@ size_t typeToSize(const DType type); size_t typeToNumBits(const DType type); void CheckNoopTensor(const Tensor &t, const std::string &name); -void CheckInputTensor(const Tensor &t, const std::string &name); +void CheckInputTensor(const Tensor &t, const std::string &name, + bool check_scale_inv_shapes = true); void CheckOutputTensor(const Tensor &t, const std::string &name, bool allow_empty = false); /*! \brief Update a tensor's FP8 scale-inverse diff --git a/transformer_engine/common/fused_attn/flash_attn.cu b/transformer_engine/common/fused_attn/flash_attn.cu index 9ebd13fb1c..f34a11422d 100644 --- a/transformer_engine/common/fused_attn/flash_attn.cu +++ b/transformer_engine/common/fused_attn/flash_attn.cu @@ -9,11 +9,17 @@ #include "../common.h" #include "../util/cuda_driver.h" +#include "../util/cuda_runtime.h" #include "../util/ptx.cuh" #include "../utils.cuh" #include "transformer_engine/fused_attn.h" namespace transformer_engine { + +// ============================================================================ +// prepare_flash_attn: repack interleaved QKV for FlashAttention backend +// ============================================================================ + namespace flash_attention { /// Packed vector of N elements of T; alignment matches a single wide load/store of N * sizeof(T) bytes. @@ -29,102 +35,6 @@ constexpr int nvec128 = sizeof(uint4) / type_size; constexpr int load_size = warp_size * nvec; constexpr int block_size = 512; -// TMA permute kernel configuration -constexpr int tma_permute_threads = 128; -constexpr int tma_permute_s_tile_default = 32; - -// ---- 4D TMA PTX wrappers ---- - -__device__ __forceinline__ void cp_async_bulk_tensor_4d_global_to_shared( - void *dst_shmem, const CUtensorMap *tensor_map, uint32_t c0, uint32_t c1, uint32_t c2, - uint32_t c3, uint64_t *mbar) { -#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) - uint32_t dst = __cvta_generic_to_shared(dst_shmem); - uint32_t bar = __cvta_generic_to_shared(mbar); - asm volatile( - "cp.async.bulk.tensor.4d.shared::cluster.global.tile" - ".mbarrier::complete_tx::bytes [%0], [%1, {%2, %3, %4, %5}], [%6];" ::"r"(dst), - "l"(tensor_map), "r"(c0), "r"(c1), "r"(c2), "r"(c3), "r"(bar) - : "memory"); -#else - NVTE_DEVICE_ERROR("cp_async_bulk_tensor_4d_global_to_shared requires SM 10.0+."); -#endif -} - -__device__ __forceinline__ void cp_async_bulk_tensor_4d_shared_to_global( - const CUtensorMap *tensor_map, uint32_t c0, uint32_t c1, uint32_t c2, uint32_t c3, - void *src_shmem) { -#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) - uint32_t src = __cvta_generic_to_shared(src_shmem); - asm volatile( - "cp.async.bulk.tensor.4d.global.shared::cta.bulk_group" - " [%0, {%1, %2, %3, %4}], [%5];" ::"l"(tensor_map), - "r"(c0), "r"(c1), "r"(c2), "r"(c3), "r"(src) - : "memory"); -#else - NVTE_DEVICE_ERROR("cp_async_bulk_tensor_4d_shared_to_global requires SM 9.0+."); -#endif -} - -// ---- Host-side 4D tensor map creation ---- -// -// Creates a 4D TMA descriptor for a densely-packed tensor whose logical -// dimensions (innermost-first) are [dim0, dim1, dim2, dim3]. -// -// The box (tile) copied per TMA instruction is [box0, box1, box2, box3]. - -static void create_4D_tensor_map(CUtensorMap &tensorMap, void *dataPtr, DType dtype, uint64_t dim0, - uint64_t dim1, uint64_t dim2, uint64_t dim3, uint32_t box0, - uint32_t box1, uint32_t box2, uint32_t box3) { - cuda_driver::ensure_context_exists(); - static PFN_cuTensorMapEncodeTiled_v12000 cuDriverTensorMapEncodeTiled = []() { - void *ptr = cuda_driver::get_symbol("cuTensorMapEncodeTiled"); - return reinterpret_cast(ptr); - }(); - - CUtensorMapDataType tma_dtype; - size_t elem_bytes; - switch (dtype) { - case DType::kFloat16: - tma_dtype = CU_TENSOR_MAP_DATA_TYPE_FLOAT16; - elem_bytes = 2; - break; - case DType::kBFloat16: - tma_dtype = CU_TENSOR_MAP_DATA_TYPE_BFLOAT16; - elem_bytes = 2; - break; - case DType::kFloat8E4M3: - case DType::kFloat8E5M2: - case DType::kFloat8E8M0: - case DType::kByte: - tma_dtype = CU_TENSOR_MAP_DATA_TYPE_UINT8; - elem_bytes = 1; - break; - default: - NVTE_ERROR("create_4D_tensor_map: unsupported dtype ", - to_string(static_cast(dtype))); - } - - constexpr uint32_t rank = 4; - uint64_t size[rank] = {dim0, dim1, dim2, dim3}; - uint64_t stride[rank - 1] = { - dim0 * elem_bytes, - dim0 * dim1 * elem_bytes, - dim0 * dim1 * dim2 * elem_bytes, - }; - uint32_t boxSize[rank] = {box0, box1, box2, box3}; - uint32_t elemStride[rank] = {1, 1, 1, 1}; - - const auto oob_fill = (tma_dtype == CU_TENSOR_MAP_DATA_TYPE_UINT8) - ? CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE - : CU_TENSOR_MAP_FLOAT_OOB_FILL_NAN_REQUEST_ZERO_FMA; - - NVTE_CHECK_CUDA_DRIVER(cuDriverTensorMapEncodeTiled( - &tensorMap, tma_dtype, rank, dataPtr, size, stride, boxSize, elemStride, - CU_TENSOR_MAP_INTERLEAVE_NONE, CU_TENSOR_MAP_SWIZZLE_NONE, CU_TENSOR_MAP_L2_PROMOTION_NONE, - oob_fill)); -} - template __launch_bounds__(block_size) __global__ void prepare_kernel_fwd(const T *qkvi, T *qkv, const size_t B, const size_t S, const size_t Z, @@ -242,222 +152,178 @@ void prepare_flash_attn_bwd(Tensor q, Tensor k, Tensor v, Tensor qkv, cudaStream NVTE_CHECK_CUDA(cudaGetLastError()); } -// ---- TMA helpers for strided (BSHD/SBHD) tensors ---- -// -// Strided BSHD: TMA dims [D, H, S, B], coords [0, h, s, b] -// Strided SBHD: TMA dims [D, H, B, S], coords [0, h, b, s] - -template -__device__ __forceinline__ void issue_tma_load_strided(T *smem_buf, const CUtensorMap *tma, - size_t h_i, size_t s_tile, size_t b_i, - uint64_t *mbar, size_t tile_bytes) { - ptx::mbarrier_arrive_expect_tx(mbar, static_cast(tile_bytes)); - if constexpr (kIsBshdBshdBshd) { - cp_async_bulk_tensor_4d_global_to_shared(smem_buf, tma, 0, static_cast(h_i), - static_cast(s_tile), - static_cast(b_i), mbar); - } else { - cp_async_bulk_tensor_4d_global_to_shared(smem_buf, tma, 0, static_cast(h_i), - static_cast(b_i), - static_cast(s_tile), mbar); - } -} - -template -__device__ __forceinline__ void issue_tma_store_strided(const CUtensorMap *tma, T *smem_buf, - size_t h_i, size_t s_tile, size_t b_i) { - if constexpr (kIsBshdBshdBshd) { - cp_async_bulk_tensor_4d_shared_to_global(tma, 0, static_cast(h_i), - static_cast(s_tile), - static_cast(b_i), smem_buf); - } else { - cp_async_bulk_tensor_4d_shared_to_global(tma, 0, static_cast(h_i), - static_cast(b_i), - static_cast(s_tile), smem_buf); - } - ptx::cp_async_bulk_commit_group(); -} +} // namespace flash_attention -__device__ __forceinline__ void st_global_cs_uint4(uint4 *ptr, uint4 val) { - asm volatile("st.global.cs.v4.b32 [%0], {%1, %2, %3, %4};" ::"l"(ptr), "r"(val.x), "r"(val.y), - "r"(val.z), "r"(val.w) - : "memory"); -} +// ============================================================================ +// permute_to_grouped_tensor: BSHD/SBHD ↔ BHSD permutation +// ============================================================================ -// ---- Forward: BSHD/SBHD → BHSD ---- -// -// TMA load from strided input → smem → non-temporal stores to contiguous output. +namespace permute_to_grouped_tensor { -template -__launch_bounds__(tma_permute_threads) __global__ - void permute_to_grouped_tensor_fwd_kernel(const __grid_constant__ CUtensorMap tma_q_in, - const __grid_constant__ CUtensorMap tma_k_in, - const __grid_constant__ CUtensorMap tma_v_in, - T *__restrict__ q_out, T *__restrict__ k_out, - T *__restrict__ v_out, size_t b, size_t s_q, - size_t h_q, size_t d_qk, size_t s_kv, size_t h_kv, - size_t d_v, unsigned int permute_s_splits, - size_t s_tile_size) { -#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) - const int which = blockIdx.z; - const CUtensorMap *tma_in = which == 0 ? &tma_q_in : (which == 1 ? &tma_k_in : &tma_v_in); - T *__restrict__ tensor_out = which == 0 ? q_out : (which == 1 ? k_out : v_out); - const size_t Sdim = which == 0 ? s_q : s_kv; - const size_t Hdim = which == 0 ? h_q : h_kv; - const size_t Ddim = which == 0 ? d_qk : (which == 1 ? d_qk : d_v); +using flash_attention::Vec; - const size_t h_grid = h_q > h_kv ? h_q : h_kv; - const size_t b_i = static_cast(blockIdx.x) / h_grid; - const size_t h_i = static_cast(blockIdx.x) % h_grid; +// ---------- fallback_not_vec_aligned: row-copy helper (D is small / misaligned) ---------- - if (b_i >= b) return; - if (which == 0) { - if (h_i >= h_q) return; - } else { - if (h_i >= h_kv) return; +__device__ __forceinline__ void copy_row_bytes(const char *__restrict__ src, + char *__restrict__ dst, size_t D_bytes) { + size_t off = 0; + for (; off + 16 <= D_bytes; off += 16) { + uint4 tmp; + memcpy(&tmp, src + off, 16); + memcpy(dst + off, &tmp, 16); + } + for (; off + 8 <= D_bytes; off += 8) { + uint2 tmp; + memcpy(&tmp, src + off, 8); + memcpy(dst + off, &tmp, 8); + } + for (; off + 4 <= D_bytes; off += 4) { + unsigned int tmp; + memcpy(&tmp, src + off, 4); + memcpy(dst + off, &tmp, 4); + } + for (; off + 2 <= D_bytes; off += 2) { + unsigned short tmp; + memcpy(&tmp, src + off, 2); + memcpy(dst + off, &tmp, 2); } + for (; off < D_bytes; ++off) dst[off] = src[off]; +} - const unsigned int s_part = blockIdx.y; - const size_t s_begin = - (Sdim * static_cast(s_part)) / static_cast(permute_s_splits); - const size_t s_end = - (Sdim * static_cast(s_part + 1)) / static_cast(permute_s_splits); - if (s_begin >= s_end) return; - const size_t out_base = b_i * Hdim * Sdim * Ddim + h_i * Sdim * Ddim; +// ---------- fallback_not_vec_aligned: tiled-transpose kernels ---------- - extern __shared__ __align__(128) char smem_raw[]; - T *smem = reinterpret_cast(smem_raw); +constexpr int TRANSPOSE_TILE = 32; +constexpr int TRANSPOSE_BLOCK = 256; +constexpr int TRANSPOSE_WARPS = TRANSPOSE_BLOCK / 32; // 8 - __shared__ __align__(8) uint64_t mbar; - const bool is_leader = (threadIdx.x == 0); +template +__launch_bounds__(TRANSPOSE_BLOCK) __global__ + void permute_to_grouped_tensor_fwd_fallback_not_vec_aligned_kernel( + const T *__restrict__ q_in, const T *__restrict__ k_in, const T *__restrict__ v_in, + T *__restrict__ q_out, T *__restrict__ k_out, T *__restrict__ v_out, + size_t b, size_t s_q, size_t h_q, size_t d_qk, size_t s_kv, size_t h_kv, size_t d_v, + unsigned int s_tiles) { + const int which = blockIdx.z; + const T *__restrict__ in = which == 0 ? q_in : (which == 1 ? k_in : v_in); + T *__restrict__ out = which == 0 ? q_out : (which == 1 ? k_out : v_out); + const size_t S = which == 0 ? s_q : s_kv; + const size_t H = which == 0 ? h_q : h_kv; + const size_t D = which == 0 ? d_qk : (which == 1 ? d_qk : d_v); + const size_t D_bytes = D * sizeof(T); + const size_t D_pad = (D_bytes + 3u) & ~size_t(3); // 4-byte aligned for smem - if (is_leader) { - ptx::mbarrier_init(&mbar, static_cast(blockDim.x)); - ptx::fence_proxy_async_shared_cta(); - } - __syncthreads(); + const size_t tile_s = static_cast(blockIdx.x) % static_cast(s_tiles); + const size_t b_i = static_cast(blockIdx.x) / static_cast(s_tiles); + if (b_i >= b) return; + const size_t tile_h = static_cast(blockIdx.y); - const size_t S_TILE = s_tile_size; - const uint32_t tile_bytes = static_cast(S_TILE * Ddim * sizeof(T)); - int parity = 0; + const size_t s_base = tile_s * TRANSPOSE_TILE; + const size_t h_base = tile_h * TRANSPOSE_TILE; - for (size_t s_tile = s_begin; s_tile < s_end; s_tile += S_TILE) { - const size_t tile_rows = min(S_TILE, s_end - s_tile); + extern __shared__ char smem[]; + // +4 padding per S-row avoids 32-way bank conflicts during the store phase. + const size_t smem_row = static_cast(TRANSPOSE_TILE) * D_pad + 4; - if (is_leader) { - issue_tma_load_strided(smem, tma_in, h_i, s_tile, b_i, &mbar, tile_bytes); - } else { - ptx::mbarrier_arrive(&mbar); + // ---- Phase 1: global → smem (sweep consecutive H → coalesced reads) ---- + for (unsigned int warp_off = threadIdx.x >> 5; + warp_off < TRANSPOSE_TILE; warp_off += TRANSPOSE_WARPS) { + const size_t local_s = warp_off; + const size_t local_h = threadIdx.x & 31u; + const size_t s_i = s_base + local_s; + const size_t h_i = h_base + local_h; + if (s_i < S && h_i < H) { + const char *__restrict__ src; + if constexpr (kIsSbhd) + src = reinterpret_cast(in + s_i * b * H * D + b_i * H * D + h_i * D); + else + src = reinterpret_cast(in + b_i * S * H * D + s_i * H * D + h_i * D); + copy_row_bytes(src, smem + local_s * smem_row + local_h * D_pad, D_bytes); } + } - ptx::mbarrier_wait_parity(&mbar, parity); - parity ^= 1; - - T *__restrict__ out_ptr = tensor_out + out_base + s_tile * Ddim; - const size_t total_elems = tile_rows * Ddim; - constexpr size_t vec_elems = sizeof(uint4) / sizeof(T); + __syncthreads(); - for (size_t i = threadIdx.x * vec_elems; i < total_elems; - i += static_cast(blockDim.x) * vec_elems) { - uint4 v = *reinterpret_cast(smem + i); - st_global_cs_uint4(reinterpret_cast(out_ptr + i), v); + // ---- Phase 2: smem → global (sweep consecutive S → coalesced writes) ---- + for (unsigned int warp_off = threadIdx.x >> 5; + warp_off < TRANSPOSE_TILE; warp_off += TRANSPOSE_WARPS) { + const size_t local_h = warp_off; + const size_t local_s = threadIdx.x & 31u; + const size_t s_i = s_base + local_s; + const size_t h_i = h_base + local_h; + if (s_i < S && h_i < H) { + copy_row_bytes(smem + local_s * smem_row + local_h * D_pad, + reinterpret_cast(out + b_i * H * S * D + h_i * S * D + s_i * D), + D_bytes); } - - __syncthreads(); - } - - if (is_leader) { - ptx::mbarrier_invalid(&mbar); } -#endif } -// ---- Backward: BHSD → BSHD/SBHD ---- -// -// Vectorized loads from contiguous input → smem → TMA store to strided output. - -template -__launch_bounds__(tma_permute_threads) __global__ void permute_to_grouped_tensor_bwd_kernel( - const T *__restrict__ grad_q, const T *__restrict__ grad_k, const T *__restrict__ grad_v, - const __grid_constant__ CUtensorMap tma_q_out, const __grid_constant__ CUtensorMap tma_k_out, - const __grid_constant__ CUtensorMap tma_v_out, size_t b, size_t s_q, size_t h_q, size_t d_qk, - size_t s_kv, size_t h_kv, size_t d_v, unsigned int permute_s_splits, - size_t s_tile_size) { -#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) +template +__launch_bounds__(TRANSPOSE_BLOCK) __global__ + void permute_to_grouped_tensor_bwd_fallback_not_vec_aligned_kernel( + const T *__restrict__ grad_q, const T *__restrict__ grad_k, const T *__restrict__ grad_v, + T *__restrict__ q_out, T *__restrict__ k_out, T *__restrict__ v_out, + size_t b, size_t s_q, size_t h_q, size_t d_qk, size_t s_kv, size_t h_kv, size_t d_v, + unsigned int s_tiles) { const int which = blockIdx.z; - const T *__restrict__ tensor_in = which == 0 ? grad_q : (which == 1 ? grad_k : grad_v); - const CUtensorMap *tma_out = which == 0 ? &tma_q_out : (which == 1 ? &tma_k_out : &tma_v_out); - const size_t Sdim = which == 0 ? s_q : s_kv; - const size_t Hdim = which == 0 ? h_q : h_kv; - const size_t Ddim = which == 0 ? d_qk : (which == 1 ? d_qk : d_v); - - const size_t h_grid = h_q > h_kv ? h_q : h_kv; - const size_t b_i = static_cast(blockIdx.x) / h_grid; - const size_t h_i = static_cast(blockIdx.x) % h_grid; + const T *__restrict__ in = which == 0 ? grad_q : (which == 1 ? grad_k : grad_v); + T *__restrict__ out = which == 0 ? q_out : (which == 1 ? k_out : v_out); + const size_t S = which == 0 ? s_q : s_kv; + const size_t H = which == 0 ? h_q : h_kv; + const size_t D = which == 0 ? d_qk : (which == 1 ? d_qk : d_v); + const size_t D_bytes = D * sizeof(T); + const size_t D_pad = (D_bytes + 3u) & ~size_t(3); + const size_t tile_s = static_cast(blockIdx.x) % static_cast(s_tiles); + const size_t b_i = static_cast(blockIdx.x) / static_cast(s_tiles); if (b_i >= b) return; - if (which == 0) { - if (h_i >= h_q) return; - } else { - if (h_i >= h_kv) return; - } - - const unsigned int s_part = blockIdx.y; - const size_t s_begin = - (Sdim * static_cast(s_part)) / static_cast(permute_s_splits); - const size_t s_end = - (Sdim * static_cast(s_part + 1)) / static_cast(permute_s_splits); - if (s_begin >= s_end) return; - - const size_t in_base = b_i * Hdim * Sdim * Ddim + h_i * Sdim * Ddim; - - extern __shared__ __align__(128) char smem_raw[]; - T *smem = reinterpret_cast(smem_raw); - - const size_t S_TILE = s_tile_size; - constexpr size_t vec_elems = sizeof(uint4) / sizeof(T); + const size_t tile_h = static_cast(blockIdx.y); - for (size_t s_tile = s_begin; s_tile < s_end; s_tile += S_TILE) { - const size_t tile_rows = min(S_TILE, s_end - s_tile); + const size_t s_base = tile_s * TRANSPOSE_TILE; + const size_t h_base = tile_h * TRANSPOSE_TILE; - const T *__restrict__ in_ptr = tensor_in + in_base + s_tile * Ddim; - const size_t total_elems = tile_rows * Ddim; + extern __shared__ char smem[]; + const size_t smem_row = static_cast(TRANSPOSE_TILE) * D_pad + 4; - for (size_t i = threadIdx.x * vec_elems; i < total_elems; - i += static_cast(blockDim.x) * vec_elems) { - *reinterpret_cast(smem + i) = *reinterpret_cast(in_ptr + i); + // ---- Phase 1: global → smem (sweep consecutive S → coalesced reads from BHSD) ---- + for (unsigned int warp_off = threadIdx.x >> 5; + warp_off < TRANSPOSE_TILE; warp_off += TRANSPOSE_WARPS) { + const size_t local_h = warp_off; + const size_t local_s = threadIdx.x & 31u; + const size_t s_i = s_base + local_s; + const size_t h_i = h_base + local_h; + if (s_i < S && h_i < H) { + copy_row_bytes( + reinterpret_cast(in + b_i * H * S * D + h_i * S * D + s_i * D), + smem + local_s * smem_row + local_h * D_pad, D_bytes); } + } - ptx::fence_proxy_async_shared_cta(); - __syncthreads(); + __syncthreads(); - if (threadIdx.x == 0) { - issue_tma_store_strided(tma_out, smem, h_i, s_tile, b_i); + // ---- Phase 2: smem → global (sweep consecutive H → coalesced writes to SBHD/BSHD) ---- + for (unsigned int warp_off = threadIdx.x >> 5; + warp_off < TRANSPOSE_TILE; warp_off += TRANSPOSE_WARPS) { + const size_t local_s = warp_off; + const size_t local_h = threadIdx.x & 31u; + const size_t s_i = s_base + local_s; + const size_t h_i = h_base + local_h; + if (s_i < S && h_i < H) { + char *__restrict__ dst; + if constexpr (kIsSbhd) + dst = reinterpret_cast(out + s_i * b * H * D + b_i * H * D + h_i * D); + else + dst = reinterpret_cast(out + b_i * S * H * D + s_i * H * D + h_i * D); + copy_row_bytes(smem + local_s * smem_row + local_h * D_pad, dst, D_bytes); } - - ptx::cp_async_bulk_wait_group(); - __syncthreads(); } -#endif } -// ---- Fallback: BSHD/SBHD ↔ BHSD (no TMA, works on any SM / dtype / D) ---- -// -// Same grid structure as the pre-TMA permute kernels: one block per (b, h) -// pair per S-partition; blockIdx.z selects Q (0), K (1), or V (2). -// -// Two strategies depending on D alignment: -// 1) "vec-flat": D divides evenly into wide vectors (16/8/4 bytes). -// Each thread handles one vector chunk; work = S_chunk * d_vec. -// 2) "row-copy": D is small / misaligned. Each thread handles complete rows -// to avoid expensive runtime integer division by D. Inner copy uses the -// widest loads/stores that fit, with smaller ops for the remainder. +// ---------- fallback_vec_aligned: ---------- constexpr int fallback_permute_threads = 1024; -// ---------- vec-flat helpers (D well-aligned) ---------- - template __device__ __forceinline__ void permute_fwd_vec_loop( const T *__restrict__ in, T *__restrict__ out, size_t b, size_t S, size_t H, size_t D, @@ -504,198 +370,271 @@ __device__ __forceinline__ void permute_bwd_vec_loop( } } -// ---------- row-copy helper (D small / misaligned) ---------- -// -// Copies D_bytes from src to dst using the widest loads that fit, -// stepping down through uint4 (16B) → uint2 (8B) → uint (4B) → ushort (2B) → uchar. - -__device__ __forceinline__ void copy_row_bytes(const char *__restrict__ src, - char *__restrict__ dst, size_t D_bytes) { - size_t off = 0; - for (; off + 16 <= D_bytes; off += 16) { - uint4 tmp; - memcpy(&tmp, src + off, 16); - memcpy(dst + off, &tmp, 16); - } - for (; off + 8 <= D_bytes; off += 8) { - uint2 tmp; - memcpy(&tmp, src + off, 8); - memcpy(dst + off, &tmp, 8); - } - for (; off + 4 <= D_bytes; off += 4) { - unsigned int tmp; - memcpy(&tmp, src + off, 4); - memcpy(dst + off, &tmp, 4); - } - for (; off + 2 <= D_bytes; off += 2) { - unsigned short tmp; - memcpy(&tmp, src + off, 2); - memcpy(dst + off, &tmp, 2); - } - for (; off < D_bytes; ++off) dst[off] = src[off]; -} - -// ---------- tiled-transpose kernels for small / misaligned D ---------- -// -// Problem: when D_bytes % 4 != 0 (e.g. D=6, unsigned char), the old row-copy -// path assigned one thread per row with a fixed (b, h) per block. Adjacent -// threads read S-rows that are B*H*D bytes apart => ~5 % cache-line use. -// -// Fix: treat the permutation as a 2-D transpose of [S, H] "atoms" of D bytes. -// Load a [TILE_S, TILE_H] tile through shared memory so that: -// Load – consecutive threads cover consecutive H (stride D in input) => coalesced reads -// Store – consecutive threads cover consecutive S (stride D in output) => coalesced writes -// -// Grid: (B * s_tiles, h_tiles, num_tensors) Block: TRANSPOSE_BLOCK - -constexpr int TRANSPOSE_TILE = 32; -constexpr int TRANSPOSE_BLOCK = 256; -constexpr int TRANSPOSE_WARPS = TRANSPOSE_BLOCK / 32; // 8 - -// FWD: strided (SBHD / BSHD) → contiguous BHSD template -__launch_bounds__(TRANSPOSE_BLOCK) __global__ - void permute_fwd_tiled_transpose_kernel( +__launch_bounds__(fallback_permute_threads) __global__ + void permute_to_grouped_tensor_fwd_fallback_vec_aligned_kernel( const T *__restrict__ q_in, const T *__restrict__ k_in, const T *__restrict__ v_in, - T *__restrict__ q_out, T *__restrict__ k_out, T *__restrict__ v_out, - size_t b, size_t s_q, size_t h_q, size_t d_qk, size_t s_kv, size_t h_kv, size_t d_v, - unsigned int s_tiles) { + T *__restrict__ q_out, T *__restrict__ k_out, T *__restrict__ v_out, size_t b, size_t s_q, + size_t h_q, size_t d_qk, size_t s_kv, size_t h_kv, size_t d_v, + unsigned int permute_s_splits) { const int which = blockIdx.z; const T *__restrict__ in = which == 0 ? q_in : (which == 1 ? k_in : v_in); - T *__restrict__ out = which == 0 ? q_out : (which == 1 ? k_out : v_out); - const size_t S = which == 0 ? s_q : s_kv; - const size_t H = which == 0 ? h_q : h_kv; - const size_t D = which == 0 ? d_qk : (which == 1 ? d_qk : d_v); - const size_t D_bytes = D * sizeof(T); - const size_t D_pad = (D_bytes + 3u) & ~size_t(3); // 4-byte aligned for smem + T *__restrict__ out = which == 0 ? q_out : (which == 1 ? k_out : v_out); + const size_t S = which == 0 ? s_q : s_kv; + const size_t H = which == 0 ? h_q : h_kv; + const size_t D = which == 0 ? d_qk : (which == 1 ? d_qk : d_v); - const size_t tile_s = static_cast(blockIdx.x) % static_cast(s_tiles); - const size_t b_i = static_cast(blockIdx.x) / static_cast(s_tiles); + const size_t h_grid = h_q > h_kv ? h_q : h_kv; + const size_t b_i = static_cast(blockIdx.x) / h_grid; + const size_t h_i = static_cast(blockIdx.x) % h_grid; if (b_i >= b) return; - const size_t tile_h = static_cast(blockIdx.y); + if (which == 0) { + if (h_i >= h_q) return; + } else { + if (h_i >= h_kv) return; + } - const size_t s_base = tile_s * TRANSPOSE_TILE; - const size_t h_base = tile_h * TRANSPOSE_TILE; + const unsigned int s_part = blockIdx.y; + const size_t s_begin = + (S * static_cast(s_part)) / static_cast(permute_s_splits); + const size_t s_end = + (S * static_cast(s_part + 1)) / static_cast(permute_s_splits); + if (s_begin >= s_end) return; + const size_t S_chunk = s_end - s_begin; - extern __shared__ char smem[]; - // +4 padding per S-row avoids 32-way bank conflicts during the store phase. - const size_t smem_row = static_cast(TRANSPOSE_TILE) * D_pad + 4; + const size_t D_bytes = D * sizeof(T); - // ---- Phase 1: global → smem (sweep consecutive H → coalesced reads) ---- - for (unsigned int warp_off = threadIdx.x >> 5; - warp_off < TRANSPOSE_TILE; warp_off += TRANSPOSE_WARPS) { - const size_t local_s = warp_off; - const size_t local_h = threadIdx.x & 31u; - const size_t s_i = s_base + local_s; - const size_t h_i = h_base + local_h; - if (s_i < S && h_i < H) { - const char *__restrict__ src; - if constexpr (kIsSbhd) - src = reinterpret_cast(in + s_i * b * H * D + b_i * H * D + h_i * D); - else - src = reinterpret_cast(in + b_i * S * H * D + s_i * H * D + h_i * D); - copy_row_bytes(src, smem + local_s * smem_row + local_h * D_pad, D_bytes); - } + if (D_bytes % 16 == 0) { + constexpr size_t N = 16 / sizeof(T); + permute_fwd_vec_loop(in, out, b, S, H, D, b_i, h_i, s_begin, S_chunk); + return; } - - __syncthreads(); - - // ---- Phase 2: smem → global (sweep consecutive S → coalesced writes) ---- - for (unsigned int warp_off = threadIdx.x >> 5; - warp_off < TRANSPOSE_TILE; warp_off += TRANSPOSE_WARPS) { - const size_t local_h = warp_off; - const size_t local_s = threadIdx.x & 31u; - const size_t s_i = s_base + local_s; - const size_t h_i = h_base + local_h; - if (s_i < S && h_i < H) { - copy_row_bytes(smem + local_s * smem_row + local_h * D_pad, - reinterpret_cast(out + b_i * H * S * D + h_i * S * D + s_i * D), - D_bytes); + if (D_bytes % 8 == 0) { + constexpr size_t N = 8 / sizeof(T); + permute_fwd_vec_loop(in, out, b, S, H, D, b_i, h_i, s_begin, S_chunk); + return; + } + if constexpr (sizeof(T) <= 4) { + if (D_bytes % 4 == 0) { + constexpr size_t N = 4 / sizeof(T); + permute_fwd_vec_loop(in, out, b, S, H, D, b_i, h_i, s_begin, S_chunk); + return; } } } -// BWD: contiguous BHSD → strided (SBHD / BSHD) template -__launch_bounds__(TRANSPOSE_BLOCK) __global__ - void permute_bwd_tiled_transpose_kernel( +__launch_bounds__(fallback_permute_threads) __global__ + void permute_to_grouped_tensor_bwd_fallback_vec_aligned_kernel( const T *__restrict__ grad_q, const T *__restrict__ grad_k, const T *__restrict__ grad_v, - T *__restrict__ q_out, T *__restrict__ k_out, T *__restrict__ v_out, - size_t b, size_t s_q, size_t h_q, size_t d_qk, size_t s_kv, size_t h_kv, size_t d_v, - unsigned int s_tiles) { + T *__restrict__ q_out, T *__restrict__ k_out, T *__restrict__ v_out, size_t b, size_t s_q, + size_t h_q, size_t d_qk, size_t s_kv, size_t h_kv, size_t d_v, + unsigned int permute_s_splits) { const int which = blockIdx.z; const T *__restrict__ in = which == 0 ? grad_q : (which == 1 ? grad_k : grad_v); - T *__restrict__ out = which == 0 ? q_out : (which == 1 ? k_out : v_out); - const size_t S = which == 0 ? s_q : s_kv; - const size_t H = which == 0 ? h_q : h_kv; - const size_t D = which == 0 ? d_qk : (which == 1 ? d_qk : d_v); - const size_t D_bytes = D * sizeof(T); - const size_t D_pad = (D_bytes + 3u) & ~size_t(3); + T *__restrict__ out = which == 0 ? q_out : (which == 1 ? k_out : v_out); + const size_t S = which == 0 ? s_q : s_kv; + const size_t H = which == 0 ? h_q : h_kv; + const size_t D = which == 0 ? d_qk : (which == 1 ? d_qk : d_v); - const size_t tile_s = static_cast(blockIdx.x) % static_cast(s_tiles); - const size_t b_i = static_cast(blockIdx.x) / static_cast(s_tiles); + const size_t h_grid = h_q > h_kv ? h_q : h_kv; + const size_t b_i = static_cast(blockIdx.x) / h_grid; + const size_t h_i = static_cast(blockIdx.x) % h_grid; if (b_i >= b) return; - const size_t tile_h = static_cast(blockIdx.y); + if (which == 0) { + if (h_i >= h_q) return; + } else { + if (h_i >= h_kv) return; + } - const size_t s_base = tile_s * TRANSPOSE_TILE; - const size_t h_base = tile_h * TRANSPOSE_TILE; + const unsigned int s_part = blockIdx.y; + const size_t s_begin = + (S * static_cast(s_part)) / static_cast(permute_s_splits); + const size_t s_end = + (S * static_cast(s_part + 1)) / static_cast(permute_s_splits); + if (s_begin >= s_end) return; + const size_t S_chunk = s_end - s_begin; - extern __shared__ char smem[]; - const size_t smem_row = static_cast(TRANSPOSE_TILE) * D_pad + 4; + const size_t D_bytes = D * sizeof(T); - // ---- Phase 1: global → smem (sweep consecutive S → coalesced reads from BHSD) ---- - for (unsigned int warp_off = threadIdx.x >> 5; - warp_off < TRANSPOSE_TILE; warp_off += TRANSPOSE_WARPS) { - const size_t local_h = warp_off; - const size_t local_s = threadIdx.x & 31u; - const size_t s_i = s_base + local_s; - const size_t h_i = h_base + local_h; - if (s_i < S && h_i < H) { - copy_row_bytes( - reinterpret_cast(in + b_i * H * S * D + h_i * S * D + s_i * D), - smem + local_s * smem_row + local_h * D_pad, D_bytes); + if (D_bytes % 16 == 0) { + constexpr size_t N = 16 / sizeof(T); + permute_bwd_vec_loop(in, out, b, S, H, D, b_i, h_i, s_begin, S_chunk); + return; + } + if (D_bytes % 8 == 0) { + constexpr size_t N = 8 / sizeof(T); + permute_bwd_vec_loop(in, out, b, S, H, D, b_i, h_i, s_begin, S_chunk); + return; + } + if constexpr (sizeof(T) <= 4) { + if (D_bytes % 4 == 0) { + constexpr size_t N = 4 / sizeof(T); + permute_bwd_vec_loop(in, out, b, S, H, D, b_i, h_i, s_begin, S_chunk); + return; } } +} + + +// ---------- main path: TMA ---------- + +constexpr int tma_permute_threads = 128; +constexpr int tma_permute_s_tile_default = 32; + +// ---- 4D TMA PTX wrappers ---- + +__device__ __forceinline__ void cp_async_bulk_tensor_4d_global_to_shared( + void *dst_shmem, const CUtensorMap *tensor_map, uint32_t c0, uint32_t c1, uint32_t c2, + uint32_t c3, uint64_t *mbar) { +#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) + uint32_t dst = __cvta_generic_to_shared(dst_shmem); + uint32_t bar = __cvta_generic_to_shared(mbar); + asm volatile( + "cp.async.bulk.tensor.4d.shared::cluster.global.tile" + ".mbarrier::complete_tx::bytes [%0], [%1, {%2, %3, %4, %5}], [%6];" ::"r"(dst), + "l"(tensor_map), "r"(c0), "r"(c1), "r"(c2), "r"(c3), "r"(bar) + : "memory"); +#else + NVTE_DEVICE_ERROR("cp_async_bulk_tensor_4d_global_to_shared requires SM 10.0+."); +#endif +} + +__device__ __forceinline__ void cp_async_bulk_tensor_4d_shared_to_global( + const CUtensorMap *tensor_map, uint32_t c0, uint32_t c1, uint32_t c2, uint32_t c3, + void *src_shmem) { +#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) + uint32_t src = __cvta_generic_to_shared(src_shmem); + asm volatile( + "cp.async.bulk.tensor.4d.global.shared::cta.bulk_group" + " [%0, {%1, %2, %3, %4}], [%5];" ::"l"(tensor_map), + "r"(c0), "r"(c1), "r"(c2), "r"(c3), "r"(src) + : "memory"); +#else + NVTE_DEVICE_ERROR("cp_async_bulk_tensor_4d_shared_to_global requires SM 9.0+."); +#endif +} + +// ---- Host-side 4D tensor map creation ---- + +static void create_4D_tensor_map(CUtensorMap &tensorMap, void *dataPtr, DType dtype, uint64_t dim0, + uint64_t dim1, uint64_t dim2, uint64_t dim3, uint32_t box0, + uint32_t box1, uint32_t box2, uint32_t box3) { + cuda_driver::ensure_context_exists(); + static PFN_cuTensorMapEncodeTiled_v12000 cuDriverTensorMapEncodeTiled = []() { + void *ptr = cuda_driver::get_symbol("cuTensorMapEncodeTiled"); + return reinterpret_cast(ptr); + }(); + + CUtensorMapDataType tma_dtype; + size_t elem_bytes; + switch (dtype) { + case DType::kFloat16: + tma_dtype = CU_TENSOR_MAP_DATA_TYPE_FLOAT16; + elem_bytes = 2; + break; + case DType::kBFloat16: + tma_dtype = CU_TENSOR_MAP_DATA_TYPE_BFLOAT16; + elem_bytes = 2; + break; + case DType::kFloat8E4M3: + case DType::kFloat8E5M2: + case DType::kFloat8E8M0: + case DType::kByte: + tma_dtype = CU_TENSOR_MAP_DATA_TYPE_UINT8; + elem_bytes = 1; + break; + default: + NVTE_ERROR("create_4D_tensor_map: unsupported dtype ", + to_string(static_cast(dtype))); + } + + constexpr uint32_t rank = 4; + uint64_t size[rank] = {dim0, dim1, dim2, dim3}; + uint64_t stride[rank - 1] = { + dim0 * elem_bytes, + dim0 * dim1 * elem_bytes, + dim0 * dim1 * dim2 * elem_bytes, + }; + uint32_t boxSize[rank] = {box0, box1, box2, box3}; + uint32_t elemStride[rank] = {1, 1, 1, 1}; + + const auto oob_fill = (tma_dtype == CU_TENSOR_MAP_DATA_TYPE_UINT8) + ? CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE + : CU_TENSOR_MAP_FLOAT_OOB_FILL_NAN_REQUEST_ZERO_FMA; + + NVTE_CHECK_CUDA_DRIVER(cuDriverTensorMapEncodeTiled( + &tensorMap, tma_dtype, rank, dataPtr, size, stride, boxSize, elemStride, + CU_TENSOR_MAP_INTERLEAVE_NONE, CU_TENSOR_MAP_SWIZZLE_NONE, CU_TENSOR_MAP_L2_PROMOTION_NONE, + oob_fill)); +} - __syncthreads(); +// ---- TMA helpers ---- +// Strided BSHD: TMA dims [D, H, S, B], coords [0, h, s, b] +// Strided SBHD: TMA dims [D, H, B, S], coords [0, h, b, s] - // ---- Phase 2: smem → global (sweep consecutive H → coalesced writes to SBHD/BSHD) ---- - for (unsigned int warp_off = threadIdx.x >> 5; - warp_off < TRANSPOSE_TILE; warp_off += TRANSPOSE_WARPS) { - const size_t local_s = warp_off; - const size_t local_h = threadIdx.x & 31u; - const size_t s_i = s_base + local_s; - const size_t h_i = h_base + local_h; - if (s_i < S && h_i < H) { - char *__restrict__ dst; - if constexpr (kIsSbhd) - dst = reinterpret_cast(out + s_i * b * H * D + b_i * H * D + h_i * D); - else - dst = reinterpret_cast(out + b_i * S * H * D + s_i * H * D + h_i * D); - copy_row_bytes(smem + local_s * smem_row + local_h * D_pad, dst, D_bytes); - } +template +__device__ __forceinline__ void issue_tma_load_strided(T *smem_buf, const CUtensorMap *tma, + size_t h_i, size_t s_tile, size_t b_i, + uint64_t *mbar, size_t tile_bytes) { + ptx::mbarrier_arrive_expect_tx(mbar, static_cast(tile_bytes)); + if constexpr (kIsBshdBshdBshd) { + cp_async_bulk_tensor_4d_global_to_shared(smem_buf, tma, 0, static_cast(h_i), + static_cast(s_tile), + static_cast(b_i), mbar); + } else { + cp_async_bulk_tensor_4d_global_to_shared(smem_buf, tma, 0, static_cast(h_i), + static_cast(b_i), + static_cast(s_tile), mbar); } } -// ---------- forward kernel (well-aligned D) ---------- +template +__device__ __forceinline__ void issue_tma_store_strided(const CUtensorMap *tma, T *smem_buf, + size_t h_i, size_t s_tile, size_t b_i) { + if constexpr (kIsBshdBshdBshd) { + cp_async_bulk_tensor_4d_shared_to_global(tma, 0, static_cast(h_i), + static_cast(s_tile), + static_cast(b_i), smem_buf); + } else { + cp_async_bulk_tensor_4d_shared_to_global(tma, 0, static_cast(h_i), + static_cast(b_i), + static_cast(s_tile), smem_buf); + } + ptx::cp_async_bulk_commit_group(); +} -template -__launch_bounds__(fallback_permute_threads) __global__ - void permute_to_grouped_tensor_fwd_fallback_kernel( - const T *__restrict__ q_in, const T *__restrict__ k_in, const T *__restrict__ v_in, - T *__restrict__ q_out, T *__restrict__ k_out, T *__restrict__ v_out, size_t b, size_t s_q, - size_t h_q, size_t d_qk, size_t s_kv, size_t h_kv, size_t d_v, - unsigned int permute_s_splits) { +__device__ __forceinline__ void st_global_cs_uint4(uint4 *ptr, uint4 val) { + asm volatile("st.global.cs.v4.b32 [%0], {%1, %2, %3, %4};" ::"l"(ptr), "r"(val.x), "r"(val.y), + "r"(val.z), "r"(val.w) + : "memory"); +} + +// ---- forward: BSHD/SBHD → BHSD ---- +// TMA load from strided input → smem → non-temporal stores to contiguous output. + +template +__launch_bounds__(tma_permute_threads) __global__ + void permute_to_grouped_tensor_fwd_kernel(const __grid_constant__ CUtensorMap tma_q_in, + const __grid_constant__ CUtensorMap tma_k_in, + const __grid_constant__ CUtensorMap tma_v_in, + T *__restrict__ q_out, T *__restrict__ k_out, + T *__restrict__ v_out, size_t b, size_t s_q, + size_t h_q, size_t d_qk, size_t s_kv, size_t h_kv, + size_t d_v, unsigned int permute_s_splits, + size_t s_tile_size) { +#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) const int which = blockIdx.z; - const T *__restrict__ in = which == 0 ? q_in : (which == 1 ? k_in : v_in); - T *__restrict__ out = which == 0 ? q_out : (which == 1 ? k_out : v_out); - const size_t S = which == 0 ? s_q : s_kv; - const size_t H = which == 0 ? h_q : h_kv; - const size_t D = which == 0 ? d_qk : (which == 1 ? d_qk : d_v); + const CUtensorMap *tma_in = which == 0 ? &tma_q_in : (which == 1 ? &tma_k_in : &tma_v_in); + T *__restrict__ tensor_out = which == 0 ? q_out : (which == 1 ? k_out : v_out); + const size_t Sdim = which == 0 ? s_q : s_kv; + const size_t Hdim = which == 0 ? h_q : h_kv; + const size_t Ddim = which == 0 ? d_qk : (which == 1 ? d_qk : d_v); const size_t h_grid = h_q > h_kv ? h_q : h_kv; const size_t b_i = static_cast(blockIdx.x) / h_grid; const size_t h_i = static_cast(blockIdx.x) % h_grid; + if (b_i >= b) return; if (which == 0) { if (h_i >= h_q) return; @@ -705,52 +644,82 @@ __launch_bounds__(fallback_permute_threads) __global__ const unsigned int s_part = blockIdx.y; const size_t s_begin = - (S * static_cast(s_part)) / static_cast(permute_s_splits); + (Sdim * static_cast(s_part)) / static_cast(permute_s_splits); const size_t s_end = - (S * static_cast(s_part + 1)) / static_cast(permute_s_splits); + (Sdim * static_cast(s_part + 1)) / static_cast(permute_s_splits); if (s_begin >= s_end) return; - const size_t S_chunk = s_end - s_begin; - const size_t D_bytes = D * sizeof(T); + const size_t out_base = b_i * Hdim * Sdim * Ddim + h_i * Sdim * Ddim; - if (D_bytes % 16 == 0) { - constexpr size_t N = 16 / sizeof(T); - permute_fwd_vec_loop(in, out, b, S, H, D, b_i, h_i, s_begin, S_chunk); - return; - } - if (D_bytes % 8 == 0) { - constexpr size_t N = 8 / sizeof(T); - permute_fwd_vec_loop(in, out, b, S, H, D, b_i, h_i, s_begin, S_chunk); - return; + extern __shared__ __align__(128) char smem_raw[]; + T *smem = reinterpret_cast(smem_raw); + + __shared__ __align__(8) uint64_t mbar; + const bool is_leader = (threadIdx.x == 0); + + if (is_leader) { + ptx::mbarrier_init(&mbar, static_cast(blockDim.x)); + ptx::fence_proxy_async_shared_cta(); } - if constexpr (sizeof(T) <= 4) { - if (D_bytes % 4 == 0) { - constexpr size_t N = 4 / sizeof(T); - permute_fwd_vec_loop(in, out, b, S, H, D, b_i, h_i, s_begin, S_chunk); - return; + __syncthreads(); + + const size_t S_TILE = s_tile_size; + const uint32_t tile_bytes = static_cast(S_TILE * Ddim * sizeof(T)); + int parity = 0; + + for (size_t s_tile = s_begin; s_tile < s_end; s_tile += S_TILE) { + const size_t tile_rows = min(S_TILE, s_end - s_tile); + + if (is_leader) { + issue_tma_load_strided(smem, tma_in, h_i, s_tile, b_i, &mbar, tile_bytes); + } else { + ptx::mbarrier_arrive(&mbar); + } + + ptx::mbarrier_wait_parity(&mbar, parity); + parity ^= 1; + + T *__restrict__ out_ptr = tensor_out + out_base + s_tile * Ddim; + const size_t total_elems = tile_rows * Ddim; + constexpr size_t vec_elems = sizeof(uint4) / sizeof(T); + + for (size_t i = threadIdx.x * vec_elems; i < total_elems; + i += static_cast(blockDim.x) * vec_elems) { + uint4 v = *reinterpret_cast(smem + i); + st_global_cs_uint4(reinterpret_cast(out_ptr + i), v); } + + __syncthreads(); + } + + if (is_leader) { + ptx::mbarrier_invalid(&mbar); } +#endif } -// ---------- backward kernel (well-aligned D) ---------- +// ---- backward: BHSD → BSHD/SBHD ---- +// Vectorized loads from contiguous input → smem → TMA store to strided output. -template -__launch_bounds__(fallback_permute_threads) __global__ - void permute_to_grouped_tensor_bwd_fallback_kernel( - const T *__restrict__ grad_q, const T *__restrict__ grad_k, const T *__restrict__ grad_v, - T *__restrict__ q_out, T *__restrict__ k_out, T *__restrict__ v_out, size_t b, size_t s_q, - size_t h_q, size_t d_qk, size_t s_kv, size_t h_kv, size_t d_v, - unsigned int permute_s_splits) { +template +__launch_bounds__(tma_permute_threads) __global__ void permute_to_grouped_tensor_bwd_kernel( + const T *__restrict__ grad_q, const T *__restrict__ grad_k, const T *__restrict__ grad_v, + const __grid_constant__ CUtensorMap tma_q_out, const __grid_constant__ CUtensorMap tma_k_out, + const __grid_constant__ CUtensorMap tma_v_out, size_t b, size_t s_q, size_t h_q, size_t d_qk, + size_t s_kv, size_t h_kv, size_t d_v, unsigned int permute_s_splits, + size_t s_tile_size) { +#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) const int which = blockIdx.z; - const T *__restrict__ in = which == 0 ? grad_q : (which == 1 ? grad_k : grad_v); - T *__restrict__ out = which == 0 ? q_out : (which == 1 ? k_out : v_out); - const size_t S = which == 0 ? s_q : s_kv; - const size_t H = which == 0 ? h_q : h_kv; - const size_t D = which == 0 ? d_qk : (which == 1 ? d_qk : d_v); + const T *__restrict__ tensor_in = which == 0 ? grad_q : (which == 1 ? grad_k : grad_v); + const CUtensorMap *tma_out = which == 0 ? &tma_q_out : (which == 1 ? &tma_k_out : &tma_v_out); + const size_t Sdim = which == 0 ? s_q : s_kv; + const size_t Hdim = which == 0 ? h_q : h_kv; + const size_t Ddim = which == 0 ? d_qk : (which == 1 ? d_qk : d_v); const size_t h_grid = h_q > h_kv ? h_q : h_kv; const size_t b_i = static_cast(blockIdx.x) / h_grid; const size_t h_i = static_cast(blockIdx.x) % h_grid; + if (b_i >= b) return; if (which == 0) { if (h_i >= h_q) return; @@ -760,36 +729,67 @@ __launch_bounds__(fallback_permute_threads) __global__ const unsigned int s_part = blockIdx.y; const size_t s_begin = - (S * static_cast(s_part)) / static_cast(permute_s_splits); + (Sdim * static_cast(s_part)) / static_cast(permute_s_splits); const size_t s_end = - (S * static_cast(s_part + 1)) / static_cast(permute_s_splits); + (Sdim * static_cast(s_part + 1)) / static_cast(permute_s_splits); if (s_begin >= s_end) return; - const size_t S_chunk = s_end - s_begin; - const size_t D_bytes = D * sizeof(T); + const size_t in_base = b_i * Hdim * Sdim * Ddim + h_i * Sdim * Ddim; - if (D_bytes % 16 == 0) { - constexpr size_t N = 16 / sizeof(T); - permute_bwd_vec_loop(in, out, b, S, H, D, b_i, h_i, s_begin, S_chunk); - return; - } - if (D_bytes % 8 == 0) { - constexpr size_t N = 8 / sizeof(T); - permute_bwd_vec_loop(in, out, b, S, H, D, b_i, h_i, s_begin, S_chunk); - return; - } - if constexpr (sizeof(T) <= 4) { - if (D_bytes % 4 == 0) { - constexpr size_t N = 4 / sizeof(T); - permute_bwd_vec_loop(in, out, b, S, H, D, b_i, h_i, s_begin, S_chunk); - return; + extern __shared__ __align__(128) char smem_raw[]; + T *smem = reinterpret_cast(smem_raw); + + const size_t S_TILE = s_tile_size; + constexpr size_t vec_elems = sizeof(uint4) / sizeof(T); + + for (size_t s_tile = s_begin; s_tile < s_end; s_tile += S_TILE) { + const size_t tile_rows = min(S_TILE, s_end - s_tile); + + const T *__restrict__ in_ptr = tensor_in + in_base + s_tile * Ddim; + const size_t total_elems = tile_rows * Ddim; + + for (size_t i = threadIdx.x * vec_elems; i < total_elems; + i += static_cast(blockDim.x) * vec_elems) { + *reinterpret_cast(smem + i) = *reinterpret_cast(in_ptr + i); + } + + ptx::fence_proxy_async_shared_cta(); + __syncthreads(); + + if (threadIdx.x == 0) { + issue_tma_store_strided(tma_out, smem, h_i, s_tile, b_i); } + + ptx::cp_async_bulk_wait_group(); + __syncthreads(); + } +#endif +} + + +// ---- create a 4D TMA descriptor ---- +// For BSHD [B, S, H, D]: TMA dims [D, H, S, B], box [D, 1, S_TILE, 1] +// For SBHD [S, B, H, D]: TMA dims [D, H, B, S], box [D, 1, 1, S_TILE] + +static void create_strided_tensor_map(CUtensorMap &map, void *ptr, DType dtype, size_t b, size_t s, + size_t h, size_t d, size_t s_tile, bool is_bshd) { + if (is_bshd) { + create_4D_tensor_map(map, ptr, dtype, static_cast(d), static_cast(h), + static_cast(s), static_cast(b), + static_cast(d), 1, static_cast(s_tile), 1); + } else { + create_4D_tensor_map(map, ptr, dtype, static_cast(d), static_cast(h), + static_cast(b), static_cast(s), + static_cast(d), 1, 1, static_cast(s_tile)); } } -// ---- TMA feasibility check ---- +// ---- check if TMA path is feasible ---- static bool can_use_tma_permute(DType dtype, size_t d_qk, size_t d_v) { + const int sm = cuda::sm_arch(cuda::current_device()); + if (sm < 90) return false; + switch (dtype) { case DType::kFloat16: case DType::kBFloat16: @@ -804,27 +804,12 @@ static bool can_use_tma_permute(DType dtype, size_t d_qk, size_t d_v) { const size_t elem_size = typeToSize(dtype); const size_t inner_qk = d_qk * elem_size; const size_t inner_v = d_v * elem_size; + // hardware requirements for TMA if (inner_qk < 32 || inner_v < 32) return false; if (inner_qk % 16 != 0 || inner_v % 16 != 0) return false; return true; } -// Helper: create a 4D TMA descriptor for the strided (BSHD or SBHD) tensor. -// -// For BSHD [B, S, H, D]: TMA dims [D, H, S, B], box [D, 1, S_TILE, 1] -// For SBHD [S, B, H, D]: TMA dims [D, H, B, S], box [D, 1, 1, S_TILE] -static void create_strided_tensor_map(CUtensorMap &map, void *ptr, DType dtype, size_t b, size_t s, - size_t h, size_t d, size_t s_tile, bool is_bshd) { - if (is_bshd) { - create_4D_tensor_map(map, ptr, dtype, static_cast(d), static_cast(h), - static_cast(s), static_cast(b), - static_cast(d), 1, static_cast(s_tile), 1); - } else { - create_4D_tensor_map(map, ptr, dtype, static_cast(d), static_cast(h), - static_cast(b), static_cast(s), - static_cast(d), 1, 1, static_cast(s_tile)); - } -} void permute_to_grouped_tensor_fwd(Tensor q, Tensor k, Tensor v, Tensor q_out, Tensor k_out, Tensor v_out, NVTE_QKV_Format original_format, @@ -840,14 +825,19 @@ void permute_to_grouped_tensor_fwd(Tensor q, Tensor k, Tensor v, Tensor q_out, T const bool is_bshd = (original_format == NVTE_QKV_Format::NVTE_BSHD); + // Permute paths (most to least preferred): + // 1. TMA: D*elem_size >= 32 and D*elem_size % 16 == 0. + // 2. fallback_vec_aligned: D*elem_size % 4 == 0 but TMA requirements not met. + // Uses vectorized uint4 loads/stores with a flat grid. + // 3. fallback_not_vec_aligned: D*elem_size % 4 != 0. + // Uses shared-memory tiled transpose with padded rows. if (!can_use_tma_permute(q.dtype(), d_qk, d_v)) { - const size_t elem_sz = typeToSize(q.dtype()); - const size_t d_qk_bytes = d_qk * elem_sz; - const size_t d_v_bytes = d_v * elem_sz; + const size_t elem_size = typeToSize(q.dtype()); + const size_t d_qk_bytes = d_qk * elem_size; + const size_t d_v_bytes = d_v * elem_size; const bool needs_transpose = (d_qk_bytes % 4 != 0) || (d_v_bytes % 4 != 0); if (needs_transpose) { - // Tiled transpose path: grid = (B * s_tiles, h_tiles, num_tensors) const size_t s_max = std::max(s_q, s_kv); const size_t h_max = std::max(h_q, h_kv); const unsigned int st = static_cast( @@ -857,7 +847,7 @@ void permute_to_grouped_tensor_fwd(Tensor q, Tensor k, Tensor v, Tensor q_out, T dim3 grid(static_cast(b) * st, ht, static_cast(num_tensors)); const size_t d_max = std::max(d_qk, d_v); - const size_t D_pad = (d_max * elem_sz + 3u) & ~size_t(3); + const size_t D_pad = (d_max * elem_size + 3u) & ~size_t(3); const size_t smem_bytes = static_cast(TRANSPOSE_TILE) * (static_cast(TRANSPOSE_TILE) * D_pad + 4); @@ -865,7 +855,7 @@ void permute_to_grouped_tensor_fwd(Tensor q, Tensor k, Tensor v, Tensor q_out, T if (is_bshd) { TRANSFORMER_ENGINE_TYPE_SWITCH_ALL( q.dtype(), dtype, - permute_fwd_tiled_transpose_kernel + permute_to_grouped_tensor_fwd_fallback_not_vec_aligned_kernel <<>>( reinterpret_cast(q.data.dptr), reinterpret_cast(k.data.dptr), @@ -877,7 +867,7 @@ void permute_to_grouped_tensor_fwd(Tensor q, Tensor k, Tensor v, Tensor q_out, T } else { TRANSFORMER_ENGINE_TYPE_SWITCH_ALL( q.dtype(), dtype, - permute_fwd_tiled_transpose_kernel + permute_to_grouped_tensor_fwd_fallback_not_vec_aligned_kernel <<>>( reinterpret_cast(q.data.dptr), reinterpret_cast(k.data.dptr), @@ -902,7 +892,7 @@ void permute_to_grouped_tensor_fwd(Tensor q, Tensor k, Tensor v, Tensor q_out, T if (is_bshd) { TRANSFORMER_ENGINE_TYPE_SWITCH_ALL( q.dtype(), dtype, - permute_to_grouped_tensor_fwd_fallback_kernel + permute_to_grouped_tensor_fwd_fallback_vec_aligned_kernel <<>>( reinterpret_cast(q.data.dptr), reinterpret_cast(k.data.dptr), @@ -914,7 +904,7 @@ void permute_to_grouped_tensor_fwd(Tensor q, Tensor k, Tensor v, Tensor q_out, T } else { TRANSFORMER_ENGINE_TYPE_SWITCH_ALL( q.dtype(), dtype, - permute_to_grouped_tensor_fwd_fallback_kernel + permute_to_grouped_tensor_fwd_fallback_vec_aligned_kernel <<>>( reinterpret_cast(q.data.dptr), reinterpret_cast(k.data.dptr), @@ -987,10 +977,16 @@ void permute_to_grouped_tensor_bwd(Tensor grad_q, Tensor grad_k, Tensor grad_v, const bool is_bshd = (original_format == NVTE_QKV_Format::NVTE_BSHD); + // Permute paths (most to least preferred): + // 1. TMA: D*elem_size >= 32 and D*elem_size % 16 == 0. + // 2. fallback_vec_aligned: D*elem_size % 4 == 0 but TMA requirements not met. + // Uses vectorized uint4 loads/stores with a flat grid. + // 3. fallback_not_vec_aligned: D*elem_size % 4 != 0. + // Uses shared-memory tiled transpose with padded rows. if (!can_use_tma_permute(grad_q.dtype(), d_qk, d_v)) { - const size_t elem_sz = typeToSize(grad_q.dtype()); - const size_t d_qk_bytes = d_qk * elem_sz; - const size_t d_v_bytes = d_v * elem_sz; + const size_t elem_size = typeToSize(grad_q.dtype()); + const size_t d_qk_bytes = d_qk * elem_size; + const size_t d_v_bytes = d_v * elem_size; const bool needs_transpose = (d_qk_bytes % 4 != 0) || (d_v_bytes % 4 != 0); if (needs_transpose) { @@ -1003,7 +999,7 @@ void permute_to_grouped_tensor_bwd(Tensor grad_q, Tensor grad_k, Tensor grad_v, dim3 grid(static_cast(b) * st, ht, static_cast(num_tensors)); const size_t d_max = std::max(d_qk, d_v); - const size_t D_pad = (d_max * elem_sz + 3u) & ~size_t(3); + const size_t D_pad = (d_max * elem_size + 3u) & ~size_t(3); const size_t smem_bytes = static_cast(TRANSPOSE_TILE) * (static_cast(TRANSPOSE_TILE) * D_pad + 4); @@ -1011,7 +1007,7 @@ void permute_to_grouped_tensor_bwd(Tensor grad_q, Tensor grad_k, Tensor grad_v, if (is_bshd) { TRANSFORMER_ENGINE_TYPE_SWITCH_ALL( grad_q.dtype(), dtype, - permute_bwd_tiled_transpose_kernel + permute_to_grouped_tensor_bwd_fallback_not_vec_aligned_kernel <<>>( reinterpret_cast(grad_q.data.dptr), reinterpret_cast(grad_k.data.dptr), @@ -1023,7 +1019,7 @@ void permute_to_grouped_tensor_bwd(Tensor grad_q, Tensor grad_k, Tensor grad_v, } else { TRANSFORMER_ENGINE_TYPE_SWITCH_ALL( grad_q.dtype(), dtype, - permute_bwd_tiled_transpose_kernel + permute_to_grouped_tensor_bwd_fallback_not_vec_aligned_kernel <<>>( reinterpret_cast(grad_q.data.dptr), reinterpret_cast(grad_k.data.dptr), @@ -1047,7 +1043,7 @@ void permute_to_grouped_tensor_bwd(Tensor grad_q, Tensor grad_k, Tensor grad_v, if (is_bshd) { TRANSFORMER_ENGINE_TYPE_SWITCH_ALL( grad_q.dtype(), dtype, - permute_to_grouped_tensor_bwd_fallback_kernel + permute_to_grouped_tensor_bwd_fallback_vec_aligned_kernel <<>>( reinterpret_cast(grad_q.data.dptr), reinterpret_cast(grad_k.data.dptr), @@ -1059,7 +1055,7 @@ void permute_to_grouped_tensor_bwd(Tensor grad_q, Tensor grad_k, Tensor grad_v, } else { TRANSFORMER_ENGINE_TYPE_SWITCH_ALL( grad_q.dtype(), dtype, - permute_to_grouped_tensor_bwd_fallback_kernel + permute_to_grouped_tensor_bwd_fallback_vec_aligned_kernel <<>>( reinterpret_cast(grad_q.data.dptr), reinterpret_cast(grad_k.data.dptr), @@ -1119,12 +1115,14 @@ void permute_to_grouped_tensor_bwd(Tensor grad_q, Tensor grad_k, Tensor grad_v, } NVTE_CHECK_CUDA(cudaGetLastError()); } -// ---- Multi-tensor pad last dimension with zeros ---- -// -// Pads multiple 2D row-major tensors in a single kernel launch. -// Each tensor copies (rows, in_cols) → (rows, padded_cols), zero-filling [in_cols, padded_cols). -// Uses uint32 (4-byte) granularity for coalesced memory access. -// blockIdx.y selects the tensor; blockIdx.x * blockDim.x + threadIdx.x selects the uint32 element. + +} // namespace permute_to_grouped_tensor + +// =================================================================================== +// multi_tensor_pad_last_dim: pad the last dimension of multiple tensors to a certain alignment +// =================================================================================== + +namespace multi_tensor_pad_last_dim { constexpr int pad_threads_per_block = 256; constexpr int kMaxPadTensors = 16; @@ -1142,7 +1140,7 @@ struct MultiPadParams { }; __launch_bounds__(pad_threads_per_block) __global__ - void multi_pad_last_dim_kernel(MultiPadParams params) { + void multi_tensor_pad_last_dim_kernel(MultiPadParams params) { const auto &a = params.tensors[blockIdx.y]; for (size_t idx = static_cast(blockIdx.x) * blockDim.x + threadIdx.x; idx < a.n_uint32; @@ -1164,7 +1162,7 @@ __launch_bounds__(pad_threads_per_block) __global__ } } -void multi_pad_last_dim(Tensor *inputs, Tensor *outputs, size_t num_tensors, +void multi_tensor_pad_last_dim(Tensor *inputs, Tensor *outputs, size_t num_tensors, cudaStream_t stream) { using namespace transformer_engine; @@ -1223,11 +1221,11 @@ void multi_pad_last_dim(Tensor *inputs, Tensor *outputs, size_t num_tensors, static_cast(65535))); dim3 grid(blocks_x, kernel_count); - multi_pad_last_dim_kernel<<>>(params); + multi_tensor_pad_last_dim_kernel<<>>(params); NVTE_CHECK_CUDA(cudaGetLastError()); } -} // namespace flash_attention +} // namespace multi_tensor_pad_last_dim } // namespace transformer_engine void nvte_prepare_flash_attn_fwd(NVTETensor qkvi, NVTETensor qkv, cudaStream_t stream) { @@ -1255,7 +1253,7 @@ void nvte_permute_to_grouped_tensor_fwd(NVTETensor q, NVTETensor k, NVTETensor v NVTE_API_CALL(nvte_permute_to_grouped_tensor_fwd); using namespace transformer_engine; - flash_attention::permute_to_grouped_tensor_fwd( + permute_to_grouped_tensor::permute_to_grouped_tensor_fwd( *convertNVTETensorCheck(q), *convertNVTETensorCheck(k), *convertNVTETensorCheck(v), *convertNVTETensorCheck(q_out), *convertNVTETensorCheck(k_out), *convertNVTETensorCheck(v_out), original_format, num_tensors, stream); @@ -1268,15 +1266,15 @@ void nvte_permute_to_grouped_tensor_bwd(NVTETensor grad_q, NVTETensor grad_k, NV NVTE_API_CALL(nvte_permute_to_grouped_tensor_bwd); using namespace transformer_engine; - flash_attention::permute_to_grouped_tensor_bwd( + permute_to_grouped_tensor::permute_to_grouped_tensor_bwd( *convertNVTETensorCheck(grad_q), *convertNVTETensorCheck(grad_k), *convertNVTETensorCheck(grad_v), *convertNVTETensorCheck(q), *convertNVTETensorCheck(k), *convertNVTETensorCheck(v), original_format, num_tensors, stream); } -void nvte_multi_pad_last_dim(NVTETensor *inputs, NVTETensor *outputs, size_t num_tensors, +void nvte_multi_tensor_pad_last_dim(NVTETensor *inputs, NVTETensor *outputs, size_t num_tensors, cudaStream_t stream) { - NVTE_API_CALL(nvte_multi_pad_last_dim); + NVTE_API_CALL(nvte_multi_tensor_pad_last_dim); using namespace transformer_engine; std::vector in_vec(num_tensors), out_vec(num_tensors); @@ -1284,5 +1282,5 @@ void nvte_multi_pad_last_dim(NVTETensor *inputs, NVTETensor *outputs, size_t num in_vec[i] = *convertNVTETensorCheck(inputs[i]); out_vec[i] = *convertNVTETensorCheck(outputs[i]); } - flash_attention::multi_pad_last_dim(in_vec.data(), out_vec.data(), num_tensors, stream); + multi_tensor_pad_last_dim::multi_tensor_pad_last_dim(in_vec.data(), out_vec.data(), num_tensors, stream); } diff --git a/transformer_engine/common/fused_attn/fused_attn_fp8.cu b/transformer_engine/common/fused_attn/fused_attn_fp8.cu index e2503ea881..6022c59c2c 100644 --- a/transformer_engine/common/fused_attn/fused_attn_fp8.cu +++ b/transformer_engine/common/fused_attn/fused_attn_fp8.cu @@ -1826,10 +1826,10 @@ void fused_attn_fp8_fwd_impl_v1( scale_o = mha_graph->tensor(1.0f); } } else if (is_mxfp8) { - NVTE_QKV_Format q_scale_format = + NVTE_QKV_Format q_scale_inv_format = (qkv_scale_inv_format != NVTE_QKV_Format_NOT_SET) ? qkv_scale_inv_format : nvte_get_q_format(qkv_layout); - NVTE_QKV_Format kv_scale_format = + NVTE_QKV_Format kv_scale_inv_format = (qkv_scale_inv_format != NVTE_QKV_Format_NOT_SET) ? qkv_scale_inv_format : nvte_get_kv_format(qkv_layout); std::vector q_scale_strides(4); @@ -1837,24 +1837,11 @@ void fused_attn_fp8_fwd_impl_v1( std::vector v_scale_strides(4); auto padded = pad_s_d_for_mxfp8(s_q, s_kv, d_qk, d_v); generateMatrixStridesWithFormat(b, h, padded.s_q_padded, padded.d_qk_scale_padded, - q_scale_strides.data(), q_scale_format); + q_scale_strides.data(), q_scale_inv_format); generateMatrixStridesWithFormat(b, hg, padded.s_kv_padded, padded.d_qk_scale_padded, - k_scale_strides.data(), kv_scale_format); + k_scale_strides.data(), kv_scale_inv_format); generateMatrixStridesWithFormat(b, hg, padded.s_kv_scale_padded, padded.d_v_padded, - v_scale_strides.data(), kv_scale_format); - printf("q_scale_strides: %d, %d, %d, %d\n", q_scale_strides[0], q_scale_strides[1], q_scale_strides[2], q_scale_strides[3]); - printf("k_scale_strides: %d, %d, %d, %d\n", k_scale_strides[0], k_scale_strides[1], k_scale_strides[2], k_scale_strides[3]); - printf("v_scale_strides: %d, %d, %d, %d\n", v_scale_strides[0], v_scale_strides[1], v_scale_strides[2], v_scale_strides[3]); - printf("qkv_layout: %d\n", qkv_layout); - printf("qkv_scale_inv_format: %d\n", qkv_scale_inv_format); - printf("q_scale_format: %d\n", q_scale_format); - printf("kv_scale_format: %d\n", kv_scale_format); - printf("padded.s_q_padded: %d\n", padded.s_q_padded); - printf("padded.d_qk_scale_padded: %d\n", padded.d_qk_scale_padded); - printf("padded.s_kv_padded: %d\n", padded.s_kv_padded); - printf("padded.d_qk_scale_padded: %d\n", padded.d_qk_scale_padded); - printf("padded.s_kv_scale_padded: %d\n", padded.s_kv_scale_padded); - printf("padded.d_v_padded: %d\n", padded.d_v_padded); + v_scale_strides.data(), kv_scale_inv_format); descale_q = mha_graph->tensor(fe::graph::Tensor_attributes() .set_name("Descale_q") @@ -2336,13 +2323,13 @@ void fused_attn_fp8_bwd_impl_v1( } else if (is_mxfp8) { NVTE_QKV_Format q_format = nvte_get_q_format(qkv_layout); NVTE_QKV_Format kv_format = nvte_get_kv_format(qkv_layout); - NVTE_QKV_Format q_scale_format = + NVTE_QKV_Format q_scale_inv_format = (qkv_scale_inv_format != NVTE_QKV_Format_NOT_SET) ? qkv_scale_inv_format : q_format; - NVTE_QKV_Format kv_scale_format = + NVTE_QKV_Format kv_scale_inv_format = (qkv_scale_inv_format != NVTE_QKV_Format_NOT_SET) ? qkv_scale_inv_format : kv_format; - NVTE_QKV_Format do_scale_format = + NVTE_QKV_Format do_scale_format_ = (do_scale_inv_format != NVTE_QKV_Format_NOT_SET) ? do_scale_inv_format : do_format; // Q_t, K_t, dO_t, dO_f16 @@ -2375,19 +2362,19 @@ void fused_attn_fp8_bwd_impl_v1( std::vector q_scale_strides(4), q_t_scale_strides(4), k_scale_strides(4), k_t_scale_strides(4), v_scale_strides(4), dO_scale_strides(4), dO_t_scale_strides(4); generateMatrixStridesWithFormat(b, h, padded.s_q_padded, padded.d_qk_scale_padded, - q_scale_strides.data(), q_scale_format); + q_scale_strides.data(), q_scale_inv_format); generateMatrixStridesWithFormat(b, h, padded.s_q_scale_padded, padded.d_qk_padded, - q_t_scale_strides.data(), q_scale_format); + q_t_scale_strides.data(), q_scale_inv_format); generateMatrixStridesWithFormat(b, hg, padded.s_kv_padded, padded.d_qk_scale_padded, - k_scale_strides.data(), kv_scale_format); + k_scale_strides.data(), kv_scale_inv_format); generateMatrixStridesWithFormat(b, hg, padded.s_kv_scale_padded, padded.d_qk_padded, - k_t_scale_strides.data(), kv_scale_format); + k_t_scale_strides.data(), kv_scale_inv_format); generateMatrixStridesWithFormat(b, hg, padded.s_kv_padded, padded.d_v_scale_padded, - v_scale_strides.data(), kv_scale_format); + v_scale_strides.data(), kv_scale_inv_format); generateMatrixStridesWithFormat(b, h, padded.s_q_padded, padded.d_v_scale_padded, - dO_scale_strides.data(), do_scale_format); + dO_scale_strides.data(), do_scale_format_); generateMatrixStridesWithFormat(b, h, padded.s_q_scale_padded, padded.d_v_padded, - dO_t_scale_strides.data(), do_scale_format); + dO_t_scale_strides.data(), do_scale_format_); descale_q = mha_graph->tensor(fe::graph::Tensor_attributes() .set_name("Descale_q") diff --git a/transformer_engine/common/fused_rope/fused_rope.cu b/transformer_engine/common/fused_rope/fused_rope.cu index df396e246c..27dc11ab43 100644 --- a/transformer_engine/common/fused_rope/fused_rope.cu +++ b/transformer_engine/common/fused_rope/fused_rope.cu @@ -648,518 +648,8 @@ void fused_qkv_rope_backward(const Tensor &q_grad_out, const Tensor &k_grad_out, qkv_split_arg_list_0, qkv_split_arg_list_1, qkv_split_arg_list_2, stream);); } -// ============================================================================ -// MLA YARN RoPE kernels -// ============================================================================ - -__device__ int mla_get_thd_token_idx(const int *cu_seqlens, int pid_m, int seq_num, int cp_rank, - int cp_size) { - int token_idx = -1; - int this_seq_len = 0; - int last_cum = cu_seqlens[0] / cp_size; - for (int seq_idx = 0; seq_idx < seq_num; seq_idx++) { - int cur_cum = cu_seqlens[seq_idx + 1] / cp_size; - if (token_idx == -1 && cur_cum > pid_m) { - token_idx = pid_m - last_cum; - this_seq_len = cur_cum - last_cum; - } - last_cum = cur_cum; - } - if (cp_size > 1) { - if (token_idx < this_seq_len / 2) { - token_idx = token_idx + cp_rank * this_seq_len / 2; - } else { - token_idx = - (token_idx - this_seq_len / 2) + (2 * cp_size - cp_rank - 1) * this_seq_len / 2; - } - } - return token_idx; -} - -template -__global__ void mla_yarn_rope_q_forward_kernel(const scalar_t *q_input, const float *cos_data, - const float *sin_data, scalar_t *q_output, - const int *cu_seqlens, const int qk_head_dim, - const int emb_dim, const int h, const int d, - const int s, const int b, const int cp_size, - const int cp_rank) { - int pid_m = blockIdx.x; - const int half_emb = emb_dim / 2; - const int stride_t = h * d; - const int stride_h_val = d; - - int token_idx; - if (cu_seqlens == nullptr) { - int s_id = pid_m / b; - token_idx = s_id; - if (cp_size > 1) { - if (s_id < s / 2) { - token_idx = s_id + cp_rank * s / 2; - } else { - token_idx = s * cp_size - (cp_rank + 1) * s / 2 + s_id - s / 2; - } - } - } else { - token_idx = mla_get_thd_token_idx(cu_seqlens, pid_m, b, cp_rank, cp_size); - } - - extern __shared__ float shared_mem_q_fwd[]; - float *sh_cos_l = shared_mem_q_fwd; - float *sh_sin_l = shared_mem_q_fwd + half_emb; - float *sh_cos_r = shared_mem_q_fwd + 2 * half_emb; - float *sh_sin_r = shared_mem_q_fwd + 3 * half_emb; - - int tid = threadIdx.x + threadIdx.y * blockDim.x; - int num_threads = blockDim.x * blockDim.y; - for (int i = tid; i < half_emb; i += num_threads) { - sh_cos_l[i] = cos_data[token_idx * emb_dim + i]; - sh_sin_l[i] = sin_data[token_idx * emb_dim + i]; - sh_cos_r[i] = cos_data[token_idx * emb_dim + half_emb + i]; - sh_sin_r[i] = sin_data[token_idx * emb_dim + half_emb + i]; - } - __syncthreads(); - - int base = pid_m * stride_t; - - for (int h_id = threadIdx.y; h_id < h; h_id += blockDim.y) { - int head_offset = base + h_id * stride_h_val; - - for (int i = threadIdx.x; i < qk_head_dim; i += blockDim.x) { - q_output[head_offset + i] = q_input[head_offset + i]; - } - - int rope_in = head_offset + qk_head_dim; - int rope_out = head_offset + qk_head_dim; - for (int i = threadIdx.x; i < half_emb; i += blockDim.x) { - float x1 = static_cast(q_input[rope_in + i * 2]); - float x2 = static_cast(q_input[rope_in + i * 2 + 1]); - - q_output[rope_out + i] = static_cast(x1 * sh_cos_l[i] - x2 * sh_sin_l[i]); - q_output[rope_out + half_emb + i] = - static_cast(x2 * sh_cos_r[i] + x1 * sh_sin_r[i]); - } - } -} - -template -__global__ void mla_yarn_rope_q_backward_kernel(const scalar_t *grad_output, - const float *cos_data, const float *sin_data, - scalar_t *grad_input, const int *cu_seqlens, - const int qk_head_dim, const int emb_dim, - const int h, const int d, const int s, const int b, - const int cp_size, const int cp_rank) { - int pid_m = blockIdx.x; - const int half_emb = emb_dim / 2; - const int stride_t = h * d; - const int stride_h_val = d; - - int token_idx; - if (cu_seqlens == nullptr) { - int s_id = pid_m / b; - token_idx = s_id; - if (cp_size > 1) { - if (s_id < s / 2) { - token_idx = s_id + cp_rank * s / 2; - } else { - token_idx = s * cp_size - (cp_rank + 1) * s / 2 + s_id - s / 2; - } - } - } else { - token_idx = mla_get_thd_token_idx(cu_seqlens, pid_m, b, cp_rank, cp_size); - } - - extern __shared__ float shared_mem_q_bwd[]; - float *sh_cos_l = shared_mem_q_bwd; - float *sh_sin_l = shared_mem_q_bwd + half_emb; - float *sh_cos_r = shared_mem_q_bwd + 2 * half_emb; - float *sh_sin_r = shared_mem_q_bwd + 3 * half_emb; - - int tid = threadIdx.x + threadIdx.y * blockDim.x; - int num_threads = blockDim.x * blockDim.y; - for (int i = tid; i < half_emb; i += num_threads) { - sh_cos_l[i] = cos_data[token_idx * emb_dim + i]; - sh_sin_l[i] = sin_data[token_idx * emb_dim + i]; - sh_cos_r[i] = cos_data[token_idx * emb_dim + half_emb + i]; - sh_sin_r[i] = sin_data[token_idx * emb_dim + half_emb + i]; - } - __syncthreads(); - - int base = pid_m * stride_t; - - for (int h_id = threadIdx.y; h_id < h; h_id += blockDim.y) { - int head_offset = base + h_id * stride_h_val; - - for (int i = threadIdx.x; i < qk_head_dim; i += blockDim.x) { - grad_input[head_offset + i] = grad_output[head_offset + i]; - } - - int rope_offset = head_offset + qk_head_dim; - for (int i = threadIdx.x; i < half_emb; i += blockDim.x) { - float gl = static_cast(grad_output[rope_offset + i]); - float gr = static_cast(grad_output[rope_offset + half_emb + i]); - - grad_input[rope_offset + i * 2] = static_cast(gl * sh_cos_l[i] + gr * sh_sin_r[i]); - grad_input[rope_offset + i * 2 + 1] = - static_cast(-gl * sh_sin_l[i] + gr * sh_cos_r[i]); - } - } -} - -template -__global__ void mla_yarn_rope_kv_forward_kernel( - const scalar_t *kv_input, const scalar_t *k_pos_emb, const float *cos_data, - const float *sin_data, scalar_t *o_key, scalar_t *o_value, const int *cu_seqlens, - const int emb_dim, const int k_dim, const int v_dim, const int h, const int s, const int b, - const int cp_size, const int cp_rank) { - int pid_m = blockIdx.x; - const int half_emb = emb_dim / 2; - const int kv_stride_t = h * (k_dim + v_dim); - const int kv_stride_h = k_dim + v_dim; - const int emb_stride_t = emb_dim; - const int k_stride_t = h * (k_dim + emb_dim); - const int k_stride_h = k_dim + emb_dim; - const int v_stride_t = h * v_dim; - const int v_stride_h = v_dim; - - int token_idx; - if (cu_seqlens == nullptr) { - int s_id = pid_m / b; - token_idx = s_id; - if (cp_size > 1) { - if (s_id < s / 2) { - token_idx = s_id + cp_rank * s / 2; - } else { - token_idx = s * cp_size - (cp_rank + 1) * s / 2 + s_id - s / 2; - } - } - } else { - token_idx = mla_get_thd_token_idx(cu_seqlens, pid_m, b, cp_rank, cp_size); - } - - extern __shared__ float shared_mem_kv_fwd[]; - float *sh_cos_l = shared_mem_kv_fwd; - float *sh_sin_l = shared_mem_kv_fwd + half_emb; - float *sh_cos_r = shared_mem_kv_fwd + 2 * half_emb; - float *sh_sin_r = shared_mem_kv_fwd + 3 * half_emb; - float *sh_rot_left = shared_mem_kv_fwd + 4 * half_emb; - float *sh_rot_right = shared_mem_kv_fwd + 5 * half_emb; - - int tid = threadIdx.x + threadIdx.y * blockDim.x; - int num_threads = blockDim.x * blockDim.y; - for (int i = tid; i < half_emb; i += num_threads) { - sh_cos_l[i] = cos_data[token_idx * emb_dim + i]; - sh_sin_l[i] = sin_data[token_idx * emb_dim + i]; - sh_cos_r[i] = cos_data[token_idx * emb_dim + half_emb + i]; - sh_sin_r[i] = sin_data[token_idx * emb_dim + half_emb + i]; - } - __syncthreads(); - - for (int i = tid; i < half_emb; i += num_threads) { - float x1 = static_cast(k_pos_emb[pid_m * emb_stride_t + i * 2]); - float x2 = static_cast(k_pos_emb[pid_m * emb_stride_t + i * 2 + 1]); - sh_rot_left[i] = x1 * sh_cos_l[i] - x2 * sh_sin_l[i]; - sh_rot_right[i] = x2 * sh_cos_r[i] + x1 * sh_sin_r[i]; - } - __syncthreads(); - - for (int h_id = threadIdx.y; h_id < h; h_id += blockDim.y) { - int kv_head = pid_m * kv_stride_t + h_id * kv_stride_h; - int k_head = pid_m * k_stride_t + h_id * k_stride_h; - int v_head = pid_m * v_stride_t + h_id * v_stride_h; - - for (int i = threadIdx.x; i < k_dim; i += blockDim.x) { - o_key[k_head + i] = kv_input[kv_head + i]; - } - - for (int i = threadIdx.x; i < v_dim; i += blockDim.x) { - o_value[v_head + i] = kv_input[kv_head + k_dim + i]; - } - - for (int i = threadIdx.x; i < half_emb; i += blockDim.x) { - o_key[k_head + k_dim + i] = static_cast(sh_rot_left[i]); - o_key[k_head + k_dim + half_emb + i] = static_cast(sh_rot_right[i]); - } - } -} - -template -__global__ void mla_yarn_rope_kv_backward_kernel( - const scalar_t *dk, const scalar_t *dv, const float *cos_data, const float *sin_data, - scalar_t *d_kv, scalar_t *d_emb, const int *cu_seqlens, const int emb_dim, const int k_dim, - const int v_dim, const int h, const int s, const int b, const int cp_size, const int cp_rank) { - int pid_m = blockIdx.x; - const int half_emb = emb_dim / 2; - const int dk_stride_t = h * (k_dim + emb_dim); - const int dk_stride_h = k_dim + emb_dim; - const int dv_stride_t = h * v_dim; - const int dv_stride_h = v_dim; - const int dkv_stride_t = h * (k_dim + v_dim); - const int dkv_stride_h = k_dim + v_dim; - - int token_idx; - if (cu_seqlens == nullptr) { - int s_id = pid_m / b; - token_idx = s_id; - if (cp_size > 1) { - if (s_id < s / 2) { - token_idx = s_id + cp_rank * s / 2; - } else { - token_idx = s * cp_size - (cp_rank + 1) * s / 2 + s_id - s / 2; - } - } - } else { - token_idx = mla_get_thd_token_idx(cu_seqlens, pid_m, b, cp_rank, cp_size); - } - - for (int h_id = threadIdx.y; h_id < h; h_id += blockDim.y) { - int dk_head = pid_m * dk_stride_t + h_id * dk_stride_h; - int dv_head = pid_m * dv_stride_t + h_id * dv_stride_h; - int dkv_head = pid_m * dkv_stride_t + h_id * dkv_stride_h; - - for (int i = threadIdx.x; i < k_dim; i += blockDim.x) { - d_kv[dkv_head + i] = dk[dk_head + i]; - } - for (int i = threadIdx.x; i < v_dim; i += blockDim.x) { - d_kv[dkv_head + k_dim + i] = dv[dv_head + i]; - } - } - - extern __shared__ float shared_mem_kv_bwd[]; - float *sh_cos_l = shared_mem_kv_bwd; - float *sh_sin_l = shared_mem_kv_bwd + half_emb; - float *sh_cos_r = shared_mem_kv_bwd + 2 * half_emb; - float *sh_sin_r = shared_mem_kv_bwd + 3 * half_emb; - - int tid = threadIdx.x + threadIdx.y * blockDim.x; - int num_threads = blockDim.x * blockDim.y; - for (int i = tid; i < half_emb; i += num_threads) { - sh_cos_l[i] = cos_data[token_idx * emb_dim + i]; - sh_sin_l[i] = sin_data[token_idx * emb_dim + i]; - sh_cos_r[i] = cos_data[token_idx * emb_dim + half_emb + i]; - sh_sin_r[i] = sin_data[token_idx * emb_dim + half_emb + i]; - } - __syncthreads(); - - for (int i = threadIdx.x; i < half_emb; i += blockDim.x) { - if (threadIdx.y == 0) { - float accum_l = 0.0f, accum_r = 0.0f; - for (int h_id = 0; h_id < h; h_id++) { - int dk_head = pid_m * dk_stride_t + h_id * dk_stride_h; - accum_l += static_cast(dk[dk_head + k_dim + i]); - accum_r += static_cast(dk[dk_head + k_dim + half_emb + i]); - } - float dx1 = accum_l * sh_cos_l[i] + accum_r * sh_sin_r[i]; - float dx2 = -accum_l * sh_sin_l[i] + accum_r * sh_cos_r[i]; - d_emb[pid_m * emb_dim + i * 2] = static_cast(dx1); - d_emb[pid_m * emb_dim + i * 2 + 1] = static_cast(dx2); - } - } -} - -template -void mla_yarn_rope_q_forward_launcher(const scalar_t *q_input, const float *cos_data, - const float *sin_data, scalar_t *q_output, - const int *cu_seqlens, const int qk_head_dim, - const int emb_dim, const int h, const int d, - const int total_seqlen, const int s, const int b, - const int cp_size, const int cp_rank, cudaStream_t stream) { - int warps_per_block = h < 16 ? 4 : 8; - dim3 blocks(total_seqlen); - dim3 threads(THREADS_PER_WARP, warps_per_block); - const int shared_mem_size = 4 * (emb_dim / 2) * sizeof(float); - - mla_yarn_rope_q_forward_kernel<<>>( - q_input, cos_data, sin_data, q_output, cu_seqlens, qk_head_dim, emb_dim, h, d, s, b, - cp_size, cp_rank); - NVTE_CHECK_CUDA(cudaGetLastError()); -} - -template -void mla_yarn_rope_q_backward_launcher(const scalar_t *grad_output, const float *cos_data, - const float *sin_data, scalar_t *grad_input, - const int *cu_seqlens, const int qk_head_dim, - const int emb_dim, const int h, const int d, - const int total_seqlen, const int s, const int b, - const int cp_size, const int cp_rank, cudaStream_t stream) { - int warps_per_block = h < 16 ? 4 : 8; - dim3 blocks(total_seqlen); - dim3 threads(THREADS_PER_WARP, warps_per_block); - const int shared_mem_size = 4 * (emb_dim / 2) * sizeof(float); - - mla_yarn_rope_q_backward_kernel<<>>( - grad_output, cos_data, sin_data, grad_input, cu_seqlens, qk_head_dim, emb_dim, h, d, s, b, - cp_size, cp_rank); - NVTE_CHECK_CUDA(cudaGetLastError()); -} - -template -void mla_yarn_rope_kv_forward_launcher(const scalar_t *kv_input, const scalar_t *k_pos_emb, - const float *cos_data, const float *sin_data, - scalar_t *o_key, scalar_t *o_value, const int *cu_seqlens, - const int emb_dim, const int k_dim, const int v_dim, - const int h, const int total_seqlen, const int s, - const int b, const int cp_size, const int cp_rank, - cudaStream_t stream) { - int warps_per_block = h < 16 ? 4 : 8; - dim3 blocks(total_seqlen); - dim3 threads(THREADS_PER_WARP, warps_per_block); - const int shared_mem_size = 6 * (emb_dim / 2) * sizeof(float); - - mla_yarn_rope_kv_forward_kernel<<>>( - kv_input, k_pos_emb, cos_data, sin_data, o_key, o_value, cu_seqlens, emb_dim, k_dim, v_dim, - h, s, b, cp_size, cp_rank); - NVTE_CHECK_CUDA(cudaGetLastError()); -} - -template -void mla_yarn_rope_kv_backward_launcher(const scalar_t *dk, const scalar_t *dv, - const float *cos_data, const float *sin_data, - scalar_t *d_kv, scalar_t *d_emb, const int *cu_seqlens, - const int emb_dim, const int k_dim, const int v_dim, - const int h, const int total_seqlen, const int s, - const int b, const int cp_size, const int cp_rank, - cudaStream_t stream) { - int warps_per_block = h < 16 ? 4 : 8; - dim3 blocks(total_seqlen); - dim3 threads(THREADS_PER_WARP, warps_per_block); - const int shared_mem_size = 4 * (emb_dim / 2) * sizeof(float); - - mla_yarn_rope_kv_backward_kernel<<>>( - dk, dv, cos_data, sin_data, d_kv, d_emb, cu_seqlens, emb_dim, k_dim, v_dim, h, s, b, - cp_size, cp_rank); - NVTE_CHECK_CUDA(cudaGetLastError()); -} - -void fused_mla_rope_q_forward(const Tensor &q_input, const Tensor &cos, const Tensor &sin, - Tensor *q_output, const Tensor &cu_seqlens, const int qk_head_dim, - const int emb_dim, const int h, const int d, const int total_seqlen, - const int s, const int b, const int cp_size, const int cp_rank, - cudaStream_t stream) { - TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT( - q_input.data.dtype, scalar_t, - mla_yarn_rope_q_forward_launcher( - reinterpret_cast(q_input.data.dptr), - reinterpret_cast(cos.data.dptr), - reinterpret_cast(sin.data.dptr), - reinterpret_cast(q_output->data.dptr), - reinterpret_cast(cu_seqlens.data.dptr), qk_head_dim, emb_dim, h, d, - total_seqlen, s, b, cp_size, cp_rank, stream);); -} - -void fused_mla_rope_q_backward(const Tensor &grad_output, const Tensor &cos, const Tensor &sin, - Tensor *grad_input, const Tensor &cu_seqlens, - const int qk_head_dim, const int emb_dim, const int h, const int d, - const int total_seqlen, const int s, const int b, const int cp_size, - const int cp_rank, cudaStream_t stream) { - TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT( - grad_output.data.dtype, scalar_t, - mla_yarn_rope_q_backward_launcher( - reinterpret_cast(grad_output.data.dptr), - reinterpret_cast(cos.data.dptr), - reinterpret_cast(sin.data.dptr), - reinterpret_cast(grad_input->data.dptr), - reinterpret_cast(cu_seqlens.data.dptr), qk_head_dim, emb_dim, h, d, - total_seqlen, s, b, cp_size, cp_rank, stream);); -} - -void fused_mla_rope_kv_forward(const Tensor &kv_input, const Tensor &k_pos_emb, const Tensor &cos, - const Tensor &sin, Tensor *o_key, Tensor *o_value, - const Tensor &cu_seqlens, const int emb_dim, const int k_dim, - const int v_dim, const int h, const int total_seqlen, const int s, - const int b, const int cp_size, const int cp_rank, - cudaStream_t stream) { - TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT( - kv_input.data.dtype, scalar_t, - mla_yarn_rope_kv_forward_launcher( - reinterpret_cast(kv_input.data.dptr), - reinterpret_cast(k_pos_emb.data.dptr), - reinterpret_cast(cos.data.dptr), - reinterpret_cast(sin.data.dptr), - reinterpret_cast(o_key->data.dptr), - reinterpret_cast(o_value->data.dptr), - reinterpret_cast(cu_seqlens.data.dptr), emb_dim, k_dim, v_dim, h, - total_seqlen, s, b, cp_size, cp_rank, stream);); -} - -void fused_mla_rope_kv_backward(const Tensor &dk, const Tensor &dv, const Tensor &cos, - const Tensor &sin, Tensor *d_kv, Tensor *d_emb, - const Tensor &cu_seqlens, const int emb_dim, const int k_dim, - const int v_dim, const int h, const int total_seqlen, const int s, - const int b, const int cp_size, const int cp_rank, - cudaStream_t stream) { - TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT( - dk.data.dtype, scalar_t, - mla_yarn_rope_kv_backward_launcher( - reinterpret_cast(dk.data.dptr), - reinterpret_cast(dv.data.dptr), - reinterpret_cast(cos.data.dptr), - reinterpret_cast(sin.data.dptr), - reinterpret_cast(d_kv->data.dptr), - reinterpret_cast(d_emb->data.dptr), - reinterpret_cast(cu_seqlens.data.dptr), emb_dim, k_dim, v_dim, h, - total_seqlen, s, b, cp_size, cp_rank, stream);); -} - } // end namespace transformer_engine -void nvte_fused_mla_rope_q_forward(const NVTETensor q_input, const NVTETensor cos, - const NVTETensor sin, NVTETensor q_output, - const NVTETensor cu_seqlens, const int qk_head_dim, - const int emb_dim, const int h, const int d, - const int total_seqlen, const int s, const int b, - const int cp_size, const int cp_rank, cudaStream_t stream) { - NVTE_API_CALL(nvte_fused_mla_rope_q_forward); - using namespace transformer_engine; - fused_mla_rope_q_forward(*convertNVTETensorCheck(q_input), *convertNVTETensorCheck(cos), - *convertNVTETensorCheck(sin), convertNVTETensorCheck(q_output), - *convertNVTETensorCheck(cu_seqlens), qk_head_dim, emb_dim, h, d, - total_seqlen, s, b, cp_size, cp_rank, stream); -} - -void nvte_fused_mla_rope_q_backward(const NVTETensor grad_output, const NVTETensor cos, - const NVTETensor sin, NVTETensor grad_input, - const NVTETensor cu_seqlens, const int qk_head_dim, - const int emb_dim, const int h, const int d, - const int total_seqlen, const int s, const int b, - const int cp_size, const int cp_rank, cudaStream_t stream) { - NVTE_API_CALL(nvte_fused_mla_rope_q_backward); - using namespace transformer_engine; - fused_mla_rope_q_backward(*convertNVTETensorCheck(grad_output), *convertNVTETensorCheck(cos), - *convertNVTETensorCheck(sin), convertNVTETensorCheck(grad_input), - *convertNVTETensorCheck(cu_seqlens), qk_head_dim, emb_dim, h, d, - total_seqlen, s, b, cp_size, cp_rank, stream); -} - -void nvte_fused_mla_rope_kv_forward(const NVTETensor kv_input, const NVTETensor k_pos_emb, - const NVTETensor cos, const NVTETensor sin, NVTETensor o_key, - NVTETensor o_value, const NVTETensor cu_seqlens, - const int emb_dim, const int k_dim, const int v_dim, - const int h, const int total_seqlen, const int s, const int b, - const int cp_size, const int cp_rank, cudaStream_t stream) { - NVTE_API_CALL(nvte_fused_mla_rope_kv_forward); - using namespace transformer_engine; - fused_mla_rope_kv_forward( - *convertNVTETensorCheck(kv_input), *convertNVTETensorCheck(k_pos_emb), - *convertNVTETensorCheck(cos), *convertNVTETensorCheck(sin), convertNVTETensorCheck(o_key), - convertNVTETensorCheck(o_value), *convertNVTETensorCheck(cu_seqlens), emb_dim, k_dim, v_dim, - h, total_seqlen, s, b, cp_size, cp_rank, stream); -} - -void nvte_fused_mla_rope_kv_backward(const NVTETensor dk, const NVTETensor dv, const NVTETensor cos, - const NVTETensor sin, NVTETensor d_kv, NVTETensor d_emb, - const NVTETensor cu_seqlens, const int emb_dim, - const int k_dim, const int v_dim, const int h, - const int total_seqlen, const int s, const int b, - const int cp_size, const int cp_rank, cudaStream_t stream) { - NVTE_API_CALL(nvte_fused_mla_rope_kv_backward); - using namespace transformer_engine; - fused_mla_rope_kv_backward(*convertNVTETensorCheck(dk), *convertNVTETensorCheck(dv), - *convertNVTETensorCheck(cos), *convertNVTETensorCheck(sin), - convertNVTETensorCheck(d_kv), convertNVTETensorCheck(d_emb), - *convertNVTETensorCheck(cu_seqlens), emb_dim, k_dim, v_dim, h, - total_seqlen, s, b, cp_size, cp_rank, stream); -} - void nvte_fused_rope_forward(const NVTETensor input, const NVTETensor cu_seqlens, const NVTETensor freqs, const NVTETensor start_positions, NVTETensor output, const NVTE_QKV_Format qkv_format, diff --git a/transformer_engine/common/include/transformer_engine/fused_attn.h b/transformer_engine/common/include/transformer_engine/fused_attn.h index 62d4369768..1ad6b8f889 100644 --- a/transformer_engine/common/include/transformer_engine/fused_attn.h +++ b/transformer_engine/common/include/transformer_engine/fused_attn.h @@ -664,8 +664,8 @@ void nvte_permute_to_grouped_tensor_bwd(NVTETensor grad_q, NVTETensor grad_k, NV * \param[in] num_tensors Number of tensor pairs to process (1..16). * \param[in] stream CUDA stream. */ -void nvte_multi_pad_last_dim(NVTETensor *inputs, NVTETensor *outputs, size_t num_tensors, - cudaStream_t stream); +void nvte_multi_tensor_pad_last_dim(NVTETensor *inputs, NVTETensor *outputs, size_t num_tensors, + cudaStream_t stream); #ifdef __cplusplus } // extern "C" diff --git a/transformer_engine/common/include/transformer_engine/fused_rope.h b/transformer_engine/common/include/transformer_engine/fused_rope.h index 69b5168212..aea5256a2c 100644 --- a/transformer_engine/common/include/transformer_engine/fused_rope.h +++ b/transformer_engine/common/include/transformer_engine/fused_rope.h @@ -139,81 +139,6 @@ void nvte_fused_qkv_rope_backward(const NVTETensor q_grad_out, const NVTETensor const int qkv_split_arg_list_0, const int qkv_split_arg_list_1, const int qkv_split_arg_list_2, cudaStream_t stream); -/*! \brief Apply YARN RoPE to MLA query tensor (forward). - * - * Reads the last emb_dim elements interleaved, applies YARN rotation with - * split cos/sin, and writes de-interleaved. First qk_head_dim elements are - * copied unchanged. Input is [total_seqlen, h, d] (flattened from SBHD or THD). - * - * \param[in] q_input Input Q tensor. - * \param[in] cos Pre-computed cosine tensor [max_s, emb_dim]. - * \param[in] sin Pre-computed sine tensor [max_s, emb_dim]. - * \param[out] q_output Output Q tensor (same shape as input). - * \param[in] cu_seqlens Cumulative sequence lengths for THD (empty for SBHD). - * \param[in] qk_head_dim Non-RoPE prefix dimension per head. - * \param[in] emb_dim RoPE embedding dimension. - * \param[in] h Number of heads. - * \param[in] d Total head dimension (qk_head_dim + emb_dim). - * \param[in] total_seqlen Total tokens (s*b for SBHD, total_t for THD). - * \param[in] s Sequence length (SBHD) or max_s (THD). - * \param[in] b Batch size (SBHD) or num_seqs (THD). - * \param[in] cp_size Context parallel world size. - * \param[in] cp_rank Context parallel rank. - * \param[in] stream CUDA stream. - */ -void nvte_fused_mla_rope_q_forward(const NVTETensor q_input, const NVTETensor cos, - const NVTETensor sin, NVTETensor q_output, - const NVTETensor cu_seqlens, const int qk_head_dim, - const int emb_dim, const int h, const int d, - const int total_seqlen, const int s, const int b, - const int cp_size, const int cp_rank, cudaStream_t stream); - -/*! \brief Backward of YARN RoPE for MLA query tensor. */ -void nvte_fused_mla_rope_q_backward(const NVTETensor grad_output, const NVTETensor cos, - const NVTETensor sin, NVTETensor grad_input, - const NVTETensor cu_seqlens, const int qk_head_dim, - const int emb_dim, const int h, const int d, - const int total_seqlen, const int s, const int b, - const int cp_size, const int cp_rank, cudaStream_t stream); - -/*! \brief Apply YARN RoPE to MLA key-value tensor (forward). - * - * Splits KV into key and value, applies YARN rotation to k_pos_emb (shared - * across heads), concatenates rotated embedding to each head of output key. - * - * \param[in] kv_input Input KV tensor [total_t, h, k_dim+v_dim]. - * \param[in] k_pos_emb Positional embedding [total_t, emb_dim]. - * \param[in] cos Pre-computed cosine [max_s, emb_dim]. - * \param[in] sin Pre-computed sine [max_s, emb_dim]. - * \param[out] o_key Output key [total_t, h, k_dim+emb_dim]. - * \param[out] o_value Output value [total_t, h, v_dim]. - * \param[in] cu_seqlens Cumulative sequence lengths for THD (empty for SBHD). - * \param[in] emb_dim RoPE embedding dimension. - * \param[in] k_dim Key dimension per head (from KV). - * \param[in] v_dim Value dimension per head (from KV). - * \param[in] h Number of heads. - * \param[in] total_seqlen Total tokens. - * \param[in] s Sequence length (SBHD) or max_s (THD). - * \param[in] b Batch size (SBHD) or num_seqs (THD). - * \param[in] cp_size Context parallel world size. - * \param[in] cp_rank Context parallel rank. - * \param[in] stream CUDA stream. - */ -void nvte_fused_mla_rope_kv_forward(const NVTETensor kv_input, const NVTETensor k_pos_emb, - const NVTETensor cos, const NVTETensor sin, NVTETensor o_key, - NVTETensor o_value, const NVTETensor cu_seqlens, - const int emb_dim, const int k_dim, const int v_dim, - const int h, const int total_seqlen, const int s, const int b, - const int cp_size, const int cp_rank, cudaStream_t stream); - -/*! \brief Backward of YARN RoPE for MLA key-value tensor. */ -void nvte_fused_mla_rope_kv_backward(const NVTETensor dk, const NVTETensor dv, const NVTETensor cos, - const NVTETensor sin, NVTETensor d_kv, NVTETensor d_emb, - const NVTETensor cu_seqlens, const int emb_dim, - const int k_dim, const int v_dim, const int h, - const int total_seqlen, const int s, const int b, - const int cp_size, const int cp_rank, cudaStream_t stream); - #ifdef __cplusplus } // extern "C" #endif diff --git a/transformer_engine/common/include/transformer_engine/swizzle.h b/transformer_engine/common/include/transformer_engine/swizzle.h index aa697aafe1..9ea524c7b7 100644 --- a/transformer_engine/common/include/transformer_engine/swizzle.h +++ b/transformer_engine/common/include/transformer_engine/swizzle.h @@ -43,7 +43,8 @@ void nvte_swizzle_scaling_factors(const NVTETensor input, NVTETensor output, cud * - data is quantized along K-dimension, i.e. 1D-scaling block lies along the K-dimension. */ void nvte_multi_tensor_swizzle_scaling_factors(const NVTETensor* inputs, NVTETensor* outputs, - const size_t num_tensors, cudaStream_t stream); + const size_t num_tensors, cudaStream_t stream, + bool check_scale_inv_shapes = true); /*! \brief Unswizzling scaling factors from the interleaved layout used by GEMM back to row-major * diff --git a/transformer_engine/common/swizzle/swizzle.cu b/transformer_engine/common/swizzle/swizzle.cu index 09639d9998..b6b4a05626 100644 --- a/transformer_engine/common/swizzle/swizzle.cu +++ b/transformer_engine/common/swizzle/swizzle.cu @@ -282,6 +282,87 @@ __device__ void swizzle_row_scaling_kernel_impl(const void* input, void* output, } } +template +__global__ void __launch_bounds__(TB_DIM* TB_DIM) + swizzle_row_scaling_kernel(const void* input, void* output, const int M, const int K, + const int original_M, const int original_K) { + swizzle_row_scaling_kernel_impl( + input, output, M, K, original_M, original_K, blockIdx.x, blockIdx.y, gridDim.x, gridDim.y); +} + +// Narrow-K specialization for row scaling swizzle. +// When K is small (num_tiles_k < TB_DIM), the standard kernel wastes threadIdx.x +// because there aren't enough K-tiles to distribute across threads. +// This kernel repurposes the thread dimensions: threadIdx.x iterates rows within +// an M-tile, threadIdx.y indexes M-tiles within the block, processing TB_DIM +// M-tiles per block with full thread utilization. +template +__device__ void swizzle_row_scaling_narrow_k_kernel_impl( + const void* input, void* output, const int M, const int K, + const int original_M, const int original_K, + const int bid, const int grid_dim) { + constexpr int SF_TILE_SIZE_I32 = SF_TILE_DIM_M * SF_TILE_DIM_K / 4; + const int K_i32 = K / 4; + const int num_tiles_m = M / SF_TILE_DIM_M; + + const int m_tile = bid * blockDim.y + threadIdx.y; + const bool active = (m_tile < num_tiles_m); + + extern __shared__ int4 slm_v4i[]; + const int slm_tile_v4i = K_i32 * (SF_TILE_SIZE_I32 / 4); + + if (active) { + const bool padding_m = (m_tile == num_tiles_m - 1) && (original_M < M); + const bool padding_k = (original_K < K); + + int4* my_slm = slm_v4i + threadIdx.y * slm_tile_v4i; + + for (int k = 0; k < K_i32; k++) { + const int input_base = m_tile * SF_TILE_DIM_M * K_i32 + k; + const int* input_i32 = reinterpret_cast(input) + input_base; + + int regs[N_SF_PER_TD_PER_TILE]; +#pragma unroll + for (int i = 0; i < N_SF_PER_TD_PER_TILE; i++) { + const int row = i * TB_DIM + threadIdx.x; + regs[i] = __ldg(input_i32 + row * K_i32); + if (padding_m || padding_k) { + for (int j = 0; j < 4; j++) { + const int byte_row = m_tile * SF_TILE_DIM_M + row; + const int byte_col = k * 4 + j; + if (byte_row >= original_M || byte_col >= original_K) { + reinterpret_cast(®s[i])[j] = 0; + } + } + } + } + + my_slm[k * (SF_TILE_SIZE_I32 / 4) + threadIdx.x] = + *reinterpret_cast(regs); + } + } + + __syncthreads(); + + if (active) { + int4* my_slm = slm_v4i + threadIdx.y * slm_tile_v4i; + int4* out_v4i = reinterpret_cast( + reinterpret_cast(output) + m_tile * SF_TILE_DIM_M * K_i32); + + for (int i = threadIdx.x; i < slm_tile_v4i; i += blockDim.x) { + out_v4i[i] = my_slm[i]; + } + } +} + +template +__global__ void __launch_bounds__(TB_DIM* TB_DIM) + swizzle_row_scaling_narrow_k_kernel(const void* input, void* output, const int M, const int K, + const int original_M, const int original_K) { + swizzle_row_scaling_narrow_k_kernel_impl( + input, output, M, K, original_M, original_K, blockIdx.x, gridDim.x); +} + template __device__ void unswizzle_row_scaling_kernel_impl(const void* input, void* output, const int M, const int K, const int bid_x, const int bid_y, @@ -422,87 +503,6 @@ __global__ void __launch_bounds__(TB_DIM* TB_DIM) } } -template -__global__ void __launch_bounds__(TB_DIM* TB_DIM) - swizzle_row_scaling_kernel(const void* input, void* output, const int M, const int K, - const int original_M, const int original_K) { - swizzle_row_scaling_kernel_impl( - input, output, M, K, original_M, original_K, blockIdx.x, blockIdx.y, gridDim.x, gridDim.y); -} - -// Narrow-K specialization for row scaling swizzle. -// When K is small (num_tiles_k < TB_DIM), the standard kernel wastes threadIdx.x -// because there aren't enough K-tiles to distribute across threads. -// This kernel repurposes the thread dimensions: threadIdx.x iterates rows within -// an M-tile, threadIdx.y indexes M-tiles within the block, processing TB_DIM -// M-tiles per block with full thread utilization. -template -__device__ void swizzle_row_scaling_narrow_k_impl( - const void* input, void* output, const int M, const int K, - const int original_M, const int original_K, - const int bid, const int grid_dim) { - constexpr int SF_TILE_SIZE_I32 = SF_TILE_DIM_M * SF_TILE_DIM_K / 4; - const int K_i32 = K / 4; - const int num_tiles_m = M / SF_TILE_DIM_M; - - const int m_tile = bid * blockDim.y + threadIdx.y; - const bool active = (m_tile < num_tiles_m); - - extern __shared__ int4 slm_v4i[]; - const int slm_tile_v4i = K_i32 * (SF_TILE_SIZE_I32 / 4); - - if (active) { - const bool padding_m = (m_tile == num_tiles_m - 1) && (original_M < M); - const bool padding_k = (original_K < K); - - int4* my_slm = slm_v4i + threadIdx.y * slm_tile_v4i; - - for (int k = 0; k < K_i32; k++) { - const int input_base = m_tile * SF_TILE_DIM_M * K_i32 + k; - const int* input_i32 = reinterpret_cast(input) + input_base; - - int regs[N_SF_PER_TD_PER_TILE]; -#pragma unroll - for (int i = 0; i < N_SF_PER_TD_PER_TILE; i++) { - const int row = i * TB_DIM + threadIdx.x; - regs[i] = __ldg(input_i32 + row * K_i32); - if (padding_m || padding_k) { - for (int j = 0; j < 4; j++) { - const int byte_row = m_tile * SF_TILE_DIM_M + row; - const int byte_col = k * 4 + j; - if (byte_row >= original_M || byte_col >= original_K) { - reinterpret_cast(®s[i])[j] = 0; - } - } - } - } - - my_slm[k * (SF_TILE_SIZE_I32 / 4) + threadIdx.x] = - *reinterpret_cast(regs); - } - } - - __syncthreads(); - - if (active) { - int4* my_slm = slm_v4i + threadIdx.y * slm_tile_v4i; - int4* out_v4i = reinterpret_cast( - reinterpret_cast(output) + m_tile * SF_TILE_DIM_M * K_i32); - - for (int i = threadIdx.x; i < slm_tile_v4i; i += blockDim.x) { - out_v4i[i] = my_slm[i]; - } - } -} - -template -__global__ void __launch_bounds__(TB_DIM* TB_DIM) - swizzle_row_scaling_narrow_k_kernel(const void* input, void* output, const int M, const int K, - const int original_M, const int original_K) { - swizzle_row_scaling_narrow_k_impl( - input, output, M, K, original_M, original_K, blockIdx.x, gridDim.x); -} - constexpr int kMaxTensorsPerKernel = 64; // Args must be <4 KB struct MultiSwizzleArgs { // (input) Data buffers for input scaling factors @@ -672,6 +672,28 @@ __global__ void multi_tensor_swizzle_col_scaling_kernel(MultiSwizzleArgs kernel_ input, output, M, K, original_M, original_K, bid_x, bid_y, grid_dim_x, grid_dim_y); } +template +__global__ void __launch_bounds__(TB_DIM* TB_DIM) + multi_tensor_swizzle_row_scaling_narrow_k_kernel(MultiSwizzleArgs kernel_args) { + const int bid = blockIdx.x; + int tensor_id = 0; + while (kernel_args.block_range[tensor_id + 1] <= bid) { + ++tensor_id; + } + const void* input = kernel_args.input_list[tensor_id]; + void* output = kernel_args.output_list[tensor_id]; + const int M = kernel_args.m_list[tensor_id]; + const int K = kernel_args.k_list[tensor_id]; + const int original_M = kernel_args.original_m_list[tensor_id]; + const int original_K = kernel_args.original_k_list[tensor_id]; + const int flat_bid = bid - kernel_args.block_range[tensor_id]; + const int num_tiles_m = M / SF_TILE_DIM_M; + const int grid_dim = DIVUP(num_tiles_m, TB_DIM); + + swizzle_row_scaling_narrow_k_kernel_impl( + input, output, M, K, original_M, original_K, flat_bid, grid_dim); +} + } // namespace void swizzle_scaling_factors(const Tensor* input, Tensor* output, cudaStream_t stream) { @@ -921,83 +943,120 @@ void swizzle_scaling_factors(const Tensor* input, Tensor* output, cudaStream_t s template void launch_multi_tensor_swizzle_scaling_factors(MultiSwizzleArgs& kernel_args, const int vec_load_size, const bool is_rowwise, + const bool use_narrow_k, cudaStream_t stream) { - int n_tiles_in_tb = TB_DIM * vec_load_size; - int slm_size = n_tiles_in_tb * SF_TILE_DIM_M * SF_TILE_DIM_K * sizeof(int8_t); - /* Calculate number of CUDA blocks needed for each tensor. - * We have to do it here because we have to iterate over all tensors in this batch to - * get the minimum vec_load_size. - */ - for (size_t j = 0; j < kernel_args.num_tensors; j++) { - const int m = kernel_args.m_list[j]; - const int k = kernel_args.k_list[j]; - int num_tiles_m = m / SF_TILE_DIM_M; - int num_tiles_k = k / SF_TILE_DIM_K; - if (is_rowwise) { - kernel_args.block_range[j + 1] = - kernel_args.block_range[j] + DIVUP(num_tiles_k, n_tiles_in_tb) * num_tiles_m; - } else { - kernel_args.block_range[j + 1] = - kernel_args.block_range[j] + - DIVUP(num_tiles_k, TB_DIM) * DIVUP(num_tiles_m, vec_load_size); + // cudaFuncSetAttribute is a host-synchronous driver call; cache the max shared memory + // setting per kernel variant so we only pay the cost when slm_size actually increases. + auto set_smem_if_needed = [](auto kernel_fn, int slm, int& cached) { + if (cached < slm) { + NVTE_CHECK_CUDA(cudaFuncSetAttribute( + kernel_fn, cudaFuncAttributeMaxDynamicSharedMemorySize, slm)); + cached = slm; } - } - // Launch kernel - const int num_blocks = kernel_args.block_range[kernel_args.num_tensors]; + }; + dim3 block_size(TB_DIM, TB_DIM); - if (is_rowwise) { - switch (vec_load_size) { - case 4: - NVTE_CHECK_CUDA(cudaFuncSetAttribute( - multi_tensor_swizzle_row_scaling_kernel, - cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size)); - multi_tensor_swizzle_row_scaling_kernel - <<>>(kernel_args); - break; - case 2: - NVTE_CHECK_CUDA(cudaFuncSetAttribute( - multi_tensor_swizzle_row_scaling_kernel, - cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size)); - multi_tensor_swizzle_row_scaling_kernel - <<>>(kernel_args); - break; - case 1: - NVTE_CHECK_CUDA(cudaFuncSetAttribute( - multi_tensor_swizzle_row_scaling_kernel, - cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size)); - multi_tensor_swizzle_row_scaling_kernel - <<>>(kernel_args); - break; - default: - NVTE_ERROR("Not valid vec_load_size."); - break; + + if (is_rowwise && use_narrow_k) { + // Narrow-K path: each block handles TB_DIM M-tiles with full thread utilization. + // slm_size depends on num_tiles_k, which can vary per tensor — use the max. + int max_num_tiles_k = 0; + for (size_t j = 0; j < kernel_args.num_tensors; j++) { + const int num_tiles_m = kernel_args.m_list[j] / SF_TILE_DIM_M; + const int num_tiles_k = kernel_args.k_list[j] / SF_TILE_DIM_K; + max_num_tiles_k = std::max(max_num_tiles_k, num_tiles_k); + kernel_args.block_range[j + 1] = + kernel_args.block_range[j] + DIVUP(num_tiles_m, TB_DIM); } + int slm_size = TB_DIM * max_num_tiles_k * SF_TILE_DIM_M * SF_TILE_DIM_K * sizeof(int8_t); + const int num_blocks = kernel_args.block_range[kernel_args.num_tensors]; + + static int cached_narrow_k = -1; + set_smem_if_needed( + multi_tensor_swizzle_row_scaling_narrow_k_kernel, + slm_size, cached_narrow_k); + multi_tensor_swizzle_row_scaling_narrow_k_kernel + <<>>(kernel_args); } else { - switch (vec_load_size) { - case 4: - NVTE_CHECK_CUDA(cudaFuncSetAttribute( - multi_tensor_swizzle_col_scaling_kernel, - cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size)); - multi_tensor_swizzle_col_scaling_kernel - <<>>(kernel_args); - break; - case 2: - NVTE_CHECK_CUDA(cudaFuncSetAttribute( - multi_tensor_swizzle_col_scaling_kernel, - cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size)); - multi_tensor_swizzle_col_scaling_kernel - <<>>(kernel_args); - break; - case 1: - NVTE_CHECK_CUDA(cudaFuncSetAttribute( - multi_tensor_swizzle_col_scaling_kernel, - cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size)); - multi_tensor_swizzle_col_scaling_kernel - <<>>(kernel_args); - break; - default: - NVTE_ERROR("Not valid vec_load_size."); - break; + int n_tiles_in_tb = TB_DIM * vec_load_size; + int slm_size = n_tiles_in_tb * SF_TILE_DIM_M * SF_TILE_DIM_K * sizeof(int8_t); + /* Calculate number of CUDA blocks needed for each tensor. + * We have to do it here because we have to iterate over all tensors in this batch to + * get the minimum vec_load_size. + */ + for (size_t j = 0; j < kernel_args.num_tensors; j++) { + const int m = kernel_args.m_list[j]; + const int k = kernel_args.k_list[j]; + int num_tiles_m = m / SF_TILE_DIM_M; + int num_tiles_k = k / SF_TILE_DIM_K; + if (is_rowwise) { + kernel_args.block_range[j + 1] = + kernel_args.block_range[j] + DIVUP(num_tiles_k, n_tiles_in_tb) * num_tiles_m; + } else { + kernel_args.block_range[j + 1] = + kernel_args.block_range[j] + + DIVUP(num_tiles_k, TB_DIM) * DIVUP(num_tiles_m, vec_load_size); + } + } + const int num_blocks = kernel_args.block_range[kernel_args.num_tensors]; + + static int cached_row_int4 = -1, cached_row_int2 = -1, cached_row_int1 = -1; + static int cached_col_int4 = -1, cached_col_int2 = -1, cached_col_int1 = -1; + + if (is_rowwise) { + switch (vec_load_size) { + case 4: + set_smem_if_needed( + multi_tensor_swizzle_row_scaling_kernel, + slm_size, cached_row_int4); + multi_tensor_swizzle_row_scaling_kernel + <<>>(kernel_args); + break; + case 2: + set_smem_if_needed( + multi_tensor_swizzle_row_scaling_kernel, + slm_size, cached_row_int2); + multi_tensor_swizzle_row_scaling_kernel + <<>>(kernel_args); + break; + case 1: + set_smem_if_needed( + multi_tensor_swizzle_row_scaling_kernel, + slm_size, cached_row_int1); + multi_tensor_swizzle_row_scaling_kernel + <<>>(kernel_args); + break; + default: + NVTE_ERROR("Not valid vec_load_size."); + break; + } + } else { + switch (vec_load_size) { + case 4: + set_smem_if_needed( + multi_tensor_swizzle_col_scaling_kernel, + slm_size, cached_col_int4); + multi_tensor_swizzle_col_scaling_kernel + <<>>(kernel_args); + break; + case 2: + set_smem_if_needed( + multi_tensor_swizzle_col_scaling_kernel, + slm_size, cached_col_int2); + multi_tensor_swizzle_col_scaling_kernel + <<>>(kernel_args); + break; + case 1: + set_smem_if_needed( + multi_tensor_swizzle_col_scaling_kernel, + slm_size, cached_col_int1); + multi_tensor_swizzle_col_scaling_kernel + <<>>(kernel_args); + break; + default: + NVTE_ERROR("Not valid vec_load_size."); + break; + } } } NVTE_CHECK_CUDA(cudaGetLastError()); @@ -1087,7 +1146,8 @@ void launch_multi_tensor_unswizzle_scaling_factors(MultiSwizzleArgs& kernel_args } void multi_tensor_swizzle_scaling_factors(const std::vector& input, - std::vector& output, cudaStream_t stream) { + std::vector& output, cudaStream_t stream, + bool check_scale_inv_shapes) { auto num_tensors = input.size(); bool all_has_data = true; bool all_has_columnwise_data = true; @@ -1106,8 +1166,10 @@ void multi_tensor_swizzle_scaling_factors(const std::vector& input, // We don't allow empty tensors. They should be filtered out before calling this function. NVTE_CHECK(input[i]->numel() != 0, "Tensor input[", i, "] is empty."); - CheckInputTensor(*input[i], "scaling_factor_input[" + std::to_string(i) + "]"); - CheckInputTensor(*output[i], "scaling_factor_output[" + std::to_string(i) + "]"); + CheckInputTensor(*input[i], "scaling_factor_input[" + std::to_string(i) + "]", + check_scale_inv_shapes); + CheckInputTensor(*output[i], "scaling_factor_output[" + std::to_string(i) + "]", + check_scale_inv_shapes); all_has_data = all_has_data && input[i]->scale_inv.has_data(); all_has_columnwise_data = (all_has_columnwise_data && input[i]->columnwise_scale_inv.has_data()); @@ -1128,16 +1190,18 @@ void multi_tensor_swizzle_scaling_factors(const std::vector& input, kernel_args.num_tensors = 0; kernel_args.block_range[0] = 0; int vec_load_size = 4; + bool all_narrow_k = true; for (size_t i = 0; i < num_tensors; i++) { //Launch kernel if argument struct is full if (kernel_args.num_tensors == kMaxTensorsPerKernel) { // There is no int3 and misaligned if using int4/int2. if (vec_load_size == 3) vec_load_size = 1; launch_multi_tensor_swizzle_scaling_factors( - kernel_args, vec_load_size, true, stream); + kernel_args, vec_load_size, true, all_narrow_k, stream); // Reset the argument struct and vec_load_size kernel_args.num_tensors = 0; vec_load_size = 4; + all_narrow_k = true; } int m, k; @@ -1171,6 +1235,7 @@ void multi_tensor_swizzle_scaling_factors(const std::vector& input, } int num_tiles_k = k / SF_TILE_DIM_K; + all_narrow_k = all_narrow_k && (num_tiles_k < TB_DIM); int vec_load_size_i = (num_tiles_k - 1) % 4 + 1; // We use the minimum vec_load_size across all tensors. // TODO(zhongbo): fix vec_load_size for NVFP4 @@ -1200,7 +1265,7 @@ void multi_tensor_swizzle_scaling_factors(const std::vector& input, // There is no int3 and misaligned if using int4/int2. if (vec_load_size == 3) vec_load_size = 1; launch_multi_tensor_swizzle_scaling_factors( - kernel_args, vec_load_size, true, stream); + kernel_args, vec_load_size, true, all_narrow_k, stream); } if (columnwise_swizzle) { @@ -1217,7 +1282,7 @@ void multi_tensor_swizzle_scaling_factors(const std::vector& input, // There is no int3 and misaligned if using int4/int2. if (vec_load_size == 3) vec_load_size = 1; launch_multi_tensor_swizzle_scaling_factors( - kernel_args, vec_load_size, false, stream); + kernel_args, vec_load_size, false, false, stream); // Reset the argument struct and vec_load_size kernel_args.num_tensors = 0; vec_load_size = 4; @@ -1252,7 +1317,7 @@ void multi_tensor_swizzle_scaling_factors(const std::vector& input, // There is no int3 and misaligned if using int4/int2. if (vec_load_size == 3) vec_load_size = 1; launch_multi_tensor_swizzle_scaling_factors( - kernel_args, vec_load_size, false, stream); + kernel_args, vec_load_size, false, false, stream); } } @@ -1588,7 +1653,8 @@ void nvte_swizzle_scaling_factors(const NVTETensor input, NVTETensor output, cud } void nvte_multi_tensor_swizzle_scaling_factors(const NVTETensor* inputs, NVTETensor* outputs, - const size_t num_tensors, cudaStream_t stream) { + const size_t num_tensors, cudaStream_t stream, + bool check_scale_inv_shapes) { NVTE_API_CALL(nvte_multi_tensor_swizzle_scaling_factors); using namespace transformer_engine; NVTE_CHECK(num_tensors > 0, "Number of tensors should be greater than 0."); @@ -1597,7 +1663,7 @@ void nvte_multi_tensor_swizzle_scaling_factors(const NVTETensor* inputs, NVTETen input_list.push_back(convertNVTETensorCheck(inputs[i])); output_list.push_back(convertNVTETensorCheck(outputs[i])); } - multi_tensor_swizzle_scaling_factors(input_list, output_list, stream); + multi_tensor_swizzle_scaling_factors(input_list, output_list, stream, check_scale_inv_shapes); } void nvte_unswizzle_scaling_factors(const NVTETensor input, NVTETensor output, diff --git a/transformer_engine/common/transformer_engine.cpp b/transformer_engine/common/transformer_engine.cpp index 68946f6a2b..4351ee3061 100644 --- a/transformer_engine/common/transformer_engine.cpp +++ b/transformer_engine/common/transformer_engine.cpp @@ -106,8 +106,9 @@ void CheckScaleTensorShape(const Tensor &t, const std::string &name) { alignment; const auto &expected = std::vector{expected_x, expected_y}; - // TODO(charleney): re-enable after scale shape rework - (void)expected; + NVTE_CHECK(t.scale_inv.shape == expected, "Tensor \"", name, + "\" has invalid scale_inv shape (expected ", expected, ", got ", + t.scale_inv.shape, ")"); } if (t.has_columnwise_data()) { alignment = block_alignment[1]; @@ -118,8 +119,9 @@ void CheckScaleTensorShape(const Tensor &t, const std::string &name) { expected_y = DIVUP(DIVUP(t.flat_last_dim(), static_cast(1)), alignment) * alignment; const auto &expected = std::vector{expected_x, expected_y}; - // TODO(charleney): re-enable after scale shape rework - (void)expected; + NVTE_CHECK(t.columnwise_scale_inv.shape == expected, "Tensor \"", name, + "\" has invalid columnwise_scale_inv shape (expected ", expected, ", got ", + t.columnwise_scale_inv.shape, ")"); } } else if (t.scaling_mode == NVTE_NVFP4_1D_SCALING) { if (t.has_data()) { @@ -142,7 +144,8 @@ void CheckScaleTensorShape(const Tensor &t, const std::string &name) { } } -void CheckInputTensor(const Tensor &t, const std::string &name) { +void CheckInputTensor(const Tensor &t, const std::string &name, + bool check_scale_inv_shapes) { const DType type = t.dtype(); if (is_fp8_dtype(type)) { // FP8 input needs to have scale_inv @@ -193,7 +196,9 @@ void CheckInputTensor(const Tensor &t, const std::string &name) { } NVTE_CHECK(t.has_data() || t.has_columnwise_data(), "Input ", name, " is not allocated!"); - CheckScaleTensorShape(t, name); + if (check_scale_inv_shapes) { + CheckScaleTensorShape(t, name); + } } void CheckOutputTensor(const Tensor &t, const std::string &name, bool allow_empty) { diff --git a/transformer_engine/pytorch/attention/dot_product_attention/backends.py b/transformer_engine/pytorch/attention/dot_product_attention/backends.py index 7a8319aa9e..b35f2380e4 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/backends.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/backends.py @@ -72,7 +72,8 @@ print_quantizers, ConvertTHDtoBSHD, ConvertBSHDtoTHD, - _mxfp8_pad_and_swizzle_scales, + mxfp8_pad_and_swizzle_scales, + mxfp8_quantize_single_tensor, ) from transformer_engine.pytorch.attention.dot_product_attention.utils import ( AttentionLogging as attn_log, @@ -167,79 +168,6 @@ _qdq_dO_in_f16_bprop = os.getenv("NVTE_QDQ_DO_IN_F16_BPROP", "0") == "1" -def _mxfp8_clone_with_new_scale_inv(tensor: MXFP8Tensor, new_rs) -> MXFP8Tensor: - """Return a new MXFP8Tensor sharing data but with replaced rowwise_scale_inv.""" - return MXFP8Tensor( - shape=tensor.shape, - dtype=tensor.dtype, - rowwise_data=tensor._rowwise_data, - rowwise_scale_inv=new_rs, - columnwise_data=tensor._columnwise_data, - columnwise_scale_inv=tensor._columnwise_scale_inv, - quantizer=tensor._quantizer, - requires_grad=False, - fp8_dtype=tensor._fp8_dtype, - with_gemm_swizzled_scales=tensor._with_gemm_swizzled_scales, - ) - - -def _mxfp8_permute_qkv_scale_inv_to_bhsd( - q: MXFP8Tensor, k: MXFP8Tensor, v: MXFP8Tensor, - q_fmt: str, kv_fmt: str, -): - """Permute Q/K/V scale_inv from their original format to BHSD in a single - batched kernel launch, then pad+swizzle for F8_128x4.""" - if q_fmt in ("bhsd", "htd"): - q_out = _mxfp8_clone_with_new_scale_inv(q, q._rowwise_scale_inv) - k_out = _mxfp8_clone_with_new_scale_inv(k, k._rowwise_scale_inv) - v_out = _mxfp8_clone_with_new_scale_inv(v, v._rowwise_scale_inv) - _mxfp8_pad_and_swizzle_scales(q_out, k_out, v_out) - return q_out, k_out, v_out - - def _view_scale_4d(tensor): - rs = tensor._rowwise_scale_inv - d_scale = rs.shape[-1] - shape_4d = list(tensor.shape[:-1]) + [d_scale] - return rs.view(shape_4d).contiguous(), d_scale - - q_rs_4d, d_scale_q = _view_scale_4d(q) - k_rs_4d, d_scale_k = _view_scale_4d(k) - v_rs_4d, d_scale_v = _view_scale_4d(v) - - fmt = dpa_utils._FORMAT_STR_TO_ENUM[q_fmt] - q_rs_bhsd, k_rs_bhsd, v_rs_bhsd = tex.permute_to_grouped_tensor_fwd( - q_rs_4d, k_rs_4d, v_rs_4d, original_format=fmt, - ) - - q_out = _mxfp8_clone_with_new_scale_inv(q, q_rs_bhsd.view(-1, d_scale_q)) - k_out = _mxfp8_clone_with_new_scale_inv(k, k_rs_bhsd.view(-1, d_scale_k)) - v_out = _mxfp8_clone_with_new_scale_inv(v, v_rs_bhsd.view(-1, d_scale_v)) - _mxfp8_pad_and_swizzle_scales(q_out, k_out, v_out) - return q_out, k_out, v_out - - -def _mxfp8_permute_scale_inv_to_bhsd( - tensor: MXFP8Tensor, src_format: str, -) -> MXFP8Tensor: - """Single-tensor variant for dO in the backward pass.""" - if src_format in ("bhsd", "htd"): - out = _mxfp8_clone_with_new_scale_inv(tensor, tensor._rowwise_scale_inv) - _mxfp8_pad_and_swizzle_scales(out) - return out - - rs = tensor._rowwise_scale_inv - d_scale = rs.shape[-1] - shape_4d = list(tensor.shape[:-1]) + [d_scale] - fmt = dpa_utils._FORMAT_STR_TO_ENUM[src_format] - new_rs = tex.permute_to_grouped_tensor_fwd( - rs.view(shape_4d).contiguous(), original_format=fmt, - )[0].view(-1, d_scale) - - out = _mxfp8_clone_with_new_scale_inv(tensor, new_rs) - _mxfp8_pad_and_swizzle_scales(out) - return out - - class FP8EmulationFunc(torch.autograd.Function): """ Emulate the effects of FP8 quantization on tensors. Used in UnfusedDotProductAttention as follows: @@ -260,12 +188,16 @@ def forward(ctx, tensor1, tensor2, tensor3, quantizer, quantizer_name, qkv_layou x.contiguous() for x in [tensor1, tensor2, tensor3] ] # always in sbhd_sbhd_sbhd shape at this point - q_fp8, k_fp8, v_fp8, qkv_layout = combine_and_quantize( - qkv_layout, query_layer, key_layer, value_layer, quantizer + q_fp8, k_fp8, v_fp8, qkv_layout, _ = combine_and_quantize( + qkv_layout, query_layer, key_layer, value_layer, quantizer, + keep_same_data_and_scale_inv_format=True, ) tensors = combine_and_dequantize( qkv_layout, q_fp8, k_fp8, v_fp8, src_nominal_dtype=query_layer.dtype ) + if isinstance(quantizer, MXFP8Quantizer): + # bhsd_bhsd_bhsd after combine_and_quantize; permute back to sbhd_sbhd_sbhd + tensors = [x.permute(2, 0, 1, 3).contiguous() for x in tensors] elif quantizer_name in ["S_quantizer", "O_quantizer"]: if quantizer is not None: t_fp8 = quantizer(tensor1) @@ -291,12 +223,16 @@ def backward(ctx, grad1, grad2, grad3): elif ctx.quantizer_name == "dQKV_quantizer": query_grad, key_grad, value_grad = [x.contiguous() for x in [grad1, grad2, grad3]] # always in sbhd_sbhd_sbhd shape at this point - dq_fp8, dk_fp8, dv_fp8, new_qkv_layout = combine_and_quantize( - ctx.qkv_layout, query_grad, key_grad, value_grad, ctx.quantizer + dq_fp8, dk_fp8, dv_fp8, new_qkv_layout, _ = combine_and_quantize( + ctx.qkv_layout, query_grad, key_grad, value_grad, ctx.quantizer, + keep_same_data_and_scale_inv_format=True, ) tensors = combine_and_dequantize( new_qkv_layout, dq_fp8, dk_fp8, dv_fp8, src_nominal_dtype=query_grad.dtype ) + if isinstance(ctx.quantizer, MXFP8Quantizer): + # bhsd_bhsd_bhsd after combine_and_quantize; permute back to sbhd_sbhd_sbhd + tensors = [x.permute(2, 0, 1, 3).contiguous() for x in tensors] else: tensors = grad1, grad2, grad3 return tensors[0], tensors[1], tensors[2], None, None, None @@ -1353,20 +1289,10 @@ def forward( if is_input_fp8: q_fp8, k_fp8, v_fp8 = q, k, v else: - q_fp8, k_fp8, v_fp8, qkv_layout = combine_and_quantize( + q_fp8, k_fp8, v_fp8, qkv_layout, qkv_scale_inv_format = combine_and_quantize( qkv_layout, q, k, v, QKV_quantizer, used_in_backward=is_training ) - # For MXFP8: data stays in its original layout (e.g. SBHD). - # Permute only the (much smaller) scale_inv to BHSD (single - # batched kernel for Q/K/V), then pad+swizzle for F8_128x4. - if isinstance(QKV_quantizer, MXFP8Quantizer) and not is_input_fp8: - _, q_fmt, kv_fmt = dpa_utils.get_qkv_format(qkv_layout) - q_fp8, k_fp8, v_fp8 = _mxfp8_permute_qkv_scale_inv_to_bhsd( - q_fp8, k_fp8, v_fp8, q_fmt, kv_fmt, - ) - qkv_scale_inv_format = "bhsd" - # print quantizers print_quantizers( "FusedAttnFunc.forward >> before: ", @@ -1539,7 +1465,7 @@ def forward( tmp_quantizer = QKV_quantizer.copy() if isinstance(tmp_quantizer, MXFP8Quantizer): tmp_quantizer.optimize_for_gemm = False - q_fp8_, k_fp8_, _, _ = combine_and_quantize( + q_fp8_, k_fp8_, _, _, _ = combine_and_quantize( original_qkv_layout, q, k, v, tmp_quantizer, used_in_backward=True ) q_ = q_fp8_.dequantize(dtype=out_nominal_dtype) @@ -1700,7 +1626,7 @@ def backward(ctx, d_out, *_args): if _qdq_dO_in_f16_bprop or _qdq_dO_in_mxfp8_bprop: d_out_qdq_f16 = d_out - d_out_qdq_f16, _ = dpa_utils.permute_to_grouped_tensor(ctx.o_format, d_out_qdq_f16) + d_out_qdq_f16 = dpa_utils.permute_to_grouped_tensor_pytorch(d_out_qdq_f16, ctx.o_format) tmp_quantizer = MXFP8Quantizer( fp8_dtype=tex.DType.kFloat8E4M3, rowwise=True, columnwise=True ) @@ -1733,14 +1659,14 @@ def backward(ctx, d_out, *_args): d_out = d_out.contiguous() d_out_fp8 = None do_format = ctx.o_format + do_scale_inv_format = None if ctx.fp8: if isinstance(d_out, QuantizedTensorStorage): d_out_fp8 = d_out elif isinstance(ctx.dO_quantizer, MXFP8Quantizer): - orig_opt = ctx.dO_quantizer.optimize_for_gemm - ctx.dO_quantizer.optimize_for_gemm = False - d_out_fp8 = ctx.dO_quantizer(d_out) - ctx.dO_quantizer.optimize_for_gemm = orig_opt + d_out_fp8, do_scale_inv_format = mxfp8_quantize_single_tensor( + d_out, ctx.dO_quantizer, do_format, + ) else: d_out_fp8 = ctx.dO_quantizer(d_out) ( @@ -1844,12 +1770,6 @@ def backward(ctx, d_out, *_args): if ctx.fp8_recipe.mxfp8(): out_ = out aux_ctx_tensors.append(d_out) - do_scale_inv_format = None - if isinstance(ctx.dO_quantizer, MXFP8Quantizer) and d_out_fp8 is not None: - d_out_fp8 = _mxfp8_permute_scale_inv_to_bhsd( - d_out_fp8, do_format, - ) - do_scale_inv_format = "bhsd" dq_, dk_, dv_, *rest = fused_attn_bwd( ctx.max_seqlen_q, ctx.max_seqlen_kv, @@ -1890,7 +1810,7 @@ def backward(ctx, d_out, *_args): tmp_quantizer = ctx.QKV_quantizer.copy() if isinstance(tmp_quantizer, MXFP8Quantizer): tmp_quantizer.optimize_for_gemm = False - q_fp8_, k_fp8_, v_fp8_, _ = combine_and_quantize( + q_fp8_, k_fp8_, v_fp8_, _, _ = combine_and_quantize( original_qkv_layout, q, k, v, tmp_quantizer, used_in_backward=True ) q_shadow_f16, k_shadow_f16, v_shadow_f16 = [ @@ -1977,7 +1897,7 @@ def backward(ctx, d_out, *_args): ) if not is_quantized_tensor and ctx.is_input_fp8: # return in FP8 - dq, dk, dv, _ = combine_and_quantize( + dq, dk, dv, _, _ = combine_and_quantize( ctx.dqkv_layout, dq_, dk_, dv_, ctx.dQKV_quantizer ) diff --git a/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py b/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py index 94422b2750..c467dfb45a 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py @@ -46,6 +46,7 @@ combine_and_quantize, combine_and_dequantize, print_quantizers, + mxfp8_quantize_single_tensor, ) _cu_seqlens_info_with_cp_cache = {} @@ -938,6 +939,7 @@ def cp_p2p_fwd_fused_attn( fp8_meta_kwargs = {} new_qkv_layout = qkv_layout + qkv_scale_inv_format = None if fp8: if not fp8_recipe.mxfp8(): q_part, k_part, v_part = [ @@ -945,8 +947,10 @@ def cp_p2p_fwd_fused_attn( for x, y in zip([q_fp8, k_fp8, v_fp8], [q_part, k_part, v_part]) ] else: - q_part, k_part, v_part, new_qkv_layout = combine_and_quantize( - qkv_layout, q_part, k_part, v_part, QKV_quantizer + q_part, k_part, v_part, new_qkv_layout, qkv_scale_inv_format = ( + combine_and_quantize( + qkv_layout, q_part, k_part, v_part, QKV_quantizer + ) ) fp8_meta_kwargs["s_quantizer"] = S_quantizer_per_step fp8_meta_kwargs["o_quantizer"] = O_quantizer_per_step @@ -974,6 +978,7 @@ def cp_p2p_fwd_fused_attn( **fp8_meta_kwargs, return_max_logit=return_max_logit, cuda_graph=is_graph_capturing(), + qkv_scale_inv_format=qkv_scale_inv_format, ) if fp8: @@ -1212,6 +1217,8 @@ def cp_p2p_bwd_fused_attn( aux_tensors += [attn_biases[cp_size - step - 1]] fp8_meta_kwargs = {} + qkv_scale_inv_format = None + do_scale_inv_format = None if fp8: if not fp8_recipe.mxfp8(): q_part, k_part, v_part = [ @@ -1222,23 +1229,26 @@ def cp_p2p_bwd_fused_attn( ) ] else: - q_part, k_part, v_part, qkv_layout = combine_and_quantize( - qkv_layout, - q_part, - k_part, - v_part, - QKV_quantizer_per_step, - used_in_forward=False, - used_in_backward=True, + q_part, k_part, v_part, qkv_layout, qkv_scale_inv_format = ( + combine_and_quantize( + qkv_layout, + q_part, + k_part, + v_part, + QKV_quantizer_per_step, + used_in_forward=False, + used_in_backward=True, + ) ) if not fp8_recipe.mxfp8(): if not (fp8_recipe.float8_current_scaling() and _dpa_fp8_cs_o_in_f16): out_part = Float8Tensor.make_like(out_fp8, data=out_part, dtype=fwd_nominal_dtype) dout_part = Float8Tensor.make_like(dout_fp8, data=dout_part, dtype=bwd_nominal_dtype) else: - dout_part, do_format = dpa_utils.permute_to_grouped_tensor(do_format, dout_part) aux_tensors.append(dout_part) - dout_part = dO_quantizer_per_step(dout_part) + dout_part, do_scale_inv_format = mxfp8_quantize_single_tensor( + dout_part, dO_quantizer_per_step, do_format, + ) fp8_meta_kwargs["s_quantizer"] = S_quantizer fp8_meta_kwargs["dp_quantizer"] = dP_quantizer_per_step fp8_meta_kwargs["dqkv_quantizer"] = dQKV_quantizer_per_step @@ -1268,6 +1278,8 @@ def cp_p2p_bwd_fused_attn( attn_bias_type=attn_bias_type, deterministic=deterministic, cuda_graph=is_graph_capturing(), + qkv_scale_inv_format=qkv_scale_inv_format, + do_scale_inv_format=do_scale_inv_format, **fp8_meta_kwargs, ) @@ -1521,7 +1533,7 @@ def forward( # q_fp8, k_fp8, v_fp8: Float8Tensor, dtype=fwd_nominal_dtype # q, k, v: torch.Tensor, dtype=torch.uint8 q_f16 = q - q_fp8, k_fp8, v_fp8, qkv_layout = combine_and_quantize( + q_fp8, k_fp8, v_fp8, qkv_layout, _ = combine_and_quantize( qkv_layout, q, k, v, QKV_quantizer ) if not fp8_recipe.mxfp8(): @@ -2826,7 +2838,7 @@ def backward(ctx, dout, *_args): dv[cu_seqlens_kv_padded[-1] :].fill_(0) if ctx.fp8 and ctx.is_input_fp8: - dq, dk, dv, _ = combine_and_quantize(ctx.qkv_layout, dq, dk, dv, ctx.dQKV_quantizer) + dq, dk, dv, _, _ = combine_and_quantize(ctx.qkv_layout, dq, dk, dv, ctx.dQKV_quantizer) if ctx.fp8: # print quantizers @@ -3075,7 +3087,7 @@ def forward( assert use_fused_attention, "FP8 is only supported with FusedAttention backend!" fused_attn_backend = tex.NVTE_Fused_Attn_Backend.NVTE_FP8 if not is_input_fp8 and not fp8_recipe.mxfp8(): - q_fp8, k_fp8, v_fp8, qkv_layout = combine_and_quantize( + q_fp8, k_fp8, v_fp8, qkv_layout, _ = combine_and_quantize( qkv_layout, q, k, v, QKV_quantizer ) if not fp8_recipe.mxfp8(): @@ -3171,6 +3183,7 @@ def forward( ] if use_fused_attention: new_qkv_layout = qkv_layout + qkv_scale_inv_format = None if fp8: if not fp8_recipe.mxfp8(): q_part, k_part, v_part = [ @@ -3178,8 +3191,10 @@ def forward( for x, y in zip([q_fp8, k_fp8, v_fp8], [q_part, k_part, v_part]) ] else: - q_part, k_part, v_part, new_qkv_layout = combine_and_quantize( - qkv_layout, q_part, k_part, v_part, QKV_quantizer + q_part, k_part, v_part, new_qkv_layout, qkv_scale_inv_format = ( + combine_and_quantize( + qkv_layout, q_part, k_part, v_part, QKV_quantizer + ) ) ( out_per_step[i], @@ -3208,6 +3223,7 @@ def forward( window_size=window_size_per_step[i], return_max_logit=return_max_logit, cuda_graph=is_graph_capturing(), + qkv_scale_inv_format=qkv_scale_inv_format, **fp8_meta_kwargs, ) if fp8: @@ -3591,6 +3607,8 @@ def backward(ctx, dout, *_args): fp8_meta_kwargs = {} new_qkv_layout = ctx.qkv_layout do_format = ctx.o_format + qkv_scale_inv_format = None + do_scale_inv_format = None if ctx.fp8: fused_attn_backend = tex.NVTE_Fused_Attn_Backend.NVTE_FP8 fp8_meta_kwargs["s_quantizer"] = ctx.S_quantizer @@ -3615,20 +3633,21 @@ def backward(ctx, dout, *_args): dout_fp8, data=dout_part, dtype=ctx.fwd_nominal_dtype ) else: - q_part, k_part, v_part, new_qkv_layout = combine_and_quantize( - ctx.qkv_layout, - q_part, - k_part, - v_part, - ctx.QKV_quantizer, - used_in_forward=False, - used_in_backward=True, - ) - dout_part, do_format = dpa_utils.permute_to_grouped_tensor( - do_format, dout_part + q_part, k_part, v_part, new_qkv_layout, qkv_scale_inv_format = ( + combine_and_quantize( + ctx.qkv_layout, + q_part, + k_part, + v_part, + ctx.QKV_quantizer, + used_in_forward=False, + used_in_backward=True, + ) ) aux_ctx_tensors.append(dout_part) - dout_part = ctx.dO_quantizer(dout_part) + dout_part, do_scale_inv_format = mxfp8_quantize_single_tensor( + dout_part, ctx.dO_quantizer, do_format, + ) dq_per_step[i], dk_per_step[i], dv_per_step[i], *_ = fused_attn_bwd( ctx.max_seqlen_q, max_seqlen_kv, @@ -3655,6 +3674,8 @@ def backward(ctx, dout, *_args): window_size=window_size_per_step[i], deterministic=ctx.deterministic, cuda_graph=is_graph_capturing(), + qkv_scale_inv_format=qkv_scale_inv_format, + do_scale_inv_format=do_scale_inv_format, **fp8_meta_kwargs, ) if ctx.fp8 and all( @@ -3758,7 +3779,7 @@ def backward(ctx, dout, *_args): # quantize if necessary if ctx.fp8 and ctx.is_input_fp8: - dq, dk, dv, _ = combine_and_quantize(ctx.dqkv_layout, dq, dk, dv, ctx.dQKV_quantizer) + dq, dk, dv, _, _ = combine_and_quantize(ctx.dqkv_layout, dq, dk, dv, ctx.dQKV_quantizer) nvtx_range_pop("transformer_engine.AttnFuncWithCPAndKVAllGather.backward") return ( @@ -3935,7 +3956,7 @@ def forward( if is_input_fp8: q_fp8, k_fp8, v_fp8 = q, k, v elif not fp8_recipe.mxfp8(): - q_fp8, k_fp8, v_fp8, qkv_layout = combine_and_quantize( + q_fp8, k_fp8, v_fp8, qkv_layout, _ = combine_and_quantize( qkv_layout, q, k, v, QKV_quantizer ) if not fp8_recipe.mxfp8(): @@ -3998,16 +4019,19 @@ def forward( or (fp8_recipe.float8_current_scaling() and not _dpa_fp8_cs_o_in_f16) ) ) + qkv_scale_inv_format = None if use_fused_attention: if fp8: if fp8_recipe.mxfp8(): - q_fp8, k_fp8, v_fp8, qkv_layout = combine_and_quantize( - qkv_layout, - q_part, - k_part, - v_part, - QKV_quantizer, - used_in_backward=is_training, + q_fp8, k_fp8, v_fp8, qkv_layout, qkv_scale_inv_format = ( + combine_and_quantize( + qkv_layout, + q_part, + k_part, + v_part, + QKV_quantizer, + used_in_backward=is_training, + ) ) q_part, k_part, v_part = [q_fp8, k_fp8, v_fp8] else: @@ -4041,6 +4065,7 @@ def forward( softmax_offset=softmax_offset, return_max_logit=return_max_logit, cuda_graph=is_graph_capturing(), + qkv_scale_inv_format=qkv_scale_inv_format, ) # construct out_part for backward # out_fp8 and out_f16 store the FP8 or F16 tensor for backward saves @@ -4131,6 +4156,7 @@ def forward( ctx.qkv_layout = qkv_layout ctx.o_format = o_format + ctx.qkv_scale_inv_format = qkv_scale_inv_format ctx.dqkv_layout = original_qkv_layout ctx.dqkv_format = qkv_format ctx.orig_q_shape = orig_q_shape @@ -4334,6 +4360,7 @@ def backward(ctx, dout, *_args): dq_fp8, dk_fp8, dv_fp8 = None, None, None if ctx.use_fused_attention: do_format = ctx.o_format + do_scale_inv_format = None q_part, k_part, v_part, out_part, dout_part = q, k, v, out, dout if ctx.fp8: q_part, k_part, v_part, out_part = q_fp8, k_fp8, v_fp8, out_fp8 @@ -4344,10 +4371,10 @@ def backward(ctx, dout, *_args): if not ctx.fp8_recipe.mxfp8(): dout_part = Float8Tensor.make_like(dout_fp8, data=dout, dtype=bwd_nominal_dtype) else: - # do_format = bhsd for both dout (F16) and dout_part (MXFP8) - dout, do_format = dpa_utils.permute_to_grouped_tensor(do_format, dout) aux_ctx_tensors.append(dout) - dout_part = ctx.dO_quantizer(dout) + dout_part, do_scale_inv_format = mxfp8_quantize_single_tensor( + dout, ctx.dO_quantizer, do_format, + ) dq, dk, dv, *rest = fused_attn_bwd( ctx.max_seqlen_q, ctx.max_seqlen_kv, @@ -4374,6 +4401,8 @@ def backward(ctx, dout, *_args): window_size=ctx.window_size, deterministic=ctx.deterministic, cuda_graph=is_graph_capturing(), + qkv_scale_inv_format=ctx.qkv_scale_inv_format, + do_scale_inv_format=do_scale_inv_format, **fp8_meta_kwargs, softmax_type=ctx.softmax_type, ) @@ -4455,7 +4484,7 @@ def backward(ctx, dout, *_args): if ( ctx.fp8_recipe.float8_current_scaling() or ctx.fp8_recipe.mxfp8() ) and ctx.is_input_fp8: - dq, dk, dv, _ = combine_and_quantize( + dq, dk, dv, _, _ = combine_and_quantize( ctx.dqkv_layout, dq, dk, dv, ctx.dQKV_quantizer ) if ctx.fp8_recipe.delayed(): diff --git a/transformer_engine/pytorch/attention/dot_product_attention/utils.py b/transformer_engine/pytorch/attention/dot_product_attention/utils.py index 61ea8666ae..c3de5d3f86 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/utils.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/utils.py @@ -22,7 +22,6 @@ import transformer_engine_torch as tex import transformer_engine as te from transformer_engine.pytorch.cpp_extensions.fused_attn import ( - QKVFormat, QKVLayout, AttnBiasType, AttnMaskType, @@ -2315,69 +2314,10 @@ def print_quantizers( print(f"{label} >> {names[i]:14s}: {type_str}") -_FORMAT_STR_TO_ENUM = { - "bshd": QKVFormat["bshd"], - "sbhd": QKVFormat["sbhd"], -} -def permute_to_grouped_tensor(src_format, tensor): - """Permute tensor from src_format = {bshd, sbhd, thd} to des_format = {bhsd, htd} for MXFP8 quantization.""" - if src_format in ["bhsd", "htd"]: - return tensor, src_format - des_format = "bhsd" if src_format != "thd" else "htd" - tensor = tensor.contiguous() if not tensor.is_contiguous() else tensor - fmt = _FORMAT_STR_TO_ENUM.get(src_format) - if fmt is not None and tensor.dim() == 4: - result = tex.permute_to_grouped_tensor_fwd(tensor, original_format=fmt) - return result[0], des_format - - dim_s_or_t = src_format.find("s") if "s" in src_format else src_format.find("t") - dim_others = [i for i in range(len(tensor.shape)) if i != dim_s_or_t] - new_dims = [*dim_others[:-1], dim_s_or_t, dim_others[-1]] - tensor = tensor.permute(*new_dims).contiguous() - return tensor, des_format - - -class PermuteToGroupedTensor(torch.autograd.Function): - """Permute tensors from {bshd, sbhd} to bhsd format. - - Accepts 1 tensor (key=None, value=None) or 3 tensors (Q, K, V). - """ - - @staticmethod - def forward( - ctx: torch.autograd.function.FunctionCtx, - query: torch.Tensor, - key: torch.Tensor = None, - value: torch.Tensor = None, - original_format: str = "bshd", - ): - # pylint: disable=missing-function-docstring - fmt = _FORMAT_STR_TO_ENUM[original_format] - ctx.original_format = fmt - ctx.num_tensors = 1 if key is None else 3 - results = tex.permute_to_grouped_tensor_fwd(query, key, value, fmt) - if ctx.num_tensors == 1: - return results[0] - return tuple(results) - - @staticmethod - def backward(ctx, *grad_outputs): - # pylint: disable=missing-function-docstring - if ctx.num_tensors == 1: - result = tex.permute_to_grouped_tensor_bwd( - grad_outputs[0], original_format=ctx.original_format, - ) - return result[0], None, None, None - q, k, v = tex.permute_to_grouped_tensor_bwd( - grad_outputs[0], grad_outputs[1], grad_outputs[2], ctx.original_format, - ) - return q, k, v, None - - -def _mxfp8_pad_and_swizzle_scales(*fp8_tensors): +def mxfp8_pad_and_swizzle_scales(*fp8_tensors): """Pad and swizzle scales for MXFP8 tensors quantized with optimize_for_gemm=False. When quantizing with optimize_for_gemm=False, the scales are in their natural @@ -2389,31 +2329,176 @@ def _mxfp8_pad_and_swizzle_scales(*fp8_tensors): rs_list = [t._rowwise_scale_inv for t in fp8_tensors if t._rowwise_scale_inv is not None] cs_list = [t._columnwise_scale_inv for t in fp8_tensors if t._columnwise_scale_inv is not None] if rs_list: - rs_padded = tex.pad_last_dim(rs_list, 4) + rs_padded = tex.multi_tensor_pad_last_dim(rs_list, 4) idx = 0 for t in fp8_tensors: if t._rowwise_scale_inv is not None: t._rowwise_scale_inv = rs_padded[idx] idx += 1 if cs_list: - cs_padded = tex.pad_last_dim(cs_list, 128) + cs_padded = tex.multi_tensor_pad_last_dim(cs_list, 128) idx = 0 for t in fp8_tensors: if t._columnwise_scale_inv is not None: t._columnwise_scale_inv = cs_padded[idx] idx += 1 - for t in fp8_tensors: - tex.swizzle_scales_for_gemm_(t) + has_rs = len(rs_list) > 0 + has_cs = len(cs_list) > 0 + tensor_list = list(fp8_tensors) + if has_rs: + tex.multi_swizzle_scales_for_gemm_( + tensor_list, True, False, check_scale_inv_shapes=False + ) + if has_cs: + tex.multi_swizzle_scales_for_gemm_( + tensor_list, False, True, check_scale_inv_shapes=False + ) + for t in tensor_list: + t._with_gemm_swizzled_scales = True + + +def mxfp8_permute_scale_inv_to_bhsd(*tensors, src_format): + """Permute scale_inv of one or more MXFP8Tensors from *src_format* to BHSD, + then pad+swizzle for F8_128x4. + + Uses a single batched ``tex.permute_to_grouped_tensor_fwd`` call regardless + of how many tensors are passed (1 for dO, 3 for Q/K/V, etc.). + """ + if src_format in ("bhsd", "htd"): + outs = [ + MXFP8Tensor( + shape=t.shape, dtype=t.dtype, + rowwise_data=t._rowwise_data, rowwise_scale_inv=t._rowwise_scale_inv, + columnwise_data=t._columnwise_data, columnwise_scale_inv=t._columnwise_scale_inv, + quantizer=t._quantizer, requires_grad=False, + fp8_dtype=t._fp8_dtype, with_gemm_swizzled_scales=t._with_gemm_swizzled_scales, + ) + for t in tensors + ] + mxfp8_pad_and_swizzle_scales(*outs) + return tuple(outs) + + rs_4d_list = [] + d_scales = [] + for t in tensors: + rs = t._rowwise_scale_inv + d_scale = rs.shape[-1] + shape_4d = list(t.shape[:-1]) + [d_scale] + rs_4d_list.append(rs.view(shape_4d).contiguous()) + d_scales.append(d_scale) + + permuted = tex.permute_to_grouped_tensor_fwd(*rs_4d_list, original_format=src_format) + + outs = [ + MXFP8Tensor( + shape=t.shape, dtype=t.dtype, + rowwise_data=t._rowwise_data, rowwise_scale_inv=p.view(-1, d), + columnwise_data=t._columnwise_data, columnwise_scale_inv=t._columnwise_scale_inv, + quantizer=t._quantizer, requires_grad=False, + fp8_dtype=t._fp8_dtype, with_gemm_swizzled_scales=t._with_gemm_swizzled_scales, + ) + for t, p, d in zip(tensors, permuted, d_scales) + ] + mxfp8_pad_and_swizzle_scales(*outs) + return tuple(outs) + + +def permute_to_grouped_tensor_pytorch(tensor, src_format): + """Permute to BHSD using pure PyTorch operations (.permute().contiguous()).""" + if src_format in ("bhsd", "htd"): + return tensor + dim_s = src_format.find("s") if "s" in src_format else src_format.find("t") + dim_others = [i for i in range(tensor.ndim) if i != dim_s] + new_dims = [*dim_others[:-1], dim_s, dim_others[-1]] + return tensor.permute(*new_dims).contiguous() + + +def mxfp8_quantize_single_tensor(tensor, quantizer, src_format): + """Quantize a single tensor with MXFP8 (no swizzle) and permute its scale_inv to BHSD. + + Combines the quantize-without-swizzle and scale_inv permute+pad+swizzle + steps into a single call. Used for dO in the backward pass. + + Returns (fp8_tensor, scale_inv_format) where scale_inv_format is "bhsd". + """ + orig_optimize = quantizer.optimize_for_gemm + quantizer.optimize_for_gemm = False + fp8_tensor = quantizer(tensor) + quantizer.optimize_for_gemm = orig_optimize + (fp8_tensor,) = mxfp8_permute_scale_inv_to_bhsd(fp8_tensor, src_format=src_format) + return fp8_tensor, "bhsd" + def combine_and_quantize( - qkv_layout, q, k, v, qkv_quantizer, used_in_forward=True, used_in_backward=False + qkv_layout, + q, + k, + v, + qkv_quantizer, + used_in_forward=True, + used_in_backward=False, + keep_same_data_and_scale_inv_format=False, ): - """Combine q,k,v based on qkv_layout and quantize them together""" + """Combine Q, K, V tensors based on qkv_layout and quantize them together. + + For non-MXFP8 quantizers the tensors are concatenated (packed or flat), + quantized in one shot, then split back. + + For MXFP8, the behaviour depends on ``keep_same_data_and_scale_inv_format``: + + * **True** – permute the high-precision data to BHSD *before* quantising + so that both the FP8 data and its scale_inv end up in the same (BHSD) + layout. This is the simpler, pre-existing path used by + ``FP8EmulationFunc`` (unfused attention). + * **False** (default) – quantise in the original layout *without* the + internal GEMM-swizzle, then permute only the much smaller scale_inv + tensors to BHSD and apply padding + swizzle. This avoids an expensive + permutation of the large FP8 data tensors and is the optimised path + used by fused attention. + + Parameters + ---------- + qkv_layout : str + Layout descriptor, e.g. ``"sbhd_sbhd_sbhd"`` or ``"bshd_bshd_bshd"``. + q : torch.Tensor + Query tensor in the layout indicated by *qkv_layout*. + k : torch.Tensor + Key tensor. + v : torch.Tensor + Value tensor. + qkv_quantizer : Quantizer + Quantizer instance (e.g. ``MXFP8Quantizer``, ``Float8Quantizer``). + used_in_forward : bool, default = True + Hint for MXFP8 rowwise/columnwise allocation. + used_in_backward : bool, default = False + Hint for MXFP8 rowwise/columnwise allocation. + keep_same_data_and_scale_inv_format : bool, default = False + MXFP8 only. When True, permute high-precision data to BHSD before + quantising (with swizzle) so that data and scale_inv share the same + layout. When False, data stays in the original layout and only the + scale_inv is permuted to BHSD afterwards. + + Returns + ------- + q_fp8 : QuantizedTensor + Quantized query. + k_fp8 : QuantizedTensor + Quantized key. + v_fp8 : QuantizedTensor + Quantized value. + qkv_layout : str + May change to ``"bhsd_bhsd_bhsd"`` when + ``keep_same_data_and_scale_inv_format=True`` and a permutation was + needed. + qkv_scale_inv_format : str or None + ``"bhsd"`` when the scale_inv layout differs from the data layout + (i.e. ``keep_same_data_and_scale_inv_format=False`` and a permutation + was needed); ``None`` otherwise. + """ if isinstance(qkv_quantizer, MXFP8Quantizer): qkv_format, q_format, kv_format = get_qkv_format(qkv_layout) - original_shapes = [x.shape for x in [q, k, v]] _seq_dim = {"sbhd": 0, "bshd": 1, "bhsd": 2, "htd": 1} d_qk = q.shape[-1] d_v = v.shape[-1] @@ -2423,24 +2508,38 @@ def combine_and_quantize( "MXFP8 quantization requires s_q % 128 == 0, s_kv % 128 == 0, d_qk % 32 == 0, d_v % 32" f" == 0. Found {s_q=}, {s_kv=}, {d_qk=}, {d_v=}." ) + + needs_permute = qkv_format not in ("bhsd", "htd") + qkv_scale_inv_format = None + + if keep_same_data_and_scale_inv_format and needs_permute: + # Permute f16 data to BHSD, then quantize with swizzle + # Prefer the fused custom kernel; fall back to pytorch if unsupported + if qkv_format in ("bshd", "sbhd") and q.dim() == 4: + q, k, v = tex.permute_to_grouped_tensor_fwd(q, k, v, original_format=q_format) + else: + q = permute_to_grouped_tensor_pytorch(q, q_format) + k = permute_to_grouped_tensor_pytorch(k, kv_format) + v = permute_to_grouped_tensor_pytorch(v, kv_format) + qkv_layout = "bhsd_bhsd_bhsd" if qkv_format != "thd" else "htd_htd_htd" + + original_shapes = [x.shape for x in [q, k, v]] q, k, v = [x.view(-1, x.shape[-1]) for x in [q, k, v]] - # Quantize without internal swizzle. The caller is responsible for - # permuting scale_inv to the required format (e.g. BHSD for cuDNN) - # and applying pad + swizzle before passing to fused attention. - orig_optimize = qkv_quantizer.optimize_for_gemm - qkv_quantizer.optimize_for_gemm = False + if not keep_same_data_and_scale_inv_format: + orig_optimize = qkv_quantizer.optimize_for_gemm + qkv_quantizer.optimize_for_gemm = False if used_in_forward and used_in_backward: q_fp8, k_fp8, v_fp8 = [qkv_quantizer(x) for x in [q, k, v]] - if used_in_forward and not used_in_backward: + elif used_in_forward and not used_in_backward: qkv_quantizer.rowwise_usage = True qkv_quantizer.columnwise_usage = False q_fp8, k_fp8 = [qkv_quantizer(x) for x in [q, k]] qkv_quantizer.rowwise_usage = False qkv_quantizer.columnwise_usage = True v_fp8 = qkv_quantizer(v) - if (not used_in_forward) and used_in_backward: + elif (not used_in_forward) and used_in_backward: qkv_quantizer.rowwise_usage = True qkv_quantizer.columnwise_usage = True q_fp8, k_fp8 = [qkv_quantizer(x) for x in [q, k]] @@ -2448,12 +2547,19 @@ def combine_and_quantize( qkv_quantizer.columnwise_usage = False v_fp8 = qkv_quantizer(v) - qkv_quantizer.optimize_for_gemm = orig_optimize + if not keep_same_data_and_scale_inv_format: + qkv_quantizer.optimize_for_gemm = orig_optimize - # view rowwise/columnwise data back to original shapes, not rowwise_scale_inv/columnwise_scale_inv q_fp8, k_fp8, v_fp8 = [x.view(s) for x, s in zip([q_fp8, k_fp8, v_fp8], original_shapes)] - return q_fp8, k_fp8, v_fp8, qkv_layout + if not keep_same_data_and_scale_inv_format: + # Permute only scale_inv to BHSD + pad + swizzle + q_fp8, k_fp8, v_fp8 = mxfp8_permute_scale_inv_to_bhsd( + q_fp8, k_fp8, v_fp8, src_format=q_format, + ) + qkv_scale_inv_format = "bhsd" + + return q_fp8, k_fp8, v_fp8, qkv_layout, qkv_scale_inv_format qkv_layout = qkv_layout.replace("paged_kv_", "") qkv_group = len(qkv_layout.split("_")) @@ -2498,7 +2604,7 @@ def combine_and_quantize( for x in [q_data, k_data, v_data] ] - return q_fp8, k_fp8, v_fp8, qkv_layout + return q_fp8, k_fp8, v_fp8, qkv_layout, None def combine_and_dequantize( diff --git a/transformer_engine/pytorch/attention/rope.py b/transformer_engine/pytorch/attention/rope.py index db960d1655..77ad57ed8f 100644 --- a/transformer_engine/pytorch/attention/rope.py +++ b/transformer_engine/pytorch/attention/rope.py @@ -12,13 +12,7 @@ from transformer_engine.pytorch.cpp_extensions.fused_attn import QKVFormat -__all__ = [ - "RotaryPositionEmbedding", - "apply_rotary_pos_emb", - "apply_fused_qkv_rotary_pos_emb", - "fused_apply_mla_rope_for_q", - "fused_apply_mla_rope_for_kv", -] +__all__ = ["RotaryPositionEmbedding", "apply_rotary_pos_emb", "apply_fused_qkv_rotary_pos_emb"] class RotaryPositionEmbedding(torch.nn.Module): @@ -261,230 +255,6 @@ def backward( return grad_input, None, None, None, None, None, None, None, None -class FusedMLARoPEQFunc(torch.autograd.Function): - """ - Autograd function for applying YARN RoPE to MLA's query using CUDA kernels. - - Reads interleaved elements from the last emb_dim of each head, applies YARN - rotation with split cos/sin (left and right halves), and writes de-interleaved. - The first qk_head_dim elements per head are copied unchanged. - - Supports both SBHD [s, b, h, d] and THD [t, h, d] input formats. - """ - - @staticmethod - def forward( - ctx, - q: torch.Tensor, - cos: torch.Tensor, - sin: torch.Tensor, - qk_head_dim: int, - emb_dim: int, - cu_seqlens_q: Union[torch.Tensor, None] = None, - cp_rank: int = 0, - cp_size: int = 1, - ) -> torch.Tensor: - if cos.dtype != torch.float32: - cos = cos.float() - if sin.dtype != torch.float32: - sin = sin.float() - cos = cos.contiguous().view(-1, emb_dim) - sin = sin.contiguous().view(-1, emb_dim) - - output = tex.fused_mla_rope_q_forward( - q, cos, sin, cu_seqlens_q, qk_head_dim, emb_dim, cp_size, cp_rank - ) - ctx.save_for_backward(cos, sin) - ctx.qk_head_dim = qk_head_dim - ctx.emb_dim = emb_dim - ctx.cu_seqlens_q = cu_seqlens_q - ctx.cp_rank = cp_rank - ctx.cp_size = cp_size - return output - - @staticmethod - def backward( - ctx, grad_output: torch.Tensor - ) -> Tuple[Union[torch.Tensor, None], ...]: - cos, sin = ctx.saved_tensors - grad_input = tex.fused_mla_rope_q_backward( - grad_output, - cos, - sin, - ctx.cu_seqlens_q, - ctx.qk_head_dim, - ctx.emb_dim, - ctx.cp_size, - ctx.cp_rank, - ) - return grad_input, None, None, None, None, None, None, None - - -class FusedMLARoPEKVFunc(torch.autograd.Function): - """ - Autograd function for applying YARN RoPE to MLA's key and value using CUDA kernels. - - Splits the input KV tensor into key and value, applies YARN rotation to a - separate k_pos_emb (shared across heads), and concatenates the rotated - embedding to each head of the output key. - - Supports both SBHD [s, b, h, k_dim+v_dim] and THD [t, h, k_dim+v_dim] formats. - """ - - @staticmethod - def forward( - ctx, - kv: torch.Tensor, - k_pos_emb: torch.Tensor, - cos: torch.Tensor, - sin: torch.Tensor, - emb_dim: int, - k_dim: int, - v_dim: int, - cu_seqlens_kv: Union[torch.Tensor, None] = None, - cp_rank: int = 0, - cp_size: int = 1, - ) -> Tuple[torch.Tensor, torch.Tensor]: - if cos.dtype != torch.float32: - cos = cos.float() - if sin.dtype != torch.float32: - sin = sin.float() - cos = cos.contiguous().view(-1, emb_dim) - sin = sin.contiguous().view(-1, emb_dim) - - o_key, o_value = tex.fused_mla_rope_kv_forward( - kv, k_pos_emb, cos, sin, cu_seqlens_kv, emb_dim, k_dim, v_dim, cp_size, cp_rank - ) - ctx.save_for_backward(cos, sin) - ctx.emb_dim = emb_dim - ctx.k_dim = k_dim - ctx.v_dim = v_dim - ctx.cu_seqlens_kv = cu_seqlens_kv - ctx.cp_rank = cp_rank - ctx.cp_size = cp_size - return o_key, o_value - - @staticmethod - def backward( - ctx, dk: torch.Tensor, dv: torch.Tensor - ) -> Tuple[Union[torch.Tensor, None], ...]: - cos, sin = ctx.saved_tensors - d_kv, d_emb = tex.fused_mla_rope_kv_backward( - dk, - dv, - cos, - sin, - ctx.cu_seqlens_kv, - ctx.emb_dim, - ctx.k_dim, - ctx.v_dim, - ctx.cp_size, - ctx.cp_rank, - ) - return d_kv, d_emb, None, None, None, None, None, None, None, None - - -def fused_apply_mla_rope_for_q( - t: torch.Tensor, - cos: torch.Tensor, - sin: torch.Tensor, - qk_head_dim: int, - emb_dim: int, - cu_seqlens_q: Optional[torch.Tensor] = None, - cp_rank: int = 0, - cp_size: int = 1, -) -> torch.Tensor: - """ - Apply YARN RoPE to MLA's query using fused CUDA kernels. - - Along the last dimension of each head, the first qk_head_dim elements are - unchanged and the last emb_dim elements receive YARN rotation. The input is - read interleaved and written de-interleaved. - - Parameters - ---------- - t : torch.Tensor - Query tensor of shape [s, b, h, qk_head_dim + emb_dim] (SBHD) - or [total_t, h, qk_head_dim + emb_dim] (THD). - cos : torch.Tensor - Pre-computed cosine tensor [max_s, 1, 1, emb_dim] or [max_s, emb_dim]. - sin : torch.Tensor - Pre-computed sine tensor [max_s, 1, 1, emb_dim] or [max_s, emb_dim]. - qk_head_dim : int - Dimension of the non-RoPE prefix per head. - emb_dim : int - RoPE embedding dimension. - cu_seqlens_q : torch.Tensor, optional - Cumulative sequence lengths [num_seqs + 1] for THD format. - cp_rank : int - Context parallel rank. - cp_size : int - Context parallel world size. - - Returns - ------- - torch.Tensor - Output tensor with same shape as input, YARN RoPE applied. - """ - return FusedMLARoPEQFunc.apply( - t, cos, sin, qk_head_dim, emb_dim, cu_seqlens_q, cp_rank, cp_size - ) - - -def fused_apply_mla_rope_for_kv( - kv: torch.Tensor, - k_pos_emb: torch.Tensor, - cos: torch.Tensor, - sin: torch.Tensor, - emb_dim: int, - k_dim: int, - v_dim: int, - cu_seqlens_kv: Optional[torch.Tensor] = None, - cp_rank: int = 0, - cp_size: int = 1, -) -> Tuple[torch.Tensor, torch.Tensor]: - """ - Apply YARN RoPE to MLA's key and value using fused CUDA kernels. - - Splits KV into key and value, applies YARN rotation to k_pos_emb (shared - across heads), and concatenates the rotated embedding to each head of the - output key. - - Parameters - ---------- - kv : torch.Tensor - Combined KV tensor [s, b, h, k_dim + v_dim] (SBHD) - or [total_t, h, k_dim + v_dim] (THD). - k_pos_emb : torch.Tensor - Positional embedding [s, b, 1, emb_dim] or [total_t, 1, emb_dim]. - cos : torch.Tensor - Pre-computed cosine tensor [max_s, 1, 1, emb_dim] or [max_s, emb_dim]. - sin : torch.Tensor - Pre-computed sine tensor [max_s, 1, 1, emb_dim] or [max_s, emb_dim]. - emb_dim : int - RoPE embedding dimension. - k_dim : int - Key dimension per head. - v_dim : int - Value dimension per head. - cu_seqlens_kv : torch.Tensor, optional - Cumulative sequence lengths [num_seqs + 1] for THD format. - cp_rank : int - Context parallel rank. - cp_size : int - Context parallel world size. - - Returns - ------- - tuple[torch.Tensor, torch.Tensor] - (o_key, o_value) where o_key has shape [..., h, k_dim + emb_dim] - and o_value has shape [..., h, v_dim]. - """ - return FusedMLARoPEKVFunc.apply( - kv, k_pos_emb, cos, sin, emb_dim, k_dim, v_dim, cu_seqlens_kv, cp_rank, cp_size - ) - - def _rotate_half(x: torch.Tensor, interleaved: bool) -> torch.Tensor: """Change sign so the last dimension becomes [-odd, +even] diff --git a/transformer_engine/pytorch/cpp_extensions/fused_attn.py b/transformer_engine/pytorch/cpp_extensions/fused_attn.py index 89ff372c40..6c78325587 100644 --- a/transformer_engine/pytorch/cpp_extensions/fused_attn.py +++ b/transformer_engine/pytorch/cpp_extensions/fused_attn.py @@ -313,7 +313,7 @@ def fused_attn_fwd( # execute kernel - _qkv_scale_inv_fmt = ( + _qkv_scale_inv_format = ( QKVFormat[qkv_scale_inv_format] if qkv_scale_inv_format is not None else NVTE_QKV_Format.NVTE_QKV_Format_NOT_SET @@ -350,7 +350,7 @@ def fused_attn_fwd( rng_elts_per_thread, return_max_logit, cuda_graph, - _qkv_scale_inv_fmt, + _qkv_scale_inv_format, ) if return_max_logit: @@ -566,12 +566,12 @@ def fused_attn_bwd( f" for backend={fused_attention_backend}." ) - _qkv_scale_inv_fmt = ( + _qkv_scale_inv_format = ( QKVFormat[qkv_scale_inv_format] if qkv_scale_inv_format is not None else NVTE_QKV_Format.NVTE_QKV_Format_NOT_SET ) - _do_scale_inv_fmt = ( + _do_scale_inv_format = ( QKVFormat[do_scale_inv_format] if do_scale_inv_format is not None else NVTE_QKV_Format.NVTE_QKV_Format_NOT_SET @@ -607,8 +607,8 @@ def fused_attn_bwd( dp_quantizer, dqkv_quantizer, cuda_graph, - _qkv_scale_inv_fmt, - _do_scale_inv_fmt, + _qkv_scale_inv_format, + _do_scale_inv_format, ) return output_tensors diff --git a/transformer_engine/pytorch/csrc/extensions.h b/transformer_engine/pytorch/csrc/extensions.h index ba765d1922..c591957aa5 100644 --- a/transformer_engine/pytorch/csrc/extensions.h +++ b/transformer_engine/pytorch/csrc/extensions.h @@ -117,12 +117,12 @@ at::Tensor fa_prepare_bwd(at::Tensor q, at::Tensor k, at::Tensor v); std::vector permute_to_grouped_tensor_fwd( at::Tensor query, std::optional key, std::optional value, - NVTE_QKV_Format original_format); + const std::string &original_format); std::vector permute_to_grouped_tensor_bwd( at::Tensor query_grad, std::optional key_grad, - std::optional value_grad, NVTE_QKV_Format original_format); + std::optional value_grad, const std::string &original_format); -std::vector pad_last_dim(std::vector inputs, int64_t alignment); +std::vector multi_tensor_pad_last_dim(std::vector inputs, int64_t alignment); at::Tensor convert_thd_to_bshd(at::Tensor tensor, at::Tensor cu_seqlens, int b, int max_seq_len); at::Tensor convert_bshd_to_thd(at::Tensor tensor, at::Tensor cu_seqlens, int t); @@ -462,28 +462,6 @@ at::Tensor fused_qkv_rope_backward(const at::Tensor &q_grad_out, const at::Tenso const NVTE_QKV_Format qkv_format, const bool interleaved, const int cp_size, const int cp_rank); -at::Tensor fused_mla_rope_q_forward(const at::Tensor &q_input, const at::Tensor &cos, - const at::Tensor &sin, - const std::optional cu_seqlens, - const int qk_head_dim, const int emb_dim, const int cp_size, - const int cp_rank); - -at::Tensor fused_mla_rope_q_backward(const at::Tensor &grad_output, const at::Tensor &cos, - const at::Tensor &sin, - const std::optional cu_seqlens, - const int qk_head_dim, const int emb_dim, const int cp_size, - const int cp_rank); - -std::tuple fused_mla_rope_kv_forward( - const at::Tensor &kv_input, const at::Tensor &k_pos_emb, const at::Tensor &cos, - const at::Tensor &sin, const std::optional cu_seqlens, const int emb_dim, - const int k_dim, const int v_dim, const int cp_size, const int cp_rank); - -std::tuple fused_mla_rope_kv_backward( - const at::Tensor &dk, const at::Tensor &dv, const at::Tensor &cos, const at::Tensor &sin, - const std::optional cu_seqlens, const int emb_dim, const int k_dim, - const int v_dim, const int cp_size, const int cp_rank); - /*************************************************************************************************** * Miscellaneous **************************************************************************************************/ @@ -605,6 +583,10 @@ void fused_multi_row_unpadding(at::Tensor input, at::Tensor output, void inplace_swizzle_scale_for_gemm(py::handle &tensor); +void inplace_multi_swizzle_scales_for_gemm(std::vector &tensors, bool rowwise_usage, + bool columnwise_usage, + bool check_scale_inv_shapes = true); + void grouped_swizzle_for_gemm(py::handle &tensor, bool rowwise, bool columnwise); /*************************************************************************************************** diff --git a/transformer_engine/pytorch/csrc/extensions/apply_rope.cpp b/transformer_engine/pytorch/csrc/extensions/apply_rope.cpp index 7e6bdef40d..4392fa4b43 100644 --- a/transformer_engine/pytorch/csrc/extensions/apply_rope.cpp +++ b/transformer_engine/pytorch/csrc/extensions/apply_rope.cpp @@ -298,216 +298,4 @@ at::Tensor fused_qkv_rope_backward(const at::Tensor &q_grad_out, const at::Tenso return qkv_grad_input; } -at::Tensor fused_mla_rope_q_forward(const at::Tensor &q_input, const at::Tensor &cos, - const at::Tensor &sin, - const std::optional cu_seqlens, - const int qk_head_dim, const int emb_dim, const int cp_size, - const int cp_rank) { - TORCH_CHECK(cos.scalar_type() == at::ScalarType::Float, "cos must be float32"); - TORCH_CHECK(sin.scalar_type() == at::ScalarType::Float, "sin must be float32"); - TORCH_CHECK(cos.is_contiguous(), "cos must be contiguous"); - TORCH_CHECK(sin.is_contiguous(), "sin must be contiguous"); - - int max_seqlen = 0, batch_size = 0, nheads = 0, headdim = 0, total_seqlen = 0, s = 0, b = 0; - at::Tensor q_flat; - if (cu_seqlens.has_value()) { - TORCH_CHECK(q_input.dim() == 3, "expected 3D tensor for THD format"); - total_seqlen = q_input.size(0); - nheads = q_input.size(1); - headdim = q_input.size(2); - b = cu_seqlens.value().size(0) - 1; - s = 0; - q_flat = q_input.contiguous(); - } else { - TORCH_CHECK(q_input.dim() == 4, "expected 4D tensor for SBHD format"); - max_seqlen = q_input.size(0); - batch_size = q_input.size(1); - nheads = q_input.size(2); - headdim = q_input.size(3); - q_flat = q_input.contiguous().view({max_seqlen * batch_size, nheads, headdim}); - total_seqlen = q_flat.size(0); - s = max_seqlen; - b = batch_size; - } - TORCH_CHECK(headdim == qk_head_dim + emb_dim, "headdim must equal qk_head_dim + emb_dim"); - - auto q_out = at::empty_like(q_flat); - auto q_in_cu = makeTransformerEngineTensor(q_flat); - auto cos_cu = makeTransformerEngineTensor(cos); - auto sin_cu = makeTransformerEngineTensor(sin); - auto q_out_cu = makeTransformerEngineTensor(q_out); - auto cu_seqlens_cu = TensorWrapper(); - if (cu_seqlens.has_value()) { - cu_seqlens_cu = makeTransformerEngineTensor(cu_seqlens.value()); - } - - nvte_fused_mla_rope_q_forward(q_in_cu.data(), cos_cu.data(), sin_cu.data(), q_out_cu.data(), - cu_seqlens_cu.data(), qk_head_dim, emb_dim, nheads, headdim, - total_seqlen, s, b, cp_size, cp_rank, - at::cuda::getCurrentCUDAStream()); - - if (!cu_seqlens.has_value()) { - q_out = q_out.view({max_seqlen, batch_size, nheads, headdim}); - } - return q_out; -} - -at::Tensor fused_mla_rope_q_backward(const at::Tensor &grad_output, const at::Tensor &cos, - const at::Tensor &sin, - const std::optional cu_seqlens, - const int qk_head_dim, const int emb_dim, const int cp_size, - const int cp_rank) { - int max_seqlen = 0, batch_size = 0, nheads = 0, headdim = 0, total_seqlen = 0, s = 0, b = 0; - at::Tensor grad_flat; - if (cu_seqlens.has_value()) { - total_seqlen = grad_output.size(0); - nheads = grad_output.size(1); - headdim = grad_output.size(2); - b = cu_seqlens.value().size(0) - 1; - s = 0; - grad_flat = grad_output.contiguous(); - } else { - max_seqlen = grad_output.size(0); - batch_size = grad_output.size(1); - nheads = grad_output.size(2); - headdim = grad_output.size(3); - grad_flat = grad_output.contiguous().view({max_seqlen * batch_size, nheads, headdim}); - total_seqlen = grad_flat.size(0); - s = max_seqlen; - b = batch_size; - } - - auto grad_in = at::empty_like(grad_flat); - auto grad_out_cu = makeTransformerEngineTensor(grad_flat); - auto cos_cu = makeTransformerEngineTensor(cos); - auto sin_cu = makeTransformerEngineTensor(sin); - auto grad_in_cu = makeTransformerEngineTensor(grad_in); - auto cu_seqlens_cu = TensorWrapper(); - if (cu_seqlens.has_value()) { - cu_seqlens_cu = makeTransformerEngineTensor(cu_seqlens.value()); - } - - nvte_fused_mla_rope_q_backward(grad_out_cu.data(), cos_cu.data(), sin_cu.data(), - grad_in_cu.data(), cu_seqlens_cu.data(), qk_head_dim, emb_dim, - nheads, headdim, total_seqlen, s, b, cp_size, cp_rank, - at::cuda::getCurrentCUDAStream()); - - if (!cu_seqlens.has_value()) { - grad_in = grad_in.view({max_seqlen, batch_size, nheads, headdim}); - } - return grad_in; -} - -std::tuple fused_mla_rope_kv_forward( - const at::Tensor &kv_input, const at::Tensor &k_pos_emb, const at::Tensor &cos, - const at::Tensor &sin, const std::optional cu_seqlens, const int emb_dim, - const int k_dim, const int v_dim, const int cp_size, const int cp_rank) { - TORCH_CHECK(cos.scalar_type() == at::ScalarType::Float, "cos must be float32"); - TORCH_CHECK(sin.scalar_type() == at::ScalarType::Float, "sin must be float32"); - TORCH_CHECK(cos.is_contiguous(), "cos must be contiguous"); - TORCH_CHECK(sin.is_contiguous(), "sin must be contiguous"); - TORCH_CHECK(kv_input.size(-1) == k_dim + v_dim, "last dim of kv must be k_dim + v_dim"); - - int max_seqlen = 0, batch_size = 0, nheads = 0, total_seqlen = 0, s = 0, b_val = 0; - at::Tensor kv_flat, emb_flat; - if (cu_seqlens.has_value()) { - TORCH_CHECK(kv_input.dim() == 3, "expected 3D tensor for THD format"); - total_seqlen = kv_input.size(0); - nheads = kv_input.size(1); - b_val = cu_seqlens.value().size(0) - 1; - s = 0; - kv_flat = kv_input.contiguous(); - emb_flat = k_pos_emb.contiguous().view({total_seqlen, emb_dim}); - } else { - TORCH_CHECK(kv_input.dim() == 4, "expected 4D tensor for SBHD format"); - max_seqlen = kv_input.size(0); - batch_size = kv_input.size(1); - nheads = kv_input.size(2); - kv_flat = kv_input.contiguous().view({max_seqlen * batch_size, nheads, k_dim + v_dim}); - emb_flat = k_pos_emb.contiguous().view({max_seqlen * batch_size, emb_dim}); - total_seqlen = kv_flat.size(0); - s = max_seqlen; - b_val = batch_size; - } - - auto opts = at::TensorOptions().dtype(kv_input.scalar_type()).device(kv_input.device()); - auto o_key = at::empty({total_seqlen, nheads, k_dim + emb_dim}, opts); - auto o_value = at::empty({total_seqlen, nheads, v_dim}, opts); - - auto kv_cu = makeTransformerEngineTensor(kv_flat); - auto emb_cu = makeTransformerEngineTensor(emb_flat); - auto cos_cu = makeTransformerEngineTensor(cos); - auto sin_cu = makeTransformerEngineTensor(sin); - auto okey_cu = makeTransformerEngineTensor(o_key); - auto oval_cu = makeTransformerEngineTensor(o_value); - auto cu_seqlens_cu = TensorWrapper(); - if (cu_seqlens.has_value()) { - cu_seqlens_cu = makeTransformerEngineTensor(cu_seqlens.value()); - } - - nvte_fused_mla_rope_kv_forward(kv_cu.data(), emb_cu.data(), cos_cu.data(), sin_cu.data(), - okey_cu.data(), oval_cu.data(), cu_seqlens_cu.data(), emb_dim, - k_dim, v_dim, nheads, total_seqlen, s, b_val, cp_size, cp_rank, - at::cuda::getCurrentCUDAStream()); - - if (!cu_seqlens.has_value()) { - o_key = o_key.view({max_seqlen, batch_size, nheads, k_dim + emb_dim}); - o_value = o_value.view({max_seqlen, batch_size, nheads, v_dim}); - } - return std::make_tuple(o_key, o_value); -} - -std::tuple fused_mla_rope_kv_backward( - const at::Tensor &dk, const at::Tensor &dv, const at::Tensor &cos, const at::Tensor &sin, - const std::optional cu_seqlens, const int emb_dim, const int k_dim, - const int v_dim, const int cp_size, const int cp_rank) { - int max_seqlen = 0, batch_size = 0, nheads = 0, total_seqlen = 0, s = 0, b_val = 0; - at::Tensor dk_flat, dv_flat; - if (cu_seqlens.has_value()) { - total_seqlen = dk.size(0); - nheads = dk.size(1); - b_val = cu_seqlens.value().size(0) - 1; - s = 0; - dk_flat = dk.contiguous(); - dv_flat = dv.contiguous(); - } else { - max_seqlen = dk.size(0); - batch_size = dk.size(1); - nheads = dk.size(2); - dk_flat = dk.contiguous().view({max_seqlen * batch_size, nheads, k_dim + emb_dim}); - dv_flat = dv.contiguous().view({max_seqlen * batch_size, nheads, v_dim}); - total_seqlen = dk_flat.size(0); - s = max_seqlen; - b_val = batch_size; - } - - auto opts = at::TensorOptions().dtype(dk.scalar_type()).device(dk.device()); - auto d_kv = at::empty({total_seqlen, nheads, k_dim + v_dim}, opts); - auto d_emb = at::empty({total_seqlen, emb_dim}, opts); - - auto dk_cu = makeTransformerEngineTensor(dk_flat); - auto dv_cu = makeTransformerEngineTensor(dv_flat); - auto cos_cu = makeTransformerEngineTensor(cos); - auto sin_cu = makeTransformerEngineTensor(sin); - auto dkv_cu = makeTransformerEngineTensor(d_kv); - auto demb_cu = makeTransformerEngineTensor(d_emb); - auto cu_seqlens_cu = TensorWrapper(); - if (cu_seqlens.has_value()) { - cu_seqlens_cu = makeTransformerEngineTensor(cu_seqlens.value()); - } - - nvte_fused_mla_rope_kv_backward(dk_cu.data(), dv_cu.data(), cos_cu.data(), sin_cu.data(), - dkv_cu.data(), demb_cu.data(), cu_seqlens_cu.data(), emb_dim, - k_dim, v_dim, nheads, total_seqlen, s, b_val, cp_size, cp_rank, - at::cuda::getCurrentCUDAStream()); - - if (!cu_seqlens.has_value()) { - d_kv = d_kv.view({max_seqlen, batch_size, nheads, k_dim + v_dim}); - d_emb = d_emb.view({max_seqlen, batch_size, 1, emb_dim}); - } else { - d_emb = d_emb.view({total_seqlen, 1, emb_dim}); - } - return std::make_tuple(d_kv, d_emb); -} - } // namespace transformer_engine::pytorch diff --git a/transformer_engine/pytorch/csrc/extensions/attention.cpp b/transformer_engine/pytorch/csrc/extensions/attention.cpp index 733e1e1602..529fca45d6 100644 --- a/transformer_engine/pytorch/csrc/extensions/attention.cpp +++ b/transformer_engine/pytorch/csrc/extensions/attention.cpp @@ -655,20 +655,20 @@ at::Tensor fa_prepare_bwd(at::Tensor q, at::Tensor k, at::Tensor v) { std::vector permute_to_grouped_tensor_fwd( at::Tensor query, std::optional key, std::optional value, - NVTE_QKV_Format original_format) { - NVTE_CHECK(original_format == NVTE_SBHD || original_format == NVTE_BSHD, - "permute_to_grouped_tensor_fwd: original_format must be NVTE_SBHD or NVTE_BSHD."); + const std::string &original_format) { + NVTE_CHECK(original_format == "sbhd" || original_format == "bshd", + "Unsupported original_format \"", original_format, "\"; expected \"sbhd\" or \"bshd\"."); + const auto original_format_enum = (original_format == "sbhd") ? NVTE_SBHD : NVTE_BSHD; NVTE_CHECK(query.is_cuda() && query.is_contiguous() && query.dim() == 4); NVTE_CHECK(query.scalar_type() == at::ScalarType::Half || query.scalar_type() == at::ScalarType::BFloat16 || - query.scalar_type() == at::ScalarType::Float || query.scalar_type() == at::ScalarType::Byte); const bool has_kv = key.has_value() && value.has_value(); const size_t num_tensors = has_kv ? 3 : 1; int64_t B, S_q, H_q, D_qk; - if (original_format == NVTE_SBHD) { + if (original_format_enum == NVTE_SBHD) { S_q = query.size(0); B = query.size(1); H_q = query.size(2); D_qk = query.size(3); } else { B = query.size(0); S_q = query.size(1); H_q = query.size(2); D_qk = query.size(3); @@ -681,7 +681,7 @@ std::vector permute_to_grouped_tensor_fwd( auto te_qo = makeTransformerEngineTensor(q_out); nvte_permute_to_grouped_tensor_fwd(te_q.data(), te_q.data(), te_q.data(), te_qo.data(), te_qo.data(), te_qo.data(), - original_format, 1, at::cuda::getCurrentCUDAStream()); + original_format_enum, 1, at::cuda::getCurrentCUDAStream()); return {q_out}; } @@ -692,7 +692,7 @@ std::vector permute_to_grouped_tensor_fwd( NVTE_CHECK(k.scalar_type() == query.scalar_type() && v.scalar_type() == query.scalar_type()); int64_t S_kv, H_kv, D_v; - if (original_format == NVTE_SBHD) { + if (original_format_enum == NVTE_SBHD) { S_kv = k.size(0); H_kv = k.size(2); D_v = v.size(3); } else { S_kv = k.size(1); H_kv = k.size(2); D_v = v.size(3); @@ -715,20 +715,20 @@ std::vector permute_to_grouped_tensor_fwd( nvte_permute_to_grouped_tensor_fwd(te_q.data(), te_k.data(), te_v.data(), te_qo.data(), te_ko.data(), te_vo.data(), - original_format, 3, at::cuda::getCurrentCUDAStream()); + original_format_enum, 3, at::cuda::getCurrentCUDAStream()); return {q_out, k_out, v_out}; } std::vector permute_to_grouped_tensor_bwd( at::Tensor query_grad, std::optional key_grad, - std::optional value_grad, NVTE_QKV_Format original_format) { - NVTE_CHECK(original_format == NVTE_SBHD || original_format == NVTE_BSHD, - "permute_to_grouped_tensor_bwd: original_format must be NVTE_SBHD or NVTE_BSHD."); + std::optional value_grad, const std::string &original_format) { + NVTE_CHECK(original_format == "sbhd" || original_format == "bshd", + "Unsupported original_format \"", original_format, "\"; expected \"sbhd\" or \"bshd\"."); + const auto original_format_enum = (original_format == "sbhd") ? NVTE_SBHD : NVTE_BSHD; NVTE_CHECK(query_grad.is_cuda() && query_grad.is_contiguous() && query_grad.dim() == 4); NVTE_CHECK(query_grad.scalar_type() == at::ScalarType::Half || query_grad.scalar_type() == at::ScalarType::BFloat16 || - query_grad.scalar_type() == at::ScalarType::Float || query_grad.scalar_type() == at::ScalarType::Byte); const bool has_kv = key_grad.has_value() && value_grad.has_value(); @@ -740,7 +740,7 @@ std::vector permute_to_grouped_tensor_bwd( if (!has_kv) { at::Tensor q; - if (original_format == NVTE_SBHD) { + if (original_format_enum == NVTE_SBHD) { q = at::empty({S_q, B, H_q, D_qk}, query_grad.options()); } else { q = at::empty({B, S_q, H_q, D_qk}, query_grad.options()); @@ -749,7 +749,7 @@ std::vector permute_to_grouped_tensor_bwd( auto te_q = makeTransformerEngineTensor(q); nvte_permute_to_grouped_tensor_bwd(te_gq.data(), te_gq.data(), te_gq.data(), te_q.data(), te_q.data(), te_q.data(), - original_format, 1, at::cuda::getCurrentCUDAStream()); + original_format_enum, 1, at::cuda::getCurrentCUDAStream()); return {q}; } @@ -770,7 +770,7 @@ std::vector permute_to_grouped_tensor_bwd( at::Tensor qkv_grad_flat = at::empty({numel_q + numel_k + numel_v}, query_grad.options()); at::Tensor query, key, value; - if (original_format == NVTE_SBHD) { + if (original_format_enum == NVTE_SBHD) { query = qkv_grad_flat.narrow(0, 0, numel_q).view({S_q, B, H_q, D_qk}); key = qkv_grad_flat.narrow(0, numel_q, numel_k).view({S_kv, B, H_kv, D_qk}); value = qkv_grad_flat.narrow(0, numel_q + numel_k, numel_v).view({S_kv, B, H_kv, D_v}); @@ -789,7 +789,7 @@ std::vector permute_to_grouped_tensor_bwd( nvte_permute_to_grouped_tensor_bwd(te_gq.data(), te_gk.data(), te_gv.data(), te_q.data(), te_k.data(), te_v.data(), - original_format, 3, at::cuda::getCurrentCUDAStream()); + original_format_enum, 3, at::cuda::getCurrentCUDAStream()); return {query, key, value}; } @@ -799,10 +799,10 @@ std::vector permute_to_grouped_tensor_bwd( * All tensors share the same alignment; launches a single fused kernel. **************************************************************************************************/ -std::vector pad_last_dim(std::vector inputs, int64_t alignment) { +std::vector multi_tensor_pad_last_dim(std::vector inputs, int64_t alignment) { const auto align = static_cast(alignment); - NVTE_CHECK(align > 0, "pad_last_dim: alignment must be > 0."); - NVTE_CHECK(!inputs.empty(), "pad_last_dim: inputs must not be empty."); + NVTE_CHECK(align > 0, "multi_tensor_pad_last_dim: alignment must be > 0."); + NVTE_CHECK(!inputs.empty(), "multi_tensor_pad_last_dim: inputs must not be empty."); auto stream = at::cuda::getCurrentCUDAStream(); std::vector outputs; @@ -813,10 +813,10 @@ std::vector pad_last_dim(std::vector inputs, int64_t ali for (size_t i = 0; i < inputs.size(); ++i) { auto &input = inputs[i]; - NVTE_CHECK(input.dim() == 2, "pad_last_dim: expected 2D input at index ", i, ", got ", + NVTE_CHECK(input.dim() == 2, "multi_tensor_pad_last_dim: expected 2D input at index ", i, ", got ", input.dim(), "D."); - NVTE_CHECK(input.is_cuda(), "pad_last_dim: input must be a CUDA tensor at index ", i, "."); - NVTE_CHECK(input.is_contiguous(), "pad_last_dim: input must be contiguous at index ", i, "."); + NVTE_CHECK(input.is_cuda(), "multi_tensor_pad_last_dim: input must be a CUDA tensor at index ", i, "."); + NVTE_CHECK(input.is_contiguous(), "multi_tensor_pad_last_dim: input must be contiguous at index ", i, "."); const int64_t rows = input.size(0); const int64_t in_cols = input.size(1); @@ -858,7 +858,7 @@ std::vector pad_last_dim(std::vector inputs, int64_t ali nvte_outputs[i] = te_out_wrappers[i].data(); } - nvte_multi_pad_last_dim(nvte_inputs.data(), nvte_outputs.data(), te_in_wrappers.size(), stream); + nvte_multi_tensor_pad_last_dim(nvte_inputs.data(), nvte_outputs.data(), te_in_wrappers.size(), stream); return outputs; } diff --git a/transformer_engine/pytorch/csrc/extensions/pybind.cpp b/transformer_engine/pytorch/csrc/extensions/pybind.cpp index d0d54c9283..de2b75b2b9 100644 --- a/transformer_engine/pytorch/csrc/extensions/pybind.cpp +++ b/transformer_engine/pytorch/csrc/extensions/pybind.cpp @@ -389,6 +389,11 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { "Fused Multi-tensor unpadding", py::call_guard()); m.def("swizzle_scales_for_gemm_", &transformer_engine::pytorch::inplace_swizzle_scale_for_gemm, "Convert tensor block scales into GEMM swizzled format"); + m.def("multi_swizzle_scales_for_gemm_", + &transformer_engine::pytorch::inplace_multi_swizzle_scales_for_gemm, + "Batch-swizzle block scales for multiple tensors in a single kernel launch", + py::arg("tensors"), py::arg("rowwise_usage"), py::arg("columnwise_usage"), + py::arg("check_scale_inv_shapes") = true); m.def("grouped_swizzle_for_gemm", &transformer_engine::pytorch::grouped_swizzle_for_gemm, "In-place swizzle of grouped tensor scales for GEMM", py::arg("tensor"), py::arg("rowwise"), py::arg("columnwise")); @@ -403,15 +408,15 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { &transformer_engine::pytorch::permute_to_grouped_tensor_fwd, "Permute tensors from BSHD/SBHD to BHSD.", py::arg("query"), py::arg("key") = py::none(), py::arg("value") = py::none(), - py::arg("original_format") = static_cast(NVTE_BSHD), + py::arg("original_format") = std::string("bshd"), py::call_guard()); m.def("permute_to_grouped_tensor_bwd", &transformer_engine::pytorch::permute_to_grouped_tensor_bwd, "Permute tensors back to original format.", py::arg("query_grad"), py::arg("key_grad") = py::none(), py::arg("value_grad") = py::none(), - py::arg("original_format") = static_cast(NVTE_BSHD), + py::arg("original_format") = std::string("bshd"), py::call_guard()); - m.def("pad_last_dim", &transformer_engine::pytorch::pad_last_dim, + m.def("multi_tensor_pad_last_dim", &transformer_engine::pytorch::multi_tensor_pad_last_dim, "Pad last dimension of 2D tensors to a common alignment.", py::arg("inputs"), py::arg("alignment"), py::call_guard()); m.def("fused_attn_fwd", &transformer_engine::pytorch::fused_attn_fwd, @@ -435,15 +440,6 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("fused_qkv_rope_backward", &transformer_engine::pytorch::fused_qkv_rope_backward, "Fused Apply QKV RoPE BWD", py::call_guard()); - // fused MLA rope - m.def("fused_mla_rope_q_forward", &transformer_engine::pytorch::fused_mla_rope_q_forward, - "Fused MLA RoPE Q FWD", py::call_guard()); - m.def("fused_mla_rope_q_backward", &transformer_engine::pytorch::fused_mla_rope_q_backward, - "Fused MLA RoPE Q BWD", py::call_guard()); - m.def("fused_mla_rope_kv_forward", &transformer_engine::pytorch::fused_mla_rope_kv_forward, - "Fused MLA RoPE KV FWD", py::call_guard()); - m.def("fused_mla_rope_kv_backward", &transformer_engine::pytorch::fused_mla_rope_kv_backward, - "Fused MLA RoPE KV BWD", py::call_guard()); // fused router m.def("fused_topk_with_score_function_fwd", diff --git a/transformer_engine/pytorch/csrc/extensions/swizzle.cpp b/transformer_engine/pytorch/csrc/extensions/swizzle.cpp index a6b4e7569d..bd4ad086a3 100644 --- a/transformer_engine/pytorch/csrc/extensions/swizzle.cpp +++ b/transformer_engine/pytorch/csrc/extensions/swizzle.cpp @@ -443,6 +443,105 @@ void grouped_swizzle_for_gemm(py::handle &tensor, bool rowwise, bool columnwise) } } +void inplace_multi_swizzle_scales_for_gemm(std::vector &tensors, bool rowwise_usage, + bool columnwise_usage, + bool check_scale_inv_shapes) { + NVTE_CHECK(rowwise_usage != columnwise_usage, + "Expect exactly one of rowwise_usage and columnwise_usage."); + if (tensors.empty()) { + return; + } + + // Convert Python tensors to C++ TensorWrappers + std::vector te_wrappers; + te_wrappers.reserve(tensors.size()); + for (auto &t : tensors) { + te_wrappers.push_back(makeTransformerEngineTensor(t, py::none())); + } + + // Check scaling mode and filter tensors needing swizzle + const auto scaling_mode = te_wrappers.front().scaling_mode(); + switch (scaling_mode) { + case NVTE_MXFP8_1D_SCALING: + case NVTE_NVFP4_1D_SCALING: + break; + default: + return; + } + + struct SwizzleItem { + size_t py_idx; + at::Tensor output_pyt; + }; + std::vector items; + std::vector inputs_nvte, outputs_nvte; + + for (size_t i = 0; i < te_wrappers.size(); ++i) { + auto &tw = te_wrappers[i]; + if (tw.get_with_gemm_swizzled_scales()) { + continue; + } + const auto scales_nvte = + rowwise_usage ? tw.get_rowwise_scale_inv() : tw.get_columnwise_scale_inv(); + if (scales_nvte.data_ptr == nullptr || + (scales_nvte.shape.ndim == 1 && scales_nvte.shape.data[0] == 0)) { + continue; + } + const auto data_nvte = rowwise_usage ? tw.get_rowwise_data() : tw.get_columnwise_data(); + const auto data_dtype = static_cast(data_nvte.dtype); + const auto scales_dtype = static_cast(scales_nvte.dtype); + + // Allocate a separate output tensor that properly owns its memory + auto output_pyt = allocateSpace(scales_nvte.shape, scales_dtype, false); + + inputs_nvte.emplace_back(scaling_mode); + outputs_nvte.emplace_back(scaling_mode); + auto &in_nvte = inputs_nvte.back(); + auto &out_nvte = outputs_nvte.back(); + if (rowwise_usage) { + in_nvte.set_rowwise_data(nullptr, data_dtype, data_nvte.shape); + in_nvte.set_rowwise_scale_inv(scales_nvte.data_ptr, scales_dtype, scales_nvte.shape); + out_nvte.set_rowwise_data(nullptr, data_dtype, data_nvte.shape); + out_nvte.set_rowwise_scale_inv(getDataPtr(output_pyt), scales_dtype, scales_nvte.shape); + } else { + in_nvte.set_columnwise_data(nullptr, data_dtype, data_nvte.shape); + in_nvte.set_columnwise_scale_inv(scales_nvte.data_ptr, scales_dtype, scales_nvte.shape); + out_nvte.set_columnwise_data(nullptr, data_dtype, data_nvte.shape); + out_nvte.set_columnwise_scale_inv(getDataPtr(output_pyt), scales_dtype, scales_nvte.shape); + } + out_nvte.set_with_gemm_swizzled_scales(true); + items.push_back({i, std::move(output_pyt)}); + } + + if (items.empty()) { + return; + } + + // Pack raw NVTETensors and launch single batched kernel + std::vector inputs_raw, outputs_raw; + inputs_raw.reserve(inputs_nvte.size()); + outputs_raw.reserve(outputs_nvte.size()); + for (auto &t : inputs_nvte) inputs_raw.push_back(t.data()); + for (auto &t : outputs_nvte) outputs_raw.push_back(t.data()); + + auto stream = at::cuda::getCurrentCUDAStream(); + NVTE_SCOPED_GIL_RELEASE({ + nvte_multi_tensor_swizzle_scaling_factors(inputs_raw.data(), outputs_raw.data(), + inputs_raw.size(), stream, + check_scale_inv_shapes); + }); + + // Update Python tensors with the owning output tensors + for (auto &item : items) { + auto &t = tensors[item.py_idx]; + if (rowwise_usage) { + t.attr("_rowwise_scale_inv") = py::cast(item.output_pyt); + } else { + t.attr("_columnwise_scale_inv") = py::cast(item.output_pyt); + } + } +} + void inplace_swizzle_scale_for_gemm(py::handle &tensor) { // Convert Python tensor to C++ tensor auto tensor_nvte = makeTransformerEngineTensor(tensor, py::none()); From e44029101d23786bedaa27700ea8c4fc76aa9c33 Mon Sep 17 00:00:00 2001 From: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Date: Fri, 10 Apr 2026 22:02:39 -0700 Subject: [PATCH 171/172] fix last commit Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --- .../attention/dot_product_attention/utils.py | 118 ++++++++++++++---- 1 file changed, 91 insertions(+), 27 deletions(-) diff --git a/transformer_engine/pytorch/attention/dot_product_attention/utils.py b/transformer_engine/pytorch/attention/dot_product_attention/utils.py index c3de5d3f86..c0a3ce7dda 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/utils.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/utils.py @@ -46,7 +46,7 @@ from transformer_engine.pytorch.tensor.storage.mxfp8_tensor_storage import MXFP8TensorStorage from transformer_engine.pytorch.tensor.grouped_tensor import GroupedTensor from transformer_engine.pytorch.quantization import get_fp8_te_dtype -from transformer_engine.pytorch.constants import TE_DType +from transformer_engine.pytorch.constants import TE_DType, MXFP8_BLOCK_SCALING_SIZE from transformer_engine.pytorch.utils import ( @@ -2378,28 +2378,48 @@ def mxfp8_permute_scale_inv_to_bhsd(*tensors, src_format): mxfp8_pad_and_swizzle_scales(*outs) return tuple(outs) + # Permute rowwise scale_inv from src_format to BHSD rs_4d_list = [] - d_scales = [] + rs_d_scales = [] for t in tensors: rs = t._rowwise_scale_inv - d_scale = rs.shape[-1] - shape_4d = list(t.shape[:-1]) + [d_scale] - rs_4d_list.append(rs.view(shape_4d).contiguous()) - d_scales.append(d_scale) - - permuted = tex.permute_to_grouped_tensor_fwd(*rs_4d_list, original_format=src_format) - - outs = [ - MXFP8Tensor( - shape=t.shape, dtype=t.dtype, - rowwise_data=t._rowwise_data, rowwise_scale_inv=p.view(-1, d), - columnwise_data=t._columnwise_data, columnwise_scale_inv=t._columnwise_scale_inv, - quantizer=t._quantizer, requires_grad=False, - fp8_dtype=t._fp8_dtype, with_gemm_swizzled_scales=t._with_gemm_swizzled_scales, + if rs is not None: + rs_4d_list.append(rs) + rs_d_scales.append(rs.shape[-1]) + + rs_permuted = None + if rs_4d_list: + rs_permuted = tex.permute_to_grouped_tensor_fwd(*rs_4d_list, original_format=src_format) + + # Permute columnwise scale_inv from src_format to BHSD + cs_4d_list = [] + cs_d_scales = [] + for t in tensors: + cs = t._columnwise_scale_inv + if cs is not None: + cs_4d_list.append(cs) + cs_d_scales.append(cs.shape[-1]) + + cs_permuted = None + if cs_4d_list: + cs_permuted = tex.permute_to_grouped_tensor_fwd(*cs_4d_list, original_format=src_format) + + outs = [] + for i, (t, rp, rd, cp, cd) in enumerate(zip(tensors, rs_permuted, rs_d_scales, cs_permuted, cs_d_scales)): + rp = rp.view(-1, rd) if rd is not None else None + cp = cp.view(-1, cd) if cd is not None else None + outs.append( + MXFP8Tensor( + shape=t.shape, dtype=t.dtype, + rowwise_data=t._rowwise_data, rowwise_scale_inv=rp, + columnwise_data=t._columnwise_data, columnwise_scale_inv=cp, + quantizer=t._quantizer, requires_grad=False, + fp8_dtype=t._fp8_dtype, with_gemm_swizzled_scales=t._with_gemm_swizzled_scales, + ) ) - for t, p, d in zip(tensors, permuted, d_scales) - ] + mxfp8_pad_and_swizzle_scales(*outs) + return tuple(outs) @@ -2421,10 +2441,25 @@ def mxfp8_quantize_single_tensor(tensor, quantizer, src_format): Returns (fp8_tensor, scale_inv_format) where scale_inv_format is "bhsd". """ + original_shape = tensor.shape + _s_dim = {"bshd": 2, "sbhd": 0, "bhsd": 2} + _d_dim = {"bshd": 3, "sbhd": 3, "bhsd": 3} + rowwise_scale_inv_shape = list(tensor.shape) + rowwise_scale_inv_shape[_d_dim[src_format]] = rowwise_scale_inv_shape[_d_dim[src_format]]//MXFP8_BLOCK_SCALING_SIZE + columnwise_scale_inv_shape = list(tensor.shape) + columnwise_scale_inv_shape[_s_dim[src_format]] = columnwise_scale_inv_shape[_s_dim[src_format]]//MXFP8_BLOCK_SCALING_SIZE + if src_format == "bhsd": + tensor = tensor.view(tensor.shape[:_s_dim[src_format]], -1) + elif src_format == "sbhd": + tensor = tensor.view(tensor.shape[_s_dim[src_format]], -1) orig_optimize = quantizer.optimize_for_gemm quantizer.optimize_for_gemm = False fp8_tensor = quantizer(tensor) quantizer.optimize_for_gemm = orig_optimize + fp8_tensor._rowwise_data = fp8_tensor._rowwise_data.view(original_shape) if fp8_tensor._rowwise_data is not None else None + fp8_tensor._columnwise_data = fp8_tensor._columnwise_data.view(original_shape) if fp8_tensor._columnwise_data is not None else None + fp8_tensor._rowwise_scale_inv = fp8_tensor._rowwise_scale_inv.view(rowwise_scale_inv_shape) if fp8_tensor._rowwise_scale_inv is not None else None + fp8_tensor._columnwise_scale_inv = fp8_tensor._columnwise_scale_inv.view(columnwise_scale_inv_shape) if fp8_tensor._columnwise_scale_inv is not None else None (fp8_tensor,) = mxfp8_permute_scale_inv_to_bhsd(fp8_tensor, src_format=src_format) return fp8_tensor, "bhsd" @@ -2499,18 +2534,31 @@ def combine_and_quantize( if isinstance(qkv_quantizer, MXFP8Quantizer): qkv_format, q_format, kv_format = get_qkv_format(qkv_layout) - _seq_dim = {"sbhd": 0, "bshd": 1, "bhsd": 2, "htd": 1} - d_qk = q.shape[-1] - d_v = v.shape[-1] - s_q = q.shape[_seq_dim.get(q_format, 2)] - s_kv = v.shape[_seq_dim.get(kv_format, 2)] + _s_dim = {"sbhd": 0, "bshd": 1, "bhsd": 2} + _b_dim = {"sbhd": 1, "bshd": 0, "bhsd": 0} + _h_dim = {"sbhd": 2, "bshd": 2, "bhsd": 1} + _d_dim = {"sbhd": 3, "bshd": 3, "bhsd": 3} + d_qk = q.shape[_d_dim[qkv_format]] + d_v = v.shape[_d_dim[qkv_format]] + s_q = q.shape[_s_dim[q_format]] + s_kv = v.shape[_s_dim[kv_format]] assert s_q % 128 == 0 and s_kv % 128 == 0 and d_qk % 32 == 0 and d_v % 32 == 0, ( "MXFP8 quantization requires s_q % 128 == 0, s_kv % 128 == 0, d_qk % 32 == 0, d_v % 32" f" == 0. Found {s_q=}, {s_kv=}, {d_qk=}, {d_v=}." ) + rowwise_scale_inv_shapes = [] + columnwise_scale_inv_shapes = [] + for x in [q, k, v]: + rs_shape = list(x.shape) + cs_shape = list(x.shape) + rs_shape[_d_dim[qkv_format]] = rs_shape[_d_dim[qkv_format]]//MXFP8_BLOCK_SCALING_SIZE + cs_shape[_s_dim[qkv_format]] = cs_shape[_s_dim[qkv_format]]//MXFP8_BLOCK_SCALING_SIZE + rowwise_scale_inv_shapes.append(rs_shape) + columnwise_scale_inv_shapes.append(cs_shape) needs_permute = qkv_format not in ("bhsd", "htd") qkv_scale_inv_format = None + original_shapes = [None, None, None] if keep_same_data_and_scale_inv_format and needs_permute: # Permute f16 data to BHSD, then quantize with swizzle @@ -2522,13 +2570,18 @@ def combine_and_quantize( k = permute_to_grouped_tensor_pytorch(k, kv_format) v = permute_to_grouped_tensor_pytorch(v, kv_format) qkv_layout = "bhsd_bhsd_bhsd" if qkv_format != "thd" else "htd_htd_htd" - - original_shapes = [x.shape for x in [q, k, v]] - q, k, v = [x.view(-1, x.shape[-1]) for x in [q, k, v]] + qkv_scale_inv_format = "bhsd" if qkv_format != "thd" else "htd" + original_shapes = [x.shape for x in [q, k, v]] + q, k, v = [x.view(-1, x.shape[-1]) for x in [q, k, v]] if not keep_same_data_and_scale_inv_format: orig_optimize = qkv_quantizer.optimize_for_gemm qkv_quantizer.optimize_for_gemm = False + original_shapes = [x.shape for x in [q, k, v]] + if qkv_format == "bshd": + q, k, v = [x.view(x.shape[:2], -1) for x in [q, k, v]] + elif qkv_format == "sbhd": + q, k, v = [x.view(x.shape[0], -1) for x in [q, k, v]] if used_in_forward and used_in_backward: q_fp8, k_fp8, v_fp8 = [qkv_quantizer(x) for x in [q, k, v]] @@ -2550,10 +2603,21 @@ def combine_and_quantize( if not keep_same_data_and_scale_inv_format: qkv_quantizer.optimize_for_gemm = orig_optimize - q_fp8, k_fp8, v_fp8 = [x.view(s) for x, s in zip([q_fp8, k_fp8, v_fp8], original_shapes)] + q_fp8._rowwise_data = q_fp8._rowwise_data.view(original_shapes[0]) if q_fp8._rowwise_data is not None else None + q_fp8._columnwise_data = q_fp8._columnwise_data.view(original_shapes[0]) if q_fp8._columnwise_data is not None else None + k_fp8._rowwise_data = k_fp8._rowwise_data.view(original_shapes[1]) if k_fp8._rowwise_data is not None else None + k_fp8._columnwise_data = k_fp8._columnwise_data.view(original_shapes[1]) if k_fp8._columnwise_data is not None else None + v_fp8._rowwise_data = v_fp8._rowwise_data.view(original_shapes[2]) if v_fp8._rowwise_data is not None else None + v_fp8._columnwise_data = v_fp8._columnwise_data.view(original_shapes[2]) if v_fp8._columnwise_data is not None else None if not keep_same_data_and_scale_inv_format: # Permute only scale_inv to BHSD + pad + swizzle + q_fp8._rowwise_scale_inv = q_fp8._rowwise_scale_inv.view(rowwise_scale_inv_shapes[0]) if q_fp8._rowwise_scale_inv is not None else None + q_fp8._columnwise_scale_inv = q_fp8._columnwise_scale_inv.view(columnwise_scale_inv_shapes[0]) if q_fp8._columnwise_scale_inv is not None else None + k_fp8._rowwise_scale_inv = k_fp8._rowwise_scale_inv.view(rowwise_scale_inv_shapes[1]) if k_fp8._rowwise_scale_inv is not None else None + k_fp8._columnwise_scale_inv = k_fp8._columnwise_scale_inv.view(columnwise_scale_inv_shapes[1]) if k_fp8._columnwise_scale_inv is not None else None + v_fp8._rowwise_scale_inv = v_fp8._rowwise_scale_inv.view(rowwise_scale_inv_shapes[2]) if v_fp8._rowwise_scale_inv is not None else None + v_fp8._columnwise_scale_inv = v_fp8._columnwise_scale_inv.view(columnwise_scale_inv_shapes[2]) if v_fp8._columnwise_scale_inv is not None else None q_fp8, k_fp8, v_fp8 = mxfp8_permute_scale_inv_to_bhsd( q_fp8, k_fp8, v_fp8, src_format=q_format, ) From 346c76476f6dd685274dceb105372f4e9a9837b9 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sat, 11 Apr 2026 05:03:35 +0000 Subject: [PATCH 172/172] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- transformer_engine/common/common.h | 3 +- .../common/fused_attn/flash_attn.cu | 318 ++++++++---------- .../common/fused_attn/fused_attn.cpp | 29 +- .../common/fused_attn/fused_attn_fp8.cu | 38 +-- transformer_engine/common/fused_attn/utils.h | 20 +- .../include/transformer_engine/fused_attn.h | 29 +- transformer_engine/common/swizzle/swizzle.cu | 58 ++-- .../common/transformer_engine.cpp | 3 +- .../common/util/pybind_helper.h | 4 +- .../dot_product_attention/backends.py | 16 +- .../dot_product_attention/context_parallel.py | 52 +-- .../attention/dot_product_attention/utils.py | 161 ++++++--- transformer_engine/pytorch/csrc/extensions.h | 19 +- .../pytorch/csrc/extensions/attention.cpp | 78 +++-- .../pytorch/csrc/extensions/pybind.cpp | 9 +- .../pytorch/csrc/extensions/swizzle.cpp | 6 +- 16 files changed, 451 insertions(+), 392 deletions(-) diff --git a/transformer_engine/common/common.h b/transformer_engine/common/common.h index b32d6bbbcd..68aa0f4c51 100644 --- a/transformer_engine/common/common.h +++ b/transformer_engine/common/common.h @@ -1003,8 +1003,7 @@ size_t typeToSize(const DType type); size_t typeToNumBits(const DType type); void CheckNoopTensor(const Tensor &t, const std::string &name); -void CheckInputTensor(const Tensor &t, const std::string &name, - bool check_scale_inv_shapes = true); +void CheckInputTensor(const Tensor &t, const std::string &name, bool check_scale_inv_shapes = true); void CheckOutputTensor(const Tensor &t, const std::string &name, bool allow_empty = false); /*! \brief Update a tensor's FP8 scale-inverse diff --git a/transformer_engine/common/fused_attn/flash_attn.cu b/transformer_engine/common/fused_attn/flash_attn.cu index f34a11422d..d225af9553 100644 --- a/transformer_engine/common/fused_attn/flash_attn.cu +++ b/transformer_engine/common/fused_attn/flash_attn.cu @@ -164,8 +164,8 @@ using flash_attention::Vec; // ---------- fallback_not_vec_aligned: row-copy helper (D is small / misaligned) ---------- -__device__ __forceinline__ void copy_row_bytes(const char *__restrict__ src, - char *__restrict__ dst, size_t D_bytes) { +__device__ __forceinline__ void copy_row_bytes(const char *__restrict__ src, char *__restrict__ dst, + size_t D_bytes) { size_t off = 0; for (; off + 16 <= D_bytes; off += 16) { uint4 tmp; @@ -190,31 +190,29 @@ __device__ __forceinline__ void copy_row_bytes(const char *__restrict__ src, for (; off < D_bytes; ++off) dst[off] = src[off]; } - // ---------- fallback_not_vec_aligned: tiled-transpose kernels ---------- -constexpr int TRANSPOSE_TILE = 32; +constexpr int TRANSPOSE_TILE = 32; constexpr int TRANSPOSE_BLOCK = 256; -constexpr int TRANSPOSE_WARPS = TRANSPOSE_BLOCK / 32; // 8 +constexpr int TRANSPOSE_WARPS = TRANSPOSE_BLOCK / 32; // 8 template __launch_bounds__(TRANSPOSE_BLOCK) __global__ void permute_to_grouped_tensor_fwd_fallback_not_vec_aligned_kernel( const T *__restrict__ q_in, const T *__restrict__ k_in, const T *__restrict__ v_in, - T *__restrict__ q_out, T *__restrict__ k_out, T *__restrict__ v_out, - size_t b, size_t s_q, size_t h_q, size_t d_qk, size_t s_kv, size_t h_kv, size_t d_v, - unsigned int s_tiles) { + T *__restrict__ q_out, T *__restrict__ k_out, T *__restrict__ v_out, size_t b, size_t s_q, + size_t h_q, size_t d_qk, size_t s_kv, size_t h_kv, size_t d_v, unsigned int s_tiles) { const int which = blockIdx.z; const T *__restrict__ in = which == 0 ? q_in : (which == 1 ? k_in : v_in); - T *__restrict__ out = which == 0 ? q_out : (which == 1 ? k_out : v_out); - const size_t S = which == 0 ? s_q : s_kv; - const size_t H = which == 0 ? h_q : h_kv; - const size_t D = which == 0 ? d_qk : (which == 1 ? d_qk : d_v); + T *__restrict__ out = which == 0 ? q_out : (which == 1 ? k_out : v_out); + const size_t S = which == 0 ? s_q : s_kv; + const size_t H = which == 0 ? h_q : h_kv; + const size_t D = which == 0 ? d_qk : (which == 1 ? d_qk : d_v); const size_t D_bytes = D * sizeof(T); - const size_t D_pad = (D_bytes + 3u) & ~size_t(3); // 4-byte aligned for smem + const size_t D_pad = (D_bytes + 3u) & ~size_t(3); // 4-byte aligned for smem const size_t tile_s = static_cast(blockIdx.x) % static_cast(s_tiles); - const size_t b_i = static_cast(blockIdx.x) / static_cast(s_tiles); + const size_t b_i = static_cast(blockIdx.x) / static_cast(s_tiles); if (b_i >= b) return; const size_t tile_h = static_cast(blockIdx.y); @@ -226,8 +224,8 @@ __launch_bounds__(TRANSPOSE_BLOCK) __global__ const size_t smem_row = static_cast(TRANSPOSE_TILE) * D_pad + 4; // ---- Phase 1: global → smem (sweep consecutive H → coalesced reads) ---- - for (unsigned int warp_off = threadIdx.x >> 5; - warp_off < TRANSPOSE_TILE; warp_off += TRANSPOSE_WARPS) { + for (unsigned int warp_off = threadIdx.x >> 5; warp_off < TRANSPOSE_TILE; + warp_off += TRANSPOSE_WARPS) { const size_t local_s = warp_off; const size_t local_h = threadIdx.x & 31u; const size_t s_i = s_base + local_s; @@ -245,8 +243,8 @@ __launch_bounds__(TRANSPOSE_BLOCK) __global__ __syncthreads(); // ---- Phase 2: smem → global (sweep consecutive S → coalesced writes) ---- - for (unsigned int warp_off = threadIdx.x >> 5; - warp_off < TRANSPOSE_TILE; warp_off += TRANSPOSE_WARPS) { + for (unsigned int warp_off = threadIdx.x >> 5; warp_off < TRANSPOSE_TILE; + warp_off += TRANSPOSE_WARPS) { const size_t local_h = warp_off; const size_t local_s = threadIdx.x & 31u; const size_t s_i = s_base + local_s; @@ -263,20 +261,19 @@ template __launch_bounds__(TRANSPOSE_BLOCK) __global__ void permute_to_grouped_tensor_bwd_fallback_not_vec_aligned_kernel( const T *__restrict__ grad_q, const T *__restrict__ grad_k, const T *__restrict__ grad_v, - T *__restrict__ q_out, T *__restrict__ k_out, T *__restrict__ v_out, - size_t b, size_t s_q, size_t h_q, size_t d_qk, size_t s_kv, size_t h_kv, size_t d_v, - unsigned int s_tiles) { + T *__restrict__ q_out, T *__restrict__ k_out, T *__restrict__ v_out, size_t b, size_t s_q, + size_t h_q, size_t d_qk, size_t s_kv, size_t h_kv, size_t d_v, unsigned int s_tiles) { const int which = blockIdx.z; const T *__restrict__ in = which == 0 ? grad_q : (which == 1 ? grad_k : grad_v); - T *__restrict__ out = which == 0 ? q_out : (which == 1 ? k_out : v_out); - const size_t S = which == 0 ? s_q : s_kv; - const size_t H = which == 0 ? h_q : h_kv; - const size_t D = which == 0 ? d_qk : (which == 1 ? d_qk : d_v); + T *__restrict__ out = which == 0 ? q_out : (which == 1 ? k_out : v_out); + const size_t S = which == 0 ? s_q : s_kv; + const size_t H = which == 0 ? h_q : h_kv; + const size_t D = which == 0 ? d_qk : (which == 1 ? d_qk : d_v); const size_t D_bytes = D * sizeof(T); - const size_t D_pad = (D_bytes + 3u) & ~size_t(3); + const size_t D_pad = (D_bytes + 3u) & ~size_t(3); const size_t tile_s = static_cast(blockIdx.x) % static_cast(s_tiles); - const size_t b_i = static_cast(blockIdx.x) / static_cast(s_tiles); + const size_t b_i = static_cast(blockIdx.x) / static_cast(s_tiles); if (b_i >= b) return; const size_t tile_h = static_cast(blockIdx.y); @@ -287,24 +284,23 @@ __launch_bounds__(TRANSPOSE_BLOCK) __global__ const size_t smem_row = static_cast(TRANSPOSE_TILE) * D_pad + 4; // ---- Phase 1: global → smem (sweep consecutive S → coalesced reads from BHSD) ---- - for (unsigned int warp_off = threadIdx.x >> 5; - warp_off < TRANSPOSE_TILE; warp_off += TRANSPOSE_WARPS) { + for (unsigned int warp_off = threadIdx.x >> 5; warp_off < TRANSPOSE_TILE; + warp_off += TRANSPOSE_WARPS) { const size_t local_h = warp_off; const size_t local_s = threadIdx.x & 31u; const size_t s_i = s_base + local_s; const size_t h_i = h_base + local_h; if (s_i < S && h_i < H) { - copy_row_bytes( - reinterpret_cast(in + b_i * H * S * D + h_i * S * D + s_i * D), - smem + local_s * smem_row + local_h * D_pad, D_bytes); + copy_row_bytes(reinterpret_cast(in + b_i * H * S * D + h_i * S * D + s_i * D), + smem + local_s * smem_row + local_h * D_pad, D_bytes); } } __syncthreads(); // ---- Phase 2: smem → global (sweep consecutive H → coalesced writes to SBHD/BSHD) ---- - for (unsigned int warp_off = threadIdx.x >> 5; - warp_off < TRANSPOSE_TILE; warp_off += TRANSPOSE_WARPS) { + for (unsigned int warp_off = threadIdx.x >> 5; warp_off < TRANSPOSE_TILE; + warp_off += TRANSPOSE_WARPS) { const size_t local_s = warp_off; const size_t local_h = threadIdx.x & 31u; const size_t s_i = s_base + local_s; @@ -325,9 +321,10 @@ __launch_bounds__(TRANSPOSE_BLOCK) __global__ constexpr int fallback_permute_threads = 1024; template -__device__ __forceinline__ void permute_fwd_vec_loop( - const T *__restrict__ in, T *__restrict__ out, size_t b, size_t S, size_t H, size_t D, - size_t b_i, size_t h_i, size_t s_begin, size_t S_chunk) { +__device__ __forceinline__ void permute_fwd_vec_loop(const T *__restrict__ in, T *__restrict__ out, + size_t b, size_t S, size_t H, size_t D, + size_t b_i, size_t h_i, size_t s_begin, + size_t S_chunk) { const size_t out_base = b_i * H * S * D + h_i * S * D; const size_t d_vec = D / static_cast(N); const size_t total_work = S_chunk * d_vec; @@ -348,9 +345,10 @@ __device__ __forceinline__ void permute_fwd_vec_loop( } template -__device__ __forceinline__ void permute_bwd_vec_loop( - const T *__restrict__ in, T *__restrict__ out, size_t b, size_t S, size_t H, size_t D, - size_t b_i, size_t h_i, size_t s_begin, size_t S_chunk) { +__device__ __forceinline__ void permute_bwd_vec_loop(const T *__restrict__ in, T *__restrict__ out, + size_t b, size_t S, size_t H, size_t D, + size_t b_i, size_t h_i, size_t s_begin, + size_t S_chunk) { const size_t in_base = b_i * H * S * D + h_i * S * D; const size_t d_vec = D / static_cast(N); const size_t total_work = S_chunk * d_vec; @@ -395,8 +393,7 @@ __launch_bounds__(fallback_permute_threads) __global__ } const unsigned int s_part = blockIdx.y; - const size_t s_begin = - (S * static_cast(s_part)) / static_cast(permute_s_splits); + const size_t s_begin = (S * static_cast(s_part)) / static_cast(permute_s_splits); const size_t s_end = (S * static_cast(s_part + 1)) / static_cast(permute_s_splits); if (s_begin >= s_end) return; @@ -448,8 +445,7 @@ __launch_bounds__(fallback_permute_threads) __global__ } const unsigned int s_part = blockIdx.y; - const size_t s_begin = - (S * static_cast(s_part)) / static_cast(permute_s_splits); + const size_t s_begin = (S * static_cast(s_part)) / static_cast(permute_s_splits); const size_t s_end = (S * static_cast(s_part + 1)) / static_cast(permute_s_splits); if (s_begin >= s_end) return; @@ -476,7 +472,6 @@ __launch_bounds__(fallback_permute_threads) __global__ } } - // ---------- main path: TMA ---------- constexpr int tma_permute_threads = 128; @@ -545,8 +540,7 @@ static void create_4D_tensor_map(CUtensorMap &tensorMap, void *dataPtr, DType dt elem_bytes = 1; break; default: - NVTE_ERROR("create_4D_tensor_map: unsupported dtype ", - to_string(static_cast(dtype))); + NVTE_ERROR("create_4D_tensor_map: unsupported dtype ", to_string(static_cast(dtype))); } constexpr uint32_t rank = 4; @@ -560,8 +554,8 @@ static void create_4D_tensor_map(CUtensorMap &tensorMap, void *dataPtr, DType dt uint32_t elemStride[rank] = {1, 1, 1, 1}; const auto oob_fill = (tma_dtype == CU_TENSOR_MAP_DATA_TYPE_UINT8) - ? CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE - : CU_TENSOR_MAP_FLOAT_OOB_FILL_NAN_REQUEST_ZERO_FMA; + ? CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE + : CU_TENSOR_MAP_FLOAT_OOB_FILL_NAN_REQUEST_ZERO_FMA; NVTE_CHECK_CUDA_DRIVER(cuDriverTensorMapEncodeTiled( &tensorMap, tma_dtype, rank, dataPtr, size, stride, boxSize, elemStride, @@ -614,15 +608,11 @@ __device__ __forceinline__ void st_global_cs_uint4(uint4 *ptr, uint4 val) { // TMA load from strided input → smem → non-temporal stores to contiguous output. template -__launch_bounds__(tma_permute_threads) __global__ - void permute_to_grouped_tensor_fwd_kernel(const __grid_constant__ CUtensorMap tma_q_in, - const __grid_constant__ CUtensorMap tma_k_in, - const __grid_constant__ CUtensorMap tma_v_in, - T *__restrict__ q_out, T *__restrict__ k_out, - T *__restrict__ v_out, size_t b, size_t s_q, - size_t h_q, size_t d_qk, size_t s_kv, size_t h_kv, - size_t d_v, unsigned int permute_s_splits, - size_t s_tile_size) { +__launch_bounds__(tma_permute_threads) __global__ void permute_to_grouped_tensor_fwd_kernel( + const __grid_constant__ CUtensorMap tma_q_in, const __grid_constant__ CUtensorMap tma_k_in, + const __grid_constant__ CUtensorMap tma_v_in, T *__restrict__ q_out, T *__restrict__ k_out, + T *__restrict__ v_out, size_t b, size_t s_q, size_t h_q, size_t d_qk, size_t s_kv, size_t h_kv, + size_t d_v, unsigned int permute_s_splits, size_t s_tile_size) { #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) const int which = blockIdx.z; const CUtensorMap *tma_in = which == 0 ? &tma_q_in : (which == 1 ? &tma_k_in : &tma_v_in); @@ -706,8 +696,7 @@ __launch_bounds__(tma_permute_threads) __global__ void permute_to_grouped_tensor const T *__restrict__ grad_q, const T *__restrict__ grad_k, const T *__restrict__ grad_v, const __grid_constant__ CUtensorMap tma_q_out, const __grid_constant__ CUtensorMap tma_k_out, const __grid_constant__ CUtensorMap tma_v_out, size_t b, size_t s_q, size_t h_q, size_t d_qk, - size_t s_kv, size_t h_kv, size_t d_v, unsigned int permute_s_splits, - size_t s_tile_size) { + size_t s_kv, size_t h_kv, size_t d_v, unsigned int permute_s_splits, size_t s_tile_size) { #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) const int which = blockIdx.z; const T *__restrict__ tensor_in = which == 0 ? grad_q : (which == 1 ? grad_k : grad_v); @@ -766,7 +755,6 @@ __launch_bounds__(tma_permute_threads) __global__ void permute_to_grouped_tensor #endif } - // ---- create a 4D TMA descriptor ---- // For BSHD [B, S, H, D]: TMA dims [D, H, S, B], box [D, 1, S_TILE, 1] // For SBHD [S, B, H, D]: TMA dims [D, H, B, S], box [D, 1, 1, S_TILE] @@ -810,7 +798,6 @@ static bool can_use_tma_permute(DType dtype, size_t d_qk, size_t d_v) { return true; } - void permute_to_grouped_tensor_fwd(Tensor q, Tensor k, Tensor v, Tensor q_out, Tensor k_out, Tensor v_out, NVTE_QKV_Format original_format, size_t num_tensors, cudaStream_t stream) { @@ -834,48 +821,46 @@ void permute_to_grouped_tensor_fwd(Tensor q, Tensor k, Tensor v, Tensor q_out, T if (!can_use_tma_permute(q.dtype(), d_qk, d_v)) { const size_t elem_size = typeToSize(q.dtype()); const size_t d_qk_bytes = d_qk * elem_size; - const size_t d_v_bytes = d_v * elem_size; + const size_t d_v_bytes = d_v * elem_size; const bool needs_transpose = (d_qk_bytes % 4 != 0) || (d_v_bytes % 4 != 0); if (needs_transpose) { const size_t s_max = std::max(s_q, s_kv); const size_t h_max = std::max(h_q, h_kv); - const unsigned int st = static_cast( - (s_max + TRANSPOSE_TILE - 1) / TRANSPOSE_TILE); - const unsigned int ht = static_cast( - (h_max + TRANSPOSE_TILE - 1) / TRANSPOSE_TILE); - dim3 grid(static_cast(b) * st, ht, - static_cast(num_tensors)); + const unsigned int st = + static_cast((s_max + TRANSPOSE_TILE - 1) / TRANSPOSE_TILE); + const unsigned int ht = + static_cast((h_max + TRANSPOSE_TILE - 1) / TRANSPOSE_TILE); + dim3 grid(static_cast(b) * st, ht, static_cast(num_tensors)); const size_t d_max = std::max(d_qk, d_v); const size_t D_pad = (d_max * elem_size + 3u) & ~size_t(3); const size_t smem_bytes = - static_cast(TRANSPOSE_TILE) * - (static_cast(TRANSPOSE_TILE) * D_pad + 4); + static_cast(TRANSPOSE_TILE) * (static_cast(TRANSPOSE_TILE) * D_pad + 4); if (is_bshd) { TRANSFORMER_ENGINE_TYPE_SWITCH_ALL( q.dtype(), dtype, permute_to_grouped_tensor_fwd_fallback_not_vec_aligned_kernel - <<>>( - reinterpret_cast(q.data.dptr), - reinterpret_cast(k.data.dptr), - reinterpret_cast(v.data.dptr), - reinterpret_cast(q_out.data.dptr), - reinterpret_cast(k_out.data.dptr), - reinterpret_cast(v_out.data.dptr), - b, s_q, h_q, d_qk, s_kv, h_kv, d_v, st);); + <<>>( + reinterpret_cast(q.data.dptr), + reinterpret_cast(k.data.dptr), + reinterpret_cast(v.data.dptr), + reinterpret_cast(q_out.data.dptr), + reinterpret_cast(k_out.data.dptr), + reinterpret_cast(v_out.data.dptr), b, s_q, h_q, d_qk, s_kv, h_kv, d_v, + st);); } else { TRANSFORMER_ENGINE_TYPE_SWITCH_ALL( q.dtype(), dtype, permute_to_grouped_tensor_fwd_fallback_not_vec_aligned_kernel - <<>>( - reinterpret_cast(q.data.dptr), - reinterpret_cast(k.data.dptr), - reinterpret_cast(v.data.dptr), - reinterpret_cast(q_out.data.dptr), - reinterpret_cast(k_out.data.dptr), - reinterpret_cast(v_out.data.dptr), - b, s_q, h_q, d_qk, s_kv, h_kv, d_v, st);); + <<>>( + reinterpret_cast(q.data.dptr), + reinterpret_cast(k.data.dptr), + reinterpret_cast(v.data.dptr), + reinterpret_cast(q_out.data.dptr), + reinterpret_cast(k_out.data.dptr), + reinterpret_cast(v_out.data.dptr), b, s_q, h_q, d_qk, s_kv, h_kv, d_v, + st);); } NVTE_CHECK_CUDA(cudaGetLastError()); return; @@ -893,26 +878,26 @@ void permute_to_grouped_tensor_fwd(Tensor q, Tensor k, Tensor v, Tensor q_out, T TRANSFORMER_ENGINE_TYPE_SWITCH_ALL( q.dtype(), dtype, permute_to_grouped_tensor_fwd_fallback_vec_aligned_kernel - <<>>( - reinterpret_cast(q.data.dptr), - reinterpret_cast(k.data.dptr), - reinterpret_cast(v.data.dptr), - reinterpret_cast(q_out.data.dptr), - reinterpret_cast(k_out.data.dptr), - reinterpret_cast(v_out.data.dptr), - b, s_q, h_q, d_qk, s_kv, h_kv, d_v, permute_s_splits);); + <<>>( + reinterpret_cast(q.data.dptr), + reinterpret_cast(k.data.dptr), + reinterpret_cast(v.data.dptr), + reinterpret_cast(q_out.data.dptr), + reinterpret_cast(k_out.data.dptr), + reinterpret_cast(v_out.data.dptr), b, s_q, h_q, d_qk, s_kv, h_kv, d_v, + permute_s_splits);); } else { TRANSFORMER_ENGINE_TYPE_SWITCH_ALL( q.dtype(), dtype, permute_to_grouped_tensor_fwd_fallback_vec_aligned_kernel - <<>>( - reinterpret_cast(q.data.dptr), - reinterpret_cast(k.data.dptr), - reinterpret_cast(v.data.dptr), - reinterpret_cast(q_out.data.dptr), - reinterpret_cast(k_out.data.dptr), - reinterpret_cast(v_out.data.dptr), - b, s_q, h_q, d_qk, s_kv, h_kv, d_v, permute_s_splits);); + <<>>( + reinterpret_cast(q.data.dptr), + reinterpret_cast(k.data.dptr), + reinterpret_cast(v.data.dptr), + reinterpret_cast(q_out.data.dptr), + reinterpret_cast(k_out.data.dptr), + reinterpret_cast(v_out.data.dptr), b, s_q, h_q, d_qk, s_kv, h_kv, d_v, + permute_s_splits);); } NVTE_CHECK_CUDA(cudaGetLastError()); return; @@ -922,18 +907,18 @@ void permute_to_grouped_tensor_fwd(Tensor q, Tensor k, Tensor v, Tensor q_out, T const size_t s_min = std::min(s_q, s_kv); const size_t s_tile = std::min(static_cast(tma_permute_s_tile_default), s_min); NVTE_CHECK((s_tile * d_qk * elem_size) % sizeof(uint4) == 0 && - (s_tile * d_v * elem_size) % sizeof(uint4) == 0, - "permute_to_grouped_tensor_fwd: S_TILE(", s_tile, ") * D * elem_size must " - "be divisible by ", sizeof(uint4), ". d_qk=", d_qk, ", d_v=", d_v, - ", elem_size=", elem_size, "."); + (s_tile * d_v * elem_size) % sizeof(uint4) == 0, + "permute_to_grouped_tensor_fwd: S_TILE(", s_tile, + ") * D * elem_size must " + "be divisible by ", + sizeof(uint4), ". d_qk=", d_qk, ", d_v=", d_v, ", elem_size=", elem_size, "."); alignas(64) CUtensorMap tma_q_in{}, tma_k_in{}, tma_v_in{}; create_strided_tensor_map(tma_q_in, q.data.dptr, q.dtype(), b, s_q, h_q, d_qk, s_tile, is_bshd); create_strided_tensor_map(tma_k_in, k.data.dptr, k.dtype(), b, s_kv, h_kv, d_qk, s_tile, is_bshd); create_strided_tensor_map(tma_v_in, v.data.dptr, v.dtype(), b, s_kv, h_kv, d_v, s_tile, is_bshd); - const unsigned int permute_s_splits = - std::max(1u, static_cast(s_min / s_tile)); + const unsigned int permute_s_splits = std::max(1u, static_cast(s_min / s_tile)); const size_t h_grid = std::max(h_q, h_kv); dim3 grid(static_cast(b * h_grid), permute_s_splits, static_cast(num_tensors)); @@ -964,8 +949,8 @@ void permute_to_grouped_tensor_fwd(Tensor q, Tensor k, Tensor v, Tensor q_out, T } void permute_to_grouped_tensor_bwd(Tensor grad_q, Tensor grad_k, Tensor grad_v, Tensor q, Tensor k, - Tensor v, NVTE_QKV_Format original_format, - size_t num_tensors, cudaStream_t stream) { + Tensor v, NVTE_QKV_Format original_format, size_t num_tensors, + cudaStream_t stream) { using namespace transformer_engine; const size_t b = grad_q.shape()[0]; const size_t h_q = grad_q.shape()[1]; @@ -986,48 +971,42 @@ void permute_to_grouped_tensor_bwd(Tensor grad_q, Tensor grad_k, Tensor grad_v, if (!can_use_tma_permute(grad_q.dtype(), d_qk, d_v)) { const size_t elem_size = typeToSize(grad_q.dtype()); const size_t d_qk_bytes = d_qk * elem_size; - const size_t d_v_bytes = d_v * elem_size; + const size_t d_v_bytes = d_v * elem_size; const bool needs_transpose = (d_qk_bytes % 4 != 0) || (d_v_bytes % 4 != 0); if (needs_transpose) { const size_t s_max = std::max(s_q, s_kv); const size_t h_max = std::max(h_q, h_kv); - const unsigned int st = static_cast( - (s_max + TRANSPOSE_TILE - 1) / TRANSPOSE_TILE); - const unsigned int ht = static_cast( - (h_max + TRANSPOSE_TILE - 1) / TRANSPOSE_TILE); - dim3 grid(static_cast(b) * st, ht, - static_cast(num_tensors)); + const unsigned int st = + static_cast((s_max + TRANSPOSE_TILE - 1) / TRANSPOSE_TILE); + const unsigned int ht = + static_cast((h_max + TRANSPOSE_TILE - 1) / TRANSPOSE_TILE); + dim3 grid(static_cast(b) * st, ht, static_cast(num_tensors)); const size_t d_max = std::max(d_qk, d_v); const size_t D_pad = (d_max * elem_size + 3u) & ~size_t(3); const size_t smem_bytes = - static_cast(TRANSPOSE_TILE) * - (static_cast(TRANSPOSE_TILE) * D_pad + 4); + static_cast(TRANSPOSE_TILE) * (static_cast(TRANSPOSE_TILE) * D_pad + 4); if (is_bshd) { TRANSFORMER_ENGINE_TYPE_SWITCH_ALL( grad_q.dtype(), dtype, permute_to_grouped_tensor_bwd_fallback_not_vec_aligned_kernel - <<>>( - reinterpret_cast(grad_q.data.dptr), - reinterpret_cast(grad_k.data.dptr), - reinterpret_cast(grad_v.data.dptr), - reinterpret_cast(q.data.dptr), - reinterpret_cast(k.data.dptr), - reinterpret_cast(v.data.dptr), - b, s_q, h_q, d_qk, s_kv, h_kv, d_v, st);); + <<>>( + reinterpret_cast(grad_q.data.dptr), + reinterpret_cast(grad_k.data.dptr), + reinterpret_cast(grad_v.data.dptr), + reinterpret_cast(q.data.dptr), reinterpret_cast(k.data.dptr), + reinterpret_cast(v.data.dptr), b, s_q, h_q, d_qk, s_kv, h_kv, d_v, st);); } else { TRANSFORMER_ENGINE_TYPE_SWITCH_ALL( grad_q.dtype(), dtype, permute_to_grouped_tensor_bwd_fallback_not_vec_aligned_kernel - <<>>( - reinterpret_cast(grad_q.data.dptr), - reinterpret_cast(grad_k.data.dptr), - reinterpret_cast(grad_v.data.dptr), - reinterpret_cast(q.data.dptr), - reinterpret_cast(k.data.dptr), - reinterpret_cast(v.data.dptr), - b, s_q, h_q, d_qk, s_kv, h_kv, d_v, st);); + <<>>( + reinterpret_cast(grad_q.data.dptr), + reinterpret_cast(grad_k.data.dptr), + reinterpret_cast(grad_v.data.dptr), + reinterpret_cast(q.data.dptr), reinterpret_cast(k.data.dptr), + reinterpret_cast(v.data.dptr), b, s_q, h_q, d_qk, s_kv, h_kv, d_v, st);); } NVTE_CHECK_CUDA(cudaGetLastError()); return; @@ -1044,26 +1023,24 @@ void permute_to_grouped_tensor_bwd(Tensor grad_q, Tensor grad_k, Tensor grad_v, TRANSFORMER_ENGINE_TYPE_SWITCH_ALL( grad_q.dtype(), dtype, permute_to_grouped_tensor_bwd_fallback_vec_aligned_kernel - <<>>( - reinterpret_cast(grad_q.data.dptr), - reinterpret_cast(grad_k.data.dptr), - reinterpret_cast(grad_v.data.dptr), - reinterpret_cast(q.data.dptr), - reinterpret_cast(k.data.dptr), - reinterpret_cast(v.data.dptr), - b, s_q, h_q, d_qk, s_kv, h_kv, d_v, permute_s_splits);); + <<>>( + reinterpret_cast(grad_q.data.dptr), + reinterpret_cast(grad_k.data.dptr), + reinterpret_cast(grad_v.data.dptr), + reinterpret_cast(q.data.dptr), reinterpret_cast(k.data.dptr), + reinterpret_cast(v.data.dptr), b, s_q, h_q, d_qk, s_kv, h_kv, d_v, + permute_s_splits);); } else { TRANSFORMER_ENGINE_TYPE_SWITCH_ALL( grad_q.dtype(), dtype, permute_to_grouped_tensor_bwd_fallback_vec_aligned_kernel - <<>>( - reinterpret_cast(grad_q.data.dptr), - reinterpret_cast(grad_k.data.dptr), - reinterpret_cast(grad_v.data.dptr), - reinterpret_cast(q.data.dptr), - reinterpret_cast(k.data.dptr), - reinterpret_cast(v.data.dptr), - b, s_q, h_q, d_qk, s_kv, h_kv, d_v, permute_s_splits);); + <<>>( + reinterpret_cast(grad_q.data.dptr), + reinterpret_cast(grad_k.data.dptr), + reinterpret_cast(grad_v.data.dptr), + reinterpret_cast(q.data.dptr), reinterpret_cast(k.data.dptr), + reinterpret_cast(v.data.dptr), b, s_q, h_q, d_qk, s_kv, h_kv, d_v, + permute_s_splits);); } NVTE_CHECK_CUDA(cudaGetLastError()); return; @@ -1073,18 +1050,19 @@ void permute_to_grouped_tensor_bwd(Tensor grad_q, Tensor grad_k, Tensor grad_v, const size_t s_min = std::min(s_q, s_kv); const size_t s_tile = std::min(static_cast(tma_permute_s_tile_default), s_min); NVTE_CHECK((s_tile * d_qk * elem_size) % sizeof(uint4) == 0 && - (s_tile * d_v * elem_size) % sizeof(uint4) == 0, - "permute_to_grouped_tensor_bwd: S_TILE(", s_tile, ") * D * elem_size must " - "be divisible by ", sizeof(uint4), ". d_qk=", d_qk, ", d_v=", d_v, - ", elem_size=", elem_size, "."); + (s_tile * d_v * elem_size) % sizeof(uint4) == 0, + "permute_to_grouped_tensor_bwd: S_TILE(", s_tile, + ") * D * elem_size must " + "be divisible by ", + sizeof(uint4), ". d_qk=", d_qk, ", d_v=", d_v, ", elem_size=", elem_size, "."); alignas(64) CUtensorMap tma_q_out{}, tma_k_out{}, tma_v_out{}; create_strided_tensor_map(tma_q_out, q.data.dptr, q.dtype(), b, s_q, h_q, d_qk, s_tile, is_bshd); - create_strided_tensor_map(tma_k_out, k.data.dptr, k.dtype(), b, s_kv, h_kv, d_qk, s_tile, is_bshd); + create_strided_tensor_map(tma_k_out, k.data.dptr, k.dtype(), b, s_kv, h_kv, d_qk, s_tile, + is_bshd); create_strided_tensor_map(tma_v_out, v.data.dptr, v.dtype(), b, s_kv, h_kv, d_v, s_tile, is_bshd); - const unsigned int permute_s_splits = - std::max(1u, static_cast(s_min / s_tile)); + const unsigned int permute_s_splits = std::max(1u, static_cast(s_min / s_tile)); const size_t h_grid = std::max(h_q, h_kv); dim3 grid(static_cast(b * h_grid), permute_s_splits, static_cast(num_tensors)); @@ -1163,11 +1141,11 @@ __launch_bounds__(pad_threads_per_block) __global__ } void multi_tensor_pad_last_dim(Tensor *inputs, Tensor *outputs, size_t num_tensors, - cudaStream_t stream) { + cudaStream_t stream) { using namespace transformer_engine; - NVTE_CHECK(num_tensors > 0 && num_tensors <= kMaxPadTensors, - "num_tensors must be in [1, ", kMaxPadTensors, "], got ", num_tensors, "."); + NVTE_CHECK(num_tensors > 0 && num_tensors <= kMaxPadTensors, "num_tensors must be in [1, ", + kMaxPadTensors, "], got ", num_tensors, "."); MultiPadParams params{}; size_t max_n_uint32 = 0; @@ -1193,15 +1171,15 @@ void multi_tensor_pad_last_dim(Tensor *inputs, Tensor *outputs, size_t num_tenso if (in_cols == out_cols) { const size_t total_bytes = rows * in_cols * typeToSize(inp.data.dtype); NVTE_CHECK_CUDA(cudaMemcpyAsync(out.data.dptr, inp.data.dptr, total_bytes, - cudaMemcpyDeviceToDevice, stream)); + cudaMemcpyDeviceToDevice, stream)); continue; } const size_t elem_size = typeToSize(inp.data.dtype); const auto in_row_bytes = static_cast(in_cols * elem_size); const auto out_row_bytes = static_cast(out_cols * elem_size); - NVTE_CHECK(out_row_bytes % 4 == 0, - "Padded row size in bytes (", out_row_bytes, ") must be a multiple of 4."); + NVTE_CHECK(out_row_bytes % 4 == 0, "Padded row size in bytes (", out_row_bytes, + ") must be a multiple of 4."); const uint32_t out_row_uint32 = out_row_bytes / 4; const size_t n_uint32 = rows * out_row_uint32; @@ -1216,9 +1194,8 @@ void multi_tensor_pad_last_dim(Tensor *inputs, Tensor *outputs, size_t num_tenso if (kernel_count == 0) return; constexpr int threads = pad_threads_per_block; - const int blocks_x = - static_cast(std::min(DIVUP(max_n_uint32, static_cast(threads)), - static_cast(65535))); + const int blocks_x = static_cast( + std::min(DIVUP(max_n_uint32, static_cast(threads)), static_cast(65535))); dim3 grid(blocks_x, kernel_count); multi_tensor_pad_last_dim_kernel<<>>(params); @@ -1273,7 +1250,7 @@ void nvte_permute_to_grouped_tensor_bwd(NVTETensor grad_q, NVTETensor grad_k, NV } void nvte_multi_tensor_pad_last_dim(NVTETensor *inputs, NVTETensor *outputs, size_t num_tensors, - cudaStream_t stream) { + cudaStream_t stream) { NVTE_API_CALL(nvte_multi_tensor_pad_last_dim); using namespace transformer_engine; @@ -1282,5 +1259,6 @@ void nvte_multi_tensor_pad_last_dim(NVTETensor *inputs, NVTETensor *outputs, siz in_vec[i] = *convertNVTETensorCheck(inputs[i]); out_vec[i] = *convertNVTETensorCheck(outputs[i]); } - multi_tensor_pad_last_dim::multi_tensor_pad_last_dim(in_vec.data(), out_vec.data(), num_tensors, stream); + multi_tensor_pad_last_dim::multi_tensor_pad_last_dim(in_vec.data(), out_vec.data(), num_tensors, + stream); } diff --git a/transformer_engine/common/fused_attn/fused_attn.cpp b/transformer_engine/common/fused_attn/fused_attn.cpp index 1792e21fcc..a2d329f5eb 100644 --- a/transformer_engine/common/fused_attn/fused_attn.cpp +++ b/transformer_engine/common/fused_attn/fused_attn.cpp @@ -787,22 +787,19 @@ void nvte_fused_attn_fwd( } } // NVTE fused attention BWD with separate Q, K and V -void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETensor V, - const NVTETensor O, const NVTETensor dO, const NVTETensor S, NVTETensor dP, - const NVTETensorPack *Aux_CTX_Tensors, NVTETensor dQ, NVTETensor dK, - NVTETensor dV, NVTETensor dBias, NVTETensor dSoftmaxOffset, - const NVTETensor cu_seqlens_q, const NVTETensor cu_seqlens_kv, - const NVTETensor cu_seqlens_q_padded, - const NVTETensor cu_seqlens_kv_padded, size_t max_seqlen_q, - size_t max_seqlen_kv, float attn_scale, float dropout, - NVTE_QKV_Layout qkv_layout, NVTE_QKV_Format o_format, - NVTE_QKV_Format do_format, NVTE_QKV_Layout dqkv_layout, - NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, - NVTE_Softmax_Type softmax_type, int64_t window_size_left, - int64_t window_size_right, bool bottom_right_diagonal, bool deterministic, - bool cuda_graph, NVTETensor workspace, cudaStream_t stream, - NVTE_QKV_Format qkv_scale_inv_format, - NVTE_QKV_Format do_scale_inv_format) { +void nvte_fused_attn_bwd( + const NVTETensor Q, const NVTETensor K, const NVTETensor V, const NVTETensor O, + const NVTETensor dO, const NVTETensor S, NVTETensor dP, const NVTETensorPack *Aux_CTX_Tensors, + NVTETensor dQ, NVTETensor dK, NVTETensor dV, NVTETensor dBias, NVTETensor dSoftmaxOffset, + const NVTETensor cu_seqlens_q, const NVTETensor cu_seqlens_kv, + const NVTETensor cu_seqlens_q_padded, const NVTETensor cu_seqlens_kv_padded, + size_t max_seqlen_q, size_t max_seqlen_kv, float attn_scale, float dropout, + NVTE_QKV_Layout qkv_layout, NVTE_QKV_Format o_format, NVTE_QKV_Format do_format, + NVTE_QKV_Layout dqkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, + NVTE_Softmax_Type softmax_type, int64_t window_size_left, int64_t window_size_right, + bool bottom_right_diagonal, bool deterministic, bool cuda_graph, NVTETensor workspace, + cudaStream_t stream, NVTE_QKV_Format qkv_scale_inv_format, + NVTE_QKV_Format do_scale_inv_format) { NVTE_API_CALL(nvte_flash_attn_bwd); using namespace transformer_engine; const Tensor *input_cu_seqlens_q = convertNVTETensorCheck(cu_seqlens_q); diff --git a/transformer_engine/common/fused_attn/fused_attn_fp8.cu b/transformer_engine/common/fused_attn/fused_attn_fp8.cu index 6022c59c2c..152f9b320d 100644 --- a/transformer_engine/common/fused_attn/fused_attn_fp8.cu +++ b/transformer_engine/common/fused_attn/fused_attn_fp8.cu @@ -1662,9 +1662,8 @@ void fused_attn_fp8_fwd_impl_v1( void* devPtrScaleO, void* devPtrAmaxO, void* devPtrAmaxS, void* devPtrcuSeqlensQ, void* devPtrcuSeqlensKV, void* devPtrDropoutSeed, void* devPtrDropoutOffset, cudnn_frontend::DataType_t qkv_tensor_type, cudnn_frontend::DataType_t o_tensor_type, - NVTEScalingMode scaling_mode, NVTE_QKV_Format qkv_scale_inv_format, - void* workspace, size_t* workspace_size, cudaStream_t stream, - cudnnHandle_t handle) { + NVTEScalingMode scaling_mode, NVTE_QKV_Format qkv_scale_inv_format, void* workspace, + size_t* workspace_size, cudaStream_t stream, cudnnHandle_t handle) { using namespace transformer_engine; const auto cudnn_runtime_version = cudnnGetVersion(); bool is_bias = (bias_type == NVTE_Bias_Type::NVTE_POST_SCALE_BIAS); @@ -1826,12 +1825,12 @@ void fused_attn_fp8_fwd_impl_v1( scale_o = mha_graph->tensor(1.0f); } } else if (is_mxfp8) { - NVTE_QKV_Format q_scale_inv_format = - (qkv_scale_inv_format != NVTE_QKV_Format_NOT_SET) - ? qkv_scale_inv_format : nvte_get_q_format(qkv_layout); - NVTE_QKV_Format kv_scale_inv_format = - (qkv_scale_inv_format != NVTE_QKV_Format_NOT_SET) - ? qkv_scale_inv_format : nvte_get_kv_format(qkv_layout); + NVTE_QKV_Format q_scale_inv_format = (qkv_scale_inv_format != NVTE_QKV_Format_NOT_SET) + ? qkv_scale_inv_format + : nvte_get_q_format(qkv_layout); + NVTE_QKV_Format kv_scale_inv_format = (qkv_scale_inv_format != NVTE_QKV_Format_NOT_SET) + ? qkv_scale_inv_format + : nvte_get_kv_format(qkv_layout); std::vector q_scale_strides(4); std::vector k_scale_strides(4); std::vector v_scale_strides(4); @@ -2098,8 +2097,8 @@ void fused_attn_fp8_bwd_impl_v1( void* devPtrDropoutSeed, void* devPtrDropoutOffset, cudnn_frontend::DataType_t qkv_tensor_type, cudnn_frontend::DataType_t o_tensor_type, cudnn_frontend::DataType_t do_tensor_type, cudnn_frontend::DataType_t dqkv_tensor_type, NVTEScalingMode scaling_mode, - NVTE_QKV_Format qkv_scale_inv_format, NVTE_QKV_Format do_scale_inv_format, - void* workspace, size_t* workspace_size, cudaStream_t stream, cudnnHandle_t handle) { + NVTE_QKV_Format qkv_scale_inv_format, NVTE_QKV_Format do_scale_inv_format, void* workspace, + size_t* workspace_size, cudaStream_t stream, cudnnHandle_t handle) { using namespace transformer_engine; const auto cudnn_runtime_version = cudnnGetVersion(); bool is_bias = (bias_type == NVTE_Bias_Type::NVTE_POST_SCALE_BIAS); @@ -2324,14 +2323,11 @@ void fused_attn_fp8_bwd_impl_v1( NVTE_QKV_Format q_format = nvte_get_q_format(qkv_layout); NVTE_QKV_Format kv_format = nvte_get_kv_format(qkv_layout); NVTE_QKV_Format q_scale_inv_format = - (qkv_scale_inv_format != NVTE_QKV_Format_NOT_SET) - ? qkv_scale_inv_format : q_format; + (qkv_scale_inv_format != NVTE_QKV_Format_NOT_SET) ? qkv_scale_inv_format : q_format; NVTE_QKV_Format kv_scale_inv_format = - (qkv_scale_inv_format != NVTE_QKV_Format_NOT_SET) - ? qkv_scale_inv_format : kv_format; + (qkv_scale_inv_format != NVTE_QKV_Format_NOT_SET) ? qkv_scale_inv_format : kv_format; NVTE_QKV_Format do_scale_format_ = - (do_scale_inv_format != NVTE_QKV_Format_NOT_SET) - ? do_scale_inv_format : do_format; + (do_scale_inv_format != NVTE_QKV_Format_NOT_SET) ? do_scale_inv_format : do_format; // Q_t, K_t, dO_t, dO_f16 std::vector q_t_strides(4), k_t_strides(4), dO_t_strides(4); generateMatrixStridesWithFormat(b, h, s_q, d_qk, q_t_strides.data(), q_format); @@ -2834,8 +2830,7 @@ void fused_attn_fp8_fwd(size_t batch, size_t num_attn_heads, size_t num_gqa_grou devPtrDescaleV, devPtrDescaleS, devPtrScaleS, devPtrScaleO, devPtrAmaxO, devPtrAmaxS, devPtrcuSeqlensQ, devPtrcuSeqlensKV, devPtrDropoutSeed, devPtrDropoutOffset, get_cudnn_fe_dtype(QKV_type), get_cudnn_fe_dtype(O_type), input_Q->scaling_mode, - qkv_scale_inv_format, - workspace->data.dptr, &workspace_size, stream, handle); + qkv_scale_inv_format, workspace->data.dptr, &workspace_size, stream, handle); } else if (qkv_layout == NVTE_QKV_Layout::NVTE_T3HD) { fused_attn::fused_attn_fp8_fwd_impl( batch, num_attn_heads, max_seqlen_q, max_seqlen_kv, head_dim_qk, is_training, attn_scale, @@ -2969,9 +2964,8 @@ void fused_attn_fp8_bwd( devPtrDescaleQ_t, devPtrDescaleK_t, devPtrDescaledO_t, devPtrcuSeqlensQ, devPtrcuSeqlensKV, devPtrDropoutSeed, devPtrDropoutOffset, get_cudnn_fe_dtype(QKV_type), get_cudnn_fe_dtype(O_type), get_cudnn_fe_dtype(dO_type), get_cudnn_fe_dtype(dQKV_type), - input_dO->scaling_mode, - qkv_scale_inv_format, do_scale_inv_format, - workspace->data.dptr, &workspace_size, stream, handle); + input_dO->scaling_mode, qkv_scale_inv_format, do_scale_inv_format, workspace->data.dptr, + &workspace_size, stream, handle); } else if (dqkv_layout == NVTE_QKV_Layout::NVTE_T3HD) { // remove this when cuDNN FE supports FP8 + THD NVTE_CHECK(input_ZInv != nullptr && input_ZInv->data.dptr != nullptr, diff --git a/transformer_engine/common/fused_attn/utils.h b/transformer_engine/common/fused_attn/utils.h index 822a0f61f3..d68277d92b 100644 --- a/transformer_engine/common/fused_attn/utils.h +++ b/transformer_engine/common/fused_attn/utils.h @@ -316,21 +316,19 @@ struct FADescriptor_v1 { return std::tie(b, h, hg, s_q, s_kv, d_qk, d_v, num_pages_k, num_pages_v, page_size_k, page_size_v, max_pages_per_seq_k, max_pages_per_seq_v, bias_b, bias_h, bias_sq, bias_skv, attnScale, isTraining, dropoutProbability, qkv_layout, o_format, - do_format, dqkv_layout, qkv_scale_inv_format, do_scale_inv_format, - mask_type, softmax_type, window_size_left, - window_size_right, bottom_right_diagonal, deterministic, bias_type, - qkv_tensor_type, o_tensor_type, do_tensor_type, dqkv_tensor_type, - return_max_logit) < + do_format, dqkv_layout, qkv_scale_inv_format, do_scale_inv_format, mask_type, + softmax_type, window_size_left, window_size_right, bottom_right_diagonal, + deterministic, bias_type, qkv_tensor_type, o_tensor_type, do_tensor_type, + dqkv_tensor_type, return_max_logit) < std::tie(rhs.b, rhs.h, rhs.hg, rhs.s_q, rhs.s_kv, rhs.d_qk, rhs.d_v, rhs.num_pages_k, rhs.num_pages_v, rhs.page_size_k, rhs.page_size_v, rhs.max_pages_per_seq_k, rhs.max_pages_per_seq_v, rhs.bias_b, rhs.bias_h, rhs.bias_sq, rhs.bias_skv, rhs.attnScale, rhs.isTraining, rhs.dropoutProbability, rhs.qkv_layout, - rhs.o_format, rhs.do_format, rhs.dqkv_layout, - rhs.qkv_scale_inv_format, rhs.do_scale_inv_format, - rhs.mask_type, rhs.softmax_type, - rhs.window_size_left, rhs.window_size_right, rhs.bottom_right_diagonal, - rhs.deterministic, rhs.bias_type, rhs.qkv_tensor_type, rhs.o_tensor_type, - rhs.do_tensor_type, rhs.dqkv_tensor_type, rhs.return_max_logit); + rhs.o_format, rhs.do_format, rhs.dqkv_layout, rhs.qkv_scale_inv_format, + rhs.do_scale_inv_format, rhs.mask_type, rhs.softmax_type, rhs.window_size_left, + rhs.window_size_right, rhs.bottom_right_diagonal, rhs.deterministic, + rhs.bias_type, rhs.qkv_tensor_type, rhs.o_tensor_type, rhs.do_tensor_type, + rhs.dqkv_tensor_type, rhs.return_max_logit); } }; diff --git a/transformer_engine/common/include/transformer_engine/fused_attn.h b/transformer_engine/common/include/transformer_engine/fused_attn.h index 1ad6b8f889..028abe9138 100644 --- a/transformer_engine/common/include/transformer_engine/fused_attn.h +++ b/transformer_engine/common/include/transformer_engine/fused_attn.h @@ -385,22 +385,19 @@ void nvte_fused_attn_fwd( * \param[in] workspace Workspace tensor. * \param[in] stream CUDA stream used for this operation. */ -void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETensor V, - const NVTETensor O, const NVTETensor dO, const NVTETensor S, NVTETensor dP, - const NVTETensorPack *Aux_CTX_Tensors, NVTETensor dQ, NVTETensor dK, - NVTETensor dV, NVTETensor dBias, NVTETensor dSoftmaxOffset, - const NVTETensor cu_seqlens_q, const NVTETensor cu_seqlens_kv, - const NVTETensor cu_seqlens_q_padded, - const NVTETensor cu_seqlens_kv_padded, size_t max_seqlen_q, - size_t max_seqlen_kv, float attn_scale, float dropout, - NVTE_QKV_Layout qkv_layout, NVTE_QKV_Format o_format, - NVTE_QKV_Format do_format, NVTE_QKV_Layout dqkv_layout, - NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, - NVTE_Softmax_Type softmax_type, int64_t window_size_left, - int64_t window_size_right, bool bottom_right_diagonal, bool deterministic, - bool cuda_graph, NVTETensor workspace, cudaStream_t stream, - NVTE_QKV_Format qkv_scale_inv_format = NVTE_QKV_Format_NOT_SET, - NVTE_QKV_Format do_scale_inv_format = NVTE_QKV_Format_NOT_SET); +void nvte_fused_attn_bwd( + const NVTETensor Q, const NVTETensor K, const NVTETensor V, const NVTETensor O, + const NVTETensor dO, const NVTETensor S, NVTETensor dP, const NVTETensorPack *Aux_CTX_Tensors, + NVTETensor dQ, NVTETensor dK, NVTETensor dV, NVTETensor dBias, NVTETensor dSoftmaxOffset, + const NVTETensor cu_seqlens_q, const NVTETensor cu_seqlens_kv, + const NVTETensor cu_seqlens_q_padded, const NVTETensor cu_seqlens_kv_padded, + size_t max_seqlen_q, size_t max_seqlen_kv, float attn_scale, float dropout, + NVTE_QKV_Layout qkv_layout, NVTE_QKV_Format o_format, NVTE_QKV_Format do_format, + NVTE_QKV_Layout dqkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, + NVTE_Softmax_Type softmax_type, int64_t window_size_left, int64_t window_size_right, + bool bottom_right_diagonal, bool deterministic, bool cuda_graph, NVTETensor workspace, + cudaStream_t stream, NVTE_QKV_Format qkv_scale_inv_format = NVTE_QKV_Format_NOT_SET, + NVTE_QKV_Format do_scale_inv_format = NVTE_QKV_Format_NOT_SET); /*! \brief Update the RNG state with the seed and calculated offset. * diff --git a/transformer_engine/common/swizzle/swizzle.cu b/transformer_engine/common/swizzle/swizzle.cu index b6b4a05626..ab8d03774a 100644 --- a/transformer_engine/common/swizzle/swizzle.cu +++ b/transformer_engine/common/swizzle/swizzle.cu @@ -297,10 +297,10 @@ __global__ void __launch_bounds__(TB_DIM* TB_DIM) // an M-tile, threadIdx.y indexes M-tiles within the block, processing TB_DIM // M-tiles per block with full thread utilization. template -__device__ void swizzle_row_scaling_narrow_k_kernel_impl( - const void* input, void* output, const int M, const int K, - const int original_M, const int original_K, - const int bid, const int grid_dim) { +__device__ void swizzle_row_scaling_narrow_k_kernel_impl(const void* input, void* output, + const int M, const int K, + const int original_M, const int original_K, + const int bid, const int grid_dim) { constexpr int SF_TILE_SIZE_I32 = SF_TILE_DIM_M * SF_TILE_DIM_K / 4; const int K_i32 = K / 4; const int num_tiles_m = M / SF_TILE_DIM_M; @@ -337,8 +337,7 @@ __device__ void swizzle_row_scaling_narrow_k_kernel_impl( } } - my_slm[k * (SF_TILE_SIZE_I32 / 4) + threadIdx.x] = - *reinterpret_cast(regs); + my_slm[k * (SF_TILE_SIZE_I32 / 4) + threadIdx.x] = *reinterpret_cast(regs); } } @@ -346,8 +345,8 @@ __device__ void swizzle_row_scaling_narrow_k_kernel_impl( if (active) { int4* my_slm = slm_v4i + threadIdx.y * slm_tile_v4i; - int4* out_v4i = reinterpret_cast( - reinterpret_cast(output) + m_tile * SF_TILE_DIM_M * K_i32); + int4* out_v4i = + reinterpret_cast(reinterpret_cast(output) + m_tile * SF_TILE_DIM_M * K_i32); for (int i = threadIdx.x; i < slm_tile_v4i; i += blockDim.x) { out_v4i[i] = my_slm[i]; @@ -847,9 +846,8 @@ void swizzle_scaling_factors(const Tensor* input, Tensor* output, cudaStream_t s dim3 num_blocks_narrow(DIVUP(num_tiles_m, TB_DIM)); int slm_size = TB_DIM * num_tiles_k * SF_TILE_DIM_M * SF_TILE_DIM_K * sizeof(int8_t); NVTE_CHECK_CUDA( - cudaFuncSetAttribute( - swizzle_row_scaling_narrow_k_kernel, - cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size)); + cudaFuncSetAttribute(swizzle_row_scaling_narrow_k_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size)); swizzle_row_scaling_narrow_k_kernel <<>>( input_scale_inv_ptr, output_scale_inv_ptr, m, k, original_M, original_K); @@ -943,14 +941,13 @@ void swizzle_scaling_factors(const Tensor* input, Tensor* output, cudaStream_t s template void launch_multi_tensor_swizzle_scaling_factors(MultiSwizzleArgs& kernel_args, const int vec_load_size, const bool is_rowwise, - const bool use_narrow_k, - cudaStream_t stream) { + const bool use_narrow_k, cudaStream_t stream) { // cudaFuncSetAttribute is a host-synchronous driver call; cache the max shared memory // setting per kernel variant so we only pay the cost when slm_size actually increases. auto set_smem_if_needed = [](auto kernel_fn, int slm, int& cached) { if (cached < slm) { - NVTE_CHECK_CUDA(cudaFuncSetAttribute( - kernel_fn, cudaFuncAttributeMaxDynamicSharedMemorySize, slm)); + NVTE_CHECK_CUDA( + cudaFuncSetAttribute(kernel_fn, cudaFuncAttributeMaxDynamicSharedMemorySize, slm)); cached = slm; } }; @@ -965,16 +962,15 @@ void launch_multi_tensor_swizzle_scaling_factors(MultiSwizzleArgs& kernel_args, const int num_tiles_m = kernel_args.m_list[j] / SF_TILE_DIM_M; const int num_tiles_k = kernel_args.k_list[j] / SF_TILE_DIM_K; max_num_tiles_k = std::max(max_num_tiles_k, num_tiles_k); - kernel_args.block_range[j + 1] = - kernel_args.block_range[j] + DIVUP(num_tiles_m, TB_DIM); + kernel_args.block_range[j + 1] = kernel_args.block_range[j] + DIVUP(num_tiles_m, TB_DIM); } int slm_size = TB_DIM * max_num_tiles_k * SF_TILE_DIM_M * SF_TILE_DIM_K * sizeof(int8_t); const int num_blocks = kernel_args.block_range[kernel_args.num_tensors]; static int cached_narrow_k = -1; set_smem_if_needed( - multi_tensor_swizzle_row_scaling_narrow_k_kernel, - slm_size, cached_narrow_k); + multi_tensor_swizzle_row_scaling_narrow_k_kernel, slm_size, + cached_narrow_k); multi_tensor_swizzle_row_scaling_narrow_k_kernel <<>>(kernel_args); } else { @@ -1007,22 +1003,22 @@ void launch_multi_tensor_swizzle_scaling_factors(MultiSwizzleArgs& kernel_args, switch (vec_load_size) { case 4: set_smem_if_needed( - multi_tensor_swizzle_row_scaling_kernel, - slm_size, cached_row_int4); + multi_tensor_swizzle_row_scaling_kernel, slm_size, + cached_row_int4); multi_tensor_swizzle_row_scaling_kernel <<>>(kernel_args); break; case 2: set_smem_if_needed( - multi_tensor_swizzle_row_scaling_kernel, - slm_size, cached_row_int2); + multi_tensor_swizzle_row_scaling_kernel, slm_size, + cached_row_int2); multi_tensor_swizzle_row_scaling_kernel <<>>(kernel_args); break; case 1: set_smem_if_needed( - multi_tensor_swizzle_row_scaling_kernel, - slm_size, cached_row_int1); + multi_tensor_swizzle_row_scaling_kernel, slm_size, + cached_row_int1); multi_tensor_swizzle_row_scaling_kernel <<>>(kernel_args); break; @@ -1034,22 +1030,22 @@ void launch_multi_tensor_swizzle_scaling_factors(MultiSwizzleArgs& kernel_args, switch (vec_load_size) { case 4: set_smem_if_needed( - multi_tensor_swizzle_col_scaling_kernel, - slm_size, cached_col_int4); + multi_tensor_swizzle_col_scaling_kernel, slm_size, + cached_col_int4); multi_tensor_swizzle_col_scaling_kernel <<>>(kernel_args); break; case 2: set_smem_if_needed( - multi_tensor_swizzle_col_scaling_kernel, - slm_size, cached_col_int2); + multi_tensor_swizzle_col_scaling_kernel, slm_size, + cached_col_int2); multi_tensor_swizzle_col_scaling_kernel <<>>(kernel_args); break; case 1: set_smem_if_needed( - multi_tensor_swizzle_col_scaling_kernel, - slm_size, cached_col_int1); + multi_tensor_swizzle_col_scaling_kernel, slm_size, + cached_col_int1); multi_tensor_swizzle_col_scaling_kernel <<>>(kernel_args); break; diff --git a/transformer_engine/common/transformer_engine.cpp b/transformer_engine/common/transformer_engine.cpp index 4351ee3061..0d17eb9870 100644 --- a/transformer_engine/common/transformer_engine.cpp +++ b/transformer_engine/common/transformer_engine.cpp @@ -144,8 +144,7 @@ void CheckScaleTensorShape(const Tensor &t, const std::string &name) { } } -void CheckInputTensor(const Tensor &t, const std::string &name, - bool check_scale_inv_shapes) { +void CheckInputTensor(const Tensor &t, const std::string &name, bool check_scale_inv_shapes) { const DType type = t.dtype(); if (is_fp8_dtype(type)) { // FP8 input needs to have scale_inv diff --git a/transformer_engine/common/util/pybind_helper.h b/transformer_engine/common/util/pybind_helper.h index 6e1f45ccde..fdfa47da8f 100644 --- a/transformer_engine/common/util/pybind_helper.h +++ b/transformer_engine/common/util/pybind_helper.h @@ -49,8 +49,8 @@ .value("NVTE_BSHD_2SBHD", NVTE_QKV_Format::NVTE_BSHD_2SBHD) \ .value("NVTE_THD_2BSHD", NVTE_QKV_Format::NVTE_THD_2BSHD) \ .value("NVTE_THD_2SBHD", NVTE_QKV_Format::NVTE_THD_2SBHD) \ - .value("NVTE_BHSD", NVTE_QKV_Format::NVTE_BHSD) \ - .value("NVTE_QKV_Format_NOT_SET", NVTE_QKV_Format::NVTE_QKV_Format_NOT_SET); \ + .value("NVTE_BHSD", NVTE_QKV_Format::NVTE_BHSD) \ + .value("NVTE_QKV_Format_NOT_SET", NVTE_QKV_Format::NVTE_QKV_Format_NOT_SET); \ pybind11::enum_(m, "NVTE_QKV_Layout", pybind11::module_local()) \ .value("NVTE_SB3HD", NVTE_QKV_Layout::NVTE_SB3HD) \ .value("NVTE_SBH3D", NVTE_QKV_Layout::NVTE_SBH3D) \ diff --git a/transformer_engine/pytorch/attention/dot_product_attention/backends.py b/transformer_engine/pytorch/attention/dot_product_attention/backends.py index b35f2380e4..13e39b02b8 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/backends.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/backends.py @@ -189,7 +189,11 @@ def forward(ctx, tensor1, tensor2, tensor3, quantizer, quantizer_name, qkv_layou ] # always in sbhd_sbhd_sbhd shape at this point q_fp8, k_fp8, v_fp8, qkv_layout, _ = combine_and_quantize( - qkv_layout, query_layer, key_layer, value_layer, quantizer, + qkv_layout, + query_layer, + key_layer, + value_layer, + quantizer, keep_same_data_and_scale_inv_format=True, ) tensors = combine_and_dequantize( @@ -224,7 +228,11 @@ def backward(ctx, grad1, grad2, grad3): query_grad, key_grad, value_grad = [x.contiguous() for x in [grad1, grad2, grad3]] # always in sbhd_sbhd_sbhd shape at this point dq_fp8, dk_fp8, dv_fp8, new_qkv_layout, _ = combine_and_quantize( - ctx.qkv_layout, query_grad, key_grad, value_grad, ctx.quantizer, + ctx.qkv_layout, + query_grad, + key_grad, + value_grad, + ctx.quantizer, keep_same_data_and_scale_inv_format=True, ) tensors = combine_and_dequantize( @@ -1665,7 +1673,9 @@ def backward(ctx, d_out, *_args): d_out_fp8 = d_out elif isinstance(ctx.dO_quantizer, MXFP8Quantizer): d_out_fp8, do_scale_inv_format = mxfp8_quantize_single_tensor( - d_out, ctx.dO_quantizer, do_format, + d_out, + ctx.dO_quantizer, + do_format, ) else: d_out_fp8 = ctx.dO_quantizer(d_out) diff --git a/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py b/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py index c467dfb45a..425ba7622f 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py @@ -947,10 +947,8 @@ def cp_p2p_fwd_fused_attn( for x, y in zip([q_fp8, k_fp8, v_fp8], [q_part, k_part, v_part]) ] else: - q_part, k_part, v_part, new_qkv_layout, qkv_scale_inv_format = ( - combine_and_quantize( - qkv_layout, q_part, k_part, v_part, QKV_quantizer - ) + q_part, k_part, v_part, new_qkv_layout, qkv_scale_inv_format = combine_and_quantize( + qkv_layout, q_part, k_part, v_part, QKV_quantizer ) fp8_meta_kwargs["s_quantizer"] = S_quantizer_per_step fp8_meta_kwargs["o_quantizer"] = O_quantizer_per_step @@ -1229,16 +1227,14 @@ def cp_p2p_bwd_fused_attn( ) ] else: - q_part, k_part, v_part, qkv_layout, qkv_scale_inv_format = ( - combine_and_quantize( - qkv_layout, - q_part, - k_part, - v_part, - QKV_quantizer_per_step, - used_in_forward=False, - used_in_backward=True, - ) + q_part, k_part, v_part, qkv_layout, qkv_scale_inv_format = combine_and_quantize( + qkv_layout, + q_part, + k_part, + v_part, + QKV_quantizer_per_step, + used_in_forward=False, + used_in_backward=True, ) if not fp8_recipe.mxfp8(): if not (fp8_recipe.float8_current_scaling() and _dpa_fp8_cs_o_in_f16): @@ -1247,7 +1243,9 @@ def cp_p2p_bwd_fused_attn( else: aux_tensors.append(dout_part) dout_part, do_scale_inv_format = mxfp8_quantize_single_tensor( - dout_part, dO_quantizer_per_step, do_format, + dout_part, + dO_quantizer_per_step, + do_format, ) fp8_meta_kwargs["s_quantizer"] = S_quantizer fp8_meta_kwargs["dp_quantizer"] = dP_quantizer_per_step @@ -3646,7 +3644,9 @@ def backward(ctx, dout, *_args): ) aux_ctx_tensors.append(dout_part) dout_part, do_scale_inv_format = mxfp8_quantize_single_tensor( - dout_part, ctx.dO_quantizer, do_format, + dout_part, + ctx.dO_quantizer, + do_format, ) dq_per_step[i], dk_per_step[i], dv_per_step[i], *_ = fused_attn_bwd( ctx.max_seqlen_q, @@ -4023,15 +4023,13 @@ def forward( if use_fused_attention: if fp8: if fp8_recipe.mxfp8(): - q_fp8, k_fp8, v_fp8, qkv_layout, qkv_scale_inv_format = ( - combine_and_quantize( - qkv_layout, - q_part, - k_part, - v_part, - QKV_quantizer, - used_in_backward=is_training, - ) + q_fp8, k_fp8, v_fp8, qkv_layout, qkv_scale_inv_format = combine_and_quantize( + qkv_layout, + q_part, + k_part, + v_part, + QKV_quantizer, + used_in_backward=is_training, ) q_part, k_part, v_part = [q_fp8, k_fp8, v_fp8] else: @@ -4373,7 +4371,9 @@ def backward(ctx, dout, *_args): else: aux_ctx_tensors.append(dout) dout_part, do_scale_inv_format = mxfp8_quantize_single_tensor( - dout, ctx.dO_quantizer, do_format, + dout, + ctx.dO_quantizer, + do_format, ) dq, dk, dv, *rest = fused_attn_bwd( ctx.max_seqlen_q, diff --git a/transformer_engine/pytorch/attention/dot_product_attention/utils.py b/transformer_engine/pytorch/attention/dot_product_attention/utils.py index c0a3ce7dda..029f1c3c0b 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/utils.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/utils.py @@ -2314,9 +2314,6 @@ def print_quantizers( print(f"{label} >> {names[i]:14s}: {type_str}") - - - def mxfp8_pad_and_swizzle_scales(*fp8_tensors): """Pad and swizzle scales for MXFP8 tensors quantized with optimize_for_gemm=False. @@ -2346,13 +2343,9 @@ def mxfp8_pad_and_swizzle_scales(*fp8_tensors): has_cs = len(cs_list) > 0 tensor_list = list(fp8_tensors) if has_rs: - tex.multi_swizzle_scales_for_gemm_( - tensor_list, True, False, check_scale_inv_shapes=False - ) + tex.multi_swizzle_scales_for_gemm_(tensor_list, True, False, check_scale_inv_shapes=False) if has_cs: - tex.multi_swizzle_scales_for_gemm_( - tensor_list, False, True, check_scale_inv_shapes=False - ) + tex.multi_swizzle_scales_for_gemm_(tensor_list, False, True, check_scale_inv_shapes=False) for t in tensor_list: t._with_gemm_swizzled_scales = True @@ -2367,11 +2360,16 @@ def mxfp8_permute_scale_inv_to_bhsd(*tensors, src_format): if src_format in ("bhsd", "htd"): outs = [ MXFP8Tensor( - shape=t.shape, dtype=t.dtype, - rowwise_data=t._rowwise_data, rowwise_scale_inv=t._rowwise_scale_inv, - columnwise_data=t._columnwise_data, columnwise_scale_inv=t._columnwise_scale_inv, - quantizer=t._quantizer, requires_grad=False, - fp8_dtype=t._fp8_dtype, with_gemm_swizzled_scales=t._with_gemm_swizzled_scales, + shape=t.shape, + dtype=t.dtype, + rowwise_data=t._rowwise_data, + rowwise_scale_inv=t._rowwise_scale_inv, + columnwise_data=t._columnwise_data, + columnwise_scale_inv=t._columnwise_scale_inv, + quantizer=t._quantizer, + requires_grad=False, + fp8_dtype=t._fp8_dtype, + with_gemm_swizzled_scales=t._with_gemm_swizzled_scales, ) for t in tensors ] @@ -2405,16 +2403,23 @@ def mxfp8_permute_scale_inv_to_bhsd(*tensors, src_format): cs_permuted = tex.permute_to_grouped_tensor_fwd(*cs_4d_list, original_format=src_format) outs = [] - for i, (t, rp, rd, cp, cd) in enumerate(zip(tensors, rs_permuted, rs_d_scales, cs_permuted, cs_d_scales)): + for i, (t, rp, rd, cp, cd) in enumerate( + zip(tensors, rs_permuted, rs_d_scales, cs_permuted, cs_d_scales) + ): rp = rp.view(-1, rd) if rd is not None else None cp = cp.view(-1, cd) if cd is not None else None outs.append( MXFP8Tensor( - shape=t.shape, dtype=t.dtype, - rowwise_data=t._rowwise_data, rowwise_scale_inv=rp, - columnwise_data=t._columnwise_data, columnwise_scale_inv=cp, - quantizer=t._quantizer, requires_grad=False, - fp8_dtype=t._fp8_dtype, with_gemm_swizzled_scales=t._with_gemm_swizzled_scales, + shape=t.shape, + dtype=t.dtype, + rowwise_data=t._rowwise_data, + rowwise_scale_inv=rp, + columnwise_data=t._columnwise_data, + columnwise_scale_inv=cp, + quantizer=t._quantizer, + requires_grad=False, + fp8_dtype=t._fp8_dtype, + with_gemm_swizzled_scales=t._with_gemm_swizzled_scales, ) ) @@ -2445,26 +2450,45 @@ def mxfp8_quantize_single_tensor(tensor, quantizer, src_format): _s_dim = {"bshd": 2, "sbhd": 0, "bhsd": 2} _d_dim = {"bshd": 3, "sbhd": 3, "bhsd": 3} rowwise_scale_inv_shape = list(tensor.shape) - rowwise_scale_inv_shape[_d_dim[src_format]] = rowwise_scale_inv_shape[_d_dim[src_format]]//MXFP8_BLOCK_SCALING_SIZE + rowwise_scale_inv_shape[_d_dim[src_format]] = ( + rowwise_scale_inv_shape[_d_dim[src_format]] // MXFP8_BLOCK_SCALING_SIZE + ) columnwise_scale_inv_shape = list(tensor.shape) - columnwise_scale_inv_shape[_s_dim[src_format]] = columnwise_scale_inv_shape[_s_dim[src_format]]//MXFP8_BLOCK_SCALING_SIZE + columnwise_scale_inv_shape[_s_dim[src_format]] = ( + columnwise_scale_inv_shape[_s_dim[src_format]] // MXFP8_BLOCK_SCALING_SIZE + ) if src_format == "bhsd": - tensor = tensor.view(tensor.shape[:_s_dim[src_format]], -1) + tensor = tensor.view(tensor.shape[: _s_dim[src_format]], -1) elif src_format == "sbhd": tensor = tensor.view(tensor.shape[_s_dim[src_format]], -1) orig_optimize = quantizer.optimize_for_gemm quantizer.optimize_for_gemm = False fp8_tensor = quantizer(tensor) quantizer.optimize_for_gemm = orig_optimize - fp8_tensor._rowwise_data = fp8_tensor._rowwise_data.view(original_shape) if fp8_tensor._rowwise_data is not None else None - fp8_tensor._columnwise_data = fp8_tensor._columnwise_data.view(original_shape) if fp8_tensor._columnwise_data is not None else None - fp8_tensor._rowwise_scale_inv = fp8_tensor._rowwise_scale_inv.view(rowwise_scale_inv_shape) if fp8_tensor._rowwise_scale_inv is not None else None - fp8_tensor._columnwise_scale_inv = fp8_tensor._columnwise_scale_inv.view(columnwise_scale_inv_shape) if fp8_tensor._columnwise_scale_inv is not None else None + fp8_tensor._rowwise_data = ( + fp8_tensor._rowwise_data.view(original_shape) + if fp8_tensor._rowwise_data is not None + else None + ) + fp8_tensor._columnwise_data = ( + fp8_tensor._columnwise_data.view(original_shape) + if fp8_tensor._columnwise_data is not None + else None + ) + fp8_tensor._rowwise_scale_inv = ( + fp8_tensor._rowwise_scale_inv.view(rowwise_scale_inv_shape) + if fp8_tensor._rowwise_scale_inv is not None + else None + ) + fp8_tensor._columnwise_scale_inv = ( + fp8_tensor._columnwise_scale_inv.view(columnwise_scale_inv_shape) + if fp8_tensor._columnwise_scale_inv is not None + else None + ) (fp8_tensor,) = mxfp8_permute_scale_inv_to_bhsd(fp8_tensor, src_format=src_format) return fp8_tensor, "bhsd" - def combine_and_quantize( qkv_layout, q, @@ -2551,8 +2575,8 @@ def combine_and_quantize( for x in [q, k, v]: rs_shape = list(x.shape) cs_shape = list(x.shape) - rs_shape[_d_dim[qkv_format]] = rs_shape[_d_dim[qkv_format]]//MXFP8_BLOCK_SCALING_SIZE - cs_shape[_s_dim[qkv_format]] = cs_shape[_s_dim[qkv_format]]//MXFP8_BLOCK_SCALING_SIZE + rs_shape[_d_dim[qkv_format]] = rs_shape[_d_dim[qkv_format]] // MXFP8_BLOCK_SCALING_SIZE + cs_shape[_s_dim[qkv_format]] = cs_shape[_s_dim[qkv_format]] // MXFP8_BLOCK_SCALING_SIZE rowwise_scale_inv_shapes.append(rs_shape) columnwise_scale_inv_shapes.append(cs_shape) @@ -2603,23 +2627,74 @@ def combine_and_quantize( if not keep_same_data_and_scale_inv_format: qkv_quantizer.optimize_for_gemm = orig_optimize - q_fp8._rowwise_data = q_fp8._rowwise_data.view(original_shapes[0]) if q_fp8._rowwise_data is not None else None - q_fp8._columnwise_data = q_fp8._columnwise_data.view(original_shapes[0]) if q_fp8._columnwise_data is not None else None - k_fp8._rowwise_data = k_fp8._rowwise_data.view(original_shapes[1]) if k_fp8._rowwise_data is not None else None - k_fp8._columnwise_data = k_fp8._columnwise_data.view(original_shapes[1]) if k_fp8._columnwise_data is not None else None - v_fp8._rowwise_data = v_fp8._rowwise_data.view(original_shapes[2]) if v_fp8._rowwise_data is not None else None - v_fp8._columnwise_data = v_fp8._columnwise_data.view(original_shapes[2]) if v_fp8._columnwise_data is not None else None + q_fp8._rowwise_data = ( + q_fp8._rowwise_data.view(original_shapes[0]) + if q_fp8._rowwise_data is not None + else None + ) + q_fp8._columnwise_data = ( + q_fp8._columnwise_data.view(original_shapes[0]) + if q_fp8._columnwise_data is not None + else None + ) + k_fp8._rowwise_data = ( + k_fp8._rowwise_data.view(original_shapes[1]) + if k_fp8._rowwise_data is not None + else None + ) + k_fp8._columnwise_data = ( + k_fp8._columnwise_data.view(original_shapes[1]) + if k_fp8._columnwise_data is not None + else None + ) + v_fp8._rowwise_data = ( + v_fp8._rowwise_data.view(original_shapes[2]) + if v_fp8._rowwise_data is not None + else None + ) + v_fp8._columnwise_data = ( + v_fp8._columnwise_data.view(original_shapes[2]) + if v_fp8._columnwise_data is not None + else None + ) if not keep_same_data_and_scale_inv_format: # Permute only scale_inv to BHSD + pad + swizzle - q_fp8._rowwise_scale_inv = q_fp8._rowwise_scale_inv.view(rowwise_scale_inv_shapes[0]) if q_fp8._rowwise_scale_inv is not None else None - q_fp8._columnwise_scale_inv = q_fp8._columnwise_scale_inv.view(columnwise_scale_inv_shapes[0]) if q_fp8._columnwise_scale_inv is not None else None - k_fp8._rowwise_scale_inv = k_fp8._rowwise_scale_inv.view(rowwise_scale_inv_shapes[1]) if k_fp8._rowwise_scale_inv is not None else None - k_fp8._columnwise_scale_inv = k_fp8._columnwise_scale_inv.view(columnwise_scale_inv_shapes[1]) if k_fp8._columnwise_scale_inv is not None else None - v_fp8._rowwise_scale_inv = v_fp8._rowwise_scale_inv.view(rowwise_scale_inv_shapes[2]) if v_fp8._rowwise_scale_inv is not None else None - v_fp8._columnwise_scale_inv = v_fp8._columnwise_scale_inv.view(columnwise_scale_inv_shapes[2]) if v_fp8._columnwise_scale_inv is not None else None + q_fp8._rowwise_scale_inv = ( + q_fp8._rowwise_scale_inv.view(rowwise_scale_inv_shapes[0]) + if q_fp8._rowwise_scale_inv is not None + else None + ) + q_fp8._columnwise_scale_inv = ( + q_fp8._columnwise_scale_inv.view(columnwise_scale_inv_shapes[0]) + if q_fp8._columnwise_scale_inv is not None + else None + ) + k_fp8._rowwise_scale_inv = ( + k_fp8._rowwise_scale_inv.view(rowwise_scale_inv_shapes[1]) + if k_fp8._rowwise_scale_inv is not None + else None + ) + k_fp8._columnwise_scale_inv = ( + k_fp8._columnwise_scale_inv.view(columnwise_scale_inv_shapes[1]) + if k_fp8._columnwise_scale_inv is not None + else None + ) + v_fp8._rowwise_scale_inv = ( + v_fp8._rowwise_scale_inv.view(rowwise_scale_inv_shapes[2]) + if v_fp8._rowwise_scale_inv is not None + else None + ) + v_fp8._columnwise_scale_inv = ( + v_fp8._columnwise_scale_inv.view(columnwise_scale_inv_shapes[2]) + if v_fp8._columnwise_scale_inv is not None + else None + ) q_fp8, k_fp8, v_fp8 = mxfp8_permute_scale_inv_to_bhsd( - q_fp8, k_fp8, v_fp8, src_format=q_format, + q_fp8, + k_fp8, + v_fp8, + src_format=q_format, ) qkv_scale_inv_format = "bhsd" diff --git a/transformer_engine/pytorch/csrc/extensions.h b/transformer_engine/pytorch/csrc/extensions.h index c591957aa5..31b20e771c 100644 --- a/transformer_engine/pytorch/csrc/extensions.h +++ b/transformer_engine/pytorch/csrc/extensions.h @@ -115,14 +115,17 @@ std::vector fused_attn_bwd( at::Tensor fa_prepare_fwd(at::Tensor qkvi); at::Tensor fa_prepare_bwd(at::Tensor q, at::Tensor k, at::Tensor v); -std::vector permute_to_grouped_tensor_fwd( - at::Tensor query, std::optional key, std::optional value, - const std::string &original_format); -std::vector permute_to_grouped_tensor_bwd( - at::Tensor query_grad, std::optional key_grad, - std::optional value_grad, const std::string &original_format); - -std::vector multi_tensor_pad_last_dim(std::vector inputs, int64_t alignment); +std::vector permute_to_grouped_tensor_fwd(at::Tensor query, + std::optional key, + std::optional value, + const std::string &original_format); +std::vector permute_to_grouped_tensor_bwd(at::Tensor query_grad, + std::optional key_grad, + std::optional value_grad, + const std::string &original_format); + +std::vector multi_tensor_pad_last_dim(std::vector inputs, + int64_t alignment); at::Tensor convert_thd_to_bshd(at::Tensor tensor, at::Tensor cu_seqlens, int b, int max_seq_len); at::Tensor convert_bshd_to_thd(at::Tensor tensor, at::Tensor cu_seqlens, int t); diff --git a/transformer_engine/pytorch/csrc/extensions/attention.cpp b/transformer_engine/pytorch/csrc/extensions/attention.cpp index 529fca45d6..cd7336cea0 100644 --- a/transformer_engine/pytorch/csrc/extensions/attention.cpp +++ b/transformer_engine/pytorch/csrc/extensions/attention.cpp @@ -653,11 +653,13 @@ at::Tensor fa_prepare_bwd(at::Tensor q, at::Tensor k, at::Tensor v) { return qkv; } -std::vector permute_to_grouped_tensor_fwd( - at::Tensor query, std::optional key, std::optional value, - const std::string &original_format) { +std::vector permute_to_grouped_tensor_fwd(at::Tensor query, + std::optional key, + std::optional value, + const std::string &original_format) { NVTE_CHECK(original_format == "sbhd" || original_format == "bshd", - "Unsupported original_format \"", original_format, "\"; expected \"sbhd\" or \"bshd\"."); + "Unsupported original_format \"", original_format, + "\"; expected \"sbhd\" or \"bshd\"."); const auto original_format_enum = (original_format == "sbhd") ? NVTE_SBHD : NVTE_BSHD; NVTE_CHECK(query.is_cuda() && query.is_contiguous() && query.dim() == 4); NVTE_CHECK(query.scalar_type() == at::ScalarType::Half || @@ -669,9 +671,15 @@ std::vector permute_to_grouped_tensor_fwd( int64_t B, S_q, H_q, D_qk; if (original_format_enum == NVTE_SBHD) { - S_q = query.size(0); B = query.size(1); H_q = query.size(2); D_qk = query.size(3); + S_q = query.size(0); + B = query.size(1); + H_q = query.size(2); + D_qk = query.size(3); } else { - B = query.size(0); S_q = query.size(1); H_q = query.size(2); D_qk = query.size(3); + B = query.size(0); + S_q = query.size(1); + H_q = query.size(2); + D_qk = query.size(3); } at::Tensor q_out = at::empty({B, H_q, S_q, D_qk}, query.options()); @@ -679,9 +687,9 @@ std::vector permute_to_grouped_tensor_fwd( if (!has_kv) { auto te_q = makeTransformerEngineTensor(query); auto te_qo = makeTransformerEngineTensor(q_out); - nvte_permute_to_grouped_tensor_fwd(te_q.data(), te_q.data(), te_q.data(), - te_qo.data(), te_qo.data(), te_qo.data(), - original_format_enum, 1, at::cuda::getCurrentCUDAStream()); + nvte_permute_to_grouped_tensor_fwd(te_q.data(), te_q.data(), te_q.data(), te_qo.data(), + te_qo.data(), te_qo.data(), original_format_enum, 1, + at::cuda::getCurrentCUDAStream()); return {q_out}; } @@ -693,9 +701,13 @@ std::vector permute_to_grouped_tensor_fwd( int64_t S_kv, H_kv, D_v; if (original_format_enum == NVTE_SBHD) { - S_kv = k.size(0); H_kv = k.size(2); D_v = v.size(3); + S_kv = k.size(0); + H_kv = k.size(2); + D_v = v.size(3); } else { - S_kv = k.size(1); H_kv = k.size(2); D_v = v.size(3); + S_kv = k.size(1); + H_kv = k.size(2); + D_v = v.size(3); } const int64_t numel_q = B * H_q * S_q * D_qk; @@ -713,18 +725,20 @@ std::vector permute_to_grouped_tensor_fwd( auto te_ko = makeTransformerEngineTensor(k_out); auto te_vo = makeTransformerEngineTensor(v_out); - nvte_permute_to_grouped_tensor_fwd(te_q.data(), te_k.data(), te_v.data(), - te_qo.data(), te_ko.data(), te_vo.data(), - original_format_enum, 3, at::cuda::getCurrentCUDAStream()); + nvte_permute_to_grouped_tensor_fwd(te_q.data(), te_k.data(), te_v.data(), te_qo.data(), + te_ko.data(), te_vo.data(), original_format_enum, 3, + at::cuda::getCurrentCUDAStream()); return {q_out, k_out, v_out}; } -std::vector permute_to_grouped_tensor_bwd( - at::Tensor query_grad, std::optional key_grad, - std::optional value_grad, const std::string &original_format) { +std::vector permute_to_grouped_tensor_bwd(at::Tensor query_grad, + std::optional key_grad, + std::optional value_grad, + const std::string &original_format) { NVTE_CHECK(original_format == "sbhd" || original_format == "bshd", - "Unsupported original_format \"", original_format, "\"; expected \"sbhd\" or \"bshd\"."); + "Unsupported original_format \"", original_format, + "\"; expected \"sbhd\" or \"bshd\"."); const auto original_format_enum = (original_format == "sbhd") ? NVTE_SBHD : NVTE_BSHD; NVTE_CHECK(query_grad.is_cuda() && query_grad.is_contiguous() && query_grad.dim() == 4); NVTE_CHECK(query_grad.scalar_type() == at::ScalarType::Half || @@ -747,9 +761,9 @@ std::vector permute_to_grouped_tensor_bwd( } auto te_gq = makeTransformerEngineTensor(query_grad); auto te_q = makeTransformerEngineTensor(q); - nvte_permute_to_grouped_tensor_bwd(te_gq.data(), te_gq.data(), te_gq.data(), - te_q.data(), te_q.data(), te_q.data(), - original_format_enum, 1, at::cuda::getCurrentCUDAStream()); + nvte_permute_to_grouped_tensor_bwd(te_gq.data(), te_gq.data(), te_gq.data(), te_q.data(), + te_q.data(), te_q.data(), original_format_enum, 1, + at::cuda::getCurrentCUDAStream()); return {q}; } @@ -787,9 +801,9 @@ std::vector permute_to_grouped_tensor_bwd( auto te_k = makeTransformerEngineTensor(key); auto te_v = makeTransformerEngineTensor(value); - nvte_permute_to_grouped_tensor_bwd(te_gq.data(), te_gk.data(), te_gv.data(), - te_q.data(), te_k.data(), te_v.data(), - original_format_enum, 3, at::cuda::getCurrentCUDAStream()); + nvte_permute_to_grouped_tensor_bwd(te_gq.data(), te_gk.data(), te_gv.data(), te_q.data(), + te_k.data(), te_v.data(), original_format_enum, 3, + at::cuda::getCurrentCUDAStream()); return {query, key, value}; } @@ -799,7 +813,8 @@ std::vector permute_to_grouped_tensor_bwd( * All tensors share the same alignment; launches a single fused kernel. **************************************************************************************************/ -std::vector multi_tensor_pad_last_dim(std::vector inputs, int64_t alignment) { +std::vector multi_tensor_pad_last_dim(std::vector inputs, + int64_t alignment) { const auto align = static_cast(alignment); NVTE_CHECK(align > 0, "multi_tensor_pad_last_dim: alignment must be > 0."); NVTE_CHECK(!inputs.empty(), "multi_tensor_pad_last_dim: inputs must not be empty."); @@ -813,10 +828,12 @@ std::vector multi_tensor_pad_last_dim(std::vector inputs for (size_t i = 0; i < inputs.size(); ++i) { auto &input = inputs[i]; - NVTE_CHECK(input.dim() == 2, "multi_tensor_pad_last_dim: expected 2D input at index ", i, ", got ", - input.dim(), "D."); - NVTE_CHECK(input.is_cuda(), "multi_tensor_pad_last_dim: input must be a CUDA tensor at index ", i, "."); - NVTE_CHECK(input.is_contiguous(), "multi_tensor_pad_last_dim: input must be contiguous at index ", i, "."); + NVTE_CHECK(input.dim() == 2, "multi_tensor_pad_last_dim: expected 2D input at index ", i, + ", got ", input.dim(), "D."); + NVTE_CHECK(input.is_cuda(), "multi_tensor_pad_last_dim: input must be a CUDA tensor at index ", + i, "."); + NVTE_CHECK(input.is_contiguous(), + "multi_tensor_pad_last_dim: input must be contiguous at index ", i, "."); const int64_t rows = input.size(0); const int64_t in_cols = input.size(1); @@ -858,7 +875,8 @@ std::vector multi_tensor_pad_last_dim(std::vector inputs nvte_outputs[i] = te_out_wrappers[i].data(); } - nvte_multi_tensor_pad_last_dim(nvte_inputs.data(), nvte_outputs.data(), te_in_wrappers.size(), stream); + nvte_multi_tensor_pad_last_dim(nvte_inputs.data(), nvte_outputs.data(), te_in_wrappers.size(), + stream); return outputs; } diff --git a/transformer_engine/pytorch/csrc/extensions/pybind.cpp b/transformer_engine/pytorch/csrc/extensions/pybind.cpp index de2b75b2b9..25c62e55e3 100644 --- a/transformer_engine/pytorch/csrc/extensions/pybind.cpp +++ b/transformer_engine/pytorch/csrc/extensions/pybind.cpp @@ -406,16 +406,14 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { py::call_guard()); m.def("permute_to_grouped_tensor_fwd", &transformer_engine::pytorch::permute_to_grouped_tensor_fwd, - "Permute tensors from BSHD/SBHD to BHSD.", py::arg("query"), - py::arg("key") = py::none(), py::arg("value") = py::none(), - py::arg("original_format") = std::string("bshd"), + "Permute tensors from BSHD/SBHD to BHSD.", py::arg("query"), py::arg("key") = py::none(), + py::arg("value") = py::none(), py::arg("original_format") = std::string("bshd"), py::call_guard()); m.def("permute_to_grouped_tensor_bwd", &transformer_engine::pytorch::permute_to_grouped_tensor_bwd, "Permute tensors back to original format.", py::arg("query_grad"), py::arg("key_grad") = py::none(), py::arg("value_grad") = py::none(), - py::arg("original_format") = std::string("bshd"), - py::call_guard()); + py::arg("original_format") = std::string("bshd"), py::call_guard()); m.def("multi_tensor_pad_last_dim", &transformer_engine::pytorch::multi_tensor_pad_last_dim, "Pad last dimension of 2D tensors to a common alignment.", py::arg("inputs"), py::arg("alignment"), py::call_guard()); @@ -440,7 +438,6 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("fused_qkv_rope_backward", &transformer_engine::pytorch::fused_qkv_rope_backward, "Fused Apply QKV RoPE BWD", py::call_guard()); - // fused router m.def("fused_topk_with_score_function_fwd", &transformer_engine::pytorch::fused_topk_with_score_function_fwd, py::arg("logits"), diff --git a/transformer_engine/pytorch/csrc/extensions/swizzle.cpp b/transformer_engine/pytorch/csrc/extensions/swizzle.cpp index bd4ad086a3..c232cf7e01 100644 --- a/transformer_engine/pytorch/csrc/extensions/swizzle.cpp +++ b/transformer_engine/pytorch/csrc/extensions/swizzle.cpp @@ -444,8 +444,7 @@ void grouped_swizzle_for_gemm(py::handle &tensor, bool rowwise, bool columnwise) } void inplace_multi_swizzle_scales_for_gemm(std::vector &tensors, bool rowwise_usage, - bool columnwise_usage, - bool check_scale_inv_shapes) { + bool columnwise_usage, bool check_scale_inv_shapes) { NVTE_CHECK(rowwise_usage != columnwise_usage, "Expect exactly one of rowwise_usage and columnwise_usage."); if (tensors.empty()) { @@ -527,8 +526,7 @@ void inplace_multi_swizzle_scales_for_gemm(std::vector &tensors, boo auto stream = at::cuda::getCurrentCUDAStream(); NVTE_SCOPED_GIL_RELEASE({ nvte_multi_tensor_swizzle_scaling_factors(inputs_raw.data(), outputs_raw.data(), - inputs_raw.size(), stream, - check_scale_inv_shapes); + inputs_raw.size(), stream, check_scale_inv_shapes); }); // Update Python tensors with the owning output tensors