Skip to content

Port softmax ops to libtorch stable ABI#2830

Open
pstjohn wants to merge 1 commit intoNVIDIA:mainfrom
pstjohn:pstjohn/stable-abi-v2
Open

Port softmax ops to libtorch stable ABI#2830
pstjohn wants to merge 1 commit intoNVIDIA:mainfrom
pstjohn:pstjohn/stable-abi-v2

Conversation

@pstjohn
Copy link
Copy Markdown
Contributor

@pstjohn pstjohn commented Apr 3, 2026

Proof of concept for migrating pybind11 functions to the PyTorch stable ABI. Ports all 8 scaled softmax functions:

  • Add stable_common.h with stable ABI helpers (tensor allocation, TensorWrapper construction, CUDA stream, dtype converters)
  • Add registration.cpp with STABLE_TORCH_LIBRARY schema definitions
  • Rewrite softmax.cpp: at::Tensor -> torch::stable::Tensor, use stable allocation and stream APIs, TORCH_BOX() for impl registration
  • Remove softmax registrations from pybind.cpp
  • Update Python callers to use torch.ops.transformer_engine_stable

The pattern is mechanical (API translation, no logic changes) and establishes the template for porting the remaining ~70 Category A functions that have no py::handle/py::object dependencies.

Breakdown of currently registered functions by category: https://docs.google.com/spreadsheets/d/1p9bgmas65M03-yak3zmPImMLXfb7mp2fG_VVcJPNFqo/edit?gid=1617092873#gid=1617092873

@pstjohn pstjohn force-pushed the pstjohn/stable-abi-v2 branch from 444e518 to 7d9f212 Compare April 3, 2026 14:18
@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps bot commented Apr 3, 2026

Greptile Summary

This PR ports all 8 scaled softmax CUDA operations from pybind11 (transformer_engine_torch) to PyTorch's stable ABI (torch.ops.transformer_engine), establishing the pattern for migrating the remaining ~70 Category A functions. The minimum PyTorch version is bumped from 2.1 to 2.10 (required for the stable ABI headers) consistently across all four manifest files.

All concerns raised in prior review rounds have been resolved: dtype validation is restored via check_fp16_bf16() in stable_common.h, mask shape guards are present in scaled_masked_softmax_forward, and the square-matrix check is back in scaled_upper_triang_masked_softmax_backward. The translation is otherwise mechanical with no logic changes.

Confidence Score: 5/5

Safe to merge — all previously raised P0/P1 issues are resolved and no new defects were found.

Every critical validation guard (dtype, mask shape, square-matrix check) that was flagged in earlier review rounds is present in the current code. The API translation is faithful to the original logic, the new stable ABI registration pattern is self-consistent (single STABLE_TORCH_LIBRARY owner in registration.cpp, STABLE_TORCH_LIBRARY_IMPL dispatch in softmax.cpp), and the version bump is applied uniformly across all four manifests. No remaining P1 or P0 findings.

No files require special attention.

Important Files Changed

Filename Overview
transformer_engine/pytorch/csrc/stable_common.h New header providing stable ABI helpers: dtype conversion, stream acquisition, TensorWrapper construction, tensor allocation, and the check_fp16_bf16 validator restored after prior review.
transformer_engine/pytorch/csrc/extensions/softmax.cpp Core migration of all 8 softmax ops from at::Tensor/pybind11 to torch::stable::Tensor; all previously flagged validations (dtype, mask shape, square-matrix check) are present in the current file.
transformer_engine/pytorch/csrc/extensions/registration.cpp New file that owns the "transformer_engine" STABLE_TORCH_LIBRARY namespace and registers schemas for all 8 softmax ops; sets the convention for future fragment registrations.
transformer_engine/pytorch/attention/dot_product_attention/softmax.py Switches all 8 op call sites from tex.* (pybind11) to torch.ops.transformer_engine.*; module-level _ops binding is intentionally lazy and safe.
transformer_engine/pytorch/init.py Version gate bumped from (2, 1) to (2, 10); packaging.version.Version.release returns a full tuple so the tuple comparison is correct.
transformer_engine/pytorch/csrc/extensions/pybind.cpp Removes all 8 softmax pybind11 bindings; no remaining references to the removed functions.
transformer_engine/pytorch/csrc/extensions.h Removes the 8 softmax C++ function declarations from the pybind header; clean removal.
build_tools/pytorch.py Bumps torch install requirement from >=2.1 to >=2.10, consistent with the stable ABI dependency.
pyproject.toml Bumps torch build requirement from >=2.1 to >=2.10 in the root build-system spec.
transformer_engine/pytorch/pyproject.toml Bumps the pytorch sub-package build requirement to torch>=2.10, completing the version bump across all manifests.

Sequence Diagram

sequenceDiagram
    participant Py as Python Caller
    participant TO as torch.ops.transformer_engine
    participant Reg as registration.cpp (STABLE_TORCH_LIBRARY)
    participant Impl as softmax.cpp (STABLE_TORCH_LIBRARY_IMPL)
    participant SC as stable_common.h
    participant NVTE as nvte_scaled_*_forward/backward

    Note over Py,NVTE: Old path: tex.scaled_softmax_forward(input, scale)
    Note over Py,NVTE: New path via stable ABI

    Py->>TO: _ops.scaled_softmax_forward(input, scale_t[0])
    TO->>Reg: schema lookup scaled_softmax_forward(Tensor, float)->Tensor
    Reg->>Impl: dispatch to CUDA impl
    Impl->>SC: check_fp16_bf16(input)
    Impl->>SC: allocateStableTensor(shape, dtype, device)
    Impl->>SC: makeTransformerEngineTensor(input)
    Impl->>SC: getCurrentCUDAStreamRaw(device_index)
    Impl->>NVTE: nvte_scaled_softmax_forward(input, output, scale, stream)
    NVTE-->>Impl: kernel result in output tensor
    Impl-->>Py: softmax_results
Loading

Reviews (3): Last reviewed commit: "Port softmax ops to libtorch stable ABI" | Re-trigger Greptile

import transformer_engine_torch as tex
from transformer_engine.pytorch.export import is_in_onnx_export_mode

_ops = torch.ops.transformer_engine
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Module-level op namespace binding may hide load-time errors

_ops = torch.ops.transformer_engine is evaluated at import time. torch.ops provides a lazy namespace so this line itself doesn't fail, but any subsequent attribute access (e.g. _ops.scaled_softmax_forward) will raise AttributeError if the native library hasn't been loaded yet. Accessing the namespace at call time or guarding with try/except would give a clearer error message if the shared library fails to load.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

torch.ops is a lazy namespace — torch.ops.transformer_engine succeeds even if the library is not loaded. The actual resolution happens at call time (e.g. _ops.scaled_softmax_forward(...)), which will raise a clear AttributeError if the op is not registered. Wrapping every call in try/except would add overhead without improving the error message, and failing at import time would be worse since it would prevent importing the module entirely even for code paths that don't use these ops.

@pstjohn
Copy link
Copy Markdown
Contributor Author

pstjohn commented Apr 3, 2026

/te-ci pytorch

@pstjohn pstjohn force-pushed the pstjohn/stable-abi-v2 branch from 7d9f212 to 5611150 Compare April 3, 2026 14:31
Copy link
Copy Markdown
Collaborator

@timmoon10 timmoon10 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is quite nice and straightforward. The only real issue is that this requires bumping the minimum PyTorch version.

Comment on lines +15 to +20
// PyTorch Stable ABI headers
#include <torch/csrc/stable/accelerator.h>
#include <torch/csrc/stable/library.h>
#include <torch/csrc/stable/ops.h>
#include <torch/csrc/stable/tensor.h>
#include <torch/headeronly/core/ScalarType.h>
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the minimum PyTorch version that will be required? I see that these files were first added in PyTorch 2.9.0. At the moment, I believe we only enforce 2.1+:

assert torch_version() >= (2, 1), f"Minimum torch version 2.1 required. Found {torch_version()}."

Copy link
Copy Markdown
Contributor Author

@pstjohn pstjohn Apr 3, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, yes. Probably 2.10 to be honest, torch::stable::empty was added there

Proof of concept for migrating pybind11 functions to the PyTorch
stable ABI. Ports all 8 scaled softmax functions:

- Add stable_common.h with stable ABI helpers (tensor allocation,
  TensorWrapper construction, CUDA stream, dtype converters)
- Add registration.cpp with STABLE_TORCH_LIBRARY schema definitions
- Rewrite softmax.cpp: at::Tensor -> torch::stable::Tensor, use
  stable allocation and stream APIs, TORCH_BOX() for impl registration
- Remove softmax registrations from pybind.cpp
- Update Python callers to use torch.ops.transformer_engine_stable

The pattern is mechanical (API translation, no logic changes) and
establishes the template for porting the remaining ~70 Category A
functions that have no py::handle/py::object dependencies.

Signed-off-by: Peter St. John <pstjohn@nvidia.com>
@pstjohn pstjohn force-pushed the pstjohn/stable-abi-v2 branch from 5611150 to e0bae00 Compare April 3, 2026 21:37
@pstjohn
Copy link
Copy Markdown
Contributor Author

pstjohn commented Apr 3, 2026

Here's a demo PR for the next set of easy modules:
pstjohn#2

"""Install dependencies for TE/PyTorch extensions."""
return ["torch>=2.1", "einops", "onnxscript", "onnx", "packaging", "pydantic", "nvdlfw-inspect"]
return [
"torch>=2.10",
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Requiring the (close to) latest Pytorch is a problem which I don't think we will be able to do before TE 3.0 to be honest.

Copy link
Copy Markdown
Contributor Author

@pstjohn pstjohn Apr 3, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

😢 fair enough -- is there a timeline for TE 3.0?

The alternative, that might let us get away with a lower pytorch version at the cost of a more extensive refactor, is to make sure any work matrices get allocated on the python side and passed in as arguments

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unfortunately, we've found that allocating in Python has non-trivial CPU overhead compared to allocating in C++.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we could validate this and get through all the ops, it would be cool if we could completely remove the pybind11 library in TE 3.0 🤷

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants