Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
37 commits
Select commit Hold shift + click to select a range
d10fa92
Initial commit
Micky774 Oct 24, 2025
eef7dc0
Updated to build from source by default
Micky774 Oct 24, 2025
cc68ab7
Updated for V3 API
Micky774 Oct 31, 2025
4455361
Fixed build, reverted AOTriton bwd changes (now V2)
Micky774 Nov 3, 2025
2586b18
Removed alterations
Micky774 Nov 3, 2025
aa80f81
Removed lazy tensor wrapper
Micky774 Nov 3, 2025
9a91b9e
Streamlined cmakelist, other PR review feedback adressed
Micky774 Nov 4, 2025
023deb4
Removed `pad_between_seqs`
Micky774 Nov 4, 2025
6b8dbe5
Updated typing to be more explicit
Micky774 Nov 4, 2025
68303d0
Minor streamlining and formatting
Micky774 Nov 4, 2025
8181972
Initial implementation
Micky774 Nov 6, 2025
6788a16
Simplified window size func for current non-SWA support
Micky774 Nov 6, 2025
182101a
Removed accidental include
Micky774 Nov 6, 2025
19a9c0f
Merge branch 'zain/aotriton' into zain/aotriton-bwd
Micky774 Nov 6, 2025
fef6baa
Corrected bwd args
Micky774 Nov 6, 2025
3a4fab8
Updated causal window default
Micky774 Nov 10, 2025
917e3c3
Updated window values for causal
Micky774 Nov 10, 2025
ce32e3b
Merge branch 'zain/aotriton' into zain/aotriton-bwd
Micky774 Nov 10, 2025
36045c8
Corrected DQ_ACC buffer, added env var for GPU kernel building
Micky774 Nov 12, 2025
d6e46c1
Update AOTriton to 0.11.1b
Micky774 Nov 12, 2025
1349a48
Merge branch 'dev' into zain/aotriton
Micky774 Nov 24, 2025
8ed0009
Merge branch 'zain/aotriton' into zain/aotriton-bwd
Micky774 Nov 24, 2025
2bd9006
Added AOTriton commit SHA
Micky774 Nov 25, 2025
a9bef37
Merge branch 'dev' into zain/aotriton-bwd
Micky774 Nov 25, 2025
0fdff86
Moved handling of env variable to makefile
Micky774 Nov 26, 2025
3f6e054
Simplified lazy tensor implementation
Micky774 Dec 1, 2025
2246da4
Merge branch 'dev' into zain/aotriton-bwd
Micky774 Dec 10, 2025
2a17f7b
Merge branch 'dev' into zain/aotriton-bwd
Micky774 Jan 29, 2026
1a267cd
Update AOTriton version
Micky774 Jan 30, 2026
51da203
Improved tests
Micky774 Feb 4, 2026
945c8b2
Fix dq_acc stride. AITER ASM expects BHS.
xinyazhang Feb 5, 2026
e478e1a
Merge branch 'dev' into zain/aotriton-bwd
Micky774 Feb 5, 2026
872bc12
Revert unnecessary changes
Micky774 Feb 10, 2026
4478af5
Updated copyright and IS_HIP_EXTENSION guard
Micky774 Feb 11, 2026
8d06ef1
Update test and cmakelist
Micky774 Feb 12, 2026
0654925
Updated tests
Micky774 Feb 17, 2026
f2798b7
Merge branch 'dev' into zain/aotriton-bwd
Micky774 Feb 18, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 63 additions & 11 deletions tests/pytorch/attention/test_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -1281,17 +1281,60 @@ def test_transformer_layer(

# FusedAttention backend
if fused_attn_supported:
fused_attn_fwd, fused_attn_bwd = _run_transformer_layer(
dtype,
config,
"FusedAttention",
ckpt_attn,
qkv_format,
workspace_opt,
fused_qkv_params,
RoPE,
is_training,
)
if len(fused_attn_backends) == 1 or not IS_HIP_EXTENSION:
fused_attn_fwd, fused_attn_bwd = _run_transformer_layer(
dtype,
config,
"FusedAttention",
ckpt_attn,
qkv_format,
workspace_opt,
fused_qkv_params,
RoPE,
is_training,
)
elif len(fused_attn_backends) == 2:
os.environ["NVTE_FUSED_ATTN_CK"] = "0"
os.environ["NVTE_FUSED_ATTN_AOTRITON"] = "1"
fused_attn_fwd, fused_attn_bwd = _run_transformer_layer(
dtype,
config,
"FusedAttention",
ckpt_attn,
qkv_format,
workspace_opt,
fused_qkv_params,
RoPE,
is_training,
)
os.environ["NVTE_FUSED_ATTN_CK"] = "1"
os.environ["NVTE_FUSED_ATTN_AOTRITON"] = "0"
fused_attn_fwd_1, fused_attn_bwd_1 = _run_transformer_layer(
dtype,
config,
"FusedAttention",
ckpt_attn,
qkv_format,
workspace_opt,
fused_qkv_params,
RoPE,
is_training,
)

os.environ["NVTE_CK_USES_FWD_V3"] = "0"
os.environ["NVTE_CK_USES_BWD_V3"] = "0"
fused_attn_fwd_2, fused_attn_bwd_2 = _run_transformer_layer(
dtype,
config,
"FusedAttention",
ckpt_attn,
qkv_format,
workspace_opt,
fused_qkv_params,
RoPE,
is_training,
)


# FlashAttention backend
if flash_attn_supported:
Expand Down Expand Up @@ -1320,6 +1363,15 @@ def test_transformer_layer(
logging.info("[test_transformer_layer]: fused attn vs flash attn")
torch.testing.assert_close(fused_attn_fwd, flash_attn_fwd, **tols)
torch.testing.assert_close(fused_attn_bwd, flash_attn_bwd, **tols)
if IS_HIP_EXTENSION and fused_attn_supported and len(fused_attn_backends) == 2:
logging.info("[test_transformer_layer]: fused attn backend 0 vs 1")
torch.testing.assert_close(fused_attn_fwd, fused_attn_fwd_1, **tols)
for i, _ in enumerate(fused_attn_bwd):
torch.testing.assert_close(fused_attn_bwd[i], fused_attn_bwd_1[i], **tols)
logging.info("[test_transformer_layer]: fused attn backend 0 vs 2")
torch.testing.assert_close(fused_attn_fwd, fused_attn_fwd_2, **tols)
for i, _ in enumerate(fused_attn_bwd):
torch.testing.assert_close(fused_attn_bwd[i], fused_attn_bwd_2[i], **tols)


@pytest.mark.skipif(get_cudnn_version() < (8, 9, 1), reason="cuDNN 8.9.1+ is required.")
Expand Down
1 change: 0 additions & 1 deletion transformer_engine/common/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ cmake_minimum_required(VERSION 3.21)

option(USE_ROCM "Use ROCm" ON)
option(USE_FUSED_ATTN_AOTRITON "Use aotriton backend" ON)
option(USE_FUSED_ATTN_AOTRITON_BUILD_GPU_KERNELS "Build AOTriton GPU kernels" OFF)
option(USE_FUSED_ATTN_CK "Use ck backend" ON)
set(USE_CUDA OFF)

Expand Down
25 changes: 11 additions & 14 deletions transformer_engine/common/aotriton/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2022-2025, Advanced Micro Devices, Inc. All rights reserved.
# Copyright (c) 2022-2026, Advanced Micro Devices, Inc. All rights reserved.
# SPDX-License-Identifier: MIT

cmake_minimum_required(VERSION 3.21)
Expand All @@ -8,19 +8,16 @@ project(aotriton LANGUAGES CXX)
# The AOTriton C++ runtime will be built from {TE}/3rdparty/aotriton
# Hence there is no need to add multiple ROCM version here

if(DEFINED ENV{AOTRITON_PATH})
set(AOTRITON_PATH $ENV{AOTRITON_PATH})
endif()

set(__AOTRITON_INSTALL_DIR "${CMAKE_CURRENT_BINARY_DIR}/aotriton")
set(__AOTRITON_SUFFIX "_TEprivate")

if(NOT DEFINED AOTRITON_PATH)
# If AOTRITON_PATH is not provided, we proceed to build the runtime
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Your original changes used env variable to control it. On the other hand this feature seems unused

# ourselves and either build or download the GPU kernels
if(USE_FUSED_ATTN_AOTRITON_BUILD_GPU_KERNELS)
set(AOTRITON_NOIMAGE_MODE OFF)
else()
set(AOTRITON_NOIMAGE_MODE ON)
endif()

set(__AOTRITON_VER "0.11.1b")
set(__AOTRITON_VER "0.11.2b")
set(__AOTRITON_IMAGE_LIST
"amd-gfx942"
"amd-gfx950"
Expand Down Expand Up @@ -66,8 +63,7 @@ if(NOT DEFINED AOTRITON_PATH)

# Build the AOTriton runtime from source with custom suffix to avoid
# potential conflict with libaotriton as provided by PyTorch
function(aotriton_build_from_source noimage)
message(STATUS "No-image mode: ${noimage}.")
function(aotriton_build_from_source)
get_git_commit(${TE}/3rdparty/aotriton AOTRITON_SHA)
ExternalProject_Add(aotriton_external
LIST_SEPARATOR ","
Expand All @@ -78,7 +74,7 @@ if(NOT DEFINED AOTRITON_PATH)
-DCMAKE_BUILD_TYPE=${CMAKE_BUILD_TYPE}
-DAOTRITON_NO_PYTHON=ON
-DAOTRITON_NAME_SUFFIX=${__AOTRITON_SUFFIX}
-DAOTRITON_NOIMAGE_MODE=${noimage}
-DAOTRITON_NOIMAGE_MODE=ON
-DTE_AOTRITON_COMMIT_SHA1=${AOTRITON_SHA}
-DCMAKE_PROJECT_INCLUDE=${CMAKE_CURRENT_LIST_DIR}/aotriton_custom.cmake
BUILD_BYPRODUCTS "${__AOTRITON_INSTALL_DIR}/lib/libaotriton${__AOTRITON_SUFFIX}_v2.so"
Expand All @@ -97,7 +93,7 @@ if(NOT DEFINED AOTRITON_PATH)
add_library(aotriton INTERFACE)
message(STATUS "Building AOTriton from source.")
string(REPLACE ";" "," ARCH_LIST_COMMA_STR "${CMAKE_HIP_ARCHITECTURES}")
aotriton_build_from_source(${AOTRITON_NOIMAGE_MODE})
aotriton_build_from_source()

# Download GPU kernels if needed
if(NOT USE_FUSED_ATTN_AOTRITON_BUILD_GPU_KERNELS)
Expand All @@ -121,8 +117,9 @@ if(NOT DEFINED AOTRITON_PATH)
else()
# Use aotriton built during initial TE building/installation
# When only need rebuild TE library itself
message(STATUS "Using existing AOTriton lib at $ENV{AOTRITON_PATH}")
unset(AOTRITON_LIB CACHE)
find_library(AOTRITON_LIB NAMES aotriton aotriton${__AOTRITON_SUFFIX}_v2 PATHS ${AOTRITON_PATH}/lib REQUIRED NO_DEFAULT_PATH)
find_library(AOTRITON_LIB NAMES aotriton aotriton${__AOTRITON_SUFFIX}_v2 PATHS ${AOTRITON_PATH} REQUIRED NO_DEFAULT_PATH)
add_library( aotriton SHARED IMPORTED )
set_target_properties( aotriton PROPERTIES IMPORTED_LOCATION ${AOTRITON_LIB} )
target_include_directories(aotriton INTERFACE ${AOTRITON_PATH}/include)
Expand Down
3 changes: 3 additions & 0 deletions transformer_engine/common/fused_attn_rocm/fused_attn.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -492,6 +492,7 @@ void nvte_fused_attn_bwd_qkvpacked(const NVTETensor QKV, const NVTETensor O, con
fused_attn_aotriton_bwd_qkvpacked(
b, h, max_seqlen, d,
attn_scale, dropout,
window_size_left, window_size_right,
qkv_layout, bias_type, attn_mask_type,
input_QKV, input_O, input_dO, output_S,
output_dQKV,
Expand Down Expand Up @@ -678,6 +679,7 @@ void nvte_fused_attn_bwd_kvpacked(
fused_attn_aotriton_bwd_kvpacked(
b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d,
attn_scale, dropout,
window_size_left, window_size_right,
qkv_layout, bias_type, attn_mask_type,
input_Q, input_KV, input_O, input_dO,
output_S,
Expand Down Expand Up @@ -858,6 +860,7 @@ void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETenso
fused_attn_aotriton_bwd(
b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d_qk,
attn_scale, dropout,
window_size_left, window_size_right,
qkv_layout, bias_type, attn_mask_type,
input_Q, input_K, input_V, input_O, input_dO,
output_S,
Expand Down
Loading