-
Notifications
You must be signed in to change notification settings - Fork 687
Port softmax ops to libtorch stable ABI #2830
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -7,9 +7,10 @@ | |
| from typing import Callable, Tuple, Union, Optional | ||
| import torch | ||
| from torch import nn | ||
| import transformer_engine_torch as tex | ||
| from transformer_engine.pytorch.export import is_in_onnx_export_mode | ||
|
|
||
| _ops = torch.ops.transformer_engine | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
|
|
||
|
|
||
| THREADS_PER_WARP = 32 | ||
| THREADS_PER_BLOCK = 128 | ||
|
|
@@ -47,7 +48,7 @@ class ScaledUpperTriangMaskedSoftmax(torch.autograd.Function): | |
| def forward(ctx, inputs: torch.Tensor, scale: float) -> torch.Tensor: | ||
| """ScaledUpperTriangMaskedSoftmax fwd""" | ||
| scale_t = torch.tensor([scale]) | ||
| softmax_results = tex.scaled_upper_triang_masked_softmax_forward(inputs, scale_t[0]) | ||
| softmax_results = _ops.scaled_upper_triang_masked_softmax_forward(inputs, scale_t[0]) | ||
|
|
||
| ctx.save_for_backward(softmax_results, scale_t) | ||
| return softmax_results | ||
|
|
@@ -56,7 +57,7 @@ def forward(ctx, inputs: torch.Tensor, scale: float) -> torch.Tensor: | |
| def backward(ctx, output_grads: torch.Tensor) -> Tuple[Union[torch.Tensor, None], ...]: | ||
| """ScaledUpperTriangMaskedSoftmax bwd""" | ||
| softmax_results, scale_t = ctx.saved_tensors | ||
| input_grads = tex.scaled_upper_triang_masked_softmax_backward( | ||
| input_grads = _ops.scaled_upper_triang_masked_softmax_backward( | ||
| output_grads, softmax_results, scale_t[0] | ||
| ) | ||
|
|
||
|
|
@@ -75,15 +76,15 @@ class ScaledAlignedCausalMaskedSoftmax(torch.autograd.Function): | |
| def forward(ctx, inputs: torch.Tensor, scale: float) -> torch.Tensor: | ||
| """ScaledAlignedCausalMaskedSoftmax fwd""" | ||
| scale_t = torch.tensor([scale]) | ||
| softmax_results = tex.scaled_aligned_causal_masked_softmax_forward(inputs, scale_t[0]) | ||
| softmax_results = _ops.scaled_aligned_causal_masked_softmax_forward(inputs, scale_t[0]) | ||
| ctx.save_for_backward(softmax_results, scale_t) | ||
| return softmax_results | ||
|
|
||
| @staticmethod | ||
| def backward(ctx, output_grads: torch.Tensor) -> Tuple[Union[torch.Tensor, None], ...]: | ||
| """ScaledAlignedCausalMaskedSoftmax bwd""" | ||
| softmax_results, scale_t = ctx.saved_tensors | ||
| input_grads = tex.scaled_aligned_causal_masked_softmax_backward( | ||
| input_grads = _ops.scaled_aligned_causal_masked_softmax_backward( | ||
| output_grads, softmax_results, scale_t[0] | ||
| ) | ||
|
|
||
|
|
@@ -103,7 +104,7 @@ def forward(ctx, inputs: torch.Tensor, mask: torch.Tensor, scale: float) -> torc | |
| """ScaledMaskedSoftmax fwd""" | ||
| scale_t = torch.tensor([scale]) | ||
|
|
||
| softmax_results = tex.scaled_masked_softmax_forward(inputs, mask, scale_t[0]) | ||
| softmax_results = _ops.scaled_masked_softmax_forward(inputs, mask, scale_t[0]) | ||
| ctx.save_for_backward(softmax_results, scale_t) | ||
| return softmax_results | ||
|
|
||
|
|
@@ -112,7 +113,7 @@ def backward(ctx, output_grads: torch.Tensor) -> Tuple[Union[torch.Tensor, None] | |
| """ScaledMaskedSoftmax bwd""" | ||
| softmax_results, scale_t = ctx.saved_tensors | ||
|
|
||
| input_grads = tex.scaled_masked_softmax_backward(output_grads, softmax_results, scale_t[0]) | ||
| input_grads = _ops.scaled_masked_softmax_backward(output_grads, softmax_results, scale_t[0]) | ||
| return input_grads, None, None | ||
|
|
||
|
|
||
|
|
@@ -128,7 +129,7 @@ def forward(ctx, inputs: torch.Tensor, scale: float) -> torch.Tensor: | |
| """ScaledSoftmax fwd""" | ||
| scale_t = torch.tensor([scale]) | ||
|
|
||
| softmax_results = tex.scaled_softmax_forward(inputs, scale_t[0]) | ||
| softmax_results = _ops.scaled_softmax_forward(inputs, scale_t[0]) | ||
| ctx.save_for_backward(softmax_results, scale_t) | ||
| return softmax_results | ||
|
|
||
|
|
@@ -137,7 +138,7 @@ def backward(ctx, output_grads: torch.Tensor) -> Tuple[Union[torch.Tensor, None] | |
| """ScaledSoftmax bwd""" | ||
| softmax_results, scale_t = ctx.saved_tensors | ||
|
|
||
| input_grads = tex.scaled_softmax_backward(output_grads, softmax_results, scale_t[0]) | ||
| input_grads = _ops.scaled_softmax_backward(output_grads, softmax_results, scale_t[0]) | ||
| return input_grads, None, None | ||
|
|
||
|
|
||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,30 @@ | ||
| /************************************************************************* | ||
| * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. | ||
| * | ||
| * See LICENSE for license information. | ||
| ************************************************************************/ | ||
|
|
||
| #include "../stable_common.h" | ||
|
|
||
| // This file defines the transformer_engine library namespace. | ||
| // All other stable ABI files use STABLE_TORCH_LIBRARY_FRAGMENT to add schemas | ||
| // and STABLE_TORCH_LIBRARY_IMPL to add implementations. | ||
| STABLE_TORCH_LIBRARY(transformer_engine, m) { | ||
| // Softmax ops | ||
| m.def("scaled_softmax_forward(Tensor input, float scale_factor) -> Tensor"); | ||
| m.def( | ||
| "scaled_softmax_backward(Tensor output_grad, Tensor softmax_results, float scale_factor) -> " | ||
| "Tensor"); | ||
| m.def("scaled_masked_softmax_forward(Tensor input, Tensor mask, float scale_factor) -> Tensor"); | ||
| m.def( | ||
| "scaled_masked_softmax_backward(Tensor output_grad, Tensor softmax_results, float " | ||
| "scale_factor) -> Tensor"); | ||
| m.def("scaled_upper_triang_masked_softmax_forward(Tensor input, float scale_factor) -> Tensor"); | ||
| m.def( | ||
| "scaled_upper_triang_masked_softmax_backward(Tensor output_grads, Tensor softmax_results, " | ||
| "float scale_factor) -> Tensor"); | ||
| m.def("scaled_aligned_causal_masked_softmax_forward(Tensor input, float scale_factor) -> Tensor"); | ||
| m.def( | ||
| "scaled_aligned_causal_masked_softmax_backward(Tensor output_grad, Tensor softmax_results, " | ||
| "float scale_factor) -> Tensor"); | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Requiring the (close to) latest Pytorch is a problem which I don't think we will be able to do before TE 3.0 to be honest.
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
😢 fair enough -- is there a timeline for TE 3.0?
The alternative, that might let us get away with a lower pytorch version at the cost of a more extensive refactor, is to make sure any work matrices get allocated on the python side and passed in as arguments
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Unfortunately, we've found that allocating in Python has non-trivial CPU overhead compared to allocating in C++.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If we could validate this and get through all the ops, it would be cool if we could completely remove the pybind11 library in TE 3.0 🤷