Conversation
444e518 to
7d9f212
Compare
Greptile SummaryThis PR ports all 8 scaled softmax CUDA operations from pybind11 ( All concerns raised in prior review rounds have been resolved: dtype validation is restored via Confidence Score: 5/5Safe to merge — all previously raised P0/P1 issues are resolved and no new defects were found. Every critical validation guard (dtype, mask shape, square-matrix check) that was flagged in earlier review rounds is present in the current code. The API translation is faithful to the original logic, the new stable ABI registration pattern is self-consistent (single STABLE_TORCH_LIBRARY owner in registration.cpp, STABLE_TORCH_LIBRARY_IMPL dispatch in softmax.cpp), and the version bump is applied uniformly across all four manifests. No remaining P1 or P0 findings. No files require special attention. Important Files Changed
Sequence DiagramsequenceDiagram
participant Py as Python Caller
participant TO as torch.ops.transformer_engine
participant Reg as registration.cpp (STABLE_TORCH_LIBRARY)
participant Impl as softmax.cpp (STABLE_TORCH_LIBRARY_IMPL)
participant SC as stable_common.h
participant NVTE as nvte_scaled_*_forward/backward
Note over Py,NVTE: Old path: tex.scaled_softmax_forward(input, scale)
Note over Py,NVTE: New path via stable ABI
Py->>TO: _ops.scaled_softmax_forward(input, scale_t[0])
TO->>Reg: schema lookup scaled_softmax_forward(Tensor, float)->Tensor
Reg->>Impl: dispatch to CUDA impl
Impl->>SC: check_fp16_bf16(input)
Impl->>SC: allocateStableTensor(shape, dtype, device)
Impl->>SC: makeTransformerEngineTensor(input)
Impl->>SC: getCurrentCUDAStreamRaw(device_index)
Impl->>NVTE: nvte_scaled_softmax_forward(input, output, scale, stream)
NVTE-->>Impl: kernel result in output tensor
Impl-->>Py: softmax_results
Reviews (3): Last reviewed commit: "Port softmax ops to libtorch stable ABI" | Re-trigger Greptile |
| import transformer_engine_torch as tex | ||
| from transformer_engine.pytorch.export import is_in_onnx_export_mode | ||
|
|
||
| _ops = torch.ops.transformer_engine |
There was a problem hiding this comment.
Module-level op namespace binding may hide load-time errors
_ops = torch.ops.transformer_engine is evaluated at import time. torch.ops provides a lazy namespace so this line itself doesn't fail, but any subsequent attribute access (e.g. _ops.scaled_softmax_forward) will raise AttributeError if the native library hasn't been loaded yet. Accessing the namespace at call time or guarding with try/except would give a clearer error message if the shared library fails to load.
There was a problem hiding this comment.
torch.ops is a lazy namespace — torch.ops.transformer_engine succeeds even if the library is not loaded. The actual resolution happens at call time (e.g. _ops.scaled_softmax_forward(...)), which will raise a clear AttributeError if the op is not registered. Wrapping every call in try/except would add overhead without improving the error message, and failing at import time would be worse since it would prevent importing the module entirely even for code paths that don't use these ops.
|
/te-ci pytorch |
7d9f212 to
5611150
Compare
timmoon10
left a comment
There was a problem hiding this comment.
This is quite nice and straightforward. The only real issue is that this requires bumping the minimum PyTorch version.
| // PyTorch Stable ABI headers | ||
| #include <torch/csrc/stable/accelerator.h> | ||
| #include <torch/csrc/stable/library.h> | ||
| #include <torch/csrc/stable/ops.h> | ||
| #include <torch/csrc/stable/tensor.h> | ||
| #include <torch/headeronly/core/ScalarType.h> |
There was a problem hiding this comment.
What is the minimum PyTorch version that will be required? I see that these files were first added in PyTorch 2.9.0. At the moment, I believe we only enforce 2.1+:
There was a problem hiding this comment.
Ah, yes. Probably 2.10 to be honest, torch::stable::empty was added there
Proof of concept for migrating pybind11 functions to the PyTorch stable ABI. Ports all 8 scaled softmax functions: - Add stable_common.h with stable ABI helpers (tensor allocation, TensorWrapper construction, CUDA stream, dtype converters) - Add registration.cpp with STABLE_TORCH_LIBRARY schema definitions - Rewrite softmax.cpp: at::Tensor -> torch::stable::Tensor, use stable allocation and stream APIs, TORCH_BOX() for impl registration - Remove softmax registrations from pybind.cpp - Update Python callers to use torch.ops.transformer_engine_stable The pattern is mechanical (API translation, no logic changes) and establishes the template for porting the remaining ~70 Category A functions that have no py::handle/py::object dependencies. Signed-off-by: Peter St. John <pstjohn@nvidia.com>
5611150 to
e0bae00
Compare
|
Here's a demo PR for the next set of easy modules: |
| """Install dependencies for TE/PyTorch extensions.""" | ||
| return ["torch>=2.1", "einops", "onnxscript", "onnx", "packaging", "pydantic", "nvdlfw-inspect"] | ||
| return [ | ||
| "torch>=2.10", |
There was a problem hiding this comment.
Requiring the (close to) latest Pytorch is a problem which I don't think we will be able to do before TE 3.0 to be honest.
There was a problem hiding this comment.
😢 fair enough -- is there a timeline for TE 3.0?
The alternative, that might let us get away with a lower pytorch version at the cost of a more extensive refactor, is to make sure any work matrices get allocated on the python side and passed in as arguments
There was a problem hiding this comment.
Unfortunately, we've found that allocating in Python has non-trivial CPU overhead compared to allocating in C++.
There was a problem hiding this comment.
If we could validate this and get through all the ops, it would be cool if we could completely remove the pybind11 library in TE 3.0 🤷
Proof of concept for migrating pybind11 functions to the PyTorch stable ABI. Ports all 8 scaled softmax functions:
The pattern is mechanical (API translation, no logic changes) and establishes the template for porting the remaining ~70 Category A functions that have no py::handle/py::object dependencies.
Breakdown of currently registered functions by category: https://docs.google.com/spreadsheets/d/1p9bgmas65M03-yak3zmPImMLXfb7mp2fG_VVcJPNFqo/edit?gid=1617092873#gid=1617092873