From e0bae007c5eaa37faf45bcb50f04bb27ba9f3f7f Mon Sep 17 00:00:00 2001 From: "Peter St. John" Date: Fri, 3 Apr 2026 06:11:35 -0700 Subject: [PATCH] Port softmax ops to libtorch stable ABI Proof of concept for migrating pybind11 functions to the PyTorch stable ABI. Ports all 8 scaled softmax functions: - Add stable_common.h with stable ABI helpers (tensor allocation, TensorWrapper construction, CUDA stream, dtype converters) - Add registration.cpp with STABLE_TORCH_LIBRARY schema definitions - Rewrite softmax.cpp: at::Tensor -> torch::stable::Tensor, use stable allocation and stream APIs, TORCH_BOX() for impl registration - Remove softmax registrations from pybind.cpp - Update Python callers to use torch.ops.transformer_engine_stable The pattern is mechanical (API translation, no logic changes) and establishes the template for porting the remaining ~70 Category A functions that have no py::handle/py::object dependencies. Signed-off-by: Peter St. John --- build_tools/pytorch.py | 10 +- pyproject.toml | 2 +- transformer_engine/pytorch/__init__.py | 2 +- .../dot_product_attention/softmax.py | 19 +- transformer_engine/pytorch/csrc/extensions.h | 26 -- .../pytorch/csrc/extensions/pybind.cpp | 26 -- .../pytorch/csrc/extensions/registration.cpp | 30 ++ .../pytorch/csrc/extensions/softmax.cpp | 290 +++++++++--------- .../pytorch/csrc/stable_common.h | 135 ++++++++ transformer_engine/pytorch/pyproject.toml | 2 +- 10 files changed, 331 insertions(+), 211 deletions(-) create mode 100644 transformer_engine/pytorch/csrc/extensions/registration.cpp create mode 100644 transformer_engine/pytorch/csrc/stable_common.h diff --git a/build_tools/pytorch.py b/build_tools/pytorch.py index fdfdee9b1c..1c41907153 100644 --- a/build_tools/pytorch.py +++ b/build_tools/pytorch.py @@ -14,7 +14,15 @@ def install_requirements() -> List[str]: """Install dependencies for TE/PyTorch extensions.""" - return ["torch>=2.1", "einops", "onnxscript", "onnx", "packaging", "pydantic", "nvdlfw-inspect"] + return [ + "torch>=2.10", + "einops", + "onnxscript", + "onnx", + "packaging", + "pydantic", + "nvdlfw-inspect", + ] def test_requirements() -> List[str]: diff --git a/pyproject.toml b/pyproject.toml index 4a8fded172..477b0f93ee 100755 --- a/pyproject.toml +++ b/pyproject.toml @@ -3,7 +3,7 @@ # See LICENSE for license information. [build-system] -requires = ["setuptools>=61.0", "cmake>=3.21", "wheel", "pybind11[global]", "ninja", "pip", "torch>=2.1", "jax>=0.5.0", "flax>=0.7.1"] +requires = ["setuptools>=61.0", "cmake>=3.21", "wheel", "pybind11[global]", "ninja", "pip", "torch>=2.10", "jax>=0.5.0", "flax>=0.7.1"] # Use legacy backend to import local packages in setup.py build-backend = "setuptools.build_meta:__legacy__" diff --git a/transformer_engine/pytorch/__init__.py b/transformer_engine/pytorch/__init__.py index bbc1d7fab6..dcc90e504e 100644 --- a/transformer_engine/pytorch/__init__.py +++ b/transformer_engine/pytorch/__init__.py @@ -13,7 +13,7 @@ from transformer_engine.common import load_framework_extension from transformer_engine.pytorch.torch_version import torch_version -assert torch_version() >= (2, 1), f"Minimum torch version 2.1 required. Found {torch_version()}." +assert torch_version() >= (2, 10), f"Minimum torch version 2.10 required. Found {torch_version()}." load_framework_extension("torch") from transformer_engine.pytorch.module import LayerNormLinear diff --git a/transformer_engine/pytorch/attention/dot_product_attention/softmax.py b/transformer_engine/pytorch/attention/dot_product_attention/softmax.py index 74d9583ce5..bb0a08d0e7 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/softmax.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/softmax.py @@ -7,9 +7,10 @@ from typing import Callable, Tuple, Union, Optional import torch from torch import nn -import transformer_engine_torch as tex from transformer_engine.pytorch.export import is_in_onnx_export_mode +_ops = torch.ops.transformer_engine + THREADS_PER_WARP = 32 THREADS_PER_BLOCK = 128 @@ -47,7 +48,7 @@ class ScaledUpperTriangMaskedSoftmax(torch.autograd.Function): def forward(ctx, inputs: torch.Tensor, scale: float) -> torch.Tensor: """ScaledUpperTriangMaskedSoftmax fwd""" scale_t = torch.tensor([scale]) - softmax_results = tex.scaled_upper_triang_masked_softmax_forward(inputs, scale_t[0]) + softmax_results = _ops.scaled_upper_triang_masked_softmax_forward(inputs, scale_t[0]) ctx.save_for_backward(softmax_results, scale_t) return softmax_results @@ -56,7 +57,7 @@ def forward(ctx, inputs: torch.Tensor, scale: float) -> torch.Tensor: def backward(ctx, output_grads: torch.Tensor) -> Tuple[Union[torch.Tensor, None], ...]: """ScaledUpperTriangMaskedSoftmax bwd""" softmax_results, scale_t = ctx.saved_tensors - input_grads = tex.scaled_upper_triang_masked_softmax_backward( + input_grads = _ops.scaled_upper_triang_masked_softmax_backward( output_grads, softmax_results, scale_t[0] ) @@ -75,7 +76,7 @@ class ScaledAlignedCausalMaskedSoftmax(torch.autograd.Function): def forward(ctx, inputs: torch.Tensor, scale: float) -> torch.Tensor: """ScaledAlignedCausalMaskedSoftmax fwd""" scale_t = torch.tensor([scale]) - softmax_results = tex.scaled_aligned_causal_masked_softmax_forward(inputs, scale_t[0]) + softmax_results = _ops.scaled_aligned_causal_masked_softmax_forward(inputs, scale_t[0]) ctx.save_for_backward(softmax_results, scale_t) return softmax_results @@ -83,7 +84,7 @@ def forward(ctx, inputs: torch.Tensor, scale: float) -> torch.Tensor: def backward(ctx, output_grads: torch.Tensor) -> Tuple[Union[torch.Tensor, None], ...]: """ScaledAlignedCausalMaskedSoftmax bwd""" softmax_results, scale_t = ctx.saved_tensors - input_grads = tex.scaled_aligned_causal_masked_softmax_backward( + input_grads = _ops.scaled_aligned_causal_masked_softmax_backward( output_grads, softmax_results, scale_t[0] ) @@ -103,7 +104,7 @@ def forward(ctx, inputs: torch.Tensor, mask: torch.Tensor, scale: float) -> torc """ScaledMaskedSoftmax fwd""" scale_t = torch.tensor([scale]) - softmax_results = tex.scaled_masked_softmax_forward(inputs, mask, scale_t[0]) + softmax_results = _ops.scaled_masked_softmax_forward(inputs, mask, scale_t[0]) ctx.save_for_backward(softmax_results, scale_t) return softmax_results @@ -112,7 +113,7 @@ def backward(ctx, output_grads: torch.Tensor) -> Tuple[Union[torch.Tensor, None] """ScaledMaskedSoftmax bwd""" softmax_results, scale_t = ctx.saved_tensors - input_grads = tex.scaled_masked_softmax_backward(output_grads, softmax_results, scale_t[0]) + input_grads = _ops.scaled_masked_softmax_backward(output_grads, softmax_results, scale_t[0]) return input_grads, None, None @@ -128,7 +129,7 @@ def forward(ctx, inputs: torch.Tensor, scale: float) -> torch.Tensor: """ScaledSoftmax fwd""" scale_t = torch.tensor([scale]) - softmax_results = tex.scaled_softmax_forward(inputs, scale_t[0]) + softmax_results = _ops.scaled_softmax_forward(inputs, scale_t[0]) ctx.save_for_backward(softmax_results, scale_t) return softmax_results @@ -137,7 +138,7 @@ def backward(ctx, output_grads: torch.Tensor) -> Tuple[Union[torch.Tensor, None] """ScaledSoftmax bwd""" softmax_results, scale_t = ctx.saved_tensors - input_grads = tex.scaled_softmax_backward(output_grads, softmax_results, scale_t[0]) + input_grads = _ops.scaled_softmax_backward(output_grads, softmax_results, scale_t[0]) return input_grads, None, None diff --git a/transformer_engine/pytorch/csrc/extensions.h b/transformer_engine/pytorch/csrc/extensions.h index 1c5116a8da..bac70904db 100644 --- a/transformer_engine/pytorch/csrc/extensions.h +++ b/transformer_engine/pytorch/csrc/extensions.h @@ -349,32 +349,6 @@ py::object dropout_bwd(const at::Tensor &grad_output, const at::Tensor &mask, const float dropout_probability, std::optional grad_input = std::nullopt); -/*************************************************************************************************** - * Softmax - **************************************************************************************************/ - -at::Tensor scaled_softmax_forward(at::Tensor input, float scale_factor); - -at::Tensor scaled_softmax_backward(at::Tensor output_grad_, at::Tensor softmax_results_, - float scale_factor); - -at::Tensor scaled_masked_softmax_forward(at::Tensor input, at::Tensor mask, float scale_factor); - -at::Tensor scaled_masked_softmax_backward(at::Tensor output_grad_, at::Tensor softmax_results_, - float scale_factor); - -at::Tensor scaled_upper_triang_masked_softmax_forward(at::Tensor input, float scale_factor); - -at::Tensor scaled_upper_triang_masked_softmax_backward(at::Tensor output_grads_, - at::Tensor softmax_results_, - float scale_factor); - -at::Tensor scaled_aligned_causal_masked_softmax_forward(at::Tensor input, float scale_factor); - -at::Tensor scaled_aligned_causal_masked_softmax_backward(at::Tensor output_grads_, - at::Tensor softmax_results_, - float scale_factor); - /*************************************************************************************************** * FP8 recipe **************************************************************************************************/ diff --git a/transformer_engine/pytorch/csrc/extensions/pybind.cpp b/transformer_engine/pytorch/csrc/extensions/pybind.cpp index c590a3c9e2..4360496d87 100644 --- a/transformer_engine/pytorch/csrc/extensions/pybind.cpp +++ b/transformer_engine/pytorch/csrc/extensions/pybind.cpp @@ -232,32 +232,6 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("moe_unpermute_bwd", transformer_engine::pytorch::moe_unpermute_bwd, "MOE unpermute BWD", py::call_guard()); - // Softmax functions - m.def("scaled_softmax_forward", &transformer_engine::pytorch::scaled_softmax_forward, - "Scaled Softmax FWD", py::call_guard()); - m.def("scaled_softmax_backward", &transformer_engine::pytorch::scaled_softmax_backward, - "Scaled Softmax BWD", py::call_guard()); - m.def("scaled_masked_softmax_forward", - &transformer_engine::pytorch::scaled_masked_softmax_forward, "Scaled Masked Softmax FWD", - py::call_guard()); - m.def("scaled_masked_softmax_backward", - &transformer_engine::pytorch::scaled_masked_softmax_backward, "Scaled Masked Softmax BWD", - py::call_guard()); - m.def("scaled_upper_triang_masked_softmax_forward", - &transformer_engine::pytorch::scaled_upper_triang_masked_softmax_forward, - "Scaled Upper-Triangular Masked Softmax FWD", py::call_guard()); - m.def("scaled_upper_triang_masked_softmax_backward", - &transformer_engine::pytorch::scaled_upper_triang_masked_softmax_backward, - "Scaled Upper-Triangular Masked Softmax BWD", py::call_guard()); - m.def("scaled_aligned_causal_masked_softmax_forward", - &transformer_engine::pytorch::scaled_aligned_causal_masked_softmax_forward, - "Scaled Bottom-Right Corner Aligned Masked Softmax FWD", - py::call_guard()); - m.def("scaled_aligned_causal_masked_softmax_backward", - &transformer_engine::pytorch::scaled_aligned_causal_masked_softmax_backward, - "Scaled Bottom-Right Corner Aligned Masked Softmax BWD", - py::call_guard()); - // Other granular functions m.def("layernorm_fwd", &transformer_engine::pytorch::layernorm_fwd, "LayerNorm", py::arg("input"), py::arg("weight"), py::arg("bias"), py::arg("eps"), py::arg("ln_out"), py::arg("quantizer"), diff --git a/transformer_engine/pytorch/csrc/extensions/registration.cpp b/transformer_engine/pytorch/csrc/extensions/registration.cpp new file mode 100644 index 0000000000..5b92979493 --- /dev/null +++ b/transformer_engine/pytorch/csrc/extensions/registration.cpp @@ -0,0 +1,30 @@ +/************************************************************************* + * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#include "../stable_common.h" + +// This file defines the transformer_engine library namespace. +// All other stable ABI files use STABLE_TORCH_LIBRARY_FRAGMENT to add schemas +// and STABLE_TORCH_LIBRARY_IMPL to add implementations. +STABLE_TORCH_LIBRARY(transformer_engine, m) { + // Softmax ops + m.def("scaled_softmax_forward(Tensor input, float scale_factor) -> Tensor"); + m.def( + "scaled_softmax_backward(Tensor output_grad, Tensor softmax_results, float scale_factor) -> " + "Tensor"); + m.def("scaled_masked_softmax_forward(Tensor input, Tensor mask, float scale_factor) -> Tensor"); + m.def( + "scaled_masked_softmax_backward(Tensor output_grad, Tensor softmax_results, float " + "scale_factor) -> Tensor"); + m.def("scaled_upper_triang_masked_softmax_forward(Tensor input, float scale_factor) -> Tensor"); + m.def( + "scaled_upper_triang_masked_softmax_backward(Tensor output_grads, Tensor softmax_results, " + "float scale_factor) -> Tensor"); + m.def("scaled_aligned_causal_masked_softmax_forward(Tensor input, float scale_factor) -> Tensor"); + m.def( + "scaled_aligned_causal_masked_softmax_backward(Tensor output_grad, Tensor softmax_results, " + "float scale_factor) -> Tensor"); +} diff --git a/transformer_engine/pytorch/csrc/extensions/softmax.cpp b/transformer_engine/pytorch/csrc/extensions/softmax.cpp index 3bb6a5e7b3..dd6ee5fe32 100644 --- a/transformer_engine/pytorch/csrc/extensions/softmax.cpp +++ b/transformer_engine/pytorch/csrc/extensions/softmax.cpp @@ -4,234 +4,232 @@ * See LICENSE for license information. ************************************************************************/ -#include "../extensions.h" +#include -namespace transformer_engine::pytorch { +#include "../stable_common.h" -at::Tensor scaled_softmax_forward(at::Tensor input, float scale_factor) { - AT_ASSERTM(input.dim() == 4, "expected 4D tensor"); - AT_ASSERTM((input.scalar_type() == at::ScalarType::Half) || - (input.scalar_type() == at::ScalarType::BFloat16), - "Only fp16 and bf16 are supported"); +namespace transformer_engine::pytorch::stable { - const int batches = input.size(0); - const int attn_heads = input.size(1); - const int query_seq_len = input.size(2); - const int key_seq_len = input.size(3); +using Tensor = torch::stable::Tensor; - AT_ASSERTM(key_seq_len <= 16384, "Key sequence length must be 16384 or less"); - AT_ASSERTM(key_seq_len % 8 == 0, "Key sequence length must be divisible by 8"); - AT_ASSERTM(query_seq_len > 1, "Query sequence length must be greater than 1"); +Tensor scaled_softmax_forward(Tensor input, double scale_factor) { + NVTE_CHECK(input.dim() == 4, "expected 4D tensor"); + check_fp16_bf16(input, "scaled_softmax_forward"); - // Output - auto act_options = input.options().requires_grad(false); - auto softmax_results = - torch::empty({batches, attn_heads, query_seq_len, key_seq_len}, act_options); + auto sizes = input.sizes(); + const int64_t batches = sizes[0]; + const int64_t attn_heads = sizes[1]; + const int64_t query_seq_len = sizes[2]; + const int64_t key_seq_len = sizes[3]; + + NVTE_CHECK(key_seq_len <= 16384, "Key sequence length must be 16384 or less"); + NVTE_CHECK(key_seq_len % 8 == 0, "Key sequence length must be divisible by 8"); + NVTE_CHECK(query_seq_len > 1, "Query sequence length must be greater than 1"); + + auto softmax_results = allocateStableTensor({batches, attn_heads, query_seq_len, key_seq_len}, + input.scalar_type(), input.get_device_index()); auto input_cu = makeTransformerEngineTensor(input); auto softmax_results_cu = makeTransformerEngineTensor(softmax_results); - nvte_scaled_softmax_forward(input_cu.data(), softmax_results_cu.data(), scale_factor, - at::cuda::getCurrentCUDAStream()); + nvte_scaled_softmax_forward(input_cu.data(), softmax_results_cu.data(), + static_cast(scale_factor), + getCurrentCUDAStreamRaw(input.get_device_index())); return softmax_results; } -at::Tensor scaled_softmax_backward(at::Tensor output_grad_, at::Tensor softmax_results_, - float scale_factor) { - auto output_grads = output_grad_.contiguous(); - auto softmax_results = softmax_results_.contiguous(); - - AT_ASSERTM(output_grads.dim() == 4, "expected 4D tensor"); - AT_ASSERTM(softmax_results.dim() == 4, "expected 4D tensor"); +Tensor scaled_softmax_backward(Tensor output_grad, Tensor softmax_results, double scale_factor) { + auto output_grads = torch::stable::contiguous(output_grad); + softmax_results = torch::stable::contiguous(softmax_results); - AT_ASSERTM((output_grads.scalar_type() == at::ScalarType::Half) || - (output_grads.scalar_type() == at::ScalarType::BFloat16), - "Only fp16 and bf16 are supported"); - AT_ASSERTM((softmax_results.scalar_type() == at::ScalarType::Half) || - (softmax_results.scalar_type() == at::ScalarType::BFloat16), - "Only fp16 and bf16 are supported"); + NVTE_CHECK(output_grads.dim() == 4, "expected 4D tensor"); + NVTE_CHECK(softmax_results.dim() == 4, "expected 4D tensor"); + check_fp16_bf16(output_grads, "scaled_softmax_backward"); + check_fp16_bf16(softmax_results, "scaled_softmax_backward"); auto output_grads_cu = makeTransformerEngineTensor(output_grads); auto softmax_results_cu = makeTransformerEngineTensor(softmax_results); // Produce gradients in place. nvte_scaled_softmax_backward(output_grads_cu.data(), softmax_results_cu.data(), - output_grads_cu.data(), scale_factor, - at::cuda::getCurrentCUDAStream()); + output_grads_cu.data(), static_cast(scale_factor), + getCurrentCUDAStreamRaw(output_grads.get_device_index())); return output_grads; } -at::Tensor scaled_masked_softmax_forward(at::Tensor input, at::Tensor mask, float scale_factor) { - AT_ASSERTM(input.dim() == 4, "expected 4D tensor"); - AT_ASSERTM((input.scalar_type() == at::ScalarType::Half) || - (input.scalar_type() == at::ScalarType::BFloat16), - "Only fp16 and bf16 are supported"); - AT_ASSERTM(mask.dim() == 4, "expected 4D tensor"); - if (!input.is_contiguous()) input = input.contiguous(); - if (!mask.is_contiguous()) mask = mask.contiguous(); - - const int batches = input.size(0); - const int pad_batches = mask.size(0); - const int attn_heads = input.size(1); - const int query_seq_len = input.size(2); - const int key_seq_len = input.size(3); - - AT_ASSERTM(key_seq_len <= 16384, "Key sequence length must be 16384 or less"); - AT_ASSERTM(key_seq_len % 8 == 0, "Key sequence length must be divisible by 8"); - AT_ASSERTM(query_seq_len > 1, "Query sequence length must be greater than 1"); - TORCH_CHECK(pad_batches == 1 || pad_batches == batches); - TORCH_CHECK(mask.size(1) == 1); - TORCH_CHECK(mask.size(2) == query_seq_len); - TORCH_CHECK(mask.size(3) == key_seq_len); - - auto act_options = input.options().requires_grad(false); - auto softmax_results = - torch::empty({batches, attn_heads, query_seq_len, key_seq_len}, act_options); +Tensor scaled_masked_softmax_forward(Tensor input, Tensor mask, double scale_factor) { + NVTE_CHECK(input.dim() == 4, "expected 4D tensor"); + NVTE_CHECK(mask.dim() == 4, "expected 4D tensor"); + check_fp16_bf16(input, "scaled_masked_softmax_forward"); + input = torch::stable::contiguous(input); + mask = torch::stable::contiguous(mask); + + auto sizes = input.sizes(); + const int64_t batches = sizes[0]; + const int64_t attn_heads = sizes[1]; + const int64_t query_seq_len = sizes[2]; + const int64_t key_seq_len = sizes[3]; + + auto mask_sizes = mask.sizes(); + const int64_t pad_batches = mask_sizes[0]; + NVTE_CHECK(pad_batches == 1 || pad_batches == batches, + "Mask batch dim must be 1 or match input batch dim"); + NVTE_CHECK(mask_sizes[1] == 1, "Mask second dim must be 1"); + NVTE_CHECK(mask_sizes[2] == query_seq_len, "Mask query dim must match input"); + NVTE_CHECK(mask_sizes[3] == key_seq_len, "Mask key dim must match input"); + + NVTE_CHECK(key_seq_len <= 16384, "Key sequence length must be 16384 or less"); + NVTE_CHECK(key_seq_len % 8 == 0, "Key sequence length must be divisible by 8"); + NVTE_CHECK(query_seq_len > 1, "Query sequence length must be greater than 1"); + + auto softmax_results = allocateStableTensor({batches, attn_heads, query_seq_len, key_seq_len}, + input.scalar_type(), input.get_device_index()); auto input_cu = makeTransformerEngineTensor(input); auto mask_cu = makeTransformerEngineTensor(mask); auto softmax_results_cu = makeTransformerEngineTensor(softmax_results); nvte_scaled_masked_softmax_forward(input_cu.data(), mask_cu.data(), softmax_results_cu.data(), - scale_factor, at::cuda::getCurrentCUDAStream()); + static_cast(scale_factor), + getCurrentCUDAStreamRaw(input.get_device_index())); return softmax_results; } -at::Tensor scaled_masked_softmax_backward(at::Tensor output_grad_, at::Tensor softmax_results_, - float scale_factor) { - auto output_grads = output_grad_.contiguous(); - auto softmax_results = softmax_results_.contiguous(); +Tensor scaled_masked_softmax_backward(Tensor output_grad, Tensor softmax_results, + double scale_factor) { + auto output_grads = torch::stable::contiguous(output_grad); + softmax_results = torch::stable::contiguous(softmax_results); - AT_ASSERTM(output_grads.dim() == 4, "expected 3D tensor"); - AT_ASSERTM(softmax_results.dim() == 4, "expected 3D tensor"); - - AT_ASSERTM((output_grads.scalar_type() == at::ScalarType::Half) || - (output_grads.scalar_type() == at::ScalarType::BFloat16), - "Only fp16 and bf16 are supported"); - AT_ASSERTM((softmax_results.scalar_type() == at::ScalarType::Half) || - (softmax_results.scalar_type() == at::ScalarType::BFloat16), - "Only fp16 and bf16 are supported"); + NVTE_CHECK(output_grads.dim() == 4, "expected 4D tensor"); + NVTE_CHECK(softmax_results.dim() == 4, "expected 4D tensor"); + check_fp16_bf16(output_grads, "scaled_masked_softmax_backward"); + check_fp16_bf16(softmax_results, "scaled_masked_softmax_backward"); auto output_grads_cu = makeTransformerEngineTensor(output_grads); auto softmax_results_cu = makeTransformerEngineTensor(softmax_results); - // Produce gradients in place. nvte_scaled_softmax_backward(output_grads_cu.data(), softmax_results_cu.data(), - output_grads_cu.data(), scale_factor, - at::cuda::getCurrentCUDAStream()); + output_grads_cu.data(), static_cast(scale_factor), + getCurrentCUDAStreamRaw(output_grads.get_device_index())); return output_grads; } -at::Tensor scaled_upper_triang_masked_softmax_forward(at::Tensor input, float scale_factor) { - AT_ASSERTM(input.dim() == 3, "expected 3D tensor"); - AT_ASSERTM((input.scalar_type() == at::ScalarType::Half) || - (input.scalar_type() == at::ScalarType::BFloat16), - "Only fp16 and bf16 are supported"); +Tensor scaled_upper_triang_masked_softmax_forward(Tensor input, double scale_factor) { + NVTE_CHECK(input.dim() == 3, "expected 3D tensor"); + check_fp16_bf16(input, "scaled_upper_triang_masked_softmax_forward"); - const int attn_batches = input.size(0); - const int seq_len = input.size(1); - AT_ASSERTM(seq_len <= 16384, "Sequence length must be 16384 or less"); + auto sizes = input.sizes(); + const int64_t attn_batches = sizes[0]; + const int64_t seq_len = sizes[1]; + NVTE_CHECK(seq_len <= 16384, "Sequence length must be 16384 or less"); - // Output - auto act_options = input.options().requires_grad(false); - auto softmax_results = torch::empty({attn_batches, seq_len, seq_len}, act_options); + auto softmax_results = allocateStableTensor({attn_batches, seq_len, seq_len}, input.scalar_type(), + input.get_device_index()); auto input_cu = makeTransformerEngineTensor(input); auto softmax_results_cu = makeTransformerEngineTensor(softmax_results); - nvte_scaled_upper_triang_masked_softmax_forward(input_cu.data(), softmax_results_cu.data(), - scale_factor, at::cuda::getCurrentCUDAStream()); + nvte_scaled_upper_triang_masked_softmax_forward( + input_cu.data(), softmax_results_cu.data(), static_cast(scale_factor), + getCurrentCUDAStreamRaw(input.get_device_index())); return softmax_results; } -at::Tensor scaled_upper_triang_masked_softmax_backward(at::Tensor output_grads_, - at::Tensor softmax_results_, - float scale_factor) { - auto output_grads = output_grads_.contiguous(); - auto softmax_results = softmax_results_.contiguous(); - - AT_ASSERTM(output_grads.dim() == 3, "expected 3D tensor"); - AT_ASSERTM(softmax_results.dim() == 3, "expected 3D tensor"); +Tensor scaled_upper_triang_masked_softmax_backward(Tensor output_grads_, Tensor softmax_results_, + double scale_factor) { + auto output_grads = torch::stable::contiguous(output_grads_); + auto softmax_results = torch::stable::contiguous(softmax_results_); - AT_ASSERTM((output_grads.scalar_type() == at::ScalarType::Half) || - (output_grads.scalar_type() == at::ScalarType::BFloat16), - "Only fp16 and bf16 are supported"); - AT_ASSERTM((softmax_results.scalar_type() == at::ScalarType::Half) || - (softmax_results.scalar_type() == at::ScalarType::BFloat16), - "Only fp16 and bf16 are supported"); - - TORCH_CHECK(output_grads.size(1) == output_grads.size(2)); + NVTE_CHECK(output_grads.dim() == 3, "expected 3D tensor"); + NVTE_CHECK(softmax_results.dim() == 3, "expected 3D tensor"); + check_fp16_bf16(output_grads, "scaled_upper_triang_masked_softmax_backward"); + check_fp16_bf16(softmax_results, "scaled_upper_triang_masked_softmax_backward"); + NVTE_CHECK(output_grads.sizes()[1] == output_grads.sizes()[2], + "Upper triangular softmax requires square attention matrix"); auto output_grads_cu = makeTransformerEngineTensor(output_grads); auto softmax_results_cu = makeTransformerEngineTensor(softmax_results); - // Produce gradients in place. nvte_scaled_upper_triang_masked_softmax_backward( - output_grads_cu.data(), softmax_results_cu.data(), output_grads_cu.data(), scale_factor, - at::cuda::getCurrentCUDAStream()); + output_grads_cu.data(), softmax_results_cu.data(), output_grads_cu.data(), + static_cast(scale_factor), getCurrentCUDAStreamRaw(output_grads.get_device_index())); return output_grads; } -at::Tensor scaled_aligned_causal_masked_softmax_forward(at::Tensor input, float scale_factor) { - AT_ASSERTM(input.dim() == 4, "expected 4D tensor"); - AT_ASSERTM((input.scalar_type() == at::ScalarType::Half) || - (input.scalar_type() == at::ScalarType::BFloat16), - "Only fp16 and bf16 are supported"); +Tensor scaled_aligned_causal_masked_softmax_forward(Tensor input, double scale_factor) { + NVTE_CHECK(input.dim() == 4, "expected 4D tensor"); + check_fp16_bf16(input, "scaled_aligned_causal_masked_softmax_forward"); - const int batches = input.size(0); - const int attn_heads = input.size(1); - const int query_seq_len = input.size(2); - const int key_seq_len = input.size(3); + auto sizes = input.sizes(); + const int64_t batches = sizes[0]; + const int64_t attn_heads = sizes[1]; + const int64_t query_seq_len = sizes[2]; + const int64_t key_seq_len = sizes[3]; - AT_ASSERTM(key_seq_len <= 16384, "Key sequence length must be 16384 or less"); - AT_ASSERTM(key_seq_len % 8 == 0, "Key sequence length must be divisible by 8"); - AT_ASSERTM(query_seq_len >= 1, "Query sequence length must be greater or equal to 1"); + NVTE_CHECK(key_seq_len <= 16384, "Key sequence length must be 16384 or less"); + NVTE_CHECK(key_seq_len % 8 == 0, "Key sequence length must be divisible by 8"); + NVTE_CHECK(query_seq_len >= 1, "Query sequence length must be greater or equal to 1"); - // Output - auto act_options = input.options().requires_grad(false); - auto softmax_results = - torch::empty({batches, attn_heads, query_seq_len, key_seq_len}, act_options); + auto softmax_results = allocateStableTensor({batches, attn_heads, query_seq_len, key_seq_len}, + input.scalar_type(), input.get_device_index()); auto input_cu = makeTransformerEngineTensor(input); auto softmax_results_cu = makeTransformerEngineTensor(softmax_results); - nvte_scaled_aligned_causal_masked_softmax_forward(input_cu.data(), softmax_results_cu.data(), - scale_factor, at::cuda::getCurrentCUDAStream()); + nvte_scaled_aligned_causal_masked_softmax_forward( + input_cu.data(), softmax_results_cu.data(), static_cast(scale_factor), + getCurrentCUDAStreamRaw(input.get_device_index())); return softmax_results; } -at::Tensor scaled_aligned_causal_masked_softmax_backward(at::Tensor output_grad_, - at::Tensor softmax_results_, - float scale_factor) { - auto output_grads = output_grad_.contiguous(); - auto softmax_results = softmax_results_.contiguous(); - - AT_ASSERTM(output_grads.dim() == 4, "expected 4D tensor"); - AT_ASSERTM(softmax_results.dim() == 4, "expected 4D tensor"); +Tensor scaled_aligned_causal_masked_softmax_backward(Tensor output_grad, Tensor softmax_results_, + double scale_factor) { + auto output_grads = torch::stable::contiguous(output_grad); + auto softmax_results = torch::stable::contiguous(softmax_results_); - AT_ASSERTM((output_grads.scalar_type() == at::ScalarType::Half) || - (output_grads.scalar_type() == at::ScalarType::BFloat16), - "Only fp16 and bf16 are supported"); - AT_ASSERTM((softmax_results.scalar_type() == at::ScalarType::Half) || - (softmax_results.scalar_type() == at::ScalarType::BFloat16), - "Only fp16 and bf16 are supported"); + NVTE_CHECK(output_grads.dim() == 4, "expected 4D tensor"); + NVTE_CHECK(softmax_results.dim() == 4, "expected 4D tensor"); + check_fp16_bf16(output_grads, "scaled_aligned_causal_masked_softmax_backward"); + check_fp16_bf16(softmax_results, "scaled_aligned_causal_masked_softmax_backward"); auto output_grads_cu = makeTransformerEngineTensor(output_grads); auto softmax_results_cu = makeTransformerEngineTensor(softmax_results); - // Produce gradients in place. nvte_scaled_aligned_causal_masked_softmax_backward( - output_grads_cu.data(), softmax_results_cu.data(), output_grads_cu.data(), scale_factor, - at::cuda::getCurrentCUDAStream()); + output_grads_cu.data(), softmax_results_cu.data(), output_grads_cu.data(), + static_cast(scale_factor), getCurrentCUDAStreamRaw(output_grads.get_device_index())); return output_grads; } -} // namespace transformer_engine::pytorch +} // namespace transformer_engine::pytorch::stable + +STABLE_TORCH_LIBRARY_IMPL(transformer_engine, CUDA, m) { + m.impl("scaled_softmax_forward", + TORCH_BOX(&transformer_engine::pytorch::stable::scaled_softmax_forward)); + m.impl("scaled_softmax_backward", + TORCH_BOX(&transformer_engine::pytorch::stable::scaled_softmax_backward)); + m.impl("scaled_masked_softmax_forward", + TORCH_BOX(&transformer_engine::pytorch::stable::scaled_masked_softmax_forward)); + m.impl("scaled_masked_softmax_backward", + TORCH_BOX(&transformer_engine::pytorch::stable::scaled_masked_softmax_backward)); + m.impl( + "scaled_upper_triang_masked_softmax_forward", + TORCH_BOX(&transformer_engine::pytorch::stable::scaled_upper_triang_masked_softmax_forward)); + m.impl( + "scaled_upper_triang_masked_softmax_backward", + TORCH_BOX(&transformer_engine::pytorch::stable::scaled_upper_triang_masked_softmax_backward)); + m.impl("scaled_aligned_causal_masked_softmax_forward", + TORCH_BOX( + &transformer_engine::pytorch::stable::scaled_aligned_causal_masked_softmax_forward)); + m.impl("scaled_aligned_causal_masked_softmax_backward", + TORCH_BOX( + &transformer_engine::pytorch::stable::scaled_aligned_causal_masked_softmax_backward)); +} diff --git a/transformer_engine/pytorch/csrc/stable_common.h b/transformer_engine/pytorch/csrc/stable_common.h new file mode 100644 index 0000000000..2e21422585 --- /dev/null +++ b/transformer_engine/pytorch/csrc/stable_common.h @@ -0,0 +1,135 @@ +/************************************************************************* + * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#ifndef TRANSFORMER_ENGINE_PYTORCH_CSRC_STABLE_COMMON_H_ +#define TRANSFORMER_ENGINE_PYTORCH_CSRC_STABLE_COMMON_H_ + +// Ensure CUDA-specific APIs are available from PyTorch's shim headers +#ifndef USE_CUDA +#define USE_CUDA +#endif + +// PyTorch Stable ABI headers +#include +#include +#include +#include +#include + +// CUDA headers +#include + +// Transformer Engine C API headers +#include + +#include + +#include "common/util/logging.h" + +namespace transformer_engine::pytorch::stable { + +using torch::headeronly::ScalarType; + +// ============================================================================ +// DType converter (ScalarType -> TE DType) +// ============================================================================ + +inline transformer_engine::DType GetTransformerEngineDType(ScalarType t) { + switch (t) { + case ScalarType::Float8_e4m3fn: + return transformer_engine::DType::kFloat8E4M3; + case ScalarType::Float8_e5m2: + return transformer_engine::DType::kFloat8E5M2; + case ScalarType::Half: + return transformer_engine::DType::kFloat16; + case ScalarType::Float: + return transformer_engine::DType::kFloat32; + case ScalarType::BFloat16: + return transformer_engine::DType::kBFloat16; + case ScalarType::Bool: + case ScalarType::Byte: + return transformer_engine::DType::kByte; + case ScalarType::Short: + return transformer_engine::DType::kInt16; + case ScalarType::Int: + return transformer_engine::DType::kInt32; + case ScalarType::Long: + return transformer_engine::DType::kInt64; + default: + NVTE_ERROR("Invalid ScalarType (", static_cast(t), ")."); + } +} + +// ============================================================================ +// CUDA stream utility +// ============================================================================ + +inline cudaStream_t getCurrentCUDAStreamRaw(int32_t device_index = -1) { + if (device_index < 0) { + device_index = torch::stable::accelerator::getCurrentDeviceIndex(); + } + void* stream_ptr = nullptr; + TORCH_ERROR_CODE_CHECK(aoti_torch_get_current_cuda_stream(device_index, &stream_ptr)); + return reinterpret_cast(stream_ptr); +} + +// ============================================================================ +// Shape utility +// ============================================================================ + +inline std::vector getStableTensorShape(const torch::stable::Tensor& t) { + auto sizes = t.sizes(); + std::vector shape; + shape.reserve(sizes.size()); + for (size_t i = 0; i < sizes.size(); ++i) { + shape.push_back(static_cast(sizes[i])); + } + return shape; +} + +// ============================================================================ +// TensorWrapper construction from stable::Tensor +// ============================================================================ + +inline transformer_engine::TensorWrapper makeTransformerEngineTensor( + const torch::stable::Tensor& tensor) { + transformer_engine::DType dtype = GetTransformerEngineDType(tensor.scalar_type()); + std::vector shape = getStableTensorShape(tensor); + return transformer_engine::TensorWrapper(tensor.data_ptr(), shape, dtype); +} + +// ============================================================================ +// Tensor allocation via stable ABI +// ============================================================================ + +inline torch::stable::Tensor allocateStableTensor(const std::vector& shape, + ScalarType dtype, int32_t device_index = -1) { + if (device_index < 0) { + device_index = torch::stable::accelerator::getCurrentDeviceIndex(); + } + torch::headeronly::IntHeaderOnlyArrayRef size_ref(shape.data(), shape.size()); + torch::stable::Device device(torch::headeronly::DeviceType::CUDA, device_index); + return torch::stable::empty(size_ref, dtype, + std::nullopt, // layout + device, + std::nullopt, // pin_memory + std::nullopt // memory_format + ); +} + +// ============================================================================ +// Input validation helpers +// ============================================================================ + +inline void check_fp16_bf16(const torch::stable::Tensor& t, const char* name) { + auto st = t.scalar_type(); + NVTE_CHECK(st == ScalarType::Half || st == ScalarType::BFloat16, name, + ": only fp16 and bf16 are supported"); +} + +} // namespace transformer_engine::pytorch::stable + +#endif // TRANSFORMER_ENGINE_PYTORCH_CSRC_STABLE_COMMON_H_ diff --git a/transformer_engine/pytorch/pyproject.toml b/transformer_engine/pytorch/pyproject.toml index 0b42b0a8da..f314729132 100755 --- a/transformer_engine/pytorch/pyproject.toml +++ b/transformer_engine/pytorch/pyproject.toml @@ -3,7 +3,7 @@ # See LICENSE for license information. [build-system] -requires = ["setuptools>=61.0", "pip", "torch>=2.1"] +requires = ["setuptools>=61.0", "pip", "torch>=2.10"] # Use legacy backend to import local packages in setup.py build-backend = "setuptools.build_meta:__legacy__"