Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion build_tools/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,15 @@

def install_requirements() -> List[str]:
"""Install dependencies for TE/PyTorch extensions."""
return ["torch>=2.1", "einops", "onnxscript", "onnx", "packaging", "pydantic", "nvdlfw-inspect"]
return [
"torch>=2.10",
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Requiring the (close to) latest Pytorch is a problem which I don't think we will be able to do before TE 3.0 to be honest.

Copy link
Copy Markdown
Contributor Author

@pstjohn pstjohn Apr 3, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

😢 fair enough -- is there a timeline for TE 3.0?

The alternative, that might let us get away with a lower pytorch version at the cost of a more extensive refactor, is to make sure any work matrices get allocated on the python side and passed in as arguments

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unfortunately, we've found that allocating in Python has non-trivial CPU overhead compared to allocating in C++.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we could validate this and get through all the ops, it would be cool if we could completely remove the pybind11 library in TE 3.0 🤷

"einops",
"onnxscript",
"onnx",
"packaging",
"pydantic",
"nvdlfw-inspect",
]


def test_requirements() -> List[str]:
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
# See LICENSE for license information.

[build-system]
requires = ["setuptools>=61.0", "cmake>=3.21", "wheel", "pybind11[global]", "ninja", "pip", "torch>=2.1", "jax>=0.5.0", "flax>=0.7.1"]
requires = ["setuptools>=61.0", "cmake>=3.21", "wheel", "pybind11[global]", "ninja", "pip", "torch>=2.10", "jax>=0.5.0", "flax>=0.7.1"]

# Use legacy backend to import local packages in setup.py
build-backend = "setuptools.build_meta:__legacy__"
2 changes: 1 addition & 1 deletion transformer_engine/pytorch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from transformer_engine.common import load_framework_extension
from transformer_engine.pytorch.torch_version import torch_version

assert torch_version() >= (2, 1), f"Minimum torch version 2.1 required. Found {torch_version()}."
assert torch_version() >= (2, 10), f"Minimum torch version 2.10 required. Found {torch_version()}."

load_framework_extension("torch")
from transformer_engine.pytorch.module import LayerNormLinear
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,10 @@
from typing import Callable, Tuple, Union, Optional
import torch
from torch import nn
import transformer_engine_torch as tex
from transformer_engine.pytorch.export import is_in_onnx_export_mode

_ops = torch.ops.transformer_engine
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Module-level op namespace binding may hide load-time errors

_ops = torch.ops.transformer_engine is evaluated at import time. torch.ops provides a lazy namespace so this line itself doesn't fail, but any subsequent attribute access (e.g. _ops.scaled_softmax_forward) will raise AttributeError if the native library hasn't been loaded yet. Accessing the namespace at call time or guarding with try/except would give a clearer error message if the shared library fails to load.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

torch.ops is a lazy namespace — torch.ops.transformer_engine succeeds even if the library is not loaded. The actual resolution happens at call time (e.g. _ops.scaled_softmax_forward(...)), which will raise a clear AttributeError if the op is not registered. Wrapping every call in try/except would add overhead without improving the error message, and failing at import time would be worse since it would prevent importing the module entirely even for code paths that don't use these ops.



THREADS_PER_WARP = 32
THREADS_PER_BLOCK = 128
Expand Down Expand Up @@ -47,7 +48,7 @@ class ScaledUpperTriangMaskedSoftmax(torch.autograd.Function):
def forward(ctx, inputs: torch.Tensor, scale: float) -> torch.Tensor:
"""ScaledUpperTriangMaskedSoftmax fwd"""
scale_t = torch.tensor([scale])
softmax_results = tex.scaled_upper_triang_masked_softmax_forward(inputs, scale_t[0])
softmax_results = _ops.scaled_upper_triang_masked_softmax_forward(inputs, scale_t[0])

ctx.save_for_backward(softmax_results, scale_t)
return softmax_results
Expand All @@ -56,7 +57,7 @@ def forward(ctx, inputs: torch.Tensor, scale: float) -> torch.Tensor:
def backward(ctx, output_grads: torch.Tensor) -> Tuple[Union[torch.Tensor, None], ...]:
"""ScaledUpperTriangMaskedSoftmax bwd"""
softmax_results, scale_t = ctx.saved_tensors
input_grads = tex.scaled_upper_triang_masked_softmax_backward(
input_grads = _ops.scaled_upper_triang_masked_softmax_backward(
output_grads, softmax_results, scale_t[0]
)

Expand All @@ -75,15 +76,15 @@ class ScaledAlignedCausalMaskedSoftmax(torch.autograd.Function):
def forward(ctx, inputs: torch.Tensor, scale: float) -> torch.Tensor:
"""ScaledAlignedCausalMaskedSoftmax fwd"""
scale_t = torch.tensor([scale])
softmax_results = tex.scaled_aligned_causal_masked_softmax_forward(inputs, scale_t[0])
softmax_results = _ops.scaled_aligned_causal_masked_softmax_forward(inputs, scale_t[0])
ctx.save_for_backward(softmax_results, scale_t)
return softmax_results

@staticmethod
def backward(ctx, output_grads: torch.Tensor) -> Tuple[Union[torch.Tensor, None], ...]:
"""ScaledAlignedCausalMaskedSoftmax bwd"""
softmax_results, scale_t = ctx.saved_tensors
input_grads = tex.scaled_aligned_causal_masked_softmax_backward(
input_grads = _ops.scaled_aligned_causal_masked_softmax_backward(
output_grads, softmax_results, scale_t[0]
)

Expand All @@ -103,7 +104,7 @@ def forward(ctx, inputs: torch.Tensor, mask: torch.Tensor, scale: float) -> torc
"""ScaledMaskedSoftmax fwd"""
scale_t = torch.tensor([scale])

softmax_results = tex.scaled_masked_softmax_forward(inputs, mask, scale_t[0])
softmax_results = _ops.scaled_masked_softmax_forward(inputs, mask, scale_t[0])
ctx.save_for_backward(softmax_results, scale_t)
return softmax_results

Expand All @@ -112,7 +113,7 @@ def backward(ctx, output_grads: torch.Tensor) -> Tuple[Union[torch.Tensor, None]
"""ScaledMaskedSoftmax bwd"""
softmax_results, scale_t = ctx.saved_tensors

input_grads = tex.scaled_masked_softmax_backward(output_grads, softmax_results, scale_t[0])
input_grads = _ops.scaled_masked_softmax_backward(output_grads, softmax_results, scale_t[0])
return input_grads, None, None


Expand All @@ -128,7 +129,7 @@ def forward(ctx, inputs: torch.Tensor, scale: float) -> torch.Tensor:
"""ScaledSoftmax fwd"""
scale_t = torch.tensor([scale])

softmax_results = tex.scaled_softmax_forward(inputs, scale_t[0])
softmax_results = _ops.scaled_softmax_forward(inputs, scale_t[0])
ctx.save_for_backward(softmax_results, scale_t)
return softmax_results

Expand All @@ -137,7 +138,7 @@ def backward(ctx, output_grads: torch.Tensor) -> Tuple[Union[torch.Tensor, None]
"""ScaledSoftmax bwd"""
softmax_results, scale_t = ctx.saved_tensors

input_grads = tex.scaled_softmax_backward(output_grads, softmax_results, scale_t[0])
input_grads = _ops.scaled_softmax_backward(output_grads, softmax_results, scale_t[0])
return input_grads, None, None


Expand Down
26 changes: 0 additions & 26 deletions transformer_engine/pytorch/csrc/extensions.h
Original file line number Diff line number Diff line change
Expand Up @@ -349,32 +349,6 @@ py::object dropout_bwd(const at::Tensor &grad_output, const at::Tensor &mask,
const float dropout_probability,
std::optional<at::Tensor> grad_input = std::nullopt);

/***************************************************************************************************
* Softmax
**************************************************************************************************/

at::Tensor scaled_softmax_forward(at::Tensor input, float scale_factor);

at::Tensor scaled_softmax_backward(at::Tensor output_grad_, at::Tensor softmax_results_,
float scale_factor);

at::Tensor scaled_masked_softmax_forward(at::Tensor input, at::Tensor mask, float scale_factor);

at::Tensor scaled_masked_softmax_backward(at::Tensor output_grad_, at::Tensor softmax_results_,
float scale_factor);

at::Tensor scaled_upper_triang_masked_softmax_forward(at::Tensor input, float scale_factor);

at::Tensor scaled_upper_triang_masked_softmax_backward(at::Tensor output_grads_,
at::Tensor softmax_results_,
float scale_factor);

at::Tensor scaled_aligned_causal_masked_softmax_forward(at::Tensor input, float scale_factor);

at::Tensor scaled_aligned_causal_masked_softmax_backward(at::Tensor output_grads_,
at::Tensor softmax_results_,
float scale_factor);

/***************************************************************************************************
* FP8 recipe
**************************************************************************************************/
Expand Down
26 changes: 0 additions & 26 deletions transformer_engine/pytorch/csrc/extensions/pybind.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -232,32 +232,6 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("moe_unpermute_bwd", transformer_engine::pytorch::moe_unpermute_bwd, "MOE unpermute BWD",
py::call_guard<py::gil_scoped_release>());

// Softmax functions
m.def("scaled_softmax_forward", &transformer_engine::pytorch::scaled_softmax_forward,
"Scaled Softmax FWD", py::call_guard<py::gil_scoped_release>());
m.def("scaled_softmax_backward", &transformer_engine::pytorch::scaled_softmax_backward,
"Scaled Softmax BWD", py::call_guard<py::gil_scoped_release>());
m.def("scaled_masked_softmax_forward",
&transformer_engine::pytorch::scaled_masked_softmax_forward, "Scaled Masked Softmax FWD",
py::call_guard<py::gil_scoped_release>());
m.def("scaled_masked_softmax_backward",
&transformer_engine::pytorch::scaled_masked_softmax_backward, "Scaled Masked Softmax BWD",
py::call_guard<py::gil_scoped_release>());
m.def("scaled_upper_triang_masked_softmax_forward",
&transformer_engine::pytorch::scaled_upper_triang_masked_softmax_forward,
"Scaled Upper-Triangular Masked Softmax FWD", py::call_guard<py::gil_scoped_release>());
m.def("scaled_upper_triang_masked_softmax_backward",
&transformer_engine::pytorch::scaled_upper_triang_masked_softmax_backward,
"Scaled Upper-Triangular Masked Softmax BWD", py::call_guard<py::gil_scoped_release>());
m.def("scaled_aligned_causal_masked_softmax_forward",
&transformer_engine::pytorch::scaled_aligned_causal_masked_softmax_forward,
"Scaled Bottom-Right Corner Aligned Masked Softmax FWD",
py::call_guard<py::gil_scoped_release>());
m.def("scaled_aligned_causal_masked_softmax_backward",
&transformer_engine::pytorch::scaled_aligned_causal_masked_softmax_backward,
"Scaled Bottom-Right Corner Aligned Masked Softmax BWD",
py::call_guard<py::gil_scoped_release>());

// Other granular functions
m.def("layernorm_fwd", &transformer_engine::pytorch::layernorm_fwd, "LayerNorm", py::arg("input"),
py::arg("weight"), py::arg("bias"), py::arg("eps"), py::arg("ln_out"), py::arg("quantizer"),
Expand Down
30 changes: 30 additions & 0 deletions transformer_engine/pytorch/csrc/extensions/registration.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
/*************************************************************************
* Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/

#include "../stable_common.h"

// This file defines the transformer_engine library namespace.
// All other stable ABI files use STABLE_TORCH_LIBRARY_FRAGMENT to add schemas
// and STABLE_TORCH_LIBRARY_IMPL to add implementations.
STABLE_TORCH_LIBRARY(transformer_engine, m) {
// Softmax ops
m.def("scaled_softmax_forward(Tensor input, float scale_factor) -> Tensor");
m.def(
"scaled_softmax_backward(Tensor output_grad, Tensor softmax_results, float scale_factor) -> "
"Tensor");
m.def("scaled_masked_softmax_forward(Tensor input, Tensor mask, float scale_factor) -> Tensor");
m.def(
"scaled_masked_softmax_backward(Tensor output_grad, Tensor softmax_results, float "
"scale_factor) -> Tensor");
m.def("scaled_upper_triang_masked_softmax_forward(Tensor input, float scale_factor) -> Tensor");
m.def(
"scaled_upper_triang_masked_softmax_backward(Tensor output_grads, Tensor softmax_results, "
"float scale_factor) -> Tensor");
m.def("scaled_aligned_causal_masked_softmax_forward(Tensor input, float scale_factor) -> Tensor");
m.def(
"scaled_aligned_causal_masked_softmax_backward(Tensor output_grad, Tensor softmax_results, "
"float scale_factor) -> Tensor");
}
Loading
Loading